diff --git a/.github/CONTRIBUTING.md b/.github/CONTRIBUTING.md new file mode 100644 index 000000000..4847406f9 --- /dev/null +++ b/.github/CONTRIBUTING.md @@ -0,0 +1,17 @@ +PRs welcome! But please file bugs first and explain the problem or +motivation. For new or changed functionality, strike up a discussion +and get agreement on the design/solution before spending too much time writing +code. + +Commit messages should [reference +bugs](https://docs.github.com/en/github/writing-on-github/autolinked-references-and-urls). + +We require [Developer Certificate of +Origin](https://en.wikipedia.org/wiki/Developer_Certificate_of_Origin) (DCO) +`Signed-off-by` lines in commits. (`git commit -s`) + +Please squash your code review edits & force push. Multiple commits in +a PR are fine, but only if they're each logically separate and all tests pass +at each stage. No fixup commits. + +See [commit-messages.md](docs/commit-messages.md) (or skim `git log`) for our commit message style. diff --git a/.github/workflows/checklocks.yml b/.github/workflows/checklocks.yml index 064797c88..5957e6925 100644 --- a/.github/workflows/checklocks.yml +++ b/.github/workflows/checklocks.yml @@ -10,7 +10,7 @@ on: - '.github/workflows/checklocks.yml' concurrency: - group: ${{ github.workflow }}-$${{ github.head_ref || github.run_id }} + group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} cancel-in-progress: true jobs: @@ -18,7 +18,7 @@ jobs: runs-on: [ ubuntu-latest ] steps: - name: Check out code - uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7 + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: Build checklocks run: ./tool/go build -o /tmp/checklocks gvisor.dev/gvisor/tools/checklocks/cmd/checklocks diff --git a/.github/workflows/codeql-analysis.yml b/.github/workflows/codeql-analysis.yml index 9dad75d91..2f5ae7d92 100644 --- a/.github/workflows/codeql-analysis.yml +++ b/.github/workflows/codeql-analysis.yml @@ -23,7 +23,7 @@ on: - cron: '31 14 * * 5' concurrency: - group: ${{ github.workflow }}-$${{ github.head_ref || github.run_id }} + group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} cancel-in-progress: true jobs: @@ -45,17 +45,17 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7 + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 # Install a more recent Go that understands modern go.mod content. - name: Install Go - uses: actions/setup-go@0a12ed9d6a96ab950c8f026ed9f722fe0da7ef32 # v5.0.2 + uses: actions/setup-go@d35c59abb061a4a6fb18e82ac0862c26744d6ab5 # v5.5.0 with: go-version-file: go.mod # Initializes the CodeQL tools for scanning. - name: Initialize CodeQL - uses: github/codeql-action/init@461ef6c76dfe95d5c364de2f431ddbd31a417628 # v3.26.9 + uses: github/codeql-action/init@76621b61decf072c1cee8dd1ce2d2a82d33c17ed # v3.29.5 with: languages: ${{ matrix.language }} # If you wish to specify custom queries, you can do so here or in a config file. @@ -66,7 +66,7 @@ jobs: # Autobuild attempts to build any compiled languages (C/C++, C#, or Java). # If this step fails, then you should remove it and run the build manually (see below) - name: Autobuild - uses: github/codeql-action/autobuild@461ef6c76dfe95d5c364de2f431ddbd31a417628 # v3.26.9 + uses: github/codeql-action/autobuild@76621b61decf072c1cee8dd1ce2d2a82d33c17ed # v3.29.5 # â„šī¸ Command-line programs to run using the OS shell. # 📚 https://git.io/JvXDl @@ -80,4 +80,4 @@ jobs: # make release - name: Perform CodeQL Analysis - uses: github/codeql-action/analyze@461ef6c76dfe95d5c364de2f431ddbd31a417628 # v3.26.9 + uses: github/codeql-action/analyze@76621b61decf072c1cee8dd1ce2d2a82d33c17ed # v3.29.5 diff --git a/.github/workflows/docker-file-build.yml b/.github/workflows/docker-file-build.yml index c53575572..c61680a34 100644 --- a/.github/workflows/docker-file-build.yml +++ b/.github/workflows/docker-file-build.yml @@ -4,12 +4,10 @@ on: branches: - main pull_request: - branches: - - "*" jobs: deploy: runs-on: ubuntu-latest steps: - - uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7 + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: "Build Docker image" run: docker build . diff --git a/.github/workflows/flakehub-publish-tagged.yml b/.github/workflows/flakehub-publish-tagged.yml index 60fdba91c..50bb8b9f7 100644 --- a/.github/workflows/flakehub-publish-tagged.yml +++ b/.github/workflows/flakehub-publish-tagged.yml @@ -17,11 +17,11 @@ jobs: id-token: "write" contents: "read" steps: - - uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7 + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: ref: "${{ (inputs.tag != null) && format('refs/tags/{0}', inputs.tag) || '' }}" - - uses: "DeterminateSystems/nix-installer-action@main" - - uses: "DeterminateSystems/flakehub-push@main" + - uses: DeterminateSystems/nix-installer-action@786fff0690178f1234e4e1fe9b536e94f5433196 # v20 + - uses: DeterminateSystems/flakehub-push@71f57208810a5d299fc6545350981de98fdbc860 # v6 with: visibility: "public" tag: "${{ inputs.tag }}" diff --git a/.github/workflows/golangci-lint.yml b/.github/workflows/golangci-lint.yml index 9c34debc5..bcf17f8e6 100644 --- a/.github/workflows/golangci-lint.yml +++ b/.github/workflows/golangci-lint.yml @@ -15,7 +15,7 @@ permissions: pull-requests: read concurrency: - group: ${{ github.workflow }}-$${{ github.head_ref || github.run_id }} + group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} cancel-in-progress: true jobs: @@ -23,18 +23,17 @@ jobs: name: lint runs-on: ubuntu-latest steps: - - uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7 + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - - uses: actions/setup-go@0a12ed9d6a96ab950c8f026ed9f722fe0da7ef32 # v5.0.2 + - uses: actions/setup-go@d35c59abb061a4a6fb18e82ac0862c26744d6ab5 # v5.5.0 with: go-version-file: go.mod cache: false - name: golangci-lint - # Note: this is the 'v6.1.0' tag as of 2024-08-21 - uses: golangci/golangci-lint-action@aaa42aa0628b4ae2578232a66b541047968fac86 + uses: golangci/golangci-lint-action@1481404843c368bc19ca9406f87d6e0fc97bdcfd # v7.0.0 with: - version: v1.60 + version: v2.4.0 # Show only new issues if it's a pull request. only-new-issues: true diff --git a/.github/workflows/govulncheck.yml b/.github/workflows/govulncheck.yml index 4a5ad54f3..c7560983a 100644 --- a/.github/workflows/govulncheck.yml +++ b/.github/workflows/govulncheck.yml @@ -14,7 +14,7 @@ jobs: steps: - name: Check out code into the Go module directory - uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7 + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: Install govulncheck run: ./tool/go install golang.org/x/vuln/cmd/govulncheck@latest @@ -24,13 +24,13 @@ jobs: - name: Post to slack if: failure() && github.event_name == 'schedule' - uses: slackapi/slack-github-action@37ebaef184d7626c5f204ab8d3baff4262dd30f0 # v1.27.0 - env: - SLACK_BOT_TOKEN: ${{ secrets.GOVULNCHECK_BOT_TOKEN }} + uses: slackapi/slack-github-action@91efab103c0de0a537f72a35f6b8cda0ee76bf0a # v2.1.1 with: - channel-id: 'C05PXRM304B' + method: chat.postMessage + token: ${{ secrets.GOVULNCHECK_BOT_TOKEN }} payload: | { + "channel": "C08FGKZCQTW", "blocks": [ { "type": "section", diff --git a/.github/workflows/installer.yml b/.github/workflows/installer.yml index 48b29c6ec..bafa9925a 100644 --- a/.github/workflows/installer.yml +++ b/.github/workflows/installer.yml @@ -1,16 +1,18 @@ name: test installer.sh on: + schedule: + - cron: '0 15 * * *' # 10am EST (UTC-4/5) push: branches: - "main" paths: - scripts/installer.sh + - .github/workflows/installer.yml pull_request: - branches: - - "*" paths: - scripts/installer.sh + - .github/workflows/installer.yml jobs: test: @@ -29,13 +31,11 @@ jobs: - "debian:stable-slim" - "debian:testing-slim" - "debian:sid-slim" - - "ubuntu:18.04" - "ubuntu:20.04" - "ubuntu:22.04" - - "ubuntu:23.04" + - "ubuntu:24.04" - "elementary/docker:stable" - "elementary/docker:unstable" - - "parrotsec/core:lts-amd64" - "parrotsec/core:latest" - "kalilinux/kali-rolling" - "kalilinux/kali-dev" @@ -48,7 +48,7 @@ jobs: - "opensuse/leap:latest" - "opensuse/tumbleweed:latest" - "archlinux:latest" - - "alpine:3.14" + - "alpine:3.21" - "alpine:latest" - "alpine:edge" deps: @@ -58,10 +58,6 @@ jobs: # Check a few images with wget rather than curl. - { image: "debian:oldstable-slim", deps: "wget" } - { image: "debian:sid-slim", deps: "wget" } - - { image: "ubuntu:23.04", deps: "wget" } - # Ubuntu 16.04 also needs apt-transport-https installed. - - { image: "ubuntu:16.04", deps: "curl apt-transport-https" } - - { image: "ubuntu:16.04", deps: "wget apt-transport-https" } runs-on: ubuntu-latest container: image: ${{ matrix.image }} @@ -76,10 +72,10 @@ jobs: # tar and gzip are needed by the actions/checkout below. run: yum install -y --allowerasing tar gzip ${{ matrix.deps }} if: | - contains(matrix.image, 'centos') - || contains(matrix.image, 'oraclelinux') - || contains(matrix.image, 'fedora') - || contains(matrix.image, 'amazonlinux') + contains(matrix.image, 'centos') || + contains(matrix.image, 'oraclelinux') || + contains(matrix.image, 'fedora') || + contains(matrix.image, 'amazonlinux') - name: install dependencies (zypper) # tar and gzip are needed by the actions/checkout below. run: zypper --non-interactive install tar gzip ${{ matrix.deps }} @@ -89,16 +85,13 @@ jobs: apt-get update apt-get install -y ${{ matrix.deps }} if: | - contains(matrix.image, 'debian') - || contains(matrix.image, 'ubuntu') - || contains(matrix.image, 'elementary') - || contains(matrix.image, 'parrotsec') - || contains(matrix.image, 'kalilinux') + contains(matrix.image, 'debian') || + contains(matrix.image, 'ubuntu') || + contains(matrix.image, 'elementary') || + contains(matrix.image, 'parrotsec') || + contains(matrix.image, 'kalilinux') - name: checkout - # We cannot use v4, as it requires a newer glibc version than some of the - # tested images provide. See - # https://github.com/actions/checkout/issues/1487 - uses: actions/checkout@f43a0e5ff2bd294095638e18286ca9a3d1956744 # v3.6.0 + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: run installer run: scripts/installer.sh # Package installation can fail in docker because systemd is not running @@ -107,3 +100,30 @@ jobs: continue-on-error: true - name: check tailscale version run: tailscale --version + notify-slack: + needs: test + runs-on: ubuntu-latest + steps: + - name: Notify Slack of failure on scheduled runs + if: failure() && github.event_name == 'schedule' + uses: slackapi/slack-github-action@91efab103c0de0a537f72a35f6b8cda0ee76bf0a # v2.1.1 + with: + webhook: ${{ secrets.SLACK_WEBHOOK_URL }} + webhook-type: incoming-webhook + payload: | + { + "attachments": [{ + "title": "Tailscale installer test failed", + "title_link": "https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}", + "text": "One or more OSes in the test matrix failed. See the run for details.", + "fields": [ + { + "title": "Ref", + "value": "${{ github.ref_name }}", + "short": true + } + ], + "footer": "${{ github.workflow }} on schedule", + "color": "danger" + }] + } diff --git a/.github/workflows/kubemanifests.yaml b/.github/workflows/kubemanifests.yaml index f943ccb52..4cffea02f 100644 --- a/.github/workflows/kubemanifests.yaml +++ b/.github/workflows/kubemanifests.yaml @@ -9,7 +9,7 @@ on: # Cancel workflow run if there is a newer push to the same PR for which it is # running concurrency: - group: ${{ github.workflow }}-$${{ github.head_ref || github.run_id }} + group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} cancel-in-progress: true jobs: @@ -17,7 +17,7 @@ jobs: runs-on: [ ubuntu-latest ] steps: - name: Check out code - uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7 + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: Build and lint Helm chart run: | eval `./tool/go run ./cmd/mkversion` diff --git a/.github/workflows/natlab-integrationtest.yml b/.github/workflows/natlab-integrationtest.yml new file mode 100644 index 000000000..99d58717b --- /dev/null +++ b/.github/workflows/natlab-integrationtest.yml @@ -0,0 +1,27 @@ +# Run some natlab integration tests. +# See https://github.com/tailscale/tailscale/issues/13038 +name: "natlab-integrationtest" + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} + cancel-in-progress: true + +on: + pull_request: + paths: + - "tstest/integration/nat/nat_test.go" +jobs: + natlab-integrationtest: + runs-on: ubuntu-latest + steps: + - name: Check out code + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + - name: Install qemu + run: | + sudo rm /var/lib/man-db/auto-update + sudo apt-get -y update + sudo apt-get -y remove man-db + sudo apt-get install -y qemu-system-x86 qemu-utils + - name: Run natlab integration tests + run: | + ./tool/go test -v -run=^TestEasyEasy$ -timeout=3m -count=1 ./tstest/integration/nat --run-vm-tests diff --git a/.github/workflows/pin-github-actions.yml b/.github/workflows/pin-github-actions.yml new file mode 100644 index 000000000..cb6673993 --- /dev/null +++ b/.github/workflows/pin-github-actions.yml @@ -0,0 +1,29 @@ +# Pin images used in github actions to a hash instead of a version tag. +name: pin-github-actions +on: + pull_request: + branches: + - main + paths: + - ".github/workflows/**" + + workflow_dispatch: + +permissions: + contents: read + pull-requests: read + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} + cancel-in-progress: true + +jobs: + run: + name: pin-github-actions + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + - name: pin + run: make pin-github-actions + - name: check for changed workflow files + run: git diff --no-ext-diff --exit-code .github/workflows || (echo "Some github actions versions need pinning, run make pin-github-actions."; exit 1) diff --git a/.github/workflows/request-dataplane-review.yml b/.github/workflows/request-dataplane-review.yml new file mode 100644 index 000000000..7ae5668c3 --- /dev/null +++ b/.github/workflows/request-dataplane-review.yml @@ -0,0 +1,30 @@ +name: request-dataplane-review + +on: + pull_request: + paths: + - ".github/workflows/request-dataplane-review.yml" + - "**/*derp*" + - "**/derp*/**" + - "!**/depaware.txt" + +jobs: + request-dataplane-review: + name: Request Dataplane Review + runs-on: ubuntu-latest + steps: + - name: Check out code + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + - name: Get access token + uses: actions/create-github-app-token@df432ceedc7162793a195dd1713ff69aefc7379e # v2.0.6 + id: generate-token + with: + # Get token for app: https://github.com/apps/change-visibility-bot + app-id: ${{ secrets.VISIBILITY_BOT_APP_ID }} + private-key: ${{ secrets.VISIBILITY_BOT_APP_PRIVATE_KEY }} + - name: Add reviewers + env: + GH_TOKEN: ${{ steps.generate-token.outputs.token }} + url: ${{ github.event.pull_request.html_url }} + run: | + gh pr edit "$url" --add-reviewer tailscale/dataplane diff --git a/.github/workflows/ssh-integrationtest.yml b/.github/workflows/ssh-integrationtest.yml index a82696307..463f4bdd4 100644 --- a/.github/workflows/ssh-integrationtest.yml +++ b/.github/workflows/ssh-integrationtest.yml @@ -3,7 +3,7 @@ name: "ssh-integrationtest" concurrency: - group: ${{ github.workflow }}-$${{ github.head_ref || github.run_id }} + group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} cancel-in-progress: true on: @@ -17,7 +17,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Check out code - uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7 + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: Run SSH integration tests run: | make sshintegrationtest \ No newline at end of file diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 5cfd86c40..35b4ea3ef 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -15,6 +15,10 @@ env: # - false: we expect fuzzing to be happy, and should report failure if it's not. # - true: we expect fuzzing is broken, and should report failure if it start working. TS_FUZZ_CURRENTLY_BROKEN: false + # GOMODCACHE is the same definition on all OSes. Within the workspace, we use + # toplevel directories "src" (for the checked out source code), and "gomodcache" + # and other caches as siblings to follow. + GOMODCACHE: ${{ github.workspace }}/gomodcache on: push: @@ -38,8 +42,42 @@ concurrency: cancel-in-progress: true jobs: + gomod-cache: + runs-on: ubuntu-24.04 + outputs: + cache-key: ${{ steps.hash.outputs.key }} + steps: + - name: Checkout + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + path: src + - name: Compute cache key from go.{mod,sum} + id: hash + run: echo "key=gomod-cross3-${{ hashFiles('src/go.mod', 'src/go.sum') }}" >> $GITHUB_OUTPUT + # See if the cache entry already exists to avoid downloading it + # and doing the cache write again. + - id: check-cache + uses: actions/cache/restore@0400d5f644dc74513175e3cd8d07132dd4860809 # v4 + with: + path: gomodcache # relative to workspace; see env note at top of file + key: ${{ steps.hash.outputs.key }} + lookup-only: true + enableCrossOsArchive: true + - name: Download modules + if: steps.check-cache.outputs.cache-hit != 'true' + working-directory: src + run: go mod download + - name: Cache Go modules + if: steps.check-cache.outputs.cache-hit != 'true' + uses: actions/cache@0400d5f644dc74513175e3cd8d07132dd4860809 # v4.2.4 + with: + path: gomodcache # relative to workspace; see env note at top of file + key: ${{ steps.hash.outputs.key }} + enableCrossOsArchive: true + race-root-integration: - runs-on: ubuntu-22.04 + runs-on: ubuntu-24.04 + needs: gomod-cache strategy: fail-fast: false # don't abort the entire matrix if one element fails matrix: @@ -50,10 +88,20 @@ jobs: - shard: '4/4' steps: - name: checkout - uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7 + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + path: src + - name: Restore Go module cache + uses: actions/cache/restore@0400d5f644dc74513175e3cd8d07132dd4860809 # v4.2.4 + with: + path: gomodcache + key: ${{ needs.gomod-cache.outputs.cache-key }} + enableCrossOsArchive: true - name: build test wrapper + working-directory: src run: ./tool/go build -o /tmp/testwrapper ./cmd/testwrapper - name: integration tests as root + working-directory: src run: PATH=$PWD/tool:$PATH /tmp/testwrapper -exec "sudo -E" -race ./tstest/integration/ env: TS_TEST_SHARD: ${{ matrix.shard }} @@ -64,7 +112,6 @@ jobs: matrix: include: - goarch: amd64 - coverflags: "-coverprofile=/tmp/coverage.out" - goarch: amd64 buildflags: "-race" shard: '1/3' @@ -75,12 +122,21 @@ jobs: buildflags: "-race" shard: '3/3' - goarch: "386" # thanks yaml - runs-on: ubuntu-22.04 + runs-on: ubuntu-24.04 + needs: gomod-cache steps: - name: checkout - uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7 + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + path: src + - name: Restore Go module cache + uses: actions/cache/restore@0400d5f644dc74513175e3cd8d07132dd4860809 # v4.2.4 + with: + path: gomodcache + key: ${{ needs.gomod-cache.outputs.cache-key }} + enableCrossOsArchive: true - name: Restore Cache - uses: actions/cache@0c45773b623bea8c8e75f6c82b208c3cf94ea4f9 # v4.0.2 + uses: actions/cache@0400d5f644dc74513175e3cd8d07132dd4860809 # v4.2.4 with: # Note: unlike the other setups, this is only grabbing the mod download # cache, rather than the whole mod directory, as the download cache @@ -88,7 +144,6 @@ jobs: # fetched and extracted by tar path: | ~/.cache/go-build - ~/go/pkg/mod/cache ~\AppData\Local\go-build # The -2- here should be incremented when the scheme of data to be # cached changes (e.g. path above changes). @@ -98,13 +153,14 @@ jobs: ${{ github.job }}-${{ runner.os }}-${{ matrix.goarch }}-${{ matrix.buildflags }}-go-2- - name: build all if: matrix.buildflags == '' # skip on race builder + working-directory: src run: ./tool/go build ${{matrix.buildflags}} ./... env: GOARCH: ${{ matrix.goarch }} - name: build variant CLIs if: matrix.buildflags == '' # skip on race builder + working-directory: src run: | - export TS_USE_TOOLCHAIN=1 ./build_dist.sh --extra-small ./cmd/tailscaled ./build_dist.sh --box ./cmd/tailscaled ./build_dist.sh --extra-small --box ./cmd/tailscaled @@ -117,24 +173,24 @@ jobs: sudo apt-get -y update sudo apt-get -y install qemu-user - name: build test wrapper + working-directory: src run: ./tool/go build -o /tmp/testwrapper ./cmd/testwrapper - name: test all - run: NOBASHDEBUG=true PATH=$PWD/tool:$PATH /tmp/testwrapper ${{matrix.coverflags}} ./... ${{matrix.buildflags}} + working-directory: src + run: NOBASHDEBUG=true NOPWSHDEBUG=true PATH=$PWD/tool:$PATH /tmp/testwrapper ./... ${{matrix.buildflags}} env: GOARCH: ${{ matrix.goarch }} TS_TEST_SHARD: ${{ matrix.shard }} - - name: Publish to coveralls.io - if: matrix.coverflags != '' # only publish results if we've tracked coverage - uses: shogo82148/actions-goveralls@v1 - with: - path-to-profile: /tmp/coverage.out - name: bench all + working-directory: src run: ./tool/go test ${{matrix.buildflags}} -bench=. -benchtime=1x -run=^$ $(for x in $(git grep -l "^func Benchmark" | xargs dirname | sort | uniq); do echo "./$x"; done) env: GOARCH: ${{ matrix.goarch }} - name: check that no tracked files changed + working-directory: src run: git diff --no-ext-diff --name-only --exit-code || (echo "Build/test modified the files above."; exit 1) - name: check that no new files were added + working-directory: src run: | # Note: The "error: pathspec..." you see below is normal! # In the success case in which there are no new untracked files, @@ -145,82 +201,141 @@ jobs: echo "Build/test created untracked files in the repo (file names above)." exit 1 fi + - name: Tidy cache + working-directory: src + shell: bash + run: | + find $(go env GOCACHE) -type f -mmin +90 -delete windows: - runs-on: windows-2022 + # windows-8vpu is a 2022 GitHub-managed runner in our + # org with 8 cores and 32 GB of RAM: + # https://github.com/organizations/tailscale/settings/actions/github-hosted-runners/1 + runs-on: windows-8vcpu + needs: gomod-cache + name: Windows (${{ matrix.name || matrix.shard}}) + strategy: + fail-fast: false # don't abort the entire matrix if one element fails + matrix: + include: + - key: "win-bench" + name: "benchmarks" + - key: "win-tool-go" + name: "./tool/go" + - key: "win-shard-1-2" + shard: "1/2" + - key: "win-shard-2-2" + shard: "2/2" steps: - name: checkout - uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7 + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + path: src - name: Install Go - uses: actions/setup-go@0a12ed9d6a96ab950c8f026ed9f722fe0da7ef32 # v5.0.2 + if: matrix.key != 'win-tool-go' + uses: actions/setup-go@d35c59abb061a4a6fb18e82ac0862c26744d6ab5 # v5.5.0 with: - go-version-file: go.mod + go-version-file: src/go.mod cache: false + - name: Restore Go module cache + if: matrix.key != 'win-tool-go' + uses: actions/cache/restore@0400d5f644dc74513175e3cd8d07132dd4860809 # v4.2.4 + with: + path: gomodcache + key: ${{ needs.gomod-cache.outputs.cache-key }} + enableCrossOsArchive: true + - name: Restore Cache - uses: actions/cache@0c45773b623bea8c8e75f6c82b208c3cf94ea4f9 # v4.0.2 + if: matrix.key != 'win-tool-go' + uses: actions/cache@0400d5f644dc74513175e3cd8d07132dd4860809 # v4.2.4 with: - # Note: unlike the other setups, this is only grabbing the mod download - # cache, rather than the whole mod directory, as the download cache - # contains zips that can be unpacked in parallel faster than they can be - # fetched and extracted by tar path: | ~/.cache/go-build - ~/go/pkg/mod/cache ~\AppData\Local\go-build # The -2- here should be incremented when the scheme of data to be # cached changes (e.g. path above changes). - key: ${{ github.job }}-${{ runner.os }}-go-2-${{ hashFiles('**/go.sum') }}-${{ github.run_id }} + key: ${{ github.job }}-${{ matrix.key }}-go-2-${{ hashFiles('**/go.sum') }}-${{ github.run_id }} restore-keys: | - ${{ github.job }}-${{ runner.os }}-go-2-${{ hashFiles('**/go.sum') }} - ${{ github.job }}-${{ runner.os }}-go-2- + ${{ github.job }}-${{ matrix.key }}-go-2-${{ hashFiles('**/go.sum') }} + ${{ github.job }}-${{ matrix.key }}-go-2- + + - name: test-tool-go + if: matrix.key == 'win-tool-go' + working-directory: src + run: ./tool/go version + - name: test - run: go run ./cmd/testwrapper ./... + if: matrix.key != 'win-bench' && matrix.key != 'win-tool-go' # skip on bench builder + working-directory: src + run: go run ./cmd/testwrapper sharded:${{ matrix.shard }} + - name: bench all + if: matrix.key == 'win-bench' + working-directory: src # Don't use -bench=. -benchtime=1x. # Somewhere in the layers (powershell?) # the equals signs cause great confusion. run: go test ./... -bench . -benchtime 1x -run "^$" + - name: Tidy cache + if: matrix.key != 'win-tool-go' + working-directory: src + shell: bash + run: | + find $(go env GOCACHE) -type f -mmin +90 -delete + privileged: - runs-on: ubuntu-22.04 + needs: gomod-cache + runs-on: ubuntu-24.04 container: image: golang:latest options: --privileged steps: - name: checkout - uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7 + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + path: src + - name: Restore Go module cache + uses: actions/cache/restore@0400d5f644dc74513175e3cd8d07132dd4860809 # v4.2.4 + with: + path: gomodcache + key: ${{ needs.gomod-cache.outputs.cache-key }} + enableCrossOsArchive: true - name: chown + working-directory: src run: chown -R $(id -u):$(id -g) $PWD - name: privileged tests + working-directory: src run: ./tool/go test ./util/linuxfw ./derp/xdp vm: + needs: gomod-cache runs-on: ["self-hosted", "linux", "vm"] # VM tests run with some privileges, don't let them run on 3p PRs. if: github.repository == 'tailscale/tailscale' steps: - name: checkout - uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7 + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + path: src + - name: Restore Go module cache + uses: actions/cache/restore@0400d5f644dc74513175e3cd8d07132dd4860809 # v4.2.4 + with: + path: gomodcache + key: ${{ needs.gomod-cache.outputs.cache-key }} + enableCrossOsArchive: true - name: Run VM tests - run: ./tool/go test ./tstest/integration/vms -v -no-s3 -run-vm-tests -run=TestRunUbuntu2004 + working-directory: src + run: ./tool/go test ./tstest/integration/vms -v -no-s3 -run-vm-tests -run=TestRunUbuntu2404 env: HOME: "/var/lib/ghrunner/home" TMPDIR: "/tmp" XDG_CACHE_HOME: "/var/lib/ghrunner/cache" - race-build: - runs-on: ubuntu-22.04 - steps: - - name: checkout - uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7 - - name: build all - run: ./tool/go install -race ./cmd/... - - name: build tests - run: ./tool/go test -race -exec=true ./... - cross: # cross-compile checks, build only. + needs: gomod-cache strategy: fail-fast: false # don't abort the entire matrix if one element fails matrix: @@ -255,12 +370,14 @@ jobs: - goos: openbsd goarch: amd64 - runs-on: ubuntu-22.04 + runs-on: ubuntu-24.04 steps: - name: checkout - uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7 + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + path: src - name: Restore Cache - uses: actions/cache@0c45773b623bea8c8e75f6c82b208c3cf94ea4f9 # v4.0.2 + uses: actions/cache@0400d5f644dc74513175e3cd8d07132dd4860809 # v4.2.4 with: # Note: unlike the other setups, this is only grabbing the mod download # cache, rather than the whole mod directory, as the download cache @@ -268,7 +385,6 @@ jobs: # fetched and extracted by tar path: | ~/.cache/go-build - ~/go/pkg/mod/cache ~\AppData\Local\go-build # The -2- here should be incremented when the scheme of data to be # cached changes (e.g. path above changes). @@ -276,7 +392,14 @@ jobs: restore-keys: | ${{ github.job }}-${{ runner.os }}-${{ matrix.goos }}-${{ matrix.goarch }}-go-2-${{ hashFiles('**/go.sum') }} ${{ github.job }}-${{ runner.os }}-${{ matrix.goos }}-${{ matrix.goarch }}-go-2- + - name: Restore Go module cache + uses: actions/cache/restore@0400d5f644dc74513175e3cd8d07132dd4860809 # v4.2.4 + with: + path: gomodcache + key: ${{ needs.gomod-cache.outputs.cache-key }} + enableCrossOsArchive: true - name: build all + working-directory: src run: ./tool/go build ./cmd/... env: GOOS: ${{ matrix.goos }} @@ -284,25 +407,42 @@ jobs: GOARM: ${{ matrix.goarm }} CGO_ENABLED: "0" - name: build tests + working-directory: src run: ./tool/go test -exec=true ./... env: GOOS: ${{ matrix.goos }} GOARCH: ${{ matrix.goarch }} CGO_ENABLED: "0" + - name: Tidy cache + working-directory: src + shell: bash + run: | + find $(go env GOCACHE) -type f -mmin +90 -delete ios: # similar to cross above, but iOS can't build most of the repo. So, just - #make it build a few smoke packages. - runs-on: ubuntu-22.04 + # make it build a few smoke packages. + runs-on: ubuntu-24.04 + needs: gomod-cache steps: - name: checkout - uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7 + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + path: src + - name: Restore Go module cache + uses: actions/cache/restore@0400d5f644dc74513175e3cd8d07132dd4860809 # v4.2.4 + with: + path: gomodcache + key: ${{ needs.gomod-cache.outputs.cache-key }} + enableCrossOsArchive: true - name: build some - run: ./tool/go build ./ipn/... ./wgengine/ ./types/... ./control/controlclient + working-directory: src + run: ./tool/go build ./ipn/... ./ssh/tailssh ./wgengine/ ./types/... ./control/controlclient env: GOOS: ios GOARCH: arm64 crossmin: # cross-compile for platforms where we only check cmd/tailscale{,d} + needs: gomod-cache strategy: fail-fast: false # don't abort the entire matrix if one element fails matrix: @@ -313,13 +453,21 @@ jobs: # AIX - goos: aix goarch: ppc64 + # Solaris + - goos: solaris + goarch: amd64 + # illumos + - goos: illumos + goarch: amd64 - runs-on: ubuntu-22.04 + runs-on: ubuntu-24.04 steps: - name: checkout - uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7 + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + path: src - name: Restore Cache - uses: actions/cache@0c45773b623bea8c8e75f6c82b208c3cf94ea4f9 # v4.0.2 + uses: actions/cache@0400d5f644dc74513175e3cd8d07132dd4860809 # v4.2.4 with: # Note: unlike the other setups, this is only grabbing the mod download # cache, rather than the whole mod directory, as the download cache @@ -327,7 +475,6 @@ jobs: # fetched and extracted by tar path: | ~/.cache/go-build - ~/go/pkg/mod/cache ~\AppData\Local\go-build # The -2- here should be incremented when the scheme of data to be # cached changes (e.g. path above changes). @@ -335,39 +482,64 @@ jobs: restore-keys: | ${{ github.job }}-${{ runner.os }}-${{ matrix.goos }}-${{ matrix.goarch }}-go-2-${{ hashFiles('**/go.sum') }} ${{ github.job }}-${{ runner.os }}-${{ matrix.goos }}-${{ matrix.goarch }}-go-2- + - name: Restore Go module cache + uses: actions/cache/restore@0400d5f644dc74513175e3cd8d07132dd4860809 # v4.2.4 + with: + path: gomodcache + key: ${{ needs.gomod-cache.outputs.cache-key }} + enableCrossOsArchive: true - name: build core + working-directory: src run: ./tool/go build ./cmd/tailscale ./cmd/tailscaled env: GOOS: ${{ matrix.goos }} GOARCH: ${{ matrix.goarch }} GOARM: ${{ matrix.goarm }} CGO_ENABLED: "0" + - name: Tidy cache + working-directory: src + shell: bash + run: | + find $(go env GOCACHE) -type f -mmin +90 -delete android: # similar to cross above, but android fails to build a few pieces of the # repo. We should fix those pieces, they're small, but as a stepping stone, # only test the subset of android that our past smoke test checked. - runs-on: ubuntu-22.04 + runs-on: ubuntu-24.04 + needs: gomod-cache steps: - name: checkout - uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7 + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + path: src # Super minimal Android build that doesn't even use CGO and doesn't build everything that's needed # and is only arm64. But it's a smoke build: it's not meant to catch everything. But it'll catch # some Android breakages early. # TODO(bradfitz): better; see https://github.com/tailscale/tailscale/issues/4482 + - name: Restore Go module cache + uses: actions/cache/restore@0400d5f644dc74513175e3cd8d07132dd4860809 # v4.2.4 + with: + path: gomodcache + key: ${{ needs.gomod-cache.outputs.cache-key }} + enableCrossOsArchive: true - name: build some - run: ./tool/go install ./net/netns ./ipn/ipnlocal ./wgengine/magicsock/ ./wgengine/ ./wgengine/router/ ./wgengine/netstack ./util/dnsname/ ./ipn/ ./net/netmon ./wgengine/router/ ./tailcfg/ ./types/logger/ ./net/dns ./hostinfo ./version + working-directory: src + run: ./tool/go install ./net/netns ./ipn/ipnlocal ./wgengine/magicsock/ ./wgengine/ ./wgengine/router/ ./wgengine/netstack ./util/dnsname/ ./ipn/ ./net/netmon ./wgengine/router/ ./tailcfg/ ./types/logger/ ./net/dns ./hostinfo ./version ./ssh/tailssh env: GOOS: android GOARCH: arm64 wasm: # builds tsconnect, which is the only wasm build we support - runs-on: ubuntu-22.04 + runs-on: ubuntu-24.04 + needs: gomod-cache steps: - name: checkout - uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7 + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + path: src - name: Restore Cache - uses: actions/cache@0c45773b623bea8c8e75f6c82b208c3cf94ea4f9 # v4.0.2 + uses: actions/cache@0400d5f644dc74513175e3cd8d07132dd4860809 # v4.2.4 with: # Note: unlike the other setups, this is only grabbing the mod download # cache, rather than the whole mod directory, as the download cache @@ -375,7 +547,6 @@ jobs: # fetched and extracted by tar path: | ~/.cache/go-build - ~/go/pkg/mod/cache ~\AppData\Local\go-build # The -2- here should be incremented when the scheme of data to be # cached changes (e.g. path above changes). @@ -383,23 +554,45 @@ jobs: restore-keys: | ${{ github.job }}-${{ runner.os }}-go-2-${{ hashFiles('**/go.sum') }} ${{ github.job }}-${{ runner.os }}-go-2- + - name: Restore Go module cache + uses: actions/cache/restore@0400d5f644dc74513175e3cd8d07132dd4860809 # v4.2.4 + with: + path: gomodcache + key: ${{ needs.gomod-cache.outputs.cache-key }} + enableCrossOsArchive: true - name: build tsconnect client + working-directory: src run: ./tool/go build ./cmd/tsconnect/wasm ./cmd/tailscale/cli env: GOOS: js GOARCH: wasm - name: build tsconnect server + working-directory: src # Note, no GOOS/GOARCH in env on this build step, we're running a build # tool that handles the build itself. run: | ./tool/go run ./cmd/tsconnect --fast-compression build ./tool/go run ./cmd/tsconnect --fast-compression build-pkg + - name: Tidy cache + working-directory: src + shell: bash + run: | + find $(go env GOCACHE) -type f -mmin +90 -delete tailscale_go: # Subset of tests that depend on our custom Go toolchain. - runs-on: ubuntu-22.04 + runs-on: ubuntu-24.04 + needs: gomod-cache steps: - name: checkout - uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7 + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + - name: Set GOMODCACHE env + run: echo "GOMODCACHE=$HOME/.cache/go-mod" >> $GITHUB_ENV + - name: Restore Go module cache + uses: actions/cache/restore@0400d5f644dc74513175e3cd8d07132dd4860809 # v4.2.4 + with: + path: gomodcache + key: ${{ needs.gomod-cache.outputs.cache-key }} + enableCrossOsArchive: true - name: test tailscale_go run: ./tool/go test -tags=tailscale_go,ts_enable_sockstats ./net/sockstats/... @@ -416,11 +609,13 @@ jobs: # explicit 'if' condition, because the default condition for steps is # 'success()', meaning "only run this if no previous steps failed". if: github.event_name == 'pull_request' - runs-on: ubuntu-22.04 + runs-on: ubuntu-24.04 steps: - name: build fuzzers id: build - uses: google/oss-fuzz/infra/cifuzz/actions/build_fuzzers@master + # As of 21 October 2025, this repo doesn't tag releases, so this commit + # hash is just the tip of master. + uses: google/oss-fuzz/infra/cifuzz/actions/build_fuzzers@1242ccb5b6352601e73c00f189ac2ae397242264 # continue-on-error makes steps.build.conclusion be 'success' even if # steps.build.outcome is 'failure'. This means this step does not # contribute to the job's overall pass/fail evaluation. @@ -450,10 +645,12 @@ jobs: # report a failure because TS_FUZZ_CURRENTLY_BROKEN is set to the wrong # value. if: steps.build.outcome == 'success' - uses: google/oss-fuzz/infra/cifuzz/actions/run_fuzzers@master + # As of 21 October 2025, this repo doesn't tag releases, so this commit + # hash is just the tip of master. + uses: google/oss-fuzz/infra/cifuzz/actions/run_fuzzers@1242ccb5b6352601e73c00f189ac2ae397242264 with: oss-fuzz-project-name: 'tailscale' - fuzz-seconds: 300 + fuzz-seconds: 150 dry-run: false language: go - name: Set artifacts_path in env (workaround for actions/upload-artifact#176) @@ -461,78 +658,154 @@ jobs: run: | echo "artifacts_path=$(realpath .)" >> $GITHUB_ENV - name: upload crash - uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 + uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2 if: steps.run.outcome != 'success' && steps.build.outcome == 'success' with: name: artifacts path: ${{ env.artifacts_path }}/out/artifacts depaware: - runs-on: ubuntu-22.04 + runs-on: ubuntu-24.04 + needs: gomod-cache steps: - name: checkout - uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7 + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + path: src + - name: Set GOMODCACHE env + run: echo "GOMODCACHE=$HOME/.cache/go-mod" >> $GITHUB_ENV + - name: Restore Go module cache + uses: actions/cache/restore@0400d5f644dc74513175e3cd8d07132dd4860809 # v4.2.4 + with: + path: gomodcache + key: ${{ needs.gomod-cache.outputs.cache-key }} + enableCrossOsArchive: true - name: check depaware - run: | - export PATH=$(./tool/go env GOROOT)/bin:$PATH - find . -name 'depaware.txt' | xargs -n1 dirname | xargs ./tool/go run github.com/tailscale/depaware --check + working-directory: src + run: make depaware go_generate: - runs-on: ubuntu-22.04 + runs-on: ubuntu-24.04 + needs: gomod-cache steps: - name: checkout - uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7 + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + path: src + - name: Restore Go module cache + uses: actions/cache/restore@0400d5f644dc74513175e3cd8d07132dd4860809 # v4.2.4 + with: + path: gomodcache + key: ${{ needs.gomod-cache.outputs.cache-key }} + enableCrossOsArchive: true - name: check that 'go generate' is clean + working-directory: src run: | pkgs=$(./tool/go list ./... | grep -Ev 'dnsfallback|k8s-operator|xdp') ./tool/go generate $pkgs + git add -N . # ensure untracked files are noticed echo echo git diff --name-only --exit-code || (echo "The files above need updating. Please run 'go generate'."; exit 1) go_mod_tidy: - runs-on: ubuntu-22.04 + runs-on: ubuntu-24.04 + needs: gomod-cache steps: - name: checkout - uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7 + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + path: src + - name: Restore Go module cache + uses: actions/cache/restore@0400d5f644dc74513175e3cd8d07132dd4860809 # v4.2.4 + with: + path: gomodcache + key: ${{ needs.gomod-cache.outputs.cache-key }} + enableCrossOsArchive: true - name: check that 'go mod tidy' is clean + working-directory: src run: | - ./tool/go mod tidy + make tidy echo echo - git diff --name-only --exit-code || (echo "Please run 'go mod tidy'."; exit 1) + git diff --name-only --exit-code || (echo "Please run 'make tidy'"; exit 1) licenses: - runs-on: ubuntu-22.04 + runs-on: ubuntu-24.04 + needs: gomod-cache steps: - name: checkout - uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7 + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + path: src + - name: Restore Go module cache + uses: actions/cache/restore@0400d5f644dc74513175e3cd8d07132dd4860809 # v4.2.4 + with: + path: gomodcache + key: ${{ needs.gomod-cache.outputs.cache-key }} + enableCrossOsArchive: true - name: check licenses - run: ./scripts/check_license_headers.sh . + working-directory: src + run: | + grep -q TestLicenseHeaders *.go || (echo "Expected a test named TestLicenseHeaders"; exit 1) + ./tool/go test -v -run=TestLicenseHeaders staticcheck: - runs-on: ubuntu-22.04 + runs-on: ubuntu-24.04 + needs: gomod-cache + name: staticcheck (${{ matrix.name }}) strategy: fail-fast: false # don't abort the entire matrix if one element fails matrix: - goos: ["linux", "windows", "darwin"] - goarch: ["amd64"] include: - - goos: "windows" - goarch: "386" + - name: "macOS" + goos: "darwin" + goarch: "arm64" + flags: "--with-tags-all=darwin" + - name: "Windows" + goos: "windows" + goarch: "amd64" + flags: "--with-tags-all=windows" + - name: "Linux" + goos: "linux" + goarch: "amd64" + flags: "--with-tags-all=linux" + - name: "Portable (1/4)" + goos: "linux" + goarch: "amd64" + flags: "--without-tags-any=windows,darwin,linux --shard=1/4" + - name: "Portable (2/4)" + goos: "linux" + goarch: "amd64" + flags: "--without-tags-any=windows,darwin,linux --shard=2/4" + - name: "Portable (3/4)" + goos: "linux" + goarch: "amd64" + flags: "--without-tags-any=windows,darwin,linux --shard=3/4" + - name: "Portable (4/4)" + goos: "linux" + goarch: "amd64" + flags: "--without-tags-any=windows,darwin,linux --shard=4/4" + steps: - name: checkout - uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7 - - name: install staticcheck - run: GOBIN=~/.local/bin ./tool/go install honnef.co/go/tools/cmd/staticcheck - - name: run staticcheck + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + path: src + - name: Restore Go module cache + uses: actions/cache/restore@0400d5f644dc74513175e3cd8d07132dd4860809 # v4.2.4 + with: + path: gomodcache + key: ${{ needs.gomod-cache.outputs.cache-key }} + enableCrossOsArchive: true + - name: run staticcheck (${{ matrix.name }}) + working-directory: src run: | export GOROOT=$(./tool/go env GOROOT) - export PATH=$GOROOT/bin:$PATH - staticcheck -- $(./tool/go list ./... | grep -v tempfork) - env: - GOOS: ${{ matrix.goos }} - GOARCH: ${{ matrix.goarch }} + ./tool/go run -exec \ + "env GOOS=${{ matrix.goos }} GOARCH=${{ matrix.goarch }}" \ + honnef.co/go/tools/cmd/staticcheck -- \ + $(./tool/go run ./tool/listpkgs --ignore-3p --goos=${{ matrix.goos }} --goarch=${{ matrix.goarch }} ${{ matrix.flags }} ./...) notify_slack: if: always() @@ -552,7 +825,7 @@ jobs: - go_mod_tidy - licenses - staticcheck - runs-on: ubuntu-22.04 + runs-on: ubuntu-24.04 steps: - name: notify # Only notify slack for merged commits, not PR failures. @@ -563,8 +836,10 @@ jobs: # By having the job always run, but skipping its only step as needed, we # let the CI output collapse nicely in PRs. if: failure() && github.event_name == 'push' - uses: slackapi/slack-github-action@37ebaef184d7626c5f204ab8d3baff4262dd30f0 # v1.27.0 + uses: slackapi/slack-github-action@91efab103c0de0a537f72a35f6b8cda0ee76bf0a # v2.1.1 with: + webhook: ${{ secrets.SLACK_WEBHOOK_URL }} + webhook-type: incoming-webhook payload: | { "attachments": [{ @@ -576,13 +851,10 @@ jobs: "color": "danger" }] } - env: - SLACK_WEBHOOK_URL: ${{ secrets.SLACK_WEBHOOK_URL }} - SLACK_WEBHOOK_TYPE: INCOMING_WEBHOOK - check_mergeability: + merge_blocker: if: always() - runs-on: ubuntu-22.04 + runs-on: ubuntu-24.04 needs: - android - test @@ -604,3 +876,46 @@ jobs: uses: re-actors/alls-green@05ac9388f0aebcb5727afa17fcccfecd6f8ec5fe # v1.2.2 with: jobs: ${{ toJSON(needs) }} + + # This waits on all the jobs which must never fail. Branch protection rules + # enforce these. No flaky tests are allowed in these jobs. (We don't want flaky + # tests anywhere, really, but a flaky test here prevents merging.) + check_mergeability_strict: + if: always() + runs-on: ubuntu-24.04 + needs: + - android + - cross + - crossmin + - ios + - tailscale_go + - depaware + - go_generate + - go_mod_tidy + - licenses + - staticcheck + steps: + - name: Decide if change is okay to merge + if: github.event_name != 'push' + uses: re-actors/alls-green@05ac9388f0aebcb5727afa17fcccfecd6f8ec5fe # v1.2.2 + with: + jobs: ${{ toJSON(needs) }} + + check_mergeability: + if: always() + runs-on: ubuntu-24.04 + needs: + - check_mergeability_strict + - test + - windows + - vm + - wasm + - fuzz + - race-root-integration + - privileged + steps: + - name: Decide if change is okay to merge + if: github.event_name != 'push' + uses: re-actors/alls-green@05ac9388f0aebcb5727afa17fcccfecd6f8ec5fe # v1.2.2 + with: + jobs: ${{ toJSON(needs) }} diff --git a/.github/workflows/update-flake.yml b/.github/workflows/update-flake.yml index f79248c1e..1968c6830 100644 --- a/.github/workflows/update-flake.yml +++ b/.github/workflows/update-flake.yml @@ -8,11 +8,11 @@ on: - main paths: - go.mod - - .github/workflows/update-flakes.yml + - .github/workflows/update-flake.yml workflow_dispatch: concurrency: - group: ${{ github.workflow }}-$${{ github.head_ref || github.run_id }} + group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} cancel-in-progress: true jobs: @@ -21,22 +21,21 @@ jobs: steps: - name: Check out code - uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7 + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: Run update-flakes run: ./update-flake.sh - name: Get access token - uses: tibdex/github-app-token@3beb63f4bd073e61482598c45c71c1019b59b73a # v2.1.0 + uses: actions/create-github-app-token@df432ceedc7162793a195dd1713ff69aefc7379e # v2.0.6 id: generate-token with: - app_id: ${{ secrets.LICENSING_APP_ID }} - installation_retrieval_mode: "id" - installation_retrieval_payload: ${{ secrets.LICENSING_APP_INSTALLATION_ID }} - private_key: ${{ secrets.LICENSING_APP_PRIVATE_KEY }} + # Get token for app: https://github.com/apps/tailscale-code-updater + app-id: ${{ secrets.CODE_UPDATER_APP_ID }} + private-key: ${{ secrets.CODE_UPDATER_APP_PRIVATE_KEY }} - name: Send pull request - uses: peter-evans/create-pull-request@5e914681df9dc83aa4e4905692ca88beb2f9e91f #v7.0.5 + uses: peter-evans/create-pull-request@271a8d0340265f705b14b6d32b9829c1cb33d45e #v7.0.8 with: token: ${{ steps.generate-token.outputs.token }} author: Flakes Updater diff --git a/.github/workflows/update-webclient-prebuilt.yml b/.github/workflows/update-webclient-prebuilt.yml index a0ae95cd7..5565b8c86 100644 --- a/.github/workflows/update-webclient-prebuilt.yml +++ b/.github/workflows/update-webclient-prebuilt.yml @@ -5,7 +5,7 @@ on: workflow_dispatch: concurrency: - group: ${{ github.workflow }}-$${{ github.head_ref || github.run_id }} + group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} cancel-in-progress: true jobs: @@ -14,7 +14,7 @@ jobs: steps: - name: Check out code - uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7 + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: Run go get run: | @@ -23,19 +23,16 @@ jobs: ./tool/go mod tidy - name: Get access token - uses: tibdex/github-app-token@3beb63f4bd073e61482598c45c71c1019b59b73a # v2.1.0 + uses: actions/create-github-app-token@df432ceedc7162793a195dd1713ff69aefc7379e # v2.0.6 id: generate-token with: - # TODO(will): this should use the code updater app rather than licensing. - # It has the same permissions, so not a big deal, but still. - app_id: ${{ secrets.LICENSING_APP_ID }} - installation_retrieval_mode: "id" - installation_retrieval_payload: ${{ secrets.LICENSING_APP_INSTALLATION_ID }} - private_key: ${{ secrets.LICENSING_APP_PRIVATE_KEY }} + # Get token for app: https://github.com/apps/tailscale-code-updater + app-id: ${{ secrets.CODE_UPDATER_APP_ID }} + private-key: ${{ secrets.CODE_UPDATER_APP_PRIVATE_KEY }} - name: Send pull request id: pull-request - uses: peter-evans/create-pull-request@5e914681df9dc83aa4e4905692ca88beb2f9e91f #v7.0.5 + uses: peter-evans/create-pull-request@271a8d0340265f705b14b6d32b9829c1cb33d45e #v7.0.8 with: token: ${{ steps.generate-token.outputs.token }} author: OSS Updater diff --git a/.github/workflows/vet.yml b/.github/workflows/vet.yml new file mode 100644 index 000000000..7eff6b45f --- /dev/null +++ b/.github/workflows/vet.yml @@ -0,0 +1,38 @@ +name: tailscale.com/cmd/vet + +env: + HOME: ${{ github.workspace }} + # GOMODCACHE is the same definition on all OSes. Within the workspace, we use + # toplevel directories "src" (for the checked out source code), and "gomodcache" + # and other caches as siblings to follow. + GOMODCACHE: ${{ github.workspace }}/gomodcache + +on: + push: + branches: + - main + - "release-branch/*" + paths: + - "**.go" + pull_request: + paths: + - "**.go" + +jobs: + vet: + runs-on: [ self-hosted, linux ] + timeout-minutes: 5 + + steps: + - name: Check out code + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + path: src + + - name: Build 'go vet' tool + working-directory: src + run: ./tool/go build -o /tmp/vettool tailscale.com/cmd/vet + + - name: Run 'go vet' + working-directory: src + run: ./tool/go vet -vettool=/tmp/vettool tailscale.com/... diff --git a/.github/workflows/webclient.yml b/.github/workflows/webclient.yml index 9afb7730d..bcec1f52d 100644 --- a/.github/workflows/webclient.yml +++ b/.github/workflows/webclient.yml @@ -3,8 +3,6 @@ on: workflow_dispatch: # For now, only run on requests, not the main branches. pull_request: - branches: - - "*" paths: - "client/web/**" - ".github/workflows/webclient.yml" @@ -15,7 +13,7 @@ on: # - main concurrency: - group: ${{ github.workflow }}-$${{ github.head_ref || github.run_id }} + group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} cancel-in-progress: true jobs: @@ -24,7 +22,7 @@ jobs: steps: - name: Check out code - uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7 + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: Install deps run: ./tool/yarn --cwd client/web - name: Run lint diff --git a/.gitignore b/.gitignore index 47d2bbe95..3941fd06e 100644 --- a/.gitignore +++ b/.gitignore @@ -49,3 +49,6 @@ client/web/build/assets *.xcworkspacedata /tstest/tailmac/bin /tstest/tailmac/build + +# Ignore personal IntelliJ settings +.idea/ diff --git a/.golangci.yml b/.golangci.yml index 45248de16..eb34f9d9e 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -1,104 +1,110 @@ +version: "2" +# Configuration for how we run golangci-lint +# Timeout of 5m was the default in v1. +run: + timeout: 5m linters: # Don't enable any linters by default; just the ones that we explicitly # enable in the list below. - disable-all: true + default: none enable: - bidichk - - gofmt - - goimports - govet - misspell - revive - -# Configuration for how we run golangci-lint -run: - timeout: 5m - -issues: - # Excluding configuration per-path, per-linter, per-text and per-source - exclude-rules: - # These are forks of an upstream package and thus are exempt from stylistic - # changes that would make pulling in upstream changes harder. - - path: tempfork/.*\.go - text: "File is not `gofmt`-ed with `-s` `-r 'interface{} -> any'`" - - path: util/singleflight/.*\.go - text: "File is not `gofmt`-ed with `-s` `-r 'interface{} -> any'`" - -# Per-linter settings are contained in this top-level key -linters-settings: - # Enable all rules by default; we don't use invisible unicode runes. - bidichk: - - gofmt: - rewrite-rules: - - pattern: 'interface{}' - replacement: 'any' - - goimports: - - govet: + settings: # Matches what we use in corp as of 2023-12-07 - enable: - - asmdecl - - assign - - atomic - - bools - - buildtag - - cgocall - - copylocks - - deepequalerrors - - errorsas - - framepointer - - httpresponse - - ifaceassert - - loopclosure - - lostcancel - - nilfunc - - nilness - - printf - - reflectvaluecompare - - shift - - sigchanyzer - - sortslice - - stdmethods - - stringintconv - - structtag - - testinggoroutine - - tests - - unmarshal - - unreachable - - unsafeptr - - unusedresult - settings: - printf: - # List of print function names to check (in addition to default) - funcs: - - github.com/tailscale/tailscale/types/logger.Discard - # NOTE(andrew-d): this doesn't currently work because the printf - # analyzer doesn't support type declarations - #- github.com/tailscale/tailscale/types/logger.Logf - - misspell: - - revive: - enable-all-rules: false - ignore-generated-header: true + govet: + enable: + - asmdecl + - assign + - atomic + - bools + - buildtag + - cgocall + - copylocks + - deepequalerrors + - errorsas + - framepointer + - httpresponse + - ifaceassert + - loopclosure + - lostcancel + - nilfunc + - nilness + - printf + - reflectvaluecompare + - shift + - sigchanyzer + - sortslice + - stdmethods + - stringintconv + - structtag + - testinggoroutine + - tests + - unmarshal + - unreachable + - unsafeptr + - unusedresult + settings: + printf: + # List of print function names to check (in addition to default) + funcs: + - github.com/tailscale/tailscale/types/logger.Discard + # NOTE(andrew-d): this doesn't currently work because the printf + # analyzer doesn't support type declarations + #- github.com/tailscale/tailscale/types/logger.Logf + revive: + enable-all-rules: false + rules: + - name: atomic + - name: context-keys-type + - name: defer + arguments: [[ + # Calling 'recover' at the time a defer is registered (i.e. "defer recover()") has no effect. + "immediate-recover", + # Calling 'recover' outside of a deferred function has no effect + "recover", + # Returning values from a deferred function has no effect + "return", + ]] + - name: duplicated-imports + - name: errorf + - name: string-of-int + - name: time-equal + - name: unconditional-recursion + - name: useless-break + - name: waitgroup-by-value + exclusions: + generated: lax + presets: + - comments + - common-false-positives + - legacy + - std-error-handling rules: - - name: atomic - - name: context-keys-type - - name: defer - arguments: [[ - # Calling 'recover' at the time a defer is registered (i.e. "defer recover()") has no effect. - "immediate-recover", - # Calling 'recover' outside of a deferred function has no effect - "recover", - # Returning values from a deferred function has no effect - "return", - ]] - - name: duplicated-imports - - name: errorf - - name: string-of-int - - name: time-equal - - name: unconditional-recursion - - name: useless-break - - name: waitgroup-by-value + # These are forks of an upstream package and thus are exempt from stylistic + # changes that would make pulling in upstream changes harder. + - path: tempfork/.*\.go + text: File is not `gofmt`-ed with `-s` `-r 'interface{} -> any'` + - path: util/singleflight/.*\.go + text: File is not `gofmt`-ed with `-s` `-r 'interface{} -> any'` + paths: + - third_party$ + - builtin$ + - examples$ +formatters: + enable: + - gofmt + - goimports + settings: + gofmt: + rewrite-rules: + - pattern: interface{} + replacement: any + exclusions: + generated: lax + paths: + - third_party$ + - builtin$ + - examples$ diff --git a/ALPINE.txt b/ALPINE.txt index 55b698c77..93a84c380 100644 --- a/ALPINE.txt +++ b/ALPINE.txt @@ -1 +1 @@ -3.18 \ No newline at end of file +3.22 \ No newline at end of file diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md index be5564ef4..348483df5 100644 --- a/CODE_OF_CONDUCT.md +++ b/CODE_OF_CONDUCT.md @@ -1,135 +1,103 @@ -# Contributor Covenant Code of Conduct +# Tailscale Community Code of Conduct ## Our Pledge -We as members, contributors, and leaders pledge to make participation -in our community a harassment-free experience for everyone, regardless -of age, body size, visible or invisible disability, ethnicity, sex -characteristics, gender identity and expression, level of experience, -education, socio-economic status, nationality, personal appearance, -race, religion, or sexual identity and orientation. - -We pledge to act and interact in ways that contribute to an open, -welcoming, diverse, inclusive, and healthy community. +We are committed to creating an open, welcoming, diverse, inclusive, healthy and respectful community. +Unacceptable, harmful and inappropriate behavior will not be tolerated. ## Our Standards -Examples of behavior that contributes to a positive environment for -our community include: - -* Demonstrating empathy and kindness toward other people -* Being respectful of differing opinions, viewpoints, and experiences -* Giving and gracefully accepting constructive feedback -* Accepting responsibility and apologizing to those affected by our - mistakes, and learning from the experience -* Focusing on what is best not just for us as individuals, but for the - overall community - -Examples of unacceptable behavior include: +Examples of behavior that contributes to a positive environment for our community include: -* The use of sexualized language or imagery, and sexual attention or - advances of any kind -* Trolling, insulting or derogatory comments, and personal or - political attacks -* Public or private harassment -* Publishing others' private information, such as a physical or email - address, without their explicit permission -* Other conduct which could reasonably be considered inappropriate in - a professional setting +- Demonstrating empathy and kindness toward other people. +- Being respectful of differing opinions, viewpoints, and experiences. +- Giving and gracefully accepting constructive feedback. +- Accepting responsibility and apologizing to those affected by our mistakes, and learning from the experience. +- Focusing on what is best not just for us as individuals, but for the overall community. -## Enforcement Responsibilities +Examples of unacceptable behavior include without limitation: -Community leaders are responsible for clarifying and enforcing our -standards of acceptable behavior and will take appropriate and fair -corrective action in response to any behavior that they deem -inappropriate, threatening, offensive, or harmful. +- The use of language, imagery or emojis (collectively "content") that is racist, sexist, homophobic, transphobic, or otherwise harassing or discriminatory based on any protected characteristic. +- The use of sexualized content and sexual attention or advances of any kind. +- The use of violent, intimidating or bullying content. +- Trolling, concern trolling, insulting or derogatory comments, and personal or political attacks. +- Public or private harassment. +- Publishing others' personal information, such as a photo, physical address, email address, online profile information, or other personal information, without their explicit permission or with the intent to bully or harass the other person. +- Posting deep fake or other AI generated content about or involving another person without the explicit permission. +- Spamming community channels and members, such as sending repeat messages, low-effort content, or automated messages. +- Phishing or any similar activity. +- Distributing or promoting malware. +- The use of any coded or suggestive content to hide or provoke otherwise unacceptable behavior. +- Other conduct which could reasonably be considered harmful, illegal, or inappropriate in a professional setting. -Community leaders have the right and responsibility to remove, edit, -or reject comments, commits, code, wiki edits, issues, and other -contributions that are not aligned to this Code of Conduct, and will -communicate reasons for moderation decisions when appropriate. +Please also see the Tailscale Acceptable Use Policy, available at [tailscale.com/tailscale-aup](https://tailscale.com/tailscale-aup). -## Scope +## Reporting Incidents -This Code of Conduct applies within all community spaces, and also -applies when an individual is officially representing the community in -public spaces. Examples of representing our community include using an -official e-mail address, posting via an official social media account, -or acting as an appointed representative at an online or offline -event. +Instances of abusive, harassing, or otherwise unacceptable behavior may be reported to Tailscale directly via , or to the community leaders or moderators via DM or similar. +All complaints will be reviewed and investigated promptly and fairly. +We will respect the privacy and safety of the reporter of any issues. -## Enforcement +Please note that this community is not moderated by staff 24/7, and we do not have, and do not undertake, any obligation to prescreen, monitor, edit, or remove any content or data, or to actively seek facts or circumstances indicating illegal activity. +While we strive to keep the community safe and welcoming, moderation may not be immediate at all hours. +If you encounter any issues, report them using the appropriate channels. -Instances of abusive, harassing, or otherwise unacceptable behavior -may be reported to the community leaders responsible for enforcement -at [info@tailscale.com](mailto:info@tailscale.com). All complaints -will be reviewed and investigated promptly and fairly. +## Enforcement Guidelines -All community leaders are obligated to respect the privacy and -security of the reporter of any incident. +Community leaders and moderators are responsible for clarifying and enforcing our standards of acceptable behavior and will take appropriate and fair corrective action in response to any behavior that they deem inappropriate, threatening, offensive, or harmful. -## Enforcement Guidelines +Community leaders and moderators have the right and responsibility to remove, edit, or reject comments, commits, code, wiki edits, issues, and other contributions that are not aligned to this Community Code of Conduct. +Tailscale retains full discretion to take action (or not) in response to a violation of these guidelines with or without notice or liability to you. +We will interpret our policies and resolve disputes in favor of protecting users, customers, the public, our community and our company, as a whole. -Community leaders will follow these Community Impact Guidelines in -determining the consequences for any action they deem in violation of -this Code of Conduct: +Community leaders will follow these community enforcement guidelines in determining the consequences for any action they deem in violation of this Code of Conduct, +and retain full discretion to apply the enforcement guidelines as necessary depending on the circumstances: ### 1. Correction -**Community Impact**: Use of inappropriate language or other behavior -deemed unprofessional or unwelcome in the community. +Community Impact: Use of inappropriate language or other behavior deemed unprofessional or unwelcome in the community. -**Consequence**: A private, written warning from community leaders, -providing clarity around the nature of the violation and an -explanation of why the behavior was inappropriate. A public apology -may be requested. +Consequence: A private, written warning from community leaders, providing clarity around the nature of the violation and an explanation of why the behavior was inappropriate. +A public apology may be requested. ### 2. Warning -**Community Impact**: A violation through a single incident or series -of actions. +Community Impact: A violation through a single incident or series of actions. -**Consequence**: A warning with consequences for continued -behavior. No interaction with the people involved, including -unsolicited interaction with those enforcing the Code of Conduct, for -a specified period of time. This includes avoiding interactions in -community spaces as well as external channels like social -media. Violating these terms may lead to a temporary or permanent ban. +Consequence: A warning with consequences for continued behavior. +No interaction with the people involved, including unsolicited interaction with those enforcing this Community Code of Conduct, for a specified period of time. +This includes avoiding interactions in community spaces as well as external channels like social media. +Violating these terms may lead to a temporary or permanent ban. ### 3. Temporary Ban -**Community Impact**: A serious violation of community standards, -including sustained inappropriate behavior. +Community Impact: A serious violation of community standards, including sustained inappropriate behavior. -**Consequence**: A temporary ban from any sort of interaction or -public communication with the community for a specified period of -time. No public or private interaction with the people involved, -including unsolicited interaction with those enforcing the Code of -Conduct, is allowed during this period. Violating these terms may lead -to a permanent ban. +Consequence: A temporary ban from any sort of interaction or public communication with the community for a specified period of time. +No public or private interaction with the people involved, including unsolicited interaction with those enforcing the Code of Conduct, is allowed during this period. Violating these terms may lead to a permanent ban. ### 4. Permanent Ban -**Community Impact**: Demonstrating a pattern of violation of -community standards, including sustained inappropriate behavior, -harassment of an individual, or aggression toward or disparagement of -classes of individuals. +Community Impact: Demonstrating a pattern of violation of community standards, including sustained inappropriate behavior, harassment of an individual, or aggression toward or disparagement of classes of individuals. -**Consequence**: A permanent ban from any sort of public interaction -within the community. +Consequence: A permanent ban from any sort of public interaction within the community. + +## Acceptable Use Policy + +Violation of this Community Code of Conduct may also violate the Tailscale Acceptable Use Policy, which may result in suspension or termination of your Tailscale account. +For more information, please see the Tailscale Acceptable Use Policy, available at [tailscale.com/tailscale-aup](https://tailscale.com/tailscale-aup). + +## Privacy + +Please see the Tailscale [Privacy Policy](https://tailscale.com/privacy-policy) for more information about how Tailscale collects, uses, discloses and protects information. ## Attribution -This Code of Conduct is adapted from the [Contributor -Covenant][homepage], version 2.0, available at -https://www.contributor-covenant.org/version/2/0/code_of_conduct.html. +This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 2.0, available at . -Community Impact Guidelines were inspired by [Mozilla's code of -conduct enforcement ladder](https://github.com/mozilla/diversity). +Community Impact Guidelines were inspired by [Mozilla's code of conduct enforcement ladder](https://github.com/mozilla/diversity). [homepage]: https://www.contributor-covenant.org -For answers to common questions about this code of conduct, see the -FAQ at https://www.contributor-covenant.org/faq. Translations are -available at https://www.contributor-covenant.org/translations. - +For answers to common questions about this code of conduct, see the FAQ at . +Translations are available at . diff --git a/Dockerfile b/Dockerfile index 4ad3d88d9..c546cf657 100644 --- a/Dockerfile +++ b/Dockerfile @@ -7,6 +7,15 @@ # Tailscale images are currently built using https://github.com/tailscale/mkctr, # and the build script can be found in ./build_docker.sh. # +# If you want to build local images for testing, you can use make. +# +# To build a Tailscale image and push to the local docker registry: +# +# $ REPO=local/tailscale TAGS=v0.0.1 PLATFORM=local make publishdevimage +# +# To build a Tailscale image and push to a remote docker registry: +# +# $ REPO=//tailscale TAGS=v0.0.1 make publishdevimage # # This Dockerfile includes all the tailscale binaries. # @@ -27,7 +36,7 @@ # $ docker exec tailscaled tailscale status -FROM golang:1.23-alpine AS build-env +FROM golang:1.25-alpine AS build-env WORKDIR /go/src/tailscale @@ -62,8 +71,10 @@ RUN GOARCH=$TARGETARCH go install -ldflags="\ -X tailscale.com/version.gitCommitStamp=$VERSION_GIT_HASH" \ -v ./cmd/tailscale ./cmd/tailscaled ./cmd/containerboot -FROM alpine:3.18 +FROM alpine:3.22 RUN apk add --no-cache ca-certificates iptables iproute2 ip6tables +RUN ln -s /sbin/iptables-legacy /sbin/iptables +RUN ln -s /sbin/ip6tables-legacy /sbin/ip6tables COPY --from=build-env /go/bin/* /usr/local/bin/ # For compat with the previous run.sh, although ideally you should be diff --git a/Dockerfile.base b/Dockerfile.base index eb4f0a02a..6c3c8ed08 100644 --- a/Dockerfile.base +++ b/Dockerfile.base @@ -1,5 +1,12 @@ # Copyright (c) Tailscale Inc & AUTHORS # SPDX-License-Identifier: BSD-3-Clause -FROM alpine:3.18 -RUN apk add --no-cache ca-certificates iptables iproute2 ip6tables iputils +FROM alpine:3.22 +RUN apk add --no-cache ca-certificates iptables iptables-legacy iproute2 ip6tables iputils +# Alpine 3.19 replaced legacy iptables with nftables based implementation. We +# can't be certain that all hosts that run Tailscale containers currently +# suppport nftables, so link back to legacy for backwards compatibility reasons. +# TODO(irbekrm): add some way how to determine if we still run on nodes that +# don't support nftables, so that we can eventually remove these symlinks. +RUN ln -s /sbin/iptables-legacy /sbin/iptables +RUN ln -s /sbin/ip6tables-legacy /sbin/ip6tables diff --git a/Makefile b/Makefile index 98c3d36cc..b78ef0469 100644 --- a/Makefile +++ b/Makefile @@ -8,8 +8,9 @@ PLATFORM ?= "flyio" ## flyio==linux/amd64. Set to "" to build all platforms. vet: ## Run go vet ./tool/go vet ./... -tidy: ## Run go mod tidy +tidy: ## Run go mod tidy and update nix flake hashes ./tool/go mod tidy + ./update-flake.sh lint: ## Run golangci-lint ./tool/go run github.com/golangci/golangci-lint/cmd/golangci-lint run @@ -17,22 +18,36 @@ lint: ## Run golangci-lint updatedeps: ## Update depaware deps # depaware (via x/tools/go/packages) shells back to "go", so make sure the "go" # it finds in its $$PATH is the right one. - PATH="$$(./tool/go env GOROOT)/bin:$$PATH" ./tool/go run github.com/tailscale/depaware --update \ + PATH="$$(./tool/go env GOROOT)/bin:$$PATH" ./tool/go run github.com/tailscale/depaware --update --vendor --internal \ tailscale.com/cmd/tailscaled \ tailscale.com/cmd/tailscale \ tailscale.com/cmd/derper \ tailscale.com/cmd/k8s-operator \ - tailscale.com/cmd/stund + tailscale.com/cmd/stund \ + tailscale.com/cmd/tsidp + PATH="$$(./tool/go env GOROOT)/bin:$$PATH" ./tool/go run github.com/tailscale/depaware --update --goos=linux,darwin,windows,android,ios --vendor --internal \ + tailscale.com/tsnet + PATH="$$(./tool/go env GOROOT)/bin:$$PATH" ./tool/go run github.com/tailscale/depaware --update --file=depaware-minbox.txt --goos=linux --tags="$$(./tool/go run ./cmd/featuretags --min --add=cli)" --vendor --internal \ + tailscale.com/cmd/tailscaled + PATH="$$(./tool/go env GOROOT)/bin:$$PATH" ./tool/go run github.com/tailscale/depaware --update --file=depaware-min.txt --goos=linux --tags="$$(./tool/go run ./cmd/featuretags --min)" --vendor --internal \ + tailscale.com/cmd/tailscaled depaware: ## Run depaware checks # depaware (via x/tools/go/packages) shells back to "go", so make sure the "go" # it finds in its $$PATH is the right one. - PATH="$$(./tool/go env GOROOT)/bin:$$PATH" ./tool/go run github.com/tailscale/depaware --check \ + PATH="$$(./tool/go env GOROOT)/bin:$$PATH" ./tool/go run github.com/tailscale/depaware --check --vendor --internal \ tailscale.com/cmd/tailscaled \ tailscale.com/cmd/tailscale \ tailscale.com/cmd/derper \ tailscale.com/cmd/k8s-operator \ - tailscale.com/cmd/stund + tailscale.com/cmd/stund \ + tailscale.com/cmd/tsidp + PATH="$$(./tool/go env GOROOT)/bin:$$PATH" ./tool/go run github.com/tailscale/depaware --check --goos=linux,darwin,windows,android,ios --vendor --internal \ + tailscale.com/tsnet + PATH="$$(./tool/go env GOROOT)/bin:$$PATH" ./tool/go run github.com/tailscale/depaware --check --file=depaware-minbox.txt --goos=linux --tags="$$(./tool/go run ./cmd/featuretags --min --add=cli)" --vendor --internal \ + tailscale.com/cmd/tailscaled + PATH="$$(./tool/go env GOROOT)/bin:$$PATH" ./tool/go run github.com/tailscale/depaware --check --file=depaware-min.txt --goos=linux --tags="$$(./tool/go run ./cmd/featuretags --min)" --vendor --internal \ + tailscale.com/cmd/tailscaled buildwindows: ## Build tailscale CLI for windows/amd64 GOOS=windows GOARCH=amd64 ./tool/go install tailscale.com/cmd/tailscale tailscale.com/cmd/tailscaled @@ -58,7 +73,7 @@ buildmultiarchimage: ## Build (and optionally push) multiarch docker image check: staticcheck vet depaware buildwindows build386 buildlinuxarm buildwasm ## Perform basic checks and compilation tests staticcheck: ## Run staticcheck.io checks - ./tool/go run honnef.co/go/tools/cmd/staticcheck -- $$(./tool/go list ./... | grep -v tempfork) + ./tool/go run honnef.co/go/tools/cmd/staticcheck -- $$(./tool/go run ./tool/listpkgs --ignore-3p ./...) kube-generate-all: kube-generate-deepcopy ## Refresh generated files for Tailscale Kubernetes Operator ./tool/go generate ./cmd/k8s-operator @@ -86,43 +101,60 @@ pushspk: spk ## Push and install synology package on ${SYNO_HOST} host scp tailscale.spk root@${SYNO_HOST}: ssh root@${SYNO_HOST} /usr/syno/bin/synopkg install tailscale.spk -publishdevimage: ## Build and publish tailscale image to location specified by ${REPO} - @test -n "${REPO}" || (echo "REPO=... required; e.g. REPO=ghcr.io/${USER}/tailscale" && exit 1) - @test "${REPO}" != "tailscale/tailscale" || (echo "REPO=... must not be tailscale/tailscale" && exit 1) - @test "${REPO}" != "ghcr.io/tailscale/tailscale" || (echo "REPO=... must not be ghcr.io/tailscale/tailscale" && exit 1) - @test "${REPO}" != "tailscale/k8s-operator" || (echo "REPO=... must not be tailscale/k8s-operator" && exit 1) - @test "${REPO}" != "ghcr.io/tailscale/k8s-operator" || (echo "REPO=... must not be ghcr.io/tailscale/k8s-operator" && exit 1) +.PHONY: check-image-repo +check-image-repo: + @if [ -z "$(REPO)" ]; then \ + echo "REPO=... required; e.g. REPO=ghcr.io/$$USER/tailscale" >&2; \ + exit 1; \ + fi + @for repo in tailscale/tailscale ghcr.io/tailscale/tailscale \ + tailscale/k8s-operator ghcr.io/tailscale/k8s-operator \ + tailscale/k8s-nameserver ghcr.io/tailscale/k8s-nameserver \ + tailscale/tsidp ghcr.io/tailscale/tsidp \ + tailscale/k8s-proxy ghcr.io/tailscale/k8s-proxy; do \ + if [ "$(REPO)" = "$$repo" ]; then \ + echo "REPO=... must not be $$repo" >&2; \ + exit 1; \ + fi; \ + done + +publishdevimage: check-image-repo ## Build and publish tailscale image to location specified by ${REPO} TAGS="${TAGS}" REPOS=${REPO} PLATFORM=${PLATFORM} PUSH=true TARGET=client ./build_docker.sh -publishdevoperator: ## Build and publish k8s-operator image to location specified by ${REPO} - @test -n "${REPO}" || (echo "REPO=... required; e.g. REPO=ghcr.io/${USER}/tailscale" && exit 1) - @test "${REPO}" != "tailscale/tailscale" || (echo "REPO=... must not be tailscale/tailscale" && exit 1) - @test "${REPO}" != "ghcr.io/tailscale/tailscale" || (echo "REPO=... must not be ghcr.io/tailscale/tailscale" && exit 1) - @test "${REPO}" != "tailscale/k8s-operator" || (echo "REPO=... must not be tailscale/k8s-operator" && exit 1) - @test "${REPO}" != "ghcr.io/tailscale/k8s-operator" || (echo "REPO=... must not be ghcr.io/tailscale/k8s-operator" && exit 1) - TAGS="${TAGS}" REPOS=${REPO} PLATFORM=${PLATFORM} PUSH=true TARGET=operator ./build_docker.sh - -publishdevnameserver: ## Build and publish k8s-nameserver image to location specified by ${REPO} - @test -n "${REPO}" || (echo "REPO=... required; e.g. REPO=ghcr.io/${USER}/tailscale" && exit 1) - @test "${REPO}" != "tailscale/tailscale" || (echo "REPO=... must not be tailscale/tailscale" && exit 1) - @test "${REPO}" != "ghcr.io/tailscale/tailscale" || (echo "REPO=... must not be ghcr.io/tailscale/tailscale" && exit 1) - @test "${REPO}" != "tailscale/k8s-nameserver" || (echo "REPO=... must not be tailscale/k8s-nameserver" && exit 1) - @test "${REPO}" != "ghcr.io/tailscale/k8s-nameserver" || (echo "REPO=... must not be ghcr.io/tailscale/k8s-nameserver" && exit 1) +publishdevoperator: check-image-repo ## Build and publish k8s-operator image to location specified by ${REPO} + TAGS="${TAGS}" REPOS=${REPO} PLATFORM=${PLATFORM} PUSH=true TARGET=k8s-operator ./build_docker.sh + +publishdevnameserver: check-image-repo ## Build and publish k8s-nameserver image to location specified by ${REPO} TAGS="${TAGS}" REPOS=${REPO} PLATFORM=${PLATFORM} PUSH=true TARGET=k8s-nameserver ./build_docker.sh +publishdevtsidp: check-image-repo ## Build and publish tsidp image to location specified by ${REPO} + TAGS="${TAGS}" REPOS=${REPO} PLATFORM=${PLATFORM} PUSH=true TARGET=tsidp ./build_docker.sh + +publishdevproxy: check-image-repo ## Build and publish k8s-proxy image to location specified by ${REPO} + TAGS="${TAGS}" REPOS=${REPO} PLATFORM=${PLATFORM} PUSH=true TARGET=k8s-proxy ./build_docker.sh + .PHONY: sshintegrationtest sshintegrationtest: ## Run the SSH integration tests in various Docker containers - @GOOS=linux GOARCH=amd64 ./tool/go test -tags integrationtest -c ./ssh/tailssh -o ssh/tailssh/testcontainers/tailssh.test && \ - GOOS=linux GOARCH=amd64 ./tool/go build -o ssh/tailssh/testcontainers/tailscaled ./cmd/tailscaled && \ + @GOOS=linux GOARCH=amd64 CGO_ENABLED=0 ./tool/go test -tags integrationtest -c ./ssh/tailssh -o ssh/tailssh/testcontainers/tailssh.test && \ + GOOS=linux GOARCH=amd64 CGO_ENABLED=0 ./tool/go build -o ssh/tailssh/testcontainers/tailscaled ./cmd/tailscaled && \ echo "Testing on ubuntu:focal" && docker build --build-arg="BASE=ubuntu:focal" -t ssh-ubuntu-focal ssh/tailssh/testcontainers && \ echo "Testing on ubuntu:jammy" && docker build --build-arg="BASE=ubuntu:jammy" -t ssh-ubuntu-jammy ssh/tailssh/testcontainers && \ - echo "Testing on ubuntu:mantic" && docker build --build-arg="BASE=ubuntu:mantic" -t ssh-ubuntu-mantic ssh/tailssh/testcontainers && \ echo "Testing on ubuntu:noble" && docker build --build-arg="BASE=ubuntu:noble" -t ssh-ubuntu-noble ssh/tailssh/testcontainers && \ echo "Testing on alpine:latest" && docker build --build-arg="BASE=alpine:latest" -t ssh-alpine-latest ssh/tailssh/testcontainers +.PHONY: generate +generate: ## Generate code + ./tool/go generate ./... + +.PHONY: pin-github-actions +pin-github-actions: + ./tool/go tool github.com/stacklok/frizbee actions .github/workflows + help: ## Show this help - @echo "\nSpecify a command. The choices are:\n" - @grep -hE '^[0-9a-zA-Z_-]+:.*?## .*$$' ${MAKEFILE_LIST} | awk 'BEGIN {FS = ":.*?## "}; {printf " \033[0;36m%-20s\033[m %s\n", $$1, $$2}' + @echo "" + @echo "Specify a command. The choices are:" + @echo "" + @grep -hE '^[0-9a-zA-Z_-]+:.*?## .*$$' ${MAKEFILE_LIST} | sort | awk 'BEGIN {FS = ":.*?## "}; {printf " \033[0;36m%-20s\033[m %s\n", $$1, $$2}' @echo "" .PHONY: help diff --git a/README.md b/README.md index 4627d9780..70b92d411 100644 --- a/README.md +++ b/README.md @@ -37,7 +37,7 @@ not open source. ## Building -We always require the latest Go release, currently Go 1.23. (While we build +We always require the latest Go release, currently Go 1.25. (While we build releases with our [Go fork](https://github.com/tailscale/go/), its use is not required.) @@ -71,8 +71,7 @@ We require [Developer Certificate of Origin](https://en.wikipedia.org/wiki/Developer_Certificate_of_Origin) `Signed-off-by` lines in commits. -See `git log` for our commit message style. It's basically the same as -[Go's style](https://github.com/golang/go/wiki/CommitMessage). +See [commit-messages.md](docs/commit-messages.md) (or skim `git log`) for our commit message style. ## About Us diff --git a/VERSION.txt b/VERSION.txt index 7c7053aa2..6979a6c06 100644 --- a/VERSION.txt +++ b/VERSION.txt @@ -1 +1 @@ -1.75.0 +1.91.0 diff --git a/appc/appconnector.go b/appc/appconnector.go index 671ced953..d41f9e8ba 100644 --- a/appc/appconnector.go +++ b/appc/appconnector.go @@ -12,20 +12,20 @@ package appc import ( "context" "fmt" + "maps" "net/netip" "slices" "strings" - "sync" "time" - xmaps "golang.org/x/exp/maps" - "golang.org/x/net/dns/dnsmessage" + "tailscale.com/syncs" + "tailscale.com/types/appctype" "tailscale.com/types/logger" "tailscale.com/types/views" "tailscale.com/util/clientmetric" "tailscale.com/util/dnsname" + "tailscale.com/util/eventbus" "tailscale.com/util/execqueue" - "tailscale.com/util/mak" "tailscale.com/util/slicesx" ) @@ -116,19 +116,6 @@ func metricStoreRoutes(rate, nRoutes int64) { recordMetric(nRoutes, metricStoreRoutesNBuckets, metricStoreRoutesN) } -// RouteInfo is a data structure used to persist the in memory state of an AppConnector -// so that we can know, even after a restart, which routes came from ACLs and which were -// learned from domains. -type RouteInfo struct { - // Control is the routes from the 'routes' section of an app connector acl. - Control []netip.Prefix `json:",omitempty"` - // Domains are the routes discovered by observing DNS lookups for configured domains. - Domains map[string][]netip.Addr `json:",omitempty"` - // Wildcards are the configured DNS lookup domains to observe. When a DNS query matches Wildcards, - // its result is added to Domains. - Wildcards []string `json:",omitempty"` -} - // AppConnector is an implementation of an AppConnector that performs // its function as a subsystem inside of a tailscale node. At the control plane // side App Connector routing is configured in terms of domains rather than IP @@ -139,14 +126,20 @@ type RouteInfo struct { // routes not yet served by the AppConnector the local node configuration is // updated to advertise the new route. type AppConnector struct { + // These fields are immutable after initialization. logf logger.Logf + eventBus *eventbus.Bus routeAdvertiser RouteAdvertiser + pubClient *eventbus.Client + updatePub *eventbus.Publisher[appctype.RouteUpdate] + storePub *eventbus.Publisher[appctype.RouteInfo] - // storeRoutesFunc will be called to persist routes if it is not nil. - storeRoutesFunc func(*RouteInfo) error + // hasStoredRoutes records whether the connector was initialized with + // persisted route information. + hasStoredRoutes bool // mu guards the fields that follow - mu sync.Mutex + mu syncs.Mutex // domains is a map of lower case domain names with no trailing dot, to an // ordered list of resolved IP addresses. @@ -165,53 +158,83 @@ type AppConnector struct { writeRateDay *rateLogger } +// Config carries the settings for an [AppConnector]. +type Config struct { + // Logf is the logger to which debug logs from the connector will be sent. + // It must be non-nil. + Logf logger.Logf + + // EventBus receives events when the collection of routes maintained by the + // connector is updated. It must be non-nil. + EventBus *eventbus.Bus + + // RouteAdvertiser allows the connector to update the set of advertised routes. + RouteAdvertiser RouteAdvertiser + + // RouteInfo, if non-nil, use used as the initial set of routes for the + // connector. If nil, the connector starts empty. + RouteInfo *appctype.RouteInfo + + // HasStoredRoutes indicates that the connector should assume stored routes. + HasStoredRoutes bool +} + // NewAppConnector creates a new AppConnector. -func NewAppConnector(logf logger.Logf, routeAdvertiser RouteAdvertiser, routeInfo *RouteInfo, storeRoutesFunc func(*RouteInfo) error) *AppConnector { +func NewAppConnector(c Config) *AppConnector { + switch { + case c.Logf == nil: + panic("missing logger") + case c.EventBus == nil: + panic("missing event bus") + } + ec := c.EventBus.Client("appc.AppConnector") + ac := &AppConnector{ - logf: logger.WithPrefix(logf, "appc: "), - routeAdvertiser: routeAdvertiser, - storeRoutesFunc: storeRoutesFunc, + logf: logger.WithPrefix(c.Logf, "appc: "), + eventBus: c.EventBus, + pubClient: ec, + updatePub: eventbus.Publish[appctype.RouteUpdate](ec), + storePub: eventbus.Publish[appctype.RouteInfo](ec), + routeAdvertiser: c.RouteAdvertiser, + hasStoredRoutes: c.HasStoredRoutes, } - if routeInfo != nil { - ac.domains = routeInfo.Domains - ac.wildcards = routeInfo.Wildcards - ac.controlRoutes = routeInfo.Control + if c.RouteInfo != nil { + ac.domains = c.RouteInfo.Domains + ac.wildcards = c.RouteInfo.Wildcards + ac.controlRoutes = c.RouteInfo.Control } - ac.writeRateMinute = newRateLogger(time.Now, time.Minute, func(c int64, s time.Time, l int64) { - ac.logf("routeInfo write rate: %d in minute starting at %v (%d routes)", c, s, l) - metricStoreRoutes(c, l) + ac.writeRateMinute = newRateLogger(time.Now, time.Minute, func(c int64, s time.Time, ln int64) { + ac.logf("routeInfo write rate: %d in minute starting at %v (%d routes)", c, s, ln) + metricStoreRoutes(c, ln) }) - ac.writeRateDay = newRateLogger(time.Now, 24*time.Hour, func(c int64, s time.Time, l int64) { - ac.logf("routeInfo write rate: %d in 24 hours starting at %v (%d routes)", c, s, l) + ac.writeRateDay = newRateLogger(time.Now, 24*time.Hour, func(c int64, s time.Time, ln int64) { + ac.logf("routeInfo write rate: %d in 24 hours starting at %v (%d routes)", c, s, ln) }) return ac } // ShouldStoreRoutes returns true if the appconnector was created with the controlknob on // and is storing its discovered routes persistently. -func (e *AppConnector) ShouldStoreRoutes() bool { - return e.storeRoutesFunc != nil -} +func (e *AppConnector) ShouldStoreRoutes() bool { return e.hasStoredRoutes } // storeRoutesLocked takes the current state of the AppConnector and persists it -func (e *AppConnector) storeRoutesLocked() error { - if !e.ShouldStoreRoutes() { - return nil - } - - // log write rate and write size - numRoutes := int64(len(e.controlRoutes)) - for _, rs := range e.domains { - numRoutes += int64(len(rs)) +func (e *AppConnector) storeRoutesLocked() { + if e.storePub.ShouldPublish() { + // log write rate and write size + numRoutes := int64(len(e.controlRoutes)) + for _, rs := range e.domains { + numRoutes += int64(len(rs)) + } + e.writeRateMinute.update(numRoutes) + e.writeRateDay.update(numRoutes) + + e.storePub.Publish(appctype.RouteInfo{ + // Clone here, as the subscriber will handle these outside our lock. + Control: slices.Clone(e.controlRoutes), + Domains: maps.Clone(e.domains), + Wildcards: slices.Clone(e.wildcards), + }) } - e.writeRateMinute.update(numRoutes) - e.writeRateDay.update(numRoutes) - - return e.storeRoutesFunc(&RouteInfo{ - Control: e.controlRoutes, - Domains: e.domains, - Wildcards: e.wildcards, - }) } // ClearRoutes removes all route state from the AppConnector. @@ -221,7 +244,8 @@ func (e *AppConnector) ClearRoutes() error { e.controlRoutes = nil e.domains = nil e.wildcards = nil - return e.storeRoutesLocked() + e.storeRoutesLocked() + return nil } // UpdateDomainsAndRoutes starts an asynchronous update of the configuration @@ -250,6 +274,18 @@ func (e *AppConnector) Wait(ctx context.Context) { e.queue.Wait(ctx) } +// Close closes the connector and cleans up resources associated with it. +// It is safe (and a noop) to call Close on nil. +func (e *AppConnector) Close() { + if e == nil { + return + } + e.mu.Lock() + defer e.mu.Unlock() + e.queue.Shutdown() // TODO(creachadair): Should we wait for it too? + e.pubClient.Close() +} + func (e *AppConnector) updateDomains(domains []string) { e.mu.Lock() defer e.mu.Unlock() @@ -281,21 +317,29 @@ func (e *AppConnector) updateDomains(domains []string) { } } - // Everything left in oldDomains is a domain we're no longer tracking - // and if we are storing route info we can unadvertise the routes - if e.ShouldStoreRoutes() { + // Everything left in oldDomains is a domain we're no longer tracking and we + // can unadvertise the routes. + if e.hasStoredRoutes { toRemove := []netip.Prefix{} for _, addrs := range oldDomains { for _, a := range addrs { toRemove = append(toRemove, netip.PrefixFrom(a, a.BitLen())) } } - if err := e.routeAdvertiser.UnadvertiseRoute(toRemove...); err != nil { - e.logf("failed to unadvertise routes on domain removal: %v: %v: %v", xmaps.Keys(oldDomains), toRemove, err) + + if len(toRemove) != 0 { + if ra := e.routeAdvertiser; ra != nil { + e.queue.Add(func() { + if err := e.routeAdvertiser.UnadvertiseRoute(toRemove...); err != nil { + e.logf("failed to unadvertise routes on domain removal: %v: %v: %v", slicesx.MapKeys(oldDomains), toRemove, err) + } + }) + } + e.updatePub.Publish(appctype.RouteUpdate{Unadvertise: toRemove}) } } - e.logf("handling domains: %v and wildcards: %v", xmaps.Keys(e.domains), e.wildcards) + e.logf("handling domains: %v and wildcards: %v", slicesx.MapKeys(e.domains), e.wildcards) } // updateRoutes merges the supplied routes into the currently configured routes. The routes supplied @@ -311,18 +355,12 @@ func (e *AppConnector) updateRoutes(routes []netip.Prefix) { return } - if err := e.routeAdvertiser.AdvertiseRoute(routes...); err != nil { - e.logf("failed to advertise routes: %v: %v", routes, err) - return - } - var toRemove []netip.Prefix - // If we're storing routes and know e.controlRoutes is a good - // representation of what should be in AdvertisedRoutes we can stop - // advertising routes that used to be in e.controlRoutes but are not - // in routes. - if e.ShouldStoreRoutes() { + // If we know e.controlRoutes is a good representation of what should be in + // AdvertisedRoutes we can stop advertising routes that used to be in + // e.controlRoutes but are not in routes. + if e.hasStoredRoutes { toRemove = routesWithout(e.controlRoutes, routes) } @@ -339,14 +377,23 @@ nextRoute: } } - if err := e.routeAdvertiser.UnadvertiseRoute(toRemove...); err != nil { - e.logf("failed to unadvertise routes: %v: %v", toRemove, err) + if e.routeAdvertiser != nil { + e.queue.Add(func() { + if err := e.routeAdvertiser.AdvertiseRoute(routes...); err != nil { + e.logf("failed to advertise routes: %v: %v", routes, err) + } + if err := e.routeAdvertiser.UnadvertiseRoute(toRemove...); err != nil { + e.logf("failed to unadvertise routes: %v: %v", toRemove, err) + } + }) } + e.updatePub.Publish(appctype.RouteUpdate{ + Advertise: routes, + Unadvertise: toRemove, + }) e.controlRoutes = routes - if err := e.storeRoutesLocked(); err != nil { - e.logf("failed to store route info: %v", err) - } + e.storeRoutesLocked() } // Domains returns the currently configured domain list. @@ -354,7 +401,7 @@ func (e *AppConnector) Domains() views.Slice[string] { e.mu.Lock() defer e.mu.Unlock() - return views.SliceOf(xmaps.Keys(e.domains)) + return views.SliceOf(slicesx.MapKeys(e.domains)) } // DomainRoutes returns a map of domains to resolved IP @@ -371,123 +418,6 @@ func (e *AppConnector) DomainRoutes() map[string][]netip.Addr { return drCopy } -// ObserveDNSResponse is a callback invoked by the DNS resolver when a DNS -// response is being returned over the PeerAPI. The response is parsed and -// matched against the configured domains, if matched the routeAdvertiser is -// advised to advertise the discovered route. -func (e *AppConnector) ObserveDNSResponse(res []byte) { - var p dnsmessage.Parser - if _, err := p.Start(res); err != nil { - return - } - if err := p.SkipAllQuestions(); err != nil { - return - } - - // cnameChain tracks a chain of CNAMEs for a given query in order to reverse - // a CNAME chain back to the original query for flattening. The keys are - // CNAME record targets, and the value is the name the record answers, so - // for www.example.com CNAME example.com, the map would contain - // ["example.com"] = "www.example.com". - var cnameChain map[string]string - - // addressRecords is a list of address records found in the response. - var addressRecords map[string][]netip.Addr - - for { - h, err := p.AnswerHeader() - if err == dnsmessage.ErrSectionDone { - break - } - if err != nil { - return - } - - if h.Class != dnsmessage.ClassINET { - if err := p.SkipAnswer(); err != nil { - return - } - continue - } - - switch h.Type { - case dnsmessage.TypeCNAME, dnsmessage.TypeA, dnsmessage.TypeAAAA: - default: - if err := p.SkipAnswer(); err != nil { - return - } - continue - - } - - domain := strings.TrimSuffix(strings.ToLower(h.Name.String()), ".") - if len(domain) == 0 { - continue - } - - if h.Type == dnsmessage.TypeCNAME { - res, err := p.CNAMEResource() - if err != nil { - return - } - cname := strings.TrimSuffix(strings.ToLower(res.CNAME.String()), ".") - if len(cname) == 0 { - continue - } - mak.Set(&cnameChain, cname, domain) - continue - } - - switch h.Type { - case dnsmessage.TypeA: - r, err := p.AResource() - if err != nil { - return - } - addr := netip.AddrFrom4(r.A) - mak.Set(&addressRecords, domain, append(addressRecords[domain], addr)) - case dnsmessage.TypeAAAA: - r, err := p.AAAAResource() - if err != nil { - return - } - addr := netip.AddrFrom16(r.AAAA) - mak.Set(&addressRecords, domain, append(addressRecords[domain], addr)) - default: - if err := p.SkipAnswer(); err != nil { - return - } - continue - } - } - - e.mu.Lock() - defer e.mu.Unlock() - - for domain, addrs := range addressRecords { - domain, isRouted := e.findRoutedDomainLocked(domain, cnameChain) - - // domain and none of the CNAMEs in the chain are routed - if !isRouted { - continue - } - - // advertise each address we have learned for the routed domain, that - // was not already known. - var toAdvertise []netip.Prefix - for _, addr := range addrs { - if !e.isAddrKnownLocked(domain, addr) { - toAdvertise = append(toAdvertise, netip.PrefixFrom(addr, addr.BitLen())) - } - } - - if len(toAdvertise) > 0 { - e.logf("[v2] observed new routes for %s: %s", domain, toAdvertise) - e.scheduleAdvertisement(domain, toAdvertise...) - } - } -} - // starting from the given domain that resolved to an address, find it, or any // of the domains in the CNAME chain toward resolving it, that are routed // domains, returning the routed domain name and a bool indicating whether a @@ -542,10 +472,13 @@ func (e *AppConnector) isAddrKnownLocked(domain string, addr netip.Addr) bool { // associated with the given domain. func (e *AppConnector) scheduleAdvertisement(domain string, routes ...netip.Prefix) { e.queue.Add(func() { - if err := e.routeAdvertiser.AdvertiseRoute(routes...); err != nil { - e.logf("failed to advertise routes for %s: %v: %v", domain, routes, err) - return + if e.routeAdvertiser != nil { + if err := e.routeAdvertiser.AdvertiseRoute(routes...); err != nil { + e.logf("failed to advertise routes for %s: %v: %v", domain, routes, err) + return + } } + e.updatePub.Publish(appctype.RouteUpdate{Advertise: routes}) e.mu.Lock() defer e.mu.Unlock() @@ -559,9 +492,7 @@ func (e *AppConnector) scheduleAdvertisement(domain string, routes ...netip.Pref e.logf("[v2] advertised route for %v: %v", domain, addr) } } - if err := e.storeRoutesLocked(); err != nil { - e.logf("failed to store route info: %v", err) - } + e.storeRoutesLocked() }) } @@ -579,8 +510,8 @@ func (e *AppConnector) addDomainAddrLocked(domain string, addr netip.Addr) { slices.SortFunc(e.domains[domain], compareAddr) } -func compareAddr(l, r netip.Addr) int { - return l.Compare(r) +func compareAddr(a, b netip.Addr) int { + return a.Compare(b) } // routesWithout returns a without b where a and b diff --git a/appc/appconnector_test.go b/appc/appconnector_test.go index 7dba8cebd..5c362d6fd 100644 --- a/appc/appconnector_test.go +++ b/appc/appconnector_test.go @@ -4,35 +4,40 @@ package appc import ( - "context" + stdcmp "cmp" + "fmt" "net/netip" "reflect" "slices" + "sync/atomic" "testing" "time" - xmaps "golang.org/x/exp/maps" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" "golang.org/x/net/dns/dnsmessage" "tailscale.com/appc/appctest" "tailscale.com/tstest" + "tailscale.com/types/appctype" "tailscale.com/util/clientmetric" + "tailscale.com/util/eventbus/eventbustest" "tailscale.com/util/mak" "tailscale.com/util/must" + "tailscale.com/util/slicesx" ) -func fakeStoreRoutes(*RouteInfo) error { return nil } - func TestUpdateDomains(t *testing.T) { + ctx := t.Context() + bus := eventbustest.NewBus(t) for _, shouldStore := range []bool{false, true} { - ctx := context.Background() - var a *AppConnector - if shouldStore { - a = NewAppConnector(t.Logf, &appctest.RouteCollector{}, &RouteInfo{}, fakeStoreRoutes) - } else { - a = NewAppConnector(t.Logf, &appctest.RouteCollector{}, nil, nil) - } - a.UpdateDomains([]string{"example.com"}) + a := NewAppConnector(Config{ + Logf: t.Logf, + EventBus: bus, + HasStoredRoutes: shouldStore, + }) + t.Cleanup(a.Close) + a.UpdateDomains([]string{"example.com"}) a.Wait(ctx) if got, want := a.Domains().AsSlice(), []string{"example.com"}; !slices.Equal(got, want) { t.Errorf("got %v; want %v", got, want) @@ -50,26 +55,32 @@ func TestUpdateDomains(t *testing.T) { // domains are explicitly downcased on set. a.UpdateDomains([]string{"UP.EXAMPLE.COM"}) a.Wait(ctx) - if got, want := xmaps.Keys(a.domains), []string{"up.example.com"}; !slices.Equal(got, want) { + if got, want := slicesx.MapKeys(a.domains), []string{"up.example.com"}; !slices.Equal(got, want) { t.Errorf("got %v; want %v", got, want) } } } func TestUpdateRoutes(t *testing.T) { + ctx := t.Context() + bus := eventbustest.NewBus(t) for _, shouldStore := range []bool{false, true} { - ctx := context.Background() + w := eventbustest.NewWatcher(t, bus) rc := &appctest.RouteCollector{} - var a *AppConnector - if shouldStore { - a = NewAppConnector(t.Logf, rc, &RouteInfo{}, fakeStoreRoutes) - } else { - a = NewAppConnector(t.Logf, rc, nil, nil) - } + a := NewAppConnector(Config{ + Logf: t.Logf, + EventBus: bus, + RouteAdvertiser: rc, + HasStoredRoutes: shouldStore, + }) + t.Cleanup(a.Close) + a.updateDomains([]string{"*.example.com"}) // This route should be collapsed into the range - a.ObserveDNSResponse(dnsResponse("a.example.com.", "192.0.2.1")) + if err := a.ObserveDNSResponse(dnsResponse("a.example.com.", "192.0.2.1")); err != nil { + t.Errorf("ObserveDNSResponse: %v", err) + } a.Wait(ctx) if !slices.Equal(rc.Routes(), []netip.Prefix{netip.MustParsePrefix("192.0.2.1/32")}) { @@ -77,11 +88,14 @@ func TestUpdateRoutes(t *testing.T) { } // This route should not be collapsed or removed - a.ObserveDNSResponse(dnsResponse("b.example.com.", "192.0.0.1")) + if err := a.ObserveDNSResponse(dnsResponse("b.example.com.", "192.0.0.1")); err != nil { + t.Errorf("ObserveDNSResponse: %v", err) + } a.Wait(ctx) routes := []netip.Prefix{netip.MustParsePrefix("192.0.2.0/24"), netip.MustParsePrefix("192.0.0.1/32")} a.updateRoutes(routes) + a.Wait(ctx) slices.SortFunc(rc.Routes(), prefixCompare) rc.SetRoutes(slices.Compact(rc.Routes())) @@ -97,41 +111,76 @@ func TestUpdateRoutes(t *testing.T) { if !slices.EqualFunc(rc.RemovedRoutes(), wantRemoved, prefixEqual) { t.Fatalf("unexpected removed routes: %v", rc.RemovedRoutes()) } + + if err := eventbustest.Expect(w, + eqUpdate(appctype.RouteUpdate{Advertise: prefixes("192.0.2.1/32")}), + eventbustest.Type[appctype.RouteInfo](), + eqUpdate(appctype.RouteUpdate{Advertise: prefixes("192.0.0.1/32")}), + eventbustest.Type[appctype.RouteInfo](), + eqUpdate(appctype.RouteUpdate{ + Advertise: prefixes("192.0.0.1/32", "192.0.2.0/24"), + Unadvertise: prefixes("192.0.2.1/32"), + }), + eventbustest.Type[appctype.RouteInfo](), + ); err != nil { + t.Error(err) + } } } func TestUpdateRoutesUnadvertisesContainedRoutes(t *testing.T) { + ctx := t.Context() + bus := eventbustest.NewBus(t) for _, shouldStore := range []bool{false, true} { + w := eventbustest.NewWatcher(t, bus) rc := &appctest.RouteCollector{} - var a *AppConnector - if shouldStore { - a = NewAppConnector(t.Logf, rc, &RouteInfo{}, fakeStoreRoutes) - } else { - a = NewAppConnector(t.Logf, rc, nil, nil) - } + a := NewAppConnector(Config{ + Logf: t.Logf, + EventBus: bus, + RouteAdvertiser: rc, + HasStoredRoutes: shouldStore, + }) + t.Cleanup(a.Close) + mak.Set(&a.domains, "example.com", []netip.Addr{netip.MustParseAddr("192.0.2.1")}) rc.SetRoutes([]netip.Prefix{netip.MustParsePrefix("192.0.2.1/32")}) routes := []netip.Prefix{netip.MustParsePrefix("192.0.2.0/24")} a.updateRoutes(routes) + a.Wait(ctx) if !slices.EqualFunc(routes, rc.Routes(), prefixEqual) { t.Fatalf("got %v, want %v", rc.Routes(), routes) } + + if err := eventbustest.ExpectExactly(w, + eqUpdate(appctype.RouteUpdate{ + Advertise: prefixes("192.0.2.0/24"), + Unadvertise: prefixes("192.0.2.1/32"), + }), + eventbustest.Type[appctype.RouteInfo](), + ); err != nil { + t.Error(err) + } } } func TestDomainRoutes(t *testing.T) { + bus := eventbustest.NewBus(t) for _, shouldStore := range []bool{false, true} { + w := eventbustest.NewWatcher(t, bus) rc := &appctest.RouteCollector{} - var a *AppConnector - if shouldStore { - a = NewAppConnector(t.Logf, rc, &RouteInfo{}, fakeStoreRoutes) - } else { - a = NewAppConnector(t.Logf, rc, nil, nil) - } + a := NewAppConnector(Config{ + Logf: t.Logf, + EventBus: bus, + RouteAdvertiser: rc, + HasStoredRoutes: shouldStore, + }) + t.Cleanup(a.Close) a.updateDomains([]string{"example.com"}) - a.ObserveDNSResponse(dnsResponse("example.com.", "192.0.0.8")) - a.Wait(context.Background()) + if err := a.ObserveDNSResponse(dnsResponse("example.com.", "192.0.0.8")); err != nil { + t.Errorf("ObserveDNSResponse: %v", err) + } + a.Wait(t.Context()) want := map[string][]netip.Addr{ "example.com": {netip.MustParseAddr("192.0.0.8")}, @@ -140,22 +189,34 @@ func TestDomainRoutes(t *testing.T) { if got := a.DomainRoutes(); !reflect.DeepEqual(got, want) { t.Fatalf("DomainRoutes: got %v, want %v", got, want) } + + if err := eventbustest.ExpectExactly(w, + eqUpdate(appctype.RouteUpdate{Advertise: prefixes("192.0.0.8/32")}), + eventbustest.Type[appctype.RouteInfo](), + ); err != nil { + t.Error(err) + } } } func TestObserveDNSResponse(t *testing.T) { + ctx := t.Context() + bus := eventbustest.NewBus(t) for _, shouldStore := range []bool{false, true} { - ctx := context.Background() + w := eventbustest.NewWatcher(t, bus) rc := &appctest.RouteCollector{} - var a *AppConnector - if shouldStore { - a = NewAppConnector(t.Logf, rc, &RouteInfo{}, fakeStoreRoutes) - } else { - a = NewAppConnector(t.Logf, rc, nil, nil) - } + a := NewAppConnector(Config{ + Logf: t.Logf, + EventBus: bus, + RouteAdvertiser: rc, + HasStoredRoutes: shouldStore, + }) + t.Cleanup(a.Close) // a has no domains configured, so it should not advertise any routes - a.ObserveDNSResponse(dnsResponse("example.com.", "192.0.0.8")) + if err := a.ObserveDNSResponse(dnsResponse("example.com.", "192.0.0.8")); err != nil { + t.Errorf("ObserveDNSResponse: %v", err) + } if got, want := rc.Routes(), ([]netip.Prefix)(nil); !slices.Equal(got, want) { t.Errorf("got %v; want %v", got, want) } @@ -163,7 +224,9 @@ func TestObserveDNSResponse(t *testing.T) { wantRoutes := []netip.Prefix{netip.MustParsePrefix("192.0.0.8/32")} a.updateDomains([]string{"example.com"}) - a.ObserveDNSResponse(dnsResponse("example.com.", "192.0.0.8")) + if err := a.ObserveDNSResponse(dnsResponse("example.com.", "192.0.0.8")); err != nil { + t.Errorf("ObserveDNSResponse: %v", err) + } a.Wait(ctx) if got, want := rc.Routes(), wantRoutes; !slices.Equal(got, want) { t.Errorf("got %v; want %v", got, want) @@ -172,7 +235,9 @@ func TestObserveDNSResponse(t *testing.T) { // a CNAME record chain should result in a route being added if the chain // matches a routed domain. a.updateDomains([]string{"www.example.com", "example.com"}) - a.ObserveDNSResponse(dnsCNAMEResponse("192.0.0.9", "www.example.com.", "chain.example.com.", "example.com.")) + if err := a.ObserveDNSResponse(dnsCNAMEResponse("192.0.0.9", "www.example.com.", "chain.example.com.", "example.com.")); err != nil { + t.Errorf("ObserveDNSResponse: %v", err) + } a.Wait(ctx) wantRoutes = append(wantRoutes, netip.MustParsePrefix("192.0.0.9/32")) if got, want := rc.Routes(), wantRoutes; !slices.Equal(got, want) { @@ -181,7 +246,9 @@ func TestObserveDNSResponse(t *testing.T) { // a CNAME record chain should result in a route being added if the chain // even if only found in the middle of the chain - a.ObserveDNSResponse(dnsCNAMEResponse("192.0.0.10", "outside.example.org.", "www.example.com.", "example.org.")) + if err := a.ObserveDNSResponse(dnsCNAMEResponse("192.0.0.10", "outside.example.org.", "www.example.com.", "example.org.")); err != nil { + t.Errorf("ObserveDNSResponse: %v", err) + } a.Wait(ctx) wantRoutes = append(wantRoutes, netip.MustParsePrefix("192.0.0.10/32")) if got, want := rc.Routes(), wantRoutes; !slices.Equal(got, want) { @@ -190,14 +257,18 @@ func TestObserveDNSResponse(t *testing.T) { wantRoutes = append(wantRoutes, netip.MustParsePrefix("2001:db8::1/128")) - a.ObserveDNSResponse(dnsResponse("example.com.", "2001:db8::1")) + if err := a.ObserveDNSResponse(dnsResponse("example.com.", "2001:db8::1")); err != nil { + t.Errorf("ObserveDNSResponse: %v", err) + } a.Wait(ctx) if got, want := rc.Routes(), wantRoutes; !slices.Equal(got, want) { t.Errorf("got %v; want %v", got, want) } // don't re-advertise routes that have already been advertised - a.ObserveDNSResponse(dnsResponse("example.com.", "2001:db8::1")) + if err := a.ObserveDNSResponse(dnsResponse("example.com.", "2001:db8::1")); err != nil { + t.Errorf("ObserveDNSResponse: %v", err) + } a.Wait(ctx) if !slices.Equal(rc.Routes(), wantRoutes) { t.Errorf("rc.Routes(): got %v; want %v", rc.Routes(), wantRoutes) @@ -207,7 +278,9 @@ func TestObserveDNSResponse(t *testing.T) { pfx := netip.MustParsePrefix("192.0.2.0/24") a.updateRoutes([]netip.Prefix{pfx}) wantRoutes = append(wantRoutes, pfx) - a.ObserveDNSResponse(dnsResponse("example.com.", "192.0.2.1")) + if err := a.ObserveDNSResponse(dnsResponse("example.com.", "192.0.2.1")); err != nil { + t.Errorf("ObserveDNSResponse: %v", err) + } a.Wait(ctx) if !slices.Equal(rc.Routes(), wantRoutes) { t.Errorf("rc.Routes(): got %v; want %v", rc.Routes(), wantRoutes) @@ -215,22 +288,43 @@ func TestObserveDNSResponse(t *testing.T) { if !slices.Contains(a.domains["example.com"], netip.MustParseAddr("192.0.2.1")) { t.Errorf("missing %v from %v", "192.0.2.1", a.domains["exmaple.com"]) } + + if err := eventbustest.ExpectExactly(w, + eqUpdate(appctype.RouteUpdate{Advertise: prefixes("192.0.0.8/32")}), // from initial DNS response, via example.com + eventbustest.Type[appctype.RouteInfo](), + eqUpdate(appctype.RouteUpdate{Advertise: prefixes("192.0.0.9/32")}), // from CNAME response + eventbustest.Type[appctype.RouteInfo](), + eqUpdate(appctype.RouteUpdate{Advertise: prefixes("192.0.0.10/32")}), // from CNAME response, mid-chain + eventbustest.Type[appctype.RouteInfo](), + eqUpdate(appctype.RouteUpdate{Advertise: prefixes("2001:db8::1/128")}), // v6 DNS response + eventbustest.Type[appctype.RouteInfo](), + eqUpdate(appctype.RouteUpdate{Advertise: prefixes("192.0.2.0/24")}), // additional prefix + eventbustest.Type[appctype.RouteInfo](), + // N.B. no update for 192.0.2.1 as it is already covered + ); err != nil { + t.Error(err) + } } } func TestWildcardDomains(t *testing.T) { + ctx := t.Context() + bus := eventbustest.NewBus(t) for _, shouldStore := range []bool{false, true} { - ctx := context.Background() + w := eventbustest.NewWatcher(t, bus) rc := &appctest.RouteCollector{} - var a *AppConnector - if shouldStore { - a = NewAppConnector(t.Logf, rc, &RouteInfo{}, fakeStoreRoutes) - } else { - a = NewAppConnector(t.Logf, rc, nil, nil) - } + a := NewAppConnector(Config{ + Logf: t.Logf, + EventBus: bus, + RouteAdvertiser: rc, + HasStoredRoutes: shouldStore, + }) + t.Cleanup(a.Close) a.updateDomains([]string{"*.example.com"}) - a.ObserveDNSResponse(dnsResponse("foo.example.com.", "192.0.0.8")) + if err := a.ObserveDNSResponse(dnsResponse("foo.example.com.", "192.0.0.8")); err != nil { + t.Errorf("ObserveDNSResponse: %v", err) + } a.Wait(ctx) if got, want := rc.Routes(), []netip.Prefix{netip.MustParsePrefix("192.0.0.8/32")}; !slices.Equal(got, want) { t.Errorf("routes: got %v; want %v", got, want) @@ -252,6 +346,13 @@ func TestWildcardDomains(t *testing.T) { if len(a.wildcards) != 1 { t.Errorf("expected only one wildcard domain, got %v", a.wildcards) } + + if err := eventbustest.ExpectExactly(w, + eqUpdate(appctype.RouteUpdate{Advertise: prefixes("192.0.0.8/32")}), + eventbustest.Type[appctype.RouteInfo](), + ); err != nil { + t.Error(err) + } } } @@ -367,8 +468,10 @@ func prefixes(in ...string) []netip.Prefix { } func TestUpdateRouteRouteRemoval(t *testing.T) { + ctx := t.Context() + bus := eventbustest.NewBus(t) for _, shouldStore := range []bool{false, true} { - ctx := context.Background() + w := eventbustest.NewWatcher(t, bus) rc := &appctest.RouteCollector{} assertRoutes := func(prefix string, routes, removedRoutes []netip.Prefix) { @@ -380,12 +483,14 @@ func TestUpdateRouteRouteRemoval(t *testing.T) { } } - var a *AppConnector - if shouldStore { - a = NewAppConnector(t.Logf, rc, &RouteInfo{}, fakeStoreRoutes) - } else { - a = NewAppConnector(t.Logf, rc, nil, nil) - } + a := NewAppConnector(Config{ + Logf: t.Logf, + EventBus: bus, + RouteAdvertiser: rc, + HasStoredRoutes: shouldStore, + }) + t.Cleanup(a.Close) + // nothing has yet been advertised assertRoutes("appc init", []netip.Prefix{}, []netip.Prefix{}) @@ -408,12 +513,21 @@ func TestUpdateRouteRouteRemoval(t *testing.T) { wantRemovedRoutes = prefixes("1.2.3.2/32") } assertRoutes("removal", wantRoutes, wantRemovedRoutes) + + if err := eventbustest.Expect(w, + eqUpdate(appctype.RouteUpdate{Advertise: prefixes("1.2.3.1/32", "1.2.3.2/32")}), // no duplicates here + eventbustest.Type[appctype.RouteInfo](), + ); err != nil { + t.Error(err) + } } } func TestUpdateDomainRouteRemoval(t *testing.T) { + ctx := t.Context() + bus := eventbustest.NewBus(t) for _, shouldStore := range []bool{false, true} { - ctx := context.Background() + w := eventbustest.NewWatcher(t, bus) rc := &appctest.RouteCollector{} assertRoutes := func(prefix string, routes, removedRoutes []netip.Prefix) { @@ -425,12 +539,14 @@ func TestUpdateDomainRouteRemoval(t *testing.T) { } } - var a *AppConnector - if shouldStore { - a = NewAppConnector(t.Logf, rc, &RouteInfo{}, fakeStoreRoutes) - } else { - a = NewAppConnector(t.Logf, rc, nil, nil) - } + a := NewAppConnector(Config{ + Logf: t.Logf, + EventBus: bus, + RouteAdvertiser: rc, + HasStoredRoutes: shouldStore, + }) + t.Cleanup(a.Close) + assertRoutes("appc init", []netip.Prefix{}, []netip.Prefix{}) a.UpdateDomainsAndRoutes([]string{"a.example.com", "b.example.com"}, []netip.Prefix{}) @@ -438,10 +554,16 @@ func TestUpdateDomainRouteRemoval(t *testing.T) { // adding domains doesn't immediately cause any routes to be advertised assertRoutes("update domains", []netip.Prefix{}, []netip.Prefix{}) - a.ObserveDNSResponse(dnsResponse("a.example.com.", "1.2.3.1")) - a.ObserveDNSResponse(dnsResponse("a.example.com.", "1.2.3.2")) - a.ObserveDNSResponse(dnsResponse("b.example.com.", "1.2.3.3")) - a.ObserveDNSResponse(dnsResponse("b.example.com.", "1.2.3.4")) + for _, res := range [][]byte{ + dnsResponse("a.example.com.", "1.2.3.1"), + dnsResponse("a.example.com.", "1.2.3.2"), + dnsResponse("b.example.com.", "1.2.3.3"), + dnsResponse("b.example.com.", "1.2.3.4"), + } { + if err := a.ObserveDNSResponse(res); err != nil { + t.Errorf("ObserveDNSResponse: %v", err) + } + } a.Wait(ctx) // observing dns responses causes routes to be advertised assertRoutes("observed dns", prefixes("1.2.3.1/32", "1.2.3.2/32", "1.2.3.3/32", "1.2.3.4/32"), []netip.Prefix{}) @@ -457,12 +579,30 @@ func TestUpdateDomainRouteRemoval(t *testing.T) { wantRemovedRoutes = prefixes("1.2.3.3/32", "1.2.3.4/32") } assertRoutes("removal", wantRoutes, wantRemovedRoutes) + + wantEvents := []any{ + // Each DNS record observed triggers an update. + eqUpdate(appctype.RouteUpdate{Advertise: prefixes("1.2.3.1/32")}), + eqUpdate(appctype.RouteUpdate{Advertise: prefixes("1.2.3.2/32")}), + eqUpdate(appctype.RouteUpdate{Advertise: prefixes("1.2.3.3/32")}), + eqUpdate(appctype.RouteUpdate{Advertise: prefixes("1.2.3.4/32")}), + } + if shouldStore { + wantEvents = append(wantEvents, eqUpdate(appctype.RouteUpdate{ + Unadvertise: prefixes("1.2.3.3/32", "1.2.3.4/32"), + })) + } + if err := eventbustest.Expect(w, wantEvents...); err != nil { + t.Error(err) + } } } func TestUpdateWildcardRouteRemoval(t *testing.T) { + ctx := t.Context() + bus := eventbustest.NewBus(t) for _, shouldStore := range []bool{false, true} { - ctx := context.Background() + w := eventbustest.NewWatcher(t, bus) rc := &appctest.RouteCollector{} assertRoutes := func(prefix string, routes, removedRoutes []netip.Prefix) { @@ -474,12 +614,14 @@ func TestUpdateWildcardRouteRemoval(t *testing.T) { } } - var a *AppConnector - if shouldStore { - a = NewAppConnector(t.Logf, rc, &RouteInfo{}, fakeStoreRoutes) - } else { - a = NewAppConnector(t.Logf, rc, nil, nil) - } + a := NewAppConnector(Config{ + Logf: t.Logf, + EventBus: bus, + RouteAdvertiser: rc, + HasStoredRoutes: shouldStore, + }) + t.Cleanup(a.Close) + assertRoutes("appc init", []netip.Prefix{}, []netip.Prefix{}) a.UpdateDomainsAndRoutes([]string{"a.example.com", "*.b.example.com"}, []netip.Prefix{}) @@ -487,10 +629,16 @@ func TestUpdateWildcardRouteRemoval(t *testing.T) { // adding domains doesn't immediately cause any routes to be advertised assertRoutes("update domains", []netip.Prefix{}, []netip.Prefix{}) - a.ObserveDNSResponse(dnsResponse("a.example.com.", "1.2.3.1")) - a.ObserveDNSResponse(dnsResponse("a.example.com.", "1.2.3.2")) - a.ObserveDNSResponse(dnsResponse("1.b.example.com.", "1.2.3.3")) - a.ObserveDNSResponse(dnsResponse("2.b.example.com.", "1.2.3.4")) + for _, res := range [][]byte{ + dnsResponse("a.example.com.", "1.2.3.1"), + dnsResponse("a.example.com.", "1.2.3.2"), + dnsResponse("1.b.example.com.", "1.2.3.3"), + dnsResponse("2.b.example.com.", "1.2.3.4"), + } { + if err := a.ObserveDNSResponse(res); err != nil { + t.Errorf("ObserveDNSResponse: %v", err) + } + } a.Wait(ctx) // observing dns responses causes routes to be advertised assertRoutes("observed dns", prefixes("1.2.3.1/32", "1.2.3.2/32", "1.2.3.3/32", "1.2.3.4/32"), []netip.Prefix{}) @@ -506,6 +654,22 @@ func TestUpdateWildcardRouteRemoval(t *testing.T) { wantRemovedRoutes = prefixes("1.2.3.3/32", "1.2.3.4/32") } assertRoutes("removal", wantRoutes, wantRemovedRoutes) + + wantEvents := []any{ + // Each DNS record observed triggers an update. + eqUpdate(appctype.RouteUpdate{Advertise: prefixes("1.2.3.1/32")}), + eqUpdate(appctype.RouteUpdate{Advertise: prefixes("1.2.3.2/32")}), + eqUpdate(appctype.RouteUpdate{Advertise: prefixes("1.2.3.3/32")}), + eqUpdate(appctype.RouteUpdate{Advertise: prefixes("1.2.3.4/32")}), + } + if shouldStore { + wantEvents = append(wantEvents, eqUpdate(appctype.RouteUpdate{ + Unadvertise: prefixes("1.2.3.3/32", "1.2.3.4/32"), + })) + } + if err := eventbustest.Expect(w, wantEvents...); err != nil { + t.Error(err) + } } } @@ -602,3 +766,107 @@ func TestMetricBucketsAreSorted(t *testing.T) { t.Errorf("metricStoreRoutesNBuckets must be in order") } } + +// TestUpdateRoutesDeadlock is a regression test for a deadlock in +// LocalBackend<->AppConnector interaction. When using real LocalBackend as the +// routeAdvertiser, calls to Advertise/UnadvertiseRoutes can end up calling +// back into AppConnector via authReconfig. If everything is called +// synchronously, this results in a deadlock on AppConnector.mu. +// +// TODO(creachadair, 2025-09-18): Remove this along with the advertiser +// interface once the LocalBackend is switched to use the event bus and the +// tests have been updated not to need it. +func TestUpdateRoutesDeadlock(t *testing.T) { + ctx := t.Context() + bus := eventbustest.NewBus(t) + w := eventbustest.NewWatcher(t, bus) + rc := &appctest.RouteCollector{} + a := NewAppConnector(Config{ + Logf: t.Logf, + EventBus: bus, + RouteAdvertiser: rc, + HasStoredRoutes: true, + }) + t.Cleanup(a.Close) + + advertiseCalled := new(atomic.Bool) + unadvertiseCalled := new(atomic.Bool) + rc.AdvertiseCallback = func() { + // Call something that requires a.mu to be held. + a.DomainRoutes() + advertiseCalled.Store(true) + } + rc.UnadvertiseCallback = func() { + // Call something that requires a.mu to be held. + a.DomainRoutes() + unadvertiseCalled.Store(true) + } + + a.updateDomains([]string{"example.com"}) + a.Wait(ctx) + + // Trigger rc.AdveriseRoute. + a.updateRoutes( + []netip.Prefix{ + netip.MustParsePrefix("127.0.0.1/32"), + netip.MustParsePrefix("127.0.0.2/32"), + }, + ) + a.Wait(ctx) + // Trigger rc.UnadveriseRoute. + a.updateRoutes( + []netip.Prefix{ + netip.MustParsePrefix("127.0.0.1/32"), + }, + ) + a.Wait(ctx) + + if !advertiseCalled.Load() { + t.Error("AdvertiseRoute was not called") + } + if !unadvertiseCalled.Load() { + t.Error("UnadvertiseRoute was not called") + } + + if want := []netip.Prefix{netip.MustParsePrefix("127.0.0.1/32")}; !slices.Equal(slices.Compact(rc.Routes()), want) { + t.Fatalf("got %v, want %v", rc.Routes(), want) + } + + if err := eventbustest.ExpectExactly(w, + eqUpdate(appctype.RouteUpdate{Advertise: prefixes("127.0.0.1/32", "127.0.0.2/32")}), + eventbustest.Type[appctype.RouteInfo](), + eqUpdate(appctype.RouteUpdate{Advertise: prefixes("127.0.0.1/32"), Unadvertise: prefixes("127.0.0.2/32")}), + eventbustest.Type[appctype.RouteInfo](), + ); err != nil { + t.Error(err) + } +} + +type textUpdate struct { + Advertise []string + Unadvertise []string +} + +func routeUpdateToText(u appctype.RouteUpdate) textUpdate { + var out textUpdate + for _, p := range u.Advertise { + out.Advertise = append(out.Advertise, p.String()) + } + for _, p := range u.Unadvertise { + out.Unadvertise = append(out.Unadvertise, p.String()) + } + return out +} + +// eqUpdate generates an eventbus test filter that matches a appctype.RouteUpdate +// message equal to want, or reports an error giving a human-readable diff. +func eqUpdate(want appctype.RouteUpdate) func(appctype.RouteUpdate) error { + return func(got appctype.RouteUpdate) error { + if diff := cmp.Diff(routeUpdateToText(got), routeUpdateToText(want), + cmpopts.SortSlices(stdcmp.Less[string]), + ); diff != "" { + return fmt.Errorf("wrong update (-got, +want):\n%s", diff) + } + return nil + } +} diff --git a/appc/appctest/appctest.go b/appc/appctest/appctest.go index aa77bc3b4..9726a2b97 100644 --- a/appc/appctest/appctest.go +++ b/appc/appctest/appctest.go @@ -11,12 +11,22 @@ import ( // RouteCollector is a test helper that collects the list of routes advertised type RouteCollector struct { + // AdvertiseCallback (optional) is called synchronously from + // AdvertiseRoute. + AdvertiseCallback func() + // UnadvertiseCallback (optional) is called synchronously from + // UnadvertiseRoute. + UnadvertiseCallback func() + routes []netip.Prefix removedRoutes []netip.Prefix } func (rc *RouteCollector) AdvertiseRoute(pfx ...netip.Prefix) error { rc.routes = append(rc.routes, pfx...) + if rc.AdvertiseCallback != nil { + rc.AdvertiseCallback() + } return nil } @@ -30,6 +40,9 @@ func (rc *RouteCollector) UnadvertiseRoute(toRemove ...netip.Prefix) error { rc.removedRoutes = append(rc.removedRoutes, r) } } + if rc.UnadvertiseCallback != nil { + rc.UnadvertiseCallback() + } return nil } diff --git a/appc/ippool.go b/appc/ippool.go new file mode 100644 index 000000000..a2e86a7c2 --- /dev/null +++ b/appc/ippool.go @@ -0,0 +1,61 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package appc + +import ( + "errors" + "net/netip" + + "go4.org/netipx" +) + +// errPoolExhausted is returned when there are no more addresses to iterate over. +var errPoolExhausted = errors.New("ip pool exhausted") + +// ippool allows for iteration over all the addresses within a netipx.IPSet. +// netipx.IPSet has a Ranges call that returns the "minimum and sorted set of IP ranges that covers [the set]". +// netipx.IPRange is "an inclusive range of IP addresses from the same address family.". So we can iterate over +// all the addresses in the set by keeping a track of the last address we returned, calling Next on the last address +// to get the new one, and if we run off the edge of the current range, starting on the next one. +type ippool struct { + // ranges defines the addresses in the pool + ranges []netipx.IPRange + // last is internal tracking of which the last address provided was. + last netip.Addr + // rangeIdx is internal tracking of which netipx.IPRange from the IPSet we are currently on. + rangeIdx int +} + +func newIPPool(ipset *netipx.IPSet) *ippool { + if ipset == nil { + return &ippool{} + } + return &ippool{ranges: ipset.Ranges()} +} + +// next returns the next address from the set, or errPoolExhausted if we have +// iterated over the whole set. +func (ipp *ippool) next() (netip.Addr, error) { + if ipp.rangeIdx >= len(ipp.ranges) { + // ipset is empty or we have iterated off the end + return netip.Addr{}, errPoolExhausted + } + if !ipp.last.IsValid() { + // not initialized yet + ipp.last = ipp.ranges[0].From() + return ipp.last, nil + } + currRange := ipp.ranges[ipp.rangeIdx] + if ipp.last == currRange.To() { + // then we need to move to the next range + ipp.rangeIdx++ + if ipp.rangeIdx >= len(ipp.ranges) { + return netip.Addr{}, errPoolExhausted + } + ipp.last = ipp.ranges[ipp.rangeIdx].From() + return ipp.last, nil + } + ipp.last = ipp.last.Next() + return ipp.last, nil +} diff --git a/appc/ippool_test.go b/appc/ippool_test.go new file mode 100644 index 000000000..64b76738f --- /dev/null +++ b/appc/ippool_test.go @@ -0,0 +1,60 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package appc + +import ( + "errors" + "net/netip" + "testing" + + "go4.org/netipx" + "tailscale.com/util/must" +) + +func TestNext(t *testing.T) { + a := ippool{} + _, err := a.next() + if !errors.Is(err, errPoolExhausted) { + t.Fatalf("expected errPoolExhausted, got %v", err) + } + + var isb netipx.IPSetBuilder + ipset := must.Get(isb.IPSet()) + b := newIPPool(ipset) + _, err = b.next() + if !errors.Is(err, errPoolExhausted) { + t.Fatalf("expected errPoolExhausted, got %v", err) + } + + isb.AddRange(netipx.IPRangeFrom(netip.MustParseAddr("192.168.0.0"), netip.MustParseAddr("192.168.0.2"))) + isb.AddRange(netipx.IPRangeFrom(netip.MustParseAddr("200.0.0.0"), netip.MustParseAddr("200.0.0.0"))) + isb.AddRange(netipx.IPRangeFrom(netip.MustParseAddr("201.0.0.0"), netip.MustParseAddr("201.0.0.1"))) + ipset = must.Get(isb.IPSet()) + c := newIPPool(ipset) + expected := []string{ + "192.168.0.0", + "192.168.0.1", + "192.168.0.2", + "200.0.0.0", + "201.0.0.0", + "201.0.0.1", + } + for i, want := range expected { + addr, err := c.next() + if err != nil { + t.Fatal(err) + } + if addr != netip.MustParseAddr(want) { + t.Fatalf("next call %d want: %s, got: %v", i, want, addr) + } + } + _, err = c.next() + if !errors.Is(err, errPoolExhausted) { + t.Fatalf("expected errPoolExhausted, got %v", err) + } + _, err = c.next() + if !errors.Is(err, errPoolExhausted) { + t.Fatalf("expected errPoolExhausted, got %v", err) + } +} diff --git a/appc/observe.go b/appc/observe.go new file mode 100644 index 000000000..06dc04f9d --- /dev/null +++ b/appc/observe.go @@ -0,0 +1,132 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !ts_omit_appconnectors + +package appc + +import ( + "net/netip" + "strings" + + "golang.org/x/net/dns/dnsmessage" + "tailscale.com/util/mak" +) + +// ObserveDNSResponse is a callback invoked by the DNS resolver when a DNS +// response is being returned over the PeerAPI. The response is parsed and +// matched against the configured domains, if matched the routeAdvertiser is +// advised to advertise the discovered route. +func (e *AppConnector) ObserveDNSResponse(res []byte) error { + var p dnsmessage.Parser + if _, err := p.Start(res); err != nil { + return err + } + if err := p.SkipAllQuestions(); err != nil { + return err + } + + // cnameChain tracks a chain of CNAMEs for a given query in order to reverse + // a CNAME chain back to the original query for flattening. The keys are + // CNAME record targets, and the value is the name the record answers, so + // for www.example.com CNAME example.com, the map would contain + // ["example.com"] = "www.example.com". + var cnameChain map[string]string + + // addressRecords is a list of address records found in the response. + var addressRecords map[string][]netip.Addr + + for { + h, err := p.AnswerHeader() + if err == dnsmessage.ErrSectionDone { + break + } + if err != nil { + return err + } + + if h.Class != dnsmessage.ClassINET { + if err := p.SkipAnswer(); err != nil { + return err + } + continue + } + + switch h.Type { + case dnsmessage.TypeCNAME, dnsmessage.TypeA, dnsmessage.TypeAAAA: + default: + if err := p.SkipAnswer(); err != nil { + return err + } + continue + + } + + domain := strings.TrimSuffix(strings.ToLower(h.Name.String()), ".") + if len(domain) == 0 { + continue + } + + if h.Type == dnsmessage.TypeCNAME { + res, err := p.CNAMEResource() + if err != nil { + return err + } + cname := strings.TrimSuffix(strings.ToLower(res.CNAME.String()), ".") + if len(cname) == 0 { + continue + } + mak.Set(&cnameChain, cname, domain) + continue + } + + switch h.Type { + case dnsmessage.TypeA: + r, err := p.AResource() + if err != nil { + return err + } + addr := netip.AddrFrom4(r.A) + mak.Set(&addressRecords, domain, append(addressRecords[domain], addr)) + case dnsmessage.TypeAAAA: + r, err := p.AAAAResource() + if err != nil { + return err + } + addr := netip.AddrFrom16(r.AAAA) + mak.Set(&addressRecords, domain, append(addressRecords[domain], addr)) + default: + if err := p.SkipAnswer(); err != nil { + return err + } + continue + } + } + + e.mu.Lock() + defer e.mu.Unlock() + + for domain, addrs := range addressRecords { + domain, isRouted := e.findRoutedDomainLocked(domain, cnameChain) + + // domain and none of the CNAMEs in the chain are routed + if !isRouted { + continue + } + + // advertise each address we have learned for the routed domain, that + // was not already known. + var toAdvertise []netip.Prefix + for _, addr := range addrs { + if !e.isAddrKnownLocked(domain, addr) { + toAdvertise = append(toAdvertise, netip.PrefixFrom(addr, addr.BitLen())) + } + } + + if len(toAdvertise) > 0 { + e.logf("[v2] observed new routes for %s: %s", domain, toAdvertise) + e.scheduleAdvertisement(domain, toAdvertise...) + } + } + return nil +} diff --git a/appc/observe_disabled.go b/appc/observe_disabled.go new file mode 100644 index 000000000..45aa285ea --- /dev/null +++ b/appc/observe_disabled.go @@ -0,0 +1,8 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build ts_omit_appconnectors + +package appc + +func (e *AppConnector) ObserveDNSResponse(res []byte) error { return nil } diff --git a/assert_ts_toolchain_match.go b/assert_ts_toolchain_match.go new file mode 100644 index 000000000..40b24b334 --- /dev/null +++ b/assert_ts_toolchain_match.go @@ -0,0 +1,27 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build tailscale_go + +package tailscaleroot + +import ( + "fmt" + "os" + "strings" +) + +func init() { + tsRev, ok := tailscaleToolchainRev() + if !ok { + panic("binary built with tailscale_go build tag but failed to read build info or find tailscale.toolchain.rev in build info") + } + want := strings.TrimSpace(GoToolchainRev) + if tsRev != want { + if os.Getenv("TS_PERMIT_TOOLCHAIN_MISMATCH") == "1" { + fmt.Fprintf(os.Stderr, "tailscale.toolchain.rev = %q, want %q; but ignoring due to TS_PERMIT_TOOLCHAIN_MISMATCH=1\n", tsRev, want) + return + } + panic(fmt.Sprintf("binary built with tailscale_go build tag but Go toolchain %q doesn't match github.com/tailscale/tailscale expected value %q; override this failure with TS_PERMIT_TOOLCHAIN_MISMATCH=1", tsRev, want)) + } +} diff --git a/atomicfile/atomicfile.go b/atomicfile/atomicfile.go index 5c18e85a8..9cae9bb75 100644 --- a/atomicfile/atomicfile.go +++ b/atomicfile/atomicfile.go @@ -15,8 +15,9 @@ import ( ) // WriteFile writes data to filename+some suffix, then renames it into filename. -// The perm argument is ignored on Windows. If the target filename already -// exists but is not a regular file, WriteFile returns an error. +// The perm argument is ignored on Windows, but if the target filename already +// exists then the target file's attributes and ACLs are preserved. If the target +// filename already exists but is not a regular file, WriteFile returns an error. func WriteFile(filename string, data []byte, perm os.FileMode) (err error) { fi, err := os.Stat(filename) if err == nil && !fi.Mode().IsRegular() { @@ -47,5 +48,9 @@ func WriteFile(filename string, data []byte, perm os.FileMode) (err error) { if err := f.Close(); err != nil { return err } - return os.Rename(tmpName, filename) + return Rename(tmpName, filename) } + +// Rename srcFile to dstFile, similar to [os.Rename] but preserving file +// attributes and ACLs on Windows. +func Rename(srcFile, dstFile string) error { return rename(srcFile, dstFile) } diff --git a/atomicfile/atomicfile_notwindows.go b/atomicfile/atomicfile_notwindows.go new file mode 100644 index 000000000..1ce2bb8ac --- /dev/null +++ b/atomicfile/atomicfile_notwindows.go @@ -0,0 +1,14 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !windows + +package atomicfile + +import ( + "os" +) + +func rename(srcFile, destFile string) error { + return os.Rename(srcFile, destFile) +} diff --git a/atomicfile/atomicfile_test.go b/atomicfile/atomicfile_test.go index 78c93e664..a081c9040 100644 --- a/atomicfile/atomicfile_test.go +++ b/atomicfile/atomicfile_test.go @@ -31,11 +31,11 @@ func TestDoesNotOverwriteIrregularFiles(t *testing.T) { // The least troublesome thing to make that is not a file is a unix socket. // Making a null device sadly requires root. - l, err := net.ListenUnix("unix", &net.UnixAddr{Name: path, Net: "unix"}) + ln, err := net.ListenUnix("unix", &net.UnixAddr{Name: path, Net: "unix"}) if err != nil { t.Fatal(err) } - defer l.Close() + defer ln.Close() err = WriteFile(path, []byte("hello"), 0644) if err == nil { diff --git a/atomicfile/atomicfile_windows.go b/atomicfile/atomicfile_windows.go new file mode 100644 index 000000000..c67762df2 --- /dev/null +++ b/atomicfile/atomicfile_windows.go @@ -0,0 +1,33 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package atomicfile + +import ( + "os" + + "golang.org/x/sys/windows" +) + +func rename(srcFile, destFile string) error { + // Use replaceFile when possible to preserve the original file's attributes and ACLs. + if err := replaceFile(destFile, srcFile); err == nil || err != windows.ERROR_FILE_NOT_FOUND { + return err + } + // destFile doesn't exist. Just do a normal rename. + return os.Rename(srcFile, destFile) +} + +func replaceFile(destFile, srcFile string) error { + destFile16, err := windows.UTF16PtrFromString(destFile) + if err != nil { + return err + } + + srcFile16, err := windows.UTF16PtrFromString(srcFile) + if err != nil { + return err + } + + return replaceFileW(destFile16, srcFile16, nil, 0, nil, nil) +} diff --git a/atomicfile/atomicfile_windows_test.go b/atomicfile/atomicfile_windows_test.go new file mode 100644 index 000000000..4dec1493e --- /dev/null +++ b/atomicfile/atomicfile_windows_test.go @@ -0,0 +1,146 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package atomicfile + +import ( + "os" + "testing" + "unsafe" + + "golang.org/x/sys/windows" +) + +var _SECURITY_RESOURCE_MANAGER_AUTHORITY = windows.SidIdentifierAuthority{[6]byte{0, 0, 0, 0, 0, 9}} + +// makeRandomSID generates a SID derived from a v4 GUID. +// This is basically the same algorithm used by browser sandboxes for generating +// random SIDs. +func makeRandomSID() (*windows.SID, error) { + guid, err := windows.GenerateGUID() + if err != nil { + return nil, err + } + + rids := *((*[4]uint32)(unsafe.Pointer(&guid))) + + var pSID *windows.SID + if err := windows.AllocateAndInitializeSid(&_SECURITY_RESOURCE_MANAGER_AUTHORITY, 4, rids[0], rids[1], rids[2], rids[3], 0, 0, 0, 0, &pSID); err != nil { + return nil, err + } + defer windows.FreeSid(pSID) + + // Make a copy that lives on the Go heap + return pSID.Copy() +} + +func getExistingFileSD(name string) (*windows.SECURITY_DESCRIPTOR, error) { + const infoFlags = windows.DACL_SECURITY_INFORMATION + return windows.GetNamedSecurityInfo(name, windows.SE_FILE_OBJECT, infoFlags) +} + +func getExistingFileDACL(name string) (*windows.ACL, error) { + sd, err := getExistingFileSD(name) + if err != nil { + return nil, err + } + + dacl, _, err := sd.DACL() + return dacl, err +} + +func addDenyACEForRandomSID(dacl *windows.ACL) (*windows.ACL, error) { + randomSID, err := makeRandomSID() + if err != nil { + return nil, err + } + + randomSIDTrustee := windows.TRUSTEE{nil, windows.NO_MULTIPLE_TRUSTEE, + windows.TRUSTEE_IS_SID, windows.TRUSTEE_IS_UNKNOWN, + windows.TrusteeValueFromSID(randomSID)} + + entries := []windows.EXPLICIT_ACCESS{ + { + windows.GENERIC_ALL, + windows.DENY_ACCESS, + windows.NO_INHERITANCE, + randomSIDTrustee, + }, + } + + return windows.ACLFromEntries(entries, dacl) +} + +func setExistingFileDACL(name string, dacl *windows.ACL) error { + return windows.SetNamedSecurityInfo(name, windows.SE_FILE_OBJECT, + windows.DACL_SECURITY_INFORMATION, nil, nil, dacl, nil) +} + +// makeOrigFileWithCustomDACL creates a new, temporary file with a custom +// DACL that we can check for later. It returns the name of the temporary +// file and the security descriptor for the file in SDDL format. +func makeOrigFileWithCustomDACL() (name, sddl string, err error) { + f, err := os.CreateTemp("", "foo*.tmp") + if err != nil { + return "", "", err + } + name = f.Name() + if err := f.Close(); err != nil { + return "", "", err + } + f = nil + defer func() { + if err != nil { + os.Remove(name) + } + }() + + dacl, err := getExistingFileDACL(name) + if err != nil { + return "", "", err + } + + // Add a harmless, deny-only ACE for a random SID that isn't used for anything + // (but that we can check for later). + dacl, err = addDenyACEForRandomSID(dacl) + if err != nil { + return "", "", err + } + + if err := setExistingFileDACL(name, dacl); err != nil { + return "", "", err + } + + sd, err := getExistingFileSD(name) + if err != nil { + return "", "", err + } + + return name, sd.String(), nil +} + +func TestPreserveSecurityInfo(t *testing.T) { + // Make a test file with a custom ACL. + origFileName, want, err := makeOrigFileWithCustomDACL() + if err != nil { + t.Fatalf("makeOrigFileWithCustomDACL returned %v", err) + } + t.Cleanup(func() { + os.Remove(origFileName) + }) + + if err := WriteFile(origFileName, []byte{}, 0); err != nil { + t.Fatalf("WriteFile returned %v", err) + } + + // We expect origFileName's security descriptor to be unchanged despite + // the WriteFile call. + sd, err := getExistingFileSD(origFileName) + if err != nil { + t.Fatalf("getExistingFileSD(%q) returned %v", origFileName, err) + } + + if got := sd.String(); got != want { + t.Errorf("security descriptor comparison failed: got %q, want %q", got, want) + } +} diff --git a/atomicfile/mksyscall.go b/atomicfile/mksyscall.go new file mode 100644 index 000000000..d8951a77c --- /dev/null +++ b/atomicfile/mksyscall.go @@ -0,0 +1,8 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package atomicfile + +//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go mksyscall.go + +//sys replaceFileW(replaced *uint16, replacement *uint16, backup *uint16, flags uint32, exclude unsafe.Pointer, reserved unsafe.Pointer) (err error) [int32(failretval)==0] = kernel32.ReplaceFileW diff --git a/atomicfile/zsyscall_windows.go b/atomicfile/zsyscall_windows.go new file mode 100644 index 000000000..bd1bf8113 --- /dev/null +++ b/atomicfile/zsyscall_windows.go @@ -0,0 +1,52 @@ +// Code generated by 'go generate'; DO NOT EDIT. + +package atomicfile + +import ( + "syscall" + "unsafe" + + "golang.org/x/sys/windows" +) + +var _ unsafe.Pointer + +// Do the interface allocations only once for common +// Errno values. +const ( + errnoERROR_IO_PENDING = 997 +) + +var ( + errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING) + errERROR_EINVAL error = syscall.EINVAL +) + +// errnoErr returns common boxed Errno values, to prevent +// allocations at runtime. +func errnoErr(e syscall.Errno) error { + switch e { + case 0: + return errERROR_EINVAL + case errnoERROR_IO_PENDING: + return errERROR_IO_PENDING + } + // TODO: add more here, after collecting data on the common + // error values see on Windows. (perhaps when running + // all.bat?) + return e +} + +var ( + modkernel32 = windows.NewLazySystemDLL("kernel32.dll") + + procReplaceFileW = modkernel32.NewProc("ReplaceFileW") +) + +func replaceFileW(replaced *uint16, replacement *uint16, backup *uint16, flags uint32, exclude unsafe.Pointer, reserved unsafe.Pointer) (err error) { + r1, _, e1 := syscall.SyscallN(procReplaceFileW.Addr(), uintptr(unsafe.Pointer(replaced)), uintptr(unsafe.Pointer(replacement)), uintptr(unsafe.Pointer(backup)), uintptr(flags), uintptr(exclude), uintptr(reserved)) + if int32(r1) == 0 { + err = errnoErr(e1) + } + return +} diff --git a/build_dist.sh b/build_dist.sh index 66afa8f74..c05644711 100755 --- a/build_dist.sh +++ b/build_dist.sh @@ -18,7 +18,7 @@ fi eval `CGO_ENABLED=0 GOOS=$($go env GOHOSTOS) GOARCH=$($go env GOHOSTARCH) $go run ./cmd/mkversion` -if [ "$1" = "shellvars" ]; then +if [ "$#" -ge 1 ] && [ "$1" = "shellvars" ]; then cat <//tailscale TAGS=v0.0.1 make publishdevimage set -eu @@ -16,13 +26,22 @@ eval "$(./build_dist.sh shellvars)" DEFAULT_TARGET="client" DEFAULT_TAGS="v${VERSION_SHORT},v${VERSION_MINOR}" -DEFAULT_BASE="tailscale/alpine-base:3.18" +DEFAULT_BASE="tailscale/alpine-base:3.22" +# Set a few pre-defined OCI annotations. The source annotation is used by tools such as Renovate that scan the linked +# Github repo to find release notes for any new image tags. Note that for official Tailscale images the default +# annotations defined here will be overriden by release scripts that call this script. +# https://github.com/opencontainers/image-spec/blob/main/annotations.md#pre-defined-annotation-keys +DEFAULT_ANNOTATIONS="org.opencontainers.image.source=https://github.com/tailscale/tailscale/blob/main/build_docker.sh,org.opencontainers.image.vendor=Tailscale" PUSH="${PUSH:-false}" TARGET="${TARGET:-${DEFAULT_TARGET}}" TAGS="${TAGS:-${DEFAULT_TAGS}}" BASE="${BASE:-${DEFAULT_BASE}}" PLATFORM="${PLATFORM:-}" # default to all platforms +FILES="${FILES:-}" # default to no extra files +# OCI annotations that will be added to the image. +# https://github.com/opencontainers/image-spec/blob/main/annotations.md +ANNOTATIONS="${ANNOTATIONS:-${DEFAULT_ANNOTATIONS}}" case "$TARGET" in client) @@ -43,9 +62,11 @@ case "$TARGET" in --repos="${REPOS}" \ --push="${PUSH}" \ --target="${PLATFORM}" \ + --annotations="${ANNOTATIONS}" \ + --files="${FILES}" \ /usr/local/bin/containerboot ;; - operator) + k8s-operator) DEFAULT_REPOS="tailscale/k8s-operator" REPOS="${REPOS:-${DEFAULT_REPOS}}" go run github.com/tailscale/mkctr \ @@ -56,9 +77,12 @@ case "$TARGET" in -X tailscale.com/version.gitCommitStamp=${VERSION_GIT_HASH}" \ --base="${BASE}" \ --tags="${TAGS}" \ + --gotags="ts_kube,ts_package_container" \ --repos="${REPOS}" \ --push="${PUSH}" \ --target="${PLATFORM}" \ + --annotations="${ANNOTATIONS}" \ + --files="${FILES}" \ /usr/local/bin/operator ;; k8s-nameserver) @@ -72,11 +96,52 @@ case "$TARGET" in -X tailscale.com/version.gitCommitStamp=${VERSION_GIT_HASH}" \ --base="${BASE}" \ --tags="${TAGS}" \ + --gotags="ts_kube,ts_package_container" \ --repos="${REPOS}" \ --push="${PUSH}" \ --target="${PLATFORM}" \ + --annotations="${ANNOTATIONS}" \ + --files="${FILES}" \ /usr/local/bin/k8s-nameserver ;; + tsidp) + DEFAULT_REPOS="tailscale/tsidp" + REPOS="${REPOS:-${DEFAULT_REPOS}}" + go run github.com/tailscale/mkctr \ + --gopaths="tailscale.com/cmd/tsidp:/usr/local/bin/tsidp" \ + --ldflags=" \ + -X tailscale.com/version.longStamp=${VERSION_LONG} \ + -X tailscale.com/version.shortStamp=${VERSION_SHORT} \ + -X tailscale.com/version.gitCommitStamp=${VERSION_GIT_HASH}" \ + --base="${BASE}" \ + --tags="${TAGS}" \ + --gotags="ts_package_container" \ + --repos="${REPOS}" \ + --push="${PUSH}" \ + --target="${PLATFORM}" \ + --annotations="${ANNOTATIONS}" \ + --files="${FILES}" \ + /usr/local/bin/tsidp + ;; + k8s-proxy) + DEFAULT_REPOS="tailscale/k8s-proxy" + REPOS="${REPOS:-${DEFAULT_REPOS}}" + go run github.com/tailscale/mkctr \ + --gopaths="tailscale.com/cmd/k8s-proxy:/usr/local/bin/k8s-proxy" \ + --ldflags=" \ + -X tailscale.com/version.longStamp=${VERSION_LONG} \ + -X tailscale.com/version.shortStamp=${VERSION_SHORT} \ + -X tailscale.com/version.gitCommitStamp=${VERSION_GIT_HASH}" \ + --base="${BASE}" \ + --tags="${TAGS}" \ + --gotags="ts_kube,ts_package_container" \ + --repos="${REPOS}" \ + --push="${PUSH}" \ + --target="${PLATFORM}" \ + --annotations="${ANNOTATIONS}" \ + --files="${FILES}" \ + /usr/local/bin/k8s-proxy + ;; *) echo "unknown target: $TARGET" exit 1 diff --git a/chirp/chirp_test.go b/chirp/chirp_test.go index 2549c163f..c545c277d 100644 --- a/chirp/chirp_test.go +++ b/chirp/chirp_test.go @@ -1,5 +1,6 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause + package chirp import ( @@ -23,7 +24,7 @@ type fakeBIRD struct { func newFakeBIRD(t *testing.T, protocols ...string) *fakeBIRD { sock := filepath.Join(t.TempDir(), "sock") - l, err := net.Listen("unix", sock) + ln, err := net.Listen("unix", sock) if err != nil { t.Fatal(err) } @@ -32,7 +33,7 @@ func newFakeBIRD(t *testing.T, protocols ...string) *fakeBIRD { pe[p] = false } return &fakeBIRD{ - Listener: l, + Listener: ln, protocolsEnabled: pe, sock: sock, } @@ -122,12 +123,12 @@ type hangingListener struct { func newHangingListener(t *testing.T) *hangingListener { sock := filepath.Join(t.TempDir(), "sock") - l, err := net.Listen("unix", sock) + ln, err := net.Listen("unix", sock) if err != nil { t.Fatal(err) } return &hangingListener{ - Listener: l, + Listener: ln, t: t, done: make(chan struct{}), sock: sock, diff --git a/client/local/cert.go b/client/local/cert.go new file mode 100644 index 000000000..bfaac7303 --- /dev/null +++ b/client/local/cert.go @@ -0,0 +1,151 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !js && !ts_omit_acme + +package local + +import ( + "context" + "crypto/tls" + "errors" + "fmt" + "net/url" + "strings" + "time" + + "go4.org/mem" +) + +// SetDNS adds a DNS TXT record for the given domain name, containing +// the provided TXT value. The intended use case is answering +// LetsEncrypt/ACME dns-01 challenges. +// +// The control plane will only permit SetDNS requests with very +// specific names and values. The name should be +// "_acme-challenge." + your node's MagicDNS name. It's expected that +// clients cache the certs from LetsEncrypt (or whichever CA is +// providing them) and only request new ones as needed; the control plane +// rate limits SetDNS requests. +// +// This is a low-level interface; it's expected that most Tailscale +// users use a higher level interface to getting/using TLS +// certificates. +func (lc *Client) SetDNS(ctx context.Context, name, value string) error { + v := url.Values{} + v.Set("name", name) + v.Set("value", value) + _, err := lc.send(ctx, "POST", "/localapi/v0/set-dns?"+v.Encode(), 200, nil) + return err +} + +// CertPair returns a cert and private key for the provided DNS domain. +// +// It returns a cached certificate from disk if it's still valid. +// +// Deprecated: use [Client.CertPair]. +func CertPair(ctx context.Context, domain string) (certPEM, keyPEM []byte, err error) { + return defaultClient.CertPair(ctx, domain) +} + +// CertPair returns a cert and private key for the provided DNS domain. +// +// It returns a cached certificate from disk if it's still valid. +// +// API maturity: this is considered a stable API. +func (lc *Client) CertPair(ctx context.Context, domain string) (certPEM, keyPEM []byte, err error) { + return lc.CertPairWithValidity(ctx, domain, 0) +} + +// CertPairWithValidity returns a cert and private key for the provided DNS +// domain. +// +// It returns a cached certificate from disk if it's still valid. +// When minValidity is non-zero, the returned certificate will be valid for at +// least the given duration, if permitted by the CA. If the certificate is +// valid, but for less than minValidity, it will be synchronously renewed. +// +// API maturity: this is considered a stable API. +func (lc *Client) CertPairWithValidity(ctx context.Context, domain string, minValidity time.Duration) (certPEM, keyPEM []byte, err error) { + res, err := lc.send(ctx, "GET", fmt.Sprintf("/localapi/v0/cert/%s?type=pair&min_validity=%s", domain, minValidity), 200, nil) + if err != nil { + return nil, nil, err + } + // with ?type=pair, the response PEM is first the one private + // key PEM block, then the cert PEM blocks. + i := mem.Index(mem.B(res), mem.S("--\n--")) + if i == -1 { + return nil, nil, fmt.Errorf("unexpected output: no delimiter") + } + i += len("--\n") + keyPEM, certPEM = res[:i], res[i:] + if mem.Contains(mem.B(certPEM), mem.S(" PRIVATE KEY-----")) { + return nil, nil, fmt.Errorf("unexpected output: key in cert") + } + return certPEM, keyPEM, nil +} + +// GetCertificate fetches a TLS certificate for the TLS ClientHello in hi. +// +// It returns a cached certificate from disk if it's still valid. +// +// It's the right signature to use as the value of +// [tls.Config.GetCertificate]. +// +// Deprecated: use [Client.GetCertificate]. +func GetCertificate(hi *tls.ClientHelloInfo) (*tls.Certificate, error) { + return defaultClient.GetCertificate(hi) +} + +// GetCertificate fetches a TLS certificate for the TLS ClientHello in hi. +// +// It returns a cached certificate from disk if it's still valid. +// +// It's the right signature to use as the value of +// [tls.Config.GetCertificate]. +// +// API maturity: this is considered a stable API. +func (lc *Client) GetCertificate(hi *tls.ClientHelloInfo) (*tls.Certificate, error) { + if hi == nil || hi.ServerName == "" { + return nil, errors.New("no SNI ServerName") + } + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + defer cancel() + + name := hi.ServerName + if !strings.Contains(name, ".") { + if v, ok := lc.ExpandSNIName(ctx, name); ok { + name = v + } + } + certPEM, keyPEM, err := lc.CertPair(ctx, name) + if err != nil { + return nil, err + } + cert, err := tls.X509KeyPair(certPEM, keyPEM) + if err != nil { + return nil, err + } + return &cert, nil +} + +// ExpandSNIName expands bare label name into the most likely actual TLS cert name. +// +// Deprecated: use [Client.ExpandSNIName]. +func ExpandSNIName(ctx context.Context, name string) (fqdn string, ok bool) { + return defaultClient.ExpandSNIName(ctx, name) +} + +// ExpandSNIName expands bare label name into the most likely actual TLS cert name. +func (lc *Client) ExpandSNIName(ctx context.Context, name string) (fqdn string, ok bool) { + st, err := lc.StatusWithoutPeers(ctx) + if err != nil { + return "", false + } + for _, d := range st.CertDomains { + if len(d) > len(name)+1 && strings.HasPrefix(d, name) && d[len(name)] == '.' { + return d, true + } + } + return "", false +} diff --git a/client/local/debugportmapper.go b/client/local/debugportmapper.go new file mode 100644 index 000000000..04ed1c109 --- /dev/null +++ b/client/local/debugportmapper.go @@ -0,0 +1,84 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !ts_omit_debugportmapper + +package local + +import ( + "cmp" + "context" + "fmt" + "io" + "net/http" + "net/netip" + "net/url" + "strconv" + "time" + + "tailscale.com/client/tailscale/apitype" +) + +// DebugPortmapOpts contains options for the [Client.DebugPortmap] command. +type DebugPortmapOpts struct { + // Duration is how long the mapping should be created for. It defaults + // to 5 seconds if not set. + Duration time.Duration + + // Type is the kind of portmap to debug. The empty string instructs the + // portmap client to perform all known types. Other valid options are + // "pmp", "pcp", and "upnp". + Type string + + // GatewayAddr specifies the gateway address used during portmapping. + // If set, SelfAddr must also be set. If unset, it will be + // autodetected. + GatewayAddr netip.Addr + + // SelfAddr specifies the gateway address used during portmapping. If + // set, GatewayAddr must also be set. If unset, it will be + // autodetected. + SelfAddr netip.Addr + + // LogHTTP instructs the debug-portmap endpoint to print all HTTP + // requests and responses made to the logs. + LogHTTP bool +} + +// DebugPortmap invokes the debug-portmap endpoint, and returns an +// io.ReadCloser that can be used to read the logs that are printed during this +// process. +// +// opts can be nil; if so, default values will be used. +func (lc *Client) DebugPortmap(ctx context.Context, opts *DebugPortmapOpts) (io.ReadCloser, error) { + vals := make(url.Values) + if opts == nil { + opts = &DebugPortmapOpts{} + } + + vals.Set("duration", cmp.Or(opts.Duration, 5*time.Second).String()) + vals.Set("type", opts.Type) + vals.Set("log_http", strconv.FormatBool(opts.LogHTTP)) + + if opts.GatewayAddr.IsValid() != opts.SelfAddr.IsValid() { + return nil, fmt.Errorf("both GatewayAddr and SelfAddr must be provided if one is") + } else if opts.GatewayAddr.IsValid() { + vals.Set("gateway_and_self", fmt.Sprintf("%s/%s", opts.GatewayAddr, opts.SelfAddr)) + } + + req, err := http.NewRequestWithContext(ctx, "GET", "http://"+apitype.LocalAPIHost+"/localapi/v0/debug-portmap?"+vals.Encode(), nil) + if err != nil { + return nil, err + } + res, err := lc.doLocalRequestNiceError(req) + if err != nil { + return nil, err + } + if res.StatusCode != 200 { + body, _ := io.ReadAll(res.Body) + res.Body.Close() + return nil, fmt.Errorf("HTTP %s: %s", res.Status, body) + } + + return res.Body, nil +} diff --git a/client/tailscale/localclient.go b/client/local/local.go similarity index 61% rename from client/tailscale/localclient.go rename to client/local/local.go index df51dc1ca..72ddbb55f 100644 --- a/client/tailscale/localclient.go +++ b/client/local/local.go @@ -1,19 +1,20 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -//go:build go1.19 - -package tailscale +// Package local contains a Go client for the Tailscale LocalAPI. +package local import ( + "bufio" "bytes" "cmp" "context" - "crypto/tls" + "encoding/base64" "encoding/json" "errors" "fmt" "io" + "iter" "net" "net/http" "net/http/httptrace" @@ -26,27 +27,30 @@ import ( "sync" "time" - "go4.org/mem" "tailscale.com/client/tailscale/apitype" "tailscale.com/drive" "tailscale.com/envknob" + "tailscale.com/feature" + "tailscale.com/feature/buildfeatures" "tailscale.com/ipn" "tailscale.com/ipn/ipnstate" "tailscale.com/net/netutil" + "tailscale.com/net/udprelay/status" "tailscale.com/paths" "tailscale.com/safesocket" + "tailscale.com/syncs" "tailscale.com/tailcfg" - "tailscale.com/tka" + "tailscale.com/types/appctype" "tailscale.com/types/dnstype" "tailscale.com/types/key" - "tailscale.com/types/tkatype" + "tailscale.com/util/eventbus" ) -// defaultLocalClient is the default LocalClient when using the legacy +// defaultClient is the default Client when using the legacy // package-level functions. -var defaultLocalClient LocalClient +var defaultClient Client -// LocalClient is a client to Tailscale's "LocalAPI", communicating with the +// Client is a client to Tailscale's "LocalAPI", communicating with the // Tailscale daemon on the local machine. Its API is not necessarily stable and // subject to changes between releases. Some API calls have stricter // compatibility guarantees, once they've been widely adopted. See method docs @@ -56,11 +60,17 @@ var defaultLocalClient LocalClient // // Any exported fields should be set before using methods on the type // and not changed thereafter. -type LocalClient struct { +type Client struct { // Dial optionally specifies an alternate func that connects to the local // machine's tailscaled or equivalent. If nil, a default is used. Dial func(ctx context.Context, network, addr string) (net.Conn, error) + // Transport optionally specifies an alternate [http.RoundTripper] + // used to execute HTTP requests. If nil, a default [http.Transport] is used, + // potentially with custom dialing logic from [Dial]. + // It is primarily used for testing. + Transport http.RoundTripper + // Socket specifies an alternate path to the local Tailscale socket. // If empty, a platform-specific default is used. Socket string @@ -84,21 +94,21 @@ type LocalClient struct { tsClientOnce sync.Once } -func (lc *LocalClient) socket() string { +func (lc *Client) socket() string { if lc.Socket != "" { return lc.Socket } return paths.DefaultTailscaledSocket() } -func (lc *LocalClient) dialer() func(ctx context.Context, network, addr string) (net.Conn, error) { +func (lc *Client) dialer() func(ctx context.Context, network, addr string) (net.Conn, error) { if lc.Dial != nil { return lc.Dial } return lc.defaultDialer } -func (lc *LocalClient) defaultDialer(ctx context.Context, network, addr string) (net.Conn, error) { +func (lc *Client) defaultDialer(ctx context.Context, network, addr string) (net.Conn, error) { if addr != "local-tailscaled.sock:80" { return nil, fmt.Errorf("unexpected URL address %q", addr) } @@ -124,13 +134,13 @@ func (lc *LocalClient) defaultDialer(ctx context.Context, network, addr string) // authenticating to the local Tailscale daemon vary by platform. // // DoLocalRequest may mutate the request to add Authorization headers. -func (lc *LocalClient) DoLocalRequest(req *http.Request) (*http.Response, error) { +func (lc *Client) DoLocalRequest(req *http.Request) (*http.Response, error) { req.Header.Set("Tailscale-Cap", strconv.Itoa(int(tailcfg.CurrentCapabilityVersion))) lc.tsClientOnce.Do(func() { lc.tsClient = &http.Client{ - Transport: &http.Transport{ - DialContext: lc.dialer(), - }, + Transport: cmp.Or(lc.Transport, http.RoundTripper( + &http.Transport{DialContext: lc.dialer()}), + ), } }) if !lc.OmitAuth { @@ -141,7 +151,7 @@ func (lc *LocalClient) DoLocalRequest(req *http.Request) (*http.Response, error) return lc.tsClient.Do(req) } -func (lc *LocalClient) doLocalRequestNiceError(req *http.Request) (*http.Response, error) { +func (lc *Client) doLocalRequestNiceError(req *http.Request) (*http.Response, error) { res, err := lc.DoLocalRequest(req) if err == nil { if server := res.Header.Get("Tailscale-Version"); server != "" && server != envknob.IPCVersion() && onVersionMismatch != nil { @@ -230,12 +240,17 @@ func SetVersionMismatchHandler(f func(clientVer, serverVer string)) { onVersionMismatch = f } -func (lc *LocalClient) send(ctx context.Context, method, path string, wantStatus int, body io.Reader) ([]byte, error) { - slurp, _, err := lc.sendWithHeaders(ctx, method, path, wantStatus, body, nil) +func (lc *Client) send(ctx context.Context, method, path string, wantStatus int, body io.Reader) ([]byte, error) { + var headers http.Header + if reason := apitype.RequestReasonKey.Value(ctx); reason != "" { + reasonBase64 := base64.StdEncoding.EncodeToString([]byte(reason)) + headers = http.Header{apitype.RequestReasonHeader: {reasonBase64}} + } + slurp, _, err := lc.sendWithHeaders(ctx, method, path, wantStatus, body, headers) return slurp, err } -func (lc *LocalClient) sendWithHeaders( +func (lc *Client) sendWithHeaders( ctx context.Context, method, path string, @@ -274,15 +289,15 @@ type httpStatusError struct { HTTPStatus int } -func (lc *LocalClient) get200(ctx context.Context, path string) ([]byte, error) { +func (lc *Client) get200(ctx context.Context, path string) ([]byte, error) { return lc.send(ctx, "GET", path, 200, nil) } // WhoIs returns the owner of the remoteAddr, which must be an IP or IP:port. // -// Deprecated: use LocalClient.WhoIs. +// Deprecated: use [Client.WhoIs]. func WhoIs(ctx context.Context, remoteAddr string) (*apitype.WhoIsResponse, error) { - return defaultLocalClient.WhoIs(ctx, remoteAddr) + return defaultClient.WhoIs(ctx, remoteAddr) } func decodeJSON[T any](b []byte) (ret T, err error) { @@ -295,12 +310,12 @@ func decodeJSON[T any](b []byte) (ret T, err error) { // WhoIs returns the owner of the remoteAddr, which must be an IP or IP:port. // -// If not found, the error is ErrPeerNotFound. +// If not found, the error is [ErrPeerNotFound]. // // For connections proxied by tailscaled, this looks up the owner of the given // address as TCP first, falling back to UDP; if you want to only check a // specific address family, use WhoIsProto. -func (lc *LocalClient) WhoIs(ctx context.Context, remoteAddr string) (*apitype.WhoIsResponse, error) { +func (lc *Client) WhoIs(ctx context.Context, remoteAddr string) (*apitype.WhoIsResponse, error) { body, err := lc.get200(ctx, "/localapi/v0/whois?addr="+url.QueryEscape(remoteAddr)) if err != nil { if hs, ok := err.(httpStatusError); ok && hs.HTTPStatus == http.StatusNotFound { @@ -311,13 +326,14 @@ func (lc *LocalClient) WhoIs(ctx context.Context, remoteAddr string) (*apitype.W return decodeJSON[*apitype.WhoIsResponse](body) } -// ErrPeerNotFound is returned by WhoIs and WhoIsNodeKey when a peer is not found. +// ErrPeerNotFound is returned by [Client.WhoIs], [Client.WhoIsNodeKey] and +// [Client.WhoIsProto] when a peer is not found. var ErrPeerNotFound = errors.New("peer not found") // WhoIsNodeKey returns the owner of the given wireguard public key. // // If not found, the error is ErrPeerNotFound. -func (lc *LocalClient) WhoIsNodeKey(ctx context.Context, key key.NodePublic) (*apitype.WhoIsResponse, error) { +func (lc *Client) WhoIsNodeKey(ctx context.Context, key key.NodePublic) (*apitype.WhoIsResponse, error) { body, err := lc.get200(ctx, "/localapi/v0/whois?addr="+url.QueryEscape(key.String())) if err != nil { if hs, ok := err.(httpStatusError); ok && hs.HTTPStatus == http.StatusNotFound { @@ -331,8 +347,8 @@ func (lc *LocalClient) WhoIsNodeKey(ctx context.Context, key key.NodePublic) (*a // WhoIsProto returns the owner of the remoteAddr, which must be an IP or // IP:port, for the given protocol (tcp or udp). // -// If not found, the error is ErrPeerNotFound. -func (lc *LocalClient) WhoIsProto(ctx context.Context, proto, remoteAddr string) (*apitype.WhoIsResponse, error) { +// If not found, the error is [ErrPeerNotFound]. +func (lc *Client) WhoIsProto(ctx context.Context, proto, remoteAddr string) (*apitype.WhoIsResponse, error) { body, err := lc.get200(ctx, "/localapi/v0/whois?proto="+url.QueryEscape(proto)+"&addr="+url.QueryEscape(remoteAddr)) if err != nil { if hs, ok := err.(httpStatusError); ok && hs.HTTPStatus == http.StatusNotFound { @@ -344,19 +360,19 @@ func (lc *LocalClient) WhoIsProto(ctx context.Context, proto, remoteAddr string) } // Goroutines returns a dump of the Tailscale daemon's current goroutines. -func (lc *LocalClient) Goroutines(ctx context.Context) ([]byte, error) { +func (lc *Client) Goroutines(ctx context.Context) ([]byte, error) { return lc.get200(ctx, "/localapi/v0/goroutines") } // DaemonMetrics returns the Tailscale daemon's metrics in // the Prometheus text exposition format. -func (lc *LocalClient) DaemonMetrics(ctx context.Context) ([]byte, error) { +func (lc *Client) DaemonMetrics(ctx context.Context) ([]byte, error) { return lc.get200(ctx, "/localapi/v0/metrics") } // UserMetrics returns the user metrics in // the Prometheus text exposition format. -func (lc *LocalClient) UserMetrics(ctx context.Context) ([]byte, error) { +func (lc *Client) UserMetrics(ctx context.Context) ([]byte, error) { return lc.get200(ctx, "/localapi/v0/usermetrics") } @@ -365,7 +381,10 @@ func (lc *LocalClient) UserMetrics(ctx context.Context) ([]byte, error) { // metric is created and initialized to delta. // // IncrementCounter does not support gauge metrics or negative delta values. -func (lc *LocalClient) IncrementCounter(ctx context.Context, name string, delta int) error { +func (lc *Client) IncrementCounter(ctx context.Context, name string, delta int) error { + if !buildfeatures.HasClientMetrics { + return nil + } type metricUpdate struct { Name string `json:"name"` Type string `json:"type"` @@ -382,9 +401,26 @@ func (lc *LocalClient) IncrementCounter(ctx context.Context, name string, delta return err } +// IncrementGauge increments the value of a Tailscale daemon's gauge +// metric by the given delta. If the metric has yet to exist, a new gauge +// metric is created and initialized to delta. The delta value can be negative. +func (lc *Client) IncrementGauge(ctx context.Context, name string, delta int) error { + type metricUpdate struct { + Name string `json:"name"` + Type string `json:"type"` + Value int `json:"value"` // amount to increment by + } + _, err := lc.send(ctx, "POST", "/localapi/v0/upload-client-metrics", 200, jsonBody([]metricUpdate{{ + Name: name, + Type: "gauge", + Value: delta, + }})) + return err +} + // TailDaemonLogs returns a stream the Tailscale daemon's logs as they arrive. // Close the context to stop the stream. -func (lc *LocalClient) TailDaemonLogs(ctx context.Context) (io.Reader, error) { +func (lc *Client) TailDaemonLogs(ctx context.Context) (io.Reader, error) { req, err := http.NewRequestWithContext(ctx, "GET", "http://"+apitype.LocalAPIHost+"/localapi/v0/logtap", nil) if err != nil { return nil, err @@ -399,8 +435,52 @@ func (lc *LocalClient) TailDaemonLogs(ctx context.Context) (io.Reader, error) { return res.Body, nil } +// EventBusGraph returns a graph of active publishers and subscribers in the eventbus +// as a [eventbus.DebugTopics] +func (lc *Client) EventBusGraph(ctx context.Context) ([]byte, error) { + return lc.get200(ctx, "/localapi/v0/debug-bus-graph") +} + +// StreamBusEvents returns an iterator of Tailscale bus events as they arrive. +// Each pair is a valid event and a nil error, or a zero event a non-nil error. +// In case of error, the iterator ends after the pair reporting the error. +// Iteration stops if ctx ends. +func (lc *Client) StreamBusEvents(ctx context.Context) iter.Seq2[eventbus.DebugEvent, error] { + return func(yield func(eventbus.DebugEvent, error) bool) { + req, err := http.NewRequestWithContext(ctx, "GET", + "http://"+apitype.LocalAPIHost+"/localapi/v0/debug-bus-events", nil) + if err != nil { + yield(eventbus.DebugEvent{}, err) + return + } + res, err := lc.doLocalRequestNiceError(req) + if err != nil { + yield(eventbus.DebugEvent{}, err) + return + } + if res.StatusCode != http.StatusOK { + yield(eventbus.DebugEvent{}, errors.New(res.Status)) + return + } + defer res.Body.Close() + dec := json.NewDecoder(bufio.NewReader(res.Body)) + for { + var evt eventbus.DebugEvent + if err := dec.Decode(&evt); err == io.EOF { + return + } else if err != nil { + yield(eventbus.DebugEvent{}, err) + return + } + if !yield(evt, nil) { + return + } + } + } +} + // Pprof returns a pprof profile of the Tailscale daemon. -func (lc *LocalClient) Pprof(ctx context.Context, pprofType string, sec int) ([]byte, error) { +func (lc *Client) Pprof(ctx context.Context, pprofType string, sec int) ([]byte, error) { var secArg string if sec < 0 || sec > 300 { return nil, errors.New("duration out of range") @@ -433,7 +513,7 @@ type BugReportOpts struct { // // The opts type specifies options to pass to the Tailscale daemon when // generating this bug report. -func (lc *LocalClient) BugReportWithOpts(ctx context.Context, opts BugReportOpts) (string, error) { +func (lc *Client) BugReportWithOpts(ctx context.Context, opts BugReportOpts) (string, error) { qparams := make(url.Values) if opts.Note != "" { qparams.Set("note", opts.Note) @@ -476,15 +556,15 @@ func (lc *LocalClient) BugReportWithOpts(ctx context.Context, opts BugReportOpts // BugReport logs and returns a log marker that can be shared by the user with support. // -// This is the same as calling BugReportWithOpts and only specifying the Note +// This is the same as calling [Client.BugReportWithOpts] and only specifying the Note // field. -func (lc *LocalClient) BugReport(ctx context.Context, note string) (string, error) { +func (lc *Client) BugReport(ctx context.Context, note string) (string, error) { return lc.BugReportWithOpts(ctx, BugReportOpts{Note: note}) } // DebugAction invokes a debug action, such as "rebind" or "restun". // These are development tools and subject to change or removal over time. -func (lc *LocalClient) DebugAction(ctx context.Context, action string) error { +func (lc *Client) DebugAction(ctx context.Context, action string) error { body, err := lc.send(ctx, "POST", "/localapi/v0/debug?action="+url.QueryEscape(action), 200, nil) if err != nil { return fmt.Errorf("error %w: %s", err, body) @@ -492,9 +572,20 @@ func (lc *LocalClient) DebugAction(ctx context.Context, action string) error { return nil } +// DebugActionBody invokes a debug action with a body parameter, such as +// "debug-force-prefer-derp". +// These are development tools and subject to change or removal over time. +func (lc *Client) DebugActionBody(ctx context.Context, action string, rbody io.Reader) error { + body, err := lc.send(ctx, "POST", "/localapi/v0/debug?action="+url.QueryEscape(action), 200, rbody) + if err != nil { + return fmt.Errorf("error %w: %s", err, body) + } + return nil +} + // DebugResultJSON invokes a debug action and returns its result as something JSON-able. // These are development tools and subject to change or removal over time. -func (lc *LocalClient) DebugResultJSON(ctx context.Context, action string) (any, error) { +func (lc *Client) DebugResultJSON(ctx context.Context, action string) (any, error) { body, err := lc.send(ctx, "POST", "/localapi/v0/debug?action="+url.QueryEscape(action), 200, nil) if err != nil { return nil, fmt.Errorf("error %w: %s", err, body) @@ -506,73 +597,22 @@ func (lc *LocalClient) DebugResultJSON(ctx context.Context, action string) (any, return x, nil } -// DebugPortmapOpts contains options for the DebugPortmap command. -type DebugPortmapOpts struct { - // Duration is how long the mapping should be created for. It defaults - // to 5 seconds if not set. - Duration time.Duration - - // Type is the kind of portmap to debug. The empty string instructs the - // portmap client to perform all known types. Other valid options are - // "pmp", "pcp", and "upnp". - Type string - - // GatewayAddr specifies the gateway address used during portmapping. - // If set, SelfAddr must also be set. If unset, it will be - // autodetected. - GatewayAddr netip.Addr - - // SelfAddr specifies the gateway address used during portmapping. If - // set, GatewayAddr must also be set. If unset, it will be - // autodetected. - SelfAddr netip.Addr - - // LogHTTP instructs the debug-portmap endpoint to print all HTTP - // requests and responses made to the logs. - LogHTTP bool -} - -// DebugPortmap invokes the debug-portmap endpoint, and returns an -// io.ReadCloser that can be used to read the logs that are printed during this -// process. -// -// opts can be nil; if so, default values will be used. -func (lc *LocalClient) DebugPortmap(ctx context.Context, opts *DebugPortmapOpts) (io.ReadCloser, error) { - vals := make(url.Values) - if opts == nil { - opts = &DebugPortmapOpts{} - } - - vals.Set("duration", cmp.Or(opts.Duration, 5*time.Second).String()) - vals.Set("type", opts.Type) - vals.Set("log_http", strconv.FormatBool(opts.LogHTTP)) - - if opts.GatewayAddr.IsValid() != opts.SelfAddr.IsValid() { - return nil, fmt.Errorf("both GatewayAddr and SelfAddr must be provided if one is") - } else if opts.GatewayAddr.IsValid() { - vals.Set("gateway_and_self", fmt.Sprintf("%s/%s", opts.GatewayAddr, opts.SelfAddr)) - } - - req, err := http.NewRequestWithContext(ctx, "GET", "http://"+apitype.LocalAPIHost+"/localapi/v0/debug-portmap?"+vals.Encode(), nil) +// QueryOptionalFeatures queries the optional features supported by the Tailscale daemon. +func (lc *Client) QueryOptionalFeatures(ctx context.Context) (*apitype.OptionalFeatures, error) { + body, err := lc.send(ctx, "POST", "/localapi/v0/debug-optional-features", 200, nil) if err != nil { - return nil, err + return nil, fmt.Errorf("error %w: %s", err, body) } - res, err := lc.doLocalRequestNiceError(req) - if err != nil { + var x apitype.OptionalFeatures + if err := json.Unmarshal(body, &x); err != nil { return nil, err } - if res.StatusCode != 200 { - body, _ := io.ReadAll(res.Body) - res.Body.Close() - return nil, fmt.Errorf("HTTP %s: %s", res.Status, body) - } - - return res.Body, nil + return &x, nil } // SetDevStoreKeyValue set a statestore key/value. It's only meant for development. // The schema (including when keys are re-read) is not a stable interface. -func (lc *LocalClient) SetDevStoreKeyValue(ctx context.Context, key, value string) error { +func (lc *Client) SetDevStoreKeyValue(ctx context.Context, key, value string) error { body, err := lc.send(ctx, "POST", "/localapi/v0/dev-set-state-store?"+(url.Values{ "key": {key}, "value": {value}, @@ -586,7 +626,10 @@ func (lc *LocalClient) SetDevStoreKeyValue(ctx context.Context, key, value strin // SetComponentDebugLogging sets component's debug logging enabled for // the provided duration. If the duration is in the past, the debug logging // is disabled. -func (lc *LocalClient) SetComponentDebugLogging(ctx context.Context, component string, d time.Duration) error { +func (lc *Client) SetComponentDebugLogging(ctx context.Context, component string, d time.Duration) error { + if !buildfeatures.HasDebug { + return feature.ErrUnavailable + } body, err := lc.send(ctx, "POST", fmt.Sprintf("/localapi/v0/component-debug-logging?component=%s&secs=%d", url.QueryEscape(component), int64(d.Seconds())), 200, nil) @@ -607,25 +650,25 @@ func (lc *LocalClient) SetComponentDebugLogging(ctx context.Context, component s // Status returns the Tailscale daemon's status. func Status(ctx context.Context) (*ipnstate.Status, error) { - return defaultLocalClient.Status(ctx) + return defaultClient.Status(ctx) } // Status returns the Tailscale daemon's status. -func (lc *LocalClient) Status(ctx context.Context) (*ipnstate.Status, error) { +func (lc *Client) Status(ctx context.Context) (*ipnstate.Status, error) { return lc.status(ctx, "") } // StatusWithoutPeers returns the Tailscale daemon's status, without the peer info. func StatusWithoutPeers(ctx context.Context) (*ipnstate.Status, error) { - return defaultLocalClient.StatusWithoutPeers(ctx) + return defaultClient.StatusWithoutPeers(ctx) } // StatusWithoutPeers returns the Tailscale daemon's status, without the peer info. -func (lc *LocalClient) StatusWithoutPeers(ctx context.Context) (*ipnstate.Status, error) { +func (lc *Client) StatusWithoutPeers(ctx context.Context) (*ipnstate.Status, error) { return lc.status(ctx, "?peers=false") } -func (lc *LocalClient) status(ctx context.Context, queryString string) (*ipnstate.Status, error) { +func (lc *Client) status(ctx context.Context, queryString string) (*ipnstate.Status, error) { body, err := lc.get200(ctx, "/localapi/v0/status"+queryString) if err != nil { return nil, err @@ -636,7 +679,7 @@ func (lc *LocalClient) status(ctx context.Context, queryString string) (*ipnstat // IDToken is a request to get an OIDC ID token for an audience. // The token can be presented to any resource provider which offers OIDC // Federation. -func (lc *LocalClient) IDToken(ctx context.Context, aud string) (*tailcfg.TokenResponse, error) { +func (lc *Client) IDToken(ctx context.Context, aud string) (*tailcfg.TokenResponse, error) { body, err := lc.get200(ctx, "/localapi/v0/id-token?aud="+url.QueryEscape(aud)) if err != nil { return nil, err @@ -648,14 +691,14 @@ func (lc *LocalClient) IDToken(ctx context.Context, aud string) (*tailcfg.TokenR // received by the Tailscale daemon in its staging/cache directory but not yet // transferred by the user's CLI or GUI client and written to a user's home // directory somewhere. -func (lc *LocalClient) WaitingFiles(ctx context.Context) ([]apitype.WaitingFile, error) { +func (lc *Client) WaitingFiles(ctx context.Context) ([]apitype.WaitingFile, error) { return lc.AwaitWaitingFiles(ctx, 0) } -// AwaitWaitingFiles is like WaitingFiles but takes a duration to await for an answer. +// AwaitWaitingFiles is like [Client.WaitingFiles] but takes a duration to await for an answer. // If the duration is 0, it will return immediately. The duration is respected at second // granularity only. If no files are available, it returns (nil, nil). -func (lc *LocalClient) AwaitWaitingFiles(ctx context.Context, d time.Duration) ([]apitype.WaitingFile, error) { +func (lc *Client) AwaitWaitingFiles(ctx context.Context, d time.Duration) ([]apitype.WaitingFile, error) { path := "/localapi/v0/files/?waitsec=" + fmt.Sprint(int(d.Seconds())) body, err := lc.get200(ctx, path) if err != nil { @@ -664,12 +707,12 @@ func (lc *LocalClient) AwaitWaitingFiles(ctx context.Context, d time.Duration) ( return decodeJSON[[]apitype.WaitingFile](body) } -func (lc *LocalClient) DeleteWaitingFile(ctx context.Context, baseName string) error { +func (lc *Client) DeleteWaitingFile(ctx context.Context, baseName string) error { _, err := lc.send(ctx, "DELETE", "/localapi/v0/files/"+url.PathEscape(baseName), http.StatusNoContent, nil) return err } -func (lc *LocalClient) GetWaitingFile(ctx context.Context, baseName string) (rc io.ReadCloser, size int64, err error) { +func (lc *Client) GetWaitingFile(ctx context.Context, baseName string) (rc io.ReadCloser, size int64, err error) { req, err := http.NewRequestWithContext(ctx, "GET", "http://"+apitype.LocalAPIHost+"/localapi/v0/files/"+url.PathEscape(baseName), nil) if err != nil { return nil, 0, err @@ -690,7 +733,7 @@ func (lc *LocalClient) GetWaitingFile(ctx context.Context, baseName string) (rc return res.Body, res.ContentLength, nil } -func (lc *LocalClient) FileTargets(ctx context.Context) ([]apitype.FileTarget, error) { +func (lc *Client) FileTargets(ctx context.Context) ([]apitype.FileTarget, error) { body, err := lc.get200(ctx, "/localapi/v0/file-targets") if err != nil { return nil, err @@ -702,7 +745,7 @@ func (lc *LocalClient) FileTargets(ctx context.Context) ([]apitype.FileTarget, e // // A size of -1 means unknown. // The name parameter is the original filename, not escaped. -func (lc *LocalClient) PushFile(ctx context.Context, target tailcfg.StableNodeID, size int64, name string, r io.Reader) error { +func (lc *Client) PushFile(ctx context.Context, target tailcfg.StableNodeID, size int64, name string, r io.Reader) error { req, err := http.NewRequestWithContext(ctx, "PUT", "http://"+apitype.LocalAPIHost+"/localapi/v0/file-put/"+string(target)+"/"+url.PathEscape(name), r) if err != nil { return err @@ -725,7 +768,10 @@ func (lc *LocalClient) PushFile(ctx context.Context, target tailcfg.StableNodeID // CheckIPForwarding asks the local Tailscale daemon whether it looks like the // machine is properly configured to forward IP packets as a subnet router // or exit node. -func (lc *LocalClient) CheckIPForwarding(ctx context.Context) error { +func (lc *Client) CheckIPForwarding(ctx context.Context) error { + if !buildfeatures.HasAdvertiseRoutes { + return nil + } body, err := lc.get200(ctx, "/localapi/v0/check-ip-forwarding") if err != nil { return err @@ -745,7 +791,7 @@ func (lc *LocalClient) CheckIPForwarding(ctx context.Context) error { // CheckUDPGROForwarding asks the local Tailscale daemon whether it looks like // the machine is optimally configured to forward UDP packets as a subnet router // or exit node. -func (lc *LocalClient) CheckUDPGROForwarding(ctx context.Context) error { +func (lc *Client) CheckUDPGROForwarding(ctx context.Context) error { body, err := lc.get200(ctx, "/localapi/v0/check-udp-gro-forwarding") if err != nil { return err @@ -762,11 +808,30 @@ func (lc *LocalClient) CheckUDPGROForwarding(ctx context.Context) error { return nil } +// CheckReversePathFiltering asks the local Tailscale daemon whether strict +// reverse path filtering is enabled, which would break exit node usage on Linux. +func (lc *Client) CheckReversePathFiltering(ctx context.Context) error { + body, err := lc.get200(ctx, "/localapi/v0/check-reverse-path-filtering") + if err != nil { + return err + } + var jres struct { + Warning string + } + if err := json.Unmarshal(body, &jres); err != nil { + return fmt.Errorf("invalid JSON from check-reverse-path-filtering: %w", err) + } + if jres.Warning != "" { + return errors.New(jres.Warning) + } + return nil +} + // SetUDPGROForwarding enables UDP GRO forwarding for the main interface of this // node. This can be done to improve performance of tailnet nodes acting as exit // nodes or subnet routers. // See https://tailscale.com/kb/1320/performance-best-practices#linux-optimizations-for-subnet-routers-and-exit-nodes -func (lc *LocalClient) SetUDPGROForwarding(ctx context.Context) error { +func (lc *Client) SetUDPGROForwarding(ctx context.Context) error { body, err := lc.get200(ctx, "/localapi/v0/set-udp-gro-forwarding") if err != nil { return err @@ -789,12 +854,12 @@ func (lc *LocalClient) SetUDPGROForwarding(ctx context.Context) error { // work. Currently (2022-04-18) this only checks for SSH server compatibility. // Note that EditPrefs does the same validation as this, so call CheckPrefs before // EditPrefs is not necessary. -func (lc *LocalClient) CheckPrefs(ctx context.Context, p *ipn.Prefs) error { +func (lc *Client) CheckPrefs(ctx context.Context, p *ipn.Prefs) error { _, err := lc.send(ctx, "POST", "/localapi/v0/check-prefs", http.StatusOK, jsonBody(p)) return err } -func (lc *LocalClient) GetPrefs(ctx context.Context) (*ipn.Prefs, error) { +func (lc *Client) GetPrefs(ctx context.Context) (*ipn.Prefs, error) { body, err := lc.get200(ctx, "/localapi/v0/prefs") if err != nil { return nil, err @@ -806,7 +871,12 @@ func (lc *LocalClient) GetPrefs(ctx context.Context) (*ipn.Prefs, error) { return &p, nil } -func (lc *LocalClient) EditPrefs(ctx context.Context, mp *ipn.MaskedPrefs) (*ipn.Prefs, error) { +// EditPrefs updates the [ipn.Prefs] of the current Tailscale profile, applying the changes in mp. +// It returns an error if the changes cannot be applied, such as due to the caller's access rights +// or a policy restriction. An optional reason or justification for the request can be +// provided as a context value using [apitype.RequestReasonKey]. If permitted by policy, +// access may be granted, and the reason will be logged for auditing purposes. +func (lc *Client) EditPrefs(ctx context.Context, mp *ipn.MaskedPrefs) (*ipn.Prefs, error) { body, err := lc.send(ctx, "PATCH", "/localapi/v0/prefs", http.StatusOK, jsonBody(mp)) if err != nil { return nil, err @@ -816,7 +886,10 @@ func (lc *LocalClient) EditPrefs(ctx context.Context, mp *ipn.MaskedPrefs) (*ipn // GetDNSOSConfig returns the system DNS configuration for the current device. // That is, it returns the DNS configuration that the system would use if Tailscale weren't being used. -func (lc *LocalClient) GetDNSOSConfig(ctx context.Context) (*apitype.DNSOSConfig, error) { +func (lc *Client) GetDNSOSConfig(ctx context.Context) (*apitype.DNSOSConfig, error) { + if !buildfeatures.HasDNS { + return nil, feature.ErrUnavailable + } body, err := lc.get200(ctx, "/localapi/v0/dns-osconfig") if err != nil { return nil, err @@ -831,7 +904,10 @@ func (lc *LocalClient) GetDNSOSConfig(ctx context.Context) (*apitype.DNSOSConfig // QueryDNS executes a DNS query for a name (`google.com.`) and query type (`CNAME`). // It returns the raw DNS response bytes and the resolvers that were used to answer the query // (often just one, but can be more if we raced multiple resolvers). -func (lc *LocalClient) QueryDNS(ctx context.Context, name string, queryType string) (bytes []byte, resolvers []*dnstype.Resolver, err error) { +func (lc *Client) QueryDNS(ctx context.Context, name string, queryType string) (bytes []byte, resolvers []*dnstype.Resolver, err error) { + if !buildfeatures.HasDNS { + return nil, nil, feature.ErrUnavailable + } body, err := lc.get200(ctx, fmt.Sprintf("/localapi/v0/dns-query?name=%s&type=%s", url.QueryEscape(name), queryType)) if err != nil { return nil, nil, err @@ -844,53 +920,31 @@ func (lc *LocalClient) QueryDNS(ctx context.Context, name string, queryType stri } // StartLoginInteractive starts an interactive login. -func (lc *LocalClient) StartLoginInteractive(ctx context.Context) error { +func (lc *Client) StartLoginInteractive(ctx context.Context) error { _, err := lc.send(ctx, "POST", "/localapi/v0/login-interactive", http.StatusNoContent, nil) return err } // Start applies the configuration specified in opts, and starts the // state machine. -func (lc *LocalClient) Start(ctx context.Context, opts ipn.Options) error { +func (lc *Client) Start(ctx context.Context, opts ipn.Options) error { _, err := lc.send(ctx, "POST", "/localapi/v0/start", http.StatusNoContent, jsonBody(opts)) return err } // Logout logs out the current node. -func (lc *LocalClient) Logout(ctx context.Context) error { +func (lc *Client) Logout(ctx context.Context) error { _, err := lc.send(ctx, "POST", "/localapi/v0/logout", http.StatusNoContent, nil) return err } -// SetDNS adds a DNS TXT record for the given domain name, containing -// the provided TXT value. The intended use case is answering -// LetsEncrypt/ACME dns-01 challenges. -// -// The control plane will only permit SetDNS requests with very -// specific names and values. The name should be -// "_acme-challenge." + your node's MagicDNS name. It's expected that -// clients cache the certs from LetsEncrypt (or whichever CA is -// providing them) and only request new ones as needed; the control plane -// rate limits SetDNS requests. -// -// This is a low-level interface; it's expected that most Tailscale -// users use a higher level interface to getting/using TLS -// certificates. -func (lc *LocalClient) SetDNS(ctx context.Context, name, value string) error { - v := url.Values{} - v.Set("name", name) - v.Set("value", value) - _, err := lc.send(ctx, "POST", "/localapi/v0/set-dns?"+v.Encode(), 200, nil) - return err -} - // DialTCP connects to the host's port via Tailscale. // // The host may be a base DNS name (resolved from the netmap inside // tailscaled), a FQDN, or an IP address. // -// The ctx is only used for the duration of the call, not the lifetime of the net.Conn. -func (lc *LocalClient) DialTCP(ctx context.Context, host string, port uint16) (net.Conn, error) { +// The ctx is only used for the duration of the call, not the lifetime of the [net.Conn]. +func (lc *Client) DialTCP(ctx context.Context, host string, port uint16) (net.Conn, error) { return lc.UserDial(ctx, "tcp", host, port) } @@ -900,8 +954,8 @@ func (lc *LocalClient) DialTCP(ctx context.Context, host string, port uint16) (n // a FQDN, or an IP address. // // The ctx is only used for the duration of the call, not the lifetime of the -// net.Conn. -func (lc *LocalClient) UserDial(ctx context.Context, network, host string, port uint16) (net.Conn, error) { +// [net.Conn]. +func (lc *Client) UserDial(ctx context.Context, network, host string, port uint16) (net.Conn, error) { connCh := make(chan net.Conn, 1) trace := httptrace.ClientTrace{ GotConn: func(info httptrace.GotConnInfo) { @@ -952,7 +1006,7 @@ func (lc *LocalClient) UserDial(ctx context.Context, network, host string, port // CurrentDERPMap returns the current DERPMap that is being used by the local tailscaled. // It is intended to be used with netcheck to see availability of DERPs. -func (lc *LocalClient) CurrentDERPMap(ctx context.Context) (*tailcfg.DERPMap, error) { +func (lc *Client) CurrentDERPMap(ctx context.Context) (*tailcfg.DERPMap, error) { var derpMap tailcfg.DERPMap res, err := lc.send(ctx, "GET", "/localapi/v0/derpmap", 200, nil) if err != nil { @@ -964,117 +1018,6 @@ func (lc *LocalClient) CurrentDERPMap(ctx context.Context) (*tailcfg.DERPMap, er return &derpMap, nil } -// CertPair returns a cert and private key for the provided DNS domain. -// -// It returns a cached certificate from disk if it's still valid. -// -// Deprecated: use LocalClient.CertPair. -func CertPair(ctx context.Context, domain string) (certPEM, keyPEM []byte, err error) { - return defaultLocalClient.CertPair(ctx, domain) -} - -// CertPair returns a cert and private key for the provided DNS domain. -// -// It returns a cached certificate from disk if it's still valid. -// -// API maturity: this is considered a stable API. -func (lc *LocalClient) CertPair(ctx context.Context, domain string) (certPEM, keyPEM []byte, err error) { - return lc.CertPairWithValidity(ctx, domain, 0) -} - -// CertPairWithValidity returns a cert and private key for the provided DNS -// domain. -// -// It returns a cached certificate from disk if it's still valid. -// When minValidity is non-zero, the returned certificate will be valid for at -// least the given duration, if permitted by the CA. If the certificate is -// valid, but for less than minValidity, it will be synchronously renewed. -// -// API maturity: this is considered a stable API. -func (lc *LocalClient) CertPairWithValidity(ctx context.Context, domain string, minValidity time.Duration) (certPEM, keyPEM []byte, err error) { - res, err := lc.send(ctx, "GET", fmt.Sprintf("/localapi/v0/cert/%s?type=pair&min_validity=%s", domain, minValidity), 200, nil) - if err != nil { - return nil, nil, err - } - // with ?type=pair, the response PEM is first the one private - // key PEM block, then the cert PEM blocks. - i := mem.Index(mem.B(res), mem.S("--\n--")) - if i == -1 { - return nil, nil, fmt.Errorf("unexpected output: no delimiter") - } - i += len("--\n") - keyPEM, certPEM = res[:i], res[i:] - if mem.Contains(mem.B(certPEM), mem.S(" PRIVATE KEY-----")) { - return nil, nil, fmt.Errorf("unexpected output: key in cert") - } - return certPEM, keyPEM, nil -} - -// GetCertificate fetches a TLS certificate for the TLS ClientHello in hi. -// -// It returns a cached certificate from disk if it's still valid. -// -// It's the right signature to use as the value of -// tls.Config.GetCertificate. -// -// Deprecated: use LocalClient.GetCertificate. -func GetCertificate(hi *tls.ClientHelloInfo) (*tls.Certificate, error) { - return defaultLocalClient.GetCertificate(hi) -} - -// GetCertificate fetches a TLS certificate for the TLS ClientHello in hi. -// -// It returns a cached certificate from disk if it's still valid. -// -// It's the right signature to use as the value of -// tls.Config.GetCertificate. -// -// API maturity: this is considered a stable API. -func (lc *LocalClient) GetCertificate(hi *tls.ClientHelloInfo) (*tls.Certificate, error) { - if hi == nil || hi.ServerName == "" { - return nil, errors.New("no SNI ServerName") - } - ctx, cancel := context.WithTimeout(context.Background(), time.Minute) - defer cancel() - - name := hi.ServerName - if !strings.Contains(name, ".") { - if v, ok := lc.ExpandSNIName(ctx, name); ok { - name = v - } - } - certPEM, keyPEM, err := lc.CertPair(ctx, name) - if err != nil { - return nil, err - } - cert, err := tls.X509KeyPair(certPEM, keyPEM) - if err != nil { - return nil, err - } - return &cert, nil -} - -// ExpandSNIName expands bare label name into the most likely actual TLS cert name. -// -// Deprecated: use LocalClient.ExpandSNIName. -func ExpandSNIName(ctx context.Context, name string) (fqdn string, ok bool) { - return defaultLocalClient.ExpandSNIName(ctx, name) -} - -// ExpandSNIName expands bare label name into the most likely actual TLS cert name. -func (lc *LocalClient) ExpandSNIName(ctx context.Context, name string) (fqdn string, ok bool) { - st, err := lc.StatusWithoutPeers(ctx) - if err != nil { - return "", false - } - for _, d := range st.CertDomains { - if len(d) > len(name)+1 && strings.HasPrefix(d, name) && d[len(name)] == '.' { - return d, true - } - } - return "", false -} - // PingOpts contains options for the ping request. // // The zero value is valid, which means to use defaults. @@ -1090,7 +1033,7 @@ type PingOpts struct { // Ping sends a ping of the provided type to the provided IP and waits // for its response. The opts type specifies additional options. -func (lc *LocalClient) PingWithOpts(ctx context.Context, ip netip.Addr, pingtype tailcfg.PingType, opts PingOpts) (*ipnstate.PingResult, error) { +func (lc *Client) PingWithOpts(ctx context.Context, ip netip.Addr, pingtype tailcfg.PingType, opts PingOpts) (*ipnstate.PingResult, error) { v := url.Values{} v.Set("ip", ip.String()) v.Set("size", strconv.Itoa(opts.Size)) @@ -1104,235 +1047,21 @@ func (lc *LocalClient) PingWithOpts(ctx context.Context, ip netip.Addr, pingtype // Ping sends a ping of the provided type to the provided IP and waits // for its response. -func (lc *LocalClient) Ping(ctx context.Context, ip netip.Addr, pingtype tailcfg.PingType) (*ipnstate.PingResult, error) { +func (lc *Client) Ping(ctx context.Context, ip netip.Addr, pingtype tailcfg.PingType) (*ipnstate.PingResult, error) { return lc.PingWithOpts(ctx, ip, pingtype, PingOpts{}) } -// NetworkLockStatus fetches information about the tailnet key authority, if one is configured. -func (lc *LocalClient) NetworkLockStatus(ctx context.Context) (*ipnstate.NetworkLockStatus, error) { - body, err := lc.send(ctx, "GET", "/localapi/v0/tka/status", 200, nil) - if err != nil { - return nil, fmt.Errorf("error: %w", err) - } - return decodeJSON[*ipnstate.NetworkLockStatus](body) -} - -// NetworkLockInit initializes the tailnet key authority. -// -// TODO(tom): Plumb through disablement secrets. -func (lc *LocalClient) NetworkLockInit(ctx context.Context, keys []tka.Key, disablementValues [][]byte, supportDisablement []byte) (*ipnstate.NetworkLockStatus, error) { - var b bytes.Buffer - type initRequest struct { - Keys []tka.Key - DisablementValues [][]byte - SupportDisablement []byte - } - - if err := json.NewEncoder(&b).Encode(initRequest{Keys: keys, DisablementValues: disablementValues, SupportDisablement: supportDisablement}); err != nil { - return nil, err - } - - body, err := lc.send(ctx, "POST", "/localapi/v0/tka/init", 200, &b) - if err != nil { - return nil, fmt.Errorf("error: %w", err) - } - return decodeJSON[*ipnstate.NetworkLockStatus](body) -} - -// NetworkLockWrapPreauthKey wraps a pre-auth key with information to -// enable unattended bringup in the locked tailnet. -func (lc *LocalClient) NetworkLockWrapPreauthKey(ctx context.Context, preauthKey string, tkaKey key.NLPrivate) (string, error) { - encodedPrivate, err := tkaKey.MarshalText() - if err != nil { - return "", err - } - - var b bytes.Buffer - type wrapRequest struct { - TSKey string - TKAKey string // key.NLPrivate.MarshalText - } - if err := json.NewEncoder(&b).Encode(wrapRequest{TSKey: preauthKey, TKAKey: string(encodedPrivate)}); err != nil { - return "", err - } - - body, err := lc.send(ctx, "POST", "/localapi/v0/tka/wrap-preauth-key", 200, &b) - if err != nil { - return "", fmt.Errorf("error: %w", err) - } - return string(body), nil -} - -// NetworkLockModify adds and/or removes key(s) to the tailnet key authority. -func (lc *LocalClient) NetworkLockModify(ctx context.Context, addKeys, removeKeys []tka.Key) error { - var b bytes.Buffer - type modifyRequest struct { - AddKeys []tka.Key - RemoveKeys []tka.Key - } - - if err := json.NewEncoder(&b).Encode(modifyRequest{AddKeys: addKeys, RemoveKeys: removeKeys}); err != nil { - return err - } - - if _, err := lc.send(ctx, "POST", "/localapi/v0/tka/modify", 204, &b); err != nil { - return fmt.Errorf("error: %w", err) - } - return nil -} - -// NetworkLockSign signs the specified node-key and transmits that signature to the control plane. -// rotationPublic, if specified, must be an ed25519 public key. -func (lc *LocalClient) NetworkLockSign(ctx context.Context, nodeKey key.NodePublic, rotationPublic []byte) error { - var b bytes.Buffer - type signRequest struct { - NodeKey key.NodePublic - RotationPublic []byte - } - - if err := json.NewEncoder(&b).Encode(signRequest{NodeKey: nodeKey, RotationPublic: rotationPublic}); err != nil { - return err - } - - if _, err := lc.send(ctx, "POST", "/localapi/v0/tka/sign", 200, &b); err != nil { - return fmt.Errorf("error: %w", err) - } - return nil -} - -// NetworkLockAffectedSigs returns all signatures signed by the specified keyID. -func (lc *LocalClient) NetworkLockAffectedSigs(ctx context.Context, keyID tkatype.KeyID) ([]tkatype.MarshaledSignature, error) { - body, err := lc.send(ctx, "POST", "/localapi/v0/tka/affected-sigs", 200, bytes.NewReader(keyID)) - if err != nil { - return nil, fmt.Errorf("error: %w", err) - } - return decodeJSON[[]tkatype.MarshaledSignature](body) -} - -// NetworkLockLog returns up to maxEntries number of changes to network-lock state. -func (lc *LocalClient) NetworkLockLog(ctx context.Context, maxEntries int) ([]ipnstate.NetworkLockUpdate, error) { - v := url.Values{} - v.Set("limit", fmt.Sprint(maxEntries)) - body, err := lc.send(ctx, "GET", "/localapi/v0/tka/log?"+v.Encode(), 200, nil) +// DisconnectControl shuts down all connections to control, thus making control consider this node inactive. This can be +// run on HA subnet router or app connector replicas before shutting them down to ensure peers get told to switch over +// to another replica whilst there is still some grace period for the existing connections to terminate. +func (lc *Client) DisconnectControl(ctx context.Context) error { + _, _, err := lc.sendWithHeaders(ctx, "POST", "/localapi/v0/disconnect-control", 200, nil, nil) if err != nil { - return nil, fmt.Errorf("error %w: %s", err, body) - } - return decodeJSON[[]ipnstate.NetworkLockUpdate](body) -} - -// NetworkLockForceLocalDisable forcibly shuts down network lock on this node. -func (lc *LocalClient) NetworkLockForceLocalDisable(ctx context.Context) error { - // This endpoint expects an empty JSON stanza as the payload. - var b bytes.Buffer - if err := json.NewEncoder(&b).Encode(struct{}{}); err != nil { - return err - } - - if _, err := lc.send(ctx, "POST", "/localapi/v0/tka/force-local-disable", 200, &b); err != nil { - return fmt.Errorf("error: %w", err) + return fmt.Errorf("error disconnecting control: %w", err) } return nil } -// NetworkLockVerifySigningDeeplink verifies the network lock deeplink contained -// in url and returns information extracted from it. -func (lc *LocalClient) NetworkLockVerifySigningDeeplink(ctx context.Context, url string) (*tka.DeeplinkValidationResult, error) { - vr := struct { - URL string - }{url} - - body, err := lc.send(ctx, "POST", "/localapi/v0/tka/verify-deeplink", 200, jsonBody(vr)) - if err != nil { - return nil, fmt.Errorf("sending verify-deeplink: %w", err) - } - - return decodeJSON[*tka.DeeplinkValidationResult](body) -} - -// NetworkLockGenRecoveryAUM generates an AUM for recovering from a tailnet-lock key compromise. -func (lc *LocalClient) NetworkLockGenRecoveryAUM(ctx context.Context, removeKeys []tkatype.KeyID, forkFrom tka.AUMHash) ([]byte, error) { - vr := struct { - Keys []tkatype.KeyID - ForkFrom string - }{removeKeys, forkFrom.String()} - - body, err := lc.send(ctx, "POST", "/localapi/v0/tka/generate-recovery-aum", 200, jsonBody(vr)) - if err != nil { - return nil, fmt.Errorf("sending generate-recovery-aum: %w", err) - } - - return body, nil -} - -// NetworkLockCosignRecoveryAUM co-signs a recovery AUM using the node's tailnet lock key. -func (lc *LocalClient) NetworkLockCosignRecoveryAUM(ctx context.Context, aum tka.AUM) ([]byte, error) { - r := bytes.NewReader(aum.Serialize()) - body, err := lc.send(ctx, "POST", "/localapi/v0/tka/cosign-recovery-aum", 200, r) - if err != nil { - return nil, fmt.Errorf("sending cosign-recovery-aum: %w", err) - } - - return body, nil -} - -// NetworkLockSubmitRecoveryAUM submits a recovery AUM to the control plane. -func (lc *LocalClient) NetworkLockSubmitRecoveryAUM(ctx context.Context, aum tka.AUM) error { - r := bytes.NewReader(aum.Serialize()) - _, err := lc.send(ctx, "POST", "/localapi/v0/tka/submit-recovery-aum", 200, r) - if err != nil { - return fmt.Errorf("sending cosign-recovery-aum: %w", err) - } - return nil -} - -// SetServeConfig sets or replaces the serving settings. -// If config is nil, settings are cleared and serving is disabled. -func (lc *LocalClient) SetServeConfig(ctx context.Context, config *ipn.ServeConfig) error { - h := make(http.Header) - if config != nil { - h.Set("If-Match", config.ETag) - } - _, _, err := lc.sendWithHeaders(ctx, "POST", "/localapi/v0/serve-config", 200, jsonBody(config), h) - if err != nil { - return fmt.Errorf("sending serve config: %w", err) - } - return nil -} - -// NetworkLockDisable shuts down network-lock across the tailnet. -func (lc *LocalClient) NetworkLockDisable(ctx context.Context, secret []byte) error { - if _, err := lc.send(ctx, "POST", "/localapi/v0/tka/disable", 200, bytes.NewReader(secret)); err != nil { - return fmt.Errorf("error: %w", err) - } - return nil -} - -// GetServeConfig return the current serve config. -// -// If the serve config is empty, it returns (nil, nil). -func (lc *LocalClient) GetServeConfig(ctx context.Context) (*ipn.ServeConfig, error) { - body, h, err := lc.sendWithHeaders(ctx, "GET", "/localapi/v0/serve-config", 200, nil, nil) - if err != nil { - return nil, fmt.Errorf("getting serve config: %w", err) - } - sc, err := getServeConfigFromJSON(body) - if err != nil { - return nil, err - } - if sc == nil { - sc = new(ipn.ServeConfig) - } - sc.ETag = h.Get("Etag") - return sc, nil -} - -func getServeConfigFromJSON(body []byte) (sc *ipn.ServeConfig, err error) { - if err := json.Unmarshal(body, &sc); err != nil { - return nil, err - } - return sc, nil -} - // tailscaledConnectHint gives a little thing about why tailscaled (or // platform equivalent) is not answering localapi connections. // @@ -1385,7 +1114,7 @@ func (r jsonReader) Read(p []byte) (n int, err error) { } // ProfileStatus returns the current profile and the list of all profiles. -func (lc *LocalClient) ProfileStatus(ctx context.Context) (current ipn.LoginProfile, all []ipn.LoginProfile, err error) { +func (lc *Client) ProfileStatus(ctx context.Context) (current ipn.LoginProfile, all []ipn.LoginProfile, err error) { body, err := lc.send(ctx, "GET", "/localapi/v0/profiles/current", 200, nil) if err != nil { return @@ -1403,7 +1132,7 @@ func (lc *LocalClient) ProfileStatus(ctx context.Context) (current ipn.LoginProf } // ReloadConfig reloads the config file, if possible. -func (lc *LocalClient) ReloadConfig(ctx context.Context) (ok bool, err error) { +func (lc *Client) ReloadConfig(ctx context.Context) (ok bool, err error) { body, err := lc.send(ctx, "POST", "/localapi/v0/reload-config", 200, nil) if err != nil { return @@ -1421,22 +1150,22 @@ func (lc *LocalClient) ReloadConfig(ctx context.Context) (ok bool, err error) { // SwitchToEmptyProfile creates and switches to a new unnamed profile. The new // profile is not assigned an ID until it is persisted after a successful login. // In order to login to the new profile, the user must call LoginInteractive. -func (lc *LocalClient) SwitchToEmptyProfile(ctx context.Context) error { +func (lc *Client) SwitchToEmptyProfile(ctx context.Context) error { _, err := lc.send(ctx, "PUT", "/localapi/v0/profiles/", http.StatusCreated, nil) return err } // SwitchProfile switches to the given profile. -func (lc *LocalClient) SwitchProfile(ctx context.Context, profile ipn.ProfileID) error { +func (lc *Client) SwitchProfile(ctx context.Context, profile ipn.ProfileID) error { _, err := lc.send(ctx, "POST", "/localapi/v0/profiles/"+url.PathEscape(string(profile)), 204, nil) return err } // DeleteProfile removes the profile with the given ID. // If the profile is the current profile, an empty profile -// will be selected as if SwitchToEmptyProfile was called. -func (lc *LocalClient) DeleteProfile(ctx context.Context, profile ipn.ProfileID) error { - _, err := lc.send(ctx, "DELETE", "/localapi/v0/profiles"+url.PathEscape(string(profile)), http.StatusNoContent, nil) +// will be selected as if [Client.SwitchToEmptyProfile] was called. +func (lc *Client) DeleteProfile(ctx context.Context, profile ipn.ProfileID) error { + _, err := lc.send(ctx, "DELETE", "/localapi/v0/profiles/"+url.PathEscape(string(profile)), http.StatusNoContent, nil) return err } @@ -1452,7 +1181,7 @@ func (lc *LocalClient) DeleteProfile(ctx context.Context, profile ipn.ProfileID) // to block until the feature has been enabled. // // 2023-08-09: Valid feature values are "serve" and "funnel". -func (lc *LocalClient) QueryFeature(ctx context.Context, feature string) (*tailcfg.QueryFeatureResponse, error) { +func (lc *Client) QueryFeature(ctx context.Context, feature string) (*tailcfg.QueryFeatureResponse, error) { v := url.Values{"feature": {feature}} body, err := lc.send(ctx, "POST", "/localapi/v0/query-feature?"+v.Encode(), 200, nil) if err != nil { @@ -1461,7 +1190,7 @@ func (lc *LocalClient) QueryFeature(ctx context.Context, feature string) (*tailc return decodeJSON[*tailcfg.QueryFeatureResponse](body) } -func (lc *LocalClient) DebugDERPRegion(ctx context.Context, regionIDOrCode string) (*ipnstate.DebugDERPRegionReport, error) { +func (lc *Client) DebugDERPRegion(ctx context.Context, regionIDOrCode string) (*ipnstate.DebugDERPRegionReport, error) { v := url.Values{"region": {regionIDOrCode}} body, err := lc.send(ctx, "POST", "/localapi/v0/debug-derp-region?"+v.Encode(), 200, nil) if err != nil { @@ -1471,7 +1200,7 @@ func (lc *LocalClient) DebugDERPRegion(ctx context.Context, regionIDOrCode strin } // DebugPacketFilterRules returns the packet filter rules for the current device. -func (lc *LocalClient) DebugPacketFilterRules(ctx context.Context) ([]tailcfg.FilterRule, error) { +func (lc *Client) DebugPacketFilterRules(ctx context.Context) ([]tailcfg.FilterRule, error) { body, err := lc.send(ctx, "POST", "/localapi/v0/debug-packet-filter-rules", 200, nil) if err != nil { return nil, fmt.Errorf("error %w: %s", err, body) @@ -1482,17 +1211,27 @@ func (lc *LocalClient) DebugPacketFilterRules(ctx context.Context) ([]tailcfg.Fi // DebugSetExpireIn marks the current node key to expire in d. // // This is meant primarily for debug and testing. -func (lc *LocalClient) DebugSetExpireIn(ctx context.Context, d time.Duration) error { +func (lc *Client) DebugSetExpireIn(ctx context.Context, d time.Duration) error { v := url.Values{"expiry": {fmt.Sprint(time.Now().Add(d).Unix())}} _, err := lc.send(ctx, "POST", "/localapi/v0/set-expiry-sooner?"+v.Encode(), 200, nil) return err } +// DebugPeerRelaySessions returns debug information about the current peer +// relay sessions running through this node. +func (lc *Client) DebugPeerRelaySessions(ctx context.Context) (*status.ServerStatus, error) { + body, err := lc.send(ctx, "GET", "/localapi/v0/debug-peer-relay-sessions", 200, nil) + if err != nil { + return nil, fmt.Errorf("error %w: %s", err, body) + } + return decodeJSON[*status.ServerStatus](body) +} + // StreamDebugCapture streams a pcap-formatted packet capture. // // The provided context does not determine the lifetime of the -// returned io.ReadCloser. -func (lc *LocalClient) StreamDebugCapture(ctx context.Context) (io.ReadCloser, error) { +// returned [io.ReadCloser]. +func (lc *Client) StreamDebugCapture(ctx context.Context) (io.ReadCloser, error) { req, err := http.NewRequestWithContext(ctx, "POST", "http://"+apitype.LocalAPIHost+"/localapi/v0/debug-capture", nil) if err != nil { return nil, err @@ -1514,11 +1253,11 @@ func (lc *LocalClient) StreamDebugCapture(ctx context.Context) (io.ReadCloser, e // The context is used for the life of the watch, not just the call to // WatchIPNBus. // -// The returned IPNBusWatcher's Close method must be called when done to release +// The returned [IPNBusWatcher]'s Close method must be called when done to release // resources. // // A default set of ipn.Notify messages are returned but the set can be modified by mask. -func (lc *LocalClient) WatchIPNBus(ctx context.Context, mask ipn.NotifyWatchOpt) (*IPNBusWatcher, error) { +func (lc *Client) WatchIPNBus(ctx context.Context, mask ipn.NotifyWatchOpt) (*IPNBusWatcher, error) { req, err := http.NewRequestWithContext(ctx, "GET", "http://"+apitype.LocalAPIHost+"/localapi/v0/watch-ipn-bus?mask="+fmt.Sprint(mask), nil) @@ -1541,10 +1280,10 @@ func (lc *LocalClient) WatchIPNBus(ctx context.Context, mask ipn.NotifyWatchOpt) }, nil } -// CheckUpdate returns a tailcfg.ClientVersion indicating whether or not an update is available +// CheckUpdate returns a [*tailcfg.ClientVersion] indicating whether or not an update is available // to be installed via the LocalAPI. In case the LocalAPI can't install updates, it returns a // ClientVersion that says that we are up to date. -func (lc *LocalClient) CheckUpdate(ctx context.Context) (*tailcfg.ClientVersion, error) { +func (lc *Client) CheckUpdate(ctx context.Context) (*tailcfg.ClientVersion, error) { body, err := lc.get200(ctx, "/localapi/v0/update/check") if err != nil { return nil, err @@ -1560,7 +1299,7 @@ func (lc *LocalClient) CheckUpdate(ctx context.Context) (*tailcfg.ClientVersion, // To turn it on, there must have been a previously used exit node. // The most previously used one is reused. // This is a convenience method for GUIs. To select an actual one, update the prefs. -func (lc *LocalClient) SetUseExitNode(ctx context.Context, on bool) error { +func (lc *Client) SetUseExitNode(ctx context.Context, on bool) error { _, err := lc.send(ctx, "POST", "/localapi/v0/set-use-exit-node-enabled?enabled="+strconv.FormatBool(on), http.StatusOK, nil) return err } @@ -1568,7 +1307,7 @@ func (lc *LocalClient) SetUseExitNode(ctx context.Context, on bool) error { // DriveSetServerAddr instructs Taildrive to use the server at addr to access // the filesystem. This is used on platforms like Windows and MacOS to let // Taildrive know to use the file server running in the GUI app. -func (lc *LocalClient) DriveSetServerAddr(ctx context.Context, addr string) error { +func (lc *Client) DriveSetServerAddr(ctx context.Context, addr string) error { _, err := lc.send(ctx, "PUT", "/localapi/v0/drive/fileserver-address", http.StatusCreated, strings.NewReader(addr)) return err } @@ -1576,14 +1315,14 @@ func (lc *LocalClient) DriveSetServerAddr(ctx context.Context, addr string) erro // DriveShareSet adds or updates the given share in the list of shares that // Taildrive will serve to remote nodes. If a share with the same name already // exists, the existing share is replaced/updated. -func (lc *LocalClient) DriveShareSet(ctx context.Context, share *drive.Share) error { +func (lc *Client) DriveShareSet(ctx context.Context, share *drive.Share) error { _, err := lc.send(ctx, "PUT", "/localapi/v0/drive/shares", http.StatusCreated, jsonBody(share)) return err } // DriveShareRemove removes the share with the given name from the list of // shares that Taildrive will serve to remote nodes. -func (lc *LocalClient) DriveShareRemove(ctx context.Context, name string) error { +func (lc *Client) DriveShareRemove(ctx context.Context, name string) error { _, err := lc.send( ctx, "DELETE", @@ -1594,7 +1333,7 @@ func (lc *LocalClient) DriveShareRemove(ctx context.Context, name string) error } // DriveShareRename renames the share from old to new name. -func (lc *LocalClient) DriveShareRename(ctx context.Context, oldName, newName string) error { +func (lc *Client) DriveShareRename(ctx context.Context, oldName, newName string) error { _, err := lc.send( ctx, "POST", @@ -1606,7 +1345,7 @@ func (lc *LocalClient) DriveShareRename(ctx context.Context, oldName, newName st // DriveShareList returns the list of shares that drive is currently serving // to remote nodes. -func (lc *LocalClient) DriveShareList(ctx context.Context) ([]*drive.Share, error) { +func (lc *Client) DriveShareList(ctx context.Context) ([]*drive.Share, error) { result, err := lc.get200(ctx, "/localapi/v0/drive/shares") if err != nil { return nil, err @@ -1617,7 +1356,7 @@ func (lc *LocalClient) DriveShareList(ctx context.Context) ([]*drive.Share, erro } // IPNBusWatcher is an active subscription (watch) of the local tailscaled IPN bus. -// It's returned by LocalClient.WatchIPNBus. +// It's returned by [Client.WatchIPNBus]. // // It must be closed when done. type IPNBusWatcher struct { @@ -1625,7 +1364,7 @@ type IPNBusWatcher struct { httpRes *http.Response dec *json.Decoder - mu sync.Mutex + mu syncs.Mutex closed bool } @@ -1641,7 +1380,7 @@ func (w *IPNBusWatcher) Close() error { } // Next returns the next ipn.Notify from the stream. -// If the context from LocalClient.WatchIPNBus is done, that error is returned. +// If the context from Client.WatchIPNBus is done, that error is returned. func (w *IPNBusWatcher) Next() (ipn.Notify, error) { var n ipn.Notify if err := w.dec.Decode(&n); err != nil { @@ -1654,10 +1393,41 @@ func (w *IPNBusWatcher) Next() (ipn.Notify, error) { } // SuggestExitNode requests an exit node suggestion and returns the exit node's details. -func (lc *LocalClient) SuggestExitNode(ctx context.Context) (apitype.ExitNodeSuggestionResponse, error) { +func (lc *Client) SuggestExitNode(ctx context.Context) (apitype.ExitNodeSuggestionResponse, error) { body, err := lc.get200(ctx, "/localapi/v0/suggest-exit-node") if err != nil { return apitype.ExitNodeSuggestionResponse{}, err } return decodeJSON[apitype.ExitNodeSuggestionResponse](body) } + +// CheckSOMarkInUse reports whether the socket mark option is in use. This will only +// be true if tailscale is running on Linux and tailscaled uses SO_MARK. +func (lc *Client) CheckSOMarkInUse(ctx context.Context) (bool, error) { + body, err := lc.get200(ctx, "/localapi/v0/check-so-mark-in-use") + if err != nil { + return false, err + } + var res struct { + UseSOMark bool `json:"useSoMark"` + } + + if err := json.Unmarshal(body, &res); err != nil { + return false, fmt.Errorf("invalid JSON from check-so-mark-in-use: %w", err) + } + return res.UseSOMark, nil +} + +// ShutdownTailscaled requests a graceful shutdown of tailscaled. +func (lc *Client) ShutdownTailscaled(ctx context.Context) error { + _, err := lc.send(ctx, "POST", "/localapi/v0/shutdown", 200, nil) + return err +} + +func (lc *Client) GetAppConnectorRouteInfo(ctx context.Context) (appctype.RouteInfo, error) { + body, err := lc.get200(ctx, "/localapi/v0/appc-route-info") + if err != nil { + return appctype.RouteInfo{}, err + } + return decodeJSON[appctype.RouteInfo](body) +} diff --git a/client/tailscale/localclient_test.go b/client/local/local_test.go similarity index 87% rename from client/tailscale/localclient_test.go rename to client/local/local_test.go index 950a22f47..0e01e74cd 100644 --- a/client/tailscale/localclient_test.go +++ b/client/local/local_test.go @@ -3,16 +3,16 @@ //go:build go1.19 -package tailscale +package local import ( "context" "net" "net/http" - "net/http/httptest" "testing" "tailscale.com/tstest/deptest" + "tailscale.com/tstest/nettest" "tailscale.com/types/key" ) @@ -36,15 +36,15 @@ func TestGetServeConfigFromJSON(t *testing.T) { } func TestWhoIsPeerNotFound(t *testing.T) { - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + nw := nettest.GetNetwork(t) + ts := nettest.NewHTTPServer(nw, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(404) })) defer ts.Close() - lc := &LocalClient{ + lc := &Client{ Dial: func(ctx context.Context, network, addr string) (net.Conn, error) { - var std net.Dialer - return std.DialContext(ctx, network, ts.Listener.Addr().(*net.TCPAddr).String()) + return nw.Dial(ctx, network, ts.Listener.Addr().String()) }, } var k key.NodePublic diff --git a/client/local/serve.go b/client/local/serve.go new file mode 100644 index 000000000..51d15e7e5 --- /dev/null +++ b/client/local/serve.go @@ -0,0 +1,55 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !ts_omit_serve + +package local + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + + "tailscale.com/ipn" +) + +// GetServeConfig return the current serve config. +// +// If the serve config is empty, it returns (nil, nil). +func (lc *Client) GetServeConfig(ctx context.Context) (*ipn.ServeConfig, error) { + body, h, err := lc.sendWithHeaders(ctx, "GET", "/localapi/v0/serve-config", 200, nil, nil) + if err != nil { + return nil, fmt.Errorf("getting serve config: %w", err) + } + sc, err := getServeConfigFromJSON(body) + if err != nil { + return nil, err + } + if sc == nil { + sc = new(ipn.ServeConfig) + } + sc.ETag = h.Get("Etag") + return sc, nil +} + +func getServeConfigFromJSON(body []byte) (sc *ipn.ServeConfig, err error) { + if err := json.Unmarshal(body, &sc); err != nil { + return nil, err + } + return sc, nil +} + +// SetServeConfig sets or replaces the serving settings. +// If config is nil, settings are cleared and serving is disabled. +func (lc *Client) SetServeConfig(ctx context.Context, config *ipn.ServeConfig) error { + h := make(http.Header) + if config != nil { + h.Set("If-Match", config.ETag) + } + _, _, err := lc.sendWithHeaders(ctx, "POST", "/localapi/v0/serve-config", 200, jsonBody(config), h) + if err != nil { + return fmt.Errorf("sending serve config: %w", err) + } + return nil +} diff --git a/client/local/syspolicy.go b/client/local/syspolicy.go new file mode 100644 index 000000000..6eff17783 --- /dev/null +++ b/client/local/syspolicy.go @@ -0,0 +1,40 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !ts_omit_syspolicy + +package local + +import ( + "context" + "net/http" + + "tailscale.com/util/syspolicy/setting" +) + +// GetEffectivePolicy returns the effective policy for the specified scope. +func (lc *Client) GetEffectivePolicy(ctx context.Context, scope setting.PolicyScope) (*setting.Snapshot, error) { + scopeID, err := scope.MarshalText() + if err != nil { + return nil, err + } + body, err := lc.get200(ctx, "/localapi/v0/policy/"+string(scopeID)) + if err != nil { + return nil, err + } + return decodeJSON[*setting.Snapshot](body) +} + +// ReloadEffectivePolicy reloads the effective policy for the specified scope +// by reading and merging policy settings from all applicable policy sources. +func (lc *Client) ReloadEffectivePolicy(ctx context.Context, scope setting.PolicyScope) (*setting.Snapshot, error) { + scopeID, err := scope.MarshalText() + if err != nil { + return nil, err + } + body, err := lc.send(ctx, "POST", "/localapi/v0/policy/"+string(scopeID), 200, http.NoBody) + if err != nil { + return nil, err + } + return decodeJSON[*setting.Snapshot](body) +} diff --git a/client/local/tailnetlock.go b/client/local/tailnetlock.go new file mode 100644 index 000000000..9d37d2f35 --- /dev/null +++ b/client/local/tailnetlock.go @@ -0,0 +1,204 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !ts_omit_tailnetlock + +package local + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "net/url" + + "tailscale.com/ipn/ipnstate" + "tailscale.com/tka" + "tailscale.com/types/key" + "tailscale.com/types/tkatype" +) + +// NetworkLockStatus fetches information about the tailnet key authority, if one is configured. +func (lc *Client) NetworkLockStatus(ctx context.Context) (*ipnstate.NetworkLockStatus, error) { + body, err := lc.send(ctx, "GET", "/localapi/v0/tka/status", 200, nil) + if err != nil { + return nil, fmt.Errorf("error: %w", err) + } + return decodeJSON[*ipnstate.NetworkLockStatus](body) +} + +// NetworkLockInit initializes the tailnet key authority. +// +// TODO(tom): Plumb through disablement secrets. +func (lc *Client) NetworkLockInit(ctx context.Context, keys []tka.Key, disablementValues [][]byte, supportDisablement []byte) (*ipnstate.NetworkLockStatus, error) { + var b bytes.Buffer + type initRequest struct { + Keys []tka.Key + DisablementValues [][]byte + SupportDisablement []byte + } + + if err := json.NewEncoder(&b).Encode(initRequest{Keys: keys, DisablementValues: disablementValues, SupportDisablement: supportDisablement}); err != nil { + return nil, err + } + + body, err := lc.send(ctx, "POST", "/localapi/v0/tka/init", 200, &b) + if err != nil { + return nil, fmt.Errorf("error: %w", err) + } + return decodeJSON[*ipnstate.NetworkLockStatus](body) +} + +// NetworkLockWrapPreauthKey wraps a pre-auth key with information to +// enable unattended bringup in the locked tailnet. +func (lc *Client) NetworkLockWrapPreauthKey(ctx context.Context, preauthKey string, tkaKey key.NLPrivate) (string, error) { + encodedPrivate, err := tkaKey.MarshalText() + if err != nil { + return "", err + } + + var b bytes.Buffer + type wrapRequest struct { + TSKey string + TKAKey string // key.NLPrivate.MarshalText + } + if err := json.NewEncoder(&b).Encode(wrapRequest{TSKey: preauthKey, TKAKey: string(encodedPrivate)}); err != nil { + return "", err + } + + body, err := lc.send(ctx, "POST", "/localapi/v0/tka/wrap-preauth-key", 200, &b) + if err != nil { + return "", fmt.Errorf("error: %w", err) + } + return string(body), nil +} + +// NetworkLockModify adds and/or removes key(s) to the tailnet key authority. +func (lc *Client) NetworkLockModify(ctx context.Context, addKeys, removeKeys []tka.Key) error { + var b bytes.Buffer + type modifyRequest struct { + AddKeys []tka.Key + RemoveKeys []tka.Key + } + + if err := json.NewEncoder(&b).Encode(modifyRequest{AddKeys: addKeys, RemoveKeys: removeKeys}); err != nil { + return err + } + + if _, err := lc.send(ctx, "POST", "/localapi/v0/tka/modify", 204, &b); err != nil { + return fmt.Errorf("error: %w", err) + } + return nil +} + +// NetworkLockSign signs the specified node-key and transmits that signature to the control plane. +// rotationPublic, if specified, must be an ed25519 public key. +func (lc *Client) NetworkLockSign(ctx context.Context, nodeKey key.NodePublic, rotationPublic []byte) error { + var b bytes.Buffer + type signRequest struct { + NodeKey key.NodePublic + RotationPublic []byte + } + + if err := json.NewEncoder(&b).Encode(signRequest{NodeKey: nodeKey, RotationPublic: rotationPublic}); err != nil { + return err + } + + if _, err := lc.send(ctx, "POST", "/localapi/v0/tka/sign", 200, &b); err != nil { + return fmt.Errorf("error: %w", err) + } + return nil +} + +// NetworkLockAffectedSigs returns all signatures signed by the specified keyID. +func (lc *Client) NetworkLockAffectedSigs(ctx context.Context, keyID tkatype.KeyID) ([]tkatype.MarshaledSignature, error) { + body, err := lc.send(ctx, "POST", "/localapi/v0/tka/affected-sigs", 200, bytes.NewReader(keyID)) + if err != nil { + return nil, fmt.Errorf("error: %w", err) + } + return decodeJSON[[]tkatype.MarshaledSignature](body) +} + +// NetworkLockLog returns up to maxEntries number of changes to network-lock state. +func (lc *Client) NetworkLockLog(ctx context.Context, maxEntries int) ([]ipnstate.NetworkLockUpdate, error) { + v := url.Values{} + v.Set("limit", fmt.Sprint(maxEntries)) + body, err := lc.send(ctx, "GET", "/localapi/v0/tka/log?"+v.Encode(), 200, nil) + if err != nil { + return nil, fmt.Errorf("error %w: %s", err, body) + } + return decodeJSON[[]ipnstate.NetworkLockUpdate](body) +} + +// NetworkLockForceLocalDisable forcibly shuts down network lock on this node. +func (lc *Client) NetworkLockForceLocalDisable(ctx context.Context) error { + // This endpoint expects an empty JSON stanza as the payload. + var b bytes.Buffer + if err := json.NewEncoder(&b).Encode(struct{}{}); err != nil { + return err + } + + if _, err := lc.send(ctx, "POST", "/localapi/v0/tka/force-local-disable", 200, &b); err != nil { + return fmt.Errorf("error: %w", err) + } + return nil +} + +// NetworkLockVerifySigningDeeplink verifies the network lock deeplink contained +// in url and returns information extracted from it. +func (lc *Client) NetworkLockVerifySigningDeeplink(ctx context.Context, url string) (*tka.DeeplinkValidationResult, error) { + vr := struct { + URL string + }{url} + + body, err := lc.send(ctx, "POST", "/localapi/v0/tka/verify-deeplink", 200, jsonBody(vr)) + if err != nil { + return nil, fmt.Errorf("sending verify-deeplink: %w", err) + } + + return decodeJSON[*tka.DeeplinkValidationResult](body) +} + +// NetworkLockGenRecoveryAUM generates an AUM for recovering from a tailnet-lock key compromise. +func (lc *Client) NetworkLockGenRecoveryAUM(ctx context.Context, removeKeys []tkatype.KeyID, forkFrom tka.AUMHash) ([]byte, error) { + vr := struct { + Keys []tkatype.KeyID + ForkFrom string + }{removeKeys, forkFrom.String()} + + body, err := lc.send(ctx, "POST", "/localapi/v0/tka/generate-recovery-aum", 200, jsonBody(vr)) + if err != nil { + return nil, fmt.Errorf("sending generate-recovery-aum: %w", err) + } + + return body, nil +} + +// NetworkLockCosignRecoveryAUM co-signs a recovery AUM using the node's tailnet lock key. +func (lc *Client) NetworkLockCosignRecoveryAUM(ctx context.Context, aum tka.AUM) ([]byte, error) { + r := bytes.NewReader(aum.Serialize()) + body, err := lc.send(ctx, "POST", "/localapi/v0/tka/cosign-recovery-aum", 200, r) + if err != nil { + return nil, fmt.Errorf("sending cosign-recovery-aum: %w", err) + } + + return body, nil +} + +// NetworkLockSubmitRecoveryAUM submits a recovery AUM to the control plane. +func (lc *Client) NetworkLockSubmitRecoveryAUM(ctx context.Context, aum tka.AUM) error { + r := bytes.NewReader(aum.Serialize()) + _, err := lc.send(ctx, "POST", "/localapi/v0/tka/submit-recovery-aum", 200, r) + if err != nil { + return fmt.Errorf("sending cosign-recovery-aum: %w", err) + } + return nil +} + +// NetworkLockDisable shuts down network-lock across the tailnet. +func (lc *Client) NetworkLockDisable(ctx context.Context, secret []byte) error { + if _, err := lc.send(ctx, "POST", "/localapi/v0/tka/disable", 200, bytes.NewReader(secret)); err != nil { + return fmt.Errorf("error: %w", err) + } + return nil +} diff --git a/client/systray/logo.go b/client/systray/logo.go new file mode 100644 index 000000000..3467d1b74 --- /dev/null +++ b/client/systray/logo.go @@ -0,0 +1,327 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build cgo || !darwin + +package systray + +import ( + "bytes" + "context" + "image" + "image/color" + "image/png" + "runtime" + "sync" + "time" + + "fyne.io/systray" + ico "github.com/Kodeworks/golang-image-ico" + "github.com/fogleman/gg" +) + +// tsLogo represents the Tailscale logo displayed as the systray icon. +type tsLogo struct { + // dots represents the state of the 3x3 dot grid in the logo. + // A 0 represents a gray dot, any other value is a white dot. + dots [9]byte + + // dotMask returns an image mask to be used when rendering the logo dots. + dotMask func(dc *gg.Context, borderUnits int, radius int) *image.Alpha + + // overlay is called after the dots are rendered to draw an additional overlay. + overlay func(dc *gg.Context, borderUnits int, radius int) +} + +var ( + // disconnected is all gray dots + disconnected = tsLogo{dots: [9]byte{ + 0, 0, 0, + 0, 0, 0, + 0, 0, 0, + }} + + // connected is the normal Tailscale logo + connected = tsLogo{dots: [9]byte{ + 0, 0, 0, + 1, 1, 1, + 0, 1, 0, + }} + + // loading is a special tsLogo value that is not meant to be rendered directly, + // but indicates that the loading animation should be shown. + loading = tsLogo{dots: [9]byte{'l', 'o', 'a', 'd', 'i', 'n', 'g'}} + + // loadingIcons are shown in sequence as an animated loading icon. + loadingLogos = []tsLogo{ + {dots: [9]byte{ + 0, 1, 1, + 1, 0, 1, + 0, 0, 1, + }}, + {dots: [9]byte{ + 0, 1, 1, + 0, 0, 1, + 0, 1, 0, + }}, + {dots: [9]byte{ + 0, 1, 1, + 0, 0, 0, + 0, 0, 1, + }}, + {dots: [9]byte{ + 0, 0, 1, + 0, 1, 0, + 0, 0, 0, + }}, + {dots: [9]byte{ + 0, 1, 0, + 0, 0, 0, + 0, 0, 0, + }}, + {dots: [9]byte{ + 0, 0, 0, + 0, 0, 1, + 0, 0, 0, + }}, + {dots: [9]byte{ + 0, 0, 0, + 0, 0, 0, + 0, 0, 0, + }}, + {dots: [9]byte{ + 0, 0, 1, + 0, 0, 0, + 0, 0, 0, + }}, + {dots: [9]byte{ + 0, 0, 0, + 0, 0, 0, + 1, 0, 0, + }}, + {dots: [9]byte{ + 0, 0, 0, + 0, 0, 0, + 1, 1, 0, + }}, + {dots: [9]byte{ + 0, 0, 0, + 1, 0, 0, + 1, 1, 0, + }}, + {dots: [9]byte{ + 0, 0, 0, + 1, 1, 0, + 0, 1, 0, + }}, + {dots: [9]byte{ + 0, 0, 0, + 1, 1, 0, + 0, 1, 1, + }}, + {dots: [9]byte{ + 0, 0, 0, + 1, 1, 1, + 0, 0, 1, + }}, + {dots: [9]byte{ + 0, 1, 0, + 0, 1, 1, + 1, 0, 1, + }}, + } + + // exitNodeOnline is the Tailscale logo with an additional arrow overlay in the corner. + exitNodeOnline = tsLogo{ + dots: [9]byte{ + 0, 0, 0, + 1, 1, 1, + 0, 1, 0, + }, + // draw an arrow mask in the bottom right corner with a reasonably thick line width. + dotMask: func(dc *gg.Context, borderUnits int, radius int) *image.Alpha { + bu, r := float64(borderUnits), float64(radius) + + x1 := r * (bu + 3.5) + y := r * (bu + 7) + x2 := x1 + (r * 5) + + mc := gg.NewContext(dc.Width(), dc.Height()) + mc.DrawLine(x1, y, x2, y) // arrow center line + mc.DrawLine(x2-(1.5*r), y-(1.5*r), x2, y) // top of arrow tip + mc.DrawLine(x2-(1.5*r), y+(1.5*r), x2, y) // bottom of arrow tip + mc.SetLineWidth(r * 3) + mc.Stroke() + return mc.AsMask() + }, + // draw an arrow in the bottom right corner over the masked area. + overlay: func(dc *gg.Context, borderUnits int, radius int) { + bu, r := float64(borderUnits), float64(radius) + + x1 := r * (bu + 3.5) + y := r * (bu + 7) + x2 := x1 + (r * 5) + + dc.DrawLine(x1, y, x2, y) // arrow center line + dc.DrawLine(x2-(1.5*r), y-(1.5*r), x2, y) // top of arrow tip + dc.DrawLine(x2-(1.5*r), y+(1.5*r), x2, y) // bottom of arrow tip + dc.SetColor(fg) + dc.SetLineWidth(r) + dc.Stroke() + }, + } + + // exitNodeOffline is the Tailscale logo with a red "x" in the corner. + exitNodeOffline = tsLogo{ + dots: [9]byte{ + 0, 0, 0, + 1, 1, 1, + 0, 1, 0, + }, + // Draw a square that hides the four dots in the bottom right corner, + dotMask: func(dc *gg.Context, borderUnits int, radius int) *image.Alpha { + bu, r := float64(borderUnits), float64(radius) + x := r * (bu + 3) + + mc := gg.NewContext(dc.Width(), dc.Height()) + mc.DrawRectangle(x, x, r*6, r*6) + mc.Fill() + return mc.AsMask() + }, + // draw a red "x" over the bottom right corner. + overlay: func(dc *gg.Context, borderUnits int, radius int) { + bu, r := float64(borderUnits), float64(radius) + + x1 := r * (bu + 4) + x2 := x1 + (r * 3.5) + dc.DrawLine(x1, x1, x2, x2) // top-left to bottom-right stroke + dc.DrawLine(x1, x2, x2, x1) // bottom-left to top-right stroke + dc.SetColor(red) + dc.SetLineWidth(r) + dc.Stroke() + }, + } +) + +var ( + bg = color.NRGBA{0, 0, 0, 255} + fg = color.NRGBA{255, 255, 255, 255} + gray = color.NRGBA{255, 255, 255, 102} + red = color.NRGBA{229, 111, 74, 255} +) + +// render returns a PNG image of the logo. +func (logo tsLogo) render() *bytes.Buffer { + const borderUnits = 1 + return logo.renderWithBorder(borderUnits) +} + +// renderWithBorder returns a PNG image of the logo with the specified border width. +// One border unit is equal to the radius of a tailscale logo dot. +func (logo tsLogo) renderWithBorder(borderUnits int) *bytes.Buffer { + const radius = 25 + dim := radius * (8 + borderUnits*2) + + dc := gg.NewContext(dim, dim) + dc.DrawRectangle(0, 0, float64(dim), float64(dim)) + dc.SetColor(bg) + dc.Fill() + + if logo.dotMask != nil { + mask := logo.dotMask(dc, borderUnits, radius) + dc.SetMask(mask) + dc.InvertMask() + } + + for y := 0; y < 3; y++ { + for x := 0; x < 3; x++ { + px := (borderUnits + 1 + 3*x) * radius + py := (borderUnits + 1 + 3*y) * radius + col := fg + if logo.dots[y*3+x] == 0 { + col = gray + } + dc.DrawCircle(float64(px), float64(py), radius) + dc.SetColor(col) + dc.Fill() + } + } + + if logo.overlay != nil { + dc.ResetClip() + logo.overlay(dc, borderUnits, radius) + } + + b := bytes.NewBuffer(nil) + + // Encode as ICO format on Windows, PNG on all other platforms. + if runtime.GOOS == "windows" { + _ = ico.Encode(b, dc.Image()) + } else { + _ = png.Encode(b, dc.Image()) + } + return b +} + +// setAppIcon renders logo and sets it as the systray icon. +func setAppIcon(icon tsLogo) { + if icon.dots == loading.dots { + startLoadingAnimation() + } else { + stopLoadingAnimation() + systray.SetIcon(icon.render().Bytes()) + } +} + +var ( + loadingMu sync.Mutex // protects loadingCancel + + // loadingCancel stops the loading animation in the systray icon. + // This is nil if the animation is not currently active. + loadingCancel func() +) + +// startLoadingAnimation starts the animated loading icon in the system tray. +// The animation continues until [stopLoadingAnimation] is called. +// If the loading animation is already active, this func does nothing. +func startLoadingAnimation() { + loadingMu.Lock() + defer loadingMu.Unlock() + + if loadingCancel != nil { + // loading icon already displayed + return + } + + ctx := context.Background() + ctx, loadingCancel = context.WithCancel(ctx) + + go func() { + t := time.NewTicker(500 * time.Millisecond) + var i int + for { + select { + case <-ctx.Done(): + return + case <-t.C: + systray.SetIcon(loadingLogos[i].render().Bytes()) + i++ + if i >= len(loadingLogos) { + i = 0 + } + } + } + }() +} + +// stopLoadingAnimation stops the animated loading icon in the system tray. +// If the loading animation is not currently active, this func does nothing. +func stopLoadingAnimation() { + loadingMu.Lock() + defer loadingMu.Unlock() + + if loadingCancel != nil { + loadingCancel() + loadingCancel = nil + } +} diff --git a/client/systray/startup-creator.go b/client/systray/startup-creator.go new file mode 100644 index 000000000..cb354856d --- /dev/null +++ b/client/systray/startup-creator.go @@ -0,0 +1,76 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build cgo || !darwin + +// Package systray provides a minimal Tailscale systray application. +package systray + +import ( + "bufio" + "bytes" + _ "embed" + "fmt" + "os" + "os/exec" + "path/filepath" + "strings" +) + +//go:embed tailscale-systray.service +var embedSystemd string + +func InstallStartupScript(initSystem string) error { + switch initSystem { + case "systemd": + return installSystemd() + default: + return fmt.Errorf("unsupported init system '%s'", initSystem) + } +} + +func installSystemd() error { + // Find the path to tailscale, just in case it's not where the example file + // has it placed, and replace that before writing the file. + tailscaleBin, err := exec.LookPath("tailscale") + if err != nil { + return fmt.Errorf("failed to find tailscale binary %w", err) + } + + var output bytes.Buffer + scanner := bufio.NewScanner(strings.NewReader(embedSystemd)) + for scanner.Scan() { + line := scanner.Text() + if strings.HasPrefix(line, "ExecStart=") { + line = fmt.Sprintf("ExecStart=%s systray", tailscaleBin) + } + output.WriteString(line + "\n") + } + + configDir, err := os.UserConfigDir() + if err != nil { + homeDir, err := os.UserHomeDir() + if err != nil { + return fmt.Errorf("unable to locate user home: %w", err) + } + configDir = filepath.Join(homeDir, ".config") + } + + systemdDir := filepath.Join(configDir, "systemd", "user") + if err := os.MkdirAll(systemdDir, 0o755); err != nil { + return fmt.Errorf("failed creating systemd uuser dir: %w", err) + } + + serviceFile := filepath.Join(systemdDir, "tailscale-systray.service") + + if err := os.WriteFile(serviceFile, output.Bytes(), 0o755); err != nil { + return fmt.Errorf("failed writing systemd user service: %w", err) + } + + fmt.Printf("Successfully installed systemd service to: %s\n", serviceFile) + fmt.Println("To enable and start the service, run:") + fmt.Println(" systemctl --user daemon-reload") + fmt.Println(" systemctl --user enable --now tailscale-systray") + + return nil +} diff --git a/client/systray/systray.go b/client/systray/systray.go new file mode 100644 index 000000000..bc099a1ec --- /dev/null +++ b/client/systray/systray.go @@ -0,0 +1,801 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build cgo || !darwin + +// Package systray provides a minimal Tailscale systray application. +package systray + +import ( + "bytes" + "context" + "errors" + "fmt" + "image" + "io" + "log" + "net/http" + "os" + "os/signal" + "runtime" + "slices" + "strings" + "sync" + "syscall" + "time" + + "fyne.io/systray" + ico "github.com/Kodeworks/golang-image-ico" + "github.com/atotto/clipboard" + dbus "github.com/godbus/dbus/v5" + "github.com/toqueteos/webbrowser" + "tailscale.com/client/local" + "tailscale.com/ipn" + "tailscale.com/ipn/ipnstate" + "tailscale.com/tailcfg" + "tailscale.com/util/slicesx" + "tailscale.com/util/stringsx" +) + +var ( + // newMenuDelay is the amount of time to sleep after creating a new menu, + // but before adding items to it. This works around a bug in some dbus implementations. + newMenuDelay time.Duration + + // if true, treat all mullvad exit node countries as single-city. + // Instead of rendering a submenu with cities, just select the highest-priority peer. + hideMullvadCities bool +) + +// Run starts the systray menu and blocks until the menu exits. +// If client is nil, a default local.Client is used. +func (menu *Menu) Run(client *local.Client) { + if client == nil { + client = &local.Client{} + } + menu.lc = client + menu.updateState() + + // exit cleanly on SIGINT and SIGTERM + go func() { + interrupt := make(chan os.Signal, 1) + signal.Notify(interrupt, syscall.SIGINT, syscall.SIGTERM) + select { + case <-interrupt: + menu.onExit() + case <-menu.bgCtx.Done(): + } + }() + go menu.lc.IncrementGauge(menu.bgCtx, "systray_running", 1) + defer menu.lc.IncrementGauge(menu.bgCtx, "systray_running", -1) + + systray.Run(menu.onReady, menu.onExit) +} + +// Menu represents the systray menu, its items, and the current Tailscale state. +type Menu struct { + mu sync.Mutex // protects the entire Menu + + lc *local.Client + status *ipnstate.Status + curProfile ipn.LoginProfile + allProfiles []ipn.LoginProfile + + // readonly is whether the systray app is running in read-only mode. + // This is set if LocalAPI returns a permission error, + // typically because the user needs to run `tailscale set --operator=$USER`. + readonly bool + + bgCtx context.Context // ctx for background tasks not involving menu item clicks + bgCancel context.CancelFunc + + // Top-level menu items + connect *systray.MenuItem + disconnect *systray.MenuItem + self *systray.MenuItem + exitNodes *systray.MenuItem + more *systray.MenuItem + rebuildMenu *systray.MenuItem + quit *systray.MenuItem + + rebuildCh chan struct{} // triggers a menu rebuild + accountsCh chan ipn.ProfileID + exitNodeCh chan tailcfg.StableNodeID // ID of selected exit node + + eventCancel context.CancelFunc // cancel eventLoop + + notificationIcon *os.File // icon used for desktop notifications +} + +func (menu *Menu) init() { + if menu.bgCtx != nil { + // already initialized + return + } + + menu.rebuildCh = make(chan struct{}, 1) + menu.accountsCh = make(chan ipn.ProfileID) + menu.exitNodeCh = make(chan tailcfg.StableNodeID) + + // dbus wants a file path for notification icons, so copy to a temp file. + menu.notificationIcon, _ = os.CreateTemp("", "tailscale-systray.png") + io.Copy(menu.notificationIcon, connected.renderWithBorder(3)) + + menu.bgCtx, menu.bgCancel = context.WithCancel(context.Background()) + go menu.watchIPNBus() +} + +func init() { + if runtime.GOOS != "linux" { + // so far, these tweaks are only needed on Linux + return + } + + desktop := strings.ToLower(os.Getenv("XDG_CURRENT_DESKTOP")) + switch desktop { + case "gnome", "ubuntu:gnome": + // GNOME expands submenus downward in the main menu, rather than flyouts to the side. + // Either as a result of that or another limitation, there seems to be a maximum depth of submenus. + // Mullvad countries that have a city submenu are not being rendered, and so can't be selected. + // Handle this by simply treating all mullvad countries as single-city and select the best peer. + hideMullvadCities = true + case "kde": + // KDE doesn't need a delay, and actually won't render submenus + // if we delay for more than about 400Âĩs. + newMenuDelay = 0 + default: + // Add a slight delay to ensure the menu is created before adding items. + // + // Systray implementations that use libdbusmenu sometimes process messages out of order, + // resulting in errors such as: + // (waybar:153009): LIBDBUSMENU-GTK-WARNING **: 18:07:11.551: Children but no menu, someone's been naughty with their 'children-display' property: 'submenu' + // + // See also: https://github.com/fyne-io/systray/issues/12 + newMenuDelay = 10 * time.Millisecond + } +} + +// onReady is called by the systray package when the menu is ready to be built. +func (menu *Menu) onReady() { + log.Printf("starting") + if os.Getuid() == 0 || os.Getuid() != os.Geteuid() || os.Getenv("SUDO_USER") != "" || os.Getenv("DOAS_USER") != "" { + fmt.Fprintln(os.Stderr, ` +It appears that you might be running the systray with sudo/doas. +This can lead to issues with D-Bus, and should be avoided. + +The systray application should be run with the same user as your desktop session. +This usually means that you should run the application like: + +tailscale systray + +See https://tailscale.com/kb/1597/linux-systray for more information.`) + } + setAppIcon(disconnected) + menu.rebuild() + + menu.mu.Lock() + if menu.readonly { + fmt.Fprintln(os.Stderr, ` +No permission to manage Tailscale. Set operator by running: + +sudo tailscale set --operator=$USER + +See https://tailscale.com/s/cli-operator for more information.`) + } + menu.mu.Unlock() +} + +// updateState updates the Menu state from the Tailscale local client. +func (menu *Menu) updateState() { + menu.mu.Lock() + defer menu.mu.Unlock() + menu.init() + + menu.readonly = false + + var err error + menu.status, err = menu.lc.Status(menu.bgCtx) + if err != nil { + log.Print(err) + } + menu.curProfile, menu.allProfiles, err = menu.lc.ProfileStatus(menu.bgCtx) + if err != nil { + if local.IsAccessDeniedError(err) { + menu.readonly = true + } + log.Print(err) + } +} + +// rebuild the systray menu based on the current Tailscale state. +// +// We currently rebuild the entire menu because it is not easy to update the existing menu. +// You cannot iterate over the items in a menu, nor can you remove some items like separators. +// So for now we rebuild the whole thing, and can optimize this later if needed. +func (menu *Menu) rebuild() { + menu.mu.Lock() + defer menu.mu.Unlock() + menu.init() + + if menu.eventCancel != nil { + menu.eventCancel() + } + ctx := context.Background() + ctx, menu.eventCancel = context.WithCancel(ctx) + + systray.ResetMenu() + + if menu.readonly { + const readonlyMsg = "No permission to manage Tailscale.\nSee tailscale.com/s/cli-operator" + m := systray.AddMenuItem(readonlyMsg, "") + onClick(ctx, m, func(_ context.Context) { + webbrowser.Open("https://tailscale.com/s/cli-operator") + }) + systray.AddSeparator() + } + + menu.connect = systray.AddMenuItem("Connect", "") + menu.disconnect = systray.AddMenuItem("Disconnect", "") + menu.disconnect.Hide() + systray.AddSeparator() + + // delay to prevent race setting icon on first start + time.Sleep(newMenuDelay) + + // Set systray menu icon and title. + // Also adjust connect/disconnect menu items if needed. + var backendState string + if menu.status != nil { + backendState = menu.status.BackendState + } + switch backendState { + case ipn.Running.String(): + if menu.status.ExitNodeStatus != nil && !menu.status.ExitNodeStatus.ID.IsZero() { + if menu.status.ExitNodeStatus.Online { + setTooltip("Using exit node") + setAppIcon(exitNodeOnline) + } else { + setTooltip("Exit node offline") + setAppIcon(exitNodeOffline) + } + } else { + setTooltip(fmt.Sprintf("Connected to %s", menu.status.CurrentTailnet.Name)) + setAppIcon(connected) + } + menu.connect.SetTitle("Connected") + menu.connect.Disable() + menu.disconnect.Show() + menu.disconnect.Enable() + case ipn.Starting.String(): + setTooltip("Connecting") + setAppIcon(loading) + default: + setTooltip("Disconnected") + setAppIcon(disconnected) + } + + if menu.readonly { + menu.connect.Disable() + menu.disconnect.Disable() + } + + account := "Account" + if pt := profileTitle(menu.curProfile); pt != "" { + account = pt + } + if !menu.readonly { + accounts := systray.AddMenuItem(account, "") + setRemoteIcon(accounts, menu.curProfile.UserProfile.ProfilePicURL) + time.Sleep(newMenuDelay) + for _, profile := range menu.allProfiles { + title := profileTitle(profile) + var item *systray.MenuItem + if profile.ID == menu.curProfile.ID { + item = accounts.AddSubMenuItemCheckbox(title, "", true) + } else { + item = accounts.AddSubMenuItem(title, "") + } + setRemoteIcon(item, profile.UserProfile.ProfilePicURL) + onClick(ctx, item, func(ctx context.Context) { + select { + case <-ctx.Done(): + case menu.accountsCh <- profile.ID: + } + }) + } + } + + if menu.status != nil && menu.status.Self != nil && len(menu.status.Self.TailscaleIPs) > 0 { + title := fmt.Sprintf("This Device: %s (%s)", menu.status.Self.HostName, menu.status.Self.TailscaleIPs[0]) + menu.self = systray.AddMenuItem(title, "") + } else { + menu.self = systray.AddMenuItem("This Device: not connected", "") + menu.self.Disable() + } + systray.AddSeparator() + + if !menu.readonly { + menu.rebuildExitNodeMenu(ctx) + } + + menu.more = systray.AddMenuItem("More settings", "") + if menu.status != nil && menu.status.BackendState == "Running" { + // web client is only available if backend is running + onClick(ctx, menu.more, func(_ context.Context) { + webbrowser.Open("http://100.100.100.100/") + }) + } else { + menu.more.Disable() + } + + // TODO(#15528): this menu item shouldn't be necessary at all, + // but is at least more discoverable than having users switch profiles or exit nodes. + menu.rebuildMenu = systray.AddMenuItem("Rebuild menu", "Fix missing menu items") + onClick(ctx, menu.rebuildMenu, func(ctx context.Context) { + select { + case <-ctx.Done(): + case menu.rebuildCh <- struct{}{}: + } + }) + menu.rebuildMenu.Enable() + + menu.quit = systray.AddMenuItem("Quit", "Quit the app") + menu.quit.Enable() + + go menu.eventLoop(ctx) +} + +// profileTitle returns the title string for a profile menu item. +func profileTitle(profile ipn.LoginProfile) string { + title := profile.Name + if profile.NetworkProfile.DomainName != "" { + if runtime.GOOS == "windows" || runtime.GOOS == "darwin" { + // windows and mac don't support multi-line menu + title += " (" + profile.NetworkProfile.DisplayNameOrDefault() + ")" + } else { + title += "\n" + profile.NetworkProfile.DisplayNameOrDefault() + } + } + return title +} + +var ( + cacheMu sync.Mutex + httpCache = map[string][]byte{} // URL => response body +) + +// setRemoteIcon sets the icon for menu to the specified remote image. +// Remote images are fetched as needed and cached. +func setRemoteIcon(menu *systray.MenuItem, urlStr string) { + if menu == nil || urlStr == "" { + return + } + + cacheMu.Lock() + b, ok := httpCache[urlStr] + if !ok { + resp, err := http.Get(urlStr) + if err == nil && resp.StatusCode == http.StatusOK { + b, _ = io.ReadAll(resp.Body) + + // Convert image to ICO format on Windows + if runtime.GOOS == "windows" { + im, _, err := image.Decode(bytes.NewReader(b)) + if err != nil { + return + } + buf := bytes.NewBuffer(nil) + if err := ico.Encode(buf, im); err != nil { + return + } + b = buf.Bytes() + } + + httpCache[urlStr] = b + resp.Body.Close() + } + } + cacheMu.Unlock() + + if len(b) > 0 { + menu.SetIcon(b) + } +} + +// setTooltip sets the tooltip text for the systray icon. +func setTooltip(text string) { + if runtime.GOOS == "darwin" || runtime.GOOS == "windows" { + systray.SetTooltip(text) + } else { + // on Linux, SetTitle actually sets the tooltip + systray.SetTitle(text) + } +} + +// eventLoop is the main event loop for handling click events on menu items +// and responding to Tailscale state changes. +// This method does not return until ctx.Done is closed. +func (menu *Menu) eventLoop(ctx context.Context) { + for { + select { + case <-ctx.Done(): + return + case <-menu.rebuildCh: + menu.updateState() + menu.rebuild() + case <-menu.connect.ClickedCh: + _, err := menu.lc.EditPrefs(ctx, &ipn.MaskedPrefs{ + Prefs: ipn.Prefs{ + WantRunning: true, + }, + WantRunningSet: true, + }) + if err != nil { + log.Printf("error connecting: %v", err) + } + + case <-menu.disconnect.ClickedCh: + _, err := menu.lc.EditPrefs(ctx, &ipn.MaskedPrefs{ + Prefs: ipn.Prefs{ + WantRunning: false, + }, + WantRunningSet: true, + }) + if err != nil { + log.Printf("error disconnecting: %v", err) + } + + case <-menu.self.ClickedCh: + menu.copyTailscaleIP(menu.status.Self) + + case id := <-menu.accountsCh: + if err := menu.lc.SwitchProfile(ctx, id); err != nil { + log.Printf("error switching to profile ID %v: %v", id, err) + } + + case exitNode := <-menu.exitNodeCh: + if exitNode.IsZero() { + log.Print("disable exit node") + if err := menu.lc.SetUseExitNode(ctx, false); err != nil { + log.Printf("error disabling exit node: %v", err) + } + } else { + log.Printf("enable exit node: %v", exitNode) + mp := &ipn.MaskedPrefs{ + Prefs: ipn.Prefs{ + ExitNodeID: exitNode, + }, + ExitNodeIDSet: true, + } + if _, err := menu.lc.EditPrefs(ctx, mp); err != nil { + log.Printf("error setting exit node: %v", err) + } + } + + case <-menu.quit.ClickedCh: + systray.Quit() + } + } +} + +// onClick registers a click handler for a menu item. +func onClick(ctx context.Context, item *systray.MenuItem, fn func(ctx context.Context)) { + go func() { + for { + select { + case <-ctx.Done(): + return + case <-item.ClickedCh: + fn(ctx) + } + } + }() +} + +// watchIPNBus subscribes to the tailscale event bus and sends state updates to chState. +// This method does not return. +func (menu *Menu) watchIPNBus() { + for { + if err := menu.watchIPNBusInner(); err != nil { + log.Println(err) + if errors.Is(err, context.Canceled) { + // If the context got canceled, we will never be able to + // reconnect to IPN bus, so exit the process. + log.Fatalf("watchIPNBus: %v", err) + } + } + // If our watch connection breaks, wait a bit before reconnecting. No + // reason to spam the logs if e.g. tailscaled is restarting or goes + // down. + time.Sleep(3 * time.Second) + } +} + +func (menu *Menu) watchIPNBusInner() error { + watcher, err := menu.lc.WatchIPNBus(menu.bgCtx, 0) + if err != nil { + return fmt.Errorf("watching ipn bus: %w", err) + } + defer watcher.Close() + for { + select { + case <-menu.bgCtx.Done(): + return nil + default: + n, err := watcher.Next() + if err != nil { + return fmt.Errorf("ipnbus error: %w", err) + } + var rebuild bool + if n.State != nil { + log.Printf("new state: %v", n.State) + rebuild = true + } + if n.Prefs != nil { + rebuild = true + } + if rebuild { + menu.rebuildCh <- struct{}{} + } + } + } +} + +// copyTailscaleIP copies the first Tailscale IP of the given device to the clipboard +// and sends a notification with the copied value. +func (menu *Menu) copyTailscaleIP(device *ipnstate.PeerStatus) { + if device == nil || len(device.TailscaleIPs) == 0 { + return + } + name := strings.Split(device.DNSName, ".")[0] + ip := device.TailscaleIPs[0].String() + err := clipboard.WriteAll(ip) + if err != nil { + log.Printf("clipboard error: %v", err) + } else { + menu.sendNotification(fmt.Sprintf("Copied Address for %v", name), ip) + } +} + +// sendNotification sends a desktop notification with the given title and content. +func (menu *Menu) sendNotification(title, content string) { + conn, err := dbus.SessionBus() + if err != nil { + log.Printf("dbus: %v", err) + return + } + timeout := 3 * time.Second + obj := conn.Object("org.freedesktop.Notifications", "/org/freedesktop/Notifications") + call := obj.Call("org.freedesktop.Notifications.Notify", 0, "Tailscale", uint32(0), + menu.notificationIcon.Name(), title, content, []string{}, map[string]dbus.Variant{}, int32(timeout.Milliseconds())) + if call.Err != nil { + log.Printf("dbus: %v", call.Err) + } +} + +func (menu *Menu) rebuildExitNodeMenu(ctx context.Context) { + if menu.status == nil { + return + } + + status := menu.status + menu.exitNodes = systray.AddMenuItem("Exit Nodes", "") + time.Sleep(newMenuDelay) + + // register a click handler for a menu item to set nodeID as the exit node. + setExitNodeOnClick := func(item *systray.MenuItem, nodeID tailcfg.StableNodeID) { + onClick(ctx, item, func(ctx context.Context) { + select { + case <-ctx.Done(): + case menu.exitNodeCh <- nodeID: + } + }) + } + + noExitNodeMenu := menu.exitNodes.AddSubMenuItemCheckbox("None", "", status.ExitNodeStatus == nil) + setExitNodeOnClick(noExitNodeMenu, "") + + // Show recommended exit node if available. + if status.Self.CapMap.Contains(tailcfg.NodeAttrSuggestExitNodeUI) { + sugg, err := menu.lc.SuggestExitNode(ctx) + if err == nil { + title := "Recommended: " + if loc := sugg.Location; loc.Valid() && loc.Country() != "" { + flag := countryFlag(loc.CountryCode()) + title += fmt.Sprintf("%s %s: %s", flag, loc.Country(), loc.City()) + } else { + title += strings.Split(sugg.Name, ".")[0] + } + menu.exitNodes.AddSeparator() + rm := menu.exitNodes.AddSubMenuItemCheckbox(title, "", false) + setExitNodeOnClick(rm, sugg.ID) + if status.ExitNodeStatus != nil && sugg.ID == status.ExitNodeStatus.ID { + rm.Check() + } + } + } + + // Add tailnet exit nodes if present. + var tailnetExitNodes []*ipnstate.PeerStatus + for _, ps := range status.Peer { + if ps.ExitNodeOption && ps.Location == nil { + tailnetExitNodes = append(tailnetExitNodes, ps) + } + } + if len(tailnetExitNodes) > 0 { + menu.exitNodes.AddSeparator() + menu.exitNodes.AddSubMenuItem("Tailnet Exit Nodes", "").Disable() + for _, ps := range status.Peer { + if !ps.ExitNodeOption || ps.Location != nil { + continue + } + name := strings.Split(ps.DNSName, ".")[0] + if !ps.Online { + name += " (offline)" + } + sm := menu.exitNodes.AddSubMenuItemCheckbox(name, "", false) + if !ps.Online { + sm.Disable() + } + if status.ExitNodeStatus != nil && ps.ID == status.ExitNodeStatus.ID { + sm.Check() + } + setExitNodeOnClick(sm, ps.ID) + } + } + + // Add mullvad exit nodes if present. + var mullvadExitNodes mullvadPeers + if status.Self.CapMap.Contains("mullvad") { + mullvadExitNodes = newMullvadPeers(status) + } + if len(mullvadExitNodes.countries) > 0 { + menu.exitNodes.AddSeparator() + menu.exitNodes.AddSubMenuItem("Location-based Exit Nodes", "").Disable() + mullvadMenu := menu.exitNodes.AddSubMenuItemCheckbox("Mullvad VPN", "", false) + + for _, country := range mullvadExitNodes.sortedCountries() { + flag := countryFlag(country.code) + countryMenu := mullvadMenu.AddSubMenuItemCheckbox(flag+" "+country.name, "", false) + + // single-city country, no submenu + if len(country.cities) == 1 || hideMullvadCities { + setExitNodeOnClick(countryMenu, country.best.ID) + if status.ExitNodeStatus != nil { + for _, city := range country.cities { + for _, ps := range city.peers { + if status.ExitNodeStatus.ID == ps.ID { + mullvadMenu.Check() + countryMenu.Check() + } + } + } + } + continue + } + + // multi-city country, build submenu with "best available" option and cities. + time.Sleep(newMenuDelay) + bm := countryMenu.AddSubMenuItemCheckbox("Best Available", "", false) + setExitNodeOnClick(bm, country.best.ID) + countryMenu.AddSeparator() + + for _, city := range country.sortedCities() { + cityMenu := countryMenu.AddSubMenuItemCheckbox(city.name, "", false) + setExitNodeOnClick(cityMenu, city.best.ID) + if status.ExitNodeStatus != nil { + for _, ps := range city.peers { + if status.ExitNodeStatus.ID == ps.ID { + mullvadMenu.Check() + countryMenu.Check() + cityMenu.Check() + } + } + } + } + } + } + + // TODO: "Allow Local Network Access" and "Run Exit Node" menu items +} + +// mullvadPeers contains all mullvad peer nodes, sorted by country and city. +type mullvadPeers struct { + countries map[string]*mvCountry // country code (uppercase) => country +} + +// sortedCountries returns countries containing mullvad nodes, sorted by name. +func (mp mullvadPeers) sortedCountries() []*mvCountry { + countries := slicesx.MapValues(mp.countries) + slices.SortFunc(countries, func(a, b *mvCountry) int { + return stringsx.CompareFold(a.name, b.name) + }) + return countries +} + +type mvCountry struct { + code string + name string + best *ipnstate.PeerStatus // highest priority peer in the country + cities map[string]*mvCity // city code => city +} + +// sortedCities returns cities containing mullvad nodes, sorted by name. +func (mc *mvCountry) sortedCities() []*mvCity { + cities := slicesx.MapValues(mc.cities) + slices.SortFunc(cities, func(a, b *mvCity) int { + return stringsx.CompareFold(a.name, b.name) + }) + return cities +} + +// countryFlag takes a 2-character ASCII string and returns the corresponding emoji flag. +// It returns the empty string on error. +func countryFlag(code string) string { + if len(code) != 2 { + return "" + } + runes := make([]rune, 0, 2) + for i := range 2 { + b := code[i] | 32 // lowercase + if b < 'a' || b > 'z' { + return "" + } + // https://en.wikipedia.org/wiki/Regional_indicator_symbol + runes = append(runes, 0x1F1E6+rune(b-'a')) + } + return string(runes) +} + +type mvCity struct { + name string + best *ipnstate.PeerStatus // highest priority peer in the city + peers []*ipnstate.PeerStatus +} + +func newMullvadPeers(status *ipnstate.Status) mullvadPeers { + countries := make(map[string]*mvCountry) + for _, ps := range status.Peer { + if !ps.ExitNodeOption || ps.Location == nil { + continue + } + loc := ps.Location + country, ok := countries[loc.CountryCode] + if !ok { + country = &mvCountry{ + code: loc.CountryCode, + name: loc.Country, + cities: make(map[string]*mvCity), + } + countries[loc.CountryCode] = country + } + city, ok := countries[loc.CountryCode].cities[loc.CityCode] + if !ok { + city = &mvCity{ + name: loc.City, + } + countries[loc.CountryCode].cities[loc.CityCode] = city + } + city.peers = append(city.peers, ps) + if city.best == nil || ps.Location.Priority > city.best.Location.Priority { + city.best = ps + } + if country.best == nil || ps.Location.Priority > country.best.Location.Priority { + country.best = ps + } + } + return mullvadPeers{countries} +} + +// onExit is called by the systray package when the menu is exiting. +func (menu *Menu) onExit() { + log.Printf("exiting") + if menu.bgCancel != nil { + menu.bgCancel() + } + if menu.eventCancel != nil { + menu.eventCancel() + } + + os.Remove(menu.notificationIcon.Name()) +} diff --git a/client/systray/tailscale-systray.service b/client/systray/tailscale-systray.service new file mode 100644 index 000000000..a4d987563 --- /dev/null +++ b/client/systray/tailscale-systray.service @@ -0,0 +1,10 @@ +[Unit] +Description=Tailscale System Tray +After=systemd.service + +[Service] +Type=simple +ExecStart=/usr/bin/tailscale systray + +[Install] +WantedBy=default.target diff --git a/client/tailscale/acl.go b/client/tailscale/acl.go index 8d8bdfc86..929ec2b3b 100644 --- a/client/tailscale/acl.go +++ b/client/tailscale/acl.go @@ -12,6 +12,7 @@ import ( "fmt" "net/http" "net/netip" + "net/url" ) // ACLRow defines a rule that grants access by a set of users or groups to a set @@ -83,7 +84,7 @@ func (c *Client) ACL(ctx context.Context) (acl *ACL, err error) { } }() - path := fmt.Sprintf("%s/api/v2/tailnet/%s/acl", c.baseURL(), c.tailnet) + path := c.BuildTailnetURL("acl") req, err := http.NewRequestWithContext(ctx, "GET", path, nil) if err != nil { return nil, err @@ -97,7 +98,7 @@ func (c *Client) ACL(ctx context.Context) (acl *ACL, err error) { // If status code was not successful, return the error. // TODO: Change the check for the StatusCode to include other 2XX success codes. if resp.StatusCode != http.StatusOK { - return nil, handleErrorResponse(b, resp) + return nil, HandleErrorResponse(b, resp) } // Otherwise, try to decode the response. @@ -126,7 +127,7 @@ func (c *Client) ACLHuJSON(ctx context.Context) (acl *ACLHuJSON, err error) { } }() - path := fmt.Sprintf("%s/api/v2/tailnet/%s/acl?details=1", c.baseURL(), c.tailnet) + path := c.BuildTailnetURL("acl", url.Values{"details": {"1"}}) req, err := http.NewRequestWithContext(ctx, "GET", path, nil) if err != nil { return nil, err @@ -138,7 +139,7 @@ func (c *Client) ACLHuJSON(ctx context.Context) (acl *ACLHuJSON, err error) { } if resp.StatusCode != http.StatusOK { - return nil, handleErrorResponse(b, resp) + return nil, HandleErrorResponse(b, resp) } data := struct { @@ -146,7 +147,7 @@ func (c *Client) ACLHuJSON(ctx context.Context) (acl *ACLHuJSON, err error) { Warnings []string `json:"warnings"` }{} if err := json.Unmarshal(b, &data); err != nil { - return nil, err + return nil, fmt.Errorf("json.Unmarshal %q: %w", b, err) } acl = &ACLHuJSON{ @@ -184,7 +185,7 @@ func (e ACLTestError) Error() string { } func (c *Client) aclPOSTRequest(ctx context.Context, body []byte, avoidCollisions bool, etag, acceptHeader string) ([]byte, string, error) { - path := fmt.Sprintf("%s/api/v2/tailnet/%s/acl", c.baseURL(), c.tailnet) + path := c.BuildTailnetURL("acl") req, err := http.NewRequestWithContext(ctx, "POST", path, bytes.NewBuffer(body)) if err != nil { return nil, "", err @@ -328,7 +329,7 @@ type ACLPreview struct { } func (c *Client) previewACLPostRequest(ctx context.Context, body []byte, previewType string, previewFor string) (res *ACLPreviewResponse, err error) { - path := fmt.Sprintf("%s/api/v2/tailnet/%s/acl/preview", c.baseURL(), c.tailnet) + path := c.BuildTailnetURL("acl", "preview") req, err := http.NewRequestWithContext(ctx, "POST", path, bytes.NewBuffer(body)) if err != nil { return nil, err @@ -350,7 +351,7 @@ func (c *Client) previewACLPostRequest(ctx context.Context, body []byte, preview // If status code was not successful, return the error. // TODO: Change the check for the StatusCode to include other 2XX success codes. if resp.StatusCode != http.StatusOK { - return nil, handleErrorResponse(b, resp) + return nil, HandleErrorResponse(b, resp) } if err = json.Unmarshal(b, &res); err != nil { return nil, err @@ -488,7 +489,7 @@ func (c *Client) ValidateACLJSON(ctx context.Context, source, dest string) (test return nil, err } - path := fmt.Sprintf("%s/api/v2/tailnet/%s/acl/validate", c.baseURL(), c.tailnet) + path := c.BuildTailnetURL("acl", "validate") req, err := http.NewRequestWithContext(ctx, "POST", path, bytes.NewBuffer(postData)) if err != nil { return nil, err diff --git a/client/tailscale/apitype/apitype.go b/client/tailscale/apitype/apitype.go index b1c273a4f..6d239d082 100644 --- a/client/tailscale/apitype/apitype.go +++ b/client/tailscale/apitype/apitype.go @@ -7,11 +7,29 @@ package apitype import ( "tailscale.com/tailcfg" "tailscale.com/types/dnstype" + "tailscale.com/util/ctxkey" ) // LocalAPIHost is the Host header value used by the LocalAPI. const LocalAPIHost = "local-tailscaled.sock" +// RequestReasonHeader is the header used to pass justification for a LocalAPI request, +// such as when a user wants to perform an action they don't have permission for, +// and a policy allows it with justification. As of 2025-01-29, it is only used to +// allow a user to disconnect Tailscale when the "always-on" mode is enabled. +// +// The header value is base64-encoded using the standard encoding defined in RFC 4648. +// +// See tailscale/corp#26146. +const RequestReasonHeader = "X-Tailscale-Reason" + +// RequestReasonKey is the context key used to pass the request reason +// when making a LocalAPI request via [local.Client]. +// It's value is a raw string. An empty string means no reason was provided. +// +// See tailscale/corp#26146. +var RequestReasonKey = ctxkey.New(RequestReasonHeader, "") + // WhoIsResponse is the JSON type returned by tailscaled debug server's /whois?ip=$IP handler. // In successful whois responses, Node and UserProfile are never nil. type WhoIsResponse struct { @@ -76,3 +94,13 @@ type DNSQueryResponse struct { // Resolvers is the list of resolvers that the forwarder deemed able to resolve the query. Resolvers []*dnstype.Resolver } + +// OptionalFeatures describes which optional features are enabled in the build. +type OptionalFeatures struct { + // Features is the map of optional feature names to whether they are + // enabled. + // + // Disabled features may be absent from the map. (That is, false values + // are not guaranteed to be present.) + Features map[string]bool +} diff --git a/client/tailscale/apitype/controltype.go b/client/tailscale/apitype/controltype.go index 9a623be31..d9d79f0ad 100644 --- a/client/tailscale/apitype/controltype.go +++ b/client/tailscale/apitype/controltype.go @@ -3,17 +3,50 @@ package apitype +// DNSConfig is the DNS configuration for a tailnet +// used in /tailnet/{tailnet}/dns/config. type DNSConfig struct { - Resolvers []DNSResolver `json:"resolvers"` - FallbackResolvers []DNSResolver `json:"fallbackResolvers"` - Routes map[string][]DNSResolver `json:"routes"` - Domains []string `json:"domains"` - Nameservers []string `json:"nameservers"` - Proxied bool `json:"proxied"` - TempCorpIssue13969 string `json:"TempCorpIssue13969,omitempty"` + // Resolvers are the global DNS resolvers to use + // overriding the local OS configuration. + Resolvers []DNSResolver `json:"resolvers"` + + // FallbackResolvers are used as global resolvers when + // the client is unable to determine the OS's preferred DNS servers. + FallbackResolvers []DNSResolver `json:"fallbackResolvers"` + + // Routes map DNS name suffixes to a set of DNS resolvers, + // used for Split DNS and other advanced routing overlays. + Routes map[string][]DNSResolver `json:"routes"` + + // Domains are the search domains to use. + Domains []string `json:"domains"` + + // Proxied means MagicDNS is enabled. + Proxied bool `json:"proxied"` + + // TempCorpIssue13969 is from an internal hack day prototype, + // See tailscale/corp#13969. + TempCorpIssue13969 string `json:"TempCorpIssue13969,omitempty"` + + // Nameservers are the IP addresses of global nameservers to use. + // This is a deprecated format but may still be found in tailnets + // that were configured a long time ago. When making updates, + // set Resolvers and leave Nameservers empty. + Nameservers []string `json:"nameservers"` } +// DNSResolver is a DNS resolver in a DNS configuration. type DNSResolver struct { - Addr string `json:"addr"` + // Addr is the address of the DNS resolver. + // It is usually an IP address or a DoH URL. + // See dnstype.Resolver.Addr for full details. + Addr string `json:"addr"` + + // BootstrapResolution is an optional suggested resolution for + // the DoT/DoH resolver. BootstrapResolution []string `json:"bootstrapResolution,omitempty"` + + // UseWithExitNode signals this resolver should be used + // even when a tailscale exit node is configured on a device. + UseWithExitNode bool `json:"useWithExitNode,omitempty"` } diff --git a/client/tailscale/cert.go b/client/tailscale/cert.go new file mode 100644 index 000000000..4f351ab99 --- /dev/null +++ b/client/tailscale/cert.go @@ -0,0 +1,34 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !js && !ts_omit_acme + +package tailscale + +import ( + "context" + "crypto/tls" + + "tailscale.com/client/local" +) + +// GetCertificate is an alias for [tailscale.com/client/local.GetCertificate]. +// +// Deprecated: import [tailscale.com/client/local] instead and use [local.Client.GetCertificate]. +func GetCertificate(hi *tls.ClientHelloInfo) (*tls.Certificate, error) { + return local.GetCertificate(hi) +} + +// CertPair is an alias for [tailscale.com/client/local.CertPair]. +// +// Deprecated: import [tailscale.com/client/local] instead and use [local.Client.CertPair]. +func CertPair(ctx context.Context, domain string) (certPEM, keyPEM []byte, err error) { + return local.CertPair(ctx, domain) +} + +// ExpandSNIName is an alias for [tailscale.com/client/local.ExpandSNIName]. +// +// Deprecated: import [tailscale.com/client/local] instead and use [local.Client.ExpandSNIName]. +func ExpandSNIName(ctx context.Context, name string) (fqdn string, ok bool) { + return local.ExpandSNIName(ctx, name) +} diff --git a/client/tailscale/devices.go b/client/tailscale/devices.go index 9008d4d0d..0664f9e63 100644 --- a/client/tailscale/devices.go +++ b/client/tailscale/devices.go @@ -79,6 +79,13 @@ type Device struct { // Tailscale have attempted to collect this from the device but it has not // opted in, PostureIdentity will have Disabled=true. PostureIdentity *DevicePostureIdentity `json:"postureIdentity"` + + // TailnetLockKey is the tailnet lock public key of the node as a hex string. + TailnetLockKey string `json:"tailnetLockKey,omitempty"` + + // TailnetLockErr indicates an issue with the tailnet lock node-key signature + // on this device. This field is only populated when tailnet lock is enabled. + TailnetLockErr string `json:"tailnetLockError,omitempty"` } type DevicePostureIdentity struct { @@ -131,7 +138,7 @@ func (c *Client) Devices(ctx context.Context, fields *DeviceFieldsOpts) (deviceL } }() - path := fmt.Sprintf("%s/api/v2/tailnet/%s/devices", c.baseURL(), c.tailnet) + path := c.BuildTailnetURL("devices") req, err := http.NewRequestWithContext(ctx, "GET", path, nil) if err != nil { return nil, err @@ -149,7 +156,7 @@ func (c *Client) Devices(ctx context.Context, fields *DeviceFieldsOpts) (deviceL // If status code was not successful, return the error. // TODO: Change the check for the StatusCode to include other 2XX success codes. if resp.StatusCode != http.StatusOK { - return nil, handleErrorResponse(b, resp) + return nil, HandleErrorResponse(b, resp) } var devices GetDevicesResponse @@ -188,7 +195,7 @@ func (c *Client) Device(ctx context.Context, deviceID string, fields *DeviceFiel // If status code was not successful, return the error. // TODO: Change the check for the StatusCode to include other 2XX success codes. if resp.StatusCode != http.StatusOK { - return nil, handleErrorResponse(b, resp) + return nil, HandleErrorResponse(b, resp) } err = json.Unmarshal(b, &device) @@ -221,7 +228,7 @@ func (c *Client) DeleteDevice(ctx context.Context, deviceID string) (err error) // If status code was not successful, return the error. // TODO: Change the check for the StatusCode to include other 2XX success codes. if resp.StatusCode != http.StatusOK { - return handleErrorResponse(b, resp) + return HandleErrorResponse(b, resp) } return nil } @@ -253,7 +260,7 @@ func (c *Client) SetAuthorized(ctx context.Context, deviceID string, authorized // If status code was not successful, return the error. // TODO: Change the check for the StatusCode to include other 2XX success codes. if resp.StatusCode != http.StatusOK { - return handleErrorResponse(b, resp) + return HandleErrorResponse(b, resp) } return nil @@ -281,7 +288,7 @@ func (c *Client) SetTags(ctx context.Context, deviceID string, tags []string) er // If status code was not successful, return the error. // TODO: Change the check for the StatusCode to include other 2XX success codes. if resp.StatusCode != http.StatusOK { - return handleErrorResponse(b, resp) + return HandleErrorResponse(b, resp) } return nil diff --git a/client/tailscale/dns.go b/client/tailscale/dns.go index f198742b3..bbdc7c56c 100644 --- a/client/tailscale/dns.go +++ b/client/tailscale/dns.go @@ -44,7 +44,7 @@ type DNSPreferences struct { } func (c *Client) dnsGETRequest(ctx context.Context, endpoint string) ([]byte, error) { - path := fmt.Sprintf("%s/api/v2/tailnet/%s/dns/%s", c.baseURL(), c.tailnet, endpoint) + path := c.BuildTailnetURL("dns", endpoint) req, err := http.NewRequestWithContext(ctx, "GET", path, nil) if err != nil { return nil, err @@ -57,14 +57,14 @@ func (c *Client) dnsGETRequest(ctx context.Context, endpoint string) ([]byte, er // If status code was not successful, return the error. // TODO: Change the check for the StatusCode to include other 2XX success codes. if resp.StatusCode != http.StatusOK { - return nil, handleErrorResponse(b, resp) + return nil, HandleErrorResponse(b, resp) } return b, nil } func (c *Client) dnsPOSTRequest(ctx context.Context, endpoint string, postData any) ([]byte, error) { - path := fmt.Sprintf("%s/api/v2/tailnet/%s/dns/%s", c.baseURL(), c.tailnet, endpoint) + path := c.BuildTailnetURL("dns", endpoint) data, err := json.Marshal(&postData) if err != nil { return nil, err @@ -84,7 +84,7 @@ func (c *Client) dnsPOSTRequest(ctx context.Context, endpoint string, postData a // If status code was not successful, return the error. // TODO: Change the check for the StatusCode to include other 2XX success codes. if resp.StatusCode != http.StatusOK { - return nil, handleErrorResponse(b, resp) + return nil, HandleErrorResponse(b, resp) } return b, nil diff --git a/client/tailscale/example/servetls/servetls.go b/client/tailscale/example/servetls/servetls.go index f48e90d16..0ade42088 100644 --- a/client/tailscale/example/servetls/servetls.go +++ b/client/tailscale/example/servetls/servetls.go @@ -11,13 +11,14 @@ import ( "log" "net/http" - "tailscale.com/client/tailscale" + "tailscale.com/client/local" ) func main() { + var lc local.Client s := &http.Server{ TLSConfig: &tls.Config{ - GetCertificate: tailscale.GetCertificate, + GetCertificate: lc.GetCertificate, }, Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { io.WriteString(w, "

Hello from Tailscale!

It works.") diff --git a/client/tailscale/keys.go b/client/tailscale/keys.go index 84bcdfae6..79e19e998 100644 --- a/client/tailscale/keys.go +++ b/client/tailscale/keys.go @@ -40,7 +40,7 @@ type KeyDeviceCreateCapabilities struct { // Keys returns the list of keys for the current user. func (c *Client) Keys(ctx context.Context) ([]string, error) { - path := fmt.Sprintf("%s/api/v2/tailnet/%s/keys", c.baseURL(), c.tailnet) + path := c.BuildTailnetURL("keys") req, err := http.NewRequestWithContext(ctx, "GET", path, nil) if err != nil { return nil, err @@ -51,7 +51,7 @@ func (c *Client) Keys(ctx context.Context) ([]string, error) { return nil, err } if resp.StatusCode != http.StatusOK { - return nil, handleErrorResponse(b, resp) + return nil, HandleErrorResponse(b, resp) } var keys struct { @@ -99,7 +99,7 @@ func (c *Client) CreateKeyWithExpiry(ctx context.Context, caps KeyCapabilities, return "", nil, err } - path := fmt.Sprintf("%s/api/v2/tailnet/%s/keys", c.baseURL(), c.tailnet) + path := c.BuildTailnetURL("keys") req, err := http.NewRequestWithContext(ctx, "POST", path, bytes.NewReader(bs)) if err != nil { return "", nil, err @@ -110,7 +110,7 @@ func (c *Client) CreateKeyWithExpiry(ctx context.Context, caps KeyCapabilities, return "", nil, err } if resp.StatusCode != http.StatusOK { - return "", nil, handleErrorResponse(b, resp) + return "", nil, HandleErrorResponse(b, resp) } var key struct { @@ -126,7 +126,7 @@ func (c *Client) CreateKeyWithExpiry(ctx context.Context, caps KeyCapabilities, // Key returns the metadata for the given key ID. Currently, capabilities are // only returned for auth keys, API keys only return general metadata. func (c *Client) Key(ctx context.Context, id string) (*Key, error) { - path := fmt.Sprintf("%s/api/v2/tailnet/%s/keys/%s", c.baseURL(), c.tailnet, id) + path := c.BuildTailnetURL("keys", id) req, err := http.NewRequestWithContext(ctx, "GET", path, nil) if err != nil { return nil, err @@ -137,7 +137,7 @@ func (c *Client) Key(ctx context.Context, id string) (*Key, error) { return nil, err } if resp.StatusCode != http.StatusOK { - return nil, handleErrorResponse(b, resp) + return nil, HandleErrorResponse(b, resp) } var key Key @@ -149,7 +149,7 @@ func (c *Client) Key(ctx context.Context, id string) (*Key, error) { // DeleteKey deletes the key with the given ID. func (c *Client) DeleteKey(ctx context.Context, id string) error { - path := fmt.Sprintf("%s/api/v2/tailnet/%s/keys/%s", c.baseURL(), c.tailnet, id) + path := c.BuildTailnetURL("keys", id) req, err := http.NewRequestWithContext(ctx, "DELETE", path, nil) if err != nil { return err @@ -160,7 +160,7 @@ func (c *Client) DeleteKey(ctx context.Context, id string) error { return err } if resp.StatusCode != http.StatusOK { - return handleErrorResponse(b, resp) + return HandleErrorResponse(b, resp) } return nil } diff --git a/client/tailscale/localclient_aliases.go b/client/tailscale/localclient_aliases.go new file mode 100644 index 000000000..e3492e841 --- /dev/null +++ b/client/tailscale/localclient_aliases.go @@ -0,0 +1,79 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tailscale + +import ( + "context" + + "tailscale.com/client/local" + "tailscale.com/client/tailscale/apitype" + "tailscale.com/ipn/ipnstate" +) + +// ErrPeerNotFound is an alias for [tailscale.com/client/local.ErrPeerNotFound]. +// +// Deprecated: import [tailscale.com/client/local] instead. +var ErrPeerNotFound = local.ErrPeerNotFound + +// LocalClient is an alias for [tailscale.com/client/local.Client]. +// +// Deprecated: import [tailscale.com/client/local] instead. +type LocalClient = local.Client + +// IPNBusWatcher is an alias for [tailscale.com/client/local.IPNBusWatcher]. +// +// Deprecated: import [tailscale.com/client/local] instead. +type IPNBusWatcher = local.IPNBusWatcher + +// BugReportOpts is an alias for [tailscale.com/client/local.BugReportOpts]. +// +// Deprecated: import [tailscale.com/client/local] instead. +type BugReportOpts = local.BugReportOpts + +// PingOpts is an alias for [tailscale.com/client/local.PingOpts]. +// +// Deprecated: import [tailscale.com/client/local] instead. +type PingOpts = local.PingOpts + +// SetVersionMismatchHandler is an alias for [tailscale.com/client/local.SetVersionMismatchHandler]. +// +// Deprecated: import [tailscale.com/client/local] instead. +func SetVersionMismatchHandler(f func(clientVer, serverVer string)) { + local.SetVersionMismatchHandler(f) +} + +// IsAccessDeniedError is an alias for [tailscale.com/client/local.IsAccessDeniedError]. +// +// Deprecated: import [tailscale.com/client/local] instead. +func IsAccessDeniedError(err error) bool { + return local.IsAccessDeniedError(err) +} + +// IsPreconditionsFailedError is an alias for [tailscale.com/client/local.IsPreconditionsFailedError]. +// +// Deprecated: import [tailscale.com/client/local] instead. +func IsPreconditionsFailedError(err error) bool { + return local.IsPreconditionsFailedError(err) +} + +// WhoIs is an alias for [tailscale.com/client/local.WhoIs]. +// +// Deprecated: import [tailscale.com/client/local] instead and use [local.Client.WhoIs]. +func WhoIs(ctx context.Context, remoteAddr string) (*apitype.WhoIsResponse, error) { + return local.WhoIs(ctx, remoteAddr) +} + +// Status is an alias for [tailscale.com/client/local.Status]. +// +// Deprecated: import [tailscale.com/client/local] instead. +func Status(ctx context.Context) (*ipnstate.Status, error) { + return local.Status(ctx) +} + +// StatusWithoutPeers is an alias for [tailscale.com/client/local.StatusWithoutPeers]. +// +// Deprecated: import [tailscale.com/client/local] instead. +func StatusWithoutPeers(ctx context.Context) (*ipnstate.Status, error) { + return local.StatusWithoutPeers(ctx) +} diff --git a/client/tailscale/routes.go b/client/tailscale/routes.go index 5912fc46c..b72f2743f 100644 --- a/client/tailscale/routes.go +++ b/client/tailscale/routes.go @@ -44,7 +44,7 @@ func (c *Client) Routes(ctx context.Context, deviceID string) (routes *Routes, e // If status code was not successful, return the error. // TODO: Change the check for the StatusCode to include other 2XX success codes. if resp.StatusCode != http.StatusOK { - return nil, handleErrorResponse(b, resp) + return nil, HandleErrorResponse(b, resp) } var sr Routes @@ -84,7 +84,7 @@ func (c *Client) SetRoutes(ctx context.Context, deviceID string, subnets []netip // If status code was not successful, return the error. // TODO: Change the check for the StatusCode to include other 2XX success codes. if resp.StatusCode != http.StatusOK { - return nil, handleErrorResponse(b, resp) + return nil, HandleErrorResponse(b, resp) } var srr *Routes diff --git a/client/tailscale/tailnet.go b/client/tailscale/tailnet.go index 2539e7f23..9453962c9 100644 --- a/client/tailscale/tailnet.go +++ b/client/tailscale/tailnet.go @@ -9,7 +9,6 @@ import ( "context" "fmt" "net/http" - "net/url" "tailscale.com/util/httpm" ) @@ -22,7 +21,7 @@ func (c *Client) TailnetDeleteRequest(ctx context.Context, tailnetID string) (er } }() - path := fmt.Sprintf("%s/api/v2/tailnet/%s", c.baseURL(), url.PathEscape(string(tailnetID))) + path := c.BuildTailnetURL("tailnet") req, err := http.NewRequestWithContext(ctx, httpm.DELETE, path, nil) if err != nil { return err @@ -35,7 +34,7 @@ func (c *Client) TailnetDeleteRequest(ctx context.Context, tailnetID string) (er } if resp.StatusCode != http.StatusOK { - return handleErrorResponse(b, resp) + return HandleErrorResponse(b, resp) } return nil diff --git a/client/tailscale/tailscale.go b/client/tailscale/tailscale.go index 894561965..76e44454b 100644 --- a/client/tailscale/tailscale.go +++ b/client/tailscale/tailscale.go @@ -3,11 +3,12 @@ //go:build go1.19 -// Package tailscale contains Go clients for the Tailscale LocalAPI and -// Tailscale control plane API. +// Package tailscale contains a Go client for the Tailscale control plane API. // -// Warning: this package is in development and makes no API compatibility -// promises as of 2022-04-29. It is subject to change at any time. +// This package is only intended for internal and transitional use. +// +// Deprecated: the official control plane client is available at +// [tailscale.com/client/tailscale/v2]. package tailscale import ( @@ -16,13 +17,12 @@ import ( "fmt" "io" "net/http" + "net/url" + "path" ) // I_Acknowledge_This_API_Is_Unstable must be set true to use this package -// for now. It was added 2022-04-29 when it was moved to this git repo -// and will be removed when the public API has settled. -// -// TODO(bradfitz): remove this after the we're happy with the public API. +// for now. This package is being replaced by [tailscale.com/client/tailscale/v2]. var I_Acknowledge_This_API_Is_Unstable = false // TODO: use url.PathEscape() for deviceID and tailnets when constructing requests. @@ -34,8 +34,10 @@ const maxReadSize = 10 << 20 // Client makes API calls to the Tailscale control plane API server. // -// Use NewClient to instantiate one. Exported fields should be set before +// Use [NewClient] to instantiate one. Exported fields should be set before // the client is used and not changed thereafter. +// +// Deprecated: use [tailscale.com/client/tailscale/v2] instead. type Client struct { // tailnet is the globally unique identifier for a Tailscale network, such // as "example.com" or "user@gmail.com". @@ -49,8 +51,11 @@ type Client struct { BaseURL string // HTTPClient optionally specifies an alternate HTTP client to use. - // If nil, http.DefaultClient is used. + // If nil, [http.DefaultClient] is used. HTTPClient *http.Client + + // UserAgent optionally specifies an alternate User-Agent header + UserAgent string } func (c *Client) httpClient() *http.Client { @@ -60,6 +65,46 @@ func (c *Client) httpClient() *http.Client { return http.DefaultClient } +// BuildURL builds a url to http(s):///api/v2/ +// using the given pathElements. It url escapes each path element, so the +// caller doesn't need to worry about that. The last item of pathElements can +// be of type url.Values to add a query string to the URL. +// +// For example, BuildURL(devices, 5) with the default server URL would result in +// https://api.tailscale.com/api/v2/devices/5. +func (c *Client) BuildURL(pathElements ...any) string { + elem := make([]string, 1, len(pathElements)+1) + elem[0] = "/api/v2" + var query string + for i, pathElement := range pathElements { + if uv, ok := pathElement.(url.Values); ok && i == len(pathElements)-1 { + query = uv.Encode() + } else { + elem = append(elem, url.PathEscape(fmt.Sprint(pathElement))) + } + } + url := c.baseURL() + path.Join(elem...) + if query != "" { + url += "?" + query + } + return url +} + +// BuildTailnetURL builds a url to http(s):///api/v2/tailnet// +// using the given pathElements. It url escapes each path element, so the +// caller doesn't need to worry about that. The last item of pathElements can +// be of type url.Values to add a query string to the URL. +// +// For example, BuildTailnetURL(policy, validate) with the default server URL and a tailnet of "example.com" +// would result in https://api.tailscale.com/api/v2/tailnet/example.com/policy/validate. +func (c *Client) BuildTailnetURL(pathElements ...any) string { + allElements := make([]any, 2, len(pathElements)+2) + allElements[0] = "tailnet" + allElements[1] = c.tailnet + allElements = append(allElements, pathElements...) + return c.BuildURL(allElements...) +} + func (c *Client) baseURL() string { if c.BaseURL != "" { return c.BaseURL @@ -74,7 +119,7 @@ type AuthMethod interface { modifyRequest(req *http.Request) } -// APIKey is an AuthMethod for NewClient that authenticates requests +// APIKey is an [AuthMethod] for [NewClient] that authenticates requests // using an authkey. type APIKey string @@ -88,17 +133,20 @@ func (c *Client) setAuth(r *http.Request) { } } -// NewClient is a convenience method for instantiating a new Client. +// NewClient is a convenience method for instantiating a new [Client]. // // tailnet is the globally unique identifier for a Tailscale network, such // as "example.com" or "user@gmail.com". -// If httpClient is nil, then http.DefaultClient is used. +// If httpClient is nil, then [http.DefaultClient] is used. // "api.tailscale.com" is set as the BaseURL for the returned client // and can be changed manually by the user. +// +// Deprecated: use [tailscale.com/client/tailscale/v2] instead. func NewClient(tailnet string, auth AuthMethod) *Client { return &Client{ - tailnet: tailnet, - auth: auth, + tailnet: tailnet, + auth: auth, + UserAgent: "tailscale-client-oss", } } @@ -110,17 +158,16 @@ func (c *Client) Do(req *http.Request) (*http.Response, error) { return nil, errors.New("use of Client without setting I_Acknowledge_This_API_Is_Unstable") } c.setAuth(req) + if c.UserAgent != "" { + req.Header.Set("User-Agent", c.UserAgent) + } return c.httpClient().Do(req) } // sendRequest add the authentication key to the request and sends it. It // receives the response and reads up to 10MB of it. func (c *Client) sendRequest(req *http.Request) ([]byte, *http.Response, error) { - if !I_Acknowledge_This_API_Is_Unstable { - return nil, nil, errors.New("use of Client without setting I_Acknowledge_This_API_Is_Unstable") - } - c.setAuth(req) - resp, err := c.httpClient().Do(req) + resp, err := c.Do(req) if err != nil { return nil, resp, err } @@ -145,12 +192,14 @@ func (e ErrResponse) Error() string { return fmt.Sprintf("Status: %d, Message: %q", e.Status, e.Message) } -// handleErrorResponse decodes the error message from the server and returns -// an ErrResponse from it. -func handleErrorResponse(b []byte, resp *http.Response) error { +// HandleErrorResponse decodes the error message from the server and returns +// an [ErrResponse] from it. +// +// Deprecated: use [tailscale.com/client/tailscale/v2] instead. +func HandleErrorResponse(b []byte, resp *http.Response) error { var errResp ErrResponse if err := json.Unmarshal(b, &errResp); err != nil { - return err + return fmt.Errorf("json.Unmarshal %q: %w", b, err) } errResp.Status = resp.StatusCode return errResp diff --git a/client/tailscale/tailscale_test.go b/client/tailscale/tailscale_test.go new file mode 100644 index 000000000..67379293b --- /dev/null +++ b/client/tailscale/tailscale_test.go @@ -0,0 +1,86 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tailscale + +import ( + "net/url" + "testing" +) + +func TestClientBuildURL(t *testing.T) { + c := Client{BaseURL: "http://127.0.0.1:1234"} + for _, tt := range []struct { + desc string + elements []any + want string + }{ + { + desc: "single-element", + elements: []any{"devices"}, + want: "http://127.0.0.1:1234/api/v2/devices", + }, + { + desc: "multiple-elements", + elements: []any{"tailnet", "example.com"}, + want: "http://127.0.0.1:1234/api/v2/tailnet/example.com", + }, + { + desc: "escape-element", + elements: []any{"tailnet", "example dot com?foo=bar"}, + want: `http://127.0.0.1:1234/api/v2/tailnet/example%20dot%20com%3Ffoo=bar`, + }, + { + desc: "url.Values", + elements: []any{"tailnet", "example.com", "acl", url.Values{"details": {"1"}}}, + want: `http://127.0.0.1:1234/api/v2/tailnet/example.com/acl?details=1`, + }, + } { + t.Run(tt.desc, func(t *testing.T) { + got := c.BuildURL(tt.elements...) + if got != tt.want { + t.Errorf("got %q, want %q", got, tt.want) + } + }) + } +} + +func TestClientBuildTailnetURL(t *testing.T) { + c := Client{ + BaseURL: "http://127.0.0.1:1234", + tailnet: "example.com", + } + for _, tt := range []struct { + desc string + elements []any + want string + }{ + { + desc: "single-element", + elements: []any{"devices"}, + want: "http://127.0.0.1:1234/api/v2/tailnet/example.com/devices", + }, + { + desc: "multiple-elements", + elements: []any{"devices", 123}, + want: "http://127.0.0.1:1234/api/v2/tailnet/example.com/devices/123", + }, + { + desc: "escape-element", + elements: []any{"foo bar?baz=qux"}, + want: `http://127.0.0.1:1234/api/v2/tailnet/example.com/foo%20bar%3Fbaz=qux`, + }, + { + desc: "url.Values", + elements: []any{"acl", url.Values{"details": {"1"}}}, + want: `http://127.0.0.1:1234/api/v2/tailnet/example.com/acl?details=1`, + }, + } { + t.Run(tt.desc, func(t *testing.T) { + got := c.BuildTailnetURL(tt.elements...) + if got != tt.want { + t.Errorf("got %q, want %q", got, tt.want) + } + }) + } +} diff --git a/client/web/auth.go b/client/web/auth.go index 8b195a417..27eb24ee4 100644 --- a/client/web/auth.go +++ b/client/web/auth.go @@ -192,7 +192,7 @@ func (s *Server) controlSupportsCheckMode(ctx context.Context) bool { if err != nil { return true } - controlURL, err := url.Parse(prefs.ControlURLOrDefault()) + controlURL, err := url.Parse(prefs.ControlURLOrDefault(s.polc)) if err != nil { return true } diff --git a/client/web/package.json b/client/web/package.json index 4b3afb1df..c45f7d6a8 100644 --- a/client/web/package.json +++ b/client/web/package.json @@ -3,7 +3,7 @@ "version": "0.0.1", "license": "BSD-3-Clause", "engines": { - "node": "18.20.4", + "node": "22.14.0", "yarn": "1.22.19" }, "type": "module", @@ -20,7 +20,7 @@ "zustand": "^4.4.7" }, "devDependencies": { - "@types/node": "^18.16.1", + "@types/node": "^22.14.0", "@types/react": "^18.0.20", "@types/react-dom": "^18.0.6", "@vitejs/plugin-react-swc": "^3.6.0", diff --git a/client/web/src/api.ts b/client/web/src/api.ts index 9414e2d5d..e780c7645 100644 --- a/client/web/src/api.ts +++ b/client/web/src/api.ts @@ -249,7 +249,6 @@ export function useAPI() { return api } -let csrfToken: string let synoToken: string | undefined // required for synology API requests let unraidCsrfToken: string | undefined // required for unraid POST requests (#8062) @@ -298,12 +297,10 @@ export function apiFetch( headers: { Accept: "application/json", "Content-Type": contentType, - "X-CSRF-Token": csrfToken, }, body: body, }) .then((r) => { - updateCsrfToken(r) if (!r.ok) { return r.text().then((err) => { throw new Error(err) @@ -322,13 +319,6 @@ export function apiFetch( }) } -function updateCsrfToken(r: Response) { - const tok = r.headers.get("X-CSRF-Token") - if (tok) { - csrfToken = tok - } -} - export function setSynoToken(token?: string) { synoToken = token } diff --git a/client/web/src/components/views/login-view.tsx b/client/web/src/components/views/login-view.tsx index b2868bb46..f8c15b16d 100644 --- a/client/web/src/components/views/login-view.tsx +++ b/client/web/src/components/views/login-view.tsx @@ -1,13 +1,11 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -import React, { useState } from "react" +import React from "react" import { useAPI } from "src/api" import TailscaleIcon from "src/assets/icons/tailscale-icon.svg?react" import { NodeData } from "src/types" import Button from "src/ui/button" -import Collapsible from "src/ui/collapsible" -import Input from "src/ui/input" /** * LoginView is rendered when the client is not authenticated @@ -15,8 +13,6 @@ import Input from "src/ui/input" */ export default function LoginView({ data }: { data: NodeData }) { const api = useAPI() - const [controlURL, setControlURL] = useState("") - const [authKey, setAuthKey] = useState("") return (
@@ -88,8 +84,6 @@ export default function LoginView({ data }: { data: NodeData }) { action: "up", data: { Reauthenticate: true, - ControlURL: controlURL, - AuthKey: authKey, }, }) } @@ -98,34 +92,6 @@ export default function LoginView({ data }: { data: NodeData }) { > Log In - -

Auth Key

-

- Connect with a pre-authenticated key.{" "} - - Learn more → - -

- setAuthKey(e.target.value)} - placeholder="tskey-auth-XXX" - /> -

Server URL

-

Base URL of control server.

- setControlURL(e.target.value)} - placeholder="https://login.tailscale.com/" - /> -
)}
diff --git a/client/web/src/hooks/exit-nodes.ts b/client/web/src/hooks/exit-nodes.ts index b3ce0a9fa..5e47fbc22 100644 --- a/client/web/src/hooks/exit-nodes.ts +++ b/client/web/src/hooks/exit-nodes.ts @@ -66,7 +66,7 @@ export default function useExitNodes(node: NodeData, filter?: string) { // match from a list of exit node `options` to `nodes`. const addBestMatchNode = ( options: ExitNode[], - name: (l: ExitNodeLocation) => string + name: (loc: ExitNodeLocation) => string ) => { const bestNode = highestPriorityNode(options) if (!bestNode || !bestNode.Location) { @@ -86,7 +86,7 @@ export default function useExitNodes(node: NodeData, filter?: string) { locationNodesMap.forEach( // add one node per country (countryNodes) => - addBestMatchNode(flattenMap(countryNodes), (l) => l.Country) + addBestMatchNode(flattenMap(countryNodes), (loc) => loc.Country) ) } else { // Otherwise, show the best match on a city-level, @@ -97,12 +97,12 @@ export default function useExitNodes(node: NodeData, filter?: string) { countryNodes.forEach( // add one node per city (cityNodes) => - addBestMatchNode(cityNodes, (l) => `${l.Country}: ${l.City}`) + addBestMatchNode(cityNodes, (loc) => `${loc.Country}: ${loc.City}`) ) // add the "Country: Best Match" node addBestMatchNode( flattenMap(countryNodes), - (l) => `${l.Country}: Best Match` + (loc) => `${loc.Country}: Best Match` ) }) } diff --git a/client/web/web.go b/client/web/web.go index 04ba2d086..dbd3d5df0 100644 --- a/client/web/web.go +++ b/client/web/web.go @@ -5,8 +5,8 @@ package web import ( + "cmp" "context" - "crypto/rand" "encoding/json" "errors" "fmt" @@ -14,18 +14,20 @@ import ( "log" "net/http" "net/netip" + "net/url" "os" "path" - "path/filepath" + "slices" "strings" "sync" "time" - "github.com/gorilla/csrf" - "tailscale.com/client/tailscale" + "tailscale.com/client/local" "tailscale.com/client/tailscale/apitype" - "tailscale.com/clientupdate" "tailscale.com/envknob" + "tailscale.com/envknob/featureknob" + "tailscale.com/feature" + "tailscale.com/feature/buildfeatures" "tailscale.com/hostinfo" "tailscale.com/ipn" "tailscale.com/ipn/ipnstate" @@ -36,6 +38,7 @@ import ( "tailscale.com/types/logger" "tailscale.com/types/views" "tailscale.com/util/httpm" + "tailscale.com/util/syspolicy/policyclient" "tailscale.com/version" "tailscale.com/version/distro" ) @@ -49,7 +52,8 @@ type Server struct { mode ServerMode logf logger.Logf - lc *tailscale.LocalClient + polc policyclient.Client // must be non-nil + lc *local.Client timeNow func() time.Time // devMode indicates that the server run with frontend assets @@ -59,6 +63,12 @@ type Server struct { cgiMode bool pathPrefix string + // originOverride is the origin that the web UI is accessible from. + // This value is used in the fallback CSRF checks when Sec-Fetch-Site is not + // available. In this case the application will compare Host and Origin + // header values to determine if the request is from the same origin. + originOverride string + apiHandler http.Handler // serves api endpoints; csrf-protected assetsHandler http.Handler // serves frontend assets assetsCleanup func() // called from Server.Shutdown @@ -88,8 +98,8 @@ type Server struct { type ServerMode string const ( - // LoginServerMode serves a readonly login client for logging a - // node into a tailnet, and viewing a readonly interface of the + // LoginServerMode serves a read-only login client for logging a + // node into a tailnet, and viewing a read-only interface of the // node's current Tailscale settings. // // In this mode, API calls are authenticated via platform auth. @@ -109,7 +119,7 @@ const ( // This mode restricts the app to only being assessible over Tailscale, // and API calls are authenticated via browser sessions associated with // the source's Tailscale identity. If the source browser does not have - // a valid session, a readonly version of the app is displayed. + // a valid session, a read-only version of the app is displayed. ManageServerMode ServerMode = "manage" ) @@ -124,18 +134,22 @@ type ServerOpts struct { // PathPrefix is the URL prefix added to requests by CGI or reverse proxy. PathPrefix string - // LocalClient is the tailscale.LocalClient to use for this web server. + // LocalClient is the local.Client to use for this web server. // If nil, a new one will be created. - LocalClient *tailscale.LocalClient + LocalClient *local.Client // TimeNow optionally provides a time function. // time.Now is used as default. TimeNow func() time.Time // Logf optionally provides a logger function. - // log.Printf is used as default. + // If nil, log.Printf is used as default. Logf logger.Logf + // PolicyClient, if non-nil, will be used to fetch policy settings. + // If nil, the default policy client will be used. + PolicyClient policyclient.Client + // The following two fields are required and used exclusively // in ManageServerMode to facilitate the control server login // check step for authorizing browser sessions. @@ -149,6 +163,9 @@ type ServerOpts struct { // as completed. // This field is required for ManageServerMode mode. WaitAuthURL func(ctx context.Context, id string, src tailcfg.NodeID) (*tailcfg.WebClientAuthResponse, error) + + // OriginOverride specifies the origin that the web UI will be accessible from if hosted behind a reverse proxy or CGI. + OriginOverride string } // NewServer constructs a new Tailscale web client server. @@ -165,18 +182,20 @@ func NewServer(opts ServerOpts) (s *Server, err error) { return nil, fmt.Errorf("invalid Mode provided") } if opts.LocalClient == nil { - opts.LocalClient = &tailscale.LocalClient{} + opts.LocalClient = &local.Client{} } s = &Server{ - mode: opts.Mode, - logf: opts.Logf, - devMode: envknob.Bool("TS_DEBUG_WEB_CLIENT_DEV"), - lc: opts.LocalClient, - cgiMode: opts.CGIMode, - pathPrefix: opts.PathPrefix, - timeNow: opts.TimeNow, - newAuthURL: opts.NewAuthURL, - waitAuthURL: opts.WaitAuthURL, + mode: opts.Mode, + polc: cmp.Or(opts.PolicyClient, policyclient.Get()), + logf: opts.Logf, + devMode: envknob.Bool("TS_DEBUG_WEB_CLIENT_DEV"), + lc: opts.LocalClient, + cgiMode: opts.CGIMode, + pathPrefix: opts.PathPrefix, + timeNow: opts.TimeNow, + newAuthURL: opts.NewAuthURL, + waitAuthURL: opts.WaitAuthURL, + originOverride: opts.OriginOverride, } if opts.PathPrefix != "" { // Enforce that path prefix always has a single leading '/' @@ -202,25 +221,9 @@ func NewServer(opts ServerOpts) (s *Server, err error) { } s.assetsHandler, s.assetsCleanup = assetsHandler(s.devMode) - var metric string // clientmetric to report on startup - - // Create handler for "/api" requests with CSRF protection. - // We don't require secure cookies, since the web client is regularly used - // on network appliances that are served on local non-https URLs. - // The client is secured by limiting the interface it listens on, - // or by authenticating requests before they reach the web client. - csrfProtect := csrf.Protect(s.csrfKey(), csrf.Secure(false)) - switch s.mode { - case LoginServerMode: - s.apiHandler = csrfProtect(http.HandlerFunc(s.serveLoginAPI)) - metric = "web_login_client_initialization" - case ReadOnlyServerMode: - s.apiHandler = csrfProtect(http.HandlerFunc(s.serveLoginAPI)) - metric = "web_readonly_client_initialization" - case ManageServerMode: - s.apiHandler = csrfProtect(http.HandlerFunc(s.serveAPI)) - metric = "web_client_initialization" - } + var metric string + s.apiHandler, metric = s.modeAPIHandler(s.mode) + s.apiHandler = s.csrfProtect(s.apiHandler) // Don't block startup on reporting metric. // Report in separate go routine with 5 second timeout. @@ -233,6 +236,80 @@ func NewServer(opts ServerOpts) (s *Server, err error) { return s, nil } +func (s *Server) csrfProtect(h http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // CSRF is not required for GET, HEAD, or OPTIONS requests. + if slices.Contains([]string{"GET", "HEAD", "OPTIONS"}, r.Method) { + h.ServeHTTP(w, r) + return + } + + // first attempt to use Sec-Fetch-Site header (sent by all modern + // browsers to "potentially trustworthy" origins i.e. localhost or those + // served over HTTPS) + secFetchSite := r.Header.Get("Sec-Fetch-Site") + if secFetchSite == "same-origin" { + h.ServeHTTP(w, r) + return + } else if secFetchSite != "" { + http.Error(w, fmt.Sprintf("CSRF request denied with Sec-Fetch-Site %q", secFetchSite), http.StatusForbidden) + return + } + + // if Sec-Fetch-Site is not available we presume we are operating over HTTP. + // We fall back to comparing the Origin & Host headers. + + // use the Host header to determine the expected origin + // (use the override if set to allow for reverse proxying) + host := r.Host + if host == "" { + http.Error(w, "CSRF request denied with no Host header", http.StatusForbidden) + return + } + if s.originOverride != "" { + host = s.originOverride + } + + originHeader := r.Header.Get("Origin") + if originHeader == "" { + http.Error(w, "CSRF request denied with no Origin header", http.StatusForbidden) + return + } + parsedOrigin, err := url.Parse(originHeader) + if err != nil { + http.Error(w, fmt.Sprintf("CSRF request denied with invalid Origin %q", r.Header.Get("Origin")), http.StatusForbidden) + return + } + origin := parsedOrigin.Host + if origin == "" { + http.Error(w, "CSRF request denied with no host in the Origin header", http.StatusForbidden) + return + } + + if origin != host { + http.Error(w, fmt.Sprintf("CSRF request denied with mismatched Origin %q and Host %q", origin, host), http.StatusForbidden) + return + } + + h.ServeHTTP(w, r) + + }) +} + +func (s *Server) modeAPIHandler(mode ServerMode) (http.Handler, string) { + switch mode { + case LoginServerMode: + return http.HandlerFunc(s.serveLoginAPI), "web_login_client_initialization" + case ReadOnlyServerMode: + return http.HandlerFunc(s.serveLoginAPI), "web_readonly_client_initialization" + case ManageServerMode: + return http.HandlerFunc(s.serveAPI), "web_client_initialization" + default: // invalid mode + log.Fatalf("invalid mode: %v", mode) + } + return nil, "" +} + func (s *Server) Shutdown() { s.logf("web.Server: shutting down") if s.assetsCleanup != nil { @@ -317,7 +394,8 @@ func (s *Server) requireTailscaleIP(w http.ResponseWriter, r *http.Request) (han ipv6ServiceHost = "[" + tsaddr.TailscaleServiceIPv6String + "]" ) // allow requests on quad-100 (or ipv6 equivalent) - if r.Host == ipv4ServiceHost || r.Host == ipv6ServiceHost { + host := strings.TrimSuffix(r.Host, ":80") + if host == ipv4ServiceHost || host == ipv6ServiceHost { return false } @@ -419,6 +497,10 @@ func (s *Server) authorizeRequest(w http.ResponseWriter, r *http.Request) (ok bo // Client using system-specific auth. switch distro.Get() { case distro.Synology: + if !buildfeatures.HasSynology { + // Synology support not built in. + return false + } authorized, _ := authorizeSynology(r) return authorized case distro.QNAP: @@ -433,7 +515,6 @@ func (s *Server) authorizeRequest(w http.ResponseWriter, r *http.Request) (ok bo // It should only be called by Server.ServeHTTP, via Server.apiHandler, // which protects the handler using gorilla csrf. func (s *Server) serveLoginAPI(w http.ResponseWriter, r *http.Request) { - w.Header().Set("X-CSRF-Token", csrf.Token(r)) switch { case r.URL.Path == "/api/data" && r.Method == httpm.GET: s.serveGetNodeData(w, r) @@ -556,7 +637,6 @@ func (s *Server) serveAPI(w http.ResponseWriter, r *http.Request) { } } - w.Header().Set("X-CSRF-Token", csrf.Token(r)) path := strings.TrimPrefix(r.URL.Path, "/api") switch { case path == "/data" && r.Method == httpm.GET: @@ -694,16 +774,16 @@ func (s *Server) serveAPIAuth(w http.ResponseWriter, r *http.Request) { switch { case sErr != nil && errors.Is(sErr, errNotUsingTailscale): s.lc.IncrementCounter(r.Context(), "web_client_viewing_local", 1) - resp.Authorized = false // restricted to the readonly view + resp.Authorized = false // restricted to the read-only view case sErr != nil && errors.Is(sErr, errNotOwner): s.lc.IncrementCounter(r.Context(), "web_client_viewing_not_owner", 1) - resp.Authorized = false // restricted to the readonly view + resp.Authorized = false // restricted to the read-only view case sErr != nil && errors.Is(sErr, errTaggedLocalSource): s.lc.IncrementCounter(r.Context(), "web_client_viewing_local_tag", 1) - resp.Authorized = false // restricted to the readonly view + resp.Authorized = false // restricted to the read-only view case sErr != nil && errors.Is(sErr, errTaggedRemoteSource): s.lc.IncrementCounter(r.Context(), "web_client_viewing_remote_tag", 1) - resp.Authorized = false // restricted to the readonly view + resp.Authorized = false // restricted to the read-only view case sErr != nil && !errors.Is(sErr, errNoSession): // Any other error. http.Error(w, sErr.Error(), http.StatusInternalServerError) @@ -803,8 +883,8 @@ type nodeData struct { DeviceName string TailnetName string // TLS cert name DomainName string - IPv4 string - IPv6 string + IPv4 netip.Addr + IPv6 netip.Addr OS string IPNVersion string @@ -863,10 +943,14 @@ func (s *Server) serveGetNodeData(w http.ResponseWriter, r *http.Request) { return } filterRules, _ := s.lc.DebugPacketFilterRules(r.Context()) + ipv4, ipv6 := s.selfNodeAddresses(r, st) + data := &nodeData{ ID: st.Self.ID, Status: st.BackendState, DeviceName: strings.Split(st.Self.DNSName, ".")[0], + IPv4: ipv4, + IPv6: ipv6, OS: st.Self.OS, IPNVersion: strings.Split(st.Version, "-")[0], Profile: st.User[st.Self.UserID], @@ -879,17 +963,13 @@ func (s *Server) serveGetNodeData(w http.ResponseWriter, r *http.Request) { UnraidToken: os.Getenv("UNRAID_CSRF_TOKEN"), RunningSSHServer: prefs.RunSSH, URLPrefix: strings.TrimSuffix(s.pathPrefix, "/"), - ControlAdminURL: prefs.AdminPageURL(), + ControlAdminURL: prefs.AdminPageURL(s.polc), LicensesURL: licenses.LicensesURL(), Features: availableFeatures(), ACLAllowsAnyIncomingTraffic: s.aclsAllowAccess(filterRules), } - ipv4, ipv6 := s.selfNodeAddresses(r, st) - data.IPv4 = ipv4.String() - data.IPv6 = ipv6.String() - if hostinfo.GetEnvType() == hostinfo.HomeAssistantAddOn && data.URLPrefix == "" { // X-Ingress-Path is the path prefix in use for Home Assistant // https://developers.home-assistant.io/docs/add-ons/presentation#ingress @@ -903,9 +983,18 @@ func (s *Server) serveGetNodeData(w http.ResponseWriter, r *http.Request) { data.ClientVersion = cv } - if st.CurrentTailnet != nil { - data.TailnetName = st.CurrentTailnet.MagicDNSSuffix - data.DomainName = st.CurrentTailnet.Name + profile, _, err := s.lc.ProfileStatus(r.Context()) + if err != nil { + s.logf("error fetching profiles: %v", err) + // If for some reason we can't fetch profiles, + // continue to use st.CurrentTailnet if set. + if st.CurrentTailnet != nil { + data.TailnetName = st.CurrentTailnet.MagicDNSSuffix + data.DomainName = st.CurrentTailnet.Name + } + } else { + data.TailnetName = profile.NetworkProfile.MagicDNSName + data.DomainName = profile.NetworkProfile.DisplayNameOrDefault() } if st.Self.Tags != nil { data.Tags = st.Self.Tags.AsSlice() @@ -960,37 +1049,16 @@ func (s *Server) serveGetNodeData(w http.ResponseWriter, r *http.Request) { } func availableFeatures() map[string]bool { - env := hostinfo.GetEnvType() features := map[string]bool{ "advertise-exit-node": true, // available on all platforms "advertise-routes": true, // available on all platforms - "use-exit-node": canUseExitNode(env) == nil, - "ssh": envknob.CanRunTailscaleSSH() == nil, - "auto-update": version.IsUnstableBuild() && clientupdate.CanAutoUpdate(), - } - if env == hostinfo.HomeAssistantAddOn { - // Setting SSH on Home Assistant causes trouble on startup - // (since the flag is not being passed to `tailscale up`). - // Although Tailscale SSH does work here, - // it's not terribly useful since it's running in a separate container. - features["ssh"] = false + "use-exit-node": featureknob.CanUseExitNode() == nil, + "ssh": featureknob.CanRunTailscaleSSH() == nil, + "auto-update": version.IsUnstableBuild() && feature.CanAutoUpdate(), } return features } -func canUseExitNode(env hostinfo.EnvType) error { - switch dist := distro.Get(); dist { - case distro.Synology, // see https://github.com/tailscale/tailscale/issues/1995 - distro.QNAP, - distro.Unraid: - return fmt.Errorf("Tailscale exit nodes cannot be used on %s.", dist) - } - if env == hostinfo.HomeAssistantAddOn { - return errors.New("Tailscale exit nodes cannot be used on Home Assistant.") - } - return nil -} - // aclsAllowAccess returns whether tailnet ACLs (as expressed in the provided filter rules) // permit any devices to access the local web client. // This does not currently check whether a specific device can connect, just any device. @@ -1278,37 +1346,6 @@ func (s *Server) proxyRequestToLocalAPI(w http.ResponseWriter, r *http.Request) } } -// csrfKey returns a key that can be used for CSRF protection. -// If an error occurs during key creation, the error is logged and the active process terminated. -// If the server is running in CGI mode, the key is cached to disk and reused between requests. -// If an error occurs during key storage, the error is logged and the active process terminated. -func (s *Server) csrfKey() []byte { - csrfFile := filepath.Join(os.TempDir(), "tailscale-web-csrf.key") - - // if running in CGI mode, try to read from disk, but ignore errors - if s.cgiMode { - key, _ := os.ReadFile(csrfFile) - if len(key) == 32 { - return key - } - } - - // create a new key - key := make([]byte, 32) - if _, err := rand.Read(key); err != nil { - log.Fatalf("error generating CSRF key: %v", err) - } - - // if running in CGI mode, try to write the newly created key to disk, and exit if it fails. - if s.cgiMode { - if err := os.WriteFile(csrfFile, key, 0600); err != nil { - log.Fatalf("unable to store CSRF key: %v", err) - } - } - - return key -} - // enforcePrefix returns a HandlerFunc that enforces a given path prefix is used in requests, // then strips it before invoking h. // Unlike http.StripPrefix, it does not return a 404 if the prefix is not present. diff --git a/client/web/web_test.go b/client/web/web_test.go index 3c5543c12..9ba16bccf 100644 --- a/client/web/web_test.go +++ b/client/web/web_test.go @@ -20,7 +20,7 @@ import ( "time" "github.com/google/go-cmp/cmp" - "tailscale.com/client/tailscale" + "tailscale.com/client/local" "tailscale.com/client/tailscale/apitype" "tailscale.com/ipn" "tailscale.com/ipn/ipnstate" @@ -28,6 +28,7 @@ import ( "tailscale.com/tailcfg" "tailscale.com/types/views" "tailscale.com/util/httpm" + "tailscale.com/util/syspolicy/policyclient" ) func TestQnapAuthnURL(t *testing.T) { @@ -120,7 +121,7 @@ func TestServeAPI(t *testing.T) { s := &Server{ mode: ManageServerMode, - lc: &tailscale.LocalClient{Dial: lal.Dial}, + lc: &local.Client{Dial: lal.Dial}, timeNow: time.Now, } @@ -288,7 +289,7 @@ func TestGetTailscaleBrowserSession(t *testing.T) { s := &Server{ timeNow: time.Now, - lc: &tailscale.LocalClient{Dial: lal.Dial}, + lc: &local.Client{Dial: lal.Dial}, } // Add some browser sessions to cache state. @@ -457,7 +458,7 @@ func TestAuthorizeRequest(t *testing.T) { s := &Server{ mode: ManageServerMode, - lc: &tailscale.LocalClient{Dial: lal.Dial}, + lc: &local.Client{Dial: lal.Dial}, timeNow: time.Now, } validCookie := "ts-cookie" @@ -572,10 +573,11 @@ func TestServeAuth(t *testing.T) { s := &Server{ mode: ManageServerMode, - lc: &tailscale.LocalClient{Dial: lal.Dial}, + lc: &local.Client{Dial: lal.Dial}, timeNow: func() time.Time { return timeNow }, newAuthURL: mockNewAuthURL, waitAuthURL: mockWaitAuthURL, + polc: policyclient.NoPolicyClient{}, } successCookie := "ts-cookie-success" @@ -914,7 +916,7 @@ func TestServeAPIAuthMetricLogging(t *testing.T) { s := &Server{ mode: ManageServerMode, - lc: &tailscale.LocalClient{Dial: lal.Dial}, + lc: &local.Client{Dial: lal.Dial}, timeNow: func() time.Time { return timeNow }, newAuthURL: mockNewAuthURL, waitAuthURL: mockWaitAuthURL, @@ -1126,7 +1128,7 @@ func TestRequireTailscaleIP(t *testing.T) { s := &Server{ mode: ManageServerMode, - lc: &tailscale.LocalClient{Dial: lal.Dial}, + lc: &local.Client{Dial: lal.Dial}, timeNow: time.Now, logf: t.Logf, } @@ -1175,6 +1177,16 @@ func TestRequireTailscaleIP(t *testing.T) { target: "http://[fd7a:115c:a1e0::53]/", wantHandled: false, }, + { + name: "quad-100:80", + target: "http://100.100.100.100:80/", + wantHandled: false, + }, + { + name: "ipv6-service-addr:80", + target: "http://[fd7a:115c:a1e0::53]:80/", + wantHandled: false, + }, } for _, tt := range tests { @@ -1477,3 +1489,101 @@ func mockWaitAuthURL(_ context.Context, id string, src tailcfg.NodeID) (*tailcfg return nil, errors.New("unknown id") } } + +func TestCSRFProtect(t *testing.T) { + tests := []struct { + name string + method string + secFetchSite string + host string + origin string + originOverride string + wantError bool + }{ + { + name: "GET requests with no header are allowed", + method: "GET", + }, + { + name: "POST requests with same-origin are allowed", + method: "POST", + secFetchSite: "same-origin", + }, + { + name: "POST requests with cross-site are not allowed", + method: "POST", + secFetchSite: "cross-site", + wantError: true, + }, + { + name: "POST requests with unknown sec-fetch-site values are not allowed", + method: "POST", + secFetchSite: "new-unknown-value", + wantError: true, + }, + { + name: "POST requests with none are not allowed", + method: "POST", + secFetchSite: "none", + wantError: true, + }, + { + name: "POST requests with no sec-fetch-site header but matching host and origin are allowed", + method: "POST", + host: "example.com", + origin: "https://example.com", + }, + { + name: "POST requests with no sec-fetch-site and non-matching host and origin are not allowed", + method: "POST", + host: "example.com", + origin: "https://example.net", + wantError: true, + }, + { + name: "POST requests with no sec-fetch-site and and origin that matches the override are allowed", + method: "POST", + originOverride: "example.net", + host: "internal.example.foo", // Host can be changed by reverse proxies + origin: "http://example.net", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintf(w, "OK") + }) + + s := &Server{ + originOverride: tt.originOverride, + } + withCSRF := s.csrfProtect(handler) + + r := httptest.NewRequest(tt.method, "http://example.com/", nil) + if tt.secFetchSite != "" { + r.Header.Set("Sec-Fetch-Site", tt.secFetchSite) + } + if tt.host != "" { + r.Host = tt.host + } + if tt.origin != "" { + r.Header.Set("Origin", tt.origin) + } + + w := httptest.NewRecorder() + withCSRF.ServeHTTP(w, r) + res := w.Result() + defer res.Body.Close() + if tt.wantError { + if res.StatusCode != http.StatusForbidden { + t.Errorf("expected status forbidden, got %v", res.StatusCode) + } + return + } + if res.StatusCode != http.StatusOK { + t.Errorf("expected status ok, got %v", res.StatusCode) + } + }) + } +} diff --git a/client/web/yarn.lock b/client/web/yarn.lock index 2c8fca5e5..7c9d9222e 100644 --- a/client/web/yarn.lock +++ b/client/web/yarn.lock @@ -1087,11 +1087,9 @@ integrity sha512-x/rqGMdzj+fWZvCOYForTghzbtqPDZ5gPwaoNGHdgDfF2QA/XZbCBp4Moo5scrkAMPhB7z26XM/AaHuIJdgauA== "@babel/runtime@^7.12.5", "@babel/runtime@^7.13.10", "@babel/runtime@^7.16.3", "@babel/runtime@^7.23.2", "@babel/runtime@^7.8.4": - version "7.23.4" - resolved "https://registry.yarnpkg.com/@babel/runtime/-/runtime-7.23.4.tgz#36fa1d2b36db873d25ec631dcc4923fdc1cf2e2e" - integrity sha512-2Yv65nlWnWlSpe3fXEyX5i7fx5kIKo4Qbcj+hMO0odwaneFjfXw5fdum+4yL20O0QiaHpia0cYQ9xpNMqrBwHg== - dependencies: - regenerator-runtime "^0.14.0" + version "7.28.2" + resolved "https://registry.yarnpkg.com/@babel/runtime/-/runtime-7.28.2.tgz#2ae5a9d51cc583bd1f5673b3bb70d6d819682473" + integrity sha512-KHp2IflsnGywDjBWDkR9iEqiWSpc8GIi0lgTT3mOElT0PP1tG26P4tmFI2YvAdzgq9RGyoHZQEIEdZy6Ec5xCA== "@babel/template@^7.22.15": version "7.22.15" @@ -1880,12 +1878,12 @@ resolved "https://registry.yarnpkg.com/@types/json5/-/json5-0.0.29.tgz#ee28707ae94e11d2b827bcbe5270bcea7f3e71ee" integrity sha512-dRLjCWHYg4oaA77cxO64oO+7JwCwnIzkZPdrrC71jQmQtlhM556pwKo5bUzqvZndkVbeFLIIi+9TC40JNF5hNQ== -"@types/node@^18.16.1": - version "18.19.18" - resolved "https://registry.yarnpkg.com/@types/node/-/node-18.19.18.tgz#7526471b28828d1fef1f7e4960fb9477e6e4369c" - integrity sha512-80CP7B8y4PzZF0GWx15/gVWRrB5y/bIjNI84NK3cmQJu0WZwvmj2WMA5LcofQFVfLqqCSp545+U2LsrVzX36Zg== +"@types/node@^22.14.0": + version "22.14.0" + resolved "https://registry.yarnpkg.com/@types/node/-/node-22.14.0.tgz#d3bfa3936fef0dbacd79ea3eb17d521c628bb47e" + integrity sha512-Kmpl+z84ILoG+3T/zQFyAJsU6EPTmOCj8/2+83fSN6djd6I4o7uOuGIH6vq3PrjY5BGitSbFuMN18j3iknubbA== dependencies: - undici-types "~5.26.4" + undici-types "~6.21.0" "@types/parse-json@^4.0.0": version "4.0.2" @@ -2450,6 +2448,14 @@ cac@^6.7.14: resolved "https://registry.yarnpkg.com/cac/-/cac-6.7.14.tgz#804e1e6f506ee363cb0e3ccbb09cad5dd9870959" integrity sha512-b6Ilus+c3RrdDk+JhLKUAQfzzgLEPy6wcXqS7f/xe1EETvsDP6GORG7SFuOs6cID5YkqchW/LXZbX5bc8j7ZcQ== +call-bind-apply-helpers@^1.0.1, call-bind-apply-helpers@^1.0.2: + version "1.0.2" + resolved "https://registry.yarnpkg.com/call-bind-apply-helpers/-/call-bind-apply-helpers-1.0.2.tgz#4b5428c222be985d79c3d82657479dbe0b59b2d6" + integrity sha512-Sp1ablJ0ivDkSzjcaJdxEunN5/XvksFJ2sMBFfq6x0ryhQV/2b/KwFe21cMpmHtPOSij8K99/wSfoEuTObmuMQ== + dependencies: + es-errors "^1.3.0" + function-bind "^1.1.2" + call-bind@^1.0.0, call-bind@^1.0.2, call-bind@^1.0.4, call-bind@^1.0.5: version "1.0.5" resolved "https://registry.yarnpkg.com/call-bind/-/call-bind-1.0.5.tgz#6fa2b7845ce0ea49bf4d8b9ef64727a2c2e2e513" @@ -2767,6 +2773,15 @@ dot-case@^3.0.4: no-case "^3.0.4" tslib "^2.0.3" +dunder-proto@^1.0.1: + version "1.0.1" + resolved "https://registry.yarnpkg.com/dunder-proto/-/dunder-proto-1.0.1.tgz#d7ae667e1dc83482f8b70fd0f6eefc50da30f58a" + integrity sha512-KIN/nDJBQRcXw0MLVhZE9iQHmG68qAVIBg9CqmUYjmQIhgij9U5MFvrqkUL5FbtyyzZuOeOt0zdeRe4UY7ct+A== + dependencies: + call-bind-apply-helpers "^1.0.1" + es-errors "^1.3.0" + gopd "^1.2.0" + electron-to-chromium@^1.4.535: version "1.4.596" resolved "https://registry.yarnpkg.com/electron-to-chromium/-/electron-to-chromium-1.4.596.tgz#6752d1aa795d942d49dfc5d3764d6ea283fab1d7" @@ -2834,6 +2849,16 @@ es-abstract@^1.22.1: unbox-primitive "^1.0.2" which-typed-array "^1.1.13" +es-define-property@^1.0.1: + version "1.0.1" + resolved "https://registry.yarnpkg.com/es-define-property/-/es-define-property-1.0.1.tgz#983eb2f9a6724e9303f61addf011c72e09e0b0fa" + integrity sha512-e3nRfgfUZ4rNGL232gUgX06QNyyez04KdjFrF+LTRoOXmrOgFKDg4BCdsjW8EnT69eqdYGmRpJwiPVYNrCaW3g== + +es-errors@^1.3.0: + version "1.3.0" + resolved "https://registry.yarnpkg.com/es-errors/-/es-errors-1.3.0.tgz#05f75a25dab98e4fb1dcd5e1472c0546d5057c8f" + integrity sha512-Zf5H2Kxt2xjTvbJvP2ZWLEICxA6j+hAmMzIlypy4xcBg1vKVnx89Wy0GbS+kf5cwCVFFzdCFh2XSCFNULS6csw== + es-iterator-helpers@^1.0.12, es-iterator-helpers@^1.0.15: version "1.0.15" resolved "https://registry.yarnpkg.com/es-iterator-helpers/-/es-iterator-helpers-1.0.15.tgz#bd81d275ac766431d19305923707c3efd9f1ae40" @@ -2854,6 +2879,13 @@ es-iterator-helpers@^1.0.12, es-iterator-helpers@^1.0.15: iterator.prototype "^1.1.2" safe-array-concat "^1.0.1" +es-object-atoms@^1.0.0, es-object-atoms@^1.1.1: + version "1.1.1" + resolved "https://registry.yarnpkg.com/es-object-atoms/-/es-object-atoms-1.1.1.tgz#1c4f2c4837327597ce69d2ca190a7fdd172338c1" + integrity sha512-FGgH2h8zKNim9ljj7dankFPcICIK9Cp5bm+c2gQSYePhpaG5+esrLODihIorn+Pe6FGJzWhXQotPv73jTaldXA== + dependencies: + es-errors "^1.3.0" + es-set-tostringtag@^2.0.1: version "2.0.2" resolved "https://registry.yarnpkg.com/es-set-tostringtag/-/es-set-tostringtag-2.0.2.tgz#11f7cc9f63376930a5f20be4915834f4bc74f9c9" @@ -2863,6 +2895,16 @@ es-set-tostringtag@^2.0.1: has-tostringtag "^1.0.0" hasown "^2.0.0" +es-set-tostringtag@^2.1.0: + version "2.1.0" + resolved "https://registry.yarnpkg.com/es-set-tostringtag/-/es-set-tostringtag-2.1.0.tgz#f31dbbe0c183b00a6d26eb6325c810c0fd18bd4d" + integrity sha512-j6vWzfrGVfyXxge+O0x5sh6cvxAog0a/4Rdd2K36zCMV5eJ+/+tOAngRO8cODMNWbVRdVlmGZQL2YS3yR8bIUA== + dependencies: + es-errors "^1.3.0" + get-intrinsic "^1.2.6" + has-tostringtag "^1.0.2" + hasown "^2.0.2" + es-shim-unscopables@^1.0.0: version "1.0.2" resolved "https://registry.yarnpkg.com/es-shim-unscopables/-/es-shim-unscopables-1.0.2.tgz#1f6942e71ecc7835ed1c8a83006d8771a63a3763" @@ -3270,12 +3312,14 @@ for-each@^0.3.3: is-callable "^1.1.3" form-data@^4.0.0: - version "4.0.0" - resolved "https://registry.yarnpkg.com/form-data/-/form-data-4.0.0.tgz#93919daeaf361ee529584b9b31664dc12c9fa452" - integrity sha512-ETEklSGi5t0QMZuiXoA/Q6vcnxcLQP5vdugSpuAyi6SVGi2clPPp+xgEhuMaHC+zGgn31Kd235W35f7Hykkaww== + version "4.0.4" + resolved "https://registry.yarnpkg.com/form-data/-/form-data-4.0.4.tgz#784cdcce0669a9d68e94d11ac4eea98088edd2c4" + integrity sha512-KrGhL9Q4zjj0kiUt5OO4Mr/A/jlI2jDYs5eHBpYHPcBEVSiipAvn2Ko2HnPe20rmcuuvMHNdZFp+4IlGTMF0Ow== dependencies: asynckit "^0.4.0" combined-stream "^1.0.8" + es-set-tostringtag "^2.1.0" + hasown "^2.0.2" mime-types "^2.1.12" fraction.js@^4.2.0: @@ -3333,11 +3377,35 @@ get-intrinsic@^1.0.2, get-intrinsic@^1.1.1, get-intrinsic@^1.1.3, get-intrinsic@ has-symbols "^1.0.3" hasown "^2.0.0" +get-intrinsic@^1.2.6: + version "1.3.0" + resolved "https://registry.yarnpkg.com/get-intrinsic/-/get-intrinsic-1.3.0.tgz#743f0e3b6964a93a5491ed1bffaae054d7f98d01" + integrity sha512-9fSjSaos/fRIVIp+xSJlE6lfwhES7LNtKaCBIamHsjr2na1BiABJPo0mOjjz8GJDURarmCPGqaiVg5mfjb98CQ== + dependencies: + call-bind-apply-helpers "^1.0.2" + es-define-property "^1.0.1" + es-errors "^1.3.0" + es-object-atoms "^1.1.1" + function-bind "^1.1.2" + get-proto "^1.0.1" + gopd "^1.2.0" + has-symbols "^1.1.0" + hasown "^2.0.2" + math-intrinsics "^1.1.0" + get-nonce@^1.0.0: version "1.0.1" resolved "https://registry.yarnpkg.com/get-nonce/-/get-nonce-1.0.1.tgz#fdf3f0278073820d2ce9426c18f07481b1e0cdf3" integrity sha512-FJhYRoDaiatfEkUK8HKlicmu/3SGFD51q3itKDGoSTysQJBnfOcxU5GxnhE1E6soB76MbT0MBtnKJuXyAx+96Q== +get-proto@^1.0.1: + version "1.0.1" + resolved "https://registry.yarnpkg.com/get-proto/-/get-proto-1.0.1.tgz#150b3f2743869ef3e851ec0c49d15b1d14d00ee1" + integrity sha512-sTSfBjoXBp89JvIKIefqw7U2CCebsc74kiY6awiGogKtoSGbgjYE/G/+l9sF3MWFPNc9IcoOC4ODfKHfxFmp0g== + dependencies: + dunder-proto "^1.0.1" + es-object-atoms "^1.0.0" + get-stream@^8.0.1: version "8.0.1" resolved "https://registry.yarnpkg.com/get-stream/-/get-stream-8.0.1.tgz#def9dfd71742cd7754a7761ed43749a27d02eca2" @@ -3437,6 +3505,11 @@ gopd@^1.0.1: dependencies: get-intrinsic "^1.1.3" +gopd@^1.2.0: + version "1.2.0" + resolved "https://registry.yarnpkg.com/gopd/-/gopd-1.2.0.tgz#89f56b8217bdbc8802bd299df6d7f1081d7e51a1" + integrity sha512-ZUKRh6/kUFoAiTAtTYPZJ3hw9wNxx+BIBOijnlG9PnrJsCcSjs1wyyD6vJpaYtgnzDrKYRSqf3OO6Rfa93xsRg== + graphemer@^1.4.0: version "1.4.0" resolved "https://registry.yarnpkg.com/graphemer/-/graphemer-1.4.0.tgz#fb2f1d55e0e3a1849aeffc90c4fa0dd53a0e66c6" @@ -3474,6 +3547,11 @@ has-symbols@^1.0.2, has-symbols@^1.0.3: resolved "https://registry.yarnpkg.com/has-symbols/-/has-symbols-1.0.3.tgz#bb7b2c4349251dce87b125f7bdf874aa7c8b39f8" integrity sha512-l3LCuF6MgDNwTDKkdYGEihYjt5pRPbEg46rtlmnSPlUbgmB8LOIrKJbYYFBSbnPaJexMKtiPO8hmeRjRz2Td+A== +has-symbols@^1.1.0: + version "1.1.0" + resolved "https://registry.yarnpkg.com/has-symbols/-/has-symbols-1.1.0.tgz#fc9c6a783a084951d0b971fe1018de813707a338" + integrity sha512-1cDNdwJ2Jaohmb3sg4OmKaMBwuC48sYni5HUw2DvsC8LjGTLK9h+eb1X6RyuOHe4hT0ULCW68iomhjUoKUqlPQ== + has-tostringtag@^1.0.0: version "1.0.0" resolved "https://registry.yarnpkg.com/has-tostringtag/-/has-tostringtag-1.0.0.tgz#7e133818a7d394734f941e73c3d3f9291e658b25" @@ -3481,6 +3559,13 @@ has-tostringtag@^1.0.0: dependencies: has-symbols "^1.0.2" +has-tostringtag@^1.0.2: + version "1.0.2" + resolved "https://registry.yarnpkg.com/has-tostringtag/-/has-tostringtag-1.0.2.tgz#2cdc42d40bef2e5b4eeab7c01a73c54ce7ab5abc" + integrity sha512-NqADB8VjPFLM2V0VvHUewwwsw0ZWBaIdgo+ieHtK3hasLz4qeCRjYcqfB6AQrBggRKppKF8L52/VqdVsO47Dlw== + dependencies: + has-symbols "^1.0.3" + hasown@^2.0.0: version "2.0.0" resolved "https://registry.yarnpkg.com/hasown/-/hasown-2.0.0.tgz#f4c513d454a57b7c7e1650778de226b11700546c" @@ -3488,6 +3573,13 @@ hasown@^2.0.0: dependencies: function-bind "^1.1.2" +hasown@^2.0.2: + version "2.0.2" + resolved "https://registry.yarnpkg.com/hasown/-/hasown-2.0.2.tgz#003eaf91be7adc372e84ec59dc37252cedb80003" + integrity sha512-0hJU9SCPvmMzIBdZFqNPXWa6dqh7WdH0cII9y+CyS8rG3nL48Bclra9HmKhVVUHyPWNH5Y7xDwAB7bfgSjkUMQ== + dependencies: + function-bind "^1.1.2" + html-encoding-sniffer@^4.0.0: version "4.0.0" resolved "https://registry.yarnpkg.com/html-encoding-sniffer/-/html-encoding-sniffer-4.0.0.tgz#696df529a7cfd82446369dc5193e590a3735b448" @@ -3992,6 +4084,11 @@ magic-string@^0.30.5: dependencies: "@jridgewell/sourcemap-codec" "^1.4.15" +math-intrinsics@^1.1.0: + version "1.1.0" + resolved "https://registry.yarnpkg.com/math-intrinsics/-/math-intrinsics-1.1.0.tgz#a0dd74be81e2aa5c2f27e65ce283605ee4e2b7f9" + integrity sha512-/IXtbwEk5HTPyEwyKX6hGkYXxM9nbj64B+ilVJnC/R6B0pH5G4V3b0pVbL7DBj4tkhBAppbQUlf6F6Xl9LHu1g== + merge-stream@^2.0.0: version "2.0.0" resolved "https://registry.yarnpkg.com/merge-stream/-/merge-stream-2.0.0.tgz#52823629a14dd00c9770fb6ad47dc6310f2c1f60" @@ -4543,11 +4640,6 @@ regenerate@^1.4.2: resolved "https://registry.yarnpkg.com/regenerate/-/regenerate-1.4.2.tgz#b9346d8827e8f5a32f7ba29637d398b69014848a" integrity sha512-zrceR/XhGYU/d/opr2EKO7aRHUeiBI8qjtfHqADTwZd6Szfy16la6kqD0MIUs5z5hx6AaKa+PixpPrR289+I0A== -regenerator-runtime@^0.14.0: - version "0.14.0" - resolved "https://registry.yarnpkg.com/regenerator-runtime/-/regenerator-runtime-0.14.0.tgz#5e19d68eb12d486f797e15a3c6a918f7cec5eb45" - integrity sha512-srw17NI0TUWHuGa5CFGGmhfNIeja30WMBfbslPNhf6JrqQlLN5gcrvig1oqPxiVaXb0oW0XRKtH6Nngs5lKCIA== - regenerator-transform@^0.15.2: version "0.15.2" resolved "https://registry.yarnpkg.com/regenerator-transform/-/regenerator-transform-0.15.2.tgz#5bbae58b522098ebdf09bca2f83838929001c7a4" @@ -5124,10 +5216,10 @@ unbox-primitive@^1.0.2: has-symbols "^1.0.3" which-boxed-primitive "^1.0.2" -undici-types@~5.26.4: - version "5.26.5" - resolved "https://registry.yarnpkg.com/undici-types/-/undici-types-5.26.5.tgz#bcd539893d00b56e964fd2657a4866b221a65617" - integrity sha512-JlCMO+ehdEIKqlFxk6IfVoAUVmgz7cU7zD/h9XZ0qzeosSHmUJVOzSQvvYSYWXkFXC+IfLKSIffhv0sVZup6pA== +undici-types@~6.21.0: + version "6.21.0" + resolved "https://registry.yarnpkg.com/undici-types/-/undici-types-6.21.0.tgz#691d00af3909be93a7faa13be61b3a5b50ef12cb" + integrity sha512-iwDZqg0QAGrg9Rav5H4n0M64c3mkR59cJ6wQp+7C4nI0gsmExaedaYLNO44eT4AtBBwjbTiGPMlt2Md0T9H9JQ== unicode-canonical-property-names-ecmascript@^2.0.0: version "2.0.0" diff --git a/clientupdate/clientupdate.go b/clientupdate/clientupdate.go index 67edce05b..3a0a8d03e 100644 --- a/clientupdate/clientupdate.go +++ b/clientupdate/clientupdate.go @@ -27,11 +27,11 @@ import ( "strconv" "strings" - "github.com/google/uuid" - "tailscale.com/clientupdate/distsign" + "tailscale.com/feature" + "tailscale.com/hostinfo" + "tailscale.com/types/lazy" "tailscale.com/types/logger" "tailscale.com/util/cmpver" - "tailscale.com/util/winutil" "tailscale.com/version" "tailscale.com/version/distro" ) @@ -172,6 +172,12 @@ func NewUpdater(args Arguments) (*Updater, error) { type updateFunction func() error func (up *Updater) getUpdateFunction() (fn updateFunction, canAutoUpdate bool) { + hi := hostinfo.New() + // We don't know how to update custom tsnet binaries, it's up to the user. + if hi.Package == "tsnet" { + return nil, false + } + switch runtime.GOOS { case "windows": return up.updateWindows, true @@ -245,9 +251,17 @@ func (up *Updater) getUpdateFunction() (fn updateFunction, canAutoUpdate bool) { return nil, false } -// CanAutoUpdate reports whether auto-updating via the clientupdate package +var canAutoUpdateCache lazy.SyncValue[bool] + +func init() { + feature.HookCanAutoUpdate.Set(canAutoUpdate) +} + +// canAutoUpdate reports whether auto-updating via the clientupdate package // is supported for the current os/distro. -func CanAutoUpdate() bool { +func canAutoUpdate() bool { return canAutoUpdateCache.Get(canAutoUpdateUncached) } + +func canAutoUpdateUncached() bool { if version.IsMacSysExt() { // Macsys uses Sparkle for auto-updates, which doesn't have an update // function in this package. @@ -404,13 +418,13 @@ func parseSynoinfo(path string) (string, error) { // Extract the CPU in the middle (88f6282 in the above example). s := bufio.NewScanner(f) for s.Scan() { - l := s.Text() - if !strings.HasPrefix(l, "unique=") { + line := s.Text() + if !strings.HasPrefix(line, "unique=") { continue } - parts := strings.SplitN(l, "_", 3) + parts := strings.SplitN(line, "_", 3) if len(parts) != 3 { - return "", fmt.Errorf(`malformed %q: found %q, expected format like 'unique="synology_$cpu_$model'`, path, l) + return "", fmt.Errorf(`malformed %q: found %q, expected format like 'unique="synology_$cpu_$model'`, path, line) } return parts[1], nil } @@ -756,164 +770,6 @@ func (up *Updater) updateMacAppStore() error { return nil } -const ( - // winMSIEnv is the environment variable that, if set, is the MSI file for - // the update command to install. It's passed like this so we can stop the - // tailscale.exe process from running before the msiexec process runs and - // tries to overwrite ourselves. - winMSIEnv = "TS_UPDATE_WIN_MSI" - // winExePathEnv is the environment variable that is set along with - // winMSIEnv and carries the full path of the calling tailscale.exe binary. - // It is used to re-launch the GUI process (tailscale-ipn.exe) after - // install is complete. - winExePathEnv = "TS_UPDATE_WIN_EXE_PATH" -) - -var ( - verifyAuthenticode func(string) error // set non-nil only on Windows - markTempFileFunc func(string) error // set non-nil only on Windows -) - -func (up *Updater) updateWindows() error { - if msi := os.Getenv(winMSIEnv); msi != "" { - // stdout/stderr from this part of the install could be lost since the - // parent tailscaled is replaced. Create a temp log file to have some - // output to debug with in case update fails. - close, err := up.switchOutputToFile() - if err != nil { - up.Logf("failed to create log file for installation: %v; proceeding with existing outputs", err) - } else { - defer close.Close() - } - - up.Logf("installing %v ...", msi) - if err := up.installMSI(msi); err != nil { - up.Logf("MSI install failed: %v", err) - return err - } - - up.Logf("success.") - return nil - } - - if !winutil.IsCurrentProcessElevated() { - return errors.New(`update must be run as Administrator - -you can run the command prompt as Administrator one of these ways: -* right-click cmd.exe, select 'Run as administrator' -* press Windows+x, then press a -* press Windows+r, type in "cmd", then press Ctrl+Shift+Enter`) - } - ver, err := requestedTailscaleVersion(up.Version, up.Track) - if err != nil { - return err - } - arch := runtime.GOARCH - if arch == "386" { - arch = "x86" - } - if !up.confirm(ver) { - return nil - } - - tsDir := filepath.Join(os.Getenv("ProgramData"), "Tailscale") - msiDir := filepath.Join(tsDir, "MSICache") - if fi, err := os.Stat(tsDir); err != nil { - return fmt.Errorf("expected %s to exist, got stat error: %w", tsDir, err) - } else if !fi.IsDir() { - return fmt.Errorf("expected %s to be a directory; got %v", tsDir, fi.Mode()) - } - if err := os.MkdirAll(msiDir, 0700); err != nil { - return err - } - up.cleanupOldDownloads(filepath.Join(msiDir, "*.msi")) - pkgsPath := fmt.Sprintf("%s/tailscale-setup-%s-%s.msi", up.Track, ver, arch) - msiTarget := filepath.Join(msiDir, path.Base(pkgsPath)) - if err := up.downloadURLToFile(pkgsPath, msiTarget); err != nil { - return err - } - - up.Logf("verifying MSI authenticode...") - if err := verifyAuthenticode(msiTarget); err != nil { - return fmt.Errorf("authenticode verification of %s failed: %w", msiTarget, err) - } - up.Logf("authenticode verification succeeded") - - up.Logf("making tailscale.exe copy to switch to...") - up.cleanupOldDownloads(filepath.Join(os.TempDir(), "tailscale-updater-*.exe")) - selfOrig, selfCopy, err := makeSelfCopy() - if err != nil { - return err - } - defer os.Remove(selfCopy) - up.Logf("running tailscale.exe copy for final install...") - - cmd := exec.Command(selfCopy, "update") - cmd.Env = append(os.Environ(), winMSIEnv+"="+msiTarget, winExePathEnv+"="+selfOrig) - cmd.Stdout = up.Stderr - cmd.Stderr = up.Stderr - cmd.Stdin = os.Stdin - if err := cmd.Start(); err != nil { - return err - } - // Once it's started, exit ourselves, so the binary is free - // to be replaced. - os.Exit(0) - panic("unreachable") -} - -func (up *Updater) switchOutputToFile() (io.Closer, error) { - var logFilePath string - exePath, err := os.Executable() - if err != nil { - logFilePath = filepath.Join(os.TempDir(), "tailscale-updater.log") - } else { - logFilePath = strings.TrimSuffix(exePath, ".exe") + ".log" - } - - up.Logf("writing update output to %q", logFilePath) - logFile, err := os.Create(logFilePath) - if err != nil { - return nil, err - } - - up.Logf = func(m string, args ...any) { - fmt.Fprintf(logFile, m+"\n", args...) - } - up.Stdout = logFile - up.Stderr = logFile - return logFile, nil -} - -func (up *Updater) installMSI(msi string) error { - var err error - for tries := 0; tries < 2; tries++ { - cmd := exec.Command("msiexec.exe", "/i", filepath.Base(msi), "/quiet", "/norestart", "/qn") - cmd.Dir = filepath.Dir(msi) - cmd.Stdout = up.Stdout - cmd.Stderr = up.Stderr - cmd.Stdin = os.Stdin - err = cmd.Run() - if err == nil { - break - } - up.Logf("Install attempt failed: %v", err) - uninstallVersion := up.currentVersion - if v := os.Getenv("TS_DEBUG_UNINSTALL_VERSION"); v != "" { - uninstallVersion = v - } - // Assume it's a downgrade, which msiexec won't permit. Uninstall our current version first. - up.Logf("Uninstalling current version %q for downgrade...", uninstallVersion) - cmd = exec.Command("msiexec.exe", "/x", msiUUIDForVersion(uninstallVersion), "/norestart", "/qn") - cmd.Stdout = up.Stdout - cmd.Stderr = up.Stderr - cmd.Stdin = os.Stdin - err = cmd.Run() - up.Logf("msiexec uninstall: %v", err) - } - return err -} - // cleanupOldDownloads removes all files matching glob (see filepath.Glob). // Only regular files are removed, so the glob must match specific files and // not directories. @@ -938,53 +794,6 @@ func (up *Updater) cleanupOldDownloads(glob string) { } } -func msiUUIDForVersion(ver string) string { - arch := runtime.GOARCH - if arch == "386" { - arch = "x86" - } - track, err := versionToTrack(ver) - if err != nil { - track = UnstableTrack - } - msiURL := fmt.Sprintf("https://pkgs.tailscale.com/%s/tailscale-setup-%s-%s.msi", track, ver, arch) - return "{" + strings.ToUpper(uuid.NewSHA1(uuid.NameSpaceURL, []byte(msiURL)).String()) + "}" -} - -func makeSelfCopy() (origPathExe, tmpPathExe string, err error) { - selfExe, err := os.Executable() - if err != nil { - return "", "", err - } - f, err := os.Open(selfExe) - if err != nil { - return "", "", err - } - defer f.Close() - f2, err := os.CreateTemp("", "tailscale-updater-*.exe") - if err != nil { - return "", "", err - } - if f := markTempFileFunc; f != nil { - if err := f(f2.Name()); err != nil { - return "", "", err - } - } - if _, err := io.Copy(f2, f); err != nil { - f2.Close() - return "", "", err - } - return selfExe, f2.Name(), f2.Close() -} - -func (up *Updater) downloadURLToFile(pathSrc, fileDst string) (ret error) { - c, err := distsign.NewClient(up.Logf, up.PkgsAddr) - if err != nil { - return err - } - return c.Download(context.Background(), pathSrc, fileDst) -} - func (up *Updater) updateFreeBSD() (err error) { if up.Version != "" { return errors.New("installing a specific version on FreeBSD is not supported") diff --git a/clientupdate/clientupdate_downloads.go b/clientupdate/clientupdate_downloads.go new file mode 100644 index 000000000..18d3176b4 --- /dev/null +++ b/clientupdate/clientupdate_downloads.go @@ -0,0 +1,20 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build (linux && !android) || windows + +package clientupdate + +import ( + "context" + + "tailscale.com/clientupdate/distsign" +) + +func (up *Updater) downloadURLToFile(pathSrc, fileDst string) (ret error) { + c, err := distsign.NewClient(up.Logf, up.PkgsAddr) + if err != nil { + return err + } + return c.Download(context.Background(), pathSrc, fileDst) +} diff --git a/clientupdate/clientupdate_not_downloads.go b/clientupdate/clientupdate_not_downloads.go new file mode 100644 index 000000000..057b4f2cd --- /dev/null +++ b/clientupdate/clientupdate_not_downloads.go @@ -0,0 +1,10 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !((linux && !android) || windows) + +package clientupdate + +func (up *Updater) downloadURLToFile(pathSrc, fileDst string) (ret error) { + panic("unreachable") +} diff --git a/clientupdate/clientupdate_notwindows.go b/clientupdate/clientupdate_notwindows.go new file mode 100644 index 000000000..edadc210c --- /dev/null +++ b/clientupdate/clientupdate_notwindows.go @@ -0,0 +1,10 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !windows + +package clientupdate + +func (up *Updater) updateWindows() error { + panic("unreachable") +} diff --git a/clientupdate/clientupdate_windows.go b/clientupdate/clientupdate_windows.go index 2f6899a60..5faeda6dd 100644 --- a/clientupdate/clientupdate_windows.go +++ b/clientupdate/clientupdate_windows.go @@ -7,13 +7,59 @@ package clientupdate import ( + "errors" + "fmt" + "io" + "os" + "os/exec" + "path" + "path/filepath" + "runtime" + "strings" + "time" + + "github.com/google/uuid" "golang.org/x/sys/windows" + "tailscale.com/util/winutil" "tailscale.com/util/winutil/authenticode" ) -func init() { - markTempFileFunc = markTempFileWindows - verifyAuthenticode = verifyTailscale +const ( + // winMSIEnv is the environment variable that, if set, is the MSI file for + // the update command to install. It's passed like this so we can stop the + // tailscale.exe process from running before the msiexec process runs and + // tries to overwrite ourselves. + winMSIEnv = "TS_UPDATE_WIN_MSI" + // winVersionEnv is the environment variable that is set along with + // winMSIEnv and carries the version of tailscale that is being installed. + // It is used for logging purposes. + winVersionEnv = "TS_UPDATE_WIN_VERSION" + // updaterPrefix is the prefix for the temporary executable created by [makeSelfCopy]. + updaterPrefix = "tailscale-updater" +) + +func makeSelfCopy() (origPathExe, tmpPathExe string, err error) { + selfExe, err := os.Executable() + if err != nil { + return "", "", err + } + f, err := os.Open(selfExe) + if err != nil { + return "", "", err + } + defer f.Close() + f2, err := os.CreateTemp("", updaterPrefix+"-*.exe") + if err != nil { + return "", "", err + } + if err := markTempFileWindows(f2.Name()); err != nil { + return "", "", err + } + if _, err := io.Copy(f2, f); err != nil { + f2.Close() + return "", "", err + } + return selfExe, f2.Name(), f2.Close() } func markTempFileWindows(name string) error { @@ -23,6 +69,236 @@ func markTempFileWindows(name string) error { const certSubjectTailscale = "Tailscale Inc." -func verifyTailscale(path string) error { +func verifyAuthenticode(path string) error { return authenticode.Verify(path, certSubjectTailscale) } + +func isTSGUIPresent() bool { + us, err := os.Executable() + if err != nil { + return false + } + + tsgui := filepath.Join(filepath.Dir(us), "tsgui.dll") + _, err = os.Stat(tsgui) + return err == nil +} + +func (up *Updater) updateWindows() error { + if msi := os.Getenv(winMSIEnv); msi != "" { + // stdout/stderr from this part of the install could be lost since the + // parent tailscaled is replaced. Create a temp log file to have some + // output to debug with in case update fails. + close, err := up.switchOutputToFile() + if err != nil { + up.Logf("failed to create log file for installation: %v; proceeding with existing outputs", err) + } else { + defer close.Close() + } + + up.Logf("installing %v ...", msi) + if err := up.installMSI(msi); err != nil { + up.Logf("MSI install failed: %v", err) + return err + } + + up.Logf("success.") + return nil + } + + if !winutil.IsCurrentProcessElevated() { + return errors.New(`update must be run as Administrator + +you can run the command prompt as Administrator one of these ways: +* right-click cmd.exe, select 'Run as administrator' +* press Windows+x, then press a +* press Windows+r, type in "cmd", then press Ctrl+Shift+Enter`) + } + ver, err := requestedTailscaleVersion(up.Version, up.Track) + if err != nil { + return err + } + arch := runtime.GOARCH + if arch == "386" { + arch = "x86" + } + if !up.confirm(ver) { + return nil + } + + tsDir := filepath.Join(os.Getenv("ProgramData"), "Tailscale") + msiDir := filepath.Join(tsDir, "MSICache") + if fi, err := os.Stat(tsDir); err != nil { + return fmt.Errorf("expected %s to exist, got stat error: %w", tsDir, err) + } else if !fi.IsDir() { + return fmt.Errorf("expected %s to be a directory; got %v", tsDir, fi.Mode()) + } + if err := os.MkdirAll(msiDir, 0700); err != nil { + return err + } + up.cleanupOldDownloads(filepath.Join(msiDir, "*.msi")) + + qualifiers := []string{ver, arch} + // TODO(aaron): Temporary hack so autoupdate still works on winui builds; + // remove when we enable winui by default on the unstable track. + if isTSGUIPresent() { + qualifiers = append(qualifiers, "winui") + } + + pkgsPath := fmt.Sprintf("%s/tailscale-setup-%s.msi", up.Track, strings.Join(qualifiers, "-")) + msiTarget := filepath.Join(msiDir, path.Base(pkgsPath)) + if err := up.downloadURLToFile(pkgsPath, msiTarget); err != nil { + return err + } + + up.Logf("verifying MSI authenticode...") + if err := verifyAuthenticode(msiTarget); err != nil { + return fmt.Errorf("authenticode verification of %s failed: %w", msiTarget, err) + } + up.Logf("authenticode verification succeeded") + + up.Logf("making tailscale.exe copy to switch to...") + up.cleanupOldDownloads(filepath.Join(os.TempDir(), updaterPrefix+"-*.exe")) + _, selfCopy, err := makeSelfCopy() + if err != nil { + return err + } + defer os.Remove(selfCopy) + up.Logf("running tailscale.exe copy for final install...") + + cmd := exec.Command(selfCopy, "update") + cmd.Env = append(os.Environ(), winMSIEnv+"="+msiTarget, winVersionEnv+"="+ver) + cmd.Stdout = up.Stderr + cmd.Stderr = up.Stderr + cmd.Stdin = os.Stdin + if err := cmd.Start(); err != nil { + return err + } + // Once it's started, exit ourselves, so the binary is free + // to be replaced. + os.Exit(0) + panic("unreachable") +} + +func (up *Updater) installMSI(msi string) error { + var err error + for tries := 0; tries < 2; tries++ { + // msiexec.exe requires exclusive access to the log file, so create a dedicated one for each run. + installLogPath := up.startNewLogFile("tailscale-installer", os.Getenv(winVersionEnv)) + up.Logf("Install log: %s", installLogPath) + cmd := exec.Command("msiexec.exe", "/i", filepath.Base(msi), "/quiet", "/norestart", "/qn", "/L*v", installLogPath) + cmd.Dir = filepath.Dir(msi) + cmd.Stdout = up.Stdout + cmd.Stderr = up.Stderr + cmd.Stdin = os.Stdin + err = cmd.Run() + switch err := err.(type) { + case nil: + // Success. + return nil + case *exec.ExitError: + // For possible error codes returned by Windows Installer, see + // https://web.archive.org/web/20250409144914/https://learn.microsoft.com/en-us/windows/win32/msi/error-codes + switch windows.Errno(err.ExitCode()) { + case windows.ERROR_SUCCESS_REBOOT_REQUIRED: + // In most cases, updating Tailscale should not require a reboot. + // If it does, it might be because we failed to close the GUI + // and the installer couldn't replace its executable. + // The old GUI will continue to run until the next reboot. + // Not ideal, but also not a retryable error. + up.Logf("[unexpected] reboot required") + return nil + case windows.ERROR_SUCCESS_REBOOT_INITIATED: + // Same as above, but perhaps the device is configured to prompt + // the user to reboot and the user has chosen to reboot now. + up.Logf("[unexpected] reboot initiated") + return nil + case windows.ERROR_INSTALL_ALREADY_RUNNING: + // The Windows Installer service is currently busy. + // It could be our own install initiated by user/MDM/GP, another MSI install or perhaps a Windows Update install. + // Anyway, we can't do anything about it right now. The user (or tailscaled) can retry later. + // Retrying now will likely fail, and is risky since we might uninstall the current version + // and then fail to install the new one, leaving the user with no Tailscale at all. + // + // TODO(nickkhyl,awly): should we check if this is actually a downgrade before uninstalling the current version? + // Also, maybe keep retrying the install longer if we uninstalled the current version due to a failed install attempt? + up.Logf("another installation is already in progress") + return err + } + default: + // Everything else is a retryable error. + } + + up.Logf("Install attempt failed: %v", err) + uninstallVersion := up.currentVersion + if v := os.Getenv("TS_DEBUG_UNINSTALL_VERSION"); v != "" { + uninstallVersion = v + } + uninstallLogPath := up.startNewLogFile("tailscale-uninstaller", uninstallVersion) + // Assume it's a downgrade, which msiexec won't permit. Uninstall our current version first. + up.Logf("Uninstalling current version %q for downgrade...", uninstallVersion) + up.Logf("Uninstall log: %s", uninstallLogPath) + cmd = exec.Command("msiexec.exe", "/x", msiUUIDForVersion(uninstallVersion), "/norestart", "/qn", "/L*v", uninstallLogPath) + cmd.Stdout = up.Stdout + cmd.Stderr = up.Stderr + cmd.Stdin = os.Stdin + err = cmd.Run() + up.Logf("msiexec uninstall: %v", err) + } + return err +} + +func msiUUIDForVersion(ver string) string { + arch := runtime.GOARCH + if arch == "386" { + arch = "x86" + } + track, err := versionToTrack(ver) + if err != nil { + track = UnstableTrack + } + msiURL := fmt.Sprintf("https://pkgs.tailscale.com/%s/tailscale-setup-%s-%s.msi", track, ver, arch) + return "{" + strings.ToUpper(uuid.NewSHA1(uuid.NameSpaceURL, []byte(msiURL)).String()) + "}" +} + +func (up *Updater) switchOutputToFile() (io.Closer, error) { + var logFilePath string + exePath, err := os.Executable() + if err != nil { + logFilePath = up.startNewLogFile(updaterPrefix, os.Getenv(winVersionEnv)) + } else { + // Use the same suffix as the self-copy executable. + suffix := strings.TrimSuffix(strings.TrimPrefix(filepath.Base(exePath), updaterPrefix), ".exe") + logFilePath = up.startNewLogFile(updaterPrefix, os.Getenv(winVersionEnv)+suffix) + } + + up.Logf("writing update output to: %s", logFilePath) + logFile, err := os.Create(logFilePath) + if err != nil { + return nil, err + } + + up.Logf = func(m string, args ...any) { + fmt.Fprintf(logFile, m+"\n", args...) + } + up.Stdout = logFile + up.Stderr = logFile + return logFile, nil +} + +// startNewLogFile returns a name for a new log file. +// It cleans up any old log files with the same baseNamePrefix. +func (up *Updater) startNewLogFile(baseNamePrefix, baseNameSuffix string) string { + baseName := fmt.Sprintf("%s-%s-%s.log", baseNamePrefix, + time.Now().Format("20060102T150405"), baseNameSuffix) + + dir := filepath.Join(os.Getenv("ProgramData"), "Tailscale", "Logs") + if err := os.MkdirAll(dir, 0700); err != nil { + up.Logf("failed to create log directory: %v", err) + return filepath.Join(os.TempDir(), baseName) + } + + // TODO(nickkhyl): preserve up to N old log files? + up.cleanupOldDownloads(filepath.Join(dir, baseNamePrefix+"-*.log")) + return filepath.Join(dir, baseName) +} diff --git a/clientupdate/distsign/distsign.go b/clientupdate/distsign/distsign.go index eba4b9267..270ee4c1f 100644 --- a/clientupdate/distsign/distsign.go +++ b/clientupdate/distsign/distsign.go @@ -55,7 +55,7 @@ import ( "github.com/hdevalence/ed25519consensus" "golang.org/x/crypto/blake2s" - "tailscale.com/net/tshttpproxy" + "tailscale.com/feature" "tailscale.com/types/logger" "tailscale.com/util/httpm" "tailscale.com/util/must" @@ -330,7 +330,7 @@ func fetch(url string, limit int64) ([]byte, error) { // limit bytes. On success, the returned value is a BLAKE2s hash of the file. func (c *Client) download(ctx context.Context, url, dst string, limit int64) ([]byte, int64, error) { tr := http.DefaultTransport.(*http.Transport).Clone() - tr.Proxy = tshttpproxy.ProxyFromEnvironment + tr.Proxy = feature.HookProxyFromEnvironment.GetOrNil() defer tr.CloseIdleConnections() hc := &http.Client{Transport: tr} diff --git a/cmd/addlicense/main.go b/cmd/addlicense/main.go index a8fd9dd4a..1cd1b0f19 100644 --- a/cmd/addlicense/main.go +++ b/cmd/addlicense/main.go @@ -18,12 +18,12 @@ var ( ) func usage() { - fmt.Fprintf(os.Stderr, ` + fmt.Fprint(os.Stderr, ` usage: addlicense -file FILE `[1:]) flag.PrintDefaults() - fmt.Fprintf(os.Stderr, ` + fmt.Fprint(os.Stderr, ` addlicense adds a Tailscale license to the beginning of file. It is intended for use with 'go generate', so it also runs a subcommand, diff --git a/cmd/checkmetrics/checkmetrics.go b/cmd/checkmetrics/checkmetrics.go new file mode 100644 index 000000000..fb9e8ab4c --- /dev/null +++ b/cmd/checkmetrics/checkmetrics.go @@ -0,0 +1,131 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// checkmetrics validates that all metrics in the tailscale client-metrics +// are documented in a given path or URL. +package main + +import ( + "context" + "flag" + "fmt" + "io" + "log" + "net/http" + "net/http/httptest" + "os" + "strings" + "time" + + "tailscale.com/ipn/store/mem" + "tailscale.com/tsnet" + "tailscale.com/tstest/integration/testcontrol" + "tailscale.com/util/httpm" +) + +var ( + kbPath = flag.String("kb-path", "", "filepath to the client-metrics knowledge base") + kbUrl = flag.String("kb-url", "", "URL to the client-metrics knowledge base page") +) + +func main() { + flag.Parse() + if *kbPath == "" && *kbUrl == "" { + log.Fatalf("either -kb-path or -kb-url must be set") + } + + var control testcontrol.Server + ts := httptest.NewServer(&control) + defer ts.Close() + + td, err := os.MkdirTemp("", "testcontrol") + if err != nil { + log.Fatal(err) + } + defer os.RemoveAll(td) + + // tsnet is used not used as a Tailscale client, but as a way to + // boot up Tailscale, have all the metrics registered, and then + // verifiy that all the metrics are documented. + tsn := &tsnet.Server{ + Dir: td, + Store: new(mem.Store), + UserLogf: log.Printf, + Ephemeral: true, + ControlURL: ts.URL, + } + if err := tsn.Start(); err != nil { + log.Fatal(err) + } + defer tsn.Close() + + log.Printf("checking that all metrics are documented, looking for: %s", tsn.Sys().UserMetricsRegistry().MetricNames()) + + if *kbPath != "" { + kb, err := readKB(*kbPath) + if err != nil { + log.Fatalf("reading kb: %v", err) + } + missing := undocumentedMetrics(kb, tsn.Sys().UserMetricsRegistry().MetricNames()) + + if len(missing) > 0 { + log.Fatalf("found undocumented metrics in %q: %v", *kbPath, missing) + } + } + + if *kbUrl != "" { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + kb, err := getKB(ctx, *kbUrl) + if err != nil { + log.Fatalf("getting kb: %v", err) + } + missing := undocumentedMetrics(kb, tsn.Sys().UserMetricsRegistry().MetricNames()) + + if len(missing) > 0 { + log.Fatalf("found undocumented metrics in %q: %v", *kbUrl, missing) + } + } +} + +func readKB(path string) (string, error) { + b, err := os.ReadFile(path) + if err != nil { + return "", fmt.Errorf("reading file: %w", err) + } + + return string(b), nil +} + +func getKB(ctx context.Context, url string) (string, error) { + req, err := http.NewRequestWithContext(ctx, httpm.GET, url, nil) + if err != nil { + return "", fmt.Errorf("creating request: %w", err) + } + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return "", fmt.Errorf("getting kb page: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("unexpected status code: %d", resp.StatusCode) + } + + b, err := io.ReadAll(resp.Body) + if err != nil { + return "", fmt.Errorf("reading body: %w", err) + } + return string(b), nil +} + +func undocumentedMetrics(b string, metrics []string) []string { + var missing []string + for _, metric := range metrics { + if !strings.Contains(b, metric) { + missing = append(missing, metric) + } + } + return missing +} diff --git a/cmd/cloner/cloner.go b/cmd/cloner/cloner.go index a1ffc30fe..a81bd10bd 100644 --- a/cmd/cloner/cloner.go +++ b/cmd/cloner/cloner.go @@ -121,7 +121,12 @@ func gen(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named) { continue } if !hasBasicUnderlying(ft) { - writef("dst.%s = *src.%s.Clone()", fname, fname) + // don't dereference if the underlying type is an interface + if _, isInterface := ft.Underlying().(*types.Interface); isInterface { + writef("if src.%s != nil { dst.%s = src.%s.Clone() }", fname, fname, fname) + } else { + writef("dst.%s = *src.%s.Clone()", fname, fname) + } continue } } @@ -136,13 +141,13 @@ func gen(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named) { writef("if src.%s[i] == nil { dst.%s[i] = nil } else {", fname, fname) if codegen.ContainsPointers(ptr.Elem()) { if _, isIface := ptr.Elem().Underlying().(*types.Interface); isIface { - it.Import("tailscale.com/types/ptr") + it.Import("", "tailscale.com/types/ptr") writef("\tdst.%s[i] = ptr.To((*src.%s[i]).Clone())", fname, fname) } else { writef("\tdst.%s[i] = src.%s[i].Clone()", fname, fname) } } else { - it.Import("tailscale.com/types/ptr") + it.Import("", "tailscale.com/types/ptr") writef("\tdst.%s[i] = ptr.To(*src.%s[i])", fname, fname) } writef("}") @@ -165,7 +170,7 @@ func gen(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named) { writef("dst.%s = src.%s.Clone()", fname, fname) continue } - it.Import("tailscale.com/types/ptr") + it.Import("", "tailscale.com/types/ptr") writef("if dst.%s != nil {", fname) if _, isIface := base.Underlying().(*types.Interface); isIface && hasPtrs { writef("\tdst.%s = ptr.To((*src.%s).Clone())", fname, fname) @@ -187,45 +192,34 @@ func gen(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named) { writef("\t\tdst.%s[k] = append([]%s{}, src.%s[k]...)", fname, n, fname) writef("\t}") writef("}") - } else if codegen.ContainsPointers(elem) { + } else if codegen.IsViewType(elem) || !codegen.ContainsPointers(elem) { + // If the map values are view types (which are + // immutable and don't need cloning) or don't + // themselves contain pointers, we can just + // clone the map itself. + it.Import("", "maps") + writef("\tdst.%s = maps.Clone(src.%s)", fname, fname) + } else { + // Otherwise we need to clone each element of + // the map using our recursive helper. writef("if dst.%s != nil {", fname) writef("\tdst.%s = map[%s]%s{}", fname, it.QualifiedName(ft.Key()), it.QualifiedName(elem)) writef("\tfor k, v := range src.%s {", fname) - switch elem := elem.Underlying().(type) { - case *types.Pointer: - writef("\t\tif v == nil { dst.%s[k] = nil } else {", fname) - if base := elem.Elem().Underlying(); codegen.ContainsPointers(base) { - if _, isIface := base.(*types.Interface); isIface { - it.Import("tailscale.com/types/ptr") - writef("\t\t\tdst.%s[k] = ptr.To((*v).Clone())", fname) - } else { - writef("\t\t\tdst.%s[k] = v.Clone()", fname) - } - } else { - it.Import("tailscale.com/types/ptr") - writef("\t\t\tdst.%s[k] = ptr.To(*v)", fname) - } - writef("}") - case *types.Interface: - if cloneResultType := methodResultType(elem, "Clone"); cloneResultType != nil { - if _, isPtr := cloneResultType.(*types.Pointer); isPtr { - writef("\t\tdst.%s[k] = *(v.Clone())", fname) - } else { - writef("\t\tdst.%s[k] = v.Clone()", fname) - } - } else { - writef(`panic("%s (%v) does not have a Clone method")`, fname, elem) - } - default: - writef("\t\tdst.%s[k] = *(v.Clone())", fname) - } - + // Use a recursive helper here; this handles + // arbitrarily nested maps in addition to + // simpler types. + writeMapValueClone(mapValueCloneParams{ + Buf: buf, + It: it, + Elem: elem, + SrcExpr: "v", + DstExpr: fmt.Sprintf("dst.%s[k]", fname), + BaseIndent: "\t", + Depth: 1, + }) writef("\t}") writef("}") - } else { - it.Import("maps") - writef("\tdst.%s = maps.Clone(src.%s)", fname, fname) } case *types.Interface: // If ft is an interface with a "Clone() ft" method, it can be used to clone the field. @@ -266,3 +260,99 @@ func methodResultType(typ types.Type, method string) types.Type { } return sig.Results().At(0).Type() } + +type mapValueCloneParams struct { + // Buf is the buffer to write generated code to + Buf *bytes.Buffer + // It is the import tracker for managing imports. + It *codegen.ImportTracker + // Elem is the type of the map value to clone + Elem types.Type + // SrcExpr is the expression for the source value (e.g., "v", "v2", "v3") + SrcExpr string + // DstExpr is the expression for the destination (e.g., "dst.Field[k]", "dst.Field[k][k2]") + DstExpr string + // BaseIndent is the "base" indentation string for the generated code + // (i.e. 1 or more tabs). Additional indentation will be added based on + // the Depth parameter. + BaseIndent string + // Depth is the current nesting depth (1 for first level, 2 for second, etc.) + Depth int +} + +// writeMapValueClone generates code to clone a map value recursively. +// It handles arbitrary nesting of maps, pointers, and interfaces. +func writeMapValueClone(params mapValueCloneParams) { + indent := params.BaseIndent + strings.Repeat("\t", params.Depth) + writef := func(format string, args ...any) { + fmt.Fprintf(params.Buf, indent+format+"\n", args...) + } + + switch elem := params.Elem.Underlying().(type) { + case *types.Pointer: + writef("if %s == nil { %s = nil } else {", params.SrcExpr, params.DstExpr) + if base := elem.Elem().Underlying(); codegen.ContainsPointers(base) { + if _, isIface := base.(*types.Interface); isIface { + params.It.Import("", "tailscale.com/types/ptr") + writef("\t%s = ptr.To((*%s).Clone())", params.DstExpr, params.SrcExpr) + } else { + writef("\t%s = %s.Clone()", params.DstExpr, params.SrcExpr) + } + } else { + params.It.Import("", "tailscale.com/types/ptr") + writef("\t%s = ptr.To(*%s)", params.DstExpr, params.SrcExpr) + } + writef("}") + + case *types.Map: + // Recursively handle nested maps + innerElem := elem.Elem() + if codegen.IsViewType(innerElem) || !codegen.ContainsPointers(innerElem) { + // Inner map values don't need deep cloning + params.It.Import("", "maps") + writef("%s = maps.Clone(%s)", params.DstExpr, params.SrcExpr) + } else { + // Inner map values need cloning + keyType := params.It.QualifiedName(elem.Key()) + valueType := params.It.QualifiedName(innerElem) + // Generate unique variable names for nested loops based on depth + keyVar := fmt.Sprintf("k%d", params.Depth+1) + valVar := fmt.Sprintf("v%d", params.Depth+1) + + writef("if %s == nil {", params.SrcExpr) + writef("\t%s = nil", params.DstExpr) + writef("\tcontinue") + writef("}") + writef("%s = map[%s]%s{}", params.DstExpr, keyType, valueType) + writef("for %s, %s := range %s {", keyVar, valVar, params.SrcExpr) + + // Recursively generate cloning code for the nested map value + nestedDstExpr := fmt.Sprintf("%s[%s]", params.DstExpr, keyVar) + writeMapValueClone(mapValueCloneParams{ + Buf: params.Buf, + It: params.It, + Elem: innerElem, + SrcExpr: valVar, + DstExpr: nestedDstExpr, + BaseIndent: params.BaseIndent, + Depth: params.Depth + 1, + }) + + writef("}") + } + + case *types.Interface: + if cloneResultType := methodResultType(elem, "Clone"); cloneResultType != nil { + if _, isPtr := cloneResultType.(*types.Pointer); isPtr { + writef("%s = *(%s.Clone())", params.DstExpr, params.SrcExpr) + } else { + writef("%s = %s.Clone()", params.DstExpr, params.SrcExpr) + } + } else { + writef(`panic("map value (%%v) does not have a Clone method")`, elem) + } + + default: + writef("%s = *(%s.Clone())", params.DstExpr, params.SrcExpr) + } +} diff --git a/cmd/cloner/cloner_test.go b/cmd/cloner/cloner_test.go index d8d5df3cb..754a4ac49 100644 --- a/cmd/cloner/cloner_test.go +++ b/cmd/cloner/cloner_test.go @@ -1,5 +1,6 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause + package main import ( @@ -58,3 +59,158 @@ func TestSliceContainer(t *testing.T) { }) } } + +func TestInterfaceContainer(t *testing.T) { + examples := []struct { + name string + in *clonerex.InterfaceContainer + }{ + { + name: "nil", + in: nil, + }, + { + name: "zero", + in: &clonerex.InterfaceContainer{}, + }, + { + name: "with_interface", + in: &clonerex.InterfaceContainer{ + Interface: &clonerex.CloneableImpl{Value: 42}, + }, + }, + { + name: "with_nil_interface", + in: &clonerex.InterfaceContainer{ + Interface: nil, + }, + }, + } + + for _, ex := range examples { + t.Run(ex.name, func(t *testing.T) { + out := ex.in.Clone() + if !reflect.DeepEqual(ex.in, out) { + t.Errorf("Clone() = %v, want %v", out, ex.in) + } + + // Verify no aliasing: modifying the clone should not affect the original + if ex.in != nil && ex.in.Interface != nil { + if impl, ok := out.Interface.(*clonerex.CloneableImpl); ok { + impl.Value = 999 + if origImpl, ok := ex.in.Interface.(*clonerex.CloneableImpl); ok { + if origImpl.Value == 999 { + t.Errorf("Clone() aliased memory with original") + } + } + } + } + }) + } +} + +func TestMapWithPointers(t *testing.T) { + num1, num2 := 42, 100 + orig := &clonerex.MapWithPointers{ + Nested: map[string]*int{ + "foo": &num1, + "bar": &num2, + }, + WithCloneMethod: map[string]*clonerex.SliceContainer{ + "container1": {Slice: []*int{&num1, &num2}}, + "container2": {Slice: []*int{&num1}}, + }, + CloneInterface: map[string]clonerex.Cloneable{ + "impl1": &clonerex.CloneableImpl{Value: 123}, + "impl2": &clonerex.CloneableImpl{Value: 456}, + }, + } + + cloned := orig.Clone() + if !reflect.DeepEqual(orig, cloned) { + t.Errorf("Clone() = %v, want %v", cloned, orig) + } + + // Mutate cloned.Nested pointer values + *cloned.Nested["foo"] = 999 + if *orig.Nested["foo"] == 999 { + t.Errorf("Clone() aliased memory in Nested: original was modified") + } + + // Mutate cloned.WithCloneMethod slice values + *cloned.WithCloneMethod["container1"].Slice[0] = 888 + if *orig.WithCloneMethod["container1"].Slice[0] == 888 { + t.Errorf("Clone() aliased memory in WithCloneMethod: original was modified") + } + + // Mutate cloned.CloneInterface values + if impl, ok := cloned.CloneInterface["impl1"].(*clonerex.CloneableImpl); ok { + impl.Value = 777 + if origImpl, ok := orig.CloneInterface["impl1"].(*clonerex.CloneableImpl); ok { + if origImpl.Value == 777 { + t.Errorf("Clone() aliased memory in CloneInterface: original was modified") + } + } + } +} + +func TestDeeplyNestedMap(t *testing.T) { + num := 123 + orig := &clonerex.DeeplyNestedMap{ + ThreeLevels: map[string]map[string]map[string]int{ + "a": { + "b": {"c": 1, "d": 2}, + "e": {"f": 3}, + }, + "g": { + "h": {"i": 4}, + }, + }, + FourLevels: map[string]map[string]map[string]map[string]*clonerex.SliceContainer{ + "l1a": { + "l2a": { + "l3a": { + "l4a": {Slice: []*int{&num}}, + "l4b": {Slice: []*int{&num, &num}}, + }, + }, + }, + }, + } + + cloned := orig.Clone() + if !reflect.DeepEqual(orig, cloned) { + t.Errorf("Clone() = %v, want %v", cloned, orig) + } + + // Mutate the clone's ThreeLevels map + cloned.ThreeLevels["a"]["b"]["c"] = 777 + if orig.ThreeLevels["a"]["b"]["c"] == 777 { + t.Errorf("Clone() aliased memory in ThreeLevels: original was modified") + } + + // Mutate the clone's FourLevels map at the deepest pointer level + *cloned.FourLevels["l1a"]["l2a"]["l3a"]["l4a"].Slice[0] = 666 + if *orig.FourLevels["l1a"]["l2a"]["l3a"]["l4a"].Slice[0] == 666 { + t.Errorf("Clone() aliased memory in FourLevels: original was modified") + } + + // Add a new top-level key to the clone's FourLevels map + newNum := 999 + cloned.FourLevels["l1b"] = map[string]map[string]map[string]*clonerex.SliceContainer{ + "l2b": { + "l3b": { + "l4c": {Slice: []*int{&newNum}}, + }, + }, + } + if _, exists := orig.FourLevels["l1b"]; exists { + t.Errorf("Clone() aliased FourLevels map: new top-level key appeared in original") + } + + // Add a new nested key to the clone's FourLevels map + cloned.FourLevels["l1a"]["l2a"]["l3a"]["l4c"] = &clonerex.SliceContainer{Slice: []*int{&newNum}} + if _, exists := orig.FourLevels["l1a"]["l2a"]["l3a"]["l4c"]; exists { + t.Errorf("Clone() aliased FourLevels map: new nested key appeared in original") + } +} diff --git a/cmd/cloner/clonerex/clonerex.go b/cmd/cloner/clonerex/clonerex.go index 96bf8a0bd..b9f6d60de 100644 --- a/cmd/cloner/clonerex/clonerex.go +++ b/cmd/cloner/clonerex/clonerex.go @@ -1,7 +1,7 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -//go:generate go run tailscale.com/cmd/cloner -clonefunc=true -type SliceContainer +//go:generate go run tailscale.com/cmd/cloner -clonefunc=true -type SliceContainer,InterfaceContainer,MapWithPointers,DeeplyNestedMap // Package clonerex is an example package for the cloner tool. package clonerex @@ -9,3 +9,38 @@ package clonerex type SliceContainer struct { Slice []*int } + +// Cloneable is an interface with a Clone method. +type Cloneable interface { + Clone() Cloneable +} + +// CloneableImpl is a concrete type that implements Cloneable. +type CloneableImpl struct { + Value int +} + +func (c *CloneableImpl) Clone() Cloneable { + if c == nil { + return nil + } + return &CloneableImpl{Value: c.Value} +} + +// InterfaceContainer has a pointer to an interface field, which tests +// the special handling for interface types in the cloner. +type InterfaceContainer struct { + Interface Cloneable +} + +type MapWithPointers struct { + Nested map[string]*int + WithCloneMethod map[string]*SliceContainer + CloneInterface map[string]Cloneable +} + +// DeeplyNestedMap tests arbitrary depth of map nesting (3+ levels) +type DeeplyNestedMap struct { + ThreeLevels map[string]map[string]map[string]int + FourLevels map[string]map[string]map[string]map[string]*SliceContainer +} diff --git a/cmd/cloner/clonerex/clonerex_clone.go b/cmd/cloner/clonerex/clonerex_clone.go index e334a4e3a..13e1276c4 100644 --- a/cmd/cloner/clonerex/clonerex_clone.go +++ b/cmd/cloner/clonerex/clonerex_clone.go @@ -6,6 +6,8 @@ package clonerex import ( + "maps" + "tailscale.com/types/ptr" ) @@ -35,9 +37,133 @@ var _SliceContainerCloneNeedsRegeneration = SliceContainer(struct { Slice []*int }{}) +// Clone makes a deep copy of InterfaceContainer. +// The result aliases no memory with the original. +func (src *InterfaceContainer) Clone() *InterfaceContainer { + if src == nil { + return nil + } + dst := new(InterfaceContainer) + *dst = *src + if src.Interface != nil { + dst.Interface = src.Interface.Clone() + } + return dst +} + +// A compilation failure here means this code must be regenerated, with the command at the top of this file. +var _InterfaceContainerCloneNeedsRegeneration = InterfaceContainer(struct { + Interface Cloneable +}{}) + +// Clone makes a deep copy of MapWithPointers. +// The result aliases no memory with the original. +func (src *MapWithPointers) Clone() *MapWithPointers { + if src == nil { + return nil + } + dst := new(MapWithPointers) + *dst = *src + if dst.Nested != nil { + dst.Nested = map[string]*int{} + for k, v := range src.Nested { + if v == nil { + dst.Nested[k] = nil + } else { + dst.Nested[k] = ptr.To(*v) + } + } + } + if dst.WithCloneMethod != nil { + dst.WithCloneMethod = map[string]*SliceContainer{} + for k, v := range src.WithCloneMethod { + if v == nil { + dst.WithCloneMethod[k] = nil + } else { + dst.WithCloneMethod[k] = v.Clone() + } + } + } + if dst.CloneInterface != nil { + dst.CloneInterface = map[string]Cloneable{} + for k, v := range src.CloneInterface { + dst.CloneInterface[k] = v.Clone() + } + } + return dst +} + +// A compilation failure here means this code must be regenerated, with the command at the top of this file. +var _MapWithPointersCloneNeedsRegeneration = MapWithPointers(struct { + Nested map[string]*int + WithCloneMethod map[string]*SliceContainer + CloneInterface map[string]Cloneable +}{}) + +// Clone makes a deep copy of DeeplyNestedMap. +// The result aliases no memory with the original. +func (src *DeeplyNestedMap) Clone() *DeeplyNestedMap { + if src == nil { + return nil + } + dst := new(DeeplyNestedMap) + *dst = *src + if dst.ThreeLevels != nil { + dst.ThreeLevels = map[string]map[string]map[string]int{} + for k, v := range src.ThreeLevels { + if v == nil { + dst.ThreeLevels[k] = nil + continue + } + dst.ThreeLevels[k] = map[string]map[string]int{} + for k2, v2 := range v { + dst.ThreeLevels[k][k2] = maps.Clone(v2) + } + } + } + if dst.FourLevels != nil { + dst.FourLevels = map[string]map[string]map[string]map[string]*SliceContainer{} + for k, v := range src.FourLevels { + if v == nil { + dst.FourLevels[k] = nil + continue + } + dst.FourLevels[k] = map[string]map[string]map[string]*SliceContainer{} + for k2, v2 := range v { + if v2 == nil { + dst.FourLevels[k][k2] = nil + continue + } + dst.FourLevels[k][k2] = map[string]map[string]*SliceContainer{} + for k3, v3 := range v2 { + if v3 == nil { + dst.FourLevels[k][k2][k3] = nil + continue + } + dst.FourLevels[k][k2][k3] = map[string]*SliceContainer{} + for k4, v4 := range v3 { + if v4 == nil { + dst.FourLevels[k][k2][k3][k4] = nil + } else { + dst.FourLevels[k][k2][k3][k4] = v4.Clone() + } + } + } + } + } + } + return dst +} + +// A compilation failure here means this code must be regenerated, with the command at the top of this file. +var _DeeplyNestedMapCloneNeedsRegeneration = DeeplyNestedMap(struct { + ThreeLevels map[string]map[string]map[string]int + FourLevels map[string]map[string]map[string]map[string]*SliceContainer +}{}) + // Clone duplicates src into dst and reports whether it succeeded. // To succeed, must be of types <*T, *T> or <*T, **T>, -// where T is one of SliceContainer. +// where T is one of SliceContainer,InterfaceContainer,MapWithPointers,DeeplyNestedMap. func Clone(dst, src any) bool { switch src := src.(type) { case *SliceContainer: @@ -49,6 +175,33 @@ func Clone(dst, src any) bool { *dst = src.Clone() return true } + case *InterfaceContainer: + switch dst := dst.(type) { + case *InterfaceContainer: + *dst = *src.Clone() + return true + case **InterfaceContainer: + *dst = src.Clone() + return true + } + case *MapWithPointers: + switch dst := dst.(type) { + case *MapWithPointers: + *dst = *src.Clone() + return true + case **MapWithPointers: + *dst = src.Clone() + return true + } + case *DeeplyNestedMap: + switch dst := dst.(type) { + case *DeeplyNestedMap: + *dst = *src.Clone() + return true + case **DeeplyNestedMap: + *dst = src.Clone() + return true + } } return false } diff --git a/cmd/containerboot/services.go b/cmd/containerboot/egressservices.go similarity index 69% rename from cmd/containerboot/services.go rename to cmd/containerboot/egressservices.go index e46c7c015..fe835a69e 100644 --- a/cmd/containerboot/services.go +++ b/cmd/containerboot/egressservices.go @@ -11,24 +11,33 @@ import ( "errors" "fmt" "log" + "net/http" "net/netip" "os" "path/filepath" "reflect" + "strconv" "strings" + "sync" "time" "github.com/fsnotify/fsnotify" + "tailscale.com/client/local" "tailscale.com/ipn" "tailscale.com/kube/egressservices" "tailscale.com/kube/kubeclient" + "tailscale.com/kube/kubetypes" "tailscale.com/tailcfg" + "tailscale.com/util/httpm" "tailscale.com/util/linuxfw" "tailscale.com/util/mak" ) const tailscaleTunInterface = "tailscale0" +// Modified using a build flag to speed up tests. +var testSleepDuration string + // This file contains functionality to run containerboot as a proxy that can // route cluster traffic to one or more tailnet targets, based on portmapping // rules read from a configfile. Currently (9/2024) this is only used for the @@ -37,16 +46,18 @@ const tailscaleTunInterface = "tailscale0" // egressProxy knows how to configure firewall rules to route cluster traffic to // one or more tailnet services. type egressProxy struct { - cfgPath string // path to egress service config file + cfgPath string // path to a directory with egress services config files nfr linuxfw.NetfilterRunner // never nil kc kubeclient.Client // never nil stateSecret string // name of the kube state Secret + tsClient *local.Client // never nil + netmapChan chan ipn.Notify // chan to receive netmap updates on - podIP string // never empty string + podIPv4 string // never empty string, currently only IPv4 is supported // tailnetFQDNs is the egress service FQDN to tailnet IP mappings that // were last used to configure firewall rules for this proxy. @@ -55,15 +66,29 @@ type egressProxy struct { // memory at all. targetFQDNs map[string][]netip.Prefix - // used to configure firewall rules. - tailnetAddrs []netip.Prefix + tailnetAddrs []netip.Prefix // tailnet IPs of this tailnet device + + // shortSleep is the backoff sleep between healthcheck endpoint calls - can be overridden in tests. + shortSleep time.Duration + // longSleep is the time to sleep after the routing rules are updated to increase the chance that kube + // proxies on all nodes have updated their routing configuration. It can be configured to 0 in + // tests. + longSleep time.Duration + // client is a client that can send HTTP requests. + client httpClient +} + +// httpClient is a client that can send HTTP requests and can be mocked in tests. +type httpClient interface { + Do(*http.Request) (*http.Response, error) } // run configures egress proxy firewall rules and ensures that the firewall rules are reconfigured when: // - the mounted egress config has changed // - the proxy's tailnet IP addresses have changed // - tailnet IPs have changed for any backend targets specified by tailnet FQDN -func (ep *egressProxy) run(ctx context.Context, n ipn.Notify) error { +func (ep *egressProxy) run(ctx context.Context, n ipn.Notify, opts egressProxyRunOpts) error { + ep.configure(opts) var tickChan <-chan time.Time var eventChan <-chan fsnotify.Event // TODO (irbekrm): take a look if this can be pulled into a single func @@ -75,7 +100,7 @@ func (ep *egressProxy) run(ctx context.Context, n ipn.Notify) error { tickChan = ticker.C } else { defer w.Close() - if err := w.Add(filepath.Dir(ep.cfgPath)); err != nil { + if err := w.Add(ep.cfgPath); err != nil { return fmt.Errorf("failed to add fsnotify watch: %w", err) } eventChan = w.Events @@ -85,28 +110,57 @@ func (ep *egressProxy) run(ctx context.Context, n ipn.Notify) error { return err } for { - var err error select { case <-ctx.Done(): return nil case <-tickChan: - err = ep.sync(ctx, n) + log.Printf("periodic sync, ensuring firewall config is up to date...") case <-eventChan: log.Printf("config file change detected, ensuring firewall config is up to date...") - err = ep.sync(ctx, n) case n = <-ep.netmapChan: shouldResync := ep.shouldResync(n) - if shouldResync { - log.Printf("netmap change detected, ensuring firewall config is up to date...") - err = ep.sync(ctx, n) + if !shouldResync { + continue } + log.Printf("netmap change detected, ensuring firewall config is up to date...") } - if err != nil { + if err := ep.sync(ctx, n); err != nil { return fmt.Errorf("error syncing egress service config: %w", err) } } } +type egressProxyRunOpts struct { + cfgPath string + nfr linuxfw.NetfilterRunner + kc kubeclient.Client + tsClient *local.Client + stateSecret string + netmapChan chan ipn.Notify + podIPv4 string + tailnetAddrs []netip.Prefix +} + +// applyOpts configures egress proxy using the provided options. +func (ep *egressProxy) configure(opts egressProxyRunOpts) { + ep.cfgPath = opts.cfgPath + ep.nfr = opts.nfr + ep.kc = opts.kc + ep.tsClient = opts.tsClient + ep.stateSecret = opts.stateSecret + ep.netmapChan = opts.netmapChan + ep.podIPv4 = opts.podIPv4 + ep.tailnetAddrs = opts.tailnetAddrs + ep.client = &http.Client{} // default HTTP client + sleepDuration := time.Second + if d, err := time.ParseDuration(testSleepDuration); err == nil && d > 0 { + log.Printf("using test sleep duration %v", d) + sleepDuration = d + } + ep.shortSleep = sleepDuration + ep.longSleep = sleepDuration * 10 +} + // sync triggers an egress proxy config resync. The resync calculates the diff between config and status to determine if // any firewall rules need to be updated. Currently using status in state Secret as a reference for what is the current // firewall configuration is good enough because - the status is keyed by the Pod IP - we crash the Pod on errors such @@ -235,7 +289,7 @@ func updatesForCfg(svcName string, cfg egressservices.Config, status *egressserv log.Printf("tailnet target for egress service %s does not have any backend addresses, deleting all rules", svcName) for _, ip := range currentConfig.TailnetTargetIPs { for ports := range currentConfig.Ports { - rulesToDelete = append(rulesToAdd, rule{tailnetPort: ports.TargetPort, containerPort: ports.MatchPort, protocol: ports.Protocol, tailnetIP: ip}) + rulesToDelete = append(rulesToDelete, rule{tailnetPort: ports.TargetPort, containerPort: ports.MatchPort, protocol: ports.Protocol, tailnetIP: ip}) } } return rulesToAdd, rulesToDelete, nil @@ -327,7 +381,8 @@ func (ep *egressProxy) deleteUnnecessaryServices(cfgs *egressservices.Configs, s // getConfigs gets the mounted egress service configuration. func (ep *egressProxy) getConfigs() (*egressservices.Configs, error) { - j, err := os.ReadFile(ep.cfgPath) + svcsCfg := filepath.Join(ep.cfgPath, egressservices.KeyEgressServices) + j, err := os.ReadFile(svcsCfg) if os.IsNotExist(err) { return nil, nil } @@ -361,7 +416,7 @@ func (ep *egressProxy) getStatus(ctx context.Context) (*egressservices.Status, e if err := json.Unmarshal([]byte(raw), status); err != nil { return nil, fmt.Errorf("error unmarshalling previous config: %w", err) } - if reflect.DeepEqual(status.PodIP, ep.podIP) { + if reflect.DeepEqual(status.PodIPv4, ep.podIPv4) { return status, nil } return nil, nil @@ -374,7 +429,7 @@ func (ep *egressProxy) setStatus(ctx context.Context, status *egressservices.Sta if status == nil { status = &egressservices.Status{} } - status.PodIP = ep.podIP + status.PodIPv4 = ep.podIPv4 secret, err := ep.kc.GetSecret(ctx, ep.stateSecret) if err != nil { return fmt.Errorf("error retrieving state Secret: %w", err) @@ -389,7 +444,7 @@ func (ep *egressProxy) setStatus(ctx context.Context, status *egressservices.Sta Path: fmt.Sprintf("/data/%s", egressservices.KeyEgressServices), Value: bs, } - if err := ep.kc.JSONPatchSecret(ctx, ep.stateSecret, []kubeclient.JSONPatch{patch}); err != nil { + if err := ep.kc.JSONPatchResource(ctx, ep.stateSecret, kubeclient.TypeSecrets, []kubeclient.JSONPatch{patch}); err != nil { return fmt.Errorf("error patching state Secret: %w", err) } ep.tailnetAddrs = n.NetMap.SelfNode.Addresses().AsSlice() @@ -470,8 +525,9 @@ func (ep *egressProxy) shouldResync(n ipn.Notify) bool { if equalFQDNs(nn.Name(), fqdn) { if !reflect.DeepEqual(ips, nn.Addresses().AsSlice()) { log.Printf("backend addresses for egress target %q have changed old IPs %v, new IPs %v trigger egress config resync", nn.Name(), ips, nn.Addresses().AsSlice()) + return true } - return true + break } } } @@ -514,7 +570,7 @@ func ensureRulesAdded(rulesPerSvc map[string][]rule, nfr linuxfw.NetfilterRunner } // ensureRulesDeleted ensures that the given rules are deleted from the firewall -// configuration. For any rules that do not exist, calling this funcion is a +// configuration. For any rules that do not exist, calling this function is a // no-op. func ensureRulesDeleted(rulesPerSvc map[string][]rule, nfr linuxfw.NetfilterRunner) error { for svc, rules := range rulesPerSvc { @@ -565,7 +621,145 @@ func servicesStatusIsEqual(st, st1 *egressservices.Status) bool { if st == nil || st1 == nil { return false } - st.PodIP = "" - st1.PodIP = "" + st.PodIPv4 = "" + st1.PodIPv4 = "" return reflect.DeepEqual(*st, *st1) } + +// registerHandlers adds a new handler to the provided ServeMux that can be called as a Kubernetes prestop hook to +// delay shutdown till it's safe to do so. +func (ep *egressProxy) registerHandlers(mux *http.ServeMux) { + mux.Handle(fmt.Sprintf("GET %s", kubetypes.EgessServicesPreshutdownEP), ep) +} + +// ServeHTTP serves /internal-egress-services-preshutdown endpoint, when it receives a request, it periodically polls +// the configured health check endpoint for each egress service till it the health check endpoint no longer hits this +// proxy Pod. It uses the Pod-IPv4 header to verify if health check response is received from this Pod. +func (ep *egressProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { + cfgs, err := ep.getConfigs() + if err != nil { + http.Error(w, fmt.Sprintf("error retrieving egress services configs: %v", err), http.StatusInternalServerError) + return + } + if cfgs == nil { + if _, err := w.Write([]byte("safe to terminate")); err != nil { + http.Error(w, fmt.Sprintf("error writing termination status: %v", err), http.StatusInternalServerError) + } + return + } + hp, err := ep.getHEPPings() + if err != nil { + http.Error(w, fmt.Sprintf("error determining the number of times health check endpoint should be pinged: %v", err), http.StatusInternalServerError) + return + } + ep.waitTillSafeToShutdown(r.Context(), cfgs, hp) +} + +// waitTillSafeToShutdown looks up all egress targets configured to be proxied via this instance and, for each target +// whose configuration includes a healthcheck endpoint, pings the endpoint till none of the responses +// are returned by this instance or till the HTTP request times out. In practice, the endpoint will be a Kubernetes Service for whom one of the backends +// would normally be this Pod. When this Pod is being deleted, the operator should have removed it from the Service +// backends and eventually kube proxy routing rules should be updated to no longer route traffic for the Service to this +// Pod. +func (ep *egressProxy) waitTillSafeToShutdown(ctx context.Context, cfgs *egressservices.Configs, hp int) { + if cfgs == nil || len(*cfgs) == 0 { // avoid sleeping if no services are configured + return + } + log.Printf("Ensuring that cluster traffic for egress targets is no longer routed via this Pod...") + var wg sync.WaitGroup + for s, cfg := range *cfgs { + hep := cfg.HealthCheckEndpoint + if hep == "" { + log.Printf("Tailnet target %q does not have a cluster healthcheck specified, unable to verify if cluster traffic for the target is still routed via this Pod", s) + continue + } + svc := s + wg.Go(func() { + log.Printf("Ensuring that cluster traffic is no longer routed to %q via this Pod...", svc) + for { + if ctx.Err() != nil { // kubelet's HTTP request timeout + log.Printf("Cluster traffic for %s did not stop being routed to this Pod.", svc) + return + } + found, err := lookupPodRoute(ctx, hep, ep.podIPv4, hp, ep.client) + if err != nil { + log.Printf("unable to reach endpoint %q, assuming the routing rules for this Pod have been deleted: %v", hep, err) + break + } + if !found { + log.Printf("service %q is no longer routed through this Pod", svc) + break + } + log.Printf("service %q is still routed through this Pod, waiting...", svc) + time.Sleep(ep.shortSleep) + } + }) + } + wg.Wait() + // The check above really only checked that the routing rules are updated on this node. Sleep for a bit to + // ensure that the routing rules are updated on other nodes. TODO(irbekrm): this may or may not be good enough. + // If it's not good enough, we'd probably want to do something more complex, where the proxies check each other. + log.Printf("Sleeping for %s before shutdown to ensure that kube proxies on all nodes have updated routing configuration", ep.longSleep) + time.Sleep(ep.longSleep) +} + +// lookupPodRoute calls the healthcheck endpoint repeat times and returns true if the endpoint returns with the podIP +// header at least once. +func lookupPodRoute(ctx context.Context, hep, podIP string, repeat int, client httpClient) (bool, error) { + for range repeat { + f, err := lookup(ctx, hep, podIP, client) + if err != nil { + return false, err + } + if f { + return true, nil + } + } + return false, nil +} + +// lookup calls the healthcheck endpoint and returns true if the response contains the podIP header. +func lookup(ctx context.Context, hep, podIP string, client httpClient) (bool, error) { + req, err := http.NewRequestWithContext(ctx, httpm.GET, hep, nil) + if err != nil { + return false, fmt.Errorf("error creating new HTTP request: %v", err) + } + + // Close the TCP connection to ensure that the next request is routed to a different backend. + req.Close = true + + resp, err := client.Do(req) + if err != nil { + log.Printf("Endpoint %q can not be reached: %v, likely because there are no (more) healthy backends", hep, err) + return true, nil + } + defer resp.Body.Close() + gotIP := resp.Header.Get(kubetypes.PodIPv4Header) + return strings.EqualFold(podIP, gotIP), nil +} + +// getHEPPings gets the number of pings that should be sent to a health check endpoint to ensure that each configured +// backend is hit. This assumes that a health check endpoint is a Kubernetes Service and traffic to backend Pods is +// round robin load balanced. +func (ep *egressProxy) getHEPPings() (int, error) { + hepPingsPath := filepath.Join(ep.cfgPath, egressservices.KeyHEPPings) + j, err := os.ReadFile(hepPingsPath) + if os.IsNotExist(err) { + return 0, nil + } + if err != nil { + return -1, err + } + if len(j) == 0 || string(j) == "" { + return 0, nil + } + hp, err := strconv.Atoi(string(j)) + if err != nil { + return -1, fmt.Errorf("error parsing hep pings as int: %v", err) + } + if hp < 0 { + log.Printf("[unexpected] hep pings is negative: %d", hp) + return 0, nil + } + return hp, nil +} diff --git a/cmd/containerboot/services_test.go b/cmd/containerboot/egressservices_test.go similarity index 61% rename from cmd/containerboot/services_test.go rename to cmd/containerboot/egressservices_test.go index 46f6db1cf..724626b07 100644 --- a/cmd/containerboot/services_test.go +++ b/cmd/containerboot/egressservices_test.go @@ -6,11 +6,18 @@ package main import ( + "context" + "fmt" + "io" + "net/http" "net/netip" "reflect" + "strings" + "sync" "testing" "tailscale.com/kube/egressservices" + "tailscale.com/kube/kubetypes" ) func Test_updatesForSvc(t *testing.T) { @@ -173,3 +180,145 @@ func Test_updatesForSvc(t *testing.T) { }) } } + +// A failure of this test will most likely look like a timeout. +func TestWaitTillSafeToShutdown(t *testing.T) { + podIP := "10.0.0.1" + anotherIP := "10.0.0.2" + + tests := []struct { + name string + // services is a map of service name to the number of calls to make to the healthcheck endpoint before + // returning a response that does NOT contain this Pod's IP in headers. + services map[string]int + replicas int + healthCheckSet bool + }{ + { + name: "no_configs", + }, + { + name: "one_service_immediately_safe_to_shutdown", + services: map[string]int{ + "svc1": 0, + }, + replicas: 2, + healthCheckSet: true, + }, + { + name: "multiple_services_immediately_safe_to_shutdown", + services: map[string]int{ + "svc1": 0, + "svc2": 0, + "svc3": 0, + }, + replicas: 2, + healthCheckSet: true, + }, + { + name: "multiple_services_no_healthcheck_endpoints", + services: map[string]int{ + "svc1": 0, + "svc2": 0, + "svc3": 0, + }, + replicas: 2, + }, + { + name: "one_service_eventually_safe_to_shutdown", + services: map[string]int{ + "svc1": 3, // After 3 calls to health check endpoint, no longer returns this Pod's IP + }, + replicas: 2, + healthCheckSet: true, + }, + { + name: "multiple_services_eventually_safe_to_shutdown", + services: map[string]int{ + "svc1": 1, // After 1 call to health check endpoint, no longer returns this Pod's IP + "svc2": 3, // After 3 calls to health check endpoint, no longer returns this Pod's IP + "svc3": 5, // After 5 calls to the health check endpoint, no longer returns this Pod's IP + }, + replicas: 2, + healthCheckSet: true, + }, + { + name: "multiple_services_eventually_safe_to_shutdown_with_higher_replica_count", + services: map[string]int{ + "svc1": 7, + "svc2": 10, + }, + replicas: 5, + healthCheckSet: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfgs := &egressservices.Configs{} + switches := make(map[string]int) + + for svc, callsToSwitch := range tt.services { + endpoint := fmt.Sprintf("http://%s.local", svc) + if tt.healthCheckSet { + (*cfgs)[svc] = egressservices.Config{ + HealthCheckEndpoint: endpoint, + } + } + switches[endpoint] = callsToSwitch + } + + ep := &egressProxy{ + podIPv4: podIP, + client: &mockHTTPClient{ + podIP: podIP, + anotherIP: anotherIP, + switches: switches, + }, + } + + ep.waitTillSafeToShutdown(context.Background(), cfgs, tt.replicas) + }) + } +} + +// mockHTTPClient is a client that receives an HTTP call for an egress service endpoint and returns a response with an +// IP address in a 'Pod-IPv4' header. It can be configured to return one IP address for N calls, then switch to another +// IP address to simulate a scenario where an IP is eventually no longer a backend for an endpoint. +// TODO(irbekrm): to test this more thoroughly, we should have the client take into account the number of replicas and +// return as if traffic was round robin load balanced across different Pods. +type mockHTTPClient struct { + // podIP - initial IP address to return, that matches the current proxy's IP address. + podIP string + anotherIP string + // after how many calls to an endpoint, the client should start returning 'anotherIP' instead of 'podIP. + switches map[string]int + mu sync.Mutex // protects the following + // calls tracks the number of calls received. + calls map[string]int +} + +func (m *mockHTTPClient) Do(req *http.Request) (*http.Response, error) { + m.mu.Lock() + if m.calls == nil { + m.calls = make(map[string]int) + } + + endpoint := req.URL.String() + m.calls[endpoint]++ + calls := m.calls[endpoint] + m.mu.Unlock() + + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: make(http.Header), + Body: io.NopCloser(strings.NewReader("")), + } + + if calls <= m.switches[endpoint] { + resp.Header.Set(kubetypes.PodIPv4Header, m.podIP) // Pod is still routable + } else { + resp.Header.Set(kubetypes.PodIPv4Header, m.anotherIP) // Pod is no longer routable + } + return resp, nil +} diff --git a/cmd/containerboot/healthz.go b/cmd/containerboot/healthz.go deleted file mode 100644 index fb7fccd96..000000000 --- a/cmd/containerboot/healthz.go +++ /dev/null @@ -1,51 +0,0 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build linux - -package main - -import ( - "log" - "net" - "net/http" - "sync" -) - -// healthz is a simple health check server, if enabled it returns 200 OK if -// this tailscale node currently has at least one tailnet IP address else -// returns 503. -type healthz struct { - sync.Mutex - hasAddrs bool -} - -func (h *healthz) ServeHTTP(w http.ResponseWriter, r *http.Request) { - h.Lock() - defer h.Unlock() - if h.hasAddrs { - w.Write([]byte("ok")) - } else { - http.Error(w, "node currently has no tailscale IPs", http.StatusInternalServerError) - } -} - -// runHealthz runs a simple HTTP health endpoint on /healthz, listening on the -// provided address. A containerized tailscale instance is considered healthy if -// it has at least one tailnet IP address. -func runHealthz(addr string, h *healthz) { - lis, err := net.Listen("tcp", addr) - if err != nil { - log.Fatalf("error listening on the provided health endpoint address %q: %v", addr, err) - } - mux := http.NewServeMux() - mux.Handle("/healthz", h) - log.Printf("Running healthcheck endpoint at %s/healthz", addr) - hs := &http.Server{Handler: mux} - - go func() { - if err := hs.Serve(lis); err != nil { - log.Fatalf("failed running health endpoint: %v", err) - } - }() -} diff --git a/cmd/containerboot/ingressservices.go b/cmd/containerboot/ingressservices.go new file mode 100644 index 000000000..1a2da9567 --- /dev/null +++ b/cmd/containerboot/ingressservices.go @@ -0,0 +1,331 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build linux + +package main + +import ( + "context" + "encoding/json" + "fmt" + "log" + "net/netip" + "os" + "path/filepath" + "reflect" + "time" + + "github.com/fsnotify/fsnotify" + "tailscale.com/kube/ingressservices" + "tailscale.com/kube/kubeclient" + "tailscale.com/util/linuxfw" + "tailscale.com/util/mak" +) + +// ingressProxy corresponds to a Kubernetes Operator's network layer ingress +// proxy. It configures firewall rules (iptables or nftables) to proxy tailnet +// traffic to Kubernetes Services. Currently this is only used for network +// layer proxies in HA mode. +type ingressProxy struct { + cfgPath string // path to ingress configfile. + + // nfr is the netfilter runner used to configure firewall rules. + // This is going to be either iptables or nftables based runner. + // Never nil. + nfr linuxfw.NetfilterRunner + + kc kubeclient.Client // never nil + stateSecret string // Secret that holds Tailscale state + + // Pod's IP addresses are used as an identifier of this particular Pod. + podIPv4 string // empty if Pod does not have IPv4 address + podIPv6 string // empty if Pod does not have IPv6 address +} + +// run starts the ingress proxy and ensures that firewall rules are set on start +// and refreshed as ingress config changes. +func (p *ingressProxy) run(ctx context.Context, opts ingressProxyOpts) error { + log.Printf("starting ingress proxy...") + p.configure(opts) + var tickChan <-chan time.Time + var eventChan <-chan fsnotify.Event + if w, err := fsnotify.NewWatcher(); err != nil { + log.Printf("failed to create fsnotify watcher, timer-only mode: %v", err) + ticker := time.NewTicker(5 * time.Second) + defer ticker.Stop() + tickChan = ticker.C + } else { + defer w.Close() + dir := filepath.Dir(p.cfgPath) + if err := w.Add(dir); err != nil { + return fmt.Errorf("failed to add fsnotify watch for %v: %w", dir, err) + } + eventChan = w.Events + } + + if err := p.sync(ctx); err != nil { + return err + } + for { + select { + case <-ctx.Done(): + return nil + case <-tickChan: + log.Printf("periodic sync, ensuring firewall config is up to date...") + case <-eventChan: + log.Printf("config file change detected, ensuring firewall config is up to date...") + } + if err := p.sync(ctx); err != nil { + return fmt.Errorf("error syncing ingress service config: %w", err) + } + } +} + +// sync reconciles proxy's firewall rules (iptables or nftables) on ingress config changes: +// - ensures that new firewall rules are added +// - ensures that old firewall rules are deleted +// - updates ingress proxy's status in the state Secret +func (p *ingressProxy) sync(ctx context.Context) error { + // 1. Get the desired firewall configuration + cfgs, err := p.getConfigs() + if err != nil { + return fmt.Errorf("ingress proxy: error retrieving configs: %w", err) + } + + // 2. Get the recorded firewall status + status, err := p.getStatus(ctx) + if err != nil { + return fmt.Errorf("ingress proxy: error retrieving current status: %w", err) + } + + // 3. Ensure that firewall configuration is up to date + if err := p.syncIngressConfigs(cfgs, status); err != nil { + return fmt.Errorf("ingress proxy: error syncing configs: %w", err) + } + var existingConfigs *ingressservices.Configs + if status != nil { + existingConfigs = &status.Configs + } + + // 4. Update the recorded firewall status + if !(ingressServicesStatusIsEqual(cfgs, existingConfigs) && p.isCurrentStatus(status)) { + if err := p.recordStatus(ctx, cfgs); err != nil { + return fmt.Errorf("ingress proxy: error setting status: %w", err) + } + } + return nil +} + +// getConfigs returns the desired ingress service configuration from the mounted +// configfile. +func (p *ingressProxy) getConfigs() (*ingressservices.Configs, error) { + j, err := os.ReadFile(p.cfgPath) + if os.IsNotExist(err) { + return nil, nil + } + if err != nil { + return nil, err + } + if len(j) == 0 || string(j) == "" { + return nil, nil + } + cfg := &ingressservices.Configs{} + if err := json.Unmarshal(j, &cfg); err != nil { + return nil, err + } + return cfg, nil +} + +// getStatus gets the recorded status of the configured firewall. The status is +// stored in the proxy's state Secret. Note that the recorded status might not +// be the current status of the firewall if it belongs to a previous Pod- we +// take that into account further down the line when determining if the desired +// rules are actually present. +func (p *ingressProxy) getStatus(ctx context.Context) (*ingressservices.Status, error) { + secret, err := p.kc.GetSecret(ctx, p.stateSecret) + if err != nil { + return nil, fmt.Errorf("error retrieving state Secret: %w", err) + } + status := &ingressservices.Status{} + raw, ok := secret.Data[ingressservices.IngressConfigKey] + if !ok { + return nil, nil + } + if err := json.Unmarshal([]byte(raw), status); err != nil { + return nil, fmt.Errorf("error unmarshalling previous config: %w", err) + } + return status, nil +} + +// syncIngressConfigs takes the desired firewall configuration and the recorded +// status and ensures that any missing rules are added and no longer needed +// rules are deleted. +func (p *ingressProxy) syncIngressConfigs(cfgs *ingressservices.Configs, status *ingressservices.Status) error { + rulesToAdd := p.getRulesToAdd(cfgs, status) + rulesToDelete := p.getRulesToDelete(cfgs, status) + + if err := ensureIngressRulesDeleted(rulesToDelete, p.nfr); err != nil { + return fmt.Errorf("error deleting ingress rules: %w", err) + } + if err := ensureIngressRulesAdded(rulesToAdd, p.nfr); err != nil { + return fmt.Errorf("error adding ingress rules: %w", err) + } + return nil +} + +// recordStatus writes the configured firewall status to the proxy's state +// Secret. This allows the Kubernetes Operator to determine whether this proxy +// Pod has setup firewall rules to route traffic for an ingress service. +func (p *ingressProxy) recordStatus(ctx context.Context, newCfg *ingressservices.Configs) error { + status := &ingressservices.Status{} + if newCfg != nil { + status.Configs = *newCfg + } + // Pod IPs are used to determine if recorded status applies to THIS proxy Pod. + status.PodIPv4 = p.podIPv4 + status.PodIPv6 = p.podIPv6 + secret, err := p.kc.GetSecret(ctx, p.stateSecret) + if err != nil { + return fmt.Errorf("error retrieving state Secret: %w", err) + } + bs, err := json.Marshal(status) + if err != nil { + return fmt.Errorf("error marshalling status: %w", err) + } + secret.Data[ingressservices.IngressConfigKey] = bs + patch := kubeclient.JSONPatch{ + Op: "replace", + Path: fmt.Sprintf("/data/%s", ingressservices.IngressConfigKey), + Value: bs, + } + if err := p.kc.JSONPatchResource(ctx, p.stateSecret, kubeclient.TypeSecrets, []kubeclient.JSONPatch{patch}); err != nil { + return fmt.Errorf("error patching state Secret: %w", err) + } + return nil +} + +// getRulesToAdd takes the desired firewall configuration and the recorded +// firewall status and returns a map of missing Tailscale Services and rules. +func (p *ingressProxy) getRulesToAdd(cfgs *ingressservices.Configs, status *ingressservices.Status) map[string]ingressservices.Config { + if cfgs == nil { + return nil + } + var rulesToAdd map[string]ingressservices.Config + for tsSvc, wantsCfg := range *cfgs { + if status == nil || !p.isCurrentStatus(status) { + mak.Set(&rulesToAdd, tsSvc, wantsCfg) + continue + } + gotCfg := status.Configs.GetConfig(tsSvc) + if gotCfg == nil || !reflect.DeepEqual(wantsCfg, *gotCfg) { + mak.Set(&rulesToAdd, tsSvc, wantsCfg) + } + } + return rulesToAdd +} + +// getRulesToDelete takes the desired firewall configuration and the recorded +// status and returns a map of Tailscale Services and rules that need to be deleted. +func (p *ingressProxy) getRulesToDelete(cfgs *ingressservices.Configs, status *ingressservices.Status) map[string]ingressservices.Config { + if status == nil || !p.isCurrentStatus(status) { + return nil + } + var rulesToDelete map[string]ingressservices.Config + for tsSvc, gotCfg := range status.Configs { + if cfgs == nil { + mak.Set(&rulesToDelete, tsSvc, gotCfg) + continue + } + wantsCfg := cfgs.GetConfig(tsSvc) + if wantsCfg != nil && reflect.DeepEqual(*wantsCfg, gotCfg) { + continue + } + mak.Set(&rulesToDelete, tsSvc, gotCfg) + } + return rulesToDelete +} + +// ensureIngressRulesAdded takes a map of Tailscale Services and rules and ensures that the firewall rules are added. +func ensureIngressRulesAdded(cfgs map[string]ingressservices.Config, nfr linuxfw.NetfilterRunner) error { + for serviceName, cfg := range cfgs { + if cfg.IPv4Mapping != nil { + if err := addDNATRuleForSvc(nfr, serviceName, cfg.IPv4Mapping.TailscaleServiceIP, cfg.IPv4Mapping.ClusterIP); err != nil { + return fmt.Errorf("error adding ingress rule for %s: %w", serviceName, err) + } + } + if cfg.IPv6Mapping != nil { + if err := addDNATRuleForSvc(nfr, serviceName, cfg.IPv6Mapping.TailscaleServiceIP, cfg.IPv6Mapping.ClusterIP); err != nil { + return fmt.Errorf("error adding ingress rule for %s: %w", serviceName, err) + } + } + } + return nil +} + +func addDNATRuleForSvc(nfr linuxfw.NetfilterRunner, serviceName string, tsIP, clusterIP netip.Addr) error { + log.Printf("adding DNAT rule for Tailscale Service %s with IP %s to Kubernetes Service IP %s", serviceName, tsIP, clusterIP) + return nfr.EnsureDNATRuleForSvc(serviceName, tsIP, clusterIP) +} + +// ensureIngressRulesDeleted takes a map of Tailscale Services and rules and ensures that the firewall rules are deleted. +func ensureIngressRulesDeleted(cfgs map[string]ingressservices.Config, nfr linuxfw.NetfilterRunner) error { + for serviceName, cfg := range cfgs { + if cfg.IPv4Mapping != nil { + if err := deleteDNATRuleForSvc(nfr, serviceName, cfg.IPv4Mapping.TailscaleServiceIP, cfg.IPv4Mapping.ClusterIP); err != nil { + return fmt.Errorf("error deleting ingress rule for %s: %w", serviceName, err) + } + } + if cfg.IPv6Mapping != nil { + if err := deleteDNATRuleForSvc(nfr, serviceName, cfg.IPv6Mapping.TailscaleServiceIP, cfg.IPv6Mapping.ClusterIP); err != nil { + return fmt.Errorf("error deleting ingress rule for %s: %w", serviceName, err) + } + } + } + return nil +} + +func deleteDNATRuleForSvc(nfr linuxfw.NetfilterRunner, serviceName string, tsIP, clusterIP netip.Addr) error { + log.Printf("deleting DNAT rule for Tailscale Service %s with IP %s to Kubernetes Service IP %s", serviceName, tsIP, clusterIP) + return nfr.DeleteDNATRuleForSvc(serviceName, tsIP, clusterIP) +} + +// isCurrentStatus returns true if the status of an ingress proxy as read from +// the proxy's state Secret is the status of the current proxy Pod. We use +// Pod's IP addresses to determine that the status is for this Pod. +func (p *ingressProxy) isCurrentStatus(status *ingressservices.Status) bool { + if status == nil { + return true + } + return status.PodIPv4 == p.podIPv4 && status.PodIPv6 == p.podIPv6 +} + +type ingressProxyOpts struct { + cfgPath string + nfr linuxfw.NetfilterRunner // never nil + kc kubeclient.Client // never nil + stateSecret string + podIPv4 string + podIPv6 string +} + +// configure sets the ingress proxy's configuration. It is called once on start +// so we don't care about concurrent access to fields. +func (p *ingressProxy) configure(opts ingressProxyOpts) { + p.cfgPath = opts.cfgPath + p.nfr = opts.nfr + p.kc = opts.kc + p.stateSecret = opts.stateSecret + p.podIPv4 = opts.podIPv4 + p.podIPv6 = opts.podIPv6 +} + +func ingressServicesStatusIsEqual(st, st1 *ingressservices.Configs) bool { + if st == nil && st1 == nil { + return true + } + if st == nil || st1 == nil { + return false + } + return reflect.DeepEqual(*st, *st1) +} diff --git a/cmd/containerboot/ingressservices_test.go b/cmd/containerboot/ingressservices_test.go new file mode 100644 index 000000000..228bbb159 --- /dev/null +++ b/cmd/containerboot/ingressservices_test.go @@ -0,0 +1,223 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build linux + +package main + +import ( + "net/netip" + "testing" + + "tailscale.com/kube/ingressservices" + "tailscale.com/util/linuxfw" +) + +func TestSyncIngressConfigs(t *testing.T) { + tests := []struct { + name string + currentConfigs *ingressservices.Configs + currentStatus *ingressservices.Status + wantServices map[string]struct { + TailscaleServiceIP netip.Addr + ClusterIP netip.Addr + } + }{ + { + name: "add_new_rules_when_no_existing_config", + currentConfigs: &ingressservices.Configs{ + "svc:foo": makeServiceConfig("100.64.0.1", "10.0.0.1", "", ""), + }, + currentStatus: nil, + wantServices: map[string]struct { + TailscaleServiceIP netip.Addr + ClusterIP netip.Addr + }{ + "svc:foo": makeWantService("100.64.0.1", "10.0.0.1"), + }, + }, + { + name: "add_multiple_services", + currentConfigs: &ingressservices.Configs{ + "svc:foo": makeServiceConfig("100.64.0.1", "10.0.0.1", "", ""), + "svc:bar": makeServiceConfig("100.64.0.2", "10.0.0.2", "", ""), + "svc:baz": makeServiceConfig("100.64.0.3", "10.0.0.3", "", ""), + }, + currentStatus: nil, + wantServices: map[string]struct { + TailscaleServiceIP netip.Addr + ClusterIP netip.Addr + }{ + "svc:foo": makeWantService("100.64.0.1", "10.0.0.1"), + "svc:bar": makeWantService("100.64.0.2", "10.0.0.2"), + "svc:baz": makeWantService("100.64.0.3", "10.0.0.3"), + }, + }, + { + name: "add_both_ipv4_and_ipv6_rules", + currentConfigs: &ingressservices.Configs{ + "svc:foo": makeServiceConfig("100.64.0.1", "10.0.0.1", "2001:db8::1", "2001:db8::2"), + }, + currentStatus: nil, + wantServices: map[string]struct { + TailscaleServiceIP netip.Addr + ClusterIP netip.Addr + }{ + "svc:foo": makeWantService("2001:db8::1", "2001:db8::2"), + }, + }, + { + name: "add_ipv6_only_rules", + currentConfigs: &ingressservices.Configs{ + "svc:ipv6": makeServiceConfig("", "", "2001:db8::10", "2001:db8::20"), + }, + currentStatus: nil, + wantServices: map[string]struct { + TailscaleServiceIP netip.Addr + ClusterIP netip.Addr + }{ + "svc:ipv6": makeWantService("2001:db8::10", "2001:db8::20"), + }, + }, + { + name: "delete_all_rules_when_config_removed", + currentConfigs: nil, + currentStatus: &ingressservices.Status{ + Configs: ingressservices.Configs{ + "svc:foo": makeServiceConfig("100.64.0.1", "10.0.0.1", "", ""), + "svc:bar": makeServiceConfig("100.64.0.2", "10.0.0.2", "", ""), + }, + PodIPv4: "10.0.0.2", // Current pod IPv4 + PodIPv6: "2001:db8::2", // Current pod IPv6 + }, + wantServices: map[string]struct { + TailscaleServiceIP netip.Addr + ClusterIP netip.Addr + }{}, + }, + { + name: "add_remove_modify", + currentConfigs: &ingressservices.Configs{ + "svc:foo": makeServiceConfig("100.64.0.1", "10.0.0.2", "", ""), // Changed cluster IP + "svc:new": makeServiceConfig("100.64.0.4", "10.0.0.4", "", ""), + }, + currentStatus: &ingressservices.Status{ + Configs: ingressservices.Configs{ + "svc:foo": makeServiceConfig("100.64.0.1", "10.0.0.1", "", ""), + "svc:bar": makeServiceConfig("100.64.0.2", "10.0.0.2", "", ""), + "svc:baz": makeServiceConfig("100.64.0.3", "10.0.0.3", "", ""), + }, + PodIPv4: "10.0.0.2", // Current pod IPv4 + PodIPv6: "2001:db8::2", // Current pod IPv6 + }, + wantServices: map[string]struct { + TailscaleServiceIP netip.Addr + ClusterIP netip.Addr + }{ + "svc:foo": makeWantService("100.64.0.1", "10.0.0.2"), + "svc:new": makeWantService("100.64.0.4", "10.0.0.4"), + }, + }, + { + name: "update_with_outdated_status", + currentConfigs: &ingressservices.Configs{ + "svc:web": makeServiceConfig("100.64.0.10", "10.0.0.10", "", ""), + "svc:web-ipv6": { + IPv6Mapping: &ingressservices.Mapping{ + TailscaleServiceIP: netip.MustParseAddr("2001:db8::10"), + ClusterIP: netip.MustParseAddr("2001:db8::20"), + }, + }, + "svc:api": makeServiceConfig("100.64.0.20", "10.0.0.20", "", ""), + }, + currentStatus: &ingressservices.Status{ + Configs: ingressservices.Configs{ + "svc:web": makeServiceConfig("100.64.0.10", "10.0.0.10", "", ""), + "svc:web-ipv6": { + IPv6Mapping: &ingressservices.Mapping{ + TailscaleServiceIP: netip.MustParseAddr("2001:db8::10"), + ClusterIP: netip.MustParseAddr("2001:db8::20"), + }, + }, + "svc:old": makeServiceConfig("100.64.0.30", "10.0.0.30", "", ""), + }, + PodIPv4: "10.0.0.1", // Outdated pod IP + PodIPv6: "2001:db8::1", // Outdated pod IP + }, + wantServices: map[string]struct { + TailscaleServiceIP netip.Addr + ClusterIP netip.Addr + }{ + "svc:web": makeWantService("100.64.0.10", "10.0.0.10"), + "svc:web-ipv6": makeWantService("2001:db8::10", "2001:db8::20"), + "svc:api": makeWantService("100.64.0.20", "10.0.0.20"), + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var nfr linuxfw.NetfilterRunner = linuxfw.NewFakeNetfilterRunner() + + ep := &ingressProxy{ + nfr: nfr, + podIPv4: "10.0.0.2", // Current pod IPv4 + podIPv6: "2001:db8::2", // Current pod IPv6 + } + + err := ep.syncIngressConfigs(tt.currentConfigs, tt.currentStatus) + if err != nil { + t.Fatalf("syncIngressConfigs failed: %v", err) + } + + fake := nfr.(*linuxfw.FakeNetfilterRunner) + gotServices := fake.GetServiceState() + if len(gotServices) != len(tt.wantServices) { + t.Errorf("got %d services, want %d", len(gotServices), len(tt.wantServices)) + } + for svc, want := range tt.wantServices { + got, ok := gotServices[svc] + if !ok { + t.Errorf("service %s not found", svc) + continue + } + if got.TailscaleServiceIP != want.TailscaleServiceIP { + t.Errorf("service %s: got TailscaleServiceIP %v, want %v", svc, got.TailscaleServiceIP, want.TailscaleServiceIP) + } + if got.ClusterIP != want.ClusterIP { + t.Errorf("service %s: got ClusterIP %v, want %v", svc, got.ClusterIP, want.ClusterIP) + } + } + }) + } +} + +func makeServiceConfig(tsIP, clusterIP string, tsIP6, clusterIP6 string) ingressservices.Config { + cfg := ingressservices.Config{} + if tsIP != "" && clusterIP != "" { + cfg.IPv4Mapping = &ingressservices.Mapping{ + TailscaleServiceIP: netip.MustParseAddr(tsIP), + ClusterIP: netip.MustParseAddr(clusterIP), + } + } + if tsIP6 != "" && clusterIP6 != "" { + cfg.IPv6Mapping = &ingressservices.Mapping{ + TailscaleServiceIP: netip.MustParseAddr(tsIP6), + ClusterIP: netip.MustParseAddr(clusterIP6), + } + } + return cfg +} + +func makeWantService(tsIP, clusterIP string) struct { + TailscaleServiceIP netip.Addr + ClusterIP netip.Addr +} { + return struct { + TailscaleServiceIP netip.Addr + ClusterIP netip.Addr + }{ + TailscaleServiceIP: netip.MustParseAddr(tsIP), + ClusterIP: netip.MustParseAddr(clusterIP), + } +} diff --git a/cmd/containerboot/kube.go b/cmd/containerboot/kube.go index 908cc01ef..e566fa483 100644 --- a/cmd/containerboot/kube.go +++ b/cmd/containerboot/kube.go @@ -8,31 +8,67 @@ package main import ( "context" "encoding/json" + "errors" "fmt" "log" "net/http" "net/netip" "os" + "strings" + "time" + "tailscale.com/ipn" + "tailscale.com/kube/egressservices" + "tailscale.com/kube/ingressservices" "tailscale.com/kube/kubeapi" "tailscale.com/kube/kubeclient" + "tailscale.com/kube/kubetypes" "tailscale.com/tailcfg" + "tailscale.com/types/logger" + "tailscale.com/util/backoff" + "tailscale.com/util/set" ) -// storeDeviceID writes deviceID to 'device_id' data field of the named -// Kubernetes Secret. -func storeDeviceID(ctx context.Context, secretName string, deviceID tailcfg.StableNodeID) error { +// kubeClient is a wrapper around Tailscale's internal kube client that knows how to talk to the kube API server. We use +// this rather than any of the upstream Kubernetes client libaries to avoid extra imports. +type kubeClient struct { + kubeclient.Client + stateSecret string + canPatch bool // whether the client has permissions to patch Kubernetes Secrets +} + +func newKubeClient(root string, stateSecret string) (*kubeClient, error) { + if root != "/" { + // If we are running in a test, we need to set the root path to the fake + // service account directory. + kubeclient.SetRootPathForTesting(root) + } + var err error + kc, err := kubeclient.New("tailscale-container") + if err != nil { + return nil, fmt.Errorf("Error creating kube client: %w", err) + } + if (root != "/") || os.Getenv("TS_KUBERNETES_READ_API_SERVER_ADDRESS_FROM_ENV") == "true" { + // Derive the API server address from the environment variables + // Used to set http server in tests, or optionally enabled by flag + kc.SetURL(fmt.Sprintf("https://%s:%s", os.Getenv("KUBERNETES_SERVICE_HOST"), os.Getenv("KUBERNETES_SERVICE_PORT_HTTPS"))) + } + return &kubeClient{Client: kc, stateSecret: stateSecret}, nil +} + +// storeDeviceID writes deviceID to 'device_id' data field of the client's state Secret. +func (kc *kubeClient) storeDeviceID(ctx context.Context, deviceID tailcfg.StableNodeID) error { s := &kubeapi.Secret{ Data: map[string][]byte{ - "device_id": []byte(deviceID), + kubetypes.KeyDeviceID: []byte(deviceID), }, } - return kc.StrategicMergePatchSecret(ctx, secretName, s, "tailscale-container") + return kc.StrategicMergePatchSecret(ctx, kc.stateSecret, s, "tailscale-container") } -// storeDeviceEndpoints writes device's tailnet IPs and MagicDNS name to fields -// 'device_ips', 'device_fqdn' of the named Kubernetes Secret. -func storeDeviceEndpoints(ctx context.Context, secretName string, fqdn string, addresses []netip.Prefix) error { +// storeDeviceEndpoints writes device's tailnet IPs and MagicDNS name to fields 'device_ips', 'device_fqdn' of client's +// state Secret. +func (kc *kubeClient) storeDeviceEndpoints(ctx context.Context, fqdn string, addresses []netip.Prefix) error { var ips []string for _, addr := range addresses { ips = append(ips, addr.Addr().String()) @@ -44,16 +80,28 @@ func storeDeviceEndpoints(ctx context.Context, secretName string, fqdn string, a s := &kubeapi.Secret{ Data: map[string][]byte{ - "device_fqdn": []byte(fqdn), - "device_ips": deviceIPs, + kubetypes.KeyDeviceFQDN: []byte(fqdn), + kubetypes.KeyDeviceIPs: deviceIPs, + }, + } + return kc.StrategicMergePatchSecret(ctx, kc.stateSecret, s, "tailscale-container") +} + +// storeHTTPSEndpoint writes an HTTPS endpoint exposed by this device via 'tailscale serve' to the client's state +// Secret. In practice this will be the same value that gets written to 'device_fqdn', but this should only be called +// when the serve config has been successfully set up. +func (kc *kubeClient) storeHTTPSEndpoint(ctx context.Context, ep string) error { + s := &kubeapi.Secret{ + Data: map[string][]byte{ + kubetypes.KeyHTTPSEndpoint: []byte(ep), }, } - return kc.StrategicMergePatchSecret(ctx, secretName, s, "tailscale-container") + return kc.StrategicMergePatchSecret(ctx, kc.stateSecret, s, "tailscale-container") } // deleteAuthKey deletes the 'authkey' field of the given kube // secret. No-op if there is no authkey in the secret. -func deleteAuthKey(ctx context.Context, secretName string) error { +func (kc *kubeClient) deleteAuthKey(ctx context.Context) error { // m is a JSON Patch data structure, see https://jsonpatch.com/ or RFC 6902. m := []kubeclient.JSONPatch{ { @@ -61,7 +109,7 @@ func deleteAuthKey(ctx context.Context, secretName string) error { Path: "/data/authkey", }, } - if err := kc.JSONPatchSecret(ctx, secretName, m); err != nil { + if err := kc.JSONPatchResource(ctx, kc.stateSecret, kubeclient.TypeSecrets, m); err != nil { if s, ok := err.(*kubeapi.Status); ok && s.Code == http.StatusUnprocessableEntity { // This is kubernetes-ese for "the field you asked to // delete already doesn't exist", aka no-op. @@ -72,22 +120,100 @@ func deleteAuthKey(ctx context.Context, secretName string) error { return nil } -var kc kubeclient.Client +// resetContainerbootState resets state from previous runs of containerboot to +// ensure the operator doesn't use stale state when a Pod is first recreated. +func (kc *kubeClient) resetContainerbootState(ctx context.Context, podUID string) error { + existingSecret, err := kc.GetSecret(ctx, kc.stateSecret) + switch { + case kubeclient.IsNotFoundErr(err): + // In the case that the Secret doesn't exist, we don't have any state to reset and can return early. + return nil + case err != nil: + return fmt.Errorf("failed to read state Secret %q to reset state: %w", kc.stateSecret, err) + } + s := &kubeapi.Secret{ + Data: map[string][]byte{ + kubetypes.KeyCapVer: fmt.Appendf(nil, "%d", tailcfg.CurrentCapabilityVersion), + }, + } + if podUID != "" { + s.Data[kubetypes.KeyPodUID] = []byte(podUID) + } -func initKubeClient(root string) { - if root != "/" { - // If we are running in a test, we need to set the root path to the fake - // service account directory. - kubeclient.SetRootPathForTesting(root) + toClear := set.SetOf([]string{ + kubetypes.KeyDeviceID, + kubetypes.KeyDeviceFQDN, + kubetypes.KeyDeviceIPs, + kubetypes.KeyHTTPSEndpoint, + egressservices.KeyEgressServices, + ingressservices.IngressConfigKey, + }) + for key := range existingSecret.Data { + if toClear.Contains(key) { + // It's fine to leave the key in place as a debugging breadcrumb, + // it should get a new value soon. + s.Data[key] = nil + } } - var err error - kc, err = kubeclient.New() - if err != nil { - log.Fatalf("Error creating kube client: %v", err) + + return kc.StrategicMergePatchSecret(ctx, kc.stateSecret, s, "tailscale-container") +} + +// waitForConsistentState waits for tailscaled to finish writing state if it +// looks like it's started. It is designed to reduce the likelihood that +// tailscaled gets shut down in the window between authenticating to control +// and finishing writing state. However, it's not bullet proof because we can't +// atomically authenticate and write state. +func (kc *kubeClient) waitForConsistentState(ctx context.Context) error { + var logged bool + + bo := backoff.NewBackoff("", logger.Discard, 2*time.Second) + for { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + secret, err := kc.GetSecret(ctx, kc.stateSecret) + if ctx.Err() != nil || kubeclient.IsNotFoundErr(err) { + return nil + } + if err != nil { + return fmt.Errorf("getting Secret %q: %v", kc.stateSecret, err) + } + + if hasConsistentState(secret.Data) { + return nil + } + + if !logged { + log.Printf("Waiting for tailscaled to finish writing state to Secret %q", kc.stateSecret) + logged = true + } + bo.BackOff(ctx, errors.New("")) // Fake error to trigger actual sleep. } - if (root != "/") || os.Getenv("TS_KUBERNETES_READ_API_SERVER_ADDRESS_FROM_ENV") == "true" { - // Derive the API server address from the environment variables - // Used to set http server in tests, or optionally enabled by flag - kc.SetURL(fmt.Sprintf("https://%s:%s", os.Getenv("KUBERNETES_SERVICE_HOST"), os.Getenv("KUBERNETES_SERVICE_PORT_HTTPS"))) +} + +// hasConsistentState returns true is there is either no state or the full set +// of expected keys are present. +func hasConsistentState(d map[string][]byte) bool { + var ( + _, hasCurrent = d[string(ipn.CurrentProfileStateKey)] + _, hasKnown = d[string(ipn.KnownProfilesStateKey)] + _, hasMachine = d[string(ipn.MachineKeyStateKey)] + hasProfile bool + ) + + for k := range d { + if strings.HasPrefix(k, "profile-") { + if hasProfile { + return false // We only expect one profile. + } + hasProfile = true + } } + + // Approximate check, we don't want to reimplement all of profileManager. + return (hasCurrent && hasKnown && hasMachine && hasProfile) || + (!hasCurrent && !hasKnown && !hasMachine && !hasProfile) } diff --git a/cmd/containerboot/kube_test.go b/cmd/containerboot/kube_test.go index 1a5730548..c33714ed1 100644 --- a/cmd/containerboot/kube_test.go +++ b/cmd/containerboot/kube_test.go @@ -8,11 +8,18 @@ package main import ( "context" "errors" + "fmt" "testing" + "time" "github.com/google/go-cmp/cmp" + "tailscale.com/ipn" + "tailscale.com/kube/egressservices" + "tailscale.com/kube/ingressservices" "tailscale.com/kube/kubeapi" "tailscale.com/kube/kubeclient" + "tailscale.com/kube/kubetypes" + "tailscale.com/tailcfg" ) func TestSetupKube(t *testing.T) { @@ -21,7 +28,7 @@ func TestSetupKube(t *testing.T) { cfg *settings wantErr bool wantCfg *settings - kc kubeclient.Client + kc *kubeClient }{ { name: "TS_AUTHKEY set, state Secret exists", @@ -29,14 +36,14 @@ func TestSetupKube(t *testing.T) { AuthKey: "foo", KubeSecret: "foo", }, - kc: &kubeclient.FakeClient{ + kc: &kubeClient{stateSecret: "foo", Client: &kubeclient.FakeClient{ CheckSecretPermissionsImpl: func(context.Context, string) (bool, bool, error) { return false, false, nil }, GetSecretImpl: func(context.Context, string) (*kubeapi.Secret, error) { return nil, nil }, - }, + }}, wantCfg: &settings{ AuthKey: "foo", KubeSecret: "foo", @@ -48,14 +55,14 @@ func TestSetupKube(t *testing.T) { AuthKey: "foo", KubeSecret: "foo", }, - kc: &kubeclient.FakeClient{ + kc: &kubeClient{stateSecret: "foo", Client: &kubeclient.FakeClient{ CheckSecretPermissionsImpl: func(context.Context, string) (bool, bool, error) { return false, true, nil }, GetSecretImpl: func(context.Context, string) (*kubeapi.Secret, error) { return nil, &kubeapi.Status{Code: 404} }, - }, + }}, wantCfg: &settings{ AuthKey: "foo", KubeSecret: "foo", @@ -67,14 +74,14 @@ func TestSetupKube(t *testing.T) { AuthKey: "foo", KubeSecret: "foo", }, - kc: &kubeclient.FakeClient{ + kc: &kubeClient{stateSecret: "foo", Client: &kubeclient.FakeClient{ CheckSecretPermissionsImpl: func(context.Context, string) (bool, bool, error) { return false, false, nil }, GetSecretImpl: func(context.Context, string) (*kubeapi.Secret, error) { return nil, &kubeapi.Status{Code: 404} }, - }, + }}, wantCfg: &settings{ AuthKey: "foo", KubeSecret: "foo", @@ -87,14 +94,14 @@ func TestSetupKube(t *testing.T) { AuthKey: "foo", KubeSecret: "foo", }, - kc: &kubeclient.FakeClient{ + kc: &kubeClient{stateSecret: "foo", Client: &kubeclient.FakeClient{ CheckSecretPermissionsImpl: func(context.Context, string) (bool, bool, error) { return false, false, nil }, GetSecretImpl: func(context.Context, string) (*kubeapi.Secret, error) { return nil, &kubeapi.Status{Code: 403} }, - }, + }}, wantCfg: &settings{ AuthKey: "foo", KubeSecret: "foo", @@ -111,11 +118,11 @@ func TestSetupKube(t *testing.T) { AuthKey: "foo", KubeSecret: "foo", }, - kc: &kubeclient.FakeClient{ + kc: &kubeClient{stateSecret: "foo", Client: &kubeclient.FakeClient{ CheckSecretPermissionsImpl: func(context.Context, string) (bool, bool, error) { return false, false, errors.New("broken") }, - }, + }}, wantErr: true, }, { @@ -127,14 +134,14 @@ func TestSetupKube(t *testing.T) { wantCfg: &settings{ KubeSecret: "foo", }, - kc: &kubeclient.FakeClient{ + kc: &kubeClient{stateSecret: "foo", Client: &kubeclient.FakeClient{ CheckSecretPermissionsImpl: func(context.Context, string) (bool, bool, error) { return false, true, nil }, GetSecretImpl: func(context.Context, string) (*kubeapi.Secret, error) { return nil, &kubeapi.Status{Code: 404} }, - }, + }}, }, { // Interactive login using URL in Pod logs @@ -145,28 +152,28 @@ func TestSetupKube(t *testing.T) { wantCfg: &settings{ KubeSecret: "foo", }, - kc: &kubeclient.FakeClient{ + kc: &kubeClient{stateSecret: "foo", Client: &kubeclient.FakeClient{ CheckSecretPermissionsImpl: func(context.Context, string) (bool, bool, error) { return false, false, nil }, GetSecretImpl: func(context.Context, string) (*kubeapi.Secret, error) { return &kubeapi.Secret{}, nil }, - }, + }}, }, { name: "TS_AUTHKEY not set, state Secret contains auth key, we do not have RBAC to patch it", cfg: &settings{ KubeSecret: "foo", }, - kc: &kubeclient.FakeClient{ + kc: &kubeClient{stateSecret: "foo", Client: &kubeclient.FakeClient{ CheckSecretPermissionsImpl: func(context.Context, string) (bool, bool, error) { return false, false, nil }, GetSecretImpl: func(context.Context, string) (*kubeapi.Secret, error) { return &kubeapi.Secret{Data: map[string][]byte{"authkey": []byte("foo")}}, nil }, - }, + }}, wantCfg: &settings{ KubeSecret: "foo", }, @@ -177,14 +184,14 @@ func TestSetupKube(t *testing.T) { cfg: &settings{ KubeSecret: "foo", }, - kc: &kubeclient.FakeClient{ + kc: &kubeClient{stateSecret: "foo", Client: &kubeclient.FakeClient{ CheckSecretPermissionsImpl: func(context.Context, string) (bool, bool, error) { return true, false, nil }, GetSecretImpl: func(context.Context, string) (*kubeapi.Secret, error) { return &kubeapi.Secret{Data: map[string][]byte{"authkey": []byte("foo")}}, nil }, - }, + }}, wantCfg: &settings{ KubeSecret: "foo", AuthKey: "foo", @@ -194,9 +201,9 @@ func TestSetupKube(t *testing.T) { } for _, tt := range tests { - kc = tt.kc + kc := tt.kc t.Run(tt.name, func(t *testing.T) { - if err := tt.cfg.setupKube(context.Background()); (err != nil) != tt.wantErr { + if err := tt.cfg.setupKube(context.Background(), kc); (err != nil) != tt.wantErr { t.Errorf("settings.setupKube() error = %v, wantErr %v", err, tt.wantErr) } if diff := cmp.Diff(*tt.cfg, *tt.wantCfg); diff != "" { @@ -205,3 +212,109 @@ func TestSetupKube(t *testing.T) { }) } } + +func TestWaitForConsistentState(t *testing.T) { + data := map[string][]byte{ + // Missing _current-profile. + string(ipn.KnownProfilesStateKey): []byte(""), + string(ipn.MachineKeyStateKey): []byte(""), + "profile-foo": []byte(""), + } + kc := &kubeClient{ + Client: &kubeclient.FakeClient{ + GetSecretImpl: func(context.Context, string) (*kubeapi.Secret, error) { + return &kubeapi.Secret{ + Data: data, + }, nil + }, + }, + } + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + if err := kc.waitForConsistentState(ctx); err != context.DeadlineExceeded { + t.Fatalf("expected DeadlineExceeded, got %v", err) + } + + ctx, cancel = context.WithTimeout(context.Background(), time.Second) + defer cancel() + data[string(ipn.CurrentProfileStateKey)] = []byte("") + if err := kc.waitForConsistentState(ctx); err != nil { + t.Fatalf("expected nil, got %v", err) + } +} + +func TestResetContainerbootState(t *testing.T) { + capver := fmt.Appendf(nil, "%d", tailcfg.CurrentCapabilityVersion) + for name, tc := range map[string]struct { + podUID string + initial map[string][]byte + expected map[string][]byte + }{ + "empty_initial": { + podUID: "1234", + initial: map[string][]byte{}, + expected: map[string][]byte{ + kubetypes.KeyCapVer: capver, + kubetypes.KeyPodUID: []byte("1234"), + }, + }, + "empty_initial_no_pod_uid": { + initial: map[string][]byte{}, + expected: map[string][]byte{ + kubetypes.KeyCapVer: capver, + }, + }, + "only_relevant_keys_updated": { + podUID: "1234", + initial: map[string][]byte{ + kubetypes.KeyCapVer: []byte("1"), + kubetypes.KeyPodUID: []byte("5678"), + kubetypes.KeyDeviceID: []byte("device-id"), + kubetypes.KeyDeviceFQDN: []byte("device-fqdn"), + kubetypes.KeyDeviceIPs: []byte(`["192.0.2.1"]`), + kubetypes.KeyHTTPSEndpoint: []byte("https://example.com"), + egressservices.KeyEgressServices: []byte("egress-services"), + ingressservices.IngressConfigKey: []byte("ingress-config"), + "_current-profile": []byte("current-profile"), + "_machinekey": []byte("machine-key"), + "_profiles": []byte("profiles"), + "_serve_e0ce": []byte("serve-e0ce"), + "profile-e0ce": []byte("profile-e0ce"), + }, + expected: map[string][]byte{ + kubetypes.KeyCapVer: capver, + kubetypes.KeyPodUID: []byte("1234"), + // Cleared keys. + kubetypes.KeyDeviceID: nil, + kubetypes.KeyDeviceFQDN: nil, + kubetypes.KeyDeviceIPs: nil, + kubetypes.KeyHTTPSEndpoint: nil, + egressservices.KeyEgressServices: nil, + ingressservices.IngressConfigKey: nil, + // Tailscaled keys not included in patch. + }, + }, + } { + t.Run(name, func(t *testing.T) { + var actual map[string][]byte + kc := &kubeClient{stateSecret: "foo", Client: &kubeclient.FakeClient{ + GetSecretImpl: func(context.Context, string) (*kubeapi.Secret, error) { + return &kubeapi.Secret{ + Data: tc.initial, + }, nil + }, + StrategicMergePatchSecretImpl: func(ctx context.Context, name string, secret *kubeapi.Secret, _ string) error { + actual = secret.Data + return nil + }, + }} + if err := kc.resetContainerbootState(context.Background(), tc.podUID); err != nil { + t.Fatalf("resetContainerbootState() error = %v", err) + } + if diff := cmp.Diff(tc.expected, actual); diff != "" { + t.Errorf("resetContainerbootState() mismatch (-want +got):\n%s", diff) + } + }) + } +} diff --git a/cmd/containerboot/main.go b/cmd/containerboot/main.go index 86612d1a6..f056d26f3 100644 --- a/cmd/containerboot/main.go +++ b/cmd/containerboot/main.go @@ -52,11 +52,17 @@ // ${TS_CERT_DOMAIN}, it will be replaced with the value of the available FQDN. // It cannot be used in conjunction with TS_DEST_IP. The file is watched for changes, // and will be re-applied when it changes. -// - TS_HEALTHCHECK_ADDR_PORT: if specified, an HTTP health endpoint will be -// served at /healthz at the provided address, which should be in form [
]:. -// If not set, no health check will be run. If set to :, addr will default to 0.0.0.0 -// The health endpoint will return 200 OK if this node has at least one tailnet IP address, -// otherwise returns 503. +// - TS_HEALTHCHECK_ADDR_PORT: deprecated, use TS_ENABLE_HEALTH_CHECK instead and optionally +// set TS_LOCAL_ADDR_PORT. Will be removed in 1.82.0. +// - TS_LOCAL_ADDR_PORT: the address and port to serve local metrics and health +// check endpoints if enabled via TS_ENABLE_METRICS and/or TS_ENABLE_HEALTH_CHECK. +// Defaults to [::]:9002, serving on all available interfaces. +// - TS_ENABLE_METRICS: if true, a metrics endpoint will be served at /metrics on +// the address specified by TS_LOCAL_ADDR_PORT. See https://tailscale.com/kb/1482/client-metrics +// for more information on the metrics exposed. +// - TS_ENABLE_HEALTH_CHECK: if true, a health check endpoint will be served at /healthz on +// the address specified by TS_LOCAL_ADDR_PORT. The health endpoint will return 200 +// OK if this node has at least one tailnet IP address, otherwise returns 503. // NB: the health criteria might change in the future. // - TS_EXPERIMENTAL_VERSIONED_CONFIG_DIR: if specified, a path to a // directory that containers tailscaled config in file. The config file needs to be @@ -99,10 +105,10 @@ import ( "log" "math" "net" + "net/http" "net/netip" "os" "os/signal" - "path" "path/filepath" "slices" "strings" @@ -115,6 +121,10 @@ import ( "tailscale.com/client/tailscale" "tailscale.com/ipn" kubeutils "tailscale.com/k8s-operator" + healthz "tailscale.com/kube/health" + "tailscale.com/kube/kubetypes" + "tailscale.com/kube/metrics" + "tailscale.com/kube/services" "tailscale.com/tailcfg" "tailscale.com/types/logger" "tailscale.com/types/ptr" @@ -130,82 +140,134 @@ func newNetfilterRunner(logf logger.Logf) (linuxfw.NetfilterRunner, error) { } func main() { + if err := run(); err != nil && !errors.Is(err, context.Canceled) { + log.Fatal(err) + } +} + +func run() error { log.SetPrefix("boot: ") tailscale.I_Acknowledge_This_API_Is_Unstable = true - cfg := &settings{ - AuthKey: defaultEnvs([]string{"TS_AUTHKEY", "TS_AUTH_KEY"}, ""), - Hostname: defaultEnv("TS_HOSTNAME", ""), - Routes: defaultEnvStringPointer("TS_ROUTES"), - ServeConfigPath: defaultEnv("TS_SERVE_CONFIG", ""), - ProxyTargetIP: defaultEnv("TS_DEST_IP", ""), - ProxyTargetDNSName: defaultEnv("TS_EXPERIMENTAL_DEST_DNS_NAME", ""), - TailnetTargetIP: defaultEnv("TS_TAILNET_TARGET_IP", ""), - TailnetTargetFQDN: defaultEnv("TS_TAILNET_TARGET_FQDN", ""), - DaemonExtraArgs: defaultEnv("TS_TAILSCALED_EXTRA_ARGS", ""), - ExtraArgs: defaultEnv("TS_EXTRA_ARGS", ""), - InKubernetes: os.Getenv("KUBERNETES_SERVICE_HOST") != "", - UserspaceMode: defaultBool("TS_USERSPACE", true), - StateDir: defaultEnv("TS_STATE_DIR", ""), - AcceptDNS: defaultEnvBoolPointer("TS_ACCEPT_DNS"), - KubeSecret: defaultEnv("TS_KUBE_SECRET", "tailscale"), - SOCKSProxyAddr: defaultEnv("TS_SOCKS5_SERVER", ""), - HTTPProxyAddr: defaultEnv("TS_OUTBOUND_HTTP_PROXY_LISTEN", ""), - Socket: defaultEnv("TS_SOCKET", "/tmp/tailscaled.sock"), - AuthOnce: defaultBool("TS_AUTH_ONCE", false), - Root: defaultEnv("TS_TEST_ONLY_ROOT", "/"), - TailscaledConfigFilePath: tailscaledConfigFilePath(), - AllowProxyingClusterTrafficViaIngress: defaultBool("EXPERIMENTAL_ALLOW_PROXYING_CLUSTER_TRAFFIC_VIA_INGRESS", false), - PodIP: defaultEnv("POD_IP", ""), - EnableForwardingOptimizations: defaultBool("TS_EXPERIMENTAL_ENABLE_FORWARDING_OPTIMIZATIONS", false), - HealthCheckAddrPort: defaultEnv("TS_HEALTHCHECK_ADDR_PORT", ""), - EgressSvcsCfgPath: defaultEnv("TS_EGRESS_SERVICES_CONFIG_PATH", ""), - } - - if err := cfg.validate(); err != nil { - log.Fatalf("invalid configuration: %v", err) + + cfg, err := configFromEnv() + if err != nil { + return fmt.Errorf("invalid configuration: %w", err) } if !cfg.UserspaceMode { if err := ensureTunFile(cfg.Root); err != nil { - log.Fatalf("Unable to create tuntap device file: %v", err) + return fmt.Errorf("unable to create tuntap device file: %w", err) } if cfg.ProxyTargetIP != "" || cfg.ProxyTargetDNSName != "" || cfg.Routes != nil || cfg.TailnetTargetIP != "" || cfg.TailnetTargetFQDN != "" { if err := ensureIPForwarding(cfg.Root, cfg.ProxyTargetIP, cfg.TailnetTargetIP, cfg.TailnetTargetFQDN, cfg.Routes); err != nil { log.Printf("Failed to enable IP forwarding: %v", err) log.Printf("To run tailscale as a proxy or router container, IP forwarding must be enabled.") if cfg.InKubernetes { - log.Fatalf("You can either set the sysctls as a privileged initContainer, or run the tailscale container with privileged=true.") + return fmt.Errorf("you can either set the sysctls as a privileged initContainer, or run the tailscale container with privileged=true.") } else { - log.Fatalf("You can fix this by running the container with privileged=true, or the equivalent in your container runtime that permits access to sysctls.") + return fmt.Errorf("you can fix this by running the container with privileged=true, or the equivalent in your container runtime that permits access to sysctls.") } } } } - // Context is used for all setup stuff until we're in steady + // Root context for the whole containerboot process, used to make sure + // shutdown signals are promptly and cleanly handled. + ctx, cancel := contextWithExitSignalWatch() + defer cancel() + + // bootCtx is used for all setup stuff until we're in steady // state, so that if something is hanging we eventually time out // and crashloop the container. - bootCtx, cancel := context.WithTimeout(context.Background(), 60*time.Second) + bootCtx, cancel := context.WithTimeout(ctx, 60*time.Second) defer cancel() + var kc *kubeClient if cfg.InKubernetes { - initKubeClient(cfg.Root) - if err := cfg.setupKube(bootCtx); err != nil { - log.Fatalf("error setting up for running on Kubernetes: %v", err) + kc, err = newKubeClient(cfg.Root, cfg.KubeSecret) + if err != nil { + return fmt.Errorf("error initializing kube client: %w", err) + } + if err := cfg.setupKube(bootCtx, kc); err != nil { + return fmt.Errorf("error setting up for running on Kubernetes: %w", err) + } + // Clear out any state from previous runs of containerboot. Check + // hasKubeStateStore because although we know we're in kube, that + // doesn't guarantee the state store is properly configured. + if hasKubeStateStore(cfg) { + if err := kc.resetContainerbootState(bootCtx, cfg.PodUID); err != nil { + return fmt.Errorf("error clearing previous state from Secret: %w", err) + } } } client, daemonProcess, err := startTailscaled(bootCtx, cfg) if err != nil { - log.Fatalf("failed to bring up tailscale: %v", err) + return fmt.Errorf("failed to bring up tailscale: %w", err) } killTailscaled := func() { + // The default termination grace period for a Pod is 30s. We wait 25s at + // most so that we still reserve some of that budget for tailscaled + // to receive and react to a SIGTERM before the SIGKILL that k8s + // will send at the end of the grace period. + ctx, cancel := context.WithTimeout(context.Background(), 25*time.Second) + defer cancel() + + if err := services.EnsureServicesNotAdvertised(ctx, client, log.Printf); err != nil { + log.Printf("Error ensuring services are not advertised: %v", err) + } + + if hasKubeStateStore(cfg) { + // Check we're not shutting tailscaled down while it's still writing + // state. If we authenticate and fail to write all the state, we'll + // never recover automatically. + log.Printf("Checking for consistent state") + err := kc.waitForConsistentState(ctx) + if err != nil { + log.Printf("Error waiting for consistent state on shutdown: %v", err) + } + } + log.Printf("Sending SIGTERM to tailscaled") if err := daemonProcess.Signal(unix.SIGTERM); err != nil { log.Fatalf("error shutting tailscaled down: %v", err) } } defer killTailscaled() + var healthCheck *healthz.Healthz + ep := &egressProxy{} + if cfg.HealthCheckAddrPort != "" { + mux := http.NewServeMux() + + log.Printf("Running healthcheck endpoint at %s/healthz", cfg.HealthCheckAddrPort) + healthCheck = healthz.RegisterHealthHandlers(mux, cfg.PodIPv4, log.Printf) + + close := runHTTPServer(mux, cfg.HealthCheckAddrPort) + defer close() + } + + if cfg.localMetricsEnabled() || cfg.localHealthEnabled() || cfg.egressSvcsTerminateEPEnabled() { + mux := http.NewServeMux() + + if cfg.localMetricsEnabled() { + log.Printf("Running metrics endpoint at %s/metrics", cfg.LocalAddrPort) + metrics.RegisterMetricsHandlers(mux, client, cfg.DebugAddrPort) + } + + if cfg.localHealthEnabled() { + log.Printf("Running healthcheck endpoint at %s/healthz", cfg.LocalAddrPort) + healthCheck = healthz.RegisterHealthHandlers(mux, cfg.PodIPv4, log.Printf) + } + + if cfg.egressSvcsTerminateEPEnabled() { + log.Printf("Running egress preshutdown hook at %s%s", cfg.LocalAddrPort, kubetypes.EgessServicesPreshutdownEP) + ep.registerHandlers(mux) + } + + close := runHTTPServer(mux, cfg.LocalAddrPort) + defer close() + } + if cfg.EnableForwardingOptimizations { if err := client.SetUDPGROForwarding(bootCtx); err != nil { log.Printf("[unexpected] error enabling UDP GRO forwarding: %v", err) @@ -214,7 +276,7 @@ func main() { w, err := client.WatchIPNBus(bootCtx, ipn.NotifyInitialNetMap|ipn.NotifyInitialPrefs|ipn.NotifyInitialState) if err != nil { - log.Fatalf("failed to watch tailscaled for updates: %v", err) + return fmt.Errorf("failed to watch tailscaled for updates: %w", err) } // Now that we've started tailscaled, we can symlink the socket to the @@ -250,18 +312,18 @@ func main() { didLogin = true w.Close() if err := tailscaleUp(bootCtx, cfg); err != nil { - return fmt.Errorf("failed to auth tailscale: %v", err) + return fmt.Errorf("failed to auth tailscale: %w", err) } w, err = client.WatchIPNBus(bootCtx, ipn.NotifyInitialNetMap|ipn.NotifyInitialState) if err != nil { - return fmt.Errorf("rewatching tailscaled for updates after auth: %v", err) + return fmt.Errorf("rewatching tailscaled for updates after auth: %w", err) } return nil } if isTwoStepConfigAlwaysAuth(cfg) { if err := authTailscale(); err != nil { - log.Fatalf("failed to auth tailscale: %v", err) + return fmt.Errorf("failed to auth tailscale: %w", err) } } @@ -269,7 +331,7 @@ authLoop: for { n, err := w.Next() if err != nil { - log.Fatalf("failed to read from tailscaled: %v", err) + return fmt.Errorf("failed to read from tailscaled: %w", err) } if n.State != nil { @@ -278,10 +340,10 @@ authLoop: if isOneStepConfig(cfg) { // This could happen if this is the first time tailscaled was run for this // device and the auth key was not passed via the configfile. - log.Fatalf("invalid state: tailscaled daemon started with a config file, but tailscale is not logged in: ensure you pass a valid auth key in the config file.") + return fmt.Errorf("invalid state: tailscaled daemon started with a config file, but tailscale is not logged in: ensure you pass a valid auth key in the config file.") } if err := authTailscale(); err != nil { - log.Fatalf("failed to auth tailscale: %v", err) + return fmt.Errorf("failed to auth tailscale: %w", err) } case ipn.NeedsMachineAuth: log.Printf("machine authorization required, please visit the admin panel") @@ -301,22 +363,20 @@ authLoop: w.Close() - ctx, cancel := contextWithExitSignalWatch() - defer cancel() - if isTwoStepConfigAuthOnce(cfg) { // Now that we are authenticated, we can set/reset any of the // settings that we need to. if err := tailscaleSet(ctx, cfg); err != nil { - log.Fatalf("failed to auth tailscale: %v", err) + return fmt.Errorf("failed to auth tailscale: %w", err) } } + // Remove any serve config and advertised HTTPS endpoint that may have been set by a previous run of + // containerboot, but only if we're providing a new one. if cfg.ServeConfigPath != "" { - // Remove any serve config that may have been set by a previous run of - // containerboot, but only if we're providing a new one. + log.Printf("serve proxy: unsetting previous config") if err := client.SetServeConfig(ctx, new(ipn.ServeConfig)); err != nil { - log.Fatalf("failed to unset serve config: %v", err) + return fmt.Errorf("failed to unset serve config: %w", err) } } @@ -325,14 +385,20 @@ authLoop: // authkey is no longer needed. We don't strictly need to // wipe it, but it's good hygiene. log.Printf("Deleting authkey from kube secret") - if err := deleteAuthKey(ctx, cfg.KubeSecret); err != nil { - log.Fatalf("deleting authkey from kube secret: %v", err) + if err := kc.deleteAuthKey(ctx); err != nil { + return fmt.Errorf("deleting authkey from kube secret: %w", err) } } w, err = client.WatchIPNBus(ctx, ipn.NotifyInitialNetMap|ipn.NotifyInitialState) if err != nil { - log.Fatalf("rewatching tailscaled for updates after auth: %v", err) + return fmt.Errorf("rewatching tailscaled for updates after auth: %w", err) + } + + // If tailscaled config was read from a mounted file, watch the file for updates and reload. + cfgWatchErrChan := make(chan error) + if cfg.TailscaledConfigFilePath != "" { + go watchTailscaledConfigChanges(ctx, cfg.TailscaledConfigFilePath, client, cfgWatchErrChan) } var ( @@ -349,17 +415,14 @@ authLoop: certDomain = new(atomic.Pointer[string]) certDomainChanged = make(chan bool, 1) - h = &healthz{} // http server for the healthz endpoint - healthzRunner = sync.OnceFunc(func() { runHealthz(cfg.HealthCheckAddrPort, h) }) + triggerWatchServeConfigChanges sync.Once ) - if cfg.ServeConfigPath != "" { - go watchServeConfigChanges(ctx, cfg.ServeConfigPath, certDomainChanged, certDomain, client) - } + var nfr linuxfw.NetfilterRunner if isL3Proxy(cfg) { nfr, err = newNetfilterRunner(log.Printf) if err != nil { - log.Fatalf("error creating new netfilter runner: %v", err) + return fmt.Errorf("error creating new netfilter runner: %w", err) } } @@ -377,7 +440,8 @@ authLoop: ) // egressSvcsErrorChan will get an error sent to it if this containerboot instance is configured to expose 1+ // egress services in HA mode and errored. - var egressSvcsErrorChan = make(chan error) + egressSvcsErrorChan := make(chan error) + ingressSvcsErrorChan := make(chan error) defer t.Stop() // resetTimer resets timer for when to next attempt to resolve the DNS // name for the proxy configured with TS_EXPERIMENTAL_DEST_DNS_NAME. The @@ -430,7 +494,9 @@ runLoop: killTailscaled() break runLoop case err := <-errChan: - log.Fatalf("failed to read from tailscaled: %v", err) + return fmt.Errorf("failed to read from tailscaled: %w", err) + case err := <-cfgWatchErrChan: + return fmt.Errorf("failed to watch tailscaled config: %w", err) case n := <-notifyChan: if n.State != nil && *n.State != ipn.Running { // Something's gone wrong and we've left the authenticated state. @@ -438,7 +504,7 @@ runLoop: // control flow required to make it work now is hard. So, just crash // the container and rely on the container runtime to restart us, // whereupon we'll go through initial auth again. - log.Fatalf("tailscaled left running state (now in state %q), exiting", *n.State) + return fmt.Errorf("tailscaled left running state (now in state %q), exiting", *n.State) } if n.NetMap != nil { addrs = n.NetMap.SelfNode.Addresses().AsSlice() @@ -455,8 +521,8 @@ runLoop: // fails. deviceID := n.NetMap.SelfNode.StableID() if hasKubeStateStore(cfg) && deephash.Update(¤tDeviceID, &deviceID) { - if err := storeDeviceID(ctx, cfg.KubeSecret, n.NetMap.SelfNode.StableID()); err != nil { - log.Fatalf("storing device ID in Kubernetes Secret: %v", err) + if err := kc.storeDeviceID(ctx, n.NetMap.SelfNode.StableID()); err != nil { + return fmt.Errorf("storing device ID in Kubernetes Secret: %w", err) } } if cfg.TailnetTargetFQDN != "" { @@ -493,12 +559,12 @@ runLoop: rulesInstalled = true log.Printf("Installing forwarding rules for destination %v", ea.String()) if err := installEgressForwardingRule(ctx, ea.String(), addrs, nfr); err != nil { - log.Fatalf("installing egress proxy rules for destination %s: %v", ea.String(), err) + return fmt.Errorf("installing egress proxy rules for destination %s: %v", ea.String(), err) } } } if !rulesInstalled { - log.Fatalf("no forwarding rules for egress addresses %v, host supports IPv6: %v", egressAddrs, nfr.HasIPV6NAT()) + return fmt.Errorf("no forwarding rules for egress addresses %v, host supports IPv6: %v", egressAddrs, nfr.HasIPV6NAT()) } } currentEgressIPs = newCurentEgressIPs @@ -506,7 +572,7 @@ runLoop: if cfg.ProxyTargetIP != "" && len(addrs) != 0 && ipsHaveChanged { log.Printf("Installing proxy rules") if err := installIngressForwardingRule(ctx, cfg.ProxyTargetIP, addrs, nfr); err != nil { - log.Fatalf("installing ingress proxy rules: %v", err) + return fmt.Errorf("installing ingress proxy rules: %w", err) } } if cfg.ProxyTargetDNSName != "" && len(addrs) != 0 && ipsHaveChanged { @@ -522,14 +588,17 @@ runLoop: if backendsHaveChanged { log.Printf("installing ingress proxy rules for backends %v", newBackendAddrs) if err := installIngressForwardingRuleForDNSTarget(ctx, newBackendAddrs, addrs, nfr); err != nil { - log.Fatalf("error installing ingress proxy rules: %v", err) + return fmt.Errorf("error installing ingress proxy rules: %w", err) } } resetTimer(false) backendAddrs = newBackendAddrs } - if cfg.ServeConfigPath != "" && len(n.NetMap.DNS.CertDomains) != 0 { - cd := n.NetMap.DNS.CertDomains[0] + if cfg.ServeConfigPath != "" { + cd := certDomainFromNetmap(n.NetMap) + if cd == "" { + cd = kubetypes.ValueNoHTTPS + } prev := certDomain.Swap(ptr.To(cd)) if prev == nil || *prev != cd { select { @@ -541,7 +610,7 @@ runLoop: if cfg.TailnetTargetIP != "" && ipsHaveChanged && len(addrs) != 0 { log.Printf("Installing forwarding rules for destination %v", cfg.TailnetTargetIP) if err := installEgressForwardingRule(ctx, cfg.TailnetTargetIP, addrs, nfr); err != nil { - log.Fatalf("installing egress proxy rules: %v", err) + return fmt.Errorf("installing egress proxy rules: %w", err) } } // If this is a L7 cluster ingress proxy (set up @@ -553,7 +622,7 @@ runLoop: if cfg.AllowProxyingClusterTrafficViaIngress && cfg.ServeConfigPath != "" && ipsHaveChanged && len(addrs) != 0 { log.Printf("installing rules to forward traffic for %s to node's tailnet IP", cfg.PodIP) if err := installTSForwardingRuleForDestination(ctx, cfg.PodIP, addrs, nfr); err != nil { - log.Fatalf("installing rules to forward traffic to node's tailnet IP: %v", err) + return fmt.Errorf("installing rules to forward traffic to node's tailnet IP: %w", err) } } currentIPs = newCurrentIPs @@ -571,17 +640,21 @@ runLoop: // TODO (irbekrm): instead of using the IP and FQDN, have some other mechanism for the proxy signal that it is 'Ready'. deviceEndpoints := []any{n.NetMap.SelfNode.Name(), n.NetMap.SelfNode.Addresses()} if hasKubeStateStore(cfg) && deephash.Update(¤tDeviceEndpoints, &deviceEndpoints) { - if err := storeDeviceEndpoints(ctx, cfg.KubeSecret, n.NetMap.SelfNode.Name(), n.NetMap.SelfNode.Addresses().AsSlice()); err != nil { - log.Fatalf("storing device IPs and FQDN in Kubernetes Secret: %v", err) + if err := kc.storeDeviceEndpoints(ctx, n.NetMap.SelfNode.Name(), n.NetMap.SelfNode.Addresses().AsSlice()); err != nil { + return fmt.Errorf("storing device IPs and FQDN in Kubernetes Secret: %w", err) } } - if cfg.HealthCheckAddrPort != "" { - h.Lock() - h.hasAddrs = len(addrs) != 0 - h.Unlock() - healthzRunner() + if healthCheck != nil { + healthCheck.Update(len(addrs) != 0) } + + if cfg.ServeConfigPath != "" { + triggerWatchServeConfigChanges.Do(func() { + go watchServeConfigChanges(ctx, certDomainChanged, certDomain, client, kc, cfg) + }) + } + if egressSvcsNotify != nil { egressSvcsNotify <- n } @@ -603,24 +676,42 @@ runLoop: // will then continuously monitor the config file and netmap updates and // reconfigure the firewall rules as needed. If any of its operations fail, it // will crash this node. - if cfg.EgressSvcsCfgPath != "" { - log.Printf("configuring egress proxy using configuration file at %s", cfg.EgressSvcsCfgPath) + if cfg.EgressProxiesCfgPath != "" { + log.Printf("configuring egress proxy using configuration file at %s", cfg.EgressProxiesCfgPath) egressSvcsNotify = make(chan ipn.Notify) - ep := egressProxy{ - cfgPath: cfg.EgressSvcsCfgPath, + opts := egressProxyRunOpts{ + cfgPath: cfg.EgressProxiesCfgPath, nfr: nfr, kc: kc, + tsClient: client, stateSecret: cfg.KubeSecret, netmapChan: egressSvcsNotify, - podIP: cfg.PodIP, + podIPv4: cfg.PodIPv4, tailnetAddrs: addrs, } go func() { - if err := ep.run(ctx, n); err != nil { + if err := ep.run(ctx, n, opts); err != nil { egressSvcsErrorChan <- err } }() } + ip := ingressProxy{} + if cfg.IngressProxiesCfgPath != "" { + log.Printf("configuring ingress proxy using configuration file at %s", cfg.IngressProxiesCfgPath) + opts := ingressProxyOpts{ + cfgPath: cfg.IngressProxiesCfgPath, + nfr: nfr, + kc: kc, + stateSecret: cfg.KubeSecret, + podIPv4: cfg.PodIPv4, + podIPv6: cfg.PodIPv6, + } + go func() { + if err := ip.run(ctx, opts); err != nil { + ingressSvcsErrorChan <- err + } + }() + } // Wait on tailscaled process. It won't be cleaned up by default when the // container exits as it is not PID1. TODO (irbekrm): perhaps we can replace the @@ -658,16 +749,20 @@ runLoop: if backendsHaveChanged && len(addrs) != 0 { log.Printf("Backend address change detected, installing proxy rules for backends %v", newBackendAddrs) if err := installIngressForwardingRuleForDNSTarget(ctx, newBackendAddrs, addrs, nfr); err != nil { - log.Fatalf("installing ingress proxy rules for DNS target %s: %v", cfg.ProxyTargetDNSName, err) + return fmt.Errorf("installing ingress proxy rules for DNS target %s: %v", cfg.ProxyTargetDNSName, err) } } backendAddrs = newBackendAddrs resetTimer(false) case e := <-egressSvcsErrorChan: - log.Fatalf("egress proxy failed: %v", e) + return fmt.Errorf("egress proxy failed: %v", e) + case e := <-ingressSvcsErrorChan: + return fmt.Errorf("ingress proxy failed: %v", e) } } wg.Wait() + + return nil } // ensureTunFile checks that /dev/net/tun exists, creating it if @@ -696,13 +791,13 @@ func resolveDNS(ctx context.Context, name string) ([]net.IP, error) { ip4s, err := net.DefaultResolver.LookupIP(ctx, "ip4", name) if err != nil { if e, ok := err.(*net.DNSError); !(ok && e.IsNotFound) { - return nil, fmt.Errorf("error looking up IPv4 addresses: %v", err) + return nil, fmt.Errorf("error looking up IPv4 addresses: %w", err) } } ip6s, err := net.DefaultResolver.LookupIP(ctx, "ip6", name) if err != nil { if e, ok := err.(*net.DNSError); !(ok && e.IsNotFound) { - return nil, fmt.Errorf("error looking up IPv6 addresses: %v", err) + return nil, fmt.Errorf("error looking up IPv6 addresses: %w", err) } } if len(ip4s) == 0 && len(ip6s) == 0 { @@ -715,7 +810,7 @@ func resolveDNS(ctx context.Context, name string) ([]net.IP, error) { // context that gets cancelled when a signal is received and a cancel function // that can be called to free the resources when the watch should be stopped. func contextWithExitSignalWatch() (context.Context, func()) { - closeChan := make(chan string) + closeChan := make(chan struct{}) ctx, cancel := context.WithCancel(context.Background()) signalChan := make(chan os.Signal, 1) signal.Notify(signalChan, syscall.SIGINT, syscall.SIGTERM) @@ -727,8 +822,11 @@ func contextWithExitSignalWatch() (context.Context, func()) { return } }() + closeOnce := sync.Once{} f := func() { - closeChan <- "goodbye" + closeOnce.Do(func() { + close(closeChan) + }) } return ctx, f } @@ -758,7 +856,6 @@ func tailscaledConfigFilePath() string { } cv, err := kubeutils.CapVerFromFileName(e.Name()) if err != nil { - log.Printf("skipping file %q in tailscaled config directory %q: %v", e.Name(), dir, err) continue } if cv > maxCompatVer && cv <= tailcfg.CurrentCapabilityVersion { @@ -766,8 +863,32 @@ func tailscaledConfigFilePath() string { } } if maxCompatVer == -1 { - log.Fatalf("no tailscaled config file found in %q for current capability version %q", dir, tailcfg.CurrentCapabilityVersion) + log.Fatalf("no tailscaled config file found in %q for current capability version %d", dir, tailcfg.CurrentCapabilityVersion) + } + filePath := filepath.Join(dir, kubeutils.TailscaledConfigFileName(maxCompatVer)) + log.Printf("Using tailscaled config file %q to match current capability version %d", filePath, tailcfg.CurrentCapabilityVersion) + return filePath +} + +func runHTTPServer(mux *http.ServeMux, addr string) (close func() error) { + ln, err := net.Listen("tcp", addr) + if err != nil { + log.Fatalf("failed to listen on addr %q: %v", addr, err) + } + srv := &http.Server{Handler: mux} + + go func() { + if err := srv.Serve(ln); err != nil { + if err != http.ErrServerClosed { + log.Fatalf("failed running server: %v", err) + } else { + log.Printf("HTTP server at %s closed", addr) + } + } + }() + + return func() error { + err := srv.Shutdown(context.Background()) + return errors.Join(err, ln.Close()) } - log.Printf("Using tailscaled config file %q for capability version %q", maxCompatVer, tailcfg.CurrentCapabilityVersion) - return path.Join(dir, kubeutils.TailscaledConfigFileNameForCap(maxCompatVer)) } diff --git a/cmd/containerboot/main_test.go b/cmd/containerboot/main_test.go index 5c92787ce..f92f35333 100644 --- a/cmd/containerboot/main_test.go +++ b/cmd/containerboot/main_test.go @@ -25,12 +25,16 @@ import ( "strconv" "strings" "sync" + "syscall" "testing" "time" "github.com/google/go-cmp/cmp" "golang.org/x/sys/unix" "tailscale.com/ipn" + "tailscale.com/kube/egressservices" + "tailscale.com/kube/kubeclient" + "tailscale.com/kube/kubetypes" "tailscale.com/tailcfg" "tailscale.com/tstest" "tailscale.com/types/netmap" @@ -38,69 +42,23 @@ import ( ) func TestContainerBoot(t *testing.T) { - d := t.TempDir() - - lapi := localAPI{FSRoot: d} - if err := lapi.Start(); err != nil { - t.Fatal(err) - } - defer lapi.Close() - - kube := kubeServer{FSRoot: d} - if err := kube.Start(); err != nil { - t.Fatal(err) - } - defer kube.Close() - - tailscaledConf := &ipn.ConfigVAlpha{AuthKey: ptr.To("foo"), Version: "alpha0"} - tailscaledConfBytes, err := json.Marshal(tailscaledConf) - if err != nil { - t.Fatalf("error unmarshaling tailscaled config: %v", err) + boot := filepath.Join(t.TempDir(), "containerboot") + if err := exec.Command("go", "build", "-ldflags", "-X main.testSleepDuration=1ms", "-o", boot, "tailscale.com/cmd/containerboot").Run(); err != nil { + t.Fatalf("Building containerboot: %v", err) } + egressStatus := egressSvcStatus("foo", "foo.tailnetxyz.ts.net") - dirs := []string{ - "var/lib", - "usr/bin", - "tmp", - "dev/net", - "proc/sys/net/ipv4", - "proc/sys/net/ipv6/conf/all", - "etc/tailscaled", - } - for _, path := range dirs { - if err := os.MkdirAll(filepath.Join(d, path), 0700); err != nil { - t.Fatal(err) - } + metricsURL := func(port int) string { + return fmt.Sprintf("http://127.0.0.1:%d/metrics", port) } - files := map[string][]byte{ - "usr/bin/tailscaled": fakeTailscaled, - "usr/bin/tailscale": fakeTailscale, - "usr/bin/iptables": fakeTailscale, - "usr/bin/ip6tables": fakeTailscale, - "dev/net/tun": []byte(""), - "proc/sys/net/ipv4/ip_forward": []byte("0"), - "proc/sys/net/ipv6/conf/all/forwarding": []byte("0"), - "etc/tailscaled/cap-95.hujson": tailscaledConfBytes, + healthURL := func(port int) string { + return fmt.Sprintf("http://127.0.0.1:%d/healthz", port) } - resetFiles := func() { - for path, content := range files { - // Making everything executable is a little weird, but the - // stuff that doesn't need to be executable doesn't care if we - // do make it executable. - if err := os.WriteFile(filepath.Join(d, path), content, 0700); err != nil { - t.Fatal(err) - } - } + egressSvcTerminateURL := func(port int) string { + return fmt.Sprintf("http://127.0.0.1:%d%s", port, kubetypes.EgessServicesPreshutdownEP) } - resetFiles() - boot := filepath.Join(d, "containerboot") - if err := exec.Command("go", "build", "-o", boot, "tailscale.com/cmd/containerboot").Run(); err != nil { - t.Fatalf("Building containerboot: %v", err) - } - - argFile := filepath.Join(d, "args") - runningSockPath := filepath.Join(d, "tmp/tailscaled.sock") + capver := fmt.Sprintf("%d", tailcfg.CurrentCapabilityVersion) type phase struct { // If non-nil, send this IPN bus notification (and remember it as the @@ -110,15 +68,31 @@ func TestContainerBoot(t *testing.T) { // WantCmds is the commands that containerboot should run in this phase. WantCmds []string + // WantKubeSecret is the secret keys/values that should exist in the // kube secret. WantKubeSecret map[string]string + + // Update the kube secret with these keys/values at the beginning of the + // phase (simulates our fake tailscaled doing it). + UpdateKubeSecret map[string]string + // WantFiles files that should exist in the container and their // contents. WantFiles map[string]string - // WantFatalLog is the fatal log message we expect from containerboot. - // If set for a phase, the test will finish on that phase. - WantFatalLog string + + // WantLog is a log message we expect from containerboot. + WantLog string + + // If set for a phase, the test will expect containerboot to exit with + // this error code, and the test will finish on that phase without + // waiting for the successful startup log message. + WantExitCode *int + + // The signal to send to containerboot at the start of the phase. + Signal *syscall.Signal + + EndpointStatuses map[string]int } runningNotify := &ipn.Notify{ State: ptr.To(ipn.Running), @@ -130,601 +104,978 @@ func TestContainerBoot(t *testing.T) { }).View(), }, } - tests := []struct { - Name string + type testCase struct { Env map[string]string KubeSecret map[string]string KubeDenyPatch bool Phases []phase - }{ - { - // Out of the box default: runs in userspace mode, ephemeral storage, interactive login. - Name: "no_args", - Env: nil, - Phases: []phase{ - { - WantCmds: []string{ - "/usr/bin/tailscaled --socket=/tmp/tailscaled.sock --state=mem: --statedir=/tmp --tun=userspace-networking", - "/usr/bin/tailscale --socket=/tmp/tailscaled.sock up --accept-dns=false", + } + tests := map[string]func(env *testEnv) testCase{ + "no_args": func(env *testEnv) testCase { + return testCase{ + // Out of the box default: runs in userspace mode, ephemeral storage, interactive login. + Env: nil, + Phases: []phase{ + { + WantCmds: []string{ + "/usr/bin/tailscaled --socket=/tmp/tailscaled.sock --state=mem: --statedir=/tmp --tun=userspace-networking", + "/usr/bin/tailscale --socket=/tmp/tailscaled.sock up --accept-dns=false", + }, + // No metrics or health by default. + EndpointStatuses: map[string]int{ + metricsURL(9002): -1, + healthURL(9002): -1, + }, + }, + { + Notify: runningNotify, }, }, - { - Notify: runningNotify, - }, - }, + } }, - { - // Userspace mode, ephemeral storage, authkey provided on every run. - Name: "authkey", - Env: map[string]string{ - "TS_AUTHKEY": "tskey-key", - }, - Phases: []phase{ - { - WantCmds: []string{ - "/usr/bin/tailscaled --socket=/tmp/tailscaled.sock --state=mem: --statedir=/tmp --tun=userspace-networking", - "/usr/bin/tailscale --socket=/tmp/tailscaled.sock up --accept-dns=false --authkey=tskey-key", - }, + "authkey": func(env *testEnv) testCase { + return testCase{ + // Userspace mode, ephemeral storage, authkey provided on every run. + Env: map[string]string{ + "TS_AUTHKEY": "tskey-key", }, - { - Notify: runningNotify, + Phases: []phase{ + { + WantCmds: []string{ + "/usr/bin/tailscaled --socket=/tmp/tailscaled.sock --state=mem: --statedir=/tmp --tun=userspace-networking", + "/usr/bin/tailscale --socket=/tmp/tailscaled.sock up --accept-dns=false --authkey=tskey-key", + }, + }, + { + Notify: runningNotify, + }, }, - }, + } }, - { - // Userspace mode, ephemeral storage, authkey provided on every run. - Name: "authkey-old-flag", - Env: map[string]string{ - "TS_AUTH_KEY": "tskey-key", - }, - Phases: []phase{ - { - WantCmds: []string{ - "/usr/bin/tailscaled --socket=/tmp/tailscaled.sock --state=mem: --statedir=/tmp --tun=userspace-networking", - "/usr/bin/tailscale --socket=/tmp/tailscaled.sock up --accept-dns=false --authkey=tskey-key", - }, + "authkey_old_flag": func(env *testEnv) testCase { + return testCase{ + // Userspace mode, ephemeral storage, authkey provided on every run. + Env: map[string]string{ + "TS_AUTH_KEY": "tskey-key", }, - { - Notify: runningNotify, + Phases: []phase{ + { + WantCmds: []string{ + "/usr/bin/tailscaled --socket=/tmp/tailscaled.sock --state=mem: --statedir=/tmp --tun=userspace-networking", + "/usr/bin/tailscale --socket=/tmp/tailscaled.sock up --accept-dns=false --authkey=tskey-key", + }, + }, + { + Notify: runningNotify, + }, }, - }, + } }, - { - Name: "authkey_disk_state", - Env: map[string]string{ - "TS_AUTHKEY": "tskey-key", - "TS_STATE_DIR": filepath.Join(d, "tmp"), - }, - Phases: []phase{ - { - WantCmds: []string{ - "/usr/bin/tailscaled --socket=/tmp/tailscaled.sock --statedir=/tmp --tun=userspace-networking", - "/usr/bin/tailscale --socket=/tmp/tailscaled.sock up --accept-dns=false --authkey=tskey-key", - }, + "authkey_disk_state": func(env *testEnv) testCase { + return testCase{ + Env: map[string]string{ + "TS_AUTHKEY": "tskey-key", + "TS_STATE_DIR": filepath.Join(env.d, "tmp"), }, - { - Notify: runningNotify, + Phases: []phase{ + { + WantCmds: []string{ + "/usr/bin/tailscaled --socket=/tmp/tailscaled.sock --statedir=/tmp --tun=userspace-networking", + "/usr/bin/tailscale --socket=/tmp/tailscaled.sock up --accept-dns=false --authkey=tskey-key", + }, + }, + { + Notify: runningNotify, + }, }, - }, + } }, - { - Name: "routes", - Env: map[string]string{ - "TS_AUTHKEY": "tskey-key", - "TS_ROUTES": "1.2.3.0/24,10.20.30.0/24", - }, - Phases: []phase{ - { - WantCmds: []string{ - "/usr/bin/tailscaled --socket=/tmp/tailscaled.sock --state=mem: --statedir=/tmp --tun=userspace-networking", - "/usr/bin/tailscale --socket=/tmp/tailscaled.sock up --accept-dns=false --authkey=tskey-key --advertise-routes=1.2.3.0/24,10.20.30.0/24", - }, + "routes": func(env *testEnv) testCase { + return testCase{ + Env: map[string]string{ + "TS_AUTHKEY": "tskey-key", + "TS_ROUTES": "1.2.3.0/24,10.20.30.0/24", }, - { - Notify: runningNotify, - WantFiles: map[string]string{ - "proc/sys/net/ipv4/ip_forward": "0", - "proc/sys/net/ipv6/conf/all/forwarding": "0", + Phases: []phase{ + { + WantCmds: []string{ + "/usr/bin/tailscaled --socket=/tmp/tailscaled.sock --state=mem: --statedir=/tmp --tun=userspace-networking", + "/usr/bin/tailscale --socket=/tmp/tailscaled.sock up --accept-dns=false --authkey=tskey-key --advertise-routes=1.2.3.0/24,10.20.30.0/24", + }, + }, + { + Notify: runningNotify, + WantFiles: map[string]string{ + "proc/sys/net/ipv4/ip_forward": "0", + "proc/sys/net/ipv6/conf/all/forwarding": "0", + }, }, }, - }, + } }, - { - Name: "empty routes", - Env: map[string]string{ - "TS_AUTHKEY": "tskey-key", - "TS_ROUTES": "", - }, - Phases: []phase{ - { - WantCmds: []string{ - "/usr/bin/tailscaled --socket=/tmp/tailscaled.sock --state=mem: --statedir=/tmp --tun=userspace-networking", - "/usr/bin/tailscale --socket=/tmp/tailscaled.sock up --accept-dns=false --authkey=tskey-key --advertise-routes=", - }, + "empty_routes": func(env *testEnv) testCase { + return testCase{ + Env: map[string]string{ + "TS_AUTHKEY": "tskey-key", + "TS_ROUTES": "", }, - { - Notify: runningNotify, - WantFiles: map[string]string{ - "proc/sys/net/ipv4/ip_forward": "0", - "proc/sys/net/ipv6/conf/all/forwarding": "0", + Phases: []phase{ + { + WantCmds: []string{ + "/usr/bin/tailscaled --socket=/tmp/tailscaled.sock --state=mem: --statedir=/tmp --tun=userspace-networking", + "/usr/bin/tailscale --socket=/tmp/tailscaled.sock up --accept-dns=false --authkey=tskey-key --advertise-routes=", + }, + }, + { + Notify: runningNotify, + WantFiles: map[string]string{ + "proc/sys/net/ipv4/ip_forward": "0", + "proc/sys/net/ipv6/conf/all/forwarding": "0", + }, }, }, - }, + } }, - { - Name: "routes_kernel_ipv4", - Env: map[string]string{ - "TS_AUTHKEY": "tskey-key", - "TS_ROUTES": "1.2.3.0/24,10.20.30.0/24", - "TS_USERSPACE": "false", - }, - Phases: []phase{ - { - WantCmds: []string{ - "/usr/bin/tailscaled --socket=/tmp/tailscaled.sock --state=mem: --statedir=/tmp", - "/usr/bin/tailscale --socket=/tmp/tailscaled.sock up --accept-dns=false --authkey=tskey-key --advertise-routes=1.2.3.0/24,10.20.30.0/24", - }, + "routes_kernel_ipv4": func(env *testEnv) testCase { + return testCase{ + Env: map[string]string{ + "TS_AUTHKEY": "tskey-key", + "TS_ROUTES": "1.2.3.0/24,10.20.30.0/24", + "TS_USERSPACE": "false", }, - { - Notify: runningNotify, - WantFiles: map[string]string{ - "proc/sys/net/ipv4/ip_forward": "1", - "proc/sys/net/ipv6/conf/all/forwarding": "0", + Phases: []phase{ + { + WantCmds: []string{ + "/usr/bin/tailscaled --socket=/tmp/tailscaled.sock --state=mem: --statedir=/tmp", + "/usr/bin/tailscale --socket=/tmp/tailscaled.sock up --accept-dns=false --authkey=tskey-key --advertise-routes=1.2.3.0/24,10.20.30.0/24", + }, + }, + { + Notify: runningNotify, + WantFiles: map[string]string{ + "proc/sys/net/ipv4/ip_forward": "1", + "proc/sys/net/ipv6/conf/all/forwarding": "0", + }, }, }, - }, + } }, - { - Name: "routes_kernel_ipv6", - Env: map[string]string{ - "TS_AUTHKEY": "tskey-key", - "TS_ROUTES": "::/64,1::/64", - "TS_USERSPACE": "false", - }, - Phases: []phase{ - { - WantCmds: []string{ - "/usr/bin/tailscaled --socket=/tmp/tailscaled.sock --state=mem: --statedir=/tmp", - "/usr/bin/tailscale --socket=/tmp/tailscaled.sock up --accept-dns=false --authkey=tskey-key --advertise-routes=::/64,1::/64", - }, + "routes_kernel_ipv6": func(env *testEnv) testCase { + return testCase{ + Env: map[string]string{ + "TS_AUTHKEY": "tskey-key", + "TS_ROUTES": "::/64,1::/64", + "TS_USERSPACE": "false", }, - { - Notify: runningNotify, - WantFiles: map[string]string{ - "proc/sys/net/ipv4/ip_forward": "0", - "proc/sys/net/ipv6/conf/all/forwarding": "1", + Phases: []phase{ + { + WantCmds: []string{ + "/usr/bin/tailscaled --socket=/tmp/tailscaled.sock --state=mem: --statedir=/tmp", + "/usr/bin/tailscale --socket=/tmp/tailscaled.sock up --accept-dns=false --authkey=tskey-key --advertise-routes=::/64,1::/64", + }, + }, + { + Notify: runningNotify, + WantFiles: map[string]string{ + "proc/sys/net/ipv4/ip_forward": "0", + "proc/sys/net/ipv6/conf/all/forwarding": "1", + }, }, }, - }, + } }, - { - Name: "routes_kernel_all_families", - Env: map[string]string{ - "TS_AUTHKEY": "tskey-key", - "TS_ROUTES": "::/64,1.2.3.0/24", - "TS_USERSPACE": "false", - }, - Phases: []phase{ - { - WantCmds: []string{ - "/usr/bin/tailscaled --socket=/tmp/tailscaled.sock --state=mem: --statedir=/tmp", - "/usr/bin/tailscale --socket=/tmp/tailscaled.sock up --accept-dns=false --authkey=tskey-key --advertise-routes=::/64,1.2.3.0/24", - }, + "routes_kernel_all_families": func(env *testEnv) testCase { + return testCase{ + Env: map[string]string{ + "TS_AUTHKEY": "tskey-key", + "TS_ROUTES": "::/64,1.2.3.0/24", + "TS_USERSPACE": "false", }, - { - Notify: runningNotify, - WantFiles: map[string]string{ - "proc/sys/net/ipv4/ip_forward": "1", - "proc/sys/net/ipv6/conf/all/forwarding": "1", + Phases: []phase{ + { + WantCmds: []string{ + "/usr/bin/tailscaled --socket=/tmp/tailscaled.sock --state=mem: --statedir=/tmp", + "/usr/bin/tailscale --socket=/tmp/tailscaled.sock up --accept-dns=false --authkey=tskey-key --advertise-routes=::/64,1.2.3.0/24", + }, + }, + { + Notify: runningNotify, + WantFiles: map[string]string{ + "proc/sys/net/ipv4/ip_forward": "1", + "proc/sys/net/ipv6/conf/all/forwarding": "1", + }, }, }, - }, + } }, - { - Name: "ingress proxy", - Env: map[string]string{ - "TS_AUTHKEY": "tskey-key", - "TS_DEST_IP": "1.2.3.4", - "TS_USERSPACE": "false", - }, - Phases: []phase{ - { - WantCmds: []string{ - "/usr/bin/tailscaled --socket=/tmp/tailscaled.sock --state=mem: --statedir=/tmp", - "/usr/bin/tailscale --socket=/tmp/tailscaled.sock up --accept-dns=false --authkey=tskey-key", - }, + "ingress_proxy": func(env *testEnv) testCase { + return testCase{ + Env: map[string]string{ + "TS_AUTHKEY": "tskey-key", + "TS_DEST_IP": "1.2.3.4", + "TS_USERSPACE": "false", }, - { - Notify: runningNotify, + Phases: []phase{ + { + WantCmds: []string{ + "/usr/bin/tailscaled --socket=/tmp/tailscaled.sock --state=mem: --statedir=/tmp", + "/usr/bin/tailscale --socket=/tmp/tailscaled.sock up --accept-dns=false --authkey=tskey-key", + }, + }, + { + Notify: runningNotify, + }, }, - }, + } }, - { - Name: "egress proxy", - Env: map[string]string{ - "TS_AUTHKEY": "tskey-key", - "TS_TAILNET_TARGET_IP": "100.99.99.99", - "TS_USERSPACE": "false", - }, - Phases: []phase{ - { - WantCmds: []string{ - "/usr/bin/tailscaled --socket=/tmp/tailscaled.sock --state=mem: --statedir=/tmp", - "/usr/bin/tailscale --socket=/tmp/tailscaled.sock up --accept-dns=false --authkey=tskey-key", + "egress_proxy": func(env *testEnv) testCase { + return testCase{ + Env: map[string]string{ + "TS_AUTHKEY": "tskey-key", + "TS_TAILNET_TARGET_IP": "100.99.99.99", + "TS_USERSPACE": "false", + }, + Phases: []phase{ + { + WantCmds: []string{ + "/usr/bin/tailscaled --socket=/tmp/tailscaled.sock --state=mem: --statedir=/tmp", + "/usr/bin/tailscale --socket=/tmp/tailscaled.sock up --accept-dns=false --authkey=tskey-key", + }, + WantFiles: map[string]string{ + "proc/sys/net/ipv4/ip_forward": "1", + "proc/sys/net/ipv6/conf/all/forwarding": "0", + }, }, - WantFiles: map[string]string{ - "proc/sys/net/ipv4/ip_forward": "1", - "proc/sys/net/ipv6/conf/all/forwarding": "0", + { + Notify: runningNotify, }, }, - { - Notify: runningNotify, - }, - }, + } }, - { - Name: "egress_proxy_fqdn_ipv6_target_on_ipv4_host", - Env: map[string]string{ - "TS_AUTHKEY": "tskey-key", - "TS_TAILNET_TARGET_FQDN": "ipv6-node.test.ts.net", // resolves to IPv6 address - "TS_USERSPACE": "false", - "TS_TEST_FAKE_NETFILTER_6": "false", - }, - Phases: []phase{ - { - WantCmds: []string{ - "/usr/bin/tailscaled --socket=/tmp/tailscaled.sock --state=mem: --statedir=/tmp", - "/usr/bin/tailscale --socket=/tmp/tailscaled.sock up --accept-dns=false --authkey=tskey-key", - }, - WantFiles: map[string]string{ - "proc/sys/net/ipv4/ip_forward": "1", - "proc/sys/net/ipv6/conf/all/forwarding": "0", - }, - }, - { - Notify: &ipn.Notify{ - State: ptr.To(ipn.Running), - NetMap: &netmap.NetworkMap{ - SelfNode: (&tailcfg.Node{ - StableID: tailcfg.StableNodeID("myID"), - Name: "test-node.test.ts.net", - Addresses: []netip.Prefix{netip.MustParsePrefix("100.64.0.1/32")}, - }).View(), - Peers: []tailcfg.NodeView{ - (&tailcfg.Node{ - StableID: tailcfg.StableNodeID("ipv6ID"), - Name: "ipv6-node.test.ts.net", - Addresses: []netip.Prefix{netip.MustParsePrefix("::1/128")}, + "egress_proxy_fqdn_ipv6_target_on_ipv4_host": func(env *testEnv) testCase { + return testCase{ + Env: map[string]string{ + "TS_AUTHKEY": "tskey-key", + "TS_TAILNET_TARGET_FQDN": "ipv6-node.test.ts.net", // resolves to IPv6 address + "TS_USERSPACE": "false", + "TS_TEST_FAKE_NETFILTER_6": "false", + }, + Phases: []phase{ + { + WantCmds: []string{ + "/usr/bin/tailscaled --socket=/tmp/tailscaled.sock --state=mem: --statedir=/tmp", + "/usr/bin/tailscale --socket=/tmp/tailscaled.sock up --accept-dns=false --authkey=tskey-key", + }, + WantFiles: map[string]string{ + "proc/sys/net/ipv4/ip_forward": "1", + "proc/sys/net/ipv6/conf/all/forwarding": "0", + }, + }, + { + Notify: &ipn.Notify{ + State: ptr.To(ipn.Running), + NetMap: &netmap.NetworkMap{ + SelfNode: (&tailcfg.Node{ + StableID: tailcfg.StableNodeID("myID"), + Name: "test-node.test.ts.net", + Addresses: []netip.Prefix{netip.MustParsePrefix("100.64.0.1/32")}, }).View(), + Peers: []tailcfg.NodeView{ + (&tailcfg.Node{ + StableID: tailcfg.StableNodeID("ipv6ID"), + Name: "ipv6-node.test.ts.net", + Addresses: []netip.Prefix{netip.MustParsePrefix("::1/128")}, + }).View(), + }, }, }, + WantLog: "no forwarding rules for egress addresses [::1/128], host supports IPv6: false", + WantExitCode: ptr.To(1), }, - WantFatalLog: "no forwarding rules for egress addresses [::1/128], host supports IPv6: false", }, - }, + } }, - { - Name: "authkey_once", - Env: map[string]string{ - "TS_AUTHKEY": "tskey-key", - "TS_AUTH_ONCE": "true", - }, - Phases: []phase{ - { - WantCmds: []string{ - "/usr/bin/tailscaled --socket=/tmp/tailscaled.sock --state=mem: --statedir=/tmp --tun=userspace-networking", - }, + "authkey_once": func(env *testEnv) testCase { + return testCase{ + Env: map[string]string{ + "TS_AUTHKEY": "tskey-key", + "TS_AUTH_ONCE": "true", }, - { - Notify: &ipn.Notify{ - State: ptr.To(ipn.NeedsLogin), + Phases: []phase{ + { + WantCmds: []string{ + "/usr/bin/tailscaled --socket=/tmp/tailscaled.sock --state=mem: --statedir=/tmp --tun=userspace-networking", + }, }, - WantCmds: []string{ - "/usr/bin/tailscale --socket=/tmp/tailscaled.sock up --accept-dns=false --authkey=tskey-key", + { + Notify: &ipn.Notify{ + State: ptr.To(ipn.NeedsLogin), + }, + WantCmds: []string{ + "/usr/bin/tailscale --socket=/tmp/tailscaled.sock up --accept-dns=false --authkey=tskey-key", + }, }, - }, - { - Notify: runningNotify, - WantCmds: []string{ - "/usr/bin/tailscale --socket=/tmp/tailscaled.sock set --accept-dns=false", + { + Notify: runningNotify, + WantCmds: []string{ + "/usr/bin/tailscale --socket=/tmp/tailscaled.sock set --accept-dns=false", + }, }, }, - }, + } }, - { - Name: "kube_storage", - Env: map[string]string{ - "KUBERNETES_SERVICE_HOST": kube.Host, - "KUBERNETES_SERVICE_PORT_HTTPS": kube.Port, - }, - KubeSecret: map[string]string{ - "authkey": "tskey-key", - }, - Phases: []phase{ - { - WantCmds: []string{ - "/usr/bin/tailscaled --socket=/tmp/tailscaled.sock --state=kube:tailscale --statedir=/tmp --tun=userspace-networking", - "/usr/bin/tailscale --socket=/tmp/tailscaled.sock up --accept-dns=false --authkey=tskey-key", + "auth_key_once_extra_args_override_dns": func(env *testEnv) testCase { + return testCase{ + Env: map[string]string{ + "TS_AUTHKEY": "tskey-key", + "TS_AUTH_ONCE": "true", + "TS_ACCEPT_DNS": "false", + "TS_EXTRA_ARGS": "--accept-dns", + }, + Phases: []phase{ + { + WantCmds: []string{ + "/usr/bin/tailscaled --socket=/tmp/tailscaled.sock --state=mem: --statedir=/tmp --tun=userspace-networking", + }, }, - WantKubeSecret: map[string]string{ - "authkey": "tskey-key", + { + Notify: &ipn.Notify{ + State: ptr.To(ipn.NeedsLogin), + }, + WantCmds: []string{ + "/usr/bin/tailscale --socket=/tmp/tailscaled.sock up --accept-dns=true --authkey=tskey-key", + }, }, - }, - { - Notify: runningNotify, - WantKubeSecret: map[string]string{ - "authkey": "tskey-key", - "device_fqdn": "test-node.test.ts.net", - "device_id": "myID", - "device_ips": `["100.64.0.1"]`, + { + Notify: runningNotify, + WantCmds: []string{ + "/usr/bin/tailscale --socket=/tmp/tailscaled.sock set --accept-dns=true", + }, }, }, - }, + } }, - { - Name: "kube_disk_storage", - Env: map[string]string{ - "KUBERNETES_SERVICE_HOST": kube.Host, - "KUBERNETES_SERVICE_PORT_HTTPS": kube.Port, - // Explicitly set to an empty value, to override the default of "tailscale". - "TS_KUBE_SECRET": "", - "TS_STATE_DIR": filepath.Join(d, "tmp"), - "TS_AUTHKEY": "tskey-key", - }, - KubeSecret: map[string]string{}, - Phases: []phase{ - { - WantCmds: []string{ - "/usr/bin/tailscaled --socket=/tmp/tailscaled.sock --statedir=/tmp --tun=userspace-networking", - "/usr/bin/tailscale --socket=/tmp/tailscaled.sock up --accept-dns=false --authkey=tskey-key", + "kube_storage": func(env *testEnv) testCase { + return testCase{ + Env: map[string]string{ + "KUBERNETES_SERVICE_HOST": env.kube.Host, + "KUBERNETES_SERVICE_PORT_HTTPS": env.kube.Port, + "POD_UID": "some-pod-uid", + }, + KubeSecret: map[string]string{ + "authkey": "tskey-key", + }, + Phases: []phase{ + { + WantCmds: []string{ + "/usr/bin/tailscaled --socket=/tmp/tailscaled.sock --state=kube:tailscale --statedir=/tmp --tun=userspace-networking", + "/usr/bin/tailscale --socket=/tmp/tailscaled.sock up --accept-dns=false --authkey=tskey-key", + }, + WantKubeSecret: map[string]string{ + "authkey": "tskey-key", + kubetypes.KeyCapVer: capver, + kubetypes.KeyPodUID: "some-pod-uid", + }, + }, + { + Notify: runningNotify, + WantKubeSecret: map[string]string{ + "authkey": "tskey-key", + "device_fqdn": "test-node.test.ts.net", + "device_id": "myID", + "device_ips": `["100.64.0.1"]`, + kubetypes.KeyCapVer: capver, + kubetypes.KeyPodUID: "some-pod-uid", + }, }, - WantKubeSecret: map[string]string{}, }, - { - Notify: runningNotify, - WantKubeSecret: map[string]string{}, + } + }, + "kube_disk_storage": func(env *testEnv) testCase { + return testCase{ + Env: map[string]string{ + "KUBERNETES_SERVICE_HOST": env.kube.Host, + "KUBERNETES_SERVICE_PORT_HTTPS": env.kube.Port, + // Explicitly set to an empty value, to override the default of "tailscale". + "TS_KUBE_SECRET": "", + "TS_STATE_DIR": filepath.Join(env.d, "tmp"), + "TS_AUTHKEY": "tskey-key", }, - }, + KubeSecret: map[string]string{}, + Phases: []phase{ + { + WantCmds: []string{ + "/usr/bin/tailscaled --socket=/tmp/tailscaled.sock --statedir=/tmp --tun=userspace-networking", + "/usr/bin/tailscale --socket=/tmp/tailscaled.sock up --accept-dns=false --authkey=tskey-key", + }, + WantKubeSecret: map[string]string{}, + }, + { + Notify: runningNotify, + WantKubeSecret: map[string]string{}, + }, + }, + } }, - { - Name: "kube_storage_no_patch", - Env: map[string]string{ - "KUBERNETES_SERVICE_HOST": kube.Host, - "KUBERNETES_SERVICE_PORT_HTTPS": kube.Port, - "TS_AUTHKEY": "tskey-key", - }, - KubeSecret: map[string]string{}, - KubeDenyPatch: true, - Phases: []phase{ - { - WantCmds: []string{ - "/usr/bin/tailscaled --socket=/tmp/tailscaled.sock --state=kube:tailscale --statedir=/tmp --tun=userspace-networking", - "/usr/bin/tailscale --socket=/tmp/tailscaled.sock up --accept-dns=false --authkey=tskey-key", + "kube_storage_no_patch": func(env *testEnv) testCase { + return testCase{ + Env: map[string]string{ + "KUBERNETES_SERVICE_HOST": env.kube.Host, + "KUBERNETES_SERVICE_PORT_HTTPS": env.kube.Port, + "TS_AUTHKEY": "tskey-key", + }, + KubeSecret: map[string]string{}, + KubeDenyPatch: true, + Phases: []phase{ + { + WantCmds: []string{ + "/usr/bin/tailscaled --socket=/tmp/tailscaled.sock --state=kube:tailscale --statedir=/tmp --tun=userspace-networking", + "/usr/bin/tailscale --socket=/tmp/tailscaled.sock up --accept-dns=false --authkey=tskey-key", + }, + WantKubeSecret: map[string]string{}, + }, + { + Notify: runningNotify, + WantKubeSecret: map[string]string{}, }, - WantKubeSecret: map[string]string{}, }, - { - Notify: runningNotify, - WantKubeSecret: map[string]string{}, + } + }, + "kube_storage_auth_once": func(env *testEnv) testCase { + return testCase{ + // Same as previous, but deletes the authkey from the kube secret. + Env: map[string]string{ + "KUBERNETES_SERVICE_HOST": env.kube.Host, + "KUBERNETES_SERVICE_PORT_HTTPS": env.kube.Port, + "TS_AUTH_ONCE": "true", }, - }, + KubeSecret: map[string]string{ + "authkey": "tskey-key", + }, + Phases: []phase{ + { + WantCmds: []string{ + "/usr/bin/tailscaled --socket=/tmp/tailscaled.sock --state=kube:tailscale --statedir=/tmp --tun=userspace-networking", + }, + WantKubeSecret: map[string]string{ + "authkey": "tskey-key", + kubetypes.KeyCapVer: capver, + }, + }, + { + Notify: &ipn.Notify{ + State: ptr.To(ipn.NeedsLogin), + }, + WantCmds: []string{ + "/usr/bin/tailscale --socket=/tmp/tailscaled.sock up --accept-dns=false --authkey=tskey-key", + }, + WantKubeSecret: map[string]string{ + "authkey": "tskey-key", + kubetypes.KeyCapVer: capver, + }, + }, + { + Notify: runningNotify, + WantCmds: []string{ + "/usr/bin/tailscale --socket=/tmp/tailscaled.sock set --accept-dns=false", + }, + WantKubeSecret: map[string]string{ + "device_fqdn": "test-node.test.ts.net", + "device_id": "myID", + "device_ips": `["100.64.0.1"]`, + kubetypes.KeyCapVer: capver, + }, + }, + }, + } }, - { - // Same as previous, but deletes the authkey from the kube secret. - Name: "kube_storage_auth_once", - Env: map[string]string{ - "KUBERNETES_SERVICE_HOST": kube.Host, - "KUBERNETES_SERVICE_PORT_HTTPS": kube.Port, - "TS_AUTH_ONCE": "true", - }, - KubeSecret: map[string]string{ - "authkey": "tskey-key", - }, - Phases: []phase{ - { - WantCmds: []string{ - "/usr/bin/tailscaled --socket=/tmp/tailscaled.sock --state=kube:tailscale --statedir=/tmp --tun=userspace-networking", + "kube_storage_updates": func(env *testEnv) testCase { + return testCase{ + Env: map[string]string{ + "KUBERNETES_SERVICE_HOST": env.kube.Host, + "KUBERNETES_SERVICE_PORT_HTTPS": env.kube.Port, + }, + KubeSecret: map[string]string{ + "authkey": "tskey-key", + }, + Phases: []phase{ + { + WantCmds: []string{ + "/usr/bin/tailscaled --socket=/tmp/tailscaled.sock --state=kube:tailscale --statedir=/tmp --tun=userspace-networking", + "/usr/bin/tailscale --socket=/tmp/tailscaled.sock up --accept-dns=false --authkey=tskey-key", + }, + WantKubeSecret: map[string]string{ + "authkey": "tskey-key", + kubetypes.KeyCapVer: capver, + }, }, - WantKubeSecret: map[string]string{ - "authkey": "tskey-key", + { + Notify: runningNotify, + WantKubeSecret: map[string]string{ + "authkey": "tskey-key", + "device_fqdn": "test-node.test.ts.net", + "device_id": "myID", + "device_ips": `["100.64.0.1"]`, + kubetypes.KeyCapVer: capver, + }, }, + { + Notify: &ipn.Notify{ + State: ptr.To(ipn.Running), + NetMap: &netmap.NetworkMap{ + SelfNode: (&tailcfg.Node{ + StableID: tailcfg.StableNodeID("newID"), + Name: "new-name.test.ts.net", + Addresses: []netip.Prefix{netip.MustParsePrefix("100.64.0.1/32")}, + }).View(), + }, + }, + WantKubeSecret: map[string]string{ + "authkey": "tskey-key", + "device_fqdn": "new-name.test.ts.net", + "device_id": "newID", + "device_ips": `["100.64.0.1"]`, + kubetypes.KeyCapVer: capver, + }, + }, + }, + } + }, + "proxies": func(env *testEnv) testCase { + return testCase{ + Env: map[string]string{ + "TS_SOCKS5_SERVER": "localhost:1080", + "TS_OUTBOUND_HTTP_PROXY_LISTEN": "localhost:8080", }, - { - Notify: &ipn.Notify{ - State: ptr.To(ipn.NeedsLogin), + Phases: []phase{ + { + WantCmds: []string{ + "/usr/bin/tailscaled --socket=/tmp/tailscaled.sock --state=mem: --statedir=/tmp --tun=userspace-networking --socks5-server=localhost:1080 --outbound-http-proxy-listen=localhost:8080", + "/usr/bin/tailscale --socket=/tmp/tailscaled.sock up --accept-dns=false", + }, }, - WantCmds: []string{ - "/usr/bin/tailscale --socket=/tmp/tailscaled.sock up --accept-dns=false --authkey=tskey-key", + { + Notify: runningNotify, }, - WantKubeSecret: map[string]string{ - "authkey": "tskey-key", + }, + } + }, + "dns": func(env *testEnv) testCase { + return testCase{ + Env: map[string]string{ + "TS_ACCEPT_DNS": "true", + }, + Phases: []phase{ + { + WantCmds: []string{ + "/usr/bin/tailscaled --socket=/tmp/tailscaled.sock --state=mem: --statedir=/tmp --tun=userspace-networking", + "/usr/bin/tailscale --socket=/tmp/tailscaled.sock up --accept-dns=true", + }, }, + { + Notify: runningNotify, + }, + }, + } + }, + "extra_args": func(env *testEnv) testCase { + return testCase{ + Env: map[string]string{ + "TS_EXTRA_ARGS": "--widget=rotated", + "TS_TAILSCALED_EXTRA_ARGS": "--experiments=widgets", }, - { - Notify: runningNotify, - WantCmds: []string{ - "/usr/bin/tailscale --socket=/tmp/tailscaled.sock set --accept-dns=false", + Phases: []phase{ + { + WantCmds: []string{ + "/usr/bin/tailscaled --socket=/tmp/tailscaled.sock --state=mem: --statedir=/tmp --tun=userspace-networking --experiments=widgets", + "/usr/bin/tailscale --socket=/tmp/tailscaled.sock up --accept-dns=false --widget=rotated", + }, + }, { + Notify: runningNotify, }, - WantKubeSecret: map[string]string{ - "device_fqdn": "test-node.test.ts.net", - "device_id": "myID", - "device_ips": `["100.64.0.1"]`, + }, + } + }, + "extra_args_accept_routes": func(env *testEnv) testCase { + return testCase{ + Env: map[string]string{ + "TS_EXTRA_ARGS": "--accept-routes", + }, + Phases: []phase{ + { + WantCmds: []string{ + "/usr/bin/tailscaled --socket=/tmp/tailscaled.sock --state=mem: --statedir=/tmp --tun=userspace-networking", + "/usr/bin/tailscale --socket=/tmp/tailscaled.sock up --accept-dns=false --accept-routes", + }, + }, { + Notify: runningNotify, }, }, - }, + } }, - { - Name: "kube_storage_updates", - Env: map[string]string{ - "KUBERNETES_SERVICE_HOST": kube.Host, - "KUBERNETES_SERVICE_PORT_HTTPS": kube.Port, - }, - KubeSecret: map[string]string{ - "authkey": "tskey-key", - }, - Phases: []phase{ - { - WantCmds: []string{ - "/usr/bin/tailscaled --socket=/tmp/tailscaled.sock --state=kube:tailscale --statedir=/tmp --tun=userspace-networking", - "/usr/bin/tailscale --socket=/tmp/tailscaled.sock up --accept-dns=false --authkey=tskey-key", + "extra_args_accept_dns": func(env *testEnv) testCase { + return testCase{ + Env: map[string]string{ + "TS_EXTRA_ARGS": "--accept-dns", + }, + Phases: []phase{ + { + WantCmds: []string{ + "/usr/bin/tailscaled --socket=/tmp/tailscaled.sock --state=mem: --statedir=/tmp --tun=userspace-networking", + "/usr/bin/tailscale --socket=/tmp/tailscaled.sock up --accept-dns=true", + }, + }, { + Notify: runningNotify, }, - WantKubeSecret: map[string]string{ - "authkey": "tskey-key", + }, + } + }, + "extra_args_accept_dns_overrides_env_var": func(env *testEnv) testCase { + return testCase{ + Env: map[string]string{ + "TS_ACCEPT_DNS": "true", // Overridden by TS_EXTRA_ARGS. + "TS_EXTRA_ARGS": "--accept-dns=false", + }, + Phases: []phase{ + { + WantCmds: []string{ + "/usr/bin/tailscaled --socket=/tmp/tailscaled.sock --state=mem: --statedir=/tmp --tun=userspace-networking", + "/usr/bin/tailscale --socket=/tmp/tailscaled.sock up --accept-dns=false", + }, + }, { + Notify: runningNotify, }, }, - { - Notify: runningNotify, - WantKubeSecret: map[string]string{ - "authkey": "tskey-key", - "device_fqdn": "test-node.test.ts.net", - "device_id": "myID", - "device_ips": `["100.64.0.1"]`, + } + }, + "hostname": func(env *testEnv) testCase { + return testCase{ + Env: map[string]string{ + "TS_HOSTNAME": "my-server", + }, + Phases: []phase{ + { + WantCmds: []string{ + "/usr/bin/tailscaled --socket=/tmp/tailscaled.sock --state=mem: --statedir=/tmp --tun=userspace-networking", + "/usr/bin/tailscale --socket=/tmp/tailscaled.sock up --accept-dns=false --hostname=my-server", + }, + }, { + Notify: runningNotify, }, }, - { - Notify: &ipn.Notify{ - State: ptr.To(ipn.Running), - NetMap: &netmap.NetworkMap{ - SelfNode: (&tailcfg.Node{ - StableID: tailcfg.StableNodeID("newID"), - Name: "new-name.test.ts.net", - Addresses: []netip.Prefix{netip.MustParsePrefix("100.64.0.1/32")}, - }).View(), + } + }, + "experimental_tailscaled_config_path": func(env *testEnv) testCase { + return testCase{ + Env: map[string]string{ + "TS_EXPERIMENTAL_VERSIONED_CONFIG_DIR": filepath.Join(env.d, "etc/tailscaled/"), + }, + Phases: []phase{ + { + WantCmds: []string{ + "/usr/bin/tailscaled --socket=/tmp/tailscaled.sock --state=mem: --statedir=/tmp --tun=userspace-networking --config=/etc/tailscaled/cap-95.hujson", }, + }, { + Notify: runningNotify, }, - WantKubeSecret: map[string]string{ - "authkey": "tskey-key", - "device_fqdn": "new-name.test.ts.net", - "device_id": "newID", - "device_ips": `["100.64.0.1"]`, + }, + } + }, + "metrics_enabled": func(env *testEnv) testCase { + return testCase{ + Env: map[string]string{ + "TS_LOCAL_ADDR_PORT": fmt.Sprintf("[::]:%d", env.localAddrPort), + "TS_ENABLE_METRICS": "true", + }, + Phases: []phase{ + { + WantCmds: []string{ + "/usr/bin/tailscaled --socket=/tmp/tailscaled.sock --state=mem: --statedir=/tmp --tun=userspace-networking", + "/usr/bin/tailscale --socket=/tmp/tailscaled.sock up --accept-dns=false", + }, + EndpointStatuses: map[string]int{ + metricsURL(env.localAddrPort): 200, + healthURL(env.localAddrPort): -1, + }, + }, { + Notify: runningNotify, }, }, - }, + } }, - { - Name: "proxies", - Env: map[string]string{ - "TS_SOCKS5_SERVER": "localhost:1080", - "TS_OUTBOUND_HTTP_PROXY_LISTEN": "localhost:8080", - }, - Phases: []phase{ - { - WantCmds: []string{ - "/usr/bin/tailscaled --socket=/tmp/tailscaled.sock --state=mem: --statedir=/tmp --tun=userspace-networking --socks5-server=localhost:1080 --outbound-http-proxy-listen=localhost:8080", - "/usr/bin/tailscale --socket=/tmp/tailscaled.sock up --accept-dns=false", + "health_enabled": func(env *testEnv) testCase { + return testCase{ + Env: map[string]string{ + "TS_LOCAL_ADDR_PORT": fmt.Sprintf("[::]:%d", env.localAddrPort), + "TS_ENABLE_HEALTH_CHECK": "true", + }, + Phases: []phase{ + { + WantCmds: []string{ + "/usr/bin/tailscaled --socket=/tmp/tailscaled.sock --state=mem: --statedir=/tmp --tun=userspace-networking", + "/usr/bin/tailscale --socket=/tmp/tailscaled.sock up --accept-dns=false", + }, + EndpointStatuses: map[string]int{ + metricsURL(env.localAddrPort): -1, + healthURL(env.localAddrPort): 503, // Doesn't start passing until the next phase. + }, + }, { + Notify: runningNotify, + EndpointStatuses: map[string]int{ + metricsURL(env.localAddrPort): -1, + healthURL(env.localAddrPort): 200, + }, }, }, - { - Notify: runningNotify, + } + }, + "metrics_and_health_on_same_port": func(env *testEnv) testCase { + return testCase{ + Env: map[string]string{ + "TS_LOCAL_ADDR_PORT": fmt.Sprintf("[::]:%d", env.localAddrPort), + "TS_ENABLE_METRICS": "true", + "TS_ENABLE_HEALTH_CHECK": "true", }, - }, + Phases: []phase{ + { + WantCmds: []string{ + "/usr/bin/tailscaled --socket=/tmp/tailscaled.sock --state=mem: --statedir=/tmp --tun=userspace-networking", + "/usr/bin/tailscale --socket=/tmp/tailscaled.sock up --accept-dns=false", + }, + EndpointStatuses: map[string]int{ + metricsURL(env.localAddrPort): 200, + healthURL(env.localAddrPort): 503, // Doesn't start passing until the next phase. + }, + }, { + Notify: runningNotify, + EndpointStatuses: map[string]int{ + metricsURL(env.localAddrPort): 200, + healthURL(env.localAddrPort): 200, + }, + }, + }, + } }, - { - Name: "dns", - Env: map[string]string{ - "TS_ACCEPT_DNS": "true", - }, - Phases: []phase{ - { - WantCmds: []string{ - "/usr/bin/tailscaled --socket=/tmp/tailscaled.sock --state=mem: --statedir=/tmp --tun=userspace-networking", - "/usr/bin/tailscale --socket=/tmp/tailscaled.sock up --accept-dns=true", + "local_metrics_and_deprecated_health": func(env *testEnv) testCase { + return testCase{ + Env: map[string]string{ + "TS_LOCAL_ADDR_PORT": fmt.Sprintf("[::]:%d", env.localAddrPort), + "TS_ENABLE_METRICS": "true", + "TS_HEALTHCHECK_ADDR_PORT": fmt.Sprintf("[::]:%d", env.healthAddrPort), + }, + Phases: []phase{ + { + WantCmds: []string{ + "/usr/bin/tailscaled --socket=/tmp/tailscaled.sock --state=mem: --statedir=/tmp --tun=userspace-networking", + "/usr/bin/tailscale --socket=/tmp/tailscaled.sock up --accept-dns=false", + }, + EndpointStatuses: map[string]int{ + metricsURL(env.localAddrPort): 200, + healthURL(env.healthAddrPort): 503, // Doesn't start passing until the next phase. + }, + }, { + Notify: runningNotify, + EndpointStatuses: map[string]int{ + metricsURL(env.localAddrPort): 200, + healthURL(env.healthAddrPort): 200, + }, }, }, - { - Notify: runningNotify, + } + }, + "serve_config_no_kube": func(env *testEnv) testCase { + return testCase{ + Env: map[string]string{ + "TS_SERVE_CONFIG": filepath.Join(env.d, "etc/tailscaled/serve-config.json"), + "TS_AUTHKEY": "tskey-key", }, - }, + Phases: []phase{ + { + WantCmds: []string{ + "/usr/bin/tailscaled --socket=/tmp/tailscaled.sock --state=mem: --statedir=/tmp --tun=userspace-networking", + "/usr/bin/tailscale --socket=/tmp/tailscaled.sock up --accept-dns=false --authkey=tskey-key", + }, + }, + { + Notify: runningNotify, + }, + }, + } }, - { - Name: "extra_args", - Env: map[string]string{ - "TS_EXTRA_ARGS": "--widget=rotated", - "TS_TAILSCALED_EXTRA_ARGS": "--experiments=widgets", - }, - Phases: []phase{ - { - WantCmds: []string{ - "/usr/bin/tailscaled --socket=/tmp/tailscaled.sock --state=mem: --statedir=/tmp --tun=userspace-networking --experiments=widgets", - "/usr/bin/tailscale --socket=/tmp/tailscaled.sock up --accept-dns=false --widget=rotated", + "serve_config_kube": func(env *testEnv) testCase { + return testCase{ + Env: map[string]string{ + "KUBERNETES_SERVICE_HOST": env.kube.Host, + "KUBERNETES_SERVICE_PORT_HTTPS": env.kube.Port, + "TS_SERVE_CONFIG": filepath.Join(env.d, "etc/tailscaled/serve-config.json"), + }, + KubeSecret: map[string]string{ + "authkey": "tskey-key", + }, + Phases: []phase{ + { + WantCmds: []string{ + "/usr/bin/tailscaled --socket=/tmp/tailscaled.sock --state=kube:tailscale --statedir=/tmp --tun=userspace-networking", + "/usr/bin/tailscale --socket=/tmp/tailscaled.sock up --accept-dns=false --authkey=tskey-key", + }, + WantKubeSecret: map[string]string{ + "authkey": "tskey-key", + kubetypes.KeyCapVer: capver, + }, + }, + { + Notify: runningNotify, + WantKubeSecret: map[string]string{ + "authkey": "tskey-key", + "device_fqdn": "test-node.test.ts.net", + "device_id": "myID", + "device_ips": `["100.64.0.1"]`, + "https_endpoint": "no-https", + kubetypes.KeyCapVer: capver, + }, }, - }, { - Notify: runningNotify, }, - }, + } }, - { - Name: "extra_args_accept_routes", - Env: map[string]string{ - "TS_EXTRA_ARGS": "--accept-routes", - }, - Phases: []phase{ - { - WantCmds: []string{ - "/usr/bin/tailscaled --socket=/tmp/tailscaled.sock --state=mem: --statedir=/tmp --tun=userspace-networking", - "/usr/bin/tailscale --socket=/tmp/tailscaled.sock up --accept-dns=false --accept-routes", + "egress_svcs_config_kube": func(env *testEnv) testCase { + return testCase{ + Env: map[string]string{ + "KUBERNETES_SERVICE_HOST": env.kube.Host, + "KUBERNETES_SERVICE_PORT_HTTPS": env.kube.Port, + "TS_EGRESS_PROXIES_CONFIG_PATH": filepath.Join(env.d, "etc/tailscaled"), + "TS_LOCAL_ADDR_PORT": fmt.Sprintf("[::]:%d", env.localAddrPort), + }, + KubeSecret: map[string]string{ + "authkey": "tskey-key", + }, + Phases: []phase{ + { + WantCmds: []string{ + "/usr/bin/tailscaled --socket=/tmp/tailscaled.sock --state=kube:tailscale --statedir=/tmp --tun=userspace-networking", + "/usr/bin/tailscale --socket=/tmp/tailscaled.sock up --accept-dns=false --authkey=tskey-key", + }, + WantKubeSecret: map[string]string{ + "authkey": "tskey-key", + kubetypes.KeyCapVer: capver, + }, + EndpointStatuses: map[string]int{ + egressSvcTerminateURL(env.localAddrPort): 200, + }, + }, + { + Notify: runningNotify, + WantKubeSecret: map[string]string{ + "egress-services": string(mustJSON(t, egressStatus)), + "authkey": "tskey-key", + "device_fqdn": "test-node.test.ts.net", + "device_id": "myID", + "device_ips": `["100.64.0.1"]`, + kubetypes.KeyCapVer: capver, + }, + EndpointStatuses: map[string]int{ + egressSvcTerminateURL(env.localAddrPort): 200, + }, }, - }, { - Notify: runningNotify, }, - }, + } }, - { - Name: "hostname", - Env: map[string]string{ - "TS_HOSTNAME": "my-server", - }, - Phases: []phase{ - { - WantCmds: []string{ - "/usr/bin/tailscaled --socket=/tmp/tailscaled.sock --state=mem: --statedir=/tmp --tun=userspace-networking", - "/usr/bin/tailscale --socket=/tmp/tailscaled.sock up --accept-dns=false --hostname=my-server", + "egress_svcs_config_no_kube": func(env *testEnv) testCase { + return testCase{ + Env: map[string]string{ + "TS_EGRESS_PROXIES_CONFIG_PATH": filepath.Join(env.d, "etc/tailscaled"), + "TS_AUTHKEY": "tskey-key", + }, + Phases: []phase{ + { + WantLog: "TS_EGRESS_PROXIES_CONFIG_PATH is only supported for Tailscale running on Kubernetes", + WantExitCode: ptr.To(1), }, - }, { - Notify: runningNotify, }, - }, + } }, - { - Name: "experimental tailscaled config path", - Env: map[string]string{ - "TS_EXPERIMENTAL_VERSIONED_CONFIG_DIR": filepath.Join(d, "etc/tailscaled/"), - }, - Phases: []phase{ - { - WantCmds: []string{ - "/usr/bin/tailscaled --socket=/tmp/tailscaled.sock --state=mem: --statedir=/tmp --tun=userspace-networking --config=/etc/tailscaled/cap-95.hujson", + "kube_shutdown_during_state_write": func(env *testEnv) testCase { + return testCase{ + Env: map[string]string{ + "KUBERNETES_SERVICE_HOST": env.kube.Host, + "KUBERNETES_SERVICE_PORT_HTTPS": env.kube.Port, + "TS_ENABLE_HEALTH_CHECK": "true", + }, + KubeSecret: map[string]string{ + "authkey": "tskey-key", + }, + Phases: []phase{ + { + // Normal startup. + WantCmds: []string{ + "/usr/bin/tailscaled --socket=/tmp/tailscaled.sock --state=kube:tailscale --statedir=/tmp --tun=userspace-networking", + "/usr/bin/tailscale --socket=/tmp/tailscaled.sock up --accept-dns=false --authkey=tskey-key", + }, + WantKubeSecret: map[string]string{ + "authkey": "tskey-key", + kubetypes.KeyCapVer: capver, + }, + }, + { + // SIGTERM before state is finished writing, should wait for + // consistent state before propagating SIGTERM to tailscaled. + Signal: ptr.To(unix.SIGTERM), + UpdateKubeSecret: map[string]string{ + "_machinekey": "foo", + "_profiles": "foo", + "profile-baff": "foo", + // Missing "_current-profile" key. + }, + WantKubeSecret: map[string]string{ + "authkey": "tskey-key", + "_machinekey": "foo", + "_profiles": "foo", + "profile-baff": "foo", + kubetypes.KeyCapVer: capver, + }, + WantLog: "Waiting for tailscaled to finish writing state to Secret \"tailscale\"", + }, + { + // tailscaled has finished writing state, should propagate SIGTERM. + UpdateKubeSecret: map[string]string{ + "_current-profile": "foo", + }, + WantKubeSecret: map[string]string{ + "authkey": "tskey-key", + "_machinekey": "foo", + "_profiles": "foo", + "profile-baff": "foo", + "_current-profile": "foo", + kubetypes.KeyCapVer: capver, + }, + WantLog: "HTTP server at [::]:9002 closed", + WantExitCode: ptr.To(0), }, - }, { - Notify: runningNotify, }, - }, + } }, } - for _, test := range tests { - t.Run(test.Name, func(t *testing.T) { - lapi.Reset() - kube.Reset() - os.Remove(argFile) - os.Remove(runningSockPath) - resetFiles() + for name, test := range tests { + t.Run(name, func(t *testing.T) { + t.Parallel() + env := newTestEnv(t) + tc := test(&env) - for k, v := range test.KubeSecret { - kube.SetSecret(k, v) + for k, v := range tc.KubeSecret { + env.kube.SetSecret(k, v) } - kube.SetPatching(!test.KubeDenyPatch) + env.kube.SetPatching(!tc.KubeDenyPatch) cmd := exec.Command(boot) cmd.Env = []string{ - fmt.Sprintf("PATH=%s/usr/bin:%s", d, os.Getenv("PATH")), - fmt.Sprintf("TS_TEST_RECORD_ARGS=%s", argFile), - fmt.Sprintf("TS_TEST_SOCKET=%s", lapi.Path), - fmt.Sprintf("TS_SOCKET=%s", runningSockPath), - fmt.Sprintf("TS_TEST_ONLY_ROOT=%s", d), - fmt.Sprint("TS_TEST_FAKE_NETFILTER=true"), + fmt.Sprintf("PATH=%s/usr/bin:%s", env.d, os.Getenv("PATH")), + fmt.Sprintf("TS_TEST_RECORD_ARGS=%s", env.argFile), + fmt.Sprintf("TS_TEST_SOCKET=%s", env.lapi.Path), + fmt.Sprintf("TS_SOCKET=%s", env.runningSockPath), + fmt.Sprintf("TS_TEST_ONLY_ROOT=%s", env.d), + "TS_TEST_FAKE_NETFILTER=true", } - for k, v := range test.Env { + for k, v := range tc.Env { cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", k, v)) } cbOut := &lockingBuffer{} @@ -734,6 +1085,7 @@ func TestContainerBoot(t *testing.T) { } }() cmd.Stderr = cbOut + cmd.Stdout = cbOut if err := cmd.Start(); err != nil { t.Fatalf("starting containerboot: %v", err) } @@ -743,37 +1095,47 @@ func TestContainerBoot(t *testing.T) { }() var wantCmds []string - for i, p := range test.Phases { - lapi.Notify(p.Notify) - if p.WantFatalLog != "" { + for i, p := range tc.Phases { + for k, v := range p.UpdateKubeSecret { + env.kube.SetSecret(k, v) + } + env.lapi.Notify(p.Notify) + if p.Signal != nil { + cmd.Process.Signal(*p.Signal) + } + if p.WantLog != "" { err := tstest.WaitFor(2*time.Second, func() error { - state, err := cmd.Process.Wait() - if err != nil { - return err - } - if state.ExitCode() != 1 { - return fmt.Errorf("process exited with code %d but wanted %d", state.ExitCode(), 1) - } - waitLogLine(t, time.Second, cbOut, p.WantFatalLog) + waitLogLine(t, time.Second, cbOut, p.WantLog) return nil }) if err != nil { t.Fatal(err) } + } + + if p.WantExitCode != nil { + state, err := cmd.Process.Wait() + if err != nil { + t.Fatal(err) + } + if state.ExitCode() != *p.WantExitCode { + t.Fatalf("phase %d: want exit code %d, got %d", i, *p.WantExitCode, state.ExitCode()) + } // Early test return, we don't expect the successful startup log message. return } + wantCmds = append(wantCmds, p.WantCmds...) - waitArgs(t, 2*time.Second, d, argFile, strings.Join(wantCmds, "\n")) + waitArgs(t, 2*time.Second, env.d, env.argFile, strings.Join(wantCmds, "\n")) err := tstest.WaitFor(2*time.Second, func() error { if p.WantKubeSecret != nil { - got := kube.Secret() + got := env.kube.Secret() if diff := cmp.Diff(got, p.WantKubeSecret); diff != "" { return fmt.Errorf("unexpected kube secret data (-got+want):\n%s", diff) } } else { - got := kube.Secret() + got := env.kube.Secret() if len(got) > 0 { return fmt.Errorf("kube secret unexpectedly not empty, got %#v", got) } @@ -785,7 +1147,7 @@ func TestContainerBoot(t *testing.T) { } err = tstest.WaitFor(2*time.Second, func() error { for path, want := range p.WantFiles { - gotBs, err := os.ReadFile(filepath.Join(d, path)) + gotBs, err := os.ReadFile(filepath.Join(env.d, path)) if err != nil { return fmt.Errorf("reading wanted file %q: %v", path, err) } @@ -796,10 +1158,32 @@ func TestContainerBoot(t *testing.T) { return nil }) if err != nil { - t.Fatal(err) + t.Fatalf("phase %d: %v", i, err) + } + + for url, want := range p.EndpointStatuses { + err := tstest.WaitFor(2*time.Second, func() error { + resp, err := http.Get(url) + if err != nil && want != -1 { + return fmt.Errorf("GET %s: %v", url, err) + } + if want > 0 && resp.StatusCode != want { + defer resp.Body.Close() + body, _ := io.ReadAll(resp.Body) + return fmt.Errorf("GET %s, want %d, got %d\n%s", url, want, resp.StatusCode, string(body)) + } + + return nil + }) + if err != nil { + t.Fatalf("phase %d: %v", i, err) + } } } waitLogLine(t, 2*time.Second, cbOut, "Startup complete, waiting for shutdown signal") + if cmd.ProcessState != nil { + t.Fatalf("containerboot should be running but exited with exit code %d", cmd.ProcessState.ExitCode()) + } }) } } @@ -903,8 +1287,8 @@ type localAPI struct { notify *ipn.Notify } -func (l *localAPI) Start() error { - path := filepath.Join(l.FSRoot, "tmp/tailscaled.sock.fake") +func (lc *localAPI) Start() error { + path := filepath.Join(lc.FSRoot, "tmp/tailscaled.sock.fake") if err := os.MkdirAll(filepath.Dir(path), 0700); err != nil { return err } @@ -914,37 +1298,30 @@ func (l *localAPI) Start() error { return err } - l.srv = &http.Server{ - Handler: l, + lc.srv = &http.Server{ + Handler: lc, } - l.Path = path - l.cond = sync.NewCond(&l.Mutex) - go l.srv.Serve(ln) + lc.Path = path + lc.cond = sync.NewCond(&lc.Mutex) + go lc.srv.Serve(ln) return nil } -func (l *localAPI) Close() { - l.srv.Close() +func (lc *localAPI) Close() { + lc.srv.Close() } -func (l *localAPI) Reset() { - l.Lock() - defer l.Unlock() - l.notify = nil - l.cond.Broadcast() -} - -func (l *localAPI) Notify(n *ipn.Notify) { +func (lc *localAPI) Notify(n *ipn.Notify) { if n == nil { return } - l.Lock() - defer l.Unlock() - l.notify = n - l.cond.Broadcast() + lc.Lock() + defer lc.Unlock() + lc.notify = n + lc.cond.Broadcast() } -func (l *localAPI) ServeHTTP(w http.ResponseWriter, r *http.Request) { +func (lc *localAPI) ServeHTTP(w http.ResponseWriter, r *http.Request) { switch r.URL.Path { case "/localapi/v0/serve-config": if r.Method != "POST" { @@ -955,6 +1332,12 @@ func (l *localAPI) ServeHTTP(w http.ResponseWriter, r *http.Request) { if r.Method != "GET" { panic(fmt.Sprintf("unsupported method %q", r.Method)) } + case "/localapi/v0/usermetrics": + if r.Method != "GET" { + panic(fmt.Sprintf("unsupported method %q", r.Method)) + } + w.Write([]byte("fake metrics")) + return default: panic(fmt.Sprintf("unsupported path %q", r.URL.Path)) } @@ -965,11 +1348,11 @@ func (l *localAPI) ServeHTTP(w http.ResponseWriter, r *http.Request) { f.Flush() } enc := json.NewEncoder(w) - l.Lock() - defer l.Unlock() + lc.Lock() + defer lc.Unlock() for { - if l.notify != nil { - if err := enc.Encode(l.notify); err != nil { + if lc.notify != nil { + if err := enc.Encode(lc.notify); err != nil { // Usually broken pipe as the test client disconnects. return } @@ -977,7 +1360,7 @@ func (l *localAPI) ServeHTTP(w http.ResponseWriter, r *http.Request) { f.Flush() } } - l.cond.Wait() + lc.cond.Wait() } } @@ -1019,24 +1402,19 @@ func (k *kubeServer) SetPatching(canPatch bool) { k.canPatch = canPatch } -func (k *kubeServer) Reset() { - k.Lock() - defer k.Unlock() +func (k *kubeServer) Start(t *testing.T) { k.secret = map[string]string{} -} - -func (k *kubeServer) Start() error { root := filepath.Join(k.FSRoot, "var/run/secrets/kubernetes.io/serviceaccount") if err := os.MkdirAll(root, 0700); err != nil { - return err + t.Fatal(err) } if err := os.WriteFile(filepath.Join(root, "namespace"), []byte("default"), 0600); err != nil { - return err + t.Fatal(err) } if err := os.WriteFile(filepath.Join(root, "token"), []byte("bearer_token"), 0600); err != nil { - return err + t.Fatal(err) } k.srv = httptest.NewTLSServer(k) @@ -1045,13 +1423,11 @@ func (k *kubeServer) Start() error { var cert bytes.Buffer if err := pem.Encode(&cert, &pem.Block{Type: "CERTIFICATE", Bytes: k.srv.Certificate().Raw}); err != nil { - return err + t.Fatal(err) } if err := os.WriteFile(filepath.Join(root, "ca.crt"), cert.Bytes(), 0600); err != nil { - return err + t.Fatal(err) } - - return nil } func (k *kubeServer) Close() { @@ -1100,6 +1476,7 @@ func (k *kubeServer) serveSecret(w http.ResponseWriter, r *http.Request) { http.Error(w, fmt.Sprintf("reading request body: %v", err), http.StatusInternalServerError) return } + defer r.Body.Close() switch r.Method { case "GET": @@ -1124,21 +1501,34 @@ func (k *kubeServer) serveSecret(w http.ResponseWriter, r *http.Request) { } switch r.Header.Get("Content-Type") { case "application/json-patch+json": - req := []struct { - Op string `json:"op"` - Path string `json:"path"` - }{} + req := []kubeclient.JSONPatch{} if err := json.Unmarshal(bs, &req); err != nil { panic(fmt.Sprintf("json decode failed: %v. Body:\n\n%s", err, string(bs))) } for _, op := range req { - if op.Op != "remove" { + switch op.Op { + case "remove": + if !strings.HasPrefix(op.Path, "/data/") { + panic(fmt.Sprintf("unsupported json-patch path %q", op.Path)) + } + delete(k.secret, strings.TrimPrefix(op.Path, "/data/")) + case "add", "replace": + path, ok := strings.CutPrefix(op.Path, "/data/") + if !ok { + panic(fmt.Sprintf("unsupported json-patch path %q", op.Path)) + } + val, ok := op.Value.(string) + if !ok { + panic(fmt.Sprintf("unsupported json patch value %v: cannot be converted to string", op.Value)) + } + v, err := base64.StdEncoding.DecodeString(val) + if err != nil { + panic(fmt.Sprintf("json patch value %q is not base64 encoded: %v", val, err)) + } + k.secret[path] = string(v) + default: panic(fmt.Sprintf("unsupported json-patch op %q", op.Op)) } - if !strings.HasPrefix(op.Path, "/data/") { - panic(fmt.Sprintf("unsupported json-patch path %q", op.Path)) - } - delete(k.secret, strings.TrimPrefix(op.Path, "/data/")) } case "application/strategic-merge-patch+json": req := struct { @@ -1154,6 +1544,135 @@ func (k *kubeServer) serveSecret(w http.ResponseWriter, r *http.Request) { panic(fmt.Sprintf("unknown content type %q", r.Header.Get("Content-Type"))) } default: - panic(fmt.Sprintf("unhandled HTTP method %q", r.Method)) + panic(fmt.Sprintf("unhandled HTTP request %s %s", r.Method, r.URL)) + } +} + +func mustBase64(t *testing.T, v any) string { + b := mustJSON(t, v) + s := base64.StdEncoding.WithPadding('=').EncodeToString(b) + return s +} + +func mustJSON(t *testing.T, v any) []byte { + b, err := json.Marshal(v) + if err != nil { + t.Fatalf("error converting %v to json: %v", v, err) + } + return b +} + +// egress services status given one named tailnet target specified by FQDN. As written by the proxy to its state Secret. +func egressSvcStatus(name, fqdn string) egressservices.Status { + return egressservices.Status{ + Services: map[string]*egressservices.ServiceStatus{ + name: { + TailnetTarget: egressservices.TailnetTarget{ + FQDN: fqdn, + }, + }, + }, + } +} + +// egress config given one named tailnet target specified by FQDN. +func egressSvcConfig(name, fqdn string) egressservices.Configs { + return egressservices.Configs{ + name: egressservices.Config{ + TailnetTarget: egressservices.TailnetTarget{ + FQDN: fqdn, + }, + }, + } +} + +// testEnv represents the environment needed for a single sub-test so that tests +// can run in parallel. +type testEnv struct { + kube *kubeServer // Fake kube server. + lapi *localAPI // Local TS API server. + d string // Temp dir for the specific test. + argFile string // File with commands test_tailscale{,d}.sh were invoked with. + runningSockPath string // Path to the running tailscaled socket. + localAddrPort int // Port for the containerboot HTTP server. + healthAddrPort int // Port for the (deprecated) containerboot health server. +} + +func newTestEnv(t *testing.T) testEnv { + d := t.TempDir() + + lapi := localAPI{FSRoot: d} + if err := lapi.Start(); err != nil { + t.Fatal(err) + } + t.Cleanup(lapi.Close) + + kube := kubeServer{FSRoot: d} + kube.Start(t) + t.Cleanup(kube.Close) + + tailscaledConf := &ipn.ConfigVAlpha{AuthKey: ptr.To("foo"), Version: "alpha0"} + serveConf := ipn.ServeConfig{TCP: map[uint16]*ipn.TCPPortHandler{80: {HTTP: true}}} + egressCfg := egressSvcConfig("foo", "foo.tailnetxyz.ts.net") + + dirs := []string{ + "var/lib", + "usr/bin", + "tmp", + "dev/net", + "proc/sys/net/ipv4", + "proc/sys/net/ipv6/conf/all", + "etc/tailscaled", + } + for _, path := range dirs { + if err := os.MkdirAll(filepath.Join(d, path), 0700); err != nil { + t.Fatal(err) + } + } + files := map[string][]byte{ + "usr/bin/tailscaled": fakeTailscaled, + "usr/bin/tailscale": fakeTailscale, + "usr/bin/iptables": fakeTailscale, + "usr/bin/ip6tables": fakeTailscale, + "dev/net/tun": []byte(""), + "proc/sys/net/ipv4/ip_forward": []byte("0"), + "proc/sys/net/ipv6/conf/all/forwarding": []byte("0"), + "etc/tailscaled/cap-95.hujson": mustJSON(t, tailscaledConf), + "etc/tailscaled/serve-config.json": mustJSON(t, serveConf), + filepath.Join("etc/tailscaled/", egressservices.KeyEgressServices): mustJSON(t, egressCfg), + filepath.Join("etc/tailscaled/", egressservices.KeyHEPPings): []byte("4"), + } + for path, content := range files { + // Making everything executable is a little weird, but the + // stuff that doesn't need to be executable doesn't care if we + // do make it executable. + if err := os.WriteFile(filepath.Join(d, path), content, 0700); err != nil { + t.Fatal(err) + } + } + + argFile := filepath.Join(d, "args") + runningSockPath := filepath.Join(d, "tmp/tailscaled.sock") + var localAddrPort, healthAddrPort int + for _, p := range []*int{&localAddrPort, &healthAddrPort} { + ln, err := net.Listen("tcp", ":0") + if err != nil { + t.Fatalf("Failed to open listener: %v", err) + } + if err := ln.Close(); err != nil { + t.Fatalf("Failed to close listener: %v", err) + } + port := ln.Addr().(*net.TCPAddr).Port + *p = port + } + + return testEnv{ + kube: &kube, + lapi: &lapi, + d: d, + argFile: argFile, + runningSockPath: runningSockPath, + localAddrPort: localAddrPort, + healthAddrPort: healthAddrPort, } } diff --git a/cmd/containerboot/serve.go b/cmd/containerboot/serve.go index 6c22b3eeb..5fa8e580d 100644 --- a/cmd/containerboot/serve.go +++ b/cmd/containerboot/serve.go @@ -17,8 +17,12 @@ import ( "time" "github.com/fsnotify/fsnotify" - "tailscale.com/client/tailscale" + "tailscale.com/client/local" "tailscale.com/ipn" + "tailscale.com/kube/certs" + "tailscale.com/kube/kubetypes" + klc "tailscale.com/kube/localclient" + "tailscale.com/types/netmap" ) // watchServeConfigChanges watches path for changes, and when it sees one, reads @@ -26,27 +30,34 @@ import ( // applies it to lc. It exits when ctx is canceled. cdChanged is a channel that // is written to when the certDomain changes, causing the serve config to be // re-read and applied. -func watchServeConfigChanges(ctx context.Context, path string, cdChanged <-chan bool, certDomainAtomic *atomic.Pointer[string], lc *tailscale.LocalClient) { +func watchServeConfigChanges(ctx context.Context, cdChanged <-chan bool, certDomainAtomic *atomic.Pointer[string], lc *local.Client, kc *kubeClient, cfg *settings) { if certDomainAtomic == nil { - panic("cd must not be nil") + panic("certDomainAtomic must not be nil") } + var tickChan <-chan time.Time var eventChan <-chan fsnotify.Event if w, err := fsnotify.NewWatcher(); err != nil { - log.Printf("failed to create fsnotify watcher, timer-only mode: %v", err) + // Creating a new fsnotify watcher would fail for example if inotify was not able to create a new file descriptor. + // See https://github.com/tailscale/tailscale/issues/15081 + log.Printf("serve proxy: failed to create fsnotify watcher, timer-only mode: %v", err) ticker := time.NewTicker(5 * time.Second) defer ticker.Stop() tickChan = ticker.C } else { defer w.Close() - if err := w.Add(filepath.Dir(path)); err != nil { - log.Fatalf("failed to add fsnotify watch: %v", err) + if err := w.Add(filepath.Dir(cfg.ServeConfigPath)); err != nil { + log.Fatalf("serve proxy: failed to add fsnotify watch: %v", err) } eventChan = w.Events } var certDomain string var prevServeConfig *ipn.ServeConfig + var cm *certs.CertManager + if cfg.CertShareMode == "rw" { + cm = certs.NewCertManager(klc.New(lc), log.Printf) + } for { select { case <-ctx.Done(): @@ -59,24 +70,79 @@ func watchServeConfigChanges(ctx context.Context, path string, cdChanged <-chan // k8s handles these mounts. So just re-read the file and apply it // if it's changed. } - if certDomain == "" { - continue - } - sc, err := readServeConfig(path, certDomain) + sc, err := readServeConfig(cfg.ServeConfigPath, certDomain) if err != nil { - log.Fatalf("failed to read serve config: %v", err) + log.Fatalf("serve proxy: failed to read serve config: %v", err) + } + if sc == nil { + log.Printf("serve proxy: no serve config at %q, skipping", cfg.ServeConfigPath) + continue } if prevServeConfig != nil && reflect.DeepEqual(sc, prevServeConfig) { continue } - log.Printf("Applying serve config") - if err := lc.SetServeConfig(ctx, sc); err != nil { - log.Fatalf("failed to set serve config: %v", err) + if err := updateServeConfig(ctx, sc, certDomain, lc); err != nil { + log.Fatalf("serve proxy: error updating serve config: %v", err) + } + if kc != nil && kc.canPatch { + if err := kc.storeHTTPSEndpoint(ctx, certDomain); err != nil { + log.Fatalf("serve proxy: error storing HTTPS endpoint: %v", err) + } } prevServeConfig = sc + if cfg.CertShareMode != "rw" { + continue + } + if err := cm.EnsureCertLoops(ctx, sc); err != nil { + log.Fatalf("serve proxy: error ensuring cert loops: %v", err) + } } } +func certDomainFromNetmap(nm *netmap.NetworkMap) string { + if len(nm.DNS.CertDomains) == 0 { + return "" + } + return nm.DNS.CertDomains[0] +} + +// localClient is a subset of [local.Client] that can be mocked for testing. +type localClient interface { + SetServeConfig(context.Context, *ipn.ServeConfig) error + CertPair(context.Context, string) ([]byte, []byte, error) +} + +func updateServeConfig(ctx context.Context, sc *ipn.ServeConfig, certDomain string, lc localClient) error { + if !isValidHTTPSConfig(certDomain, sc) { + return nil + } + log.Printf("serve proxy: applying serve config") + return lc.SetServeConfig(ctx, sc) +} + +func isValidHTTPSConfig(certDomain string, sc *ipn.ServeConfig) bool { + if certDomain == kubetypes.ValueNoHTTPS && hasHTTPSEndpoint(sc) { + log.Printf( + `serve proxy: this node is configured as a proxy that exposes an HTTPS endpoint to tailnet, + (perhaps a Kubernetes operator Ingress proxy) but it is not able to issue TLS certs, so this will likely not work. + To make it work, ensure that HTTPS is enabled for your tailnet, see https://tailscale.com/kb/1153/enabling-https for more details.`) + return false + } + return true +} + +func hasHTTPSEndpoint(cfg *ipn.ServeConfig) bool { + if cfg == nil { + return false + } + for _, tcpCfg := range cfg.TCP { + if tcpCfg.HTTPS { + return true + } + } + return false +} + // readServeConfig reads the ipn.ServeConfig from path, replacing // ${TS_CERT_DOMAIN} with certDomain. func readServeConfig(path, certDomain string) (*ipn.ServeConfig, error) { @@ -85,8 +151,17 @@ func readServeConfig(path, certDomain string) (*ipn.ServeConfig, error) { } j, err := os.ReadFile(path) if err != nil { + if os.IsNotExist(err) { + return nil, nil + } return nil, err } + // Serve config can be provided by users as well as the Kubernetes Operator (for its proxies). User-provided + // config could be empty for reasons. + if len(j) == 0 { + log.Printf("serve proxy: serve config file is empty, skipping") + return nil, nil + } j = bytes.ReplaceAll(j, []byte("${TS_CERT_DOMAIN}"), []byte(certDomain)) var sc ipn.ServeConfig if err := json.Unmarshal(j, &sc); err != nil { diff --git a/cmd/containerboot/serve_test.go b/cmd/containerboot/serve_test.go new file mode 100644 index 000000000..fc18f254d --- /dev/null +++ b/cmd/containerboot/serve_test.go @@ -0,0 +1,271 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build linux + +package main + +import ( + "context" + "os" + "path/filepath" + "testing" + + "github.com/google/go-cmp/cmp" + "tailscale.com/client/local" + "tailscale.com/ipn" + "tailscale.com/kube/kubetypes" +) + +func TestUpdateServeConfig(t *testing.T) { + tests := []struct { + name string + sc *ipn.ServeConfig + certDomain string + wantCall bool + }{ + { + name: "no_https_no_cert_domain", + sc: &ipn.ServeConfig{ + TCP: map[uint16]*ipn.TCPPortHandler{ + 80: {HTTP: true}, + }, + }, + certDomain: kubetypes.ValueNoHTTPS, // tailnet has HTTPS disabled + wantCall: true, // should set serve config as it doesn't have HTTPS endpoints + }, + { + name: "https_with_cert_domain", + sc: &ipn.ServeConfig{ + TCP: map[uint16]*ipn.TCPPortHandler{ + 443: {HTTPS: true}, + }, + Web: map[ipn.HostPort]*ipn.WebServerConfig{ + "${TS_CERT_DOMAIN}:443": { + Handlers: map[string]*ipn.HTTPHandler{ + "/": {Proxy: "http://10.0.1.100:8080"}, + }, + }, + }, + }, + certDomain: "test-node.tailnet.ts.net", + wantCall: true, + }, + { + name: "https_without_cert_domain", + sc: &ipn.ServeConfig{ + TCP: map[uint16]*ipn.TCPPortHandler{ + 443: {HTTPS: true}, + }, + }, + certDomain: kubetypes.ValueNoHTTPS, + wantCall: false, // incorrect configuration- should not set serve config + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + fakeLC := &fakeLocalClient{} + err := updateServeConfig(context.Background(), tt.sc, tt.certDomain, fakeLC) + if err != nil { + t.Errorf("updateServeConfig() error = %v", err) + } + if fakeLC.setServeCalled != tt.wantCall { + t.Errorf("SetServeConfig() called = %v, want %v", fakeLC.setServeCalled, tt.wantCall) + } + }) + } +} + +func TestReadServeConfig(t *testing.T) { + tests := []struct { + name string + gotSC string + certDomain string + wantSC *ipn.ServeConfig + wantErr bool + }{ + { + name: "empty_file", + }, + { + name: "valid_config_with_cert_domain_placeholder", + gotSC: `{ + "TCP": { + "443": { + "HTTPS": true + } + }, + "Web": { + "${TS_CERT_DOMAIN}:443": { + "Handlers": { + "/api": { + "Proxy": "https://10.2.3.4/api" + }}}}}`, + certDomain: "example.com", + wantSC: &ipn.ServeConfig{ + TCP: map[uint16]*ipn.TCPPortHandler{ + 443: { + HTTPS: true, + }, + }, + Web: map[ipn.HostPort]*ipn.WebServerConfig{ + ipn.HostPort("example.com:443"): { + Handlers: map[string]*ipn.HTTPHandler{ + "/api": { + Proxy: "https://10.2.3.4/api", + }, + }, + }, + }, + }, + }, + { + name: "valid_config_for_http_proxy", + gotSC: `{ + "TCP": { + "80": { + "HTTP": true + } + }}`, + wantSC: &ipn.ServeConfig{ + TCP: map[uint16]*ipn.TCPPortHandler{ + 80: { + HTTP: true, + }, + }, + }, + }, + { + name: "config_without_cert_domain", + gotSC: `{ + "TCP": { + "443": { + "HTTPS": true + } + }, + "Web": { + "localhost:443": { + "Handlers": { + "/api": { + "Proxy": "https://10.2.3.4/api" + }}}}}`, + certDomain: "", + wantErr: false, + wantSC: &ipn.ServeConfig{ + TCP: map[uint16]*ipn.TCPPortHandler{ + 443: { + HTTPS: true, + }, + }, + Web: map[ipn.HostPort]*ipn.WebServerConfig{ + ipn.HostPort("localhost:443"): { + Handlers: map[string]*ipn.HTTPHandler{ + "/api": { + Proxy: "https://10.2.3.4/api", + }, + }, + }, + }, + }, + }, + { + name: "invalid_json", + gotSC: "invalid json", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "serve-config.json") + if err := os.WriteFile(path, []byte(tt.gotSC), 0644); err != nil { + t.Fatal(err) + } + + got, err := readServeConfig(path, tt.certDomain) + if (err != nil) != tt.wantErr { + t.Errorf("readServeConfig() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !cmp.Equal(got, tt.wantSC) { + t.Errorf("readServeConfig() diff (-got +want):\n%s", cmp.Diff(got, tt.wantSC)) + } + }) + } +} + +type fakeLocalClient struct { + *local.Client + setServeCalled bool +} + +func (m *fakeLocalClient) SetServeConfig(ctx context.Context, cfg *ipn.ServeConfig) error { + m.setServeCalled = true + return nil +} + +func (m *fakeLocalClient) CertPair(ctx context.Context, domain string) (certPEM, keyPEM []byte, err error) { + return nil, nil, nil +} + +func TestHasHTTPSEndpoint(t *testing.T) { + tests := []struct { + name string + cfg *ipn.ServeConfig + want bool + }{ + { + name: "nil_config", + cfg: nil, + want: false, + }, + { + name: "empty_config", + cfg: &ipn.ServeConfig{}, + want: false, + }, + { + name: "no_https_endpoints", + cfg: &ipn.ServeConfig{ + TCP: map[uint16]*ipn.TCPPortHandler{ + 80: { + HTTPS: false, + }, + }, + }, + want: false, + }, + { + name: "has_https_endpoint", + cfg: &ipn.ServeConfig{ + TCP: map[uint16]*ipn.TCPPortHandler{ + 443: { + HTTPS: true, + }, + }, + }, + want: true, + }, + { + name: "mixed_endpoints", + cfg: &ipn.ServeConfig{ + TCP: map[uint16]*ipn.TCPPortHandler{ + 80: {HTTPS: false}, + 443: {HTTPS: true}, + }, + }, + want: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := hasHTTPSEndpoint(tt.cfg) + if got != tt.want { + t.Errorf("hasHTTPSEndpoint() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/cmd/containerboot/settings.go b/cmd/containerboot/settings.go index d72aefbdf..5a8be9036 100644 --- a/cmd/containerboot/settings.go +++ b/cmd/containerboot/settings.go @@ -14,6 +14,7 @@ import ( "os" "path" "strconv" + "strings" "tailscale.com/ipn/conffile" "tailscale.com/kube/kubeclient" @@ -62,9 +63,151 @@ type settings struct { // PodIP is the IP of the Pod if running in Kubernetes. This is used // when setting up rules to proxy cluster traffic to cluster ingress // target. - PodIP string - HealthCheckAddrPort string - EgressSvcsCfgPath string + // Deprecated: use PodIPv4, PodIPv6 instead to support dual stack clusters + PodIP string + PodIPv4 string + PodIPv6 string + PodUID string + HealthCheckAddrPort string + LocalAddrPort string + MetricsEnabled bool + HealthCheckEnabled bool + DebugAddrPort string + EgressProxiesCfgPath string + IngressProxiesCfgPath string + // CertShareMode is set for Kubernetes Pods running cert share mode. + // Possible values are empty (containerboot doesn't run any certs + // logic), 'ro' (for Pods that shold never attempt to issue/renew + // certs) and 'rw' for Pods that should manage the TLS certs shared + // amongst the replicas. + CertShareMode string +} + +func configFromEnv() (*settings, error) { + cfg := &settings{ + AuthKey: defaultEnvs([]string{"TS_AUTHKEY", "TS_AUTH_KEY"}, ""), + Hostname: defaultEnv("TS_HOSTNAME", ""), + Routes: defaultEnvStringPointer("TS_ROUTES"), + ServeConfigPath: defaultEnv("TS_SERVE_CONFIG", ""), + ProxyTargetIP: defaultEnv("TS_DEST_IP", ""), + ProxyTargetDNSName: defaultEnv("TS_EXPERIMENTAL_DEST_DNS_NAME", ""), + TailnetTargetIP: defaultEnv("TS_TAILNET_TARGET_IP", ""), + TailnetTargetFQDN: defaultEnv("TS_TAILNET_TARGET_FQDN", ""), + DaemonExtraArgs: defaultEnv("TS_TAILSCALED_EXTRA_ARGS", ""), + ExtraArgs: defaultEnv("TS_EXTRA_ARGS", ""), + InKubernetes: os.Getenv("KUBERNETES_SERVICE_HOST") != "", + UserspaceMode: defaultBool("TS_USERSPACE", true), + StateDir: defaultEnv("TS_STATE_DIR", ""), + AcceptDNS: defaultEnvBoolPointer("TS_ACCEPT_DNS"), + KubeSecret: defaultEnv("TS_KUBE_SECRET", "tailscale"), + SOCKSProxyAddr: defaultEnv("TS_SOCKS5_SERVER", ""), + HTTPProxyAddr: defaultEnv("TS_OUTBOUND_HTTP_PROXY_LISTEN", ""), + Socket: defaultEnv("TS_SOCKET", "/tmp/tailscaled.sock"), + AuthOnce: defaultBool("TS_AUTH_ONCE", false), + Root: defaultEnv("TS_TEST_ONLY_ROOT", "/"), + TailscaledConfigFilePath: tailscaledConfigFilePath(), + AllowProxyingClusterTrafficViaIngress: defaultBool("EXPERIMENTAL_ALLOW_PROXYING_CLUSTER_TRAFFIC_VIA_INGRESS", false), + PodIP: defaultEnv("POD_IP", ""), + EnableForwardingOptimizations: defaultBool("TS_EXPERIMENTAL_ENABLE_FORWARDING_OPTIMIZATIONS", false), + HealthCheckAddrPort: defaultEnv("TS_HEALTHCHECK_ADDR_PORT", ""), + LocalAddrPort: defaultEnv("TS_LOCAL_ADDR_PORT", "[::]:9002"), + MetricsEnabled: defaultBool("TS_ENABLE_METRICS", false), + HealthCheckEnabled: defaultBool("TS_ENABLE_HEALTH_CHECK", false), + DebugAddrPort: defaultEnv("TS_DEBUG_ADDR_PORT", ""), + EgressProxiesCfgPath: defaultEnv("TS_EGRESS_PROXIES_CONFIG_PATH", ""), + IngressProxiesCfgPath: defaultEnv("TS_INGRESS_PROXIES_CONFIG_PATH", ""), + PodUID: defaultEnv("POD_UID", ""), + } + podIPs, ok := os.LookupEnv("POD_IPS") + if ok { + ips := strings.Split(podIPs, ",") + if len(ips) > 2 { + return nil, fmt.Errorf("POD_IPs can contain at most 2 IPs, got %d (%v)", len(ips), ips) + } + for _, ip := range ips { + parsed, err := netip.ParseAddr(ip) + if err != nil { + return nil, fmt.Errorf("error parsing IP address %s: %w", ip, err) + } + if parsed.Is4() { + cfg.PodIPv4 = parsed.String() + continue + } + cfg.PodIPv6 = parsed.String() + } + } + // If cert share is enabled, set the replica as read or write. Only 0th + // replica should be able to write. + isInCertShareMode := defaultBool("TS_EXPERIMENTAL_CERT_SHARE", false) + if isInCertShareMode { + cfg.CertShareMode = "ro" + podName := os.Getenv("POD_NAME") + if strings.HasSuffix(podName, "-0") { + cfg.CertShareMode = "rw" + } + } + + // See https://github.com/tailscale/tailscale/issues/16108 for context- we + // do this to preserve the previous behaviour where --accept-dns could be + // set either via TS_ACCEPT_DNS or TS_EXTRA_ARGS. + acceptDNS := cfg.AcceptDNS != nil && *cfg.AcceptDNS + tsExtraArgs, acceptDNSNew := parseAcceptDNS(cfg.ExtraArgs, acceptDNS) + cfg.ExtraArgs = tsExtraArgs + if acceptDNS != acceptDNSNew { + cfg.AcceptDNS = &acceptDNSNew + } + + if err := cfg.validate(); err != nil { + return nil, fmt.Errorf("invalid configuration: %v", err) + } + return cfg, nil +} + +// parseAcceptDNS parses any values for Tailscale --accept-dns flag set via +// TS_ACCEPT_DNS and TS_EXTRA_ARGS env vars. If TS_EXTRA_ARGS contains +// --accept-dns flag, override the acceptDNS value with the one from +// TS_EXTRA_ARGS. +// The value of extraArgs can be empty string or one or more whitespace-separate +// key value pairs for 'tailscale up' command. The value for boolean flags can +// be omitted (default to true). +func parseAcceptDNS(extraArgs string, acceptDNS bool) (string, bool) { + if !strings.Contains(extraArgs, "--accept-dns") { + return extraArgs, acceptDNS + } + // TODO(irbekrm): we should validate that TS_EXTRA_ARGS contains legit + // 'tailscale up' flag values separated by whitespace. + argsArr := strings.Fields(extraArgs) + i := -1 + for key, val := range argsArr { + if strings.HasPrefix(val, "--accept-dns") { + i = key + break + } + } + if i == -1 { + return extraArgs, acceptDNS + } + a := strings.TrimSpace(argsArr[i]) + var acceptDNSFromExtraArgsS string + keyval := strings.Split(a, "=") + if len(keyval) == 2 { + acceptDNSFromExtraArgsS = keyval[1] + } else if len(keyval) == 1 && keyval[0] == "--accept-dns" { + // If the arg is just --accept-dns, we assume it means true. + acceptDNSFromExtraArgsS = "true" + } else { + log.Printf("TS_EXTRA_ARGS contains --accept-dns, but it is not in the expected format --accept-dns=, ignoring it") + return extraArgs, acceptDNS + } + acceptDNSFromExtraArgs, err := strconv.ParseBool(acceptDNSFromExtraArgsS) + if err != nil { + log.Printf("TS_EXTRA_ARGS contains --accept-dns=%q, which is not a valid boolean value, ignoring it", acceptDNSFromExtraArgsS) + return extraArgs, acceptDNS + } + if acceptDNSFromExtraArgs != acceptDNS { + log.Printf("TS_EXTRA_ARGS contains --accept-dns=%v, which overrides TS_ACCEPT_DNS=%v", acceptDNSFromExtraArgs, acceptDNS) + } + return strings.Join(append(argsArr[:i], argsArr[i+1:]...), " "), acceptDNSFromExtraArgs } func (s *settings) validate() error { @@ -114,60 +257,88 @@ func (s *settings) validate() error { return errors.New("TS_EXPERIMENTAL_ENABLE_FORWARDING_OPTIMIZATIONS is not supported in userspace mode") } if s.HealthCheckAddrPort != "" { + log.Printf("[warning] TS_HEALTHCHECK_ADDR_PORT is deprecated and will be removed in 1.82.0. Please use TS_ENABLE_HEALTH_CHECK and optionally TS_LOCAL_ADDR_PORT instead.") if _, err := netip.ParseAddrPort(s.HealthCheckAddrPort); err != nil { - return fmt.Errorf("error parsing TS_HEALTH_CHECK_ADDR_PORT value %q: %w", s.HealthCheckAddrPort, err) + return fmt.Errorf("error parsing TS_HEALTHCHECK_ADDR_PORT value %q: %w", s.HealthCheckAddrPort, err) + } + } + if s.localMetricsEnabled() || s.localHealthEnabled() || s.EgressProxiesCfgPath != "" { + if _, err := netip.ParseAddrPort(s.LocalAddrPort); err != nil { + return fmt.Errorf("error parsing TS_LOCAL_ADDR_PORT value %q: %w", s.LocalAddrPort, err) } } + if s.DebugAddrPort != "" { + if _, err := netip.ParseAddrPort(s.DebugAddrPort); err != nil { + return fmt.Errorf("error parsing TS_DEBUG_ADDR_PORT value %q: %w", s.DebugAddrPort, err) + } + } + if s.HealthCheckEnabled && s.HealthCheckAddrPort != "" { + return errors.New("TS_HEALTHCHECK_ADDR_PORT is deprecated and will be removed in 1.82.0, use TS_ENABLE_HEALTH_CHECK and optionally TS_LOCAL_ADDR_PORT") + } + if s.EgressProxiesCfgPath != "" && !(s.InKubernetes && s.KubeSecret != "") { + return errors.New("TS_EGRESS_PROXIES_CONFIG_PATH is only supported for Tailscale running on Kubernetes") + } + if s.IngressProxiesCfgPath != "" && !(s.InKubernetes && s.KubeSecret != "") { + return errors.New("TS_INGRESS_PROXIES_CONFIG_PATH is only supported for Tailscale running on Kubernetes") + } return nil } // setupKube is responsible for doing any necessary configuration and checks to // ensure that tailscale state storage and authentication mechanism will work on // Kubernetes. -func (cfg *settings) setupKube(ctx context.Context) error { +func (cfg *settings) setupKube(ctx context.Context, kc *kubeClient) error { if cfg.KubeSecret == "" { return nil } canPatch, canCreate, err := kc.CheckSecretPermissions(ctx, cfg.KubeSecret) if err != nil { - return fmt.Errorf("Some Kubernetes permissions are missing, please check your RBAC configuration: %v", err) + return fmt.Errorf("some Kubernetes permissions are missing, please check your RBAC configuration: %v", err) } cfg.KubernetesCanPatch = canPatch + kc.canPatch = canPatch s, err := kc.GetSecret(ctx, cfg.KubeSecret) - if err != nil && kubeclient.IsNotFoundErr(err) && !canCreate { - return fmt.Errorf("Tailscale state Secret %s does not exist and we don't have permissions to create it. "+ - "If you intend to store tailscale state elsewhere than a Kubernetes Secret, "+ - "you can explicitly set TS_KUBE_SECRET env var to an empty string. "+ - "Else ensure that RBAC is set up that allows the service account associated with this installation to create Secrets.", cfg.KubeSecret) - } else if err != nil && !kubeclient.IsNotFoundErr(err) { - return fmt.Errorf("Getting Tailscale state Secret %s: %v", cfg.KubeSecret, err) - } - - if cfg.AuthKey == "" && !isOneStepConfig(cfg) { - if s == nil { - log.Print("TS_AUTHKEY not provided and kube secret does not exist, login will be interactive if needed.") - return nil + if err != nil { + if !kubeclient.IsNotFoundErr(err) { + return fmt.Errorf("getting Tailscale state Secret %s: %v", cfg.KubeSecret, err) } - keyBytes, _ := s.Data["authkey"] - key := string(keyBytes) - - if key != "" { - // This behavior of pulling authkeys from kube secrets was added - // at the same time as the patch permission, so we can enforce - // that we must be able to patch out the authkey after - // authenticating if you want to use this feature. This avoids - // us having to deal with the case where we might leave behind - // an unnecessary reusable authkey in a secret, like a rake in - // the grass. - if !cfg.KubernetesCanPatch { - return errors.New("authkey found in TS_KUBE_SECRET, but the pod doesn't have patch permissions on the secret to manage the authkey.") - } - cfg.AuthKey = key - } else { - log.Print("No authkey found in kube secret and TS_AUTHKEY not provided, login will be interactive if needed.") + + if !canCreate { + return fmt.Errorf("tailscale state Secret %s does not exist and we don't have permissions to create it. "+ + "If you intend to store tailscale state elsewhere than a Kubernetes Secret, "+ + "you can explicitly set TS_KUBE_SECRET env var to an empty string. "+ + "Else ensure that RBAC is set up that allows the service account associated with this installation to create Secrets.", cfg.KubeSecret) + } + } + + // Return early if we already have an auth key. + if cfg.AuthKey != "" || isOneStepConfig(cfg) { + return nil + } + + if s == nil { + log.Print("TS_AUTHKEY not provided and state Secret does not exist, login will be interactive if needed.") + return nil + } + + keyBytes, _ := s.Data["authkey"] + key := string(keyBytes) + + if key != "" { + // Enforce that we must be able to patch out the authkey after + // authenticating if you want to use this feature. This avoids + // us having to deal with the case where we might leave behind + // an unnecessary reusable authkey in a secret, like a rake in + // the grass. + if !cfg.KubernetesCanPatch { + return errors.New("authkey found in TS_KUBE_SECRET, but the pod doesn't have patch permissions on the Secret to manage the authkey.") } + cfg.AuthKey = key } + + log.Print("No authkey found in state Secret and TS_AUTHKEY not provided, login will be interactive if needed.") + return nil } @@ -199,7 +370,7 @@ func isOneStepConfig(cfg *settings) bool { // as an L3 proxy, proxying to an endpoint provided via one of the config env // vars. func isL3Proxy(cfg *settings) bool { - return cfg.ProxyTargetIP != "" || cfg.ProxyTargetDNSName != "" || cfg.TailnetTargetIP != "" || cfg.TailnetTargetFQDN != "" || cfg.AllowProxyingClusterTrafficViaIngress || cfg.EgressSvcsCfgPath != "" + return cfg.ProxyTargetIP != "" || cfg.ProxyTargetDNSName != "" || cfg.TailnetTargetIP != "" || cfg.TailnetTargetFQDN != "" || cfg.AllowProxyingClusterTrafficViaIngress || cfg.EgressProxiesCfgPath != "" || cfg.IngressProxiesCfgPath != "" } // hasKubeStateStore returns true if the state must be stored in a Kubernetes @@ -208,6 +379,18 @@ func hasKubeStateStore(cfg *settings) bool { return cfg.InKubernetes && cfg.KubernetesCanPatch && cfg.KubeSecret != "" } +func (cfg *settings) localMetricsEnabled() bool { + return cfg.LocalAddrPort != "" && cfg.MetricsEnabled +} + +func (cfg *settings) localHealthEnabled() bool { + return cfg.LocalAddrPort != "" && cfg.HealthCheckEnabled +} + +func (cfg *settings) egressSvcsTerminateEPEnabled() bool { + return cfg.LocalAddrPort != "" && cfg.EgressProxiesCfgPath != "" +} + // defaultEnv returns the value of the given envvar name, or defVal if // unset. func defaultEnv(name, defVal string) string { diff --git a/cmd/containerboot/settings_test.go b/cmd/containerboot/settings_test.go new file mode 100644 index 000000000..dbec066c9 --- /dev/null +++ b/cmd/containerboot/settings_test.go @@ -0,0 +1,108 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build linux + +package main + +import "testing" + +func Test_parseAcceptDNS(t *testing.T) { + tests := []struct { + name string + extraArgs string + acceptDNS bool + wantExtraArgs string + wantAcceptDNS bool + }{ + { + name: "false_extra_args_unset", + extraArgs: "", + wantExtraArgs: "", + wantAcceptDNS: false, + }, + { + name: "false_unrelated_args_set", + extraArgs: "--accept-routes=true --advertise-routes=10.0.0.1/32", + wantExtraArgs: "--accept-routes=true --advertise-routes=10.0.0.1/32", + wantAcceptDNS: false, + }, + { + name: "true_extra_args_unset", + extraArgs: "", + acceptDNS: true, + wantExtraArgs: "", + wantAcceptDNS: true, + }, + { + name: "true_unrelated_args_set", + acceptDNS: true, + extraArgs: "--accept-routes=true --advertise-routes=10.0.0.1/32", + wantExtraArgs: "--accept-routes=true --advertise-routes=10.0.0.1/32", + wantAcceptDNS: true, + }, + { + name: "false_extra_args_set_to_false", + extraArgs: "--accept-dns=false", + wantExtraArgs: "", + wantAcceptDNS: false, + }, + { + name: "false_extra_args_set_to_true", + extraArgs: "--accept-dns=true", + wantExtraArgs: "", + wantAcceptDNS: true, + }, + { + name: "true_extra_args_set_to_false", + extraArgs: "--accept-dns=false", + acceptDNS: true, + wantExtraArgs: "", + wantAcceptDNS: false, + }, + { + name: "true_extra_args_set_to_true", + extraArgs: "--accept-dns=true", + acceptDNS: true, + wantExtraArgs: "", + wantAcceptDNS: true, + }, + { + name: "false_extra_args_set_to_true_implicitly", + extraArgs: "--accept-dns", + wantExtraArgs: "", + wantAcceptDNS: true, + }, + { + name: "false_extra_args_set_to_true_implicitly_with_unrelated_args", + extraArgs: "--accept-dns --accept-routes --advertise-routes=10.0.0.1/32", + wantExtraArgs: "--accept-routes --advertise-routes=10.0.0.1/32", + wantAcceptDNS: true, + }, + { + name: "false_extra_args_set_to_true_implicitly_surrounded_with_unrelated_args", + extraArgs: "--accept-routes --accept-dns --advertise-routes=10.0.0.1/32", + wantExtraArgs: "--accept-routes --advertise-routes=10.0.0.1/32", + wantAcceptDNS: true, + }, + { + name: "true_extra_args_set_to_false_with_unrelated_args", + extraArgs: "--accept-routes --accept-dns=false", + acceptDNS: true, + wantExtraArgs: "--accept-routes", + wantAcceptDNS: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotExtraArgs, gotAcceptDNS := parseAcceptDNS(tt.extraArgs, tt.acceptDNS) + if gotExtraArgs != tt.wantExtraArgs { + t.Errorf("parseAcceptDNS() gotExtraArgs = %v, want %v", gotExtraArgs, tt.wantExtraArgs) + } + if gotAcceptDNS != tt.wantAcceptDNS { + t.Errorf("parseAcceptDNS() gotAcceptDNS = %v, want %v", gotAcceptDNS, tt.wantAcceptDNS) + } + }) + } +} diff --git a/cmd/containerboot/tailscaled.go b/cmd/containerboot/tailscaled.go index 53fb7e703..f828c5257 100644 --- a/cmd/containerboot/tailscaled.go +++ b/cmd/containerboot/tailscaled.go @@ -13,14 +13,17 @@ import ( "log" "os" "os/exec" + "path/filepath" + "reflect" "strings" "syscall" "time" - "tailscale.com/client/tailscale" + "github.com/fsnotify/fsnotify" + "tailscale.com/client/local" ) -func startTailscaled(ctx context.Context, cfg *settings) (*tailscale.LocalClient, *os.Process, error) { +func startTailscaled(ctx context.Context, cfg *settings) (*local.Client, *os.Process, error) { args := tailscaledArgs(cfg) // tailscaled runs without context, since it needs to persist // beyond the startup timeout in ctx. @@ -30,28 +33,31 @@ func startTailscaled(ctx context.Context, cfg *settings) (*tailscale.LocalClient cmd.SysProcAttr = &syscall.SysProcAttr{ Setpgid: true, } + if cfg.CertShareMode != "" { + cmd.Env = append(os.Environ(), "TS_CERT_SHARE_MODE="+cfg.CertShareMode) + } log.Printf("Starting tailscaled") if err := cmd.Start(); err != nil { - return nil, nil, fmt.Errorf("starting tailscaled failed: %v", err) + return nil, nil, fmt.Errorf("starting tailscaled failed: %w", err) } // Wait for the socket file to appear, otherwise API ops will racily fail. - log.Printf("Waiting for tailscaled socket") + log.Printf("Waiting for tailscaled socket at %s", cfg.Socket) for { if ctx.Err() != nil { - log.Fatalf("Timed out waiting for tailscaled socket") + return nil, nil, errors.New("timed out waiting for tailscaled socket") } _, err := os.Stat(cfg.Socket) if errors.Is(err, fs.ErrNotExist) { time.Sleep(100 * time.Millisecond) continue } else if err != nil { - log.Fatalf("Waiting for tailscaled socket: %v", err) + return nil, nil, fmt.Errorf("error waiting for tailscaled socket: %w", err) } break } - tsClient := &tailscale.LocalClient{ + tsClient := &local.Client{ Socket: cfg.Socket, UseSocketOnly: true, } @@ -90,6 +96,12 @@ func tailscaledArgs(cfg *settings) []string { if cfg.TailscaledConfigFilePath != "" { args = append(args, "--config="+cfg.TailscaledConfigFilePath) } + // Once enough proxy versions have been released for all the supported + // versions to understand this cfg setting, the operator can stop + // setting TS_TAILSCALED_EXTRA_ARGS for the debug flag. + if cfg.DebugAddrPort != "" && !strings.Contains(cfg.DaemonExtraArgs, cfg.DebugAddrPort) { + args = append(args, "--debug="+cfg.DebugAddrPort) + } if cfg.DaemonExtraArgs != "" { args = append(args, strings.Fields(cfg.DaemonExtraArgs)...) } @@ -160,3 +172,75 @@ func tailscaleSet(ctx context.Context, cfg *settings) error { } return nil } + +func watchTailscaledConfigChanges(ctx context.Context, path string, lc *local.Client, errCh chan<- error) { + var ( + tickChan <-chan time.Time + eventChan <-chan fsnotify.Event + errChan <-chan error + tailscaledCfgDir = filepath.Dir(path) + prevTailscaledCfg []byte + ) + if w, err := fsnotify.NewWatcher(); err != nil { + // Creating a new fsnotify watcher would fail for example if inotify was not able to create a new file descriptor. + // See https://github.com/tailscale/tailscale/issues/15081 + log.Printf("tailscaled config watch: failed to create fsnotify watcher, timer-only mode: %v", err) + ticker := time.NewTicker(5 * time.Second) + defer ticker.Stop() + tickChan = ticker.C + } else { + defer w.Close() + if err := w.Add(tailscaledCfgDir); err != nil { + errCh <- fmt.Errorf("failed to add fsnotify watch: %w", err) + return + } + eventChan = w.Events + errChan = w.Errors + } + b, err := os.ReadFile(path) + if err != nil { + errCh <- fmt.Errorf("error reading configfile: %w", err) + return + } + prevTailscaledCfg = b + // kubelet mounts Secrets to Pods using a series of symlinks, one of + // which is /..data that Kubernetes recommends consumers to + // use if they need to monitor changes + // https://github.com/kubernetes/kubernetes/blob/v1.28.1/pkg/volume/util/atomic_writer.go#L39-L61 + const kubeletMountedCfg = "..data" + toWatch := filepath.Join(tailscaledCfgDir, kubeletMountedCfg) + for { + select { + case <-ctx.Done(): + return + case err := <-errChan: + errCh <- fmt.Errorf("watcher error: %w", err) + return + case <-tickChan: + case event := <-eventChan: + if event.Name != toWatch { + continue + } + } + b, err := os.ReadFile(path) + if err != nil { + errCh <- fmt.Errorf("error reading configfile: %w", err) + return + } + // For some proxy types the mounted volume also contains tailscaled state and other files. We + // don't want to reload config unnecessarily on unrelated changes to these files. + if reflect.DeepEqual(b, prevTailscaledCfg) { + continue + } + prevTailscaledCfg = b + log.Printf("tailscaled config watch: ensuring that config is up to date") + ok, err := lc.ReloadConfig(ctx) + if err != nil { + errCh <- fmt.Errorf("error reloading tailscaled config: %w", err) + return + } + if ok { + log.Printf("tailscaled config watch: config was reloaded") + } + } +} diff --git a/cmd/derper/ace.go b/cmd/derper/ace.go new file mode 100644 index 000000000..56fb68c33 --- /dev/null +++ b/cmd/derper/ace.go @@ -0,0 +1,77 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// TODO: docs about all this + +package main + +import ( + "errors" + "fmt" + "net" + "net/http" + "strings" + + "tailscale.com/derp/derpserver" + "tailscale.com/net/connectproxy" +) + +// serveConnect handles a CONNECT request for ACE support. +func serveConnect(s *derpserver.Server, w http.ResponseWriter, r *http.Request) { + if !*flagACEEnabled { + http.Error(w, "CONNECT not enabled", http.StatusForbidden) + return + } + if r.TLS == nil { + // This should already be enforced by the caller of serveConnect, but + // double check. + http.Error(w, "CONNECT requires TLS", http.StatusForbidden) + return + } + + ch := &connectproxy.Handler{ + Check: func(hostPort string) error { + host, port, err := net.SplitHostPort(hostPort) + if err != nil { + return err + } + if port != "443" && port != "80" { + // There are only two types of CONNECT requests the client makes + // via ACE: requests for /key (port 443) and requests to upgrade + // to the bidirectional ts2021 Noise protocol. + // + // The ts2021 layer can bootstrap over port 80 (http) or port + // 443 (https). + // + // Without ACE, we prefer port 80 to avoid unnecessary double + // encryption. But enough places require TLS+port 443 that we do + // support that double encryption path as a fallback. + // + // But ACE adds its own TLS layer (ACE is always CONNECT over + // https). If we don't permit port 80 here as a target, we'd + // have three layers of encryption (TLS + TLS + Noise) which is + // even more silly than two. + // + // So we permit port 80 such that we can only have two layers of + // encryption, varying by the request type: + // + // 1. TLS from client to ACE proxy (CONNECT) + // 2a. TLS from ACE proxy to https://controlplane.tailscale.com/key (port 443) + // 2b. ts2021 Noise from ACE proxy to http://controlplane.tailscale.com/ts2021 (port 80) + // + // But nothing's stopping the client from doing its ts2021 + // upgrade over https anyway and having three layers of + // encryption. But we can at least permit the client to do a + // "CONNECT controlplane.tailscale.com:80 HTTP/1.1" if it wants. + return fmt.Errorf("only ports 443 and 80 are allowed") + } + // TODO(bradfitz): make policy configurable from flags and/or come + // from local tailscaled nodeAttrs + if !strings.HasSuffix(host, ".tailscale.com") || strings.Contains(host, "derp") { + return errors.New("bad host") + } + return nil + }, + } + ch.ServeHTTP(w, r) +} diff --git a/cmd/derper/bootstrap_dns_test.go b/cmd/derper/bootstrap_dns_test.go index d151bc2b0..9b99103ab 100644 --- a/cmd/derper/bootstrap_dns_test.go +++ b/cmd/derper/bootstrap_dns_test.go @@ -20,10 +20,10 @@ import ( ) func BenchmarkHandleBootstrapDNS(b *testing.B) { - tstest.Replace(b, bootstrapDNS, "log.tailscale.io,login.tailscale.com,controlplane.tailscale.com,login.us.tailscale.com") + tstest.Replace(b, bootstrapDNS, "log.tailscale.com,login.tailscale.com,controlplane.tailscale.com,login.us.tailscale.com") refreshBootstrapDNS() w := new(bitbucketResponseWriter) - req, _ := http.NewRequest("GET", "https://localhost/bootstrap-dns?q="+url.QueryEscape("log.tailscale.io"), nil) + req, _ := http.NewRequest("GET", "https://localhost/bootstrap-dns?q="+url.QueryEscape("log.tailscale.com"), nil) b.ReportAllocs() b.ResetTimer() b.RunParallel(func(b *testing.PB) { @@ -63,7 +63,7 @@ func TestUnpublishedDNS(t *testing.T) { nettest.SkipIfNoNetwork(t) const published = "login.tailscale.com" - const unpublished = "log.tailscale.io" + const unpublished = "log.tailscale.com" prev1, prev2 := *bootstrapDNS, *unpublishedDNS *bootstrapDNS = published @@ -119,18 +119,18 @@ func TestUnpublishedDNSEmptyList(t *testing.T) { unpublishedDNSCache.Store(&dnsEntryMap{ IPs: map[string][]net.IP{ - "log.tailscale.io": {}, + "log.tailscale.com": {}, "controlplane.tailscale.com": {net.IPv4(1, 2, 3, 4)}, }, Percent: map[string]float64{ - "log.tailscale.io": 1.0, + "log.tailscale.com": 1.0, "controlplane.tailscale.com": 1.0, }, }) t.Run("CacheMiss", func(t *testing.T) { // One domain in map but empty, one not in map at all - for _, q := range []string{"log.tailscale.io", "login.tailscale.com"} { + for _, q := range []string{"log.tailscale.com", "login.tailscale.com"} { resetMetrics() ips := getBootstrapDNS(t, q) diff --git a/cmd/derper/cert.go b/cmd/derper/cert.go index db84aa515..b95755c64 100644 --- a/cmd/derper/cert.go +++ b/cmd/derper/cert.go @@ -4,15 +4,28 @@ package main import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/sha256" "crypto/tls" "crypto/x509" + "crypto/x509/pkix" + "encoding/json" + "encoding/pem" "errors" "fmt" + "log" + "math/big" + "net" "net/http" + "os" "path/filepath" "regexp" + "time" "golang.org/x/crypto/acme/autocert" + "tailscale.com/tailcfg" ) var unsafeHostnameCharacters = regexp.MustCompile(`[^a-zA-Z0-9-\.]`) @@ -53,8 +66,9 @@ func certProviderByCertMode(mode, dir, hostname string) (certProvider, error) { } type manualCertManager struct { - cert *tls.Certificate - hostname string + cert *tls.Certificate + hostname string // hostname or IP address of server + noHostname bool // whether hostname is an IP address } // NewManualCertManager returns a cert provider which read certificate by given hostname on create. @@ -63,8 +77,18 @@ func NewManualCertManager(certdir, hostname string) (certProvider, error) { crtPath := filepath.Join(certdir, keyname+".crt") keyPath := filepath.Join(certdir, keyname+".key") cert, err := tls.LoadX509KeyPair(crtPath, keyPath) + hostnameIP := net.ParseIP(hostname) // or nil if hostname isn't an IP address if err != nil { - return nil, fmt.Errorf("can not load x509 key pair for hostname %q: %w", keyname, err) + // If the hostname is an IP address, automatically create a + // self-signed certificate for it. + var certp *tls.Certificate + if os.IsNotExist(err) && hostnameIP != nil { + certp, err = createSelfSignedIPCert(crtPath, keyPath, hostname) + } + if err != nil { + return nil, fmt.Errorf("can not load x509 key pair for hostname %q: %w", keyname, err) + } + cert = *certp } // ensure hostname matches with the certificate x509Cert, err := x509.ParseCertificate(cert.Certificate[0]) @@ -74,7 +98,23 @@ func NewManualCertManager(certdir, hostname string) (certProvider, error) { if err := x509Cert.VerifyHostname(hostname); err != nil { return nil, fmt.Errorf("cert invalid for hostname %q: %w", hostname, err) } - return &manualCertManager{cert: &cert, hostname: hostname}, nil + if hostnameIP != nil { + // If the hostname is an IP address, print out information on how to + // confgure this in the derpmap. + dn := &tailcfg.DERPNode{ + Name: "custom", + RegionID: 900, + HostName: hostname, + CertName: fmt.Sprintf("sha256-raw:%-02x", sha256.Sum256(x509Cert.Raw)), + } + dnJSON, _ := json.Marshal(dn) + log.Printf("Using self-signed certificate for IP address %q. Configure it in DERPMap using: (https://tailscale.com/s/custom-derp)\n %s", hostname, dnJSON) + } + return &manualCertManager{ + cert: &cert, + hostname: hostname, + noHostname: net.ParseIP(hostname) != nil, + }, nil } func (m *manualCertManager) TLSConfig() *tls.Config { @@ -88,7 +128,7 @@ func (m *manualCertManager) TLSConfig() *tls.Config { } func (m *manualCertManager) getCertificate(hi *tls.ClientHelloInfo) (*tls.Certificate, error) { - if hi.ServerName != m.hostname { + if hi.ServerName != m.hostname && !m.noHostname { return nil, fmt.Errorf("cert mismatch with hostname: %q", hi.ServerName) } @@ -103,3 +143,69 @@ func (m *manualCertManager) getCertificate(hi *tls.ClientHelloInfo) (*tls.Certif func (m *manualCertManager) HTTPHandler(fallback http.Handler) http.Handler { return fallback } + +func createSelfSignedIPCert(crtPath, keyPath, ipStr string) (*tls.Certificate, error) { + ip := net.ParseIP(ipStr) + if ip == nil { + return nil, fmt.Errorf("invalid IP address: %s", ipStr) + } + + priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + return nil, fmt.Errorf("failed to generate EC private key: %v", err) + } + + serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) + serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) + if err != nil { + return nil, fmt.Errorf("failed to generate serial number: %v", err) + } + + now := time.Now() + template := x509.Certificate{ + SerialNumber: serialNumber, + Subject: pkix.Name{ + CommonName: ipStr, + }, + NotBefore: now, + NotAfter: now.AddDate(1, 0, 0), // expires in 1 year; a bit over that is rejected by macOS etc + + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + } + + // Set the IP as a SAN. + template.IPAddresses = []net.IP{ip} + + // Create the self-signed certificate. + derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv) + if err != nil { + return nil, fmt.Errorf("failed to create certificate: %v", err) + } + + certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes}) + + keyBytes, err := x509.MarshalECPrivateKey(priv) + if err != nil { + return nil, fmt.Errorf("unable to marshal EC private key: %v", err) + } + + keyPEM := pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: keyBytes}) + + if err := os.MkdirAll(filepath.Dir(crtPath), 0700); err != nil { + return nil, fmt.Errorf("failed to create directory for certificate: %v", err) + } + if err := os.WriteFile(crtPath, certPEM, 0644); err != nil { + return nil, fmt.Errorf("failed to write certificate to %s: %v", crtPath, err) + } + if err := os.WriteFile(keyPath, keyPEM, 0600); err != nil { + return nil, fmt.Errorf("failed to write key to %s: %v", keyPath, err) + } + + tlsCert, err := tls.X509KeyPair(certPEM, keyPEM) + if err != nil { + return nil, fmt.Errorf("failed to create tls.Certificate: %v", err) + } + return &tlsCert, nil +} diff --git a/cmd/derper/cert_test.go b/cmd/derper/cert_test.go new file mode 100644 index 000000000..c8a3229e9 --- /dev/null +++ b/cmd/derper/cert_test.go @@ -0,0 +1,171 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package main + +import ( + "context" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/sha256" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "fmt" + "math/big" + "net" + "net/http" + "os" + "path/filepath" + "testing" + "time" + + "tailscale.com/derp/derphttp" + "tailscale.com/derp/derpserver" + "tailscale.com/net/netmon" + "tailscale.com/tailcfg" + "tailscale.com/types/key" +) + +// Verify that in --certmode=manual mode, we can use a bare IP address +// as the --hostname and that GetCertificate will return it. +func TestCertIP(t *testing.T) { + dir := t.TempDir() + const hostname = "1.2.3.4" + + priv, err := ecdsa.GenerateKey(elliptic.P224(), rand.Reader) + if err != nil { + t.Fatal(err) + } + serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) + serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) + if err != nil { + t.Fatal(err) + } + ip := net.ParseIP(hostname) + if ip == nil { + t.Fatalf("invalid IP address %q", hostname) + } + template := &x509.Certificate{ + SerialNumber: serialNumber, + Subject: pkix.Name{ + Organization: []string{"Tailscale Test Corp"}, + }, + NotBefore: time.Now(), + NotAfter: time.Now().Add(30 * 24 * time.Hour), + + KeyUsage: x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + IPAddresses: []net.IP{ip}, + } + derBytes, err := x509.CreateCertificate(rand.Reader, template, template, &priv.PublicKey, priv) + if err != nil { + t.Fatal(err) + } + certOut, err := os.Create(filepath.Join(dir, hostname+".crt")) + if err != nil { + t.Fatal(err) + } + if err := pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}); err != nil { + t.Fatalf("Failed to write data to cert.pem: %v", err) + } + if err := certOut.Close(); err != nil { + t.Fatalf("Error closing cert.pem: %v", err) + } + + keyOut, err := os.OpenFile(filepath.Join(dir, hostname+".key"), os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600) + if err != nil { + t.Fatal(err) + } + privBytes, err := x509.MarshalPKCS8PrivateKey(priv) + if err != nil { + t.Fatalf("Unable to marshal private key: %v", err) + } + if err := pem.Encode(keyOut, &pem.Block{Type: "PRIVATE KEY", Bytes: privBytes}); err != nil { + t.Fatalf("Failed to write data to key.pem: %v", err) + } + if err := keyOut.Close(); err != nil { + t.Fatalf("Error closing key.pem: %v", err) + } + + cp, err := certProviderByCertMode("manual", dir, hostname) + if err != nil { + t.Fatal(err) + } + back, err := cp.TLSConfig().GetCertificate(&tls.ClientHelloInfo{ + ServerName: "", // no SNI + }) + if err != nil { + t.Fatalf("GetCertificate: %v", err) + } + if back == nil { + t.Fatalf("GetCertificate returned nil") + } +} + +// Test that we can dial a raw IP without using a hostname and without a WebPKI +// cert, validating the cert against the signature of the cert in the DERP map's +// DERPNode. +// +// See https://github.com/tailscale/tailscale/issues/11776. +func TestPinnedCertRawIP(t *testing.T) { + td := t.TempDir() + cp, err := NewManualCertManager(td, "127.0.0.1") + if err != nil { + t.Fatalf("NewManualCertManager: %v", err) + } + + cert, err := cp.TLSConfig().GetCertificate(&tls.ClientHelloInfo{ + ServerName: "127.0.0.1", + }) + if err != nil { + t.Fatalf("GetCertificate: %v", err) + } + + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("Listen: %v", err) + } + defer ln.Close() + + ds := derpserver.New(key.NewNode(), t.Logf) + + derpHandler := derpserver.Handler(ds) + mux := http.NewServeMux() + mux.Handle("/derp", derpHandler) + + var hs http.Server + hs.Handler = mux + hs.TLSConfig = cp.TLSConfig() + ds.ModifyTLSConfigToAddMetaCert(hs.TLSConfig) + go hs.ServeTLS(ln, "", "") + + lnPort := ln.Addr().(*net.TCPAddr).Port + + reg := &tailcfg.DERPRegion{ + RegionID: 900, + Nodes: []*tailcfg.DERPNode{ + { + RegionID: 900, + HostName: "127.0.0.1", + CertName: fmt.Sprintf("sha256-raw:%-02x", sha256.Sum256(cert.Leaf.Raw)), + DERPPort: lnPort, + }, + }, + } + + netMon := netmon.NewStatic() + dc := derphttp.NewRegionClient(key.NewNode(), t.Logf, netMon, func() *tailcfg.DERPRegion { + return reg + }) + defer dc.Close() + + _, connClose, _, err := dc.DialRegionTLS(context.Background(), reg) + if err != nil { + t.Fatalf("DialRegionTLS: %v", err) + } + defer connClose.Close() +} diff --git a/cmd/derper/depaware.txt b/cmd/derper/depaware.txt index eb9ba1619..0a75ac43e 100644 --- a/cmd/derper/depaware.txt +++ b/cmd/derper/depaware.txt @@ -2,16 +2,12 @@ tailscale.com/cmd/derper dependencies: (generated by github.com/tailscale/depawa filippo.io/edwards25519 from github.com/hdevalence/ed25519consensus filippo.io/edwards25519/field from filippo.io/edwards25519 - W đŸ’Ŗ github.com/alexbrainman/sspi from github.com/alexbrainman/sspi/internal/common+ - W github.com/alexbrainman/sspi/internal/common from github.com/alexbrainman/sspi/negotiate - W đŸ’Ŗ github.com/alexbrainman/sspi/negotiate from tailscale.com/net/tshttpproxy github.com/beorn7/perks/quantile from github.com/prometheus/client_golang/prometheus đŸ’Ŗ github.com/cespare/xxhash/v2 from github.com/prometheus/client_golang/prometheus github.com/coder/websocket from tailscale.com/cmd/derper+ github.com/coder/websocket/internal/errd from github.com/coder/websocket github.com/coder/websocket/internal/util from github.com/coder/websocket github.com/coder/websocket/internal/xsync from github.com/coder/websocket - L github.com/coreos/go-iptables/iptables from tailscale.com/util/linuxfw W đŸ’Ŗ github.com/dblohm7/wingoes from tailscale.com/util/winutil github.com/fxamacker/cbor/v2 from tailscale.com/tka github.com/go-json-experiment/json from tailscale.com/types/opt+ @@ -21,27 +17,18 @@ tailscale.com/cmd/derper dependencies: (generated by github.com/tailscale/depawa github.com/go-json-experiment/json/internal/jsonwire from github.com/go-json-experiment/json+ github.com/go-json-experiment/json/jsontext from github.com/go-json-experiment/json+ github.com/golang/groupcache/lru from tailscale.com/net/dnscache - L github.com/google/nftables from tailscale.com/util/linuxfw - L đŸ’Ŗ github.com/google/nftables/alignedbuff from github.com/google/nftables/xt - L đŸ’Ŗ github.com/google/nftables/binaryutil from github.com/google/nftables+ - L github.com/google/nftables/expr from github.com/google/nftables+ - L github.com/google/nftables/internal/parseexprfunc from github.com/google/nftables+ - L github.com/google/nftables/xt from github.com/google/nftables/expr+ - github.com/google/uuid from tailscale.com/util/fastuuid github.com/hdevalence/ed25519consensus from tailscale.com/tka - L github.com/josharian/native from github.com/mdlayher/netlink+ L đŸ’Ŗ github.com/jsimonetti/rtnetlink from tailscale.com/net/netmon L github.com/jsimonetti/rtnetlink/internal/unix from github.com/jsimonetti/rtnetlink - L đŸ’Ŗ github.com/mdlayher/netlink from github.com/google/nftables+ + L đŸ’Ŗ github.com/mdlayher/netlink from github.com/jsimonetti/rtnetlink+ L đŸ’Ŗ github.com/mdlayher/netlink/nlenc from github.com/jsimonetti/rtnetlink+ - L github.com/mdlayher/netlink/nltest from github.com/google/nftables L đŸ’Ŗ github.com/mdlayher/socket from github.com/mdlayher/netlink đŸ’Ŗ github.com/mitchellh/go-ps from tailscale.com/safesocket + github.com/munnerz/goautoneg from github.com/prometheus/common/expfmt đŸ’Ŗ github.com/prometheus/client_golang/prometheus from tailscale.com/tsweb/promvarz github.com/prometheus/client_golang/prometheus/internal from github.com/prometheus/client_golang/prometheus github.com/prometheus/client_model/go from github.com/prometheus/client_golang/prometheus+ github.com/prometheus/common/expfmt from github.com/prometheus/client_golang/prometheus+ - github.com/prometheus/common/internal/bitbucket.org/ww/goautoneg from github.com/prometheus/common/expfmt github.com/prometheus/common/model from github.com/prometheus/client_golang/prometheus+ LD github.com/prometheus/procfs from github.com/prometheus/client_golang/prometheus LD github.com/prometheus/procfs/internal/fs from github.com/prometheus/procfs @@ -51,11 +38,10 @@ tailscale.com/cmd/derper dependencies: (generated by github.com/tailscale/depawa W đŸ’Ŗ github.com/tailscale/go-winio/internal/socket from github.com/tailscale/go-winio W github.com/tailscale/go-winio/internal/stringbuffer from github.com/tailscale/go-winio/internal/fs W github.com/tailscale/go-winio/pkg/guid from github.com/tailscale/go-winio+ - L đŸ’Ŗ github.com/tailscale/netlink from tailscale.com/util/linuxfw - L đŸ’Ŗ github.com/tailscale/netlink/nl from github.com/tailscale/netlink - L github.com/vishvananda/netns from github.com/tailscale/netlink+ + github.com/tailscale/setec/client/setec from tailscale.com/cmd/derper + github.com/tailscale/setec/types/api from github.com/tailscale/setec/client/setec github.com/x448/float16 from github.com/fxamacker/cbor/v2 - đŸ’Ŗ go4.org/mem from tailscale.com/client/tailscale+ + đŸ’Ŗ go4.org/mem from tailscale.com/client/local+ go4.org/netipx from tailscale.com/net/tsaddr W đŸ’Ŗ golang.zx2c4.com/wireguard/windows/tunnel/winipcfg from tailscale.com/net/netmon+ google.golang.org/protobuf/encoding/protodelim from github.com/prometheus/common/expfmt @@ -77,6 +63,7 @@ tailscale.com/cmd/derper dependencies: (generated by github.com/tailscale/depawa đŸ’Ŗ google.golang.org/protobuf/internal/impl from google.golang.org/protobuf/internal/filetype+ google.golang.org/protobuf/internal/order from google.golang.org/protobuf/encoding/prototext+ google.golang.org/protobuf/internal/pragma from google.golang.org/protobuf/encoding/prototext+ + đŸ’Ŗ google.golang.org/protobuf/internal/protolazy from google.golang.org/protobuf/internal/impl+ google.golang.org/protobuf/internal/set from google.golang.org/protobuf/encoding/prototext đŸ’Ŗ google.golang.org/protobuf/internal/strs from google.golang.org/protobuf/encoding/prototext+ google.golang.org/protobuf/internal/version from google.golang.org/protobuf/runtime/protoimpl @@ -87,88 +74,98 @@ tailscale.com/cmd/derper dependencies: (generated by github.com/tailscale/depawa google.golang.org/protobuf/runtime/protoimpl from github.com/prometheus/client_model/go+ google.golang.org/protobuf/types/known/timestamppb from github.com/prometheus/client_golang/prometheus+ tailscale.com from tailscale.com/version - tailscale.com/atomicfile from tailscale.com/cmd/derper+ - tailscale.com/client/tailscale from tailscale.com/derp - tailscale.com/client/tailscale/apitype from tailscale.com/client/tailscale + đŸ’Ŗ tailscale.com/atomicfile from tailscale.com/cmd/derper+ + tailscale.com/client/local from tailscale.com/derp/derpserver + tailscale.com/client/tailscale/apitype from tailscale.com/client/local tailscale.com/derp from tailscale.com/cmd/derper+ + tailscale.com/derp/derpconst from tailscale.com/derp/derphttp+ tailscale.com/derp/derphttp from tailscale.com/cmd/derper - tailscale.com/disco from tailscale.com/derp - tailscale.com/drive from tailscale.com/client/tailscale+ - tailscale.com/envknob from tailscale.com/client/tailscale+ + tailscale.com/derp/derpserver from tailscale.com/cmd/derper + tailscale.com/disco from tailscale.com/derp/derpserver + tailscale.com/drive from tailscale.com/client/local+ + tailscale.com/envknob from tailscale.com/client/local+ + tailscale.com/feature from tailscale.com/tsweb+ + tailscale.com/feature/buildfeatures from tailscale.com/feature+ tailscale.com/health from tailscale.com/net/tlsdial+ tailscale.com/hostinfo from tailscale.com/net/netmon+ - tailscale.com/ipn from tailscale.com/client/tailscale - tailscale.com/ipn/ipnstate from tailscale.com/client/tailscale+ + tailscale.com/ipn from tailscale.com/client/local + tailscale.com/ipn/ipnstate from tailscale.com/client/local+ tailscale.com/kube/kubetypes from tailscale.com/envknob tailscale.com/metrics from tailscale.com/cmd/derper+ + tailscale.com/net/bakedroots from tailscale.com/net/tlsdial + tailscale.com/net/connectproxy from tailscale.com/cmd/derper tailscale.com/net/dnscache from tailscale.com/derp/derphttp tailscale.com/net/ktimeout from tailscale.com/cmd/derper tailscale.com/net/netaddr from tailscale.com/ipn+ tailscale.com/net/netknob from tailscale.com/net/netns đŸ’Ŗ tailscale.com/net/netmon from tailscale.com/derp/derphttp+ đŸ’Ŗ tailscale.com/net/netns from tailscale.com/derp/derphttp - tailscale.com/net/netutil from tailscale.com/client/tailscale + tailscale.com/net/netutil from tailscale.com/client/local + tailscale.com/net/netx from tailscale.com/net/dnscache+ tailscale.com/net/sockstats from tailscale.com/derp/derphttp tailscale.com/net/stun from tailscale.com/net/stunserver tailscale.com/net/stunserver from tailscale.com/cmd/derper - L tailscale.com/net/tcpinfo from tailscale.com/derp + L tailscale.com/net/tcpinfo from tailscale.com/derp/derpserver tailscale.com/net/tlsdial from tailscale.com/derp/derphttp + tailscale.com/net/tlsdial/blockblame from tailscale.com/net/tlsdial tailscale.com/net/tsaddr from tailscale.com/ipn+ - đŸ’Ŗ tailscale.com/net/tshttpproxy from tailscale.com/derp/derphttp+ - tailscale.com/net/wsconn from tailscale.com/cmd/derper+ - tailscale.com/paths from tailscale.com/client/tailscale - đŸ’Ŗ tailscale.com/safesocket from tailscale.com/client/tailscale + tailscale.com/net/udprelay/status from tailscale.com/client/local + tailscale.com/net/wsconn from tailscale.com/cmd/derper + tailscale.com/paths from tailscale.com/client/local + đŸ’Ŗ tailscale.com/safesocket from tailscale.com/client/local tailscale.com/syncs from tailscale.com/cmd/derper+ - tailscale.com/tailcfg from tailscale.com/client/tailscale+ - tailscale.com/tka from tailscale.com/client/tailscale+ - W tailscale.com/tsconst from tailscale.com/net/netmon+ + tailscale.com/tailcfg from tailscale.com/client/local+ + tailscale.com/tka from tailscale.com/client/local+ + tailscale.com/tsconst from tailscale.com/net/netmon+ tailscale.com/tstime from tailscale.com/derp+ tailscale.com/tstime/mono from tailscale.com/tstime/rate - tailscale.com/tstime/rate from tailscale.com/derp - tailscale.com/tsweb from tailscale.com/cmd/derper - tailscale.com/tsweb/promvarz from tailscale.com/tsweb + tailscale.com/tstime/rate from tailscale.com/derp/derpserver + tailscale.com/tsweb from tailscale.com/cmd/derper+ + tailscale.com/tsweb/promvarz from tailscale.com/cmd/derper tailscale.com/tsweb/varz from tailscale.com/tsweb+ + tailscale.com/types/appctype from tailscale.com/client/local tailscale.com/types/dnstype from tailscale.com/tailcfg+ tailscale.com/types/empty from tailscale.com/ipn tailscale.com/types/ipproto from tailscale.com/tailcfg+ - tailscale.com/types/key from tailscale.com/client/tailscale+ + tailscale.com/types/key from tailscale.com/client/local+ tailscale.com/types/lazy from tailscale.com/version+ tailscale.com/types/logger from tailscale.com/cmd/derper+ tailscale.com/types/netmap from tailscale.com/ipn - tailscale.com/types/opt from tailscale.com/client/tailscale+ - tailscale.com/types/persist from tailscale.com/ipn + tailscale.com/types/opt from tailscale.com/envknob+ + tailscale.com/types/persist from tailscale.com/ipn+ tailscale.com/types/preftype from tailscale.com/ipn tailscale.com/types/ptr from tailscale.com/hostinfo+ + tailscale.com/types/result from tailscale.com/util/lineiter tailscale.com/types/structs from tailscale.com/ipn+ - tailscale.com/types/tkatype from tailscale.com/client/tailscale+ + tailscale.com/types/tkatype from tailscale.com/client/local+ tailscale.com/types/views from tailscale.com/ipn+ - tailscale.com/util/cibuild from tailscale.com/health - tailscale.com/util/clientmetric from tailscale.com/net/netmon+ + tailscale.com/util/cibuild from tailscale.com/health+ + tailscale.com/util/clientmetric from tailscale.com/net/netmon tailscale.com/util/cloudenv from tailscale.com/hostinfo+ - W tailscale.com/util/cmpver from tailscale.com/net/tshttpproxy tailscale.com/util/ctxkey from tailscale.com/tsweb+ đŸ’Ŗ tailscale.com/util/deephash from tailscale.com/util/syspolicy/setting L đŸ’Ŗ tailscale.com/util/dirwalk from tailscale.com/metrics tailscale.com/util/dnsname from tailscale.com/hostinfo+ - tailscale.com/util/fastuuid from tailscale.com/tsweb + tailscale.com/util/eventbus from tailscale.com/net/netmon+ đŸ’Ŗ tailscale.com/util/hashx from tailscale.com/util/deephash - tailscale.com/util/httpm from tailscale.com/client/tailscale - tailscale.com/util/lineread from tailscale.com/hostinfo+ - L tailscale.com/util/linuxfw from tailscale.com/net/netns + tailscale.com/util/lineiter from tailscale.com/hostinfo+ tailscale.com/util/mak from tailscale.com/health+ - tailscale.com/util/multierr from tailscale.com/health+ tailscale.com/util/nocasemaps from tailscale.com/types/ipproto - tailscale.com/util/set from tailscale.com/derp+ + tailscale.com/util/rands from tailscale.com/tsweb + tailscale.com/util/set from tailscale.com/derp/derpserver+ tailscale.com/util/singleflight from tailscale.com/net/dnscache tailscale.com/util/slicesx from tailscale.com/cmd/derper+ - tailscale.com/util/syspolicy from tailscale.com/ipn tailscale.com/util/syspolicy/internal from tailscale.com/util/syspolicy/setting - tailscale.com/util/syspolicy/setting from tailscale.com/util/syspolicy + tailscale.com/util/syspolicy/pkey from tailscale.com/ipn+ + tailscale.com/util/syspolicy/policyclient from tailscale.com/ipn + tailscale.com/util/syspolicy/ptype from tailscale.com/util/syspolicy/policyclient+ + tailscale.com/util/syspolicy/setting from tailscale.com/client/local + tailscale.com/util/testenv from tailscale.com/net/bakedroots+ tailscale.com/util/usermetric from tailscale.com/health tailscale.com/util/vizerror from tailscale.com/tailcfg+ W đŸ’Ŗ tailscale.com/util/winutil from tailscale.com/hostinfo+ W đŸ’Ŗ tailscale.com/util/winutil/winenv from tailscale.com/hostinfo+ - tailscale.com/version from tailscale.com/derp+ + tailscale.com/version from tailscale.com/cmd/derper+ tailscale.com/version/distro from tailscale.com/envknob+ tailscale.com/wgengine/filter/filtertype from tailscale.com/types/netmap golang.org/x/crypto/acme from golang.org/x/crypto/acme/autocert @@ -176,29 +173,24 @@ tailscale.com/cmd/derper dependencies: (generated by github.com/tailscale/depawa golang.org/x/crypto/argon2 from tailscale.com/tka golang.org/x/crypto/blake2b from golang.org/x/crypto/argon2+ golang.org/x/crypto/blake2s from tailscale.com/tka - golang.org/x/crypto/chacha20 from golang.org/x/crypto/chacha20poly1305 - golang.org/x/crypto/chacha20poly1305 from crypto/tls+ - golang.org/x/crypto/cryptobyte from crypto/ecdsa+ - golang.org/x/crypto/cryptobyte/asn1 from crypto/ecdsa+ golang.org/x/crypto/curve25519 from golang.org/x/crypto/nacl/box+ - golang.org/x/crypto/hkdf from crypto/tls+ + golang.org/x/crypto/internal/alias from golang.org/x/crypto/nacl/secretbox + golang.org/x/crypto/internal/poly1305 from golang.org/x/crypto/nacl/secretbox golang.org/x/crypto/nacl/box from tailscale.com/types/key golang.org/x/crypto/nacl/secretbox from golang.org/x/crypto/nacl/box golang.org/x/crypto/salsa20/salsa from golang.org/x/crypto/nacl/box+ - golang.org/x/crypto/sha3 from crypto/internal/mlkem768+ - W golang.org/x/exp/constraints from tailscale.com/util/winutil + golang.org/x/exp/constraints from tailscale.com/util/winutil+ golang.org/x/exp/maps from tailscale.com/util/syspolicy/setting L golang.org/x/net/bpf from github.com/mdlayher/netlink+ - golang.org/x/net/dns/dnsmessage from net+ - golang.org/x/net/http/httpguts from net/http - golang.org/x/net/http/httpproxy from net/http+ - golang.org/x/net/http2/hpack from net/http - golang.org/x/net/idna from golang.org/x/crypto/acme/autocert+ + golang.org/x/net/dns/dnsmessage from tailscale.com/net/dnscache + golang.org/x/net/idna from golang.org/x/crypto/acme/autocert + golang.org/x/net/internal/socks from golang.org/x/net/proxy golang.org/x/net/proxy from tailscale.com/net/netns - D golang.org/x/net/route from net+ + D golang.org/x/net/route from tailscale.com/net/netmon+ golang.org/x/sync/errgroup from github.com/mdlayher/socket+ - golang.org/x/sys/cpu from github.com/josharian/native+ - LD golang.org/x/sys/unix from github.com/google/nftables+ + golang.org/x/sync/singleflight from github.com/tailscale/setec/client/setec + golang.org/x/sys/cpu from golang.org/x/crypto/argon2+ + LD golang.org/x/sys/unix from github.com/jsimonetti/rtnetlink/internal/unix+ W golang.org/x/sys/windows from github.com/dblohm7/wingoes+ W golang.org/x/sys/windows/registry from github.com/dblohm7/wingoes+ W golang.org/x/sys/windows/svc from golang.org/x/sys/windows/svc/mgr+ @@ -208,6 +200,22 @@ tailscale.com/cmd/derper dependencies: (generated by github.com/tailscale/depawa golang.org/x/text/unicode/bidi from golang.org/x/net/idna+ golang.org/x/text/unicode/norm from golang.org/x/net/idna golang.org/x/time/rate from tailscale.com/cmd/derper+ + vendor/golang.org/x/crypto/chacha20 from vendor/golang.org/x/crypto/chacha20poly1305 + vendor/golang.org/x/crypto/chacha20poly1305 from crypto/internal/hpke+ + vendor/golang.org/x/crypto/cryptobyte from crypto/ecdsa+ + vendor/golang.org/x/crypto/cryptobyte/asn1 from crypto/ecdsa+ + vendor/golang.org/x/crypto/internal/alias from vendor/golang.org/x/crypto/chacha20+ + vendor/golang.org/x/crypto/internal/poly1305 from vendor/golang.org/x/crypto/chacha20poly1305 + vendor/golang.org/x/net/dns/dnsmessage from net + vendor/golang.org/x/net/http/httpguts from net/http+ + vendor/golang.org/x/net/http/httpproxy from net/http + vendor/golang.org/x/net/http2/hpack from net/http+ + vendor/golang.org/x/net/idna from net/http+ + vendor/golang.org/x/sys/cpu from vendor/golang.org/x/crypto/chacha20poly1305 + vendor/golang.org/x/text/secure/bidirule from vendor/golang.org/x/net/idna + vendor/golang.org/x/text/transform from vendor/golang.org/x/text/secure/bidirule+ + vendor/golang.org/x/text/unicode/bidi from vendor/golang.org/x/net/idna+ + vendor/golang.org/x/text/unicode/norm from vendor/golang.org/x/net/idna bufio from compress/flate+ bytes from bufio+ cmp from slices+ @@ -216,7 +224,7 @@ tailscale.com/cmd/derper dependencies: (generated by github.com/tailscale/depawa container/list from crypto/tls+ context from crypto/tls+ crypto from crypto/ecdh+ - crypto/aes from crypto/ecdsa+ + crypto/aes from crypto/internal/hpke+ crypto/cipher from crypto/aes+ crypto/des from crypto/tls+ crypto/dsa from crypto/x509 @@ -224,20 +232,62 @@ tailscale.com/cmd/derper dependencies: (generated by github.com/tailscale/depawa crypto/ecdsa from crypto/tls+ crypto/ed25519 from crypto/tls+ crypto/elliptic from crypto/ecdsa+ + crypto/fips140 from crypto/tls/internal/fips140tls + crypto/hkdf from crypto/internal/hpke+ crypto/hmac from crypto/tls+ + crypto/internal/boring from crypto/aes+ + crypto/internal/boring/bbig from crypto/ecdsa+ + crypto/internal/boring/sig from crypto/internal/boring + crypto/internal/entropy from crypto/internal/fips140/drbg + crypto/internal/fips140 from crypto/internal/fips140/aes+ + crypto/internal/fips140/aes from crypto/aes+ + crypto/internal/fips140/aes/gcm from crypto/cipher+ + crypto/internal/fips140/alias from crypto/cipher+ + crypto/internal/fips140/bigmod from crypto/internal/fips140/ecdsa+ + crypto/internal/fips140/check from crypto/internal/fips140/aes+ + crypto/internal/fips140/drbg from crypto/internal/fips140/aes/gcm+ + crypto/internal/fips140/ecdh from crypto/ecdh + crypto/internal/fips140/ecdsa from crypto/ecdsa + crypto/internal/fips140/ed25519 from crypto/ed25519 + crypto/internal/fips140/edwards25519 from crypto/internal/fips140/ed25519 + crypto/internal/fips140/edwards25519/field from crypto/ecdh+ + crypto/internal/fips140/hkdf from crypto/internal/fips140/tls13+ + crypto/internal/fips140/hmac from crypto/hmac+ + crypto/internal/fips140/mlkem from crypto/tls + crypto/internal/fips140/nistec from crypto/elliptic+ + crypto/internal/fips140/nistec/fiat from crypto/internal/fips140/nistec + crypto/internal/fips140/rsa from crypto/rsa + crypto/internal/fips140/sha256 from crypto/internal/fips140/check+ + crypto/internal/fips140/sha3 from crypto/internal/fips140/hmac+ + crypto/internal/fips140/sha512 from crypto/internal/fips140/ecdsa+ + crypto/internal/fips140/subtle from crypto/internal/fips140/aes+ + crypto/internal/fips140/tls12 from crypto/tls + crypto/internal/fips140/tls13 from crypto/tls + crypto/internal/fips140cache from crypto/ecdsa+ + crypto/internal/fips140deps/byteorder from crypto/internal/fips140/aes+ + crypto/internal/fips140deps/cpu from crypto/internal/fips140/aes+ + crypto/internal/fips140deps/godebug from crypto/internal/fips140+ + crypto/internal/fips140hash from crypto/ecdsa+ + crypto/internal/fips140only from crypto/cipher+ + crypto/internal/hpke from crypto/tls + crypto/internal/impl from crypto/internal/fips140/aes+ + crypto/internal/randutil from crypto/dsa+ + crypto/internal/sysrand from crypto/internal/entropy+ crypto/md5 from crypto/tls+ crypto/rand from crypto/ed25519+ crypto/rc4 from crypto/tls crypto/rsa from crypto/tls+ crypto/sha1 from crypto/tls+ crypto/sha256 from crypto/tls+ + crypto/sha3 from crypto/internal/fips140hash crypto/sha512 from crypto/ecdsa+ - crypto/subtle from crypto/aes+ + crypto/subtle from crypto/cipher+ crypto/tls from golang.org/x/crypto/acme+ + crypto/tls/internal/fips140tls from crypto/tls crypto/x509 from crypto/tls+ + D crypto/x509/internal/macos from crypto/x509 crypto/x509/pkix from crypto/x509+ - database/sql/driver from github.com/google/uuid - embed from crypto/internal/nistec+ + embed from google.golang.org/protobuf/internal/editiondefaults+ encoding from encoding/json+ encoding/asn1 from crypto/x509+ encoding/base32 from github.com/fxamacker/cbor/v2+ @@ -248,7 +298,7 @@ tailscale.com/cmd/derper dependencies: (generated by github.com/tailscale/depawa encoding/pem from crypto/tls+ errors from bufio+ expvar from github.com/prometheus/client_golang/prometheus+ - flag from tailscale.com/cmd/derper + flag from tailscale.com/cmd/derper+ fmt from compress/flate+ go/token from google.golang.org/protobuf/internal/strs hash from crypto+ @@ -256,9 +306,57 @@ tailscale.com/cmd/derper dependencies: (generated by github.com/tailscale/depawa hash/fnv from google.golang.org/protobuf/internal/detrand hash/maphash from go4.org/mem html from net/http/pprof+ + html/template from tailscale.com/cmd/derper+ + internal/abi from crypto/x509/internal/macos+ + internal/asan from internal/runtime/maps+ + internal/bisect from internal/godebug + internal/bytealg from bytes+ + internal/byteorder from crypto/cipher+ + internal/chacha8rand from math/rand/v2+ + internal/coverage/rtcov from runtime + internal/cpu from crypto/internal/fips140deps/cpu+ + internal/filepathlite from os+ + internal/fmtsort from fmt+ + internal/goarch from crypto/internal/fips140deps/cpu+ + internal/godebug from crypto/internal/fips140deps/godebug+ + internal/godebugs from internal/godebug+ + internal/goexperiment from hash/maphash+ + internal/goos from crypto/x509+ + internal/itoa from internal/poll+ + internal/msan from internal/runtime/maps+ + internal/nettrace from net+ + internal/oserror from io/fs+ + internal/poll from net+ + internal/profile from net/http/pprof + internal/profilerecord from runtime+ + internal/race from internal/poll+ + internal/reflectlite from context+ + D internal/routebsd from net + internal/runtime/atomic from internal/runtime/exithook+ + L internal/runtime/cgroup from runtime + internal/runtime/exithook from runtime + internal/runtime/gc from runtime + internal/runtime/maps from reflect+ + internal/runtime/math from internal/runtime/maps+ + internal/runtime/strconv from internal/runtime/cgroup+ + internal/runtime/sys from crypto/subtle+ + L internal/runtime/syscall from runtime+ + internal/saferio from encoding/asn1 + internal/singleflight from net + internal/stringslite from embed+ + internal/sync from sync+ + internal/synctest from sync + internal/syscall/execenv from os+ + LD internal/syscall/unix from crypto/internal/sysrand+ + W internal/syscall/windows from crypto/internal/sysrand+ + W internal/syscall/windows/registry from mime+ + W internal/syscall/windows/sysdll from internal/syscall/windows+ + internal/testlog from os + internal/trace/tracev2 from runtime+ + internal/unsafeheader from internal/reflectlite+ io from bufio+ io/fs from crypto/x509+ - io/ioutil from github.com/mitchellh/go-ps+ + L io/ioutil from github.com/mitchellh/go-ps iter from maps+ log from expvar+ log/internal from log @@ -267,7 +365,7 @@ tailscale.com/cmd/derper dependencies: (generated by github.com/tailscale/depawa math/big from crypto/dsa+ math/bits from compress/flate+ math/rand from github.com/mdlayher/netlink+ - math/rand/v2 from tailscale.com/util/fastuuid+ + math/rand/v2 from crypto/ecdsa+ mime from github.com/prometheus/common/expfmt+ mime/multipart from net/http mime/quotedprintable from mime/multipart @@ -275,19 +373,22 @@ tailscale.com/cmd/derper dependencies: (generated by github.com/tailscale/depawa net/http from expvar+ net/http/httptrace from net/http+ net/http/internal from net/http + net/http/internal/ascii from net/http + net/http/internal/httpcommon from net/http net/http/pprof from tailscale.com/tsweb net/netip from go4.org/netipx+ - net/textproto from golang.org/x/net/http/httpguts+ + net/textproto from github.com/coder/websocket+ net/url from crypto/x509+ - os from crypto/rand+ - os/exec from github.com/coreos/go-iptables/iptables+ + os from crypto/internal/sysrand+ + os/exec from golang.zx2c4.com/wireguard/windows/tunnel/winipcfg+ os/signal from tailscale.com/cmd/derper W os/user from tailscale.com/util/winutil path from github.com/prometheus/client_golang/prometheus/internal+ path/filepath from crypto/x509+ reflect from crypto/x509+ - regexp from github.com/coreos/go-iptables/iptables+ + regexp from github.com/prometheus/client_golang/prometheus/internal+ regexp/syntax from regexp + runtime from crypto/internal/fips140+ runtime/debug from github.com/prometheus/client_golang/prometheus+ runtime/metrics from github.com/prometheus/client_golang/prometheus+ runtime/pprof from net/http/pprof @@ -296,12 +397,17 @@ tailscale.com/cmd/derper dependencies: (generated by github.com/tailscale/depawa sort from compress/flate+ strconv from compress/flate+ strings from bufio+ + W structs from internal/syscall/windows sync from compress/flate+ sync/atomic from context+ - syscall from crypto/rand+ + syscall from crypto/internal/sysrand+ text/tabwriter from runtime/pprof + text/template from html/template + text/template/parse from html/template+ time from compress/gzip+ unicode from bytes+ unicode/utf16 from crypto/x509+ unicode/utf8 from bufio+ unique from net/netip + unsafe from bytes+ + weak from unique+ diff --git a/cmd/derper/derper.go b/cmd/derper/derper.go index 80c9dc44f..f177986a5 100644 --- a/cmd/derper/derper.go +++ b/cmd/derper/derper.go @@ -19,6 +19,7 @@ import ( "expvar" "flag" "fmt" + "html/template" "io" "log" "math" @@ -26,6 +27,7 @@ import ( "net/http" "os" "os/signal" + "path" "path/filepath" "regexp" "runtime" @@ -35,10 +37,10 @@ import ( "syscall" "time" + "github.com/tailscale/setec/client/setec" "golang.org/x/time/rate" "tailscale.com/atomicfile" - "tailscale.com/derp" - "tailscale.com/derp/derphttp" + "tailscale.com/derp/derpserver" "tailscale.com/metrics" "tailscale.com/net/ktimeout" "tailscale.com/net/stunserver" @@ -46,6 +48,9 @@ import ( "tailscale.com/types/key" "tailscale.com/types/logger" "tailscale.com/version" + + // Support for prometheus varz in tsweb + _ "tailscale.com/tsweb/promvarz" ) var ( @@ -57,18 +62,25 @@ var ( configPath = flag.String("c", "", "config file path") certMode = flag.String("certmode", "letsencrypt", "mode for getting a cert. possible options: manual, letsencrypt") certDir = flag.String("certdir", tsweb.DefaultCertDir("derper-certs"), "directory to store LetsEncrypt certs, if addr's port is :443") - hostname = flag.String("hostname", "derp.tailscale.com", "LetsEncrypt host name, if addr's port is :443") + hostname = flag.String("hostname", "derp.tailscale.com", "LetsEncrypt host name, if addr's port is :443. When --certmode=manual, this can be an IP address to avoid SNI checks") runSTUN = flag.Bool("stun", true, "whether to run a STUN server. It will bind to the same IP (if any) as the --addr flag value.") runDERP = flag.Bool("derp", true, "whether to run a DERP server. The only reason to set this false is if you're decommissioning a server but want to keep its bootstrap DNS functionality still running.") + flagHome = flag.String("home", "", "what to serve at the root path. It may be left empty (the default, for a default homepage), \"blank\" for a blank page, or a URL to redirect to") - meshPSKFile = flag.String("mesh-psk-file", defaultMeshPSKFile(), "if non-empty, path to file containing the mesh pre-shared key file. It should contain some hex string; whitespace is trimmed.") - meshWith = flag.String("mesh-with", "", "optional comma-separated list of hostnames to mesh with; the server's own hostname can be in the list") + meshPSKFile = flag.String("mesh-psk-file", defaultMeshPSKFile(), "if non-empty, path to file containing the mesh pre-shared key file. It must be 64 lowercase hexadecimal characters; whitespace is trimmed.") + meshWith = flag.String("mesh-with", "", "optional comma-separated list of hostnames to mesh with; the server's own hostname can be in the list. If an entry contains a slash, the second part names a hostname to be used when dialing the target.") + secretsURL = flag.String("secrets-url", "", "SETEC server URL for secrets retrieval of mesh key") + secretPrefix = flag.String("secrets-path-prefix", "prod/derp", "setec path prefix for \""+setecMeshKeyName+"\" secret for DERP mesh key") + secretsCacheDir = flag.String("secrets-cache-dir", defaultSetecCacheDir(), "directory to cache setec secrets in (required if --secrets-url is set)") bootstrapDNS = flag.String("bootstrap-dns-names", "", "optional comma-separated list of hostnames to make available at /bootstrap-dns") unpublishedDNS = flag.String("unpublished-bootstrap-dns-names", "", "optional comma-separated list of hostnames to make available at /bootstrap-dns and not publish in the list. If an entry contains a slash, the second part names a DNS record to poll for its TXT record with a `0` to `100` value for rollout percentage.") + verifyClients = flag.Bool("verify-clients", false, "verify clients to this DERP server through a local tailscaled instance.") verifyClientURL = flag.String("verify-client-url", "", "if non-empty, an admission controller URL for permitting client connections; see tailcfg.DERPAdmitClientRequest") verifyFailOpen = flag.Bool("verify-client-url-fail-open", true, "whether we fail open if --verify-client-url is unreachable") + socket = flag.String("socket", "", "optional alternate path to tailscaled socket (only relevant when using --verify-clients)") + acceptConnLimit = flag.Float64("accept-connection-limit", math.Inf(+1), "rate limit for accepting new connection") acceptConnBurst = flag.Int("accept-connection-burst", math.MaxInt, "burst limit for accepting new connection") @@ -76,6 +88,11 @@ var ( tcpKeepAlive = flag.Duration("tcp-keepalive-time", 10*time.Minute, "TCP keepalive time") // tcpUserTimeout is intentionally short, so that hung connections are cleaned up promptly. DERPs should be nearby users. tcpUserTimeout = flag.Duration("tcp-user-timeout", 15*time.Second, "TCP user timeout") + // tcpWriteTimeout is the timeout for writing to client TCP connections. It does not apply to mesh connections. + tcpWriteTimeout = flag.Duration("tcp-write-timeout", derpserver.DefaultTCPWiteTimeout, "TCP write timeout; 0 results in no timeout being set on writes") + + // ACE + flagACEEnabled = flag.Bool("ace", false, "whether to enable embedded ACE server [experimental + in-development as of 2025-09-12; not yet documented]") ) var ( @@ -83,6 +100,9 @@ var ( tlsActiveVersion = &metrics.LabelMap{Label: "version"} ) +const setecMeshKeyName = "meshkey" +const meshKeyEnvVar = "TAILSCALE_DERPER_MESH_KEY" + func init() { expvar.Publish("derper_tls_request_version", tlsRequestVersion) expvar.Publish("gauge_derper_tls_active_version", tlsActiveVersion) @@ -168,31 +188,74 @@ func main() { serveTLS := tsweb.IsProd443(*addr) || *certMode == "manual" - s := derp.NewServer(cfg.PrivateKey, log.Printf) + s := derpserver.New(cfg.PrivateKey, log.Printf) s.SetVerifyClient(*verifyClients) + s.SetTailscaledSocketPath(*socket) s.SetVerifyClientURL(*verifyClientURL) s.SetVerifyClientURLFailOpen(*verifyFailOpen) + s.SetTCPWriteTimeout(*tcpWriteTimeout) - if *meshPSKFile != "" { - b, err := os.ReadFile(*meshPSKFile) + var meshKey string + if *dev { + meshKey = os.Getenv(meshKeyEnvVar) + if meshKey == "" { + log.Printf("No mesh key specified for dev via %s\n", meshKeyEnvVar) + } else { + log.Printf("Set mesh key from %s\n", meshKeyEnvVar) + } + } else if *secretsURL != "" { + meshKeySecret := path.Join(*secretPrefix, setecMeshKeyName) + fc, err := setec.NewFileCache(*secretsCacheDir) if err != nil { - log.Fatal(err) + log.Fatalf("NewFileCache: %v", err) } - key := strings.TrimSpace(string(b)) - if matched, _ := regexp.MatchString(`(?i)^[0-9a-f]{64,}$`, key); !matched { - log.Fatalf("key in %s must contain 64+ hex digits", *meshPSKFile) + log.Printf("Setting up setec store from %q", *secretsURL) + st, err := setec.NewStore(ctx, + setec.StoreConfig{ + Client: setec.Client{Server: *secretsURL}, + Secrets: []string{ + meshKeySecret, + }, + Cache: fc, + }) + if err != nil { + log.Fatalf("NewStore: %v", err) + } + meshKey = st.Secret(meshKeySecret).GetString() + log.Println("Got mesh key from setec store") + st.Close() + } else if *meshPSKFile != "" { + b, err := setec.StaticFile(*meshPSKFile) + if err != nil { + log.Fatalf("StaticFile failed to get key: %v", err) } - s.SetMeshKey(key) - log.Printf("DERP mesh key configured") + log.Println("Got mesh key from static file") + meshKey = b.GetString() + } + + if meshKey == "" && *dev { + log.Printf("No mesh key configured for --dev mode") + } else if meshKey == "" { + log.Printf("No mesh key configured") + } else if err := s.SetMeshKey(meshKey); err != nil { + log.Fatalf("invalid mesh key: %v", err) + } else { + log.Println("DERP mesh key configured") } + if err := startMesh(s); err != nil { log.Fatalf("startMesh: %v", err) } expvar.Publish("derp", s.ExpVar()) + handleHome, ok := getHomeHandler(*flagHome) + if !ok { + log.Fatalf("unknown --home value %q", *flagHome) + } + mux := http.NewServeMux() if *runDERP { - derpHandler := derphttp.Handler(s) + derpHandler := derpserver.Handler(s) derpHandler = addWebSocketSupport(s, derpHandler) mux.Handle("/derp", derpHandler) } else { @@ -203,41 +266,20 @@ func main() { // These two endpoints are the same. Different versions of the clients // have assumes different paths over time so we support both. - mux.HandleFunc("/derp/probe", derphttp.ProbeHandler) - mux.HandleFunc("/derp/latency-check", derphttp.ProbeHandler) + mux.HandleFunc("/derp/probe", derpserver.ProbeHandler) + mux.HandleFunc("/derp/latency-check", derpserver.ProbeHandler) go refreshBootstrapDNSLoop() mux.HandleFunc("/bootstrap-dns", tsweb.BrowserHeaderHandlerFunc(handleBootstrapDNS)) mux.Handle("/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { tsweb.AddBrowserHeaders(w) - w.Header().Set("Content-Type", "text/html; charset=utf-8") - w.WriteHeader(200) - io.WriteString(w, ` -

DERP

-

- This is a Tailscale DERP server. -

-

- Documentation: -

- -`) - if !*runDERP { - io.WriteString(w, `

Status: disabled

`) - } - if tsweb.AllowDebugAccess(r) { - io.WriteString(w, "

Debug info at /debug/.

\n") - } + handleHome.ServeHTTP(w, r) })) mux.Handle("/robots.txt", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { tsweb.AddBrowserHeaders(w) io.WriteString(w, "User-agent: *\nDisallow: /\n") })) - mux.Handle("/generate_204", http.HandlerFunc(derphttp.ServeNoContent)) + mux.Handle("/generate_204", http.HandlerFunc(derpserver.ServeNoContent)) debug := tsweb.Debugger(mux) debug.KV("TLS hostname", *hostname) debug.KV("Mesh key", s.HasMeshKey()) @@ -273,6 +315,9 @@ func main() { Control: ktimeout.UserTimeout(*tcpUserTimeout), KeepAlive: *tcpKeepAlive, } + // As of 2025-02-19, MPTCP does not support TCP_USER_TIMEOUT socket option + // set in ktimeout.UserTimeout above. + lc.SetMultipathTCP(false) quietLogger := log.New(logger.HTTPServerLogFilter{Inner: log.Printf}, "", 0) httpsrv := &http.Server{ @@ -330,6 +375,11 @@ func main() { tlsRequestVersion.Add(label, 1) tlsActiveVersion.Add(label, 1) defer tlsActiveVersion.Add(label, -1) + + if r.Method == "CONNECT" { + serveConnect(s, w, r) + return + } } mux.ServeHTTP(w, r) @@ -337,7 +387,7 @@ func main() { if *httpPort > -1 { go func() { port80mux := http.NewServeMux() - port80mux.HandleFunc("/generate_204", derphttp.ServeNoContent) + port80mux.HandleFunc("/generate_204", derpserver.ServeNoContent) port80mux.Handle("/", certManager.HTTPHandler(tsweb.Port80Handler{Main: mux})) port80srv := &http.Server{ Addr: net.JoinHostPort(listenHost, fmt.Sprintf("%d", *httpPort)), @@ -387,6 +437,10 @@ func prodAutocertHostPolicy(_ context.Context, host string) error { return errors.New("invalid hostname") } +func defaultSetecCacheDir() string { + return filepath.Join(os.Getenv("HOME"), ".cache", "derper-secrets") +} + func defaultMeshPSKFile() string { try := []string{ "/home/derp/keys/derp-mesh.key", @@ -427,32 +481,32 @@ func newRateLimitedListener(ln net.Listener, limit rate.Limit, burst int) *rateL return &rateLimitedListener{Listener: ln, lim: rate.NewLimiter(limit, burst)} } -func (l *rateLimitedListener) ExpVar() expvar.Var { +func (ln *rateLimitedListener) ExpVar() expvar.Var { m := new(metrics.Set) - m.Set("counter_accepted_connections", &l.numAccepts) - m.Set("counter_rejected_connections", &l.numRejects) + m.Set("counter_accepted_connections", &ln.numAccepts) + m.Set("counter_rejected_connections", &ln.numRejects) return m } var errLimitedConn = errors.New("cannot accept connection; rate limited") -func (l *rateLimitedListener) Accept() (net.Conn, error) { +func (ln *rateLimitedListener) Accept() (net.Conn, error) { // Even under a rate limited situation, we accept the connection immediately // and close it, rather than being slow at accepting new connections. // This provides two benefits: 1) it signals to the client that something // is going on on the server, and 2) it prevents new connections from // piling up and occupying resources in the OS kernel. // The client will retry as needing (with backoffs in place). - cn, err := l.Listener.Accept() + cn, err := ln.Listener.Accept() if err != nil { return nil, err } - if !l.lim.Allow() { - l.numRejects.Add(1) + if !ln.lim.Allow() { + ln.numRejects.Add(1) cn.Close() return nil, errLimitedConn } - l.numAccepts.Add(1) + ln.numAccepts.Add(1) return cn, nil } @@ -468,3 +522,84 @@ func init() { return 0 })) } + +type templateData struct { + ShowAbuseInfo bool + Disabled bool + AllowDebug bool +} + +// homePageTemplate renders the home page using [templateData]. +var homePageTemplate = template.Must(template.New("home").Parse(` +

DERP

+

+ This is a Tailscale DERP server. +

+ +

+ It provides STUN, interactive connectivity establishment, and relaying of end-to-end encrypted traffic + for Tailscale clients. +

+ +{{if .ShowAbuseInfo }} +

+ If you suspect abuse, please contact security@tailscale.com. +

+{{end}} + +

+ Documentation: +

+ + + +{{if .Disabled}} +

Status: disabled

+{{end}} + +{{if .AllowDebug}} +

Debug info at /debug/.

+{{end}} + + +`)) + +// getHomeHandler returns a handler for the home page based on a flag string +// as documented on the --home flag. +func getHomeHandler(val string) (_ http.Handler, ok bool) { + if val == "" { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/html; charset=utf-8") + w.WriteHeader(200) + err := homePageTemplate.Execute(w, templateData{ + ShowAbuseInfo: validProdHostname.MatchString(*hostname), + Disabled: !*runDERP, + AllowDebug: tsweb.AllowDebugAccess(r), + }) + if err != nil { + if r.Context().Err() == nil { + log.Printf("homePageTemplate.Execute: %v", err) + } + return + } + }), true + } + if val == "blank" { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/html; charset=utf-8") + w.WriteHeader(200) + }), true + } + if strings.HasPrefix(val, "http://") || strings.HasPrefix(val, "https://") { + return http.RedirectHandler(val, http.StatusFound), true + } + return nil, false +} diff --git a/cmd/derper/derper_test.go b/cmd/derper/derper_test.go index 553a78f9f..d27f8cb20 100644 --- a/cmd/derper/derper_test.go +++ b/cmd/derper/derper_test.go @@ -4,13 +4,14 @@ package main import ( + "bytes" "context" "net/http" "net/http/httptest" "strings" "testing" - "tailscale.com/derp/derphttp" + "tailscale.com/derp/derpserver" "tailscale.com/tstest/deptest" ) @@ -77,20 +78,20 @@ func TestNoContent(t *testing.T) { t.Run(tt.name, func(t *testing.T) { req, _ := http.NewRequest("GET", "https://localhost/generate_204", nil) if tt.input != "" { - req.Header.Set(derphttp.NoContentChallengeHeader, tt.input) + req.Header.Set(derpserver.NoContentChallengeHeader, tt.input) } w := httptest.NewRecorder() - derphttp.ServeNoContent(w, req) + derpserver.ServeNoContent(w, req) resp := w.Result() if tt.want == "" { - if h, found := resp.Header[derphttp.NoContentResponseHeader]; found { + if h, found := resp.Header[derpserver.NoContentResponseHeader]; found { t.Errorf("got %+v; expected no response header", h) } return } - if got := resp.Header.Get(derphttp.NoContentResponseHeader); got != tt.want { + if got := resp.Header.Get(derpserver.NoContentResponseHeader); got != tt.want { t.Errorf("got %q; want %q", got, tt.want) } }) @@ -107,6 +108,33 @@ func TestDeps(t *testing.T) { "gvisor.dev/gvisor/pkg/tcpip/header": "https://github.com/tailscale/tailscale/issues/9756", "tailscale.com/net/packet": "not needed in derper", "github.com/gaissmai/bart": "not needed in derper", + "database/sql/driver": "not needed in derper", // previously came in via github.com/google/uuid }, }.Check(t) } + +func TestTemplate(t *testing.T) { + buf := &bytes.Buffer{} + err := homePageTemplate.Execute(buf, templateData{ + ShowAbuseInfo: true, + Disabled: true, + AllowDebug: true, + }) + if err != nil { + t.Fatal(err) + } + + str := buf.String() + if !strings.Contains(str, "If you suspect abuse") { + t.Error("Output is missing abuse mailto") + } + if !strings.Contains(str, "Tailscale Security Policies") { + t.Error("Output is missing Tailscale Security Policies link") + } + if !strings.Contains(str, "Status:") { + t.Error("Output is missing disabled status") + } + if !strings.Contains(str, "Debug info") { + t.Error("Output is missing debug info") + } +} diff --git a/cmd/derper/mesh.go b/cmd/derper/mesh.go index ee1807f00..909b5f2ca 100644 --- a/cmd/derper/mesh.go +++ b/cmd/derper/mesh.go @@ -10,30 +10,43 @@ import ( "log" "net" "strings" - "time" "tailscale.com/derp" "tailscale.com/derp/derphttp" + "tailscale.com/derp/derpserver" "tailscale.com/net/netmon" "tailscale.com/types/logger" ) -func startMesh(s *derp.Server) error { +func startMesh(s *derpserver.Server) error { if *meshWith == "" { return nil } if !s.HasMeshKey() { return errors.New("--mesh-with requires --mesh-psk-file") } - for _, host := range strings.Split(*meshWith, ",") { - if err := startMeshWithHost(s, host); err != nil { + for _, hostTuple := range strings.Split(*meshWith, ",") { + if err := startMeshWithHost(s, hostTuple); err != nil { return err } } return nil } -func startMeshWithHost(s *derp.Server, host string) error { +func startMeshWithHost(s *derpserver.Server, hostTuple string) error { + var host string + var dialHost string + hostParts := strings.Split(hostTuple, "/") + if len(hostParts) > 2 { + return fmt.Errorf("too many components in host tuple %q", hostTuple) + } + host = hostParts[0] + if len(hostParts) == 2 { + dialHost = hostParts[1] + } else { + dialHost = hostParts[0] + } + logf := logger.WithPrefix(log.Printf, fmt.Sprintf("mesh(%q): ", host)) netMon := netmon.NewStatic() // good enough for cmd/derper; no need for netns fanciness c, err := derphttp.NewClient(s.PrivateKey(), "https://"+host+"/derp", logf, netMon) @@ -43,34 +56,24 @@ func startMeshWithHost(s *derp.Server, host string) error { c.MeshKey = s.MeshKey() c.WatchConnectionChanges = true - // For meshed peers within a region, connect via VPC addresses. - c.SetURLDialer(func(ctx context.Context, network, addr string) (net.Conn, error) { - host, port, err := net.SplitHostPort(addr) - if err != nil { - return nil, err - } + logf("will dial %q for %q", dialHost, host) + if dialHost != host { var d net.Dialer - var r net.Resolver - if base, ok := strings.CutSuffix(host, ".tailscale.com"); ok && port == "443" { - subCtx, cancel := context.WithTimeout(ctx, 2*time.Second) - defer cancel() - vpcHost := base + "-vpc.tailscale.com" - ips, _ := r.LookupIP(subCtx, "ip", vpcHost) - if len(ips) > 0 { - vpcAddr := net.JoinHostPort(ips[0].String(), port) - c, err := d.DialContext(subCtx, network, vpcAddr) - if err == nil { - log.Printf("connected to %v (%v) instead of %v", vpcHost, ips[0], base) - return c, nil - } - log.Printf("failed to connect to %v (%v): %v; trying non-VPC route", vpcHost, ips[0], err) + c.SetURLDialer(func(ctx context.Context, network, addr string) (net.Conn, error) { + _, port, err := net.SplitHostPort(addr) + if err != nil { + logf("failed to split %q: %v", addr, err) + return nil, err } - } - return d.DialContext(ctx, network, addr) - }) + dialAddr := net.JoinHostPort(dialHost, port) + logf("dialing %q instead of %q", dialAddr, addr) + return d.DialContext(ctx, network, dialAddr) + }) + } add := func(m derp.PeerPresentMessage) { s.AddPacketForwarder(m.Key, c) } remove := func(m derp.PeerGoneMessage) { s.RemovePacketForwarder(m.Peer, c) } - go c.RunWatchConnectionLoop(context.Background(), s.PublicKey(), logf, add, remove) + notifyError := func(err error) {} + go c.RunWatchConnectionLoop(context.Background(), s.PublicKey(), logf, add, remove, notifyError) return nil } diff --git a/cmd/derper/websocket.go b/cmd/derper/websocket.go index 05f40deb8..82fd30bed 100644 --- a/cmd/derper/websocket.go +++ b/cmd/derper/websocket.go @@ -11,14 +11,14 @@ import ( "strings" "github.com/coder/websocket" - "tailscale.com/derp" + "tailscale.com/derp/derpserver" "tailscale.com/net/wsconn" ) var counterWebSocketAccepts = expvar.NewInt("derp_websocket_accepts") // addWebSocketSupport returns a Handle wrapping base that adds WebSocket server support. -func addWebSocketSupport(s *derp.Server, base http.Handler) http.Handler { +func addWebSocketSupport(s *derpserver.Server, base http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { up := strings.ToLower(r.Header.Get("Upgrade")) diff --git a/cmd/derpprobe/derpprobe.go b/cmd/derpprobe/derpprobe.go index 1d0ec32c3..5d2179b51 100644 --- a/cmd/derpprobe/derpprobe.go +++ b/cmd/derpprobe/derpprobe.go @@ -5,30 +5,55 @@ package main import ( + "context" "flag" "fmt" "log" "net/http" + "os" + "path" + "path/filepath" "sort" "time" + "github.com/tailscale/setec/client/setec" "tailscale.com/prober" "tailscale.com/tsweb" + "tailscale.com/types/key" "tailscale.com/version" + + // Support for prometheus varz in tsweb + _ "tailscale.com/tsweb/promvarz" ) +const meshKeyEnvVar = "TAILSCALE_DERPER_MESH_KEY" +const setecMeshKeyName = "meshkey" + +func defaultSetecCacheDir() string { + return filepath.Join(os.Getenv("HOME"), ".cache", "derper-secrets") +} + var ( - derpMapURL = flag.String("derp-map", "https://login.tailscale.com/derpmap/default", "URL to DERP map (https:// or file://) or 'local' to use the local tailscaled's DERP map") - versionFlag = flag.Bool("version", false, "print version and exit") - listen = flag.String("listen", ":8030", "HTTP listen address") - probeOnce = flag.Bool("once", false, "probe once and print results, then exit; ignores the listen flag") - spread = flag.Bool("spread", true, "whether to spread probing over time") - interval = flag.Duration("interval", 15*time.Second, "probe interval") - meshInterval = flag.Duration("mesh-interval", 15*time.Second, "mesh probe interval") - stunInterval = flag.Duration("stun-interval", 15*time.Second, "STUN probe interval") - tlsInterval = flag.Duration("tls-interval", 15*time.Second, "TLS probe interval") - bwInterval = flag.Duration("bw-interval", 0, "bandwidth probe interval (0 = no bandwidth probing)") - bwSize = flag.Int64("bw-probe-size-bytes", 1_000_000, "bandwidth probe size") + dev = flag.Bool("dev", false, "run in localhost development mode") + derpMapURL = flag.String("derp-map", "https://login.tailscale.com/derpmap/default", "URL to DERP map (https:// or file://) or 'local' to use the local tailscaled's DERP map") + versionFlag = flag.Bool("version", false, "print version and exit") + listen = flag.String("listen", ":8030", "HTTP listen address") + probeOnce = flag.Bool("once", false, "probe once and print results, then exit; ignores the listen flag") + spread = flag.Bool("spread", true, "whether to spread probing over time") + interval = flag.Duration("interval", 15*time.Second, "probe interval") + meshInterval = flag.Duration("mesh-interval", 15*time.Second, "mesh probe interval") + stunInterval = flag.Duration("stun-interval", 15*time.Second, "STUN probe interval") + tlsInterval = flag.Duration("tls-interval", 15*time.Second, "TLS probe interval") + bwInterval = flag.Duration("bw-interval", 0, "bandwidth probe interval (0 = no bandwidth probing)") + bwSize = flag.Int64("bw-probe-size-bytes", 1_000_000, "bandwidth probe size") + bwTUNIPv4Address = flag.String("bw-tun-ipv4-addr", "", "if specified, bandwidth probes will be performed over a TUN device at this address in order to exercise TCP-in-TCP in similar fashion to TCP over Tailscale via DERP; we will use a /30 subnet including this IP address") + qdPacketsPerSecond = flag.Int("qd-packets-per-second", 0, "if greater than 0, queuing delay will be measured continuously using 260 byte packets (approximate size of a CallMeMaybe packet) sent at this rate per second") + qdPacketTimeout = flag.Duration("qd-packet-timeout", 5*time.Second, "queuing delay packets arriving after this period of time from being sent are treated like dropped packets and don't count toward queuing delay timings") + regionCodeOrID = flag.String("region-code", "", "probe only this region (e.g. 'lax' or '17'); if left blank, all regions will be probed") + meshPSKFile = flag.String("mesh-psk-file", "", "if non-empty, path to file containing the mesh pre-shared key file. It must be 64 lowercase hexadecimal characters; whitespace is trimmed.") + secretsURL = flag.String("secrets-url", "", "SETEC server URL for secrets retrieval of mesh key") + secretPrefix = flag.String("secrets-path-prefix", "prod/derp", fmt.Sprintf("setec path prefix for \"%s\" secret for DERP mesh key", setecMeshKeyName)) + secretsCacheDir = flag.String("secrets-cache-dir", defaultSetecCacheDir(), "directory to cache setec secrets in (required if --secrets-url is set)") ) func main() { @@ -39,13 +64,22 @@ func main() { } p := prober.New().WithSpread(*spread).WithOnce(*probeOnce).WithMetricNamespace("derpprobe") + meshKey, err := getMeshKey() + if err != nil { + log.Fatalf("failed to get mesh key: %v", err) + } opts := []prober.DERPOpt{ prober.WithMeshProbing(*meshInterval), prober.WithSTUNProbing(*stunInterval), prober.WithTLSProbing(*tlsInterval), + prober.WithQueuingDelayProbing(*qdPacketsPerSecond, *qdPacketTimeout), + prober.WithMeshKey(meshKey), } if *bwInterval > 0 { - opts = append(opts, prober.WithBandwidthProbing(*bwInterval, *bwSize)) + opts = append(opts, prober.WithBandwidthProbing(*bwInterval, *bwSize, *bwTUNIPv4Address)) + } + if *regionCodeOrID != "" { + opts = append(opts, prober.WithRegionCodeOrID(*regionCodeOrID)) } dp, err := prober.DERP(p, *derpMapURL, opts...) if err != nil { @@ -64,21 +98,77 @@ func main() { for _, s := range st.bad { log.Printf("bad: %s", s) } + if len(st.bad) > 0 { + os.Exit(1) + } return } mux := http.NewServeMux() d := tsweb.Debugger(mux) d.Handle("probe-run", "Run a probe", tsweb.StdHandler(tsweb.ReturnHandlerFunc(p.RunHandler), tsweb.HandlerOptions{Logf: log.Printf})) + d.Handle("probe-all", "Run all configured probes", tsweb.StdHandler(tsweb.ReturnHandlerFunc(p.RunAllHandler), tsweb.HandlerOptions{Logf: log.Printf})) mux.Handle("/", tsweb.StdHandler(p.StatusHandler( prober.WithTitle("DERP Prober"), prober.WithPageLink("Prober metrics", "/debug/varz"), prober.WithProbeLink("Run Probe", "/debug/probe-run?name={{.Name}}"), ), tsweb.HandlerOptions{Logf: log.Printf})) + mux.Handle("/healthz", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/plain") + w.WriteHeader(http.StatusOK) + w.Write([]byte("ok\n")) + })) log.Printf("Listening on %s", *listen) log.Fatal(http.ListenAndServe(*listen, mux)) } +func getMeshKey() (key.DERPMesh, error) { + var meshKey string + + if *dev { + meshKey = os.Getenv(meshKeyEnvVar) + if meshKey == "" { + log.Printf("No mesh key specified for dev via %s\n", meshKeyEnvVar) + } else { + log.Printf("Set mesh key from %s\n", meshKeyEnvVar) + } + } else if *secretsURL != "" { + meshKeySecret := path.Join(*secretPrefix, setecMeshKeyName) + fc, err := setec.NewFileCache(*secretsCacheDir) + if err != nil { + log.Fatalf("NewFileCache: %v", err) + } + log.Printf("Setting up setec store from %q", *secretsURL) + st, err := setec.NewStore(context.Background(), + setec.StoreConfig{ + Client: setec.Client{Server: *secretsURL}, + Secrets: []string{ + meshKeySecret, + }, + Cache: fc, + }) + if err != nil { + log.Fatalf("NewStore: %v", err) + } + meshKey = st.Secret(meshKeySecret).GetString() + log.Println("Got mesh key from setec store") + st.Close() + } else if *meshPSKFile != "" { + b, err := setec.StaticFile(*meshPSKFile) + if err != nil { + log.Fatalf("StaticFile failed to get key: %v", err) + } + log.Println("Got mesh key from static file") + meshKey = b.GetString() + } + if meshKey == "" { + log.Printf("No mesh key found, mesh key is empty") + return key.DERPMesh{}, nil + } + + return key.ParseDERPMesh(meshKey) +} + type overallStatus struct { good, bad []string } @@ -97,7 +187,7 @@ func getOverallStatus(p *prober.Prober) (o overallStatus) { // Do not show probes that have not finished yet. continue } - if i.Result { + if i.Status == prober.ProbeStatusSucceeded { o.addGoodf("%s: %s", p, i.Latency) } else { o.addBadf("%s: %s", p, i.Error) diff --git a/cmd/dist/dist.go b/cmd/dist/dist.go index 05f5bbfb2..c7406298d 100644 --- a/cmd/dist/dist.go +++ b/cmd/dist/dist.go @@ -5,11 +5,13 @@ package main import ( + "cmp" "context" "errors" "flag" "log" "os" + "slices" "tailscale.com/release/dist" "tailscale.com/release/dist/cli" @@ -19,9 +21,13 @@ import ( ) var ( - synologyPackageCenter bool - qnapPrivateKeyPath string - qnapCertificatePath string + synologyPackageCenter bool + gcloudCredentialsBase64 string + gcloudProject string + gcloudKeyring string + qnapKeyName string + qnapCertificateBase64 string + qnapCertificateIntermediariesBase64 string ) func getTargets() ([]dist.Target, error) { @@ -42,10 +48,11 @@ func getTargets() ([]dist.Target, error) { // To build for package center, run // ./tool/go run ./cmd/dist build --synology-package-center synology ret = append(ret, synology.Targets(synologyPackageCenter, nil)...) - if (qnapPrivateKeyPath == "") != (qnapCertificatePath == "") { - return nil, errors.New("both --qnap-private-key-path and --qnap-certificate-path must be set") + qnapSigningArgs := []string{gcloudCredentialsBase64, gcloudProject, gcloudKeyring, qnapKeyName, qnapCertificateBase64, qnapCertificateIntermediariesBase64} + if cmp.Or(qnapSigningArgs...) != "" && slices.Contains(qnapSigningArgs, "") { + return nil, errors.New("all of --gcloud-credentials, --gcloud-project, --gcloud-keyring, --qnap-key-name, --qnap-certificate and --qnap-certificate-intermediaries must be set") } - ret = append(ret, qnap.Targets(qnapPrivateKeyPath, qnapCertificatePath)...) + ret = append(ret, qnap.Targets(gcloudCredentialsBase64, gcloudProject, gcloudKeyring, qnapKeyName, qnapCertificateBase64, qnapCertificateIntermediariesBase64)...) return ret, nil } @@ -54,8 +61,12 @@ func main() { for _, subcmd := range cmd.Subcommands { if subcmd.Name == "build" { subcmd.FlagSet.BoolVar(&synologyPackageCenter, "synology-package-center", false, "build synology packages with extra metadata for the official package center") - subcmd.FlagSet.StringVar(&qnapPrivateKeyPath, "qnap-private-key-path", "", "sign qnap packages with given key (must also provide --qnap-certificate-path)") - subcmd.FlagSet.StringVar(&qnapCertificatePath, "qnap-certificate-path", "", "sign qnap packages with given certificate (must also provide --qnap-private-key-path)") + subcmd.FlagSet.StringVar(&gcloudCredentialsBase64, "gcloud-credentials", "", "base64 encoded GCP credentials (used when signing QNAP builds)") + subcmd.FlagSet.StringVar(&gcloudProject, "gcloud-project", "", "name of project in GCP KMS (used when signing QNAP builds)") + subcmd.FlagSet.StringVar(&gcloudKeyring, "gcloud-keyring", "", "path to keyring in GCP KMS (used when signing QNAP builds)") + subcmd.FlagSet.StringVar(&qnapKeyName, "qnap-key-name", "", "name of GCP key to use when signing QNAP builds") + subcmd.FlagSet.StringVar(&qnapCertificateBase64, "qnap-certificate", "", "base64 encoded certificate to use when signing QNAP builds") + subcmd.FlagSet.StringVar(&qnapCertificateIntermediariesBase64, "qnap-certificate-intermediaries", "", "base64 encoded intermediary certificate to use when signing QNAP builds") } } diff --git a/cmd/featuretags/featuretags.go b/cmd/featuretags/featuretags.go new file mode 100644 index 000000000..8c8a2ceaf --- /dev/null +++ b/cmd/featuretags/featuretags.go @@ -0,0 +1,86 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// The featuretags command helps other build tools select Tailscale's Go build +// tags to use. +package main + +import ( + "flag" + "fmt" + "log" + "maps" + "slices" + "strings" + + "tailscale.com/feature/featuretags" + "tailscale.com/util/set" +) + +var ( + min = flag.Bool("min", false, "remove all features not mentioned in --add") + remove = flag.String("remove", "", "a comma-separated list of features to remove from the build. (without the 'ts_omit_' prefix)") + add = flag.String("add", "", "a comma-separated list of features or tags to add, if --min is used.") + list = flag.Bool("list", false, "if true, list all known features and what they do") +) + +func main() { + flag.Parse() + + features := featuretags.Features + + if *list { + for _, f := range slices.Sorted(maps.Keys(features)) { + fmt.Printf("%20s: %s\n", f, features[f].Desc) + } + return + } + + var keep = map[featuretags.FeatureTag]bool{} + for t := range strings.SplitSeq(*add, ",") { + if t != "" { + for ft := range featuretags.Requires(featuretags.FeatureTag(t)) { + keep[ft] = true + } + } + } + var tags []string + if keep[featuretags.CLI] { + tags = append(tags, "ts_include_cli") + } + if *min { + for _, f := range slices.Sorted(maps.Keys(features)) { + if f == "" { + continue + } + if !keep[f] && f.IsOmittable() { + tags = append(tags, f.OmitTag()) + } + } + } + removeSet := set.Set[featuretags.FeatureTag]{} + for v := range strings.SplitSeq(*remove, ",") { + if v == "" { + continue + } + f := featuretags.FeatureTag(v) + if _, ok := features[f]; !ok { + log.Fatalf("unknown feature %q in --remove", f) + } + removeSet.Add(f) + } + for ft := range removeSet { + set := featuretags.RequiredBy(ft) + for dependent := range set { + if !removeSet.Contains(dependent) { + log.Fatalf("cannot remove %q without also removing %q, which depends on it", ft, dependent) + } + } + tags = append(tags, ft.OmitTag()) + } + slices.Sort(tags) + tags = slices.Compact(tags) + if len(tags) != 0 { + fmt.Println(strings.Join(tags, ",")) + } +} diff --git a/cmd/get-authkey/main.go b/cmd/get-authkey/main.go index d8030252c..ec7ab5d2c 100644 --- a/cmd/get-authkey/main.go +++ b/cmd/get-authkey/main.go @@ -16,14 +16,10 @@ import ( "strings" "golang.org/x/oauth2/clientcredentials" - "tailscale.com/client/tailscale" + "tailscale.com/internal/client/tailscale" ) func main() { - // Required to use our client API. We're fine with the instability since the - // client lives in the same repo as this code. - tailscale.I_Acknowledge_This_API_Is_Unstable = true - reusable := flag.Bool("reusable", false, "allocate a reusable authkey") ephemeral := flag.Bool("ephemeral", false, "allocate an ephemeral authkey") preauth := flag.Bool("preauth", true, "set the authkey as pre-authorized") @@ -46,11 +42,11 @@ func main() { ClientID: clientID, ClientSecret: clientSecret, TokenURL: baseURL + "/api/v2/oauth/token", - Scopes: []string{"device"}, } ctx := context.Background() tsClient := tailscale.NewClient("-", nil) + tsClient.UserAgent = "tailscale-get-authkey" tsClient.HTTPClient = credentials.Client(ctx) tsClient.BaseURL = baseURL diff --git a/cmd/gitops-pusher/gitops-pusher.go b/cmd/gitops-pusher/gitops-pusher.go index c33937ef2..690ca2870 100644 --- a/cmd/gitops-pusher/gitops-pusher.go +++ b/cmd/gitops-pusher/gitops-pusher.go @@ -13,6 +13,7 @@ import ( "encoding/json" "flag" "fmt" + "io" "log" "net/http" "os" @@ -58,8 +59,8 @@ func apply(cache *Cache, client *http.Client, tailnet, apiKey string) func(conte } if cache.PrevETag == "" { - log.Println("no previous etag found, assuming local file is correct and recording that") - cache.PrevETag = localEtag + log.Println("no previous etag found, assuming the latest control etag") + cache.PrevETag = controlEtag } log.Printf("control: %s", controlEtag) @@ -105,8 +106,8 @@ func test(cache *Cache, client *http.Client, tailnet, apiKey string) func(contex } if cache.PrevETag == "" { - log.Println("no previous etag found, assuming local file is correct and recording that") - cache.PrevETag = localEtag + log.Println("no previous etag found, assuming the latest control etag") + cache.PrevETag = controlEtag } log.Printf("control: %s", controlEtag) @@ -148,8 +149,8 @@ func getChecksums(cache *Cache, client *http.Client, tailnet, apiKey string) fun } if cache.PrevETag == "" { - log.Println("no previous etag found, assuming local file is correct and recording that") - cache.PrevETag = Shuck(localEtag) + log.Println("no previous etag found, assuming control etag") + cache.PrevETag = Shuck(controlEtag) } log.Printf("control: %s", controlEtag) @@ -405,7 +406,8 @@ func getACLETag(ctx context.Context, client *http.Client, tailnet, apiKey string got := resp.StatusCode want := http.StatusOK if got != want { - return "", fmt.Errorf("wanted HTTP status code %d but got %d", want, got) + errorDetails, _ := io.ReadAll(resp.Body) + return "", fmt.Errorf("wanted HTTP status code %d but got %d: %#q", want, got, string(errorDetails)) } return Shuck(resp.Header.Get("ETag")), nil diff --git a/cmd/gitops-pusher/gitops-pusher_test.go b/cmd/gitops-pusher/gitops-pusher_test.go index b050761d9..e08b06c9c 100644 --- a/cmd/gitops-pusher/gitops-pusher_test.go +++ b/cmd/gitops-pusher/gitops-pusher_test.go @@ -1,5 +1,6 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause + package main import ( diff --git a/cmd/hello/hello.go b/cmd/hello/hello.go index e4b0ca827..fa116b28b 100644 --- a/cmd/hello/hello.go +++ b/cmd/hello/hello.go @@ -18,8 +18,9 @@ import ( "strings" "time" - "tailscale.com/client/tailscale" + "tailscale.com/client/local" "tailscale.com/client/tailscale/apitype" + "tailscale.com/tailcfg" ) var ( @@ -31,7 +32,7 @@ var ( //go:embed hello.tmpl.html var embeddedTemplate string -var localClient tailscale.LocalClient +var localClient local.Client func main() { flag.Parse() @@ -134,6 +135,10 @@ func tailscaleIP(who *apitype.WhoIsResponse) string { if who == nil { return "" } + vals, err := tailcfg.UnmarshalNodeCapJSON[string](who.Node.CapMap, tailcfg.NodeAttrNativeIPV4) + if err == nil && len(vals) > 0 { + return vals[0] + } for _, nodeIP := range who.Node.Addresses { if nodeIP.Addr().Is4() && nodeIP.IsSingleIP() { return nodeIP.Addr().String() diff --git a/cmd/jsonimports/format.go b/cmd/jsonimports/format.go new file mode 100644 index 000000000..6dbd17558 --- /dev/null +++ b/cmd/jsonimports/format.go @@ -0,0 +1,175 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package main + +import ( + "bytes" + "go/ast" + "go/format" + "go/parser" + "go/token" + "go/types" + "path" + "slices" + "strconv" + "strings" + + "tailscale.com/util/must" +) + +// mustFormatFile formats a Go source file and adjust "json" imports. +// It panics if there are any parsing errors. +// +// - "encoding/json" is imported under the name "jsonv1" or "jsonv1std" +// - "encoding/json/v2" is rewritten to import "github.com/go-json-experiment/json" instead +// - "encoding/json/jsontext" is rewritten to import "github.com/go-json-experiment/json/jsontext" instead +// - "github.com/go-json-experiment/json" is imported under the name "jsonv2" +// - "github.com/go-json-experiment/json/v1" is imported under the name "jsonv1" +// +// If no changes to the file is made, it returns input. +func mustFormatFile(in []byte) (out []byte) { + fset := token.NewFileSet() + f := must.Get(parser.ParseFile(fset, "", in, parser.ParseComments)) + + // Check for the existence of "json" imports. + jsonImports := make(map[string][]*ast.ImportSpec) + for _, imp := range f.Imports { + switch pkgPath := must.Get(strconv.Unquote(imp.Path.Value)); pkgPath { + case + "encoding/json", + "encoding/json/v2", + "encoding/json/jsontext", + "github.com/go-json-experiment/json", + "github.com/go-json-experiment/json/v1", + "github.com/go-json-experiment/json/jsontext": + jsonImports[pkgPath] = append(jsonImports[pkgPath], imp) + } + } + if len(jsonImports) == 0 { + return in + } + + // Best-effort local type-check of the file + // to resolve local declarations to detect shadowed variables. + typeInfo := &types.Info{Uses: make(map[*ast.Ident]types.Object)} + (&types.Config{ + Error: func(err error) {}, + }).Check("", fset, []*ast.File{f}, typeInfo) + + // Rewrite imports to instead use "github.com/go-json-experiment/json". + // This ensures that code continues to build even if + // goexperiment.jsonv2 is *not* specified. + // As of https://github.com/go-json-experiment/json/pull/186, + // imports to "github.com/go-json-experiment/json" are identical + // to the standard library if built with goexperiment.jsonv2. + for fromPath, toPath := range map[string]string{ + "encoding/json/v2": "github.com/go-json-experiment/json", + "encoding/json/jsontext": "github.com/go-json-experiment/json/jsontext", + } { + for _, imp := range jsonImports[fromPath] { + imp.Path.Value = strconv.Quote(toPath) + jsonImports[toPath] = append(jsonImports[toPath], imp) + } + delete(jsonImports, fromPath) + } + + // While in a transitory state, where both v1 and v2 json imports + // may exist in our codebase, always explicitly import with + // either jsonv1 or jsonv2 in the package name to avoid ambiguities + // when looking at a particular Marshal or Unmarshal call site. + renames := make(map[string]string) // mapping of old names to new names + deletes := make(map[*ast.ImportSpec]bool) // set of imports to delete + for pkgPath, imps := range jsonImports { + var newName string + switch pkgPath { + case "encoding/json": + newName = "jsonv1" + // If "github.com/go-json-experiment/json/v1" is also imported, + // then use jsonv1std for "encoding/json" to avoid a conflict. + if len(jsonImports["github.com/go-json-experiment/json/v1"]) > 0 { + newName += "std" + } + case "github.com/go-json-experiment/json": + newName = "jsonv2" + case "github.com/go-json-experiment/json/v1": + newName = "jsonv1" + } + + // Rename the import if different than expected. + if oldName := importName(imps[0]); oldName != newName && newName != "" { + renames[oldName] = newName + pos := imps[0].Pos() // preserve original positioning + imps[0].Name = ast.NewIdent(newName) + imps[0].Name.NamePos = pos + } + + // For all redundant imports, use the first imported name. + for _, imp := range imps[1:] { + renames[importName(imp)] = importName(imps[0]) + deletes[imp] = true + } + } + if len(deletes) > 0 { + f.Imports = slices.DeleteFunc(f.Imports, func(imp *ast.ImportSpec) bool { + return deletes[imp] + }) + for _, decl := range f.Decls { + if genDecl, ok := decl.(*ast.GenDecl); ok && genDecl.Tok == token.IMPORT { + genDecl.Specs = slices.DeleteFunc(genDecl.Specs, func(spec ast.Spec) bool { + return deletes[spec.(*ast.ImportSpec)] + }) + } + } + } + if len(renames) > 0 { + ast.Walk(astVisitor(func(n ast.Node) bool { + if sel, ok := n.(*ast.SelectorExpr); ok { + if id, ok := sel.X.(*ast.Ident); ok { + // Just because the selector looks like "json.Marshal" + // does not mean that it is referencing the "json" package. + // There could be a local "json" declaration that shadows + // the package import. Check partial type information + // to see if there was a local declaration. + if obj, ok := typeInfo.Uses[id]; ok { + if _, ok := obj.(*types.PkgName); !ok { + return true + } + } + + if newName, ok := renames[id.String()]; ok { + id.Name = newName + } + } + } + return true + }), f) + } + + bb := new(bytes.Buffer) + must.Do(format.Node(bb, fset, f)) + return must.Get(format.Source(bb.Bytes())) +} + +// importName is the local package name used for an import. +// If no explicit local name is used, then it uses string parsing +// to derive the package name from the path, relying on the convention +// that the package name is the base name of the package path. +func importName(imp *ast.ImportSpec) string { + if imp.Name != nil { + return imp.Name.String() + } + pkgPath, _ := strconv.Unquote(imp.Path.Value) + pkgPath = strings.TrimRight(pkgPath, "/v0123456789") // exclude version directories + return path.Base(pkgPath) +} + +// astVisitor is a function that implements [ast.Visitor]. +type astVisitor func(ast.Node) bool + +func (f astVisitor) Visit(node ast.Node) ast.Visitor { + if !f(node) { + return nil + } + return f +} diff --git a/cmd/jsonimports/format_test.go b/cmd/jsonimports/format_test.go new file mode 100644 index 000000000..28654eb45 --- /dev/null +++ b/cmd/jsonimports/format_test.go @@ -0,0 +1,162 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package main + +import ( + "go/format" + "testing" + + "tailscale.com/util/must" + "tailscale.com/util/safediff" +) + +func TestFormatFile(t *testing.T) { + tests := []struct{ in, want string }{{ + in: `package foobar + + import ( + "encoding/json" + jsonv2exp "github.com/go-json-experiment/json" + ) + + func main() { + json.Marshal() + jsonv2exp.Marshal() + { + var json T // deliberately shadow "json" package name + json.Marshal() // should not be re-written + } + } + `, + want: `package foobar + + import ( + jsonv1 "encoding/json" + jsonv2 "github.com/go-json-experiment/json" + ) + + func main() { + jsonv1.Marshal() + jsonv2.Marshal() + { + var json T // deliberately shadow "json" package name + json.Marshal() // should not be re-written + } + } + `, + }, { + in: `package foobar + + import ( + "github.com/go-json-experiment/json" + jsonv2exp "github.com/go-json-experiment/json" + ) + + func main() { + json.Marshal() + jsonv2exp.Marshal() + } + `, + want: `package foobar + import ( + jsonv2 "github.com/go-json-experiment/json" + ) + func main() { + jsonv2.Marshal() + jsonv2.Marshal() + } + `, + }, { + in: `package foobar + import "github.com/go-json-experiment/json/v1" + func main() { + json.Marshal() + } + `, + want: `package foobar + import jsonv1 "github.com/go-json-experiment/json/v1" + func main() { + jsonv1.Marshal() + } + `, + }, { + in: `package foobar + import ( + "encoding/json" + jsonv1in2 "github.com/go-json-experiment/json/v1" + ) + func main() { + json.Marshal() + jsonv1in2.Marshal() + } + `, + want: `package foobar + import ( + jsonv1std "encoding/json" + jsonv1 "github.com/go-json-experiment/json/v1" + ) + func main() { + jsonv1std.Marshal() + jsonv1.Marshal() + } + `, + }, { + in: `package foobar + import ( + "encoding/json" + jsonv1in2 "github.com/go-json-experiment/json/v1" + ) + func main() { + json.Marshal() + jsonv1in2.Marshal() + } + `, + want: `package foobar + import ( + jsonv1std "encoding/json" + jsonv1 "github.com/go-json-experiment/json/v1" + ) + func main() { + jsonv1std.Marshal() + jsonv1.Marshal() + } + `, + }, { + in: `package foobar + import ( + "encoding/json" + j2 "encoding/json/v2" + "encoding/json/jsontext" + ) + func main() { + json.Marshal() + j2.Marshal() + jsontext.NewEncoder + } + `, + want: `package foobar + import ( + jsonv1 "encoding/json" + jsonv2 "github.com/go-json-experiment/json" + "github.com/go-json-experiment/json/jsontext" + ) + func main() { + jsonv1.Marshal() + jsonv2.Marshal() + jsontext.NewEncoder + } + `, + }} + for _, tt := range tests { + got := string(must.Get(format.Source([]byte(tt.in)))) + got = string(mustFormatFile([]byte(got))) + want := string(must.Get(format.Source([]byte(tt.want)))) + if got != want { + diff, _ := safediff.Lines(got, want, -1) + t.Errorf("mismatch (-got +want)\n%s", diff) + t.Error(got) + t.Error(want) + } + } +} diff --git a/cmd/jsonimports/jsonimports.go b/cmd/jsonimports/jsonimports.go new file mode 100644 index 000000000..4be2e10cb --- /dev/null +++ b/cmd/jsonimports/jsonimports.go @@ -0,0 +1,124 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// The jsonimports tool formats all Go source files in the repository +// to enforce that "json" imports are consistent. +// +// With Go 1.25, the "encoding/json/v2" and "encoding/json/jsontext" +// packages are now available under goexperiment.jsonv2. +// This leads to possible confusion over the following: +// +// - "encoding/json" +// - "encoding/json/v2" +// - "encoding/json/jsontext" +// - "github.com/go-json-experiment/json/v1" +// - "github.com/go-json-experiment/json" +// - "github.com/go-json-experiment/json/jsontext" +// +// In order to enforce consistent usage, we apply the following rules: +// +// - Until the Go standard library formally accepts "encoding/json/v2" +// and "encoding/json/jsontext" into the standard library +// (i.e., they are no longer considered experimental), +// we forbid any code from directly importing those packages. +// Go code should instead import "github.com/go-json-experiment/json" +// and "github.com/go-json-experiment/json/jsontext". +// The latter packages contain aliases to the standard library +// if built on Go 1.25 with the goexperiment.jsonv2 tag specified. +// +// - Imports of "encoding/json" or "github.com/go-json-experiment/json/v1" +// must be explicitly imported under the package name "jsonv1". +// If both packages need to be imported, then the former should +// be imported under the package name "jsonv1std". +// +// - Imports of "github.com/go-json-experiment/json" +// must be explicitly imported under the package name "jsonv2". +// +// The latter two rules exist to provide clarity when reading code. +// Without them, it is unclear whether "json.Marshal" refers to v1 or v2. +// With them, however, it is clear that "jsonv1.Marshal" is calling v1 and +// that "jsonv2.Marshal" is calling v2. +// +// TODO(@joetsai): At this present moment, there is no guidance given on +// whether to use v1 or v2 for newly written Go source code. +// I will write a document in the near future providing more guidance. +// Feel free to continue using v1 "encoding/json" as you are accustomed to. +package main + +import ( + "bytes" + "flag" + "fmt" + "os" + "os/exec" + "runtime" + "strings" + "sync" + + "tailscale.com/syncs" + "tailscale.com/util/must" + "tailscale.com/util/safediff" +) + +func main() { + update := flag.Bool("update", false, "update all Go source files") + flag.Parse() + + // Change working directory to Git repository root. + repoRoot := strings.TrimSuffix(string(must.Get(exec.Command( + "git", "rev-parse", "--show-toplevel", + ).Output())), "\n") + must.Do(os.Chdir(repoRoot)) + + // Iterate over all indexed files in the Git repository. + var printMu sync.Mutex + var group sync.WaitGroup + sema := syncs.NewSemaphore(runtime.NumCPU()) + var numDiffs int + files := string(must.Get(exec.Command("git", "ls-files").Output())) + for file := range strings.Lines(files) { + sema.Acquire() + group.Go(func() { + defer sema.Release() + + // Ignore non-Go source files. + file = strings.TrimSuffix(file, "\n") + if !strings.HasSuffix(file, ".go") { + return + } + + // Format all "json" imports in the Go source file. + srcIn := must.Get(os.ReadFile(file)) + srcOut := mustFormatFile(srcIn) + + // Print differences with each formatted file. + if !bytes.Equal(srcIn, srcOut) { + numDiffs++ + + printMu.Lock() + fmt.Println(file) + lines, _ := safediff.Lines(string(srcIn), string(srcOut), -1) + for line := range strings.Lines(lines) { + fmt.Print("\t", line) + } + fmt.Println() + printMu.Unlock() + + // If -update is specified, write out the changes. + if *update { + mode := must.Get(os.Stat(file)).Mode() + must.Do(os.WriteFile(file, srcOut, mode)) + } + } + }) + } + group.Wait() + + // Report whether any differences were detected. + if numDiffs > 0 && !*update { + fmt.Printf(`%d files with "json" imports that need formatting`+"\n", numDiffs) + fmt.Println("Please run:") + fmt.Println("\t./tool/go run tailscale.com/cmd/jsonimports -update") + os.Exit(1) + } +} diff --git a/cmd/k8s-nameserver/main.go b/cmd/k8s-nameserver/main.go index ca4b44935..84e65452d 100644 --- a/cmd/k8s-nameserver/main.go +++ b/cmd/k8s-nameserver/main.go @@ -31,6 +31,9 @@ const ( tsNetDomain = "ts.net" // addr is the the address that the UDP and TCP listeners will listen on. addr = ":1053" + // defaultTTL is the default TTL for DNS records in seconds. + // Set to 0 to disable caching. Can be increased when usage patterns are better understood. + defaultTTL = 0 // The following constants are specific to the nameserver configuration // provided by a mounted Kubernetes Configmap. The Configmap mounted at @@ -39,9 +42,9 @@ const ( kubeletMountedConfigLn = "..data" ) -// nameserver is a simple nameserver that responds to DNS queries for A records +// nameserver is a simple nameserver that responds to DNS queries for A and AAAA records // for ts.net domain names over UDP or TCP. It serves DNS responses from -// in-memory IPv4 host records. It is intended to be deployed on Kubernetes with +// in-memory IPv4 and IPv6 host records. It is intended to be deployed on Kubernetes with // a ConfigMap mounted at /config that should contain the host records. It // dynamically reconfigures its in-memory mappings as the contents of the // mounted ConfigMap changes. @@ -56,10 +59,13 @@ type nameserver struct { // in-memory records. configWatcher <-chan string - mu sync.Mutex // protects following + mu sync.RWMutex // protects following // ip4 are the in-memory hostname -> IP4 mappings that the nameserver // uses to respond to A record queries. ip4 map[dnsname.FQDN][]net.IP + // ip6 are the in-memory hostname -> IP6 mappings that the nameserver + // uses to respond to AAAA record queries. + ip6 map[dnsname.FQDN][]net.IP } func main() { @@ -98,16 +104,13 @@ func main() { tcpSig <- s // stop the TCP listener } -// handleFunc is a DNS query handler that can respond to A record queries from +// handleFunc is a DNS query handler that can respond to A and AAAA record queries from // the nameserver's in-memory records. -// - If an A record query is received and the -// nameserver's in-memory records contain records for the queried domain name, -// return a success response. -// - If an A record query is received, but the -// nameserver's in-memory records do not contain records for the queried domain name, -// return NXDOMAIN. -// - If an A record query is received, but the queried domain name is not valid, return Format Error. -// - If a query is received for any other record type than A, return Not Implemented. +// - For A queries: returns IPv4 addresses if available, NXDOMAIN if the name doesn't exist +// - For AAAA queries: returns IPv6 addresses if available, NOERROR with no data if only +// IPv4 exists (per RFC 4074), or NXDOMAIN if the name doesn't exist at all +// - For invalid domain names: returns Format Error +// - For other record types: returns Not Implemented func (n *nameserver) handleFunc() func(w dns.ResponseWriter, r *dns.Msg) { h := func(w dns.ResponseWriter, r *dns.Msg) { m := new(dns.Msg) @@ -135,35 +138,19 @@ func (n *nameserver) handleFunc() func(w dns.ResponseWriter, r *dns.Msg) { m.RecursionAvailable = false ips := n.lookupIP4(fqdn) - if ips == nil || len(ips) == 0 { + if len(ips) == 0 { // As we are the authoritative nameserver for MagicDNS // names, if we do not have a record for this MagicDNS // name, it does not exist. m = m.SetRcode(r, dns.RcodeNameError) return } - // TODO (irbekrm): TTL is currently set to 0, meaning - // that cluster workloads will not cache the DNS - // records. Revisit this in future when we understand - // the usage patterns better- is it putting too much - // load on kube DNS server or is this fine? for _, ip := range ips { - rr := &dns.A{Hdr: dns.RR_Header{Name: q, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 0}, A: ip} + rr := &dns.A{Hdr: dns.RR_Header{Name: q, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: defaultTTL}, A: ip} m.SetRcode(r, dns.RcodeSuccess) m.Answer = append(m.Answer, rr) } case dns.TypeAAAA: - // TODO (irbekrm): add IPv6 support. - // The nameserver currently does not support IPv6 - // (records are not being created for IPv6 Pod addresses). - // However, we can expect that some callers will - // nevertheless send AAAA queries. - // We have to return NOERROR if a query is received for - // an AAAA record for a DNS name that we have an A - // record for- else the caller might not follow with an - // A record query. - // https://github.com/tailscale/tailscale/issues/12321 - // https://datatracker.ietf.org/doc/html/rfc4074 q := r.Question[0].Name fqdn, err := dnsname.ToFQDN(q) if err != nil { @@ -174,14 +161,27 @@ func (n *nameserver) handleFunc() func(w dns.ResponseWriter, r *dns.Msg) { // single source of truth for MagicDNS names by // non-tailnet Kubernetes workloads. m.Authoritative = true - ips := n.lookupIP4(fqdn) - if len(ips) == 0 { + m.RecursionAvailable = false + + ips := n.lookupIP6(fqdn) + // Also check if we have IPv4 records to determine correct response code. + // If the name exists (has A records) but no AAAA records, we return NOERROR + // per RFC 4074. If the name doesn't exist at all, we return NXDOMAIN. + ip4s := n.lookupIP4(fqdn) + + if len(ips) == 0 && len(ip4s) == 0 { // As we are the authoritative nameserver for MagicDNS - // names, if we do not have a record for this MagicDNS + // names, if we do not have any record for this MagicDNS // name, it does not exist. m = m.SetRcode(r, dns.RcodeNameError) return } + + // Return IPv6 addresses if available + for _, ip := range ips { + rr := &dns.AAAA{Hdr: dns.RR_Header{Name: q, Rrtype: dns.TypeAAAA, Class: dns.ClassINET, Ttl: defaultTTL}, AAAA: ip} + m.Answer = append(m.Answer, rr) + } m.SetRcode(r, dns.RcodeSuccess) default: log.Printf("[unexpected] nameserver received a query for an unsupported record type: %s", r.Question[0].String()) @@ -231,10 +231,11 @@ func (n *nameserver) resetRecords() error { log.Printf("error reading nameserver's configuration: %v", err) return err } - if dnsCfgBytes == nil || len(dnsCfgBytes) < 1 { + if len(dnsCfgBytes) == 0 { log.Print("nameserver's configuration is empty, any in-memory records will be unset") n.mu.Lock() n.ip4 = make(map[dnsname.FQDN][]net.IP) + n.ip6 = make(map[dnsname.FQDN][]net.IP) n.mu.Unlock() return nil } @@ -249,30 +250,63 @@ func (n *nameserver) resetRecords() error { } ip4 := make(map[dnsname.FQDN][]net.IP) + ip6 := make(map[dnsname.FQDN][]net.IP) defer func() { n.mu.Lock() defer n.mu.Unlock() n.ip4 = ip4 + n.ip6 = ip6 }() - if len(dnsCfg.IP4) == 0 { + if len(dnsCfg.IP4) == 0 && len(dnsCfg.IP6) == 0 { log.Print("nameserver's configuration contains no records, any in-memory records will be unset") return nil } + // Process IPv4 records for fqdn, ips := range dnsCfg.IP4 { fqdn, err := dnsname.ToFQDN(fqdn) if err != nil { log.Printf("invalid nameserver's configuration: %s is not a valid FQDN: %v; skipping this record", fqdn, err) continue // one invalid hostname should not break the whole nameserver } + var validIPs []net.IP for _, ipS := range ips { ip := net.ParseIP(ipS).To4() if ip == nil { // To4 returns nil if IP is not a IPv4 address log.Printf("invalid nameserver's configuration: %v does not appear to be an IPv4 address; skipping this record", ipS) continue // one invalid IP address should not break the whole nameserver } - ip4[fqdn] = []net.IP{ip} + validIPs = append(validIPs, ip) + } + if len(validIPs) > 0 { + ip4[fqdn] = validIPs + } + } + + // Process IPv6 records + for fqdn, ips := range dnsCfg.IP6 { + fqdn, err := dnsname.ToFQDN(fqdn) + if err != nil { + log.Printf("invalid nameserver's configuration: %s is not a valid FQDN: %v; skipping this record", fqdn, err) + continue // one invalid hostname should not break the whole nameserver + } + var validIPs []net.IP + for _, ipS := range ips { + ip := net.ParseIP(ipS) + if ip == nil { + log.Printf("invalid nameserver's configuration: %v does not appear to be a valid IP address; skipping this record", ipS) + continue + } + // Check if it's a valid IPv6 address + if ip.To4() != nil { + log.Printf("invalid nameserver's configuration: %v appears to be IPv4 but was in IPv6 records; skipping this record", ipS) + continue + } + validIPs = append(validIPs, ip.To16()) + } + if len(validIPs) > 0 { + ip6[fqdn] = validIPs } } return nil @@ -372,8 +406,20 @@ func (n *nameserver) lookupIP4(fqdn dnsname.FQDN) []net.IP { if n.ip4 == nil { return nil } - n.mu.Lock() - defer n.mu.Unlock() + n.mu.RLock() + defer n.mu.RUnlock() f := n.ip4[fqdn] return f } + +// lookupIP6 returns any IPv6 addresses for the given FQDN from nameserver's +// in-memory records. +func (n *nameserver) lookupIP6(fqdn dnsname.FQDN) []net.IP { + if n.ip6 == nil { + return nil + } + n.mu.RLock() + defer n.mu.RUnlock() + f := n.ip6[fqdn] + return f +} diff --git a/cmd/k8s-nameserver/main_test.go b/cmd/k8s-nameserver/main_test.go index d9a33c4fa..bca010048 100644 --- a/cmd/k8s-nameserver/main_test.go +++ b/cmd/k8s-nameserver/main_test.go @@ -19,6 +19,7 @@ func TestNameserver(t *testing.T) { tests := []struct { name string ip4 map[dnsname.FQDN][]net.IP + ip6 map[dnsname.FQDN][]net.IP query *dns.Msg wantResp *dns.Msg }{ @@ -112,6 +113,49 @@ func TestNameserver(t *testing.T) { Authoritative: true, }}, }, + { + name: "AAAA record query with IPv6 record", + ip6: map[dnsname.FQDN][]net.IP{dnsname.FQDN("foo.bar.com."): {net.ParseIP("2001:db8::1")}}, + query: &dns.Msg{ + Question: []dns.Question{{Name: "foo.bar.com", Qtype: dns.TypeAAAA}}, + MsgHdr: dns.MsgHdr{Id: 1, RecursionDesired: true}, + }, + wantResp: &dns.Msg{ + Answer: []dns.RR{&dns.AAAA{Hdr: dns.RR_Header{ + Name: "foo.bar.com", Rrtype: dns.TypeAAAA, Class: dns.ClassINET, Ttl: 0}, + AAAA: net.ParseIP("2001:db8::1")}}, + Question: []dns.Question{{Name: "foo.bar.com", Qtype: dns.TypeAAAA}}, + MsgHdr: dns.MsgHdr{ + Id: 1, + Rcode: dns.RcodeSuccess, + RecursionAvailable: false, + RecursionDesired: true, + Response: true, + Opcode: dns.OpcodeQuery, + Authoritative: true, + }}, + }, + { + name: "Dual-stack: both A and AAAA records exist", + ip4: map[dnsname.FQDN][]net.IP{dnsname.FQDN("dual.bar.com."): {{10, 0, 0, 1}}}, + ip6: map[dnsname.FQDN][]net.IP{dnsname.FQDN("dual.bar.com."): {net.ParseIP("2001:db8::1")}}, + query: &dns.Msg{ + Question: []dns.Question{{Name: "dual.bar.com", Qtype: dns.TypeAAAA}}, + MsgHdr: dns.MsgHdr{Id: 1}, + }, + wantResp: &dns.Msg{ + Answer: []dns.RR{&dns.AAAA{Hdr: dns.RR_Header{ + Name: "dual.bar.com", Rrtype: dns.TypeAAAA, Class: dns.ClassINET, Ttl: 0}, + AAAA: net.ParseIP("2001:db8::1")}}, + Question: []dns.Question{{Name: "dual.bar.com", Qtype: dns.TypeAAAA}}, + MsgHdr: dns.MsgHdr{ + Id: 1, + Rcode: dns.RcodeSuccess, + Response: true, + Opcode: dns.OpcodeQuery, + Authoritative: true, + }}, + }, { name: "CNAME record query", ip4: map[dnsname.FQDN][]net.IP{dnsname.FQDN("foo.bar.com."): {{1, 2, 3, 4}}}, @@ -133,6 +177,7 @@ func TestNameserver(t *testing.T) { t.Run(tt.name, func(t *testing.T) { ns := &nameserver{ ip4: tt.ip4, + ip6: tt.ip6, } handler := ns.handleFunc() fakeRespW := &fakeResponseWriter{} @@ -149,43 +194,63 @@ func TestResetRecords(t *testing.T) { name string config []byte hasIp4 map[dnsname.FQDN][]net.IP + hasIp6 map[dnsname.FQDN][]net.IP wantsIp4 map[dnsname.FQDN][]net.IP + wantsIp6 map[dnsname.FQDN][]net.IP wantsErr bool }{ { name: "previously empty nameserver.ip4 gets set", config: []byte(`{"version": "v1alpha1", "ip4": {"foo.bar.com": ["1.2.3.4"]}}`), wantsIp4: map[dnsname.FQDN][]net.IP{"foo.bar.com.": {{1, 2, 3, 4}}}, + wantsIp6: make(map[dnsname.FQDN][]net.IP), }, { name: "nameserver.ip4 gets reset", hasIp4: map[dnsname.FQDN][]net.IP{"baz.bar.com.": {{1, 1, 3, 3}}}, config: []byte(`{"version": "v1alpha1", "ip4": {"foo.bar.com": ["1.2.3.4"]}}`), wantsIp4: map[dnsname.FQDN][]net.IP{"foo.bar.com.": {{1, 2, 3, 4}}}, + wantsIp6: make(map[dnsname.FQDN][]net.IP), }, { name: "configuration with incompatible version", hasIp4: map[dnsname.FQDN][]net.IP{"baz.bar.com.": {{1, 1, 3, 3}}}, config: []byte(`{"version": "v1beta1", "ip4": {"foo.bar.com": ["1.2.3.4"]}}`), wantsIp4: map[dnsname.FQDN][]net.IP{"baz.bar.com.": {{1, 1, 3, 3}}}, + wantsIp6: nil, wantsErr: true, }, { name: "nameserver.ip4 gets reset to empty config when no configuration is provided", hasIp4: map[dnsname.FQDN][]net.IP{"baz.bar.com.": {{1, 1, 3, 3}}}, wantsIp4: make(map[dnsname.FQDN][]net.IP), + wantsIp6: make(map[dnsname.FQDN][]net.IP), }, { name: "nameserver.ip4 gets reset to empty config when the provided configuration is empty", hasIp4: map[dnsname.FQDN][]net.IP{"baz.bar.com.": {{1, 1, 3, 3}}}, config: []byte(`{"version": "v1alpha1", "ip4": {}}`), wantsIp4: make(map[dnsname.FQDN][]net.IP), + wantsIp6: make(map[dnsname.FQDN][]net.IP), + }, + { + name: "nameserver.ip6 gets set", + config: []byte(`{"version": "v1alpha1", "ip6": {"foo.bar.com": ["2001:db8::1"]}}`), + wantsIp4: make(map[dnsname.FQDN][]net.IP), + wantsIp6: map[dnsname.FQDN][]net.IP{"foo.bar.com.": {net.ParseIP("2001:db8::1")}}, + }, + { + name: "dual-stack configuration", + config: []byte(`{"version": "v1alpha1", "ip4": {"dual.bar.com": ["10.0.0.1"]}, "ip6": {"dual.bar.com": ["2001:db8::1"]}}`), + wantsIp4: map[dnsname.FQDN][]net.IP{"dual.bar.com.": {{10, 0, 0, 1}}}, + wantsIp6: map[dnsname.FQDN][]net.IP{"dual.bar.com.": {net.ParseIP("2001:db8::1")}}, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { ns := &nameserver{ ip4: tt.hasIp4, + ip6: tt.hasIp6, configReader: func() ([]byte, error) { return tt.config, nil }, } if err := ns.resetRecords(); err == nil == tt.wantsErr { @@ -194,6 +259,9 @@ func TestResetRecords(t *testing.T) { if diff := cmp.Diff(ns.ip4, tt.wantsIp4); diff != "" { t.Fatalf("unexpected nameserver.ip4 contents (-got +want): \n%s", diff) } + if diff := cmp.Diff(ns.ip6, tt.wantsIp6); diff != "" { + t.Fatalf("unexpected nameserver.ip6 contents (-got +want): \n%s", diff) + } }) } } diff --git a/cmd/k8s-operator/api-server-proxy-pg.go b/cmd/k8s-operator/api-server-proxy-pg.go new file mode 100644 index 000000000..1a81e4967 --- /dev/null +++ b/cmd/k8s-operator/api-server-proxy-pg.go @@ -0,0 +1,473 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !plan9 + +package main + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "maps" + "slices" + "strings" + + "go.uber.org/zap" + corev1 "k8s.io/api/core/v1" + rbacv1 "k8s.io/api/rbac/v1" + apiequality "k8s.io/apimachinery/pkg/api/equality" + apierrors "k8s.io/apimachinery/pkg/api/errors" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/client-go/tools/record" + "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/reconcile" + "tailscale.com/internal/client/tailscale" + tsoperator "tailscale.com/k8s-operator" + tsapi "tailscale.com/k8s-operator/apis/v1alpha1" + "tailscale.com/kube/k8s-proxy/conf" + "tailscale.com/kube/kubetypes" + "tailscale.com/tailcfg" + "tailscale.com/tstime" +) + +const ( + proxyPGFinalizerName = "tailscale.com/kube-apiserver-finalizer" + + // Reasons for KubeAPIServerProxyValid condition. + reasonKubeAPIServerProxyInvalid = "KubeAPIServerProxyInvalid" + reasonKubeAPIServerProxyValid = "KubeAPIServerProxyValid" + + // Reasons for KubeAPIServerProxyConfigured condition. + reasonKubeAPIServerProxyConfigured = "KubeAPIServerProxyConfigured" + reasonKubeAPIServerProxyNoBackends = "KubeAPIServerProxyNoBackends" +) + +// KubeAPIServerTSServiceReconciler reconciles the Tailscale Services required for an +// HA deployment of the API Server Proxy. +type KubeAPIServerTSServiceReconciler struct { + client.Client + recorder record.EventRecorder + logger *zap.SugaredLogger + tsClient tsClient + tsNamespace string + lc localClient + defaultTags []string + operatorID string // stableID of the operator's Tailscale device + + clock tstime.Clock +} + +// Reconcile is the entry point for the controller. +func (r *KubeAPIServerTSServiceReconciler) Reconcile(ctx context.Context, req reconcile.Request) (res reconcile.Result, err error) { + logger := r.logger.With("ProxyGroup", req.Name) + logger.Debugf("starting reconcile") + defer logger.Debugf("reconcile finished") + + pg := new(tsapi.ProxyGroup) + err = r.Get(ctx, req.NamespacedName, pg) + if apierrors.IsNotFound(err) { + // Request object not found, could have been deleted after reconcile request. + logger.Debugf("ProxyGroup not found, assuming it was deleted") + return res, nil + } else if err != nil { + return res, fmt.Errorf("failed to get ProxyGroup: %w", err) + } + + serviceName := serviceNameForAPIServerProxy(pg) + logger = logger.With("Tailscale Service", serviceName) + + if markedForDeletion(pg) { + logger.Debugf("ProxyGroup is being deleted, ensuring any created resources are cleaned up") + if err = r.maybeCleanup(ctx, serviceName, pg, logger); err != nil && strings.Contains(err.Error(), optimisticLockErrorMsg) { + logger.Infof("optimistic lock error, retrying: %s", err) + return res, nil + } + + return res, err + } + + err = r.maybeProvision(ctx, serviceName, pg, logger) + if err != nil { + if strings.Contains(err.Error(), optimisticLockErrorMsg) { + logger.Infof("optimistic lock error, retrying: %s", err) + return reconcile.Result{}, nil + } + return reconcile.Result{}, err + } + + return reconcile.Result{}, nil +} + +// maybeProvision ensures that a Tailscale Service for this ProxyGroup exists +// and is up to date. +// +// Returns true if the operation resulted in a Tailscale Service update. +func (r *KubeAPIServerTSServiceReconciler) maybeProvision(ctx context.Context, serviceName tailcfg.ServiceName, pg *tsapi.ProxyGroup, logger *zap.SugaredLogger) (err error) { + var dnsName string + oldPGStatus := pg.Status.DeepCopy() + defer func() { + podsAdvertising, podsErr := numberPodsAdvertising(ctx, r.Client, r.tsNamespace, pg.Name, serviceName) + if podsErr != nil { + err = errors.Join(err, fmt.Errorf("failed to get number of advertised Pods: %w", podsErr)) + // Continue, updating the status with the best available information. + } + + // Update the ProxyGroup status with the Tailscale Service information + // Update the condition based on how many pods are advertising the service + conditionStatus := metav1.ConditionFalse + conditionReason := reasonKubeAPIServerProxyNoBackends + conditionMessage := fmt.Sprintf("%d/%d proxy backends ready and advertising", podsAdvertising, pgReplicas(pg)) + + pg.Status.URL = "" + if podsAdvertising > 0 { + // At least one pod is advertising the service, consider it configured + conditionStatus = metav1.ConditionTrue + conditionReason = reasonKubeAPIServerProxyConfigured + if dnsName != "" { + pg.Status.URL = "https://" + dnsName + } + } + + tsoperator.SetProxyGroupCondition(pg, tsapi.KubeAPIServerProxyConfigured, conditionStatus, conditionReason, conditionMessage, pg.Generation, r.clock, logger) + + if !apiequality.Semantic.DeepEqual(oldPGStatus, &pg.Status) { + // An error encountered here should get returned by the Reconcile function. + err = errors.Join(err, r.Client.Status().Update(ctx, pg)) + } + }() + + if !tsoperator.ProxyGroupAvailable(pg) { + return nil + } + + if !slices.Contains(pg.Finalizers, proxyPGFinalizerName) { + // This log line is printed exactly once during initial provisioning, + // because once the finalizer is in place this block gets skipped. So, + // this is a nice place to tell the operator that the high level, + // multi-reconcile operation is underway. + logger.Info("provisioning Tailscale Service for ProxyGroup") + pg.Finalizers = append(pg.Finalizers, proxyPGFinalizerName) + if err := r.Update(ctx, pg); err != nil { + return fmt.Errorf("failed to add finalizer: %w", err) + } + } + + // 1. Check there isn't a Tailscale Service with the same hostname + // already created and not owned by this ProxyGroup. + existingTSSvc, err := r.tsClient.GetVIPService(ctx, serviceName) + if err != nil && !isErrorTailscaleServiceNotFound(err) { + return fmt.Errorf("error getting Tailscale Service %q: %w", serviceName, err) + } + + updatedAnnotations, err := exclusiveOwnerAnnotations(pg, r.operatorID, existingTSSvc) + if err != nil { + const instr = "To proceed, you can either manually delete the existing Tailscale Service or choose a different Service name in the ProxyGroup's spec.kubeAPIServer.serviceName field" + msg := fmt.Sprintf("error ensuring exclusive ownership of Tailscale Service %s: %v. %s", serviceName, err, instr) + logger.Warn(msg) + r.recorder.Event(pg, corev1.EventTypeWarning, "InvalidTailscaleService", msg) + tsoperator.SetProxyGroupCondition(pg, tsapi.KubeAPIServerProxyValid, metav1.ConditionFalse, reasonKubeAPIServerProxyInvalid, msg, pg.Generation, r.clock, logger) + return nil + } + + // After getting this far, we know the Tailscale Service is valid. + tsoperator.SetProxyGroupCondition(pg, tsapi.KubeAPIServerProxyValid, metav1.ConditionTrue, reasonKubeAPIServerProxyValid, reasonKubeAPIServerProxyValid, pg.Generation, r.clock, logger) + + // Service tags are limited to matching the ProxyGroup's tags until we have + // support for querying peer caps for a Service-bound request. + serviceTags := r.defaultTags + if len(pg.Spec.Tags) > 0 { + serviceTags = pg.Spec.Tags.Stringify() + } + + tsSvc := &tailscale.VIPService{ + Name: serviceName, + Tags: serviceTags, + Ports: []string{"tcp:443"}, + Comment: managedTSServiceComment, + Annotations: updatedAnnotations, + } + if existingTSSvc != nil { + tsSvc.Addrs = existingTSSvc.Addrs + } + + // 2. Ensure the Tailscale Service exists and is up to date. + if existingTSSvc == nil || + !slices.Equal(tsSvc.Tags, existingTSSvc.Tags) || + !ownersAreSetAndEqual(tsSvc, existingTSSvc) || + !slices.Equal(tsSvc.Ports, existingTSSvc.Ports) { + logger.Infof("Ensuring Tailscale Service exists and is up to date") + if err := r.tsClient.CreateOrUpdateVIPService(ctx, tsSvc); err != nil { + return fmt.Errorf("error creating Tailscale Service: %w", err) + } + } + + // 3. Ensure that TLS Secret and RBAC exists. + tcd, err := tailnetCertDomain(ctx, r.lc) + if err != nil { + return fmt.Errorf("error determining DNS name base: %w", err) + } + dnsName = serviceName.WithoutPrefix() + "." + tcd + if err = r.ensureCertResources(ctx, pg, dnsName); err != nil { + return fmt.Errorf("error ensuring cert resources: %w", err) + } + + // 4. Configure the Pods to advertise the Tailscale Service. + if err = r.maybeAdvertiseServices(ctx, pg, serviceName, logger); err != nil { + return fmt.Errorf("error updating advertised Tailscale Services: %w", err) + } + + // 5. Clean up any stale Tailscale Services from previous resource versions. + if err = r.maybeDeleteStaleServices(ctx, pg, logger); err != nil { + return fmt.Errorf("failed to delete stale Tailscale Services: %w", err) + } + + return nil +} + +// maybeCleanup ensures that any resources, such as a Tailscale Service created for this Service, are cleaned up when the +// Service is being deleted or is unexposed. The cleanup is safe for a multi-cluster setup- the Tailscale Service is only +// deleted if it does not contain any other owner references. If it does, the cleanup only removes the owner reference +// corresponding to this Service. +func (r *KubeAPIServerTSServiceReconciler) maybeCleanup(ctx context.Context, serviceName tailcfg.ServiceName, pg *tsapi.ProxyGroup, logger *zap.SugaredLogger) (err error) { + ix := slices.Index(pg.Finalizers, proxyPGFinalizerName) + if ix < 0 { + logger.Debugf("no finalizer, nothing to do") + return nil + } + logger.Infof("Ensuring that Service %q is cleaned up", serviceName) + + defer func() { + if err == nil { + err = r.deleteFinalizer(ctx, pg, logger) + } + }() + + if _, err = cleanupTailscaleService(ctx, r.tsClient, serviceName, r.operatorID, logger); err != nil { + return fmt.Errorf("error deleting Tailscale Service: %w", err) + } + + if err = cleanupCertResources(ctx, r.Client, r.lc, r.tsNamespace, pg.Name, serviceName); err != nil { + return fmt.Errorf("failed to clean up cert resources: %w", err) + } + + return nil +} + +// maybeDeleteStaleServices deletes Services that have previously been created for +// this ProxyGroup but are no longer needed. +func (r *KubeAPIServerTSServiceReconciler) maybeDeleteStaleServices(ctx context.Context, pg *tsapi.ProxyGroup, logger *zap.SugaredLogger) error { + serviceName := serviceNameForAPIServerProxy(pg) + + svcs, err := r.tsClient.ListVIPServices(ctx) + if err != nil { + return fmt.Errorf("error listing Tailscale Services: %w", err) + } + + for _, svc := range svcs.VIPServices { + if svc.Name == serviceName { + continue + } + + owners, err := parseOwnerAnnotation(&svc) + if err != nil { + logger.Warnf("error parsing owner annotation for Tailscale Service %s: %v", svc.Name, err) + continue + } + if owners == nil || len(owners.OwnerRefs) != 1 || owners.OwnerRefs[0].OperatorID != r.operatorID { + continue + } + + owner := owners.OwnerRefs[0] + if owner.Resource == nil || owner.Resource.Kind != "ProxyGroup" || owner.Resource.UID != string(pg.UID) { + continue + } + + logger.Infof("Deleting Tailscale Service %s", svc.Name) + if err := r.tsClient.DeleteVIPService(ctx, svc.Name); err != nil && !isErrorTailscaleServiceNotFound(err) { + return fmt.Errorf("error deleting Tailscale Service %s: %w", svc.Name, err) + } + + if err = cleanupCertResources(ctx, r.Client, r.lc, r.tsNamespace, pg.Name, svc.Name); err != nil { + return fmt.Errorf("failed to clean up cert resources: %w", err) + } + } + + return nil +} + +func (r *KubeAPIServerTSServiceReconciler) deleteFinalizer(ctx context.Context, pg *tsapi.ProxyGroup, logger *zap.SugaredLogger) error { + pg.Finalizers = slices.DeleteFunc(pg.Finalizers, func(f string) bool { + return f == proxyPGFinalizerName + }) + logger.Debugf("ensure %q finalizer is removed", proxyPGFinalizerName) + + if err := r.Update(ctx, pg); err != nil { + return fmt.Errorf("failed to remove finalizer %q: %w", proxyPGFinalizerName, err) + } + return nil +} + +func (r *KubeAPIServerTSServiceReconciler) ensureCertResources(ctx context.Context, pg *tsapi.ProxyGroup, domain string) error { + secret := certSecret(pg.Name, r.tsNamespace, domain, pg) + if _, err := createOrUpdate(ctx, r.Client, r.tsNamespace, secret, func(s *corev1.Secret) { + s.Labels = secret.Labels + }); err != nil { + return fmt.Errorf("failed to create or update Secret %s: %w", secret.Name, err) + } + role := certSecretRole(pg.Name, r.tsNamespace, domain) + if _, err := createOrUpdate(ctx, r.Client, r.tsNamespace, role, func(r *rbacv1.Role) { + r.Labels = role.Labels + r.Rules = role.Rules + }); err != nil { + return fmt.Errorf("failed to create or update Role %s: %w", role.Name, err) + } + rolebinding := certSecretRoleBinding(pg, r.tsNamespace, domain) + if _, err := createOrUpdate(ctx, r.Client, r.tsNamespace, rolebinding, func(rb *rbacv1.RoleBinding) { + rb.Labels = rolebinding.Labels + rb.Subjects = rolebinding.Subjects + rb.RoleRef = rolebinding.RoleRef + }); err != nil { + return fmt.Errorf("failed to create or update RoleBinding %s: %w", rolebinding.Name, err) + } + return nil +} + +func (r *KubeAPIServerTSServiceReconciler) maybeAdvertiseServices(ctx context.Context, pg *tsapi.ProxyGroup, serviceName tailcfg.ServiceName, logger *zap.SugaredLogger) error { + // Get all config Secrets for this ProxyGroup + cfgSecrets := &corev1.SecretList{} + if err := r.List(ctx, cfgSecrets, client.InNamespace(r.tsNamespace), client.MatchingLabels(pgSecretLabels(pg.Name, kubetypes.LabelSecretTypeConfig))); err != nil { + return fmt.Errorf("failed to list config Secrets: %w", err) + } + + // Only advertise a Tailscale Service once the TLS certs required for + // serving it are available. + shouldBeAdvertised, err := hasCerts(ctx, r.Client, r.lc, r.tsNamespace, serviceName) + if err != nil { + return fmt.Errorf("error checking TLS credentials provisioned for Tailscale Service %q: %w", serviceName, err) + } + var advertiseServices []string + if shouldBeAdvertised { + advertiseServices = []string{serviceName.String()} + } + + for _, s := range cfgSecrets.Items { + if len(s.Data[kubetypes.KubeAPIServerConfigFile]) == 0 { + continue + } + + // Parse the existing config. + cfg, err := conf.Load(s.Data[kubetypes.KubeAPIServerConfigFile]) + if err != nil { + return fmt.Errorf("error loading config from Secret %q: %w", s.Name, err) + } + + if cfg.Parsed.APIServerProxy == nil { + return fmt.Errorf("config Secret %q does not contain APIServerProxy config", s.Name) + } + + existingCfgSecret := s.DeepCopy() + + var updated bool + if cfg.Parsed.APIServerProxy.ServiceName == nil || *cfg.Parsed.APIServerProxy.ServiceName != serviceName { + cfg.Parsed.APIServerProxy.ServiceName = &serviceName + updated = true + } + + // Update the services to advertise if required. + if !slices.Equal(cfg.Parsed.AdvertiseServices, advertiseServices) { + cfg.Parsed.AdvertiseServices = advertiseServices + updated = true + } + + if !updated { + continue + } + + // Update the config Secret. + cfgB, err := json.Marshal(conf.VersionedConfig{ + Version: "v1alpha1", + ConfigV1Alpha1: &cfg.Parsed, + }) + if err != nil { + return err + } + + s.Data[kubetypes.KubeAPIServerConfigFile] = cfgB + if !apiequality.Semantic.DeepEqual(existingCfgSecret, s) { + logger.Debugf("Updating the Tailscale Services in ProxyGroup config Secret %s", s.Name) + if err := r.Update(ctx, &s); err != nil { + return err + } + } + } + + return nil +} + +func serviceNameForAPIServerProxy(pg *tsapi.ProxyGroup) tailcfg.ServiceName { + if pg.Spec.KubeAPIServer != nil && pg.Spec.KubeAPIServer.Hostname != "" { + return tailcfg.ServiceName("svc:" + pg.Spec.KubeAPIServer.Hostname) + } + + return tailcfg.ServiceName("svc:" + pg.Name) +} + +// exclusiveOwnerAnnotations returns the updated annotations required to ensure this +// instance of the operator is the exclusive owner. If the Tailscale Service is not +// nil, but does not contain an owner reference we return an error as this likely means +// that the Service was created by something other than a Tailscale Kubernetes operator. +// We also error if it is already owned by another operator instance, as we do not +// want to load balance a kube-apiserver ProxyGroup across multiple clusters. +func exclusiveOwnerAnnotations(pg *tsapi.ProxyGroup, operatorID string, svc *tailscale.VIPService) (map[string]string, error) { + ref := OwnerRef{ + OperatorID: operatorID, + Resource: &Resource{ + Kind: "ProxyGroup", + Name: pg.Name, + UID: string(pg.UID), + }, + } + if svc == nil { + c := ownerAnnotationValue{OwnerRefs: []OwnerRef{ref}} + json, err := json.Marshal(c) + if err != nil { + return nil, fmt.Errorf("[unexpected] unable to marshal Tailscale Service's owner annotation contents: %w, please report this", err) + } + return map[string]string{ + ownerAnnotation: string(json), + }, nil + } + o, err := parseOwnerAnnotation(svc) + if err != nil { + return nil, err + } + if o == nil || len(o.OwnerRefs) == 0 { + return nil, fmt.Errorf("Tailscale Service %s exists, but does not contain owner annotation with owner references; not proceeding as this is likely a resource created by something other than the Tailscale Kubernetes operator", svc.Name) + } + if len(o.OwnerRefs) > 1 || o.OwnerRefs[0].OperatorID != operatorID { + return nil, fmt.Errorf("Tailscale Service %s is already owned by other operator(s) and cannot be shared across multiple clusters; configure a difference Service name to continue", svc.Name) + } + if o.OwnerRefs[0].Resource == nil { + return nil, fmt.Errorf("Tailscale Service %s exists, but does not reference an owning resource; not proceeding as this is likely a Service already owned by an Ingress", svc.Name) + } + if o.OwnerRefs[0].Resource.Kind != "ProxyGroup" || o.OwnerRefs[0].Resource.UID != string(pg.UID) { + return nil, fmt.Errorf("Tailscale Service %s is already owned by another resource: %#v; configure a difference Service name to continue", svc.Name, o.OwnerRefs[0].Resource) + } + if o.OwnerRefs[0].Resource.Name != pg.Name { + // ProxyGroup name can be updated in place. + o.OwnerRefs[0].Resource.Name = pg.Name + } + + oBytes, err := json.Marshal(o) + if err != nil { + return nil, err + } + + newAnnots := make(map[string]string, len(svc.Annotations)+1) + maps.Copy(newAnnots, svc.Annotations) + newAnnots[ownerAnnotation] = string(oBytes) + + return newAnnots, nil +} diff --git a/cmd/k8s-operator/api-server-proxy-pg_test.go b/cmd/k8s-operator/api-server-proxy-pg_test.go new file mode 100644 index 000000000..dfef63f22 --- /dev/null +++ b/cmd/k8s-operator/api-server-proxy-pg_test.go @@ -0,0 +1,384 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package main + +import ( + "encoding/json" + "reflect" + "strings" + "testing" + + "github.com/google/go-cmp/cmp" + "go.uber.org/zap" + corev1 "k8s.io/api/core/v1" + rbacv1 "k8s.io/api/rbac/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/client-go/tools/record" + "sigs.k8s.io/controller-runtime/pkg/client/fake" + "tailscale.com/internal/client/tailscale" + "tailscale.com/ipn/ipnstate" + tsoperator "tailscale.com/k8s-operator" + tsapi "tailscale.com/k8s-operator/apis/v1alpha1" + "tailscale.com/kube/k8s-proxy/conf" + "tailscale.com/kube/kubetypes" + "tailscale.com/tailcfg" + "tailscale.com/tstest" + "tailscale.com/types/opt" + "tailscale.com/types/ptr" +) + +func TestAPIServerProxyReconciler(t *testing.T) { + const ( + pgName = "test-pg" + pgUID = "test-pg-uid" + ns = "operator-ns" + defaultDomain = "test-pg.ts.net" + ) + pg := &tsapi.ProxyGroup{ + ObjectMeta: metav1.ObjectMeta{ + Name: pgName, + Generation: 1, + UID: pgUID, + }, + Spec: tsapi.ProxyGroupSpec{ + Type: tsapi.ProxyGroupTypeKubernetesAPIServer, + }, + Status: tsapi.ProxyGroupStatus{ + Conditions: []metav1.Condition{ + { + Type: string(tsapi.ProxyGroupAvailable), + Status: metav1.ConditionTrue, + ObservedGeneration: 1, + }, + }, + }, + } + initialCfg := &conf.VersionedConfig{ + Version: "v1alpha1", + ConfigV1Alpha1: &conf.ConfigV1Alpha1{ + AuthKey: ptr.To("test-key"), + APIServerProxy: &conf.APIServerProxyConfig{ + Enabled: opt.NewBool(true), + }, + }, + } + expectedCfg := *initialCfg + initialCfgB, err := json.Marshal(initialCfg) + if err != nil { + t.Fatalf("marshaling initial config: %v", err) + } + pgCfgSecret := &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: pgConfigSecretName(pgName, 0), + Namespace: ns, + Labels: pgSecretLabels(pgName, kubetypes.LabelSecretTypeConfig), + }, + Data: map[string][]byte{ + // Existing config should be preserved. + kubetypes.KubeAPIServerConfigFile: initialCfgB, + }, + } + fc := fake.NewClientBuilder(). + WithScheme(tsapi.GlobalScheme). + WithObjects(pg, pgCfgSecret). + WithStatusSubresource(pg). + Build() + expectCfg := func(c *conf.VersionedConfig) { + t.Helper() + cBytes, err := json.Marshal(c) + if err != nil { + t.Fatalf("marshaling expected config: %v", err) + } + pgCfgSecret.Data[kubetypes.KubeAPIServerConfigFile] = cBytes + expectEqual(t, fc, pgCfgSecret) + } + + ft := &fakeTSClient{} + ingressTSSvc := &tailscale.VIPService{ + Name: "svc:some-ingress-hostname", + Comment: managedTSServiceComment, + Annotations: map[string]string{ + // No resource field. + ownerAnnotation: `{"ownerRefs":[{"operatorID":"self-id"}]}`, + }, + Ports: []string{"tcp:443"}, + Tags: []string{"tag:k8s"}, + Addrs: []string{"5.6.7.8"}, + } + ft.CreateOrUpdateVIPService(t.Context(), ingressTSSvc) + + lc := &fakeLocalClient{ + status: &ipnstate.Status{ + CurrentTailnet: &ipnstate.TailnetStatus{ + MagicDNSSuffix: "ts.net", + }, + }, + } + + r := &KubeAPIServerTSServiceReconciler{ + Client: fc, + tsClient: ft, + defaultTags: []string{"tag:k8s"}, + tsNamespace: ns, + logger: zap.Must(zap.NewDevelopment()).Sugar(), + recorder: record.NewFakeRecorder(10), + lc: lc, + clock: tstest.NewClock(tstest.ClockOpts{}), + operatorID: "self-id", + } + + // Create a Tailscale Service that will conflict with the initial config. + if err := ft.CreateOrUpdateVIPService(t.Context(), &tailscale.VIPService{ + Name: "svc:" + pgName, + }); err != nil { + t.Fatalf("creating initial Tailscale Service: %v", err) + } + expectReconciled(t, r, "", pgName) + pg.ObjectMeta.Finalizers = []string{proxyPGFinalizerName} + tsoperator.SetProxyGroupCondition(pg, tsapi.KubeAPIServerProxyValid, metav1.ConditionFalse, reasonKubeAPIServerProxyInvalid, "", 1, r.clock, r.logger) + tsoperator.SetProxyGroupCondition(pg, tsapi.KubeAPIServerProxyConfigured, metav1.ConditionFalse, reasonKubeAPIServerProxyNoBackends, "", 1, r.clock, r.logger) + expectEqual(t, fc, pg, omitPGStatusConditionMessages) + expectMissing[corev1.Secret](t, fc, ns, defaultDomain) + expectMissing[rbacv1.Role](t, fc, ns, defaultDomain) + expectMissing[rbacv1.RoleBinding](t, fc, ns, defaultDomain) + expectEqual(t, fc, pgCfgSecret) // Unchanged. + + // Delete Tailscale Service; should see Service created and valid condition updated to true. + if err := ft.DeleteVIPService(t.Context(), "svc:"+pgName); err != nil { + t.Fatalf("deleting initial Tailscale Service: %v", err) + } + expectReconciled(t, r, "", pgName) + + tsSvc, err := ft.GetVIPService(t.Context(), "svc:"+pgName) + if err != nil { + t.Fatalf("getting Tailscale Service: %v", err) + } + if tsSvc == nil { + t.Fatalf("expected Tailscale Service to be created, but got nil") + } + expectedTSSvc := &tailscale.VIPService{ + Name: "svc:" + pgName, + Comment: managedTSServiceComment, + Annotations: map[string]string{ + ownerAnnotation: `{"ownerRefs":[{"operatorID":"self-id","resource":{"kind":"ProxyGroup","name":"test-pg","uid":"test-pg-uid"}}]}`, + }, + Ports: []string{"tcp:443"}, + Tags: []string{"tag:k8s"}, + Addrs: []string{"5.6.7.8"}, + } + if !reflect.DeepEqual(tsSvc, expectedTSSvc) { + t.Fatalf("expected Tailscale Service to be %+v, got %+v", expectedTSSvc, tsSvc) + } + tsoperator.SetProxyGroupCondition(pg, tsapi.KubeAPIServerProxyValid, metav1.ConditionTrue, reasonKubeAPIServerProxyValid, "", 1, r.clock, r.logger) + tsoperator.SetProxyGroupCondition(pg, tsapi.KubeAPIServerProxyConfigured, metav1.ConditionFalse, reasonKubeAPIServerProxyNoBackends, "", 1, r.clock, r.logger) + expectEqual(t, fc, pg, omitPGStatusConditionMessages) + + expectedCfg.APIServerProxy.ServiceName = ptr.To(tailcfg.ServiceName("svc:" + pgName)) + expectCfg(&expectedCfg) + + expectEqual(t, fc, certSecret(pgName, ns, defaultDomain, pg)) + expectEqual(t, fc, certSecretRole(pgName, ns, defaultDomain)) + expectEqual(t, fc, certSecretRoleBinding(pg, ns, defaultDomain)) + + // Simulate certs being issued; should observe AdvertiseServices config change. + if err := populateTLSSecret(t.Context(), fc, pgName, defaultDomain); err != nil { + t.Fatalf("populating TLS Secret: %v", err) + } + expectReconciled(t, r, "", pgName) + + expectedCfg.AdvertiseServices = []string{"svc:" + pgName} + expectCfg(&expectedCfg) + + expectEqual(t, fc, pg, omitPGStatusConditionMessages) // Unchanged status. + + // Simulate Pod prefs updated with advertised services; should see Configured condition updated to true. + mustCreate(t, fc, &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-pg-0", + Namespace: ns, + Labels: pgSecretLabels(pgName, kubetypes.LabelSecretTypeState), + }, + Data: map[string][]byte{ + "_current-profile": []byte("profile-foo"), + "profile-foo": []byte(`{"AdvertiseServices":["svc:test-pg"],"Config":{"NodeID":"node-foo"}}`), + }, + }) + expectReconciled(t, r, "", pgName) + tsoperator.SetProxyGroupCondition(pg, tsapi.KubeAPIServerProxyConfigured, metav1.ConditionTrue, reasonKubeAPIServerProxyConfigured, "", 1, r.clock, r.logger) + pg.Status.URL = "https://" + defaultDomain + expectEqual(t, fc, pg, omitPGStatusConditionMessages) + + // Rename the Tailscale Service - old one + cert resources should be cleaned up. + updatedServiceName := tailcfg.ServiceName("svc:test-pg-renamed") + updatedDomain := "test-pg-renamed.ts.net" + pg.Spec.KubeAPIServer = &tsapi.KubeAPIServerConfig{ + Hostname: updatedServiceName.WithoutPrefix(), + } + mustUpdate(t, fc, "", pgName, func(p *tsapi.ProxyGroup) { + p.Spec.KubeAPIServer = pg.Spec.KubeAPIServer + }) + expectReconciled(t, r, "", pgName) + _, err = ft.GetVIPService(t.Context(), "svc:"+pgName) + if !isErrorTailscaleServiceNotFound(err) { + t.Fatalf("Expected 404, got: %v", err) + } + tsSvc, err = ft.GetVIPService(t.Context(), updatedServiceName) + if err != nil { + t.Fatalf("Expected renamed svc, got error: %v", err) + } + expectedTSSvc.Name = updatedServiceName + if !reflect.DeepEqual(tsSvc, expectedTSSvc) { + t.Fatalf("expected Tailscale Service to be %+v, got %+v", expectedTSSvc, tsSvc) + } + // Check cfg and status reset until TLS certs are available again. + expectedCfg.APIServerProxy.ServiceName = ptr.To(updatedServiceName) + expectedCfg.AdvertiseServices = nil + expectCfg(&expectedCfg) + tsoperator.SetProxyGroupCondition(pg, tsapi.KubeAPIServerProxyConfigured, metav1.ConditionFalse, reasonKubeAPIServerProxyNoBackends, "", 1, r.clock, r.logger) + pg.Status.URL = "" + expectEqual(t, fc, pg, omitPGStatusConditionMessages) + + expectEqual(t, fc, certSecret(pgName, ns, updatedDomain, pg)) + expectEqual(t, fc, certSecretRole(pgName, ns, updatedDomain)) + expectEqual(t, fc, certSecretRoleBinding(pg, ns, updatedDomain)) + expectMissing[corev1.Secret](t, fc, ns, defaultDomain) + expectMissing[rbacv1.Role](t, fc, ns, defaultDomain) + expectMissing[rbacv1.RoleBinding](t, fc, ns, defaultDomain) + + // Check we get the new hostname in the status once ready. + if err := populateTLSSecret(t.Context(), fc, pgName, updatedDomain); err != nil { + t.Fatalf("populating TLS Secret: %v", err) + } + mustUpdate(t, fc, "operator-ns", "test-pg-0", func(s *corev1.Secret) { + s.Data["profile-foo"] = []byte(`{"AdvertiseServices":["svc:test-pg"],"Config":{"NodeID":"node-foo"}}`) + }) + expectReconciled(t, r, "", pgName) + expectedCfg.AdvertiseServices = []string{updatedServiceName.String()} + expectCfg(&expectedCfg) + tsoperator.SetProxyGroupCondition(pg, tsapi.KubeAPIServerProxyConfigured, metav1.ConditionTrue, reasonKubeAPIServerProxyConfigured, "", 1, r.clock, r.logger) + pg.Status.URL = "https://" + updatedDomain + + // Delete the ProxyGroup and verify Tailscale Service and cert resources are cleaned up. + if err := fc.Delete(t.Context(), pg); err != nil { + t.Fatalf("deleting ProxyGroup: %v", err) + } + expectReconciled(t, r, "", pgName) + expectMissing[corev1.Secret](t, fc, ns, updatedDomain) + expectMissing[rbacv1.Role](t, fc, ns, updatedDomain) + expectMissing[rbacv1.RoleBinding](t, fc, ns, updatedDomain) + _, err = ft.GetVIPService(t.Context(), updatedServiceName) + if !isErrorTailscaleServiceNotFound(err) { + t.Fatalf("Expected 404, got: %v", err) + } + + // Ingress Tailscale Service should not be affected. + svc, err := ft.GetVIPService(t.Context(), ingressTSSvc.Name) + if err != nil { + t.Fatalf("getting ingress Tailscale Service: %v", err) + } + if !reflect.DeepEqual(svc, ingressTSSvc) { + t.Fatalf("expected ingress Tailscale Service to be unmodified %+v, got %+v", ingressTSSvc, svc) + } +} + +func TestExclusiveOwnerAnnotations(t *testing.T) { + pg := &tsapi.ProxyGroup{ + ObjectMeta: metav1.ObjectMeta{ + Name: "pg1", + UID: "pg1-uid", + }, + } + const ( + selfOperatorID = "self-id" + pg1Owner = `{"ownerRefs":[{"operatorID":"self-id","resource":{"kind":"ProxyGroup","name":"pg1","uid":"pg1-uid"}}]}` + ) + + for name, tc := range map[string]struct { + svc *tailscale.VIPService + wantErr string + }{ + "no_svc": { + svc: nil, + }, + "empty_svc": { + svc: &tailscale.VIPService{}, + wantErr: "likely a resource created by something other than the Tailscale Kubernetes operator", + }, + "already_owner": { + svc: &tailscale.VIPService{ + Annotations: map[string]string{ + ownerAnnotation: pg1Owner, + }, + }, + }, + "already_owner_name_updated": { + svc: &tailscale.VIPService{ + Annotations: map[string]string{ + ownerAnnotation: `{"ownerRefs":[{"operatorID":"self-id","resource":{"kind":"ProxyGroup","name":"old-pg1-name","uid":"pg1-uid"}}]}`, + }, + }, + }, + "preserves_existing_annotations": { + svc: &tailscale.VIPService{ + Annotations: map[string]string{ + "existing": "annotation", + ownerAnnotation: pg1Owner, + }, + }, + }, + "owned_by_another_operator": { + svc: &tailscale.VIPService{ + Annotations: map[string]string{ + ownerAnnotation: `{"ownerRefs":[{"operatorID":"operator-2"}]}`, + }, + }, + wantErr: "already owned by other operator(s)", + }, + "owned_by_an_ingress": { + svc: &tailscale.VIPService{ + Annotations: map[string]string{ + ownerAnnotation: `{"ownerRefs":[{"operatorID":"self-id"}]}`, // Ingress doesn't set Resource field (yet). + }, + }, + wantErr: "does not reference an owning resource", + }, + "owned_by_another_pg": { + svc: &tailscale.VIPService{ + Annotations: map[string]string{ + ownerAnnotation: `{"ownerRefs":[{"operatorID":"self-id","resource":{"kind":"ProxyGroup","name":"pg2","uid":"pg2-uid"}}]}`, + }, + }, + wantErr: "already owned by another resource", + }, + } { + t.Run(name, func(t *testing.T) { + got, err := exclusiveOwnerAnnotations(pg, "self-id", tc.svc) + if tc.wantErr != "" { + if !strings.Contains(err.Error(), tc.wantErr) { + t.Errorf("exclusiveOwnerAnnotations() error = %v, wantErr %v", err, tc.wantErr) + } + } else if diff := cmp.Diff(pg1Owner, got[ownerAnnotation]); diff != "" { + t.Errorf("exclusiveOwnerAnnotations() mismatch (-want +got):\n%s", diff) + } + if tc.svc == nil { + return // Don't check annotations being preserved. + } + for k, v := range tc.svc.Annotations { + if k == ownerAnnotation { + continue + } + if got[k] != v { + t.Errorf("exclusiveOwnerAnnotations() did not preserve annotation %q: got %q, want %q", k, got[k], v) + } + } + }) + } +} + +func omitPGStatusConditionMessages(p *tsapi.ProxyGroup) { + for i := range p.Status.Conditions { + // Don't bother validating the message. + p.Status.Conditions[i].Message = "" + } +} diff --git a/cmd/k8s-operator/api-server-proxy.go b/cmd/k8s-operator/api-server-proxy.go new file mode 100644 index 000000000..70333d2c4 --- /dev/null +++ b/cmd/k8s-operator/api-server-proxy.go @@ -0,0 +1,43 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !plan9 + +package main + +import ( + "fmt" + "log" + "os" + + "tailscale.com/kube/kubetypes" + "tailscale.com/types/ptr" +) + +func parseAPIProxyMode() *kubetypes.APIServerProxyMode { + haveAuthProxyEnv := os.Getenv("AUTH_PROXY") != "" + haveAPIProxyEnv := os.Getenv("APISERVER_PROXY") != "" + switch { + case haveAPIProxyEnv && haveAuthProxyEnv: + log.Fatal("AUTH_PROXY (deprecated) and APISERVER_PROXY are mutually exclusive, please unset AUTH_PROXY") + case haveAuthProxyEnv: + var authProxyEnv = defaultBool("AUTH_PROXY", false) // deprecated + if authProxyEnv { + return ptr.To(kubetypes.APIServerProxyModeAuth) + } + return nil + case haveAPIProxyEnv: + var apiProxyEnv = defaultEnv("APISERVER_PROXY", "") // true, false or "noauth" + switch apiProxyEnv { + case "true": + return ptr.To(kubetypes.APIServerProxyModeAuth) + case "false", "": + return nil + case "noauth": + return ptr.To(kubetypes.APIServerProxyModeNoAuth) + default: + panic(fmt.Sprintf("unknown APISERVER_PROXY value %q", apiProxyEnv)) + } + } + return nil +} diff --git a/cmd/k8s-operator/connector.go b/cmd/k8s-operator/connector.go index 016166b4c..7fa311532 100644 --- a/cmd/k8s-operator/connector.go +++ b/cmd/k8s-operator/connector.go @@ -7,13 +7,14 @@ package main import ( "context" + "errors" "fmt" "net/netip" "slices" + "strings" "sync" "time" - "github.com/pkg/errors" "go.uber.org/zap" xslices "golang.org/x/exp/slices" corev1 "k8s.io/api/core/v1" @@ -34,6 +35,7 @@ import ( const ( reasonConnectorCreationFailed = "ConnectorCreationFailed" + reasonConnectorCreating = "ConnectorCreating" reasonConnectorCreated = "ConnectorCreated" reasonConnectorInvalid = "ConnectorInvalid" @@ -58,6 +60,7 @@ type ConnectorReconciler struct { subnetRouters set.Slice[types.UID] // for subnet routers gauge exitNodes set.Slice[types.UID] // for exit nodes gauge + appConnectors set.Slice[types.UID] // for app connectors gauge } var ( @@ -67,6 +70,8 @@ var ( gaugeConnectorSubnetRouterResources = clientmetric.NewGauge(kubetypes.MetricConnectorWithSubnetRouterCount) // gaugeConnectorExitNodeResources tracks the number of Connectors currently managed by this operator instance that are exit nodes. gaugeConnectorExitNodeResources = clientmetric.NewGauge(kubetypes.MetricConnectorWithExitNodeCount) + // gaugeConnectorAppConnectorResources tracks the number of Connectors currently managed by this operator instance that are app connectors. + gaugeConnectorAppConnectorResources = clientmetric.NewGauge(kubetypes.MetricConnectorWithAppConnectorCount) ) func (a *ConnectorReconciler) Reconcile(ctx context.Context, req reconcile.Request) (res reconcile.Result, err error) { @@ -108,13 +113,12 @@ func (a *ConnectorReconciler) Reconcile(ctx context.Context, req reconcile.Reque oldCnStatus := cn.Status.DeepCopy() setStatus := func(cn *tsapi.Connector, _ tsapi.ConditionType, status metav1.ConditionStatus, reason, message string) (reconcile.Result, error) { tsoperator.SetConnectorCondition(cn, tsapi.ConnectorReady, status, reason, message, cn.Generation, a.clock, logger) - if !apiequality.Semantic.DeepEqual(oldCnStatus, cn.Status) { + var updateErr error + if !apiequality.Semantic.DeepEqual(oldCnStatus, &cn.Status) { // An error encountered here should get returned by the Reconcile function. - if updateErr := a.Client.Status().Update(ctx, cn); updateErr != nil { - err = errors.Wrap(err, updateErr.Error()) - } + updateErr = a.Client.Status().Update(ctx, cn) } - return res, err + return res, errors.Join(err, updateErr) } if !slices.Contains(cn.Finalizers, FinalizerName) { @@ -131,17 +135,24 @@ func (a *ConnectorReconciler) Reconcile(ctx context.Context, req reconcile.Reque } if err := a.validate(cn); err != nil { - logger.Errorf("error validating Connector spec: %w", err) message := fmt.Sprintf(messageConnectorInvalid, err) a.recorder.Eventf(cn, corev1.EventTypeWarning, reasonConnectorInvalid, message) return setStatus(cn, tsapi.ConnectorReady, metav1.ConditionFalse, reasonConnectorInvalid, message) } if err = a.maybeProvisionConnector(ctx, logger, cn); err != nil { - logger.Errorf("error creating Connector resources: %w", err) + reason := reasonConnectorCreationFailed message := fmt.Sprintf(messageConnectorCreationFailed, err) - a.recorder.Eventf(cn, corev1.EventTypeWarning, reasonConnectorCreationFailed, message) - return setStatus(cn, tsapi.ConnectorReady, metav1.ConditionFalse, reasonConnectorCreationFailed, message) + if strings.Contains(err.Error(), optimisticLockErrorMsg) { + reason = reasonConnectorCreating + message = fmt.Sprintf("optimistic lock error, retrying: %s", err) + err = nil + logger.Info(message) + } else { + a.recorder.Eventf(cn, corev1.EventTypeWarning, reason, message) + } + + return setStatus(cn, tsapi.ConnectorReady, metav1.ConditionFalse, reason, message) } logger.Info("Connector resources synced") @@ -150,6 +161,9 @@ func (a *ConnectorReconciler) Reconcile(ctx context.Context, req reconcile.Reque cn.Status.SubnetRoutes = cn.Spec.SubnetRouter.AdvertiseRoutes.Stringify() return setStatus(cn, tsapi.ConnectorReady, metav1.ConditionTrue, reasonConnectorCreated, reasonConnectorCreated) } + if cn.Spec.AppConnector != nil { + cn.Status.IsAppConnector = true + } cn.Status.SubnetRoutes = "" return setStatus(cn, tsapi.ConnectorReady, metav1.ConditionTrue, reasonConnectorCreated, reasonConnectorCreated) } @@ -161,6 +175,7 @@ func (a *ConnectorReconciler) maybeProvisionConnector(ctx context.Context, logge if cn.Spec.Hostname != "" { hostname = string(cn.Spec.Hostname) } + crl := childResourceLabels(cn.Name, a.tsnamespace, "connector") proxyClass := cn.Spec.ProxyClass @@ -173,39 +188,65 @@ func (a *ConnectorReconciler) maybeProvisionConnector(ctx context.Context, logge } } + var replicas int32 = 1 + if cn.Spec.Replicas != nil { + replicas = *cn.Spec.Replicas + } + sts := &tailscaleSTSConfig{ + Replicas: replicas, ParentResourceName: cn.Name, ParentResourceUID: string(cn.UID), Hostname: hostname, + HostnamePrefix: string(cn.Spec.HostnamePrefix), ChildResourceLabels: crl, Tags: cn.Spec.Tags.Stringify(), Connector: &connector{ isExitNode: cn.Spec.ExitNode, }, ProxyClassName: proxyClass, + proxyType: proxyTypeConnector, + LoginServer: a.ssr.loginServer, } if cn.Spec.SubnetRouter != nil && len(cn.Spec.SubnetRouter.AdvertiseRoutes) > 0 { sts.Connector.routes = cn.Spec.SubnetRouter.AdvertiseRoutes.Stringify() } + if cn.Spec.AppConnector != nil { + sts.Connector.isAppConnector = true + if len(cn.Spec.AppConnector.Routes) != 0 { + sts.Connector.routes = cn.Spec.AppConnector.Routes.Stringify() + } + } + a.mu.Lock() - if sts.Connector.isExitNode { + if cn.Spec.ExitNode { a.exitNodes.Add(cn.UID) } else { a.exitNodes.Remove(cn.UID) } - if sts.Connector.routes != "" { + + if cn.Spec.SubnetRouter != nil { a.subnetRouters.Add(cn.GetUID()) } else { a.subnetRouters.Remove(cn.GetUID()) } + + if cn.Spec.AppConnector != nil { + a.appConnectors.Add(cn.GetUID()) + } else { + a.appConnectors.Remove(cn.GetUID()) + } + a.mu.Unlock() gaugeConnectorSubnetRouterResources.Set(int64(a.subnetRouters.Len())) gaugeConnectorExitNodeResources.Set(int64(a.exitNodes.Len())) + gaugeConnectorAppConnectorResources.Set(int64(a.appConnectors.Len())) var connectors set.Slice[types.UID] connectors.AddSlice(a.exitNodes.Slice()) connectors.AddSlice(a.subnetRouters.Slice()) + connectors.AddSlice(a.appConnectors.Slice()) gaugeConnectorResources.Set(int64(connectors.Len())) _, err := a.ssr.Provision(ctx, logger, sts) @@ -213,27 +254,29 @@ func (a *ConnectorReconciler) maybeProvisionConnector(ctx context.Context, logge return err } - _, tsHost, ips, err := a.ssr.DeviceInfo(ctx, crl) + devices, err := a.ssr.DeviceInfo(ctx, crl, logger) if err != nil { return err } - if tsHost == "" { - logger.Debugf("no Tailscale hostname known yet, waiting for connector pod to finish auth") - // No hostname yet. Wait for the connector pod to auth. - cn.Status.TailnetIPs = nil - cn.Status.Hostname = "" - return nil + cn.Status.Devices = make([]tsapi.ConnectorDevice, len(devices)) + for i, dev := range devices { + cn.Status.Devices[i] = tsapi.ConnectorDevice{ + Hostname: dev.hostname, + TailnetIPs: dev.ips, + } } - cn.Status.TailnetIPs = ips - cn.Status.Hostname = tsHost + if len(cn.Status.Devices) > 0 { + cn.Status.Hostname = cn.Status.Devices[0].Hostname + cn.Status.TailnetIPs = cn.Status.Devices[0].TailnetIPs + } return nil } func (a *ConnectorReconciler) maybeCleanupConnector(ctx context.Context, logger *zap.SugaredLogger, cn *tsapi.Connector) (bool, error) { - if done, err := a.ssr.Cleanup(ctx, logger, childResourceLabels(cn.Name, a.tsnamespace, "connector")); err != nil { + if done, err := a.ssr.Cleanup(ctx, logger, childResourceLabels(cn.Name, a.tsnamespace, "connector"), proxyTypeConnector); err != nil { return false, fmt.Errorf("failed to cleanup Connector resources: %w", err) } else if !done { logger.Debugf("Connector cleanup not done yet, waiting for next reconcile") @@ -248,12 +291,15 @@ func (a *ConnectorReconciler) maybeCleanupConnector(ctx context.Context, logger a.mu.Lock() a.subnetRouters.Remove(cn.UID) a.exitNodes.Remove(cn.UID) + a.appConnectors.Remove(cn.UID) a.mu.Unlock() gaugeConnectorExitNodeResources.Set(int64(a.exitNodes.Len())) gaugeConnectorSubnetRouterResources.Set(int64(a.subnetRouters.Len())) + gaugeConnectorAppConnectorResources.Set(int64(a.appConnectors.Len())) var connectors set.Slice[types.UID] connectors.AddSlice(a.exitNodes.Slice()) connectors.AddSlice(a.subnetRouters.Slice()) + connectors.AddSlice(a.appConnectors.Slice()) gaugeConnectorResources.Set(int64(connectors.Len())) return true, nil } @@ -262,8 +308,23 @@ func (a *ConnectorReconciler) validate(cn *tsapi.Connector) error { // Connector fields are already validated at apply time with CEL validation // on custom resource fields. The checks here are a backup in case the // CEL validation breaks without us noticing. - if !(cn.Spec.SubnetRouter != nil || cn.Spec.ExitNode) { - return errors.New("invalid spec: a Connector must expose subnet routes or act as an exit node (or both)") + if cn.Spec.SubnetRouter == nil && !cn.Spec.ExitNode && cn.Spec.AppConnector == nil { + return errors.New("invalid spec: a Connector must be configured as at least one of subnet router, exit node or app connector") + } + if (cn.Spec.SubnetRouter != nil || cn.Spec.ExitNode) && cn.Spec.AppConnector != nil { + return errors.New("invalid spec: a Connector that is configured as an app connector must not be also configured as a subnet router or exit node") + } + + // These two checks should be caught by the Connector schema validation. + if cn.Spec.Replicas != nil && *cn.Spec.Replicas > 1 && cn.Spec.Hostname != "" { + return errors.New("invalid spec: a Connector that is configured with multiple replicas cannot specify a hostname. Instead, use a hostnamePrefix") + } + if cn.Spec.HostnamePrefix != "" && cn.Spec.Hostname != "" { + return errors.New("invalid spec: a Connect cannot use both a hostname and hostname prefix") + } + + if cn.Spec.AppConnector != nil { + return validateAppConnector(cn.Spec.AppConnector) } if cn.Spec.SubnetRouter == nil { return nil @@ -272,19 +333,27 @@ func (a *ConnectorReconciler) validate(cn *tsapi.Connector) error { } func validateSubnetRouter(sb *tsapi.SubnetRouter) error { - if len(sb.AdvertiseRoutes) < 1 { + if len(sb.AdvertiseRoutes) == 0 { return errors.New("invalid subnet router spec: no routes defined") } - var err error - for _, route := range sb.AdvertiseRoutes { + return validateRoutes(sb.AdvertiseRoutes) +} + +func validateAppConnector(ac *tsapi.AppConnector) error { + return validateRoutes(ac.Routes) +} + +func validateRoutes(routes tsapi.Routes) error { + var errs []error + for _, route := range routes { pfx, e := netip.ParsePrefix(string(route)) if e != nil { - err = errors.Wrap(err, fmt.Sprintf("route %s is invalid: %v", route, err)) + errs = append(errs, fmt.Errorf("route %v is invalid: %v", route, e)) continue } if pfx.Masked() != pfx { - err = errors.Wrap(err, fmt.Sprintf("route %s has non-address bits set; expected %s", pfx, pfx.Masked())) + errs = append(errs, fmt.Errorf("route %s has non-address bits set; expected %s", pfx, pfx.Masked())) } } - return err + return errors.Join(errs...) } diff --git a/cmd/k8s-operator/connector_test.go b/cmd/k8s-operator/connector_test.go index 01c60bc9e..afc7d2d6e 100644 --- a/cmd/k8s-operator/connector_test.go +++ b/cmd/k8s-operator/connector_test.go @@ -7,17 +7,22 @@ package main import ( "context" + "strconv" + "strings" "testing" + "time" "go.uber.org/zap" appsv1 "k8s.io/api/apps/v1" corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/types" + "k8s.io/client-go/tools/record" "sigs.k8s.io/controller-runtime/pkg/client/fake" tsapi "tailscale.com/k8s-operator/apis/v1alpha1" "tailscale.com/kube/kubetypes" "tailscale.com/tstest" + "tailscale.com/types/ptr" "tailscale.com/util/mak" ) @@ -34,6 +39,7 @@ func TestConnector(t *testing.T) { APIVersion: "tailscale.com/v1alpha1", }, Spec: tsapi.ConnectorSpec{ + Replicas: ptr.To[int32](1), SubnetRouter: &tsapi.SubnetRouter{ AdvertiseRoutes: []tsapi.Route{"10.40.0.0/14"}, }, @@ -53,7 +59,8 @@ func TestConnector(t *testing.T) { cl := tstest.NewClock(tstest.ClockOpts{}) cr := &ConnectorReconciler{ - Client: fc, + Client: fc, + recorder: record.NewFakeRecorder(10), ssr: &tailscaleSTSReconciler{ Client: fc, tsClient: ft, @@ -76,9 +83,10 @@ func TestConnector(t *testing.T) { isExitNode: true, subnetRoutes: "10.40.0.0/14", app: kubetypes.AppConnector, + replicas: cn.Spec.Replicas, } - expectEqual(t, fc, expectedSecret(t, fc, opts), nil) - expectEqual(t, fc, expectedSTS(t, fc, opts), removeHashAnnotation) + expectEqual(t, fc, expectedSecret(t, fc, opts)) + expectEqual(t, fc, expectedSTS(t, fc, opts), removeResourceReqs) // Connector status should get updated with the IP/hostname info when available. const hostname = "foo.tailnetxyz.ts.net" @@ -92,6 +100,10 @@ func TestConnector(t *testing.T) { cn.Status.IsExitNode = cn.Spec.ExitNode cn.Status.SubnetRoutes = cn.Spec.SubnetRouter.AdvertiseRoutes.Stringify() cn.Status.Hostname = hostname + cn.Status.Devices = []tsapi.ConnectorDevice{{ + Hostname: hostname, + TailnetIPs: []string{"127.0.0.1", "::1"}, + }} cn.Status.TailnetIPs = []string{"127.0.0.1", "::1"} expectEqual(t, fc, cn, func(o *tsapi.Connector) { o.Status.Conditions = nil @@ -104,7 +116,7 @@ func TestConnector(t *testing.T) { opts.subnetRoutes = "10.40.0.0/14,10.44.0.0/20" expectReconciled(t, cr, "", "test") - expectEqual(t, fc, expectedSTS(t, fc, opts), removeHashAnnotation) + expectEqual(t, fc, expectedSTS(t, fc, opts), removeResourceReqs) // Remove a route. mustUpdate[tsapi.Connector](t, fc, "", "test", func(conn *tsapi.Connector) { @@ -112,7 +124,7 @@ func TestConnector(t *testing.T) { }) opts.subnetRoutes = "10.44.0.0/20" expectReconciled(t, cr, "", "test") - expectEqual(t, fc, expectedSTS(t, fc, opts), removeHashAnnotation) + expectEqual(t, fc, expectedSTS(t, fc, opts), removeResourceReqs) // Remove the subnet router. mustUpdate[tsapi.Connector](t, fc, "", "test", func(conn *tsapi.Connector) { @@ -120,7 +132,7 @@ func TestConnector(t *testing.T) { }) opts.subnetRoutes = "" expectReconciled(t, cr, "", "test") - expectEqual(t, fc, expectedSTS(t, fc, opts), removeHashAnnotation) + expectEqual(t, fc, expectedSTS(t, fc, opts), removeResourceReqs) // Re-add the subnet router. mustUpdate[tsapi.Connector](t, fc, "", "test", func(conn *tsapi.Connector) { @@ -130,7 +142,7 @@ func TestConnector(t *testing.T) { }) opts.subnetRoutes = "10.44.0.0/20" expectReconciled(t, cr, "", "test") - expectEqual(t, fc, expectedSTS(t, fc, opts), removeHashAnnotation) + expectEqual(t, fc, expectedSTS(t, fc, opts), removeResourceReqs) // Delete the Connector. if err = fc.Delete(context.Background(), cn); err != nil { @@ -154,6 +166,7 @@ func TestConnector(t *testing.T) { APIVersion: "tailscale.io/v1alpha1", }, Spec: tsapi.ConnectorSpec{ + Replicas: ptr.To[int32](1), SubnetRouter: &tsapi.SubnetRouter{ AdvertiseRoutes: []tsapi.Route{"10.40.0.0/14"}, }, @@ -172,9 +185,10 @@ func TestConnector(t *testing.T) { subnetRoutes: "10.40.0.0/14", hostname: "test-connector", app: kubetypes.AppConnector, + replicas: cn.Spec.Replicas, } - expectEqual(t, fc, expectedSecret(t, fc, opts), nil) - expectEqual(t, fc, expectedSTS(t, fc, opts), removeHashAnnotation) + expectEqual(t, fc, expectedSecret(t, fc, opts)) + expectEqual(t, fc, expectedSTS(t, fc, opts), removeResourceReqs) // Add an exit node. mustUpdate[tsapi.Connector](t, fc, "", "test", func(conn *tsapi.Connector) { @@ -182,7 +196,7 @@ func TestConnector(t *testing.T) { }) opts.isExitNode = true expectReconciled(t, cr, "", "test") - expectEqual(t, fc, expectedSTS(t, fc, opts), removeHashAnnotation) + expectEqual(t, fc, expectedSTS(t, fc, opts), removeResourceReqs) // Delete the Connector. if err = fc.Delete(context.Background(), cn); err != nil { @@ -201,7 +215,7 @@ func TestConnectorWithProxyClass(t *testing.T) { pc := &tsapi.ProxyClass{ ObjectMeta: metav1.ObjectMeta{Name: "custom-metadata"}, Spec: tsapi.ProxyClassSpec{StatefulSet: &tsapi.StatefulSet{ - Labels: map[string]string{"foo": "bar"}, + Labels: tsapi.Labels{"foo": "bar"}, Annotations: map[string]string{"bar.io/foo": "some-val"}, Pod: &tsapi.Pod{Annotations: map[string]string{"foo.io/bar": "some-val"}}}}, } @@ -215,9 +229,11 @@ func TestConnectorWithProxyClass(t *testing.T) { APIVersion: "tailscale.io/v1alpha1", }, Spec: tsapi.ConnectorSpec{ + Replicas: ptr.To[int32](1), SubnetRouter: &tsapi.SubnetRouter{ AdvertiseRoutes: []tsapi.Route{"10.40.0.0/14"}, }, + ExitNode: true, }, } @@ -258,9 +274,10 @@ func TestConnectorWithProxyClass(t *testing.T) { isExitNode: true, subnetRoutes: "10.40.0.0/14", app: kubetypes.AppConnector, + replicas: cn.Spec.Replicas, } - expectEqual(t, fc, expectedSecret(t, fc, opts), nil) - expectEqual(t, fc, expectedSTS(t, fc, opts), removeHashAnnotation) + expectEqual(t, fc, expectedSecret(t, fc, opts)) + expectEqual(t, fc, expectedSTS(t, fc, opts), removeResourceReqs) // 2. Update Connector to specify a ProxyClass. ProxyClass is not yet // ready, so its configuration is NOT applied to the Connector @@ -269,7 +286,7 @@ func TestConnectorWithProxyClass(t *testing.T) { conn.Spec.ProxyClass = "custom-metadata" }) expectReconciled(t, cr, "", "test") - expectEqual(t, fc, expectedSTS(t, fc, opts), removeHashAnnotation) + expectEqual(t, fc, expectedSTS(t, fc, opts), removeResourceReqs) // 3. ProxyClass is set to Ready by proxy-class reconciler. Connector // get reconciled and configuration from the ProxyClass is applied to @@ -278,13 +295,13 @@ func TestConnectorWithProxyClass(t *testing.T) { pc.Status = tsapi.ProxyClassStatus{ Conditions: []metav1.Condition{{ Status: metav1.ConditionTrue, - Type: string(tsapi.ProxyClassready), + Type: string(tsapi.ProxyClassReady), ObservedGeneration: pc.Generation, }}} }) opts.proxyClass = pc.Name expectReconciled(t, cr, "", "test") - expectEqual(t, fc, expectedSTS(t, fc, opts), removeHashAnnotation) + expectEqual(t, fc, expectedSTS(t, fc, opts), removeResourceReqs) // 4. Connector.spec.proxyClass field is unset, Connector gets // reconciled and configuration from the ProxyClass is removed from the @@ -294,5 +311,196 @@ func TestConnectorWithProxyClass(t *testing.T) { }) opts.proxyClass = "" expectReconciled(t, cr, "", "test") - expectEqual(t, fc, expectedSTS(t, fc, opts), removeHashAnnotation) + expectEqual(t, fc, expectedSTS(t, fc, opts), removeResourceReqs) +} + +func TestConnectorWithAppConnector(t *testing.T) { + // Setup + cn := &tsapi.Connector{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test", + UID: types.UID("1234-UID"), + }, + TypeMeta: metav1.TypeMeta{ + Kind: tsapi.ConnectorKind, + APIVersion: "tailscale.io/v1alpha1", + }, + Spec: tsapi.ConnectorSpec{ + Replicas: ptr.To[int32](1), + AppConnector: &tsapi.AppConnector{}, + }, + } + fc := fake.NewClientBuilder(). + WithScheme(tsapi.GlobalScheme). + WithObjects(cn). + WithStatusSubresource(cn). + Build() + ft := &fakeTSClient{} + zl, err := zap.NewDevelopment() + if err != nil { + t.Fatal(err) + } + cl := tstest.NewClock(tstest.ClockOpts{}) + fr := record.NewFakeRecorder(1) + cr := &ConnectorReconciler{ + Client: fc, + clock: cl, + ssr: &tailscaleSTSReconciler{ + Client: fc, + tsClient: ft, + defaultTags: []string{"tag:k8s"}, + operatorNamespace: "operator-ns", + proxyImage: "tailscale/tailscale", + }, + logger: zl.Sugar(), + recorder: fr, + } + + // 1. Connector with app connector is created and becomes ready + expectReconciled(t, cr, "", "test") + fullName, shortName := findGenName(t, fc, "", "test", "connector") + opts := configOpts{ + stsName: shortName, + secretName: fullName, + parentType: "connector", + hostname: "test-connector", + app: kubetypes.AppConnector, + isAppConnector: true, + replicas: cn.Spec.Replicas, + } + expectEqual(t, fc, expectedSecret(t, fc, opts)) + expectEqual(t, fc, expectedSTS(t, fc, opts), removeResourceReqs) + // Connector's ready condition should be set to true + + cn.ObjectMeta.Finalizers = append(cn.ObjectMeta.Finalizers, "tailscale.com/finalizer") + cn.Status.IsAppConnector = true + cn.Status.Devices = []tsapi.ConnectorDevice{} + cn.Status.Conditions = []metav1.Condition{{ + Type: string(tsapi.ConnectorReady), + Status: metav1.ConditionTrue, + LastTransitionTime: metav1.Time{Time: cl.Now().Truncate(time.Second)}, + Reason: reasonConnectorCreated, + Message: reasonConnectorCreated, + }} + expectEqual(t, fc, cn) + + // 2. Connector with invalid app connector routes has status set to invalid + mustUpdate[tsapi.Connector](t, fc, "", "test", func(conn *tsapi.Connector) { + conn.Spec.AppConnector.Routes = tsapi.Routes{"1.2.3.4/5"} + }) + cn.Spec.AppConnector.Routes = tsapi.Routes{"1.2.3.4/5"} + expectReconciled(t, cr, "", "test") + cn.Status.Conditions = []metav1.Condition{{ + Type: string(tsapi.ConnectorReady), + Status: metav1.ConditionFalse, + LastTransitionTime: metav1.Time{Time: cl.Now().Truncate(time.Second)}, + Reason: reasonConnectorInvalid, + Message: "Connector is invalid: route 1.2.3.4/5 has non-address bits set; expected 0.0.0.0/5", + }} + expectEqual(t, fc, cn) + + // 3. Connector with valid app connnector routes becomes ready + mustUpdate[tsapi.Connector](t, fc, "", "test", func(conn *tsapi.Connector) { + conn.Spec.AppConnector.Routes = tsapi.Routes{"10.88.2.21/32"} + }) + cn.Spec.AppConnector.Routes = tsapi.Routes{"10.88.2.21/32"} + cn.Status.Conditions = []metav1.Condition{{ + Type: string(tsapi.ConnectorReady), + Status: metav1.ConditionTrue, + LastTransitionTime: metav1.Time{Time: cl.Now().Truncate(time.Second)}, + Reason: reasonConnectorCreated, + Message: reasonConnectorCreated, + }} + expectReconciled(t, cr, "", "test") +} + +func TestConnectorWithMultipleReplicas(t *testing.T) { + cn := &tsapi.Connector{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test", + UID: types.UID("1234-UID"), + }, + TypeMeta: metav1.TypeMeta{ + Kind: tsapi.ConnectorKind, + APIVersion: "tailscale.io/v1alpha1", + }, + Spec: tsapi.ConnectorSpec{ + Replicas: ptr.To[int32](3), + AppConnector: &tsapi.AppConnector{}, + HostnamePrefix: "test-connector", + }, + } + fc := fake.NewClientBuilder(). + WithScheme(tsapi.GlobalScheme). + WithObjects(cn). + WithStatusSubresource(cn). + Build() + ft := &fakeTSClient{} + zl, err := zap.NewDevelopment() + if err != nil { + t.Fatal(err) + } + cl := tstest.NewClock(tstest.ClockOpts{}) + fr := record.NewFakeRecorder(1) + cr := &ConnectorReconciler{ + Client: fc, + clock: cl, + ssr: &tailscaleSTSReconciler{ + Client: fc, + tsClient: ft, + defaultTags: []string{"tag:k8s"}, + operatorNamespace: "operator-ns", + proxyImage: "tailscale/tailscale", + }, + logger: zl.Sugar(), + recorder: fr, + } + + // 1. Ensure that our connector resource is reconciled. + expectReconciled(t, cr, "", "test") + + // 2. Ensure we have a number of secrets matching the number of replicas. + names := findGenNames(t, fc, "", "test", "connector") + if int32(len(names)) != *cn.Spec.Replicas { + t.Fatalf("expected %d secrets, got %d", *cn.Spec.Replicas, len(names)) + } + + // 3. Ensure each device has the correct hostname prefix and ordinal suffix. + for i, name := range names { + expected := expectedSecret(t, fc, configOpts{ + secretName: name, + hostname: string(cn.Spec.HostnamePrefix) + "-" + strconv.Itoa(i), + isAppConnector: true, + parentType: "connector", + namespace: cr.tsnamespace, + }) + + expectEqual(t, fc, expected) + } + + // 4. Ensure the generated stateful set has the matching number of replicas + shortName := strings.TrimSuffix(names[0], "-0") + + var sts appsv1.StatefulSet + if err = fc.Get(t.Context(), types.NamespacedName{Namespace: "operator-ns", Name: shortName}, &sts); err != nil { + t.Fatalf("failed to get StatefulSet %q: %v", shortName, err) + } + + if sts.Spec.Replicas == nil { + t.Fatalf("actual StatefulSet %q does not have replicas set", shortName) + } + + if *sts.Spec.Replicas != *cn.Spec.Replicas { + t.Fatalf("expected %d replicas, got %d", *cn.Spec.Replicas, *sts.Spec.Replicas) + } + + // 5. We'll scale the connector down by 1 replica and make sure its secret is cleaned up + mustUpdate[tsapi.Connector](t, fc, "", "test", func(conn *tsapi.Connector) { + conn.Spec.Replicas = ptr.To[int32](2) + }) + expectReconciled(t, cr, "", "test") + names = findGenNames(t, fc, "", "test", "connector") + if len(names) != 2 { + t.Fatalf("expected 2 secrets, got %d", len(names)) + } } diff --git a/cmd/k8s-operator/depaware.txt b/cmd/k8s-operator/depaware.txt index 649296b59..16ad089f3 100644 --- a/cmd/k8s-operator/depaware.txt +++ b/cmd/k8s-operator/depaware.txt @@ -5,101 +5,30 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ W đŸ’Ŗ github.com/alexbrainman/sspi from github.com/alexbrainman/sspi/internal/common+ W github.com/alexbrainman/sspi/internal/common from github.com/alexbrainman/sspi/negotiate W đŸ’Ŗ github.com/alexbrainman/sspi/negotiate from tailscale.com/net/tshttpproxy - L github.com/aws/aws-sdk-go-v2/aws from github.com/aws/aws-sdk-go-v2/aws/defaults+ - L github.com/aws/aws-sdk-go-v2/aws/arn from tailscale.com/ipn/store/awsstore - L github.com/aws/aws-sdk-go-v2/aws/defaults from github.com/aws/aws-sdk-go-v2/service/ssm+ - L github.com/aws/aws-sdk-go-v2/aws/middleware from github.com/aws/aws-sdk-go-v2/aws/retry+ - L github.com/aws/aws-sdk-go-v2/aws/middleware/private/metrics from github.com/aws/aws-sdk-go-v2/aws/retry+ - L github.com/aws/aws-sdk-go-v2/aws/protocol/query from github.com/aws/aws-sdk-go-v2/service/sts - L github.com/aws/aws-sdk-go-v2/aws/protocol/restjson from github.com/aws/aws-sdk-go-v2/service/ssm+ - L github.com/aws/aws-sdk-go-v2/aws/protocol/xml from github.com/aws/aws-sdk-go-v2/service/sts - L github.com/aws/aws-sdk-go-v2/aws/ratelimit from github.com/aws/aws-sdk-go-v2/aws/retry - L github.com/aws/aws-sdk-go-v2/aws/retry from github.com/aws/aws-sdk-go-v2/credentials/endpointcreds/internal/client+ - L github.com/aws/aws-sdk-go-v2/aws/signer/internal/v4 from github.com/aws/aws-sdk-go-v2/aws/signer/v4 - L github.com/aws/aws-sdk-go-v2/aws/signer/v4 from github.com/aws/aws-sdk-go-v2/internal/auth/smithy+ - L github.com/aws/aws-sdk-go-v2/aws/transport/http from github.com/aws/aws-sdk-go-v2/config+ - L github.com/aws/aws-sdk-go-v2/config from tailscale.com/ipn/store/awsstore - L github.com/aws/aws-sdk-go-v2/credentials from github.com/aws/aws-sdk-go-v2/config - L github.com/aws/aws-sdk-go-v2/credentials/ec2rolecreds from github.com/aws/aws-sdk-go-v2/config - L github.com/aws/aws-sdk-go-v2/credentials/endpointcreds from github.com/aws/aws-sdk-go-v2/config - L github.com/aws/aws-sdk-go-v2/credentials/endpointcreds/internal/client from github.com/aws/aws-sdk-go-v2/credentials/endpointcreds - L github.com/aws/aws-sdk-go-v2/credentials/processcreds from github.com/aws/aws-sdk-go-v2/config - L github.com/aws/aws-sdk-go-v2/credentials/ssocreds from github.com/aws/aws-sdk-go-v2/config - L github.com/aws/aws-sdk-go-v2/credentials/stscreds from github.com/aws/aws-sdk-go-v2/config - L github.com/aws/aws-sdk-go-v2/feature/ec2/imds from github.com/aws/aws-sdk-go-v2/config+ - L github.com/aws/aws-sdk-go-v2/feature/ec2/imds/internal/config from github.com/aws/aws-sdk-go-v2/feature/ec2/imds - L github.com/aws/aws-sdk-go-v2/internal/auth from github.com/aws/aws-sdk-go-v2/aws/signer/v4+ - L github.com/aws/aws-sdk-go-v2/internal/auth/smithy from github.com/aws/aws-sdk-go-v2/service/ssm+ - L github.com/aws/aws-sdk-go-v2/internal/configsources from github.com/aws/aws-sdk-go-v2/service/ssm+ - L github.com/aws/aws-sdk-go-v2/internal/endpoints from github.com/aws/aws-sdk-go-v2/service/ssm+ - L github.com/aws/aws-sdk-go-v2/internal/endpoints/awsrulesfn from github.com/aws/aws-sdk-go-v2/service/ssm+ - L github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 from github.com/aws/aws-sdk-go-v2/service/ssm/internal/endpoints+ - L github.com/aws/aws-sdk-go-v2/internal/ini from github.com/aws/aws-sdk-go-v2/config - L github.com/aws/aws-sdk-go-v2/internal/rand from github.com/aws/aws-sdk-go-v2/aws+ - L github.com/aws/aws-sdk-go-v2/internal/sdk from github.com/aws/aws-sdk-go-v2/aws+ - L github.com/aws/aws-sdk-go-v2/internal/sdkio from github.com/aws/aws-sdk-go-v2/credentials/processcreds - L github.com/aws/aws-sdk-go-v2/internal/shareddefaults from github.com/aws/aws-sdk-go-v2/config+ - L github.com/aws/aws-sdk-go-v2/internal/strings from github.com/aws/aws-sdk-go-v2/aws/signer/internal/v4 - L github.com/aws/aws-sdk-go-v2/internal/sync/singleflight from github.com/aws/aws-sdk-go-v2/aws - L github.com/aws/aws-sdk-go-v2/internal/timeconv from github.com/aws/aws-sdk-go-v2/aws/retry - L github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding from github.com/aws/aws-sdk-go-v2/service/sts - L github.com/aws/aws-sdk-go-v2/service/internal/presigned-url from github.com/aws/aws-sdk-go-v2/service/sts - L github.com/aws/aws-sdk-go-v2/service/ssm from tailscale.com/ipn/store/awsstore - L github.com/aws/aws-sdk-go-v2/service/ssm/internal/endpoints from github.com/aws/aws-sdk-go-v2/service/ssm - L github.com/aws/aws-sdk-go-v2/service/ssm/types from github.com/aws/aws-sdk-go-v2/service/ssm+ - L github.com/aws/aws-sdk-go-v2/service/sso from github.com/aws/aws-sdk-go-v2/config+ - L github.com/aws/aws-sdk-go-v2/service/sso/internal/endpoints from github.com/aws/aws-sdk-go-v2/service/sso - L github.com/aws/aws-sdk-go-v2/service/sso/types from github.com/aws/aws-sdk-go-v2/service/sso - L github.com/aws/aws-sdk-go-v2/service/ssooidc from github.com/aws/aws-sdk-go-v2/config+ - L github.com/aws/aws-sdk-go-v2/service/ssooidc/internal/endpoints from github.com/aws/aws-sdk-go-v2/service/ssooidc - L github.com/aws/aws-sdk-go-v2/service/ssooidc/types from github.com/aws/aws-sdk-go-v2/service/ssooidc - L github.com/aws/aws-sdk-go-v2/service/sts from github.com/aws/aws-sdk-go-v2/config+ - L github.com/aws/aws-sdk-go-v2/service/sts/internal/endpoints from github.com/aws/aws-sdk-go-v2/service/sts - L github.com/aws/aws-sdk-go-v2/service/sts/types from github.com/aws/aws-sdk-go-v2/credentials/stscreds+ - L github.com/aws/smithy-go from github.com/aws/aws-sdk-go-v2/aws/protocol/restjson+ - L github.com/aws/smithy-go/auth from github.com/aws/aws-sdk-go-v2/internal/auth+ - L github.com/aws/smithy-go/auth/bearer from github.com/aws/aws-sdk-go-v2/aws+ - L github.com/aws/smithy-go/context from github.com/aws/smithy-go/auth/bearer - L github.com/aws/smithy-go/document from github.com/aws/aws-sdk-go-v2/service/ssm+ - L github.com/aws/smithy-go/encoding from github.com/aws/smithy-go/encoding/json+ - L github.com/aws/smithy-go/encoding/httpbinding from github.com/aws/aws-sdk-go-v2/aws/protocol/query+ - L github.com/aws/smithy-go/encoding/json from github.com/aws/aws-sdk-go-v2/service/ssm+ - L github.com/aws/smithy-go/encoding/xml from github.com/aws/aws-sdk-go-v2/service/sts - L github.com/aws/smithy-go/endpoints from github.com/aws/aws-sdk-go-v2/service/ssm+ - L github.com/aws/smithy-go/internal/sync/singleflight from github.com/aws/smithy-go/auth/bearer - L github.com/aws/smithy-go/io from github.com/aws/aws-sdk-go-v2/feature/ec2/imds+ - L github.com/aws/smithy-go/logging from github.com/aws/aws-sdk-go-v2/aws+ - L github.com/aws/smithy-go/middleware from github.com/aws/aws-sdk-go-v2/aws+ - L github.com/aws/smithy-go/private/requestcompression from github.com/aws/aws-sdk-go-v2/config - L github.com/aws/smithy-go/ptr from github.com/aws/aws-sdk-go-v2/aws+ - L github.com/aws/smithy-go/rand from github.com/aws/aws-sdk-go-v2/aws/middleware+ - L github.com/aws/smithy-go/time from github.com/aws/aws-sdk-go-v2/service/ssm+ - L github.com/aws/smithy-go/transport/http from github.com/aws/aws-sdk-go-v2/aws/middleware+ - L github.com/aws/smithy-go/transport/http/internal/io from github.com/aws/smithy-go/transport/http - L github.com/aws/smithy-go/waiter from github.com/aws/aws-sdk-go-v2/service/ssm github.com/beorn7/perks/quantile from github.com/prometheus/client_golang/prometheus - github.com/bits-and-blooms/bitset from github.com/gaissmai/bart + github.com/blang/semver/v4 from k8s.io/component-base/metrics đŸ’Ŗ github.com/cespare/xxhash/v2 from github.com/prometheus/client_golang/prometheus - github.com/coder/websocket from tailscale.com/control/controlhttp+ + github.com/coder/websocket from tailscale.com/util/eventbus github.com/coder/websocket/internal/errd from github.com/coder/websocket github.com/coder/websocket/internal/util from github.com/coder/websocket github.com/coder/websocket/internal/xsync from github.com/coder/websocket - L github.com/coreos/go-iptables/iptables from tailscale.com/util/linuxfw + github.com/creachadair/msync/trigger from tailscale.com/logtail đŸ’Ŗ github.com/davecgh/go-spew/spew from k8s.io/apimachinery/pkg/util/dump - W đŸ’Ŗ github.com/dblohm7/wingoes from github.com/dblohm7/wingoes/com+ + W đŸ’Ŗ github.com/dblohm7/wingoes from tailscale.com/net/tshttpproxy+ W đŸ’Ŗ github.com/dblohm7/wingoes/com from tailscale.com/util/osdiag+ W đŸ’Ŗ github.com/dblohm7/wingoes/com/automation from tailscale.com/util/osdiag/internal/wsc W github.com/dblohm7/wingoes/internal from github.com/dblohm7/wingoes/com W đŸ’Ŗ github.com/dblohm7/wingoes/pe from tailscale.com/util/osdiag+ - LW đŸ’Ŗ github.com/digitalocean/go-smbios/smbios from tailscale.com/posture github.com/distribution/reference from tailscale.com/cmd/k8s-operator github.com/emicklei/go-restful/v3 from k8s.io/kube-openapi/pkg/common github.com/emicklei/go-restful/v3/log from github.com/emicklei/go-restful/v3 github.com/evanphx/json-patch/v5 from sigs.k8s.io/controller-runtime/pkg/client github.com/evanphx/json-patch/v5/internal/json from github.com/evanphx/json-patch/v5 đŸ’Ŗ github.com/fsnotify/fsnotify from sigs.k8s.io/controller-runtime/pkg/certwatcher - github.com/fxamacker/cbor/v2 from tailscale.com/tka + github.com/fxamacker/cbor/v2 from tailscale.com/tka+ github.com/gaissmai/bart from tailscale.com/net/ipset+ + github.com/gaissmai/bart/internal/bitset from github.com/gaissmai/bart+ + github.com/gaissmai/bart/internal/sparse from github.com/gaissmai/bart github.com/go-json-experiment/json from tailscale.com/types/opt+ github.com/go-json-experiment/json/internal from github.com/go-json-experiment/json/internal/jsonflags+ github.com/go-json-experiment/json/internal/jsonflags from github.com/go-json-experiment/json/internal/jsonopts+ @@ -109,16 +38,14 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ github.com/go-logr/logr from github.com/go-logr/logr/slogr+ github.com/go-logr/logr/slogr from github.com/go-logr/zapr github.com/go-logr/zapr from sigs.k8s.io/controller-runtime/pkg/log/zap+ - W đŸ’Ŗ github.com/go-ole/go-ole from github.com/go-ole/go-ole/oleutil+ - W đŸ’Ŗ github.com/go-ole/go-ole/oleutil from tailscale.com/wgengine/winnet github.com/go-openapi/jsonpointer from github.com/go-openapi/jsonreference github.com/go-openapi/jsonreference from k8s.io/kube-openapi/pkg/internal+ github.com/go-openapi/jsonreference/internal from github.com/go-openapi/jsonreference - github.com/go-openapi/swag from github.com/go-openapi/jsonpointer+ + đŸ’Ŗ github.com/go-openapi/swag from github.com/go-openapi/jsonpointer+ L đŸ’Ŗ github.com/godbus/dbus/v5 from tailscale.com/net/dns đŸ’Ŗ github.com/gogo/protobuf/proto from k8s.io/api/admission/v1+ github.com/gogo/protobuf/sortkeys from k8s.io/api/admission/v1+ - github.com/golang/groupcache/lru from k8s.io/client-go/tools/record+ + github.com/golang/groupcache/lru from tailscale.com/net/dnscache github.com/golang/protobuf/proto from k8s.io/client-go/discovery+ github.com/google/btree from gvisor.dev/gvisor/pkg/tcpip/header+ github.com/google/gnostic-models/compiler from github.com/google/gnostic-models/openapiv2+ @@ -133,25 +60,10 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ đŸ’Ŗ github.com/google/go-cmp/cmp/internal/value from github.com/google/go-cmp/cmp github.com/google/gofuzz from k8s.io/apimachinery/pkg/apis/meta/v1+ github.com/google/gofuzz/bytesource from github.com/google/gofuzz - L github.com/google/nftables from tailscale.com/util/linuxfw - L đŸ’Ŗ github.com/google/nftables/alignedbuff from github.com/google/nftables/xt - L đŸ’Ŗ github.com/google/nftables/binaryutil from github.com/google/nftables+ - L github.com/google/nftables/expr from github.com/google/nftables+ - L github.com/google/nftables/internal/parseexprfunc from github.com/google/nftables+ - L github.com/google/nftables/xt from github.com/google/nftables/expr+ github.com/google/uuid from github.com/prometheus-community/pro-bing+ - github.com/gorilla/csrf from tailscale.com/client/web - github.com/gorilla/securecookie from github.com/gorilla/csrf - github.com/hdevalence/ed25519consensus from tailscale.com/clientupdate/distsign+ - L đŸ’Ŗ github.com/illarion/gonotify/v2 from tailscale.com/net/dns - github.com/imdario/mergo from k8s.io/client-go/tools/clientcmd - L github.com/insomniacslk/dhcp/dhcpv4 from tailscale.com/net/tstun - L github.com/insomniacslk/dhcp/iana from github.com/insomniacslk/dhcp/dhcpv4 - L github.com/insomniacslk/dhcp/interfaces from github.com/insomniacslk/dhcp/dhcpv4 - L github.com/insomniacslk/dhcp/rfc1035label from github.com/insomniacslk/dhcp/dhcpv4 - L github.com/jmespath/go-jmespath from github.com/aws/aws-sdk-go-v2/service/ssm + github.com/hdevalence/ed25519consensus from tailscale.com/tka + W đŸ’Ŗ github.com/inconshreveable/mousetrap from github.com/spf13/cobra github.com/josharian/intern from github.com/mailru/easyjson/jlexer - L github.com/josharian/native from github.com/mdlayher/netlink+ L đŸ’Ŗ github.com/jsimonetti/rtnetlink from tailscale.com/net/netmon L github.com/jsimonetti/rtnetlink/internal/unix from github.com/jsimonetti/rtnetlink đŸ’Ŗ github.com/json-iterator/go from sigs.k8s.io/structured-merge-diff/v4/fieldpath+ @@ -160,64 +72,51 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ github.com/klauspost/compress/huff0 from github.com/klauspost/compress/zstd github.com/klauspost/compress/internal/cpuinfo from github.com/klauspost/compress/huff0+ github.com/klauspost/compress/internal/snapref from github.com/klauspost/compress/zstd - github.com/klauspost/compress/zstd from tailscale.com/util/zstdframe + github.com/klauspost/compress/zstd from tailscale.com/util/zstdframe+ github.com/klauspost/compress/zstd/internal/xxhash from github.com/klauspost/compress/zstd - github.com/kortschak/wol from tailscale.com/ipn/ipnlocal github.com/mailru/easyjson/buffer from github.com/mailru/easyjson/jwriter đŸ’Ŗ github.com/mailru/easyjson/jlexer from github.com/go-openapi/swag github.com/mailru/easyjson/jwriter from github.com/go-openapi/swag - L github.com/mdlayher/genetlink from tailscale.com/net/tstun - L đŸ’Ŗ github.com/mdlayher/netlink from github.com/google/nftables+ + L đŸ’Ŗ github.com/mdlayher/netlink from github.com/jsimonetti/rtnetlink+ L đŸ’Ŗ github.com/mdlayher/netlink/nlenc from github.com/jsimonetti/rtnetlink+ - L github.com/mdlayher/netlink/nltest from github.com/google/nftables - L github.com/mdlayher/sdnotify from tailscale.com/util/systemd L đŸ’Ŗ github.com/mdlayher/socket from github.com/mdlayher/netlink+ - github.com/miekg/dns from tailscale.com/net/dns/recursive đŸ’Ŗ github.com/mitchellh/go-ps from tailscale.com/safesocket github.com/modern-go/concurrent from github.com/json-iterator/go đŸ’Ŗ github.com/modern-go/reflect2 from github.com/json-iterator/go - github.com/munnerz/goautoneg from k8s.io/kube-openapi/pkg/handler3 + github.com/munnerz/goautoneg from k8s.io/kube-openapi/pkg/handler3+ github.com/opencontainers/go-digest from github.com/distribution/reference - L github.com/pierrec/lz4/v4 from github.com/u-root/uio/uio - L github.com/pierrec/lz4/v4/internal/lz4block from github.com/pierrec/lz4/v4+ - L github.com/pierrec/lz4/v4/internal/lz4errors from github.com/pierrec/lz4/v4+ - L github.com/pierrec/lz4/v4/internal/lz4stream from github.com/pierrec/lz4/v4 - L github.com/pierrec/lz4/v4/internal/xxh32 from github.com/pierrec/lz4/v4/internal/lz4stream + github.com/pires/go-proxyproto from tailscale.com/ipn/ipnlocal github.com/pkg/errors from github.com/evanphx/json-patch/v5+ D github.com/prometheus-community/pro-bing from tailscale.com/wgengine/netstack + github.com/prometheus/client_golang/internal/github.com/golang/gddo/httputil from github.com/prometheus/client_golang/prometheus/promhttp + github.com/prometheus/client_golang/internal/github.com/golang/gddo/httputil/header from github.com/prometheus/client_golang/internal/github.com/golang/gddo/httputil đŸ’Ŗ github.com/prometheus/client_golang/prometheus from github.com/prometheus/client_golang/prometheus/collectors+ - github.com/prometheus/client_golang/prometheus/collectors from sigs.k8s.io/controller-runtime/pkg/internal/controller/metrics + github.com/prometheus/client_golang/prometheus/collectors from sigs.k8s.io/controller-runtime/pkg/internal/controller/metrics+ github.com/prometheus/client_golang/prometheus/internal from github.com/prometheus/client_golang/prometheus+ github.com/prometheus/client_golang/prometheus/promhttp from sigs.k8s.io/controller-runtime/pkg/metrics/server+ github.com/prometheus/client_model/go from github.com/prometheus/client_golang/prometheus+ github.com/prometheus/common/expfmt from github.com/prometheus/client_golang/prometheus+ - github.com/prometheus/common/internal/bitbucket.org/ww/goautoneg from github.com/prometheus/common/expfmt github.com/prometheus/common/model from github.com/prometheus/client_golang/prometheus+ - LD github.com/prometheus/procfs from github.com/prometheus/client_golang/prometheus + LD github.com/prometheus/procfs from github.com/prometheus/client_golang/prometheus+ LD github.com/prometheus/procfs/internal/fs from github.com/prometheus/procfs LD github.com/prometheus/procfs/internal/util from github.com/prometheus/procfs - L đŸ’Ŗ github.com/safchain/ethtool from tailscale.com/doctor/ethtool+ - github.com/spf13/pflag from k8s.io/client-go/tools/clientcmd + L đŸ’Ŗ github.com/safchain/ethtool from tailscale.com/net/netkernelconf + github.com/spf13/cobra from k8s.io/component-base/cli/flag + github.com/spf13/pflag from k8s.io/client-go/tools/clientcmd+ W đŸ’Ŗ github.com/tailscale/certstore from tailscale.com/control/controlclient W đŸ’Ŗ github.com/tailscale/go-winio from tailscale.com/safesocket W đŸ’Ŗ github.com/tailscale/go-winio/internal/fs from github.com/tailscale/go-winio W đŸ’Ŗ github.com/tailscale/go-winio/internal/socket from github.com/tailscale/go-winio W github.com/tailscale/go-winio/internal/stringbuffer from github.com/tailscale/go-winio/internal/fs W github.com/tailscale/go-winio/pkg/guid from github.com/tailscale/go-winio+ - github.com/tailscale/golang-x-crypto/acme from tailscale.com/ipn/ipnlocal - LD github.com/tailscale/golang-x-crypto/internal/poly1305 from github.com/tailscale/golang-x-crypto/ssh - LD github.com/tailscale/golang-x-crypto/ssh from tailscale.com/ipn/ipnlocal - LD github.com/tailscale/golang-x-crypto/ssh/internal/bcrypt_pbkdf from github.com/tailscale/golang-x-crypto/ssh github.com/tailscale/goupnp from github.com/tailscale/goupnp/dcps/internetgateway2+ github.com/tailscale/goupnp/dcps/internetgateway2 from tailscale.com/net/portmapper github.com/tailscale/goupnp/httpu from github.com/tailscale/goupnp+ github.com/tailscale/goupnp/scpd from github.com/tailscale/goupnp github.com/tailscale/goupnp/soap from github.com/tailscale/goupnp+ github.com/tailscale/goupnp/ssdp from github.com/tailscale/goupnp - github.com/tailscale/hujson from tailscale.com/ipn/conffile - L đŸ’Ŗ github.com/tailscale/netlink from tailscale.com/net/routetable+ - L đŸ’Ŗ github.com/tailscale/netlink/nl from github.com/tailscale/netlink - github.com/tailscale/peercred from tailscale.com/ipn/ipnauth + github.com/tailscale/hujson from tailscale.com/ipn/conffile+ + LD github.com/tailscale/peercred from tailscale.com/ipn/ipnauth github.com/tailscale/web-client-prebuilt from tailscale.com/client/web đŸ’Ŗ github.com/tailscale/wireguard-go/conn from github.com/tailscale/wireguard-go/device+ W đŸ’Ŗ github.com/tailscale/wireguard-go/conn/winrio from github.com/tailscale/wireguard-go/conn @@ -229,11 +128,13 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ github.com/tailscale/wireguard-go/rwcancel from github.com/tailscale/wireguard-go/device+ github.com/tailscale/wireguard-go/tai64n from github.com/tailscale/wireguard-go/device đŸ’Ŗ github.com/tailscale/wireguard-go/tun from github.com/tailscale/wireguard-go/device+ - github.com/tcnksm/go-httpstat from tailscale.com/net/netcheck - L github.com/u-root/uio/rand from github.com/insomniacslk/dhcp/dhcpv4 - L github.com/u-root/uio/uio from github.com/insomniacslk/dhcp/dhcpv4+ - L github.com/vishvananda/netns from github.com/tailscale/netlink+ github.com/x448/float16 from github.com/fxamacker/cbor/v2 + go.opentelemetry.io/otel/attribute from go.opentelemetry.io/otel/trace + go.opentelemetry.io/otel/codes from go.opentelemetry.io/otel/trace + đŸ’Ŗ go.opentelemetry.io/otel/internal from go.opentelemetry.io/otel/attribute + go.opentelemetry.io/otel/internal/attribute from go.opentelemetry.io/otel/attribute + go.opentelemetry.io/otel/trace from k8s.io/component-base/metrics + go.opentelemetry.io/otel/trace/embedded from go.opentelemetry.io/otel/trace go.uber.org/multierr from go.uber.org/zap+ go.uber.org/zap from github.com/go-logr/zapr+ go.uber.org/zap/buffer from go.uber.org/zap/internal/bufferpool+ @@ -244,7 +145,7 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ go.uber.org/zap/internal/pool from go.uber.org/zap+ go.uber.org/zap/internal/stacktrace from go.uber.org/zap go.uber.org/zap/zapcore from github.com/go-logr/zapr+ - đŸ’Ŗ go4.org/mem from tailscale.com/client/tailscale+ + đŸ’Ŗ go4.org/mem from tailscale.com/client/local+ go4.org/netipx from tailscale.com/ipn/ipnlocal+ W đŸ’Ŗ golang.zx2c4.com/wintun from github.com/tailscale/wireguard-go/tun W đŸ’Ŗ golang.zx2c4.com/wireguard/windows/tunnel/winipcfg from tailscale.com/net/dns+ @@ -256,6 +157,7 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ google.golang.org/protobuf/internal/descopts from google.golang.org/protobuf/internal/filedesc+ google.golang.org/protobuf/internal/detrand from google.golang.org/protobuf/internal/descfmt+ google.golang.org/protobuf/internal/editiondefaults from google.golang.org/protobuf/internal/filedesc+ + google.golang.org/protobuf/internal/editionssupport from google.golang.org/protobuf/reflect/protodesc google.golang.org/protobuf/internal/encoding/defval from google.golang.org/protobuf/internal/encoding/tag+ google.golang.org/protobuf/internal/encoding/messageset from google.golang.org/protobuf/encoding/prototext+ google.golang.org/protobuf/internal/encoding/tag from google.golang.org/protobuf/internal/impl @@ -268,6 +170,7 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ đŸ’Ŗ google.golang.org/protobuf/internal/impl from google.golang.org/protobuf/internal/filetype+ google.golang.org/protobuf/internal/order from google.golang.org/protobuf/encoding/prototext+ google.golang.org/protobuf/internal/pragma from google.golang.org/protobuf/encoding/prototext+ + đŸ’Ŗ google.golang.org/protobuf/internal/protolazy from google.golang.org/protobuf/internal/impl+ google.golang.org/protobuf/internal/set from google.golang.org/protobuf/encoding/prototext đŸ’Ŗ google.golang.org/protobuf/internal/strs from google.golang.org/protobuf/encoding/prototext+ google.golang.org/protobuf/internal/version from google.golang.org/protobuf/runtime/protoimpl @@ -281,8 +184,8 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ google.golang.org/protobuf/types/gofeaturespb from google.golang.org/protobuf/reflect/protodesc google.golang.org/protobuf/types/known/anypb from github.com/google/gnostic-models/compiler+ google.golang.org/protobuf/types/known/timestamppb from github.com/prometheus/client_golang/prometheus+ + gopkg.in/evanphx/json-patch.v4 from k8s.io/client-go/testing gopkg.in/inf.v0 from k8s.io/apimachinery/pkg/api/resource - gopkg.in/yaml.v2 from k8s.io/kube-openapi/pkg/util/proto+ gopkg.in/yaml.v3 from github.com/go-openapi/swag+ gvisor.dev/gvisor/pkg/atomicbitops from gvisor.dev/gvisor/pkg/buffer+ gvisor.dev/gvisor/pkg/bits from gvisor.dev/gvisor/pkg/buffer @@ -304,12 +207,12 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ gvisor.dev/gvisor/pkg/tcpip/hash/jenkins from gvisor.dev/gvisor/pkg/tcpip/stack+ gvisor.dev/gvisor/pkg/tcpip/header from gvisor.dev/gvisor/pkg/tcpip/header/parse+ gvisor.dev/gvisor/pkg/tcpip/header/parse from gvisor.dev/gvisor/pkg/tcpip/network/ipv4+ - gvisor.dev/gvisor/pkg/tcpip/internal/tcp from gvisor.dev/gvisor/pkg/tcpip/stack+ + gvisor.dev/gvisor/pkg/tcpip/internal/tcp from gvisor.dev/gvisor/pkg/tcpip/transport/tcp gvisor.dev/gvisor/pkg/tcpip/network/hash from gvisor.dev/gvisor/pkg/tcpip/network/ipv4 gvisor.dev/gvisor/pkg/tcpip/network/internal/fragmentation from gvisor.dev/gvisor/pkg/tcpip/network/ipv4+ gvisor.dev/gvisor/pkg/tcpip/network/internal/ip from gvisor.dev/gvisor/pkg/tcpip/network/ipv4+ gvisor.dev/gvisor/pkg/tcpip/network/internal/multicast from gvisor.dev/gvisor/pkg/tcpip/network/ipv4+ - gvisor.dev/gvisor/pkg/tcpip/network/ipv4 from tailscale.com/net/tstun+ + gvisor.dev/gvisor/pkg/tcpip/network/ipv4 from tailscale.com/wgengine/netstack gvisor.dev/gvisor/pkg/tcpip/network/ipv6 from tailscale.com/wgengine/netstack gvisor.dev/gvisor/pkg/tcpip/ports from gvisor.dev/gvisor/pkg/tcpip/stack+ gvisor.dev/gvisor/pkg/tcpip/seqnum from gvisor.dev/gvisor/pkg/tcpip/header+ @@ -351,6 +254,7 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ k8s.io/api/certificates/v1alpha1 from k8s.io/client-go/applyconfigurations/certificates/v1alpha1+ k8s.io/api/certificates/v1beta1 from k8s.io/client-go/applyconfigurations/certificates/v1beta1+ k8s.io/api/coordination/v1 from k8s.io/client-go/applyconfigurations/coordination/v1+ + k8s.io/api/coordination/v1alpha2 from k8s.io/client-go/applyconfigurations/coordination/v1alpha2+ k8s.io/api/coordination/v1beta1 from k8s.io/client-go/applyconfigurations/coordination/v1beta1+ k8s.io/api/core/v1 from k8s.io/api/apps/v1+ k8s.io/api/discovery/v1 from k8s.io/client-go/applyconfigurations/discovery/v1+ @@ -373,7 +277,8 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ k8s.io/api/rbac/v1 from k8s.io/client-go/applyconfigurations/rbac/v1+ k8s.io/api/rbac/v1alpha1 from k8s.io/client-go/applyconfigurations/rbac/v1alpha1+ k8s.io/api/rbac/v1beta1 from k8s.io/client-go/applyconfigurations/rbac/v1beta1+ - k8s.io/api/resource/v1alpha2 from k8s.io/client-go/applyconfigurations/resource/v1alpha2+ + k8s.io/api/resource/v1alpha3 from k8s.io/client-go/applyconfigurations/resource/v1alpha3+ + k8s.io/api/resource/v1beta1 from k8s.io/client-go/applyconfigurations/resource/v1beta1+ k8s.io/api/scheduling/v1 from k8s.io/client-go/applyconfigurations/scheduling/v1+ k8s.io/api/scheduling/v1alpha1 from k8s.io/client-go/applyconfigurations/scheduling/v1alpha1+ k8s.io/api/scheduling/v1beta1 from k8s.io/client-go/applyconfigurations/scheduling/v1beta1+ @@ -382,14 +287,17 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ k8s.io/api/storage/v1beta1 from k8s.io/client-go/applyconfigurations/storage/v1beta1+ k8s.io/api/storagemigration/v1alpha1 from k8s.io/client-go/applyconfigurations/storagemigration/v1alpha1+ k8s.io/apiextensions-apiserver/pkg/apis/apiextensions from k8s.io/apiextensions-apiserver/pkg/apis/apiextensions/v1 - đŸ’Ŗ k8s.io/apiextensions-apiserver/pkg/apis/apiextensions/v1 from sigs.k8s.io/controller-runtime/pkg/webhook/conversion + đŸ’Ŗ k8s.io/apiextensions-apiserver/pkg/apis/apiextensions/v1 from sigs.k8s.io/controller-runtime/pkg/webhook/conversion+ k8s.io/apimachinery/pkg/api/equality from k8s.io/apiextensions-apiserver/pkg/apis/apiextensions/v1+ k8s.io/apimachinery/pkg/api/errors from k8s.io/apimachinery/pkg/util/managedfields/internal+ k8s.io/apimachinery/pkg/api/meta from k8s.io/apimachinery/pkg/api/validation+ + k8s.io/apimachinery/pkg/api/meta/testrestmapper from k8s.io/client-go/testing k8s.io/apimachinery/pkg/api/resource from k8s.io/api/autoscaling/v1+ k8s.io/apimachinery/pkg/api/validation from k8s.io/apimachinery/pkg/util/managedfields/internal+ + k8s.io/apimachinery/pkg/api/validation/path from k8s.io/apiserver/pkg/endpoints/request đŸ’Ŗ k8s.io/apimachinery/pkg/apis/meta/internalversion from k8s.io/apimachinery/pkg/apis/meta/internalversion/scheme+ - k8s.io/apimachinery/pkg/apis/meta/internalversion/scheme from k8s.io/client-go/metadata + k8s.io/apimachinery/pkg/apis/meta/internalversion/scheme from k8s.io/client-go/metadata+ + k8s.io/apimachinery/pkg/apis/meta/internalversion/validation from k8s.io/client-go/util/watchlist đŸ’Ŗ k8s.io/apimachinery/pkg/apis/meta/v1 from k8s.io/api/admission/v1+ k8s.io/apimachinery/pkg/apis/meta/v1/unstructured from k8s.io/apimachinery/pkg/runtime/serializer/versioning+ k8s.io/apimachinery/pkg/apis/meta/v1/validation from k8s.io/apimachinery/pkg/api/validation+ @@ -401,6 +309,9 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ k8s.io/apimachinery/pkg/runtime from k8s.io/api/admission/v1+ k8s.io/apimachinery/pkg/runtime/schema from k8s.io/api/admission/v1+ k8s.io/apimachinery/pkg/runtime/serializer from k8s.io/apimachinery/pkg/apis/meta/internalversion/scheme+ + k8s.io/apimachinery/pkg/runtime/serializer/cbor from k8s.io/client-go/dynamic+ + k8s.io/apimachinery/pkg/runtime/serializer/cbor/direct from k8s.io/apiextensions-apiserver/pkg/apis/apiextensions/v1+ + k8s.io/apimachinery/pkg/runtime/serializer/cbor/internal/modes from k8s.io/apimachinery/pkg/runtime/serializer/cbor+ k8s.io/apimachinery/pkg/runtime/serializer/json from k8s.io/apimachinery/pkg/runtime/serializer+ k8s.io/apimachinery/pkg/runtime/serializer/protobuf from k8s.io/apimachinery/pkg/runtime/serializer k8s.io/apimachinery/pkg/runtime/serializer/recognizer from k8s.io/apimachinery/pkg/runtime/serializer+ @@ -428,13 +339,18 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ k8s.io/apimachinery/pkg/util/uuid from sigs.k8s.io/controller-runtime/pkg/internal/controller+ k8s.io/apimachinery/pkg/util/validation from k8s.io/apimachinery/pkg/api/validation+ k8s.io/apimachinery/pkg/util/validation/field from k8s.io/apimachinery/pkg/api/errors+ + k8s.io/apimachinery/pkg/util/version from k8s.io/apiserver/pkg/features+ k8s.io/apimachinery/pkg/util/wait from k8s.io/client-go/tools/cache+ k8s.io/apimachinery/pkg/util/yaml from k8s.io/apimachinery/pkg/runtime/serializer/json k8s.io/apimachinery/pkg/version from k8s.io/client-go/discovery+ k8s.io/apimachinery/pkg/watch from k8s.io/apimachinery/pkg/apis/meta/v1+ k8s.io/apimachinery/third_party/forked/golang/json from k8s.io/apimachinery/pkg/util/strategicpatch k8s.io/apimachinery/third_party/forked/golang/reflect from k8s.io/apimachinery/pkg/conversion + k8s.io/apiserver/pkg/authentication/user from k8s.io/apiserver/pkg/endpoints/request + k8s.io/apiserver/pkg/endpoints/request from tailscale.com/k8s-operator/api-proxy + k8s.io/apiserver/pkg/features from k8s.io/apiserver/pkg/endpoints/request k8s.io/apiserver/pkg/storage/names from tailscale.com/cmd/k8s-operator + k8s.io/apiserver/pkg/util/feature from k8s.io/apiserver/pkg/endpoints/request+ k8s.io/client-go/applyconfigurations/admissionregistration/v1 from k8s.io/client-go/applyconfigurations/admissionregistration/v1alpha1+ k8s.io/client-go/applyconfigurations/admissionregistration/v1alpha1 from k8s.io/client-go/kubernetes/typed/admissionregistration/v1alpha1 k8s.io/client-go/applyconfigurations/admissionregistration/v1beta1 from k8s.io/client-go/kubernetes/typed/admissionregistration/v1beta1 @@ -452,6 +368,7 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ k8s.io/client-go/applyconfigurations/certificates/v1alpha1 from k8s.io/client-go/kubernetes/typed/certificates/v1alpha1 k8s.io/client-go/applyconfigurations/certificates/v1beta1 from k8s.io/client-go/kubernetes/typed/certificates/v1beta1 k8s.io/client-go/applyconfigurations/coordination/v1 from k8s.io/client-go/kubernetes/typed/coordination/v1 + k8s.io/client-go/applyconfigurations/coordination/v1alpha2 from k8s.io/client-go/kubernetes/typed/coordination/v1alpha2 k8s.io/client-go/applyconfigurations/coordination/v1beta1 from k8s.io/client-go/kubernetes/typed/coordination/v1beta1 k8s.io/client-go/applyconfigurations/core/v1 from k8s.io/client-go/applyconfigurations/apps/v1+ k8s.io/client-go/applyconfigurations/discovery/v1 from k8s.io/client-go/kubernetes/typed/discovery/v1 @@ -476,7 +393,8 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ k8s.io/client-go/applyconfigurations/rbac/v1 from k8s.io/client-go/kubernetes/typed/rbac/v1 k8s.io/client-go/applyconfigurations/rbac/v1alpha1 from k8s.io/client-go/kubernetes/typed/rbac/v1alpha1 k8s.io/client-go/applyconfigurations/rbac/v1beta1 from k8s.io/client-go/kubernetes/typed/rbac/v1beta1 - k8s.io/client-go/applyconfigurations/resource/v1alpha2 from k8s.io/client-go/kubernetes/typed/resource/v1alpha2 + k8s.io/client-go/applyconfigurations/resource/v1alpha3 from k8s.io/client-go/kubernetes/typed/resource/v1alpha3 + k8s.io/client-go/applyconfigurations/resource/v1beta1 from k8s.io/client-go/kubernetes/typed/resource/v1beta1 k8s.io/client-go/applyconfigurations/scheduling/v1 from k8s.io/client-go/kubernetes/typed/scheduling/v1 k8s.io/client-go/applyconfigurations/scheduling/v1alpha1 from k8s.io/client-go/kubernetes/typed/scheduling/v1alpha1 k8s.io/client-go/applyconfigurations/scheduling/v1beta1 from k8s.io/client-go/kubernetes/typed/scheduling/v1beta1 @@ -486,8 +404,80 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ k8s.io/client-go/applyconfigurations/storagemigration/v1alpha1 from k8s.io/client-go/kubernetes/typed/storagemigration/v1alpha1 k8s.io/client-go/discovery from k8s.io/client-go/applyconfigurations/meta/v1+ k8s.io/client-go/dynamic from sigs.k8s.io/controller-runtime/pkg/cache/internal+ - k8s.io/client-go/features from k8s.io/client-go/tools/cache - k8s.io/client-go/kubernetes from k8s.io/client-go/tools/leaderelection/resourcelock + k8s.io/client-go/features from k8s.io/client-go/tools/cache+ + k8s.io/client-go/gentype from k8s.io/client-go/kubernetes/typed/admissionregistration/v1+ + k8s.io/client-go/informers from k8s.io/client-go/tools/leaderelection + k8s.io/client-go/informers/admissionregistration from k8s.io/client-go/informers + k8s.io/client-go/informers/admissionregistration/v1 from k8s.io/client-go/informers/admissionregistration + k8s.io/client-go/informers/admissionregistration/v1alpha1 from k8s.io/client-go/informers/admissionregistration + k8s.io/client-go/informers/admissionregistration/v1beta1 from k8s.io/client-go/informers/admissionregistration + k8s.io/client-go/informers/apiserverinternal from k8s.io/client-go/informers + k8s.io/client-go/informers/apiserverinternal/v1alpha1 from k8s.io/client-go/informers/apiserverinternal + k8s.io/client-go/informers/apps from k8s.io/client-go/informers + k8s.io/client-go/informers/apps/v1 from k8s.io/client-go/informers/apps + k8s.io/client-go/informers/apps/v1beta1 from k8s.io/client-go/informers/apps + k8s.io/client-go/informers/apps/v1beta2 from k8s.io/client-go/informers/apps + k8s.io/client-go/informers/autoscaling from k8s.io/client-go/informers + k8s.io/client-go/informers/autoscaling/v1 from k8s.io/client-go/informers/autoscaling + k8s.io/client-go/informers/autoscaling/v2 from k8s.io/client-go/informers/autoscaling + k8s.io/client-go/informers/autoscaling/v2beta1 from k8s.io/client-go/informers/autoscaling + k8s.io/client-go/informers/autoscaling/v2beta2 from k8s.io/client-go/informers/autoscaling + k8s.io/client-go/informers/batch from k8s.io/client-go/informers + k8s.io/client-go/informers/batch/v1 from k8s.io/client-go/informers/batch + k8s.io/client-go/informers/batch/v1beta1 from k8s.io/client-go/informers/batch + k8s.io/client-go/informers/certificates from k8s.io/client-go/informers + k8s.io/client-go/informers/certificates/v1 from k8s.io/client-go/informers/certificates + k8s.io/client-go/informers/certificates/v1alpha1 from k8s.io/client-go/informers/certificates + k8s.io/client-go/informers/certificates/v1beta1 from k8s.io/client-go/informers/certificates + k8s.io/client-go/informers/coordination from k8s.io/client-go/informers + k8s.io/client-go/informers/coordination/v1 from k8s.io/client-go/informers/coordination + k8s.io/client-go/informers/coordination/v1alpha2 from k8s.io/client-go/informers/coordination + k8s.io/client-go/informers/coordination/v1beta1 from k8s.io/client-go/informers/coordination + k8s.io/client-go/informers/core from k8s.io/client-go/informers + k8s.io/client-go/informers/core/v1 from k8s.io/client-go/informers/core + k8s.io/client-go/informers/discovery from k8s.io/client-go/informers + k8s.io/client-go/informers/discovery/v1 from k8s.io/client-go/informers/discovery + k8s.io/client-go/informers/discovery/v1beta1 from k8s.io/client-go/informers/discovery + k8s.io/client-go/informers/events from k8s.io/client-go/informers + k8s.io/client-go/informers/events/v1 from k8s.io/client-go/informers/events + k8s.io/client-go/informers/events/v1beta1 from k8s.io/client-go/informers/events + k8s.io/client-go/informers/extensions from k8s.io/client-go/informers + k8s.io/client-go/informers/extensions/v1beta1 from k8s.io/client-go/informers/extensions + k8s.io/client-go/informers/flowcontrol from k8s.io/client-go/informers + k8s.io/client-go/informers/flowcontrol/v1 from k8s.io/client-go/informers/flowcontrol + k8s.io/client-go/informers/flowcontrol/v1beta1 from k8s.io/client-go/informers/flowcontrol + k8s.io/client-go/informers/flowcontrol/v1beta2 from k8s.io/client-go/informers/flowcontrol + k8s.io/client-go/informers/flowcontrol/v1beta3 from k8s.io/client-go/informers/flowcontrol + k8s.io/client-go/informers/internalinterfaces from k8s.io/client-go/informers+ + k8s.io/client-go/informers/networking from k8s.io/client-go/informers + k8s.io/client-go/informers/networking/v1 from k8s.io/client-go/informers/networking + k8s.io/client-go/informers/networking/v1alpha1 from k8s.io/client-go/informers/networking + k8s.io/client-go/informers/networking/v1beta1 from k8s.io/client-go/informers/networking + k8s.io/client-go/informers/node from k8s.io/client-go/informers + k8s.io/client-go/informers/node/v1 from k8s.io/client-go/informers/node + k8s.io/client-go/informers/node/v1alpha1 from k8s.io/client-go/informers/node + k8s.io/client-go/informers/node/v1beta1 from k8s.io/client-go/informers/node + k8s.io/client-go/informers/policy from k8s.io/client-go/informers + k8s.io/client-go/informers/policy/v1 from k8s.io/client-go/informers/policy + k8s.io/client-go/informers/policy/v1beta1 from k8s.io/client-go/informers/policy + k8s.io/client-go/informers/rbac from k8s.io/client-go/informers + k8s.io/client-go/informers/rbac/v1 from k8s.io/client-go/informers/rbac + k8s.io/client-go/informers/rbac/v1alpha1 from k8s.io/client-go/informers/rbac + k8s.io/client-go/informers/rbac/v1beta1 from k8s.io/client-go/informers/rbac + k8s.io/client-go/informers/resource from k8s.io/client-go/informers + k8s.io/client-go/informers/resource/v1alpha3 from k8s.io/client-go/informers/resource + k8s.io/client-go/informers/resource/v1beta1 from k8s.io/client-go/informers/resource + k8s.io/client-go/informers/scheduling from k8s.io/client-go/informers + k8s.io/client-go/informers/scheduling/v1 from k8s.io/client-go/informers/scheduling + k8s.io/client-go/informers/scheduling/v1alpha1 from k8s.io/client-go/informers/scheduling + k8s.io/client-go/informers/scheduling/v1beta1 from k8s.io/client-go/informers/scheduling + k8s.io/client-go/informers/storage from k8s.io/client-go/informers + k8s.io/client-go/informers/storage/v1 from k8s.io/client-go/informers/storage + k8s.io/client-go/informers/storage/v1alpha1 from k8s.io/client-go/informers/storage + k8s.io/client-go/informers/storage/v1beta1 from k8s.io/client-go/informers/storage + k8s.io/client-go/informers/storagemigration from k8s.io/client-go/informers + k8s.io/client-go/informers/storagemigration/v1alpha1 from k8s.io/client-go/informers/storagemigration + k8s.io/client-go/kubernetes from k8s.io/client-go/tools/leaderelection/resourcelock+ k8s.io/client-go/kubernetes/scheme from k8s.io/client-go/discovery+ k8s.io/client-go/kubernetes/typed/admissionregistration/v1 from k8s.io/client-go/kubernetes k8s.io/client-go/kubernetes/typed/admissionregistration/v1alpha1 from k8s.io/client-go/kubernetes @@ -511,6 +501,7 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ k8s.io/client-go/kubernetes/typed/certificates/v1alpha1 from k8s.io/client-go/kubernetes k8s.io/client-go/kubernetes/typed/certificates/v1beta1 from k8s.io/client-go/kubernetes k8s.io/client-go/kubernetes/typed/coordination/v1 from k8s.io/client-go/kubernetes+ + k8s.io/client-go/kubernetes/typed/coordination/v1alpha2 from k8s.io/client-go/kubernetes+ k8s.io/client-go/kubernetes/typed/coordination/v1beta1 from k8s.io/client-go/kubernetes k8s.io/client-go/kubernetes/typed/core/v1 from k8s.io/client-go/kubernetes+ k8s.io/client-go/kubernetes/typed/discovery/v1 from k8s.io/client-go/kubernetes @@ -533,7 +524,8 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ k8s.io/client-go/kubernetes/typed/rbac/v1 from k8s.io/client-go/kubernetes k8s.io/client-go/kubernetes/typed/rbac/v1alpha1 from k8s.io/client-go/kubernetes k8s.io/client-go/kubernetes/typed/rbac/v1beta1 from k8s.io/client-go/kubernetes - k8s.io/client-go/kubernetes/typed/resource/v1alpha2 from k8s.io/client-go/kubernetes + k8s.io/client-go/kubernetes/typed/resource/v1alpha3 from k8s.io/client-go/kubernetes + k8s.io/client-go/kubernetes/typed/resource/v1beta1 from k8s.io/client-go/kubernetes k8s.io/client-go/kubernetes/typed/scheduling/v1 from k8s.io/client-go/kubernetes k8s.io/client-go/kubernetes/typed/scheduling/v1alpha1 from k8s.io/client-go/kubernetes k8s.io/client-go/kubernetes/typed/scheduling/v1beta1 from k8s.io/client-go/kubernetes @@ -541,6 +533,56 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ k8s.io/client-go/kubernetes/typed/storage/v1alpha1 from k8s.io/client-go/kubernetes k8s.io/client-go/kubernetes/typed/storage/v1beta1 from k8s.io/client-go/kubernetes k8s.io/client-go/kubernetes/typed/storagemigration/v1alpha1 from k8s.io/client-go/kubernetes + k8s.io/client-go/listers from k8s.io/client-go/listers/admissionregistration/v1+ + k8s.io/client-go/listers/admissionregistration/v1 from k8s.io/client-go/informers/admissionregistration/v1 + k8s.io/client-go/listers/admissionregistration/v1alpha1 from k8s.io/client-go/informers/admissionregistration/v1alpha1 + k8s.io/client-go/listers/admissionregistration/v1beta1 from k8s.io/client-go/informers/admissionregistration/v1beta1 + k8s.io/client-go/listers/apiserverinternal/v1alpha1 from k8s.io/client-go/informers/apiserverinternal/v1alpha1 + k8s.io/client-go/listers/apps/v1 from k8s.io/client-go/informers/apps/v1 + k8s.io/client-go/listers/apps/v1beta1 from k8s.io/client-go/informers/apps/v1beta1 + k8s.io/client-go/listers/apps/v1beta2 from k8s.io/client-go/informers/apps/v1beta2 + k8s.io/client-go/listers/autoscaling/v1 from k8s.io/client-go/informers/autoscaling/v1 + k8s.io/client-go/listers/autoscaling/v2 from k8s.io/client-go/informers/autoscaling/v2 + k8s.io/client-go/listers/autoscaling/v2beta1 from k8s.io/client-go/informers/autoscaling/v2beta1 + k8s.io/client-go/listers/autoscaling/v2beta2 from k8s.io/client-go/informers/autoscaling/v2beta2 + k8s.io/client-go/listers/batch/v1 from k8s.io/client-go/informers/batch/v1 + k8s.io/client-go/listers/batch/v1beta1 from k8s.io/client-go/informers/batch/v1beta1 + k8s.io/client-go/listers/certificates/v1 from k8s.io/client-go/informers/certificates/v1 + k8s.io/client-go/listers/certificates/v1alpha1 from k8s.io/client-go/informers/certificates/v1alpha1 + k8s.io/client-go/listers/certificates/v1beta1 from k8s.io/client-go/informers/certificates/v1beta1 + k8s.io/client-go/listers/coordination/v1 from k8s.io/client-go/informers/coordination/v1 + k8s.io/client-go/listers/coordination/v1alpha2 from k8s.io/client-go/informers/coordination/v1alpha2 + k8s.io/client-go/listers/coordination/v1beta1 from k8s.io/client-go/informers/coordination/v1beta1 + k8s.io/client-go/listers/core/v1 from k8s.io/client-go/informers/core/v1 + k8s.io/client-go/listers/discovery/v1 from k8s.io/client-go/informers/discovery/v1 + k8s.io/client-go/listers/discovery/v1beta1 from k8s.io/client-go/informers/discovery/v1beta1 + k8s.io/client-go/listers/events/v1 from k8s.io/client-go/informers/events/v1 + k8s.io/client-go/listers/events/v1beta1 from k8s.io/client-go/informers/events/v1beta1 + k8s.io/client-go/listers/extensions/v1beta1 from k8s.io/client-go/informers/extensions/v1beta1 + k8s.io/client-go/listers/flowcontrol/v1 from k8s.io/client-go/informers/flowcontrol/v1 + k8s.io/client-go/listers/flowcontrol/v1beta1 from k8s.io/client-go/informers/flowcontrol/v1beta1 + k8s.io/client-go/listers/flowcontrol/v1beta2 from k8s.io/client-go/informers/flowcontrol/v1beta2 + k8s.io/client-go/listers/flowcontrol/v1beta3 from k8s.io/client-go/informers/flowcontrol/v1beta3 + k8s.io/client-go/listers/networking/v1 from k8s.io/client-go/informers/networking/v1 + k8s.io/client-go/listers/networking/v1alpha1 from k8s.io/client-go/informers/networking/v1alpha1 + k8s.io/client-go/listers/networking/v1beta1 from k8s.io/client-go/informers/networking/v1beta1 + k8s.io/client-go/listers/node/v1 from k8s.io/client-go/informers/node/v1 + k8s.io/client-go/listers/node/v1alpha1 from k8s.io/client-go/informers/node/v1alpha1 + k8s.io/client-go/listers/node/v1beta1 from k8s.io/client-go/informers/node/v1beta1 + k8s.io/client-go/listers/policy/v1 from k8s.io/client-go/informers/policy/v1 + k8s.io/client-go/listers/policy/v1beta1 from k8s.io/client-go/informers/policy/v1beta1 + k8s.io/client-go/listers/rbac/v1 from k8s.io/client-go/informers/rbac/v1 + k8s.io/client-go/listers/rbac/v1alpha1 from k8s.io/client-go/informers/rbac/v1alpha1 + k8s.io/client-go/listers/rbac/v1beta1 from k8s.io/client-go/informers/rbac/v1beta1 + k8s.io/client-go/listers/resource/v1alpha3 from k8s.io/client-go/informers/resource/v1alpha3 + k8s.io/client-go/listers/resource/v1beta1 from k8s.io/client-go/informers/resource/v1beta1 + k8s.io/client-go/listers/scheduling/v1 from k8s.io/client-go/informers/scheduling/v1 + k8s.io/client-go/listers/scheduling/v1alpha1 from k8s.io/client-go/informers/scheduling/v1alpha1 + k8s.io/client-go/listers/scheduling/v1beta1 from k8s.io/client-go/informers/scheduling/v1beta1 + k8s.io/client-go/listers/storage/v1 from k8s.io/client-go/informers/storage/v1 + k8s.io/client-go/listers/storage/v1alpha1 from k8s.io/client-go/informers/storage/v1alpha1 + k8s.io/client-go/listers/storage/v1beta1 from k8s.io/client-go/informers/storage/v1beta1 + k8s.io/client-go/listers/storagemigration/v1alpha1 from k8s.io/client-go/informers/storagemigration/v1alpha1 k8s.io/client-go/metadata from sigs.k8s.io/controller-runtime/pkg/cache/internal+ k8s.io/client-go/openapi from k8s.io/client-go/discovery k8s.io/client-go/pkg/apis/clientauthentication from k8s.io/client-go/pkg/apis/clientauthentication/install+ @@ -552,6 +594,7 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ k8s.io/client-go/rest from k8s.io/client-go/discovery+ k8s.io/client-go/rest/watch from k8s.io/client-go/rest k8s.io/client-go/restmapper from sigs.k8s.io/controller-runtime/pkg/client/apiutil + k8s.io/client-go/testing from k8s.io/client-go/gentype k8s.io/client-go/tools/auth from k8s.io/client-go/tools/clientcmd k8s.io/client-go/tools/cache from sigs.k8s.io/controller-runtime/pkg/cache+ k8s.io/client-go/tools/cache/synctrack from k8s.io/client-go/tools/cache @@ -568,12 +611,22 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ k8s.io/client-go/tools/record/util from k8s.io/client-go/tools/record k8s.io/client-go/tools/reference from k8s.io/client-go/kubernetes/typed/core/v1+ k8s.io/client-go/transport from k8s.io/client-go/plugin/pkg/client/auth/exec+ + k8s.io/client-go/util/apply from k8s.io/client-go/dynamic+ k8s.io/client-go/util/cert from k8s.io/client-go/rest+ k8s.io/client-go/util/connrotation from k8s.io/client-go/plugin/pkg/client/auth/exec+ + k8s.io/client-go/util/consistencydetector from k8s.io/client-go/dynamic+ k8s.io/client-go/util/flowcontrol from k8s.io/client-go/kubernetes+ k8s.io/client-go/util/homedir from k8s.io/client-go/tools/clientcmd k8s.io/client-go/util/keyutil from k8s.io/client-go/util/cert + k8s.io/client-go/util/watchlist from k8s.io/client-go/dynamic+ k8s.io/client-go/util/workqueue from k8s.io/client-go/transport+ + k8s.io/component-base/cli/flag from k8s.io/component-base/featuregate + k8s.io/component-base/featuregate from k8s.io/apiserver/pkg/features+ + k8s.io/component-base/metrics from k8s.io/component-base/metrics/legacyregistry+ + k8s.io/component-base/metrics/legacyregistry from k8s.io/component-base/metrics/prometheus/feature + k8s.io/component-base/metrics/prometheus/feature from k8s.io/component-base/featuregate + k8s.io/component-base/metrics/prometheusextension from k8s.io/component-base/metrics + k8s.io/component-base/version from k8s.io/component-base/featuregate+ k8s.io/klog/v2 from k8s.io/apimachinery/pkg/api/meta+ k8s.io/klog/v2/internal/buffer from k8s.io/klog/v2 k8s.io/klog/v2/internal/clock from k8s.io/klog/v2 @@ -593,11 +646,12 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ k8s.io/utils/buffer from k8s.io/client-go/tools/cache k8s.io/utils/clock from k8s.io/apimachinery/pkg/util/cache+ k8s.io/utils/clock/testing from k8s.io/client-go/util/flowcontrol + k8s.io/utils/internal/third_party/forked/golang/golang-lru from k8s.io/utils/lru k8s.io/utils/internal/third_party/forked/golang/net from k8s.io/utils/net + k8s.io/utils/lru from k8s.io/client-go/tools/record k8s.io/utils/net from k8s.io/apimachinery/pkg/util/net+ k8s.io/utils/pointer from k8s.io/apiextensions-apiserver/pkg/apis/apiextensions/v1+ k8s.io/utils/ptr from k8s.io/client-go/tools/cache+ - k8s.io/utils/strings/slices from k8s.io/apimachinery/pkg/labels k8s.io/utils/trace from k8s.io/client-go/tools/cache sigs.k8s.io/controller-runtime/pkg/builder from tailscale.com/cmd/k8s-operator sigs.k8s.io/controller-runtime/pkg/cache from sigs.k8s.io/controller-runtime/pkg/cluster+ @@ -630,12 +684,12 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ sigs.k8s.io/controller-runtime/pkg/metrics from sigs.k8s.io/controller-runtime/pkg/certwatcher/metrics+ sigs.k8s.io/controller-runtime/pkg/metrics/server from sigs.k8s.io/controller-runtime/pkg/manager sigs.k8s.io/controller-runtime/pkg/predicate from sigs.k8s.io/controller-runtime/pkg/builder+ - sigs.k8s.io/controller-runtime/pkg/ratelimiter from sigs.k8s.io/controller-runtime/pkg/controller+ sigs.k8s.io/controller-runtime/pkg/reconcile from sigs.k8s.io/controller-runtime/pkg/builder+ sigs.k8s.io/controller-runtime/pkg/recorder from sigs.k8s.io/controller-runtime/pkg/leaderelection+ sigs.k8s.io/controller-runtime/pkg/source from sigs.k8s.io/controller-runtime/pkg/builder+ sigs.k8s.io/controller-runtime/pkg/webhook from sigs.k8s.io/controller-runtime/pkg/manager sigs.k8s.io/controller-runtime/pkg/webhook/admission from sigs.k8s.io/controller-runtime/pkg/builder+ + sigs.k8s.io/controller-runtime/pkg/webhook/admission/metrics from sigs.k8s.io/controller-runtime/pkg/webhook/admission sigs.k8s.io/controller-runtime/pkg/webhook/conversion from sigs.k8s.io/controller-runtime/pkg/builder sigs.k8s.io/controller-runtime/pkg/webhook/internal/metrics from sigs.k8s.io/controller-runtime/pkg/webhook+ sigs.k8s.io/json from k8s.io/apimachinery/pkg/runtime/serializer/json+ @@ -646,51 +700,63 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ sigs.k8s.io/structured-merge-diff/v4/typed from k8s.io/apimachinery/pkg/util/managedfields+ sigs.k8s.io/structured-merge-diff/v4/value from k8s.io/apimachinery/pkg/runtime+ sigs.k8s.io/yaml from k8s.io/apimachinery/pkg/runtime/serializer/json+ - sigs.k8s.io/yaml/goyaml.v2 from sigs.k8s.io/yaml + sigs.k8s.io/yaml/goyaml.v2 from sigs.k8s.io/yaml+ tailscale.com from tailscale.com/version tailscale.com/appc from tailscale.com/ipn/ipnlocal - tailscale.com/atomicfile from tailscale.com/ipn+ - tailscale.com/client/tailscale from tailscale.com/client/web+ + đŸ’Ŗ tailscale.com/atomicfile from tailscale.com/ipn+ + tailscale.com/client/local from tailscale.com/client/tailscale+ + tailscale.com/client/tailscale from tailscale.com/cmd/k8s-operator+ tailscale.com/client/tailscale/apitype from tailscale.com/client/tailscale+ tailscale.com/client/web from tailscale.com/ipn/ipnlocal - tailscale.com/clientupdate from tailscale.com/client/web+ - tailscale.com/clientupdate/distsign from tailscale.com/clientupdate tailscale.com/control/controlbase from tailscale.com/control/controlhttp+ tailscale.com/control/controlclient from tailscale.com/ipn/ipnlocal+ - tailscale.com/control/controlhttp from tailscale.com/control/controlclient + tailscale.com/control/controlhttp from tailscale.com/control/ts2021 + tailscale.com/control/controlhttp/controlhttpcommon from tailscale.com/control/controlhttp tailscale.com/control/controlknobs from tailscale.com/control/controlclient+ + tailscale.com/control/ts2021 from tailscale.com/control/controlclient tailscale.com/derp from tailscale.com/derp/derphttp+ + tailscale.com/derp/derpconst from tailscale.com/derp/derphttp+ tailscale.com/derp/derphttp from tailscale.com/ipn/localapi+ - tailscale.com/disco from tailscale.com/derp+ - tailscale.com/doctor from tailscale.com/ipn/ipnlocal - tailscale.com/doctor/ethtool from tailscale.com/ipn/ipnlocal - đŸ’Ŗ tailscale.com/doctor/permissions from tailscale.com/ipn/ipnlocal - tailscale.com/doctor/routetable from tailscale.com/ipn/ipnlocal - tailscale.com/drive from tailscale.com/client/tailscale+ - tailscale.com/envknob from tailscale.com/client/tailscale+ + tailscale.com/disco from tailscale.com/net/tstun+ + tailscale.com/drive from tailscale.com/client/local+ + tailscale.com/envknob from tailscale.com/client/local+ + tailscale.com/envknob/featureknob from tailscale.com/client/web+ + tailscale.com/feature from tailscale.com/ipn/ipnext+ + tailscale.com/feature/buildfeatures from tailscale.com/wgengine/magicsock+ + tailscale.com/feature/c2n from tailscale.com/tsnet + tailscale.com/feature/condlite/expvar from tailscale.com/wgengine/magicsock + tailscale.com/feature/condregister/oauthkey from tailscale.com/tsnet + tailscale.com/feature/condregister/portmapper from tailscale.com/tsnet + tailscale.com/feature/condregister/useproxy from tailscale.com/tsnet + tailscale.com/feature/oauthkey from tailscale.com/feature/condregister/oauthkey + tailscale.com/feature/portmapper from tailscale.com/feature/condregister/portmapper + tailscale.com/feature/syspolicy from tailscale.com/logpolicy + tailscale.com/feature/useproxy from tailscale.com/feature/condregister/useproxy tailscale.com/health from tailscale.com/control/controlclient+ - tailscale.com/health/healthmsg from tailscale.com/ipn/ipnlocal + tailscale.com/health/healthmsg from tailscale.com/ipn/ipnlocal+ tailscale.com/hostinfo from tailscale.com/client/web+ - tailscale.com/internal/noiseconn from tailscale.com/control/controlclient - tailscale.com/ipn from tailscale.com/client/tailscale+ + tailscale.com/internal/client/tailscale from tailscale.com/cmd/k8s-operator+ + tailscale.com/ipn from tailscale.com/client/local+ tailscale.com/ipn/conffile from tailscale.com/ipn/ipnlocal+ đŸ’Ŗ tailscale.com/ipn/ipnauth from tailscale.com/ipn/ipnlocal+ + tailscale.com/ipn/ipnext from tailscale.com/ipn/ipnlocal tailscale.com/ipn/ipnlocal from tailscale.com/ipn/localapi+ - tailscale.com/ipn/ipnstate from tailscale.com/client/tailscale+ + tailscale.com/ipn/ipnstate from tailscale.com/client/local+ tailscale.com/ipn/localapi from tailscale.com/tsnet - tailscale.com/ipn/policy from tailscale.com/ipn/ipnlocal tailscale.com/ipn/store from tailscale.com/ipn/ipnlocal+ - L tailscale.com/ipn/store/awsstore from tailscale.com/ipn/store - tailscale.com/ipn/store/kubestore from tailscale.com/cmd/k8s-operator+ + tailscale.com/ipn/store/kubestore from tailscale.com/cmd/k8s-operator tailscale.com/ipn/store/mem from tailscale.com/ipn/ipnlocal+ tailscale.com/k8s-operator from tailscale.com/cmd/k8s-operator + tailscale.com/k8s-operator/api-proxy from tailscale.com/cmd/k8s-operator tailscale.com/k8s-operator/apis from tailscale.com/k8s-operator/apis/v1alpha1 tailscale.com/k8s-operator/apis/v1alpha1 from tailscale.com/cmd/k8s-operator+ - tailscale.com/k8s-operator/sessionrecording from tailscale.com/cmd/k8s-operator + tailscale.com/k8s-operator/sessionrecording from tailscale.com/k8s-operator/api-proxy tailscale.com/k8s-operator/sessionrecording/spdy from tailscale.com/k8s-operator/sessionrecording tailscale.com/k8s-operator/sessionrecording/tsrecorder from tailscale.com/k8s-operator/sessionrecording+ tailscale.com/k8s-operator/sessionrecording/ws from tailscale.com/k8s-operator/sessionrecording tailscale.com/kube/egressservices from tailscale.com/cmd/k8s-operator + tailscale.com/kube/ingressservices from tailscale.com/cmd/k8s-operator + tailscale.com/kube/k8s-proxy/conf from tailscale.com/cmd/k8s-operator tailscale.com/kube/kubeapi from tailscale.com/ipn/store/kubestore+ tailscale.com/kube/kubeclient from tailscale.com/ipn/store/kubestore tailscale.com/kube/kubetypes from tailscale.com/cmd/k8s-operator+ @@ -699,19 +765,18 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ tailscale.com/log/sockstatlog from tailscale.com/ipn/ipnlocal tailscale.com/logpolicy from tailscale.com/ipn/ipnlocal+ tailscale.com/logtail from tailscale.com/control/controlclient+ - tailscale.com/logtail/backoff from tailscale.com/control/controlclient+ tailscale.com/logtail/filch from tailscale.com/log/sockstatlog+ - tailscale.com/metrics from tailscale.com/derp+ + tailscale.com/metrics from tailscale.com/tsweb+ + tailscale.com/net/bakedroots from tailscale.com/net/tlsdial+ + đŸ’Ŗ tailscale.com/net/batching from tailscale.com/wgengine/magicsock tailscale.com/net/captivedetection from tailscale.com/ipn/ipnlocal+ - tailscale.com/net/connstats from tailscale.com/net/tstun+ tailscale.com/net/dns from tailscale.com/ipn/ipnlocal+ tailscale.com/net/dns/publicdns from tailscale.com/net/dns+ - tailscale.com/net/dns/recursive from tailscale.com/net/dnsfallback tailscale.com/net/dns/resolvconffile from tailscale.com/cmd/k8s-operator+ - tailscale.com/net/dns/resolver from tailscale.com/net/dns + tailscale.com/net/dns/resolver from tailscale.com/net/dns+ tailscale.com/net/dnscache from tailscale.com/control/controlclient+ tailscale.com/net/dnsfallback from tailscale.com/control/controlclient+ - tailscale.com/net/flowtrack from tailscale.com/net/packet+ + tailscale.com/net/flowtrack from tailscale.com/wgengine+ tailscale.com/net/ipset from tailscale.com/ipn/ipnlocal+ tailscale.com/net/memnet from tailscale.com/tsnet tailscale.com/net/netaddr from tailscale.com/ipn+ @@ -721,113 +786,120 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ tailscale.com/net/netknob from tailscale.com/logpolicy+ đŸ’Ŗ tailscale.com/net/netmon from tailscale.com/control/controlclient+ đŸ’Ŗ tailscale.com/net/netns from tailscale.com/derp/derphttp+ - W đŸ’Ŗ tailscale.com/net/netstat from tailscale.com/portlist - tailscale.com/net/netutil from tailscale.com/client/tailscale+ - tailscale.com/net/packet from tailscale.com/net/connstats+ + tailscale.com/net/netutil from tailscale.com/client/local+ + tailscale.com/net/netx from tailscale.com/control/controlclient+ + tailscale.com/net/packet from tailscale.com/ipn/ipnlocal+ tailscale.com/net/packet/checksum from tailscale.com/net/tstun tailscale.com/net/ping from tailscale.com/net/netcheck+ - tailscale.com/net/portmapper from tailscale.com/ipn/localapi+ + tailscale.com/net/portmapper from tailscale.com/feature/portmapper + tailscale.com/net/portmapper/portmappertype from tailscale.com/net/netcheck+ tailscale.com/net/proxymux from tailscale.com/tsnet - tailscale.com/net/routetable from tailscale.com/doctor/routetable + đŸ’Ŗ tailscale.com/net/sockopts from tailscale.com/wgengine/magicsock tailscale.com/net/socks5 from tailscale.com/tsnet tailscale.com/net/sockstats from tailscale.com/control/controlclient+ tailscale.com/net/stun from tailscale.com/ipn/localapi+ - L tailscale.com/net/tcpinfo from tailscale.com/derp tailscale.com/net/tlsdial from tailscale.com/control/controlclient+ + tailscale.com/net/tlsdial/blockblame from tailscale.com/net/tlsdial tailscale.com/net/tsaddr from tailscale.com/client/web+ tailscale.com/net/tsdial from tailscale.com/control/controlclient+ - đŸ’Ŗ tailscale.com/net/tshttpproxy from tailscale.com/clientupdate/distsign+ + đŸ’Ŗ tailscale.com/net/tshttpproxy from tailscale.com/feature/useproxy tailscale.com/net/tstun from tailscale.com/tsd+ - tailscale.com/net/wsconn from tailscale.com/control/controlhttp+ + tailscale.com/net/udprelay/endpoint from tailscale.com/wgengine/magicsock + tailscale.com/net/udprelay/status from tailscale.com/client/local tailscale.com/omit from tailscale.com/ipn/conffile - tailscale.com/paths from tailscale.com/client/tailscale+ - đŸ’Ŗ tailscale.com/portlist from tailscale.com/ipn/ipnlocal - tailscale.com/posture from tailscale.com/ipn/ipnlocal + tailscale.com/paths from tailscale.com/client/local+ tailscale.com/proxymap from tailscale.com/tsd+ - đŸ’Ŗ tailscale.com/safesocket from tailscale.com/client/tailscale+ + đŸ’Ŗ tailscale.com/safesocket from tailscale.com/client/local+ tailscale.com/sessionrecording from tailscale.com/k8s-operator/sessionrecording+ tailscale.com/syncs from tailscale.com/control/controlknobs+ - tailscale.com/tailcfg from tailscale.com/client/tailscale+ - tailscale.com/taildrop from tailscale.com/ipn/ipnlocal+ + tailscale.com/tailcfg from tailscale.com/client/local+ + tailscale.com/tempfork/acme from tailscale.com/ipn/ipnlocal tailscale.com/tempfork/heap from tailscale.com/wgengine/magicsock - tailscale.com/tka from tailscale.com/client/tailscale+ + tailscale.com/tempfork/httprec from tailscale.com/feature/c2n + tailscale.com/tka from tailscale.com/client/local+ tailscale.com/tsconst from tailscale.com/net/netmon+ tailscale.com/tsd from tailscale.com/ipn/ipnlocal+ tailscale.com/tsnet from tailscale.com/cmd/k8s-operator+ tailscale.com/tstime from tailscale.com/cmd/k8s-operator+ tailscale.com/tstime/mono from tailscale.com/net/tstun+ - tailscale.com/tstime/rate from tailscale.com/derp+ - tailscale.com/tsweb/varz from tailscale.com/util/usermetric - tailscale.com/types/appctype from tailscale.com/ipn/ipnlocal + tailscale.com/tstime/rate from tailscale.com/wgengine/filter + tailscale.com/tsweb from tailscale.com/util/eventbus + tailscale.com/tsweb/varz from tailscale.com/util/usermetric+ + tailscale.com/types/appctype from tailscale.com/ipn/ipnlocal+ + tailscale.com/types/bools from tailscale.com/tsnet+ tailscale.com/types/dnstype from tailscale.com/ipn/ipnlocal+ tailscale.com/types/empty from tailscale.com/ipn+ tailscale.com/types/ipproto from tailscale.com/net/flowtrack+ - tailscale.com/types/key from tailscale.com/client/tailscale+ + tailscale.com/types/key from tailscale.com/client/local+ tailscale.com/types/lazy from tailscale.com/ipn/ipnlocal+ tailscale.com/types/logger from tailscale.com/appc+ tailscale.com/types/logid from tailscale.com/ipn/ipnlocal+ - tailscale.com/types/netlogtype from tailscale.com/net/connstats+ + tailscale.com/types/mapx from tailscale.com/ipn/ipnext + tailscale.com/types/netlogfunc from tailscale.com/net/tstun+ + tailscale.com/types/netlogtype from tailscale.com/wgengine/netlog tailscale.com/types/netmap from tailscale.com/control/controlclient+ tailscale.com/types/nettype from tailscale.com/ipn/localapi+ tailscale.com/types/opt from tailscale.com/client/tailscale+ tailscale.com/types/persist from tailscale.com/control/controlclient+ tailscale.com/types/preftype from tailscale.com/ipn+ tailscale.com/types/ptr from tailscale.com/cmd/k8s-operator+ + tailscale.com/types/result from tailscale.com/util/lineiter tailscale.com/types/structs from tailscale.com/control/controlclient+ - tailscale.com/types/tkatype from tailscale.com/client/tailscale+ + tailscale.com/types/tkatype from tailscale.com/client/local+ tailscale.com/types/views from tailscale.com/appc+ - tailscale.com/util/cibuild from tailscale.com/health + tailscale.com/util/backoff from tailscale.com/cmd/k8s-operator+ + tailscale.com/util/checkchange from tailscale.com/ipn/ipnlocal+ + tailscale.com/util/cibuild from tailscale.com/health+ tailscale.com/util/clientmetric from tailscale.com/cmd/k8s-operator+ tailscale.com/util/cloudenv from tailscale.com/hostinfo+ - tailscale.com/util/cmpver from tailscale.com/clientupdate+ - tailscale.com/util/ctxkey from tailscale.com/cmd/k8s-operator+ - đŸ’Ŗ tailscale.com/util/deephash from tailscale.com/ipn/ipnlocal+ - L đŸ’Ŗ tailscale.com/util/dirwalk from tailscale.com/metrics+ + LW tailscale.com/util/cmpver from tailscale.com/net/dns+ + tailscale.com/util/ctxkey from tailscale.com/client/tailscale/apitype+ + đŸ’Ŗ tailscale.com/util/deephash from tailscale.com/util/syspolicy/setting + L đŸ’Ŗ tailscale.com/util/dirwalk from tailscale.com/metrics tailscale.com/util/dnsname from tailscale.com/appc+ + tailscale.com/util/eventbus from tailscale.com/tsd+ tailscale.com/util/execqueue from tailscale.com/appc+ tailscale.com/util/goroutines from tailscale.com/ipn/ipnlocal tailscale.com/util/groupmember from tailscale.com/client/web+ đŸ’Ŗ tailscale.com/util/hashx from tailscale.com/util/deephash - tailscale.com/util/httphdr from tailscale.com/ipn/ipnlocal+ tailscale.com/util/httpm from tailscale.com/client/tailscale+ - tailscale.com/util/lineread from tailscale.com/hostinfo+ - L tailscale.com/util/linuxfw from tailscale.com/net/netns+ + tailscale.com/util/lineiter from tailscale.com/hostinfo+ tailscale.com/util/mak from tailscale.com/appc+ - tailscale.com/util/multierr from tailscale.com/control/controlclient+ - tailscale.com/util/must from tailscale.com/clientupdate/distsign+ + tailscale.com/util/must from tailscale.com/logpolicy+ tailscale.com/util/nocasemaps from tailscale.com/types/ipproto đŸ’Ŗ tailscale.com/util/osdiag from tailscale.com/ipn/localapi W đŸ’Ŗ tailscale.com/util/osdiag/internal/wsc from tailscale.com/util/osdiag - tailscale.com/util/osshare from tailscale.com/ipn/ipnlocal tailscale.com/util/osuser from tailscale.com/ipn/ipnlocal - tailscale.com/util/progresstracking from tailscale.com/ipn/localapi tailscale.com/util/race from tailscale.com/net/dns/resolver tailscale.com/util/racebuild from tailscale.com/logpolicy tailscale.com/util/rands from tailscale.com/ipn/ipnlocal+ - tailscale.com/util/ringbuffer from tailscale.com/wgengine/magicsock + tailscale.com/util/ringlog from tailscale.com/wgengine/magicsock tailscale.com/util/set from tailscale.com/cmd/k8s-operator+ tailscale.com/util/singleflight from tailscale.com/control/controlclient+ tailscale.com/util/slicesx from tailscale.com/appc+ - tailscale.com/util/syspolicy from tailscale.com/control/controlclient+ - tailscale.com/util/syspolicy/internal from tailscale.com/util/syspolicy/setting - tailscale.com/util/syspolicy/setting from tailscale.com/util/syspolicy - tailscale.com/util/sysresources from tailscale.com/wgengine/magicsock - tailscale.com/util/systemd from tailscale.com/control/controlclient+ + tailscale.com/util/syspolicy from tailscale.com/feature/syspolicy + tailscale.com/util/syspolicy/internal from tailscale.com/util/syspolicy/setting+ + tailscale.com/util/syspolicy/internal/loggerx from tailscale.com/util/syspolicy/internal/metrics+ + tailscale.com/util/syspolicy/internal/metrics from tailscale.com/util/syspolicy/source + tailscale.com/util/syspolicy/pkey from tailscale.com/control/controlclient+ + tailscale.com/util/syspolicy/policyclient from tailscale.com/control/controlclient+ + tailscale.com/util/syspolicy/ptype from tailscale.com/util/syspolicy+ + tailscale.com/util/syspolicy/rsop from tailscale.com/util/syspolicy+ + tailscale.com/util/syspolicy/setting from tailscale.com/util/syspolicy+ + tailscale.com/util/syspolicy/source from tailscale.com/util/syspolicy+ tailscale.com/util/testenv from tailscale.com/control/controlclient+ tailscale.com/util/truncate from tailscale.com/logtail - tailscale.com/util/uniq from tailscale.com/ipn/ipnlocal+ tailscale.com/util/usermetric from tailscale.com/health+ tailscale.com/util/vizerror from tailscale.com/tailcfg+ - đŸ’Ŗ tailscale.com/util/winutil from tailscale.com/clientupdate+ - W đŸ’Ŗ tailscale.com/util/winutil/authenticode from tailscale.com/clientupdate+ - W đŸ’Ŗ tailscale.com/util/winutil/gp from tailscale.com/net/dns + đŸ’Ŗ tailscale.com/util/winutil from tailscale.com/hostinfo+ + W đŸ’Ŗ tailscale.com/util/winutil/authenticode from tailscale.com/util/osdiag + W đŸ’Ŗ tailscale.com/util/winutil/gp from tailscale.com/net/dns+ W tailscale.com/util/winutil/policy from tailscale.com/ipn/ipnlocal W đŸ’Ŗ tailscale.com/util/winutil/winenv from tailscale.com/hostinfo+ tailscale.com/util/zstdframe from tailscale.com/control/controlclient+ tailscale.com/version from tailscale.com/client/web+ tailscale.com/version/distro from tailscale.com/client/web+ tailscale.com/wgengine from tailscale.com/ipn/ipnlocal+ - tailscale.com/wgengine/capture from tailscale.com/ipn/ipnlocal+ tailscale.com/wgengine/filter from tailscale.com/control/controlclient+ tailscale.com/wgengine/filter/filtertype from tailscale.com/types/netmap+ đŸ’Ŗ tailscale.com/wgengine/magicsock from tailscale.com/ipn/ipnlocal+ @@ -839,44 +911,47 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ tailscale.com/wgengine/wgcfg/nmcfg from tailscale.com/ipn/ipnlocal đŸ’Ŗ tailscale.com/wgengine/wgint from tailscale.com/wgengine+ tailscale.com/wgengine/wglog from tailscale.com/wgengine - W đŸ’Ŗ tailscale.com/wgengine/winnet from tailscale.com/wgengine/router golang.org/x/crypto/argon2 from tailscale.com/tka golang.org/x/crypto/blake2b from golang.org/x/crypto/argon2+ golang.org/x/crypto/blake2s from github.com/tailscale/wireguard-go/device+ - LD golang.org/x/crypto/blowfish from github.com/tailscale/golang-x-crypto/ssh/internal/bcrypt_pbkdf - golang.org/x/crypto/chacha20 from github.com/tailscale/golang-x-crypto/ssh+ - golang.org/x/crypto/chacha20poly1305 from crypto/tls+ - golang.org/x/crypto/cryptobyte from crypto/ecdsa+ - golang.org/x/crypto/cryptobyte/asn1 from crypto/ecdsa+ - golang.org/x/crypto/curve25519 from github.com/tailscale/golang-x-crypto/ssh+ - golang.org/x/crypto/hkdf from crypto/tls+ + LD golang.org/x/crypto/blowfish from golang.org/x/crypto/ssh/internal/bcrypt_pbkdf + golang.org/x/crypto/chacha20 from golang.org/x/crypto/ssh+ + golang.org/x/crypto/chacha20poly1305 from github.com/tailscale/wireguard-go/device+ + golang.org/x/crypto/curve25519 from golang.org/x/crypto/ssh+ + golang.org/x/crypto/hkdf from tailscale.com/control/controlbase + golang.org/x/crypto/internal/alias from golang.org/x/crypto/chacha20+ + golang.org/x/crypto/internal/poly1305 from golang.org/x/crypto/chacha20poly1305+ golang.org/x/crypto/nacl/box from tailscale.com/types/key golang.org/x/crypto/nacl/secretbox from golang.org/x/crypto/nacl/box golang.org/x/crypto/poly1305 from github.com/tailscale/wireguard-go/device golang.org/x/crypto/salsa20/salsa from golang.org/x/crypto/nacl/box+ - golang.org/x/crypto/sha3 from crypto/internal/mlkem768+ - golang.org/x/exp/constraints from github.com/dblohm7/wingoes/pe+ + LD golang.org/x/crypto/ssh from tailscale.com/ipn/ipnlocal + LD golang.org/x/crypto/ssh/internal/bcrypt_pbkdf from golang.org/x/crypto/ssh + golang.org/x/exp/constraints from tailscale.com/tsweb/varz+ golang.org/x/exp/maps from sigs.k8s.io/controller-runtime/pkg/cache+ golang.org/x/exp/slices from tailscale.com/cmd/k8s-operator+ - golang.org/x/net/bpf from github.com/mdlayher/genetlink+ - golang.org/x/net/dns/dnsmessage from net+ + golang.org/x/net/bpf from github.com/mdlayher/netlink+ + golang.org/x/net/dns/dnsmessage from tailscale.com/appc+ golang.org/x/net/http/httpguts from golang.org/x/net/http2+ - golang.org/x/net/http/httpproxy from net/http+ - golang.org/x/net/http2 from golang.org/x/net/http2/h2c+ - golang.org/x/net/http2/h2c from tailscale.com/ipn/ipnlocal + golang.org/x/net/http/httpproxy from tailscale.com/net/tshttpproxy + golang.org/x/net/http2 from k8s.io/apimachinery/pkg/util/net+ golang.org/x/net/http2/hpack from golang.org/x/net/http2+ golang.org/x/net/icmp from github.com/prometheus-community/pro-bing+ golang.org/x/net/idna from golang.org/x/net/http/httpguts+ - golang.org/x/net/ipv4 from github.com/miekg/dns+ - golang.org/x/net/ipv6 from github.com/miekg/dns+ + golang.org/x/net/internal/httpcommon from golang.org/x/net/http2 + golang.org/x/net/internal/iana from golang.org/x/net/icmp+ + golang.org/x/net/internal/socket from golang.org/x/net/icmp+ + golang.org/x/net/internal/socks from golang.org/x/net/proxy + golang.org/x/net/ipv4 from github.com/prometheus-community/pro-bing+ + golang.org/x/net/ipv6 from github.com/prometheus-community/pro-bing+ golang.org/x/net/proxy from tailscale.com/net/netns - D golang.org/x/net/route from net+ + D golang.org/x/net/route from tailscale.com/net/netmon+ golang.org/x/net/websocket from tailscale.com/k8s-operator/sessionrecording/ws golang.org/x/oauth2 from golang.org/x/oauth2/clientcredentials+ - golang.org/x/oauth2/clientcredentials from tailscale.com/cmd/k8s-operator + golang.org/x/oauth2/clientcredentials from tailscale.com/cmd/k8s-operator+ golang.org/x/oauth2/internal from golang.org/x/oauth2+ golang.org/x/sync/errgroup from github.com/mdlayher/socket+ - golang.org/x/sys/cpu from github.com/josharian/native+ + golang.org/x/sys/cpu from github.com/tailscale/certstore+ LD golang.org/x/sys/unix from github.com/fsnotify/fsnotify+ W golang.org/x/sys/windows from github.com/dblohm7/wingoes+ W golang.org/x/sys/windows/registry from github.com/dblohm7/wingoes+ @@ -888,18 +963,33 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ golang.org/x/text/unicode/bidi from golang.org/x/net/idna+ golang.org/x/text/unicode/norm from golang.org/x/net/idna golang.org/x/time/rate from gvisor.dev/gvisor/pkg/log+ - archive/tar from tailscale.com/clientupdate + vendor/golang.org/x/crypto/chacha20 from vendor/golang.org/x/crypto/chacha20poly1305 + vendor/golang.org/x/crypto/chacha20poly1305 from crypto/internal/hpke+ + vendor/golang.org/x/crypto/cryptobyte from crypto/ecdsa+ + vendor/golang.org/x/crypto/cryptobyte/asn1 from crypto/ecdsa+ + vendor/golang.org/x/crypto/internal/alias from vendor/golang.org/x/crypto/chacha20+ + vendor/golang.org/x/crypto/internal/poly1305 from vendor/golang.org/x/crypto/chacha20poly1305 + vendor/golang.org/x/net/dns/dnsmessage from net + vendor/golang.org/x/net/http/httpguts from net/http+ + vendor/golang.org/x/net/http/httpproxy from net/http + vendor/golang.org/x/net/http2/hpack from net/http+ + vendor/golang.org/x/net/idna from net/http+ + vendor/golang.org/x/sys/cpu from vendor/golang.org/x/crypto/chacha20poly1305 + vendor/golang.org/x/text/secure/bidirule from vendor/golang.org/x/net/idna + vendor/golang.org/x/text/transform from vendor/golang.org/x/text/secure/bidirule+ + vendor/golang.org/x/text/unicode/bidi from vendor/golang.org/x/net/idna+ + vendor/golang.org/x/text/unicode/norm from vendor/golang.org/x/net/idna bufio from compress/flate+ - bytes from archive/tar+ + bytes from bufio+ cmp from github.com/gaissmai/bart+ compress/flate from compress/gzip+ - compress/gzip from github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding+ - compress/zlib from debug/pe+ + compress/gzip from github.com/emicklei/go-restful/v3+ + compress/zlib from github.com/emicklei/go-restful/v3+ container/heap from gvisor.dev/gvisor/pkg/tcpip/transport/tcp+ container/list from crypto/tls+ context from crypto/tls+ crypto from crypto/ecdh+ - crypto/aes from crypto/ecdsa+ + crypto/aes from crypto/internal/hpke+ crypto/cipher from crypto/aes+ crypto/des from crypto/tls+ crypto/dsa from crypto/x509+ @@ -907,38 +997,81 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ crypto/ecdsa from crypto/tls+ crypto/ed25519 from crypto/tls+ crypto/elliptic from crypto/ecdsa+ + crypto/fips140 from crypto/tls/internal/fips140tls+ + crypto/hkdf from crypto/internal/hpke+ crypto/hmac from crypto/tls+ + crypto/internal/boring from crypto/aes+ + crypto/internal/boring/bbig from crypto/ecdsa+ + crypto/internal/boring/sig from crypto/internal/boring + crypto/internal/entropy from crypto/internal/fips140/drbg + crypto/internal/fips140 from crypto/internal/fips140/aes+ + crypto/internal/fips140/aes from crypto/aes+ + crypto/internal/fips140/aes/gcm from crypto/cipher+ + crypto/internal/fips140/alias from crypto/cipher+ + crypto/internal/fips140/bigmod from crypto/internal/fips140/ecdsa+ + crypto/internal/fips140/check from crypto/internal/fips140/aes+ + crypto/internal/fips140/drbg from crypto/internal/fips140/aes/gcm+ + crypto/internal/fips140/ecdh from crypto/ecdh + crypto/internal/fips140/ecdsa from crypto/ecdsa + crypto/internal/fips140/ed25519 from crypto/ed25519 + crypto/internal/fips140/edwards25519 from crypto/internal/fips140/ed25519 + crypto/internal/fips140/edwards25519/field from crypto/ecdh+ + crypto/internal/fips140/hkdf from crypto/internal/fips140/tls13+ + crypto/internal/fips140/hmac from crypto/hmac+ + crypto/internal/fips140/mlkem from crypto/tls+ + crypto/internal/fips140/nistec from crypto/elliptic+ + crypto/internal/fips140/nistec/fiat from crypto/internal/fips140/nistec + crypto/internal/fips140/rsa from crypto/rsa + crypto/internal/fips140/sha256 from crypto/internal/fips140/check+ + crypto/internal/fips140/sha3 from crypto/internal/fips140/hmac+ + crypto/internal/fips140/sha512 from crypto/internal/fips140/ecdsa+ + crypto/internal/fips140/subtle from crypto/internal/fips140/aes+ + crypto/internal/fips140/tls12 from crypto/tls + crypto/internal/fips140/tls13 from crypto/tls + crypto/internal/fips140cache from crypto/ecdsa+ + crypto/internal/fips140deps/byteorder from crypto/internal/fips140/aes+ + crypto/internal/fips140deps/cpu from crypto/internal/fips140/aes+ + crypto/internal/fips140deps/godebug from crypto/internal/fips140+ + crypto/internal/fips140hash from crypto/ecdsa+ + crypto/internal/fips140only from crypto/cipher+ + crypto/internal/hpke from crypto/tls + crypto/internal/impl from crypto/internal/fips140/aes+ + crypto/internal/randutil from crypto/dsa+ + crypto/internal/sysrand from crypto/internal/entropy+ crypto/md5 from crypto/tls+ + LD crypto/mlkem from golang.org/x/crypto/ssh crypto/rand from crypto/ed25519+ crypto/rc4 from crypto/tls+ crypto/rsa from crypto/tls+ crypto/sha1 from crypto/tls+ crypto/sha256 from crypto/tls+ + crypto/sha3 from crypto/internal/fips140hash crypto/sha512 from crypto/ecdsa+ - crypto/subtle from crypto/aes+ - crypto/tls from github.com/aws/aws-sdk-go-v2/aws/transport/http+ + crypto/subtle from crypto/cipher+ + crypto/tls from github.com/prometheus-community/pro-bing+ + crypto/tls/internal/fips140tls from crypto/tls crypto/x509 from crypto/tls+ + D crypto/x509/internal/macos from crypto/x509 crypto/x509/pkix from crypto/x509+ database/sql from github.com/prometheus/client_golang/prometheus/collectors database/sql/driver from database/sql+ W debug/dwarf from debug/pe W debug/pe from github.com/dblohm7/wingoes/pe - embed from crypto/internal/nistec+ - encoding from encoding/gob+ + embed from github.com/tailscale/web-client-prebuilt+ + encoding from encoding/json+ encoding/asn1 from crypto/x509+ encoding/base32 from github.com/fxamacker/cbor/v2+ encoding/base64 from encoding/json+ encoding/binary from compress/gzip+ encoding/csv from github.com/spf13/pflag - encoding/gob from github.com/gorilla/securecookie encoding/hex from crypto/x509+ encoding/json from expvar+ encoding/pem from crypto/tls+ - encoding/xml from github.com/aws/aws-sdk-go-v2/aws/protocol/xml+ - errors from archive/tar+ + encoding/xml from github.com/emicklei/go-restful/v3+ + errors from bufio+ expvar from github.com/prometheus/client_golang/prometheus+ flag from github.com/spf13/pflag+ - fmt from archive/tar+ + fmt from compress/flate+ go/ast from go/doc+ go/build/constraint from go/parser go/doc from k8s.io/apimachinery/pkg/runtime @@ -947,64 +1080,118 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ go/scanner from go/ast+ go/token from go/ast+ hash from compress/zlib+ - hash/adler32 from compress/zlib+ + hash/adler32 from compress/zlib hash/crc32 from compress/gzip+ hash/fnv from google.golang.org/protobuf/internal/detrand hash/maphash from go4.org/mem html from html/template+ - html/template from github.com/gorilla/csrf - io from archive/tar+ - io/fs from archive/tar+ - io/ioutil from github.com/aws/aws-sdk-go-v2/aws/protocol/query+ + html/template from tailscale.com/util/eventbus + internal/abi from crypto/x509/internal/macos+ + internal/asan from internal/runtime/maps+ + internal/bisect from internal/godebug + internal/bytealg from bytes+ + internal/byteorder from crypto/cipher+ + internal/chacha8rand from math/rand/v2+ + internal/coverage/rtcov from runtime + internal/cpu from crypto/internal/fips140deps/cpu+ + internal/filepathlite from os+ + internal/fmtsort from fmt+ + internal/goarch from crypto/internal/fips140deps/cpu+ + internal/godebug from crypto/internal/fips140deps/godebug+ + internal/godebugs from internal/godebug+ + internal/goexperiment from hash/maphash+ + internal/goos from crypto/x509+ + internal/itoa from internal/poll+ + internal/lazyregexp from go/doc + internal/msan from internal/runtime/maps+ + internal/nettrace from net+ + internal/oserror from io/fs+ + internal/poll from net+ + internal/profile from net/http/pprof + internal/profilerecord from runtime+ + internal/race from internal/poll+ + internal/reflectlite from context+ + D internal/routebsd from net + internal/runtime/atomic from internal/runtime/exithook+ + L internal/runtime/cgroup from runtime + internal/runtime/exithook from runtime + internal/runtime/gc from runtime + internal/runtime/maps from reflect+ + internal/runtime/math from internal/runtime/maps+ + internal/runtime/strconv from internal/runtime/cgroup+ + internal/runtime/sys from crypto/subtle+ + L internal/runtime/syscall from runtime+ + internal/saferio from debug/pe+ + internal/singleflight from net + internal/stringslite from embed+ + internal/sync from sync+ + internal/synctest from sync + internal/syscall/execenv from os+ + LD internal/syscall/unix from crypto/internal/sysrand+ + W internal/syscall/windows from crypto/internal/sysrand+ + W internal/syscall/windows/registry from mime+ + W internal/syscall/windows/sysdll from internal/syscall/windows+ + internal/testlog from os + internal/trace/tracev2 from runtime+ + internal/unsafeheader from internal/reflectlite+ + io from bufio+ + io/fs from crypto/x509+ + io/ioutil from github.com/godbus/dbus/v5+ iter from go/ast+ log from expvar+ log/internal from log+ log/slog from github.com/go-logr/logr+ log/slog/internal from log/slog + log/slog/internal/buffer from log/slog maps from sigs.k8s.io/controller-runtime/pkg/predicate+ - math from archive/tar+ + math from compress/flate+ math/big from crypto/dsa+ math/bits from compress/flate+ math/rand from github.com/google/go-cmp/cmp+ - math/rand/v2 from tailscale.com/derp+ + math/rand/v2 from crypto/ecdsa+ mime from github.com/prometheus/common/expfmt+ mime/multipart from github.com/go-openapi/swag+ mime/quotedprintable from mime/multipart net from crypto/tls+ net/http from expvar+ - net/http/httptest from tailscale.com/control/controlclient net/http/httptrace from github.com/prometheus-community/pro-bing+ - net/http/httputil from github.com/aws/smithy-go/transport/http+ + net/http/httputil from tailscale.com/client/web+ net/http/internal from net/http+ + net/http/internal/ascii from net/http+ + net/http/internal/httpcommon from net/http net/http/pprof from sigs.k8s.io/controller-runtime/pkg/manager+ net/netip from github.com/gaissmai/bart+ - net/textproto from github.com/aws/aws-sdk-go-v2/aws/signer/v4+ + net/textproto from github.com/coder/websocket+ net/url from crypto/x509+ - os from crypto/rand+ - os/exec from github.com/aws/aws-sdk-go-v2/credentials/processcreds+ + os from crypto/internal/sysrand+ + os/exec from github.com/godbus/dbus/v5+ os/signal from sigs.k8s.io/controller-runtime/pkg/manager/signals - os/user from archive/tar+ - path from archive/tar+ - path/filepath from archive/tar+ - reflect from archive/tar+ - regexp from github.com/aws/aws-sdk-go-v2/internal/endpoints+ + os/user from github.com/godbus/dbus/v5+ + path from debug/dwarf+ + path/filepath from crypto/x509+ + reflect from crypto/x509+ + regexp from github.com/davecgh/go-spew/spew+ regexp/syntax from regexp - runtime/debug from github.com/aws/aws-sdk-go-v2/internal/sync/singleflight+ + runtime from crypto/internal/fips140+ + runtime/debug from github.com/coder/websocket/internal/xsync+ runtime/metrics from github.com/prometheus/client_golang/prometheus+ runtime/pprof from net/http/pprof+ runtime/trace from net/http/pprof slices from encoding/base32+ sort from compress/flate+ - strconv from archive/tar+ - strings from archive/tar+ - sync from archive/tar+ + strconv from compress/flate+ + strings from bufio+ + W structs from internal/syscall/windows + sync from compress/flate+ sync/atomic from context+ - syscall from archive/tar+ + syscall from crypto/internal/sysrand+ text/tabwriter from k8s.io/apimachinery/pkg/util/diff+ - text/template from html/template + text/template from html/template+ text/template/parse from html/template+ - time from archive/tar+ + time from compress/gzip+ unicode from bytes+ unicode/utf16 from crypto/x509+ unicode/utf8 from bufio+ unique from net/netip + unsafe from bytes+ + weak from unique+ diff --git a/cmd/k8s-operator/deploy/chart/Chart.yaml b/cmd/k8s-operator/deploy/chart/Chart.yaml index 363d87d15..9db6389d1 100644 --- a/cmd/k8s-operator/deploy/chart/Chart.yaml +++ b/cmd/k8s-operator/deploy/chart/Chart.yaml @@ -26,4 +26,4 @@ maintainers: version: 0.1.0 # appVersion will be set to Tailscale repo tag at release time. -appVersion: "unstable" +appVersion: "stable" diff --git a/cmd/k8s-operator/deploy/chart/templates/.gitignore b/cmd/k8s-operator/deploy/chart/templates/.gitignore new file mode 100644 index 000000000..ae7c682d9 --- /dev/null +++ b/cmd/k8s-operator/deploy/chart/templates/.gitignore @@ -0,0 +1,10 @@ +# Don't add helm chart CRDs to git. Canonical CRD files live in +# cmd/k8s-operator/deploy/crds. +# +# Generate for local usage with: +# go run tailscale.com/cmd/k8s-operator/generate helmcrd +/connector.yaml +/dnsconfig.yaml +/proxyclass.yaml +/proxygroup.yaml +/recorder.yaml diff --git a/cmd/k8s-operator/deploy/chart/templates/NOTES.txt b/cmd/k8s-operator/deploy/chart/templates/NOTES.txt new file mode 100644 index 000000000..1bee67046 --- /dev/null +++ b/cmd/k8s-operator/deploy/chart/templates/NOTES.txt @@ -0,0 +1,27 @@ +You have successfully installed the Tailscale Kubernetes Operator! + +Once connected, the operator should appear as a device within the Tailscale admin console: +https://login.tailscale.com/admin/machines + +If you have not used the Tailscale operator before, here are some examples to try out: + +* Private Kubernetes API access and authorization using the API server proxy + https://tailscale.com/kb/1437/kubernetes-operator-api-server-proxy + +* Private access to cluster Services using an ingress proxy + https://tailscale.com/kb/1439/kubernetes-operator-cluster-ingress + +* Private access to the cluster's available subnets using a subnet router + https://tailscale.com/kb/1441/kubernetes-operator-connector + +You can also explore the CRDs, operator, and associated resources within the {{ .Release.Namespace }} namespace: + +$ kubectl explain connector +$ kubectl explain proxygroup +$ kubectl explain proxyclass +$ kubectl explain recorder +$ kubectl explain dnsconfig + +If you're interested to explore what resources were created: + +$ kubectl --namespace={{ .Release.Namespace }} get all -l app.kubernetes.io/managed-by=Helm diff --git a/cmd/k8s-operator/deploy/chart/templates/apiserverproxy-rbac.yaml b/cmd/k8s-operator/deploy/chart/templates/apiserverproxy-rbac.yaml index 072ecf6d2..d6e9d1bf4 100644 --- a/cmd/k8s-operator/deploy/chart/templates/apiserverproxy-rbac.yaml +++ b/cmd/k8s-operator/deploy/chart/templates/apiserverproxy-rbac.yaml @@ -1,7 +1,16 @@ # Copyright (c) Tailscale Inc & AUTHORS # SPDX-License-Identifier: BSD-3-Clause -{{ if eq .Values.apiServerProxyConfig.mode "true" }} +# If old setting used, enable both old (operator) and new (ProxyGroup) workflows. +# If new setting used, enable only new workflow. +{{ if or (eq (toString .Values.apiServerProxyConfig.mode) "true") + (eq (toString .Values.apiServerProxyConfig.allowImpersonation) "true") }} +apiVersion: v1 +kind: ServiceAccount +metadata: + name: kube-apiserver-auth-proxy + namespace: {{ .Release.Namespace }} +--- apiVersion: rbac.authorization.k8s.io/v1 kind: ClusterRole metadata: @@ -16,9 +25,14 @@ kind: ClusterRoleBinding metadata: name: tailscale-auth-proxy subjects: +{{- if eq (toString .Values.apiServerProxyConfig.mode) "true" }} - kind: ServiceAccount name: operator namespace: {{ .Release.Namespace }} +{{- end }} +- kind: ServiceAccount + name: kube-apiserver-auth-proxy + namespace: {{ .Release.Namespace }} roleRef: kind: ClusterRole name: tailscale-auth-proxy diff --git a/cmd/k8s-operator/deploy/chart/templates/deployment.yaml b/cmd/k8s-operator/deploy/chart/templates/deployment.yaml index c428d5d1e..0f2dc42fc 100644 --- a/cmd/k8s-operator/deploy/chart/templates/deployment.yaml +++ b/cmd/k8s-operator/deploy/chart/templates/deployment.yaml @@ -34,10 +34,27 @@ spec: securityContext: {{- toYaml . | nindent 8 }} {{- end }} + {{- if or .Values.oauth.clientSecret .Values.oauth.audience }} volumes: - - name: oauth - secret: - secretName: operator-oauth + {{- if .Values.oauth.clientSecret }} + - name: oauth + {{- with .Values.oauthSecretVolume }} + {{- toYaml . | nindent 10 }} + {{- else }} + secret: + secretName: operator-oauth + {{- end }} + {{- else }} + - name: oidc-jwt + projected: + defaultMode: 420 + sources: + - serviceAccountToken: + audience: {{ .Values.oauth.audience }} + expirationSeconds: 3600 + path: token + {{- end }} + {{- end }} containers: - name: operator {{- with .Values.operatorConfig.securityContext }} @@ -64,10 +81,19 @@ spec: valueFrom: fieldRef: fieldPath: metadata.namespace + - name: OPERATOR_LOGIN_SERVER + value: {{ .Values.loginServer }} + - name: OPERATOR_INGRESS_CLASS_NAME + value: {{ .Values.ingressClass.name }} + {{- if .Values.oauth.clientSecret }} - name: CLIENT_ID_FILE value: /oauth/client_id - name: CLIENT_SECRET_FILE value: /oauth/client_secret + {{- else if .Values.oauth.audience }} + - name: CLIENT_ID + value: {{ .Values.oauth.clientId }} + {{- end }} {{- $proxyTag := printf ":%s" ( .Values.proxyConfig.image.tag | default .Chart.AppVersion )}} - name: PROXY_IMAGE value: {{ coalesce .Values.proxyConfig.image.repo .Values.proxyConfig.image.repository }}{{- if .Values.proxyConfig.image.digest -}}{{ printf "@%s" .Values.proxyConfig.image.digest}}{{- else -}}{{ printf "%s" $proxyTag }}{{- end }} @@ -81,13 +107,29 @@ spec: - name: PROXY_DEFAULT_CLASS value: {{ .Values.proxyConfig.defaultProxyClass }} {{- end }} + - name: POD_NAME + valueFrom: + fieldRef: + fieldPath: metadata.name + - name: POD_UID + valueFrom: + fieldRef: + fieldPath: metadata.uid {{- with .Values.operatorConfig.extraEnv }} {{- toYaml . | nindent 12 }} {{- end }} + {{- if or .Values.oauth.clientSecret .Values.oauth.audience }} volumeMounts: + {{- if .Values.oauth.clientSecret }} - name: oauth mountPath: /oauth readOnly: true + {{- else }} + - name: oidc-jwt + mountPath: /var/run/secrets/tailscale/serviceaccount + readOnly: true + {{- end }} + {{- end }} {{- with .Values.operatorConfig.nodeSelector }} nodeSelector: {{- toYaml . | nindent 8 }} diff --git a/cmd/k8s-operator/deploy/chart/templates/ingressclass.yaml b/cmd/k8s-operator/deploy/chart/templates/ingressclass.yaml index 2a1fa81b4..54851955d 100644 --- a/cmd/k8s-operator/deploy/chart/templates/ingressclass.yaml +++ b/cmd/k8s-operator/deploy/chart/templates/ingressclass.yaml @@ -1,8 +1,10 @@ +{{- if .Values.ingressClass.enabled }} apiVersion: networking.k8s.io/v1 kind: IngressClass metadata: - name: tailscale # class name currently can not be changed + name: {{ .Values.ingressClass.name }} annotations: {} # we do not support default IngressClass annotation https://kubernetes.io/docs/concepts/services-networking/ingress/#default-ingress-class spec: controller: tailscale.com/ts-ingress # controller name currently can not be changed # parameters: {} # currently no parameters are supported +{{- end }} diff --git a/cmd/k8s-operator/deploy/chart/templates/oauth-secret.yaml b/cmd/k8s-operator/deploy/chart/templates/oauth-secret.yaml index b44fde0a1..b85c78915 100644 --- a/cmd/k8s-operator/deploy/chart/templates/oauth-secret.yaml +++ b/cmd/k8s-operator/deploy/chart/templates/oauth-secret.yaml @@ -1,7 +1,7 @@ # Copyright (c) Tailscale Inc & AUTHORS # SPDX-License-Identifier: BSD-3-Clause -{{ if and .Values.oauth .Values.oauth.clientId -}} +{{ if and .Values.oauth .Values.oauth.clientId .Values.oauth.clientSecret -}} apiVersion: v1 kind: Secret metadata: diff --git a/cmd/k8s-operator/deploy/chart/templates/operator-rbac.yaml b/cmd/k8s-operator/deploy/chart/templates/operator-rbac.yaml index ede61070b..5eb920a6f 100644 --- a/cmd/k8s-operator/deploy/chart/templates/operator-rbac.yaml +++ b/cmd/k8s-operator/deploy/chart/templates/operator-rbac.yaml @@ -6,12 +6,19 @@ kind: ServiceAccount metadata: name: operator namespace: {{ .Release.Namespace }} + {{- with .Values.operatorConfig.serviceAccountAnnotations }} + annotations: + {{- toYaml . | nindent 4 }} + {{- end }} --- apiVersion: rbac.authorization.k8s.io/v1 kind: ClusterRole metadata: name: tailscale-operator rules: +- apiGroups: [""] + resources: ["nodes"] + verbs: ["get", "list", "watch"] - apiGroups: [""] resources: ["events", "services", "services/status"] verbs: ["create","delete","deletecollection","get","list","patch","update","watch"] @@ -21,6 +28,9 @@ rules: - apiGroups: ["networking.k8s.io"] resources: ["ingressclasses"] verbs: ["get", "list", "watch"] +- apiGroups: ["discovery.k8s.io"] + resources: ["endpointslices"] + verbs: ["get", "list", "watch"] - apiGroups: ["tailscale.com"] resources: ["connectors", "connectors/status", "proxyclasses", "proxyclasses/status", "proxygroups", "proxygroups/status"] verbs: ["get", "list", "watch", "update"] @@ -30,6 +40,10 @@ rules: - apiGroups: ["tailscale.com"] resources: ["recorders", "recorders/status"] verbs: ["get", "list", "watch", "update"] +- apiGroups: ["apiextensions.k8s.io"] + resources: ["customresourcedefinitions"] + verbs: ["get", "list", "watch"] + resourceNames: ["servicemonitors.monitoring.coreos.com"] --- apiVersion: rbac.authorization.k8s.io/v1 kind: ClusterRoleBinding @@ -55,7 +69,10 @@ rules: verbs: ["create","delete","deletecollection","get","list","patch","update","watch"] - apiGroups: [""] resources: ["pods"] - verbs: ["get","list","watch"] + verbs: ["get","list","watch", "update"] +- apiGroups: [""] + resources: ["pods/status"] + verbs: ["update"] - apiGroups: ["apps"] resources: ["statefulsets", "deployments"] verbs: ["create","delete","deletecollection","get","list","patch","update","watch"] @@ -64,7 +81,10 @@ rules: verbs: ["get", "list", "watch", "create", "update", "deletecollection"] - apiGroups: ["rbac.authorization.k8s.io"] resources: ["roles", "rolebindings"] - verbs: ["get", "create", "patch", "update", "list", "watch"] + verbs: ["get", "create", "patch", "update", "list", "watch", "deletecollection"] +- apiGroups: ["monitoring.coreos.com"] + resources: ["servicemonitors"] + verbs: ["get", "list", "update", "create", "delete"] --- apiVersion: rbac.authorization.k8s.io/v1 kind: RoleBinding diff --git a/cmd/k8s-operator/deploy/chart/templates/proxy-rbac.yaml b/cmd/k8s-operator/deploy/chart/templates/proxy-rbac.yaml index 1c15c9119..fa552a7c7 100644 --- a/cmd/k8s-operator/deploy/chart/templates/proxy-rbac.yaml +++ b/cmd/k8s-operator/deploy/chart/templates/proxy-rbac.yaml @@ -16,6 +16,9 @@ rules: - apiGroups: [""] resources: ["secrets"] verbs: ["create","delete","deletecollection","get","list","patch","update","watch"] +- apiGroups: [""] + resources: ["events"] + verbs: ["create", "patch", "get"] --- apiVersion: rbac.authorization.k8s.io/v1 kind: RoleBinding diff --git a/cmd/k8s-operator/deploy/chart/values.yaml b/cmd/k8s-operator/deploy/chart/values.yaml index 43ed382c6..eb11fc7f2 100644 --- a/cmd/k8s-operator/deploy/chart/values.yaml +++ b/cmd/k8s-operator/deploy/chart/values.yaml @@ -1,12 +1,37 @@ # Copyright (c) Tailscale Inc & AUTHORS # SPDX-License-Identifier: BSD-3-Clause -# Operator oauth credentials. If set a Kubernetes Secret with the provided -# values will be created in the operator namespace. If unset a Secret named -# operator-oauth must be precreated. -oauth: {} - # clientId: "" - # clientSecret: "" +# Operator oauth credentials. If unset a Secret named operator-oauth must be +# precreated or oauthSecretVolume needs to be adjusted. This block will be +# overridden by oauthSecretVolume, if set. +oauth: + # The Client ID the operator will authenticate with. + clientId: "" + # If set a Kubernetes Secret with the provided value will be created in + # the operator namespace, and mounted into the operator Pod. Takes precedence + # over oauth.audience. + clientSecret: "" + # The audience for oauth.clientId if using a workload identity federation + # OAuth client. Mutually exclusive with oauth.clientSecret. + # See https://tailscale.com/kb/1581/workload-identity-federation. + audience: "" + +# URL of the control plane to be used by all resources managed by the operator. +loginServer: "" + +# Secret volume. +# If set it defines the volume the oauth secrets will be mounted from. +# The volume needs to contain two files named `client_id` and `client_secret`. +# If unset the volume will reference the Secret named operator-oauth. +# This block will override the oauth block. +oauthSecretVolume: {} + # csi: + # driver: secrets-store.csi.k8s.io + # readOnly: true + # volumeAttributes: + # secretProviderClass: tailscale-oauth + # + ## NAME is pre-defined! # installCRDs determines whether tailscale.com CRDs should be installed as part # of chart installation. We do not use Helm's CRD installation mechanism as that @@ -40,6 +65,9 @@ operatorConfig: podAnnotations: {} podLabels: {} + serviceAccountAnnotations: {} + # eks.amazonaws.com/role-arn: arn:aws:iam::123456789012:role/tailscale-operator-role + tolerations: [] affinity: {} @@ -54,6 +82,13 @@ operatorConfig: # - name: EXTRA_VAR2 # value: "value2" +# In the case that you already have a tailscale ingressclass in your cluster (or vcluster), you can disable the creation here +ingressClass: + # Allows for customization of the ingress class name used by the operator to identify ingresses to reconcile. This does + # not allow multiple operator instances to manage different ingresses, but provides an onboarding route for users that + # may have previously set up ingress classes named "tailscale" prior to using the operator. + name: "tailscale" + enabled: true # proxyConfig contains configuraton that will be applied to any ingress/egress # proxies created by the operator. @@ -64,6 +99,13 @@ operatorConfig: # If you need more configuration options, take a look at ProxyClass: # https://tailscale.com/kb/1445/kubernetes-operator-customization#cluster-resource-customization-using-proxyclass-custom-resource proxyConfig: + # Configure the proxy image to use instead of the default tailscale/tailscale:latest. + # Applying a ProxyClass with `spec.statefulSet.pod.tailscaleContainer.image` + # set will override any defaults here. + # + # Note that ProxyGroups of type "kube-apiserver" use a different default image, + # tailscale/k8s-proxy:latest, and it is currently only possible to override + # that image via the same ProxyClass field. image: # Repository defaults to DockerHub, but images are also synced to ghcr.io/tailscale/tailscale. repository: tailscale/tailscale @@ -79,13 +121,23 @@ proxyConfig: defaultTags: "tag:k8s" firewallMode: auto # If defined, this proxy class will be used as the default proxy class for - # service and ingress resources that do not have a proxy class defined. + # service and ingress resources that do not have a proxy class defined. It + # does not apply to Connector resources. defaultProxyClass: "" # apiServerProxyConfig allows to configure whether the operator should expose # Kubernetes API server. # https://tailscale.com/kb/1437/kubernetes-operator-api-server-proxy apiServerProxyConfig: + # Set to "true" to create the ClusterRole permissions required for the API + # server proxy's auth mode. In auth mode, the API server proxy impersonates + # groups and users based on tailnet ACL grants. Required for ProxyGroups of + # type "kube-apiserver" running in auth mode. + allowImpersonation: "false" # "true", "false" + + # If true or noauth, the operator will run an in-process API server proxy. + # You can deploy a ProxyGroup of type "kube-apiserver" to run a high + # availability set of API server proxies instead. mode: "false" # "true", "false", "noauth" imagePullSecrets: [] diff --git a/cmd/k8s-operator/deploy/crds/tailscale.com_connectors.yaml b/cmd/k8s-operator/deploy/crds/tailscale.com_connectors.yaml index 9614f74e6..74d32d53d 100644 --- a/cmd/k8s-operator/deploy/crds/tailscale.com_connectors.yaml +++ b/cmd/k8s-operator/deploy/crds/tailscale.com_connectors.yaml @@ -2,7 +2,7 @@ apiVersion: apiextensions.k8s.io/v1 kind: CustomResourceDefinition metadata: annotations: - controller-gen.kubebuilder.io/version: v0.15.1-0.20240618033008-7824932b0cab + controller-gen.kubebuilder.io/version: v0.17.0 name: connectors.tailscale.com spec: group: tailscale.com @@ -24,10 +24,17 @@ spec: jsonPath: .status.isExitNode name: IsExitNode type: string + - description: Whether this Connector instance is an app connector. + jsonPath: .status.isAppConnector + name: IsAppConnector + type: string - description: Status of the deployed Connector resources. jsonPath: .status.conditions[?(@.type == "ConnectorReady")].reason name: Status type: string + - jsonPath: .metadata.creationTimestamp + name: Age + type: date name: v1alpha1 schema: openAPIV3Schema: @@ -66,10 +73,40 @@ spec: https://git.k8s.io/community/contributors/devel/sig-architecture/api-conventions.md#spec-and-status type: object properties: + appConnector: + description: |- + AppConnector defines whether the Connector device should act as a Tailscale app connector. A Connector that is + configured as an app connector cannot be a subnet router or an exit node. If this field is unset, the + Connector does not act as an app connector. + Note that you will need to manually configure the permissions and the domains for the app connector via the + Admin panel. + Note also that the main tested and supported use case of this config option is to deploy an app connector on + Kubernetes to access SaaS applications available on the public internet. Using the app connector to expose + cluster workloads or other internal workloads to tailnet might work, but this is not a use case that we have + tested or optimised for. + If you are using the app connector to access SaaS applications because you need a predictable egress IP that + can be whitelisted, it is also your responsibility to ensure that cluster traffic from the connector flows + via that predictable IP, for example by enforcing that cluster egress traffic is routed via an egress NAT + device with a static IP address. + https://tailscale.com/kb/1281/app-connectors + type: object + properties: + routes: + description: |- + Routes are optional preconfigured routes for the domains routed via the app connector. + If not set, routes for the domains will be discovered dynamically. + If set, the app connector will immediately be able to route traffic using the preconfigured routes, but may + also dynamically discover other routes. + https://tailscale.com/kb/1332/apps-best-practices#preconfiguration + type: array + minItems: 1 + items: + type: string + format: cidr exitNode: description: |- - ExitNode defines whether the Connector node should act as a - Tailscale exit node. Defaults to false. + ExitNode defines whether the Connector device should act as a Tailscale exit node. Defaults to false. + This field is mutually exclusive with the appConnector field. https://tailscale.com/kb/1103/exit-nodes type: boolean hostname: @@ -78,9 +115,19 @@ spec: Connector node. If unset, hostname defaults to -connector. Hostname can contain lower case letters, numbers and dashes, it must not start or end with a dash and must be between 2 - and 63 characters long. + and 63 characters long. This field should only be used when creating a connector + with an unspecified number of replicas, or a single replica. type: string pattern: ^[a-z0-9][a-z0-9-]{0,61}[a-z0-9]$ + hostnamePrefix: + description: |- + HostnamePrefix specifies the hostname prefix for each + replica. Each device will have the integer number + from its StatefulSet pod appended to this prefix to form the full hostname. + HostnamePrefix can contain lower case letters, numbers and dashes, it + must not start with a dash and must be between 1 and 62 characters long. + type: string + pattern: ^[a-z0-9][a-z0-9-]{0,61}$ proxyClass: description: |- ProxyClass is the name of the ProxyClass custom resource that @@ -88,11 +135,21 @@ spec: resources created for this Connector. If unset, the operator will create resources with the default configuration. type: string + replicas: + description: |- + Replicas specifies how many devices to create. Set this to enable + high availability for app connectors, subnet routers, or exit nodes. + https://tailscale.com/kb/1115/high-availability. Defaults to 1. + type: integer + format: int32 + minimum: 0 subnetRouter: description: |- - SubnetRouter defines subnet routes that the Connector node should - expose to tailnet. If unset, none are exposed. + SubnetRouter defines subnet routes that the Connector device should + expose to tailnet as a Tailscale subnet router. https://tailscale.com/kb/1019/subnets/ + If this field is unset, the device does not get configured as a Tailscale subnet router. + This field is mutually exclusive with the appConnector field. type: object required: - advertiseRoutes @@ -125,8 +182,14 @@ spec: type: string pattern: ^tag:[a-zA-Z][a-zA-Z0-9-]*$ x-kubernetes-validations: - - rule: has(self.subnetRouter) || self.exitNode == true - message: A Connector needs to be either an exit node or a subnet router, or both. + - rule: has(self.subnetRouter) || (has(self.exitNode) && self.exitNode == true) || has(self.appConnector) + message: A Connector needs to have at least one of exit node, subnet router or app connector configured. + - rule: '!((has(self.subnetRouter) || (has(self.exitNode) && self.exitNode == true)) && has(self.appConnector))' + message: The appConnector field is mutually exclusive with exitNode and subnetRouter fields. + - rule: '!(has(self.hostname) && has(self.replicas) && self.replicas > 1)' + message: The hostname field cannot be specified when replicas is greater than 1. + - rule: '!(has(self.hostname) && has(self.hostnamePrefix))' + message: The hostname and hostnamePrefix fields are mutually exclusive. status: description: |- ConnectorStatus describes the status of the Connector. This is set @@ -194,12 +257,36 @@ spec: x-kubernetes-list-map-keys: - type x-kubernetes-list-type: map + devices: + description: Devices contains information on each device managed by the Connector resource. + type: array + items: + type: object + properties: + hostname: + description: |- + Hostname is the fully qualified domain name of the Connector replica. + If MagicDNS is enabled in your tailnet, it is the MagicDNS name of the + node. + type: string + tailnetIPs: + description: |- + TailnetIPs is the set of tailnet IP addresses (both IPv4 and IPv6) + assigned to the Connector replica. + type: array + items: + type: string hostname: description: |- Hostname is the fully qualified domain name of the Connector node. If MagicDNS is enabled in your tailnet, it is the MagicDNS name of the - node. + node. When using multiple replicas, this field will be populated with the + first replica's hostname. Use the Hostnames field for the full list + of hostnames. type: string + isAppConnector: + description: IsAppConnector is set to true if the Connector acts as an app connector. + type: boolean isExitNode: description: IsExitNode is set to true if the Connector acts as an exit node. type: boolean diff --git a/cmd/k8s-operator/deploy/crds/tailscale.com_dnsconfigs.yaml b/cmd/k8s-operator/deploy/crds/tailscale.com_dnsconfigs.yaml index 13aee9b9e..a819aa651 100644 --- a/cmd/k8s-operator/deploy/crds/tailscale.com_dnsconfigs.yaml +++ b/cmd/k8s-operator/deploy/crds/tailscale.com_dnsconfigs.yaml @@ -2,7 +2,7 @@ apiVersion: apiextensions.k8s.io/v1 kind: CustomResourceDefinition metadata: annotations: - controller-gen.kubebuilder.io/version: v0.15.1-0.20240618033008-7824932b0cab + controller-gen.kubebuilder.io/version: v0.17.0 name: dnsconfigs.tailscale.com spec: group: tailscale.com @@ -20,6 +20,9 @@ spec: jsonPath: .status.nameserver.ip name: NameserverIP type: string + - jsonPath: .metadata.creationTimestamp + name: Age + type: date name: v1alpha1 schema: openAPIV3Schema: @@ -49,7 +52,6 @@ spec: using its MagicDNS name, you must also annotate the Ingress resource with tailscale.com/experimental-forward-cluster-traffic-via-ingress annotation to ensure that the proxy created for the Ingress listens on its Pod IP address. - NB: Clusters where Pods get assigned IPv6 addresses only are currently not supported. type: object required: - spec @@ -98,6 +100,61 @@ spec: tag: description: Tag defaults to unstable. type: string + pod: + description: Pod configuration. + type: object + properties: + tolerations: + description: If specified, applies tolerations to the pods deployed by the DNSConfig resource. + type: array + items: + description: |- + The pod this Toleration is attached to tolerates any taint that matches + the triple using the matching operator . + type: object + properties: + effect: + description: |- + Effect indicates the taint effect to match. Empty means match all taint effects. + When specified, allowed values are NoSchedule, PreferNoSchedule and NoExecute. + type: string + key: + description: |- + Key is the taint key that the toleration applies to. Empty means match all taint keys. + If the key is empty, operator must be Exists; this combination means to match all values and all keys. + type: string + operator: + description: |- + Operator represents a key's relationship to the value. + Valid operators are Exists and Equal. Defaults to Equal. + Exists is equivalent to wildcard for value, so that a pod can + tolerate all taints of a particular category. + type: string + tolerationSeconds: + description: |- + TolerationSeconds represents the period of time the toleration (which must be + of effect NoExecute, otherwise this field is ignored) tolerates the taint. By default, + it is not set, which means tolerate the taint forever (do not evict). Zero and + negative values will be treated as 0 (evict immediately) by the system. + type: integer + format: int64 + value: + description: |- + Value is the taint value the toleration matches to. + If the operator is Exists, the value should be empty, otherwise just a regular string. + type: string + replicas: + description: Replicas specifies how many Pods to create. Defaults to 1. + type: integer + format: int32 + minimum: 0 + service: + description: Service configuration. + type: object + properties: + clusterIP: + description: ClusterIP sets the static IP of the service used by the nameserver. + type: string status: description: |- Status describes the status of the DNSConfig. This is set @@ -169,7 +226,7 @@ spec: ip: description: |- IP is the ClusterIP of the Service fronting the deployed ts.net nameserver. - Currently you must manually update your cluster DNS config to add + Currently, you must manually update your cluster DNS config to add this address as a stub nameserver for ts.net for cluster workloads to be able to resolve MagicDNS names associated with egress or Ingress proxies. diff --git a/cmd/k8s-operator/deploy/crds/tailscale.com_proxyclasses.yaml b/cmd/k8s-operator/deploy/crds/tailscale.com_proxyclasses.yaml index 0fff30516..516e75f48 100644 --- a/cmd/k8s-operator/deploy/crds/tailscale.com_proxyclasses.yaml +++ b/cmd/k8s-operator/deploy/crds/tailscale.com_proxyclasses.yaml @@ -2,7 +2,7 @@ apiVersion: apiextensions.k8s.io/v1 kind: CustomResourceDefinition metadata: annotations: - controller-gen.kubebuilder.io/version: v0.15.1-0.20240618033008-7824932b0cab + controller-gen.kubebuilder.io/version: v0.17.0 name: proxyclasses.tailscale.com spec: group: tailscale.com @@ -18,6 +18,9 @@ spec: jsonPath: .status.conditions[?(@.type == "ProxyClassReady")].reason name: Status type: string + - jsonPath: .metadata.creationTimestamp + name: Age + type: date name: v1alpha1 schema: openAPIV3Schema: @@ -73,9 +76,45 @@ spec: enable: description: |- Setting enable to true will make the proxy serve Tailscale metrics - at :9001/debug/metrics. + at :9002/metrics. + A metrics Service named -metrics will also be created in the operator's namespace and will + serve the metrics at :9002/metrics. + + In 1.78.x and 1.80.x, this field also serves as the default value for + .spec.statefulSet.pod.tailscaleContainer.debug.enable. From 1.82.0, both + fields will independently default to false. + Defaults to false. type: boolean + serviceMonitor: + description: |- + Enable to create a Prometheus ServiceMonitor for scraping the proxy's Tailscale metrics. + The ServiceMonitor will select the metrics Service that gets created when metrics are enabled. + The ingested metrics for each Service monitor will have labels to identify the proxy: + ts_proxy_type: ingress_service|ingress_resource|connector|proxygroup + ts_proxy_parent_name: name of the parent resource (i.e name of the Connector, Tailscale Ingress, Tailscale Service or ProxyGroup) + ts_proxy_parent_namespace: namespace of the parent resource (if the parent resource is not cluster scoped) + job: ts__[]_ + type: object + required: + - enable + properties: + enable: + description: If Enable is set to true, a Prometheus ServiceMonitor will be created. Enable can only be set to true if metrics are enabled. + type: boolean + labels: + description: |- + Labels to add to the ServiceMonitor. + Labels must be valid Kubernetes labels. + https://kubernetes.io/docs/concepts/overview/working-with-objects/labels/#syntax-and-character-set + type: object + additionalProperties: + type: string + maxLength: 63 + pattern: ^(([a-zA-Z0-9][-._a-zA-Z0-9]*)?[a-zA-Z0-9])?$ + x-kubernetes-validations: + - rule: '!(has(self.serviceMonitor) && self.serviceMonitor.enable && !self.enable)' + message: ServiceMonitor can only be enabled if metrics are enabled statefulSet: description: |- Configuration parameters for the proxy's StatefulSet. Tailscale @@ -107,6 +146,8 @@ spec: type: object additionalProperties: type: string + maxLength: 63 + pattern: ^(([a-zA-Z0-9][-._a-zA-Z0-9]*)?[a-zA-Z0-9])?$ pod: description: Configuration for the proxy Pod. type: object @@ -390,7 +431,7 @@ spec: pod labels will be ignored. The default value is empty. The same key is forbidden to exist in both matchLabelKeys and labelSelector. Also, matchLabelKeys cannot be set when labelSelector isn't set. - This is an alpha field and requires enabling MatchLabelKeysInPodAffinity feature gate. + This is a beta field and requires enabling MatchLabelKeysInPodAffinity feature gate (enabled by default). type: array items: type: string @@ -405,7 +446,7 @@ spec: pod labels will be ignored. The default value is empty. The same key is forbidden to exist in both mismatchLabelKeys and labelSelector. Also, mismatchLabelKeys cannot be set when labelSelector isn't set. - This is an alpha field and requires enabling MatchLabelKeysInPodAffinity feature gate. + This is a beta field and requires enabling MatchLabelKeysInPodAffinity feature gate (enabled by default). type: array items: type: string @@ -562,7 +603,7 @@ spec: pod labels will be ignored. The default value is empty. The same key is forbidden to exist in both matchLabelKeys and labelSelector. Also, matchLabelKeys cannot be set when labelSelector isn't set. - This is an alpha field and requires enabling MatchLabelKeysInPodAffinity feature gate. + This is a beta field and requires enabling MatchLabelKeysInPodAffinity feature gate (enabled by default). type: array items: type: string @@ -577,7 +618,7 @@ spec: pod labels will be ignored. The default value is empty. The same key is forbidden to exist in both mismatchLabelKeys and labelSelector. Also, mismatchLabelKeys cannot be set when labelSelector isn't set. - This is an alpha field and requires enabling MatchLabelKeysInPodAffinity feature gate. + This is a beta field and requires enabling MatchLabelKeysInPodAffinity feature gate (enabled by default). type: array items: type: string @@ -735,7 +776,7 @@ spec: pod labels will be ignored. The default value is empty. The same key is forbidden to exist in both matchLabelKeys and labelSelector. Also, matchLabelKeys cannot be set when labelSelector isn't set. - This is an alpha field and requires enabling MatchLabelKeysInPodAffinity feature gate. + This is a beta field and requires enabling MatchLabelKeysInPodAffinity feature gate (enabled by default). type: array items: type: string @@ -750,7 +791,7 @@ spec: pod labels will be ignored. The default value is empty. The same key is forbidden to exist in both mismatchLabelKeys and labelSelector. Also, mismatchLabelKeys cannot be set when labelSelector isn't set. - This is an alpha field and requires enabling MatchLabelKeysInPodAffinity feature gate. + This is a beta field and requires enabling MatchLabelKeysInPodAffinity feature gate (enabled by default). type: array items: type: string @@ -907,7 +948,7 @@ spec: pod labels will be ignored. The default value is empty. The same key is forbidden to exist in both matchLabelKeys and labelSelector. Also, matchLabelKeys cannot be set when labelSelector isn't set. - This is an alpha field and requires enabling MatchLabelKeysInPodAffinity feature gate. + This is a beta field and requires enabling MatchLabelKeysInPodAffinity feature gate (enabled by default). type: array items: type: string @@ -922,7 +963,7 @@ spec: pod labels will be ignored. The default value is empty. The same key is forbidden to exist in both mismatchLabelKeys and labelSelector. Also, mismatchLabelKeys cannot be set when labelSelector isn't set. - This is an alpha field and requires enabling MatchLabelKeysInPodAffinity feature gate. + This is a beta field and requires enabling MatchLabelKeysInPodAffinity feature gate (enabled by default). type: array items: type: string @@ -1005,6 +1046,62 @@ spec: type: object additionalProperties: type: string + dnsConfig: + description: |- + DNSConfig defines DNS parameters for the proxy Pod in addition to those generated from DNSPolicy. + When DNSPolicy is set to "None", DNSConfig must be specified. + https://kubernetes.io/docs/concepts/services-networking/dns-pod-service/#pod-dns-config + type: object + properties: + nameservers: + description: |- + A list of DNS name server IP addresses. + This will be appended to the base nameservers generated from DNSPolicy. + Duplicated nameservers will be removed. + type: array + items: + type: string + x-kubernetes-list-type: atomic + options: + description: |- + A list of DNS resolver options. + This will be merged with the base options generated from DNSPolicy. + Duplicated entries will be removed. Resolution options given in Options + will override those that appear in the base DNSPolicy. + type: array + items: + description: PodDNSConfigOption defines DNS resolver options of a pod. + type: object + properties: + name: + description: |- + Name is this DNS resolver option's name. + Required. + type: string + value: + description: Value is this DNS resolver option's value. + type: string + x-kubernetes-list-type: atomic + searches: + description: |- + A list of DNS search domains for host-name lookup. + This will be appended to the base search paths generated from DNSPolicy. + Duplicated search paths will be removed. + type: array + items: + type: string + x-kubernetes-list-type: atomic + dnsPolicy: + description: |- + DNSPolicy defines how DNS will be configured for the proxy Pod. + By default the Tailscale Kubernetes Operator does not set a DNS policy (uses cluster default). + https://kubernetes.io/docs/concepts/services-networking/dns-pod-service/#pod-s-dns-policy + type: string + enum: + - ClusterFirstWithHostNet + - ClusterFirst + - Default + - None imagePullSecrets: description: |- Proxy Pod's image pull Secrets. @@ -1036,6 +1133,8 @@ spec: type: object additionalProperties: type: string + maxLength: 63 + pattern: ^(([a-zA-Z0-9][-._a-zA-Z0-9]*)?[a-zA-Z0-9])?$ nodeName: description: |- Proxy Pod's node name. @@ -1050,6 +1149,12 @@ spec: type: object additionalProperties: type: string + priorityClassName: + description: |- + PriorityClassName for the proxy Pod. + By default Tailscale Kubernetes operator does not apply any priority class. + https://kubernetes.io/docs/reference/kubernetes-api/workload-resources/pod-v1/#scheduling + type: string securityContext: description: |- Proxy Pod's security context. @@ -1134,6 +1239,32 @@ spec: Note that this field cannot be set when spec.os.name is windows. type: integer format: int64 + seLinuxChangePolicy: + description: |- + seLinuxChangePolicy defines how the container's SELinux label is applied to all volumes used by the Pod. + It has no effect on nodes that do not support SELinux or to volumes does not support SELinux. + Valid values are "MountOption" and "Recursive". + + "Recursive" means relabeling of all files on all Pod volumes by the container runtime. + This may be slow for large volumes, but allows mixing privileged and unprivileged Pods sharing the same volume on the same node. + + "MountOption" mounts all eligible Pod volumes with `-o context` mount option. + This requires all Pods that share the same volume to use the same SELinux label. + It is not possible to share the same volume among privileged and unprivileged Pods. + Eligible volumes are in-tree FibreChannel and iSCSI volumes, and all CSI volumes + whose CSI driver announces SELinux support by setting spec.seLinuxMount: true in their + CSIDriver instance. Other volumes are always re-labelled recursively. + "MountOption" value is allowed only when SELinuxMount feature gate is enabled. + + If not specified and SELinuxMount feature gate is enabled, "MountOption" is used. + If not specified and SELinuxMount feature gate is disabled, "MountOption" is used for ReadWriteOncePod volumes + and "Recursive" for all other volumes. + + This field affects only Pods that have SELinux label set, either in PodSecurityContext or in SecurityContext of all containers. + + All Pods that use the same volume should use the same seLinuxChangePolicy, otherwise some pods can get stuck in ContainerCreating state. + Note that this field cannot be set when spec.os.name is windows. + type: string seLinuxOptions: description: |- The SELinux context to be applied to all containers. @@ -1182,18 +1313,28 @@ spec: type: string supplementalGroups: description: |- - A list of groups applied to the first process run in each container, in addition - to the container's primary GID, the fsGroup (if specified), and group memberships - defined in the container image for the uid of the container process. If unspecified, - no additional groups are added to any container. Note that group memberships - defined in the container image for the uid of the container process are still effective, - even if they are not included in this list. + A list of groups applied to the first process run in each container, in + addition to the container's primary GID and fsGroup (if specified). If + the SupplementalGroupsPolicy feature is enabled, the + supplementalGroupsPolicy field determines whether these are in addition + to or instead of any group memberships defined in the container image. + If unspecified, no additional groups are added, though group memberships + defined in the container image may still be used, depending on the + supplementalGroupsPolicy field. Note that this field cannot be set when spec.os.name is windows. type: array items: type: integer format: int64 x-kubernetes-list-type: atomic + supplementalGroupsPolicy: + description: |- + Defines how supplemental groups of the first container processes are calculated. + Valid values are "Merge" and "Strict". If not specified, "Merge" is used. + (Alpha) Using the field requires the SupplementalGroupsPolicy feature gate to be enabled + and the container runtime must implement support for this feature. + Note that this field cannot be set when spec.os.name is windows. + type: string sysctls: description: |- Sysctls hold a list of namespaced sysctls used for the pod. Pods with unsupported @@ -1249,6 +1390,25 @@ spec: description: Configuration for the proxy container running tailscale. type: object properties: + debug: + description: |- + Configuration for enabling extra debug information in the container. + Not recommended for production use. + type: object + properties: + enable: + description: |- + Enable tailscaled's HTTP pprof endpoints at :9001/debug/pprof/ + and internal debug metrics endpoint at :9001/debug/metrics, where + 9001 is a container port named "debug". The endpoints and their responses + may change in backwards incompatible ways in the future, and should not + be considered stable. + + In 1.78.x and 1.80.x, this setting will default to the value of + .spec.metrics.enable, and requests to the "metrics" port matching the + mux pattern /debug/ will be forwarded to the "debug" port. In 1.82.x, + this setting will default to false, and no requests will be proxied. + type: boolean env: description: |- List of environment variables to set in the container. @@ -1281,12 +1441,21 @@ spec: type: string image: description: |- - Container image name. By default images are pulled from - docker.io/tailscale/tailscale, but the official images are also - available at ghcr.io/tailscale/tailscale. Specifying image name here - will override any proxy image values specified via the Kubernetes - operator's Helm chart values or PROXY_IMAGE env var in the operator - Deployment. + Container image name. By default images are pulled from docker.io/tailscale, + but the official images are also available at ghcr.io/tailscale. + + For all uses except on ProxyGroups of type "kube-apiserver", this image must + be either tailscale/tailscale, or an equivalent mirror of that image. + To apply to ProxyGroups of type "kube-apiserver", this image must be + tailscale/k8s-proxy or a mirror of that image. + + For "tailscale/tailscale"-based proxies, specifying image name here will + override any proxy image values specified via the Kubernetes operator's + Helm chart values or PROXY_IMAGE env var in the operator Deployment. + For "tailscale/k8s-proxy"-based proxies, there is currently no way to + configure your own default, and this field is the only way to use a + custom image. + https://kubernetes.io/docs/reference/kubernetes-api/workload-resources/pod-v1/#image type: string imagePullPolicy: @@ -1330,6 +1499,12 @@ spec: the Pod where this field is used. It makes that resource available inside a container. type: string + request: + description: |- + Request is the name chosen for a request in the referenced claim. + If empty, everything from the claim is made available, otherwise + only the result of this request. + type: string x-kubernetes-list-map-keys: - name x-kubernetes-list-type: map @@ -1360,11 +1535,12 @@ spec: securityContext: description: |- Container security context. - Security context specified here will override the security context by the operator. - By default the operator: - - sets 'privileged: true' for the init container - - set NET_ADMIN capability for tailscale container for proxies that - are created for Services or Connector. + Security context specified here will override the security context set by the operator. + By default the operator sets the Tailscale container and the Tailscale init container to privileged + for proxies created for Tailscale ingress and egress Service, Connector and ProxyGroup. + You can reduce the permissions of the Tailscale container to cap NET_ADMIN by + installing device plugin in your cluster and configuring the proxies tun device to be created + by the device plugin, see https://github.com/tailscale/tailscale/issues/10814#issuecomment-2479977752 https://kubernetes.io/docs/reference/kubernetes-api/workload-resources/pod-v1/#security-context type: object properties: @@ -1433,7 +1609,7 @@ spec: procMount: description: |- procMount denotes the type of proc mount to use for the containers. - The default is DefaultProcMount which uses the container runtime defaults for + The default value is Default which uses the container runtime defaults for readonly paths and masked paths. This requires the ProcMountType feature flag to be enabled. Note that this field cannot be set when spec.os.name is windows. @@ -1550,9 +1726,30 @@ spec: PodSecurityContext, the value specified in SecurityContext takes precedence. type: string tailscaleInitContainer: - description: Configuration for the proxy init container that enables forwarding. + description: |- + Configuration for the proxy init container that enables forwarding. + Not valid to apply to ProxyGroups of type "kube-apiserver". type: object properties: + debug: + description: |- + Configuration for enabling extra debug information in the container. + Not recommended for production use. + type: object + properties: + enable: + description: |- + Enable tailscaled's HTTP pprof endpoints at :9001/debug/pprof/ + and internal debug metrics endpoint at :9001/debug/metrics, where + 9001 is a container port named "debug". The endpoints and their responses + may change in backwards incompatible ways in the future, and should not + be considered stable. + + In 1.78.x and 1.80.x, this setting will default to the value of + .spec.metrics.enable, and requests to the "metrics" port matching the + mux pattern /debug/ will be forwarded to the "debug" port. In 1.82.x, + this setting will default to false, and no requests will be proxied. + type: boolean env: description: |- List of environment variables to set in the container. @@ -1585,12 +1782,21 @@ spec: type: string image: description: |- - Container image name. By default images are pulled from - docker.io/tailscale/tailscale, but the official images are also - available at ghcr.io/tailscale/tailscale. Specifying image name here - will override any proxy image values specified via the Kubernetes - operator's Helm chart values or PROXY_IMAGE env var in the operator - Deployment. + Container image name. By default images are pulled from docker.io/tailscale, + but the official images are also available at ghcr.io/tailscale. + + For all uses except on ProxyGroups of type "kube-apiserver", this image must + be either tailscale/tailscale, or an equivalent mirror of that image. + To apply to ProxyGroups of type "kube-apiserver", this image must be + tailscale/k8s-proxy or a mirror of that image. + + For "tailscale/tailscale"-based proxies, specifying image name here will + override any proxy image values specified via the Kubernetes operator's + Helm chart values or PROXY_IMAGE env var in the operator Deployment. + For "tailscale/k8s-proxy"-based proxies, there is currently no way to + configure your own default, and this field is the only way to use a + custom image. + https://kubernetes.io/docs/reference/kubernetes-api/workload-resources/pod-v1/#image type: string imagePullPolicy: @@ -1634,6 +1840,12 @@ spec: the Pod where this field is used. It makes that resource available inside a container. type: string + request: + description: |- + Request is the name chosen for a request in the referenced claim. + If empty, everything from the claim is made available, otherwise + only the result of this request. + type: string x-kubernetes-list-map-keys: - name x-kubernetes-list-type: map @@ -1664,11 +1876,12 @@ spec: securityContext: description: |- Container security context. - Security context specified here will override the security context by the operator. - By default the operator: - - sets 'privileged: true' for the init container - - set NET_ADMIN capability for tailscale container for proxies that - are created for Services or Connector. + Security context specified here will override the security context set by the operator. + By default the operator sets the Tailscale container and the Tailscale init container to privileged + for proxies created for Tailscale ingress and egress Service, Connector and ProxyGroup. + You can reduce the permissions of the Tailscale container to cap NET_ADMIN by + installing device plugin in your cluster and configuring the proxies tun device to be created + by the device plugin, see https://github.com/tailscale/tailscale/issues/10814#issuecomment-2479977752 https://kubernetes.io/docs/reference/kubernetes-api/workload-resources/pod-v1/#security-context type: object properties: @@ -1737,7 +1950,7 @@ spec: procMount: description: |- procMount denotes the type of proc mount to use for the containers. - The default is DefaultProcMount which uses the container runtime defaults for + The default value is Default which uses the container runtime defaults for readonly paths and masked paths. This requires the ProcMountType feature flag to be enabled. Note that this field cannot be set when spec.os.name is windows. @@ -1896,6 +2109,227 @@ spec: Value is the taint value the toleration matches to. If the operator is Exists, the value should be empty, otherwise just a regular string. type: string + topologySpreadConstraints: + description: |- + Proxy Pod's topology spread constraints. + By default Tailscale Kubernetes operator does not apply any topology spread constraints. + https://kubernetes.io/docs/concepts/scheduling-eviction/topology-spread-constraints/ + type: array + items: + description: TopologySpreadConstraint specifies how to spread matching pods among the given topology. + type: object + required: + - maxSkew + - topologyKey + - whenUnsatisfiable + properties: + labelSelector: + description: |- + LabelSelector is used to find matching pods. + Pods that match this label selector are counted to determine the number of pods + in their corresponding topology domain. + type: object + properties: + matchExpressions: + description: matchExpressions is a list of label selector requirements. The requirements are ANDed. + type: array + items: + description: |- + A label selector requirement is a selector that contains values, a key, and an operator that + relates the key and values. + type: object + required: + - key + - operator + properties: + key: + description: key is the label key that the selector applies to. + type: string + operator: + description: |- + operator represents a key's relationship to a set of values. + Valid operators are In, NotIn, Exists and DoesNotExist. + type: string + values: + description: |- + values is an array of string values. If the operator is In or NotIn, + the values array must be non-empty. If the operator is Exists or DoesNotExist, + the values array must be empty. This array is replaced during a strategic + merge patch. + type: array + items: + type: string + x-kubernetes-list-type: atomic + x-kubernetes-list-type: atomic + matchLabels: + description: |- + matchLabels is a map of {key,value} pairs. A single {key,value} in the matchLabels + map is equivalent to an element of matchExpressions, whose key field is "key", the + operator is "In", and the values array contains only "value". The requirements are ANDed. + type: object + additionalProperties: + type: string + x-kubernetes-map-type: atomic + matchLabelKeys: + description: |- + MatchLabelKeys is a set of pod label keys to select the pods over which + spreading will be calculated. The keys are used to lookup values from the + incoming pod labels, those key-value labels are ANDed with labelSelector + to select the group of existing pods over which spreading will be calculated + for the incoming pod. The same key is forbidden to exist in both MatchLabelKeys and LabelSelector. + MatchLabelKeys cannot be set when LabelSelector isn't set. + Keys that don't exist in the incoming pod labels will + be ignored. A null or empty list means only match against labelSelector. + + This is a beta field and requires the MatchLabelKeysInPodTopologySpread feature gate to be enabled (enabled by default). + type: array + items: + type: string + x-kubernetes-list-type: atomic + maxSkew: + description: |- + MaxSkew describes the degree to which pods may be unevenly distributed. + When `whenUnsatisfiable=DoNotSchedule`, it is the maximum permitted difference + between the number of matching pods in the target topology and the global minimum. + The global minimum is the minimum number of matching pods in an eligible domain + or zero if the number of eligible domains is less than MinDomains. + For example, in a 3-zone cluster, MaxSkew is set to 1, and pods with the same + labelSelector spread as 2/2/1: + In this case, the global minimum is 1. + | zone1 | zone2 | zone3 | + | P P | P P | P | + - if MaxSkew is 1, incoming pod can only be scheduled to zone3 to become 2/2/2; + scheduling it onto zone1(zone2) would make the ActualSkew(3-1) on zone1(zone2) + violate MaxSkew(1). + - if MaxSkew is 2, incoming pod can be scheduled onto any zone. + When `whenUnsatisfiable=ScheduleAnyway`, it is used to give higher precedence + to topologies that satisfy it. + It's a required field. Default value is 1 and 0 is not allowed. + type: integer + format: int32 + minDomains: + description: |- + MinDomains indicates a minimum number of eligible domains. + When the number of eligible domains with matching topology keys is less than minDomains, + Pod Topology Spread treats "global minimum" as 0, and then the calculation of Skew is performed. + And when the number of eligible domains with matching topology keys equals or greater than minDomains, + this value has no effect on scheduling. + As a result, when the number of eligible domains is less than minDomains, + scheduler won't schedule more than maxSkew Pods to those domains. + If value is nil, the constraint behaves as if MinDomains is equal to 1. + Valid values are integers greater than 0. + When value is not nil, WhenUnsatisfiable must be DoNotSchedule. + + For example, in a 3-zone cluster, MaxSkew is set to 2, MinDomains is set to 5 and pods with the same + labelSelector spread as 2/2/2: + | zone1 | zone2 | zone3 | + | P P | P P | P P | + The number of domains is less than 5(MinDomains), so "global minimum" is treated as 0. + In this situation, new pod with the same labelSelector cannot be scheduled, + because computed skew will be 3(3 - 0) if new Pod is scheduled to any of the three zones, + it will violate MaxSkew. + type: integer + format: int32 + nodeAffinityPolicy: + description: |- + NodeAffinityPolicy indicates how we will treat Pod's nodeAffinity/nodeSelector + when calculating pod topology spread skew. Options are: + - Honor: only nodes matching nodeAffinity/nodeSelector are included in the calculations. + - Ignore: nodeAffinity/nodeSelector are ignored. All nodes are included in the calculations. + + If this value is nil, the behavior is equivalent to the Honor policy. + This is a beta-level feature default enabled by the NodeInclusionPolicyInPodTopologySpread feature flag. + type: string + nodeTaintsPolicy: + description: |- + NodeTaintsPolicy indicates how we will treat node taints when calculating + pod topology spread skew. Options are: + - Honor: nodes without taints, along with tainted nodes for which the incoming pod + has a toleration, are included. + - Ignore: node taints are ignored. All nodes are included. + + If this value is nil, the behavior is equivalent to the Ignore policy. + This is a beta-level feature default enabled by the NodeInclusionPolicyInPodTopologySpread feature flag. + type: string + topologyKey: + description: |- + TopologyKey is the key of node labels. Nodes that have a label with this key + and identical values are considered to be in the same topology. + We consider each as a "bucket", and try to put balanced number + of pods into each bucket. + We define a domain as a particular instance of a topology. + Also, we define an eligible domain as a domain whose nodes meet the requirements of + nodeAffinityPolicy and nodeTaintsPolicy. + e.g. If TopologyKey is "kubernetes.io/hostname", each Node is a domain of that topology. + And, if TopologyKey is "topology.kubernetes.io/zone", each zone is a domain of that topology. + It's a required field. + type: string + whenUnsatisfiable: + description: |- + WhenUnsatisfiable indicates how to deal with a pod if it doesn't satisfy + the spread constraint. + - DoNotSchedule (default) tells the scheduler not to schedule it. + - ScheduleAnyway tells the scheduler to schedule the pod in any location, + but giving higher precedence to topologies that would help reduce the + skew. + A constraint is considered "Unsatisfiable" for an incoming pod + if and only if every possible node assignment for that pod would violate + "MaxSkew" on some topology. + For example, in a 3-zone cluster, MaxSkew is set to 1, and pods with the same + labelSelector spread as 3/1/1: + | zone1 | zone2 | zone3 | + | P P P | P | P | + If WhenUnsatisfiable is set to DoNotSchedule, incoming pod can only be scheduled + to zone2(zone3) to become 3/2/1(3/1/2) as ActualSkew(2-1) on zone2(zone3) satisfies + MaxSkew(1). In other words, the cluster can still be imbalanced, but scheduler + won't make it *more* imbalanced. + It's a required field. + type: string + staticEndpoints: + description: |- + Configuration for 'static endpoints' on proxies in order to facilitate + direct connections from other devices on the tailnet. + See https://tailscale.com/kb/1445/kubernetes-operator-customization#static-endpoints. + type: object + required: + - nodePort + properties: + nodePort: + description: The configuration for static endpoints using NodePort Services. + type: object + required: + - ports + properties: + ports: + description: |- + The port ranges from which the operator will select NodePorts for the Services. + You must ensure that firewall rules allow UDP ingress traffic for these ports + to the node's external IPs. + The ports must be in the range of service node ports for the cluster (default `30000-32767`). + See https://kubernetes.io/docs/concepts/services-networking/service/#type-nodeport. + type: array + minItems: 1 + items: + type: object + required: + - port + properties: + endPort: + description: |- + endPort indicates that the range of ports from port to endPort if set, inclusive, + should be used. This field cannot be defined if the port field is not defined. + The endPort must be either unset, or equal or greater than port. + type: integer + port: + description: port represents a port selected to be used. This is a required field. + type: integer + selector: + description: |- + A selector which will be used to select the node's that will have their `ExternalIP`'s advertised + by the ProxyGroup as Static Endpoints. + type: object + additionalProperties: + type: string tailscale: description: |- TailscaleConfig contains options to configure the tailscale-specific @@ -1911,6 +2345,22 @@ spec: https://tailscale.com/kb/1019/subnets#use-your-subnet-routes-from-other-devices Defaults to false. type: boolean + useLetsEncryptStagingEnvironment: + description: |- + Set UseLetsEncryptStagingEnvironment to true to issue TLS + certificates for any HTTPS endpoints exposed to the tailnet from + LetsEncrypt's staging environment. + https://letsencrypt.org/docs/staging-environment/ + This setting only affects Tailscale Ingress resources. + By default Ingress TLS certificates are issued from LetsEncrypt's + production environment. + Changing this setting true -> false, will result in any + existing certs being re-issued from the production environment. + Changing this setting false (default) -> true, when certs have already + been provisioned from production environment will NOT result in certs + being re-issued from the staging environment before they need to be + renewed. + type: boolean status: description: |- Status of the ProxyClass. This is set and managed automatically. diff --git a/cmd/k8s-operator/deploy/crds/tailscale.com_proxygroups.yaml b/cmd/k8s-operator/deploy/crds/tailscale.com_proxygroups.yaml index 5f3520d26..98ca1c378 100644 --- a/cmd/k8s-operator/deploy/crds/tailscale.com_proxygroups.yaml +++ b/cmd/k8s-operator/deploy/crds/tailscale.com_proxygroups.yaml @@ -2,7 +2,7 @@ apiVersion: apiextensions.k8s.io/v1 kind: CustomResourceDefinition metadata: annotations: - controller-gen.kubebuilder.io/version: v0.15.1-0.20240618033008-7824932b0cab + controller-gen.kubebuilder.io/version: v0.17.0 name: proxygroups.tailscale.com spec: group: tailscale.com @@ -20,9 +20,38 @@ spec: jsonPath: .status.conditions[?(@.type == "ProxyGroupReady")].reason name: Status type: string + - description: URL of the kube-apiserver proxy advertised by the ProxyGroup devices, if any. Only applies to ProxyGroups of type kube-apiserver. + jsonPath: .status.url + name: URL + type: string + - description: ProxyGroup type. + jsonPath: .spec.type + name: Type + type: string + - jsonPath: .metadata.creationTimestamp + name: Age + type: date name: v1alpha1 schema: openAPIV3Schema: + description: |- + ProxyGroup defines a set of Tailscale devices that will act as proxies. + Depending on spec.Type, it can be a group of egress, ingress, or kube-apiserver + proxies. In addition to running a highly available set of proxies, ingress + and egress ProxyGroups also allow for serving many annotated Services from a + single set of proxies to minimise resource consumption. + + For ingress and egress, use the tailscale.com/proxy-group annotation on a + Service to specify that the proxy should be implemented by a ProxyGroup + instead of a single dedicated proxy. + + More info: + * https://tailscale.com/kb/1438/kubernetes-operator-cluster-egress + * https://tailscale.com/kb/1439/kubernetes-operator-cluster-ingress + + For kube-apiserver, the ProxyGroup is a standalone resource. Use the + spec.kubeAPIServer field to configure options specific to the kube-apiserver + ProxyGroup type. type: object required: - spec @@ -59,18 +88,45 @@ spec: must not start with a dash and must be between 1 and 62 characters long. type: string pattern: ^[a-z0-9][a-z0-9-]{0,61}$ + kubeAPIServer: + description: |- + KubeAPIServer contains configuration specific to the kube-apiserver + ProxyGroup type. This field is only used when Type is set to "kube-apiserver". + type: object + properties: + hostname: + description: |- + Hostname is the hostname with which to expose the Kubernetes API server + proxies. Must be a valid DNS label no longer than 63 characters. If not + specified, the name of the ProxyGroup is used as the hostname. Must be + unique across the whole tailnet. + type: string + pattern: ^[a-z0-9]([a-z0-9-]{0,61}[a-z0-9])?$ + mode: + description: |- + Mode to run the API server proxy in. Supported modes are auth and noauth. + In auth mode, requests from the tailnet proxied over to the Kubernetes + API server are additionally impersonated using the sender's tailnet identity. + If not specified, defaults to auth mode. + type: string + enum: + - auth + - noauth proxyClass: description: |- ProxyClass is the name of the ProxyClass custom resource that contains configuration options that should be applied to the resources created - for this ProxyGroup. If unset, and no default ProxyClass is set, the - operator will create resources with the default configuration. + for this ProxyGroup. If unset, and there is no default ProxyClass + configured, the operator will create resources with the default + configuration. type: string replicas: description: |- Replicas specifies how many replicas to create the StatefulSet with. Defaults to 2. type: integer + format: int32 + minimum: 0 tags: description: |- Tags that the Tailscale devices will be tagged with. Defaults to [tag:k8s]. @@ -85,12 +141,16 @@ spec: pattern: ^tag:[a-zA-Z][a-zA-Z0-9-]*$ type: description: |- - Type of the ProxyGroup, either ingress or egress. Each set of proxies - managed by a single ProxyGroup definition operate as only ingress or - only egress proxies. + Type of the ProxyGroup proxies. Supported types are egress, ingress, and kube-apiserver. + Type is immutable once a ProxyGroup is created. type: string enum: - egress + - ingress + - kube-apiserver + x-kubernetes-validations: + - rule: self == oldSelf + message: ProxyGroup type is immutable status: description: |- ProxyGroupStatus describes the status of the ProxyGroup resources. This is @@ -100,7 +160,20 @@ spec: conditions: description: |- List of status conditions to indicate the status of the ProxyGroup - resources. Known condition types are `ProxyGroupReady`. + resources. Known condition types include `ProxyGroupReady` and + `ProxyGroupAvailable`. + + * `ProxyGroupReady` indicates all ProxyGroup resources are reconciled and + all expected conditions are true. + * `ProxyGroupAvailable` indicates that at least one proxy is ready to + serve traffic. + + For ProxyGroups of type kube-apiserver, there are two additional conditions: + + * `KubeAPIServerProxyConfigured` indicates that at least one API server + proxy is configured and ready to serve traffic. + * `KubeAPIServerProxyValid` indicates that spec.kubeAPIServer config is + valid. type: array items: description: Condition contains details for one aspect of the current state of this API Resource. @@ -172,6 +245,11 @@ spec: If MagicDNS is enabled in your tailnet, it is the MagicDNS name of the node. type: string + staticEndpoints: + description: StaticEndpoints are user configured, 'static' endpoints by which tailnet peers can reach this device. + type: array + items: + type: string tailnetIPs: description: |- TailnetIPs is the set of tailnet IP addresses (both IPv4 and IPv6) @@ -182,6 +260,11 @@ spec: x-kubernetes-list-map-keys: - hostname x-kubernetes-list-type: map + url: + description: |- + URL of the kube-apiserver proxy advertised by the ProxyGroup devices, if + any. Only applies to ProxyGroups of type kube-apiserver. + type: string served: true storage: true subresources: diff --git a/cmd/k8s-operator/deploy/crds/tailscale.com_recorders.yaml b/cmd/k8s-operator/deploy/crds/tailscale.com_recorders.yaml index fda8bcebd..0f3dcfcca 100644 --- a/cmd/k8s-operator/deploy/crds/tailscale.com_recorders.yaml +++ b/cmd/k8s-operator/deploy/crds/tailscale.com_recorders.yaml @@ -2,7 +2,7 @@ apiVersion: apiextensions.k8s.io/v1 kind: CustomResourceDefinition metadata: annotations: - controller-gen.kubebuilder.io/version: v0.15.1-0.20240618033008-7824932b0cab + controller-gen.kubebuilder.io/version: v0.17.0 name: recorders.tailscale.com spec: group: tailscale.com @@ -24,9 +24,18 @@ spec: jsonPath: .status.devices[?(@.url != "")].url name: URL type: string + - jsonPath: .metadata.creationTimestamp + name: Age + type: date name: v1alpha1 schema: openAPIV3Schema: + description: |- + Recorder defines a tsrecorder device for recording SSH sessions. By default, + it will store recordings in a local ephemeral volume. If you want to persist + recordings, you can configure an S3-compatible API for storage. + + More info: https://tailscale.com/kb/1484/kubernetes-operator-deploying-tsrecorder type: object required: - spec @@ -366,7 +375,7 @@ spec: pod labels will be ignored. The default value is empty. The same key is forbidden to exist in both matchLabelKeys and labelSelector. Also, matchLabelKeys cannot be set when labelSelector isn't set. - This is an alpha field and requires enabling MatchLabelKeysInPodAffinity feature gate. + This is a beta field and requires enabling MatchLabelKeysInPodAffinity feature gate (enabled by default). type: array items: type: string @@ -381,7 +390,7 @@ spec: pod labels will be ignored. The default value is empty. The same key is forbidden to exist in both mismatchLabelKeys and labelSelector. Also, mismatchLabelKeys cannot be set when labelSelector isn't set. - This is an alpha field and requires enabling MatchLabelKeysInPodAffinity feature gate. + This is a beta field and requires enabling MatchLabelKeysInPodAffinity feature gate (enabled by default). type: array items: type: string @@ -538,7 +547,7 @@ spec: pod labels will be ignored. The default value is empty. The same key is forbidden to exist in both matchLabelKeys and labelSelector. Also, matchLabelKeys cannot be set when labelSelector isn't set. - This is an alpha field and requires enabling MatchLabelKeysInPodAffinity feature gate. + This is a beta field and requires enabling MatchLabelKeysInPodAffinity feature gate (enabled by default). type: array items: type: string @@ -553,7 +562,7 @@ spec: pod labels will be ignored. The default value is empty. The same key is forbidden to exist in both mismatchLabelKeys and labelSelector. Also, mismatchLabelKeys cannot be set when labelSelector isn't set. - This is an alpha field and requires enabling MatchLabelKeysInPodAffinity feature gate. + This is a beta field and requires enabling MatchLabelKeysInPodAffinity feature gate (enabled by default). type: array items: type: string @@ -711,7 +720,7 @@ spec: pod labels will be ignored. The default value is empty. The same key is forbidden to exist in both matchLabelKeys and labelSelector. Also, matchLabelKeys cannot be set when labelSelector isn't set. - This is an alpha field and requires enabling MatchLabelKeysInPodAffinity feature gate. + This is a beta field and requires enabling MatchLabelKeysInPodAffinity feature gate (enabled by default). type: array items: type: string @@ -726,7 +735,7 @@ spec: pod labels will be ignored. The default value is empty. The same key is forbidden to exist in both mismatchLabelKeys and labelSelector. Also, mismatchLabelKeys cannot be set when labelSelector isn't set. - This is an alpha field and requires enabling MatchLabelKeysInPodAffinity feature gate. + This is a beta field and requires enabling MatchLabelKeysInPodAffinity feature gate (enabled by default). type: array items: type: string @@ -883,7 +892,7 @@ spec: pod labels will be ignored. The default value is empty. The same key is forbidden to exist in both matchLabelKeys and labelSelector. Also, matchLabelKeys cannot be set when labelSelector isn't set. - This is an alpha field and requires enabling MatchLabelKeysInPodAffinity feature gate. + This is a beta field and requires enabling MatchLabelKeysInPodAffinity feature gate (enabled by default). type: array items: type: string @@ -898,7 +907,7 @@ spec: pod labels will be ignored. The default value is empty. The same key is forbidden to exist in both mismatchLabelKeys and labelSelector. Also, mismatchLabelKeys cannot be set when labelSelector isn't set. - This is an alpha field and requires enabling MatchLabelKeysInPodAffinity feature gate. + This is a beta field and requires enabling MatchLabelKeysInPodAffinity feature gate (enabled by default). type: array items: type: string @@ -1060,6 +1069,12 @@ spec: the Pod where this field is used. It makes that resource available inside a container. type: string + request: + description: |- + Request is the name chosen for a request in the referenced claim. + If empty, everything from the claim is made available, otherwise + only the result of this request. + type: string x-kubernetes-list-map-keys: - name x-kubernetes-list-type: map @@ -1159,7 +1174,7 @@ spec: procMount: description: |- procMount denotes the type of proc mount to use for the containers. - The default is DefaultProcMount which uses the container runtime defaults for + The default value is Default which uses the container runtime defaults for readonly paths and masked paths. This requires the ProcMountType feature flag to be enabled. Note that this field cannot be set when spec.os.name is windows. @@ -1395,6 +1410,32 @@ spec: Note that this field cannot be set when spec.os.name is windows. type: integer format: int64 + seLinuxChangePolicy: + description: |- + seLinuxChangePolicy defines how the container's SELinux label is applied to all volumes used by the Pod. + It has no effect on nodes that do not support SELinux or to volumes does not support SELinux. + Valid values are "MountOption" and "Recursive". + + "Recursive" means relabeling of all files on all Pod volumes by the container runtime. + This may be slow for large volumes, but allows mixing privileged and unprivileged Pods sharing the same volume on the same node. + + "MountOption" mounts all eligible Pod volumes with `-o context` mount option. + This requires all Pods that share the same volume to use the same SELinux label. + It is not possible to share the same volume among privileged and unprivileged Pods. + Eligible volumes are in-tree FibreChannel and iSCSI volumes, and all CSI volumes + whose CSI driver announces SELinux support by setting spec.seLinuxMount: true in their + CSIDriver instance. Other volumes are always re-labelled recursively. + "MountOption" value is allowed only when SELinuxMount feature gate is enabled. + + If not specified and SELinuxMount feature gate is enabled, "MountOption" is used. + If not specified and SELinuxMount feature gate is disabled, "MountOption" is used for ReadWriteOncePod volumes + and "Recursive" for all other volumes. + + This field affects only Pods that have SELinux label set, either in PodSecurityContext or in SecurityContext of all containers. + + All Pods that use the same volume should use the same seLinuxChangePolicy, otherwise some pods can get stuck in ContainerCreating state. + Note that this field cannot be set when spec.os.name is windows. + type: string seLinuxOptions: description: |- The SELinux context to be applied to all containers. @@ -1443,18 +1484,28 @@ spec: type: string supplementalGroups: description: |- - A list of groups applied to the first process run in each container, in addition - to the container's primary GID, the fsGroup (if specified), and group memberships - defined in the container image for the uid of the container process. If unspecified, - no additional groups are added to any container. Note that group memberships - defined in the container image for the uid of the container process are still effective, - even if they are not included in this list. + A list of groups applied to the first process run in each container, in + addition to the container's primary GID and fsGroup (if specified). If + the SupplementalGroupsPolicy feature is enabled, the + supplementalGroupsPolicy field determines whether these are in addition + to or instead of any group memberships defined in the container image. + If unspecified, no additional groups are added, though group memberships + defined in the container image may still be used, depending on the + supplementalGroupsPolicy field. Note that this field cannot be set when spec.os.name is windows. type: array items: type: integer format: int64 x-kubernetes-list-type: atomic + supplementalGroupsPolicy: + description: |- + Defines how supplemental groups of the first container processes are calculated. + Valid values are "Merge" and "Strict". If not specified, "Merge" is used. + (Alpha) Using the field requires the SupplementalGroupsPolicy feature gate to be enabled + and the container runtime must implement support for this feature. + Note that this field cannot be set when spec.os.name is windows. + type: string sysctls: description: |- Sysctls hold a list of namespaced sysctls used for the pod. Pods with unsupported @@ -1506,6 +1557,36 @@ spec: May also be set in PodSecurityContext. If set in both SecurityContext and PodSecurityContext, the value specified in SecurityContext takes precedence. type: string + serviceAccount: + description: |- + Config for the ServiceAccount to create for the Recorder's StatefulSet. + By default, the operator will create a ServiceAccount with the same + name as the Recorder resource. + https://kubernetes.io/docs/reference/kubernetes-api/workload-resources/pod-v1/#service-account + type: object + properties: + annotations: + description: |- + Annotations to add to the ServiceAccount. + https://kubernetes.io/docs/concepts/overview/working-with-objects/annotations/#syntax-and-character-set + + You can use this to add IAM roles to the ServiceAccount (IRSA) instead of + providing static S3 credentials in a Secret. + https://docs.aws.amazon.com/eks/latest/userguide/iam-roles-for-service-accounts.html + + For example: + eks.amazonaws.com/role-arn: arn:aws:iam:::role/ + type: object + additionalProperties: + type: string + name: + description: |- + Name of the ServiceAccount to create. Defaults to the name of the + Recorder resource. + https://kubernetes.io/docs/reference/kubernetes-api/workload-resources/pod-v1/#service-account + type: string + maxLength: 253 + pattern: ^[a-z0-9]([a-z0-9-.]{0,61}[a-z0-9])?$ tolerations: description: |- Tolerations for Recorder Pods. By default, the operator does not apply diff --git a/cmd/k8s-operator/deploy/examples/connector.yaml b/cmd/k8s-operator/deploy/examples/connector.yaml index d29f27cf5..f5447400e 100644 --- a/cmd/k8s-operator/deploy/examples/connector.yaml +++ b/cmd/k8s-operator/deploy/examples/connector.yaml @@ -11,7 +11,8 @@ metadata: spec: tags: - "tag:prod" - hostname: ts-prod + hostnamePrefix: ts-prod + replicas: 2 subnetRouter: advertiseRoutes: - "10.40.0.0/14" diff --git a/cmd/k8s-operator/deploy/examples/proxygroup.yaml b/cmd/k8s-operator/deploy/examples/proxygroup.yaml new file mode 100644 index 000000000..337d87f0b --- /dev/null +++ b/cmd/k8s-operator/deploy/examples/proxygroup.yaml @@ -0,0 +1,7 @@ +apiVersion: tailscale.com/v1alpha1 +kind: ProxyGroup +metadata: + name: egress-proxies +spec: + type: egress + replicas: 3 diff --git a/cmd/k8s-operator/deploy/manifests/authproxy-rbac.yaml b/cmd/k8s-operator/deploy/manifests/authproxy-rbac.yaml index ddbdda32e..5818fa69f 100644 --- a/cmd/k8s-operator/deploy/manifests/authproxy-rbac.yaml +++ b/cmd/k8s-operator/deploy/manifests/authproxy-rbac.yaml @@ -1,6 +1,12 @@ # Copyright (c) Tailscale Inc & AUTHORS # SPDX-License-Identifier: BSD-3-Clause +apiVersion: v1 +kind: ServiceAccount +metadata: + name: kube-apiserver-auth-proxy + namespace: tailscale +--- apiVersion: rbac.authorization.k8s.io/v1 kind: ClusterRole metadata: @@ -18,6 +24,9 @@ subjects: - kind: ServiceAccount name: operator namespace: tailscale +- kind: ServiceAccount + name: kube-apiserver-auth-proxy + namespace: tailscale roleRef: kind: ClusterRole name: tailscale-auth-proxy diff --git a/cmd/k8s-operator/deploy/manifests/operator.yaml b/cmd/k8s-operator/deploy/manifests/operator.yaml index 25f3b4d1c..c5da367e0 100644 --- a/cmd/k8s-operator/deploy/manifests/operator.yaml +++ b/cmd/k8s-operator/deploy/manifests/operator.yaml @@ -31,7 +31,7 @@ apiVersion: apiextensions.k8s.io/v1 kind: CustomResourceDefinition metadata: annotations: - controller-gen.kubebuilder.io/version: v0.15.1-0.20240618033008-7824932b0cab + controller-gen.kubebuilder.io/version: v0.17.0 name: connectors.tailscale.com spec: group: tailscale.com @@ -53,10 +53,17 @@ spec: jsonPath: .status.isExitNode name: IsExitNode type: string + - description: Whether this Connector instance is an app connector. + jsonPath: .status.isAppConnector + name: IsAppConnector + type: string - description: Status of the deployed Connector resources. jsonPath: .status.conditions[?(@.type == "ConnectorReady")].reason name: Status type: string + - jsonPath: .metadata.creationTimestamp + name: Age + type: date name: v1alpha1 schema: openAPIV3Schema: @@ -91,10 +98,40 @@ spec: More info: https://git.k8s.io/community/contributors/devel/sig-architecture/api-conventions.md#spec-and-status properties: + appConnector: + description: |- + AppConnector defines whether the Connector device should act as a Tailscale app connector. A Connector that is + configured as an app connector cannot be a subnet router or an exit node. If this field is unset, the + Connector does not act as an app connector. + Note that you will need to manually configure the permissions and the domains for the app connector via the + Admin panel. + Note also that the main tested and supported use case of this config option is to deploy an app connector on + Kubernetes to access SaaS applications available on the public internet. Using the app connector to expose + cluster workloads or other internal workloads to tailnet might work, but this is not a use case that we have + tested or optimised for. + If you are using the app connector to access SaaS applications because you need a predictable egress IP that + can be whitelisted, it is also your responsibility to ensure that cluster traffic from the connector flows + via that predictable IP, for example by enforcing that cluster egress traffic is routed via an egress NAT + device with a static IP address. + https://tailscale.com/kb/1281/app-connectors + properties: + routes: + description: |- + Routes are optional preconfigured routes for the domains routed via the app connector. + If not set, routes for the domains will be discovered dynamically. + If set, the app connector will immediately be able to route traffic using the preconfigured routes, but may + also dynamically discover other routes. + https://tailscale.com/kb/1332/apps-best-practices#preconfiguration + items: + format: cidr + type: string + minItems: 1 + type: array + type: object exitNode: description: |- - ExitNode defines whether the Connector node should act as a - Tailscale exit node. Defaults to false. + ExitNode defines whether the Connector device should act as a Tailscale exit node. Defaults to false. + This field is mutually exclusive with the appConnector field. https://tailscale.com/kb/1103/exit-nodes type: boolean hostname: @@ -103,9 +140,19 @@ spec: Connector node. If unset, hostname defaults to -connector. Hostname can contain lower case letters, numbers and dashes, it must not start or end with a dash and must be between 2 - and 63 characters long. + and 63 characters long. This field should only be used when creating a connector + with an unspecified number of replicas, or a single replica. pattern: ^[a-z0-9][a-z0-9-]{0,61}[a-z0-9]$ type: string + hostnamePrefix: + description: |- + HostnamePrefix specifies the hostname prefix for each + replica. Each device will have the integer number + from its StatefulSet pod appended to this prefix to form the full hostname. + HostnamePrefix can contain lower case letters, numbers and dashes, it + must not start with a dash and must be between 1 and 62 characters long. + pattern: ^[a-z0-9][a-z0-9-]{0,61}$ + type: string proxyClass: description: |- ProxyClass is the name of the ProxyClass custom resource that @@ -113,11 +160,21 @@ spec: resources created for this Connector. If unset, the operator will create resources with the default configuration. type: string + replicas: + description: |- + Replicas specifies how many devices to create. Set this to enable + high availability for app connectors, subnet routers, or exit nodes. + https://tailscale.com/kb/1115/high-availability. Defaults to 1. + format: int32 + minimum: 0 + type: integer subnetRouter: description: |- - SubnetRouter defines subnet routes that the Connector node should - expose to tailnet. If unset, none are exposed. + SubnetRouter defines subnet routes that the Connector device should + expose to tailnet as a Tailscale subnet router. https://tailscale.com/kb/1019/subnets/ + If this field is unset, the device does not get configured as a Tailscale subnet router. + This field is mutually exclusive with the appConnector field. properties: advertiseRoutes: description: |- @@ -151,8 +208,14 @@ spec: type: array type: object x-kubernetes-validations: - - message: A Connector needs to be either an exit node or a subnet router, or both. - rule: has(self.subnetRouter) || self.exitNode == true + - message: A Connector needs to have at least one of exit node, subnet router or app connector configured. + rule: has(self.subnetRouter) || (has(self.exitNode) && self.exitNode == true) || has(self.appConnector) + - message: The appConnector field is mutually exclusive with exitNode and subnetRouter fields. + rule: '!((has(self.subnetRouter) || (has(self.exitNode) && self.exitNode == true)) && has(self.appConnector))' + - message: The hostname field cannot be specified when replicas is greater than 1. + rule: '!(has(self.hostname) && has(self.replicas) && self.replicas > 1)' + - message: The hostname and hostnamePrefix fields are mutually exclusive. + rule: '!(has(self.hostname) && has(self.hostnamePrefix))' status: description: |- ConnectorStatus describes the status of the Connector. This is set @@ -219,12 +282,36 @@ spec: x-kubernetes-list-map-keys: - type x-kubernetes-list-type: map + devices: + description: Devices contains information on each device managed by the Connector resource. + items: + properties: + hostname: + description: |- + Hostname is the fully qualified domain name of the Connector replica. + If MagicDNS is enabled in your tailnet, it is the MagicDNS name of the + node. + type: string + tailnetIPs: + description: |- + TailnetIPs is the set of tailnet IP addresses (both IPv4 and IPv6) + assigned to the Connector replica. + items: + type: string + type: array + type: object + type: array hostname: description: |- Hostname is the fully qualified domain name of the Connector node. If MagicDNS is enabled in your tailnet, it is the MagicDNS name of the - node. + node. When using multiple replicas, this field will be populated with the + first replica's hostname. Use the Hostnames field for the full list + of hostnames. type: string + isAppConnector: + description: IsAppConnector is set to true if the Connector acts as an app connector. + type: boolean isExitNode: description: IsExitNode is set to true if the Connector acts as an exit node. type: boolean @@ -253,7 +340,7 @@ apiVersion: apiextensions.k8s.io/v1 kind: CustomResourceDefinition metadata: annotations: - controller-gen.kubebuilder.io/version: v0.15.1-0.20240618033008-7824932b0cab + controller-gen.kubebuilder.io/version: v0.17.0 name: dnsconfigs.tailscale.com spec: group: tailscale.com @@ -271,6 +358,9 @@ spec: jsonPath: .status.nameserver.ip name: NameserverIP type: string + - jsonPath: .metadata.creationTimestamp + name: Age + type: date name: v1alpha1 schema: openAPIV3Schema: @@ -300,7 +390,6 @@ spec: using its MagicDNS name, you must also annotate the Ingress resource with tailscale.com/experimental-forward-cluster-traffic-via-ingress annotation to ensure that the proxy created for the Ingress listens on its Pod IP address. - NB: Clusters where Pods get assigned IPv6 addresses only are currently not supported. properties: apiVersion: description: |- @@ -342,6 +431,61 @@ spec: description: Tag defaults to unstable. type: string type: object + pod: + description: Pod configuration. + properties: + tolerations: + description: If specified, applies tolerations to the pods deployed by the DNSConfig resource. + items: + description: |- + The pod this Toleration is attached to tolerates any taint that matches + the triple using the matching operator . + properties: + effect: + description: |- + Effect indicates the taint effect to match. Empty means match all taint effects. + When specified, allowed values are NoSchedule, PreferNoSchedule and NoExecute. + type: string + key: + description: |- + Key is the taint key that the toleration applies to. Empty means match all taint keys. + If the key is empty, operator must be Exists; this combination means to match all values and all keys. + type: string + operator: + description: |- + Operator represents a key's relationship to the value. + Valid operators are Exists and Equal. Defaults to Equal. + Exists is equivalent to wildcard for value, so that a pod can + tolerate all taints of a particular category. + type: string + tolerationSeconds: + description: |- + TolerationSeconds represents the period of time the toleration (which must be + of effect NoExecute, otherwise this field is ignored) tolerates the taint. By default, + it is not set, which means tolerate the taint forever (do not evict). Zero and + negative values will be treated as 0 (evict immediately) by the system. + format: int64 + type: integer + value: + description: |- + Value is the taint value the toleration matches to. + If the operator is Exists, the value should be empty, otherwise just a regular string. + type: string + type: object + type: array + type: object + replicas: + description: Replicas specifies how many Pods to create. Defaults to 1. + format: int32 + minimum: 0 + type: integer + service: + description: Service configuration. + properties: + clusterIP: + description: ClusterIP sets the static IP of the service used by the nameserver. + type: string + type: object type: object required: - nameserver @@ -415,7 +559,7 @@ spec: ip: description: |- IP is the ClusterIP of the Service fronting the deployed ts.net nameserver. - Currently you must manually update your cluster DNS config to add + Currently, you must manually update your cluster DNS config to add this address as a stub nameserver for ts.net for cluster workloads to be able to resolve MagicDNS names associated with egress or Ingress proxies. @@ -435,7 +579,7 @@ apiVersion: apiextensions.k8s.io/v1 kind: CustomResourceDefinition metadata: annotations: - controller-gen.kubebuilder.io/version: v0.15.1-0.20240618033008-7824932b0cab + controller-gen.kubebuilder.io/version: v0.17.0 name: proxyclasses.tailscale.com spec: group: tailscale.com @@ -451,6 +595,9 @@ spec: jsonPath: .status.conditions[?(@.type == "ProxyClassReady")].reason name: Status type: string + - jsonPath: .metadata.creationTimestamp + name: Age + type: date name: v1alpha1 schema: openAPIV3Schema: @@ -499,12 +646,48 @@ spec: enable: description: |- Setting enable to true will make the proxy serve Tailscale metrics - at :9001/debug/metrics. + at :9002/metrics. + A metrics Service named -metrics will also be created in the operator's namespace and will + serve the metrics at :9002/metrics. + + In 1.78.x and 1.80.x, this field also serves as the default value for + .spec.statefulSet.pod.tailscaleContainer.debug.enable. From 1.82.0, both + fields will independently default to false. + Defaults to false. type: boolean + serviceMonitor: + description: |- + Enable to create a Prometheus ServiceMonitor for scraping the proxy's Tailscale metrics. + The ServiceMonitor will select the metrics Service that gets created when metrics are enabled. + The ingested metrics for each Service monitor will have labels to identify the proxy: + ts_proxy_type: ingress_service|ingress_resource|connector|proxygroup + ts_proxy_parent_name: name of the parent resource (i.e name of the Connector, Tailscale Ingress, Tailscale Service or ProxyGroup) + ts_proxy_parent_namespace: namespace of the parent resource (if the parent resource is not cluster scoped) + job: ts__[]_ + properties: + enable: + description: If Enable is set to true, a Prometheus ServiceMonitor will be created. Enable can only be set to true if metrics are enabled. + type: boolean + labels: + additionalProperties: + maxLength: 63 + pattern: ^(([a-zA-Z0-9][-._a-zA-Z0-9]*)?[a-zA-Z0-9])?$ + type: string + description: |- + Labels to add to the ServiceMonitor. + Labels must be valid Kubernetes labels. + https://kubernetes.io/docs/concepts/overview/working-with-objects/labels/#syntax-and-character-set + type: object + required: + - enable + type: object required: - enable type: object + x-kubernetes-validations: + - message: ServiceMonitor can only be enabled if metrics are enabled + rule: '!(has(self.serviceMonitor) && self.serviceMonitor.enable && !self.enable)' statefulSet: description: |- Configuration parameters for the proxy's StatefulSet. Tailscale @@ -525,6 +708,8 @@ spec: type: object labels: additionalProperties: + maxLength: 63 + pattern: ^(([a-zA-Z0-9][-._a-zA-Z0-9]*)?[a-zA-Z0-9])?$ type: string description: |- Labels that will be added to the StatefulSet created for the proxy. @@ -807,7 +992,7 @@ spec: pod labels will be ignored. The default value is empty. The same key is forbidden to exist in both matchLabelKeys and labelSelector. Also, matchLabelKeys cannot be set when labelSelector isn't set. - This is an alpha field and requires enabling MatchLabelKeysInPodAffinity feature gate. + This is a beta field and requires enabling MatchLabelKeysInPodAffinity feature gate (enabled by default). items: type: string type: array @@ -822,7 +1007,7 @@ spec: pod labels will be ignored. The default value is empty. The same key is forbidden to exist in both mismatchLabelKeys and labelSelector. Also, mismatchLabelKeys cannot be set when labelSelector isn't set. - This is an alpha field and requires enabling MatchLabelKeysInPodAffinity feature gate. + This is a beta field and requires enabling MatchLabelKeysInPodAffinity feature gate (enabled by default). items: type: string type: array @@ -983,7 +1168,7 @@ spec: pod labels will be ignored. The default value is empty. The same key is forbidden to exist in both matchLabelKeys and labelSelector. Also, matchLabelKeys cannot be set when labelSelector isn't set. - This is an alpha field and requires enabling MatchLabelKeysInPodAffinity feature gate. + This is a beta field and requires enabling MatchLabelKeysInPodAffinity feature gate (enabled by default). items: type: string type: array @@ -998,7 +1183,7 @@ spec: pod labels will be ignored. The default value is empty. The same key is forbidden to exist in both mismatchLabelKeys and labelSelector. Also, mismatchLabelKeys cannot be set when labelSelector isn't set. - This is an alpha field and requires enabling MatchLabelKeysInPodAffinity feature gate. + This is a beta field and requires enabling MatchLabelKeysInPodAffinity feature gate (enabled by default). items: type: string type: array @@ -1152,7 +1337,7 @@ spec: pod labels will be ignored. The default value is empty. The same key is forbidden to exist in both matchLabelKeys and labelSelector. Also, matchLabelKeys cannot be set when labelSelector isn't set. - This is an alpha field and requires enabling MatchLabelKeysInPodAffinity feature gate. + This is a beta field and requires enabling MatchLabelKeysInPodAffinity feature gate (enabled by default). items: type: string type: array @@ -1167,7 +1352,7 @@ spec: pod labels will be ignored. The default value is empty. The same key is forbidden to exist in both mismatchLabelKeys and labelSelector. Also, mismatchLabelKeys cannot be set when labelSelector isn't set. - This is an alpha field and requires enabling MatchLabelKeysInPodAffinity feature gate. + This is a beta field and requires enabling MatchLabelKeysInPodAffinity feature gate (enabled by default). items: type: string type: array @@ -1328,7 +1513,7 @@ spec: pod labels will be ignored. The default value is empty. The same key is forbidden to exist in both matchLabelKeys and labelSelector. Also, matchLabelKeys cannot be set when labelSelector isn't set. - This is an alpha field and requires enabling MatchLabelKeysInPodAffinity feature gate. + This is a beta field and requires enabling MatchLabelKeysInPodAffinity feature gate (enabled by default). items: type: string type: array @@ -1343,7 +1528,7 @@ spec: pod labels will be ignored. The default value is empty. The same key is forbidden to exist in both mismatchLabelKeys and labelSelector. Also, mismatchLabelKeys cannot be set when labelSelector isn't set. - This is an alpha field and requires enabling MatchLabelKeysInPodAffinity feature gate. + This is a beta field and requires enabling MatchLabelKeysInPodAffinity feature gate (enabled by default). items: type: string type: array @@ -1432,6 +1617,62 @@ spec: Annotations must be valid Kubernetes annotations. https://kubernetes.io/docs/concepts/overview/working-with-objects/annotations/#syntax-and-character-set type: object + dnsConfig: + description: |- + DNSConfig defines DNS parameters for the proxy Pod in addition to those generated from DNSPolicy. + When DNSPolicy is set to "None", DNSConfig must be specified. + https://kubernetes.io/docs/concepts/services-networking/dns-pod-service/#pod-dns-config + properties: + nameservers: + description: |- + A list of DNS name server IP addresses. + This will be appended to the base nameservers generated from DNSPolicy. + Duplicated nameservers will be removed. + items: + type: string + type: array + x-kubernetes-list-type: atomic + options: + description: |- + A list of DNS resolver options. + This will be merged with the base options generated from DNSPolicy. + Duplicated entries will be removed. Resolution options given in Options + will override those that appear in the base DNSPolicy. + items: + description: PodDNSConfigOption defines DNS resolver options of a pod. + properties: + name: + description: |- + Name is this DNS resolver option's name. + Required. + type: string + value: + description: Value is this DNS resolver option's value. + type: string + type: object + type: array + x-kubernetes-list-type: atomic + searches: + description: |- + A list of DNS search domains for host-name lookup. + This will be appended to the base search paths generated from DNSPolicy. + Duplicated search paths will be removed. + items: + type: string + type: array + x-kubernetes-list-type: atomic + type: object + dnsPolicy: + description: |- + DNSPolicy defines how DNS will be configured for the proxy Pod. + By default the Tailscale Kubernetes Operator does not set a DNS policy (uses cluster default). + https://kubernetes.io/docs/concepts/services-networking/dns-pod-service/#pod-s-dns-policy + enum: + - ClusterFirstWithHostNet + - ClusterFirst + - Default + - None + type: string imagePullSecrets: description: |- Proxy Pod's image pull Secrets. @@ -1455,6 +1696,8 @@ spec: type: array labels: additionalProperties: + maxLength: 63 + pattern: ^(([a-zA-Z0-9][-._a-zA-Z0-9]*)?[a-zA-Z0-9])?$ type: string description: |- Labels that will be added to the proxy Pod. @@ -1477,6 +1720,12 @@ spec: selector. https://kubernetes.io/docs/reference/kubernetes-api/workload-resources/pod-v1/#scheduling type: object + priorityClassName: + description: |- + PriorityClassName for the proxy Pod. + By default Tailscale Kubernetes operator does not apply any priority class. + https://kubernetes.io/docs/reference/kubernetes-api/workload-resources/pod-v1/#scheduling + type: string securityContext: description: |- Proxy Pod's security context. @@ -1560,6 +1809,32 @@ spec: Note that this field cannot be set when spec.os.name is windows. format: int64 type: integer + seLinuxChangePolicy: + description: |- + seLinuxChangePolicy defines how the container's SELinux label is applied to all volumes used by the Pod. + It has no effect on nodes that do not support SELinux or to volumes does not support SELinux. + Valid values are "MountOption" and "Recursive". + + "Recursive" means relabeling of all files on all Pod volumes by the container runtime. + This may be slow for large volumes, but allows mixing privileged and unprivileged Pods sharing the same volume on the same node. + + "MountOption" mounts all eligible Pod volumes with `-o context` mount option. + This requires all Pods that share the same volume to use the same SELinux label. + It is not possible to share the same volume among privileged and unprivileged Pods. + Eligible volumes are in-tree FibreChannel and iSCSI volumes, and all CSI volumes + whose CSI driver announces SELinux support by setting spec.seLinuxMount: true in their + CSIDriver instance. Other volumes are always re-labelled recursively. + "MountOption" value is allowed only when SELinuxMount feature gate is enabled. + + If not specified and SELinuxMount feature gate is enabled, "MountOption" is used. + If not specified and SELinuxMount feature gate is disabled, "MountOption" is used for ReadWriteOncePod volumes + and "Recursive" for all other volumes. + + This field affects only Pods that have SELinux label set, either in PodSecurityContext or in SecurityContext of all containers. + + All Pods that use the same volume should use the same seLinuxChangePolicy, otherwise some pods can get stuck in ContainerCreating state. + Note that this field cannot be set when spec.os.name is windows. + type: string seLinuxOptions: description: |- The SELinux context to be applied to all containers. @@ -1608,18 +1883,28 @@ spec: type: object supplementalGroups: description: |- - A list of groups applied to the first process run in each container, in addition - to the container's primary GID, the fsGroup (if specified), and group memberships - defined in the container image for the uid of the container process. If unspecified, - no additional groups are added to any container. Note that group memberships - defined in the container image for the uid of the container process are still effective, - even if they are not included in this list. + A list of groups applied to the first process run in each container, in + addition to the container's primary GID and fsGroup (if specified). If + the SupplementalGroupsPolicy feature is enabled, the + supplementalGroupsPolicy field determines whether these are in addition + to or instead of any group memberships defined in the container image. + If unspecified, no additional groups are added, though group memberships + defined in the container image may still be used, depending on the + supplementalGroupsPolicy field. Note that this field cannot be set when spec.os.name is windows. items: format: int64 type: integer type: array x-kubernetes-list-type: atomic + supplementalGroupsPolicy: + description: |- + Defines how supplemental groups of the first container processes are calculated. + Valid values are "Merge" and "Strict". If not specified, "Merge" is used. + (Alpha) Using the field requires the SupplementalGroupsPolicy feature gate to be enabled + and the container runtime must implement support for this feature. + Note that this field cannot be set when spec.os.name is windows. + type: string sysctls: description: |- Sysctls hold a list of namespaced sysctls used for the pod. Pods with unsupported @@ -1675,6 +1960,25 @@ spec: tailscaleContainer: description: Configuration for the proxy container running tailscale. properties: + debug: + description: |- + Configuration for enabling extra debug information in the container. + Not recommended for production use. + properties: + enable: + description: |- + Enable tailscaled's HTTP pprof endpoints at :9001/debug/pprof/ + and internal debug metrics endpoint at :9001/debug/metrics, where + 9001 is a container port named "debug". The endpoints and their responses + may change in backwards incompatible ways in the future, and should not + be considered stable. + + In 1.78.x and 1.80.x, this setting will default to the value of + .spec.metrics.enable, and requests to the "metrics" port matching the + mux pattern /debug/ will be forwarded to the "debug" port. In 1.82.x, + this setting will default to false, and no requests will be proxied. + type: boolean + type: object env: description: |- List of environment variables to set in the container. @@ -1707,12 +2011,21 @@ spec: type: array image: description: |- - Container image name. By default images are pulled from - docker.io/tailscale/tailscale, but the official images are also - available at ghcr.io/tailscale/tailscale. Specifying image name here - will override any proxy image values specified via the Kubernetes - operator's Helm chart values or PROXY_IMAGE env var in the operator - Deployment. + Container image name. By default images are pulled from docker.io/tailscale, + but the official images are also available at ghcr.io/tailscale. + + For all uses except on ProxyGroups of type "kube-apiserver", this image must + be either tailscale/tailscale, or an equivalent mirror of that image. + To apply to ProxyGroups of type "kube-apiserver", this image must be + tailscale/k8s-proxy or a mirror of that image. + + For "tailscale/tailscale"-based proxies, specifying image name here will + override any proxy image values specified via the Kubernetes operator's + Helm chart values or PROXY_IMAGE env var in the operator Deployment. + For "tailscale/k8s-proxy"-based proxies, there is currently no way to + configure your own default, and this field is the only way to use a + custom image. + https://kubernetes.io/docs/reference/kubernetes-api/workload-resources/pod-v1/#image type: string imagePullPolicy: @@ -1751,6 +2064,12 @@ spec: the Pod where this field is used. It makes that resource available inside a container. type: string + request: + description: |- + Request is the name chosen for a request in the referenced claim. + If empty, everything from the claim is made available, otherwise + only the result of this request. + type: string required: - name type: object @@ -1786,11 +2105,12 @@ spec: securityContext: description: |- Container security context. - Security context specified here will override the security context by the operator. - By default the operator: - - sets 'privileged: true' for the init container - - set NET_ADMIN capability for tailscale container for proxies that - are created for Services or Connector. + Security context specified here will override the security context set by the operator. + By default the operator sets the Tailscale container and the Tailscale init container to privileged + for proxies created for Tailscale ingress and egress Service, Connector and ProxyGroup. + You can reduce the permissions of the Tailscale container to cap NET_ADMIN by + installing device plugin in your cluster and configuring the proxies tun device to be created + by the device plugin, see https://github.com/tailscale/tailscale/issues/10814#issuecomment-2479977752 https://kubernetes.io/docs/reference/kubernetes-api/workload-resources/pod-v1/#security-context properties: allowPrivilegeEscalation: @@ -1858,7 +2178,7 @@ spec: procMount: description: |- procMount denotes the type of proc mount to use for the containers. - The default is DefaultProcMount which uses the container runtime defaults for + The default value is Default which uses the container runtime defaults for readonly paths and masked paths. This requires the ProcMountType feature flag to be enabled. Note that this field cannot be set when spec.os.name is windows. @@ -1977,8 +2297,29 @@ spec: type: object type: object tailscaleInitContainer: - description: Configuration for the proxy init container that enables forwarding. + description: |- + Configuration for the proxy init container that enables forwarding. + Not valid to apply to ProxyGroups of type "kube-apiserver". properties: + debug: + description: |- + Configuration for enabling extra debug information in the container. + Not recommended for production use. + properties: + enable: + description: |- + Enable tailscaled's HTTP pprof endpoints at :9001/debug/pprof/ + and internal debug metrics endpoint at :9001/debug/metrics, where + 9001 is a container port named "debug". The endpoints and their responses + may change in backwards incompatible ways in the future, and should not + be considered stable. + + In 1.78.x and 1.80.x, this setting will default to the value of + .spec.metrics.enable, and requests to the "metrics" port matching the + mux pattern /debug/ will be forwarded to the "debug" port. In 1.82.x, + this setting will default to false, and no requests will be proxied. + type: boolean + type: object env: description: |- List of environment variables to set in the container. @@ -2011,12 +2352,21 @@ spec: type: array image: description: |- - Container image name. By default images are pulled from - docker.io/tailscale/tailscale, but the official images are also - available at ghcr.io/tailscale/tailscale. Specifying image name here - will override any proxy image values specified via the Kubernetes - operator's Helm chart values or PROXY_IMAGE env var in the operator - Deployment. + Container image name. By default images are pulled from docker.io/tailscale, + but the official images are also available at ghcr.io/tailscale. + + For all uses except on ProxyGroups of type "kube-apiserver", this image must + be either tailscale/tailscale, or an equivalent mirror of that image. + To apply to ProxyGroups of type "kube-apiserver", this image must be + tailscale/k8s-proxy or a mirror of that image. + + For "tailscale/tailscale"-based proxies, specifying image name here will + override any proxy image values specified via the Kubernetes operator's + Helm chart values or PROXY_IMAGE env var in the operator Deployment. + For "tailscale/k8s-proxy"-based proxies, there is currently no way to + configure your own default, and this field is the only way to use a + custom image. + https://kubernetes.io/docs/reference/kubernetes-api/workload-resources/pod-v1/#image type: string imagePullPolicy: @@ -2055,6 +2405,12 @@ spec: the Pod where this field is used. It makes that resource available inside a container. type: string + request: + description: |- + Request is the name chosen for a request in the referenced claim. + If empty, everything from the claim is made available, otherwise + only the result of this request. + type: string required: - name type: object @@ -2090,11 +2446,12 @@ spec: securityContext: description: |- Container security context. - Security context specified here will override the security context by the operator. - By default the operator: - - sets 'privileged: true' for the init container - - set NET_ADMIN capability for tailscale container for proxies that - are created for Services or Connector. + Security context specified here will override the security context set by the operator. + By default the operator sets the Tailscale container and the Tailscale init container to privileged + for proxies created for Tailscale ingress and egress Service, Connector and ProxyGroup. + You can reduce the permissions of the Tailscale container to cap NET_ADMIN by + installing device plugin in your cluster and configuring the proxies tun device to be created + by the device plugin, see https://github.com/tailscale/tailscale/issues/10814#issuecomment-2479977752 https://kubernetes.io/docs/reference/kubernetes-api/workload-resources/pod-v1/#security-context properties: allowPrivilegeEscalation: @@ -2162,7 +2519,7 @@ spec: procMount: description: |- procMount denotes the type of proc mount to use for the containers. - The default is DefaultProcMount which uses the container runtime defaults for + The default value is Default which uses the container runtime defaults for readonly paths and masked paths. This requires the ProcMountType feature flag to be enabled. Note that this field cannot be set when spec.os.name is windows. @@ -2323,7 +2680,228 @@ spec: type: string type: object type: array + topologySpreadConstraints: + description: |- + Proxy Pod's topology spread constraints. + By default Tailscale Kubernetes operator does not apply any topology spread constraints. + https://kubernetes.io/docs/concepts/scheduling-eviction/topology-spread-constraints/ + items: + description: TopologySpreadConstraint specifies how to spread matching pods among the given topology. + properties: + labelSelector: + description: |- + LabelSelector is used to find matching pods. + Pods that match this label selector are counted to determine the number of pods + in their corresponding topology domain. + properties: + matchExpressions: + description: matchExpressions is a list of label selector requirements. The requirements are ANDed. + items: + description: |- + A label selector requirement is a selector that contains values, a key, and an operator that + relates the key and values. + properties: + key: + description: key is the label key that the selector applies to. + type: string + operator: + description: |- + operator represents a key's relationship to a set of values. + Valid operators are In, NotIn, Exists and DoesNotExist. + type: string + values: + description: |- + values is an array of string values. If the operator is In or NotIn, + the values array must be non-empty. If the operator is Exists or DoesNotExist, + the values array must be empty. This array is replaced during a strategic + merge patch. + items: + type: string + type: array + x-kubernetes-list-type: atomic + required: + - key + - operator + type: object + type: array + x-kubernetes-list-type: atomic + matchLabels: + additionalProperties: + type: string + description: |- + matchLabels is a map of {key,value} pairs. A single {key,value} in the matchLabels + map is equivalent to an element of matchExpressions, whose key field is "key", the + operator is "In", and the values array contains only "value". The requirements are ANDed. + type: object + type: object + x-kubernetes-map-type: atomic + matchLabelKeys: + description: |- + MatchLabelKeys is a set of pod label keys to select the pods over which + spreading will be calculated. The keys are used to lookup values from the + incoming pod labels, those key-value labels are ANDed with labelSelector + to select the group of existing pods over which spreading will be calculated + for the incoming pod. The same key is forbidden to exist in both MatchLabelKeys and LabelSelector. + MatchLabelKeys cannot be set when LabelSelector isn't set. + Keys that don't exist in the incoming pod labels will + be ignored. A null or empty list means only match against labelSelector. + + This is a beta field and requires the MatchLabelKeysInPodTopologySpread feature gate to be enabled (enabled by default). + items: + type: string + type: array + x-kubernetes-list-type: atomic + maxSkew: + description: |- + MaxSkew describes the degree to which pods may be unevenly distributed. + When `whenUnsatisfiable=DoNotSchedule`, it is the maximum permitted difference + between the number of matching pods in the target topology and the global minimum. + The global minimum is the minimum number of matching pods in an eligible domain + or zero if the number of eligible domains is less than MinDomains. + For example, in a 3-zone cluster, MaxSkew is set to 1, and pods with the same + labelSelector spread as 2/2/1: + In this case, the global minimum is 1. + | zone1 | zone2 | zone3 | + | P P | P P | P | + - if MaxSkew is 1, incoming pod can only be scheduled to zone3 to become 2/2/2; + scheduling it onto zone1(zone2) would make the ActualSkew(3-1) on zone1(zone2) + violate MaxSkew(1). + - if MaxSkew is 2, incoming pod can be scheduled onto any zone. + When `whenUnsatisfiable=ScheduleAnyway`, it is used to give higher precedence + to topologies that satisfy it. + It's a required field. Default value is 1 and 0 is not allowed. + format: int32 + type: integer + minDomains: + description: |- + MinDomains indicates a minimum number of eligible domains. + When the number of eligible domains with matching topology keys is less than minDomains, + Pod Topology Spread treats "global minimum" as 0, and then the calculation of Skew is performed. + And when the number of eligible domains with matching topology keys equals or greater than minDomains, + this value has no effect on scheduling. + As a result, when the number of eligible domains is less than minDomains, + scheduler won't schedule more than maxSkew Pods to those domains. + If value is nil, the constraint behaves as if MinDomains is equal to 1. + Valid values are integers greater than 0. + When value is not nil, WhenUnsatisfiable must be DoNotSchedule. + + For example, in a 3-zone cluster, MaxSkew is set to 2, MinDomains is set to 5 and pods with the same + labelSelector spread as 2/2/2: + | zone1 | zone2 | zone3 | + | P P | P P | P P | + The number of domains is less than 5(MinDomains), so "global minimum" is treated as 0. + In this situation, new pod with the same labelSelector cannot be scheduled, + because computed skew will be 3(3 - 0) if new Pod is scheduled to any of the three zones, + it will violate MaxSkew. + format: int32 + type: integer + nodeAffinityPolicy: + description: |- + NodeAffinityPolicy indicates how we will treat Pod's nodeAffinity/nodeSelector + when calculating pod topology spread skew. Options are: + - Honor: only nodes matching nodeAffinity/nodeSelector are included in the calculations. + - Ignore: nodeAffinity/nodeSelector are ignored. All nodes are included in the calculations. + + If this value is nil, the behavior is equivalent to the Honor policy. + This is a beta-level feature default enabled by the NodeInclusionPolicyInPodTopologySpread feature flag. + type: string + nodeTaintsPolicy: + description: |- + NodeTaintsPolicy indicates how we will treat node taints when calculating + pod topology spread skew. Options are: + - Honor: nodes without taints, along with tainted nodes for which the incoming pod + has a toleration, are included. + - Ignore: node taints are ignored. All nodes are included. + + If this value is nil, the behavior is equivalent to the Ignore policy. + This is a beta-level feature default enabled by the NodeInclusionPolicyInPodTopologySpread feature flag. + type: string + topologyKey: + description: |- + TopologyKey is the key of node labels. Nodes that have a label with this key + and identical values are considered to be in the same topology. + We consider each as a "bucket", and try to put balanced number + of pods into each bucket. + We define a domain as a particular instance of a topology. + Also, we define an eligible domain as a domain whose nodes meet the requirements of + nodeAffinityPolicy and nodeTaintsPolicy. + e.g. If TopologyKey is "kubernetes.io/hostname", each Node is a domain of that topology. + And, if TopologyKey is "topology.kubernetes.io/zone", each zone is a domain of that topology. + It's a required field. + type: string + whenUnsatisfiable: + description: |- + WhenUnsatisfiable indicates how to deal with a pod if it doesn't satisfy + the spread constraint. + - DoNotSchedule (default) tells the scheduler not to schedule it. + - ScheduleAnyway tells the scheduler to schedule the pod in any location, + but giving higher precedence to topologies that would help reduce the + skew. + A constraint is considered "Unsatisfiable" for an incoming pod + if and only if every possible node assignment for that pod would violate + "MaxSkew" on some topology. + For example, in a 3-zone cluster, MaxSkew is set to 1, and pods with the same + labelSelector spread as 3/1/1: + | zone1 | zone2 | zone3 | + | P P P | P | P | + If WhenUnsatisfiable is set to DoNotSchedule, incoming pod can only be scheduled + to zone2(zone3) to become 3/2/1(3/1/2) as ActualSkew(2-1) on zone2(zone3) satisfies + MaxSkew(1). In other words, the cluster can still be imbalanced, but scheduler + won't make it *more* imbalanced. + It's a required field. + type: string + required: + - maxSkew + - topologyKey + - whenUnsatisfiable + type: object + type: array + type: object + type: object + staticEndpoints: + description: |- + Configuration for 'static endpoints' on proxies in order to facilitate + direct connections from other devices on the tailnet. + See https://tailscale.com/kb/1445/kubernetes-operator-customization#static-endpoints. + properties: + nodePort: + description: The configuration for static endpoints using NodePort Services. + properties: + ports: + description: |- + The port ranges from which the operator will select NodePorts for the Services. + You must ensure that firewall rules allow UDP ingress traffic for these ports + to the node's external IPs. + The ports must be in the range of service node ports for the cluster (default `30000-32767`). + See https://kubernetes.io/docs/concepts/services-networking/service/#type-nodeport. + items: + properties: + endPort: + description: |- + endPort indicates that the range of ports from port to endPort if set, inclusive, + should be used. This field cannot be defined if the port field is not defined. + The endPort must be either unset, or equal or greater than port. + type: integer + port: + description: port represents a port selected to be used. This is a required field. + type: integer + required: + - port + type: object + minItems: 1 + type: array + selector: + additionalProperties: + type: string + description: |- + A selector which will be used to select the node's that will have their `ExternalIP`'s advertised + by the ProxyGroup as Static Endpoints. + type: object + required: + - ports type: object + required: + - nodePort type: object tailscale: description: |- @@ -2340,6 +2918,22 @@ spec: Defaults to false. type: boolean type: object + useLetsEncryptStagingEnvironment: + description: |- + Set UseLetsEncryptStagingEnvironment to true to issue TLS + certificates for any HTTPS endpoints exposed to the tailnet from + LetsEncrypt's staging environment. + https://letsencrypt.org/docs/staging-environment/ + This setting only affects Tailscale Ingress resources. + By default Ingress TLS certificates are issued from LetsEncrypt's + production environment. + Changing this setting true -> false, will result in any + existing certs being re-issued from the production environment. + Changing this setting false (default) -> true, when certs have already + been provisioned from production environment will NOT result in certs + being re-issued from the staging environment before they need to be + renewed. + type: boolean type: object status: description: |- @@ -2420,7 +3014,7 @@ apiVersion: apiextensions.k8s.io/v1 kind: CustomResourceDefinition metadata: annotations: - controller-gen.kubebuilder.io/version: v0.15.1-0.20240618033008-7824932b0cab + controller-gen.kubebuilder.io/version: v0.17.0 name: proxygroups.tailscale.com spec: group: tailscale.com @@ -2438,9 +3032,38 @@ spec: jsonPath: .status.conditions[?(@.type == "ProxyGroupReady")].reason name: Status type: string + - description: URL of the kube-apiserver proxy advertised by the ProxyGroup devices, if any. Only applies to ProxyGroups of type kube-apiserver. + jsonPath: .status.url + name: URL + type: string + - description: ProxyGroup type. + jsonPath: .spec.type + name: Type + type: string + - jsonPath: .metadata.creationTimestamp + name: Age + type: date name: v1alpha1 schema: openAPIV3Schema: + description: |- + ProxyGroup defines a set of Tailscale devices that will act as proxies. + Depending on spec.Type, it can be a group of egress, ingress, or kube-apiserver + proxies. In addition to running a highly available set of proxies, ingress + and egress ProxyGroups also allow for serving many annotated Services from a + single set of proxies to minimise resource consumption. + + For ingress and egress, use the tailscale.com/proxy-group annotation on a + Service to specify that the proxy should be implemented by a ProxyGroup + instead of a single dedicated proxy. + + More info: + * https://tailscale.com/kb/1438/kubernetes-operator-cluster-egress + * https://tailscale.com/kb/1439/kubernetes-operator-cluster-ingress + + For kube-apiserver, the ProxyGroup is a standalone resource. Use the + spec.kubeAPIServer field to configure options specific to the kube-apiserver + ProxyGroup type. properties: apiVersion: description: |- @@ -2471,17 +3094,44 @@ spec: must not start with a dash and must be between 1 and 62 characters long. pattern: ^[a-z0-9][a-z0-9-]{0,61}$ type: string + kubeAPIServer: + description: |- + KubeAPIServer contains configuration specific to the kube-apiserver + ProxyGroup type. This field is only used when Type is set to "kube-apiserver". + properties: + hostname: + description: |- + Hostname is the hostname with which to expose the Kubernetes API server + proxies. Must be a valid DNS label no longer than 63 characters. If not + specified, the name of the ProxyGroup is used as the hostname. Must be + unique across the whole tailnet. + pattern: ^[a-z0-9]([a-z0-9-]{0,61}[a-z0-9])?$ + type: string + mode: + description: |- + Mode to run the API server proxy in. Supported modes are auth and noauth. + In auth mode, requests from the tailnet proxied over to the Kubernetes + API server are additionally impersonated using the sender's tailnet identity. + If not specified, defaults to auth mode. + enum: + - auth + - noauth + type: string + type: object proxyClass: description: |- ProxyClass is the name of the ProxyClass custom resource that contains configuration options that should be applied to the resources created - for this ProxyGroup. If unset, and no default ProxyClass is set, the - operator will create resources with the default configuration. + for this ProxyGroup. If unset, and there is no default ProxyClass + configured, the operator will create resources with the default + configuration. type: string replicas: description: |- Replicas specifies how many replicas to create the StatefulSet with. Defaults to 2. + format: int32 + minimum: 0 type: integer tags: description: |- @@ -2497,12 +3147,16 @@ spec: type: array type: description: |- - Type of the ProxyGroup, either ingress or egress. Each set of proxies - managed by a single ProxyGroup definition operate as only ingress or - only egress proxies. + Type of the ProxyGroup proxies. Supported types are egress, ingress, and kube-apiserver. + Type is immutable once a ProxyGroup is created. enum: - egress + - ingress + - kube-apiserver type: string + x-kubernetes-validations: + - message: ProxyGroup type is immutable + rule: self == oldSelf required: - type type: object @@ -2514,7 +3168,20 @@ spec: conditions: description: |- List of status conditions to indicate the status of the ProxyGroup - resources. Known condition types are `ProxyGroupReady`. + resources. Known condition types include `ProxyGroupReady` and + `ProxyGroupAvailable`. + + * `ProxyGroupReady` indicates all ProxyGroup resources are reconciled and + all expected conditions are true. + * `ProxyGroupAvailable` indicates that at least one proxy is ready to + serve traffic. + + For ProxyGroups of type kube-apiserver, there are two additional conditions: + + * `KubeAPIServerProxyConfigured` indicates that at least one API server + proxy is configured and ready to serve traffic. + * `KubeAPIServerProxyValid` indicates that spec.kubeAPIServer config is + valid. items: description: Condition contains details for one aspect of the current state of this API Resource. properties: @@ -2582,6 +3249,11 @@ spec: If MagicDNS is enabled in your tailnet, it is the MagicDNS name of the node. type: string + staticEndpoints: + description: StaticEndpoints are user configured, 'static' endpoints by which tailnet peers can reach this device. + items: + type: string + type: array tailnetIPs: description: |- TailnetIPs is the set of tailnet IP addresses (both IPv4 and IPv6) @@ -2596,6 +3268,11 @@ spec: x-kubernetes-list-map-keys: - hostname x-kubernetes-list-type: map + url: + description: |- + URL of the kube-apiserver proxy advertised by the ProxyGroup devices, if + any. Only applies to ProxyGroups of type kube-apiserver. + type: string type: object required: - spec @@ -2609,7 +3286,7 @@ apiVersion: apiextensions.k8s.io/v1 kind: CustomResourceDefinition metadata: annotations: - controller-gen.kubebuilder.io/version: v0.15.1-0.20240618033008-7824932b0cab + controller-gen.kubebuilder.io/version: v0.17.0 name: recorders.tailscale.com spec: group: tailscale.com @@ -2631,9 +3308,18 @@ spec: jsonPath: .status.devices[?(@.url != "")].url name: URL type: string + - jsonPath: .metadata.creationTimestamp + name: Age + type: date name: v1alpha1 schema: openAPIV3Schema: + description: |- + Recorder defines a tsrecorder device for recording SSH sessions. By default, + it will store recordings in a local ephemeral volume. If you want to persist + recordings, you can configure an S3-compatible API for storage. + + More info: https://tailscale.com/kb/1484/kubernetes-operator-deploying-tsrecorder properties: apiVersion: description: |- @@ -2957,7 +3643,7 @@ spec: pod labels will be ignored. The default value is empty. The same key is forbidden to exist in both matchLabelKeys and labelSelector. Also, matchLabelKeys cannot be set when labelSelector isn't set. - This is an alpha field and requires enabling MatchLabelKeysInPodAffinity feature gate. + This is a beta field and requires enabling MatchLabelKeysInPodAffinity feature gate (enabled by default). items: type: string type: array @@ -2972,7 +3658,7 @@ spec: pod labels will be ignored. The default value is empty. The same key is forbidden to exist in both mismatchLabelKeys and labelSelector. Also, mismatchLabelKeys cannot be set when labelSelector isn't set. - This is an alpha field and requires enabling MatchLabelKeysInPodAffinity feature gate. + This is a beta field and requires enabling MatchLabelKeysInPodAffinity feature gate (enabled by default). items: type: string type: array @@ -3133,7 +3819,7 @@ spec: pod labels will be ignored. The default value is empty. The same key is forbidden to exist in both matchLabelKeys and labelSelector. Also, matchLabelKeys cannot be set when labelSelector isn't set. - This is an alpha field and requires enabling MatchLabelKeysInPodAffinity feature gate. + This is a beta field and requires enabling MatchLabelKeysInPodAffinity feature gate (enabled by default). items: type: string type: array @@ -3148,7 +3834,7 @@ spec: pod labels will be ignored. The default value is empty. The same key is forbidden to exist in both mismatchLabelKeys and labelSelector. Also, mismatchLabelKeys cannot be set when labelSelector isn't set. - This is an alpha field and requires enabling MatchLabelKeysInPodAffinity feature gate. + This is a beta field and requires enabling MatchLabelKeysInPodAffinity feature gate (enabled by default). items: type: string type: array @@ -3302,7 +3988,7 @@ spec: pod labels will be ignored. The default value is empty. The same key is forbidden to exist in both matchLabelKeys and labelSelector. Also, matchLabelKeys cannot be set when labelSelector isn't set. - This is an alpha field and requires enabling MatchLabelKeysInPodAffinity feature gate. + This is a beta field and requires enabling MatchLabelKeysInPodAffinity feature gate (enabled by default). items: type: string type: array @@ -3317,7 +4003,7 @@ spec: pod labels will be ignored. The default value is empty. The same key is forbidden to exist in both mismatchLabelKeys and labelSelector. Also, mismatchLabelKeys cannot be set when labelSelector isn't set. - This is an alpha field and requires enabling MatchLabelKeysInPodAffinity feature gate. + This is a beta field and requires enabling MatchLabelKeysInPodAffinity feature gate (enabled by default). items: type: string type: array @@ -3478,7 +4164,7 @@ spec: pod labels will be ignored. The default value is empty. The same key is forbidden to exist in both matchLabelKeys and labelSelector. Also, matchLabelKeys cannot be set when labelSelector isn't set. - This is an alpha field and requires enabling MatchLabelKeysInPodAffinity feature gate. + This is a beta field and requires enabling MatchLabelKeysInPodAffinity feature gate (enabled by default). items: type: string type: array @@ -3493,7 +4179,7 @@ spec: pod labels will be ignored. The default value is empty. The same key is forbidden to exist in both mismatchLabelKeys and labelSelector. Also, mismatchLabelKeys cannot be set when labelSelector isn't set. - This is an alpha field and requires enabling MatchLabelKeysInPodAffinity feature gate. + This is a beta field and requires enabling MatchLabelKeysInPodAffinity feature gate (enabled by default). items: type: string type: array @@ -3655,6 +4341,12 @@ spec: the Pod where this field is used. It makes that resource available inside a container. type: string + request: + description: |- + Request is the name chosen for a request in the referenced claim. + If empty, everything from the claim is made available, otherwise + only the result of this request. + type: string required: - name type: object @@ -3758,7 +4450,7 @@ spec: procMount: description: |- procMount denotes the type of proc mount to use for the containers. - The default is DefaultProcMount which uses the container runtime defaults for + The default value is Default which uses the container runtime defaults for readonly paths and masked paths. This requires the ProcMountType feature flag to be enabled. Note that this field cannot be set when spec.os.name is windows. @@ -3995,6 +4687,32 @@ spec: Note that this field cannot be set when spec.os.name is windows. format: int64 type: integer + seLinuxChangePolicy: + description: |- + seLinuxChangePolicy defines how the container's SELinux label is applied to all volumes used by the Pod. + It has no effect on nodes that do not support SELinux or to volumes does not support SELinux. + Valid values are "MountOption" and "Recursive". + + "Recursive" means relabeling of all files on all Pod volumes by the container runtime. + This may be slow for large volumes, but allows mixing privileged and unprivileged Pods sharing the same volume on the same node. + + "MountOption" mounts all eligible Pod volumes with `-o context` mount option. + This requires all Pods that share the same volume to use the same SELinux label. + It is not possible to share the same volume among privileged and unprivileged Pods. + Eligible volumes are in-tree FibreChannel and iSCSI volumes, and all CSI volumes + whose CSI driver announces SELinux support by setting spec.seLinuxMount: true in their + CSIDriver instance. Other volumes are always re-labelled recursively. + "MountOption" value is allowed only when SELinuxMount feature gate is enabled. + + If not specified and SELinuxMount feature gate is enabled, "MountOption" is used. + If not specified and SELinuxMount feature gate is disabled, "MountOption" is used for ReadWriteOncePod volumes + and "Recursive" for all other volumes. + + This field affects only Pods that have SELinux label set, either in PodSecurityContext or in SecurityContext of all containers. + + All Pods that use the same volume should use the same seLinuxChangePolicy, otherwise some pods can get stuck in ContainerCreating state. + Note that this field cannot be set when spec.os.name is windows. + type: string seLinuxOptions: description: |- The SELinux context to be applied to all containers. @@ -4043,18 +4761,28 @@ spec: type: object supplementalGroups: description: |- - A list of groups applied to the first process run in each container, in addition - to the container's primary GID, the fsGroup (if specified), and group memberships - defined in the container image for the uid of the container process. If unspecified, - no additional groups are added to any container. Note that group memberships - defined in the container image for the uid of the container process are still effective, - even if they are not included in this list. + A list of groups applied to the first process run in each container, in + addition to the container's primary GID and fsGroup (if specified). If + the SupplementalGroupsPolicy feature is enabled, the + supplementalGroupsPolicy field determines whether these are in addition + to or instead of any group memberships defined in the container image. + If unspecified, no additional groups are added, though group memberships + defined in the container image may still be used, depending on the + supplementalGroupsPolicy field. Note that this field cannot be set when spec.os.name is windows. items: format: int64 type: integer type: array x-kubernetes-list-type: atomic + supplementalGroupsPolicy: + description: |- + Defines how supplemental groups of the first container processes are calculated. + Valid values are "Merge" and "Strict". If not specified, "Merge" is used. + (Alpha) Using the field requires the SupplementalGroupsPolicy feature gate to be enabled + and the container runtime must implement support for this feature. + Note that this field cannot be set when spec.os.name is windows. + type: string sysctls: description: |- Sysctls hold a list of namespaced sysctls used for the pod. Pods with unsupported @@ -4107,6 +4835,36 @@ spec: type: string type: object type: object + serviceAccount: + description: |- + Config for the ServiceAccount to create for the Recorder's StatefulSet. + By default, the operator will create a ServiceAccount with the same + name as the Recorder resource. + https://kubernetes.io/docs/reference/kubernetes-api/workload-resources/pod-v1/#service-account + properties: + annotations: + additionalProperties: + type: string + description: |- + Annotations to add to the ServiceAccount. + https://kubernetes.io/docs/concepts/overview/working-with-objects/annotations/#syntax-and-character-set + + You can use this to add IAM roles to the ServiceAccount (IRSA) instead of + providing static S3 credentials in a Secret. + https://docs.aws.amazon.com/eks/latest/userguide/iam-roles-for-service-accounts.html + + For example: + eks.amazonaws.com/role-arn: arn:aws:iam:::role/ + type: object + name: + description: |- + Name of the ServiceAccount to create. Defaults to the name of the + Recorder resource. + https://kubernetes.io/docs/reference/kubernetes-api/workload-resources/pod-v1/#service-account + maxLength: 253 + pattern: ^[a-z0-9]([a-z0-9-.]{0,61}[a-z0-9])?$ + type: string + type: object tolerations: description: |- Tolerations for Recorder Pods. By default, the operator does not apply @@ -4316,6 +5074,14 @@ kind: ClusterRole metadata: name: tailscale-operator rules: + - apiGroups: + - "" + resources: + - nodes + verbs: + - get + - list + - watch - apiGroups: - "" resources: @@ -4353,6 +5119,14 @@ rules: - get - list - watch + - apiGroups: + - discovery.k8s.io + resources: + - endpointslices + verbs: + - get + - list + - watch - apiGroups: - tailscale.com resources: @@ -4387,6 +5161,16 @@ rules: - list - watch - update + - apiGroups: + - apiextensions.k8s.io + resourceNames: + - servicemonitors.monitoring.coreos.com + resources: + - customresourcedefinitions + verbs: + - get + - list + - watch --- apiVersion: rbac.authorization.k8s.io/v1 kind: ClusterRoleBinding @@ -4430,6 +5214,13 @@ rules: - get - list - watch + - update + - apiGroups: + - "" + resources: + - pods/status + verbs: + - update - apiGroups: - apps resources: @@ -4467,6 +5258,17 @@ rules: - update - list - watch + - deletecollection + - apiGroups: + - monitoring.coreos.com + resources: + - servicemonitors + verbs: + - get + - list + - update + - create + - delete --- apiVersion: rbac.authorization.k8s.io/v1 kind: Role @@ -4487,6 +5289,14 @@ rules: - patch - update - watch + - apiGroups: + - "" + resources: + - events + verbs: + - create + - patch + - get --- apiVersion: rbac.authorization.k8s.io/v1 kind: RoleBinding @@ -4547,19 +5357,31 @@ spec: valueFrom: fieldRef: fieldPath: metadata.namespace + - name: OPERATOR_LOGIN_SERVER + value: null + - name: OPERATOR_INGRESS_CLASS_NAME + value: tailscale - name: CLIENT_ID_FILE value: /oauth/client_id - name: CLIENT_SECRET_FILE value: /oauth/client_secret - name: PROXY_IMAGE - value: tailscale/tailscale:unstable + value: tailscale/tailscale:stable - name: PROXY_TAGS value: tag:k8s - name: APISERVER_PROXY value: "false" - name: PROXY_FIREWALL_MODE value: auto - image: tailscale/k8s-operator:unstable + - name: POD_NAME + valueFrom: + fieldRef: + fieldPath: metadata.name + - name: POD_UID + valueFrom: + fieldRef: + fieldPath: metadata.uid + image: tailscale/k8s-operator:stable imagePullPolicy: Always name: operator volumeMounts: diff --git a/cmd/k8s-operator/deploy/manifests/proxy.yaml b/cmd/k8s-operator/deploy/manifests/proxy.yaml index a79d48d73..3c9a3eaa3 100644 --- a/cmd/k8s-operator/deploy/manifests/proxy.yaml +++ b/cmd/k8s-operator/deploy/manifests/proxy.yaml @@ -30,7 +30,13 @@ spec: valueFrom: fieldRef: fieldPath: status.podIP + - name: POD_NAME + valueFrom: + fieldRef: + fieldPath: metadata.name + - name: POD_UID + valueFrom: + fieldRef: + fieldPath: metadata.uid securityContext: - capabilities: - add: - - NET_ADMIN + privileged: true diff --git a/cmd/k8s-operator/deploy/manifests/userspace-proxy.yaml b/cmd/k8s-operator/deploy/manifests/userspace-proxy.yaml index 46b49a57b..6617f6d4b 100644 --- a/cmd/k8s-operator/deploy/manifests/userspace-proxy.yaml +++ b/cmd/k8s-operator/deploy/manifests/userspace-proxy.yaml @@ -24,3 +24,11 @@ spec: valueFrom: fieldRef: fieldPath: status.podIP + - name: POD_NAME + valueFrom: + fieldRef: + fieldPath: metadata.name + - name: POD_UID + valueFrom: + fieldRef: + fieldPath: metadata.uid diff --git a/cmd/k8s-operator/dnsrecords.go b/cmd/k8s-operator/dnsrecords.go index bba87bf25..1a9395aa0 100644 --- a/cmd/k8s-operator/dnsrecords.go +++ b/cmd/k8s-operator/dnsrecords.go @@ -10,6 +10,7 @@ import ( "encoding/json" "fmt" "slices" + "strings" "go.uber.org/zap" corev1 "k8s.io/api/core/v1" @@ -30,15 +31,19 @@ import ( const ( dnsRecordsRecocilerFinalizer = "tailscale.com/dns-records-reconciler" annotationTSMagicDNSName = "tailscale.com/magic-dnsname" + + // Service types for consistent string usage + serviceTypeIngress = "ingress" + serviceTypeSvc = "svc" ) // dnsRecordsReconciler knows how to update dnsrecords ConfigMap with DNS // records. // The records that it creates are: -// - For tailscale Ingress, a mapping of the Ingress's MagicDNSName to the IP address of -// the ingress proxy Pod. +// - For tailscale Ingress, a mapping of the Ingress's MagicDNSName to the IP addresses +// (both IPv4 and IPv6) of the ingress proxy Pod. // - For egress proxies configured via tailscale.com/tailnet-fqdn annotation, a -// mapping of the tailnet FQDN to the IP address of the egress proxy Pod. +// mapping of the tailnet FQDN to the IP addresses (both IPv4 and IPv6) of the egress proxy Pod. // // Records will only be created if there is exactly one ready // tailscale.com/v1alpha1.DNSConfig instance in the cluster (so that we know @@ -50,7 +55,7 @@ type dnsRecordsReconciler struct { isDefaultLoadBalancer bool // true if operator is the default ingress controller in this cluster } -// Reconcile takes a reconcile.Request for a headless Service fronting a +// Reconcile takes a reconcile.Request for a Service fronting a // tailscale proxy and updates DNS Records in dnsrecords ConfigMap for the // in-cluster ts.net nameserver if required. func (dnsRR *dnsRecordsReconciler) Reconcile(ctx context.Context, req reconcile.Request) (res reconcile.Result, err error) { @@ -58,8 +63,8 @@ func (dnsRR *dnsRecordsReconciler) Reconcile(ctx context.Context, req reconcile. logger.Debugf("starting reconcile") defer logger.Debugf("reconcile finished") - headlessSvc := new(corev1.Service) - err = dnsRR.Client.Get(ctx, req.NamespacedName, headlessSvc) + proxySvc := new(corev1.Service) + err = dnsRR.Client.Get(ctx, req.NamespacedName, proxySvc) if apierrors.IsNotFound(err) { logger.Debugf("Service not found") return reconcile.Result{}, nil @@ -67,14 +72,14 @@ func (dnsRR *dnsRecordsReconciler) Reconcile(ctx context.Context, req reconcile. if err != nil { return reconcile.Result{}, fmt.Errorf("failed to get Service: %w", err) } - if !(isManagedByType(headlessSvc, "svc") || isManagedByType(headlessSvc, "ingress")) { - logger.Debugf("Service is not a headless Service for a tailscale ingress or egress proxy; do nothing") + if !(isManagedByType(proxySvc, serviceTypeSvc) || isManagedByType(proxySvc, serviceTypeIngress)) { + logger.Debugf("Service is not a proxy Service for a tailscale ingress or egress proxy; do nothing") return reconcile.Result{}, nil } - if !headlessSvc.DeletionTimestamp.IsZero() { + if !proxySvc.DeletionTimestamp.IsZero() { logger.Debug("Service is being deleted, clean up resources") - return reconcile.Result{}, dnsRR.maybeCleanup(ctx, headlessSvc, logger) + return reconcile.Result{}, dnsRR.maybeCleanup(ctx, proxySvc, logger) } // Check that there is a ts.net nameserver deployed to the cluster by @@ -98,41 +103,45 @@ func (dnsRR *dnsRecordsReconciler) Reconcile(ctx context.Context, req reconcile. return reconcile.Result{}, nil } - return reconcile.Result{}, dnsRR.maybeProvision(ctx, headlessSvc, logger) + if err := dnsRR.maybeProvision(ctx, proxySvc, logger); err != nil { + if strings.Contains(err.Error(), optimisticLockErrorMsg) { + logger.Infof("optimistic lock error, retrying: %s", err) + } else { + return reconcile.Result{}, err + } + } + + return reconcile.Result{}, nil } // maybeProvision ensures that dnsrecords ConfigMap contains a record for the -// proxy associated with the headless Service. +// proxy associated with the Service. // The record is only provisioned if the proxy is for a tailscale Ingress or // egress configured via tailscale.com/tailnet-fqdn annotation. // // For Ingress, the record is a mapping between the MagicDNSName of the Ingress, retrieved from // ingress.status.loadBalancer.ingress.hostname field and the proxy Pod IP addresses -// retrieved from the EndpoinSlice associated with this headless Service, i.e -// Records{IP4: : <[IPs of the ingress proxy Pods]>} +// retrieved from the EndpointSlice associated with this Service, i.e +// Records{IP4: {: <[IPv4 addresses]>}, IP6: {: <[IPv6 addresses]>}} // // For egress, the record is a mapping between tailscale.com/tailnet-fqdn // annotation and the proxy Pod IP addresses, retrieved from the EndpointSlice -// associated with this headless Service, i.e -// Records{IP4: {: <[IPs of the egress proxy Pods]>} +// associated with this Service, i.e +// Records{IP4: {: <[IPv4 addresses]>}, IP6: {: <[IPv6 addresses]>}} +// +// For ProxyGroup egress, the record is a mapping between tailscale.com/magic-dnsname +// annotation and the ClusterIP Service IPs (which provides portmapping), i.e +// Records{IP4: {: <[IPv4 ClusterIPs]>}, IP6: {: <[IPv6 ClusterIPs]>}} // // If records need to be created for this proxy, maybeProvision will also: -// - update the headless Service with a tailscale.com/magic-dnsname annotation -// - update the headless Service with a finalizer -func (dnsRR *dnsRecordsReconciler) maybeProvision(ctx context.Context, headlessSvc *corev1.Service, logger *zap.SugaredLogger) error { - if headlessSvc == nil { - logger.Info("[unexpected] maybeProvision called with a nil Service") - return nil - } - isEgressFQDNSvc, err := dnsRR.isSvcForFQDNEgressProxy(ctx, headlessSvc) - if err != nil { - return fmt.Errorf("error checking whether the Service is for an egress proxy: %w", err) - } - if !(isEgressFQDNSvc || isManagedByType(headlessSvc, "ingress")) { +// - update the Service with a tailscale.com/magic-dnsname annotation +// - update the Service with a finalizer +func (dnsRR *dnsRecordsReconciler) maybeProvision(ctx context.Context, proxySvc *corev1.Service, logger *zap.SugaredLogger) error { + if !dnsRR.isInterestingService(ctx, proxySvc) { logger.Debug("Service is not fronting a proxy that we create DNS records for; do nothing") return nil } - fqdn, err := dnsRR.fqdnForDNSRecord(ctx, headlessSvc, logger) + fqdn, err := dnsRR.fqdnForDNSRecord(ctx, proxySvc, logger) if err != nil { return fmt.Errorf("error determining DNS name for record: %w", err) } @@ -141,18 +150,18 @@ func (dnsRR *dnsRecordsReconciler) maybeProvision(ctx context.Context, headlessS return nil // a new reconcile will be triggered once it's added } - oldHeadlessSvc := headlessSvc.DeepCopy() - // Ensure that headless Service is annotated with a finalizer to help + oldProxySvc := proxySvc.DeepCopy() + // Ensure that proxy Service is annotated with a finalizer to help // with records cleanup when proxy resources are deleted. - if !slices.Contains(headlessSvc.Finalizers, dnsRecordsRecocilerFinalizer) { - headlessSvc.Finalizers = append(headlessSvc.Finalizers, dnsRecordsRecocilerFinalizer) + if !slices.Contains(proxySvc.Finalizers, dnsRecordsRecocilerFinalizer) { + proxySvc.Finalizers = append(proxySvc.Finalizers, dnsRecordsRecocilerFinalizer) } - // Ensure that headless Service is annotated with the current MagicDNS + // Ensure that proxy Service is annotated with the current MagicDNS // name to help with records cleanup when proxy resources are deleted or // MagicDNS name changes. - oldFqdn := headlessSvc.Annotations[annotationTSMagicDNSName] + oldFqdn := proxySvc.Annotations[annotationTSMagicDNSName] if oldFqdn != "" && oldFqdn != fqdn { // i.e user has changed the value of tailscale.com/tailnet-fqdn annotation - logger.Debugf("MagicDNS name has changed, remvoving record for %s", oldFqdn) + logger.Debugf("MagicDNS name has changed, removing record for %s", oldFqdn) updateFunc := func(rec *operatorutils.Records) { delete(rec.IP4, oldFqdn) } @@ -160,58 +169,32 @@ func (dnsRR *dnsRecordsReconciler) maybeProvision(ctx context.Context, headlessS return fmt.Errorf("error removing record for %s: %w", oldFqdn, err) } } - mak.Set(&headlessSvc.Annotations, annotationTSMagicDNSName, fqdn) - if !apiequality.Semantic.DeepEqual(oldHeadlessSvc, headlessSvc) { + mak.Set(&proxySvc.Annotations, annotationTSMagicDNSName, fqdn) + if !apiequality.Semantic.DeepEqual(oldProxySvc, proxySvc) { logger.Infof("provisioning DNS record for MagicDNS name: %s", fqdn) // this will be printed exactly once - if err := dnsRR.Update(ctx, headlessSvc); err != nil { - return fmt.Errorf("error updating proxy headless Service metadata: %w", err) + if err := dnsRR.Update(ctx, proxySvc); err != nil { + return fmt.Errorf("error updating proxy Service metadata: %w", err) } } - // Get the Pod IP addresses for the proxy from the EndpointSlices for - // the headless Service. The Service can have multiple EndpointSlices - // associated with it, for example in dual-stack clusters. - labels := map[string]string{discoveryv1.LabelServiceName: headlessSvc.Name} // https://kubernetes.io/docs/concepts/services-networking/endpoint-slices/#ownership - var eps = new(discoveryv1.EndpointSliceList) - if err := dnsRR.List(ctx, eps, client.InNamespace(dnsRR.tsNamespace), client.MatchingLabels(labels)); err != nil { - return fmt.Errorf("error listing EndpointSlices for the proxy's headless Service: %w", err) + // Get the IP addresses for the DNS record + ip4s, ip6s, err := dnsRR.getTargetIPs(ctx, proxySvc, logger) + if err != nil { + return fmt.Errorf("error getting target IPs: %w", err) } - if len(eps.Items) == 0 { - logger.Debugf("proxy's headless Service EndpointSlice does not yet exist. We will reconcile again once it's created") + if len(ip4s) == 0 && len(ip6s) == 0 { + logger.Debugf("No target IP addresses available yet. We will reconcile again once they are available.") return nil } - // Each EndpointSlice for a Service can have a list of endpoints that each - // can have multiple addresses - these are the IP addresses of any Pods - // selected by that Service. Pick all the IPv4 addresses. - // It is also possible that multiple EndpointSlices have overlapping addresses. - // https://kubernetes.io/docs/concepts/services-networking/endpoint-slices/#duplicate-endpoints - ips := make(set.Set[string], 0) - for _, slice := range eps.Items { - if slice.AddressType != discoveryv1.AddressTypeIPv4 { - logger.Infof("EndpointSlice is for AddressType %s, currently only IPv4 address type is supported", slice.AddressType) - continue + + updateFunc := func(rec *operatorutils.Records) { + if len(ip4s) > 0 { + mak.Set(&rec.IP4, fqdn, ip4s) } - for _, ep := range slice.Endpoints { - if !epIsReady(&ep) { - logger.Debugf("Endpoint with addresses %v appears not ready to receive traffic %v", ep.Addresses, ep.Conditions.String()) - continue - } - for _, ip := range ep.Addresses { - if !net.IsIPv4String(ip) { - logger.Infof("EndpointSlice contains IP address %q that is not IPv4, ignoring. Currently only IPv4 is supported", ip) - } else { - ips.Add(ip) - } - } + if len(ip6s) > 0 { + mak.Set(&rec.IP6, fqdn, ip6s) } } - if ips.Len() == 0 { - logger.Debugf("EndpointSlice for the Service contains no IPv4 addresses. We will reconcile again once they are created.") - return nil - } - updateFunc := func(rec *operatorutils.Records) { - mak.Set(&rec.IP4, fqdn, ips.Slice()) - } if err = dnsRR.updateDNSConfig(ctx, updateFunc); err != nil { return fmt.Errorf("error updating DNS records: %w", err) } @@ -234,62 +217,66 @@ func epIsReady(ep *discoveryv1.Endpoint) bool { // has been removed from the Service. If the record is not found in the // ConfigMap, the ConfigMap does not exist, or the Service does not have // tailscale.com/magic-dnsname annotation, just remove the finalizer. -func (h *dnsRecordsReconciler) maybeCleanup(ctx context.Context, headlessSvc *corev1.Service, logger *zap.SugaredLogger) error { - ix := slices.Index(headlessSvc.Finalizers, dnsRecordsRecocilerFinalizer) +func (dnsRR *dnsRecordsReconciler) maybeCleanup(ctx context.Context, proxySvc *corev1.Service, logger *zap.SugaredLogger) error { + ix := slices.Index(proxySvc.Finalizers, dnsRecordsRecocilerFinalizer) if ix == -1 { logger.Debugf("no finalizer, nothing to do") return nil } cm := &corev1.ConfigMap{} - err := h.Client.Get(ctx, types.NamespacedName{Name: operatorutils.DNSRecordsCMName, Namespace: h.tsNamespace}, cm) + err := dnsRR.Client.Get(ctx, types.NamespacedName{Name: operatorutils.DNSRecordsCMName, Namespace: dnsRR.tsNamespace}, cm) if apierrors.IsNotFound(err) { - logger.Debug("'dsnrecords' ConfigMap not found") - return h.removeHeadlessSvcFinalizer(ctx, headlessSvc) + logger.Debug("'dnsrecords' ConfigMap not found") + return dnsRR.removeProxySvcFinalizer(ctx, proxySvc) } if err != nil { return fmt.Errorf("error retrieving 'dnsrecords' ConfigMap: %w", err) } if cm.Data == nil { logger.Debug("'dnsrecords' ConfigMap contains no records") - return h.removeHeadlessSvcFinalizer(ctx, headlessSvc) + return dnsRR.removeProxySvcFinalizer(ctx, proxySvc) } _, ok := cm.Data[operatorutils.DNSRecordsCMKey] if !ok { logger.Debug("'dnsrecords' ConfigMap contains no records") - return h.removeHeadlessSvcFinalizer(ctx, headlessSvc) + return dnsRR.removeProxySvcFinalizer(ctx, proxySvc) } - fqdn, _ := headlessSvc.GetAnnotations()[annotationTSMagicDNSName] + fqdn := proxySvc.GetAnnotations()[annotationTSMagicDNSName] if fqdn == "" { - return h.removeHeadlessSvcFinalizer(ctx, headlessSvc) + return dnsRR.removeProxySvcFinalizer(ctx, proxySvc) } logger.Infof("removing DNS record for MagicDNS name %s", fqdn) updateFunc := func(rec *operatorutils.Records) { delete(rec.IP4, fqdn) + if rec.IP6 != nil { + delete(rec.IP6, fqdn) + } } - if err = h.updateDNSConfig(ctx, updateFunc); err != nil { + if err = dnsRR.updateDNSConfig(ctx, updateFunc); err != nil { return fmt.Errorf("error updating DNS config: %w", err) } - return h.removeHeadlessSvcFinalizer(ctx, headlessSvc) + return dnsRR.removeProxySvcFinalizer(ctx, proxySvc) } -func (dnsRR *dnsRecordsReconciler) removeHeadlessSvcFinalizer(ctx context.Context, headlessSvc *corev1.Service) error { - idx := slices.Index(headlessSvc.Finalizers, dnsRecordsRecocilerFinalizer) +func (dnsRR *dnsRecordsReconciler) removeProxySvcFinalizer(ctx context.Context, proxySvc *corev1.Service) error { + idx := slices.Index(proxySvc.Finalizers, dnsRecordsRecocilerFinalizer) if idx == -1 { return nil } - headlessSvc.Finalizers = append(headlessSvc.Finalizers[:idx], headlessSvc.Finalizers[idx+1:]...) - return dnsRR.Update(ctx, headlessSvc) + proxySvc.Finalizers = slices.Delete(proxySvc.Finalizers, idx, idx+1) + return dnsRR.Update(ctx, proxySvc) } -// fqdnForDNSRecord returns MagicDNS name associated with a given headless Service. -// If the headless Service is for a tailscale Ingress proxy, returns ingress.status.loadBalancer.ingress.hostname. -// If the headless Service is for an tailscale egress proxy configured via tailscale.com/tailnet-fqdn annotation, returns the annotation value. -// This function is not expected to be called with headless Services for other +// fqdnForDNSRecord returns MagicDNS name associated with a given proxy Service. +// If the proxy Service is for a tailscale Ingress proxy, returns ingress.status.loadBalancer.ingress.hostname. +// If the proxy Service is for an tailscale egress proxy configured via tailscale.com/tailnet-fqdn annotation, returns the annotation value. +// For ProxyGroup egress Services, returns the tailnet-fqdn annotation from the parent Service. +// This function is not expected to be called with proxy Services for other // proxy types, or any other Services, but it just returns an empty string if // that happens. -func (dnsRR *dnsRecordsReconciler) fqdnForDNSRecord(ctx context.Context, headlessSvc *corev1.Service, logger *zap.SugaredLogger) (string, error) { - parentName := parentFromObjectLabels(headlessSvc) - if isManagedByType(headlessSvc, "ingress") { +func (dnsRR *dnsRecordsReconciler) fqdnForDNSRecord(ctx context.Context, proxySvc *corev1.Service, logger *zap.SugaredLogger) (string, error) { + parentName := parentFromObjectLabels(proxySvc) + if isManagedByType(proxySvc, serviceTypeIngress) { ing := new(networkingv1.Ingress) if err := dnsRR.Get(ctx, parentName, ing); err != nil { return "", err @@ -299,10 +286,10 @@ func (dnsRR *dnsRecordsReconciler) fqdnForDNSRecord(ctx context.Context, headles } return ing.Status.LoadBalancer.Ingress[0].Hostname, nil } - if isManagedByType(headlessSvc, "svc") { + if isManagedByType(proxySvc, serviceTypeSvc) { svc := new(corev1.Service) if err := dnsRR.Get(ctx, parentName, svc); apierrors.IsNotFound(err) { - logger.Info("[unexpected] parent Service for egress proxy %s not found", headlessSvc.Name) + logger.Infof("[unexpected] parent Service for egress proxy %s not found", proxySvc.Name) return "", nil } else if err != nil { return "", err @@ -319,7 +306,7 @@ func (dnsRR *dnsRecordsReconciler) updateDNSConfig(ctx context.Context, update f cm := &corev1.ConfigMap{} err := dnsRR.Get(ctx, types.NamespacedName{Name: operatorutils.DNSRecordsCMName, Namespace: dnsRR.tsNamespace}, cm) if apierrors.IsNotFound(err) { - dnsRR.logger.Info("[unexpected] dnsrecords ConfigMap not found in cluster. Not updating DNS records. Please open an isue and attach operator logs.") + dnsRR.logger.Info("[unexpected] dnsrecords ConfigMap not found in cluster. Not updating DNS records. Please open an issue and attach operator logs.") return nil } if err != nil { @@ -357,3 +344,153 @@ func (dnsRR *dnsRecordsReconciler) isSvcForFQDNEgressProxy(ctx context.Context, annots := parentSvc.Annotations return annots != nil && annots[AnnotationTailnetTargetFQDN] != "", nil } + +// isProxyGroupEgressService reports whether the Service is a ClusterIP Service +// created for ProxyGroup egress. For ProxyGroup egress, there are no headless +// services. Instead, the DNS reconciler processes the ClusterIP Service +// directly, which has portmapping and should use its own IP for DNS records. +func (dnsRR *dnsRecordsReconciler) isProxyGroupEgressService(svc *corev1.Service) bool { + return svc.GetLabels()[labelProxyGroup] != "" && + svc.GetLabels()[labelSvcType] == typeEgress && + svc.Spec.Type == corev1.ServiceTypeClusterIP && + isManagedByType(svc, serviceTypeSvc) +} + +// isInterestingService reports whether the Service is one that we should create +// DNS records for. +func (dnsRR *dnsRecordsReconciler) isInterestingService(ctx context.Context, svc *corev1.Service) bool { + if isManagedByType(svc, serviceTypeIngress) { + return true + } + + isEgressFQDNSvc, err := dnsRR.isSvcForFQDNEgressProxy(ctx, svc) + if err != nil { + return false + } + if isEgressFQDNSvc { + return true + } + + if dnsRR.isProxyGroupEgressService(svc) { + return dnsRR.parentSvcTargetsFQDN(ctx, svc) + } + + return false +} + +// parentSvcTargetsFQDN reports whether the parent Service of a ProxyGroup +// egress Service has an FQDN target (not an IP target). +func (dnsRR *dnsRecordsReconciler) parentSvcTargetsFQDN(ctx context.Context, svc *corev1.Service) bool { + + parentName := parentFromObjectLabels(svc) + parentSvc := new(corev1.Service) + if err := dnsRR.Get(ctx, parentName, parentSvc); err != nil { + return false + } + + return parentSvc.Annotations[AnnotationTailnetTargetFQDN] != "" +} + +// getTargetIPs returns the IPv4 and IPv6 addresses that should be used for DNS records +// for the given proxy Service. +func (dnsRR *dnsRecordsReconciler) getTargetIPs(ctx context.Context, proxySvc *corev1.Service, logger *zap.SugaredLogger) ([]string, []string, error) { + if dnsRR.isProxyGroupEgressService(proxySvc) { + return dnsRR.getClusterIPServiceIPs(proxySvc, logger) + } + return dnsRR.getPodIPs(ctx, proxySvc, logger) +} + +// getClusterIPServiceIPs returns the ClusterIPs of a ProxyGroup egress Service. +// It separates IPv4 and IPv6 addresses for dual-stack services. +func (dnsRR *dnsRecordsReconciler) getClusterIPServiceIPs(proxySvc *corev1.Service, logger *zap.SugaredLogger) ([]string, []string, error) { + // Handle services with no ClusterIP + if proxySvc.Spec.ClusterIP == "" || proxySvc.Spec.ClusterIP == "None" { + logger.Debugf("ProxyGroup egress ClusterIP Service does not have a ClusterIP yet.") + return nil, nil, nil + } + + var ip4s, ip6s []string + + // Check all ClusterIPs for dual-stack support + clusterIPs := proxySvc.Spec.ClusterIPs + if len(clusterIPs) == 0 && proxySvc.Spec.ClusterIP != "" { + // Fallback to single ClusterIP for backward compatibility + clusterIPs = []string{proxySvc.Spec.ClusterIP} + } + + for _, ip := range clusterIPs { + if net.IsIPv4String(ip) { + ip4s = append(ip4s, ip) + logger.Debugf("Using IPv4 ClusterIP %s for ProxyGroup egress DNS record", ip) + } else if net.IsIPv6String(ip) { + ip6s = append(ip6s, ip) + logger.Debugf("Using IPv6 ClusterIP %s for ProxyGroup egress DNS record", ip) + } else { + logger.Debugf("ClusterIP %s is not a valid IP address", ip) + } + } + + if len(ip4s) == 0 && len(ip6s) == 0 { + return nil, nil, fmt.Errorf("no valid ClusterIPs found") + } + + return ip4s, ip6s, nil +} + +// getPodIPs returns Pod IPv4 and IPv6 addresses from EndpointSlices for non-ProxyGroup Services. +func (dnsRR *dnsRecordsReconciler) getPodIPs(ctx context.Context, proxySvc *corev1.Service, logger *zap.SugaredLogger) ([]string, []string, error) { + // Get the Pod IP addresses for the proxy from the EndpointSlices for + // the headless Service. The Service can have multiple EndpointSlices + // associated with it, for example in dual-stack clusters. + labels := map[string]string{discoveryv1.LabelServiceName: proxySvc.Name} // https://kubernetes.io/docs/concepts/services-networking/endpoint-slices/#ownership + var eps = new(discoveryv1.EndpointSliceList) + if err := dnsRR.List(ctx, eps, client.InNamespace(dnsRR.tsNamespace), client.MatchingLabels(labels)); err != nil { + return nil, nil, fmt.Errorf("error listing EndpointSlices for the proxy's Service: %w", err) + } + if len(eps.Items) == 0 { + logger.Debugf("proxy's Service EndpointSlice does not yet exist.") + return nil, nil, nil + } + // Each EndpointSlice for a Service can have a list of endpoints that each + // can have multiple addresses - these are the IP addresses of any Pods + // selected by that Service. Separate IPv4 and IPv6 addresses. + // It is also possible that multiple EndpointSlices have overlapping addresses. + // https://kubernetes.io/docs/concepts/services-networking/endpoint-slices/#duplicate-endpoints + ip4s := make(set.Set[string], 0) + ip6s := make(set.Set[string], 0) + for _, slice := range eps.Items { + for _, ep := range slice.Endpoints { + if !epIsReady(&ep) { + logger.Debugf("Endpoint with addresses %v appears not ready to receive traffic %v", ep.Addresses, ep.Conditions.String()) + continue + } + for _, ip := range ep.Addresses { + switch slice.AddressType { + case discoveryv1.AddressTypeIPv4: + if net.IsIPv4String(ip) { + ip4s.Add(ip) + } else { + logger.Debugf("EndpointSlice with AddressType IPv4 contains non-IPv4 address %q, ignoring", ip) + } + case discoveryv1.AddressTypeIPv6: + if net.IsIPv6String(ip) { + // Strip zone ID if present (e.g., fe80::1%eth0 -> fe80::1) + if idx := strings.IndexByte(ip, '%'); idx != -1 { + ip = ip[:idx] + } + ip6s.Add(ip) + } else { + logger.Debugf("EndpointSlice with AddressType IPv6 contains non-IPv6 address %q, ignoring", ip) + } + default: + logger.Debugf("EndpointSlice is for unsupported AddressType %s, skipping", slice.AddressType) + } + } + } + } + if ip4s.Len() == 0 && ip6s.Len() == 0 { + logger.Debugf("EndpointSlice for the Service contains no IP addresses.") + return nil, nil, nil + } + return ip4s.Slice(), ip6s.Slice(), nil +} diff --git a/cmd/k8s-operator/dnsrecords_test.go b/cmd/k8s-operator/dnsrecords_test.go index 389461b85..13898078f 100644 --- a/cmd/k8s-operator/dnsrecords_test.go +++ b/cmd/k8s-operator/dnsrecords_test.go @@ -18,10 +18,12 @@ import ( networkingv1 "k8s.io/api/networking/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/types" + "k8s.io/apimachinery/pkg/util/intstr" "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/client/fake" operatorutils "tailscale.com/k8s-operator" tsapi "tailscale.com/k8s-operator/apis/v1alpha1" + "tailscale.com/kube/kubetypes" "tailscale.com/tstest" "tailscale.com/types/ptr" ) @@ -65,7 +67,7 @@ func TestDNSRecordsReconciler(t *testing.T) { } cl := tstest.NewClock(tstest.ClockOpts{}) // Set the ready condition of the DNSConfig - mustUpdateStatus[tsapi.DNSConfig](t, fc, "", "test", func(c *tsapi.DNSConfig) { + mustUpdateStatus(t, fc, "", "test", func(c *tsapi.DNSConfig) { operatorutils.SetDNSConfigCondition(c, tsapi.NameserverReady, metav1.ConditionTrue, reasonNameserverCreated, reasonNameserverCreated, 0, cl, zl.Sugar()) }) dnsRR := &dnsRecordsReconciler{ @@ -97,8 +99,9 @@ func TestDNSRecordsReconciler(t *testing.T) { mustCreate(t, fc, epv6) expectReconciled(t, dnsRR, "tailscale", "egress-fqdn") // dns-records-reconciler reconcile the headless Service // ConfigMap should now have a record for foo.bar.ts.net -> 10.8.8.7 - wantHosts := map[string][]string{"foo.bar.ts.net": {"10.9.8.7"}} // IPv6 endpoint is currently ignored - expectHostsRecords(t, fc, wantHosts) + wantHosts := map[string][]string{"foo.bar.ts.net": {"10.9.8.7"}} + wantHostsIPv6 := map[string][]string{"foo.bar.ts.net": {"2600:1900:4011:161:0:d:0:d"}} + expectHostsRecordsWithIPv6(t, fc, wantHosts, wantHostsIPv6) // 2. DNS record is updated if tailscale.com/tailnet-fqdn annotation's // value changes @@ -155,6 +158,262 @@ func TestDNSRecordsReconciler(t *testing.T) { expectReconciled(t, dnsRR, "tailscale", "ts-ingress") wantHosts["another.ingress.ts.net"] = []string{"1.2.3.4"} expectHostsRecords(t, fc, wantHosts) + + // 8. DNS record is created for ProxyGroup egress using ClusterIP Service IP instead of Pod IPs + t.Log("test case 8: ProxyGroup egress") + + // Create the parent ExternalName service with tailnet-fqdn annotation + parentEgressSvc := &corev1.Service{ + ObjectMeta: metav1.ObjectMeta{ + Name: "external-service", + Namespace: "default", + Annotations: map[string]string{ + AnnotationTailnetTargetFQDN: "external-service.example.ts.net", + }, + }, + Spec: corev1.ServiceSpec{ + Type: corev1.ServiceTypeExternalName, + ExternalName: "unused", + }, + } + mustCreate(t, fc, parentEgressSvc) + + proxyGroupEgressSvc := &corev1.Service{ + ObjectMeta: metav1.ObjectMeta{ + Name: "ts-proxygroup-egress-abcd1", + Namespace: "tailscale", + Labels: map[string]string{ + kubetypes.LabelManaged: "true", + LabelParentName: "external-service", + LabelParentNamespace: "default", + LabelParentType: "svc", + labelProxyGroup: "test-proxy-group", + labelSvcType: typeEgress, + }, + }, + Spec: corev1.ServiceSpec{ + Type: corev1.ServiceTypeClusterIP, + ClusterIP: "10.0.100.50", // This IP should be used in DNS, not Pod IPs + Ports: []corev1.ServicePort{{ + Port: 443, + TargetPort: intstr.FromInt(10443), // Port mapping + }}, + }, + } + + // Create EndpointSlice with Pod IPs (these should NOT be used in DNS records) + proxyGroupEps := &discoveryv1.EndpointSlice{ + ObjectMeta: metav1.ObjectMeta{ + Name: "ts-proxygroup-egress-abcd1-ipv4", + Namespace: "tailscale", + Labels: map[string]string{ + discoveryv1.LabelServiceName: "ts-proxygroup-egress-abcd1", + kubetypes.LabelManaged: "true", + LabelParentName: "external-service", + LabelParentNamespace: "default", + LabelParentType: "svc", + labelProxyGroup: "test-proxy-group", + labelSvcType: typeEgress, + }, + }, + AddressType: discoveryv1.AddressTypeIPv4, + Endpoints: []discoveryv1.Endpoint{{ + Addresses: []string{"10.1.0.100", "10.1.0.101", "10.1.0.102"}, // Pod IPs that should NOT be used + Conditions: discoveryv1.EndpointConditions{ + Ready: ptr.To(true), + Serving: ptr.To(true), + Terminating: ptr.To(false), + }, + }}, + Ports: []discoveryv1.EndpointPort{{ + Port: ptr.To(int32(10443)), + }}, + } + + mustCreate(t, fc, proxyGroupEgressSvc) + mustCreate(t, fc, proxyGroupEps) + expectReconciled(t, dnsRR, "tailscale", "ts-proxygroup-egress-abcd1") + + // Verify DNS record uses ClusterIP Service IP, not Pod IPs + wantHosts["external-service.example.ts.net"] = []string{"10.0.100.50"} + expectHostsRecords(t, fc, wantHosts) + + // 9. ProxyGroup egress DNS record updates when ClusterIP changes + t.Log("test case 9: ProxyGroup egress ClusterIP change") + mustUpdate(t, fc, "tailscale", "ts-proxygroup-egress-abcd1", func(svc *corev1.Service) { + svc.Spec.ClusterIP = "10.0.100.51" + }) + expectReconciled(t, dnsRR, "tailscale", "ts-proxygroup-egress-abcd1") + wantHosts["external-service.example.ts.net"] = []string{"10.0.100.51"} + expectHostsRecords(t, fc, wantHosts) + + // 10. Test ProxyGroup service deletion and DNS cleanup + t.Log("test case 10: ProxyGroup egress service deletion") + mustDeleteAll(t, fc, proxyGroupEgressSvc) + expectReconciled(t, dnsRR, "tailscale", "ts-proxygroup-egress-abcd1") + delete(wantHosts, "external-service.example.ts.net") + expectHostsRecords(t, fc, wantHosts) +} + +func TestDNSRecordsReconcilerErrorCases(t *testing.T) { + zl, err := zap.NewDevelopment() + if err != nil { + t.Fatal(err) + } + + dnsRR := &dnsRecordsReconciler{ + logger: zl.Sugar(), + } + + testSvc := &corev1.Service{ + ObjectMeta: metav1.ObjectMeta{Name: "test"}, + Spec: corev1.ServiceSpec{Type: corev1.ServiceTypeClusterIP}, + } + + // Test invalid IP format + testSvc.Spec.ClusterIP = "invalid-ip" + _, _, err = dnsRR.getClusterIPServiceIPs(testSvc, zl.Sugar()) + if err == nil { + t.Error("expected error for invalid IP format") + } + + // Test valid IP + testSvc.Spec.ClusterIP = "10.0.100.50" + ip4s, ip6s, err := dnsRR.getClusterIPServiceIPs(testSvc, zl.Sugar()) + if err != nil { + t.Errorf("unexpected error for valid IP: %v", err) + } + if len(ip4s) != 1 || ip4s[0] != "10.0.100.50" { + t.Errorf("expected IPv4 address 10.0.100.50, got %v", ip4s) + } + if len(ip6s) != 0 { + t.Errorf("expected no IPv6 addresses, got %v", ip6s) + } +} + +func TestDNSRecordsReconcilerDualStack(t *testing.T) { + // Test dual-stack (IPv4 and IPv6) scenarios + zl, err := zap.NewDevelopment() + if err != nil { + t.Fatal(err) + } + + // Preconfigure cluster with DNSConfig + dnsCfg := &tsapi.DNSConfig{ + ObjectMeta: metav1.ObjectMeta{Name: "test"}, + TypeMeta: metav1.TypeMeta{Kind: "DNSConfig"}, + Spec: tsapi.DNSConfigSpec{Nameserver: &tsapi.Nameserver{}}, + } + dnsCfg.Status.Conditions = append(dnsCfg.Status.Conditions, metav1.Condition{ + Type: string(tsapi.NameserverReady), + Status: metav1.ConditionTrue, + }) + + // Create dual-stack ingress + ing := &networkingv1.Ingress{ + ObjectMeta: metav1.ObjectMeta{ + Name: "dual-stack-ingress", + Namespace: "test", + }, + Spec: networkingv1.IngressSpec{ + IngressClassName: ptr.To("tailscale"), + }, + Status: networkingv1.IngressStatus{ + LoadBalancer: networkingv1.IngressLoadBalancerStatus{ + Ingress: []networkingv1.IngressLoadBalancerIngress{ + {Hostname: "dual-stack.example.ts.net"}, + }, + }, + }, + } + + headlessSvc := headlessSvcForParent(ing, "ingress") + headlessSvc.Name = "ts-dual-stack-ingress" + headlessSvc.SetLabels(map[string]string{ + kubetypes.LabelManaged: "true", + LabelParentName: "dual-stack-ingress", + LabelParentNamespace: "test", + LabelParentType: "ingress", + }) + + // Create both IPv4 and IPv6 endpoints + epv4 := endpointSliceForService(headlessSvc, "10.1.2.3", discoveryv1.AddressTypeIPv4) + epv6 := endpointSliceForService(headlessSvc, "2001:db8::1", discoveryv1.AddressTypeIPv6) + + dnsRRDualStack := &dnsRecordsReconciler{ + tsNamespace: "tailscale", + logger: zl.Sugar(), + } + + // Create the dnsrecords ConfigMap + cm := &corev1.ConfigMap{ + ObjectMeta: metav1.ObjectMeta{ + Name: operatorutils.DNSRecordsCMName, + Namespace: "tailscale", + }, + } + + fc := fake.NewClientBuilder(). + WithScheme(tsapi.GlobalScheme). + WithObjects(dnsCfg, ing, headlessSvc, epv4, epv6, cm). + WithStatusSubresource(dnsCfg). + Build() + + dnsRRDualStack.Client = fc + + // Test dual-stack service records + expectReconciled(t, dnsRRDualStack, "tailscale", "ts-dual-stack-ingress") + + wantIPv4 := map[string][]string{"dual-stack.example.ts.net": {"10.1.2.3"}} + wantIPv6 := map[string][]string{"dual-stack.example.ts.net": {"2001:db8::1"}} + expectHostsRecordsWithIPv6(t, fc, wantIPv4, wantIPv6) + + // Test ProxyGroup with dual-stack ClusterIPs + // First create parent service + parentEgressSvc := &corev1.Service{ + ObjectMeta: metav1.ObjectMeta{ + Name: "pg-service", + Namespace: "tailscale", + Annotations: map[string]string{ + AnnotationTailnetTargetFQDN: "pg-service.example.ts.net", + }, + }, + Spec: corev1.ServiceSpec{ + Type: corev1.ServiceTypeExternalName, + ExternalName: "unused", + }, + } + + proxyGroupSvc := &corev1.Service{ + ObjectMeta: metav1.ObjectMeta{ + Name: "ts-proxygroup-dualstack", + Namespace: "tailscale", + Labels: map[string]string{ + kubetypes.LabelManaged: "true", + labelProxyGroup: "test-pg", + labelSvcType: typeEgress, + LabelParentName: "pg-service", + LabelParentNamespace: "tailscale", + LabelParentType: "svc", + }, + Annotations: map[string]string{ + annotationTSMagicDNSName: "pg-service.example.ts.net", + }, + }, + Spec: corev1.ServiceSpec{ + Type: corev1.ServiceTypeClusterIP, + ClusterIP: "10.96.0.100", + ClusterIPs: []string{"10.96.0.100", "2001:db8::100"}, + }, + } + + mustCreate(t, fc, parentEgressSvc) + mustCreate(t, fc, proxyGroupSvc) + expectReconciled(t, dnsRRDualStack, "tailscale", "ts-proxygroup-dualstack") + + wantIPv4["pg-service.example.ts.net"] = []string{"10.96.0.100"} + wantIPv6["pg-service.example.ts.net"] = []string{"2001:db8::100"} + expectHostsRecordsWithIPv6(t, fc, wantIPv4, wantIPv6) } func headlessSvcForParent(o client.Object, typ string) *corev1.Service { @@ -163,10 +422,10 @@ func headlessSvcForParent(o client.Object, typ string) *corev1.Service { Name: o.GetName(), Namespace: "tailscale", Labels: map[string]string{ - LabelManaged: "true", - LabelParentName: o.GetName(), - LabelParentNamespace: o.GetNamespace(), - LabelParentType: typ, + kubetypes.LabelManaged: "true", + LabelParentName: o.GetName(), + LabelParentNamespace: o.GetNamespace(), + LabelParentType: typ, }, }, Spec: corev1.ServiceSpec{ @@ -217,3 +476,28 @@ func expectHostsRecords(t *testing.T, cl client.Client, wantsHosts map[string][] t.Fatalf("unexpected dns config (-got +want):\n%s", diff) } } + +func expectHostsRecordsWithIPv6(t *testing.T, cl client.Client, wantsHostsIPv4, wantsHostsIPv6 map[string][]string) { + t.Helper() + cm := new(corev1.ConfigMap) + if err := cl.Get(context.Background(), types.NamespacedName{Name: "dnsrecords", Namespace: "tailscale"}, cm); err != nil { + t.Fatalf("getting dnsconfig ConfigMap: %v", err) + } + if cm.Data == nil { + t.Fatal("dnsconfig ConfigMap has no data") + } + dnsConfigString, ok := cm.Data[operatorutils.DNSRecordsCMKey] + if !ok { + t.Fatal("dnsconfig ConfigMap does not contain dnsconfig") + } + dnsConfig := &operatorutils.Records{} + if err := json.Unmarshal([]byte(dnsConfigString), dnsConfig); err != nil { + t.Fatalf("unmarshaling dnsconfig: %v", err) + } + if diff := cmp.Diff(dnsConfig.IP4, wantsHostsIPv4); diff != "" { + t.Fatalf("unexpected IPv4 dns config (-got +want):\n%s", diff) + } + if diff := cmp.Diff(dnsConfig.IP6, wantsHostsIPv6); diff != "" { + t.Fatalf("unexpected IPv6 dns config (-got +want):\n%s", diff) + } +} diff --git a/cmd/k8s-operator/e2e/acl.hujson b/cmd/k8s-operator/e2e/acl.hujson new file mode 100644 index 000000000..1a7b61767 --- /dev/null +++ b/cmd/k8s-operator/e2e/acl.hujson @@ -0,0 +1,33 @@ +// To run the e2e tests against a tailnet, ensure its access controls are a +// superset of the following: +{ + "tagOwners": { + "tag:k8s-operator": [], + "tag:k8s": ["tag:k8s-operator"], + "tag:k8s-recorder": ["tag:k8s-operator"], + }, + "autoApprovers": { + // Could be relaxed if we coordinated with the cluster config, but this + // wide subnet maximises compatibility for most clusters. + "routes": { + "10.0.0.0/8": ["tag:k8s"], + }, + "services": { + "tag:k8s": ["tag:k8s"], + }, + }, + "grants": [ + { + "src": ["tag:k8s"], + "dst": ["tag:k8s", "tag:k8s-operator"], + "ip": ["tcp:80", "tcp:443"], + "app": { + "tailscale.com/cap/kubernetes": [{ + "impersonate": { + "groups": ["ts:e2e-test-proxy"], + }, + }], + }, + }, + ], +} \ No newline at end of file diff --git a/cmd/k8s-operator/e2e/ingress_test.go b/cmd/k8s-operator/e2e/ingress_test.go new file mode 100644 index 000000000..23f0711ec --- /dev/null +++ b/cmd/k8s-operator/e2e/ingress_test.go @@ -0,0 +1,130 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package e2e + +import ( + "context" + "fmt" + "net/http" + "testing" + "time" + + appsv1 "k8s.io/api/apps/v1" + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/util/wait" + "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/client/config" + kube "tailscale.com/k8s-operator" + "tailscale.com/tstest" + "tailscale.com/types/ptr" + "tailscale.com/util/httpm" +) + +// See [TestMain] for test requirements. +func TestIngress(t *testing.T) { + if apiClient == nil { + t.Skip("TestIngress requires TS_API_CLIENT_SECRET set") + } + + cfg := config.GetConfigOrDie() + cl, err := client.New(cfg, client.Options{}) + if err != nil { + t.Fatal(err) + } + // Apply nginx + createAndCleanup(t, cl, + &appsv1.Deployment{ + ObjectMeta: metav1.ObjectMeta{ + Name: "nginx", + Namespace: "default", + Labels: map[string]string{ + "app.kubernetes.io/name": "nginx", + }, + }, + Spec: appsv1.DeploymentSpec{ + Replicas: ptr.To[int32](1), + Selector: &metav1.LabelSelector{ + MatchLabels: map[string]string{ + "app.kubernetes.io/name": "nginx", + }, + }, + Template: corev1.PodTemplateSpec{ + ObjectMeta: metav1.ObjectMeta{ + Labels: map[string]string{ + "app.kubernetes.io/name": "nginx", + }, + }, + Spec: corev1.PodSpec{ + Containers: []corev1.Container{ + { + Name: "nginx", + Image: "nginx", + }, + }, + }, + }, + }, + }) + // Apply service to expose it as ingress + svc := &corev1.Service{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-ingress", + Namespace: "default", + Annotations: map[string]string{ + "tailscale.com/expose": "true", + "tailscale.com/proxy-class": "prod", + }, + }, + Spec: corev1.ServiceSpec{ + Selector: map[string]string{ + "app.kubernetes.io/name": "nginx", + }, + Ports: []corev1.ServicePort{ + { + Name: "http", + Protocol: "TCP", + Port: 80, + }, + }, + }, + } + createAndCleanup(t, cl, svc) + + // TODO: instead of timing out only when test times out, cancel context after 60s or so. + if err := wait.PollUntilContextCancel(t.Context(), time.Millisecond*100, true, func(ctx context.Context) (done bool, err error) { + maybeReadySvc := &corev1.Service{ObjectMeta: objectMeta("default", "test-ingress")} + if err := get(ctx, cl, maybeReadySvc); err != nil { + return false, err + } + isReady := kube.SvcIsReady(maybeReadySvc) + if isReady { + t.Log("Service is ready") + } + return isReady, nil + }); err != nil { + t.Fatalf("error waiting for the Service to become Ready: %v", err) + } + + var resp *http.Response + if err := tstest.WaitFor(time.Minute, func() error { + // TODO(tomhjp): Get the tailnet DNS name from the associated secret instead. + // If we are not the first tailnet node with the requested name, we'll get + // a -N suffix. + req, err := http.NewRequest(httpm.GET, fmt.Sprintf("http://%s-%s:80", svc.Namespace, svc.Name), nil) + if err != nil { + return err + } + ctx, cancel := context.WithTimeout(t.Context(), time.Second) + defer cancel() + resp, err = tailnetClient.HTTPClient().Do(req.WithContext(ctx)) + return err + }); err != nil { + t.Fatalf("error trying to reach Service: %v", err) + } + + if resp.StatusCode != http.StatusOK { + t.Fatalf("unexpected status: %v; response body s", resp.StatusCode) + } +} diff --git a/cmd/k8s-operator/e2e/main_test.go b/cmd/k8s-operator/e2e/main_test.go new file mode 100644 index 000000000..fb5e5c859 --- /dev/null +++ b/cmd/k8s-operator/e2e/main_test.go @@ -0,0 +1,127 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package e2e + +import ( + "context" + "errors" + "log" + "os" + "strings" + "testing" + "time" + + "golang.org/x/oauth2/clientcredentials" + apierrors "k8s.io/apimachinery/pkg/api/errors" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "sigs.k8s.io/controller-runtime/pkg/client" + "tailscale.com/internal/client/tailscale" + "tailscale.com/ipn/store/mem" + "tailscale.com/tsnet" +) + +// This test suite is currently not run in CI. +// It requires some setup not handled by this code: +// - Kubernetes cluster with local kubeconfig for it (direct connection, no API server proxy) +// - Tailscale operator installed with --set apiServerProxyConfig.mode="true" +// - ACLs from acl.hujson +// - OAuth client secret in TS_API_CLIENT_SECRET env, with at least auth_keys write scope and tag:k8s tag +var ( + apiClient *tailscale.Client // For API calls to control. + tailnetClient *tsnet.Server // For testing real tailnet traffic. +) + +func TestMain(m *testing.M) { + code, err := runTests(m) + if err != nil { + log.Printf("Error: %v", err) + os.Exit(1) + } + os.Exit(code) +} + +func runTests(m *testing.M) (int, error) { + secret := os.Getenv("TS_API_CLIENT_SECRET") + if secret != "" { + secretParts := strings.Split(secret, "-") + if len(secretParts) != 4 { + return 0, errors.New("TS_API_CLIENT_SECRET is not valid") + } + ctx := context.Background() + credentials := clientcredentials.Config{ + ClientID: secretParts[2], + ClientSecret: secret, + TokenURL: "https://login.tailscale.com/api/v2/oauth/token", + Scopes: []string{"auth_keys"}, + } + apiClient = tailscale.NewClient("-", nil) + apiClient.HTTPClient = credentials.Client(ctx) + + caps := tailscale.KeyCapabilities{ + Devices: tailscale.KeyDeviceCapabilities{ + Create: tailscale.KeyDeviceCreateCapabilities{ + Reusable: false, + Preauthorized: true, + Ephemeral: true, + Tags: []string{"tag:k8s"}, + }, + }, + } + + authKey, authKeyMeta, err := apiClient.CreateKeyWithExpiry(ctx, caps, 10*time.Minute) + if err != nil { + return 0, err + } + defer apiClient.DeleteKey(context.Background(), authKeyMeta.ID) + + tailnetClient = &tsnet.Server{ + Hostname: "test-proxy", + Ephemeral: true, + Store: &mem.Store{}, + AuthKey: authKey, + } + _, err = tailnetClient.Up(ctx) + if err != nil { + return 0, err + } + defer tailnetClient.Close() + } + + return m.Run(), nil +} + +func objectMeta(namespace, name string) metav1.ObjectMeta { + return metav1.ObjectMeta{ + Namespace: namespace, + Name: name, + } +} + +func createAndCleanup(t *testing.T, cl client.Client, obj client.Object) { + t.Helper() + + // Try to create the object first + err := cl.Create(t.Context(), obj) + if err != nil { + if apierrors.IsAlreadyExists(err) { + if updateErr := cl.Update(t.Context(), obj); updateErr != nil { + t.Fatal(updateErr) + } + } else { + t.Fatal(err) + } + } + + t.Cleanup(func() { + // Use context.Background() for cleanup, as t.Context() is cancelled + // just before cleanup functions are called. + if err := cl.Delete(context.Background(), obj); err != nil { + t.Errorf("error cleaning up %s %s/%s: %s", obj.GetObjectKind().GroupVersionKind(), obj.GetNamespace(), obj.GetName(), err) + } + }) +} + +func get(ctx context.Context, cl client.Client, obj client.Object) error { + return cl.Get(ctx, client.ObjectKeyFromObject(obj), obj) +} diff --git a/cmd/k8s-operator/e2e/proxy_test.go b/cmd/k8s-operator/e2e/proxy_test.go new file mode 100644 index 000000000..b3010f97e --- /dev/null +++ b/cmd/k8s-operator/e2e/proxy_test.go @@ -0,0 +1,110 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package e2e + +import ( + "encoding/json" + "fmt" + "testing" + "time" + + corev1 "k8s.io/api/core/v1" + rbacv1 "k8s.io/api/rbac/v1" + apierrors "k8s.io/apimachinery/pkg/api/errors" + "k8s.io/client-go/rest" + "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/client/config" + "tailscale.com/ipn" + "tailscale.com/tstest" +) + +// See [TestMain] for test requirements. +func TestProxy(t *testing.T) { + if apiClient == nil { + t.Skip("TestIngress requires TS_API_CLIENT_SECRET set") + } + + cfg := config.GetConfigOrDie() + cl, err := client.New(cfg, client.Options{}) + if err != nil { + t.Fatal(err) + } + + // Create role and role binding to allow a group we'll impersonate to do stuff. + createAndCleanup(t, cl, &rbacv1.Role{ + ObjectMeta: objectMeta("tailscale", "read-secrets"), + Rules: []rbacv1.PolicyRule{{ + APIGroups: []string{""}, + Verbs: []string{"get"}, + Resources: []string{"secrets"}, + }}, + }) + createAndCleanup(t, cl, &rbacv1.RoleBinding{ + ObjectMeta: objectMeta("tailscale", "read-secrets"), + Subjects: []rbacv1.Subject{{ + Kind: "Group", + Name: "ts:e2e-test-proxy", + }}, + RoleRef: rbacv1.RoleRef{ + Kind: "Role", + Name: "read-secrets", + }, + }) + + // Get operator host name from kube secret. + operatorSecret := corev1.Secret{ + ObjectMeta: objectMeta("tailscale", "operator"), + } + if err := get(t.Context(), cl, &operatorSecret); err != nil { + t.Fatal(err) + } + + // Join tailnet as a client of the API server proxy. + proxyCfg := &rest.Config{ + Host: fmt.Sprintf("https://%s:443", hostNameFromOperatorSecret(t, operatorSecret)), + Dial: tailnetClient.Dial, + } + proxyCl, err := client.New(proxyCfg, client.Options{}) + if err != nil { + t.Fatal(err) + } + + // Expect success. + allowedSecret := corev1.Secret{ + ObjectMeta: objectMeta("tailscale", "operator"), + } + // Wait for up to a minute the first time we use the proxy, to give it time + // to provision the TLS certs. + if err := tstest.WaitFor(time.Minute, func() error { + return get(t.Context(), proxyCl, &allowedSecret) + }); err != nil { + t.Fatal(err) + } + + // Expect forbidden. + forbiddenSecret := corev1.Secret{ + ObjectMeta: objectMeta("default", "operator"), + } + if err := get(t.Context(), proxyCl, &forbiddenSecret); err == nil || !apierrors.IsForbidden(err) { + t.Fatalf("expected forbidden error fetching secret from default namespace: %s", err) + } +} + +func hostNameFromOperatorSecret(t *testing.T, s corev1.Secret) string { + t.Helper() + prefsBytes, ok := s.Data[string(s.Data["_current-profile"])] + if !ok { + t.Fatalf("no state in operator Secret data: %#v", s.Data) + } + + prefs := ipn.Prefs{} + if err := json.Unmarshal(prefsBytes, &prefs); err != nil { + t.Fatal(err) + } + + if prefs.Persist == nil { + t.Fatalf("no hostname in operator Secret data: %#v", s.Data) + } + return prefs.Persist.UserProfile.LoginName +} diff --git a/cmd/k8s-operator/egress-eps.go b/cmd/k8s-operator/egress-eps.go index 510d58783..88da99353 100644 --- a/cmd/k8s-operator/egress-eps.go +++ b/cmd/k8s-operator/egress-eps.go @@ -9,6 +9,7 @@ import ( "context" "encoding/json" "fmt" + "net/netip" "reflect" "strings" @@ -19,7 +20,6 @@ import ( metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/reconcile" - tsoperator "tailscale.com/k8s-operator" "tailscale.com/kube/egressservices" "tailscale.com/types/ptr" ) @@ -36,21 +36,21 @@ type egressEpsReconciler struct { // It compares tailnet service state stored in egress proxy state Secrets by containerboot with the desired // configuration stored in proxy-cfg ConfigMap to determine if the endpoint is ready. func (er *egressEpsReconciler) Reconcile(ctx context.Context, req reconcile.Request) (res reconcile.Result, err error) { - l := er.logger.With("Service", req.NamespacedName) - l.Debugf("starting reconcile") - defer l.Debugf("reconcile finished") + lg := er.logger.With("Service", req.NamespacedName) + lg.Debugf("starting reconcile") + defer lg.Debugf("reconcile finished") eps := new(discoveryv1.EndpointSlice) err = er.Get(ctx, req.NamespacedName, eps) if apierrors.IsNotFound(err) { - l.Debugf("EndpointSlice not found") + lg.Debugf("EndpointSlice not found") return reconcile.Result{}, nil } if err != nil { return reconcile.Result{}, fmt.Errorf("failed to get EndpointSlice: %w", err) } if !eps.DeletionTimestamp.IsZero() { - l.Debugf("EnpointSlice is being deleted") + lg.Debugf("EnpointSlice is being deleted") return res, nil } @@ -58,61 +58,67 @@ func (er *egressEpsReconciler) Reconcile(ctx context.Context, req reconcile.Requ // resources are set up for this tailnet service. svc := &corev1.Service{ ObjectMeta: metav1.ObjectMeta{ - Name: eps.Labels[labelExternalSvcName], - Namespace: eps.Labels[labelExternalSvcNamespace], + Name: eps.Labels[LabelParentName], + Namespace: eps.Labels[LabelParentNamespace], }, } err = er.Get(ctx, client.ObjectKeyFromObject(svc), svc) if apierrors.IsNotFound(err) { - l.Infof("ExternalName Service %s/%s not found, perhaps it was deleted", svc.Namespace, svc.Name) + lg.Infof("ExternalName Service %s/%s not found, perhaps it was deleted", svc.Namespace, svc.Name) return res, nil } if err != nil { return res, fmt.Errorf("error retrieving ExternalName Service: %w", err) } - if !tsoperator.EgressServiceIsValidAndConfigured(svc) { - l.Infof("Cluster resources for ExternalName Service %s/%s are not yet configured", svc.Namespace, svc.Name) - return res, nil - } // TODO(irbekrm): currently this reconcile loop runs all the checks every time it's triggered, which is // wasteful. Once we have a Ready condition for ExternalName Services for ProxyGroup, use the condition to // determine if a reconcile is needed. oldEps := eps.DeepCopy() - proxyGroupName := eps.Labels[labelProxyGroup] tailnetSvc := tailnetSvcName(svc) - l = l.With("tailnet-service-name", tailnetSvc) + lg = lg.With("tailnet-service-name", tailnetSvc) // Retrieve the desired tailnet service configuration from the ConfigMap. + proxyGroupName := eps.Labels[labelProxyGroup] _, cfgs, err := egressSvcsConfigs(ctx, er.Client, proxyGroupName, er.tsNamespace) if err != nil { return res, fmt.Errorf("error retrieving tailnet services configuration: %w", err) } + if cfgs == nil { + // TODO(irbekrm): this path would be hit if egress service was once exposed on a ProxyGroup that later + // got deleted. Probably the EndpointSlices then need to be deleted too- need to rethink this flow. + lg.Debugf("No egress config found, likely because ProxyGroup has not been created") + return res, nil + } cfg, ok := (*cfgs)[tailnetSvc] if !ok { - l.Infof("[unexpected] configuration for tailnet service %s not found", tailnetSvc) + lg.Infof("[unexpected] configuration for tailnet service %s not found", tailnetSvc) return res, nil } // Check which Pods in ProxyGroup are ready to route traffic to this // egress service. podList := &corev1.PodList{} - if err := er.List(ctx, podList, client.MatchingLabels(map[string]string{labelProxyGroup: proxyGroupName})); err != nil { + if err := er.List(ctx, podList, client.MatchingLabels(pgLabels(proxyGroupName, nil))); err != nil { return res, fmt.Errorf("error listing Pods for ProxyGroup %s: %w", proxyGroupName, err) } newEndpoints := make([]discoveryv1.Endpoint, 0) for _, pod := range podList.Items { - ready, err := er.podIsReadyToRouteTraffic(ctx, pod, &cfg, tailnetSvc, l) + ready, err := er.podIsReadyToRouteTraffic(ctx, pod, &cfg, tailnetSvc, lg) if err != nil { return res, fmt.Errorf("error verifying if Pod is ready to route traffic: %w", err) } if !ready { continue // maybe next time } + podIP, err := podIPv4(&pod) // we currently only support IPv4 + if err != nil { + return res, fmt.Errorf("error determining IPv4 address for Pod: %w", err) + } newEndpoints = append(newEndpoints, discoveryv1.Endpoint{ Hostname: (*string)(&pod.UID), - Addresses: []string{pod.Status.PodIP}, + Addresses: []string{podIP}, Conditions: discoveryv1.EndpointConditions{ Ready: ptr.To(true), Serving: ptr.To(true), @@ -124,7 +130,7 @@ func (er *egressEpsReconciler) Reconcile(ctx context.Context, req reconcile.Requ // run a cleanup for deleted Pods etc. eps.Endpoints = newEndpoints if !reflect.DeepEqual(eps, oldEps) { - l.Infof("Updating EndpointSlice to ensure traffic is routed to ready proxy Pods") + lg.Infof("Updating EndpointSlice to ensure traffic is routed to ready proxy Pods") if err := er.Update(ctx, eps); err != nil { return res, fmt.Errorf("error updating EndpointSlice: %w", err) } @@ -132,26 +138,46 @@ func (er *egressEpsReconciler) Reconcile(ctx context.Context, req reconcile.Requ return res, nil } +func podIPv4(pod *corev1.Pod) (string, error) { + for _, ip := range pod.Status.PodIPs { + parsed, err := netip.ParseAddr(ip.IP) + if err != nil { + return "", fmt.Errorf("error parsing IP address %s: %w", ip, err) + } + if parsed.Is4() { + return parsed.String(), nil + } + } + return "", nil +} + // podIsReadyToRouteTraffic returns true if it appears that the proxy Pod has configured firewall rules to be able to // route traffic to the given tailnet service. It retrieves the proxy's state Secret and compares the tailnet service // status written there to the desired service configuration. -func (er *egressEpsReconciler) podIsReadyToRouteTraffic(ctx context.Context, pod corev1.Pod, cfg *egressservices.Config, tailnetSvcName string, l *zap.SugaredLogger) (bool, error) { - l = l.With("proxy_pod", pod.Name) - l.Debugf("checking whether proxy is ready to route to egress service") +func (er *egressEpsReconciler) podIsReadyToRouteTraffic(ctx context.Context, pod corev1.Pod, cfg *egressservices.Config, tailnetSvcName string, lg *zap.SugaredLogger) (bool, error) { + lg = lg.With("proxy_pod", pod.Name) + lg.Debugf("checking whether proxy is ready to route to egress service") if !pod.DeletionTimestamp.IsZero() { - l.Debugf("proxy Pod is being deleted, ignore") + lg.Debugf("proxy Pod is being deleted, ignore") + return false, nil + } + podIP, err := podIPv4(&pod) + if err != nil { + return false, fmt.Errorf("error determining Pod IP address: %v", err) + } + if podIP == "" { + lg.Infof("[unexpected] Pod does not have an IPv4 address, and IPv6 is not currently supported") return false, nil } - podIP := pod.Status.PodIP stateS := &corev1.Secret{ ObjectMeta: metav1.ObjectMeta{ Name: pod.Name, Namespace: pod.Namespace, }, } - err := er.Get(ctx, client.ObjectKeyFromObject(stateS), stateS) + err = er.Get(ctx, client.ObjectKeyFromObject(stateS), stateS) if apierrors.IsNotFound(err) { - l.Debugf("proxy does not have a state Secret, waiting...") + lg.Debugf("proxy does not have a state Secret, waiting...") return false, nil } if err != nil { @@ -159,30 +185,30 @@ func (er *egressEpsReconciler) podIsReadyToRouteTraffic(ctx context.Context, pod } svcStatusBS := stateS.Data[egressservices.KeyEgressServices] if len(svcStatusBS) == 0 { - l.Debugf("proxy's state Secret does not contain egress services status, waiting...") + lg.Debugf("proxy's state Secret does not contain egress services status, waiting...") return false, nil } svcStatus := &egressservices.Status{} if err := json.Unmarshal(svcStatusBS, svcStatus); err != nil { return false, fmt.Errorf("error unmarshalling egress service status: %w", err) } - if !strings.EqualFold(podIP, svcStatus.PodIP) { - l.Infof("proxy's egress service status is for Pod IP %s, current proxy's Pod IP %s, waiting for the proxy to reconfigure...", svcStatus.PodIP, podIP) + if !strings.EqualFold(podIP, svcStatus.PodIPv4) { + lg.Infof("proxy's egress service status is for Pod IP %s, current proxy's Pod IP %s, waiting for the proxy to reconfigure...", svcStatus.PodIPv4, podIP) return false, nil } st, ok := (*svcStatus).Services[tailnetSvcName] if !ok { - l.Infof("proxy's state Secret does not have egress service status, waiting...") + lg.Infof("proxy's state Secret does not have egress service status, waiting...") return false, nil } if !reflect.DeepEqual(cfg.TailnetTarget, st.TailnetTarget) { - l.Infof("proxy has configured egress service for tailnet target %v, current target is %v, waiting for proxy to reconfigure...", st.TailnetTarget, cfg.TailnetTarget) + lg.Infof("proxy has configured egress service for tailnet target %v, current target is %v, waiting for proxy to reconfigure...", st.TailnetTarget, cfg.TailnetTarget) return false, nil } if !reflect.DeepEqual(cfg.Ports, st.Ports) { - l.Debugf("proxy has configured egress service for ports %#+v, wants ports %#+v, waiting for proxy to reconfigure", st.Ports, cfg.Ports) + lg.Debugf("proxy has configured egress service for ports %#+v, wants ports %#+v, waiting for proxy to reconfigure", st.Ports, cfg.Ports) return false, nil } - l.Debugf("proxy is ready to route traffic to egress service") + lg.Debugf("proxy is ready to route traffic to egress service") return true, nil } diff --git a/cmd/k8s-operator/egress-eps_test.go b/cmd/k8s-operator/egress-eps_test.go index a2e95e5d3..bd80112ae 100644 --- a/cmd/k8s-operator/egress-eps_test.go +++ b/cmd/k8s-operator/egress-eps_test.go @@ -20,6 +20,7 @@ import ( "sigs.k8s.io/controller-runtime/pkg/client/fake" tsapi "tailscale.com/k8s-operator/apis/v1alpha1" "tailscale.com/kube/egressservices" + "tailscale.com/kube/kubetypes" "tailscale.com/tstest" "tailscale.com/util/mak" ) @@ -75,7 +76,11 @@ func TestTailscaleEgressEndpointSlices(t *testing.T) { ObjectMeta: metav1.ObjectMeta{ Name: "foo", Namespace: "operator-ns", - Labels: map[string]string{labelExternalSvcName: "test", labelExternalSvcNamespace: "default", labelProxyGroup: "foo"}, + Labels: map[string]string{ + LabelParentName: "test", + LabelParentNamespace: "default", + labelSvcType: typeEgress, + labelProxyGroup: "foo"}, }, AddressType: discoveryv1.AddressTypeIPv4, } @@ -94,13 +99,13 @@ func TestTailscaleEgressEndpointSlices(t *testing.T) { t.Run("pods_are_ready_to_route_traffic", func(t *testing.T) { pod, stateS := podAndSecretForProxyGroup("foo") - stBs := serviceStatusForPodIP(t, svc, pod.Status.PodIP, port) + stBs := serviceStatusForPodIP(t, svc, pod.Status.PodIPs[0].IP, port) mustUpdate(t, fc, "operator-ns", stateS.Name, func(s *corev1.Secret) { mak.Set(&s.Data, egressservices.KeyEgressServices, stBs) }) expectReconciled(t, er, "operator-ns", "foo") eps.Endpoints = append(eps.Endpoints, discoveryv1.Endpoint{ - Addresses: []string{pod.Status.PodIP}, + Addresses: []string{"10.0.0.1"}, Hostname: pointer.To("foo"), Conditions: discoveryv1.EndpointConditions{ Serving: pointer.ToBool(true), @@ -108,7 +113,17 @@ func TestTailscaleEgressEndpointSlices(t *testing.T) { Terminating: pointer.ToBool(false), }, }) - expectEqual(t, fc, eps, nil) + expectEqual(t, fc, eps) + }) + t.Run("status_does_not_match_pod_ip", func(t *testing.T) { + _, stateS := podAndSecretForProxyGroup("foo") // replica Pod has IP 10.0.0.1 + stBs := serviceStatusForPodIP(t, svc, "10.0.0.2", port) // status is for a Pod with IP 10.0.0.2 + mustUpdate(t, fc, "operator-ns", stateS.Name, func(s *corev1.Secret) { + mak.Set(&s.Data, egressservices.KeyEgressServices, stBs) + }) + expectReconciled(t, er, "operator-ns", "foo") + eps.Endpoints = []discoveryv1.Endpoint{} + expectEqual(t, fc, eps) }) } @@ -135,7 +150,7 @@ func configMapForSvc(t *testing.T, svc *corev1.Service, p uint16) *corev1.Config } cm := &corev1.ConfigMap{ ObjectMeta: metav1.ObjectMeta{ - Name: fmt.Sprintf(egressSvcsCMNameTemplate, svc.Annotations[AnnotationProxyGroup]), + Name: pgEgressCMName(svc.Annotations[AnnotationProxyGroup]), Namespace: "operator-ns", }, BinaryData: map[string][]byte{egressservices.KeyEgressServices: bs}, @@ -158,7 +173,7 @@ func serviceStatusForPodIP(t *testing.T, svc *corev1.Service, ip string, p uint1 } svcName := tailnetSvcName(svc) st := egressservices.Status{ - PodIP: ip, + PodIPv4: ip, Services: map[string]*egressservices.ServiceStatus{svcName: &svcSt}, } bs, err := json.Marshal(st) @@ -173,18 +188,20 @@ func podAndSecretForProxyGroup(pg string) (*corev1.Pod, *corev1.Secret) { ObjectMeta: metav1.ObjectMeta{ Name: fmt.Sprintf("%s-0", pg), Namespace: "operator-ns", - Labels: map[string]string{labelProxyGroup: pg}, + Labels: pgLabels(pg, nil), UID: "foo", }, Status: corev1.PodStatus{ - PodIP: "10.0.0.1", + PodIPs: []corev1.PodIP{ + {IP: "10.0.0.1"}, + }, }, } s := &corev1.Secret{ ObjectMeta: metav1.ObjectMeta{ Name: fmt.Sprintf("%s-0", pg), Namespace: "operator-ns", - Labels: map[string]string{labelProxyGroup: pg}, + Labels: pgSecretLabels(pg, kubetypes.LabelSecretTypeState), }, } return p, s diff --git a/cmd/k8s-operator/egress-pod-readiness.go b/cmd/k8s-operator/egress-pod-readiness.go new file mode 100644 index 000000000..a732e0861 --- /dev/null +++ b/cmd/k8s-operator/egress-pod-readiness.go @@ -0,0 +1,274 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !plan9 + +package main + +import ( + "context" + "errors" + "fmt" + "net/http" + "slices" + "strings" + "sync/atomic" + "time" + + "go.uber.org/zap" + xslices "golang.org/x/exp/slices" + corev1 "k8s.io/api/core/v1" + apierrors "k8s.io/apimachinery/pkg/api/errors" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/types" + "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/reconcile" + tsapi "tailscale.com/k8s-operator/apis/v1alpha1" + "tailscale.com/kube/kubetypes" + "tailscale.com/tstime" + "tailscale.com/util/backoff" + "tailscale.com/util/httpm" +) + +const tsEgressReadinessGate = "tailscale.com/egress-services" + +// egressPodsReconciler is responsible for setting tailscale.com/egress-services condition on egress ProxyGroup Pods. +// The condition is used as a readiness gate for the Pod, meaning that kubelet will not mark the Pod as ready before the +// condition is set. The ProxyGroup StatefulSet updates are rolled out in such a way that no Pod is restarted, before +// the previous Pod is marked as ready, so ensuring that the Pod does not get marked as ready when it is not yet able to +// route traffic for egress service prevents downtime during restarts caused by no available endpoints left because +// every Pod has been recreated and is not yet added to endpoints. +// https://kubernetes.io/docs/concepts/workloads/pods/pod-lifecycle/#pod-readiness-gate +type egressPodsReconciler struct { + client.Client + logger *zap.SugaredLogger + tsNamespace string + clock tstime.Clock + httpClient doer // http client that can be set to a mock client in tests + maxBackoff time.Duration // max backoff period between health check calls +} + +// Reconcile reconciles an egress ProxyGroup Pods on changes to those Pods and ProxyGroup EndpointSlices. It ensures +// that for each Pod who is ready to route traffic to all egress services for the ProxyGroup, the Pod has a +// tailscale.com/egress-services condition to set, so that kubelet will mark the Pod as ready. +// +// For the Pod to be ready +// to route traffic to the egress service, the kube proxy needs to have set up the Pod's IP as an endpoint for the +// ClusterIP Service corresponding to the egress service. +// +// Note that the endpoints for the ClusterIP Service are configured by the operator itself using custom +// EndpointSlices(egress-eps-reconciler), so the routing is not blocked on Pod's readiness. +// +// Each egress service has a corresponding ClusterIP Service, that exposes all user configured +// tailnet ports, as well as a health check port for the proxy. +// +// The reconciler calls the health check endpoint of each Service up to N number of times, where N is the number of +// replicas for the ProxyGroup x 3, and checks if the received response is healthy response from the Pod being reconciled. +// +// The health check response contains a header with the +// Pod's IP address- this is used to determine whether the response is received from this Pod. +// +// If the Pod does not appear to be serving the health check endpoint (pre-v1.80 proxies), the reconciler just sets the +// readiness condition for backwards compatibility reasons. +func (er *egressPodsReconciler) Reconcile(ctx context.Context, req reconcile.Request) (res reconcile.Result, err error) { + lg := er.logger.With("Pod", req.NamespacedName) + lg.Debugf("starting reconcile") + defer lg.Debugf("reconcile finished") + + pod := new(corev1.Pod) + err = er.Get(ctx, req.NamespacedName, pod) + if apierrors.IsNotFound(err) { + return reconcile.Result{}, nil + } + if err != nil { + return reconcile.Result{}, fmt.Errorf("failed to get Pod: %w", err) + } + if !pod.DeletionTimestamp.IsZero() { + lg.Debugf("Pod is being deleted, do nothing") + return res, nil + } + if pod.Labels[LabelParentType] != proxyTypeProxyGroup { + lg.Infof("[unexpected] reconciler called for a Pod that is not a ProxyGroup Pod") + return res, nil + } + + // If the Pod does not have the readiness gate set, there is no need to add the readiness condition. In practice + // this will happen if the user has configured custom TS_LOCAL_ADDR_PORT, thus disabling the graceful failover. + if !slices.ContainsFunc(pod.Spec.ReadinessGates, func(r corev1.PodReadinessGate) bool { + return r.ConditionType == tsEgressReadinessGate + }) { + lg.Debug("Pod does not have egress readiness gate set, skipping") + return res, nil + } + + proxyGroupName := pod.Labels[LabelParentName] + pg := new(tsapi.ProxyGroup) + if err := er.Get(ctx, types.NamespacedName{Name: proxyGroupName}, pg); err != nil { + return res, fmt.Errorf("error getting ProxyGroup %q: %w", proxyGroupName, err) + } + if pg.Spec.Type != typeEgress { + lg.Infof("[unexpected] reconciler called for %q ProxyGroup Pod", pg.Spec.Type) + return res, nil + } + // Get all ClusterIP Services for all egress targets exposed to cluster via this ProxyGroup. + lbls := map[string]string{ + kubetypes.LabelManaged: "true", + labelProxyGroup: proxyGroupName, + labelSvcType: typeEgress, + } + svcs := &corev1.ServiceList{} + if err := er.List(ctx, svcs, client.InNamespace(er.tsNamespace), client.MatchingLabels(lbls)); err != nil { + return res, fmt.Errorf("error listing ClusterIP Services") + } + + idx := xslices.IndexFunc(pod.Status.Conditions, func(c corev1.PodCondition) bool { + return c.Type == tsEgressReadinessGate + }) + if idx != -1 { + lg.Debugf("Pod is already ready, do nothing") + return res, nil + } + + var routesMissing atomic.Bool + errChan := make(chan error, len(svcs.Items)) + for _, svc := range svcs.Items { + s := svc + go func() { + ll := lg.With("service_name", s.Name) + d := retrieveClusterDomain(er.tsNamespace, ll) + healthCheckAddr := healthCheckForSvc(&s, d) + if healthCheckAddr == "" { + ll.Debugf("ClusterIP Service does not expose a health check endpoint, unable to verify if routing is set up") + errChan <- nil + return + } + + var routesSetup bool + bo := backoff.NewBackoff(s.Name, ll.Infof, er.maxBackoff) + for range numCalls(pgReplicas(pg)) { + if ctx.Err() != nil { + errChan <- nil + return + } + state, err := er.lookupPodRouteViaSvc(ctx, pod, healthCheckAddr, ll) + if err != nil { + errChan <- fmt.Errorf("error validating if routing has been set up for Pod: %w", err) + return + } + if state == healthy || state == cannotVerify { + routesSetup = true + break + } + if state == unreachable || state == unhealthy || state == podNotReady { + bo.BackOff(ctx, errors.New("backoff")) + } + } + if !routesSetup { + ll.Debugf("Pod is not yet configured as Service endpoint") + routesMissing.Store(true) + } + errChan <- nil + }() + } + for range len(svcs.Items) { + e := <-errChan + err = errors.Join(err, e) + } + if err != nil { + return res, fmt.Errorf("error verifying conectivity: %w", err) + } + if rm := routesMissing.Load(); rm { + lg.Info("Pod is not yet added as an endpoint for all egress targets, waiting...") + return reconcile.Result{RequeueAfter: shortRequeue}, nil + } + if err := er.setPodReady(ctx, pod, lg); err != nil { + return res, fmt.Errorf("error setting Pod as ready: %w", err) + } + return res, nil +} + +func (er *egressPodsReconciler) setPodReady(ctx context.Context, pod *corev1.Pod, lg *zap.SugaredLogger) error { + if slices.ContainsFunc(pod.Status.Conditions, func(c corev1.PodCondition) bool { + return c.Type == tsEgressReadinessGate + }) { + return nil + } + lg.Infof("Pod is ready to route traffic to all egress targets") + pod.Status.Conditions = append(pod.Status.Conditions, corev1.PodCondition{ + Type: tsEgressReadinessGate, + Status: corev1.ConditionTrue, + LastTransitionTime: metav1.Time{Time: er.clock.Now()}, + }) + return er.Status().Update(ctx, pod) +} + +// healthCheckState is the result of a single request to an egress Service health check endpoint with a goal to hit a +// specific backend Pod. +type healthCheckState int8 + +const ( + cannotVerify healthCheckState = iota // not verifiable for this setup (i.e earlier proxy version) + unreachable // no backends or another network error + notFound // hit another backend + unhealthy // not 200 + podNotReady // Pod is not ready, i.e does not have an IP address yet + healthy // 200 +) + +// lookupPodRouteViaSvc attempts to reach a Pod using a health check endpoint served by a Service and returns the state of the health check. +func (er *egressPodsReconciler) lookupPodRouteViaSvc(ctx context.Context, pod *corev1.Pod, healthCheckAddr string, lg *zap.SugaredLogger) (healthCheckState, error) { + if !slices.ContainsFunc(pod.Spec.Containers[0].Env, func(e corev1.EnvVar) bool { + return e.Name == "TS_ENABLE_HEALTH_CHECK" && e.Value == "true" + }) { + lg.Debugf("Pod does not have health check enabled, unable to verify if it is currently routable via Service") + return cannotVerify, nil + } + wantsIP, err := podIPv4(pod) + if err != nil { + return -1, fmt.Errorf("error determining Pod's IP address: %w", err) + } + if wantsIP == "" { + return podNotReady, nil + } + + ctx, cancel := context.WithTimeout(ctx, time.Second*3) + defer cancel() + req, err := http.NewRequestWithContext(ctx, httpm.GET, healthCheckAddr, nil) + if err != nil { + return -1, fmt.Errorf("error creating new HTTP request: %w", err) + } + // Do not re-use the same connection for the next request so to maximize the chance of hitting all backends equally. + req.Close = true + resp, err := er.httpClient.Do(req) + if err != nil { + // This is most likely because this is the first Pod and is not yet added to Service endoints. Other + // error types are possible, but checking for those would likely make the system too fragile. + return unreachable, nil + } + defer resp.Body.Close() + gotIP := resp.Header.Get(kubetypes.PodIPv4Header) + if gotIP == "" { + lg.Debugf("Health check does not return Pod's IP header, unable to verify if Pod is currently routable via Service") + return cannotVerify, nil + } + if !strings.EqualFold(wantsIP, gotIP) { + return notFound, nil + } + if resp.StatusCode != http.StatusOK { + return unhealthy, nil + } + return healthy, nil +} + +// numCalls return the number of times an endpoint on a ProxyGroup Service should be called till it can be safely +// assumed that, if none of the responses came back from a specific Pod then traffic for the Service is currently not +// being routed to that Pod. This assumes that traffic for the Service is routed via round robin, so +// InternalTrafficPolicy must be 'Cluster' and session affinity must be None. +func numCalls(replicas int32) int32 { + return replicas * 3 +} + +// doer is an interface for HTTP client that can be set to a mock client in tests. +type doer interface { + Do(*http.Request) (*http.Response, error) +} diff --git a/cmd/k8s-operator/egress-pod-readiness_test.go b/cmd/k8s-operator/egress-pod-readiness_test.go new file mode 100644 index 000000000..3c35d9043 --- /dev/null +++ b/cmd/k8s-operator/egress-pod-readiness_test.go @@ -0,0 +1,525 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !plan9 + +package main + +import ( + "bytes" + "errors" + "fmt" + "io" + "log" + "net/http" + "sync" + "testing" + "time" + + "go.uber.org/zap" + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/util/intstr" + "sigs.k8s.io/controller-runtime/pkg/client/fake" + tsapi "tailscale.com/k8s-operator/apis/v1alpha1" + "tailscale.com/kube/kubetypes" + "tailscale.com/tstest" + "tailscale.com/types/ptr" +) + +func TestEgressPodReadiness(t *testing.T) { + // We need to pass a Pod object to WithStatusSubresource because of some quirks in how the fake client + // works. Without this code we would not be able to update Pod's status further down. + fc := fake.NewClientBuilder(). + WithScheme(tsapi.GlobalScheme). + WithStatusSubresource(&corev1.Pod{}). + Build() + zl, _ := zap.NewDevelopment() + cl := tstest.NewClock(tstest.ClockOpts{}) + rec := &egressPodsReconciler{ + tsNamespace: "operator-ns", + Client: fc, + logger: zl.Sugar(), + clock: cl, + } + pg := &tsapi.ProxyGroup{ + ObjectMeta: metav1.ObjectMeta{ + Name: "dev", + }, + Spec: tsapi.ProxyGroupSpec{ + Type: "egress", + Replicas: ptr.To(int32(3)), + }, + } + mustCreate(t, fc, pg) + podIP := "10.0.0.2" + podTemplate := &corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Namespace: "operator-ns", + Name: "pod", + Labels: map[string]string{ + LabelParentType: "proxygroup", + LabelParentName: "dev", + }, + }, + Spec: corev1.PodSpec{ + ReadinessGates: []corev1.PodReadinessGate{{ + ConditionType: tsEgressReadinessGate, + }}, + Containers: []corev1.Container{{ + Name: "tailscale", + Env: []corev1.EnvVar{{ + Name: "TS_ENABLE_HEALTH_CHECK", + Value: "true", + }}, + }}, + }, + Status: corev1.PodStatus{ + PodIPs: []corev1.PodIP{{IP: podIP}}, + }, + } + + t.Run("no_egress_services", func(t *testing.T) { + pod := podTemplate.DeepCopy() + mustCreate(t, fc, pod) + expectReconciled(t, rec, "operator-ns", pod.Name) + + // Pod should have readiness gate condition set. + podSetReady(pod, cl) + expectEqual(t, fc, pod) + mustDeleteAll(t, fc, pod) + }) + t.Run("one_svc_already_routed_to", func(t *testing.T) { + pod := podTemplate.DeepCopy() + + svc, hep := newSvc("svc", 9002) + mustCreateAll(t, fc, svc, pod) + resp := readyResps(podIP, 1) + httpCl := fakeHTTPClient{ + t: t, + state: map[string][]fakeResponse{hep: resp}, + } + rec.httpClient = &httpCl + expectReconciled(t, rec, "operator-ns", pod.Name) + + // Pod should have readiness gate condition set. + podSetReady(pod, cl) + expectEqual(t, fc, pod) + + // A subsequent reconcile should not change the Pod. + expectReconciled(t, rec, "operator-ns", pod.Name) + expectEqual(t, fc, pod) + + mustDeleteAll(t, fc, pod, svc) + }) + t.Run("one_svc_many_backends_eventually_routed_to", func(t *testing.T) { + pod := podTemplate.DeepCopy() + + svc, hep := newSvc("svc", 9002) + mustCreateAll(t, fc, svc, pod) + // For a 3 replica ProxyGroup the healthcheck endpoint should be called 9 times, make the 9th time only + // return with the right Pod IP. + resps := append(readyResps("10.0.0.3", 4), append(readyResps("10.0.0.4", 4), readyResps(podIP, 1)...)...) + httpCl := fakeHTTPClient{ + t: t, + state: map[string][]fakeResponse{hep: resps}, + } + rec.httpClient = &httpCl + expectReconciled(t, rec, "operator-ns", pod.Name) + + // Pod should have readiness gate condition set. + podSetReady(pod, cl) + expectEqual(t, fc, pod) + mustDeleteAll(t, fc, pod, svc) + }) + t.Run("one_svc_one_backend_eventually_healthy", func(t *testing.T) { + pod := podTemplate.DeepCopy() + + svc, hep := newSvc("svc", 9002) + mustCreateAll(t, fc, svc, pod) + // For a 3 replica ProxyGroup the healthcheck endpoint should be called 9 times, make the 9th time only + // return with 200 status code. + resps := append(unreadyResps(podIP, 8), readyResps(podIP, 1)...) + httpCl := fakeHTTPClient{ + t: t, + state: map[string][]fakeResponse{hep: resps}, + } + rec.httpClient = &httpCl + expectReconciled(t, rec, "operator-ns", pod.Name) + + // Pod should have readiness gate condition set. + podSetReady(pod, cl) + expectEqual(t, fc, pod) + mustDeleteAll(t, fc, pod, svc) + }) + t.Run("one_svc_one_backend_never_routable", func(t *testing.T) { + pod := podTemplate.DeepCopy() + + svc, hep := newSvc("svc", 9002) + mustCreateAll(t, fc, svc, pod) + // For a 3 replica ProxyGroup the healthcheck endpoint should be called 9 times and Pod should be + // requeued if neither of those succeed. + resps := readyResps("10.0.0.3", 9) + httpCl := fakeHTTPClient{ + t: t, + state: map[string][]fakeResponse{hep: resps}, + } + rec.httpClient = &httpCl + expectRequeue(t, rec, "operator-ns", pod.Name) + + // Pod should not have readiness gate condition set. + expectEqual(t, fc, pod) + mustDeleteAll(t, fc, pod, svc) + }) + t.Run("one_svc_many_backends_already_routable", func(t *testing.T) { + pod := podTemplate.DeepCopy() + + svc, hep := newSvc("svc", 9002) + svc2, hep2 := newSvc("svc-2", 9002) + svc3, hep3 := newSvc("svc-3", 9002) + mustCreateAll(t, fc, svc, svc2, svc3, pod) + resps := readyResps(podIP, 1) + httpCl := fakeHTTPClient{ + t: t, + state: map[string][]fakeResponse{ + hep: resps, + hep2: resps, + hep3: resps, + }, + } + rec.httpClient = &httpCl + expectReconciled(t, rec, "operator-ns", pod.Name) + + // Pod should not have readiness gate condition set. + podSetReady(pod, cl) + expectEqual(t, fc, pod) + mustDeleteAll(t, fc, pod, svc, svc2, svc3) + }) + t.Run("one_svc_many_backends_eventually_routable_and_healthy", func(t *testing.T) { + pod := podTemplate.DeepCopy() + svc, hep := newSvc("svc", 9002) + svc2, hep2 := newSvc("svc-2", 9002) + svc3, hep3 := newSvc("svc-3", 9002) + mustCreateAll(t, fc, svc, svc2, svc3, pod) + resps := append(readyResps("10.0.0.3", 7), readyResps(podIP, 1)...) + resps2 := append(readyResps("10.0.0.3", 5), readyResps(podIP, 1)...) + resps3 := append(unreadyResps(podIP, 4), readyResps(podIP, 1)...) + httpCl := fakeHTTPClient{ + t: t, + state: map[string][]fakeResponse{ + hep: resps, + hep2: resps2, + hep3: resps3, + }, + } + rec.httpClient = &httpCl + expectReconciled(t, rec, "operator-ns", pod.Name) + + // Pod should have readiness gate condition set. + podSetReady(pod, cl) + expectEqual(t, fc, pod) + mustDeleteAll(t, fc, pod, svc, svc2, svc3) + }) + t.Run("one_svc_many_backends_never_routable_and_healthy", func(t *testing.T) { + pod := podTemplate.DeepCopy() + + svc, hep := newSvc("svc", 9002) + svc2, hep2 := newSvc("svc-2", 9002) + svc3, hep3 := newSvc("svc-3", 9002) + mustCreateAll(t, fc, svc, svc2, svc3, pod) + // For a ProxyGroup with 3 replicas, each Service's health endpoint will be tried 9 times and the Pod + // will be requeued if neither succeeds. + resps := readyResps("10.0.0.3", 9) + resps2 := append(readyResps("10.0.0.3", 5), readyResps("10.0.0.4", 4)...) + resps3 := unreadyResps(podIP, 9) + httpCl := fakeHTTPClient{ + t: t, + state: map[string][]fakeResponse{ + hep: resps, + hep2: resps2, + hep3: resps3, + }, + } + rec.httpClient = &httpCl + expectRequeue(t, rec, "operator-ns", pod.Name) + + // Pod should not have readiness gate condition set. + expectEqual(t, fc, pod) + mustDeleteAll(t, fc, pod, svc, svc2, svc3) + }) + t.Run("one_svc_many_backends_one_never_routable", func(t *testing.T) { + pod := podTemplate.DeepCopy() + + svc, hep := newSvc("svc", 9002) + svc2, hep2 := newSvc("svc-2", 9002) + svc3, hep3 := newSvc("svc-3", 9002) + mustCreateAll(t, fc, svc, svc2, svc3, pod) + // For a ProxyGroup with 3 replicas, each Service's health endpoint will be tried 9 times and the Pod + // will be requeued if any one never succeeds. + resps := readyResps(podIP, 9) + resps2 := readyResps(podIP, 9) + resps3 := append(readyResps("10.0.0.3", 5), readyResps("10.0.0.4", 4)...) + httpCl := fakeHTTPClient{ + t: t, + state: map[string][]fakeResponse{ + hep: resps, + hep2: resps2, + hep3: resps3, + }, + } + rec.httpClient = &httpCl + expectRequeue(t, rec, "operator-ns", pod.Name) + + // Pod should not have readiness gate condition set. + expectEqual(t, fc, pod) + mustDeleteAll(t, fc, pod, svc, svc2, svc3) + }) + t.Run("one_svc_many_backends_one_never_healthy", func(t *testing.T) { + pod := podTemplate.DeepCopy() + + svc, hep := newSvc("svc", 9002) + svc2, hep2 := newSvc("svc-2", 9002) + svc3, hep3 := newSvc("svc-3", 9002) + mustCreateAll(t, fc, svc, svc2, svc3, pod) + // For a ProxyGroup with 3 replicas, each Service's health endpoint will be tried 9 times and the Pod + // will be requeued if any one never succeeds. + resps := readyResps(podIP, 9) + resps2 := unreadyResps(podIP, 9) + resps3 := readyResps(podIP, 9) + httpCl := fakeHTTPClient{ + t: t, + state: map[string][]fakeResponse{ + hep: resps, + hep2: resps2, + hep3: resps3, + }, + } + rec.httpClient = &httpCl + expectRequeue(t, rec, "operator-ns", pod.Name) + + // Pod should not have readiness gate condition set. + expectEqual(t, fc, pod) + mustDeleteAll(t, fc, pod, svc, svc2, svc3) + }) + t.Run("one_svc_many_backends_different_ports_eventually_healthy_and_routable", func(t *testing.T) { + pod := podTemplate.DeepCopy() + + svc, hep := newSvc("svc", 9003) + svc2, hep2 := newSvc("svc-2", 9004) + svc3, hep3 := newSvc("svc-3", 9010) + mustCreateAll(t, fc, svc, svc2, svc3, pod) + // For a ProxyGroup with 3 replicas, each Service's health endpoint will be tried up to 9 times and + // marked as success as soon as one try succeeds. + resps := append(readyResps("10.0.0.3", 7), readyResps(podIP, 1)...) + resps2 := append(readyResps("10.0.0.3", 5), readyResps(podIP, 1)...) + resps3 := append(unreadyResps(podIP, 4), readyResps(podIP, 1)...) + httpCl := fakeHTTPClient{ + t: t, + state: map[string][]fakeResponse{ + hep: resps, + hep2: resps2, + hep3: resps3, + }, + } + rec.httpClient = &httpCl + expectReconciled(t, rec, "operator-ns", pod.Name) + + // Pod should have readiness gate condition set. + podSetReady(pod, cl) + expectEqual(t, fc, pod) + mustDeleteAll(t, fc, pod, svc, svc2, svc3) + }) + // Proxies of 1.78 and earlier did not set the Pod IP header. + t.Run("pod_does_not_return_ip_header", func(t *testing.T) { + pod := podTemplate.DeepCopy() + pod.Name = "foo-bar" + + svc, hep := newSvc("foo-bar", 9002) + mustCreateAll(t, fc, svc, pod) + // If a response does not contain Pod IP header, we assume that this is an earlier proxy version, + // readiness cannot be verified so the readiness gate is just set to true. + resps := unreadyResps("", 1) + httpCl := fakeHTTPClient{ + t: t, + state: map[string][]fakeResponse{ + hep: resps, + }, + } + rec.httpClient = &httpCl + expectReconciled(t, rec, "operator-ns", pod.Name) + + // Pod should have readiness gate condition set. + podSetReady(pod, cl) + expectEqual(t, fc, pod) + mustDeleteAll(t, fc, pod, svc) + }) + t.Run("one_svc_one_backend_eventually_healthy_and_routable", func(t *testing.T) { + pod := podTemplate.DeepCopy() + + svc, hep := newSvc("svc", 9002) + mustCreateAll(t, fc, svc, pod) + // If a response errors, it is probably because the Pod is not yet properly running, so retry. + resps := append(erroredResps(8), readyResps(podIP, 1)...) + httpCl := fakeHTTPClient{ + t: t, + state: map[string][]fakeResponse{ + hep: resps, + }, + } + rec.httpClient = &httpCl + expectReconciled(t, rec, "operator-ns", pod.Name) + + // Pod should have readiness gate condition set. + podSetReady(pod, cl) + expectEqual(t, fc, pod) + mustDeleteAll(t, fc, pod, svc) + }) + t.Run("one_svc_one_backend_svc_does_not_have_health_port", func(t *testing.T) { + pod := podTemplate.DeepCopy() + + // If a Service does not have health port set, we assume that it is not possible to determine Pod's + // readiness and set it to ready. + svc, _ := newSvc("svc", -1) + mustCreateAll(t, fc, svc, pod) + rec.httpClient = nil + expectReconciled(t, rec, "operator-ns", pod.Name) + + // Pod should have readiness gate condition set. + podSetReady(pod, cl) + expectEqual(t, fc, pod) + mustDeleteAll(t, fc, pod, svc) + }) + t.Run("error_setting_up_healthcheck", func(t *testing.T) { + pod := podTemplate.DeepCopy() + // This is not a realistic reason for error, but we are just testing the behaviour of a healthcheck + // lookup failing. + pod.Status.PodIPs = []corev1.PodIP{{IP: "not-an-ip"}} + + svc, _ := newSvc("svc", 9002) + svc2, _ := newSvc("svc-2", 9002) + svc3, _ := newSvc("svc-3", 9002) + mustCreateAll(t, fc, svc, svc2, svc3, pod) + rec.httpClient = nil + expectError(t, rec, "operator-ns", pod.Name) + + // Pod should not have readiness gate condition set. + expectEqual(t, fc, pod) + mustDeleteAll(t, fc, pod, svc, svc2, svc3) + }) + t.Run("pod_does_not_have_an_ip_address", func(t *testing.T) { + pod := podTemplate.DeepCopy() + pod.Status.PodIPs = nil + + svc, _ := newSvc("svc", 9002) + svc2, _ := newSvc("svc-2", 9002) + svc3, _ := newSvc("svc-3", 9002) + mustCreateAll(t, fc, svc, svc2, svc3, pod) + rec.httpClient = nil + expectRequeue(t, rec, "operator-ns", pod.Name) + + // Pod should not have readiness gate condition set. + expectEqual(t, fc, pod) + mustDeleteAll(t, fc, pod, svc, svc2, svc3) + }) +} + +func readyResps(ip string, num int) (resps []fakeResponse) { + for range num { + resps = append(resps, fakeResponse{statusCode: 200, podIP: ip}) + } + return resps +} + +func unreadyResps(ip string, num int) (resps []fakeResponse) { + for range num { + resps = append(resps, fakeResponse{statusCode: 503, podIP: ip}) + } + return resps +} + +func erroredResps(num int) (resps []fakeResponse) { + for range num { + resps = append(resps, fakeResponse{err: errors.New("timeout")}) + } + return resps +} + +func newSvc(name string, port int32) (*corev1.Service, string) { + svc := &corev1.Service{ + ObjectMeta: metav1.ObjectMeta{ + Namespace: "operator-ns", + Name: name, + Labels: map[string]string{ + kubetypes.LabelManaged: "true", + labelProxyGroup: "dev", + labelSvcType: typeEgress, + }, + }, + Spec: corev1.ServiceSpec{}, + } + if port != -1 { + svc.Spec.Ports = []corev1.ServicePort{ + { + Name: tsHealthCheckPortName, + Port: port, + TargetPort: intstr.FromInt(9002), + Protocol: "TCP", + }, + } + } + return svc, fmt.Sprintf("http://%s.operator-ns.svc.cluster.local:%d/healthz", name, port) +} + +func podSetReady(pod *corev1.Pod, cl *tstest.Clock) { + pod.Status.Conditions = append(pod.Status.Conditions, corev1.PodCondition{ + Type: tsEgressReadinessGate, + Status: corev1.ConditionTrue, + LastTransitionTime: metav1.Time{Time: cl.Now().Truncate(time.Second)}, + }) +} + +// fakeHTTPClient is a mock HTTP client with a preset map of request URLs to list of responses. When it receives a +// request for a specific URL, it returns the preset response for that URL. It errors if an unexpected request is +// received. +type fakeHTTPClient struct { + t *testing.T + mu sync.Mutex // protects following + state map[string][]fakeResponse +} + +func (f *fakeHTTPClient) Do(req *http.Request) (*http.Response, error) { + f.mu.Lock() + resps := f.state[req.URL.String()] + if len(resps) == 0 { + f.mu.Unlock() + log.Printf("\n\n\nURL %q\n\n\n", req.URL) + f.t.Fatalf("fakeHTTPClient received an unexpected request for %q", req.URL) + } + defer func() { + if len(resps) == 1 { + delete(f.state, req.URL.String()) + f.mu.Unlock() + return + } + f.state[req.URL.String()] = f.state[req.URL.String()][1:] + f.mu.Unlock() + }() + + resp := resps[0] + if resp.err != nil { + return nil, resp.err + } + r := http.Response{ + StatusCode: resp.statusCode, + Header: make(http.Header), + Body: io.NopCloser(bytes.NewReader([]byte{})), + } + r.Header.Add(kubetypes.PodIPv4Header, resp.podIP) + return &r, nil +} + +type fakeResponse struct { + err error + statusCode int + podIP string // for the Pod IP header +} diff --git a/cmd/k8s-operator/egress-services-readiness.go b/cmd/k8s-operator/egress-services-readiness.go new file mode 100644 index 000000000..80f3c7d28 --- /dev/null +++ b/cmd/k8s-operator/egress-services-readiness.go @@ -0,0 +1,180 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !plan9 + +package main + +import ( + "context" + "errors" + "fmt" + "strings" + + "go.uber.org/zap" + appsv1 "k8s.io/api/apps/v1" + corev1 "k8s.io/api/core/v1" + discoveryv1 "k8s.io/api/discovery/v1" + apiequality "k8s.io/apimachinery/pkg/api/equality" + apierrors "k8s.io/apimachinery/pkg/api/errors" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/reconcile" + tsoperator "tailscale.com/k8s-operator" + tsapi "tailscale.com/k8s-operator/apis/v1alpha1" + "tailscale.com/tstime" +) + +const ( + reasonReadinessCheckFailed = "ReadinessCheckFailed" + reasonClusterResourcesNotReady = "ClusterResourcesNotReady" + reasonNoProxies = "NoProxiesConfigured" + reasonNotReady = "NotReadyToRouteTraffic" + reasonReady = "ReadyToRouteTraffic" + reasonPartiallyReady = "PartiallyReadyToRouteTraffic" + msgReadyToRouteTemplate = "%d out of %d replicas are ready to route traffic" +) + +type egressSvcsReadinessReconciler struct { + client.Client + logger *zap.SugaredLogger + clock tstime.Clock + tsNamespace string +} + +// Reconcile reconciles an ExternalName Service that defines a tailnet target to be exposed on a ProxyGroup and sets the +// EgressSvcReady condition on it. The condition gets set to true if at least one of the proxies is currently ready to +// route traffic to the target. It compares proxy Pod IPs with the endpoints set on the EndpointSlice for the egress +// service to determine how many replicas are currently able to route traffic. +func (esrr *egressSvcsReadinessReconciler) Reconcile(ctx context.Context, req reconcile.Request) (res reconcile.Result, err error) { + lg := esrr.logger.With("Service", req.NamespacedName) + lg.Debugf("starting reconcile") + defer lg.Debugf("reconcile finished") + + svc := new(corev1.Service) + if err = esrr.Get(ctx, req.NamespacedName, svc); apierrors.IsNotFound(err) { + lg.Debugf("Service not found") + return res, nil + } else if err != nil { + return res, fmt.Errorf("failed to get Service: %w", err) + } + var ( + reason, msg string + st metav1.ConditionStatus = metav1.ConditionUnknown + ) + oldStatus := svc.Status.DeepCopy() + defer func() { + tsoperator.SetServiceCondition(svc, tsapi.EgressSvcReady, st, reason, msg, esrr.clock, lg) + if !apiequality.Semantic.DeepEqual(oldStatus, &svc.Status) { + err = errors.Join(err, esrr.Status().Update(ctx, svc)) + } + }() + + crl := egressSvcChildResourceLabels(svc) + eps, err := getSingleObject[discoveryv1.EndpointSlice](ctx, esrr.Client, esrr.tsNamespace, crl) + if err != nil { + err = fmt.Errorf("error getting EndpointSlice: %w", err) + reason = reasonReadinessCheckFailed + msg = err.Error() + return res, err + } + if eps == nil { + lg.Infof("EndpointSlice for Service does not yet exist, waiting...") + reason, msg = reasonClusterResourcesNotReady, reasonClusterResourcesNotReady + st = metav1.ConditionFalse + return res, nil + } + pg := &tsapi.ProxyGroup{ + ObjectMeta: metav1.ObjectMeta{ + Name: svc.Annotations[AnnotationProxyGroup], + }, + } + err = esrr.Get(ctx, client.ObjectKeyFromObject(pg), pg) + if apierrors.IsNotFound(err) { + lg.Infof("ProxyGroup for Service does not exist, waiting...") + reason, msg = reasonClusterResourcesNotReady, reasonClusterResourcesNotReady + st = metav1.ConditionFalse + return res, nil + } + if err != nil { + err = fmt.Errorf("error retrieving ProxyGroup: %w", err) + reason = reasonReadinessCheckFailed + msg = err.Error() + return res, err + } + if !tsoperator.ProxyGroupAvailable(pg) { + lg.Infof("ProxyGroup for Service is not ready, waiting...") + reason, msg = reasonClusterResourcesNotReady, reasonClusterResourcesNotReady + st = metav1.ConditionFalse + return res, nil + } + + replicas := pgReplicas(pg) + if replicas == 0 { + lg.Infof("ProxyGroup replicas set to 0") + reason, msg = reasonNoProxies, reasonNoProxies + st = metav1.ConditionFalse + return res, nil + } + podLabels := pgLabels(pg.Name, nil) + var readyReplicas int32 + for i := range replicas { + podLabels[appsv1.PodIndexLabel] = fmt.Sprintf("%d", i) + pod, err := getSingleObject[corev1.Pod](ctx, esrr.Client, esrr.tsNamespace, podLabels) + if err != nil { + err = fmt.Errorf("error retrieving ProxyGroup Pod: %w", err) + reason = reasonReadinessCheckFailed + msg = err.Error() + return res, err + } + if pod == nil { + lg.Warnf("[unexpected] ProxyGroup is ready, but replica %d was not found", i) + reason, msg = reasonClusterResourcesNotReady, reasonClusterResourcesNotReady + return res, nil + } + lg.Debugf("looking at Pod with IPs %v", pod.Status.PodIPs) + ready := false + for _, ep := range eps.Endpoints { + lg.Debugf("looking at endpoint with addresses %v", ep.Addresses) + if endpointReadyForPod(&ep, pod, lg) { + lg.Debugf("endpoint is ready for Pod") + ready = true + break + } + } + if ready { + readyReplicas++ + } + } + msg = fmt.Sprintf(msgReadyToRouteTemplate, readyReplicas, replicas) + if readyReplicas == 0 { + reason = reasonNotReady + st = metav1.ConditionFalse + return res, nil + } + st = metav1.ConditionTrue + if readyReplicas < replicas { + reason = reasonPartiallyReady + } else { + reason = reasonReady + } + return res, nil +} + +// endpointReadyForPod returns true if the endpoint is for the Pod's IPv4 address and is ready to serve traffic. +// Endpoint must not be nil. +func endpointReadyForPod(ep *discoveryv1.Endpoint, pod *corev1.Pod, lg *zap.SugaredLogger) bool { + podIP, err := podIPv4(pod) + if err != nil { + lg.Warnf("[unexpected] error retrieving Pod's IPv4 address: %v", err) + return false + } + // Currently we only ever set a single address on and Endpoint and nothing else is meant to modify this. + if len(ep.Addresses) != 1 { + return false + } + return strings.EqualFold(ep.Addresses[0], podIP) && + *ep.Conditions.Ready && + *ep.Conditions.Serving && + !*ep.Conditions.Terminating +} diff --git a/cmd/k8s-operator/egress-services-readiness_test.go b/cmd/k8s-operator/egress-services-readiness_test.go new file mode 100644 index 000000000..fdff4fafa --- /dev/null +++ b/cmd/k8s-operator/egress-services-readiness_test.go @@ -0,0 +1,169 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !plan9 + +package main + +import ( + "fmt" + "testing" + + "github.com/AlekSi/pointer" + "go.uber.org/zap" + appsv1 "k8s.io/api/apps/v1" + corev1 "k8s.io/api/core/v1" + discoveryv1 "k8s.io/api/discovery/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "sigs.k8s.io/controller-runtime/pkg/client/fake" + tsoperator "tailscale.com/k8s-operator" + tsapi "tailscale.com/k8s-operator/apis/v1alpha1" + "tailscale.com/tstest" + "tailscale.com/tstime" +) + +func TestEgressServiceReadiness(t *testing.T) { + // We need to pass a ProxyGroup object to WithStatusSubresource because of some quirks in how the fake client + // works. Without this code further down would not be able to update ProxyGroup status. + fc := fake.NewClientBuilder(). + WithScheme(tsapi.GlobalScheme). + WithStatusSubresource(&tsapi.ProxyGroup{}). + Build() + zl, _ := zap.NewDevelopment() + cl := tstest.NewClock(tstest.ClockOpts{}) + rec := &egressSvcsReadinessReconciler{ + tsNamespace: "operator-ns", + Client: fc, + logger: zl.Sugar(), + clock: cl, + } + tailnetFQDN := "my-app.tailnetxyz.ts.net" + egressSvc := &corev1.Service{ + ObjectMeta: metav1.ObjectMeta{ + Name: "my-app", + Namespace: "dev", + Annotations: map[string]string{ + AnnotationProxyGroup: "dev", + AnnotationTailnetTargetFQDN: tailnetFQDN, + }, + }, + } + fakeClusterIPSvc := &corev1.Service{ObjectMeta: metav1.ObjectMeta{Name: "my-app", Namespace: "operator-ns"}} + labels := egressSvcEpsLabels(egressSvc, fakeClusterIPSvc) + eps := &discoveryv1.EndpointSlice{ + ObjectMeta: metav1.ObjectMeta{ + Name: "my-app", + Namespace: "operator-ns", + Labels: labels, + }, + AddressType: discoveryv1.AddressTypeIPv4, + } + pg := &tsapi.ProxyGroup{ + ObjectMeta: metav1.ObjectMeta{ + Name: "dev", + }, + } + mustCreate(t, fc, egressSvc) + setClusterNotReady(egressSvc, cl, zl.Sugar()) + t.Run("endpointslice_does_not_exist", func(t *testing.T) { + expectReconciled(t, rec, "dev", "my-app") + expectEqual(t, fc, egressSvc) // not ready + }) + t.Run("proxy_group_does_not_exist", func(t *testing.T) { + mustCreate(t, fc, eps) + expectReconciled(t, rec, "dev", "my-app") + expectEqual(t, fc, egressSvc) // still not ready + }) + t.Run("proxy_group_not_ready", func(t *testing.T) { + mustCreate(t, fc, pg) + expectReconciled(t, rec, "dev", "my-app") + expectEqual(t, fc, egressSvc) // still not ready + }) + t.Run("no_ready_replicas", func(t *testing.T) { + setPGReady(pg, cl, zl.Sugar()) + mustUpdateStatus(t, fc, pg.Namespace, pg.Name, func(p *tsapi.ProxyGroup) { + p.Status = pg.Status + }) + expectEqual(t, fc, pg) + for i := range pgReplicas(pg) { + p := pod(pg, i) + mustCreate(t, fc, p) + mustUpdateStatus(t, fc, p.Namespace, p.Name, func(existing *corev1.Pod) { + existing.Status.PodIPs = p.Status.PodIPs + }) + } + expectReconciled(t, rec, "dev", "my-app") + setNotReady(egressSvc, cl, zl.Sugar(), pgReplicas(pg)) + expectEqual(t, fc, egressSvc) // still not ready + }) + t.Run("one_ready_replica", func(t *testing.T) { + setEndpointForReplica(pg, 0, eps) + mustUpdate(t, fc, eps.Namespace, eps.Name, func(e *discoveryv1.EndpointSlice) { + e.Endpoints = eps.Endpoints + }) + setReady(egressSvc, cl, zl.Sugar(), pgReplicas(pg), 1) + expectReconciled(t, rec, "dev", "my-app") + expectEqual(t, fc, egressSvc) // partially ready + }) + t.Run("all_replicas_ready", func(t *testing.T) { + for i := range pgReplicas(pg) { + setEndpointForReplica(pg, i, eps) + } + mustUpdate(t, fc, eps.Namespace, eps.Name, func(e *discoveryv1.EndpointSlice) { + e.Endpoints = eps.Endpoints + }) + setReady(egressSvc, cl, zl.Sugar(), pgReplicas(pg), pgReplicas(pg)) + expectReconciled(t, rec, "dev", "my-app") + expectEqual(t, fc, egressSvc) // ready + }) +} + +func setClusterNotReady(svc *corev1.Service, cl tstime.Clock, lg *zap.SugaredLogger) { + tsoperator.SetServiceCondition(svc, tsapi.EgressSvcReady, metav1.ConditionFalse, reasonClusterResourcesNotReady, reasonClusterResourcesNotReady, cl, lg) +} + +func setNotReady(svc *corev1.Service, cl tstime.Clock, lg *zap.SugaredLogger, replicas int32) { + msg := fmt.Sprintf(msgReadyToRouteTemplate, 0, replicas) + tsoperator.SetServiceCondition(svc, tsapi.EgressSvcReady, metav1.ConditionFalse, reasonNotReady, msg, cl, lg) +} + +func setReady(svc *corev1.Service, cl tstime.Clock, lg *zap.SugaredLogger, replicas, readyReplicas int32) { + reason := reasonPartiallyReady + if readyReplicas == replicas { + reason = reasonReady + } + msg := fmt.Sprintf(msgReadyToRouteTemplate, readyReplicas, replicas) + tsoperator.SetServiceCondition(svc, tsapi.EgressSvcReady, metav1.ConditionTrue, reason, msg, cl, lg) +} + +func setPGReady(pg *tsapi.ProxyGroup, cl tstime.Clock, lg *zap.SugaredLogger) { + tsoperator.SetProxyGroupCondition(pg, tsapi.ProxyGroupAvailable, metav1.ConditionTrue, "foo", "foo", pg.Generation, cl, lg) +} + +func setEndpointForReplica(pg *tsapi.ProxyGroup, ordinal int32, eps *discoveryv1.EndpointSlice) { + p := pod(pg, ordinal) + eps.Endpoints = append(eps.Endpoints, discoveryv1.Endpoint{ + Addresses: []string{p.Status.PodIPs[0].IP}, + Conditions: discoveryv1.EndpointConditions{ + Ready: pointer.ToBool(true), + Serving: pointer.ToBool(true), + Terminating: pointer.ToBool(false), + }, + }) +} + +func pod(pg *tsapi.ProxyGroup, ordinal int32) *corev1.Pod { + labels := pgLabels(pg.Name, nil) + labels[appsv1.PodIndexLabel] = fmt.Sprintf("%d", ordinal) + ip := fmt.Sprintf("10.0.0.%d", ordinal) + return &corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: fmt.Sprintf("%s-%d", pg.Name, ordinal), + Namespace: "operator-ns", + Labels: labels, + }, + Status: corev1.PodStatus{ + PodIPs: []corev1.PodIP{{IP: ip}}, + }, + } +} diff --git a/cmd/k8s-operator/egress-services.go b/cmd/k8s-operator/egress-services.go index 1c4f70a96..05be8efed 100644 --- a/cmd/k8s-operator/egress-services.go +++ b/cmd/k8s-operator/egress-services.go @@ -46,24 +46,21 @@ const ( reasonEgressSvcCreationFailed = "EgressSvcCreationFailed" reasonProxyGroupNotReady = "ProxyGroupNotReady" - labelProxyGroup = "tailscale.com/proxy-group" - labelProxyGroupType = "tailscale.com/proxy-group-type" - labelExternalSvcName = "tailscale.com/external-service-name" - labelExternalSvcNamespace = "tailscale.com/external-service-namespace" + labelProxyGroup = "tailscale.com/proxy-group" labelSvcType = "tailscale.com/svc-type" // ingress or egress typeEgress = "egress" // maxPorts is the maximum number of ports that can be exposed on a - // container. In practice this will be ports in range [3000 - 4000). The + // container. In practice this will be ports in range [10000 - 11000). The // high range should make it easier to distinguish container ports from // the tailnet target ports for debugging purposes (i.e when reading - // netfilter rules). The limit of 10000 is somewhat arbitrary, the + // netfilter rules). The limit of 1000 is somewhat arbitrary, the // assumption is that this would not be hit in practice. - maxPorts = 10000 + maxPorts = 1000 indexEgressProxyGroup = ".metadata.annotations.egress-proxy-group" - egressSvcsCMNameTemplate = "proxy-cfg-%s" + tsHealthCheckPortName = "tailscale-health-check" ) var gaugeEgressServices = clientmetric.NewGauge(kubetypes.MetricEgressServiceCount) @@ -101,12 +98,12 @@ type egressSvcsReconciler struct { // - updates the egress service config in a ConfigMap mounted to the ProxyGroup proxies with the tailnet target and the // portmappings. func (esr *egressSvcsReconciler) Reconcile(ctx context.Context, req reconcile.Request) (res reconcile.Result, err error) { - l := esr.logger.With("Service", req.NamespacedName) - defer l.Info("reconcile finished") + lg := esr.logger.With("Service", req.NamespacedName) + defer lg.Info("reconcile finished") svc := new(corev1.Service) if err = esr.Get(ctx, req.NamespacedName, svc); apierrors.IsNotFound(err) { - l.Info("Service not found") + lg.Info("Service not found") return res, nil } else if err != nil { return res, fmt.Errorf("failed to get Service: %w", err) @@ -114,7 +111,7 @@ func (esr *egressSvcsReconciler) Reconcile(ctx context.Context, req reconcile.Re // Name of the 'egress service', meaning the tailnet target. tailnetSvc := tailnetSvcName(svc) - l = l.With("tailnet-service", tailnetSvc) + lg = lg.With("tailnet-service", tailnetSvc) // Note that resources for egress Services are only cleaned up when the // Service is actually deleted (and not if, for example, user decides to @@ -122,31 +119,30 @@ func (esr *egressSvcsReconciler) Reconcile(ctx context.Context, req reconcile.Re // assume that the egress ExternalName Services are always created for // Tailscale operator specifically. if !svc.DeletionTimestamp.IsZero() { - l.Info("Service is being deleted, ensuring resource cleanup") - return res, esr.maybeCleanup(ctx, svc, l) + lg.Info("Service is being deleted, ensuring resource cleanup") + return res, esr.maybeCleanup(ctx, svc, lg) } oldStatus := svc.Status.DeepCopy() defer func() { - if !apiequality.Semantic.DeepEqual(oldStatus, svc.Status) { + if !apiequality.Semantic.DeepEqual(oldStatus, &svc.Status) { err = errors.Join(err, esr.Status().Update(ctx, svc)) } }() // Validate the user-created ExternalName Service and the associated ProxyGroup. - if ok, err := esr.validateClusterResources(ctx, svc, l); err != nil { + if ok, err := esr.validateClusterResources(ctx, svc, lg); err != nil { return res, fmt.Errorf("error validating cluster resources: %w", err) } else if !ok { return res, nil } if !slices.Contains(svc.Finalizers, FinalizerName) { - l.Infof("configuring tailnet service") // logged exactly once svc.Finalizers = append(svc.Finalizers, FinalizerName) - if err := esr.Update(ctx, svc); err != nil { + if err := esr.updateSvcSpec(ctx, svc); err != nil { err := fmt.Errorf("failed to add finalizer: %w", err) - r := svcConfiguredReason(svc, false, l) - tsoperator.SetServiceCondition(svc, tsapi.EgressSvcConfigured, metav1.ConditionFalse, r, err.Error(), esr.clock, l) + r := svcConfiguredReason(svc, false, lg) + tsoperator.SetServiceCondition(svc, tsapi.EgressSvcConfigured, metav1.ConditionFalse, r, err.Error(), esr.clock, lg) return res, err } esr.mu.Lock() @@ -155,26 +151,33 @@ func (esr *egressSvcsReconciler) Reconcile(ctx context.Context, req reconcile.Re esr.mu.Unlock() } - if err := esr.maybeCleanupProxyGroupConfig(ctx, svc, l); err != nil { + if err := esr.maybeCleanupProxyGroupConfig(ctx, svc, lg); err != nil { err = fmt.Errorf("cleaning up resources for previous ProxyGroup failed: %w", err) - r := svcConfiguredReason(svc, false, l) - tsoperator.SetServiceCondition(svc, tsapi.EgressSvcConfigured, metav1.ConditionFalse, r, err.Error(), esr.clock, l) + r := svcConfiguredReason(svc, false, lg) + tsoperator.SetServiceCondition(svc, tsapi.EgressSvcConfigured, metav1.ConditionFalse, r, err.Error(), esr.clock, lg) return res, err } - return res, esr.maybeProvision(ctx, svc, l) + if err := esr.maybeProvision(ctx, svc, lg); err != nil { + if strings.Contains(err.Error(), optimisticLockErrorMsg) { + lg.Infof("optimistic lock error, retrying: %s", err) + } else { + return reconcile.Result{}, err + } + } + + return res, nil } -func (esr *egressSvcsReconciler) maybeProvision(ctx context.Context, svc *corev1.Service, l *zap.SugaredLogger) (err error) { - l.Debug("maybe provision") - r := svcConfiguredReason(svc, false, l) +func (esr *egressSvcsReconciler) maybeProvision(ctx context.Context, svc *corev1.Service, lg *zap.SugaredLogger) (err error) { + r := svcConfiguredReason(svc, false, lg) st := metav1.ConditionFalse defer func() { msg := r if st != metav1.ConditionTrue && err != nil { msg = err.Error() } - tsoperator.SetServiceCondition(svc, tsapi.EgressSvcConfigured, st, r, msg, esr.clock, l) + tsoperator.SetServiceCondition(svc, tsapi.EgressSvcConfigured, st, r, msg, esr.clock, lg) }() crl := egressSvcChildResourceLabels(svc) @@ -186,36 +189,36 @@ func (esr *egressSvcsReconciler) maybeProvision(ctx context.Context, svc *corev1 if clusterIPSvc == nil { clusterIPSvc = esr.clusterIPSvcForEgress(crl) } - upToDate := svcConfigurationUpToDate(svc, l) + upToDate := svcConfigurationUpToDate(svc, lg) provisioned := true if !upToDate { - if clusterIPSvc, provisioned, err = esr.provision(ctx, svc.Annotations[AnnotationProxyGroup], svc, clusterIPSvc, l); err != nil { + if clusterIPSvc, provisioned, err = esr.provision(ctx, svc.Annotations[AnnotationProxyGroup], svc, clusterIPSvc, lg); err != nil { return err } } if !provisioned { - l.Infof("unable to provision cluster resources") + lg.Infof("unable to provision cluster resources") return nil } // Update ExternalName Service to point at the ClusterIP Service. - clusterDomain := retrieveClusterDomain(esr.tsNamespace, l) + clusterDomain := retrieveClusterDomain(esr.tsNamespace, lg) clusterIPSvcFQDN := fmt.Sprintf("%s.%s.svc.%s", clusterIPSvc.Name, clusterIPSvc.Namespace, clusterDomain) if svc.Spec.ExternalName != clusterIPSvcFQDN { - l.Infof("Configuring ExternalName Service to point to ClusterIP Service %s", clusterIPSvcFQDN) + lg.Infof("Configuring ExternalName Service to point to ClusterIP Service %s", clusterIPSvcFQDN) svc.Spec.ExternalName = clusterIPSvcFQDN - if err = esr.Update(ctx, svc); err != nil { + if err = esr.updateSvcSpec(ctx, svc); err != nil { err = fmt.Errorf("error updating ExternalName Service: %w", err) return err } } - r = svcConfiguredReason(svc, true, l) + r = svcConfiguredReason(svc, true, lg) st = metav1.ConditionTrue return nil } -func (esr *egressSvcsReconciler) provision(ctx context.Context, proxyGroupName string, svc, clusterIPSvc *corev1.Service, l *zap.SugaredLogger) (*corev1.Service, bool, error) { - l.Infof("updating configuration...") +func (esr *egressSvcsReconciler) provision(ctx context.Context, proxyGroupName string, svc, clusterIPSvc *corev1.Service, lg *zap.SugaredLogger) (*corev1.Service, bool, error) { + lg.Infof("updating configuration...") usedPorts, err := esr.usedPortsForPG(ctx, proxyGroupName) if err != nil { return nil, false, fmt.Errorf("error calculating used ports for ProxyGroup %s: %w", proxyGroupName, err) @@ -228,12 +231,22 @@ func (esr *egressSvcsReconciler) provision(ctx context.Context, proxyGroupName s found := false for _, wantsPM := range svc.Spec.Ports { if wantsPM.Port == pm.Port && strings.EqualFold(string(wantsPM.Protocol), string(pm.Protocol)) { + // We want to both preserve the user set port names for ease of debugging, but also + // ensure that we name all unnamed ports as the ClusterIP Service that we create will + // always have at least two ports. + // https://kubernetes.io/docs/concepts/services-networking/service/#multi-port-services + // See also https://github.com/tailscale/tailscale/issues/13406#issuecomment-2507230388 + if wantsPM.Name != "" { + clusterIPSvc.Spec.Ports[i].Name = wantsPM.Name + } else { + clusterIPSvc.Spec.Ports[i].Name = "tailscale-unnamed" + } found = true break } } if !found { - l.Debugf("portmapping %s:%d -> %s:%d is no longer required, removing", pm.Protocol, pm.TargetPort.IntVal, pm.Protocol, pm.Port) + lg.Debugf("portmapping %s:%d -> %s:%d is no longer required, removing", pm.Protocol, pm.TargetPort.IntVal, pm.Protocol, pm.Port) clusterIPSvc.Spec.Ports = slices.Delete(clusterIPSvc.Spec.Ports, i, i+1) } } @@ -242,6 +255,12 @@ func (esr *egressSvcsReconciler) provision(ctx context.Context, proxyGroupName s // ClusterIP Service produce new target port and add a portmapping to // the ClusterIP Service. for _, wantsPM := range svc.Spec.Ports { + // Because we add a healthcheck port of our own, we will always have at least two ports. That + // means that we cannot have ports with name not set. + // https://kubernetes.io/docs/concepts/services-networking/service/#multi-port-services + if wantsPM.Name == "" { + wantsPM.Name = "tailscale-unnamed" + } found := false for _, gotPM := range clusterIPSvc.Spec.Ports { if wantsPM.Port == gotPM.Port && strings.EqualFold(string(wantsPM.Protocol), string(gotPM.Protocol)) { @@ -252,13 +271,13 @@ func (esr *egressSvcsReconciler) provision(ctx context.Context, proxyGroupName s if !found { // Calculate a free port to expose on container and add // a new PortMap to the ClusterIP Service. - if usedPorts.Len() == maxPorts { + if usedPorts.Len() >= maxPorts { // TODO(irbekrm): refactor to avoid extra reconciles here. Low priority as in practice, // the limit should not be hit. return nil, false, fmt.Errorf("unable to allocate additional ports on ProxyGroup %s, %d ports already used. Create another ProxyGroup or open an issue if you believe this is unexpected.", proxyGroupName, maxPorts) } p := unusedPort(usedPorts) - l.Debugf("mapping tailnet target port %d to container port %d", wantsPM.Port, p) + lg.Debugf("mapping tailnet target port %d to container port %d", wantsPM.Port, p) usedPorts.Insert(p) clusterIPSvc.Spec.Ports = append(clusterIPSvc.Spec.Ports, corev1.ServicePort{ Name: wantsPM.Name, @@ -268,6 +287,25 @@ func (esr *egressSvcsReconciler) provision(ctx context.Context, proxyGroupName s }) } } + var healthCheckPort int32 = defaultLocalAddrPort + + for { + if !slices.ContainsFunc(svc.Spec.Ports, func(p corev1.ServicePort) bool { + return p.Port == healthCheckPort + }) { + break + } + healthCheckPort++ + if healthCheckPort > 10002 { + return nil, false, fmt.Errorf("unable to find a free port for internal health check in range [9002, 10002]") + } + } + clusterIPSvc.Spec.Ports = append(clusterIPSvc.Spec.Ports, corev1.ServicePort{ + Name: tsHealthCheckPortName, + Port: healthCheckPort, + TargetPort: intstr.FromInt(defaultLocalAddrPort), + Protocol: "TCP", + }) if !reflect.DeepEqual(clusterIPSvc, oldClusterIPSvc) { if clusterIPSvc, err = createOrUpdate(ctx, esr.Client, esr.tsNamespace, clusterIPSvc, func(svc *corev1.Service) { svc.Labels = clusterIPSvc.Labels @@ -277,11 +315,9 @@ func (esr *egressSvcsReconciler) provision(ctx context.Context, proxyGroupName s } } - crl := egressSvcChildResourceLabels(svc) + crl := egressSvcEpsLabels(svc, clusterIPSvc) // TODO(irbekrm): support IPv6, but need to investigate how kube proxy // sets up Service -> Pod routing when IPv6 is involved. - crl[discoveryv1.LabelServiceName] = clusterIPSvc.Name - crl[discoveryv1.LabelManagedBy] = "tailscale.com" eps := &discoveryv1.EndpointSlice{ ObjectMeta: metav1.ObjectMeta{ Name: fmt.Sprintf("%s-ipv4", clusterIPSvc.Name), @@ -307,14 +343,14 @@ func (esr *egressSvcsReconciler) provision(ctx context.Context, proxyGroupName s return nil, false, fmt.Errorf("error retrieving egress services configuration: %w", err) } if cm == nil { - l.Info("ConfigMap not yet created, waiting..") + lg.Info("ConfigMap not yet created, waiting..") return nil, false, nil } tailnetSvc := tailnetSvcName(svc) gotCfg := (*cfgs)[tailnetSvc] - wantsCfg := egressSvcCfg(svc, clusterIPSvc) + wantsCfg := egressSvcCfg(svc, clusterIPSvc, esr.tsNamespace, lg) if !reflect.DeepEqual(gotCfg, wantsCfg) { - l.Debugf("updating egress services ConfigMap %s", cm.Name) + lg.Debugf("updating egress services ConfigMap %s", cm.Name) mak.Set(cfgs, tailnetSvc, wantsCfg) bs, err := json.Marshal(cfgs) if err != nil { @@ -325,7 +361,7 @@ func (esr *egressSvcsReconciler) provision(ctx context.Context, proxyGroupName s return nil, false, fmt.Errorf("error updating egress services ConfigMap: %w", err) } } - l.Infof("egress service configuration has been updated") + lg.Infof("egress service configuration has been updated") return clusterIPSvc, true, nil } @@ -366,7 +402,7 @@ func (esr *egressSvcsReconciler) maybeCleanup(ctx context.Context, svc *corev1.S return nil } -func (esr *egressSvcsReconciler) maybeCleanupProxyGroupConfig(ctx context.Context, svc *corev1.Service, l *zap.SugaredLogger) error { +func (esr *egressSvcsReconciler) maybeCleanupProxyGroupConfig(ctx context.Context, svc *corev1.Service, lg *zap.SugaredLogger) error { wantsProxyGroup := svc.Annotations[AnnotationProxyGroup] cond := tsoperator.GetServiceCondition(svc, tsapi.EgressSvcConfigured) if cond == nil { @@ -380,7 +416,7 @@ func (esr *egressSvcsReconciler) maybeCleanupProxyGroupConfig(ctx context.Contex return nil } esr.logger.Infof("egress Service configured on ProxyGroup %s, wants ProxyGroup %s, cleaning up...", ss[2], wantsProxyGroup) - if err := esr.ensureEgressSvcCfgDeleted(ctx, svc, l); err != nil { + if err := esr.ensureEgressSvcCfgDeleted(ctx, svc, lg); err != nil { return fmt.Errorf("error deleting egress service config: %w", err) } return nil @@ -416,7 +452,7 @@ func (esr *egressSvcsReconciler) usedPortsForPG(ctx context.Context, pg string) func (esr *egressSvcsReconciler) clusterIPSvcForEgress(crl map[string]string) *corev1.Service { return &corev1.Service{ ObjectMeta: metav1.ObjectMeta{ - GenerateName: svcNameBase(crl[labelExternalSvcName]), + GenerateName: svcNameBase(crl[LabelParentName]), Namespace: esr.tsNamespace, Labels: crl, }, @@ -428,24 +464,24 @@ func (esr *egressSvcsReconciler) clusterIPSvcForEgress(crl map[string]string) *c func (esr *egressSvcsReconciler) ensureEgressSvcCfgDeleted(ctx context.Context, svc *corev1.Service, logger *zap.SugaredLogger) error { crl := egressSvcChildResourceLabels(svc) - cmName := fmt.Sprintf(egressSvcsCMNameTemplate, crl[labelProxyGroup]) + cmName := pgEgressCMName(crl[labelProxyGroup]) cm := &corev1.ConfigMap{ ObjectMeta: metav1.ObjectMeta{ Name: cmName, Namespace: esr.tsNamespace, }, } - l := logger.With("ConfigMap", client.ObjectKeyFromObject(cm)) - l.Debug("ensuring that egress service configuration is removed from proxy config") + lggr := logger.With("ConfigMap", client.ObjectKeyFromObject(cm)) + lggr.Debug("ensuring that egress service configuration is removed from proxy config") if err := esr.Get(ctx, client.ObjectKeyFromObject(cm), cm); apierrors.IsNotFound(err) { - l.Debugf("ConfigMap not found") + lggr.Debugf("ConfigMap not found") return nil } else if err != nil { return fmt.Errorf("error retrieving ConfigMap: %w", err) } bs := cm.BinaryData[egressservices.KeyEgressServices] if len(bs) == 0 { - l.Debugf("ConfigMap does not contain egress service configs") + lggr.Debugf("ConfigMap does not contain egress service configs") return nil } cfgs := &egressservices.Configs{} @@ -455,12 +491,12 @@ func (esr *egressSvcsReconciler) ensureEgressSvcCfgDeleted(ctx context.Context, tailnetSvc := tailnetSvcName(svc) _, ok := (*cfgs)[tailnetSvc] if !ok { - l.Debugf("ConfigMap does not contain egress service config, likely because it was already deleted") + lggr.Debugf("ConfigMap does not contain egress service config, likely because it was already deleted") return nil } - l.Infof("before deleting config %+#v", *cfgs) + lggr.Infof("before deleting config %+#v", *cfgs) delete(*cfgs, tailnetSvc) - l.Infof("after deleting config %+#v", *cfgs) + lggr.Infof("after deleting config %+#v", *cfgs) bs, err := json.Marshal(cfgs) if err != nil { return fmt.Errorf("error marshalling egress services configs: %w", err) @@ -469,7 +505,7 @@ func (esr *egressSvcsReconciler) ensureEgressSvcCfgDeleted(ctx context.Context, return esr.Update(ctx, cm) } -func (esr *egressSvcsReconciler) validateClusterResources(ctx context.Context, svc *corev1.Service, l *zap.SugaredLogger) (bool, error) { +func (esr *egressSvcsReconciler) validateClusterResources(ctx context.Context, svc *corev1.Service, lg *zap.SugaredLogger) (bool, error) { proxyGroupName := svc.Annotations[AnnotationProxyGroup] pg := &tsapi.ProxyGroup{ ObjectMeta: metav1.ObjectMeta{ @@ -477,32 +513,52 @@ func (esr *egressSvcsReconciler) validateClusterResources(ctx context.Context, s }, } if err := esr.Get(ctx, client.ObjectKeyFromObject(pg), pg); apierrors.IsNotFound(err) { - l.Infof("ProxyGroup %q not found, waiting...", proxyGroupName) - tsoperator.SetServiceCondition(svc, tsapi.EgressSvcValid, metav1.ConditionUnknown, reasonProxyGroupNotReady, reasonProxyGroupNotReady, esr.clock, l) + lg.Infof("ProxyGroup %q not found, waiting...", proxyGroupName) + tsoperator.SetServiceCondition(svc, tsapi.EgressSvcValid, metav1.ConditionUnknown, reasonProxyGroupNotReady, reasonProxyGroupNotReady, esr.clock, lg) + tsoperator.RemoveServiceCondition(svc, tsapi.EgressSvcConfigured) return false, nil } else if err != nil { err := fmt.Errorf("unable to retrieve ProxyGroup %s: %w", proxyGroupName, err) - tsoperator.SetServiceCondition(svc, tsapi.EgressSvcValid, metav1.ConditionUnknown, reasonProxyGroupNotReady, err.Error(), esr.clock, l) + tsoperator.SetServiceCondition(svc, tsapi.EgressSvcValid, metav1.ConditionUnknown, reasonProxyGroupNotReady, err.Error(), esr.clock, lg) + tsoperator.RemoveServiceCondition(svc, tsapi.EgressSvcConfigured) return false, err } - if !tsoperator.ProxyGroupIsReady(pg) { - l.Infof("ProxyGroup %s is not ready, waiting...", proxyGroupName) - tsoperator.SetServiceCondition(svc, tsapi.EgressSvcValid, metav1.ConditionUnknown, reasonProxyGroupNotReady, reasonProxyGroupNotReady, esr.clock, l) - return false, nil - } - if violations := validateEgressService(svc, pg); len(violations) > 0 { msg := fmt.Sprintf("invalid egress Service: %s", strings.Join(violations, ", ")) esr.recorder.Event(svc, corev1.EventTypeWarning, "INVALIDSERVICE", msg) - l.Info(msg) - tsoperator.SetServiceCondition(svc, tsapi.EgressSvcValid, metav1.ConditionFalse, reasonEgressSvcInvalid, msg, esr.clock, l) + lg.Info(msg) + tsoperator.SetServiceCondition(svc, tsapi.EgressSvcValid, metav1.ConditionFalse, reasonEgressSvcInvalid, msg, esr.clock, lg) + tsoperator.RemoveServiceCondition(svc, tsapi.EgressSvcConfigured) return false, nil } - l.Debugf("egress service is valid") - tsoperator.SetServiceCondition(svc, tsapi.EgressSvcValid, metav1.ConditionTrue, reasonEgressSvcValid, reasonEgressSvcValid, esr.clock, l) + if !tsoperator.ProxyGroupAvailable(pg) { + tsoperator.SetServiceCondition(svc, tsapi.EgressSvcValid, metav1.ConditionUnknown, reasonProxyGroupNotReady, reasonProxyGroupNotReady, esr.clock, lg) + tsoperator.RemoveServiceCondition(svc, tsapi.EgressSvcConfigured) + } + + lg.Debugf("egress service is valid") + tsoperator.SetServiceCondition(svc, tsapi.EgressSvcValid, metav1.ConditionTrue, reasonEgressSvcValid, reasonEgressSvcValid, esr.clock, lg) return true, nil } +func egressSvcCfg(externalNameSvc, clusterIPSvc *corev1.Service, ns string, lg *zap.SugaredLogger) egressservices.Config { + d := retrieveClusterDomain(ns, lg) + tt := tailnetTargetFromSvc(externalNameSvc) + hep := healthCheckForSvc(clusterIPSvc, d) + cfg := egressservices.Config{ + TailnetTarget: tt, + HealthCheckEndpoint: hep, + } + for _, svcPort := range clusterIPSvc.Spec.Ports { + if svcPort.Name == tsHealthCheckPortName { + continue // exclude healthcheck from egress svcs configs + } + pm := portMap(svcPort) + mak.Set(&cfg.Ports, pm, struct{}{}) + } + return cfg +} + func validateEgressService(svc *corev1.Service, pg *tsapi.ProxyGroup) []string { violations := validateService(svc) @@ -544,13 +600,13 @@ func svcNameBase(s string) string { } } -// unusedPort returns a port in range [3000 - 4000). The caller must ensure that -// usedPorts does not contain all ports in range [3000 - 4000). +// unusedPort returns a port in range [10000 - 11000). The caller must ensure that +// usedPorts does not contain all ports in range [10000 - 11000). func unusedPort(usedPorts sets.Set[int32]) int32 { foundFreePort := false var suggestPort int32 for !foundFreePort { - suggestPort = rand.Int32N(maxPorts) + 3000 + suggestPort = rand.Int32N(maxPorts) + 10000 if !usedPorts.Has(suggestPort) { foundFreePort = true } @@ -572,19 +628,13 @@ func tailnetTargetFromSvc(svc *corev1.Service) egressservices.TailnetTarget { } } -func egressSvcCfg(externalNameSvc, clusterIPSvc *corev1.Service) egressservices.Config { - tt := tailnetTargetFromSvc(externalNameSvc) - cfg := egressservices.Config{TailnetTarget: tt} - for _, svcPort := range clusterIPSvc.Spec.Ports { - pm := portMap(svcPort) - mak.Set(&cfg.Ports, pm, struct{}{}) - } - return cfg -} - func portMap(p corev1.ServicePort) egressservices.PortMap { // TODO (irbekrm): out of bounds check? - return egressservices.PortMap{Protocol: string(p.Protocol), MatchPort: uint16(p.TargetPort.IntVal), TargetPort: uint16(p.Port)} + return egressservices.PortMap{ + Protocol: string(p.Protocol), + MatchPort: uint16(p.TargetPort.IntVal), + TargetPort: uint16(p.Port), + } } func isEgressSvcForProxyGroup(obj client.Object) bool { @@ -599,15 +649,19 @@ func isEgressSvcForProxyGroup(obj client.Object) bool { // egressSvcConfig returns a ConfigMap that contains egress services configuration for the provided ProxyGroup as well // as unmarshalled configuration from the ConfigMap. func egressSvcsConfigs(ctx context.Context, cl client.Client, proxyGroupName, tsNamespace string) (cm *corev1.ConfigMap, cfgs *egressservices.Configs, err error) { - cmName := fmt.Sprintf(egressSvcsCMNameTemplate, proxyGroupName) + name := pgEgressCMName(proxyGroupName) cm = &corev1.ConfigMap{ ObjectMeta: metav1.ObjectMeta{ - Name: cmName, + Name: name, Namespace: tsNamespace, }, } - if err := cl.Get(ctx, client.ObjectKeyFromObject(cm), cm); err != nil { - return nil, nil, fmt.Errorf("error retrieving egress services ConfigMap %s: %v", cmName, err) + err = cl.Get(ctx, client.ObjectKeyFromObject(cm), cm) + if apierrors.IsNotFound(err) { // ProxyGroup resources have not been created (yet) + return nil, nil, nil + } + if err != nil { + return nil, nil, fmt.Errorf("error retrieving egress services ConfigMap %s: %v", name, err) } cfgs = &egressservices.Configs{} if len(cm.BinaryData[egressservices.KeyEgressServices]) != 0 { @@ -626,15 +680,29 @@ func egressSvcsConfigs(ctx context.Context, cl client.Client, proxyGroupName, ts // should probably validate and truncate (?) the names is they are too long. func egressSvcChildResourceLabels(svc *corev1.Service) map[string]string { return map[string]string{ - LabelManaged: "true", - labelProxyGroup: svc.Annotations[AnnotationProxyGroup], - labelExternalSvcName: svc.Name, - labelExternalSvcNamespace: svc.Namespace, - labelSvcType: typeEgress, + kubetypes.LabelManaged: "true", + LabelParentType: "svc", + LabelParentName: svc.Name, + LabelParentNamespace: svc.Namespace, + labelProxyGroup: svc.Annotations[AnnotationProxyGroup], + labelSvcType: typeEgress, } } -func svcConfigurationUpToDate(svc *corev1.Service, l *zap.SugaredLogger) bool { +// egressEpsLabels returns labels to be added to an EndpointSlice created for an egress service. +func egressSvcEpsLabels(extNSvc, clusterIPSvc *corev1.Service) map[string]string { + lbels := egressSvcChildResourceLabels(extNSvc) + // Adding this label is what makes kube proxy set up rules to route traffic sent to the clusterIP Service to the + // endpoints defined on this EndpointSlice. + // https://kubernetes.io/docs/concepts/services-networking/endpoint-slices/#ownership + lbels[discoveryv1.LabelServiceName] = clusterIPSvc.Name + // Kubernetes recommends setting this label. + // https://kubernetes.io/docs/concepts/services-networking/endpoint-slices/#management + lbels[discoveryv1.LabelManagedBy] = "tailscale.com" + return lbels +} + +func svcConfigurationUpToDate(svc *corev1.Service, lg *zap.SugaredLogger) bool { cond := tsoperator.GetServiceCondition(svc, tsapi.EgressSvcConfigured) if cond == nil { return false @@ -642,21 +710,21 @@ func svcConfigurationUpToDate(svc *corev1.Service, l *zap.SugaredLogger) bool { if cond.Status != metav1.ConditionTrue { return false } - wantsReadyReason := svcConfiguredReason(svc, true, l) + wantsReadyReason := svcConfiguredReason(svc, true, lg) return strings.EqualFold(wantsReadyReason, cond.Reason) } -func cfgHash(c cfg, l *zap.SugaredLogger) string { +func cfgHash(c cfg, lg *zap.SugaredLogger) string { bs, err := json.Marshal(c) if err != nil { // Don't use l.Error as that messes up component logs with, in this case, unnecessary stack trace. - l.Infof("error marhsalling Config: %v", err) + lg.Infof("error marhsalling Config: %v", err) return "" } h := sha256.New() if _, err := h.Write(bs); err != nil { // Don't use l.Error as that messes up component logs with, in this case, unnecessary stack trace. - l.Infof("error producing Config hash: %v", err) + lg.Infof("error producing Config hash: %v", err) return "" } return fmt.Sprintf("%x", h.Sum(nil)) @@ -668,7 +736,7 @@ type cfg struct { ProxyGroup string `json:"proxyGroup"` } -func svcConfiguredReason(svc *corev1.Service, configured bool, l *zap.SugaredLogger) string { +func svcConfiguredReason(svc *corev1.Service, configured bool, lg *zap.SugaredLogger) string { var r string if configured { r = "ConfiguredFor:" @@ -682,7 +750,7 @@ func svcConfiguredReason(svc *corev1.Service, configured bool, l *zap.SugaredLog TailnetTarget: tt, ProxyGroup: svc.Annotations[AnnotationProxyGroup], } - r += fmt.Sprintf(":Config:%s", cfgHash(s, l)) + r += fmt.Sprintf(":Config:%s", cfgHash(s, lg)) return r } @@ -704,3 +772,27 @@ func epsPortsFromSvc(svc *corev1.Service) (ep []discoveryv1.EndpointPort) { } return ep } + +// updateSvcSpec ensures that the given Service's spec is updated in cluster, but the local Service object still retains +// the not-yet-applied status. +// TODO(irbekrm): once we do SSA for these patch updates, this will no longer be needed. +func (esr *egressSvcsReconciler) updateSvcSpec(ctx context.Context, svc *corev1.Service) error { + st := svc.Status.DeepCopy() + err := esr.Update(ctx, svc) + svc.Status = *st + return err +} + +// healthCheckForSvc return the URL of the containerboot's health check endpoint served by this Service or empty string. +func healthCheckForSvc(svc *corev1.Service, clusterDomain string) string { + // This version of the operator always sets health check port on the egress Services. However, it is possible + // that this reconcile loops runs during a proxy upgrade from a version that did not set the health check port + // and parses a Service that does not have the port set yet. + i := slices.IndexFunc(svc.Spec.Ports, func(port corev1.ServicePort) bool { + return port.Name == tsHealthCheckPortName + }) + if i == -1 { + return "" + } + return fmt.Sprintf("http://%s.%s.svc.%s:%d/healthz", svc.Name, svc.Namespace, clusterDomain, svc.Spec.Ports[i].Port) +} diff --git a/cmd/k8s-operator/egress-services_test.go b/cmd/k8s-operator/egress-services_test.go index 13fa31784..202804d30 100644 --- a/cmd/k8s-operator/egress-services_test.go +++ b/cmd/k8s-operator/egress-services_test.go @@ -18,6 +18,7 @@ import ( discoveryv1 "k8s.io/api/discovery/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/types" + "k8s.io/apimachinery/pkg/util/intstr" "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/client/fake" tsapi "tailscale.com/k8s-operator/apis/v1alpha1" @@ -34,13 +35,13 @@ func TestTailscaleEgressServices(t *testing.T) { UID: types.UID("1234-UID"), }, Spec: tsapi.ProxyGroupSpec{ - Replicas: pointer.To(3), + Replicas: pointer.To[int32](3), Type: tsapi.ProxyGroupTypeEgress, }, } cm := &corev1.ConfigMap{ ObjectMeta: metav1.ObjectMeta{ - Name: fmt.Sprintf(egressSvcsCMNameTemplate, "foo"), + Name: pgEgressCMName("foo"), Namespace: "operator-ns", }, } @@ -78,55 +79,41 @@ func TestTailscaleEgressServices(t *testing.T) { Selector: nil, Ports: []corev1.ServicePort{ { - Name: "http", Protocol: "TCP", Port: 80, }, - { - Name: "https", - Protocol: "TCP", - Port: 443, - }, }, }, } - t.Run("proxy_group_not_ready", func(t *testing.T) { + t.Run("service_one_unnamed_port", func(t *testing.T) { mustCreate(t, fc, svc) expectReconciled(t, esr, "default", "test") - // Service should have EgressSvcValid condition set to Unknown. - svc.Status.Conditions = []metav1.Condition{condition(tsapi.EgressSvcValid, metav1.ConditionUnknown, reasonProxyGroupNotReady, reasonProxyGroupNotReady, clock)} - expectEqual(t, fc, svc, nil) + validateReadyService(t, fc, esr, svc, clock, zl, cm) }) - - t.Run("proxy_group_ready", func(t *testing.T) { - mustUpdateStatus(t, fc, "", "foo", func(pg *tsapi.ProxyGroup) { - pg.Status.Conditions = []metav1.Condition{ - condition(tsapi.ProxyGroupReady, metav1.ConditionTrue, "", "", clock), - } + t.Run("service_add_two_named_ports", func(t *testing.T) { + svc.Spec.Ports = []corev1.ServicePort{{Protocol: "TCP", Port: 80, Name: "http"}, {Protocol: "TCP", Port: 443, Name: "https"}} + mustUpdate(t, fc, "default", "test", func(s *corev1.Service) { + s.Spec.Ports = svc.Spec.Ports }) - // Quirks of the fake client. - mustUpdateStatus(t, fc, "default", "test", func(svc *corev1.Service) { - svc.Status.Conditions = []metav1.Condition{} + expectReconciled(t, esr, "default", "test") + validateReadyService(t, fc, esr, svc, clock, zl, cm) + }) + t.Run("service_add_udp_port", func(t *testing.T) { + svc.Spec.Ports = append(svc.Spec.Ports, corev1.ServicePort{Port: 53, Protocol: "UDP", Name: "dns"}) + mustUpdate(t, fc, "default", "test", func(s *corev1.Service) { + s.Spec.Ports = svc.Spec.Ports }) expectReconciled(t, esr, "default", "test") - // Verify that a ClusterIP Service has been created. - name := findGenNameForEgressSvcResources(t, fc, svc) - expectEqual(t, fc, clusterIPSvc(name, svc), removeTargetPortsFromSvc) - clusterSvc := mustGetClusterIPSvc(t, fc, name) - // Verify that an EndpointSlice has been created. - expectEqual(t, fc, endpointSlice(name, svc, clusterSvc), nil) - // Verify that ConfigMap contains configuration for the new egress service. - mustHaveConfigForSvc(t, fc, svc, clusterSvc, cm) - r := svcConfiguredReason(svc, true, zl.Sugar()) - // Verify that the user-created ExternalName Service has Configured set to true and ExternalName pointing to the - // CluterIP Service. - svc.Status.Conditions = []metav1.Condition{ - condition(tsapi.EgressSvcConfigured, metav1.ConditionTrue, r, r, clock), - } - svc.ObjectMeta.Finalizers = []string{"tailscale.com/finalizer"} - svc.Spec.ExternalName = fmt.Sprintf("%s.operator-ns.svc.cluster.local", name) - expectEqual(t, fc, svc, nil) + validateReadyService(t, fc, esr, svc, clock, zl, cm) + }) + t.Run("service_change_protocol", func(t *testing.T) { + svc.Spec.Ports = []corev1.ServicePort{{Protocol: "TCP", Port: 80, Name: "http"}, {Protocol: "TCP", Port: 443, Name: "https"}, {Port: 53, Protocol: "TCP", Name: "tcp_dns"}} + mustUpdate(t, fc, "default", "test", func(s *corev1.Service) { + s.Spec.Ports = svc.Spec.Ports + }) + expectReconciled(t, esr, "default", "test") + validateReadyService(t, fc, esr, svc, clock, zl, cm) }) t.Run("delete_external_name_service", func(t *testing.T) { @@ -143,6 +130,29 @@ func TestTailscaleEgressServices(t *testing.T) { }) } +func validateReadyService(t *testing.T, fc client.WithWatch, esr *egressSvcsReconciler, svc *corev1.Service, clock *tstest.Clock, zl *zap.Logger, cm *corev1.ConfigMap) { + expectReconciled(t, esr, "default", "test") + // Verify that a ClusterIP Service has been created. + name := findGenNameForEgressSvcResources(t, fc, svc) + expectEqual(t, fc, clusterIPSvc(name, svc), removeTargetPortsFromSvc) + clusterSvc := mustGetClusterIPSvc(t, fc, name) + // Verify that an EndpointSlice has been created. + expectEqual(t, fc, endpointSlice(name, svc, clusterSvc)) + // Verify that ConfigMap contains configuration for the new egress service. + mustHaveConfigForSvc(t, fc, svc, clusterSvc, cm, zl) + r := svcConfiguredReason(svc, true, zl.Sugar()) + // Verify that the user-created ExternalName Service has Configured set to true and ExternalName pointing to the + // CluterIP Service. + svc.Status.Conditions = []metav1.Condition{ + condition(tsapi.EgressSvcValid, metav1.ConditionTrue, "EgressSvcValid", "EgressSvcValid", clock), + condition(tsapi.EgressSvcConfigured, metav1.ConditionTrue, r, r, clock), + } + svc.ObjectMeta.Finalizers = []string{"tailscale.com/finalizer"} + svc.Spec.ExternalName = fmt.Sprintf("%s.operator-ns.svc.cluster.local", name) + expectEqual(t, fc, svc) + +} + func condition(typ tsapi.ConditionType, st metav1.ConditionStatus, r, msg string, clock tstime.Clock) metav1.Condition { return metav1.Condition{ Type: string(typ), @@ -168,6 +178,23 @@ func findGenNameForEgressSvcResources(t *testing.T, client client.Client, svc *c func clusterIPSvc(name string, extNSvc *corev1.Service) *corev1.Service { labels := egressSvcChildResourceLabels(extNSvc) + ports := make([]corev1.ServicePort, len(extNSvc.Spec.Ports)) + for i, port := range extNSvc.Spec.Ports { + ports[i] = corev1.ServicePort{ // Copy the port to avoid modifying the original. + Name: port.Name, + Port: port.Port, + Protocol: port.Protocol, + } + if port.Name == "" { + ports[i].Name = "tailscale-unnamed" + } + } + ports = append(ports, corev1.ServicePort{ + Name: "tailscale-health-check", + Port: 9002, + TargetPort: intstr.FromInt(9002), + Protocol: "TCP", + }) return &corev1.Service{ ObjectMeta: metav1.ObjectMeta{ Name: name, @@ -177,7 +204,7 @@ func clusterIPSvc(name string, extNSvc *corev1.Service) *corev1.Service { }, Spec: corev1.ServiceSpec{ Type: corev1.ServiceTypeClusterIP, - Ports: extNSvc.Spec.Ports, + Ports: ports, }, } } @@ -222,9 +249,9 @@ func portsForEndpointSlice(svc *corev1.Service) []discoveryv1.EndpointPort { return ports } -func mustHaveConfigForSvc(t *testing.T, cl client.Client, extNSvc, clusterIPSvc *corev1.Service, cm *corev1.ConfigMap) { +func mustHaveConfigForSvc(t *testing.T, cl client.Client, extNSvc, clusterIPSvc *corev1.Service, cm *corev1.ConfigMap, lg *zap.Logger) { t.Helper() - wantsCfg := egressSvcCfg(extNSvc, clusterIPSvc) + wantsCfg := egressSvcCfg(extNSvc, clusterIPSvc, clusterIPSvc.Namespace, lg.Sugar()) if err := cl.Get(context.Background(), client.ObjectKeyFromObject(cm), cm); err != nil { t.Fatalf("Error retrieving ConfigMap: %v", err) } diff --git a/cmd/k8s-operator/generate/main.go b/cmd/k8s-operator/generate/main.go index 25435a47c..08bdc350d 100644 --- a/cmd/k8s-operator/generate/main.go +++ b/cmd/k8s-operator/generate/main.go @@ -41,11 +41,16 @@ func main() { if len(os.Args) < 2 { log.Fatalf("usage ./generate [staticmanifests|helmcrd]") } - repoRoot := "../../" + gitOut, err := exec.Command("git", "rev-parse", "--show-toplevel").CombinedOutput() + if err != nil { + log.Fatalf("error determining git root: %v: %s", err, gitOut) + } + + repoRoot := strings.TrimSpace(string(gitOut)) switch os.Args[1] { case "helmcrd": // insert CRDs to Helm templates behind a installCRDs=true conditional check log.Print("Adding CRDs to Helm templates") - if err := generate("./"); err != nil { + if err := generate(repoRoot); err != nil { log.Fatalf("error adding CRDs to Helm templates: %v", err) } return @@ -64,7 +69,7 @@ func main() { }() log.Print("Templating Helm chart contents") helmTmplCmd := exec.Command("./tool/helm", "template", "operator", "./cmd/k8s-operator/deploy/chart", - "--namespace=tailscale") + "--namespace=tailscale", "--set=oauth.clientSecret=''") helmTmplCmd.Dir = repoRoot var out bytes.Buffer helmTmplCmd.Stdout = &out @@ -139,7 +144,7 @@ func generate(baseDir string) error { if _, err := file.Write([]byte(helmConditionalEnd)); err != nil { return fmt.Errorf("error writing helm if-statement end: %w", err) } - return nil + return file.Close() } for _, crd := range []struct { crdPath, templatePath string diff --git a/cmd/k8s-operator/generate/main_test.go b/cmd/k8s-operator/generate/main_test.go index c7956dcdb..5ea7fec80 100644 --- a/cmd/k8s-operator/generate/main_test.go +++ b/cmd/k8s-operator/generate/main_test.go @@ -7,26 +7,50 @@ package main import ( "bytes" + "context" + "net" "os" "os/exec" "path/filepath" "strings" "testing" + "time" + + "tailscale.com/tstest/nettest" + "tailscale.com/util/cibuild" ) func Test_generate(t *testing.T) { + nettest.SkipIfNoNetwork(t) + + ctx, cancel := context.WithTimeout(t.Context(), 10*time.Second) + defer cancel() + if _, err := net.DefaultResolver.LookupIPAddr(ctx, "get.helm.sh"); err != nil { + // https://github.com/helm/helm/issues/31434 + t.Skipf("get.helm.sh seems down or unreachable; skipping test") + } + base, err := os.Getwd() base = filepath.Join(base, "../../../") if err != nil { t.Fatalf("error getting current working directory: %v", err) } defer cleanup(base) + + helmCLIPath := filepath.Join(base, "tool/helm") + if out, err := exec.Command(helmCLIPath, "version").CombinedOutput(); err != nil && cibuild.On() { + // It's not just DNS. Azure is generating bogus certs within GitHub Actions at least for + // helm. So try to run it and see if we can even fetch it. + // + // https://github.com/helm/helm/issues/31434 + t.Skipf("error fetching helm; skipping test in CI: %v, %s", err, out) + } + if err := generate(base); err != nil { t.Fatalf("CRD template generation: %v", err) } tempDir := t.TempDir() - helmCLIPath := filepath.Join(base, "tool/helm") helmChartTemplatesPath := filepath.Join(base, "cmd/k8s-operator/deploy/chart") helmPackageCmd := exec.Command(helmCLIPath, "package", helmChartTemplatesPath, "--destination", tempDir, "--version", "0.0.1") helmPackageCmd.Stderr = os.Stderr diff --git a/cmd/k8s-operator/ingress-for-pg.go b/cmd/k8s-operator/ingress-for-pg.go new file mode 100644 index 000000000..4d8311805 --- /dev/null +++ b/cmd/k8s-operator/ingress-for-pg.go @@ -0,0 +1,1132 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !plan9 + +package main + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "math/rand/v2" + "net/http" + "reflect" + "slices" + "strings" + "sync" + "time" + + "go.uber.org/zap" + corev1 "k8s.io/api/core/v1" + networkingv1 "k8s.io/api/networking/v1" + rbacv1 "k8s.io/api/rbac/v1" + apiequality "k8s.io/apimachinery/pkg/api/equality" + apierrors "k8s.io/apimachinery/pkg/api/errors" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/types" + "k8s.io/client-go/tools/record" + "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/reconcile" + "tailscale.com/internal/client/tailscale" + "tailscale.com/ipn" + "tailscale.com/ipn/ipnstate" + tsoperator "tailscale.com/k8s-operator" + tsapi "tailscale.com/k8s-operator/apis/v1alpha1" + "tailscale.com/kube/kubetypes" + "tailscale.com/tailcfg" + "tailscale.com/util/clientmetric" + "tailscale.com/util/dnsname" + "tailscale.com/util/mak" + "tailscale.com/util/set" +) + +const ( + serveConfigKey = "serve-config.json" + TailscaleSvcOwnerRef = "tailscale.com/k8s-operator:owned-by:%s" + // FinalizerNamePG is the finalizer used by the IngressPGReconciler + FinalizerNamePG = "tailscale.com/ingress-pg-finalizer" + + indexIngressProxyGroup = ".metadata.annotations.ingress-proxy-group" + // annotationHTTPEndpoint can be used to configure the Ingress to expose an HTTP endpoint to tailnet (as + // well as the default HTTPS endpoint). + annotationHTTPEndpoint = "tailscale.com/http-endpoint" + + labelDomain = "tailscale.com/domain" + msgFeatureFlagNotEnabled = "Tailscale Service feature flag is not enabled for this tailnet, skipping provisioning. " + + "Please contact Tailscale support through https://tailscale.com/contact/support to enable the feature flag, then recreate the operator's Pod." + + warningTailscaleServiceFeatureFlagNotEnabled = "TailscaleServiceFeatureFlagNotEnabled" + managedTSServiceComment = "This Tailscale Service is managed by the Tailscale Kubernetes Operator, do not modify" +) + +var gaugePGIngressResources = clientmetric.NewGauge(kubetypes.MetricIngressPGResourceCount) + +// HAIngressReconciler is a controller that reconciles Tailscale Ingresses +// should be exposed on an ingress ProxyGroup (in HA mode). +type HAIngressReconciler struct { + client.Client + + recorder record.EventRecorder + logger *zap.SugaredLogger + tsClient tsClient + tsnetServer tsnetServer + tsNamespace string + lc localClient + defaultTags []string + operatorID string // stableID of the operator's Tailscale device + ingressClassName string + + mu sync.Mutex // protects following + // managedIngresses is a set of all ingress resources that we're currently + // managing. This is only used for metrics. + managedIngresses set.Slice[types.UID] +} + +// Reconcile reconciles Ingresses that should be exposed over Tailscale in HA +// mode (on a ProxyGroup). It looks at all Ingresses with +// tailscale.com/proxy-group annotation. For each such Ingress, it ensures that +// a TailscaleService named after the hostname of the Ingress exists and is up to +// date. It also ensures that the serve config for the ingress ProxyGroup is +// updated to route traffic for the Tailscale Service to the Ingress's backend +// Services. Ingress hostname change also results in the Tailscale Service for the +// previous hostname being cleaned up and a new Tailscale Service being created for the +// new hostname. +// HA Ingresses support multi-cluster Ingress setup. +// Each Tailscale Service contains a list of owner references that uniquely identify +// the Ingress resource and the operator. When an Ingress that acts as a +// backend is being deleted, the corresponding Tailscale Service is only deleted if the +// only owner reference that it contains is for this Ingress. If other owner +// references are found, then cleanup operation only removes this Ingress' owner +// reference. +func (r *HAIngressReconciler) Reconcile(ctx context.Context, req reconcile.Request) (res reconcile.Result, err error) { + logger := r.logger.With("Ingress", req.NamespacedName) + logger.Debugf("starting reconcile") + defer logger.Debugf("reconcile finished") + + ing := new(networkingv1.Ingress) + err = r.Get(ctx, req.NamespacedName, ing) + if apierrors.IsNotFound(err) { + // Request object not found, could have been deleted after reconcile request. + logger.Debugf("Ingress not found, assuming it was deleted") + return res, nil + } else if err != nil { + return res, fmt.Errorf("failed to get Ingress: %w", err) + } + + // hostname is the name of the Tailscale Service that will be created + // for this Ingress as well as the first label in the MagicDNS name of + // the Ingress. + hostname := hostnameForIngress(ing) + logger = logger.With("hostname", hostname) + + // needsRequeue is set to true if the underlying Tailscale Service has + // changed as a result of this reconcile. If that is the case, we + // reconcile the Ingress one more time to ensure that concurrent updates + // to the Tailscale Service in a multi-cluster Ingress setup have not + // resulted in another actor overwriting our Tailscale Service update. + needsRequeue := false + if !ing.DeletionTimestamp.IsZero() || !r.shouldExpose(ing) { + needsRequeue, err = r.maybeCleanup(ctx, hostname, ing, logger) + } else { + needsRequeue, err = r.maybeProvision(ctx, hostname, ing, logger) + } + if err != nil { + return res, err + } + if needsRequeue { + res = reconcile.Result{RequeueAfter: requeueInterval()} + } + return res, nil +} + +// maybeProvision ensures that a Tailscale Service for this Ingress exists and is up to date and that the serve config for the +// corresponding ProxyGroup contains the Ingress backend's definition. +// If a Tailscale Service does not exist, it will be created. +// If a Tailscale Service exists, but only with owner references from other operator instances, an owner reference for this +// operator instance is added. +// If a Tailscale Service exists, but does not have an owner reference from any operator, we error +// out assuming that this is an owner reference created by an unknown actor. +// Returns true if the operation resulted in a Tailscale Service update. +func (r *HAIngressReconciler) maybeProvision(ctx context.Context, hostname string, ing *networkingv1.Ingress, logger *zap.SugaredLogger) (svcsChanged bool, err error) { + // Currently (2025-05) Tailscale Services are behind an alpha feature flag that + // needs to be explicitly enabled for a tailnet to be able to use them. + serviceName := tailcfg.ServiceName("svc:" + hostname) + existingTSSvc, err := r.tsClient.GetVIPService(ctx, serviceName) + if err != nil && !isErrorTailscaleServiceNotFound(err) { + return false, fmt.Errorf("error getting Tailscale Service %q: %w", hostname, err) + } + + if err := validateIngressClass(ctx, r.Client, r.ingressClassName); err != nil { + logger.Infof("error validating tailscale IngressClass: %v.", err) + return false, nil + } + // Get and validate ProxyGroup readiness + pgName := ing.Annotations[AnnotationProxyGroup] + if pgName == "" { + logger.Infof("[unexpected] no ProxyGroup annotation, skipping Tailscale Service provisioning") + return false, nil + } + logger = logger.With("ProxyGroup", pgName) + + pg := &tsapi.ProxyGroup{} + if err := r.Get(ctx, client.ObjectKey{Name: pgName}, pg); err != nil { + if apierrors.IsNotFound(err) { + logger.Infof("ProxyGroup does not exist") + return false, nil + } + return false, fmt.Errorf("getting ProxyGroup %q: %w", pgName, err) + } + if !tsoperator.ProxyGroupAvailable(pg) { + logger.Infof("ProxyGroup is not (yet) ready") + return false, nil + } + + // Validate Ingress configuration + if err := r.validateIngress(ctx, ing, pg); err != nil { + logger.Infof("invalid Ingress configuration: %v", err) + r.recorder.Event(ing, corev1.EventTypeWarning, "InvalidIngressConfiguration", err.Error()) + return false, nil + } + + if !IsHTTPSEnabledOnTailnet(r.tsnetServer) { + r.recorder.Event(ing, corev1.EventTypeWarning, "HTTPSNotEnabled", "HTTPS is not enabled on the tailnet; ingress may not work") + } + + if !slices.Contains(ing.Finalizers, FinalizerNamePG) { + // This log line is printed exactly once during initial provisioning, + // because once the finalizer is in place this block gets skipped. So, + // this is a nice place to tell the operator that the high level, + // multi-reconcile operation is underway. + logger.Infof("exposing Ingress over tailscale") + ing.Finalizers = append(ing.Finalizers, FinalizerNamePG) + if err := r.Update(ctx, ing); err != nil { + return false, fmt.Errorf("failed to add finalizer: %w", err) + } + r.mu.Lock() + r.managedIngresses.Add(ing.UID) + gaugePGIngressResources.Set(int64(r.managedIngresses.Len())) + r.mu.Unlock() + } + + // 1. Ensure that if Ingress' hostname has changed, any Tailscale Service + // resources corresponding to the old hostname are cleaned up. + // In practice, this function will ensure that any Tailscale Services that are + // associated with the provided ProxyGroup and no longer owned by an + // Ingress are cleaned up. This is fine- it is not expensive and ensures + // that in edge cases (a single update changed both hostname and removed + // ProxyGroup annotation) the Tailscale Service is more likely to be + // (eventually) removed. + svcsChanged, err = r.maybeCleanupProxyGroup(ctx, pgName, logger) + if err != nil { + return false, fmt.Errorf("failed to cleanup Tailscale Service resources for ProxyGroup: %w", err) + } + + // 2. Ensure that there isn't a Tailscale Service with the same hostname + // already created and not owned by this Ingress. + // TODO(irbekrm): perhaps in future we could have record names being + // stored on Tailscale Services. I am not certain if there might not be edge + // cases (custom domains, etc?) where attempting to determine the DNS + // name of the Tailscale Service in this way won't be incorrect. + + // Generate the Tailscale Service owner annotation for a new or existing Tailscale Service. + // This checks and ensures that Tailscale Service's owner references are updated + // for this Ingress and errors if that is not possible (i.e. because it + // appears that the Tailscale Service has been created by a non-operator actor). + updatedAnnotations, err := ownerAnnotations(r.operatorID, existingTSSvc) + if err != nil { + const instr = "To proceed, you can either manually delete the existing Tailscale Service or choose a different MagicDNS name at `.spec.tls.hosts[0] in the Ingress definition" + msg := fmt.Sprintf("error ensuring ownership of Tailscale Service %s: %v. %s", hostname, err, instr) + logger.Warn(msg) + r.recorder.Event(ing, corev1.EventTypeWarning, "InvalidTailscaleService", msg) + return false, nil + } + // 3. Ensure that TLS Secret and RBAC exists + tcd, err := tailnetCertDomain(ctx, r.lc) + if err != nil { + return false, fmt.Errorf("error determining DNS name base: %w", err) + } + dnsName := hostname + "." + tcd + if err := r.ensureCertResources(ctx, pg, dnsName, ing); err != nil { + return false, fmt.Errorf("error ensuring cert resources: %w", err) + } + + // 4. Ensure that the serve config for the ProxyGroup contains the Tailscale Service. + cm, cfg, err := r.proxyGroupServeConfig(ctx, pgName) + if err != nil { + return false, fmt.Errorf("error getting Ingress serve config: %w", err) + } + if cm == nil { + logger.Infof("no Ingress serve config ConfigMap found, unable to update serve config. Ensure that ProxyGroup is healthy.") + return svcsChanged, nil + } + ep := ipn.HostPort(fmt.Sprintf("%s:443", dnsName)) + handlers, err := handlersForIngress(ctx, ing, r.Client, r.recorder, dnsName, logger) + if err != nil { + return false, fmt.Errorf("failed to get handlers for Ingress: %w", err) + } + ingCfg := &ipn.ServiceConfig{ + TCP: map[uint16]*ipn.TCPPortHandler{ + 443: { + HTTPS: true, + }, + }, + Web: map[ipn.HostPort]*ipn.WebServerConfig{ + ep: { + Handlers: handlers, + }, + }, + } + + // Add HTTP endpoint if configured. + if isHTTPEndpointEnabled(ing) { + logger.Infof("exposing Ingress over HTTP") + epHTTP := ipn.HostPort(fmt.Sprintf("%s:80", dnsName)) + ingCfg.TCP[80] = &ipn.TCPPortHandler{ + HTTP: true, + } + ingCfg.Web[epHTTP] = &ipn.WebServerConfig{ + Handlers: handlers, + } + } + + var gotCfg *ipn.ServiceConfig + if cfg != nil && cfg.Services != nil { + gotCfg = cfg.Services[serviceName] + } + if !reflect.DeepEqual(gotCfg, ingCfg) { + logger.Infof("Updating serve config") + mak.Set(&cfg.Services, serviceName, ingCfg) + cfgBytes, err := json.Marshal(cfg) + if err != nil { + return false, fmt.Errorf("error marshaling serve config: %w", err) + } + mak.Set(&cm.BinaryData, serveConfigKey, cfgBytes) + if err := r.Update(ctx, cm); err != nil { + return false, fmt.Errorf("error updating serve config: %w", err) + } + } + + // 4. Ensure that the Tailscale Service exists and is up to date. + tags := r.defaultTags + if tstr, ok := ing.Annotations[AnnotationTags]; ok { + tags = strings.Split(tstr, ",") + } + + tsSvcPorts := []string{"tcp:443"} // always 443 for Ingress + if isHTTPEndpointEnabled(ing) { + tsSvcPorts = append(tsSvcPorts, "tcp:80") + } + + tsSvc := &tailscale.VIPService{ + Name: serviceName, + Tags: tags, + Ports: tsSvcPorts, + Comment: managedTSServiceComment, + Annotations: updatedAnnotations, + } + if existingTSSvc != nil { + tsSvc.Addrs = existingTSSvc.Addrs + } + // TODO(irbekrm): right now if two Ingress resources attempt to apply different Tailscale Service configs (different + // tags, or HTTP endpoint settings) we can end up reconciling those in a loop. We should detect when an Ingress + // with the same generation number has been reconciled ~more than N times and stop attempting to apply updates. + if existingTSSvc == nil || + !reflect.DeepEqual(tsSvc.Tags, existingTSSvc.Tags) || + !reflect.DeepEqual(tsSvc.Ports, existingTSSvc.Ports) || + !ownersAreSetAndEqual(tsSvc, existingTSSvc) { + logger.Infof("Ensuring Tailscale Service exists and is up to date") + if err := r.tsClient.CreateOrUpdateVIPService(ctx, tsSvc); err != nil { + return false, fmt.Errorf("error creating Tailscale Service: %w", err) + } + } + + // 5. Update tailscaled's AdvertiseServices config, which should add the Tailscale Service + // IPs to the ProxyGroup Pods' AllowedIPs in the next netmap update if approved. + mode := serviceAdvertisementHTTPS + if isHTTPEndpointEnabled(ing) { + mode = serviceAdvertisementHTTPAndHTTPS + } + if err = r.maybeUpdateAdvertiseServicesConfig(ctx, pg.Name, serviceName, mode, logger); err != nil { + return false, fmt.Errorf("failed to update tailscaled config: %w", err) + } + + // 6. Update Ingress status if ProxyGroup Pods are ready. + count, err := numberPodsAdvertising(ctx, r.Client, r.tsNamespace, pg.Name, serviceName) + if err != nil { + return false, fmt.Errorf("failed to check if any Pods are configured: %w", err) + } + + oldStatus := ing.Status.DeepCopy() + + switch count { + case 0: + ing.Status.LoadBalancer.Ingress = nil + default: + var ports []networkingv1.IngressPortStatus + hasCerts, err := hasCerts(ctx, r.Client, r.lc, r.tsNamespace, serviceName) + if err != nil { + return false, fmt.Errorf("error checking TLS credentials provisioned for Ingress: %w", err) + } + // If TLS certs have not been issued (yet), do not set port 443. + if hasCerts { + ports = append(ports, networkingv1.IngressPortStatus{ + Protocol: "TCP", + Port: 443, + }) + } + if isHTTPEndpointEnabled(ing) { + ports = append(ports, networkingv1.IngressPortStatus{ + Protocol: "TCP", + Port: 80, + }) + } + // Set Ingress status hostname only if either port 443 or 80 is advertised. + var hostname string + if len(ports) != 0 { + hostname = dnsName + } + ing.Status.LoadBalancer.Ingress = []networkingv1.IngressLoadBalancerIngress{ + { + Hostname: hostname, + Ports: ports, + }, + } + } + if apiequality.Semantic.DeepEqual(oldStatus, &ing.Status) { + return svcsChanged, nil + } + + const prefix = "Updating Ingress status" + if count == 0 { + logger.Infof("%s. No Pods are advertising Tailscale Service yet", prefix) + } else { + logger.Infof("%s. %d Pod(s) advertising Tailscale Service", prefix, count) + } + + if err := r.Status().Update(ctx, ing); err != nil { + return false, fmt.Errorf("failed to update Ingress status: %w", err) + } + return svcsChanged, nil +} + +// maybeCleanupProxyGroup ensures that any Tailscale Services that are +// associated with the provided ProxyGroup and no longer needed for any +// Ingresses exposed on this ProxyGroup are deleted, if not owned by other +// operator instances, else the owner reference is cleaned up. Returns true if +// the operation resulted in an existing Tailscale Service updates (owner +// reference removal). +func (r *HAIngressReconciler) maybeCleanupProxyGroup(ctx context.Context, proxyGroupName string, logger *zap.SugaredLogger) (svcsChanged bool, err error) { + // Get serve config for the ProxyGroup + cm, cfg, err := r.proxyGroupServeConfig(ctx, proxyGroupName) + if err != nil { + return false, fmt.Errorf("getting serve config: %w", err) + } + if cfg == nil { + // ProxyGroup does not have any Tailscale Services associated with it. + return false, nil + } + + ingList := &networkingv1.IngressList{} + if err := r.List(ctx, ingList); err != nil { + return false, fmt.Errorf("listing Ingresses: %w", err) + } + serveConfigChanged := false + // For each Tailscale Service in serve config... + for tsSvcName := range cfg.Services { + // ...check if there is currently an Ingress with this hostname + found := false + for _, i := range ingList.Items { + ingressHostname := hostnameForIngress(&i) + if ingressHostname == tsSvcName.WithoutPrefix() { + found = true + break + } + } + + if !found { + logger.Infof("Tailscale Service %q is not owned by any Ingress, cleaning up", tsSvcName) + tsService, err := r.tsClient.GetVIPService(ctx, tsSvcName) + if isErrorTailscaleServiceNotFound(err) { + return false, nil + } + if err != nil { + return false, fmt.Errorf("getting Tailscale Service %q: %w", tsSvcName, err) + } + + // Delete the Tailscale Service from control if necessary. + svcsChanged, err = r.cleanupTailscaleService(ctx, tsService, logger) + if err != nil { + return false, fmt.Errorf("deleting Tailscale Service %q: %w", tsSvcName, err) + } + + // Make sure the Tailscale Service is not advertised in tailscaled or serve config. + if err = r.maybeUpdateAdvertiseServicesConfig(ctx, proxyGroupName, tsSvcName, serviceAdvertisementOff, logger); err != nil { + return false, fmt.Errorf("failed to update tailscaled config services: %w", err) + } + _, ok := cfg.Services[tsSvcName] + if ok { + logger.Infof("Removing Tailscale Service %q from serve config", tsSvcName) + delete(cfg.Services, tsSvcName) + serveConfigChanged = true + } + if err := cleanupCertResources(ctx, r.Client, r.lc, r.tsNamespace, proxyGroupName, tsSvcName); err != nil { + return false, fmt.Errorf("failed to clean up cert resources: %w", err) + } + } + } + + if serveConfigChanged { + cfgBytes, err := json.Marshal(cfg) + if err != nil { + return false, fmt.Errorf("marshaling serve config: %w", err) + } + mak.Set(&cm.BinaryData, serveConfigKey, cfgBytes) + if err := r.Update(ctx, cm); err != nil { + return false, fmt.Errorf("updating serve config: %w", err) + } + } + return svcsChanged, nil +} + +// maybeCleanup ensures that any resources, such as a Tailscale Service created for this Ingress, are cleaned up when the +// Ingress is being deleted or is unexposed. The cleanup is safe for a multi-cluster setup- the Tailscale Service is only +// deleted if it does not contain any other owner references. If it does the cleanup only removes the owner reference +// corresponding to this Ingress. +func (r *HAIngressReconciler) maybeCleanup(ctx context.Context, hostname string, ing *networkingv1.Ingress, logger *zap.SugaredLogger) (svcChanged bool, err error) { + logger.Debugf("Ensuring any resources for Ingress are cleaned up") + ix := slices.Index(ing.Finalizers, FinalizerNamePG) + if ix < 0 { + logger.Debugf("no finalizer, nothing to do") + return false, nil + } + logger.Infof("Ensuring that Tailscale Service %q configuration is cleaned up", hostname) + serviceName := tailcfg.ServiceName("svc:" + hostname) + svc, err := r.tsClient.GetVIPService(ctx, serviceName) + if err != nil { + if isErrorTailscaleServiceNotFound(err) { + return false, nil + } + return false, fmt.Errorf("error getting Tailscale Service: %w", err) + } + + // Ensure that if cleanup succeeded Ingress finalizers are removed. + defer func() { + if err != nil { + return + } + err = r.deleteFinalizer(ctx, ing, logger) + }() + + // 1. Check if there is a Tailscale Service associated with this Ingress. + pg := ing.Annotations[AnnotationProxyGroup] + cm, cfg, err := r.proxyGroupServeConfig(ctx, pg) + if err != nil { + return false, fmt.Errorf("error getting ProxyGroup serve config: %w", err) + } + + // Tailscale Service is always first added to serve config and only then created in the Tailscale API, so if it is not + // found in the serve config, we can assume that there is no Tailscale Service. (If the serve config does not exist at + // all, it is possible that the ProxyGroup has been deleted before cleaning up the Ingress, so carry on with + // cleanup). + if cfg != nil && cfg.Services != nil && cfg.Services[serviceName] == nil { + return false, nil + } + + // 2. Clean up the Tailscale Service resources. + svcChanged, err = r.cleanupTailscaleService(ctx, svc, logger) + if err != nil { + return false, fmt.Errorf("error deleting Tailscale Service: %w", err) + } + + // 3. Clean up any cluster resources + if err := cleanupCertResources(ctx, r.Client, r.lc, r.tsNamespace, pg, serviceName); err != nil { + return false, fmt.Errorf("failed to clean up cert resources: %w", err) + } + + if cfg == nil || cfg.Services == nil { // user probably deleted the ProxyGroup + return svcChanged, nil + } + + // 4. Unadvertise the Tailscale Service in tailscaled config. + if err = r.maybeUpdateAdvertiseServicesConfig(ctx, pg, serviceName, serviceAdvertisementOff, logger); err != nil { + return false, fmt.Errorf("failed to update tailscaled config services: %w", err) + } + + // 5. Remove the Tailscale Service from the serve config for the ProxyGroup. + logger.Infof("Removing TailscaleService %q from serve config for ProxyGroup %q", hostname, pg) + delete(cfg.Services, serviceName) + cfgBytes, err := json.Marshal(cfg) + if err != nil { + return false, fmt.Errorf("error marshaling serve config: %w", err) + } + mak.Set(&cm.BinaryData, serveConfigKey, cfgBytes) + return svcChanged, r.Update(ctx, cm) +} + +func (r *HAIngressReconciler) deleteFinalizer(ctx context.Context, ing *networkingv1.Ingress, logger *zap.SugaredLogger) error { + found := false + ing.Finalizers = slices.DeleteFunc(ing.Finalizers, func(f string) bool { + found = true + return f == FinalizerNamePG + }) + if !found { + return nil + } + logger.Debug("ensure %q finalizer is removed", FinalizerNamePG) + + if err := r.Update(ctx, ing); err != nil { + return fmt.Errorf("failed to remove finalizer %q: %w", FinalizerNamePG, err) + } + r.mu.Lock() + defer r.mu.Unlock() + r.managedIngresses.Remove(ing.UID) + gaugePGIngressResources.Set(int64(r.managedIngresses.Len())) + return nil +} + +func pgIngressCMName(pg string) string { + return fmt.Sprintf("%s-ingress-config", pg) +} + +func (r *HAIngressReconciler) proxyGroupServeConfig(ctx context.Context, pg string) (cm *corev1.ConfigMap, cfg *ipn.ServeConfig, err error) { + name := pgIngressCMName(pg) + cm = &corev1.ConfigMap{ + ObjectMeta: metav1.ObjectMeta{ + Name: name, + Namespace: r.tsNamespace, + }, + } + if err := r.Get(ctx, client.ObjectKeyFromObject(cm), cm); err != nil && !apierrors.IsNotFound(err) { + return nil, nil, fmt.Errorf("error retrieving ingress serve config ConfigMap %s: %v", name, err) + } + if apierrors.IsNotFound(err) { + return nil, nil, nil + } + cfg = &ipn.ServeConfig{} + if len(cm.BinaryData[serveConfigKey]) != 0 { + if err := json.Unmarshal(cm.BinaryData[serveConfigKey], cfg); err != nil { + return nil, nil, fmt.Errorf("error unmarshaling ingress serve config %v: %w", cm.BinaryData[serveConfigKey], err) + } + } + return cm, cfg, nil +} + +type localClient interface { + StatusWithoutPeers(ctx context.Context) (*ipnstate.Status, error) +} + +// tailnetCertDomain returns the base domain (TCD) of the current tailnet. +func tailnetCertDomain(ctx context.Context, lc localClient) (string, error) { + st, err := lc.StatusWithoutPeers(ctx) + if err != nil { + return "", fmt.Errorf("error getting tailscale status: %w", err) + } + return st.CurrentTailnet.MagicDNSSuffix, nil +} + +// shouldExpose returns true if the Ingress should be exposed over Tailscale in HA mode (on a ProxyGroup). +func (r *HAIngressReconciler) shouldExpose(ing *networkingv1.Ingress) bool { + isTSIngress := ing != nil && + ing.Spec.IngressClassName != nil && + *ing.Spec.IngressClassName == r.ingressClassName + pgAnnot := ing.Annotations[AnnotationProxyGroup] + return isTSIngress && pgAnnot != "" +} + +// validateIngress validates that the Ingress is properly configured. +// Currently validates: +// - Any tags provided via tailscale.com/tags annotation are valid Tailscale ACL tags +// - The derived hostname is a valid DNS label +// - The referenced ProxyGroup exists and is of type 'ingress' +// - Ingress' TLS block is invalid +func (r *HAIngressReconciler) validateIngress(ctx context.Context, ing *networkingv1.Ingress, pg *tsapi.ProxyGroup) error { + var errs []error + + // Validate tags if present + violations := tagViolations(ing) + if len(violations) > 0 { + errs = append(errs, fmt.Errorf("Ingress contains invalid tags: %v", strings.Join(violations, ","))) + } + + // Validate TLS configuration + if len(ing.Spec.TLS) > 0 && (len(ing.Spec.TLS) > 1 || len(ing.Spec.TLS[0].Hosts) > 1) { + errs = append(errs, fmt.Errorf("Ingress contains invalid TLS block %v: only a single TLS entry with a single host is allowed", ing.Spec.TLS)) + } + + // Validate that the hostname will be a valid DNS label + hostname := hostnameForIngress(ing) + if err := dnsname.ValidLabel(hostname); err != nil { + errs = append(errs, fmt.Errorf("invalid hostname %q: %w. Ensure that the hostname is a valid DNS label", hostname, err)) + } + + // Validate ProxyGroup type + if pg.Spec.Type != tsapi.ProxyGroupTypeIngress { + errs = append(errs, fmt.Errorf("ProxyGroup %q is of type %q but must be of type %q", + pg.Name, pg.Spec.Type, tsapi.ProxyGroupTypeIngress)) + } + + // Validate ProxyGroup readiness + if !tsoperator.ProxyGroupAvailable(pg) { + errs = append(errs, fmt.Errorf("ProxyGroup %q is not ready", pg.Name)) + } + + // It is invalid to have multiple Ingress resources for the same Tailscale Service in one cluster. + ingList := &networkingv1.IngressList{} + if err := r.List(ctx, ingList); err != nil { + errs = append(errs, fmt.Errorf("[unexpected] error listing Ingresses: %w", err)) + return errors.Join(errs...) + } + for _, i := range ingList.Items { + if r.shouldExpose(&i) && hostnameForIngress(&i) == hostname && i.UID != ing.UID { + errs = append(errs, fmt.Errorf("found duplicate Ingress %q for hostname %q - multiple Ingresses for the same hostname in the same cluster are not allowed", client.ObjectKeyFromObject(&i), hostname)) + } + } + return errors.Join(errs...) +} + +// cleanupTailscaleService deletes any Tailscale Service by the provided name if it is not owned by operator instances other than this one. +// If a Tailscale Service is found, but contains other owner references, only removes this operator's owner reference. +// If a Tailscale Service by the given name is not found or does not contain this operator's owner reference, do nothing. +// It returns true if an existing Tailscale Service was updated to remove owner reference, as well as any error that occurred. +func (r *HAIngressReconciler) cleanupTailscaleService(ctx context.Context, svc *tailscale.VIPService, logger *zap.SugaredLogger) (updated bool, _ error) { + if svc == nil { + return false, nil + } + o, err := parseOwnerAnnotation(svc) + if err != nil { + return false, fmt.Errorf("error parsing Tailscale Service's owner annotation") + } + if o == nil || len(o.OwnerRefs) == 0 { + return false, nil + } + // Comparing with the operatorID only means that we will not be able to + // clean up Tailscale Service in cases where the operator was deleted from the + // cluster before deleting the Ingress. Perhaps the comparison could be + // 'if or.OperatorID === r.operatorID || or.ingressUID == r.ingressUID'. + ix := slices.IndexFunc(o.OwnerRefs, func(or OwnerRef) bool { + return or.OperatorID == r.operatorID + }) + if ix == -1 { + return false, nil + } + if len(o.OwnerRefs) == 1 { + logger.Infof("Deleting Tailscale Service %q", svc.Name) + return false, r.tsClient.DeleteVIPService(ctx, svc.Name) + } + o.OwnerRefs = slices.Delete(o.OwnerRefs, ix, ix+1) + logger.Infof("Deleting Tailscale Service %q", svc.Name) + json, err := json.Marshal(o) + if err != nil { + return false, fmt.Errorf("error marshalling updated Tailscale Service owner reference: %w", err) + } + svc.Annotations[ownerAnnotation] = string(json) + return true, r.tsClient.CreateOrUpdateVIPService(ctx, svc) +} + +// isHTTPEndpointEnabled returns true if the Ingress has been configured to expose an HTTP endpoint to tailnet. +func isHTTPEndpointEnabled(ing *networkingv1.Ingress) bool { + if ing == nil { + return false + } + return ing.Annotations[annotationHTTPEndpoint] == "enabled" +} + +// serviceAdvertisementMode describes the desired state of a Tailscale Service. +type serviceAdvertisementMode int + +const ( + serviceAdvertisementOff serviceAdvertisementMode = iota // Should not be advertised + serviceAdvertisementHTTPS // Port 443 should be advertised + serviceAdvertisementHTTPAndHTTPS // Both ports 80 and 443 should be advertised +) + +func (a *HAIngressReconciler) maybeUpdateAdvertiseServicesConfig(ctx context.Context, pgName string, serviceName tailcfg.ServiceName, mode serviceAdvertisementMode, logger *zap.SugaredLogger) (err error) { + // Get all config Secrets for this ProxyGroup. + secrets := &corev1.SecretList{} + if err := a.List(ctx, secrets, client.InNamespace(a.tsNamespace), client.MatchingLabels(pgSecretLabels(pgName, kubetypes.LabelSecretTypeConfig))); err != nil { + return fmt.Errorf("failed to list config Secrets: %w", err) + } + + // Verify that TLS cert for the Tailscale Service has been successfully issued + // before attempting to advertise the service. + // This is so that in multi-cluster setups where some Ingresses succeed + // to issue certs and some do not (rate limits), clients are not pinned + // to a backend that is not able to serve HTTPS. + // The only exception is Ingresses with an HTTP endpoint enabled - if an + // Ingress has an HTTP endpoint enabled, it will be advertised even if the + // TLS cert is not yet provisioned. + hasCert, err := hasCerts(ctx, a.Client, a.lc, a.tsNamespace, serviceName) + if err != nil { + return fmt.Errorf("error checking TLS credentials provisioned for service %q: %w", serviceName, err) + } + shouldBeAdvertised := (mode == serviceAdvertisementHTTPAndHTTPS) || + (mode == serviceAdvertisementHTTPS && hasCert) // if we only expose port 443 and don't have certs (yet), do not advertise + + for _, secret := range secrets.Items { + var updated bool + for fileName, confB := range secret.Data { + var conf ipn.ConfigVAlpha + if err := json.Unmarshal(confB, &conf); err != nil { + return fmt.Errorf("error unmarshalling ProxyGroup config: %w", err) + } + + // Update the services to advertise if required. + idx := slices.Index(conf.AdvertiseServices, serviceName.String()) + isAdvertised := idx >= 0 + switch { + case isAdvertised == shouldBeAdvertised: + // Already up to date. + continue + case isAdvertised: + // Needs to be removed. + conf.AdvertiseServices = slices.Delete(conf.AdvertiseServices, idx, idx+1) + case shouldBeAdvertised: + // Needs to be added. + conf.AdvertiseServices = append(conf.AdvertiseServices, serviceName.String()) + } + + // Update the Secret. + confB, err := json.Marshal(conf) + if err != nil { + return fmt.Errorf("error marshalling ProxyGroup config: %w", err) + } + mak.Set(&secret.Data, fileName, confB) + updated = true + } + + if updated { + if err := a.Update(ctx, &secret); err != nil { + return fmt.Errorf("error updating ProxyGroup config Secret: %w", err) + } + } + } + + return nil +} + +func numberPodsAdvertising(ctx context.Context, cl client.Client, tsNamespace, pgName string, serviceName tailcfg.ServiceName) (int, error) { + // Get all state Secrets for this ProxyGroup. + secrets := &corev1.SecretList{} + if err := cl.List(ctx, secrets, client.InNamespace(tsNamespace), client.MatchingLabels(pgSecretLabels(pgName, kubetypes.LabelSecretTypeState))); err != nil { + return 0, fmt.Errorf("failed to list ProxyGroup %q state Secrets: %w", pgName, err) + } + + var count int + for _, secret := range secrets.Items { + prefs, ok, err := getDevicePrefs(&secret) + if err != nil { + return 0, fmt.Errorf("error getting node metadata: %w", err) + } + if !ok { + continue + } + if slices.Contains(prefs.AdvertiseServices, serviceName.String()) { + count++ + } + } + + return count, nil +} + +const ownerAnnotation = "tailscale.com/owner-references" + +// ownerAnnotationValue is the content of the TailscaleService.Annotation[ownerAnnotation] field. +type ownerAnnotationValue struct { + // OwnerRefs is a list of owner references that identify all operator + // instances that manage this Tailscale Services. + OwnerRefs []OwnerRef `json:"ownerRefs,omitempty"` +} + +// OwnerRef is an owner reference that uniquely identifies a Tailscale +// Kubernetes operator instance. +type OwnerRef struct { + // OperatorID is the stable ID of the operator's Tailscale device. + OperatorID string `json:"operatorID,omitempty"` + Resource *Resource `json:"resource,omitempty"` // optional, used to identify the ProxyGroup that owns this Tailscale Service. +} + +type Resource struct { + Kind string `json:"kind,omitempty"` // "ProxyGroup" + Name string `json:"name,omitempty"` // Name of the ProxyGroup that owns this Tailscale Service. Informational only. + UID string `json:"uid,omitempty"` // UID of the ProxyGroup that owns this Tailscale Service. +} + +// ownerAnnotations returns the updated annotations required to ensure this +// instance of the operator is included as an owner. If the Tailscale Service is not +// nil, but does not contain an owner reference we return an error as this likely means +// that the Service was created by somthing other than a Tailscale +// Kubernetes operator. +func ownerAnnotations(operatorID string, svc *tailscale.VIPService) (map[string]string, error) { + ref := OwnerRef{ + OperatorID: operatorID, + } + if svc == nil { + c := ownerAnnotationValue{OwnerRefs: []OwnerRef{ref}} + json, err := json.Marshal(c) + if err != nil { + return nil, fmt.Errorf("[unexpected] unable to marshal Tailscale Service's owner annotation contents: %w, please report this", err) + } + return map[string]string{ + ownerAnnotation: string(json), + }, nil + } + o, err := parseOwnerAnnotation(svc) + if err != nil { + return nil, err + } + if o == nil || len(o.OwnerRefs) == 0 { + return nil, fmt.Errorf("Tailscale Service %s exists, but does not contain owner annotation with owner references; not proceeding as this is likely a resource created by something other than the Tailscale Kubernetes operator", svc.Name) + } + if slices.Contains(o.OwnerRefs, ref) { // up to date + return svc.Annotations, nil + } + if o.OwnerRefs[0].Resource != nil { + return nil, fmt.Errorf("Tailscale Service %s is owned by another resource: %#v; cannot be reused for an Ingress", svc.Name, o.OwnerRefs[0].Resource) + } + o.OwnerRefs = append(o.OwnerRefs, ref) + json, err := json.Marshal(o) + if err != nil { + return nil, fmt.Errorf("error marshalling updated owner references: %w", err) + } + + newAnnots := make(map[string]string, len(svc.Annotations)+1) + for k, v := range svc.Annotations { + newAnnots[k] = v + } + newAnnots[ownerAnnotation] = string(json) + return newAnnots, nil +} + +// parseOwnerAnnotation returns nil if no valid owner found. +func parseOwnerAnnotation(tsSvc *tailscale.VIPService) (*ownerAnnotationValue, error) { + if tsSvc.Annotations == nil || tsSvc.Annotations[ownerAnnotation] == "" { + return nil, nil + } + o := &ownerAnnotationValue{} + if err := json.Unmarshal([]byte(tsSvc.Annotations[ownerAnnotation]), o); err != nil { + return nil, fmt.Errorf("error parsing Tailscale Service's %s annotation %q: %w", ownerAnnotation, tsSvc.Annotations[ownerAnnotation], err) + } + return o, nil +} + +func ownersAreSetAndEqual(a, b *tailscale.VIPService) bool { + return a != nil && b != nil && + a.Annotations != nil && b.Annotations != nil && + a.Annotations[ownerAnnotation] != "" && + b.Annotations[ownerAnnotation] != "" && + strings.EqualFold(a.Annotations[ownerAnnotation], b.Annotations[ownerAnnotation]) +} + +// ensureCertResources ensures that the TLS Secret for an HA Ingress and RBAC +// resources that allow proxies to manage the Secret are created. +// Note that Tailscale Service's name validation matches Kubernetes +// resource name validation, so we can be certain that the Tailscale Service name +// (domain) is a valid Kubernetes resource name. +// https://github.com/tailscale/tailscale/blob/8b1e7f646ee4730ad06c9b70c13e7861b964949b/util/dnsname/dnsname.go#L99 +// https://kubernetes.io/docs/concepts/overview/working-with-objects/names/#dns-subdomain-names +func (r *HAIngressReconciler) ensureCertResources(ctx context.Context, pg *tsapi.ProxyGroup, domain string, ing *networkingv1.Ingress) error { + secret := certSecret(pg.Name, r.tsNamespace, domain, ing) + if _, err := createOrUpdate(ctx, r.Client, r.tsNamespace, secret, func(s *corev1.Secret) { + // Labels might have changed if the Ingress has been updated to use a + // different ProxyGroup. + s.Labels = secret.Labels + }); err != nil { + return fmt.Errorf("failed to create or update Secret %s: %w", secret.Name, err) + } + role := certSecretRole(pg.Name, r.tsNamespace, domain) + if _, err := createOrUpdate(ctx, r.Client, r.tsNamespace, role, func(r *rbacv1.Role) { + // Labels might have changed if the Ingress has been updated to use a + // different ProxyGroup. + r.Labels = role.Labels + }); err != nil { + return fmt.Errorf("failed to create or update Role %s: %w", role.Name, err) + } + rolebinding := certSecretRoleBinding(pg, r.tsNamespace, domain) + if _, err := createOrUpdate(ctx, r.Client, r.tsNamespace, rolebinding, func(rb *rbacv1.RoleBinding) { + // Labels and subjects might have changed if the Ingress has been updated to use a + // different ProxyGroup. + rb.Labels = rolebinding.Labels + rb.Subjects = rolebinding.Subjects + }); err != nil { + return fmt.Errorf("failed to create or update RoleBinding %s: %w", rolebinding.Name, err) + } + return nil +} + +// cleanupCertResources ensures that the TLS Secret and associated RBAC +// resources that allow proxies to read/write to the Secret are deleted. +func cleanupCertResources(ctx context.Context, cl client.Client, lc localClient, tsNamespace, pgName string, serviceName tailcfg.ServiceName) error { + domainName, err := dnsNameForService(ctx, lc, serviceName) + if err != nil { + return fmt.Errorf("error getting DNS name for Tailscale Service %s: %w", serviceName, err) + } + labels := certResourceLabels(pgName, domainName) + if err := cl.DeleteAllOf(ctx, &rbacv1.RoleBinding{}, client.InNamespace(tsNamespace), client.MatchingLabels(labels)); err != nil { + return fmt.Errorf("error deleting RoleBinding for domain name %s: %w", domainName, err) + } + if err := cl.DeleteAllOf(ctx, &rbacv1.Role{}, client.InNamespace(tsNamespace), client.MatchingLabels(labels)); err != nil { + return fmt.Errorf("error deleting Role for domain name %s: %w", domainName, err) + } + if err := cl.DeleteAllOf(ctx, &corev1.Secret{}, client.InNamespace(tsNamespace), client.MatchingLabels(labels)); err != nil { + return fmt.Errorf("error deleting Secret for domain name %s: %w", domainName, err) + } + return nil +} + +// requeueInterval returns a time duration between 5 and 10 minutes, which is +// the period of time after which an HA Ingress, whose Tailscale Service has been newly +// created or changed, needs to be requeued. This is to protect against +// Tailscale Service's owner references being overwritten as a result of concurrent +// updates during multi-clutster Ingress create/update operations. +func requeueInterval() time.Duration { + return time.Duration(rand.N(5)+5) * time.Minute +} + +// certSecretRole creates a Role that will allow proxies to manage the TLS +// Secret for the given domain. Domain must be a valid Kubernetes resource name. +func certSecretRole(pgName, namespace, domain string) *rbacv1.Role { + return &rbacv1.Role{ + ObjectMeta: metav1.ObjectMeta{ + Name: domain, + Namespace: namespace, + Labels: certResourceLabels(pgName, domain), + }, + Rules: []rbacv1.PolicyRule{ + { + APIGroups: []string{""}, + Resources: []string{"secrets"}, + ResourceNames: []string{domain}, + Verbs: []string{ + "get", + "list", + "patch", + "update", + }, + }, + }, + } +} + +// certSecretRoleBinding creates a RoleBinding for Role that will allow proxies +// to manage the TLS Secret for the given domain. Domain must be a valid +// Kubernetes resource name. +func certSecretRoleBinding(pg *tsapi.ProxyGroup, namespace, domain string) *rbacv1.RoleBinding { + return &rbacv1.RoleBinding{ + ObjectMeta: metav1.ObjectMeta{ + Name: domain, + Namespace: namespace, + Labels: certResourceLabels(pg.Name, domain), + }, + Subjects: []rbacv1.Subject{ + { + Kind: "ServiceAccount", + Name: pgServiceAccountName(pg), + Namespace: namespace, + }, + }, + RoleRef: rbacv1.RoleRef{ + Kind: "Role", + Name: domain, + }, + } +} + +// certSecret creates a Secret that will store the TLS certificate and private +// key for the given domain. Domain must be a valid Kubernetes resource name. +func certSecret(pgName, namespace, domain string, parent client.Object) *corev1.Secret { + labels := certResourceLabels(pgName, domain) + labels[kubetypes.LabelSecretType] = kubetypes.LabelSecretTypeCerts + // Labels that let us identify the Ingress resource lets us reconcile + // the Ingress when the TLS Secret is updated (for example, when TLS + // certs have been provisioned). + labels[LabelParentType] = strings.ToLower(parent.GetObjectKind().GroupVersionKind().Kind) + labels[LabelParentName] = parent.GetName() + if ns := parent.GetNamespace(); ns != "" { + labels[LabelParentNamespace] = ns + } + return &corev1.Secret{ + TypeMeta: metav1.TypeMeta{ + APIVersion: "v1", + Kind: "Secret", + }, + ObjectMeta: metav1.ObjectMeta{ + Name: domain, + Namespace: namespace, + Labels: labels, + }, + Data: map[string][]byte{ + corev1.TLSCertKey: nil, + corev1.TLSPrivateKeyKey: nil, + }, + Type: corev1.SecretTypeTLS, + } +} + +func certResourceLabels(pgName, domain string) map[string]string { + return map[string]string{ + kubetypes.LabelManaged: "true", + labelProxyGroup: pgName, + labelDomain: domain, + } +} + +// dnsNameForService returns the DNS name for the given Tailscale Service's name. +func dnsNameForService(ctx context.Context, lc localClient, svc tailcfg.ServiceName) (string, error) { + s := svc.WithoutPrefix() + tcd, err := tailnetCertDomain(ctx, lc) + if err != nil { + return "", fmt.Errorf("error determining DNS name base: %w", err) + } + return s + "." + tcd, nil +} + +// hasCerts checks if the TLS Secret for the given service has non-zero cert and key data. +func hasCerts(ctx context.Context, cl client.Client, lc localClient, ns string, svc tailcfg.ServiceName) (bool, error) { + domain, err := dnsNameForService(ctx, lc, svc) + if err != nil { + return false, fmt.Errorf("failed to get DNS name for service: %w", err) + } + secret := &corev1.Secret{} + err = cl.Get(ctx, client.ObjectKey{ + Namespace: ns, + Name: domain, + }, secret) + if err != nil { + if apierrors.IsNotFound(err) { + return false, nil + } + return false, fmt.Errorf("failed to get TLS Secret: %w", err) + } + + cert := secret.Data[corev1.TLSCertKey] + key := secret.Data[corev1.TLSPrivateKeyKey] + + return len(cert) > 0 && len(key) > 0, nil +} + +func isErrorTailscaleServiceNotFound(err error) bool { + var errResp tailscale.ErrResponse + ok := errors.As(err, &errResp) + return ok && errResp.Status == http.StatusNotFound +} + +func tagViolations(obj client.Object) []string { + var violations []string + if obj == nil { + return nil + } + tags, ok := obj.GetAnnotations()[AnnotationTags] + if !ok { + return nil + } + + for _, tag := range strings.Split(tags, ",") { + tag = strings.TrimSpace(tag) + if err := tailcfg.CheckTag(tag); err != nil { + violations = append(violations, fmt.Sprintf("invalid tag %q: %v", tag, err)) + } + } + return violations +} diff --git a/cmd/k8s-operator/ingress-for-pg_test.go b/cmd/k8s-operator/ingress-for-pg_test.go new file mode 100644 index 000000000..77e5ecb37 --- /dev/null +++ b/cmd/k8s-operator/ingress-for-pg_test.go @@ -0,0 +1,928 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !plan9 + +package main + +import ( + "context" + "encoding/json" + "fmt" + "maps" + "reflect" + "slices" + "strings" + "testing" + + "github.com/google/go-cmp/cmp" + "go.uber.org/zap" + corev1 "k8s.io/api/core/v1" + networkingv1 "k8s.io/api/networking/v1" + rbacv1 "k8s.io/api/rbac/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/types" + "k8s.io/client-go/tools/record" + "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/client/fake" + "tailscale.com/internal/client/tailscale" + "tailscale.com/ipn" + "tailscale.com/ipn/ipnstate" + tsoperator "tailscale.com/k8s-operator" + tsapi "tailscale.com/k8s-operator/apis/v1alpha1" + "tailscale.com/kube/kubetypes" + "tailscale.com/tailcfg" + "tailscale.com/types/ptr" +) + +func TestIngressPGReconciler(t *testing.T) { + ingPGR, fc, ft := setupIngressTest(t) + + ing := &networkingv1.Ingress{ + TypeMeta: metav1.TypeMeta{Kind: "Ingress", APIVersion: "networking.k8s.io/v1"}, + ObjectMeta: metav1.ObjectMeta{ + Name: "test-ingress", + Namespace: "default", + UID: types.UID("1234-UID"), + Annotations: map[string]string{ + "tailscale.com/proxy-group": "test-pg", + }, + }, + Spec: networkingv1.IngressSpec{ + IngressClassName: ptr.To("tailscale"), + DefaultBackend: &networkingv1.IngressBackend{ + Service: &networkingv1.IngressServiceBackend{ + Name: "test", + Port: networkingv1.ServiceBackendPort{ + Number: 8080, + }, + }, + }, + TLS: []networkingv1.IngressTLS{ + {Hosts: []string{"my-svc"}}, + }, + }, + } + mustCreate(t, fc, ing) + + // Verify initial reconciliation + expectReconciled(t, ingPGR, "default", "test-ingress") + populateTLSSecret(context.Background(), fc, "test-pg", "my-svc.ts.net") + expectReconciled(t, ingPGR, "default", "test-ingress") + verifyServeConfig(t, fc, "svc:my-svc", false) + verifyTailscaleService(t, ft, "svc:my-svc", []string{"tcp:443"}) + verifyTailscaledConfig(t, fc, "test-pg", []string{"svc:my-svc"}) + + // Verify that Role and RoleBinding have been created for the first Ingress. + // Do not verify the cert Secret as that was already verified implicitly above. + pg := &tsapi.ProxyGroup{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-pg", + }, + } + expectEqual(t, fc, certSecretRole("test-pg", "operator-ns", "my-svc.ts.net")) + expectEqual(t, fc, certSecretRoleBinding(pg, "operator-ns", "my-svc.ts.net")) + + mustUpdate(t, fc, "default", "test-ingress", func(ing *networkingv1.Ingress) { + ing.Annotations["tailscale.com/tags"] = "tag:custom,tag:test" + }) + expectReconciled(t, ingPGR, "default", "test-ingress") + + // Verify Tailscale Service uses custom tags + tsSvc, err := ft.GetVIPService(context.Background(), "svc:my-svc") + if err != nil { + t.Fatalf("getting Tailscale Service: %v", err) + } + if tsSvc == nil { + t.Fatal("Tailscale Service not created") + } + wantTags := []string{"tag:custom", "tag:test"} // custom tags only + gotTags := slices.Clone(tsSvc.Tags) + slices.Sort(gotTags) + slices.Sort(wantTags) + if !slices.Equal(gotTags, wantTags) { + t.Errorf("incorrect Tailscale Service tags: got %v, want %v", gotTags, wantTags) + } + + // Create second Ingress + ing2 := &networkingv1.Ingress{ + TypeMeta: metav1.TypeMeta{Kind: "Ingress", APIVersion: "networking.k8s.io/v1"}, + ObjectMeta: metav1.ObjectMeta{ + Name: "my-other-ingress", + Namespace: "default", + UID: types.UID("5678-UID"), + Annotations: map[string]string{ + "tailscale.com/proxy-group": "test-pg", + }, + }, + Spec: networkingv1.IngressSpec{ + IngressClassName: ptr.To("tailscale"), + DefaultBackend: &networkingv1.IngressBackend{ + Service: &networkingv1.IngressServiceBackend{ + Name: "test", + Port: networkingv1.ServiceBackendPort{ + Number: 8080, + }, + }, + }, + TLS: []networkingv1.IngressTLS{ + {Hosts: []string{"my-other-svc.tailnetxyz.ts.net"}}, + }, + }, + } + mustCreate(t, fc, ing2) + + // Verify second Ingress reconciliation + expectReconciled(t, ingPGR, "default", "my-other-ingress") + populateTLSSecret(context.Background(), fc, "test-pg", "my-other-svc.ts.net") + expectReconciled(t, ingPGR, "default", "my-other-ingress") + verifyServeConfig(t, fc, "svc:my-other-svc", false) + verifyTailscaleService(t, ft, "svc:my-other-svc", []string{"tcp:443"}) + + // Verify that Role and RoleBinding have been created for the second Ingress. + // Do not verify the cert Secret as that was already verified implicitly above. + expectEqual(t, fc, certSecretRole("test-pg", "operator-ns", "my-other-svc.ts.net")) + expectEqual(t, fc, certSecretRoleBinding(pg, "operator-ns", "my-other-svc.ts.net")) + + // Verify first Ingress is still working + verifyServeConfig(t, fc, "svc:my-svc", false) + verifyTailscaleService(t, ft, "svc:my-svc", []string{"tcp:443"}) + + verifyTailscaledConfig(t, fc, "test-pg", []string{"svc:my-svc", "svc:my-other-svc"}) + + // Delete second Ingress + if err := fc.Delete(context.Background(), ing2); err != nil { + t.Fatalf("deleting second Ingress: %v", err) + } + expectReconciled(t, ingPGR, "default", "my-other-ingress") + + // Verify second Ingress cleanup + cm := &corev1.ConfigMap{} + if err := fc.Get(context.Background(), types.NamespacedName{ + Name: "test-pg-ingress-config", + Namespace: "operator-ns", + }, cm); err != nil { + t.Fatalf("getting ConfigMap: %v", err) + } + + cfg := &ipn.ServeConfig{} + if err := json.Unmarshal(cm.BinaryData[serveConfigKey], cfg); err != nil { + t.Fatalf("unmarshaling serve config: %v", err) + } + + // Verify first Ingress is still configured + if cfg.Services["svc:my-svc"] == nil { + t.Error("first Ingress service config was incorrectly removed") + } + // Verify second Ingress was cleaned up + if cfg.Services["svc:my-other-svc"] != nil { + t.Error("second Ingress service config was not cleaned up") + } + + verifyTailscaledConfig(t, fc, "test-pg", []string{"svc:my-svc"}) + expectMissing[corev1.Secret](t, fc, "operator-ns", "my-other-svc.ts.net") + expectMissing[rbacv1.Role](t, fc, "operator-ns", "my-other-svc.ts.net") + expectMissing[rbacv1.RoleBinding](t, fc, "operator-ns", "my-other-svc.ts.net") + + // Test Ingress ProxyGroup change + createPGResources(t, fc, "test-pg-second") + mustUpdate(t, fc, "default", "test-ingress", func(ing *networkingv1.Ingress) { + ing.Annotations["tailscale.com/proxy-group"] = "test-pg-second" + }) + expectReconciled(t, ingPGR, "default", "test-ingress") + expectEqual(t, fc, certSecretRole("test-pg-second", "operator-ns", "my-svc.ts.net")) + pg = &tsapi.ProxyGroup{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-pg-second", + }, + } + expectEqual(t, fc, certSecretRoleBinding(pg, "operator-ns", "my-svc.ts.net")) + + // Delete the first Ingress and verify cleanup + if err := fc.Delete(context.Background(), ing); err != nil { + t.Fatalf("deleting Ingress: %v", err) + } + + expectReconciled(t, ingPGR, "default", "test-ingress") + + // Verify the ConfigMap was cleaned up + cm = &corev1.ConfigMap{} + if err := fc.Get(context.Background(), types.NamespacedName{ + Name: "test-pg-second-ingress-config", + Namespace: "operator-ns", + }, cm); err != nil { + t.Fatalf("getting ConfigMap: %v", err) + } + + cfg = &ipn.ServeConfig{} + if err := json.Unmarshal(cm.BinaryData[serveConfigKey], cfg); err != nil { + t.Fatalf("unmarshaling serve config: %v", err) + } + + if len(cfg.Services) > 0 { + t.Error("serve config not cleaned up") + } + verifyTailscaledConfig(t, fc, "test-pg-second", nil) + + // Add verification that cert resources were cleaned up + expectMissing[corev1.Secret](t, fc, "operator-ns", "my-svc.ts.net") + expectMissing[rbacv1.Role](t, fc, "operator-ns", "my-svc.ts.net") + expectMissing[rbacv1.RoleBinding](t, fc, "operator-ns", "my-svc.ts.net") +} + +func TestIngressPGReconciler_UpdateIngressHostname(t *testing.T) { + ingPGR, fc, ft := setupIngressTest(t) + + ing := &networkingv1.Ingress{ + TypeMeta: metav1.TypeMeta{Kind: "Ingress", APIVersion: "networking.k8s.io/v1"}, + ObjectMeta: metav1.ObjectMeta{ + Name: "test-ingress", + Namespace: "default", + UID: types.UID("1234-UID"), + Annotations: map[string]string{ + "tailscale.com/proxy-group": "test-pg", + }, + }, + Spec: networkingv1.IngressSpec{ + IngressClassName: ptr.To("tailscale"), + DefaultBackend: &networkingv1.IngressBackend{ + Service: &networkingv1.IngressServiceBackend{ + Name: "test", + Port: networkingv1.ServiceBackendPort{ + Number: 8080, + }, + }, + }, + TLS: []networkingv1.IngressTLS{ + {Hosts: []string{"my-svc"}}, + }, + }, + } + mustCreate(t, fc, ing) + + // Verify initial reconciliation + expectReconciled(t, ingPGR, "default", "test-ingress") + populateTLSSecret(context.Background(), fc, "test-pg", "my-svc.ts.net") + expectReconciled(t, ingPGR, "default", "test-ingress") + verifyServeConfig(t, fc, "svc:my-svc", false) + verifyTailscaleService(t, ft, "svc:my-svc", []string{"tcp:443"}) + verifyTailscaledConfig(t, fc, "test-pg", []string{"svc:my-svc"}) + + // Update the Ingress hostname and make sure the original Tailscale Service is deleted. + mustUpdate(t, fc, "default", "test-ingress", func(ing *networkingv1.Ingress) { + ing.Spec.TLS[0].Hosts[0] = "updated-svc" + }) + expectReconciled(t, ingPGR, "default", "test-ingress") + populateTLSSecret(context.Background(), fc, "test-pg", "updated-svc.ts.net") + expectReconciled(t, ingPGR, "default", "test-ingress") + verifyServeConfig(t, fc, "svc:updated-svc", false) + verifyTailscaleService(t, ft, "svc:updated-svc", []string{"tcp:443"}) + verifyTailscaledConfig(t, fc, "test-pg", []string{"svc:updated-svc"}) + + _, err := ft.GetVIPService(context.Background(), tailcfg.ServiceName("svc:my-svc")) + if err == nil { + t.Fatalf("svc:my-svc not cleaned up") + } + if !isErrorTailscaleServiceNotFound(err) { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestValidateIngress(t *testing.T) { + baseIngress := &networkingv1.Ingress{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-ingress", + Namespace: "default", + UID: types.UID("1234-UID"), + Annotations: map[string]string{ + AnnotationProxyGroup: "test-pg", + }, + }, + Spec: networkingv1.IngressSpec{ + IngressClassName: ptr.To("tailscale"), + TLS: []networkingv1.IngressTLS{ + {Hosts: []string{"test"}}, + }, + }, + } + + readyProxyGroup := &tsapi.ProxyGroup{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-pg", + Generation: 1, + }, + Spec: tsapi.ProxyGroupSpec{ + Type: tsapi.ProxyGroupTypeIngress, + }, + Status: tsapi.ProxyGroupStatus{ + Conditions: []metav1.Condition{ + { + Type: string(tsapi.ProxyGroupAvailable), + Status: metav1.ConditionTrue, + ObservedGeneration: 1, + }, + }, + }, + } + + tests := []struct { + name string + ing *networkingv1.Ingress + pg *tsapi.ProxyGroup + existingIngs []networkingv1.Ingress + wantErr string + }{ + { + name: "valid_ingress_with_hostname", + ing: &networkingv1.Ingress{ + ObjectMeta: baseIngress.ObjectMeta, + Spec: networkingv1.IngressSpec{ + TLS: []networkingv1.IngressTLS{ + {Hosts: []string{"test.example.com"}}, + }, + }, + }, + pg: readyProxyGroup, + }, + { + name: "valid_ingress_with_default_hostname", + ing: baseIngress, + pg: readyProxyGroup, + }, + { + name: "invalid_tags", + ing: &networkingv1.Ingress{ + ObjectMeta: metav1.ObjectMeta{ + Name: baseIngress.Name, + Namespace: baseIngress.Namespace, + Annotations: map[string]string{ + AnnotationTags: "tag:invalid!", + }, + }, + }, + pg: readyProxyGroup, + wantErr: "Ingress contains invalid tags: invalid tag \"tag:invalid!\": tag names can only contain numbers, letters, or dashes", + }, + { + name: "multiple_TLS_entries", + ing: &networkingv1.Ingress{ + ObjectMeta: baseIngress.ObjectMeta, + Spec: networkingv1.IngressSpec{ + TLS: []networkingv1.IngressTLS{ + {Hosts: []string{"test1.example.com"}}, + {Hosts: []string{"test2.example.com"}}, + }, + }, + }, + pg: readyProxyGroup, + wantErr: "Ingress contains invalid TLS block [{[test1.example.com] } {[test2.example.com] }]: only a single TLS entry with a single host is allowed", + }, + { + name: "multiple_hosts_in_TLS_entry", + ing: &networkingv1.Ingress{ + ObjectMeta: baseIngress.ObjectMeta, + Spec: networkingv1.IngressSpec{ + TLS: []networkingv1.IngressTLS{ + {Hosts: []string{"test1.example.com", "test2.example.com"}}, + }, + }, + }, + pg: readyProxyGroup, + wantErr: "Ingress contains invalid TLS block [{[test1.example.com test2.example.com] }]: only a single TLS entry with a single host is allowed", + }, + { + name: "wrong_proxy_group_type", + ing: baseIngress, + pg: &tsapi.ProxyGroup{ + ObjectMeta: readyProxyGroup.ObjectMeta, + Spec: tsapi.ProxyGroupSpec{ + Type: tsapi.ProxyGroupType("foo"), + }, + Status: readyProxyGroup.Status, + }, + wantErr: "ProxyGroup \"test-pg\" is of type \"foo\" but must be of type \"ingress\"", + }, + { + name: "proxy_group_not_ready", + ing: baseIngress, + pg: &tsapi.ProxyGroup{ + ObjectMeta: readyProxyGroup.ObjectMeta, + Spec: readyProxyGroup.Spec, + Status: tsapi.ProxyGroupStatus{ + Conditions: []metav1.Condition{ + { + Type: string(tsapi.ProxyGroupAvailable), + Status: metav1.ConditionFalse, + ObservedGeneration: 1, + }, + }, + }, + }, + wantErr: "ProxyGroup \"test-pg\" is not ready", + }, + { + name: "duplicate_hostname", + ing: baseIngress, + pg: readyProxyGroup, + existingIngs: []networkingv1.Ingress{{ + ObjectMeta: metav1.ObjectMeta{ + Name: "existing-ingress", + Namespace: "default", + Annotations: map[string]string{ + AnnotationProxyGroup: "test-pg", + }, + }, + Spec: networkingv1.IngressSpec{ + IngressClassName: ptr.To("tailscale"), + TLS: []networkingv1.IngressTLS{ + {Hosts: []string{"test"}}, + }, + }, + }}, + wantErr: `found duplicate Ingress "default/existing-ingress" for hostname "test" - multiple Ingresses for the same hostname in the same cluster are not allowed`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + fc := fake.NewClientBuilder(). + WithScheme(tsapi.GlobalScheme). + WithObjects(tt.ing). + WithLists(&networkingv1.IngressList{Items: tt.existingIngs}). + Build() + + r := &HAIngressReconciler{Client: fc} + if tt.ing.Spec.IngressClassName != nil { + r.ingressClassName = *tt.ing.Spec.IngressClassName + } + + err := r.validateIngress(context.Background(), tt.ing, tt.pg) + if (err == nil && tt.wantErr != "") || (err != nil && err.Error() != tt.wantErr) { + t.Errorf("validateIngress() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestIngressPGReconciler_HTTPEndpoint(t *testing.T) { + ingPGR, fc, ft := setupIngressTest(t) + + // Create test Ingress with HTTP endpoint enabled + ing := &networkingv1.Ingress{ + TypeMeta: metav1.TypeMeta{Kind: "Ingress", APIVersion: "networking.k8s.io/v1"}, + ObjectMeta: metav1.ObjectMeta{ + Name: "test-ingress", + Namespace: "default", + UID: types.UID("1234-UID"), + Annotations: map[string]string{ + "tailscale.com/proxy-group": "test-pg", + "tailscale.com/http-endpoint": "enabled", + }, + }, + Spec: networkingv1.IngressSpec{ + IngressClassName: ptr.To("tailscale"), + DefaultBackend: &networkingv1.IngressBackend{ + Service: &networkingv1.IngressServiceBackend{ + Name: "test", + Port: networkingv1.ServiceBackendPort{ + Number: 8080, + }, + }, + }, + TLS: []networkingv1.IngressTLS{ + {Hosts: []string{"my-svc"}}, + }, + }, + } + if err := fc.Create(context.Background(), ing); err != nil { + t.Fatal(err) + } + + // Verify initial reconciliation with HTTP enabled + expectReconciled(t, ingPGR, "default", "test-ingress") + populateTLSSecret(context.Background(), fc, "test-pg", "my-svc.ts.net") + expectReconciled(t, ingPGR, "default", "test-ingress") + verifyTailscaleService(t, ft, "svc:my-svc", []string{"tcp:80", "tcp:443"}) + verifyServeConfig(t, fc, "svc:my-svc", true) + + // Verify Ingress status + ing = &networkingv1.Ingress{} + if err := fc.Get(context.Background(), types.NamespacedName{ + Name: "test-ingress", + Namespace: "default", + }, ing); err != nil { + t.Fatal(err) + } + + // Status will be empty until the Tailscale Service shows up in prefs. + if !reflect.DeepEqual(ing.Status.LoadBalancer.Ingress, []networkingv1.IngressLoadBalancerIngress(nil)) { + t.Errorf("incorrect Ingress status: got %v, want empty", + ing.Status.LoadBalancer.Ingress) + } + + // Add the Tailscale Service to prefs to have the Ingress recognised as ready. + mustCreate(t, fc, &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-pg-0", + Namespace: "operator-ns", + Labels: pgSecretLabels("test-pg", kubetypes.LabelSecretTypeState), + }, + Data: map[string][]byte{ + "_current-profile": []byte("profile-foo"), + "profile-foo": []byte(`{"AdvertiseServices":["svc:my-svc"],"Config":{"NodeID":"node-foo"}}`), + }, + }) + + // Reconcile and re-fetch Ingress. + expectReconciled(t, ingPGR, "default", "test-ingress") + if err := fc.Get(context.Background(), client.ObjectKeyFromObject(ing), ing); err != nil { + t.Fatal(err) + } + + wantStatus := []networkingv1.IngressPortStatus{ + {Port: 443, Protocol: "TCP"}, + {Port: 80, Protocol: "TCP"}, + } + if !reflect.DeepEqual(ing.Status.LoadBalancer.Ingress[0].Ports, wantStatus) { + t.Errorf("incorrect status ports: got %v, want %v", + ing.Status.LoadBalancer.Ingress[0].Ports, wantStatus) + } + + // Remove HTTP endpoint annotation + mustUpdate(t, fc, "default", "test-ingress", func(ing *networkingv1.Ingress) { + delete(ing.Annotations, "tailscale.com/http-endpoint") + }) + + // Verify reconciliation after removing HTTP + expectReconciled(t, ingPGR, "default", "test-ingress") + verifyTailscaleService(t, ft, "svc:my-svc", []string{"tcp:443"}) + verifyServeConfig(t, fc, "svc:my-svc", false) + + // Verify Ingress status + ing = &networkingv1.Ingress{} + if err := fc.Get(context.Background(), types.NamespacedName{ + Name: "test-ingress", + Namespace: "default", + }, ing); err != nil { + t.Fatal(err) + } + + wantStatus = []networkingv1.IngressPortStatus{ + {Port: 443, Protocol: "TCP"}, + } + if !reflect.DeepEqual(ing.Status.LoadBalancer.Ingress[0].Ports, wantStatus) { + t.Errorf("incorrect status ports: got %v, want %v", + ing.Status.LoadBalancer.Ingress[0].Ports, wantStatus) + } +} + +func TestIngressPGReconciler_MultiCluster(t *testing.T) { + ingPGR, fc, ft := setupIngressTest(t) + ingPGR.operatorID = "operator-1" + + // Create initial Ingress + ing := &networkingv1.Ingress{ + TypeMeta: metav1.TypeMeta{Kind: "Ingress", APIVersion: "networking.k8s.io/v1"}, + ObjectMeta: metav1.ObjectMeta{ + Name: "test-ingress", + Namespace: "default", + UID: types.UID("1234-UID"), + Annotations: map[string]string{ + "tailscale.com/proxy-group": "test-pg", + }, + }, + Spec: networkingv1.IngressSpec{ + IngressClassName: ptr.To("tailscale"), + TLS: []networkingv1.IngressTLS{ + {Hosts: []string{"my-svc"}}, + }, + }, + } + mustCreate(t, fc, ing) + + // Simulate existing Tailscale Service from another cluster + existingVIPSvc := &tailscale.VIPService{ + Name: "svc:my-svc", + Annotations: map[string]string{ + ownerAnnotation: `{"ownerrefs":[{"operatorID":"operator-2"}]}`, + }, + } + ft.vipServices = map[tailcfg.ServiceName]*tailscale.VIPService{ + "svc:my-svc": existingVIPSvc, + } + + // Verify reconciliation adds our operator reference + expectReconciled(t, ingPGR, "default", "test-ingress") + + tsSvc, err := ft.GetVIPService(context.Background(), "svc:my-svc") + if err != nil { + t.Fatalf("getting Tailscale Service: %v", err) + } + if tsSvc == nil { + t.Fatal("Tailscale Service not found") + } + + o, err := parseOwnerAnnotation(tsSvc) + if err != nil { + t.Fatalf("parsing owner annotation: %v", err) + } + + wantOwnerRefs := []OwnerRef{ + {OperatorID: "operator-2"}, + {OperatorID: "operator-1"}, + } + if !reflect.DeepEqual(o.OwnerRefs, wantOwnerRefs) { + t.Errorf("incorrect owner refs\ngot: %+v\nwant: %+v", o.OwnerRefs, wantOwnerRefs) + } + + // Delete the Ingress and verify Tailscale Service still exists with one owner ref + if err := fc.Delete(context.Background(), ing); err != nil { + t.Fatalf("deleting Ingress: %v", err) + } + expectRequeue(t, ingPGR, "default", "test-ingress") + + tsSvc, err = ft.GetVIPService(context.Background(), "svc:my-svc") + if err != nil { + t.Fatalf("getting Tailscale Service after deletion: %v", err) + } + if tsSvc == nil { + t.Fatal("Tailscale Service was incorrectly deleted") + } + + o, err = parseOwnerAnnotation(tsSvc) + if err != nil { + t.Fatalf("parsing owner annotation: %v", err) + } + + wantOwnerRefs = []OwnerRef{ + {OperatorID: "operator-2"}, + } + if !reflect.DeepEqual(o.OwnerRefs, wantOwnerRefs) { + t.Errorf("incorrect owner refs after deletion\ngot: %+v\nwant: %+v", o.OwnerRefs, wantOwnerRefs) + } +} + +func TestOwnerAnnotations(t *testing.T) { + singleSelfOwner := map[string]string{ + ownerAnnotation: `{"ownerRefs":[{"operatorID":"self-id"}]}`, + } + + for name, tc := range map[string]struct { + svc *tailscale.VIPService + wantAnnotations map[string]string + wantErr string + }{ + "no_svc": { + svc: nil, + wantAnnotations: singleSelfOwner, + }, + "empty_svc": { + svc: &tailscale.VIPService{}, + wantErr: "likely a resource created by something other than the Tailscale Kubernetes operator", + }, + "already_owner": { + svc: &tailscale.VIPService{ + Annotations: singleSelfOwner, + }, + wantAnnotations: singleSelfOwner, + }, + "add_owner": { + svc: &tailscale.VIPService{ + Annotations: map[string]string{ + ownerAnnotation: `{"ownerRefs":[{"operatorID":"operator-2"}]}`, + }, + }, + wantAnnotations: map[string]string{ + ownerAnnotation: `{"ownerRefs":[{"operatorID":"operator-2"},{"operatorID":"self-id"}]}`, + }, + }, + "owned_by_proxygroup": { + svc: &tailscale.VIPService{ + Annotations: map[string]string{ + ownerAnnotation: `{"ownerRefs":[{"operatorID":"self-id","resource":{"kind":"ProxyGroup","name":"test-pg","uid":"1234-UID"}}]}`, + }, + }, + wantErr: "owned by another resource", + }, + } { + t.Run(name, func(t *testing.T) { + got, err := ownerAnnotations("self-id", tc.svc) + if tc.wantErr != "" && !strings.Contains(err.Error(), tc.wantErr) { + t.Errorf("ownerAnnotations() error = %v, wantErr %v", err, tc.wantErr) + } + if diff := cmp.Diff(tc.wantAnnotations, got); diff != "" { + t.Errorf("ownerAnnotations() mismatch (-want +got):\n%s", diff) + } + }) + } +} + +func populateTLSSecret(ctx context.Context, c client.Client, pgName, domain string) error { + secret := &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: domain, + Namespace: "operator-ns", + Labels: map[string]string{ + kubetypes.LabelManaged: "true", + labelProxyGroup: pgName, + labelDomain: domain, + kubetypes.LabelSecretType: kubetypes.LabelSecretTypeCerts, + }, + }, + Type: corev1.SecretTypeTLS, + Data: map[string][]byte{ + corev1.TLSCertKey: []byte("fake-cert"), + corev1.TLSPrivateKeyKey: []byte("fake-key"), + }, + } + + _, err := createOrUpdate(ctx, c, "operator-ns", secret, func(s *corev1.Secret) { + s.Data = secret.Data + }) + return err +} + +func verifyTailscaleService(t *testing.T, ft *fakeTSClient, serviceName string, wantPorts []string) { + t.Helper() + tsSvc, err := ft.GetVIPService(context.Background(), tailcfg.ServiceName(serviceName)) + if err != nil { + t.Fatalf("getting Tailscale Service %q: %v", serviceName, err) + } + if tsSvc == nil { + t.Fatalf("Tailscale Service %q not created", serviceName) + } + gotPorts := slices.Clone(tsSvc.Ports) + slices.Sort(gotPorts) + slices.Sort(wantPorts) + if !slices.Equal(gotPorts, wantPorts) { + t.Errorf("incorrect ports for Tailscale Service %q: got %v, want %v", serviceName, gotPorts, wantPorts) + } +} + +func verifyServeConfig(t *testing.T, fc client.Client, serviceName string, wantHTTP bool) { + t.Helper() + + cm := &corev1.ConfigMap{} + if err := fc.Get(context.Background(), types.NamespacedName{ + Name: "test-pg-ingress-config", + Namespace: "operator-ns", + }, cm); err != nil { + t.Fatalf("getting ConfigMap: %v", err) + } + + cfg := &ipn.ServeConfig{} + if err := json.Unmarshal(cm.BinaryData["serve-config.json"], cfg); err != nil { + t.Fatalf("unmarshaling serve config: %v", err) + } + + t.Logf("Looking for service %q in config: %+v", serviceName, cfg) + + svc := cfg.Services[tailcfg.ServiceName(serviceName)] + if svc == nil { + t.Fatalf("service %q not found in serve config, services: %+v", serviceName, maps.Keys(cfg.Services)) + } + + wantHandlers := 1 + if wantHTTP { + wantHandlers = 2 + } + + // Check TCP handlers + if len(svc.TCP) != wantHandlers { + t.Errorf("incorrect number of TCP handlers for service %q: got %d, want %d", serviceName, len(svc.TCP), wantHandlers) + } + if wantHTTP { + if h, ok := svc.TCP[uint16(80)]; !ok { + t.Errorf("HTTP (port 80) handler not found for service %q", serviceName) + } else if !h.HTTP { + t.Errorf("HTTP not enabled for port 80 handler for service %q", serviceName) + } + } + if h, ok := svc.TCP[uint16(443)]; !ok { + t.Errorf("HTTPS (port 443) handler not found for service %q", serviceName) + } else if !h.HTTPS { + t.Errorf("HTTPS not enabled for port 443 handler for service %q", serviceName) + } + + // Check Web handlers + if len(svc.Web) != wantHandlers { + t.Errorf("incorrect number of Web handlers for service %q: got %d, want %d", serviceName, len(svc.Web), wantHandlers) + } +} + +func verifyTailscaledConfig(t *testing.T, fc client.Client, pgName string, expectedServices []string) { + t.Helper() + var expected string + if expectedServices != nil && len(expectedServices) > 0 { + expectedServicesJSON, err := json.Marshal(expectedServices) + if err != nil { + t.Fatalf("marshaling expected services: %v", err) + } + expected = fmt.Sprintf(`,"AdvertiseServices":%s`, expectedServicesJSON) + } + expectEqual(t, fc, &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: pgConfigSecretName(pgName, 0), + Namespace: "operator-ns", + Labels: pgSecretLabels(pgName, kubetypes.LabelSecretTypeConfig), + }, + Data: map[string][]byte{ + tsoperator.TailscaledConfigFileName(pgMinCapabilityVersion): []byte(fmt.Sprintf(`{"Version":""%s}`, expected)), + }, + }) +} + +func createPGResources(t *testing.T, fc client.Client, pgName string) { + t.Helper() + // Pre-create the ProxyGroup + pg := &tsapi.ProxyGroup{ + ObjectMeta: metav1.ObjectMeta{ + Name: pgName, + Generation: 1, + }, + Spec: tsapi.ProxyGroupSpec{ + Type: tsapi.ProxyGroupTypeIngress, + }, + } + mustCreate(t, fc, pg) + + // Pre-create the ConfigMap for the ProxyGroup + pgConfigMap := &corev1.ConfigMap{ + ObjectMeta: metav1.ObjectMeta{ + Name: fmt.Sprintf("%s-ingress-config", pgName), + Namespace: "operator-ns", + }, + BinaryData: map[string][]byte{ + "serve-config.json": []byte(`{"Services":{}}`), + }, + } + mustCreate(t, fc, pgConfigMap) + + // Pre-create a config Secret for the ProxyGroup + pgCfgSecret := &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: pgConfigSecretName(pgName, 0), + Namespace: "operator-ns", + Labels: pgSecretLabels(pgName, kubetypes.LabelSecretTypeConfig), + }, + Data: map[string][]byte{ + tsoperator.TailscaledConfigFileName(pgMinCapabilityVersion): []byte("{}"), + }, + } + mustCreate(t, fc, pgCfgSecret) + pg.Status.Conditions = []metav1.Condition{ + { + Type: string(tsapi.ProxyGroupAvailable), + Status: metav1.ConditionTrue, + ObservedGeneration: 1, + }, + } + if err := fc.Status().Update(context.Background(), pg); err != nil { + t.Fatal(err) + } +} + +func setupIngressTest(t *testing.T) (*HAIngressReconciler, client.Client, *fakeTSClient) { + tsIngressClass := &networkingv1.IngressClass{ + ObjectMeta: metav1.ObjectMeta{Name: "tailscale"}, + Spec: networkingv1.IngressClassSpec{Controller: "tailscale.com/ts-ingress"}, + } + + fc := fake.NewClientBuilder(). + WithScheme(tsapi.GlobalScheme). + WithObjects(tsIngressClass). + WithStatusSubresource(&tsapi.ProxyGroup{}). + Build() + + createPGResources(t, fc, "test-pg") + + fakeTsnetServer := &fakeTSNetServer{certDomains: []string{"foo.com"}} + + ft := &fakeTSClient{} + zl, err := zap.NewDevelopment() + if err != nil { + t.Fatal(err) + } + + lc := &fakeLocalClient{ + status: &ipnstate.Status{ + CurrentTailnet: &ipnstate.TailnetStatus{ + MagicDNSSuffix: "ts.net", + }, + }, + } + + ingPGR := &HAIngressReconciler{ + Client: fc, + tsClient: ft, + defaultTags: []string{"tag:k8s"}, + tsNamespace: "operator-ns", + tsnetServer: fakeTsnetServer, + logger: zl.Sugar(), + recorder: record.NewFakeRecorder(10), + lc: lc, + ingressClassName: tsIngressClass.Name, + } + + return ingPGR, fc, ft +} diff --git a/cmd/k8s-operator/ingress.go b/cmd/k8s-operator/ingress.go index 700cf4be8..fb11f717d 100644 --- a/cmd/k8s-operator/ingress.go +++ b/cmd/k8s-operator/ingress.go @@ -22,17 +22,19 @@ import ( "k8s.io/client-go/tools/record" "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/reconcile" + "tailscale.com/ipn" "tailscale.com/kube/kubetypes" "tailscale.com/types/opt" "tailscale.com/util/clientmetric" + "tailscale.com/util/mak" "tailscale.com/util/set" ) const ( - tailscaleIngressClassName = "tailscale" // ingressClass.metadata.name for tailscale IngressClass resource tailscaleIngressControllerName = "tailscale.com/ts-ingress" // ingressClass.spec.controllerName for tailscale IngressClass resource ingressClassDefaultAnnotation = "ingressclass.kubernetes.io/is-default-class" // we do not support this https://kubernetes.io/docs/concepts/services-networking/ingress/#default-ingress-class + indexIngressProxyClass = ".metadata.annotations.ingress-proxy-class" ) type IngressReconciler struct { @@ -48,7 +50,8 @@ type IngressReconciler struct { // managing. This is only used for metrics. managedIngresses set.Slice[types.UID] - proxyDefaultClass string + defaultProxyClass string + ingressClassName string } var ( @@ -58,7 +61,7 @@ var ( ) func (a *IngressReconciler) Reconcile(ctx context.Context, req reconcile.Request) (_ reconcile.Result, err error) { - logger := a.logger.With("ingress-ns", req.Namespace, "ingress-name", req.Name) + logger := a.logger.With("Ingress", req.NamespacedName) logger.Debugf("starting reconcile") defer logger.Debugf("reconcile finished") @@ -72,11 +75,20 @@ func (a *IngressReconciler) Reconcile(ctx context.Context, req reconcile.Request return reconcile.Result{}, fmt.Errorf("failed to get ing: %w", err) } if !ing.DeletionTimestamp.IsZero() || !a.shouldExpose(ing) { + // TODO(irbekrm): this message is confusing if the Ingress is an HA Ingress logger.Debugf("ingress is being deleted or should not be exposed, cleaning up") return reconcile.Result{}, a.maybeCleanup(ctx, logger, ing) } - return reconcile.Result{}, a.maybeProvision(ctx, logger, ing) + if err := a.maybeProvision(ctx, logger, ing); err != nil { + if strings.Contains(err.Error(), optimisticLockErrorMsg) { + logger.Infof("optimistic lock error, retrying: %s", err) + } else { + return reconcile.Result{}, err + } + } + + return reconcile.Result{}, nil } func (a *IngressReconciler) maybeCleanup(ctx context.Context, logger *zap.SugaredLogger, ing *networkingv1.Ingress) error { @@ -90,7 +102,7 @@ func (a *IngressReconciler) maybeCleanup(ctx context.Context, logger *zap.Sugare return nil } - if done, err := a.ssr.Cleanup(ctx, logger, childResourceLabels(ing.Name, ing.Namespace, "ingress")); err != nil { + if done, err := a.ssr.Cleanup(ctx, logger, childResourceLabels(ing.Name, ing.Namespace, "ingress"), proxyTypeIngressResource); err != nil { return fmt.Errorf("failed to cleanup: %w", err) } else if !done { logger.Debugf("cleanup not done yet, waiting for next reconcile") @@ -120,9 +132,8 @@ func (a *IngressReconciler) maybeCleanup(ctx context.Context, logger *zap.Sugare // This function adds a finalizer to ing, ensuring that we can handle orderly // deprovisioning later. func (a *IngressReconciler) maybeProvision(ctx context.Context, logger *zap.SugaredLogger, ing *networkingv1.Ingress) error { - if err := a.validateIngressClass(ctx); err != nil { + if err := validateIngressClass(ctx, a.Client, a.ingressClassName); err != nil { logger.Warnf("error validating tailscale IngressClass: %v. In future this might be a terminal error.", err) - } if !slices.Contains(ing.Finalizers, FinalizerName) { // This log line is printed exactly once during initial provisioning, @@ -136,7 +147,7 @@ func (a *IngressReconciler) maybeProvision(ctx context.Context, logger *zap.Suga } } - proxyClass := proxyClassForObject(ing, a.proxyDefaultClass) + proxyClass := proxyClassForObject(ing, a.defaultProxyClass) if proxyClass != "" { if ready, err := proxyClassIsReady(ctx, proxyClass, a.Client); err != nil { return fmt.Errorf("error verifying ProxyClass for Ingress: %w", err) @@ -151,7 +162,7 @@ func (a *IngressReconciler) maybeProvision(ctx context.Context, logger *zap.Suga gaugeIngressResources.Set(int64(a.managedIngresses.Len())) a.mu.Unlock() - if !a.ssr.IsHTTPSEnabledOnTailnet() { + if !IsHTTPSEnabledOnTailnet(a.ssr.tsnetServer) { a.recorder.Event(ing, corev1.EventTypeWarning, "HTTPSNotEnabled", "HTTPS is not enabled on the tailnet; ingress may not work") } @@ -177,73 +188,16 @@ func (a *IngressReconciler) maybeProvision(ctx context.Context, logger *zap.Suga } web := sc.Web[magic443] - addIngressBackend := func(b *networkingv1.IngressBackend, path string) { - if b == nil { - return - } - if b.Service == nil { - a.recorder.Eventf(ing, corev1.EventTypeWarning, "InvalidIngressBackend", "backend for path %q is missing service", path) - return - } - var svc corev1.Service - if err := a.Get(ctx, types.NamespacedName{Namespace: ing.Namespace, Name: b.Service.Name}, &svc); err != nil { - a.recorder.Eventf(ing, corev1.EventTypeWarning, "InvalidIngressBackend", "failed to get service %q for path %q: %v", b.Service.Name, path, err) - return - } - if svc.Spec.ClusterIP == "" || svc.Spec.ClusterIP == "None" { - a.recorder.Eventf(ing, corev1.EventTypeWarning, "InvalidIngressBackend", "backend for path %q has invalid ClusterIP", path) - return - } - var port int32 - if b.Service.Port.Name != "" { - for _, p := range svc.Spec.Ports { - if p.Name == b.Service.Port.Name { - port = p.Port - break - } - } - } else { - port = b.Service.Port.Number - } - if port == 0 { - a.recorder.Eventf(ing, corev1.EventTypeWarning, "InvalidIngressBackend", "backend for path %q has invalid port", path) - return - } - proto := "http://" - if port == 443 || b.Service.Port.Name == "https" { - proto = "https+insecure://" - } - web.Handlers[path] = &ipn.HTTPHandler{ - Proxy: proto + svc.Spec.ClusterIP + ":" + fmt.Sprint(port) + path, - } - } - addIngressBackend(ing.Spec.DefaultBackend, "/") var tlsHost string // hostname or FQDN or empty if ing.Spec.TLS != nil && len(ing.Spec.TLS) > 0 && len(ing.Spec.TLS[0].Hosts) > 0 { tlsHost = ing.Spec.TLS[0].Hosts[0] } - for _, rule := range ing.Spec.Rules { - // Host is optional, but if it's present it must match the TLS host - // otherwise we ignore the rule. - if rule.Host != "" && rule.Host != tlsHost { - a.recorder.Eventf(ing, corev1.EventTypeWarning, "InvalidIngressBackend", "rule with host %q ignored, unsupported", rule.Host) - continue - } - for _, p := range rule.HTTP.Paths { - // Send a warning if folks use Exact path type - to make - // it easier for us to support Exact path type matching - // in the future if needed. - // https://kubernetes.io/docs/concepts/services-networking/ingress/#path-types - if *p.PathType == networkingv1.PathTypeExact { - msg := "Exact path type strict matching is currently not supported and requests will be routed as for Prefix path type. This behaviour might change in the future." - logger.Warnf(fmt.Sprintf("Unsupported Path type exact for path %s. %s", p.Path, msg)) - a.recorder.Eventf(ing, corev1.EventTypeWarning, "UnsupportedPathTypeExact", msg) - } - addIngressBackend(&p.Backend, p.Path) - } + handlers, err := handlersForIngress(ctx, ing, a.Client, a.recorder, tlsHost, logger) + if err != nil { + return fmt.Errorf("failed to get handlers for ingress: %w", err) } - + web.Handlers = handlers if len(web.Handlers) == 0 { logger.Warn("Ingress contains no valid backends") a.recorder.Eventf(ing, corev1.EventTypeWarning, "NoValidBackends", "no valid backends") @@ -255,12 +209,10 @@ func (a *IngressReconciler) maybeProvision(ctx context.Context, logger *zap.Suga if tstr, ok := ing.Annotations[AnnotationTags]; ok { tags = strings.Split(tstr, ",") } - hostname := ing.Namespace + "-" + ing.Name + "-ingress" - if tlsHost != "" { - hostname, _, _ = strings.Cut(tlsHost, ".") - } + hostname := hostnameForIngress(ing) sts := &tailscaleSTSConfig{ + Replicas: 1, Hostname: hostname, ParentResourceName: ing.Name, ParentResourceUID: string(ing.UID), @@ -268,73 +220,157 @@ func (a *IngressReconciler) maybeProvision(ctx context.Context, logger *zap.Suga Tags: tags, ChildResourceLabels: crl, ProxyClassName: proxyClass, + proxyType: proxyTypeIngressResource, + LoginServer: a.ssr.loginServer, } if val := ing.GetAnnotations()[AnnotationExperimentalForwardClusterTrafficViaL7IngresProxy]; val == "true" { sts.ForwardClusterTrafficViaL7IngressProxy = true } - if _, err := a.ssr.Provision(ctx, logger, sts); err != nil { + if _, err = a.ssr.Provision(ctx, logger, sts); err != nil { return fmt.Errorf("failed to provision: %w", err) } - _, tsHost, _, err := a.ssr.DeviceInfo(ctx, crl) + devices, err := a.ssr.DeviceInfo(ctx, crl, logger) if err != nil { - return fmt.Errorf("failed to get device ID: %w", err) + return fmt.Errorf("failed to retrieve Ingress HTTPS endpoint status: %w", err) } - if tsHost == "" { - logger.Debugf("no Tailscale hostname known yet, waiting for proxy pod to finish auth") - // No hostname yet. Wait for the proxy pod to auth. - ing.Status.LoadBalancer.Ingress = nil - if err := a.Status().Update(ctx, ing); err != nil { - return fmt.Errorf("failed to update ingress status: %w", err) + + ing.Status.LoadBalancer.Ingress = nil + for _, dev := range devices { + if dev.ingressDNSName == "" { + continue } - return nil - } - logger.Debugf("setting ingress hostname to %q", tsHost) - ing.Status.LoadBalancer.Ingress = []networkingv1.IngressLoadBalancerIngress{ - { - Hostname: tsHost, + logger.Debugf("setting Ingress hostname to %q", dev.ingressDNSName) + ing.Status.LoadBalancer.Ingress = append(ing.Status.LoadBalancer.Ingress, networkingv1.IngressLoadBalancerIngress{ + Hostname: dev.ingressDNSName, Ports: []networkingv1.IngressPortStatus{ { Protocol: "TCP", Port: 443, }, }, - }, + }) } - if err := a.Status().Update(ctx, ing); err != nil { + + if err = a.Status().Update(ctx, ing); err != nil { return fmt.Errorf("failed to update ingress status: %w", err) } + return nil } func (a *IngressReconciler) shouldExpose(ing *networkingv1.Ingress) bool { return ing != nil && ing.Spec.IngressClassName != nil && - *ing.Spec.IngressClassName == tailscaleIngressClassName + *ing.Spec.IngressClassName == a.ingressClassName && + ing.Annotations[AnnotationProxyGroup] == "" } // validateIngressClass attempts to validate that 'tailscale' IngressClass // included in Tailscale installation manifests exists and has not been modified // to attempt to enable features that we do not support. -func (a *IngressReconciler) validateIngressClass(ctx context.Context) error { +func validateIngressClass(ctx context.Context, cl client.Client, ingressClassName string) error { ic := &networkingv1.IngressClass{ ObjectMeta: metav1.ObjectMeta{ - Name: tailscaleIngressClassName, + Name: ingressClassName, }, } - if err := a.Get(ctx, client.ObjectKeyFromObject(ic), ic); apierrors.IsNotFound(err) { - return errors.New("Tailscale IngressClass not found in cluster. Latest installation manifests include a tailscale IngressClass - please update") + if err := cl.Get(ctx, client.ObjectKeyFromObject(ic), ic); apierrors.IsNotFound(err) { + return errors.New("'tailscale' IngressClass not found in cluster.") } else if err != nil { return fmt.Errorf("error retrieving 'tailscale' IngressClass: %w", err) } if ic.Spec.Controller != tailscaleIngressControllerName { - return fmt.Errorf("Tailscale Ingress class controller name %s does not match tailscale Ingress controller name %s. Ensure that you are using 'tailscale' IngressClass from latest Tailscale installation manifests", ic.Spec.Controller, tailscaleIngressControllerName) + return fmt.Errorf("'tailscale' Ingress class controller name %s does not match tailscale Ingress controller name %s. Ensure that you are using 'tailscale' IngressClass from latest Tailscale installation manifests", ic.Spec.Controller, tailscaleIngressControllerName) } if ic.GetAnnotations()[ingressClassDefaultAnnotation] != "" { return fmt.Errorf("%s annotation is set on 'tailscale' IngressClass, but Tailscale Ingress controller does not support default Ingress class. Ensure that you are using 'tailscale' IngressClass from latest Tailscale installation manifests", ingressClassDefaultAnnotation) } return nil } + +func handlersForIngress(ctx context.Context, ing *networkingv1.Ingress, cl client.Client, rec record.EventRecorder, tlsHost string, logger *zap.SugaredLogger) (handlers map[string]*ipn.HTTPHandler, err error) { + addIngressBackend := func(b *networkingv1.IngressBackend, path string) { + if path == "" { + path = "/" + rec.Eventf(ing, corev1.EventTypeNormal, "PathUndefined", "configured backend is missing a path, defaulting to '/'") + } + + if b == nil { + return + } + + if b.Service == nil { + rec.Eventf(ing, corev1.EventTypeWarning, "InvalidIngressBackend", "backend for path %q is missing service", path) + return + } + var svc corev1.Service + if err := cl.Get(ctx, types.NamespacedName{Namespace: ing.Namespace, Name: b.Service.Name}, &svc); err != nil { + rec.Eventf(ing, corev1.EventTypeWarning, "InvalidIngressBackend", "failed to get service %q for path %q: %v", b.Service.Name, path, err) + return + } + if svc.Spec.ClusterIP == "" || svc.Spec.ClusterIP == "None" { + rec.Eventf(ing, corev1.EventTypeWarning, "InvalidIngressBackend", "backend for path %q has invalid ClusterIP", path) + return + } + var port int32 + if b.Service.Port.Name != "" { + for _, p := range svc.Spec.Ports { + if p.Name == b.Service.Port.Name { + port = p.Port + break + } + } + } else { + port = b.Service.Port.Number + } + if port == 0 { + rec.Eventf(ing, corev1.EventTypeWarning, "InvalidIngressBackend", "backend for path %q has invalid port", path) + return + } + proto := "http://" + if port == 443 || b.Service.Port.Name == "https" { + proto = "https+insecure://" + } + mak.Set(&handlers, path, &ipn.HTTPHandler{ + Proxy: proto + svc.Spec.ClusterIP + ":" + fmt.Sprint(port) + path, + }) + } + addIngressBackend(ing.Spec.DefaultBackend, "/") + for _, rule := range ing.Spec.Rules { + // Host is optional, but if it's present it must match the TLS host + // otherwise we ignore the rule. + if rule.Host != "" && rule.Host != tlsHost { + rec.Eventf(ing, corev1.EventTypeWarning, "InvalidIngressBackend", "rule with host %q ignored, unsupported", rule.Host) + continue + } + for _, p := range rule.HTTP.Paths { + // Send a warning if folks use Exact path type - to make + // it easier for us to support Exact path type matching + // in the future if needed. + // https://kubernetes.io/docs/concepts/services-networking/ingress/#path-types + if *p.PathType == networkingv1.PathTypeExact { + msg := "Exact path type strict matching is currently not supported and requests will be routed as for Prefix path type. This behaviour might change in the future." + logger.Warnf(fmt.Sprintf("Unsupported Path type exact for path %s. %s", p.Path, msg)) + rec.Eventf(ing, corev1.EventTypeWarning, "UnsupportedPathTypeExact", msg) + } + addIngressBackend(&p.Backend, p.Path) + } + } + return handlers, nil +} + +// hostnameForIngress returns the hostname for an Ingress resource. +// If the Ingress has TLS configured with a host, it returns the first component of that host. +// Otherwise, it returns a hostname derived from the Ingress name and namespace. +func hostnameForIngress(ing *networkingv1.Ingress) string { + if ing.Spec.TLS != nil && len(ing.Spec.TLS) > 0 && len(ing.Spec.TLS[0].Hosts) > 0 { + h := ing.Spec.TLS[0].Hosts[0] + hostname, _, _ := strings.Cut(h, ".") + return hostname + } + return ing.Namespace + "-" + ing.Name + "-ingress" +} diff --git a/cmd/k8s-operator/ingress_test.go b/cmd/k8s-operator/ingress_test.go index 8b18776b4..f5e23cfe9 100644 --- a/cmd/k8s-operator/ingress_test.go +++ b/cmd/k8s-operator/ingress_test.go @@ -6,25 +6,29 @@ package main import ( + "context" "testing" "go.uber.org/zap" appsv1 "k8s.io/api/apps/v1" corev1 "k8s.io/api/core/v1" networkingv1 "k8s.io/api/networking/v1" + apiextensionsv1 "k8s.io/apiextensions-apiserver/pkg/apis/apiextensions/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/types" + "k8s.io/client-go/tools/record" + "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/client/fake" "tailscale.com/ipn" tsapi "tailscale.com/k8s-operator/apis/v1alpha1" "tailscale.com/kube/kubetypes" + "tailscale.com/tstest" "tailscale.com/types/ptr" "tailscale.com/util/mak" ) func TestTailscaleIngress(t *testing.T) { - tsIngressClass := &networkingv1.IngressClass{ObjectMeta: metav1.ObjectMeta{Name: "tailscale"}, Spec: networkingv1.IngressClassSpec{Controller: "tailscale.com/ts-ingress"}} - fc := fake.NewFakeClient(tsIngressClass) + fc := fake.NewFakeClient(ingressClass()) ft := &fakeTSClient{} fakeTsnetServer := &fakeTSNetServer{certDomains: []string{"foo.com"}} zl, err := zap.NewDevelopment() @@ -32,7 +36,8 @@ func TestTailscaleIngress(t *testing.T) { t.Fatal(err) } ingR := &IngressReconciler{ - Client: fc, + Client: fc, + ingressClassName: "tailscale", ssr: &tailscaleSTSReconciler{ Client: fc, tsClient: ft, @@ -45,50 +50,14 @@ func TestTailscaleIngress(t *testing.T) { } // 1. Resources get created for regular Ingress - ing := &networkingv1.Ingress{ - TypeMeta: metav1.TypeMeta{Kind: "Ingress", APIVersion: "networking.k8s.io/v1"}, - ObjectMeta: metav1.ObjectMeta{ - Name: "test", - Namespace: "default", - // The apiserver is supposed to set the UID, but the fake client - // doesn't. So, set it explicitly because other code later depends - // on it being set. - UID: types.UID("1234-UID"), - }, - Spec: networkingv1.IngressSpec{ - IngressClassName: ptr.To("tailscale"), - DefaultBackend: &networkingv1.IngressBackend{ - Service: &networkingv1.IngressServiceBackend{ - Name: "test", - Port: networkingv1.ServiceBackendPort{ - Number: 8080, - }, - }, - }, - TLS: []networkingv1.IngressTLS{ - {Hosts: []string{"default-test"}}, - }, - }, - } - mustCreate(t, fc, ing) - mustCreate(t, fc, &corev1.Service{ - ObjectMeta: metav1.ObjectMeta{ - Name: "test", - Namespace: "default", - }, - Spec: corev1.ServiceSpec{ - ClusterIP: "1.2.3.4", - Ports: []corev1.ServicePort{{ - Port: 8080, - Name: "http"}, - }, - }, - }) + mustCreate(t, fc, ingress()) + mustCreate(t, fc, service()) expectReconciled(t, ingR, "default", "test") fullName, shortName := findGenName(t, fc, "default", "test", "ingress") opts := configOpts{ + replicas: ptr.To[int32](1), stsName: shortName, secretName: fullName, namespace: "default", @@ -102,9 +71,9 @@ func TestTailscaleIngress(t *testing.T) { } opts.serveConfig = serveConfig - expectEqual(t, fc, expectedSecret(t, fc, opts), nil) - expectEqual(t, fc, expectedHeadlessService(shortName, "ingress"), nil) - expectEqual(t, fc, expectedSTSUserspace(t, fc, opts), removeHashAnnotation) + expectEqual(t, fc, expectedSecret(t, fc, opts)) + expectEqual(t, fc, expectedHeadlessService(shortName, "ingress")) + expectEqual(t, fc, expectedSTSUserspace(t, fc, opts), removeResourceReqs) // 2. Ingress status gets updated with ingress proxy's MagicDNS name // once that becomes available. @@ -113,13 +82,16 @@ func TestTailscaleIngress(t *testing.T) { mak.Set(&secret.Data, "device_fqdn", []byte("foo.tailnetxyz.ts.net")) }) expectReconciled(t, ingR, "default", "test") + + // Get the ingress and update it with expected changes + ing := ingress() ing.Finalizers = append(ing.Finalizers, "tailscale.com/finalizer") ing.Status.LoadBalancer = networkingv1.IngressLoadBalancerStatus{ Ingress: []networkingv1.IngressLoadBalancerIngress{ {Hostname: "foo.tailnetxyz.ts.net", Ports: []networkingv1.IngressPortStatus{{Port: 443, Protocol: "TCP"}}}, }, } - expectEqual(t, fc, ing, nil) + expectEqual(t, fc, ing) // 3. Resources get created for Ingress that should allow forwarding // cluster traffic @@ -128,7 +100,7 @@ func TestTailscaleIngress(t *testing.T) { }) opts.shouldEnableForwardingClusterTrafficViaIngress = true expectReconciled(t, ingR, "default", "test") - expectEqual(t, fc, expectedSTS(t, fc, opts), removeHashAnnotation) + expectEqual(t, fc, expectedSTS(t, fc, opts), removeResourceReqs) // 4. Resources get cleaned up when Ingress class is unset mustUpdate(t, fc, "default", "test", func(ing *networkingv1.Ingress) { @@ -141,19 +113,132 @@ func TestTailscaleIngress(t *testing.T) { expectMissing[corev1.Secret](t, fc, "operator-ns", fullName) } +func TestTailscaleIngressHostname(t *testing.T) { + fc := fake.NewFakeClient(ingressClass()) + ft := &fakeTSClient{} + fakeTsnetServer := &fakeTSNetServer{certDomains: []string{"foo.com"}} + zl, err := zap.NewDevelopment() + if err != nil { + t.Fatal(err) + } + ingR := &IngressReconciler{ + Client: fc, + ingressClassName: "tailscale", + ssr: &tailscaleSTSReconciler{ + Client: fc, + tsClient: ft, + tsnetServer: fakeTsnetServer, + defaultTags: []string{"tag:k8s"}, + operatorNamespace: "operator-ns", + proxyImage: "tailscale/tailscale", + }, + logger: zl.Sugar(), + } + + // 1. Resources get created for regular Ingress + mustCreate(t, fc, ingress()) + mustCreate(t, fc, service()) + + expectReconciled(t, ingR, "default", "test") + + fullName, shortName := findGenName(t, fc, "default", "test", "ingress") + mustCreate(t, fc, &corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: fullName, + Namespace: "operator-ns", + UID: "test-uid", + }, + }) + opts := configOpts{ + stsName: shortName, + secretName: fullName, + namespace: "default", + parentType: "ingress", + hostname: "default-test", + app: kubetypes.AppIngressResource, + } + serveConfig := &ipn.ServeConfig{ + TCP: map[uint16]*ipn.TCPPortHandler{443: {HTTPS: true}}, + Web: map[ipn.HostPort]*ipn.WebServerConfig{"${TS_CERT_DOMAIN}:443": {Handlers: map[string]*ipn.HTTPHandler{"/": {Proxy: "http://1.2.3.4:8080/"}}}}, + } + opts.serveConfig = serveConfig + + expectEqual(t, fc, expectedSecret(t, fc, opts)) + expectEqual(t, fc, expectedHeadlessService(shortName, "ingress")) + expectEqual(t, fc, expectedSTSUserspace(t, fc, opts), removeResourceReqs) + + // 2. Ingress proxy with capability version >= 110 does not have an HTTPS endpoint set + mustUpdate(t, fc, "operator-ns", opts.secretName, func(secret *corev1.Secret) { + mak.Set(&secret.Data, "device_id", []byte("1234")) + mak.Set(&secret.Data, "tailscale_capver", []byte("110")) + mak.Set(&secret.Data, "pod_uid", []byte("test-uid")) + mak.Set(&secret.Data, "device_fqdn", []byte("foo.tailnetxyz.ts.net")) + }) + expectReconciled(t, ingR, "default", "test") + + // Get the ingress and update it with expected changes + ing := ingress() + ing.Finalizers = append(ing.Finalizers, "tailscale.com/finalizer") + expectEqual(t, fc, ing) + + // 3. Ingress proxy with capability version >= 110 advertises HTTPS endpoint + mustUpdate(t, fc, "operator-ns", opts.secretName, func(secret *corev1.Secret) { + mak.Set(&secret.Data, "device_id", []byte("1234")) + mak.Set(&secret.Data, "tailscale_capver", []byte("110")) + mak.Set(&secret.Data, "pod_uid", []byte("test-uid")) + mak.Set(&secret.Data, "device_fqdn", []byte("foo.tailnetxyz.ts.net")) + mak.Set(&secret.Data, "https_endpoint", []byte("foo.tailnetxyz.ts.net")) + }) + expectReconciled(t, ingR, "default", "test") + ing.Status.LoadBalancer = networkingv1.IngressLoadBalancerStatus{ + Ingress: []networkingv1.IngressLoadBalancerIngress{ + {Hostname: "foo.tailnetxyz.ts.net", Ports: []networkingv1.IngressPortStatus{{Port: 443, Protocol: "TCP"}}}, + }, + } + expectEqual(t, fc, ing) + + // 4. Ingress proxy with capability version >= 110 does not have an HTTPS endpoint ready + mustUpdate(t, fc, "operator-ns", opts.secretName, func(secret *corev1.Secret) { + mak.Set(&secret.Data, "device_id", []byte("1234")) + mak.Set(&secret.Data, "tailscale_capver", []byte("110")) + mak.Set(&secret.Data, "pod_uid", []byte("test-uid")) + mak.Set(&secret.Data, "device_fqdn", []byte("foo.tailnetxyz.ts.net")) + mak.Set(&secret.Data, "https_endpoint", []byte("no-https")) + }) + expectReconciled(t, ingR, "default", "test") + ing.Status.LoadBalancer.Ingress = nil + expectEqual(t, fc, ing) + + // 5. Ingress proxy's state has https_endpoints set, but its capver is not matching Pod UID (downgrade) + mustUpdate(t, fc, "operator-ns", opts.secretName, func(secret *corev1.Secret) { + mak.Set(&secret.Data, "device_id", []byte("1234")) + mak.Set(&secret.Data, "tailscale_capver", []byte("110")) + mak.Set(&secret.Data, "pod_uid", []byte("not-the-right-uid")) + mak.Set(&secret.Data, "device_fqdn", []byte("foo.tailnetxyz.ts.net")) + mak.Set(&secret.Data, "https_endpoint", []byte("bar.tailnetxyz.ts.net")) + }) + ing.Status.LoadBalancer = networkingv1.IngressLoadBalancerStatus{ + Ingress: []networkingv1.IngressLoadBalancerIngress{ + {Hostname: "foo.tailnetxyz.ts.net", Ports: []networkingv1.IngressPortStatus{{Port: 443, Protocol: "TCP"}}}, + }, + } + expectReconciled(t, ingR, "default", "test") + expectEqual(t, fc, ing) +} + func TestTailscaleIngressWithProxyClass(t *testing.T) { // Setup pc := &tsapi.ProxyClass{ ObjectMeta: metav1.ObjectMeta{Name: "custom-metadata"}, Spec: tsapi.ProxyClassSpec{StatefulSet: &tsapi.StatefulSet{ - Labels: map[string]string{"foo": "bar"}, + Labels: tsapi.Labels{"foo": "bar"}, Annotations: map[string]string{"bar.io/foo": "some-val"}, - Pod: &tsapi.Pod{Annotations: map[string]string{"foo.io/bar": "some-val"}}}}, + Pod: &tsapi.Pod{Annotations: map[string]string{"foo.io/bar": "some-val"}}, + }}, } - tsIngressClass := &networkingv1.IngressClass{ObjectMeta: metav1.ObjectMeta{Name: "tailscale"}, Spec: networkingv1.IngressClassSpec{Controller: "tailscale.com/ts-ingress"}} fc := fake.NewClientBuilder(). WithScheme(tsapi.GlobalScheme). - WithObjects(pc, tsIngressClass). + WithObjects(pc, ingressClass()). WithStatusSubresource(pc). Build() ft := &fakeTSClient{} @@ -163,7 +248,8 @@ func TestTailscaleIngressWithProxyClass(t *testing.T) { t.Fatal(err) } ingR := &IngressReconciler{ - Client: fc, + Client: fc, + ingressClassName: "tailscale", ssr: &tailscaleSTSReconciler{ Client: fc, tsClient: ft, @@ -177,45 +263,8 @@ func TestTailscaleIngressWithProxyClass(t *testing.T) { // 1. Ingress is created with no ProxyClass specified, default proxy // resources get configured. - ing := &networkingv1.Ingress{ - TypeMeta: metav1.TypeMeta{Kind: "Ingress", APIVersion: "networking.k8s.io/v1"}, - ObjectMeta: metav1.ObjectMeta{ - Name: "test", - Namespace: "default", - // The apiserver is supposed to set the UID, but the fake client - // doesn't. So, set it explicitly because other code later depends - // on it being set. - UID: types.UID("1234-UID"), - }, - Spec: networkingv1.IngressSpec{ - IngressClassName: ptr.To("tailscale"), - DefaultBackend: &networkingv1.IngressBackend{ - Service: &networkingv1.IngressServiceBackend{ - Name: "test", - Port: networkingv1.ServiceBackendPort{ - Number: 8080, - }, - }, - }, - TLS: []networkingv1.IngressTLS{ - {Hosts: []string{"default-test"}}, - }, - }, - } - mustCreate(t, fc, ing) - mustCreate(t, fc, &corev1.Service{ - ObjectMeta: metav1.ObjectMeta{ - Name: "test", - Namespace: "default", - }, - Spec: corev1.ServiceSpec{ - ClusterIP: "1.2.3.4", - Ports: []corev1.ServicePort{{ - Port: 8080, - Name: "http"}, - }, - }, - }) + mustCreate(t, fc, ingress()) + mustCreate(t, fc, service()) expectReconciled(t, ingR, "default", "test") @@ -234,17 +283,17 @@ func TestTailscaleIngressWithProxyClass(t *testing.T) { } opts.serveConfig = serveConfig - expectEqual(t, fc, expectedSecret(t, fc, opts), nil) - expectEqual(t, fc, expectedHeadlessService(shortName, "ingress"), nil) - expectEqual(t, fc, expectedSTSUserspace(t, fc, opts), removeHashAnnotation) + expectEqual(t, fc, expectedSecret(t, fc, opts)) + expectEqual(t, fc, expectedHeadlessService(shortName, "ingress")) + expectEqual(t, fc, expectedSTSUserspace(t, fc, opts), removeResourceReqs) // 2. Ingress is updated to specify a ProxyClass, ProxyClass is not yet // ready, so proxy resource configuration does not change. mustUpdate(t, fc, "default", "test", func(ing *networkingv1.Ingress) { - mak.Set(&ing.ObjectMeta.Labels, LabelProxyClass, "custom-metadata") + mak.Set(&ing.ObjectMeta.Labels, LabelAnnotationProxyClass, "custom-metadata") }) expectReconciled(t, ingR, "default", "test") - expectEqual(t, fc, expectedSTSUserspace(t, fc, opts), removeHashAnnotation) + expectEqual(t, fc, expectedSTSUserspace(t, fc, opts), removeResourceReqs) // 3. ProxyClass is set to Ready by proxy-class reconciler. Ingress get // reconciled and configuration from the ProxyClass is applied to the @@ -253,21 +302,517 @@ func TestTailscaleIngressWithProxyClass(t *testing.T) { pc.Status = tsapi.ProxyClassStatus{ Conditions: []metav1.Condition{{ Status: metav1.ConditionTrue, - Type: string(tsapi.ProxyClassready), + Type: string(tsapi.ProxyClassReady), ObservedGeneration: pc.Generation, - }}} + }}, + } }) expectReconciled(t, ingR, "default", "test") opts.proxyClass = pc.Name - expectEqual(t, fc, expectedSTSUserspace(t, fc, opts), removeHashAnnotation) + expectEqual(t, fc, expectedSTSUserspace(t, fc, opts), removeResourceReqs) // 4. tailscale.com/proxy-class label is removed from the Ingress, the // Ingress gets reconciled and the custom ProxyClass configuration is // removed from the proxy resources. mustUpdate(t, fc, "default", "test", func(ing *networkingv1.Ingress) { - delete(ing.ObjectMeta.Labels, LabelProxyClass) + delete(ing.ObjectMeta.Labels, LabelAnnotationProxyClass) }) expectReconciled(t, ingR, "default", "test") opts.proxyClass = "" - expectEqual(t, fc, expectedSTSUserspace(t, fc, opts), removeHashAnnotation) + expectEqual(t, fc, expectedSTSUserspace(t, fc, opts), removeResourceReqs) +} + +func TestTailscaleIngressWithServiceMonitor(t *testing.T) { + pc := &tsapi.ProxyClass{ + ObjectMeta: metav1.ObjectMeta{Name: "metrics", Generation: 1}, + Spec: tsapi.ProxyClassSpec{}, + Status: tsapi.ProxyClassStatus{ + Conditions: []metav1.Condition{{ + Status: metav1.ConditionTrue, + Type: string(tsapi.ProxyClassReady), + ObservedGeneration: 1, + }}, + }, + } + crd := &apiextensionsv1.CustomResourceDefinition{ObjectMeta: metav1.ObjectMeta{Name: serviceMonitorCRD}} + + // Create fake client with ProxyClass, IngressClass, Ingress with metrics ProxyClass, and Service + ing := ingress() + ing.Labels = map[string]string{ + LabelAnnotationProxyClass: "metrics", + } + fc := fake.NewClientBuilder(). + WithScheme(tsapi.GlobalScheme). + WithObjects(pc, ingressClass(), ing, service()). + WithStatusSubresource(pc). + Build() + + ft := &fakeTSClient{} + fakeTsnetServer := &fakeTSNetServer{certDomains: []string{"foo.com"}} + zl, err := zap.NewDevelopment() + if err != nil { + t.Fatal(err) + } + ingR := &IngressReconciler{ + Client: fc, + ingressClassName: "tailscale", + ssr: &tailscaleSTSReconciler{ + Client: fc, + tsClient: ft, + tsnetServer: fakeTsnetServer, + defaultTags: []string{"tag:k8s"}, + operatorNamespace: "operator-ns", + proxyImage: "tailscale/tailscale", + }, + logger: zl.Sugar(), + } + expectReconciled(t, ingR, "default", "test") + fullName, shortName := findGenName(t, fc, "default", "test", "ingress") + serveConfig := &ipn.ServeConfig{ + TCP: map[uint16]*ipn.TCPPortHandler{443: {HTTPS: true}}, + Web: map[ipn.HostPort]*ipn.WebServerConfig{"${TS_CERT_DOMAIN}:443": {Handlers: map[string]*ipn.HTTPHandler{"/": {Proxy: "http://1.2.3.4:8080/"}}}}, + } + opts := configOpts{ + stsName: shortName, + secretName: fullName, + namespace: "default", + tailscaleNamespace: "operator-ns", + parentType: "ingress", + hostname: "default-test", + app: kubetypes.AppIngressResource, + namespaced: true, + proxyType: proxyTypeIngressResource, + serveConfig: serveConfig, + resourceVersion: "1", + } + + // 1. Enable metrics- expect metrics Service to be created + mustUpdate(t, fc, "", "metrics", func(proxyClass *tsapi.ProxyClass) { + proxyClass.Spec.Metrics = &tsapi.Metrics{Enable: true} + }) + opts.enableMetrics = true + + expectReconciled(t, ingR, "default", "test") + + expectEqual(t, fc, expectedMetricsService(opts)) + + // 2. Enable ServiceMonitor - should not error when there is no ServiceMonitor CRD in cluster + mustUpdate(t, fc, "", "metrics", func(pc *tsapi.ProxyClass) { + pc.Spec.Metrics.ServiceMonitor = &tsapi.ServiceMonitor{Enable: true, Labels: tsapi.Labels{"foo": "bar"}} + }) + expectReconciled(t, ingR, "default", "test") + expectEqual(t, fc, expectedMetricsService(opts)) + + // 3. Create ServiceMonitor CRD and reconcile- ServiceMonitor should get created + mustCreate(t, fc, crd) + expectReconciled(t, ingR, "default", "test") + opts.serviceMonitorLabels = tsapi.Labels{"foo": "bar"} + expectEqual(t, fc, expectedMetricsService(opts)) + expectEqualUnstructured(t, fc, expectedServiceMonitor(t, opts)) + + // 4. Update ServiceMonitor CRD and reconcile- ServiceMonitor should get updated + mustUpdate(t, fc, pc.Namespace, pc.Name, func(proxyClass *tsapi.ProxyClass) { + proxyClass.Spec.Metrics.ServiceMonitor.Labels = nil + }) + expectReconciled(t, ingR, "default", "test") + opts.serviceMonitorLabels = nil + opts.resourceVersion = "2" + expectEqual(t, fc, expectedMetricsService(opts)) + expectEqualUnstructured(t, fc, expectedServiceMonitor(t, opts)) + + // 5. Disable metrics - metrics resources should get deleted. + mustUpdate(t, fc, pc.Namespace, pc.Name, func(proxyClass *tsapi.ProxyClass) { + proxyClass.Spec.Metrics = nil + }) + expectReconciled(t, ingR, "default", "test") + expectMissing[corev1.Service](t, fc, "operator-ns", metricsResourceName(shortName)) + // ServiceMonitor gets garbage collected when the Service is deleted - we cannot test that here. +} + +func TestIngressProxyClassAnnotation(t *testing.T) { + cl := tstest.NewClock(tstest.ClockOpts{}) + zl := zap.Must(zap.NewDevelopment()) + + pcLEStaging, pcLEStagingFalse, _ := proxyClassesForLEStagingTest() + + testCases := []struct { + name string + proxyClassAnnotation string + proxyClassLabel string + proxyClassDefault string + expectedProxyClass string + expectEvents []string + }{ + { + name: "via_label", + proxyClassLabel: pcLEStaging.Name, + expectedProxyClass: pcLEStaging.Name, + }, + { + name: "via_annotation", + proxyClassAnnotation: pcLEStaging.Name, + expectedProxyClass: pcLEStaging.Name, + }, + { + name: "via_default", + proxyClassDefault: pcLEStaging.Name, + expectedProxyClass: pcLEStaging.Name, + }, + { + name: "via_label_override_annotation", + proxyClassLabel: pcLEStaging.Name, + proxyClassAnnotation: pcLEStagingFalse.Name, + expectedProxyClass: pcLEStaging.Name, + }, + } + + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + builder := fake.NewClientBuilder(). + WithScheme(tsapi.GlobalScheme) + + builder = builder.WithObjects(pcLEStaging, pcLEStagingFalse). + WithStatusSubresource(pcLEStaging, pcLEStagingFalse) + + fc := builder.Build() + + if tt.proxyClassAnnotation != "" || tt.proxyClassLabel != "" || tt.proxyClassDefault != "" { + name := tt.proxyClassDefault + if name == "" { + name = tt.proxyClassLabel + if name == "" { + name = tt.proxyClassAnnotation + } + } + setProxyClassReady(t, fc, cl, name) + } + + mustCreate(t, fc, ingressClass()) + mustCreate(t, fc, service()) + ing := ingress() + if tt.proxyClassLabel != "" { + ing.Labels = map[string]string{ + LabelAnnotationProxyClass: tt.proxyClassLabel, + } + } + if tt.proxyClassAnnotation != "" { + ing.Annotations = map[string]string{ + LabelAnnotationProxyClass: tt.proxyClassAnnotation, + } + } + mustCreate(t, fc, ing) + + ingR := &IngressReconciler{ + Client: fc, + ingressClassName: "tailscale", + ssr: &tailscaleSTSReconciler{ + Client: fc, + tsClient: &fakeTSClient{}, + tsnetServer: &fakeTSNetServer{certDomains: []string{"test-host"}}, + defaultTags: []string{"tag:test"}, + operatorNamespace: "operator-ns", + proxyImage: "tailscale/tailscale:test", + }, + logger: zl.Sugar(), + defaultProxyClass: tt.proxyClassDefault, + } + + expectReconciled(t, ingR, "default", "test") + + _, shortName := findGenName(t, fc, "default", "test", "ingress") + sts := &appsv1.StatefulSet{} + if err := fc.Get(context.Background(), client.ObjectKey{Namespace: "operator-ns", Name: shortName}, sts); err != nil { + t.Fatalf("failed to get StatefulSet: %v", err) + } + + switch tt.expectedProxyClass { + case pcLEStaging.Name: + verifyEnvVar(t, sts, "TS_DEBUG_ACME_DIRECTORY_URL", letsEncryptStagingEndpoint) + case pcLEStagingFalse.Name: + verifyEnvVarNotPresent(t, sts, "TS_DEBUG_ACME_DIRECTORY_URL") + default: + t.Fatalf("unexpected expected ProxyClass %q", tt.expectedProxyClass) + } + }) + } +} + +func TestIngressLetsEncryptStaging(t *testing.T) { + cl := tstest.NewClock(tstest.ClockOpts{}) + zl := zap.Must(zap.NewDevelopment()) + + pcLEStaging, pcLEStagingFalse, pcOther := proxyClassesForLEStagingTest() + + testCases := testCasesForLEStagingTests() + + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + builder := fake.NewClientBuilder(). + WithScheme(tsapi.GlobalScheme) + + builder = builder.WithObjects(pcLEStaging, pcLEStagingFalse, pcOther). + WithStatusSubresource(pcLEStaging, pcLEStagingFalse, pcOther) + + fc := builder.Build() + + if tt.proxyClassPerResource != "" || tt.defaultProxyClass != "" { + name := tt.proxyClassPerResource + if name == "" { + name = tt.defaultProxyClass + } + setProxyClassReady(t, fc, cl, name) + } + + mustCreate(t, fc, ingressClass()) + mustCreate(t, fc, service()) + ing := ingress() + if tt.proxyClassPerResource != "" { + ing.Labels = map[string]string{ + LabelAnnotationProxyClass: tt.proxyClassPerResource, + } + } + mustCreate(t, fc, ing) + + ingR := &IngressReconciler{ + Client: fc, + ingressClassName: "tailscale", + ssr: &tailscaleSTSReconciler{ + Client: fc, + tsClient: &fakeTSClient{}, + tsnetServer: &fakeTSNetServer{certDomains: []string{"test-host"}}, + defaultTags: []string{"tag:test"}, + operatorNamespace: "operator-ns", + proxyImage: "tailscale/tailscale:test", + }, + logger: zl.Sugar(), + defaultProxyClass: tt.defaultProxyClass, + } + + expectReconciled(t, ingR, "default", "test") + + _, shortName := findGenName(t, fc, "default", "test", "ingress") + sts := &appsv1.StatefulSet{} + if err := fc.Get(context.Background(), client.ObjectKey{Namespace: "operator-ns", Name: shortName}, sts); err != nil { + t.Fatalf("failed to get StatefulSet: %v", err) + } + + if tt.useLEStagingEndpoint { + verifyEnvVar(t, sts, "TS_DEBUG_ACME_DIRECTORY_URL", letsEncryptStagingEndpoint) + } else { + verifyEnvVarNotPresent(t, sts, "TS_DEBUG_ACME_DIRECTORY_URL") + } + }) + } +} + +func TestEmptyPath(t *testing.T) { + testCases := []struct { + name string + paths []networkingv1.HTTPIngressPath + expectedEvents []string + }{ + { + name: "empty_path_with_prefix_type", + paths: []networkingv1.HTTPIngressPath{ + { + PathType: ptrPathType(networkingv1.PathTypePrefix), + Path: "", + Backend: *backend(), + }, + }, + expectedEvents: []string{ + "Normal PathUndefined configured backend is missing a path, defaulting to '/'", + }, + }, + { + name: "empty_path_with_implementation_specific_type", + paths: []networkingv1.HTTPIngressPath{ + { + PathType: ptrPathType(networkingv1.PathTypeImplementationSpecific), + Path: "", + Backend: *backend(), + }, + }, + expectedEvents: []string{ + "Normal PathUndefined configured backend is missing a path, defaulting to '/'", + }, + }, + { + name: "empty_path_with_exact_type", + paths: []networkingv1.HTTPIngressPath{ + { + PathType: ptrPathType(networkingv1.PathTypeExact), + Path: "", + Backend: *backend(), + }, + }, + expectedEvents: []string{ + "Warning UnsupportedPathTypeExact Exact path type strict matching is currently not supported and requests will be routed as for Prefix path type. This behaviour might change in the future.", + "Normal PathUndefined configured backend is missing a path, defaulting to '/'", + }, + }, + { + name: "two_competing_but_not_identical_paths_including_one_empty", + paths: []networkingv1.HTTPIngressPath{ + { + PathType: ptrPathType(networkingv1.PathTypeImplementationSpecific), + Path: "", + Backend: *backend(), + }, + { + PathType: ptrPathType(networkingv1.PathTypeImplementationSpecific), + Path: "/", + Backend: *backend(), + }, + }, + expectedEvents: []string{ + "Normal PathUndefined configured backend is missing a path, defaulting to '/'", + }, + }, + } + + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + fc := fake.NewFakeClient(ingressClass()) + ft := &fakeTSClient{} + fr := record.NewFakeRecorder(3) // bump this if you expect a test case to throw more events + fakeTsnetServer := &fakeTSNetServer{certDomains: []string{"foo.com"}} + zl, err := zap.NewDevelopment() + if err != nil { + t.Fatal(err) + } + ingR := &IngressReconciler{ + recorder: fr, + Client: fc, + ingressClassName: "tailscale", + ssr: &tailscaleSTSReconciler{ + Client: fc, + tsClient: ft, + tsnetServer: fakeTsnetServer, + defaultTags: []string{"tag:k8s"}, + operatorNamespace: "operator-ns", + proxyImage: "tailscale/tailscale", + }, + logger: zl.Sugar(), + } + + // 1. Resources get created for regular Ingress + mustCreate(t, fc, ingressWithPaths(tt.paths)) + mustCreate(t, fc, service()) + + expectReconciled(t, ingR, "default", "test") + + fullName, shortName := findGenName(t, fc, "default", "test", "ingress") + mustCreate(t, fc, &corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: fullName, + Namespace: "operator-ns", + UID: "test-uid", + }, + }) + opts := configOpts{ + stsName: shortName, + secretName: fullName, + namespace: "default", + parentType: "ingress", + hostname: "foo", + app: kubetypes.AppIngressResource, + } + serveConfig := &ipn.ServeConfig{ + TCP: map[uint16]*ipn.TCPPortHandler{443: {HTTPS: true}}, + Web: map[ipn.HostPort]*ipn.WebServerConfig{"${TS_CERT_DOMAIN}:443": {Handlers: map[string]*ipn.HTTPHandler{"/": {Proxy: "http://1.2.3.4:8080/"}}}}, + } + opts.serveConfig = serveConfig + + expectEqual(t, fc, expectedSecret(t, fc, opts)) + expectEqual(t, fc, expectedHeadlessService(shortName, "ingress")) + expectEqual(t, fc, expectedSTSUserspace(t, fc, opts), removeResourceReqs) + + expectEvents(t, fr, tt.expectedEvents) + }) + } +} + +// ptrPathType is a helper function to return a pointer to the pathtype string (required for TestEmptyPath) +func ptrPathType(p networkingv1.PathType) *networkingv1.PathType { + return &p +} + +func ingressClass() *networkingv1.IngressClass { + return &networkingv1.IngressClass{ + ObjectMeta: metav1.ObjectMeta{Name: "tailscale"}, + Spec: networkingv1.IngressClassSpec{Controller: "tailscale.com/ts-ingress"}, + } +} + +func service() *corev1.Service { + return &corev1.Service{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test", + Namespace: "default", + }, + Spec: corev1.ServiceSpec{ + ClusterIP: "1.2.3.4", + Ports: []corev1.ServicePort{{ + Port: 8080, + Name: "http"}, + }, + }, + } +} + +func ingress() *networkingv1.Ingress { + return &networkingv1.Ingress{ + TypeMeta: metav1.TypeMeta{Kind: "Ingress", APIVersion: "networking.k8s.io/v1"}, + ObjectMeta: metav1.ObjectMeta{ + Name: "test", + Namespace: "default", + UID: "1234-UID", + }, + Spec: networkingv1.IngressSpec{ + IngressClassName: ptr.To("tailscale"), + DefaultBackend: backend(), + TLS: []networkingv1.IngressTLS{ + {Hosts: []string{"default-test"}}, + }, + }, + } +} + +func ingressWithPaths(paths []networkingv1.HTTPIngressPath) *networkingv1.Ingress { + return &networkingv1.Ingress{ + TypeMeta: metav1.TypeMeta{Kind: "Ingress", APIVersion: "networking.k8s.io/v1"}, + ObjectMeta: metav1.ObjectMeta{ + Name: "test", + Namespace: "default", + UID: types.UID("1234-UID"), + }, + Spec: networkingv1.IngressSpec{ + IngressClassName: ptr.To("tailscale"), + Rules: []networkingv1.IngressRule{ + { + Host: "foo.tailnetxyz.ts.net", + IngressRuleValue: networkingv1.IngressRuleValue{ + HTTP: &networkingv1.HTTPIngressRuleValue{ + Paths: paths, + }, + }, + }, + }, + TLS: []networkingv1.IngressTLS{ + {Hosts: []string{"foo.tailnetxyz.ts.net"}}, + }, + }, + } +} + +func backend() *networkingv1.IngressBackend { + return &networkingv1.IngressBackend{ + Service: &networkingv1.IngressServiceBackend{ + Name: "test", + Port: networkingv1.ServiceBackendPort{ + Number: 8080, + }, + }, + } } diff --git a/cmd/k8s-operator/logger.go b/cmd/k8s-operator/logger.go new file mode 100644 index 000000000..46b1fc0c8 --- /dev/null +++ b/cmd/k8s-operator/logger.go @@ -0,0 +1,26 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !plan9 + +package main + +import ( + "io" + + "go.uber.org/zap" + "go.uber.org/zap/zapcore" + kzap "sigs.k8s.io/controller-runtime/pkg/log/zap" +) + +// wrapZapCore returns a zapcore.Core implementation that splits the core chain using zapcore.NewTee. This causes +// logs to be simultaneously written to both the original core and the provided io.Writer implementation. +func wrapZapCore(core zapcore.Core, writer io.Writer) zapcore.Core { + encoder := &kzap.KubeAwareEncoder{ + Encoder: zapcore.NewJSONEncoder(zap.NewProductionEncoderConfig()), + } + + // We use a tee logger here so that logs are written to stdout/stderr normally while at the same time being + // sent upstream. + return zapcore.NewTee(core, zapcore.NewCore(encoder, zapcore.AddSync(writer), zap.DebugLevel)) +} diff --git a/cmd/k8s-operator/metrics_resources.go b/cmd/k8s-operator/metrics_resources.go new file mode 100644 index 000000000..0579e3466 --- /dev/null +++ b/cmd/k8s-operator/metrics_resources.go @@ -0,0 +1,296 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !plan9 + +package main + +import ( + "context" + "fmt" + "reflect" + + "go.uber.org/zap" + corev1 "k8s.io/api/core/v1" + apierrors "k8s.io/apimachinery/pkg/api/errors" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/types" + "sigs.k8s.io/controller-runtime/pkg/client" + tsapi "tailscale.com/k8s-operator/apis/v1alpha1" + "tailscale.com/kube/kubetypes" +) + +const ( + labelMetricsTarget = "tailscale.com/metrics-target" + + // These labels get transferred from the metrics Service to the ingested Prometheus metrics. + labelPromProxyType = "ts_proxy_type" + labelPromProxyParentName = "ts_proxy_parent_name" + labelPromProxyParentNamespace = "ts_proxy_parent_namespace" + labelPromJob = "ts_prom_job" + + serviceMonitorCRD = "servicemonitors.monitoring.coreos.com" +) + +// ServiceMonitor contains a subset of fields of servicemonitors.monitoring.coreos.com Custom Resource Definition. +// Duplicating it here allows us to avoid importing prometheus-operator library. +// https://github.com/prometheus-operator/prometheus-operator/blob/bb4514e0d5d69f20270e29cfd4ad39b87865ccdf/pkg/apis/monitoring/v1/servicemonitor_types.go#L40 +type ServiceMonitor struct { + metav1.TypeMeta `json:",inline"` + metav1.ObjectMeta `json:"metadata"` + Spec ServiceMonitorSpec `json:"spec"` +} + +// https://github.com/prometheus-operator/prometheus-operator/blob/bb4514e0d5d69f20270e29cfd4ad39b87865ccdf/pkg/apis/monitoring/v1/servicemonitor_types.go#L55 +type ServiceMonitorSpec struct { + // Endpoints defines the endpoints to be scraped on the selected Service(s). + // https://github.com/prometheus-operator/prometheus-operator/blob/bb4514e0d5d69f20270e29cfd4ad39b87865ccdf/pkg/apis/monitoring/v1/servicemonitor_types.go#L82 + Endpoints []ServiceMonitorEndpoint `json:"endpoints"` + // JobLabel is the label on the Service whose value will become the value of the Prometheus job label for the metrics ingested via this ServiceMonitor. + // https://github.com/prometheus-operator/prometheus-operator/blob/bb4514e0d5d69f20270e29cfd4ad39b87865ccdf/pkg/apis/monitoring/v1/servicemonitor_types.go#L66 + JobLabel string `json:"jobLabel"` + // NamespaceSelector selects the namespace of Service(s) that this ServiceMonitor allows to scrape. + // https://github.com/prometheus-operator/prometheus-operator/blob/bb4514e0d5d69f20270e29cfd4ad39b87865ccdf/pkg/apis/monitoring/v1/servicemonitor_types.go#L88 + NamespaceSelector ServiceMonitorNamespaceSelector `json:"namespaceSelector,omitempty"` + // Selector is the label selector for Service(s) that this ServiceMonitor allows to scrape. + // https://github.com/prometheus-operator/prometheus-operator/blob/bb4514e0d5d69f20270e29cfd4ad39b87865ccdf/pkg/apis/monitoring/v1/servicemonitor_types.go#L85 + Selector metav1.LabelSelector `json:"selector"` + // TargetLabels are labels on the selected Service that should be applied as Prometheus labels to the ingested metrics. + // https://github.com/prometheus-operator/prometheus-operator/blob/bb4514e0d5d69f20270e29cfd4ad39b87865ccdf/pkg/apis/monitoring/v1/servicemonitor_types.go#L72 + TargetLabels []string `json:"targetLabels"` +} + +// ServiceMonitorNamespaceSelector selects namespaces in which Prometheus operator will attempt to find Services for +// this ServiceMonitor. +// https://github.com/prometheus-operator/prometheus-operator/blob/bb4514e0d5d69f20270e29cfd4ad39b87865ccdf/pkg/apis/monitoring/v1/servicemonitor_types.go#L88 +type ServiceMonitorNamespaceSelector struct { + MatchNames []string `json:"matchNames,omitempty"` +} + +// ServiceMonitorEndpoint defines an endpoint of Service to scrape. We only define port here. Prometheus by default +// scrapes /metrics path, which is what we want. +type ServiceMonitorEndpoint struct { + // Port is the name of the Service port that Prometheus will scrape. + Port string `json:"port,omitempty"` +} + +func reconcileMetricsResources(ctx context.Context, logger *zap.SugaredLogger, opts *metricsOpts, pc *tsapi.ProxyClass, cl client.Client) error { + if opts.proxyType == proxyTypeEgress { + // Metrics are currently not being enabled for standalone egress proxies. + return nil + } + if pc == nil || pc.Spec.Metrics == nil || !pc.Spec.Metrics.Enable { + return maybeCleanupMetricsResources(ctx, opts, cl) + } + metricsSvc := &corev1.Service{ + ObjectMeta: metav1.ObjectMeta{ + Name: metricsResourceName(opts.proxyStsName), + Namespace: opts.tsNamespace, + Labels: metricsResourceLabels(opts), + }, + Spec: corev1.ServiceSpec{ + Selector: opts.proxyLabels, + Type: corev1.ServiceTypeClusterIP, + Ports: []corev1.ServicePort{{Protocol: "TCP", Port: 9002, Name: "metrics"}}, + }, + } + var err error + metricsSvc, err = createOrUpdate(ctx, cl, opts.tsNamespace, metricsSvc, func(svc *corev1.Service) { + svc.Spec.Ports = metricsSvc.Spec.Ports + svc.Spec.Selector = metricsSvc.Spec.Selector + }) + if err != nil { + return fmt.Errorf("error ensuring metrics Service: %w", err) + } + + crdExists, err := hasServiceMonitorCRD(ctx, cl) + if err != nil { + return fmt.Errorf("error verifying that %q CRD exists: %w", serviceMonitorCRD, err) + } + if !crdExists { + return nil + } + + if pc.Spec.Metrics.ServiceMonitor == nil || !pc.Spec.Metrics.ServiceMonitor.Enable { + return maybeCleanupServiceMonitor(ctx, cl, opts.proxyStsName, opts.tsNamespace) + } + + logger.Infof("ensuring ServiceMonitor for metrics Service %s/%s", metricsSvc.Namespace, metricsSvc.Name) + svcMonitor, err := newServiceMonitor(metricsSvc, pc.Spec.Metrics.ServiceMonitor) + if err != nil { + return fmt.Errorf("error creating ServiceMonitor: %w", err) + } + + // We don't use createOrUpdate here because that does not work with unstructured types. + existing := svcMonitor.DeepCopy() + err = cl.Get(ctx, client.ObjectKeyFromObject(metricsSvc), existing) + if apierrors.IsNotFound(err) { + if err := cl.Create(ctx, svcMonitor); err != nil { + return fmt.Errorf("error creating ServiceMonitor: %w", err) + } + return nil + } + if err != nil { + return fmt.Errorf("error getting ServiceMonitor: %w", err) + } + // Currently, we only update labels on the ServiceMonitor as those are the only values that can change. + if !reflect.DeepEqual(existing.GetLabels(), svcMonitor.GetLabels()) { + existing.SetLabels(svcMonitor.GetLabels()) + if err := cl.Update(ctx, existing); err != nil { + return fmt.Errorf("error updating ServiceMonitor: %w", err) + } + } + return nil +} + +// maybeCleanupMetricsResources ensures that any metrics resources created for a proxy are deleted. Only metrics Service +// gets deleted explicitly because the ServiceMonitor has Service's owner reference, so gets garbage collected +// automatically. +func maybeCleanupMetricsResources(ctx context.Context, opts *metricsOpts, cl client.Client) error { + sel := metricsSvcSelector(opts.proxyLabels, opts.proxyType) + return cl.DeleteAllOf(ctx, &corev1.Service{}, client.InNamespace(opts.tsNamespace), client.MatchingLabels(sel)) +} + +// maybeCleanupServiceMonitor cleans up any ServiceMonitor created for the named proxy StatefulSet. +func maybeCleanupServiceMonitor(ctx context.Context, cl client.Client, stsName, ns string) error { + smName := metricsResourceName(stsName) + sm := serviceMonitorTemplate(smName, ns) + u, err := serviceMonitorToUnstructured(sm) + if err != nil { + return fmt.Errorf("error building ServiceMonitor: %w", err) + } + err = cl.Get(ctx, types.NamespacedName{Name: smName, Namespace: ns}, u) + if apierrors.IsNotFound(err) { + return nil // nothing to do + } + if err != nil { + return fmt.Errorf("error verifying if ServiceMonitor %s/%s exists: %w", ns, stsName, err) + } + return cl.Delete(ctx, u) +} + +// newServiceMonitor takes a metrics Service created for a proxy and constructs and returns a ServiceMonitor for that +// proxy that can be applied to the kube API server. +// The ServiceMonitor is returned as Unstructured type - this allows us to avoid importing prometheus-operator API server client/schema. +func newServiceMonitor(metricsSvc *corev1.Service, spec *tsapi.ServiceMonitor) (*unstructured.Unstructured, error) { + sm := serviceMonitorTemplate(metricsSvc.Name, metricsSvc.Namespace) + sm.ObjectMeta.Labels = metricsSvc.Labels + if spec != nil && len(spec.Labels) > 0 { + sm.ObjectMeta.Labels = mergeMapKeys(sm.ObjectMeta.Labels, spec.Labels.Parse()) + } + + sm.ObjectMeta.OwnerReferences = []metav1.OwnerReference{*metav1.NewControllerRef(metricsSvc, corev1.SchemeGroupVersion.WithKind("Service"))} + sm.Spec = ServiceMonitorSpec{ + Selector: metav1.LabelSelector{MatchLabels: metricsSvc.Labels}, + Endpoints: []ServiceMonitorEndpoint{{ + Port: "metrics", + }}, + NamespaceSelector: ServiceMonitorNamespaceSelector{ + MatchNames: []string{metricsSvc.Namespace}, + }, + JobLabel: labelPromJob, + TargetLabels: []string{ + labelPromProxyParentName, + labelPromProxyParentNamespace, + labelPromProxyType, + }, + } + return serviceMonitorToUnstructured(sm) +} + +// serviceMonitorToUnstructured takes a ServiceMonitor and converts it to Unstructured type that can be used by the c/r +// client in Kubernetes API server calls. +func serviceMonitorToUnstructured(sm *ServiceMonitor) (*unstructured.Unstructured, error) { + contents, err := runtime.DefaultUnstructuredConverter.ToUnstructured(sm) + if err != nil { + return nil, fmt.Errorf("error converting ServiceMonitor to Unstructured: %w", err) + } + u := &unstructured.Unstructured{} + u.SetUnstructuredContent(contents) + u.SetGroupVersionKind(sm.GroupVersionKind()) + return u, nil +} + +// metricsResourceName returns name for metrics Service and ServiceMonitor for a proxy StatefulSet. +func metricsResourceName(stsName string) string { + // Maximum length of StatefulSet name if 52 chars, so this is fine. + return fmt.Sprintf("%s-metrics", stsName) +} + +// metricsResourceLabels constructs labels that will be applied to metrics Service and metrics ServiceMonitor for a +// proxy. +func metricsResourceLabels(opts *metricsOpts) map[string]string { + lbls := map[string]string{ + kubetypes.LabelManaged: "true", + labelMetricsTarget: opts.proxyStsName, + labelPromProxyType: opts.proxyType, + labelPromProxyParentName: opts.proxyLabels[LabelParentName], + } + // Include namespace label for proxies created for a namespaced type. + if isNamespacedProxyType(opts.proxyType) { + lbls[labelPromProxyParentNamespace] = opts.proxyLabels[LabelParentNamespace] + } + lbls[labelPromJob] = promJobName(opts) + return lbls +} + +// promJobName constructs the value of the Prometheus job label that will apply to all metrics for a ServiceMonitor. +func promJobName(opts *metricsOpts) string { + // Include parent resource namespace for proxies created for namespaced types. + if opts.proxyType == proxyTypeIngressResource || opts.proxyType == proxyTypeIngressService { + return fmt.Sprintf("ts_%s_%s_%s", opts.proxyType, opts.proxyLabels[LabelParentNamespace], opts.proxyLabels[LabelParentName]) + } + return fmt.Sprintf("ts_%s_%s", opts.proxyType, opts.proxyLabels[LabelParentName]) +} + +// metricsSvcSelector returns the minimum label set to uniquely identify a metrics Service for a proxy. +func metricsSvcSelector(proxyLabels map[string]string, proxyType string) map[string]string { + sel := map[string]string{ + labelPromProxyType: proxyType, + labelPromProxyParentName: proxyLabels[LabelParentName], + } + // Include namespace label for proxies created for a namespaced type. + if isNamespacedProxyType(proxyType) { + sel[labelPromProxyParentNamespace] = proxyLabels[LabelParentNamespace] + } + return sel +} + +// serviceMonitorTemplate returns a base ServiceMonitor type that, when converted to Unstructured, is a valid type that +// can be used in kube API server calls via the c/r client. +func serviceMonitorTemplate(name, ns string) *ServiceMonitor { + return &ServiceMonitor{ + TypeMeta: metav1.TypeMeta{ + Kind: "ServiceMonitor", + APIVersion: "monitoring.coreos.com/v1", + }, + ObjectMeta: metav1.ObjectMeta{ + Name: name, + Namespace: ns, + }, + } +} + +type metricsOpts struct { + proxyStsName string // name of StatefulSet for proxy + tsNamespace string // namespace in which Tailscale is installed + proxyLabels map[string]string // labels of the proxy StatefulSet + proxyType string +} + +func isNamespacedProxyType(typ string) bool { + return typ == proxyTypeIngressResource || typ == proxyTypeIngressService +} + +func mergeMapKeys(a, b map[string]string) map[string]string { + m := make(map[string]string, len(a)+len(b)) + for key, val := range b { + m[key] = val + } + for key, val := range a { + m[key] = val + } + return m +} diff --git a/cmd/k8s-operator/nameserver.go b/cmd/k8s-operator/nameserver.go index 52577c929..39db5f0f9 100644 --- a/cmd/k8s-operator/nameserver.go +++ b/cmd/k8s-operator/nameserver.go @@ -7,13 +7,13 @@ package main import ( "context" + _ "embed" + "errors" "fmt" "slices" + "strings" "sync" - _ "embed" - - "github.com/pkg/errors" "go.uber.org/zap" xslices "golang.org/x/exp/slices" appsv1 "k8s.io/api/apps/v1" @@ -26,10 +26,12 @@ import ( "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/reconcile" "sigs.k8s.io/yaml" + tsoperator "tailscale.com/k8s-operator" tsapi "tailscale.com/k8s-operator/apis/v1alpha1" "tailscale.com/kube/kubetypes" "tailscale.com/tstime" + "tailscale.com/types/ptr" "tailscale.com/util/clientmetric" "tailscale.com/util/set" ) @@ -44,10 +46,7 @@ const ( messageMultipleDNSConfigsPresent = "Multiple DNSConfig resources found in cluster. Please ensure no more than one is present." defaultNameserverImageRepo = "tailscale/k8s-nameserver" - // TODO (irbekrm): once we start publishing nameserver images for stable - // track, replace 'unstable' here with the version of this operator - // instance. - defaultNameserverImageTag = "unstable" + defaultNameserverImageTag = "stable" ) // NameserverReconciler knows how to create nameserver resources in cluster in @@ -86,7 +85,7 @@ func (a *NameserverReconciler) Reconcile(ctx context.Context, req reconcile.Requ return reconcile.Result{}, nil } logger.Info("Cleaning up DNSConfig resources") - if err := a.maybeCleanup(ctx, &dnsCfg, logger); err != nil { + if err := a.maybeCleanup(&dnsCfg); err != nil { logger.Errorf("error cleaning up reconciler resource: %v", err) return res, err } @@ -100,12 +99,12 @@ func (a *NameserverReconciler) Reconcile(ctx context.Context, req reconcile.Requ } oldCnStatus := dnsCfg.Status.DeepCopy() - setStatus := func(dnsCfg *tsapi.DNSConfig, conditionType tsapi.ConditionType, status metav1.ConditionStatus, reason, message string) (reconcile.Result, error) { + setStatus := func(dnsCfg *tsapi.DNSConfig, status metav1.ConditionStatus, reason, message string) (reconcile.Result, error) { tsoperator.SetDNSConfigCondition(dnsCfg, tsapi.NameserverReady, status, reason, message, dnsCfg.Generation, a.clock, logger) - if !apiequality.Semantic.DeepEqual(oldCnStatus, dnsCfg.Status) { + if !apiequality.Semantic.DeepEqual(oldCnStatus, &dnsCfg.Status) { // An error encountered here should get returned by the Reconcile function. if updateErr := a.Client.Status().Update(ctx, dnsCfg); updateErr != nil { - err = errors.Wrap(err, updateErr.Error()) + err = errors.Join(err, updateErr) } } return res, err @@ -118,7 +117,7 @@ func (a *NameserverReconciler) Reconcile(ctx context.Context, req reconcile.Requ msg := "invalid cluster configuration: more than one tailscale.com/dnsconfigs found. Please ensure that no more than one is created." logger.Error(msg) a.recorder.Event(&dnsCfg, corev1.EventTypeWarning, reasonMultipleDNSConfigsPresent, messageMultipleDNSConfigsPresent) - setStatus(&dnsCfg, tsapi.NameserverReady, metav1.ConditionFalse, reasonMultipleDNSConfigsPresent, messageMultipleDNSConfigsPresent) + setStatus(&dnsCfg, metav1.ConditionFalse, reasonMultipleDNSConfigsPresent, messageMultipleDNSConfigsPresent) } if !slices.Contains(dnsCfg.Finalizers, FinalizerName) { @@ -127,11 +126,16 @@ func (a *NameserverReconciler) Reconcile(ctx context.Context, req reconcile.Requ if err := a.Update(ctx, &dnsCfg); err != nil { msg := fmt.Sprintf(messageNameserverCreationFailed, err) logger.Error(msg) - return setStatus(&dnsCfg, tsapi.NameserverReady, metav1.ConditionFalse, reasonNameserverCreationFailed, msg) + return setStatus(&dnsCfg, metav1.ConditionFalse, reasonNameserverCreationFailed, msg) } } - if err := a.maybeProvision(ctx, &dnsCfg, logger); err != nil { - return reconcile.Result{}, fmt.Errorf("error provisioning nameserver resources: %w", err) + if err = a.maybeProvision(ctx, &dnsCfg); err != nil { + if strings.Contains(err.Error(), optimisticLockErrorMsg) { + logger.Infof("optimistic lock error, retrying: %s", err) + return reconcile.Result{}, nil + } else { + return reconcile.Result{}, fmt.Errorf("error provisioning nameserver resources: %w", err) + } } a.mu.Lock() @@ -149,7 +153,7 @@ func (a *NameserverReconciler) Reconcile(ctx context.Context, req reconcile.Requ dnsCfg.Status.Nameserver = &tsapi.NameserverStatus{ IP: ip, } - return setStatus(&dnsCfg, tsapi.NameserverReady, metav1.ConditionTrue, reasonNameserverCreated, reasonNameserverCreated) + return setStatus(&dnsCfg, metav1.ConditionTrue, reasonNameserverCreated, reasonNameserverCreated) } logger.Info("nameserver Service does not have an IP address allocated, waiting...") return reconcile.Result{}, nil @@ -162,7 +166,7 @@ func nameserverResourceLabels(name, namespace string) map[string]string { return labels } -func (a *NameserverReconciler) maybeProvision(ctx context.Context, tsDNSCfg *tsapi.DNSConfig, logger *zap.SugaredLogger) error { +func (a *NameserverReconciler) maybeProvision(ctx context.Context, tsDNSCfg *tsapi.DNSConfig) error { labels := nameserverResourceLabels(tsDNSCfg.Name, a.tsNamespace) dCfg := &deployConfig{ ownerRefs: []metav1.OwnerReference{*metav1.NewControllerRef(tsDNSCfg, tsapi.SchemeGroupVersion.WithKind("DNSConfig"))}, @@ -170,6 +174,11 @@ func (a *NameserverReconciler) maybeProvision(ctx context.Context, tsDNSCfg *tsa labels: labels, imageRepo: defaultNameserverImageRepo, imageTag: defaultNameserverImageTag, + replicas: 1, + } + + if tsDNSCfg.Spec.Nameserver.Replicas != nil { + dCfg.replicas = *tsDNSCfg.Spec.Nameserver.Replicas } if tsDNSCfg.Spec.Nameserver.Image != nil && tsDNSCfg.Spec.Nameserver.Image.Repo != "" { dCfg.imageRepo = tsDNSCfg.Spec.Nameserver.Image.Repo @@ -177,6 +186,13 @@ func (a *NameserverReconciler) maybeProvision(ctx context.Context, tsDNSCfg *tsa if tsDNSCfg.Spec.Nameserver.Image != nil && tsDNSCfg.Spec.Nameserver.Image.Tag != "" { dCfg.imageTag = tsDNSCfg.Spec.Nameserver.Image.Tag } + if tsDNSCfg.Spec.Nameserver.Service != nil { + dCfg.clusterIP = tsDNSCfg.Spec.Nameserver.Service.ClusterIP + } + if tsDNSCfg.Spec.Nameserver.Pod != nil { + dCfg.tolerations = tsDNSCfg.Spec.Nameserver.Pod.Tolerations + } + for _, deployable := range []deployable{saDeployable, deployDeployable, svcDeployable, cmDeployable} { if err := deployable.updateObj(ctx, dCfg, a.Client); err != nil { return fmt.Errorf("error reconciling %s: %w", deployable.kind, err) @@ -188,7 +204,7 @@ func (a *NameserverReconciler) maybeProvision(ctx context.Context, tsDNSCfg *tsa // maybeCleanup removes DNSConfig from being tracked. The cluster resources // created, will be automatically garbage collected as they are owned by the // DNSConfig. -func (a *NameserverReconciler) maybeCleanup(ctx context.Context, dnsCfg *tsapi.DNSConfig, logger *zap.SugaredLogger) error { +func (a *NameserverReconciler) maybeCleanup(dnsCfg *tsapi.DNSConfig) error { a.mu.Lock() a.managedNameservers.Remove(dnsCfg.UID) a.mu.Unlock() @@ -202,11 +218,14 @@ type deployable struct { } type deployConfig struct { - imageRepo string - imageTag string - labels map[string]string - ownerRefs []metav1.OwnerReference - namespace string + replicas int32 + imageRepo string + imageTag string + labels map[string]string + ownerRefs []metav1.OwnerReference + namespace string + clusterIP string + tolerations []corev1.Toleration } var ( @@ -226,10 +245,12 @@ var ( if err := yaml.Unmarshal(deployYaml, &d); err != nil { return fmt.Errorf("error unmarshalling Deployment yaml: %w", err) } + d.Spec.Replicas = ptr.To(cfg.replicas) d.Spec.Template.Spec.Containers[0].Image = fmt.Sprintf("%s:%s", cfg.imageRepo, cfg.imageTag) d.ObjectMeta.Namespace = cfg.namespace d.ObjectMeta.Labels = cfg.labels d.ObjectMeta.OwnerReferences = cfg.ownerRefs + d.Spec.Template.Spec.Tolerations = cfg.tolerations updateF := func(oldD *appsv1.Deployment) { oldD.Spec = d.Spec } @@ -261,6 +282,7 @@ var ( svc.ObjectMeta.Labels = cfg.labels svc.ObjectMeta.OwnerReferences = cfg.ownerRefs svc.ObjectMeta.Namespace = cfg.namespace + svc.Spec.ClusterIP = cfg.clusterIP _, err := createOrUpdate[corev1.Service](ctx, kubeClient, cfg.namespace, svc, func(*corev1.Service) {}) return err }, diff --git a/cmd/k8s-operator/nameserver_test.go b/cmd/k8s-operator/nameserver_test.go index 695710212..858cd973d 100644 --- a/cmd/k8s-operator/nameserver_test.go +++ b/cmd/k8s-operator/nameserver_test.go @@ -19,109 +19,171 @@ import ( metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "sigs.k8s.io/controller-runtime/pkg/client/fake" "sigs.k8s.io/yaml" + operatorutils "tailscale.com/k8s-operator" tsapi "tailscale.com/k8s-operator/apis/v1alpha1" "tailscale.com/tstest" + "tailscale.com/types/ptr" "tailscale.com/util/mak" ) func TestNameserverReconciler(t *testing.T) { - dnsCfg := &tsapi.DNSConfig{ + dnsConfig := &tsapi.DNSConfig{ TypeMeta: metav1.TypeMeta{Kind: "DNSConfig", APIVersion: "tailscale.com/v1alpha1"}, ObjectMeta: metav1.ObjectMeta{ Name: "test", }, Spec: tsapi.DNSConfigSpec{ Nameserver: &tsapi.Nameserver{ + Replicas: ptr.To[int32](3), Image: &tsapi.NameserverImage{ Repo: "test", Tag: "v0.0.1", }, + Service: &tsapi.NameserverService{ + ClusterIP: "5.4.3.2", + }, + Pod: &tsapi.NameserverPod{ + Tolerations: []corev1.Toleration{ + { + Key: "some-key", + Operator: corev1.TolerationOpEqual, + Value: "some-value", + Effect: corev1.TaintEffectNoSchedule, + }, + }, + }, }, }, } fc := fake.NewClientBuilder(). WithScheme(tsapi.GlobalScheme). - WithObjects(dnsCfg). - WithStatusSubresource(dnsCfg). + WithObjects(dnsConfig). + WithStatusSubresource(dnsConfig). Build() - zl, err := zap.NewDevelopment() + + logger, err := zap.NewDevelopment() if err != nil { t.Fatal(err) } - cl := tstest.NewClock(tstest.ClockOpts{}) - nr := &NameserverReconciler{ + + clock := tstest.NewClock(tstest.ClockOpts{}) + reconciler := &NameserverReconciler{ Client: fc, - clock: cl, - logger: zl.Sugar(), - tsNamespace: "tailscale", - } - expectReconciled(t, nr, "", "test") - // Verify that nameserver Deployment has been created and has the expected fields. - wantsDeploy := &appsv1.Deployment{ObjectMeta: metav1.ObjectMeta{Name: "nameserver", Namespace: "tailscale"}, TypeMeta: metav1.TypeMeta{Kind: "Deployment", APIVersion: appsv1.SchemeGroupVersion.Identifier()}} - if err := yaml.Unmarshal(deployYaml, wantsDeploy); err != nil { - t.Fatalf("unmarshalling yaml: %v", err) + clock: clock, + logger: logger.Sugar(), + tsNamespace: tsNamespace, } - dnsCfgOwnerRef := metav1.NewControllerRef(dnsCfg, tsapi.SchemeGroupVersion.WithKind("DNSConfig")) - wantsDeploy.OwnerReferences = []metav1.OwnerReference{*dnsCfgOwnerRef} - wantsDeploy.Spec.Template.Spec.Containers[0].Image = "test:v0.0.1" - wantsDeploy.Namespace = "tailscale" - labels := nameserverResourceLabels("test", "tailscale") - wantsDeploy.ObjectMeta.Labels = labels - expectEqual(t, fc, wantsDeploy, nil) - - // Verify that DNSConfig advertizes the nameserver's Service IP address, - // has the ready status condition and tailscale finalizer. - mustUpdate(t, fc, "tailscale", "nameserver", func(svc *corev1.Service) { - svc.Spec.ClusterIP = "1.2.3.4" + expectReconciled(t, reconciler, "", "test") + + ownerReference := metav1.NewControllerRef(dnsConfig, tsapi.SchemeGroupVersion.WithKind("DNSConfig")) + nameserverLabels := nameserverResourceLabels(dnsConfig.Name, tsNamespace) + + wantsDeploy := &appsv1.Deployment{ObjectMeta: metav1.ObjectMeta{Name: "nameserver", Namespace: tsNamespace}, TypeMeta: metav1.TypeMeta{Kind: "Deployment", APIVersion: appsv1.SchemeGroupVersion.Identifier()}} + t.Run("deployment has expected fields", func(t *testing.T) { + if err = yaml.Unmarshal(deployYaml, wantsDeploy); err != nil { + t.Fatalf("unmarshalling yaml: %v", err) + } + wantsDeploy.OwnerReferences = []metav1.OwnerReference{*ownerReference} + wantsDeploy.Spec.Template.Spec.Containers[0].Image = "test:v0.0.1" + wantsDeploy.Spec.Replicas = ptr.To[int32](3) + wantsDeploy.Namespace = tsNamespace + wantsDeploy.ObjectMeta.Labels = nameserverLabels + wantsDeploy.Spec.Template.Spec.Tolerations = []corev1.Toleration{ + { + Key: "some-key", + Operator: corev1.TolerationOpEqual, + Value: "some-value", + Effect: corev1.TaintEffectNoSchedule, + }, + } + + expectEqual(t, fc, wantsDeploy) }) - expectReconciled(t, nr, "", "test") - dnsCfg.Status.Nameserver = &tsapi.NameserverStatus{ - IP: "1.2.3.4", - } - dnsCfg.Finalizers = []string{FinalizerName} - dnsCfg.Status.Conditions = append(dnsCfg.Status.Conditions, metav1.Condition{ - Type: string(tsapi.NameserverReady), - Status: metav1.ConditionTrue, - Reason: reasonNameserverCreated, - Message: reasonNameserverCreated, - LastTransitionTime: metav1.Time{Time: cl.Now().Truncate(time.Second)}, + + wantsSvc := &corev1.Service{ObjectMeta: metav1.ObjectMeta{Name: "nameserver", Namespace: tsNamespace}, TypeMeta: metav1.TypeMeta{Kind: "Service", APIVersion: corev1.SchemeGroupVersion.Identifier()}} + t.Run("service has expected fields", func(t *testing.T) { + if err = yaml.Unmarshal(svcYaml, wantsSvc); err != nil { + t.Fatalf("unmarshalling yaml: %v", err) + } + wantsSvc.Spec.ClusterIP = dnsConfig.Spec.Nameserver.Service.ClusterIP + wantsSvc.OwnerReferences = []metav1.OwnerReference{*ownerReference} + wantsSvc.Namespace = tsNamespace + wantsSvc.ObjectMeta.Labels = nameserverLabels + expectEqual(t, fc, wantsSvc) }) - expectEqual(t, fc, dnsCfg, nil) - // // Verify that nameserver image gets updated to match DNSConfig spec. - mustUpdate(t, fc, "", "test", func(dnsCfg *tsapi.DNSConfig) { - dnsCfg.Spec.Nameserver.Image.Tag = "v0.0.2" + t.Run("dns config status is set", func(t *testing.T) { + // Verify that DNSConfig advertizes the nameserver's Service IP address, + // has the ready status condition and tailscale finalizer. + mustUpdate(t, fc, "tailscale", "nameserver", func(svc *corev1.Service) { + svc.Spec.ClusterIP = "1.2.3.4" + }) + expectReconciled(t, reconciler, "", "test") + + dnsConfig.Finalizers = []string{FinalizerName} + dnsConfig.Status.Nameserver = &tsapi.NameserverStatus{ + IP: "1.2.3.4", + } + dnsConfig.Status.Conditions = append(dnsConfig.Status.Conditions, metav1.Condition{ + Type: string(tsapi.NameserverReady), + Status: metav1.ConditionTrue, + Reason: reasonNameserverCreated, + Message: reasonNameserverCreated, + LastTransitionTime: metav1.Time{Time: clock.Now().Truncate(time.Second)}, + }) + + expectEqual(t, fc, dnsConfig) }) - expectReconciled(t, nr, "", "test") - wantsDeploy.Spec.Template.Spec.Containers[0].Image = "test:v0.0.2" - expectEqual(t, fc, wantsDeploy, nil) - - // Verify that when another actor sets ConfigMap data, it does not get - // overwritten by nameserver reconciler. - dnsRecords := &operatorutils.Records{Version: "v1alpha1", IP4: map[string][]string{"foo.ts.net": {"1.2.3.4"}}} - bs, err := json.Marshal(dnsRecords) - if err != nil { - t.Fatalf("error marshalling ConfigMap contents: %v", err) - } - mustUpdate(t, fc, "tailscale", "dnsrecords", func(cm *corev1.ConfigMap) { - mak.Set(&cm.Data, "records.json", string(bs)) + + t.Run("nameserver image can be updated", func(t *testing.T) { + // Verify that nameserver image gets updated to match DNSConfig spec. + mustUpdate(t, fc, "", "test", func(dnsCfg *tsapi.DNSConfig) { + dnsCfg.Spec.Nameserver.Image.Tag = "v0.0.2" + }) + expectReconciled(t, reconciler, "", "test") + wantsDeploy.Spec.Template.Spec.Containers[0].Image = "test:v0.0.2" + expectEqual(t, fc, wantsDeploy) + }) + + t.Run("reconciler does not overwrite custom configuration", func(t *testing.T) { + // Verify that when another actor sets ConfigMap data, it does not get + // overwritten by nameserver reconciler. + dnsRecords := &operatorutils.Records{Version: "v1alpha1", IP4: map[string][]string{"foo.ts.net": {"1.2.3.4"}}} + bs, err := json.Marshal(dnsRecords) + if err != nil { + t.Fatalf("error marshalling ConfigMap contents: %v", err) + } + + mustUpdate(t, fc, "tailscale", "dnsrecords", func(cm *corev1.ConfigMap) { + mak.Set(&cm.Data, "records.json", string(bs)) + }) + + expectReconciled(t, reconciler, "", "test") + + wantCm := &corev1.ConfigMap{ + ObjectMeta: metav1.ObjectMeta{ + Name: "dnsrecords", + Namespace: "tailscale", + Labels: nameserverLabels, + OwnerReferences: []metav1.OwnerReference{*ownerReference}, + }, + TypeMeta: metav1.TypeMeta{Kind: "ConfigMap", APIVersion: "v1"}, + Data: map[string]string{"records.json": string(bs)}, + } + + expectEqual(t, fc, wantCm) }) - expectReconciled(t, nr, "", "test") - wantCm := &corev1.ConfigMap{ObjectMeta: metav1.ObjectMeta{Name: "dnsrecords", - Namespace: "tailscale", Labels: labels, OwnerReferences: []metav1.OwnerReference{*dnsCfgOwnerRef}}, - TypeMeta: metav1.TypeMeta{Kind: "ConfigMap", APIVersion: "v1"}, - Data: map[string]string{"records.json": string(bs)}, - } - expectEqual(t, fc, wantCm, nil) - // Verify that if dnsconfig.spec.nameserver.image.{repo,tag} are unset, - // the nameserver image defaults to tailscale/k8s-nameserver:unstable. - mustUpdate(t, fc, "", "test", func(dnsCfg *tsapi.DNSConfig) { - dnsCfg.Spec.Nameserver.Image = nil + t.Run("uses default nameserver image", func(t *testing.T) { + // Verify that if dnsconfig.spec.nameserver.image.{repo,tag} are unset, + // the nameserver image defaults to tailscale/k8s-nameserver:unstable. + mustUpdate(t, fc, "", "test", func(dnsCfg *tsapi.DNSConfig) { + dnsCfg.Spec.Nameserver.Image = nil + }) + expectReconciled(t, reconciler, "", "test") + wantsDeploy.Spec.Template.Spec.Containers[0].Image = "tailscale/k8s-nameserver:stable" + expectEqual(t, fc, wantsDeploy) }) - expectReconciled(t, nr, "", "test") - wantsDeploy.Spec.Template.Spec.Containers[0].Image = "tailscale/k8s-nameserver:unstable" - expectEqual(t, fc, wantsDeploy, nil) } diff --git a/cmd/k8s-operator/nodeport-service-ports.go b/cmd/k8s-operator/nodeport-service-ports.go new file mode 100644 index 000000000..a9504e3e9 --- /dev/null +++ b/cmd/k8s-operator/nodeport-service-ports.go @@ -0,0 +1,203 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package main + +import ( + "context" + "fmt" + "math/rand/v2" + "regexp" + "sort" + "strconv" + "strings" + + "go.uber.org/zap" + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/util/intstr" + "sigs.k8s.io/controller-runtime/pkg/client" + k8soperator "tailscale.com/k8s-operator" + tsapi "tailscale.com/k8s-operator/apis/v1alpha1" + "tailscale.com/kube/kubetypes" +) + +const ( + tailscaledPortMax = 65535 + tailscaledPortMin = 1024 + testSvcName = "test-node-port-range" + + invalidSvcNodePort = 777777 +) + +// getServicesNodePortRange is a hacky function that attempts to determine Service NodePort range by +// creating a deliberately invalid Service with a NodePort that is too large and parsing the returned +// validation error. Returns nil if unable to determine port range. +// https://kubernetes.io/docs/concepts/services-networking/service/#type-nodeport +func getServicesNodePortRange(ctx context.Context, c client.Client, tsNamespace string, logger *zap.SugaredLogger) *tsapi.PortRange { + svc := &corev1.Service{ + ObjectMeta: metav1.ObjectMeta{ + Name: testSvcName, + Namespace: tsNamespace, + Labels: map[string]string{ + kubetypes.LabelManaged: "true", + }, + }, + Spec: corev1.ServiceSpec{ + Type: corev1.ServiceTypeNodePort, + Ports: []corev1.ServicePort{ + { + Name: testSvcName, + Port: 8080, + TargetPort: intstr.FromInt32(8080), + Protocol: corev1.ProtocolUDP, + NodePort: invalidSvcNodePort, + }, + }, + }, + } + + // NOTE(ChaosInTheCRD): ideally this would be a server side dry-run but could not get it working + err := c.Create(ctx, svc) + if err == nil { + return nil + } + + if validPorts := getServicesNodePortRangeFromErr(err.Error()); validPorts != "" { + pr, err := parseServicesNodePortRange(validPorts) + if err != nil { + logger.Debugf("failed to parse NodePort range set for Kubernetes Cluster: %w", err) + return nil + } + + return pr + } + + return nil +} + +func getServicesNodePortRangeFromErr(err string) string { + reg := regexp.MustCompile(`\d{1,5}-\d{1,5}`) + matches := reg.FindAllString(err, -1) + if len(matches) != 1 { + return "" + } + + return matches[0] +} + +// parseServicesNodePortRange converts the `ValidPorts` string field in the Kubernetes PortAllocator error and converts it to +// PortRange +func parseServicesNodePortRange(p string) (*tsapi.PortRange, error) { + parts := strings.Split(p, "-") + s, err := strconv.ParseUint(parts[0], 10, 16) + if err != nil { + return nil, fmt.Errorf("failed to parse string as uint16: %w", err) + } + + var e uint64 + switch len(parts) { + case 1: + e = uint64(s) + case 2: + e, err = strconv.ParseUint(parts[1], 10, 16) + if err != nil { + return nil, fmt.Errorf("failed to parse string as uint16: %w", err) + } + default: + return nil, fmt.Errorf("failed to parse port range %q", p) + } + + portRange := &tsapi.PortRange{Port: uint16(s), EndPort: uint16(e)} + if !portRange.IsValid() { + return nil, fmt.Errorf("port range %q is not valid", portRange.String()) + } + + return portRange, nil +} + +// validateNodePortRanges checks that the port range specified is valid. It also ensures that the specified ranges +// lie within the NodePort Service port range specified for the Kubernetes API Server. +func validateNodePortRanges(ctx context.Context, c client.Client, kubeRange *tsapi.PortRange, pc *tsapi.ProxyClass) error { + if pc.Spec.StaticEndpoints == nil { + return nil + } + + portRanges := pc.Spec.StaticEndpoints.NodePort.Ports + + if kubeRange != nil { + for _, pr := range portRanges { + if !kubeRange.Contains(pr.Port) || (pr.EndPort != 0 && !kubeRange.Contains(pr.EndPort)) { + return fmt.Errorf("range %q is not within Cluster configured range %q", pr.String(), kubeRange.String()) + } + } + } + + for _, r := range portRanges { + if !r.IsValid() { + return fmt.Errorf("port range %q is invalid", r.String()) + } + } + + // TODO(ChaosInTheCRD): if a ProxyClass that made another invalid (due to port range clash) is deleted, + // the invalid ProxyClass doesn't get reconciled on, and therefore will not go valid. We should fix this. + proxyClassRanges, err := getPortsForProxyClasses(ctx, c) + if err != nil { + return fmt.Errorf("failed to get port ranges for ProxyClasses: %w", err) + } + + for _, r := range portRanges { + for pcName, pcr := range proxyClassRanges { + if pcName == pc.Name { + continue + } + if pcr.ClashesWith(r) { + return fmt.Errorf("port ranges for ProxyClass %q clash with existing ProxyClass %q", pc.Name, pcName) + } + } + } + + if len(portRanges) == 1 { + return nil + } + + sort.Slice(portRanges, func(i, j int) bool { + return portRanges[i].Port < portRanges[j].Port + }) + + for i := 1; i < len(portRanges); i++ { + prev := portRanges[i-1] + curr := portRanges[i] + if curr.Port <= prev.Port || curr.Port <= prev.EndPort { + return fmt.Errorf("overlapping ranges: %q and %q", prev.String(), curr.String()) + } + } + + return nil +} + +// getPortsForProxyClasses gets the port ranges for all the other existing ProxyClasses +func getPortsForProxyClasses(ctx context.Context, c client.Client) (map[string]tsapi.PortRanges, error) { + pcs := new(tsapi.ProxyClassList) + + err := c.List(ctx, pcs) + if err != nil { + return nil, fmt.Errorf("failed to list ProxyClasses: %w", err) + } + + portRanges := make(map[string]tsapi.PortRanges) + for _, i := range pcs.Items { + if !k8soperator.ProxyClassIsReady(&i) { + continue + } + if se := i.Spec.StaticEndpoints; se != nil && se.NodePort != nil { + portRanges[i.Name] = se.NodePort.Ports + } + } + + return portRanges, nil +} + +func getRandomPort() uint16 { + return uint16(rand.IntN(tailscaledPortMax-tailscaledPortMin+1) + tailscaledPortMin) +} diff --git a/cmd/k8s-operator/nodeport-services-ports_test.go b/cmd/k8s-operator/nodeport-services-ports_test.go new file mode 100644 index 000000000..9418bb844 --- /dev/null +++ b/cmd/k8s-operator/nodeport-services-ports_test.go @@ -0,0 +1,277 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !plan9 + +package main + +import ( + "context" + "testing" + "time" + + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "sigs.k8s.io/controller-runtime/pkg/client/fake" + tsapi "tailscale.com/k8s-operator/apis/v1alpha1" + "tailscale.com/tstest" +) + +func TestGetServicesNodePortRangeFromErr(t *testing.T) { + tests := []struct { + name string + errStr string + want string + }{ + { + name: "valid_error_string", + errStr: "NodePort 777777 is not in the allowed range 30000-32767", + want: "30000-32767", + }, + { + name: "error_string_with_different_message", + errStr: "some other error without a port range", + want: "", + }, + { + name: "error_string_with_multiple_port_ranges", + errStr: "range 1000-2000 and another range 3000-4000", + want: "", + }, + { + name: "empty_error_string", + errStr: "", + want: "", + }, + { + name: "error_string_with_range_at_start", + errStr: "30000-32767 is the range", + want: "30000-32767", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := getServicesNodePortRangeFromErr(tt.errStr); got != tt.want { + t.Errorf("got %v, want %v", got, tt.want) + } + }) + } +} + +func TestParseServicesNodePortRange(t *testing.T) { + tests := []struct { + name string + p string + want *tsapi.PortRange + wantErr bool + }{ + { + name: "valid_range", + p: "30000-32767", + want: &tsapi.PortRange{Port: 30000, EndPort: 32767}, + wantErr: false, + }, + { + name: "single_port_range", + p: "30000", + want: &tsapi.PortRange{Port: 30000, EndPort: 30000}, + wantErr: false, + }, + { + name: "invalid_format_non_numeric_end", + p: "30000-abc", + want: nil, + wantErr: true, + }, + { + name: "invalid_format_non_numeric_start", + p: "abc-32767", + want: nil, + wantErr: true, + }, + { + name: "empty_string", + p: "", + want: nil, + wantErr: true, + }, + { + name: "too_many_parts", + p: "1-2-3", + want: nil, + wantErr: true, + }, + { + name: "port_too_large_start", + p: "65536-65537", + want: nil, + wantErr: true, + }, + { + name: "port_too_large_end", + p: "30000-65536", + want: nil, + wantErr: true, + }, + { + name: "inverted_range", + p: "32767-30000", + want: nil, + wantErr: true, // IsValid() will fail + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + portRange, err := parseServicesNodePortRange(tt.p) + if (err != nil) != tt.wantErr { + t.Errorf("error = %v, wantErr %v", err, tt.wantErr) + return + } + + if tt.wantErr { + return + } + + if portRange == nil { + t.Fatalf("got nil port range, expected %v", tt.want) + } + + if portRange.Port != tt.want.Port || portRange.EndPort != tt.want.EndPort { + t.Errorf("got = %v, want %v", portRange, tt.want) + } + }) + } +} + +func TestValidateNodePortRanges(t *testing.T) { + tests := []struct { + name string + portRanges []tsapi.PortRange + wantErr bool + }{ + { + name: "valid_ranges_with_unknown_kube_range", + portRanges: []tsapi.PortRange{ + {Port: 30003, EndPort: 30005}, + {Port: 30006, EndPort: 30007}, + }, + wantErr: false, + }, + { + name: "overlapping_ranges", + portRanges: []tsapi.PortRange{ + {Port: 30000, EndPort: 30010}, + {Port: 30005, EndPort: 30015}, + }, + wantErr: true, + }, + { + name: "adjacent_ranges_no_overlap", + portRanges: []tsapi.PortRange{ + {Port: 30010, EndPort: 30020}, + {Port: 30021, EndPort: 30022}, + }, + wantErr: false, + }, + { + name: "identical_ranges_are_overlapping", + portRanges: []tsapi.PortRange{ + {Port: 30005, EndPort: 30010}, + {Port: 30005, EndPort: 30010}, + }, + wantErr: true, + }, + { + name: "range_clashes_with_existing_proxyclass", + portRanges: []tsapi.PortRange{ + {Port: 31005, EndPort: 32070}, + }, + wantErr: true, + }, + } + + // as part of this test, we want to create an adjacent ProxyClass in order to ensure that if it clashes with the one created in this test + // that we get an error + cl := tstest.NewClock(tstest.ClockOpts{}) + opc := &tsapi.ProxyClass{ + ObjectMeta: metav1.ObjectMeta{ + Name: "other-pc", + }, + Spec: tsapi.ProxyClassSpec{ + StatefulSet: &tsapi.StatefulSet{ + Annotations: defaultProxyClassAnnotations, + }, + StaticEndpoints: &tsapi.StaticEndpointsConfig{ + NodePort: &tsapi.NodePortConfig{ + Ports: []tsapi.PortRange{ + {Port: 31000}, {Port: 32000}, + }, + Selector: map[string]string{ + "foo/bar": "baz", + }, + }, + }, + }, + Status: tsapi.ProxyClassStatus{ + Conditions: []metav1.Condition{{ + Type: string(tsapi.ProxyClassReady), + Status: metav1.ConditionTrue, + Reason: reasonProxyClassValid, + Message: reasonProxyClassValid, + LastTransitionTime: metav1.Time{Time: cl.Now().Truncate(time.Second)}, + }}, + }, + } + + fc := fake.NewClientBuilder(). + WithObjects(opc). + WithStatusSubresource(opc). + WithScheme(tsapi.GlobalScheme). + Build() + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + pc := &tsapi.ProxyClass{ + ObjectMeta: metav1.ObjectMeta{ + Name: "pc", + }, + Spec: tsapi.ProxyClassSpec{ + StatefulSet: &tsapi.StatefulSet{ + Annotations: defaultProxyClassAnnotations, + }, + StaticEndpoints: &tsapi.StaticEndpointsConfig{ + NodePort: &tsapi.NodePortConfig{ + Ports: tt.portRanges, + Selector: map[string]string{ + "foo/bar": "baz", + }, + }, + }, + }, + Status: tsapi.ProxyClassStatus{ + Conditions: []metav1.Condition{{ + Type: string(tsapi.ProxyClassReady), + Status: metav1.ConditionTrue, + Reason: reasonProxyClassValid, + Message: reasonProxyClassValid, + LastTransitionTime: metav1.Time{Time: cl.Now().Truncate(time.Second)}, + }}, + }, + } + err := validateNodePortRanges(context.Background(), fc, &tsapi.PortRange{Port: 30000, EndPort: 32767}, pc) + if (err != nil) != tt.wantErr { + t.Errorf("unexpected error: %v", err) + } + }) + } +} + +func TestGetRandomPort(t *testing.T) { + for range 100 { + port := getRandomPort() + if port < tailscaledPortMin || port > tailscaledPortMax { + t.Errorf("generated port %d which is out of range [%d, %d]", port, tailscaledPortMin, tailscaledPortMax) + } + } +} diff --git a/cmd/k8s-operator/operator.go b/cmd/k8s-operator/operator.go index 5255d4f29..6b545a827 100644 --- a/cmd/k8s-operator/operator.go +++ b/cmd/k8s-operator/operator.go @@ -9,22 +9,30 @@ package main import ( "context" + "fmt" + "net/http" "os" "regexp" + "strconv" "strings" "time" "github.com/go-logr/zapr" "go.uber.org/zap" "go.uber.org/zap/zapcore" - "golang.org/x/oauth2/clientcredentials" appsv1 "k8s.io/api/apps/v1" corev1 "k8s.io/api/core/v1" discoveryv1 "k8s.io/api/discovery/v1" networkingv1 "k8s.io/api/networking/v1" rbacv1 "k8s.io/api/rbac/v1" + apiextensionsv1 "k8s.io/apiextensions-apiserver/pkg/apis/apiextensions/v1" + apiequality "k8s.io/apimachinery/pkg/api/equality" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/fields" + klabels "k8s.io/apimachinery/pkg/labels" "k8s.io/apimachinery/pkg/types" "k8s.io/client-go/rest" + toolscache "k8s.io/client-go/tools/cache" "sigs.k8s.io/controller-runtime/pkg/builder" "sigs.k8s.io/controller-runtime/pkg/cache" "sigs.k8s.io/controller-runtime/pkg/client" @@ -34,16 +42,22 @@ import ( kzap "sigs.k8s.io/controller-runtime/pkg/log/zap" "sigs.k8s.io/controller-runtime/pkg/manager" "sigs.k8s.io/controller-runtime/pkg/manager/signals" + "sigs.k8s.io/controller-runtime/pkg/predicate" "sigs.k8s.io/controller-runtime/pkg/reconcile" + "tailscale.com/envknob" + + "tailscale.com/client/local" "tailscale.com/client/tailscale" "tailscale.com/hostinfo" "tailscale.com/ipn" "tailscale.com/ipn/store/kubestore" + apiproxy "tailscale.com/k8s-operator/api-proxy" tsapi "tailscale.com/k8s-operator/apis/v1alpha1" "tailscale.com/kube/kubetypes" "tailscale.com/tsnet" "tailscale.com/tstime" "tailscale.com/types/logger" + "tailscale.com/util/set" "tailscale.com/version" ) @@ -53,6 +67,9 @@ import ( // Generate static manifests for deploying Tailscale operator on Kubernetes from the operator's Helm chart. //go:generate go run tailscale.com/cmd/k8s-operator/generate staticmanifests +// Generate the helm chart's CRDs (which are ignored from git). +//go:generate go run tailscale.com/cmd/k8s-operator/generate helmcrd + // Generate CRD API docs. //go:generate go run github.com/elastic/crd-ref-docs --renderer=markdown --source-path=../../k8s-operator/apis/ --config=../../k8s-operator/api-docs-config.yaml --output-path=../../k8s-operator/api.md @@ -65,11 +82,14 @@ func main() { tsNamespace = defaultEnv("OPERATOR_NAMESPACE", "") tslogging = defaultEnv("OPERATOR_LOGGING", "info") image = defaultEnv("PROXY_IMAGE", "tailscale/tailscale:latest") + k8sProxyImage = defaultEnv("K8S_PROXY_IMAGE", "tailscale/k8s-proxy:latest") priorityClassName = defaultEnv("PROXY_PRIORITY_CLASS_NAME", "") tags = defaultEnv("PROXY_TAGS", "tag:k8s") tsFirewallMode = defaultEnv("PROXY_FIREWALL_MODE", "") defaultProxyClass = defaultEnv("PROXY_DEFAULT_CLASS", "") isDefaultLoadBalancer = defaultBool("OPERATOR_DEFAULT_LOAD_BALANCER", false) + loginServer = strings.TrimSuffix(defaultEnv("OPERATOR_LOGIN_SERVER", ""), "/") + ingressClassName = defaultEnv("OPERATOR_INGRESS_CLASS_NAME", "tailscale") ) var opts []kzap.Opts @@ -84,70 +104,98 @@ func main() { zlog := kzap.NewRaw(opts...).Sugar() logf.SetLogger(zapr.NewLogger(zlog.Desugar())) + if tsNamespace == "" { + const namespaceFile = "/var/run/secrets/kubernetes.io/serviceaccount/namespace" + b, err := os.ReadFile(namespaceFile) + if err != nil { + zlog.Fatalf("Could not get operator namespace from OPERATOR_NAMESPACE environment variable or default projected volume: %v", err) + } + tsNamespace = strings.TrimSpace(string(b)) + } + // The operator can run either as a plain operator or it can // additionally act as api-server proxy // https://tailscale.com/kb/1236/kubernetes-operator/?q=kubernetes#accessing-the-kubernetes-control-plane-using-an-api-server-proxy. mode := parseAPIProxyMode() - if mode == apiserverProxyModeDisabled { + if mode == nil { hostinfo.SetApp(kubetypes.AppOperator) } else { - hostinfo.SetApp(kubetypes.AppAPIServerProxy) + hostinfo.SetApp(kubetypes.AppInProcessAPIServerProxy) } - s, tsClient := initTSNet(zlog) + s, tsc := initTSNet(zlog, loginServer) defer s.Close() restConfig := config.GetConfigOrDie() - maybeLaunchAPIServerProxy(zlog, restConfig, s, mode) + if mode != nil { + ap, err := apiproxy.NewAPIServerProxy(zlog, restConfig, s, *mode, true) + if err != nil { + zlog.Fatalf("error creating API server proxy: %v", err) + } + go func() { + if err := ap.Run(context.Background()); err != nil { + zlog.Fatalf("error running API server proxy: %v", err) + } + }() + } + + // Operator log uploads can be opted-out using the "TS_NO_LOGS_NO_SUPPORT" environment variable. + if !envknob.NoLogsNoSupport() { + zlog = zlog.WithOptions(zap.WrapCore(func(core zapcore.Core) zapcore.Core { + return wrapZapCore(core, s.LogtailWriter()) + })) + } + rOpts := reconcilerOpts{ log: zlog, tsServer: s, - tsClient: tsClient, + tsClient: tsc, tailscaleNamespace: tsNamespace, restConfig: restConfig, proxyImage: image, + k8sProxyImage: k8sProxyImage, proxyPriorityClassName: priorityClassName, proxyActAsDefaultLoadBalancer: isDefaultLoadBalancer, proxyTags: tags, proxyFirewallMode: tsFirewallMode, - proxyDefaultClass: defaultProxyClass, + defaultProxyClass: defaultProxyClass, + loginServer: loginServer, + ingressClassName: ingressClassName, } runReconcilers(rOpts) } -// initTSNet initializes the tsnet.Server and logs in to Tailscale. It uses the -// CLIENT_ID_FILE and CLIENT_SECRET_FILE environment variables to authenticate -// with Tailscale. -func initTSNet(zlog *zap.SugaredLogger) (*tsnet.Server, *tailscale.Client) { +// initTSNet initializes the tsnet.Server and logs in to Tailscale. If CLIENT_ID +// is set, it authenticates to the Tailscale API using the federated OIDC workload +// identity flow. Otherwise, it uses the CLIENT_ID_FILE and CLIENT_SECRET_FILE +// environment variables to authenticate with static credentials. +func initTSNet(zlog *zap.SugaredLogger, loginServer string) (*tsnet.Server, tsClient) { var ( - clientIDPath = defaultEnv("CLIENT_ID_FILE", "") - clientSecretPath = defaultEnv("CLIENT_SECRET_FILE", "") + clientID = defaultEnv("CLIENT_ID", "") // Used for workload identity federation. + clientIDPath = defaultEnv("CLIENT_ID_FILE", "") // Used for static client credentials. + clientSecretPath = defaultEnv("CLIENT_SECRET_FILE", "") // Used for static client credentials. hostname = defaultEnv("OPERATOR_HOSTNAME", "tailscale-operator") kubeSecret = defaultEnv("OPERATOR_SECRET", "") operatorTags = defaultEnv("OPERATOR_INITIAL_TAGS", "tag:k8s-operator") ) startlog := zlog.Named("startup") - if clientIDPath == "" || clientSecretPath == "" { - startlog.Fatalf("CLIENT_ID_FILE and CLIENT_SECRET_FILE must be set") + if clientID == "" && (clientIDPath == "" || clientSecretPath == "") { + startlog.Fatalf("CLIENT_ID_FILE and CLIENT_SECRET_FILE must be set") // TODO(tomhjp): error message can mention WIF once it's publicly available. } - clientID, err := os.ReadFile(clientIDPath) + tsc, err := newTSClient(zlog.Named("ts-api-client"), clientID, clientIDPath, clientSecretPath, loginServer) if err != nil { - startlog.Fatalf("reading client ID %q: %v", clientIDPath, err) + startlog.Fatalf("error creating Tailscale client: %v", err) } - clientSecret, err := os.ReadFile(clientSecretPath) - if err != nil { - startlog.Fatalf("reading client secret %q: %v", clientSecretPath, err) - } - credentials := clientcredentials.Config{ - ClientID: string(clientID), - ClientSecret: string(clientSecret), - TokenURL: "https://login.tailscale.com/api/v2/oauth/token", - } - tsClient := tailscale.NewClient("-", nil) - tsClient.HTTPClient = credentials.Client(context.Background()) - s := &tsnet.Server{ - Hostname: hostname, - Logf: zlog.Named("tailscaled").Debugf, + Hostname: hostname, + Logf: zlog.Named("tailscaled").Debugf, + ControlURL: loginServer, + } + if p := os.Getenv("TS_PORT"); p != "" { + port, err := strconv.ParseUint(p, 10, 16) + if err != nil { + startlog.Fatalf("TS_PORT %q cannot be parsed as uint16: %v", p, err) + } + s.Port = uint16(port) } if kubeSecret != "" { st, err := kubestore.New(logger.Discard, kubeSecret) @@ -190,7 +238,7 @@ waitOnline: }, }, } - authkey, _, err := tsClient.CreateKey(ctx, caps) + authkey, _, err := tsc.CreateKey(ctx, caps) if err != nil { startlog.Fatalf("creating operator authkey: %v", err) } @@ -214,7 +262,18 @@ waitOnline: } time.Sleep(time.Second) } - return s, tsClient + return s, tsc +} + +// predicate function for filtering to ensure we *don't* reconcile on tailscale managed Kubernetes Services +func serviceManagedResourceFilterPredicate() predicate.Predicate { + return predicate.NewPredicateFuncs(func(object client.Object) bool { + if svc, ok := object.(*corev1.Service); !ok { + return false + } else { + return !isManagedResource(svc) + } + }) } // runReconcilers starts the controller-runtime manager and registers the @@ -230,21 +289,32 @@ func runReconcilers(opts reconcilerOpts) { nsFilter := cache.ByObject{ Field: client.InNamespace(opts.tailscaleNamespace).AsSelector(), } + + // We watch the ServiceMonitor CRD to ensure that reconcilers are re-triggered if user's workflows result in the + // ServiceMonitor CRD applied after some of our resources that define ServiceMonitor creation. This selector + // ensures that we only watch the ServiceMonitor CRD and that we don't cache full contents of it. + serviceMonitorSelector := cache.ByObject{ + Field: fields.SelectorFromSet(fields.Set{"metadata.name": serviceMonitorCRD}), + Transform: crdTransformer(startlog), + } + + // TODO (irbekrm): stricter filtering what we watch/cache/call + // reconcilers on. c/r by default starts a watch on any + // resources that we GET via the controller manager's client. mgrOpts := manager.Options{ - // TODO (irbekrm): stricter filtering what we watch/cache/call - // reconcilers on. c/r by default starts a watch on any - // resources that we GET via the controller manager's client. + // The cache will apply the specified filters only to the object types listed below via ByObject. + // Other object types (e.g., EndpointSlices) can still be fetched or watched using the cached client, but they will not have any filtering applied. Cache: cache.Options{ ByObject: map[client.Object]cache.ByObject{ - &corev1.Secret{}: nsFilter, - &corev1.ServiceAccount{}: nsFilter, - &corev1.Pod{}: nsFilter, - &corev1.ConfigMap{}: nsFilter, - &appsv1.StatefulSet{}: nsFilter, - &appsv1.Deployment{}: nsFilter, - &discoveryv1.EndpointSlice{}: nsFilter, - &rbacv1.Role{}: nsFilter, - &rbacv1.RoleBinding{}: nsFilter, + &corev1.Secret{}: nsFilter, + &corev1.ServiceAccount{}: nsFilter, + &corev1.Pod{}: nsFilter, + &corev1.ConfigMap{}: nsFilter, + &appsv1.StatefulSet{}: nsFilter, + &appsv1.Deployment{}: nsFilter, + &rbacv1.Role{}: nsFilter, + &rbacv1.RoleBinding{}: nsFilter, + &apiextensionsv1.CustomResourceDefinition{}: serviceMonitorSelector, }, }, Scheme: tsapi.GlobalScheme, @@ -270,7 +340,9 @@ func runReconcilers(opts reconcilerOpts) { proxyImage: opts.proxyImage, proxyPriorityClassName: opts.proxyPriorityClassName, tsFirewallMode: opts.proxyFirewallMode, + loginServer: opts.tsServer.ControlURL, } + err = builder. ControllerManagedBy(mgr). Named("service-reconciler"). @@ -286,20 +358,25 @@ func runReconcilers(opts reconcilerOpts) { recorder: eventRecorder, tsNamespace: opts.tailscaleNamespace, clock: tstime.DefaultClock{}, - proxyDefaultClass: opts.proxyDefaultClass, + defaultProxyClass: opts.defaultProxyClass, }) if err != nil { startlog.Fatalf("could not create service reconciler: %v", err) } + if err := mgr.GetFieldIndexer().IndexField(context.Background(), new(corev1.Service), indexServiceProxyClass, indexProxyClass); err != nil { + startlog.Fatalf("failed setting up ProxyClass indexer for Services: %v", err) + } + ingressChildFilter := handler.EnqueueRequestsFromMapFunc(managedResourceHandlerForType("ingress")) // If a ProxyClassChanges, enqueue all Ingresses labeled with that // ProxyClass's name. proxyClassFilterForIngress := handler.EnqueueRequestsFromMapFunc(proxyClassHandlerForIngress(mgr.GetClient(), startlog)) // Enque Ingress if a managed Service or backend Service associated with a tailscale Ingress changes. - svcHandlerForIngress := handler.EnqueueRequestsFromMapFunc(serviceHandlerForIngress(mgr.GetClient(), startlog)) + svcHandlerForIngress := handler.EnqueueRequestsFromMapFunc(serviceHandlerForIngress(mgr.GetClient(), startlog, opts.ingressClassName)) err = builder. ControllerManagedBy(mgr). For(&networkingv1.Ingress{}). + Named("ingress-reconciler"). Watches(&appsv1.StatefulSet{}, ingressChildFilter). Watches(&corev1.Secret{}, ingressChildFilter). Watches(&corev1.Service{}, svcHandlerForIngress). @@ -309,11 +386,76 @@ func runReconcilers(opts reconcilerOpts) { recorder: eventRecorder, Client: mgr.GetClient(), logger: opts.log.Named("ingress-reconciler"), - proxyDefaultClass: opts.proxyDefaultClass, + defaultProxyClass: opts.defaultProxyClass, + ingressClassName: opts.ingressClassName, }) if err != nil { startlog.Fatalf("could not create ingress reconciler: %v", err) } + if err := mgr.GetFieldIndexer().IndexField(context.Background(), new(networkingv1.Ingress), indexIngressProxyClass, indexProxyClass); err != nil { + startlog.Fatalf("failed setting up ProxyClass indexer for Ingresses: %v", err) + } + + lc, err := opts.tsServer.LocalClient() + if err != nil { + startlog.Fatalf("could not get local client: %v", err) + } + id, err := id(context.Background(), lc) + if err != nil { + startlog.Fatalf("error determining stable ID of the operator's Tailscale device: %v", err) + } + ingressProxyGroupFilter := handler.EnqueueRequestsFromMapFunc(ingressesFromIngressProxyGroup(mgr.GetClient(), opts.log)) + err = builder. + ControllerManagedBy(mgr). + For(&networkingv1.Ingress{}). + Named("ingress-pg-reconciler"). + Watches(&corev1.Service{}, handler.EnqueueRequestsFromMapFunc(serviceHandlerForIngressPG(mgr.GetClient(), startlog, opts.ingressClassName))). + Watches(&corev1.Secret{}, handler.EnqueueRequestsFromMapFunc(HAIngressesFromSecret(mgr.GetClient(), startlog))). + Watches(&tsapi.ProxyGroup{}, ingressProxyGroupFilter). + Complete(&HAIngressReconciler{ + recorder: eventRecorder, + tsClient: opts.tsClient, + tsnetServer: opts.tsServer, + defaultTags: strings.Split(opts.proxyTags, ","), + Client: mgr.GetClient(), + logger: opts.log.Named("ingress-pg-reconciler"), + lc: lc, + operatorID: id, + tsNamespace: opts.tailscaleNamespace, + ingressClassName: opts.ingressClassName, + }) + if err != nil { + startlog.Fatalf("could not create ingress-pg-reconciler: %v", err) + } + if err := mgr.GetFieldIndexer().IndexField(context.Background(), new(networkingv1.Ingress), indexIngressProxyGroup, indexPGIngresses); err != nil { + startlog.Fatalf("failed setting up indexer for HA Ingresses: %v", err) + } + + ingressSvcFromEpsFilter := handler.EnqueueRequestsFromMapFunc(ingressSvcFromEps(mgr.GetClient(), opts.log.Named("service-pg-reconciler"))) + err = builder. + ControllerManagedBy(mgr). + For(&corev1.Service{}, builder.WithPredicates(serviceManagedResourceFilterPredicate())). + Named("service-pg-reconciler"). + Watches(&corev1.Secret{}, handler.EnqueueRequestsFromMapFunc(HAServicesFromSecret(mgr.GetClient(), startlog))). + Watches(&tsapi.ProxyGroup{}, ingressProxyGroupFilter). + Watches(&discoveryv1.EndpointSlice{}, ingressSvcFromEpsFilter). + Complete(&HAServiceReconciler{ + recorder: eventRecorder, + tsClient: opts.tsClient, + defaultTags: strings.Split(opts.proxyTags, ","), + Client: mgr.GetClient(), + logger: opts.log.Named("service-pg-reconciler"), + lc: lc, + clock: tstime.DefaultClock{}, + operatorID: id, + tsNamespace: opts.tailscaleNamespace, + }) + if err != nil { + startlog.Fatalf("could not create service-pg-reconciler: %v", err) + } + if err := mgr.GetFieldIndexer().IndexField(context.Background(), new(corev1.Service), indexIngressProxyGroup, indexPGIngresses); err != nil { + startlog.Fatalf("failed setting up indexer for HA Services: %v", err) + } connectorFilter := handler.EnqueueRequestsFromMapFunc(managedResourceHandlerForType("connector")) // If a ProxyClassChanges, enqueue all Connectors that have @@ -321,6 +463,7 @@ func runReconcilers(opts reconcilerOpts) { proxyClassFilterForConnector := handler.EnqueueRequestsFromMapFunc(proxyClassHandlerForConnector(mgr.GetClient(), startlog)) err = builder.ControllerManagedBy(mgr). For(&tsapi.Connector{}). + Named("connector-reconciler"). Watches(&appsv1.StatefulSet{}, connectorFilter). Watches(&corev1.Secret{}, connectorFilter). Watches(&tsapi.ProxyClass{}, proxyClassFilterForConnector). @@ -340,6 +483,7 @@ func runReconcilers(opts reconcilerOpts) { nameserverFilter := handler.EnqueueRequestsFromMapFunc(managedResourceHandlerForType("nameserver")) err = builder.ControllerManagedBy(mgr). For(&tsapi.DNSConfig{}). + Named("nameserver-reconciler"). Watches(&appsv1.Deployment{}, nameserverFilter). Watches(&corev1.ConfigMap{}, nameserverFilter). Watches(&corev1.Service{}, nameserverFilter). @@ -356,12 +500,12 @@ func runReconcilers(opts reconcilerOpts) { } egressSvcFilter := handler.EnqueueRequestsFromMapFunc(egressSvcsHandler) - proxyGroupFilter := handler.EnqueueRequestsFromMapFunc(egressSvcsFromEgressProxyGroup(mgr.GetClient(), opts.log)) + egressProxyGroupFilter := handler.EnqueueRequestsFromMapFunc(egressSvcsFromEgressProxyGroup(mgr.GetClient(), opts.log)) err = builder. ControllerManagedBy(mgr). Named("egress-svcs-reconciler"). Watches(&corev1.Service{}, egressSvcFilter). - Watches(&tsapi.ProxyGroup{}, proxyGroupFilter). + Watches(&tsapi.ProxyGroup{}, egressProxyGroupFilter). Complete(&egressSvcsReconciler{ Client: mgr.GetClient(), tsNamespace: opts.tailscaleNamespace, @@ -376,16 +520,33 @@ func runReconcilers(opts reconcilerOpts) { startlog.Fatalf("failed setting up indexer for egress Services: %v", err) } + egressSvcFromEpsFilter := handler.EnqueueRequestsFromMapFunc(egressSvcFromEps) + err = builder. + ControllerManagedBy(mgr). + Named("egress-svcs-readiness-reconciler"). + Watches(&corev1.Service{}, egressSvcFilter). + Watches(&discoveryv1.EndpointSlice{}, egressSvcFromEpsFilter). + Complete(&egressSvcsReadinessReconciler{ + Client: mgr.GetClient(), + tsNamespace: opts.tailscaleNamespace, + clock: tstime.DefaultClock{}, + logger: opts.log.Named("egress-svcs-readiness-reconciler"), + }) + if err != nil { + startlog.Fatalf("could not create egress Services readiness reconciler: %v", err) + } + epsFilter := handler.EnqueueRequestsFromMapFunc(egressEpsHandler) - podsSecretsFilter := handler.EnqueueRequestsFromMapFunc(egressEpsFromEgressPGChildResources(mgr.GetClient(), opts.log, opts.tailscaleNamespace)) - epsFromExtNSvcFilter := handler.EnqueueRequestsFromMapFunc(epsFromExternalNameService(mgr.GetClient(), opts.log)) + podsFilter := handler.EnqueueRequestsFromMapFunc(egressEpsFromPGPods(mgr.GetClient(), opts.tailscaleNamespace)) + secretsFilter := handler.EnqueueRequestsFromMapFunc(egressEpsFromPGStateSecrets(mgr.GetClient(), opts.tailscaleNamespace)) + epsFromExtNSvcFilter := handler.EnqueueRequestsFromMapFunc(epsFromExternalNameService(mgr.GetClient(), opts.log, opts.tailscaleNamespace)) err = builder. ControllerManagedBy(mgr). Named("egress-eps-reconciler"). Watches(&discoveryv1.EndpointSlice{}, epsFilter). - Watches(&corev1.Pod{}, podsSecretsFilter). - Watches(&corev1.Secret{}, podsSecretsFilter). + Watches(&corev1.Pod{}, podsFilter). + Watches(&corev1.Secret{}, secretsFilter). Watches(&corev1.Service{}, epsFromExtNSvcFilter). Complete(&egressEpsReconciler{ Client: mgr.GetClient(), @@ -396,13 +557,40 @@ func runReconcilers(opts reconcilerOpts) { startlog.Fatalf("could not create egress EndpointSlices reconciler: %v", err) } + podsForEps := handler.EnqueueRequestsFromMapFunc(podsFromEgressEps(mgr.GetClient(), opts.log, opts.tailscaleNamespace)) + podsER := handler.EnqueueRequestsFromMapFunc(egressPodsHandler) + err = builder. + ControllerManagedBy(mgr). + Named("egress-pods-readiness-reconciler"). + Watches(&discoveryv1.EndpointSlice{}, podsForEps). + Watches(&corev1.Pod{}, podsER). + Complete(&egressPodsReconciler{ + Client: mgr.GetClient(), + tsNamespace: opts.tailscaleNamespace, + clock: tstime.DefaultClock{}, + logger: opts.log.Named("egress-pods-readiness-reconciler"), + httpClient: http.DefaultClient, + }) + if err != nil { + startlog.Fatalf("could not create egress Pods readiness reconciler: %v", err) + } + + // ProxyClass reconciler gets triggered on ServiceMonitor CRD changes to ensure that any ProxyClasses, that + // define that a ServiceMonitor should be created, were set to invalid because the CRD did not exist get + // reconciled if the CRD is applied at a later point. + kPortRange := getServicesNodePortRange(context.Background(), mgr.GetClient(), opts.tailscaleNamespace, startlog) + serviceMonitorFilter := handler.EnqueueRequestsFromMapFunc(proxyClassesWithServiceMonitor(mgr.GetClient(), opts.log)) err = builder.ControllerManagedBy(mgr). For(&tsapi.ProxyClass{}). + Named("proxyclass-reconciler"). + Watches(&apiextensionsv1.CustomResourceDefinition{}, serviceMonitorFilter). Complete(&ProxyClassReconciler{ - Client: mgr.GetClient(), - recorder: eventRecorder, - logger: opts.log.Named("proxyclass-reconciler"), - clock: tstime.DefaultClock{}, + Client: mgr.GetClient(), + nodePortRange: kPortRange, + recorder: eventRecorder, + tsNamespace: opts.tailscaleNamespace, + logger: opts.log.Named("proxyclass-reconciler"), + clock: tstime.DefaultClock{}, }) if err != nil { startlog.Fatal("could not create proxyclass reconciler: %v", err) @@ -440,6 +628,7 @@ func runReconcilers(opts reconcilerOpts) { recorderFilter := handler.EnqueueRequestForOwner(mgr.GetScheme(), mgr.GetRESTMapper(), &tsapi.Recorder{}) err = builder.ControllerManagedBy(mgr). For(&tsapi.Recorder{}). + Named("recorder-reconciler"). Watches(&appsv1.StatefulSet{}, recorderFilter). Watches(&corev1.ServiceAccount{}, recorderFilter). Watches(&corev1.Secret{}, recorderFilter). @@ -449,14 +638,77 @@ func runReconcilers(opts reconcilerOpts) { recorder: eventRecorder, tsNamespace: opts.tailscaleNamespace, Client: mgr.GetClient(), - l: opts.log.Named("recorder-reconciler"), + log: opts.log.Named("recorder-reconciler"), clock: tstime.DefaultClock{}, tsClient: opts.tsClient, + loginServer: opts.loginServer, }) if err != nil { startlog.Fatalf("could not create Recorder reconciler: %v", err) } + // kube-apiserver's Tailscale Service reconciler. + err = builder. + ControllerManagedBy(mgr). + For(&tsapi.ProxyGroup{}, builder.WithPredicates( + predicate.NewPredicateFuncs(func(obj client.Object) bool { + pg, ok := obj.(*tsapi.ProxyGroup) + return ok && pg.Spec.Type == tsapi.ProxyGroupTypeKubernetesAPIServer + }), + )). + Named("kube-apiserver-ts-service-reconciler"). + Watches(&corev1.Secret{}, handler.EnqueueRequestsFromMapFunc(kubeAPIServerPGsFromSecret(mgr.GetClient(), startlog))). + Complete(&KubeAPIServerTSServiceReconciler{ + Client: mgr.GetClient(), + recorder: eventRecorder, + logger: opts.log.Named("kube-apiserver-ts-service-reconciler"), + tsClient: opts.tsClient, + tsNamespace: opts.tailscaleNamespace, + lc: lc, + defaultTags: strings.Split(opts.proxyTags, ","), + operatorID: id, + clock: tstime.DefaultClock{}, + }) + if err != nil { + startlog.Fatalf("could not create Kubernetes API server Tailscale Service reconciler: %v", err) + } + + // ProxyGroup reconciler. + ownedByProxyGroupFilter := handler.EnqueueRequestForOwner(mgr.GetScheme(), mgr.GetRESTMapper(), &tsapi.ProxyGroup{}) + proxyClassFilterForProxyGroup := handler.EnqueueRequestsFromMapFunc(proxyClassHandlerForProxyGroup(mgr.GetClient(), startlog)) + nodeFilterForProxyGroup := handler.EnqueueRequestsFromMapFunc(nodeHandlerForProxyGroup(mgr.GetClient(), opts.defaultProxyClass, startlog)) + saFilterForProxyGroup := handler.EnqueueRequestsFromMapFunc(serviceAccountHandlerForProxyGroup(mgr.GetClient(), startlog)) + err = builder.ControllerManagedBy(mgr). + For(&tsapi.ProxyGroup{}). + Named("proxygroup-reconciler"). + Watches(&corev1.Service{}, ownedByProxyGroupFilter). + Watches(&appsv1.StatefulSet{}, ownedByProxyGroupFilter). + Watches(&corev1.ConfigMap{}, ownedByProxyGroupFilter). + Watches(&corev1.ServiceAccount{}, saFilterForProxyGroup). + Watches(&corev1.Secret{}, ownedByProxyGroupFilter). + Watches(&rbacv1.Role{}, ownedByProxyGroupFilter). + Watches(&rbacv1.RoleBinding{}, ownedByProxyGroupFilter). + Watches(&tsapi.ProxyClass{}, proxyClassFilterForProxyGroup). + Watches(&corev1.Node{}, nodeFilterForProxyGroup). + Complete(&ProxyGroupReconciler{ + recorder: eventRecorder, + Client: mgr.GetClient(), + log: opts.log.Named("proxygroup-reconciler"), + clock: tstime.DefaultClock{}, + tsClient: opts.tsClient, + + tsNamespace: opts.tailscaleNamespace, + tsProxyImage: opts.proxyImage, + k8sProxyImage: opts.k8sProxyImage, + defaultTags: strings.Split(opts.proxyTags, ","), + tsFirewallMode: opts.proxyFirewallMode, + defaultProxyClass: opts.defaultProxyClass, + loginServer: opts.tsServer.ControlURL, + }) + if err != nil { + startlog.Fatalf("could not create ProxyGroup reconciler: %v", err) + } + startlog.Infof("Startup complete, operator running, version: %s", version.Long()) if err := mgr.Start(signals.SetupSignalHandler()); err != nil { startlog.Fatalf("could not start manager: %v", err) @@ -466,10 +718,11 @@ func runReconcilers(opts reconcilerOpts) { type reconcilerOpts struct { log *zap.SugaredLogger tsServer *tsnet.Server - tsClient *tailscale.Client + tsClient tsClient tailscaleNamespace string // namespace in which operator resources will be deployed restConfig *rest.Config // config for connecting to the kube API server proxyImage string // : + k8sProxyImage string // : // proxyPriorityClassName isPriorityClass to be set for proxy Pods. This // is a legacy mechanism for cluster resource configuration options - // going forward use ProxyClass. @@ -497,10 +750,15 @@ type reconcilerOpts struct { // Auto is usually the best choice, unless you want to explicitly set // specific mode for debugging purposes. proxyFirewallMode string - // proxyDefaultClass is the name of the ProxyClass to use as the default + // defaultProxyClass is the name of the ProxyClass to use as the default // class for proxies that do not have a ProxyClass set. // this is defined by an operator env variable. - proxyDefaultClass string + defaultProxyClass string + // loginServer is the coordination server URL that should be used by managed resources. + loginServer string + // ingressClassName is the name of the ingress class used by reconcilers of Ingress resources. This defaults + // to "tailscale" but can be customised. + ingressClassName string } // enqueueAllIngressEgressProxySvcsinNS returns a reconcile request for each @@ -511,8 +769,8 @@ func enqueueAllIngressEgressProxySvcsInNS(ns string, cl client.Client, logger *z // Get all headless Services for proxies configured using Service. svcProxyLabels := map[string]string{ - LabelManaged: "true", - LabelParentType: "svc", + kubetypes.LabelManaged: "true", + LabelParentType: "svc", } svcHeadlessSvcList := &corev1.ServiceList{} if err := cl.List(ctx, svcHeadlessSvcList, client.InNamespace(ns), client.MatchingLabels(svcProxyLabels)); err != nil { @@ -525,8 +783,8 @@ func enqueueAllIngressEgressProxySvcsInNS(ns string, cl client.Client, logger *z // Get all headless Services for proxies configured using Ingress. ingProxyLabels := map[string]string{ - LabelManaged: "true", - LabelParentType: "ingress", + kubetypes.LabelManaged: "true", + LabelParentType: "ingress", } ingHeadlessSvcList := &corev1.ServiceList{} if err := cl.List(ctx, ingHeadlessSvcList, client.InNamespace(ns), client.MatchingLabels(ingProxyLabels)); err != nil { @@ -591,15 +849,9 @@ func dnsRecordsReconcilerIngressHandler(ns string, isDefaultLoadBalancer bool, c } } -type tsClient interface { - CreateKey(ctx context.Context, caps tailscale.KeyCapabilities) (string, *tailscale.Key, error) - Device(ctx context.Context, deviceID string, fields *tailscale.DeviceFieldsOpts) (*tailscale.Device, error) - DeleteDevice(ctx context.Context, nodeStableID string) error -} - func isManagedResource(o client.Object) bool { ls := o.GetLabels() - return ls[LabelManaged] == "true" + return ls[kubetypes.LabelManaged] == "true" } func isManagedByType(o client.Object, typ string) bool { @@ -626,6 +878,16 @@ func managedResourceHandlerForType(typ string) handler.MapFunc { } } +// indexProxyClass is used to select ProxyClass-backed objects which are +// locally indexed in the cache for efficient listing without requiring labels. +func indexProxyClass(o client.Object) []string { + if !hasProxyClassAnnotation(o) { + return nil + } + + return []string{o.GetAnnotations()[LabelAnnotationProxyClass]} +} + // proxyClassHandlerForSvc returns a handler that, for a given ProxyClass, // returns a list of reconcile requests for all Services labeled with // tailscale.com/proxy-class: . @@ -633,16 +895,37 @@ func proxyClassHandlerForSvc(cl client.Client, logger *zap.SugaredLogger) handle return func(ctx context.Context, o client.Object) []reconcile.Request { svcList := new(corev1.ServiceList) labels := map[string]string{ - LabelProxyClass: o.GetName(), + LabelAnnotationProxyClass: o.GetName(), } + if err := cl.List(ctx, svcList, client.MatchingLabels(labels)); err != nil { logger.Debugf("error listing Services for ProxyClass: %v", err) return nil } + reqs := make([]reconcile.Request, 0) + seenSvcs := make(set.Set[string]) for _, svc := range svcList.Items { reqs = append(reqs, reconcile.Request{NamespacedName: client.ObjectKeyFromObject(&svc)}) + seenSvcs.Add(fmt.Sprintf("%s/%s", svc.Namespace, svc.Name)) + } + + svcAnnotationList := new(corev1.ServiceList) + if err := cl.List(ctx, svcAnnotationList, client.MatchingFields{indexServiceProxyClass: o.GetName()}); err != nil { + logger.Debugf("error listing Services for ProxyClass: %v", err) + return nil + } + + for _, svc := range svcAnnotationList.Items { + nsname := fmt.Sprintf("%s/%s", svc.Namespace, svc.Name) + if seenSvcs.Contains(nsname) { + continue + } + + reqs = append(reqs, reconcile.Request{NamespacedName: client.ObjectKeyFromObject(&svc)}) + seenSvcs.Add(nsname) } + return reqs } } @@ -654,16 +937,36 @@ func proxyClassHandlerForIngress(cl client.Client, logger *zap.SugaredLogger) ha return func(ctx context.Context, o client.Object) []reconcile.Request { ingList := new(networkingv1.IngressList) labels := map[string]string{ - LabelProxyClass: o.GetName(), + LabelAnnotationProxyClass: o.GetName(), } if err := cl.List(ctx, ingList, client.MatchingLabels(labels)); err != nil { logger.Debugf("error listing Ingresses for ProxyClass: %v", err) return nil } + reqs := make([]reconcile.Request, 0) + seenIngs := make(set.Set[string]) for _, ing := range ingList.Items { reqs = append(reqs, reconcile.Request{NamespacedName: client.ObjectKeyFromObject(&ing)}) + seenIngs.Add(fmt.Sprintf("%s/%s", ing.Namespace, ing.Name)) } + + ingAnnotationList := new(networkingv1.IngressList) + if err := cl.List(ctx, ingAnnotationList, client.MatchingFields{indexIngressProxyClass: o.GetName()}); err != nil { + logger.Debugf("error listing Ingreses for ProxyClass: %v", err) + return nil + } + + for _, ing := range ingAnnotationList.Items { + nsname := fmt.Sprintf("%s/%s", ing.Namespace, ing.Name) + if seenIngs.Contains(nsname) { + continue + } + + reqs = append(reqs, reconcile.Request{NamespacedName: client.ObjectKeyFromObject(&ing)}) + seenIngs.Add(nsname) + } + return reqs } } @@ -689,13 +992,123 @@ func proxyClassHandlerForConnector(cl client.Client, logger *zap.SugaredLogger) } } +// nodeHandlerForProxyGroup returns a handler that, for a given Node, returns a +// list of reconcile requests for ProxyGroups that should be reconciled for the +// Node event. ProxyGroups need to be reconciled for Node events if they are +// configured to expose tailscaled static endpoints to tailnet using NodePort +// Services. +func nodeHandlerForProxyGroup(cl client.Client, defaultProxyClass string, logger *zap.SugaredLogger) handler.MapFunc { + return func(ctx context.Context, o client.Object) []reconcile.Request { + pgList := new(tsapi.ProxyGroupList) + if err := cl.List(ctx, pgList); err != nil { + logger.Debugf("error listing ProxyGroups for ProxyClass: %v", err) + return nil + } + + reqs := make([]reconcile.Request, 0) + for _, pg := range pgList.Items { + if pg.Spec.ProxyClass == "" && defaultProxyClass == "" { + continue + } + + pc := defaultProxyClass + if pc == "" { + pc = pg.Spec.ProxyClass + } + + proxyClass := &tsapi.ProxyClass{} + if err := cl.Get(ctx, types.NamespacedName{Name: pc}, proxyClass); err != nil { + logger.Debugf("error getting ProxyClass %q: %v", pg.Spec.ProxyClass, err) + return nil + } + + stat := proxyClass.Spec.StaticEndpoints + if stat == nil { + continue + } + + // If the selector is empty, all nodes match. + // TODO(ChaosInTheCRD): think about how this must be handled if we want to limit the number of nodes used + if len(stat.NodePort.Selector) == 0 { + reqs = append(reqs, reconcile.Request{NamespacedName: client.ObjectKeyFromObject(&pg)}) + continue + } + + selector, err := metav1.LabelSelectorAsSelector(&metav1.LabelSelector{ + MatchLabels: stat.NodePort.Selector, + }) + if err != nil { + logger.Debugf("error converting `spec.staticEndpoints.nodePort.selector` to Selector: %v", err) + return nil + } + + if selector.Matches(klabels.Set(o.GetLabels())) { + reqs = append(reqs, reconcile.Request{NamespacedName: client.ObjectKeyFromObject(&pg)}) + } + } + return reqs + } +} + +// proxyClassHandlerForProxyGroup returns a handler that, for a given ProxyClass, +// returns a list of reconcile requests for all ProxyGroups that have +// .spec.proxyClass set to that ProxyClass. +func proxyClassHandlerForProxyGroup(cl client.Client, logger *zap.SugaredLogger) handler.MapFunc { + return func(ctx context.Context, o client.Object) []reconcile.Request { + pgList := new(tsapi.ProxyGroupList) + if err := cl.List(ctx, pgList); err != nil { + logger.Debugf("error listing ProxyGroups for ProxyClass: %v", err) + return nil + } + reqs := make([]reconcile.Request, 0) + proxyClassName := o.GetName() + for _, pg := range pgList.Items { + if pg.Spec.ProxyClass == proxyClassName { + reqs = append(reqs, reconcile.Request{NamespacedName: client.ObjectKeyFromObject(&pg)}) + } + } + return reqs + } +} + +// serviceAccountHandlerForProxyGroup returns a handler that, for a given ServiceAccount, +// returns a list of reconcile requests for all ProxyGroups that use that ServiceAccount. +// For most ProxyGroups, this will be a dedicated ServiceAccount owned by a specific +// ProxyGroup. But for kube-apiserver ProxyGroups running in auth mode, they use a shared +// static ServiceAccount named "kube-apiserver-auth-proxy". +func serviceAccountHandlerForProxyGroup(cl client.Client, logger *zap.SugaredLogger) handler.MapFunc { + return func(ctx context.Context, o client.Object) []reconcile.Request { + pgList := new(tsapi.ProxyGroupList) + if err := cl.List(ctx, pgList); err != nil { + logger.Debugf("error listing ProxyGroups for ServiceAccount: %v", err) + return nil + } + reqs := make([]reconcile.Request, 0) + saName := o.GetName() + for _, pg := range pgList.Items { + if saName == authAPIServerProxySAName && isAuthAPIServerProxy(&pg) { + reqs = append(reqs, reconcile.Request{NamespacedName: client.ObjectKeyFromObject(&pg)}) + } + expectedOwner := pgOwnerReference(&pg)[0] + saOwnerRefs := o.GetOwnerReferences() + for _, ref := range saOwnerRefs { + if apiequality.Semantic.DeepEqual(ref, expectedOwner) { + reqs = append(reqs, reconcile.Request{NamespacedName: client.ObjectKeyFromObject(&pg)}) + break + } + } + } + return reqs + } +} + // serviceHandlerForIngress returns a handler for Service events for ingress // reconciler that ensures that if the Service associated with an event is of // interest to the reconciler, the associated Ingress(es) gets be reconciled. // The Services of interest are backend Services for tailscale Ingress and // managed Services for an StatefulSet for a proxy configured for tailscale // Ingress -func serviceHandlerForIngress(cl client.Client, logger *zap.SugaredLogger) handler.MapFunc { +func serviceHandlerForIngress(cl client.Client, logger *zap.SugaredLogger, ingressClassName string) handler.MapFunc { return func(ctx context.Context, o client.Object) []reconcile.Request { if isManagedByType(o, "ingress") { ingName := parentFromObjectLabels(o) @@ -708,8 +1121,12 @@ func serviceHandlerForIngress(cl client.Client, logger *zap.SugaredLogger) handl } reqs := make([]reconcile.Request, 0) for _, ing := range ingList.Items { - if ing.Spec.IngressClassName == nil || *ing.Spec.IngressClassName != tailscaleIngressClassName { - return nil + if ing.Spec.IngressClassName == nil || *ing.Spec.IngressClassName != ingressClassName { + continue + } + if hasProxyGroupAnnotation(&ing) { + // We don't want to reconcile backend Services for Ingresses for ProxyGroups. + continue } if ing.Spec.DefaultBackend != nil && ing.Spec.DefaultBackend.Service != nil && ing.Spec.DefaultBackend.Service.Name == o.GetName() { reqs = append(reqs, reconcile.Request{NamespacedName: client.ObjectKeyFromObject(&ing)}) @@ -793,33 +1210,189 @@ func egressEpsHandler(_ context.Context, o client.Object) []reconcile.Request { } } -// egressEpsFromEgressPGChildResources returns a handler that checks if an -// object is a child resource for an egress ProxyGroup (a Pod or a state Secret) -// and if it is, returns reconciler requests for all egress EndpointSlices for -// that ProxyGroup. -func egressEpsFromEgressPGChildResources(cl client.Client, logger *zap.SugaredLogger, ns string) handler.MapFunc { +func egressPodsHandler(_ context.Context, o client.Object) []reconcile.Request { + if typ := o.GetLabels()[LabelParentType]; typ != proxyTypeProxyGroup { + return nil + } + return []reconcile.Request{ + { + NamespacedName: types.NamespacedName{ + Namespace: o.GetNamespace(), + Name: o.GetName(), + }, + }, + } +} + +// egressEpsFromEgressPods returns a Pod event handler that checks if Pod is a replica for a ProxyGroup and if it is, +// returns reconciler requests for all egress EndpointSlices for that ProxyGroup. +func egressEpsFromPGPods(cl client.Client, ns string) handler.MapFunc { return func(_ context.Context, o client.Object) []reconcile.Request { - pg, ok := o.GetLabels()[labelProxyGroup] + if v, ok := o.GetLabels()[kubetypes.LabelManaged]; !ok || v != "true" { + return nil + } + // TODO(irbekrm): for now this is good enough as all ProxyGroups are egress. Add a type check once we + // have ingress ProxyGroups. + if typ := o.GetLabels()[LabelParentType]; typ != "proxygroup" { + return nil + } + pg, ok := o.GetLabels()[LabelParentName] + if !ok { + return nil + } + return reconcileRequestsForPG(pg, cl, ns) + } +} + +// egressEpsFromPGStateSecrets returns a Secret event handler that checks if Secret is a state Secret for a ProxyGroup and if it is, +// returns reconciler requests for all egress EndpointSlices for that ProxyGroup. +func egressEpsFromPGStateSecrets(cl client.Client, ns string) handler.MapFunc { + return func(_ context.Context, o client.Object) []reconcile.Request { + if v, ok := o.GetLabels()[kubetypes.LabelManaged]; !ok || v != "true" { + return nil + } + if parentType := o.GetLabels()[LabelParentType]; parentType != "proxygroup" { + return nil + } + if secretType := o.GetLabels()[kubetypes.LabelSecretType]; secretType != kubetypes.LabelSecretTypeState { + return nil + } + pg, ok := o.GetLabels()[LabelParentName] if !ok { return nil } - // TODO(irbekrm): depending on what labels we add to ProxyGroup - // resources and which resources, this might need some extra - // checks. - if typ, ok := o.GetLabels()[labelProxyGroupType]; !ok || typ != typeEgress { + return reconcileRequestsForPG(pg, cl, ns) + } +} + +func ingressSvcFromEps(cl client.Client, logger *zap.SugaredLogger) handler.MapFunc { + return func(ctx context.Context, o client.Object) []reconcile.Request { + svcName := o.GetLabels()[discoveryv1.LabelServiceName] + if svcName == "" { + return nil + } + + svc := &corev1.Service{} + ns := o.GetNamespace() + if err := cl.Get(ctx, types.NamespacedName{Name: svcName, Namespace: ns}, svc); err != nil { + logger.Errorf("failed to get service: %v", err) return nil } - epsList := discoveryv1.EndpointSliceList{} - if err := cl.List(context.Background(), &epsList, client.InNamespace(ns), client.MatchingLabels(map[string]string{labelProxyGroup: pg})); err != nil { - logger.Infof("error listing EndpointSlices: %v, skipping a reconcile for event on %s %s", err, o.GetName(), o.GetObjectKind().GroupVersionKind().Kind) + + pgName := svc.Annotations[AnnotationProxyGroup] + if pgName == "" { + return nil + } + + return []reconcile.Request{ + { + NamespacedName: types.NamespacedName{ + Namespace: ns, + Name: svcName, + }, + }, + } + } +} + +// egressSvcFromEps is an event handler for EndpointSlices. If an EndpointSlice is for an egress ExternalName Service +// meant to be exposed on a ProxyGroup, returns a reconcile request for the Service. +func egressSvcFromEps(_ context.Context, o client.Object) []reconcile.Request { + if typ := o.GetLabels()[labelSvcType]; typ != typeEgress { + return nil + } + if v, ok := o.GetLabels()[kubetypes.LabelManaged]; !ok || v != "true" { + return nil + } + svcName, ok := o.GetLabels()[LabelParentName] + if !ok { + return nil + } + svcNs, ok := o.GetLabels()[LabelParentNamespace] + if !ok { + return nil + } + return []reconcile.Request{ + { + NamespacedName: types.NamespacedName{ + Namespace: svcNs, + Name: svcName, + }, + }, + } +} + +func reconcileRequestsForPG(pg string, cl client.Client, ns string) []reconcile.Request { + epsList := discoveryv1.EndpointSliceList{} + if err := cl.List(context.Background(), &epsList, + client.InNamespace(ns), + client.MatchingLabels(map[string]string{labelProxyGroup: pg})); err != nil { + return nil + } + reqs := make([]reconcile.Request, 0) + for _, ep := range epsList.Items { + reqs = append(reqs, reconcile.Request{ + NamespacedName: types.NamespacedName{ + Namespace: ep.Namespace, + Name: ep.Name, + }, + }) + } + return reqs +} + +func isTLSSecret(secret *corev1.Secret) bool { + return secret.Type == corev1.SecretTypeTLS && + secret.ObjectMeta.Labels[kubetypes.LabelManaged] == "true" && + secret.ObjectMeta.Labels[kubetypes.LabelSecretType] == kubetypes.LabelSecretTypeCerts && + secret.ObjectMeta.Labels[labelDomain] != "" && + secret.ObjectMeta.Labels[labelProxyGroup] != "" +} + +func isPGStateSecret(secret *corev1.Secret) bool { + return secret.ObjectMeta.Labels[kubetypes.LabelManaged] == "true" && + secret.ObjectMeta.Labels[LabelParentType] == "proxygroup" && + secret.ObjectMeta.Labels[kubetypes.LabelSecretType] == kubetypes.LabelSecretTypeState +} + +// HAIngressesFromSecret returns a handler that returns reconcile requests for +// all HA Ingresses that should be reconciled in response to a Secret event. +func HAIngressesFromSecret(cl client.Client, logger *zap.SugaredLogger) handler.MapFunc { + return func(ctx context.Context, o client.Object) []reconcile.Request { + secret, ok := o.(*corev1.Secret) + if !ok { + logger.Infof("[unexpected] Secret handler triggered for an object that is not a Secret") + return nil + } + if isTLSSecret(secret) { + return []reconcile.Request{ + { + NamespacedName: types.NamespacedName{ + Namespace: secret.ObjectMeta.Labels[LabelParentNamespace], + Name: secret.ObjectMeta.Labels[LabelParentName], + }, + }, + } + } + if !isPGStateSecret(secret) { + return nil + } + pgName, ok := secret.ObjectMeta.Labels[LabelParentName] + if !ok { + return nil + } + + ingList := &networkingv1.IngressList{} + if err := cl.List(ctx, ingList, client.MatchingFields{indexIngressProxyGroup: pgName}); err != nil { + logger.Infof("error listing Ingresses, skipping a reconcile for event on Secret %s: %v", secret.Name, err) return nil } reqs := make([]reconcile.Request, 0) - for _, ep := range epsList.Items { + for _, ing := range ingList.Items { reqs = append(reqs, reconcile.Request{ NamespacedName: types.NamespacedName{ - Namespace: ep.Namespace, - Name: ep.Name, + Namespace: ing.Namespace, + Name: ing.Name, }, }) } @@ -827,8 +1400,80 @@ func egressEpsFromEgressPGChildResources(cl client.Client, logger *zap.SugaredLo } } +// HAServiceFromSecret returns a handler that returns reconcile requests for +// all HA Services that should be reconciled in response to a Secret event. +func HAServicesFromSecret(cl client.Client, logger *zap.SugaredLogger) handler.MapFunc { + return func(ctx context.Context, o client.Object) []reconcile.Request { + secret, ok := o.(*corev1.Secret) + if !ok { + logger.Infof("[unexpected] Secret handler triggered for an object that is not a Secret") + return nil + } + if !isPGStateSecret(secret) { + return nil + } + pgName, ok := secret.ObjectMeta.Labels[LabelParentName] + if !ok { + return nil + } + svcList := &corev1.ServiceList{} + if err := cl.List(ctx, svcList, client.MatchingFields{indexIngressProxyGroup: pgName}); err != nil { + logger.Infof("error listing Services, skipping a reconcile for event on Secret %s: %v", secret.Name, err) + return nil + } + reqs := make([]reconcile.Request, 0) + for _, svc := range svcList.Items { + reqs = append(reqs, reconcile.Request{ + NamespacedName: types.NamespacedName{ + Namespace: svc.Namespace, + Name: svc.Name, + }, + }) + } + return reqs + } +} + +// kubeAPIServerPGsFromSecret finds ProxyGroups of type "kube-apiserver" that +// need to be reconciled after a ProxyGroup-owned Secret is updated. +func kubeAPIServerPGsFromSecret(cl client.Client, logger *zap.SugaredLogger) handler.MapFunc { + return func(ctx context.Context, o client.Object) []reconcile.Request { + secret, ok := o.(*corev1.Secret) + if !ok { + logger.Infof("[unexpected] Secret handler triggered for an object that is not a Secret") + return nil + } + if secret.ObjectMeta.Labels[kubetypes.LabelManaged] != "true" || + secret.ObjectMeta.Labels[LabelParentType] != "proxygroup" { + return nil + } + + var pg tsapi.ProxyGroup + if err := cl.Get(ctx, types.NamespacedName{Name: secret.ObjectMeta.Labels[LabelParentName]}, &pg); err != nil { + logger.Infof("error getting ProxyGroup %s: %v", secret.ObjectMeta.Labels[LabelParentName], err) + return nil + } + + if pg.Spec.Type != tsapi.ProxyGroupTypeKubernetesAPIServer { + return nil + } + + return []reconcile.Request{ + { + NamespacedName: types.NamespacedName{ + Namespace: secret.ObjectMeta.Labels[LabelParentNamespace], + Name: secret.ObjectMeta.Labels[LabelParentName], + }, + }, + } + + } +} + +// egressSvcsFromEgressProxyGroup is an event handler for egress ProxyGroups. It returns reconcile requests for all +// user-created ExternalName Services that should be exposed on this ProxyGroup. func egressSvcsFromEgressProxyGroup(cl client.Client, logger *zap.SugaredLogger) handler.MapFunc { - return func(_ context.Context, o client.Object) []reconcile.Request { + return func(ctx context.Context, o client.Object) []reconcile.Request { pg, ok := o.(*tsapi.ProxyGroup) if !ok { logger.Infof("[unexpected] ProxyGroup handler triggered for an object that is not a ProxyGroup") @@ -838,7 +1483,7 @@ func egressSvcsFromEgressProxyGroup(cl client.Client, logger *zap.SugaredLogger) return nil } svcList := &corev1.ServiceList{} - if err := cl.List(context.Background(), svcList, client.MatchingFields{indexEgressProxyGroup: pg.Name}); err != nil { + if err := cl.List(ctx, svcList, client.MatchingFields{indexEgressProxyGroup: pg.Name}); err != nil { logger.Infof("error listing Services: %v, skipping a reconcile for event on ProxyGroup %s", err, pg.Name) return nil } @@ -855,8 +1500,40 @@ func egressSvcsFromEgressProxyGroup(cl client.Client, logger *zap.SugaredLogger) } } -func epsFromExternalNameService(cl client.Client, logger *zap.SugaredLogger) handler.MapFunc { - return func(_ context.Context, o client.Object) []reconcile.Request { +// ingressesFromIngressProxyGroup is an event handler for ingress ProxyGroups. It returns reconcile requests for all +// user-created Ingresses that should be exposed on this ProxyGroup. +func ingressesFromIngressProxyGroup(cl client.Client, logger *zap.SugaredLogger) handler.MapFunc { + return func(ctx context.Context, o client.Object) []reconcile.Request { + pg, ok := o.(*tsapi.ProxyGroup) + if !ok { + logger.Infof("[unexpected] ProxyGroup handler triggered for an object that is not a ProxyGroup") + return nil + } + if pg.Spec.Type != tsapi.ProxyGroupTypeIngress { + return nil + } + ingList := &networkingv1.IngressList{} + if err := cl.List(ctx, ingList, client.MatchingFields{indexIngressProxyGroup: pg.Name}); err != nil { + logger.Infof("error listing Ingresses: %v, skipping a reconcile for event on ProxyGroup %s", err, pg.Name) + return nil + } + reqs := make([]reconcile.Request, 0) + for _, svc := range ingList.Items { + reqs = append(reqs, reconcile.Request{ + NamespacedName: types.NamespacedName{ + Namespace: svc.Namespace, + Name: svc.Name, + }, + }) + } + return reqs + } +} + +// epsFromExternalNameService is an event handler for ExternalName Services that define a Tailscale egress service that +// should be exposed on a ProxyGroup. It returns reconcile requests for EndpointSlices created for this Service. +func epsFromExternalNameService(cl client.Client, logger *zap.SugaredLogger, ns string) handler.MapFunc { + return func(ctx context.Context, o client.Object) []reconcile.Request { svc, ok := o.(*corev1.Service) if !ok { logger.Infof("[unexpected] Service handler triggered for an object that is not a Service") @@ -866,10 +1543,8 @@ func epsFromExternalNameService(cl client.Client, logger *zap.SugaredLogger) han return nil } epsList := &discoveryv1.EndpointSliceList{} - if err := cl.List(context.Background(), epsList, client.MatchingLabels(map[string]string{ - labelExternalSvcName: svc.Name, - labelExternalSvcNamespace: svc.Namespace, - })); err != nil { + if err := cl.List(ctx, epsList, client.InNamespace(ns), + client.MatchingLabels(egressSvcChildResourceLabels(svc))); err != nil { logger.Infof("error listing EndpointSlices: %v, skipping a reconcile for event on Service %s", err, svc.Name) return nil } @@ -886,9 +1561,155 @@ func epsFromExternalNameService(cl client.Client, logger *zap.SugaredLogger) han } } +func podsFromEgressEps(cl client.Client, logger *zap.SugaredLogger, ns string) handler.MapFunc { + return func(ctx context.Context, o client.Object) []reconcile.Request { + eps, ok := o.(*discoveryv1.EndpointSlice) + if !ok { + logger.Infof("[unexpected] EndpointSlice handler triggered for an object that is not a EndpointSlice") + return nil + } + if eps.Labels[labelProxyGroup] == "" { + return nil + } + if eps.Labels[labelSvcType] != "egress" { + return nil + } + podLabels := map[string]string{ + kubetypes.LabelManaged: "true", + LabelParentType: "proxygroup", + LabelParentName: eps.Labels[labelProxyGroup], + } + podList := &corev1.PodList{} + if err := cl.List(ctx, podList, client.InNamespace(ns), + client.MatchingLabels(podLabels)); err != nil { + logger.Infof("error listing EndpointSlices: %v, skipping a reconcile for event on EndpointSlice %s", err, eps.Name) + return nil + } + reqs := make([]reconcile.Request, 0) + for _, pod := range podList.Items { + reqs = append(reqs, reconcile.Request{ + NamespacedName: types.NamespacedName{ + Namespace: pod.Namespace, + Name: pod.Name, + }, + }) + } + return reqs + } +} + +// proxyClassesWithServiceMonitor returns an event handler that, given that the event is for the Prometheus +// ServiceMonitor CRD, returns all ProxyClasses that define that a ServiceMonitor should be created. +func proxyClassesWithServiceMonitor(cl client.Client, logger *zap.SugaredLogger) handler.MapFunc { + return func(ctx context.Context, o client.Object) []reconcile.Request { + crd, ok := o.(*apiextensionsv1.CustomResourceDefinition) + if !ok { + logger.Debugf("[unexpected] ServiceMonitor CRD handler received an object that is not a CustomResourceDefinition") + return nil + } + if crd.Name != serviceMonitorCRD { + logger.Debugf("[unexpected] ServiceMonitor CRD handler received an unexpected CRD %q", crd.Name) + return nil + } + pcl := &tsapi.ProxyClassList{} + if err := cl.List(ctx, pcl); err != nil { + logger.Debugf("[unexpected] error listing ProxyClasses: %v", err) + return nil + } + reqs := make([]reconcile.Request, 0) + for _, pc := range pcl.Items { + if pc.Spec.Metrics != nil && pc.Spec.Metrics.ServiceMonitor != nil && pc.Spec.Metrics.ServiceMonitor.Enable { + reqs = append(reqs, reconcile.Request{ + NamespacedName: types.NamespacedName{Namespace: pc.Namespace, Name: pc.Name}, + }) + } + } + return reqs + } +} + +// crdTransformer gets called before a CRD is stored to c/r cache, it removes the CRD spec to reduce memory consumption. +func crdTransformer(log *zap.SugaredLogger) toolscache.TransformFunc { + return func(o any) (any, error) { + crd, ok := o.(*apiextensionsv1.CustomResourceDefinition) + if !ok { + log.Infof("[unexpected] CRD transformer called for a non-CRD type") + return crd, nil + } + crd.Spec = apiextensionsv1.CustomResourceDefinitionSpec{} + return crd, nil + } +} + +// indexEgressServices adds a local index to cached Tailscale egress Services meant to be exposed on a ProxyGroup. The +// index is used a list filter. func indexEgressServices(o client.Object) []string { if !isEgressSvcForProxyGroup(o) { return nil } return []string{o.GetAnnotations()[AnnotationProxyGroup]} } + +// indexPGIngresses is used to select ProxyGroup-backed Services which are +// locally indexed in the cache for efficient listing without requiring labels. +func indexPGIngresses(o client.Object) []string { + if !hasProxyGroupAnnotation(o) { + return nil + } + return []string{o.GetAnnotations()[AnnotationProxyGroup]} +} + +// serviceHandlerForIngressPG returns a handler for Service events that ensures that if the Service +// associated with an event is a backend Service for a tailscale Ingress with ProxyGroup annotation, +// the associated Ingress gets reconciled. +func serviceHandlerForIngressPG(cl client.Client, logger *zap.SugaredLogger, ingressClassName string) handler.MapFunc { + return func(ctx context.Context, o client.Object) []reconcile.Request { + ingList := networkingv1.IngressList{} + if err := cl.List(ctx, &ingList, client.InNamespace(o.GetNamespace())); err != nil { + logger.Debugf("error listing Ingresses: %v", err) + return nil + } + reqs := make([]reconcile.Request, 0) + for _, ing := range ingList.Items { + if ing.Spec.IngressClassName == nil || *ing.Spec.IngressClassName != ingressClassName { + continue + } + if !hasProxyGroupAnnotation(&ing) { + continue + } + if ing.Spec.DefaultBackend != nil && ing.Spec.DefaultBackend.Service != nil && ing.Spec.DefaultBackend.Service.Name == o.GetName() { + reqs = append(reqs, reconcile.Request{NamespacedName: client.ObjectKeyFromObject(&ing)}) + } + for _, rule := range ing.Spec.Rules { + if rule.HTTP == nil { + continue + } + for _, path := range rule.HTTP.Paths { + if path.Backend.Service != nil && path.Backend.Service.Name == o.GetName() { + reqs = append(reqs, reconcile.Request{NamespacedName: client.ObjectKeyFromObject(&ing)}) + } + } + } + } + return reqs + } +} + +func hasProxyGroupAnnotation(obj client.Object) bool { + return obj.GetAnnotations()[AnnotationProxyGroup] != "" +} + +func hasProxyClassAnnotation(obj client.Object) bool { + return obj.GetAnnotations()[LabelAnnotationProxyClass] != "" +} + +func id(ctx context.Context, lc *local.Client) (string, error) { + st, err := lc.StatusWithoutPeers(ctx) + if err != nil { + return "", fmt.Errorf("error getting tailscale status: %w", err) + } + if st.Self == nil { + return "", fmt.Errorf("unexpected: device's status does not contain self status") + } + return string(st.Self.ID), nil +} diff --git a/cmd/k8s-operator/operator_test.go b/cmd/k8s-operator/operator_test.go index 7ea8c09e1..e11235768 100644 --- a/cmd/k8s-operator/operator_test.go +++ b/cmd/k8s-operator/operator_test.go @@ -7,6 +7,7 @@ package main import ( "context" + "encoding/json" "fmt" "testing" "time" @@ -16,11 +17,14 @@ import ( appsv1 "k8s.io/api/apps/v1" corev1 "k8s.io/api/core/v1" networkingv1 "k8s.io/api/networking/v1" + apiextensionsv1 "k8s.io/apiextensions-apiserver/pkg/apis/apiextensions/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/types" "k8s.io/client-go/tools/record" + "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/client/fake" "sigs.k8s.io/controller-runtime/pkg/reconcile" + "tailscale.com/k8s-operator/apis/v1alpha1" tsapi "tailscale.com/k8s-operator/apis/v1alpha1" "tailscale.com/kube/kubetypes" "tailscale.com/net/dns/resolvconffile" @@ -105,7 +109,7 @@ func TestLoadBalancerClass(t *testing.T) { }}, }, } - expectEqual(t, fc, want, nil) + expectEqual(t, fc, want) // Delete the misconfiguration so the proxy starts getting created on the // next reconcile. @@ -118,6 +122,7 @@ func TestLoadBalancerClass(t *testing.T) { fullName, shortName := findGenName(t, fc, "default", "test", "svc") opts := configOpts{ + replicas: ptr.To[int32](1), stsName: shortName, secretName: fullName, namespace: "default", @@ -127,9 +132,9 @@ func TestLoadBalancerClass(t *testing.T) { app: kubetypes.AppIngressProxy, } - expectEqual(t, fc, expectedSecret(t, fc, opts), nil) - expectEqual(t, fc, expectedHeadlessService(shortName, "svc"), nil) - expectEqual(t, fc, expectedSTS(t, fc, opts), removeHashAnnotation) + expectEqual(t, fc, expectedSecret(t, fc, opts)) + expectEqual(t, fc, expectedHeadlessService(shortName, "svc")) + expectEqual(t, fc, expectedSTS(t, fc, opts), removeResourceReqs) want.Annotations = nil want.ObjectMeta.Finalizers = []string{"tailscale.com/finalizer"} @@ -142,7 +147,7 @@ func TestLoadBalancerClass(t *testing.T) { Message: "no Tailscale hostname known yet, waiting for proxy pod to finish auth", }}, } - expectEqual(t, fc, want, nil) + expectEqual(t, fc, want) // Normally the Tailscale proxy pod would come up here and write its info // into the secret. Simulate that, then verify reconcile again and verify @@ -168,7 +173,11 @@ func TestLoadBalancerClass(t *testing.T) { }, }, } - expectEqual(t, fc, want, nil) + + // Perform an additional reconciliation loop here to ensure resources don't change through side effects. Mainly + // to prevent infinite reconciliation + expectReconciled(t, sr, "default", "test") + expectEqual(t, fc, want) // Turn the service back into a ClusterIP service, which should make the // operator clean up. @@ -205,7 +214,7 @@ func TestLoadBalancerClass(t *testing.T) { Type: corev1.ServiceTypeClusterIP, }, } - expectEqual(t, fc, want, nil) + expectEqual(t, fc, want) } func TestTailnetTargetFQDNAnnotation(t *testing.T) { @@ -256,6 +265,7 @@ func TestTailnetTargetFQDNAnnotation(t *testing.T) { fullName, shortName := findGenName(t, fc, "default", "test", "svc") o := configOpts{ + replicas: ptr.To[int32](1), stsName: shortName, secretName: fullName, namespace: "default", @@ -265,9 +275,9 @@ func TestTailnetTargetFQDNAnnotation(t *testing.T) { app: kubetypes.AppEgressProxy, } - expectEqual(t, fc, expectedSecret(t, fc, o), nil) - expectEqual(t, fc, expectedHeadlessService(shortName, "svc"), nil) - expectEqual(t, fc, expectedSTS(t, fc, o), removeHashAnnotation) + expectEqual(t, fc, expectedSecret(t, fc, o)) + expectEqual(t, fc, expectedHeadlessService(shortName, "svc")) + expectEqual(t, fc, expectedSTS(t, fc, o), removeResourceReqs) want := &corev1.Service{ ObjectMeta: metav1.ObjectMeta{ Name: "test", @@ -287,10 +297,10 @@ func TestTailnetTargetFQDNAnnotation(t *testing.T) { Conditions: proxyCreatedCondition(clock), }, } - expectEqual(t, fc, want, nil) - expectEqual(t, fc, expectedSecret(t, fc, o), nil) - expectEqual(t, fc, expectedHeadlessService(shortName, "svc"), nil) - expectEqual(t, fc, expectedSTS(t, fc, o), removeHashAnnotation) + expectEqual(t, fc, want) + expectEqual(t, fc, expectedSecret(t, fc, o)) + expectEqual(t, fc, expectedHeadlessService(shortName, "svc")) + expectEqual(t, fc, expectedSTS(t, fc, o), removeResourceReqs) // Change the tailscale-target-fqdn annotation which should update the // StatefulSet @@ -368,6 +378,7 @@ func TestTailnetTargetIPAnnotation(t *testing.T) { fullName, shortName := findGenName(t, fc, "default", "test", "svc") o := configOpts{ + replicas: ptr.To[int32](1), stsName: shortName, secretName: fullName, namespace: "default", @@ -377,9 +388,9 @@ func TestTailnetTargetIPAnnotation(t *testing.T) { app: kubetypes.AppEgressProxy, } - expectEqual(t, fc, expectedSecret(t, fc, o), nil) - expectEqual(t, fc, expectedHeadlessService(shortName, "svc"), nil) - expectEqual(t, fc, expectedSTS(t, fc, o), removeHashAnnotation) + expectEqual(t, fc, expectedSecret(t, fc, o)) + expectEqual(t, fc, expectedHeadlessService(shortName, "svc")) + expectEqual(t, fc, expectedSTS(t, fc, o), removeResourceReqs) want := &corev1.Service{ ObjectMeta: metav1.ObjectMeta{ Name: "test", @@ -399,10 +410,10 @@ func TestTailnetTargetIPAnnotation(t *testing.T) { Conditions: proxyCreatedCondition(clock), }, } - expectEqual(t, fc, want, nil) - expectEqual(t, fc, expectedSecret(t, fc, o), nil) - expectEqual(t, fc, expectedHeadlessService(shortName, "svc"), nil) - expectEqual(t, fc, expectedSTS(t, fc, o), removeHashAnnotation) + expectEqual(t, fc, want) + expectEqual(t, fc, expectedSecret(t, fc, o)) + expectEqual(t, fc, expectedHeadlessService(shortName, "svc")) + expectEqual(t, fc, expectedSTS(t, fc, o), removeResourceReqs) // Change the tailscale-target-ip annotation which should update the // StatefulSet @@ -432,6 +443,148 @@ func TestTailnetTargetIPAnnotation(t *testing.T) { expectMissing[corev1.Secret](t, fc, "operator-ns", fullName) } +func TestTailnetTargetIPAnnotation_IPCouldNotBeParsed(t *testing.T) { + fc := fake.NewFakeClient() + ft := &fakeTSClient{} + zl, err := zap.NewDevelopment() + if err != nil { + t.Fatal(err) + } + clock := tstest.NewClock(tstest.ClockOpts{}) + sr := &ServiceReconciler{ + Client: fc, + ssr: &tailscaleSTSReconciler{ + Client: fc, + tsClient: ft, + defaultTags: []string{"tag:k8s"}, + operatorNamespace: "operator-ns", + proxyImage: "tailscale/tailscale", + }, + logger: zl.Sugar(), + clock: clock, + recorder: record.NewFakeRecorder(100), + } + tailnetTargetIP := "invalid-ip" + mustCreate(t, fc, &corev1.Service{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test", + Namespace: "default", + + UID: types.UID("1234-UID"), + Annotations: map[string]string{ + AnnotationTailnetTargetIP: tailnetTargetIP, + }, + }, + Spec: corev1.ServiceSpec{ + ClusterIP: "10.20.30.40", + Type: corev1.ServiceTypeLoadBalancer, + LoadBalancerClass: ptr.To("tailscale"), + }, + }) + + expectReconciled(t, sr, "default", "test") + + t0 := conditionTime(clock) + + want := &corev1.Service{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test", + Namespace: "default", + UID: types.UID("1234-UID"), + Annotations: map[string]string{ + AnnotationTailnetTargetIP: tailnetTargetIP, + }, + }, + Spec: corev1.ServiceSpec{ + ClusterIP: "10.20.30.40", + Type: corev1.ServiceTypeLoadBalancer, + LoadBalancerClass: ptr.To("tailscale"), + }, + Status: corev1.ServiceStatus{ + Conditions: []metav1.Condition{{ + Type: string(tsapi.ProxyReady), + Status: metav1.ConditionFalse, + LastTransitionTime: t0, + Reason: reasonProxyInvalid, + Message: `unable to provision proxy resources: invalid Service: invalid value of annotation tailscale.com/tailnet-ip: "invalid-ip" could not be parsed as a valid IP Address, error: ParseAddr("invalid-ip"): unable to parse IP`, + }}, + }, + } + + expectEqual(t, fc, want) +} + +func TestTailnetTargetIPAnnotation_InvalidIP(t *testing.T) { + fc := fake.NewFakeClient() + ft := &fakeTSClient{} + zl, err := zap.NewDevelopment() + if err != nil { + t.Fatal(err) + } + clock := tstest.NewClock(tstest.ClockOpts{}) + sr := &ServiceReconciler{ + Client: fc, + ssr: &tailscaleSTSReconciler{ + Client: fc, + tsClient: ft, + defaultTags: []string{"tag:k8s"}, + operatorNamespace: "operator-ns", + proxyImage: "tailscale/tailscale", + }, + logger: zl.Sugar(), + clock: clock, + recorder: record.NewFakeRecorder(100), + } + tailnetTargetIP := "999.999.999.999" + mustCreate(t, fc, &corev1.Service{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test", + Namespace: "default", + + UID: types.UID("1234-UID"), + Annotations: map[string]string{ + AnnotationTailnetTargetIP: tailnetTargetIP, + }, + }, + Spec: corev1.ServiceSpec{ + ClusterIP: "10.20.30.40", + Type: corev1.ServiceTypeLoadBalancer, + LoadBalancerClass: ptr.To("tailscale"), + }, + }) + + expectReconciled(t, sr, "default", "test") + + t0 := conditionTime(clock) + + want := &corev1.Service{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test", + Namespace: "default", + UID: types.UID("1234-UID"), + Annotations: map[string]string{ + AnnotationTailnetTargetIP: tailnetTargetIP, + }, + }, + Spec: corev1.ServiceSpec{ + ClusterIP: "10.20.30.40", + Type: corev1.ServiceTypeLoadBalancer, + LoadBalancerClass: ptr.To("tailscale"), + }, + Status: corev1.ServiceStatus{ + Conditions: []metav1.Condition{{ + Type: string(tsapi.ProxyReady), + Status: metav1.ConditionFalse, + LastTransitionTime: t0, + Reason: reasonProxyInvalid, + Message: `unable to provision proxy resources: invalid Service: invalid value of annotation tailscale.com/tailnet-ip: "999.999.999.999" could not be parsed as a valid IP Address, error: ParseAddr("999.999.999.999"): IPv4 field has value >255`, + }}, + }, + } + + expectEqual(t, fc, want) +} + func TestAnnotations(t *testing.T) { fc := fake.NewFakeClient() ft := &fakeTSClient{} @@ -477,6 +630,7 @@ func TestAnnotations(t *testing.T) { fullName, shortName := findGenName(t, fc, "default", "test", "svc") o := configOpts{ + replicas: ptr.To[int32](1), stsName: shortName, secretName: fullName, namespace: "default", @@ -486,9 +640,9 @@ func TestAnnotations(t *testing.T) { app: kubetypes.AppIngressProxy, } - expectEqual(t, fc, expectedSecret(t, fc, o), nil) - expectEqual(t, fc, expectedHeadlessService(shortName, "svc"), nil) - expectEqual(t, fc, expectedSTS(t, fc, o), removeHashAnnotation) + expectEqual(t, fc, expectedSecret(t, fc, o)) + expectEqual(t, fc, expectedHeadlessService(shortName, "svc")) + expectEqual(t, fc, expectedSTS(t, fc, o), removeResourceReqs) want := &corev1.Service{ ObjectMeta: metav1.ObjectMeta{ Name: "test", @@ -507,7 +661,7 @@ func TestAnnotations(t *testing.T) { Conditions: proxyCreatedCondition(clock), }, } - expectEqual(t, fc, want, nil) + expectEqual(t, fc, want) // Turn the service back into a ClusterIP service, which should make the // operator clean up. @@ -535,7 +689,7 @@ func TestAnnotations(t *testing.T) { Type: corev1.ServiceTypeClusterIP, }, } - expectEqual(t, fc, want, nil) + expectEqual(t, fc, want) } func TestAnnotationIntoLB(t *testing.T) { @@ -583,6 +737,7 @@ func TestAnnotationIntoLB(t *testing.T) { fullName, shortName := findGenName(t, fc, "default", "test", "svc") o := configOpts{ + replicas: ptr.To[int32](1), stsName: shortName, secretName: fullName, namespace: "default", @@ -592,9 +747,9 @@ func TestAnnotationIntoLB(t *testing.T) { app: kubetypes.AppIngressProxy, } - expectEqual(t, fc, expectedSecret(t, fc, o), nil) - expectEqual(t, fc, expectedHeadlessService(shortName, "svc"), nil) - expectEqual(t, fc, expectedSTS(t, fc, o), removeHashAnnotation) + expectEqual(t, fc, expectedSecret(t, fc, o)) + expectEqual(t, fc, expectedHeadlessService(shortName, "svc")) + expectEqual(t, fc, expectedSTS(t, fc, o), removeResourceReqs) // Normally the Tailscale proxy pod would come up here and write its info // into the secret. Simulate that, since it would have normally happened at @@ -626,7 +781,7 @@ func TestAnnotationIntoLB(t *testing.T) { Conditions: proxyCreatedCondition(clock), }, } - expectEqual(t, fc, want, nil) + expectEqual(t, fc, want) // Remove Tailscale's annotation, and at the same time convert the service // into a tailscale LoadBalancer. @@ -637,8 +792,8 @@ func TestAnnotationIntoLB(t *testing.T) { }) expectReconciled(t, sr, "default", "test") // None of the proxy machinery should have changed... - expectEqual(t, fc, expectedHeadlessService(shortName, "svc"), nil) - expectEqual(t, fc, expectedSTS(t, fc, o), removeHashAnnotation) + expectEqual(t, fc, expectedHeadlessService(shortName, "svc")) + expectEqual(t, fc, expectedSTS(t, fc, o), removeResourceReqs) // ... but the service should have a LoadBalancer status. want = &corev1.Service{ @@ -667,7 +822,7 @@ func TestAnnotationIntoLB(t *testing.T) { Conditions: proxyCreatedCondition(clock), }, } - expectEqual(t, fc, want, nil) + expectEqual(t, fc, want) } func TestLBIntoAnnotation(t *testing.T) { @@ -713,6 +868,7 @@ func TestLBIntoAnnotation(t *testing.T) { fullName, shortName := findGenName(t, fc, "default", "test", "svc") o := configOpts{ + replicas: ptr.To[int32](1), stsName: shortName, secretName: fullName, namespace: "default", @@ -722,9 +878,9 @@ func TestLBIntoAnnotation(t *testing.T) { app: kubetypes.AppIngressProxy, } - expectEqual(t, fc, expectedSecret(t, fc, o), nil) - expectEqual(t, fc, expectedHeadlessService(shortName, "svc"), nil) - expectEqual(t, fc, expectedSTS(t, fc, o), removeHashAnnotation) + expectEqual(t, fc, expectedSecret(t, fc, o)) + expectEqual(t, fc, expectedHeadlessService(shortName, "svc")) + expectEqual(t, fc, expectedSTS(t, fc, o), removeResourceReqs) // Normally the Tailscale proxy pod would come up here and write its info // into the secret. Simulate that, then verify reconcile again and verify @@ -764,7 +920,7 @@ func TestLBIntoAnnotation(t *testing.T) { Conditions: proxyCreatedCondition(clock), }, } - expectEqual(t, fc, want, nil) + expectEqual(t, fc, want) // Turn the service back into a ClusterIP service, but also add the // tailscale annotation. @@ -783,8 +939,8 @@ func TestLBIntoAnnotation(t *testing.T) { }) expectReconciled(t, sr, "default", "test") - expectEqual(t, fc, expectedHeadlessService(shortName, "svc"), nil) - expectEqual(t, fc, expectedSTS(t, fc, o), removeHashAnnotation) + expectEqual(t, fc, expectedHeadlessService(shortName, "svc")) + expectEqual(t, fc, expectedSTS(t, fc, o), removeResourceReqs) want = &corev1.Service{ ObjectMeta: metav1.ObjectMeta{ @@ -804,7 +960,7 @@ func TestLBIntoAnnotation(t *testing.T) { Conditions: proxyCreatedCondition(clock), }, } - expectEqual(t, fc, want, nil) + expectEqual(t, fc, want) } func TestCustomHostname(t *testing.T) { @@ -853,6 +1009,7 @@ func TestCustomHostname(t *testing.T) { fullName, shortName := findGenName(t, fc, "default", "test", "svc") o := configOpts{ + replicas: ptr.To[int32](1), stsName: shortName, secretName: fullName, namespace: "default", @@ -862,9 +1019,9 @@ func TestCustomHostname(t *testing.T) { app: kubetypes.AppIngressProxy, } - expectEqual(t, fc, expectedSecret(t, fc, o), nil) - expectEqual(t, fc, expectedHeadlessService(shortName, "svc"), nil) - expectEqual(t, fc, expectedSTS(t, fc, o), removeHashAnnotation) + expectEqual(t, fc, expectedSecret(t, fc, o)) + expectEqual(t, fc, expectedHeadlessService(shortName, "svc")) + expectEqual(t, fc, expectedSTS(t, fc, o), removeResourceReqs) want := &corev1.Service{ ObjectMeta: metav1.ObjectMeta{ Name: "test", @@ -884,7 +1041,7 @@ func TestCustomHostname(t *testing.T) { Conditions: proxyCreatedCondition(clock), }, } - expectEqual(t, fc, want, nil) + expectEqual(t, fc, want) // Turn the service back into a ClusterIP service, which should make the // operator clean up. @@ -915,7 +1072,7 @@ func TestCustomHostname(t *testing.T) { Type: corev1.ServiceTypeClusterIP, }, } - expectEqual(t, fc, want, nil) + expectEqual(t, fc, want) } func TestCustomPriorityClassName(t *testing.T) { @@ -965,6 +1122,7 @@ func TestCustomPriorityClassName(t *testing.T) { fullName, shortName := findGenName(t, fc, "default", "test", "svc") o := configOpts{ + replicas: ptr.To[int32](1), stsName: shortName, secretName: fullName, namespace: "default", @@ -975,7 +1133,183 @@ func TestCustomPriorityClassName(t *testing.T) { app: kubetypes.AppIngressProxy, } - expectEqual(t, fc, expectedSTS(t, fc, o), removeHashAnnotation) + expectEqual(t, fc, expectedSTS(t, fc, o), removeResourceReqs) +} + +func TestServiceProxyClassAnnotation(t *testing.T) { + cl := tstest.NewClock(tstest.ClockOpts{}) + zl := zap.Must(zap.NewDevelopment()) + + pcIfNotPresent := &tsapi.ProxyClass{ + ObjectMeta: metav1.ObjectMeta{ + Name: "if-not-present", + }, + Spec: tsapi.ProxyClassSpec{ + StatefulSet: &tsapi.StatefulSet{ + Pod: &tsapi.Pod{ + TailscaleContainer: &v1alpha1.Container{ + ImagePullPolicy: corev1.PullIfNotPresent, + }, + }, + }, + }, + } + + pcAlways := &tsapi.ProxyClass{ + ObjectMeta: metav1.ObjectMeta{ + Name: "always", + }, + Spec: tsapi.ProxyClassSpec{ + StatefulSet: &tsapi.StatefulSet{ + Pod: &tsapi.Pod{ + TailscaleContainer: &v1alpha1.Container{ + ImagePullPolicy: corev1.PullAlways, + }, + }, + }, + }, + } + + builder := fake.NewClientBuilder(). + WithScheme(tsapi.GlobalScheme) + builder = builder.WithObjects(pcIfNotPresent, pcAlways). + WithStatusSubresource(pcIfNotPresent, pcAlways) + fc := builder.Build() + + svc := &corev1.Service{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test", + Namespace: "default", + // The apiserver is supposed to set the UID, but the fake client + // doesn't. So, set it explicitly because other code later depends + // on it being set. + UID: types.UID("1234-UID"), + }, + Spec: corev1.ServiceSpec{ + ClusterIP: "10.20.30.40", + Type: corev1.ServiceTypeLoadBalancer, + }, + } + + mustCreate(t, fc, svc) + + testCases := []struct { + name string + proxyClassAnnotation string + proxyClassLabel string + proxyClassDefault string + expectedProxyClass string + expectEvents []string + }{ + { + name: "via_label", + proxyClassLabel: pcIfNotPresent.Name, + expectedProxyClass: pcIfNotPresent.Name, + }, + { + name: "via_annotation", + proxyClassAnnotation: pcIfNotPresent.Name, + expectedProxyClass: pcIfNotPresent.Name, + }, + { + name: "via_default", + proxyClassDefault: pcIfNotPresent.Name, + expectedProxyClass: pcIfNotPresent.Name, + }, + { + name: "via_label_override_annotation", + proxyClassLabel: pcIfNotPresent.Name, + proxyClassAnnotation: pcAlways.Name, + expectedProxyClass: pcIfNotPresent.Name, + }, + } + + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + ft := &fakeTSClient{} + + if tt.proxyClassAnnotation != "" || tt.proxyClassLabel != "" || tt.proxyClassDefault != "" { + name := tt.proxyClassDefault + if name == "" { + name = tt.proxyClassLabel + if name == "" { + name = tt.proxyClassAnnotation + } + } + setProxyClassReady(t, fc, cl, name) + } + + sr := &ServiceReconciler{ + Client: fc, + ssr: &tailscaleSTSReconciler{ + Client: fc, + tsClient: ft, + defaultTags: []string{"tag:k8s"}, + operatorNamespace: "operator-ns", + proxyImage: "tailscale/tailscale", + }, + defaultProxyClass: tt.proxyClassDefault, + logger: zl.Sugar(), + clock: cl, + isDefaultLoadBalancer: true, + } + + if tt.proxyClassLabel != "" { + svc.Labels = map[string]string{ + LabelAnnotationProxyClass: tt.proxyClassLabel, + } + } + if tt.proxyClassAnnotation != "" { + svc.Annotations = map[string]string{ + LabelAnnotationProxyClass: tt.proxyClassAnnotation, + } + } + + mustUpdate(t, fc, svc.Namespace, svc.Name, func(s *corev1.Service) { + s.Labels = svc.Labels + s.Annotations = svc.Annotations + }) + + expectReconciled(t, sr, "default", "test") + + list := &corev1.ServiceList{} + fc.List(context.Background(), list, client.InNamespace("default")) + + for _, i := range list.Items { + t.Logf("found service %s", i.Name) + } + + slist := &corev1.SecretList{} + fc.List(context.Background(), slist, client.InNamespace("operator-ns")) + for _, i := range slist.Items { + labels, _ := json.Marshal(i.Labels) + t.Logf("found secret %q with labels %q ", i.Name, string(labels)) + } + + _, shortName := findGenName(t, fc, "default", "test", "svc") + sts := &appsv1.StatefulSet{} + if err := fc.Get(context.Background(), client.ObjectKey{Namespace: "operator-ns", Name: shortName}, sts); err != nil { + t.Fatalf("failed to get StatefulSet: %v", err) + } + + switch tt.expectedProxyClass { + case pcIfNotPresent.Name: + for _, cont := range sts.Spec.Template.Spec.Containers { + if cont.Name == "tailscale" && cont.ImagePullPolicy != corev1.PullIfNotPresent { + t.Fatalf("ImagePullPolicy %q does not match ProxyClass %q with value %q", cont.ImagePullPolicy, pcIfNotPresent.Name, pcIfNotPresent.Spec.StatefulSet.Pod.TailscaleContainer.ImagePullPolicy) + } + } + case pcAlways.Name: + for _, cont := range sts.Spec.Template.Spec.Containers { + if cont.Name == "tailscale" && cont.ImagePullPolicy != corev1.PullAlways { + t.Fatalf("ImagePullPolicy %q does not match ProxyClass %q with value %q", cont.ImagePullPolicy, pcAlways.Name, pcAlways.Spec.StatefulSet.Pod.TailscaleContainer.ImagePullPolicy) + } + } + default: + t.Fatalf("unexpected expected ProxyClass %q", tt.expectedProxyClass) + } + }) + } } func TestProxyClassForService(t *testing.T) { @@ -987,9 +1321,11 @@ func TestProxyClassForService(t *testing.T) { AcceptRoutes: true, }, StatefulSet: &tsapi.StatefulSet{ - Labels: map[string]string{"foo": "bar"}, + Labels: tsapi.Labels{"foo": "bar"}, Annotations: map[string]string{"bar.io/foo": "some-val"}, - Pod: &tsapi.Pod{Annotations: map[string]string{"foo.io/bar": "some-val"}}}}, + Pod: &tsapi.Pod{Annotations: map[string]string{"foo.io/bar": "some-val"}}, + }, + }, } fc := fake.NewClientBuilder(). WithScheme(tsapi.GlobalScheme). @@ -1035,6 +1371,7 @@ func TestProxyClassForService(t *testing.T) { expectReconciled(t, sr, "default", "test") fullName, shortName := findGenName(t, fc, "default", "test", "svc") opts := configOpts{ + replicas: ptr.To[int32](1), stsName: shortName, secretName: fullName, namespace: "default", @@ -1043,19 +1380,19 @@ func TestProxyClassForService(t *testing.T) { clusterTargetIP: "10.20.30.40", app: kubetypes.AppIngressProxy, } - expectEqual(t, fc, expectedSecret(t, fc, opts), nil) - expectEqual(t, fc, expectedHeadlessService(shortName, "svc"), nil) - expectEqual(t, fc, expectedSTS(t, fc, opts), removeHashAnnotation) + expectEqual(t, fc, expectedSecret(t, fc, opts)) + expectEqual(t, fc, expectedHeadlessService(shortName, "svc")) + expectEqual(t, fc, expectedSTS(t, fc, opts), removeResourceReqs) // 2. The Service gets updated with tailscale.com/proxy-class label // pointing at the 'custom-metadata' ProxyClass. The ProxyClass is not // yet ready, so no changes are actually applied to the proxy resources. mustUpdate(t, fc, "default", "test", func(svc *corev1.Service) { - mak.Set(&svc.Labels, LabelProxyClass, "custom-metadata") + mak.Set(&svc.Labels, LabelAnnotationProxyClass, "custom-metadata") }) expectReconciled(t, sr, "default", "test") - expectEqual(t, fc, expectedSTS(t, fc, opts), removeHashAnnotation) - expectEqual(t, fc, expectedSecret(t, fc, opts), nil) + expectEqual(t, fc, expectedSTS(t, fc, opts), removeResourceReqs) + expectEqual(t, fc, expectedSecret(t, fc, opts)) // 3. ProxyClass is set to Ready, the Service gets reconciled by the // services-reconciler and the customization from the ProxyClass is @@ -1064,24 +1401,25 @@ func TestProxyClassForService(t *testing.T) { pc.Status = tsapi.ProxyClassStatus{ Conditions: []metav1.Condition{{ Status: metav1.ConditionTrue, - Type: string(tsapi.ProxyClassready), + Type: string(tsapi.ProxyClassReady), ObservedGeneration: pc.Generation, - }}} + }}, + } }) opts.proxyClass = pc.Name expectReconciled(t, sr, "default", "test") - expectEqual(t, fc, expectedSTS(t, fc, opts), removeHashAnnotation) + expectEqual(t, fc, expectedSTS(t, fc, opts), removeResourceReqs) expectEqual(t, fc, expectedSecret(t, fc, opts), removeAuthKeyIfExistsModifier(t)) // 4. tailscale.com/proxy-class label is removed from the Service, the // configuration from the ProxyClass is removed from the cluster // resources. mustUpdate(t, fc, "default", "test", func(svc *corev1.Service) { - delete(svc.Labels, LabelProxyClass) + delete(svc.Labels, LabelAnnotationProxyClass) }) opts.proxyClass = "" expectReconciled(t, sr, "default", "test") - expectEqual(t, fc, expectedSTS(t, fc, opts), removeHashAnnotation) + expectEqual(t, fc, expectedSTS(t, fc, opts), removeResourceReqs) } func TestDefaultLoadBalancer(t *testing.T) { @@ -1127,8 +1465,9 @@ func TestDefaultLoadBalancer(t *testing.T) { fullName, shortName := findGenName(t, fc, "default", "test", "svc") - expectEqual(t, fc, expectedHeadlessService(shortName, "svc"), nil) + expectEqual(t, fc, expectedHeadlessService(shortName, "svc")) o := configOpts{ + replicas: ptr.To[int32](1), stsName: shortName, secretName: fullName, namespace: "default", @@ -1137,8 +1476,7 @@ func TestDefaultLoadBalancer(t *testing.T) { clusterTargetIP: "10.20.30.40", app: kubetypes.AppIngressProxy, } - expectEqual(t, fc, expectedSTS(t, fc, o), removeHashAnnotation) - + expectEqual(t, fc, expectedSTS(t, fc, o), removeResourceReqs) } func TestProxyFirewallMode(t *testing.T) { @@ -1185,6 +1523,7 @@ func TestProxyFirewallMode(t *testing.T) { fullName, shortName := findGenName(t, fc, "default", "test", "svc") o := configOpts{ + replicas: ptr.To[int32](1), stsName: shortName, secretName: fullName, namespace: "default", @@ -1194,73 +1533,9 @@ func TestProxyFirewallMode(t *testing.T) { clusterTargetIP: "10.20.30.40", app: kubetypes.AppIngressProxy, } - expectEqual(t, fc, expectedSTS(t, fc, o), removeHashAnnotation) + expectEqual(t, fc, expectedSTS(t, fc, o), removeResourceReqs) } -func TestTailscaledConfigfileHash(t *testing.T) { - fc := fake.NewFakeClient() - ft := &fakeTSClient{} - zl, err := zap.NewDevelopment() - if err != nil { - t.Fatal(err) - } - clock := tstest.NewClock(tstest.ClockOpts{}) - sr := &ServiceReconciler{ - Client: fc, - ssr: &tailscaleSTSReconciler{ - Client: fc, - tsClient: ft, - defaultTags: []string{"tag:k8s"}, - operatorNamespace: "operator-ns", - proxyImage: "tailscale/tailscale", - }, - logger: zl.Sugar(), - clock: clock, - isDefaultLoadBalancer: true, - } - - // Create a service that we should manage, and check that the initial round - // of objects looks right. - mustCreate(t, fc, &corev1.Service{ - ObjectMeta: metav1.ObjectMeta{ - Name: "test", - Namespace: "default", - // The apiserver is supposed to set the UID, but the fake client - // doesn't. So, set it explicitly because other code later depends - // on it being set. - UID: types.UID("1234-UID"), - }, - Spec: corev1.ServiceSpec{ - ClusterIP: "10.20.30.40", - Type: corev1.ServiceTypeLoadBalancer, - }, - }) - - expectReconciled(t, sr, "default", "test") - - fullName, shortName := findGenName(t, fc, "default", "test", "svc") - o := configOpts{ - stsName: shortName, - secretName: fullName, - namespace: "default", - parentType: "svc", - hostname: "default-test", - clusterTargetIP: "10.20.30.40", - confFileHash: "e09bededa0379920141cbd0b0dbdf9b8b66545877f9e8397423f5ce3e1ba439e", - app: kubetypes.AppIngressProxy, - } - expectEqual(t, fc, expectedSTS(t, fc, o), nil) - - // 2. Hostname gets changed, configfile is updated and a new hash value - // is produced. - mustUpdate(t, fc, "default", "test", func(svc *corev1.Service) { - mak.Set(&svc.Annotations, AnnotationHostname, "another-test") - }) - o.hostname = "another-test" - o.confFileHash = "5d754cf55463135ee34aa9821f2fd8483b53eb0570c3740c84a086304f427684" - expectReconciled(t, sr, "default", "test") - expectEqual(t, fc, expectedSTS(t, fc, o), nil) -} func Test_isMagicDNSName(t *testing.T) { tests := []struct { in string @@ -1289,6 +1564,8 @@ func Test_isMagicDNSName(t *testing.T) { } func Test_serviceHandlerForIngress(t *testing.T) { + const tailscaleIngressClassName = "tailscale" + fc := fake.NewFakeClient() zl, err := zap.NewDevelopment() if err != nil { @@ -1309,16 +1586,16 @@ func Test_serviceHandlerForIngress(t *testing.T) { Name: "headless-1", Namespace: "tailscale", Labels: map[string]string{ - LabelManaged: "true", - LabelParentName: "ing-1", - LabelParentNamespace: "ns-1", - LabelParentType: "ingress", + kubetypes.LabelManaged: "true", + LabelParentName: "ing-1", + LabelParentNamespace: "ns-1", + LabelParentType: "ingress", }, }, } mustCreate(t, fc, svc1) wantReqs := []reconcile.Request{{NamespacedName: types.NamespacedName{Namespace: "ns-1", Name: "ing-1"}}} - gotReqs := serviceHandlerForIngress(fc, zl.Sugar())(context.Background(), svc1) + gotReqs := serviceHandlerForIngress(fc, zl.Sugar(), tailscaleIngressClassName)(context.Background(), svc1) if diff := cmp.Diff(gotReqs, wantReqs); diff != "" { t.Fatalf("unexpected reconcile requests (-got +want):\n%s", diff) } @@ -1345,7 +1622,7 @@ func Test_serviceHandlerForIngress(t *testing.T) { } mustCreate(t, fc, backendSvc) wantReqs = []reconcile.Request{{NamespacedName: types.NamespacedName{Namespace: "ns-2", Name: "ing-2"}}} - gotReqs = serviceHandlerForIngress(fc, zl.Sugar())(context.Background(), backendSvc) + gotReqs = serviceHandlerForIngress(fc, zl.Sugar(), tailscaleIngressClassName)(context.Background(), backendSvc) if diff := cmp.Diff(gotReqs, wantReqs); diff != "" { t.Fatalf("unexpected reconcile requests (-got +want):\n%s", diff) } @@ -1361,7 +1638,8 @@ func Test_serviceHandlerForIngress(t *testing.T) { IngressClassName: ptr.To(tailscaleIngressClassName), Rules: []networkingv1.IngressRule{{IngressRuleValue: networkingv1.IngressRuleValue{HTTP: &networkingv1.HTTPIngressRuleValue{ Paths: []networkingv1.HTTPIngressPath{ - {Backend: networkingv1.IngressBackend{Service: &networkingv1.IngressServiceBackend{Name: "backend"}}}}, + {Backend: networkingv1.IngressBackend{Service: &networkingv1.IngressServiceBackend{Name: "backend"}}}, + }, }}}}, }, }) @@ -1373,7 +1651,7 @@ func Test_serviceHandlerForIngress(t *testing.T) { } mustCreate(t, fc, backendSvc2) wantReqs = []reconcile.Request{{NamespacedName: types.NamespacedName{Namespace: "ns-3", Name: "ing-3"}}} - gotReqs = serviceHandlerForIngress(fc, zl.Sugar())(context.Background(), backendSvc2) + gotReqs = serviceHandlerForIngress(fc, zl.Sugar(), tailscaleIngressClassName)(context.Background(), backendSvc2) if diff := cmp.Diff(gotReqs, wantReqs); diff != "" { t.Fatalf("unexpected reconcile requests (-got +want):\n%s", diff) } @@ -1388,7 +1666,8 @@ func Test_serviceHandlerForIngress(t *testing.T) { Spec: networkingv1.IngressSpec{ Rules: []networkingv1.IngressRule{{IngressRuleValue: networkingv1.IngressRuleValue{HTTP: &networkingv1.HTTPIngressRuleValue{ Paths: []networkingv1.HTTPIngressPath{ - {Backend: networkingv1.IngressBackend{Service: &networkingv1.IngressServiceBackend{Name: "non-ts-backend"}}}}, + {Backend: networkingv1.IngressBackend{Service: &networkingv1.IngressServiceBackend{Name: "non-ts-backend"}}}, + }, }}}}, }, }) @@ -1399,7 +1678,7 @@ func Test_serviceHandlerForIngress(t *testing.T) { }, } mustCreate(t, fc, nonTSBackend) - gotReqs = serviceHandlerForIngress(fc, zl.Sugar())(context.Background(), nonTSBackend) + gotReqs = serviceHandlerForIngress(fc, zl.Sugar(), tailscaleIngressClassName)(context.Background(), nonTSBackend) if len(gotReqs) > 0 { t.Errorf("unexpected reconcile request for a Service that does not belong to a Tailscale Ingress: %#+v\n", gotReqs) } @@ -1413,12 +1692,48 @@ func Test_serviceHandlerForIngress(t *testing.T) { }, } mustCreate(t, fc, someSvc) - gotReqs = serviceHandlerForIngress(fc, zl.Sugar())(context.Background(), someSvc) + gotReqs = serviceHandlerForIngress(fc, zl.Sugar(), tailscaleIngressClassName)(context.Background(), someSvc) if len(gotReqs) > 0 { t.Errorf("unexpected reconcile request for a Service that does not belong to any Ingress: %#+v\n", gotReqs) } } +func Test_serviceHandlerForIngress_multipleIngressClasses(t *testing.T) { + fc := fake.NewFakeClient() + zl, err := zap.NewDevelopment() + if err != nil { + t.Fatal(err) + } + + svc := &corev1.Service{ + ObjectMeta: metav1.ObjectMeta{Name: "backend", Namespace: "default"}, + } + mustCreate(t, fc, svc) + + mustCreate(t, fc, &networkingv1.Ingress{ + ObjectMeta: metav1.ObjectMeta{Name: "nginx-ing", Namespace: "default"}, + Spec: networkingv1.IngressSpec{ + IngressClassName: ptr.To("nginx"), + DefaultBackend: &networkingv1.IngressBackend{Service: &networkingv1.IngressServiceBackend{Name: "backend"}}, + }, + }) + + mustCreate(t, fc, &networkingv1.Ingress{ + ObjectMeta: metav1.ObjectMeta{Name: "ts-ing", Namespace: "default"}, + Spec: networkingv1.IngressSpec{ + IngressClassName: ptr.To("tailscale"), + DefaultBackend: &networkingv1.IngressBackend{Service: &networkingv1.IngressServiceBackend{Name: "backend"}}, + }, + }) + + got := serviceHandlerForIngress(fc, zl.Sugar(), "tailscale")(context.Background(), svc) + want := []reconcile.Request{{NamespacedName: types.NamespacedName{Namespace: "default", Name: "ts-ing"}}} + + if diff := cmp.Diff(got, want); diff != "" { + t.Fatalf("unexpected reconcile requests (-got +want):\n%s", diff) + } +} + func Test_clusterDomainFromResolverConf(t *testing.T) { zl, err := zap.NewDevelopment() if err != nil { @@ -1487,6 +1802,7 @@ func Test_clusterDomainFromResolverConf(t *testing.T) { }) } } + func Test_authKeyRemoval(t *testing.T) { fc := fake.NewFakeClient() ft := &fakeTSClient{} @@ -1535,11 +1851,12 @@ func Test_authKeyRemoval(t *testing.T) { hostname: "default-test", clusterTargetIP: "10.20.30.40", app: kubetypes.AppIngressProxy, + replicas: ptr.To[int32](1), } - expectEqual(t, fc, expectedSecret(t, fc, opts), nil) - expectEqual(t, fc, expectedHeadlessService(shortName, "svc"), nil) - expectEqual(t, fc, expectedSTS(t, fc, opts), removeHashAnnotation) + expectEqual(t, fc, expectedSecret(t, fc, opts)) + expectEqual(t, fc, expectedHeadlessService(shortName, "svc")) + expectEqual(t, fc, expectedSTS(t, fc, opts), removeResourceReqs) // 2. Apply update to the Secret that imitates the proxy setting device_id. s := expectedSecret(t, fc, opts) @@ -1551,7 +1868,7 @@ func Test_authKeyRemoval(t *testing.T) { expectReconciled(t, sr, "default", "test") opts.shouldRemoveAuthKey = true opts.secretExtraData = map[string][]byte{"device_id": []byte("dkkdi4CNTRL")} - expectEqual(t, fc, expectedSecret(t, fc, opts), nil) + expectEqual(t, fc, expectedSecret(t, fc, opts)) } func Test_externalNameService(t *testing.T) { @@ -1602,6 +1919,7 @@ func Test_externalNameService(t *testing.T) { fullName, shortName := findGenName(t, fc, "default", "test", "svc") opts := configOpts{ + replicas: ptr.To[int32](1), stsName: shortName, secretName: fullName, namespace: "default", @@ -1611,9 +1929,9 @@ func Test_externalNameService(t *testing.T) { app: kubetypes.AppIngressProxy, } - expectEqual(t, fc, expectedSecret(t, fc, opts), nil) - expectEqual(t, fc, expectedHeadlessService(shortName, "svc"), nil) - expectEqual(t, fc, expectedSTS(t, fc, opts), removeHashAnnotation) + expectEqual(t, fc, expectedSecret(t, fc, opts)) + expectEqual(t, fc, expectedHeadlessService(shortName, "svc")) + expectEqual(t, fc, expectedSTS(t, fc, opts), removeResourceReqs) // 2. Change the ExternalName and verify that changes get propagated. mustUpdate(t, sr, "default", "test", func(s *corev1.Service) { @@ -1621,7 +1939,156 @@ func Test_externalNameService(t *testing.T) { }) expectReconciled(t, sr, "default", "test") opts.clusterTargetDNS = "bar.com" - expectEqual(t, fc, expectedSTS(t, fc, opts), removeHashAnnotation) + expectEqual(t, fc, expectedSTS(t, fc, opts), removeResourceReqs) +} + +func Test_metricsResourceCreation(t *testing.T) { + pc := &tsapi.ProxyClass{ + ObjectMeta: metav1.ObjectMeta{Name: "metrics", Generation: 1}, + Spec: tsapi.ProxyClassSpec{}, + Status: tsapi.ProxyClassStatus{ + Conditions: []metav1.Condition{{ + Status: metav1.ConditionTrue, + Type: string(tsapi.ProxyClassReady), + ObservedGeneration: 1, + }}, + }, + } + svc := &corev1.Service{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test", + Namespace: "default", + UID: types.UID("1234-UID"), + Labels: map[string]string{LabelAnnotationProxyClass: "metrics"}, + }, + Spec: corev1.ServiceSpec{ + ClusterIP: "10.20.30.40", + Type: corev1.ServiceTypeLoadBalancer, + LoadBalancerClass: ptr.To("tailscale"), + }, + } + crd := &apiextensionsv1.CustomResourceDefinition{ObjectMeta: metav1.ObjectMeta{Name: serviceMonitorCRD}} + fc := fake.NewClientBuilder(). + WithScheme(tsapi.GlobalScheme). + WithObjects(pc, svc). + WithStatusSubresource(pc). + Build() + ft := &fakeTSClient{} + zl, err := zap.NewDevelopment() + if err != nil { + t.Fatal(err) + } + clock := tstest.NewClock(tstest.ClockOpts{}) + sr := &ServiceReconciler{ + Client: fc, + ssr: &tailscaleSTSReconciler{ + Client: fc, + tsClient: ft, + operatorNamespace: "operator-ns", + }, + logger: zl.Sugar(), + clock: clock, + } + expectReconciled(t, sr, "default", "test") + fullName, shortName := findGenName(t, fc, "default", "test", "svc") + opts := configOpts{ + stsName: shortName, + secretName: fullName, + namespace: "default", + parentType: "svc", + tailscaleNamespace: "operator-ns", + hostname: "default-test", + namespaced: true, + proxyType: proxyTypeIngressService, + app: kubetypes.AppIngressProxy, + resourceVersion: "1", + } + + // 1. Enable metrics- expect metrics Service to be created + mustUpdate(t, fc, "", "metrics", func(pc *tsapi.ProxyClass) { + pc.Spec = tsapi.ProxyClassSpec{Metrics: &tsapi.Metrics{Enable: true}} + }) + expectReconciled(t, sr, "default", "test") + opts.enableMetrics = true + expectEqual(t, fc, expectedMetricsService(opts)) + + // 2. Enable ServiceMonitor - should not error when there is no ServiceMonitor CRD in cluster + mustUpdate(t, fc, "", "metrics", func(pc *tsapi.ProxyClass) { + pc.Spec.Metrics.ServiceMonitor = &tsapi.ServiceMonitor{Enable: true} + }) + expectReconciled(t, sr, "default", "test") + + // 3. Create ServiceMonitor CRD and reconcile- ServiceMonitor should get created + mustCreate(t, fc, crd) + expectReconciled(t, sr, "default", "test") + expectEqualUnstructured(t, fc, expectedServiceMonitor(t, opts)) + + // 4. A change to ServiceMonitor config gets reflected in the ServiceMonitor resource + mustUpdate(t, fc, "", "metrics", func(pc *tsapi.ProxyClass) { + pc.Spec.Metrics.ServiceMonitor.Labels = tsapi.Labels{"foo": "bar"} + }) + expectReconciled(t, sr, "default", "test") + opts.serviceMonitorLabels = tsapi.Labels{"foo": "bar"} + opts.resourceVersion = "2" + expectEqual(t, fc, expectedMetricsService(opts)) + expectEqualUnstructured(t, fc, expectedServiceMonitor(t, opts)) + + // 5. Disable metrics- expect metrics Service to be deleted + mustUpdate(t, fc, "", "metrics", func(pc *tsapi.ProxyClass) { + pc.Spec.Metrics = nil + }) + expectReconciled(t, sr, "default", "test") + expectMissing[corev1.Service](t, fc, "operator-ns", metricsResourceName(opts.stsName)) + // ServiceMonitor gets garbage collected when Service gets deleted (it has OwnerReference of the Service + // object). We cannot test this using the fake client. +} + +func TestIgnorePGService(t *testing.T) { + // NOTE: creating proxygroup stuff just to be sure that it's all ignored + _, _, fc, _, _ := setupServiceTest(t) + + ft := &fakeTSClient{} + zl, err := zap.NewDevelopment() + if err != nil { + t.Fatal(err) + } + clock := tstest.NewClock(tstest.ClockOpts{}) + sr := &ServiceReconciler{ + Client: fc, + ssr: &tailscaleSTSReconciler{ + Client: fc, + tsClient: ft, + defaultTags: []string{"tag:k8s"}, + operatorNamespace: "operator-ns", + proxyImage: "tailscale/tailscale", + }, + logger: zl.Sugar(), + clock: clock, + } + + // Create a service that we should manage, and check that the initial round + // of objects looks right. + mustCreate(t, fc, &corev1.Service{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test", + Namespace: "default", + // The apiserver is supposed to set the UID, but the fake client + // doesn't. So, set it explicitly because other code later depends + // on it being set. + UID: types.UID("1234-UID"), + Annotations: map[string]string{ + "tailscale.com/proxygroup": "test-pg", + }, + }, + Spec: corev1.ServiceSpec{ + ClusterIP: "10.20.30.40", + Type: corev1.ServiceTypeClusterIP, + }, + }) + + expectReconciled(t, sr, "default", "test") + + findNoGenName(t, fc, "default", "test", "svc") } func toFQDN(t *testing.T, s string) dnsname.FQDN { diff --git a/cmd/k8s-operator/proxy.go b/cmd/k8s-operator/proxy.go deleted file mode 100644 index 672f07b1f..000000000 --- a/cmd/k8s-operator/proxy.go +++ /dev/null @@ -1,421 +0,0 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !plan9 - -package main - -import ( - "crypto/tls" - "fmt" - "log" - "net/http" - "net/http/httputil" - "net/netip" - "net/url" - "os" - "strings" - - "github.com/pkg/errors" - "go.uber.org/zap" - "k8s.io/client-go/rest" - "k8s.io/client-go/transport" - "tailscale.com/client/tailscale" - "tailscale.com/client/tailscale/apitype" - ksr "tailscale.com/k8s-operator/sessionrecording" - "tailscale.com/kube/kubetypes" - "tailscale.com/tailcfg" - "tailscale.com/tsnet" - "tailscale.com/util/clientmetric" - "tailscale.com/util/ctxkey" - "tailscale.com/util/set" -) - -var ( - // counterNumRequestsproxies counts the number of API server requests proxied via this proxy. - counterNumRequestsProxied = clientmetric.NewCounter("k8s_auth_proxy_requests_proxied") - whoIsKey = ctxkey.New("", (*apitype.WhoIsResponse)(nil)) -) - -type apiServerProxyMode int - -func (a apiServerProxyMode) String() string { - switch a { - case apiserverProxyModeDisabled: - return "disabled" - case apiserverProxyModeEnabled: - return "auth" - case apiserverProxyModeNoAuth: - return "noauth" - default: - return "unknown" - } -} - -const ( - apiserverProxyModeDisabled apiServerProxyMode = iota - apiserverProxyModeEnabled - apiserverProxyModeNoAuth -) - -func parseAPIProxyMode() apiServerProxyMode { - haveAuthProxyEnv := os.Getenv("AUTH_PROXY") != "" - haveAPIProxyEnv := os.Getenv("APISERVER_PROXY") != "" - switch { - case haveAPIProxyEnv && haveAuthProxyEnv: - log.Fatal("AUTH_PROXY and APISERVER_PROXY are mutually exclusive") - case haveAuthProxyEnv: - var authProxyEnv = defaultBool("AUTH_PROXY", false) // deprecated - if authProxyEnv { - return apiserverProxyModeEnabled - } - return apiserverProxyModeDisabled - case haveAPIProxyEnv: - var apiProxyEnv = defaultEnv("APISERVER_PROXY", "") // true, false or "noauth" - switch apiProxyEnv { - case "true": - return apiserverProxyModeEnabled - case "false", "": - return apiserverProxyModeDisabled - case "noauth": - return apiserverProxyModeNoAuth - default: - panic(fmt.Sprintf("unknown APISERVER_PROXY value %q", apiProxyEnv)) - } - } - return apiserverProxyModeDisabled -} - -// maybeLaunchAPIServerProxy launches the auth proxy, which is a small HTTP server -// that authenticates requests using the Tailscale LocalAPI and then proxies -// them to the kube-apiserver. -func maybeLaunchAPIServerProxy(zlog *zap.SugaredLogger, restConfig *rest.Config, s *tsnet.Server, mode apiServerProxyMode) { - if mode == apiserverProxyModeDisabled { - return - } - startlog := zlog.Named("launchAPIProxy") - if mode == apiserverProxyModeNoAuth { - restConfig = rest.AnonymousClientConfig(restConfig) - } - cfg, err := restConfig.TransportConfig() - if err != nil { - startlog.Fatalf("could not get rest.TransportConfig(): %v", err) - } - - // Kubernetes uses SPDY for exec and port-forward, however SPDY is - // incompatible with HTTP/2; so disable HTTP/2 in the proxy. - tr := http.DefaultTransport.(*http.Transport).Clone() - tr.TLSClientConfig, err = transport.TLSConfigFor(cfg) - if err != nil { - startlog.Fatalf("could not get transport.TLSConfigFor(): %v", err) - } - tr.TLSNextProto = make(map[string]func(authority string, c *tls.Conn) http.RoundTripper) - - rt, err := transport.HTTPWrappersForConfig(cfg, tr) - if err != nil { - startlog.Fatalf("could not get rest.TransportConfig(): %v", err) - } - go runAPIServerProxy(s, rt, zlog.Named("apiserver-proxy"), mode, restConfig.Host) -} - -// runAPIServerProxy runs an HTTP server that authenticates requests using the -// Tailscale LocalAPI and then proxies them to the Kubernetes API. -// It listens on :443 and uses the Tailscale HTTPS certificate. -// s will be started if it is not already running. -// rt is used to proxy requests to the Kubernetes API. -// -// mode controls how the proxy behaves: -// - apiserverProxyModeDisabled: the proxy is not started. -// - apiserverProxyModeEnabled: the proxy is started and requests are impersonated using the -// caller's identity from the Tailscale LocalAPI. -// - apiserverProxyModeNoAuth: the proxy is started and requests are not impersonated and -// are passed through to the Kubernetes API. -// -// It never returns. -func runAPIServerProxy(ts *tsnet.Server, rt http.RoundTripper, log *zap.SugaredLogger, mode apiServerProxyMode, host string) { - if mode == apiserverProxyModeDisabled { - return - } - ln, err := ts.Listen("tcp", ":443") - if err != nil { - log.Fatalf("could not listen on :443: %v", err) - } - u, err := url.Parse(host) - if err != nil { - log.Fatalf("runAPIServerProxy: failed to parse URL %v", err) - } - - lc, err := ts.LocalClient() - if err != nil { - log.Fatalf("could not get local client: %v", err) - } - - ap := &apiserverProxy{ - log: log, - lc: lc, - mode: mode, - upstreamURL: u, - ts: ts, - } - ap.rp = &httputil.ReverseProxy{ - Rewrite: func(pr *httputil.ProxyRequest) { - ap.addImpersonationHeadersAsRequired(pr.Out) - }, - Transport: rt, - } - - mux := http.NewServeMux() - mux.HandleFunc("/", ap.serveDefault) - mux.HandleFunc("POST /api/v1/namespaces/{namespace}/pods/{pod}/exec", ap.serveExecSPDY) - mux.HandleFunc("GET /api/v1/namespaces/{namespace}/pods/{pod}/exec", ap.serveExecWS) - - hs := &http.Server{ - // Kubernetes uses SPDY for exec and port-forward, however SPDY is - // incompatible with HTTP/2; so disable HTTP/2 in the proxy. - TLSConfig: &tls.Config{ - GetCertificate: lc.GetCertificate, - NextProtos: []string{"http/1.1"}, - }, - TLSNextProto: make(map[string]func(*http.Server, *tls.Conn, http.Handler)), - Handler: mux, - } - log.Infof("API server proxy in %q mode is listening on %s", mode, ln.Addr()) - if err := hs.ServeTLS(ln, "", ""); err != nil { - log.Fatalf("runAPIServerProxy: failed to serve %v", err) - } -} - -// apiserverProxy is an [net/http.Handler] that authenticates requests using the Tailscale -// LocalAPI and then proxies them to the Kubernetes API. -type apiserverProxy struct { - log *zap.SugaredLogger - lc *tailscale.LocalClient - rp *httputil.ReverseProxy - - mode apiServerProxyMode - ts *tsnet.Server - upstreamURL *url.URL -} - -// serveDefault is the default handler for Kubernetes API server requests. -func (ap *apiserverProxy) serveDefault(w http.ResponseWriter, r *http.Request) { - who, err := ap.whoIs(r) - if err != nil { - ap.authError(w, err) - return - } - counterNumRequestsProxied.Add(1) - ap.rp.ServeHTTP(w, r.WithContext(whoIsKey.WithValue(r.Context(), who))) -} - -// serveExecSPDY serves 'kubectl exec' requests for sessions streamed over SPDY, -// optionally configuring the kubectl exec sessions to be recorded. -func (ap *apiserverProxy) serveExecSPDY(w http.ResponseWriter, r *http.Request) { - ap.execForProto(w, r, ksr.SPDYProtocol) -} - -// serveExecWS serves 'kubectl exec' requests for sessions streamed over WebSocket, -// optionally configuring the kubectl exec sessions to be recorded. -func (ap *apiserverProxy) serveExecWS(w http.ResponseWriter, r *http.Request) { - ap.execForProto(w, r, ksr.WSProtocol) -} - -func (ap *apiserverProxy) execForProto(w http.ResponseWriter, r *http.Request, proto ksr.Protocol) { - const ( - podNameKey = "pod" - namespaceNameKey = "namespace" - upgradeHeaderKey = "Upgrade" - ) - - who, err := ap.whoIs(r) - if err != nil { - ap.authError(w, err) - return - } - counterNumRequestsProxied.Add(1) - failOpen, addrs, err := determineRecorderConfig(who) - if err != nil { - ap.log.Errorf("error trying to determine whether the 'kubectl exec' session needs to be recorded: %v", err) - return - } - if failOpen && len(addrs) == 0 { // will not record - ap.rp.ServeHTTP(w, r.WithContext(whoIsKey.WithValue(r.Context(), who))) - return - } - ksr.CounterSessionRecordingsAttempted.Add(1) // at this point we know that users intended for this session to be recorded - if !failOpen && len(addrs) == 0 { - msg := "forbidden: 'kubectl exec' session must be recorded, but no recorders are available." - ap.log.Error(msg) - http.Error(w, msg, http.StatusForbidden) - return - } - - wantsHeader := upgradeHeaderForProto[proto] - if h := r.Header.Get(upgradeHeaderKey); h != wantsHeader { - msg := fmt.Sprintf("[unexpected] unable to verify that streaming protocol is %s, wants Upgrade header %q, got: %q", proto, wantsHeader, h) - if failOpen { - msg = msg + "; failure mode is 'fail open'; continuing session without recording." - ap.log.Warn(msg) - ap.rp.ServeHTTP(w, r.WithContext(whoIsKey.WithValue(r.Context(), who))) - return - } - ap.log.Error(msg) - msg += "; failure mode is 'fail closed'; closing connection." - http.Error(w, msg, http.StatusForbidden) - return - } - - opts := ksr.HijackerOpts{ - Req: r, - W: w, - Proto: proto, - TS: ap.ts, - Who: who, - Addrs: addrs, - FailOpen: failOpen, - Pod: r.PathValue(podNameKey), - Namespace: r.PathValue(namespaceNameKey), - Log: ap.log, - } - h := ksr.New(opts) - - ap.rp.ServeHTTP(h, r.WithContext(whoIsKey.WithValue(r.Context(), who))) -} - -func (h *apiserverProxy) addImpersonationHeadersAsRequired(r *http.Request) { - r.URL.Scheme = h.upstreamURL.Scheme - r.URL.Host = h.upstreamURL.Host - if h.mode == apiserverProxyModeNoAuth { - // If we are not providing authentication, then we are just - // proxying to the Kubernetes API, so we don't need to do - // anything else. - return - } - - // We want to proxy to the Kubernetes API, but we want to use - // the caller's identity to do so. We do this by impersonating - // the caller using the Kubernetes User Impersonation feature: - // https://kubernetes.io/docs/reference/access-authn-authz/authentication/#user-impersonation - - // Out of paranoia, remove all authentication headers that might - // have been set by the client. - r.Header.Del("Authorization") - r.Header.Del("Impersonate-Group") - r.Header.Del("Impersonate-User") - r.Header.Del("Impersonate-Uid") - for k := range r.Header { - if strings.HasPrefix(k, "Impersonate-Extra-") { - r.Header.Del(k) - } - } - - // Now add the impersonation headers that we want. - if err := addImpersonationHeaders(r, h.log); err != nil { - log.Printf("failed to add impersonation headers: " + err.Error()) - } -} - -func (ap *apiserverProxy) whoIs(r *http.Request) (*apitype.WhoIsResponse, error) { - return ap.lc.WhoIs(r.Context(), r.RemoteAddr) -} - -func (ap *apiserverProxy) authError(w http.ResponseWriter, err error) { - ap.log.Errorf("failed to authenticate caller: %v", err) - http.Error(w, "failed to authenticate caller", http.StatusInternalServerError) -} - -const ( - // oldCapabilityName is a legacy form of - // tailfcg.PeerCapabilityKubernetes capability. The only capability rule - // that is respected for this form is group impersonation - for - // backwards compatibility reasons. - // TODO (irbekrm): determine if anyone uses this and remove if possible. - oldCapabilityName = "https://" + tailcfg.PeerCapabilityKubernetes -) - -// addImpersonationHeaders adds the appropriate headers to r to impersonate the -// caller when proxying to the Kubernetes API. It uses the WhoIsResponse stashed -// in the context by the apiserverProxy. -func addImpersonationHeaders(r *http.Request, log *zap.SugaredLogger) error { - log = log.With("remote", r.RemoteAddr) - who := whoIsKey.Value(r.Context()) - rules, err := tailcfg.UnmarshalCapJSON[kubetypes.KubernetesCapRule](who.CapMap, tailcfg.PeerCapabilityKubernetes) - if len(rules) == 0 && err == nil { - // Try the old capability name for backwards compatibility. - rules, err = tailcfg.UnmarshalCapJSON[kubetypes.KubernetesCapRule](who.CapMap, oldCapabilityName) - } - if err != nil { - return fmt.Errorf("failed to unmarshal capability: %v", err) - } - - var groupsAdded set.Slice[string] - for _, rule := range rules { - if rule.Impersonate == nil { - continue - } - for _, group := range rule.Impersonate.Groups { - if groupsAdded.Contains(group) { - continue - } - r.Header.Add("Impersonate-Group", group) - groupsAdded.Add(group) - log.Debugf("adding group impersonation header for user group %s", group) - } - } - - if !who.Node.IsTagged() { - r.Header.Set("Impersonate-User", who.UserProfile.LoginName) - log.Debugf("adding user impersonation header for user %s", who.UserProfile.LoginName) - return nil - } - // "Impersonate-Group" requires "Impersonate-User" to be set, so we set it - // to the node FQDN for tagged nodes. - nodeName := strings.TrimSuffix(who.Node.Name, ".") - r.Header.Set("Impersonate-User", nodeName) - log.Debugf("adding user impersonation header for node name %s", nodeName) - - // For legacy behavior (before caps), set the groups to the nodes tags. - if groupsAdded.Slice().Len() == 0 { - for _, tag := range who.Node.Tags { - r.Header.Add("Impersonate-Group", tag) - log.Debugf("adding group impersonation header for node tag %s", tag) - } - } - return nil -} - -// determineRecorderConfig determines recorder config from requester's peer -// capabilities. Determines whether a 'kubectl exec' session from this requester -// needs to be recorded and what recorders the recording should be sent to. -func determineRecorderConfig(who *apitype.WhoIsResponse) (failOpen bool, recorderAddresses []netip.AddrPort, _ error) { - if who == nil { - return false, nil, errors.New("[unexpected] cannot determine caller") - } - failOpen = true - rules, err := tailcfg.UnmarshalCapJSON[kubetypes.KubernetesCapRule](who.CapMap, tailcfg.PeerCapabilityKubernetes) - if err != nil { - return failOpen, nil, fmt.Errorf("failed to unmarshal Kubernetes capability: %w", err) - } - if len(rules) == 0 { - return failOpen, nil, nil - } - - for _, rule := range rules { - if len(rule.RecorderAddrs) != 0 { - // TODO (irbekrm): here or later determine if the - // recorders behind those addrs are online - else we - // spend 30s trying to reach a recorder whose tailscale - // status is offline. - recorderAddresses = append(recorderAddresses, rule.RecorderAddrs...) - } - if rule.EnforceRecorder { - failOpen = false - } - } - return failOpen, recorderAddresses, nil -} - -var upgradeHeaderForProto = map[ksr.Protocol]string{ - ksr.SPDYProtocol: "SPDY/3.1", - ksr.WSProtocol: "websocket", -} diff --git a/cmd/k8s-operator/proxyclass.go b/cmd/k8s-operator/proxyclass.go index b5d213746..2d51b351d 100644 --- a/cmd/k8s-operator/proxyclass.go +++ b/cmd/k8s-operator/proxyclass.go @@ -15,6 +15,7 @@ import ( dockerref "github.com/distribution/reference" "go.uber.org/zap" corev1 "k8s.io/api/core/v1" + apiextensionsv1 "k8s.io/apiextensions-apiserver/pkg/apis/apiextensions/v1" apiequality "k8s.io/apimachinery/pkg/api/equality" apierrors "k8s.io/apimachinery/pkg/api/errors" apivalidation "k8s.io/apimachinery/pkg/api/validation" @@ -43,22 +44,24 @@ const ( type ProxyClassReconciler struct { client.Client - recorder record.EventRecorder - logger *zap.SugaredLogger - clock tstime.Clock + recorder record.EventRecorder + logger *zap.SugaredLogger + clock tstime.Clock + tsNamespace string mu sync.Mutex // protects following // managedProxyClasses is a set of all ProxyClass resources that we're currently // managing. This is only used for metrics. managedProxyClasses set.Slice[types.UID] + // nodePortRange is the NodePort range set for the Kubernetes Cluster. This is used + // when validating port ranges configured by users for spec.StaticEndpoints + nodePortRange *tsapi.PortRange } -var ( - // gaugeProxyClassResources tracks the number of ProxyClass resources - // that we're currently managing. - gaugeProxyClassResources = clientmetric.NewGauge("k8s_proxyclass_resources") -) +// gaugeProxyClassResources tracks the number of ProxyClass resources +// that we're currently managing. +var gaugeProxyClassResources = clientmetric.NewGauge("k8s_proxyclass_resources") func (pcr *ProxyClassReconciler) Reconcile(ctx context.Context, req reconcile.Request) (res reconcile.Result, err error) { logger := pcr.logger.With("ProxyClass", req.Name) @@ -95,14 +98,14 @@ func (pcr *ProxyClassReconciler) Reconcile(ctx context.Context, req reconcile.Re pcr.mu.Unlock() oldPCStatus := pc.Status.DeepCopy() - if errs := pcr.validate(pc); errs != nil { + if errs := pcr.validate(ctx, pc, logger); errs != nil { msg := fmt.Sprintf(messageProxyClassInvalid, errs.ToAggregate().Error()) pcr.recorder.Event(pc, corev1.EventTypeWarning, reasonProxyClassInvalid, msg) - tsoperator.SetProxyClassCondition(pc, tsapi.ProxyClassready, metav1.ConditionFalse, reasonProxyClassInvalid, msg, pc.Generation, pcr.clock, logger) + tsoperator.SetProxyClassCondition(pc, tsapi.ProxyClassReady, metav1.ConditionFalse, reasonProxyClassInvalid, msg, pc.Generation, pcr.clock, logger) } else { - tsoperator.SetProxyClassCondition(pc, tsapi.ProxyClassready, metav1.ConditionTrue, reasonProxyClassValid, reasonProxyClassValid, pc.Generation, pcr.clock, logger) + tsoperator.SetProxyClassCondition(pc, tsapi.ProxyClassReady, metav1.ConditionTrue, reasonProxyClassValid, reasonProxyClassValid, pc.Generation, pcr.clock, logger) } - if !apiequality.Semantic.DeepEqual(oldPCStatus, pc.Status) { + if !apiequality.Semantic.DeepEqual(oldPCStatus, &pc.Status) { if err := pcr.Client.Status().Update(ctx, pc); err != nil { logger.Errorf("error updating ProxyClass status: %v", err) return reconcile.Result{}, err @@ -111,10 +114,10 @@ func (pcr *ProxyClassReconciler) Reconcile(ctx context.Context, req reconcile.Re return reconcile.Result{}, nil } -func (pcr *ProxyClassReconciler) validate(pc *tsapi.ProxyClass) (violations field.ErrorList) { +func (pcr *ProxyClassReconciler) validate(ctx context.Context, pc *tsapi.ProxyClass, logger *zap.SugaredLogger) (violations field.ErrorList) { if sts := pc.Spec.StatefulSet; sts != nil { if len(sts.Labels) > 0 { - if errs := metavalidation.ValidateLabels(sts.Labels, field.NewPath(".spec.statefulSet.labels")); errs != nil { + if errs := metavalidation.ValidateLabels(sts.Labels.Parse(), field.NewPath(".spec.statefulSet.labels")); errs != nil { violations = append(violations, errs...) } } @@ -125,7 +128,7 @@ func (pcr *ProxyClassReconciler) validate(pc *tsapi.ProxyClass) (violations fiel } if pod := sts.Pod; pod != nil { if len(pod.Labels) > 0 { - if errs := metavalidation.ValidateLabels(pod.Labels, field.NewPath(".spec.statefulSet.pod.labels")); errs != nil { + if errs := metavalidation.ValidateLabels(pod.Labels.Parse(), field.NewPath(".spec.statefulSet.pod.labels")); errs != nil { violations = append(violations, errs...) } } @@ -160,9 +163,39 @@ func (pcr *ProxyClassReconciler) validate(pc *tsapi.ProxyClass) (violations fiel violations = append(violations, field.TypeInvalid(field.NewPath("spec", "statefulSet", "pod", "tailscaleInitContainer", "image"), tc.Image, err.Error())) } } + + if tc.Debug != nil { + violations = append(violations, field.TypeInvalid(field.NewPath("spec", "statefulSet", "pod", "tailscaleInitContainer", "debug"), tc.Debug, "debug settings cannot be configured on the init container")) + } } } } + if pc.Spec.Metrics != nil && pc.Spec.Metrics.ServiceMonitor != nil && pc.Spec.Metrics.ServiceMonitor.Enable { + found, err := hasServiceMonitorCRD(ctx, pcr.Client) + if err != nil { + pcr.logger.Infof("[unexpected]: error retrieving %q CRD: %v", serviceMonitorCRD, err) + // best effort validation - don't error out here + } else if !found { + msg := fmt.Sprintf("ProxyClass defines that a ServiceMonitor custom resource should be created, but %q CRD was not found", serviceMonitorCRD) + violations = append(violations, field.TypeInvalid(field.NewPath("spec", "metrics", "serviceMonitor"), "enable", msg)) + } + } + if pc.Spec.Metrics != nil && pc.Spec.Metrics.ServiceMonitor != nil && len(pc.Spec.Metrics.ServiceMonitor.Labels) > 0 { + if errs := metavalidation.ValidateLabels(pc.Spec.Metrics.ServiceMonitor.Labels.Parse(), field.NewPath(".spec.metrics.serviceMonitor.labels")); errs != nil { + violations = append(violations, errs...) + } + } + + if stat := pc.Spec.StaticEndpoints; stat != nil { + if err := validateNodePortRanges(ctx, pcr.Client, pcr.nodePortRange, pc); err != nil { + var prs tsapi.PortRanges = stat.NodePort.Ports + violations = append(violations, field.TypeInvalid(field.NewPath("spec", "staticEndpoints", "nodePort", "ports"), prs.String(), err.Error())) + } + + if len(stat.NodePort.Selector) < 1 { + logger.Debug("no Selectors specified on `spec.staticEndpoints.nodePort.selectors` field") + } + } // We do not validate embedded fields (security context, resource // requirements etc) as we inherit upstream validation for those fields. // Invalid values would get rejected by upstream validations at apply @@ -170,6 +203,16 @@ func (pcr *ProxyClassReconciler) validate(pc *tsapi.ProxyClass) (violations fiel return violations } +func hasServiceMonitorCRD(ctx context.Context, cl client.Client) (bool, error) { + sm := &apiextensionsv1.CustomResourceDefinition{} + if err := cl.Get(ctx, types.NamespacedName{Name: serviceMonitorCRD}, sm); apierrors.IsNotFound(err) { + return false, nil + } else if err != nil { + return false, err + } + return true, nil +} + // maybeCleanup removes tailscale.com finalizer and ensures that the ProxyClass // is no longer counted towards k8s_proxyclass_resources. func (pcr *ProxyClassReconciler) maybeCleanup(ctx context.Context, logger *zap.SugaredLogger, pc *tsapi.ProxyClass) error { diff --git a/cmd/k8s-operator/proxyclass_test.go b/cmd/k8s-operator/proxyclass_test.go index c52fbb187..ae0f63d99 100644 --- a/cmd/k8s-operator/proxyclass_test.go +++ b/cmd/k8s-operator/proxyclass_test.go @@ -8,10 +8,12 @@ package main import ( + "context" "testing" "time" "go.uber.org/zap" + apiextensionsv1 "k8s.io/apiextensions-apiserver/pkg/apis/apiextensions/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/types" "k8s.io/client-go/tools/record" @@ -34,10 +36,10 @@ func TestProxyClass(t *testing.T) { }, Spec: tsapi.ProxyClassSpec{ StatefulSet: &tsapi.StatefulSet{ - Labels: map[string]string{"foo": "bar", "xyz1234": "abc567"}, + Labels: tsapi.Labels{"foo": "bar", "xyz1234": "abc567"}, Annotations: map[string]string{"foo.io/bar": "{'key': 'val1232'}"}, Pod: &tsapi.Pod{ - Labels: map[string]string{"foo": "bar", "xyz1234": "abc567"}, + Labels: tsapi.Labels{"foo": "bar", "xyz1234": "abc567"}, Annotations: map[string]string{"foo.io/bar": "{'key': 'val1232'}"}, TailscaleContainer: &tsapi.Container{ Env: []tsapi.Env{{Name: "FOO", Value: "BAR"}}, @@ -69,14 +71,14 @@ func TestProxyClass(t *testing.T) { // 1. A valid ProxyClass resource gets its status updated to Ready. expectReconciled(t, pcr, "", "test") pc.Status.Conditions = append(pc.Status.Conditions, metav1.Condition{ - Type: string(tsapi.ProxyClassready), + Type: string(tsapi.ProxyClassReady), Status: metav1.ConditionTrue, Reason: reasonProxyClassValid, Message: reasonProxyClassValid, LastTransitionTime: metav1.Time{Time: cl.Now().Truncate(time.Second)}, }) - expectEqual(t, fc, pc, nil) + expectEqual(t, fc, pc) // 2. A ProxyClass resource with invalid labels gets its status updated to Invalid with an error message. pc.Spec.StatefulSet.Labels["foo"] = "?!someVal" @@ -85,8 +87,8 @@ func TestProxyClass(t *testing.T) { }) expectReconciled(t, pcr, "", "test") msg := `ProxyClass is not valid: .spec.statefulSet.labels: Invalid value: "?!someVal": a valid label must be an empty string or consist of alphanumeric characters, '-', '_' or '.', and must start and end with an alphanumeric character (e.g. 'MyValue', or 'my_value', or '12345', regex used for validation is '(([A-Za-z0-9][-A-Za-z0-9_.]*)?[A-Za-z0-9])?')` - tsoperator.SetProxyClassCondition(pc, tsapi.ProxyClassready, metav1.ConditionFalse, reasonProxyClassInvalid, msg, 0, cl, zl.Sugar()) - expectEqual(t, fc, pc, nil) + tsoperator.SetProxyClassCondition(pc, tsapi.ProxyClassReady, metav1.ConditionFalse, reasonProxyClassInvalid, msg, 0, cl, zl.Sugar()) + expectEqual(t, fc, pc) expectedEvent := "Warning ProxyClassInvalid ProxyClass is not valid: .spec.statefulSet.labels: Invalid value: \"?!someVal\": a valid label must be an empty string or consist of alphanumeric characters, '-', '_' or '.', and must start and end with an alphanumeric character (e.g. 'MyValue', or 'my_value', or '12345', regex used for validation is '(([A-Za-z0-9][-A-Za-z0-9_.]*)?[A-Za-z0-9])?')" expectEvents(t, fr, []string{expectedEvent}) @@ -99,8 +101,8 @@ func TestProxyClass(t *testing.T) { }) expectReconciled(t, pcr, "", "test") msg = `ProxyClass is not valid: spec.statefulSet.pod.tailscaleContainer.image: Invalid value: "FOO bar": invalid reference format: repository name (library/FOO bar) must be lowercase` - tsoperator.SetProxyClassCondition(pc, tsapi.ProxyClassready, metav1.ConditionFalse, reasonProxyClassInvalid, msg, 0, cl, zl.Sugar()) - expectEqual(t, fc, pc, nil) + tsoperator.SetProxyClassCondition(pc, tsapi.ProxyClassReady, metav1.ConditionFalse, reasonProxyClassInvalid, msg, 0, cl, zl.Sugar()) + expectEqual(t, fc, pc) expectedEvent = `Warning ProxyClassInvalid ProxyClass is not valid: spec.statefulSet.pod.tailscaleContainer.image: Invalid value: "FOO bar": invalid reference format: repository name (library/FOO bar) must be lowercase` expectEvents(t, fr, []string{expectedEvent}) @@ -118,8 +120,8 @@ func TestProxyClass(t *testing.T) { }) expectReconciled(t, pcr, "", "test") msg = `ProxyClass is not valid: spec.statefulSet.pod.tailscaleInitContainer.image: Invalid value: "FOO bar": invalid reference format: repository name (library/FOO bar) must be lowercase` - tsoperator.SetProxyClassCondition(pc, tsapi.ProxyClassready, metav1.ConditionFalse, reasonProxyClassInvalid, msg, 0, cl, zl.Sugar()) - expectEqual(t, fc, pc, nil) + tsoperator.SetProxyClassCondition(pc, tsapi.ProxyClassReady, metav1.ConditionFalse, reasonProxyClassInvalid, msg, 0, cl, zl.Sugar()) + expectEqual(t, fc, pc) expectedEvent = `Warning ProxyClassInvalid ProxyClass is not valid: spec.statefulSet.pod.tailscaleInitContainer.image: Invalid value: "FOO bar": invalid reference format: repository name (library/FOO bar) must be lowercase` expectEvents(t, fr, []string{expectedEvent}) @@ -129,9 +131,210 @@ func TestProxyClass(t *testing.T) { proxyClass.Spec.StatefulSet.Pod.TailscaleInitContainer.Image = pc.Spec.StatefulSet.Pod.TailscaleInitContainer.Image proxyClass.Spec.StatefulSet.Pod.TailscaleContainer.Env = []tsapi.Env{{Name: "TS_USERSPACE", Value: "true"}, {Name: "EXPERIMENTAL_TS_CONFIGFILE_PATH"}, {Name: "EXPERIMENTAL_ALLOW_PROXYING_CLUSTER_TRAFFIC_VIA_INGRESS"}} }) - expectedEvents := []string{"Warning CustomTSEnvVar ProxyClass overrides the default value for TS_USERSPACE env var for tailscale container. Running with custom values for Tailscale env vars is not recommended and might break in the future.", + expectedEvents := []string{ + "Warning CustomTSEnvVar ProxyClass overrides the default value for TS_USERSPACE env var for tailscale container. Running with custom values for Tailscale env vars is not recommended and might break in the future.", "Warning CustomTSEnvVar ProxyClass overrides the default value for EXPERIMENTAL_TS_CONFIGFILE_PATH env var for tailscale container. Running with custom values for Tailscale env vars is not recommended and might break in the future.", - "Warning CustomTSEnvVar ProxyClass overrides the default value for EXPERIMENTAL_ALLOW_PROXYING_CLUSTER_TRAFFIC_VIA_INGRESS env var for tailscale container. Running with custom values for Tailscale env vars is not recommended and might break in the future."} + "Warning CustomTSEnvVar ProxyClass overrides the default value for EXPERIMENTAL_ALLOW_PROXYING_CLUSTER_TRAFFIC_VIA_INGRESS env var for tailscale container. Running with custom values for Tailscale env vars is not recommended and might break in the future.", + } expectReconciled(t, pcr, "", "test") expectEvents(t, fr, expectedEvents) + + // 6. A ProxyClass with ServiceMonitor enabled and in a cluster that has not ServiceMonitor CRD is invalid + pc.Spec.Metrics = &tsapi.Metrics{Enable: true, ServiceMonitor: &tsapi.ServiceMonitor{Enable: true}} + mustUpdate(t, fc, "", "test", func(proxyClass *tsapi.ProxyClass) { + proxyClass.Spec = pc.Spec + }) + expectReconciled(t, pcr, "", "test") + msg = `ProxyClass is not valid: spec.metrics.serviceMonitor: Invalid value: "enable": ProxyClass defines that a ServiceMonitor custom resource should be created, but "servicemonitors.monitoring.coreos.com" CRD was not found` + tsoperator.SetProxyClassCondition(pc, tsapi.ProxyClassReady, metav1.ConditionFalse, reasonProxyClassInvalid, msg, 0, cl, zl.Sugar()) + expectEqual(t, fc, pc) + expectedEvent = "Warning ProxyClassInvalid " + msg + expectEvents(t, fr, []string{expectedEvent}) + + // 7. A ProxyClass with ServiceMonitor enabled and in a cluster that does have the ServiceMonitor CRD is valid + crd := &apiextensionsv1.CustomResourceDefinition{ObjectMeta: metav1.ObjectMeta{Name: serviceMonitorCRD}} + mustCreate(t, fc, crd) + expectReconciled(t, pcr, "", "test") + tsoperator.SetProxyClassCondition(pc, tsapi.ProxyClassReady, metav1.ConditionTrue, reasonProxyClassValid, reasonProxyClassValid, 0, cl, zl.Sugar()) + expectEqual(t, fc, pc) + + // 7. A ProxyClass with invalid ServiceMonitor labels gets its status updated to Invalid with an error message. + pc.Spec.Metrics.ServiceMonitor.Labels = tsapi.Labels{"foo": "bar!"} + mustUpdate(t, fc, "", "test", func(proxyClass *tsapi.ProxyClass) { + proxyClass.Spec.Metrics.ServiceMonitor.Labels = pc.Spec.Metrics.ServiceMonitor.Labels + }) + expectReconciled(t, pcr, "", "test") + msg = `ProxyClass is not valid: .spec.metrics.serviceMonitor.labels: Invalid value: "bar!": a valid label must be an empty string or consist of alphanumeric characters, '-', '_' or '.', and must start and end with an alphanumeric character (e.g. 'MyValue', or 'my_value', or '12345', regex used for validation is '(([A-Za-z0-9][-A-Za-z0-9_.]*)?[A-Za-z0-9])?')` + tsoperator.SetProxyClassCondition(pc, tsapi.ProxyClassReady, metav1.ConditionFalse, reasonProxyClassInvalid, msg, 0, cl, zl.Sugar()) + expectEqual(t, fc, pc) + + // 8. A ProxyClass with valid ServiceMonitor labels gets its status updated to Valid. + pc.Spec.Metrics.ServiceMonitor.Labels = tsapi.Labels{"foo": "bar", "xyz1234": "abc567", "empty": "", "onechar": "a"} + mustUpdate(t, fc, "", "test", func(proxyClass *tsapi.ProxyClass) { + proxyClass.Spec.Metrics.ServiceMonitor.Labels = pc.Spec.Metrics.ServiceMonitor.Labels + }) + expectReconciled(t, pcr, "", "test") + tsoperator.SetProxyClassCondition(pc, tsapi.ProxyClassReady, metav1.ConditionTrue, reasonProxyClassValid, reasonProxyClassValid, 0, cl, zl.Sugar()) + expectEqual(t, fc, pc) +} + +func TestValidateProxyClassStaticEndpoints(t *testing.T) { + for name, tc := range map[string]struct { + staticEndpointConfig *tsapi.StaticEndpointsConfig + valid bool + }{ + "no_static_endpoints": { + staticEndpointConfig: nil, + valid: true, + }, + "valid_specific_ports": { + staticEndpointConfig: &tsapi.StaticEndpointsConfig{ + NodePort: &tsapi.NodePortConfig{ + Ports: []tsapi.PortRange{ + {Port: 3001}, + {Port: 3005}, + }, + Selector: map[string]string{"kubernetes.io/hostname": "foobar"}, + }, + }, + valid: true, + }, + "valid_port_ranges": { + staticEndpointConfig: &tsapi.StaticEndpointsConfig{ + NodePort: &tsapi.NodePortConfig{ + Ports: []tsapi.PortRange{ + {Port: 3000, EndPort: 3002}, + {Port: 3005, EndPort: 3007}, + }, + Selector: map[string]string{"kubernetes.io/hostname": "foobar"}, + }, + }, + valid: true, + }, + "overlapping_port_ranges": { + staticEndpointConfig: &tsapi.StaticEndpointsConfig{ + NodePort: &tsapi.NodePortConfig{ + Ports: []tsapi.PortRange{ + {Port: 1000, EndPort: 2000}, + {Port: 1500, EndPort: 1800}, + }, + Selector: map[string]string{"kubernetes.io/hostname": "foobar"}, + }, + }, + valid: false, + }, + "clashing_port_and_range": { + staticEndpointConfig: &tsapi.StaticEndpointsConfig{ + NodePort: &tsapi.NodePortConfig{ + Ports: []tsapi.PortRange{ + {Port: 3005}, + {Port: 3001, EndPort: 3010}, + }, + Selector: map[string]string{"kubernetes.io/hostname": "foobar"}, + }, + }, + valid: false, + }, + "malformed_port_range": { + staticEndpointConfig: &tsapi.StaticEndpointsConfig{ + NodePort: &tsapi.NodePortConfig{ + Ports: []tsapi.PortRange{ + {Port: 3001, EndPort: 3000}, + }, + Selector: map[string]string{"kubernetes.io/hostname": "foobar"}, + }, + }, + valid: false, + }, + "empty_selector": { + staticEndpointConfig: &tsapi.StaticEndpointsConfig{ + NodePort: &tsapi.NodePortConfig{ + Ports: []tsapi.PortRange{{Port: 3000}}, + Selector: map[string]string{}, + }, + }, + valid: true, + }, + } { + t.Run(name, func(t *testing.T) { + fc := fake.NewClientBuilder(). + WithScheme(tsapi.GlobalScheme). + Build() + zl, _ := zap.NewDevelopment() + pcr := &ProxyClassReconciler{ + logger: zl.Sugar(), + Client: fc, + } + + pc := &tsapi.ProxyClass{ + Spec: tsapi.ProxyClassSpec{ + StaticEndpoints: tc.staticEndpointConfig, + }, + } + + logger := pcr.logger.With("ProxyClass", pc) + err := pcr.validate(context.Background(), pc, logger) + valid := err == nil + if valid != tc.valid { + t.Errorf("expected valid=%v, got valid=%v, err=%v", tc.valid, valid, err) + } + }) + } +} + +func TestValidateProxyClass(t *testing.T) { + for name, tc := range map[string]struct { + pc *tsapi.ProxyClass + valid bool + }{ + "empty": { + valid: true, + pc: &tsapi.ProxyClass{}, + }, + "debug_enabled_for_main_container": { + valid: true, + pc: &tsapi.ProxyClass{ + Spec: tsapi.ProxyClassSpec{ + StatefulSet: &tsapi.StatefulSet{ + Pod: &tsapi.Pod{ + TailscaleContainer: &tsapi.Container{ + Debug: &tsapi.Debug{ + Enable: true, + }, + }, + }, + }, + }, + }, + }, + "debug_enabled_for_init_container": { + valid: false, + pc: &tsapi.ProxyClass{ + Spec: tsapi.ProxyClassSpec{ + StatefulSet: &tsapi.StatefulSet{ + Pod: &tsapi.Pod{ + TailscaleInitContainer: &tsapi.Container{ + Debug: &tsapi.Debug{ + Enable: true, + }, + }, + }, + }, + }, + }, + }, + } { + t.Run(name, func(t *testing.T) { + zl, _ := zap.NewDevelopment() + pcr := &ProxyClassReconciler{ + logger: zl.Sugar(), + } + logger := pcr.logger.With("ProxyClass", tc.pc) + err := pcr.validate(context.Background(), tc.pc, logger) + valid := err == nil + if valid != tc.valid { + t.Errorf("expected valid=%v, got valid=%v, err=%v", tc.valid, valid, err) + } + }) + } } diff --git a/cmd/k8s-operator/proxygroup.go b/cmd/k8s-operator/proxygroup.go new file mode 100644 index 000000000..946e017a2 --- /dev/null +++ b/cmd/k8s-operator/proxygroup.go @@ -0,0 +1,1209 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !plan9 + +package main + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "net/http" + "net/netip" + "slices" + "sort" + "strings" + "sync" + + dockerref "github.com/distribution/reference" + "go.uber.org/zap" + xslices "golang.org/x/exp/slices" + appsv1 "k8s.io/api/apps/v1" + corev1 "k8s.io/api/core/v1" + rbacv1 "k8s.io/api/rbac/v1" + apiequality "k8s.io/apimachinery/pkg/api/equality" + apierrors "k8s.io/apimachinery/pkg/api/errors" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/types" + "k8s.io/apimachinery/pkg/util/intstr" + "k8s.io/client-go/tools/record" + "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/reconcile" + + "tailscale.com/client/tailscale" + "tailscale.com/ipn" + tsoperator "tailscale.com/k8s-operator" + tsapi "tailscale.com/k8s-operator/apis/v1alpha1" + "tailscale.com/kube/egressservices" + "tailscale.com/kube/k8s-proxy/conf" + "tailscale.com/kube/kubetypes" + "tailscale.com/tailcfg" + "tailscale.com/tstime" + "tailscale.com/types/opt" + "tailscale.com/types/ptr" + "tailscale.com/util/clientmetric" + "tailscale.com/util/mak" + "tailscale.com/util/set" +) + +const ( + reasonProxyGroupCreationFailed = "ProxyGroupCreationFailed" + reasonProxyGroupReady = "ProxyGroupReady" + reasonProxyGroupAvailable = "ProxyGroupAvailable" + reasonProxyGroupCreating = "ProxyGroupCreating" + reasonProxyGroupInvalid = "ProxyGroupInvalid" + + // Copied from k8s.io/apiserver/pkg/registry/generic/registry/store.go@cccad306d649184bf2a0e319ba830c53f65c445c + optimisticLockErrorMsg = "the object has been modified; please apply your changes to the latest version and try again" + staticEndpointsMaxAddrs = 2 + + // The minimum tailcfg.CapabilityVersion that deployed clients are expected + // to support to be compatible with the current ProxyGroup controller. + // If the controller needs to depend on newer client behaviour, it should + // maintain backwards compatible logic for older capability versions for 3 + // stable releases, as per documentation on supported version drift: + // https://tailscale.com/kb/1236/kubernetes-operator#supported-versions + // + // tailcfg.CurrentCapabilityVersion was 106 when the ProxyGroup controller was + // first introduced. + pgMinCapabilityVersion = 106 +) + +var ( + gaugeEgressProxyGroupResources = clientmetric.NewGauge(kubetypes.MetricProxyGroupEgressCount) + gaugeIngressProxyGroupResources = clientmetric.NewGauge(kubetypes.MetricProxyGroupIngressCount) + gaugeAPIServerProxyGroupResources = clientmetric.NewGauge(kubetypes.MetricProxyGroupAPIServerCount) +) + +// ProxyGroupReconciler ensures cluster resources for a ProxyGroup definition. +type ProxyGroupReconciler struct { + client.Client + log *zap.SugaredLogger + recorder record.EventRecorder + clock tstime.Clock + tsClient tsClient + + // User-specified defaults from the helm installation. + tsNamespace string + tsProxyImage string + k8sProxyImage string + defaultTags []string + tsFirewallMode string + defaultProxyClass string + loginServer string + + mu sync.Mutex // protects following + egressProxyGroups set.Slice[types.UID] // for egress proxygroups gauge + ingressProxyGroups set.Slice[types.UID] // for ingress proxygroups gauge + apiServerProxyGroups set.Slice[types.UID] // for kube-apiserver proxygroups gauge +} + +func (r *ProxyGroupReconciler) logger(name string) *zap.SugaredLogger { + return r.log.With("ProxyGroup", name) +} + +func (r *ProxyGroupReconciler) Reconcile(ctx context.Context, req reconcile.Request) (_ reconcile.Result, err error) { + logger := r.logger(req.Name) + logger.Debugf("starting reconcile") + defer logger.Debugf("reconcile finished") + + pg := new(tsapi.ProxyGroup) + err = r.Get(ctx, req.NamespacedName, pg) + if apierrors.IsNotFound(err) { + logger.Debugf("ProxyGroup not found, assuming it was deleted") + return reconcile.Result{}, nil + } else if err != nil { + return reconcile.Result{}, fmt.Errorf("failed to get tailscale.com ProxyGroup: %w", err) + } + if markedForDeletion(pg) { + logger.Debugf("ProxyGroup is being deleted, cleaning up resources") + ix := xslices.Index(pg.Finalizers, FinalizerName) + if ix < 0 { + logger.Debugf("no finalizer, nothing to do") + return reconcile.Result{}, nil + } + + if done, err := r.maybeCleanup(ctx, pg); err != nil { + if strings.Contains(err.Error(), optimisticLockErrorMsg) { + logger.Infof("optimistic lock error, retrying: %s", err) + return reconcile.Result{}, nil + } + return reconcile.Result{}, err + } else if !done { + logger.Debugf("ProxyGroup resource cleanup not yet finished, will retry...") + return reconcile.Result{RequeueAfter: shortRequeue}, nil + } + + pg.Finalizers = slices.Delete(pg.Finalizers, ix, ix+1) + if err := r.Update(ctx, pg); err != nil { + return reconcile.Result{}, err + } + return reconcile.Result{}, nil + } + + oldPGStatus := pg.Status.DeepCopy() + staticEndpoints, nrr, err := r.reconcilePG(ctx, pg, logger) + return reconcile.Result{}, errors.Join(err, r.maybeUpdateStatus(ctx, logger, pg, oldPGStatus, nrr, staticEndpoints)) +} + +// reconcilePG handles all reconciliation of a ProxyGroup that is not marked +// for deletion. It is separated out from Reconcile to make a clear separation +// between reconciling the ProxyGroup, and posting the status of its created +// resources onto the ProxyGroup status field. +func (r *ProxyGroupReconciler) reconcilePG(ctx context.Context, pg *tsapi.ProxyGroup, logger *zap.SugaredLogger) (map[string][]netip.AddrPort, *notReadyReason, error) { + if !slices.Contains(pg.Finalizers, FinalizerName) { + // This log line is printed exactly once during initial provisioning, + // because once the finalizer is in place this block gets skipped. So, + // this is a nice place to log that the high level, multi-reconcile + // operation is underway. + logger.Infof("ensuring ProxyGroup is set up") + pg.Finalizers = append(pg.Finalizers, FinalizerName) + if err := r.Update(ctx, pg); err != nil { + return r.notReadyErrf(pg, logger, "error adding finalizer: %w", err) + } + } + + proxyClassName := r.defaultProxyClass + if pg.Spec.ProxyClass != "" { + proxyClassName = pg.Spec.ProxyClass + } + + var proxyClass *tsapi.ProxyClass + if proxyClassName != "" { + proxyClass = new(tsapi.ProxyClass) + err := r.Get(ctx, types.NamespacedName{Name: proxyClassName}, proxyClass) + if apierrors.IsNotFound(err) { + msg := fmt.Sprintf("the ProxyGroup's ProxyClass %q does not (yet) exist", proxyClassName) + logger.Info(msg) + return notReady(reasonProxyGroupCreating, msg) + } + if err != nil { + return r.notReadyErrf(pg, logger, "error getting ProxyGroup's ProxyClass %q: %w", proxyClassName, err) + } + if !tsoperator.ProxyClassIsReady(proxyClass) { + msg := fmt.Sprintf("the ProxyGroup's ProxyClass %q is not yet in a ready state, waiting...", proxyClassName) + logger.Info(msg) + return notReady(reasonProxyGroupCreating, msg) + } + } + + if err := r.validate(ctx, pg, proxyClass, logger); err != nil { + return notReady(reasonProxyGroupInvalid, fmt.Sprintf("invalid ProxyGroup spec: %v", err)) + } + + staticEndpoints, nrr, err := r.maybeProvision(ctx, pg, proxyClass) + if err != nil { + return nil, nrr, err + } + + return staticEndpoints, nrr, nil +} + +func (r *ProxyGroupReconciler) validate(ctx context.Context, pg *tsapi.ProxyGroup, pc *tsapi.ProxyClass, logger *zap.SugaredLogger) error { + // Our custom logic for ensuring minimum downtime ProxyGroup update rollouts relies on the local health check + // beig accessible on the replica Pod IP:9002. This address can also be modified by users, via + // TS_LOCAL_ADDR_PORT env var. + // + // Currently TS_LOCAL_ADDR_PORT controls Pod's health check and metrics address. _Probably_ there is no need for + // users to set this to a custom value. Users who want to consume metrics, should integrate with the metrics + // Service and/or ServiceMonitor, rather than Pods directly. The health check is likely not useful to integrate + // directly with for operator proxies (and we should aim for unified lifecycle logic in the operator, users + // shouldn't need to set their own). + // + // TODO(irbekrm): maybe disallow configuring this env var in future (in Tailscale 1.84 or later). + if pg.Spec.Type == tsapi.ProxyGroupTypeEgress && hasLocalAddrPortSet(pc) { + msg := fmt.Sprintf("ProxyClass %s applied to an egress ProxyGroup has TS_LOCAL_ADDR_PORT env var set to a custom value."+ + "This will disable the ProxyGroup graceful failover mechanism, so you might experience downtime when ProxyGroup pods are restarted."+ + "In future we will remove the ability to set custom TS_LOCAL_ADDR_PORT for egress ProxyGroups."+ + "Please raise an issue if you expect that this will cause issues for your workflow.", pc.Name) + logger.Warn(msg) + } + + // image is the value of pc.Spec.StatefulSet.Pod.TailscaleContainer.Image or "" + // imagePath is a slash-delimited path ending with the image name, e.g. + // "tailscale/tailscale" or maybe "k8s-proxy" if hosted at example.com/k8s-proxy. + var image, imagePath string + if pc != nil && + pc.Spec.StatefulSet != nil && + pc.Spec.StatefulSet.Pod != nil && + pc.Spec.StatefulSet.Pod.TailscaleContainer != nil && + pc.Spec.StatefulSet.Pod.TailscaleContainer.Image != "" { + image, err := dockerref.ParseNormalizedNamed(pc.Spec.StatefulSet.Pod.TailscaleContainer.Image) + if err != nil { + // Shouldn't be possible as the ProxyClass won't be marked ready + // without successfully parsing the image. + return fmt.Errorf("error parsing %q as a container image reference: %w", pc.Spec.StatefulSet.Pod.TailscaleContainer.Image, err) + } + imagePath = dockerref.Path(image) + } + + var errs []error + if isAuthAPIServerProxy(pg) { + // Validate that the static ServiceAccount already exists. + sa := &corev1.ServiceAccount{} + if err := r.Get(ctx, types.NamespacedName{Namespace: r.tsNamespace, Name: authAPIServerProxySAName}, sa); err != nil { + if !apierrors.IsNotFound(err) { + return fmt.Errorf("error validating that ServiceAccount %q exists: %w", authAPIServerProxySAName, err) + } + + errs = append(errs, fmt.Errorf("the ServiceAccount %q used for the API server proxy in auth mode does not exist but "+ + "should have been created during operator installation; use apiServerProxyConfig.allowImpersonation=true "+ + "in the helm chart, or authproxy-rbac.yaml from the static manifests", authAPIServerProxySAName)) + } + } else { + // Validate that the ServiceAccount we create won't overwrite the static one. + // TODO(tomhjp): This doesn't cover other controllers that could create a + // ServiceAccount. Perhaps should have some guards to ensure that an update + // would never change the ownership of a resource we expect to already be owned. + if pgServiceAccountName(pg) == authAPIServerProxySAName { + errs = append(errs, fmt.Errorf("the name of the ProxyGroup %q conflicts with the static ServiceAccount used for the API server proxy in auth mode", pg.Name)) + } + } + + if pg.Spec.Type == tsapi.ProxyGroupTypeKubernetesAPIServer { + if strings.HasSuffix(imagePath, "tailscale") { + errs = append(errs, fmt.Errorf("the configured ProxyClass %q specifies to use image %q but expected a %q image for ProxyGroup of type %q", pc.Name, image, "k8s-proxy", pg.Spec.Type)) + } + + if pc != nil && pc.Spec.StatefulSet != nil && pc.Spec.StatefulSet.Pod != nil && pc.Spec.StatefulSet.Pod.TailscaleInitContainer != nil { + errs = append(errs, fmt.Errorf("the configured ProxyClass %q specifies Tailscale init container config, but ProxyGroups of type %q do not use init containers", pc.Name, pg.Spec.Type)) + } + } else { + if strings.HasSuffix(imagePath, "k8s-proxy") { + errs = append(errs, fmt.Errorf("the configured ProxyClass %q specifies to use image %q but expected a %q image for ProxyGroup of type %q", pc.Name, image, "tailscale", pg.Spec.Type)) + } + } + + return errors.Join(errs...) +} + +func (r *ProxyGroupReconciler) maybeProvision(ctx context.Context, pg *tsapi.ProxyGroup, proxyClass *tsapi.ProxyClass) (map[string][]netip.AddrPort, *notReadyReason, error) { + logger := r.logger(pg.Name) + r.mu.Lock() + r.ensureAddedToGaugeForProxyGroup(pg) + r.mu.Unlock() + + svcToNodePorts := make(map[string]uint16) + var tailscaledPort *uint16 + if proxyClass != nil && proxyClass.Spec.StaticEndpoints != nil { + var err error + svcToNodePorts, tailscaledPort, err = r.ensureNodePortServiceCreated(ctx, pg, proxyClass) + if err != nil { + var allocatePortErr *allocatePortsErr + if errors.As(err, &allocatePortErr) { + reason := reasonProxyGroupCreationFailed + msg := fmt.Sprintf("error provisioning NodePort Services for static endpoints: %v", err) + r.recorder.Event(pg, corev1.EventTypeWarning, reason, msg) + return notReady(reason, msg) + } + return r.notReadyErrf(pg, logger, "error provisioning NodePort Services for static endpoints: %w", err) + } + } + + staticEndpoints, err := r.ensureConfigSecretsCreated(ctx, pg, proxyClass, svcToNodePorts) + if err != nil { + var selectorErr *FindStaticEndpointErr + if errors.As(err, &selectorErr) { + reason := reasonProxyGroupCreationFailed + msg := fmt.Sprintf("error provisioning config Secrets: %v", err) + r.recorder.Event(pg, corev1.EventTypeWarning, reason, msg) + return notReady(reason, msg) + } + return r.notReadyErrf(pg, logger, "error provisioning config Secrets: %w", err) + } + + // State secrets are precreated so we can use the ProxyGroup CR as their owner ref. + stateSecrets := pgStateSecrets(pg, r.tsNamespace) + for _, sec := range stateSecrets { + if _, err := createOrUpdate(ctx, r.Client, r.tsNamespace, sec, func(s *corev1.Secret) { + s.ObjectMeta.Labels = sec.ObjectMeta.Labels + s.ObjectMeta.Annotations = sec.ObjectMeta.Annotations + s.ObjectMeta.OwnerReferences = sec.ObjectMeta.OwnerReferences + }); err != nil { + return r.notReadyErrf(pg, logger, "error provisioning state Secrets: %w", err) + } + } + + // auth mode kube-apiserver ProxyGroups use a statically created + // ServiceAccount to keep ClusterRole creation permissions limited to the + // helm chart installer. + if !isAuthAPIServerProxy(pg) { + sa := pgServiceAccount(pg, r.tsNamespace) + if _, err := createOrUpdate(ctx, r.Client, r.tsNamespace, sa, func(s *corev1.ServiceAccount) { + s.ObjectMeta.Labels = sa.ObjectMeta.Labels + s.ObjectMeta.Annotations = sa.ObjectMeta.Annotations + s.ObjectMeta.OwnerReferences = sa.ObjectMeta.OwnerReferences + }); err != nil { + return r.notReadyErrf(pg, logger, "error provisioning ServiceAccount: %w", err) + } + } + + role := pgRole(pg, r.tsNamespace) + if _, err := createOrUpdate(ctx, r.Client, r.tsNamespace, role, func(r *rbacv1.Role) { + r.ObjectMeta.Labels = role.ObjectMeta.Labels + r.ObjectMeta.Annotations = role.ObjectMeta.Annotations + r.ObjectMeta.OwnerReferences = role.ObjectMeta.OwnerReferences + r.Rules = role.Rules + }); err != nil { + return r.notReadyErrf(pg, logger, "error provisioning Role: %w", err) + } + + roleBinding := pgRoleBinding(pg, r.tsNamespace) + if _, err := createOrUpdate(ctx, r.Client, r.tsNamespace, roleBinding, func(r *rbacv1.RoleBinding) { + r.ObjectMeta.Labels = roleBinding.ObjectMeta.Labels + r.ObjectMeta.Annotations = roleBinding.ObjectMeta.Annotations + r.ObjectMeta.OwnerReferences = roleBinding.ObjectMeta.OwnerReferences + r.RoleRef = roleBinding.RoleRef + r.Subjects = roleBinding.Subjects + }); err != nil { + return r.notReadyErrf(pg, logger, "error provisioning RoleBinding: %w", err) + } + + if pg.Spec.Type == tsapi.ProxyGroupTypeEgress { + cm, hp := pgEgressCM(pg, r.tsNamespace) + if _, err := createOrUpdate(ctx, r.Client, r.tsNamespace, cm, func(existing *corev1.ConfigMap) { + existing.ObjectMeta.Labels = cm.ObjectMeta.Labels + existing.ObjectMeta.OwnerReferences = cm.ObjectMeta.OwnerReferences + mak.Set(&existing.BinaryData, egressservices.KeyHEPPings, hp) + }); err != nil { + return r.notReadyErrf(pg, logger, "error provisioning egress ConfigMap %q: %w", cm.Name, err) + } + } + + if pg.Spec.Type == tsapi.ProxyGroupTypeIngress { + cm := pgIngressCM(pg, r.tsNamespace) + if _, err := createOrUpdate(ctx, r.Client, r.tsNamespace, cm, func(existing *corev1.ConfigMap) { + existing.ObjectMeta.Labels = cm.ObjectMeta.Labels + existing.ObjectMeta.OwnerReferences = cm.ObjectMeta.OwnerReferences + }); err != nil { + return r.notReadyErrf(pg, logger, "error provisioning ingress ConfigMap %q: %w", cm.Name, err) + } + } + + defaultImage := r.tsProxyImage + if pg.Spec.Type == tsapi.ProxyGroupTypeKubernetesAPIServer { + defaultImage = r.k8sProxyImage + } + ss, err := pgStatefulSet(pg, r.tsNamespace, defaultImage, r.tsFirewallMode, tailscaledPort, proxyClass) + if err != nil { + return r.notReadyErrf(pg, logger, "error generating StatefulSet spec: %w", err) + } + cfg := &tailscaleSTSConfig{ + proxyType: string(pg.Spec.Type), + } + ss = applyProxyClassToStatefulSet(proxyClass, ss, cfg, logger) + + if _, err := createOrUpdate(ctx, r.Client, r.tsNamespace, ss, func(s *appsv1.StatefulSet) { + s.Spec = ss.Spec + s.ObjectMeta.Labels = ss.ObjectMeta.Labels + s.ObjectMeta.Annotations = ss.ObjectMeta.Annotations + s.ObjectMeta.OwnerReferences = ss.ObjectMeta.OwnerReferences + }); err != nil { + return r.notReadyErrf(pg, logger, "error provisioning StatefulSet: %w", err) + } + + mo := &metricsOpts{ + tsNamespace: r.tsNamespace, + proxyStsName: pg.Name, + proxyLabels: pgLabels(pg.Name, nil), + proxyType: "proxygroup", + } + if err := reconcileMetricsResources(ctx, logger, mo, proxyClass, r.Client); err != nil { + return r.notReadyErrf(pg, logger, "error reconciling metrics resources: %w", err) + } + + if err := r.cleanupDanglingResources(ctx, pg, proxyClass); err != nil { + return r.notReadyErrf(pg, logger, "error cleaning up dangling resources: %w", err) + } + + logger.Info("ProxyGroup resources synced") + + return staticEndpoints, nil, nil +} + +func (r *ProxyGroupReconciler) maybeUpdateStatus(ctx context.Context, logger *zap.SugaredLogger, pg *tsapi.ProxyGroup, oldPGStatus *tsapi.ProxyGroupStatus, nrr *notReadyReason, endpoints map[string][]netip.AddrPort) (err error) { + defer func() { + if !apiequality.Semantic.DeepEqual(*oldPGStatus, pg.Status) { + if updateErr := r.Client.Status().Update(ctx, pg); updateErr != nil { + if strings.Contains(updateErr.Error(), optimisticLockErrorMsg) { + logger.Infof("optimistic lock error updating status, retrying: %s", updateErr) + updateErr = nil + } + err = errors.Join(err, updateErr) + } + } + }() + + devices, err := r.getRunningProxies(ctx, pg, endpoints) + if err != nil { + return fmt.Errorf("failed to list running proxies: %w", err) + } + + pg.Status.Devices = devices + + desiredReplicas := int(pgReplicas(pg)) + + // Set ProxyGroupAvailable condition. + status := metav1.ConditionFalse + reason := reasonProxyGroupCreating + message := fmt.Sprintf("%d/%d ProxyGroup pods running", len(devices), desiredReplicas) + if len(devices) > 0 { + status = metav1.ConditionTrue + if len(devices) == desiredReplicas { + reason = reasonProxyGroupAvailable + } + } + tsoperator.SetProxyGroupCondition(pg, tsapi.ProxyGroupAvailable, status, reason, message, 0, r.clock, logger) + + // Set ProxyGroupReady condition. + tsSvcValid, tsSvcSet := tsoperator.KubeAPIServerProxyValid(pg) + status = metav1.ConditionFalse + reason = reasonProxyGroupCreating + switch { + case nrr != nil: + // If we failed earlier, that reason takes precedence. + reason = nrr.reason + message = nrr.message + case pg.Spec.Type == tsapi.ProxyGroupTypeKubernetesAPIServer && tsSvcSet && !tsSvcValid: + reason = reasonProxyGroupInvalid + message = "waiting for config in spec.kubeAPIServer to be marked valid" + case len(devices) < desiredReplicas: + case len(devices) > desiredReplicas: + message = fmt.Sprintf("waiting for %d ProxyGroup pods to shut down", len(devices)-desiredReplicas) + case pg.Spec.Type == tsapi.ProxyGroupTypeKubernetesAPIServer && !tsoperator.KubeAPIServerProxyConfigured(pg): + reason = reasonProxyGroupCreating + message = "waiting for proxies to start advertising the kube-apiserver proxy's hostname" + default: + status = metav1.ConditionTrue + reason = reasonProxyGroupReady + message = reasonProxyGroupReady + } + tsoperator.SetProxyGroupCondition(pg, tsapi.ProxyGroupReady, status, reason, message, pg.Generation, r.clock, logger) + + return nil +} + +// getServicePortsForProxyGroups returns a map of ProxyGroup Service names to their NodePorts, +// and a set of all allocated NodePorts for quick occupancy checking. +func getServicePortsForProxyGroups(ctx context.Context, c client.Client, namespace string, portRanges tsapi.PortRanges) (map[string]uint16, set.Set[uint16], error) { + svcs := new(corev1.ServiceList) + matchingLabels := client.MatchingLabels(map[string]string{ + LabelParentType: "proxygroup", + }) + + err := c.List(ctx, svcs, matchingLabels, client.InNamespace(namespace)) + if err != nil { + return nil, nil, fmt.Errorf("failed to list ProxyGroup Services: %w", err) + } + + svcToNodePorts := map[string]uint16{} + usedPorts := set.Set[uint16]{} + for _, svc := range svcs.Items { + if len(svc.Spec.Ports) == 1 && svc.Spec.Ports[0].NodePort != 0 { + p := uint16(svc.Spec.Ports[0].NodePort) + if portRanges.Contains(p) { + svcToNodePorts[svc.Name] = p + usedPorts.Add(p) + } + } + } + + return svcToNodePorts, usedPorts, nil +} + +type allocatePortsErr struct { + msg string +} + +func (e *allocatePortsErr) Error() string { + return e.msg +} + +func (r *ProxyGroupReconciler) allocatePorts(ctx context.Context, pg *tsapi.ProxyGroup, proxyClassName string, portRanges tsapi.PortRanges) (map[string]uint16, error) { + replicaCount := int(pgReplicas(pg)) + svcToNodePorts, usedPorts, err := getServicePortsForProxyGroups(ctx, r.Client, r.tsNamespace, portRanges) + if err != nil { + return nil, &allocatePortsErr{msg: fmt.Sprintf("failed to find ports for existing ProxyGroup NodePort Services: %s", err.Error())} + } + + replicasAllocated := 0 + for i := range pgReplicas(pg) { + if _, ok := svcToNodePorts[pgNodePortServiceName(pg.Name, i)]; !ok { + svcToNodePorts[pgNodePortServiceName(pg.Name, i)] = 0 + } else { + replicasAllocated++ + } + } + + for replica, port := range svcToNodePorts { + if port == 0 { + for p := range portRanges.All() { + if !usedPorts.Contains(p) { + svcToNodePorts[replica] = p + usedPorts.Add(p) + replicasAllocated++ + break + } + } + } + } + + if replicasAllocated < replicaCount { + return nil, &allocatePortsErr{msg: fmt.Sprintf("not enough available ports to allocate all replicas (needed %d, got %d). Field 'spec.staticEndpoints.nodePort.ports' on ProxyClass %q must have bigger range allocated", replicaCount, usedPorts.Len(), proxyClassName)} + } + + return svcToNodePorts, nil +} + +func (r *ProxyGroupReconciler) ensureNodePortServiceCreated(ctx context.Context, pg *tsapi.ProxyGroup, pc *tsapi.ProxyClass) (map[string]uint16, *uint16, error) { + // NOTE: (ChaosInTheCRD) we want the same TargetPort for every static endpoint NodePort Service for the ProxyGroup + tailscaledPort := getRandomPort() + svcs := []*corev1.Service{} + for i := range pgReplicas(pg) { + nodePortSvcName := pgNodePortServiceName(pg.Name, i) + + svc := &corev1.Service{} + err := r.Get(ctx, types.NamespacedName{Name: nodePortSvcName, Namespace: r.tsNamespace}, svc) + if err != nil && !apierrors.IsNotFound(err) { + return nil, nil, fmt.Errorf("error getting Kubernetes Service %q: %w", nodePortSvcName, err) + } + if apierrors.IsNotFound(err) { + svcs = append(svcs, pgNodePortService(pg, nodePortSvcName, r.tsNamespace)) + } else { + // NOTE: if we can we want to recover the random port used for tailscaled, + // as well as the NodePort previously used for that Service + if len(svc.Spec.Ports) == 1 { + if svc.Spec.Ports[0].Port != 0 { + tailscaledPort = uint16(svc.Spec.Ports[0].Port) + } + } + svcs = append(svcs, svc) + } + } + + svcToNodePorts, err := r.allocatePorts(ctx, pg, pc.Name, pc.Spec.StaticEndpoints.NodePort.Ports) + if err != nil { + return nil, nil, fmt.Errorf("failed to allocate NodePorts to ProxyGroup Services: %w", err) + } + + for _, svc := range svcs { + // NOTE: we know that every service is going to have 1 port here + svc.Spec.Ports[0].Port = int32(tailscaledPort) + svc.Spec.Ports[0].TargetPort = intstr.FromInt(int(tailscaledPort)) + svc.Spec.Ports[0].NodePort = int32(svcToNodePorts[svc.Name]) + + _, err = createOrUpdate(ctx, r.Client, r.tsNamespace, svc, func(s *corev1.Service) { + s.ObjectMeta.Labels = svc.ObjectMeta.Labels + s.ObjectMeta.Annotations = svc.ObjectMeta.Annotations + s.ObjectMeta.OwnerReferences = svc.ObjectMeta.OwnerReferences + s.Spec.Selector = svc.Spec.Selector + s.Spec.Ports = svc.Spec.Ports + }) + if err != nil { + return nil, nil, fmt.Errorf("error creating/updating Kubernetes NodePort Service %q: %w", svc.Name, err) + } + } + + return svcToNodePorts, ptr.To(tailscaledPort), nil +} + +// cleanupDanglingResources ensures we don't leak config secrets, state secrets, and +// tailnet devices when the number of replicas specified is reduced. +func (r *ProxyGroupReconciler) cleanupDanglingResources(ctx context.Context, pg *tsapi.ProxyGroup, pc *tsapi.ProxyClass) error { + logger := r.logger(pg.Name) + metadata, err := r.getNodeMetadata(ctx, pg) + if err != nil { + return err + } + + for _, m := range metadata { + if m.ordinal+1 <= int(pgReplicas(pg)) { + continue + } + + // Dangling resource, delete the config + state Secrets, as well as + // deleting the device from the tailnet. + if err := r.deleteTailnetDevice(ctx, m.tsID, logger); err != nil { + return err + } + if err := r.Delete(ctx, m.stateSecret); err != nil && !apierrors.IsNotFound(err) { + return fmt.Errorf("error deleting state Secret %q: %w", m.stateSecret.Name, err) + } + configSecret := m.stateSecret.DeepCopy() + configSecret.Name += "-config" + if err := r.Delete(ctx, configSecret); err != nil && !apierrors.IsNotFound(err) { + return fmt.Errorf("error deleting config Secret %q: %w", configSecret.Name, err) + } + // NOTE(ChaosInTheCRD): we shouldn't need to get the service first, checking for a not found error should be enough + svc := &corev1.Service{ + ObjectMeta: metav1.ObjectMeta{ + Name: fmt.Sprintf("%s-nodeport", m.stateSecret.Name), + Namespace: m.stateSecret.Namespace, + }, + } + if err := r.Delete(ctx, svc); err != nil { + if !apierrors.IsNotFound(err) { + return fmt.Errorf("error deleting static endpoints Kubernetes Service %q: %w", svc.Name, err) + } + } + } + + // If the ProxyClass has its StaticEndpoints config removed, we want to remove all of the NodePort Services + if pc != nil && pc.Spec.StaticEndpoints == nil { + labels := map[string]string{ + kubetypes.LabelManaged: "true", + LabelParentType: proxyTypeProxyGroup, + LabelParentName: pg.Name, + } + if err := r.DeleteAllOf(ctx, &corev1.Service{}, client.InNamespace(r.tsNamespace), client.MatchingLabels(labels)); err != nil { + return fmt.Errorf("error deleting Kubernetes Services for static endpoints: %w", err) + } + } + + return nil +} + +// maybeCleanup just deletes the device from the tailnet. All the kubernetes +// resources linked to a ProxyGroup will get cleaned up via owner references +// (which we can use because they are all in the same namespace). +func (r *ProxyGroupReconciler) maybeCleanup(ctx context.Context, pg *tsapi.ProxyGroup) (bool, error) { + logger := r.logger(pg.Name) + + metadata, err := r.getNodeMetadata(ctx, pg) + if err != nil { + return false, err + } + + for _, m := range metadata { + if err := r.deleteTailnetDevice(ctx, m.tsID, logger); err != nil { + return false, err + } + } + + mo := &metricsOpts{ + proxyLabels: pgLabels(pg.Name, nil), + tsNamespace: r.tsNamespace, + proxyType: "proxygroup", + } + if err := maybeCleanupMetricsResources(ctx, mo, r.Client); err != nil { + return false, fmt.Errorf("error cleaning up metrics resources: %w", err) + } + + logger.Infof("cleaned up ProxyGroup resources") + r.mu.Lock() + r.ensureRemovedFromGaugeForProxyGroup(pg) + r.mu.Unlock() + return true, nil +} + +func (r *ProxyGroupReconciler) deleteTailnetDevice(ctx context.Context, id tailcfg.StableNodeID, logger *zap.SugaredLogger) error { + logger.Debugf("deleting device %s from control", string(id)) + if err := r.tsClient.DeleteDevice(ctx, string(id)); err != nil { + errResp := &tailscale.ErrResponse{} + if ok := errors.As(err, errResp); ok && errResp.Status == http.StatusNotFound { + logger.Debugf("device %s not found, likely because it has already been deleted from control", string(id)) + } else { + return fmt.Errorf("error deleting device: %w", err) + } + } else { + logger.Debugf("device %s deleted from control", string(id)) + } + + return nil +} + +func (r *ProxyGroupReconciler) ensureConfigSecretsCreated(ctx context.Context, pg *tsapi.ProxyGroup, proxyClass *tsapi.ProxyClass, svcToNodePorts map[string]uint16) (endpoints map[string][]netip.AddrPort, err error) { + logger := r.logger(pg.Name) + endpoints = make(map[string][]netip.AddrPort, pgReplicas(pg)) // keyed by Service name. + for i := range pgReplicas(pg) { + cfgSecret := &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: pgConfigSecretName(pg.Name, i), + Namespace: r.tsNamespace, + Labels: pgSecretLabels(pg.Name, kubetypes.LabelSecretTypeConfig), + OwnerReferences: pgOwnerReference(pg), + }, + } + + var existingCfgSecret *corev1.Secret // unmodified copy of secret + if err := r.Get(ctx, client.ObjectKeyFromObject(cfgSecret), cfgSecret); err == nil { + logger.Debugf("Secret %s/%s already exists", cfgSecret.GetNamespace(), cfgSecret.GetName()) + existingCfgSecret = cfgSecret.DeepCopy() + } else if !apierrors.IsNotFound(err) { + return nil, err + } + + var authKey *string + if existingCfgSecret == nil { + logger.Debugf("Creating authkey for new ProxyGroup proxy") + tags := pg.Spec.Tags.Stringify() + if len(tags) == 0 { + tags = r.defaultTags + } + key, err := newAuthKey(ctx, r.tsClient, tags) + if err != nil { + return nil, err + } + authKey = &key + } + + if authKey == nil { + // Get state Secret to check if it's already authed. + stateSecret := &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: pgStateSecretName(pg.Name, i), + Namespace: r.tsNamespace, + }, + } + if err := r.Get(ctx, client.ObjectKeyFromObject(stateSecret), stateSecret); err != nil && !apierrors.IsNotFound(err) { + return nil, err + } + + if shouldRetainAuthKey(stateSecret) && existingCfgSecret != nil { + authKey, err = authKeyFromSecret(existingCfgSecret) + if err != nil { + return nil, fmt.Errorf("error retrieving auth key from existing config Secret: %w", err) + } + } + } + + nodePortSvcName := pgNodePortServiceName(pg.Name, i) + if len(svcToNodePorts) > 0 { + replicaName := fmt.Sprintf("%s-%d", pg.Name, i) + port, ok := svcToNodePorts[nodePortSvcName] + if !ok { + return nil, fmt.Errorf("could not find configured NodePort for ProxyGroup replica %q", replicaName) + } + + endpoints[nodePortSvcName], err = r.findStaticEndpoints(ctx, existingCfgSecret, proxyClass, port, logger) + if err != nil { + return nil, fmt.Errorf("could not find static endpoints for replica %q: %w", replicaName, err) + } + } + + if pg.Spec.Type == tsapi.ProxyGroupTypeKubernetesAPIServer { + hostname := pgHostname(pg, i) + + if authKey == nil && existingCfgSecret != nil { + deviceAuthed := false + for _, d := range pg.Status.Devices { + if d.Hostname == hostname { + deviceAuthed = true + break + } + } + if !deviceAuthed { + existingCfg := conf.ConfigV1Alpha1{} + if err := json.Unmarshal(existingCfgSecret.Data[kubetypes.KubeAPIServerConfigFile], &existingCfg); err != nil { + return nil, fmt.Errorf("error unmarshalling existing config: %w", err) + } + if existingCfg.AuthKey != nil { + authKey = existingCfg.AuthKey + } + } + } + + mode := kubetypes.APIServerProxyModeAuth + if !isAuthAPIServerProxy(pg) { + mode = kubetypes.APIServerProxyModeNoAuth + } + cfg := conf.VersionedConfig{ + Version: "v1alpha1", + ConfigV1Alpha1: &conf.ConfigV1Alpha1{ + AuthKey: authKey, + State: ptr.To(fmt.Sprintf("kube:%s", pgPodName(pg.Name, i))), + App: ptr.To(kubetypes.AppProxyGroupKubeAPIServer), + LogLevel: ptr.To(logger.Level().String()), + + // Reloadable fields. + Hostname: &hostname, + APIServerProxy: &conf.APIServerProxyConfig{ + Enabled: opt.NewBool(true), + Mode: &mode, + // The first replica is elected as the cert issuer, same + // as containerboot does for ingress-pg-reconciler. + IssueCerts: opt.NewBool(i == 0), + }, + LocalPort: ptr.To(uint16(9002)), + HealthCheckEnabled: opt.NewBool(true), + }, + } + + // Copy over config that the apiserver-proxy-service-reconciler sets. + if existingCfgSecret != nil { + if k8sProxyCfg, ok := cfgSecret.Data[kubetypes.KubeAPIServerConfigFile]; ok { + k8sCfg := &conf.ConfigV1Alpha1{} + if err := json.Unmarshal(k8sProxyCfg, k8sCfg); err != nil { + return nil, fmt.Errorf("failed to unmarshal kube-apiserver config: %w", err) + } + + cfg.AdvertiseServices = k8sCfg.AdvertiseServices + if k8sCfg.APIServerProxy != nil { + cfg.APIServerProxy.ServiceName = k8sCfg.APIServerProxy.ServiceName + } + } + } + + if r.loginServer != "" { + cfg.ServerURL = &r.loginServer + } + + if proxyClass != nil && proxyClass.Spec.TailscaleConfig != nil { + cfg.AcceptRoutes = opt.NewBool(proxyClass.Spec.TailscaleConfig.AcceptRoutes) + } + + if proxyClass != nil && proxyClass.Spec.Metrics != nil { + cfg.MetricsEnabled = opt.NewBool(proxyClass.Spec.Metrics.Enable) + } + + if len(endpoints[nodePortSvcName]) > 0 { + cfg.StaticEndpoints = endpoints[nodePortSvcName] + } + + cfgB, err := json.Marshal(cfg) + if err != nil { + return nil, fmt.Errorf("error marshalling k8s-proxy config: %w", err) + } + mak.Set(&cfgSecret.Data, kubetypes.KubeAPIServerConfigFile, cfgB) + } else { + // AdvertiseServices config is set by ingress-pg-reconciler, so make sure we + // don't overwrite it if already set. + existingAdvertiseServices, err := extractAdvertiseServicesConfig(existingCfgSecret) + if err != nil { + return nil, err + } + + configs, err := pgTailscaledConfig(pg, proxyClass, i, authKey, endpoints[nodePortSvcName], existingAdvertiseServices, r.loginServer) + if err != nil { + return nil, fmt.Errorf("error creating tailscaled config: %w", err) + } + + for cap, cfg := range configs { + cfgJSON, err := json.Marshal(cfg) + if err != nil { + return nil, fmt.Errorf("error marshalling tailscaled config: %w", err) + } + mak.Set(&cfgSecret.Data, tsoperator.TailscaledConfigFileName(cap), cfgJSON) + } + } + + if existingCfgSecret != nil { + if !apiequality.Semantic.DeepEqual(existingCfgSecret, cfgSecret) { + logger.Debugf("Updating the existing ProxyGroup config Secret %s", cfgSecret.Name) + if err := r.Update(ctx, cfgSecret); err != nil { + return nil, err + } + } + } else { + logger.Debugf("Creating a new config Secret %s for the ProxyGroup", cfgSecret.Name) + if err := r.Create(ctx, cfgSecret); err != nil { + return nil, err + } + } + } + + return endpoints, nil +} + +type FindStaticEndpointErr struct { + msg string +} + +func (e *FindStaticEndpointErr) Error() string { + return e.msg +} + +// findStaticEndpoints returns up to two `netip.AddrPort` entries, derived from the ExternalIPs of Nodes that +// match the `proxyClass`'s selector within the StaticEndpoints configuration. The port is set to the replica's NodePort Service Port. +func (r *ProxyGroupReconciler) findStaticEndpoints(ctx context.Context, existingCfgSecret *corev1.Secret, proxyClass *tsapi.ProxyClass, port uint16, logger *zap.SugaredLogger) ([]netip.AddrPort, error) { + var currAddrs []netip.AddrPort + if existingCfgSecret != nil { + oldConfB := existingCfgSecret.Data[tsoperator.TailscaledConfigFileName(106)] + if len(oldConfB) > 0 { + var oldConf ipn.ConfigVAlpha + if err := json.Unmarshal(oldConfB, &oldConf); err == nil { + currAddrs = oldConf.StaticEndpoints + } else { + logger.Debugf("failed to unmarshal tailscaled config from secret %q: %v", existingCfgSecret.Name, err) + } + } else { + logger.Debugf("failed to get tailscaled config from secret %q: empty data", existingCfgSecret.Name) + } + } + + nodes := new(corev1.NodeList) + selectors := client.MatchingLabels(proxyClass.Spec.StaticEndpoints.NodePort.Selector) + + err := r.List(ctx, nodes, selectors) + if err != nil { + return nil, fmt.Errorf("failed to list nodes: %w", err) + } + + if len(nodes.Items) == 0 { + return nil, &FindStaticEndpointErr{msg: fmt.Sprintf("failed to match nodes to configured Selectors on `spec.staticEndpoints.nodePort.selectors` field for ProxyClass %q", proxyClass.Name)} + } + + endpoints := []netip.AddrPort{} + + // NOTE(ChaosInTheCRD): Setting a hard limit of two static endpoints. + newAddrs := []netip.AddrPort{} + for _, n := range nodes.Items { + for _, a := range n.Status.Addresses { + if a.Type == corev1.NodeExternalIP { + addr := getStaticEndpointAddress(&a, port) + if addr == nil { + logger.Debugf("failed to parse %q address on node %q: %q", corev1.NodeExternalIP, n.Name, a.Address) + continue + } + + // we want to add the currently used IPs first before + // adding new ones. + if currAddrs != nil && slices.Contains(currAddrs, *addr) { + endpoints = append(endpoints, *addr) + } else { + newAddrs = append(newAddrs, *addr) + } + } + + if len(endpoints) == 2 { + break + } + } + } + + // if the 2 endpoints limit hasn't been reached, we + // can start adding newIPs. + if len(endpoints) < 2 { + for _, a := range newAddrs { + endpoints = append(endpoints, a) + if len(endpoints) == 2 { + break + } + } + } + + if len(endpoints) == 0 { + return nil, &FindStaticEndpointErr{msg: fmt.Sprintf("failed to find any `status.addresses` of type %q on nodes using configured Selectors on `spec.staticEndpoints.nodePort.selectors` for ProxyClass %q", corev1.NodeExternalIP, proxyClass.Name)} + } + + return endpoints, nil +} + +func getStaticEndpointAddress(a *corev1.NodeAddress, port uint16) *netip.AddrPort { + addr, err := netip.ParseAddr(a.Address) + if err != nil { + return nil + } + + return ptr.To(netip.AddrPortFrom(addr, port)) +} + +// ensureAddedToGaugeForProxyGroup ensures the gauge metric for the ProxyGroup resource is updated when the ProxyGroup +// is created. r.mu must be held. +func (r *ProxyGroupReconciler) ensureAddedToGaugeForProxyGroup(pg *tsapi.ProxyGroup) { + switch pg.Spec.Type { + case tsapi.ProxyGroupTypeEgress: + r.egressProxyGroups.Add(pg.UID) + case tsapi.ProxyGroupTypeIngress: + r.ingressProxyGroups.Add(pg.UID) + case tsapi.ProxyGroupTypeKubernetesAPIServer: + r.apiServerProxyGroups.Add(pg.UID) + } + gaugeEgressProxyGroupResources.Set(int64(r.egressProxyGroups.Len())) + gaugeIngressProxyGroupResources.Set(int64(r.ingressProxyGroups.Len())) + gaugeAPIServerProxyGroupResources.Set(int64(r.apiServerProxyGroups.Len())) +} + +// ensureRemovedFromGaugeForProxyGroup ensures the gauge metric for the ProxyGroup resource type is updated when the +// ProxyGroup is deleted. r.mu must be held. +func (r *ProxyGroupReconciler) ensureRemovedFromGaugeForProxyGroup(pg *tsapi.ProxyGroup) { + switch pg.Spec.Type { + case tsapi.ProxyGroupTypeEgress: + r.egressProxyGroups.Remove(pg.UID) + case tsapi.ProxyGroupTypeIngress: + r.ingressProxyGroups.Remove(pg.UID) + case tsapi.ProxyGroupTypeKubernetesAPIServer: + r.apiServerProxyGroups.Remove(pg.UID) + } + gaugeEgressProxyGroupResources.Set(int64(r.egressProxyGroups.Len())) + gaugeIngressProxyGroupResources.Set(int64(r.ingressProxyGroups.Len())) + gaugeAPIServerProxyGroupResources.Set(int64(r.apiServerProxyGroups.Len())) +} + +func pgTailscaledConfig(pg *tsapi.ProxyGroup, pc *tsapi.ProxyClass, idx int32, authKey *string, staticEndpoints []netip.AddrPort, oldAdvertiseServices []string, loginServer string) (tailscaledConfigs, error) { + conf := &ipn.ConfigVAlpha{ + Version: "alpha0", + AcceptDNS: "false", + AcceptRoutes: "false", // AcceptRoutes defaults to true + Locked: "false", + Hostname: ptr.To(pgHostname(pg, idx)), + AdvertiseServices: oldAdvertiseServices, + AuthKey: authKey, + } + + if loginServer != "" { + conf.ServerURL = &loginServer + } + + if shouldAcceptRoutes(pc) { + conf.AcceptRoutes = "true" + } + + if len(staticEndpoints) > 0 { + conf.StaticEndpoints = staticEndpoints + } + + return map[tailcfg.CapabilityVersion]ipn.ConfigVAlpha{ + pgMinCapabilityVersion: *conf, + }, nil +} + +func extractAdvertiseServicesConfig(cfgSecret *corev1.Secret) ([]string, error) { + if cfgSecret == nil { + return nil, nil + } + + cfg, err := latestConfigFromSecret(cfgSecret) + if err != nil { + return nil, err + } + + if cfg == nil { + return nil, nil + } + + return cfg.AdvertiseServices, nil +} + +// getNodeMetadata gets metadata for all the pods owned by this ProxyGroup by +// querying their state Secrets. It may not return the same number of items as +// specified in the ProxyGroup spec if e.g. it is getting scaled up or down, or +// some pods have failed to write state. +// +// The returned metadata will contain an entry for each state Secret that exists. +func (r *ProxyGroupReconciler) getNodeMetadata(ctx context.Context, pg *tsapi.ProxyGroup) (metadata []nodeMetadata, _ error) { + // List all state Secrets owned by this ProxyGroup. + secrets := &corev1.SecretList{} + if err := r.List(ctx, secrets, client.InNamespace(r.tsNamespace), client.MatchingLabels(pgSecretLabels(pg.Name, kubetypes.LabelSecretTypeState))); err != nil { + return nil, fmt.Errorf("failed to list state Secrets: %w", err) + } + for _, secret := range secrets.Items { + var ordinal int + if _, err := fmt.Sscanf(secret.Name, pg.Name+"-%d", &ordinal); err != nil { + return nil, fmt.Errorf("unexpected secret %s was labelled as owned by the ProxyGroup %s: %w", secret.Name, pg.Name, err) + } + + nm := nodeMetadata{ + ordinal: ordinal, + stateSecret: &secret, + } + + prefs, ok, err := getDevicePrefs(&secret) + if err != nil { + return nil, err + } + if ok { + nm.tsID = prefs.Config.NodeID + nm.dnsName = prefs.Config.UserProfile.LoginName + } + + pod := &corev1.Pod{} + if err := r.Get(ctx, client.ObjectKey{Namespace: r.tsNamespace, Name: fmt.Sprintf("%s-%d", pg.Name, ordinal)}, pod); err != nil && !apierrors.IsNotFound(err) { + return nil, err + } else if err == nil { + nm.podUID = string(pod.UID) + } + metadata = append(metadata, nm) + } + + // Sort for predictable ordering and status. + sort.Slice(metadata, func(i, j int) bool { + return metadata[i].ordinal < metadata[j].ordinal + }) + + return metadata, nil +} + +// getRunningProxies will return status for all proxy Pods whose state Secret +// has an up to date Pod UID and at least a hostname. +func (r *ProxyGroupReconciler) getRunningProxies(ctx context.Context, pg *tsapi.ProxyGroup, staticEndpoints map[string][]netip.AddrPort) (devices []tsapi.TailnetDevice, _ error) { + metadata, err := r.getNodeMetadata(ctx, pg) + if err != nil { + return nil, err + } + + for i, m := range metadata { + if m.podUID == "" || !strings.EqualFold(string(m.stateSecret.Data[kubetypes.KeyPodUID]), m.podUID) { + // Current Pod has not yet written its UID to the state Secret, data may + // be stale. + continue + } + + device := tsapi.TailnetDevice{} + if hostname, _, ok := strings.Cut(string(m.stateSecret.Data[kubetypes.KeyDeviceFQDN]), "."); ok { + device.Hostname = hostname + } else { + continue + } + + if ipsB := m.stateSecret.Data[kubetypes.KeyDeviceIPs]; len(ipsB) > 0 { + ips := []string{} + if err := json.Unmarshal(ipsB, &ips); err != nil { + return nil, fmt.Errorf("failed to extract device IPs from state Secret %q: %w", m.stateSecret.Name, err) + } + device.TailnetIPs = ips + } + + // TODO(tomhjp): This is our input to the proxy, but we should instead + // read this back from the proxy's state in some way to more accurately + // reflect its status. + if ep, ok := staticEndpoints[pgNodePortServiceName(pg.Name, int32(i))]; ok && len(ep) > 0 { + eps := make([]string, 0, len(ep)) + for _, e := range ep { + eps = append(eps, e.String()) + } + device.StaticEndpoints = eps + } + + devices = append(devices, device) + } + + return devices, nil +} + +type nodeMetadata struct { + ordinal int + stateSecret *corev1.Secret + podUID string // or empty if the Pod no longer exists. + tsID tailcfg.StableNodeID + dnsName string +} + +func notReady(reason, msg string) (map[string][]netip.AddrPort, *notReadyReason, error) { + return nil, ¬ReadyReason{ + reason: reason, + message: msg, + }, nil +} + +func (r *ProxyGroupReconciler) notReadyErrf(pg *tsapi.ProxyGroup, logger *zap.SugaredLogger, format string, a ...any) (map[string][]netip.AddrPort, *notReadyReason, error) { + err := fmt.Errorf(format, a...) + if strings.Contains(err.Error(), optimisticLockErrorMsg) { + msg := fmt.Sprintf("optimistic lock error, retrying: %s", err.Error()) + logger.Info(msg) + return notReady(reasonProxyGroupCreating, msg) + } + + r.recorder.Event(pg, corev1.EventTypeWarning, reasonProxyGroupCreationFailed, err.Error()) + return nil, ¬ReadyReason{ + reason: reasonProxyGroupCreationFailed, + message: err.Error(), + }, err +} + +type notReadyReason struct { + reason string + message string +} diff --git a/cmd/k8s-operator/proxygroup_specs.go b/cmd/k8s-operator/proxygroup_specs.go new file mode 100644 index 000000000..34db86db2 --- /dev/null +++ b/cmd/k8s-operator/proxygroup_specs.go @@ -0,0 +1,593 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !plan9 + +package main + +import ( + "fmt" + "slices" + "strconv" + "strings" + + appsv1 "k8s.io/api/apps/v1" + corev1 "k8s.io/api/core/v1" + rbacv1 "k8s.io/api/rbac/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/types" + "k8s.io/apimachinery/pkg/util/intstr" + "sigs.k8s.io/yaml" + tsapi "tailscale.com/k8s-operator/apis/v1alpha1" + "tailscale.com/kube/egressservices" + "tailscale.com/kube/ingressservices" + "tailscale.com/kube/kubetypes" + "tailscale.com/types/ptr" +) + +const ( + // deletionGracePeriodSeconds is set to 6 minutes to ensure that the pre-stop hook of these proxies have enough chance to terminate gracefully. + deletionGracePeriodSeconds int64 = 360 + staticEndpointPortName = "static-endpoint-port" + // authAPIServerProxySAName is the ServiceAccount deployed by the helm chart + // if apiServerProxy.authEnabled is true. + authAPIServerProxySAName = "kube-apiserver-auth-proxy" +) + +func pgNodePortServiceName(proxyGroupName string, replica int32) string { + return fmt.Sprintf("%s-%d-nodeport", proxyGroupName, replica) +} + +func pgNodePortService(pg *tsapi.ProxyGroup, name string, namespace string) *corev1.Service { + return &corev1.Service{ + ObjectMeta: metav1.ObjectMeta{ + Name: name, + Namespace: namespace, + Labels: pgLabels(pg.Name, nil), + OwnerReferences: pgOwnerReference(pg), + }, + Spec: corev1.ServiceSpec{ + Type: corev1.ServiceTypeNodePort, + Ports: []corev1.ServicePort{ + // NOTE(ChaosInTheCRD): we set the ports once we've iterated over every svc and found any old configuration we want to persist. + { + Name: staticEndpointPortName, + Protocol: corev1.ProtocolUDP, + }, + }, + Selector: map[string]string{ + appsv1.StatefulSetPodNameLabel: strings.TrimSuffix(name, "-nodeport"), + }, + }, + } +} + +// Returns the base StatefulSet definition for a ProxyGroup. A ProxyClass may be +// applied over the top after. +func pgStatefulSet(pg *tsapi.ProxyGroup, namespace, image, tsFirewallMode string, port *uint16, proxyClass *tsapi.ProxyClass) (*appsv1.StatefulSet, error) { + if pg.Spec.Type == tsapi.ProxyGroupTypeKubernetesAPIServer { + return kubeAPIServerStatefulSet(pg, namespace, image, port) + } + ss := new(appsv1.StatefulSet) + if err := yaml.Unmarshal(proxyYaml, &ss); err != nil { + return nil, fmt.Errorf("failed to unmarshal proxy spec: %w", err) + } + // Validate some base assumptions. + if len(ss.Spec.Template.Spec.InitContainers) != 1 { + return nil, fmt.Errorf("[unexpected] base proxy config had %d init containers instead of 1", len(ss.Spec.Template.Spec.InitContainers)) + } + if len(ss.Spec.Template.Spec.Containers) != 1 { + return nil, fmt.Errorf("[unexpected] base proxy config had %d containers instead of 1", len(ss.Spec.Template.Spec.Containers)) + } + + // StatefulSet config. + ss.ObjectMeta = metav1.ObjectMeta{ + Name: pg.Name, + Namespace: namespace, + Labels: pgLabels(pg.Name, nil), + OwnerReferences: pgOwnerReference(pg), + } + ss.Spec.Replicas = ptr.To(pgReplicas(pg)) + ss.Spec.Selector = &metav1.LabelSelector{ + MatchLabels: pgLabels(pg.Name, nil), + } + + // Template config. + tmpl := &ss.Spec.Template + tmpl.ObjectMeta = metav1.ObjectMeta{ + Name: pg.Name, + Namespace: namespace, + Labels: pgLabels(pg.Name, nil), + DeletionGracePeriodSeconds: ptr.To[int64](10), + } + tmpl.Spec.ServiceAccountName = pg.Name + tmpl.Spec.InitContainers[0].Image = image + proxyConfigVolName := pgEgressCMName(pg.Name) + if pg.Spec.Type == tsapi.ProxyGroupTypeIngress { + proxyConfigVolName = pgIngressCMName(pg.Name) + } + tmpl.Spec.Volumes = func() []corev1.Volume { + var volumes []corev1.Volume + for i := range pgReplicas(pg) { + volumes = append(volumes, corev1.Volume{ + Name: fmt.Sprintf("tailscaledconfig-%d", i), + VolumeSource: corev1.VolumeSource{ + Secret: &corev1.SecretVolumeSource{ + SecretName: pgConfigSecretName(pg.Name, i), + }, + }, + }) + } + + volumes = append(volumes, corev1.Volume{ + Name: proxyConfigVolName, + VolumeSource: corev1.VolumeSource{ + ConfigMap: &corev1.ConfigMapVolumeSource{ + LocalObjectReference: corev1.LocalObjectReference{ + Name: proxyConfigVolName, + }, + }, + }, + }) + + return volumes + }() + + // Main container config. + c := &ss.Spec.Template.Spec.Containers[0] + c.Image = image + c.VolumeMounts = func() []corev1.VolumeMount { + var mounts []corev1.VolumeMount + + // TODO(tomhjp): Read config directly from the secret instead. The + // mounts change on scaling up/down which causes unnecessary restarts + // for pods that haven't meaningfully changed. + for i := range pgReplicas(pg) { + mounts = append(mounts, corev1.VolumeMount{ + Name: fmt.Sprintf("tailscaledconfig-%d", i), + ReadOnly: true, + MountPath: fmt.Sprintf("/etc/tsconfig/%s-%d", pg.Name, i), + }) + } + + mounts = append(mounts, corev1.VolumeMount{ + Name: proxyConfigVolName, + MountPath: "/etc/proxies", + ReadOnly: true, + }) + + return mounts + }() + c.Env = func() []corev1.EnvVar { + envs := []corev1.EnvVar{ + { + // TODO(irbekrm): verify that .status.podIPs are always set, else read in .status.podIP as well. + Name: "POD_IPS", // this will be a comma separate list i.e 10.136.0.6,2600:1900:4011:161:0:e:0:6 + ValueFrom: &corev1.EnvVarSource{ + FieldRef: &corev1.ObjectFieldSelector{ + FieldPath: "status.podIPs", + }, + }, + }, + { + Name: "TS_KUBE_SECRET", + Value: "$(POD_NAME)", + }, + { + // TODO(tomhjp): This is tsrecorder-specific and does nothing. Delete. + Name: "TS_STATE", + Value: "kube:$(POD_NAME)", + }, + { + Name: "TS_EXPERIMENTAL_VERSIONED_CONFIG_DIR", + Value: "/etc/tsconfig/$(POD_NAME)", + }, + } + + if port != nil { + envs = append(envs, corev1.EnvVar{ + Name: "PORT", + Value: strconv.Itoa(int(*port)), + }) + } + + if tsFirewallMode != "" { + envs = append(envs, corev1.EnvVar{ + Name: "TS_DEBUG_FIREWALL_MODE", + Value: tsFirewallMode, + }) + } + + if pg.Spec.Type == tsapi.ProxyGroupTypeEgress { + envs = append(envs, + // TODO(irbekrm): in 1.80 we deprecated TS_EGRESS_SERVICES_CONFIG_PATH in favour of + // TS_EGRESS_PROXIES_CONFIG_PATH. Remove it in 1.84. + corev1.EnvVar{ + Name: "TS_EGRESS_SERVICES_CONFIG_PATH", + Value: fmt.Sprintf("/etc/proxies/%s", egressservices.KeyEgressServices), + }, + corev1.EnvVar{ + Name: "TS_EGRESS_PROXIES_CONFIG_PATH", + Value: "/etc/proxies", + }, + corev1.EnvVar{ + Name: "TS_INTERNAL_APP", + Value: kubetypes.AppProxyGroupEgress, + }, + corev1.EnvVar{ + Name: "TS_ENABLE_HEALTH_CHECK", + Value: "true", + }) + } else { // ingress + envs = append(envs, corev1.EnvVar{ + Name: "TS_INTERNAL_APP", + Value: kubetypes.AppProxyGroupIngress, + }, + corev1.EnvVar{ + Name: "TS_INGRESS_PROXIES_CONFIG_PATH", + Value: fmt.Sprintf("/etc/proxies/%s", ingressservices.IngressConfigKey), + }, + corev1.EnvVar{ + Name: "TS_SERVE_CONFIG", + Value: fmt.Sprintf("/etc/proxies/%s", serveConfigKey), + }, + corev1.EnvVar{ + // Run proxies in cert share mode to + // ensure that only one TLS cert is + // issued for an HA Ingress. + Name: "TS_EXPERIMENTAL_CERT_SHARE", + Value: "true", + }, + ) + } + return append(c.Env, envs...) + }() + + // The pre-stop hook is used to ensure that a replica does not get terminated while cluster traffic for egress + // services is still being routed to it. + // + // This mechanism currently (2025-01-26) rely on the local health check being accessible on the Pod's + // IP, so they are not supported for ProxyGroups where users have configured TS_LOCAL_ADDR_PORT to a custom + // value. + // + // NB: For _Ingress_ ProxyGroups, we run shutdown logic within containerboot + // in reaction to a SIGTERM signal instead of using a pre-stop hook. This is + // because Ingress pods need to unadvertise services, and it's preferable to + // avoid triggering those side-effects from a GET request that would be + // accessible to the whole cluster network (in the absence of NetworkPolicy + // rules). + // + // TODO(tomhjp): add a readiness probe or gate to Ingress Pods. There is a + // small window where the Pod is marked ready but routing can still fail. + if pg.Spec.Type == tsapi.ProxyGroupTypeEgress && !hasLocalAddrPortSet(proxyClass) { + c.Lifecycle = &corev1.Lifecycle{ + PreStop: &corev1.LifecycleHandler{ + HTTPGet: &corev1.HTTPGetAction{ + Path: kubetypes.EgessServicesPreshutdownEP, + Port: intstr.FromInt(defaultLocalAddrPort), + }, + }, + } + // Set the deletion grace period to 6 minutes to ensure that the pre-stop hook has enough time to terminate + // gracefully. + ss.Spec.Template.DeletionGracePeriodSeconds = ptr.To(deletionGracePeriodSeconds) + } + + return ss, nil +} + +func kubeAPIServerStatefulSet(pg *tsapi.ProxyGroup, namespace, image string, port *uint16) (*appsv1.StatefulSet, error) { + sts := &appsv1.StatefulSet{ + ObjectMeta: metav1.ObjectMeta{ + Name: pg.Name, + Namespace: namespace, + Labels: pgLabels(pg.Name, nil), + OwnerReferences: pgOwnerReference(pg), + }, + Spec: appsv1.StatefulSetSpec{ + Replicas: ptr.To(pgReplicas(pg)), + Selector: &metav1.LabelSelector{ + MatchLabels: pgLabels(pg.Name, nil), + }, + Template: corev1.PodTemplateSpec{ + ObjectMeta: metav1.ObjectMeta{ + Name: pg.Name, + Namespace: namespace, + Labels: pgLabels(pg.Name, nil), + DeletionGracePeriodSeconds: ptr.To[int64](10), + }, + Spec: corev1.PodSpec{ + ServiceAccountName: pgServiceAccountName(pg), + Containers: []corev1.Container{ + { + Name: mainContainerName, + Image: image, + Env: func() []corev1.EnvVar { + envs := []corev1.EnvVar{ + { + // Used as default hostname and in Secret names. + Name: "POD_NAME", + ValueFrom: &corev1.EnvVarSource{ + FieldRef: &corev1.ObjectFieldSelector{ + FieldPath: "metadata.name", + }, + }, + }, + { + // Used by kubeclient to post Events about the Pod's lifecycle. + Name: "POD_UID", + ValueFrom: &corev1.EnvVarSource{ + FieldRef: &corev1.ObjectFieldSelector{ + FieldPath: "metadata.uid", + }, + }, + }, + { + // Used in an interpolated env var if metrics enabled. + Name: "POD_IP", + ValueFrom: &corev1.EnvVarSource{ + FieldRef: &corev1.ObjectFieldSelector{ + FieldPath: "status.podIP", + }, + }, + }, + { + // Included for completeness with POD_IP and easier backwards compatibility in future. + Name: "POD_IPS", + ValueFrom: &corev1.EnvVarSource{ + FieldRef: &corev1.ObjectFieldSelector{ + FieldPath: "status.podIPs", + }, + }, + }, + { + Name: "TS_K8S_PROXY_CONFIG", + Value: "kube:" + types.NamespacedName{ + Namespace: namespace, + Name: "$(POD_NAME)-config", + }.String(), + }, + } + + if port != nil { + envs = append(envs, corev1.EnvVar{ + Name: "PORT", + Value: strconv.Itoa(int(*port)), + }) + } + + return envs + }(), + Ports: []corev1.ContainerPort{ + { + Name: "k8s-proxy", + ContainerPort: 443, + Protocol: corev1.ProtocolTCP, + }, + }, + }, + }, + }, + }, + }, + } + + return sts, nil +} + +func pgServiceAccount(pg *tsapi.ProxyGroup, namespace string) *corev1.ServiceAccount { + return &corev1.ServiceAccount{ + ObjectMeta: metav1.ObjectMeta{ + Name: pg.Name, + Namespace: namespace, + Labels: pgLabels(pg.Name, nil), + OwnerReferences: pgOwnerReference(pg), + }, + } +} + +func pgRole(pg *tsapi.ProxyGroup, namespace string) *rbacv1.Role { + return &rbacv1.Role{ + ObjectMeta: metav1.ObjectMeta{ + Name: pg.Name, + Namespace: namespace, + Labels: pgLabels(pg.Name, nil), + OwnerReferences: pgOwnerReference(pg), + }, + Rules: []rbacv1.PolicyRule{ + { + APIGroups: []string{""}, + Resources: []string{"secrets"}, + Verbs: []string{ + "list", + "watch", // For k8s-proxy. + }, + }, + { + APIGroups: []string{""}, + Resources: []string{"secrets"}, + Verbs: []string{ + "get", + "patch", + "update", + }, + ResourceNames: func() (secrets []string) { + for i := range pgReplicas(pg) { + secrets = append(secrets, + pgConfigSecretName(pg.Name, i), // Config with auth key. + pgPodName(pg.Name, i), // State. + ) + } + return secrets + }(), + }, + { + APIGroups: []string{""}, + Resources: []string{"events"}, + Verbs: []string{ + "create", + "patch", + "get", + }, + }, + }, + } +} + +func pgRoleBinding(pg *tsapi.ProxyGroup, namespace string) *rbacv1.RoleBinding { + return &rbacv1.RoleBinding{ + ObjectMeta: metav1.ObjectMeta{ + Name: pg.Name, + Namespace: namespace, + Labels: pgLabels(pg.Name, nil), + OwnerReferences: pgOwnerReference(pg), + }, + Subjects: []rbacv1.Subject{ + { + Kind: "ServiceAccount", + Name: pgServiceAccountName(pg), + Namespace: namespace, + }, + }, + RoleRef: rbacv1.RoleRef{ + Kind: "Role", + Name: pg.Name, + }, + } +} + +// kube-apiserver proxies in auth mode use a static ServiceAccount. Everything +// else uses a per-ProxyGroup ServiceAccount. +func pgServiceAccountName(pg *tsapi.ProxyGroup) string { + if isAuthAPIServerProxy(pg) { + return authAPIServerProxySAName + } + + return pg.Name +} + +func isAuthAPIServerProxy(pg *tsapi.ProxyGroup) bool { + if pg.Spec.Type != tsapi.ProxyGroupTypeKubernetesAPIServer { + return false + } + + // The default is auth mode. + return pg.Spec.KubeAPIServer == nil || + pg.Spec.KubeAPIServer.Mode == nil || + *pg.Spec.KubeAPIServer.Mode == tsapi.APIServerProxyModeAuth +} + +func pgStateSecrets(pg *tsapi.ProxyGroup, namespace string) (secrets []*corev1.Secret) { + for i := range pgReplicas(pg) { + secrets = append(secrets, &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: pgStateSecretName(pg.Name, i), + Namespace: namespace, + Labels: pgSecretLabels(pg.Name, kubetypes.LabelSecretTypeState), + OwnerReferences: pgOwnerReference(pg), + }, + }) + } + + return secrets +} + +func pgEgressCM(pg *tsapi.ProxyGroup, namespace string) (*corev1.ConfigMap, []byte) { + hp := hepPings(pg) + hpBs := []byte(strconv.Itoa(hp)) + return &corev1.ConfigMap{ + ObjectMeta: metav1.ObjectMeta{ + Name: pgEgressCMName(pg.Name), + Namespace: namespace, + Labels: pgLabels(pg.Name, nil), + OwnerReferences: pgOwnerReference(pg), + }, + BinaryData: map[string][]byte{egressservices.KeyHEPPings: hpBs}, + }, hpBs +} + +func pgIngressCM(pg *tsapi.ProxyGroup, namespace string) *corev1.ConfigMap { + return &corev1.ConfigMap{ + ObjectMeta: metav1.ObjectMeta{ + Name: pgIngressCMName(pg.Name), + Namespace: namespace, + Labels: pgLabels(pg.Name, nil), + OwnerReferences: pgOwnerReference(pg), + }, + } +} + +func pgSecretLabels(pgName, secretType string) map[string]string { + return pgLabels(pgName, map[string]string{ + kubetypes.LabelSecretType: secretType, // "config" or "state". + }) +} + +func pgLabels(pgName string, customLabels map[string]string) map[string]string { + labels := make(map[string]string, len(customLabels)+3) + for k, v := range customLabels { + labels[k] = v + } + + labels[kubetypes.LabelManaged] = "true" + labels[LabelParentType] = "proxygroup" + labels[LabelParentName] = pgName + + return labels +} + +func pgOwnerReference(owner *tsapi.ProxyGroup) []metav1.OwnerReference { + return []metav1.OwnerReference{*metav1.NewControllerRef(owner, tsapi.SchemeGroupVersion.WithKind("ProxyGroup"))} +} + +func pgReplicas(pg *tsapi.ProxyGroup) int32 { + if pg.Spec.Replicas != nil { + return *pg.Spec.Replicas + } + + return 2 +} + +func pgPodName(pgName string, i int32) string { + return fmt.Sprintf("%s-%d", pgName, i) +} + +func pgHostname(pg *tsapi.ProxyGroup, i int32) string { + if pg.Spec.HostnamePrefix != "" { + return fmt.Sprintf("%s-%d", pg.Spec.HostnamePrefix, i) + } + + return fmt.Sprintf("%s-%d", pg.Name, i) +} + +func pgConfigSecretName(pgName string, i int32) string { + return fmt.Sprintf("%s-%d-config", pgName, i) +} + +func pgStateSecretName(pgName string, i int32) string { + return fmt.Sprintf("%s-%d", pgName, i) +} + +func pgEgressCMName(pg string) string { + return fmt.Sprintf("%s-egress-config", pg) +} + +// hasLocalAddrPortSet returns true if the proxyclass has the TS_LOCAL_ADDR_PORT env var set. For egress ProxyGroups, +// currently (2025-01-26) this means that the ProxyGroup does not support graceful failover. +func hasLocalAddrPortSet(proxyClass *tsapi.ProxyClass) bool { + if proxyClass == nil || proxyClass.Spec.StatefulSet == nil || proxyClass.Spec.StatefulSet.Pod == nil || proxyClass.Spec.StatefulSet.Pod.TailscaleContainer == nil { + return false + } + return slices.ContainsFunc(proxyClass.Spec.StatefulSet.Pod.TailscaleContainer.Env, func(env tsapi.Env) bool { + return env.Name == envVarTSLocalAddrPort + }) +} + +// hepPings returns the number of times a health check endpoint exposed by a Service fronting ProxyGroup replicas should +// be pinged to ensure that all currently configured backend replicas are hit. +func hepPings(pg *tsapi.ProxyGroup) int { + rc := pgReplicas(pg) + // Assuming a Service implemented using round robin load balancing, number-of-replica-times should be enough, but in + // practice, we cannot assume that the requests will be load balanced perfectly. + return int(rc) * 3 +} diff --git a/cmd/k8s-operator/proxygroup_test.go b/cmd/k8s-operator/proxygroup_test.go new file mode 100644 index 000000000..2bcc9fb7a --- /dev/null +++ b/cmd/k8s-operator/proxygroup_test.go @@ -0,0 +1,1979 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !plan9 + +package main + +import ( + "encoding/json" + "fmt" + "net/netip" + "slices" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "go.uber.org/zap" + appsv1 "k8s.io/api/apps/v1" + corev1 "k8s.io/api/core/v1" + rbacv1 "k8s.io/api/rbac/v1" + apiextensionsv1 "k8s.io/apiextensions-apiserver/pkg/apis/apiextensions/v1" + apierrors "k8s.io/apimachinery/pkg/api/errors" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/types" + "k8s.io/apimachinery/pkg/util/intstr" + "k8s.io/client-go/tools/record" + "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/client/fake" + "tailscale.com/client/tailscale" + "tailscale.com/ipn" + kube "tailscale.com/k8s-operator" + tsoperator "tailscale.com/k8s-operator" + tsapi "tailscale.com/k8s-operator/apis/v1alpha1" + "tailscale.com/kube/k8s-proxy/conf" + "tailscale.com/kube/kubetypes" + "tailscale.com/tailcfg" + "tailscale.com/tstest" + "tailscale.com/types/opt" + "tailscale.com/types/ptr" +) + +const ( + testProxyImage = "tailscale/tailscale:test" + initialCfgHash = "6632726be70cf224049580deb4d317bba065915b5fd415461d60ed621c91b196" +) + +var ( + defaultProxyClassAnnotations = map[string]string{ + "some-annotation": "from-the-proxy-class", + } + + defaultReplicas = ptr.To(int32(2)) + defaultStaticEndpointConfig = &tsapi.StaticEndpointsConfig{ + NodePort: &tsapi.NodePortConfig{ + Ports: []tsapi.PortRange{ + {Port: 30001}, {Port: 30002}, + }, + Selector: map[string]string{ + "foo/bar": "baz", + }, + }, + } +) + +func TestProxyGroupWithStaticEndpoints(t *testing.T) { + type testNodeAddr struct { + ip string + addrType corev1.NodeAddressType + } + + type testNode struct { + name string + addresses []testNodeAddr + labels map[string]string + } + + type reconcile struct { + staticEndpointConfig *tsapi.StaticEndpointsConfig + replicas *int32 + nodes []testNode + expectedIPs []netip.Addr + expectedEvents []string + expectedErr string + expectStatefulSet bool + } + + testCases := []struct { + name string + description string + reconciles []reconcile + }{ + { + // the reconciler should manage to create static endpoints when Nodes have IPv6 addresses. + name: "IPv6", + reconciles: []reconcile{ + { + staticEndpointConfig: &tsapi.StaticEndpointsConfig{ + NodePort: &tsapi.NodePortConfig{ + Ports: []tsapi.PortRange{ + {Port: 3001}, + {Port: 3005}, + {Port: 3007}, + {Port: 3009}, + }, + Selector: map[string]string{ + "foo/bar": "baz", + }, + }, + }, + replicas: ptr.To(int32(4)), + nodes: []testNode{ + { + name: "foobar", + addresses: []testNodeAddr{{ip: "2001:0db8::1", addrType: corev1.NodeExternalIP}}, + labels: map[string]string{"foo/bar": "baz"}, + }, + { + name: "foobarbaz", + addresses: []testNodeAddr{{ip: "2001:0db8::2", addrType: corev1.NodeExternalIP}}, + labels: map[string]string{"foo/bar": "baz"}, + }, + { + name: "foobarbazz", + addresses: []testNodeAddr{{ip: "2001:0db8::3", addrType: corev1.NodeExternalIP}}, + labels: map[string]string{"foo/bar": "baz"}, + }, + }, + expectedIPs: []netip.Addr{netip.MustParseAddr("2001:0db8::1"), netip.MustParseAddr("2001:0db8::2"), netip.MustParseAddr("2001:0db8::3")}, + expectedEvents: []string{}, + expectedErr: "", + expectStatefulSet: true, + }, + }, + }, + { + // declaring specific ports (with no `endPort`s) in the `spec.staticEndpoints.nodePort` should work. + name: "SpecificPorts", + reconciles: []reconcile{ + { + staticEndpointConfig: &tsapi.StaticEndpointsConfig{ + NodePort: &tsapi.NodePortConfig{ + Ports: []tsapi.PortRange{ + {Port: 3001}, + {Port: 3005}, + {Port: 3007}, + {Port: 3009}, + }, + Selector: map[string]string{ + "foo/bar": "baz", + }, + }, + }, + replicas: ptr.To(int32(4)), + nodes: []testNode{ + { + name: "foobar", + addresses: []testNodeAddr{{ip: "192.168.0.1", addrType: corev1.NodeExternalIP}}, + labels: map[string]string{"foo/bar": "baz"}, + }, + { + name: "foobarbaz", + addresses: []testNodeAddr{{ip: "192.168.0.2", addrType: corev1.NodeExternalIP}}, + labels: map[string]string{"foo/bar": "baz"}, + }, + { + name: "foobarbazz", + addresses: []testNodeAddr{{ip: "192.168.0.3", addrType: corev1.NodeExternalIP}}, + labels: map[string]string{"foo/bar": "baz"}, + }, + }, + expectedIPs: []netip.Addr{netip.MustParseAddr("192.168.0.1"), netip.MustParseAddr("192.168.0.2"), netip.MustParseAddr("192.168.0.3")}, + expectedEvents: []string{}, + expectedErr: "", + expectStatefulSet: true, + }, + }, + }, + { + // if too narrow a range of `spec.staticEndpoints.nodePort.Ports` on the proxyClass should result in no StatefulSet being created. + name: "NotEnoughPorts", + reconciles: []reconcile{ + { + staticEndpointConfig: &tsapi.StaticEndpointsConfig{ + NodePort: &tsapi.NodePortConfig{ + Ports: []tsapi.PortRange{ + {Port: 3001}, + {Port: 3005}, + {Port: 3007}, + }, + Selector: map[string]string{ + "foo/bar": "baz", + }, + }, + }, + replicas: ptr.To(int32(4)), + nodes: []testNode{ + { + name: "foobar", + addresses: []testNodeAddr{{ip: "192.168.0.1", addrType: corev1.NodeExternalIP}}, + labels: map[string]string{"foo/bar": "baz"}, + }, + { + name: "foobarbaz", + addresses: []testNodeAddr{{ip: "192.168.0.2", addrType: corev1.NodeExternalIP}}, + labels: map[string]string{"foo/bar": "baz"}, + }, + { + name: "foobarbazz", + addresses: []testNodeAddr{{ip: "192.168.0.3", addrType: corev1.NodeExternalIP}}, + labels: map[string]string{"foo/bar": "baz"}, + }, + }, + expectedIPs: []netip.Addr{}, + expectedEvents: []string{"Warning ProxyGroupCreationFailed error provisioning NodePort Services for static endpoints: failed to allocate NodePorts to ProxyGroup Services: not enough available ports to allocate all replicas (needed 4, got 3). Field 'spec.staticEndpoints.nodePort.ports' on ProxyClass \"default-pc\" must have bigger range allocated"}, + expectedErr: "", + expectStatefulSet: false, + }, + }, + }, + { + // when supplying a variety of ranges that are not clashing, the reconciler should manage to create a StatefulSet. + name: "NonClashingRanges", + reconciles: []reconcile{ + { + staticEndpointConfig: &tsapi.StaticEndpointsConfig{ + NodePort: &tsapi.NodePortConfig{ + Ports: []tsapi.PortRange{ + {Port: 3000, EndPort: 3002}, + {Port: 3003, EndPort: 3005}, + {Port: 3006}, + }, + Selector: map[string]string{ + "foo/bar": "baz", + }, + }, + }, + replicas: ptr.To(int32(3)), + nodes: []testNode{ + {name: "node1", addresses: []testNodeAddr{{ip: "10.0.0.1", addrType: corev1.NodeExternalIP}}, labels: map[string]string{"foo/bar": "baz"}}, + {name: "node2", addresses: []testNodeAddr{{ip: "10.0.0.2", addrType: corev1.NodeExternalIP}}, labels: map[string]string{"foo/bar": "baz"}}, + {name: "node3", addresses: []testNodeAddr{{ip: "10.0.0.3", addrType: corev1.NodeExternalIP}}, labels: map[string]string{"foo/bar": "baz"}}, + }, + expectedIPs: []netip.Addr{netip.MustParseAddr("10.0.0.1"), netip.MustParseAddr("10.0.0.2"), netip.MustParseAddr("10.0.0.3")}, + expectedEvents: []string{}, + expectedErr: "", + expectStatefulSet: true, + }, + }, + }, + { + // when there isn't a node that matches the selector, the ProxyGroup enters a failed state as there are no valid Static Endpoints. + // while it does create an event on the resource, It does not return an error + name: "NoMatchingNodes", + reconciles: []reconcile{ + { + staticEndpointConfig: &tsapi.StaticEndpointsConfig{ + NodePort: &tsapi.NodePortConfig{ + Ports: []tsapi.PortRange{ + {Port: 3000, EndPort: 3005}, + }, + Selector: map[string]string{ + "zone": "us-west", + }, + }, + }, + replicas: defaultReplicas, + nodes: []testNode{ + {name: "node1", addresses: []testNodeAddr{{ip: "10.0.0.1", addrType: corev1.NodeExternalIP}}, labels: map[string]string{"zone": "eu-central"}}, + {name: "node2", addresses: []testNodeAddr{{ip: "10.0.0.2", addrType: corev1.NodeInternalIP}}, labels: map[string]string{"zone": "eu-central"}}, + }, + expectedIPs: []netip.Addr{}, + expectedEvents: []string{"Warning ProxyGroupCreationFailed error provisioning config Secrets: could not find static endpoints for replica \"test-0\": failed to match nodes to configured Selectors on `spec.staticEndpoints.nodePort.selectors` field for ProxyClass \"default-pc\""}, + expectedErr: "", + expectStatefulSet: false, + }, + }, + }, + { + // when all the nodes have only have addresses of type InternalIP populated in their status, the ProxyGroup enters a failed state as there are no valid Static Endpoints. + // while it does create an event on the resource, It does not return an error + name: "AllInternalIPAddresses", + reconciles: []reconcile{ + { + staticEndpointConfig: &tsapi.StaticEndpointsConfig{ + NodePort: &tsapi.NodePortConfig{ + Ports: []tsapi.PortRange{ + {Port: 3001}, + {Port: 3005}, + {Port: 3007}, + {Port: 3009}, + }, + Selector: map[string]string{ + "foo/bar": "baz", + }, + }, + }, + replicas: ptr.To(int32(4)), + nodes: []testNode{ + { + name: "foobar", + addresses: []testNodeAddr{{ip: "192.168.0.1", addrType: corev1.NodeInternalIP}}, + labels: map[string]string{"foo/bar": "baz"}, + }, + { + name: "foobarbaz", + addresses: []testNodeAddr{{ip: "192.168.0.2", addrType: corev1.NodeInternalIP}}, + labels: map[string]string{"foo/bar": "baz"}, + }, + { + name: "foobarbazz", + addresses: []testNodeAddr{{ip: "192.168.0.3", addrType: corev1.NodeInternalIP}}, + labels: map[string]string{"foo/bar": "baz"}, + }, + }, + expectedIPs: []netip.Addr{}, + expectedEvents: []string{"Warning ProxyGroupCreationFailed error provisioning config Secrets: could not find static endpoints for replica \"test-0\": failed to find any `status.addresses` of type \"ExternalIP\" on nodes using configured Selectors on `spec.staticEndpoints.nodePort.selectors` for ProxyClass \"default-pc\""}, + expectedErr: "", + expectStatefulSet: false, + }, + }, + }, + { + // When the node's (and some of their addresses) change between reconciles, the reconciler should first pick addresses that + // have been used previously (provided that they are still populated on a node that matches the selector) + name: "NodeIPChangesAndPersists", + reconciles: []reconcile{ + { + staticEndpointConfig: defaultStaticEndpointConfig, + replicas: defaultReplicas, + nodes: []testNode{ + { + name: "node1", + addresses: []testNodeAddr{{ip: "10.0.0.1", addrType: corev1.NodeExternalIP}}, + labels: map[string]string{"foo/bar": "baz"}, + }, + { + name: "node2", + addresses: []testNodeAddr{{ip: "10.0.0.2", addrType: corev1.NodeExternalIP}}, + labels: map[string]string{"foo/bar": "baz"}, + }, + { + name: "node3", + addresses: []testNodeAddr{{ip: "10.0.0.3", addrType: corev1.NodeExternalIP}}, + labels: map[string]string{"foo/bar": "baz"}, + }, + }, + expectedIPs: []netip.Addr{netip.MustParseAddr("10.0.0.1"), netip.MustParseAddr("10.0.0.2")}, + expectStatefulSet: true, + }, + { + staticEndpointConfig: defaultStaticEndpointConfig, + replicas: defaultReplicas, + nodes: []testNode{ + { + name: "node1", + addresses: []testNodeAddr{{ip: "10.0.0.1", addrType: corev1.NodeExternalIP}}, + labels: map[string]string{"foo/bar": "baz"}, + }, + { + name: "node2", + addresses: []testNodeAddr{{ip: "10.0.0.10", addrType: corev1.NodeExternalIP}}, + labels: map[string]string{"foo/bar": "baz"}, + }, + { + name: "node3", + addresses: []testNodeAddr{{ip: "10.0.0.2", addrType: corev1.NodeExternalIP}}, + labels: map[string]string{"foo/bar": "baz"}, + }, + }, + expectStatefulSet: true, + expectedIPs: []netip.Addr{netip.MustParseAddr("10.0.0.1"), netip.MustParseAddr("10.0.0.2")}, + }, + }, + }, + { + // given a new node being created with a new IP, and a node previously used for Static Endpoints being removed, the Static Endpoints should be updated + // correctly + name: "NodeIPChangesWithNewNode", + reconciles: []reconcile{ + { + staticEndpointConfig: defaultStaticEndpointConfig, + replicas: defaultReplicas, + nodes: []testNode{ + { + name: "node1", + addresses: []testNodeAddr{{ip: "10.0.0.1", addrType: corev1.NodeExternalIP}}, + labels: map[string]string{"foo/bar": "baz"}, + }, + { + name: "node2", + addresses: []testNodeAddr{{ip: "10.0.0.2", addrType: corev1.NodeExternalIP}}, + labels: map[string]string{"foo/bar": "baz"}, + }, + }, + expectedIPs: []netip.Addr{netip.MustParseAddr("10.0.0.1"), netip.MustParseAddr("10.0.0.2")}, + expectStatefulSet: true, + }, + { + staticEndpointConfig: defaultStaticEndpointConfig, + replicas: defaultReplicas, + nodes: []testNode{ + { + name: "node1", + addresses: []testNodeAddr{{ip: "10.0.0.1", addrType: corev1.NodeExternalIP}}, + labels: map[string]string{"foo/bar": "baz"}, + }, + { + name: "node3", + addresses: []testNodeAddr{{ip: "10.0.0.3", addrType: corev1.NodeExternalIP}}, + labels: map[string]string{"foo/bar": "baz"}, + }, + }, + expectedIPs: []netip.Addr{netip.MustParseAddr("10.0.0.1"), netip.MustParseAddr("10.0.0.3")}, + expectStatefulSet: true, + }, + }, + }, + { + // when all the node IPs change, they should all update + name: "AllNodeIPsChange", + reconciles: []reconcile{ + { + staticEndpointConfig: defaultStaticEndpointConfig, + replicas: defaultReplicas, + nodes: []testNode{ + { + name: "node1", + addresses: []testNodeAddr{{ip: "10.0.0.1", addrType: corev1.NodeExternalIP}}, + labels: map[string]string{"foo/bar": "baz"}, + }, + { + name: "node2", + addresses: []testNodeAddr{{ip: "10.0.0.2", addrType: corev1.NodeExternalIP}}, + labels: map[string]string{"foo/bar": "baz"}, + }, + }, + expectedIPs: []netip.Addr{netip.MustParseAddr("10.0.0.1"), netip.MustParseAddr("10.0.0.2")}, + expectStatefulSet: true, + }, + { + staticEndpointConfig: defaultStaticEndpointConfig, + replicas: defaultReplicas, + nodes: []testNode{ + { + name: "node1", + addresses: []testNodeAddr{{ip: "10.0.0.100", addrType: corev1.NodeExternalIP}}, + labels: map[string]string{"foo/bar": "baz"}, + }, + { + name: "node2", + addresses: []testNodeAddr{{ip: "10.0.0.200", addrType: corev1.NodeExternalIP}}, + labels: map[string]string{"foo/bar": "baz"}, + }, + }, + expectedIPs: []netip.Addr{netip.MustParseAddr("10.0.0.100"), netip.MustParseAddr("10.0.0.200")}, + expectStatefulSet: true, + }, + }, + }, + { + // if there are less ExternalIPs after changes to the nodes between reconciles, the reconciler should complete without issues + name: "LessExternalIPsAfterChange", + reconciles: []reconcile{ + { + staticEndpointConfig: defaultStaticEndpointConfig, + replicas: defaultReplicas, + nodes: []testNode{ + { + name: "node1", + addresses: []testNodeAddr{{ip: "10.0.0.1", addrType: corev1.NodeExternalIP}}, + labels: map[string]string{"foo/bar": "baz"}, + }, + { + name: "node2", + addresses: []testNodeAddr{{ip: "10.0.0.2", addrType: corev1.NodeExternalIP}}, + labels: map[string]string{"foo/bar": "baz"}, + }, + }, + expectedIPs: []netip.Addr{netip.MustParseAddr("10.0.0.1"), netip.MustParseAddr("10.0.0.2")}, + expectStatefulSet: true, + }, + { + staticEndpointConfig: defaultStaticEndpointConfig, + replicas: defaultReplicas, + nodes: []testNode{ + { + name: "node1", + addresses: []testNodeAddr{{ip: "10.0.0.1", addrType: corev1.NodeExternalIP}}, + labels: map[string]string{"foo/bar": "baz"}, + }, + { + name: "node2", + addresses: []testNodeAddr{{ip: "10.0.0.2", addrType: corev1.NodeInternalIP}}, + labels: map[string]string{"foo/bar": "baz"}, + }, + }, + expectedIPs: []netip.Addr{netip.MustParseAddr("10.0.0.1")}, + expectStatefulSet: true, + }, + }, + }, + { + // if node address parsing fails (given an invalid address), the reconciler should continue without failure and find other + // valid addresses + name: "NodeAddressParsingFails", + reconciles: []reconcile{ + { + staticEndpointConfig: defaultStaticEndpointConfig, + replicas: defaultReplicas, + nodes: []testNode{ + { + name: "node1", + addresses: []testNodeAddr{{ip: "invalid-ip", addrType: corev1.NodeExternalIP}}, + labels: map[string]string{"foo/bar": "baz"}, + }, + { + name: "node2", + addresses: []testNodeAddr{{ip: "10.0.0.2", addrType: corev1.NodeExternalIP}}, + labels: map[string]string{"foo/bar": "baz"}, + }, + }, + expectedIPs: []netip.Addr{netip.MustParseAddr("10.0.0.2")}, + expectStatefulSet: true, + }, + { + staticEndpointConfig: defaultStaticEndpointConfig, + replicas: defaultReplicas, + nodes: []testNode{ + { + name: "node1", + addresses: []testNodeAddr{{ip: "invalid-ip", addrType: corev1.NodeExternalIP}}, + labels: map[string]string{"foo/bar": "baz"}, + }, + { + name: "node2", + addresses: []testNodeAddr{{ip: "10.0.0.2", addrType: corev1.NodeExternalIP}}, + labels: map[string]string{"foo/bar": "baz"}, + }, + }, + expectedIPs: []netip.Addr{netip.MustParseAddr("10.0.0.2")}, + expectStatefulSet: true, + }, + }, + }, + { + // if the node's become unlabeled, the ProxyGroup should enter a ProxyGroupInvalid state, but the reconciler should not fail + name: "NodesBecomeUnlabeled", + reconciles: []reconcile{ + { + staticEndpointConfig: defaultStaticEndpointConfig, + replicas: defaultReplicas, + nodes: []testNode{ + { + name: "node1", + addresses: []testNodeAddr{{ip: "10.0.0.1", addrType: corev1.NodeExternalIP}}, + labels: map[string]string{"foo/bar": "baz"}, + }, + { + name: "node2", + addresses: []testNodeAddr{{ip: "10.0.0.2", addrType: corev1.NodeExternalIP}}, + labels: map[string]string{"foo/bar": "baz"}, + }, + }, + expectedIPs: []netip.Addr{netip.MustParseAddr("10.0.0.1"), netip.MustParseAddr("10.0.0.2")}, + expectStatefulSet: true, + }, + { + staticEndpointConfig: defaultStaticEndpointConfig, + replicas: defaultReplicas, + nodes: []testNode{ + { + name: "node3", + addresses: []testNodeAddr{{ip: "10.0.0.1", addrType: corev1.NodeExternalIP}}, + labels: map[string]string{}, + }, + { + name: "node4", + addresses: []testNodeAddr{{ip: "10.0.0.2", addrType: corev1.NodeExternalIP}}, + labels: map[string]string{}, + }, + }, + expectedIPs: []netip.Addr{netip.MustParseAddr("10.0.0.1"), netip.MustParseAddr("10.0.0.2")}, + expectedEvents: []string{"Warning ProxyGroupCreationFailed error provisioning config Secrets: could not find static endpoints for replica \"test-0\": failed to match nodes to configured Selectors on `spec.staticEndpoints.nodePort.selectors` field for ProxyClass \"default-pc\""}, + expectStatefulSet: true, + }, + }, + }, + } + + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + tsClient := &fakeTSClient{} + zl, _ := zap.NewDevelopment() + fr := record.NewFakeRecorder(10) + cl := tstest.NewClock(tstest.ClockOpts{}) + + pc := &tsapi.ProxyClass{ + ObjectMeta: metav1.ObjectMeta{ + Name: "default-pc", + }, + Spec: tsapi.ProxyClassSpec{ + StatefulSet: &tsapi.StatefulSet{ + Annotations: defaultProxyClassAnnotations, + }, + }, + Status: tsapi.ProxyClassStatus{ + Conditions: []metav1.Condition{{ + Type: string(tsapi.ProxyClassReady), + Status: metav1.ConditionTrue, + Reason: reasonProxyClassValid, + Message: reasonProxyClassValid, + LastTransitionTime: metav1.Time{Time: cl.Now().Truncate(time.Second)}, + }}, + }, + } + + pg := &tsapi.ProxyGroup{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test", + Finalizers: []string{"tailscale.com/finalizer"}, + }, + Spec: tsapi.ProxyGroupSpec{ + Type: tsapi.ProxyGroupTypeEgress, + ProxyClass: pc.Name, + }, + } + + fc := fake.NewClientBuilder(). + WithObjects(pc, pg). + WithStatusSubresource(pc, pg). + WithScheme(tsapi.GlobalScheme). + Build() + + reconciler := &ProxyGroupReconciler{ + tsNamespace: tsNamespace, + tsProxyImage: testProxyImage, + defaultTags: []string{"tag:test-tag"}, + tsFirewallMode: "auto", + defaultProxyClass: "default-pc", + + Client: fc, + tsClient: tsClient, + recorder: fr, + clock: cl, + } + + for i, r := range tt.reconciles { + createdNodes := []corev1.Node{} + t.Run(tt.name, func(t *testing.T) { + for _, n := range r.nodes { + no := &corev1.Node{ + ObjectMeta: metav1.ObjectMeta{ + Name: n.name, + Labels: n.labels, + }, + Status: corev1.NodeStatus{ + Addresses: []corev1.NodeAddress{}, + }, + } + for _, addr := range n.addresses { + no.Status.Addresses = append(no.Status.Addresses, corev1.NodeAddress{ + Type: addr.addrType, + Address: addr.ip, + }) + } + if err := fc.Create(t.Context(), no); err != nil { + t.Fatalf("failed to create node %q: %v", n.name, err) + } + createdNodes = append(createdNodes, *no) + t.Logf("created node %q with data", n.name) + } + + reconciler.log = zl.Sugar().With("TestName", tt.name).With("Reconcile", i) + pg.Spec.Replicas = r.replicas + pc.Spec.StaticEndpoints = r.staticEndpointConfig + + createOrUpdate(t.Context(), fc, "", pg, func(o *tsapi.ProxyGroup) { + o.Spec.Replicas = pg.Spec.Replicas + }) + + createOrUpdate(t.Context(), fc, "", pc, func(o *tsapi.ProxyClass) { + o.Spec.StaticEndpoints = pc.Spec.StaticEndpoints + }) + + if r.expectedErr != "" { + expectError(t, reconciler, "", pg.Name) + } else { + expectReconciled(t, reconciler, "", pg.Name) + } + expectEvents(t, fr, r.expectedEvents) + + sts := &appsv1.StatefulSet{} + err := fc.Get(t.Context(), client.ObjectKey{Namespace: tsNamespace, Name: pg.Name}, sts) + if r.expectStatefulSet { + if err != nil { + t.Fatalf("failed to get StatefulSet: %v", err) + } + + for j := range 2 { + sec := &corev1.Secret{} + if err := fc.Get(t.Context(), client.ObjectKey{Namespace: tsNamespace, Name: fmt.Sprintf("%s-%d-config", pg.Name, j)}, sec); err != nil { + t.Fatalf("failed to get state Secret for replica %d: %v", j, err) + } + + config := &ipn.ConfigVAlpha{} + foundConfig := false + for _, d := range sec.Data { + if err := json.Unmarshal(d, config); err == nil { + foundConfig = true + break + } + } + if !foundConfig { + t.Fatalf("could not unmarshal config from secret data for replica %d", j) + } + + if len(config.StaticEndpoints) > staticEndpointsMaxAddrs { + t.Fatalf("expected %d StaticEndpoints in config Secret, but got %d for replica %d. Found Static Endpoints: %v", staticEndpointsMaxAddrs, len(config.StaticEndpoints), j, config.StaticEndpoints) + } + + for _, e := range config.StaticEndpoints { + if !slices.Contains(r.expectedIPs, e.Addr()) { + t.Fatalf("found unexpected static endpoint IP %q for replica %d. Expected one of %v", e.Addr().String(), j, r.expectedIPs) + } + if c := r.staticEndpointConfig; c != nil && c.NodePort.Ports != nil { + var ports tsapi.PortRanges = c.NodePort.Ports + found := false + for port := range ports.All() { + if port == e.Port() { + found = true + break + } + } + + if !found { + t.Fatalf("found unexpected static endpoint port %d for replica %d. Expected one of %v .", e.Port(), j, ports.All()) + } + } else { + if e.Port() != 3001 && e.Port() != 3002 { + t.Fatalf("found unexpected static endpoint port %d for replica %d. Expected 3001 or 3002.", e.Port(), j) + } + } + } + } + + pgroup := &tsapi.ProxyGroup{} + err = fc.Get(t.Context(), client.ObjectKey{Name: pg.Name}, pgroup) + if err != nil { + t.Fatalf("failed to get ProxyGroup %q: %v", pg.Name, err) + } + + t.Logf("getting proxygroup after reconcile") + for _, d := range pgroup.Status.Devices { + t.Logf("found device %q", d.Hostname) + for _, e := range d.StaticEndpoints { + t.Logf("found static endpoint %q", e) + } + } + } else { + if err == nil { + t.Fatal("expected error when getting Statefulset") + } + } + }) + + // node cleanup between reconciles + // we created a new set of nodes for each + for _, n := range createdNodes { + err := fc.Delete(t.Context(), &n) + if err != nil && !apierrors.IsNotFound(err) { + t.Fatalf("failed to delete node: %v", err) + } + } + } + + t.Run("delete_and_cleanup", func(t *testing.T) { + reconciler := &ProxyGroupReconciler{ + tsNamespace: tsNamespace, + tsProxyImage: testProxyImage, + defaultTags: []string{"tag:test-tag"}, + tsFirewallMode: "auto", + defaultProxyClass: "default-pc", + + Client: fc, + tsClient: tsClient, + recorder: fr, + log: zl.Sugar().With("TestName", tt.name).With("Reconcile", "cleanup"), + clock: cl, + } + + if err := fc.Delete(t.Context(), pg); err != nil { + t.Fatalf("error deleting ProxyGroup: %v", err) + } + + expectReconciled(t, reconciler, "", pg.Name) + expectMissing[tsapi.ProxyGroup](t, fc, "", pg.Name) + + if err := fc.Delete(t.Context(), pc); err != nil { + t.Fatalf("error deleting ProxyClass: %v", err) + } + expectMissing[tsapi.ProxyClass](t, fc, "", pc.Name) + }) + }) + } +} + +func TestProxyGroup(t *testing.T) { + pc := &tsapi.ProxyClass{ + ObjectMeta: metav1.ObjectMeta{ + Name: "default-pc", + }, + Spec: tsapi.ProxyClassSpec{ + StatefulSet: &tsapi.StatefulSet{ + Annotations: defaultProxyClassAnnotations, + }, + }, + } + pg := &tsapi.ProxyGroup{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test", + Finalizers: []string{"tailscale.com/finalizer"}, + Generation: 1, + }, + Spec: tsapi.ProxyGroupSpec{ + Type: tsapi.ProxyGroupTypeEgress, + }, + } + + fc := fake.NewClientBuilder(). + WithScheme(tsapi.GlobalScheme). + WithObjects(pg, pc). + WithStatusSubresource(pg, pc). + Build() + tsClient := &fakeTSClient{} + zl, _ := zap.NewDevelopment() + fr := record.NewFakeRecorder(1) + cl := tstest.NewClock(tstest.ClockOpts{}) + reconciler := &ProxyGroupReconciler{ + tsNamespace: tsNamespace, + tsProxyImage: testProxyImage, + defaultTags: []string{"tag:test-tag"}, + tsFirewallMode: "auto", + defaultProxyClass: "default-pc", + + Client: fc, + tsClient: tsClient, + recorder: fr, + log: zl.Sugar(), + clock: cl, + } + crd := &apiextensionsv1.CustomResourceDefinition{ObjectMeta: metav1.ObjectMeta{Name: serviceMonitorCRD}} + opts := configOpts{ + proxyType: "proxygroup", + stsName: pg.Name, + parentType: "proxygroup", + tailscaleNamespace: "tailscale", + resourceVersion: "1", + } + + t.Run("proxyclass_not_ready", func(t *testing.T) { + expectReconciled(t, reconciler, "", pg.Name) + + tsoperator.SetProxyGroupCondition(pg, tsapi.ProxyGroupAvailable, metav1.ConditionFalse, reasonProxyGroupCreating, "0/2 ProxyGroup pods running", 0, cl, zl.Sugar()) + tsoperator.SetProxyGroupCondition(pg, tsapi.ProxyGroupReady, metav1.ConditionFalse, reasonProxyGroupCreating, "the ProxyGroup's ProxyClass \"default-pc\" is not yet in a ready state, waiting...", 1, cl, zl.Sugar()) + expectEqual(t, fc, pg) + expectProxyGroupResources(t, fc, pg, false, pc) + if kube.ProxyGroupAvailable(pg) { + t.Fatal("expected ProxyGroup to not be available") + } + }) + + t.Run("observe_ProxyGroupCreating_status_reason", func(t *testing.T) { + pc.Status = tsapi.ProxyClassStatus{ + Conditions: []metav1.Condition{{ + Type: string(tsapi.ProxyClassReady), + Status: metav1.ConditionTrue, + Reason: reasonProxyClassValid, + Message: reasonProxyClassValid, + LastTransitionTime: metav1.Time{Time: cl.Now().Truncate(time.Second)}, + }}, + } + if err := fc.Status().Update(t.Context(), pc); err != nil { + t.Fatal(err) + } + pg.ObjectMeta.Generation = 2 + mustUpdate(t, fc, "", pg.Name, func(p *tsapi.ProxyGroup) { + p.ObjectMeta.Generation = pg.ObjectMeta.Generation + }) + expectReconciled(t, reconciler, "", pg.Name) + + tsoperator.SetProxyGroupCondition(pg, tsapi.ProxyGroupReady, metav1.ConditionFalse, reasonProxyGroupCreating, "0/2 ProxyGroup pods running", 2, cl, zl.Sugar()) + tsoperator.SetProxyGroupCondition(pg, tsapi.ProxyGroupAvailable, metav1.ConditionFalse, reasonProxyGroupCreating, "0/2 ProxyGroup pods running", 0, cl, zl.Sugar()) + expectEqual(t, fc, pg) + expectProxyGroupResources(t, fc, pg, true, pc) + if kube.ProxyGroupAvailable(pg) { + t.Fatal("expected ProxyGroup to not be available") + } + if expected := 1; reconciler.egressProxyGroups.Len() != expected { + t.Fatalf("expected %d egress ProxyGroups, got %d", expected, reconciler.egressProxyGroups.Len()) + } + expectProxyGroupResources(t, fc, pg, true, pc) + keyReq := tailscale.KeyCapabilities{ + Devices: tailscale.KeyDeviceCapabilities{ + Create: tailscale.KeyDeviceCreateCapabilities{ + Reusable: false, + Ephemeral: false, + Preauthorized: true, + Tags: []string{"tag:test-tag"}, + }, + }, + } + if diff := cmp.Diff(tsClient.KeyRequests(), []tailscale.KeyCapabilities{keyReq, keyReq}); diff != "" { + t.Fatalf("unexpected secrets (-got +want):\n%s", diff) + } + }) + + t.Run("simulate_successful_device_auth", func(t *testing.T) { + addNodeIDToStateSecrets(t, fc, pg) + pg.ObjectMeta.Generation = 3 + mustUpdate(t, fc, "", pg.Name, func(p *tsapi.ProxyGroup) { + p.ObjectMeta.Generation = pg.ObjectMeta.Generation + }) + expectReconciled(t, reconciler, "", pg.Name) + + pg.Status.Devices = []tsapi.TailnetDevice{ + { + Hostname: "hostname-nodeid-0", + TailnetIPs: []string{"1.2.3.4", "::1"}, + }, + { + Hostname: "hostname-nodeid-1", + TailnetIPs: []string{"1.2.3.4", "::1"}, + }, + } + tsoperator.SetProxyGroupCondition(pg, tsapi.ProxyGroupReady, metav1.ConditionTrue, reasonProxyGroupReady, reasonProxyGroupReady, 3, cl, zl.Sugar()) + tsoperator.SetProxyGroupCondition(pg, tsapi.ProxyGroupAvailable, metav1.ConditionTrue, reasonProxyGroupAvailable, "2/2 ProxyGroup pods running", 0, cl, zl.Sugar()) + expectEqual(t, fc, pg) + expectProxyGroupResources(t, fc, pg, true, pc) + if !kube.ProxyGroupAvailable(pg) { + t.Fatal("expected ProxyGroup to be available") + } + }) + + t.Run("scale_up_to_3", func(t *testing.T) { + pg.Spec.Replicas = ptr.To[int32](3) + mustUpdate(t, fc, "", pg.Name, func(p *tsapi.ProxyGroup) { + p.Spec = pg.Spec + }) + expectReconciled(t, reconciler, "", pg.Name) + tsoperator.SetProxyGroupCondition(pg, tsapi.ProxyGroupReady, metav1.ConditionFalse, reasonProxyGroupCreating, "2/3 ProxyGroup pods running", 3, cl, zl.Sugar()) + tsoperator.SetProxyGroupCondition(pg, tsapi.ProxyGroupAvailable, metav1.ConditionTrue, reasonProxyGroupCreating, "2/3 ProxyGroup pods running", 0, cl, zl.Sugar()) + expectEqual(t, fc, pg) + expectProxyGroupResources(t, fc, pg, true, pc) + + addNodeIDToStateSecrets(t, fc, pg) + expectReconciled(t, reconciler, "", pg.Name) + tsoperator.SetProxyGroupCondition(pg, tsapi.ProxyGroupReady, metav1.ConditionTrue, reasonProxyGroupReady, reasonProxyGroupReady, 3, cl, zl.Sugar()) + tsoperator.SetProxyGroupCondition(pg, tsapi.ProxyGroupAvailable, metav1.ConditionTrue, reasonProxyGroupAvailable, "3/3 ProxyGroup pods running", 0, cl, zl.Sugar()) + pg.Status.Devices = append(pg.Status.Devices, tsapi.TailnetDevice{ + Hostname: "hostname-nodeid-2", + TailnetIPs: []string{"1.2.3.4", "::1"}, + }) + expectEqual(t, fc, pg) + expectProxyGroupResources(t, fc, pg, true, pc) + }) + + t.Run("scale_down_to_1", func(t *testing.T) { + pg.Spec.Replicas = ptr.To[int32](1) + mustUpdate(t, fc, "", pg.Name, func(p *tsapi.ProxyGroup) { + p.Spec = pg.Spec + }) + + expectReconciled(t, reconciler, "", pg.Name) + + pg.Status.Devices = pg.Status.Devices[:1] // truncate to only the first device. + tsoperator.SetProxyGroupCondition(pg, tsapi.ProxyGroupAvailable, metav1.ConditionTrue, reasonProxyGroupAvailable, "1/1 ProxyGroup pods running", 0, cl, zl.Sugar()) + expectEqual(t, fc, pg) + expectProxyGroupResources(t, fc, pg, true, pc) + }) + + t.Run("enable_metrics", func(t *testing.T) { + pc.Spec.Metrics = &tsapi.Metrics{Enable: true} + mustUpdate(t, fc, "", pc.Name, func(p *tsapi.ProxyClass) { + p.Spec = pc.Spec + }) + expectReconciled(t, reconciler, "", pg.Name) + expectEqual(t, fc, expectedMetricsService(opts)) + }) + t.Run("enable_service_monitor_no_crd", func(t *testing.T) { + pc.Spec.Metrics.ServiceMonitor = &tsapi.ServiceMonitor{Enable: true} + mustUpdate(t, fc, "", pc.Name, func(p *tsapi.ProxyClass) { + p.Spec.Metrics = pc.Spec.Metrics + }) + expectReconciled(t, reconciler, "", pg.Name) + }) + t.Run("create_crd_expect_service_monitor", func(t *testing.T) { + mustCreate(t, fc, crd) + expectReconciled(t, reconciler, "", pg.Name) + expectEqualUnstructured(t, fc, expectedServiceMonitor(t, opts)) + }) + + t.Run("delete_and_cleanup", func(t *testing.T) { + if err := fc.Delete(t.Context(), pg); err != nil { + t.Fatal(err) + } + + expectReconciled(t, reconciler, "", pg.Name) + + expectMissing[tsapi.ProxyGroup](t, fc, "", pg.Name) + if expected := 0; reconciler.egressProxyGroups.Len() != expected { + t.Fatalf("expected %d ProxyGroups, got %d", expected, reconciler.egressProxyGroups.Len()) + } + // 2 nodes should get deleted as part of the scale down, and then finally + // the first node gets deleted with the ProxyGroup cleanup. + if diff := cmp.Diff(tsClient.deleted, []string{"nodeid-1", "nodeid-2", "nodeid-0"}); diff != "" { + t.Fatalf("unexpected deleted devices (-got +want):\n%s", diff) + } + expectMissing[corev1.Service](t, reconciler, "tailscale", metricsResourceName(pg.Name)) + // The fake client does not clean up objects whose owner has been + // deleted, so we can't test for the owned resources getting deleted. + }) +} + +func TestProxyGroupTypes(t *testing.T) { + pc := &tsapi.ProxyClass{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test", + Generation: 1, + }, + Spec: tsapi.ProxyClassSpec{}, + } + // Passing ProxyGroup as status subresource is a way to get around fake + // client's limitations for updating resource statuses. + fc := fake.NewClientBuilder(). + WithScheme(tsapi.GlobalScheme). + WithObjects(pc). + WithStatusSubresource(pc, &tsapi.ProxyGroup{}). + Build() + mustUpdateStatus(t, fc, "", pc.Name, func(p *tsapi.ProxyClass) { + p.Status.Conditions = []metav1.Condition{{ + Type: string(tsapi.ProxyClassReady), + Status: metav1.ConditionTrue, + ObservedGeneration: 1, + }} + }) + + zl, _ := zap.NewDevelopment() + reconciler := &ProxyGroupReconciler{ + tsNamespace: tsNamespace, + tsProxyImage: testProxyImage, + Client: fc, + log: zl.Sugar(), + tsClient: &fakeTSClient{}, + clock: tstest.NewClock(tstest.ClockOpts{}), + } + + t.Run("egress_type", func(t *testing.T) { + pg := &tsapi.ProxyGroup{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-egress", + UID: "test-egress-uid", + }, + Spec: tsapi.ProxyGroupSpec{ + Type: tsapi.ProxyGroupTypeEgress, + Replicas: ptr.To[int32](0), + }, + } + mustCreate(t, fc, pg) + + expectReconciled(t, reconciler, "", pg.Name) + verifyProxyGroupCounts(t, reconciler, 0, 1, 0) + + sts := &appsv1.StatefulSet{} + if err := fc.Get(t.Context(), client.ObjectKey{Namespace: tsNamespace, Name: pg.Name}, sts); err != nil { + t.Fatalf("failed to get StatefulSet: %v", err) + } + verifyEnvVar(t, sts, "TS_INTERNAL_APP", kubetypes.AppProxyGroupEgress) + verifyEnvVar(t, sts, "TS_EGRESS_PROXIES_CONFIG_PATH", "/etc/proxies") + verifyEnvVar(t, sts, "TS_ENABLE_HEALTH_CHECK", "true") + + // Verify that egress configuration has been set up. + cm := &corev1.ConfigMap{} + cmName := fmt.Sprintf("%s-egress-config", pg.Name) + if err := fc.Get(t.Context(), client.ObjectKey{Namespace: tsNamespace, Name: cmName}, cm); err != nil { + t.Fatalf("failed to get ConfigMap: %v", err) + } + + expectedVolumes := []corev1.Volume{ + { + Name: cmName, + VolumeSource: corev1.VolumeSource{ + ConfigMap: &corev1.ConfigMapVolumeSource{ + LocalObjectReference: corev1.LocalObjectReference{ + Name: cmName, + }, + }, + }, + }, + } + + expectedVolumeMounts := []corev1.VolumeMount{ + { + Name: cmName, + MountPath: "/etc/proxies", + ReadOnly: true, + }, + } + + if diff := cmp.Diff(expectedVolumes, sts.Spec.Template.Spec.Volumes); diff != "" { + t.Errorf("unexpected volumes (-want +got):\n%s", diff) + } + + if diff := cmp.Diff(expectedVolumeMounts, sts.Spec.Template.Spec.Containers[0].VolumeMounts); diff != "" { + t.Errorf("unexpected volume mounts (-want +got):\n%s", diff) + } + + expectedLifecycle := corev1.Lifecycle{ + PreStop: &corev1.LifecycleHandler{ + HTTPGet: &corev1.HTTPGetAction{ + Path: kubetypes.EgessServicesPreshutdownEP, + Port: intstr.FromInt(defaultLocalAddrPort), + }, + }, + } + if diff := cmp.Diff(expectedLifecycle, *sts.Spec.Template.Spec.Containers[0].Lifecycle); diff != "" { + t.Errorf("unexpected lifecycle (-want +got):\n%s", diff) + } + if *sts.Spec.Template.DeletionGracePeriodSeconds != deletionGracePeriodSeconds { + t.Errorf("unexpected deletion grace period seconds %d, want %d", *sts.Spec.Template.DeletionGracePeriodSeconds, deletionGracePeriodSeconds) + } + }) + t.Run("egress_type_no_lifecycle_hook_when_local_addr_port_set", func(t *testing.T) { + pg := &tsapi.ProxyGroup{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-egress-no-lifecycle", + UID: "test-egress-no-lifecycle-uid", + }, + Spec: tsapi.ProxyGroupSpec{ + Type: tsapi.ProxyGroupTypeEgress, + Replicas: ptr.To[int32](0), + ProxyClass: "test", + }, + } + mustCreate(t, fc, pg) + mustUpdate(t, fc, "", pc.Name, func(p *tsapi.ProxyClass) { + p.Spec.StatefulSet = &tsapi.StatefulSet{ + Pod: &tsapi.Pod{ + TailscaleContainer: &tsapi.Container{ + Env: []tsapi.Env{{ + Name: "TS_LOCAL_ADDR_PORT", + Value: "127.0.0.1:8080", + }}, + }, + }, + } + }) + expectReconciled(t, reconciler, "", pg.Name) + + sts := &appsv1.StatefulSet{} + if err := fc.Get(t.Context(), client.ObjectKey{Namespace: tsNamespace, Name: pg.Name}, sts); err != nil { + t.Fatalf("failed to get StatefulSet: %v", err) + } + + if sts.Spec.Template.Spec.Containers[0].Lifecycle != nil { + t.Error("lifecycle hook was set when TS_LOCAL_ADDR_PORT was configured via ProxyClass") + } + }) + + t.Run("ingress_type", func(t *testing.T) { + pg := &tsapi.ProxyGroup{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-ingress", + UID: "test-ingress-uid", + }, + Spec: tsapi.ProxyGroupSpec{ + Type: tsapi.ProxyGroupTypeIngress, + Replicas: ptr.To[int32](0), + }, + } + if err := fc.Create(t.Context(), pg); err != nil { + t.Fatal(err) + } + + expectReconciled(t, reconciler, "", pg.Name) + verifyProxyGroupCounts(t, reconciler, 1, 2, 0) + + sts := &appsv1.StatefulSet{} + if err := fc.Get(t.Context(), client.ObjectKey{Namespace: tsNamespace, Name: pg.Name}, sts); err != nil { + t.Fatalf("failed to get StatefulSet: %v", err) + } + verifyEnvVar(t, sts, "TS_INTERNAL_APP", kubetypes.AppProxyGroupIngress) + verifyEnvVar(t, sts, "TS_SERVE_CONFIG", "/etc/proxies/serve-config.json") + verifyEnvVar(t, sts, "TS_EXPERIMENTAL_CERT_SHARE", "true") + + // Verify ConfigMap volume mount + cmName := fmt.Sprintf("%s-ingress-config", pg.Name) + expectedVolume := corev1.Volume{ + Name: cmName, + VolumeSource: corev1.VolumeSource{ + ConfigMap: &corev1.ConfigMapVolumeSource{ + LocalObjectReference: corev1.LocalObjectReference{ + Name: cmName, + }, + }, + }, + } + + expectedVolumeMount := corev1.VolumeMount{ + Name: cmName, + MountPath: "/etc/proxies", + ReadOnly: true, + } + + if diff := cmp.Diff([]corev1.Volume{expectedVolume}, sts.Spec.Template.Spec.Volumes); diff != "" { + t.Errorf("unexpected volumes (-want +got):\n%s", diff) + } + + if diff := cmp.Diff([]corev1.VolumeMount{expectedVolumeMount}, sts.Spec.Template.Spec.Containers[0].VolumeMounts); diff != "" { + t.Errorf("unexpected volume mounts (-want +got):\n%s", diff) + } + }) + + t.Run("kubernetes_api_server_type", func(t *testing.T) { + pg := &tsapi.ProxyGroup{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-k8s-apiserver", + UID: "test-k8s-apiserver-uid", + }, + Spec: tsapi.ProxyGroupSpec{ + Type: tsapi.ProxyGroupTypeKubernetesAPIServer, + Replicas: ptr.To[int32](2), + KubeAPIServer: &tsapi.KubeAPIServerConfig{ + Mode: ptr.To(tsapi.APIServerProxyModeNoAuth), + }, + }, + } + if err := fc.Create(t.Context(), pg); err != nil { + t.Fatal(err) + } + + expectReconciled(t, reconciler, "", pg.Name) + verifyProxyGroupCounts(t, reconciler, 1, 2, 1) + + sts := &appsv1.StatefulSet{} + if err := fc.Get(t.Context(), client.ObjectKey{Namespace: tsNamespace, Name: pg.Name}, sts); err != nil { + t.Fatalf("failed to get StatefulSet: %v", err) + } + + // Verify the StatefulSet configuration for KubernetesAPIServer type. + if sts.Spec.Template.Spec.Containers[0].Name != mainContainerName { + t.Errorf("unexpected container name %s, want %s", sts.Spec.Template.Spec.Containers[0].Name, mainContainerName) + } + if sts.Spec.Template.Spec.Containers[0].Ports[0].ContainerPort != 443 { + t.Errorf("unexpected container port %d, want 443", sts.Spec.Template.Spec.Containers[0].Ports[0].ContainerPort) + } + if sts.Spec.Template.Spec.Containers[0].Ports[0].Name != "k8s-proxy" { + t.Errorf("unexpected port name %s, want k8s-proxy", sts.Spec.Template.Spec.Containers[0].Ports[0].Name) + } + }) +} + +func TestKubeAPIServerStatusConditionFlow(t *testing.T) { + pg := &tsapi.ProxyGroup{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-k8s-apiserver", + UID: "test-k8s-apiserver-uid", + Generation: 1, + }, + Spec: tsapi.ProxyGroupSpec{ + Type: tsapi.ProxyGroupTypeKubernetesAPIServer, + Replicas: ptr.To[int32](1), + KubeAPIServer: &tsapi.KubeAPIServerConfig{ + Mode: ptr.To(tsapi.APIServerProxyModeNoAuth), + }, + }, + } + stateSecret := &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: pgStateSecretName(pg.Name, 0), + Namespace: tsNamespace, + }, + } + fc := fake.NewClientBuilder(). + WithScheme(tsapi.GlobalScheme). + WithObjects(pg, stateSecret). + WithStatusSubresource(pg). + Build() + r := &ProxyGroupReconciler{ + tsNamespace: tsNamespace, + tsProxyImage: testProxyImage, + Client: fc, + log: zap.Must(zap.NewDevelopment()).Sugar(), + tsClient: &fakeTSClient{}, + clock: tstest.NewClock(tstest.ClockOpts{}), + } + + expectReconciled(t, r, "", pg.Name) + pg.ObjectMeta.Finalizers = append(pg.ObjectMeta.Finalizers, FinalizerName) + tsoperator.SetProxyGroupCondition(pg, tsapi.ProxyGroupAvailable, metav1.ConditionFalse, reasonProxyGroupCreating, "", 0, r.clock, r.log) + tsoperator.SetProxyGroupCondition(pg, tsapi.ProxyGroupReady, metav1.ConditionFalse, reasonProxyGroupCreating, "", 1, r.clock, r.log) + expectEqual(t, fc, pg, omitPGStatusConditionMessages) + + // Set kube-apiserver valid. + mustUpdateStatus(t, fc, "", pg.Name, func(p *tsapi.ProxyGroup) { + tsoperator.SetProxyGroupCondition(p, tsapi.KubeAPIServerProxyValid, metav1.ConditionTrue, reasonKubeAPIServerProxyValid, "", 1, r.clock, r.log) + }) + expectReconciled(t, r, "", pg.Name) + tsoperator.SetProxyGroupCondition(pg, tsapi.KubeAPIServerProxyValid, metav1.ConditionTrue, reasonKubeAPIServerProxyValid, "", 1, r.clock, r.log) + tsoperator.SetProxyGroupCondition(pg, tsapi.ProxyGroupReady, metav1.ConditionFalse, reasonProxyGroupCreating, "", 1, r.clock, r.log) + expectEqual(t, fc, pg, omitPGStatusConditionMessages) + + // Set available. + addNodeIDToStateSecrets(t, fc, pg) + expectReconciled(t, r, "", pg.Name) + pg.Status.Devices = []tsapi.TailnetDevice{ + { + Hostname: "hostname-nodeid-0", + TailnetIPs: []string{"1.2.3.4", "::1"}, + }, + } + tsoperator.SetProxyGroupCondition(pg, tsapi.ProxyGroupAvailable, metav1.ConditionTrue, reasonProxyGroupAvailable, "", 0, r.clock, r.log) + tsoperator.SetProxyGroupCondition(pg, tsapi.ProxyGroupReady, metav1.ConditionFalse, reasonProxyGroupCreating, "", 1, r.clock, r.log) + expectEqual(t, fc, pg, omitPGStatusConditionMessages) + + // Set kube-apiserver configured. + mustUpdateStatus(t, fc, "", pg.Name, func(p *tsapi.ProxyGroup) { + tsoperator.SetProxyGroupCondition(p, tsapi.KubeAPIServerProxyConfigured, metav1.ConditionTrue, reasonKubeAPIServerProxyConfigured, "", 1, r.clock, r.log) + }) + expectReconciled(t, r, "", pg.Name) + tsoperator.SetProxyGroupCondition(pg, tsapi.KubeAPIServerProxyConfigured, metav1.ConditionTrue, reasonKubeAPIServerProxyConfigured, "", 1, r.clock, r.log) + tsoperator.SetProxyGroupCondition(pg, tsapi.ProxyGroupReady, metav1.ConditionTrue, reasonProxyGroupReady, "", 1, r.clock, r.log) + expectEqual(t, fc, pg, omitPGStatusConditionMessages) +} + +func TestKubeAPIServerType_DoesNotOverwriteServicesConfig(t *testing.T) { + fc := fake.NewClientBuilder(). + WithScheme(tsapi.GlobalScheme). + WithStatusSubresource(&tsapi.ProxyGroup{}). + Build() + + reconciler := &ProxyGroupReconciler{ + tsNamespace: tsNamespace, + tsProxyImage: testProxyImage, + Client: fc, + log: zap.Must(zap.NewDevelopment()).Sugar(), + tsClient: &fakeTSClient{}, + clock: tstest.NewClock(tstest.ClockOpts{}), + } + + pg := &tsapi.ProxyGroup{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-k8s-apiserver", + UID: "test-k8s-apiserver-uid", + }, + Spec: tsapi.ProxyGroupSpec{ + Type: tsapi.ProxyGroupTypeKubernetesAPIServer, + Replicas: ptr.To[int32](1), + KubeAPIServer: &tsapi.KubeAPIServerConfig{ + Mode: ptr.To(tsapi.APIServerProxyModeNoAuth), // Avoid needing to pre-create the static ServiceAccount. + }, + }, + } + if err := fc.Create(t.Context(), pg); err != nil { + t.Fatal(err) + } + expectReconciled(t, reconciler, "", pg.Name) + + cfg := conf.VersionedConfig{ + Version: "v1alpha1", + ConfigV1Alpha1: &conf.ConfigV1Alpha1{ + AuthKey: ptr.To("secret-authkey"), + State: ptr.To(fmt.Sprintf("kube:%s", pgPodName(pg.Name, 0))), + App: ptr.To(kubetypes.AppProxyGroupKubeAPIServer), + LogLevel: ptr.To("debug"), + + Hostname: ptr.To("test-k8s-apiserver-0"), + APIServerProxy: &conf.APIServerProxyConfig{ + Enabled: opt.NewBool(true), + Mode: ptr.To(kubetypes.APIServerProxyModeNoAuth), + IssueCerts: opt.NewBool(true), + }, + LocalPort: ptr.To(uint16(9002)), + HealthCheckEnabled: opt.NewBool(true), + }, + } + cfgB, err := json.Marshal(cfg) + if err != nil { + t.Fatalf("failed to marshal config: %v", err) + } + + cfgSecret := &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: pgConfigSecretName(pg.Name, 0), + Namespace: tsNamespace, + Labels: pgSecretLabels(pg.Name, kubetypes.LabelSecretTypeConfig), + OwnerReferences: pgOwnerReference(pg), + }, + Data: map[string][]byte{ + kubetypes.KubeAPIServerConfigFile: cfgB, + }, + } + expectEqual(t, fc, cfgSecret) + + // Now simulate the kube-apiserver services reconciler updating config, + // then check the proxygroup reconciler doesn't overwrite it. + cfg.APIServerProxy.ServiceName = ptr.To(tailcfg.ServiceName("svc:some-svc-name")) + cfg.AdvertiseServices = []string{"svc:should-not-be-overwritten"} + cfgB, err = json.Marshal(cfg) + if err != nil { + t.Fatalf("failed to marshal config: %v", err) + } + mustUpdate(t, fc, tsNamespace, cfgSecret.Name, func(s *corev1.Secret) { + s.Data[kubetypes.KubeAPIServerConfigFile] = cfgB + }) + expectReconciled(t, reconciler, "", pg.Name) + + cfgSecret.Data[kubetypes.KubeAPIServerConfigFile] = cfgB + expectEqual(t, fc, cfgSecret) +} + +func TestIngressAdvertiseServicesConfigPreserved(t *testing.T) { + fc := fake.NewClientBuilder(). + WithScheme(tsapi.GlobalScheme). + WithStatusSubresource(&tsapi.ProxyGroup{}). + Build() + reconciler := &ProxyGroupReconciler{ + tsNamespace: tsNamespace, + tsProxyImage: testProxyImage, + Client: fc, + log: zap.Must(zap.NewDevelopment()).Sugar(), + tsClient: &fakeTSClient{}, + clock: tstest.NewClock(tstest.ClockOpts{}), + } + + existingServices := []string{"svc1", "svc2"} + existingConfigBytes, err := json.Marshal(ipn.ConfigVAlpha{ + AdvertiseServices: existingServices, + Version: "should-get-overwritten", + }) + if err != nil { + t.Fatal(err) + } + + const pgName = "test-ingress" + mustCreate(t, fc, &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: pgConfigSecretName(pgName, 0), + Namespace: tsNamespace, + }, + Data: map[string][]byte{ + tsoperator.TailscaledConfigFileName(pgMinCapabilityVersion): existingConfigBytes, + }, + }) + + mustCreate(t, fc, &tsapi.ProxyGroup{ + ObjectMeta: metav1.ObjectMeta{ + Name: pgName, + UID: "test-ingress-uid", + }, + Spec: tsapi.ProxyGroupSpec{ + Type: tsapi.ProxyGroupTypeIngress, + Replicas: ptr.To[int32](1), + }, + }) + expectReconciled(t, reconciler, "", pgName) + + expectedConfigBytes, err := json.Marshal(ipn.ConfigVAlpha{ + // Preserved. + AdvertiseServices: existingServices, + + // Everything else got updated in the reconcile: + Version: "alpha0", + AcceptDNS: "false", + AcceptRoutes: "false", + Locked: "false", + Hostname: ptr.To(fmt.Sprintf("%s-%d", pgName, 0)), + }) + if err != nil { + t.Fatal(err) + } + expectEqual(t, fc, &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: pgConfigSecretName(pgName, 0), + Namespace: tsNamespace, + ResourceVersion: "2", + }, + Data: map[string][]byte{ + tsoperator.TailscaledConfigFileName(pgMinCapabilityVersion): expectedConfigBytes, + }, + }) +} + +func TestValidateProxyGroup(t *testing.T) { + type testCase struct { + typ tsapi.ProxyGroupType + pgName string + image string + noauth bool + initContainer bool + staticSAExists bool + expectedErrs int + } + + for name, tc := range map[string]testCase{ + "default_ingress": { + typ: tsapi.ProxyGroupTypeIngress, + }, + "default_kube": { + typ: tsapi.ProxyGroupTypeKubernetesAPIServer, + staticSAExists: true, + }, + "default_kube_noauth": { + typ: tsapi.ProxyGroupTypeKubernetesAPIServer, + noauth: true, + // Does not require the static ServiceAccount to exist. + }, + "kube_static_sa_missing": { + typ: tsapi.ProxyGroupTypeKubernetesAPIServer, + staticSAExists: false, + expectedErrs: 1, + }, + "kube_noauth_would_overwrite_static_sa": { + typ: tsapi.ProxyGroupTypeKubernetesAPIServer, + staticSAExists: true, + noauth: true, + pgName: authAPIServerProxySAName, + expectedErrs: 1, + }, + "ingress_would_overwrite_static_sa": { + typ: tsapi.ProxyGroupTypeIngress, + staticSAExists: true, + pgName: authAPIServerProxySAName, + expectedErrs: 1, + }, + "tailscale_image_for_kube_pg_1": { + typ: tsapi.ProxyGroupTypeKubernetesAPIServer, + staticSAExists: true, + image: "example.com/tailscale/tailscale", + expectedErrs: 1, + }, + "tailscale_image_for_kube_pg_2": { + typ: tsapi.ProxyGroupTypeKubernetesAPIServer, + staticSAExists: true, + image: "example.com/tailscale", + expectedErrs: 1, + }, + "tailscale_image_for_kube_pg_3": { + typ: tsapi.ProxyGroupTypeKubernetesAPIServer, + staticSAExists: true, + image: "example.com/tailscale/tailscale:latest", + expectedErrs: 1, + }, + "tailscale_image_for_kube_pg_4": { + typ: tsapi.ProxyGroupTypeKubernetesAPIServer, + staticSAExists: true, + image: "tailscale/tailscale", + expectedErrs: 1, + }, + "k8s_proxy_image_for_ingress_pg": { + typ: tsapi.ProxyGroupTypeIngress, + image: "example.com/k8s-proxy", + expectedErrs: 1, + }, + "init_container_for_kube_pg": { + typ: tsapi.ProxyGroupTypeKubernetesAPIServer, + staticSAExists: true, + initContainer: true, + expectedErrs: 1, + }, + "init_container_for_ingress_pg": { + typ: tsapi.ProxyGroupTypeIngress, + initContainer: true, + }, + "init_container_for_egress_pg": { + typ: tsapi.ProxyGroupTypeEgress, + initContainer: true, + }, + } { + t.Run(name, func(t *testing.T) { + pc := &tsapi.ProxyClass{ + ObjectMeta: metav1.ObjectMeta{ + Name: "some-pc", + }, + Spec: tsapi.ProxyClassSpec{ + StatefulSet: &tsapi.StatefulSet{ + Pod: &tsapi.Pod{}, + }, + }, + } + if tc.image != "" { + pc.Spec.StatefulSet.Pod.TailscaleContainer = &tsapi.Container{ + Image: tc.image, + } + } + if tc.initContainer { + pc.Spec.StatefulSet.Pod.TailscaleInitContainer = &tsapi.Container{} + } + pgName := "some-pg" + if tc.pgName != "" { + pgName = tc.pgName + } + pg := &tsapi.ProxyGroup{ + ObjectMeta: metav1.ObjectMeta{ + Name: pgName, + }, + Spec: tsapi.ProxyGroupSpec{ + Type: tc.typ, + }, + } + if tc.noauth { + pg.Spec.KubeAPIServer = &tsapi.KubeAPIServerConfig{ + Mode: ptr.To(tsapi.APIServerProxyModeNoAuth), + } + } + + var objs []client.Object + if tc.staticSAExists { + objs = append(objs, &corev1.ServiceAccount{ + ObjectMeta: metav1.ObjectMeta{ + Name: authAPIServerProxySAName, + Namespace: tsNamespace, + }, + }) + } + r := ProxyGroupReconciler{ + tsNamespace: tsNamespace, + Client: fake.NewClientBuilder(). + WithObjects(objs...). + Build(), + } + + logger, _ := zap.NewDevelopment() + err := r.validate(t.Context(), pg, pc, logger.Sugar()) + if tc.expectedErrs == 0 { + if err != nil { + t.Fatalf("expected no errors, got: %v", err) + } + // Test finished. + return + } + + if err == nil { + t.Fatalf("expected %d errors, got none", tc.expectedErrs) + } + + type unwrapper interface { + Unwrap() []error + } + errs := err.(unwrapper) + if len(errs.Unwrap()) != tc.expectedErrs { + t.Fatalf("expected %d errors, got %d: %v", tc.expectedErrs, len(errs.Unwrap()), err) + } + }) + } +} + +func proxyClassesForLEStagingTest() (*tsapi.ProxyClass, *tsapi.ProxyClass, *tsapi.ProxyClass) { + pcLEStaging := &tsapi.ProxyClass{ + ObjectMeta: metav1.ObjectMeta{ + Name: "le-staging", + Generation: 1, + }, + Spec: tsapi.ProxyClassSpec{ + UseLetsEncryptStagingEnvironment: true, + }, + } + + pcLEStagingFalse := &tsapi.ProxyClass{ + ObjectMeta: metav1.ObjectMeta{ + Name: "le-staging-false", + Generation: 1, + }, + Spec: tsapi.ProxyClassSpec{ + UseLetsEncryptStagingEnvironment: false, + }, + } + + pcOther := &tsapi.ProxyClass{ + ObjectMeta: metav1.ObjectMeta{ + Name: "other", + Generation: 1, + }, + Spec: tsapi.ProxyClassSpec{}, + } + + return pcLEStaging, pcLEStagingFalse, pcOther +} + +func setProxyClassReady(t *testing.T, fc client.Client, cl *tstest.Clock, name string) *tsapi.ProxyClass { + t.Helper() + pc := &tsapi.ProxyClass{} + if err := fc.Get(t.Context(), client.ObjectKey{Name: name}, pc); err != nil { + t.Fatal(err) + } + pc.Status = tsapi.ProxyClassStatus{ + Conditions: []metav1.Condition{{ + Type: string(tsapi.ProxyClassReady), + Status: metav1.ConditionTrue, + Reason: reasonProxyClassValid, + Message: reasonProxyClassValid, + LastTransitionTime: metav1.Time{Time: cl.Now().Truncate(time.Second)}, + ObservedGeneration: pc.Generation, + }}, + } + if err := fc.Status().Update(t.Context(), pc); err != nil { + t.Fatal(err) + } + return pc +} + +func verifyProxyGroupCounts(t *testing.T, r *ProxyGroupReconciler, wantIngress, wantEgress, wantAPIServer int) { + t.Helper() + if r.ingressProxyGroups.Len() != wantIngress { + t.Errorf("expected %d ingress proxy groups, got %d", wantIngress, r.ingressProxyGroups.Len()) + } + if r.egressProxyGroups.Len() != wantEgress { + t.Errorf("expected %d egress proxy groups, got %d", wantEgress, r.egressProxyGroups.Len()) + } + if r.apiServerProxyGroups.Len() != wantAPIServer { + t.Errorf("expected %d kube-apiserver proxy groups, got %d", wantAPIServer, r.apiServerProxyGroups.Len()) + } +} + +func verifyEnvVar(t *testing.T, sts *appsv1.StatefulSet, name, expectedValue string) { + t.Helper() + for _, env := range sts.Spec.Template.Spec.Containers[0].Env { + if env.Name == name { + if env.Value != expectedValue { + t.Errorf("expected %s=%s, got %s", name, expectedValue, env.Value) + } + return + } + } + t.Errorf("%s environment variable not found", name) +} + +func verifyEnvVarNotPresent(t *testing.T, sts *appsv1.StatefulSet, name string) { + t.Helper() + for _, env := range sts.Spec.Template.Spec.Containers[0].Env { + if env.Name == name { + t.Errorf("environment variable %s should not be present", name) + return + } + } +} + +func expectProxyGroupResources(t *testing.T, fc client.WithWatch, pg *tsapi.ProxyGroup, shouldExist bool, proxyClass *tsapi.ProxyClass) { + t.Helper() + + role := pgRole(pg, tsNamespace) + roleBinding := pgRoleBinding(pg, tsNamespace) + serviceAccount := pgServiceAccount(pg, tsNamespace) + statefulSet, err := pgStatefulSet(pg, tsNamespace, testProxyImage, "auto", nil, proxyClass) + if err != nil { + t.Fatal(err) + } + statefulSet.Annotations = defaultProxyClassAnnotations + + if shouldExist { + expectEqual(t, fc, role) + expectEqual(t, fc, roleBinding) + expectEqual(t, fc, serviceAccount) + expectEqual(t, fc, statefulSet, removeResourceReqs) + } else { + expectMissing[rbacv1.Role](t, fc, role.Namespace, role.Name) + expectMissing[rbacv1.RoleBinding](t, fc, roleBinding.Namespace, roleBinding.Name) + expectMissing[corev1.ServiceAccount](t, fc, serviceAccount.Namespace, serviceAccount.Name) + expectMissing[appsv1.StatefulSet](t, fc, statefulSet.Namespace, statefulSet.Name) + } + + var expectedSecrets []string + if shouldExist { + for i := range pgReplicas(pg) { + expectedSecrets = append(expectedSecrets, + fmt.Sprintf("%s-%d", pg.Name, i), + pgConfigSecretName(pg.Name, i), + ) + } + } + expectSecrets(t, fc, expectedSecrets) +} + +func expectSecrets(t *testing.T, fc client.WithWatch, expected []string) { + t.Helper() + + secrets := &corev1.SecretList{} + if err := fc.List(t.Context(), secrets); err != nil { + t.Fatal(err) + } + + var actual []string + for _, secret := range secrets.Items { + actual = append(actual, secret.Name) + } + + if diff := cmp.Diff(actual, expected); diff != "" { + t.Fatalf("unexpected secrets (-got +want):\n%s", diff) + } +} + +func addNodeIDToStateSecrets(t *testing.T, fc client.WithWatch, pg *tsapi.ProxyGroup) { + t.Helper() + const key = "profile-abc" + for i := range pgReplicas(pg) { + bytes, err := json.Marshal(map[string]any{ + "Config": map[string]any{ + "NodeID": fmt.Sprintf("nodeid-%d", i), + }, + }) + if err != nil { + t.Fatal(err) + } + + podUID := fmt.Sprintf("pod-uid-%d", i) + pod := &corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: fmt.Sprintf("%s-%d", pg.Name, i), + Namespace: "tailscale", + UID: types.UID(podUID), + }, + } + if _, err := createOrUpdate(t.Context(), fc, "tailscale", pod, nil); err != nil { + t.Fatalf("failed to create or update Pod %s: %v", pod.Name, err) + } + mustUpdate(t, fc, tsNamespace, pgStateSecretName(pg.Name, i), func(s *corev1.Secret) { + s.Data = map[string][]byte{ + currentProfileKey: []byte(key), + key: bytes, + kubetypes.KeyDeviceIPs: []byte(`["1.2.3.4", "::1"]`), + kubetypes.KeyDeviceFQDN: []byte(fmt.Sprintf("hostname-nodeid-%d.tails-scales.ts.net", i)), + // TODO(tomhjp): We have two different mechanisms to retrieve device IDs. + // Consolidate on this one. + kubetypes.KeyDeviceID: []byte(fmt.Sprintf("nodeid-%d", i)), + kubetypes.KeyPodUID: []byte(podUID), + } + }) + } +} + +func TestProxyGroupLetsEncryptStaging(t *testing.T) { + cl := tstest.NewClock(tstest.ClockOpts{}) + zl := zap.Must(zap.NewDevelopment()) + + // Set up test cases- most are shared with non-HA Ingress. + type proxyGroupLETestCase struct { + leStagingTestCase + pgType tsapi.ProxyGroupType + } + pcLEStaging, pcLEStagingFalse, pcOther := proxyClassesForLEStagingTest() + sharedTestCases := testCasesForLEStagingTests() + var tests []proxyGroupLETestCase + for _, tt := range sharedTestCases { + tests = append(tests, proxyGroupLETestCase{ + leStagingTestCase: tt, + pgType: tsapi.ProxyGroupTypeIngress, + }) + } + tests = append(tests, proxyGroupLETestCase{ + leStagingTestCase: leStagingTestCase{ + name: "egress_pg_with_staging_proxyclass", + proxyClassPerResource: "le-staging", + useLEStagingEndpoint: false, + }, + pgType: tsapi.ProxyGroupTypeEgress, + }) + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + builder := fake.NewClientBuilder(). + WithScheme(tsapi.GlobalScheme) + + pg := &tsapi.ProxyGroup{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test", + }, + Spec: tsapi.ProxyGroupSpec{ + Type: tt.pgType, + Replicas: ptr.To[int32](1), + ProxyClass: tt.proxyClassPerResource, + }, + } + + // Pre-populate the fake client with ProxyClasses. + builder = builder.WithObjects(pcLEStaging, pcLEStagingFalse, pcOther, pg). + WithStatusSubresource(pcLEStaging, pcLEStagingFalse, pcOther, pg) + + fc := builder.Build() + + // If the test case needs a ProxyClass to exist, ensure it is set to Ready. + if tt.proxyClassPerResource != "" || tt.defaultProxyClass != "" { + name := tt.proxyClassPerResource + if name == "" { + name = tt.defaultProxyClass + } + setProxyClassReady(t, fc, cl, name) + } + + reconciler := &ProxyGroupReconciler{ + tsNamespace: tsNamespace, + tsProxyImage: testProxyImage, + defaultTags: []string{"tag:test"}, + defaultProxyClass: tt.defaultProxyClass, + Client: fc, + tsClient: &fakeTSClient{}, + log: zl.Sugar(), + clock: cl, + } + + expectReconciled(t, reconciler, "", pg.Name) + + // Verify that the StatefulSet created for ProxyGrup has + // the expected setting for the staging endpoint. + sts := &appsv1.StatefulSet{} + if err := fc.Get(t.Context(), client.ObjectKey{Namespace: tsNamespace, Name: pg.Name}, sts); err != nil { + t.Fatalf("failed to get StatefulSet: %v", err) + } + + if tt.useLEStagingEndpoint { + verifyEnvVar(t, sts, "TS_DEBUG_ACME_DIRECTORY_URL", letsEncryptStagingEndpoint) + } else { + verifyEnvVarNotPresent(t, sts, "TS_DEBUG_ACME_DIRECTORY_URL") + } + }) + } +} + +type leStagingTestCase struct { + name string + // ProxyClass set on ProxyGroup or Ingress resource. + proxyClassPerResource string + // Default ProxyClass. + defaultProxyClass string + useLEStagingEndpoint bool +} + +// Shared test cases for LE staging endpoint configuration for ProxyGroup and +// non-HA Ingress. +func testCasesForLEStagingTests() []leStagingTestCase { + return []leStagingTestCase{ + { + name: "with_staging_proxyclass", + proxyClassPerResource: "le-staging", + useLEStagingEndpoint: true, + }, + { + name: "with_staging_proxyclass_false", + proxyClassPerResource: "le-staging-false", + useLEStagingEndpoint: false, + }, + { + name: "with_other_proxyclass", + proxyClassPerResource: "other", + useLEStagingEndpoint: false, + }, + { + name: "no_proxyclass", + proxyClassPerResource: "", + useLEStagingEndpoint: false, + }, + { + name: "with_default_staging_proxyclass", + proxyClassPerResource: "", + defaultProxyClass: "le-staging", + useLEStagingEndpoint: true, + }, + { + name: "with_default_other_proxyclass", + proxyClassPerResource: "", + defaultProxyClass: "other", + useLEStagingEndpoint: false, + }, + { + name: "with_default_staging_proxyclass_false", + proxyClassPerResource: "", + defaultProxyClass: "le-staging-false", + useLEStagingEndpoint: false, + }, + } +} diff --git a/cmd/k8s-operator/sts.go b/cmd/k8s-operator/sts.go index cc6bdb8fe..c52ffce85 100644 --- a/cmd/k8s-operator/sts.go +++ b/cmd/k8s-operator/sts.go @@ -7,19 +7,21 @@ package main import ( "context" - "crypto/sha256" _ "embed" "encoding/json" "errors" "fmt" "net/http" "os" + "path" "slices" + "strconv" "strings" "go.uber.org/zap" appsv1 "k8s.io/api/apps/v1" corev1 "k8s.io/api/core/v1" + apiequality "k8s.io/apimachinery/pkg/api/equality" apierrors "k8s.io/apimachinery/pkg/api/errors" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" @@ -27,6 +29,7 @@ import ( "k8s.io/apiserver/pkg/storage/names" "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/yaml" + "tailscale.com/client/tailscale" "tailscale.com/ipn" tsoperator "tailscale.com/k8s-operator" @@ -43,16 +46,14 @@ const ( // Labels that the operator sets on StatefulSets and Pods. If you add a // new label here, do also add it to tailscaleManagedLabels var to // ensure that it does not get overwritten by ProxyClass configuration. - LabelManaged = "tailscale.com/managed" LabelParentType = "tailscale.com/parent-resource-type" LabelParentName = "tailscale.com/parent-resource" LabelParentNamespace = "tailscale.com/parent-resource-ns" - // LabelProxyClass can be set by users on Connectors, tailscale - // Ingresses and Services that define cluster ingress or cluster egress, - // to specify that configuration in this ProxyClass should be applied to - // resources created for the Connector, Ingress or Service. - LabelProxyClass = "tailscale.com/proxy-class" + // LabelProxyClass can be set by users on tailscale Ingresses and Services that define cluster ingress or + // cluster egress, to specify that configuration in this ProxyClass should be applied to resources created for + // the Ingress or Service. + LabelAnnotationProxyClass = "tailscale.com/proxy-class" FinalizerName = "tailscale.com/finalizer" @@ -62,7 +63,7 @@ const ( AnnotationHostname = "tailscale.com/hostname" annotationTailnetTargetIPOld = "tailscale.com/ts-tailnet-target-ip" AnnotationTailnetTargetIP = "tailscale.com/tailnet-ip" - //MagicDNS name of tailnet node. + // MagicDNS name of tailnet node. AnnotationTailnetTargetFQDN = "tailscale.com/tailnet-fqdn" AnnotationProxyGroup = "tailscale.com/proxy-group" @@ -92,18 +93,30 @@ const ( podAnnotationLastSetClusterDNSName = "tailscale.com/operator-last-set-cluster-dns-name" podAnnotationLastSetTailnetTargetIP = "tailscale.com/operator-last-set-ts-tailnet-target-ip" podAnnotationLastSetTailnetTargetFQDN = "tailscale.com/operator-last-set-ts-tailnet-target-fqdn" - // podAnnotationLastSetConfigFileHash is sha256 hash of the current tailscaled configuration contents. - podAnnotationLastSetConfigFileHash = "tailscale.com/operator-last-set-config-file-hash" + + proxyTypeEgress = "egress_service" + proxyTypeIngressService = "ingress_service" + proxyTypeIngressResource = "ingress_resource" + proxyTypeConnector = "connector" + proxyTypeProxyGroup = "proxygroup" + + envVarTSLocalAddrPort = "TS_LOCAL_ADDR_PORT" + defaultLocalAddrPort = 9002 // metrics and health check port + + letsEncryptStagingEndpoint = "https://acme-staging-v02.api.letsencrypt.org/directory" + + mainContainerName = "tailscale" ) var ( // tailscaleManagedLabels are label keys that tailscale operator sets on StatefulSets and Pods. - tailscaleManagedLabels = []string{LabelManaged, LabelParentType, LabelParentName, LabelParentNamespace, "app"} + tailscaleManagedLabels = []string{kubetypes.LabelManaged, LabelParentType, LabelParentName, LabelParentNamespace, "app"} // tailscaleManagedAnnotations are annotation keys that tailscale operator sets on StatefulSets and Pods. - tailscaleManagedAnnotations = []string{podAnnotationLastSetClusterIP, podAnnotationLastSetTailnetTargetIP, podAnnotationLastSetTailnetTargetFQDN, podAnnotationLastSetConfigFileHash} + tailscaleManagedAnnotations = []string{podAnnotationLastSetClusterIP, podAnnotationLastSetTailnetTargetIP, podAnnotationLastSetTailnetTargetFQDN} ) type tailscaleSTSConfig struct { + Replicas int32 ParentResourceName string ParentResourceUID string ChildResourceLabels map[string]string @@ -122,6 +135,8 @@ type tailscaleSTSConfig struct { Hostname string Tags []string // if empty, use defaultTags + proxyType string + // Connector specifies a configuration of a Connector instance if that's // what this StatefulSet should be created for. Connector *connector @@ -129,13 +144,23 @@ type tailscaleSTSConfig struct { ProxyClassName string // name of ProxyClass if one needs to be applied to the proxy ProxyClass *tsapi.ProxyClass // ProxyClass that needs to be applied to the proxy (if there is one) + + // LoginServer denotes the URL of the control plane that should be used by the proxy. + LoginServer string + + // HostnamePrefix specifies the desired prefix for the device's hostname. The hostname will be suffixed with the + // ordinal number generated by the StatefulSet. + HostnamePrefix string } type connector struct { - // routes is a list of subnet routes that this Connector should expose. + // routes is a list of routes that this Connector should advertise either as a subnet router or as an app + // connector. routes string // isExitNode defines whether this Connector should act as an exit node. isExitNode bool + // isAppConnector defines whether this Connector should act as an app connector. + isAppConnector bool } type tsnetServer interface { CertDomains() []string @@ -150,6 +175,7 @@ type tailscaleSTSReconciler struct { proxyImage string proxyPriorityClassName string tsFirewallMode string + loginServer string } func (sts tailscaleSTSReconciler) validate() error { @@ -160,8 +186,8 @@ func (sts tailscaleSTSReconciler) validate() error { } // IsHTTPSEnabledOnTailnet reports whether HTTPS is enabled on the tailnet. -func (a *tailscaleSTSReconciler) IsHTTPSEnabledOnTailnet() bool { - return len(a.tsnetServer.CertDomains()) > 0 +func IsHTTPSEnabledOnTailnet(tsnetServer tsnetServer) bool { + return len(tsnetServer.CertDomains()) > 0 } // Provision ensures that the StatefulSet for the given service is running and @@ -186,22 +212,31 @@ func (a *tailscaleSTSReconciler) Provision(ctx context.Context, logger *zap.Suga } sts.ProxyClass = proxyClass - secretName, tsConfigHash, configs, err := a.createOrGetSecret(ctx, logger, sts, hsvc) + secretNames, err := a.provisionSecrets(ctx, logger, sts, hsvc) if err != nil { return nil, fmt.Errorf("failed to create or get API key secret: %w", err) } - _, err = a.reconcileSTS(ctx, logger, sts, hsvc, secretName, tsConfigHash, configs) + + _, err = a.reconcileSTS(ctx, logger, sts, hsvc, secretNames) if err != nil { return nil, fmt.Errorf("failed to reconcile statefulset: %w", err) } - + mo := &metricsOpts{ + proxyStsName: hsvc.Name, + tsNamespace: hsvc.Namespace, + proxyLabels: hsvc.Labels, + proxyType: sts.proxyType, + } + if err = reconcileMetricsResources(ctx, logger, mo, sts.ProxyClass, a.Client); err != nil { + return nil, fmt.Errorf("failed to ensure metrics resources: %w", err) + } return hsvc, nil } // Cleanup removes all resources associated that were created by Provision with // the given labels. It returns true when all resources have been removed, // otherwise it returns false and the caller should retry later. -func (a *tailscaleSTSReconciler) Cleanup(ctx context.Context, logger *zap.SugaredLogger, labels map[string]string) (done bool, _ error) { +func (a *tailscaleSTSReconciler) Cleanup(ctx context.Context, logger *zap.SugaredLogger, labels map[string]string, typ string) (done bool, _ error) { // Need to delete the StatefulSet first, and delete it with foreground // cascading deletion. That way, the pod that's writing to the Secret will // stop running before we start looking at the Secret's contents, and @@ -212,6 +247,7 @@ func (a *tailscaleSTSReconciler) Cleanup(ctx context.Context, logger *zap.Sugare if err != nil { return false, fmt.Errorf("getting statefulset: %w", err) } + if sts != nil { if !sts.GetDeletionTimestamp().IsZero() { // Deletion in progress, check again later. We'll get another @@ -219,29 +255,39 @@ func (a *tailscaleSTSReconciler) Cleanup(ctx context.Context, logger *zap.Sugare logger.Debugf("waiting for statefulset %s/%s deletion", sts.GetNamespace(), sts.GetName()) return false, nil } - err := a.DeleteAllOf(ctx, &appsv1.StatefulSet{}, client.InNamespace(a.operatorNamespace), client.MatchingLabels(labels), client.PropagationPolicy(metav1.DeletePropagationForeground)) - if err != nil { + + options := []client.DeleteAllOfOption{ + client.InNamespace(a.operatorNamespace), + client.MatchingLabels(labels), + client.PropagationPolicy(metav1.DeletePropagationForeground), + } + + if err = a.DeleteAllOf(ctx, &appsv1.StatefulSet{}, options...); err != nil { return false, fmt.Errorf("deleting statefulset: %w", err) } + logger.Debugf("started deletion of statefulset %s/%s", sts.GetNamespace(), sts.GetName()) return false, nil } - id, _, _, err := a.DeviceInfo(ctx, labels) + devices, err := a.DeviceInfo(ctx, labels, logger) if err != nil { return false, fmt.Errorf("getting device info: %w", err) } - if id != "" { - logger.Debugf("deleting device %s from control", string(id)) - if err := a.tsClient.DeleteDevice(ctx, string(id)); err != nil { - errResp := &tailscale.ErrResponse{} - if ok := errors.As(err, errResp); ok && errResp.Status == http.StatusNotFound { - logger.Debugf("device %s not found, likely because it has already been deleted from control", string(id)) + + for _, dev := range devices { + if dev.id != "" { + logger.Debugf("deleting device %s from control", string(dev.id)) + if err = a.tsClient.DeleteDevice(ctx, string(dev.id)); err != nil { + errResp := &tailscale.ErrResponse{} + if ok := errors.As(err, errResp); ok && errResp.Status == http.StatusNotFound { + logger.Debugf("device %s not found, likely because it has already been deleted from control", string(dev.id)) + } else { + return false, fmt.Errorf("deleting device: %w", err) + } } else { - return false, fmt.Errorf("deleting device: %w", err) + logger.Debugf("device %s deleted from control", string(dev.id)) } - } else { - logger.Debugf("device %s deleted from control", string(id)) } } @@ -254,6 +300,15 @@ func (a *tailscaleSTSReconciler) Cleanup(ctx context.Context, logger *zap.Sugare return false, err } } + mo := &metricsOpts{ + proxyLabels: labels, + tsNamespace: a.operatorNamespace, + proxyType: typ, + } + if err = maybeCleanupMetricsResources(ctx, mo, a.Client); err != nil { + return false, fmt.Errorf("error cleaning up metrics resources: %w", err) + } + return true, nil } @@ -304,149 +359,238 @@ func (a *tailscaleSTSReconciler) reconcileHeadlessService(ctx context.Context, l return createOrUpdate(ctx, a.Client, a.operatorNamespace, hsvc, func(svc *corev1.Service) { svc.Spec = hsvc.Spec }) } -func (a *tailscaleSTSReconciler) createOrGetSecret(ctx context.Context, logger *zap.SugaredLogger, stsC *tailscaleSTSConfig, hsvc *corev1.Service) (secretName, hash string, configs tailscaleConfigs, _ error) { - secret := &corev1.Secret{ - ObjectMeta: metav1.ObjectMeta{ - // Hardcode a -0 suffix so that in future, if we support - // multiple StatefulSet replicas, we can provision -N for - // those. - Name: hsvc.Name + "-0", - Namespace: a.operatorNamespace, - Labels: stsC.ChildResourceLabels, - }, - } - var orig *corev1.Secret // unmodified copy of secret - if err := a.Get(ctx, client.ObjectKeyFromObject(secret), secret); err == nil { - logger.Debugf("secret %s/%s already exists", secret.GetNamespace(), secret.GetName()) - orig = secret.DeepCopy() - } else if !apierrors.IsNotFound(err) { - return "", "", nil, err - } +func (a *tailscaleSTSReconciler) provisionSecrets(ctx context.Context, logger *zap.SugaredLogger, stsC *tailscaleSTSConfig, hsvc *corev1.Service) ([]string, error) { + secretNames := make([]string, stsC.Replicas) + + // Start by ensuring we have Secrets for the desired number of replicas. This will handle both creating and scaling + // up a StatefulSet. + for i := range stsC.Replicas { + secret := &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: fmt.Sprintf("%s-%d", hsvc.Name, i), + Namespace: a.operatorNamespace, + Labels: stsC.ChildResourceLabels, + }, + } - var authKey string - if orig == nil { - // Initially it contains only tailscaled config, but when the - // proxy starts, it will also store there the state, certs and - // ACME account key. - sts, err := getSingleObject[appsv1.StatefulSet](ctx, a.Client, a.operatorNamespace, stsC.ChildResourceLabels) - if err != nil { - return "", "", nil, err - } - if sts != nil { - // StatefulSet exists, so we have already created the secret. - // If the secret is missing, they should delete the StatefulSet. - logger.Errorf("Tailscale proxy secret doesn't exist, but the corresponding StatefulSet %s/%s already does. Something is wrong, please delete the StatefulSet.", sts.GetNamespace(), sts.GetName()) - return "", "", nil, nil - } - // Create API Key secret which is going to be used by the statefulset - // to authenticate with Tailscale. - logger.Debugf("creating authkey for new tailscale proxy") - tags := stsC.Tags - if len(tags) == 0 { - tags = a.defaultTags - } - authKey, err = newAuthKey(ctx, a.tsClient, tags) + // If we only have a single replica, use the hostname verbatim. Otherwise, use the hostname prefix and add + // an ordinal suffix. + hostname := stsC.Hostname + if stsC.HostnamePrefix != "" { + hostname = fmt.Sprintf("%s-%d", stsC.HostnamePrefix, i) + } + + secretNames[i] = secret.Name + + var orig *corev1.Secret // unmodified copy of secret + if err := a.Get(ctx, client.ObjectKeyFromObject(secret), secret); err == nil { + logger.Debugf("secret %s/%s already exists", secret.GetNamespace(), secret.GetName()) + orig = secret.DeepCopy() + } else if !apierrors.IsNotFound(err) { + return nil, err + } + + var ( + authKey string + err error + ) + if orig == nil { + // Create API Key secret which is going to be used by the statefulset + // to authenticate with Tailscale. + logger.Debugf("creating authkey for new tailscale proxy") + tags := stsC.Tags + if len(tags) == 0 { + tags = a.defaultTags + } + authKey, err = newAuthKey(ctx, a.tsClient, tags) + if err != nil { + return nil, err + } + } + + configs, err := tailscaledConfig(stsC, authKey, orig, hostname) if err != nil { - return "", "", nil, err + return nil, fmt.Errorf("error creating tailscaled config: %w", err) + } + + latest := tailcfg.CapabilityVersion(-1) + var latestConfig ipn.ConfigVAlpha + for key, val := range configs { + fn := tsoperator.TailscaledConfigFileName(key) + b, err := json.Marshal(val) + if err != nil { + return nil, fmt.Errorf("error marshalling tailscaled config: %w", err) + } + + mak.Set(&secret.StringData, fn, string(b)) + if key > latest { + latest = key + latestConfig = val + } + } + + if stsC.ServeConfig != nil { + j, err := json.Marshal(stsC.ServeConfig) + if err != nil { + return nil, err + } + + mak.Set(&secret.StringData, "serve-config", string(j)) + } + + if orig != nil && !apiequality.Semantic.DeepEqual(latest, orig) { + logger.With("config", sanitizeConfig(latestConfig)).Debugf("patching the existing proxy Secret") + if err = a.Patch(ctx, secret, client.MergeFrom(orig)); err != nil { + return nil, err + } + } else { + logger.With("config", sanitizeConfig(latestConfig)).Debugf("creating a new Secret for the proxy") + if err = a.Create(ctx, secret); err != nil { + return nil, err + } } } - configs, err := tailscaledConfig(stsC, authKey, orig) - if err != nil { - return "", "", nil, fmt.Errorf("error creating tailscaled config: %w", err) - } - hash, err = tailscaledConfigHash(configs) - if err != nil { - return "", "", nil, fmt.Errorf("error calculating hash of tailscaled configs: %w", err) + + // Next, we check if we have additional secrets and remove them and their associated device. This happens when we + // scale an StatefulSet down. + var secrets corev1.SecretList + if err := a.List(ctx, &secrets, client.InNamespace(a.operatorNamespace), client.MatchingLabels(stsC.ChildResourceLabels)); err != nil { + return nil, err } - latest := tailcfg.CapabilityVersion(-1) - var latestConfig ipn.ConfigVAlpha - for key, val := range configs { - fn := tsoperator.TailscaledConfigFileNameForCap(key) - b, err := json.Marshal(val) - if err != nil { - return "", "", nil, fmt.Errorf("error marshalling tailscaled config: %w", err) + for _, secret := range secrets.Items { + var ordinal int32 + if _, err := fmt.Sscanf(secret.Name, hsvc.Name+"-%d", &ordinal); err != nil { + return nil, err } - mak.Set(&secret.StringData, fn, string(b)) - if key > latest { - latest = key - latestConfig = val + + if ordinal < stsC.Replicas { + continue } - } - if stsC.ServeConfig != nil { - j, err := json.Marshal(stsC.ServeConfig) + dev, err := deviceInfo(&secret, "", logger) if err != nil { - return "", "", nil, err + return nil, err } - mak.Set(&secret.StringData, "serve-config", string(j)) - } - if orig != nil { - logger.Debugf("patching the existing proxy Secret with tailscaled config %s", sanitizeConfigBytes(latestConfig)) - if err := a.Patch(ctx, secret, client.MergeFrom(orig)); err != nil { - return "", "", nil, err + if dev != nil && dev.id != "" { + var errResp *tailscale.ErrResponse + + err = a.tsClient.DeleteDevice(ctx, string(dev.id)) + switch { + case errors.As(err, &errResp) && errResp.Status == http.StatusNotFound: + // This device has possibly already been deleted in the admin console. So we can ignore this + // and move on to removing the secret. + case err != nil: + return nil, err + } } - } else { - logger.Debugf("creating a new Secret for the proxy with tailscaled config %s", sanitizeConfigBytes(latestConfig)) - if err := a.Create(ctx, secret); err != nil { - return "", "", nil, err + + if err = a.Delete(ctx, &secret); err != nil { + return nil, err } } - return secret.Name, hash, configs, nil + + return secretNames, nil } -// sanitizeConfigBytes returns ipn.ConfigVAlpha in string form with redacted -// auth key. -func sanitizeConfigBytes(c ipn.ConfigVAlpha) string { +// sanitizeConfig returns an ipn.ConfigVAlpha with sensitive fields redacted. Since we pump everything +// into JSON-encoded logs it's easier to read this with a .With method than converting it to a string. +func sanitizeConfig(c ipn.ConfigVAlpha) ipn.ConfigVAlpha { + // Explicitly redact AuthKey because we never want it appearing in logs. Never populate this with the + // actual auth key. if c.AuthKey != nil { c.AuthKey = ptr.To("**redacted**") } - sanitizedBytes, err := json.Marshal(c) - if err != nil { - return "invalid config" - } - return string(sanitizedBytes) + + return c } -// DeviceInfo returns the device ID, hostname and IPs for the Tailscale device -// that acts as an operator proxy. It retrieves info from a Kubernetes Secret -// labeled with the provided labels. -// Either of device ID, hostname and IPs can be empty string if not found in the Secret. -func (a *tailscaleSTSReconciler) DeviceInfo(ctx context.Context, childLabels map[string]string) (id tailcfg.StableNodeID, hostname string, ips []string, err error) { - sec, err := getSingleObject[corev1.Secret](ctx, a.Client, a.operatorNamespace, childLabels) - if err != nil { - return "", "", nil, err +// DeviceInfo returns the device ID, hostname, IPs and capver for the Tailscale device that acts as an operator proxy. +// It retrieves info from a Kubernetes Secret labeled with the provided labels. Capver is cross-validated against the +// Pod to ensure that it is the currently running Pod that set the capver. If the Pod or the Secret does not exist, the +// returned capver is -1. Either of device ID, hostname and IPs can be empty string if not found in the Secret. +func (a *tailscaleSTSReconciler) DeviceInfo(ctx context.Context, childLabels map[string]string, logger *zap.SugaredLogger) ([]*device, error) { + var secrets corev1.SecretList + if err := a.List(ctx, &secrets, client.InNamespace(a.operatorNamespace), client.MatchingLabels(childLabels)); err != nil { + return nil, err } - if sec == nil { - return "", "", nil, nil + + devices := make([]*device, 0) + for _, sec := range secrets.Items { + podUID := "" + pod := new(corev1.Pod) + err := a.Get(ctx, types.NamespacedName{Namespace: sec.Namespace, Name: sec.Name}, pod) + switch { + case apierrors.IsNotFound(err): + // If the Pod is not found, we won't have its UID. We can still get the device information but the + // capability version will be unknown. + case err != nil: + return nil, err + default: + podUID = string(pod.ObjectMeta.UID) + } + + info, err := deviceInfo(&sec, podUID, logger) + if err != nil { + return nil, err + } + + if info != nil { + devices = append(devices, info) + } } - return deviceInfo(sec) + return devices, nil +} + +// device contains tailscale state of a proxy device as gathered from its tailscale state Secret. +type device struct { + id tailcfg.StableNodeID // device's stable ID + hostname string // MagicDNS name of the device + ips []string // Tailscale IPs of the device + // ingressDNSName is the L7 Ingress DNS name. In practice this will be the same value as hostname, but only set + // when the device has been configured to serve traffic on it via 'tailscale serve'. + ingressDNSName string + capver tailcfg.CapabilityVersion } -func deviceInfo(sec *corev1.Secret) (id tailcfg.StableNodeID, hostname string, ips []string, err error) { - id = tailcfg.StableNodeID(sec.Data["device_id"]) +func deviceInfo(sec *corev1.Secret, podUID string, log *zap.SugaredLogger) (dev *device, err error) { + id := tailcfg.StableNodeID(sec.Data[kubetypes.KeyDeviceID]) if id == "" { - return "", "", nil, nil + return dev, nil } + dev = &device{id: id} // Kubernetes chokes on well-formed FQDNs with the trailing dot, so we have // to remove it. - hostname = strings.TrimSuffix(string(sec.Data["device_fqdn"]), ".") - if hostname == "" { + dev.hostname = strings.TrimSuffix(string(sec.Data[kubetypes.KeyDeviceFQDN]), ".") + if dev.hostname == "" { // Device ID gets stored and retrieved in a different flow than // FQDN and IPs. A device that acts as Kubernetes operator - // proxy, but whose route setup has failed might have an device + // proxy, but whose route setup has failed might have a device // ID, but no FQDN/IPs. If so, return the ID, to allow the // operator to clean up such devices. - return id, "", nil, nil + return dev, nil + } + dev.ingressDNSName = dev.hostname + pcv := proxyCapVer(sec, podUID, log) + dev.capver = pcv + // TODO(irbekrm): we fall back to using the hostname field to determine Ingress's hostname to ensure backwards + // compatibility. In 1.82 we can remove this fallback mechanism. + if pcv >= 109 { + dev.ingressDNSName = strings.TrimSuffix(string(sec.Data[kubetypes.KeyHTTPSEndpoint]), ".") + if strings.EqualFold(dev.ingressDNSName, kubetypes.ValueNoHTTPS) { + dev.ingressDNSName = "" + } } - if rawDeviceIPs, ok := sec.Data["device_ips"]; ok { + if rawDeviceIPs, ok := sec.Data[kubetypes.KeyDeviceIPs]; ok { + ips := make([]string, 0) if err := json.Unmarshal(rawDeviceIPs, &ips); err != nil { - return "", "", nil, err + return nil, err } + dev.ips = ips } - return id, hostname, ips, nil + return dev, nil } func newAuthKey(ctx context.Context, tsClient tsClient, tags []string) (string, error) { @@ -473,7 +617,7 @@ var proxyYaml []byte //go:embed deploy/manifests/userspace-proxy.yaml var userspaceProxyYaml []byte -func (a *tailscaleSTSReconciler) reconcileSTS(ctx context.Context, logger *zap.SugaredLogger, sts *tailscaleSTSConfig, headlessSvc *corev1.Service, proxySecret, tsConfigHash string, configs map[tailcfg.CapabilityVersion]ipn.ConfigVAlpha) (*appsv1.StatefulSet, error) { +func (a *tailscaleSTSReconciler) reconcileSTS(ctx context.Context, logger *zap.SugaredLogger, sts *tailscaleSTSConfig, headlessSvc *corev1.Service, proxySecrets []string) (*appsv1.StatefulSet, error) { ss := new(appsv1.StatefulSet) if sts.ServeConfig != nil && sts.ForwardClusterTrafficViaL7IngressProxy != true { // If forwarding cluster traffic via is required we need non-userspace + NET_ADMIN + forwarding if err := yaml.Unmarshal(userspaceProxyYaml, &ss); err != nil { @@ -512,46 +656,46 @@ func (a *tailscaleSTSReconciler) reconcileSTS(ctx context.Context, logger *zap.S pod.Labels[key] = val // sync StatefulSet labels to Pod to make it easier for users to select the Pod } + if sts.Replicas > 0 { + ss.Spec.Replicas = ptr.To(sts.Replicas) + } + // Generic containerboot configuration options. container.Env = append(container.Env, corev1.EnvVar{ Name: "TS_KUBE_SECRET", - Value: proxySecret, - }, - corev1.EnvVar{ - // Old tailscaled config key is still used for backwards compatibility. - Name: "EXPERIMENTAL_TS_CONFIGFILE_PATH", - Value: "/etc/tsconfig/tailscaled", + Value: "$(POD_NAME)", }, corev1.EnvVar{ - // New style is in the form of cap-.hujson. Name: "TS_EXPERIMENTAL_VERSIONED_CONFIG_DIR", - Value: "/etc/tsconfig", + Value: "/etc/tsconfig/$(POD_NAME)", }, ) + if sts.ForwardClusterTrafficViaL7IngressProxy { container.Env = append(container.Env, corev1.EnvVar{ Name: "EXPERIMENTAL_ALLOW_PROXYING_CLUSTER_TRAFFIC_VIA_INGRESS", Value: "true", }) } - // Configure containeboot to run tailscaled with a configfile read from the state Secret. - mak.Set(&ss.Spec.Template.Annotations, podAnnotationLastSetConfigFileHash, tsConfigHash) - configVolume := corev1.Volume{ - Name: "tailscaledconfig", - VolumeSource: corev1.VolumeSource{ - Secret: &corev1.SecretVolumeSource{ - SecretName: proxySecret, + for i, secret := range proxySecrets { + configVolume := corev1.Volume{ + Name: "tailscaledconfig-" + strconv.Itoa(i), + VolumeSource: corev1.VolumeSource{ + Secret: &corev1.SecretVolumeSource{ + SecretName: secret, + }, }, - }, + } + + pod.Spec.Volumes = append(ss.Spec.Template.Spec.Volumes, configVolume) + container.VolumeMounts = append(container.VolumeMounts, corev1.VolumeMount{ + Name: fmt.Sprintf("tailscaledconfig-%d", i), + ReadOnly: true, + MountPath: path.Join("/etc/tsconfig/", secret), + }) } - pod.Spec.Volumes = append(ss.Spec.Template.Spec.Volumes, configVolume) - container.VolumeMounts = append(container.VolumeMounts, corev1.VolumeMount{ - Name: "tailscaledconfig", - ReadOnly: true, - MountPath: "/etc/tsconfig", - }) if a.tsFirewallMode != "" { container.Env = append(container.Env, corev1.EnvVar{ @@ -589,23 +733,29 @@ func (a *tailscaleSTSReconciler) reconcileSTS(ctx context.Context, logger *zap.S } else if sts.ServeConfig != nil { container.Env = append(container.Env, corev1.EnvVar{ Name: "TS_SERVE_CONFIG", - Value: "/etc/tailscaled/serve-config", - }) - container.VolumeMounts = append(container.VolumeMounts, corev1.VolumeMount{ - Name: "serve-config", - ReadOnly: true, - MountPath: "/etc/tailscaled", + Value: "/etc/tailscaled/$(POD_NAME)/serve-config", }) - pod.Spec.Volumes = append(ss.Spec.Template.Spec.Volumes, corev1.Volume{ - Name: "serve-config", - VolumeSource: corev1.VolumeSource{ - Secret: &corev1.SecretVolumeSource{ - SecretName: proxySecret, - Items: []corev1.KeyToPath{{Key: "serve-config", Path: "serve-config"}}, + + for i, secret := range proxySecrets { + container.VolumeMounts = append(container.VolumeMounts, corev1.VolumeMount{ + Name: "serve-config-" + strconv.Itoa(i), + ReadOnly: true, + MountPath: path.Join("/etc/tailscaled", secret), + }) + + pod.Spec.Volumes = append(ss.Spec.Template.Spec.Volumes, corev1.Volume{ + Name: "serve-config-" + strconv.Itoa(i), + VolumeSource: corev1.VolumeSource{ + Secret: &corev1.SecretVolumeSource{ + SecretName: secret, + Items: []corev1.KeyToPath{{Key: "serve-config", Path: "serve-config"}}, + }, }, - }, - }) + }) + } + } + app, err := appInfoForProxy(sts) if err != nil { // No need to error out if now or in future we end up in a @@ -668,24 +818,60 @@ func mergeStatefulSetLabelsOrAnnots(current, custom map[string]string, managed [ return custom } +func debugSetting(pc *tsapi.ProxyClass) bool { + if pc == nil || + pc.Spec.StatefulSet == nil || + pc.Spec.StatefulSet.Pod == nil || + pc.Spec.StatefulSet.Pod.TailscaleContainer == nil || + pc.Spec.StatefulSet.Pod.TailscaleContainer.Debug == nil { + // This default will change to false in 1.82.0. + return pc.Spec.Metrics != nil && pc.Spec.Metrics.Enable + } + + return pc.Spec.StatefulSet.Pod.TailscaleContainer.Debug.Enable +} + func applyProxyClassToStatefulSet(pc *tsapi.ProxyClass, ss *appsv1.StatefulSet, stsCfg *tailscaleSTSConfig, logger *zap.SugaredLogger) *appsv1.StatefulSet { if pc == nil || ss == nil { return ss } - if pc.Spec.Metrics != nil && pc.Spec.Metrics.Enable { - if stsCfg.TailnetTargetFQDN == "" && stsCfg.TailnetTargetIP == "" && !stsCfg.ForwardClusterTrafficViaL7IngressProxy { - enableMetrics(ss, pc) - } else if stsCfg.ForwardClusterTrafficViaL7IngressProxy { + + metricsEnabled := pc.Spec.Metrics != nil && pc.Spec.Metrics.Enable + debugEnabled := debugSetting(pc) + if metricsEnabled || debugEnabled { + isEgress := stsCfg != nil && (stsCfg.TailnetTargetFQDN != "" || stsCfg.TailnetTargetIP != "") + isForwardingL7Ingress := stsCfg != nil && stsCfg.ForwardClusterTrafficViaL7IngressProxy + if isEgress { // TODO (irbekrm): fix this // For Ingress proxies that have been configured with // tailscale.com/experimental-forward-cluster-traffic-via-ingress // annotation, all cluster traffic is forwarded to the // Ingress backend(s). - logger.Info("ProxyClass specifies that metrics should be enabled, but this is currently not supported for Ingress proxies that accept cluster traffic.") - } else { + logger.Info("ProxyClass specifies that metrics should be enabled, but this is currently not supported for egress proxies.") + } else if isForwardingL7Ingress { // TODO (irbekrm): fix this // For egress proxies, currently all cluster traffic is forwarded to the tailnet target. logger.Info("ProxyClass specifies that metrics should be enabled, but this is currently not supported for Ingress proxies that accept cluster traffic.") + } else { + enableEndpoints(ss, metricsEnabled, debugEnabled) + } + } + + if stsCfg != nil { + usesLetsEncrypt := stsCfg.proxyType == proxyTypeIngressResource || + stsCfg.proxyType == string(tsapi.ProxyGroupTypeIngress) || + stsCfg.proxyType == string(tsapi.ProxyGroupTypeKubernetesAPIServer) + + if pc.Spec.UseLetsEncryptStagingEnvironment && usesLetsEncrypt { + for i, c := range ss.Spec.Template.Spec.Containers { + if isMainContainer(&c) { + ss.Spec.Template.Spec.Containers[i].Env = append(ss.Spec.Template.Spec.Containers[i].Env, corev1.EnvVar{ + Name: "TS_DEBUG_ACME_DIRECTORY_URL", + Value: letsEncryptStagingEndpoint, + }) + break + } + } } } @@ -694,7 +880,7 @@ func applyProxyClassToStatefulSet(pc *tsapi.ProxyClass, ss *appsv1.StatefulSet, } // Update StatefulSet metadata. - if wantsSSLabels := pc.Spec.StatefulSet.Labels; len(wantsSSLabels) > 0 { + if wantsSSLabels := pc.Spec.StatefulSet.Labels.Parse(); len(wantsSSLabels) > 0 { ss.ObjectMeta.Labels = mergeStatefulSetLabelsOrAnnots(ss.ObjectMeta.Labels, wantsSSLabels, tailscaleManagedLabels) } if wantsSSAnnots := pc.Spec.StatefulSet.Annotations; len(wantsSSAnnots) > 0 { @@ -706,7 +892,7 @@ func applyProxyClassToStatefulSet(pc *tsapi.ProxyClass, ss *appsv1.StatefulSet, return ss } wantsPod := pc.Spec.StatefulSet.Pod - if wantsPodLabels := wantsPod.Labels; len(wantsPodLabels) > 0 { + if wantsPodLabels := wantsPod.Labels.Parse(); len(wantsPodLabels) > 0 { ss.Spec.Template.ObjectMeta.Labels = mergeStatefulSetLabelsOrAnnots(ss.Spec.Template.ObjectMeta.Labels, wantsPodLabels, tailscaleManagedLabels) } if wantsPodAnnots := wantsPod.Annotations; len(wantsPodAnnots) > 0 { @@ -718,6 +904,14 @@ func applyProxyClassToStatefulSet(pc *tsapi.ProxyClass, ss *appsv1.StatefulSet, ss.Spec.Template.Spec.NodeSelector = wantsPod.NodeSelector ss.Spec.Template.Spec.Affinity = wantsPod.Affinity ss.Spec.Template.Spec.Tolerations = wantsPod.Tolerations + ss.Spec.Template.Spec.PriorityClassName = wantsPod.PriorityClassName + ss.Spec.Template.Spec.TopologySpreadConstraints = wantsPod.TopologySpreadConstraints + if wantsPod.DNSPolicy != nil { + ss.Spec.Template.Spec.DNSPolicy = *wantsPod.DNSPolicy + } + if wantsPod.DNSConfig != nil { + ss.Spec.Template.Spec.DNSConfig = wantsPod.DNSConfig + } // Update containers. updateContainer := func(overlay *tsapi.Container, base corev1.Container) corev1.Container { @@ -746,7 +940,7 @@ func applyProxyClassToStatefulSet(pc *tsapi.ProxyClass, ss *appsv1.StatefulSet, return base } for i, c := range ss.Spec.Template.Spec.Containers { - if c.Name == "tailscale" { + if isMainContainer(&c) { ss.Spec.Template.Spec.Containers[i] = updateContainer(wantsPod.TailscaleContainer, ss.Spec.Template.Spec.Containers[i]) break } @@ -762,60 +956,93 @@ func applyProxyClassToStatefulSet(pc *tsapi.ProxyClass, ss *appsv1.StatefulSet, return ss } -func enableMetrics(ss *appsv1.StatefulSet, pc *tsapi.ProxyClass) { +func enableEndpoints(ss *appsv1.StatefulSet, metrics, debug bool) { for i, c := range ss.Spec.Template.Spec.Containers { - if c.Name == "tailscale" { - // Serve metrics on on :9001/debug/metrics. If - // we didn't specify Pod IP here, the proxy would, in - // some cases, also listen to its Tailscale IP- we don't - // want folks to start relying on this side-effect as a - // feature. - ss.Spec.Template.Spec.Containers[i].Env = append(ss.Spec.Template.Spec.Containers[i].Env, corev1.EnvVar{Name: "TS_TAILSCALED_EXTRA_ARGS", Value: "--debug=$(POD_IP):9001"}) - ss.Spec.Template.Spec.Containers[i].Ports = append(ss.Spec.Template.Spec.Containers[i].Ports, corev1.ContainerPort{Name: "metrics", Protocol: "TCP", HostPort: 9001, ContainerPort: 9001}) + if isMainContainer(&c) { + if debug { + ss.Spec.Template.Spec.Containers[i].Env = append(ss.Spec.Template.Spec.Containers[i].Env, + // Serve tailscaled's debug metrics on on + // :9001/debug/metrics. If we didn't specify Pod IP + // here, the proxy would, in some cases, also listen to its + // Tailscale IP- we don't want folks to start relying on this + // side-effect as a feature. + corev1.EnvVar{ + Name: "TS_DEBUG_ADDR_PORT", + Value: "$(POD_IP):9001", + }, + // TODO(tomhjp): Can remove this env var once 1.76.x is no + // longer supported. + corev1.EnvVar{ + Name: "TS_TAILSCALED_EXTRA_ARGS", + Value: "--debug=$(TS_DEBUG_ADDR_PORT)", + }, + ) + + ss.Spec.Template.Spec.Containers[i].Ports = append(ss.Spec.Template.Spec.Containers[i].Ports, + corev1.ContainerPort{ + Name: "debug", + Protocol: "TCP", + ContainerPort: 9001, + }, + ) + } + + if metrics { + ss.Spec.Template.Spec.Containers[i].Env = append(ss.Spec.Template.Spec.Containers[i].Env, + // Serve client metrics on :9002/metrics. + corev1.EnvVar{ + Name: "TS_LOCAL_ADDR_PORT", + Value: "$(POD_IP):9002", + }, + corev1.EnvVar{ + Name: "TS_ENABLE_METRICS", + Value: "true", + }, + ) + ss.Spec.Template.Spec.Containers[i].Ports = append(ss.Spec.Template.Spec.Containers[i].Ports, + corev1.ContainerPort{ + Name: "metrics", + Protocol: "TCP", + ContainerPort: 9002, + }, + ) + } + break } } } -func readAuthKey(secret *corev1.Secret, key string) (*string, error) { - origConf := &ipn.ConfigVAlpha{} - if err := json.Unmarshal([]byte(secret.Data[key]), origConf); err != nil { - return nil, fmt.Errorf("error unmarshaling previous tailscaled config in %q: %w", key, err) - } - return origConf.AuthKey, nil +func isMainContainer(c *corev1.Container) bool { + return c.Name == mainContainerName } -// tailscaledConfig takes a proxy config, a newly generated auth key if -// generated and a Secret with the previous proxy state and auth key and -// returns tailscaled configuration and a hash of that configuration. -// -// As of 2024-05-09 it also returns legacy tailscaled config without the -// later added NoStatefulFilter field to support proxies older than cap95. -// TODO (irbekrm): remove the legacy config once we no longer need to support -// versions older than cap94, -// https://tailscale.com/kb/1236/kubernetes-operator#operator-and-proxies -func tailscaledConfig(stsC *tailscaleSTSConfig, newAuthkey string, oldSecret *corev1.Secret) (tailscaleConfigs, error) { +// tailscaledConfig takes a proxy config, a newly generated auth key if generated and a Secret with the previous proxy +// state and auth key and returns tailscaled config files for currently supported proxy versions. +func tailscaledConfig(stsC *tailscaleSTSConfig, newAuthkey string, oldSecret *corev1.Secret, hostname string) (tailscaledConfigs, error) { conf := &ipn.ConfigVAlpha{ Version: "alpha0", AcceptDNS: "false", AcceptRoutes: "false", // AcceptRoutes defaults to true Locked: "false", - Hostname: &stsC.Hostname, - NoStatefulFiltering: "false", + Hostname: &hostname, + NoStatefulFiltering: "true", // Explicitly enforce default value, see #14216 + AppConnector: &ipn.AppConnectorPrefs{Advertise: false}, } - // For egress proxies only, we need to ensure that stateful filtering is - // not in place so that traffic from cluster can be forwarded via - // Tailscale IPs. - if stsC.TailnetTargetFQDN != "" || stsC.TailnetTargetIP != "" { - conf.NoStatefulFiltering = "true" + if stsC.LoginServer != "" { + conf.ServerURL = &stsC.LoginServer } + if stsC.Connector != nil { routes, err := netutil.CalcAdvertiseRoutes(stsC.Connector.routes, stsC.Connector.isExitNode) if err != nil { return nil, fmt.Errorf("error calculating routes: %w", err) } conf.AdvertiseRoutes = routes + if stsC.Connector.isAppConnector { + conf.AppConnector.Advertise = true + } } if shouldAcceptRoutes(stsC.ProxyClass) { conf.AcceptRoutes = "true" @@ -830,15 +1057,20 @@ func tailscaledConfig(stsC *tailscaleSTSConfig, newAuthkey string, oldSecret *co } conf.AuthKey = key } + capVerConfigs := make(map[tailcfg.CapabilityVersion]ipn.ConfigVAlpha) + capVerConfigs[107] = *conf + + // AppConnector config option is only understood by clients of capver 107 and newer. + conf.AppConnector = nil capVerConfigs[95] = *conf - // legacy config should not contain NoStatefulFiltering field. - conf.NoStatefulFiltering.Clear() - capVerConfigs[94] = *conf return capVerConfigs, nil } -func authKeyFromSecret(s *corev1.Secret) (key *string, err error) { +// latestConfigFromSecret returns the ipn.ConfigVAlpha with the highest capver +// as found in the Secret's key names, e.g. "cap-107.hujson" has capver 107. +// If no config is found, it returns nil. +func latestConfigFromSecret(s *corev1.Secret) (*ipn.ConfigVAlpha, error) { latest := tailcfg.CapabilityVersion(-1) latestStr := "" for k, data := range s.Data { @@ -855,12 +1087,31 @@ func authKeyFromSecret(s *corev1.Secret) (key *string, err error) { latest = v } } + + var conf *ipn.ConfigVAlpha + if latestStr != "" { + conf = &ipn.ConfigVAlpha{} + if err := json.Unmarshal([]byte(s.Data[latestStr]), conf); err != nil { + return nil, fmt.Errorf("error unmarshaling tailscaled config from Secret %q in field %q: %w", s.Name, latestStr, err) + } + } + + return conf, nil +} + +func authKeyFromSecret(s *corev1.Secret) (key *string, err error) { + conf, err := latestConfigFromSecret(s) + if err != nil { + return nil, err + } + // Allow for configs that don't contain an auth key. Perhaps // users have some mechanisms to delete them. Auth key is // normally not needed after the initial login. - if latestStr != "" { - return readAuthKey(s, latestStr) + if conf != nil { + key = conf.AuthKey } + return key, nil } @@ -884,36 +1135,15 @@ type ptrObject[T any] interface { *T } -type tailscaleConfigs map[tailcfg.CapabilityVersion]ipn.ConfigVAlpha - -// hashBytes produces a hash for the provided tailscaled config that is the same across -// different invocations of this code. We do not use the -// tailscale.com/deephash.Hash here because that produces a different hash for -// the same value in different tailscale builds. The hash we are producing here -// is used to determine if the container running the Connector Tailscale node -// needs to be restarted. The container does not need restarting when the only -// thing that changed is operator version (the hash is also exposed to users via -// an annotation and might be confusing if it changes without the config having -// changed). -func tailscaledConfigHash(c tailscaleConfigs) (string, error) { - b, err := json.Marshal(c) - if err != nil { - return "", fmt.Errorf("error marshalling tailscaled configs: %w", err) - } - h := sha256.New() - if _, err = h.Write(b); err != nil { - return "", fmt.Errorf("error calculating hash: %w", err) - } - return fmt.Sprintf("%x", h.Sum(nil)), nil -} +type tailscaledConfigs map[tailcfg.CapabilityVersion]ipn.ConfigVAlpha -// createOrUpdate adds obj to the k8s cluster, unless the object already exists, -// in which case update is called to make changes to it. If update is nil, the -// existing object is returned unmodified. +// createOrMaybeUpdate adds obj to the k8s cluster, unless the object already exists, +// in which case update is called to make changes to it. If update is nil or returns +// an error, the object is returned unmodified. // // obj is looked up by its Name and Namespace if Name is set, otherwise it's // looked up by labels. -func createOrUpdate[T any, O ptrObject[T]](ctx context.Context, c client.Client, ns string, obj O, update func(O)) (O, error) { +func createOrMaybeUpdate[T any, O ptrObject[T]](ctx context.Context, c client.Client, ns string, obj O, update func(O) error) (O, error) { var ( existing O err error @@ -928,7 +1158,9 @@ func createOrUpdate[T any, O ptrObject[T]](ctx context.Context, c client.Client, } if err == nil && existing != nil { if update != nil { - update(existing) + if err := update(existing); err != nil { + return nil, err + } if err := c.Update(ctx, existing); err != nil { return nil, err } @@ -944,6 +1176,21 @@ func createOrUpdate[T any, O ptrObject[T]](ctx context.Context, c client.Client, return obj, nil } +// createOrUpdate adds obj to the k8s cluster, unless the object already exists, +// in which case update is called to make changes to it. If update is nil, the +// existing object is returned unmodified. +// +// obj is looked up by its Name and Namespace if Name is set, otherwise it's +// looked up by labels. +func createOrUpdate[T any, O ptrObject[T]](ctx context.Context, c client.Client, ns string, obj O, update func(O)) (O, error) { + return createOrMaybeUpdate(ctx, c, ns, obj, func(o O) error { + if update != nil { + update(o) + } + return nil + }) +} + // getSingleObject searches for k8s objects of type T // (e.g. corev1.Service) with the given labels, and returns // it. Returns nil if no objects match the labels, and an error if @@ -1004,6 +1251,43 @@ func nameForService(svc *corev1.Service) string { return svc.Namespace + "-" + svc.Name } +// proxyClassForObject returns the proxy class for the given object. If the +// object does not have a proxy class label, it returns the default proxy class +func proxyClassForObject(o client.Object, proxyDefaultClass string) string { + proxyClass, exists := o.GetLabels()[LabelAnnotationProxyClass] + if exists { + return proxyClass + } + + proxyClass, exists = o.GetAnnotations()[LabelAnnotationProxyClass] + if exists { + return proxyClass + } + + return proxyDefaultClass +} + func isValidFirewallMode(m string) bool { return m == "auto" || m == "nftables" || m == "iptables" } + +// proxyCapVer accepts a proxy state Secret and UID of the current proxy Pod returns the capability version of the +// tailscale running in that Pod. This is best effort - if the capability version can not (currently) be determined, it +// returns -1. +func proxyCapVer(sec *corev1.Secret, podUID string, log *zap.SugaredLogger) tailcfg.CapabilityVersion { + if sec == nil || podUID == "" { + return tailcfg.CapabilityVersion(-1) + } + if len(sec.Data[kubetypes.KeyCapVer]) == 0 || len(sec.Data[kubetypes.KeyPodUID]) == 0 { + return tailcfg.CapabilityVersion(-1) + } + capVer, err := strconv.Atoi(string(sec.Data[kubetypes.KeyCapVer])) + if err != nil { + log.Infof("[unexpected]: unexpected capability version in proxy's state Secret, expected an integer, got %q", string(sec.Data[kubetypes.KeyCapVer])) + return tailcfg.CapabilityVersion(-1) + } + if !strings.EqualFold(podUID, string(sec.Data[kubetypes.KeyPodUID])) { + return tailcfg.CapabilityVersion(-1) + } + return tailcfg.CapabilityVersion(capVer) +} diff --git a/cmd/k8s-operator/sts_test.go b/cmd/k8s-operator/sts_test.go index b2b2c8b93..afe54ed98 100644 --- a/cmd/k8s-operator/sts_test.go +++ b/cmd/k8s-operator/sts_test.go @@ -18,8 +18,10 @@ import ( appsv1 "k8s.io/api/apps/v1" corev1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/api/resource" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "sigs.k8s.io/yaml" tsapi "tailscale.com/k8s-operator/apis/v1alpha1" + "tailscale.com/kube/kubetypes" "tailscale.com/types/ptr" ) @@ -59,20 +61,41 @@ func Test_applyProxyClassToStatefulSet(t *testing.T) { // Setup proxyClassAllOpts := &tsapi.ProxyClass{ Spec: tsapi.ProxyClassSpec{ + UseLetsEncryptStagingEnvironment: true, StatefulSet: &tsapi.StatefulSet{ - Labels: map[string]string{"foo": "bar"}, + Labels: tsapi.Labels{"foo": "bar"}, Annotations: map[string]string{"foo.io/bar": "foo"}, Pod: &tsapi.Pod{ - Labels: map[string]string{"bar": "foo"}, + Labels: tsapi.Labels{"bar": "foo"}, Annotations: map[string]string{"bar.io/foo": "foo"}, SecurityContext: &corev1.PodSecurityContext{ RunAsUser: ptr.To(int64(0)), }, - ImagePullSecrets: []corev1.LocalObjectReference{{Name: "docker-creds"}}, - NodeName: "some-node", - NodeSelector: map[string]string{"beta.kubernetes.io/os": "linux"}, - Affinity: &corev1.Affinity{NodeAffinity: &corev1.NodeAffinity{RequiredDuringSchedulingIgnoredDuringExecution: &corev1.NodeSelector{}}}, - Tolerations: []corev1.Toleration{{Key: "", Operator: "Exists"}}, + ImagePullSecrets: []corev1.LocalObjectReference{{Name: "docker-creds"}}, + NodeName: "some-node", + NodeSelector: map[string]string{"beta.kubernetes.io/os": "linux"}, + Affinity: &corev1.Affinity{NodeAffinity: &corev1.NodeAffinity{RequiredDuringSchedulingIgnoredDuringExecution: &corev1.NodeSelector{}}}, + Tolerations: []corev1.Toleration{{Key: "", Operator: "Exists"}}, + PriorityClassName: "high-priority", + TopologySpreadConstraints: []corev1.TopologySpreadConstraint{ + { + WhenUnsatisfiable: "DoNotSchedule", + TopologyKey: "kubernetes.io/hostname", + MaxSkew: 3, + LabelSelector: &metav1.LabelSelector{ + MatchLabels: map[string]string{"foo": "bar"}, + }, + }, + }, + DNSPolicy: ptr.To(corev1.DNSClusterFirstWithHostNet), + DNSConfig: &corev1.PodDNSConfig{ + Nameservers: []string{"1.1.1.1", "8.8.8.8"}, + Searches: []string{"example.com", "test.local"}, + Options: []corev1.PodDNSConfigOption{ + {Name: "ndots", Value: ptr.To("2")}, + {Name: "edns0"}, + }, + }, TailscaleContainer: &tsapi.Container{ SecurityContext: &corev1.SecurityContext{ Privileged: ptr.To(true), @@ -105,21 +128,36 @@ func Test_applyProxyClassToStatefulSet(t *testing.T) { proxyClassJustLabels := &tsapi.ProxyClass{ Spec: tsapi.ProxyClassSpec{ StatefulSet: &tsapi.StatefulSet{ - Labels: map[string]string{"foo": "bar"}, + Labels: tsapi.Labels{"foo": "bar"}, Annotations: map[string]string{"foo.io/bar": "foo"}, Pod: &tsapi.Pod{ - Labels: map[string]string{"bar": "foo"}, + Labels: tsapi.Labels{"bar": "foo"}, Annotations: map[string]string{"bar.io/foo": "foo"}, }, }, }, } - proxyClassMetrics := &tsapi.ProxyClass{ - Spec: tsapi.ProxyClassSpec{ - Metrics: &tsapi.Metrics{Enable: true}, - }, - } + proxyClassWithMetricsDebug := func(metrics bool, debug *bool) *tsapi.ProxyClass { + return &tsapi.ProxyClass{ + Spec: tsapi.ProxyClassSpec{ + Metrics: &tsapi.Metrics{Enable: metrics}, + StatefulSet: func() *tsapi.StatefulSet { + if debug == nil { + return nil + } + + return &tsapi.StatefulSet{ + Pod: &tsapi.Pod{ + TailscaleContainer: &tsapi.Container{ + Debug: &tsapi.Debug{Enable: *debug}, + }, + }, + } + }(), + }, + } + } var userspaceProxySS, nonUserspaceProxySS appsv1.StatefulSet if err := yaml.Unmarshal(userspaceProxyYaml, &userspaceProxySS); err != nil { t.Fatalf("unmarshaling userspace proxy template: %v", err) @@ -130,8 +168,8 @@ func Test_applyProxyClassToStatefulSet(t *testing.T) { // Set a couple additional fields so we can test that we don't // mistakenly override those. labels := map[string]string{ - LabelManaged: "true", - LabelParentName: "foo", + kubetypes.LabelManaged: "true", + LabelParentName: "foo", } annots := map[string]string{ podAnnotationLastSetClusterIP: "1.2.3.4", @@ -149,9 +187,9 @@ func Test_applyProxyClassToStatefulSet(t *testing.T) { // 1. Test that a ProxyClass with all fields set gets correctly applied // to a Statefulset built from non-userspace proxy template. wantSS := nonUserspaceProxySS.DeepCopy() - wantSS.ObjectMeta.Labels = mergeMapKeys(wantSS.ObjectMeta.Labels, proxyClassAllOpts.Spec.StatefulSet.Labels) - wantSS.ObjectMeta.Annotations = mergeMapKeys(wantSS.ObjectMeta.Annotations, proxyClassAllOpts.Spec.StatefulSet.Annotations) - wantSS.Spec.Template.Labels = proxyClassAllOpts.Spec.StatefulSet.Pod.Labels + updateMap(wantSS.ObjectMeta.Labels, proxyClassAllOpts.Spec.StatefulSet.Labels.Parse()) + updateMap(wantSS.ObjectMeta.Annotations, proxyClassAllOpts.Spec.StatefulSet.Annotations) + wantSS.Spec.Template.Labels = proxyClassAllOpts.Spec.StatefulSet.Pod.Labels.Parse() wantSS.Spec.Template.Annotations = proxyClassAllOpts.Spec.StatefulSet.Pod.Annotations wantSS.Spec.Template.Spec.SecurityContext = proxyClassAllOpts.Spec.StatefulSet.Pod.SecurityContext wantSS.Spec.Template.Spec.ImagePullSecrets = proxyClassAllOpts.Spec.StatefulSet.Pod.ImagePullSecrets @@ -159,6 +197,7 @@ func Test_applyProxyClassToStatefulSet(t *testing.T) { wantSS.Spec.Template.Spec.NodeSelector = proxyClassAllOpts.Spec.StatefulSet.Pod.NodeSelector wantSS.Spec.Template.Spec.Affinity = proxyClassAllOpts.Spec.StatefulSet.Pod.Affinity wantSS.Spec.Template.Spec.Tolerations = proxyClassAllOpts.Spec.StatefulSet.Pod.Tolerations + wantSS.Spec.Template.Spec.TopologySpreadConstraints = proxyClassAllOpts.Spec.StatefulSet.Pod.TopologySpreadConstraints wantSS.Spec.Template.Spec.Containers[0].SecurityContext = proxyClassAllOpts.Spec.StatefulSet.Pod.TailscaleContainer.SecurityContext wantSS.Spec.Template.Spec.InitContainers[0].SecurityContext = proxyClassAllOpts.Spec.StatefulSet.Pod.TailscaleInitContainer.SecurityContext wantSS.Spec.Template.Spec.Containers[0].Resources = proxyClassAllOpts.Spec.StatefulSet.Pod.TailscaleContainer.Resources @@ -169,31 +208,34 @@ func Test_applyProxyClassToStatefulSet(t *testing.T) { wantSS.Spec.Template.Spec.Containers[0].ImagePullPolicy = "IfNotPresent" wantSS.Spec.Template.Spec.InitContainers[0].Image = "ghcr.io/my-repo/tailscale:v0.01testsomething" wantSS.Spec.Template.Spec.InitContainers[0].ImagePullPolicy = "IfNotPresent" + wantSS.Spec.Template.Spec.PriorityClassName = proxyClassAllOpts.Spec.StatefulSet.Pod.PriorityClassName + wantSS.Spec.Template.Spec.DNSPolicy = corev1.DNSClusterFirstWithHostNet + wantSS.Spec.Template.Spec.DNSConfig = proxyClassAllOpts.Spec.StatefulSet.Pod.DNSConfig gotSS := applyProxyClassToStatefulSet(proxyClassAllOpts, nonUserspaceProxySS.DeepCopy(), new(tailscaleSTSConfig), zl.Sugar()) if diff := cmp.Diff(gotSS, wantSS); diff != "" { - t.Fatalf("Unexpected result applying ProxyClass with all fields set to a StatefulSet for non-userspace proxy (-got +want):\n%s", diff) + t.Errorf("Unexpected result applying ProxyClass with all fields set to a StatefulSet for non-userspace proxy (-got +want):\n%s", diff) } // 2. Test that a ProxyClass with custom labels and annotations for // StatefulSet and Pod set gets correctly applied to a Statefulset built // from non-userspace proxy template. wantSS = nonUserspaceProxySS.DeepCopy() - wantSS.ObjectMeta.Labels = mergeMapKeys(wantSS.ObjectMeta.Labels, proxyClassJustLabels.Spec.StatefulSet.Labels) - wantSS.ObjectMeta.Annotations = mergeMapKeys(wantSS.ObjectMeta.Annotations, proxyClassJustLabels.Spec.StatefulSet.Annotations) - wantSS.Spec.Template.Labels = proxyClassJustLabels.Spec.StatefulSet.Pod.Labels + updateMap(wantSS.ObjectMeta.Labels, proxyClassJustLabels.Spec.StatefulSet.Labels.Parse()) + updateMap(wantSS.ObjectMeta.Annotations, proxyClassJustLabels.Spec.StatefulSet.Annotations) + wantSS.Spec.Template.Labels = proxyClassJustLabels.Spec.StatefulSet.Pod.Labels.Parse() wantSS.Spec.Template.Annotations = proxyClassJustLabels.Spec.StatefulSet.Pod.Annotations gotSS = applyProxyClassToStatefulSet(proxyClassJustLabels, nonUserspaceProxySS.DeepCopy(), new(tailscaleSTSConfig), zl.Sugar()) if diff := cmp.Diff(gotSS, wantSS); diff != "" { - t.Fatalf("Unexpected result applying ProxyClass with custom labels and annotations to a StatefulSet for non-userspace proxy (-got +want):\n%s", diff) + t.Errorf("Unexpected result applying ProxyClass with custom labels and annotations to a StatefulSet for non-userspace proxy (-got +want):\n%s", diff) } // 3. Test that a ProxyClass with all fields set gets correctly applied // to a Statefulset built from a userspace proxy template. wantSS = userspaceProxySS.DeepCopy() - wantSS.ObjectMeta.Labels = mergeMapKeys(wantSS.ObjectMeta.Labels, proxyClassAllOpts.Spec.StatefulSet.Labels) - wantSS.ObjectMeta.Annotations = mergeMapKeys(wantSS.ObjectMeta.Annotations, proxyClassAllOpts.Spec.StatefulSet.Annotations) - wantSS.Spec.Template.Labels = proxyClassAllOpts.Spec.StatefulSet.Pod.Labels + updateMap(wantSS.ObjectMeta.Labels, proxyClassAllOpts.Spec.StatefulSet.Labels.Parse()) + updateMap(wantSS.ObjectMeta.Annotations, proxyClassAllOpts.Spec.StatefulSet.Annotations) + wantSS.Spec.Template.Labels = proxyClassAllOpts.Spec.StatefulSet.Pod.Labels.Parse() wantSS.Spec.Template.Annotations = proxyClassAllOpts.Spec.StatefulSet.Pod.Annotations wantSS.Spec.Template.Spec.SecurityContext = proxyClassAllOpts.Spec.StatefulSet.Pod.SecurityContext wantSS.Spec.Template.Spec.ImagePullSecrets = proxyClassAllOpts.Spec.StatefulSet.Pod.ImagePullSecrets @@ -201,43 +243,76 @@ func Test_applyProxyClassToStatefulSet(t *testing.T) { wantSS.Spec.Template.Spec.NodeSelector = proxyClassAllOpts.Spec.StatefulSet.Pod.NodeSelector wantSS.Spec.Template.Spec.Affinity = proxyClassAllOpts.Spec.StatefulSet.Pod.Affinity wantSS.Spec.Template.Spec.Tolerations = proxyClassAllOpts.Spec.StatefulSet.Pod.Tolerations + wantSS.Spec.Template.Spec.TopologySpreadConstraints = proxyClassAllOpts.Spec.StatefulSet.Pod.TopologySpreadConstraints wantSS.Spec.Template.Spec.Containers[0].SecurityContext = proxyClassAllOpts.Spec.StatefulSet.Pod.TailscaleContainer.SecurityContext wantSS.Spec.Template.Spec.Containers[0].Resources = proxyClassAllOpts.Spec.StatefulSet.Pod.TailscaleContainer.Resources wantSS.Spec.Template.Spec.Containers[0].Env = append(wantSS.Spec.Template.Spec.Containers[0].Env, []corev1.EnvVar{{Name: "foo", Value: "bar"}, {Name: "TS_USERSPACE", Value: "true"}, {Name: "bar"}}...) wantSS.Spec.Template.Spec.Containers[0].ImagePullPolicy = "IfNotPresent" wantSS.Spec.Template.Spec.Containers[0].Image = "ghcr.io/my-repo/tailscale:v0.01testsomething" + wantSS.Spec.Template.Spec.PriorityClassName = proxyClassAllOpts.Spec.StatefulSet.Pod.PriorityClassName + wantSS.Spec.Template.Spec.DNSPolicy = corev1.DNSClusterFirstWithHostNet + wantSS.Spec.Template.Spec.DNSConfig = proxyClassAllOpts.Spec.StatefulSet.Pod.DNSConfig gotSS = applyProxyClassToStatefulSet(proxyClassAllOpts, userspaceProxySS.DeepCopy(), new(tailscaleSTSConfig), zl.Sugar()) if diff := cmp.Diff(gotSS, wantSS); diff != "" { - t.Fatalf("Unexpected result applying ProxyClass with all options to a StatefulSet for a userspace proxy (-got +want):\n%s", diff) + t.Errorf("Unexpected result applying ProxyClass with all options to a StatefulSet for a userspace proxy (-got +want):\n%s", diff) } // 4. Test that a ProxyClass with custom labels and annotations gets correctly applied // to a Statefulset built from a userspace proxy template. wantSS = userspaceProxySS.DeepCopy() - wantSS.ObjectMeta.Labels = mergeMapKeys(wantSS.ObjectMeta.Labels, proxyClassJustLabels.Spec.StatefulSet.Labels) - wantSS.ObjectMeta.Annotations = mergeMapKeys(wantSS.ObjectMeta.Annotations, proxyClassJustLabels.Spec.StatefulSet.Annotations) - wantSS.Spec.Template.Labels = proxyClassJustLabels.Spec.StatefulSet.Pod.Labels + updateMap(wantSS.ObjectMeta.Labels, proxyClassJustLabels.Spec.StatefulSet.Labels.Parse()) + updateMap(wantSS.ObjectMeta.Annotations, proxyClassJustLabels.Spec.StatefulSet.Annotations) + wantSS.Spec.Template.Labels = proxyClassJustLabels.Spec.StatefulSet.Pod.Labels.Parse() wantSS.Spec.Template.Annotations = proxyClassJustLabels.Spec.StatefulSet.Pod.Annotations gotSS = applyProxyClassToStatefulSet(proxyClassJustLabels, userspaceProxySS.DeepCopy(), new(tailscaleSTSConfig), zl.Sugar()) if diff := cmp.Diff(gotSS, wantSS); diff != "" { - t.Fatalf("Unexpected result applying ProxyClass with custom labels and annotations to a StatefulSet for a userspace proxy (-got +want):\n%s", diff) + t.Errorf("Unexpected result applying ProxyClass with custom labels and annotations to a StatefulSet for a userspace proxy (-got +want):\n%s", diff) } - // 5. Test that a ProxyClass with metrics enabled gets correctly applied to a StatefulSet. + // 5. Metrics enabled defaults to enabling both metrics and debug. wantSS = nonUserspaceProxySS.DeepCopy() - wantSS.Spec.Template.Spec.Containers[0].Env = append(wantSS.Spec.Template.Spec.Containers[0].Env, corev1.EnvVar{Name: "TS_TAILSCALED_EXTRA_ARGS", Value: "--debug=$(POD_IP):9001"}) - wantSS.Spec.Template.Spec.Containers[0].Ports = []corev1.ContainerPort{{Name: "metrics", Protocol: "TCP", ContainerPort: 9001, HostPort: 9001}} - gotSS = applyProxyClassToStatefulSet(proxyClassMetrics, nonUserspaceProxySS.DeepCopy(), new(tailscaleSTSConfig), zl.Sugar()) + wantSS.Spec.Template.Spec.Containers[0].Env = append(wantSS.Spec.Template.Spec.Containers[0].Env, + corev1.EnvVar{Name: "TS_DEBUG_ADDR_PORT", Value: "$(POD_IP):9001"}, + corev1.EnvVar{Name: "TS_TAILSCALED_EXTRA_ARGS", Value: "--debug=$(TS_DEBUG_ADDR_PORT)"}, + corev1.EnvVar{Name: "TS_LOCAL_ADDR_PORT", Value: "$(POD_IP):9002"}, + corev1.EnvVar{Name: "TS_ENABLE_METRICS", Value: "true"}, + ) + wantSS.Spec.Template.Spec.Containers[0].Ports = []corev1.ContainerPort{ + {Name: "debug", Protocol: "TCP", ContainerPort: 9001}, + {Name: "metrics", Protocol: "TCP", ContainerPort: 9002}, + } + gotSS = applyProxyClassToStatefulSet(proxyClassWithMetricsDebug(true, nil), nonUserspaceProxySS.DeepCopy(), new(tailscaleSTSConfig), zl.Sugar()) if diff := cmp.Diff(gotSS, wantSS); diff != "" { - t.Fatalf("Unexpected result applying ProxyClass with metrics enabled to a StatefulSet (-got +want):\n%s", diff) + t.Errorf("Unexpected result applying ProxyClass with metrics enabled to a StatefulSet (-got +want):\n%s", diff) } -} -func mergeMapKeys(a, b map[string]string) map[string]string { - for key, val := range b { - a[key] = val + // 6. Enable _just_ metrics by explicitly disabling debug. + wantSS = nonUserspaceProxySS.DeepCopy() + wantSS.Spec.Template.Spec.Containers[0].Env = append(wantSS.Spec.Template.Spec.Containers[0].Env, + corev1.EnvVar{Name: "TS_LOCAL_ADDR_PORT", Value: "$(POD_IP):9002"}, + corev1.EnvVar{Name: "TS_ENABLE_METRICS", Value: "true"}, + ) + wantSS.Spec.Template.Spec.Containers[0].Ports = []corev1.ContainerPort{{Name: "metrics", Protocol: "TCP", ContainerPort: 9002}} + gotSS = applyProxyClassToStatefulSet(proxyClassWithMetricsDebug(true, ptr.To(false)), nonUserspaceProxySS.DeepCopy(), new(tailscaleSTSConfig), zl.Sugar()) + if diff := cmp.Diff(gotSS, wantSS); diff != "" { + t.Errorf("Unexpected result applying ProxyClass with metrics enabled to a StatefulSet (-got +want):\n%s", diff) + } + + // 7. Enable _just_ debug without metrics. + wantSS = nonUserspaceProxySS.DeepCopy() + wantSS.Spec.Template.Spec.Containers[0].Env = append(wantSS.Spec.Template.Spec.Containers[0].Env, + corev1.EnvVar{Name: "TS_DEBUG_ADDR_PORT", Value: "$(POD_IP):9001"}, + corev1.EnvVar{Name: "TS_TAILSCALED_EXTRA_ARGS", Value: "--debug=$(TS_DEBUG_ADDR_PORT)"}, + ) + wantSS.Spec.Template.Spec.Containers[0].Ports = []corev1.ContainerPort{{Name: "debug", Protocol: "TCP", ContainerPort: 9001}} + gotSS = applyProxyClassToStatefulSet(proxyClassWithMetricsDebug(false, ptr.To(true)), nonUserspaceProxySS.DeepCopy(), new(tailscaleSTSConfig), zl.Sugar()) + if diff := cmp.Diff(gotSS, wantSS); diff != "" { + t.Errorf("Unexpected result applying ProxyClass with metrics enabled to a StatefulSet (-got +want):\n%s", diff) } - return a + + // 8. A Kubernetes API proxy with letsencrypt staging enabled + gotSS = applyProxyClassToStatefulSet(proxyClassAllOpts, nonUserspaceProxySS.DeepCopy(), &tailscaleSTSConfig{proxyType: string(tsapi.ProxyGroupTypeKubernetesAPIServer)}, zl.Sugar()) + verifyEnvVar(t, gotSS, "TS_DEBUG_ACME_DIRECTORY_URL", letsEncryptStagingEndpoint) } func Test_mergeStatefulSetLabelsOrAnnots(t *testing.T) { @@ -250,28 +325,28 @@ func Test_mergeStatefulSetLabelsOrAnnots(t *testing.T) { }{ { name: "no custom labels specified and none present in current labels, return current labels", - current: map[string]string{LabelManaged: "true", LabelParentName: "foo", LabelParentType: "svc", LabelParentNamespace: "foo"}, - want: map[string]string{LabelManaged: "true", LabelParentName: "foo", LabelParentType: "svc", LabelParentNamespace: "foo"}, + current: map[string]string{kubetypes.LabelManaged: "true", LabelParentName: "foo", LabelParentType: "svc", LabelParentNamespace: "foo"}, + want: map[string]string{kubetypes.LabelManaged: "true", LabelParentName: "foo", LabelParentType: "svc", LabelParentNamespace: "foo"}, managed: tailscaleManagedLabels, }, { name: "no custom labels specified, but some present in current labels, return tailscale managed labels only from the current labels", - current: map[string]string{"foo": "bar", "something.io/foo": "bar", LabelManaged: "true", LabelParentName: "foo", LabelParentType: "svc", LabelParentNamespace: "foo"}, - want: map[string]string{LabelManaged: "true", LabelParentName: "foo", LabelParentType: "svc", LabelParentNamespace: "foo"}, + current: map[string]string{"foo": "bar", "something.io/foo": "bar", kubetypes.LabelManaged: "true", LabelParentName: "foo", LabelParentType: "svc", LabelParentNamespace: "foo"}, + want: map[string]string{kubetypes.LabelManaged: "true", LabelParentName: "foo", LabelParentType: "svc", LabelParentNamespace: "foo"}, managed: tailscaleManagedLabels, }, { name: "custom labels specified, current labels only contain tailscale managed labels, return a union of both", - current: map[string]string{LabelManaged: "true", LabelParentName: "foo", LabelParentType: "svc", LabelParentNamespace: "foo"}, + current: map[string]string{kubetypes.LabelManaged: "true", LabelParentName: "foo", LabelParentType: "svc", LabelParentNamespace: "foo"}, custom: map[string]string{"foo": "bar", "something.io/foo": "bar"}, - want: map[string]string{"foo": "bar", "something.io/foo": "bar", LabelManaged: "true", LabelParentName: "foo", LabelParentType: "svc", LabelParentNamespace: "foo"}, + want: map[string]string{"foo": "bar", "something.io/foo": "bar", kubetypes.LabelManaged: "true", LabelParentName: "foo", LabelParentType: "svc", LabelParentNamespace: "foo"}, managed: tailscaleManagedLabels, }, { name: "custom labels specified, current labels contain tailscale managed labels and custom labels, some of which re not present in the new custom labels, return a union of managed labels and the desired custom labels", - current: map[string]string{"foo": "bar", "bar": "baz", "app": "1234", LabelManaged: "true", LabelParentName: "foo", LabelParentType: "svc", LabelParentNamespace: "foo"}, + current: map[string]string{"foo": "bar", "bar": "baz", "app": "1234", kubetypes.LabelManaged: "true", LabelParentName: "foo", LabelParentType: "svc", LabelParentNamespace: "foo"}, custom: map[string]string{"foo": "bar", "something.io/foo": "bar"}, - want: map[string]string{"foo": "bar", "something.io/foo": "bar", "app": "1234", LabelManaged: "true", LabelParentName: "foo", LabelParentType: "svc", LabelParentNamespace: "foo"}, + want: map[string]string{"foo": "bar", "something.io/foo": "bar", "app": "1234", kubetypes.LabelManaged: "true", LabelParentName: "foo", LabelParentType: "svc", LabelParentNamespace: "foo"}, managed: tailscaleManagedLabels, }, { @@ -331,3 +406,10 @@ func Test_mergeStatefulSetLabelsOrAnnots(t *testing.T) { }) } } + +// updateMap updates map a with the values from map b. +func updateMap(a, b map[string]string) { + for key, val := range b { + a[key] = val + } +} diff --git a/cmd/k8s-operator/svc-for-pg.go b/cmd/k8s-operator/svc-for-pg.go new file mode 100644 index 000000000..144d37558 --- /dev/null +++ b/cmd/k8s-operator/svc-for-pg.go @@ -0,0 +1,825 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !plan9 + +package main + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "net/http" + "net/netip" + "reflect" + "slices" + "strings" + "sync" + + "go.uber.org/zap" + corev1 "k8s.io/api/core/v1" + discoveryv1 "k8s.io/api/discovery/v1" + apiequality "k8s.io/apimachinery/pkg/api/equality" + apierrors "k8s.io/apimachinery/pkg/api/errors" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/types" + "k8s.io/client-go/tools/record" + "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/reconcile" + "tailscale.com/internal/client/tailscale" + "tailscale.com/ipn" + tsoperator "tailscale.com/k8s-operator" + tsapi "tailscale.com/k8s-operator/apis/v1alpha1" + "tailscale.com/kube/ingressservices" + "tailscale.com/kube/kubetypes" + "tailscale.com/tailcfg" + "tailscale.com/tstime" + "tailscale.com/util/clientmetric" + "tailscale.com/util/mak" + "tailscale.com/util/set" +) + +const ( + svcPGFinalizerName = "tailscale.com/service-pg-finalizer" + + reasonIngressSvcInvalid = "IngressSvcInvalid" + reasonIngressSvcValid = "IngressSvcValid" + reasonIngressSvcConfigured = "IngressSvcConfigured" + reasonIngressSvcNoBackendsConfigured = "IngressSvcNoBackendsConfigured" + reasonIngressSvcCreationFailed = "IngressSvcCreationFailed" +) + +var gaugePGServiceResources = clientmetric.NewGauge(kubetypes.MetricServicePGResourceCount) + +// HAServiceReconciler is a controller that reconciles Tailscale Kubernetes +// Services that should be exposed on an ingress ProxyGroup (in HA mode). +type HAServiceReconciler struct { + client.Client + isDefaultLoadBalancer bool + recorder record.EventRecorder + logger *zap.SugaredLogger + tsClient tsClient + tsNamespace string + lc localClient + defaultTags []string + operatorID string // stableID of the operator's Tailscale device + + clock tstime.Clock + + mu sync.Mutex // protects following + // managedServices is a set of all Service resources that we're currently + // managing. This is only used for metrics. + managedServices set.Slice[types.UID] +} + +// Reconcile reconciles Services that should be exposed over Tailscale in HA +// mode (on a ProxyGroup). It looks at all Services with +// tailscale.com/proxy-group annotation. For each such Service, it ensures that +// a Tailscale Service named after the hostname of the Service exists and is up to +// date. +// HA Servicees support multi-cluster Service setup. +// Each Tailscale Service contains a list of owner references that uniquely identify +// the operator. When an Service that acts as a +// backend is being deleted, the corresponding Tailscale Service is only deleted if the +// only owner reference that it contains is for this operator. If other owner +// references are found, then cleanup operation only removes this operator's owner +// reference. +func (r *HAServiceReconciler) Reconcile(ctx context.Context, req reconcile.Request) (res reconcile.Result, err error) { + logger := r.logger.With("Service", req.NamespacedName) + logger.Debugf("starting reconcile") + defer logger.Debugf("reconcile finished") + + svc := new(corev1.Service) + err = r.Get(ctx, req.NamespacedName, svc) + if apierrors.IsNotFound(err) { + // Request object not found, could have been deleted after reconcile request. + logger.Debugf("Service not found, assuming it was deleted") + return res, nil + } else if err != nil { + return res, fmt.Errorf("failed to get Service: %w", err) + } + + hostname := nameForService(svc) + logger = logger.With("hostname", hostname) + + if !svc.DeletionTimestamp.IsZero() || !r.isTailscaleService(svc) { + logger.Debugf("Service is being deleted or is (no longer) referring to Tailscale ingress/egress, ensuring any created resources are cleaned up") + _, err = r.maybeCleanup(ctx, hostname, svc, logger) + return res, err + } + + // needsRequeue is set to true if the underlying Tailscale Service has changed as a result of this reconcile. If that + // is the case, we reconcile the Ingress one more time to ensure that concurrent updates to the Tailscale Service in a + // multi-cluster Ingress setup have not resulted in another actor overwriting our Tailscale Service update. + needsRequeue := false + needsRequeue, err = r.maybeProvision(ctx, hostname, svc, logger) + if err != nil { + if strings.Contains(err.Error(), optimisticLockErrorMsg) { + logger.Infof("optimistic lock error, retrying: %s", err) + } else { + return reconcile.Result{}, err + } + } + if needsRequeue { + res = reconcile.Result{RequeueAfter: requeueInterval()} + } + + return reconcile.Result{}, nil +} + +// maybeProvision ensures that a Tailscale Service for this Ingress exists and is up to date and that the serve config for the +// corresponding ProxyGroup contains the Ingress backend's definition. +// If a Tailscale Service does not exist, it will be created. +// If a Tailscale Service exists, but only with owner references from other operator instances, an owner reference for this +// operator instance is added. +// If a Tailscale Service exists, but does not have an owner reference from any operator, we error +// out assuming that this is an owner reference created by an unknown actor. +// Returns true if the operation resulted in a Tailscale Service update. +func (r *HAServiceReconciler) maybeProvision(ctx context.Context, hostname string, svc *corev1.Service, logger *zap.SugaredLogger) (svcsChanged bool, err error) { + oldSvcStatus := svc.Status.DeepCopy() + defer func() { + if !apiequality.Semantic.DeepEqual(oldSvcStatus, &svc.Status) { + // An error encountered here should get returned by the Reconcile function. + err = errors.Join(err, r.Client.Status().Update(ctx, svc)) + } + }() + + pgName := svc.Annotations[AnnotationProxyGroup] + if pgName == "" { + logger.Infof("[unexpected] no ProxyGroup annotation, skipping Tailscale Service provisioning") + return false, nil + } + + logger = logger.With("ProxyGroup", pgName) + + pg := &tsapi.ProxyGroup{} + if err := r.Get(ctx, client.ObjectKey{Name: pgName}, pg); err != nil { + if apierrors.IsNotFound(err) { + msg := fmt.Sprintf("ProxyGroup %q does not exist", pgName) + logger.Warnf(msg) + r.recorder.Event(svc, corev1.EventTypeWarning, "ProxyGroupNotFound", msg) + return false, nil + } + return false, fmt.Errorf("getting ProxyGroup %q: %w", pgName, err) + } + if !tsoperator.ProxyGroupAvailable(pg) { + logger.Infof("ProxyGroup is not (yet) ready") + return false, nil + } + + if err := r.validateService(ctx, svc, pg); err != nil { + r.recorder.Event(svc, corev1.EventTypeWarning, reasonIngressSvcInvalid, err.Error()) + tsoperator.SetServiceCondition(svc, tsapi.IngressSvcValid, metav1.ConditionFalse, reasonIngressSvcInvalid, err.Error(), r.clock, logger) + return false, nil + } + + if !slices.Contains(svc.Finalizers, svcPGFinalizerName) { + // This log line is printed exactly once during initial provisioning, + // because once the finalizer is in place this block gets skipped. So, + // this is a nice place to tell the operator that the high level, + // multi-reconcile operation is underway. + logger.Infof("exposing Service over tailscale") + svc.Finalizers = append(svc.Finalizers, svcPGFinalizerName) + if err := r.Update(ctx, svc); err != nil { + return false, fmt.Errorf("failed to add finalizer: %w", err) + } + r.mu.Lock() + r.managedServices.Add(svc.UID) + gaugePGServiceResources.Set(int64(r.managedServices.Len())) + r.mu.Unlock() + } + + // 1. Ensure that if Service's hostname/name has changed, any Tailscale Service + // resources corresponding to the old hostname are cleaned up. + // In practice, this function will ensure that any Tailscale Services that are + // associated with the provided ProxyGroup and no longer owned by a + // Service are cleaned up. This is fine- it is not expensive and ensures + // that in edge cases (a single update changed both hostname and removed + // ProxyGroup annotation) the Tailscale Service is more likely to be + // (eventually) removed. + svcsChanged, err = r.maybeCleanupProxyGroup(ctx, pgName, logger) + if err != nil { + return false, fmt.Errorf("failed to cleanup Tailscale Service resources for ProxyGroup: %w", err) + } + + // 2. Ensure that there isn't a Tailscale Service with the same hostname + // already created and not owned by this Service. + serviceName := tailcfg.ServiceName("svc:" + hostname) + existingTSSvc, err := r.tsClient.GetVIPService(ctx, serviceName) + if err != nil && !isErrorTailscaleServiceNotFound(err) { + return false, fmt.Errorf("error getting Tailscale Service %q: %w", hostname, err) + } + + // 3. Generate the Tailscale Service owner annotation for new or existing Tailscale Service. + // This checks and ensures that Tailscale Service's owner references are updated + // for this Service and errors if that is not possible (i.e. because it + // appears that the Tailscale Service has been created by a non-operator actor). + updatedAnnotations, err := ownerAnnotations(r.operatorID, existingTSSvc) + if err != nil { + instr := fmt.Sprintf("To proceed, you can either manually delete the existing Tailscale Service or choose a different hostname with the '%s' annotaion", AnnotationHostname) + msg := fmt.Sprintf("error ensuring ownership of Tailscale Service %s: %v. %s", hostname, err, instr) + logger.Warn(msg) + r.recorder.Event(svc, corev1.EventTypeWarning, "InvalidTailscaleService", msg) + tsoperator.SetServiceCondition(svc, tsapi.IngressSvcValid, metav1.ConditionFalse, reasonIngressSvcInvalid, msg, r.clock, logger) + return false, nil + } + + tags := r.defaultTags + if tstr, ok := svc.Annotations[AnnotationTags]; ok && tstr != "" { + tags = strings.Split(tstr, ",") + } + + tsSvc := &tailscale.VIPService{ + Name: serviceName, + Tags: tags, + Ports: []string{"do-not-validate"}, // we don't want to validate ports + Comment: managedTSServiceComment, + Annotations: updatedAnnotations, + } + if existingTSSvc != nil { + tsSvc.Addrs = existingTSSvc.Addrs + } + + // TODO(irbekrm): right now if two Service resources attempt to apply different Tailscale Service configs (different + // tags) we can end up reconciling those in a loop. We should detect when a Service + // with the same generation number has been reconciled ~more than N times and stop attempting to apply updates. + if existingTSSvc == nil || + !reflect.DeepEqual(tsSvc.Tags, existingTSSvc.Tags) || + !ownersAreSetAndEqual(tsSvc, existingTSSvc) { + logger.Infof("Ensuring Tailscale Service exists and is up to date") + if err := r.tsClient.CreateOrUpdateVIPService(ctx, tsSvc); err != nil { + return false, fmt.Errorf("error creating Tailscale Service: %w", err) + } + existingTSSvc = tsSvc + } + + cm, cfgs, err := ingressSvcsConfigs(ctx, r.Client, pgName, r.tsNamespace) + if err != nil { + return false, fmt.Errorf("error retrieving ingress services configuration: %w", err) + } + if cm == nil { + logger.Info("ConfigMap not yet created, waiting..") + return false, nil + } + + if existingTSSvc.Addrs == nil { + existingTSSvc, err = r.tsClient.GetVIPService(ctx, tsSvc.Name) + if err != nil { + return false, fmt.Errorf("error getting Tailscale Service: %w", err) + } + if existingTSSvc.Addrs == nil { + // TODO(irbekrm): this should be a retry + return false, fmt.Errorf("unexpected: Tailscale Service addresses not populated") + } + } + + var tsSvcIPv4 netip.Addr + var tsSvcIPv6 netip.Addr + for _, tsip := range existingTSSvc.Addrs { + ip, err := netip.ParseAddr(tsip) + if err != nil { + return false, fmt.Errorf("error parsing Tailscale Service address: %w", err) + } + + if ip.Is4() { + tsSvcIPv4 = ip + } else if ip.Is6() { + tsSvcIPv6 = ip + } + } + + cfg := ingressservices.Config{} + for _, cip := range svc.Spec.ClusterIPs { + ip, err := netip.ParseAddr(cip) + if err != nil { + return false, fmt.Errorf("error parsing Kubernetes Service address: %w", err) + } + + if ip.Is4() { + cfg.IPv4Mapping = &ingressservices.Mapping{ + ClusterIP: ip, + TailscaleServiceIP: tsSvcIPv4, + } + } else if ip.Is6() { + cfg.IPv6Mapping = &ingressservices.Mapping{ + ClusterIP: ip, + TailscaleServiceIP: tsSvcIPv6, + } + } + } + + existingCfg := cfgs[serviceName.String()] + if !reflect.DeepEqual(existingCfg, cfg) { + mak.Set(&cfgs, serviceName.String(), cfg) + cfgBytes, err := json.Marshal(cfgs) + if err != nil { + return false, fmt.Errorf("error marshaling ingress config: %w", err) + } + mak.Set(&cm.BinaryData, ingressservices.IngressConfigKey, cfgBytes) + if err := r.Update(ctx, cm); err != nil { + return false, fmt.Errorf("error updating ingress config: %w", err) + } + } + + logger.Infof("updating AdvertiseServices config") + // 4. Update tailscaled's AdvertiseServices config, which should add the Tailscale Service + // IPs to the ProxyGroup Pods' AllowedIPs in the next netmap update if approved. + if err = r.maybeUpdateAdvertiseServicesConfig(ctx, svc, pg.Name, serviceName, &cfg, true, logger); err != nil { + return false, fmt.Errorf("failed to update tailscaled config: %w", err) + } + + count, err := r.numberPodsAdvertising(ctx, pgName, serviceName) + if err != nil { + return false, fmt.Errorf("failed to get number of advertised Pods: %w", err) + } + + // TODO(irbekrm): here and when creating the Tailscale Service, verify if the + // error is not terminal (and therefore should not be reconciled). For + // example, if the hostname is already a hostname of a Tailscale node, + // the GET here will fail. + // If there are no Pods advertising the Tailscale Service (yet), we want to set 'svc.Status.LoadBalancer.Ingress' to nil" + var lbs []corev1.LoadBalancerIngress + conditionStatus := metav1.ConditionFalse + conditionType := tsapi.IngressSvcConfigured + conditionReason := reasonIngressSvcNoBackendsConfigured + conditionMessage := fmt.Sprintf("%d/%d proxy backends ready and advertising", count, pgReplicas(pg)) + if count != 0 { + dnsName, err := r.dnsNameForService(ctx, serviceName) + if err != nil { + return false, fmt.Errorf("error getting DNS name for Service: %w", err) + } + + lbs = []corev1.LoadBalancerIngress{ + { + Hostname: dnsName, + IP: tsSvcIPv4.String(), + }, + } + + conditionStatus = metav1.ConditionTrue + conditionReason = reasonIngressSvcConfigured + } + + tsoperator.SetServiceCondition(svc, conditionType, conditionStatus, conditionReason, conditionMessage, r.clock, logger) + svc.Status.LoadBalancer.Ingress = lbs + + return svcsChanged, nil +} + +// maybeCleanup ensures that any resources, such as a Tailscale Service created for this Service, are cleaned up when the +// Service is being deleted or is unexposed. The cleanup is safe for a multi-cluster setup- the Tailscale Service is only +// deleted if it does not contain any other owner references. If it does the cleanup only removes the owner reference +// corresponding to this Service. +func (r *HAServiceReconciler) maybeCleanup(ctx context.Context, hostname string, svc *corev1.Service, logger *zap.SugaredLogger) (svcChanged bool, err error) { + logger.Debugf("Ensuring any resources for Service are cleaned up") + ix := slices.Index(svc.Finalizers, svcPGFinalizerName) + if ix < 0 { + logger.Debugf("no finalizer, nothing to do") + return false, nil + } + logger.Infof("Ensuring that Tailscale Service %q configuration is cleaned up", hostname) + + defer func() { + if err != nil { + return + } + err = r.deleteFinalizer(ctx, svc, logger) + }() + + serviceName := tailcfg.ServiceName("svc:" + hostname) + // 1. Clean up the Tailscale Service. + svcChanged, err = cleanupTailscaleService(ctx, r.tsClient, serviceName, r.operatorID, logger) + if err != nil { + return false, fmt.Errorf("error deleting Tailscale Service: %w", err) + } + + // 2. Unadvertise the Tailscale Service. + pgName := svc.Annotations[AnnotationProxyGroup] + if err = r.maybeUpdateAdvertiseServicesConfig(ctx, svc, pgName, serviceName, nil, false, logger); err != nil { + return false, fmt.Errorf("failed to update tailscaled config services: %w", err) + } + + // TODO: maybe wait for the service to be unadvertised, only then remove the backend routing + + // 3. Clean up ingress config (routing rules). + cm, cfgs, err := ingressSvcsConfigs(ctx, r.Client, pgName, r.tsNamespace) + if err != nil { + return false, fmt.Errorf("error retrieving ingress services configuration: %w", err) + } + if cm == nil || cfgs == nil { + return true, nil + } + logger.Infof("Removing Tailscale Service %q from ingress config for ProxyGroup %q", hostname, pgName) + delete(cfgs, serviceName.String()) + cfgBytes, err := json.Marshal(cfgs) + if err != nil { + return false, fmt.Errorf("error marshaling ingress config: %w", err) + } + mak.Set(&cm.BinaryData, ingressservices.IngressConfigKey, cfgBytes) + return true, r.Update(ctx, cm) +} + +// Tailscale Services that are associated with the provided ProxyGroup and no longer managed this operator's instance are deleted, if not owned by other operator instances, else the owner reference is cleaned up. +// Returns true if the operation resulted in existing Tailscale Service updates (owner reference removal). +func (r *HAServiceReconciler) maybeCleanupProxyGroup(ctx context.Context, proxyGroupName string, logger *zap.SugaredLogger) (svcsChanged bool, err error) { + cm, config, err := ingressSvcsConfigs(ctx, r.Client, proxyGroupName, r.tsNamespace) + if err != nil { + return false, fmt.Errorf("failed to get ingress service config: %s", err) + } + + svcList := &corev1.ServiceList{} + if err := r.Client.List(ctx, svcList, client.MatchingFields{indexIngressProxyGroup: proxyGroupName}); err != nil { + return false, fmt.Errorf("failed to find Services for ProxyGroup %q: %w", proxyGroupName, err) + } + + ingressConfigChanged := false + for tsSvcName, cfg := range config { + found := false + for _, svc := range svcList.Items { + if strings.EqualFold(fmt.Sprintf("svc:%s", nameForService(&svc)), tsSvcName) { + found = true + break + } + } + if !found { + logger.Infof("Tailscale Service %q is not owned by any Service, cleaning up", tsSvcName) + + // Make sure the Tailscale Service is not advertised in tailscaled or serve config. + if err = r.maybeUpdateAdvertiseServicesConfig(ctx, nil, proxyGroupName, tailcfg.ServiceName(tsSvcName), &cfg, false, logger); err != nil { + return false, fmt.Errorf("failed to update tailscaled config services: %w", err) + } + + svcsChanged, err = cleanupTailscaleService(ctx, r.tsClient, tailcfg.ServiceName(tsSvcName), r.operatorID, logger) + if err != nil { + return false, fmt.Errorf("deleting Tailscale Service %q: %w", tsSvcName, err) + } + + _, ok := config[tsSvcName] + if ok { + logger.Infof("Removing Tailscale Service %q from serve config", tsSvcName) + delete(config, tsSvcName) + ingressConfigChanged = true + } + } + } + + if ingressConfigChanged { + configBytes, err := json.Marshal(config) + if err != nil { + return false, fmt.Errorf("marshaling serve config: %w", err) + } + mak.Set(&cm.BinaryData, ingressservices.IngressConfigKey, configBytes) + if err := r.Update(ctx, cm); err != nil { + return false, fmt.Errorf("updating serve config: %w", err) + } + } + + return svcsChanged, nil +} + +func (r *HAServiceReconciler) deleteFinalizer(ctx context.Context, svc *corev1.Service, logger *zap.SugaredLogger) error { + svc.Finalizers = slices.DeleteFunc(svc.Finalizers, func(f string) bool { + return f == svcPGFinalizerName + }) + logger.Debugf("ensure %q finalizer is removed", svcPGFinalizerName) + + if err := r.Update(ctx, svc); err != nil { + return fmt.Errorf("failed to remove finalizer %q: %w", svcPGFinalizerName, err) + } + r.mu.Lock() + defer r.mu.Unlock() + r.managedServices.Remove(svc.UID) + gaugePGServiceResources.Set(int64(r.managedServices.Len())) + return nil +} + +func (r *HAServiceReconciler) isTailscaleService(svc *corev1.Service) bool { + proxyGroup := svc.Annotations[AnnotationProxyGroup] + return r.shouldExpose(svc) && proxyGroup != "" +} + +func (r *HAServiceReconciler) shouldExpose(svc *corev1.Service) bool { + return r.shouldExposeClusterIP(svc) +} + +func (r *HAServiceReconciler) shouldExposeClusterIP(svc *corev1.Service) bool { + if svc.Spec.ClusterIP == "" || svc.Spec.ClusterIP == "None" { + return false + } + return isTailscaleLoadBalancerService(svc, r.isDefaultLoadBalancer) || hasExposeAnnotation(svc) +} + +// tailnetCertDomain returns the base domain (TCD) of the current tailnet. +func (r *HAServiceReconciler) tailnetCertDomain(ctx context.Context) (string, error) { + st, err := r.lc.StatusWithoutPeers(ctx) + if err != nil { + return "", fmt.Errorf("error getting tailscale status: %w", err) + } + return st.CurrentTailnet.MagicDNSSuffix, nil +} + +// cleanupTailscaleService deletes any Tailscale Service by the provided name if it is not owned by operator instances other than this one. +// If a Tailscale Service is found, but contains other owner references, only removes this operator's owner reference. +// If a Tailscale Service by the given name is not found or does not contain this operator's owner reference, do nothing. +// It returns true if an existing Tailscale Service was updated to remove owner reference, as well as any error that occurred. +func cleanupTailscaleService(ctx context.Context, tsClient tsClient, name tailcfg.ServiceName, operatorID string, logger *zap.SugaredLogger) (updated bool, err error) { + svc, err := tsClient.GetVIPService(ctx, name) + if err != nil { + errResp := &tailscale.ErrResponse{} + ok := errors.As(err, errResp) + if ok && errResp.Status == http.StatusNotFound { + return false, nil + } + if !ok { + return false, fmt.Errorf("unexpected error getting Tailscale Service %q: %w", name.String(), err) + } + + return false, fmt.Errorf("error getting Tailscale Service: %w", err) + } + if svc == nil { + return false, nil + } + o, err := parseOwnerAnnotation(svc) + if err != nil { + return false, fmt.Errorf("error parsing Tailscale Service owner annotation: %w", err) + } + if o == nil || len(o.OwnerRefs) == 0 { + return false, nil + } + // Comparing with the operatorID only means that we will not be able to + // clean up Tailscale Services in cases where the operator was deleted from the + // cluster before deleting the Ingress. Perhaps the comparison could be + // 'if or.OperatorID == r.operatorID || or.ingressUID == r.ingressUID'. + ix := slices.IndexFunc(o.OwnerRefs, func(or OwnerRef) bool { + return or.OperatorID == operatorID + }) + if ix == -1 { + return false, nil + } + if len(o.OwnerRefs) == 1 { + logger.Infof("Deleting Tailscale Service %q", name) + return false, tsClient.DeleteVIPService(ctx, name) + } + o.OwnerRefs = slices.Delete(o.OwnerRefs, ix, ix+1) + logger.Infof("Updating Tailscale Service %q", name) + json, err := json.Marshal(o) + if err != nil { + return false, fmt.Errorf("error marshalling updated Tailscale Service owner reference: %w", err) + } + svc.Annotations[ownerAnnotation] = string(json) + return true, tsClient.CreateOrUpdateVIPService(ctx, svc) +} + +func (a *HAServiceReconciler) backendRoutesSetup(ctx context.Context, serviceName, replicaName, pgName string, wantsCfg *ingressservices.Config, logger *zap.SugaredLogger) (bool, error) { + logger.Debugf("checking backend routes for service '%s'", serviceName) + pod := &corev1.Pod{} + err := a.Get(ctx, client.ObjectKey{Namespace: a.tsNamespace, Name: replicaName}, pod) + if apierrors.IsNotFound(err) { + logger.Debugf("Pod %q not found", replicaName) + return false, nil + } + if err != nil { + return false, fmt.Errorf("failed to get Pod: %w", err) + } + secret := &corev1.Secret{} + err = a.Get(ctx, client.ObjectKey{Namespace: a.tsNamespace, Name: replicaName}, secret) + if apierrors.IsNotFound(err) { + logger.Debugf("Secret %q not found", replicaName) + return false, nil + } + if err != nil { + return false, fmt.Errorf("failed to get Secret: %w", err) + } + if len(secret.Data) == 0 || secret.Data[ingressservices.IngressConfigKey] == nil { + return false, nil + } + gotCfgB := secret.Data[ingressservices.IngressConfigKey] + var gotCfgs ingressservices.Status + if err := json.Unmarshal(gotCfgB, &gotCfgs); err != nil { + return false, fmt.Errorf("error unmarshalling ingress config: %w", err) + } + statusUpToDate, err := isCurrentStatus(gotCfgs, pod, logger) + if err != nil { + return false, fmt.Errorf("error checking ingress config status: %w", err) + } + if !statusUpToDate || !reflect.DeepEqual(gotCfgs.Configs.GetConfig(serviceName), wantsCfg) { + logger.Debugf("Pod %q is not ready to advertise Tailscale Service", pod.Name) + return false, nil + } + return true, nil +} + +func isCurrentStatus(gotCfgs ingressservices.Status, pod *corev1.Pod, logger *zap.SugaredLogger) (bool, error) { + ips := pod.Status.PodIPs + if len(ips) == 0 { + logger.Debugf("Pod %q does not yet have IPs, unable to determine if status is up to date", pod.Name) + return false, nil + } + + if len(ips) > 2 { + return false, fmt.Errorf("pod 'status.PodIPs' can contain at most 2 IPs, got %d (%v)", len(ips), ips) + } + var podIPv4, podIPv6 string + for _, ip := range ips { + parsed, err := netip.ParseAddr(ip.IP) + if err != nil { + return false, fmt.Errorf("error parsing IP address %s: %w", ip.IP, err) + } + if parsed.Is4() { + podIPv4 = parsed.String() + continue + } + podIPv6 = parsed.String() + } + if podIPv4 != gotCfgs.PodIPv4 || podIPv6 != gotCfgs.PodIPv6 { + return false, nil + } + return true, nil +} + +func (a *HAServiceReconciler) maybeUpdateAdvertiseServicesConfig(ctx context.Context, svc *corev1.Service, pgName string, serviceName tailcfg.ServiceName, cfg *ingressservices.Config, shouldBeAdvertised bool, logger *zap.SugaredLogger) (err error) { + logger.Debugf("checking advertisement for service '%s'", serviceName) + // Get all config Secrets for this ProxyGroup. + // Get all Pods + secrets := &corev1.SecretList{} + if err := a.List(ctx, secrets, client.InNamespace(a.tsNamespace), client.MatchingLabels(pgSecretLabels(pgName, kubetypes.LabelSecretTypeConfig))); err != nil { + return fmt.Errorf("failed to list config Secrets: %w", err) + } + + if svc != nil && shouldBeAdvertised { + shouldBeAdvertised, err = a.checkEndpointsReady(ctx, svc, logger) + if err != nil { + return fmt.Errorf("failed to check readiness of Service '%s' endpoints: %w", svc.Name, err) + } + } + + for _, secret := range secrets.Items { + var updated bool + for fileName, confB := range secret.Data { + var conf ipn.ConfigVAlpha + if err := json.Unmarshal(confB, &conf); err != nil { + return fmt.Errorf("error unmarshalling ProxyGroup config: %w", err) + } + + idx := slices.Index(conf.AdvertiseServices, serviceName.String()) + isAdvertised := idx >= 0 + switch { + case !isAdvertised && !shouldBeAdvertised: + logger.Debugf("service %q shouldn't be advertised", serviceName) + continue + case isAdvertised && shouldBeAdvertised: + logger.Debugf("service %q is already advertised", serviceName) + continue + case isAdvertised && !shouldBeAdvertised: + logger.Debugf("deleting advertisement for service %q", serviceName) + conf.AdvertiseServices = slices.Delete(conf.AdvertiseServices, idx, idx+1) + case shouldBeAdvertised: + replicaName, ok := strings.CutSuffix(secret.Name, "-config") + if !ok { + logger.Infof("[unexpected] unable to determine replica name from config Secret name %q, unable to determine if backend routing has been configured", secret.Name) + return nil + } + ready, err := a.backendRoutesSetup(ctx, serviceName.String(), replicaName, pgName, cfg, logger) + if err != nil { + return fmt.Errorf("error checking backend routes: %w", err) + } + if !ready { + logger.Debugf("service %q is not ready to be advertised", serviceName) + continue + } + + conf.AdvertiseServices = append(conf.AdvertiseServices, serviceName.String()) + } + confB, err := json.Marshal(conf) + if err != nil { + return fmt.Errorf("error marshalling ProxyGroup config: %w", err) + } + mak.Set(&secret.Data, fileName, confB) + updated = true + } + if updated { + if err := a.Update(ctx, &secret); err != nil { + return fmt.Errorf("error updating ProxyGroup config Secret: %w", err) + } + } + } + return nil +} + +func (a *HAServiceReconciler) numberPodsAdvertising(ctx context.Context, pgName string, serviceName tailcfg.ServiceName) (int, error) { + // Get all state Secrets for this ProxyGroup. + secrets := &corev1.SecretList{} + if err := a.List(ctx, secrets, client.InNamespace(a.tsNamespace), client.MatchingLabels(pgSecretLabels(pgName, kubetypes.LabelSecretTypeState))); err != nil { + return 0, fmt.Errorf("failed to list ProxyGroup %q state Secrets: %w", pgName, err) + } + + var count int + for _, secret := range secrets.Items { + prefs, ok, err := getDevicePrefs(&secret) + if err != nil { + return 0, fmt.Errorf("error getting node metadata: %w", err) + } + if !ok { + continue + } + if slices.Contains(prefs.AdvertiseServices, serviceName.String()) { + count++ + } + } + + return count, nil +} + +// dnsNameForService returns the DNS name for the given Tailscale Service name. +func (r *HAServiceReconciler) dnsNameForService(ctx context.Context, svc tailcfg.ServiceName) (string, error) { + s := svc.WithoutPrefix() + tcd, err := r.tailnetCertDomain(ctx) + if err != nil { + return "", fmt.Errorf("error determining DNS name base: %w", err) + } + return s + "." + tcd, nil +} + +// ingressSvcsConfig returns a ConfigMap that contains ingress services configuration for the provided ProxyGroup as well +// as unmarshalled configuration from the ConfigMap. +func ingressSvcsConfigs(ctx context.Context, cl client.Client, proxyGroupName, tsNamespace string) (cm *corev1.ConfigMap, cfgs ingressservices.Configs, err error) { + name := pgIngressCMName(proxyGroupName) + cm = &corev1.ConfigMap{ + ObjectMeta: metav1.ObjectMeta{ + Name: name, + Namespace: tsNamespace, + }, + } + err = cl.Get(ctx, client.ObjectKeyFromObject(cm), cm) + if apierrors.IsNotFound(err) { // ProxyGroup resources have not been created (yet) + return nil, nil, nil + } + if err != nil { + return nil, nil, fmt.Errorf("error retrieving ingress services ConfigMap %s: %v", name, err) + } + cfgs = ingressservices.Configs{} + if len(cm.BinaryData[ingressservices.IngressConfigKey]) != 0 { + if err := json.Unmarshal(cm.BinaryData[ingressservices.IngressConfigKey], &cfgs); err != nil { + return nil, nil, fmt.Errorf("error unmarshaling ingress services config %v: %w", cm.BinaryData[ingressservices.IngressConfigKey], err) + } + } + return cm, cfgs, nil +} + +func (r *HAServiceReconciler) getEndpointSlicesForService(ctx context.Context, svc *corev1.Service, logger *zap.SugaredLogger) ([]discoveryv1.EndpointSlice, error) { + logger.Debugf("looking for endpoint slices for svc with name '%s' in namespace '%s' matching label '%s=%s'", svc.Name, svc.Namespace, discoveryv1.LabelServiceName, svc.Name) + // https://kubernetes.io/docs/concepts/services-networking/endpoint-slices/#ownership + labels := map[string]string{discoveryv1.LabelServiceName: svc.Name} + eps := new(discoveryv1.EndpointSliceList) + if err := r.List(ctx, eps, client.InNamespace(svc.Namespace), client.MatchingLabels(labels)); err != nil { + return nil, fmt.Errorf("error listing EndpointSlices: %w", err) + } + + if len(eps.Items) == 0 { + logger.Debugf("Service '%s' EndpointSlice does not yet exist. We will reconcile again once it's created", svc.Name) + return nil, nil + } + + return eps.Items, nil +} + +func (r *HAServiceReconciler) checkEndpointsReady(ctx context.Context, svc *corev1.Service, logger *zap.SugaredLogger) (bool, error) { + epss, err := r.getEndpointSlicesForService(ctx, svc, logger) + if err != nil { + return false, fmt.Errorf("failed to list EndpointSlices for Service %q: %w", svc.Name, err) + } + for _, eps := range epss { + for _, ep := range eps.Endpoints { + if *ep.Conditions.Ready { + return true, nil + } + } + } + + logger.Debugf("could not find any ready Endpoints in EndpointSlice") + return false, nil +} + +func (r *HAServiceReconciler) validateService(ctx context.Context, svc *corev1.Service, pg *tsapi.ProxyGroup) error { + var errs []error + if pg.Spec.Type != tsapi.ProxyGroupTypeIngress { + errs = append(errs, fmt.Errorf("ProxyGroup %q is of type %q but must be of type %q", + pg.Name, pg.Spec.Type, tsapi.ProxyGroupTypeIngress)) + } + if violations := validateService(svc); len(violations) > 0 { + errs = append(errs, fmt.Errorf("invalid Service: %s", strings.Join(violations, ", "))) + } + svcList := &corev1.ServiceList{} + if err := r.List(ctx, svcList); err != nil { + errs = append(errs, fmt.Errorf("[unexpected] error listing Services: %w", err)) + return errors.Join(errs...) + } + svcName := nameForService(svc) + for _, s := range svcList.Items { + if r.shouldExpose(&s) && nameForService(&s) == svcName && s.UID != svc.UID { + errs = append(errs, fmt.Errorf("found duplicate Service %q for hostname %q - multiple HA Services for the same hostname in the same cluster are not allowed", client.ObjectKeyFromObject(&s), svcName)) + } + } + return errors.Join(errs...) +} diff --git a/cmd/k8s-operator/svc-for-pg_test.go b/cmd/k8s-operator/svc-for-pg_test.go new file mode 100644 index 000000000..baaa07727 --- /dev/null +++ b/cmd/k8s-operator/svc-for-pg_test.go @@ -0,0 +1,427 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !plan9 + +package main + +import ( + "context" + "encoding/json" + "fmt" + "math/rand/v2" + "net/netip" + "testing" + "time" + + "go.uber.org/zap" + corev1 "k8s.io/api/core/v1" + discoveryv1 "k8s.io/api/discovery/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/types" + "k8s.io/client-go/tools/record" + "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/client/fake" + "tailscale.com/ipn/ipnstate" + tsoperator "tailscale.com/k8s-operator" + tsapi "tailscale.com/k8s-operator/apis/v1alpha1" + "tailscale.com/kube/ingressservices" + "tailscale.com/kube/kubetypes" + "tailscale.com/tstest" + "tailscale.com/types/ptr" + "tailscale.com/util/mak" + + "tailscale.com/tailcfg" +) + +func TestServicePGReconciler(t *testing.T) { + svcPGR, stateSecret, fc, ft, _ := setupServiceTest(t) + svcs := []*corev1.Service{} + config := []string{} + for i := range 4 { + svc, _ := setupTestService(t, fmt.Sprintf("test-svc-%d", i), "", fmt.Sprintf("1.2.3.%d", i), fc, stateSecret) + svcs = append(svcs, svc) + + // Verify initial reconciliation + expectReconciled(t, svcPGR, "default", svc.Name) + + config = append(config, fmt.Sprintf("svc:default-%s", svc.Name)) + verifyTailscaleService(t, ft, fmt.Sprintf("svc:default-%s", svc.Name), []string{"do-not-validate"}) + verifyTailscaledConfig(t, fc, "test-pg", config) + } + + for i, svc := range svcs { + if err := fc.Delete(context.Background(), svc); err != nil { + t.Fatalf("deleting Service: %v", err) + } + + expectReconciled(t, svcPGR, "default", svc.Name) + + // Verify the ConfigMap was cleaned up + cm := &corev1.ConfigMap{} + if err := fc.Get(context.Background(), types.NamespacedName{ + Name: "test-pg-ingress-config", + Namespace: "operator-ns", + }, cm); err != nil { + t.Fatalf("getting ConfigMap: %v", err) + } + + cfgs := ingressservices.Configs{} + if err := json.Unmarshal(cm.BinaryData[ingressservices.IngressConfigKey], &cfgs); err != nil { + t.Fatalf("unmarshaling serve config: %v", err) + } + + if len(cfgs) > len(svcs)-(i+1) { + t.Error("serve config not cleaned up") + } + + config = removeEl(config, fmt.Sprintf("svc:default-%s", svc.Name)) + verifyTailscaledConfig(t, fc, "test-pg", config) + } +} + +func TestServicePGReconciler_UpdateHostname(t *testing.T) { + svcPGR, stateSecret, fc, ft, _ := setupServiceTest(t) + + cip := "4.1.6.7" + svc, _ := setupTestService(t, "test-service", "", cip, fc, stateSecret) + + expectReconciled(t, svcPGR, "default", svc.Name) + + verifyTailscaleService(t, ft, fmt.Sprintf("svc:default-%s", svc.Name), []string{"do-not-validate"}) + verifyTailscaledConfig(t, fc, "test-pg", []string{fmt.Sprintf("svc:default-%s", svc.Name)}) + + hostname := "foobarbaz" + mustUpdate(t, fc, svc.Namespace, svc.Name, func(s *corev1.Service) { + mak.Set(&s.Annotations, AnnotationHostname, hostname) + }) + + // NOTE: we need to update the ingress config Secret because there is no containerboot in the fake proxy Pod + updateIngressConfigSecret(t, fc, stateSecret, hostname, cip) + expectReconciled(t, svcPGR, "default", svc.Name) + + verifyTailscaleService(t, ft, fmt.Sprintf("svc:%s", hostname), []string{"do-not-validate"}) + verifyTailscaledConfig(t, fc, "test-pg", []string{fmt.Sprintf("svc:%s", hostname)}) + + _, err := ft.GetVIPService(context.Background(), tailcfg.ServiceName(fmt.Sprintf("svc:default-%s", svc.Name))) + if err == nil { + t.Fatalf("svc:default-%s not cleaned up", svc.Name) + } + if !isErrorTailscaleServiceNotFound(err) { + t.Fatalf("unexpected error: %v", err) + } +} + +func setupServiceTest(t *testing.T) (*HAServiceReconciler, *corev1.Secret, client.Client, *fakeTSClient, *tstest.Clock) { + // Pre-create the ProxyGroup + pg := &tsapi.ProxyGroup{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-pg", + Generation: 1, + }, + Spec: tsapi.ProxyGroupSpec{ + Type: tsapi.ProxyGroupTypeIngress, + }, + } + + // Pre-create the ConfigMap for the ProxyGroup + pgConfigMap := &corev1.ConfigMap{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-pg-ingress-config", + Namespace: "operator-ns", + }, + BinaryData: map[string][]byte{ + "serve-config.json": []byte(`{"Services":{}}`), + }, + } + + // Pre-create a config Secret for the ProxyGroup + pgCfgSecret := &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: pgConfigSecretName("test-pg", 0), + Namespace: "operator-ns", + Labels: pgSecretLabels("test-pg", kubetypes.LabelSecretTypeConfig), + }, + Data: map[string][]byte{ + tsoperator.TailscaledConfigFileName(pgMinCapabilityVersion): []byte(`{"Version":""}`), + }, + } + + pgStateSecret := &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-pg-0", + Namespace: "operator-ns", + }, + Data: map[string][]byte{}, + } + + pgPod := &corev1.Pod{ + TypeMeta: metav1.TypeMeta{Kind: "Pod", APIVersion: "v1"}, + ObjectMeta: metav1.ObjectMeta{ + Name: "test-pg-0", + Namespace: "operator-ns", + }, + Status: corev1.PodStatus{ + PodIPs: []corev1.PodIP{ + { + IP: "4.3.2.1", + }, + }, + }, + } + + fc := fake.NewClientBuilder(). + WithScheme(tsapi.GlobalScheme). + WithObjects(pg, pgCfgSecret, pgConfigMap, pgPod, pgStateSecret). + WithStatusSubresource(pg). + WithIndex(new(corev1.Service), indexIngressProxyGroup, indexPGIngresses). + Build() + + // Set ProxyGroup status to ready + pg.Status.Conditions = []metav1.Condition{ + { + Type: string(tsapi.ProxyGroupAvailable), + Status: metav1.ConditionTrue, + ObservedGeneration: 1, + }, + } + if err := fc.Status().Update(context.Background(), pg); err != nil { + t.Fatal(err) + } + + ft := &fakeTSClient{} + zl, err := zap.NewDevelopment() + if err != nil { + t.Fatal(err) + } + + lc := &fakeLocalClient{ + status: &ipnstate.Status{ + CurrentTailnet: &ipnstate.TailnetStatus{ + MagicDNSSuffix: "ts.net", + }, + }, + } + + cl := tstest.NewClock(tstest.ClockOpts{}) + svcPGR := &HAServiceReconciler{ + Client: fc, + tsClient: ft, + clock: cl, + defaultTags: []string{"tag:k8s"}, + tsNamespace: "operator-ns", + logger: zl.Sugar(), + recorder: record.NewFakeRecorder(10), + lc: lc, + } + + return svcPGR, pgStateSecret, fc, ft, cl +} + +func TestValidateService(t *testing.T) { + // Test that no more than one Kubernetes Service in a cluster refers to the same Tailscale Service. + pgr, _, lc, _, cl := setupServiceTest(t) + svc := &corev1.Service{ + TypeMeta: metav1.TypeMeta{Kind: "Service", APIVersion: "v1"}, + ObjectMeta: metav1.ObjectMeta{ + Name: "my-app", + Namespace: "ns-1", + UID: types.UID("1234-UID"), + Annotations: map[string]string{ + "tailscale.com/proxy-group": "test-pg", + "tailscale.com/hostname": "my-app", + }, + }, + Spec: corev1.ServiceSpec{ + ClusterIP: "1.2.3.4", + Type: corev1.ServiceTypeLoadBalancer, + LoadBalancerClass: ptr.To("tailscale"), + }, + } + svc2 := &corev1.Service{ + TypeMeta: metav1.TypeMeta{Kind: "Service", APIVersion: "v1"}, + ObjectMeta: metav1.ObjectMeta{ + Name: "my-app2", + Namespace: "ns-2", + UID: types.UID("1235-UID"), + Annotations: map[string]string{ + "tailscale.com/proxy-group": "test-pg", + "tailscale.com/hostname": "my-app", + }, + }, + Spec: corev1.ServiceSpec{ + ClusterIP: "1.2.3.5", + Type: corev1.ServiceTypeLoadBalancer, + LoadBalancerClass: ptr.To("tailscale"), + }, + } + wantSvc := &corev1.Service{ + ObjectMeta: svc.ObjectMeta, + TypeMeta: svc.TypeMeta, + Spec: svc.Spec, + Status: corev1.ServiceStatus{ + Conditions: []metav1.Condition{ + { + Type: string(tsapi.IngressSvcValid), + Status: metav1.ConditionFalse, + Reason: reasonIngressSvcInvalid, + LastTransitionTime: metav1.NewTime(cl.Now().Truncate(time.Second)), + Message: `found duplicate Service "ns-2/my-app2" for hostname "my-app" - multiple HA Services for the same hostname in the same cluster are not allowed`, + }, + }, + }, + } + + mustCreate(t, lc, svc) + mustCreate(t, lc, svc2) + expectReconciled(t, pgr, svc.Namespace, svc.Name) + expectEqual(t, lc, wantSvc) +} + +func TestServicePGReconciler_MultiCluster(t *testing.T) { + var ft *fakeTSClient + var lc localClient + for i := 0; i <= 10; i++ { + pgr, stateSecret, fc, fti, _ := setupServiceTest(t) + if i == 0 { + ft = fti + lc = pgr.lc + } else { + pgr.tsClient = ft + pgr.lc = lc + } + + svc, _ := setupTestService(t, "test-multi-cluster", "", "4.3.2.1", fc, stateSecret) + expectReconciled(t, pgr, "default", svc.Name) + + tsSvcs, err := ft.ListVIPServices(context.Background()) + if err != nil { + t.Fatalf("getting Tailscale Service: %v", err) + } + + if len(tsSvcs.VIPServices) != 1 { + t.Fatalf("unexpected number of Tailscale Services (%d)", len(tsSvcs.VIPServices)) + } + + for _, svc := range tsSvcs.VIPServices { + t.Logf("found Tailscale Service with name %q", svc.Name) + } + } +} + +func TestIgnoreRegularService(t *testing.T) { + pgr, _, fc, ft, _ := setupServiceTest(t) + + svc := &corev1.Service{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test", + Namespace: "default", + // The apiserver is supposed to set the UID, but the fake client + // doesn't. So, set it explicitly because other code later depends + // on it being set. + UID: types.UID("1234-UID"), + Annotations: map[string]string{ + "tailscale.com/expose": "true", + }, + }, + Spec: corev1.ServiceSpec{ + ClusterIP: "10.20.30.40", + Type: corev1.ServiceTypeClusterIP, + }, + } + + mustCreate(t, fc, svc) + expectReconciled(t, pgr, "default", "test") + + verifyTailscaledConfig(t, fc, "test-pg", nil) + + tsSvcs, err := ft.ListVIPServices(context.Background()) + if err == nil { + if len(tsSvcs.VIPServices) > 0 { + t.Fatal("unexpected Tailscale Services found") + } + } +} + +func removeEl(s []string, value string) []string { + result := s[:0] + for _, v := range s { + if v != value { + result = append(result, v) + } + } + return result +} + +func updateIngressConfigSecret(t *testing.T, fc client.Client, stateSecret *corev1.Secret, serviceName string, clusterIP string) { + ingressConfig := ingressservices.Configs{ + fmt.Sprintf("svc:%s", serviceName): ingressservices.Config{ + IPv4Mapping: &ingressservices.Mapping{ + TailscaleServiceIP: netip.MustParseAddr(vipTestIP), + ClusterIP: netip.MustParseAddr(clusterIP), + }, + }, + } + + ingressStatus := ingressservices.Status{ + Configs: ingressConfig, + PodIPv4: "4.3.2.1", + } + + icJson, err := json.Marshal(ingressStatus) + if err != nil { + t.Fatalf("failed to json marshal ingress config: %s", err.Error()) + } + + mustUpdate(t, fc, stateSecret.Namespace, stateSecret.Name, func(sec *corev1.Secret) { + mak.Set(&sec.Data, ingressservices.IngressConfigKey, icJson) + }) +} + +func setupTestService(t *testing.T, svcName string, hostname string, clusterIP string, fc client.Client, stateSecret *corev1.Secret) (svc *corev1.Service, eps *discoveryv1.EndpointSlice) { + uid := rand.IntN(100) + svc = &corev1.Service{ + TypeMeta: metav1.TypeMeta{Kind: "Service", APIVersion: "v1"}, + ObjectMeta: metav1.ObjectMeta{ + Name: svcName, + Namespace: "default", + UID: types.UID(fmt.Sprintf("%d-UID", uid)), + Annotations: map[string]string{ + "tailscale.com/proxy-group": "test-pg", + }, + }, + Spec: corev1.ServiceSpec{ + Type: corev1.ServiceTypeLoadBalancer, + LoadBalancerClass: ptr.To("tailscale"), + ClusterIP: clusterIP, + ClusterIPs: []string{clusterIP}, + }, + } + + eps = &discoveryv1.EndpointSlice{ + TypeMeta: metav1.TypeMeta{Kind: "EndpointSlice", APIVersion: "v1"}, + ObjectMeta: metav1.ObjectMeta{ + Name: svcName, + Namespace: "default", + Labels: map[string]string{ + discoveryv1.LabelServiceName: svcName, + }, + }, + AddressType: discoveryv1.AddressTypeIPv4, + Endpoints: []discoveryv1.Endpoint{ + { + Addresses: []string{"4.3.2.1"}, + Conditions: discoveryv1.EndpointConditions{ + Ready: ptr.To(true), + }, + }, + }, + } + + updateIngressConfigSecret(t, fc, stateSecret, fmt.Sprintf("default-%s", svcName), clusterIP) + + mustCreate(t, fc, svc) + mustCreate(t, fc, eps) + + return svc, eps +} diff --git a/cmd/k8s-operator/svc.go b/cmd/k8s-operator/svc.go index e47fcae7f..eec1924e7 100644 --- a/cmd/k8s-operator/svc.go +++ b/cmd/k8s-operator/svc.go @@ -41,6 +41,8 @@ const ( reasonProxyInvalid = "ProxyInvalid" reasonProxyFailed = "ProxyFailed" reasonProxyPending = "ProxyPending" + + indexServiceProxyClass = ".metadata.annotations.service-proxy-class" ) type ServiceReconciler struct { @@ -64,7 +66,7 @@ type ServiceReconciler struct { clock tstime.Clock - proxyDefaultClass string + defaultProxyClass string } var ( @@ -84,10 +86,10 @@ func childResourceLabels(name, ns, typ string) map[string]string { // proxying. Instead, we have to do our own filtering and tracking with // labels. return map[string]string{ - LabelManaged: "true", - LabelParentName: name, - LabelParentNamespace: ns, - LabelParentType: typ, + kubetypes.LabelManaged: "true", + LabelParentName: name, + LabelParentNamespace: ns, + LabelParentType: typ, } } @@ -112,12 +114,24 @@ func (a *ServiceReconciler) Reconcile(ctx context.Context, req reconcile.Request return reconcile.Result{}, fmt.Errorf("failed to get svc: %w", err) } + if _, ok := svc.Annotations[AnnotationProxyGroup]; ok { + return reconcile.Result{}, nil // this reconciler should not look at Services for ProxyGroup + } + if !svc.DeletionTimestamp.IsZero() || !a.isTailscaleService(svc) { logger.Debugf("service is being deleted or is (no longer) referring to Tailscale ingress/egress, ensuring any created resources are cleaned up") return reconcile.Result{}, a.maybeCleanup(ctx, logger, svc) } - return reconcile.Result{}, a.maybeProvision(ctx, logger, svc) + if err := a.maybeProvision(ctx, logger, svc); err != nil { + if strings.Contains(err.Error(), optimisticLockErrorMsg) { + logger.Infof("optimistic lock error, retrying: %s", err) + } else { + return reconcile.Result{}, err + } + } + + return reconcile.Result{}, nil } // maybeCleanup removes any existing resources related to serving svc over tailscale. @@ -127,7 +141,7 @@ func (a *ServiceReconciler) Reconcile(ctx context.Context, req reconcile.Request func (a *ServiceReconciler) maybeCleanup(ctx context.Context, logger *zap.SugaredLogger, svc *corev1.Service) (err error) { oldSvcStatus := svc.Status.DeepCopy() defer func() { - if !apiequality.Semantic.DeepEqual(oldSvcStatus, svc.Status) { + if !apiequality.Semantic.DeepEqual(oldSvcStatus, &svc.Status) { // An error encountered here should get returned by the Reconcile function. err = errors.Join(err, a.Client.Status().Update(ctx, svc)) } @@ -148,7 +162,12 @@ func (a *ServiceReconciler) maybeCleanup(ctx context.Context, logger *zap.Sugare return nil } - if done, err := a.ssr.Cleanup(ctx, logger, childResourceLabels(svc.Name, svc.Namespace, "svc")); err != nil { + proxyTyp := proxyTypeEgress + if a.shouldExpose(svc) { + proxyTyp = proxyTypeIngressService + } + + if done, err := a.ssr.Cleanup(ctx, logger, childResourceLabels(svc.Name, svc.Namespace, "svc"), proxyTyp); err != nil { return fmt.Errorf("failed to cleanup: %w", err) } else if !done { logger.Debugf("cleanup not done yet, waiting for next reconcile") @@ -187,7 +206,7 @@ func (a *ServiceReconciler) maybeCleanup(ctx context.Context, logger *zap.Sugare func (a *ServiceReconciler) maybeProvision(ctx context.Context, logger *zap.SugaredLogger, svc *corev1.Service) (err error) { oldSvcStatus := svc.Status.DeepCopy() defer func() { - if !apiequality.Semantic.DeepEqual(oldSvcStatus, svc.Status) { + if !apiequality.Semantic.DeepEqual(oldSvcStatus, &svc.Status) { // An error encountered here should get returned by the Reconcile function. err = errors.Join(err, a.Client.Status().Update(ctx, svc)) } @@ -211,7 +230,7 @@ func (a *ServiceReconciler) maybeProvision(ctx context.Context, logger *zap.Suga return nil } - proxyClass := proxyClassForObject(svc, a.proxyDefaultClass) + proxyClass := proxyClassForObject(svc, a.defaultProxyClass) if proxyClass != "" { if ready, err := proxyClassIsReady(ctx, proxyClass, a.Client); err != nil { errMsg := fmt.Errorf("error verifying ProxyClass for Service: %w", err) @@ -245,12 +264,18 @@ func (a *ServiceReconciler) maybeProvision(ctx context.Context, logger *zap.Suga } sts := &tailscaleSTSConfig{ + Replicas: 1, ParentResourceName: svc.Name, ParentResourceUID: string(svc.UID), Hostname: nameForService(svc), Tags: tags, ChildResourceLabels: crl, ProxyClassName: proxyClass, + LoginServer: a.ssr.loginServer, + } + sts.proxyType = proxyTypeEgress + if a.shouldExpose(svc) { + sts.proxyType = proxyTypeIngressService } a.mu.Lock() @@ -307,11 +332,12 @@ func (a *ServiceReconciler) maybeProvision(ctx context.Context, logger *zap.Suga return nil } - _, tsHost, tsIPs, err := a.ssr.DeviceInfo(ctx, crl) + devices, err := a.ssr.DeviceInfo(ctx, crl, logger) if err != nil { return fmt.Errorf("failed to get device ID: %w", err) } - if tsHost == "" { + + if len(devices) == 0 || devices[0].hostname == "" { msg := "no Tailscale hostname known yet, waiting for proxy pod to finish auth" logger.Debug(msg) // No hostname yet. Wait for the proxy pod to auth. @@ -320,17 +346,21 @@ func (a *ServiceReconciler) maybeProvision(ctx context.Context, logger *zap.Suga return nil } - logger.Debugf("setting Service LoadBalancer status to %q, %s", tsHost, strings.Join(tsIPs, ", ")) + dev := devices[0] + logger.Debugf("setting Service LoadBalancer status to %q, %s", dev.hostname, strings.Join(dev.ips, ", ")) + ingress := []corev1.LoadBalancerIngress{ - {Hostname: tsHost}, + {Hostname: dev.hostname}, } + clusterIPAddr, err := netip.ParseAddr(svc.Spec.ClusterIP) if err != nil { msg := fmt.Sprintf("failed to parse cluster IP: %v", err) tsoperator.SetServiceCondition(svc, tsapi.ProxyReady, metav1.ConditionFalse, reasonProxyFailed, msg, a.clock, logger) return errors.New(msg) } - for _, ip := range tsIPs { + + for _, ip := range dev.ips { addr, err := netip.ParseAddr(ip) if err != nil { continue @@ -339,6 +369,7 @@ func (a *ServiceReconciler) maybeProvision(ctx context.Context, logger *zap.Suga ingress = append(ingress, corev1.LoadBalancerIngress{IP: ip}) } } + svc.Status.LoadBalancer.Ingress = ingress tsoperator.SetServiceCondition(svc, tsapi.ProxyReady, metav1.ConditionTrue, reasonProxyCreated, reasonProxyCreated, a.clock, logger) return nil @@ -354,9 +385,14 @@ func validateService(svc *corev1.Service) []string { violations = append(violations, fmt.Sprintf("invalid value of annotation %s: %q does not appear to be a valid MagicDNS name", AnnotationTailnetTargetFQDN, fqdn)) } } - - // TODO(irbekrm): validate that tailscale.com/tailnet-ip annotation is a - // valid IP address (tailscale/tailscale#13671). + if ipStr := svc.Annotations[AnnotationTailnetTargetIP]; ipStr != "" { + ip, err := netip.ParseAddr(ipStr) + if err != nil { + violations = append(violations, fmt.Sprintf("invalid value of annotation %s: %q could not be parsed as a valid IP Address, error: %s", AnnotationTailnetTargetIP, ipStr, err)) + } else if !ip.IsValid() { + violations = append(violations, fmt.Sprintf("parsed IP address in annotation %s: %q is not valid", AnnotationTailnetTargetIP, ipStr)) + } + } svcName := nameForService(svc) if err := dnsname.ValidLabel(svcName); err != nil { @@ -366,6 +402,7 @@ func validateService(svc *corev1.Service) []string { violations = append(violations, fmt.Sprintf("invalid Tailscale hostname %q, use %q annotation to override: %s", svcName, AnnotationHostname, err)) } } + violations = append(violations, tagViolations(svc)...) return violations } @@ -411,16 +448,6 @@ func tailnetTargetAnnotation(svc *corev1.Service) string { return svc.Annotations[annotationTailnetTargetIPOld] } -// proxyClassForObject returns the proxy class for the given object. If the -// object does not have a proxy class label, it returns the default proxy class -func proxyClassForObject(o client.Object, proxyDefaultClass string) string { - proxyClass, exists := o.GetLabels()[LabelProxyClass] - if !exists { - proxyClass = proxyDefaultClass - } - return proxyClass -} - func proxyClassIsReady(ctx context.Context, name string, cl client.Client) (bool, error) { proxyClass := new(tsapi.ProxyClass) if err := cl.Get(ctx, types.NamespacedName{Name: name}, proxyClass); err != nil { diff --git a/cmd/k8s-operator/testutils_test.go b/cmd/k8s-operator/testutils_test.go index 457248d57..b4c468c8e 100644 --- a/cmd/k8s-operator/testutils_test.go +++ b/cmd/k8s-operator/testutils_test.go @@ -8,7 +8,10 @@ package main import ( "context" "encoding/json" + "fmt" + "net/http" "net/netip" + "path" "reflect" "strings" "sync" @@ -21,17 +24,25 @@ import ( corev1 "k8s.io/api/core/v1" apierrors "k8s.io/apimachinery/pkg/api/errors" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" "k8s.io/apimachinery/pkg/types" "k8s.io/client-go/tools/record" "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/reconcile" - "tailscale.com/client/tailscale" + "tailscale.com/internal/client/tailscale" "tailscale.com/ipn" + "tailscale.com/ipn/ipnstate" tsapi "tailscale.com/k8s-operator/apis/v1alpha1" + "tailscale.com/kube/kubetypes" + "tailscale.com/tailcfg" "tailscale.com/types/ptr" "tailscale.com/util/mak" ) +const ( + vipTestIP = "5.6.7.8" +) + // confgOpts contains configuration options for creating cluster resources for // Tailscale proxies. type configOpts struct { @@ -39,7 +50,10 @@ type configOpts struct { secretName string hostname string namespace string + tailscaleNamespace string + namespaced bool parentType string + proxyType string priorityClassName string firewallMode string tailnetTargetIP string @@ -48,13 +62,17 @@ type configOpts struct { clusterTargetDNS string subnetRoutes string isExitNode bool - confFileHash string + isAppConnector bool serveConfig *ipn.ServeConfig shouldEnableForwardingClusterTrafficViaIngress bool proxyClass string // configuration from the named ProxyClass should be applied to proxy resources app string shouldRemoveAuthKey bool secretExtraData map[string][]byte + resourceVersion string + replicas *int32 + enableMetrics bool + serviceMonitorLabels tsapi.Labels } func expectedSTS(t *testing.T, cl client.Client, opts configOpts) *appsv1.StatefulSet { @@ -69,14 +87,13 @@ func expectedSTS(t *testing.T, cl client.Client, opts configOpts) *appsv1.Statef Env: []corev1.EnvVar{ {Name: "TS_USERSPACE", Value: "false"}, {Name: "POD_IP", ValueFrom: &corev1.EnvVarSource{FieldRef: &corev1.ObjectFieldSelector{APIVersion: "", FieldPath: "status.podIP"}, ResourceFieldRef: nil, ConfigMapKeyRef: nil, SecretKeyRef: nil}}, - {Name: "TS_KUBE_SECRET", Value: opts.secretName}, - {Name: "EXPERIMENTAL_TS_CONFIGFILE_PATH", Value: "/etc/tsconfig/tailscaled"}, - {Name: "TS_EXPERIMENTAL_VERSIONED_CONFIG_DIR", Value: "/etc/tsconfig"}, + {Name: "POD_NAME", ValueFrom: &corev1.EnvVarSource{FieldRef: &corev1.ObjectFieldSelector{APIVersion: "", FieldPath: "metadata.name"}, ResourceFieldRef: nil, ConfigMapKeyRef: nil, SecretKeyRef: nil}}, + {Name: "POD_UID", ValueFrom: &corev1.EnvVarSource{FieldRef: &corev1.ObjectFieldSelector{APIVersion: "", FieldPath: "metadata.uid"}, ResourceFieldRef: nil, ConfigMapKeyRef: nil, SecretKeyRef: nil}}, + {Name: "TS_KUBE_SECRET", Value: "$(POD_NAME)"}, + {Name: "TS_EXPERIMENTAL_VERSIONED_CONFIG_DIR", Value: "/etc/tsconfig/$(POD_NAME)"}, }, SecurityContext: &corev1.SecurityContext{ - Capabilities: &corev1.Capabilities{ - Add: []corev1.Capability{"NET_ADMIN"}, - }, + Privileged: ptr.To(true), }, ImagePullPolicy: "Always", } @@ -86,11 +103,11 @@ func expectedSTS(t *testing.T, cl client.Client, opts configOpts) *appsv1.Statef Value: "true", }) } - annots := make(map[string]string) + var annots map[string]string var volumes []corev1.Volume volumes = []corev1.Volume{ { - Name: "tailscaledconfig", + Name: "tailscaledconfig-0", VolumeSource: corev1.VolumeSource{ Secret: &corev1.SecretVolumeSource{ SecretName: opts.secretName, @@ -99,13 +116,10 @@ func expectedSTS(t *testing.T, cl client.Client, opts configOpts) *appsv1.Statef }, } tsContainer.VolumeMounts = []corev1.VolumeMount{{ - Name: "tailscaledconfig", + Name: "tailscaledconfig-0", ReadOnly: true, - MountPath: "/etc/tsconfig", + MountPath: "/etc/tsconfig/" + opts.secretName, }} - if opts.confFileHash != "" { - annots["tailscale.com/operator-last-set-config-file-hash"] = opts.confFileHash - } if opts.firewallMode != "" { tsContainer.Env = append(tsContainer.Env, corev1.EnvVar{ Name: "TS_DEBUG_FIREWALL_MODE", @@ -113,13 +127,13 @@ func expectedSTS(t *testing.T, cl client.Client, opts configOpts) *appsv1.Statef }) } if opts.tailnetTargetIP != "" { - annots["tailscale.com/operator-last-set-ts-tailnet-target-ip"] = opts.tailnetTargetIP + mak.Set(&annots, "tailscale.com/operator-last-set-ts-tailnet-target-ip", opts.tailnetTargetIP) tsContainer.Env = append(tsContainer.Env, corev1.EnvVar{ Name: "TS_TAILNET_TARGET_IP", Value: opts.tailnetTargetIP, }) } else if opts.tailnetTargetFQDN != "" { - annots["tailscale.com/operator-last-set-ts-tailnet-target-fqdn"] = opts.tailnetTargetFQDN + mak.Set(&annots, "tailscale.com/operator-last-set-ts-tailnet-target-fqdn", opts.tailnetTargetFQDN) tsContainer.Env = append(tsContainer.Env, corev1.EnvVar{ Name: "TS_TAILNET_TARGET_FQDN", Value: opts.tailnetTargetFQDN, @@ -130,26 +144,60 @@ func expectedSTS(t *testing.T, cl client.Client, opts configOpts) *appsv1.Statef Name: "TS_DEST_IP", Value: opts.clusterTargetIP, }) - annots["tailscale.com/operator-last-set-cluster-ip"] = opts.clusterTargetIP + mak.Set(&annots, "tailscale.com/operator-last-set-cluster-ip", opts.clusterTargetIP) } else if opts.clusterTargetDNS != "" { tsContainer.Env = append(tsContainer.Env, corev1.EnvVar{ Name: "TS_EXPERIMENTAL_DEST_DNS_NAME", Value: opts.clusterTargetDNS, }) - annots["tailscale.com/operator-last-set-cluster-dns-name"] = opts.clusterTargetDNS + mak.Set(&annots, "tailscale.com/operator-last-set-cluster-dns-name", opts.clusterTargetDNS) } if opts.serveConfig != nil { tsContainer.Env = append(tsContainer.Env, corev1.EnvVar{ Name: "TS_SERVE_CONFIG", - Value: "/etc/tailscaled/serve-config", + Value: "/etc/tailscaled/$(POD_NAME)/serve-config", + }) + volumes = append(volumes, corev1.Volume{ + Name: "serve-config-0", + VolumeSource: corev1.VolumeSource{ + Secret: &corev1.SecretVolumeSource{ + SecretName: opts.secretName, + Items: []corev1.KeyToPath{{ + Key: "serve-config", + Path: "serve-config", + }}, + }, + }, }) - volumes = append(volumes, corev1.Volume{Name: "serve-config", VolumeSource: corev1.VolumeSource{Secret: &corev1.SecretVolumeSource{SecretName: opts.secretName, Items: []corev1.KeyToPath{{Key: "serve-config", Path: "serve-config"}}}}}) - tsContainer.VolumeMounts = append(tsContainer.VolumeMounts, corev1.VolumeMount{Name: "serve-config", ReadOnly: true, MountPath: "/etc/tailscaled"}) + tsContainer.VolumeMounts = append(tsContainer.VolumeMounts, corev1.VolumeMount{Name: "serve-config-0", ReadOnly: true, MountPath: path.Join("/etc/tailscaled", opts.secretName)}) } tsContainer.Env = append(tsContainer.Env, corev1.EnvVar{ Name: "TS_INTERNAL_APP", Value: opts.app, }) + if opts.enableMetrics { + tsContainer.Env = append(tsContainer.Env, + corev1.EnvVar{ + Name: "TS_DEBUG_ADDR_PORT", + Value: "$(POD_IP):9001"}, + corev1.EnvVar{ + Name: "TS_TAILSCALED_EXTRA_ARGS", + Value: "--debug=$(TS_DEBUG_ADDR_PORT)", + }, + corev1.EnvVar{ + Name: "TS_LOCAL_ADDR_PORT", + Value: "$(POD_IP):9002", + }, + corev1.EnvVar{ + Name: "TS_ENABLE_METRICS", + Value: "true", + }, + ) + tsContainer.Ports = append(tsContainer.Ports, + corev1.ContainerPort{Name: "debug", ContainerPort: 9001, Protocol: "TCP"}, + corev1.ContainerPort{Name: "metrics", ContainerPort: 9002, Protocol: "TCP"}, + ) + } ss := &appsv1.StatefulSet{ TypeMeta: metav1.TypeMeta{ Kind: "StatefulSet", @@ -166,7 +214,7 @@ func expectedSTS(t *testing.T, cl client.Client, opts configOpts) *appsv1.Statef }, }, Spec: appsv1.StatefulSetSpec{ - Replicas: ptr.To[int32](1), + Replicas: opts.replicas, Selector: &metav1.LabelSelector{ MatchLabels: map[string]string{"app": "1234-UID"}, }, @@ -228,30 +276,60 @@ func expectedSTSUserspace(t *testing.T, cl client.Client, opts configOpts) *apps Env: []corev1.EnvVar{ {Name: "TS_USERSPACE", Value: "true"}, {Name: "POD_IP", ValueFrom: &corev1.EnvVarSource{FieldRef: &corev1.ObjectFieldSelector{APIVersion: "", FieldPath: "status.podIP"}, ResourceFieldRef: nil, ConfigMapKeyRef: nil, SecretKeyRef: nil}}, - {Name: "TS_KUBE_SECRET", Value: opts.secretName}, - {Name: "EXPERIMENTAL_TS_CONFIGFILE_PATH", Value: "/etc/tsconfig/tailscaled"}, - {Name: "TS_EXPERIMENTAL_VERSIONED_CONFIG_DIR", Value: "/etc/tsconfig"}, - {Name: "TS_SERVE_CONFIG", Value: "/etc/tailscaled/serve-config"}, + {Name: "POD_NAME", ValueFrom: &corev1.EnvVarSource{FieldRef: &corev1.ObjectFieldSelector{APIVersion: "", FieldPath: "metadata.name"}, ResourceFieldRef: nil, ConfigMapKeyRef: nil, SecretKeyRef: nil}}, + {Name: "POD_UID", ValueFrom: &corev1.EnvVarSource{FieldRef: &corev1.ObjectFieldSelector{APIVersion: "", FieldPath: "metadata.uid"}, ResourceFieldRef: nil, ConfigMapKeyRef: nil, SecretKeyRef: nil}}, + {Name: "TS_KUBE_SECRET", Value: "$(POD_NAME)"}, + {Name: "TS_EXPERIMENTAL_VERSIONED_CONFIG_DIR", Value: "/etc/tsconfig/$(POD_NAME)"}, + {Name: "TS_SERVE_CONFIG", Value: "/etc/tailscaled/$(POD_NAME)/serve-config"}, {Name: "TS_INTERNAL_APP", Value: opts.app}, }, ImagePullPolicy: "Always", VolumeMounts: []corev1.VolumeMount{ - {Name: "tailscaledconfig", ReadOnly: true, MountPath: "/etc/tsconfig"}, - {Name: "serve-config", ReadOnly: true, MountPath: "/etc/tailscaled"}, + {Name: "tailscaledconfig-0", ReadOnly: true, MountPath: path.Join("/etc/tsconfig", opts.secretName)}, + {Name: "serve-config-0", ReadOnly: true, MountPath: path.Join("/etc/tailscaled", opts.secretName)}, }, } + if opts.enableMetrics { + tsContainer.Env = append(tsContainer.Env, + corev1.EnvVar{ + Name: "TS_DEBUG_ADDR_PORT", + Value: "$(POD_IP):9001"}, + corev1.EnvVar{ + Name: "TS_TAILSCALED_EXTRA_ARGS", + Value: "--debug=$(TS_DEBUG_ADDR_PORT)", + }, + corev1.EnvVar{ + Name: "TS_LOCAL_ADDR_PORT", + Value: "$(POD_IP):9002", + }, + corev1.EnvVar{ + Name: "TS_ENABLE_METRICS", + Value: "true", + }, + ) + tsContainer.Ports = append(tsContainer.Ports, corev1.ContainerPort{ + Name: "debug", ContainerPort: 9001, Protocol: "TCP"}, + corev1.ContainerPort{Name: "metrics", ContainerPort: 9002, Protocol: "TCP"}, + ) + } volumes := []corev1.Volume{ { - Name: "tailscaledconfig", + Name: "tailscaledconfig-0", VolumeSource: corev1.VolumeSource{ Secret: &corev1.SecretVolumeSource{ SecretName: opts.secretName, }, }, }, - {Name: "serve-config", + { + Name: "serve-config-0", VolumeSource: corev1.VolumeSource{ - Secret: &corev1.SecretVolumeSource{SecretName: opts.secretName, Items: []corev1.KeyToPath{{Key: "serve-config", Path: "serve-config"}}}}}, + Secret: &corev1.SecretVolumeSource{ + SecretName: opts.secretName, + Items: []corev1.KeyToPath{{Key: "serve-config", Path: "serve-config"}}, + }, + }, + }, } ss := &appsv1.StatefulSet{ TypeMeta: metav1.TypeMeta{ @@ -294,10 +372,6 @@ func expectedSTSUserspace(t *testing.T, cl client.Client, opts configOpts) *apps }, }, } - ss.Spec.Template.Annotations = map[string]string{} - if opts.confFileHash != "" { - ss.Spec.Template.Annotations["tailscale.com/operator-last-set-config-file-hash"] = opts.confFileHash - } // If opts.proxyClass is set, retrieve the ProxyClass and apply // configuration from that to the StatefulSet. if opts.proxyClass != "" { @@ -334,6 +408,90 @@ func expectedHeadlessService(name string, parentType string) *corev1.Service { } } +func expectedMetricsService(opts configOpts) *corev1.Service { + labels := metricsLabels(opts) + selector := map[string]string{ + "tailscale.com/managed": "true", + "tailscale.com/parent-resource": "test", + "tailscale.com/parent-resource-type": opts.parentType, + } + if opts.namespaced { + selector["tailscale.com/parent-resource-ns"] = opts.namespace + } + return &corev1.Service{ + ObjectMeta: metav1.ObjectMeta{ + Name: metricsResourceName(opts.stsName), + Namespace: opts.tailscaleNamespace, + Labels: labels, + }, + Spec: corev1.ServiceSpec{ + Selector: selector, + Type: corev1.ServiceTypeClusterIP, + Ports: []corev1.ServicePort{{Protocol: "TCP", Port: 9002, Name: "metrics"}}, + }, + } +} + +func metricsLabels(opts configOpts) map[string]string { + promJob := fmt.Sprintf("ts_%s_default_test", opts.proxyType) + if !opts.namespaced { + promJob = fmt.Sprintf("ts_%s_test", opts.proxyType) + } + labels := map[string]string{ + "tailscale.com/managed": "true", + "tailscale.com/metrics-target": opts.stsName, + "ts_prom_job": promJob, + "ts_proxy_type": opts.proxyType, + "ts_proxy_parent_name": "test", + } + if opts.namespaced { + labels["ts_proxy_parent_namespace"] = "default" + } + return labels +} + +func expectedServiceMonitor(t *testing.T, opts configOpts) *unstructured.Unstructured { + t.Helper() + smLabels := metricsLabels(opts) + if len(opts.serviceMonitorLabels) != 0 { + smLabels = mergeMapKeys(smLabels, opts.serviceMonitorLabels.Parse()) + } + name := metricsResourceName(opts.stsName) + sm := &ServiceMonitor{ + ObjectMeta: metav1.ObjectMeta{ + Name: name, + Namespace: opts.tailscaleNamespace, + Labels: smLabels, + ResourceVersion: opts.resourceVersion, + OwnerReferences: []metav1.OwnerReference{{APIVersion: "v1", Kind: "Service", Name: name, BlockOwnerDeletion: ptr.To(true), Controller: ptr.To(true)}}, + }, + TypeMeta: metav1.TypeMeta{ + Kind: "ServiceMonitor", + APIVersion: "monitoring.coreos.com/v1", + }, + Spec: ServiceMonitorSpec{ + Selector: metav1.LabelSelector{MatchLabels: metricsLabels(opts)}, + Endpoints: []ServiceMonitorEndpoint{{ + Port: "metrics", + }}, + NamespaceSelector: ServiceMonitorNamespaceSelector{ + MatchNames: []string{opts.tailscaleNamespace}, + }, + JobLabel: "ts_prom_job", + TargetLabels: []string{ + "ts_proxy_parent_name", + "ts_proxy_parent_namespace", + "ts_proxy_type", + }, + }, + } + u, err := serviceMonitorToUnstructured(sm) + if err != nil { + t.Fatalf("error converting ServiceMonitor to unstructured: %v", err) + } + return u +} + func expectedSecret(t *testing.T, cl client.Client, opts configOpts) *corev1.Secret { t.Helper() s := &corev1.Secret{ @@ -350,12 +508,14 @@ func expectedSecret(t *testing.T, cl client.Client, opts configOpts) *corev1.Sec mak.Set(&s.StringData, "serve-config", string(serveConfigBs)) } conf := &ipn.ConfigVAlpha{ - Version: "alpha0", - AcceptDNS: "false", - Hostname: &opts.hostname, - Locked: "false", - AuthKey: ptr.To("secret-authkey"), - AcceptRoutes: "false", + Version: "alpha0", + AcceptDNS: "false", + Hostname: &opts.hostname, + Locked: "false", + AuthKey: ptr.To("secret-authkey"), + AcceptRoutes: "false", + AppConnector: &ipn.AppConnectorPrefs{Advertise: false}, + NoStatefulFiltering: "true", } if opts.proxyClass != "" { t.Logf("applying configuration from ProxyClass %s", opts.proxyClass) @@ -370,6 +530,9 @@ func expectedSecret(t *testing.T, cl client.Client, opts configOpts) *corev1.Sec if opts.shouldRemoveAuthKey { conf.AuthKey = nil } + if opts.isAppConnector { + conf.AppConnector = &ipn.AppConnectorPrefs{Advertise: true} + } var routes []netip.Prefix if opts.subnetRoutes != "" || opts.isExitNode { r := opts.subnetRoutes @@ -385,21 +548,17 @@ func expectedSecret(t *testing.T, cl client.Client, opts configOpts) *corev1.Sec } } conf.AdvertiseRoutes = routes - b, err := json.Marshal(conf) + bnn, err := json.Marshal(conf) if err != nil { t.Fatalf("error marshalling tailscaled config") } - if opts.tailnetTargetFQDN != "" || opts.tailnetTargetIP != "" { - conf.NoStatefulFiltering = "true" - } else { - conf.NoStatefulFiltering = "false" - } + conf.AppConnector = nil bn, err := json.Marshal(conf) if err != nil { t.Fatalf("error marshalling tailscaled config") } - mak.Set(&s.StringData, "tailscaled", string(b)) mak.Set(&s.StringData, "cap-95.hujson", string(bn)) + mak.Set(&s.StringData, "cap-107.hujson", string(bnn)) labels := map[string]string{ "tailscale.com/managed": "true", "tailscale.com/parent-resource": "test", @@ -416,13 +575,30 @@ func expectedSecret(t *testing.T, cl client.Client, opts configOpts) *corev1.Sec return s } +func findNoGenName(t *testing.T, client client.Client, ns, name, typ string) { + t.Helper() + labels := map[string]string{ + kubetypes.LabelManaged: "true", + LabelParentName: name, + LabelParentNamespace: ns, + LabelParentType: typ, + } + s, err := getSingleObject[corev1.Secret](context.Background(), client, "operator-ns", labels) + if err != nil { + t.Fatalf("finding secrets for %q: %v", name, err) + } + if s != nil { + t.Fatalf("found unexpected secret with name %q", s.GetName()) + } +} + func findGenName(t *testing.T, client client.Client, ns, name, typ string) (full, noSuffix string) { t.Helper() labels := map[string]string{ - LabelManaged: "true", - LabelParentName: name, - LabelParentNamespace: ns, - LabelParentType: typ, + kubetypes.LabelManaged: "true", + LabelParentName: name, + LabelParentNamespace: ns, + LabelParentType: typ, } s, err := getSingleObject[corev1.Secret](context.Background(), client, "operator-ns", labels) if err != nil { @@ -434,12 +610,53 @@ func findGenName(t *testing.T, client client.Client, ns, name, typ string) (full return s.GetName(), strings.TrimSuffix(s.GetName(), "-0") } +func findGenNames(t *testing.T, cl client.Client, ns, name, typ string) []string { + t.Helper() + labels := map[string]string{ + kubetypes.LabelManaged: "true", + LabelParentName: name, + LabelParentNamespace: ns, + LabelParentType: typ, + } + + var list corev1.SecretList + if err := cl.List(t.Context(), &list, client.InNamespace(ns), client.MatchingLabels(labels)); err != nil { + t.Fatalf("finding secrets for %q: %v", name, err) + } + + if len(list.Items) == 0 { + t.Fatalf("no secrets found for %q %s %+#v", name, ns, labels) + } + + names := make([]string, len(list.Items)) + for i, secret := range list.Items { + names[i] = secret.GetName() + } + + return names +} + func mustCreate(t *testing.T, client client.Client, obj client.Object) { t.Helper() if err := client.Create(context.Background(), obj); err != nil { t.Fatalf("creating %q: %v", obj.GetName(), err) } } +func mustCreateAll(t *testing.T, client client.Client, objs ...client.Object) { + t.Helper() + for _, obj := range objs { + mustCreate(t, client, obj) + } +} + +func mustDeleteAll(t *testing.T, client client.Client, objs ...client.Object) { + t.Helper() + for _, obj := range objs { + if err := client.Delete(context.Background(), obj); err != nil { + t.Fatalf("deleting %q: %v", obj.GetName(), err) + } + } +} func mustUpdate[T any, O ptrObject[T]](t *testing.T, client client.Client, ns, name string, update func(O)) { t.Helper() @@ -477,7 +694,7 @@ func mustUpdateStatus[T any, O ptrObject[T]](t *testing.T, client client.Client, // modify func to ensure that they are removed from the cluster object and the // object passed as 'want'. If no such modifications are needed, you can pass // nil in place of the modify function. -func expectEqual[T any, O ptrObject[T]](t *testing.T, client client.Client, want O, modifier func(O)) { +func expectEqual[T any, O ptrObject[T]](t *testing.T, client client.Client, want O, modifiers ...func(O)) { t.Helper() got := O(new(T)) if err := client.Get(context.Background(), types.NamespacedName{ @@ -491,7 +708,7 @@ func expectEqual[T any, O ptrObject[T]](t *testing.T, client client.Client, want // so just remove it from both got and want. got.SetResourceVersion("") want.SetResourceVersion("") - if modifier != nil { + for _, modifier := range modifiers { modifier(want) modifier(got) } @@ -500,13 +717,29 @@ func expectEqual[T any, O ptrObject[T]](t *testing.T, client client.Client, want } } +func expectEqualUnstructured(t *testing.T, client client.Client, want *unstructured.Unstructured) { + t.Helper() + got := &unstructured.Unstructured{} + got.SetGroupVersionKind(want.GroupVersionKind()) + if err := client.Get(context.Background(), types.NamespacedName{ + Name: want.GetName(), + Namespace: want.GetNamespace(), + }, got); err != nil { + t.Fatalf("getting %q: %v", want.GetName(), err) + } + if diff := cmp.Diff(got, want); diff != "" { + t.Fatalf("unexpected contents of Unstructured (-got +want):\n%s", diff) + } +} + func expectMissing[T any, O ptrObject[T]](t *testing.T, client client.Client, ns, name string) { t.Helper() obj := O(new(T)) - if err := client.Get(context.Background(), types.NamespacedName{ + err := client.Get(context.Background(), types.NamespacedName{ Name: name, Namespace: ns, - }, obj); !apierrors.IsNotFound(err) { + }, obj) + if !apierrors.IsNotFound(err) { t.Fatalf("%s %s/%s unexpectedly present, wanted missing", reflect.TypeOf(obj).Elem().Name(), ns, name) } } @@ -547,6 +780,19 @@ func expectRequeue(t *testing.T, sr reconcile.Reconciler, ns, name string) { t.Fatalf("expected timed requeue, got success") } } +func expectError(t *testing.T, sr reconcile.Reconciler, ns, name string) { + t.Helper() + req := reconcile.Request{ + NamespacedName: types.NamespacedName{ + Name: name, + Namespace: ns, + }, + } + _, err := sr.Reconcile(context.Background(), req) + if err == nil { + t.Error("Reconcile: expected error but did not get one") + } +} // expectEvents accepts a test recorder and a list of events, tests that expected // events are sent down the recorder's channel. Waits for 5s for each event. @@ -580,6 +826,7 @@ type fakeTSClient struct { sync.Mutex keyRequests []tailscale.KeyCapabilities deleted []string + vipServices map[tailcfg.ServiceName]*tailscale.VIPService } type fakeTSNetServer struct { certDomains []string @@ -604,7 +851,7 @@ func (c *fakeTSClient) CreateKey(ctx context.Context, caps tailscale.KeyCapabili func (c *fakeTSClient) Device(ctx context.Context, deviceID string, fields *tailscale.DeviceFieldsOpts) (*tailscale.Device, error) { return &tailscale.Device{ DeviceID: deviceID, - Hostname: "test-device", + Hostname: "hostname-" + deviceID, Addresses: []string{ "1.2.3.4", "::1", @@ -631,18 +878,16 @@ func (c *fakeTSClient) Deleted() []string { return c.deleted } -// removeHashAnnotation can be used to remove declarative tailscaled config hash -// annotation from proxy StatefulSets to make the tests more maintainable (so -// that we don't have to change the annotation in each test case after any -// change to the configfile contents). -func removeHashAnnotation(sts *appsv1.StatefulSet) { - delete(sts.Spec.Template.Annotations, podAnnotationLastSetConfigFileHash) +func removeResourceReqs(sts *appsv1.StatefulSet) { + if sts != nil { + sts.Spec.Template.Spec.Resources = nil + } } func removeTargetPortsFromSvc(svc *corev1.Service) { newPorts := make([]corev1.ServicePort, 0) for _, p := range svc.Spec.Ports { - newPorts = append(newPorts, corev1.ServicePort{Protocol: p.Protocol, Port: p.Port}) + newPorts = append(newPorts, corev1.ServicePort{Protocol: p.Protocol, Port: p.Port, Name: p.Name}) } svc.Spec.Ports = newPorts } @@ -650,29 +895,94 @@ func removeTargetPortsFromSvc(svc *corev1.Service) { func removeAuthKeyIfExistsModifier(t *testing.T) func(s *corev1.Secret) { return func(secret *corev1.Secret) { t.Helper() - if len(secret.StringData["tailscaled"]) != 0 { + if len(secret.StringData["cap-95.hujson"]) != 0 { conf := &ipn.ConfigVAlpha{} - if err := json.Unmarshal([]byte(secret.StringData["tailscaled"]), conf); err != nil { - t.Fatalf("error unmarshalling 'tailscaled' contents: %v", err) + if err := json.Unmarshal([]byte(secret.StringData["cap-95.hujson"]), conf); err != nil { + t.Fatalf("error umarshalling 'cap-95.hujson' contents: %v", err) } conf.AuthKey = nil b, err := json.Marshal(conf) if err != nil { - t.Fatalf("error marshalling updated 'tailscaled' config: %v", err) + t.Fatalf("error marshalling 'cap-95.huson' contents: %v", err) } - mak.Set(&secret.StringData, "tailscaled", string(b)) + mak.Set(&secret.StringData, "cap-95.hujson", string(b)) } - if len(secret.StringData["cap-95.hujson"]) != 0 { + if len(secret.StringData["cap-107.hujson"]) != 0 { conf := &ipn.ConfigVAlpha{} - if err := json.Unmarshal([]byte(secret.StringData["cap-95.hujson"]), conf); err != nil { - t.Fatalf("error umarshalling 'cap-95.hujson' contents: %v", err) + if err := json.Unmarshal([]byte(secret.StringData["cap-107.hujson"]), conf); err != nil { + t.Fatalf("error umarshalling 'cap-107.hujson' contents: %v", err) } conf.AuthKey = nil b, err := json.Marshal(conf) if err != nil { - t.Fatalf("error marshalling 'cap-95.huson' contents: %v", err) + t.Fatalf("error marshalling 'cap-107.huson' contents: %v", err) } - mak.Set(&secret.StringData, "cap-95.hujson", string(b)) + mak.Set(&secret.StringData, "cap-107.hujson", string(b)) } } } + +func (c *fakeTSClient) GetVIPService(ctx context.Context, name tailcfg.ServiceName) (*tailscale.VIPService, error) { + c.Lock() + defer c.Unlock() + if c.vipServices == nil { + return nil, tailscale.ErrResponse{Status: http.StatusNotFound} + } + svc, ok := c.vipServices[name] + if !ok { + return nil, tailscale.ErrResponse{Status: http.StatusNotFound} + } + return svc, nil +} + +func (c *fakeTSClient) ListVIPServices(ctx context.Context) (*tailscale.VIPServiceList, error) { + c.Lock() + defer c.Unlock() + if c.vipServices == nil { + return nil, &tailscale.ErrResponse{Status: http.StatusNotFound} + } + result := &tailscale.VIPServiceList{} + for _, svc := range c.vipServices { + result.VIPServices = append(result.VIPServices, *svc) + } + return result, nil +} + +func (c *fakeTSClient) CreateOrUpdateVIPService(ctx context.Context, svc *tailscale.VIPService) error { + c.Lock() + defer c.Unlock() + if c.vipServices == nil { + c.vipServices = make(map[tailcfg.ServiceName]*tailscale.VIPService) + } + + if svc.Addrs == nil { + svc.Addrs = []string{vipTestIP} + } + + c.vipServices[svc.Name] = svc + return nil +} + +func (c *fakeTSClient) DeleteVIPService(ctx context.Context, name tailcfg.ServiceName) error { + c.Lock() + defer c.Unlock() + if c.vipServices != nil { + delete(c.vipServices, name) + } + return nil +} + +type fakeLocalClient struct { + status *ipnstate.Status +} + +func (f *fakeLocalClient) StatusWithoutPeers(ctx context.Context) (*ipnstate.Status, error) { + if f.status == nil { + return &ipnstate.Status{ + Self: &ipnstate.PeerStatus{ + DNSName: "test-node.test.ts.net.", + }, + }, nil + } + return f.status, nil +} diff --git a/cmd/k8s-operator/tsclient.go b/cmd/k8s-operator/tsclient.go new file mode 100644 index 000000000..d22fa1797 --- /dev/null +++ b/cmd/k8s-operator/tsclient.go @@ -0,0 +1,133 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !plan9 + +package main + +import ( + "context" + "fmt" + "net/http" + "os" + "sync" + "time" + + "go.uber.org/zap" + "golang.org/x/oauth2" + "golang.org/x/oauth2/clientcredentials" + "tailscale.com/internal/client/tailscale" + "tailscale.com/ipn" + "tailscale.com/tailcfg" +) + +// defaultTailnet is a value that can be used in Tailscale API calls instead of tailnet name to indicate that the API +// call should be performed on the default tailnet for the provided credentials. +const ( + defaultTailnet = "-" + oidcJWTPath = "/var/run/secrets/tailscale/serviceaccount/token" +) + +func newTSClient(logger *zap.SugaredLogger, clientID, clientIDPath, clientSecretPath, loginServer string) (*tailscale.Client, error) { + baseURL := ipn.DefaultControlURL + if loginServer != "" { + baseURL = loginServer + } + + var httpClient *http.Client + if clientID == "" { + // Use static client credentials mounted to disk. + id, err := os.ReadFile(clientIDPath) + if err != nil { + return nil, fmt.Errorf("error reading client ID %q: %w", clientIDPath, err) + } + secret, err := os.ReadFile(clientSecretPath) + if err != nil { + return nil, fmt.Errorf("reading client secret %q: %w", clientSecretPath, err) + } + credentials := clientcredentials.Config{ + ClientID: string(id), + ClientSecret: string(secret), + TokenURL: fmt.Sprintf("%s%s", baseURL, "/api/v2/oauth/token"), + } + tokenSrc := credentials.TokenSource(context.Background()) + httpClient = oauth2.NewClient(context.Background(), tokenSrc) + } else { + // Use workload identity federation. + tokenSrc := &jwtTokenSource{ + logger: logger, + jwtPath: oidcJWTPath, + baseCfg: clientcredentials.Config{ + ClientID: clientID, + TokenURL: fmt.Sprintf("%s%s", baseURL, "/api/v2/oauth/token-exchange"), + }, + } + httpClient = &http.Client{ + Transport: &oauth2.Transport{ + Source: tokenSrc, + }, + } + } + + c := tailscale.NewClient(defaultTailnet, nil) + c.UserAgent = "tailscale-k8s-operator" + c.HTTPClient = httpClient + if loginServer != "" { + c.BaseURL = loginServer + } + return c, nil +} + +type tsClient interface { + CreateKey(ctx context.Context, caps tailscale.KeyCapabilities) (string, *tailscale.Key, error) + Device(ctx context.Context, deviceID string, fields *tailscale.DeviceFieldsOpts) (*tailscale.Device, error) + DeleteDevice(ctx context.Context, nodeStableID string) error + // GetVIPService is a method for getting a Tailscale Service. VIPService is the original name for Tailscale Service. + GetVIPService(ctx context.Context, name tailcfg.ServiceName) (*tailscale.VIPService, error) + // ListVIPServices is a method for listing all Tailscale Services. VIPService is the original name for Tailscale Service. + ListVIPServices(ctx context.Context) (*tailscale.VIPServiceList, error) + // CreateOrUpdateVIPService is a method for creating or updating a Tailscale Service. + CreateOrUpdateVIPService(ctx context.Context, svc *tailscale.VIPService) error + // DeleteVIPService is a method for deleting a Tailscale Service. + DeleteVIPService(ctx context.Context, name tailcfg.ServiceName) error +} + +// jwtTokenSource implements the [oauth2.TokenSource] interface, but with the +// ability to regenerate a fresh underlying token source each time a new value +// of the JWT parameter is needed due to expiration. +type jwtTokenSource struct { + logger *zap.SugaredLogger + jwtPath string // Path to the file containing an automatically refreshed JWT. + baseCfg clientcredentials.Config // Holds config that doesn't change for the lifetime of the process. + + mu sync.Mutex // Guards underlying. + underlying oauth2.TokenSource // The oauth2 client implementation. Does its own separate caching of the access token. +} + +func (s *jwtTokenSource) Token() (*oauth2.Token, error) { + s.mu.Lock() + defer s.mu.Unlock() + + if s.underlying != nil { + t, err := s.underlying.Token() + if err == nil && t != nil && t.Valid() { + return t, nil + } + } + + s.logger.Debugf("Refreshing JWT from %s", s.jwtPath) + tk, err := os.ReadFile(s.jwtPath) + if err != nil { + return nil, fmt.Errorf("error reading JWT from %q: %w", s.jwtPath, err) + } + + // Shallow copy of the base config. + credentials := s.baseCfg + credentials.EndpointParams = map[string][]string{ + "jwt": {string(tk)}, + } + + src := credentials.TokenSource(context.Background()) + s.underlying = oauth2.ReuseTokenSourceWithExpiry(nil, src, time.Minute) + return s.underlying.Token() +} diff --git a/cmd/k8s-operator/tsclient_test.go b/cmd/k8s-operator/tsclient_test.go new file mode 100644 index 000000000..16de512d5 --- /dev/null +++ b/cmd/k8s-operator/tsclient_test.go @@ -0,0 +1,135 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !plan9 + +package main + +import ( + "encoding/json" + "fmt" + "io" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" + + "go.uber.org/zap" + "golang.org/x/oauth2" +) + +func TestNewStaticClient(t *testing.T) { + const ( + clientIDFile = "client-id" + clientSecretFile = "client-secret" + ) + + tmp := t.TempDir() + clientIDPath := filepath.Join(tmp, clientIDFile) + if err := os.WriteFile(clientIDPath, []byte("test-client-id"), 0600); err != nil { + t.Fatalf("error writing test file %q: %v", clientIDPath, err) + } + clientSecretPath := filepath.Join(tmp, clientSecretFile) + if err := os.WriteFile(clientSecretPath, []byte("test-client-secret"), 0600); err != nil { + t.Fatalf("error writing test file %q: %v", clientSecretPath, err) + } + + srv := testAPI(t, 3600) + cl, err := newTSClient(zap.NewNop().Sugar(), "", clientIDPath, clientSecretPath, srv.URL) + if err != nil { + t.Fatalf("error creating Tailscale client: %v", err) + } + + resp, err := cl.HTTPClient.Get(srv.URL) + if err != nil { + t.Fatalf("error making test API call: %v", err) + } + defer resp.Body.Close() + + got, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("error reading response body: %v", err) + } + want := "Bearer " + testToken("/api/v2/oauth/token", "test-client-id", "test-client-secret", "") + if string(got) != want { + t.Errorf("got %q; want %q", got, want) + } +} + +func TestNewWorkloadIdentityClient(t *testing.T) { + // 5 seconds is within expiryDelta leeway, so the access token will + // immediately be considered expired and get refreshed on each access. + srv := testAPI(t, 5) + cl, err := newTSClient(zap.NewNop().Sugar(), "test-client-id", "", "", srv.URL) + if err != nil { + t.Fatalf("error creating Tailscale client: %v", err) + } + + // Modify the path where the JWT will be read from. + oauth2Transport, ok := cl.HTTPClient.Transport.(*oauth2.Transport) + if !ok { + t.Fatalf("expected oauth2.Transport, got %T", cl.HTTPClient.Transport) + } + jwtTokenSource, ok := oauth2Transport.Source.(*jwtTokenSource) + if !ok { + t.Fatalf("expected jwtTokenSource, got %T", oauth2Transport.Source) + } + tmp := t.TempDir() + jwtPath := filepath.Join(tmp, "token") + jwtTokenSource.jwtPath = jwtPath + + for _, jwt := range []string{"test-jwt", "updated-test-jwt"} { + if err := os.WriteFile(jwtPath, []byte(jwt), 0600); err != nil { + t.Fatalf("error writing test file %q: %v", jwtPath, err) + } + resp, err := cl.HTTPClient.Get(srv.URL) + if err != nil { + t.Fatalf("error making test API call: %v", err) + } + defer resp.Body.Close() + + got, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("error reading response body: %v", err) + } + if want := "Bearer " + testToken("/api/v2/oauth/token-exchange", "test-client-id", "", jwt); string(got) != want { + t.Errorf("got %q; want %q", got, want) + } + } +} + +func testAPI(t *testing.T, expirationSeconds int) *httptest.Server { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + t.Logf("test server got request: %s %s", r.Method, r.URL.Path) + switch r.URL.Path { + case "/api/v2/oauth/token", "/api/v2/oauth/token-exchange": + id, secret, ok := r.BasicAuth() + if !ok { + t.Fatal("missing or invalid basic auth") + } + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(map[string]any{ + "access_token": testToken(r.URL.Path, id, secret, r.FormValue("jwt")), + "token_type": "Bearer", + "expires_in": expirationSeconds, + }); err != nil { + t.Fatalf("error writing response: %v", err) + } + case "/": + // Echo back the authz header for test assertions. + _, err := w.Write([]byte(r.Header.Get("Authorization"))) + if err != nil { + t.Fatalf("error writing response: %v", err) + } + default: + w.WriteHeader(http.StatusNotFound) + } + })) + t.Cleanup(srv.Close) + return srv +} + +func testToken(path, id, secret, jwt string) string { + return fmt.Sprintf("%s|%s|%s|%s", path, id, secret, jwt) +} diff --git a/cmd/k8s-operator/tsrecorder.go b/cmd/k8s-operator/tsrecorder.go index dfbf96b0b..c922f78fe 100644 --- a/cmd/k8s-operator/tsrecorder.go +++ b/cmd/k8s-operator/tsrecorder.go @@ -8,12 +8,13 @@ package main import ( "context" "encoding/json" + "errors" "fmt" "net/http" "slices" + "strings" "sync" - "github.com/pkg/errors" "go.uber.org/zap" xslices "golang.org/x/exp/slices" appsv1 "k8s.io/api/apps/v1" @@ -21,8 +22,10 @@ import ( rbacv1 "k8s.io/api/rbac/v1" apiequality "k8s.io/apimachinery/pkg/api/equality" apierrors "k8s.io/apimachinery/pkg/api/errors" + apivalidation "k8s.io/apimachinery/pkg/api/validation" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/types" + "k8s.io/apimachinery/pkg/util/validation/field" "k8s.io/client-go/tools/record" "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/reconcile" @@ -38,6 +41,7 @@ import ( const ( reasonRecorderCreationFailed = "RecorderCreationFailed" + reasonRecorderCreating = "RecorderCreating" reasonRecorderCreated = "RecorderCreated" reasonRecorderInvalid = "RecorderInvalid" @@ -50,18 +54,19 @@ var gaugeRecorderResources = clientmetric.NewGauge(kubetypes.MetricRecorderCount // Recorder CRs. type RecorderReconciler struct { client.Client - l *zap.SugaredLogger + log *zap.SugaredLogger recorder record.EventRecorder clock tstime.Clock tsNamespace string tsClient tsClient + loginServer string mu sync.Mutex // protects following recorders set.Slice[types.UID] // for recorders gauge } func (r *RecorderReconciler) logger(name string) *zap.SugaredLogger { - return r.l.With("Recorder", name) + return r.log.With("Recorder", name) } func (r *RecorderReconciler) Reconcile(ctx context.Context, req reconcile.Request) (_ reconcile.Result, err error) { @@ -102,10 +107,10 @@ func (r *RecorderReconciler) Reconcile(ctx context.Context, req reconcile.Reques oldTSRStatus := tsr.Status.DeepCopy() setStatusReady := func(tsr *tsapi.Recorder, status metav1.ConditionStatus, reason, message string) (reconcile.Result, error) { tsoperator.SetRecorderCondition(tsr, tsapi.RecorderReady, status, reason, message, tsr.Generation, r.clock, logger) - if !apiequality.Semantic.DeepEqual(oldTSRStatus, tsr.Status) { + if !apiequality.Semantic.DeepEqual(oldTSRStatus, &tsr.Status) { // An error encountered here should get returned by the Reconcile function. if updateErr := r.Client.Status().Update(ctx, tsr); updateErr != nil { - err = errors.Wrap(err, updateErr.Error()) + err = errors.Join(err, updateErr) } } return reconcile.Result{}, err @@ -119,23 +124,28 @@ func (r *RecorderReconciler) Reconcile(ctx context.Context, req reconcile.Reques logger.Infof("ensuring Recorder is set up") tsr.Finalizers = append(tsr.Finalizers, FinalizerName) if err := r.Update(ctx, tsr); err != nil { - logger.Errorf("error adding finalizer: %w", err) return setStatusReady(tsr, metav1.ConditionFalse, reasonRecorderCreationFailed, reasonRecorderCreationFailed) } } - if err := r.validate(tsr); err != nil { - logger.Errorf("error validating Recorder spec: %w", err) + if err := r.validate(ctx, tsr); err != nil { message := fmt.Sprintf("Recorder is invalid: %s", err) r.recorder.Eventf(tsr, corev1.EventTypeWarning, reasonRecorderInvalid, message) return setStatusReady(tsr, metav1.ConditionFalse, reasonRecorderInvalid, message) } if err = r.maybeProvision(ctx, tsr); err != nil { - logger.Errorf("error creating Recorder resources: %w", err) + reason := reasonRecorderCreationFailed message := fmt.Sprintf("failed creating Recorder: %s", err) - r.recorder.Eventf(tsr, corev1.EventTypeWarning, reasonRecorderCreationFailed, message) - return setStatusReady(tsr, metav1.ConditionFalse, reasonRecorderCreationFailed, message) + if strings.Contains(err.Error(), optimisticLockErrorMsg) { + reason = reasonRecorderCreating + message = fmt.Sprintf("optimistic lock error, retrying: %s", err) + err = nil + logger.Info(message) + } else { + r.recorder.Eventf(tsr, corev1.EventTypeWarning, reasonRecorderCreationFailed, message) + } + return setStatusReady(tsr, metav1.ConditionFalse, reason, message) } logger.Info("Recorder resources synced") @@ -153,20 +163,26 @@ func (r *RecorderReconciler) maybeProvision(ctx context.Context, tsr *tsapi.Reco if err := r.ensureAuthSecretCreated(ctx, tsr); err != nil { return fmt.Errorf("error creating secrets: %w", err) } - // State secret is precreated so we can use the Recorder CR as its owner ref. + // State Secret is precreated so we can use the Recorder CR as its owner ref. sec := tsrStateSecret(tsr, r.tsNamespace) if _, err := createOrUpdate(ctx, r.Client, r.tsNamespace, sec, func(s *corev1.Secret) { s.ObjectMeta.Labels = sec.ObjectMeta.Labels s.ObjectMeta.Annotations = sec.ObjectMeta.Annotations - s.ObjectMeta.OwnerReferences = sec.ObjectMeta.OwnerReferences }); err != nil { return fmt.Errorf("error creating state Secret: %w", err) } sa := tsrServiceAccount(tsr, r.tsNamespace) - if _, err := createOrUpdate(ctx, r.Client, r.tsNamespace, sa, func(s *corev1.ServiceAccount) { + if _, err := createOrMaybeUpdate(ctx, r.Client, r.tsNamespace, sa, func(s *corev1.ServiceAccount) error { + // Perform this check within the update function to make sure we don't + // have a race condition between the previous check and the update. + if err := saOwnedByRecorder(s, tsr); err != nil { + return err + } + s.ObjectMeta.Labels = sa.ObjectMeta.Labels s.ObjectMeta.Annotations = sa.ObjectMeta.Annotations - s.ObjectMeta.OwnerReferences = sa.ObjectMeta.OwnerReferences + + return nil }); err != nil { return fmt.Errorf("error creating ServiceAccount: %w", err) } @@ -174,7 +190,6 @@ func (r *RecorderReconciler) maybeProvision(ctx context.Context, tsr *tsapi.Reco if _, err := createOrUpdate(ctx, r.Client, r.tsNamespace, role, func(r *rbacv1.Role) { r.ObjectMeta.Labels = role.ObjectMeta.Labels r.ObjectMeta.Annotations = role.ObjectMeta.Annotations - r.ObjectMeta.OwnerReferences = role.ObjectMeta.OwnerReferences r.Rules = role.Rules }); err != nil { return fmt.Errorf("error creating Role: %w", err) @@ -183,22 +198,27 @@ func (r *RecorderReconciler) maybeProvision(ctx context.Context, tsr *tsapi.Reco if _, err := createOrUpdate(ctx, r.Client, r.tsNamespace, roleBinding, func(r *rbacv1.RoleBinding) { r.ObjectMeta.Labels = roleBinding.ObjectMeta.Labels r.ObjectMeta.Annotations = roleBinding.ObjectMeta.Annotations - r.ObjectMeta.OwnerReferences = roleBinding.ObjectMeta.OwnerReferences r.RoleRef = roleBinding.RoleRef r.Subjects = roleBinding.Subjects }); err != nil { return fmt.Errorf("error creating RoleBinding: %w", err) } - ss := tsrStatefulSet(tsr, r.tsNamespace) + ss := tsrStatefulSet(tsr, r.tsNamespace, r.loginServer) if _, err := createOrUpdate(ctx, r.Client, r.tsNamespace, ss, func(s *appsv1.StatefulSet) { s.ObjectMeta.Labels = ss.ObjectMeta.Labels s.ObjectMeta.Annotations = ss.ObjectMeta.Annotations - s.ObjectMeta.OwnerReferences = ss.ObjectMeta.OwnerReferences s.Spec = ss.Spec }); err != nil { return fmt.Errorf("error creating StatefulSet: %w", err) } + // ServiceAccount name may have changed, in which case we need to clean up + // the previous ServiceAccount. RoleBinding will already be updated to point + // to the new ServiceAccount. + if err := r.maybeCleanupServiceAccounts(ctx, tsr, sa.Name); err != nil { + return fmt.Errorf("error cleaning up ServiceAccounts: %w", err) + } + var devices []tsapi.RecorderTailnetDevice device, ok, err := r.getDeviceInfo(ctx, tsr.Name) @@ -217,13 +237,54 @@ func (r *RecorderReconciler) maybeProvision(ctx context.Context, tsr *tsapi.Reco return nil } +func saOwnedByRecorder(sa *corev1.ServiceAccount, tsr *tsapi.Recorder) error { + // If ServiceAccount name has been configured, check that we don't clobber + // a pre-existing SA not owned by this Recorder. + if sa.Name != tsr.Name && !apiequality.Semantic.DeepEqual(sa.OwnerReferences, tsrOwnerReference(tsr)) { + return fmt.Errorf("custom ServiceAccount name %q specified but conflicts with a pre-existing ServiceAccount in the %s namespace", sa.Name, sa.Namespace) + } + + return nil +} + +// maybeCleanupServiceAccounts deletes any dangling ServiceAccounts +// owned by the Recorder if the ServiceAccount name has been changed. +// They would eventually be cleaned up by owner reference deletion, but +// this avoids a long-lived Recorder with many ServiceAccount name changes +// accumulating a large amount of garbage. +// +// This is a no-op if the ServiceAccount name has not changed. +func (r *RecorderReconciler) maybeCleanupServiceAccounts(ctx context.Context, tsr *tsapi.Recorder, currentName string) error { + logger := r.logger(tsr.Name) + + // List all ServiceAccounts owned by this Recorder. + sas := &corev1.ServiceAccountList{} + if err := r.List(ctx, sas, client.InNamespace(r.tsNamespace), client.MatchingLabels(labels("recorder", tsr.Name, nil))); err != nil { + return fmt.Errorf("error listing ServiceAccounts for cleanup: %w", err) + } + for _, sa := range sas.Items { + if sa.Name == currentName { + continue + } + if err := r.Delete(ctx, &sa); err != nil { + if apierrors.IsNotFound(err) { + logger.Debugf("ServiceAccount %s not found, likely already deleted", sa.Name) + } else { + return fmt.Errorf("error deleting ServiceAccount %s: %w", sa.Name, err) + } + } + } + + return nil +} + // maybeCleanup just deletes the device from the tailnet. All the kubernetes // resources linked to a Recorder will get cleaned up via owner references // (which we can use because they are all in the same namespace). func (r *RecorderReconciler) maybeCleanup(ctx context.Context, tsr *tsapi.Recorder) (bool, error) { logger := r.logger(tsr.Name) - id, _, ok, err := r.getNodeMetadata(ctx, tsr.Name) + prefs, ok, err := r.getDevicePrefs(ctx, tsr.Name) if err != nil { return false, err } @@ -236,6 +297,7 @@ func (r *RecorderReconciler) maybeCleanup(ctx context.Context, tsr *tsapi.Record return true, nil } + id := string(prefs.Config.NodeID) logger.Debugf("deleting device %s from control", string(id)) if err := r.tsClient.DeleteDevice(ctx, string(id)); err != nil { errResp := &tailscale.ErrResponse{} @@ -294,17 +356,45 @@ func (r *RecorderReconciler) ensureAuthSecretCreated(ctx context.Context, tsr *t return nil } -func (r *RecorderReconciler) validate(tsr *tsapi.Recorder) error { +func (r *RecorderReconciler) validate(ctx context.Context, tsr *tsapi.Recorder) error { if !tsr.Spec.EnableUI && tsr.Spec.Storage.S3 == nil { return errors.New("must either enable UI or use S3 storage to ensure recordings are accessible") } + // Check any custom ServiceAccount config doesn't conflict with pre-existing + // ServiceAccounts. This check is performed once during validation to ensure + // errors are raised early, but also again during any Updates to prevent a race. + specSA := tsr.Spec.StatefulSet.Pod.ServiceAccount + if specSA.Name != "" && specSA.Name != tsr.Name { + sa := &corev1.ServiceAccount{} + key := client.ObjectKey{ + Name: specSA.Name, + Namespace: r.tsNamespace, + } + + err := r.Get(ctx, key, sa) + switch { + case apierrors.IsNotFound(err): + // ServiceAccount doesn't exist, so no conflict. + case err != nil: + return fmt.Errorf("error getting ServiceAccount %q for validation: %w", tsr.Spec.StatefulSet.Pod.ServiceAccount.Name, err) + default: + // ServiceAccount exists, check if it's owned by the Recorder. + if err := saOwnedByRecorder(sa, tsr); err != nil { + return err + } + } + } + if len(specSA.Annotations) > 0 { + if violations := apivalidation.ValidateAnnotations(specSA.Annotations, field.NewPath(".spec.statefulSet.pod.serviceAccount.annotations")); len(violations) > 0 { + return violations.ToAggregate() + } + } + return nil } -// getNodeMetadata returns 'ok == true' iff the node ID is found. The dnsName -// is expected to always be non-empty if the node ID is, but not required. -func (r *RecorderReconciler) getNodeMetadata(ctx context.Context, tsrName string) (id tailcfg.StableNodeID, dnsName string, ok bool, err error) { +func (r *RecorderReconciler) getStateSecret(ctx context.Context, tsrName string) (*corev1.Secret, error) { secret := &corev1.Secret{ ObjectMeta: metav1.ObjectMeta{ Namespace: r.tsNamespace, @@ -313,39 +403,59 @@ func (r *RecorderReconciler) getNodeMetadata(ctx context.Context, tsrName string } if err := r.Get(ctx, client.ObjectKeyFromObject(secret), secret); err != nil { if apierrors.IsNotFound(err) { - return "", "", false, nil + return nil, nil } - return "", "", false, err + return nil, fmt.Errorf("error getting state Secret: %w", err) } + return secret, nil +} + +func (r *RecorderReconciler) getDevicePrefs(ctx context.Context, tsrName string) (prefs prefs, ok bool, err error) { + secret, err := r.getStateSecret(ctx, tsrName) + if err != nil || secret == nil { + return prefs, false, err + } + + return getDevicePrefs(secret) +} + +// getDevicePrefs returns 'ok == true' iff the node ID is found. The dnsName +// is expected to always be non-empty if the node ID is, but not required. +func getDevicePrefs(secret *corev1.Secret) (prefs prefs, ok bool, err error) { // TODO(tomhjp): Should maybe use ipn to parse the following info instead. currentProfile, ok := secret.Data[currentProfileKey] if !ok { - return "", "", false, nil + return prefs, false, nil } profileBytes, ok := secret.Data[string(currentProfile)] if !ok { - return "", "", false, nil + return prefs, false, nil } - var profile profile - if err := json.Unmarshal(profileBytes, &profile); err != nil { - return "", "", false, fmt.Errorf("failed to extract node profile info from state Secret %s: %w", secret.Name, err) + if err := json.Unmarshal(profileBytes, &prefs); err != nil { + return prefs, false, fmt.Errorf("failed to extract node profile info from state Secret %s: %w", secret.Name, err) } - ok = profile.Config.NodeID != "" - return tailcfg.StableNodeID(profile.Config.NodeID), profile.Config.UserProfile.LoginName, ok, nil + ok = prefs.Config.NodeID != "" + return prefs, ok, nil } func (r *RecorderReconciler) getDeviceInfo(ctx context.Context, tsrName string) (d tsapi.RecorderTailnetDevice, ok bool, err error) { - nodeID, dnsName, ok, err := r.getNodeMetadata(ctx, tsrName) + secret, err := r.getStateSecret(ctx, tsrName) + if err != nil || secret == nil { + return tsapi.RecorderTailnetDevice{}, false, err + } + + prefs, ok, err := getDevicePrefs(secret) if !ok || err != nil { return tsapi.RecorderTailnetDevice{}, false, err } // TODO(tomhjp): The profile info doesn't include addresses, which is why we - // need the API. Should we instead update the profile to include addresses? - device, err := r.tsClient.Device(ctx, string(nodeID), nil) + // need the API. Should maybe update tsrecorder to write IPs to the state + // Secret like containerboot does. + device, err := r.tsClient.Device(ctx, string(prefs.Config.NodeID), nil) if err != nil { return tsapi.RecorderTailnetDevice{}, false, fmt.Errorf("failed to get device info from API: %w", err) } @@ -354,22 +464,27 @@ func (r *RecorderReconciler) getDeviceInfo(ctx context.Context, tsrName string) Hostname: device.Hostname, TailnetIPs: device.Addresses, } - if dnsName != "" { + if dnsName := prefs.Config.UserProfile.LoginName; dnsName != "" { d.URL = fmt.Sprintf("https://%s", dnsName) } return d, true, nil } -type profile struct { +// [prefs] is a subset of the ipn.Prefs struct used for extracting information +// from the state Secret of Tailscale devices. +type prefs struct { Config struct { - NodeID string `json:"NodeID"` + NodeID tailcfg.StableNodeID `json:"NodeID"` UserProfile struct { + // LoginName is the MagicDNS name of the device, e.g. foo.tail-scale.ts.net. LoginName string `json:"LoginName"` } `json:"UserProfile"` } `json:"Config"` + + AdvertiseServices []string `json:"AdvertiseServices"` } -func markedForDeletion(tsr *tsapi.Recorder) bool { - return !tsr.DeletionTimestamp.IsZero() +func markedForDeletion(obj metav1.Object) bool { + return !obj.GetDeletionTimestamp().IsZero() } diff --git a/cmd/k8s-operator/tsrecorder_specs.go b/cmd/k8s-operator/tsrecorder_specs.go index 4a74fb7e0..83d7439db 100644 --- a/cmd/k8s-operator/tsrecorder_specs.go +++ b/cmd/k8s-operator/tsrecorder_specs.go @@ -17,7 +17,7 @@ import ( "tailscale.com/version" ) -func tsrStatefulSet(tsr *tsapi.Recorder, namespace string) *appsv1.StatefulSet { +func tsrStatefulSet(tsr *tsapi.Recorder, namespace string, loginServer string) *appsv1.StatefulSet { return &appsv1.StatefulSet{ ObjectMeta: metav1.ObjectMeta{ Name: tsr.Name, @@ -39,7 +39,7 @@ func tsrStatefulSet(tsr *tsapi.Recorder, namespace string) *appsv1.StatefulSet { Annotations: tsr.Spec.StatefulSet.Pod.Annotations, }, Spec: corev1.PodSpec{ - ServiceAccountName: tsr.Name, + ServiceAccountName: tsrServiceAccountName(tsr), Affinity: tsr.Spec.StatefulSet.Pod.Affinity, SecurityContext: tsr.Spec.StatefulSet.Pod.SecurityContext, ImagePullSecrets: tsr.Spec.StatefulSet.Pod.ImagePullSecrets, @@ -59,7 +59,7 @@ func tsrStatefulSet(tsr *tsapi.Recorder, namespace string) *appsv1.StatefulSet { ImagePullPolicy: tsr.Spec.StatefulSet.Pod.Container.ImagePullPolicy, Resources: tsr.Spec.StatefulSet.Pod.Container.Resources, SecurityContext: tsr.Spec.StatefulSet.Pod.Container.SecurityContext, - Env: env(tsr), + Env: env(tsr, loginServer), EnvFrom: func() []corev1.EnvFromSource { if tsr.Spec.Storage.S3 == nil || tsr.Spec.Storage.S3.Credentials.Secret.Name == "" { return nil @@ -100,14 +100,25 @@ func tsrStatefulSet(tsr *tsapi.Recorder, namespace string) *appsv1.StatefulSet { func tsrServiceAccount(tsr *tsapi.Recorder, namespace string) *corev1.ServiceAccount { return &corev1.ServiceAccount{ ObjectMeta: metav1.ObjectMeta{ - Name: tsr.Name, + Name: tsrServiceAccountName(tsr), Namespace: namespace, Labels: labels("recorder", tsr.Name, nil), OwnerReferences: tsrOwnerReference(tsr), + Annotations: tsr.Spec.StatefulSet.Pod.ServiceAccount.Annotations, }, } } +func tsrServiceAccountName(tsr *tsapi.Recorder) string { + sa := tsr.Spec.StatefulSet.Pod.ServiceAccount + name := tsr.Name + if sa.Name != "" { + name = sa.Name + } + + return name +} + func tsrRole(tsr *tsapi.Recorder, namespace string) *rbacv1.Role { return &rbacv1.Role{ ObjectMeta: metav1.ObjectMeta{ @@ -130,6 +141,15 @@ func tsrRole(tsr *tsapi.Recorder, namespace string) *rbacv1.Role { fmt.Sprintf("%s-0", tsr.Name), // Contains the node state. }, }, + { + APIGroups: []string{""}, + Resources: []string{"events"}, + Verbs: []string{ + "get", + "create", + "patch", + }, + }, }, } } @@ -145,7 +165,7 @@ func tsrRoleBinding(tsr *tsapi.Recorder, namespace string) *rbacv1.RoleBinding { Subjects: []rbacv1.Subject{ { Kind: "ServiceAccount", - Name: tsr.Name, + Name: tsrServiceAccountName(tsr), Namespace: namespace, }, }, @@ -181,7 +201,7 @@ func tsrStateSecret(tsr *tsapi.Recorder, namespace string) *corev1.Secret { } } -func env(tsr *tsapi.Recorder) []corev1.EnvVar { +func env(tsr *tsapi.Recorder, loginServer string) []corev1.EnvVar { envs := []corev1.EnvVar{ { Name: "TS_AUTHKEY", @@ -203,6 +223,14 @@ func env(tsr *tsapi.Recorder) []corev1.EnvVar { }, }, }, + { + Name: "POD_UID", + ValueFrom: &corev1.EnvVarSource{ + FieldRef: &corev1.ObjectFieldSelector{ + FieldPath: "metadata.uid", + }, + }, + }, { Name: "TS_STATE", Value: "kube:$(POD_NAME)", @@ -211,6 +239,10 @@ func env(tsr *tsapi.Recorder) []corev1.EnvVar { Name: "TSRECORDER_HOSTNAME", Value: "$(POD_NAME)", }, + { + Name: "TSRECORDER_LOGIN_SERVER", + Value: loginServer, + }, } for _, env := range tsr.Spec.StatefulSet.Pod.Container.Env { @@ -249,17 +281,17 @@ func env(tsr *tsapi.Recorder) []corev1.EnvVar { } func labels(app, instance string, customLabels map[string]string) map[string]string { - l := make(map[string]string, len(customLabels)+3) + labels := make(map[string]string, len(customLabels)+3) for k, v := range customLabels { - l[k] = v + labels[k] = v } // ref: https://kubernetes.io/docs/concepts/overview/working-with-objects/common-labels/ - l["app.kubernetes.io/name"] = app - l["app.kubernetes.io/instance"] = instance - l["app.kubernetes.io/managed-by"] = "tailscale-operator" + labels["app.kubernetes.io/name"] = app + labels["app.kubernetes.io/instance"] = instance + labels["app.kubernetes.io/managed-by"] = "tailscale-operator" - return l + return labels } func tsrOwnerReference(owner metav1.Object) []metav1.OwnerReference { diff --git a/cmd/k8s-operator/tsrecorder_specs_test.go b/cmd/k8s-operator/tsrecorder_specs_test.go index 94a8a816c..49332d09b 100644 --- a/cmd/k8s-operator/tsrecorder_specs_test.go +++ b/cmd/k8s-operator/tsrecorder_specs_test.go @@ -90,7 +90,7 @@ func TestRecorderSpecs(t *testing.T) { }, } - ss := tsrStatefulSet(tsr, tsNamespace) + ss := tsrStatefulSet(tsr, tsNamespace, tsLoginServer) // StatefulSet-level. if diff := cmp.Diff(ss.Annotations, tsr.Spec.StatefulSet.Annotations); diff != "" { @@ -124,7 +124,7 @@ func TestRecorderSpecs(t *testing.T) { } // Container-level. - if diff := cmp.Diff(ss.Spec.Template.Spec.Containers[0].Env, env(tsr)); diff != "" { + if diff := cmp.Diff(ss.Spec.Template.Spec.Containers[0].Env, env(tsr, tsLoginServer)); diff != "" { t.Errorf("(-got +want):\n%s", diff) } if diff := cmp.Diff(ss.Spec.Template.Spec.Containers[0].Image, tsr.Spec.StatefulSet.Pod.Container.Image); diff != "" { diff --git a/cmd/k8s-operator/tsrecorder_test.go b/cmd/k8s-operator/tsrecorder_test.go index a3500f191..184af2344 100644 --- a/cmd/k8s-operator/tsrecorder_test.go +++ b/cmd/k8s-operator/tsrecorder_test.go @@ -8,6 +8,7 @@ package main import ( "context" "encoding/json" + "strings" "testing" "github.com/google/go-cmp/cmp" @@ -24,7 +25,10 @@ import ( "tailscale.com/tstest" ) -const tsNamespace = "tailscale" +const ( + tsNamespace = "tailscale" + tsLoginServer = "example.tailscale.com" +) func TestRecorder(t *testing.T) { tsr := &tsapi.Recorder{ @@ -41,23 +45,24 @@ func TestRecorder(t *testing.T) { Build() tsClient := &fakeTSClient{} zl, _ := zap.NewDevelopment() - fr := record.NewFakeRecorder(1) + fr := record.NewFakeRecorder(2) cl := tstest.NewClock(tstest.ClockOpts{}) reconciler := &RecorderReconciler{ tsNamespace: tsNamespace, Client: fc, tsClient: tsClient, recorder: fr, - l: zl.Sugar(), + log: zl.Sugar(), clock: cl, + loginServer: tsLoginServer, } - t.Run("invalid spec gives an error condition", func(t *testing.T) { + t.Run("invalid_spec_gives_an_error_condition", func(t *testing.T) { expectReconciled(t, reconciler, "", tsr.Name) msg := "Recorder is invalid: must either enable UI or use S3 storage to ensure recordings are accessible" tsoperator.SetRecorderCondition(tsr, tsapi.RecorderReady, metav1.ConditionFalse, reasonRecorderInvalid, msg, 0, cl, zl.Sugar()) - expectEqual(t, fc, tsr, nil) + expectEqual(t, fc, tsr) if expected := 0; reconciler.recorders.Len() != expected { t.Fatalf("expected %d recorders, got %d", expected, reconciler.recorders.Len()) } @@ -65,10 +70,66 @@ func TestRecorder(t *testing.T) { expectedEvent := "Warning RecorderInvalid Recorder is invalid: must either enable UI or use S3 storage to ensure recordings are accessible" expectEvents(t, fr, []string{expectedEvent}) - }) - t.Run("observe Ready=true status condition for a valid spec", func(t *testing.T) { tsr.Spec.EnableUI = true + tsr.Spec.StatefulSet.Pod.ServiceAccount.Annotations = map[string]string{ + "invalid space characters": "test", + } + mustUpdate(t, fc, "", "test", func(t *tsapi.Recorder) { + t.Spec = tsr.Spec + }) + expectReconciled(t, reconciler, "", tsr.Name) + + // Only check part of this error message, because it's defined in an + // external package and may change. + if err := fc.Get(context.Background(), client.ObjectKey{ + Name: tsr.Name, + }, tsr); err != nil { + t.Fatal(err) + } + if len(tsr.Status.Conditions) != 1 { + t.Fatalf("expected 1 condition, got %d", len(tsr.Status.Conditions)) + } + cond := tsr.Status.Conditions[0] + if cond.Type != string(tsapi.RecorderReady) || cond.Status != metav1.ConditionFalse || cond.Reason != reasonRecorderInvalid { + t.Fatalf("expected condition RecorderReady false due to RecorderInvalid, got %v", cond) + } + for _, msg := range []string{cond.Message, <-fr.Events} { + if !strings.Contains(msg, `"invalid space characters"`) { + t.Fatalf("expected invalid annotation key in error message, got %q", cond.Message) + } + } + }) + + t.Run("conflicting_service_account_config_marked_as_invalid", func(t *testing.T) { + mustCreate(t, fc, &corev1.ServiceAccount{ + ObjectMeta: metav1.ObjectMeta{ + Name: "pre-existing-sa", + Namespace: tsNamespace, + }, + }) + + tsr.Spec.StatefulSet.Pod.ServiceAccount.Annotations = nil + tsr.Spec.StatefulSet.Pod.ServiceAccount.Name = "pre-existing-sa" + mustUpdate(t, fc, "", "test", func(t *tsapi.Recorder) { + t.Spec = tsr.Spec + }) + + expectReconciled(t, reconciler, "", tsr.Name) + + msg := `Recorder is invalid: custom ServiceAccount name "pre-existing-sa" specified but conflicts with a pre-existing ServiceAccount in the tailscale namespace` + tsoperator.SetRecorderCondition(tsr, tsapi.RecorderReady, metav1.ConditionFalse, reasonRecorderInvalid, msg, 0, cl, zl.Sugar()) + expectEqual(t, fc, tsr) + if expected := 0; reconciler.recorders.Len() != expected { + t.Fatalf("expected %d recorders, got %d", expected, reconciler.recorders.Len()) + } + + expectedEvent := "Warning RecorderInvalid " + msg + expectEvents(t, fr, []string{expectedEvent}) + }) + + t.Run("observe_Ready_true_status_condition_for_a_valid_spec", func(t *testing.T) { + tsr.Spec.StatefulSet.Pod.ServiceAccount.Name = "" mustUpdate(t, fc, "", "test", func(t *tsapi.Recorder) { t.Spec = tsr.Spec }) @@ -76,14 +137,49 @@ func TestRecorder(t *testing.T) { expectReconciled(t, reconciler, "", tsr.Name) tsoperator.SetRecorderCondition(tsr, tsapi.RecorderReady, metav1.ConditionTrue, reasonRecorderCreated, reasonRecorderCreated, 0, cl, zl.Sugar()) - expectEqual(t, fc, tsr, nil) + expectEqual(t, fc, tsr) + if expected := 1; reconciler.recorders.Len() != expected { + t.Fatalf("expected %d recorders, got %d", expected, reconciler.recorders.Len()) + } + expectRecorderResources(t, fc, tsr, true) + }) + + t.Run("valid_service_account_config", func(t *testing.T) { + tsr.Spec.StatefulSet.Pod.ServiceAccount.Name = "test-sa" + tsr.Spec.StatefulSet.Pod.ServiceAccount.Annotations = map[string]string{ + "test": "test", + } + mustUpdate(t, fc, "", "test", func(t *tsapi.Recorder) { + t.Spec = tsr.Spec + }) + + expectReconciled(t, reconciler, "", tsr.Name) + + expectEqual(t, fc, tsr) if expected := 1; reconciler.recorders.Len() != expected { t.Fatalf("expected %d recorders, got %d", expected, reconciler.recorders.Len()) } expectRecorderResources(t, fc, tsr, true) + + // Get the service account and check the annotations. + sa := &corev1.ServiceAccount{} + if err := fc.Get(context.Background(), client.ObjectKey{ + Name: tsr.Spec.StatefulSet.Pod.ServiceAccount.Name, + Namespace: tsNamespace, + }, sa); err != nil { + t.Fatal(err) + } + if diff := cmp.Diff(sa.Annotations, tsr.Spec.StatefulSet.Pod.ServiceAccount.Annotations); diff != "" { + t.Fatalf("unexpected service account annotations (-got +want):\n%s", diff) + } + if sa.Name != tsr.Spec.StatefulSet.Pod.ServiceAccount.Name { + t.Fatalf("unexpected service account name: got %q, want %q", sa.Name, tsr.Spec.StatefulSet.Pod.ServiceAccount.Name) + } + + expectMissing[corev1.ServiceAccount](t, fc, tsNamespace, tsr.Name) }) - t.Run("populate node info in state secret, and see it appear in status", func(t *testing.T) { + t.Run("populate_node_info_in_state_secret_and_see_it_appear_in_status", func(t *testing.T) { bytes, err := json.Marshal(map[string]any{ "Config": map[string]any{ "NodeID": "nodeid-123", @@ -107,15 +203,15 @@ func TestRecorder(t *testing.T) { expectReconciled(t, reconciler, "", tsr.Name) tsr.Status.Devices = []tsapi.RecorderTailnetDevice{ { - Hostname: "test-device", + Hostname: "hostname-nodeid-123", TailnetIPs: []string{"1.2.3.4", "::1"}, URL: "https://test-0.example.ts.net", }, } - expectEqual(t, fc, tsr, nil) + expectEqual(t, fc, tsr) }) - t.Run("delete the Recorder and observe cleanup", func(t *testing.T) { + t.Run("delete_the_Recorder_and_observe_cleanup", func(t *testing.T) { if err := fc.Delete(context.Background(), tsr); err != nil { t.Fatal(err) } @@ -142,15 +238,15 @@ func expectRecorderResources(t *testing.T, fc client.WithWatch, tsr *tsapi.Recor role := tsrRole(tsr, tsNamespace) roleBinding := tsrRoleBinding(tsr, tsNamespace) serviceAccount := tsrServiceAccount(tsr, tsNamespace) - statefulSet := tsrStatefulSet(tsr, tsNamespace) + statefulSet := tsrStatefulSet(tsr, tsNamespace, tsLoginServer) if shouldExist { - expectEqual(t, fc, auth, nil) - expectEqual(t, fc, state, nil) - expectEqual(t, fc, role, nil) - expectEqual(t, fc, roleBinding, nil) - expectEqual(t, fc, serviceAccount, nil) - expectEqual(t, fc, statefulSet, nil) + expectEqual(t, fc, auth) + expectEqual(t, fc, state) + expectEqual(t, fc, role) + expectEqual(t, fc, roleBinding) + expectEqual(t, fc, serviceAccount) + expectEqual(t, fc, statefulSet, removeResourceReqs) } else { expectMissing[corev1.Secret](t, fc, auth.Namespace, auth.Name) expectMissing[corev1.Secret](t, fc, state.Namespace, state.Name) diff --git a/cmd/k8s-proxy/internal/config/config.go b/cmd/k8s-proxy/internal/config/config.go new file mode 100644 index 000000000..0f0bd1bfc --- /dev/null +++ b/cmd/k8s-proxy/internal/config/config.go @@ -0,0 +1,264 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !plan9 + +// Package config provides watchers for the various supported ways to load a +// config file for k8s-proxy; currently file or Kubernetes Secret. +package config + +import ( + "bytes" + "context" + "errors" + "fmt" + "os" + "path/filepath" + "strings" + "sync" + "time" + + "github.com/fsnotify/fsnotify" + "go.uber.org/zap" + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/types" + "k8s.io/apimachinery/pkg/watch" + clientcorev1 "k8s.io/client-go/kubernetes/typed/core/v1" + "tailscale.com/kube/k8s-proxy/conf" + "tailscale.com/kube/kubetypes" + "tailscale.com/types/ptr" + "tailscale.com/util/testenv" +) + +type configLoader struct { + logger *zap.SugaredLogger + client clientcorev1.CoreV1Interface + + cfgChan chan<- *conf.Config + previous []byte + + once sync.Once // For use in tests. To close cfgIgnored. + cfgIgnored chan struct{} // For use in tests. +} + +func NewConfigLoader(logger *zap.SugaredLogger, client clientcorev1.CoreV1Interface, cfgChan chan<- *conf.Config) *configLoader { + return &configLoader{ + logger: logger, + client: client, + cfgChan: cfgChan, + } +} + +func (ld *configLoader) WatchConfig(ctx context.Context, path string) error { + secretNamespacedName, isKubeSecret := strings.CutPrefix(path, "kube:") + if isKubeSecret { + secretNamespace, secretName, ok := strings.Cut(secretNamespacedName, string(types.Separator)) + if !ok { + return fmt.Errorf("invalid Kubernetes Secret reference %q, expected format /", path) + } + if err := ld.watchConfigSecretChanges(ctx, secretNamespace, secretName); err != nil && !errors.Is(err, context.Canceled) { + return fmt.Errorf("error watching config Secret %q: %w", secretNamespacedName, err) + } + + return nil + } + + if err := ld.watchConfigFileChanges(ctx, path); err != nil && !errors.Is(err, context.Canceled) { + return fmt.Errorf("error watching config file %q: %w", path, err) + } + + return nil +} + +func (ld *configLoader) reloadConfig(ctx context.Context, raw []byte) error { + if bytes.Equal(raw, ld.previous) { + if ld.cfgIgnored != nil && testenv.InTest() { + ld.once.Do(func() { + close(ld.cfgIgnored) + }) + } + return nil + } + + cfg, err := conf.Load(raw) + if err != nil { + return fmt.Errorf("error loading config: %w", err) + } + + select { + case <-ctx.Done(): + return ctx.Err() + case ld.cfgChan <- &cfg: + } + + ld.previous = raw + return nil +} + +func (ld *configLoader) watchConfigFileChanges(ctx context.Context, path string) error { + var ( + tickChan <-chan time.Time + eventChan <-chan fsnotify.Event + errChan <-chan error + ) + + if w, err := fsnotify.NewWatcher(); err != nil { + // Creating a new fsnotify watcher would fail for example if inotify was not able to create a new file descriptor. + // See https://github.com/tailscale/tailscale/issues/15081 + ld.logger.Infof("Failed to create fsnotify watcher on config file %q; watching for changes on 5s timer: %v", path, err) + ticker := time.NewTicker(5 * time.Second) + defer ticker.Stop() + tickChan = ticker.C + } else { + dir := filepath.Dir(path) + file := filepath.Base(path) + ld.logger.Infof("Watching directory %q for changes to config file %q", dir, file) + defer w.Close() + if err := w.Add(dir); err != nil { + return fmt.Errorf("failed to add fsnotify watch: %w", err) + } + eventChan = w.Events + errChan = w.Errors + } + + // Read the initial config file, but after the watcher is already set up to + // avoid an unlucky race condition if the config file is edited in between. + b, err := os.ReadFile(path) + if err != nil { + return fmt.Errorf("error reading config file %q: %w", path, err) + } + if err := ld.reloadConfig(ctx, b); err != nil { + return fmt.Errorf("error loading initial config file %q: %w", path, err) + } + + for { + select { + case <-ctx.Done(): + return ctx.Err() + case err, ok := <-errChan: + if !ok { + // Watcher was closed. + return nil + } + return fmt.Errorf("watcher error: %w", err) + case <-tickChan: + case ev, ok := <-eventChan: + if !ok { + // Watcher was closed. + return nil + } + if ev.Name != path || ev.Op&fsnotify.Write == 0 { + // Ignore irrelevant events. + continue + } + } + b, err := os.ReadFile(path) + if err != nil { + return fmt.Errorf("error reading config file: %w", err) + } + // Writers such as os.WriteFile may truncate the file before writing + // new contents, so it's possible to read an empty file if we read before + // the write has completed. + if len(b) == 0 { + continue + } + if err := ld.reloadConfig(ctx, b); err != nil { + return fmt.Errorf("error reloading config file %q: %v", path, err) + } + } +} + +func (ld *configLoader) watchConfigSecretChanges(ctx context.Context, secretNamespace, secretName string) error { + secrets := ld.client.Secrets(secretNamespace) + w, err := secrets.Watch(ctx, metav1.ListOptions{ + TypeMeta: metav1.TypeMeta{ + Kind: "Secret", + APIVersion: "v1", + }, + // Re-watch regularly to avoid relying on long-lived connections. + // See https://github.com/kubernetes-client/javascript/issues/596#issuecomment-786419380 + TimeoutSeconds: ptr.To(int64(600)), + FieldSelector: fmt.Sprintf("metadata.name=%s", secretName), + Watch: true, + }) + if err != nil { + return fmt.Errorf("failed to watch config Secret %q: %w", secretName, err) + } + defer func() { + // May not be the original watcher by the time we exit. + if w != nil { + w.Stop() + } + }() + + // Get the initial config Secret now we've got the watcher set up. + secret, err := secrets.Get(ctx, secretName, metav1.GetOptions{}) + if err != nil { + return fmt.Errorf("failed to get config Secret %q: %w", secretName, err) + } + + if err := ld.configFromSecret(ctx, secret); err != nil { + return fmt.Errorf("error loading initial config: %w", err) + } + + ld.logger.Infof("Watching config Secret %q for changes", secretName) + for { + var secret *corev1.Secret + select { + case <-ctx.Done(): + return ctx.Err() + case ev, ok := <-w.ResultChan(): + if !ok { + w.Stop() + w, err = secrets.Watch(ctx, metav1.ListOptions{ + TypeMeta: metav1.TypeMeta{ + Kind: "Secret", + APIVersion: "v1", + }, + TimeoutSeconds: ptr.To(int64(600)), + FieldSelector: fmt.Sprintf("metadata.name=%s", secretName), + Watch: true, + }) + if err != nil { + return fmt.Errorf("failed to re-watch config Secret %q: %w", secretName, err) + } + continue + } + + switch ev.Type { + case watch.Added, watch.Modified: + // New config available to load. + var ok bool + secret, ok = ev.Object.(*corev1.Secret) + if !ok { + return fmt.Errorf("unexpected object type %T in watch event for config Secret %q", ev.Object, secretName) + } + if secret == nil || secret.Data == nil { + continue + } + if err := ld.configFromSecret(ctx, secret); err != nil { + return fmt.Errorf("error reloading config Secret %q: %v", secret.Name, err) + } + case watch.Error: + return fmt.Errorf("error watching config Secret %q: %v", secretName, ev.Object) + default: + // Ignore, no action required. + continue + } + } + } +} + +func (ld *configLoader) configFromSecret(ctx context.Context, s *corev1.Secret) error { + b := s.Data[kubetypes.KubeAPIServerConfigFile] + if len(b) == 0 { + return fmt.Errorf("config Secret %q does not contain expected config in key %q", s.Name, kubetypes.KubeAPIServerConfigFile) + } + + if err := ld.reloadConfig(ctx, b); err != nil { + return err + } + + return nil +} diff --git a/cmd/k8s-proxy/internal/config/config_test.go b/cmd/k8s-proxy/internal/config/config_test.go new file mode 100644 index 000000000..bcb1b9ebd --- /dev/null +++ b/cmd/k8s-proxy/internal/config/config_test.go @@ -0,0 +1,245 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package config + +import ( + "context" + "os" + "path/filepath" + "strings" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "go.uber.org/zap" + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/watch" + "k8s.io/client-go/kubernetes/fake" + ktesting "k8s.io/client-go/testing" + "tailscale.com/kube/k8s-proxy/conf" + "tailscale.com/kube/kubetypes" + "tailscale.com/types/ptr" +) + +func TestWatchConfig(t *testing.T) { + type phase struct { + config string + cancel bool + expectedConf *conf.ConfigV1Alpha1 + expectedErr string + } + + // Same set of behaviour tests for each config source. + for _, env := range []string{"file", "kube"} { + t.Run(env, func(t *testing.T) { + t.Parallel() + + for _, tc := range []struct { + name string + initialConfig string + phases []phase + }{ + { + name: "no_config", + phases: []phase{{ + expectedErr: "error loading initial config", + }}, + }, + { + name: "valid_config", + initialConfig: `{"version": "v1alpha1", "authKey": "abc123"}`, + phases: []phase{{ + expectedConf: &conf.ConfigV1Alpha1{ + AuthKey: ptr.To("abc123"), + }, + }}, + }, + { + name: "can_cancel", + initialConfig: `{"version": "v1alpha1", "authKey": "abc123"}`, + phases: []phase{ + { + expectedConf: &conf.ConfigV1Alpha1{ + AuthKey: ptr.To("abc123"), + }, + }, + { + cancel: true, + }, + }, + }, + { + name: "can_reload", + initialConfig: `{"version": "v1alpha1", "authKey": "abc123"}`, + phases: []phase{ + { + expectedConf: &conf.ConfigV1Alpha1{ + AuthKey: ptr.To("abc123"), + }, + }, + { + config: `{"version": "v1alpha1", "authKey": "def456"}`, + expectedConf: &conf.ConfigV1Alpha1{ + AuthKey: ptr.To("def456"), + }, + }, + }, + }, + { + name: "ignores_events_with_no_changes", + initialConfig: `{"version": "v1alpha1", "authKey": "abc123"}`, + phases: []phase{ + { + expectedConf: &conf.ConfigV1Alpha1{ + AuthKey: ptr.To("abc123"), + }, + }, + { + config: `{"version": "v1alpha1", "authKey": "abc123"}`, + }, + }, + }, + } { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + root := t.TempDir() + cl := fake.NewClientset() + + var cfgPath string + var writeFile func(*testing.T, string) + if env == "file" { + cfgPath = filepath.Join(root, kubetypes.KubeAPIServerConfigFile) + writeFile = func(t *testing.T, content string) { + if err := os.WriteFile(cfgPath, []byte(content), 0o644); err != nil { + t.Fatalf("error writing config file %q: %v", cfgPath, err) + } + } + } else { + cfgPath = "kube:default/config-secret" + writeFile = func(t *testing.T, content string) { + s := secretFrom(content) + mustCreateOrUpdate(t, cl, s) + } + } + configChan := make(chan *conf.Config) + loader := NewConfigLoader(zap.Must(zap.NewDevelopment()).Sugar(), cl.CoreV1(), configChan) + loader.cfgIgnored = make(chan struct{}) + errs := make(chan error) + ctx, cancel := context.WithCancel(t.Context()) + defer cancel() + + writeFile(t, tc.initialConfig) + go func() { + errs <- loader.WatchConfig(ctx, cfgPath) + }() + + for i, p := range tc.phases { + if p.config != "" { + writeFile(t, p.config) + } + if p.cancel { + cancel() + } + + select { + case cfg := <-configChan: + if diff := cmp.Diff(*p.expectedConf, cfg.Parsed); diff != "" { + t.Errorf("unexpected config (-want +got):\n%s", diff) + } + case err := <-errs: + if p.cancel { + if err != nil { + t.Fatalf("unexpected error after cancel: %v", err) + } + } else if p.expectedErr == "" { + t.Fatalf("unexpected error: %v", err) + } else if !strings.Contains(err.Error(), p.expectedErr) { + t.Fatalf("expected error to contain %q, got %q", p.expectedErr, err.Error()) + } + case <-loader.cfgIgnored: + if p.expectedConf != nil { + t.Fatalf("expected config to be reloaded, but got ignored signal") + } + case <-time.After(5 * time.Second): + t.Fatalf("timed out waiting for expected event in phase: %d", i) + } + } + }) + } + }) + } +} + +func TestWatchConfigSecret_Rewatches(t *testing.T) { + cl := fake.NewClientset() + var watchCount int + var watcher *watch.RaceFreeFakeWatcher + expected := []string{ + `{"version": "v1alpha1", "authKey": "abc123"}`, + `{"version": "v1alpha1", "authKey": "def456"}`, + `{"version": "v1alpha1", "authKey": "ghi789"}`, + } + cl.PrependWatchReactor("secrets", func(action ktesting.Action) (handled bool, ret watch.Interface, err error) { + watcher = watch.NewRaceFreeFake() + watcher.Add(secretFrom(expected[watchCount])) + if action.GetVerb() == "watch" && action.GetResource().Resource == "secrets" { + watchCount++ + } + return true, watcher, nil + }) + + configChan := make(chan *conf.Config) + loader := NewConfigLoader(zap.Must(zap.NewDevelopment()).Sugar(), cl.CoreV1(), configChan) + + mustCreateOrUpdate(t, cl, secretFrom(expected[0])) + + errs := make(chan error) + go func() { + errs <- loader.watchConfigSecretChanges(t.Context(), "default", "config-secret") + }() + + for i := range 2 { + select { + case cfg := <-configChan: + if exp := expected[i]; cfg.Parsed.AuthKey == nil || !strings.Contains(exp, *cfg.Parsed.AuthKey) { + t.Fatalf("expected config to have authKey %q, got: %v", exp, cfg.Parsed.AuthKey) + } + if i == 0 { + watcher.Stop() + } + case err := <-errs: + t.Fatalf("unexpected error: %v", err) + case <-loader.cfgIgnored: + t.Fatalf("expected config to be reloaded, but got ignored signal") + case <-time.After(5 * time.Second): + t.Fatalf("timed out waiting for expected event") + } + } + + if watchCount != 2 { + t.Fatalf("expected 2 watch API calls, got %d", watchCount) + } +} + +func secretFrom(content string) *corev1.Secret { + return &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: "config-secret", + }, + Data: map[string][]byte{ + kubetypes.KubeAPIServerConfigFile: []byte(content), + }, + } +} + +func mustCreateOrUpdate(t *testing.T, cl *fake.Clientset, s *corev1.Secret) { + t.Helper() + if _, err := cl.CoreV1().Secrets("default").Create(t.Context(), s, metav1.CreateOptions{}); err != nil { + if _, updateErr := cl.CoreV1().Secrets("default").Update(t.Context(), s, metav1.UpdateOptions{}); updateErr != nil { + t.Fatalf("error writing config Secret %q: %v", s.Name, updateErr) + } + } +} diff --git a/cmd/k8s-proxy/k8s-proxy.go b/cmd/k8s-proxy/k8s-proxy.go new file mode 100644 index 000000000..9b2bb6749 --- /dev/null +++ b/cmd/k8s-proxy/k8s-proxy.go @@ -0,0 +1,477 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !plan9 + +// k8s-proxy proxies between tailnet and Kubernetes cluster traffic. +// Currently, it only supports proxying tailnet clients to the Kubernetes API +// server. +package main + +import ( + "context" + "errors" + "fmt" + "net" + "net/http" + "os" + "os/signal" + "reflect" + "strconv" + "strings" + "syscall" + "time" + + "go.uber.org/zap" + "go.uber.org/zap/zapcore" + "golang.org/x/sync/errgroup" + "k8s.io/client-go/kubernetes" + "k8s.io/client-go/rest" + "k8s.io/client-go/tools/clientcmd" + "k8s.io/utils/strings/slices" + "tailscale.com/client/local" + "tailscale.com/cmd/k8s-proxy/internal/config" + "tailscale.com/hostinfo" + "tailscale.com/ipn" + "tailscale.com/ipn/store" + + // we need to import this package so that the `kube:` ipn store gets registered + _ "tailscale.com/ipn/store/kubestore" + apiproxy "tailscale.com/k8s-operator/api-proxy" + "tailscale.com/kube/certs" + healthz "tailscale.com/kube/health" + "tailscale.com/kube/k8s-proxy/conf" + "tailscale.com/kube/kubetypes" + klc "tailscale.com/kube/localclient" + "tailscale.com/kube/metrics" + "tailscale.com/kube/services" + "tailscale.com/kube/state" + "tailscale.com/tailcfg" + "tailscale.com/tsnet" +) + +func main() { + encoderCfg := zap.NewProductionEncoderConfig() + encoderCfg.EncodeTime = zapcore.RFC3339TimeEncoder + logger := zap.Must(zap.Config{ + Level: zap.NewAtomicLevelAt(zap.DebugLevel), + Encoding: "json", + OutputPaths: []string{"stderr"}, + ErrorOutputPaths: []string{"stderr"}, + EncoderConfig: encoderCfg, + }.Build()).Sugar() + defer logger.Sync() + + if err := run(logger); err != nil { + logger.Fatal(err.Error()) + } +} + +func run(logger *zap.SugaredLogger) error { + var ( + configPath = os.Getenv("TS_K8S_PROXY_CONFIG") + podUID = os.Getenv("POD_UID") + podIP = os.Getenv("POD_IP") + ) + if configPath == "" { + return errors.New("TS_K8S_PROXY_CONFIG unset") + } + + // serveCtx to live for the lifetime of the process, only gets cancelled + // once the Tailscale Service has been drained + serveCtx, serveCancel := context.WithCancel(context.Background()) + defer serveCancel() + + // ctx to cancel to start the shutdown process. + ctx, cancel := context.WithCancel(serveCtx) + defer cancel() + + sigsChan := make(chan os.Signal, 1) + signal.Notify(sigsChan, syscall.SIGINT, syscall.SIGTERM) + go func() { + select { + case <-ctx.Done(): + case s := <-sigsChan: + logger.Infof("Received shutdown signal %s, exiting", s) + cancel() + } + }() + + var group *errgroup.Group + group, ctx = errgroup.WithContext(ctx) + + restConfig, err := getRestConfig(logger) + if err != nil { + return fmt.Errorf("error getting rest config: %w", err) + } + clientset, err := kubernetes.NewForConfig(restConfig) + if err != nil { + return fmt.Errorf("error creating Kubernetes clientset: %w", err) + } + + // Load and watch config. + cfgChan := make(chan *conf.Config) + cfgLoader := config.NewConfigLoader(logger, clientset.CoreV1(), cfgChan) + group.Go(func() error { + return cfgLoader.WatchConfig(ctx, configPath) + }) + + // Get initial config. + var cfg *conf.Config + select { + case <-ctx.Done(): + return group.Wait() + case cfg = <-cfgChan: + } + + if cfg.Parsed.LogLevel != nil { + level, err := zapcore.ParseLevel(*cfg.Parsed.LogLevel) + if err != nil { + return fmt.Errorf("error parsing log level %q: %w", *cfg.Parsed.LogLevel, err) + } + logger = logger.WithOptions(zap.IncreaseLevel(level)) + } + + // TODO:(ChaosInTheCRD) This is a temporary workaround until we can set static endpoints using prefs + if se := cfg.Parsed.StaticEndpoints; len(se) > 0 { + logger.Debugf("setting static endpoints '%v' via TS_DEBUG_PRETENDPOINT environment variable", cfg.Parsed.StaticEndpoints) + ses := make([]string, len(se)) + for i, e := range se { + ses[i] = e.String() + } + + err := os.Setenv("TS_DEBUG_PRETENDPOINT", strings.Join(ses, ",")) + if err != nil { + return err + } + } + + if cfg.Parsed.App != nil { + hostinfo.SetApp(*cfg.Parsed.App) + } + + // TODO(tomhjp): Pass this setting directly into the store instead of using + // environment variables. + if cfg.Parsed.APIServerProxy != nil && cfg.Parsed.APIServerProxy.IssueCerts.EqualBool(true) { + os.Setenv("TS_CERT_SHARE_MODE", "rw") + } else { + os.Setenv("TS_CERT_SHARE_MODE", "ro") + } + + st, err := getStateStore(cfg.Parsed.State, logger) + if err != nil { + return err + } + + // If Pod UID unset, assume we're running outside of a cluster/not managed + // by the operator, so no need to set additional state keys. + if podUID != "" { + if err := state.SetInitialKeys(st, podUID); err != nil { + return fmt.Errorf("error setting initial state: %w", err) + } + } + + var authKey string + if cfg.Parsed.AuthKey != nil { + authKey = *cfg.Parsed.AuthKey + } + + ts := &tsnet.Server{ + Logf: logger.Named("tsnet").Debugf, + UserLogf: logger.Named("tsnet").Infof, + Store: st, + AuthKey: authKey, + } + + if cfg.Parsed.ServerURL != nil { + ts.ControlURL = *cfg.Parsed.ServerURL + } + + if cfg.Parsed.Hostname != nil { + ts.Hostname = *cfg.Parsed.Hostname + } + + // Make sure we crash loop if Up doesn't complete in reasonable time. + upCtx, upCancel := context.WithTimeout(ctx, time.Minute) + defer upCancel() + if _, err := ts.Up(upCtx); err != nil { + return fmt.Errorf("error starting tailscale server: %w", err) + } + defer ts.Close() + lc, err := ts.LocalClient() + if err != nil { + return fmt.Errorf("error getting local client: %w", err) + } + + // Setup for updating state keys. + if podUID != "" { + group.Go(func() error { + return state.KeepKeysUpdated(ctx, st, klc.New(lc)) + }) + } + + if cfg.Parsed.HealthCheckEnabled.EqualBool(true) || cfg.Parsed.MetricsEnabled.EqualBool(true) { + addr := podIP + if addr == "" { + addr = cfg.GetLocalAddr() + } + + addrPort := getLocalAddrPort(addr, cfg.GetLocalPort()) + mux := http.NewServeMux() + localSrv := &http.Server{Addr: addrPort, Handler: mux} + + if cfg.Parsed.MetricsEnabled.EqualBool(true) { + logger.Infof("Running metrics endpoint at %s/metrics", addrPort) + metrics.RegisterMetricsHandlers(mux, lc, "") + } + + if cfg.Parsed.HealthCheckEnabled.EqualBool(true) { + ipV4, _ := ts.TailscaleIPs() + hz := healthz.RegisterHealthHandlers(mux, ipV4.String(), logger.Infof) + group.Go(func() error { + err := hz.MonitorHealth(ctx, lc) + if err == nil || errors.Is(err, context.Canceled) { + return nil + } + return err + }) + } + + group.Go(func() error { + errChan := make(chan error) + go func() { + if err := localSrv.ListenAndServe(); err != nil { + errChan <- err + } + close(errChan) + }() + + select { + case <-ctx.Done(): + sCtx, scancel := context.WithTimeout(serveCtx, 10*time.Second) + defer scancel() + return localSrv.Shutdown(sCtx) + case err := <-errChan: + return err + } + }) + } + + if v, ok := cfg.Parsed.AcceptRoutes.Get(); ok { + _, err = lc.EditPrefs(ctx, &ipn.MaskedPrefs{ + RouteAllSet: true, + Prefs: ipn.Prefs{RouteAll: v}, + }) + if err != nil { + return fmt.Errorf("error editing prefs: %w", err) + } + } + + // TODO(tomhjp): There seems to be a bug that on restart the device does + // not get reassigned it's already working Service IPs unless we clear and + // reset the serve config. + if err := lc.SetServeConfig(ctx, &ipn.ServeConfig{}); err != nil { + return fmt.Errorf("error clearing existing ServeConfig: %w", err) + } + + var cm *certs.CertManager + if shouldIssueCerts(cfg) { + logger.Infof("Will issue TLS certs for Tailscale Service") + cm = certs.NewCertManager(klc.New(lc), logger.Infof) + } + if err := setServeConfig(ctx, lc, cm, apiServerProxyService(cfg)); err != nil { + return err + } + + if cfg.Parsed.AdvertiseServices != nil { + if _, err := lc.EditPrefs(ctx, &ipn.MaskedPrefs{ + AdvertiseServicesSet: true, + Prefs: ipn.Prefs{ + AdvertiseServices: cfg.Parsed.AdvertiseServices, + }, + }); err != nil { + return fmt.Errorf("error setting prefs AdvertiseServices: %w", err) + } + } + + // Setup for the API server proxy. + mode := kubetypes.APIServerProxyModeAuth + if cfg.Parsed.APIServerProxy != nil && cfg.Parsed.APIServerProxy.Mode != nil { + mode = *cfg.Parsed.APIServerProxy.Mode + } + ap, err := apiproxy.NewAPIServerProxy(logger.Named("apiserver-proxy"), restConfig, ts, mode, false) + if err != nil { + return fmt.Errorf("error creating api server proxy: %w", err) + } + + group.Go(func() error { + if err := ap.Run(serveCtx); err != nil { + return fmt.Errorf("error running API server proxy: %w", err) + } + + return nil + }) + + for { + select { + case <-ctx.Done(): + // Context cancelled, exit. + logger.Info("Context cancelled, exiting") + shutdownCtx, shutdownCancel := context.WithTimeout(serveCtx, 20*time.Second) + unadvertiseErr := services.EnsureServicesNotAdvertised(shutdownCtx, lc, logger.Infof) + shutdownCancel() + serveCancel() + return errors.Join(unadvertiseErr, group.Wait()) + case cfg = <-cfgChan: + // Handle config reload. + // TODO(tomhjp): Make auth mode reloadable. + var prefs ipn.MaskedPrefs + cfgLogger := logger + currentPrefs, err := lc.GetPrefs(ctx) + if err != nil { + return fmt.Errorf("error getting current prefs: %w", err) + } + if !slices.Equal(currentPrefs.AdvertiseServices, cfg.Parsed.AdvertiseServices) { + cfgLogger = cfgLogger.With("AdvertiseServices", fmt.Sprintf("%v -> %v", currentPrefs.AdvertiseServices, cfg.Parsed.AdvertiseServices)) + prefs.AdvertiseServicesSet = true + prefs.Prefs.AdvertiseServices = cfg.Parsed.AdvertiseServices + } + if cfg.Parsed.Hostname != nil && *cfg.Parsed.Hostname != currentPrefs.Hostname { + cfgLogger = cfgLogger.With("Hostname", fmt.Sprintf("%s -> %s", currentPrefs.Hostname, *cfg.Parsed.Hostname)) + prefs.HostnameSet = true + prefs.Hostname = *cfg.Parsed.Hostname + } + if v, ok := cfg.Parsed.AcceptRoutes.Get(); ok && v != currentPrefs.RouteAll { + cfgLogger = cfgLogger.With("AcceptRoutes", fmt.Sprintf("%v -> %v", currentPrefs.RouteAll, v)) + prefs.RouteAllSet = true + prefs.Prefs.RouteAll = v + } + if !prefs.IsEmpty() { + if _, err := lc.EditPrefs(ctx, &prefs); err != nil { + return fmt.Errorf("error editing prefs: %w", err) + } + } + if err := setServeConfig(ctx, lc, cm, apiServerProxyService(cfg)); err != nil { + return fmt.Errorf("error setting serve config: %w", err) + } + + cfgLogger.Infof("Config reloaded") + } + } +} + +func getLocalAddrPort(addr string, port uint16) string { + return net.JoinHostPort(addr, strconv.FormatUint(uint64(port), 10)) +} + +func getStateStore(path *string, logger *zap.SugaredLogger) (ipn.StateStore, error) { + p := "mem:" + if path != nil { + p = *path + } else { + logger.Warn("No state Secret provided; using in-memory store, which will lose state on restart") + } + st, err := store.New(logger.Errorf, p) + if err != nil { + return nil, fmt.Errorf("error creating state store: %w", err) + } + + return st, nil +} + +func getRestConfig(logger *zap.SugaredLogger) (*rest.Config, error) { + restConfig, err := rest.InClusterConfig() + switch err { + case nil: + return restConfig, nil + case rest.ErrNotInCluster: + logger.Info("Not running in-cluster, falling back to kubeconfig") + default: + return nil, fmt.Errorf("error getting in-cluster config: %w", err) + } + + loadingRules := clientcmd.NewDefaultClientConfigLoadingRules() + clientConfig := clientcmd.NewNonInteractiveDeferredLoadingClientConfig(loadingRules, nil) + restConfig, err = clientConfig.ClientConfig() + if err != nil { + return nil, fmt.Errorf("error loading kubeconfig: %w", err) + } + + return restConfig, nil +} + +func apiServerProxyService(cfg *conf.Config) tailcfg.ServiceName { + if cfg.Parsed.APIServerProxy != nil && + cfg.Parsed.APIServerProxy.Enabled.EqualBool(true) && + cfg.Parsed.APIServerProxy.ServiceName != nil && + *cfg.Parsed.APIServerProxy.ServiceName != "" { + return tailcfg.ServiceName(*cfg.Parsed.APIServerProxy.ServiceName) + } + + return "" +} + +func shouldIssueCerts(cfg *conf.Config) bool { + return cfg.Parsed.APIServerProxy != nil && + cfg.Parsed.APIServerProxy.IssueCerts.EqualBool(true) +} + +// setServeConfig sets up serve config such that it's serving for the passed in +// Tailscale Service, and does nothing if it's already up to date. +func setServeConfig(ctx context.Context, lc *local.Client, cm *certs.CertManager, name tailcfg.ServiceName) error { + existingServeConfig, err := lc.GetServeConfig(ctx) + if err != nil { + return fmt.Errorf("error getting existing serve config: %w", err) + } + + // Ensure serve config is cleared if no Tailscale Service. + if name == "" { + if reflect.DeepEqual(*existingServeConfig, ipn.ServeConfig{}) { + // Already up to date. + return nil + } + + if cm != nil { + cm.EnsureCertLoops(ctx, &ipn.ServeConfig{}) + } + return lc.SetServeConfig(ctx, &ipn.ServeConfig{}) + } + + status, err := lc.StatusWithoutPeers(ctx) + if err != nil { + return fmt.Errorf("error getting local client status: %w", err) + } + serviceHostPort := ipn.HostPort(fmt.Sprintf("%s.%s:443", name.WithoutPrefix(), status.CurrentTailnet.MagicDNSSuffix)) + + serveConfig := ipn.ServeConfig{ + // Configure for the Service hostname. + Services: map[tailcfg.ServiceName]*ipn.ServiceConfig{ + name: { + TCP: map[uint16]*ipn.TCPPortHandler{ + 443: { + HTTPS: true, + }, + }, + Web: map[ipn.HostPort]*ipn.WebServerConfig{ + serviceHostPort: { + Handlers: map[string]*ipn.HTTPHandler{ + "/": { + Proxy: "http://localhost:80", + }, + }, + }, + }, + }, + }, + } + + if reflect.DeepEqual(*existingServeConfig, serveConfig) { + // Already up to date. + return nil + } + + if cm != nil { + cm.EnsureCertLoops(ctx, &serveConfig) + } + return lc.SetServeConfig(ctx, &serveConfig) +} diff --git a/cmd/nardump/nardump.go b/cmd/nardump/nardump.go index 05be7b65a..f8947b02b 100644 --- a/cmd/nardump/nardump.go +++ b/cmd/nardump/nardump.go @@ -100,14 +100,13 @@ func (nw *narWriter) writeDir(dirPath string) error { sub := path.Join(dirPath, ent.Name()) var err error switch { - case mode.IsRegular(): - err = nw.writeRegular(sub) case mode.IsDir(): err = nw.writeDir(sub) + case mode.IsRegular(): + err = nw.writeRegular(sub) + case mode&os.ModeSymlink != 0: + err = nw.writeSymlink(sub) default: - // TODO(bradfitz): symlink, but requires fighting io/fs a bit - // to get at Readlink or the osFS via fs. But for now - // we don't need symlinks because they're not in Go's archive. return fmt.Errorf("unsupported file type %v at %q", sub, mode) } if err != nil { @@ -143,6 +142,23 @@ func (nw *narWriter) writeRegular(path string) error { return nil } +func (nw *narWriter) writeSymlink(path string) error { + nw.str("(") + nw.str("type") + nw.str("symlink") + nw.str("target") + // broken symlinks are valid in a nar + // given we do os.chdir(dir) and os.dirfs(".") above + // readlink now resolves relative links even if they are broken + link, err := os.Readlink(path) + if err != nil { + return err + } + nw.str(link) + nw.str(")") + return nil +} + func (nw *narWriter) str(s string) { if err := writeString(nw.w, s); err != nil { panic(writeNARError{err}) diff --git a/cmd/nardump/nardump_test.go b/cmd/nardump/nardump_test.go new file mode 100644 index 000000000..3b87e7962 --- /dev/null +++ b/cmd/nardump/nardump_test.go @@ -0,0 +1,52 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package main + +import ( + "crypto/sha256" + "fmt" + "os" + "runtime" + "testing" +) + +// setupTmpdir sets up a known golden layout, covering all allowed file/folder types in a nar +func setupTmpdir(t *testing.T) string { + tmpdir := t.TempDir() + pwd, _ := os.Getwd() + os.Chdir(tmpdir) + defer os.Chdir(pwd) + os.MkdirAll("sub/dir", 0755) + os.Symlink("brokenfile", "brokenlink") + os.Symlink("sub/dir", "dirl") + os.Symlink("/abs/nonexistentdir", "dirb") + os.Create("sub/dir/file1") + f, _ := os.Create("file2m") + _ = f.Truncate(2 * 1024 * 1024) + f.Close() + os.Symlink("../file2m", "sub/goodlink") + return tmpdir +} + +func TestWriteNar(t *testing.T) { + if runtime.GOOS == "windows" { + // Skip test on Windows as the Nix package manager is not supported on this platform + t.Skip("nix package manager is not available on Windows") + } + dir := setupTmpdir(t) + t.Run("nar", func(t *testing.T) { + // obtained via `nix-store --dump /tmp/... | sha256sum` of the above test dir + expected := "727613a36f41030e93a4abf2649c3ec64a2757ccff364e3f6f7d544eb976e442" + h := sha256.New() + os.Chdir(dir) + err := writeNAR(h, os.DirFS(".")) + if err != nil { + t.Fatal(err) + } + hash := fmt.Sprintf("%x", h.Sum(nil)) + if expected != hash { + t.Fatal("sha256sum of nar not matched", hash, expected) + } + }) +} diff --git a/cmd/natc/ippool/consensusippool.go b/cmd/natc/ippool/consensusippool.go new file mode 100644 index 000000000..bfa909b69 --- /dev/null +++ b/cmd/natc/ippool/consensusippool.go @@ -0,0 +1,461 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package ippool + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "log" + "net/netip" + "time" + + "github.com/hashicorp/raft" + "go4.org/netipx" + "tailscale.com/syncs" + "tailscale.com/tailcfg" + "tailscale.com/tsconsensus" + "tailscale.com/tsnet" + "tailscale.com/util/mak" +) + +// ConsensusIPPool implements an [IPPool] that is distributed among members of a cluster for high availability. +// Writes are directed to a leader among the cluster and are slower than reads, reads are performed locally +// using information replicated from the leader. +// The cluster maintains consistency, reads can be stale and writes can be unavailable if sufficient cluster +// peers are unavailable. +type ConsensusIPPool struct { + IPSet *netipx.IPSet + perPeerMap *syncs.Map[tailcfg.NodeID, *consensusPerPeerState] + consensus commandExecutor + clusterController clusterController + unusedAddressLifetime time.Duration +} + +func NewConsensusIPPool(ipSet *netipx.IPSet) *ConsensusIPPool { + return &ConsensusIPPool{ + unusedAddressLifetime: 48 * time.Hour, // TODO (fran) is this appropriate? should it be configurable? + IPSet: ipSet, + perPeerMap: &syncs.Map[tailcfg.NodeID, *consensusPerPeerState]{}, + } +} + +// IPForDomain looks up or creates an IP address allocation for the tailcfg.NodeID and domain pair. +// If no address association is found, one is allocated from the range of free addresses for this tailcfg.NodeID. +// If no more address are available, an error is returned. +func (ipp *ConsensusIPPool) IPForDomain(nid tailcfg.NodeID, domain string) (netip.Addr, error) { + now := time.Now() + // Check local state; local state may be stale. If we have an IP for this domain, and we are not + // close to the expiry time for the domain, it's safe to return what we have. + ps, psFound := ipp.perPeerMap.Load(nid) + if psFound { + if addr, addrFound := ps.domainToAddr[domain]; addrFound { + if ww, wwFound := ps.addrToDomain.Load(addr); wwFound { + if !isCloseToExpiry(ww.LastUsed, now, ipp.unusedAddressLifetime) { + ipp.fireAndForgetMarkLastUsed(nid, addr, ww, now) + return addr, nil + } + } + } + } + + // go via consensus + args := checkoutAddrArgs{ + NodeID: nid, + Domain: domain, + ReuseDeadline: now.Add(-1 * ipp.unusedAddressLifetime), + UpdatedAt: now, + } + bs, err := json.Marshal(args) + if err != nil { + return netip.Addr{}, err + } + c := tsconsensus.Command{ + Name: "checkoutAddr", + Args: bs, + } + result, err := ipp.consensus.ExecuteCommand(c) + if err != nil { + log.Printf("IPForDomain: raft error executing command: %v", err) + return netip.Addr{}, err + } + if result.Err != nil { + log.Printf("IPForDomain: error returned from state machine: %v", err) + return netip.Addr{}, result.Err + } + var addr netip.Addr + err = json.Unmarshal(result.Result, &addr) + return addr, err +} + +// DomainForIP looks up the domain associated with a tailcfg.NodeID and netip.Addr pair. +// If there is no association, the result is empty and ok is false. +func (ipp *ConsensusIPPool) DomainForIP(from tailcfg.NodeID, addr netip.Addr, updatedAt time.Time) (string, bool) { + // Look in local state, to save a consensus round trip; local state may be stale. + // + // The only time we expect ordering of commands to matter to clients is on first + // connection to a domain. In that case it may be that although we don't find the + // domain in our local state, it is in fact in the state of the state machine (ie + // the client did a DNS lookup, and we responded with an IP and _should_ know that + // domain when the TCP connection for that IP arrives.) + // + // So it's ok to return local state, unless local state doesn't recognize the domain, + // in which case we should check the consensus state machine to know for sure. + var domain string + ww, ok := ipp.domainLookup(from, addr) + if ok { + domain = ww.Domain + } else { + d, err := ipp.readDomainForIP(from, addr) + if err != nil { + log.Printf("error reading domain from consensus: %v", err) + return "", false + } + domain = d + } + if domain == "" { + log.Printf("did not find domain for node: %v, addr: %s", from, addr) + return "", false + } + ipp.fireAndForgetMarkLastUsed(from, addr, ww, updatedAt) + return domain, true +} + +func (ipp *ConsensusIPPool) fireAndForgetMarkLastUsed(from tailcfg.NodeID, addr netip.Addr, ww whereWhen, updatedAt time.Time) { + window := 5 * time.Minute + if updatedAt.Sub(ww.LastUsed).Abs() < window { + return + } + go func() { + err := ipp.markLastUsed(from, addr, ww.Domain, updatedAt) + if err != nil { + log.Printf("error marking last used: %v", err) + } + }() +} + +func (ipp *ConsensusIPPool) domainLookup(from tailcfg.NodeID, addr netip.Addr) (whereWhen, bool) { + ps, ok := ipp.perPeerMap.Load(from) + if !ok { + log.Printf("domainLookup: peer state absent for: %d", from) + return whereWhen{}, false + } + ww, ok := ps.addrToDomain.Load(addr) + if !ok { + log.Printf("domainLookup: peer state doesn't recognize addr: %s", addr) + return whereWhen{}, false + } + return ww, true +} + +type ClusterOpts struct { + Tag string + StateDir string + FollowOnly bool +} + +// StartConsensus is part of the IPPool interface. It starts the raft background routines that handle consensus. +func (ipp *ConsensusIPPool) StartConsensus(ctx context.Context, ts *tsnet.Server, opts ClusterOpts) error { + cfg := tsconsensus.DefaultConfig() + cfg.ServeDebugMonitor = true + cfg.StateDirPath = opts.StateDir + cns, err := tsconsensus.Start(ctx, ts, ipp, tsconsensus.BootstrapOpts{ + Tag: opts.Tag, + FollowOnly: opts.FollowOnly, + }, cfg) + if err != nil { + return err + } + ipp.consensus = cns + ipp.clusterController = cns + return nil +} + +type whereWhen struct { + Domain string + LastUsed time.Time +} + +type consensusPerPeerState struct { + domainToAddr map[string]netip.Addr + addrToDomain *syncs.Map[netip.Addr, whereWhen] +} + +// StopConsensus is part of the IPPool interface. It stops the raft background routines that handle consensus. +func (ipp *ConsensusIPPool) StopConsensus(ctx context.Context) error { + return (ipp.consensus).(*tsconsensus.Consensus).Stop(ctx) +} + +// unusedIPV4 finds the next unused or expired IP address in the pool. +// IP addresses in the pool should be reused if they haven't been used for some period of time. +// reuseDeadline is the time before which addresses are considered to be expired. +// So if addresses are being reused after they haven't been used for 24 hours say, reuseDeadline +// would be 24 hours ago. +func (ps *consensusPerPeerState) unusedIPV4(ipset *netipx.IPSet, reuseDeadline time.Time) (netip.Addr, bool, string, error) { + // If we want to have a random IP choice behavior we could make that work with the state machine by doing something like + // passing the randomly chosen IP into the state machine call (so replaying logs would still be deterministic). + for _, r := range ipset.Ranges() { + ip := r.From() + toIP := r.To() + if !ip.IsValid() || !toIP.IsValid() { + continue + } + for toIP.Compare(ip) != -1 { + ww, ok := ps.addrToDomain.Load(ip) + if !ok { + return ip, false, "", nil + } + if ww.LastUsed.Before(reuseDeadline) { + return ip, true, ww.Domain, nil + } + ip = ip.Next() + } + } + return netip.Addr{}, false, "", errors.New("ip pool exhausted") +} + +// isCloseToExpiry returns true if the lastUsed and now times are more than +// half the lifetime apart +func isCloseToExpiry(lastUsed, now time.Time, lifetime time.Duration) bool { + return now.Sub(lastUsed).Abs() > (lifetime / 2) +} + +type readDomainForIPArgs struct { + NodeID tailcfg.NodeID + Addr netip.Addr +} + +// executeReadDomainForIP parses a readDomainForIP log entry and applies it. +func (ipp *ConsensusIPPool) executeReadDomainForIP(bs []byte) tsconsensus.CommandResult { + var args readDomainForIPArgs + err := json.Unmarshal(bs, &args) + if err != nil { + return tsconsensus.CommandResult{Err: err} + } + return ipp.applyReadDomainForIP(args.NodeID, args.Addr) +} + +func (ipp *ConsensusIPPool) applyReadDomainForIP(from tailcfg.NodeID, addr netip.Addr) tsconsensus.CommandResult { + domain := func() string { + ps, ok := ipp.perPeerMap.Load(from) + if !ok { + return "" + } + ww, ok := ps.addrToDomain.Load(addr) + if !ok { + return "" + } + return ww.Domain + }() + resultBs, err := json.Marshal(domain) + return tsconsensus.CommandResult{Result: resultBs, Err: err} +} + +// readDomainForIP executes a readDomainForIP command on the leader with raft. +func (ipp *ConsensusIPPool) readDomainForIP(nid tailcfg.NodeID, addr netip.Addr) (string, error) { + args := readDomainForIPArgs{ + NodeID: nid, + Addr: addr, + } + bs, err := json.Marshal(args) + if err != nil { + return "", err + } + c := tsconsensus.Command{ + Name: "readDomainForIP", + Args: bs, + } + result, err := ipp.consensus.ExecuteCommand(c) + if err != nil { + log.Printf("readDomainForIP: raft error executing command: %v", err) + return "", err + } + if result.Err != nil { + log.Printf("readDomainForIP: error returned from state machine: %v", err) + return "", result.Err + } + var domain string + err = json.Unmarshal(result.Result, &domain) + return domain, err +} + +type markLastUsedArgs struct { + NodeID tailcfg.NodeID + Addr netip.Addr + Domain string + UpdatedAt time.Time +} + +// executeMarkLastUsed parses a markLastUsed log entry and applies it. +func (ipp *ConsensusIPPool) executeMarkLastUsed(bs []byte) tsconsensus.CommandResult { + var args markLastUsedArgs + err := json.Unmarshal(bs, &args) + if err != nil { + return tsconsensus.CommandResult{Err: err} + } + err = ipp.applyMarkLastUsed(args.NodeID, args.Addr, args.Domain, args.UpdatedAt) + if err != nil { + return tsconsensus.CommandResult{Err: err} + } + return tsconsensus.CommandResult{} +} + +// applyMarkLastUsed applies the arguments from the log entry to the state. It updates an entry in the AddrToDomain +// map with a new LastUsed timestamp. +// applyMarkLastUsed is not safe for concurrent access. It's only called from raft which will +// not call it concurrently. +func (ipp *ConsensusIPPool) applyMarkLastUsed(from tailcfg.NodeID, addr netip.Addr, domain string, updatedAt time.Time) error { + ps, ok := ipp.perPeerMap.Load(from) + if !ok { + // There's nothing to mark. But this is unexpected, because we mark last used after we do things with peer state. + log.Printf("applyMarkLastUsed: could not find peer state, nodeID: %s", from) + return nil + } + ww, ok := ps.addrToDomain.Load(addr) + if !ok { + // The peer state didn't have an entry for the IP address (possibly it expired), so there's nothing to mark. + return nil + } + if ww.Domain != domain { + // The IP address expired and was reused for a new domain. Don't mark. + return nil + } + if ww.LastUsed.After(updatedAt) { + // This has been marked more recently. Don't mark. + return nil + } + ww.LastUsed = updatedAt + ps.addrToDomain.Store(addr, ww) + return nil +} + +// markLastUsed executes a markLastUsed command on the leader with raft. +func (ipp *ConsensusIPPool) markLastUsed(nid tailcfg.NodeID, addr netip.Addr, domain string, lastUsed time.Time) error { + args := markLastUsedArgs{ + NodeID: nid, + Addr: addr, + Domain: domain, + UpdatedAt: lastUsed, + } + bs, err := json.Marshal(args) + if err != nil { + return err + } + c := tsconsensus.Command{ + Name: "markLastUsed", + Args: bs, + } + result, err := ipp.consensus.ExecuteCommand(c) + if err != nil { + log.Printf("markLastUsed: raft error executing command: %v", err) + return err + } + if result.Err != nil { + log.Printf("markLastUsed: error returned from state machine: %v", err) + return result.Err + } + return nil +} + +type checkoutAddrArgs struct { + NodeID tailcfg.NodeID + Domain string + ReuseDeadline time.Time + UpdatedAt time.Time +} + +// executeCheckoutAddr parses a checkoutAddr raft log entry and applies it. +func (ipp *ConsensusIPPool) executeCheckoutAddr(bs []byte) tsconsensus.CommandResult { + var args checkoutAddrArgs + err := json.Unmarshal(bs, &args) + if err != nil { + return tsconsensus.CommandResult{Err: err} + } + addr, err := ipp.applyCheckoutAddr(args.NodeID, args.Domain, args.ReuseDeadline, args.UpdatedAt) + if err != nil { + return tsconsensus.CommandResult{Err: err} + } + resultBs, err := json.Marshal(addr) + if err != nil { + return tsconsensus.CommandResult{Err: err} + } + return tsconsensus.CommandResult{Result: resultBs} +} + +// applyCheckoutAddr finds the IP address for a nid+domain +// Each nid can use all of the addresses in the pool. +// updatedAt is the current time, the time at which we are wanting to get a new IP address. +// reuseDeadline is the time before which addresses are considered to be expired. +// So if addresses are being reused after they haven't been used for 24 hours say updatedAt would be now +// and reuseDeadline would be 24 hours ago. +// It is not safe for concurrent access (it's only called from raft, which will not call concurrently +// so that's fine). +func (ipp *ConsensusIPPool) applyCheckoutAddr(nid tailcfg.NodeID, domain string, reuseDeadline, updatedAt time.Time) (netip.Addr, error) { + ps, ok := ipp.perPeerMap.Load(nid) + if !ok { + ps = &consensusPerPeerState{ + addrToDomain: &syncs.Map[netip.Addr, whereWhen]{}, + } + ipp.perPeerMap.Store(nid, ps) + } + if existing, ok := ps.domainToAddr[domain]; ok { + ww, ok := ps.addrToDomain.Load(existing) + if ok { + ww.LastUsed = updatedAt + ps.addrToDomain.Store(existing, ww) + return existing, nil + } + log.Printf("applyCheckoutAddr: data out of sync, allocating new IP") + } + addr, wasInUse, previousDomain, err := ps.unusedIPV4(ipp.IPSet, reuseDeadline) + if err != nil { + return netip.Addr{}, err + } + mak.Set(&ps.domainToAddr, domain, addr) + if wasInUse { + delete(ps.domainToAddr, previousDomain) + } + ps.addrToDomain.Store(addr, whereWhen{Domain: domain, LastUsed: updatedAt}) + return addr, nil +} + +// Apply is part of the raft.FSM interface. It takes an incoming log entry and applies it to the state. +func (ipp *ConsensusIPPool) Apply(lg *raft.Log) any { + var c tsconsensus.Command + if err := json.Unmarshal(lg.Data, &c); err != nil { + panic(fmt.Sprintf("failed to unmarshal command: %s", err.Error())) + } + switch c.Name { + case "checkoutAddr": + return ipp.executeCheckoutAddr(c.Args) + case "markLastUsed": + return ipp.executeMarkLastUsed(c.Args) + case "readDomainForIP": + return ipp.executeReadDomainForIP(c.Args) + default: + panic(fmt.Sprintf("unrecognized command: %s", c.Name)) + } +} + +// commandExecutor is an interface covering the routing parts of consensus +// used to allow a fake in the tests +type commandExecutor interface { + ExecuteCommand(tsconsensus.Command) (tsconsensus.CommandResult, error) +} + +type clusterController interface { + GetClusterConfiguration() (raft.Configuration, error) + DeleteClusterServer(id raft.ServerID) (uint64, error) +} + +// GetClusterConfiguration gets the consensus implementation's cluster configuration +func (ipp *ConsensusIPPool) GetClusterConfiguration() (raft.Configuration, error) { + return ipp.clusterController.GetClusterConfiguration() +} + +// DeleteClusterServer removes a server from the consensus implementation's cluster configuration +func (ipp *ConsensusIPPool) DeleteClusterServer(id raft.ServerID) (uint64, error) { + return ipp.clusterController.DeleteClusterServer(id) +} diff --git a/cmd/natc/ippool/consensusippool_test.go b/cmd/natc/ippool/consensusippool_test.go new file mode 100644 index 000000000..242cdffaf --- /dev/null +++ b/cmd/natc/ippool/consensusippool_test.go @@ -0,0 +1,383 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package ippool + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/netip" + "testing" + "time" + + "github.com/hashicorp/raft" + "go4.org/netipx" + "tailscale.com/tailcfg" + "tailscale.com/tsconsensus" + "tailscale.com/util/must" +) + +func makeSetFromPrefix(pfx netip.Prefix) *netipx.IPSet { + var ipsb netipx.IPSetBuilder + ipsb.AddPrefix(pfx) + return must.Get(ipsb.IPSet()) +} + +type FakeConsensus struct { + ipp *ConsensusIPPool +} + +func (c *FakeConsensus) ExecuteCommand(cmd tsconsensus.Command) (tsconsensus.CommandResult, error) { + b, err := json.Marshal(cmd) + if err != nil { + return tsconsensus.CommandResult{}, err + } + result := c.ipp.Apply(&raft.Log{Data: b}) + return result.(tsconsensus.CommandResult), nil +} + +func makePool(pfx netip.Prefix) *ConsensusIPPool { + ipp := NewConsensusIPPool(makeSetFromPrefix(pfx)) + ipp.consensus = &FakeConsensus{ipp: ipp} + return ipp +} + +func TestConsensusIPForDomain(t *testing.T) { + pfx := netip.MustParsePrefix("100.64.0.0/16") + ipp := makePool(pfx) + from := tailcfg.NodeID(1) + + a, err := ipp.IPForDomain(from, "example.com") + if err != nil { + t.Fatal(err) + } + if !pfx.Contains(a) { + t.Fatalf("expected %v to be in the prefix %v", a, pfx) + } + + b, err := ipp.IPForDomain(from, "a.example.com") + if err != nil { + t.Fatal(err) + } + if !pfx.Contains(b) { + t.Fatalf("expected %v to be in the prefix %v", b, pfx) + } + if b == a { + t.Fatalf("same address issued twice %v, %v", a, b) + } + + c, err := ipp.IPForDomain(from, "example.com") + if err != nil { + t.Fatal(err) + } + if c != a { + t.Fatalf("expected %v to be remembered as the addr for example.com, but got %v", a, c) + } +} + +func TestConsensusPoolExhaustion(t *testing.T) { + ipp := makePool(netip.MustParsePrefix("100.64.0.0/31")) + from := tailcfg.NodeID(1) + + subdomains := []string{"a", "b", "c"} + for i, sd := range subdomains { + _, err := ipp.IPForDomain(from, fmt.Sprintf("%s.example.com", sd)) + if i < 2 && err != nil { + t.Fatal(err) + } + expected := "ip pool exhausted" + if i == 2 && err.Error() != expected { + t.Fatalf("expected error to be '%s', got '%s'", expected, err.Error()) + } + } +} + +func TestConsensusPoolExpiry(t *testing.T) { + ipp := makePool(netip.MustParsePrefix("100.64.0.0/31")) + firstIP := netip.MustParseAddr("100.64.0.0") + secondIP := netip.MustParseAddr("100.64.0.1") + timeOfUse := time.Now() + beforeTimeOfUse := timeOfUse.Add(-1 * time.Hour) + afterTimeOfUse := timeOfUse.Add(1 * time.Hour) + from := tailcfg.NodeID(1) + + // the pool is unused, we get an address, and it's marked as being used at timeOfUse + aAddr, err := ipp.applyCheckoutAddr(from, "a.example.com", time.Time{}, timeOfUse) + if err != nil { + t.Fatal(err) + } + if aAddr.Compare(firstIP) != 0 { + t.Fatalf("expected %s, got %s", firstIP, aAddr) + } + ww, ok := ipp.domainLookup(from, firstIP) + if !ok { + t.Fatal("expected wherewhen to be found") + } + if ww.Domain != "a.example.com" { + t.Fatalf("expected aAddr to look up to a.example.com, got: %s", ww.Domain) + } + + // the time before which we will reuse addresses is prior to timeOfUse, so no reuse + bAddr, err := ipp.applyCheckoutAddr(from, "b.example.com", beforeTimeOfUse, timeOfUse) + if err != nil { + t.Fatal(err) + } + if bAddr.Compare(secondIP) != 0 { + t.Fatalf("expected %s, got %s", secondIP, bAddr) + } + + // the time before which we will reuse addresses is after timeOfUse, so reuse addresses that were marked as used at timeOfUse. + cAddr, err := ipp.applyCheckoutAddr(from, "c.example.com", afterTimeOfUse, timeOfUse) + if err != nil { + t.Fatal(err) + } + if cAddr.Compare(firstIP) != 0 { + t.Fatalf("expected %s, got %s", firstIP, cAddr) + } + ww, ok = ipp.domainLookup(from, firstIP) + if !ok { + t.Fatal("expected wherewhen to be found") + } + if ww.Domain != "c.example.com" { + t.Fatalf("expected firstIP to look up to c.example.com, got: %s", ww.Domain) + } + + // the addr remains associated with c.example.com + cAddrAgain, err := ipp.applyCheckoutAddr(from, "c.example.com", afterTimeOfUse, timeOfUse) + if err != nil { + t.Fatal(err) + } + if cAddrAgain.Compare(cAddr) != 0 { + t.Fatalf("expected cAddrAgain to be cAddr, but they are different. cAddrAgain=%s cAddr=%s", cAddrAgain, cAddr) + } + ww, ok = ipp.domainLookup(from, firstIP) + if !ok { + t.Fatal("expected wherewhen to be found") + } + if ww.Domain != "c.example.com" { + t.Fatalf("expected firstIP to look up to c.example.com, got: %s", ww.Domain) + } +} + +func TestConsensusPoolApplyMarkLastUsed(t *testing.T) { + ipp := makePool(netip.MustParsePrefix("100.64.0.0/31")) + firstIP := netip.MustParseAddr("100.64.0.0") + time1 := time.Now() + time2 := time1.Add(1 * time.Hour) + from := tailcfg.NodeID(1) + domain := "example.com" + + aAddr, err := ipp.applyCheckoutAddr(from, domain, time.Time{}, time1) + if err != nil { + t.Fatal(err) + } + if aAddr.Compare(firstIP) != 0 { + t.Fatalf("expected %s, got %s", firstIP, aAddr) + } + // example.com LastUsed is now time1 + ww, ok := ipp.domainLookup(from, firstIP) + if !ok { + t.Fatal("expected wherewhen to be found") + } + if ww.LastUsed != time1 { + t.Fatalf("expected %s, got %s", time1, ww.LastUsed) + } + if ww.Domain != domain { + t.Fatalf("expected %s, got %s", domain, ww.Domain) + } + + err = ipp.applyMarkLastUsed(from, firstIP, domain, time2) + if err != nil { + t.Fatal(err) + } + + // example.com LastUsed is now time2 + ww, ok = ipp.domainLookup(from, firstIP) + if !ok { + t.Fatal("expected wherewhen to be found") + } + if ww.LastUsed != time2 { + t.Fatalf("expected %s, got %s", time2, ww.LastUsed) + } + if ww.Domain != domain { + t.Fatalf("expected %s, got %s", domain, ww.Domain) + } +} + +func TestConsensusDomainForIP(t *testing.T) { + ipp := makePool(netip.MustParsePrefix("100.64.0.0/16")) + from := tailcfg.NodeID(1) + domain := "example.com" + now := time.Now() + + d, ok := ipp.DomainForIP(from, netip.MustParseAddr("100.64.0.1"), now) + if d != "" { + t.Fatalf("expected an empty string if the addr is not found but got %s", d) + } + if ok { + t.Fatalf("expected domain to not be found for IP, as it has never been looked up") + } + a, err := ipp.IPForDomain(from, domain) + if err != nil { + t.Fatal(err) + } + d2, ok := ipp.DomainForIP(from, a, now) + if d2 != domain { + t.Fatalf("expected %s but got %s", domain, d2) + } + if !ok { + t.Fatalf("expected domain to be found for IP that was handed out for it") + } +} + +func TestConsensusReadDomainForIP(t *testing.T) { + ipp := makePool(netip.MustParsePrefix("100.64.0.0/16")) + from := tailcfg.NodeID(1) + domain := "example.com" + + d, err := ipp.readDomainForIP(from, netip.MustParseAddr("100.64.0.1")) + if err != nil { + t.Fatal(err) + } + if d != "" { + t.Fatalf("expected an empty string if the addr is not found but got %s", d) + } + a, err := ipp.IPForDomain(from, domain) + if err != nil { + t.Fatal(err) + } + d2, err := ipp.readDomainForIP(from, a) + if err != nil { + t.Fatal(err) + } + if d2 != domain { + t.Fatalf("expected %s but got %s", domain, d2) + } +} + +func TestConsensusSnapshot(t *testing.T) { + pfx := netip.MustParsePrefix("100.64.0.0/16") + ipp := makePool(pfx) + domain := "example.com" + expectedAddr := netip.MustParseAddr("100.64.0.0") + expectedFrom := expectedAddr + expectedTo := netip.MustParseAddr("100.64.255.255") + from := tailcfg.NodeID(1) + + // pool allocates first addr for from + if _, err := ipp.IPForDomain(from, domain); err != nil { + t.Fatal(err) + } + // take a snapshot + fsmSnap, err := ipp.Snapshot() + if err != nil { + t.Fatal(err) + } + snap := fsmSnap.(fsmSnapshot) + + // verify snapshot state matches the state we know ipp will have + // ipset matches ipp.IPSet + if len(snap.IPSet.Ranges) != 1 { + t.Fatalf("expected 1, got %d", len(snap.IPSet.Ranges)) + } + if snap.IPSet.Ranges[0].From != expectedFrom { + t.Fatalf("want %s, got %s", expectedFrom, snap.IPSet.Ranges[0].From) + } + if snap.IPSet.Ranges[0].To != expectedTo { + t.Fatalf("want %s, got %s", expectedTo, snap.IPSet.Ranges[0].To) + } + + // perPeerMap has one entry, for from + if len(snap.PerPeerMap) != 1 { + t.Fatalf("expected 1, got %d", len(snap.PerPeerMap)) + } + ps := snap.PerPeerMap[from] + + // the one peer state has allocated one address, the first in the prefix + if len(ps.DomainToAddr) != 1 { + t.Fatalf("expected 1, got %d", len(ps.DomainToAddr)) + } + addr := ps.DomainToAddr[domain] + if addr != expectedAddr { + t.Fatalf("want %s, got %s", expectedAddr.String(), addr.String()) + } + if len(ps.AddrToDomain) != 1 { + t.Fatalf("expected 1, got %d", len(ps.AddrToDomain)) + } + ww := ps.AddrToDomain[addr] + if ww.Domain != domain { + t.Fatalf("want %s, got %s", domain, ww.Domain) + } +} + +func TestConsensusRestore(t *testing.T) { + pfx := netip.MustParsePrefix("100.64.0.0/16") + ipp := makePool(pfx) + domain := "example.com" + expectedAddr := netip.MustParseAddr("100.64.0.0") + from := tailcfg.NodeID(1) + + if _, err := ipp.IPForDomain(from, domain); err != nil { + t.Fatal(err) + } + // take the snapshot after only 1 addr allocated + fsmSnap, err := ipp.Snapshot() + if err != nil { + t.Fatal(err) + } + snap := fsmSnap.(fsmSnapshot) + + if _, err := ipp.IPForDomain(from, "b.example.com"); err != nil { + t.Fatal(err) + } + if _, err := ipp.IPForDomain(from, "c.example.com"); err != nil { + t.Fatal(err) + } + if _, err := ipp.IPForDomain(from, "d.example.com"); err != nil { + t.Fatal(err) + } + // ipp now has 4 entries in domainToAddr + ps, _ := ipp.perPeerMap.Load(from) + if len(ps.domainToAddr) != 4 { + t.Fatalf("want 4, got %d", len(ps.domainToAddr)) + } + + // restore the snapshot + bs, err := json.Marshal(snap) + if err != nil { + t.Fatal(err) + } + err = ipp.Restore(io.NopCloser(bytes.NewBuffer(bs))) + if err != nil { + t.Fatal(err) + } + + // everything should be as it was when the snapshot was taken + if ipp.perPeerMap.Len() != 1 { + t.Fatalf("want 1, got %d", ipp.perPeerMap.Len()) + } + psAfter, _ := ipp.perPeerMap.Load(from) + if len(psAfter.domainToAddr) != 1 { + t.Fatalf("want 1, got %d", len(psAfter.domainToAddr)) + } + if psAfter.domainToAddr[domain] != expectedAddr { + t.Fatalf("want %s, got %s", expectedAddr, psAfter.domainToAddr[domain]) + } + ww, _ := psAfter.addrToDomain.Load(expectedAddr) + if ww.Domain != domain { + t.Fatalf("want %s, got %s", domain, ww.Domain) + } +} + +func TestConsensusIsCloseToExpiry(t *testing.T) { + a := time.Now() + b := a.Add(5 * time.Second) + if !isCloseToExpiry(a, b, 8*time.Second) { + t.Fatal("times are not within half the lifetime, expected true") + } + if isCloseToExpiry(a, b, 12*time.Second) { + t.Fatal("times are within half the lifetime, expected false") + } +} diff --git a/cmd/natc/ippool/consensusippoolserialize.go b/cmd/natc/ippool/consensusippoolserialize.go new file mode 100644 index 000000000..97dc02f2c --- /dev/null +++ b/cmd/natc/ippool/consensusippoolserialize.go @@ -0,0 +1,164 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package ippool + +import ( + "encoding/json" + "io" + "log" + "maps" + "net/netip" + + "github.com/hashicorp/raft" + "go4.org/netipx" + "tailscale.com/syncs" + "tailscale.com/tailcfg" +) + +// Snapshot and Restore enable the raft lib to do log compaction. +// https://pkg.go.dev/github.com/hashicorp/raft#FSM + +// Snapshot is part of the raft.FSM interface. +// According to the docs it: +// - should return quickly +// - will not be called concurrently with Apply +// - the snapshot returned will have Persist called on it concurrently with Apply +// (so it should not contain pointers to the original data that's being mutated) +func (ipp *ConsensusIPPool) Snapshot() (raft.FSMSnapshot, error) { + // everything is safe for concurrent reads and this is not called concurrently with Apply which is + // the only thing that writes, so we do not need to lock + return ipp.getPersistable(), nil +} + +type persistableIPSet struct { + Ranges []persistableIPRange +} + +func getPersistableIPSet(i *netipx.IPSet) persistableIPSet { + rs := []persistableIPRange{} + for _, r := range i.Ranges() { + rs = append(rs, getPersistableIPRange(r)) + } + return persistableIPSet{Ranges: rs} +} + +func (mips *persistableIPSet) toIPSet() (*netipx.IPSet, error) { + b := netipx.IPSetBuilder{} + for _, r := range mips.Ranges { + b.AddRange(r.toIPRange()) + } + return b.IPSet() +} + +type persistableIPRange struct { + From netip.Addr + To netip.Addr +} + +func getPersistableIPRange(r netipx.IPRange) persistableIPRange { + return persistableIPRange{ + From: r.From(), + To: r.To(), + } +} + +func (mipr *persistableIPRange) toIPRange() netipx.IPRange { + return netipx.IPRangeFrom(mipr.From, mipr.To) +} + +// Restore is part of the raft.FSM interface. +// According to the docs it: +// - will not be called concurrently with any other command +// - the FSM must discard all previous state before restoring +func (ipp *ConsensusIPPool) Restore(rc io.ReadCloser) error { + var snap fsmSnapshot + if err := json.NewDecoder(rc).Decode(&snap); err != nil { + return err + } + ipset, ppm, err := snap.getData() + if err != nil { + return err + } + ipp.IPSet = ipset + ipp.perPeerMap = ppm + return nil +} + +type fsmSnapshot struct { + IPSet persistableIPSet + PerPeerMap map[tailcfg.NodeID]persistablePPS +} + +// Persist is part of the raft.FSMSnapshot interface +// According to the docs Persist may be called concurrently with Apply +func (f fsmSnapshot) Persist(sink raft.SnapshotSink) error { + if err := json.NewEncoder(sink).Encode(f); err != nil { + log.Printf("Error encoding snapshot as JSON: %v", err) + return sink.Cancel() + } + return sink.Close() +} + +// Release is part of the raft.FSMSnapshot interface +func (f fsmSnapshot) Release() {} + +// getPersistable returns an object that: +// - contains all the data in ConsensusIPPool +// - doesn't share any pointers with it +// - can be marshalled to JSON +// +// part of the raft snapshotting, getPersistable will be called during Snapshot +// and the results used during persist (concurrently with Apply) +func (ipp *ConsensusIPPool) getPersistable() fsmSnapshot { + ppm := map[tailcfg.NodeID]persistablePPS{} + for k, v := range ipp.perPeerMap.All() { + ppm[k] = v.getPersistable() + } + return fsmSnapshot{ + IPSet: getPersistableIPSet(ipp.IPSet), + PerPeerMap: ppm, + } +} + +func (f fsmSnapshot) getData() (*netipx.IPSet, *syncs.Map[tailcfg.NodeID, *consensusPerPeerState], error) { + ppm := syncs.Map[tailcfg.NodeID, *consensusPerPeerState]{} + for k, v := range f.PerPeerMap { + ppm.Store(k, v.toPerPeerState()) + } + ipset, err := f.IPSet.toIPSet() + if err != nil { + return nil, nil, err + } + return ipset, &ppm, nil +} + +// getPersistable returns an object that: +// - contains all the data in consensusPerPeerState +// - doesn't share any pointers with it +// - can be marshalled to JSON +// +// part of the raft snapshotting, getPersistable will be called during Snapshot +// and the results used during persist (concurrently with Apply) +func (ps *consensusPerPeerState) getPersistable() persistablePPS { + return persistablePPS{ + AddrToDomain: maps.Collect(ps.addrToDomain.All()), + DomainToAddr: maps.Clone(ps.domainToAddr), + } +} + +type persistablePPS struct { + DomainToAddr map[string]netip.Addr + AddrToDomain map[netip.Addr]whereWhen +} + +func (p persistablePPS) toPerPeerState() *consensusPerPeerState { + atd := &syncs.Map[netip.Addr, whereWhen]{} + for k, v := range p.AddrToDomain { + atd.Store(k, v) + } + return &consensusPerPeerState{ + domainToAddr: p.DomainToAddr, + addrToDomain: atd, + } +} diff --git a/cmd/natc/ippool/ippool.go b/cmd/natc/ippool/ippool.go new file mode 100644 index 000000000..5a2dcbec9 --- /dev/null +++ b/cmd/natc/ippool/ippool.go @@ -0,0 +1,133 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// ippool implements IP address storage, creation, and retrieval for cmd/natc +package ippool + +import ( + "errors" + "log" + "math/big" + "net/netip" + "sync" + "time" + + "github.com/gaissmai/bart" + "go4.org/netipx" + "tailscale.com/syncs" + "tailscale.com/tailcfg" + "tailscale.com/util/dnsname" + "tailscale.com/util/mak" +) + +var ErrNoIPsAvailable = errors.New("no IPs available") + +// IPPool allocates IPv4 addresses from a pool to DNS domains, on a per tailcfg.NodeID basis. +// For each tailcfg.NodeID, IPv4 addresses are associated with at most one DNS domain. +// Addresses may be reused across other tailcfg.NodeID's for the same or other domains. +type IPPool interface { + // DomainForIP looks up the domain associated with a tailcfg.NodeID and netip.Addr pair. + // If there is no association, the result is empty and ok is false. + DomainForIP(tailcfg.NodeID, netip.Addr, time.Time) (string, bool) + + // IPForDomain looks up or creates an IP address allocation for the tailcfg.NodeID and domain pair. + // If no address association is found, one is allocated from the range of free addresses for this tailcfg.NodeID. + // If no more address are available, an error is returned. + IPForDomain(tailcfg.NodeID, string) (netip.Addr, error) +} + +type SingleMachineIPPool struct { + perPeerMap syncs.Map[tailcfg.NodeID, *perPeerState] + IPSet *netipx.IPSet +} + +func (ipp *SingleMachineIPPool) DomainForIP(from tailcfg.NodeID, addr netip.Addr, _ time.Time) (string, bool) { + ps, ok := ipp.perPeerMap.Load(from) + if !ok { + log.Printf("handleTCPFlow: no perPeerState for %v", from) + return "", false + } + domain, ok := ps.domainForIP(addr) + if !ok { + log.Printf("handleTCPFlow: no domain for IP %v\n", addr) + return "", false + } + return domain, ok +} + +func (ipp *SingleMachineIPPool) IPForDomain(from tailcfg.NodeID, domain string) (netip.Addr, error) { + npps := &perPeerState{ + ipset: ipp.IPSet, + } + ps, _ := ipp.perPeerMap.LoadOrStore(from, npps) + return ps.ipForDomain(domain) +} + +// perPeerState holds the state for a single peer. +type perPeerState struct { + ipset *netipx.IPSet + + mu sync.Mutex + addrInUse *big.Int + domainToAddr map[string]netip.Addr + addrToDomain *bart.Table[string] +} + +// domainForIP returns the domain name assigned to the given IP address and +// whether it was found. +func (ps *perPeerState) domainForIP(ip netip.Addr) (_ string, ok bool) { + ps.mu.Lock() + defer ps.mu.Unlock() + if ps.addrToDomain == nil { + return "", false + } + return ps.addrToDomain.Lookup(ip) +} + +// ipForDomain assigns a pair of unique IP addresses for the given domain and +// returns them. The first address is an IPv4 address and the second is an IPv6 +// address. If the domain already has assigned addresses, it returns them. +func (ps *perPeerState) ipForDomain(domain string) (netip.Addr, error) { + fqdn, err := dnsname.ToFQDN(domain) + if err != nil { + return netip.Addr{}, err + } + domain = fqdn.WithoutTrailingDot() + + ps.mu.Lock() + defer ps.mu.Unlock() + if addr, ok := ps.domainToAddr[domain]; ok { + return addr, nil + } + addr := ps.assignAddrsLocked(domain) + if !addr.IsValid() { + return netip.Addr{}, ErrNoIPsAvailable + } + return addr, nil +} + +// unusedIPv4Locked returns an unused IPv4 address from the available ranges. +func (ps *perPeerState) unusedIPv4Locked() netip.Addr { + if ps.addrInUse == nil { + ps.addrInUse = big.NewInt(0) + } + return allocAddr(ps.ipset, ps.addrInUse) +} + +// assignAddrsLocked assigns a pair of unique IP addresses for the given domain +// and returns them. The first address is an IPv4 address and the second is an +// IPv6 address. It does not check if the domain already has assigned addresses. +// ps.mu must be held. +func (ps *perPeerState) assignAddrsLocked(domain string) netip.Addr { + if ps.addrToDomain == nil { + ps.addrToDomain = &bart.Table[string]{} + } + v4 := ps.unusedIPv4Locked() + if !v4.IsValid() { + return netip.Addr{} + } + addr := v4 + mak.Set(&ps.domainToAddr, domain, addr) + ps.addrToDomain.Insert(netip.PrefixFrom(addr, addr.BitLen()), domain) + return addr +} diff --git a/cmd/natc/ippool/ippool_test.go b/cmd/natc/ippool/ippool_test.go new file mode 100644 index 000000000..8d474f86a --- /dev/null +++ b/cmd/natc/ippool/ippool_test.go @@ -0,0 +1,108 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package ippool + +import ( + "errors" + "fmt" + "net/netip" + "testing" + "time" + + "go4.org/netipx" + "tailscale.com/tailcfg" + "tailscale.com/util/must" +) + +func TestIPPoolExhaustion(t *testing.T) { + smallPrefix := netip.MustParsePrefix("100.64.1.0/30") // Only 4 IPs: .0, .1, .2, .3 + var ipsb netipx.IPSetBuilder + ipsb.AddPrefix(smallPrefix) + addrPool := must.Get(ipsb.IPSet()) + pool := SingleMachineIPPool{IPSet: addrPool} + + assignedIPs := make(map[netip.Addr]string) + + domains := []string{"a.example.com", "b.example.com", "c.example.com", "d.example.com", "e.example.com"} + + var errs []error + + from := tailcfg.NodeID(12345) + + for i := 0; i < 5; i++ { + for _, domain := range domains { + addr, err := pool.IPForDomain(from, domain) + if err != nil { + errs = append(errs, fmt.Errorf("failed to get IP for domain %q: %w", domain, err)) + continue + } + + if d, ok := assignedIPs[addr]; ok { + if d != domain { + t.Errorf("IP %s reused for domain %q, previously assigned to %q", addr, domain, d) + } + } else { + assignedIPs[addr] = domain + } + } + } + + for addr, domain := range assignedIPs { + if addr.Is4() && !smallPrefix.Contains(addr) { + t.Errorf("IP %s for domain %q not in expected range %s", addr, domain, smallPrefix) + } + } + + // expect one error for each iteration with the 5th domain + if len(errs) != 5 { + t.Errorf("Expected 5 errors, got %d: %v", len(errs), errs) + } + for _, err := range errs { + if !errors.Is(err, ErrNoIPsAvailable) { + t.Errorf("generateDNSResponse() error = %v, want ErrNoIPsAvailable", err) + } + } +} + +func TestIPPool(t *testing.T) { + var ipsb netipx.IPSetBuilder + ipsb.AddPrefix(netip.MustParsePrefix("100.64.1.0/24")) + addrPool := must.Get(ipsb.IPSet()) + pool := SingleMachineIPPool{ + IPSet: addrPool, + } + from := tailcfg.NodeID(12345) + addr, err := pool.IPForDomain(from, "example.com") + if err != nil { + t.Fatalf("ipForDomain() error = %v", err) + } + + if !addr.IsValid() { + t.Fatal("ipForDomain() returned an invalid address") + } + + if !addr.Is4() { + t.Errorf("Address is not IPv4: %s", addr) + } + + if !addrPool.Contains(addr) { + t.Errorf("IPv4 address %s not in range %s", addr, addrPool) + } + + domain, ok := pool.DomainForIP(from, addr, time.Now()) + if !ok { + t.Errorf("domainForIP(%s) not found", addr) + } else if domain != "example.com" { + t.Errorf("domainForIP(%s) = %s, want %s", addr, domain, "example.com") + } + + addr2, err := pool.IPForDomain(from, "example.com") + if err != nil { + t.Fatalf("ipForDomain() second call error = %v", err) + } + + if addr.Compare(addr2) != 0 { + t.Errorf("ipForDomain() second call = %v, want %v", addr2, addr) + } +} diff --git a/cmd/natc/ippool/ipx.go b/cmd/natc/ippool/ipx.go new file mode 100644 index 000000000..8259a56db --- /dev/null +++ b/cmd/natc/ippool/ipx.go @@ -0,0 +1,130 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package ippool + +import ( + "math/big" + "math/bits" + "math/rand/v2" + "net/netip" + + "go4.org/netipx" +) + +func addrLessOrEqual(a, b netip.Addr) bool { + if a.Less(b) { + return true + } + if a == b { + return true + } + return false +} + +// indexOfAddr returns the index of addr in ipset, or -1 if not found. +func indexOfAddr(addr netip.Addr, ipset *netipx.IPSet) int { + var base int // offset of the current range + for _, r := range ipset.Ranges() { + if addr.Less(r.From()) { + return -1 + } + numFrom := v4ToNum(r.From()) + if addrLessOrEqual(addr, r.To()) { + numInRange := int(v4ToNum(addr) - numFrom) + return base + numInRange + } + numTo := v4ToNum(r.To()) + base += int(numTo-numFrom) + 1 + } + return -1 +} + +// addrAtIndex returns the address at the given index in ipset, or an empty +// address if index is out of range. +func addrAtIndex(index int, ipset *netipx.IPSet) netip.Addr { + if index < 0 { + return netip.Addr{} + } + var base int // offset of the current range + for _, r := range ipset.Ranges() { + numFrom := v4ToNum(r.From()) + numTo := v4ToNum(r.To()) + if index <= base+int(numTo-numFrom) { + return numToV4(uint32(int(numFrom) + index - base)) + } + base += int(numTo-numFrom) + 1 + } + return netip.Addr{} +} + +// TODO(golang/go#9455): once we have uint128 we can easily implement for all addrs. + +// v4ToNum returns a uint32 representation of the IPv4 address. If addr is not +// an IPv4 address, this function will panic. +func v4ToNum(addr netip.Addr) uint32 { + addr = addr.Unmap() + if !addr.Is4() { + panic("only IPv4 addresses are supported by v4ToNum") + } + b := addr.As4() + var o uint32 + o = o<<8 | uint32(b[0]) + o = o<<8 | uint32(b[1]) + o = o<<8 | uint32(b[2]) + o = o<<8 | uint32(b[3]) + return o +} + +func numToV4(i uint32) netip.Addr { + var addr [4]byte + addr[0] = byte((i >> 24) & 0xff) + addr[1] = byte((i >> 16) & 0xff) + addr[2] = byte((i >> 8) & 0xff) + addr[3] = byte(i & 0xff) + return netip.AddrFrom4(addr) +} + +// allocAddr returns an address in ipset that is not already marked allocated in allocated. +func allocAddr(ipset *netipx.IPSet, allocated *big.Int) netip.Addr { + // first try to allocate a random IP from each range, if we land on one. + var base uint32 // index offset of the current range + for _, r := range ipset.Ranges() { + numFrom := v4ToNum(r.From()) + numTo := v4ToNum(r.To()) + randInRange := rand.N(numTo - numFrom) + randIndex := base + randInRange + if allocated.Bit(int(randIndex)) == 0 { + allocated.SetBit(allocated, int(randIndex), 1) + return numToV4(numFrom + randInRange) + } + base += numTo - numFrom + 1 + } + + // fall back to seeking a free bit in the allocated set + index := -1 + for i, word := range allocated.Bits() { + zbi := leastZeroBit(uint(word)) + if zbi == -1 { + continue + } + index = i*bits.UintSize + zbi + allocated.SetBit(allocated, index, 1) + break + } + if index == -1 { + return netip.Addr{} + } + return addrAtIndex(index, ipset) +} + +// leastZeroBit returns the index of the least significant zero bit in the given uint, or -1 +// if all bits are set. +func leastZeroBit(n uint) int { + notN := ^n + rightmostBit := notN & -notN + if rightmostBit == 0 { + return -1 + } + return bits.TrailingZeros(rightmostBit) +} diff --git a/cmd/natc/ippool/ipx_test.go b/cmd/natc/ippool/ipx_test.go new file mode 100644 index 000000000..2e2b9d3d4 --- /dev/null +++ b/cmd/natc/ippool/ipx_test.go @@ -0,0 +1,150 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package ippool + +import ( + "math" + "math/big" + "net/netip" + "testing" + + "go4.org/netipx" + "tailscale.com/util/must" +) + +func TestV4ToNum(t *testing.T) { + cases := []struct { + addr netip.Addr + num uint32 + }{ + {netip.MustParseAddr("0.0.0.0"), 0}, + {netip.MustParseAddr("255.255.255.255"), 0xffffffff}, + {netip.MustParseAddr("8.8.8.8"), 0x08080808}, + {netip.MustParseAddr("192.168.0.1"), 0xc0a80001}, + {netip.MustParseAddr("10.0.0.1"), 0x0a000001}, + {netip.MustParseAddr("172.16.0.1"), 0xac100001}, + {netip.MustParseAddr("100.64.0.1"), 0x64400001}, + } + + for _, tc := range cases { + num := v4ToNum(tc.addr) + if num != tc.num { + t.Errorf("addrNum(%v) = %d, want %d", tc.addr, num, tc.num) + } + if numToV4(num) != tc.addr { + t.Errorf("numToV4(%d) = %v, want %v", num, numToV4(num), tc.addr) + } + } + + func() { + defer func() { + if r := recover(); r == nil { + t.Fatal("expected panic") + } + }() + + v4ToNum(netip.MustParseAddr("::1")) + }() +} + +func TestAddrIndex(t *testing.T) { + builder := netipx.IPSetBuilder{} + builder.AddRange(netipx.MustParseIPRange("10.0.0.1-10.0.0.5")) + builder.AddRange(netipx.MustParseIPRange("192.168.0.1-192.168.0.10")) + ipset := must.Get(builder.IPSet()) + + indexCases := []struct { + addr netip.Addr + index int + }{ + {netip.MustParseAddr("10.0.0.1"), 0}, + {netip.MustParseAddr("10.0.0.2"), 1}, + {netip.MustParseAddr("10.0.0.3"), 2}, + {netip.MustParseAddr("10.0.0.4"), 3}, + {netip.MustParseAddr("10.0.0.5"), 4}, + {netip.MustParseAddr("192.168.0.1"), 5}, + {netip.MustParseAddr("192.168.0.5"), 9}, + {netip.MustParseAddr("192.168.0.10"), 14}, + {netip.MustParseAddr("172.16.0.1"), -1}, // Not in set + } + + for _, tc := range indexCases { + index := indexOfAddr(tc.addr, ipset) + if index != tc.index { + t.Errorf("indexOfAddr(%v) = %d, want %d", tc.addr, index, tc.index) + } + if tc.index == -1 { + continue + } + addr := addrAtIndex(tc.index, ipset) + if addr != tc.addr { + t.Errorf("addrAtIndex(%d) = %v, want %v", tc.index, addr, tc.addr) + } + } +} + +func TestAllocAddr(t *testing.T) { + builder := netipx.IPSetBuilder{} + builder.AddRange(netipx.MustParseIPRange("10.0.0.1-10.0.0.5")) + builder.AddRange(netipx.MustParseIPRange("192.168.0.1-192.168.0.10")) + ipset := must.Get(builder.IPSet()) + + allocated := new(big.Int) + for range 15 { + addr := allocAddr(ipset, allocated) + if !addr.IsValid() { + t.Errorf("allocAddr() = invalid, want valid") + } + if !ipset.Contains(addr) { + t.Errorf("allocAddr() = %v, not in set", addr) + } + } + addr := allocAddr(ipset, allocated) + if addr.IsValid() { + t.Errorf("allocAddr() = %v, want invalid", addr) + } + wantAddr := netip.MustParseAddr("10.0.0.2") + allocated.SetBit(allocated, indexOfAddr(wantAddr, ipset), 0) + addr = allocAddr(ipset, allocated) + if addr != wantAddr { + t.Errorf("allocAddr() = %v, want %v", addr, wantAddr) + } +} + +func TestLeastZeroBit(t *testing.T) { + cases := []struct { + num uint + want int + }{ + {math.MaxUint, -1}, + {0, 0}, + {0b01, 1}, + {0b11, 2}, + {0b111, 3}, + {math.MaxUint, -1}, + {math.MaxUint - 1, 0}, + } + if math.MaxUint == math.MaxUint64 { + cases = append(cases, []struct { + num uint + want int + }{ + {math.MaxUint >> 1, 63}, + }...) + } else { + cases = append(cases, []struct { + num uint + want int + }{ + {math.MaxUint >> 1, 31}, + }...) + } + + for _, tc := range cases { + got := leastZeroBit(tc.num) + if got != tc.want { + t.Errorf("leastZeroBit(%b) = %d, want %d", tc.num, got, tc.want) + } + } +} diff --git a/cmd/natc/natc.go b/cmd/natc/natc.go index d94523c6e..a4f53d657 100644 --- a/cmd/natc/natc.go +++ b/cmd/natc/natc.go @@ -8,8 +8,9 @@ package main import ( "context" - "encoding/binary" + "encoding/json" "errors" + "expvar" "flag" "fmt" "log" @@ -18,25 +19,28 @@ import ( "net/http" "net/netip" "os" + "path/filepath" "strings" - "sync" "time" "github.com/gaissmai/bart" + "github.com/hashicorp/raft" "github.com/inetaf/tcpproxy" "github.com/peterbourgon/ff/v3" + "go4.org/netipx" "golang.org/x/net/dns/dnsmessage" - "tailscale.com/client/tailscale" + "tailscale.com/client/local" + "tailscale.com/client/tailscale/apitype" + "tailscale.com/cmd/natc/ippool" "tailscale.com/envknob" "tailscale.com/hostinfo" "tailscale.com/ipn" "tailscale.com/net/netutil" - "tailscale.com/syncs" - "tailscale.com/tailcfg" "tailscale.com/tsnet" "tailscale.com/tsweb" - "tailscale.com/util/dnsname" "tailscale.com/util/mak" + "tailscale.com/util/must" + "tailscale.com/wgengine/netstack" ) func main() { @@ -48,14 +52,20 @@ func main() { // Parse flags fs := flag.NewFlagSet("natc", flag.ExitOnError) var ( - debugPort = fs.Int("debug-port", 8893, "Listening port for debug/metrics endpoint") - hostname = fs.String("hostname", "", "Hostname to register the service under") - siteID = fs.Uint("site-id", 1, "an integer site ID to use for the ULA prefix which allows for multiple proxies to act in a HA configuration") - v4PfxStr = fs.String("v4-pfx", "100.64.1.0/24", "comma-separated list of IPv4 prefixes to advertise") - verboseTSNet = fs.Bool("verbose-tsnet", false, "enable verbose logging in tsnet") - printULA = fs.Bool("print-ula", false, "print the ULA prefix and exit") - ignoreDstPfxStr = fs.String("ignore-destinations", "", "comma-separated list of prefixes to ignore") - wgPort = fs.Uint("wg-port", 0, "udp port for wireguard and peer to peer traffic") + debugPort = fs.Int("debug-port", 8893, "Listening port for debug/metrics endpoint") + hostname = fs.String("hostname", "", "Hostname to register the service under") + siteID = fs.Uint("site-id", 1, "an integer site ID to use for the ULA prefix which allows for multiple proxies to act in a HA configuration") + v4PfxStr = fs.String("v4-pfx", "100.64.1.0/24", "comma-separated list of IPv4 prefixes to advertise") + dnsServers = fs.String("dns-servers", "", "comma separated list of upstream DNS to use, including host and port (use system if empty)") + verboseTSNet = fs.Bool("verbose-tsnet", false, "enable verbose logging in tsnet") + printULA = fs.Bool("print-ula", false, "print the ULA prefix and exit") + ignoreDstPfxStr = fs.String("ignore-destinations", "", "comma-separated list of prefixes to ignore") + wgPort = fs.Uint("wg-port", 0, "udp port for wireguard and peer to peer traffic") + clusterTag = fs.String("cluster-tag", "", "optionally run in a consensus cluster with other nodes with this tag") + server = fs.String("login-server", ipn.DefaultControlURL, "the base URL of control server") + stateDir = fs.String("state-dir", "", "path to directory in which to store app state") + clusterFollowOnly = fs.Bool("follow-only", false, "Try to find a leader with the cluster tag or exit.") + clusterAdminPort = fs.Int("cluster-admin-port", 8081, "Port on localhost for the cluster admin HTTP API") ) ff.Parse(fs, os.Args[1:], ff.WithEnvVarPrefix("TS_NATC")) @@ -73,7 +83,7 @@ func main() { } var ignoreDstTable *bart.Table[bool] - for _, s := range strings.Split(*ignoreDstPfxStr, ",") { + for s := range strings.SplitSeq(*ignoreDstPfxStr, ",") { s := strings.TrimSpace(s) if s == "" { continue @@ -90,21 +100,11 @@ func main() { } ignoreDstTable.Insert(pfx, true) } - var v4Prefixes []netip.Prefix - for _, s := range strings.Split(*v4PfxStr, ",") { - p := netip.MustParsePrefix(strings.TrimSpace(s)) - if p.Masked() != p { - log.Fatalf("v4 prefix %v is not a masked prefix", p) - } - v4Prefixes = append(v4Prefixes, p) - } - if len(v4Prefixes) == 0 { - log.Fatalf("no v4 prefixes specified") - } - dnsAddr := v4Prefixes[0].Addr() ts := &tsnet.Server{ Hostname: *hostname, + Dir: *stateDir, } + ts.ControlURL = *server if *wgPort != 0 { if *wgPort >= 1<<16 { log.Fatalf("wg-port must be in the range [0, 65535]") @@ -112,6 +112,7 @@ func main() { ts.Port = uint16(*wgPort) } defer ts.Close() + if *verboseTSNet { ts.Logf = log.Printf } @@ -129,6 +130,16 @@ func main() { log.Fatalf("debug serve: %v", http.Serve(dln, mux)) }() } + + if err := ts.Start(); err != nil { + log.Fatalf("ts.Start: %v", err) + } + // TODO(raggi): this is not a public interface or guarantee. + ns := ts.Sys().Netstack.Get().(*netstack.Impl) + if *debugPort != 0 { + expvar.Publish("netstack", ns.ExpVar()) + } + lc, err := ts.LocalClient() if err != nil { log.Fatalf("LocalClient() failed: %v", err) @@ -137,36 +148,127 @@ func main() { log.Fatalf("ts.Up: %v", err) } + var prefixes []netip.Prefix + for _, s := range strings.Split(*v4PfxStr, ",") { + p := netip.MustParsePrefix(strings.TrimSpace(s)) + if p.Masked() != p { + log.Fatalf("v4 prefix %v is not a masked prefix", p) + } + prefixes = append(prefixes, p) + } + routes, dnsAddr, addrPool := calculateAddresses(prefixes) + + v6ULA := ula(uint16(*siteID)) + + var ipp ippool.IPPool + if *clusterTag != "" { + cipp := ippool.NewConsensusIPPool(addrPool) + clusterStateDir, err := getClusterStatePath(*stateDir) + if err != nil { + log.Fatalf("Creating cluster state dir failed: %v", err) + } + err = cipp.StartConsensus(ctx, ts, ippool.ClusterOpts{ + Tag: *clusterTag, + StateDir: clusterStateDir, + FollowOnly: *clusterFollowOnly, + }) + if err != nil { + log.Fatalf("StartConsensus: %v", err) + } + defer func() { + err := cipp.StopConsensus(ctx) + if err != nil { + log.Printf("Error stopping consensus: %v", err) + } + }() + ipp = cipp + + go func() { + // This listens on localhost only, so that only those with access to the host machine + // can remove servers from the cluster config. + log.Print(http.ListenAndServe(fmt.Sprintf("127.0.0.1:%d", *clusterAdminPort), httpClusterAdmin(cipp))) + }() + } else { + ipp = &ippool.SingleMachineIPPool{IPSet: addrPool} + } + c := &connector{ ts: ts, - lc: lc, - dnsAddr: dnsAddr, - v4Ranges: v4Prefixes, - v6ULA: ula(uint16(*siteID)), + whois: lc, + v6ULA: v6ULA, ignoreDsts: ignoreDstTable, + ipPool: ipp, + routes: routes, + dnsAddr: dnsAddr, + resolver: getResolver(*dnsServers), } - c.run(ctx) + c.run(ctx, lc) +} + +// getResolver parses serverFlag and returns either the default resolver, or a +// resolver that uses the provided comma-separated DNS server AddrPort's, or +// panics. +func getResolver(serverFlag string) lookupNetIPer { + if serverFlag == "" { + return net.DefaultResolver + } + var addrs []string + for s := range strings.SplitSeq(serverFlag, ",") { + s = strings.TrimSpace(s) + addr, err := netip.ParseAddrPort(s) + if err != nil { + log.Fatalf("dns server provided: %q does not parse: %v", s, err) + } + addrs = append(addrs, addr.String()) + } + return &net.Resolver{ + PreferGo: true, + Dial: func(ctx context.Context, network string, address string) (net.Conn, error) { + var dialer net.Dialer + // TODO(raggi): perhaps something other than random? + return dialer.DialContext(ctx, network, addrs[rand.N(len(addrs))]) + }, + } +} + +func calculateAddresses(prefixes []netip.Prefix) (*netipx.IPSet, netip.Addr, *netipx.IPSet) { + var ipsb netipx.IPSetBuilder + for _, p := range prefixes { + ipsb.AddPrefix(p) + } + routesToAdvertise := must.Get(ipsb.IPSet()) + dnsAddr := routesToAdvertise.Ranges()[0].From() + ipsb.Remove(dnsAddr) + addrPool := must.Get(ipsb.IPSet()) + return routesToAdvertise, dnsAddr, addrPool +} + +type lookupNetIPer interface { + LookupNetIP(ctx context.Context, net, host string) ([]netip.Addr, error) +} + +type whoiser interface { + WhoIs(ctx context.Context, remoteAddr string) (*apitype.WhoIsResponse, error) } type connector struct { // ts is the tsnet.Server used to host the connector. ts *tsnet.Server - // lc is the LocalClient used to interact with the tsnet.Server hosting this + // whois is the local.Client used to interact with the tsnet.Server hosting this // connector. - lc *tailscale.LocalClient + whois whoiser // dnsAddr is the IPv4 address to listen on for DNS requests. It is used to // prevent the app connector from assigning it to a domain. dnsAddr netip.Addr - // v4Ranges is the list of IPv4 ranges to advertise and assign addresses from. - // These are masked prefixes. - v4Ranges []netip.Prefix + // routes is the set of IPv4 ranges advertised to the tailnet, or ipset with + // the dnsAddr removed. + routes *netipx.IPSet + // v6ULA is the ULA prefix used by the app connector to assign IPv6 addresses. v6ULA netip.Prefix - perPeerMap syncs.Map[tailcfg.NodeID, *perPeerState] - // ignoreDsts is initialized at start up with the contents of --ignore-destinations (if none it is nil) // It is never mutated, only used for lookups. // Users who want to natc a DNS wildcard but not every address record in that domain can supply the @@ -175,6 +277,12 @@ type connector struct { // return a dns response that contains the ip addresses we discovered with the lookup (ie not the // natc behavior, which would return a dummy ip address pointing at natc). ignoreDsts *bart.Table[bool] + + // ipPool contains the per-peer IPv4 address assignments. + ipPool ippool.IPPool + + // resolver is used to lookup IP addresses for DNS queries. + resolver lookupNetIPer } // v6ULA is the ULA prefix used by the app connector to assign IPv6 addresses. @@ -194,11 +302,11 @@ func ula(siteID uint16) netip.Prefix { // // The passed in context is only used for the initial setup. The connector runs // forever. -func (c *connector) run(ctx context.Context) { - if _, err := c.lc.EditPrefs(ctx, &ipn.MaskedPrefs{ +func (c *connector) run(ctx context.Context, lc *local.Client) { + if _, err := lc.EditPrefs(ctx, &ipn.MaskedPrefs{ AdvertiseRoutesSet: true, Prefs: ipn.Prefs{ - AdvertiseRoutes: append(c.v4Ranges, c.v6ULA), + AdvertiseRoutes: append(c.routes.Prefixes(), c.v6ULA), }, }); err != nil { log.Fatalf("failed to advertise routes: %v", err) @@ -228,26 +336,6 @@ func (c *connector) serveDNS() { } } -func lookupDestinationIP(domain string) ([]netip.Addr, error) { - netIPs, err := net.LookupIP(domain) - if err != nil { - var dnsError *net.DNSError - if errors.As(err, &dnsError) && dnsError.IsNotFound { - return nil, nil - } else { - return nil, err - } - } - var addrs []netip.Addr - for _, ip := range netIPs { - a, ok := netip.AddrFromSlice(ip) - if ok { - addrs = append(addrs, a) - } - } - return addrs, nil -} - // handleDNS handles a DNS request to the app connector. // It generates a response based on the request and the node that sent it. // @@ -262,157 +350,161 @@ func lookupDestinationIP(domain string) ([]netip.Addr, error) { func (c *connector) handleDNS(pc net.PacketConn, buf []byte, remoteAddr *net.UDPAddr) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - who, err := c.lc.WhoIs(ctx, remoteAddr.String()) + who, err := c.whois.WhoIs(ctx, remoteAddr.String()) if err != nil { - log.Printf("HandleDNS: WhoIs failed: %v\n", err) + log.Printf("HandleDNS(remote=%s): WhoIs failed: %v\n", remoteAddr.String(), err) return } var msg dnsmessage.Message err = msg.Unpack(buf) if err != nil { - log.Printf("HandleDNS: dnsmessage unpack failed: %v\n ", err) + log.Printf("HandleDNS(remote=%s): dnsmessage unpack failed: %v\n", remoteAddr.String(), err) return } - // If there are destination ips that we don't want to route, we - // have to do a dns lookup here to find the destination ip. - if c.ignoreDsts != nil { - if len(msg.Questions) > 0 { - q := msg.Questions[0] - switch q.Type { - case dnsmessage.TypeAAAA, dnsmessage.TypeA: - dstAddrs, err := lookupDestinationIP(q.Name.String()) + var resolves map[string][]netip.Addr + var addrQCount int + for _, q := range msg.Questions { + if q.Type != dnsmessage.TypeA && q.Type != dnsmessage.TypeAAAA { + continue + } + addrQCount++ + if _, ok := resolves[q.Name.String()]; !ok { + addrs, err := c.resolver.LookupNetIP(ctx, "ip", q.Name.String()) + var dnsErr *net.DNSError + if errors.As(err, &dnsErr) && dnsErr.IsNotFound { + continue + } + if err != nil { + log.Printf("HandleDNS(remote=%s): lookup destination failed: %v\n", remoteAddr.String(), err) + return + } + // Note: If _any_ destination is ignored, pass through all of the resolved + // addresses as-is. + // + // This could result in some odd split-routing if there was a mix of + // ignored and non-ignored addresses, but it's currently the user + // preferred behavior. + if !c.ignoreDestination(addrs) { + addr, err := c.ipPool.IPForDomain(who.Node.ID, q.Name.String()) if err != nil { - log.Printf("HandleDNS: lookup destination failed: %v\n ", err) - return - } - if c.ignoreDestination(dstAddrs) { - bs, err := dnsResponse(&msg, dstAddrs) - // TODO (fran): treat as SERVFAIL - if err != nil { - log.Printf("HandleDNS: generate ignore response failed: %v\n", err) - return - } - _, err = pc.WriteTo(bs, remoteAddr) - if err != nil { - log.Printf("HandleDNS: write failed: %v\n", err) - } + log.Printf("HandleDNS(remote=%s): lookup destination failed: %v\n", remoteAddr.String(), err) return } + addrs = []netip.Addr{addr, v6ForV4(c.v6ULA.Addr(), addr)} } + mak.Set(&resolves, q.Name.String(), addrs) } } - // None of the destination IP addresses match an ignore destination prefix, do - // the natc thing. - resp, err := c.generateDNSResponse(&msg, who.Node.ID) - // TODO (fran): treat as SERVFAIL - if err != nil { - log.Printf("HandleDNS: connector handling failed: %v\n", err) - return - } - // TODO (fran): treat as NXDOMAIN - if len(resp) == 0 { - return - } - // This connector handled the DNS request - _, err = pc.WriteTo(resp, remoteAddr) - if err != nil { - log.Printf("HandleDNS: write failed: %v\n", err) - } -} - -// tsMBox is the mailbox used in SOA records. -// The convention is to replace the @ symbol with a dot. -// So in this case, the mailbox is support.tailscale.com. with the trailing dot -// to indicate that it is a fully qualified domain name. -var tsMBox = dnsmessage.MustNewName("support.tailscale.com.") - -// generateDNSResponse generates a DNS response for the given request. The from -// argument is the NodeID of the node that sent the request. -func (c *connector) generateDNSResponse(req *dnsmessage.Message, from tailcfg.NodeID) ([]byte, error) { - pm, _ := c.perPeerMap.LoadOrStore(from, &perPeerState{c: c}) - var addrs []netip.Addr - if len(req.Questions) > 0 { - switch req.Questions[0].Type { - case dnsmessage.TypeAAAA, dnsmessage.TypeA: - var err error - addrs, err = pm.ipForDomain(req.Questions[0].Name.String()) - if err != nil { - return nil, err - } - } + rcode := dnsmessage.RCodeSuccess + if addrQCount > 0 && len(resolves) == 0 { + rcode = dnsmessage.RCodeNameError } - return dnsResponse(req, addrs) -} -// dnsResponse makes a DNS response for the natc. If the dnsmessage is requesting TypeAAAA -// or TypeA the provided addrs of the requested type will be used. -func dnsResponse(req *dnsmessage.Message, addrs []netip.Addr) ([]byte, error) { b := dnsmessage.NewBuilder(nil, dnsmessage.Header{ - ID: req.Header.ID, + ID: msg.Header.ID, Response: true, Authoritative: true, + RCode: rcode, }) b.EnableCompression() - if len(req.Questions) == 0 { - return b.Finish() - } - q := req.Questions[0] if err := b.StartQuestions(); err != nil { - return nil, err + log.Printf("HandleDNS(remote=%s): dnsmessage start questions failed: %v\n", remoteAddr.String(), err) + return } - if err := b.Question(q); err != nil { - return nil, err + + for _, q := range msg.Questions { + b.Question(q) } + if err := b.StartAnswers(); err != nil { - return nil, err + log.Printf("HandleDNS(remote=%s): dnsmessage start answers failed: %v\n", remoteAddr.String(), err) + return } - switch q.Type { - case dnsmessage.TypeAAAA, dnsmessage.TypeA: - want6 := q.Type == dnsmessage.TypeAAAA - for _, ip := range addrs { - if want6 != ip.Is6() { - continue + + for _, q := range msg.Questions { + switch q.Type { + case dnsmessage.TypeSOA: + if err := b.SOAResource( + dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 120}, + dnsmessage.SOAResource{NS: q.Name, MBox: tsMBox, Serial: 2023030600, + Refresh: 120, Retry: 120, Expire: 120, MinTTL: 60}, + ); err != nil { + log.Printf("HandleDNS(remote=%s): dnsmessage SOA resource failed: %v\n", remoteAddr.String(), err) + return + } + case dnsmessage.TypeNS: + if err := b.NSResource( + dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 120}, + dnsmessage.NSResource{NS: tsMBox}, + ); err != nil { + log.Printf("HandleDNS(remote=%s): dnsmessage NS resource failed: %v\n", remoteAddr.String(), err) + return } - if want6 { + case dnsmessage.TypeAAAA: + for _, addr := range resolves[q.Name.String()] { + if !addr.Is6() { + continue + } if err := b.AAAAResource( - dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 5}, - dnsmessage.AAAAResource{AAAA: ip.As16()}, + dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 120}, + dnsmessage.AAAAResource{AAAA: addr.As16()}, ); err != nil { - return nil, err + log.Printf("HandleDNS(remote=%s): dnsmessage AAAA resource failed: %v\n", remoteAddr.String(), err) + return + } + } + case dnsmessage.TypeA: + for _, addr := range resolves[q.Name.String()] { + if !addr.Is4() { + continue } - } else { if err := b.AResource( - dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 5}, - dnsmessage.AResource{A: ip.As4()}, + dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 120}, + dnsmessage.AResource{A: addr.As4()}, ); err != nil { - return nil, err + log.Printf("HandleDNS(remote=%s): dnsmessage A resource failed: %v\n", remoteAddr.String(), err) + return } } } - case dnsmessage.TypeSOA: - if err := b.SOAResource( - dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 120}, - dnsmessage.SOAResource{NS: q.Name, MBox: tsMBox, Serial: 2023030600, - Refresh: 120, Retry: 120, Expire: 120, MinTTL: 60}, - ); err != nil { - return nil, err - } - case dnsmessage.TypeNS: - if err := b.NSResource( - dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 120}, - dnsmessage.NSResource{NS: tsMBox}, - ); err != nil { - return nil, err - } } - return b.Finish() + + out, err := b.Finish() + if err != nil { + log.Printf("HandleDNS(remote=%s): dnsmessage finish failed: %v\n", remoteAddr.String(), err) + return + } + _, err = pc.WriteTo(out, remoteAddr) + if err != nil { + log.Printf("HandleDNS(remote=%s): write failed: %v\n", remoteAddr.String(), err) + } +} + +func v6ForV4(ula netip.Addr, v4 netip.Addr) netip.Addr { + as16 := ula.As16() + as4 := v4.As4() + copy(as16[12:], as4[:]) + return netip.AddrFrom16(as16) } +func v4ForV6(v6 netip.Addr) netip.Addr { + as16 := v6.As16() + var as4 [4]byte + copy(as4[:], as16[12:]) + return netip.AddrFrom4(as4) +} + +// tsMBox is the mailbox used in SOA records. +// The convention is to replace the @ symbol with a dot. +// So in this case, the mailbox is support.tailscale.com. with the trailing dot +// to indicate that it is a fully qualified domain name. +var tsMBox = dnsmessage.MustNewName("support.tailscale.com.") + // handleTCPFlow handles a TCP flow from the given source to the given // destination. It uses the source address to determine the node that sent the // request and the destination address to determine the domain that the request @@ -421,32 +513,31 @@ func dnsResponse(req *dnsmessage.Message, addrs []netip.Addr) ([]byte, error) { func (c *connector) handleTCPFlow(src, dst netip.AddrPort) (handler func(net.Conn), intercept bool) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - who, err := c.lc.WhoIs(ctx, src.Addr().String()) + who, err := c.whois.WhoIs(ctx, src.Addr().String()) cancel() if err != nil { log.Printf("HandleTCPFlow: WhoIs failed: %v\n", err) return nil, false } - - from := who.Node.ID - ps, ok := c.perPeerMap.Load(from) - if !ok { - log.Printf("handleTCPFlow: no perPeerState for %v", from) - return nil, false + dstAddr := dst.Addr() + if dstAddr.Is6() { + dstAddr = v4ForV6(dstAddr) } - domain, ok := ps.domainForIP(dst.Addr()) + domain, ok := c.ipPool.DomainForIP(who.Node.ID, dstAddr, time.Now()) if !ok { - log.Printf("handleTCPFlow: no domain for IP %v\n", dst.Addr()) return nil, false } return func(conn net.Conn) { - proxyTCPConn(conn, domain) + proxyTCPConn(conn, domain, c) }, true } // ignoreDestination reports whether any of the provided dstAddrs match the prefixes configured // in --ignore-destinations func (c *connector) ignoreDestination(dstAddrs []netip.Addr) bool { + if c.ignoreDsts == nil { + return false + } for _, a := range dstAddrs { if _, ok := c.ignoreDsts.Lookup(a); ok { return true @@ -455,16 +546,34 @@ func (c *connector) ignoreDestination(dstAddrs []netip.Addr) bool { return false } -func proxyTCPConn(c net.Conn, dest string) { +func proxyTCPConn(c net.Conn, dest string, ctor *connector) { if c.RemoteAddr() == nil { log.Printf("proxyTCPConn: nil RemoteAddr") c.Close() return } - addrPortStr := c.LocalAddr().String() - _, port, err := net.SplitHostPort(addrPortStr) + laddr, err := netip.ParseAddrPort(c.LocalAddr().String()) if err != nil { - log.Printf("tcpRoundRobinHandler.Handle: bogus addrPort %q", addrPortStr) + log.Printf("proxyTCPConn: ParseAddrPort failed: %v", err) + c.Close() + return + } + + daddrs, err := ctor.resolver.LookupNetIP(context.TODO(), "ip", dest) + if err != nil { + log.Printf("proxyTCPConn: LookupNetIP failed: %v", err) + c.Close() + return + } + + if len(daddrs) == 0 { + log.Printf("proxyTCPConn: no IP addresses found for %s", dest) + c.Close() + return + } + + if ctor.ignoreDestination(daddrs) { + log.Printf("proxyTCPConn: closing connection to ignored destination %s (%v)", dest, daddrs) c.Close() return } @@ -474,102 +583,91 @@ func proxyTCPConn(c net.Conn, dest string) { return netutil.NewOneConnListener(c, nil), nil }, } - p.AddRoute(addrPortStr, &tcpproxy.DialProxy{ - Addr: fmt.Sprintf("%s:%s", dest, port), + + // TODO(raggi): more code could avoid this shuffle, but avoiding allocations + // for now most of the time daddrs will be short. + rand.Shuffle(len(daddrs), func(i, j int) { + daddrs[i], daddrs[j] = daddrs[j], daddrs[i] }) - p.Start() -} + daddr := daddrs[0] + + // Try to match the upstream and downstream protocols (v4/v6) + if laddr.Addr().Is6() { + for _, addr := range daddrs { + if addr.Is6() { + daddr = addr + break + } + } + } else { + for _, addr := range daddrs { + if addr.Is4() { + daddr = addr + break + } + } + } -// perPeerState holds the state for a single peer. -type perPeerState struct { - c *connector + // TODO(raggi): drop this library, it ends up being allocation and + // indirection heavy and really doesn't help us here. + dsockaddrs := netip.AddrPortFrom(daddr, laddr.Port()).String() + p.AddRoute(dsockaddrs, &tcpproxy.DialProxy{ + Addr: dsockaddrs, + }) - mu sync.Mutex - domainToAddr map[string][]netip.Addr - addrToDomain *bart.Table[string] + p.Start() } -// domainForIP returns the domain name assigned to the given IP address and -// whether it was found. -func (ps *perPeerState) domainForIP(ip netip.Addr) (_ string, ok bool) { - ps.mu.Lock() - defer ps.mu.Unlock() - if ps.addrToDomain == nil { - return "", false +func getClusterStatePath(stateDirFlag string) (string, error) { + var dirPath string + if stateDirFlag != "" { + dirPath = stateDirFlag + } else { + confDir, err := os.UserConfigDir() + if err != nil { + return "", err + } + dirPath = filepath.Join(confDir, "nat-connector-state") } - return ps.addrToDomain.Lookup(ip) -} + dirPath = filepath.Join(dirPath, "cluster") -// ipForDomain assigns a pair of unique IP addresses for the given domain and -// returns them. The first address is an IPv4 address and the second is an IPv6 -// address. If the domain already has assigned addresses, it returns them. -func (ps *perPeerState) ipForDomain(domain string) ([]netip.Addr, error) { - fqdn, err := dnsname.ToFQDN(domain) - if err != nil { - return nil, err + if err := os.MkdirAll(dirPath, 0700); err != nil { + return "", err } - domain = fqdn.WithoutTrailingDot() - - ps.mu.Lock() - defer ps.mu.Unlock() - if addrs, ok := ps.domainToAddr[domain]; ok { - return addrs, nil + if fi, err := os.Stat(dirPath); err != nil { + return "", err + } else if !fi.IsDir() { + return "", fmt.Errorf("%v is not a directory", dirPath) } - addrs := ps.assignAddrsLocked(domain) - return addrs, nil -} -// isIPUsedLocked reports whether the given IP address is already assigned to a -// domain. -// ps.mu must be held. -func (ps *perPeerState) isIPUsedLocked(ip netip.Addr) bool { - _, ok := ps.addrToDomain.Lookup(ip) - return ok + return dirPath, nil } -// unusedIPv4Locked returns an unused IPv4 address from the available ranges. -func (ps *perPeerState) unusedIPv4Locked() netip.Addr { - // TODO: skip ranges that have been exhausted - for _, r := range ps.c.v4Ranges { - ip := randV4(r) - for r.Contains(ip) { - if !ps.isIPUsedLocked(ip) && ip != ps.c.dnsAddr { - return ip - } - ip = ip.Next() +func httpClusterAdmin(ipp *ippool.ConsensusIPPool) http.Handler { + mux := http.NewServeMux() + mux.HandleFunc("GET /{$}", func(w http.ResponseWriter, r *http.Request) { + c, err := ipp.GetClusterConfiguration() + if err != nil { + log.Printf("cluster admin http: error getClusterConfig: %v", err) + http.Error(w, "", http.StatusInternalServerError) + return } - } - return netip.Addr{} -} - -// randV4 returns a random IPv4 address within the given prefix. -func randV4(maskedPfx netip.Prefix) netip.Addr { - bits := 32 - maskedPfx.Bits() - randBits := rand.Uint32N(1 << uint(bits)) - - ip4 := maskedPfx.Addr().As4() - pn := binary.BigEndian.Uint32(ip4[:]) - binary.BigEndian.PutUint32(ip4[:], randBits|pn) - return netip.AddrFrom4(ip4) -} - -// assignAddrsLocked assigns a pair of unique IP addresses for the given domain -// and returns them. The first address is an IPv4 address and the second is an -// IPv6 address. It does not check if the domain already has assigned addresses. -// ps.mu must be held. -func (ps *perPeerState) assignAddrsLocked(domain string) []netip.Addr { - if ps.addrToDomain == nil { - ps.addrToDomain = &bart.Table[string]{} - } - v4 := ps.unusedIPv4Locked() - as16 := ps.c.v6ULA.Addr().As16() - as4 := v4.As4() - copy(as16[12:], as4[:]) - v6 := netip.AddrFrom16(as16) - addrs := []netip.Addr{v4, v6} - mak.Set(&ps.domainToAddr, domain, addrs) - for _, a := range addrs { - ps.addrToDomain.Insert(netip.PrefixFrom(a, a.BitLen()), domain) - } - return addrs + if err := json.NewEncoder(w).Encode(c); err != nil { + log.Printf("cluster admin http: error encoding raft configuration: %v", err) + } + }) + mux.HandleFunc("DELETE /{id}", func(w http.ResponseWriter, r *http.Request) { + idString := r.PathValue("id") + id := raft.ServerID(idString) + idx, err := ipp.DeleteClusterServer(id) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + if err := json.NewEncoder(w).Encode(idx); err != nil { + log.Printf("cluster admin http: error encoding delete index: %v", err) + return + } + }) + return mux } diff --git a/cmd/natc/natc_test.go b/cmd/natc/natc_test.go new file mode 100644 index 000000000..c0a66deb8 --- /dev/null +++ b/cmd/natc/natc_test.go @@ -0,0 +1,678 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package main + +import ( + "context" + "fmt" + "io" + "net" + "net/netip" + "sync" + "testing" + "time" + + "github.com/gaissmai/bart" + "golang.org/x/net/dns/dnsmessage" + "tailscale.com/client/tailscale/apitype" + "tailscale.com/cmd/natc/ippool" + "tailscale.com/tailcfg" + "tailscale.com/util/must" +) + +func prefixEqual(a, b netip.Prefix) bool { + return a.Bits() == b.Bits() && a.Addr() == b.Addr() +} + +func TestULA(t *testing.T) { + tests := []struct { + name string + siteID uint16 + expected string + }{ + {"zero", 0, "fd7a:115c:a1e0:a99c:0000::/80"}, + {"one", 1, "fd7a:115c:a1e0:a99c:0001::/80"}, + {"max", 65535, "fd7a:115c:a1e0:a99c:ffff::/80"}, + {"random", 12345, "fd7a:115c:a1e0:a99c:3039::/80"}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := ula(tc.siteID) + expected := netip.MustParsePrefix(tc.expected) + if !prefixEqual(got, expected) { + t.Errorf("ula(%d) = %s; want %s", tc.siteID, got, expected) + } + }) + } +} + +type recordingPacketConn struct { + writes [][]byte +} + +func (w *recordingPacketConn) WriteTo(b []byte, addr net.Addr) (int, error) { + w.writes = append(w.writes, b) + return len(b), nil +} + +func (w *recordingPacketConn) ReadFrom(b []byte) (int, net.Addr, error) { + return 0, nil, io.EOF +} + +func (w *recordingPacketConn) Close() error { + return nil +} + +func (w *recordingPacketConn) LocalAddr() net.Addr { + return nil +} + +func (w *recordingPacketConn) RemoteAddr() net.Addr { + return nil +} + +func (w *recordingPacketConn) SetDeadline(t time.Time) error { + return nil +} + +func (w *recordingPacketConn) SetReadDeadline(t time.Time) error { + return nil +} + +func (w *recordingPacketConn) SetWriteDeadline(t time.Time) error { + return nil +} + +type resolver struct { + resolves map[string][]netip.Addr + fails map[string]bool +} + +func (r *resolver) LookupNetIP(ctx context.Context, _net, host string) ([]netip.Addr, error) { + if addrs, ok := r.resolves[host]; ok { + return addrs, nil + } + if _, ok := r.fails[host]; ok { + return nil, &net.DNSError{IsTimeout: false, IsNotFound: false, Name: host, IsTemporary: true} + } + return nil, &net.DNSError{IsNotFound: true, Name: host} +} + +type whois struct { + peers map[string]*apitype.WhoIsResponse +} + +func (w *whois) WhoIs(ctx context.Context, remoteAddr string) (*apitype.WhoIsResponse, error) { + addr := netip.MustParseAddrPort(remoteAddr).Addr().String() + if peer, ok := w.peers[addr]; ok { + return peer, nil + } + return nil, fmt.Errorf("peer not found") +} + +func TestDNSResponse(t *testing.T) { + tests := []struct { + name string + questions []dnsmessage.Question + wantEmpty bool + wantAnswers []struct { + name string + qType dnsmessage.Type + addr netip.Addr + } + wantNXDOMAIN bool + wantIgnored bool + }{ + { + name: "empty_request", + questions: []dnsmessage.Question{}, + wantEmpty: false, + wantAnswers: nil, + }, + { + name: "a_record", + questions: []dnsmessage.Question{ + { + Name: dnsmessage.MustNewName("example.com."), + Type: dnsmessage.TypeA, + Class: dnsmessage.ClassINET, + }, + }, + wantAnswers: []struct { + name string + qType dnsmessage.Type + addr netip.Addr + }{ + { + name: "example.com.", + qType: dnsmessage.TypeA, + addr: netip.MustParseAddr("100.64.0.0"), + }, + }, + }, + { + name: "aaaa_record", + questions: []dnsmessage.Question{ + { + Name: dnsmessage.MustNewName("example.com."), + Type: dnsmessage.TypeAAAA, + Class: dnsmessage.ClassINET, + }, + }, + wantAnswers: []struct { + name string + qType dnsmessage.Type + addr netip.Addr + }{ + { + name: "example.com.", + qType: dnsmessage.TypeAAAA, + addr: netip.MustParseAddr("fd7a:115c:a1e0::"), + }, + }, + }, + { + name: "soa_record", + questions: []dnsmessage.Question{ + { + Name: dnsmessage.MustNewName("example.com."), + Type: dnsmessage.TypeSOA, + Class: dnsmessage.ClassINET, + }, + }, + wantAnswers: nil, + }, + { + name: "ns_record", + questions: []dnsmessage.Question{ + { + Name: dnsmessage.MustNewName("example.com."), + Type: dnsmessage.TypeNS, + Class: dnsmessage.ClassINET, + }, + }, + wantAnswers: nil, + }, + { + name: "nxdomain", + questions: []dnsmessage.Question{ + { + Name: dnsmessage.MustNewName("noexist.example.com."), + Type: dnsmessage.TypeA, + Class: dnsmessage.ClassINET, + }, + }, + wantNXDOMAIN: true, + }, + { + name: "servfail", + questions: []dnsmessage.Question{ + { + Name: dnsmessage.MustNewName("fail.example.com."), + Type: dnsmessage.TypeA, + Class: dnsmessage.ClassINET, + }, + }, + wantEmpty: true, // TODO: pass through instead? + }, + { + name: "ignored", + questions: []dnsmessage.Question{ + { + Name: dnsmessage.MustNewName("ignore.example.com."), + Type: dnsmessage.TypeA, + Class: dnsmessage.ClassINET, + }, + }, + wantAnswers: []struct { + name string + qType dnsmessage.Type + addr netip.Addr + }{ + { + name: "ignore.example.com.", + qType: dnsmessage.TypeA, + addr: netip.MustParseAddr("8.8.4.4"), + }, + }, + wantIgnored: true, + }, + } + + var rpc recordingPacketConn + remoteAddr := must.Get(net.ResolveUDPAddr("udp", "100.64.254.1:12345")) + + routes, dnsAddr, addrPool := calculateAddresses([]netip.Prefix{netip.MustParsePrefix("10.64.0.0/24")}) + v6ULA := ula(1) + c := connector{ + resolver: &resolver{ + resolves: map[string][]netip.Addr{ + "example.com.": { + netip.MustParseAddr("8.8.8.8"), + netip.MustParseAddr("2001:4860:4860::8888"), + }, + "ignore.example.com.": { + netip.MustParseAddr("8.8.4.4"), + }, + }, + fails: map[string]bool{ + "fail.example.com.": true, + }, + }, + whois: &whois{ + peers: map[string]*apitype.WhoIsResponse{ + "100.64.254.1": { + Node: &tailcfg.Node{ID: 123}, + }, + }, + }, + ignoreDsts: &bart.Table[bool]{}, + routes: routes, + v6ULA: v6ULA, + ipPool: &ippool.SingleMachineIPPool{IPSet: addrPool}, + dnsAddr: dnsAddr, + } + c.ignoreDsts.Insert(netip.MustParsePrefix("8.8.4.4/32"), true) + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + rb := dnsmessage.NewBuilder(nil, + dnsmessage.Header{ + ID: 1234, + }, + ) + must.Do(rb.StartQuestions()) + for _, q := range tc.questions { + rb.Question(q) + } + + c.handleDNS(&rpc, must.Get(rb.Finish()), remoteAddr) + + writes := rpc.writes + rpc.writes = rpc.writes[:0] + + if tc.wantEmpty { + if len(writes) != 0 { + t.Errorf("handleDNS() returned non-empty response when expected empty") + } + return + } + + if !tc.wantEmpty && len(writes) != 1 { + t.Fatalf("handleDNS() returned an unexpected number of responses: %d, want 1", len(writes)) + } + + resp := writes[0] + var msg dnsmessage.Message + err := msg.Unpack(resp) + if err != nil { + t.Fatalf("Failed to unpack response: %v", err) + } + + if !msg.Header.Response { + t.Errorf("Response header is not set") + } + + if msg.Header.ID != 1234 { + t.Errorf("Response ID = %d, want %d", msg.Header.ID, 1234) + } + + if len(tc.wantAnswers) > 0 { + if len(msg.Answers) != len(tc.wantAnswers) { + t.Errorf("got %d answers, want %d:\n%s", len(msg.Answers), len(tc.wantAnswers), msg.GoString()) + } else { + for i, want := range tc.wantAnswers { + ans := msg.Answers[i] + + gotName := ans.Header.Name.String() + if gotName != want.name { + t.Errorf("answer[%d] name = %s, want %s", i, gotName, want.name) + } + + if ans.Header.Type != want.qType { + t.Errorf("answer[%d] type = %v, want %v", i, ans.Header.Type, want.qType) + } + + switch want.qType { + case dnsmessage.TypeA: + if ans.Body.(*dnsmessage.AResource) == nil { + t.Errorf("answer[%d] not an A record", i) + continue + } + case dnsmessage.TypeAAAA: + if ans.Body.(*dnsmessage.AAAAResource) == nil { + t.Errorf("answer[%d] not an AAAA record", i) + continue + } + } + + var gotIP netip.Addr + switch want.qType { + case dnsmessage.TypeA: + resource := ans.Body.(*dnsmessage.AResource) + gotIP = netip.AddrFrom4([4]byte(resource.A)) + case dnsmessage.TypeAAAA: + resource := ans.Body.(*dnsmessage.AAAAResource) + gotIP = netip.AddrFrom16([16]byte(resource.AAAA)) + } + + var wantIP netip.Addr + if tc.wantIgnored { + var net string + var fxSelectIP func(netip.Addr) bool + switch want.qType { + case dnsmessage.TypeA: + net = "ip4" + fxSelectIP = func(a netip.Addr) bool { + return a.Is4() + } + case dnsmessage.TypeAAAA: + //TODO(fran) is this branch exercised? + net = "ip6" + fxSelectIP = func(a netip.Addr) bool { + return a.Is6() + } + } + ips := must.Get(c.resolver.LookupNetIP(t.Context(), net, want.name)) + for _, ip := range ips { + if fxSelectIP(ip) { + wantIP = ip + break + } + } + } else { + addr := must.Get(c.ipPool.IPForDomain(tailcfg.NodeID(123), want.name)) + switch want.qType { + case dnsmessage.TypeA: + wantIP = addr + case dnsmessage.TypeAAAA: + wantIP = v6ForV4(v6ULA.Addr(), addr) + } + } + if gotIP != wantIP { + t.Errorf("answer[%d] IP = %s, want %s", i, gotIP, wantIP) + } + } + } + } + + if tc.wantNXDOMAIN { + if msg.RCode != dnsmessage.RCodeNameError { + t.Errorf("expected NXDOMAIN, got %v", msg.RCode) + } + if len(msg.Answers) != 0 { + t.Errorf("expected no answers, got %d", len(msg.Answers)) + } + } + }) + } +} + +func TestIgnoreDestination(t *testing.T) { + ignoreDstTable := &bart.Table[bool]{} + ignoreDstTable.Insert(netip.MustParsePrefix("192.168.1.0/24"), true) + ignoreDstTable.Insert(netip.MustParsePrefix("10.0.0.0/8"), true) + + c := &connector{ + ignoreDsts: ignoreDstTable, + } + + tests := []struct { + name string + addrs []netip.Addr + expected bool + }{ + { + name: "no_match", + addrs: []netip.Addr{netip.MustParseAddr("8.8.8.8"), netip.MustParseAddr("1.1.1.1")}, + expected: false, + }, + { + name: "one_match", + addrs: []netip.Addr{netip.MustParseAddr("8.8.8.8"), netip.MustParseAddr("192.168.1.5")}, + expected: true, + }, + { + name: "all_match", + addrs: []netip.Addr{netip.MustParseAddr("10.0.0.1"), netip.MustParseAddr("192.168.1.5")}, + expected: true, + }, + { + name: "empty_addrs", + addrs: []netip.Addr{}, + expected: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := c.ignoreDestination(tc.addrs) + if got != tc.expected { + t.Errorf("ignoreDestination(%v) = %v, want %v", tc.addrs, got, tc.expected) + } + }) + } +} + +func TestV6V4(t *testing.T) { + v6ULA := ula(1) + + tests := [][]string{ + {"100.64.0.0", "fd7a:115c:a1e0:a99c:1:0:6440:0"}, + {"0.0.0.0", "fd7a:115c:a1e0:a99c:1::"}, + {"255.255.255.255", "fd7a:115c:a1e0:a99c:1:0:ffff:ffff"}, + } + + for i, test := range tests { + // to v6 + v6 := v6ForV4(v6ULA.Addr(), netip.MustParseAddr(test[0])) + want := netip.MustParseAddr(test[1]) + if v6 != want { + t.Fatalf("test %d: want: %v, got: %v", i, want, v6) + } + + // to v4 + v4 := v4ForV6(netip.MustParseAddr(test[1])) + want = netip.MustParseAddr(test[0]) + if v4 != want { + t.Fatalf("test %d: want: %v, got: %v", i, want, v4) + } + } +} + +// echoServer is a simple server that just echos back data set to it. +type echoServer struct { + listener net.Listener + addr string + wg sync.WaitGroup + done chan struct{} +} + +// newEchoServer creates a new test DNS server on the specified network and address +func newEchoServer(t *testing.T, network, addr string) *echoServer { + listener, err := net.Listen(network, addr) + if err != nil { + t.Fatalf("Failed to create test DNS server: %v", err) + } + + server := &echoServer{ + listener: listener, + addr: listener.Addr().String(), + done: make(chan struct{}), + } + + server.wg.Add(1) + go server.serve() + + return server +} + +func (s *echoServer) serve() { + defer s.wg.Done() + + for { + select { + case <-s.done: + return + default: + conn, err := s.listener.Accept() + if err != nil { + select { + case <-s.done: + return + default: + continue + } + } + go s.handleConnection(conn) + } + } +} + +func (s *echoServer) handleConnection(conn net.Conn) { + defer conn.Close() + // Simple response - just echo back some data to confirm connectivity + buf := make([]byte, 1024) + n, err := conn.Read(buf) + if err != nil { + return + } + conn.Write(buf[:n]) +} + +func (s *echoServer) close() { + close(s.done) + s.listener.Close() + s.wg.Wait() +} + +func TestGetResolver(t *testing.T) { + tests := []struct { + name string + network string + addr string + }{ + { + name: "ipv4_loopback", + network: "tcp4", + addr: "127.0.0.1:0", + }, + { + name: "ipv6_loopback", + network: "tcp6", + addr: "[::1]:0", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + server := newEchoServer(t, tc.network, tc.addr) + defer server.close() + serverAddr := server.addr + resolver := getResolver(serverAddr) + if resolver == nil { + t.Fatal("getResolver returned nil") + } + + netResolver, ok := resolver.(*net.Resolver) + if !ok { + t.Fatal("getResolver did not return a *net.Resolver") + } + if netResolver.Dial == nil { + t.Fatal("resolver.Dial is nil") + } + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + conn, err := netResolver.Dial(ctx, "tcp", "dummy.address:53") + if err != nil { + t.Fatalf("Failed to dial test DNS server: %v", err) + } + defer conn.Close() + + testData := []byte("test") + _, err = conn.Write(testData) + if err != nil { + t.Fatalf("Failed to write to connection: %v", err) + } + + response := make([]byte, len(testData)) + _, err = conn.Read(response) + if err != nil { + t.Fatalf("Failed to read from connection: %v", err) + } + + if string(response) != string(testData) { + t.Fatalf("Expected echo response %q, got %q", testData, response) + } + }) + } +} + +func TestGetResolverMultipleServers(t *testing.T) { + server1 := newEchoServer(t, "tcp4", "127.0.0.1:0") + defer server1.close() + server2 := newEchoServer(t, "tcp4", "127.0.0.1:0") + defer server2.close() + serverFlag := server1.addr + ", " + server2.addr + + resolver := getResolver(serverFlag) + netResolver, ok := resolver.(*net.Resolver) + if !ok { + t.Fatal("getResolver did not return a *net.Resolver") + } + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + servers := map[string]bool{ + server1.addr: false, + server2.addr: false, + } + + // Try up to 1000 times to hit all servers, this should be very quick, and + // if this fails randomness has regressed beyond reason. + for range 1000 { + conn, err := netResolver.Dial(ctx, "tcp", "dummy.address:53") + if err != nil { + t.Fatalf("Failed to dial test DNS server: %v", err) + } + + remoteAddr := conn.RemoteAddr().String() + + conn.Close() + + servers[remoteAddr] = true + + var allDone = true + for _, done := range servers { + if !done { + allDone = false + break + } + } + if allDone { + break + } + } + + var allDone = true + for _, done := range servers { + if !done { + allDone = false + break + } + } + if !allDone { + t.Errorf("after 1000 queries, not all servers were hit, significant lack of randomness: %#v", servers) + } +} + +func TestGetResolverEmpty(t *testing.T) { + resolver := getResolver("") + if resolver != net.DefaultResolver { + t.Fatal(`getResolver("") should return net.DefaultResolver`) + } +} diff --git a/cmd/netlogfmt/main.go b/cmd/netlogfmt/main.go index 65e87098f..b8aba4aaa 100644 --- a/cmd/netlogfmt/main.go +++ b/cmd/netlogfmt/main.go @@ -44,25 +44,51 @@ import ( "github.com/dsnet/try" jsonv2 "github.com/go-json-experiment/json" "github.com/go-json-experiment/json/jsontext" + "tailscale.com/tailcfg" + "tailscale.com/types/bools" "tailscale.com/types/logid" "tailscale.com/types/netlogtype" "tailscale.com/util/must" ) var ( - resolveNames = flag.Bool("resolve-names", false, "convert tailscale IP addresses to hostnames; must also specify --api-key and --tailnet-id") - apiKey = flag.String("api-key", "", "API key to query the Tailscale API with; see https://login.tailscale.com/admin/settings/keys") - tailnetName = flag.String("tailnet-name", "", "tailnet domain name to lookup devices in; see https://login.tailscale.com/admin/settings/general") + resolveNames = flag.Bool("resolve-names", false, "This is equivalent to specifying \"--resolve-addrs=name\".") + resolveAddrs = flag.String("resolve-addrs", "", "Resolve each tailscale IP address as a node ID, name, or user.\n"+ + "If network flow logs do not support embedded node information,\n"+ + "then --api-key and --tailnet-name must also be provided.\n"+ + "Valid values include \"nodeId\", \"name\", or \"user\".") + apiKey = flag.String("api-key", "", "The API key to query the Tailscale API with.\nSee https://login.tailscale.com/admin/settings/keys") + tailnetName = flag.String("tailnet-name", "", "The Tailnet name to lookup nodes within.\nSee https://login.tailscale.com/admin/settings/general") ) -var namesByAddr map[netip.Addr]string +var ( + tailnetNodesByAddr map[netip.Addr]netlogtype.Node + tailnetNodesByID map[tailcfg.StableNodeID]netlogtype.Node +) func main() { flag.Parse() if *resolveNames { - namesByAddr = mustMakeNamesByAddr() + *resolveAddrs = "name" + } + *resolveAddrs = strings.ToLower(*resolveAddrs) // make case-insensitive + *resolveAddrs = strings.TrimSuffix(*resolveAddrs, "s") // allow plural form + *resolveAddrs = strings.ReplaceAll(*resolveAddrs, " ", "") // ignore spaces + *resolveAddrs = strings.ReplaceAll(*resolveAddrs, "-", "") // ignore dashes + *resolveAddrs = strings.ReplaceAll(*resolveAddrs, "_", "") // ignore underscores + switch *resolveAddrs { + case "id", "nodeid": + *resolveAddrs = "nodeid" + case "name", "hostname": + *resolveAddrs = "name" + case "user", "tag", "usertag", "taguser": + *resolveAddrs = "user" // tag resolution is implied + default: + log.Fatalf("--resolve-addrs must be \"nodeId\", \"name\", or \"user\"") } + mustLoadTailnetNodes() + // The logic handles a stream of arbitrary JSON. // So long as a JSON object seems like a network log message, // then this will unmarshal and print it. @@ -103,7 +129,7 @@ func processArray(dec *jsontext.Decoder) { func processObject(dec *jsontext.Decoder) { var hasTraffic bool - var rawMsg []byte + var rawMsg jsontext.Value try.E1(dec.ReadToken()) // parse '{' for dec.PeekKind() != '}' { // Capture any members that could belong to a network log message. @@ -111,13 +137,13 @@ func processObject(dec *jsontext.Decoder) { case "virtualTraffic", "subnetTraffic", "exitTraffic", "physicalTraffic": hasTraffic = true fallthrough - case "logtail", "nodeId", "logged", "start", "end": + case "logtail", "nodeId", "logged", "srcNode", "dstNodes", "start", "end": if len(rawMsg) == 0 { rawMsg = append(rawMsg, '{') } else { rawMsg = append(rawMsg[:len(rawMsg)-1], ',') } - rawMsg = append(append(append(rawMsg, '"'), name.String()...), '"') + rawMsg, _ = jsontext.AppendQuote(rawMsg, name.String()) rawMsg = append(rawMsg, ':') rawMsg = append(rawMsg, try.E1(dec.ReadValue())...) rawMsg = append(rawMsg, '}') @@ -145,6 +171,32 @@ type message struct { } func printMessage(msg message) { + var nodesByAddr map[netip.Addr]netlogtype.Node + var tailnetDNS string // e.g., ".acme-corp.ts.net" + if *resolveAddrs != "" { + nodesByAddr = make(map[netip.Addr]netlogtype.Node) + insertNode := func(node netlogtype.Node) { + for _, addr := range node.Addresses { + nodesByAddr[addr] = node + } + } + for _, node := range msg.DstNodes { + insertNode(node) + } + insertNode(msg.SrcNode) + + // Derive the Tailnet DNS of the self node. + detectTailnetDNS := func(nodeName string) { + if prefix, ok := strings.CutSuffix(nodeName, ".ts.net"); ok { + if i := strings.LastIndexByte(prefix, '.'); i > 0 { + tailnetDNS = nodeName[i:] + } + } + } + detectTailnetDNS(msg.SrcNode.Name) + detectTailnetDNS(tailnetNodesByID[msg.NodeID].Name) + } + // Construct a table of network traffic per connection. rows := [][7]string{{3: "Tx[P/s]", 4: "Tx[B/s]", 5: "Rx[P/s]", 6: "Rx[B/s]"}} duration := msg.End.Sub(msg.Start) @@ -175,16 +227,25 @@ func printMessage(msg message) { if !a.IsValid() { return "" } - if name, ok := namesByAddr[a.Addr()]; ok { - if a.Port() == 0 { - return name + name := a.Addr().String() + node, ok := tailnetNodesByAddr[a.Addr()] + if !ok { + node, ok = nodesByAddr[a.Addr()] + } + if ok { + switch *resolveAddrs { + case "nodeid": + name = cmp.Or(string(node.NodeID), name) + case "name": + name = cmp.Or(strings.TrimSuffix(string(node.Name), tailnetDNS), name) + case "user": + name = cmp.Or(bools.IfElse(len(node.Tags) > 0, fmt.Sprint(node.Tags), node.User), name) } - return name + ":" + strconv.Itoa(int(a.Port())) } - if a.Port() == 0 { - return a.Addr().String() + if a.Port() != 0 { + return name + ":" + strconv.Itoa(int(a.Port())) } - return a.String() + return name } for _, cc := range traffic { row := [7]string{ @@ -279,8 +340,10 @@ func printMessage(msg message) { } } -func mustMakeNamesByAddr() map[netip.Addr]string { +func mustLoadTailnetNodes() { switch { + case *apiKey == "" && *tailnetName == "": + return // rely on embedded node information in the logs themselves case *apiKey == "": log.Fatalf("--api-key must be specified with --resolve-names") case *tailnetName == "": @@ -300,57 +363,19 @@ func mustMakeNamesByAddr() map[netip.Addr]string { // Unmarshal the API response. var m struct { - Devices []struct { - Name string `json:"name"` - Addrs []netip.Addr `json:"addresses"` - } `json:"devices"` + Devices []netlogtype.Node `json:"devices"` } must.Do(json.Unmarshal(b, &m)) - // Construct a unique mapping of Tailscale IP addresses to hostnames. - // For brevity, we start with the first segment of the name and - // use more segments until we find the shortest prefix that is unique - // for all names in the tailnet. - seen := make(map[string]bool) - namesByAddr := make(map[netip.Addr]string) -retry: - for i := range 10 { - clear(seen) - clear(namesByAddr) - for _, d := range m.Devices { - name := fieldPrefix(d.Name, i) - if seen[name] { - continue retry - } - seen[name] = true - for _, a := range d.Addrs { - namesByAddr[a] = name - } - } - return namesByAddr - } - panic("unable to produce unique mapping of address to names") -} - -// fieldPrefix returns the first n number of dot-separated segments. -// -// Example: -// -// fieldPrefix("foo.bar.baz", 0) returns "" -// fieldPrefix("foo.bar.baz", 1) returns "foo" -// fieldPrefix("foo.bar.baz", 2) returns "foo.bar" -// fieldPrefix("foo.bar.baz", 3) returns "foo.bar.baz" -// fieldPrefix("foo.bar.baz", 4) returns "foo.bar.baz" -func fieldPrefix(s string, n int) string { - s0 := s - for i := 0; i < n && len(s) > 0; i++ { - if j := strings.IndexByte(s, '.'); j >= 0 { - s = s[j+1:] - } else { - s = "" + // Construct a mapping of Tailscale IP addresses to node information. + tailnetNodesByAddr = make(map[netip.Addr]netlogtype.Node) + tailnetNodesByID = make(map[tailcfg.StableNodeID]netlogtype.Node) + for _, node := range m.Devices { + for _, addr := range node.Addresses { + tailnetNodesByAddr[addr] = node } + tailnetNodesByID[node.NodeID] = node } - return strings.TrimSuffix(s0[:len(s0)-len(s)], ".") } func appendRepeatByte(b []byte, c byte, n int) []byte { diff --git a/cmd/omitsize/omitsize.go b/cmd/omitsize/omitsize.go new file mode 100644 index 000000000..35e03d268 --- /dev/null +++ b/cmd/omitsize/omitsize.go @@ -0,0 +1,229 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// The omitsize tool prints out how large the Tailscale binaries are with +// different build tags. +package main + +import ( + "crypto/sha256" + "flag" + "fmt" + "log" + "maps" + "os" + "os/exec" + "path/filepath" + "slices" + "strconv" + "strings" + "sync" + + "tailscale.com/feature/featuretags" + "tailscale.com/util/set" +) + +var ( + cacheDir = flag.String("cachedir", "", "if non-empty, use this directory to store cached size results to speed up subsequent runs. The tool does not consider the git status when deciding whether to use the cache. It's on you to nuke it between runs if the tree changed.") + features = flag.String("features", "", "comma-separated list of features to list in the table, without the ts_omit_ prefix. It may also contain a '+' sign(s) for ANDing features together. If empty, all omittable features are considered one at a time.") + + showRemovals = flag.Bool("show-removals", false, "if true, show a table of sizes removing one feature at a time from the full set.") +) + +// allOmittable returns the list of all build tags that remove features. +var allOmittable = sync.OnceValue(func() []string { + var ret []string // all build tags that can be omitted + for k := range featuretags.Features { + if k.IsOmittable() { + ret = append(ret, k.OmitTag()) + } + } + slices.Sort(ret) + return ret +}) + +func main() { + flag.Parse() + + // rows is a set (usually of size 1) of feature(s) to add/remove, without deps + // included at this point (as dep direction depends on whether we're adding or removing, + // so it's expanded later) + var rows []set.Set[featuretags.FeatureTag] + + if *features == "" { + for _, k := range slices.Sorted(maps.Keys(featuretags.Features)) { + if k.IsOmittable() { + rows = append(rows, set.Of(k)) + } + } + } else { + for v := range strings.SplitSeq(*features, ",") { + s := set.Set[featuretags.FeatureTag]{} + for fts := range strings.SplitSeq(v, "+") { + ft := featuretags.FeatureTag(fts) + if _, ok := featuretags.Features[ft]; !ok { + log.Fatalf("unknown feature %q", v) + } + s.Add(ft) + } + rows = append(rows, s) + } + } + + minD := measure("tailscaled", allOmittable()...) + minC := measure("tailscale", allOmittable()...) + minBoth := measure("tailscaled", append(slices.Clone(allOmittable()), "ts_include_cli")...) + + if *showRemovals { + baseD := measure("tailscaled") + baseC := measure("tailscale") + baseBoth := measure("tailscaled", "ts_include_cli") + + fmt.Printf("Starting with everything and removing a feature...\n\n") + + fmt.Printf("%9s %9s %9s\n", "tailscaled", "tailscale", "combined (linux/amd64)") + fmt.Printf("%9d %9d %9d\n", baseD, baseC, baseBoth) + + fmt.Printf("-%8d -%8d -%8d .. remove *\n", baseD-minD, baseC-minC, baseBoth-minBoth) + + for _, s := range rows { + title, tags := computeRemove(s) + sizeD := measure("tailscaled", tags...) + sizeC := measure("tailscale", tags...) + sizeBoth := measure("tailscaled", append(slices.Clone(tags), "ts_include_cli")...) + saveD := max(baseD-sizeD, 0) + saveC := max(baseC-sizeC, 0) + saveBoth := max(baseBoth-sizeBoth, 0) + fmt.Printf("-%8d -%8d -%8d .. remove %s\n", saveD, saveC, saveBoth, title) + + } + } + + fmt.Printf("\nStarting at a minimal binary and adding one feature back...\n\n") + fmt.Printf("%9s %9s %9s\n", "tailscaled", "tailscale", "combined (linux/amd64)") + fmt.Printf("%9d %9d %9d omitting everything\n", minD, minC, minBoth) + for _, s := range rows { + title, tags := computeAdd(s) + sizeD := measure("tailscaled", tags...) + sizeC := measure("tailscale", tags...) + sizeBoth := measure("tailscaled", append(tags, "ts_include_cli")...) + + fmt.Printf("+%8d +%8d +%8d .. add %s\n", max(sizeD-minD, 0), max(sizeC-minC, 0), max(sizeBoth-minBoth, 0), title) + } + +} + +// computeAdd returns a human-readable title of a set of features and the build +// tags to use to add that set of features to a minimal binary, including their +// feature dependencies. +func computeAdd(s set.Set[featuretags.FeatureTag]) (title string, tags []string) { + allSet := set.Set[featuretags.FeatureTag]{} // s + all their outbound dependencies + var explicitSorted []string // string versions of s, sorted + for ft := range s { + allSet.AddSet(featuretags.Requires(ft)) + if ft.IsOmittable() { + explicitSorted = append(explicitSorted, string(ft)) + } + } + slices.Sort(explicitSorted) + + var removeTags []string + for ft := range allSet { + if ft.IsOmittable() { + removeTags = append(removeTags, ft.OmitTag()) + } + } + + var titleBuf strings.Builder + titleBuf.WriteString(strings.Join(explicitSorted, "+")) + var and []string + for ft := range allSet { + if !s.Contains(ft) { + and = append(and, string(ft)) + } + } + if len(and) > 0 { + slices.Sort(and) + fmt.Fprintf(&titleBuf, " (and %s)", strings.Join(and, "+")) + } + tags = allExcept(allOmittable(), removeTags) + return titleBuf.String(), tags +} + +// computeRemove returns a human-readable title of a set of features and the build +// tags to use to remove that set of features from a full binary, including removing +// any features that depend on features in the provided set. +func computeRemove(s set.Set[featuretags.FeatureTag]) (title string, tags []string) { + allSet := set.Set[featuretags.FeatureTag]{} // s + all their inbound dependencies + var explicitSorted []string // string versions of s, sorted + for ft := range s { + allSet.AddSet(featuretags.RequiredBy(ft)) + if ft.IsOmittable() { + explicitSorted = append(explicitSorted, string(ft)) + } + } + slices.Sort(explicitSorted) + + var removeTags []string + for ft := range allSet { + if ft.IsOmittable() { + removeTags = append(removeTags, ft.OmitTag()) + } + } + + var titleBuf strings.Builder + titleBuf.WriteString(strings.Join(explicitSorted, "+")) + + var and []string + for ft := range allSet { + if !s.Contains(ft) { + and = append(and, string(ft)) + } + } + if len(and) > 0 { + slices.Sort(and) + fmt.Fprintf(&titleBuf, " (and %s)", strings.Join(and, "+")) + } + + return titleBuf.String(), removeTags +} + +func allExcept(all, omit []string) []string { + return slices.DeleteFunc(slices.Clone(all), func(s string) bool { return slices.Contains(omit, s) }) +} + +func measure(bin string, tags ...string) int64 { + tags = slices.Clone(tags) + slices.Sort(tags) + tags = slices.Compact(tags) + comma := strings.Join(tags, ",") + + var cacheFile string + if *cacheDir != "" { + cacheFile = filepath.Join(*cacheDir, fmt.Sprintf("%02x", sha256.Sum256(fmt.Appendf(nil, "%s-%s.size", bin, comma)))) + if v, err := os.ReadFile(cacheFile); err == nil { + if size, err := strconv.ParseInt(strings.TrimSpace(string(v)), 10, 64); err == nil { + return size + } + } + } + + cmd := exec.Command("go", "build", "-trimpath", "-ldflags=-w -s", "-tags", strings.Join(tags, ","), "-o", "tmpbin", "./cmd/"+bin) + log.Printf("# Measuring %v", cmd.Args) + cmd.Env = append(os.Environ(), "CGO_ENABLED=0", "GOOS=linux", "GOARCH=amd64") + out, err := cmd.CombinedOutput() + if err != nil { + log.Fatalf("error measuring %q: %v, %s\n", bin, err, out) + } + fi, err := os.Stat("tmpbin") + if err != nil { + log.Fatal(err) + } + n := fi.Size() + if cacheFile != "" { + if err := os.WriteFile(cacheFile, fmt.Appendf(nil, "%d", n), 0644); err != nil { + log.Fatalf("error writing size to cache: %v\n", err) + } + } + return n +} diff --git a/cmd/pgproxy/pgproxy.go b/cmd/pgproxy/pgproxy.go index 468649ee2..e102c8ae4 100644 --- a/cmd/pgproxy/pgproxy.go +++ b/cmd/pgproxy/pgproxy.go @@ -24,7 +24,7 @@ import ( "strings" "time" - "tailscale.com/client/tailscale" + "tailscale.com/client/local" "tailscale.com/metrics" "tailscale.com/tsnet" "tailscale.com/tsweb" @@ -105,7 +105,7 @@ type proxy struct { upstreamHost string // "my.database.com" upstreamCertPool *x509.CertPool downstreamCert []tls.Certificate - client *tailscale.LocalClient + client *local.Client activeSessions expvar.Int startedSessions expvar.Int @@ -115,7 +115,7 @@ type proxy struct { // newProxy returns a proxy that forwards connections to // upstreamAddr. The upstream's TLS session is verified using the CA // cert(s) in upstreamCAPath. -func newProxy(upstreamAddr, upstreamCAPath string, client *tailscale.LocalClient) (*proxy, error) { +func newProxy(upstreamAddr, upstreamCAPath string, client *local.Client) (*proxy, error) { bs, err := os.ReadFile(upstreamCAPath) if err != nil { return nil, err diff --git a/cmd/proxy-test-server/proxy-test-server.go b/cmd/proxy-test-server/proxy-test-server.go new file mode 100644 index 000000000..9f8c94a38 --- /dev/null +++ b/cmd/proxy-test-server/proxy-test-server.go @@ -0,0 +1,81 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// The proxy-test-server command is a simple HTTP proxy server for testing +// Tailscale's client proxy functionality. +package main + +import ( + "crypto/tls" + "flag" + "fmt" + "log" + "net" + "net/http" + "os" + "strings" + + "golang.org/x/crypto/acme/autocert" + "tailscale.com/net/connectproxy" + "tailscale.com/tempfork/acme" +) + +var ( + listen = flag.String("listen", ":8080", "Address to listen on for HTTPS proxy requests") + hostname = flag.String("hostname", "localhost", "Hostname for the proxy server") + tailscaleOnly = flag.Bool("tailscale-only", true, "Restrict proxy to Tailscale targets only") + extraAllowedHosts = flag.String("allow-hosts", "", "Comma-separated list of allowed target hosts to additionally allow if --tailscale-only is true") +) + +func main() { + flag.Parse() + + am := &autocert.Manager{ + HostPolicy: autocert.HostWhitelist(*hostname), + Prompt: autocert.AcceptTOS, + Cache: autocert.DirCache(os.ExpandEnv("$HOME/.cache/autocert/proxy-test-server")), + } + var allowTarget func(hostPort string) error + if *tailscaleOnly { + allowTarget = func(hostPort string) error { + host, port, err := net.SplitHostPort(hostPort) + if err != nil { + return fmt.Errorf("invalid target %q: %v", hostPort, err) + } + if port != "443" { + return fmt.Errorf("target %q must use port 443", hostPort) + } + for allowed := range strings.SplitSeq(*extraAllowedHosts, ",") { + if host == allowed { + return nil // explicitly allowed target + } + } + if !strings.HasSuffix(host, ".tailscale.com") { + return fmt.Errorf("target %q is not a Tailscale host", hostPort) + } + return nil // valid Tailscale target + } + } + + go func() { + if err := http.ListenAndServe(":http", am.HTTPHandler(nil)); err != nil { + log.Fatalf("autocert HTTP server failed: %v", err) + } + }() + hs := &http.Server{ + Addr: *listen, + Handler: &connectproxy.Handler{ + Check: allowTarget, + Logf: log.Printf, + }, + TLSConfig: &tls.Config{ + GetCertificate: am.GetCertificate, + NextProtos: []string{ + "http/1.1", // enable HTTP/2 + acme.ALPNProto, // enable tls-alpn ACME challenges + }, + }, + } + log.Printf("Starting proxy-test-server on %s (hostname: %q)\n", *listen, *hostname) + log.Fatal(hs.ListenAndServeTLS("", "")) // cert and key are provided by autocert +} diff --git a/cmd/proxy-to-grafana/proxy-to-grafana.go b/cmd/proxy-to-grafana/proxy-to-grafana.go index f1c67bad5..27f5e338c 100644 --- a/cmd/proxy-to-grafana/proxy-to-grafana.go +++ b/cmd/proxy-to-grafana/proxy-to-grafana.go @@ -19,8 +19,25 @@ // header_property = username // auto_sign_up = true // whitelist = 127.0.0.1 -// headers = Name:X-WEBAUTH-NAME +// headers = Email:X-Webauth-User, Name:X-Webauth-Name, Role:X-Webauth-Role // enable_login_token = true +// +// You can use grants in Tailscale ACL to give users different roles in Grafana. +// For example, to give group:eng the Editor role, add the following to your ACLs: +// +// "grants": [ +// { +// "src": ["group:eng"], +// "dst": ["tag:grafana"], +// "app": { +// "tailscale.com/cap/proxy-to-grafana": [{ +// "role": "editor", +// }], +// }, +// }, +// ], +// +// If multiple roles are specified, the most permissive role is used. package main import ( @@ -36,7 +53,7 @@ import ( "strings" "time" - "tailscale.com/client/tailscale" + "tailscale.com/client/tailscale/apitype" "tailscale.com/tailcfg" "tailscale.com/tsnet" ) @@ -49,6 +66,57 @@ var ( loginServer = flag.String("login-server", "", "URL to alternative control server. If empty, the default Tailscale control is used.") ) +// aclCap is the Tailscale ACL capability used to configure proxy-to-grafana. +const aclCap tailcfg.PeerCapability = "tailscale.com/cap/proxy-to-grafana" + +// aclGrant is an access control rule that assigns Grafana permissions +// while provisioning a user. +type aclGrant struct { + // Role is one of: "viewer", "editor", "admin". + Role string `json:"role"` +} + +// grafanaRole defines possible Grafana roles. +type grafanaRole int + +const ( + // Roles are ordered by their permissions, with the least permissive role first. + // If a user has multiple roles, the most permissive role is used. + ViewerRole grafanaRole = iota + EditorRole + AdminRole +) + +// String returns the string representation of a grafanaRole. +// It is used as a header value in the HTTP request to Grafana. +func (r grafanaRole) String() string { + switch r { + case ViewerRole: + return "Viewer" + case EditorRole: + return "Editor" + case AdminRole: + return "Admin" + default: + // A safe default. + return "Viewer" + } +} + +// roleFromString converts a string to a grafanaRole. +// It is used to parse the role from the ACL grant. +func roleFromString(s string) (grafanaRole, error) { + switch strings.ToLower(s) { + case "viewer": + return ViewerRole, nil + case "editor": + return EditorRole, nil + case "admin": + return AdminRole, nil + } + return ViewerRole, fmt.Errorf("unknown role: %q", s) +} + func main() { flag.Parse() if *hostname == "" || strings.Contains(*hostname, ".") { @@ -127,14 +195,23 @@ func main() { log.Fatal(http.Serve(ln, proxy)) } -func modifyRequest(req *http.Request, localClient *tailscale.LocalClient) { - // with enable_login_token set to true, we get a cookie that handles +func modifyRequest(req *http.Request, localClient whoisIdentitySource) { + // Delete any existing X-Webauth-* headers to prevent possible spoofing + // if getting Tailnet identity fails. + for h := range req.Header { + if strings.HasPrefix(h, "X-Webauth-") { + req.Header.Del(h) + } + } + + // Set the X-Webauth-* headers only for the /login path + // With enable_login_token set to true, we get a cookie that handles // auth for paths that are not /login if req.URL.Path != "/login" { return } - user, err := getTailscaleUser(req.Context(), localClient, req.RemoteAddr) + user, role, err := getTailscaleIdentity(req.Context(), localClient, req.RemoteAddr) if err != nil { log.Printf("error getting Tailscale user: %v", err) return @@ -142,19 +219,37 @@ func modifyRequest(req *http.Request, localClient *tailscale.LocalClient) { req.Header.Set("X-Webauth-User", user.LoginName) req.Header.Set("X-Webauth-Name", user.DisplayName) + req.Header.Set("X-Webauth-Role", role.String()) } -func getTailscaleUser(ctx context.Context, localClient *tailscale.LocalClient, ipPort string) (*tailcfg.UserProfile, error) { +func getTailscaleIdentity(ctx context.Context, localClient whoisIdentitySource, ipPort string) (*tailcfg.UserProfile, grafanaRole, error) { whois, err := localClient.WhoIs(ctx, ipPort) if err != nil { - return nil, fmt.Errorf("failed to identify remote host: %w", err) + return nil, ViewerRole, fmt.Errorf("failed to identify remote host: %w", err) } if whois.Node.IsTagged() { - return nil, fmt.Errorf("tagged nodes are not users") + return nil, ViewerRole, fmt.Errorf("tagged nodes are not users") } if whois.UserProfile == nil || whois.UserProfile.LoginName == "" { - return nil, fmt.Errorf("failed to identify remote user") + return nil, ViewerRole, fmt.Errorf("failed to identify remote user") } - return whois.UserProfile, nil + role := ViewerRole + grants, err := tailcfg.UnmarshalCapJSON[aclGrant](whois.CapMap, aclCap) + if err != nil { + return nil, ViewerRole, fmt.Errorf("failed to unmarshal ACL grants: %w", err) + } + for _, g := range grants { + r, err := roleFromString(g.Role) + if err != nil { + return nil, ViewerRole, fmt.Errorf("failed to parse role: %w", err) + } + role = max(role, r) + } + + return whois.UserProfile, role, nil +} + +type whoisIdentitySource interface { + WhoIs(ctx context.Context, ipPort string) (*apitype.WhoIsResponse, error) } diff --git a/cmd/proxy-to-grafana/proxy-to-grafana_test.go b/cmd/proxy-to-grafana/proxy-to-grafana_test.go new file mode 100644 index 000000000..4831d5436 --- /dev/null +++ b/cmd/proxy-to-grafana/proxy-to-grafana_test.go @@ -0,0 +1,78 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package main + +import ( + "context" + "fmt" + "net/http/httptest" + "testing" + + "tailscale.com/client/tailscale/apitype" + "tailscale.com/tailcfg" +) + +type mockWhoisSource struct { + id *apitype.WhoIsResponse +} + +func (m *mockWhoisSource) WhoIs(ctx context.Context, remoteAddr string) (*apitype.WhoIsResponse, error) { + if m.id == nil { + return nil, fmt.Errorf("missing mock identity") + } + return m.id, nil +} + +var whois = &apitype.WhoIsResponse{ + UserProfile: &tailcfg.UserProfile{ + LoginName: "foobar@example.com", + DisplayName: "Foobar", + }, + Node: &tailcfg.Node{ + ID: 1, + }, +} + +func TestModifyRequest_Login(t *testing.T) { + req := httptest.NewRequest("GET", "/login", nil) + modifyRequest(req, &mockWhoisSource{id: whois}) + + if got := req.Header.Get("X-Webauth-User"); got != "foobar@example.com" { + t.Errorf("X-Webauth-User = %q; want %q", got, "foobar@example.com") + } + + if got := req.Header.Get("X-Webauth-Role"); got != "Viewer" { + t.Errorf("X-Webauth-Role = %q; want %q", got, "Viewer") + } +} + +func TestModifyRequest_RemoveHeaders_Login(t *testing.T) { + req := httptest.NewRequest("GET", "/login", nil) + req.Header.Set("X-Webauth-User", "malicious@example.com") + req.Header.Set("X-Webauth-Role", "Admin") + + modifyRequest(req, &mockWhoisSource{id: whois}) + + if got := req.Header.Get("X-Webauth-User"); got != "foobar@example.com" { + t.Errorf("X-Webauth-User = %q; want %q", got, "foobar@example.com") + } + if got := req.Header.Get("X-Webauth-Role"); got != "Viewer" { + t.Errorf("X-Webauth-Role = %q; want %q", got, "Viewer") + } +} + +func TestModifyRequest_RemoveHeaders_API(t *testing.T) { + req := httptest.NewRequest("DELETE", "/api/org/users/1", nil) + req.Header.Set("X-Webauth-User", "malicious@example.com") + req.Header.Set("X-Webauth-Role", "Admin") + + modifyRequest(req, &mockWhoisSource{id: whois}) + + if got := req.Header.Get("X-Webauth-User"); got != "" { + t.Errorf("X-Webauth-User = %q; want %q", got, "") + } + if got := req.Header.Get("X-Webauth-Role"); got != "" { + t.Errorf("X-Webauth-Role = %q; want %q", got, "") + } +} diff --git a/cmd/sniproxy/handlers.go b/cmd/sniproxy/handlers.go index 102110fe3..1973eecc0 100644 --- a/cmd/sniproxy/handlers.go +++ b/cmd/sniproxy/handlers.go @@ -14,6 +14,7 @@ import ( "github.com/inetaf/tcpproxy" "tailscale.com/net/netutil" + "tailscale.com/net/netx" ) type tcpRoundRobinHandler struct { @@ -22,7 +23,7 @@ type tcpRoundRobinHandler struct { To []string // DialContext is used to make the outgoing TCP connection. - DialContext func(ctx context.Context, network, address string) (net.Conn, error) + DialContext netx.DialFunc // ReachableIPs enumerates the IP addresses this handler is reachable on. ReachableIPs []netip.Addr diff --git a/cmd/sniproxy/sniproxy.go b/cmd/sniproxy/sniproxy.go index fa83aaf4a..2115c8095 100644 --- a/cmd/sniproxy/sniproxy.go +++ b/cmd/sniproxy/sniproxy.go @@ -22,7 +22,7 @@ import ( "strings" "github.com/peterbourgon/ff/v3" - "tailscale.com/client/tailscale" + "tailscale.com/client/local" "tailscale.com/hostinfo" "tailscale.com/ipn" "tailscale.com/tailcfg" @@ -141,7 +141,7 @@ func run(ctx context.Context, ts *tsnet.Server, wgPort int, hostname string, pro // in the netmap. // We set the NotifyInitialNetMap flag so we will always get woken with the // current netmap, before only being woken on changes. - bus, err := lc.WatchIPNBus(ctx, ipn.NotifyWatchEngineUpdates|ipn.NotifyInitialNetMap|ipn.NotifyNoPrivateKeys) + bus, err := lc.WatchIPNBus(ctx, ipn.NotifyWatchEngineUpdates|ipn.NotifyInitialNetMap) if err != nil { log.Fatalf("watching IPN bus: %v", err) } @@ -157,10 +157,8 @@ func run(ctx context.Context, ts *tsnet.Server, wgPort int, hostname string, pro // NetMap contains app-connector configuration if nm := msg.NetMap; nm != nil && nm.SelfNode.Valid() { - sn := nm.SelfNode.AsStruct() - var c appctype.AppConnectorConfig - nmConf, err := tailcfg.UnmarshalNodeCapJSON[appctype.AppConnectorConfig](sn.CapMap, configCapKey) + nmConf, err := tailcfg.UnmarshalNodeCapViewJSON[appctype.AppConnectorConfig](nm.SelfNode.CapMap(), configCapKey) if err != nil { log.Printf("failed to read app connector configuration from coordination server: %v", err) } else if len(nmConf) > 0 { @@ -185,7 +183,7 @@ func run(ctx context.Context, ts *tsnet.Server, wgPort int, hostname string, pro type sniproxy struct { srv Server ts *tsnet.Server - lc *tailscale.LocalClient + lc *local.Client } func (s *sniproxy) advertiseRoutesFromConfig(ctx context.Context, c *appctype.AppConnectorConfig) error { diff --git a/cmd/sniproxy/sniproxy_test.go b/cmd/sniproxy/sniproxy_test.go index cd2e070bd..65e059efa 100644 --- a/cmd/sniproxy/sniproxy_test.go +++ b/cmd/sniproxy/sniproxy_test.go @@ -152,17 +152,17 @@ func TestSNIProxyWithNetmapConfig(t *testing.T) { configCapKey: []tailcfg.RawMessage{tailcfg.RawMessage(b)}, }) - // Lets spin up a second node (to represent the client). + // Let's spin up a second node (to represent the client). client, _, _ := startNode(t, ctx, controlURL, "client") // Make sure that the sni node has received its config. - l, err := sni.LocalClient() + lc, err := sni.LocalClient() if err != nil { t.Fatal(err) } gotConfigured := false for range 100 { - s, err := l.StatusWithoutPeers(ctx) + s, err := lc.StatusWithoutPeers(ctx) if err != nil { t.Fatal(err) } @@ -176,7 +176,7 @@ func TestSNIProxyWithNetmapConfig(t *testing.T) { t.Error("sni node never received its configuration from the coordination server!") } - // Lets make the client open a connection to the sniproxy node, and + // Let's make the client open a connection to the sniproxy node, and // make sure it results in a connection to our test listener. w, err := client.Dial(ctx, "tcp", fmt.Sprintf("%s:%d", ip, ln.Addr().(*net.TCPAddr).Port)) if err != nil { @@ -208,10 +208,10 @@ func TestSNIProxyWithFlagConfig(t *testing.T) { sni, _, ip := startNode(t, ctx, controlURL, "snitest") go run(ctx, sni, 0, sni.Hostname, false, 0, "", fmt.Sprintf("tcp/%d/localhost", ln.Addr().(*net.TCPAddr).Port)) - // Lets spin up a second node (to represent the client). + // Let's spin up a second node (to represent the client). client, _, _ := startNode(t, ctx, controlURL, "client") - // Lets make the client open a connection to the sniproxy node, and + // Let's make the client open a connection to the sniproxy node, and // make sure it results in a connection to our test listener. w, err := client.Dial(ctx, "tcp", fmt.Sprintf("%s:%d", ip, ln.Addr().(*net.TCPAddr).Port)) if err != nil { diff --git a/cmd/ssh-auth-none-demo/ssh-auth-none-demo.go b/cmd/ssh-auth-none-demo/ssh-auth-none-demo.go index ee929299a..39af584ec 100644 --- a/cmd/ssh-auth-none-demo/ssh-auth-none-demo.go +++ b/cmd/ssh-auth-none-demo/ssh-auth-none-demo.go @@ -6,6 +6,9 @@ // highlight the unique parts of the Tailscale SSH server so SSH // client authors can hit it easily and fix their SSH clients without // needing to set up Tailscale and Tailscale SSH. +// +// Connections are allowed using any username except for "denyme". Connecting as +// "denyme" will result in an authentication failure with error message. package main import ( @@ -16,6 +19,7 @@ import ( "crypto/rsa" "crypto/x509" "encoding/pem" + "errors" "flag" "fmt" "io" @@ -24,7 +28,7 @@ import ( "path/filepath" "time" - gossh "github.com/tailscale/golang-x-crypto/ssh" + gossh "golang.org/x/crypto/ssh" "tailscale.com/tempfork/gliderlabs/ssh" ) @@ -62,13 +66,21 @@ func main() { Handler: handleSessionPostSSHAuth, ServerConfigCallback: func(ctx ssh.Context) *gossh.ServerConfig { start := time.Now() + var spac gossh.ServerPreAuthConn return &gossh.ServerConfig{ - NextAuthMethodCallback: func(conn gossh.ConnMetadata, prevErrors []error) []string { - return []string{"tailscale"} + PreAuthConnCallback: func(conn gossh.ServerPreAuthConn) { + spac = conn }, NoClientAuth: true, // required for the NoClientAuthCallback to run NoClientAuthCallback: func(cm gossh.ConnMetadata) (*gossh.Permissions, error) { - cm.SendAuthBanner(fmt.Sprintf("# Banner: doing none auth at %v\r\n", time.Since(start))) + spac.SendAuthBanner(fmt.Sprintf("# Banner: doing none auth at %v\r\n", time.Since(start))) + + if cm.User() == "denyme" { + return nil, &gossh.BannerError{ + Err: errors.New("denying access"), + Message: "denyme is not allowed to access this machine\n", + } + } totalBanners := 2 if cm.User() == "banners" { @@ -77,9 +89,9 @@ func main() { for banner := 2; banner <= totalBanners; banner++ { time.Sleep(time.Second) if banner == totalBanners { - cm.SendAuthBanner(fmt.Sprintf("# Banner%d: access granted at %v\r\n", banner, time.Since(start))) + spac.SendAuthBanner(fmt.Sprintf("# Banner%d: access granted at %v\r\n", banner, time.Since(start))) } else { - cm.SendAuthBanner(fmt.Sprintf("# Banner%d at %v\r\n", banner, time.Since(start))) + spac.SendAuthBanner(fmt.Sprintf("# Banner%d at %v\r\n", banner, time.Since(start))) } } return nil, nil diff --git a/cmd/stunc/stunc.go b/cmd/stunc/stunc.go index 9743a3300..c4b2eedd3 100644 --- a/cmd/stunc/stunc.go +++ b/cmd/stunc/stunc.go @@ -5,24 +5,40 @@ package main import ( + "flag" "log" "net" "os" "strconv" + "time" "tailscale.com/net/stun" ) func main() { log.SetFlags(0) - - if len(os.Args) < 2 || len(os.Args) > 3 { - log.Fatalf("usage: %s [port]", os.Args[0]) - } - host := os.Args[1] + var host string port := "3478" - if len(os.Args) == 3 { - port = os.Args[2] + + var readTimeout time.Duration + flag.DurationVar(&readTimeout, "timeout", 3*time.Second, "response wait timeout") + + flag.Parse() + + values := flag.Args() + if len(values) < 1 || len(values) > 2 { + log.Printf("usage: %s [port]", os.Args[0]) + flag.PrintDefaults() + os.Exit(1) + } else { + for i, value := range values { + switch i { + case 0: + host = value + case 1: + port = value + } + } } _, err := strconv.ParseUint(port, 10, 16) if err != nil { @@ -46,6 +62,10 @@ func main() { log.Fatal(err) } + err = c.SetReadDeadline(time.Now().Add(readTimeout)) + if err != nil { + log.Fatal(err) + } var buf [1024]byte n, raddr, err := c.ReadFromUDPAddrPort(buf[:]) if err != nil { diff --git a/cmd/stund/depaware.txt b/cmd/stund/depaware.txt index a35f59516..7b3d05f94 100644 --- a/cmd/stund/depaware.txt +++ b/cmd/stund/depaware.txt @@ -2,18 +2,17 @@ tailscale.com/cmd/stund dependencies: (generated by github.com/tailscale/depawar github.com/beorn7/perks/quantile from github.com/prometheus/client_golang/prometheus đŸ’Ŗ github.com/cespare/xxhash/v2 from github.com/prometheus/client_golang/prometheus - github.com/go-json-experiment/json from tailscale.com/types/opt + github.com/go-json-experiment/json from tailscale.com/types/opt+ github.com/go-json-experiment/json/internal from github.com/go-json-experiment/json+ github.com/go-json-experiment/json/internal/jsonflags from github.com/go-json-experiment/json+ github.com/go-json-experiment/json/internal/jsonopts from github.com/go-json-experiment/json+ github.com/go-json-experiment/json/internal/jsonwire from github.com/go-json-experiment/json+ github.com/go-json-experiment/json/jsontext from github.com/go-json-experiment/json+ - github.com/google/uuid from tailscale.com/util/fastuuid + github.com/munnerz/goautoneg from github.com/prometheus/common/expfmt đŸ’Ŗ github.com/prometheus/client_golang/prometheus from tailscale.com/tsweb/promvarz github.com/prometheus/client_golang/prometheus/internal from github.com/prometheus/client_golang/prometheus github.com/prometheus/client_model/go from github.com/prometheus/client_golang/prometheus+ github.com/prometheus/common/expfmt from github.com/prometheus/client_golang/prometheus+ - github.com/prometheus/common/internal/bitbucket.org/ww/goautoneg from github.com/prometheus/common/expfmt github.com/prometheus/common/model from github.com/prometheus/client_golang/prometheus+ LD github.com/prometheus/procfs from github.com/prometheus/client_golang/prometheus LD github.com/prometheus/procfs/internal/fs from github.com/prometheus/procfs @@ -39,6 +38,7 @@ tailscale.com/cmd/stund dependencies: (generated by github.com/tailscale/depawar đŸ’Ŗ google.golang.org/protobuf/internal/impl from google.golang.org/protobuf/internal/filetype+ google.golang.org/protobuf/internal/order from google.golang.org/protobuf/encoding/prototext+ google.golang.org/protobuf/internal/pragma from google.golang.org/protobuf/encoding/prototext+ + đŸ’Ŗ google.golang.org/protobuf/internal/protolazy from google.golang.org/protobuf/internal/impl+ google.golang.org/protobuf/internal/set from google.golang.org/protobuf/encoding/prototext đŸ’Ŗ google.golang.org/protobuf/internal/strs from google.golang.org/protobuf/encoding/prototext+ google.golang.org/protobuf/internal/version from google.golang.org/protobuf/runtime/protoimpl @@ -50,60 +50,71 @@ tailscale.com/cmd/stund dependencies: (generated by github.com/tailscale/depawar google.golang.org/protobuf/types/known/timestamppb from github.com/prometheus/client_golang/prometheus+ tailscale.com from tailscale.com/version tailscale.com/envknob from tailscale.com/tsweb+ + tailscale.com/feature from tailscale.com/tsweb + tailscale.com/feature/buildfeatures from tailscale.com/feature+ tailscale.com/kube/kubetypes from tailscale.com/envknob tailscale.com/metrics from tailscale.com/net/stunserver+ tailscale.com/net/netaddr from tailscale.com/net/tsaddr tailscale.com/net/stun from tailscale.com/net/stunserver tailscale.com/net/stunserver from tailscale.com/cmd/stund tailscale.com/net/tsaddr from tailscale.com/tsweb - tailscale.com/tailcfg from tailscale.com/version - tailscale.com/tsweb from tailscale.com/cmd/stund - tailscale.com/tsweb/promvarz from tailscale.com/tsweb + tailscale.com/syncs from tailscale.com/metrics+ + tailscale.com/tailcfg from tailscale.com/version+ + tailscale.com/tsweb from tailscale.com/cmd/stund+ + tailscale.com/tsweb/promvarz from tailscale.com/cmd/stund tailscale.com/tsweb/varz from tailscale.com/tsweb+ tailscale.com/types/dnstype from tailscale.com/tailcfg tailscale.com/types/ipproto from tailscale.com/tailcfg - tailscale.com/types/key from tailscale.com/tailcfg + tailscale.com/types/key from tailscale.com/tailcfg+ tailscale.com/types/lazy from tailscale.com/version+ - tailscale.com/types/logger from tailscale.com/tsweb + tailscale.com/types/logger from tailscale.com/tsweb+ tailscale.com/types/opt from tailscale.com/envknob+ + tailscale.com/types/persist from tailscale.com/feature tailscale.com/types/ptr from tailscale.com/tailcfg+ + tailscale.com/types/result from tailscale.com/util/lineiter tailscale.com/types/structs from tailscale.com/tailcfg+ tailscale.com/types/tkatype from tailscale.com/tailcfg+ tailscale.com/types/views from tailscale.com/net/tsaddr+ tailscale.com/util/ctxkey from tailscale.com/tsweb+ L đŸ’Ŗ tailscale.com/util/dirwalk from tailscale.com/metrics tailscale.com/util/dnsname from tailscale.com/tailcfg - tailscale.com/util/fastuuid from tailscale.com/tsweb - tailscale.com/util/lineread from tailscale.com/version/distro + tailscale.com/util/lineiter from tailscale.com/version/distro + tailscale.com/util/mak from tailscale.com/syncs+ tailscale.com/util/nocasemaps from tailscale.com/types/ipproto + tailscale.com/util/rands from tailscale.com/tsweb + tailscale.com/util/set from tailscale.com/types/key tailscale.com/util/slicesx from tailscale.com/tailcfg + tailscale.com/util/testenv from tailscale.com/types/logger+ tailscale.com/util/vizerror from tailscale.com/tailcfg+ tailscale.com/version from tailscale.com/envknob+ tailscale.com/version/distro from tailscale.com/envknob golang.org/x/crypto/blake2b from golang.org/x/crypto/nacl/box - golang.org/x/crypto/chacha20 from golang.org/x/crypto/chacha20poly1305 - golang.org/x/crypto/chacha20poly1305 from crypto/tls+ - golang.org/x/crypto/cryptobyte from crypto/ecdsa+ - golang.org/x/crypto/cryptobyte/asn1 from crypto/ecdsa+ golang.org/x/crypto/curve25519 from golang.org/x/crypto/nacl/box+ - golang.org/x/crypto/hkdf from crypto/tls+ + golang.org/x/crypto/internal/alias from golang.org/x/crypto/nacl/secretbox + golang.org/x/crypto/internal/poly1305 from golang.org/x/crypto/nacl/secretbox golang.org/x/crypto/nacl/box from tailscale.com/types/key golang.org/x/crypto/nacl/secretbox from golang.org/x/crypto/nacl/box golang.org/x/crypto/salsa20/salsa from golang.org/x/crypto/nacl/box+ - golang.org/x/crypto/sha3 from crypto/internal/mlkem768+ - golang.org/x/net/dns/dnsmessage from net+ - golang.org/x/net/http/httpguts from net/http - golang.org/x/net/http/httpproxy from net/http - golang.org/x/net/http2/hpack from net/http - golang.org/x/net/idna from golang.org/x/net/http/httpguts+ - D golang.org/x/net/route from net + golang.org/x/exp/constraints from tailscale.com/tsweb/varz+ golang.org/x/sys/cpu from golang.org/x/crypto/blake2b+ LD golang.org/x/sys/unix from github.com/prometheus/procfs+ W golang.org/x/sys/windows from github.com/prometheus/client_golang/prometheus - golang.org/x/text/secure/bidirule from golang.org/x/net/idna - golang.org/x/text/transform from golang.org/x/text/secure/bidirule+ - golang.org/x/text/unicode/bidi from golang.org/x/net/idna+ - golang.org/x/text/unicode/norm from golang.org/x/net/idna + vendor/golang.org/x/crypto/chacha20 from vendor/golang.org/x/crypto/chacha20poly1305 + vendor/golang.org/x/crypto/chacha20poly1305 from crypto/internal/hpke+ + vendor/golang.org/x/crypto/cryptobyte from crypto/ecdsa+ + vendor/golang.org/x/crypto/cryptobyte/asn1 from crypto/ecdsa+ + vendor/golang.org/x/crypto/internal/alias from vendor/golang.org/x/crypto/chacha20+ + vendor/golang.org/x/crypto/internal/poly1305 from vendor/golang.org/x/crypto/chacha20poly1305 + vendor/golang.org/x/net/dns/dnsmessage from net + vendor/golang.org/x/net/http/httpguts from net/http+ + vendor/golang.org/x/net/http/httpproxy from net/http + vendor/golang.org/x/net/http2/hpack from net/http+ + vendor/golang.org/x/net/idna from net/http+ + vendor/golang.org/x/sys/cpu from vendor/golang.org/x/crypto/chacha20poly1305 + vendor/golang.org/x/text/secure/bidirule from vendor/golang.org/x/net/idna + vendor/golang.org/x/text/transform from vendor/golang.org/x/text/secure/bidirule+ + vendor/golang.org/x/text/unicode/bidi from vendor/golang.org/x/net/idna+ + vendor/golang.org/x/text/unicode/norm from vendor/golang.org/x/net/idna bufio from compress/flate+ bytes from bufio+ cmp from slices+ @@ -112,7 +123,7 @@ tailscale.com/cmd/stund dependencies: (generated by github.com/tailscale/depawar container/list from crypto/tls+ context from crypto/tls+ crypto from crypto/ecdh+ - crypto/aes from crypto/ecdsa+ + crypto/aes from crypto/internal/hpke+ crypto/cipher from crypto/aes+ crypto/des from crypto/tls+ crypto/dsa from crypto/x509 @@ -120,20 +131,62 @@ tailscale.com/cmd/stund dependencies: (generated by github.com/tailscale/depawar crypto/ecdsa from crypto/tls+ crypto/ed25519 from crypto/tls+ crypto/elliptic from crypto/ecdsa+ - crypto/hmac from crypto/tls+ + crypto/fips140 from crypto/tls/internal/fips140tls + crypto/hkdf from crypto/internal/hpke+ + crypto/hmac from crypto/tls + crypto/internal/boring from crypto/aes+ + crypto/internal/boring/bbig from crypto/ecdsa+ + crypto/internal/boring/sig from crypto/internal/boring + crypto/internal/entropy from crypto/internal/fips140/drbg + crypto/internal/fips140 from crypto/internal/fips140/aes+ + crypto/internal/fips140/aes from crypto/aes+ + crypto/internal/fips140/aes/gcm from crypto/cipher+ + crypto/internal/fips140/alias from crypto/cipher+ + crypto/internal/fips140/bigmod from crypto/internal/fips140/ecdsa+ + crypto/internal/fips140/check from crypto/internal/fips140/aes+ + crypto/internal/fips140/drbg from crypto/internal/fips140/aes/gcm+ + crypto/internal/fips140/ecdh from crypto/ecdh + crypto/internal/fips140/ecdsa from crypto/ecdsa + crypto/internal/fips140/ed25519 from crypto/ed25519 + crypto/internal/fips140/edwards25519 from crypto/internal/fips140/ed25519 + crypto/internal/fips140/edwards25519/field from crypto/ecdh+ + crypto/internal/fips140/hkdf from crypto/internal/fips140/tls13+ + crypto/internal/fips140/hmac from crypto/hmac+ + crypto/internal/fips140/mlkem from crypto/tls + crypto/internal/fips140/nistec from crypto/elliptic+ + crypto/internal/fips140/nistec/fiat from crypto/internal/fips140/nistec + crypto/internal/fips140/rsa from crypto/rsa + crypto/internal/fips140/sha256 from crypto/internal/fips140/check+ + crypto/internal/fips140/sha3 from crypto/internal/fips140/hmac+ + crypto/internal/fips140/sha512 from crypto/internal/fips140/ecdsa+ + crypto/internal/fips140/subtle from crypto/internal/fips140/aes+ + crypto/internal/fips140/tls12 from crypto/tls + crypto/internal/fips140/tls13 from crypto/tls + crypto/internal/fips140cache from crypto/ecdsa+ + crypto/internal/fips140deps/byteorder from crypto/internal/fips140/aes+ + crypto/internal/fips140deps/cpu from crypto/internal/fips140/aes+ + crypto/internal/fips140deps/godebug from crypto/internal/fips140+ + crypto/internal/fips140hash from crypto/ecdsa+ + crypto/internal/fips140only from crypto/cipher+ + crypto/internal/hpke from crypto/tls + crypto/internal/impl from crypto/internal/fips140/aes+ + crypto/internal/randutil from crypto/dsa+ + crypto/internal/sysrand from crypto/internal/entropy+ crypto/md5 from crypto/tls+ crypto/rand from crypto/ed25519+ crypto/rc4 from crypto/tls crypto/rsa from crypto/tls+ crypto/sha1 from crypto/tls+ crypto/sha256 from crypto/tls+ + crypto/sha3 from crypto/internal/fips140hash crypto/sha512 from crypto/ecdsa+ - crypto/subtle from crypto/aes+ + crypto/subtle from crypto/cipher+ crypto/tls from net/http+ + crypto/tls/internal/fips140tls from crypto/tls crypto/x509 from crypto/tls + D crypto/x509/internal/macos from crypto/x509 crypto/x509/pkix from crypto/x509 - database/sql/driver from github.com/google/uuid - embed from crypto/internal/nistec+ + embed from google.golang.org/protobuf/internal/editiondefaults+ encoding from encoding/json+ encoding/asn1 from crypto/x509+ encoding/base32 from github.com/go-json-experiment/json @@ -144,7 +197,7 @@ tailscale.com/cmd/stund dependencies: (generated by github.com/tailscale/depawar encoding/pem from crypto/tls+ errors from bufio+ expvar from github.com/prometheus/client_golang/prometheus+ - flag from tailscale.com/cmd/stund + flag from tailscale.com/cmd/stund+ fmt from compress/flate+ go/token from google.golang.org/protobuf/internal/strs hash from crypto+ @@ -152,9 +205,55 @@ tailscale.com/cmd/stund dependencies: (generated by github.com/tailscale/depawar hash/fnv from google.golang.org/protobuf/internal/detrand hash/maphash from go4.org/mem html from net/http/pprof+ + internal/abi from crypto/x509/internal/macos+ + internal/asan from internal/runtime/maps+ + internal/bisect from internal/godebug + internal/bytealg from bytes+ + internal/byteorder from crypto/cipher+ + internal/chacha8rand from math/rand/v2+ + internal/coverage/rtcov from runtime + internal/cpu from crypto/internal/fips140deps/cpu+ + internal/filepathlite from os+ + internal/fmtsort from fmt + internal/goarch from crypto/internal/fips140deps/cpu+ + internal/godebug from crypto/internal/fips140deps/godebug+ + internal/godebugs from internal/godebug+ + internal/goexperiment from hash/maphash+ + internal/goos from crypto/x509+ + internal/itoa from internal/poll+ + internal/msan from internal/runtime/maps+ + internal/nettrace from net+ + internal/oserror from io/fs+ + internal/poll from net+ + internal/profile from net/http/pprof + internal/profilerecord from runtime+ + internal/race from internal/poll+ + internal/reflectlite from context+ + D internal/routebsd from net + internal/runtime/atomic from internal/runtime/exithook+ + L internal/runtime/cgroup from runtime + internal/runtime/exithook from runtime + internal/runtime/gc from runtime + internal/runtime/maps from reflect+ + internal/runtime/math from internal/runtime/maps+ + internal/runtime/strconv from internal/runtime/cgroup+ + internal/runtime/sys from crypto/subtle+ + L internal/runtime/syscall from runtime+ + internal/saferio from encoding/asn1 + internal/singleflight from net + internal/stringslite from embed+ + internal/sync from sync+ + internal/synctest from sync + internal/syscall/execenv from os + LD internal/syscall/unix from crypto/internal/sysrand+ + W internal/syscall/windows from crypto/internal/sysrand+ + W internal/syscall/windows/registry from mime+ + W internal/syscall/windows/sysdll from internal/syscall/windows+ + internal/testlog from os + internal/trace/tracev2 from runtime+ + internal/unsafeheader from internal/reflectlite+ io from bufio+ io/fs from crypto/x509+ - io/ioutil from google.golang.org/protobuf/internal/impl iter from maps+ log from expvar+ log/internal from log @@ -163,25 +262,28 @@ tailscale.com/cmd/stund dependencies: (generated by github.com/tailscale/depawar math/big from crypto/dsa+ math/bits from compress/flate+ math/rand from math/big+ - math/rand/v2 from tailscale.com/util/fastuuid+ + math/rand/v2 from crypto/ecdsa+ mime from github.com/prometheus/common/expfmt+ mime/multipart from net/http mime/quotedprintable from mime/multipart net from crypto/tls+ net/http from expvar+ - net/http/httptrace from net/http + net/http/httptrace from net/http+ net/http/internal from net/http + net/http/internal/ascii from net/http + net/http/internal/httpcommon from net/http net/http/pprof from tailscale.com/tsweb net/netip from go4.org/netipx+ - net/textproto from golang.org/x/net/http/httpguts+ + net/textproto from mime/multipart+ net/url from crypto/x509+ - os from crypto/rand+ + os from crypto/internal/sysrand+ os/signal from tailscale.com/cmd/stund path from github.com/prometheus/client_golang/prometheus/internal+ path/filepath from crypto/x509+ reflect from crypto/x509+ regexp from github.com/prometheus/client_golang/prometheus/internal+ regexp/syntax from regexp + runtime from crypto/internal/fips140+ runtime/debug from github.com/prometheus/client_golang/prometheus+ runtime/metrics from github.com/prometheus/client_golang/prometheus+ runtime/pprof from net/http/pprof @@ -190,12 +292,15 @@ tailscale.com/cmd/stund dependencies: (generated by github.com/tailscale/depawar sort from compress/flate+ strconv from compress/flate+ strings from bufio+ + W structs from internal/syscall/windows sync from compress/flate+ sync/atomic from context+ - syscall from crypto/rand+ + syscall from crypto/internal/sysrand+ text/tabwriter from runtime/pprof time from compress/gzip+ unicode from bytes+ unicode/utf16 from crypto/x509+ unicode/utf8 from bufio+ unique from net/netip + unsafe from bytes+ + weak from unique+ diff --git a/cmd/stund/stund.go b/cmd/stund/stund.go index c38429169..1055d966f 100644 --- a/cmd/stund/stund.go +++ b/cmd/stund/stund.go @@ -15,6 +15,9 @@ import ( "tailscale.com/net/stunserver" "tailscale.com/tsweb" + + // Support for prometheus varz in tsweb + _ "tailscale.com/tsweb/promvarz" ) var ( diff --git a/cmd/stunstamp/stunstamp.go b/cmd/stunstamp/stunstamp.go index c3842e2e8..153dc9303 100644 --- a/cmd/stunstamp/stunstamp.go +++ b/cmd/stunstamp/stunstamp.go @@ -34,10 +34,10 @@ import ( "github.com/golang/snappy" "github.com/prometheus/prometheus/prompb" "github.com/tcnksm/go-httpstat" - "tailscale.com/logtail/backoff" "tailscale.com/net/stun" "tailscale.com/net/tcpinfo" "tailscale.com/tailcfg" + "tailscale.com/util/backoff" ) var ( @@ -135,18 +135,18 @@ type lportsPool struct { ports []int } -func (l *lportsPool) get() int { - l.Lock() - defer l.Unlock() - ret := l.ports[0] - l.ports = append(l.ports[:0], l.ports[1:]...) +func (pl *lportsPool) get() int { + pl.Lock() + defer pl.Unlock() + ret := pl.ports[0] + pl.ports = append(pl.ports[:0], pl.ports[1:]...) return ret } -func (l *lportsPool) put(i int) { - l.Lock() - defer l.Unlock() - l.ports = append(l.ports, int(i)) +func (pl *lportsPool) put(i int) { + pl.Lock() + defer pl.Unlock() + pl.ports = append(pl.ports, int(i)) } var ( @@ -173,19 +173,19 @@ func init() { // measure dial time. type lportForTCPConn int -func (l *lportForTCPConn) Close() error { - if *l == 0 { +func (lp *lportForTCPConn) Close() error { + if *lp == 0 { return nil } - lports.put(int(*l)) + lports.put(int(*lp)) return nil } -func (l *lportForTCPConn) Write([]byte) (int, error) { +func (lp *lportForTCPConn) Write([]byte) (int, error) { return 0, errors.New("unimplemented") } -func (l *lportForTCPConn) Read([]byte) (int, error) { +func (lp *lportForTCPConn) Read([]byte) (int, error) { return 0, errors.New("unimplemented") } diff --git a/cmd/sync-containers/main.go b/cmd/sync-containers/main.go index 6317b4943..63efa5453 100644 --- a/cmd/sync-containers/main.go +++ b/cmd/sync-containers/main.go @@ -65,9 +65,9 @@ func main() { } add, remove := diffTags(stags, dtags) - if l := len(add); l > 0 { + if ln := len(add); ln > 0 { log.Printf("%d tags to push: %s", len(add), strings.Join(add, ", ")) - if *max > 0 && l > *max { + if *max > 0 && ln > *max { log.Printf("Limiting sync to %d tags", *max) add = add[:*max] } diff --git a/cmd/systray/logo.go b/cmd/systray/logo.go deleted file mode 100644 index cd79c94a0..000000000 --- a/cmd/systray/logo.go +++ /dev/null @@ -1,220 +0,0 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build cgo || !darwin - -package main - -import ( - "bytes" - "context" - "image/color" - "image/png" - "sync" - "time" - - "fyne.io/systray" - "github.com/fogleman/gg" -) - -// tsLogo represents the state of the 3x3 dot grid in the Tailscale logo. -// A 0 represents a gray dot, any other value is a white dot. -type tsLogo [9]byte - -var ( - // disconnected is all gray dots - disconnected = tsLogo{ - 0, 0, 0, - 0, 0, 0, - 0, 0, 0, - } - - // connected is the normal Tailscale logo - connected = tsLogo{ - 0, 0, 0, - 1, 1, 1, - 0, 1, 0, - } - - // loading is a special tsLogo value that is not meant to be rendered directly, - // but indicates that the loading animation should be shown. - loading = tsLogo{'l', 'o', 'a', 'd', 'i', 'n', 'g'} - - // loadingIcons are shown in sequence as an animated loading icon. - loadingLogos = []tsLogo{ - { - 0, 1, 1, - 1, 0, 1, - 0, 0, 1, - }, - { - 0, 1, 1, - 0, 0, 1, - 0, 1, 0, - }, - { - 0, 1, 1, - 0, 0, 0, - 0, 0, 1, - }, - { - 0, 0, 1, - 0, 1, 0, - 0, 0, 0, - }, - { - 0, 1, 0, - 0, 0, 0, - 0, 0, 0, - }, - { - 0, 0, 0, - 0, 0, 1, - 0, 0, 0, - }, - { - 0, 0, 0, - 0, 0, 0, - 0, 0, 0, - }, - { - 0, 0, 1, - 0, 0, 0, - 0, 0, 0, - }, - { - 0, 0, 0, - 0, 0, 0, - 1, 0, 0, - }, - { - 0, 0, 0, - 0, 0, 0, - 1, 1, 0, - }, - { - 0, 0, 0, - 1, 0, 0, - 1, 1, 0, - }, - { - 0, 0, 0, - 1, 1, 0, - 0, 1, 0, - }, - { - 0, 0, 0, - 1, 1, 0, - 0, 1, 1, - }, - { - 0, 0, 0, - 1, 1, 1, - 0, 0, 1, - }, - { - 0, 1, 0, - 0, 1, 1, - 1, 0, 1, - }, - } -) - -var ( - black = color.NRGBA{0, 0, 0, 255} - white = color.NRGBA{255, 255, 255, 255} - gray = color.NRGBA{255, 255, 255, 102} -) - -// render returns a PNG image of the logo. -func (logo tsLogo) render() *bytes.Buffer { - const radius = 25 - const borderUnits = 1 - dim := radius * (8 + borderUnits*2) - - dc := gg.NewContext(dim, dim) - dc.DrawRectangle(0, 0, float64(dim), float64(dim)) - dc.SetColor(black) - dc.Fill() - - for y := 0; y < 3; y++ { - for x := 0; x < 3; x++ { - px := (borderUnits + 1 + 3*x) * radius - py := (borderUnits + 1 + 3*y) * radius - col := white - if logo[y*3+x] == 0 { - col = gray - } - dc.DrawCircle(float64(px), float64(py), radius) - dc.SetColor(col) - dc.Fill() - } - } - - b := bytes.NewBuffer(nil) - png.Encode(b, dc.Image()) - return b -} - -// setAppIcon renders logo and sets it as the systray icon. -func setAppIcon(icon tsLogo) { - if icon == loading { - startLoadingAnimation() - } else { - stopLoadingAnimation() - systray.SetIcon(icon.render().Bytes()) - } -} - -var ( - loadingMu sync.Mutex // protects loadingCancel - - // loadingCancel stops the loading animation in the systray icon. - // This is nil if the animation is not currently active. - loadingCancel func() -) - -// startLoadingAnimation starts the animated loading icon in the system tray. -// The animation continues until [stopLoadingAnimation] is called. -// If the loading animation is already active, this func does nothing. -func startLoadingAnimation() { - loadingMu.Lock() - defer loadingMu.Unlock() - - if loadingCancel != nil { - // loading icon already displayed - return - } - - ctx := context.Background() - ctx, loadingCancel = context.WithCancel(ctx) - - go func() { - t := time.NewTicker(500 * time.Millisecond) - var i int - for { - select { - case <-ctx.Done(): - return - case <-t.C: - systray.SetIcon(loadingLogos[i].render().Bytes()) - i++ - if i >= len(loadingLogos) { - i = 0 - } - } - } - }() -} - -// stopLoadingAnimation stops the animated loading icon in the system tray. -// If the loading animation is not currently active, this func does nothing. -func stopLoadingAnimation() { - loadingMu.Lock() - defer loadingMu.Unlock() - - if loadingCancel != nil { - loadingCancel() - loadingCancel = nil - } -} diff --git a/cmd/systray/systray.go b/cmd/systray/systray.go index aca38f627..d35595e25 100644 --- a/cmd/systray/systray.go +++ b/cmd/systray/systray.go @@ -3,256 +3,21 @@ //go:build cgo || !darwin -// The systray command is a minimal Tailscale systray application for Linux. +// systray is a minimal Tailscale systray application. package main import ( - "context" - "errors" - "fmt" - "io" - "log" - "os" - "strings" - "sync" - "time" + "flag" - "fyne.io/systray" - "github.com/atotto/clipboard" - dbus "github.com/godbus/dbus/v5" - "github.com/toqueteos/webbrowser" - "tailscale.com/client/tailscale" - "tailscale.com/ipn" - "tailscale.com/ipn/ipnstate" + "tailscale.com/client/local" + "tailscale.com/client/systray" + "tailscale.com/paths" ) -var ( - localClient tailscale.LocalClient - chState chan ipn.State // tailscale state changes - - appIcon *os.File -) +var socket = flag.String("socket", paths.DefaultTailscaledSocket(), "path to tailscaled socket") func main() { - systray.Run(onReady, onExit) -} - -// Menu represents the systray menu, its items, and the current Tailscale state. -type Menu struct { - mu sync.Mutex // protects the entire Menu - status *ipnstate.Status - - connect *systray.MenuItem - disconnect *systray.MenuItem - - self *systray.MenuItem - more *systray.MenuItem - quit *systray.MenuItem - - eventCancel func() // cancel eventLoop -} - -func onReady() { - log.Printf("starting") - ctx := context.Background() - - setAppIcon(disconnected) - - // dbus wants a file path for notification icons, so copy to a temp file. - appIcon, _ = os.CreateTemp("", "tailscale-systray.png") - io.Copy(appIcon, connected.render()) - - chState = make(chan ipn.State, 1) - - status, err := localClient.Status(ctx) - if err != nil { - log.Print(err) - } - - menu := new(Menu) - menu.rebuild(status) - - go watchIPNBus(ctx) -} - -// rebuild the systray menu based on the current Tailscale state. -// -// We currently rebuild the entire menu because it is not easy to update the existing menu. -// You cannot iterate over the items in a menu, nor can you remove some items like separators. -// So for now we rebuild the whole thing, and can optimize this later if needed. -func (menu *Menu) rebuild(status *ipnstate.Status) { - menu.mu.Lock() - defer menu.mu.Unlock() - - if menu.eventCancel != nil { - menu.eventCancel() - } - menu.status = status - systray.ResetMenu() - - menu.connect = systray.AddMenuItem("Connect", "") - menu.disconnect = systray.AddMenuItem("Disconnect", "") - menu.disconnect.Hide() - systray.AddSeparator() - - if status != nil && status.Self != nil { - title := fmt.Sprintf("This Device: %s (%s)", status.Self.HostName, status.Self.TailscaleIPs[0]) - menu.self = systray.AddMenuItem(title, "") - } - systray.AddSeparator() - - menu.more = systray.AddMenuItem("More settings", "") - menu.more.Enable() - - menu.quit = systray.AddMenuItem("Quit", "Quit the app") - menu.quit.Enable() - - ctx := context.Background() - ctx, menu.eventCancel = context.WithCancel(ctx) - go menu.eventLoop(ctx) -} - -// eventLoop is the main event loop for handling click events on menu items -// and responding to Tailscale state changes. -// This method does not return until ctx.Done is closed. -func (menu *Menu) eventLoop(ctx context.Context) { - for { - select { - case <-ctx.Done(): - return - case state := <-chState: - switch state { - case ipn.Running: - setAppIcon(loading) - status, err := localClient.Status(ctx) - if err != nil { - log.Printf("error getting tailscale status: %v", err) - } - menu.rebuild(status) - setAppIcon(connected) - menu.connect.SetTitle("Connected") - menu.connect.Disable() - menu.disconnect.Show() - menu.disconnect.Enable() - case ipn.NoState, ipn.Stopped: - menu.connect.SetTitle("Connect") - menu.connect.Enable() - menu.disconnect.Hide() - setAppIcon(disconnected) - case ipn.Starting: - setAppIcon(loading) - } - case <-menu.connect.ClickedCh: - _, err := localClient.EditPrefs(ctx, &ipn.MaskedPrefs{ - Prefs: ipn.Prefs{ - WantRunning: true, - }, - WantRunningSet: true, - }) - if err != nil { - log.Print(err) - continue - } - - case <-menu.disconnect.ClickedCh: - _, err := localClient.EditPrefs(ctx, &ipn.MaskedPrefs{ - Prefs: ipn.Prefs{ - WantRunning: false, - }, - WantRunningSet: true, - }) - if err != nil { - log.Printf("disconnecting: %v", err) - continue - } - - case <-menu.self.ClickedCh: - copyTailscaleIP(menu.status.Self) - - case <-menu.more.ClickedCh: - webbrowser.Open("http://100.100.100.100/") - - case <-menu.quit.ClickedCh: - systray.Quit() - } - } -} - -// watchIPNBus subscribes to the tailscale event bus and sends state updates to chState. -// This method does not return. -func watchIPNBus(ctx context.Context) { - for { - if err := watchIPNBusInner(ctx); err != nil { - log.Println(err) - if errors.Is(err, context.Canceled) { - // If the context got canceled, we will never be able to - // reconnect to IPN bus, so exit the process. - log.Fatalf("watchIPNBus: %v", err) - } - } - // If our watch connection breaks, wait a bit before reconnecting. No - // reason to spam the logs if e.g. tailscaled is restarting or goes - // down. - time.Sleep(3 * time.Second) - } -} - -func watchIPNBusInner(ctx context.Context) error { - watcher, err := localClient.WatchIPNBus(ctx, ipn.NotifyInitialState|ipn.NotifyNoPrivateKeys) - if err != nil { - return fmt.Errorf("watching ipn bus: %w", err) - } - defer watcher.Close() - for { - select { - case <-ctx.Done(): - return nil - default: - n, err := watcher.Next() - if err != nil { - return fmt.Errorf("ipnbus error: %w", err) - } - if n.State != nil { - chState <- *n.State - log.Printf("new state: %v", n.State) - } - } - } -} - -// copyTailscaleIP copies the first Tailscale IP of the given device to the clipboard -// and sends a notification with the copied value. -func copyTailscaleIP(device *ipnstate.PeerStatus) { - if device == nil || len(device.TailscaleIPs) == 0 { - return - } - name := strings.Split(device.DNSName, ".")[0] - ip := device.TailscaleIPs[0].String() - err := clipboard.WriteAll(ip) - if err != nil { - log.Printf("clipboard error: %v", err) - } - - sendNotification(fmt.Sprintf("Copied Address for %v", name), ip) -} - -// sendNotification sends a desktop notification with the given title and content. -func sendNotification(title, content string) { - conn, err := dbus.SessionBus() - if err != nil { - log.Printf("dbus: %v", err) - return - } - timeout := 3 * time.Second - obj := conn.Object("org.freedesktop.Notifications", "/org/freedesktop/Notifications") - call := obj.Call("org.freedesktop.Notifications.Notify", 0, "Tailscale", uint32(0), - appIcon.Name(), title, content, []string{}, map[string]dbus.Variant{}, int32(timeout.Milliseconds())) - if call.Err != nil { - log.Printf("dbus: %v", call.Err) - } -} - -func onExit() { - log.Printf("exiting") - os.Remove(appIcon.Name()) + flag.Parse() + lc := &local.Client{Socket: *socket} + new(systray.Menu).Run(lc) } diff --git a/cmd/tailscale/cli/appcroutes.go b/cmd/tailscale/cli/appcroutes.go new file mode 100644 index 000000000..4a1ba87e3 --- /dev/null +++ b/cmd/tailscale/cli/appcroutes.go @@ -0,0 +1,153 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package cli + +import ( + "context" + "encoding/json" + "flag" + "fmt" + "slices" + "strings" + + "github.com/peterbourgon/ff/v3/ffcli" + "tailscale.com/types/appctype" +) + +var appcRoutesArgs struct { + all bool + domainMap bool + n bool +} + +var appcRoutesCmd = &ffcli.Command{ + Name: "appc-routes", + ShortUsage: "tailscale appc-routes", + Exec: runAppcRoutesInfo, + ShortHelp: "Print the current app connector routes", + FlagSet: (func() *flag.FlagSet { + fs := newFlagSet("appc-routes") + fs.BoolVar(&appcRoutesArgs.all, "all", false, "Print learned domains and routes and extra policy configured routes.") + fs.BoolVar(&appcRoutesArgs.domainMap, "map", false, "Print the map of learned domains: [routes].") + fs.BoolVar(&appcRoutesArgs.n, "n", false, "Print the total number of routes this node advertises.") + return fs + })(), + LongHelp: strings.TrimSpace(` +The 'tailscale appc-routes' command prints the current App Connector route status. + +By default this command prints the domains configured in the app connector configuration and how many routes have been +learned for each domain. + +--all prints the routes learned from the domains configured in the app connector configuration; and any extra routes provided +in the the policy app connector 'routes' field. + +--map prints the routes learned from the domains configured in the app connector configuration. + +-n prints the total number of routes advertised by this device, whether learned, set in the policy, or set locally. + +For more information about App Connectors, refer to +https://tailscale.com/kb/1281/app-connectors +`), +} + +func getAllOutput(ri *appctype.RouteInfo) (string, error) { + domains, err := json.MarshalIndent(ri.Domains, " ", " ") + if err != nil { + return "", err + } + control, err := json.MarshalIndent(ri.Control, " ", " ") + if err != nil { + return "", err + } + s := fmt.Sprintf(`Learned Routes +============== +%s + +Routes from Policy +================== +%s +`, domains, control) + return s, nil +} + +type domainCount struct { + domain string + count int +} + +func getSummarizeLearnedOutput(ri *appctype.RouteInfo) string { + x := make([]domainCount, len(ri.Domains)) + i := 0 + maxDomainWidth := 0 + for k, v := range ri.Domains { + if len(k) > maxDomainWidth { + maxDomainWidth = len(k) + } + x[i] = domainCount{domain: k, count: len(v)} + i++ + } + slices.SortFunc(x, func(i, j domainCount) int { + if i.count > j.count { + return -1 + } + if i.count < j.count { + return 1 + } + if i.domain > j.domain { + return 1 + } + if i.domain < j.domain { + return -1 + } + return 0 + }) + s := "" + fmtString := fmt.Sprintf("%%-%ds %%d\n", maxDomainWidth) // eg "%-10s %d\n" + for _, dc := range x { + s += fmt.Sprintf(fmtString, dc.domain, dc.count) + } + return s +} + +func runAppcRoutesInfo(ctx context.Context, args []string) error { + prefs, err := localClient.GetPrefs(ctx) + if err != nil { + return err + } + if !prefs.AppConnector.Advertise { + fmt.Println("not a connector") + return nil + } + + if appcRoutesArgs.n { + fmt.Println(len(prefs.AdvertiseRoutes)) + return nil + } + + routeInfo, err := localClient.GetAppConnectorRouteInfo(ctx) + if err != nil { + return err + } + + if appcRoutesArgs.domainMap { + domains, err := json.Marshal(routeInfo.Domains) + if err != nil { + return err + } + fmt.Println(string(domains)) + return nil + } + + if appcRoutesArgs.all { + s, err := getAllOutput(&routeInfo) + if err != nil { + return err + } + fmt.Println(s) + return nil + } + + fmt.Print(getSummarizeLearnedOutput(&routeInfo)) + return nil +} diff --git a/cmd/tailscale/cli/bugreport.go b/cmd/tailscale/cli/bugreport.go index d671f3df6..50e6ffd82 100644 --- a/cmd/tailscale/cli/bugreport.go +++ b/cmd/tailscale/cli/bugreport.go @@ -10,7 +10,7 @@ import ( "fmt" "github.com/peterbourgon/ff/v3/ffcli" - "tailscale.com/client/tailscale" + "tailscale.com/client/local" ) var bugReportCmd = &ffcli.Command{ @@ -40,7 +40,7 @@ func runBugReport(ctx context.Context, args []string) error { default: return errors.New("unknown arguments") } - opts := tailscale.BugReportOpts{ + opts := local.BugReportOpts{ Note: note, Diagnose: bugReportArgs.diagnose, } diff --git a/cmd/tailscale/cli/cert.go b/cmd/tailscale/cli/cert.go index 9c8eca5b7..171eebe1e 100644 --- a/cmd/tailscale/cli/cert.go +++ b/cmd/tailscale/cli/cert.go @@ -1,6 +1,8 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause +//go:build !js && !ts_omit_acme + package cli import ( @@ -25,19 +27,23 @@ import ( "tailscale.com/version" ) -var certCmd = &ffcli.Command{ - Name: "cert", - Exec: runCert, - ShortHelp: "Get TLS certs", - ShortUsage: "tailscale cert [flags] ", - FlagSet: (func() *flag.FlagSet { - fs := newFlagSet("cert") - fs.StringVar(&certArgs.certFile, "cert-file", "", "output cert file or \"-\" for stdout; defaults to DOMAIN.crt if --cert-file and --key-file are both unset") - fs.StringVar(&certArgs.keyFile, "key-file", "", "output key file or \"-\" for stdout; defaults to DOMAIN.key if --cert-file and --key-file are both unset") - fs.BoolVar(&certArgs.serve, "serve-demo", false, "if true, serve on port :443 using the cert as a demo, instead of writing out the files to disk") - fs.DurationVar(&certArgs.minValidity, "min-validity", 0, "ensure the certificate is valid for at least this duration; the output certificate is never expired if this flag is unset or 0, but the lifetime may vary; the maximum allowed min-validity depends on the CA") - return fs - })(), +func init() { + maybeCertCmd = func() *ffcli.Command { + return &ffcli.Command{ + Name: "cert", + Exec: runCert, + ShortHelp: "Get TLS certs", + ShortUsage: "tailscale cert [flags] ", + FlagSet: (func() *flag.FlagSet { + fs := newFlagSet("cert") + fs.StringVar(&certArgs.certFile, "cert-file", "", "output cert file or \"-\" for stdout; defaults to DOMAIN.crt if --cert-file and --key-file are both unset") + fs.StringVar(&certArgs.keyFile, "key-file", "", "output key file or \"-\" for stdout; defaults to DOMAIN.key if --cert-file and --key-file are both unset") + fs.BoolVar(&certArgs.serve, "serve-demo", false, "if true, serve on port :443 using the cert as a demo, instead of writing out the files to disk") + fs.DurationVar(&certArgs.minValidity, "min-validity", 0, "ensure the certificate is valid for at least this duration; the output certificate is never expired if this flag is unset or 0, but the lifetime may vary; the maximum allowed min-validity depends on the CA") + return fs + })(), + } + } } var certArgs struct { diff --git a/cmd/tailscale/cli/cli.go b/cmd/tailscale/cli/cli.go index 864cf6903..5ebc23a5b 100644 --- a/cmd/tailscale/cli/cli.go +++ b/cmd/tailscale/cli/cli.go @@ -7,6 +7,7 @@ package cli import ( "context" + "encoding/json" "errors" "flag" "fmt" @@ -17,14 +18,17 @@ import ( "strings" "sync" "text/tabwriter" + "time" "github.com/mattn/go-colorable" "github.com/mattn/go-isatty" "github.com/peterbourgon/ff/v3/ffcli" - "tailscale.com/client/tailscale" + "tailscale.com/client/local" "tailscale.com/cmd/tailscale/cli/ffcomplete" "tailscale.com/envknob" + "tailscale.com/feature" "tailscale.com/paths" + "tailscale.com/util/slicesx" "tailscale.com/version/distro" ) @@ -63,42 +67,54 @@ func newFlagSet(name string) *flag.FlagSet { func CleanUpArgs(args []string) []string { out := make([]string, 0, len(args)) for _, arg := range args { + switch { // Rewrite --authkey to --auth-key, and --authkey=x to --auth-key=x, // and the same for the -authkey variant. - switch { case arg == "--authkey", arg == "-authkey": arg = "--auth-key" case strings.HasPrefix(arg, "--authkey="), strings.HasPrefix(arg, "-authkey="): - arg = strings.TrimLeft(arg, "-") - arg = strings.TrimPrefix(arg, "authkey=") - arg = "--auth-key=" + arg + _, val, _ := strings.Cut(arg, "=") + arg = "--auth-key=" + val + + // And the same, for posture-checking => report-posture + case arg == "--posture-checking", arg == "-posture-checking": + arg = "--report-posture" + case strings.HasPrefix(arg, "--posture-checking="), strings.HasPrefix(arg, "-posture-checking="): + _, val, _ := strings.Cut(arg, "=") + arg = "--report-posture=" + val + } out = append(out, arg) } return out } -var localClient = tailscale.LocalClient{ +var localClient = local.Client{ Socket: paths.DefaultTailscaledSocket(), } // Run runs the CLI. The args do not include the binary name. func Run(args []string) (err error) { - if runtime.GOOS == "linux" && os.Getenv("GOKRAZY_FIRST_START") == "1" && distro.Get() == distro.Gokrazy && os.Getppid() == 1 { - // We're running on gokrazy and it's the first start. - // Don't run the tailscale CLI as a service; just exit. + if runtime.GOOS == "linux" && os.Getenv("GOKRAZY_FIRST_START") == "1" && distro.Get() == distro.Gokrazy && os.Getppid() == 1 && len(args) == 0 { + // We're running on gokrazy and the user did not specify 'up'. + // Don't run the tailscale CLI and spam logs with usage; just exit. // See https://gokrazy.org/development/process-interface/ os.Exit(0) } args = CleanUpArgs(args) - if len(args) == 1 && (args[0] == "-V" || args[0] == "--version") { - args = []string{"version"} + if len(args) == 1 { + switch args[0] { + case "-V", "--version": + args = []string{"version"} + case "help": + args = []string{"--help"} + } } var warnOnce sync.Once - tailscale.SetVersionMismatchHandler(func(clientVer, serverVer string) { + local.SetVersionMismatchHandler(func(clientVer, serverVer string) { warnOnce.Do(func() { fmt.Fprintf(Stderr, "Warning: client version %q != tailscaled server version %q\n", clientVer, serverVer) }) @@ -149,8 +165,8 @@ func Run(args []string) (err error) { } err = rootCmd.Run(context.Background()) - if tailscale.IsAccessDeniedError(err) && os.Getuid() != 0 && runtime.GOOS != "windows" { - return fmt.Errorf("%v\n\nUse 'sudo tailscale %s' or 'tailscale up --operator=$USER' to not require root.", err, strings.Join(args, " ")) + if local.IsAccessDeniedError(err) && os.Getuid() != 0 && runtime.GOOS != "windows" { + return fmt.Errorf("%v\n\nUse 'sudo tailscale %s'.\nTo not require root, use 'sudo tailscale set --operator=$USER' once.", err, strings.Join(args, " ")) } if errors.Is(err, flag.ErrHelp) { return nil @@ -158,6 +174,53 @@ func Run(args []string) (err error) { return err } +type onceFlagValue struct { + flag.Value + set bool +} + +func (v *onceFlagValue) Set(s string) error { + if v.set { + return fmt.Errorf("flag provided multiple times") + } + v.set = true + return v.Value.Set(s) +} + +func (v *onceFlagValue) IsBoolFlag() bool { + type boolFlag interface { + IsBoolFlag() bool + } + bf, ok := v.Value.(boolFlag) + return ok && bf.IsBoolFlag() +} + +// noDupFlagify modifies c recursively to make all the +// flag values be wrappers that permit setting the value +// at most once. +func noDupFlagify(c *ffcli.Command) { + if c.FlagSet != nil { + c.FlagSet.VisitAll(func(f *flag.Flag) { + f.Value = &onceFlagValue{Value: f.Value} + }) + } + for _, sub := range c.Subcommands { + noDupFlagify(sub) + } +} + +var ( + fileCmd, + sysPolicyCmd, + maybeWebCmd, + maybeDriveCmd, + maybeNetlockCmd, + maybeFunnelCmd, + maybeServeCmd, + maybeCertCmd, + _ func() *ffcli.Command +) + func newRootCmd() *ffcli.Command { rootfs := newFlagSet("tailscale") rootfs.Func("socket", "path to tailscaled socket", func(s string) error { @@ -166,8 +229,10 @@ func newRootCmd() *ffcli.Command { return nil }) rootfs.Lookup("socket").DefValue = localClient.Socket + jsonDocs := rootfs.Bool("json-docs", false, hidden+"print JSON-encoded docs for all subcommands and flags") - rootCmd := &ffcli.Command{ + var rootCmd *ffcli.Command + rootCmd = &ffcli.Command{ Name: "tailscale", ShortUsage: "tailscale [flags] [command flags]", ShortHelp: "The easiest, most secure way to use WireGuard.", @@ -177,39 +242,47 @@ For help on subcommands, add --help after: "tailscale status --help". This CLI is still under active development. Commands and flags will change in the future. `), - Subcommands: []*ffcli.Command{ + Subcommands: nonNilCmds( upCmd, downCmd, setCmd, loginCmd, logoutCmd, switchCmd, - configureCmd, + configureCmd(), + nilOrCall(sysPolicyCmd), netcheckCmd, ipCmd, dnsCmd, statusCmd, + metricsCmd, pingCmd, ncCmd, sshCmd, - funnelCmd(), - serveCmd(), + nilOrCall(maybeFunnelCmd), + nilOrCall(maybeServeCmd), versionCmd, - webCmd, - fileCmd, + nilOrCall(maybeWebCmd), + nilOrCall(fileCmd), bugReportCmd, - certCmd, - netlockCmd, + nilOrCall(maybeCertCmd), + nilOrCall(maybeNetlockCmd), licensesCmd, exitNodeCmd(), updateCmd, whoisCmd, - debugCmd, - driveCmd, + debugCmd(), + nilOrCall(maybeDriveCmd), idTokenCmd, - }, + configureHostCmd(), + systrayCmd, + appcRoutesCmd, + ), FlagSet: rootfs, Exec: func(ctx context.Context, args []string) error { + if *jsonDocs { + return printJSONDocs(rootCmd) + } if len(args) > 0 { return fmt.Errorf("tailscale: unknown subcommand: %s", args[0]) } @@ -217,10 +290,6 @@ change in the future. }, } - if runtime.GOOS == "linux" && distro.Get() == distro.Synology { - rootCmd.Subcommands = append(rootCmd.Subcommands, configureHostCmd) - } - walkCommands(rootCmd, func(w cmdWalk) bool { if w.UsageFunc == nil { w.UsageFunc = usageFunc @@ -229,9 +298,21 @@ change in the future. }) ffcomplete.Inject(rootCmd, func(c *ffcli.Command) { c.LongHelp = hidden + c.LongHelp }, usageFunc) + noDupFlagify(rootCmd) return rootCmd } +func nonNilCmds(cmds ...*ffcli.Command) []*ffcli.Command { + return slicesx.AppendNonzero(cmds[:0], cmds) +} + +func nilOrCall(f func() *ffcli.Command) *ffcli.Command { + if f == nil { + return nil + } + return f() +} + func fatalf(format string, a ...any) { if Fatalf != nil { Fatalf(format, a...) @@ -409,3 +490,79 @@ func colorableOutput() (w io.Writer, ok bool) { } return colorable.NewColorableStdout(), true } + +type commandDoc struct { + Name string + Desc string + Subcommands []commandDoc `json:",omitempty"` + Flags []flagDoc `json:",omitempty"` +} + +type flagDoc struct { + Name string + Desc string +} + +func printJSONDocs(root *ffcli.Command) error { + docs := jsonDocsWalk(root) + return json.NewEncoder(os.Stdout).Encode(docs) +} + +func jsonDocsWalk(cmd *ffcli.Command) *commandDoc { + res := &commandDoc{ + Name: cmd.Name, + } + if cmd.LongHelp != "" { + res.Desc = cmd.LongHelp + } else if cmd.ShortHelp != "" { + res.Desc = cmd.ShortHelp + } else { + res.Desc = cmd.ShortUsage + } + if strings.HasPrefix(res.Desc, hidden) { + return nil + } + if cmd.FlagSet != nil { + cmd.FlagSet.VisitAll(func(f *flag.Flag) { + if strings.HasPrefix(f.Usage, hidden) { + return + } + res.Flags = append(res.Flags, flagDoc{ + Name: f.Name, + Desc: f.Usage, + }) + }) + } + for _, sub := range cmd.Subcommands { + subj := jsonDocsWalk(sub) + if subj != nil { + res.Subcommands = append(res.Subcommands, *subj) + } + } + return res +} + +func lastSeenFmt(t time.Time) string { + if t.IsZero() { + return "" + } + d := max(time.Since(t), time.Minute) // at least 1 minute + + switch { + case d < time.Hour: + return fmt.Sprintf(", last seen %dm ago", int(d.Minutes())) + case d < 24*time.Hour: + return fmt.Sprintf(", last seen %dh ago", int(d.Hours())) + default: + return fmt.Sprintf(", last seen %dd ago", int(d.Hours()/24)) + } +} + +var hookFixTailscaledConnectError feature.Hook[func(error) error] // for cliconndiag + +func fixTailscaledConnectError(origErr error) error { + if f, ok := hookFixTailscaledConnectError.GetOk(); ok { + return f(origErr) + } + return origErr +} diff --git a/cmd/tailscale/cli/cli_test.go b/cmd/tailscale/cli/cli_test.go index b0658fd95..8762b7aae 100644 --- a/cmd/tailscale/cli/cli_test.go +++ b/cmd/tailscale/cli/cli_test.go @@ -9,6 +9,7 @@ import ( "encoding/json" "flag" "fmt" + "io" "net/netip" "reflect" "strings" @@ -16,6 +17,7 @@ import ( qt "github.com/frankban/quicktest" "github.com/google/go-cmp/cmp" + "github.com/peterbourgon/ff/v3/ffcli" "tailscale.com/envknob" "tailscale.com/health/healthmsg" "tailscale.com/ipn" @@ -23,10 +25,12 @@ import ( "tailscale.com/tailcfg" "tailscale.com/tka" "tailscale.com/tstest" + "tailscale.com/tstest/deptest" "tailscale.com/types/logger" "tailscale.com/types/opt" "tailscale.com/types/persist" "tailscale.com/types/preftype" + "tailscale.com/util/set" "tailscale.com/version/distro" ) @@ -170,6 +174,7 @@ func TestCheckForAccidentalSettingReverts(t *testing.T) { curUser string // os.Getenv("USER") on the client side goos string // empty means "linux" distro distro.Distro + backendState string // empty means "Running" want string }{ @@ -184,6 +189,28 @@ func TestCheckForAccidentalSettingReverts(t *testing.T) { }, want: "", }, + { + name: "bare_up_needs_login_default_prefs", + flags: []string{}, + curPrefs: ipn.NewPrefs(), + backendState: ipn.NeedsLogin.String(), + want: "", + }, + { + name: "bare_up_needs_login_losing_prefs", + flags: []string{}, + curPrefs: &ipn.Prefs{ + // defaults: + ControlURL: ipn.DefaultControlURL, + WantRunning: false, + NetfilterMode: preftype.NetfilterOn, + NoStatefulFiltering: opt.NewBool(true), + // non-default: + CorpDNS: false, + }, + backendState: ipn.NeedsLogin.String(), + want: accidentalUpPrefix + " --accept-dns=false", + }, { name: "losing_hostname", flags: []string{"--accept-dns"}, @@ -600,12 +627,29 @@ func TestCheckForAccidentalSettingReverts(t *testing.T) { goos: "linux", want: "", }, + { + name: "losing_report_posture", + flags: []string{"--accept-dns"}, + curPrefs: &ipn.Prefs{ + ControlURL: ipn.DefaultControlURL, + WantRunning: false, + CorpDNS: true, + PostureChecking: true, + NetfilterMode: preftype.NetfilterOn, + NoStatefulFiltering: opt.NewBool(true), + }, + want: accidentalUpPrefix + " --accept-dns --report-posture", + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - goos := "linux" - if tt.goos != "" { - goos = tt.goos + goos := stdcmp.Or(tt.goos, "linux") + backendState := stdcmp.Or(tt.backendState, ipn.Running.String()) + // Needs to match the other conditions in checkForAccidentalSettingReverts + tt.curPrefs.Persist = &persist.Persist{ + UserProfile: tailcfg.UserProfile{ + LoginName: "janet", + }, } var upArgs upArgsT flagSet := newUpFlagSet(goos, &upArgs, "up") @@ -621,10 +665,11 @@ func TestCheckForAccidentalSettingReverts(t *testing.T) { curExitNodeIP: tt.curExitNodeIP, distro: tt.distro, user: tt.curUser, + backendState: backendState, } applyImplicitPrefs(newPrefs, tt.curPrefs, upEnv) var got string - if err := checkForAccidentalSettingReverts(newPrefs, tt.curPrefs, upEnv); err != nil { + if _, err := checkForAccidentalSettingReverts(newPrefs, tt.curPrefs, upEnv); err != nil { got = err.Error() } if strings.TrimSpace(got) != tt.want { @@ -946,10 +991,17 @@ func TestPrefFlagMapping(t *testing.T) { // Handled by the tailscale share subcommand, we don't want a CLI // flag for this. continue + case "AdvertiseServices": + // Handled by the tailscale serve subcommand, we don't want a + // CLI flag for this. + continue case "InternalExitNodePrior": // Used internally by LocalBackend as part of exit node usage toggling. // No CLI flag for this. continue + case "AutoExitNode": + // Handled by tailscale {set,up} --exit-node=auto:any. + continue } t.Errorf("unexpected new ipn.Pref field %q is not handled by up.go (see addPrefFlagMapping and checkForAccidentalSettingReverts)", prefName) } @@ -987,13 +1039,10 @@ func TestUpdatePrefs(t *testing.T) { wantErrSubtr string }{ { - name: "bare_up_means_up", - flags: []string{}, - curPrefs: &ipn.Prefs{ - ControlURL: ipn.DefaultControlURL, - WantRunning: false, - Hostname: "foo", - }, + name: "bare_up_means_up", + flags: []string{}, + curPrefs: ipn.NewPrefs(), + wantSimpleUp: false, // user profile not set, so no simple up }, { name: "just_up", @@ -1007,6 +1056,32 @@ func TestUpdatePrefs(t *testing.T) { }, wantSimpleUp: true, }, + { + name: "just_up_needs_login_default_prefs", + flags: []string{}, + curPrefs: ipn.NewPrefs(), + env: upCheckEnv{ + backendState: "NeedsLogin", + }, + wantSimpleUp: false, + }, + { + name: "just_up_needs_login_losing_prefs", + flags: []string{}, + curPrefs: &ipn.Prefs{ + // defaults: + ControlURL: ipn.DefaultControlURL, + WantRunning: false, + NetfilterMode: preftype.NetfilterOn, + // non-default: + CorpDNS: false, + }, + env: upCheckEnv{ + backendState: "NeedsLogin", + }, + wantSimpleUp: false, + wantErrSubtr: "tailscale up --accept-dns=false", + }, { name: "just_edit", flags: []string{}, @@ -1040,6 +1115,7 @@ func TestUpdatePrefs(t *testing.T) { NoSNATSet: true, NoStatefulFilteringSet: true, OperatorUserSet: true, + PostureCheckingSet: true, RouteAllSet: true, RunSSHSet: true, ShieldsUpSet: true, @@ -1312,6 +1388,27 @@ func TestUpdatePrefs(t *testing.T) { } }, }, + { + name: "auto_exit_node", + flags: []string{"--exit-node=auto:any"}, + curPrefs: &ipn.Prefs{ + ControlURL: ipn.DefaultControlURL, + CorpDNS: true, // enabled by [ipn.NewPrefs] by default + NetfilterMode: preftype.NetfilterOn, // enabled by [ipn.NewPrefs] by default + }, + wantJustEditMP: &ipn.MaskedPrefs{ + WantRunningSet: true, // enabled by default for tailscale up + AutoExitNodeSet: true, + ExitNodeIDSet: true, // we want ExitNodeID cleared + ExitNodeIPSet: true, // same for ExitNodeIP + }, + env: upCheckEnv{backendState: "Running"}, + checkUpdatePrefsMutations: func(t *testing.T, newPrefs *ipn.Prefs) { + if newPrefs.AutoExitNode != ipn.AnyExitNode { + t.Errorf("AutoExitNode: got %q; want %q", newPrefs.AutoExitNode, ipn.AnyExitNode) + } + }, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -1372,23 +1469,28 @@ var cmpIP = cmp.Comparer(func(a, b netip.Addr) bool { }) func TestCleanUpArgs(t *testing.T) { + type S = []string c := qt.New(t) tests := []struct { in []string want []string }{ - {in: []string{"something"}, want: []string{"something"}}, - {in: []string{}, want: []string{}}, - {in: []string{"--authkey=0"}, want: []string{"--auth-key=0"}}, - {in: []string{"a", "--authkey=1", "b"}, want: []string{"a", "--auth-key=1", "b"}}, - {in: []string{"a", "--auth-key=2", "b"}, want: []string{"a", "--auth-key=2", "b"}}, - {in: []string{"a", "-authkey=3", "b"}, want: []string{"a", "--auth-key=3", "b"}}, - {in: []string{"a", "-auth-key=4", "b"}, want: []string{"a", "-auth-key=4", "b"}}, - {in: []string{"a", "--authkey", "5", "b"}, want: []string{"a", "--auth-key", "5", "b"}}, - {in: []string{"a", "-authkey", "6", "b"}, want: []string{"a", "--auth-key", "6", "b"}}, - {in: []string{"a", "authkey", "7", "b"}, want: []string{"a", "authkey", "7", "b"}}, - {in: []string{"--authkeyexpiry", "8"}, want: []string{"--authkeyexpiry", "8"}}, - {in: []string{"--auth-key-expiry", "9"}, want: []string{"--auth-key-expiry", "9"}}, + {in: S{"something"}, want: S{"something"}}, + {in: S{}, want: S{}}, + {in: S{"--authkey=0"}, want: S{"--auth-key=0"}}, + {in: S{"a", "--authkey=1", "b"}, want: S{"a", "--auth-key=1", "b"}}, + {in: S{"a", "--auth-key=2", "b"}, want: S{"a", "--auth-key=2", "b"}}, + {in: S{"a", "-authkey=3", "b"}, want: S{"a", "--auth-key=3", "b"}}, + {in: S{"a", "-auth-key=4", "b"}, want: S{"a", "-auth-key=4", "b"}}, + {in: S{"a", "--authkey", "5", "b"}, want: S{"a", "--auth-key", "5", "b"}}, + {in: S{"a", "-authkey", "6", "b"}, want: S{"a", "--auth-key", "6", "b"}}, + {in: S{"a", "authkey", "7", "b"}, want: S{"a", "authkey", "7", "b"}}, + {in: S{"--authkeyexpiry", "8"}, want: S{"--authkeyexpiry", "8"}}, + {in: S{"--auth-key-expiry", "9"}, want: S{"--auth-key-expiry", "9"}}, + + {in: S{"--posture-checking"}, want: S{"--report-posture"}}, + {in: S{"-posture-checking"}, want: S{"--report-posture"}}, + {in: S{"--posture-checking=nein"}, want: S{"--report-posture=nein"}}, } for _, tt := range tests { @@ -1448,7 +1550,7 @@ func TestParseNLArgs(t *testing.T) { name: "disablements not allowed", input: []string{"disablement:" + strings.Repeat("02", 32)}, parseKeys: true, - wantErr: fmt.Errorf("parsing key 1: key hex string doesn't have expected type prefix nlpub:"), + wantErr: fmt.Errorf("parsing key 1: key hex string doesn't have expected type prefix tlpub:"), }, { name: "keys not allowed", @@ -1476,3 +1578,148 @@ func TestParseNLArgs(t *testing.T) { }) } } + +// makeQuietContinueOnError modifies c recursively to make all the +// flagsets have error mode flag.ContinueOnError and not +// spew all over stderr. +func makeQuietContinueOnError(c *ffcli.Command) { + if c.FlagSet != nil { + c.FlagSet.Init(c.Name, flag.ContinueOnError) + c.FlagSet.Usage = func() {} + c.FlagSet.SetOutput(io.Discard) + } + c.UsageFunc = func(*ffcli.Command) string { return "" } + for _, sub := range c.Subcommands { + makeQuietContinueOnError(sub) + } +} + +// see tailscale/tailscale#6813 +func TestNoDups(t *testing.T) { + tests := []struct { + name string + args []string + want string + }{ + { + name: "dup-boolean", + args: []string{"up", "--json", "--json"}, + want: "error parsing commandline arguments: invalid boolean flag json: flag provided multiple times", + }, + { + name: "dup-string", + args: []string{"up", "--hostname=foo", "--hostname=bar"}, + want: "error parsing commandline arguments: invalid value \"bar\" for flag -hostname: flag provided multiple times", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cmd := newRootCmd() + makeQuietContinueOnError(cmd) + err := cmd.Parse(tt.args) + if got := fmt.Sprint(err); got != tt.want { + t.Errorf("got %q, want %q", got, tt.want) + } + }) + } +} + +func TestHelpAlias(t *testing.T) { + var stdout, stderr bytes.Buffer + tstest.Replace[io.Writer](t, &Stdout, &stdout) + tstest.Replace[io.Writer](t, &Stderr, &stderr) + + gotExit0 := false + defer func() { + if !gotExit0 { + t.Error("expected os.Exit(0) to be called") + return + } + if !strings.Contains(stderr.String(), "SUBCOMMANDS") { + t.Errorf("expected help output to contain SUBCOMMANDS; got stderr=%q; stdout=%q", stderr.String(), stdout.String()) + } + }() + defer func() { + if e := recover(); e != nil { + if strings.Contains(fmt.Sprint(e), "unexpected call to os.Exit(0)") { + gotExit0 = true + } else { + t.Errorf("unexpected panic: %v", e) + } + } + }() + err := Run([]string{"help"}) + if err != nil { + t.Fatalf("Run: %v", err) + } +} + +func TestDocs(t *testing.T) { + root := newRootCmd() + check := func(t *testing.T, c *ffcli.Command) { + shortVerb, _, ok := strings.Cut(c.ShortHelp, " ") + if !ok || shortVerb == "" { + t.Errorf("couldn't find verb+space in ShortHelp") + } else { + if strings.HasSuffix(shortVerb, ".") { + t.Errorf("ShortHelp shouldn't end in period; got %q", c.ShortHelp) + } + if b := shortVerb[0]; b >= 'a' && b <= 'z' { + t.Errorf("ShortHelp should start with upper-case letter; got %q", c.ShortHelp) + } + if strings.HasSuffix(shortVerb, "s") && shortVerb != "Does" { + t.Errorf("verb %q ending in 's' is unexpected, from %q", shortVerb, c.ShortHelp) + } + } + + name := t.Name() + wantPfx := strings.ReplaceAll(strings.TrimPrefix(name, "TestDocs/"), "/", " ") + switch name { + case "TestDocs/tailscale/completion/bash", + "TestDocs/tailscale/completion/zsh": + wantPfx = "" // special-case exceptions + } + if !strings.HasPrefix(c.ShortUsage, wantPfx) { + t.Errorf("ShortUsage should start with %q; got %q", wantPfx, c.ShortUsage) + } + } + + var walk func(t *testing.T, c *ffcli.Command) + walk = func(t *testing.T, c *ffcli.Command) { + t.Run(c.Name, func(t *testing.T) { + check(t, c) + for _, sub := range c.Subcommands { + walk(t, sub) + } + }) + } + walk(t, root) +} + +func TestDeps(t *testing.T) { + deptest.DepChecker{ + GOOS: "linux", + GOARCH: "arm64", + WantDeps: set.Of( + "tailscale.com/feature/capture/dissector", // want the Lua by default + ), + BadDeps: map[string]string{ + "tailscale.com/feature/capture": "don't link capture code", + "tailscale.com/net/packet": "why we passing packets in the CLI?", + "tailscale.com/net/flowtrack": "why we tracking flows in the CLI?", + }, + }.Check(t) +} + +func TestDepsNoCapture(t *testing.T) { + deptest.DepChecker{ + GOOS: "linux", + GOARCH: "arm64", + Tags: "ts_omit_capture", + BadDeps: map[string]string{ + "tailscale.com/feature/capture": "don't link capture code", + "tailscale.com/feature/capture/dissector": "don't like the Lua", + }, + }.Check(t) + +} diff --git a/cmd/tailscale/cli/configure-jetkvm.go b/cmd/tailscale/cli/configure-jetkvm.go new file mode 100644 index 000000000..c80bf6736 --- /dev/null +++ b/cmd/tailscale/cli/configure-jetkvm.go @@ -0,0 +1,84 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build linux && !android && arm + +package cli + +import ( + "bytes" + "context" + "errors" + "flag" + "os" + "runtime" + "strings" + + "github.com/peterbourgon/ff/v3/ffcli" + "tailscale.com/version/distro" +) + +func init() { + maybeJetKVMConfigureCmd = jetKVMConfigureCmd +} + +func jetKVMConfigureCmd() *ffcli.Command { + if runtime.GOOS != "linux" || distro.Get() != distro.JetKVM { + return nil + } + return &ffcli.Command{ + Name: "jetkvm", + Exec: runConfigureJetKVM, + ShortUsage: "tailscale configure jetkvm", + ShortHelp: "Configure JetKVM to run tailscaled at boot", + LongHelp: strings.TrimSpace(` +This command configures the JetKVM host to run tailscaled at boot. +`), + FlagSet: (func() *flag.FlagSet { + fs := newFlagSet("jetkvm") + return fs + })(), + } +} + +func runConfigureJetKVM(ctx context.Context, args []string) error { + if len(args) > 0 { + return errors.New("unknown arguments") + } + if runtime.GOOS != "linux" || distro.Get() != distro.JetKVM { + return errors.New("only implemented on JetKVM") + } + if err := os.MkdirAll("/userdata/init.d", 0755); err != nil { + return errors.New("unable to create /userdata/init.d") + } + err := os.WriteFile("/userdata/init.d/S22tailscale", bytes.TrimLeft([]byte(` +#!/bin/sh +# /userdata/init.d/S22tailscale +# Start/stop tailscaled + +case "$1" in + start) + /userdata/tailscale/tailscaled > /dev/null 2>&1 & + ;; + stop) + killall tailscaled + ;; + *) + echo "Usage: $0 {start|stop}" + exit 1 + ;; +esac +`), "\n"), 0755) + if err != nil { + return err + } + + if err := os.Symlink("/userdata/tailscale/tailscale", "/bin/tailscale"); err != nil { + if !os.IsExist(err) { + return err + } + } + + printf("Done. Now restart your JetKVM.\n") + return nil +} diff --git a/cmd/tailscale/cli/configure-kube.go b/cmd/tailscale/cli/configure-kube.go index 6af15e3d9..e74e88779 100644 --- a/cmd/tailscale/cli/configure-kube.go +++ b/cmd/tailscale/cli/configure-kube.go @@ -9,44 +9,55 @@ import ( "errors" "flag" "fmt" + "net/netip" + "net/url" "os" "path/filepath" "slices" "strings" + "time" "github.com/peterbourgon/ff/v3/ffcli" "k8s.io/client-go/util/homedir" "sigs.k8s.io/yaml" + "tailscale.com/ipn" + "tailscale.com/ipn/ipnstate" + "tailscale.com/tailcfg" + "tailscale.com/types/netmap" + "tailscale.com/util/dnsname" "tailscale.com/version" ) -func init() { - configureCmd.Subcommands = append(configureCmd.Subcommands, configureKubeconfigCmd) +var configureKubeconfigArgs struct { + http bool // Use HTTP instead of HTTPS (default) for the auth proxy. } -var configureKubeconfigCmd = &ffcli.Command{ - Name: "kubeconfig", - ShortHelp: "[ALPHA] Connect to a Kubernetes cluster using a Tailscale Auth Proxy", - ShortUsage: "tailscale configure kubeconfig ", - LongHelp: strings.TrimSpace(` +func configureKubeconfigCmd() *ffcli.Command { + return &ffcli.Command{ + Name: "kubeconfig", + ShortHelp: "[ALPHA] Connect to a Kubernetes cluster using a Tailscale Auth Proxy", + ShortUsage: "tailscale configure kubeconfig ", + LongHelp: strings.TrimSpace(` Run this command to configure kubectl to connect to a Kubernetes cluster over Tailscale. The hostname argument should be set to the Tailscale hostname of the peer running as an auth proxy in the cluster. See: https://tailscale.com/s/k8s-auth-proxy `), - FlagSet: (func() *flag.FlagSet { - fs := newFlagSet("kubeconfig") - return fs - })(), - Exec: runConfigureKubeconfig, + FlagSet: (func() *flag.FlagSet { + fs := newFlagSet("kubeconfig") + fs.BoolVar(&configureKubeconfigArgs.http, "http", false, "Use HTTP instead of HTTPS to connect to the auth proxy. Ignored if you include a scheme in the hostname argument.") + return fs + })(), + Exec: runConfigureKubeconfig, + } } // kubeconfigPath returns the path to the kubeconfig file for the current user. func kubeconfigPath() (string, error) { if kubeconfig := os.Getenv("KUBECONFIG"); kubeconfig != "" { if version.IsSandboxedMacOS() { - return "", errors.New("$KUBECONFIG is incompatible with the App Store version") + return "", errors.New("cannot read $KUBECONFIG on GUI builds of the macOS client: this requires the open-source tailscaled distribution") } var out string for _, out = range filepath.SplitList(kubeconfig) { @@ -72,10 +83,13 @@ func kubeconfigPath() (string, error) { } func runConfigureKubeconfig(ctx context.Context, args []string) error { - if len(args) != 1 { - return errors.New("unknown arguments") + if len(args) != 1 || args[0] == "" { + return flag.ErrHelp + } + hostOrFQDNOrIP, http, err := getInputs(args[0], configureKubeconfigArgs.http) + if err != nil { + return fmt.Errorf("error parsing inputs: %w", err) } - hostOrFQDN := args[0] st, err := localClient.Status(ctx) if err != nil { @@ -84,22 +98,45 @@ func runConfigureKubeconfig(ctx context.Context, args []string) error { if st.BackendState != "Running" { return errors.New("Tailscale is not running") } - targetFQDN, ok := nodeDNSNameFromArg(st, hostOrFQDN) - if !ok { - return fmt.Errorf("no peer found with hostname %q", hostOrFQDN) + nm, err := getNetMap(ctx) + if err != nil { + return err + } + + targetFQDN, err := nodeOrServiceDNSNameFromArg(st, nm, hostOrFQDNOrIP) + if err != nil { + return err } targetFQDN = strings.TrimSuffix(targetFQDN, ".") var kubeconfig string if kubeconfig, err = kubeconfigPath(); err != nil { return err } - if err = setKubeconfigForPeer(targetFQDN, kubeconfig); err != nil { + scheme := "https://" + if http { + scheme = "http://" + } + if err = setKubeconfigForPeer(scheme, targetFQDN, kubeconfig); err != nil { return err } - printf("kubeconfig configured for %q\n", hostOrFQDN) + printf("kubeconfig configured for %q at URL %q\n", targetFQDN, scheme+targetFQDN) return nil } +func getInputs(arg string, httpArg bool) (string, bool, error) { + u, err := url.Parse(arg) + if err != nil { + return "", false, err + } + + switch u.Scheme { + case "http", "https": + return u.Host, u.Scheme == "http", nil + default: + return arg, httpArg, nil + } +} + // appendOrSetNamed finds a map with a "name" key matching name in dst, and // replaces it with val. If no such map is found, val is appended to dst. func appendOrSetNamed(dst []any, name string, val map[string]any) []any { @@ -118,7 +155,7 @@ func appendOrSetNamed(dst []any, name string, val map[string]any) []any { var errInvalidKubeconfig = errors.New("invalid kubeconfig") -func updateKubeconfig(cfgYaml []byte, fqdn string) ([]byte, error) { +func updateKubeconfig(cfgYaml []byte, scheme, fqdn string) ([]byte, error) { var cfg map[string]any if len(cfgYaml) > 0 { if err := yaml.Unmarshal(cfgYaml, &cfg); err != nil { @@ -141,7 +178,7 @@ func updateKubeconfig(cfgYaml []byte, fqdn string) ([]byte, error) { cfg["clusters"] = appendOrSetNamed(clusters, fqdn, map[string]any{ "name": fqdn, "cluster": map[string]string{ - "server": "https://" + fqdn, + "server": scheme + fqdn, }, }) @@ -174,7 +211,7 @@ func updateKubeconfig(cfgYaml []byte, fqdn string) ([]byte, error) { return yaml.Marshal(cfg) } -func setKubeconfigForPeer(fqdn, filePath string) error { +func setKubeconfigForPeer(scheme, fqdn, filePath string) error { dir := filepath.Dir(filePath) if _, err := os.Stat(dir); err != nil { if !os.IsNotExist(err) { @@ -193,9 +230,97 @@ func setKubeconfigForPeer(fqdn, filePath string) error { if err != nil && !os.IsNotExist(err) { return fmt.Errorf("reading kubeconfig: %w", err) } - b, err = updateKubeconfig(b, fqdn) + b, err = updateKubeconfig(b, scheme, fqdn) if err != nil { return err } return os.WriteFile(filePath, b, 0600) } + +// nodeOrServiceDNSNameFromArg returns the PeerStatus.DNSName value from a peer +// in st that matches the input arg which can be a base name, full DNS name, or +// an IP. If none is found, it looks for a Tailscale Service +func nodeOrServiceDNSNameFromArg(st *ipnstate.Status, nm *netmap.NetworkMap, arg string) (string, error) { + // First check for a node DNS name. + if dnsName, ok := nodeDNSNameFromArg(st, arg); ok { + return dnsName, nil + } + + // If not found, check for a Tailscale Service DNS name. + rec, ok := serviceDNSRecordFromNetMap(nm, st.CurrentTailnet.MagicDNSSuffix, arg) + if !ok { + return "", fmt.Errorf("no peer found for %q", arg) + } + + // Validate we can see a peer advertising the Tailscale Service. + ip, err := netip.ParseAddr(rec.Value) + if err != nil { + return "", fmt.Errorf("error parsing ExtraRecord IP address %q: %w", rec.Value, err) + } + ipPrefix := netip.PrefixFrom(ip, ip.BitLen()) + for _, ps := range st.Peer { + for _, allowedIP := range ps.AllowedIPs.All() { + if allowedIP == ipPrefix { + return rec.Name, nil + } + } + } + + return "", fmt.Errorf("%q is in MagicDNS, but is not currently reachable on any known peer", arg) +} + +func getNetMap(ctx context.Context) (*netmap.NetworkMap, error) { + ctx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + + watcher, err := localClient.WatchIPNBus(ctx, ipn.NotifyInitialNetMap) + if err != nil { + return nil, err + } + defer watcher.Close() + + n, err := watcher.Next() + if err != nil { + return nil, err + } + + return n.NetMap, nil +} + +func serviceDNSRecordFromNetMap(nm *netmap.NetworkMap, tcd, arg string) (rec tailcfg.DNSRecord, ok bool) { + argIP, _ := netip.ParseAddr(arg) + argFQDN, err := dnsname.ToFQDN(arg) + argFQDNValid := err == nil + if !argIP.IsValid() && !argFQDNValid { + return rec, false + } + + for _, rec := range nm.DNS.ExtraRecords { + if argIP.IsValid() { + recIP, _ := netip.ParseAddr(rec.Value) + if recIP == argIP { + return rec, true + } + continue + } + + if !argFQDNValid { + continue + } + + recFirstLabel := dnsname.FirstLabel(rec.Name) + if strings.EqualFold(arg, recFirstLabel) { + return rec, true + } + + recFQDN, err := dnsname.ToFQDN(rec.Name) + if err != nil { + continue + } + if strings.EqualFold(argFQDN.WithTrailingDot(), recFQDN.WithTrailingDot()) { + return rec, true + } + } + + return tailcfg.DNSRecord{}, false +} diff --git a/cmd/tailscale/cli/configure-kube_omit.go b/cmd/tailscale/cli/configure-kube_omit.go new file mode 100644 index 000000000..130f2870f --- /dev/null +++ b/cmd/tailscale/cli/configure-kube_omit.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build ts_omit_kube + +package cli + +import "github.com/peterbourgon/ff/v3/ffcli" + +func configureKubeconfigCmd() *ffcli.Command { + // omitted from the build when the ts_omit_kube build tag is set + return nil +} diff --git a/cmd/tailscale/cli/configure-kube_test.go b/cmd/tailscale/cli/configure-kube_test.go index d71a9b627..0c8b6b2b6 100644 --- a/cmd/tailscale/cli/configure-kube_test.go +++ b/cmd/tailscale/cli/configure-kube_test.go @@ -6,6 +6,7 @@ package cli import ( "bytes" + "fmt" "strings" "testing" @@ -16,6 +17,7 @@ func TestKubeconfig(t *testing.T) { const fqdn = "foo.tail-scale.ts.net" tests := []struct { name string + http bool in string want string wantErr error @@ -48,6 +50,27 @@ contexts: current-context: foo.tail-scale.ts.net kind: Config users: +- name: tailscale-auth + user: + token: unused`, + }, + { + name: "empty_http", + http: true, + in: "", + want: `apiVersion: v1 +clusters: +- cluster: + server: http://foo.tail-scale.ts.net + name: foo.tail-scale.ts.net +contexts: +- context: + cluster: foo.tail-scale.ts.net + user: tailscale-auth + name: foo.tail-scale.ts.net +current-context: foo.tail-scale.ts.net +kind: Config +users: - name: tailscale-auth user: token: unused`, @@ -202,7 +225,11 @@ users: } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := updateKubeconfig([]byte(tt.in), fqdn) + scheme := "https://" + if tt.http { + scheme = "http://" + } + got, err := updateKubeconfig([]byte(tt.in), scheme, fqdn) if err != nil { if err != tt.wantErr { t.Fatalf("updateKubeconfig() error = %v, wantErr %v", err, tt.wantErr) @@ -219,3 +246,30 @@ users: }) } } + +func TestGetInputs(t *testing.T) { + for _, arg := range []string{ + "foo.tail-scale.ts.net", + "foo", + "127.0.0.1", + } { + for _, prefix := range []string{"", "https://", "http://"} { + for _, httpFlag := range []bool{false, true} { + expectedHost := arg + expectedHTTP := (httpFlag && !strings.HasPrefix(prefix, "https://")) || strings.HasPrefix(prefix, "http://") + t.Run(fmt.Sprintf("%s%s_http=%v", prefix, arg, httpFlag), func(t *testing.T) { + host, http, err := getInputs(prefix+arg, httpFlag) + if err != nil { + t.Fatal(err) + } + if host != expectedHost { + t.Errorf("host = %v, want %v", host, expectedHost) + } + if http != expectedHTTP { + t.Errorf("http = %v, want %v", http, expectedHTTP) + } + }) + } + } + } +} diff --git a/cmd/tailscale/cli/configure-synology-cert.go b/cmd/tailscale/cli/configure-synology-cert.go index aabcb8dfa..b5168ef92 100644 --- a/cmd/tailscale/cli/configure-synology-cert.go +++ b/cmd/tailscale/cli/configure-synology-cert.go @@ -1,6 +1,8 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause +//go:build linux && !ts_omit_acme && !ts_omit_synology + package cli import ( @@ -22,22 +24,31 @@ import ( "tailscale.com/version/distro" ) -var synologyConfigureCertCmd = &ffcli.Command{ - Name: "synology-cert", - Exec: runConfigureSynologyCert, - ShortHelp: "Configure Synology with a TLS certificate for your tailnet", - ShortUsage: "synology-cert [--domain ]", - LongHelp: strings.TrimSpace(` +func init() { + maybeConfigSynologyCertCmd = synologyConfigureCertCmd +} + +func synologyConfigureCertCmd() *ffcli.Command { + if runtime.GOOS != "linux" || distro.Get() != distro.Synology { + return nil + } + return &ffcli.Command{ + Name: "synology-cert", + Exec: runConfigureSynologyCert, + ShortHelp: "Configure Synology with a TLS certificate for your tailnet", + ShortUsage: "synology-cert [--domain ]", + LongHelp: strings.TrimSpace(` This command is intended to run periodically as root on a Synology device to create or refresh the TLS certificate for the tailnet domain. See: https://tailscale.com/kb/1153/enabling-https `), - FlagSet: (func() *flag.FlagSet { - fs := newFlagSet("synology-cert") - fs.StringVar(&synologyConfigureCertArgs.domain, "domain", "", "Tailnet domain to create or refresh certificates for. Ignored if only one domain exists.") - return fs - })(), + FlagSet: (func() *flag.FlagSet { + fs := newFlagSet("synology-cert") + fs.StringVar(&synologyConfigureCertArgs.domain, "domain", "", "Tailnet domain to create or refresh certificates for. Ignored if only one domain exists.") + return fs + })(), + } } var synologyConfigureCertArgs struct { diff --git a/cmd/tailscale/cli/configure-synology-cert_test.go b/cmd/tailscale/cli/configure-synology-cert_test.go index 801285e55..c7da5622f 100644 --- a/cmd/tailscale/cli/configure-synology-cert_test.go +++ b/cmd/tailscale/cli/configure-synology-cert_test.go @@ -1,6 +1,8 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause +//go:build linux && !ts_omit_acme + package cli import ( diff --git a/cmd/tailscale/cli/configure-synology.go b/cmd/tailscale/cli/configure-synology.go index 9d674e56d..f0f05f757 100644 --- a/cmd/tailscale/cli/configure-synology.go +++ b/cmd/tailscale/cli/configure-synology.go @@ -21,34 +21,49 @@ import ( // configureHostCmd is the "tailscale configure-host" command which was once // used to configure Synology devices, but is now a compatibility alias to // "tailscale configure synology". -var configureHostCmd = &ffcli.Command{ - Name: "configure-host", - Exec: runConfigureSynology, - ShortUsage: "tailscale configure-host\n" + synologyConfigureCmd.ShortUsage, - ShortHelp: synologyConfigureCmd.ShortHelp, - LongHelp: hidden + synologyConfigureCmd.LongHelp, - FlagSet: (func() *flag.FlagSet { - fs := newFlagSet("configure-host") - return fs - })(), +// +// It returns nil if the actual "tailscale configure synology" command is not +// available. +func configureHostCmd() *ffcli.Command { + synologyConfigureCmd := synologyConfigureCmd() + if synologyConfigureCmd == nil { + // No need to offer this compatibility alias if the actual command is not available. + return nil + } + return &ffcli.Command{ + Name: "configure-host", + Exec: runConfigureSynology, + ShortUsage: "tailscale configure-host\n" + synologyConfigureCmd.ShortUsage, + ShortHelp: synologyConfigureCmd.ShortHelp, + LongHelp: hidden + synologyConfigureCmd.LongHelp, + FlagSet: (func() *flag.FlagSet { + fs := newFlagSet("configure-host") + return fs + })(), + } } -var synologyConfigureCmd = &ffcli.Command{ - Name: "synology", - Exec: runConfigureSynology, - ShortUsage: "tailscale configure synology", - ShortHelp: "Configure Synology to enable outbound connections", - LongHelp: strings.TrimSpace(` +func synologyConfigureCmd() *ffcli.Command { + if runtime.GOOS != "linux" || distro.Get() != distro.Synology { + return nil + } + return &ffcli.Command{ + Name: "synology", + Exec: runConfigureSynology, + ShortUsage: "tailscale configure synology", + ShortHelp: "Configure Synology to enable outbound connections", + LongHelp: strings.TrimSpace(` This command is intended to run at boot as root on a Synology device to create the /dev/net/tun device and give the tailscaled binary permission to use it. See: https://tailscale.com/s/synology-outbound `), - FlagSet: (func() *flag.FlagSet { - fs := newFlagSet("synology") - return fs - })(), + FlagSet: (func() *flag.FlagSet { + fs := newFlagSet("synology") + return fs + })(), + } } func runConfigureSynology(ctx context.Context, args []string) error { diff --git a/cmd/tailscale/cli/configure.go b/cmd/tailscale/cli/configure.go index fd136d766..20236eb28 100644 --- a/cmd/tailscale/cli/configure.go +++ b/cmd/tailscale/cli/configure.go @@ -5,32 +5,49 @@ package cli import ( "flag" - "runtime" "strings" "github.com/peterbourgon/ff/v3/ffcli" - "tailscale.com/version/distro" ) -var configureCmd = &ffcli.Command{ - Name: "configure", - ShortUsage: "tailscale configure ", - ShortHelp: "[ALPHA] Configure the host to enable more Tailscale features", - LongHelp: strings.TrimSpace(` +var ( + maybeJetKVMConfigureCmd, + maybeConfigSynologyCertCmd, + _ func() *ffcli.Command // non-nil only on Linux/arm for JetKVM +) + +func configureCmd() *ffcli.Command { + return &ffcli.Command{ + Name: "configure", + ShortUsage: "tailscale configure ", + ShortHelp: "Configure the host to enable more Tailscale features", + LongHelp: strings.TrimSpace(` The 'configure' set of commands are intended to provide a way to enable different services on the host to use Tailscale in more ways. `), - FlagSet: (func() *flag.FlagSet { - fs := newFlagSet("configure") - return fs - })(), - Subcommands: configureSubcommands(), + FlagSet: (func() *flag.FlagSet { + fs := newFlagSet("configure") + return fs + })(), + Subcommands: nonNilCmds( + configureKubeconfigCmd(), + synologyConfigureCmd(), + ccall(maybeConfigSynologyCertCmd), + ccall(maybeSysExtCmd), + ccall(maybeVPNConfigCmd), + ccall(maybeJetKVMConfigureCmd), + ccall(maybeSystrayCmd), + ), + } } -func configureSubcommands() (out []*ffcli.Command) { - if runtime.GOOS == "linux" && distro.Get() == distro.Synology { - out = append(out, synologyConfigureCmd) - out = append(out, synologyConfigureCertCmd) +// ccall calls the function f if it is non-nil, and returns its result. +// +// It returns the zero value of the type T if f is nil. +func ccall[T any](f func() T) T { + var zero T + if f == nil { + return zero } - return out + return f() } diff --git a/cmd/tailscale/cli/configure_apple-all.go b/cmd/tailscale/cli/configure_apple-all.go new file mode 100644 index 000000000..5f0da9b95 --- /dev/null +++ b/cmd/tailscale/cli/configure_apple-all.go @@ -0,0 +1,11 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package cli + +import "github.com/peterbourgon/ff/v3/ffcli" + +var ( + maybeSysExtCmd func() *ffcli.Command // non-nil only on macOS, see configure_apple.go + maybeVPNConfigCmd func() *ffcli.Command // non-nil only on macOS, see configure_apple.go +) diff --git a/cmd/tailscale/cli/configure_apple.go b/cmd/tailscale/cli/configure_apple.go new file mode 100644 index 000000000..c0d99b90a --- /dev/null +++ b/cmd/tailscale/cli/configure_apple.go @@ -0,0 +1,97 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build darwin + +package cli + +import ( + "context" + "errors" + + "github.com/peterbourgon/ff/v3/ffcli" +) + +func init() { + maybeSysExtCmd = sysExtCmd + maybeVPNConfigCmd = vpnConfigCmd +} + +// Functions in this file provide a dummy Exec function that only prints an error message for users of the open-source +// tailscaled distribution. On GUI builds, the Swift code in the macOS client handles these commands by not passing the +// flow of execution to the CLI. + +// sysExtCmd returns a command for managing the Tailscale system extension on macOS +// (for the Standalone variant of the client only). +func sysExtCmd() *ffcli.Command { + return &ffcli.Command{ + Name: "sysext", + ShortUsage: "tailscale configure sysext [activate|deactivate|status]", + ShortHelp: "Manage the system extension for macOS (Standalone variant)", + LongHelp: "The sysext set of commands provides a way to activate, deactivate, or manage the state of the Tailscale system extension on macOS. " + + "This is only relevant if you are running the Standalone variant of the Tailscale client for macOS. " + + "To access more detailed information about system extensions installed on this Mac, run 'systemextensionsctl list'.", + Subcommands: []*ffcli.Command{ + { + Name: "activate", + ShortUsage: "tailscale configure sysext activate", + ShortHelp: "Register the Tailscale system extension with macOS.", + LongHelp: "This command registers the Tailscale system extension with macOS. To run Tailscale, you'll also need to install the VPN configuration separately (run `tailscale configure vpn-config install`). After running this command, you need to approve the extension in System Settings > Login Items and Extensions > Network Extensions.", + Exec: requiresStandalone, + }, + { + Name: "deactivate", + ShortUsage: "tailscale configure sysext deactivate", + ShortHelp: "Deactivate the Tailscale system extension on macOS", + LongHelp: "This command deactivates the Tailscale system extension on macOS. To completely remove Tailscale, you'll also need to delete the VPN configuration separately (use `tailscale configure vpn-config uninstall`).", + Exec: requiresStandalone, + }, + { + Name: "status", + ShortUsage: "tailscale configure sysext status", + ShortHelp: "Print the enablement status of the Tailscale system extension", + LongHelp: "This command prints the enablement status of the Tailscale system extension. If the extension is not enabled, run `tailscale sysext activate` to enable it.", + Exec: requiresStandalone, + }, + }, + Exec: requiresStandalone, + } +} + +// vpnConfigCmd returns a command for managing the Tailscale VPN configuration on macOS +// (the entry that appears in System Settings > VPN). +func vpnConfigCmd() *ffcli.Command { + return &ffcli.Command{ + Name: "mac-vpn", + ShortUsage: "tailscale configure mac-vpn [install|uninstall]", + ShortHelp: "Manage the VPN configuration on macOS (App Store and Standalone variants)", + LongHelp: "The vpn-config set of commands provides a way to add or remove the Tailscale VPN configuration from the macOS settings. This is the entry that appears in System Settings > VPN.", + Subcommands: []*ffcli.Command{ + { + Name: "install", + ShortUsage: "tailscale configure mac-vpn install", + ShortHelp: "Write the Tailscale VPN configuration to the macOS settings", + LongHelp: "This command writes the Tailscale VPN configuration to the macOS settings. This is the entry that appears in System Settings > VPN. If you are running the Standalone variant of the client, you'll also need to install the system extension separately (run `tailscale configure sysext activate`).", + Exec: requiresGUI, + }, + { + Name: "uninstall", + ShortUsage: "tailscale configure mac-vpn uninstall", + ShortHelp: "Delete the Tailscale VPN configuration from the macOS settings", + LongHelp: "This command removes the Tailscale VPN configuration from the macOS settings. This is the entry that appears in System Settings > VPN. If you are running the Standalone variant of the client, you'll also need to deactivate the system extension separately (run `tailscale configure sysext deactivate`).", + Exec: requiresGUI, + }, + }, + Exec: func(ctx context.Context, args []string) error { + return errors.New("unsupported command: requires a GUI build of the macOS client") + }, + } +} + +func requiresStandalone(ctx context.Context, args []string) error { + return errors.New("unsupported command: requires the Standalone (.pkg installer) GUI build of the client") +} + +func requiresGUI(ctx context.Context, args []string) error { + return errors.New("unsupported command: requires a GUI build of the macOS client") +} diff --git a/cmd/tailscale/cli/configure_linux-all.go b/cmd/tailscale/cli/configure_linux-all.go new file mode 100644 index 000000000..e645e9654 --- /dev/null +++ b/cmd/tailscale/cli/configure_linux-all.go @@ -0,0 +1,8 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package cli + +import "github.com/peterbourgon/ff/v3/ffcli" + +var maybeSystrayCmd func() *ffcli.Command // non-nil only on Linux, see configure_linux.go diff --git a/cmd/tailscale/cli/configure_linux.go b/cmd/tailscale/cli/configure_linux.go new file mode 100644 index 000000000..4bbde8721 --- /dev/null +++ b/cmd/tailscale/cli/configure_linux.go @@ -0,0 +1,51 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build linux && !ts_omit_systray + +package cli + +import ( + "context" + "flag" + "fmt" + + "github.com/peterbourgon/ff/v3/ffcli" + "tailscale.com/client/systray" +) + +func init() { + maybeSystrayCmd = systrayConfigCmd +} + +var systrayArgs struct { + initSystem string + installStartup bool +} + +func systrayConfigCmd() *ffcli.Command { + return &ffcli.Command{ + Name: "systray", + ShortUsage: "tailscale configure systray [options]", + ShortHelp: "[ALPHA] Manage the systray client for Linux", + LongHelp: "[ALPHA] The systray set of commands provides a way to configure the systray application on Linux.", + Exec: configureSystray, + FlagSet: (func() *flag.FlagSet { + fs := newFlagSet("systray") + fs.StringVar(&systrayArgs.initSystem, "enable-startup", "", + "Install startup script for init system. Currently supported systems are [systemd].") + return fs + })(), + } +} + +func configureSystray(_ context.Context, _ []string) error { + if systrayArgs.initSystem != "" { + if err := systray.InstallStartupScript(systrayArgs.initSystem); err != nil { + fmt.Printf("%s\n\n", err.Error()) + return flag.ErrHelp + } + return nil + } + return flag.ErrHelp +} diff --git a/cmd/tailscale/cli/debug-capture.go b/cmd/tailscale/cli/debug-capture.go new file mode 100644 index 000000000..a54066fa6 --- /dev/null +++ b/cmd/tailscale/cli/debug-capture.go @@ -0,0 +1,80 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !ios && !ts_omit_capture + +package cli + +import ( + "context" + "flag" + "fmt" + "io" + "os" + "os/exec" + + "github.com/peterbourgon/ff/v3/ffcli" + "tailscale.com/feature/capture/dissector" +) + +func init() { + debugCaptureCmd = mkDebugCaptureCmd +} + +func mkDebugCaptureCmd() *ffcli.Command { + return &ffcli.Command{ + Name: "capture", + ShortUsage: "tailscale debug capture", + Exec: runCapture, + ShortHelp: "Stream pcaps for debugging", + FlagSet: (func() *flag.FlagSet { + fs := newFlagSet("capture") + fs.StringVar(&captureArgs.outFile, "o", "", "path to stream the pcap (or - for stdout), leave empty to start wireshark") + return fs + })(), + } +} + +var captureArgs struct { + outFile string +} + +func runCapture(ctx context.Context, args []string) error { + stream, err := localClient.StreamDebugCapture(ctx) + if err != nil { + return err + } + defer stream.Close() + + switch captureArgs.outFile { + case "-": + fmt.Fprintln(Stderr, "Press Ctrl-C to stop the capture.") + _, err = io.Copy(os.Stdout, stream) + return err + case "": + lua, err := os.CreateTemp("", "ts-dissector") + if err != nil { + return err + } + defer os.Remove(lua.Name()) + io.WriteString(lua, dissector.Lua) + if err := lua.Close(); err != nil { + return err + } + + wireshark := exec.CommandContext(ctx, "wireshark", "-X", "lua_script:"+lua.Name(), "-k", "-i", "-") + wireshark.Stdin = stream + wireshark.Stdout = os.Stdout + wireshark.Stderr = os.Stderr + return wireshark.Run() + } + + f, err := os.OpenFile(captureArgs.outFile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644) + if err != nil { + return err + } + defer f.Close() + fmt.Fprintln(Stderr, "Press Ctrl-C to stop the capture.") + _, err = io.Copy(f, stream) + return err +} diff --git a/cmd/tailscale/cli/debug-peer-relay.go b/cmd/tailscale/cli/debug-peer-relay.go new file mode 100644 index 000000000..bef8b8369 --- /dev/null +++ b/cmd/tailscale/cli/debug-peer-relay.go @@ -0,0 +1,77 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !ios && !ts_omit_relayserver + +package cli + +import ( + "bytes" + "cmp" + "context" + "fmt" + "net/netip" + "slices" + + "github.com/peterbourgon/ff/v3/ffcli" + "tailscale.com/net/udprelay/status" +) + +func init() { + debugPeerRelayCmd = mkDebugPeerRelaySessionsCmd +} + +func mkDebugPeerRelaySessionsCmd() *ffcli.Command { + return &ffcli.Command{ + Name: "peer-relay-sessions", + ShortUsage: "tailscale debug peer-relay-sessions", + Exec: runPeerRelaySessions, + ShortHelp: "Print the current set of active peer relay sessions relayed through this node", + } +} + +func runPeerRelaySessions(ctx context.Context, args []string) error { + srv, err := localClient.DebugPeerRelaySessions(ctx) + if err != nil { + return err + } + + var buf bytes.Buffer + f := func(format string, a ...any) { fmt.Fprintf(&buf, format, a...) } + + f("Server port: ") + if srv.UDPPort == nil { + f("not configured (you can configure the port with 'tailscale set --relay-server-port=')") + } else { + f("%d", *srv.UDPPort) + } + f("\n") + f("Sessions count: %d\n", len(srv.Sessions)) + if len(srv.Sessions) == 0 { + Stdout.Write(buf.Bytes()) + return nil + } + + fmtSessionDirection := func(a, z status.ClientInfo) string { + fmtEndpoint := func(ap netip.AddrPort) string { + if ap.IsValid() { + return ap.String() + } + return "" + } + return fmt.Sprintf("%s(%s) --> %s(%s), Packets: %d Bytes: %d", + fmtEndpoint(a.Endpoint), a.ShortDisco, + fmtEndpoint(z.Endpoint), z.ShortDisco, + a.PacketsTx, a.BytesTx) + } + + f("\n") + slices.SortFunc(srv.Sessions, func(s1, s2 status.ServerSession) int { return cmp.Compare(s1.VNI, s2.VNI) }) + for _, s := range srv.Sessions { + f("VNI: %d\n", s.VNI) + f(" %s\n", fmtSessionDirection(s.Client1, s.Client2)) + f(" %s\n", fmtSessionDirection(s.Client2, s.Client1)) + } + Stdout.Write(buf.Bytes()) + return nil +} diff --git a/cmd/tailscale/cli/debug-portmap.go b/cmd/tailscale/cli/debug-portmap.go new file mode 100644 index 000000000..d8db1442c --- /dev/null +++ b/cmd/tailscale/cli/debug-portmap.go @@ -0,0 +1,79 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !ios && !ts_omit_debugportmapper + +package cli + +import ( + "context" + "flag" + "fmt" + "io" + "net/netip" + "os" + "time" + + "github.com/peterbourgon/ff/v3/ffcli" + "tailscale.com/client/local" +) + +func init() { + debugPortmapCmd = mkDebugPortmapCmd +} + +func mkDebugPortmapCmd() *ffcli.Command { + return &ffcli.Command{ + Name: "portmap", + ShortUsage: "tailscale debug portmap", + Exec: debugPortmap, + ShortHelp: "Run portmap debugging", + FlagSet: (func() *flag.FlagSet { + fs := newFlagSet("portmap") + fs.DurationVar(&debugPortmapArgs.duration, "duration", 5*time.Second, "timeout for port mapping") + fs.StringVar(&debugPortmapArgs.ty, "type", "", `portmap debug type (one of "", "pmp", "pcp", or "upnp")`) + fs.StringVar(&debugPortmapArgs.gatewayAddr, "gateway-addr", "", `override gateway IP (must also pass --self-addr)`) + fs.StringVar(&debugPortmapArgs.selfAddr, "self-addr", "", `override self IP (must also pass --gateway-addr)`) + fs.BoolVar(&debugPortmapArgs.logHTTP, "log-http", false, `print all HTTP requests and responses to the log`) + return fs + })(), + } +} + +var debugPortmapArgs struct { + duration time.Duration + gatewayAddr string + selfAddr string + ty string + logHTTP bool +} + +func debugPortmap(ctx context.Context, args []string) error { + opts := &local.DebugPortmapOpts{ + Duration: debugPortmapArgs.duration, + Type: debugPortmapArgs.ty, + LogHTTP: debugPortmapArgs.logHTTP, + } + if (debugPortmapArgs.gatewayAddr != "") != (debugPortmapArgs.selfAddr != "") { + return fmt.Errorf("if one of --gateway-addr and --self-addr is provided, the other must be as well") + } + if debugPortmapArgs.gatewayAddr != "" { + var err error + opts.GatewayAddr, err = netip.ParseAddr(debugPortmapArgs.gatewayAddr) + if err != nil { + return fmt.Errorf("invalid --gateway-addr: %w", err) + } + opts.SelfAddr, err = netip.ParseAddr(debugPortmapArgs.selfAddr) + if err != nil { + return fmt.Errorf("invalid --self-addr: %w", err) + } + } + rc, err := localClient.DebugPortmap(ctx, opts) + if err != nil { + return err + } + defer rc.Close() + + _, err = io.Copy(os.Stdout, rc) + return err +} diff --git a/cmd/tailscale/cli/debug.go b/cmd/tailscale/cli/debug.go index fdde9ef09..2facd66ae 100644 --- a/cmd/tailscale/cli/debug.go +++ b/cmd/tailscale/cli/debug.go @@ -6,6 +6,7 @@ package cli import ( "bufio" "bytes" + "cmp" "context" "encoding/binary" "encoding/json" @@ -16,11 +17,11 @@ import ( "log" "net" "net/http" + "net/http/httptrace" "net/http/httputil" "net/netip" "net/url" "os" - "os/exec" "runtime" "runtime/debug" "strconv" @@ -28,316 +29,361 @@ import ( "time" "github.com/peterbourgon/ff/v3/ffcli" - "golang.org/x/net/http/httpproxy" - "golang.org/x/net/http2" - "tailscale.com/client/tailscale" "tailscale.com/client/tailscale/apitype" - "tailscale.com/control/controlhttp" + "tailscale.com/control/ts2021" + "tailscale.com/feature" + _ "tailscale.com/feature/condregister/useproxy" + "tailscale.com/health" "tailscale.com/hostinfo" - "tailscale.com/internal/noiseconn" "tailscale.com/ipn" + "tailscale.com/net/ace" + "tailscale.com/net/dnscache" + "tailscale.com/net/netmon" "tailscale.com/net/tsaddr" - "tailscale.com/net/tshttpproxy" + "tailscale.com/net/tsdial" "tailscale.com/paths" "tailscale.com/safesocket" "tailscale.com/tailcfg" "tailscale.com/types/key" "tailscale.com/types/logger" + "tailscale.com/util/eventbus" "tailscale.com/util/must" - "tailscale.com/wgengine/capture" ) -var debugCmd = &ffcli.Command{ - Name: "debug", - Exec: runDebug, - ShortUsage: "tailscale debug ", - ShortHelp: "Debug commands", - LongHelp: hidden + `"tailscale debug" contains misc debug facilities; it is not a stable interface.`, - FlagSet: (func() *flag.FlagSet { - fs := newFlagSet("debug") - fs.StringVar(&debugArgs.file, "file", "", "get, delete:NAME, or NAME") - fs.StringVar(&debugArgs.cpuFile, "cpu-profile", "", "if non-empty, grab a CPU profile for --profile-seconds seconds and write it to this file; - for stdout") - fs.StringVar(&debugArgs.memFile, "mem-profile", "", "if non-empty, grab a memory profile and write it to this file; - for stdout") - fs.IntVar(&debugArgs.cpuSec, "profile-seconds", 15, "number of seconds to run a CPU profile for, when --cpu-profile is non-empty") - return fs - })(), - Subcommands: []*ffcli.Command{ - { - Name: "derp-map", - ShortUsage: "tailscale debug derp-map", - Exec: runDERPMap, - ShortHelp: "Print DERP map", - }, - { - Name: "component-logs", - ShortUsage: "tailscale debug component-logs [" + strings.Join(ipn.DebuggableComponents, "|") + "]", - Exec: runDebugComponentLogs, - ShortHelp: "Enable/disable debug logs for a component", - FlagSet: (func() *flag.FlagSet { - fs := newFlagSet("component-logs") - fs.DurationVar(&debugComponentLogsArgs.forDur, "for", time.Hour, "how long to enable debug logs for; zero or negative means to disable") - return fs - })(), - }, - { - Name: "daemon-goroutines", - ShortUsage: "tailscale debug daemon-goroutines", - Exec: runDaemonGoroutines, - ShortHelp: "Print tailscaled's goroutines", - }, - { - Name: "daemon-logs", - ShortUsage: "tailscale debug daemon-logs", - Exec: runDaemonLogs, - ShortHelp: "Watch tailscaled's server logs", - FlagSet: (func() *flag.FlagSet { - fs := newFlagSet("daemon-logs") - fs.IntVar(&daemonLogsArgs.verbose, "verbose", 0, "verbosity level") - fs.BoolVar(&daemonLogsArgs.time, "time", false, "include client time") - return fs - })(), - }, - { - Name: "metrics", - ShortUsage: "tailscale debug metrics", - Exec: runDaemonMetrics, - ShortHelp: "Print tailscaled's metrics", - FlagSet: (func() *flag.FlagSet { - fs := newFlagSet("metrics") - fs.BoolVar(&metricsArgs.watch, "watch", false, "print JSON dump of delta values") - return fs - })(), - }, - { - Name: "env", - ShortUsage: "tailscale debug env", - Exec: runEnv, - ShortHelp: "Print cmd/tailscale environment", - }, - { - Name: "stat", - ShortUsage: "tailscale debug stat ", - Exec: runStat, - ShortHelp: "Stat a file", - }, - { - Name: "hostinfo", - ShortUsage: "tailscale debug hostinfo", - Exec: runHostinfo, - ShortHelp: "Print hostinfo", - }, - { - Name: "local-creds", - ShortUsage: "tailscale debug local-creds", - Exec: runLocalCreds, - ShortHelp: "Print how to access Tailscale LocalAPI", - }, - { - Name: "restun", - ShortUsage: "tailscale debug restun", - Exec: localAPIAction("restun"), - ShortHelp: "Force a magicsock restun", - }, - { - Name: "rebind", - ShortUsage: "tailscale debug rebind", - Exec: localAPIAction("rebind"), - ShortHelp: "Force a magicsock rebind", - }, - { - Name: "derp-set-on-demand", - ShortUsage: "tailscale debug derp-set-on-demand", - Exec: localAPIAction("derp-set-homeless"), - ShortHelp: "Enable DERP on-demand mode (breaks reachability)", - }, - { - Name: "derp-unset-on-demand", - ShortUsage: "tailscale debug derp-unset-on-demand", - Exec: localAPIAction("derp-unset-homeless"), - ShortHelp: "Disable DERP on-demand mode", - }, - { - Name: "break-tcp-conns", - ShortUsage: "tailscale debug break-tcp-conns", - Exec: localAPIAction("break-tcp-conns"), - ShortHelp: "Break any open TCP connections from the daemon", - }, - { - Name: "break-derp-conns", - ShortUsage: "tailscale debug break-derp-conns", - Exec: localAPIAction("break-derp-conns"), - ShortHelp: "Break any open DERP connections from the daemon", - }, - { - Name: "pick-new-derp", - ShortUsage: "tailscale debug pick-new-derp", - Exec: localAPIAction("pick-new-derp"), - ShortHelp: "Switch to some other random DERP home region for a short time", - }, - { - Name: "force-netmap-update", - ShortUsage: "tailscale debug force-netmap-update", - Exec: localAPIAction("force-netmap-update"), - ShortHelp: "Force a full no-op netmap update (for load testing)", - }, - { - // TODO(bradfitz,maisem): eventually promote this out of debug - Name: "reload-config", - ShortUsage: "tailscale debug reload-config", - Exec: reloadConfig, - ShortHelp: "Reload config", - }, - { - Name: "control-knobs", - ShortUsage: "tailscale debug control-knobs", - Exec: debugControlKnobs, - ShortHelp: "See current control knobs", - }, - { - Name: "prefs", - ShortUsage: "tailscale debug prefs", - Exec: runPrefs, - ShortHelp: "Print prefs", - FlagSet: (func() *flag.FlagSet { - fs := newFlagSet("prefs") - fs.BoolVar(&prefsArgs.pretty, "pretty", false, "If true, pretty-print output") - return fs - })(), - }, - { - Name: "watch-ipn", - ShortUsage: "tailscale debug watch-ipn", - Exec: runWatchIPN, - ShortHelp: "Subscribe to IPN message bus", - FlagSet: (func() *flag.FlagSet { - fs := newFlagSet("watch-ipn") - fs.BoolVar(&watchIPNArgs.netmap, "netmap", true, "include netmap in messages") - fs.BoolVar(&watchIPNArgs.initial, "initial", false, "include initial status") - fs.BoolVar(&watchIPNArgs.showPrivateKey, "show-private-key", false, "include node private key in printed netmap") - fs.IntVar(&watchIPNArgs.count, "count", 0, "exit after printing this many statuses, or 0 to keep going forever") - return fs - })(), - }, - { - Name: "netmap", - ShortUsage: "tailscale debug netmap", - Exec: runNetmap, - ShortHelp: "Print the current network map", - FlagSet: (func() *flag.FlagSet { - fs := newFlagSet("netmap") - fs.BoolVar(&netmapArgs.showPrivateKey, "show-private-key", false, "include node private key in printed netmap") - return fs - })(), - }, - { - Name: "via", - ShortUsage: "tailscale debug via \n" + - "tailscale debug via ", - Exec: runVia, - ShortHelp: "Convert between site-specific IPv4 CIDRs and IPv6 'via' routes", - }, - { - Name: "ts2021", - ShortUsage: "tailscale debug ts2021", - Exec: runTS2021, - ShortHelp: "Debug ts2021 protocol connectivity", - FlagSet: (func() *flag.FlagSet { - fs := newFlagSet("ts2021") - fs.StringVar(&ts2021Args.host, "host", "controlplane.tailscale.com", "hostname of control plane") - fs.IntVar(&ts2021Args.version, "version", int(tailcfg.CurrentCapabilityVersion), "protocol version") - fs.BoolVar(&ts2021Args.verbose, "verbose", false, "be extra verbose") - return fs - })(), - }, - { - Name: "set-expire", - ShortUsage: "tailscale debug set-expire --in=1m", - Exec: runSetExpire, - ShortHelp: "Manipulate node key expiry for testing", - FlagSet: (func() *flag.FlagSet { - fs := newFlagSet("set-expire") - fs.DurationVar(&setExpireArgs.in, "in", 0, "if non-zero, set node key to expire this duration from now") - return fs - })(), - }, - { - Name: "dev-store-set", - ShortUsage: "tailscale debug dev-store-set", - Exec: runDevStoreSet, - ShortHelp: "Set a key/value pair during development", - FlagSet: (func() *flag.FlagSet { - fs := newFlagSet("store-set") - fs.BoolVar(&devStoreSetArgs.danger, "danger", false, "accept danger") - return fs - })(), - }, - { - Name: "derp", - ShortUsage: "tailscale debug derp", - Exec: runDebugDERP, - ShortHelp: "Test a DERP configuration", - }, - { - Name: "capture", - ShortUsage: "tailscale debug capture", - Exec: runCapture, - ShortHelp: "Streams pcaps for debugging", - FlagSet: (func() *flag.FlagSet { - fs := newFlagSet("capture") - fs.StringVar(&captureArgs.outFile, "o", "", "path to stream the pcap (or - for stdout), leave empty to start wireshark") - return fs - })(), - }, - { - Name: "portmap", - ShortUsage: "tailscale debug portmap", - Exec: debugPortmap, - ShortHelp: "Run portmap debugging", - FlagSet: (func() *flag.FlagSet { - fs := newFlagSet("portmap") - fs.DurationVar(&debugPortmapArgs.duration, "duration", 5*time.Second, "timeout for port mapping") - fs.StringVar(&debugPortmapArgs.ty, "type", "", `portmap debug type (one of "", "pmp", "pcp", or "upnp")`) - fs.StringVar(&debugPortmapArgs.gatewayAddr, "gateway-addr", "", `override gateway IP (must also pass --self-addr)`) - fs.StringVar(&debugPortmapArgs.selfAddr, "self-addr", "", `override self IP (must also pass --gateway-addr)`) - fs.BoolVar(&debugPortmapArgs.logHTTP, "log-http", false, `print all HTTP requests and responses to the log`) - return fs - })(), - }, - { - Name: "peer-endpoint-changes", - ShortUsage: "tailscale debug peer-endpoint-changes ", - Exec: runPeerEndpointChanges, - ShortHelp: "Prints debug information about a peer's endpoint changes", - }, - { - Name: "dial-types", - ShortUsage: "tailscale debug dial-types ", - Exec: runDebugDialTypes, - ShortHelp: "Prints debug information about connecting to a given host or IP", - FlagSet: (func() *flag.FlagSet { - fs := newFlagSet("dial-types") - fs.StringVar(&debugDialTypesArgs.network, "network", "tcp", `network type to dial ("tcp", "udp", etc.)`) - return fs - })(), - }, - { - Name: "resolve", - ShortUsage: "tailscale debug resolve ", - Exec: runDebugResolve, - ShortHelp: "Does a DNS lookup", - FlagSet: (func() *flag.FlagSet { - fs := newFlagSet("resolve") - fs.StringVar(&resolveArgs.net, "net", "ip", "network type to resolve (ip, ip4, ip6)") - return fs - })(), - }, - { - Name: "go-buildinfo", - ShortUsage: "tailscale debug go-buildinfo", - ShortHelp: "Prints Go's runtime/debug.BuildInfo", - Exec: runGoBuildInfo, - }, - }, +var ( + debugCaptureCmd func() *ffcli.Command // or nil + debugPortmapCmd func() *ffcli.Command // or nil + debugPeerRelayCmd func() *ffcli.Command // or nil +) + +func debugCmd() *ffcli.Command { + return &ffcli.Command{ + Name: "debug", + Exec: runDebug, + ShortUsage: "tailscale debug ", + ShortHelp: "Debug commands", + LongHelp: hidden + `"tailscale debug" contains misc debug facilities; it is not a stable interface.`, + FlagSet: (func() *flag.FlagSet { + fs := newFlagSet("debug") + fs.StringVar(&debugArgs.file, "file", "", "get, delete:NAME, or NAME") + fs.StringVar(&debugArgs.cpuFile, "cpu-profile", "", "if non-empty, grab a CPU profile for --profile-seconds seconds and write it to this file; - for stdout") + fs.StringVar(&debugArgs.memFile, "mem-profile", "", "if non-empty, grab a memory profile and write it to this file; - for stdout") + fs.IntVar(&debugArgs.cpuSec, "profile-seconds", 15, "number of seconds to run a CPU profile for, when --cpu-profile is non-empty") + return fs + })(), + Subcommands: nonNilCmds([]*ffcli.Command{ + { + Name: "derp-map", + ShortUsage: "tailscale debug derp-map", + Exec: runDERPMap, + ShortHelp: "Print DERP map", + }, + { + Name: "component-logs", + ShortUsage: "tailscale debug component-logs [" + strings.Join(ipn.DebuggableComponents, "|") + "]", + Exec: runDebugComponentLogs, + ShortHelp: "Enable/disable debug logs for a component", + FlagSet: (func() *flag.FlagSet { + fs := newFlagSet("component-logs") + fs.DurationVar(&debugComponentLogsArgs.forDur, "for", time.Hour, "how long to enable debug logs for; zero or negative means to disable") + return fs + })(), + }, + { + Name: "daemon-goroutines", + ShortUsage: "tailscale debug daemon-goroutines", + Exec: runDaemonGoroutines, + ShortHelp: "Print tailscaled's goroutines", + }, + { + Name: "daemon-logs", + ShortUsage: "tailscale debug daemon-logs", + Exec: runDaemonLogs, + ShortHelp: "Watch tailscaled's server logs", + FlagSet: (func() *flag.FlagSet { + fs := newFlagSet("daemon-logs") + fs.IntVar(&daemonLogsArgs.verbose, "verbose", 0, "verbosity level") + fs.BoolVar(&daemonLogsArgs.time, "time", false, "include client time") + return fs + })(), + }, + { + Name: "daemon-bus-events", + ShortUsage: "tailscale debug daemon-bus-events", + Exec: runDaemonBusEvents, + ShortHelp: "Watch events on the tailscaled bus", + }, + { + Name: "daemon-bus-graph", + ShortUsage: "tailscale debug daemon-bus-graph", + Exec: runDaemonBusGraph, + ShortHelp: "Print graph for the tailscaled bus", + FlagSet: (func() *flag.FlagSet { + fs := newFlagSet("debug-bus-graph") + fs.StringVar(&daemonBusGraphArgs.format, "format", "json", "output format [json/dot]") + return fs + })(), + }, + { + Name: "metrics", + ShortUsage: "tailscale debug metrics", + Exec: runDaemonMetrics, + ShortHelp: "Print tailscaled's metrics", + FlagSet: (func() *flag.FlagSet { + fs := newFlagSet("metrics") + fs.BoolVar(&metricsArgs.watch, "watch", false, "print JSON dump of delta values") + return fs + })(), + }, + { + Name: "env", + ShortUsage: "tailscale debug env", + Exec: runEnv, + ShortHelp: "Print cmd/tailscale environment", + }, + { + Name: "stat", + ShortUsage: "tailscale debug stat ", + Exec: runStat, + ShortHelp: "Stat a file", + }, + { + Name: "hostinfo", + ShortUsage: "tailscale debug hostinfo", + Exec: runHostinfo, + ShortHelp: "Print hostinfo", + }, + { + Name: "local-creds", + ShortUsage: "tailscale debug local-creds", + Exec: runLocalCreds, + ShortHelp: "Print how to access Tailscale LocalAPI", + }, + { + Name: "localapi", + ShortUsage: "tailscale debug localapi [] []", + Exec: runLocalAPI, + ShortHelp: "Call a LocalAPI method directly", + FlagSet: (func() *flag.FlagSet { + fs := newFlagSet("localapi") + fs.BoolVar(&localAPIFlags.verbose, "v", false, "verbose; dump HTTP headers") + return fs + })(), + }, + { + Name: "restun", + ShortUsage: "tailscale debug restun", + Exec: localAPIAction("restun"), + ShortHelp: "Force a magicsock restun", + }, + { + Name: "rebind", + ShortUsage: "tailscale debug rebind", + Exec: localAPIAction("rebind"), + ShortHelp: "Force a magicsock rebind", + }, + { + Name: "rotate-disco-key", + ShortUsage: "tailscale debug rotate-disco-key", + Exec: localAPIAction("rotate-disco-key"), + ShortHelp: "Rotate the discovery key", + }, + { + Name: "derp-set-on-demand", + ShortUsage: "tailscale debug derp-set-on-demand", + Exec: localAPIAction("derp-set-homeless"), + ShortHelp: "Enable DERP on-demand mode (breaks reachability)", + }, + { + Name: "derp-unset-on-demand", + ShortUsage: "tailscale debug derp-unset-on-demand", + Exec: localAPIAction("derp-unset-homeless"), + ShortHelp: "Disable DERP on-demand mode", + }, + { + Name: "break-tcp-conns", + ShortUsage: "tailscale debug break-tcp-conns", + Exec: localAPIAction("break-tcp-conns"), + ShortHelp: "Break any open TCP connections from the daemon", + }, + { + Name: "break-derp-conns", + ShortUsage: "tailscale debug break-derp-conns", + Exec: localAPIAction("break-derp-conns"), + ShortHelp: "Break any open DERP connections from the daemon", + }, + { + Name: "pick-new-derp", + ShortUsage: "tailscale debug pick-new-derp", + Exec: localAPIAction("pick-new-derp"), + ShortHelp: "Switch to some other random DERP home region for a short time", + }, + { + Name: "force-prefer-derp", + ShortUsage: "tailscale debug force-prefer-derp", + Exec: forcePreferDERP, + ShortHelp: "Prefer the given region ID if reachable (until restart, or 0 to clear)", + }, + { + Name: "force-netmap-update", + ShortUsage: "tailscale debug force-netmap-update", + Exec: localAPIAction("force-netmap-update"), + ShortHelp: "Force a full no-op netmap update (for load testing)", + }, + { + // TODO(bradfitz,maisem): eventually promote this out of debug + Name: "reload-config", + ShortUsage: "tailscale debug reload-config", + Exec: reloadConfig, + ShortHelp: "Reload config", + }, + { + Name: "control-knobs", + ShortUsage: "tailscale debug control-knobs", + Exec: debugControlKnobs, + ShortHelp: "See current control knobs", + }, + { + Name: "prefs", + ShortUsage: "tailscale debug prefs", + Exec: runPrefs, + ShortHelp: "Print prefs", + FlagSet: (func() *flag.FlagSet { + fs := newFlagSet("prefs") + fs.BoolVar(&prefsArgs.pretty, "pretty", false, "If true, pretty-print output") + return fs + })(), + }, + { + Name: "watch-ipn", + ShortUsage: "tailscale debug watch-ipn", + Exec: runWatchIPN, + ShortHelp: "Subscribe to IPN message bus", + FlagSet: (func() *flag.FlagSet { + fs := newFlagSet("watch-ipn") + fs.BoolVar(&watchIPNArgs.netmap, "netmap", true, "include netmap in messages") + fs.BoolVar(&watchIPNArgs.initial, "initial", false, "include initial status") + fs.BoolVar(&watchIPNArgs.rateLimit, "rate-limit", true, "rate limit messags") + fs.IntVar(&watchIPNArgs.count, "count", 0, "exit after printing this many statuses, or 0 to keep going forever") + return fs + })(), + }, + { + Name: "netmap", + ShortUsage: "tailscale debug netmap", + Exec: runNetmap, + ShortHelp: "Print the current network map", + FlagSet: (func() *flag.FlagSet { + fs := newFlagSet("netmap") + return fs + })(), + }, + { + Name: "via", + ShortUsage: "tailscale debug via \n" + + "tailscale debug via ", + Exec: runVia, + ShortHelp: "Convert between site-specific IPv4 CIDRs and IPv6 'via' routes", + }, + { + Name: "ts2021", + ShortUsage: "tailscale debug ts2021", + Exec: runTS2021, + ShortHelp: "Debug ts2021 protocol connectivity", + FlagSet: (func() *flag.FlagSet { + fs := newFlagSet("ts2021") + fs.StringVar(&ts2021Args.host, "host", "controlplane.tailscale.com", "hostname of control plane") + fs.IntVar(&ts2021Args.version, "version", int(tailcfg.CurrentCapabilityVersion), "protocol version") + fs.BoolVar(&ts2021Args.verbose, "verbose", false, "be extra verbose") + fs.StringVar(&ts2021Args.aceHost, "ace", "", "if non-empty, use this ACE server IP/hostname as a candidate path") + fs.StringVar(&ts2021Args.dialPlanJSONFile, "dial-plan", "", "if non-empty, use this JSON file to configure the dial plan") + return fs + })(), + }, + { + Name: "set-expire", + ShortUsage: "tailscale debug set-expire --in=1m", + Exec: runSetExpire, + ShortHelp: "Manipulate node key expiry for testing", + FlagSet: (func() *flag.FlagSet { + fs := newFlagSet("set-expire") + fs.DurationVar(&setExpireArgs.in, "in", 0, "if non-zero, set node key to expire this duration from now") + return fs + })(), + }, + { + Name: "dev-store-set", + ShortUsage: "tailscale debug dev-store-set", + Exec: runDevStoreSet, + ShortHelp: "Set a key/value pair during development", + FlagSet: (func() *flag.FlagSet { + fs := newFlagSet("store-set") + fs.BoolVar(&devStoreSetArgs.danger, "danger", false, "accept danger") + return fs + })(), + }, + { + Name: "derp", + ShortUsage: "tailscale debug derp", + Exec: runDebugDERP, + ShortHelp: "Test a DERP configuration", + }, + ccall(debugCaptureCmd), + ccall(debugPortmapCmd), + { + Name: "peer-endpoint-changes", + ShortUsage: "tailscale debug peer-endpoint-changes ", + Exec: runPeerEndpointChanges, + ShortHelp: "Print debug information about a peer's endpoint changes", + }, + { + Name: "dial-types", + ShortUsage: "tailscale debug dial-types ", + Exec: runDebugDialTypes, + ShortHelp: "Print debug information about connecting to a given host or IP", + FlagSet: (func() *flag.FlagSet { + fs := newFlagSet("dial-types") + fs.StringVar(&debugDialTypesArgs.network, "network", "tcp", `network type to dial ("tcp", "udp", etc.)`) + return fs + })(), + }, + { + Name: "resolve", + ShortUsage: "tailscale debug resolve ", + Exec: runDebugResolve, + ShortHelp: "Does a DNS lookup", + FlagSet: (func() *flag.FlagSet { + fs := newFlagSet("resolve") + fs.StringVar(&resolveArgs.net, "net", "ip", "network type to resolve (ip, ip4, ip6)") + return fs + })(), + }, + { + Name: "go-buildinfo", + ShortUsage: "tailscale debug go-buildinfo", + ShortHelp: "Print Go's runtime/debug.BuildInfo", + Exec: runGoBuildInfo, + }, + { + Name: "peer-relay-servers", + ShortUsage: "tailscale debug peer-relay-servers", + ShortHelp: "Print the current set of candidate peer relay servers", + Exec: runPeerRelayServers, + }, + { + Name: "test-risk", + ShortUsage: "tailscale debug test-risk", + ShortHelp: "Do a fake risky action", + Exec: runTestRisk, + FlagSet: (func() *flag.FlagSet { + fs := newFlagSet("test-risk") + fs.StringVar(&testRiskArgs.acceptedRisk, "accept-risk", "", "comma-separated list of accepted risks") + return fs + })(), + }, + ccall(debugPeerRelayCmd), + }...), + } } func runGoBuildInfo(ctx context.Context, args []string) error { @@ -449,6 +495,81 @@ func runLocalCreds(ctx context.Context, args []string) error { return nil } +func looksLikeHTTPMethod(s string) bool { + if len(s) > len("OPTIONS") { + return false + } + for _, r := range s { + if r < 'A' || r > 'Z' { + return false + } + } + return true +} + +var localAPIFlags struct { + verbose bool +} + +func runLocalAPI(ctx context.Context, args []string) error { + if len(args) == 0 { + return errors.New("expected at least one argument") + } + method := "GET" + if looksLikeHTTPMethod(args[0]) { + method = args[0] + args = args[1:] + if len(args) == 0 { + return errors.New("expected at least one argument after method") + } + } + path := args[0] + if !strings.HasPrefix(path, "/localapi/") { + if !strings.Contains(path, "/") { + path = "/localapi/v0/" + path + } else { + path = "/localapi/" + path + } + } + + var body io.Reader + if len(args) > 1 { + if args[1] == "-" { + fmt.Fprintf(Stderr, "# reading request body from stdin...\n") + all, err := io.ReadAll(os.Stdin) + if err != nil { + return fmt.Errorf("reading Stdin: %q", err) + } + body = bytes.NewReader(all) + } else { + body = strings.NewReader(args[1]) + } + } + req, err := http.NewRequest(method, "http://local-tailscaled.sock"+path, body) + if err != nil { + return err + } + fmt.Fprintf(Stderr, "# doing request %s %s\n", method, path) + + res, err := localClient.DoLocalRequest(req) + if err != nil { + return err + } + is2xx := res.StatusCode >= 200 && res.StatusCode <= 299 + if localAPIFlags.verbose { + res.Write(Stdout) + } else { + if !is2xx { + fmt.Fprintf(Stderr, "# Response status %s\n", res.Status) + } + io.Copy(Stdout, res.Body) + } + if is2xx { + return nil + } + return errors.New(res.Status) +} + type localClientRoundTripper struct{} func (localClientRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { @@ -497,10 +618,10 @@ func runPrefs(ctx context.Context, args []string) error { } var watchIPNArgs struct { - netmap bool - initial bool - showPrivateKey bool - count int + netmap bool + initial bool + rateLimit bool + count int } func runWatchIPN(ctx context.Context, args []string) error { @@ -508,8 +629,8 @@ func runWatchIPN(ctx context.Context, args []string) error { if watchIPNArgs.initial { mask = ipn.NotifyInitialState | ipn.NotifyInitialPrefs | ipn.NotifyInitialNetMap } - if !watchIPNArgs.showPrivateKey { - mask |= ipn.NotifyNoPrivateKeys + if watchIPNArgs.rateLimit { + mask |= ipn.NotifyRateLimit } watcher, err := localClient.WatchIPNBus(ctx, mask) if err != nil { @@ -531,18 +652,11 @@ func runWatchIPN(ctx context.Context, args []string) error { return nil } -var netmapArgs struct { - showPrivateKey bool -} - func runNetmap(ctx context.Context, args []string) error { ctx, cancel := context.WithTimeout(ctx, 5*time.Second) defer cancel() var mask ipn.NotifyWatchOpt = ipn.NotifyInitialNetMap - if !netmapArgs.showPrivateKey { - mask |= ipn.NotifyNoPrivateKeys - } watcher, err := localClient.WatchIPNBus(ctx, mask) if err != nil { return err @@ -571,6 +685,25 @@ func runDERPMap(ctx context.Context, args []string) error { return nil } +func forcePreferDERP(ctx context.Context, args []string) error { + var n int + if len(args) != 1 { + return errors.New("expected exactly one integer argument") + } + n, err := strconv.Atoi(args[0]) + if err != nil { + return fmt.Errorf("expected exactly one integer argument: %w", err) + } + b, err := json.Marshal(n) + if err != nil { + return fmt.Errorf("failed to marshal DERP region: %w", err) + } + if err := localClient.DebugActionBody(ctx, "force-prefer-derp", bytes.NewReader(b)); err != nil { + return fmt.Errorf("failed to force preferred DERP: %w", err) + } + return nil +} + func localAPIAction(action string) func(context.Context, []string) error { return func(ctx context.Context, args []string) error { if len(args) > 0 { @@ -672,6 +805,61 @@ func runDaemonLogs(ctx context.Context, args []string) error { } } +func runDaemonBusEvents(ctx context.Context, args []string) error { + for line, err := range localClient.StreamBusEvents(ctx) { + if err != nil { + return err + } + fmt.Printf("[%d][%q][from: %q][to: %q] %s\n", line.Count, line.Type, + line.From, line.To, line.Event) + } + return nil +} + +var daemonBusGraphArgs struct { + format string +} + +func runDaemonBusGraph(ctx context.Context, args []string) error { + graph, err := localClient.EventBusGraph(ctx) + if err != nil { + return err + } + if format := daemonBusGraphArgs.format; format != "json" && format != "dot" { + return fmt.Errorf("unrecognized output format %q", format) + } + if daemonBusGraphArgs.format == "dot" { + var topics eventbus.DebugTopics + if err := json.Unmarshal(graph, &topics); err != nil { + return fmt.Errorf("unable to parse json: %w", err) + } + fmt.Print(generateDOTGraph(topics.Topics)) + } else { + fmt.Print(string(graph)) + } + return nil +} + +// generateDOTGraph generates the DOT graph format based on the events +func generateDOTGraph(topics []eventbus.DebugTopic) string { + var sb strings.Builder + sb.WriteString("digraph event_bus {\n") + + for _, topic := range topics { + // If no subscribers, still ensure the topic is drawn + if len(topic.Subscribers) == 0 { + topic.Subscribers = append(topic.Subscribers, "no-subscribers") + } + for _, subscriber := range topic.Subscribers { + fmt.Fprintf(&sb, "\t%q -> %q [label=%q];\n", + topic.Publisher, subscriber, cmp.Or(topic.Name, "???")) + } + } + + sb.WriteString("}\n") + return sb.String() +} + var metricsArgs struct { watch bool } @@ -776,6 +964,9 @@ var ts2021Args struct { host string // "controlplane.tailscale.com" version int // 27 or whatever verbose bool + aceHost string // if non-empty, FQDN of https ACE server to use ("ace.example.com") + + dialPlanJSONFile string // if non-empty, path to JSON file [tailcfg.ControlDialPlan] JSON } func runTS2021(ctx context.Context, args []string) error { @@ -784,19 +975,22 @@ func runTS2021(ctx context.Context, args []string) error { keysURL := "https://" + ts2021Args.host + "/key?v=" + strconv.Itoa(ts2021Args.version) + keyTransport := http.DefaultTransport.(*http.Transport).Clone() + if ts2021Args.aceHost != "" { + log.Printf("using ACE server %q", ts2021Args.aceHost) + keyTransport.Proxy = nil + keyTransport.DialContext = (&ace.Dialer{ACEHost: ts2021Args.aceHost}).Dial + } + if ts2021Args.verbose { u, err := url.Parse(keysURL) if err != nil { return err } - envConf := httpproxy.FromEnvironment() - if *envConf == (httpproxy.Config{}) { - log.Printf("HTTP proxy env: (none)") - } else { - log.Printf("HTTP proxy env: %+v", envConf) + if proxyFromEnv, ok := feature.HookProxyFromEnvironment.GetOk(); ok { + proxy, err := proxyFromEnv(&http.Request{URL: u}) + log.Printf("tshttpproxy.ProxyFromEnvironment = (%v, %v)", proxy, err) } - proxy, err := tshttpproxy.ProxyFromEnvironment(&http.Request{URL: u}) - log.Printf("tshttpproxy.ProxyFromEnvironment = (%v, %v)", proxy, err) } machinePrivate := key.NewMachine() var dialer net.Dialer @@ -809,7 +1003,7 @@ func runTS2021(ctx context.Context, args []string) error { if err != nil { return err } - res, err := http.DefaultClient.Do(req) + res, err := keyTransport.RoundTrip(req) if err != nil { log.Printf("Do: %v", err) return err @@ -845,19 +1039,53 @@ func runTS2021(ctx context.Context, args []string) error { logf = log.Printf } - noiseDialer := &controlhttp.Dialer{ - Hostname: ts2021Args.host, - HTTPPort: "80", - HTTPSPort: "443", - MachineKey: machinePrivate, - ControlKey: keys.PublicKey, - ProtocolVersion: uint16(ts2021Args.version), - Dialer: dialFunc, - Logf: logf, + bus := eventbus.New() + defer bus.Close() + + netMon, err := netmon.New(bus, logger.WithPrefix(logf, "netmon: ")) + if err != nil { + return fmt.Errorf("creating netmon: %w", err) + } + + var dialPlan *tailcfg.ControlDialPlan + if ts2021Args.dialPlanJSONFile != "" { + b, err := os.ReadFile(ts2021Args.dialPlanJSONFile) + if err != nil { + return fmt.Errorf("reading dial plan JSON file: %w", err) + } + dialPlan = new(tailcfg.ControlDialPlan) + if err := json.Unmarshal(b, dialPlan); err != nil { + return fmt.Errorf("unmarshaling dial plan JSON file: %w", err) + } + } else if ts2021Args.aceHost != "" { + dialPlan = &tailcfg.ControlDialPlan{ + Candidates: []tailcfg.ControlIPCandidate{ + { + ACEHost: ts2021Args.aceHost, + DialTimeoutSec: 10, + }, + }, + } + } + + opts := ts2021.ClientOpts{ + ServerURL: "https://" + ts2021Args.host, + DialPlan: func() *tailcfg.ControlDialPlan { + return dialPlan + }, + Logf: logf, + NetMon: netMon, + PrivKey: machinePrivate, + ServerPubKey: keys.PublicKey, + Dialer: tsdial.NewFromFuncForDebug(logf, dialFunc), + DNSCache: &dnscache.Resolver{}, + HealthTracker: &health.Tracker{}, } + + // TODO: ProtocolVersion: uint16(ts2021Args.version), const tries = 2 for i := range tries { - err := tryConnect(ctx, keys.PublicKey, noiseDialer) + err := tryConnect(ctx, keys.PublicKey, opts) if err != nil { log.Printf("error on attempt %d/%d: %v", i+1, tries, err) continue @@ -867,53 +1095,37 @@ func runTS2021(ctx context.Context, args []string) error { return nil } -func tryConnect(ctx context.Context, controlPublic key.MachinePublic, noiseDialer *controlhttp.Dialer) error { - conn, err := noiseDialer.Dial(ctx) - log.Printf("controlhttp.Dial = %p, %v", conn, err) - if err != nil { - return err - } - log.Printf("did noise handshake") +func tryConnect(ctx context.Context, controlPublic key.MachinePublic, opts ts2021.ClientOpts) error { - gotPeer := conn.Peer() - if gotPeer != controlPublic { - log.Printf("peer = %v, want %v", gotPeer, controlPublic) - return errors.New("key mismatch") - } - - log.Printf("final underlying conn: %v / %v", conn.LocalAddr(), conn.RemoteAddr()) - - h2Transport, err := http2.ConfigureTransports(&http.Transport{ - IdleConnTimeout: time.Second, + ctx = httptrace.WithClientTrace(ctx, &httptrace.ClientTrace{ + GotConn: func(ci httptrace.GotConnInfo) { + log.Printf("GotConn: %T", ci.Conn) + ncc, ok := ci.Conn.(*ts2021.Conn) + if !ok { + return + } + log.Printf("did noise handshake") + log.Printf("final underlying conn: %v / %v", ncc.LocalAddr(), ncc.RemoteAddr()) + gotPeer := ncc.Peer() + if gotPeer != controlPublic { + log.Fatalf("peer = %v, want %v", gotPeer, controlPublic) + } + }, }) - if err != nil { - return fmt.Errorf("http2.ConfigureTransports: %w", err) - } - // Now, create a Noise conn over the existing conn. - nc, err := noiseconn.New(conn.Conn, h2Transport, 0, nil) + nc, err := ts2021.NewClient(opts) if err != nil { - return fmt.Errorf("noiseconn.New: %w", err) - } - defer nc.Close() - - // Reserve a RoundTrip for the whoami request. - ok, _, err := nc.ReserveNewRequest(ctx) - if err != nil { - return fmt.Errorf("ReserveNewRequest: %w", err) - } - if !ok { - return errors.New("ReserveNewRequest failed") + return fmt.Errorf("NewNoiseClient: %w", err) } // Make a /whoami request to the server to verify that we can actually // communicate over the newly-established connection. - whoamiURL := "http://" + ts2021Args.host + "/machine/whoami" + whoamiURL := "https://" + ts2021Args.host + "/machine/whoami" req, err := http.NewRequestWithContext(ctx, "GET", whoamiURL, nil) if err != nil { return err } - resp, err := nc.RoundTrip(req) + resp, err := nc.Do(req) if err != nil { return fmt.Errorf("RoundTrip whoami request: %w", err) } @@ -999,88 +1211,6 @@ func runSetExpire(ctx context.Context, args []string) error { return localClient.DebugSetExpireIn(ctx, setExpireArgs.in) } -var captureArgs struct { - outFile string -} - -func runCapture(ctx context.Context, args []string) error { - stream, err := localClient.StreamDebugCapture(ctx) - if err != nil { - return err - } - defer stream.Close() - - switch captureArgs.outFile { - case "-": - fmt.Fprintln(Stderr, "Press Ctrl-C to stop the capture.") - _, err = io.Copy(os.Stdout, stream) - return err - case "": - lua, err := os.CreateTemp("", "ts-dissector") - if err != nil { - return err - } - defer os.Remove(lua.Name()) - lua.Write([]byte(capture.DissectorLua)) - if err := lua.Close(); err != nil { - return err - } - - wireshark := exec.CommandContext(ctx, "wireshark", "-X", "lua_script:"+lua.Name(), "-k", "-i", "-") - wireshark.Stdin = stream - wireshark.Stdout = os.Stdout - wireshark.Stderr = os.Stderr - return wireshark.Run() - } - - f, err := os.OpenFile(captureArgs.outFile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644) - if err != nil { - return err - } - defer f.Close() - fmt.Fprintln(Stderr, "Press Ctrl-C to stop the capture.") - _, err = io.Copy(f, stream) - return err -} - -var debugPortmapArgs struct { - duration time.Duration - gatewayAddr string - selfAddr string - ty string - logHTTP bool -} - -func debugPortmap(ctx context.Context, args []string) error { - opts := &tailscale.DebugPortmapOpts{ - Duration: debugPortmapArgs.duration, - Type: debugPortmapArgs.ty, - LogHTTP: debugPortmapArgs.logHTTP, - } - if (debugPortmapArgs.gatewayAddr != "") != (debugPortmapArgs.selfAddr != "") { - return fmt.Errorf("if one of --gateway-addr and --self-addr is provided, the other must be as well") - } - if debugPortmapArgs.gatewayAddr != "" { - var err error - opts.GatewayAddr, err = netip.ParseAddr(debugPortmapArgs.gatewayAddr) - if err != nil { - return fmt.Errorf("invalid --gateway-addr: %w", err) - } - opts.SelfAddr, err = netip.ParseAddr(debugPortmapArgs.selfAddr) - if err != nil { - return fmt.Errorf("invalid --self-addr: %w", err) - } - } - rc, err := localClient.DebugPortmap(ctx, opts) - if err != nil { - return err - } - defer rc.Close() - - _, err = io.Copy(os.Stdout, rc) - return err -} - func runPeerEndpointChanges(ctx context.Context, args []string) error { st, err := localClient.Status(ctx) if err != nil { @@ -1233,3 +1363,32 @@ func runDebugResolve(ctx context.Context, args []string) error { } return nil } + +func runPeerRelayServers(ctx context.Context, args []string) error { + if len(args) > 0 { + return errors.New("unexpected arguments") + } + v, err := localClient.DebugResultJSON(ctx, "peer-relay-servers") + if err != nil { + return err + } + e := json.NewEncoder(os.Stdout) + e.SetIndent("", " ") + e.Encode(v) + return nil +} + +var testRiskArgs struct { + acceptedRisk string +} + +func runTestRisk(ctx context.Context, args []string) error { + if len(args) > 0 { + return errors.New("unexpected arguments") + } + if err := presentRiskToUser("test-risk", "This is a test risky action.", testRiskArgs.acceptedRisk); err != nil { + return err + } + fmt.Println("did-test-risky-action") + return nil +} diff --git a/cmd/tailscale/cli/diag.go b/cmd/tailscale/cli/diag.go index ebf26985f..3b2aa504b 100644 --- a/cmd/tailscale/cli/diag.go +++ b/cmd/tailscale/cli/diag.go @@ -1,7 +1,7 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -//go:build linux || windows || darwin +//go:build (linux || windows || darwin) && !ts_omit_cliconndiag package cli @@ -16,11 +16,15 @@ import ( "tailscale.com/version/distro" ) -// fixTailscaledConnectError is called when the local tailscaled has +func init() { + hookFixTailscaledConnectError.Set(fixTailscaledConnectErrorImpl) +} + +// fixTailscaledConnectErrorImpl is called when the local tailscaled has // been determined unreachable due to the provided origErr value. It // returns either the same error or a better one to help the user // understand why tailscaled isn't running for their platform. -func fixTailscaledConnectError(origErr error) error { +func fixTailscaledConnectErrorImpl(origErr error) error { procs, err := ps.Processes() if err != nil { return fmt.Errorf("failed to connect to local Tailscaled process and failed to enumerate processes while looking for it") diff --git a/cmd/tailscale/cli/diag_other.go b/cmd/tailscale/cli/diag_other.go deleted file mode 100644 index ece10cc79..000000000 --- a/cmd/tailscale/cli/diag_other.go +++ /dev/null @@ -1,15 +0,0 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !linux && !windows && !darwin - -package cli - -import "fmt" - -// The github.com/mitchellh/go-ps package doesn't work on all platforms, -// so just don't diagnose connect failures. - -func fixTailscaledConnectError(origErr error) error { - return fmt.Errorf("failed to connect to local tailscaled process (is it running?); got: %w", origErr) -} diff --git a/cmd/tailscale/cli/dns-query.go b/cmd/tailscale/cli/dns-query.go index da2d9d2a5..11f644537 100644 --- a/cmd/tailscale/cli/dns-query.go +++ b/cmd/tailscale/cli/dns-query.go @@ -9,12 +9,31 @@ import ( "fmt" "net/netip" "os" + "strings" "text/tabwriter" + "github.com/peterbourgon/ff/v3/ffcli" "golang.org/x/net/dns/dnsmessage" "tailscale.com/types/dnstype" ) +var dnsQueryCmd = &ffcli.Command{ + Name: "query", + ShortUsage: "tailscale dns query [a|aaaa|cname|mx|ns|opt|ptr|srv|txt]", + Exec: runDNSQuery, + ShortHelp: "Perform a DNS query", + LongHelp: strings.TrimSpace(` +The 'tailscale dns query' subcommand performs a DNS query for the specified name +using the internal DNS forwarder (100.100.100.100). + +By default, the DNS query will request an A record. Another DNS record type can +be specified as the second parameter. + +The output also provides information about the resolver(s) used to resolve the +query. +`), +} + func runDNSQuery(ctx context.Context, args []string) error { if len(args) < 1 { return flag.ErrHelp diff --git a/cmd/tailscale/cli/dns-status.go b/cmd/tailscale/cli/dns-status.go index e487c66bc..8c18622ce 100644 --- a/cmd/tailscale/cli/dns-status.go +++ b/cmd/tailscale/cli/dns-status.go @@ -5,15 +5,77 @@ package cli import ( "context" + "flag" "fmt" "maps" "slices" "strings" + "github.com/peterbourgon/ff/v3/ffcli" "tailscale.com/ipn" "tailscale.com/types/netmap" ) +var dnsStatusCmd = &ffcli.Command{ + Name: "status", + ShortUsage: "tailscale dns status [--all]", + Exec: runDNSStatus, + ShortHelp: "Print the current DNS status and configuration", + LongHelp: strings.TrimSpace(` +The 'tailscale dns status' subcommand prints the current DNS status and +configuration, including: + +- Whether the built-in DNS forwarder is enabled. + +- The MagicDNS configuration provided by the coordination server. + +- Details on which resolver(s) Tailscale believes the system is using by + default. + +The --all flag can be used to output advanced debugging information, including +fallback resolvers, nameservers, certificate domains, extra records, and the +exit node filtered set. + +=== Contents of the MagicDNS configuration === + +The MagicDNS configuration is provided by the coordination server to the client +and includes the following components: + +- MagicDNS enablement status: Indicates whether MagicDNS is enabled across the + entire tailnet. + +- MagicDNS Suffix: The DNS suffix used for devices within your tailnet. + +- DNS Name: The DNS name that other devices in the tailnet can use to reach this + device. + +- Resolvers: The preferred DNS resolver(s) to be used for resolving queries, in + order of preference. If no resolvers are listed here, the system defaults are + used. + +- Split DNS Routes: Custom DNS resolvers may be used to resolve hostnames in + specific domains, this is also known as a 'Split DNS' configuration. The + mapping of domains to their respective resolvers is provided here. + +- Certificate Domains: The DNS names for which the coordination server will + assist in provisioning TLS certificates. + +- Extra Records: Additional DNS records that the coordination server might + provide to the internal DNS resolver. + +- Exit Node Filtered Set: DNS suffixes that the node, when acting as an exit + node DNS proxy, will not answer. + +For more information about the DNS functionality built into Tailscale, refer to +https://tailscale.com/kb/1054/dns. +`), + FlagSet: (func() *flag.FlagSet { + fs := newFlagSet("status") + fs.BoolVar(&dnsStatusArgs.all, "all", false, "outputs advanced debugging information") + return fs + })(), +} + // dnsStatusArgs are the arguments for the "dns status" subcommand. var dnsStatusArgs struct { all bool @@ -208,35 +270,3 @@ func fetchNetMap() (netMap *netmap.NetworkMap, err error) { } return notify.NetMap, nil } - -func dnsStatusLongHelp() string { - return `The 'tailscale dns status' subcommand prints the current DNS status and configuration, including: - -- Whether the built-in DNS forwarder is enabled. -- The MagicDNS configuration provided by the coordination server. -- Details on which resolver(s) Tailscale believes the system is using by default. - -The --all flag can be used to output advanced debugging information, including fallback resolvers, nameservers, certificate domains, extra records, and the exit node filtered set. - -=== Contents of the MagicDNS configuration === - -The MagicDNS configuration is provided by the coordination server to the client and includes the following components: - -- MagicDNS enablement status: Indicates whether MagicDNS is enabled across the entire tailnet. - -- MagicDNS Suffix: The DNS suffix used for devices within your tailnet. - -- DNS Name: The DNS name that other devices in the tailnet can use to reach this device. - -- Resolvers: The preferred DNS resolver(s) to be used for resolving queries, in order of preference. If no resolvers are listed here, the system defaults are used. - -- Split DNS Routes: Custom DNS resolvers may be used to resolve hostnames in specific domains, this is also known as a 'Split DNS' configuration. The mapping of domains to their respective resolvers is provided here. - -- Certificate Domains: The DNS names for which the coordination server will assist in provisioning TLS certificates. - -- Extra Records: Additional DNS records that the coordination server might provide to the internal DNS resolver. - -- Exit Node Filtered Set: DNS suffixes that the node, when acting as an exit node DNS proxy, will not answer. - -For more information about the DNS functionality built into Tailscale, refer to https://tailscale.com/kb/1054/dns.` -} diff --git a/cmd/tailscale/cli/dns.go b/cmd/tailscale/cli/dns.go index 042ce1a94..086abefd6 100644 --- a/cmd/tailscale/cli/dns.go +++ b/cmd/tailscale/cli/dns.go @@ -4,46 +4,32 @@ package cli import ( - "flag" + "strings" "github.com/peterbourgon/ff/v3/ffcli" ) var dnsCmd = &ffcli.Command{ - Name: "dns", - ShortHelp: "Diagnose the internal DNS forwarder", - LongHelp: dnsCmdLongHelp(), - ShortUsage: "tailscale dns [flags]", - UsageFunc: usageFuncNoDefaultValues, + Name: "dns", + ShortHelp: "Diagnose the internal DNS forwarder", + LongHelp: strings.TrimSpace(` +The 'tailscale dns' subcommand provides tools for diagnosing the internal DNS +forwarder (100.100.100.100). + +For more information about the DNS functionality built into Tailscale, refer to +https://tailscale.com/kb/1054/dns. +`), + ShortUsage: strings.Join([]string{ + dnsStatusCmd.ShortUsage, + dnsQueryCmd.ShortUsage, + }, "\n"), + UsageFunc: usageFuncNoDefaultValues, Subcommands: []*ffcli.Command{ - { - Name: "status", - ShortUsage: "tailscale dns status [--all]", - Exec: runDNSStatus, - ShortHelp: "Prints the current DNS status and configuration", - LongHelp: dnsStatusLongHelp(), - FlagSet: (func() *flag.FlagSet { - fs := newFlagSet("status") - fs.BoolVar(&dnsStatusArgs.all, "all", false, "outputs advanced debugging information (fallback resolvers, nameservers, cert domains, extra records, and exit node filtered set)") - return fs - })(), - }, - { - Name: "query", - ShortUsage: "tailscale dns query [a|aaaa|cname|mx|ns|opt|ptr|srv|txt]", - Exec: runDNSQuery, - ShortHelp: "Perform a DNS query", - LongHelp: "The 'tailscale dns query' subcommand performs a DNS query for the specified name using the internal DNS forwarder (100.100.100.100).\n\nIt also provides information about the resolver(s) used to resolve the query.", - }, + dnsStatusCmd, + dnsQueryCmd, // TODO: implement `tailscale log` here // The above work is tracked in https://github.com/tailscale/tailscale/issues/13326 }, } - -func dnsCmdLongHelp() string { - return `The 'tailscale dns' subcommand provides tools for diagnosing the internal DNS forwarder (100.100.100.100). - -For more information about the DNS functionality built into Tailscale, refer to https://tailscale.com/kb/1054/dns.` -} diff --git a/cmd/tailscale/cli/down.go b/cmd/tailscale/cli/down.go index 1eb85a13e..224198a98 100644 --- a/cmd/tailscale/cli/down.go +++ b/cmd/tailscale/cli/down.go @@ -9,6 +9,7 @@ import ( "fmt" "github.com/peterbourgon/ff/v3/ffcli" + "tailscale.com/client/tailscale/apitype" "tailscale.com/ipn" ) @@ -23,10 +24,12 @@ var downCmd = &ffcli.Command{ var downArgs struct { acceptedRisks string + reason string } func newDownFlagSet() *flag.FlagSet { downf := newFlagSet("down") + downf.StringVar(&downArgs.reason, "reason", "", "reason for the disconnect, if required by a policy") registerAcceptRiskFlag(downf, &downArgs.acceptedRisks) return downf } @@ -50,6 +53,7 @@ func runDown(ctx context.Context, args []string) error { fmt.Fprintf(Stderr, "Tailscale was already stopped.\n") return nil } + ctx = apitype.RequestReasonKey.WithValue(ctx, downArgs.reason) _, err = localClient.EditPrefs(ctx, &ipn.MaskedPrefs{ Prefs: ipn.Prefs{ WantRunning: false, diff --git a/cmd/tailscale/cli/drive.go b/cmd/tailscale/cli/drive.go index 929852b4c..131f46847 100644 --- a/cmd/tailscale/cli/drive.go +++ b/cmd/tailscale/cli/drive.go @@ -1,6 +1,8 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause +//go:build !ts_omit_drive && !ts_mac_gui + package cli import ( @@ -20,43 +22,49 @@ const ( driveListUsage = "tailscale drive list" ) -var driveCmd = &ffcli.Command{ - Name: "drive", - ShortHelp: "Share a directory with your tailnet", - ShortUsage: strings.Join([]string{ - driveShareUsage, - driveRenameUsage, - driveUnshareUsage, - driveListUsage, - }, "\n"), - LongHelp: buildShareLongHelp(), - UsageFunc: usageFuncNoDefaultValues, - Subcommands: []*ffcli.Command{ - { - Name: "share", - ShortUsage: driveShareUsage, - Exec: runDriveShare, - ShortHelp: "[ALPHA] Create or modify a share", - }, - { - Name: "rename", - ShortUsage: driveRenameUsage, - ShortHelp: "[ALPHA] Rename a share", - Exec: runDriveRename, - }, - { - Name: "unshare", - ShortUsage: driveUnshareUsage, - ShortHelp: "[ALPHA] Remove a share", - Exec: runDriveUnshare, - }, - { - Name: "list", - ShortUsage: driveListUsage, - ShortHelp: "[ALPHA] List current shares", - Exec: runDriveList, +func init() { + maybeDriveCmd = driveCmd +} + +func driveCmd() *ffcli.Command { + return &ffcli.Command{ + Name: "drive", + ShortHelp: "Share a directory with your tailnet", + ShortUsage: strings.Join([]string{ + driveShareUsage, + driveRenameUsage, + driveUnshareUsage, + driveListUsage, + }, "\n"), + LongHelp: buildShareLongHelp(), + UsageFunc: usageFuncNoDefaultValues, + Subcommands: []*ffcli.Command{ + { + Name: "share", + ShortUsage: driveShareUsage, + Exec: runDriveShare, + ShortHelp: "[ALPHA] Create or modify a share", + }, + { + Name: "rename", + ShortUsage: driveRenameUsage, + ShortHelp: "[ALPHA] Rename a share", + Exec: runDriveRename, + }, + { + Name: "unshare", + ShortUsage: driveUnshareUsage, + ShortHelp: "[ALPHA] Remove a share", + Exec: runDriveUnshare, + }, + { + Name: "list", + ShortUsage: driveListUsage, + ShortHelp: "[ALPHA] List current shares", + Exec: runDriveList, + }, }, - }, + } } // runDriveShare is the entry point for the "tailscale drive share" command. diff --git a/cmd/tailscale/cli/exitnode.go b/cmd/tailscale/cli/exitnode.go index 6b9247a7b..b47b9f0bd 100644 --- a/cmd/tailscale/cli/exitnode.go +++ b/cmd/tailscale/cli/exitnode.go @@ -15,10 +15,10 @@ import ( "github.com/kballard/go-shellquote" "github.com/peterbourgon/ff/v3/ffcli" - xmaps "golang.org/x/exp/maps" "tailscale.com/envknob" "tailscale.com/ipn/ipnstate" "tailscale.com/tailcfg" + "tailscale.com/util/slicesx" ) func exitNodeCmd() *ffcli.Command { @@ -41,7 +41,7 @@ func exitNodeCmd() *ffcli.Command { { Name: "suggest", ShortUsage: "tailscale exit-node suggest", - ShortHelp: "Suggests the best available exit node", + ShortHelp: "Suggest the best available exit node", Exec: runExitNodeSuggest, }}, (func() []*ffcli.Command { @@ -131,7 +131,7 @@ func runExitNodeList(ctx context.Context, args []string) error { for _, country := range filteredPeers.Countries { for _, city := range country.Cities { for _, peer := range city.Peers { - fmt.Fprintf(w, "\n %s\t%s\t%s\t%s\t%s\t", peer.TailscaleIPs[0], strings.Trim(peer.DNSName, "."), country.Name, city.Name, peerStatus(peer)) + fmt.Fprintf(w, "\n %s\t%s\t%s\t%s\t%s\t", peer.TailscaleIPs[0], strings.Trim(peer.DNSName, "."), cmp.Or(country.Name, "-"), cmp.Or(city.Name, "-"), peerStatus(peer)) } } } @@ -173,11 +173,13 @@ func hasAnyExitNodeSuggestions(peers []*ipnstate.PeerStatus) bool { // a peer. If there is no notable state, a - is returned. func peerStatus(peer *ipnstate.PeerStatus) string { if !peer.Active { + lastseen := lastSeenFmt(peer.LastSeen) + if peer.ExitNode { - return "selected but offline" + return "selected but offline" + lastseen } if !peer.Online { - return "offline" + return "offline" + lastseen } } @@ -202,23 +204,16 @@ type filteredCity struct { Peers []*ipnstate.PeerStatus } -const noLocationData = "-" - -var noLocation = &tailcfg.Location{ - Country: noLocationData, - CountryCode: noLocationData, - City: noLocationData, - CityCode: noLocationData, -} - // filterFormatAndSortExitNodes filters and sorts exit nodes into // alphabetical order, by country, city and then by priority if // present. +// // If an exit node has location data, and the country has more than // one city, an `Any` city is added to the country that contains the // highest priority exit node within that country. +// // For exit nodes without location data, their country fields are -// defined as '-' to indicate that the data is not available. +// defined as the empty string to indicate that the data is not available. func filterFormatAndSortExitNodes(peers []*ipnstate.PeerStatus, filterBy string) filteredExitNodes { // first get peers into some fixed order, as code below doesn't break ties // and our input comes from a random range-over-map. @@ -229,7 +224,10 @@ func filterFormatAndSortExitNodes(peers []*ipnstate.PeerStatus, filterBy string) countries := make(map[string]*filteredCountry) cities := make(map[string]*filteredCity) for _, ps := range peers { - loc := cmp.Or(ps.Location, noLocation) + loc := ps.Location + if loc == nil { + loc = &tailcfg.Location{} + } if filterBy != "" && !strings.EqualFold(loc.Country, filterBy) { continue @@ -255,11 +253,11 @@ func filterFormatAndSortExitNodes(peers []*ipnstate.PeerStatus, filterBy string) } filteredExitNodes := filteredExitNodes{ - Countries: xmaps.Values(countries), + Countries: slicesx.MapValues(countries), } for _, country := range filteredExitNodes.Countries { - if country.Name == noLocationData { + if country.Name == "" { // Countries without location data should not // be filtered further. continue diff --git a/cmd/tailscale/cli/exitnode_test.go b/cmd/tailscale/cli/exitnode_test.go index 9d569a45a..cc38fd3a4 100644 --- a/cmd/tailscale/cli/exitnode_test.go +++ b/cmd/tailscale/cli/exitnode_test.go @@ -74,10 +74,10 @@ func TestFilterFormatAndSortExitNodes(t *testing.T) { want := filteredExitNodes{ Countries: []*filteredCountry{ { - Name: noLocationData, + Name: "", Cities: []*filteredCity{ { - Name: noLocationData, + Name: "", Peers: []*ipnstate.PeerStatus{ ps[5], }, @@ -273,14 +273,20 @@ func TestSortByCountryName(t *testing.T) { Name: "Zimbabwe", }, { - Name: noLocationData, + Name: "", }, } sortByCountryName(fc) - if fc[0].Name != noLocationData { - t.Fatalf("sortByCountryName did not order countries by alphabetical order, got %v, want %v", fc[0].Name, noLocationData) + want := []string{"", "Albania", "Sweden", "Zimbabwe"} + var got []string + for _, c := range fc { + got = append(got, c.Name) + } + + if diff := cmp.Diff(want, got); diff != "" { + t.Errorf("sortByCountryName did not order countries by alphabetical order (-want +got):\n%s", diff) } } @@ -296,13 +302,19 @@ func TestSortByCityName(t *testing.T) { Name: "Squamish", }, { - Name: noLocationData, + Name: "", }, } sortByCityName(fc) - if fc[0].Name != noLocationData { - t.Fatalf("sortByCityName did not order cities by alphabetical order, got %v, want %v", fc[0].Name, noLocationData) + want := []string{"", "Goteborg", "Kingston", "Squamish"} + var got []string + for _, c := range fc { + got = append(got, c.Name) + } + + if diff := cmp.Diff(want, got); diff != "" { + t.Errorf("sortByCityName did not order countries by alphabetical order (-want +got):\n%s", diff) } } diff --git a/cmd/tailscale/cli/ffcomplete/internal/complete_test.go b/cmd/tailscale/cli/ffcomplete/internal/complete_test.go index 7e36b1bcd..c216bdeec 100644 --- a/cmd/tailscale/cli/ffcomplete/internal/complete_test.go +++ b/cmd/tailscale/cli/ffcomplete/internal/complete_test.go @@ -196,7 +196,6 @@ func TestComplete(t *testing.T) { // Run the tests. for _, test := range tests { - test := test name := strings.Join(test.args, "âŖ") if test.showFlags { name += "+flags" diff --git a/cmd/tailscale/cli/file.go b/cmd/tailscale/cli/file.go index cd7762446..e0879197e 100644 --- a/cmd/tailscale/cli/file.go +++ b/cmd/tailscale/cli/file.go @@ -1,6 +1,8 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause +//go:build !ts_omit_taildrop + package cli import ( @@ -18,6 +20,7 @@ import ( "path" "path/filepath" "strings" + "sync" "sync/atomic" "time" "unicode/utf8" @@ -28,8 +31,8 @@ import ( "tailscale.com/client/tailscale/apitype" "tailscale.com/cmd/tailscale/cli/ffcomplete" "tailscale.com/envknob" + "tailscale.com/ipn/ipnstate" "tailscale.com/net/tsaddr" - "tailscale.com/syncs" "tailscale.com/tailcfg" tsrate "tailscale.com/tstime/rate" "tailscale.com/util/quarantine" @@ -37,14 +40,20 @@ import ( "tailscale.com/version" ) -var fileCmd = &ffcli.Command{ - Name: "file", - ShortUsage: "tailscale file ...", - ShortHelp: "Send or receive files", - Subcommands: []*ffcli.Command{ - fileCpCmd, - fileGetCmd, - }, +func init() { + fileCmd = getFileCmd +} + +func getFileCmd() *ffcli.Command { + return &ffcli.Command{ + Name: "file", + ShortUsage: "tailscale file ...", + ShortHelp: "Send or receive files", + Subcommands: []*ffcli.Command{ + fileCpCmd, + fileGetCmd, + }, + } } type countingReader struct { @@ -167,7 +176,7 @@ func runCp(ctx context.Context, args []string) error { log.Printf("sending %q to %v/%v/%v ...", name, target, ip, stableID) } - var group syncs.WaitGroup + var group sync.WaitGroup ctxProgress, cancelProgress := context.WithCancel(ctx) defer cancelProgress() if isatty.IsTerminal(os.Stderr.Fd()) { @@ -268,46 +277,77 @@ func getTargetStableID(ctx context.Context, ipStr string) (id tailcfg.StableNode if err != nil { return "", false, err } - fts, err := localClient.FileTargets(ctx) + + st, err := localClient.Status(ctx) if err != nil { - return "", false, err - } - for _, ft := range fts { - n := ft.Node - for _, a := range n.Addresses { - if a.Addr() != ip { - continue + // This likely means tailscaled is unreachable or returned an error on /localapi/v0/status. + return "", false, fmt.Errorf("failed to get local status: %w", err) + } + if st == nil { + // Handle the case if the daemon returns nil with no error. + return "", false, errors.New("no status available") + } + if st.Self == nil { + // We have a status structure, but it doesn’t include Self info. Probably not connected. + return "", false, errors.New("local node is not configured or missing Self information") + } + + // Find the PeerStatus that corresponds to ip. + var foundPeer *ipnstate.PeerStatus +peerLoop: + for _, ps := range st.Peer { + for _, pip := range ps.TailscaleIPs { + if pip == ip { + foundPeer = ps + break peerLoop } - isOffline = n.Online != nil && !*n.Online - return n.StableID, isOffline, nil } } - return "", false, fileTargetErrorDetail(ctx, ip) -} -// fileTargetErrorDetail returns a non-nil error saying why ip is an -// invalid file sharing target. -func fileTargetErrorDetail(ctx context.Context, ip netip.Addr) error { - found := false - if st, err := localClient.Status(ctx); err == nil && st.Self != nil { - for _, peer := range st.Peer { - for _, pip := range peer.TailscaleIPs { - if pip == ip { - found = true - if peer.UserID != st.Self.UserID { - return errors.New("owned by different user; can only send files to your own devices") - } - } - } + // If we didn’t find a matching peer at all: + if foundPeer == nil { + if !tsaddr.IsTailscaleIP(ip) { + return "", false, fmt.Errorf("unknown target; %v is not a Tailscale IP address", ip) } + return "", false, errors.New("unknown target; not in your Tailnet") } - if found { - return errors.New("target seems to be running an old Tailscale version") - } - if !tsaddr.IsTailscaleIP(ip) { - return fmt.Errorf("unknown target; %v is not a Tailscale IP address", ip) + + // We found a peer. Decide whether we can send files to it: + isOffline = !foundPeer.Online + + switch foundPeer.TaildropTarget { + case ipnstate.TaildropTargetAvailable: + return foundPeer.ID, isOffline, nil + + case ipnstate.TaildropTargetNoNetmapAvailable: + return "", isOffline, errors.New("cannot send files: no netmap available on this node") + + case ipnstate.TaildropTargetIpnStateNotRunning: + return "", isOffline, errors.New("cannot send files: local Tailscale is not connected to the tailnet") + + case ipnstate.TaildropTargetMissingCap: + return "", isOffline, errors.New("cannot send files: missing required Taildrop capability") + + case ipnstate.TaildropTargetOffline: + return "", isOffline, errors.New("cannot send files: peer is offline") + + case ipnstate.TaildropTargetNoPeerInfo: + return "", isOffline, errors.New("cannot send files: invalid or unrecognized peer") + + case ipnstate.TaildropTargetUnsupportedOS: + return "", isOffline, errors.New("cannot send files: target's OS does not support Taildrop") + + case ipnstate.TaildropTargetNoPeerAPI: + return "", isOffline, errors.New("cannot send files: target is not advertising a file sharing API") + + case ipnstate.TaildropTargetOwnedByOtherUser: + return "", isOffline, errors.New("cannot send files: peer is owned by a different user") + + case ipnstate.TaildropTargetUnknown: + fallthrough + default: + return "", isOffline, fmt.Errorf("cannot send files: unknown or indeterminate reason") } - return errors.New("unknown target; not in your Tailnet") } const maxSniff = 4 << 20 diff --git a/cmd/tailscale/cli/funnel.go b/cmd/tailscale/cli/funnel.go index a95f9e270..34b0c74c2 100644 --- a/cmd/tailscale/cli/funnel.go +++ b/cmd/tailscale/cli/funnel.go @@ -1,6 +1,8 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause +//go:build !ts_omit_serve + package cli import ( @@ -16,10 +18,14 @@ import ( "tailscale.com/tailcfg" ) +func init() { + maybeFunnelCmd = funnelCmd +} + var funnelCmd = func() *ffcli.Command { se := &serveEnv{lc: &localClient} // previously used to serve legacy newFunnelCommand unless useWIPCode is true - // change is limited to make a revert easier and full cleanup to come after the relase. + // change is limited to make a revert easier and full cleanup to come after the release. // TODO(tylersmalley): cleanup and removal of newFunnelCommand as of 2023-10-16 return newServeV2Command(se, funnel) } @@ -174,3 +180,42 @@ func printFunnelWarning(sc *ipn.ServeConfig) { fmt.Fprintf(Stderr, " run: `tailscale serve --help` to see how to configure handlers\n") } } + +func init() { + hookPrintFunnelStatus.Set(printFunnelStatus) +} + +// printFunnelStatus prints the status of the funnel, if it's running. +// It prints nothing if the funnel is not running. +func printFunnelStatus(ctx context.Context) { + sc, err := localClient.GetServeConfig(ctx) + if err != nil { + outln() + printf("# Funnel:\n") + printf("# - Unable to get Funnel status: %v\n", err) + return + } + if !sc.IsFunnelOn() { + return + } + outln() + printf("# Funnel on:\n") + for hp, on := range sc.AllowFunnel { + if !on { // if present, should be on + continue + } + sni, portStr, _ := net.SplitHostPort(string(hp)) + p, _ := strconv.ParseUint(portStr, 10, 16) + isTCP := sc.IsTCPForwardingOnPort(uint16(p), noService) + url := "https://" + if isTCP { + url = "tcp://" + } + url += sni + if isTCP || p != 443 { + url += ":" + portStr + } + printf("# - %s\n", url) + } + outln() +} diff --git a/cmd/tailscale/cli/jsonoutput/jsonoutput.go b/cmd/tailscale/cli/jsonoutput/jsonoutput.go new file mode 100644 index 000000000..aa49acc28 --- /dev/null +++ b/cmd/tailscale/cli/jsonoutput/jsonoutput.go @@ -0,0 +1,84 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package jsonoutput provides stable and versioned JSON serialisation for CLI output. +// This allows us to provide stable output to scripts/clients, but also make +// breaking changes to the output when it's useful. +// +// Historically we only used `--json` as a boolean flag, so changing the output +// could break scripts that rely on the existing format. +// +// This package allows callers to pass a version number to `--json` and get +// a consistent output. We'll bump the version when we make a breaking change +// that's likely to break scripts that rely on the existing output, e.g. if +// we remove a field or change the type/format. +// +// Passing just the boolean flag `--json` will always return v1, to preserve +// compatibility with scripts written before we versioned our output. +package jsonoutput + +import ( + "errors" + "fmt" + "strconv" +) + +// JSONSchemaVersion implements flag.Value, and tracks whether the CLI has +// been called with `--json`, and if so, with what value. +type JSONSchemaVersion struct { + // IsSet tracks if the flag was provided at all. + IsSet bool + + // Value tracks the desired schema version, which defaults to 1 if + // the user passes `--json` without an argument. + Value int +} + +// String returns the default value which is printed in the CLI help text. +func (v *JSONSchemaVersion) String() string { + if v.IsSet { + return strconv.Itoa(v.Value) + } else { + return "(not set)" + } +} + +// Set is called when the user passes the flag as a command-line argument. +func (v *JSONSchemaVersion) Set(s string) error { + if v.IsSet { + return errors.New("received multiple instances of --json; only pass it once") + } + + v.IsSet = true + + // If the user doesn't supply a schema version, default to 1. + // This ensures that any existing scripts will continue to get their + // current output. + if s == "true" { + v.Value = 1 + return nil + } + + version, err := strconv.Atoi(s) + if err != nil { + return fmt.Errorf("invalid integer value passed to --json: %q", s) + } + v.Value = version + return nil +} + +// IsBoolFlag tells the flag package that JSONSchemaVersion can be set +// without an argument. +func (v *JSONSchemaVersion) IsBoolFlag() bool { + return true +} + +// ResponseEnvelope is a set of fields common to all versioned JSON output. +type ResponseEnvelope struct { + // SchemaVersion is the version of the JSON output, e.g. "1", "2", "3" + SchemaVersion string + + // ResponseWarning tells a user if a newer version of the JSON output + // is available. + ResponseWarning string `json:"_WARNING,omitzero"` +} diff --git a/cmd/tailscale/cli/jsonoutput/network-lock-v1.go b/cmd/tailscale/cli/jsonoutput/network-lock-v1.go new file mode 100644 index 000000000..8a2d2de33 --- /dev/null +++ b/cmd/tailscale/cli/jsonoutput/network-lock-v1.go @@ -0,0 +1,203 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package jsonoutput + +import ( + "bytes" + "encoding/base64" + "encoding/json" + "fmt" + "io" + + "tailscale.com/ipn/ipnstate" + "tailscale.com/tka" +) + +// PrintNetworkLockJSONV1 prints the stored TKA state as a JSON object to the CLI, +// in a stable "v1" format. +// +// This format includes: +// +// - the AUM hash as a base32-encoded string +// - the raw AUM as base64-encoded bytes +// - the expanded AUM, which prints named fields for consumption by other tools +func PrintNetworkLockJSONV1(out io.Writer, updates []ipnstate.NetworkLockUpdate) error { + messages := make([]logMessageV1, len(updates)) + + for i, update := range updates { + var aum tka.AUM + if err := aum.Unserialize(update.Raw); err != nil { + return fmt.Errorf("decoding: %w", err) + } + + h := aum.Hash() + + if !bytes.Equal(h[:], update.Hash[:]) { + return fmt.Errorf("incorrect AUM hash: got %v, want %v", h, update) + } + + messages[i] = toLogMessageV1(aum, update) + } + + result := struct { + ResponseEnvelope + Messages []logMessageV1 + }{ + ResponseEnvelope: ResponseEnvelope{ + SchemaVersion: "1", + }, + Messages: messages, + } + + enc := json.NewEncoder(out) + enc.SetIndent("", " ") + return enc.Encode(result) +} + +// toLogMessageV1 converts a [tka.AUM] and [ipnstate.NetworkLockUpdate] to the +// JSON output returned by the CLI. +func toLogMessageV1(aum tka.AUM, update ipnstate.NetworkLockUpdate) logMessageV1 { + expandedAUM := expandedAUMV1{} + expandedAUM.MessageKind = aum.MessageKind.String() + if len(aum.PrevAUMHash) > 0 { + expandedAUM.PrevAUMHash = aum.PrevAUMHash.String() + } + if key := aum.Key; key != nil { + expandedAUM.Key = toExpandedKeyV1(key) + } + if keyID := aum.KeyID; keyID != nil { + expandedAUM.KeyID = fmt.Sprintf("tlpub:%x", keyID) + } + if state := aum.State; state != nil { + expandedState := expandedStateV1{} + if h := state.LastAUMHash; h != nil { + expandedState.LastAUMHash = h.String() + } + for _, secret := range state.DisablementSecrets { + expandedState.DisablementSecrets = append(expandedState.DisablementSecrets, fmt.Sprintf("%x", secret)) + } + for _, key := range state.Keys { + expandedState.Keys = append(expandedState.Keys, toExpandedKeyV1(&key)) + } + expandedState.StateID1 = state.StateID1 + expandedState.StateID2 = state.StateID2 + expandedAUM.State = expandedState + } + if votes := aum.Votes; votes != nil { + expandedAUM.Votes = *votes + } + expandedAUM.Meta = aum.Meta + for _, signature := range aum.Signatures { + expandedAUM.Signatures = append(expandedAUM.Signatures, expandedSignatureV1{ + KeyID: fmt.Sprintf("tlpub:%x", signature.KeyID), + Signature: base64.URLEncoding.EncodeToString(signature.Signature), + }) + } + + return logMessageV1{ + Hash: aum.Hash().String(), + AUM: expandedAUM, + Raw: base64.URLEncoding.EncodeToString(update.Raw), + } +} + +// toExpandedKeyV1 converts a [tka.Key] to the JSON output returned +// by the CLI. +func toExpandedKeyV1(key *tka.Key) expandedKeyV1 { + return expandedKeyV1{ + Kind: key.Kind.String(), + Votes: key.Votes, + Public: fmt.Sprintf("tlpub:%x", key.Public), + Meta: key.Meta, + } +} + +// logMessageV1 is the JSON representation of an AUM as both raw bytes and +// in its expanded form, and the CLI output is a list of these entries. +type logMessageV1 struct { + // The BLAKE2s digest of the CBOR-encoded AUM. This is printed as a + // base32-encoded string, e.g. KCEâ€ĻXZQ + Hash string + + // The expanded form of the AUM, which presents the fields in a more + // accessible format than doing a CBOR decoding. + AUM expandedAUMV1 + + // The raw bytes of the CBOR-encoded AUM, encoded as base64. + // This is useful for verifying the AUM hash. + Raw string +} + +// expandedAUMV1 is the expanded version of a [tka.AUM], designed so external tools +// can read the AUM without knowing our CBOR definitions. +type expandedAUMV1 struct { + MessageKind string + PrevAUMHash string `json:"PrevAUMHash,omitzero"` + + // Key encodes a public key to be added to the key authority. + // This field is used for AddKey AUMs. + Key expandedKeyV1 `json:"Key,omitzero"` + + // KeyID references a public key which is part of the key authority. + // This field is used for RemoveKey and UpdateKey AUMs. + KeyID string `json:"KeyID,omitzero"` + + // State describes the full state of the key authority. + // This field is used for Checkpoint AUMs. + State expandedStateV1 `json:"State,omitzero"` + + // Votes and Meta describe properties of a key in the key authority. + // These fields are used for UpdateKey AUMs. + Votes uint `json:"Votes,omitzero"` + Meta map[string]string `json:"Meta,omitzero"` + + // Signatures lists the signatures over this AUM. + Signatures []expandedSignatureV1 `json:"Signatures,omitzero"` +} + +// expandedAUMV1 is the expanded version of a [tka.Key], which describes +// the public components of a key known to network-lock. +type expandedKeyV1 struct { + Kind string + + // Votes describes the weight applied to signatures using this key. + Votes uint + + // Public encodes the public key of the key as a hex string. + Public string + + // Meta describes arbitrary metadata about the key. This could be + // used to store the name of the key, for instance. + Meta map[string]string `json:"Meta,omitzero"` +} + +// expandedStateV1 is the expanded version of a [tka.State], which describes +// Tailnet Key Authority state at an instant in time. +type expandedStateV1 struct { + // LastAUMHash is the blake2s digest of the last-applied AUM. + LastAUMHash string `json:"LastAUMHash,omitzero"` + + // DisablementSecrets are KDF-derived values which can be used + // to turn off the TKA in the event of a consensus-breaking bug. + DisablementSecrets []string + + // Keys are the public keys of either: + // + // 1. The signing nodes currently trusted by the TKA. + // 2. Ephemeral keys that were used to generate pre-signed auth keys. + Keys []expandedKeyV1 + + // StateID's are nonce's, generated on enablement and fixed for + // the lifetime of the Tailnet Key Authority. + StateID1 uint64 + StateID2 uint64 +} + +// expandedSignatureV1 is the expanded form of a [tka.Signature], which +// describes a signature over an AUM. This signature can be verified +// using the key referenced by KeyID. +type expandedSignatureV1 struct { + KeyID string + Signature string +} diff --git a/cmd/tailscale/cli/logout.go b/cmd/tailscale/cli/logout.go index 0c2007a66..fbc394730 100644 --- a/cmd/tailscale/cli/logout.go +++ b/cmd/tailscale/cli/logout.go @@ -5,12 +5,18 @@ package cli import ( "context" + "flag" "fmt" "strings" "github.com/peterbourgon/ff/v3/ffcli" + "tailscale.com/client/tailscale/apitype" ) +var logoutArgs struct { + reason string +} + var logoutCmd = &ffcli.Command{ Name: "logout", ShortUsage: "tailscale logout", @@ -22,11 +28,17 @@ the current node key, forcing a future use of it to cause a reauthentication. `), Exec: runLogout, + FlagSet: (func() *flag.FlagSet { + fs := newFlagSet("logout") + fs.StringVar(&logoutArgs.reason, "reason", "", "reason for the logout, if required by a policy") + return fs + })(), } func runLogout(ctx context.Context, args []string) error { if len(args) > 0 { return fmt.Errorf("too many non-flag arguments: %q", args) } + ctx = apitype.RequestReasonKey.WithValue(ctx, logoutArgs.reason) return localClient.Logout(ctx) } diff --git a/cmd/tailscale/cli/maybe_syspolicy.go b/cmd/tailscale/cli/maybe_syspolicy.go new file mode 100644 index 000000000..937a27833 --- /dev/null +++ b/cmd/tailscale/cli/maybe_syspolicy.go @@ -0,0 +1,8 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !ts_omit_syspolicy + +package cli + +import _ "tailscale.com/feature/syspolicy" diff --git a/cmd/tailscale/cli/metrics.go b/cmd/tailscale/cli/metrics.go new file mode 100644 index 000000000..dbdedd5a6 --- /dev/null +++ b/cmd/tailscale/cli/metrics.go @@ -0,0 +1,88 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package cli + +import ( + "context" + "errors" + "fmt" + "strings" + + "github.com/peterbourgon/ff/v3/ffcli" + "tailscale.com/atomicfile" +) + +var metricsCmd = &ffcli.Command{ + Name: "metrics", + ShortHelp: "Show Tailscale metrics", + LongHelp: strings.TrimSpace(` + +The 'tailscale metrics' command shows Tailscale user-facing metrics (as opposed +to internal metrics printed by 'tailscale debug metrics'). + +For more information about Tailscale metrics, refer to +https://tailscale.com/s/client-metrics + +`), + ShortUsage: "tailscale metrics [flags]", + UsageFunc: usageFuncNoDefaultValues, + Exec: runMetricsNoSubcommand, + Subcommands: []*ffcli.Command{ + { + Name: "print", + ShortUsage: "tailscale metrics print", + Exec: runMetricsPrint, + ShortHelp: "Print current metric values in Prometheus text format", + }, + { + Name: "write", + ShortUsage: "tailscale metrics write ", + Exec: runMetricsWrite, + ShortHelp: "Write metric values to a file", + LongHelp: strings.TrimSpace(` + +The 'tailscale metrics write' command writes metric values to a text file provided as its +only argument. It's meant to be used alongside Prometheus node exporter, allowing Tailscale +metrics to be consumed and exported by the textfile collector. + +As an example, to export Tailscale metrics on an Ubuntu system running node exporter, you +can regularly run 'tailscale metrics write /var/lib/prometheus/node-exporter/tailscaled.prom' +using cron or a systemd timer. + + `), + }, + }, +} + +// runMetricsNoSubcommand prints metric values if no subcommand is specified. +func runMetricsNoSubcommand(ctx context.Context, args []string) error { + if len(args) > 0 { + return fmt.Errorf("tailscale metrics: unknown subcommand: %s", args[0]) + } + + return runMetricsPrint(ctx, args) +} + +// runMetricsPrint prints metric values to stdout. +func runMetricsPrint(ctx context.Context, args []string) error { + out, err := localClient.UserMetrics(ctx) + if err != nil { + return err + } + Stdout.Write(out) + return nil +} + +// runMetricsWrite writes metric values to a file. +func runMetricsWrite(ctx context.Context, args []string) error { + if len(args) != 1 { + return errors.New("usage: tailscale metrics write ") + } + path := args[0] + out, err := localClient.UserMetrics(ctx) + if err != nil { + return err + } + return atomicfile.WriteFile(path, out, 0644) +} diff --git a/cmd/tailscale/cli/netcheck.go b/cmd/tailscale/cli/netcheck.go index 682cd99a3..a8a8992f5 100644 --- a/cmd/tailscale/cli/netcheck.go +++ b/cmd/tailscale/cli/netcheck.go @@ -17,13 +17,23 @@ import ( "github.com/peterbourgon/ff/v3/ffcli" "tailscale.com/envknob" + "tailscale.com/feature/buildfeatures" "tailscale.com/ipn" "tailscale.com/net/netcheck" "tailscale.com/net/netmon" - "tailscale.com/net/portmapper" + "tailscale.com/net/portmapper/portmappertype" "tailscale.com/net/tlsdial" "tailscale.com/tailcfg" "tailscale.com/types/logger" + "tailscale.com/util/eventbus" + + // The "netcheck" command also wants the portmapper linked. + // + // TODO: make that subcommand either hit LocalAPI for that info, or use a + // tailscaled subcommand, to avoid making the CLI also link in the portmapper. + // For now (2025-09-15), keep doing what we've done for the past five years and + // keep linking it here. + _ "tailscale.com/feature/condregister/portmapper" ) var netcheckCmd = &ffcli.Command{ @@ -48,15 +58,20 @@ var netcheckArgs struct { func runNetcheck(ctx context.Context, args []string) error { logf := logger.WithPrefix(log.Printf, "portmap: ") - netMon, err := netmon.New(logf) + bus := eventbus.New() + defer bus.Close() + netMon, err := netmon.New(bus, logf) if err != nil { return err } - // Ensure that we close the portmapper after running a netcheck; this - // will release any port mappings created. - pm := portmapper.NewClient(logf, netMon, nil, nil, nil) - defer pm.Close() + var pm portmappertype.Client + if buildfeatures.HasPortMapper { + // Ensure that we close the portmapper after running a netcheck; this + // will release any port mappings created. + pm = portmappertype.HookNewPortMapper.Get()(logf, bus, netMon, nil, nil) + defer pm.Close() + } c := &netcheck.Client{ NetMon: netMon, @@ -136,6 +151,7 @@ func printReport(dm *tailcfg.DERPMap, report *netcheck.Report) error { } printf("\nReport:\n") + printf("\t* Time: %v\n", report.Now.Format(time.RFC3339Nano)) printf("\t* UDP: %v\n", report.UDP) if report.GlobalV4.IsValid() { printf("\t* IPv4: yes, %s\n", report.GlobalV4) @@ -164,7 +180,11 @@ func printReport(dm *tailcfg.DERPMap, report *netcheck.Report) error { printf("\t* Nearest DERP: unknown (no response to latency probes)\n") } else { if report.PreferredDERP != 0 { - printf("\t* Nearest DERP: %v\n", dm.Regions[report.PreferredDERP].RegionName) + if region, ok := dm.Regions[report.PreferredDERP]; ok { + printf("\t* Nearest DERP: %v\n", region.RegionName) + } else { + printf("\t* Nearest DERP: %v (region not found in map)\n", report.PreferredDERP) + } } else { printf("\t* Nearest DERP: [none]\n") } @@ -202,6 +222,9 @@ func printReport(dm *tailcfg.DERPMap, report *netcheck.Report) error { } func portMapping(r *netcheck.Report) string { + if !buildfeatures.HasPortMapper { + return "binary built without portmapper support" + } if !r.AnyPortMappingChecked() { return "not checked" } diff --git a/cmd/tailscale/cli/network-lock.go b/cmd/tailscale/cli/network-lock.go index 45f989f10..73b1d6201 100644 --- a/cmd/tailscale/cli/network-lock.go +++ b/cmd/tailscale/cli/network-lock.go @@ -1,6 +1,8 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause +//go:build !ts_omit_tailnetlock + package cli import ( @@ -8,23 +10,31 @@ import ( "context" "crypto/rand" "encoding/hex" - "encoding/json" + jsonv1 "encoding/json" "errors" "flag" "fmt" + "io" "os" "strconv" "strings" "time" + "github.com/mattn/go-isatty" "github.com/peterbourgon/ff/v3/ffcli" + "tailscale.com/cmd/tailscale/cli/jsonoutput" "tailscale.com/ipn/ipnstate" "tailscale.com/tka" "tailscale.com/tsconst" "tailscale.com/types/key" "tailscale.com/types/tkatype" + "tailscale.com/util/prompt" ) +func init() { + maybeNetlockCmd = func() *ffcli.Command { return netlockCmd } +} + var netlockCmd = &ffcli.Command{ Name: "lock", ShortUsage: "tailscale lock [arguments...]", @@ -191,8 +201,7 @@ var nlStatusArgs struct { var nlStatusCmd = &ffcli.Command{ Name: "status", ShortUsage: "tailscale lock status", - ShortHelp: "Outputs the state of tailnet lock", - LongHelp: "Outputs the state of tailnet lock", + ShortHelp: "Output the state of tailnet lock", Exec: runNetworkLockStatus, FlagSet: (func() *flag.FlagSet { fs := newFlagSet("lock status") @@ -212,24 +221,24 @@ func runNetworkLockStatus(ctx context.Context, args []string) error { } if nlStatusArgs.json { - enc := json.NewEncoder(os.Stdout) + enc := jsonv1.NewEncoder(os.Stdout) enc.SetIndent("", " ") return enc.Encode(st) } if st.Enabled { - fmt.Println("Tailnet lock is ENABLED.") + fmt.Println("Tailnet Lock is ENABLED.") } else { - fmt.Println("Tailnet lock is NOT enabled.") + fmt.Println("Tailnet Lock is NOT enabled.") } fmt.Println() if st.Enabled && st.NodeKey != nil && !st.PublicKey.IsZero() { if st.NodeKeySigned { - fmt.Println("This node is accessible under tailnet lock. Node signature:") + fmt.Println("This node is accessible under Tailnet Lock. Node signature:") fmt.Println(st.NodeKeySignature.String()) } else { - fmt.Println("This node is LOCKED OUT by tailnet-lock, and action is required to establish connectivity.") + fmt.Println("This node is LOCKED OUT by Tailnet Lock, and action is required to establish connectivity.") fmt.Printf("Run the following command on a node with a trusted key:\n\ttailscale lock sign %v %s\n", st.NodeKey, st.PublicKey.CLIString()) } fmt.Println() @@ -293,8 +302,7 @@ func runNetworkLockStatus(ctx context.Context, args []string) error { var nlAddCmd = &ffcli.Command{ Name: "add", ShortUsage: "tailscale lock add ...", - ShortHelp: "Adds one or more trusted signing keys to tailnet lock", - LongHelp: "Adds one or more trusted signing keys to tailnet lock", + ShortHelp: "Add one or more trusted signing keys to tailnet lock", Exec: func(ctx context.Context, args []string) error { return runNetworkLockModify(ctx, args, nil) }, @@ -307,8 +315,7 @@ var nlRemoveArgs struct { var nlRemoveCmd = &ffcli.Command{ Name: "remove", ShortUsage: "tailscale lock remove [--re-sign=false] ...", - ShortHelp: "Removes one or more trusted signing keys from tailnet lock", - LongHelp: "Removes one or more trusted signing keys from tailnet lock", + ShortHelp: "Remove one or more trusted signing keys from tailnet lock", Exec: runNetworkLockRemove, FlagSet: (func() *flag.FlagSet { fs := newFlagSet("lock remove") @@ -329,6 +336,9 @@ func runNetworkLockRemove(ctx context.Context, args []string) error { if !st.Enabled { return errors.New("tailnet lock is not enabled") } + if len(st.TrustedKeys) == 1 { + return errors.New("cannot remove the last trusted signing key; use 'tailscale lock disable' to disable tailnet lock instead, or add another signing key before removing one") + } if nlRemoveArgs.resign { // Validate we are not removing trust in ourselves while resigning. This is because @@ -369,6 +379,18 @@ func runNetworkLockRemove(ctx context.Context, args []string) error { } } } + } else { + if isatty.IsTerminal(os.Stdout.Fd()) { + fmt.Printf(`Warning +Removal of a signing key(s) without resigning nodes (--re-sign=false) +will cause any nodes signed by the the given key(s) to be locked out +of the Tailscale network. Proceed with caution. +`) + if !prompt.YesNo("Are you sure you want to remove the signing key(s)?", true) { + fmt.Printf("aborting removal of signing key(s)\n") + os.Exit(0) + } + } } return localClient.NetworkLockModify(ctx, nil, removeKeys) @@ -448,7 +470,7 @@ func runNetworkLockModify(ctx context.Context, addArgs, removeArgs []string) err var nlSignCmd = &ffcli.Command{ Name: "sign", ShortUsage: "tailscale lock sign []\ntailscale lock sign ", - ShortHelp: "Signs a node or pre-approved auth key", + ShortHelp: "Sign a node or pre-approved auth key", LongHelp: `Either: - signs a node key and transmits the signature to the coordination server, or @@ -510,7 +532,7 @@ func runNetworkLockSign(ctx context.Context, args []string) error { var nlDisableCmd = &ffcli.Command{ Name: "disable", ShortUsage: "tailscale lock disable ", - ShortHelp: "Consumes a disablement secret to shut down tailnet lock for the tailnet", + ShortHelp: "Consume a disablement secret to shut down tailnet lock for the tailnet", LongHelp: strings.TrimSpace(` The 'tailscale lock disable' command uses the specified disablement @@ -539,7 +561,7 @@ func runNetworkLockDisable(ctx context.Context, args []string) error { var nlLocalDisableCmd = &ffcli.Command{ Name: "local-disable", ShortUsage: "tailscale lock local-disable", - ShortHelp: "Disables tailnet lock for this node only", + ShortHelp: "Disable tailnet lock for this node only", LongHelp: strings.TrimSpace(` The 'tailscale lock local-disable' command disables tailnet lock for only @@ -561,8 +583,8 @@ func runNetworkLockLocalDisable(ctx context.Context, args []string) error { var nlDisablementKDFCmd = &ffcli.Command{ Name: "disablement-kdf", ShortUsage: "tailscale lock disablement-kdf ", - ShortHelp: "Computes a disablement value from a disablement secret (advanced users only)", - LongHelp: "Computes a disablement value from a disablement secret (advanced users only)", + ShortHelp: "Compute a disablement value from a disablement secret (advanced users only)", + LongHelp: "Compute a disablement value from a disablement secret (advanced users only)", Exec: runNetworkLockDisablementKDF, } @@ -580,7 +602,7 @@ func runNetworkLockDisablementKDF(ctx context.Context, args []string) error { var nlLogArgs struct { limit int - json bool + json jsonoutput.JSONSchemaVersion } var nlLogCmd = &ffcli.Command{ @@ -592,7 +614,7 @@ var nlLogCmd = &ffcli.Command{ FlagSet: (func() *flag.FlagSet { fs := newFlagSet("lock log") fs.IntVar(&nlLogArgs.limit, "limit", 50, "max number of updates to list") - fs.BoolVar(&nlLogArgs.json, "json", false, "output in JSON format (WARNING: format subject to change)") + fs.Var(&nlLogArgs.json, "json", "output in JSON format") return fs })(), } @@ -609,7 +631,7 @@ func nlDescribeUpdate(update ipnstate.NetworkLockUpdate, color bool) (string, er printKey := func(key *tka.Key, prefix string) { fmt.Fprintf(&stanza, "%sType: %s\n", prefix, key.Kind.String()) if keyID, err := key.ID(); err == nil { - fmt.Fprintf(&stanza, "%sKeyID: %x\n", prefix, keyID) + fmt.Fprintf(&stanza, "%sKeyID: tlpub:%x\n", prefix, keyID) } else { // Older versions of the client shouldn't explode when they encounter an // unknown key type. @@ -625,16 +647,20 @@ func nlDescribeUpdate(update ipnstate.NetworkLockUpdate, color bool) (string, er return "", fmt.Errorf("decoding: %w", err) } - fmt.Fprintf(&stanza, "%supdate %x (%s)%s\n", terminalYellow, update.Hash, update.Change, terminalClear) + tkaHead, err := aum.Hash().MarshalText() + if err != nil { + return "", fmt.Errorf("decoding AUM hash: %w", err) + } + fmt.Fprintf(&stanza, "%supdate %s (%s)%s\n", terminalYellow, string(tkaHead), update.Change, terminalClear) switch update.Change { case tka.AUMAddKey.String(): printKey(aum.Key, "") case tka.AUMRemoveKey.String(): - fmt.Fprintf(&stanza, "KeyID: %x\n", aum.KeyID) + fmt.Fprintf(&stanza, "KeyID: tlpub:%x\n", aum.KeyID) case tka.AUMUpdateKey.String(): - fmt.Fprintf(&stanza, "KeyID: %x\n", aum.KeyID) + fmt.Fprintf(&stanza, "KeyID: tlpub:%x\n", aum.KeyID) if aum.Votes != nil { fmt.Fprintf(&stanza, "Votes: %d\n", aum.Votes) } @@ -654,7 +680,7 @@ func nlDescribeUpdate(update ipnstate.NetworkLockUpdate, color bool) (string, er default: // Print a JSON encoding of the AUM as a fallback. - e := json.NewEncoder(&stanza) + e := jsonv1.NewEncoder(&stanza) e.SetIndent("", "\t") if err := e.Encode(aum); err != nil { return "", err @@ -666,18 +692,33 @@ func nlDescribeUpdate(update ipnstate.NetworkLockUpdate, color bool) (string, er } func runNetworkLockLog(ctx context.Context, args []string) error { - updates, err := localClient.NetworkLockLog(ctx, nlLogArgs.limit) + st, err := localClient.NetworkLockStatus(ctx) if err != nil { return fixTailscaledConnectError(err) } - if nlLogArgs.json { - enc := json.NewEncoder(Stdout) - enc.SetIndent("", " ") - return enc.Encode(updates) + if !st.Enabled { + return errors.New("Tailnet Lock is not enabled") + } + + updates, err := localClient.NetworkLockLog(ctx, nlLogArgs.limit) + if err != nil { + return fixTailscaledConnectError(err) } out, useColor := colorableOutput() + return printNetworkLockLog(updates, out, nlLogArgs.json, useColor) +} + +func printNetworkLockLog(updates []ipnstate.NetworkLockUpdate, out io.Writer, jsonSchema jsonoutput.JSONSchemaVersion, useColor bool) error { + if jsonSchema.IsSet { + if jsonSchema.Value == 1 { + return jsonoutput.PrintNetworkLockJSONV1(out, updates) + } else { + return fmt.Errorf("unrecognised version: %q", jsonSchema.Value) + } + } + for _, update := range updates { stanza, err := nlDescribeUpdate(update, useColor) if err != nil { diff --git a/cmd/tailscale/cli/network-lock_test.go b/cmd/tailscale/cli/network-lock_test.go new file mode 100644 index 000000000..ccd2957ab --- /dev/null +++ b/cmd/tailscale/cli/network-lock_test.go @@ -0,0 +1,204 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package cli + +import ( + "bytes" + "testing" + + "github.com/google/go-cmp/cmp" + "tailscale.com/cmd/tailscale/cli/jsonoutput" + "tailscale.com/ipn/ipnstate" + "tailscale.com/tka" + "tailscale.com/types/tkatype" +) + +func TestNetworkLockLogOutput(t *testing.T) { + votes := uint(1) + aum1 := tka.AUM{ + MessageKind: tka.AUMAddKey, + Key: &tka.Key{ + Kind: tka.Key25519, + Votes: 1, + Public: []byte{2, 2}, + }, + } + h1 := aum1.Hash() + aum2 := tka.AUM{ + MessageKind: tka.AUMRemoveKey, + KeyID: []byte{3, 3}, + PrevAUMHash: h1[:], + Signatures: []tkatype.Signature{ + { + KeyID: []byte{3, 4}, + Signature: []byte{4, 5}, + }, + }, + Meta: map[string]string{"en": "three", "de": "drei", "es": "tres"}, + } + h2 := aum2.Hash() + aum3 := tka.AUM{ + MessageKind: tka.AUMCheckpoint, + PrevAUMHash: h2[:], + State: &tka.State{ + Keys: []tka.Key{ + { + Kind: tka.Key25519, + Votes: 1, + Public: []byte{1, 1}, + Meta: map[string]string{"en": "one", "de": "eins", "es": "uno"}, + }, + }, + DisablementSecrets: [][]byte{ + {1, 2, 3}, + {4, 5, 6}, + {7, 8, 9}, + }, + }, + Votes: &votes, + } + + updates := []ipnstate.NetworkLockUpdate{ + { + Hash: aum3.Hash(), + Change: aum3.MessageKind.String(), + Raw: aum3.Serialize(), + }, + { + Hash: aum2.Hash(), + Change: aum2.MessageKind.String(), + Raw: aum2.Serialize(), + }, + { + Hash: aum1.Hash(), + Change: aum1.MessageKind.String(), + Raw: aum1.Serialize(), + }, + } + + t.Run("human-readable", func(t *testing.T) { + t.Parallel() + + var outBuf bytes.Buffer + json := jsonoutput.JSONSchemaVersion{} + useColor := false + + printNetworkLockLog(updates, &outBuf, json, useColor) + + t.Logf("%s", outBuf.String()) + + want := `update 4M4Q3IXBARPQMFVXHJBDCYQMWU5H5FBKD7MFF75HE4O5JMIWR2UA (checkpoint) +Disablement values: + - 010203 + - 040506 + - 070809 +Keys: + Type: 25519 + KeyID: tlpub:0101 + Metadata: map[de:eins en:one es:uno] + +update BKVVXHOVBW7Y7YXYTLVVLMNSYG6DS5GVRVSYZLASNU3AQKA732XQ (remove-key) +KeyID: tlpub:0303 + +update UKJIKFHILQ62AEN7MQIFHXJ6SFVDGQCQA3OHVI3LWVPM736EMSAA (add-key) +Type: 25519 +KeyID: tlpub:0202 + +` + + if diff := cmp.Diff(outBuf.String(), want); diff != "" { + t.Fatalf("wrong output (-got, +want):\n%s", diff) + } + }) + + jsonV1 := `{ + "SchemaVersion": "1", + "Messages": [ + { + "Hash": "4M4Q3IXBARPQMFVXHJBDCYQMWU5H5FBKD7MFF75HE4O5JMIWR2UA", + "AUM": { + "MessageKind": "checkpoint", + "PrevAUMHash": "BKVVXHOVBW7Y7YXYTLVVLMNSYG6DS5GVRVSYZLASNU3AQKA732XQ", + "State": { + "DisablementSecrets": [ + "010203", + "040506", + "070809" + ], + "Keys": [ + { + "Kind": "25519", + "Votes": 1, + "Public": "tlpub:0101", + "Meta": { + "de": "eins", + "en": "one", + "es": "uno" + } + } + ], + "StateID1": 0, + "StateID2": 0 + }, + "Votes": 1 + }, + "Raw": "pAEFAlggCqtbndUNv4_i-JrrVbGywbw5dNWNZYysEm02CCgf3q8FowH2AoNDAQIDQwQFBkMHCAkDgaQBAQIBA0IBAQyjYmRlZGVpbnNiZW5jb25lYmVzY3VubwYB" + }, + { + "Hash": "BKVVXHOVBW7Y7YXYTLVVLMNSYG6DS5GVRVSYZLASNU3AQKA732XQ", + "AUM": { + "MessageKind": "remove-key", + "PrevAUMHash": "UKJIKFHILQ62AEN7MQIFHXJ6SFVDGQCQA3OHVI3LWVPM736EMSAA", + "KeyID": "tlpub:0303", + "Meta": { + "de": "drei", + "en": "three", + "es": "tres" + }, + "Signatures": [ + { + "KeyID": "tlpub:0304", + "Signature": "BAU=" + } + ] + }, + "Raw": "pQECAlggopKFFOhcPaARv2QQU90-kWozQFAG3Hqja7Vez-_EZIAEQgMDB6NiZGVkZHJlaWJlbmV0aHJlZWJlc2R0cmVzF4GiAUIDBAJCBAU=" + }, + { + "Hash": "UKJIKFHILQ62AEN7MQIFHXJ6SFVDGQCQA3OHVI3LWVPM736EMSAA", + "AUM": { + "MessageKind": "add-key", + "Key": { + "Kind": "25519", + "Votes": 1, + "Public": "tlpub:0202" + } + }, + "Raw": "owEBAvYDowEBAgEDQgIC" + } + ] +} +` + + t.Run("json-1", func(t *testing.T) { + t.Parallel() + t.Logf("BOOM") + + var outBuf bytes.Buffer + json := jsonoutput.JSONSchemaVersion{ + IsSet: true, + Value: 1, + } + useColor := false + + printNetworkLockLog(updates, &outBuf, json, useColor) + + want := jsonV1 + t.Logf("%s", outBuf.String()) + + if diff := cmp.Diff(outBuf.String(), want); diff != "" { + t.Fatalf("wrong output (-got, +want):\n%s", diff) + } + }) +} diff --git a/cmd/tailscale/cli/ping.go b/cmd/tailscale/cli/ping.go index 3a909f30d..8ece7c93d 100644 --- a/cmd/tailscale/cli/ping.go +++ b/cmd/tailscale/cli/ping.go @@ -16,7 +16,7 @@ import ( "time" "github.com/peterbourgon/ff/v3/ffcli" - "tailscale.com/client/tailscale" + "tailscale.com/client/local" "tailscale.com/cmd/tailscale/cli/ffcomplete" "tailscale.com/ipn/ipnstate" "tailscale.com/tailcfg" @@ -128,7 +128,7 @@ func runPing(ctx context.Context, args []string) error { for { n++ ctx, cancel := context.WithTimeout(ctx, pingArgs.timeout) - pr, err := localClient.PingWithOpts(ctx, netip.MustParseAddr(ip), pingType(), tailscale.PingOpts{Size: pingArgs.size}) + pr, err := localClient.PingWithOpts(ctx, netip.MustParseAddr(ip), pingType(), local.PingOpts{Size: pingArgs.size}) cancel() if err != nil { if errors.Is(err, context.DeadlineExceeded) { @@ -152,7 +152,9 @@ func runPing(ctx context.Context, args []string) error { } latency := time.Duration(pr.LatencySeconds * float64(time.Second)).Round(time.Millisecond) via := pr.Endpoint - if pr.DERPRegionID != 0 { + if pr.PeerRelay != "" { + via = fmt.Sprintf("peer-relay(%s)", pr.PeerRelay) + } else if pr.DERPRegionID != 0 { via = fmt.Sprintf("DERP(%s)", pr.DERPRegionCode) } if via == "" { diff --git a/cmd/tailscale/cli/risks.go b/cmd/tailscale/cli/risks.go index 4cfa50d58..d4572842b 100644 --- a/cmd/tailscale/cli/risks.go +++ b/cmd/tailscale/cli/risks.go @@ -4,24 +4,31 @@ package cli import ( + "context" "errors" "flag" - "fmt" - "os" - "os/signal" + "runtime" "strings" - "syscall" - "time" + "tailscale.com/ipn" + "tailscale.com/util/prompt" "tailscale.com/util/testenv" ) var ( - riskTypes []string - riskLoseSSH = registerRiskType("lose-ssh") - riskAll = registerRiskType("all") + riskTypes []string + riskLoseSSH = registerRiskType("lose-ssh") + riskMacAppConnector = registerRiskType("mac-app-connector") + riskStrictRPFilter = registerRiskType("linux-strict-rp-filter") + riskAll = registerRiskType("all") ) +const riskMacAppConnectorMessage = ` +You are trying to configure an app connector on macOS, which is not officially supported due to system limitations. This may result in performance and reliability issues. + +Do not use a macOS app connector for any mission-critical purposes. For the best experience, Linux is the only recommended platform for app connectors. +` + func registerRiskType(riskType string) string { riskTypes = append(riskTypes, riskType) return riskType @@ -46,11 +53,6 @@ func isRiskAccepted(riskType, acceptedRisks string) bool { var errAborted = errors.New("aborted, no changes made") -// riskAbortTimeSeconds is the number of seconds to wait after displaying the -// risk message before continuing with the operation. -// It is used by the presentRiskToUser function below. -const riskAbortTimeSeconds = 5 - // presentRiskToUser displays the risk message and waits for the user to cancel. // It returns errorAborted if the user aborts. In tests it returns errAborted // immediately unless the risk has been explicitly accepted. @@ -64,21 +66,24 @@ func presentRiskToUser(riskType, riskMessage, acceptedRisks string) error { outln(riskMessage) printf("To skip this warning, use --accept-risk=%s\n", riskType) - interrupt := make(chan os.Signal, 1) - signal.Notify(interrupt, syscall.SIGINT) - var msgLen int - for left := riskAbortTimeSeconds; left > 0; left-- { - msg := fmt.Sprintf("\rContinuing in %d seconds...", left) - msgLen = len(msg) - printf(msg) - select { - case <-interrupt: - printf("\r%s\r", strings.Repeat("x", msgLen+1)) - return errAborted - case <-time.After(time.Second): - continue - } + if prompt.YesNo("Continue?", false) { + return nil } - printf("\r%s\r", strings.Repeat(" ", msgLen)) + return errAborted } + +// checkExitNodeRisk checks if the user is using an exit node on Linux and +// whether reverse path filtering is enabled. If so, it presents a risk message. +func checkExitNodeRisk(ctx context.Context, prefs *ipn.Prefs, acceptedRisks string) error { + if runtime.GOOS != "linux" { + return nil + } + if !prefs.ExitNodeIP.IsValid() && prefs.ExitNodeID == "" { + return nil + } + if err := localClient.CheckReversePathFiltering(ctx); err != nil { + return presentRiskToUser(riskStrictRPFilter, err.Error(), acceptedRisks) + } + return nil +} diff --git a/cmd/tailscale/cli/serve_legacy.go b/cmd/tailscale/cli/serve_legacy.go index 443a404ab..580393ce4 100644 --- a/cmd/tailscale/cli/serve_legacy.go +++ b/cmd/tailscale/cli/serve_legacy.go @@ -1,6 +1,8 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause +//go:build !ts_omit_serve + package cli import ( @@ -23,13 +25,18 @@ import ( "strings" "github.com/peterbourgon/ff/v3/ffcli" - "tailscale.com/client/tailscale" + "tailscale.com/client/local" "tailscale.com/ipn" "tailscale.com/ipn/ipnstate" "tailscale.com/tailcfg" + "tailscale.com/util/slicesx" "tailscale.com/version" ) +func init() { + maybeServeCmd = serveCmd +} + var serveCmd = func() *ffcli.Command { se := &serveEnv{lc: &localClient} // previously used to serve legacy newFunnelCommand unless useWIPCode is true @@ -129,7 +136,7 @@ func (e *serveEnv) newFlags(name string, setup func(fs *flag.FlagSet)) *flag.Fla } // localServeClient is an interface conforming to the subset of -// tailscale.LocalClient. It includes only the methods used by the +// local.Client. It includes only the methods used by the // serve command. // // The purpose of this interface is to allow tests to provide a mock. @@ -138,8 +145,11 @@ type localServeClient interface { GetServeConfig(context.Context) (*ipn.ServeConfig, error) SetServeConfig(context.Context, *ipn.ServeConfig) error QueryFeature(ctx context.Context, feature string) (*tailcfg.QueryFeatureResponse, error) - WatchIPNBus(ctx context.Context, mask ipn.NotifyWatchOpt) (*tailscale.IPNBusWatcher, error) + WatchIPNBus(ctx context.Context, mask ipn.NotifyWatchOpt) (*local.IPNBusWatcher, error) IncrementCounter(ctx context.Context, name string, delta int) error + GetPrefs(ctx context.Context) (*ipn.Prefs, error) + EditPrefs(ctx context.Context, mp *ipn.MaskedPrefs) (*ipn.Prefs, error) + CheckSOMarkInUse(ctx context.Context) (bool, error) } // serveEnv is the environment the serve command runs within. All I/O should be @@ -153,17 +163,21 @@ type serveEnv struct { json bool // output JSON (status only for now) // v2 specific flags - bg bool // background mode - setPath string // serve path - https uint // HTTP port - http uint // HTTP port - tcp uint // TCP port - tlsTerminatedTCP uint // a TLS terminated TCP port - subcmd serveMode // subcommand - yes bool // update without prompt + bg bgBoolFlag // background mode + setPath string // serve path + https uint // HTTP port + http uint // HTTP port + tcp uint // TCP port + tlsTerminatedTCP uint // a TLS terminated TCP port + proxyProtocol uint // PROXY protocol version (1 or 2) + subcmd serveMode // subcommand + yes bool // update without prompt + service tailcfg.ServiceName // service name + tun bool // redirect traffic to OS for service + allServices bool // apply config file to all services + acceptAppCaps []tailcfg.PeerCapability // app capabilities to forward lc localServeClient // localClient interface, specific to serve - // optional stuff for tests: testFlagOut io.Writer testStdout io.Writer @@ -353,12 +367,12 @@ func (e *serveEnv) handleWebServe(ctx context.Context, srvPort uint16, useTLS bo if err != nil { return err } - if sc.IsTCPForwardingOnPort(srvPort) { + if sc.IsTCPForwardingOnPort(srvPort, noService) { fmt.Fprintf(Stderr, "error: cannot serve web; already serving TCP\n") return errHelp } - sc.SetWebHandler(h, dnsName, srvPort, mount, useTLS) + sc.SetWebHandler(h, dnsName, srvPort, mount, useTLS, noService.String()) if !reflect.DeepEqual(cursc, sc) { if err := e.lc.SetServeConfig(ctx, sc); err != nil { @@ -410,11 +424,11 @@ func (e *serveEnv) handleWebServeRemove(ctx context.Context, srvPort uint16, mou if err != nil { return err } - if sc.IsTCPForwardingOnPort(srvPort) { + if sc.IsTCPForwardingOnPort(srvPort, noService) { return errors.New("cannot remove web handler; currently serving TCP") } hp := ipn.HostPort(net.JoinHostPort(dnsName, strconv.Itoa(int(srvPort)))) - if !sc.WebHandlerExists(hp, mount) { + if !sc.WebHandlerExists(noService, hp, mount) { return errors.New("error: handler does not exist") } sc.RemoveWebHandler(dnsName, srvPort, []string{mount}, false) @@ -549,16 +563,16 @@ func (e *serveEnv) handleTCPServe(ctx context.Context, srcType string, srcPort u fwdAddr := "127.0.0.1:" + dstPortStr - if sc.IsServingWeb(srcPort) { - return fmt.Errorf("cannot serve TCP; already serving web on %d", srcPort) - } - dnsName, err := e.getSelfDNSName(ctx) if err != nil { return err } - sc.SetTCPForwarding(srcPort, fwdAddr, terminateTLS, dnsName) + if sc.IsServingWeb(srcPort, noService) { + return fmt.Errorf("cannot serve TCP; already serving web on %d", srcPort) + } + + sc.SetTCPForwarding(srcPort, fwdAddr, terminateTLS, 0 /* proxy proto */, dnsName) if !reflect.DeepEqual(cursc, sc) { if err := e.lc.SetServeConfig(ctx, sc); err != nil { @@ -580,11 +594,11 @@ func (e *serveEnv) handleTCPServeRemove(ctx context.Context, src uint16) error { if sc == nil { sc = new(ipn.ServeConfig) } - if sc.IsServingWeb(src) { + if sc.IsServingWeb(src, noService) { return fmt.Errorf("unable to remove; serving web, not TCP forwarding on serve port %d", src) } - if ph := sc.GetTCPPortHandler(src); ph != nil { - sc.RemoveTCPForwarding(src) + if ph := sc.GetTCPPortHandler(src, noService); ph != nil { + sc.RemoveTCPForwarding(noService, src) return e.lc.SetServeConfig(ctx, sc) } return errors.New("error: serve config does not exist") @@ -681,7 +695,7 @@ func (e *serveEnv) printWebStatusTree(sc *ipn.ServeConfig, hp ipn.HostPort) erro } scheme := "https" - if sc.IsServingHTTP(port) { + if sc.IsServingHTTP(port, noService) { scheme = "http" } @@ -707,10 +721,7 @@ func (e *serveEnv) printWebStatusTree(sc *ipn.ServeConfig, hp ipn.HostPort) erro return "", "" } - var mounts []string - for k := range sc.Web[hp].Handlers { - mounts = append(mounts, k) - } + mounts := slicesx.MapKeys(sc.Web[hp].Handlers) sort.Slice(mounts, func(i, j int) bool { return len(mounts[i]) < len(mounts[j]) }) diff --git a/cmd/tailscale/cli/serve_legacy_test.go b/cmd/tailscale/cli/serve_legacy_test.go index 2eb982ca0..819017ad8 100644 --- a/cmd/tailscale/cli/serve_legacy_test.go +++ b/cmd/tailscale/cli/serve_legacy_test.go @@ -18,7 +18,7 @@ import ( "testing" "github.com/peterbourgon/ff/v3/ffcli" - "tailscale.com/client/tailscale" + "tailscale.com/client/local" "tailscale.com/ipn" "tailscale.com/ipn/ipnstate" "tailscale.com/tailcfg" @@ -850,7 +850,7 @@ func TestVerifyFunnelEnabled(t *testing.T) { } } -// fakeLocalServeClient is a fake tailscale.LocalClient for tests. +// fakeLocalServeClient is a fake local.Client for tests. // It's not a full implementation, just enough to test the serve command. // // The fake client is stateful, and is used to test manipulating @@ -859,6 +859,9 @@ type fakeLocalServeClient struct { config *ipn.ServeConfig setCount int // counts calls to SetServeConfig queryFeatureResponse *mockQueryFeatureResponse // mock response to QueryFeature calls + prefs *ipn.Prefs // fake preferences, used to test GetPrefs and SetPrefs + SOMarkInUse bool // fake SO mark in use status + statusWithoutPeers *ipnstate.Status // nil for fakeStatus } // fakeStatus is a fake ipnstate.Status value for tests. @@ -875,10 +878,14 @@ var fakeStatus = &ipnstate.Status{ tailcfg.CapabilityFunnelPorts + "?ports=443,8443": nil, }, }, + CurrentTailnet: &ipnstate.TailnetStatus{MagicDNSSuffix: "test.ts.net"}, } func (lc *fakeLocalServeClient) StatusWithoutPeers(ctx context.Context) (*ipnstate.Status, error) { - return fakeStatus, nil + if lc.statusWithoutPeers == nil { + return fakeStatus, nil + } + return lc.statusWithoutPeers, nil } func (lc *fakeLocalServeClient) GetServeConfig(ctx context.Context) (*ipn.ServeConfig, error) { @@ -891,6 +898,21 @@ func (lc *fakeLocalServeClient) SetServeConfig(ctx context.Context, config *ipn. return nil } +func (lc *fakeLocalServeClient) GetPrefs(ctx context.Context) (*ipn.Prefs, error) { + if lc.prefs == nil { + lc.prefs = ipn.NewPrefs() + } + return lc.prefs, nil +} + +func (lc *fakeLocalServeClient) EditPrefs(ctx context.Context, prefs *ipn.MaskedPrefs) (*ipn.Prefs, error) { + if lc.prefs == nil { + lc.prefs = ipn.NewPrefs() + } + lc.prefs.ApplyEdits(prefs) + return lc.prefs, nil +} + type mockQueryFeatureResponse struct { resp *tailcfg.QueryFeatureResponse err error @@ -908,7 +930,7 @@ func (lc *fakeLocalServeClient) QueryFeature(ctx context.Context, feature string return &tailcfg.QueryFeatureResponse{Complete: true}, nil // fallback to already enabled } -func (lc *fakeLocalServeClient) WatchIPNBus(ctx context.Context, mask ipn.NotifyWatchOpt) (*tailscale.IPNBusWatcher, error) { +func (lc *fakeLocalServeClient) WatchIPNBus(ctx context.Context, mask ipn.NotifyWatchOpt) (*local.IPNBusWatcher, error) { return nil, nil // unused in tests } @@ -916,6 +938,10 @@ func (lc *fakeLocalServeClient) IncrementCounter(ctx context.Context, name strin return nil // unused in tests } +func (lc *fakeLocalServeClient) CheckSOMarkInUse(ctx context.Context) (bool, error) { + return lc.SOMarkInUse, nil +} + // exactError returns an error checker that wants exactly the provided want error. // If optName is non-empty, it's used in the error message. func exactErr(want error, optName ...string) func(error) string { diff --git a/cmd/tailscale/cli/serve_v2.go b/cmd/tailscale/cli/serve_v2.go index 009a61198..b60e645f3 100644 --- a/cmd/tailscale/cli/serve_v2.go +++ b/cmd/tailscale/cli/serve_v2.go @@ -1,6 +1,8 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause +//go:build !ts_omit_serve + package cli import ( @@ -18,16 +20,25 @@ import ( "os/signal" "path" "path/filepath" + "regexp" + "runtime" + "slices" "sort" "strconv" "strings" "github.com/peterbourgon/ff/v3/ffcli" - "tailscale.com/client/tailscale" + "tailscale.com/client/local" "tailscale.com/ipn" + "tailscale.com/ipn/conffile" "tailscale.com/ipn/ipnstate" "tailscale.com/tailcfg" + "tailscale.com/types/ipproto" + "tailscale.com/util/dnsname" "tailscale.com/util/mak" + "tailscale.com/util/prompt" + "tailscale.com/util/set" + "tailscale.com/util/slicesx" "tailscale.com/version" ) @@ -39,6 +50,90 @@ type commandInfo struct { LongHelp string } +type serviceNameFlag struct { + Value *tailcfg.ServiceName +} + +func (s *serviceNameFlag) Set(sv string) error { + if sv == "" { + s.Value = new(tailcfg.ServiceName) + return nil + } + v := tailcfg.ServiceName(sv) + if err := v.Validate(); err != nil { + return fmt.Errorf("invalid service name: %q", sv) + } + *s.Value = v + return nil +} + +// String returns the string representation of service name. +func (s *serviceNameFlag) String() string { + return s.Value.String() +} + +type bgBoolFlag struct { + Value bool + IsSet bool // tracks if the flag was set by the user +} + +// Set sets the boolean flag and whether it's explicitly set by user based on the string value. +func (b *bgBoolFlag) Set(s string) error { + v, err := strconv.ParseBool(s) + if err != nil { + return err + } + b.Value = v + b.IsSet = true + return nil +} + +// This is a hack to make the flag package recognize that this is a boolean flag. +func (b *bgBoolFlag) IsBoolFlag() bool { return true } + +// String returns the string representation of the boolean flag. +func (b *bgBoolFlag) String() string { + if !b.IsSet { + return "default" + } + return strconv.FormatBool(b.Value) +} + +type acceptAppCapsFlag struct { + Value *[]tailcfg.PeerCapability +} + +// An application capability name has the form {domain}/{name}. +// Both parts must use the (simplified) FQDN label character set. +// The "name" can contain forward slashes. +// \pL = Unicode Letter, \pN = Unicode Number, - = Hyphen +var validAppCap = regexp.MustCompile(`^([\pL\pN-]+\.)+[\pL\pN-]+\/[\pL\pN-/]+$`) + +// Set appends s to the list of appCaps to accept. +func (u *acceptAppCapsFlag) Set(s string) error { + if s == "" { + return nil + } + appCaps := strings.Split(s, ",") + for _, appCap := range appCaps { + appCap = strings.TrimSpace(appCap) + if !validAppCap.MatchString(appCap) { + return fmt.Errorf("%q does not match the form {domain}/{name}, where domain must be a fully qualified domain name", appCap) + } + *u.Value = append(*u.Value, tailcfg.PeerCapability(appCap)) + } + return nil +} + +// String returns the string representation of the slice of appCaps to accept. +func (u *acceptAppCapsFlag) String() string { + s := make([]string, len(*u.Value)) + for i, v := range *u.Value { + s[i] = string(v) + } + return strings.Join(s, ",") +} + var serveHelpCommon = strings.TrimSpace(` can be a file, directory, text, or most commonly the location to a service running on the local machine. The location to the location service can be expressed as a port number (e.g., 3000), @@ -71,8 +166,27 @@ const ( serveTypeHTTP serveTypeTCP serveTypeTLSTerminatedTCP + serveTypeTUN ) +func serveTypeFromConfString(sp conffile.ServiceProtocol) (st serveType, ok bool) { + switch sp { + case conffile.ProtoHTTP: + return serveTypeHTTP, true + case conffile.ProtoHTTPS, conffile.ProtoHTTPSInsecure, conffile.ProtoFile: + return serveTypeHTTPS, true + case conffile.ProtoTCP: + return serveTypeTCP, true + case conffile.ProtoTLSTerminatedTCP: + return serveTypeTLSTerminatedTCP, true + case conffile.ProtoTUN: + return serveTypeTUN, true + } + return -1, false +} + +const noService tailcfg.ServiceName = "" + var infoMap = map[serveMode]commandInfo{ serve: { Name: "serve", @@ -118,15 +232,19 @@ func newServeV2Command(e *serveEnv, subcmd serveMode) *ffcli.Command { Exec: e.runServeCombined(subcmd), FlagSet: e.newFlags("serve-set", func(fs *flag.FlagSet) { - fs.BoolVar(&e.bg, "bg", false, "Run the command as a background process (default false)") + fs.Var(&e.bg, "bg", "Run the command as a background process (default false, when --service is set defaults to true).") fs.StringVar(&e.setPath, "set-path", "", "Appends the specified path to the base URL for accessing the underlying service") fs.UintVar(&e.https, "https", 0, "Expose an HTTPS server at the specified port (default mode)") if subcmd == serve { fs.UintVar(&e.http, "http", 0, "Expose an HTTP server at the specified port") + fs.Var(&acceptAppCapsFlag{Value: &e.acceptAppCaps}, "accept-app-caps", "App capabilities to forward to the server (specify multiple capabilities with a comma-separated list)") + fs.Var(&serviceNameFlag{Value: &e.service}, "service", "Serve for a service with distinct virtual IP instead on node itself.") } fs.UintVar(&e.tcp, "tcp", 0, "Expose a TCP forwarder to forward raw TCP packets at the specified port") fs.UintVar(&e.tlsTerminatedTCP, "tls-terminated-tcp", 0, "Expose a TCP forwarder to forward TLS-terminated TCP packets at the specified port") + fs.UintVar(&e.proxyProtocol, "proxy-protocol", 0, "PROXY protocol version (1 or 2) for TCP forwarding") fs.BoolVar(&e.yes, "yes", false, "Update without interactive prompts (default false)") + fs.BoolVar(&e.tun, "tun", false, "Forward all traffic to the local machine (default false), only supported for services. Refer to docs for more information.") }), UsageFunc: usageFuncNoDefaultValues, Subcommands: []*ffcli.Command{ @@ -146,6 +264,61 @@ func newServeV2Command(e *serveEnv, subcmd serveMode) *ffcli.Command { Exec: e.runServeReset, FlagSet: e.newFlags("serve-reset", nil), }, + { + Name: "drain", + ShortUsage: fmt.Sprintf("tailscale %s drain ", info.Name), + ShortHelp: "Drain a service from the current node", + LongHelp: "Make the current node no longer accept new connections for the specified service.\n" + + "Existing connections will continue to work until they are closed, but no new connections will be accepted.\n" + + "Use this command to gracefully remove a service from the current node without disrupting existing connections.\n" + + " should be a service name (e.g., svc:my-service).", + Exec: e.runServeDrain, + }, + { + Name: "clear", + ShortUsage: fmt.Sprintf("tailscale %s clear ", info.Name), + ShortHelp: "Remove all config for a service", + LongHelp: "Remove all handlers configured for the specified service.", + Exec: e.runServeClear, + }, + { + Name: "advertise", + ShortUsage: fmt.Sprintf("tailscale %s advertise ", info.Name), + ShortHelp: "Advertise this node as a service proxy to the tailnet", + LongHelp: "Advertise this node as a service proxy to the tailnet. This command is used\n" + + "to make the current node be considered as a service host for a service. This is\n" + + "useful to bring a service back after it has been drained. (i.e. after running \n" + + "`tailscale serve drain `). This is not needed if you are using `tailscale serve` to initialize a service.", + Exec: e.runServeAdvertise, + }, + { + Name: "get-config", + ShortUsage: fmt.Sprintf("tailscale %s get-config [--service=] [--all]", info.Name), + ShortHelp: "Get service configuration to save to a file", + LongHelp: "Get the configuration for services that this node is currently hosting in a\n" + + "format that can later be provided to set-config. This can be used to declaratively set\n" + + "configuration for a service host.", + Exec: e.runServeGetConfig, + FlagSet: e.newFlags("serve-get-config", func(fs *flag.FlagSet) { + fs.BoolVar(&e.allServices, "all", false, "read config from all services") + fs.Var(&serviceNameFlag{Value: &e.service}, "service", "read config from a particular service") + }), + }, + { + Name: "set-config", + ShortUsage: fmt.Sprintf("tailscale %s set-config [--service=] [--all]", info.Name), + ShortHelp: "Define service configuration from a file", + LongHelp: "Read the provided configuration file and use it to declaratively set the configuration\n" + + "for either a single service, or for all services that this node is hosting. If --service is specified,\n" + + "all endpoint handlers for that service are overwritten. If --all is specified, all endpoint handlers for\n" + + "all services are overwritten.\n\n" + + "For information on the file format, see tailscale.com/kb/1589/tailscale-services-configuration-file", + Exec: e.runServeSetConfig, + FlagSet: e.newFlags("serve-set-config", func(fs *flag.FlagSet) { + fs.BoolVar(&e.allServices, "all", false, "apply config to all services") + fs.Var(&serviceNameFlag{Value: &e.service}, "service", "apply config to a particular service") + }), + }, }, } } @@ -160,9 +333,16 @@ func (e *serveEnv) validateArgs(subcmd serveMode, args []string) error { fmt.Fprint(e.stderr(), "\nPlease see https://tailscale.com/kb/1242/tailscale-serve for more information.\n") return errHelpFunc(subcmd) } + if len(args) == 0 && e.tun { + return nil + } if len(args) == 0 { return flag.ErrHelp } + if e.tun && len(args) > 1 { + fmt.Fprintln(e.stderr(), "Error: invalid argument format") + return errHelpFunc(subcmd) + } if len(args) > 2 { fmt.Fprintf(e.stderr(), "Error: invalid number of arguments (%d)\n", len(args)) return errHelpFunc(subcmd) @@ -204,7 +384,16 @@ func (e *serveEnv) runServeCombined(subcmd serveMode) execFunc { ctx, cancel := signal.NotifyContext(ctx, os.Interrupt) defer cancel() + forService := e.service != "" + if !e.bg.IsSet { + e.bg.Value = forService + } + funnel := subcmd == funnel + if forService && funnel { + return errors.New("Error: --service flag is not supported with funnel") + } + if funnel { // verify node has funnel capabilities if err := e.verifyFunnelEnabled(ctx, 443); err != nil { @@ -212,6 +401,10 @@ func (e *serveEnv) runServeCombined(subcmd serveMode) execFunc { } } + if forService && !e.bg.Value { + return errors.New("Error: --service flag is only compatible with background mode") + } + mount, err := cleanURLPath(e.setPath) if err != nil { return fmt.Errorf("failed to clean the mount point: %w", err) @@ -223,6 +416,14 @@ func (e *serveEnv) runServeCombined(subcmd serveMode) execFunc { return errHelpFunc(subcmd) } + if (srvType == serveTypeHTTP || srvType == serveTypeHTTPS) && e.proxyProtocol != 0 { + return fmt.Errorf("PROXY protocol is only supported for TCP forwarding, not HTTP/HTTPS") + } + // Validate PROXY protocol version + if e.proxyProtocol != 0 && e.proxyProtocol != 1 && e.proxyProtocol != 2 { + return fmt.Errorf("invalid PROXY protocol version %d; must be 1 or 2", e.proxyProtocol) + } + sc, err := e.lc.GetServeConfig(ctx) if err != nil { return fmt.Errorf("error getting serve config: %w", err) @@ -237,6 +438,7 @@ func (e *serveEnv) runServeCombined(subcmd serveMode) execFunc { return fmt.Errorf("getting client status: %w", err) } dnsName := strings.TrimSuffix(st.Self.DNSName, ".") + magicDNSSuffix := st.CurrentTailnet.MagicDNSSuffix // set parent serve config to always be persisted // at the top level, but a nested config might be @@ -244,7 +446,7 @@ func (e *serveEnv) runServeCombined(subcmd serveMode) execFunc { // foreground or background. parentSC := sc - turnOff := "off" == args[len(args)-1] + turnOff := len(args) > 0 && "off" == args[len(args)-1] if !turnOff && srvType == serveTypeHTTPS { // Running serve with https requires that the tailnet has enabled // https cert provisioning. Send users through an interactive flow @@ -260,18 +462,31 @@ func (e *serveEnv) runServeCombined(subcmd serveMode) execFunc { } } - var watcher *tailscale.IPNBusWatcher - wantFg := !e.bg && !turnOff + var watcher *local.IPNBusWatcher + svcName := noService + + if forService { + svcName = e.service + dnsName = e.service.String() + } + tagged := st.Self.Tags != nil && st.Self.Tags.Len() > 0 + if forService && !tagged && !turnOff { + return errors.New("service hosts must be tagged nodes") + } + if !forService && srvType == serveTypeTUN { + return errors.New("tun mode is only supported for services") + } + wantFg := !e.bg.Value && !turnOff if wantFg { // validate the config before creating a WatchIPNBus session - if err := e.validateConfig(parentSC, srvPort, srvType); err != nil { + if err := e.validateConfig(parentSC, srvPort, srvType, svcName); err != nil { return err } // if foreground mode, create a WatchIPNBus session // and use the nested config for all following operations // TODO(marwan-at-work): nested-config validations should happen here or previous to this point. - watcher, err = e.lc.WatchIPNBus(ctx, ipn.NotifyInitialState|ipn.NotifyNoPrivateKeys) + watcher, err = e.lc.WatchIPNBus(ctx, ipn.NotifyInitialState) if err != nil { return err } @@ -290,12 +505,23 @@ func (e *serveEnv) runServeCombined(subcmd serveMode) execFunc { var msg string if turnOff { - err = e.unsetServe(sc, dnsName, srvType, srvPort, mount) + // only unset serve when trying to unset with type and port flags. + err = e.unsetServe(sc, dnsName, srvType, srvPort, mount, magicDNSSuffix) } else { - if err := e.validateConfig(parentSC, srvPort, srvType); err != nil { + if err := e.validateConfig(parentSC, srvPort, srvType, svcName); err != nil { return err } - err = e.setServe(sc, st, dnsName, srvType, srvPort, mount, args[0], funnel) + if forService { + e.addServiceToPrefs(ctx, svcName) + } + target := "" + if len(args) > 0 { + target = args[0] + } + if err := e.shouldWarnRemoteDestCompatibility(ctx, target); err != nil { + return err + } + err = e.setServe(sc, dnsName, srvType, srvPort, mount, target, funnel, magicDNSSuffix, e.acceptAppCaps, int(e.proxyProtocol)) msg = e.messageForPort(sc, st, dnsName, srvType, srvPort) } if err != nil { @@ -304,7 +530,7 @@ func (e *serveEnv) runServeCombined(subcmd serveMode) execFunc { } if err := e.lc.SetServeConfig(ctx, parentSC); err != nil { - if tailscale.IsPreconditionsFailedError(err) { + if local.IsPreconditionsFailedError(err) { fmt.Fprintln(e.stderr(), "Another client is changing the serve config; please try again.") } return err @@ -330,22 +556,398 @@ func (e *serveEnv) runServeCombined(subcmd serveMode) execFunc { } } -const backgroundExistsMsg = "background configuration already exists, use `tailscale %s --%s=%d off` to remove the existing configuration" +func (e *serveEnv) addServiceToPrefs(ctx context.Context, serviceName tailcfg.ServiceName) error { + prefs, err := e.lc.GetPrefs(ctx) + if err != nil { + return fmt.Errorf("error getting prefs: %w", err) + } + advertisedServices := prefs.AdvertiseServices + if slices.Contains(advertisedServices, serviceName.String()) { + return nil // already advertised + } + advertisedServices = append(advertisedServices, serviceName.String()) + _, err = e.lc.EditPrefs(ctx, &ipn.MaskedPrefs{ + AdvertiseServicesSet: true, + Prefs: ipn.Prefs{ + AdvertiseServices: advertisedServices, + }, + }) + return err +} -func (e *serveEnv) validateConfig(sc *ipn.ServeConfig, port uint16, wantServe serveType) error { - sc, isFg := sc.FindConfig(port) - if sc == nil { +func (e *serveEnv) removeServiceFromPrefs(ctx context.Context, serviceName tailcfg.ServiceName) error { + prefs, err := e.lc.GetPrefs(ctx) + if err != nil { + return fmt.Errorf("error getting prefs: %w", err) + } + if len(prefs.AdvertiseServices) == 0 { + return nil // nothing to remove + } + initialLen := len(prefs.AdvertiseServices) + prefs.AdvertiseServices = slices.DeleteFunc(prefs.AdvertiseServices, func(s string) bool { return s == serviceName.String() }) + if initialLen == len(prefs.AdvertiseServices) { + return nil // serviceName not advertised + } + _, err = e.lc.EditPrefs(ctx, &ipn.MaskedPrefs{ + AdvertiseServicesSet: true, + Prefs: ipn.Prefs{ + AdvertiseServices: prefs.AdvertiseServices, + }, + }) + return err +} + +func (e *serveEnv) runServeDrain(ctx context.Context, args []string) error { + if len(args) == 0 { + return errHelp + } + if len(args) != 1 { + fmt.Fprintf(Stderr, "error: invalid number of arguments\n\n") + return errHelp + } + svc := args[0] + svcName := tailcfg.ServiceName(svc) + if err := svcName.Validate(); err != nil { + return fmt.Errorf("invalid service name: %w", err) + } + return e.removeServiceFromPrefs(ctx, svcName) +} + +func (e *serveEnv) runServeClear(ctx context.Context, args []string) error { + if len(args) == 0 { + return errHelp + } + if len(args) != 1 { + fmt.Fprintf(Stderr, "error: invalid number of arguments\n\n") + return errHelp + } + svc := tailcfg.ServiceName(args[0]) + if err := svc.Validate(); err != nil { + return fmt.Errorf("invalid service name: %w", err) + } + sc, err := e.lc.GetServeConfig(ctx) + if err != nil { + return fmt.Errorf("error getting serve config: %w", err) + } + if _, ok := sc.Services[svc]; !ok { + log.Printf("service %s not found in serve config, nothing to clear", svc) return nil } - if isFg { - return errors.New("foreground already exists under this port") + delete(sc.Services, svc) + if err := e.removeServiceFromPrefs(ctx, svc); err != nil { + return fmt.Errorf("error removing service %s from prefs: %w", svc, err) + } + return e.lc.SetServeConfig(ctx, sc) +} + +func (e *serveEnv) runServeAdvertise(ctx context.Context, args []string) error { + if len(args) == 0 { + return errors.New("error: missing service name argument") + } + if len(args) != 1 { + fmt.Fprintf(Stderr, "error: invalid number of arguments\n\n") + return errHelp + } + svc := tailcfg.ServiceName(args[0]) + if err := svc.Validate(); err != nil { + return fmt.Errorf("invalid service name: %w", err) + } + return e.addServiceToPrefs(ctx, svc) +} + +func (e *serveEnv) runServeGetConfig(ctx context.Context, args []string) (err error) { + forSingleService := e.service.Validate() == nil + sc, err := e.lc.GetServeConfig(ctx) + if err != nil { + return err + } + + prefs, err := e.lc.GetPrefs(ctx) + if err != nil { + return err + } + advertised := set.SetOf(prefs.AdvertiseServices) + + st, err := e.getLocalClientStatusWithoutPeers(ctx) + if err != nil { + return err } - if !e.bg { - return fmt.Errorf(backgroundExistsMsg, infoMap[e.subcmd].Name, wantServe.String(), port) + magicDNSSuffix := st.CurrentTailnet.MagicDNSSuffix + + handleService := func(svcName tailcfg.ServiceName, serviceConfig *ipn.ServiceConfig) (*conffile.ServiceDetailsFile, error) { + var sdf conffile.ServiceDetailsFile + // Leave unset for true case since that's the default. + if !advertised.Contains(svcName.String()) { + sdf.Advertised.Set(false) + } + + if serviceConfig.Tun { + mak.Set(&sdf.Endpoints, &tailcfg.ProtoPortRange{Ports: tailcfg.PortRangeAny}, &conffile.Target{ + Protocol: conffile.ProtoTUN, + Destination: "", + DestinationPorts: tailcfg.PortRange{}, + }) + } + + for port, config := range serviceConfig.TCP { + sniName := fmt.Sprintf("%s.%s", svcName.WithoutPrefix(), magicDNSSuffix) + ppr := tailcfg.ProtoPortRange{Proto: int(ipproto.TCP), Ports: tailcfg.PortRange{First: port, Last: port}} + if config.TCPForward != "" { + var proto conffile.ServiceProtocol + if config.TerminateTLS != "" { + proto = conffile.ProtoTLSTerminatedTCP + } else { + proto = conffile.ProtoTCP + } + destHost, destPortStr, err := net.SplitHostPort(config.TCPForward) + if err != nil { + return nil, fmt.Errorf("parse TCPForward=%q: %w", config.TCPForward, err) + } + destPort, err := strconv.ParseUint(destPortStr, 10, 16) + if err != nil { + return nil, fmt.Errorf("parse port %q: %w", destPortStr, err) + } + mak.Set(&sdf.Endpoints, &ppr, &conffile.Target{ + Protocol: proto, + Destination: destHost, + DestinationPorts: tailcfg.PortRange{First: uint16(destPort), Last: uint16(destPort)}, + }) + } else if config.HTTP || config.HTTPS { + webKey := ipn.HostPort(net.JoinHostPort(sniName, strconv.FormatUint(uint64(port), 10))) + handlers, ok := serviceConfig.Web[webKey] + if !ok { + return nil, fmt.Errorf("service %q: HTTP/HTTPS is set but no handlers in config", svcName) + } + defaultHandler, ok := handlers.Handlers["/"] + if !ok { + return nil, fmt.Errorf("service %q: root handler not set", svcName) + } + if defaultHandler.Path != "" { + mak.Set(&sdf.Endpoints, &ppr, &conffile.Target{ + Protocol: conffile.ProtoFile, + Destination: defaultHandler.Path, + DestinationPorts: tailcfg.PortRange{}, + }) + } else if defaultHandler.Proxy != "" { + proto, rest, ok := strings.Cut(defaultHandler.Proxy, "://") + if !ok { + return nil, fmt.Errorf("service %q: invalid proxy handler %q", svcName, defaultHandler.Proxy) + } + host, portStr, err := net.SplitHostPort(rest) + if err != nil { + return nil, fmt.Errorf("service %q: invalid proxy handler %q: %w", svcName, defaultHandler.Proxy, err) + } + + port, err := strconv.ParseUint(portStr, 10, 16) + if err != nil { + return nil, fmt.Errorf("service %q: parse port %q: %w", svcName, portStr, err) + } + + mak.Set(&sdf.Endpoints, &ppr, &conffile.Target{ + Protocol: conffile.ServiceProtocol(proto), + Destination: host, + DestinationPorts: tailcfg.PortRange{First: uint16(port), Last: uint16(port)}, + }) + } + } + } + + return &sdf, nil } - existingServe := serveFromPortHandler(sc.TCP[port]) + + var j []byte + + if e.allServices && forSingleService { + return errors.New("cannot specify both --all and --service") + } else if e.allServices { + var scf conffile.ServicesConfigFile + scf.Version = "0.0.1" + for svcName, serviceConfig := range sc.Services { + sdf, err := handleService(svcName, serviceConfig) + if err != nil { + return err + } + mak.Set(&scf.Services, svcName, sdf) + } + j, err = json.MarshalIndent(scf, "", " ") + if err != nil { + return err + } + } else if forSingleService { + serviceConfig, ok := sc.Services[e.service] + if !ok { + j = []byte("{}") + } else { + sdf, err := handleService(e.service, serviceConfig) + if err != nil { + return err + } + sdf.Version = "0.0.1" + j, err = json.MarshalIndent(sdf, "", " ") + if err != nil { + return err + } + } + } else { + return errors.New("must specify either --service=svc: or --all") + } + + j = append(j, '\n') + _, err = e.stdout().Write(j) + return err +} + +func (e *serveEnv) runServeSetConfig(ctx context.Context, args []string) (err error) { + if len(args) != 1 { + return errors.New("must specify filename") + } + forSingleService := e.service.Validate() == nil + + var scf *conffile.ServicesConfigFile + if e.allServices && forSingleService { + return errors.New("cannot specify both --all and --service") + } else if e.allServices { + scf, err = conffile.LoadServicesConfig(args[0], "") + } else if forSingleService { + scf, err = conffile.LoadServicesConfig(args[0], e.service.String()) + } else { + return errors.New("must specify either --service=svc: or --all") + } + if err != nil { + return fmt.Errorf("could not read config from file %q: %w", args[0], err) + } + + st, err := e.getLocalClientStatusWithoutPeers(ctx) + if err != nil { + return fmt.Errorf("getting client status: %w", err) + } + magicDNSSuffix := st.CurrentTailnet.MagicDNSSuffix + sc, err := e.lc.GetServeConfig(ctx) + if err != nil { + return fmt.Errorf("getting current serve config: %w", err) + } + + // Clear all existing config. + if forSingleService { + if sc.Services != nil { + if sc.Services[e.service] != nil { + delete(sc.Services, e.service) + } + } + } else { + sc.Services = map[tailcfg.ServiceName]*ipn.ServiceConfig{} + } + advertisedServices := set.Set[string]{} + + for name, details := range scf.Services { + for ppr, ep := range details.Endpoints { + if ep.Protocol == conffile.ProtoTUN { + err := e.setServe(sc, name.String(), serveTypeTUN, 0, "", "", false, magicDNSSuffix, nil, 0 /* proxy protocol */) + if err != nil { + return err + } + // TUN mode is exclusive. + break + } + + if ppr.Proto != int(ipproto.TCP) { + return fmt.Errorf("service %q: source ports must be TCP", name) + } + serveType, _ := serveTypeFromConfString(ep.Protocol) + for port := ppr.Ports.First; port <= ppr.Ports.Last; port++ { + var target string + if ep.Protocol == conffile.ProtoFile { + target = ep.Destination + } else { + // map source port range 1-1 to destination port range + destPort := ep.DestinationPorts.First + (port - ppr.Ports.First) + portStr := fmt.Sprint(destPort) + target = fmt.Sprintf("%s://%s", ep.Protocol, net.JoinHostPort(ep.Destination, portStr)) + } + err := e.setServe(sc, name.String(), serveType, port, "/", target, false, magicDNSSuffix, nil, 0 /* proxy protocol */) + if err != nil { + return fmt.Errorf("service %q: %w", name, err) + } + } + } + if v, set := details.Advertised.Get(); !set || v { + advertisedServices.Add(name.String()) + } + } + + var changed bool + var servicesList []string + if e.allServices { + servicesList = advertisedServices.Slice() + changed = true + } else if advertisedServices.Contains(e.service.String()) { + // If allServices wasn't set, the only service that could have been + // advertised is the one that was provided as a flag. + prefs, err := e.lc.GetPrefs(ctx) + if err != nil { + return err + } + if !slices.Contains(prefs.AdvertiseServices, e.service.String()) { + servicesList = append(prefs.AdvertiseServices, e.service.String()) + changed = true + } + } + if changed { + _, err = e.lc.EditPrefs(ctx, &ipn.MaskedPrefs{ + AdvertiseServicesSet: true, + Prefs: ipn.Prefs{ + AdvertiseServices: servicesList, + }, + }) + if err != nil { + return err + } + } + + return e.lc.SetServeConfig(ctx, sc) +} + +const backgroundExistsMsg = "background configuration already exists, use `tailscale %s --%s=%d off` to remove the existing configuration" + +// validateConfig checks if the serve config is valid to serve the type wanted on the port. +// dnsName is a FQDN or a serviceName (with `svc:` prefix). +func (e *serveEnv) validateConfig(sc *ipn.ServeConfig, port uint16, wantServe serveType, svcName tailcfg.ServiceName) error { + var tcpHandlerForPort *ipn.TCPPortHandler + if svcName != noService { + svc := sc.Services[svcName] + if svc == nil { + return nil + } + if wantServe == serveTypeTUN && (svc.TCP != nil || svc.Web != nil) { + return errors.New("service already has a TCP or Web handler, cannot serve in TUN mode") + } + if svc.Tun && wantServe != serveTypeTUN { + return errors.New("service is already being served in TUN mode") + } + if svc.TCP[port] == nil { + return nil + } + tcpHandlerForPort = svc.TCP[port] + } else { + sc, isFg := sc.FindConfig(port) + if sc == nil { + return nil + } + if isFg { + return errors.New("foreground already exists under this port") + } + if !e.bg.Value { + return fmt.Errorf(backgroundExistsMsg, infoMap[e.subcmd].Name, wantServe.String(), port) + } + tcpHandlerForPort = sc.TCP[port] + } + existingServe := serveFromPortHandler(tcpHandlerForPort) if wantServe != existingServe { - return fmt.Errorf("want %q but port is already serving %q", wantServe, existingServe) + target := svcName + if target == noService { + target = "machine" + } + return fmt.Errorf("want to serve %q but port is already serving %q for %q", wantServe, existingServe, target) } return nil } @@ -365,12 +967,12 @@ func serveFromPortHandler(tcp *ipn.TCPPortHandler) serveType { } } -func (e *serveEnv) setServe(sc *ipn.ServeConfig, st *ipnstate.Status, dnsName string, srvType serveType, srvPort uint16, mount string, target string, allowFunnel bool) error { +func (e *serveEnv) setServe(sc *ipn.ServeConfig, dnsName string, srvType serveType, srvPort uint16, mount string, target string, allowFunnel bool, mds string, caps []tailcfg.PeerCapability, proxyProtocol int) error { // update serve config based on the type switch srvType { case serveTypeHTTPS, serveTypeHTTP: useTLS := srvType == serveTypeHTTPS - err := e.applyWebServe(sc, dnsName, srvPort, useTLS, mount, target) + err := e.applyWebServe(sc, dnsName, srvPort, useTLS, mount, target, mds, caps) if err != nil { return fmt.Errorf("failed apply web serve: %w", err) } @@ -378,45 +980,61 @@ func (e *serveEnv) setServe(sc *ipn.ServeConfig, st *ipnstate.Status, dnsName st if e.setPath != "" { return fmt.Errorf("cannot mount a path for TCP serve") } - - err := e.applyTCPServe(sc, dnsName, srvType, srvPort, target) + err := e.applyTCPServe(sc, dnsName, srvType, srvPort, target, proxyProtocol) if err != nil { return fmt.Errorf("failed to apply TCP serve: %w", err) } + case serveTypeTUN: + // Caller checks that TUN mode is only supported for services. + svcName := tailcfg.ServiceName(dnsName) + if _, ok := sc.Services[svcName]; !ok { + mak.Set(&sc.Services, svcName, new(ipn.ServiceConfig)) + } + sc.Services[svcName].Tun = true default: return fmt.Errorf("invalid type %q", srvType) } // update the serve config based on if funnel is enabled - e.applyFunnel(sc, dnsName, srvPort, allowFunnel) - + // Since funnel is not supported for services, we only apply it for node's serve. + if svcName := tailcfg.AsServiceName(dnsName); svcName == noService { + e.applyFunnel(sc, dnsName, srvPort, allowFunnel) + } return nil } var ( - msgFunnelAvailable = "Available on the internet:" - msgServeAvailable = "Available within your tailnet:" - msgRunningInBackground = "%s started and running in the background." - msgDisableProxy = "To disable the proxy, run: tailscale %s --%s=%d off" - msgToExit = "Press Ctrl+C to exit." + msgFunnelAvailable = "Available on the internet:" + msgServeAvailable = "Available within your tailnet:" + msgServiceWaitingApproval = "This machine is configured as a service proxy for %s, but approval from an admin is required. Once approved, it will be available in your Tailnet as:" + msgRunningInBackground = "%s started and running in the background." + msgRunningTunService = "IPv4 and IPv6 traffic to %s is being routed to your operating system." + msgDisableProxy = "To disable the proxy, run: tailscale %s --%s=%d off" + msgDisableServiceProxy = "To disable the proxy, run: tailscale serve --service=%s --%s=%d off" + msgDisableServiceTun = "To disable the service in TUN mode, run: tailscale serve --service=%s --tun off" + msgDisableService = "To remove config for the service, run: tailscale serve clear %s" + msgWarnRemoteDestCompatibility = "Warning: %s doesn't support connecting to remote destinations from non-default route, see tailscale.com/kb/1552/tailscale-services for detail." + msgToExit = "Press Ctrl+C to exit." ) // messageForPort returns a message for the given port based on the // serve config and status. func (e *serveEnv) messageForPort(sc *ipn.ServeConfig, st *ipnstate.Status, dnsName string, srvType serveType, srvPort uint16) string { var output strings.Builder - - hp := ipn.HostPort(net.JoinHostPort(dnsName, strconv.Itoa(int(srvPort)))) - - if sc.AllowFunnel[hp] == true { - output.WriteString(msgFunnelAvailable) - } else { - output.WriteString(msgServeAvailable) + svcName := tailcfg.AsServiceName(dnsName) + forService := svcName != noService + var webConfig *ipn.WebServerConfig + var tcpHandler *ipn.TCPPortHandler + ips := st.TailscaleIPs + magicDNSSuffix := st.CurrentTailnet.MagicDNSSuffix + host := dnsName + if forService { + host = strings.Join([]string{svcName.WithoutPrefix(), magicDNSSuffix}, ".") } - output.WriteString("\n\n") + hp := ipn.HostPort(net.JoinHostPort(host, strconv.Itoa(int(srvPort)))) scheme := "https" - if sc.IsServingHTTP(srvPort) { + if sc.IsServingHTTP(srvPort, svcName) { scheme = "http" } @@ -437,41 +1055,71 @@ func (e *serveEnv) messageForPort(sc *ipn.ServeConfig, st *ipnstate.Status, dnsN } return "", "" } - - if sc.Web[hp] != nil { - var mounts []string - - for k := range sc.Web[hp].Handlers { - mounts = append(mounts, k) + if forService { + serviceIPMaps, err := tailcfg.UnmarshalNodeCapJSON[tailcfg.ServiceIPMappings](st.Self.CapMap, tailcfg.NodeAttrServiceHost) + if err != nil || len(serviceIPMaps) == 0 || serviceIPMaps[0][svcName] == nil { + // The capmap does not contain IPs for this service yet. Usually this means + // the service hasn't been added to prefs and sent to control yet. + output.WriteString(fmt.Sprintf(msgServiceWaitingApproval, svcName.String())) + ips = nil + } else { + output.WriteString(msgServeAvailable) + ips = serviceIPMaps[0][svcName] + } + output.WriteString("\n\n") + svc := sc.Services[svcName] + if srvType == serveTypeTUN && svc.Tun { + output.WriteString(fmt.Sprintf(msgRunningTunService, host)) + output.WriteString("\n") + output.WriteString(fmt.Sprintf(msgDisableServiceTun, dnsName)) + output.WriteString("\n") + output.WriteString(fmt.Sprintf(msgDisableService, dnsName)) + return output.String() + } + if svc != nil { + webConfig = svc.Web[hp] + tcpHandler = svc.TCP[srvPort] } + } else { + if sc.AllowFunnel[hp] == true { + output.WriteString(msgFunnelAvailable) + } else { + output.WriteString(msgServeAvailable) + } + output.WriteString("\n\n") + webConfig = sc.Web[hp] + tcpHandler = sc.TCP[srvPort] + } + + if webConfig != nil { + mounts := slicesx.MapKeys(webConfig.Handlers) sort.Slice(mounts, func(i, j int) bool { return len(mounts[i]) < len(mounts[j]) }) - for _, m := range mounts { - h := sc.Web[hp].Handlers[m] - t, d := srvTypeAndDesc(h) - output.WriteString(fmt.Sprintf("%s://%s%s%s\n", scheme, dnsName, portPart, m)) + t, d := srvTypeAndDesc(webConfig.Handlers[m]) + output.WriteString(fmt.Sprintf("%s://%s%s%s\n", scheme, host, portPart, m)) output.WriteString(fmt.Sprintf("%s %-5s %s\n\n", "|--", t, d)) } - } else if sc.TCP[srvPort] != nil { - h := sc.TCP[srvPort] + } else if tcpHandler != nil { tlsStatus := "TLS over TCP" - if h.TerminateTLS != "" { + if tcpHandler.TerminateTLS != "" { tlsStatus = "TLS terminated" } + if ver := tcpHandler.ProxyProtocol; ver != 0 { + tlsStatus = fmt.Sprintf("%s, PROXY protocol v%d", tlsStatus, ver) + } - output.WriteString(fmt.Sprintf("%s://%s%s\n", scheme, dnsName, portPart)) - output.WriteString(fmt.Sprintf("|-- tcp://%s (%s)\n", hp, tlsStatus)) - for _, a := range st.TailscaleIPs { + output.WriteString(fmt.Sprintf("|-- tcp://%s:%d (%s)\n", host, srvPort, tlsStatus)) + for _, a := range ips { ipp := net.JoinHostPort(a.String(), strconv.Itoa(int(srvPort))) output.WriteString(fmt.Sprintf("|-- tcp://%s\n", ipp)) } - output.WriteString(fmt.Sprintf("|--> tcp://%s\n", h.TCPForward)) + output.WriteString(fmt.Sprintf("|--> tcp://%s\n\n", tcpHandler.TCPForward)) } - if !e.bg { + if !forService && !e.bg.Value { output.WriteString(msgToExit) return output.String() } @@ -481,14 +1129,90 @@ func (e *serveEnv) messageForPort(sc *ipn.ServeConfig, st *ipnstate.Status, dnsN output.WriteString(fmt.Sprintf(msgRunningInBackground, subCmdUpper)) output.WriteString("\n") - output.WriteString(fmt.Sprintf(msgDisableProxy, subCmd, srvType.String(), srvPort)) + if forService { + output.WriteString(fmt.Sprintf(msgDisableServiceProxy, dnsName, srvType.String(), srvPort)) + output.WriteString("\n") + output.WriteString(fmt.Sprintf(msgDisableService, dnsName)) + } else { + output.WriteString(fmt.Sprintf(msgDisableProxy, subCmd, srvType.String(), srvPort)) + } return output.String() } -func (e *serveEnv) applyWebServe(sc *ipn.ServeConfig, dnsName string, srvPort uint16, useTLS bool, mount, target string) error { - h := new(ipn.HTTPHandler) +// isRemote reports whether the given destination from serve config +// is a remote destination. +func isRemote(target string) bool { + // target being a port number means it's localhost + if _, err := strconv.ParseUint(target, 10, 16); err == nil { + return false + } + + // prepend tmp:// if no scheme is present just to help parsing + if !strings.Contains(target, "://") { + target = "tmp://" + target + } + + // make sure we can parse the target, wether it's a full URL or just a host:port + u, err := url.ParseRequestURI(target) + if err != nil { + // If we can't parse the target, it doesn't matter if it's remote or not + return false + } + validHN := dnsname.ValidHostname(u.Hostname()) == nil + validIP := net.ParseIP(u.Hostname()) != nil + if !validHN && !validIP { + return false + } + if u.Hostname() == "localhost" || u.Hostname() == "127.0.0.1" || u.Hostname() == "::1" { + return false + } + return true +} + +// shouldWarnRemoteDestCompatibility reports whether we should warn the user +// that their current OS/environment may not be compatible with +// service's proxy destination. +func (e *serveEnv) shouldWarnRemoteDestCompatibility(ctx context.Context, target string) error { + // no target means nothing to check + if target == "" { + return nil + } + + if filepath.IsAbs(target) || strings.HasPrefix(target, "text:") { + // local path or text target, nothing to check + return nil + } + + // only check for remote destinations + if !isRemote(target) { + return nil + } + + // Check if running as Mac extension and warn + if version.IsMacAppStore() || version.IsMacSysExt() { + return fmt.Errorf(msgWarnRemoteDestCompatibility, "the MacOS extension") + } + + // Check for linux, if it's running with TS_FORCE_LINUX_BIND_TO_DEVICE=true + // and tailscale bypass mark is not working. If any of these conditions are true, and the dest is + // a remote destination, return true. + if runtime.GOOS == "linux" { + SOMarkInUse, err := e.lc.CheckSOMarkInUse(ctx) + if err != nil { + log.Printf("error checking SO mark in use: %v", err) + return nil + } + if !SOMarkInUse { + return fmt.Errorf(msgWarnRemoteDestCompatibility, "the Linux tailscaled without SO_MARK") + } + } + + return nil +} +func (e *serveEnv) applyWebServe(sc *ipn.ServeConfig, dnsName string, srvPort uint16, useTLS bool, mount, target, mds string, caps []tailcfg.PeerCapability) error { + h := new(ipn.HTTPHandler) switch { case strings.HasPrefix(target, "text:"): text := strings.TrimPrefix(target, "text:") @@ -521,19 +1245,21 @@ func (e *serveEnv) applyWebServe(sc *ipn.ServeConfig, dnsName string, srvPort ui return err } h.Proxy = t + h.AcceptAppCaps = caps } // TODO: validation needs to check nested foreground configs - if sc.IsTCPForwardingOnPort(srvPort) { + svcName := tailcfg.AsServiceName(dnsName) + if sc.IsTCPForwardingOnPort(srvPort, svcName) { return errors.New("cannot serve web; already serving TCP") } - sc.SetWebHandler(h, dnsName, srvPort, mount, useTLS) + sc.SetWebHandler(h, dnsName, srvPort, mount, useTLS, mds) return nil } -func (e *serveEnv) applyTCPServe(sc *ipn.ServeConfig, dnsName string, srcType serveType, srcPort uint16, target string) error { +func (e *serveEnv) applyTCPServe(sc *ipn.ServeConfig, dnsName string, srcType serveType, srcPort uint16, target string, proxyProtocol int) error { var terminateTLS bool switch srcType { case serveTypeTCP: @@ -544,6 +1270,8 @@ func (e *serveEnv) applyTCPServe(sc *ipn.ServeConfig, dnsName string, srcType se return fmt.Errorf("invalid TCP target %q", target) } + svcName := tailcfg.AsServiceName(dnsName) + targetURL, err := ipn.ExpandProxyTargetValue(target, []string{"tcp"}, "tcp") if err != nil { return fmt.Errorf("unable to expand target: %v", err) @@ -555,12 +1283,11 @@ func (e *serveEnv) applyTCPServe(sc *ipn.ServeConfig, dnsName string, srcType se } // TODO: needs to account for multiple configs from foreground mode - if sc.IsServingWeb(srcPort) { - return fmt.Errorf("cannot serve TCP; already serving web on %d", srcPort) + if sc.IsServingWeb(srcPort, svcName) { + return fmt.Errorf("cannot serve TCP; already serving web on %d for %s", srcPort, dnsName) } - sc.SetTCPForwarding(srcPort, dstURL.Host, terminateTLS, dnsName) - + sc.SetTCPForwarding(srcPort, dstURL.Host, terminateTLS, proxyProtocol, dnsName) return nil } @@ -580,18 +1307,25 @@ func (e *serveEnv) applyFunnel(sc *ipn.ServeConfig, dnsName string, srvPort uint } // unsetServe removes the serve config for the given serve port. -func (e *serveEnv) unsetServe(sc *ipn.ServeConfig, dnsName string, srvType serveType, srvPort uint16, mount string) error { +// dnsName is a FQDN or a serviceName (with `svc:` prefix). mds +// is the Magic DNS suffix, which is used to recreate serve's host. +func (e *serveEnv) unsetServe(sc *ipn.ServeConfig, dnsName string, srvType serveType, srvPort uint16, mount string, mds string) error { switch srvType { case serveTypeHTTPS, serveTypeHTTP: - err := e.removeWebServe(sc, dnsName, srvPort, mount) + err := e.removeWebServe(sc, dnsName, srvPort, mount, mds) if err != nil { return fmt.Errorf("failed to remove web serve: %w", err) } case serveTypeTCP, serveTypeTLSTerminatedTCP: - err := e.removeTCPServe(sc, srvPort) + err := e.removeTCPServe(sc, dnsName, srvPort) if err != nil { return fmt.Errorf("failed to remove TCP serve: %w", err) } + case serveTypeTUN: + err := e.removeTunServe(sc, dnsName) + if err != nil { + return fmt.Errorf("failed to remove TUN serve: %w", err) + } default: return fmt.Errorf("invalid type %q", srvType) } @@ -622,11 +1356,16 @@ func srvTypeAndPortFromFlags(e *serveEnv) (srvType serveType, srvPort uint16, er } } + if e.tun { + srcTypeCount++ + srvType = serveTypeTUN + } + if srcTypeCount > 1 { return 0, 0, fmt.Errorf("cannot serve multiple types for a single mount point") - } else if srcTypeCount == 0 { - srvType = serveTypeHTTPS - srvPort = 443 + } + if srcTypeCount == 0 { + return serveTypeHTTPS, 443, nil } return srvType, srvPort, nil @@ -729,59 +1468,100 @@ func isLegacyInvocation(subcmd serveMode, args []string) (string, bool) { // removeWebServe removes a web handler from the serve config // and removes funnel if no remaining mounts exist for the serve port. // The srvPort argument is the serving port and the mount argument is -// the mount point or registered path to remove. -func (e *serveEnv) removeWebServe(sc *ipn.ServeConfig, dnsName string, srvPort uint16, mount string) error { - if sc.IsTCPForwardingOnPort(srvPort) { - return errors.New("cannot remove web handler; currently serving TCP") +// the mount point or registered path to remove. mds is the Magic DNS suffix, +// which is used to recreate serve's host. +func (e *serveEnv) removeWebServe(sc *ipn.ServeConfig, dnsName string, srvPort uint16, mount string, mds string) error { + if sc == nil { + return nil } portStr := strconv.Itoa(int(srvPort)) - hp := ipn.HostPort(net.JoinHostPort(dnsName, portStr)) + hostName := dnsName + webServeMap := sc.Web + svcName := tailcfg.AsServiceName(dnsName) + forService := svcName != noService + if forService { + svc := sc.Services[svcName] + if svc == nil { + return errors.New("service does not exist") + } + hostName = strings.Join([]string{svcName.WithoutPrefix(), mds}, ".") + webServeMap = svc.Web + } + + hp := ipn.HostPort(net.JoinHostPort(hostName, portStr)) + if sc.IsTCPForwardingOnPort(srvPort, svcName) { + return errors.New("cannot remove web handler; currently serving TCP") + } var targetExists bool var mounts []string // mount is deduced from e.setPath but it is ambiguous as // to whether the user explicitly passed "/" or it was defaulted to. if e.setPath == "" { - targetExists = sc.Web[hp] != nil && len(sc.Web[hp].Handlers) > 0 + targetExists = webServeMap[hp] != nil && len(webServeMap[hp].Handlers) > 0 if targetExists { - for mount := range sc.Web[hp].Handlers { + for mount := range webServeMap[hp].Handlers { mounts = append(mounts, mount) } } } else { - targetExists = sc.WebHandlerExists(hp, mount) + targetExists = sc.WebHandlerExists(svcName, hp, mount) mounts = []string{mount} } if !targetExists { - return errors.New("error: handler does not exist") + return errors.New("handler does not exist") } if len(mounts) > 1 { msg := fmt.Sprintf("Are you sure you want to delete %d handlers under port %s?", len(mounts), portStr) - if !e.yes && !promptYesNo(msg) { + if !e.yes && !prompt.YesNo(msg, true) { return nil } } - sc.RemoveWebHandler(dnsName, srvPort, mounts, true) + if forService { + sc.RemoveServiceWebHandler(svcName, hostName, srvPort, mounts) + } else { + sc.RemoveWebHandler(dnsName, srvPort, mounts, true) + } return nil } // removeTCPServe removes the TCP forwarding configuration for the -// given srvPort, or serving port. -func (e *serveEnv) removeTCPServe(sc *ipn.ServeConfig, src uint16) error { +// given srvPort, or serving port for the given dnsName. +func (e *serveEnv) removeTCPServe(sc *ipn.ServeConfig, dnsName string, src uint16) error { if sc == nil { return nil } - if sc.GetTCPPortHandler(src) == nil { - return errors.New("error: serve config does not exist") + svcName := tailcfg.AsServiceName(dnsName) + if sc.GetTCPPortHandler(src, svcName) == nil { + return errors.New("serve config does not exist") } - if sc.IsServingWeb(src) { + if sc.IsServingWeb(src, svcName) { return fmt.Errorf("unable to remove; serving web, not TCP forwarding on serve port %d", src) } - sc.RemoveTCPForwarding(src) + sc.RemoveTCPForwarding(svcName, src) + return nil +} + +func (e *serveEnv) removeTunServe(sc *ipn.ServeConfig, dnsName string) error { + if sc == nil { + return nil + } + svcName := tailcfg.ServiceName(dnsName) + svc, ok := sc.Services[svcName] + if !ok || svc == nil { + return errors.New("service does not exist") + } + if !svc.Tun { + return errors.New("service is not being served in TUN mode") + } + delete(sc.Services, svcName) + if len(sc.Services) == 0 { + sc.Services = nil // clean up empty map + } return nil } diff --git a/cmd/tailscale/cli/serve_v2_test.go b/cmd/tailscale/cli/serve_v2_test.go index 5768127ad..491baf9dd 100644 --- a/cmd/tailscale/cli/serve_v2_test.go +++ b/cmd/tailscale/cli/serve_v2_test.go @@ -8,9 +8,12 @@ import ( "context" "encoding/json" "fmt" + "net/netip" "os" "path/filepath" "reflect" + "regexp" + "slices" "strconv" "strings" "testing" @@ -19,6 +22,8 @@ import ( "github.com/peterbourgon/ff/v3/ffcli" "tailscale.com/ipn" "tailscale.com/ipn/ipnstate" + "tailscale.com/tailcfg" + "tailscale.com/types/views" ) func TestServeDevConfigMutations(t *testing.T) { @@ -30,10 +35,11 @@ func TestServeDevConfigMutations(t *testing.T) { } // group is a group of steps that share the same - // config mutation, but always starts from an empty config + // config mutation type group struct { - name string - steps []step + name string + steps []step + initialState fakeLocalServeClient // use the zero value for empty config } // creaet a temporary directory for path-based destinations @@ -214,10 +220,20 @@ func TestServeDevConfigMutations(t *testing.T) { }}, }, { - name: "invalid_host", + name: "ip_host", + initialState: fakeLocalServeClient{ + SOMarkInUse: true, + }, steps: []step{{ - command: cmd("serve --https=443 --bg http://somehost:3000"), // invalid host - wantErr: anyErr(), + command: cmd("serve --https=443 --bg http://192.168.1.1:3000"), + want: &ipn.ServeConfig{ + TCP: map[uint16]*ipn.TCPPortHandler{443: {HTTPS: true}}, + Web: map[ipn.HostPort]*ipn.WebServerConfig{ + "foo.test.ts.net:443": {Handlers: map[string]*ipn.HTTPHandler{ + "/": {Proxy: "http://192.168.1.1:3000"}, + }}, + }, + }, }}, }, { @@ -227,6 +243,16 @@ func TestServeDevConfigMutations(t *testing.T) { wantErr: anyErr(), }}, }, + { + name: "no_scheme_remote_host_tcp", + initialState: fakeLocalServeClient{ + SOMarkInUse: true, + }, + steps: []step{{ + command: cmd("serve --https=443 --bg 192.168.1.1:3000"), + wantErr: exactErrMsg(errHelp), + }}, + }, { name: "turn_off_https", steps: []step{ @@ -396,15 +422,11 @@ func TestServeDevConfigMutations(t *testing.T) { }, }}, }, - { - name: "unknown_host_tcp", - steps: []step{{ - command: cmd("serve --tls-terminated-tcp=443 --bg tcp://somehost:5432"), - wantErr: exactErrMsg(errHelp), - }}, - }, { name: "tcp_port_too_low", + initialState: fakeLocalServeClient{ + SOMarkInUse: true, + }, steps: []step{{ command: cmd("serve --tls-terminated-tcp=443 --bg tcp://somehost:0"), wantErr: exactErrMsg(errHelp), @@ -412,6 +434,9 @@ func TestServeDevConfigMutations(t *testing.T) { }, { name: "tcp_port_too_high", + initialState: fakeLocalServeClient{ + SOMarkInUse: true, + }, steps: []step{{ command: cmd("serve --tls-terminated-tcp=443 --bg tcp://somehost:65536"), wantErr: exactErrMsg(errHelp), @@ -526,6 +551,9 @@ func TestServeDevConfigMutations(t *testing.T) { }, { name: "bad_path", + initialState: fakeLocalServeClient{ + SOMarkInUse: true, + }, steps: []step{{ command: cmd("serve --bg --https=443 bad/path"), wantErr: exactErrMsg(errHelp), @@ -811,17 +839,187 @@ func TestServeDevConfigMutations(t *testing.T) { }, }, }, + { + name: "advertise_service", + initialState: fakeLocalServeClient{ + statusWithoutPeers: &ipnstate.Status{ + BackendState: ipn.Running.String(), + Self: &ipnstate.PeerStatus{ + DNSName: "foo.test.ts.net", + CapMap: tailcfg.NodeCapMap{ + tailcfg.NodeAttrFunnel: nil, + tailcfg.CapabilityFunnelPorts + "?ports=443,8443": nil, + }, + Tags: ptrToReadOnlySlice([]string{"some-tag"}), + }, + CurrentTailnet: &ipnstate.TailnetStatus{MagicDNSSuffix: "test.ts.net"}, + }, + SOMarkInUse: true, + }, + steps: []step{{ + command: cmd("serve --service=svc:foo --http=80 text:foo"), + want: &ipn.ServeConfig{ + Services: map[tailcfg.ServiceName]*ipn.ServiceConfig{ + "svc:foo": { + TCP: map[uint16]*ipn.TCPPortHandler{ + 80: {HTTP: true}, + }, + Web: map[ipn.HostPort]*ipn.WebServerConfig{ + "foo.test.ts.net:80": {Handlers: map[string]*ipn.HTTPHandler{ + "/": {Text: "foo"}, + }}, + }, + }, + }, + }, + }}, + }, + { + name: "advertise_service_from_untagged_node", + steps: []step{{ + command: cmd("serve --service=svc:foo --http=80 text:foo"), + wantErr: anyErr(), + }}, + }, + { + name: "forward_grant_header", + steps: []step{ + { + command: cmd("serve --bg --accept-app-caps=example.com/cap/foo 3000"), + want: &ipn.ServeConfig{ + TCP: map[uint16]*ipn.TCPPortHandler{443: {HTTPS: true}}, + Web: map[ipn.HostPort]*ipn.WebServerConfig{ + "foo.test.ts.net:443": {Handlers: map[string]*ipn.HTTPHandler{ + "/": { + Proxy: "http://127.0.0.1:3000", + AcceptAppCaps: []tailcfg.PeerCapability{"example.com/cap/foo"}, + }, + }}, + }, + }, + }, + { + command: cmd("serve --bg --accept-app-caps=example.com/cap/foo,example.com/cap/bar 3000"), + want: &ipn.ServeConfig{ + TCP: map[uint16]*ipn.TCPPortHandler{443: {HTTPS: true}}, + Web: map[ipn.HostPort]*ipn.WebServerConfig{ + "foo.test.ts.net:443": {Handlers: map[string]*ipn.HTTPHandler{ + "/": { + Proxy: "http://127.0.0.1:3000", + AcceptAppCaps: []tailcfg.PeerCapability{"example.com/cap/foo", "example.com/cap/bar"}, + }, + }}, + }, + }, + }, + { + command: cmd("serve --bg --accept-app-caps=example.com/cap/bar 3000"), + want: &ipn.ServeConfig{ + TCP: map[uint16]*ipn.TCPPortHandler{443: {HTTPS: true}}, + Web: map[ipn.HostPort]*ipn.WebServerConfig{ + "foo.test.ts.net:443": {Handlers: map[string]*ipn.HTTPHandler{ + "/": { + Proxy: "http://127.0.0.1:3000", + AcceptAppCaps: []tailcfg.PeerCapability{"example.com/cap/bar"}, + }, + }}, + }, + }, + }, + }, + }, + { + name: "invalid_accept_caps_invalid_app_cap", + steps: []step{ + { + command: cmd("serve --bg --accept-app-caps=example.com/cap/fine,NOTFINE 3000"), // should be {domain.tld}/{name} + wantErr: func(err error) (badErrMsg string) { + if err == nil || !strings.Contains(err.Error(), fmt.Sprintf("%q does not match", "NOTFINE")) { + return fmt.Sprintf("wanted validation error that quotes the non-matching capability (and nothing more) but got %q", err.Error()) + } + return "" + }, + }, + }, + }, + { + name: "tcp_with_proxy_protocol_v1", + steps: []step{{ + command: cmd("serve --tcp=8000 --proxy-protocol=1 --bg tcp://localhost:5432"), + want: &ipn.ServeConfig{ + TCP: map[uint16]*ipn.TCPPortHandler{ + 8000: { + TCPForward: "localhost:5432", + ProxyProtocol: 1, + }, + }, + }, + }}, + }, + { + name: "tls_terminated_tcp_with_proxy_protocol_v2", + steps: []step{{ + command: cmd("serve --tls-terminated-tcp=443 --proxy-protocol=2 --bg tcp://localhost:5432"), + want: &ipn.ServeConfig{ + TCP: map[uint16]*ipn.TCPPortHandler{ + 443: { + TCPForward: "localhost:5432", + TerminateTLS: "foo.test.ts.net", + ProxyProtocol: 2, + }, + }, + }, + }}, + }, + { + name: "tcp_update_to_add_proxy_protocol", + steps: []step{ + { + command: cmd("serve --tcp=8000 --bg tcp://localhost:5432"), + want: &ipn.ServeConfig{ + TCP: map[uint16]*ipn.TCPPortHandler{ + 8000: {TCPForward: "localhost:5432"}, + }, + }, + }, + { + command: cmd("serve --tcp=8000 --proxy-protocol=1 --bg tcp://localhost:5432"), + want: &ipn.ServeConfig{ + TCP: map[uint16]*ipn.TCPPortHandler{ + 8000: { + TCPForward: "localhost:5432", + ProxyProtocol: 1, + }, + }, + }, + }, + }, + }, + { + name: "tcp_proxy_protocol_invalid_version", + steps: []step{{ + command: cmd("serve --tcp=8000 --proxy-protocol=3 --bg tcp://localhost:5432"), + wantErr: anyErr(), + }}, + }, + { + name: "proxy_protocol_without_tcp", + steps: []step{{ + command: cmd("serve --https=443 --proxy-protocol=1 --bg http://localhost:3000"), + wantErr: anyErr(), + }}, + }, } for _, group := range groups { t.Run(group.name, func(t *testing.T) { - lc := &fakeLocalServeClient{} + lc := group.initialState for i, st := range group.steps { var stderr bytes.Buffer var stdout bytes.Buffer var flagOut bytes.Buffer e := &serveEnv{ - lc: lc, + lc: &lc, testFlagOut: &flagOut, testStdout: &stdout, testStderr: &stderr, @@ -874,9 +1072,10 @@ func TestValidateConfig(t *testing.T) { name string desc string cfg *ipn.ServeConfig + svc tailcfg.ServiceName servePort uint16 serveType serveType - bg bool + bg bgBoolFlag wantErr bool }{ { @@ -894,7 +1093,7 @@ func TestValidateConfig(t *testing.T) { 443: {HTTPS: true}, }, }, - bg: true, + bg: bgBoolFlag{true, false}, servePort: 10000, serveType: serveTypeHTTPS, }, @@ -906,7 +1105,7 @@ func TestValidateConfig(t *testing.T) { 443: {TCPForward: "http://localhost:4545"}, }, }, - bg: true, + bg: bgBoolFlag{true, false}, servePort: 443, serveType: serveTypeTCP, }, @@ -918,7 +1117,7 @@ func TestValidateConfig(t *testing.T) { 443: {HTTPS: true}, }, }, - bg: true, + bg: bgBoolFlag{true, false}, servePort: 443, serveType: serveTypeHTTP, wantErr: true, @@ -957,12 +1156,90 @@ func TestValidateConfig(t *testing.T) { serveType: serveTypeTCP, wantErr: true, }, + { + name: "new_service_tcp", + desc: "no error when adding a new service port", + cfg: &ipn.ServeConfig{ + Services: map[tailcfg.ServiceName]*ipn.ServiceConfig{ + "svc:foo": { + TCP: map[uint16]*ipn.TCPPortHandler{80: {HTTP: true}}, + }, + }, + }, + svc: "svc:foo", + servePort: 8080, + serveType: serveTypeTCP, + }, + { + name: "override_service_tcp", + desc: "no error when overwriting a previous service port", + cfg: &ipn.ServeConfig{ + Services: map[tailcfg.ServiceName]*ipn.ServiceConfig{ + "svc:foo": { + TCP: map[uint16]*ipn.TCPPortHandler{ + 443: {TCPForward: "http://localhost:4545"}, + }, + }, + }, + }, + svc: "svc:foo", + servePort: 443, + serveType: serveTypeTCP, + }, + { + name: "override_service_tcp", + desc: "error when overwriting a previous service port with a different serve type", + cfg: &ipn.ServeConfig{ + Services: map[tailcfg.ServiceName]*ipn.ServiceConfig{ + "svc:foo": { + TCP: map[uint16]*ipn.TCPPortHandler{ + 443: {HTTPS: true}, + }, + }, + }, + }, + svc: "svc:foo", + servePort: 443, + serveType: serveTypeHTTP, + wantErr: true, + }, + { + name: "override_service_tcp", + desc: "error when setting previous tcp service to tun mode", + cfg: &ipn.ServeConfig{ + Services: map[tailcfg.ServiceName]*ipn.ServiceConfig{ + "svc:foo": { + TCP: map[uint16]*ipn.TCPPortHandler{ + 443: {TCPForward: "http://localhost:4545"}, + }, + }, + }, + }, + svc: "svc:foo", + serveType: serveTypeTUN, + wantErr: true, + }, + { + name: "override_service_tun", + desc: "error when setting previous tun service to tcp forwarder", + cfg: &ipn.ServeConfig{ + Services: map[tailcfg.ServiceName]*ipn.ServiceConfig{ + "svc:foo": { + Tun: true, + }, + }, + }, + svc: "svc:foo", + serveType: serveTypeTCP, + servePort: 443, + wantErr: true, + }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { se := serveEnv{bg: tc.bg} - err := se.validateConfig(tc.cfg, tc.servePort, tc.serveType) + err := se.validateConfig(tc.cfg, tc.servePort, tc.serveType, tc.svc) if err == nil && tc.wantErr { t.Fatal("expected an error but got nil") } @@ -1017,6 +1294,13 @@ func TestSrcTypeFromFlags(t *testing.T) { expectedPort: 443, expectedErr: false, }, + { + name: "defaults to https, port 443 for service", + env: &serveEnv{service: "svc:foo"}, + expectedType: serveTypeHTTPS, + expectedPort: 443, + expectedErr: false, + }, { name: "multiple types set", env: &serveEnv{http: 80, https: 443}, @@ -1041,6 +1325,118 @@ func TestSrcTypeFromFlags(t *testing.T) { } } +func TestAcceptSetAppCapsFlag(t *testing.T) { + testCases := []struct { + name string + inputs []string + expectErr bool + expectErrToMatch *regexp.Regexp + expectedValue []tailcfg.PeerCapability + }{ + { + name: "valid_simple", + inputs: []string{"example.com/name"}, + expectErr: false, + expectedValue: []tailcfg.PeerCapability{"example.com/name"}, + }, + { + name: "valid_unicode", + inputs: []string{"bÃŧcher.de/something"}, + expectErr: false, + expectedValue: []tailcfg.PeerCapability{"bÃŧcher.de/something"}, + }, + { + name: "more_valid_unicode", + inputs: []string{"example.tw/某某某"}, + expectErr: false, + expectedValue: []tailcfg.PeerCapability{"example.tw/某某某"}, + }, + { + name: "valid_path_slashes", + inputs: []string{"domain.com/path/to/name"}, + expectErr: false, + expectedValue: []tailcfg.PeerCapability{"domain.com/path/to/name"}, + }, + { + name: "valid_multiple_sets", + inputs: []string{"one.com/foo,two.com/bar"}, + expectErr: false, + expectedValue: []tailcfg.PeerCapability{"one.com/foo", "two.com/bar"}, + }, + { + name: "valid_empty_string", + inputs: []string{""}, + expectErr: false, + expectedValue: nil, // Empty string should be a no-op and not append anything. + }, + { + name: "invalid_path_chars", + inputs: []string{"domain.com/path_with_underscore"}, + expectErr: true, + expectErrToMatch: regexp.MustCompile(`"domain.com/path_with_underscore"`), + expectedValue: nil, // Slice should remain empty. + }, + { + name: "valid_subdomain", + inputs: []string{"sub.domain.com/name"}, + expectErr: false, + expectedValue: []tailcfg.PeerCapability{"sub.domain.com/name"}, + }, + { + name: "invalid_no_path", + inputs: []string{"domain.com/"}, + expectErr: true, + expectErrToMatch: regexp.MustCompile(`"domain.com/"`), + expectedValue: nil, + }, + { + name: "invalid_no_domain", + inputs: []string{"/path/only"}, + expectErr: true, + expectErrToMatch: regexp.MustCompile(`"/path/only"`), + expectedValue: nil, + }, + { + name: "some_invalid_some_valid", + inputs: []string{"one.com/foo,bad/bar,two.com/baz"}, + expectErr: true, + expectErrToMatch: regexp.MustCompile(`"bad/bar"`), + expectedValue: []tailcfg.PeerCapability{"one.com/foo"}, // Parsing will stop after first error + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + var v []tailcfg.PeerCapability + flag := &acceptAppCapsFlag{Value: &v} + + var err error + for _, s := range tc.inputs { + err = flag.Set(s) + if err != nil { + break + } + } + + if tc.expectErr && err == nil { + t.Errorf("expected an error, but got none") + } + if tc.expectErrToMatch != nil { + if !tc.expectErrToMatch.MatchString(err.Error()) { + t.Errorf("expected error to match %q, but was %q", tc.expectErrToMatch, err) + } + } + if !tc.expectErr && err != nil { + t.Errorf("did not expect an error, but got: %v", err) + } + + if !reflect.DeepEqual(tc.expectedValue, v) { + t.Errorf("unexpected value, got: %q, want: %q", v, tc.expectedValue) + } + }) + } +} + func TestCleanURLPath(t *testing.T) { tests := []struct { input string @@ -1075,45 +1471,151 @@ func TestCleanURLPath(t *testing.T) { } } -func TestMessageForPort(t *testing.T) { +func TestAddServiceToPrefs(t *testing.T) { tests := []struct { - name string - subcmd serveMode - serveConfig *ipn.ServeConfig - status *ipnstate.Status - dnsName string - srvType serveType - srvPort uint16 - expected string + name string + svcName tailcfg.ServiceName + startServices []string + expected []string }{ { - name: "funnel-https", - subcmd: funnel, - serveConfig: &ipn.ServeConfig{ - TCP: map[uint16]*ipn.TCPPortHandler{ - 443: {HTTPS: true}, - }, - Web: map[ipn.HostPort]*ipn.WebServerConfig{ - "foo.test.ts.net:443": { - Handlers: map[string]*ipn.HTTPHandler{ - "/": {Proxy: "http://127.0.0.1:3000"}, - }, - }, - }, - AllowFunnel: map[ipn.HostPort]bool{ - "foo.test.ts.net:443": true, - }, - }, - status: &ipnstate.Status{}, - dnsName: "foo.test.ts.net", - srvType: serveTypeHTTPS, - srvPort: 443, - expected: strings.Join([]string{ - msgFunnelAvailable, - "", - "https://foo.test.ts.net/", - "|-- proxy http://127.0.0.1:3000", - "", + name: "add service to empty prefs", + svcName: "svc:foo", + expected: []string{"svc:foo"}, + }, + { + name: "add service to existing prefs", + svcName: "svc:bar", + startServices: []string{"svc:foo"}, + expected: []string{"svc:foo", "svc:bar"}, + }, + { + name: "add existing service to prefs", + svcName: "svc:foo", + startServices: []string{"svc:foo"}, + expected: []string{"svc:foo"}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + lc := &fakeLocalServeClient{} + ctx := t.Context() + lc.EditPrefs(ctx, &ipn.MaskedPrefs{ + AdvertiseServicesSet: true, + Prefs: ipn.Prefs{ + AdvertiseServices: tt.startServices, + }, + }) + e := &serveEnv{lc: lc, bg: bgBoolFlag{true, false}} + err := e.addServiceToPrefs(ctx, tt.svcName) + if err != nil { + t.Fatalf("addServiceToPrefs(%q) returned unexpected error: %v", tt.svcName, err) + } + if !slices.Equal(lc.prefs.AdvertiseServices, tt.expected) { + t.Errorf("addServiceToPrefs(%q) = %v, want %v", tt.svcName, lc.prefs.AdvertiseServices, tt.expected) + } + }) + } + +} + +func TestRemoveServiceFromPrefs(t *testing.T) { + tests := []struct { + name string + svcName tailcfg.ServiceName + startServices []string + expected []string + }{ + { + name: "remove service from empty prefs", + svcName: "svc:foo", + expected: []string{}, + }, + { + name: "remove existing service from prefs", + svcName: "svc:foo", + startServices: []string{"svc:foo"}, + expected: []string{}, + }, + { + name: "remove service not in prefs", + svcName: "svc:bar", + startServices: []string{"svc:foo"}, + expected: []string{"svc:foo"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + lc := &fakeLocalServeClient{} + ctx := t.Context() + lc.EditPrefs(ctx, &ipn.MaskedPrefs{ + AdvertiseServicesSet: true, + Prefs: ipn.Prefs{ + AdvertiseServices: tt.startServices, + }, + }) + e := &serveEnv{lc: lc, bg: bgBoolFlag{true, false}} + err := e.removeServiceFromPrefs(ctx, tt.svcName) + if err != nil { + t.Fatalf("removeServiceFromPrefs(%q) returned unexpected error: %v", tt.svcName, err) + } + if !slices.Equal(lc.prefs.AdvertiseServices, tt.expected) { + t.Errorf("removeServiceFromPrefs(%q) = %v, want %v", tt.svcName, lc.prefs.AdvertiseServices, tt.expected) + } + }) + } +} + +func TestMessageForPort(t *testing.T) { + svcIPMap := tailcfg.ServiceIPMappings{ + "svc:foo": []netip.Addr{ + netip.MustParseAddr("100.101.101.101"), + netip.MustParseAddr("fd7a:115c:a1e0:ab12:4843:cd96:6565:6565"), + }, + } + svcIPMapJSON, _ := json.Marshal(svcIPMap) + svcIPMapJSONRawMSG := tailcfg.RawMessage(svcIPMapJSON) + + tests := []struct { + name string + subcmd serveMode + serveConfig *ipn.ServeConfig + status *ipnstate.Status + prefs *ipn.Prefs + dnsName string + srvType serveType + srvPort uint16 + expected string + }{ + { + name: "funnel-https", + subcmd: funnel, + serveConfig: &ipn.ServeConfig{ + TCP: map[uint16]*ipn.TCPPortHandler{ + 443: {HTTPS: true}, + }, + Web: map[ipn.HostPort]*ipn.WebServerConfig{ + "foo.test.ts.net:443": { + Handlers: map[string]*ipn.HTTPHandler{ + "/": {Proxy: "http://127.0.0.1:3000"}, + }, + }, + }, + AllowFunnel: map[ipn.HostPort]bool{ + "foo.test.ts.net:443": true, + }, + }, + status: &ipnstate.Status{CurrentTailnet: &ipnstate.TailnetStatus{MagicDNSSuffix: "test.ts.net"}}, + dnsName: "foo.test.ts.net", + srvType: serveTypeHTTPS, + srvPort: 443, + expected: strings.Join([]string{ + msgFunnelAvailable, + "", + "https://foo.test.ts.net/", + "|-- proxy http://127.0.0.1:3000", + "", fmt.Sprintf(msgRunningInBackground, "Funnel"), fmt.Sprintf(msgDisableProxy, "funnel", "https", 443), }, "\n"), @@ -1133,7 +1635,7 @@ func TestMessageForPort(t *testing.T) { }, }, }, - status: &ipnstate.Status{}, + status: &ipnstate.Status{CurrentTailnet: &ipnstate.TailnetStatus{MagicDNSSuffix: "test.ts.net"}}, dnsName: "foo.test.ts.net", srvType: serveTypeHTTP, srvPort: 80, @@ -1147,10 +1649,206 @@ func TestMessageForPort(t *testing.T) { fmt.Sprintf(msgDisableProxy, "serve", "http", 80), }, "\n"), }, + { + name: "serve service http", + subcmd: serve, + serveConfig: &ipn.ServeConfig{ + Services: map[tailcfg.ServiceName]*ipn.ServiceConfig{ + "svc:foo": { + TCP: map[uint16]*ipn.TCPPortHandler{ + 80: {HTTP: true}, + }, + Web: map[ipn.HostPort]*ipn.WebServerConfig{ + "foo.test.ts.net:80": { + Handlers: map[string]*ipn.HTTPHandler{ + "/": {Proxy: "http://localhost:3000"}, + }, + }, + }, + }, + }, + }, + status: &ipnstate.Status{ + CurrentTailnet: &ipnstate.TailnetStatus{MagicDNSSuffix: "test.ts.net"}, + Self: &ipnstate.PeerStatus{ + CapMap: tailcfg.NodeCapMap{ + tailcfg.NodeAttrServiceHost: []tailcfg.RawMessage{svcIPMapJSONRawMSG}, + }, + }, + }, + prefs: &ipn.Prefs{ + AdvertiseServices: []string{"svc:foo"}, + }, + dnsName: "svc:foo", + srvType: serveTypeHTTP, + srvPort: 80, + expected: strings.Join([]string{ + msgServeAvailable, + "", + "http://foo.test.ts.net/", + "|-- proxy http://localhost:3000", + "", + fmt.Sprintf(msgRunningInBackground, "Serve"), + fmt.Sprintf(msgDisableServiceProxy, "svc:foo", "http", 80), + fmt.Sprintf(msgDisableService, "svc:foo"), + }, "\n"), + }, + { + name: "serve service no capmap", + subcmd: serve, + serveConfig: &ipn.ServeConfig{ + Services: map[tailcfg.ServiceName]*ipn.ServiceConfig{ + "svc:bar": { + TCP: map[uint16]*ipn.TCPPortHandler{ + 80: {HTTP: true}, + }, + Web: map[ipn.HostPort]*ipn.WebServerConfig{ + "bar.test.ts.net:80": { + Handlers: map[string]*ipn.HTTPHandler{ + "/": {Proxy: "http://localhost:3000"}, + }, + }, + }, + }, + }, + }, + status: &ipnstate.Status{ + CurrentTailnet: &ipnstate.TailnetStatus{MagicDNSSuffix: "test.ts.net"}, + Self: &ipnstate.PeerStatus{ + CapMap: tailcfg.NodeCapMap{ + tailcfg.NodeAttrServiceHost: []tailcfg.RawMessage{svcIPMapJSONRawMSG}, + }, + }, + }, + prefs: &ipn.Prefs{ + AdvertiseServices: []string{"svc:bar"}, + }, + dnsName: "svc:bar", + srvType: serveTypeHTTP, + srvPort: 80, + expected: strings.Join([]string{ + fmt.Sprintf(msgServiceWaitingApproval, "svc:bar"), + "", + "http://bar.test.ts.net/", + "|-- proxy http://localhost:3000", + "", + fmt.Sprintf(msgRunningInBackground, "Serve"), + fmt.Sprintf(msgDisableServiceProxy, "svc:bar", "http", 80), + fmt.Sprintf(msgDisableService, "svc:bar"), + }, "\n"), + }, + { + name: "serve service https non-default port", + subcmd: serve, + serveConfig: &ipn.ServeConfig{ + Services: map[tailcfg.ServiceName]*ipn.ServiceConfig{ + "svc:foo": { + TCP: map[uint16]*ipn.TCPPortHandler{ + 2200: {HTTPS: true}, + }, + Web: map[ipn.HostPort]*ipn.WebServerConfig{ + "foo.test.ts.net:2200": { + Handlers: map[string]*ipn.HTTPHandler{ + "/": {Proxy: "http://localhost:3000"}, + }, + }, + }, + }, + }, + }, + status: &ipnstate.Status{ + CurrentTailnet: &ipnstate.TailnetStatus{MagicDNSSuffix: "test.ts.net"}, + Self: &ipnstate.PeerStatus{ + CapMap: tailcfg.NodeCapMap{ + tailcfg.NodeAttrServiceHost: []tailcfg.RawMessage{svcIPMapJSONRawMSG}, + }, + }, + }, + prefs: &ipn.Prefs{AdvertiseServices: []string{"svc:foo"}}, + dnsName: "svc:foo", + srvType: serveTypeHTTPS, + srvPort: 2200, + expected: strings.Join([]string{ + msgServeAvailable, + "", + "https://foo.test.ts.net:2200/", + "|-- proxy http://localhost:3000", + "", + fmt.Sprintf(msgRunningInBackground, "Serve"), + fmt.Sprintf(msgDisableServiceProxy, "svc:foo", "https", 2200), + fmt.Sprintf(msgDisableService, "svc:foo"), + }, "\n"), + }, + { + name: "serve service TCPForward", + subcmd: serve, + serveConfig: &ipn.ServeConfig{ + Services: map[tailcfg.ServiceName]*ipn.ServiceConfig{ + "svc:foo": { + TCP: map[uint16]*ipn.TCPPortHandler{ + 2200: {TCPForward: "localhost:3000"}, + }, + }, + }, + }, + status: &ipnstate.Status{ + CurrentTailnet: &ipnstate.TailnetStatus{MagicDNSSuffix: "test.ts.net"}, + Self: &ipnstate.PeerStatus{ + CapMap: tailcfg.NodeCapMap{ + tailcfg.NodeAttrServiceHost: []tailcfg.RawMessage{svcIPMapJSONRawMSG}, + }, + }, + }, + prefs: &ipn.Prefs{AdvertiseServices: []string{"svc:foo"}}, + dnsName: "svc:foo", + srvType: serveTypeTCP, + srvPort: 2200, + expected: strings.Join([]string{ + msgServeAvailable, + "", + "|-- tcp://foo.test.ts.net:2200 (TLS over TCP)", + "|-- tcp://100.101.101.101:2200", + "|-- tcp://[fd7a:115c:a1e0:ab12:4843:cd96:6565:6565]:2200", + "|--> tcp://localhost:3000", + "", + fmt.Sprintf(msgRunningInBackground, "Serve"), + fmt.Sprintf(msgDisableServiceProxy, "svc:foo", "tcp", 2200), + fmt.Sprintf(msgDisableService, "svc:foo"), + }, "\n"), + }, + { + name: "serve service Tun", + subcmd: serve, + serveConfig: &ipn.ServeConfig{ + Services: map[tailcfg.ServiceName]*ipn.ServiceConfig{ + "svc:foo": { + Tun: true, + }, + }, + }, + status: &ipnstate.Status{ + CurrentTailnet: &ipnstate.TailnetStatus{MagicDNSSuffix: "test.ts.net"}, + Self: &ipnstate.PeerStatus{ + CapMap: tailcfg.NodeCapMap{ + tailcfg.NodeAttrServiceHost: []tailcfg.RawMessage{svcIPMapJSONRawMSG}, + }, + }, + }, + prefs: &ipn.Prefs{AdvertiseServices: []string{"svc:foo"}}, + dnsName: "svc:foo", + srvType: serveTypeTUN, + expected: strings.Join([]string{ + msgServeAvailable, + "", + fmt.Sprintf(msgRunningTunService, "foo.test.ts.net"), + fmt.Sprintf(msgDisableServiceTun, "svc:foo"), + fmt.Sprintf(msgDisableService, "svc:foo"), + }, "\n"), + }, } for _, tt := range tests { - e := &serveEnv{bg: true, subcmd: tt.subcmd} + e := &serveEnv{bg: bgBoolFlag{true, false}, subcmd: tt.subcmd} t.Run(tt.name, func(t *testing.T) { actual := e.messageForPort(tt.serveConfig, tt.status, tt.dnsName, tt.srvType, tt.srvPort) @@ -1277,6 +1975,578 @@ func TestIsLegacyInvocation(t *testing.T) { } } +func TestSetServe(t *testing.T) { + e := &serveEnv{} + magicDNSSuffix := "test.ts.net" + tests := []struct { + name string + desc string + cfg *ipn.ServeConfig + st *ipnstate.Status + dnsName string + srvType serveType + srvPort uint16 + mountPath string + target string + allowFunnel bool + proxyProtocol int + expected *ipn.ServeConfig + expectErr bool + }{ + { + name: "add new handler", + desc: "add a new http handler to empty config", + cfg: &ipn.ServeConfig{}, + dnsName: "foo.test.ts.net", + srvType: serveTypeHTTP, + srvPort: 80, + mountPath: "/", + target: "http://localhost:3000", + expected: &ipn.ServeConfig{ + TCP: map[uint16]*ipn.TCPPortHandler{80: {HTTP: true}}, + Web: map[ipn.HostPort]*ipn.WebServerConfig{ + "foo.test.ts.net:80": { + Handlers: map[string]*ipn.HTTPHandler{ + "/": {Proxy: "http://localhost:3000"}, + }, + }, + }, + }, + }, + { + name: "update http handler", + desc: "update an existing http handler on the same port to same type", + cfg: &ipn.ServeConfig{ + TCP: map[uint16]*ipn.TCPPortHandler{80: {HTTP: true}}, + Web: map[ipn.HostPort]*ipn.WebServerConfig{ + "foo.test.ts.net:80": { + Handlers: map[string]*ipn.HTTPHandler{ + "/": {Proxy: "http://localhost:3000"}, + }, + }, + }, + }, + dnsName: "foo.test.ts.net", + srvType: serveTypeHTTP, + srvPort: 80, + mountPath: "/", + target: "http://localhost:3001", + expected: &ipn.ServeConfig{ + TCP: map[uint16]*ipn.TCPPortHandler{80: {HTTP: true}}, + Web: map[ipn.HostPort]*ipn.WebServerConfig{ + "foo.test.ts.net:80": { + Handlers: map[string]*ipn.HTTPHandler{ + "/": {Proxy: "http://localhost:3001"}, + }, + }, + }, + }, + }, + { + name: "update TCP handler", + desc: "update an existing TCP handler on the same port to a http handler", + cfg: &ipn.ServeConfig{ + TCP: map[uint16]*ipn.TCPPortHandler{80: {TCPForward: "http://localhost:3000"}}, + }, + dnsName: "foo.test.ts.net", + srvType: serveTypeHTTP, + srvPort: 80, + mountPath: "/", + target: "http://localhost:3001", + expectErr: true, + }, + { + name: "add new service handler", + desc: "add a new service TCP handler to empty config", + cfg: &ipn.ServeConfig{}, + + dnsName: "svc:bar", + srvType: serveTypeTCP, + srvPort: 80, + target: "3000", + expected: &ipn.ServeConfig{ + Services: map[tailcfg.ServiceName]*ipn.ServiceConfig{ + "svc:bar": { + TCP: map[uint16]*ipn.TCPPortHandler{80: {TCPForward: "127.0.0.1:3000"}}, + }, + }, + }, + }, + { + name: "update service handler", + desc: "update an existing service TCP handler on the same port to same type", + cfg: &ipn.ServeConfig{ + Services: map[tailcfg.ServiceName]*ipn.ServiceConfig{ + "svc:bar": { + TCP: map[uint16]*ipn.TCPPortHandler{80: {TCPForward: "127.0.0.1:3000"}}, + }, + }, + }, + dnsName: "svc:bar", + srvType: serveTypeTCP, + srvPort: 80, + target: "3001", + expected: &ipn.ServeConfig{ + Services: map[tailcfg.ServiceName]*ipn.ServiceConfig{ + "svc:bar": { + TCP: map[uint16]*ipn.TCPPortHandler{80: {TCPForward: "127.0.0.1:3001"}}, + }, + }, + }, + }, + { + name: "update service handler", + desc: "update an existing service TCP handler on the same port to a http handler", + cfg: &ipn.ServeConfig{ + Services: map[tailcfg.ServiceName]*ipn.ServiceConfig{ + "svc:bar": { + TCP: map[uint16]*ipn.TCPPortHandler{80: {TCPForward: "127.0.0.1:3000"}}, + }, + }, + }, + dnsName: "svc:bar", + srvType: serveTypeHTTP, + srvPort: 80, + mountPath: "/", + target: "http://localhost:3001", + expectErr: true, + }, + { + name: "add new service handler", + desc: "add a new service HTTP handler to empty config", + cfg: &ipn.ServeConfig{}, + dnsName: "svc:bar", + srvType: serveTypeHTTP, + srvPort: 80, + mountPath: "/", + target: "http://localhost:3000", + expected: &ipn.ServeConfig{ + Services: map[tailcfg.ServiceName]*ipn.ServiceConfig{ + "svc:bar": { + TCP: map[uint16]*ipn.TCPPortHandler{80: {HTTP: true}}, + Web: map[ipn.HostPort]*ipn.WebServerConfig{ + "bar.test.ts.net:80": { + Handlers: map[string]*ipn.HTTPHandler{ + "/": {Proxy: "http://localhost:3000"}, + }, + }, + }, + }, + }, + }, + }, + { + name: "update existing service handler", + desc: "update an existing service HTTP handler", + cfg: &ipn.ServeConfig{ + Services: map[tailcfg.ServiceName]*ipn.ServiceConfig{ + "svc:bar": { + TCP: map[uint16]*ipn.TCPPortHandler{80: {HTTP: true}}, + Web: map[ipn.HostPort]*ipn.WebServerConfig{ + "bar.test.ts.net:80": { + Handlers: map[string]*ipn.HTTPHandler{ + "/": {Proxy: "http://localhost:3000"}, + }, + }, + }, + }, + }, + }, + dnsName: "svc:bar", + srvType: serveTypeHTTP, + srvPort: 80, + mountPath: "/", + target: "http://localhost:3001", + expected: &ipn.ServeConfig{ + Services: map[tailcfg.ServiceName]*ipn.ServiceConfig{ + "svc:bar": { + TCP: map[uint16]*ipn.TCPPortHandler{80: {HTTP: true}}, + Web: map[ipn.HostPort]*ipn.WebServerConfig{ + "bar.test.ts.net:80": { + Handlers: map[string]*ipn.HTTPHandler{ + "/": {Proxy: "http://localhost:3001"}, + }, + }, + }, + }, + }, + }, + }, + { + name: "add new service handler", + desc: "add a new service HTTP handler to existing service config", + cfg: &ipn.ServeConfig{ + Services: map[tailcfg.ServiceName]*ipn.ServiceConfig{ + "svc:bar": { + TCP: map[uint16]*ipn.TCPPortHandler{80: {HTTP: true}}, + Web: map[ipn.HostPort]*ipn.WebServerConfig{ + "bar.test.ts.net:80": { + Handlers: map[string]*ipn.HTTPHandler{ + "/": {Proxy: "http://localhost:3000"}, + }, + }, + }, + }, + }, + }, + dnsName: "svc:bar", + srvType: serveTypeHTTP, + srvPort: 88, + mountPath: "/", + target: "http://localhost:3001", + expected: &ipn.ServeConfig{ + Services: map[tailcfg.ServiceName]*ipn.ServiceConfig{ + "svc:bar": { + TCP: map[uint16]*ipn.TCPPortHandler{ + 80: {HTTP: true}, + 88: {HTTP: true}, + }, + Web: map[ipn.HostPort]*ipn.WebServerConfig{ + "bar.test.ts.net:80": { + Handlers: map[string]*ipn.HTTPHandler{ + "/": {Proxy: "http://localhost:3000"}, + }, + }, + "bar.test.ts.net:88": { + Handlers: map[string]*ipn.HTTPHandler{ + "/": {Proxy: "http://localhost:3001"}, + }, + }, + }, + }, + }, + }, + }, + { + name: "add new service mount", + desc: "add a new service mount to existing service config", + cfg: &ipn.ServeConfig{ + Services: map[tailcfg.ServiceName]*ipn.ServiceConfig{ + "svc:bar": { + TCP: map[uint16]*ipn.TCPPortHandler{80: {HTTP: true}}, + Web: map[ipn.HostPort]*ipn.WebServerConfig{ + "bar.test.ts.net:80": { + Handlers: map[string]*ipn.HTTPHandler{ + "/": {Proxy: "http://localhost:3000"}, + }, + }, + }, + }, + }, + }, + dnsName: "svc:bar", + srvType: serveTypeHTTP, + srvPort: 80, + mountPath: "/added", + target: "http://localhost:3001", + expected: &ipn.ServeConfig{ + Services: map[tailcfg.ServiceName]*ipn.ServiceConfig{ + "svc:bar": { + TCP: map[uint16]*ipn.TCPPortHandler{ + 80: {HTTP: true}, + }, + Web: map[ipn.HostPort]*ipn.WebServerConfig{ + "bar.test.ts.net:80": { + Handlers: map[string]*ipn.HTTPHandler{ + "/": {Proxy: "http://localhost:3000"}, + "/added": {Proxy: "http://localhost:3001"}, + }, + }, + }, + }, + }, + }, + }, + { + name: "add new service handler", + desc: "add a new service handler in tun mode to empty config", + cfg: &ipn.ServeConfig{}, + dnsName: "svc:bar", + srvType: serveTypeTUN, + expected: &ipn.ServeConfig{ + Services: map[tailcfg.ServiceName]*ipn.ServiceConfig{ + "svc:bar": { + Tun: true, + }, + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := e.setServe(tt.cfg, tt.dnsName, tt.srvType, tt.srvPort, tt.mountPath, tt.target, tt.allowFunnel, magicDNSSuffix, nil, tt.proxyProtocol) + if err != nil && !tt.expectErr { + t.Fatalf("got error: %v; did not expect error.", err) + } + if err == nil && tt.expectErr { + t.Fatalf("got no error; expected error.") + } + if !tt.expectErr && !reflect.DeepEqual(tt.cfg, tt.expected) { + svcName := tailcfg.ServiceName(tt.dnsName) + t.Fatalf("got: %v; expected: %v", tt.cfg.Services[svcName], tt.expected.Services[svcName]) + } + }) + } +} + +func TestUnsetServe(t *testing.T) { + tests := []struct { + name string + desc string + cfg *ipn.ServeConfig + st *ipnstate.Status + dnsName string + srvType serveType + srvPort uint16 + mount string + setServeEnv bool + serveEnv *serveEnv // if set, use this instead of the default serveEnv + expected *ipn.ServeConfig + expectErr bool + }{ + { + name: "unset http handler", + desc: "remove an existing http handler", + cfg: &ipn.ServeConfig{ + TCP: map[uint16]*ipn.TCPPortHandler{ + 80: {HTTP: true}, + }, + Web: map[ipn.HostPort]*ipn.WebServerConfig{ + "foo.test.ts.net:80": { + Handlers: map[string]*ipn.HTTPHandler{ + "/": {Proxy: "http://localhost:3000"}, + }, + }, + }, + }, + st: &ipnstate.Status{ + CurrentTailnet: &ipnstate.TailnetStatus{MagicDNSSuffix: "test.ts.net"}, + }, + dnsName: "foo.test.ts.net", + srvType: serveTypeHTTP, + srvPort: 80, + mount: "/", + expected: &ipn.ServeConfig{}, + expectErr: false, + }, + { + name: "unset service handler", + desc: "remove an existing service TCP handler", + cfg: &ipn.ServeConfig{ + Services: map[tailcfg.ServiceName]*ipn.ServiceConfig{ + "svc:bar": { + TCP: map[uint16]*ipn.TCPPortHandler{ + 80: {HTTP: true}, + }, + Web: map[ipn.HostPort]*ipn.WebServerConfig{ + "bar.test.ts.net:80": { + Handlers: map[string]*ipn.HTTPHandler{ + "/": {Proxy: "http://localhost:3000"}, + }, + }, + }, + }, + }, + }, + st: &ipnstate.Status{ + CurrentTailnet: &ipnstate.TailnetStatus{MagicDNSSuffix: "test.ts.net"}, + }, + dnsName: "svc:bar", + srvType: serveTypeHTTP, + srvPort: 80, + mount: "/", + expected: &ipn.ServeConfig{}, + expectErr: false, + }, + { + name: "unset service handler tun", + desc: "remove an existing service handler in tun mode", + cfg: &ipn.ServeConfig{ + Services: map[tailcfg.ServiceName]*ipn.ServiceConfig{ + "svc:bar": { + Tun: true, + }, + }, + }, + st: &ipnstate.Status{ + CurrentTailnet: &ipnstate.TailnetStatus{MagicDNSSuffix: "test.ts.net"}, + }, + dnsName: "svc:bar", + srvType: serveTypeTUN, + expected: &ipn.ServeConfig{}, + expectErr: false, + }, + { + name: "unset service handler tcp", + desc: "remove an existing service TCP handler", + cfg: &ipn.ServeConfig{ + Services: map[tailcfg.ServiceName]*ipn.ServiceConfig{ + "svc:bar": { + TCP: map[uint16]*ipn.TCPPortHandler{ + 80: {TCPForward: "11.11.11.11:3000"}, + }, + }, + }, + }, + st: &ipnstate.Status{ + CurrentTailnet: &ipnstate.TailnetStatus{MagicDNSSuffix: "test.ts.net"}, + }, + dnsName: "svc:bar", + srvType: serveTypeTCP, + srvPort: 80, + expected: &ipn.ServeConfig{}, + expectErr: false, + }, + { + name: "unset http handler not found", + desc: "try to remove a non-existing http handler", + cfg: &ipn.ServeConfig{ + TCP: map[uint16]*ipn.TCPPortHandler{ + 80: {HTTP: true}, + }, + Web: map[ipn.HostPort]*ipn.WebServerConfig{ + "foo:80": { + Handlers: map[string]*ipn.HTTPHandler{ + "/": {Proxy: "http://localhost:3000"}, + }, + }, + }, + }, + st: &ipnstate.Status{ + CurrentTailnet: &ipnstate.TailnetStatus{MagicDNSSuffix: "test.ts.net"}, + }, + dnsName: "bar.test.ts.net", + srvType: serveTypeHTTP, + srvPort: 80, + mount: "/abc", + expected: &ipn.ServeConfig{}, + expectErr: true, + }, + { + name: "unset service handler not found", + desc: "try to remove a non-existing service TCP handler", + + cfg: &ipn.ServeConfig{ + Services: map[tailcfg.ServiceName]*ipn.ServiceConfig{ + "svc:bar": { + TCP: map[uint16]*ipn.TCPPortHandler{ + 80: {HTTP: true}, + }, + Web: map[ipn.HostPort]*ipn.WebServerConfig{ + "bar.test.ts.net:80": { + Handlers: map[string]*ipn.HTTPHandler{ + "/": {Proxy: "http://localhost:3000"}, + }, + }, + }, + }, + }, + }, + st: &ipnstate.Status{ + CurrentTailnet: &ipnstate.TailnetStatus{MagicDNSSuffix: "test.ts.net"}, + }, + dnsName: "svc:bar", + srvType: serveTypeHTTP, + srvPort: 80, + mount: "/abc", + setServeEnv: true, + serveEnv: &serveEnv{setPath: "/abc"}, + expected: &ipn.ServeConfig{}, + expectErr: true, + }, + { + name: "unset service doesn't exist", + desc: "try to remove a non-existing service's handler", + cfg: &ipn.ServeConfig{ + Services: map[tailcfg.ServiceName]*ipn.ServiceConfig{ + "svc:bar": { + TCP: map[uint16]*ipn.TCPPortHandler{ + 80: {TCPForward: "11.11.11.11:3000"}, + }, + }, + }, + }, + st: &ipnstate.Status{ + CurrentTailnet: &ipnstate.TailnetStatus{MagicDNSSuffix: "test.ts.net"}, + }, + dnsName: "svc:foo", + srvType: serveTypeTCP, + srvPort: 80, + expectErr: true, + }, + { + name: "unset tcp while port is in use", + desc: "try to remove a TCP handler while the port is used for web", + cfg: &ipn.ServeConfig{ + TCP: map[uint16]*ipn.TCPPortHandler{ + 80: {HTTP: true}, + }, + Web: map[ipn.HostPort]*ipn.WebServerConfig{ + "foo:80": { + Handlers: map[string]*ipn.HTTPHandler{ + "/": {Proxy: "http://localhost:3000"}, + }, + }, + }, + }, + st: &ipnstate.Status{ + CurrentTailnet: &ipnstate.TailnetStatus{MagicDNSSuffix: "test.ts.net"}, + }, + dnsName: "foo.test.ts.net", + srvType: serveTypeTCP, + srvPort: 80, + mount: "/", + expectErr: true, + }, + { + name: "unset service tcp while port is in use", + desc: "try to remove a service TCP handler while the port is used for web", + cfg: &ipn.ServeConfig{ + Services: map[tailcfg.ServiceName]*ipn.ServiceConfig{ + "svc:bar": { + TCP: map[uint16]*ipn.TCPPortHandler{ + 80: {HTTP: true}, + }, + Web: map[ipn.HostPort]*ipn.WebServerConfig{ + "bar.test.ts.net:80": { + Handlers: map[string]*ipn.HTTPHandler{ + "/": {Proxy: "http://localhost:3000"}, + }, + }, + }, + }, + }, + }, + st: &ipnstate.Status{ + CurrentTailnet: &ipnstate.TailnetStatus{MagicDNSSuffix: "test.ts.net"}, + }, + dnsName: "svc:bar", + srvType: serveTypeTCP, + srvPort: 80, + mount: "/", + expectErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + e := &serveEnv{} + if tt.setServeEnv { + e = tt.serveEnv + } + err := e.unsetServe(tt.cfg, tt.dnsName, tt.srvType, tt.srvPort, tt.mount, tt.st.CurrentTailnet.MagicDNSSuffix) + if err != nil && !tt.expectErr { + t.Fatalf("got error: %v; did not expect error.", err) + } + if err == nil && tt.expectErr { + t.Fatalf("got no error; expected error.") + } + if !tt.expectErr && !reflect.DeepEqual(tt.cfg, tt.expected) { + t.Fatalf("got: %v; expected: %v", tt.cfg, tt.expected) + } + }) + } +} + // exactErrMsg returns an error checker that wants exactly the provided want error. // If optName is non-empty, it's used in the error message. func exactErrMsg(want error) func(error) string { @@ -1287,3 +2557,8 @@ func exactErrMsg(want error) func(error) string { return fmt.Sprintf("\ngot: %v\nwant: %v\n", got, want) } } + +func ptrToReadOnlySlice[T any](s []T) *views.Slice[T] { + vs := views.SliceOf(s) + return &vs +} diff --git a/cmd/tailscale/cli/set.go b/cmd/tailscale/cli/set.go index 2e1251f04..cb3a07a6f 100644 --- a/cmd/tailscale/cli/set.go +++ b/cmd/tailscale/cli/set.go @@ -10,17 +10,20 @@ import ( "fmt" "net/netip" "os/exec" + "runtime" + "strconv" "strings" "github.com/peterbourgon/ff/v3/ffcli" - "tailscale.com/client/web" - "tailscale.com/clientupdate" "tailscale.com/cmd/tailscale/cli/ffcomplete" + "tailscale.com/feature/buildfeatures" "tailscale.com/ipn" "tailscale.com/net/netutil" "tailscale.com/net/tsaddr" "tailscale.com/safesocket" + "tailscale.com/tsconst" "tailscale.com/types/opt" + "tailscale.com/types/ptr" "tailscale.com/types/views" "tailscale.com/version" ) @@ -57,19 +60,21 @@ type setArgsT struct { forceDaemon bool updateCheck bool updateApply bool - postureChecking bool + reportPosture bool snat bool statefulFiltering bool + sync bool netfilterMode string + relayServerPort string } func newSetFlagSet(goos string, setArgs *setArgsT) *flag.FlagSet { setf := newFlagSet("set") setf.StringVar(&setArgs.profileName, "nickname", "", "nickname for the current account") - setf.BoolVar(&setArgs.acceptRoutes, "accept-routes", false, "accept routes advertised by other Tailscale nodes") - setf.BoolVar(&setArgs.acceptDNS, "accept-dns", false, "accept DNS configuration from the admin panel") - setf.StringVar(&setArgs.exitNodeIP, "exit-node", "", "Tailscale exit node (IP or base name) for internet traffic, or empty string to not use an exit node") + setf.BoolVar(&setArgs.acceptRoutes, "accept-routes", acceptRouteDefault(goos), "accept routes advertised by other Tailscale nodes") + setf.BoolVar(&setArgs.acceptDNS, "accept-dns", true, "accept DNS configuration from the admin panel") + setf.StringVar(&setArgs.exitNodeIP, "exit-node", "", "Tailscale exit node (IP, base name, or auto:any) for internet traffic, or empty string to not use an exit node") setf.BoolVar(&setArgs.exitNodeAllowLANAccess, "exit-node-allow-lan-access", false, "Allow direct access to the local network when routing traffic via an exit node") setf.BoolVar(&setArgs.shieldsUp, "shields-up", false, "don't allow incoming connections") setf.BoolVar(&setArgs.runSSH, "ssh", false, "run an SSH server, permitting access per tailnet admin's declared policy") @@ -79,8 +84,10 @@ func newSetFlagSet(goos string, setArgs *setArgsT) *flag.FlagSet { setf.BoolVar(&setArgs.advertiseConnector, "advertise-connector", false, "offer to be an app connector for domain specific internet traffic for the tailnet") setf.BoolVar(&setArgs.updateCheck, "update-check", true, "notify about available Tailscale updates") setf.BoolVar(&setArgs.updateApply, "auto-update", false, "automatically update to the latest available version") - setf.BoolVar(&setArgs.postureChecking, "posture-checking", false, hidden+"allow management plane to gather device posture information") + setf.BoolVar(&setArgs.reportPosture, "report-posture", false, "allow management plane to gather device posture information") setf.BoolVar(&setArgs.runWebClient, "webclient", false, "expose the web interface for managing this node over Tailscale at port 5252") + setf.BoolVar(&setArgs.sync, "sync", false, hidden+"actively sync configuration from the control plane (set to false only for network failure testing)") + setf.StringVar(&setArgs.relayServerPort, "relay-server-port", "", "UDP port number (0 will pick a random unused port) for the relay server to bind to, on all interfaces, or empty string to disable relay server functionality") ffcomplete.Flag(setf, "exit-node", func(args []string) ([]string, ffcomplete.ShellCompDirective, error) { st, err := localClient.Status(context.Background()) @@ -103,7 +110,7 @@ func newSetFlagSet(goos string, setArgs *setArgsT) *flag.FlagSet { switch goos { case "linux": setf.BoolVar(&setArgs.snat, "snat-subnet-routes", true, "source NAT traffic to local routes advertised with --advertise-routes") - setf.BoolVar(&setArgs.statefulFiltering, "stateful-filtering", false, "apply stateful filtering to forwarded packets (subnet routers, exit nodes, etc.)") + setf.BoolVar(&setArgs.statefulFiltering, "stateful-filtering", false, "apply stateful filtering to forwarded packets (subnet routers, exit nodes, and so on)") setf.StringVar(&setArgs.netfilterMode, "netfilter-mode", defaultNetfilterMode(), "netfilter mode (one of on, nodivert, off)") case "windows": setf.BoolVar(&setArgs.forceDaemon, "unattended", false, "run in \"Unattended Mode\" where Tailscale keeps running even after the current GUI user logs out (Windows-only)") @@ -144,6 +151,7 @@ func runSet(ctx context.Context, args []string) (retErr error) { OperatorUser: setArgs.opUser, NoSNAT: !setArgs.snat, ForceDaemon: setArgs.forceDaemon, + Sync: opt.NewBool(setArgs.sync), AutoUpdate: ipn.AutoUpdatePrefs{ Check: setArgs.updateCheck, Apply: opt.NewBool(setArgs.updateApply), @@ -151,7 +159,7 @@ func runSet(ctx context.Context, args []string) (retErr error) { AppConnector: ipn.AppConnectorPrefs{ Advertise: setArgs.advertiseConnector, }, - PostureChecking: setArgs.postureChecking, + PostureChecking: setArgs.reportPosture, NoStatefulFiltering: opt.NewBool(!setArgs.statefulFiltering), }, } @@ -168,7 +176,10 @@ func runSet(ctx context.Context, args []string) (retErr error) { } if setArgs.exitNodeIP != "" { - if err := maskedPrefs.Prefs.SetExitNodeIP(setArgs.exitNodeIP, st); err != nil { + if expr, useAutoExitNode := ipn.ParseAutoExitNodeString(setArgs.exitNodeIP); useAutoExitNode { + maskedPrefs.AutoExitNode = expr + maskedPrefs.AutoExitNodeSet = true + } else if err := maskedPrefs.Prefs.SetExitNodeIP(setArgs.exitNodeIP, st); err != nil { var e ipn.ExitNodeLocalIPError if errors.As(err, &e) { return fmt.Errorf("%w; did you mean --advertise-exit-node?", err) @@ -177,7 +188,10 @@ func runSet(ctx context.Context, args []string) (retErr error) { } } - warnOnAdvertiseRouts(ctx, &maskedPrefs.Prefs) + warnOnAdvertiseRoutes(ctx, &maskedPrefs.Prefs) + if err := checkExitNodeRisk(ctx, &maskedPrefs.Prefs, setArgs.acceptedRisks); err != nil { + return err + } var advertiseExitNodeSet, advertiseRoutesSet bool setFlagSet.Visit(func(f *flag.Flag) { updateMaskedPrefsFromUpOrSetFlag(maskedPrefs, f.Name) @@ -203,29 +217,37 @@ func runSet(ctx context.Context, args []string) (retErr error) { } } + if runtime.GOOS == "darwin" && maskedPrefs.AppConnector.Advertise { + if err := presentRiskToUser(riskMacAppConnector, riskMacAppConnectorMessage, setArgs.acceptedRisks); err != nil { + return err + } + } + if maskedPrefs.RunSSHSet { wantSSH, haveSSH := maskedPrefs.RunSSH, curPrefs.RunSSH if err := presentSSHToggleRisk(wantSSH, haveSSH, setArgs.acceptedRisks); err != nil { return err } } - if maskedPrefs.AutoUpdateSet.ApplySet { - if !clientupdate.CanAutoUpdate() { - return errors.New("automatic updates are not supported on this platform") + if maskedPrefs.AutoUpdateSet.ApplySet && buildfeatures.HasClientUpdate && version.IsMacSysExt() { + apply := "0" + if maskedPrefs.AutoUpdate.Apply.EqualBool(true) { + apply = "1" } - // On macsys, tailscaled will set the Sparkle auto-update setting. It - // does not use clientupdate. - if version.IsMacSysExt() { - apply := "0" - if maskedPrefs.AutoUpdate.Apply.EqualBool(true) { - apply = "1" - } - out, err := exec.Command("defaults", "write", "io.tailscale.ipn.macsys", "SUAutomaticallyUpdate", apply).CombinedOutput() - if err != nil { - return fmt.Errorf("failed to enable automatic updates: %v, %q", err, out) - } + out, err := exec.Command("defaults", "write", "io.tailscale.ipn.macsys", "SUAutomaticallyUpdate", apply).CombinedOutput() + if err != nil { + return fmt.Errorf("failed to enable automatic updates: %v, %q", err, out) } } + + if setArgs.relayServerPort != "" { + uport, err := strconv.ParseUint(setArgs.relayServerPort, 10, 16) + if err != nil { + return fmt.Errorf("failed to set relay server port: %v", err) + } + maskedPrefs.Prefs.RelayServerPort = ptr.To(int(uport)) + } + checkPrefs := curPrefs.Clone() checkPrefs.ApplyEdits(maskedPrefs) if err := localClient.CheckPrefs(ctx, checkPrefs); err != nil { @@ -238,7 +260,7 @@ func runSet(ctx context.Context, args []string) (retErr error) { } if setArgs.runWebClient && len(st.TailscaleIPs) > 0 { - printf("\nWeb interface now running at %s:%d", st.TailscaleIPs[0], web.ListenPort) + printf("\nWeb interface now running at %s:%d\n", st.TailscaleIPs[0], tsconst.WebListenPort) } return nil diff --git a/cmd/tailscale/cli/set_test.go b/cmd/tailscale/cli/set_test.go index 15305c3ce..a2f211f8c 100644 --- a/cmd/tailscale/cli/set_test.go +++ b/cmd/tailscale/cli/set_test.go @@ -4,6 +4,7 @@ package cli import ( + "flag" "net/netip" "reflect" "testing" @@ -129,3 +130,24 @@ func TestCalcAdvertiseRoutesForSet(t *testing.T) { }) } } + +// TestSetDefaultsMatchUpDefaults is meant to ensure that the default values +// for `tailscale set` and `tailscale up` are the same. +// Since `tailscale set` only sets preferences that are explicitly mentioned, +// the default values for its flags are only used for `--help` documentation. +func TestSetDefaultsMatchUpDefaults(t *testing.T) { + upFlagSet.VisitAll(func(up *flag.Flag) { + if preflessFlag(up.Name) { + return + } + + set := setFlagSet.Lookup(up.Name) + if set == nil { + return + } + + if set.DefValue != up.DefValue { + t.Errorf("--%s: set defaults to %q, but up defaults to %q", up.Name, set.DefValue, up.DefValue) + } + }) +} diff --git a/cmd/tailscale/cli/ssh.go b/cmd/tailscale/cli/ssh.go index 68a6193af..9275c9a1c 100644 --- a/cmd/tailscale/cli/ssh.go +++ b/cmd/tailscale/cli/ssh.go @@ -70,12 +70,28 @@ func runSSH(ctx context.Context, args []string) error { return err } + prefs, err := localClient.GetPrefs(ctx) + if err != nil { + return err + } + // hostForSSH is the hostname we'll tell OpenSSH we're // connecting to, so we have to maintain fewer entries in the // known_hosts files. hostForSSH := host - if v, ok := nodeDNSNameFromArg(st, host); ok { - hostForSSH = v + ps, ok := peerStatusFromArg(st, host) + if ok { + hostForSSH = ps.DNSName + + // If MagicDNS isn't enabled on the client, + // we will use the first IPv4 we know about + // or fallback to the first IPv6 address + if !prefs.CorpDNS { + ipHost, found := ipFromPeerStatus(ps) + if found { + hostForSSH = ipHost + } + } } ssh, err := findSSH() @@ -84,10 +100,6 @@ func runSSH(ctx context.Context, args []string) error { // of failing. But for now: return fmt.Errorf("no system 'ssh' command found: %w", err) } - tailscaleBin, err := os.Executable() - if err != nil { - return err - } knownHostsFile, err := writeKnownHosts(st) if err != nil { return err @@ -116,7 +128,9 @@ func runSSH(ctx context.Context, args []string) error { argv = append(argv, "-o", fmt.Sprintf("ProxyCommand %q %s nc %%h %%p", - tailscaleBin, + // os.Executable() would return the real running binary but in case tailscale is built with the ts_include_cli tag, + // we need to return the started symlink instead + os.Args[0], socketArg, )) } @@ -171,11 +185,40 @@ func genKnownHosts(st *ipnstate.Status) []byte { continue } fmt.Fprintf(&buf, "%s %s\n", ps.DNSName, hostKey) + for _, ip := range ps.TailscaleIPs { + fmt.Fprintf(&buf, "%s %s\n", ip.String(), hostKey) + } } } return buf.Bytes() } +// peerStatusFromArg returns the PeerStatus that matches +// the input arg which can be a base name, full DNS name, or an IP. +func peerStatusFromArg(st *ipnstate.Status, arg string) (*ipnstate.PeerStatus, bool) { + if arg == "" { + return nil, false + } + argIP, _ := netip.ParseAddr(arg) + for _, ps := range st.Peer { + if argIP.IsValid() { + for _, ip := range ps.TailscaleIPs { + if ip == argIP { + return ps, true + } + } + continue + } + if strings.EqualFold(strings.TrimSuffix(arg, "."), strings.TrimSuffix(ps.DNSName, ".")) { + return ps, true + } + if base, _, ok := strings.Cut(ps.DNSName, "."); ok && strings.EqualFold(base, arg) { + return ps, true + } + } + return nil, false +} + // nodeDNSNameFromArg returns the PeerStatus.DNSName value from a peer // in st that matches the input arg which can be a base name, full // DNS name, or an IP. @@ -204,6 +247,20 @@ func nodeDNSNameFromArg(st *ipnstate.Status, arg string) (dnsName string, ok boo return "", false } +func ipFromPeerStatus(ps *ipnstate.PeerStatus) (string, bool) { + if len(ps.TailscaleIPs) < 1 { + return "", false + } + + // Look for a IPv4 address or default to the first IP of the list + for _, ip := range ps.TailscaleIPs { + if ip.Is4() { + return ip.String(), true + } + } + return ps.TailscaleIPs[0].String(), true +} + // getSSHClientEnvVar returns the "SSH_CLIENT" environment variable // for the current process group, if any. var getSSHClientEnvVar = func() string { diff --git a/cmd/tailscale/cli/status.go b/cmd/tailscale/cli/status.go index e4dccc247..89b18335b 100644 --- a/cmd/tailscale/cli/status.go +++ b/cmd/tailscale/cli/status.go @@ -4,7 +4,6 @@ package cli import ( - "bytes" "cmp" "context" "encoding/json" @@ -15,12 +14,13 @@ import ( "net/http" "net/netip" "os" - "strconv" "strings" + "text/tabwriter" "github.com/peterbourgon/ff/v3/ffcli" "github.com/toqueteos/webbrowser" "golang.org/x/net/idna" + "tailscale.com/feature" "tailscale.com/ipn" "tailscale.com/ipn/ipnstate" "tailscale.com/net/netmon" @@ -56,6 +56,7 @@ https://github.com/tailscale/tailscale/blob/main/ipn/ipnstate/ipnstate.go fs.BoolVar(&statusArgs.peers, "peers", true, "show status of peers") fs.StringVar(&statusArgs.listen, "listen", "127.0.0.1:8384", "listen address for web mode; use port 0 for automatic") fs.BoolVar(&statusArgs.browser, "browser", true, "Open a browser in web mode") + fs.BoolVar(&statusArgs.header, "header", false, "show column headers in table format") return fs })(), } @@ -68,8 +69,11 @@ var statusArgs struct { active bool // in CLI mode, filter output to only peers with active sessions self bool // in CLI mode, show status of local machine peers bool // in CLI mode, show status of peer machines + header bool // in CLI mode, show column headers in table format } +const mullvadTCD = "mullvad.ts.net." + func runStatus(ctx context.Context, args []string) error { if len(args) > 0 { return errors.New("unexpected non-flag arguments to 'tailscale status'") @@ -149,10 +153,15 @@ func runStatus(ctx context.Context, args []string) error { os.Exit(1) } - var buf bytes.Buffer - f := func(format string, a ...any) { fmt.Fprintf(&buf, format, a...) } + w := tabwriter.NewWriter(Stdout, 0, 0, 2, ' ', 0) + f := func(format string, a ...any) { fmt.Fprintf(w, format, a...) } + if statusArgs.header { + fmt.Fprintln(w, "IP\tHostname\tOwner\tOS\tStatus\t") + fmt.Fprintln(w, "--\t--------\t-----\t--\t------\t") + } + printPS := func(ps *ipnstate.PeerStatus) { - f("%-15s %-20s %-12s %-7s ", + f("%s\t%s\t%s\t%s\t", firstIPString(ps.TailscaleIPs), dnsOrQuoteHostname(st, ps), ownerLogin(st, ps), @@ -162,7 +171,7 @@ func runStatus(ctx context.Context, args []string) error { anyTraffic := ps.TxBytes != 0 || ps.RxBytes != 0 var offline string if !ps.Online { - offline = "; offline" + offline = "; offline" + lastSeenFmt(ps.LastSeen) } if !ps.Active { if ps.ExitNode { @@ -172,7 +181,7 @@ func runStatus(ctx context.Context, args []string) error { } else if anyTraffic { f("idle" + offline) } else if !ps.Online { - f("offline") + f("offline" + lastSeenFmt(ps.LastSeen)) } else { f("-") } @@ -183,19 +192,21 @@ func runStatus(ctx context.Context, args []string) error { } else if ps.ExitNodeOption { f("offers exit node; ") } - if relay != "" && ps.CurAddr == "" { + if relay != "" && ps.CurAddr == "" && ps.PeerRelay == "" { f("relay %q", relay) } else if ps.CurAddr != "" { f("direct %s", ps.CurAddr) + } else if ps.PeerRelay != "" { + f("peer-relay %s", ps.PeerRelay) } if !ps.Online { - f("; offline") + f(offline) } } if anyTraffic { f(", tx %d rx %d", ps.TxBytes, ps.RxBytes) } - f("\n") + f("\t\n") } if statusArgs.self && st.Self != nil { @@ -210,9 +221,8 @@ func runStatus(ctx context.Context, args []string) error { if ps.ShareeNode { continue } - if ps.Location != nil && ps.ExitNodeOption && !ps.ExitNode { - // Location based exit nodes are only shown with the - // `exit-node list` command. + if ps.ExitNodeOption && !ps.ExitNode && strings.HasSuffix(ps.DNSName, mullvadTCD) { + // Mullvad exit nodes are only shown with the `exit-node list` command. locBasedExitNode = true continue } @@ -226,7 +236,8 @@ func runStatus(ctx context.Context, args []string) error { printPS(ps) } } - Stdout.Write(buf.Bytes()) + w.Flush() + if locBasedExitNode { outln() printf("# To see the full list of exit nodes, including location-based exit nodes, run `tailscale exit-node list` \n") @@ -235,44 +246,13 @@ func runStatus(ctx context.Context, args []string) error { outln() printHealth() } - printFunnelStatus(ctx) + if f, ok := hookPrintFunnelStatus.GetOk(); ok { + f(ctx) + } return nil } -// printFunnelStatus prints the status of the funnel, if it's running. -// It prints nothing if the funnel is not running. -func printFunnelStatus(ctx context.Context) { - sc, err := localClient.GetServeConfig(ctx) - if err != nil { - outln() - printf("# Funnel:\n") - printf("# - Unable to get Funnel status: %v\n", err) - return - } - if !sc.IsFunnelOn() { - return - } - outln() - printf("# Funnel on:\n") - for hp, on := range sc.AllowFunnel { - if !on { // if present, should be on - continue - } - sni, portStr, _ := net.SplitHostPort(string(hp)) - p, _ := strconv.ParseUint(portStr, 10, 16) - isTCP := sc.IsTCPForwardingOnPort(uint16(p)) - url := "https://" - if isTCP { - url = "tcp://" - } - url += sni - if isTCP || p != 443 { - url += ":" + portStr - } - printf("# - %s\n", url) - } - outln() -} +var hookPrintFunnelStatus feature.Hook[func(context.Context)] // isRunningOrStarting reports whether st is in state Running or Starting. // It also returns a description of the status suitable to display to a user. diff --git a/cmd/tailscale/cli/switch.go b/cmd/tailscale/cli/switch.go index 731492daa..b315a21e7 100644 --- a/cmd/tailscale/cli/switch.go +++ b/cmd/tailscale/cli/switch.go @@ -20,11 +20,11 @@ import ( var switchCmd = &ffcli.Command{ Name: "switch", ShortUsage: "tailscale switch ", - ShortHelp: "Switches to a different Tailscale account", + ShortHelp: "Switch to a different Tailscale account", LongHelp: `"tailscale switch" switches between logged in accounts. You can use the ID that's returned from 'tailnet switch -list' to pick which profile you want to switch to. Alternatively, you -can use the Tailnet or the account names to switch as well. +can use the Tailnet, account names, or display names to switch as well. This command is currently in alpha and may change in the future.`, @@ -34,6 +34,22 @@ This command is currently in alpha and may change in the future.`, return fs }(), Exec: switchProfile, + + // Add remove subcommand + Subcommands: []*ffcli.Command{ + { + Name: "remove", + ShortUsage: "tailscale switch remove ", + ShortHelp: "Remove a Tailscale account", + LongHelp: `"tailscale switch remove" removes a Tailscale account from the +local machine. This does not delete the account itself, but +it will no longer be available for switching to. You can +add it back by logging in again. + +This command is currently in alpha and may change in the future.`, + Exec: removeProfile, + }, + }, } func init() { @@ -46,7 +62,7 @@ func init() { seen := make(map[string]bool, 3*len(all)) wordfns := []func(prof ipn.LoginProfile) string{ func(prof ipn.LoginProfile) string { return string(prof.ID) }, - func(prof ipn.LoginProfile) string { return prof.NetworkProfile.DomainName }, + func(prof ipn.LoginProfile) string { return prof.NetworkProfile.DisplayNameOrDefault() }, func(prof ipn.LoginProfile) string { return prof.Name }, } @@ -57,7 +73,7 @@ func init() { continue } seen[word] = true - words = append(words, fmt.Sprintf("%s\tid: %s, tailnet: %s, account: %s", word, prof.ID, prof.NetworkProfile.DomainName, prof.Name)) + words = append(words, fmt.Sprintf("%s\tid: %s, tailnet: %s, account: %s", word, prof.ID, prof.NetworkProfile.DisplayNameOrDefault(), prof.Name)) } } return words, ffcomplete.ShellCompDirectiveNoFileComp, nil @@ -86,7 +102,7 @@ func listProfiles(ctx context.Context) error { } printRow( string(prof.ID), - prof.NetworkProfile.DomainName, + prof.NetworkProfile.DisplayNameOrDefault(), name, ) } @@ -106,32 +122,8 @@ func switchProfile(ctx context.Context, args []string) error { errf("Failed to switch to account: %v\n", err) os.Exit(1) } - var profID ipn.ProfileID - // Allow matching by ID, Tailnet, or Account - // in that order. - for _, p := range all { - if p.ID == ipn.ProfileID(args[0]) { - profID = p.ID - break - } - } - if profID == "" { - for _, p := range all { - if p.NetworkProfile.DomainName == args[0] { - profID = p.ID - break - } - } - } - if profID == "" { - for _, p := range all { - if p.Name == args[0] { - profID = p.ID - break - } - } - } - if profID == "" { + profID, ok := matchProfile(args[0], all) + if !ok { errf("No profile named %q\n", args[0]) os.Exit(1) } @@ -178,3 +170,54 @@ func switchProfile(ctx context.Context, args []string) error { } } } + +func removeProfile(ctx context.Context, args []string) error { + if len(args) != 1 { + outln("usage: tailscale switch remove NAME") + os.Exit(1) + } + cp, all, err := localClient.ProfileStatus(ctx) + if err != nil { + errf("Failed to remove account: %v\n", err) + os.Exit(1) + } + + profID, ok := matchProfile(args[0], all) + if !ok { + errf("No profile named %q\n", args[0]) + os.Exit(1) + } + + if profID == cp.ID { + printf("Already on account %q\n", args[0]) + os.Exit(0) + } + + return localClient.DeleteProfile(ctx, profID) +} + +func matchProfile(arg string, all []ipn.LoginProfile) (ipn.ProfileID, bool) { + // Allow matching by ID, Tailnet, Account, or Display Name + // in that order. + for _, p := range all { + if p.ID == ipn.ProfileID(arg) { + return p.ID, true + } + } + for _, p := range all { + if p.NetworkProfile.DomainName == arg { + return p.ID, true + } + } + for _, p := range all { + if p.Name == arg { + return p.ID, true + } + } + for _, p := range all { + if p.NetworkProfile.DisplayName == arg { + return p.ID, true + } + } + return "", false +} diff --git a/cmd/tailscale/cli/syspolicy.go b/cmd/tailscale/cli/syspolicy.go new file mode 100644 index 000000000..97f3f2122 --- /dev/null +++ b/cmd/tailscale/cli/syspolicy.go @@ -0,0 +1,115 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !ts_omit_syspolicy + +package cli + +import ( + "context" + "encoding/json" + "flag" + "fmt" + "os" + "slices" + "text/tabwriter" + + "github.com/peterbourgon/ff/v3/ffcli" + "tailscale.com/util/syspolicy/setting" +) + +var syspolicyArgs struct { + json bool // JSON output mode +} + +func init() { + sysPolicyCmd = func() *ffcli.Command { + return &ffcli.Command{ + Name: "syspolicy", + ShortHelp: "Diagnose the MDM and system policy configuration", + LongHelp: "The 'tailscale syspolicy' command provides tools for diagnosing the MDM and system policy configuration.", + ShortUsage: "tailscale syspolicy ", + UsageFunc: usageFuncNoDefaultValues, + Subcommands: []*ffcli.Command{ + { + Name: "list", + ShortUsage: "tailscale syspolicy list", + Exec: runSysPolicyList, + ShortHelp: "Print effective policy settings", + LongHelp: "The 'tailscale syspolicy list' subcommand displays the effective policy settings and their sources (e.g., MDM or environment variables).", + FlagSet: (func() *flag.FlagSet { + fs := newFlagSet("syspolicy list") + fs.BoolVar(&syspolicyArgs.json, "json", false, "output in JSON format") + return fs + })(), + }, + { + Name: "reload", + ShortUsage: "tailscale syspolicy reload", + Exec: runSysPolicyReload, + ShortHelp: "Force a reload of policy settings, even if no changes are detected, and prints the result", + LongHelp: "The 'tailscale syspolicy reload' subcommand forces a reload of policy settings, even if no changes are detected, and prints the result.", + FlagSet: (func() *flag.FlagSet { + fs := newFlagSet("syspolicy reload") + fs.BoolVar(&syspolicyArgs.json, "json", false, "output in JSON format") + return fs + })(), + }, + }, + } + } +} + +func runSysPolicyList(ctx context.Context, args []string) error { + policy, err := localClient.GetEffectivePolicy(ctx, setting.DefaultScope()) + if err != nil { + return err + } + printPolicySettings(policy) + return nil +} + +func runSysPolicyReload(ctx context.Context, args []string) error { + policy, err := localClient.ReloadEffectivePolicy(ctx, setting.DefaultScope()) + if err != nil { + return err + } + printPolicySettings(policy) + return nil +} + +func printPolicySettings(policy *setting.Snapshot) { + if syspolicyArgs.json { + json, err := json.MarshalIndent(policy, "", "\t") + if err != nil { + errf("syspolicy marshalling error: %v", err) + } else { + outln(string(json)) + } + return + } + if policy.Len() == 0 { + outln("No policy settings") + return + } + + w := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0) + fmt.Fprintln(w, "Name\tOrigin\tValue\tError") + fmt.Fprintln(w, "----\t------\t-----\t-----") + for _, k := range slices.Sorted(policy.Keys()) { + setting, _ := policy.GetSetting(k) + var origin string + if o := setting.Origin(); o != nil { + origin = o.String() + } + if err := setting.Error(); err != nil { + fmt.Fprintf(w, "%s\t%s\t\t{%v}\n", k, origin, err) + } else { + fmt.Fprintf(w, "%s\t%s\t%v\t\n", k, origin, setting.Value()) + } + } + w.Flush() + + fmt.Println() + return +} diff --git a/cmd/tailscale/cli/systray.go b/cmd/tailscale/cli/systray.go new file mode 100644 index 000000000..827e8a9a4 --- /dev/null +++ b/cmd/tailscale/cli/systray.go @@ -0,0 +1,26 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build linux && !ts_omit_systray + +package cli + +import ( + "context" + + "github.com/peterbourgon/ff/v3/ffcli" + "tailscale.com/client/systray" +) + +var systrayCmd = &ffcli.Command{ + Name: "systray", + ShortUsage: "tailscale systray", + ShortHelp: "Run a systray application to manage Tailscale", + LongHelp: "Run a systray application to manage Tailscale.", + Exec: runSystray, +} + +func runSystray(ctx context.Context, _ []string) error { + new(systray.Menu).Run(&localClient) + return nil +} diff --git a/cmd/tailscale/cli/systray_omit.go b/cmd/tailscale/cli/systray_omit.go new file mode 100644 index 000000000..8d93fd84b --- /dev/null +++ b/cmd/tailscale/cli/systray_omit.go @@ -0,0 +1,31 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !linux || ts_omit_systray + +package cli + +import ( + "context" + "fmt" + "strings" + + "github.com/peterbourgon/ff/v3/ffcli" +) + +// TODO(will): update URL to KB article when available +var systrayHelp = strings.TrimSpace(` +The Tailscale systray app is not included in this client build. +To run it manually, see https://github.com/tailscale/tailscale/tree/main/cmd/systray +`) + +var systrayCmd = &ffcli.Command{ + Name: "systray", + ShortUsage: "tailscale systray", + ShortHelp: "Not available in this client build", + LongHelp: hidden + systrayHelp, + Exec: func(_ context.Context, _ []string) error { + fmt.Println(systrayHelp) + return nil + }, +} diff --git a/cmd/tailscale/cli/up.go b/cmd/tailscale/cli/up.go index e1b828105..7f5b2e6b4 100644 --- a/cmd/tailscale/cli/up.go +++ b/cmd/tailscale/cli/up.go @@ -12,13 +12,11 @@ import ( "fmt" "log" "net/netip" - "net/url" "os" "os/signal" "reflect" "runtime" "sort" - "strconv" "strings" "syscall" "time" @@ -26,9 +24,11 @@ import ( shellquote "github.com/kballard/go-shellquote" "github.com/peterbourgon/ff/v3/ffcli" qrcode "github.com/skip2/go-qrcode" - "golang.org/x/oauth2/clientcredentials" - "tailscale.com/client/tailscale" + "tailscale.com/feature/buildfeatures" + _ "tailscale.com/feature/condregister/identityfederation" + _ "tailscale.com/feature/condregister/oauthkey" "tailscale.com/health/healthmsg" + "tailscale.com/internal/client/tailscale" "tailscale.com/ipn" "tailscale.com/ipn/ipnstate" "tailscale.com/net/netutil" @@ -39,7 +39,7 @@ import ( "tailscale.com/types/preftype" "tailscale.com/types/views" "tailscale.com/util/dnsname" - "tailscale.com/version" + "tailscale.com/util/syspolicy/policyclient" "tailscale.com/version/distro" ) @@ -79,14 +79,8 @@ func effectiveGOOS() string { // acceptRouteDefault returns the CLI's default value of --accept-routes as // a function of the platform it's running on. func acceptRouteDefault(goos string) bool { - switch goos { - case "windows": - return true - case "darwin": - return version.IsSandboxedMacOS() - default: - return false - } + var p *ipn.Prefs + return p.DefaultRouteAll(goos) } var upFlagSet = newUpFlagSet(effectiveGOOS(), &upArgsGlobal, "up") @@ -101,13 +95,17 @@ func newUpFlagSet(goos string, upArgs *upArgsT, cmd string) *flag.FlagSet { // When adding new flags, prefer to put them under "tailscale set" instead // of here. Setting preferences via "tailscale up" is deprecated. upf.BoolVar(&upArgs.qr, "qr", false, "show QR code for login URLs") + upf.StringVar(&upArgs.qrFormat, "qr-format", "small", "QR code formatting (small or large)") upf.StringVar(&upArgs.authKeyOrFile, "auth-key", "", `node authorization key; if it begins with "file:", then it's a path to a file containing the authkey`) + upf.StringVar(&upArgs.clientID, "client-id", "", "Client ID used to generate authkeys via workload identity federation") + upf.StringVar(&upArgs.clientSecretOrFile, "client-secret", "", `Client Secret used to generate authkeys via OAuth; if it begins with "file:", then it's a path to a file containing the secret`) + upf.StringVar(&upArgs.idTokenOrFile, "id-token", "", `ID token from the identity provider to exchange with the control server for workload identity federation; if it begins with "file:", then it's a path to a file containing the token`) upf.StringVar(&upArgs.server, "login-server", ipn.DefaultControlURL, "base URL of control server") upf.BoolVar(&upArgs.acceptRoutes, "accept-routes", acceptRouteDefault(goos), "accept routes advertised by other Tailscale nodes") upf.BoolVar(&upArgs.acceptDNS, "accept-dns", true, "accept DNS configuration from the admin panel") upf.Var(notFalseVar{}, "host-routes", hidden+"install host routes to other Tailscale nodes (must be true as of Tailscale 1.67+)") - upf.StringVar(&upArgs.exitNodeIP, "exit-node", "", "Tailscale exit node (IP or base name) for internet traffic, or empty string to not use an exit node") + upf.StringVar(&upArgs.exitNodeIP, "exit-node", "", "Tailscale exit node (IP, base name, or auto:any) for internet traffic, or empty string to not use an exit node") upf.BoolVar(&upArgs.exitNodeAllowLANAccess, "exit-node-allow-lan-access", false, "Allow direct access to the local network when routing traffic via an exit node") upf.BoolVar(&upArgs.shieldsUp, "shields-up", false, "don't allow incoming connections") upf.BoolVar(&upArgs.runSSH, "ssh", false, "run an SSH server, permitting access per tailnet admin's declared policy") @@ -116,6 +114,7 @@ func newUpFlagSet(goos string, upArgs *upArgsT, cmd string) *flag.FlagSet { upf.StringVar(&upArgs.advertiseRoutes, "advertise-routes", "", "routes to advertise to other nodes (comma-separated, e.g. \"10.0.0.0/8,192.168.0.0/24\") or empty string to not advertise routes") upf.BoolVar(&upArgs.advertiseConnector, "advertise-connector", false, "advertise this node as an app connector") upf.BoolVar(&upArgs.advertiseDefaultRoute, "advertise-exit-node", false, "offer to be an exit node for internet traffic for the tailnet") + upf.BoolVar(&upArgs.postureChecking, "report-posture", false, hidden+"allow management plane to gather device posture information") if safesocket.GOOSUsesPeerCreds(goos) { upf.StringVar(&upArgs.opUser, "operator", "", "Unix username to allow to operate on tailscaled without sudo") @@ -123,7 +122,7 @@ func newUpFlagSet(goos string, upArgs *upArgsT, cmd string) *flag.FlagSet { switch goos { case "linux": upf.BoolVar(&upArgs.snat, "snat-subnet-routes", true, "source NAT traffic to local routes advertised with --advertise-routes") - upf.BoolVar(&upArgs.statefulFiltering, "stateful-filtering", false, "apply stateful filtering to forwarded packets (subnet routers, exit nodes, etc.)") + upf.BoolVar(&upArgs.statefulFiltering, "stateful-filtering", false, "apply stateful filtering to forwarded packets (subnet routers, exit nodes, and so on)") upf.StringVar(&upArgs.netfilterMode, "netfilter-mode", defaultNetfilterMode(), "netfilter mode (one of on, nodivert, off)") case "windows": upf.BoolVar(&upArgs.forceDaemon, "unattended", false, "run in \"Unattended Mode\" where Tailscale keeps running even after the current GUI user logs out (Windows-only)") @@ -138,7 +137,7 @@ func newUpFlagSet(goos string, upArgs *upArgsT, cmd string) *flag.FlagSet { // Some flags are only for "up", not "login". upf.BoolVar(&upArgs.json, "json", false, "output in JSON format (WARNING: format subject to change)") upf.BoolVar(&upArgs.reset, "reset", false, "reset unspecified settings to their default values") - upf.BoolVar(&upArgs.forceReauth, "force-reauth", false, "force reauthentication") + upf.BoolVar(&upArgs.forceReauth, "force-reauth", false, "force reauthentication (WARNING: this will bring down the Tailscale connection and thus should not be done remotely over SSH or RDP)") registerAcceptRiskFlag(upf, &upArgs.acceptedRisks) } @@ -164,8 +163,12 @@ func defaultNetfilterMode() string { return "on" } +// upArgsT is the type of upArgs, the argument struct for `tailscale up`. +// As of 2024-10-08, upArgsT is frozen and no new arguments should be +// added to it. Add new arguments to setArgsT instead. type upArgsT struct { qr bool + qrFormat string reset bool server string acceptRoutes bool @@ -185,16 +188,21 @@ type upArgsT struct { statefulFiltering bool netfilterMode string authKeyOrFile string // "secret" or "file:/path/to/secret" + clientID string + clientSecretOrFile string // "secret" or "file:/path/to/secret" + idTokenOrFile string // "secret" or "file:/path/to/secret" hostname string opUser string json bool timeout time.Duration acceptedRisks string profileName string + postureChecking bool } -func (a upArgsT) getAuthKey() (string, error) { - v := a.authKeyOrFile +// resolveValueFromFile returns the value as-is, or if it starts with "file:", +// reads and returns the trimmed contents of the file. +func resolveValueFromFile(v string) (string, error) { if file, ok := strings.CutPrefix(v, "file:"); ok { b, err := os.ReadFile(file) if err != nil { @@ -205,6 +213,18 @@ func (a upArgsT) getAuthKey() (string, error) { return v, nil } +func (a upArgsT) getAuthKey() (string, error) { + return resolveValueFromFile(a.authKeyOrFile) +} + +func (a upArgsT) getClientSecret() (string, error) { + return resolveValueFromFile(a.clientSecretOrFile) +} + +func (a upArgsT) getIDToken() (string, error) { + return resolveValueFromFile(a.idTokenOrFile) +} + var upArgsGlobal upArgsT // Fields output when `tailscale up --json` is used. Two JSON blocks will be output. @@ -280,7 +300,9 @@ func prefsFromUpArgs(upArgs upArgsT, warnf logger.Logf, st *ipnstate.Status, goo prefs.NetfilterMode = preftype.NetfilterOff } if upArgs.exitNodeIP != "" { - if err := prefs.SetExitNodeIP(upArgs.exitNodeIP, st); err != nil { + if expr, useAutoExitNode := ipn.ParseAutoExitNodeString(upArgs.exitNodeIP); useAutoExitNode { + prefs.AutoExitNode = expr + } else if err := prefs.SetExitNodeIP(upArgs.exitNodeIP, st); err != nil { var e ipn.ExitNodeLocalIPError if errors.As(err, &e) { return nil, fmt.Errorf("%w; did you mean --advertise-exit-node?", err) @@ -301,6 +323,7 @@ func prefsFromUpArgs(upArgs upArgsT, warnf logger.Logf, st *ipnstate.Status, goo prefs.OperatorUser = upArgs.opUser prefs.ProfileName = upArgs.profileName prefs.AppConnector.Advertise = upArgs.advertiseConnector + prefs.PostureChecking = upArgs.postureChecking if goos == "linux" { prefs.NoSNAT = !upArgs.snat @@ -354,11 +377,19 @@ func netfilterModeFromFlag(v string) (_ preftype.NetfilterMode, warning string, // It returns simpleUp if we're running a simple "tailscale up" to // transition to running from a previously-logged-in but down state, // without changing any settings. +// +// Note this can also mutate prefs to add implicit preferences for the +// user operator. +// +// TODO(alexc): the name of this function is confusing, and perhaps a +// sign that it's doing too much. Consider refactoring this so it's just +// telling the caller what to do next, but not changing anything itself. func updatePrefs(prefs, curPrefs *ipn.Prefs, env upCheckEnv) (simpleUp bool, justEditMP *ipn.MaskedPrefs, err error) { if !env.upArgs.reset { applyImplicitPrefs(prefs, curPrefs, env) - if err := checkForAccidentalSettingReverts(prefs, curPrefs, env); err != nil { + simpleUp, err = checkForAccidentalSettingReverts(prefs, curPrefs, env) + if err != nil { return false, nil, err } } @@ -376,19 +407,20 @@ func updatePrefs(prefs, curPrefs *ipn.Prefs, env upCheckEnv) (simpleUp bool, jus return false, nil, err } + if env.goos == "darwin" && env.upArgs.advertiseConnector { + if err := presentRiskToUser(riskMacAppConnector, riskMacAppConnectorMessage, env.upArgs.acceptedRisks); err != nil { + return false, nil, err + } + } + if env.upArgs.forceReauth && isSSHOverTailscale() { - if err := presentRiskToUser(riskLoseSSH, `You are connected over Tailscale; this action will result in your SSH session disconnecting.`, env.upArgs.acceptedRisks); err != nil { + if err := presentRiskToUser(riskLoseSSH, `You are connected over Tailscale; this action may result in your SSH session disconnecting.`, env.upArgs.acceptedRisks); err != nil { return false, nil, err } } tagsChanged := !reflect.DeepEqual(curPrefs.AdvertiseTags, prefs.AdvertiseTags) - simpleUp = env.flagSet.NFlag() == 0 && - curPrefs.Persist != nil && - curPrefs.Persist.UserProfile.LoginName != "" && - env.backendState != ipn.NeedsLogin.String() - justEdit := env.backendState == ipn.Running.String() && !env.upArgs.forceReauth && env.upArgs.authKeyOrFile == "" && @@ -403,6 +435,9 @@ func updatePrefs(prefs, curPrefs *ipn.Prefs, env upCheckEnv) (simpleUp bool, jus if env.upArgs.reset { visitFlags = env.flagSet.VisitAll } + if prefs.AutoExitNode.IsSet() { + justEditMP.AutoExitNodeSet = true + } visitFlags(func(f *flag.Flag) { updateMaskedPrefsFromUpOrSetFlag(justEditMP, f.Name) }) @@ -435,6 +470,7 @@ func runUp(ctx context.Context, cmd string, args []string, upArgs upArgsT) (retE return fixTailscaledConnectError(err) } origAuthURL := st.AuthURL + origNodeKey := st.Self.PublicKey // printAuthURL reports whether we should print out the // provided auth URL from an IPN notify. @@ -475,12 +511,17 @@ func runUp(ctx context.Context, cmd string, args []string, upArgs upArgsT) (retE fatalf("%s", err) } - warnOnAdvertiseRouts(ctx, prefs) + warnOnAdvertiseRoutes(ctx, prefs) + if err := checkExitNodeRisk(ctx, prefs, upArgs.acceptedRisks); err != nil { + return err + } curPrefs, err := localClient.GetPrefs(ctx) if err != nil { return err } + effectivePrefs := curPrefs + if cmd == "up" { // "tailscale up" should not be able to change the // profile name. @@ -526,8 +567,16 @@ func runUp(ctx context.Context, cmd string, args []string, upArgs upArgsT) (retE } }() - running := make(chan bool, 1) // gets value once in state ipn.Running - watchErr := make(chan error, 1) + // Start watching the IPN bus before we call Start() or StartLoginInteractive(), + // or we could miss IPN notifications. + // + // In particular, if we're doing a force-reauth, we could miss the + // notification with the auth URL we should print for the user. + watcher, err := localClient.WatchIPNBus(watchCtx, 0) + if err != nil { + return err + } + defer watcher.Close() // Special case: bare "tailscale up" means to just start // running, if there's ever been a login. @@ -550,10 +599,36 @@ func runUp(ctx context.Context, cmd string, args []string, upArgs upArgsT) (retE if err != nil { return err } - authKey, err = resolveAuthKey(ctx, authKey, upArgs.advertiseTags) - if err != nil { - return err + // Try to use an OAuth secret to generate an auth key if that functionality + // is available. + if f, ok := tailscale.HookResolveAuthKey.GetOk(); ok { + clientSecret := authKey // the authkey argument accepts client secrets, if both arguments are provided authkey has precedence + if clientSecret == "" { + clientSecret, err = upArgs.getClientSecret() + if err != nil { + return err + } + } + + authKey, err = f(ctx, clientSecret, strings.Split(upArgs.advertiseTags, ",")) + if err != nil { + return err + } + } + // Try to resolve the auth key via workload identity federation if that functionality + // is available and no auth key is yet determined. + if f, ok := tailscale.HookResolveAuthKeyViaWIF.GetOk(); ok && authKey == "" { + idToken, err := upArgs.getIDToken() + if err != nil { + return err + } + + authKey, err = f(ctx, prefs.ControlURL, upArgs.clientID, idToken, strings.Split(upArgs.advertiseTags, ",")) + if err != nil { + return err + } } + err = localClient.Start(ctx, ipn.Options{ AuthKey: authKey, UpdatePrefs: prefs, @@ -561,6 +636,7 @@ func runUp(ctx context.Context, cmd string, args []string, upArgs upArgsT) (retE if err != nil { return err } + effectivePrefs = prefs if upArgs.forceReauth || !st.HaveNodeKey { err := localClient.StartLoginInteractive(ctx) if err != nil { @@ -569,15 +645,32 @@ func runUp(ctx context.Context, cmd string, args []string, upArgs upArgsT) (retE } } - watcher, err := localClient.WatchIPNBus(watchCtx, ipn.NotifyInitialState) - if err != nil { - return err - } - defer watcher.Close() + upComplete := make(chan bool, 1) + watchErr := make(chan error, 1) go func() { var printed bool // whether we've yet printed anything to stdout or stderr - var lastURLPrinted string + lastURLPrinted := "" + + // If we're doing a force-reauth, we need to get two notifications: + // + // 1. IPN is running + // 2. The node key has changed + // + // These two notifications arrive separately, and trying to combine them + // has caused unexpected issues elsewhere in `tailscale up`. For now, we + // track them separately. + ipnIsRunning := false + waitingForKeyChange := upArgs.forceReauth + + // If we're doing a simple up (i.e. `tailscale up`, no flags) and + // the initial state is NeedsMachineAuth, then we never receive a + // state notification from ipn, so we print the device approval URL + // immediately. + if simpleUp && st.BackendState == ipn.NeedsMachineAuth.String() { + printed = true + printDeviceApprovalInfo(env.upArgs.json, effectivePrefs, &lastURLPrinted) + } for { n, err := watcher.Next() @@ -589,29 +682,30 @@ func runUp(ctx context.Context, cmd string, args []string, upArgs upArgsT) (retE msg := *n.ErrMessage fatalf("backend error: %v\n", msg) } + if s := n.State; s != nil && *s == ipn.NeedsMachineAuth { + printed = true + printDeviceApprovalInfo(env.upArgs.json, effectivePrefs, &lastURLPrinted) + } if s := n.State; s != nil { - switch *s { - case ipn.NeedsMachineAuth: - printed = true - if env.upArgs.json { - printUpDoneJSON(ipn.NeedsMachineAuth, "") - } else { - fmt.Fprintf(Stderr, "\nTo approve your machine, visit (as admin):\n\n\t%s\n\n", prefs.AdminPageURL()) - } - case ipn.Running: - // Done full authentication process - if env.upArgs.json { - printUpDoneJSON(ipn.Running, "") - } else if printed { - // Only need to print an update if we printed the "please click" message earlier. - fmt.Fprintf(Stderr, "Success.\n") - } - select { - case running <- true: - default: - } - cancelWatch() + ipnIsRunning = *s == ipn.Running + } + if n.NetMap != nil && n.NetMap.NodeKey != origNodeKey { + waitingForKeyChange = false + } + if ipnIsRunning && !waitingForKeyChange { + // Done full authentication process + if env.upArgs.json { + printUpDoneJSON(ipn.Running, "") + } else if printed { + // Only need to print an update if we printed the "please click" message earlier. + fmt.Fprintf(Stderr, "Success.\n") } + select { + case upComplete <- true: + default: + } + cancelWatch() + return } if url := n.BrowseToURL; url != nil { authURL := *url @@ -644,7 +738,14 @@ func runUp(ctx context.Context, cmd string, args []string, upArgs upArgsT) (retE if err != nil { log.Printf("QR code error: %v", err) } else { - fmt.Fprintf(Stderr, "%s\n", q.ToString(false)) + switch upArgs.qrFormat { + case "large": + fmt.Fprintf(Stderr, "%s\n", q.ToString(false)) + case "small": + fmt.Fprintf(Stderr, "%s\n", q.ToSmallString(false)) + default: + log.Printf("unknown QR code format: %q", upArgs.qrFormat) + } } } } @@ -666,18 +767,18 @@ func runUp(ctx context.Context, cmd string, args []string, upArgs upArgsT) (retE timeoutCh = timeoutTimer.C } select { - case <-running: + case <-upComplete: return nil case <-watchCtx.Done(): select { - case <-running: + case <-upComplete: return nil default: } return watchCtx.Err() case err := <-watchErr: select { - case <-running: + case <-upComplete: return nil default: } @@ -687,6 +788,21 @@ func runUp(ctx context.Context, cmd string, args []string, upArgs upArgsT) (retE } } +func printDeviceApprovalInfo(printJson bool, prefs *ipn.Prefs, lastURLPrinted *string) { + if printJson { + printUpDoneJSON(ipn.NeedsMachineAuth, "") + } else { + deviceApprovalURL := prefs.AdminPageURL(policyclient.Get()) + + if lastURLPrinted != nil && deviceApprovalURL == *lastURLPrinted { + return + } + + *lastURLPrinted = deviceApprovalURL + errf("\nTo approve your machine, visit (as admin):\n\n\t%s\n\n", deviceApprovalURL) + } +} + // upWorthWarning reports whether the health check message s is worth warning // about during "tailscale up". Many of the health checks are noisy or confusing // or very ephemeral and happen especially briefly at startup. @@ -698,6 +814,7 @@ func upWorthyWarning(s string) bool { strings.Contains(s, healthmsg.WarnAcceptRoutesOff) || strings.Contains(s, healthmsg.LockedOut) || strings.Contains(s, healthmsg.WarnExitNodeUsage) || + strings.Contains(s, healthmsg.InMemoryTailnetLockState) || strings.Contains(strings.ToLower(s), "update available: ") } @@ -767,7 +884,9 @@ func init() { addPrefFlagMapping("update-check", "AutoUpdate.Check") addPrefFlagMapping("auto-update", "AutoUpdate.Apply") addPrefFlagMapping("advertise-connector", "AppConnector") - addPrefFlagMapping("posture-checking", "PostureChecking") + addPrefFlagMapping("report-posture", "PostureChecking") + addPrefFlagMapping("relay-server-port", "RelayServerPort") + addPrefFlagMapping("sync", "Sync") } func addPrefFlagMapping(flagName string, prefNames ...string) { @@ -790,7 +909,7 @@ func addPrefFlagMapping(flagName string, prefNames ...string) { // correspond to an ipn.Pref. func preflessFlag(flagName string) bool { switch flagName { - case "auth-key", "force-reauth", "reset", "qr", "json", "timeout", "accept-risk", "host-routes": + case "auth-key", "force-reauth", "reset", "qr", "qr-format", "json", "timeout", "accept-risk", "host-routes", "client-id", "client-secret", "id-token": return true } return false @@ -803,7 +922,7 @@ func updateMaskedPrefsFromUpOrSetFlag(mp *ipn.MaskedPrefs, flagName string) { if prefs, ok := prefsOfFlag[flagName]; ok { for _, pref := range prefs { f := reflect.ValueOf(mp).Elem() - for _, name := range strings.Split(pref, ".") { + for name := range strings.SplitSeq(pref, ".") { f = f.FieldByName(name + "Set") } f.SetBool(true) @@ -845,10 +964,10 @@ type upCheckEnv struct { // // mp is the mask of settings actually set, where mp.Prefs is the new // preferences to set, including any values set from implicit flags. -func checkForAccidentalSettingReverts(newPrefs, curPrefs *ipn.Prefs, env upCheckEnv) error { +func checkForAccidentalSettingReverts(newPrefs, curPrefs *ipn.Prefs, env upCheckEnv) (simpleUp bool, err error) { if curPrefs.ControlURL == "" { // Don't validate things on initial "up" before a control URL has been set. - return nil + return false, nil } flagIsSet := map[string]bool{} @@ -856,10 +975,13 @@ func checkForAccidentalSettingReverts(newPrefs, curPrefs *ipn.Prefs, env upCheck flagIsSet[f.Name] = true }) - if len(flagIsSet) == 0 { + if len(flagIsSet) == 0 && + curPrefs.Persist != nil && + curPrefs.Persist.UserProfile.LoginName != "" && + env.backendState != ipn.NeedsLogin.String() { // A bare "tailscale up" is a special case to just // mean bringing the network up without any changes. - return nil + return true, nil } // flagsCur is what flags we'd need to use to keep the exact @@ -901,7 +1023,7 @@ func checkForAccidentalSettingReverts(newPrefs, curPrefs *ipn.Prefs, env upCheck missing = append(missing, fmtFlagValueArg(flagName, valCur)) } if len(missing) == 0 { - return nil + return false, nil } // Some previously provided flags are missing. This run of 'tailscale @@ -934,7 +1056,7 @@ func checkForAccidentalSettingReverts(newPrefs, curPrefs *ipn.Prefs, env upCheck fmt.Fprintf(&sb, " %s", a) } sb.WriteString("\n\n") - return errors.New(sb.String()) + return false, errors.New(sb.String()) } // applyImplicitPrefs mutates prefs to add implicit preferences for the user operator. @@ -1044,6 +1166,8 @@ func prefsToFlags(env upCheckEnv, prefs *ipn.Prefs) (flagVal map[string]any) { set(prefs.NetfilterMode.String()) case "unattended": set(prefs.ForceDaemon) + case "report-posture": + set(prefs.PostureChecking) } }) return ret @@ -1083,98 +1207,9 @@ func exitNodeIP(p *ipn.Prefs, st *ipnstate.Status) (ip netip.Addr) { return } -func init() { - // Required to use our client API. We're fine with the instability since the - // client lives in the same repo as this code. - tailscale.I_Acknowledge_This_API_Is_Unstable = true -} - -// resolveAuthKey either returns v unchanged (in the common case) or, if it -// starts with "tskey-client-" (as Tailscale OAuth secrets do) parses it like -// -// tskey-client-xxxx[?ephemeral=false&bar&preauthorized=BOOL&baseURL=...] -// -// and does the OAuth2 dance to get and return an authkey. The "ephemeral" -// property defaults to true if unspecified. The "preauthorized" defaults to -// false. The "baseURL" defaults to https://api.tailscale.com. -// The passed in tags are required, and must be non-empty. These will be -// set on the authkey generated by the OAuth2 dance. -func resolveAuthKey(ctx context.Context, v, tags string) (string, error) { - if !strings.HasPrefix(v, "tskey-client-") { - return v, nil - } - if tags == "" { - return "", errors.New("oauth authkeys require --advertise-tags") - } - - clientSecret, named, _ := strings.Cut(v, "?") - attrs, err := url.ParseQuery(named) - if err != nil { - return "", err - } - for k := range attrs { - switch k { - case "ephemeral", "preauthorized", "baseURL": - default: - return "", fmt.Errorf("unknown attribute %q", k) - } - } - getBool := func(name string, def bool) (bool, error) { - v := attrs.Get(name) - if v == "" { - return def, nil - } - ret, err := strconv.ParseBool(v) - if err != nil { - return false, fmt.Errorf("invalid attribute boolean attribute %s value %q", name, v) - } - return ret, nil - } - ephemeral, err := getBool("ephemeral", true) - if err != nil { - return "", err - } - preauth, err := getBool("preauthorized", false) - if err != nil { - return "", err - } - - baseURL := "https://api.tailscale.com" - if v := attrs.Get("baseURL"); v != "" { - baseURL = v - } - - credentials := clientcredentials.Config{ - ClientID: "some-client-id", // ignored - ClientSecret: clientSecret, - TokenURL: baseURL + "/api/v2/oauth/token", - Scopes: []string{"device"}, - } - - tsClient := tailscale.NewClient("-", nil) - tsClient.HTTPClient = credentials.Client(ctx) - tsClient.BaseURL = baseURL - - caps := tailscale.KeyCapabilities{ - Devices: tailscale.KeyDeviceCapabilities{ - Create: tailscale.KeyDeviceCreateCapabilities{ - Reusable: false, - Ephemeral: ephemeral, - Preauthorized: preauth, - Tags: strings.Split(tags, ","), - }, - }, - } - - authkey, _, err := tsClient.CreateKey(ctx, caps) - if err != nil { - return "", err - } - return authkey, nil -} - -func warnOnAdvertiseRouts(ctx context.Context, prefs *ipn.Prefs) { - if len(prefs.AdvertiseRoutes) > 0 || prefs.AppConnector.Advertise { +func warnOnAdvertiseRoutes(ctx context.Context, prefs *ipn.Prefs) { + if buildfeatures.HasAdvertiseRoutes && len(prefs.AdvertiseRoutes) > 0 || + buildfeatures.HasAppConnectors && prefs.AppConnector.Advertise { // TODO(jwhited): compress CheckIPForwarding and CheckUDPGROForwarding // into a single HTTP req. if err := localClient.CheckIPForwarding(ctx); err != nil { diff --git a/cmd/tailscale/cli/up_test.go b/cmd/tailscale/cli/up_test.go new file mode 100644 index 000000000..fe2f1b555 --- /dev/null +++ b/cmd/tailscale/cli/up_test.go @@ -0,0 +1,59 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package cli + +import ( + "flag" + "testing" + + "tailscale.com/util/set" +) + +// validUpFlags are the only flags that are valid for tailscale up. The up +// command is frozen: no new preferences can be added. Instead, add them to +// tailscale set. +// See tailscale/tailscale#15460. +var validUpFlags = set.Of( + "accept-dns", + "accept-risk", + "accept-routes", + "advertise-connector", + "advertise-exit-node", + "advertise-routes", + "advertise-tags", + "auth-key", + "exit-node", + "exit-node-allow-lan-access", + "force-reauth", + "host-routes", + "hostname", + "json", + "login-server", + "netfilter-mode", + "nickname", + "operator", + "report-posture", + "qr", + "qr-format", + "reset", + "shields-up", + "snat-subnet-routes", + "ssh", + "stateful-filtering", + "timeout", + "unattended", + "client-id", + "client-secret", + "id-token", +) + +// TestUpFlagSetIsFrozen complains when new flags are added to tailscale up. +func TestUpFlagSetIsFrozen(t *testing.T) { + upFlagSet.VisitAll(func(f *flag.Flag) { + name := f.Name + if !validUpFlags.Contains(name) { + t.Errorf("--%s flag added to tailscale up, new prefs go in tailscale set: see tailscale/tailscale#15460", name) + } + }) +} diff --git a/cmd/tailscale/cli/update.go b/cmd/tailscale/cli/update.go index 69d1aa97b..7eb0dccac 100644 --- a/cmd/tailscale/cli/update.go +++ b/cmd/tailscale/cli/update.go @@ -9,10 +9,10 @@ import ( "flag" "fmt" "runtime" - "strings" "github.com/peterbourgon/ff/v3/ffcli" "tailscale.com/clientupdate" + "tailscale.com/util/prompt" "tailscale.com/version" "tailscale.com/version/distro" ) @@ -87,19 +87,5 @@ func confirmUpdate(ver string) bool { } msg := fmt.Sprintf("This will update Tailscale from %v to %v. Continue?", version.Short(), ver) - return promptYesNo(msg) -} - -// PromptYesNo takes a question and prompts the user to answer the -// question with a yes or no. It appends a [y/n] to the message. -func promptYesNo(msg string) bool { - fmt.Print(msg + " [y/n] ") - var resp string - fmt.Scanln(&resp) - resp = strings.ToLower(resp) - switch resp { - case "y", "yes", "sure": - return true - } - return false + return prompt.YesNo(msg, true) } diff --git a/cmd/tailscale/cli/web.go b/cmd/tailscale/cli/web.go index e209d388e..2713f730b 100644 --- a/cmd/tailscale/cli/web.go +++ b/cmd/tailscale/cli/web.go @@ -1,6 +1,8 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause +//go:build !ts_omit_webclient + package cli import ( @@ -22,14 +24,20 @@ import ( "github.com/peterbourgon/ff/v3/ffcli" "tailscale.com/client/web" "tailscale.com/ipn" + "tailscale.com/tsconst" ) -var webCmd = &ffcli.Command{ - Name: "web", - ShortUsage: "tailscale web [flags]", - ShortHelp: "Run a web server for controlling Tailscale", +func init() { + maybeWebCmd = webCmd +} + +func webCmd() *ffcli.Command { + return &ffcli.Command{ + Name: "web", + ShortUsage: "tailscale web [flags]", + ShortHelp: "Run a web server for controlling Tailscale", - LongHelp: strings.TrimSpace(` + LongHelp: strings.TrimSpace(` "tailscale web" runs a webserver for controlling the Tailscale daemon. It's primarily intended for use on Synology, QNAP, and other @@ -37,15 +45,17 @@ NAS devices where a web interface is the natural place to control Tailscale, as opposed to a CLI or a native app. `), - FlagSet: (func() *flag.FlagSet { - webf := newFlagSet("web") - webf.StringVar(&webArgs.listen, "listen", "localhost:8088", "listen address; use port 0 for automatic") - webf.BoolVar(&webArgs.cgi, "cgi", false, "run as CGI script") - webf.StringVar(&webArgs.prefix, "prefix", "", "URL prefix added to requests (for cgi or reverse proxies)") - webf.BoolVar(&webArgs.readonly, "readonly", false, "run web UI in read-only mode") - return webf - })(), - Exec: runWeb, + FlagSet: (func() *flag.FlagSet { + webf := newFlagSet("web") + webf.StringVar(&webArgs.listen, "listen", "localhost:8088", "listen address; use port 0 for automatic") + webf.BoolVar(&webArgs.cgi, "cgi", false, "run as CGI script") + webf.StringVar(&webArgs.prefix, "prefix", "", "URL prefix added to requests (for cgi or reverse proxies)") + webf.BoolVar(&webArgs.readonly, "readonly", false, "run web UI in read-only mode") + webf.StringVar(&webArgs.origin, "origin", "", "origin at which the web UI is served (if behind a reverse proxy or used with cgi)") + return webf + })(), + Exec: runWeb, + } } var webArgs struct { @@ -53,6 +63,7 @@ var webArgs struct { cgi bool prefix string readonly bool + origin string } func tlsConfigFromEnvironment() *tls.Config { @@ -99,7 +110,7 @@ func runWeb(ctx context.Context, args []string) error { var startedManagementClient bool // we started the management client if !existingWebClient && !webArgs.readonly { // Also start full client in tailscaled. - log.Printf("starting tailscaled web client at http://%s\n", netip.AddrPortFrom(selfIP, web.ListenPort)) + log.Printf("starting tailscaled web client at http://%s\n", netip.AddrPortFrom(selfIP, tsconst.WebListenPort)) if err := setRunWebClient(ctx, true); err != nil { return fmt.Errorf("starting web client in tailscaled: %w", err) } @@ -115,6 +126,9 @@ func runWeb(ctx context.Context, args []string) error { if webArgs.readonly { opts.Mode = web.ReadOnlyServerMode } + if webArgs.origin != "" { + opts.OriginOverride = webArgs.origin + } webServer, err := web.NewServer(opts) if err != nil { log.Printf("tailscale.web: %v", err) diff --git a/cmd/tailscale/depaware.txt b/cmd/tailscale/depaware.txt index be6f42946..8b576ffc3 100644 --- a/cmd/tailscale/depaware.txt +++ b/cmd/tailscale/depaware.txt @@ -2,45 +2,47 @@ tailscale.com/cmd/tailscale dependencies: (generated by github.com/tailscale/dep filippo.io/edwards25519 from github.com/hdevalence/ed25519consensus filippo.io/edwards25519/field from filippo.io/edwards25519 + L fyne.io/systray from tailscale.com/client/systray + L fyne.io/systray/internal/generated/menu from fyne.io/systray + L fyne.io/systray/internal/generated/notifier from fyne.io/systray + L github.com/Kodeworks/golang-image-ico from tailscale.com/client/systray W đŸ’Ŗ github.com/alexbrainman/sspi from github.com/alexbrainman/sspi/internal/common+ W github.com/alexbrainman/sspi/internal/common from github.com/alexbrainman/sspi/negotiate W đŸ’Ŗ github.com/alexbrainman/sspi/negotiate from tailscale.com/net/tshttpproxy - github.com/coder/websocket from tailscale.com/control/controlhttp+ + L github.com/atotto/clipboard from tailscale.com/client/systray + github.com/coder/websocket from tailscale.com/util/eventbus github.com/coder/websocket/internal/errd from github.com/coder/websocket github.com/coder/websocket/internal/util from github.com/coder/websocket github.com/coder/websocket/internal/xsync from github.com/coder/websocket - L github.com/coreos/go-iptables/iptables from tailscale.com/util/linuxfw W đŸ’Ŗ github.com/dblohm7/wingoes from github.com/dblohm7/wingoes/pe+ W đŸ’Ŗ github.com/dblohm7/wingoes/pe from tailscale.com/util/winutil/authenticode + L github.com/fogleman/gg from tailscale.com/client/systray github.com/fxamacker/cbor/v2 from tailscale.com/tka + github.com/gaissmai/bart from tailscale.com/net/tsdial + github.com/gaissmai/bart/internal/bitset from github.com/gaissmai/bart+ + github.com/gaissmai/bart/internal/sparse from github.com/gaissmai/bart github.com/go-json-experiment/json from tailscale.com/types/opt+ github.com/go-json-experiment/json/internal from github.com/go-json-experiment/json+ github.com/go-json-experiment/json/internal/jsonflags from github.com/go-json-experiment/json+ github.com/go-json-experiment/json/internal/jsonopts from github.com/go-json-experiment/json+ github.com/go-json-experiment/json/internal/jsonwire from github.com/go-json-experiment/json+ github.com/go-json-experiment/json/jsontext from github.com/go-json-experiment/json+ + L đŸ’Ŗ github.com/godbus/dbus/v5 from fyne.io/systray+ + L github.com/godbus/dbus/v5/introspect from fyne.io/systray+ + L github.com/godbus/dbus/v5/prop from fyne.io/systray + L github.com/golang/freetype/raster from github.com/fogleman/gg+ + L github.com/golang/freetype/truetype from github.com/fogleman/gg github.com/golang/groupcache/lru from tailscale.com/net/dnscache - L github.com/google/nftables from tailscale.com/util/linuxfw - L đŸ’Ŗ github.com/google/nftables/alignedbuff from github.com/google/nftables/xt - L đŸ’Ŗ github.com/google/nftables/binaryutil from github.com/google/nftables+ - L github.com/google/nftables/expr from github.com/google/nftables+ - L github.com/google/nftables/internal/parseexprfunc from github.com/google/nftables+ - L github.com/google/nftables/xt from github.com/google/nftables/expr+ - github.com/google/uuid from tailscale.com/clientupdate+ - github.com/gorilla/csrf from tailscale.com/client/web - github.com/gorilla/securecookie from github.com/gorilla/csrf + DW github.com/google/uuid from tailscale.com/clientupdate+ github.com/hdevalence/ed25519consensus from tailscale.com/clientupdate/distsign+ - L github.com/josharian/native from github.com/mdlayher/netlink+ L đŸ’Ŗ github.com/jsimonetti/rtnetlink from tailscale.com/net/netmon L github.com/jsimonetti/rtnetlink/internal/unix from github.com/jsimonetti/rtnetlink github.com/kballard/go-shellquote from tailscale.com/cmd/tailscale/cli đŸ’Ŗ github.com/mattn/go-colorable from tailscale.com/cmd/tailscale/cli đŸ’Ŗ github.com/mattn/go-isatty from tailscale.com/cmd/tailscale/cli+ - L đŸ’Ŗ github.com/mdlayher/netlink from github.com/google/nftables+ + L đŸ’Ŗ github.com/mdlayher/netlink from github.com/jsimonetti/rtnetlink+ L đŸ’Ŗ github.com/mdlayher/netlink/nlenc from github.com/jsimonetti/rtnetlink+ - L github.com/mdlayher/netlink/nltest from github.com/google/nftables L đŸ’Ŗ github.com/mdlayher/socket from github.com/mdlayher/netlink - github.com/miekg/dns from tailscale.com/net/dns/recursive đŸ’Ŗ github.com/mitchellh/go-ps from tailscale.com/cmd/tailscale/cli+ github.com/peterbourgon/ff/v3 from github.com/peterbourgon/ff/v3/ffcli+ github.com/peterbourgon/ff/v3/ffcli from tailscale.com/cmd/tailscale/cli+ @@ -59,14 +61,11 @@ tailscale.com/cmd/tailscale dependencies: (generated by github.com/tailscale/dep github.com/tailscale/goupnp/scpd from github.com/tailscale/goupnp github.com/tailscale/goupnp/soap from github.com/tailscale/goupnp+ github.com/tailscale/goupnp/ssdp from github.com/tailscale/goupnp - L đŸ’Ŗ github.com/tailscale/netlink from tailscale.com/util/linuxfw - L đŸ’Ŗ github.com/tailscale/netlink/nl from github.com/tailscale/netlink + github.com/tailscale/hujson from tailscale.com/ipn/conffile github.com/tailscale/web-client-prebuilt from tailscale.com/client/web - github.com/tcnksm/go-httpstat from tailscale.com/net/netcheck - github.com/toqueteos/webbrowser from tailscale.com/cmd/tailscale/cli - L github.com/vishvananda/netns from github.com/tailscale/netlink+ + github.com/toqueteos/webbrowser from tailscale.com/cmd/tailscale/cli+ github.com/x448/float16 from github.com/fxamacker/cbor/v2 - đŸ’Ŗ go4.org/mem from tailscale.com/client/tailscale+ + đŸ’Ŗ go4.org/mem from tailscale.com/client/local+ go4.org/netipx from tailscale.com/net/tsaddr W đŸ’Ŗ golang.zx2c4.com/wireguard/windows/tunnel/winipcfg from tailscale.com/net/netmon+ k8s.io/client-go/util/homedir from tailscale.com/cmd/tailscale/cli @@ -75,149 +74,186 @@ tailscale.com/cmd/tailscale dependencies: (generated by github.com/tailscale/dep software.sslmate.com/src/go-pkcs12 from tailscale.com/cmd/tailscale/cli software.sslmate.com/src/go-pkcs12/internal/rc2 from software.sslmate.com/src/go-pkcs12 tailscale.com from tailscale.com/version - tailscale.com/atomicfile from tailscale.com/cmd/tailscale/cli+ - tailscale.com/client/tailscale from tailscale.com/client/web+ + đŸ’Ŗ tailscale.com/atomicfile from tailscale.com/cmd/tailscale/cli+ + tailscale.com/client/local from tailscale.com/client/tailscale+ + L tailscale.com/client/systray from tailscale.com/cmd/tailscale/cli + tailscale.com/client/tailscale from tailscale.com/internal/client/tailscale tailscale.com/client/tailscale/apitype from tailscale.com/client/tailscale+ tailscale.com/client/web from tailscale.com/cmd/tailscale/cli - tailscale.com/clientupdate from tailscale.com/client/web+ - tailscale.com/clientupdate/distsign from tailscale.com/clientupdate + tailscale.com/clientupdate from tailscale.com/cmd/tailscale/cli + LW tailscale.com/clientupdate/distsign from tailscale.com/clientupdate tailscale.com/cmd/tailscale/cli from tailscale.com/cmd/tailscale tailscale.com/cmd/tailscale/cli/ffcomplete from tailscale.com/cmd/tailscale/cli tailscale.com/cmd/tailscale/cli/ffcomplete/internal from tailscale.com/cmd/tailscale/cli/ffcomplete + tailscale.com/cmd/tailscale/cli/jsonoutput from tailscale.com/cmd/tailscale/cli tailscale.com/control/controlbase from tailscale.com/control/controlhttp+ - tailscale.com/control/controlhttp from tailscale.com/cmd/tailscale/cli - tailscale.com/control/controlknobs from tailscale.com/net/portmapper - tailscale.com/derp from tailscale.com/derp/derphttp + tailscale.com/control/controlhttp from tailscale.com/control/ts2021 + tailscale.com/control/controlhttp/controlhttpcommon from tailscale.com/control/controlhttp + tailscale.com/control/ts2021 from tailscale.com/cmd/tailscale/cli + tailscale.com/derp from tailscale.com/derp/derphttp+ + tailscale.com/derp/derpconst from tailscale.com/derp/derphttp+ tailscale.com/derp/derphttp from tailscale.com/net/netcheck - tailscale.com/disco from tailscale.com/derp - tailscale.com/drive from tailscale.com/client/tailscale+ - tailscale.com/envknob from tailscale.com/client/tailscale+ + tailscale.com/drive from tailscale.com/client/local+ + tailscale.com/envknob from tailscale.com/client/local+ + tailscale.com/envknob/featureknob from tailscale.com/client/web + tailscale.com/feature from tailscale.com/tsweb+ + tailscale.com/feature/buildfeatures from tailscale.com/cmd/tailscale/cli+ + tailscale.com/feature/capture/dissector from tailscale.com/cmd/tailscale/cli + tailscale.com/feature/condregister/identityfederation from tailscale.com/cmd/tailscale/cli + tailscale.com/feature/condregister/oauthkey from tailscale.com/cmd/tailscale/cli + tailscale.com/feature/condregister/portmapper from tailscale.com/cmd/tailscale/cli + tailscale.com/feature/condregister/useproxy from tailscale.com/cmd/tailscale/cli + tailscale.com/feature/identityfederation from tailscale.com/feature/condregister/identityfederation + tailscale.com/feature/oauthkey from tailscale.com/feature/condregister/oauthkey + tailscale.com/feature/portmapper from tailscale.com/feature/condregister/portmapper + tailscale.com/feature/syspolicy from tailscale.com/cmd/tailscale/cli + tailscale.com/feature/useproxy from tailscale.com/feature/condregister/useproxy tailscale.com/health from tailscale.com/net/tlsdial+ tailscale.com/health/healthmsg from tailscale.com/cmd/tailscale/cli tailscale.com/hostinfo from tailscale.com/client/web+ - tailscale.com/internal/noiseconn from tailscale.com/cmd/tailscale/cli - tailscale.com/ipn from tailscale.com/client/tailscale+ - tailscale.com/ipn/ipnstate from tailscale.com/client/tailscale+ + tailscale.com/internal/client/tailscale from tailscale.com/cmd/tailscale/cli+ + tailscale.com/ipn from tailscale.com/client/local+ + tailscale.com/ipn/conffile from tailscale.com/cmd/tailscale/cli + tailscale.com/ipn/ipnstate from tailscale.com/client/local+ tailscale.com/kube/kubetypes from tailscale.com/envknob tailscale.com/licenses from tailscale.com/client/web+ - tailscale.com/metrics from tailscale.com/derp+ + tailscale.com/metrics from tailscale.com/tsweb+ + tailscale.com/net/ace from tailscale.com/cmd/tailscale/cli + tailscale.com/net/bakedroots from tailscale.com/net/tlsdial tailscale.com/net/captivedetection from tailscale.com/net/netcheck - tailscale.com/net/dns/recursive from tailscale.com/net/dnsfallback tailscale.com/net/dnscache from tailscale.com/control/controlhttp+ tailscale.com/net/dnsfallback from tailscale.com/control/controlhttp+ - tailscale.com/net/flowtrack from tailscale.com/net/packet tailscale.com/net/netaddr from tailscale.com/ipn+ tailscale.com/net/netcheck from tailscale.com/cmd/tailscale/cli tailscale.com/net/neterror from tailscale.com/net/netcheck+ - tailscale.com/net/netknob from tailscale.com/net/netns + tailscale.com/net/netknob from tailscale.com/net/netns+ đŸ’Ŗ tailscale.com/net/netmon from tailscale.com/cmd/tailscale/cli+ đŸ’Ŗ tailscale.com/net/netns from tailscale.com/derp/derphttp+ - tailscale.com/net/netutil from tailscale.com/client/tailscale+ - tailscale.com/net/packet from tailscale.com/wgengine/capture + tailscale.com/net/netutil from tailscale.com/client/local+ + tailscale.com/net/netx from tailscale.com/control/controlhttp+ tailscale.com/net/ping from tailscale.com/net/netcheck - tailscale.com/net/portmapper from tailscale.com/cmd/tailscale/cli+ + tailscale.com/net/portmapper from tailscale.com/feature/portmapper + tailscale.com/net/portmapper/portmappertype from tailscale.com/net/netcheck+ tailscale.com/net/sockstats from tailscale.com/control/controlhttp+ tailscale.com/net/stun from tailscale.com/net/netcheck - L tailscale.com/net/tcpinfo from tailscale.com/derp tailscale.com/net/tlsdial from tailscale.com/cmd/tailscale/cli+ + tailscale.com/net/tlsdial/blockblame from tailscale.com/net/tlsdial tailscale.com/net/tsaddr from tailscale.com/client/web+ - đŸ’Ŗ tailscale.com/net/tshttpproxy from tailscale.com/clientupdate/distsign+ - tailscale.com/net/wsconn from tailscale.com/control/controlhttp+ - tailscale.com/paths from tailscale.com/client/tailscale+ - đŸ’Ŗ tailscale.com/safesocket from tailscale.com/client/tailscale+ - tailscale.com/syncs from tailscale.com/cmd/tailscale/cli+ - tailscale.com/tailcfg from tailscale.com/client/tailscale+ + tailscale.com/net/tsdial from tailscale.com/cmd/tailscale/cli+ + đŸ’Ŗ tailscale.com/net/tshttpproxy from tailscale.com/feature/useproxy + tailscale.com/net/udprelay/status from tailscale.com/client/local+ + tailscale.com/omit from tailscale.com/ipn/conffile + tailscale.com/paths from tailscale.com/client/local+ + đŸ’Ŗ tailscale.com/safesocket from tailscale.com/client/local+ + tailscale.com/syncs from tailscale.com/control/controlhttp+ + tailscale.com/tailcfg from tailscale.com/client/local+ tailscale.com/tempfork/spf13/cobra from tailscale.com/cmd/tailscale/cli/ffcomplete+ - tailscale.com/tka from tailscale.com/client/tailscale+ + tailscale.com/tka from tailscale.com/client/local+ tailscale.com/tsconst from tailscale.com/net/netmon+ tailscale.com/tstime from tailscale.com/control/controlhttp+ tailscale.com/tstime/mono from tailscale.com/tstime/rate - tailscale.com/tstime/rate from tailscale.com/cmd/tailscale/cli+ - tailscale.com/tsweb/varz from tailscale.com/util/usermetric + tailscale.com/tstime/rate from tailscale.com/cmd/tailscale/cli + tailscale.com/tsweb from tailscale.com/util/eventbus + tailscale.com/tsweb/varz from tailscale.com/util/usermetric+ + tailscale.com/types/appctype from tailscale.com/client/local+ tailscale.com/types/dnstype from tailscale.com/tailcfg+ tailscale.com/types/empty from tailscale.com/ipn - tailscale.com/types/ipproto from tailscale.com/net/flowtrack+ - tailscale.com/types/key from tailscale.com/client/tailscale+ + tailscale.com/types/ipproto from tailscale.com/ipn+ + tailscale.com/types/key from tailscale.com/client/local+ tailscale.com/types/lazy from tailscale.com/util/testenv+ tailscale.com/types/logger from tailscale.com/client/web+ tailscale.com/types/netmap from tailscale.com/ipn+ tailscale.com/types/nettype from tailscale.com/net/netcheck+ tailscale.com/types/opt from tailscale.com/client/tailscale+ - tailscale.com/types/persist from tailscale.com/ipn + tailscale.com/types/persist from tailscale.com/ipn+ tailscale.com/types/preftype from tailscale.com/cmd/tailscale/cli+ tailscale.com/types/ptr from tailscale.com/hostinfo+ + tailscale.com/types/result from tailscale.com/util/lineiter tailscale.com/types/structs from tailscale.com/ipn+ tailscale.com/types/tkatype from tailscale.com/types/key+ tailscale.com/types/views from tailscale.com/tailcfg+ - tailscale.com/util/cibuild from tailscale.com/health + tailscale.com/util/cibuild from tailscale.com/health+ tailscale.com/util/clientmetric from tailscale.com/net/netcheck+ tailscale.com/util/cloudenv from tailscale.com/net/dnscache+ tailscale.com/util/cmpver from tailscale.com/net/tshttpproxy+ - tailscale.com/util/ctxkey from tailscale.com/types/logger + tailscale.com/util/ctxkey from tailscale.com/types/logger+ đŸ’Ŗ tailscale.com/util/deephash from tailscale.com/util/syspolicy/setting L đŸ’Ŗ tailscale.com/util/dirwalk from tailscale.com/metrics tailscale.com/util/dnsname from tailscale.com/cmd/tailscale/cli+ + tailscale.com/util/eventbus from tailscale.com/client/local+ tailscale.com/util/groupmember from tailscale.com/client/web đŸ’Ŗ tailscale.com/util/hashx from tailscale.com/util/deephash tailscale.com/util/httpm from tailscale.com/client/tailscale+ - tailscale.com/util/lineread from tailscale.com/hostinfo+ - L tailscale.com/util/linuxfw from tailscale.com/net/netns + tailscale.com/util/lineiter from tailscale.com/hostinfo+ tailscale.com/util/mak from tailscale.com/cmd/tailscale/cli+ - tailscale.com/util/multierr from tailscale.com/control/controlhttp+ tailscale.com/util/must from tailscale.com/clientupdate/distsign+ tailscale.com/util/nocasemaps from tailscale.com/types/ipproto + tailscale.com/util/prompt from tailscale.com/cmd/tailscale/cli tailscale.com/util/quarantine from tailscale.com/cmd/tailscale/cli - tailscale.com/util/set from tailscale.com/derp+ - tailscale.com/util/singleflight from tailscale.com/net/dnscache+ - tailscale.com/util/slicesx from tailscale.com/net/dns/recursive+ - tailscale.com/util/syspolicy from tailscale.com/ipn - tailscale.com/util/syspolicy/internal from tailscale.com/util/syspolicy/setting - tailscale.com/util/syspolicy/setting from tailscale.com/util/syspolicy - tailscale.com/util/testenv from tailscale.com/cmd/tailscale/cli + tailscale.com/util/rands from tailscale.com/tsweb + tailscale.com/util/set from tailscale.com/ipn+ + tailscale.com/util/singleflight from tailscale.com/net/dnscache + tailscale.com/util/slicesx from tailscale.com/client/systray+ + L tailscale.com/util/stringsx from tailscale.com/client/systray + tailscale.com/util/syspolicy from tailscale.com/feature/syspolicy + tailscale.com/util/syspolicy/internal from tailscale.com/util/syspolicy/setting+ + tailscale.com/util/syspolicy/internal/loggerx from tailscale.com/util/syspolicy+ + tailscale.com/util/syspolicy/internal/metrics from tailscale.com/util/syspolicy/source + tailscale.com/util/syspolicy/pkey from tailscale.com/ipn+ + tailscale.com/util/syspolicy/policyclient from tailscale.com/client/web+ + tailscale.com/util/syspolicy/ptype from tailscale.com/util/syspolicy/policyclient+ + tailscale.com/util/syspolicy/rsop from tailscale.com/util/syspolicy + tailscale.com/util/syspolicy/setting from tailscale.com/client/local+ + tailscale.com/util/syspolicy/source from tailscale.com/util/syspolicy+ + tailscale.com/util/testenv from tailscale.com/cmd/tailscale/cli+ tailscale.com/util/truncate from tailscale.com/cmd/tailscale/cli tailscale.com/util/usermetric from tailscale.com/health tailscale.com/util/vizerror from tailscale.com/tailcfg+ - đŸ’Ŗ tailscale.com/util/winutil from tailscale.com/clientupdate+ + W đŸ’Ŗ tailscale.com/util/winutil from tailscale.com/clientupdate+ W đŸ’Ŗ tailscale.com/util/winutil/authenticode from tailscale.com/clientupdate + W đŸ’Ŗ tailscale.com/util/winutil/gp from tailscale.com/util/syspolicy/source W đŸ’Ŗ tailscale.com/util/winutil/winenv from tailscale.com/hostinfo+ tailscale.com/version from tailscale.com/client/web+ tailscale.com/version/distro from tailscale.com/client/web+ - tailscale.com/wgengine/capture from tailscale.com/cmd/tailscale/cli tailscale.com/wgengine/filter/filtertype from tailscale.com/types/netmap golang.org/x/crypto/argon2 from tailscale.com/tka golang.org/x/crypto/blake2b from golang.org/x/crypto/argon2+ golang.org/x/crypto/blake2s from tailscale.com/clientupdate/distsign+ golang.org/x/crypto/chacha20 from golang.org/x/crypto/chacha20poly1305 - golang.org/x/crypto/chacha20poly1305 from crypto/tls+ - golang.org/x/crypto/cryptobyte from crypto/ecdsa+ - golang.org/x/crypto/cryptobyte/asn1 from crypto/ecdsa+ + golang.org/x/crypto/chacha20poly1305 from tailscale.com/control/controlbase golang.org/x/crypto/curve25519 from golang.org/x/crypto/nacl/box+ - golang.org/x/crypto/hkdf from crypto/tls+ + golang.org/x/crypto/hkdf from tailscale.com/control/controlbase + golang.org/x/crypto/internal/alias from golang.org/x/crypto/chacha20+ + golang.org/x/crypto/internal/poly1305 from golang.org/x/crypto/chacha20poly1305+ golang.org/x/crypto/nacl/box from tailscale.com/types/key golang.org/x/crypto/nacl/secretbox from golang.org/x/crypto/nacl/box golang.org/x/crypto/pbkdf2 from software.sslmate.com/src/go-pkcs12 golang.org/x/crypto/salsa20/salsa from golang.org/x/crypto/nacl/box+ - golang.org/x/crypto/sha3 from crypto/internal/mlkem768+ - W golang.org/x/exp/constraints from github.com/dblohm7/wingoes/pe+ - golang.org/x/exp/maps from tailscale.com/cmd/tailscale/cli+ + golang.org/x/exp/constraints from github.com/dblohm7/wingoes/pe+ + golang.org/x/exp/maps from tailscale.com/util/syspolicy/setting+ + L golang.org/x/image/draw from github.com/fogleman/gg + L golang.org/x/image/font from github.com/fogleman/gg+ + L golang.org/x/image/font/basicfont from github.com/fogleman/gg + L golang.org/x/image/math/f64 from github.com/fogleman/gg+ + L golang.org/x/image/math/fixed from github.com/fogleman/gg+ golang.org/x/net/bpf from github.com/mdlayher/netlink+ - golang.org/x/net/dns/dnsmessage from net+ - golang.org/x/net/http/httpguts from net/http+ - golang.org/x/net/http/httpproxy from net/http+ - golang.org/x/net/http2 from tailscale.com/cmd/tailscale/cli+ - golang.org/x/net/http2/hpack from net/http+ + golang.org/x/net/dns/dnsmessage from tailscale.com/cmd/tailscale/cli+ + golang.org/x/net/http/httpproxy from tailscale.com/net/tshttpproxy golang.org/x/net/icmp from tailscale.com/net/ping - golang.org/x/net/idna from golang.org/x/net/http/httpguts+ - golang.org/x/net/ipv4 from github.com/miekg/dns+ - golang.org/x/net/ipv6 from github.com/miekg/dns+ + golang.org/x/net/idna from golang.org/x/net/http/httpproxy+ + golang.org/x/net/internal/iana from golang.org/x/net/icmp+ + golang.org/x/net/internal/socket from golang.org/x/net/icmp+ + golang.org/x/net/internal/socks from golang.org/x/net/proxy + golang.org/x/net/ipv4 from golang.org/x/net/icmp+ + golang.org/x/net/ipv6 from golang.org/x/net/icmp+ golang.org/x/net/proxy from tailscale.com/net/netns - D golang.org/x/net/route from net+ - golang.org/x/oauth2 from golang.org/x/oauth2/clientcredentials - golang.org/x/oauth2/clientcredentials from tailscale.com/cmd/tailscale/cli + D golang.org/x/net/route from tailscale.com/net/netmon+ + golang.org/x/oauth2 from golang.org/x/oauth2/clientcredentials+ + golang.org/x/oauth2/clientcredentials from tailscale.com/feature/oauthkey golang.org/x/oauth2/internal from golang.org/x/oauth2+ golang.org/x/sync/errgroup from github.com/mdlayher/socket+ - golang.org/x/sys/cpu from github.com/josharian/native+ - LD golang.org/x/sys/unix from github.com/google/nftables+ + golang.org/x/sys/cpu from golang.org/x/crypto/argon2+ + LD golang.org/x/sys/unix from github.com/jsimonetti/rtnetlink/internal/unix+ W golang.org/x/sys/windows from github.com/dblohm7/wingoes+ W golang.org/x/sys/windows/registry from github.com/dblohm7/wingoes+ W golang.org/x/sys/windows/svc from golang.org/x/sys/windows/svc/mgr+ @@ -227,6 +263,22 @@ tailscale.com/cmd/tailscale dependencies: (generated by github.com/tailscale/dep golang.org/x/text/unicode/bidi from golang.org/x/net/idna+ golang.org/x/text/unicode/norm from golang.org/x/net/idna golang.org/x/time/rate from tailscale.com/cmd/tailscale/cli+ + vendor/golang.org/x/crypto/chacha20 from vendor/golang.org/x/crypto/chacha20poly1305 + vendor/golang.org/x/crypto/chacha20poly1305 from crypto/internal/hpke+ + vendor/golang.org/x/crypto/cryptobyte from crypto/ecdsa+ + vendor/golang.org/x/crypto/cryptobyte/asn1 from crypto/ecdsa+ + vendor/golang.org/x/crypto/internal/alias from vendor/golang.org/x/crypto/chacha20+ + vendor/golang.org/x/crypto/internal/poly1305 from vendor/golang.org/x/crypto/chacha20poly1305 + vendor/golang.org/x/net/dns/dnsmessage from net + vendor/golang.org/x/net/http/httpguts from net/http+ + vendor/golang.org/x/net/http/httpproxy from net/http + vendor/golang.org/x/net/http2/hpack from net/http+ + vendor/golang.org/x/net/idna from net/http+ + vendor/golang.org/x/sys/cpu from vendor/golang.org/x/crypto/chacha20poly1305 + vendor/golang.org/x/text/secure/bidirule from vendor/golang.org/x/net/idna + vendor/golang.org/x/text/transform from vendor/golang.org/x/text/secure/bidirule+ + vendor/golang.org/x/text/unicode/bidi from vendor/golang.org/x/net/idna+ + vendor/golang.org/x/text/unicode/norm from vendor/golang.org/x/net/idna archive/tar from tailscale.com/clientupdate bufio from compress/flate+ bytes from archive/tar+ @@ -237,7 +289,7 @@ tailscale.com/cmd/tailscale dependencies: (generated by github.com/tailscale/dep container/list from crypto/tls+ context from crypto/tls+ crypto from crypto/ecdh+ - crypto/aes from crypto/ecdsa+ + crypto/aes from crypto/internal/hpke+ crypto/cipher from crypto/aes+ crypto/des from crypto/tls+ crypto/dsa from crypto/x509 @@ -245,34 +297,76 @@ tailscale.com/cmd/tailscale dependencies: (generated by github.com/tailscale/dep crypto/ecdsa from crypto/tls+ crypto/ed25519 from crypto/tls+ crypto/elliptic from crypto/ecdsa+ + crypto/fips140 from crypto/tls/internal/fips140tls + crypto/hkdf from crypto/internal/hpke+ crypto/hmac from crypto/tls+ + crypto/internal/boring from crypto/aes+ + crypto/internal/boring/bbig from crypto/ecdsa+ + crypto/internal/boring/sig from crypto/internal/boring + crypto/internal/entropy from crypto/internal/fips140/drbg + crypto/internal/fips140 from crypto/internal/fips140/aes+ + crypto/internal/fips140/aes from crypto/aes+ + crypto/internal/fips140/aes/gcm from crypto/cipher+ + crypto/internal/fips140/alias from crypto/cipher+ + crypto/internal/fips140/bigmod from crypto/internal/fips140/ecdsa+ + crypto/internal/fips140/check from crypto/internal/fips140/aes+ + crypto/internal/fips140/drbg from crypto/internal/fips140/aes/gcm+ + crypto/internal/fips140/ecdh from crypto/ecdh + crypto/internal/fips140/ecdsa from crypto/ecdsa + crypto/internal/fips140/ed25519 from crypto/ed25519 + crypto/internal/fips140/edwards25519 from crypto/internal/fips140/ed25519 + crypto/internal/fips140/edwards25519/field from crypto/ecdh+ + crypto/internal/fips140/hkdf from crypto/internal/fips140/tls13+ + crypto/internal/fips140/hmac from crypto/hmac+ + crypto/internal/fips140/mlkem from crypto/tls + crypto/internal/fips140/nistec from crypto/elliptic+ + crypto/internal/fips140/nistec/fiat from crypto/internal/fips140/nistec + crypto/internal/fips140/rsa from crypto/rsa + crypto/internal/fips140/sha256 from crypto/internal/fips140/check+ + crypto/internal/fips140/sha3 from crypto/internal/fips140/hmac+ + crypto/internal/fips140/sha512 from crypto/internal/fips140/ecdsa+ + crypto/internal/fips140/subtle from crypto/internal/fips140/aes+ + crypto/internal/fips140/tls12 from crypto/tls + crypto/internal/fips140/tls13 from crypto/tls + crypto/internal/fips140cache from crypto/ecdsa+ + crypto/internal/fips140deps/byteorder from crypto/internal/fips140/aes+ + crypto/internal/fips140deps/cpu from crypto/internal/fips140/aes+ + crypto/internal/fips140deps/godebug from crypto/internal/fips140+ + crypto/internal/fips140hash from crypto/ecdsa+ + crypto/internal/fips140only from crypto/cipher+ + crypto/internal/hpke from crypto/tls + crypto/internal/impl from crypto/internal/fips140/aes+ + crypto/internal/randutil from crypto/dsa+ + crypto/internal/sysrand from crypto/internal/entropy+ crypto/md5 from crypto/tls+ crypto/rand from crypto/ed25519+ crypto/rc4 from crypto/tls crypto/rsa from crypto/tls+ crypto/sha1 from crypto/tls+ crypto/sha256 from crypto/tls+ + crypto/sha3 from crypto/internal/fips140hash crypto/sha512 from crypto/ecdsa+ - crypto/subtle from crypto/aes+ - crypto/tls from github.com/miekg/dns+ + crypto/subtle from crypto/cipher+ + crypto/tls from net/http+ + crypto/tls/internal/fips140tls from crypto/tls crypto/x509 from crypto/tls+ + D crypto/x509/internal/macos from crypto/x509 crypto/x509/pkix from crypto/x509+ - database/sql/driver from github.com/google/uuid + DW database/sql/driver from github.com/google/uuid W debug/dwarf from debug/pe W debug/pe from github.com/dblohm7/wingoes/pe - embed from crypto/internal/nistec+ - encoding from encoding/gob+ + embed from github.com/peterbourgon/ff/v3+ + encoding from encoding/json+ encoding/asn1 from crypto/x509+ encoding/base32 from github.com/fxamacker/cbor/v2+ encoding/base64 from encoding/json+ encoding/binary from compress/gzip+ - encoding/gob from github.com/gorilla/securecookie encoding/hex from crypto/x509+ encoding/json from expvar+ encoding/pem from crypto/tls+ - encoding/xml from github.com/tailscale/goupnp+ + encoding/xml from github.com/godbus/dbus/v5/introspect+ errors from archive/tar+ - expvar from tailscale.com/derp+ + expvar from tailscale.com/health+ flag from github.com/peterbourgon/ff/v3+ fmt from archive/tar+ hash from compress/zlib+ @@ -280,10 +374,60 @@ tailscale.com/cmd/tailscale dependencies: (generated by github.com/tailscale/dep hash/crc32 from compress/gzip+ hash/maphash from go4.org/mem html from html/template+ - html/template from github.com/gorilla/csrf + html/template from tailscale.com/util/eventbus image from github.com/skip2/go-qrcode+ image/color from github.com/skip2/go-qrcode+ - image/png from github.com/skip2/go-qrcode + L image/draw from github.com/Kodeworks/golang-image-ico+ + L image/internal/imageutil from image/draw+ + L image/jpeg from github.com/fogleman/gg + image/png from github.com/skip2/go-qrcode+ + internal/abi from crypto/x509/internal/macos+ + internal/asan from internal/runtime/maps+ + internal/bisect from internal/godebug + internal/bytealg from bytes+ + internal/byteorder from crypto/cipher+ + internal/chacha8rand from math/rand/v2+ + internal/coverage/rtcov from runtime + internal/cpu from crypto/internal/fips140deps/cpu+ + internal/filepathlite from os+ + internal/fmtsort from fmt+ + internal/goarch from crypto/internal/fips140deps/cpu+ + internal/godebug from archive/tar+ + internal/godebugs from internal/godebug+ + internal/goexperiment from hash/maphash+ + internal/goos from crypto/x509+ + internal/itoa from internal/poll+ + internal/msan from internal/runtime/maps+ + internal/nettrace from net+ + internal/oserror from io/fs+ + internal/poll from net+ + internal/profile from net/http/pprof + internal/profilerecord from runtime+ + internal/race from internal/poll+ + internal/reflectlite from context+ + D internal/routebsd from net + internal/runtime/atomic from internal/runtime/exithook+ + L internal/runtime/cgroup from runtime + internal/runtime/exithook from runtime + internal/runtime/gc from runtime + internal/runtime/maps from reflect+ + internal/runtime/math from internal/runtime/maps+ + internal/runtime/strconv from internal/runtime/cgroup+ + internal/runtime/sys from crypto/subtle+ + L internal/runtime/syscall from runtime+ + internal/saferio from debug/pe+ + internal/singleflight from net + internal/stringslite from embed+ + internal/sync from sync+ + internal/synctest from sync + internal/syscall/execenv from os+ + LD internal/syscall/unix from crypto/internal/sysrand+ + W internal/syscall/windows from crypto/internal/sysrand+ + W internal/syscall/windows/registry from mime+ + W internal/syscall/windows/sysdll from internal/syscall/windows+ + internal/testlog from os + internal/trace/tracev2 from runtime+ + internal/unsafeheader from internal/reflectlite+ io from archive/tar+ io/fs from archive/tar+ io/ioutil from github.com/mitchellh/go-ps+ @@ -295,33 +439,40 @@ tailscale.com/cmd/tailscale dependencies: (generated by github.com/tailscale/dep math/big from crypto/dsa+ math/bits from compress/flate+ math/rand from github.com/mdlayher/netlink+ - math/rand/v2 from tailscale.com/derp+ + math/rand/v2 from crypto/ecdsa+ mime from golang.org/x/oauth2/internal+ mime/multipart from net/http mime/quotedprintable from mime/multipart net from crypto/tls+ net/http from expvar+ net/http/cgi from tailscale.com/cmd/tailscale/cli - net/http/httptrace from github.com/tcnksm/go-httpstat+ + net/http/httptrace from net/http+ net/http/httputil from tailscale.com/client/web+ net/http/internal from net/http+ + net/http/internal/ascii from net/http+ + net/http/internal/httpcommon from net/http + net/http/pprof from tailscale.com/tsweb net/netip from go4.org/netipx+ - net/textproto from golang.org/x/net/http/httpguts+ + net/textproto from github.com/coder/websocket+ net/url from crypto/x509+ - os from crypto/rand+ - os/exec from github.com/coreos/go-iptables/iptables+ - os/signal from tailscale.com/cmd/tailscale/cli + os from crypto/internal/sysrand+ + os/exec from github.com/atotto/clipboard+ + os/signal from tailscale.com/cmd/tailscale/cli+ os/user from archive/tar+ path from archive/tar+ path/filepath from archive/tar+ reflect from archive/tar+ - regexp from github.com/coreos/go-iptables/iptables+ + regexp from github.com/tailscale/goupnp/httpu+ regexp/syntax from regexp - runtime/debug from github.com/coder/websocket/internal/xsync+ + runtime from archive/tar+ + runtime/debug from tailscale.com+ + runtime/pprof from net/http/pprof + runtime/trace from net/http/pprof slices from tailscale.com/client/web+ sort from compress/flate+ strconv from archive/tar+ strings from archive/tar+ + W structs from internal/syscall/windows sync from archive/tar+ sync/atomic from context+ syscall from archive/tar+ @@ -333,3 +484,5 @@ tailscale.com/cmd/tailscale dependencies: (generated by github.com/tailscale/dep unicode/utf16 from crypto/x509+ unicode/utf8 from bufio+ unique from net/netip + unsafe from bytes+ + weak from unique+ diff --git a/cmd/tailscale/tailscale.rc b/cmd/tailscale/tailscale.rc new file mode 100755 index 000000000..2cac53efb --- /dev/null +++ b/cmd/tailscale/tailscale.rc @@ -0,0 +1,3 @@ +#!/bin/rc +# Plan 9 cmd/tailscale wrapper script to run cmd/tailscaled's embedded CLI. +TS_BE_CLI=1 tailscaled $* diff --git a/cmd/tailscale/tailscale_test.go b/cmd/tailscale/tailscale_test.go index dc477fb6e..a7a3c2323 100644 --- a/cmd/tailscale/tailscale_test.go +++ b/cmd/tailscale/tailscale_test.go @@ -19,7 +19,6 @@ func TestDeps(t *testing.T) { "gvisor.dev/gvisor/pkg/tcpip/header": "https://github.com/tailscale/tailscale/issues/9756", "tailscale.com/wgengine/filter": "brings in bart, etc", "github.com/bits-and-blooms/bitset": "unneeded in CLI", - "github.com/gaissmai/bart": "unneeded in CLI", "tailscale.com/net/ipset": "unneeded in CLI", }, }.Check(t) diff --git a/cmd/tailscaled/debug.go b/cmd/tailscaled/debug.go index b41604d29..b16cb28e0 100644 --- a/cmd/tailscaled/debug.go +++ b/cmd/tailscaled/debug.go @@ -1,7 +1,7 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -//go:build go1.19 +//go:build !ts_omit_debug package main @@ -16,17 +16,22 @@ import ( "log" "net/http" "net/http/httptrace" + "net/http/pprof" "net/url" "os" "time" "tailscale.com/derp/derphttp" + "tailscale.com/feature" + "tailscale.com/feature/buildfeatures" "tailscale.com/health" "tailscale.com/ipn" "tailscale.com/net/netmon" - "tailscale.com/net/tshttpproxy" "tailscale.com/tailcfg" + "tailscale.com/tsweb/varz" "tailscale.com/types/key" + "tailscale.com/util/clientmetric" + "tailscale.com/util/eventbus" ) var debugArgs struct { @@ -37,7 +42,29 @@ var debugArgs struct { portmap bool } -var debugModeFunc = debugMode // so it can be addressable +func init() { + debugModeFunc := debugMode // to be addressable + subCommands["debug"] = &debugModeFunc + + hookNewDebugMux.Set(newDebugMux) +} + +func newDebugMux() *http.ServeMux { + mux := http.NewServeMux() + mux.HandleFunc("/debug/metrics", servePrometheusMetrics) + mux.HandleFunc("/debug/pprof/", pprof.Index) + mux.HandleFunc("/debug/pprof/cmdline", pprof.Cmdline) + mux.HandleFunc("/debug/pprof/profile", pprof.Profile) + mux.HandleFunc("/debug/pprof/symbol", pprof.Symbol) + mux.HandleFunc("/debug/pprof/trace", pprof.Trace) + return mux +} + +func servePrometheusMetrics(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/plain") + varz.Handler(w, r) + clientmetric.WritePrometheusExpositionFormat(w) +} func debugMode(args []string) error { fs := flag.NewFlagSet("debug", flag.ExitOnError) @@ -72,24 +99,23 @@ func debugMode(args []string) error { } func runMonitor(ctx context.Context, loop bool) error { + b := eventbus.New() + defer b.Close() + dump := func(st *netmon.State) { j, _ := json.MarshalIndent(st, "", " ") os.Stderr.Write(j) } - mon, err := netmon.New(log.Printf) + mon, err := netmon.New(b, log.Printf) if err != nil { return err } defer mon.Close() - mon.RegisterChangeCallback(func(delta *netmon.ChangeDelta) { - if !delta.Major { - log.Printf("Network monitor fired; not a major change") - return - } - log.Printf("Network monitor fired. New state:") - dump(delta.New) - }) + eventClient := b.Client("debug.runMonitor") + m := eventClient.Monitor(changeDeltaWatcher(eventClient, ctx, dump)) + defer m.Close() + if loop { log.Printf("Starting link change monitor; initial state:") } @@ -102,6 +128,27 @@ func runMonitor(ctx context.Context, loop bool) error { select {} } +func changeDeltaWatcher(ec *eventbus.Client, ctx context.Context, dump func(st *netmon.State)) func(*eventbus.Client) { + changeSub := eventbus.Subscribe[netmon.ChangeDelta](ec) + return func(ec *eventbus.Client) { + for { + select { + case <-ctx.Done(): + return + case <-ec.Done(): + return + case delta := <-changeSub.Events(): + if !delta.Major { + log.Printf("Network monitor fired; not a major change") + return + } + log.Printf("Network monitor fired. New state:") + dump(delta.New) + } + } + } +} + func getURL(ctx context.Context, urlStr string) error { if urlStr == "login" { urlStr = "https://login.tailscale.com" @@ -120,9 +167,14 @@ func getURL(ctx context.Context, urlStr string) error { if err != nil { return fmt.Errorf("http.NewRequestWithContext: %v", err) } - proxyURL, err := tshttpproxy.ProxyFromEnvironment(req) - if err != nil { - return fmt.Errorf("tshttpproxy.ProxyFromEnvironment: %v", err) + var proxyURL *url.URL + if buildfeatures.HasUseProxy { + if proxyFromEnv, ok := feature.HookProxyFromEnvironment.GetOk(); ok { + proxyURL, err = proxyFromEnv(req) + if err != nil { + return fmt.Errorf("tshttpproxy.ProxyFromEnvironment: %v", err) + } + } } log.Printf("proxy: %v", proxyURL) tr := &http.Transport{ @@ -131,7 +183,10 @@ func getURL(ctx context.Context, urlStr string) error { DisableKeepAlives: true, } if proxyURL != nil { - auth, err := tshttpproxy.GetAuthHeader(proxyURL) + var auth string + if f, ok := feature.HookProxyGetAuthHeader.GetOk(); ok { + auth, err = f(proxyURL) + } if err == nil && auth != "" { tr.ProxyConnectHeader.Set("Proxy-Authorization", auth) } @@ -157,7 +212,9 @@ func getURL(ctx context.Context, urlStr string) error { } func checkDerp(ctx context.Context, derpRegion string) (err error) { - ht := new(health.Tracker) + bus := eventbus.New() + defer bus.Close() + ht := health.NewTracker(bus) req, err := http.NewRequestWithContext(ctx, "GET", ipn.DefaultControlURL+"/derpmap/default", nil) if err != nil { return fmt.Errorf("create derp map request: %w", err) diff --git a/cmd/tailscaled/debug_forcereflect.go b/cmd/tailscaled/debug_forcereflect.go new file mode 100644 index 000000000..7378753ce --- /dev/null +++ b/cmd/tailscaled/debug_forcereflect.go @@ -0,0 +1,26 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build ts_debug_forcereflect + +// This file exists for benchmarking binary sizes. When the build tag is +// enabled, it forces use of part of the reflect package that makes the Go +// linker go into conservative retention mode where its deadcode pass can't +// eliminate exported method. + +package main + +import ( + "reflect" + "time" +) + +func init() { + // See Go's src/cmd/compile/internal/walk/expr.go:usemethod for + // why this is isn't a const. + name := []byte("Bar") + if time.Now().Unix()&1 == 0 { + name[0] = 'X' + } + _, _ = reflect.TypeOf(12).MethodByName(string(name)) +} diff --git a/cmd/tailscaled/depaware-min.txt b/cmd/tailscaled/depaware-min.txt new file mode 100644 index 000000000..e750f86e6 --- /dev/null +++ b/cmd/tailscaled/depaware-min.txt @@ -0,0 +1,413 @@ +tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/depaware) + + github.com/gaissmai/bart from tailscale.com/net/ipset+ + github.com/gaissmai/bart/internal/bitset from github.com/gaissmai/bart+ + github.com/gaissmai/bart/internal/sparse from github.com/gaissmai/bart + github.com/go-json-experiment/json from tailscale.com/drive+ + github.com/go-json-experiment/json/internal from github.com/go-json-experiment/json+ + github.com/go-json-experiment/json/internal/jsonflags from github.com/go-json-experiment/json+ + github.com/go-json-experiment/json/internal/jsonopts from github.com/go-json-experiment/json+ + github.com/go-json-experiment/json/internal/jsonwire from github.com/go-json-experiment/json+ + github.com/go-json-experiment/json/jsontext from github.com/go-json-experiment/json+ + github.com/golang/groupcache/lru from tailscale.com/net/dnscache + đŸ’Ŗ github.com/jsimonetti/rtnetlink from tailscale.com/net/netmon + github.com/jsimonetti/rtnetlink/internal/unix from github.com/jsimonetti/rtnetlink + github.com/klauspost/compress from github.com/klauspost/compress/zstd + github.com/klauspost/compress/fse from github.com/klauspost/compress/huff0 + github.com/klauspost/compress/huff0 from github.com/klauspost/compress/zstd + github.com/klauspost/compress/internal/cpuinfo from github.com/klauspost/compress/huff0+ + github.com/klauspost/compress/internal/snapref from github.com/klauspost/compress/zstd + github.com/klauspost/compress/zstd from tailscale.com/util/zstdframe + github.com/klauspost/compress/zstd/internal/xxhash from github.com/klauspost/compress/zstd + đŸ’Ŗ github.com/mdlayher/netlink from github.com/jsimonetti/rtnetlink+ + đŸ’Ŗ github.com/mdlayher/netlink/nlenc from github.com/jsimonetti/rtnetlink+ + đŸ’Ŗ github.com/mdlayher/socket from github.com/mdlayher/netlink + đŸ’Ŗ github.com/safchain/ethtool from tailscale.com/net/netkernelconf + đŸ’Ŗ github.com/tailscale/wireguard-go/conn from github.com/tailscale/wireguard-go/device+ + đŸ’Ŗ github.com/tailscale/wireguard-go/device from tailscale.com/net/tstun+ + github.com/tailscale/wireguard-go/ipc from github.com/tailscale/wireguard-go/device + github.com/tailscale/wireguard-go/ratelimiter from github.com/tailscale/wireguard-go/device + github.com/tailscale/wireguard-go/replay from github.com/tailscale/wireguard-go/device + github.com/tailscale/wireguard-go/rwcancel from github.com/tailscale/wireguard-go/device+ + github.com/tailscale/wireguard-go/tai64n from github.com/tailscale/wireguard-go/device + đŸ’Ŗ github.com/tailscale/wireguard-go/tun from github.com/tailscale/wireguard-go/device+ + đŸ’Ŗ go4.org/mem from tailscale.com/control/controlbase+ + go4.org/netipx from tailscale.com/ipn/ipnlocal+ + tailscale.com from tailscale.com/version + tailscale.com/appc from tailscale.com/ipn/ipnlocal + tailscale.com/atomicfile from tailscale.com/ipn+ + tailscale.com/client/tailscale/apitype from tailscale.com/ipn/ipnauth+ + tailscale.com/cmd/tailscaled/childproc from tailscale.com/cmd/tailscaled + tailscale.com/control/controlbase from tailscale.com/control/controlhttp+ + tailscale.com/control/controlclient from tailscale.com/cmd/tailscaled+ + tailscale.com/control/controlhttp from tailscale.com/control/ts2021 + tailscale.com/control/controlhttp/controlhttpcommon from tailscale.com/control/controlhttp + tailscale.com/control/controlknobs from tailscale.com/control/controlclient+ + tailscale.com/control/ts2021 from tailscale.com/control/controlclient + tailscale.com/derp from tailscale.com/derp/derphttp+ + tailscale.com/derp/derpconst from tailscale.com/derp/derphttp+ + tailscale.com/derp/derphttp from tailscale.com/net/netcheck+ + tailscale.com/disco from tailscale.com/net/tstun+ + tailscale.com/drive from tailscale.com/ipn+ + tailscale.com/envknob from tailscale.com/cmd/tailscaled+ + tailscale.com/envknob/featureknob from tailscale.com/ipn/ipnlocal + tailscale.com/feature from tailscale.com/cmd/tailscaled+ + tailscale.com/feature/buildfeatures from tailscale.com/cmd/tailscaled+ + tailscale.com/feature/condlite/expvar from tailscale.com/wgengine/magicsock + tailscale.com/feature/condregister from tailscale.com/cmd/tailscaled + tailscale.com/feature/condregister/portmapper from tailscale.com/feature/condregister + tailscale.com/feature/condregister/useproxy from tailscale.com/feature/condregister + tailscale.com/health from tailscale.com/control/controlclient+ + tailscale.com/health/healthmsg from tailscale.com/ipn/ipnlocal+ + tailscale.com/hostinfo from tailscale.com/cmd/tailscaled+ + tailscale.com/ipn from tailscale.com/cmd/tailscaled+ + tailscale.com/ipn/conffile from tailscale.com/cmd/tailscaled+ + tailscale.com/ipn/ipnauth from tailscale.com/ipn/ipnext+ + tailscale.com/ipn/ipnext from tailscale.com/ipn/ipnlocal + tailscale.com/ipn/ipnlocal from tailscale.com/cmd/tailscaled+ + tailscale.com/ipn/ipnserver from tailscale.com/cmd/tailscaled + tailscale.com/ipn/ipnstate from tailscale.com/control/controlclient+ + tailscale.com/ipn/localapi from tailscale.com/ipn/ipnserver + tailscale.com/ipn/store from tailscale.com/cmd/tailscaled + tailscale.com/ipn/store/mem from tailscale.com/ipn/store + tailscale.com/kube/kubetypes from tailscale.com/envknob + tailscale.com/log/filelogger from tailscale.com/logpolicy + tailscale.com/log/sockstatlog from tailscale.com/ipn/ipnlocal + tailscale.com/logpolicy from tailscale.com/cmd/tailscaled+ + tailscale.com/logtail from tailscale.com/cmd/tailscaled+ + tailscale.com/logtail/filch from tailscale.com/log/sockstatlog+ + tailscale.com/net/bakedroots from tailscale.com/net/tlsdial + đŸ’Ŗ tailscale.com/net/batching from tailscale.com/wgengine/magicsock + tailscale.com/net/dns from tailscale.com/cmd/tailscaled+ + tailscale.com/net/dns/publicdns from tailscale.com/net/dns+ + tailscale.com/net/dns/resolvconffile from tailscale.com/net/dns+ + tailscale.com/net/dns/resolver from tailscale.com/net/dns+ + tailscale.com/net/dnscache from tailscale.com/control/controlclient+ + tailscale.com/net/dnsfallback from tailscale.com/cmd/tailscaled+ + tailscale.com/net/flowtrack from tailscale.com/wgengine/filter + tailscale.com/net/ipset from tailscale.com/ipn/ipnlocal+ + tailscale.com/net/netaddr from tailscale.com/ipn+ + tailscale.com/net/netcheck from tailscale.com/ipn/ipnlocal+ + tailscale.com/net/neterror from tailscale.com/net/batching+ + tailscale.com/net/netkernelconf from tailscale.com/ipn/ipnlocal + tailscale.com/net/netknob from tailscale.com/logpolicy+ + tailscale.com/net/netmon from tailscale.com/cmd/tailscaled+ + tailscale.com/net/netns from tailscale.com/cmd/tailscaled+ + tailscale.com/net/netutil from tailscale.com/control/controlclient+ + tailscale.com/net/netx from tailscale.com/control/controlclient+ + tailscale.com/net/packet from tailscale.com/ipn/ipnlocal+ + tailscale.com/net/packet/checksum from tailscale.com/net/tstun + tailscale.com/net/ping from tailscale.com/net/netcheck+ + tailscale.com/net/portmapper/portmappertype from tailscale.com/net/netcheck+ + tailscale.com/net/sockopts from tailscale.com/wgengine/magicsock + tailscale.com/net/sockstats from tailscale.com/control/controlclient+ + tailscale.com/net/stun from tailscale.com/net/netcheck+ + tailscale.com/net/tlsdial from tailscale.com/control/controlclient+ + tailscale.com/net/tlsdial/blockblame from tailscale.com/net/tlsdial + tailscale.com/net/tsaddr from tailscale.com/ipn+ + tailscale.com/net/tsdial from tailscale.com/cmd/tailscaled+ + tailscale.com/net/tstun from tailscale.com/cmd/tailscaled+ + tailscale.com/net/udprelay/endpoint from tailscale.com/wgengine/magicsock + tailscale.com/omit from tailscale.com/ipn/conffile + tailscale.com/paths from tailscale.com/cmd/tailscaled+ + tailscale.com/proxymap from tailscale.com/tsd + tailscale.com/safesocket from tailscale.com/cmd/tailscaled+ + tailscale.com/syncs from tailscale.com/cmd/tailscaled+ + tailscale.com/tailcfg from tailscale.com/client/tailscale/apitype+ + tailscale.com/tempfork/heap from tailscale.com/wgengine/magicsock + tailscale.com/tka from tailscale.com/control/controlclient+ + tailscale.com/tsconst from tailscale.com/net/netns+ + tailscale.com/tsd from tailscale.com/cmd/tailscaled+ + tailscale.com/tstime from tailscale.com/control/controlclient+ + tailscale.com/tstime/mono from tailscale.com/net/tstun+ + tailscale.com/tstime/rate from tailscale.com/wgengine/filter + tailscale.com/types/appctype from tailscale.com/ipn/ipnlocal+ + tailscale.com/types/dnstype from tailscale.com/client/tailscale/apitype+ + tailscale.com/types/empty from tailscale.com/ipn+ + tailscale.com/types/flagtype from tailscale.com/cmd/tailscaled + tailscale.com/types/ipproto from tailscale.com/ipn+ + tailscale.com/types/key from tailscale.com/control/controlbase+ + tailscale.com/types/lazy from tailscale.com/hostinfo+ + tailscale.com/types/logger from tailscale.com/appc+ + tailscale.com/types/logid from tailscale.com/cmd/tailscaled+ + tailscale.com/types/mapx from tailscale.com/ipn/ipnext + tailscale.com/types/netlogfunc from tailscale.com/net/tstun+ + tailscale.com/types/netmap from tailscale.com/control/controlclient+ + tailscale.com/types/nettype from tailscale.com/net/batching+ + tailscale.com/types/opt from tailscale.com/control/controlknobs+ + tailscale.com/types/persist from tailscale.com/control/controlclient+ + tailscale.com/types/preftype from tailscale.com/ipn+ + tailscale.com/types/ptr from tailscale.com/control/controlclient+ + tailscale.com/types/result from tailscale.com/util/lineiter + tailscale.com/types/structs from tailscale.com/control/controlclient+ + tailscale.com/types/tkatype from tailscale.com/control/controlclient+ + tailscale.com/types/views from tailscale.com/appc+ + tailscale.com/util/backoff from tailscale.com/control/controlclient+ + tailscale.com/util/checkchange from tailscale.com/ipn/ipnlocal+ + tailscale.com/util/cibuild from tailscale.com/health+ + tailscale.com/util/clientmetric from tailscale.com/appc+ + tailscale.com/util/cloudenv from tailscale.com/hostinfo+ + tailscale.com/util/ctxkey from tailscale.com/client/tailscale/apitype+ + tailscale.com/util/dnsname from tailscale.com/appc+ + tailscale.com/util/eventbus from tailscale.com/control/controlclient+ + tailscale.com/util/execqueue from tailscale.com/appc+ + tailscale.com/util/goroutines from tailscale.com/ipn/ipnlocal + tailscale.com/util/groupmember from tailscale.com/ipn/ipnauth + tailscale.com/util/httpm from tailscale.com/ipn/ipnlocal+ + tailscale.com/util/lineiter from tailscale.com/hostinfo+ + tailscale.com/util/mak from tailscale.com/control/controlclient+ + tailscale.com/util/must from tailscale.com/logpolicy+ + tailscale.com/util/nocasemaps from tailscale.com/types/ipproto + tailscale.com/util/osdiag from tailscale.com/ipn/localapi + tailscale.com/util/osshare from tailscale.com/cmd/tailscaled + tailscale.com/util/osuser from tailscale.com/ipn/ipnlocal+ + tailscale.com/util/race from tailscale.com/net/dns/resolver + tailscale.com/util/racebuild from tailscale.com/logpolicy + tailscale.com/util/rands from tailscale.com/ipn/ipnlocal+ + tailscale.com/util/ringlog from tailscale.com/wgengine/magicsock + tailscale.com/util/set from tailscale.com/control/controlclient+ + tailscale.com/util/singleflight from tailscale.com/control/controlclient+ + tailscale.com/util/slicesx from tailscale.com/appc+ + tailscale.com/util/syspolicy/pkey from tailscale.com/cmd/tailscaled+ + tailscale.com/util/syspolicy/policyclient from tailscale.com/cmd/tailscaled+ + tailscale.com/util/syspolicy/ptype from tailscale.com/ipn/ipnlocal+ + tailscale.com/util/testenv from tailscale.com/control/controlclient+ + tailscale.com/util/usermetric from tailscale.com/ipn/ipnlocal+ + tailscale.com/util/vizerror from tailscale.com/tailcfg+ + tailscale.com/util/winutil from tailscale.com/ipn/ipnauth + tailscale.com/util/zstdframe from tailscale.com/control/controlclient + tailscale.com/version from tailscale.com/cmd/tailscaled+ + tailscale.com/version/distro from tailscale.com/cmd/tailscaled+ + tailscale.com/wgengine from tailscale.com/cmd/tailscaled+ + tailscale.com/wgengine/filter from tailscale.com/control/controlclient+ + tailscale.com/wgengine/filter/filtertype from tailscale.com/types/netmap+ + đŸ’Ŗ tailscale.com/wgengine/magicsock from tailscale.com/ipn/ipnlocal+ + tailscale.com/wgengine/netlog from tailscale.com/wgengine + tailscale.com/wgengine/netstack/gro from tailscale.com/net/tstun+ + tailscale.com/wgengine/router from tailscale.com/cmd/tailscaled+ + tailscale.com/wgengine/wgcfg from tailscale.com/ipn/ipnlocal+ + tailscale.com/wgengine/wgcfg/nmcfg from tailscale.com/ipn/ipnlocal + đŸ’Ŗ tailscale.com/wgengine/wgint from tailscale.com/wgengine+ + tailscale.com/wgengine/wglog from tailscale.com/wgengine + golang.org/x/crypto/blake2b from golang.org/x/crypto/nacl/box + golang.org/x/crypto/blake2s from github.com/tailscale/wireguard-go/device+ + golang.org/x/crypto/chacha20 from golang.org/x/crypto/chacha20poly1305 + golang.org/x/crypto/chacha20poly1305 from github.com/tailscale/wireguard-go/device+ + golang.org/x/crypto/curve25519 from github.com/tailscale/wireguard-go/device+ + golang.org/x/crypto/hkdf from tailscale.com/control/controlbase + golang.org/x/crypto/internal/alias from golang.org/x/crypto/chacha20+ + golang.org/x/crypto/internal/poly1305 from golang.org/x/crypto/chacha20poly1305+ + golang.org/x/crypto/nacl/box from tailscale.com/types/key + golang.org/x/crypto/nacl/secretbox from golang.org/x/crypto/nacl/box + golang.org/x/crypto/poly1305 from github.com/tailscale/wireguard-go/device + golang.org/x/crypto/salsa20/salsa from golang.org/x/crypto/nacl/box+ + golang.org/x/exp/constraints from tailscale.com/util/set + golang.org/x/exp/maps from tailscale.com/ipn/store/mem + golang.org/x/net/bpf from github.com/mdlayher/netlink+ + golang.org/x/net/dns/dnsmessage from tailscale.com/ipn/ipnlocal+ + golang.org/x/net/http/httpguts from tailscale.com/ipn/ipnlocal + golang.org/x/net/icmp from tailscale.com/net/ping + golang.org/x/net/idna from golang.org/x/net/http/httpguts + golang.org/x/net/internal/iana from golang.org/x/net/icmp+ + golang.org/x/net/internal/socket from golang.org/x/net/icmp+ + golang.org/x/net/ipv4 from github.com/tailscale/wireguard-go/conn+ + golang.org/x/net/ipv6 from github.com/tailscale/wireguard-go/conn+ + golang.org/x/sync/errgroup from github.com/mdlayher/socket + golang.org/x/sys/cpu from github.com/tailscale/wireguard-go/tun+ + golang.org/x/sys/unix from github.com/jsimonetti/rtnetlink/internal/unix+ + golang.org/x/term from tailscale.com/logpolicy + golang.org/x/text/secure/bidirule from golang.org/x/net/idna + golang.org/x/text/transform from golang.org/x/text/secure/bidirule+ + golang.org/x/text/unicode/bidi from golang.org/x/net/idna+ + golang.org/x/text/unicode/norm from golang.org/x/net/idna + golang.org/x/time/rate from tailscale.com/derp + vendor/golang.org/x/crypto/chacha20 from vendor/golang.org/x/crypto/chacha20poly1305 + vendor/golang.org/x/crypto/chacha20poly1305 from crypto/internal/hpke+ + vendor/golang.org/x/crypto/cryptobyte from crypto/ecdsa+ + vendor/golang.org/x/crypto/cryptobyte/asn1 from crypto/ecdsa+ + vendor/golang.org/x/crypto/internal/alias from vendor/golang.org/x/crypto/chacha20+ + vendor/golang.org/x/crypto/internal/poly1305 from vendor/golang.org/x/crypto/chacha20poly1305 + vendor/golang.org/x/net/dns/dnsmessage from net + vendor/golang.org/x/net/http/httpguts from net/http+ + vendor/golang.org/x/net/http/httpproxy from net/http + vendor/golang.org/x/net/http2/hpack from net/http+ + vendor/golang.org/x/net/idna from net/http+ + vendor/golang.org/x/sys/cpu from vendor/golang.org/x/crypto/chacha20poly1305 + vendor/golang.org/x/text/secure/bidirule from vendor/golang.org/x/net/idna + vendor/golang.org/x/text/transform from vendor/golang.org/x/text/secure/bidirule+ + vendor/golang.org/x/text/unicode/bidi from vendor/golang.org/x/net/idna+ + vendor/golang.org/x/text/unicode/norm from vendor/golang.org/x/net/idna + bufio from compress/flate+ + bytes from bufio+ + cmp from encoding/json+ + compress/flate from compress/gzip + compress/gzip from net/http + container/list from crypto/tls+ + context from crypto/tls+ + crypto from crypto/ecdh+ + crypto/aes from crypto/internal/hpke+ + crypto/cipher from crypto/aes+ + crypto/des from crypto/tls+ + crypto/dsa from crypto/x509 + crypto/ecdh from crypto/ecdsa+ + crypto/ecdsa from crypto/tls+ + crypto/ed25519 from crypto/tls+ + crypto/elliptic from crypto/ecdsa+ + crypto/fips140 from crypto/tls/internal/fips140tls + crypto/hkdf from crypto/internal/hpke+ + crypto/hmac from crypto/tls+ + crypto/internal/boring from crypto/aes+ + crypto/internal/boring/bbig from crypto/ecdsa+ + crypto/internal/boring/sig from crypto/internal/boring + crypto/internal/entropy from crypto/internal/fips140/drbg + crypto/internal/fips140 from crypto/fips140+ + crypto/internal/fips140/aes from crypto/aes+ + crypto/internal/fips140/aes/gcm from crypto/cipher+ + crypto/internal/fips140/alias from crypto/cipher+ + crypto/internal/fips140/bigmod from crypto/internal/fips140/ecdsa+ + crypto/internal/fips140/check from crypto/fips140+ + crypto/internal/fips140/drbg from crypto/internal/fips140/aes/gcm+ + crypto/internal/fips140/ecdh from crypto/ecdh + crypto/internal/fips140/ecdsa from crypto/ecdsa + crypto/internal/fips140/ed25519 from crypto/ed25519 + crypto/internal/fips140/edwards25519 from crypto/internal/fips140/ed25519 + crypto/internal/fips140/edwards25519/field from crypto/ecdh+ + crypto/internal/fips140/hkdf from crypto/hkdf+ + crypto/internal/fips140/hmac from crypto/hmac+ + crypto/internal/fips140/mlkem from crypto/tls + crypto/internal/fips140/nistec from crypto/ecdsa+ + crypto/internal/fips140/nistec/fiat from crypto/internal/fips140/nistec + crypto/internal/fips140/rsa from crypto/rsa + crypto/internal/fips140/sha256 from crypto/internal/fips140/check+ + crypto/internal/fips140/sha3 from crypto/internal/fips140/hmac+ + crypto/internal/fips140/sha512 from crypto/internal/fips140/ecdsa+ + crypto/internal/fips140/subtle from crypto/internal/fips140/aes+ + crypto/internal/fips140/tls12 from crypto/tls + crypto/internal/fips140/tls13 from crypto/tls + crypto/internal/fips140cache from crypto/ecdsa+ + crypto/internal/fips140deps/byteorder from crypto/internal/fips140/aes+ + crypto/internal/fips140deps/cpu from crypto/internal/fips140/aes+ + crypto/internal/fips140deps/godebug from crypto/internal/fips140+ + crypto/internal/fips140hash from crypto/ecdsa+ + crypto/internal/fips140only from crypto/cipher+ + crypto/internal/hpke from crypto/tls + crypto/internal/impl from crypto/internal/fips140/aes+ + crypto/internal/randutil from crypto/dsa+ + crypto/internal/sysrand from crypto/internal/entropy+ + crypto/md5 from crypto/tls+ + crypto/rand from crypto/ed25519+ + crypto/rc4 from crypto/tls + crypto/rsa from crypto/tls+ + crypto/sha1 from crypto/tls+ + crypto/sha256 from crypto/tls+ + crypto/sha3 from crypto/internal/fips140hash + crypto/sha512 from crypto/ecdsa+ + crypto/subtle from crypto/cipher+ + crypto/tls from net/http+ + crypto/tls/internal/fips140tls from crypto/tls + crypto/x509 from crypto/tls+ + crypto/x509/pkix from crypto/x509 + embed from tailscale.com+ + encoding from encoding/json+ + encoding/asn1 from crypto/x509+ + encoding/base32 from github.com/go-json-experiment/json + encoding/base64 from encoding/json+ + encoding/binary from compress/gzip+ + encoding/hex from crypto/x509+ + encoding/json from github.com/gaissmai/bart+ + encoding/pem from crypto/tls+ + errors from bufio+ + flag from tailscale.com/cmd/tailscaled+ + fmt from compress/flate+ + hash from crypto+ + hash/crc32 from compress/gzip+ + hash/maphash from go4.org/mem + html from tailscale.com/ipn/ipnlocal+ + internal/abi from hash/maphash+ + internal/asan from internal/runtime/maps+ + internal/bisect from internal/godebug + internal/bytealg from bytes+ + internal/byteorder from crypto/cipher+ + internal/chacha8rand from math/rand/v2+ + internal/coverage/rtcov from runtime + internal/cpu from crypto/internal/fips140deps/cpu+ + internal/filepathlite from os+ + internal/fmtsort from fmt + internal/goarch from crypto/internal/fips140deps/cpu+ + internal/godebug from crypto/internal/fips140deps/godebug+ + internal/godebugs from internal/godebug+ + internal/goexperiment from hash/maphash+ + internal/goos from crypto/x509+ + internal/itoa from internal/poll+ + internal/msan from internal/runtime/maps+ + internal/nettrace from net+ + internal/oserror from io/fs+ + internal/poll from net+ + internal/profilerecord from runtime + internal/race from internal/runtime/maps+ + internal/reflectlite from context+ + internal/runtime/atomic from internal/runtime/exithook+ + internal/runtime/cgroup from runtime + internal/runtime/exithook from runtime + internal/runtime/gc from runtime + internal/runtime/maps from reflect+ + internal/runtime/math from internal/runtime/maps+ + internal/runtime/strconv from internal/runtime/cgroup+ + internal/runtime/sys from crypto/subtle+ + internal/runtime/syscall from internal/runtime/cgroup+ + internal/saferio from encoding/asn1 + internal/singleflight from net + internal/stringslite from embed+ + internal/sync from sync+ + internal/synctest from sync + internal/syscall/execenv from os+ + internal/syscall/unix from crypto/internal/sysrand+ + internal/testlog from os + internal/trace/tracev2 from runtime + internal/unsafeheader from internal/reflectlite+ + io from bufio+ + io/fs from crypto/x509+ + iter from bytes+ + log from github.com/klauspost/compress/zstd+ + log/internal from log + maps from crypto/x509+ + math from compress/flate+ + math/big from crypto/dsa+ + math/bits from bytes+ + math/rand from github.com/mdlayher/netlink+ + math/rand/v2 from crypto/ecdsa+ + mime from mime/multipart+ + mime/multipart from net/http + mime/quotedprintable from mime/multipart + net from crypto/tls+ + net/http from tailscale.com/cmd/tailscaled+ + net/http/httptrace from net/http+ + net/http/internal from net/http + net/http/internal/ascii from net/http + net/http/internal/httpcommon from net/http + net/netip from crypto/x509+ + net/textproto from golang.org/x/net/http/httpguts+ + net/url from crypto/x509+ + os from crypto/internal/sysrand+ + os/exec from tailscale.com/hostinfo+ + os/signal from tailscale.com/cmd/tailscaled + os/user from tailscale.com/ipn/ipnauth+ + path from io/fs+ + path/filepath from crypto/x509+ + reflect from crypto/x509+ + runtime from crypto/internal/fips140+ + runtime/debug from github.com/klauspost/compress/zstd+ + slices from crypto/tls+ + sort from compress/flate+ + strconv from compress/flate+ + strings from bufio+ + sync from compress/flate+ + sync/atomic from context+ + syscall from crypto/internal/sysrand+ + time from compress/gzip+ + unicode from bytes+ + unicode/utf16 from crypto/x509+ + unicode/utf8 from bufio+ + unique from net/netip + unsafe from bytes+ + weak from crypto/internal/fips140cache+ diff --git a/cmd/tailscaled/depaware-minbox.txt b/cmd/tailscaled/depaware-minbox.txt new file mode 100644 index 000000000..17f1a22b2 --- /dev/null +++ b/cmd/tailscaled/depaware-minbox.txt @@ -0,0 +1,453 @@ +tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/depaware) + + filippo.io/edwards25519 from github.com/hdevalence/ed25519consensus + filippo.io/edwards25519/field from filippo.io/edwards25519 + github.com/gaissmai/bart from tailscale.com/net/ipset+ + github.com/gaissmai/bart/internal/bitset from github.com/gaissmai/bart+ + github.com/gaissmai/bart/internal/sparse from github.com/gaissmai/bart + github.com/go-json-experiment/json from tailscale.com/drive+ + github.com/go-json-experiment/json/internal from github.com/go-json-experiment/json+ + github.com/go-json-experiment/json/internal/jsonflags from github.com/go-json-experiment/json+ + github.com/go-json-experiment/json/internal/jsonopts from github.com/go-json-experiment/json+ + github.com/go-json-experiment/json/internal/jsonwire from github.com/go-json-experiment/json+ + github.com/go-json-experiment/json/jsontext from github.com/go-json-experiment/json+ + github.com/golang/groupcache/lru from tailscale.com/net/dnscache + github.com/hdevalence/ed25519consensus from tailscale.com/clientupdate/distsign + đŸ’Ŗ github.com/jsimonetti/rtnetlink from tailscale.com/net/netmon + github.com/jsimonetti/rtnetlink/internal/unix from github.com/jsimonetti/rtnetlink + github.com/kballard/go-shellquote from tailscale.com/cmd/tailscale/cli + github.com/klauspost/compress from github.com/klauspost/compress/zstd + github.com/klauspost/compress/fse from github.com/klauspost/compress/huff0 + github.com/klauspost/compress/huff0 from github.com/klauspost/compress/zstd + github.com/klauspost/compress/internal/cpuinfo from github.com/klauspost/compress/huff0+ + github.com/klauspost/compress/internal/snapref from github.com/klauspost/compress/zstd + github.com/klauspost/compress/zstd from tailscale.com/util/zstdframe + github.com/klauspost/compress/zstd/internal/xxhash from github.com/klauspost/compress/zstd + github.com/mattn/go-colorable from tailscale.com/cmd/tailscale/cli + github.com/mattn/go-isatty from github.com/mattn/go-colorable+ + đŸ’Ŗ github.com/mdlayher/netlink from github.com/jsimonetti/rtnetlink+ + đŸ’Ŗ github.com/mdlayher/netlink/nlenc from github.com/jsimonetti/rtnetlink+ + đŸ’Ŗ github.com/mdlayher/socket from github.com/mdlayher/netlink + github.com/peterbourgon/ff/v3 from github.com/peterbourgon/ff/v3/ffcli+ + github.com/peterbourgon/ff/v3/ffcli from tailscale.com/cmd/tailscale/cli+ + github.com/peterbourgon/ff/v3/internal from github.com/peterbourgon/ff/v3 + đŸ’Ŗ github.com/safchain/ethtool from tailscale.com/net/netkernelconf + github.com/skip2/go-qrcode from tailscale.com/cmd/tailscale/cli + github.com/skip2/go-qrcode/bitset from github.com/skip2/go-qrcode+ + github.com/skip2/go-qrcode/reedsolomon from github.com/skip2/go-qrcode + đŸ’Ŗ github.com/tailscale/wireguard-go/conn from github.com/tailscale/wireguard-go/device+ + đŸ’Ŗ github.com/tailscale/wireguard-go/device from tailscale.com/net/tstun+ + github.com/tailscale/wireguard-go/ipc from github.com/tailscale/wireguard-go/device + github.com/tailscale/wireguard-go/ratelimiter from github.com/tailscale/wireguard-go/device + github.com/tailscale/wireguard-go/replay from github.com/tailscale/wireguard-go/device + github.com/tailscale/wireguard-go/rwcancel from github.com/tailscale/wireguard-go/device+ + github.com/tailscale/wireguard-go/tai64n from github.com/tailscale/wireguard-go/device + đŸ’Ŗ github.com/tailscale/wireguard-go/tun from github.com/tailscale/wireguard-go/device+ + github.com/toqueteos/webbrowser from tailscale.com/cmd/tailscale/cli + đŸ’Ŗ go4.org/mem from tailscale.com/control/controlbase+ + go4.org/netipx from tailscale.com/ipn/ipnlocal+ + tailscale.com from tailscale.com/version + tailscale.com/appc from tailscale.com/ipn/ipnlocal + tailscale.com/atomicfile from tailscale.com/ipn+ + tailscale.com/client/local from tailscale.com/client/tailscale+ + tailscale.com/client/tailscale from tailscale.com/internal/client/tailscale + tailscale.com/client/tailscale/apitype from tailscale.com/ipn/ipnauth+ + tailscale.com/clientupdate from tailscale.com/cmd/tailscale/cli + tailscale.com/clientupdate/distsign from tailscale.com/clientupdate + tailscale.com/cmd/tailscale/cli from tailscale.com/cmd/tailscaled + tailscale.com/cmd/tailscale/cli/ffcomplete from tailscale.com/cmd/tailscale/cli + tailscale.com/cmd/tailscale/cli/ffcomplete/internal from tailscale.com/cmd/tailscale/cli/ffcomplete + tailscale.com/cmd/tailscaled/childproc from tailscale.com/cmd/tailscaled + tailscale.com/control/controlbase from tailscale.com/control/controlhttp+ + tailscale.com/control/controlclient from tailscale.com/cmd/tailscaled+ + tailscale.com/control/controlhttp from tailscale.com/control/ts2021 + tailscale.com/control/controlhttp/controlhttpcommon from tailscale.com/control/controlhttp + tailscale.com/control/controlknobs from tailscale.com/control/controlclient+ + tailscale.com/control/ts2021 from tailscale.com/control/controlclient+ + tailscale.com/derp from tailscale.com/derp/derphttp+ + tailscale.com/derp/derpconst from tailscale.com/derp/derphttp+ + tailscale.com/derp/derphttp from tailscale.com/net/netcheck+ + tailscale.com/disco from tailscale.com/net/tstun+ + tailscale.com/drive from tailscale.com/ipn+ + tailscale.com/envknob from tailscale.com/cmd/tailscaled+ + tailscale.com/envknob/featureknob from tailscale.com/ipn/ipnlocal + tailscale.com/feature from tailscale.com/cmd/tailscaled+ + tailscale.com/feature/buildfeatures from tailscale.com/ipn/ipnlocal+ + tailscale.com/feature/condlite/expvar from tailscale.com/wgengine/magicsock + tailscale.com/feature/condregister from tailscale.com/cmd/tailscaled + tailscale.com/feature/condregister/identityfederation from tailscale.com/cmd/tailscale/cli + tailscale.com/feature/condregister/oauthkey from tailscale.com/cmd/tailscale/cli + tailscale.com/feature/condregister/portmapper from tailscale.com/feature/condregister+ + tailscale.com/feature/condregister/useproxy from tailscale.com/cmd/tailscale/cli+ + tailscale.com/health from tailscale.com/control/controlclient+ + tailscale.com/health/healthmsg from tailscale.com/ipn/ipnlocal+ + tailscale.com/hostinfo from tailscale.com/cmd/tailscaled+ + tailscale.com/internal/client/tailscale from tailscale.com/cmd/tailscale/cli + tailscale.com/ipn from tailscale.com/cmd/tailscaled+ + tailscale.com/ipn/conffile from tailscale.com/cmd/tailscaled+ + tailscale.com/ipn/ipnauth from tailscale.com/ipn/ipnext+ + tailscale.com/ipn/ipnext from tailscale.com/ipn/ipnlocal + tailscale.com/ipn/ipnlocal from tailscale.com/cmd/tailscaled+ + tailscale.com/ipn/ipnserver from tailscale.com/cmd/tailscaled + tailscale.com/ipn/ipnstate from tailscale.com/control/controlclient+ + tailscale.com/ipn/localapi from tailscale.com/ipn/ipnserver + tailscale.com/ipn/store from tailscale.com/cmd/tailscaled + tailscale.com/ipn/store/mem from tailscale.com/ipn/store + tailscale.com/kube/kubetypes from tailscale.com/envknob + tailscale.com/licenses from tailscale.com/cmd/tailscale/cli + tailscale.com/log/filelogger from tailscale.com/logpolicy + tailscale.com/log/sockstatlog from tailscale.com/ipn/ipnlocal + tailscale.com/logpolicy from tailscale.com/cmd/tailscaled+ + tailscale.com/logtail from tailscale.com/cmd/tailscaled+ + tailscale.com/logtail/filch from tailscale.com/log/sockstatlog+ + tailscale.com/net/ace from tailscale.com/cmd/tailscale/cli + tailscale.com/net/bakedroots from tailscale.com/net/tlsdial + đŸ’Ŗ tailscale.com/net/batching from tailscale.com/wgengine/magicsock + tailscale.com/net/dns from tailscale.com/cmd/tailscaled+ + tailscale.com/net/dns/publicdns from tailscale.com/net/dns+ + tailscale.com/net/dns/resolvconffile from tailscale.com/net/dns+ + tailscale.com/net/dns/resolver from tailscale.com/net/dns+ + tailscale.com/net/dnscache from tailscale.com/control/controlclient+ + tailscale.com/net/dnsfallback from tailscale.com/cmd/tailscaled+ + tailscale.com/net/flowtrack from tailscale.com/wgengine/filter + tailscale.com/net/ipset from tailscale.com/ipn/ipnlocal+ + tailscale.com/net/netaddr from tailscale.com/ipn+ + tailscale.com/net/netcheck from tailscale.com/ipn/ipnlocal+ + tailscale.com/net/neterror from tailscale.com/net/batching+ + tailscale.com/net/netkernelconf from tailscale.com/ipn/ipnlocal + tailscale.com/net/netknob from tailscale.com/logpolicy+ + tailscale.com/net/netmon from tailscale.com/cmd/tailscaled+ + tailscale.com/net/netns from tailscale.com/cmd/tailscaled+ + tailscale.com/net/netutil from tailscale.com/control/controlclient+ + tailscale.com/net/netx from tailscale.com/control/controlclient+ + tailscale.com/net/packet from tailscale.com/ipn/ipnlocal+ + tailscale.com/net/packet/checksum from tailscale.com/net/tstun + tailscale.com/net/ping from tailscale.com/net/netcheck+ + tailscale.com/net/portmapper/portmappertype from tailscale.com/net/netcheck+ + tailscale.com/net/sockopts from tailscale.com/wgengine/magicsock + tailscale.com/net/sockstats from tailscale.com/control/controlclient+ + tailscale.com/net/stun from tailscale.com/net/netcheck+ + tailscale.com/net/tlsdial from tailscale.com/control/controlclient+ + tailscale.com/net/tlsdial/blockblame from tailscale.com/net/tlsdial + tailscale.com/net/tsaddr from tailscale.com/ipn+ + tailscale.com/net/tsdial from tailscale.com/cmd/tailscaled+ + tailscale.com/net/tstun from tailscale.com/cmd/tailscaled+ + tailscale.com/net/udprelay/endpoint from tailscale.com/wgengine/magicsock + tailscale.com/net/udprelay/status from tailscale.com/client/local + tailscale.com/omit from tailscale.com/ipn/conffile + tailscale.com/paths from tailscale.com/cmd/tailscaled+ + tailscale.com/proxymap from tailscale.com/tsd + tailscale.com/safesocket from tailscale.com/cmd/tailscaled+ + tailscale.com/syncs from tailscale.com/cmd/tailscaled+ + tailscale.com/tailcfg from tailscale.com/client/tailscale/apitype+ + tailscale.com/tempfork/heap from tailscale.com/wgengine/magicsock + tailscale.com/tempfork/spf13/cobra from tailscale.com/cmd/tailscale/cli/ffcomplete+ + tailscale.com/tka from tailscale.com/control/controlclient+ + tailscale.com/tsconst from tailscale.com/net/netns+ + tailscale.com/tsd from tailscale.com/cmd/tailscaled+ + tailscale.com/tstime from tailscale.com/control/controlclient+ + tailscale.com/tstime/mono from tailscale.com/net/tstun+ + tailscale.com/tstime/rate from tailscale.com/wgengine/filter + tailscale.com/types/appctype from tailscale.com/ipn/ipnlocal+ + tailscale.com/types/dnstype from tailscale.com/client/tailscale/apitype+ + tailscale.com/types/empty from tailscale.com/ipn+ + tailscale.com/types/flagtype from tailscale.com/cmd/tailscaled + tailscale.com/types/ipproto from tailscale.com/ipn+ + tailscale.com/types/key from tailscale.com/client/local+ + tailscale.com/types/lazy from tailscale.com/hostinfo+ + tailscale.com/types/logger from tailscale.com/appc+ + tailscale.com/types/logid from tailscale.com/cmd/tailscaled+ + tailscale.com/types/mapx from tailscale.com/ipn/ipnext + tailscale.com/types/netlogfunc from tailscale.com/net/tstun+ + tailscale.com/types/netmap from tailscale.com/control/controlclient+ + tailscale.com/types/nettype from tailscale.com/net/batching+ + tailscale.com/types/opt from tailscale.com/control/controlknobs+ + tailscale.com/types/persist from tailscale.com/control/controlclient+ + tailscale.com/types/preftype from tailscale.com/ipn+ + tailscale.com/types/ptr from tailscale.com/control/controlclient+ + tailscale.com/types/result from tailscale.com/util/lineiter + tailscale.com/types/structs from tailscale.com/control/controlclient+ + tailscale.com/types/tkatype from tailscale.com/control/controlclient+ + tailscale.com/types/views from tailscale.com/appc+ + tailscale.com/util/backoff from tailscale.com/control/controlclient+ + tailscale.com/util/checkchange from tailscale.com/ipn/ipnlocal+ + tailscale.com/util/cibuild from tailscale.com/health+ + tailscale.com/util/clientmetric from tailscale.com/appc+ + tailscale.com/util/cloudenv from tailscale.com/hostinfo+ + tailscale.com/util/cmpver from tailscale.com/clientupdate + tailscale.com/util/ctxkey from tailscale.com/client/tailscale/apitype+ + tailscale.com/util/dnsname from tailscale.com/appc+ + tailscale.com/util/eventbus from tailscale.com/client/local+ + tailscale.com/util/execqueue from tailscale.com/appc+ + tailscale.com/util/goroutines from tailscale.com/ipn/ipnlocal + tailscale.com/util/groupmember from tailscale.com/ipn/ipnauth + tailscale.com/util/httpm from tailscale.com/ipn/ipnlocal+ + tailscale.com/util/lineiter from tailscale.com/hostinfo+ + tailscale.com/util/mak from tailscale.com/control/controlclient+ + tailscale.com/util/must from tailscale.com/logpolicy+ + tailscale.com/util/nocasemaps from tailscale.com/types/ipproto + tailscale.com/util/osdiag from tailscale.com/ipn/localapi + tailscale.com/util/osshare from tailscale.com/cmd/tailscaled + tailscale.com/util/osuser from tailscale.com/ipn/ipnlocal+ + tailscale.com/util/prompt from tailscale.com/cmd/tailscale/cli + tailscale.com/util/race from tailscale.com/net/dns/resolver + tailscale.com/util/racebuild from tailscale.com/logpolicy + tailscale.com/util/rands from tailscale.com/ipn/ipnlocal+ + tailscale.com/util/ringlog from tailscale.com/wgengine/magicsock + tailscale.com/util/set from tailscale.com/control/controlclient+ + tailscale.com/util/singleflight from tailscale.com/control/controlclient+ + tailscale.com/util/slicesx from tailscale.com/appc+ + tailscale.com/util/syspolicy/pkey from tailscale.com/cmd/tailscaled+ + tailscale.com/util/syspolicy/policyclient from tailscale.com/cmd/tailscaled+ + tailscale.com/util/syspolicy/ptype from tailscale.com/ipn/ipnlocal+ + tailscale.com/util/testenv from tailscale.com/control/controlclient+ + tailscale.com/util/usermetric from tailscale.com/ipn/ipnlocal+ + tailscale.com/util/vizerror from tailscale.com/tailcfg+ + tailscale.com/util/winutil from tailscale.com/ipn/ipnauth + tailscale.com/util/zstdframe from tailscale.com/control/controlclient + tailscale.com/version from tailscale.com/cmd/tailscaled+ + tailscale.com/version/distro from tailscale.com/cmd/tailscaled+ + tailscale.com/wgengine from tailscale.com/cmd/tailscaled+ + tailscale.com/wgengine/filter from tailscale.com/control/controlclient+ + tailscale.com/wgengine/filter/filtertype from tailscale.com/types/netmap+ + đŸ’Ŗ tailscale.com/wgengine/magicsock from tailscale.com/ipn/ipnlocal+ + tailscale.com/wgengine/netlog from tailscale.com/wgengine + tailscale.com/wgengine/netstack/gro from tailscale.com/net/tstun+ + tailscale.com/wgengine/router from tailscale.com/cmd/tailscaled+ + tailscale.com/wgengine/wgcfg from tailscale.com/ipn/ipnlocal+ + tailscale.com/wgengine/wgcfg/nmcfg from tailscale.com/ipn/ipnlocal + đŸ’Ŗ tailscale.com/wgengine/wgint from tailscale.com/wgengine+ + tailscale.com/wgengine/wglog from tailscale.com/wgengine + golang.org/x/crypto/blake2b from golang.org/x/crypto/nacl/box + golang.org/x/crypto/blake2s from github.com/tailscale/wireguard-go/device+ + golang.org/x/crypto/chacha20 from golang.org/x/crypto/chacha20poly1305 + golang.org/x/crypto/chacha20poly1305 from github.com/tailscale/wireguard-go/device+ + golang.org/x/crypto/curve25519 from github.com/tailscale/wireguard-go/device+ + golang.org/x/crypto/hkdf from tailscale.com/control/controlbase + golang.org/x/crypto/internal/alias from golang.org/x/crypto/chacha20+ + golang.org/x/crypto/internal/poly1305 from golang.org/x/crypto/chacha20poly1305+ + golang.org/x/crypto/nacl/box from tailscale.com/types/key + golang.org/x/crypto/nacl/secretbox from golang.org/x/crypto/nacl/box + golang.org/x/crypto/poly1305 from github.com/tailscale/wireguard-go/device + golang.org/x/crypto/salsa20/salsa from golang.org/x/crypto/nacl/box+ + golang.org/x/exp/constraints from tailscale.com/util/set + golang.org/x/exp/maps from tailscale.com/ipn/store/mem + golang.org/x/net/bpf from github.com/mdlayher/netlink+ + golang.org/x/net/dns/dnsmessage from tailscale.com/cmd/tailscale/cli+ + golang.org/x/net/http/httpguts from tailscale.com/ipn/ipnlocal + golang.org/x/net/icmp from tailscale.com/net/ping + golang.org/x/net/idna from golang.org/x/net/http/httpguts+ + golang.org/x/net/internal/iana from golang.org/x/net/icmp+ + golang.org/x/net/internal/socket from golang.org/x/net/icmp+ + golang.org/x/net/ipv4 from github.com/tailscale/wireguard-go/conn+ + golang.org/x/net/ipv6 from github.com/tailscale/wireguard-go/conn+ + golang.org/x/sync/errgroup from github.com/mdlayher/socket + golang.org/x/sys/cpu from github.com/tailscale/wireguard-go/tun+ + golang.org/x/sys/unix from github.com/jsimonetti/rtnetlink/internal/unix+ + golang.org/x/term from tailscale.com/logpolicy + golang.org/x/text/secure/bidirule from golang.org/x/net/idna + golang.org/x/text/transform from golang.org/x/text/secure/bidirule+ + golang.org/x/text/unicode/bidi from golang.org/x/net/idna+ + golang.org/x/text/unicode/norm from golang.org/x/net/idna + golang.org/x/time/rate from tailscale.com/derp + vendor/golang.org/x/crypto/chacha20 from vendor/golang.org/x/crypto/chacha20poly1305 + vendor/golang.org/x/crypto/chacha20poly1305 from crypto/internal/hpke+ + vendor/golang.org/x/crypto/cryptobyte from crypto/ecdsa+ + vendor/golang.org/x/crypto/cryptobyte/asn1 from crypto/ecdsa+ + vendor/golang.org/x/crypto/internal/alias from vendor/golang.org/x/crypto/chacha20+ + vendor/golang.org/x/crypto/internal/poly1305 from vendor/golang.org/x/crypto/chacha20poly1305 + vendor/golang.org/x/net/dns/dnsmessage from net + vendor/golang.org/x/net/http/httpguts from net/http+ + vendor/golang.org/x/net/http/httpproxy from net/http + vendor/golang.org/x/net/http2/hpack from net/http+ + vendor/golang.org/x/net/idna from net/http+ + vendor/golang.org/x/sys/cpu from vendor/golang.org/x/crypto/chacha20poly1305 + vendor/golang.org/x/text/secure/bidirule from vendor/golang.org/x/net/idna + vendor/golang.org/x/text/transform from vendor/golang.org/x/text/secure/bidirule+ + vendor/golang.org/x/text/unicode/bidi from vendor/golang.org/x/net/idna+ + vendor/golang.org/x/text/unicode/norm from vendor/golang.org/x/net/idna + archive/tar from tailscale.com/clientupdate + bufio from compress/flate+ + bytes from bufio+ + cmp from encoding/json+ + compress/flate from compress/gzip+ + compress/gzip from net/http+ + compress/zlib from image/png + container/list from crypto/tls+ + context from crypto/tls+ + crypto from crypto/ecdh+ + crypto/aes from crypto/internal/hpke+ + crypto/cipher from crypto/aes+ + crypto/des from crypto/tls+ + crypto/dsa from crypto/x509 + crypto/ecdh from crypto/ecdsa+ + crypto/ecdsa from crypto/tls+ + crypto/ed25519 from crypto/tls+ + crypto/elliptic from crypto/ecdsa+ + crypto/fips140 from crypto/tls/internal/fips140tls + crypto/hkdf from crypto/internal/hpke+ + crypto/hmac from crypto/tls+ + crypto/internal/boring from crypto/aes+ + crypto/internal/boring/bbig from crypto/ecdsa+ + crypto/internal/boring/sig from crypto/internal/boring + crypto/internal/entropy from crypto/internal/fips140/drbg + crypto/internal/fips140 from crypto/fips140+ + crypto/internal/fips140/aes from crypto/aes+ + crypto/internal/fips140/aes/gcm from crypto/cipher+ + crypto/internal/fips140/alias from crypto/cipher+ + crypto/internal/fips140/bigmod from crypto/internal/fips140/ecdsa+ + crypto/internal/fips140/check from crypto/fips140+ + crypto/internal/fips140/drbg from crypto/internal/fips140/aes/gcm+ + crypto/internal/fips140/ecdh from crypto/ecdh + crypto/internal/fips140/ecdsa from crypto/ecdsa + crypto/internal/fips140/ed25519 from crypto/ed25519 + crypto/internal/fips140/edwards25519 from crypto/internal/fips140/ed25519 + crypto/internal/fips140/edwards25519/field from crypto/ecdh+ + crypto/internal/fips140/hkdf from crypto/hkdf+ + crypto/internal/fips140/hmac from crypto/hmac+ + crypto/internal/fips140/mlkem from crypto/tls + crypto/internal/fips140/nistec from crypto/ecdsa+ + crypto/internal/fips140/nistec/fiat from crypto/internal/fips140/nistec + crypto/internal/fips140/rsa from crypto/rsa + crypto/internal/fips140/sha256 from crypto/internal/fips140/check+ + crypto/internal/fips140/sha3 from crypto/internal/fips140/hmac+ + crypto/internal/fips140/sha512 from crypto/internal/fips140/ecdsa+ + crypto/internal/fips140/subtle from crypto/internal/fips140/aes+ + crypto/internal/fips140/tls12 from crypto/tls + crypto/internal/fips140/tls13 from crypto/tls + crypto/internal/fips140cache from crypto/ecdsa+ + crypto/internal/fips140deps/byteorder from crypto/internal/fips140/aes+ + crypto/internal/fips140deps/cpu from crypto/internal/fips140/aes+ + crypto/internal/fips140deps/godebug from crypto/internal/fips140+ + crypto/internal/fips140hash from crypto/ecdsa+ + crypto/internal/fips140only from crypto/cipher+ + crypto/internal/hpke from crypto/tls + crypto/internal/impl from crypto/internal/fips140/aes+ + crypto/internal/randutil from crypto/dsa+ + crypto/internal/sysrand from crypto/internal/entropy+ + crypto/md5 from crypto/tls+ + crypto/rand from crypto/ed25519+ + crypto/rc4 from crypto/tls + crypto/rsa from crypto/tls+ + crypto/sha1 from crypto/tls+ + crypto/sha256 from crypto/tls+ + crypto/sha3 from crypto/internal/fips140hash + crypto/sha512 from crypto/ecdsa+ + crypto/subtle from crypto/cipher+ + crypto/tls from net/http+ + crypto/tls/internal/fips140tls from crypto/tls + crypto/x509 from crypto/tls+ + crypto/x509/pkix from crypto/x509 + embed from tailscale.com+ + encoding from encoding/json+ + encoding/asn1 from crypto/x509+ + encoding/base32 from github.com/go-json-experiment/json + encoding/base64 from encoding/json+ + encoding/binary from compress/gzip+ + encoding/hex from crypto/x509+ + encoding/json from github.com/gaissmai/bart+ + encoding/pem from crypto/tls+ + errors from bufio+ + flag from tailscale.com/cmd/tailscaled+ + fmt from compress/flate+ + hash from crypto+ + hash/adler32 from compress/zlib + hash/crc32 from compress/gzip+ + hash/maphash from go4.org/mem + html from tailscale.com/ipn/ipnlocal+ + image from github.com/skip2/go-qrcode+ + image/color from github.com/skip2/go-qrcode+ + image/png from github.com/skip2/go-qrcode + internal/abi from hash/maphash+ + internal/asan from internal/runtime/maps+ + internal/bisect from internal/godebug + internal/bytealg from bytes+ + internal/byteorder from crypto/cipher+ + internal/chacha8rand from math/rand/v2+ + internal/coverage/rtcov from runtime + internal/cpu from crypto/internal/fips140deps/cpu+ + internal/filepathlite from os+ + internal/fmtsort from fmt + internal/goarch from crypto/internal/fips140deps/cpu+ + internal/godebug from crypto/internal/fips140deps/godebug+ + internal/godebugs from internal/godebug+ + internal/goexperiment from hash/maphash+ + internal/goos from crypto/x509+ + internal/itoa from internal/poll+ + internal/msan from internal/runtime/maps+ + internal/nettrace from net+ + internal/oserror from io/fs+ + internal/poll from net+ + internal/profilerecord from runtime + internal/race from internal/runtime/maps+ + internal/reflectlite from context+ + internal/runtime/atomic from internal/runtime/exithook+ + internal/runtime/cgroup from runtime + internal/runtime/exithook from runtime + internal/runtime/gc from runtime + internal/runtime/maps from reflect+ + internal/runtime/math from internal/runtime/maps+ + internal/runtime/strconv from internal/runtime/cgroup+ + internal/runtime/sys from crypto/subtle+ + internal/runtime/syscall from internal/runtime/cgroup+ + internal/saferio from encoding/asn1 + internal/singleflight from net + internal/stringslite from embed+ + internal/sync from sync+ + internal/synctest from sync + internal/syscall/execenv from os+ + internal/syscall/unix from crypto/internal/sysrand+ + internal/testlog from os + internal/trace/tracev2 from runtime + internal/unsafeheader from internal/reflectlite+ + io from bufio+ + io/fs from crypto/x509+ + io/ioutil from github.com/skip2/go-qrcode + iter from bytes+ + log from github.com/klauspost/compress/zstd+ + log/internal from log + maps from crypto/x509+ + math from compress/flate+ + math/big from crypto/dsa+ + math/bits from bytes+ + math/rand from github.com/mdlayher/netlink+ + math/rand/v2 from crypto/ecdsa+ + mime from mime/multipart+ + mime/multipart from net/http + mime/quotedprintable from mime/multipart + net from crypto/tls+ + net/http from net/http/httputil+ + net/http/httptrace from net/http+ + net/http/httputil from tailscale.com/cmd/tailscale/cli + net/http/internal from net/http+ + net/http/internal/ascii from net/http+ + net/http/internal/httpcommon from net/http + net/netip from crypto/x509+ + net/textproto from golang.org/x/net/http/httpguts+ + net/url from crypto/x509+ + os from crypto/internal/sysrand+ + os/exec from tailscale.com/hostinfo+ + os/signal from tailscale.com/cmd/tailscaled+ + os/user from tailscale.com/ipn/ipnauth+ + path from io/fs+ + path/filepath from crypto/x509+ + reflect from crypto/x509+ + regexp from tailscale.com/clientupdate + regexp/syntax from regexp + runtime from crypto/internal/fips140+ + runtime/debug from github.com/klauspost/compress/zstd+ + slices from crypto/tls+ + sort from compress/flate+ + strconv from compress/flate+ + strings from bufio+ + sync from compress/flate+ + sync/atomic from context+ + syscall from crypto/internal/sysrand+ + text/tabwriter from github.com/peterbourgon/ff/v3/ffcli+ + time from compress/gzip+ + unicode from bytes+ + unicode/utf16 from crypto/x509+ + unicode/utf8 from bufio+ + unique from net/netip + unsafe from bytes+ + weak from crypto/internal/fips140cache+ diff --git a/cmd/tailscaled/depaware.txt b/cmd/tailscaled/depaware.txt index 018e74fac..d15402092 100644 --- a/cmd/tailscaled/depaware.txt +++ b/cmd/tailscaled/depaware.txt @@ -10,7 +10,6 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de L github.com/aws/aws-sdk-go-v2/aws/arn from tailscale.com/ipn/store/awsstore L github.com/aws/aws-sdk-go-v2/aws/defaults from github.com/aws/aws-sdk-go-v2/service/ssm+ L github.com/aws/aws-sdk-go-v2/aws/middleware from github.com/aws/aws-sdk-go-v2/aws/retry+ - L github.com/aws/aws-sdk-go-v2/aws/middleware/private/metrics from github.com/aws/aws-sdk-go-v2/aws/retry+ L github.com/aws/aws-sdk-go-v2/aws/protocol/query from github.com/aws/aws-sdk-go-v2/service/sts L github.com/aws/aws-sdk-go-v2/aws/protocol/restjson from github.com/aws/aws-sdk-go-v2/service/ssm+ L github.com/aws/aws-sdk-go-v2/aws/protocol/xml from github.com/aws/aws-sdk-go-v2/service/sts @@ -32,10 +31,12 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de L github.com/aws/aws-sdk-go-v2/internal/auth from github.com/aws/aws-sdk-go-v2/aws/signer/v4+ L github.com/aws/aws-sdk-go-v2/internal/auth/smithy from github.com/aws/aws-sdk-go-v2/service/ssm+ L github.com/aws/aws-sdk-go-v2/internal/configsources from github.com/aws/aws-sdk-go-v2/service/ssm+ + L github.com/aws/aws-sdk-go-v2/internal/context from github.com/aws/aws-sdk-go-v2/aws/retry+ L github.com/aws/aws-sdk-go-v2/internal/endpoints from github.com/aws/aws-sdk-go-v2/service/ssm+ L github.com/aws/aws-sdk-go-v2/internal/endpoints/awsrulesfn from github.com/aws/aws-sdk-go-v2/service/ssm+ L github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 from github.com/aws/aws-sdk-go-v2/service/ssm/internal/endpoints+ L github.com/aws/aws-sdk-go-v2/internal/ini from github.com/aws/aws-sdk-go-v2/config + L github.com/aws/aws-sdk-go-v2/internal/middleware from github.com/aws/aws-sdk-go-v2/service/sso+ L github.com/aws/aws-sdk-go-v2/internal/rand from github.com/aws/aws-sdk-go-v2/aws+ L github.com/aws/aws-sdk-go-v2/internal/sdk from github.com/aws/aws-sdk-go-v2/aws+ L github.com/aws/aws-sdk-go-v2/internal/sdkio from github.com/aws/aws-sdk-go-v2/credentials/processcreds @@ -70,20 +71,22 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de L github.com/aws/smithy-go/internal/sync/singleflight from github.com/aws/smithy-go/auth/bearer L github.com/aws/smithy-go/io from github.com/aws/aws-sdk-go-v2/feature/ec2/imds+ L github.com/aws/smithy-go/logging from github.com/aws/aws-sdk-go-v2/aws+ + L github.com/aws/smithy-go/metrics from github.com/aws/aws-sdk-go-v2/aws/retry+ L github.com/aws/smithy-go/middleware from github.com/aws/aws-sdk-go-v2/aws+ L github.com/aws/smithy-go/private/requestcompression from github.com/aws/aws-sdk-go-v2/config L github.com/aws/smithy-go/ptr from github.com/aws/aws-sdk-go-v2/aws+ L github.com/aws/smithy-go/rand from github.com/aws/aws-sdk-go-v2/aws/middleware+ L github.com/aws/smithy-go/time from github.com/aws/aws-sdk-go-v2/service/ssm+ + L github.com/aws/smithy-go/tracing from github.com/aws/aws-sdk-go-v2/aws/middleware+ L github.com/aws/smithy-go/transport/http from github.com/aws/aws-sdk-go-v2/aws/middleware+ L github.com/aws/smithy-go/transport/http/internal/io from github.com/aws/smithy-go/transport/http L github.com/aws/smithy-go/waiter from github.com/aws/aws-sdk-go-v2/service/ssm - github.com/bits-and-blooms/bitset from github.com/gaissmai/bart - github.com/coder/websocket from tailscale.com/control/controlhttp+ + github.com/coder/websocket from tailscale.com/util/eventbus github.com/coder/websocket/internal/errd from github.com/coder/websocket github.com/coder/websocket/internal/util from github.com/coder/websocket github.com/coder/websocket/internal/xsync from github.com/coder/websocket L github.com/coreos/go-iptables/iptables from tailscale.com/util/linuxfw + github.com/creachadair/msync/trigger from tailscale.com/logtail LD đŸ’Ŗ github.com/creack/pty from tailscale.com/ssh/tailssh W đŸ’Ŗ github.com/dblohm7/wingoes from github.com/dblohm7/wingoes/com+ W đŸ’Ŗ github.com/dblohm7/wingoes/com from tailscale.com/cmd/tailscaled+ @@ -94,6 +97,8 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de đŸ’Ŗ github.com/djherbis/times from tailscale.com/drive/driveimpl github.com/fxamacker/cbor/v2 from tailscale.com/tka github.com/gaissmai/bart from tailscale.com/net/tstun+ + github.com/gaissmai/bart/internal/bitset from github.com/gaissmai/bart+ + github.com/gaissmai/bart/internal/sparse from github.com/gaissmai/bart github.com/go-json-experiment/json from tailscale.com/types/opt+ github.com/go-json-experiment/json/internal from github.com/go-json-experiment/json/internal/jsonflags+ github.com/go-json-experiment/json/internal/jsonflags from github.com/go-json-experiment/json/internal/jsonopts+ @@ -105,24 +110,29 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de L đŸ’Ŗ github.com/godbus/dbus/v5 from tailscale.com/net/dns+ github.com/golang/groupcache/lru from tailscale.com/net/dnscache github.com/google/btree from gvisor.dev/gvisor/pkg/tcpip/header+ + github.com/google/go-tpm/legacy/tpm2 from github.com/google/go-tpm/tpm2/transport+ + github.com/google/go-tpm/tpm2 from tailscale.com/feature/tpm + github.com/google/go-tpm/tpm2/transport from github.com/google/go-tpm/tpm2/transport/linuxtpm+ + L github.com/google/go-tpm/tpm2/transport/linuxtpm from tailscale.com/feature/tpm + W github.com/google/go-tpm/tpm2/transport/windowstpm from tailscale.com/feature/tpm + github.com/google/go-tpm/tpmutil from github.com/google/go-tpm/legacy/tpm2+ + W đŸ’Ŗ github.com/google/go-tpm/tpmutil/tbs from github.com/google/go-tpm/legacy/tpm2+ L github.com/google/nftables from tailscale.com/util/linuxfw L đŸ’Ŗ github.com/google/nftables/alignedbuff from github.com/google/nftables/xt L đŸ’Ŗ github.com/google/nftables/binaryutil from github.com/google/nftables+ L github.com/google/nftables/expr from github.com/google/nftables+ L github.com/google/nftables/internal/parseexprfunc from github.com/google/nftables+ L github.com/google/nftables/xt from github.com/google/nftables/expr+ - github.com/google/uuid from tailscale.com/clientupdate+ - github.com/gorilla/csrf from tailscale.com/client/web - github.com/gorilla/securecookie from github.com/gorilla/csrf + DW github.com/google/uuid from tailscale.com/clientupdate+ github.com/hdevalence/ed25519consensus from tailscale.com/clientupdate/distsign+ - L đŸ’Ŗ github.com/illarion/gonotify/v2 from tailscale.com/net/dns - L github.com/insomniacslk/dhcp/dhcpv4 from tailscale.com/net/tstun + L đŸ’Ŗ github.com/illarion/gonotify/v3 from tailscale.com/feature/linuxdnsfight + L github.com/illarion/gonotify/v3/syscallf from github.com/illarion/gonotify/v3 + L github.com/insomniacslk/dhcp/dhcpv4 from tailscale.com/feature/tap L github.com/insomniacslk/dhcp/iana from github.com/insomniacslk/dhcp/dhcpv4 L github.com/insomniacslk/dhcp/interfaces from github.com/insomniacslk/dhcp/dhcpv4 L github.com/insomniacslk/dhcp/rfc1035label from github.com/insomniacslk/dhcp/dhcpv4 github.com/jellydator/ttlcache/v3 from tailscale.com/drive/driveimpl/compositedav L github.com/jmespath/go-jmespath from github.com/aws/aws-sdk-go-v2/service/ssm - L github.com/josharian/native from github.com/mdlayher/netlink+ L đŸ’Ŗ github.com/jsimonetti/rtnetlink from tailscale.com/net/netmon L github.com/jsimonetti/rtnetlink/internal/unix from github.com/jsimonetti/rtnetlink github.com/klauspost/compress from github.com/klauspost/compress/zstd @@ -132,21 +142,21 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de github.com/klauspost/compress/internal/snapref from github.com/klauspost/compress/zstd github.com/klauspost/compress/zstd from tailscale.com/util/zstdframe github.com/klauspost/compress/zstd/internal/xxhash from github.com/klauspost/compress/zstd - github.com/kortschak/wol from tailscale.com/ipn/ipnlocal + github.com/kortschak/wol from tailscale.com/feature/wakeonlan LD github.com/kr/fs from github.com/pkg/sftp - L github.com/mdlayher/genetlink from tailscale.com/net/tstun + L github.com/mdlayher/genetlink from tailscale.com/feature/linkspeed L đŸ’Ŗ github.com/mdlayher/netlink from github.com/google/nftables+ L đŸ’Ŗ github.com/mdlayher/netlink/nlenc from github.com/jsimonetti/rtnetlink+ L github.com/mdlayher/netlink/nltest from github.com/google/nftables - L github.com/mdlayher/sdnotify from tailscale.com/util/systemd + L github.com/mdlayher/sdnotify from tailscale.com/feature/sdnotify L đŸ’Ŗ github.com/mdlayher/socket from github.com/mdlayher/netlink+ - github.com/miekg/dns from tailscale.com/net/dns/recursive đŸ’Ŗ github.com/mitchellh/go-ps from tailscale.com/safesocket L github.com/pierrec/lz4/v4 from github.com/u-root/uio/uio L github.com/pierrec/lz4/v4/internal/lz4block from github.com/pierrec/lz4/v4+ L github.com/pierrec/lz4/v4/internal/lz4errors from github.com/pierrec/lz4/v4+ L github.com/pierrec/lz4/v4/internal/lz4stream from github.com/pierrec/lz4/v4 L github.com/pierrec/lz4/v4/internal/xxh32 from github.com/pierrec/lz4/v4/internal/lz4stream + github.com/pires/go-proxyproto from tailscale.com/ipn/ipnlocal LD github.com/pkg/sftp from tailscale.com/ssh/tailssh LD github.com/pkg/sftp/internal/encoding/ssh/filexfer from github.com/pkg/sftp D github.com/prometheus-community/pro-bing from tailscale.com/wgengine/netstack @@ -157,10 +167,6 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de W đŸ’Ŗ github.com/tailscale/go-winio/internal/socket from github.com/tailscale/go-winio W github.com/tailscale/go-winio/internal/stringbuffer from github.com/tailscale/go-winio/internal/fs W github.com/tailscale/go-winio/pkg/guid from github.com/tailscale/go-winio+ - github.com/tailscale/golang-x-crypto/acme from tailscale.com/ipn/ipnlocal - LD github.com/tailscale/golang-x-crypto/internal/poly1305 from github.com/tailscale/golang-x-crypto/ssh - LD github.com/tailscale/golang-x-crypto/ssh from tailscale.com/ipn/ipnlocal+ - LD github.com/tailscale/golang-x-crypto/ssh/internal/bcrypt_pbkdf from github.com/tailscale/golang-x-crypto/ssh github.com/tailscale/goupnp from github.com/tailscale/goupnp/dcps/internetgateway2+ github.com/tailscale/goupnp/dcps/internetgateway2 from tailscale.com/net/portmapper github.com/tailscale/goupnp/httpu from github.com/tailscale/goupnp+ @@ -170,7 +176,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de github.com/tailscale/hujson from tailscale.com/ipn/conffile L đŸ’Ŗ github.com/tailscale/netlink from tailscale.com/net/routetable+ L đŸ’Ŗ github.com/tailscale/netlink/nl from github.com/tailscale/netlink - github.com/tailscale/peercred from tailscale.com/ipn/ipnauth + LD github.com/tailscale/peercred from tailscale.com/ipn/ipnauth github.com/tailscale/web-client-prebuilt from tailscale.com/client/web W đŸ’Ŗ github.com/tailscale/wf from tailscale.com/wf đŸ’Ŗ github.com/tailscale/wireguard-go/conn from github.com/tailscale/wireguard-go/device+ @@ -185,13 +191,12 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de đŸ’Ŗ github.com/tailscale/wireguard-go/tun from github.com/tailscale/wireguard-go/device+ github.com/tailscale/xnet/webdav from tailscale.com/drive/driveimpl+ github.com/tailscale/xnet/webdav/internal/xml from github.com/tailscale/xnet/webdav - github.com/tcnksm/go-httpstat from tailscale.com/net/netcheck LD github.com/u-root/u-root/pkg/termios from tailscale.com/ssh/tailssh L github.com/u-root/uio/rand from github.com/insomniacslk/dhcp/dhcpv4 L github.com/u-root/uio/uio from github.com/insomniacslk/dhcp/dhcpv4+ L github.com/vishvananda/netns from github.com/tailscale/netlink+ github.com/x448/float16 from github.com/fxamacker/cbor/v2 - đŸ’Ŗ go4.org/mem from tailscale.com/client/tailscale+ + đŸ’Ŗ go4.org/mem from tailscale.com/client/local+ go4.org/netipx from github.com/tailscale/wf+ W đŸ’Ŗ golang.zx2c4.com/wintun from github.com/tailscale/wireguard-go/tun+ W đŸ’Ŗ golang.zx2c4.com/wireguard/windows/tunnel/winipcfg from tailscale.com/cmd/tailscaled+ @@ -215,13 +220,13 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de gvisor.dev/gvisor/pkg/tcpip/hash/jenkins from gvisor.dev/gvisor/pkg/tcpip/stack+ gvisor.dev/gvisor/pkg/tcpip/header from gvisor.dev/gvisor/pkg/tcpip/header/parse+ gvisor.dev/gvisor/pkg/tcpip/header/parse from gvisor.dev/gvisor/pkg/tcpip/network/ipv4+ - gvisor.dev/gvisor/pkg/tcpip/internal/tcp from gvisor.dev/gvisor/pkg/tcpip/stack+ + gvisor.dev/gvisor/pkg/tcpip/internal/tcp from gvisor.dev/gvisor/pkg/tcpip/transport/tcp gvisor.dev/gvisor/pkg/tcpip/network/hash from gvisor.dev/gvisor/pkg/tcpip/network/ipv4 gvisor.dev/gvisor/pkg/tcpip/network/internal/fragmentation from gvisor.dev/gvisor/pkg/tcpip/network/ipv4+ gvisor.dev/gvisor/pkg/tcpip/network/internal/ip from gvisor.dev/gvisor/pkg/tcpip/network/ipv4+ gvisor.dev/gvisor/pkg/tcpip/network/internal/multicast from gvisor.dev/gvisor/pkg/tcpip/network/ipv4+ - gvisor.dev/gvisor/pkg/tcpip/network/ipv4 from tailscale.com/net/tstun+ - gvisor.dev/gvisor/pkg/tcpip/network/ipv6 from tailscale.com/wgengine/netstack + gvisor.dev/gvisor/pkg/tcpip/network/ipv4 from tailscale.com/feature/tap+ + gvisor.dev/gvisor/pkg/tcpip/network/ipv6 from tailscale.com/wgengine/netstack+ gvisor.dev/gvisor/pkg/tcpip/ports from gvisor.dev/gvisor/pkg/tcpip/stack+ gvisor.dev/gvisor/pkg/tcpip/seqnum from gvisor.dev/gvisor/pkg/tcpip/header+ đŸ’Ŗ gvisor.dev/gvisor/pkg/tcpip/stack from gvisor.dev/gvisor/pkg/tcpip/adapters/gonet+ @@ -238,68 +243,102 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de gvisor.dev/gvisor/pkg/waiter from gvisor.dev/gvisor/pkg/context+ tailscale.com from tailscale.com/version tailscale.com/appc from tailscale.com/ipn/ipnlocal - tailscale.com/atomicfile from tailscale.com/ipn+ + đŸ’Ŗ tailscale.com/atomicfile from tailscale.com/ipn+ LD tailscale.com/chirp from tailscale.com/cmd/tailscaled - tailscale.com/client/tailscale from tailscale.com/client/web+ - tailscale.com/client/tailscale/apitype from tailscale.com/client/tailscale+ + tailscale.com/client/local from tailscale.com/client/web+ + tailscale.com/client/tailscale/apitype from tailscale.com/client/local+ tailscale.com/client/web from tailscale.com/ipn/ipnlocal - tailscale.com/clientupdate from tailscale.com/client/web+ - tailscale.com/clientupdate/distsign from tailscale.com/clientupdate + tailscale.com/clientupdate from tailscale.com/feature/clientupdate + LW tailscale.com/clientupdate/distsign from tailscale.com/clientupdate tailscale.com/cmd/tailscaled/childproc from tailscale.com/cmd/tailscaled+ + tailscale.com/cmd/tailscaled/tailscaledhooks from tailscale.com/cmd/tailscaled+ tailscale.com/control/controlbase from tailscale.com/control/controlhttp+ tailscale.com/control/controlclient from tailscale.com/cmd/tailscaled+ - tailscale.com/control/controlhttp from tailscale.com/control/controlclient + tailscale.com/control/controlhttp from tailscale.com/control/ts2021+ + tailscale.com/control/controlhttp/controlhttpcommon from tailscale.com/control/controlhttp tailscale.com/control/controlknobs from tailscale.com/control/controlclient+ + tailscale.com/control/ts2021 from tailscale.com/control/controlclient tailscale.com/derp from tailscale.com/derp/derphttp+ + tailscale.com/derp/derpconst from tailscale.com/derp/derphttp+ tailscale.com/derp/derphttp from tailscale.com/cmd/tailscaled+ - tailscale.com/disco from tailscale.com/derp+ - tailscale.com/doctor from tailscale.com/ipn/ipnlocal - tailscale.com/doctor/ethtool from tailscale.com/ipn/ipnlocal - đŸ’Ŗ tailscale.com/doctor/permissions from tailscale.com/ipn/ipnlocal - tailscale.com/doctor/routetable from tailscale.com/ipn/ipnlocal - tailscale.com/drive from tailscale.com/client/tailscale+ + tailscale.com/disco from tailscale.com/feature/relayserver+ + tailscale.com/doctor from tailscale.com/feature/doctor + tailscale.com/doctor/ethtool from tailscale.com/feature/doctor + đŸ’Ŗ tailscale.com/doctor/permissions from tailscale.com/feature/doctor + tailscale.com/doctor/routetable from tailscale.com/feature/doctor + tailscale.com/drive from tailscale.com/client/local+ tailscale.com/drive/driveimpl from tailscale.com/cmd/tailscaled tailscale.com/drive/driveimpl/compositedav from tailscale.com/drive/driveimpl tailscale.com/drive/driveimpl/dirfs from tailscale.com/drive/driveimpl+ tailscale.com/drive/driveimpl/shared from tailscale.com/drive/driveimpl+ - tailscale.com/envknob from tailscale.com/client/tailscale+ + tailscale.com/envknob from tailscale.com/client/local+ + tailscale.com/envknob/featureknob from tailscale.com/client/web+ + tailscale.com/feature from tailscale.com/feature/wakeonlan+ + tailscale.com/feature/ace from tailscale.com/feature/condregister + tailscale.com/feature/appconnectors from tailscale.com/feature/condregister + tailscale.com/feature/buildfeatures from tailscale.com/wgengine/magicsock+ + tailscale.com/feature/c2n from tailscale.com/feature/condregister + tailscale.com/feature/capture from tailscale.com/feature/condregister + tailscale.com/feature/clientupdate from tailscale.com/feature/condregister + tailscale.com/feature/condlite/expvar from tailscale.com/wgengine/magicsock + tailscale.com/feature/condregister from tailscale.com/cmd/tailscaled + tailscale.com/feature/condregister/portmapper from tailscale.com/feature/condregister + tailscale.com/feature/condregister/useproxy from tailscale.com/feature/condregister + tailscale.com/feature/debugportmapper from tailscale.com/feature/condregister + tailscale.com/feature/doctor from tailscale.com/feature/condregister + tailscale.com/feature/drive from tailscale.com/feature/condregister + L tailscale.com/feature/linkspeed from tailscale.com/feature/condregister + L tailscale.com/feature/linuxdnsfight from tailscale.com/feature/condregister + tailscale.com/feature/portlist from tailscale.com/feature/condregister + tailscale.com/feature/portmapper from tailscale.com/feature/condregister/portmapper + tailscale.com/feature/posture from tailscale.com/feature/condregister + tailscale.com/feature/relayserver from tailscale.com/feature/condregister + L tailscale.com/feature/sdnotify from tailscale.com/feature/condregister + tailscale.com/feature/syspolicy from tailscale.com/feature/condregister+ + tailscale.com/feature/taildrop from tailscale.com/feature/condregister + L tailscale.com/feature/tap from tailscale.com/feature/condregister + tailscale.com/feature/tpm from tailscale.com/feature/condregister + tailscale.com/feature/useproxy from tailscale.com/feature/condregister/useproxy + tailscale.com/feature/wakeonlan from tailscale.com/feature/condregister tailscale.com/health from tailscale.com/control/controlclient+ - tailscale.com/health/healthmsg from tailscale.com/ipn/ipnlocal + tailscale.com/health/healthmsg from tailscale.com/ipn/ipnlocal+ tailscale.com/hostinfo from tailscale.com/client/web+ - tailscale.com/internal/noiseconn from tailscale.com/control/controlclient - tailscale.com/ipn from tailscale.com/client/tailscale+ + tailscale.com/ipn from tailscale.com/client/local+ + W tailscale.com/ipn/auditlog from tailscale.com/cmd/tailscaled tailscale.com/ipn/conffile from tailscale.com/cmd/tailscaled+ + W đŸ’Ŗ tailscale.com/ipn/desktop from tailscale.com/cmd/tailscaled đŸ’Ŗ tailscale.com/ipn/ipnauth from tailscale.com/ipn/ipnlocal+ + tailscale.com/ipn/ipnext from tailscale.com/ipn/auditlog+ tailscale.com/ipn/ipnlocal from tailscale.com/cmd/tailscaled+ tailscale.com/ipn/ipnserver from tailscale.com/cmd/tailscaled - tailscale.com/ipn/ipnstate from tailscale.com/client/tailscale+ - tailscale.com/ipn/localapi from tailscale.com/ipn/ipnserver - tailscale.com/ipn/policy from tailscale.com/ipn/ipnlocal + tailscale.com/ipn/ipnstate from tailscale.com/client/local+ + tailscale.com/ipn/localapi from tailscale.com/ipn/ipnserver+ + tailscale.com/ipn/policy from tailscale.com/feature/portlist tailscale.com/ipn/store from tailscale.com/cmd/tailscaled+ - L tailscale.com/ipn/store/awsstore from tailscale.com/ipn/store - L tailscale.com/ipn/store/kubestore from tailscale.com/ipn/store + L tailscale.com/ipn/store/awsstore from tailscale.com/feature/condregister + L tailscale.com/ipn/store/kubestore from tailscale.com/feature/condregister tailscale.com/ipn/store/mem from tailscale.com/ipn/ipnlocal+ L tailscale.com/kube/kubeapi from tailscale.com/ipn/store/kubestore+ L tailscale.com/kube/kubeclient from tailscale.com/ipn/store/kubestore - tailscale.com/kube/kubetypes from tailscale.com/envknob + tailscale.com/kube/kubetypes from tailscale.com/envknob+ tailscale.com/licenses from tailscale.com/client/web tailscale.com/log/filelogger from tailscale.com/logpolicy tailscale.com/log/sockstatlog from tailscale.com/ipn/ipnlocal tailscale.com/logpolicy from tailscale.com/cmd/tailscaled+ tailscale.com/logtail from tailscale.com/cmd/tailscaled+ - tailscale.com/logtail/backoff from tailscale.com/cmd/tailscaled+ tailscale.com/logtail/filch from tailscale.com/log/sockstatlog+ - tailscale.com/metrics from tailscale.com/derp+ + tailscale.com/metrics from tailscale.com/tsweb+ + tailscale.com/net/ace from tailscale.com/feature/ace + tailscale.com/net/bakedroots from tailscale.com/net/tlsdial+ + đŸ’Ŗ tailscale.com/net/batching from tailscale.com/wgengine/magicsock+ tailscale.com/net/captivedetection from tailscale.com/ipn/ipnlocal+ - tailscale.com/net/connstats from tailscale.com/net/tstun+ tailscale.com/net/dns from tailscale.com/cmd/tailscaled+ tailscale.com/net/dns/publicdns from tailscale.com/net/dns+ - tailscale.com/net/dns/recursive from tailscale.com/net/dnsfallback tailscale.com/net/dns/resolvconffile from tailscale.com/net/dns+ - tailscale.com/net/dns/resolver from tailscale.com/net/dns + tailscale.com/net/dns/resolver from tailscale.com/net/dns+ tailscale.com/net/dnscache from tailscale.com/control/controlclient+ tailscale.com/net/dnsfallback from tailscale.com/cmd/tailscaled+ - tailscale.com/net/flowtrack from tailscale.com/net/packet+ + tailscale.com/net/flowtrack from tailscale.com/wgengine+ tailscale.com/net/ipset from tailscale.com/ipn/ipnlocal+ tailscale.com/net/netaddr from tailscale.com/ipn+ tailscale.com/net/netcheck from tailscale.com/wgengine/magicsock+ @@ -309,107 +348,125 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de đŸ’Ŗ tailscale.com/net/netmon from tailscale.com/cmd/tailscaled+ đŸ’Ŗ tailscale.com/net/netns from tailscale.com/cmd/tailscaled+ W đŸ’Ŗ tailscale.com/net/netstat from tailscale.com/portlist - tailscale.com/net/netutil from tailscale.com/client/tailscale+ - tailscale.com/net/packet from tailscale.com/net/connstats+ + tailscale.com/net/netutil from tailscale.com/client/local+ + tailscale.com/net/netx from tailscale.com/control/controlclient+ + tailscale.com/net/packet from tailscale.com/feature/capture+ tailscale.com/net/packet/checksum from tailscale.com/net/tstun tailscale.com/net/ping from tailscale.com/net/netcheck+ - tailscale.com/net/portmapper from tailscale.com/ipn/localapi+ + tailscale.com/net/portmapper from tailscale.com/feature/portmapper+ + tailscale.com/net/portmapper/portmappertype from tailscale.com/feature/portmapper+ tailscale.com/net/proxymux from tailscale.com/cmd/tailscaled tailscale.com/net/routetable from tailscale.com/doctor/routetable + đŸ’Ŗ tailscale.com/net/sockopts from tailscale.com/wgengine/magicsock+ tailscale.com/net/socks5 from tailscale.com/cmd/tailscaled tailscale.com/net/sockstats from tailscale.com/control/controlclient+ tailscale.com/net/stun from tailscale.com/ipn/localapi+ - L tailscale.com/net/tcpinfo from tailscale.com/derp tailscale.com/net/tlsdial from tailscale.com/control/controlclient+ + tailscale.com/net/tlsdial/blockblame from tailscale.com/net/tlsdial tailscale.com/net/tsaddr from tailscale.com/client/web+ tailscale.com/net/tsdial from tailscale.com/cmd/tailscaled+ - đŸ’Ŗ tailscale.com/net/tshttpproxy from tailscale.com/clientupdate/distsign+ + đŸ’Ŗ tailscale.com/net/tshttpproxy from tailscale.com/feature/useproxy tailscale.com/net/tstun from tailscale.com/cmd/tailscaled+ - tailscale.com/net/wsconn from tailscale.com/control/controlhttp+ + tailscale.com/net/udprelay from tailscale.com/feature/relayserver + tailscale.com/net/udprelay/endpoint from tailscale.com/net/udprelay+ + tailscale.com/net/udprelay/status from tailscale.com/client/local+ tailscale.com/omit from tailscale.com/ipn/conffile - tailscale.com/paths from tailscale.com/client/tailscale+ - đŸ’Ŗ tailscale.com/portlist from tailscale.com/ipn/ipnlocal - tailscale.com/posture from tailscale.com/ipn/ipnlocal + tailscale.com/paths from tailscale.com/client/local+ + đŸ’Ŗ tailscale.com/portlist from tailscale.com/feature/portlist + tailscale.com/posture from tailscale.com/feature/posture tailscale.com/proxymap from tailscale.com/tsd+ - đŸ’Ŗ tailscale.com/safesocket from tailscale.com/client/tailscale+ + đŸ’Ŗ tailscale.com/safesocket from tailscale.com/client/local+ LD tailscale.com/sessionrecording from tailscale.com/ssh/tailssh LD đŸ’Ŗ tailscale.com/ssh/tailssh from tailscale.com/cmd/tailscaled tailscale.com/syncs from tailscale.com/cmd/tailscaled+ - tailscale.com/tailcfg from tailscale.com/client/tailscale+ - tailscale.com/taildrop from tailscale.com/ipn/ipnlocal+ + tailscale.com/tailcfg from tailscale.com/client/local+ + tailscale.com/tempfork/acme from tailscale.com/ipn/ipnlocal LD tailscale.com/tempfork/gliderlabs/ssh from tailscale.com/ssh/tailssh tailscale.com/tempfork/heap from tailscale.com/wgengine/magicsock - tailscale.com/tka from tailscale.com/client/tailscale+ + tailscale.com/tempfork/httprec from tailscale.com/feature/c2n + tailscale.com/tka from tailscale.com/client/local+ tailscale.com/tsconst from tailscale.com/net/netmon+ tailscale.com/tsd from tailscale.com/cmd/tailscaled+ tailscale.com/tstime from tailscale.com/control/controlclient+ tailscale.com/tstime/mono from tailscale.com/net/tstun+ - tailscale.com/tstime/rate from tailscale.com/derp+ + tailscale.com/tstime/rate from tailscale.com/wgengine/filter + tailscale.com/tsweb from tailscale.com/util/eventbus tailscale.com/tsweb/varz from tailscale.com/cmd/tailscaled+ - tailscale.com/types/appctype from tailscale.com/ipn/ipnlocal + tailscale.com/types/appctype from tailscale.com/ipn/ipnlocal+ + tailscale.com/types/bools from tailscale.com/wgengine/netlog tailscale.com/types/dnstype from tailscale.com/ipn/ipnlocal+ tailscale.com/types/empty from tailscale.com/ipn+ tailscale.com/types/flagtype from tailscale.com/cmd/tailscaled tailscale.com/types/ipproto from tailscale.com/net/flowtrack+ - tailscale.com/types/key from tailscale.com/client/tailscale+ + tailscale.com/types/key from tailscale.com/client/local+ tailscale.com/types/lazy from tailscale.com/ipn/ipnlocal+ tailscale.com/types/logger from tailscale.com/appc+ tailscale.com/types/logid from tailscale.com/cmd/tailscaled+ - tailscale.com/types/netlogtype from tailscale.com/net/connstats+ + tailscale.com/types/mapx from tailscale.com/ipn/ipnext + tailscale.com/types/netlogfunc from tailscale.com/net/tstun+ + tailscale.com/types/netlogtype from tailscale.com/wgengine/netlog tailscale.com/types/netmap from tailscale.com/control/controlclient+ tailscale.com/types/nettype from tailscale.com/ipn/localapi+ - tailscale.com/types/opt from tailscale.com/client/tailscale+ + tailscale.com/types/opt from tailscale.com/control/controlknobs+ tailscale.com/types/persist from tailscale.com/control/controlclient+ tailscale.com/types/preftype from tailscale.com/ipn+ tailscale.com/types/ptr from tailscale.com/control/controlclient+ + tailscale.com/types/result from tailscale.com/util/lineiter tailscale.com/types/structs from tailscale.com/control/controlclient+ tailscale.com/types/tkatype from tailscale.com/tka+ tailscale.com/types/views from tailscale.com/ipn/ipnlocal+ - tailscale.com/util/cibuild from tailscale.com/health + tailscale.com/util/backoff from tailscale.com/cmd/tailscaled+ + tailscale.com/util/checkchange from tailscale.com/ipn/ipnlocal+ + tailscale.com/util/cibuild from tailscale.com/health+ tailscale.com/util/clientmetric from tailscale.com/control/controlclient+ tailscale.com/util/cloudenv from tailscale.com/net/dns/resolver+ tailscale.com/util/cmpver from tailscale.com/net/dns+ tailscale.com/util/ctxkey from tailscale.com/ipn/ipnlocal+ - đŸ’Ŗ tailscale.com/util/deephash from tailscale.com/ipn/ipnlocal+ + đŸ’Ŗ tailscale.com/util/deephash from tailscale.com/util/syspolicy/setting L đŸ’Ŗ tailscale.com/util/dirwalk from tailscale.com/metrics+ tailscale.com/util/dnsname from tailscale.com/appc+ + tailscale.com/util/eventbus from tailscale.com/tsd+ tailscale.com/util/execqueue from tailscale.com/control/controlclient+ tailscale.com/util/goroutines from tailscale.com/ipn/ipnlocal tailscale.com/util/groupmember from tailscale.com/client/web+ đŸ’Ŗ tailscale.com/util/hashx from tailscale.com/util/deephash - tailscale.com/util/httphdr from tailscale.com/ipn/ipnlocal+ - tailscale.com/util/httpm from tailscale.com/client/tailscale+ - tailscale.com/util/lineread from tailscale.com/hostinfo+ - L tailscale.com/util/linuxfw from tailscale.com/net/netns+ + tailscale.com/util/httphdr from tailscale.com/feature/taildrop + tailscale.com/util/httpm from tailscale.com/client/web+ + tailscale.com/util/lineiter from tailscale.com/hostinfo+ + L tailscale.com/util/linuxfw from tailscale.com/wgengine/router/osrouter tailscale.com/util/mak from tailscale.com/control/controlclient+ - tailscale.com/util/multierr from tailscale.com/cmd/tailscaled+ + tailscale.com/util/multierr from tailscale.com/feature/taildrop tailscale.com/util/must from tailscale.com/clientupdate/distsign+ tailscale.com/util/nocasemaps from tailscale.com/types/ipproto đŸ’Ŗ tailscale.com/util/osdiag from tailscale.com/cmd/tailscaled+ W đŸ’Ŗ tailscale.com/util/osdiag/internal/wsc from tailscale.com/util/osdiag tailscale.com/util/osshare from tailscale.com/cmd/tailscaled+ tailscale.com/util/osuser from tailscale.com/ipn/ipnlocal+ - tailscale.com/util/progresstracking from tailscale.com/ipn/localapi + tailscale.com/util/progresstracking from tailscale.com/feature/taildrop tailscale.com/util/race from tailscale.com/net/dns/resolver tailscale.com/util/racebuild from tailscale.com/logpolicy tailscale.com/util/rands from tailscale.com/ipn/ipnlocal+ - tailscale.com/util/ringbuffer from tailscale.com/wgengine/magicsock - tailscale.com/util/set from tailscale.com/derp+ + tailscale.com/util/ringlog from tailscale.com/wgengine/magicsock + tailscale.com/util/set from tailscale.com/control/controlclient+ tailscale.com/util/singleflight from tailscale.com/control/controlclient+ - tailscale.com/util/slicesx from tailscale.com/net/dns/recursive+ - tailscale.com/util/syspolicy from tailscale.com/cmd/tailscaled+ - tailscale.com/util/syspolicy/internal from tailscale.com/util/syspolicy/setting - tailscale.com/util/syspolicy/setting from tailscale.com/util/syspolicy - tailscale.com/util/sysresources from tailscale.com/wgengine/magicsock - tailscale.com/util/systemd from tailscale.com/control/controlclient+ + tailscale.com/util/slicesx from tailscale.com/appc+ + tailscale.com/util/syspolicy from tailscale.com/feature/syspolicy + tailscale.com/util/syspolicy/internal from tailscale.com/util/syspolicy/setting+ + tailscale.com/util/syspolicy/internal/loggerx from tailscale.com/util/syspolicy/internal/metrics+ + tailscale.com/util/syspolicy/internal/metrics from tailscale.com/util/syspolicy/source + tailscale.com/util/syspolicy/pkey from tailscale.com/cmd/tailscaled+ + tailscale.com/util/syspolicy/policyclient from tailscale.com/control/controlclient+ + tailscale.com/util/syspolicy/ptype from tailscale.com/util/syspolicy+ + tailscale.com/util/syspolicy/rsop from tailscale.com/util/syspolicy+ + tailscale.com/util/syspolicy/setting from tailscale.com/util/syspolicy+ + tailscale.com/util/syspolicy/source from tailscale.com/util/syspolicy+ tailscale.com/util/testenv from tailscale.com/ipn/ipnlocal+ tailscale.com/util/truncate from tailscale.com/logtail - tailscale.com/util/uniq from tailscale.com/ipn/ipnlocal+ tailscale.com/util/usermetric from tailscale.com/health+ tailscale.com/util/vizerror from tailscale.com/tailcfg+ đŸ’Ŗ tailscale.com/util/winutil from tailscale.com/clientupdate+ W đŸ’Ŗ tailscale.com/util/winutil/authenticode from tailscale.com/clientupdate+ - W đŸ’Ŗ tailscale.com/util/winutil/gp from tailscale.com/net/dns + W đŸ’Ŗ tailscale.com/util/winutil/gp from tailscale.com/net/dns+ W tailscale.com/util/winutil/policy from tailscale.com/ipn/ipnlocal W đŸ’Ŗ tailscale.com/util/winutil/winenv from tailscale.com/hostinfo+ tailscale.com/util/zstdframe from tailscale.com/control/controlclient+ @@ -417,7 +474,6 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de tailscale.com/version/distro from tailscale.com/client/web+ W tailscale.com/wf from tailscale.com/cmd/tailscaled tailscale.com/wgengine from tailscale.com/cmd/tailscaled+ - tailscale.com/wgengine/capture from tailscale.com/ipn/ipnlocal+ tailscale.com/wgengine/filter from tailscale.com/control/controlclient+ tailscale.com/wgengine/filter/filtertype from tailscale.com/types/netmap+ đŸ’Ŗ tailscale.com/wgengine/magicsock from tailscale.com/ipn/ipnlocal+ @@ -425,45 +481,48 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de tailscale.com/wgengine/netstack from tailscale.com/cmd/tailscaled tailscale.com/wgengine/netstack/gro from tailscale.com/net/tstun+ tailscale.com/wgengine/router from tailscale.com/cmd/tailscaled+ + tailscale.com/wgengine/router/osrouter from tailscale.com/feature/condregister tailscale.com/wgengine/wgcfg from tailscale.com/ipn/ipnlocal+ tailscale.com/wgengine/wgcfg/nmcfg from tailscale.com/ipn/ipnlocal đŸ’Ŗ tailscale.com/wgengine/wgint from tailscale.com/wgengine+ tailscale.com/wgengine/wglog from tailscale.com/wgengine - W đŸ’Ŗ tailscale.com/wgengine/winnet from tailscale.com/wgengine/router + W đŸ’Ŗ tailscale.com/wgengine/winnet from tailscale.com/wgengine/router/osrouter golang.org/x/crypto/argon2 from tailscale.com/tka golang.org/x/crypto/blake2b from golang.org/x/crypto/argon2+ golang.org/x/crypto/blake2s from github.com/tailscale/wireguard-go/device+ - LD golang.org/x/crypto/blowfish from github.com/tailscale/golang-x-crypto/ssh/internal/bcrypt_pbkdf+ + LD golang.org/x/crypto/blowfish from golang.org/x/crypto/ssh/internal/bcrypt_pbkdf golang.org/x/crypto/chacha20 from golang.org/x/crypto/chacha20poly1305+ - golang.org/x/crypto/chacha20poly1305 from crypto/tls+ - golang.org/x/crypto/cryptobyte from crypto/ecdsa+ - golang.org/x/crypto/cryptobyte/asn1 from crypto/ecdsa+ - golang.org/x/crypto/curve25519 from github.com/tailscale/golang-x-crypto/ssh+ - golang.org/x/crypto/hkdf from crypto/tls+ + golang.org/x/crypto/chacha20poly1305 from github.com/tailscale/wireguard-go/device+ + golang.org/x/crypto/cryptobyte from tailscale.com/feature/tpm + golang.org/x/crypto/cryptobyte/asn1 from golang.org/x/crypto/cryptobyte+ + golang.org/x/crypto/curve25519 from golang.org/x/crypto/ssh+ + golang.org/x/crypto/hkdf from tailscale.com/control/controlbase + golang.org/x/crypto/internal/alias from golang.org/x/crypto/chacha20+ + golang.org/x/crypto/internal/poly1305 from golang.org/x/crypto/chacha20poly1305+ golang.org/x/crypto/nacl/box from tailscale.com/types/key - golang.org/x/crypto/nacl/secretbox from golang.org/x/crypto/nacl/box + golang.org/x/crypto/nacl/secretbox from golang.org/x/crypto/nacl/box+ golang.org/x/crypto/poly1305 from github.com/tailscale/wireguard-go/device golang.org/x/crypto/salsa20/salsa from golang.org/x/crypto/nacl/box+ - golang.org/x/crypto/sha3 from crypto/internal/mlkem768+ LD golang.org/x/crypto/ssh from github.com/pkg/sftp+ + LD golang.org/x/crypto/ssh/internal/bcrypt_pbkdf from golang.org/x/crypto/ssh golang.org/x/exp/constraints from github.com/dblohm7/wingoes/pe+ - golang.org/x/exp/maps from tailscale.com/appc+ + golang.org/x/exp/maps from tailscale.com/ipn/store/mem+ golang.org/x/net/bpf from github.com/mdlayher/genetlink+ - golang.org/x/net/dns/dnsmessage from net+ - golang.org/x/net/http/httpguts from golang.org/x/net/http2+ - golang.org/x/net/http/httpproxy from net/http+ - golang.org/x/net/http2 from golang.org/x/net/http2/h2c+ - golang.org/x/net/http2/h2c from tailscale.com/ipn/ipnlocal - golang.org/x/net/http2/hpack from golang.org/x/net/http2+ + golang.org/x/net/dns/dnsmessage from tailscale.com/appc+ + golang.org/x/net/http/httpguts from tailscale.com/ipn/ipnlocal + golang.org/x/net/http/httpproxy from tailscale.com/net/tshttpproxy golang.org/x/net/icmp from tailscale.com/net/ping+ golang.org/x/net/idna from golang.org/x/net/http/httpguts+ - golang.org/x/net/ipv4 from github.com/miekg/dns+ - golang.org/x/net/ipv6 from github.com/miekg/dns+ + golang.org/x/net/internal/iana from golang.org/x/net/icmp+ + golang.org/x/net/internal/socket from golang.org/x/net/icmp+ + golang.org/x/net/internal/socks from golang.org/x/net/proxy + golang.org/x/net/ipv4 from github.com/prometheus-community/pro-bing+ + golang.org/x/net/ipv6 from github.com/prometheus-community/pro-bing+ golang.org/x/net/proxy from tailscale.com/net/netns - D golang.org/x/net/route from net+ + D golang.org/x/net/route from tailscale.com/net/netmon+ golang.org/x/sync/errgroup from github.com/mdlayher/socket+ golang.org/x/sync/singleflight from github.com/jellydator/ttlcache/v3 - golang.org/x/sys/cpu from github.com/josharian/native+ + golang.org/x/sys/cpu from github.com/tailscale/certstore+ LD golang.org/x/sys/unix from github.com/google/nftables+ W golang.org/x/sys/windows from github.com/dblohm7/wingoes+ W golang.org/x/sys/windows/registry from github.com/dblohm7/wingoes+ @@ -476,18 +535,34 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de golang.org/x/text/unicode/bidi from golang.org/x/net/idna+ golang.org/x/text/unicode/norm from golang.org/x/net/idna golang.org/x/time/rate from gvisor.dev/gvisor/pkg/log+ + vendor/golang.org/x/crypto/chacha20 from vendor/golang.org/x/crypto/chacha20poly1305 + vendor/golang.org/x/crypto/chacha20poly1305 from crypto/internal/hpke+ + vendor/golang.org/x/crypto/cryptobyte from crypto/ecdsa+ + vendor/golang.org/x/crypto/cryptobyte/asn1 from crypto/ecdsa+ + vendor/golang.org/x/crypto/internal/alias from vendor/golang.org/x/crypto/chacha20+ + vendor/golang.org/x/crypto/internal/poly1305 from vendor/golang.org/x/crypto/chacha20poly1305 + vendor/golang.org/x/net/dns/dnsmessage from net + vendor/golang.org/x/net/http/httpguts from net/http+ + vendor/golang.org/x/net/http/httpproxy from net/http + vendor/golang.org/x/net/http2/hpack from net/http+ + vendor/golang.org/x/net/idna from net/http+ + vendor/golang.org/x/sys/cpu from vendor/golang.org/x/crypto/chacha20poly1305 + vendor/golang.org/x/text/secure/bidirule from vendor/golang.org/x/net/idna + vendor/golang.org/x/text/transform from vendor/golang.org/x/text/secure/bidirule+ + vendor/golang.org/x/text/unicode/bidi from vendor/golang.org/x/net/idna+ + vendor/golang.org/x/text/unicode/norm from vendor/golang.org/x/net/idna archive/tar from tailscale.com/clientupdate bufio from compress/flate+ bytes from archive/tar+ cmp from slices+ compress/flate from compress/gzip+ - compress/gzip from golang.org/x/net/http2+ + compress/gzip from github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding+ W compress/zlib from debug/pe container/heap from github.com/jellydator/ttlcache/v3+ container/list from crypto/tls+ context from crypto/tls+ crypto from crypto/ecdh+ - crypto/aes from crypto/ecdsa+ + crypto/aes from crypto/internal/hpke+ crypto/cipher from crypto/aes+ crypto/des from crypto/tls+ crypto/dsa from crypto/x509+ @@ -495,42 +570,132 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de crypto/ecdsa from crypto/tls+ crypto/ed25519 from crypto/tls+ crypto/elliptic from crypto/ecdsa+ + crypto/fips140 from crypto/tls/internal/fips140tls+ + crypto/hkdf from crypto/internal/hpke+ crypto/hmac from crypto/tls+ + crypto/internal/boring from crypto/aes+ + crypto/internal/boring/bbig from crypto/ecdsa+ + crypto/internal/boring/sig from crypto/internal/boring + crypto/internal/entropy from crypto/internal/fips140/drbg + crypto/internal/fips140 from crypto/internal/fips140/aes+ + crypto/internal/fips140/aes from crypto/aes+ + crypto/internal/fips140/aes/gcm from crypto/cipher+ + crypto/internal/fips140/alias from crypto/cipher+ + crypto/internal/fips140/bigmod from crypto/internal/fips140/ecdsa+ + crypto/internal/fips140/check from crypto/internal/fips140/aes+ + crypto/internal/fips140/drbg from crypto/internal/fips140/aes/gcm+ + crypto/internal/fips140/ecdh from crypto/ecdh + crypto/internal/fips140/ecdsa from crypto/ecdsa + crypto/internal/fips140/ed25519 from crypto/ed25519 + crypto/internal/fips140/edwards25519 from crypto/internal/fips140/ed25519 + crypto/internal/fips140/edwards25519/field from crypto/ecdh+ + crypto/internal/fips140/hkdf from crypto/internal/fips140/tls13+ + crypto/internal/fips140/hmac from crypto/hmac+ + crypto/internal/fips140/mlkem from crypto/tls+ + crypto/internal/fips140/nistec from crypto/elliptic+ + crypto/internal/fips140/nistec/fiat from crypto/internal/fips140/nistec + crypto/internal/fips140/rsa from crypto/rsa + crypto/internal/fips140/sha256 from crypto/internal/fips140/check+ + crypto/internal/fips140/sha3 from crypto/internal/fips140/hmac+ + crypto/internal/fips140/sha512 from crypto/internal/fips140/ecdsa+ + crypto/internal/fips140/subtle from crypto/internal/fips140/aes+ + crypto/internal/fips140/tls12 from crypto/tls + crypto/internal/fips140/tls13 from crypto/tls + crypto/internal/fips140cache from crypto/ecdsa+ + crypto/internal/fips140deps/byteorder from crypto/internal/fips140/aes+ + crypto/internal/fips140deps/cpu from crypto/internal/fips140/aes+ + crypto/internal/fips140deps/godebug from crypto/internal/fips140+ + crypto/internal/fips140hash from crypto/ecdsa+ + crypto/internal/fips140only from crypto/cipher+ + crypto/internal/hpke from crypto/tls + crypto/internal/impl from crypto/internal/fips140/aes+ + crypto/internal/randutil from crypto/dsa+ + crypto/internal/sysrand from crypto/internal/entropy+ crypto/md5 from crypto/tls+ + LD crypto/mlkem from golang.org/x/crypto/ssh crypto/rand from crypto/ed25519+ crypto/rc4 from crypto/tls+ crypto/rsa from crypto/tls+ crypto/sha1 from crypto/tls+ crypto/sha256 from crypto/tls+ + crypto/sha3 from crypto/internal/fips140hash crypto/sha512 from crypto/ecdsa+ - crypto/subtle from crypto/aes+ + crypto/subtle from crypto/cipher+ crypto/tls from github.com/aws/aws-sdk-go-v2/aws/transport/http+ + crypto/tls/internal/fips140tls from crypto/tls crypto/x509 from crypto/tls+ + D crypto/x509/internal/macos from crypto/x509 crypto/x509/pkix from crypto/x509+ - database/sql/driver from github.com/google/uuid + DW database/sql/driver from github.com/google/uuid W debug/dwarf from debug/pe W debug/pe from github.com/dblohm7/wingoes/pe - embed from crypto/internal/nistec+ - encoding from encoding/gob+ + embed from github.com/tailscale/web-client-prebuilt+ + encoding from encoding/json+ encoding/asn1 from crypto/x509+ encoding/base32 from github.com/fxamacker/cbor/v2+ encoding/base64 from encoding/json+ encoding/binary from compress/gzip+ - encoding/gob from github.com/gorilla/securecookie encoding/hex from crypto/x509+ encoding/json from expvar+ encoding/pem from crypto/tls+ encoding/xml from github.com/aws/aws-sdk-go-v2/aws/protocol/xml+ errors from archive/tar+ - expvar from tailscale.com/derp+ - flag from net/http/httptest+ + expvar from tailscale.com/cmd/tailscaled+ + flag from tailscale.com/cmd/tailscaled+ fmt from archive/tar+ hash from compress/zlib+ hash/adler32 from compress/zlib+ hash/crc32 from compress/gzip+ hash/maphash from go4.org/mem html from html/template+ - html/template from github.com/gorilla/csrf + html/template from tailscale.com/util/eventbus + internal/abi from crypto/x509/internal/macos+ + internal/asan from internal/runtime/maps+ + internal/bisect from internal/godebug + internal/bytealg from bytes+ + internal/byteorder from crypto/cipher+ + internal/chacha8rand from math/rand/v2+ + internal/coverage/rtcov from runtime + internal/cpu from crypto/internal/fips140deps/cpu+ + internal/filepathlite from os+ + internal/fmtsort from fmt+ + internal/goarch from crypto/internal/fips140deps/cpu+ + internal/godebug from archive/tar+ + internal/godebugs from internal/godebug+ + internal/goexperiment from hash/maphash+ + internal/goos from crypto/x509+ + internal/itoa from internal/poll+ + internal/msan from internal/runtime/maps+ + internal/nettrace from net+ + internal/oserror from io/fs+ + internal/poll from net+ + internal/profile from net/http/pprof + internal/profilerecord from runtime+ + internal/race from internal/poll+ + internal/reflectlite from context+ + D internal/routebsd from net + internal/runtime/atomic from internal/runtime/exithook+ + L internal/runtime/cgroup from runtime + internal/runtime/exithook from runtime + internal/runtime/gc from runtime + internal/runtime/maps from reflect+ + internal/runtime/math from internal/runtime/maps+ + internal/runtime/strconv from internal/runtime/cgroup+ + internal/runtime/sys from crypto/subtle+ + L internal/runtime/syscall from runtime+ + internal/saferio from debug/pe+ + internal/singleflight from net + internal/stringslite from embed+ + internal/sync from sync+ + internal/synctest from sync + internal/syscall/execenv from os+ + LD internal/syscall/unix from crypto/internal/sysrand+ + W internal/syscall/windows from crypto/internal/sysrand+ + W internal/syscall/windows/registry from mime+ + W internal/syscall/windows/sysdll from internal/syscall/windows+ + internal/testlog from os + internal/trace/tracev2 from runtime+ + internal/unsafeheader from internal/reflectlite+ io from archive/tar+ io/fs from archive/tar+ io/ioutil from github.com/aws/aws-sdk-go-v2/aws/protocol/query+ @@ -549,15 +714,16 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de mime/quotedprintable from mime/multipart net from crypto/tls+ net/http from expvar+ - net/http/httptest from tailscale.com/control/controlclient - net/http/httptrace from github.com/tcnksm/go-httpstat+ + net/http/httptrace from github.com/prometheus-community/pro-bing+ net/http/httputil from github.com/aws/smithy-go/transport/http+ net/http/internal from net/http+ + net/http/internal/ascii from net/http+ + net/http/internal/httpcommon from net/http net/http/pprof from tailscale.com/cmd/tailscaled+ net/netip from github.com/tailscale/wireguard-go/conn+ net/textproto from github.com/aws/aws-sdk-go-v2/aws/signer/v4+ net/url from crypto/x509+ - os from crypto/rand+ + os from crypto/internal/sysrand+ os/exec from github.com/aws/aws-sdk-go-v2/credentials/processcreds+ os/signal from tailscale.com/cmd/tailscaled os/user from archive/tar+ @@ -566,6 +732,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de reflect from archive/tar+ regexp from github.com/aws/aws-sdk-go-v2/internal/endpoints/awsrulesfn+ regexp/syntax from regexp + runtime from archive/tar+ runtime/debug from github.com/aws/aws-sdk-go-v2/internal/sync/singleflight+ runtime/pprof from net/http/pprof+ runtime/trace from net/http/pprof @@ -573,6 +740,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de sort from compress/flate+ strconv from archive/tar+ strings from archive/tar+ + W structs from internal/syscall/windows sync from archive/tar+ sync/atomic from context+ syscall from archive/tar+ @@ -584,3 +752,5 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de unicode/utf16 from crypto/x509+ unicode/utf8 from bufio+ unique from net/netip + unsafe from bytes+ + weak from unique+ diff --git a/cmd/tailscaled/deps_test.go b/cmd/tailscaled/deps_test.go new file mode 100644 index 000000000..64d1beca7 --- /dev/null +++ b/cmd/tailscaled/deps_test.go @@ -0,0 +1,293 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package main + +import ( + "maps" + "slices" + "strings" + "testing" + + "tailscale.com/feature/featuretags" + "tailscale.com/tstest/deptest" +) + +func TestOmitSSH(t *testing.T) { + const msg = "unexpected with ts_omit_ssh" + deptest.DepChecker{ + GOOS: "linux", + GOARCH: "amd64", + Tags: "ts_omit_ssh,ts_include_cli", + BadDeps: map[string]string{ + "golang.org/x/crypto/ssh": msg, + "tailscale.com/ssh/tailssh": msg, + "tailscale.com/sessionrecording": msg, + "github.com/anmitsu/go-shlex": msg, + "github.com/creack/pty": msg, + "github.com/kr/fs": msg, + "github.com/pkg/sftp": msg, + "github.com/u-root/u-root/pkg/termios": msg, + "tempfork/gliderlabs/ssh": msg, + }, + }.Check(t) +} + +func TestOmitSyspolicy(t *testing.T) { + const msg = "unexpected syspolicy usage with ts_omit_syspolicy" + deptest.DepChecker{ + GOOS: "linux", + GOARCH: "amd64", + Tags: "ts_omit_syspolicy,ts_include_cli", + BadDeps: map[string]string{ + "tailscale.com/util/syspolicy": msg, + "tailscale.com/util/syspolicy/setting": msg, + "tailscale.com/util/syspolicy/rsop": msg, + }, + }.Check(t) +} + +func TestOmitLocalClient(t *testing.T) { + deptest.DepChecker{ + GOOS: "linux", + GOARCH: "amd64", + Tags: "ts_omit_webclient,ts_omit_relayserver,ts_omit_oauthkey,ts_omit_acme", + BadDeps: map[string]string{ + "tailscale.com/client/local": "unexpected", + }, + }.Check(t) +} + +// Test that we can build a binary without reflect.MethodByName. +// See https://github.com/tailscale/tailscale/issues/17063 +func TestOmitReflectThings(t *testing.T) { + deptest.DepChecker{ + GOOS: "linux", + GOARCH: "amd64", + Tags: "ts_include_cli,ts_omit_systray,ts_omit_debugeventbus,ts_omit_webclient", + BadDeps: map[string]string{ + "text/template": "unexpected text/template usage", + "html/template": "unexpected text/template usage", + }, + OnDep: func(dep string) { + if strings.Contains(dep, "systray") { + t.Errorf("unexpected systray dep %q", dep) + } + }, + }.Check(t) +} + +func TestOmitDrive(t *testing.T) { + deptest.DepChecker{ + GOOS: "linux", + GOARCH: "amd64", + Tags: "ts_omit_drive,ts_include_cli", + OnDep: func(dep string) { + if strings.Contains(dep, "driveimpl") { + t.Errorf("unexpected dep with ts_omit_drive: %q", dep) + } + if strings.Contains(dep, "webdav") { + t.Errorf("unexpected dep with ts_omit_drive: %q", dep) + } + }, + }.Check(t) +} + +func TestOmitPortmapper(t *testing.T) { + deptest.DepChecker{ + GOOS: "linux", + GOARCH: "amd64", + Tags: "ts_omit_portmapper,ts_include_cli,ts_omit_debugportmapper", + OnDep: func(dep string) { + if dep == "tailscale.com/net/portmapper" { + t.Errorf("unexpected dep with ts_omit_portmapper: %q", dep) + return + } + if strings.Contains(dep, "goupnp") || strings.Contains(dep, "/soap") || + strings.Contains(dep, "internetgateway2") { + t.Errorf("unexpected dep with ts_omit_portmapper: %q", dep) + } + }, + }.Check(t) +} + +func TestOmitACME(t *testing.T) { + deptest.DepChecker{ + GOOS: "linux", + GOARCH: "amd64", + Tags: "ts_omit_acme,ts_include_cli", + OnDep: func(dep string) { + if strings.Contains(dep, "/acme") { + t.Errorf("unexpected dep with ts_omit_acme: %q", dep) + } + }, + }.Check(t) +} + +func TestOmitCaptivePortal(t *testing.T) { + deptest.DepChecker{ + GOOS: "linux", + GOARCH: "amd64", + Tags: "ts_omit_captiveportal,ts_include_cli", + OnDep: func(dep string) { + if strings.Contains(dep, "captive") { + t.Errorf("unexpected dep with ts_omit_captiveportal: %q", dep) + } + }, + }.Check(t) +} + +func TestOmitAuth(t *testing.T) { + deptest.DepChecker{ + GOOS: "linux", + GOARCH: "amd64", + Tags: "ts_omit_oauthkey,ts_omit_identityfederation,ts_include_cli", + OnDep: func(dep string) { + if strings.HasPrefix(dep, "golang.org/x/oauth2") { + t.Errorf("unexpected oauth2 dep: %q", dep) + } + }, + }.Check(t) +} + +func TestOmitOutboundProxy(t *testing.T) { + deptest.DepChecker{ + GOOS: "linux", + GOARCH: "amd64", + Tags: "ts_omit_outboundproxy,ts_include_cli", + OnDep: func(dep string) { + if strings.Contains(dep, "socks5") || strings.Contains(dep, "proxymux") { + t.Errorf("unexpected dep with ts_omit_outboundproxy: %q", dep) + } + }, + }.Check(t) +} + +func TestOmitDBus(t *testing.T) { + deptest.DepChecker{ + GOOS: "linux", + GOARCH: "amd64", + Tags: "ts_omit_networkmanager,ts_omit_dbus,ts_omit_resolved,ts_omit_systray,ts_omit_ssh,ts_include_cli", + OnDep: func(dep string) { + if strings.Contains(dep, "dbus") { + t.Errorf("unexpected DBus dep: %q", dep) + } + }, + }.Check(t) +} + +func TestNetstack(t *testing.T) { + deptest.DepChecker{ + GOOS: "linux", + GOARCH: "amd64", + Tags: "ts_omit_gro,ts_omit_netstack,ts_omit_outboundproxy,ts_omit_serve,ts_omit_ssh,ts_omit_webclient,ts_omit_tap", + OnDep: func(dep string) { + if strings.Contains(dep, "gvisor") { + t.Errorf("unexpected gvisor dep: %q", dep) + } + }, + }.Check(t) +} + +func TestOmitPortlist(t *testing.T) { + deptest.DepChecker{ + GOOS: "linux", + GOARCH: "amd64", + Tags: "ts_omit_portlist,ts_include_cli", + OnDep: func(dep string) { + if strings.Contains(dep, "portlist") { + t.Errorf("unexpected dep: %q", dep) + } + }, + }.Check(t) +} + +func TestOmitGRO(t *testing.T) { + deptest.DepChecker{ + GOOS: "linux", + GOARCH: "amd64", + Tags: "ts_omit_gro,ts_include_cli", + BadDeps: map[string]string{ + "gvisor.dev/gvisor/pkg/tcpip/stack/gro": "unexpected dep with ts_omit_gro", + }, + }.Check(t) +} + +func TestOmitUseProxy(t *testing.T) { + deptest.DepChecker{ + GOOS: "linux", + GOARCH: "amd64", + Tags: "ts_omit_useproxy,ts_include_cli", + OnDep: func(dep string) { + if strings.Contains(dep, "tshttproxy") { + t.Errorf("unexpected dep: %q", dep) + } + }, + }.Check(t) +} + +func minTags() string { + var tags []string + for _, f := range slices.Sorted(maps.Keys(featuretags.Features)) { + if f.IsOmittable() { + tags = append(tags, f.OmitTag()) + } + } + return strings.Join(tags, ",") +} + +func TestMinTailscaledNoCLI(t *testing.T) { + badSubstrs := []string{ + "cbor", + "regexp", + "golang.org/x/net/proxy", + "internal/socks", + "github.com/tailscale/peercred", + "tailscale.com/types/netlogtype", + "deephash", + "util/hashx", + } + deptest.DepChecker{ + GOOS: "linux", + GOARCH: "amd64", + Tags: minTags(), + OnDep: func(dep string) { + for _, bad := range badSubstrs { + if strings.Contains(dep, bad) { + t.Errorf("unexpected dep: %q", dep) + } + } + }, + }.Check(t) +} + +func TestMinTailscaledWithCLI(t *testing.T) { + badSubstrs := []string{ + "cbor", + "hujson", + "pprof", + "multierr", // https://github.com/tailscale/tailscale/pull/17379 + "tailscale.com/metrics", + "tailscale.com/tsweb/varz", + "dirwalk", + "deephash", + "util/hashx", + } + deptest.DepChecker{ + GOOS: "linux", + GOARCH: "amd64", + Tags: minTags() + ",ts_include_cli", + OnDep: func(dep string) { + for _, bad := range badSubstrs { + if strings.Contains(dep, bad) { + t.Errorf("unexpected dep: %q", dep) + } + } + }, + BadDeps: map[string]string{ + "golang.org/x/net/http2": "unexpected x/net/http2 dep; tailscale/tailscale#17305", + "expvar": "unexpected expvar dep", + "github.com/mdlayher/genetlink": "unexpected genetlink dep", + }, + }.Check(t) +} diff --git a/cmd/tailscaled/flag.go b/cmd/tailscaled/flag.go new file mode 100644 index 000000000..f640aceed --- /dev/null +++ b/cmd/tailscaled/flag.go @@ -0,0 +1,31 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package main + +import "strconv" + +// boolFlag is a flag.Value that tracks whether it was ever set. +type boolFlag struct { + set bool + v bool +} + +func (b *boolFlag) String() string { + if b == nil || !b.set { + return "unset" + } + return strconv.FormatBool(b.v) +} + +func (b *boolFlag) Set(s string) error { + v, err := strconv.ParseBool(s) + if err != nil { + return err + } + b.v = v + b.set = true + return nil +} + +func (b *boolFlag) IsBoolFlag() bool { return true } diff --git a/cmd/tailscaled/install_windows.go b/cmd/tailscaled/install_windows.go index c36418642..6013660f5 100644 --- a/cmd/tailscaled/install_windows.go +++ b/cmd/tailscaled/install_windows.go @@ -15,9 +15,9 @@ import ( "golang.org/x/sys/windows" "golang.org/x/sys/windows/svc" "golang.org/x/sys/windows/svc/mgr" - "tailscale.com/logtail/backoff" + "tailscale.com/cmd/tailscaled/tailscaledhooks" "tailscale.com/types/logger" - "tailscale.com/util/osshare" + "tailscale.com/util/backoff" ) func init() { @@ -25,6 +25,16 @@ func init() { uninstallSystemDaemon = uninstallSystemDaemonWindows } +// serviceDependencies lists all system services that tailscaled depends on. +// This list must be kept in sync with the TailscaledDependencies preprocessor +// variable in the installer. +var serviceDependencies = []string{ + "Dnscache", + "iphlpsvc", + "netprofm", + "WinHttpAutoProxySvc", +} + func installSystemDaemonWindows(args []string) (err error) { m, err := mgr.Connect() if err != nil { @@ -48,6 +58,7 @@ func installSystemDaemonWindows(args []string) (err error) { ServiceType: windows.SERVICE_WIN32_OWN_PROCESS, StartType: mgr.StartAutomatic, ErrorControl: mgr.ErrorNormal, + Dependencies: serviceDependencies, DisplayName: serviceName, Description: "Connects this computer to others on the Tailscale network.", } @@ -81,8 +92,9 @@ func installSystemDaemonWindows(args []string) (err error) { } func uninstallSystemDaemonWindows(args []string) (ret error) { - // Remove file sharing from Windows shell (noop in non-windows) - osshare.SetFileSharingEnabled(false, logger.Discard) + for _, f := range tailscaledhooks.UninstallSystemDaemonWindows { + f() + } m, err := mgr.Connect() if err != nil { diff --git a/cmd/tailscaled/netstack.go b/cmd/tailscaled/netstack.go new file mode 100644 index 000000000..c0b34ed41 --- /dev/null +++ b/cmd/tailscaled/netstack.go @@ -0,0 +1,75 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !ts_omit_netstack + +package main + +import ( + "context" + "expvar" + "net" + "net/netip" + + "tailscale.com/tsd" + "tailscale.com/types/logger" + "tailscale.com/wgengine/netstack" +) + +func init() { + hookNewNetstack.Set(newNetstack) +} + +func newNetstack(logf logger.Logf, sys *tsd.System, onlyNetstack bool) (tsd.NetstackImpl, error) { + ns, err := netstack.Create(logf, + sys.Tun.Get(), + sys.Engine.Get(), + sys.MagicSock.Get(), + sys.Dialer.Get(), + sys.DNSManager.Get(), + sys.ProxyMapper(), + ) + if err != nil { + return nil, err + } + // Only register debug info if we have a debug mux + if debugMux != nil { + expvar.Publish("netstack", ns.ExpVar()) + } + + sys.Set(ns) + ns.ProcessLocalIPs = onlyNetstack + ns.ProcessSubnets = onlyNetstack || handleSubnetsInNetstack() + + dialer := sys.Dialer.Get() // must be set by caller already + + if onlyNetstack { + e := sys.Engine.Get() + dialer.UseNetstackForIP = func(ip netip.Addr) bool { + _, ok := e.PeerForIP(ip) + return ok + } + dialer.NetstackDialTCP = func(ctx context.Context, dst netip.AddrPort) (net.Conn, error) { + // Note: don't just return ns.DialContextTCP or we'll return + // *gonet.TCPConn(nil) instead of a nil interface which trips up + // callers. + tcpConn, err := ns.DialContextTCP(ctx, dst) + if err != nil { + return nil, err + } + return tcpConn, nil + } + dialer.NetstackDialUDP = func(ctx context.Context, dst netip.AddrPort) (net.Conn, error) { + // Note: don't just return ns.DialContextUDP or we'll return + // *gonet.UDPConn(nil) instead of a nil interface which trips up + // callers. + udpConn, err := ns.DialContextUDP(ctx, dst) + if err != nil { + return nil, err + } + return udpConn, nil + } + } + + return ns, nil +} diff --git a/cmd/tailscaled/proxy.go b/cmd/tailscaled/proxy.go index a91c62bfa..85c3d91f9 100644 --- a/cmd/tailscaled/proxy.go +++ b/cmd/tailscaled/proxy.go @@ -1,7 +1,7 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -//go:build go1.19 +//go:build !ts_omit_outboundproxy // HTTP proxy code @@ -9,13 +9,107 @@ package main import ( "context" + "flag" "io" + "log" "net" "net/http" "net/http/httputil" "strings" + + "tailscale.com/feature" + "tailscale.com/net/proxymux" + "tailscale.com/net/socks5" + "tailscale.com/net/tsdial" + "tailscale.com/types/logger" ) +func init() { + hookRegisterOutboundProxyFlags.Set(registerOutboundProxyFlags) + hookOutboundProxyListen.Set(outboundProxyListen) +} + +func registerOutboundProxyFlags() { + flag.StringVar(&args.socksAddr, "socks5-server", "", `optional [ip]:port to run a SOCK5 server (e.g. "localhost:1080")`) + flag.StringVar(&args.httpProxyAddr, "outbound-http-proxy-listen", "", `optional [ip]:port to run an outbound HTTP proxy (e.g. "localhost:8080")`) +} + +// outboundProxyListen creates listeners for local SOCKS and HTTP proxies, if +// the respective addresses are not empty. args.socksAddr and args.httpProxyAddr +// can be the same, in which case the SOCKS5 Listener will receive connections +// that look like they're speaking SOCKS and httpListener will receive +// everything else. +// +// socksListener and httpListener can be nil, if their respective addrs are +// empty. +// +// The returned func closes over those two (possibly nil) listeners and +// starts the respective servers on the listener when called. +func outboundProxyListen() proxyStartFunc { + socksAddr, httpAddr := args.socksAddr, args.httpProxyAddr + + if socksAddr == httpAddr && socksAddr != "" && !strings.HasSuffix(socksAddr, ":0") { + ln, err := net.Listen("tcp", socksAddr) + if err != nil { + log.Fatalf("proxy listener: %v", err) + } + return mkProxyStartFunc(proxymux.SplitSOCKSAndHTTP(ln)) + } + + var socksListener, httpListener net.Listener + var err error + if socksAddr != "" { + socksListener, err = net.Listen("tcp", socksAddr) + if err != nil { + log.Fatalf("SOCKS5 listener: %v", err) + } + if strings.HasSuffix(socksAddr, ":0") { + // Log kernel-selected port number so integration tests + // can find it portably. + log.Printf("SOCKS5 listening on %v", socksListener.Addr()) + } + } + if httpAddr != "" { + httpListener, err = net.Listen("tcp", httpAddr) + if err != nil { + log.Fatalf("HTTP proxy listener: %v", err) + } + if strings.HasSuffix(httpAddr, ":0") { + // Log kernel-selected port number so integration tests + // can find it portably. + log.Printf("HTTP proxy listening on %v", httpListener.Addr()) + } + } + + return mkProxyStartFunc(socksListener, httpListener) +} + +func mkProxyStartFunc(socksListener, httpListener net.Listener) proxyStartFunc { + return func(logf logger.Logf, dialer *tsdial.Dialer) { + var addrs []string + if httpListener != nil { + hs := &http.Server{Handler: httpProxyHandler(dialer.UserDial)} + go func() { + log.Fatalf("HTTP proxy exited: %v", hs.Serve(httpListener)) + }() + addrs = append(addrs, httpListener.Addr().String()) + } + if socksListener != nil { + ss := &socks5.Server{ + Logf: logger.WithPrefix(logf, "socks5: "), + Dialer: dialer.UserDial, + } + go func() { + log.Fatalf("SOCKS5 server exited: %v", ss.Serve(socksListener)) + }() + addrs = append(addrs, socksListener.Addr().String()) + } + if set, ok := feature.HookProxySetSelfProxy.GetOk(); ok { + set(addrs...) + } + } +} + // httpProxyHandler returns an HTTP proxy http.Handler using the // provided backend dialer. func httpProxyHandler(dialer func(ctx context.Context, netw, addr string) (net.Conn, error)) http.Handler { diff --git a/cmd/tailscaled/ssh.go b/cmd/tailscaled/ssh.go index f7b0b367e..59a1ddd0d 100644 --- a/cmd/tailscaled/ssh.go +++ b/cmd/tailscaled/ssh.go @@ -1,7 +1,7 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -//go:build linux || darwin || freebsd || openbsd +//go:build (linux || darwin || freebsd || openbsd || plan9) && !ts_omit_ssh package main diff --git a/cmd/tailscaled/tailscaled.go b/cmd/tailscaled/tailscaled.go index 2831b4061..f14cdcff0 100644 --- a/cmd/tailscaled/tailscaled.go +++ b/cmd/tailscaled/tailscaled.go @@ -13,14 +13,11 @@ package main // import "tailscale.com/cmd/tailscaled" import ( "context" "errors" - "expvar" "flag" "fmt" "log" "net" "net/http" - "net/http/pprof" - "net/netip" "os" "os/signal" "path/filepath" @@ -30,11 +27,12 @@ import ( "syscall" "time" - "tailscale.com/client/tailscale" "tailscale.com/cmd/tailscaled/childproc" "tailscale.com/control/controlclient" - "tailscale.com/drive/driveimpl" "tailscale.com/envknob" + "tailscale.com/feature" + "tailscale.com/feature/buildfeatures" + _ "tailscale.com/feature/condregister" "tailscale.com/hostinfo" "tailscale.com/ipn" "tailscale.com/ipn/conffile" @@ -47,26 +45,22 @@ import ( "tailscale.com/net/dnsfallback" "tailscale.com/net/netmon" "tailscale.com/net/netns" - "tailscale.com/net/proxymux" - "tailscale.com/net/socks5" "tailscale.com/net/tsdial" - "tailscale.com/net/tshttpproxy" "tailscale.com/net/tstun" "tailscale.com/paths" "tailscale.com/safesocket" "tailscale.com/syncs" "tailscale.com/tsd" - "tailscale.com/tsweb/varz" "tailscale.com/types/flagtype" + "tailscale.com/types/key" "tailscale.com/types/logger" "tailscale.com/types/logid" - "tailscale.com/util/clientmetric" - "tailscale.com/util/multierr" "tailscale.com/util/osshare" + "tailscale.com/util/syspolicy/pkey" + "tailscale.com/util/syspolicy/policyclient" "tailscale.com/version" "tailscale.com/version/distro" "tailscale.com/wgengine" - "tailscale.com/wgengine/netstack" "tailscale.com/wgengine/router" ) @@ -81,16 +75,16 @@ func defaultTunName() string { // "utun" is recognized by wireguard-go/tun/tun_darwin.go // as a magic value that uses/creates any free number. return "utun" - case "plan9", "aix": + case "plan9": + return "auto" + case "aix", "solaris", "illumos": return "userspace-networking" case "linux": - switch distro.Get() { - case distro.Synology: + if buildfeatures.HasSynology && buildfeatures.HasNetstack && distro.Get() == distro.Synology { // Try TUN, but fall back to userspace networking if needed. // See https://github.com/tailscale/tailscale-synology/issues/35 return "tailscale0,userspace-networking" } - } return "tailscale0" } @@ -118,18 +112,20 @@ var args struct { // or comma-separated list thereof. tunname string - cleanUp bool - confFile string // empty, file path, or "vm:user-data" - debug string - port uint16 - statepath string - statedir string - socketpath string - birdSocketPath string - verbose int - socksAddr string // listen address for SOCKS5 server - httpProxyAddr string // listen address for HTTP proxy server - disableLogs bool + cleanUp bool + confFile string // empty, file path, or "vm:user-data" + debug string + port uint16 + statepath string + encryptState boolFlag + statedir string + socketpath string + birdSocketPath string + verbose int + socksAddr string // listen address for SOCKS5 server + httpProxyAddr string // listen address for HTTP proxy server + disableLogs bool + hardwareAttestation boolFlag } var ( @@ -145,15 +141,47 @@ var ( var subCommands = map[string]*func([]string) error{ "install-system-daemon": &installSystemDaemon, "uninstall-system-daemon": &uninstallSystemDaemon, - "debug": &debugModeFunc, "be-child": &beChildFunc, - "serve-taildrive": &serveDriveFunc, } -var beCLI func() // non-nil if CLI is linked in +var beCLI func() // non-nil if CLI is linked in with the "ts_include_cli" build tag + +// shouldRunCLI reports whether we should run the Tailscale CLI (cmd/tailscale) +// instead of the daemon (cmd/tailscaled) in the case when the two are linked +// together into one binary for space savings reasons. +func shouldRunCLI() bool { + if beCLI == nil { + // Not linked in with the "ts_include_cli" build tag. + return false + } + if len(os.Args) > 0 && filepath.Base(os.Args[0]) == "tailscale" { + // The binary was named (or hardlinked) as "tailscale". + return true + } + if envknob.Bool("TS_BE_CLI") { + // The environment variable was set to force it. + return true + } + return false +} + +// Outbound Proxy hooks +var ( + hookRegisterOutboundProxyFlags feature.Hook[func()] + hookOutboundProxyListen feature.Hook[func() proxyStartFunc] +) + +// proxyStartFunc is the type of the function returned by +// outboundProxyListen, to start the servers on the Listeners +// started by hookOutboundProxyListen. +type proxyStartFunc = func(logf logger.Logf, dialer *tsdial.Dialer) func main() { envknob.PanicIfAnyEnvCheckedInInit() + if shouldRunCLI() { + beCLI() + return + } envknob.ApplyDiskConfig() applyIntegrationTestEnvKnob() @@ -161,22 +189,32 @@ func main() { printVersion := false flag.IntVar(&args.verbose, "verbose", defaultVerbosity(), "log verbosity level; 0 is default, 1 or higher are increasingly verbose") flag.BoolVar(&args.cleanUp, "cleanup", false, "clean up system state and exit") - flag.StringVar(&args.debug, "debug", "", "listen address ([ip]:port) of optional debug server") - flag.StringVar(&args.socksAddr, "socks5-server", "", `optional [ip]:port to run a SOCK5 server (e.g. "localhost:1080")`) - flag.StringVar(&args.httpProxyAddr, "outbound-http-proxy-listen", "", `optional [ip]:port to run an outbound HTTP proxy (e.g. "localhost:8080")`) + if buildfeatures.HasDebug { + flag.StringVar(&args.debug, "debug", "", "listen address ([ip]:port) of optional debug server") + } flag.StringVar(&args.tunname, "tun", defaultTunName(), `tunnel interface name; use "userspace-networking" (beta) to not use TUN`) flag.Var(flagtype.PortValue(&args.port, defaultPort()), "port", "UDP port to listen on for WireGuard and peer-to-peer traffic; 0 means automatically select") flag.StringVar(&args.statepath, "state", "", "absolute path of state file; use 'kube:' to use Kubernetes secrets or 'arn:aws:ssm:...' to store in AWS SSM; use 'mem:' to not store state and register as an ephemeral node. If empty and --statedir is provided, the default is /tailscaled.state. Default: "+paths.DefaultTailscaledStateFile()) + if buildfeatures.HasTPM { + flag.Var(&args.encryptState, "encrypt-state", `encrypt the state file on disk; when not set encryption will be enabled if supported on this platform; uses TPM on Linux and Windows, on all other platforms this flag is not supported`) + } flag.StringVar(&args.statedir, "statedir", "", "path to directory for storage of config state, TLS certs, temporary incoming Taildrop files, etc. If empty, it's derived from --state when possible.") flag.StringVar(&args.socketpath, "socket", paths.DefaultTailscaledSocket(), "path of the service unix socket") - flag.StringVar(&args.birdSocketPath, "bird-socket", "", "path of the bird unix socket") + if buildfeatures.HasBird { + flag.StringVar(&args.birdSocketPath, "bird-socket", "", "path of the bird unix socket") + } flag.BoolVar(&printVersion, "version", false, "print version information and exit") flag.BoolVar(&args.disableLogs, "no-logs-no-support", false, "disable log uploads; this also disables any technical support") flag.StringVar(&args.confFile, "config", "", "path to config file, or 'vm:user-data' to use the VM's user-data (EC2)") + if buildfeatures.HasTPM { + flag.Var(&args.hardwareAttestation, "hardware-attestation", "use hardware-backed keys to bind node identity to this device when supported by the OS and hardware. Uses TPM 2.0 on Linux and Windows; SecureEnclave on macOS and iOS; and Keystore on Android") + } + if f, ok := hookRegisterOutboundProxyFlags.GetOk(); ok { + f() + } - if len(os.Args) > 0 && filepath.Base(os.Args[0]) == "tailscale" && beCLI != nil { - beCLI() - return + if runtime.GOOS == "plan9" && os.Getenv("_NETSHELL_CHILD_") != "" { + os.Args = []string{"tailscaled", "be-child", "plan9-netshell"} } if len(os.Args) > 1 { @@ -221,7 +259,7 @@ func main() { log.Fatalf("--socket is required") } - if args.birdSocketPath != "" && createBIRDClient == nil { + if buildfeatures.HasBird && args.birdSocketPath != "" && createBIRDClient == nil { log.SetFlags(0) log.Fatalf("--bird-socket is not supported on %s", runtime.GOOS) } @@ -229,7 +267,21 @@ func main() { // Only apply a default statepath when neither have been provided, so that a // user may specify only --statedir if they wish. if args.statepath == "" && args.statedir == "" { - args.statepath = paths.DefaultTailscaledStateFile() + if paths.MakeAutomaticStateDir() { + d := paths.DefaultTailscaledStateDir() + if d != "" { + args.statedir = d + if err := os.MkdirAll(d, 0700); err != nil { + log.Fatalf("failed to create state directory: %v", err) + } + } + } else { + args.statepath = paths.DefaultTailscaledStateFile() + } + } + + if buildfeatures.HasTPM { + handleTPMFlags() } if args.disableLogs { @@ -242,8 +294,10 @@ func main() { err := run() - // Remove file sharing from Windows shell (noop in non-windows) - osshare.SetFileSharingEnabled(false, logger.Discard) + if buildfeatures.HasTaildrop { + // Remove file sharing from Windows shell (noop in non-windows) + osshare.SetFileSharingEnabled(false, logger.Discard) + } if err != nil { log.Fatal(err) @@ -279,13 +333,17 @@ func trySynologyMigration(p string) error { } func statePathOrDefault() string { + var path string if args.statepath != "" { - return args.statepath + path = args.statepath + } + if path == "" && args.statedir != "" { + path = filepath.Join(args.statedir, "tailscaled.state") } - if args.statedir != "" { - return filepath.Join(args.statedir, "tailscaled.state") + if path != "" && !store.HasKnownProviderPrefix(path) && args.encryptState.v { + path = store.TPMPrefix + path } - return "" + return path } // serverOptions is the configuration of the Tailscale node agent. @@ -332,13 +390,15 @@ func ipnServerOpts() (o serverOptions) { return o } -var logPol *logpolicy.Policy +var logPol *logpolicy.Policy // or nil if not used var debugMux *http.ServeMux func run() (err error) { var logf logger.Logf = log.Printf - sys := new(tsd.System) + // Install an event bus as early as possible, so that it's + // available universally when setting up everything else. + sys := tsd.NewSystem() // Parse config, if specified, to fail early if it's invalid. var conf *conffile.Config @@ -353,24 +413,32 @@ func run() (err error) { var netMon *netmon.Monitor isWinSvc := isWindowsService() if !isWinSvc { - netMon, err = netmon.New(func(format string, args ...any) { - logf(format, args...) - }) + netMon, err = netmon.New(sys.Bus.Get(), logf) if err != nil { return fmt.Errorf("netmon.New: %w", err) } sys.Set(netMon) } - pol := logpolicy.New(logtail.CollectionNode, netMon, sys.HealthTracker(), nil /* use log.Printf */) - pol.SetVerbosityLevel(args.verbose) - logPol = pol - defer func() { - // Finish uploading logs after closing everything else. - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - defer cancel() - pol.Shutdown(ctx) - }() + var publicLogID logid.PublicID + if buildfeatures.HasLogTail { + + pol := logpolicy.Options{ + Collection: logtail.CollectionNode, + NetMon: netMon, + Health: sys.HealthTracker.Get(), + Bus: sys.Bus.Get(), + }.New() + pol.SetVerbosityLevel(args.verbose) + publicLogID = pol.PublicID + logPol = pol + defer func() { + // Finish uploading logs after closing everything else. + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + pol.Shutdown(ctx) + }() + } if err := envknob.ApplyDiskConfigError(); err != nil { log.Printf("Error reading environment config: %v", err) @@ -379,7 +447,7 @@ func run() (err error) { if isWinSvc { // Run the IPN server from the Windows service manager. log.Printf("Running service...") - if err := runWindowsService(pol); err != nil { + if err := runWindowsService(logPol); err != nil { log.Printf("runservice: %v", err) } log.Printf("Service ended.") @@ -397,7 +465,7 @@ func run() (err error) { // Always clean up, even if we're going to run the server. This covers cases // such as when a system was rebooted without shutting down, or tailscaled // crashed, and would for example restore system DNS configuration. - dns.CleanUp(logf, netMon, sys.HealthTracker(), args.tunname) + dns.CleanUp(logf, netMon, sys.Bus.Get(), sys.HealthTracker.Get(), args.tunname) router.CleanUp(logf, netMon, args.tunname) // If the cleanUp flag was passed, then exit. if args.cleanUp { @@ -411,21 +479,29 @@ func run() (err error) { log.Printf("error in synology migration: %v", err) } - if args.debug != "" { - debugMux = newDebugMux() + if buildfeatures.HasDebug && args.debug != "" { + debugMux = hookNewDebugMux.Get()() } - sys.Set(driveimpl.NewFileSystemForRemote(logf)) + if f, ok := hookSetSysDrive.GetOk(); ok { + f(sys, logf) + } if app := envknob.App(); app != "" { hostinfo.SetApp(app) } - return startIPNServer(context.Background(), logf, pol.PublicID, sys) + return startIPNServer(context.Background(), logf, publicLogID, sys) } +var ( + hookSetSysDrive feature.Hook[func(*tsd.System, logger.Logf)] + hookSetWgEnginConfigDrive feature.Hook[func(*wgengine.Config, logger.Logf)] +) + var sigPipe os.Signal // set by sigpipe.go +// logID may be the zero value if logging is not in use. func startIPNServer(ctx context.Context, logf logger.Logf, logID logid.PublicID, sys *tsd.System) error { ln, err := safesocket.Listen(args.socketpath) if err != nil { @@ -467,8 +543,8 @@ func startIPNServer(ctx context.Context, logf logger.Logf, logID logid.PublicID, } }() - srv := ipnserver.New(logf, logID, sys.NetMon.Get()) - if debugMux != nil { + srv := ipnserver.New(logf, logID, sys.Bus.Get(), sys.NetMon.Get()) + if buildfeatures.HasDebug && debugMux != nil { debugMux.HandleFunc("/debug/ipn", srv.ServeHTMLStatus) } var lbErr syncs.AtomicValue[error] @@ -519,82 +595,49 @@ func startIPNServer(ctx context.Context, logf logger.Logf, logID logid.PublicID, return nil } +var ( + hookNewNetstack feature.Hook[func(_ logger.Logf, _ *tsd.System, onlyNetstack bool) (tsd.NetstackImpl, error)] +) + +// logID may be the zero value if logging is not in use. func getLocalBackend(ctx context.Context, logf logger.Logf, logID logid.PublicID, sys *tsd.System) (_ *ipnlocal.LocalBackend, retErr error) { if logPol != nil { logPol.Logtail.SetNetMon(sys.NetMon.Get()) } - socksListener, httpProxyListener := mustStartProxyListeners(args.socksAddr, args.httpProxyAddr) + var startProxy proxyStartFunc + if listen, ok := hookOutboundProxyListen.GetOk(); ok { + startProxy = listen() + } dialer := &tsdial.Dialer{Logf: logf} // mutated below (before used) + dialer.SetBus(sys.Bus.Get()) sys.Set(dialer) onlyNetstack, err := createEngine(logf, sys) if err != nil { return nil, fmt.Errorf("createEngine: %w", err) } - if debugMux != nil { + if onlyNetstack && !buildfeatures.HasNetstack { + return nil, errors.New("userspace-networking support is not compiled in to this binary") + } + if buildfeatures.HasDebug && debugMux != nil { if ms, ok := sys.MagicSock.GetOK(); ok { debugMux.HandleFunc("/debug/magicsock", ms.ServeHTTPDebug) } - go runDebugServer(debugMux, args.debug) + go runDebugServer(logf, debugMux, args.debug) } - ns, err := newNetstack(logf, sys) - if err != nil { - return nil, fmt.Errorf("newNetstack: %w", err) + var ns tsd.NetstackImpl // or nil if not linked in + if newNetstack, ok := hookNewNetstack.GetOk(); ok { + ns, err = newNetstack(logf, sys, onlyNetstack) + if err != nil { + return nil, fmt.Errorf("newNetstack: %w", err) + } } - sys.Set(ns) - ns.ProcessLocalIPs = onlyNetstack - ns.ProcessSubnets = onlyNetstack || handleSubnetsInNetstack() - if onlyNetstack { - e := sys.Engine.Get() - dialer.UseNetstackForIP = func(ip netip.Addr) bool { - _, ok := e.PeerForIP(ip) - return ok - } - dialer.NetstackDialTCP = func(ctx context.Context, dst netip.AddrPort) (net.Conn, error) { - // Note: don't just return ns.DialContextTCP or we'll return - // *gonet.TCPConn(nil) instead of a nil interface which trips up - // callers. - tcpConn, err := ns.DialContextTCP(ctx, dst) - if err != nil { - return nil, err - } - return tcpConn, nil - } - dialer.NetstackDialUDP = func(ctx context.Context, dst netip.AddrPort) (net.Conn, error) { - // Note: don't just return ns.DialContextUDP or we'll return - // *gonet.UDPConn(nil) instead of a nil interface which trips up - // callers. - udpConn, err := ns.DialContextUDP(ctx, dst) - if err != nil { - return nil, err - } - return udpConn, nil - } - } - if socksListener != nil || httpProxyListener != nil { - var addrs []string - if httpProxyListener != nil { - hs := &http.Server{Handler: httpProxyHandler(dialer.UserDial)} - go func() { - log.Fatalf("HTTP proxy exited: %v", hs.Serve(httpProxyListener)) - }() - addrs = append(addrs, httpProxyListener.Addr().String()) - } - if socksListener != nil { - ss := &socks5.Server{ - Logf: logger.WithPrefix(logf, "socks5: "), - Dialer: dialer.UserDial, - } - go func() { - log.Fatalf("SOCKS5 server exited: %v", ss.Serve(socksListener)) - }() - addrs = append(addrs, socksListener.Addr().String()) - } - tshttpproxy.SetSelfProxy(addrs...) + if startProxy != nil { + go startProxy(logf, dialer) } opts := ipnServerOpts() @@ -620,17 +663,23 @@ func getLocalBackend(ctx context.Context, logf logger.Logf, logID logid.PublicID if root := lb.TailscaleVarRoot(); root != "" { dnsfallback.SetCachePath(filepath.Join(root, "derpmap.cached.json"), logf) } - lb.ConfigureWebClient(&tailscale.LocalClient{ - Socket: args.socketpath, - UseSocketOnly: args.socketpath != paths.DefaultTailscaledSocket(), - }) - configureTaildrop(logf, lb) - if err := ns.Start(lb); err != nil { - log.Fatalf("failed to start netstack: %v", err) + if f, ok := hookConfigureWebClient.GetOk(); ok { + f(lb) + } + + if ns != nil { + if err := ns.Start(lb); err != nil { + log.Fatalf("failed to start netstack: %v", err) + } + } + if buildfeatures.HasTPM && args.hardwareAttestation.v { + lb.SetHardwareAttested() } return lb, nil } +var hookConfigureWebClient feature.Hook[func(*ipnlocal.LocalBackend)] + // createEngine tries to the wgengine.Engine based on the order of tunnels // specified in the command line flags. // @@ -650,7 +699,7 @@ func createEngine(logf logger.Logf, sys *tsd.System) (onlyNetstack bool, err err logf("wgengine.NewUserspaceEngine(tun %q) error: %v", name, err) errs = append(errs, err) } - return false, multierr.New(errs...) + return false, errors.Join(errs...) } // handleSubnetsInNetstack reports whether netstack should handle subnet routers @@ -665,7 +714,7 @@ func handleSubnetsInNetstack() bool { return true } switch runtime.GOOS { - case "windows", "darwin", "freebsd", "openbsd": + case "windows", "darwin", "freebsd", "openbsd", "solaris", "illumos": // Enable on Windows and tailscaled-on-macOS (this doesn't // affect the GUI clients), and on FreeBSD. return true @@ -679,15 +728,18 @@ func tryEngine(logf logger.Logf, sys *tsd.System, name string) (onlyNetstack boo conf := wgengine.Config{ ListenPort: args.port, NetMon: sys.NetMon.Get(), - HealthTracker: sys.HealthTracker(), + HealthTracker: sys.HealthTracker.Get(), Metrics: sys.UserMetricsRegistry(), Dialer: sys.Dialer.Get(), SetSubsystem: sys.Set, ControlKnobs: sys.ControlKnobs(), - DriveForLocal: driveimpl.NewFileSystemForLocal(logf), + EventBus: sys.Bus.Get(), + } + if f, ok := hookSetWgEnginConfigDrive.GetOk(); ok { + f(&conf, logf) } - sys.HealthTracker().SetMetricsRegistry(sys.UserMetricsRegistry()) + sys.HealthTracker.Get().SetMetricsRegistry(sys.UserMetricsRegistry()) onlyNetstack = name == "userspace-networking" netstackSubnetRouter := onlyNetstack // but mutated later on some platforms @@ -708,7 +760,7 @@ func tryEngine(logf logger.Logf, sys *tsd.System, name string) (onlyNetstack boo // configuration being unavailable (from the noop // manager). More in Issue 4017. // TODO(bradfitz): add a Synology-specific DNS manager. - conf.DNS, err = dns.NewOSConfigurator(logf, sys.HealthTracker(), sys.ControlKnobs(), "") // empty interface name + conf.DNS, err = dns.NewOSConfigurator(logf, sys.HealthTracker.Get(), sys.PolicyClientOrDefault(), sys.ControlKnobs(), "") // empty interface name if err != nil { return false, fmt.Errorf("dns.NewOSConfigurator: %w", err) } @@ -730,13 +782,19 @@ func tryEngine(logf logger.Logf, sys *tsd.System, name string) (onlyNetstack boo return false, err } - r, err := router.New(logf, dev, sys.NetMon.Get(), sys.HealthTracker()) + if runtime.GOOS == "plan9" { + // TODO(bradfitz): why don't we do this on all platforms? + // We should. Doing it just on plan9 for now conservatively. + sys.NetMon.Get().SetTailscaleInterfaceName(devName) + } + + r, err := router.New(logf, dev, sys.NetMon.Get(), sys.HealthTracker.Get(), sys.Bus.Get()) if err != nil { dev.Close() return false, fmt.Errorf("creating router: %w", err) } - d, err := dns.NewOSConfigurator(logf, sys.HealthTracker(), sys.ControlKnobs(), devName) + d, err := dns.NewOSConfigurator(logf, sys.HealthTracker.Get(), sys.PolicyClientOrDefault(), sys.ControlKnobs(), devName) if err != nil { dev.Close() r.Close() @@ -760,96 +818,27 @@ func tryEngine(logf logger.Logf, sys *tsd.System, name string) (onlyNetstack boo return onlyNetstack, nil } -func newDebugMux() *http.ServeMux { - mux := http.NewServeMux() - mux.HandleFunc("/debug/metrics", servePrometheusMetrics) - mux.HandleFunc("/debug/pprof/", pprof.Index) - mux.HandleFunc("/debug/pprof/cmdline", pprof.Cmdline) - mux.HandleFunc("/debug/pprof/profile", pprof.Profile) - mux.HandleFunc("/debug/pprof/symbol", pprof.Symbol) - mux.HandleFunc("/debug/pprof/trace", pprof.Trace) - return mux -} - -func servePrometheusMetrics(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "text/plain") - varz.Handler(w, r) - clientmetric.WritePrometheusExpositionFormat(w) -} +var hookNewDebugMux feature.Hook[func() *http.ServeMux] -func runDebugServer(mux *http.ServeMux, addr string) { - srv := &http.Server{ - Addr: addr, - Handler: mux, - } - if err := srv.ListenAndServe(); err != nil { - log.Fatal(err) +func runDebugServer(logf logger.Logf, mux *http.ServeMux, addr string) { + if !buildfeatures.HasDebug { + return } -} - -func newNetstack(logf logger.Logf, sys *tsd.System) (*netstack.Impl, error) { - tfs, _ := sys.DriveForLocal.GetOK() - ret, err := netstack.Create(logf, - sys.Tun.Get(), - sys.Engine.Get(), - sys.MagicSock.Get(), - sys.Dialer.Get(), - sys.DNSManager.Get(), - sys.ProxyMapper(), - tfs, - ) + ln, err := net.Listen("tcp", addr) if err != nil { - return nil, err - } - // Only register debug info if we have a debug mux - if debugMux != nil { - expvar.Publish("netstack", ret.ExpVar()) + log.Fatalf("debug server: %v", err) } - return ret, nil -} - -// mustStartProxyListeners creates listeners for local SOCKS and HTTP -// proxies, if the respective addresses are not empty. socksAddr and -// httpAddr can be the same, in which case socksListener will receive -// connections that look like they're speaking SOCKS and httpListener -// will receive everything else. -// -// socksListener and httpListener can be nil, if their respective -// addrs are empty. -func mustStartProxyListeners(socksAddr, httpAddr string) (socksListener, httpListener net.Listener) { - if socksAddr == httpAddr && socksAddr != "" && !strings.HasSuffix(socksAddr, ":0") { - ln, err := net.Listen("tcp", socksAddr) - if err != nil { - log.Fatalf("proxy listener: %v", err) - } - return proxymux.SplitSOCKSAndHTTP(ln) + if strings.HasSuffix(addr, ":0") { + // Log kernel-selected port number so integration tests + // can find it portably. + logf("DEBUG-ADDR=%v", ln.Addr()) } - - var err error - if socksAddr != "" { - socksListener, err = net.Listen("tcp", socksAddr) - if err != nil { - log.Fatalf("SOCKS5 listener: %v", err) - } - if strings.HasSuffix(socksAddr, ":0") { - // Log kernel-selected port number so integration tests - // can find it portably. - log.Printf("SOCKS5 listening on %v", socksListener.Addr()) - } + srv := &http.Server{ + Handler: mux, } - if httpAddr != "" { - httpListener, err = net.Listen("tcp", httpAddr) - if err != nil { - log.Fatalf("HTTP proxy listener: %v", err) - } - if strings.HasSuffix(httpAddr, ":0") { - // Log kernel-selected port number so integration tests - // can find it portably. - log.Printf("HTTP proxy listening on %v", httpListener.Addr()) - } + if err := srv.Serve(ln); err != nil { + log.Fatal(err) } - - return socksListener, httpListener } var beChildFunc = beChild @@ -866,35 +855,6 @@ func beChild(args []string) error { return f(args[1:]) } -var serveDriveFunc = serveDrive - -// serveDrive serves one or more Taildrives on localhost using the WebDAV -// protocol. On UNIX and MacOS tailscaled environment, Taildrive spawns child -// tailscaled processes in serve-taildrive mode in order to access the fliesystem -// as specific (usually unprivileged) users. -// -// serveDrive prints the address on which it's listening to stdout so that the -// parent process knows where to connect to. -func serveDrive(args []string) error { - if len(args) == 0 { - return errors.New("missing shares") - } - if len(args)%2 != 0 { - return errors.New("need pairs") - } - s, err := driveimpl.NewFileServer() - if err != nil { - return fmt.Errorf("unable to start Taildrive file server: %v", err) - } - shares := make(map[string]string) - for i := 0; i < len(args); i += 2 { - shares[args[i]] = args[i+1] - } - s.SetShares(shares) - fmt.Printf("%v\n", s.Addr()) - return s.Serve() -} - // dieOnPipeReadErrorOfFD reads from the pipe named by fd and exit the process // when the pipe becomes readable. We use this in tests as a somewhat more // portable mechanism for the Linux PR_SET_PDEATHSIG, which we wish existed on @@ -926,3 +886,65 @@ func applyIntegrationTestEnvKnob() { } } } + +// handleTPMFlags validates the --encrypt-state and --hardware-attestation flags +// if set, and defaults both to on if supported and compatible with other +// settings. +func handleTPMFlags() { + switch { + case args.hardwareAttestation.v: + if _, err := key.NewEmptyHardwareAttestationKey(); err == key.ErrUnsupported { + log.SetFlags(0) + log.Fatalf("--hardware-attestation is not supported on this platform or in this build of tailscaled") + } + case !args.hardwareAttestation.set: + policyHWAttestation, _ := policyclient.Get().GetBoolean(pkey.HardwareAttestation, feature.HardwareAttestationAvailable()) + if !policyHWAttestation { + break + } + if feature.TPMAvailable() { + args.hardwareAttestation.v = true + } + } + + switch { + case args.encryptState.v: + // Explicitly enabled, validate. + if err := canEncryptState(); err != nil { + log.SetFlags(0) + log.Fatal(err) + } + case !args.encryptState.set: + policyEncrypt, _ := policyclient.Get().GetBoolean(pkey.EncryptState, feature.TPMAvailable()) + if !policyEncrypt { + // Default disabled, no need to validate. + return + } + // Default enabled if available. + if err := canEncryptState(); err == nil { + args.encryptState.v = true + } + } +} + +// canEncryptState returns an error if state encryption can't be enabled, +// either due to availability or compatibility with other settings. +func canEncryptState() error { + if runtime.GOOS != "windows" && runtime.GOOS != "linux" { + // TPM encryption is only configurable on Windows and Linux. Other + // platforms either use system APIs and are not configurable + // (Android/Apple), or don't support any form of encryption yet + // (plan9/FreeBSD/etc). + return fmt.Errorf("--encrypt-state is not supported on %s", runtime.GOOS) + } + // Check if we have TPM access. + if !feature.TPMAvailable() { + return errors.New("--encrypt-state is not supported on this device or a TPM is not accessible") + } + // Check for conflicting prefix in --state, like arn: or kube:. + if args.statepath != "" && store.HasKnownProviderPrefix(args.statepath) { + return errors.New("--encrypt-state can only be used with --state set to a local file path") + } + + return nil +} diff --git a/cmd/tailscaled/tailscaled_drive.go b/cmd/tailscaled/tailscaled_drive.go new file mode 100644 index 000000000..49f35a381 --- /dev/null +++ b/cmd/tailscaled/tailscaled_drive.go @@ -0,0 +1,56 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !ts_omit_drive + +package main + +import ( + "errors" + "fmt" + + "tailscale.com/drive/driveimpl" + "tailscale.com/tsd" + "tailscale.com/types/logger" + "tailscale.com/wgengine" +) + +func init() { + subCommands["serve-taildrive"] = &serveDriveFunc + + hookSetSysDrive.Set(func(sys *tsd.System, logf logger.Logf) { + sys.Set(driveimpl.NewFileSystemForRemote(logf)) + }) + hookSetWgEnginConfigDrive.Set(func(conf *wgengine.Config, logf logger.Logf) { + conf.DriveForLocal = driveimpl.NewFileSystemForLocal(logf) + }) +} + +var serveDriveFunc = serveDrive + +// serveDrive serves one or more Taildrives on localhost using the WebDAV +// protocol. On UNIX and MacOS tailscaled environment, Taildrive spawns child +// tailscaled processes in serve-taildrive mode in order to access the fliesystem +// as specific (usually unprivileged) users. +// +// serveDrive prints the address on which it's listening to stdout so that the +// parent process knows where to connect to. +func serveDrive(args []string) error { + if len(args) == 0 { + return errors.New("missing shares") + } + if len(args)%2 != 0 { + return errors.New("need pairs") + } + s, err := driveimpl.NewFileServer() + if err != nil { + return fmt.Errorf("unable to start Taildrive file server: %v", err) + } + shares := make(map[string]string) + for i := 0; i < len(args); i += 2 { + shares[args[i]] = args[i+1] + } + s.SetShares(shares) + fmt.Printf("%v\n", s.Addr()) + return s.Serve() +} diff --git a/cmd/tailscaled/tailscaled_test.go b/cmd/tailscaled/tailscaled_test.go index 5045468d6..c50c23759 100644 --- a/cmd/tailscaled/tailscaled_test.go +++ b/cmd/tailscaled/tailscaled_test.go @@ -22,6 +22,8 @@ func TestDeps(t *testing.T) { BadDeps: map[string]string{ "testing": "do not use testing package in production code", "gvisor.dev/gvisor/pkg/hostarch": "will crash on non-4K page sizes; see https://github.com/tailscale/tailscale/issues/8658", + "net/http/httptest": "do not use httptest in production code", + "net/http/internal/testcert": "do not use httptest in production code", }, }.Check(t) @@ -29,8 +31,10 @@ func TestDeps(t *testing.T) { GOOS: "linux", GOARCH: "arm64", BadDeps: map[string]string{ - "testing": "do not use testing package in production code", - "gvisor.dev/gvisor/pkg/hostarch": "will crash on non-4K page sizes; see https://github.com/tailscale/tailscale/issues/8658", + "testing": "do not use testing package in production code", + "gvisor.dev/gvisor/pkg/hostarch": "will crash on non-4K page sizes; see https://github.com/tailscale/tailscale/issues/8658", + "google.golang.org/protobuf/proto": "unexpected", + "github.com/prometheus/client_golang/prometheus": "use tailscale.com/metrics in tailscaled", }, }.Check(t) } diff --git a/cmd/tailscaled/tailscaled_windows.go b/cmd/tailscaled/tailscaled_windows.go index 35c878f38..3019bbaf9 100644 --- a/cmd/tailscaled/tailscaled_windows.go +++ b/cmd/tailscaled/tailscaled_windows.go @@ -44,17 +44,21 @@ import ( "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" "tailscale.com/drive/driveimpl" "tailscale.com/envknob" + _ "tailscale.com/ipn/auditlog" + _ "tailscale.com/ipn/desktop" "tailscale.com/logpolicy" - "tailscale.com/logtail/backoff" "tailscale.com/net/dns" "tailscale.com/net/netmon" "tailscale.com/net/tstun" "tailscale.com/tsd" "tailscale.com/types/logger" "tailscale.com/types/logid" + "tailscale.com/util/backoff" "tailscale.com/util/osdiag" - "tailscale.com/util/syspolicy" + "tailscale.com/util/syspolicy/pkey" + "tailscale.com/util/syspolicy/policyclient" "tailscale.com/util/winutil" + "tailscale.com/util/winutil/gp" "tailscale.com/version" "tailscale.com/wf" ) @@ -70,6 +74,22 @@ func init() { } } +// permitPolicyLocks is a function to be called to lift the restriction on acquiring +// [gp.PolicyLock]s once the service is running. +// It is safe to be called multiple times. +var permitPolicyLocks = func() {} + +func init() { + if isWindowsService() { + // We prevent [gp.PolicyLock]s from being acquired until the service enters the running state. + // Otherwise, if tailscaled starts due to a GPSI policy installing Tailscale, it may deadlock + // while waiting for the write counterpart of the GP lock to be released by Group Policy, + // which is itself waiting for the installation to complete and tailscaled to start. + // See tailscale/tailscale#14416 for more information. + permitPolicyLocks = gp.RestrictPolicyLocks() + } +} + const serviceName = "Tailscale" // Application-defined command codes between 128 and 255 @@ -109,13 +129,13 @@ func tstunNewWithWindowsRetries(logf logger.Logf, tunName string) (_ tun.Device, } } -func isWindowsService() bool { +var isWindowsService = sync.OnceValue(func() bool { v, err := svc.IsWindowsService() if err != nil { log.Fatalf("svc.IsWindowsService failed: %v", err) } return v -} +}) // syslogf is a logger function that writes to the Windows event log (ie, the // one that you see in the Windows Event Viewer). tailscaled may optionally @@ -129,19 +149,20 @@ var syslogf logger.Logf = logger.Discard // // At this point we're still the parent process that // Windows started. +// +// pol may be nil. func runWindowsService(pol *logpolicy.Policy) error { go func() { logger.Logf(log.Printf).JSON(1, "SupportInfo", osdiag.SupportInfo(osdiag.LogSupportInfoReasonStartup)) }() - if logSCMInteractions, _ := syspolicy.GetBoolean(syspolicy.LogSCMInteractions, false); logSCMInteractions { - syslog, err := eventlog.Open(serviceName) - if err == nil { - syslogf = func(format string, args ...any) { + if syslog, err := eventlog.Open(serviceName); err == nil { + syslogf = func(format string, args ...any) { + if logSCMInteractions, _ := policyclient.Get().GetBoolean(pkey.LogSCMInteractions, false); logSCMInteractions { syslog.Info(0, fmt.Sprintf(format, args...)) } - defer syslog.Close() } + defer syslog.Close() } syslogf("Service entering svc.Run") @@ -150,7 +171,7 @@ func runWindowsService(pol *logpolicy.Policy) error { } type ipnService struct { - Policy *logpolicy.Policy + Policy *logpolicy.Policy // or nil if logging not in use } // Called by Windows to execute the windows service. @@ -160,17 +181,18 @@ func (service *ipnService) Execute(args []string, r <-chan svc.ChangeRequest, ch changes <- svc.Status{State: svc.StartPending} syslogf("Service start pending") - svcAccepts := svc.AcceptStop - if flushDNSOnSessionUnlock, _ := syspolicy.GetBoolean(syspolicy.FlushDNSOnSessionUnlock, false); flushDNSOnSessionUnlock { - svcAccepts |= svc.AcceptSessionChange - } + svcAccepts := svc.AcceptStop | svc.AcceptSessionChange ctx, cancel := context.WithCancel(context.Background()) defer cancel() doneCh := make(chan struct{}) go func() { defer close(doneCh) - args := []string{"/subproc", service.Policy.PublicID.String()} + publicID := "none" + if service.Policy != nil { + publicID = service.Policy.PublicID.String() + } + args := []string{"/subproc", publicID} // Make a logger without a date prefix, as filelogger // and logtail both already add their own. All we really want // from the log package is the automatic newline. @@ -184,6 +206,10 @@ func (service *ipnService) Execute(args []string, r <-chan svc.ChangeRequest, ch changes <- svc.Status{State: svc.Running, Accepts: svcAccepts} syslogf("Service running") + // It is safe to allow GP locks to be acquired now that the service + // is running. + permitPolicyLocks() + for { select { case <-doneCh: @@ -309,8 +335,8 @@ func beWindowsSubprocess() bool { log.Printf("Error pre-loading \"%s\": %v", fqWintunPath, err) } - sys := new(tsd.System) - netMon, err := netmon.New(log.Printf) + sys := tsd.NewSystem() + netMon, err := netmon.New(sys.Bus.Get(), log.Printf) if err != nil { log.Fatalf("Could not create netMon: %v", err) } @@ -370,14 +396,15 @@ func handleSessionChange(chgRequest svc.ChangeRequest) { if chgRequest.Cmd != svc.SessionChange || chgRequest.EventType != windows.WTS_SESSION_UNLOCK { return } - - log.Printf("Received WTS_SESSION_UNLOCK event, initiating DNS flush.") - go func() { - err := dns.Flush() - if err != nil { - log.Printf("Error flushing DNS on session unlock: %v", err) - } - }() + if flushDNSOnSessionUnlock, _ := policyclient.Get().GetBoolean(pkey.FlushDNSOnSessionUnlock, false); flushDNSOnSessionUnlock { + log.Printf("Received WTS_SESSION_UNLOCK event, initiating DNS flush.") + go func() { + err := dns.Flush() + if err != nil { + log.Printf("Error flushing DNS on session unlock: %v", err) + } + }() + } } var ( diff --git a/cmd/tailscaled/tailscaledhooks/tailscaledhooks.go b/cmd/tailscaled/tailscaledhooks/tailscaledhooks.go new file mode 100644 index 000000000..6ea662d39 --- /dev/null +++ b/cmd/tailscaled/tailscaledhooks/tailscaledhooks.go @@ -0,0 +1,12 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package tailscaledhooks provides hooks for optional features +// to add to during init that tailscaled calls at runtime. +package tailscaledhooks + +import "tailscale.com/feature" + +// UninstallSystemDaemonWindows is called when the Windows +// system daemon is uninstalled. +var UninstallSystemDaemonWindows feature.Hooks[func()] diff --git a/cmd/tailscaled/webclient.go b/cmd/tailscaled/webclient.go new file mode 100644 index 000000000..672ba7126 --- /dev/null +++ b/cmd/tailscaled/webclient.go @@ -0,0 +1,21 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !ts_omit_webclient + +package main + +import ( + "tailscale.com/client/local" + "tailscale.com/ipn/ipnlocal" + "tailscale.com/paths" +) + +func init() { + hookConfigureWebClient.Set(func(lb *ipnlocal.LocalBackend) { + lb.ConfigureWebClient(&local.Client{ + Socket: args.socketpath, + UseSocketOnly: args.socketpath != paths.DefaultTailscaledSocket(), + }) + }) +} diff --git a/cmd/testwrapper/flakytest/flakytest.go b/cmd/testwrapper/flakytest/flakytest.go index 494ed080b..856cb28ef 100644 --- a/cmd/testwrapper/flakytest/flakytest.go +++ b/cmd/testwrapper/flakytest/flakytest.go @@ -9,8 +9,12 @@ package flakytest import ( "fmt" "os" + "path" "regexp" + "sync" "testing" + + "tailscale.com/util/mak" ) // FlakyTestLogMessage is a sentinel value that is printed to stderr when a @@ -23,7 +27,12 @@ const FlakyTestLogMessage = "flakytest: this is a known flaky test" // starting at 1. const FlakeAttemptEnv = "TS_TESTWRAPPER_ATTEMPT" -var issueRegexp = regexp.MustCompile(`\Ahttps://github\.com/tailscale/[a-zA-Z0-9_.-]+/issues/\d+\z`) +var issueRegexp = regexp.MustCompile(`\Ahttps://github\.com/[a-zA-Z0-9_.-]+/[a-zA-Z0-9_.-]+/issues/\d+\z`) + +var ( + rootFlakesMu sync.Mutex + rootFlakes map[string]bool +) // Mark sets the current test as a flaky test, such that if it fails, it will // be retried a few times on failure. issue must be a GitHub issue that tracks @@ -40,5 +49,34 @@ func Mark(t testing.TB, issue string) { // spamming people running tests without the wrapper) fmt.Fprintf(os.Stderr, "%s: %s\n", FlakyTestLogMessage, issue) } + t.Attr("flaky-test-issue-url", issue) + + // The Attr method above also emits human-readable output, so this t.Logf + // is somewhat redundant, but we keep it for compatibility with + // old test runs, so cmd/testwrapper doesn't need to be modified. + // TODO(bradfitz): switch testwrapper to look for Action "attr" + // instead: + // "Action":"attr","Package":"tailscale.com/cmd/testwrapper/flakytest","Test":"TestMarked_Root","Key":"flaky-test-issue-url","Value":"https://github.com/tailscale/tailscale/issues/0"} + // And then remove this Logf a month or so after that. t.Logf("flakytest: issue tracking this flaky test: %s", issue) + + // Record the root test name as flakey. + rootFlakesMu.Lock() + defer rootFlakesMu.Unlock() + mak.Set(&rootFlakes, t.Name(), true) +} + +// Marked reports whether the current test or one of its parents was marked flaky. +func Marked(t testing.TB) bool { + n := t.Name() + for { + if rootFlakes[n] { + return true + } + n = path.Dir(n) + if n == "." || n == "/" { + break + } + } + return false } diff --git a/cmd/testwrapper/flakytest/flakytest_test.go b/cmd/testwrapper/flakytest/flakytest_test.go index 85e77a939..9b744de13 100644 --- a/cmd/testwrapper/flakytest/flakytest_test.go +++ b/cmd/testwrapper/flakytest/flakytest_test.go @@ -14,7 +14,8 @@ func TestIssueFormat(t *testing.T) { want bool }{ {"https://github.com/tailscale/cOrp/issues/1234", true}, - {"https://github.com/otherproject/corp/issues/1234", false}, + {"https://github.com/otherproject/corp/issues/1234", true}, + {"https://not.huyb/tailscale/corp/issues/1234", false}, {"https://github.com/tailscale/corp/issues/", false}, } for _, testCase := range testCases { @@ -41,3 +42,49 @@ func TestFlakeRun(t *testing.T) { t.Fatal("First run in testwrapper, failing so that test is retried. This is expected.") } } + +func TestMarked_Root(t *testing.T) { + Mark(t, "https://github.com/tailscale/tailscale/issues/0") + + t.Run("child", func(t *testing.T) { + t.Run("grandchild", func(t *testing.T) { + if got, want := Marked(t), true; got != want { + t.Fatalf("Marked(t) = %t, want %t", got, want) + } + }) + + if got, want := Marked(t), true; got != want { + t.Fatalf("Marked(t) = %t, want %t", got, want) + } + }) + + if got, want := Marked(t), true; got != want { + t.Fatalf("Marked(t) = %t, want %t", got, want) + } +} + +func TestMarked_Subtest(t *testing.T) { + t.Run("flaky", func(t *testing.T) { + Mark(t, "https://github.com/tailscale/tailscale/issues/0") + + t.Run("child", func(t *testing.T) { + t.Run("grandchild", func(t *testing.T) { + if got, want := Marked(t), true; got != want { + t.Fatalf("Marked(t) = %t, want %t", got, want) + } + }) + + if got, want := Marked(t), true; got != want { + t.Fatalf("Marked(t) = %t, want %t", got, want) + } + }) + + if got, want := Marked(t), true; got != want { + t.Fatalf("Marked(t) = %t, want %t", got, want) + } + }) + + if got, want := Marked(t), false; got != want { + t.Fatalf("Marked(t) = %t, want %t", got, want) + } +} diff --git a/cmd/testwrapper/testwrapper.go b/cmd/testwrapper/testwrapper.go index 9b8d7a7c1..173edee73 100644 --- a/cmd/testwrapper/testwrapper.go +++ b/cmd/testwrapper/testwrapper.go @@ -10,6 +10,7 @@ package main import ( "bufio" "bytes" + "cmp" "context" "encoding/json" "errors" @@ -22,15 +23,9 @@ import ( "sort" "strings" "time" - "unicode" - "github.com/dave/courtney/scanner" - "github.com/dave/courtney/shared" - "github.com/dave/courtney/tester" - "github.com/dave/patsy" - "github.com/dave/patsy/vos" - xmaps "golang.org/x/exp/maps" "tailscale.com/cmd/testwrapper/flakytest" + "tailscale.com/util/slicesx" ) const ( @@ -42,6 +37,7 @@ type testAttempt struct { testName string // "TestFoo" outcome string // "pass", "fail", "skip" logs bytes.Buffer + start, end time.Time isMarkedFlaky bool // set if the test is marked as flaky issueURL string // set if the test is marked as flaky @@ -64,11 +60,12 @@ type packageTests struct { } type goTestOutput struct { - Time time.Time - Action string - Package string - Test string - Output string + Time time.Time + Action string + ImportPath string + Package string + Test string + Output string } var debug = os.Getenv("TS_TESTWRAPPER_DEBUG") != "" @@ -116,43 +113,56 @@ func runTests(ctx context.Context, attempt int, pt *packageTests, goTestArgs, te for s.Scan() { var goOutput goTestOutput if err := json.Unmarshal(s.Bytes(), &goOutput); err != nil { - if errors.Is(err, io.EOF) || errors.Is(err, os.ErrClosed) { - break - } - - // `go test -json` outputs invalid JSON when a build fails. - // In that case, discard the the output and start reading again. - // The build error will be printed to stderr. - // See: https://github.com/golang/go/issues/35169 - if _, ok := err.(*json.SyntaxError); ok { - fmt.Println(s.Text()) - continue - } - panic(err) + return fmt.Errorf("failed to parse go test output %q: %w", s.Bytes(), err) } - pkg := goOutput.Package + pkg := cmp.Or( + goOutput.Package, + "build:"+goOutput.ImportPath, // can be "./cmd" while Package is "tailscale.com/cmd" so use separate namespace + ) pkgTests := resultMap[pkg] + if pkgTests == nil { + pkgTests = map[string]*testAttempt{ + "": {}, // Used for start time and build logs. + } + resultMap[pkg] = pkgTests + } if goOutput.Test == "" { switch goOutput.Action { - case "fail", "pass", "skip": + case "start": + pkgTests[""].start = goOutput.Time + case "build-output": + pkgTests[""].logs.WriteString(goOutput.Output) + case "build-fail", "fail", "pass", "skip": for _, test := range pkgTests { - if test.outcome == "" { + if test.testName != "" && test.outcome == "" { test.outcome = "fail" ch <- test } } + outcome := goOutput.Action + if outcome == "build-fail" { + outcome = "fail" + } + pkgTests[""].logs.WriteString(goOutput.Output) ch <- &testAttempt{ pkg: goOutput.Package, - outcome: goOutput.Action, + outcome: outcome, + start: pkgTests[""].start, + end: goOutput.Time, + logs: pkgTests[""].logs, pkgFinished: true, } + case "output": + // Capture all output from the package except for the final + // "FAIL tailscale.io/control 0.684s" line, as + // printPkgOutcome will output a similar line + if !strings.HasPrefix(goOutput.Output, fmt.Sprintf("FAIL\t%s\t", goOutput.Package)) { + pkgTests[""].logs.WriteString(goOutput.Output) + } } + continue } - if pkgTests == nil { - pkgTests = make(map[string]*testAttempt) - resultMap[pkg] = pkgTests - } testName := goOutput.Test if test, _, isSubtest := strings.Cut(goOutput.Test, "/"); isSubtest { testName = test @@ -168,8 +178,10 @@ func runTests(ctx context.Context, attempt int, pt *packageTests, goTestArgs, te pkgTests[testName] = &testAttempt{ pkg: pkg, testName: testName, + start: goOutput.Time, } case "skip", "pass", "fail": + pkgTests[testName].end = goOutput.Time pkgTests[testName].outcome = goOutput.Action ch <- pkgTests[testName] case "output": @@ -201,6 +213,16 @@ func main() { return } + // As a special case, if the packages looks like "sharded:1/2" then shell out to + // ./tool/listpkgs to cut up the package list pieces for each sharded builder. + if nOfM, ok := strings.CutPrefix(packages[0], "sharded:"); ok && len(packages) == 1 { + out, err := exec.Command("go", "run", "tailscale.com/tool/listpkgs", "-shard", nOfM, "./...").Output() + if err != nil { + log.Fatalf("failed to list packages for sharded test: %v", err) + } + packages = strings.Split(strings.TrimSpace(string(out)), "\n") + } + ctx := context.Background() type nextRun struct { tests []*packageTests @@ -213,7 +235,10 @@ func main() { firstRun.tests = append(firstRun.tests, &packageTests{Pattern: pkg}) } toRun := []*nextRun{firstRun} - printPkgOutcome := func(pkg, outcome string, attempt int) { + printPkgOutcome := func(pkg, outcome string, attempt int, runtime time.Duration) { + if pkg == "" { + return // We reach this path on a build error. + } if outcome == "skip" { fmt.Printf("?\t%s [skipped/no tests] \n", pkg) return @@ -225,36 +250,12 @@ func main() { outcome = "FAIL" } if attempt > 1 { - fmt.Printf("%s\t%s [attempt=%d]\n", outcome, pkg, attempt) + fmt.Printf("%s\t%s\t%.3fs\t[attempt=%d]\n", outcome, pkg, runtime.Seconds(), attempt) return } - fmt.Printf("%s\t%s\n", outcome, pkg) + fmt.Printf("%s\t%s\t%.3fs\n", outcome, pkg, runtime.Seconds()) } - // Check for -coverprofile argument and filter it out - combinedCoverageFilename := "" - filteredGoTestArgs := make([]string, 0, len(goTestArgs)) - preceededByCoverProfile := false - for _, arg := range goTestArgs { - if arg == "-coverprofile" { - preceededByCoverProfile = true - } else if preceededByCoverProfile { - combinedCoverageFilename = strings.TrimSpace(arg) - preceededByCoverProfile = false - } else { - filteredGoTestArgs = append(filteredGoTestArgs, arg) - } - } - goTestArgs = filteredGoTestArgs - - runningWithCoverage := combinedCoverageFilename != "" - if runningWithCoverage { - fmt.Printf("Will log coverage to %v\n", combinedCoverageFilename) - } - - // Keep track of all test coverage files. With each retry, we'll end up - // with additional coverage files that will be combined when we finish. - coverageFiles := make([]string, 0) for len(toRun) > 0 { var thisRun *nextRun thisRun, toRun = toRun[0], toRun[1:] @@ -268,27 +269,14 @@ func main() { fmt.Printf("\n\nAttempt #%d: Retrying flaky tests:\n\nflakytest failures JSON: %s\n\n", thisRun.attempt, j) } - goTestArgsWithCoverage := testArgs - if runningWithCoverage { - coverageFile := fmt.Sprintf("/tmp/coverage_%d.out", thisRun.attempt) - coverageFiles = append(coverageFiles, coverageFile) - goTestArgsWithCoverage = make([]string, len(goTestArgs), len(goTestArgs)+2) - copy(goTestArgsWithCoverage, goTestArgs) - goTestArgsWithCoverage = append( - goTestArgsWithCoverage, - fmt.Sprintf("-coverprofile=%v", coverageFile), - "-covermode=set", - "-coverpkg=./...", - ) - } - + fatalFailures := make(map[string]struct{}) // pkg.Test key toRetry := make(map[string][]*testAttempt) // pkg -> tests to retry for _, pt := range thisRun.tests { ch := make(chan *testAttempt) runErr := make(chan error, 1) go func() { defer close(runErr) - runErr <- runTests(ctx, thisRun.attempt, pt, goTestArgsWithCoverage, testArgs, ch) + runErr <- runTests(ctx, thisRun.attempt, pt, goTestArgs, testArgs, ch) }() var failed bool @@ -307,7 +295,12 @@ func main() { // when a package times out. failed = true } - printPkgOutcome(tr.pkg, tr.outcome, thisRun.attempt) + if testingVerbose || tr.outcome == "fail" { + // Output package-level output which is where e.g. + // panics outside tests will be printed + io.Copy(os.Stdout, &tr.logs) + } + printPkgOutcome(tr.pkg, tr.outcome, thisRun.attempt, tr.end.Sub(tr.start)) continue } if testingVerbose || tr.outcome == "fail" { @@ -319,11 +312,24 @@ func main() { if tr.isMarkedFlaky { toRetry[tr.pkg] = append(toRetry[tr.pkg], tr) } else { + fatalFailures[tr.pkg+"."+tr.testName] = struct{}{} failed = true } } if failed { fmt.Println("\n\nNot retrying flaky tests because non-flaky tests failed.") + + // Print the list of non-flakytest failures. + // We will later analyze the retried GitHub Action runs to see + // if non-flakytest failures succeeded upon retry. This will + // highlight tests which are flaky but not yet flagged as such. + if len(fatalFailures) > 0 { + tests := slicesx.MapKeys(fatalFailures) + sort.Strings(tests) + j, _ := json.Marshal(tests) + fmt.Printf("non-flakytest failures: %s\n", j) + } + fmt.Println() os.Exit(1) } @@ -343,7 +349,7 @@ func main() { if len(toRetry) == 0 { continue } - pkgs := xmaps.Keys(toRetry) + pkgs := slicesx.MapKeys(toRetry) sort.Strings(pkgs) nextRun := &nextRun{ attempt: thisRun.attempt + 1, @@ -365,107 +371,4 @@ func main() { } toRun = append(toRun, nextRun) } - - if runningWithCoverage { - intermediateCoverageFilename := "/tmp/coverage.out_intermediate" - if err := combineCoverageFiles(intermediateCoverageFilename, coverageFiles); err != nil { - fmt.Printf("error combining coverage files: %v\n", err) - os.Exit(2) - } - - if err := processCoverageWithCourtney(intermediateCoverageFilename, combinedCoverageFilename, testArgs); err != nil { - fmt.Printf("error processing coverage with courtney: %v\n", err) - os.Exit(3) - } - - fmt.Printf("Wrote combined coverage to %v\n", combinedCoverageFilename) - } -} - -func combineCoverageFiles(intermediateCoverageFilename string, coverageFiles []string) error { - combinedCoverageFile, err := os.OpenFile(intermediateCoverageFilename, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644) - if err != nil { - return fmt.Errorf("create /tmp/coverage.out: %w", err) - } - defer combinedCoverageFile.Close() - w := bufio.NewWriter(combinedCoverageFile) - defer w.Flush() - - for fileNumber, coverageFile := range coverageFiles { - f, err := os.Open(coverageFile) - if err != nil { - return fmt.Errorf("open %v: %w", coverageFile, err) - } - defer f.Close() - in := bufio.NewReader(f) - line := 0 - for { - r, _, err := in.ReadRune() - if err != nil { - if err != io.EOF { - return fmt.Errorf("read %v: %w", coverageFile, err) - } - break - } - - // On all but the first coverage file, skip the coverage file header - if fileNumber > 0 && line == 0 { - continue - } - if r == '\n' { - line++ - } - - // filter for only printable characters because coverage file sometimes includes junk on 2nd line - if unicode.IsPrint(r) || r == '\n' { - if _, err := w.WriteRune(r); err != nil { - return fmt.Errorf("write %v: %w", combinedCoverageFile.Name(), err) - } - } - } - } - - return nil -} - -// processCoverageWithCourtney post-processes code coverage to exclude less -// meaningful sections like 'if err != nil { return err}', as well as -// anything marked with a '// notest' comment. -// -// instead of running the courtney as a separate program, this embeds -// courtney for easier integration. -func processCoverageWithCourtney(intermediateCoverageFilename, combinedCoverageFilename string, testArgs []string) error { - env := vos.Os() - - setup := &shared.Setup{ - Env: vos.Os(), - Paths: patsy.NewCache(env), - TestArgs: testArgs, - Load: intermediateCoverageFilename, - Output: combinedCoverageFilename, - } - if err := setup.Parse(testArgs); err != nil { - return fmt.Errorf("parse args: %w", err) - } - - s := scanner.New(setup) - if err := s.LoadProgram(); err != nil { - return fmt.Errorf("load program: %w", err) - } - if err := s.ScanPackages(); err != nil { - return fmt.Errorf("scan packages: %w", err) - } - - t := tester.New(setup) - if err := t.Load(); err != nil { - return fmt.Errorf("load: %w", err) - } - if err := t.ProcessExcludes(s.Excludes); err != nil { - return fmt.Errorf("process excludes: %w", err) - } - if err := t.Save(); err != nil { - return fmt.Errorf("save: %w", err) - } - - return nil } diff --git a/cmd/testwrapper/testwrapper_test.go b/cmd/testwrapper/testwrapper_test.go index d7dbccd09..ace53ccd0 100644 --- a/cmd/testwrapper/testwrapper_test.go +++ b/cmd/testwrapper/testwrapper_test.go @@ -10,6 +10,8 @@ import ( "os" "os/exec" "path/filepath" + "regexp" + "strings" "sync" "testing" ) @@ -76,7 +78,10 @@ func TestFlakeRun(t *testing.T) { t.Fatalf("go run . %s: %s with output:\n%s", testfile, err, out) } - want := []byte("ok\t" + testfile + " [attempt=2]") + // Replace the unpredictable timestamp with "0.00s". + out = regexp.MustCompile(`\t\d+\.\d\d\ds\t`).ReplaceAll(out, []byte("\t0.00s\t")) + + want := []byte("ok\t" + testfile + "\t0.00s\t[attempt=2]") if !bytes.Contains(out, want) { t.Fatalf("wanted output containing %q but got:\n%s", want, out) } @@ -150,24 +155,24 @@ func TestBuildError(t *testing.T) { t.Fatalf("writing package: %s", err) } - buildErr := []byte("builderror_test.go:3:1: expected declaration, found derp\nFAIL command-line-arguments [setup failed]") + wantErr := "builderror_test.go:3:1: expected declaration, found derp\nFAIL" // Confirm `go test` exits with code 1. goOut, err := exec.Command("go", "test", testfile).CombinedOutput() if code, ok := errExitCode(err); !ok || code != 1 { - t.Fatalf("go test %s: expected error with exit code 0 but got: %v", testfile, err) + t.Fatalf("go test %s: got exit code %d, want 1 (err: %v)", testfile, code, err) } - if !bytes.Contains(goOut, buildErr) { - t.Fatalf("go test %s: expected build error containing %q but got:\n%s", testfile, buildErr, goOut) + if !strings.Contains(string(goOut), wantErr) { + t.Fatalf("go test %s: got output %q, want output containing %q", testfile, goOut, wantErr) } // Confirm `testwrapper` exits with code 1. twOut, err := cmdTestwrapper(t, testfile).CombinedOutput() if code, ok := errExitCode(err); !ok || code != 1 { - t.Fatalf("testwrapper %s: expected error with exit code 0 but got: %v", testfile, err) + t.Fatalf("testwrapper %s: got exit code %d, want 1 (err: %v)", testfile, code, err) } - if !bytes.Contains(twOut, buildErr) { - t.Fatalf("testwrapper %s: expected build error containing %q but got:\n%s", testfile, buildErr, twOut) + if !strings.Contains(string(twOut), wantErr) { + t.Fatalf("testwrapper %s: got output %q, want output containing %q", testfile, twOut, wantErr) } if testing.Verbose() { diff --git a/cmd/tl-longchain/tl-longchain.go b/cmd/tl-longchain/tl-longchain.go index c92714505..384d24222 100644 --- a/cmd/tl-longchain/tl-longchain.go +++ b/cmd/tl-longchain/tl-longchain.go @@ -22,7 +22,7 @@ import ( "log" "time" - "tailscale.com/client/tailscale" + "tailscale.com/client/local" "tailscale.com/ipn/ipnstate" "tailscale.com/tka" "tailscale.com/types/key" @@ -37,7 +37,7 @@ var ( func main() { flag.Parse() - lc := tailscale.LocalClient{Socket: *flagSocket} + lc := local.Client{Socket: *flagSocket} if lc.Socket != "" { lc.UseSocketOnly = true } @@ -75,8 +75,8 @@ func peerInfo(peer *ipnstate.TKAPeer) string { // print prints a message about a node key signature and a re-signing command if needed. func print(info string, nodeKey key.NodePublic, sig tka.NodeKeySignature) { - if l := chainLength(sig); l > *maxRotations { - log.Printf("%s: chain length %d, printing command to re-sign", info, l) + if ln := chainLength(sig); ln > *maxRotations { + log.Printf("%s: chain length %d, printing command to re-sign", info, ln) wrapping, _ := sig.UnverifiedWrappingPublic() fmt.Printf("tailscale lock sign %s %s\n", nodeKey, key.NLPublicFromEd25519Unsafe(wrapping).CLIString()) } else { diff --git a/cmd/tsconnect/common.go b/cmd/tsconnect/common.go index a387c00c9..ff10e4efb 100644 --- a/cmd/tsconnect/common.go +++ b/cmd/tsconnect/common.go @@ -150,6 +150,7 @@ func runEsbuildServe(buildOptions esbuild.BuildOptions) { log.Fatalf("Cannot start esbuild server: %v", err) } log.Printf("Listening on http://%s:%d\n", result.Host, result.Port) + select {} } func runEsbuild(buildOptions esbuild.BuildOptions) esbuild.BuildResult { @@ -175,6 +176,10 @@ func runEsbuild(buildOptions esbuild.BuildOptions) esbuild.BuildResult { // wasm_exec.js runtime helper library from the Go toolchain. func setupEsbuildWasmExecJS(build esbuild.PluginBuild) { wasmExecSrcPath := filepath.Join(runtime.GOROOT(), "misc", "wasm", "wasm_exec.js") + if _, err := os.Stat(wasmExecSrcPath); os.IsNotExist(err) { + // Go 1.24+ location: + wasmExecSrcPath = filepath.Join(runtime.GOROOT(), "lib", "wasm", "wasm_exec.js") + } build.OnResolve(esbuild.OnResolveOptions{ Filter: "./wasm_exec$", }, func(args esbuild.OnResolveArgs) (esbuild.OnResolveResult, error) { diff --git a/cmd/tsconnect/tsconnect.go b/cmd/tsconnect/tsconnect.go index 4c8a0a52e..ef55593b4 100644 --- a/cmd/tsconnect/tsconnect.go +++ b/cmd/tsconnect/tsconnect.go @@ -53,12 +53,12 @@ func main() { } func usage() { - fmt.Fprintf(os.Stderr, ` + fmt.Fprint(os.Stderr, ` usage: tsconnect {dev|build|serve} `[1:]) flag.PrintDefaults() - fmt.Fprintf(os.Stderr, ` + fmt.Fprint(os.Stderr, ` tsconnect implements development/build/serving workflows for Tailscale Connect. It can be invoked with one of three subcommands: diff --git a/cmd/tsconnect/wasm/wasm_js.go b/cmd/tsconnect/wasm/wasm_js.go index 8291ac9b4..c7aa00d1d 100644 --- a/cmd/tsconnect/wasm/wasm_js.go +++ b/cmd/tsconnect/wasm/wasm_js.go @@ -27,6 +27,7 @@ import ( "golang.org/x/crypto/ssh" "tailscale.com/control/controlclient" "tailscale.com/ipn" + "tailscale.com/ipn/ipnauth" "tailscale.com/ipn/ipnlocal" "tailscale.com/ipn/ipnserver" "tailscale.com/ipn/store/mem" @@ -100,21 +101,24 @@ func newIPN(jsConfig js.Value) map[string]any { logtail := logtail.NewLogger(c, log.Printf) logf := logtail.Logf - sys := new(tsd.System) + sys := tsd.NewSystem() sys.Set(store) dialer := &tsdial.Dialer{Logf: logf} + dialer.SetBus(sys.Bus.Get()) eng, err := wgengine.NewUserspaceEngine(logf, wgengine.Config{ Dialer: dialer, SetSubsystem: sys.Set, ControlKnobs: sys.ControlKnobs(), - HealthTracker: sys.HealthTracker(), + HealthTracker: sys.HealthTracker.Get(), + Metrics: sys.UserMetricsRegistry(), + EventBus: sys.Bus.Get(), }) if err != nil { log.Fatal(err) } sys.Set(eng) - ns, err := netstack.Create(logf, sys.Tun.Get(), eng, sys.MagicSock.Get(), dialer, sys.DNSManager.Get(), sys.ProxyMapper(), nil) + ns, err := netstack.Create(logf, sys.Tun.Get(), eng, sys.MagicSock.Get(), dialer, sys.DNSManager.Get(), sys.ProxyMapper()) if err != nil { log.Fatalf("netstack.Create: %v", err) } @@ -128,11 +132,14 @@ func newIPN(jsConfig js.Value) map[string]any { dialer.NetstackDialTCP = func(ctx context.Context, dst netip.AddrPort) (net.Conn, error) { return ns.DialContextTCP(ctx, dst) } + dialer.NetstackDialUDP = func(ctx context.Context, dst netip.AddrPort) (net.Conn, error) { + return ns.DialContextUDP(ctx, dst) + } sys.NetstackRouter.Set(true) sys.Tun.Get().Start() logid := lpc.PublicID - srv := ipnserver.New(logf, logid, sys.NetMon.Get()) + srv := ipnserver.New(logf, logid, sys.Bus.Get(), sys.NetMon.Get()) lb, err := ipnlocal.NewLocalBackend(logf, logid, sys, controlclient.LoginEphemeral) if err != nil { log.Fatalf("ipnlocal.NewLocalBackend: %v", err) @@ -254,7 +261,7 @@ func (i *jsIPN) run(jsCallbacks js.Value) { jsNetMap := jsNetMap{ Self: jsNetMapSelfNode{ jsNetMapNode: jsNetMapNode{ - Name: nm.Name, + Name: nm.SelfName(), Addresses: mapSliceView(nm.GetAddresses(), func(a netip.Prefix) string { return a.Addr().String() }), NodeKey: nm.NodeKey.String(), MachineKey: nm.MachineKey.String(), @@ -268,8 +275,8 @@ func (i *jsIPN) run(jsCallbacks js.Value) { name = p.Hostinfo().Hostname() } addrs := make([]string, p.Addresses().Len()) - for i := range p.Addresses().Len() { - addrs[i] = p.Addresses().At(i).Addr().String() + for i, ap := range p.Addresses().All() { + addrs[i] = ap.Addr().String() } return jsNetMapPeerNode{ jsNetMapNode: jsNetMapNode{ @@ -278,7 +285,7 @@ func (i *jsIPN) run(jsCallbacks js.Value) { MachineKey: p.Machine().String(), NodeKey: p.Key().String(), }, - Online: p.Online(), + Online: p.Online().Clone(), TailscaleSSHEnabled: p.Hostinfo().TailscaleSSHEnabled(), } }), @@ -332,7 +339,7 @@ func (i *jsIPN) logout() { go func() { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - i.lb.Logout(ctx) + i.lb.Logout(ctx, ipnauth.Self) }() } @@ -457,7 +464,6 @@ func (s *jsSSHSession) Run() { cols = s.pendingResizeCols } err = session.RequestPty("xterm", rows, cols, ssh.TerminalModes{}) - if err != nil { writeError("Pseudo Terminal", err) return @@ -585,8 +591,8 @@ func mapSlice[T any, M any](a []T, f func(T) M) []M { func mapSliceView[T any, M any](a views.Slice[T], f func(T) M) []M { n := make([]M, a.Len()) - for i := range a.Len() { - n[i] = f(a.At(i)) + for i, v := range a.All() { + n[i] = f(v) } return n } diff --git a/cmd/tsconnect/yarn.lock b/cmd/tsconnect/yarn.lock index 663a1244e..d9d9db32f 100644 --- a/cmd/tsconnect/yarn.lock +++ b/cmd/tsconnect/yarn.lock @@ -90,11 +90,11 @@ binary-extensions@^2.0.0: integrity sha512-jDctJ/IVQbZoJykoeHbhXpOlNBqGNcwXJKJog42E5HDPUwQTSdjCHdihjj0DlnheQ7blbT6dHOafNAiS8ooQKA== braces@^3.0.2, braces@~3.0.2: - version "3.0.2" - resolved "https://registry.yarnpkg.com/braces/-/braces-3.0.2.tgz#3454e1a462ee8d599e236df336cd9ea4f8afe107" - integrity sha512-b8um+L1RzM3WDSzvhm6gIz1yfTbBt6YTlcEKAvsmqCZZFw46z626lVj9j1yEPW33H5H+lBQpZMP1k8l+78Ha0A== + version "3.0.3" + resolved "https://registry.yarnpkg.com/braces/-/braces-3.0.3.tgz#490332f40919452272d55a8480adc0c441358789" + integrity sha512-yQbXgO/OSZVD2IsiLlro+7Hf6Q18EJrKSEsdoMzKePKXct3gvD8oLcOQdIzGupr5Fj+EDe8gO/lxc1BzfMpxvA== dependencies: - fill-range "^7.0.1" + fill-range "^7.1.1" camelcase-css@^2.0.1: version "2.0.1" @@ -231,10 +231,10 @@ fastq@^1.6.0: dependencies: reusify "^1.0.4" -fill-range@^7.0.1: - version "7.0.1" - resolved "https://registry.yarnpkg.com/fill-range/-/fill-range-7.0.1.tgz#1919a6a7c75fe38b2c7c77e5198535da9acdda40" - integrity sha512-qOo9F+dMUmC2Lcb4BbVvnKJxTPjCm+RRpe4gDuGrzkL7mEVl/djYSu2OdQ2Pa302N4oqkSg9ir6jaLWJ2USVpQ== +fill-range@^7.1.1: + version "7.1.1" + resolved "https://registry.yarnpkg.com/fill-range/-/fill-range-7.1.1.tgz#44265d3cac07e3ea7dc247516380643754a05292" + integrity sha512-YsGpe3WHLK8ZYi4tWDg2Jy3ebRz2rXowDxnld4bkQB00cc/1Zw9AWnC0i9ztDJitivtQvaI9KaLyKrc+hBW0yg== dependencies: to-regex-range "^5.0.1" @@ -349,9 +349,9 @@ minimist@^1.2.6: integrity sha512-Jsjnk4bw3YJqYzbdyBiNsPWHPfO++UGG749Cxs6peCu5Xg4nrena6OVxOYxrQTqww0Jmwt+Ref8rggumkTLz9Q== nanoid@^3.3.4: - version "3.3.4" - resolved "https://registry.yarnpkg.com/nanoid/-/nanoid-3.3.4.tgz#730b67e3cd09e2deacf03c027c81c9d9dbc5e8ab" - integrity sha512-MqBkQh/OHTS2egovRtLk45wEyNXwF+cokD+1YPf9u5VfJiRdAiRwB2froX5Co9Rh20xs4siNPm8naNotSD6RBw== + version "3.3.8" + resolved "https://registry.yarnpkg.com/nanoid/-/nanoid-3.3.8.tgz#b1be3030bee36aaff18bacb375e5cce521684baf" + integrity sha512-WNLf5Sd8oZxOm+TzppcYk8gVOgP+l58xNy58D0nbUnOxOWRWvlcCV4kUF7ltmI6PsrLl/BgKEyS4mqsGChFN0w== normalize-path@^3.0.0, normalize-path@~3.0.0: version "3.0.0" diff --git a/cmd/tsidp/README.md b/cmd/tsidp/README.md new file mode 100644 index 000000000..1635feabf --- /dev/null +++ b/cmd/tsidp/README.md @@ -0,0 +1,103 @@ +> [!CAUTION] +> Development of tsidp has been moved to [https://github.com/tailscale/tsidp](https://github.com/tailscale/tsidp) and it is no longer maintained here. Please visit the new repository to see the latest updates, file an issue, or contribute. + +# `tsidp` - Tailscale OpenID Connect (OIDC) Identity Provider + +[![status: community project](https://img.shields.io/badge/status-community_project-blue)](https://tailscale.com/kb/1531/community-projects) + +`tsidp` is an OIDC Identity Provider (IdP) server that integrates with your Tailscale network. It allows you to use Tailscale identities for authentication in applications that support OpenID Connect, enabling single sign-on (SSO) capabilities within your tailnet. + +## Prerequisites + +- A Tailscale network (tailnet) with magicDNS and HTTPS enabled +- A Tailscale authentication key from your tailnet +- Docker installed on your system + +## Installation using Docker + +### Pre-built image + +A pre-built tsidp image exists at `tailscale/tsidp:unstable`. + +### Building from Source + +```bash +# Clone the Tailscale repository +git clone https://github.com/tailscale/tailscale.git +cd tailscale + +# Build and publish to your own registry +make publishdevtsidp REPO=ghcr.io/yourusername/tsidp TAGS=v0.0.1 PUSH=true +``` + +### Running the Container + +Replace `YOUR_TAILSCALE_AUTHKEY` with your Tailscale authentication key: + +```bash +docker run -d \ + --name tsidp \ + -p 443:443 \ + -e TS_AUTHKEY=YOUR_TAILSCALE_AUTHKEY \ + -e TAILSCALE_USE_WIP_CODE=1 \ + -v tsidp-data:/var/lib/tsidp \ + ghcr.io/yourusername/tsidp:v0.0.1 \ + tsidp --hostname=idp --dir=/var/lib/tsidp +``` + +### Verify Installation +```bash +docker logs tsidp +``` + +Visit `https://idp.tailnet.ts.net` to confirm the service is running. + +## Usage Example: Proxmox Integration + +Here's how to configure Proxmox to use `tsidp` for authentication: + +1. In Proxmox, navigate to Datacenter > Realms > Add OpenID Connect Server + +2. Configure the following settings: + - Issuer URL: `https://idp.velociraptor.ts.net` + - Realm: `tailscale` (or your preferred name) + - Client ID: `unused` + - Client Key: `unused` + - Default: `true` + - Autocreate users: `true` + - Username claim: `email` + +3. Set up user permissions: + - Go to Datacenter > Permissions > Groups + - Create a new group (e.g., "tsadmins") + - Click Permissions in the sidebar + - Add Group Permission + - Set Path to `/` for full admin access or scope as needed + - Set the group and role + - Add Tailscale-authenticated users to the group + +## Configuration Options + +The `tsidp` server supports several command-line flags: + +- `--verbose`: Enable verbose logging +- `--port`: Port to listen on (default: 443) +- `--local-port`: Allow requests from localhost +- `--use-local-tailscaled`: Use local tailscaled instead of tsnet +- `--hostname`: tsnet hostname +- `--dir`: tsnet state directory + +## Environment Variables + +- `TS_AUTHKEY`: Your Tailscale authentication key (required) +- `TS_HOSTNAME`: Hostname for the `tsidp` server (default: "idp", Docker only) +- `TS_STATE_DIR`: State directory (default: "/var/lib/tsidp", Docker only) +- `TAILSCALE_USE_WIP_CODE`: Enable work-in-progress code (default: "1") + +## Support + +This is an experimental, work in progress, [community project](https://tailscale.com/kb/1531/community-projects). For issues or questions, file issues on the [GitHub repository](https://github.com/tailscale/tailscale). + +## License + +BSD-3-Clause License. See [LICENSE](../../LICENSE) for details. diff --git a/cmd/tsidp/depaware.txt b/cmd/tsidp/depaware.txt new file mode 100644 index 000000000..14db7414a --- /dev/null +++ b/cmd/tsidp/depaware.txt @@ -0,0 +1,584 @@ +tailscale.com/cmd/tsidp dependencies: (generated by github.com/tailscale/depaware) + + filippo.io/edwards25519 from github.com/hdevalence/ed25519consensus + filippo.io/edwards25519/field from filippo.io/edwards25519 + W đŸ’Ŗ github.com/alexbrainman/sspi from github.com/alexbrainman/sspi/internal/common+ + W github.com/alexbrainman/sspi/internal/common from github.com/alexbrainman/sspi/negotiate + W đŸ’Ŗ github.com/alexbrainman/sspi/negotiate from tailscale.com/net/tshttpproxy + github.com/coder/websocket from tailscale.com/util/eventbus + github.com/coder/websocket/internal/errd from github.com/coder/websocket + github.com/coder/websocket/internal/util from github.com/coder/websocket + github.com/coder/websocket/internal/xsync from github.com/coder/websocket + github.com/creachadair/msync/trigger from tailscale.com/logtail + W đŸ’Ŗ github.com/dblohm7/wingoes from tailscale.com/net/tshttpproxy+ + W đŸ’Ŗ github.com/dblohm7/wingoes/com from tailscale.com/util/osdiag+ + W đŸ’Ŗ github.com/dblohm7/wingoes/com/automation from tailscale.com/util/osdiag/internal/wsc + W github.com/dblohm7/wingoes/internal from github.com/dblohm7/wingoes/com + W đŸ’Ŗ github.com/dblohm7/wingoes/pe from tailscale.com/util/osdiag+ + github.com/fxamacker/cbor/v2 from tailscale.com/tka + github.com/gaissmai/bart from tailscale.com/net/ipset+ + github.com/gaissmai/bart/internal/bitset from github.com/gaissmai/bart+ + github.com/gaissmai/bart/internal/sparse from github.com/gaissmai/bart + github.com/go-json-experiment/json from tailscale.com/types/opt+ + github.com/go-json-experiment/json/internal from github.com/go-json-experiment/json+ + github.com/go-json-experiment/json/internal/jsonflags from github.com/go-json-experiment/json+ + github.com/go-json-experiment/json/internal/jsonopts from github.com/go-json-experiment/json+ + github.com/go-json-experiment/json/internal/jsonwire from github.com/go-json-experiment/json+ + github.com/go-json-experiment/json/jsontext from github.com/go-json-experiment/json+ + L đŸ’Ŗ github.com/godbus/dbus/v5 from tailscale.com/net/dns + github.com/golang/groupcache/lru from tailscale.com/net/dnscache + github.com/google/btree from gvisor.dev/gvisor/pkg/tcpip/header+ + D github.com/google/uuid from github.com/prometheus-community/pro-bing + github.com/hdevalence/ed25519consensus from tailscale.com/tka + L đŸ’Ŗ github.com/jsimonetti/rtnetlink from tailscale.com/net/netmon + L github.com/jsimonetti/rtnetlink/internal/unix from github.com/jsimonetti/rtnetlink + github.com/klauspost/compress from github.com/klauspost/compress/zstd + github.com/klauspost/compress/fse from github.com/klauspost/compress/huff0 + github.com/klauspost/compress/huff0 from github.com/klauspost/compress/zstd + github.com/klauspost/compress/internal/cpuinfo from github.com/klauspost/compress/huff0+ + github.com/klauspost/compress/internal/snapref from github.com/klauspost/compress/zstd + github.com/klauspost/compress/zstd from tailscale.com/util/zstdframe + github.com/klauspost/compress/zstd/internal/xxhash from github.com/klauspost/compress/zstd + L đŸ’Ŗ github.com/mdlayher/netlink from github.com/jsimonetti/rtnetlink+ + L đŸ’Ŗ github.com/mdlayher/netlink/nlenc from github.com/jsimonetti/rtnetlink+ + L đŸ’Ŗ github.com/mdlayher/socket from github.com/mdlayher/netlink+ + đŸ’Ŗ github.com/mitchellh/go-ps from tailscale.com/safesocket + github.com/pires/go-proxyproto from tailscale.com/ipn/ipnlocal + D github.com/prometheus-community/pro-bing from tailscale.com/wgengine/netstack + L đŸ’Ŗ github.com/safchain/ethtool from tailscale.com/net/netkernelconf + W đŸ’Ŗ github.com/tailscale/certstore from tailscale.com/control/controlclient + W đŸ’Ŗ github.com/tailscale/go-winio from tailscale.com/safesocket + W đŸ’Ŗ github.com/tailscale/go-winio/internal/fs from github.com/tailscale/go-winio + W đŸ’Ŗ github.com/tailscale/go-winio/internal/socket from github.com/tailscale/go-winio + W github.com/tailscale/go-winio/internal/stringbuffer from github.com/tailscale/go-winio/internal/fs + W github.com/tailscale/go-winio/pkg/guid from github.com/tailscale/go-winio+ + github.com/tailscale/goupnp from github.com/tailscale/goupnp/dcps/internetgateway2+ + github.com/tailscale/goupnp/dcps/internetgateway2 from tailscale.com/net/portmapper + github.com/tailscale/goupnp/httpu from github.com/tailscale/goupnp+ + github.com/tailscale/goupnp/scpd from github.com/tailscale/goupnp + github.com/tailscale/goupnp/soap from github.com/tailscale/goupnp+ + github.com/tailscale/goupnp/ssdp from github.com/tailscale/goupnp + github.com/tailscale/hujson from tailscale.com/ipn/conffile + LD github.com/tailscale/peercred from tailscale.com/ipn/ipnauth + github.com/tailscale/web-client-prebuilt from tailscale.com/client/web + đŸ’Ŗ github.com/tailscale/wireguard-go/conn from github.com/tailscale/wireguard-go/device+ + W đŸ’Ŗ github.com/tailscale/wireguard-go/conn/winrio from github.com/tailscale/wireguard-go/conn + đŸ’Ŗ github.com/tailscale/wireguard-go/device from tailscale.com/net/tstun+ + đŸ’Ŗ github.com/tailscale/wireguard-go/ipc from github.com/tailscale/wireguard-go/device + W đŸ’Ŗ github.com/tailscale/wireguard-go/ipc/namedpipe from github.com/tailscale/wireguard-go/ipc + github.com/tailscale/wireguard-go/ratelimiter from github.com/tailscale/wireguard-go/device + github.com/tailscale/wireguard-go/replay from github.com/tailscale/wireguard-go/device + github.com/tailscale/wireguard-go/rwcancel from github.com/tailscale/wireguard-go/device+ + github.com/tailscale/wireguard-go/tai64n from github.com/tailscale/wireguard-go/device + đŸ’Ŗ github.com/tailscale/wireguard-go/tun from github.com/tailscale/wireguard-go/device+ + github.com/x448/float16 from github.com/fxamacker/cbor/v2 + đŸ’Ŗ go4.org/mem from tailscale.com/client/local+ + go4.org/netipx from tailscale.com/ipn/ipnlocal+ + W đŸ’Ŗ golang.zx2c4.com/wintun from github.com/tailscale/wireguard-go/tun + W đŸ’Ŗ golang.zx2c4.com/wireguard/windows/tunnel/winipcfg from tailscale.com/net/dns+ + gopkg.in/square/go-jose.v2 from gopkg.in/square/go-jose.v2/jwt+ + gopkg.in/square/go-jose.v2/cipher from gopkg.in/square/go-jose.v2 + gopkg.in/square/go-jose.v2/json from gopkg.in/square/go-jose.v2+ + gopkg.in/square/go-jose.v2/jwt from tailscale.com/cmd/tsidp + gvisor.dev/gvisor/pkg/atomicbitops from gvisor.dev/gvisor/pkg/buffer+ + gvisor.dev/gvisor/pkg/bits from gvisor.dev/gvisor/pkg/buffer + đŸ’Ŗ gvisor.dev/gvisor/pkg/buffer from gvisor.dev/gvisor/pkg/tcpip+ + gvisor.dev/gvisor/pkg/context from gvisor.dev/gvisor/pkg/refs + đŸ’Ŗ gvisor.dev/gvisor/pkg/gohacks from gvisor.dev/gvisor/pkg/state/wire+ + gvisor.dev/gvisor/pkg/linewriter from gvisor.dev/gvisor/pkg/log + gvisor.dev/gvisor/pkg/log from gvisor.dev/gvisor/pkg/context+ + gvisor.dev/gvisor/pkg/rand from gvisor.dev/gvisor/pkg/tcpip+ + gvisor.dev/gvisor/pkg/refs from gvisor.dev/gvisor/pkg/buffer+ + đŸ’Ŗ gvisor.dev/gvisor/pkg/sleep from gvisor.dev/gvisor/pkg/tcpip/transport/tcp + đŸ’Ŗ gvisor.dev/gvisor/pkg/state from gvisor.dev/gvisor/pkg/atomicbitops+ + gvisor.dev/gvisor/pkg/state/wire from gvisor.dev/gvisor/pkg/state + đŸ’Ŗ gvisor.dev/gvisor/pkg/sync from gvisor.dev/gvisor/pkg/atomicbitops+ + đŸ’Ŗ gvisor.dev/gvisor/pkg/sync/locking from gvisor.dev/gvisor/pkg/tcpip/stack + gvisor.dev/gvisor/pkg/tcpip from gvisor.dev/gvisor/pkg/tcpip/adapters/gonet+ + gvisor.dev/gvisor/pkg/tcpip/adapters/gonet from tailscale.com/wgengine/netstack + đŸ’Ŗ gvisor.dev/gvisor/pkg/tcpip/checksum from gvisor.dev/gvisor/pkg/buffer+ + gvisor.dev/gvisor/pkg/tcpip/hash/jenkins from gvisor.dev/gvisor/pkg/tcpip/stack+ + gvisor.dev/gvisor/pkg/tcpip/header from gvisor.dev/gvisor/pkg/tcpip/header/parse+ + gvisor.dev/gvisor/pkg/tcpip/header/parse from gvisor.dev/gvisor/pkg/tcpip/network/ipv4+ + gvisor.dev/gvisor/pkg/tcpip/internal/tcp from gvisor.dev/gvisor/pkg/tcpip/transport/tcp + gvisor.dev/gvisor/pkg/tcpip/network/hash from gvisor.dev/gvisor/pkg/tcpip/network/ipv4 + gvisor.dev/gvisor/pkg/tcpip/network/internal/fragmentation from gvisor.dev/gvisor/pkg/tcpip/network/ipv4+ + gvisor.dev/gvisor/pkg/tcpip/network/internal/ip from gvisor.dev/gvisor/pkg/tcpip/network/ipv4+ + gvisor.dev/gvisor/pkg/tcpip/network/internal/multicast from gvisor.dev/gvisor/pkg/tcpip/network/ipv4+ + gvisor.dev/gvisor/pkg/tcpip/network/ipv4 from tailscale.com/wgengine/netstack + gvisor.dev/gvisor/pkg/tcpip/network/ipv6 from tailscale.com/wgengine/netstack + gvisor.dev/gvisor/pkg/tcpip/ports from gvisor.dev/gvisor/pkg/tcpip/stack+ + gvisor.dev/gvisor/pkg/tcpip/seqnum from gvisor.dev/gvisor/pkg/tcpip/header+ + đŸ’Ŗ gvisor.dev/gvisor/pkg/tcpip/stack from gvisor.dev/gvisor/pkg/tcpip/adapters/gonet+ + gvisor.dev/gvisor/pkg/tcpip/stack/gro from tailscale.com/wgengine/netstack/gro + gvisor.dev/gvisor/pkg/tcpip/transport from gvisor.dev/gvisor/pkg/tcpip/transport/icmp+ + gvisor.dev/gvisor/pkg/tcpip/transport/icmp from tailscale.com/wgengine/netstack + gvisor.dev/gvisor/pkg/tcpip/transport/internal/network from gvisor.dev/gvisor/pkg/tcpip/transport/icmp+ + gvisor.dev/gvisor/pkg/tcpip/transport/internal/noop from gvisor.dev/gvisor/pkg/tcpip/transport/raw + gvisor.dev/gvisor/pkg/tcpip/transport/packet from gvisor.dev/gvisor/pkg/tcpip/transport/raw + gvisor.dev/gvisor/pkg/tcpip/transport/raw from gvisor.dev/gvisor/pkg/tcpip/transport/icmp+ + đŸ’Ŗ gvisor.dev/gvisor/pkg/tcpip/transport/tcp from gvisor.dev/gvisor/pkg/tcpip/adapters/gonet+ + gvisor.dev/gvisor/pkg/tcpip/transport/tcpconntrack from gvisor.dev/gvisor/pkg/tcpip/stack + gvisor.dev/gvisor/pkg/tcpip/transport/udp from gvisor.dev/gvisor/pkg/tcpip/adapters/gonet+ + gvisor.dev/gvisor/pkg/waiter from gvisor.dev/gvisor/pkg/context+ + tailscale.com from tailscale.com/version + tailscale.com/appc from tailscale.com/ipn/ipnlocal + đŸ’Ŗ tailscale.com/atomicfile from tailscale.com/ipn+ + tailscale.com/client/local from tailscale.com/client/web+ + tailscale.com/client/tailscale from tailscale.com/internal/client/tailscale + tailscale.com/client/tailscale/apitype from tailscale.com/client/local+ + tailscale.com/client/web from tailscale.com/ipn/ipnlocal + tailscale.com/control/controlbase from tailscale.com/control/controlhttp+ + tailscale.com/control/controlclient from tailscale.com/ipn/ipnext+ + tailscale.com/control/controlhttp from tailscale.com/control/ts2021 + tailscale.com/control/controlhttp/controlhttpcommon from tailscale.com/control/controlhttp + tailscale.com/control/controlknobs from tailscale.com/control/controlclient+ + tailscale.com/control/ts2021 from tailscale.com/control/controlclient + tailscale.com/derp from tailscale.com/derp/derphttp+ + tailscale.com/derp/derpconst from tailscale.com/derp/derphttp+ + tailscale.com/derp/derphttp from tailscale.com/ipn/localapi+ + tailscale.com/disco from tailscale.com/net/tstun+ + tailscale.com/drive from tailscale.com/client/local+ + tailscale.com/envknob from tailscale.com/client/local+ + tailscale.com/envknob/featureknob from tailscale.com/client/web+ + tailscale.com/feature from tailscale.com/ipn/ipnext+ + tailscale.com/feature/buildfeatures from tailscale.com/wgengine/magicsock+ + tailscale.com/feature/c2n from tailscale.com/tsnet + tailscale.com/feature/condlite/expvar from tailscale.com/wgengine/magicsock + tailscale.com/feature/condregister/oauthkey from tailscale.com/tsnet + tailscale.com/feature/condregister/portmapper from tailscale.com/tsnet + tailscale.com/feature/condregister/useproxy from tailscale.com/tsnet + tailscale.com/feature/oauthkey from tailscale.com/feature/condregister/oauthkey + tailscale.com/feature/portmapper from tailscale.com/feature/condregister/portmapper + tailscale.com/feature/syspolicy from tailscale.com/logpolicy + tailscale.com/feature/useproxy from tailscale.com/feature/condregister/useproxy + tailscale.com/health from tailscale.com/control/controlclient+ + tailscale.com/health/healthmsg from tailscale.com/ipn/ipnlocal+ + tailscale.com/hostinfo from tailscale.com/client/web+ + tailscale.com/internal/client/tailscale from tailscale.com/tsnet+ + tailscale.com/ipn from tailscale.com/client/local+ + tailscale.com/ipn/conffile from tailscale.com/ipn/ipnlocal+ + đŸ’Ŗ tailscale.com/ipn/ipnauth from tailscale.com/ipn/ipnext+ + tailscale.com/ipn/ipnext from tailscale.com/ipn/ipnlocal + tailscale.com/ipn/ipnlocal from tailscale.com/ipn/localapi+ + tailscale.com/ipn/ipnstate from tailscale.com/client/local+ + tailscale.com/ipn/localapi from tailscale.com/tsnet + tailscale.com/ipn/store from tailscale.com/ipn/ipnlocal+ + tailscale.com/ipn/store/mem from tailscale.com/ipn/ipnlocal+ + tailscale.com/kube/kubetypes from tailscale.com/envknob + tailscale.com/licenses from tailscale.com/client/web + tailscale.com/log/filelogger from tailscale.com/logpolicy + tailscale.com/log/sockstatlog from tailscale.com/ipn/ipnlocal + tailscale.com/logpolicy from tailscale.com/ipn/ipnlocal+ + tailscale.com/logtail from tailscale.com/control/controlclient+ + tailscale.com/logtail/filch from tailscale.com/log/sockstatlog+ + tailscale.com/metrics from tailscale.com/tsweb+ + tailscale.com/net/bakedroots from tailscale.com/ipn/ipnlocal+ + đŸ’Ŗ tailscale.com/net/batching from tailscale.com/wgengine/magicsock + tailscale.com/net/captivedetection from tailscale.com/ipn/ipnlocal+ + tailscale.com/net/dns from tailscale.com/ipn/ipnlocal+ + tailscale.com/net/dns/publicdns from tailscale.com/net/dns+ + tailscale.com/net/dns/resolvconffile from tailscale.com/net/dns+ + tailscale.com/net/dns/resolver from tailscale.com/net/dns+ + tailscale.com/net/dnscache from tailscale.com/control/controlclient+ + tailscale.com/net/dnsfallback from tailscale.com/control/controlclient+ + tailscale.com/net/flowtrack from tailscale.com/wgengine+ + tailscale.com/net/ipset from tailscale.com/ipn/ipnlocal+ + tailscale.com/net/memnet from tailscale.com/tsnet + tailscale.com/net/netaddr from tailscale.com/ipn+ + tailscale.com/net/netcheck from tailscale.com/ipn/ipnlocal+ + tailscale.com/net/neterror from tailscale.com/net/dns/resolver+ + tailscale.com/net/netkernelconf from tailscale.com/ipn/ipnlocal + tailscale.com/net/netknob from tailscale.com/logpolicy+ + đŸ’Ŗ tailscale.com/net/netmon from tailscale.com/control/controlclient+ + đŸ’Ŗ tailscale.com/net/netns from tailscale.com/derp/derphttp+ + tailscale.com/net/netutil from tailscale.com/client/local+ + tailscale.com/net/netx from tailscale.com/control/controlclient+ + tailscale.com/net/packet from tailscale.com/ipn/ipnlocal+ + tailscale.com/net/packet/checksum from tailscale.com/net/tstun + tailscale.com/net/ping from tailscale.com/net/netcheck+ + tailscale.com/net/portmapper from tailscale.com/feature/portmapper + tailscale.com/net/portmapper/portmappertype from tailscale.com/net/netcheck+ + tailscale.com/net/proxymux from tailscale.com/tsnet + đŸ’Ŗ tailscale.com/net/sockopts from tailscale.com/wgengine/magicsock + tailscale.com/net/socks5 from tailscale.com/tsnet + tailscale.com/net/sockstats from tailscale.com/control/controlclient+ + tailscale.com/net/stun from tailscale.com/ipn/localapi+ + tailscale.com/net/tlsdial from tailscale.com/control/controlclient+ + tailscale.com/net/tlsdial/blockblame from tailscale.com/net/tlsdial + tailscale.com/net/tsaddr from tailscale.com/client/web+ + tailscale.com/net/tsdial from tailscale.com/control/controlclient+ + đŸ’Ŗ tailscale.com/net/tshttpproxy from tailscale.com/feature/useproxy + tailscale.com/net/tstun from tailscale.com/tsd+ + tailscale.com/net/udprelay/endpoint from tailscale.com/wgengine/magicsock + tailscale.com/net/udprelay/status from tailscale.com/client/local + tailscale.com/omit from tailscale.com/ipn/conffile + tailscale.com/paths from tailscale.com/client/local+ + tailscale.com/proxymap from tailscale.com/tsd+ + đŸ’Ŗ tailscale.com/safesocket from tailscale.com/client/local+ + tailscale.com/syncs from tailscale.com/control/controlhttp+ + tailscale.com/tailcfg from tailscale.com/client/local+ + tailscale.com/tempfork/acme from tailscale.com/ipn/ipnlocal + tailscale.com/tempfork/heap from tailscale.com/wgengine/magicsock + tailscale.com/tempfork/httprec from tailscale.com/feature/c2n + tailscale.com/tka from tailscale.com/client/local+ + tailscale.com/tsconst from tailscale.com/ipn/ipnlocal+ + tailscale.com/tsd from tailscale.com/ipn/ipnext+ + tailscale.com/tsnet from tailscale.com/cmd/tsidp + tailscale.com/tstime from tailscale.com/control/controlclient+ + tailscale.com/tstime/mono from tailscale.com/net/tstun+ + tailscale.com/tstime/rate from tailscale.com/wgengine/filter + tailscale.com/tsweb from tailscale.com/util/eventbus + tailscale.com/tsweb/varz from tailscale.com/tsweb+ + tailscale.com/types/appctype from tailscale.com/ipn/ipnlocal+ + tailscale.com/types/bools from tailscale.com/tsnet+ + tailscale.com/types/dnstype from tailscale.com/client/local+ + tailscale.com/types/empty from tailscale.com/ipn+ + tailscale.com/types/ipproto from tailscale.com/ipn+ + tailscale.com/types/key from tailscale.com/client/local+ + tailscale.com/types/lazy from tailscale.com/cmd/tsidp+ + tailscale.com/types/logger from tailscale.com/appc+ + tailscale.com/types/logid from tailscale.com/ipn/ipnlocal+ + tailscale.com/types/mapx from tailscale.com/ipn/ipnext + tailscale.com/types/netlogfunc from tailscale.com/net/tstun+ + tailscale.com/types/netlogtype from tailscale.com/wgengine/netlog + tailscale.com/types/netmap from tailscale.com/control/controlclient+ + tailscale.com/types/nettype from tailscale.com/ipn/localapi+ + tailscale.com/types/opt from tailscale.com/cmd/tsidp+ + tailscale.com/types/persist from tailscale.com/control/controlclient+ + tailscale.com/types/preftype from tailscale.com/ipn+ + tailscale.com/types/ptr from tailscale.com/control/controlclient+ + tailscale.com/types/result from tailscale.com/util/lineiter + tailscale.com/types/structs from tailscale.com/control/controlclient+ + tailscale.com/types/tkatype from tailscale.com/client/local+ + tailscale.com/types/views from tailscale.com/appc+ + tailscale.com/util/backoff from tailscale.com/control/controlclient+ + tailscale.com/util/checkchange from tailscale.com/ipn/ipnlocal+ + tailscale.com/util/cibuild from tailscale.com/health+ + tailscale.com/util/clientmetric from tailscale.com/appc+ + tailscale.com/util/cloudenv from tailscale.com/hostinfo+ + LW tailscale.com/util/cmpver from tailscale.com/net/dns+ + tailscale.com/util/ctxkey from tailscale.com/client/tailscale/apitype+ + đŸ’Ŗ tailscale.com/util/deephash from tailscale.com/util/syspolicy/setting + L đŸ’Ŗ tailscale.com/util/dirwalk from tailscale.com/metrics + tailscale.com/util/dnsname from tailscale.com/appc+ + tailscale.com/util/eventbus from tailscale.com/client/local+ + tailscale.com/util/execqueue from tailscale.com/appc+ + tailscale.com/util/goroutines from tailscale.com/ipn/ipnlocal + tailscale.com/util/groupmember from tailscale.com/client/web+ + đŸ’Ŗ tailscale.com/util/hashx from tailscale.com/util/deephash + tailscale.com/util/httpm from tailscale.com/client/web+ + tailscale.com/util/lineiter from tailscale.com/hostinfo+ + tailscale.com/util/mak from tailscale.com/appc+ + tailscale.com/util/must from tailscale.com/cmd/tsidp+ + tailscale.com/util/nocasemaps from tailscale.com/types/ipproto + đŸ’Ŗ tailscale.com/util/osdiag from tailscale.com/ipn/localapi + W đŸ’Ŗ tailscale.com/util/osdiag/internal/wsc from tailscale.com/util/osdiag + tailscale.com/util/osuser from tailscale.com/ipn/ipnlocal + tailscale.com/util/race from tailscale.com/net/dns/resolver + tailscale.com/util/racebuild from tailscale.com/logpolicy + tailscale.com/util/rands from tailscale.com/cmd/tsidp+ + tailscale.com/util/ringlog from tailscale.com/wgengine/magicsock + tailscale.com/util/set from tailscale.com/control/controlclient+ + tailscale.com/util/singleflight from tailscale.com/control/controlclient+ + tailscale.com/util/slicesx from tailscale.com/appc+ + tailscale.com/util/syspolicy from tailscale.com/feature/syspolicy + tailscale.com/util/syspolicy/internal from tailscale.com/util/syspolicy+ + tailscale.com/util/syspolicy/internal/loggerx from tailscale.com/util/syspolicy+ + tailscale.com/util/syspolicy/internal/metrics from tailscale.com/util/syspolicy/source + tailscale.com/util/syspolicy/pkey from tailscale.com/control/controlclient+ + tailscale.com/util/syspolicy/policyclient from tailscale.com/control/controlclient+ + tailscale.com/util/syspolicy/ptype from tailscale.com/util/syspolicy+ + tailscale.com/util/syspolicy/rsop from tailscale.com/ipn/localapi+ + tailscale.com/util/syspolicy/setting from tailscale.com/client/local+ + tailscale.com/util/syspolicy/source from tailscale.com/util/syspolicy+ + tailscale.com/util/testenv from tailscale.com/control/controlclient+ + tailscale.com/util/truncate from tailscale.com/logtail + tailscale.com/util/usermetric from tailscale.com/health+ + tailscale.com/util/vizerror from tailscale.com/tailcfg+ + đŸ’Ŗ tailscale.com/util/winutil from tailscale.com/hostinfo+ + W đŸ’Ŗ tailscale.com/util/winutil/authenticode from tailscale.com/util/osdiag + W đŸ’Ŗ tailscale.com/util/winutil/gp from tailscale.com/net/dns+ + W tailscale.com/util/winutil/policy from tailscale.com/ipn/ipnlocal + W đŸ’Ŗ tailscale.com/util/winutil/winenv from tailscale.com/hostinfo+ + tailscale.com/util/zstdframe from tailscale.com/control/controlclient+ + tailscale.com/version from tailscale.com/client/web+ + tailscale.com/version/distro from tailscale.com/client/web+ + tailscale.com/wgengine from tailscale.com/ipn/ipnlocal+ + tailscale.com/wgengine/filter from tailscale.com/control/controlclient+ + tailscale.com/wgengine/filter/filtertype from tailscale.com/types/netmap+ + đŸ’Ŗ tailscale.com/wgengine/magicsock from tailscale.com/ipn/ipnlocal+ + tailscale.com/wgengine/netlog from tailscale.com/wgengine + tailscale.com/wgengine/netstack from tailscale.com/tsnet + tailscale.com/wgengine/netstack/gro from tailscale.com/net/tstun+ + tailscale.com/wgengine/router from tailscale.com/ipn/ipnlocal+ + tailscale.com/wgengine/wgcfg from tailscale.com/ipn/ipnlocal+ + tailscale.com/wgengine/wgcfg/nmcfg from tailscale.com/ipn/ipnlocal + đŸ’Ŗ tailscale.com/wgengine/wgint from tailscale.com/wgengine+ + tailscale.com/wgengine/wglog from tailscale.com/wgengine + golang.org/x/crypto/argon2 from tailscale.com/tka + golang.org/x/crypto/blake2b from golang.org/x/crypto/argon2+ + golang.org/x/crypto/blake2s from github.com/tailscale/wireguard-go/device+ + LD golang.org/x/crypto/blowfish from golang.org/x/crypto/ssh/internal/bcrypt_pbkdf + golang.org/x/crypto/chacha20 from golang.org/x/crypto/chacha20poly1305+ + golang.org/x/crypto/chacha20poly1305 from github.com/tailscale/wireguard-go/device+ + golang.org/x/crypto/curve25519 from github.com/tailscale/wireguard-go/device+ + golang.org/x/crypto/ed25519 from gopkg.in/square/go-jose.v2 + golang.org/x/crypto/hkdf from tailscale.com/control/controlbase + golang.org/x/crypto/internal/alias from golang.org/x/crypto/chacha20+ + golang.org/x/crypto/internal/poly1305 from golang.org/x/crypto/chacha20poly1305+ + golang.org/x/crypto/nacl/box from tailscale.com/types/key + golang.org/x/crypto/nacl/secretbox from golang.org/x/crypto/nacl/box + golang.org/x/crypto/pbkdf2 from gopkg.in/square/go-jose.v2 + golang.org/x/crypto/poly1305 from github.com/tailscale/wireguard-go/device + golang.org/x/crypto/salsa20/salsa from golang.org/x/crypto/nacl/box+ + LD golang.org/x/crypto/ssh from tailscale.com/ipn/ipnlocal + LD golang.org/x/crypto/ssh/internal/bcrypt_pbkdf from golang.org/x/crypto/ssh + golang.org/x/exp/constraints from tailscale.com/tsweb/varz+ + golang.org/x/exp/maps from tailscale.com/ipn/store/mem+ + golang.org/x/net/bpf from github.com/mdlayher/netlink+ + golang.org/x/net/dns/dnsmessage from tailscale.com/appc+ + golang.org/x/net/http/httpguts from tailscale.com/ipn/ipnlocal + golang.org/x/net/http/httpproxy from tailscale.com/net/tshttpproxy + golang.org/x/net/icmp from github.com/prometheus-community/pro-bing+ + golang.org/x/net/idna from golang.org/x/net/http/httpguts+ + golang.org/x/net/internal/iana from golang.org/x/net/icmp+ + golang.org/x/net/internal/socket from golang.org/x/net/icmp+ + golang.org/x/net/internal/socks from golang.org/x/net/proxy + golang.org/x/net/ipv4 from github.com/prometheus-community/pro-bing+ + golang.org/x/net/ipv6 from github.com/prometheus-community/pro-bing+ + golang.org/x/net/proxy from tailscale.com/net/netns + D golang.org/x/net/route from tailscale.com/net/netmon+ + golang.org/x/oauth2 from golang.org/x/oauth2/clientcredentials + golang.org/x/oauth2/clientcredentials from tailscale.com/feature/oauthkey + golang.org/x/oauth2/internal from golang.org/x/oauth2+ + golang.org/x/sync/errgroup from github.com/mdlayher/socket+ + golang.org/x/sys/cpu from github.com/tailscale/certstore+ + LD golang.org/x/sys/unix from github.com/jsimonetti/rtnetlink/internal/unix+ + W golang.org/x/sys/windows from github.com/dblohm7/wingoes+ + W golang.org/x/sys/windows/registry from github.com/dblohm7/wingoes+ + W golang.org/x/sys/windows/svc from golang.org/x/sys/windows/svc/mgr+ + W golang.org/x/sys/windows/svc/mgr from tailscale.com/util/winutil + golang.org/x/term from tailscale.com/logpolicy + golang.org/x/text/secure/bidirule from golang.org/x/net/idna + golang.org/x/text/transform from golang.org/x/text/secure/bidirule+ + golang.org/x/text/unicode/bidi from golang.org/x/net/idna+ + golang.org/x/text/unicode/norm from golang.org/x/net/idna + golang.org/x/time/rate from gvisor.dev/gvisor/pkg/log+ + vendor/golang.org/x/crypto/chacha20 from vendor/golang.org/x/crypto/chacha20poly1305 + vendor/golang.org/x/crypto/chacha20poly1305 from crypto/internal/hpke+ + vendor/golang.org/x/crypto/cryptobyte from crypto/ecdsa+ + vendor/golang.org/x/crypto/cryptobyte/asn1 from crypto/ecdsa+ + vendor/golang.org/x/crypto/internal/alias from vendor/golang.org/x/crypto/chacha20+ + vendor/golang.org/x/crypto/internal/poly1305 from vendor/golang.org/x/crypto/chacha20poly1305 + vendor/golang.org/x/net/dns/dnsmessage from net + vendor/golang.org/x/net/http/httpguts from net/http+ + vendor/golang.org/x/net/http/httpproxy from net/http + vendor/golang.org/x/net/http2/hpack from net/http+ + vendor/golang.org/x/net/idna from net/http+ + vendor/golang.org/x/sys/cpu from vendor/golang.org/x/crypto/chacha20poly1305 + vendor/golang.org/x/text/secure/bidirule from vendor/golang.org/x/net/idna + vendor/golang.org/x/text/transform from vendor/golang.org/x/text/secure/bidirule+ + vendor/golang.org/x/text/unicode/bidi from vendor/golang.org/x/net/idna+ + vendor/golang.org/x/text/unicode/norm from vendor/golang.org/x/net/idna + bufio from compress/flate+ + bytes from bufio+ + cmp from encoding/json+ + compress/flate from compress/gzip+ + compress/gzip from internal/profile+ + W compress/zlib from debug/pe + container/heap from gvisor.dev/gvisor/pkg/tcpip/transport/tcp + container/list from crypto/tls+ + context from crypto/tls+ + crypto from crypto/ecdh+ + crypto/aes from crypto/internal/hpke+ + crypto/cipher from crypto/aes+ + crypto/des from crypto/tls+ + crypto/dsa from crypto/x509+ + crypto/ecdh from crypto/ecdsa+ + crypto/ecdsa from crypto/tls+ + crypto/ed25519 from crypto/tls+ + crypto/elliptic from crypto/ecdsa+ + crypto/fips140 from crypto/tls/internal/fips140tls+ + crypto/hkdf from crypto/internal/hpke+ + crypto/hmac from crypto/tls+ + crypto/internal/boring from crypto/aes+ + crypto/internal/boring/bbig from crypto/ecdsa+ + crypto/internal/boring/sig from crypto/internal/boring + crypto/internal/entropy from crypto/internal/fips140/drbg + crypto/internal/fips140 from crypto/internal/fips140/aes+ + crypto/internal/fips140/aes from crypto/aes+ + crypto/internal/fips140/aes/gcm from crypto/cipher+ + crypto/internal/fips140/alias from crypto/cipher+ + crypto/internal/fips140/bigmod from crypto/internal/fips140/ecdsa+ + crypto/internal/fips140/check from crypto/internal/fips140/aes+ + crypto/internal/fips140/drbg from crypto/internal/fips140/aes/gcm+ + crypto/internal/fips140/ecdh from crypto/ecdh + crypto/internal/fips140/ecdsa from crypto/ecdsa + crypto/internal/fips140/ed25519 from crypto/ed25519 + crypto/internal/fips140/edwards25519 from crypto/internal/fips140/ed25519 + crypto/internal/fips140/edwards25519/field from crypto/ecdh+ + crypto/internal/fips140/hkdf from crypto/internal/fips140/tls13+ + crypto/internal/fips140/hmac from crypto/hmac+ + crypto/internal/fips140/mlkem from crypto/tls+ + crypto/internal/fips140/nistec from crypto/elliptic+ + crypto/internal/fips140/nistec/fiat from crypto/internal/fips140/nistec + crypto/internal/fips140/rsa from crypto/rsa + crypto/internal/fips140/sha256 from crypto/internal/fips140/check+ + crypto/internal/fips140/sha3 from crypto/internal/fips140/hmac+ + crypto/internal/fips140/sha512 from crypto/internal/fips140/ecdsa+ + crypto/internal/fips140/subtle from crypto/internal/fips140/aes+ + crypto/internal/fips140/tls12 from crypto/tls + crypto/internal/fips140/tls13 from crypto/tls + crypto/internal/fips140cache from crypto/ecdsa+ + crypto/internal/fips140deps/byteorder from crypto/internal/fips140/aes+ + crypto/internal/fips140deps/cpu from crypto/internal/fips140/aes+ + crypto/internal/fips140deps/godebug from crypto/internal/fips140+ + crypto/internal/fips140hash from crypto/ecdsa+ + crypto/internal/fips140only from crypto/cipher+ + crypto/internal/hpke from crypto/tls + crypto/internal/impl from crypto/internal/fips140/aes+ + crypto/internal/randutil from crypto/dsa+ + crypto/internal/sysrand from crypto/internal/entropy+ + crypto/md5 from crypto/tls+ + LD crypto/mlkem from golang.org/x/crypto/ssh + crypto/rand from crypto/ed25519+ + crypto/rc4 from crypto/tls+ + crypto/rsa from crypto/tls+ + crypto/sha1 from crypto/tls+ + crypto/sha256 from crypto/tls+ + crypto/sha3 from crypto/internal/fips140hash + crypto/sha512 from crypto/ecdsa+ + crypto/subtle from crypto/cipher+ + crypto/tls from github.com/prometheus-community/pro-bing+ + crypto/tls/internal/fips140tls from crypto/tls + crypto/x509 from crypto/tls+ + D crypto/x509/internal/macos from crypto/x509 + crypto/x509/pkix from crypto/x509+ + D database/sql/driver from github.com/google/uuid + W debug/dwarf from debug/pe + W debug/pe from github.com/dblohm7/wingoes/pe + embed from github.com/tailscale/web-client-prebuilt+ + encoding from encoding/json+ + encoding/asn1 from crypto/x509+ + encoding/base32 from github.com/fxamacker/cbor/v2+ + encoding/base64 from encoding/json+ + encoding/binary from compress/gzip+ + encoding/hex from crypto/x509+ + encoding/json from expvar+ + encoding/pem from crypto/tls+ + encoding/xml from github.com/tailscale/goupnp+ + errors from bufio+ + expvar from tailscale.com/health+ + flag from tailscale.com/cmd/tsidp+ + fmt from compress/flate+ + hash from crypto+ + W hash/adler32 from compress/zlib + hash/crc32 from compress/gzip+ + hash/maphash from go4.org/mem + html from html/template+ + html/template from tailscale.com/util/eventbus+ + internal/abi from crypto/x509/internal/macos+ + internal/asan from internal/runtime/maps+ + internal/bisect from internal/godebug + internal/bytealg from bytes+ + internal/byteorder from crypto/cipher+ + internal/chacha8rand from math/rand/v2+ + internal/coverage/rtcov from runtime + internal/cpu from crypto/internal/fips140deps/cpu+ + internal/filepathlite from os+ + internal/fmtsort from fmt+ + internal/goarch from crypto/internal/fips140deps/cpu+ + internal/godebug from crypto/internal/fips140deps/godebug+ + internal/godebugs from internal/godebug+ + internal/goexperiment from hash/maphash+ + internal/goos from crypto/x509+ + internal/itoa from internal/poll+ + internal/msan from internal/runtime/maps+ + internal/nettrace from net+ + internal/oserror from io/fs+ + internal/poll from net+ + internal/profile from net/http/pprof + internal/profilerecord from runtime+ + internal/race from internal/poll+ + internal/reflectlite from context+ + D internal/routebsd from net + internal/runtime/atomic from internal/runtime/exithook+ + L internal/runtime/cgroup from runtime + internal/runtime/exithook from runtime + internal/runtime/gc from runtime + internal/runtime/maps from reflect+ + internal/runtime/math from internal/runtime/maps+ + internal/runtime/strconv from internal/runtime/cgroup+ + internal/runtime/sys from crypto/subtle+ + L internal/runtime/syscall from runtime+ + internal/saferio from debug/pe+ + internal/singleflight from net + internal/stringslite from embed+ + internal/sync from sync+ + internal/synctest from sync + internal/syscall/execenv from os+ + LD internal/syscall/unix from crypto/internal/sysrand+ + W internal/syscall/windows from crypto/internal/sysrand+ + W internal/syscall/windows/registry from mime+ + W internal/syscall/windows/sysdll from internal/syscall/windows+ + internal/testlog from os + internal/trace/tracev2 from runtime+ + internal/unsafeheader from internal/reflectlite+ + io from bufio+ + io/fs from crypto/x509+ + io/ioutil from github.com/godbus/dbus/v5+ + iter from bytes+ + log from expvar+ + log/internal from log + maps from crypto/x509+ + math from compress/flate+ + math/big from crypto/dsa+ + math/bits from bytes+ + math/rand from github.com/fxamacker/cbor/v2+ + math/rand/v2 from crypto/ecdsa+ + mime from mime/multipart+ + mime/multipart from net/http + mime/quotedprintable from mime/multipart + net from crypto/tls+ + net/http from expvar+ + net/http/httptrace from github.com/prometheus-community/pro-bing+ + net/http/httputil from tailscale.com/client/web+ + net/http/internal from net/http+ + net/http/internal/ascii from net/http+ + net/http/internal/httpcommon from net/http + net/http/pprof from tailscale.com/ipn/localapi+ + net/netip from crypto/x509+ + net/textproto from github.com/coder/websocket+ + net/url from crypto/x509+ + os from crypto/internal/sysrand+ + os/exec from github.com/godbus/dbus/v5+ + os/signal from tailscale.com/cmd/tsidp + os/user from github.com/godbus/dbus/v5+ + path from debug/dwarf+ + path/filepath from crypto/x509+ + reflect from crypto/x509+ + regexp from github.com/tailscale/goupnp/httpu+ + regexp/syntax from regexp + runtime from crypto/internal/fips140+ + runtime/debug from github.com/coder/websocket/internal/xsync+ + runtime/pprof from net/http/pprof+ + runtime/trace from net/http/pprof + slices from crypto/tls+ + sort from compress/flate+ + strconv from compress/flate+ + strings from bufio+ + W structs from internal/syscall/windows + sync from compress/flate+ + sync/atomic from context+ + syscall from crypto/internal/sysrand+ + text/tabwriter from runtime/pprof + text/template from html/template + text/template/parse from html/template+ + time from compress/gzip+ + unicode from bytes+ + unicode/utf16 from crypto/x509+ + unicode/utf8 from bufio+ + unique from net/netip + unsafe from bytes+ + weak from unique+ diff --git a/cmd/tsidp/tsidp.go b/cmd/tsidp/tsidp.go index 1bdca8919..7093ab9ee 100644 --- a/cmd/tsidp/tsidp.go +++ b/cmd/tsidp/tsidp.go @@ -11,6 +11,7 @@ import ( "context" crand "crypto/rand" "crypto/rsa" + "crypto/subtle" "crypto/tls" "crypto/x509" "encoding/base64" @@ -28,6 +29,7 @@ import ( "net/url" "os" "os/signal" + "path/filepath" "strconv" "strings" "sync" @@ -35,15 +37,17 @@ import ( "gopkg.in/square/go-jose.v2" "gopkg.in/square/go-jose.v2/jwt" - "tailscale.com/client/tailscale" + "tailscale.com/client/local" "tailscale.com/client/tailscale/apitype" "tailscale.com/envknob" + "tailscale.com/hostinfo" "tailscale.com/ipn" "tailscale.com/ipn/ipnstate" "tailscale.com/tailcfg" "tailscale.com/tsnet" "tailscale.com/types/key" "tailscale.com/types/lazy" + "tailscale.com/types/opt" "tailscale.com/types/views" "tailscale.com/util/mak" "tailscale.com/util/must" @@ -58,16 +62,40 @@ type ctxConn struct{} // accessing the IDP over Funnel are persisted. const funnelClientsFile = "oidc-funnel-clients.json" +// oauthClientsFile is the new file name for OAuth clients when running in secure mode. +const oauthClientsFile = "oauth-clients.json" + +// deprecatedFunnelClientsFile is the name used when renaming the old file. +const deprecatedFunnelClientsFile = "deprecated-oidc-funnel-clients.json" + +// oidcKeyFile is where the OIDC private key is persisted. +const oidcKeyFile = "oidc-key.json" + var ( - flagVerbose = flag.Bool("verbose", false, "be verbose") - flagPort = flag.Int("port", 443, "port to listen on") - flagLocalPort = flag.Int("local-port", -1, "allow requests from localhost") - flagUseLocalTailscaled = flag.Bool("use-local-tailscaled", false, "use local tailscaled instead of tsnet") - flagFunnel = flag.Bool("funnel", false, "use Tailscale Funnel to make tsidp available on the public internet") - flagDir = flag.String("dir", "", "tsnet state directory; a default one will be created if not provided") + flagVerbose = flag.Bool("verbose", false, "be verbose") + flagPort = flag.Int("port", 443, "port to listen on") + flagLocalPort = flag.Int("local-port", -1, "allow requests from localhost") + flagUseLocalTailscaled = flag.Bool("use-local-tailscaled", false, "use local tailscaled instead of tsnet") + flagFunnel = flag.Bool("funnel", false, "use Tailscale Funnel to make tsidp available on the public internet") + flagHostname = flag.String("hostname", "idp", "tsnet hostname to use instead of idp") + flagDir = flag.String("dir", "", "tsnet state directory; a default one will be created if not provided") + flagAllowInsecureRegistrationBool opt.Bool + flagAllowInsecureRegistration = opt.BoolFlag{Bool: &flagAllowInsecureRegistrationBool} ) +// getAllowInsecureRegistration returns whether to allow OAuth flows without pre-registered clients. +// Default is true for backward compatibility; explicitly set to false for strict OAuth compliance. +func getAllowInsecureRegistration() bool { + v, ok := flagAllowInsecureRegistration.Get() + if !ok { + // Flag not set, default to true (allow insecure for backward compatibility) + return true + } + return v +} + func main() { + flag.Var(&flagAllowInsecureRegistration, "allow-insecure-registration", "allow OAuth flows without pre-registered client credentials (default: true for backward compatibility; set to false for strict OAuth compliance)") flag.Parse() ctx := context.Background() if !envknob.UseWIPCode() { @@ -75,16 +103,18 @@ func main() { } var ( - lc *tailscale.LocalClient + lc *local.Client st *ipnstate.Status + rootPath string err error watcherChan chan error cleanup func() lns []net.Listener ) + if *flagUseLocalTailscaled { - lc = &tailscale.LocalClient{} + lc = &local.Client{} st, err = lc.StatusWithoutPeers(ctx) if err != nil { log.Fatalf("getting status: %v", err) @@ -107,6 +137,15 @@ func main() { log.Fatalf("failed to listen on any of %v", st.TailscaleIPs) } + if flagDir == nil || *flagDir == "" { + // use user config directory as storage for tsidp oidc key + configDir, err := os.UserConfigDir() + if err != nil { + log.Fatalf("getting user config directory: %v", err) + } + rootPath = filepath.Join(configDir, "tsidp") + } + // tailscaled needs to be setting an HTTP header for funneled requests // that older versions don't provide. // TODO(naman): is this the correct check? @@ -119,8 +158,9 @@ func main() { } defer cleanup() } else { + hostinfo.SetApp("tsidp") ts := &tsnet.Server{ - Hostname: "idp", + Hostname: *flagHostname, Dir: *flagDir, } if *flagVerbose { @@ -147,28 +187,48 @@ func main() { log.Fatal(err) } lns = append(lns, ln) + + rootPath = ts.GetRootPath() + log.Printf("tsidp root path: %s", rootPath) } srv := &idpServer{ - lc: lc, - funnel: *flagFunnel, - localTSMode: *flagUseLocalTailscaled, + lc: lc, + funnel: *flagFunnel, + localTSMode: *flagUseLocalTailscaled, + rootPath: rootPath, + allowInsecureRegistration: getAllowInsecureRegistration(), } + if *flagPort != 443 { srv.serverURL = fmt.Sprintf("https://%s:%d", strings.TrimSuffix(st.Self.DNSName, "."), *flagPort) } else { srv.serverURL = fmt.Sprintf("https://%s", strings.TrimSuffix(st.Self.DNSName, ".")) } - if *flagFunnel { - f, err := os.Open(funnelClientsFile) - if err == nil { - srv.funnelClients = make(map[string]*funnelClient) - if err := json.NewDecoder(f).Decode(&srv.funnelClients); err != nil { - log.Fatalf("could not parse %s: %v", funnelClientsFile, err) - } - } else if !errors.Is(err, os.ErrNotExist) { - log.Fatalf("could not open %s: %v", funnelClientsFile, err) + + // If allowInsecureRegistration is enabled, the old oidc-funnel-clients.json path is used. + // If allowInsecureRegistration is disabled, attempt to migrate the old path to oidc-clients.json and use this new path. + var clientsFilePath string + if !srv.allowInsecureRegistration { + clientsFilePath, err = migrateOAuthClients(rootPath) + if err != nil { + log.Fatalf("could not migrate OAuth clients: %v", err) + } + } else { + clientsFilePath, err = getConfigFilePath(rootPath, funnelClientsFile) + if err != nil { + log.Fatalf("could not get funnel clients file path: %v", err) + } + } + + f, err := os.Open(clientsFilePath) + if err == nil { + if err := json.NewDecoder(f).Decode(&srv.funnelClients); err != nil { + log.Fatalf("could not parse %s: %v", clientsFilePath, err) } + f.Close() + } else if !errors.Is(err, os.ErrNotExist) { + log.Fatalf("could not open %s: %v", clientsFilePath, err) } log.Printf("Running tsidp at %s ...", srv.serverURL) @@ -212,7 +272,7 @@ func main() { // serveOnLocalTailscaled starts a serve session using an already-running // tailscaled instead of starting a fresh tsnet server, making something // listening on clientDNSName:dstPort accessible over serve/funnel. -func serveOnLocalTailscaled(ctx context.Context, lc *tailscale.LocalClient, st *ipnstate.Status, dstPort uint16, shouldFunnel bool) (cleanup func(), watcherChan chan error, err error) { +func serveOnLocalTailscaled(ctx context.Context, lc *local.Client, st *ipnstate.Status, dstPort uint16, shouldFunnel bool) (cleanup func(), watcherChan chan error, err error) { // In order to support funneling out in local tailscaled mode, we need // to add a serve config to forward the listeners we bound above and // allow those forwarders to be funneled out. @@ -227,7 +287,7 @@ func serveOnLocalTailscaled(ctx context.Context, lc *tailscale.LocalClient, st * // We watch the IPN bus just to get a session ID. The session expires // when we stop watching the bus, and that auto-deletes the foreground // serve/funnel configs we are creating below. - watcher, err := lc.WatchIPNBus(ctx, ipn.NotifyInitialState|ipn.NotifyNoPrivateKeys) + watcher, err := lc.WatchIPNBus(ctx, ipn.NotifyInitialState) if err != nil { return nil, nil, fmt.Errorf("could not set up ipn bus watcher: %v", err) } @@ -265,7 +325,7 @@ func serveOnLocalTailscaled(ctx context.Context, lc *tailscale.LocalClient, st * foregroundSc.SetFunnel(serverURL, dstPort, shouldFunnel) foregroundSc.SetWebHandler(&ipn.HTTPHandler{ Proxy: fmt.Sprintf("https://%s", net.JoinHostPort(serverURL, strconv.Itoa(int(dstPort)))), - }, serverURL, uint16(*flagPort), "/", true) + }, serverURL, uint16(*flagPort), "/", true, st.CurrentTailnet.MagicDNSSuffix) err = lc.SetServeConfig(ctx, sc) if err != nil { return nil, watcherChan, fmt.Errorf("could not set serve config: %v", err) @@ -275,11 +335,13 @@ func serveOnLocalTailscaled(ctx context.Context, lc *tailscale.LocalClient, st * } type idpServer struct { - lc *tailscale.LocalClient - loopbackURL string - serverURL string // "https://foo.bar.ts.net" - funnel bool - localTSMode bool + lc *local.Client + loopbackURL string + serverURL string // "https://foo.bar.ts.net" + funnel bool + localTSMode bool + rootPath string // root path, used for storing state files + allowInsecureRegistration bool // If true, allow OAuth without pre-registered clients lazyMux lazy.SyncValue[*http.ServeMux] lazySigningKey lazy.SyncValue[*signingKey] @@ -328,7 +390,7 @@ type authRequest struct { // allowRelyingParty validates that a relying party identified either by a // known remoteAddr or a valid client ID/secret pair is allowed to proceed // with the authorization flow associated with this authRequest. -func (ar *authRequest) allowRelyingParty(r *http.Request, lc *tailscale.LocalClient) error { +func (ar *authRequest) allowRelyingParty(r *http.Request, lc *local.Client) error { if ar.localRP { ra, err := netip.ParseAddrPort(r.RemoteAddr) if err != nil { @@ -345,7 +407,9 @@ func (ar *authRequest) allowRelyingParty(r *http.Request, lc *tailscale.LocalCli clientID = r.FormValue("client_id") clientSecret = r.FormValue("client_secret") } - if ar.funnelRP.ID != clientID || ar.funnelRP.Secret != clientSecret { + clientIDcmp := subtle.ConstantTimeCompare([]byte(clientID), []byte(ar.funnelRP.ID)) + clientSecretcmp := subtle.ConstantTimeCompare([]byte(clientSecret), []byte(ar.funnelRP.Secret)) + if clientIDcmp != 1 || clientSecretcmp != 1 { return fmt.Errorf("tsidp: invalid client credentials") } return nil @@ -361,14 +425,15 @@ func (ar *authRequest) allowRelyingParty(r *http.Request, lc *tailscale.LocalCli } func (s *idpServer) authorize(w http.ResponseWriter, r *http.Request) { + // This URL is visited by the user who is being authenticated. If they are // visiting the URL over Funnel, that means they are not part of the // tailnet that they are trying to be authenticated for. + // NOTE: Funnel request behavior is the same regardless of secure or insecure mode. if isFunnelRequest(r) { http.Error(w, "tsidp: unauthorized", http.StatusUnauthorized) return } - uq := r.URL.Query() redirectURI := uq.Get("redirect_uri") @@ -377,6 +442,86 @@ func (s *idpServer) authorize(w http.ResponseWriter, r *http.Request) { return } + clientID := uq.Get("client_id") + if clientID == "" { + http.Error(w, "tsidp: must specify client_id", http.StatusBadRequest) + return + } + + if !s.allowInsecureRegistration { + // When insecure registration is NOT allowed, validate client_id exists but defer client_secret validation to token endpoint + // This follows RFC 6749 which specifies client authentication should occur at token endpoint, not authorization endpoint + + s.mu.Lock() + c, ok := s.funnelClients[clientID] + s.mu.Unlock() + if !ok { + http.Error(w, "tsidp: invalid client ID", http.StatusBadRequest) + return + } + + // Validate client_id matches (public identifier validation) + clientIDcmp := subtle.ConstantTimeCompare([]byte(clientID), []byte(c.ID)) + if clientIDcmp != 1 { + http.Error(w, "tsidp: invalid client ID", http.StatusBadRequest) + return + } + + // Validate redirect URI + if redirectURI != c.RedirectURI { + http.Error(w, "tsidp: redirect_uri mismatch", http.StatusBadRequest) + return + } + + // Get user information + var remoteAddr string + if s.localTSMode { + remoteAddr = r.Header.Get("X-Forwarded-For") + } else { + remoteAddr = r.RemoteAddr + } + + // Check who is visiting the authorize endpoint. + var who *apitype.WhoIsResponse + var err error + who, err = s.lc.WhoIs(r.Context(), remoteAddr) + if err != nil { + log.Printf("Error getting WhoIs: %v", err) + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + code := rands.HexString(32) + ar := &authRequest{ + nonce: uq.Get("nonce"), + remoteUser: who, + redirectURI: redirectURI, + clientID: clientID, + funnelRP: c, // Store the validated client + } + + s.mu.Lock() + mak.Set(&s.code, code, ar) + s.mu.Unlock() + + q := make(url.Values) + q.Set("code", code) + if state := uq.Get("state"); state != "" { + q.Set("state", state) + } + parsedURL, err := url.Parse(redirectURI) + if err != nil { + http.Error(w, "invalid redirect URI", http.StatusInternalServerError) + return + } + parsedURL.RawQuery = q.Encode() + u := parsedURL.String() + log.Printf("Redirecting to %q", u) + + http.Redirect(w, r, u, http.StatusFound) + return + } + var remoteAddr string if s.localTSMode { // in local tailscaled mode, the local tailscaled is forwarding us @@ -398,7 +543,7 @@ func (s *idpServer) authorize(w http.ResponseWriter, r *http.Request) { nonce: uq.Get("nonce"), remoteUser: who, redirectURI: redirectURI, - clientID: uq.Get("client_id"), + clientID: clientID, } if r.URL.Path == "/authorize/funnel" { @@ -434,7 +579,13 @@ func (s *idpServer) authorize(w http.ResponseWriter, r *http.Request) { if state := uq.Get("state"); state != "" { q.Set("state", state) } - u := redirectURI + "?" + q.Encode() + parsedURL, err := url.Parse(redirectURI) + if err != nil { + http.Error(w, "invalid redirect URI", http.StatusInternalServerError) + return + } + parsedURL.RawQuery = q.Encode() + u := parsedURL.String() log.Printf("Redirecting to %q", u) http.Redirect(w, r, u, http.StatusFound) @@ -444,17 +595,17 @@ func (s *idpServer) newMux() *http.ServeMux { mux := http.NewServeMux() mux.HandleFunc(oidcJWKSPath, s.serveJWKS) mux.HandleFunc(oidcConfigPath, s.serveOpenIDConfig) - mux.HandleFunc("/authorize/", s.authorize) + if !s.allowInsecureRegistration { + // When insecure registration is NOT allowed, use a single /authorize endpoint + mux.HandleFunc("/authorize", s.authorize) + } else { + // When insecure registration is allowed, preserve original behavior with path-based routing + mux.HandleFunc("/authorize/", s.authorize) + } mux.HandleFunc("/userinfo", s.serveUserInfo) mux.HandleFunc("/token", s.serveToken) mux.HandleFunc("/clients/", s.serveClients) - mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { - if r.URL.Path == "/" { - io.WriteString(w, "

Tailscale OIDC IdP

") - return - } - http.Error(w, "tsidp: not found", http.StatusNotFound) - }) + mux.HandleFunc("/", s.handleUI) return mux } @@ -487,6 +638,24 @@ func (s *idpServer) serveUserInfo(w http.ResponseWriter, r *http.Request) { s.mu.Lock() delete(s.accessToken, tk) s.mu.Unlock() + return + } + + if !s.allowInsecureRegistration { + // When insecure registration is NOT allowed, validate that the token was issued to a valid client. + if ar.clientID == "" { + http.Error(w, "tsidp: no client associated with token", http.StatusBadRequest) + return + } + + // Validate client still exists + s.mu.Lock() + _, clientExists := s.funnelClients[ar.clientID] + s.mu.Unlock() + if !clientExists { + http.Error(w, "tsidp: client no longer exists", http.StatusUnauthorized) + return + } } ui := userInfo{} @@ -494,6 +663,7 @@ func (s *idpServer) serveUserInfo(w http.ResponseWriter, r *http.Request) { http.Error(w, "tsidp: tagged nodes not supported", http.StatusBadRequest) return } + ui.Sub = ar.remoteUser.Node.User.String() ui.Name = ar.remoteUser.UserProfile.DisplayName ui.Email = ar.remoteUser.UserProfile.LoginName @@ -502,8 +672,29 @@ func (s *idpServer) serveUserInfo(w http.ResponseWriter, r *http.Request) { // TODO(maisem): not sure if this is the right thing to do ui.UserName, _, _ = strings.Cut(ar.remoteUser.UserProfile.LoginName, "@") + rules, err := tailcfg.UnmarshalCapJSON[capRule](ar.remoteUser.CapMap, tailcfg.PeerCapabilityTsIDP) + if err != nil { + http.Error(w, "tsidp: failed to unmarshal capability: %v", http.StatusBadRequest) + return + } + + // Only keep rules where IncludeInUserInfo is true + var filtered []capRule + for _, r := range rules { + if r.IncludeInUserInfo { + filtered = append(filtered, r) + } + } + + userInfo, err := withExtraClaims(ui, filtered) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + // Write the final result w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(ui); err != nil { + if err := json.NewEncoder(w).Encode(userInfo); err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) } } @@ -516,6 +707,140 @@ type userInfo struct { UserName string `json:"username"` } +type capRule struct { + IncludeInUserInfo bool `json:"includeInUserInfo"` + ExtraClaims map[string]any `json:"extraClaims,omitempty"` // list of features peer is allowed to edit +} + +// flattenExtraClaims merges all ExtraClaims from a slice of capRule into a single map. +// It deduplicates values for each claim and preserves the original input type: +// scalar values remain scalars, and slices are returned as deduplicated []any slices. +func flattenExtraClaims(rules []capRule) map[string]any { + // sets stores deduplicated stringified values for each claim key. + sets := make(map[string]map[string]struct{}) + + // isSlice tracks whether each claim was originally provided as a slice. + isSlice := make(map[string]bool) + + for _, rule := range rules { + for claim, raw := range rule.ExtraClaims { + // Track whether the claim was provided as a slice + switch raw.(type) { + case []string, []any: + isSlice[claim] = true + default: + // Only mark as scalar if this is the first time we've seen this claim + if _, seen := isSlice[claim]; !seen { + isSlice[claim] = false + } + } + + // Add the claim value(s) into the deduplication set + addClaimValue(sets, claim, raw) + } + } + + // Build final result: either scalar or slice depending on original type + result := make(map[string]any) + for claim, valSet := range sets { + if isSlice[claim] { + // Claim was provided as a slice: output as []any + var vals []any + for val := range valSet { + vals = append(vals, val) + } + result[claim] = vals + } else { + // Claim was a scalar: return a single value + for val := range valSet { + result[claim] = val + break // only one value is expected + } + } + } + + return result +} + +// addClaimValue adds a claim value to the deduplication set for a given claim key. +// It accepts scalars (string, int, float64), slices of strings or interfaces, +// and recursively handles nested slices. Unsupported types are ignored with a log message. +func addClaimValue(sets map[string]map[string]struct{}, claim string, val any) { + switch v := val.(type) { + case string, float64, int, int64: + // Ensure the claim set is initialized + if sets[claim] == nil { + sets[claim] = make(map[string]struct{}) + } + // Add the stringified scalar to the set + sets[claim][fmt.Sprintf("%v", v)] = struct{}{} + + case []string: + // Ensure the claim set is initialized + if sets[claim] == nil { + sets[claim] = make(map[string]struct{}) + } + // Add each string value to the set + for _, s := range v { + sets[claim][s] = struct{}{} + } + + case []any: + // Recursively handle each item in the slice + for _, item := range v { + addClaimValue(sets, claim, item) + } + + default: + // Log unsupported types for visibility and debugging + log.Printf("Unsupported claim type for %q: %#v (type %T)", claim, val, val) + } +} + +// withExtraClaims merges flattened extra claims from a list of capRule into the provided struct v, +// returning a map[string]any that combines both sources. +// +// v is any struct whose fields represent static claims; it is first marshaled to JSON, then unmarshalled into a generic map. +// rules is a slice of capRule objects that may define additional (extra) claims to merge. +// +// These extra claims are flattened and merged into the base map unless they conflict with protected claims. +// Claims defined in openIDSupportedClaims are considered protected and cannot be overwritten. +// If an extra claim attempts to overwrite a protected claim, an error is returned. +// +// Returns the merged claims map or an error if any protected claim is violated or JSON (un)marshaling fails. +func withExtraClaims(v any, rules []capRule) (map[string]any, error) { + // Marshal the static struct + data, err := json.Marshal(v) + if err != nil { + return nil, err + } + + // Unmarshal into a generic map + var claimMap map[string]any + if err := json.Unmarshal(data, &claimMap); err != nil { + return nil, err + } + + // Convert views.Slice to a map[string]struct{} for efficient lookup + protected := make(map[string]struct{}, len(openIDSupportedClaims.AsSlice())) + for _, claim := range openIDSupportedClaims.AsSlice() { + protected[claim] = struct{}{} + } + + // Merge extra claims + extra := flattenExtraClaims(rules) + for k, v := range extra { + if _, isProtected := protected[k]; isProtected { + log.Printf("Skip overwriting of existing claim %q", k) + return nil, fmt.Errorf("extra claim %q overwriting existing claim", k) + } + + claimMap[k] = v + } + + return claimMap, nil +} + func (s *idpServer) serveToken(w http.ResponseWriter, r *http.Request) { if r.Method != "POST" { http.Error(w, "tsidp: method not allowed", http.StatusMethodNotAllowed) @@ -540,11 +865,58 @@ func (s *idpServer) serveToken(w http.ResponseWriter, r *http.Request) { http.Error(w, "tsidp: code not found", http.StatusBadRequest) return } - if err := ar.allowRelyingParty(r, s.lc); err != nil { - log.Printf("Error allowing relying party: %v", err) - http.Error(w, err.Error(), http.StatusForbidden) - return + + if !s.allowInsecureRegistration { + // When insecure registration is NOT allowed, always validate client credentials regardless of request source + clientID := r.FormValue("client_id") + clientSecret := r.FormValue("client_secret") + + // Try basic auth if form values are empty + if clientID == "" || clientSecret == "" { + if basicClientID, basicClientSecret, ok := r.BasicAuth(); ok { + if clientID == "" { + clientID = basicClientID + } + if clientSecret == "" { + clientSecret = basicClientSecret + } + } + } + + if clientID == "" || clientSecret == "" { + http.Error(w, "tsidp: client credentials required in when insecure registration is not allowed", http.StatusUnauthorized) + return + } + + // Validate against the stored auth request + if ar.clientID != clientID { + http.Error(w, "tsidp: client_id mismatch", http.StatusBadRequest) + return + } + + // Validate client credentials against stored clients + if ar.funnelRP == nil { + http.Error(w, "tsidp: no client information found", http.StatusBadRequest) + return + } + + clientIDcmp := subtle.ConstantTimeCompare([]byte(clientID), []byte(ar.funnelRP.ID)) + clientSecretcmp := subtle.ConstantTimeCompare([]byte(clientSecret), []byte(ar.funnelRP.Secret)) + if clientIDcmp != 1 || clientSecretcmp != 1 { + http.Error(w, "tsidp: invalid client credentials", http.StatusUnauthorized) + return + } + } else { + // Original behavior when insecure registration is allowed + // Only checks ClientID and Client Secret when over funnel. + // Local connections are allowed and tailnet connections only check matching nodeIDs. + if err := ar.allowRelyingParty(r, s.lc); err != nil { + log.Printf("Error allowing relying party: %v", err) + http.Error(w, err.Error(), http.StatusForbidden) + return + } } + if ar.redirectURI != r.FormValue("redirect_uri") { http.Error(w, "tsidp: redirect_uri mismatch", http.StatusBadRequest) return @@ -592,8 +964,22 @@ func (s *idpServer) serveToken(w http.ResponseWriter, r *http.Request) { tsClaims.Issuer = s.loopbackURL } + rules, err := tailcfg.UnmarshalCapJSON[capRule](who.CapMap, tailcfg.PeerCapabilityTsIDP) + if err != nil { + log.Printf("tsidp: failed to unmarshal capability: %v", err) + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + tsClaimsWithExtra, err := withExtraClaims(tsClaims, rules) + if err != nil { + log.Printf("tsidp: failed to merge extra claims: %v", err) + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + // Create an OIDC token using this issuer's signer. - token, err := jwt.Signed(signer).Claims(tsClaims).CompactSerialize() + token, err := jwt.Signed(signer).Claims(tsClaimsWithExtra).CompactSerialize() if err != nil { log.Printf("Error getting token: %v", err) http.Error(w, err.Error(), http.StatusInternalServerError) @@ -621,7 +1007,7 @@ type oidcTokenResponse struct { IDToken string `json:"id_token"` TokenType string `json:"token_type"` AccessToken string `json:"access_token"` - RefreshToken string `json:"refresh_token"` + RefreshToken string `json:"refresh_token,omitempty"` ExpiresIn int `json:"expires_in"` } @@ -648,8 +1034,12 @@ func (s *idpServer) oidcSigner() (jose.Signer, error) { func (s *idpServer) oidcPrivateKey() (*signingKey, error) { return s.lazySigningKey.GetErr(func() (*signingKey, error) { + keyPath, err := getConfigFilePath(s.rootPath, oidcKeyFile) + if err != nil { + return nil, fmt.Errorf("could not get OIDC key file path: %w", err) + } var sk signingKey - b, err := os.ReadFile("oidc-key.json") + b, err := os.ReadFile(keyPath) if err == nil { if err := sk.UnmarshalJSON(b); err == nil { return &sk, nil @@ -664,7 +1054,7 @@ func (s *idpServer) oidcPrivateKey() (*signingKey, error) { if err != nil { log.Fatalf("Error marshaling key: %v", err) } - if err := os.WriteFile("oidc-key.json", b, 0600); err != nil { + if err := os.WriteFile(keyPath, b, 0600); err != nil { log.Fatalf("Error writing key: %v", err) } return &sk, nil @@ -698,7 +1088,6 @@ func (s *idpServer) serveJWKS(w http.ResponseWriter, r *http.Request) { }); err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) } - return } // openIDProviderMetadata is a partial representation of @@ -762,28 +1151,54 @@ var ( ) func (s *idpServer) serveOpenIDConfig(w http.ResponseWriter, r *http.Request) { - if r.URL.Path != oidcConfigPath { - http.Error(w, "tsidp: not found", http.StatusNotFound) + h := w.Header() + h.Set("Access-Control-Allow-Origin", "*") + h.Set("Access-Control-Allow-Method", "GET, OPTIONS") + // allow all to prevent errors from client sending their own bespoke headers + // and having the server reject the request. + h.Set("Access-Control-Allow-Headers", "*") + + // early return for pre-flight OPTIONS requests. + if r.Method == "OPTIONS" { + w.WriteHeader(http.StatusOK) return } - ap, err := netip.ParseAddrPort(r.RemoteAddr) - if err != nil { - log.Printf("Error parsing remote addr: %v", err) + if r.URL.Path != oidcConfigPath { + http.Error(w, "tsidp: not found", http.StatusNotFound) return } + var authorizeEndpoint string rpEndpoint := s.serverURL - if isFunnelRequest(r) { - authorizeEndpoint = fmt.Sprintf("%s/authorize/funnel", s.serverURL) - } else if who, err := s.lc.WhoIs(r.Context(), r.RemoteAddr); err == nil { - authorizeEndpoint = fmt.Sprintf("%s/authorize/%d", s.serverURL, who.Node.ID) - } else if ap.Addr().IsLoopback() { - rpEndpoint = s.loopbackURL - authorizeEndpoint = fmt.Sprintf("%s/authorize/localhost", s.serverURL) + + if !s.allowInsecureRegistration { + // When insecure registration is NOT allowed, use a single authorization endpoint for all request types + // This will be the same regardless of if the user is on localhost, tailscale, or funnel. + authorizeEndpoint = fmt.Sprintf("%s/authorize", s.serverURL) + rpEndpoint = s.serverURL } else { - log.Printf("Error getting WhoIs: %v", err) - http.Error(w, err.Error(), http.StatusInternalServerError) - return + // When insecure registration is allowed TSIDP uses the requestors nodeID + // (typically that of the resource server during auto discovery) when on the tailnet + // and adds it to the authorize URL as a replacement clientID for when the user authorizes. + // The behavior over funnel drops the nodeID & clientID replacement behvaior and does require a + // previously created clientID and client secret. + ap, err := netip.ParseAddrPort(r.RemoteAddr) + if err != nil { + log.Printf("Error parsing remote addr: %v", err) + return + } + if isFunnelRequest(r) { + authorizeEndpoint = fmt.Sprintf("%s/authorize/funnel", s.serverURL) + } else if who, err := s.lc.WhoIs(r.Context(), r.RemoteAddr); err == nil { + authorizeEndpoint = fmt.Sprintf("%s/authorize/%d", s.serverURL, who.Node.ID) + } else if ap.Addr().IsLoopback() { + rpEndpoint = s.loopbackURL + authorizeEndpoint = fmt.Sprintf("%s/authorize/localhost", s.serverURL) + } else { + log.Printf("Error getting WhoIs: %v", err) + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } } w.Header().Set("Content-Type", "application/json") @@ -937,14 +1352,27 @@ func (s *idpServer) serveDeleteClient(w http.ResponseWriter, r *http.Request, cl } // storeFunnelClientsLocked writes the current mapping of OIDC client ID/secret -// pairs for RPs that access the IDP over funnel. s.mu must be held while -// calling this. +// pairs for RPs that access the IDP. When insecure registration is NOT allowed, uses oauth-clients.json; +// otherwise uses oidc-funnel-clients.json. s.mu must be held while calling this. func (s *idpServer) storeFunnelClientsLocked() error { var buf bytes.Buffer if err := json.NewEncoder(&buf).Encode(s.funnelClients); err != nil { return err } - return os.WriteFile(funnelClientsFile, buf.Bytes(), 0600) + + var clientsFilePath string + var err error + if !s.allowInsecureRegistration { + clientsFilePath, err = getConfigFilePath(s.rootPath, oauthClientsFile) + } else { + clientsFilePath, err = getConfigFilePath(s.rootPath, funnelClientsFile) + } + + if err != nil { + return fmt.Errorf("storeFunnelClientsLocked: %v", err) + } + + return os.WriteFile(clientsFilePath, buf.Bytes(), 0600) } const ( @@ -1057,3 +1485,76 @@ func isFunnelRequest(r *http.Request) bool { } return false } + +// migrateOAuthClients migrates from oidc-funnel-clients.json to oauth-clients.json. +// If oauth-clients.json already exists, no migration is performed. +// If both files are missing a new configuration is created. +// The path to the new configuration file is returned. +func migrateOAuthClients(rootPath string) (string, error) { + // First, check for oauth-clients.json (new file) + oauthPath, err := getConfigFilePath(rootPath, oauthClientsFile) + if err != nil { + return "", fmt.Errorf("could not get oauth clients file path: %w", err) + } + if _, err := os.Stat(oauthPath); err == nil { + // oauth-clients.json already exists, use it + return oauthPath, nil + } + + // Check for old oidc-funnel-clients.json + oldPath, err := getConfigFilePath(rootPath, funnelClientsFile) + if err != nil { + return "", fmt.Errorf("could not get funnel clients file path: %w", err) + } + if _, err := os.Stat(oldPath); err == nil { + // Old file exists, migrate it + log.Printf("Migrating OAuth clients from %s to %s", oldPath, oauthPath) + + // Read the old file + data, err := os.ReadFile(oldPath) + if err != nil { + return "", fmt.Errorf("could not read old funnel clients file: %w", err) + } + + // Write to new location + if err := os.WriteFile(oauthPath, data, 0600); err != nil { + return "", fmt.Errorf("could not write new oauth clients file: %w", err) + } + + // Rename old file to deprecated name + deprecatedPath, err := getConfigFilePath(rootPath, deprecatedFunnelClientsFile) + if err != nil { + return "", fmt.Errorf("could not get deprecated file path: %w", err) + } + if err := os.Rename(oldPath, deprecatedPath); err != nil { + log.Printf("Warning: could not rename old file to deprecated name: %v", err) + } else { + log.Printf("Renamed old file to %s", deprecatedPath) + } + + return oauthPath, nil + } + + // Neither file exists, create empty oauth-clients.json + log.Printf("Creating empty OAuth clients file at %s", oauthPath) + if err := os.WriteFile(oauthPath, []byte("{}"), 0600); err != nil { + return "", fmt.Errorf("could not create empty oauth clients file: %w", err) + } + + return oauthPath, nil +} + +// getConfigFilePath returns the path to the config file for the given file name. +// The oidc-key.json and funnel-clients.json files were originally opened and written +// to without paths, and ended up in /root or home directory of the user running +// the process. To maintain backward compatibility, we return the naked file name if that +// file exists already, otherwise we return the full path in the rootPath. +func getConfigFilePath(rootPath string, fileName string) (string, error) { + if _, err := os.Stat(fileName); err == nil { + return fileName, nil + } else if errors.Is(err, os.ErrNotExist) { + return filepath.Join(rootPath, fileName), nil + } else { + return "", err + } +} diff --git a/cmd/tsidp/tsidp_test.go b/cmd/tsidp/tsidp_test.go new file mode 100644 index 000000000..4f5af9e59 --- /dev/null +++ b/cmd/tsidp/tsidp_test.go @@ -0,0 +1,2063 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package main tests for tsidp focus on OAuth security boundaries and +// correct implementation of the OpenID Connect identity provider. +// +// Test Strategy: +// - Tests are intentionally granular to provide clear failure signals when +// security-critical logic breaks +// - OAuth flow tests cover both strict mode (registered clients only) and +// legacy mode (local funnel clients) to ensure proper access controls +// - Helper functions like normalizeMap ensure deterministic comparisons +// despite JSON marshaling order variations +// - The privateKey global is reused across tests for performance (RSA key +// generation is expensive) + +package main + +import ( + "crypto/rand" + "crypto/rsa" + "encoding/json" + "errors" + "fmt" + "io" + "log" + "net/http" + "net/http/httptest" + "net/netip" + "net/url" + "os" + "path/filepath" + "reflect" + "sort" + "strings" + "sync" + "testing" + "time" + + "gopkg.in/square/go-jose.v2" + "gopkg.in/square/go-jose.v2/jwt" + "tailscale.com/client/local" + "tailscale.com/client/tailscale/apitype" + "tailscale.com/tailcfg" + "tailscale.com/types/key" + "tailscale.com/types/opt" + "tailscale.com/types/views" +) + +// normalizeMap recursively sorts []any values in a map[string]any to ensure +// deterministic test comparisons. This is necessary because JSON marshaling +// doesn't guarantee array order, and we need stable comparisons when testing +// claim merging and flattening logic. +func normalizeMap(t *testing.T, m map[string]any) map[string]any { + t.Helper() + normalized := make(map[string]any, len(m)) + for k, v := range m { + switch val := v.(type) { + case []any: + sorted := make([]string, len(val)) + for i, item := range val { + sorted[i] = fmt.Sprintf("%v", item) // convert everything to string for sorting + } + sort.Strings(sorted) + + // convert back to []any + sortedIface := make([]any, len(sorted)) + for i, s := range sorted { + sortedIface[i] = s + } + normalized[k] = sortedIface + + default: + normalized[k] = v + } + } + return normalized +} + +func mustMarshalJSON(t *testing.T, v any) tailcfg.RawMessage { + t.Helper() + b, err := json.Marshal(v) + if err != nil { + panic(err) + } + return tailcfg.RawMessage(b) +} + +// privateKey is a shared RSA private key used across tests. It's lazily +// initialized on first use to avoid the expensive key generation cost +// for every test. Protected by privateKeyMu for thread safety. +var ( + privateKey *rsa.PrivateKey + privateKeyMu sync.Mutex +) + +func oidcTestingSigner(t *testing.T) jose.Signer { + t.Helper() + privKey := mustGeneratePrivateKey(t) + sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.RS256, Key: privKey}, nil) + if err != nil { + t.Fatalf("failed to create signer: %v", err) + } + return sig +} + +func oidcTestingPublicKey(t *testing.T) *rsa.PublicKey { + t.Helper() + privKey := mustGeneratePrivateKey(t) + return &privKey.PublicKey +} + +func mustGeneratePrivateKey(t *testing.T) *rsa.PrivateKey { + t.Helper() + privateKeyMu.Lock() + defer privateKeyMu.Unlock() + + if privateKey != nil { + return privateKey + } + + var err error + privateKey, err = rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatalf("failed to generate key: %v", err) + } + + return privateKey +} + +func TestFlattenExtraClaims(t *testing.T) { + log.SetOutput(io.Discard) // suppress log output during tests + + tests := []struct { + name string + input []capRule + expected map[string]any + }{ + { + name: "empty extra claims", + input: []capRule{ + {ExtraClaims: map[string]any{}}, + }, + expected: map[string]any{}, + }, + { + name: "string and number values", + input: []capRule{ + { + ExtraClaims: map[string]any{ + "featureA": "read", + "featureB": 42, + }, + }, + }, + expected: map[string]any{ + "featureA": "read", + "featureB": "42", + }, + }, + { + name: "slice of strings and ints", + input: []capRule{ + { + ExtraClaims: map[string]any{ + "roles": []any{"admin", "user", 1}, + }, + }, + }, + expected: map[string]any{ + "roles": []any{"admin", "user", "1"}, + }, + }, + { + name: "duplicate values deduplicated (slice input)", + input: []capRule{ + { + ExtraClaims: map[string]any{ + "foo": []string{"bar", "baz"}, + }, + }, + { + ExtraClaims: map[string]any{ + "foo": []any{"bar", "qux"}, + }, + }, + }, + expected: map[string]any{ + "foo": []any{"bar", "baz", "qux"}, + }, + }, + { + name: "ignore unsupported map type, keep valid scalar", + input: []capRule{ + { + ExtraClaims: map[string]any{ + "invalid": map[string]any{"bad": "yes"}, + "valid": "ok", + }, + }, + }, + expected: map[string]any{ + "valid": "ok", + }, + }, + { + name: "scalar first, slice second", + input: []capRule{ + {ExtraClaims: map[string]any{"foo": "bar"}}, + {ExtraClaims: map[string]any{"foo": []any{"baz"}}}, + }, + expected: map[string]any{ + "foo": []any{"bar", "baz"}, // converts to slice when any rule provides a slice + }, + }, + { + name: "conflicting scalar and unsupported map", + input: []capRule{ + {ExtraClaims: map[string]any{"foo": "bar"}}, + {ExtraClaims: map[string]any{"foo": map[string]any{"bad": "entry"}}}, + }, + expected: map[string]any{ + "foo": "bar", // map should be ignored + }, + }, + { + name: "multiple slices with overlap", + input: []capRule{ + {ExtraClaims: map[string]any{"roles": []any{"admin", "user"}}}, + {ExtraClaims: map[string]any{"roles": []any{"admin", "guest"}}}, + }, + expected: map[string]any{ + "roles": []any{"admin", "user", "guest"}, + }, + }, + { + name: "slice with unsupported values", + input: []capRule{ + {ExtraClaims: map[string]any{ + "mixed": []any{"ok", 42, map[string]string{"oops": "fail"}}, + }}, + }, + expected: map[string]any{ + "mixed": []any{"ok", "42"}, // map is ignored + }, + }, + { + name: "duplicate scalar value", + input: []capRule{ + {ExtraClaims: map[string]any{"env": "prod"}}, + {ExtraClaims: map[string]any{"env": "prod"}}, + }, + expected: map[string]any{ + "env": "prod", // not converted to slice + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := flattenExtraClaims(tt.input) + + gotNormalized := normalizeMap(t, got) + expectedNormalized := normalizeMap(t, tt.expected) + + if !reflect.DeepEqual(gotNormalized, expectedNormalized) { + t.Errorf("mismatch\nGot:\n%s\nWant:\n%s", gotNormalized, expectedNormalized) + } + }) + } +} + +func TestExtraClaims(t *testing.T) { + tests := []struct { + name string + claim tailscaleClaims + extraClaims []capRule + expected map[string]any + expectError bool + }{ + { + name: "extra claim", + claim: tailscaleClaims{ + Claims: jwt.Claims{}, + Nonce: "foobar", + Key: key.NodePublic{}, + Addresses: views.Slice[netip.Prefix]{}, + NodeID: 0, + NodeName: "test-node", + Tailnet: "test.ts.net", + Email: "test@example.com", + UserID: 0, + UserName: "test", + }, + extraClaims: []capRule{ + { + ExtraClaims: map[string]any{ + "foo": []string{"bar"}, + }, + }, + }, + expected: map[string]any{ + "nonce": "foobar", + "key": "nodekey:0000000000000000000000000000000000000000000000000000000000000000", + "addresses": nil, + "nid": float64(0), + "node": "test-node", + "tailnet": "test.ts.net", + "email": "test@example.com", + "username": "test", + "foo": []any{"bar"}, + }, + }, + { + name: "duplicate claim distinct values", + claim: tailscaleClaims{ + Claims: jwt.Claims{}, + Nonce: "foobar", + Key: key.NodePublic{}, + Addresses: views.Slice[netip.Prefix]{}, + NodeID: 0, + NodeName: "test-node", + Tailnet: "test.ts.net", + Email: "test@example.com", + UserID: 0, + UserName: "test", + }, + extraClaims: []capRule{ + { + ExtraClaims: map[string]any{ + "foo": []string{"bar"}, + }, + }, + { + ExtraClaims: map[string]any{ + "foo": []string{"foobar"}, + }, + }, + }, + expected: map[string]any{ + "nonce": "foobar", + "key": "nodekey:0000000000000000000000000000000000000000000000000000000000000000", + "addresses": nil, + "nid": float64(0), + "node": "test-node", + "tailnet": "test.ts.net", + "email": "test@example.com", + "username": "test", + "foo": []any{"foobar", "bar"}, + }, + }, + { + name: "multiple extra claims", + claim: tailscaleClaims{ + Claims: jwt.Claims{}, + Nonce: "foobar", + Key: key.NodePublic{}, + Addresses: views.Slice[netip.Prefix]{}, + NodeID: 0, + NodeName: "test-node", + Tailnet: "test.ts.net", + Email: "test@example.com", + UserID: 0, + UserName: "test", + }, + extraClaims: []capRule{ + { + ExtraClaims: map[string]any{ + "foo": []string{"bar"}, + }, + }, + { + ExtraClaims: map[string]any{ + "bar": []string{"foo"}, + }, + }, + }, + expected: map[string]any{ + "nonce": "foobar", + "key": "nodekey:0000000000000000000000000000000000000000000000000000000000000000", + "addresses": nil, + "nid": float64(0), + "node": "test-node", + "tailnet": "test.ts.net", + "email": "test@example.com", + "username": "test", + "foo": []any{"bar"}, + "bar": []any{"foo"}, + }, + }, + { + name: "overwrite claim", + claim: tailscaleClaims{ + Claims: jwt.Claims{}, + Nonce: "foobar", + Key: key.NodePublic{}, + Addresses: views.Slice[netip.Prefix]{}, + NodeID: 0, + NodeName: "test-node", + Tailnet: "test.ts.net", + Email: "test@example.com", + UserID: 0, + UserName: "test", + }, + extraClaims: []capRule{ + { + ExtraClaims: map[string]any{ + "username": "foobar", + }, + }, + }, + expected: map[string]any{ + "nonce": "foobar", + "key": "nodekey:0000000000000000000000000000000000000000000000000000000000000000", + "addresses": nil, + "nid": float64(0), + "node": "test-node", + "tailnet": "test.ts.net", + "email": "test@example.com", + "username": "foobar", + }, + expectError: true, + }, + { + name: "empty extra claims", + claim: tailscaleClaims{ + Claims: jwt.Claims{}, + Nonce: "foobar", + Key: key.NodePublic{}, + Addresses: views.Slice[netip.Prefix]{}, + NodeID: 0, + NodeName: "test-node", + Tailnet: "test.ts.net", + Email: "test@example.com", + UserID: 0, + UserName: "test", + }, + extraClaims: []capRule{{ExtraClaims: map[string]any{}}}, + expected: map[string]any{ + "nonce": "foobar", + "key": "nodekey:0000000000000000000000000000000000000000000000000000000000000000", + "addresses": nil, + "nid": float64(0), + "node": "test-node", + "tailnet": "test.ts.net", + "email": "test@example.com", + "username": "test", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + claims, err := withExtraClaims(tt.claim, tt.extraClaims) + if err != nil && !tt.expectError { + t.Fatalf("claim.withExtraClaims() unexpected error = %v", err) + } else if err == nil && tt.expectError { + t.Fatalf("expected error, got nil") + } else if err != nil && tt.expectError { + return // just as expected + } + + // Marshal to JSON then unmarshal back to map[string]any + gotClaims, err := json.Marshal(claims) + if err != nil { + t.Errorf("json.Marshal(claims) error = %v", err) + } + + var gotClaimsMap map[string]any + if err := json.Unmarshal(gotClaims, &gotClaimsMap); err != nil { + t.Fatalf("json.Unmarshal(gotClaims) error = %v", err) + } + + gotNormalized := normalizeMap(t, gotClaimsMap) + expectedNormalized := normalizeMap(t, tt.expected) + + if !reflect.DeepEqual(gotNormalized, expectedNormalized) { + t.Errorf("claims mismatch:\n got: %#v\nwant: %#v", gotNormalized, expectedNormalized) + } + }) + } +} + +func TestServeToken(t *testing.T) { + tests := []struct { + name string + caps tailcfg.PeerCapMap + method string + grantType string + code string + omitCode bool + redirectURI string + remoteAddr string + strictMode bool + expectError bool + expected map[string]any + }{ + { + name: "GET not allowed", + method: "GET", + grantType: "authorization_code", + strictMode: false, + expectError: true, + }, + { + name: "unsupported grant type", + method: "POST", + grantType: "pkcs", + strictMode: false, + expectError: true, + }, + { + name: "invalid code", + method: "POST", + grantType: "authorization_code", + code: "invalid-code", + strictMode: false, + expectError: true, + }, + { + name: "omit code from form", + method: "POST", + grantType: "authorization_code", + omitCode: true, + strictMode: false, + expectError: true, + }, + { + name: "invalid redirect uri", + method: "POST", + grantType: "authorization_code", + code: "valid-code", + redirectURI: "https://invalid.example.com/callback", + remoteAddr: "127.0.0.1:12345", + strictMode: false, + expectError: true, + }, + { + name: "invalid remoteAddr", + method: "POST", + grantType: "authorization_code", + redirectURI: "https://rp.example.com/callback", + code: "valid-code", + remoteAddr: "192.168.0.1:12345", + strictMode: false, + expectError: true, + }, + { + name: "extra claim included (non-strict)", + method: "POST", + grantType: "authorization_code", + redirectURI: "https://rp.example.com/callback", + code: "valid-code", + remoteAddr: "127.0.0.1:12345", + strictMode: false, + caps: tailcfg.PeerCapMap{ + tailcfg.PeerCapabilityTsIDP: { + mustMarshalJSON(t, capRule{ + IncludeInUserInfo: true, + ExtraClaims: map[string]any{ + "foo": "bar", + }, + }), + }, + }, + expected: map[string]any{ + "foo": "bar", + }, + }, + { + name: "attempt to overwrite protected claim (non-strict)", + method: "POST", + grantType: "authorization_code", + redirectURI: "https://rp.example.com/callback", + code: "valid-code", + strictMode: false, + caps: tailcfg.PeerCapMap{ + tailcfg.PeerCapabilityTsIDP: { + mustMarshalJSON(t, capRule{ + IncludeInUserInfo: true, + ExtraClaims: map[string]any{ + "sub": "should-not-overwrite", + }, + }), + }, + }, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + now := time.Now() + + // Use setupTestServer helper + s := setupTestServer(t, tt.strictMode) + + // Fake user/node + profile := &tailcfg.UserProfile{ + LoginName: "alice@example.com", + DisplayName: "Alice Example", + ProfilePicURL: "https://example.com/alice.jpg", + } + node := &tailcfg.Node{ + ID: 123, + Name: "test-node.test.ts.net.", + User: 456, + Key: key.NodePublic{}, + Cap: 1, + DiscoKey: key.DiscoPublic{}, + } + + remoteUser := &apitype.WhoIsResponse{ + Node: node, + UserProfile: profile, + CapMap: tt.caps, + } + + // Setup auth request with appropriate configuration for strict mode + var funnelClientPtr *funnelClient + if tt.strictMode { + funnelClientPtr = &funnelClient{ + ID: "client-id", + Secret: "test-secret", + Name: "Test Client", + RedirectURI: "https://rp.example.com/callback", + } + s.funnelClients["client-id"] = funnelClientPtr + } + + s.code["valid-code"] = &authRequest{ + clientID: "client-id", + nonce: "nonce123", + redirectURI: "https://rp.example.com/callback", + validTill: now.Add(5 * time.Minute), + remoteUser: remoteUser, + localRP: !tt.strictMode, + funnelRP: funnelClientPtr, + } + + form := url.Values{} + form.Set("grant_type", tt.grantType) + form.Set("redirect_uri", tt.redirectURI) + if !tt.omitCode { + form.Set("code", tt.code) + } + // Add client credentials for strict mode + if tt.strictMode { + form.Set("client_id", "client-id") + form.Set("client_secret", "test-secret") + } + + req := httptest.NewRequest(tt.method, "/token", strings.NewReader(form.Encode())) + req.RemoteAddr = tt.remoteAddr + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr := httptest.NewRecorder() + + s.serveToken(rr, req) + + if tt.expectError { + if rr.Code == http.StatusOK { + t.Fatalf("expected error, got 200 OK: %s", rr.Body.String()) + } + return + } + + if rr.Code != http.StatusOK { + t.Fatalf("expected 200 OK, got %d: %s", rr.Code, rr.Body.String()) + } + + var resp struct { + IDToken string `json:"id_token"` + } + if err := json.Unmarshal(rr.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal response: %v", err) + } + + tok, err := jwt.ParseSigned(resp.IDToken) + if err != nil { + t.Fatalf("failed to parse ID token: %v", err) + } + + out := make(map[string]any) + if err := tok.Claims(oidcTestingPublicKey(t), &out); err != nil { + t.Fatalf("failed to extract claims: %v", err) + } + + for k, want := range tt.expected { + got, ok := out[k] + if !ok { + t.Errorf("missing expected claim %q", k) + continue + } + if !reflect.DeepEqual(got, want) { + t.Errorf("claim %q: got %v, want %v", k, got, want) + } + } + }) + } +} + +func TestExtraUserInfo(t *testing.T) { + tests := []struct { + name string + caps tailcfg.PeerCapMap + tokenValidTill time.Time + expected map[string]any + expectError bool + }{ + { + name: "extra claim", + tokenValidTill: time.Now().Add(1 * time.Minute), + caps: tailcfg.PeerCapMap{ + tailcfg.PeerCapabilityTsIDP: { + mustMarshalJSON(t, capRule{ + IncludeInUserInfo: true, + ExtraClaims: map[string]any{ + "foo": []string{"bar"}, + }, + }), + }, + }, + expected: map[string]any{ + "foo": []any{"bar"}, + }, + }, + { + name: "duplicate claim distinct values", + tokenValidTill: time.Now().Add(1 * time.Minute), + caps: tailcfg.PeerCapMap{ + tailcfg.PeerCapabilityTsIDP: { + mustMarshalJSON(t, capRule{ + IncludeInUserInfo: true, + ExtraClaims: map[string]any{ + "foo": []string{"bar", "foobar"}, + }, + }), + }, + }, + expected: map[string]any{ + "foo": []any{"bar", "foobar"}, + }, + }, + { + name: "multiple extra claims", + tokenValidTill: time.Now().Add(1 * time.Minute), + caps: tailcfg.PeerCapMap{ + tailcfg.PeerCapabilityTsIDP: { + mustMarshalJSON(t, capRule{ + IncludeInUserInfo: true, + ExtraClaims: map[string]any{ + "foo": "bar", + "bar": "foo", + }, + }), + }, + }, + expected: map[string]any{ + "foo": "bar", + "bar": "foo", + }, + }, + { + name: "empty extra claims", + caps: tailcfg.PeerCapMap{}, + tokenValidTill: time.Now().Add(1 * time.Minute), + expected: map[string]any{}, + }, + { + name: "attempt to overwrite protected claim", + tokenValidTill: time.Now().Add(1 * time.Minute), + caps: tailcfg.PeerCapMap{ + tailcfg.PeerCapabilityTsIDP: { + mustMarshalJSON(t, capRule{ + IncludeInUserInfo: true, + ExtraClaims: map[string]any{ + "sub": "should-not-overwrite", + "foo": "ok", + }, + }), + }, + }, + expectError: true, + }, + { + name: "extra claim omitted", + tokenValidTill: time.Now().Add(1 * time.Minute), + caps: tailcfg.PeerCapMap{ + tailcfg.PeerCapabilityTsIDP: { + mustMarshalJSON(t, capRule{ + IncludeInUserInfo: false, + ExtraClaims: map[string]any{ + "foo": "ok", + }, + }), + }, + }, + expected: map[string]any{}, + }, + { + name: "expired token", + caps: tailcfg.PeerCapMap{}, + tokenValidTill: time.Now().Add(-1 * time.Minute), + expected: map[string]any{}, + expectError: true, + }, + } + token := "valid-token" + + // Create a fake tailscale Node + node := &tailcfg.Node{ + ID: 123, + Name: "test-node.test.ts.net.", + User: 456, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + + // Construct the remote user + profile := tailcfg.UserProfile{ + LoginName: "alice@example.com", + DisplayName: "Alice Example", + ProfilePicURL: "https://example.com/alice.jpg", + } + + remoteUser := &apitype.WhoIsResponse{ + Node: node, + UserProfile: &profile, + CapMap: tt.caps, + } + + // Insert a valid token into the idpServer + s := &idpServer{ + allowInsecureRegistration: true, // Default to allowing insecure registration for backward compatibility + accessToken: map[string]*authRequest{ + token: { + validTill: tt.tokenValidTill, + remoteUser: remoteUser, + }, + }, + } + + // Construct request + req := httptest.NewRequest("GET", "/userinfo", nil) + req.Header.Set("Authorization", "Bearer "+token) + rr := httptest.NewRecorder() + + // Call the method under test + s.serveUserInfo(rr, req) + + if tt.expectError { + if rr.Code == http.StatusOK { + t.Fatalf("expected error, got %d: %s", rr.Code, rr.Body.String()) + } + return + } + + if rr.Code != http.StatusOK { + t.Fatalf("expected 200 OK, got %d: %s", rr.Code, rr.Body.String()) + } + + var resp map[string]any + if err := json.Unmarshal(rr.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to parse JSON response: %v", err) + } + + // Construct expected + tt.expected["sub"] = remoteUser.Node.User.String() + tt.expected["name"] = profile.DisplayName + tt.expected["email"] = profile.LoginName + tt.expected["picture"] = profile.ProfilePicURL + tt.expected["username"], _, _ = strings.Cut(profile.LoginName, "@") + + gotNormalized := normalizeMap(t, resp) + expectedNormalized := normalizeMap(t, tt.expected) + + if !reflect.DeepEqual(gotNormalized, expectedNormalized) { + t.Errorf("UserInfo mismatch:\n got: %#v\nwant: %#v", gotNormalized, expectedNormalized) + } + }) + } +} + +func TestFunnelClientsPersistence(t *testing.T) { + testClients := map[string]*funnelClient{ + "test-client-1": { + ID: "test-client-1", + Secret: "test-secret-1", + Name: "Test Client 1", + RedirectURI: "https://example.com/callback", + }, + "test-client-2": { + ID: "test-client-2", + Secret: "test-secret-2", + Name: "Test Client 2", + RedirectURI: "https://example2.com/callback", + }, + } + + testData, err := json.Marshal(testClients) + if err != nil { + t.Fatalf("failed to marshal test data: %v", err) + } + + tmpFile := t.TempDir() + "/oidc-funnel-clients.json" + if err := os.WriteFile(tmpFile, testData, 0600); err != nil { + t.Fatalf("failed to write test file: %v", err) + } + + t.Run("load_from_existing_file", func(t *testing.T) { + srv := &idpServer{} + + // Simulate the funnel clients loading logic from main() + srv.funnelClients = make(map[string]*funnelClient) + f, err := os.Open(tmpFile) + if err == nil { + if err := json.NewDecoder(f).Decode(&srv.funnelClients); err != nil { + t.Fatalf("could not parse %s: %v", tmpFile, err) + } + f.Close() + } else if !errors.Is(err, os.ErrNotExist) { + t.Fatalf("could not open %s: %v", tmpFile, err) + } + + // Verify clients were loaded correctly + if len(srv.funnelClients) != 2 { + t.Errorf("expected 2 clients, got %d", len(srv.funnelClients)) + } + + client1, ok := srv.funnelClients["test-client-1"] + if !ok { + t.Error("expected test-client-1 to be loaded") + } else { + if client1.Name != "Test Client 1" { + t.Errorf("expected client name 'Test Client 1', got '%s'", client1.Name) + } + if client1.Secret != "test-secret-1" { + t.Errorf("expected client secret 'test-secret-1', got '%s'", client1.Secret) + } + } + }) + + t.Run("initialize_empty_when_no_file", func(t *testing.T) { + nonExistentFile := t.TempDir() + "/non-existent.json" + + srv := &idpServer{} + + // Simulate the funnel clients loading logic from main() + srv.funnelClients = make(map[string]*funnelClient) + f, err := os.Open(nonExistentFile) + if err == nil { + if err := json.NewDecoder(f).Decode(&srv.funnelClients); err != nil { + t.Fatalf("could not parse %s: %v", nonExistentFile, err) + } + f.Close() + } else if !errors.Is(err, os.ErrNotExist) { + t.Fatalf("could not open %s: %v", nonExistentFile, err) + } + + // Verify map is initialized but empty + if srv.funnelClients == nil { + t.Error("expected funnelClients map to be initialized") + } + if len(srv.funnelClients) != 0 { + t.Errorf("expected empty map, got %d clients", len(srv.funnelClients)) + } + }) + + t.Run("persist_and_reload_clients", func(t *testing.T) { + tmpFile2 := t.TempDir() + "/test-persistence.json" + + // Create initial server with one client + srv1 := &idpServer{ + funnelClients: make(map[string]*funnelClient), + } + srv1.funnelClients["new-client"] = &funnelClient{ + ID: "new-client", + Secret: "new-secret", + Name: "New Client", + RedirectURI: "https://new.example.com/callback", + } + + // Save clients to file (simulating saveFunnelClients) + data, err := json.Marshal(srv1.funnelClients) + if err != nil { + t.Fatalf("failed to marshal clients: %v", err) + } + if err := os.WriteFile(tmpFile2, data, 0600); err != nil { + t.Fatalf("failed to write clients file: %v", err) + } + + // Create new server instance and load clients + srv2 := &idpServer{} + srv2.funnelClients = make(map[string]*funnelClient) + f, err := os.Open(tmpFile2) + if err == nil { + if err := json.NewDecoder(f).Decode(&srv2.funnelClients); err != nil { + t.Fatalf("could not parse %s: %v", tmpFile2, err) + } + f.Close() + } else if !errors.Is(err, os.ErrNotExist) { + t.Fatalf("could not open %s: %v", tmpFile2, err) + } + + // Verify the client was persisted correctly + loadedClient, ok := srv2.funnelClients["new-client"] + if !ok { + t.Error("expected new-client to be loaded after persistence") + } else { + if loadedClient.Name != "New Client" { + t.Errorf("expected client name 'New Client', got '%s'", loadedClient.Name) + } + if loadedClient.Secret != "new-secret" { + t.Errorf("expected client secret 'new-secret', got '%s'", loadedClient.Secret) + } + } + }) + + t.Run("strict_mode_file_handling", func(t *testing.T) { + tmpDir := t.TempDir() + + // Test strict mode uses oauth-clients.json + srv1 := setupTestServer(t, true) + srv1.rootPath = tmpDir + srv1.funnelClients["oauth-client"] = &funnelClient{ + ID: "oauth-client", + Secret: "oauth-secret", + Name: "OAuth Client", + RedirectURI: "https://oauth.example.com/callback", + } + + // Test storeFunnelClientsLocked in strict mode + srv1.mu.Lock() + err := srv1.storeFunnelClientsLocked() + srv1.mu.Unlock() + + if err != nil { + t.Fatalf("failed to store clients in strict mode: %v", err) + } + + // Verify oauth-clients.json was created + oauthPath := tmpDir + "/" + oauthClientsFile + if _, err := os.Stat(oauthPath); err != nil { + t.Errorf("expected oauth-clients.json to be created: %v", err) + } + + // Verify oidc-funnel-clients.json was NOT created + funnelPath := tmpDir + "/" + funnelClientsFile + if _, err := os.Stat(funnelPath); !os.IsNotExist(err) { + t.Error("expected oidc-funnel-clients.json NOT to be created in strict mode") + } + }) + + t.Run("non_strict_mode_file_handling", func(t *testing.T) { + tmpDir := t.TempDir() + + // Test non-strict mode uses oidc-funnel-clients.json + srv1 := setupTestServer(t, false) + srv1.rootPath = tmpDir + srv1.funnelClients["funnel-client"] = &funnelClient{ + ID: "funnel-client", + Secret: "funnel-secret", + Name: "Funnel Client", + RedirectURI: "https://funnel.example.com/callback", + } + + // Test storeFunnelClientsLocked in non-strict mode + srv1.mu.Lock() + err := srv1.storeFunnelClientsLocked() + srv1.mu.Unlock() + + if err != nil { + t.Fatalf("failed to store clients in non-strict mode: %v", err) + } + + // Verify oidc-funnel-clients.json was created + funnelPath := tmpDir + "/" + funnelClientsFile + if _, err := os.Stat(funnelPath); err != nil { + t.Errorf("expected oidc-funnel-clients.json to be created: %v", err) + } + + // Verify oauth-clients.json was NOT created + oauthPath := tmpDir + "/" + oauthClientsFile + if _, err := os.Stat(oauthPath); !os.IsNotExist(err) { + t.Error("expected oauth-clients.json NOT to be created in non-strict mode") + } + }) +} + +// Test helper functions for strict OAuth mode testing +func setupTestServer(t *testing.T, strictMode bool) *idpServer { + return setupTestServerWithClient(t, strictMode, nil) +} + +// setupTestServerWithClient creates a test server with an optional LocalClient. +// If lc is nil, the server will have no LocalClient (original behavior). +// If lc is provided, it will be used for WhoIs calls during testing. +func setupTestServerWithClient(t *testing.T, strictMode bool, lc *local.Client) *idpServer { + t.Helper() + + srv := &idpServer{ + allowInsecureRegistration: !strictMode, + code: make(map[string]*authRequest), + accessToken: make(map[string]*authRequest), + funnelClients: make(map[string]*funnelClient), + serverURL: "https://test.ts.net", + rootPath: t.TempDir(), + lc: lc, + } + + // Add a test client for funnel/strict mode testing + srv.funnelClients["test-client"] = &funnelClient{ + ID: "test-client", + Secret: "test-secret", + Name: "Test Client", + RedirectURI: "https://rp.example.com/callback", + } + + // Inject a working signer for token tests + srv.lazySigner.Set(oidcTestingSigner(t)) + + return srv +} + +func TestGetAllowInsecureRegistration(t *testing.T) { + tests := []struct { + name string + flagSet bool + flagValue bool + expectAllowInsecureRegistration bool + }{ + { + name: "flag explicitly set to false - insecure registration disabled (strict mode)", + flagSet: true, + flagValue: false, + expectAllowInsecureRegistration: false, + }, + { + name: "flag explicitly set to true - insecure registration enabled", + flagSet: true, + flagValue: true, + expectAllowInsecureRegistration: true, + }, + { + name: "flag unset - insecure registration enabled (default for backward compatibility)", + flagSet: false, + flagValue: false, // not used when unset + expectAllowInsecureRegistration: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Save original state + originalFlag := flagAllowInsecureRegistration + defer func() { + flagAllowInsecureRegistration = originalFlag + }() + + // Set up test state by creating a new BoolFlag and setting values + var b opt.Bool + flagAllowInsecureRegistration = opt.BoolFlag{Bool: &b} + if tt.flagSet { + flagAllowInsecureRegistration.Bool.Set(tt.flagValue) + } + // Note: when tt.flagSet is false, the Bool remains unset (which is what we want) + + got := getAllowInsecureRegistration() + if got != tt.expectAllowInsecureRegistration { + t.Errorf("getAllowInsecureRegistration() = %v, want %v", got, tt.expectAllowInsecureRegistration) + } + }) + } +} + +// TestMigrateOAuthClients verifies the migration from legacy funnel clients +// to OAuth clients. This migration is necessary when transitioning from +// non-strict to strict OAuth mode. The migration logic should: +// - Copy clients from oidc-funnel-clients.json to oauth-clients.json +// - Rename the old file to mark it as deprecated +// - Handle cases where files already exist or are missing +func TestMigrateOAuthClients(t *testing.T) { + tests := []struct { + name string + setupOldFile bool + setupNewFile bool + oldFileContent map[string]*funnelClient + newFileContent map[string]*funnelClient + expectError bool + expectNewFileExists bool + expectOldRenamed bool + }{ + { + name: "migrate from old file to new file", + setupOldFile: true, + oldFileContent: map[string]*funnelClient{ + "old-client": { + ID: "old-client", + Secret: "old-secret", + Name: "Old Client", + RedirectURI: "https://old.example.com/callback", + }, + }, + expectNewFileExists: true, + expectOldRenamed: true, + }, + { + name: "new file already exists - no migration", + setupNewFile: true, + newFileContent: map[string]*funnelClient{ + "existing-client": { + ID: "existing-client", + Secret: "existing-secret", + Name: "Existing Client", + RedirectURI: "https://existing.example.com/callback", + }, + }, + expectNewFileExists: true, + expectOldRenamed: false, + }, + { + name: "neither file exists - create empty new file", + expectNewFileExists: true, + expectOldRenamed: false, + }, + { + name: "both files exist - prefer new file", + setupOldFile: true, + setupNewFile: true, + oldFileContent: map[string]*funnelClient{ + "old-client": { + ID: "old-client", + Secret: "old-secret", + Name: "Old Client", + RedirectURI: "https://old.example.com/callback", + }, + }, + newFileContent: map[string]*funnelClient{ + "new-client": { + ID: "new-client", + Secret: "new-secret", + Name: "New Client", + RedirectURI: "https://new.example.com/callback", + }, + }, + expectNewFileExists: true, + expectOldRenamed: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rootPath := t.TempDir() + + // Setup old file if needed + if tt.setupOldFile { + oldData, err := json.Marshal(tt.oldFileContent) + if err != nil { + t.Fatalf("failed to marshal old file content: %v", err) + } + oldPath := rootPath + "/" + funnelClientsFile + if err := os.WriteFile(oldPath, oldData, 0600); err != nil { + t.Fatalf("failed to create old file: %v", err) + } + } + + // Setup new file if needed + if tt.setupNewFile { + newData, err := json.Marshal(tt.newFileContent) + if err != nil { + t.Fatalf("failed to marshal new file content: %v", err) + } + newPath := rootPath + "/" + oauthClientsFile + if err := os.WriteFile(newPath, newData, 0600); err != nil { + t.Fatalf("failed to create new file: %v", err) + } + } + + // Call migrateOAuthClients + resultPath, err := migrateOAuthClients(rootPath) + + if tt.expectError && err == nil { + t.Fatalf("expected error but got none") + } + if !tt.expectError && err != nil { + t.Fatalf("unexpected error: %v", err) + } + if tt.expectError { + return + } + + // Verify result path points to oauth-clients.json + expectedPath := filepath.Join(rootPath, oauthClientsFile) + if resultPath != expectedPath { + t.Errorf("expected result path %s, got %s", expectedPath, resultPath) + } + + // Verify new file exists if expected + if tt.expectNewFileExists { + if _, err := os.Stat(resultPath); err != nil { + t.Errorf("expected new file to exist at %s: %v", resultPath, err) + } + + // Verify content + data, err := os.ReadFile(resultPath) + if err != nil { + t.Fatalf("failed to read new file: %v", err) + } + + var clients map[string]*funnelClient + if err := json.Unmarshal(data, &clients); err != nil { + t.Fatalf("failed to unmarshal new file: %v", err) + } + + // Determine expected content + var expectedContent map[string]*funnelClient + if tt.setupNewFile { + expectedContent = tt.newFileContent + } else if tt.setupOldFile { + expectedContent = tt.oldFileContent + } else { + expectedContent = make(map[string]*funnelClient) + } + + if len(clients) != len(expectedContent) { + t.Errorf("expected %d clients, got %d", len(expectedContent), len(clients)) + } + + for id, expectedClient := range expectedContent { + actualClient, ok := clients[id] + if !ok { + t.Errorf("expected client %s not found", id) + continue + } + if actualClient.ID != expectedClient.ID || + actualClient.Secret != expectedClient.Secret || + actualClient.Name != expectedClient.Name || + actualClient.RedirectURI != expectedClient.RedirectURI { + t.Errorf("client %s mismatch: got %+v, want %+v", id, actualClient, expectedClient) + } + } + } + + // Verify old file renamed if expected + if tt.expectOldRenamed { + deprecatedPath := rootPath + "/" + deprecatedFunnelClientsFile + if _, err := os.Stat(deprecatedPath); err != nil { + t.Errorf("expected old file to be renamed to %s: %v", deprecatedPath, err) + } + + // Verify original old file is gone + oldPath := rootPath + "/" + funnelClientsFile + if _, err := os.Stat(oldPath); !os.IsNotExist(err) { + t.Errorf("expected old file %s to be removed", oldPath) + } + } + }) + } +} + +// TestGetConfigFilePath verifies backward compatibility for config file location. +// The function must check current directory first (legacy deployments) before +// falling back to rootPath (new installations) to prevent breaking existing +// tsidp deployments that have config files in unexpected locations. +func TestGetConfigFilePath(t *testing.T) { + tests := []struct { + name string + fileName string + createInCwd bool + createInRoot bool + expectInCwd bool + expectError bool + }{ + { + name: "file exists in current directory - use current directory", + fileName: "test-config.json", + createInCwd: true, + expectInCwd: true, + }, + { + name: "file does not exist - use root path", + fileName: "test-config.json", + createInCwd: false, + expectInCwd: false, + }, + { + name: "file exists in both - prefer current directory", + fileName: "test-config.json", + createInCwd: true, + createInRoot: true, + expectInCwd: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create temporary directories + rootPath := t.TempDir() + originalWd, err := os.Getwd() + if err != nil { + t.Fatalf("failed to get working directory: %v", err) + } + + // Create a temporary working directory + tmpWd := t.TempDir() + if err := os.Chdir(tmpWd); err != nil { + t.Fatalf("failed to change to temp directory: %v", err) + } + defer func() { + os.Chdir(originalWd) + }() + + // Setup files as needed + if tt.createInCwd { + if err := os.WriteFile(tt.fileName, []byte("{}"), 0600); err != nil { + t.Fatalf("failed to create file in cwd: %v", err) + } + } + if tt.createInRoot { + rootFilePath := filepath.Join(rootPath, tt.fileName) + if err := os.WriteFile(rootFilePath, []byte("{}"), 0600); err != nil { + t.Fatalf("failed to create file in root: %v", err) + } + } + + // Call getConfigFilePath + resultPath, err := getConfigFilePath(rootPath, tt.fileName) + + if tt.expectError && err == nil { + t.Fatalf("expected error but got none") + } + if !tt.expectError && err != nil { + t.Fatalf("unexpected error: %v", err) + } + if tt.expectError { + return + } + + // Verify result + if tt.expectInCwd { + if resultPath != tt.fileName { + t.Errorf("expected path %s, got %s", tt.fileName, resultPath) + } + } else { + expectedPath := filepath.Join(rootPath, tt.fileName) + if resultPath != expectedPath { + t.Errorf("expected path %s, got %s", expectedPath, resultPath) + } + } + }) + } +} + +// TestAuthorizeStrictMode verifies OAuth authorization endpoint security and validation logic. +// Tests both the security boundary (funnel rejection) and the business logic (strict mode validation). +func TestAuthorizeStrictMode(t *testing.T) { + tests := []struct { + name string + strictMode bool + clientID string + redirectURI string + state string + nonce string + setupClient bool + clientRedirect string + useFunnel bool // whether to simulate funnel request + mockWhoIsError bool // whether to make WhoIs return an error + expectError bool + expectCode int + expectRedirect bool + }{ + // Security boundary test: funnel rejection + { + name: "funnel requests are always rejected for security", + strictMode: true, + clientID: "test-client", + redirectURI: "https://rp.example.com/callback", + state: "random-state", + nonce: "random-nonce", + setupClient: true, + clientRedirect: "https://rp.example.com/callback", + useFunnel: true, + expectError: true, + expectCode: http.StatusUnauthorized, + }, + + // Strict mode parameter validation tests (non-funnel) + { + name: "strict mode - missing client_id", + strictMode: true, + clientID: "", + redirectURI: "https://rp.example.com/callback", + useFunnel: false, + expectError: true, + expectCode: http.StatusBadRequest, + }, + { + name: "strict mode - missing redirect_uri", + strictMode: true, + clientID: "test-client", + redirectURI: "", + useFunnel: false, + expectError: true, + expectCode: http.StatusBadRequest, + }, + + // Strict mode client validation tests (non-funnel) + { + name: "strict mode - invalid client_id", + strictMode: true, + clientID: "invalid-client", + redirectURI: "https://rp.example.com/callback", + setupClient: false, + useFunnel: false, + expectError: true, + expectCode: http.StatusBadRequest, + }, + { + name: "strict mode - redirect_uri mismatch", + strictMode: true, + clientID: "test-client", + redirectURI: "https://wrong.example.com/callback", + setupClient: true, + clientRedirect: "https://rp.example.com/callback", + useFunnel: false, + expectError: true, + expectCode: http.StatusBadRequest, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + srv := setupTestServer(t, tt.strictMode) + + // For non-funnel tests, we'll test the parameter validation logic + // without needing to mock WhoIs, since the validation happens before WhoIs calls + + // Setup client if needed + if tt.setupClient { + srv.funnelClients["test-client"] = &funnelClient{ + ID: "test-client", + Secret: "test-secret", + Name: "Test Client", + RedirectURI: tt.clientRedirect, + } + } else if !tt.strictMode { + // For non-strict mode tests that don't need a specific client setup + // but might reference one, clear the default client + delete(srv.funnelClients, "test-client") + } + + // Create request + reqURL := "/authorize" + if !tt.strictMode { + // In non-strict mode, use the node-specific endpoint + reqURL = "/authorize/123" + } + + query := url.Values{} + if tt.clientID != "" { + query.Set("client_id", tt.clientID) + } + if tt.redirectURI != "" { + query.Set("redirect_uri", tt.redirectURI) + } + if tt.state != "" { + query.Set("state", tt.state) + } + if tt.nonce != "" { + query.Set("nonce", tt.nonce) + } + + reqURL += "?" + query.Encode() + req := httptest.NewRequest("GET", reqURL, nil) + req.RemoteAddr = "127.0.0.1:12345" + + // Set funnel header only when explicitly testing funnel behavior + if tt.useFunnel { + req.Header.Set("Tailscale-Funnel-Request", "true") + } + + rr := httptest.NewRecorder() + srv.authorize(rr, req) + + if tt.expectError { + if rr.Code != tt.expectCode { + t.Errorf("expected status code %d, got %d: %s", tt.expectCode, rr.Code, rr.Body.String()) + } + } else if tt.expectRedirect { + if rr.Code != http.StatusFound { + t.Errorf("expected redirect (302), got %d: %s", rr.Code, rr.Body.String()) + } + + location := rr.Header().Get("Location") + if location == "" { + t.Error("expected Location header in redirect response") + } else { + // Parse the redirect URL to verify it contains a code + redirectURL, err := url.Parse(location) + if err != nil { + t.Errorf("failed to parse redirect URL: %v", err) + } else { + code := redirectURL.Query().Get("code") + if code == "" { + t.Error("expected 'code' parameter in redirect URL") + } + + // Verify state is preserved if provided + if tt.state != "" { + returnedState := redirectURL.Query().Get("state") + if returnedState != tt.state { + t.Errorf("expected state '%s', got '%s'", tt.state, returnedState) + } + } + + // Verify the auth request was stored + srv.mu.Lock() + ar, ok := srv.code[code] + srv.mu.Unlock() + + if !ok { + t.Error("expected authorization request to be stored") + } else { + if ar.clientID != tt.clientID { + t.Errorf("expected clientID '%s', got '%s'", tt.clientID, ar.clientID) + } + if ar.redirectURI != tt.redirectURI { + t.Errorf("expected redirectURI '%s', got '%s'", tt.redirectURI, ar.redirectURI) + } + if ar.nonce != tt.nonce { + t.Errorf("expected nonce '%s', got '%s'", tt.nonce, ar.nonce) + } + } + } + } + } else { + t.Errorf("unexpected test case: not expecting error or redirect") + } + }) + } +} + +// TestServeTokenWithClientValidation verifies OAuth token endpoint security in both strict and non-strict modes. +// In strict mode, the token endpoint must: +// - Require and validate client credentials (client_id + client_secret) +// - Only accept tokens from registered funnel clients +// - Validate that redirect_uri matches the registered client +// - Support both form-based and HTTP Basic authentication for client credentials +func TestServeTokenWithClientValidation(t *testing.T) { + tests := []struct { + name string + strictMode bool + method string + grantType string + code string + clientID string + clientSecret string + redirectURI string + useBasicAuth bool + setupAuthRequest bool + authRequestClient string + authRequestRedirect string + expectError bool + expectCode int + expectIDToken bool + }{ + { + name: "strict mode - valid token exchange with form credentials", + strictMode: true, + method: "POST", + grantType: "authorization_code", + code: "valid-code", + clientID: "test-client", + clientSecret: "test-secret", + redirectURI: "https://rp.example.com/callback", + setupAuthRequest: true, + authRequestClient: "test-client", + authRequestRedirect: "https://rp.example.com/callback", + expectIDToken: true, + }, + { + name: "strict mode - valid token exchange with basic auth", + strictMode: true, + method: "POST", + grantType: "authorization_code", + code: "valid-code", + redirectURI: "https://rp.example.com/callback", + useBasicAuth: true, + clientID: "test-client", + clientSecret: "test-secret", + setupAuthRequest: true, + authRequestClient: "test-client", + authRequestRedirect: "https://rp.example.com/callback", + expectIDToken: true, + }, + { + name: "strict mode - missing client credentials", + strictMode: true, + method: "POST", + grantType: "authorization_code", + code: "valid-code", + redirectURI: "https://rp.example.com/callback", + setupAuthRequest: true, + authRequestClient: "test-client", + authRequestRedirect: "https://rp.example.com/callback", + expectError: true, + expectCode: http.StatusUnauthorized, + }, + { + name: "strict mode - client_id mismatch", + strictMode: true, + method: "POST", + grantType: "authorization_code", + code: "valid-code", + clientID: "wrong-client", + clientSecret: "test-secret", + redirectURI: "https://rp.example.com/callback", + setupAuthRequest: true, + authRequestClient: "test-client", + expectError: true, + expectCode: http.StatusBadRequest, + }, + { + name: "strict mode - invalid client secret", + strictMode: true, + method: "POST", + grantType: "authorization_code", + code: "valid-code", + clientID: "test-client", + clientSecret: "wrong-secret", + redirectURI: "https://rp.example.com/callback", + setupAuthRequest: true, + authRequestClient: "test-client", + authRequestRedirect: "https://rp.example.com/callback", + expectError: true, + expectCode: http.StatusUnauthorized, + }, + { + name: "strict mode - redirect_uri mismatch", + strictMode: true, + method: "POST", + grantType: "authorization_code", + code: "valid-code", + clientID: "test-client", + clientSecret: "test-secret", + redirectURI: "https://wrong.example.com/callback", + setupAuthRequest: true, + authRequestClient: "test-client", + authRequestRedirect: "https://rp.example.com/callback", + expectError: true, + expectCode: http.StatusBadRequest, + }, + { + name: "non-strict mode - no client validation required", + strictMode: false, + method: "POST", + grantType: "authorization_code", + code: "valid-code", + redirectURI: "https://rp.example.com/callback", + setupAuthRequest: true, + authRequestRedirect: "https://rp.example.com/callback", + expectIDToken: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + srv := setupTestServer(t, tt.strictMode) + + // Setup authorization request if needed + if tt.setupAuthRequest { + now := time.Now() + profile := &tailcfg.UserProfile{ + LoginName: "alice@example.com", + DisplayName: "Alice Example", + ProfilePicURL: "https://example.com/alice.jpg", + } + node := &tailcfg.Node{ + ID: 123, + Name: "test-node.test.ts.net.", + User: 456, + Key: key.NodePublic{}, + Cap: 1, + DiscoKey: key.DiscoPublic{}, + } + remoteUser := &apitype.WhoIsResponse{ + Node: node, + UserProfile: profile, + CapMap: tailcfg.PeerCapMap{}, + } + + var funnelClientPtr *funnelClient + if tt.strictMode && tt.authRequestClient != "" { + funnelClientPtr = &funnelClient{ + ID: tt.authRequestClient, + Secret: "test-secret", + Name: "Test Client", + RedirectURI: tt.authRequestRedirect, + } + srv.funnelClients[tt.authRequestClient] = funnelClientPtr + } + + srv.code["valid-code"] = &authRequest{ + clientID: tt.authRequestClient, + nonce: "nonce123", + redirectURI: tt.authRequestRedirect, + validTill: now.Add(5 * time.Minute), + remoteUser: remoteUser, + localRP: !tt.strictMode, + funnelRP: funnelClientPtr, + } + } + + // Create form data + form := url.Values{} + form.Set("grant_type", tt.grantType) + form.Set("code", tt.code) + form.Set("redirect_uri", tt.redirectURI) + + if !tt.useBasicAuth { + if tt.clientID != "" { + form.Set("client_id", tt.clientID) + } + if tt.clientSecret != "" { + form.Set("client_secret", tt.clientSecret) + } + } + + req := httptest.NewRequest(tt.method, "/token", strings.NewReader(form.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.RemoteAddr = "127.0.0.1:12345" + + if tt.useBasicAuth && tt.clientID != "" && tt.clientSecret != "" { + req.SetBasicAuth(tt.clientID, tt.clientSecret) + } + + rr := httptest.NewRecorder() + srv.serveToken(rr, req) + + if tt.expectError { + if rr.Code != tt.expectCode { + t.Errorf("expected status code %d, got %d: %s", tt.expectCode, rr.Code, rr.Body.String()) + } + } else if tt.expectIDToken { + if rr.Code != http.StatusOK { + t.Errorf("expected 200 OK, got %d: %s", rr.Code, rr.Body.String()) + } + + var resp struct { + IDToken string `json:"id_token"` + AccessToken string `json:"access_token"` + TokenType string `json:"token_type"` + ExpiresIn int `json:"expires_in"` + } + + if err := json.Unmarshal(rr.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal response: %v", err) + } + + if resp.IDToken == "" { + t.Error("expected id_token in response") + } + if resp.AccessToken == "" { + t.Error("expected access_token in response") + } + if resp.TokenType != "Bearer" { + t.Errorf("expected token_type 'Bearer', got '%s'", resp.TokenType) + } + if resp.ExpiresIn != 300 { + t.Errorf("expected expires_in 300, got %d", resp.ExpiresIn) + } + + // Verify access token was stored + srv.mu.Lock() + _, ok := srv.accessToken[resp.AccessToken] + srv.mu.Unlock() + + if !ok { + t.Error("expected access token to be stored") + } + + // Verify authorization code was consumed + srv.mu.Lock() + _, ok = srv.code[tt.code] + srv.mu.Unlock() + + if ok { + t.Error("expected authorization code to be consumed") + } + } + }) + } +} + +// TestServeUserInfoWithClientValidation verifies UserInfo endpoint security in both strict and non-strict modes. +// In strict mode, the UserInfo endpoint must: +// - Validate that access tokens are associated with registered clients +// - Reject tokens for clients that have been deleted/unregistered +// - Enforce token expiration properly +// - Return appropriate user claims based on client capabilities +func TestServeUserInfoWithClientValidation(t *testing.T) { + tests := []struct { + name string + strictMode bool + setupToken bool + setupClient bool + clientID string + token string + tokenValidTill time.Time + expectError bool + expectCode int + expectUserInfo bool + }{ + { + name: "strict mode - valid token with existing client", + strictMode: true, + setupToken: true, + setupClient: true, + clientID: "test-client", + token: "valid-token", + tokenValidTill: time.Now().Add(5 * time.Minute), + expectUserInfo: true, + }, + { + name: "strict mode - valid token but client no longer exists", + strictMode: true, + setupToken: true, + setupClient: false, + clientID: "deleted-client", + token: "valid-token", + tokenValidTill: time.Now().Add(5 * time.Minute), + expectError: true, + expectCode: http.StatusUnauthorized, + }, + { + name: "strict mode - expired token", + strictMode: true, + setupToken: true, + setupClient: true, + clientID: "test-client", + token: "expired-token", + tokenValidTill: time.Now().Add(-5 * time.Minute), + expectError: true, + expectCode: http.StatusBadRequest, + }, + { + name: "strict mode - invalid token", + strictMode: true, + setupToken: false, + token: "invalid-token", + expectError: true, + expectCode: http.StatusBadRequest, + }, + { + name: "strict mode - token without client association", + strictMode: true, + setupToken: true, + setupClient: false, + clientID: "", + token: "valid-token", + tokenValidTill: time.Now().Add(5 * time.Minute), + expectError: true, + expectCode: http.StatusBadRequest, + }, + { + name: "non-strict mode - no client validation required", + strictMode: false, + setupToken: true, + setupClient: false, + clientID: "", + token: "valid-token", + tokenValidTill: time.Now().Add(5 * time.Minute), + expectUserInfo: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + srv := setupTestServer(t, tt.strictMode) + + // Setup client if needed + if tt.setupClient { + srv.funnelClients[tt.clientID] = &funnelClient{ + ID: tt.clientID, + Secret: "test-secret", + Name: "Test Client", + RedirectURI: "https://rp.example.com/callback", + } + } + + // Setup token if needed + if tt.setupToken { + profile := &tailcfg.UserProfile{ + LoginName: "alice@example.com", + DisplayName: "Alice Example", + ProfilePicURL: "https://example.com/alice.jpg", + } + node := &tailcfg.Node{ + ID: 123, + Name: "test-node.test.ts.net.", + User: 456, + Key: key.NodePublic{}, + Cap: 1, + DiscoKey: key.DiscoPublic{}, + } + remoteUser := &apitype.WhoIsResponse{ + Node: node, + UserProfile: profile, + CapMap: tailcfg.PeerCapMap{}, + } + + srv.accessToken[tt.token] = &authRequest{ + clientID: tt.clientID, + validTill: tt.tokenValidTill, + remoteUser: remoteUser, + } + } + + // Create request + req := httptest.NewRequest("GET", "/userinfo", nil) + req.Header.Set("Authorization", "Bearer "+tt.token) + req.RemoteAddr = "127.0.0.1:12345" + + rr := httptest.NewRecorder() + srv.serveUserInfo(rr, req) + + if tt.expectError { + if rr.Code != tt.expectCode { + t.Errorf("expected status code %d, got %d: %s", tt.expectCode, rr.Code, rr.Body.String()) + } + } else if tt.expectUserInfo { + if rr.Code != http.StatusOK { + t.Errorf("expected 200 OK, got %d: %s", rr.Code, rr.Body.String()) + } + + var resp map[string]any + if err := json.Unmarshal(rr.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to parse JSON response: %v", err) + } + + // Check required fields + expectedFields := []string{"sub", "name", "email", "picture", "username"} + for _, field := range expectedFields { + if _, ok := resp[field]; !ok { + t.Errorf("expected field '%s' in user info response", field) + } + } + + // Verify specific values + if resp["name"] != "Alice Example" { + t.Errorf("expected name 'Alice Example', got '%v'", resp["name"]) + } + if resp["email"] != "alice@example.com" { + t.Errorf("expected email 'alice@example.com', got '%v'", resp["email"]) + } + if resp["username"] != "alice" { + t.Errorf("expected username 'alice', got '%v'", resp["username"]) + } + } + }) + } +} diff --git a/cmd/tsidp/ui-edit.html b/cmd/tsidp/ui-edit.html new file mode 100644 index 000000000..d463981aa --- /dev/null +++ b/cmd/tsidp/ui-edit.html @@ -0,0 +1,199 @@ + + + + + {{if .IsNew}}Add New Client{{else}}Edit Client{{end}} - Tailscale OIDC Identity Provider + + + + + + + {{template "header"}} + +
+
+
+

+ {{if .IsNew}}Add New OIDC Client{{else}}Edit OIDC Client{{end}} +

+ ← Back to Clients +
+ + {{if .Success}} +
+ {{.Success}} +
+ {{end}} + + {{if .Error}} +
+ {{.Error}} +
+ {{end}} + + {{if and .Secret .IsNew}} +
+

Client Created Successfully!

+

âš ī¸ Save both the Client ID and Secret now! The secret will not be shown again.

+ +
+ +
+ + +
+
+ +
+ +
+ + +
+
+
+ {{end}} + + {{if and .Secret .IsEdit}} +
+

New Client Secret

+

âš ī¸ Save this secret now! It will not be shown again.

+
+ + +
+
+ {{end}} + +
+
+ + +
+ A descriptive name for this OIDC client (optional). +
+
+ +
+ + +
+ The URL where users will be redirected after authentication. +
+
+ + {{if .IsEdit}} +
+ + +
+ The client ID cannot be changed. +
+
+ {{end}} + +
+ + + {{if .IsEdit}} + + + + {{end}} +
+
+ + {{if .IsEdit}} +
+

Client Information

+
+
Client ID
+
{{.ID}}
+
Secret Status
+
+ {{if .HasSecret}} + Secret configured + {{else}} + No secret + {{end}} +
+
+
+ {{end}} +
+
+ + + + \ No newline at end of file diff --git a/cmd/tsidp/ui-header.html b/cmd/tsidp/ui-header.html new file mode 100644 index 000000000..68e9bc0df --- /dev/null +++ b/cmd/tsidp/ui-header.html @@ -0,0 +1,53 @@ +
+ +
\ No newline at end of file diff --git a/cmd/tsidp/ui-list.html b/cmd/tsidp/ui-list.html new file mode 100644 index 000000000..d45b88349 --- /dev/null +++ b/cmd/tsidp/ui-list.html @@ -0,0 +1,73 @@ + + + + Tailscale OIDC Identity Provider + + + + + + {{template "header"}} + +
+
+
+

OIDC Clients

+ {{if .}} +

{{len .}} client{{if ne (len .) 1}}s{{end}} configured

+ {{end}} +
+ Add New Client +
+ + {{if .}} + + + + + + + + + + + + {{range .}} + + + + + + + + {{end}} + +
NameClient IDRedirect URIStatusActions
+ {{if .Name}} + {{.Name}} + {{else}} + Unnamed Client + {{end}} + + {{.ID}} + + {{.RedirectURI}} + + {{if .HasSecret}} + Active + {{else}} + No Secret + {{end}} + + Edit +
+ {{else}} +
+

No OIDC clients configured

+

Create your first OIDC client to get started with authentication.

+ Add New Client +
+ {{end}} +
+ + \ No newline at end of file diff --git a/cmd/tsidp/ui-style.css b/cmd/tsidp/ui-style.css new file mode 100644 index 000000000..148ec3030 --- /dev/null +++ b/cmd/tsidp/ui-style.css @@ -0,0 +1,446 @@ +:root { + --tw-text-opacity: 1; + --color-gray-100: 247 245 244; + --color-gray-200: 238 235 234; + --color-gray-500: 112 110 109; + --color-gray-700: 46 45 45; + --color-gray-800: 35 34 34; + --color-gray-900: 31 30 30; + --color-bg-app: rgb(var(--color-gray-900) / 1); + --color-border-base: rgb(var(--color-gray-200) / 1); + --color-primary: 59 130 246; + --color-primary-hover: 37 99 235; + --color-secondary: 107 114 128; + --color-secondary-hover: 75 85 99; + --color-success: 34 197 94; + --color-warning: 245 158 11; + --color-danger: 239 68 68; + --color-danger-hover: 220 38 38; +} + +* { + box-sizing: border-box; + padding: 0; + margin: 0; +} + +body { + font-family: Inter, -apple-system, BlinkMacSystemFont, Helvetica, Arial, + sans-serif; + text-rendering: optimizeLegibility; + -webkit-font-smoothing: antialiased; + -moz-osx-font-smoothing: grayscale; + font-size: 16px; + line-height: 1.4; + margin: 0; + background-color: var(--color-bg-app); + color: rgb(var(--color-gray-200)); +} + +a { + text-decoration: none; + color: inherit; +} + +header { + margin-top: 40px; +} +header nav { + margin: 0 auto; + max-width: 1120px; + display: flex; + align-items: center; + justify-content: center; +} +header nav h1 { + display: inline; + font-weight: 600; + font-size: 1.125rem; + line-height: 1.75rem; + margin-left: 0.75rem; +} + +main { + margin: 40px auto 60px auto; + max-width: 1120px; + padding: 0 20px; +} + +/* Header actions */ +.header-actions { + display: flex; + justify-content: space-between; + align-items: center; + margin-bottom: 2rem; +} + +.header-actions h2 { + font-size: 1.5rem; + font-weight: 600; + margin: 0 0 0.25rem 0; +} + +.client-count { + font-size: 0.875rem; + color: rgb(var(--color-gray-500)); + margin: 0; +} + +/* Buttons */ +.btn { + display: inline-flex; + align-items: center; + padding: 8px 16px; + border-radius: 6px; + font-size: 14px; + font-weight: 500; + text-decoration: none; + border: none; + cursor: pointer; + transition: all 0.2s ease; +} + +.btn-small { + padding: 4px 8px; + font-size: 12px; +} + +.btn-primary { + background-color: rgb(var(--color-primary)); + color: white; +} + +.btn-primary:hover { + background-color: rgb(var(--color-primary-hover)); +} + +.btn-secondary { + background-color: rgb(var(--color-secondary)); + color: white; +} + +.btn-secondary:hover { + background-color: rgb(var(--color-secondary-hover)); +} + +.btn-success { + background-color: rgb(var(--color-success)); + color: white; +} + +.btn-warning { + background-color: rgb(var(--color-warning)); + color: white; +} + +.btn-danger { + background-color: rgb(var(--color-danger)); + color: white; +} + +.btn-danger:hover { + background-color: rgb(var(--color-danger-hover)); +} + +/* Tables */ +table { + width: 100%; + border-spacing: 0; + border: 1px solid rgb(var(--color-gray-700)); + border-bottom-width: 0; + border-radius: 8px; + overflow: hidden; +} + +td { + border: 0 solid rgb(var(--color-gray-700)); + border-bottom-width: 1px; + padding: 12px 16px; +} + +thead td { + text-transform: uppercase; + color: rgb(var(--color-gray-500) / var(--tw-text-opacity)); + font-size: 12px; + letter-spacing: 0.08em; + font-weight: 600; + background-color: rgb(var(--color-gray-800)); +} + +tbody tr:hover { + background-color: rgb(var(--color-gray-800)); +} + +/* Client display elements */ +.client-id { + font-family: "SF Mono", SFMono-Regular, ui-monospace, "DejaVu Sans Mono", + Menlo, Consolas, monospace; + font-size: 12px; + background-color: rgb(var(--color-gray-800)); + padding: 2px 6px; + border-radius: 4px; + color: rgb(var(--color-gray-200)); +} + +.redirect-uri { + font-size: 14px; + color: rgb(var(--color-gray-200)); + word-break: break-all; +} + +.status-active { + color: rgb(var(--color-success)); + font-weight: 500; +} + +.status-inactive { + color: rgb(var(--color-gray-500)); + font-weight: 500; +} + +.text-muted { + color: rgb(var(--color-gray-500)); +} + +/* Empty state */ +.empty-state { + text-align: center; + padding: 60px 20px; + border: 1px solid rgb(var(--color-gray-700)); + border-radius: 8px; + background-color: rgb(var(--color-gray-800) / 0.5); +} + +.empty-state h3 { + font-size: 1.25rem; + font-weight: 600; + margin-bottom: 0.5rem; + color: rgb(var(--color-gray-200)); +} + +.empty-state p { + color: rgb(var(--color-gray-500)); + margin-bottom: 1.5rem; +} + +/* Forms */ +.form-container { + max-width: 600px; + margin: 0 auto; +} + +.form-header { + display: flex; + justify-content: space-between; + align-items: center; + margin-bottom: 2rem; +} + +.form-header h2 { + font-size: 1.5rem; + font-weight: 600; + margin: 0; +} + +.client-form { + background-color: rgb(var(--color-gray-800) / 0.5); + border: 1px solid rgb(var(--color-gray-700)); + border-radius: 8px; + padding: 24px; + margin-bottom: 2rem; +} + +.form-group { + margin-bottom: 1.5rem; +} + +.form-group:last-child { + margin-bottom: 0; +} + +.form-group label { + display: block; + font-weight: 500; + margin-bottom: 0.5rem; + color: rgb(var(--color-gray-200)); +} + +.required { + color: rgb(var(--color-danger)); +} + +.form-input { + width: 100%; + padding: 10px 12px; + border: 1px solid rgb(var(--color-gray-700)); + border-radius: 6px; + background-color: rgb(var(--color-gray-900)); + color: rgb(var(--color-gray-200)); + font-size: 14px; +} + +.form-input:focus { + outline: none; + border-color: rgb(var(--color-primary)); + box-shadow: 0 0 0 3px rgb(var(--color-primary) / 0.1); +} + +.form-input-readonly { + background-color: rgb(var(--color-gray-800)); + color: rgb(var(--color-gray-500)); +} + +.form-help { + font-size: 12px; + color: rgb(var(--color-gray-500)); + margin-top: 0.25rem; +} + +.form-actions { + display: flex; + gap: 1rem; + margin-top: 2rem; + padding-top: 1.5rem; + border-top: 1px solid rgb(var(--color-gray-700)); +} + +/* Alerts */ +.alert { + padding: 12px 16px; + border-radius: 6px; + margin-bottom: 1.5rem; + font-size: 14px; +} + +.alert-success { + background-color: rgb(var(--color-success) / 0.1); + border: 1px solid rgb(var(--color-success) / 0.3); + color: rgb(var(--color-success)); +} + +.alert-error { + background-color: rgb(var(--color-danger) / 0.1); + border: 1px solid rgb(var(--color-danger) / 0.3); + color: rgb(var(--color-danger)); +} + +/* Secret display */ +.secret-display { + background-color: rgb(var(--color-gray-800) / 0.5); + border: 1px solid rgb(var(--color-gray-700)); + border-radius: 8px; + padding: 20px; + margin-bottom: 2rem; +} + +.secret-display h3 { + font-size: 1.125rem; + font-weight: 600; + margin-bottom: 0.5rem; + color: rgb(var(--color-gray-200)); +} + +.warning { + color: rgb(var(--color-warning)); + font-weight: 500; + margin-bottom: 1rem; +} + +.secret-field { + display: flex; + gap: 0.5rem; +} + +.secret-input { + flex: 1; + padding: 10px 12px; + border: 1px solid rgb(var(--color-gray-700)); + border-radius: 6px; + background-color: rgb(var(--color-gray-900)); + color: rgb(var(--color-gray-200)); + font-family: "SF Mono", SFMono-Regular, ui-monospace, "DejaVu Sans Mono", + Menlo, Consolas, monospace; + font-size: 12px; +} + +/* Client info */ +.client-info { + background-color: rgb(var(--color-gray-800) / 0.5); + border: 1px solid rgb(var(--color-gray-700)); + border-radius: 8px; + padding: 20px; +} + +.client-info h3 { + font-size: 1.125rem; + font-weight: 600; + margin-bottom: 1rem; + color: rgb(var(--color-gray-200)); +} + +.client-info dl { + display: grid; + grid-template-columns: auto 1fr; + gap: 0.5rem 1rem; + border: none; + border-radius: 0; + padding: 0; +} + +.client-info dt { + font-weight: 600; + color: rgb(var(--color-gray-400)); + border: none; + padding: 0; +} + +.client-info dd { + color: rgb(var(--color-gray-200)); + border: none; + padding: 0; +} + +.client-info code { + font-family: "SF Mono", SFMono-Regular, ui-monospace, "DejaVu Sans Mono", + Menlo, Consolas, monospace; + font-size: 12px; + background-color: rgb(var(--color-gray-800)); + padding: 2px 6px; + border-radius: 4px; + color: rgb(var(--color-gray-200)); +} + +/* Responsive design */ +@media (max-width: 768px) { + .header-actions { + flex-direction: column; + align-items: stretch; + gap: 1rem; + } + + .form-header { + flex-direction: column; + align-items: stretch; + gap: 1rem; + } + + .form-actions { + flex-direction: column; + } + + .secret-field { + flex-direction: column; + } + + table { + font-size: 14px; + } + + td { + padding: 8px 12px; + } + + .client-id { + font-size: 10px; + } +} \ No newline at end of file diff --git a/cmd/tsidp/ui.go b/cmd/tsidp/ui.go new file mode 100644 index 000000000..d37b64990 --- /dev/null +++ b/cmd/tsidp/ui.go @@ -0,0 +1,325 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package main + +import ( + "bytes" + _ "embed" + "html/template" + "log" + "net/http" + "net/url" + "sort" + "strings" + "time" + + "tailscale.com/util/rands" +) + +//go:embed ui-header.html +var headerHTML string + +//go:embed ui-list.html +var listHTML string + +//go:embed ui-edit.html +var editHTML string + +//go:embed ui-style.css +var styleCSS string + +var headerTmpl = template.Must(template.New("header").Parse(headerHTML)) +var listTmpl = template.Must(headerTmpl.New("list").Parse(listHTML)) +var editTmpl = template.Must(headerTmpl.New("edit").Parse(editHTML)) + +var processStart = time.Now() + +func (s *idpServer) handleUI(w http.ResponseWriter, r *http.Request) { + if isFunnelRequest(r) { + http.Error(w, "tsidp: UI not available over Funnel", http.StatusNotFound) + return + } + + switch r.URL.Path { + case "/": + s.handleClientsList(w, r) + return + case "/new": + s.handleNewClient(w, r) + return + case "/style.css": + http.ServeContent(w, r, "ui-style.css", processStart, strings.NewReader(styleCSS)) + return + } + + if strings.HasPrefix(r.URL.Path, "/edit/") { + s.handleEditClient(w, r) + return + } + + http.Error(w, "tsidp: not found", http.StatusNotFound) +} + +func (s *idpServer) handleClientsList(w http.ResponseWriter, r *http.Request) { + s.mu.Lock() + clients := make([]clientDisplayData, 0, len(s.funnelClients)) + for _, c := range s.funnelClients { + clients = append(clients, clientDisplayData{ + ID: c.ID, + Name: c.Name, + RedirectURI: c.RedirectURI, + HasSecret: c.Secret != "", + }) + } + s.mu.Unlock() + + sort.Slice(clients, func(i, j int) bool { + if clients[i].Name != clients[j].Name { + return clients[i].Name < clients[j].Name + } + return clients[i].ID < clients[j].ID + }) + + var buf bytes.Buffer + if err := listTmpl.Execute(&buf, clients); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + buf.WriteTo(w) +} + +func (s *idpServer) handleNewClient(w http.ResponseWriter, r *http.Request) { + if r.Method == "GET" { + if err := s.renderClientForm(w, clientDisplayData{IsNew: true}); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } + return + } + + if r.Method == "POST" { + if err := r.ParseForm(); err != nil { + http.Error(w, "Failed to parse form", http.StatusBadRequest) + return + } + + name := strings.TrimSpace(r.FormValue("name")) + redirectURI := strings.TrimSpace(r.FormValue("redirect_uri")) + + baseData := clientDisplayData{ + IsNew: true, + Name: name, + RedirectURI: redirectURI, + } + + if errMsg := validateRedirectURI(redirectURI); errMsg != "" { + s.renderFormError(w, baseData, errMsg) + return + } + + clientID := rands.HexString(32) + clientSecret := rands.HexString(64) + newClient := funnelClient{ + ID: clientID, + Secret: clientSecret, + Name: name, + RedirectURI: redirectURI, + } + + s.mu.Lock() + if s.funnelClients == nil { + s.funnelClients = make(map[string]*funnelClient) + } + s.funnelClients[clientID] = &newClient + err := s.storeFunnelClientsLocked() + s.mu.Unlock() + + if err != nil { + log.Printf("could not write funnel clients db: %v", err) + s.renderFormError(w, baseData, "Failed to save client") + return + } + + successData := clientDisplayData{ + ID: clientID, + Name: name, + RedirectURI: redirectURI, + Secret: clientSecret, + IsNew: true, + } + s.renderFormSuccess(w, successData, "Client created successfully! Save the client secret - it won't be shown again.") + return + } + + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) +} + +func (s *idpServer) handleEditClient(w http.ResponseWriter, r *http.Request) { + clientID := strings.TrimPrefix(r.URL.Path, "/edit/") + if clientID == "" { + http.Error(w, "Client ID required", http.StatusBadRequest) + return + } + + s.mu.Lock() + client, exists := s.funnelClients[clientID] + s.mu.Unlock() + + if !exists { + http.Error(w, "Client not found", http.StatusNotFound) + return + } + + if r.Method == "GET" { + data := createEditBaseData(client, client.Name, client.RedirectURI) + if err := s.renderClientForm(w, data); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } + return + } + + if r.Method == "POST" { + action := r.FormValue("action") + + if action == "delete" { + s.mu.Lock() + delete(s.funnelClients, clientID) + err := s.storeFunnelClientsLocked() + s.mu.Unlock() + + if err != nil { + log.Printf("could not write funnel clients db: %v", err) + s.mu.Lock() + s.funnelClients[clientID] = client + s.mu.Unlock() + + baseData := createEditBaseData(client, client.Name, client.RedirectURI) + s.renderFormError(w, baseData, "Failed to delete client. Please try again.") + return + } + + http.Redirect(w, r, "/", http.StatusSeeOther) + return + } + + if action == "regenerate_secret" { + newSecret := rands.HexString(64) + s.mu.Lock() + s.funnelClients[clientID].Secret = newSecret + err := s.storeFunnelClientsLocked() + s.mu.Unlock() + + baseData := createEditBaseData(client, client.Name, client.RedirectURI) + baseData.HasSecret = true + + if err != nil { + log.Printf("could not write funnel clients db: %v", err) + s.renderFormError(w, baseData, "Failed to regenerate secret") + return + } + + baseData.Secret = newSecret + s.renderFormSuccess(w, baseData, "New client secret generated! Save it - it won't be shown again.") + return + } + + if err := r.ParseForm(); err != nil { + http.Error(w, "Failed to parse form", http.StatusBadRequest) + return + } + + name := strings.TrimSpace(r.FormValue("name")) + redirectURI := strings.TrimSpace(r.FormValue("redirect_uri")) + baseData := createEditBaseData(client, name, redirectURI) + + if errMsg := validateRedirectURI(redirectURI); errMsg != "" { + s.renderFormError(w, baseData, errMsg) + return + } + + s.mu.Lock() + s.funnelClients[clientID].Name = name + s.funnelClients[clientID].RedirectURI = redirectURI + err := s.storeFunnelClientsLocked() + s.mu.Unlock() + + if err != nil { + log.Printf("could not write funnel clients db: %v", err) + s.renderFormError(w, baseData, "Failed to update client") + return + } + + s.renderFormSuccess(w, baseData, "Client updated successfully!") + return + } + + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) +} + +type clientDisplayData struct { + ID string + Name string + RedirectURI string + Secret string + HasSecret bool + IsNew bool + IsEdit bool + Success string + Error string +} + +func (s *idpServer) renderClientForm(w http.ResponseWriter, data clientDisplayData) error { + var buf bytes.Buffer + if err := editTmpl.Execute(&buf, data); err != nil { + return err + } + if _, err := buf.WriteTo(w); err != nil { + return err + } + return nil +} + +func (s *idpServer) renderFormError(w http.ResponseWriter, data clientDisplayData, errorMsg string) { + data.Error = errorMsg + if err := s.renderClientForm(w, data); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } +} + +func (s *idpServer) renderFormSuccess(w http.ResponseWriter, data clientDisplayData, successMsg string) { + data.Success = successMsg + if err := s.renderClientForm(w, data); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } +} + +func createEditBaseData(client *funnelClient, name, redirectURI string) clientDisplayData { + return clientDisplayData{ + ID: client.ID, + Name: name, + RedirectURI: redirectURI, + HasSecret: client.Secret != "", + IsEdit: true, + } +} + +func validateRedirectURI(redirectURI string) string { + if redirectURI == "" { + return "Redirect URI is required" + } + + u, err := url.Parse(redirectURI) + if err != nil { + return "Invalid URL format" + } + + if u.Scheme != "http" && u.Scheme != "https" { + return "Redirect URI must be a valid HTTP or HTTPS URL" + } + + if u.Host == "" { + return "Redirect URI must include a valid host" + } + + return "" +} diff --git a/cmd/tta/tta.go b/cmd/tta/tta.go index 4a4c4a6be..9f8f00295 100644 --- a/cmd/tta/tta.go +++ b/cmd/tta/tta.go @@ -30,7 +30,7 @@ import ( "time" "tailscale.com/atomicfile" - "tailscale.com/client/tailscale" + "tailscale.com/client/local" "tailscale.com/hostinfo" "tailscale.com/util/mak" "tailscale.com/util/must" @@ -64,7 +64,7 @@ func serveCmd(w http.ResponseWriter, cmd string, args ...string) { } type localClientRoundTripper struct { - lc tailscale.LocalClient + lc local.Client } func (rt *localClientRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { diff --git a/cmd/vet/jsontags/analyzer.go b/cmd/vet/jsontags/analyzer.go new file mode 100644 index 000000000..d799b66cb --- /dev/null +++ b/cmd/vet/jsontags/analyzer.go @@ -0,0 +1,201 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package jsontags checks for incompatible usage of JSON struct tags. +package jsontags + +import ( + "go/ast" + "go/types" + "reflect" + "strings" + + "golang.org/x/tools/go/analysis" + "golang.org/x/tools/go/analysis/passes/inspect" + "golang.org/x/tools/go/ast/inspector" +) + +var Analyzer = &analysis.Analyzer{ + Name: "jsonvet", + Doc: "check for incompatible usages of JSON struct tags", + Requires: []*analysis.Analyzer{inspect.Analyzer}, + Run: run, +} + +func run(pass *analysis.Pass) (any, error) { + inspect := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector) + + // TODO: Report byte arrays fields without an explicit `format` tag option. + + inspect.Preorder([]ast.Node{(*ast.StructType)(nil)}, func(n ast.Node) { + structType, ok := pass.TypesInfo.Types[n.(*ast.StructType)].Type.(*types.Struct) + if !ok { + return // type information may be incomplete + } + for i := range structType.NumFields() { + fieldVar := structType.Field(i) + tag := reflect.StructTag(structType.Tag(i)).Get("json") + if tag == "" { + continue + } + var seenName, hasFormat bool + for opt := range strings.SplitSeq(tag, ",") { + if !seenName { + seenName = true + continue + } + switch opt { + case "omitempty": + // For bools, ints, uints, floats, strings, and interfaces, + // it is always safe to migrate from `omitempty` to `omitzero` + // so long as the type does not have an IsZero method or + // the IsZero method is identical to reflect.Value.IsZero. + // + // For pointers, it is only safe to migrate from `omitempty` to `omitzero` + // so long as the type does not have an IsZero method, regardless of + // whether the IsZero method is identical to reflect.Value.IsZero. + // + // For pointers, `omitempty` behaves identically on both v1 and v2 + // so long as the type does not implement a Marshal method that + // might serialize as an empty JSON value (i.e., null, "", [], or {}). + hasIsZero := hasIsZeroMethod(fieldVar.Type()) && !hasPureIsZeroMethod(fieldVar.Type()) + underType := fieldVar.Type().Underlying() + basic, isBasic := underType.(*types.Basic) + array, isArrayKind := underType.(*types.Array) + _, isMapKind := underType.(*types.Map) + _, isSliceKind := underType.(*types.Slice) + _, isPointerKind := underType.(*types.Pointer) + _, isInterfaceKind := underType.(*types.Interface) + supportedInV1 := isNumericKind(underType) || + isBasic && basic.Kind() == types.Bool || + isBasic && basic.Kind() == types.String || + isArrayKind && array.Len() == 0 || + isMapKind || isSliceKind || isPointerKind || isInterfaceKind + notSupportedInV2 := isNumericKind(underType) || + isBasic && basic.Kind() == types.Bool + switch { + case isMapKind, isSliceKind: + // This operates the same under both v1 and v2 so long as + // the map or slice type does not implement Marshal + // that could emit an empty JSON value for cases + // other than when the map or slice are empty. + // This is very rare. + case isString(fieldVar.Type()): + // This operates the same under both v1 and v2. + // These are safe to migrate to `omitzero`, + // but doing so is probably unnecessary churn. + // Note that this is only for a unnamed string type. + case !supportedInV1: + // This never worked in v1. Switching to `omitzero` + // may lead to unexpected behavior changes. + report(pass, structType, fieldVar, OmitEmptyUnsupportedInV1) + case notSupportedInV2: + // This does not work in v2. Switching to `omitzero` + // may lead to unexpected behavior changes. + report(pass, structType, fieldVar, OmitEmptyUnsupportedInV2) + case !hasIsZero: + // These are safe to migrate to `omitzero` such that + // it behaves identically under v1 and v2. + report(pass, structType, fieldVar, OmitEmptyShouldBeOmitZero) + case isPointerKind: + // This operates the same under both v1 and v2 so long as + // the pointer type does not implement Marshal that + // could emit an empty JSON value. + // For example, time.Time is safe since the zero value + // never marshals as an empty JSON string. + default: + // This is a non-pointer type with an IsZero method. + // If IsZero is not identical to reflect.Value.IsZero, + // omission may behave slightly differently when using + // `omitzero` instead of `omitempty`. + // Thus the finding uses the word "should". + report(pass, structType, fieldVar, OmitEmptyShouldBeOmitZeroButHasIsZero) + } + case "string": + if !isNumericKind(fieldVar.Type()) { + report(pass, structType, fieldVar, StringOnNonNumericKind) + } + default: + key, _, ok := strings.Cut(opt, ":") + hasFormat = key == "format" && ok + } + } + if !hasFormat && isTimeDuration(mayPointerElem(fieldVar.Type())) { + report(pass, structType, fieldVar, FormatMissingOnTimeDuration) + } + } + }) + return nil, nil +} + +// hasIsZeroMethod reports whether t has an IsZero method. +func hasIsZeroMethod(t types.Type) bool { + for method := range types.NewMethodSet(t).Methods() { + if fn, ok := method.Type().(*types.Signature); ok && method.Obj().Name() == "IsZero" { + if fn.Params().Len() == 0 && fn.Results().Len() == 1 && isBool(fn.Results().At(0).Type()) { + return true + } + } + } + return false +} + +// isBool reports whether t is a bool type. +func isBool(t types.Type) bool { + basic, ok := t.(*types.Basic) + return ok && basic.Kind() == types.Bool +} + +// isString reports whether t is a string type. +func isString(t types.Type) bool { + basic, ok := t.(*types.Basic) + return ok && basic.Kind() == types.String +} + +// isTimeDuration reports whether t is a time.Duration type. +func isTimeDuration(t types.Type) bool { + return isNamed(t, "time", "Duration") +} + +// mayPointerElem returns the pointed-at type if t is a pointer, +// otherwise it returns t as-is. +func mayPointerElem(t types.Type) types.Type { + if pointer, ok := t.(*types.Pointer); ok { + return pointer.Elem() + } + return t +} + +// isNamed reports t is a named typed of the given path and name. +func isNamed(t types.Type, path, name string) bool { + gotPath, gotName := typeName(t) + return gotPath == path && gotName == name +} + +// typeName reports the pkgPath and name of the type. +// It recursively follows type aliases to get the underlying named type. +func typeName(t types.Type) (pkgPath, name string) { + if named, ok := types.Unalias(t).(*types.Named); ok { + obj := named.Obj() + if pkg := obj.Pkg(); pkg != nil { + return pkg.Path(), obj.Name() + } + return "", obj.Name() + } + return "", "" +} + +// isNumericKind reports whether t is a numeric kind. +func isNumericKind(t types.Type) bool { + if basic, ok := t.Underlying().(*types.Basic); ok { + switch basic.Kind() { + case types.Int, types.Int8, types.Int16, types.Int32, types.Int64: + case types.Uint, types.Uint8, types.Uint16, types.Uint32, types.Uint64, types.Uintptr: + case types.Float32, types.Float64: + default: + return false + } + return true + } + return false +} diff --git a/cmd/vet/jsontags/iszero.go b/cmd/vet/jsontags/iszero.go new file mode 100644 index 000000000..77520d72c --- /dev/null +++ b/cmd/vet/jsontags/iszero.go @@ -0,0 +1,75 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package jsontags + +import ( + "go/types" + "reflect" + + "tailscale.com/util/set" +) + +var _ = reflect.Value.IsZero // refer for hot-linking purposes + +var pureIsZeroMethods map[string]set.Set[string] + +// hasPureIsZeroMethod reports whether the IsZero method is truly +// identical to [reflect.Value.IsZero]. +func hasPureIsZeroMethod(t types.Type) bool { + // TODO: Detect this automatically by checking the method AST? + path, name := typeName(t) + return pureIsZeroMethods[path].Contains(name) +} + +// PureIsZeroMethodsInTailscaleModule is a list of known IsZero methods +// in the "tailscale.com" module that are pure. +var PureIsZeroMethodsInTailscaleModule = map[string]set.Set[string]{ + "tailscale.com/net/packet": set.Of( + "TailscaleRejectReason", + ), + "tailscale.com/tailcfg": set.Of( + "UserID", + "LoginID", + "NodeID", + "StableNodeID", + ), + "tailscale.com/tka": set.Of( + "AUMHash", + ), + "tailscale.com/types/geo": set.Of( + "Point", + ), + "tailscale.com/tstime/mono": set.Of( + "Time", + ), + "tailscale.com/types/key": set.Of( + "NLPrivate", + "NLPublic", + "DERPMesh", + "MachinePrivate", + "MachinePublic", + "ControlPrivate", + "DiscoPrivate", + "DiscoPublic", + "DiscoShared", + "HardwareAttestationPublic", + "ChallengePublic", + "NodePrivate", + "NodePublic", + ), + "tailscale.com/types/netlogtype": set.Of( + "Connection", + "Counts", + ), +} + +// RegisterPureIsZeroMethods specifies a list of pure IsZero methods +// where it is identical to calling [reflect.Value.IsZero] on the receiver. +// This is not strictly necessary, but allows for more accurate +// detection of improper use of `json` tags. +// +// This must be called at init and the input must not be mutated. +func RegisterPureIsZeroMethods(methods map[string]set.Set[string]) { + pureIsZeroMethods = methods +} diff --git a/cmd/vet/jsontags/report.go b/cmd/vet/jsontags/report.go new file mode 100644 index 000000000..8e5869060 --- /dev/null +++ b/cmd/vet/jsontags/report.go @@ -0,0 +1,135 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package jsontags + +import ( + "fmt" + "go/types" + "os" + "strings" + + _ "embed" + + "golang.org/x/tools/go/analysis" + "tailscale.com/util/set" +) + +var jsontagsAllowlist map[ReportKind]set.Set[string] + +// ParseAllowlist parses an allowlist of reports to ignore, +// which is a newline-delimited list of tuples separated by a tab, +// where each tuple is a [ReportKind] and a fully-qualified field name. +// +// For example: +// +// OmitEmptyUnsupportedInV1 tailscale.com/path/to/package.StructType.FieldName +// OmitEmptyUnsupportedInV1 tailscale.com/path/to/package.*.FieldName +// +// The struct type name may be "*" for anonymous struct types such +// as those declared within a function or as a type literal in a variable. +func ParseAllowlist(s string) map[ReportKind]set.Set[string] { + var allowlist map[ReportKind]set.Set[string] + for line := range strings.SplitSeq(s, "\n") { + kind, field, _ := strings.Cut(strings.TrimSpace(line), "\t") + if allowlist == nil { + allowlist = make(map[ReportKind]set.Set[string]) + } + fields := allowlist[ReportKind(kind)] + if fields == nil { + fields = make(set.Set[string]) + } + fields.Add(field) + allowlist[ReportKind(kind)] = fields + } + return allowlist +} + +// RegisterAllowlist registers an allowlist of reports to ignore, +// which is represented by a set of fully-qualified field names +// for each [ReportKind]. +// +// For example: +// +// { +// "OmitEmptyUnsupportedInV1": set.Of( +// "tailscale.com/path/to/package.StructType.FieldName", +// "tailscale.com/path/to/package.*.FieldName", +// ), +// } +// +// The struct type name may be "*" for anonymous struct types such +// as those declared within a function or as a type literal in a variable. +// +// This must be called at init and the input must not be mutated. +func RegisterAllowlist(allowlist map[ReportKind]set.Set[string]) { + jsontagsAllowlist = allowlist +} + +type ReportKind string + +const ( + OmitEmptyUnsupportedInV1 ReportKind = "OmitEmptyUnsupportedInV1" + OmitEmptyUnsupportedInV2 ReportKind = "OmitEmptyUnsupportedInV2" + OmitEmptyShouldBeOmitZero ReportKind = "OmitEmptyShouldBeOmitZero" + OmitEmptyShouldBeOmitZeroButHasIsZero ReportKind = "OmitEmptyShouldBeOmitZeroButHasIsZero" + StringOnNonNumericKind ReportKind = "StringOnNonNumericKind" + FormatMissingOnTimeDuration ReportKind = "FormatMissingOnTimeDuration" +) + +func (k ReportKind) message() string { + switch k { + case OmitEmptyUnsupportedInV1: + return "uses `omitempty` on an unsupported type in json/v1; should probably use `omitzero` instead" + case OmitEmptyUnsupportedInV2: + return "uses `omitempty` on an unsupported type in json/v2; should probably use `omitzero` instead" + case OmitEmptyShouldBeOmitZero: + return "should use `omitzero` instead of `omitempty`" + case OmitEmptyShouldBeOmitZeroButHasIsZero: + return "should probably use `omitzero` instead of `omitempty`" + case StringOnNonNumericKind: + return "must not use `string` on non-numeric types" + case FormatMissingOnTimeDuration: + return "must use an explicit `format` tag (e.g., `format:nano`) on a time.Duration type; see https://go.dev/issue/71631" + default: + return string(k) + } +} + +func report(pass *analysis.Pass, structType *types.Struct, fieldVar *types.Var, k ReportKind) { + // Lookup the full name of the struct type. + var fullName string + for _, name := range pass.Pkg.Scope().Names() { + if obj := pass.Pkg.Scope().Lookup(name); obj != nil { + if named, ok := obj.(*types.TypeName); ok { + if types.Identical(named.Type().Underlying(), structType) { + fullName = fmt.Sprintf("%v.%v.%v", named.Pkg().Path(), named.Name(), fieldVar.Name()) + break + } + } + } + } + if fullName == "" { + // Full name could not be found since this is probably an anonymous type + // or locally declared within a function scope. + // Use just the package path and field name instead. + // This is imprecise, but better than nothing. + fullName = fmt.Sprintf("%s.*.%s", fieldVar.Pkg().Path(), fieldVar.Name()) + } + if jsontagsAllowlist[k].Contains(fullName) { + return + } + + const appendAllowlist = "" + if appendAllowlist != "" { + if f, err := os.OpenFile(appendAllowlist, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0664); err == nil { + fmt.Fprintf(f, "%v\t%v\n", k, fullName) + f.Close() + } + } + + pass.Report(analysis.Diagnostic{ + Pos: fieldVar.Pos(), + Message: fmt.Sprintf("field %q %s", fieldVar.Name(), k.message()), + }) +} diff --git a/cmd/vet/jsontags_allowlist b/cmd/vet/jsontags_allowlist new file mode 100644 index 000000000..060a81b05 --- /dev/null +++ b/cmd/vet/jsontags_allowlist @@ -0,0 +1,315 @@ +OmitEmptyShouldBeOmitZero tailscale.com/client/web.authResponse.ViewerIdentity +OmitEmptyShouldBeOmitZero tailscale.com/cmd/k8s-operator.OwnerRef.Resource +OmitEmptyShouldBeOmitZero tailscale.com/cmd/tailscale/cli.apiResponse.Error +OmitEmptyShouldBeOmitZero tailscale.com/health.UnhealthyState.PrimaryAction +OmitEmptyShouldBeOmitZero tailscale.com/internal/client/tailscale.VIPService.Name +OmitEmptyShouldBeOmitZero tailscale.com/ipn.ConfigVAlpha.AcceptDNS +OmitEmptyShouldBeOmitZero tailscale.com/ipn.ConfigVAlpha.AcceptRoutes +OmitEmptyShouldBeOmitZero tailscale.com/ipn.ConfigVAlpha.AllowLANWhileUsingExitNode +OmitEmptyShouldBeOmitZero tailscale.com/ipn.ConfigVAlpha.AppConnector +OmitEmptyShouldBeOmitZero tailscale.com/ipn.ConfigVAlpha.AuthKey +OmitEmptyShouldBeOmitZero tailscale.com/ipn.ConfigVAlpha.AutoUpdate +OmitEmptyShouldBeOmitZero tailscale.com/ipn.ConfigVAlpha.DisableSNAT +OmitEmptyShouldBeOmitZero tailscale.com/ipn.ConfigVAlpha.Enabled +OmitEmptyShouldBeOmitZero tailscale.com/ipn.ConfigVAlpha.ExitNode +OmitEmptyShouldBeOmitZero tailscale.com/ipn.ConfigVAlpha.Hostname +OmitEmptyShouldBeOmitZero tailscale.com/ipn.ConfigVAlpha.Locked +OmitEmptyShouldBeOmitZero tailscale.com/ipn.ConfigVAlpha.NetfilterMode +OmitEmptyShouldBeOmitZero tailscale.com/ipn.ConfigVAlpha.NoStatefulFiltering +OmitEmptyShouldBeOmitZero tailscale.com/ipn.ConfigVAlpha.OperatorUser +OmitEmptyShouldBeOmitZero tailscale.com/ipn.ConfigVAlpha.PostureChecking +OmitEmptyShouldBeOmitZero tailscale.com/ipn.ConfigVAlpha.RunSSHServer +OmitEmptyShouldBeOmitZero tailscale.com/ipn.ConfigVAlpha.RunWebClient +OmitEmptyShouldBeOmitZero tailscale.com/ipn.ConfigVAlpha.ServeConfigTemp +OmitEmptyShouldBeOmitZero tailscale.com/ipn.ConfigVAlpha.ServerURL +OmitEmptyShouldBeOmitZero tailscale.com/ipn.ConfigVAlpha.ShieldsUp +OmitEmptyShouldBeOmitZero tailscale.com/ipn.OutgoingFile.PeerID +OmitEmptyShouldBeOmitZero tailscale.com/ipn.Prefs.AutoExitNode +OmitEmptyShouldBeOmitZero tailscale.com/ipn.Prefs.NoStatefulFiltering +OmitEmptyShouldBeOmitZero tailscale.com/ipn.Prefs.RelayServerPort +OmitEmptyShouldBeOmitZero tailscale.com/ipn/auditlog.transaction.Action +OmitEmptyShouldBeOmitZero tailscale.com/ipn/ipnstate.PeerStatus.AllowedIPs +OmitEmptyShouldBeOmitZero tailscale.com/ipn/ipnstate.PeerStatus.Location +OmitEmptyShouldBeOmitZero tailscale.com/ipn/ipnstate.PeerStatus.PrimaryRoutes +OmitEmptyShouldBeOmitZero tailscale.com/ipn/ipnstate.PeerStatus.Tags +OmitEmptyShouldBeOmitZero tailscale.com/ipn/ipnstate.Status.ExitNodeStatus +OmitEmptyShouldBeOmitZero tailscale.com/ipn/ipnstate.UpdateProgress.Status +OmitEmptyShouldBeOmitZero tailscale.com/k8s-operator/apis/v1alpha1.ConnectorSpec.AppConnector +OmitEmptyShouldBeOmitZero tailscale.com/k8s-operator/apis/v1alpha1.ConnectorSpec.Hostname +OmitEmptyShouldBeOmitZero tailscale.com/k8s-operator/apis/v1alpha1.ConnectorSpec.HostnamePrefix +OmitEmptyShouldBeOmitZero tailscale.com/k8s-operator/apis/v1alpha1.ConnectorSpec.Replicas +OmitEmptyShouldBeOmitZero tailscale.com/k8s-operator/apis/v1alpha1.ConnectorSpec.SubnetRouter +OmitEmptyShouldBeOmitZero tailscale.com/k8s-operator/apis/v1alpha1.Container.Debug +OmitEmptyShouldBeOmitZero tailscale.com/k8s-operator/apis/v1alpha1.Container.ImagePullPolicy +OmitEmptyShouldBeOmitZero tailscale.com/k8s-operator/apis/v1alpha1.Container.SecurityContext +OmitEmptyShouldBeOmitZero tailscale.com/k8s-operator/apis/v1alpha1.KubeAPIServerConfig.Mode +OmitEmptyShouldBeOmitZero tailscale.com/k8s-operator/apis/v1alpha1.Nameserver.Image +OmitEmptyShouldBeOmitZero tailscale.com/k8s-operator/apis/v1alpha1.Nameserver.Pod +OmitEmptyShouldBeOmitZero tailscale.com/k8s-operator/apis/v1alpha1.Nameserver.Replicas +OmitEmptyShouldBeOmitZero tailscale.com/k8s-operator/apis/v1alpha1.Nameserver.Service +OmitEmptyShouldBeOmitZero tailscale.com/k8s-operator/apis/v1alpha1.Pod.Affinity +OmitEmptyShouldBeOmitZero tailscale.com/k8s-operator/apis/v1alpha1.Pod.DNSConfig +OmitEmptyShouldBeOmitZero tailscale.com/k8s-operator/apis/v1alpha1.Pod.DNSPolicy +OmitEmptyShouldBeOmitZero tailscale.com/k8s-operator/apis/v1alpha1.Pod.SecurityContext +OmitEmptyShouldBeOmitZero tailscale.com/k8s-operator/apis/v1alpha1.Pod.TailscaleContainer +OmitEmptyShouldBeOmitZero tailscale.com/k8s-operator/apis/v1alpha1.Pod.TailscaleInitContainer +OmitEmptyShouldBeOmitZero tailscale.com/k8s-operator/apis/v1alpha1.ProxyClassSpec.Metrics +OmitEmptyShouldBeOmitZero tailscale.com/k8s-operator/apis/v1alpha1.ProxyClassSpec.StaticEndpoints +OmitEmptyShouldBeOmitZero tailscale.com/k8s-operator/apis/v1alpha1.ProxyClassSpec.TailscaleConfig +OmitEmptyShouldBeOmitZero tailscale.com/k8s-operator/apis/v1alpha1.ProxyGroupSpec.HostnamePrefix +OmitEmptyShouldBeOmitZero tailscale.com/k8s-operator/apis/v1alpha1.ProxyGroupSpec.KubeAPIServer +OmitEmptyShouldBeOmitZero tailscale.com/k8s-operator/apis/v1alpha1.ProxyGroupSpec.Replicas +OmitEmptyShouldBeOmitZero tailscale.com/k8s-operator/apis/v1alpha1.RecorderContainer.ImagePullPolicy +OmitEmptyShouldBeOmitZero tailscale.com/k8s-operator/apis/v1alpha1.RecorderContainer.SecurityContext +OmitEmptyShouldBeOmitZero tailscale.com/k8s-operator/apis/v1alpha1.RecorderPod.Affinity +OmitEmptyShouldBeOmitZero tailscale.com/k8s-operator/apis/v1alpha1.RecorderPod.SecurityContext +OmitEmptyShouldBeOmitZero tailscale.com/k8s-operator/apis/v1alpha1.StatefulSet.Pod +OmitEmptyShouldBeOmitZero tailscale.com/k8s-operator/apis/v1alpha1.Storage.S3 +OmitEmptyShouldBeOmitZero tailscale.com/kube/ingressservices.Config.IPv4Mapping +OmitEmptyShouldBeOmitZero tailscale.com/kube/ingressservices.Config.IPv6Mapping +OmitEmptyShouldBeOmitZero tailscale.com/kube/k8s-proxy/conf.APIServerProxyConfig.Enabled +OmitEmptyShouldBeOmitZero tailscale.com/kube/k8s-proxy/conf.APIServerProxyConfig.IssueCerts +OmitEmptyShouldBeOmitZero tailscale.com/kube/k8s-proxy/conf.APIServerProxyConfig.Mode +OmitEmptyShouldBeOmitZero tailscale.com/kube/k8s-proxy/conf.APIServerProxyConfig.ServiceName +OmitEmptyShouldBeOmitZero tailscale.com/kube/k8s-proxy/conf.ConfigV1Alpha1.AcceptRoutes +OmitEmptyShouldBeOmitZero tailscale.com/kube/k8s-proxy/conf.ConfigV1Alpha1.APIServerProxy +OmitEmptyShouldBeOmitZero tailscale.com/kube/k8s-proxy/conf.ConfigV1Alpha1.App +OmitEmptyShouldBeOmitZero tailscale.com/kube/k8s-proxy/conf.ConfigV1Alpha1.AuthKey +OmitEmptyShouldBeOmitZero tailscale.com/kube/k8s-proxy/conf.ConfigV1Alpha1.HealthCheckEnabled +OmitEmptyShouldBeOmitZero tailscale.com/kube/k8s-proxy/conf.ConfigV1Alpha1.Hostname +OmitEmptyShouldBeOmitZero tailscale.com/kube/k8s-proxy/conf.ConfigV1Alpha1.LocalAddr +OmitEmptyShouldBeOmitZero tailscale.com/kube/k8s-proxy/conf.ConfigV1Alpha1.LocalPort +OmitEmptyShouldBeOmitZero tailscale.com/kube/k8s-proxy/conf.ConfigV1Alpha1.LogLevel +OmitEmptyShouldBeOmitZero tailscale.com/kube/k8s-proxy/conf.ConfigV1Alpha1.MetricsEnabled +OmitEmptyShouldBeOmitZero tailscale.com/kube/k8s-proxy/conf.ConfigV1Alpha1.ServerURL +OmitEmptyShouldBeOmitZero tailscale.com/kube/k8s-proxy/conf.ConfigV1Alpha1.State +OmitEmptyShouldBeOmitZero tailscale.com/kube/k8s-proxy/conf.VersionedConfig.V1Alpha1 +OmitEmptyShouldBeOmitZero tailscale.com/kube/kubeapi.ObjectMeta.DeletionGracePeriodSeconds +OmitEmptyShouldBeOmitZero tailscale.com/kube/kubeapi.Status.Details +OmitEmptyShouldBeOmitZero tailscale.com/kube/kubeclient.JSONPatch.Value +OmitEmptyShouldBeOmitZero tailscale.com/kube/kubetypes.*.Mode +OmitEmptyShouldBeOmitZero tailscale.com/kube/kubetypes.KubernetesCapRule.Impersonate +OmitEmptyShouldBeOmitZero tailscale.com/sessionrecording.CastHeader.Kubernetes +OmitEmptyShouldBeOmitZero tailscale.com/tailcfg.AuditLogRequest.Action +OmitEmptyShouldBeOmitZero tailscale.com/tailcfg.Debug.Exit +OmitEmptyShouldBeOmitZero tailscale.com/tailcfg.DERPMap.HomeParams +OmitEmptyShouldBeOmitZero tailscale.com/tailcfg.DisplayMessage.PrimaryAction +OmitEmptyShouldBeOmitZero tailscale.com/tailcfg.Hostinfo.AppConnector +OmitEmptyShouldBeOmitZero tailscale.com/tailcfg.Hostinfo.Container +OmitEmptyShouldBeOmitZero tailscale.com/tailcfg.Hostinfo.Desktop +OmitEmptyShouldBeOmitZero tailscale.com/tailcfg.Hostinfo.Location +OmitEmptyShouldBeOmitZero tailscale.com/tailcfg.Hostinfo.NetInfo +OmitEmptyShouldBeOmitZero tailscale.com/tailcfg.Hostinfo.StateEncrypted +OmitEmptyShouldBeOmitZero tailscale.com/tailcfg.Hostinfo.TPM +OmitEmptyShouldBeOmitZero tailscale.com/tailcfg.Hostinfo.Userspace +OmitEmptyShouldBeOmitZero tailscale.com/tailcfg.Hostinfo.UserspaceRouter +OmitEmptyShouldBeOmitZero tailscale.com/tailcfg.MapResponse.ClientVersion +OmitEmptyShouldBeOmitZero tailscale.com/tailcfg.MapResponse.CollectServices +OmitEmptyShouldBeOmitZero tailscale.com/tailcfg.MapResponse.ControlDialPlan +OmitEmptyShouldBeOmitZero tailscale.com/tailcfg.MapResponse.Debug +OmitEmptyShouldBeOmitZero tailscale.com/tailcfg.MapResponse.DefaultAutoUpdate +OmitEmptyShouldBeOmitZero tailscale.com/tailcfg.MapResponse.DERPMap +OmitEmptyShouldBeOmitZero tailscale.com/tailcfg.MapResponse.DNSConfig +OmitEmptyShouldBeOmitZero tailscale.com/tailcfg.MapResponse.Node +OmitEmptyShouldBeOmitZero tailscale.com/tailcfg.MapResponse.PingRequest +OmitEmptyShouldBeOmitZero tailscale.com/tailcfg.MapResponse.SSHPolicy +OmitEmptyShouldBeOmitZero tailscale.com/tailcfg.MapResponse.TKAInfo +OmitEmptyShouldBeOmitZero tailscale.com/tailcfg.NetPortRange.Bits +OmitEmptyShouldBeOmitZero tailscale.com/tailcfg.Node.Online +OmitEmptyShouldBeOmitZero tailscale.com/tailcfg.Node.SelfNodeV4MasqAddrForThisPeer +OmitEmptyShouldBeOmitZero tailscale.com/tailcfg.Node.SelfNodeV6MasqAddrForThisPeer +OmitEmptyShouldBeOmitZero tailscale.com/tailcfg.PeerChange.Online +OmitEmptyShouldBeOmitZero tailscale.com/tailcfg.RegisterRequest.Auth +OmitEmptyShouldBeOmitZero tailscale.com/tailcfg.RegisterResponseAuth.Oauth2Token +OmitEmptyShouldBeOmitZero tailscale.com/tailcfg.SSHAction.OnRecordingFailure +OmitEmptyShouldBeOmitZero tailscale.com/tailcfg.SSHPrincipal.Node +OmitEmptyShouldBeOmitZero tailscale.com/tempfork/acme.*.ExternalAccountBinding +OmitEmptyShouldBeOmitZero tailscale.com/tsweb.AccessLogRecord.RequestID +OmitEmptyShouldBeOmitZero tailscale.com/types/opt.*.Unset +OmitEmptyShouldBeOmitZero tailscale.com/types/views.viewStruct.AddrsPtr +OmitEmptyShouldBeOmitZero tailscale.com/types/views.viewStruct.StringsPtr +OmitEmptyShouldBeOmitZero tailscale.com/wgengine/magicsock.EndpointChange.From +OmitEmptyShouldBeOmitZero tailscale.com/wgengine/magicsock.EndpointChange.To +OmitEmptyShouldBeOmitZeroButHasIsZero tailscale.com/types/persist.Persist.AttestationKey +OmitEmptyUnsupportedInV1 tailscale.com/client/tailscale.KeyCapabilities.Devices +OmitEmptyUnsupportedInV1 tailscale.com/client/tailscale/apitype.ExitNodeSuggestionResponse.Location +OmitEmptyUnsupportedInV1 tailscale.com/cmd/k8s-operator.ServiceMonitorSpec.NamespaceSelector +OmitEmptyUnsupportedInV1 tailscale.com/derp.ClientInfo.MeshKey +OmitEmptyUnsupportedInV1 tailscale.com/ipn.MaskedPrefs.AutoUpdateSet +OmitEmptyUnsupportedInV1 tailscale.com/k8s-operator/apis/v1alpha1.Connector.ObjectMeta +OmitEmptyUnsupportedInV1 tailscale.com/k8s-operator/apis/v1alpha1.Container.Resources +OmitEmptyUnsupportedInV1 tailscale.com/k8s-operator/apis/v1alpha1.DNSConfig.ObjectMeta +OmitEmptyUnsupportedInV1 tailscale.com/k8s-operator/apis/v1alpha1.ProxyClass.ObjectMeta +OmitEmptyUnsupportedInV1 tailscale.com/k8s-operator/apis/v1alpha1.ProxyGroup.ObjectMeta +OmitEmptyUnsupportedInV1 tailscale.com/k8s-operator/apis/v1alpha1.Recorder.ObjectMeta +OmitEmptyUnsupportedInV1 tailscale.com/k8s-operator/apis/v1alpha1.RecorderContainer.Resources +OmitEmptyUnsupportedInV1 tailscale.com/k8s-operator/apis/v1alpha1.RecorderPod.Container +OmitEmptyUnsupportedInV1 tailscale.com/k8s-operator/apis/v1alpha1.RecorderPod.ServiceAccount +OmitEmptyUnsupportedInV1 tailscale.com/k8s-operator/apis/v1alpha1.RecorderSpec.Storage +OmitEmptyUnsupportedInV1 tailscale.com/k8s-operator/apis/v1alpha1.RecorderStatefulSet.Pod +OmitEmptyUnsupportedInV1 tailscale.com/k8s-operator/apis/v1alpha1.S3.Credentials +OmitEmptyUnsupportedInV1 tailscale.com/k8s-operator/apis/v1alpha1.S3Credentials.Secret +OmitEmptyUnsupportedInV1 tailscale.com/kube/kubeapi.Event.FirstTimestamp +OmitEmptyUnsupportedInV1 tailscale.com/kube/kubeapi.Event.LastTimestamp +OmitEmptyUnsupportedInV1 tailscale.com/kube/kubeapi.Event.Source +OmitEmptyUnsupportedInV1 tailscale.com/kube/kubeapi.ObjectMeta.CreationTimestamp +OmitEmptyUnsupportedInV1 tailscale.com/tailcfg_test.*.Groups +OmitEmptyUnsupportedInV1 tailscale.com/tailcfg.Oauth2Token.Expiry +OmitEmptyUnsupportedInV1 tailscale.com/tailcfg.QueryFeatureRequest.NodeKey +OmitEmptyUnsupportedInV2 tailscale.com/client/tailscale.*.ExpirySeconds +OmitEmptyUnsupportedInV2 tailscale.com/client/tailscale.DerpRegion.Preferred +OmitEmptyUnsupportedInV2 tailscale.com/client/tailscale.DevicePostureIdentity.Disabled +OmitEmptyUnsupportedInV2 tailscale.com/client/tailscale/apitype.DNSResolver.UseWithExitNode +OmitEmptyUnsupportedInV2 tailscale.com/client/web.authResponse.NeedsSynoAuth +OmitEmptyUnsupportedInV2 tailscale.com/cmd/tsidp.tailscaleClaims.UserID +OmitEmptyUnsupportedInV2 tailscale.com/derp.ClientInfo.IsProber +OmitEmptyUnsupportedInV2 tailscale.com/derp.ClientInfo.Version +OmitEmptyUnsupportedInV2 tailscale.com/derp.ServerInfo.TokenBucketBytesBurst +OmitEmptyUnsupportedInV2 tailscale.com/derp.ServerInfo.TokenBucketBytesPerSecond +OmitEmptyUnsupportedInV2 tailscale.com/derp.ServerInfo.Version +OmitEmptyUnsupportedInV2 tailscale.com/health.UnhealthyState.ImpactsConnectivity +OmitEmptyUnsupportedInV2 tailscale.com/ipn.AutoUpdatePrefsMask.ApplySet +OmitEmptyUnsupportedInV2 tailscale.com/ipn.AutoUpdatePrefsMask.CheckSet +OmitEmptyUnsupportedInV2 tailscale.com/ipn.MaskedPrefs.AdvertiseRoutesSet +OmitEmptyUnsupportedInV2 tailscale.com/ipn.MaskedPrefs.AdvertiseServicesSet +OmitEmptyUnsupportedInV2 tailscale.com/ipn.MaskedPrefs.AdvertiseTagsSet +OmitEmptyUnsupportedInV2 tailscale.com/ipn.MaskedPrefs.AppConnectorSet +OmitEmptyUnsupportedInV2 tailscale.com/ipn.MaskedPrefs.AutoExitNodeSet +OmitEmptyUnsupportedInV2 tailscale.com/ipn.MaskedPrefs.ControlURLSet +OmitEmptyUnsupportedInV2 tailscale.com/ipn.MaskedPrefs.CorpDNSSet +OmitEmptyUnsupportedInV2 tailscale.com/ipn.MaskedPrefs.DriveSharesSet +OmitEmptyUnsupportedInV2 tailscale.com/ipn.MaskedPrefs.EggSet +OmitEmptyUnsupportedInV2 tailscale.com/ipn.MaskedPrefs.ExitNodeAllowLANAccessSet +OmitEmptyUnsupportedInV2 tailscale.com/ipn.MaskedPrefs.ExitNodeIDSet +OmitEmptyUnsupportedInV2 tailscale.com/ipn.MaskedPrefs.ExitNodeIPSet +OmitEmptyUnsupportedInV2 tailscale.com/ipn.MaskedPrefs.ForceDaemonSet +OmitEmptyUnsupportedInV2 tailscale.com/ipn.MaskedPrefs.HostnameSet +OmitEmptyUnsupportedInV2 tailscale.com/ipn.MaskedPrefs.InternalExitNodePriorSet +OmitEmptyUnsupportedInV2 tailscale.com/ipn.MaskedPrefs.LoggedOutSet +OmitEmptyUnsupportedInV2 tailscale.com/ipn.MaskedPrefs.NetfilterKindSet +OmitEmptyUnsupportedInV2 tailscale.com/ipn.MaskedPrefs.NetfilterModeSet +OmitEmptyUnsupportedInV2 tailscale.com/ipn.MaskedPrefs.NoSNATSet +OmitEmptyUnsupportedInV2 tailscale.com/ipn.MaskedPrefs.NoStatefulFilteringSet +OmitEmptyUnsupportedInV2 tailscale.com/ipn.MaskedPrefs.NotepadURLsSet +OmitEmptyUnsupportedInV2 tailscale.com/ipn.MaskedPrefs.OperatorUserSet +OmitEmptyUnsupportedInV2 tailscale.com/ipn.MaskedPrefs.PostureCheckingSet +OmitEmptyUnsupportedInV2 tailscale.com/ipn.MaskedPrefs.ProfileNameSet +OmitEmptyUnsupportedInV2 tailscale.com/ipn.MaskedPrefs.RelayServerPortSet +OmitEmptyUnsupportedInV2 tailscale.com/ipn.MaskedPrefs.RouteAllSet +OmitEmptyUnsupportedInV2 tailscale.com/ipn.MaskedPrefs.RunSSHSet +OmitEmptyUnsupportedInV2 tailscale.com/ipn.MaskedPrefs.RunWebClientSet +OmitEmptyUnsupportedInV2 tailscale.com/ipn.MaskedPrefs.ShieldsUpSet +OmitEmptyUnsupportedInV2 tailscale.com/ipn.MaskedPrefs.WantRunningSet +OmitEmptyUnsupportedInV2 tailscale.com/ipn.PartialFile.Done +OmitEmptyUnsupportedInV2 tailscale.com/ipn.Prefs.Egg +OmitEmptyUnsupportedInV2 tailscale.com/ipn.Prefs.ForceDaemon +OmitEmptyUnsupportedInV2 tailscale.com/ipn.ServiceConfig.Tun +OmitEmptyUnsupportedInV2 tailscale.com/ipn.TCPPortHandler.HTTP +OmitEmptyUnsupportedInV2 tailscale.com/ipn.TCPPortHandler.HTTPS +OmitEmptyUnsupportedInV2 tailscale.com/ipn/auditlog.transaction.Retries +OmitEmptyUnsupportedInV2 tailscale.com/ipn/ipnstate.PeerStatus.AltSharerUserID +OmitEmptyUnsupportedInV2 tailscale.com/ipn/ipnstate.PeerStatus.Expired +OmitEmptyUnsupportedInV2 tailscale.com/ipn/ipnstate.PeerStatus.ShareeNode +OmitEmptyUnsupportedInV2 tailscale.com/ipn/ipnstate.PingResult.IsLocalIP +OmitEmptyUnsupportedInV2 tailscale.com/ipn/ipnstate.PingResult.PeerAPIPort +OmitEmptyUnsupportedInV2 tailscale.com/ipn/ipnstate.Status.HaveNodeKey +OmitEmptyUnsupportedInV2 tailscale.com/k8s-operator/apis/v1alpha1.PortRange.EndPort +OmitEmptyUnsupportedInV2 tailscale.com/k8s-operator/apis/v1alpha1.ProxyClassSpec.UseLetsEncryptStagingEnvironment +OmitEmptyUnsupportedInV2 tailscale.com/k8s-operator/apis/v1alpha1.RecorderSpec.EnableUI +OmitEmptyUnsupportedInV2 tailscale.com/k8s-operator/apis/v1alpha1.TailscaleConfig.AcceptRoutes +OmitEmptyUnsupportedInV2 tailscale.com/kube/kubeapi.Event.Count +OmitEmptyUnsupportedInV2 tailscale.com/kube/kubeapi.ObjectMeta.Generation +OmitEmptyUnsupportedInV2 tailscale.com/kube/kubeapi.Status.Code +OmitEmptyUnsupportedInV2 tailscale.com/kube/kubetypes.KubernetesCapRule.EnforceRecorder +OmitEmptyUnsupportedInV2 tailscale.com/log/sockstatlog.event.IsCellularInterface +OmitEmptyUnsupportedInV2 tailscale.com/sessionrecording.CastHeader.SrcNodeUserID +OmitEmptyUnsupportedInV2 tailscale.com/sessionrecording.Source.NodeUserID +OmitEmptyUnsupportedInV2 tailscale.com/sessionrecording.v2ResponseFrame.Ack +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg_test.*.ToggleOn +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.AuditLogRequest.Version +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.C2NPostureIdentityResponse.PostureDisabled +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.C2NSSHUsernamesRequest.Max +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.C2NTLSCertInfo.Expired +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.C2NTLSCertInfo.Missing +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.C2NTLSCertInfo.Valid +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.ClientVersion.Notify +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.ClientVersion.RunningLatest +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.ClientVersion.UrgentSecurityUpdate +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.ControlIPCandidate.DialStartDelaySec +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.ControlIPCandidate.DialTimeoutSec +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.ControlIPCandidate.Priority +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.Debug.DisableLogTail +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.Debug.SleepSeconds +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.DERPMap.OmitDefaultRegions +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.DERPNode.CanPort80 +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.DERPNode.DERPPort +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.DERPNode.InsecureForTests +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.DERPNode.STUNOnly +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.DERPNode.STUNPort +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.DERPRegion.Avoid +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.DERPRegion.Latitude +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.DERPRegion.Longitude +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.DERPRegion.NoMeasureNoHome +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.DisplayMessage.ImpactsConnectivity +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.DNSConfig.Proxied +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.Hostinfo.AllowsUpdate +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.Hostinfo.IngressEnabled +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.Hostinfo.NoLogsNoSupport +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.Hostinfo.ShareeNode +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.Hostinfo.ShieldsUp +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.Hostinfo.WireIngress +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.Location.Latitude +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.Location.Longitude +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.Location.Priority +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.MapRequest.MapSessionSeq +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.MapRequest.OmitPeers +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.MapRequest.ReadOnly +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.MapResponse.KeepAlive +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.MapResponse.Seq +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.NetInfo.HavePortMap +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.Node.Cap +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.Node.Expired +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.Node.HomeDERP +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.Node.IsJailed +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.Node.IsWireGuardOnly +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.Node.MachineAuthorized +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.Node.Sharer +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.Node.UnsignedPeerAPIOnly +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.PeerChange.Cap +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.PeerChange.DERPRegion +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.PingRequest.Log +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.PingRequest.URLIsNoise +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.PingResponse.DERPRegionID +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.PingResponse.IsLocalIP +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.PingResponse.LatencySeconds +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.PingResponse.PeerAPIPort +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.QueryFeatureResponse.Complete +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.QueryFeatureResponse.ShouldWait +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.RegisterRequest.Ephemeral +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.RegisterRequest.SignatureType +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.SSHAction.Accept +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.SSHAction.AllowAgentForwarding +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.SSHAction.AllowLocalPortForwarding +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.SSHAction.AllowRemotePortForwarding +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.SSHAction.Reject +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.SSHAction.SessionDuration +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.SSHPrincipal.Any +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.TKAInfo.Disabled +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.TPMInfo.FirmwareVersion +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.TPMInfo.Model +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.TPMInfo.SpecRevision +OmitEmptyUnsupportedInV2 tailscale.com/tailcfg.WebClientAuthResponse.Complete +OmitEmptyUnsupportedInV2 tailscale.com/tempfork/acme.*.TermsAgreed +OmitEmptyUnsupportedInV2 tailscale.com/tstime/rate.jsonValue.Updated +OmitEmptyUnsupportedInV2 tailscale.com/tstime/rate.jsonValue.Value +OmitEmptyUnsupportedInV2 tailscale.com/tsweb.AccessLogRecord.Bytes +OmitEmptyUnsupportedInV2 tailscale.com/tsweb.AccessLogRecord.Code +OmitEmptyUnsupportedInV2 tailscale.com/tsweb.AccessLogRecord.Seconds +OmitEmptyUnsupportedInV2 tailscale.com/tsweb.AccessLogRecord.TLS +OmitEmptyUnsupportedInV2 tailscale.com/tsweb/varz.SomeStats.TotalY +OmitEmptyUnsupportedInV2 tailscale.com/types/appctype.AppConnectorConfig.AdvertiseRoutes +OmitEmptyUnsupportedInV2 tailscale.com/types/dnstype.Resolver.UseWithExitNode +OmitEmptyUnsupportedInV2 tailscale.com/types/opt.testStruct.Int +OmitEmptyUnsupportedInV2 tailscale.com/version.Meta.GitDirty +OmitEmptyUnsupportedInV2 tailscale.com/version.Meta.IsDev +OmitEmptyUnsupportedInV2 tailscale.com/version.Meta.UnstableBranch diff --git a/cmd/vet/vet.go b/cmd/vet/vet.go new file mode 100644 index 000000000..45473af48 --- /dev/null +++ b/cmd/vet/vet.go @@ -0,0 +1,24 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package vet is a tool to statically check Go source code. +package main + +import ( + _ "embed" + + "golang.org/x/tools/go/analysis/unitchecker" + "tailscale.com/cmd/vet/jsontags" +) + +//go:embed jsontags_allowlist +var jsontagsAllowlistSource string + +func init() { + jsontags.RegisterAllowlist(jsontags.ParseAllowlist(jsontagsAllowlistSource)) + jsontags.RegisterPureIsZeroMethods(jsontags.PureIsZeroMethodsInTailscaleModule) +} + +func main() { + unitchecker.Main(jsontags.Analyzer) +} diff --git a/cmd/viewer/tests/tests.go b/cmd/viewer/tests/tests.go index 14a488861..d1c753db7 100644 --- a/cmd/viewer/tests/tests.go +++ b/cmd/viewer/tests/tests.go @@ -13,7 +13,7 @@ import ( "tailscale.com/types/views" ) -//go:generate go run tailscale.com/cmd/viewer --type=StructWithPtrs,StructWithoutPtrs,Map,StructWithSlices,OnlyGetClone,StructWithEmbedded,GenericIntStruct,GenericNoPtrsStruct,GenericCloneableStruct,StructWithContainers,StructWithTypeAliasFields,GenericTypeAliasStruct --clone-only-type=OnlyGetClone +//go:generate go run tailscale.com/cmd/viewer --type=StructWithPtrs,StructWithoutPtrs,Map,StructWithSlices,OnlyGetClone,StructWithEmbedded,GenericIntStruct,GenericNoPtrsStruct,GenericCloneableStruct,StructWithContainers,StructWithTypeAliasFields,GenericTypeAliasStruct,StructWithMapOfViews --clone-only-type=OnlyGetClone type StructWithoutPtrs struct { Int int @@ -37,9 +37,14 @@ type Map struct { StructWithPtrKey map[StructWithPtrs]int `json:"-"` } +type StructWithNoView struct { + Value int +} + type StructWithPtrs struct { - Value *StructWithoutPtrs - Int *int + Value *StructWithoutPtrs + Int *int + NoView *StructWithNoView NoCloneValue *StructWithoutPtrs `codegen:"noclone"` } @@ -135,7 +140,7 @@ func (c *Container[T]) Clone() *Container[T] { panic(fmt.Errorf("%T contains pointers, but is not cloneable", c.Item)) } -// ContainerView is a pre-defined readonly view of a Container[T]. +// ContainerView is a pre-defined read-only view of a Container[T]. type ContainerView[T views.ViewCloner[T, V], V views.StructView[T]] struct { // Đļ is the underlying mutable value, named with a hard-to-type // character that looks pointy like a pointer. @@ -173,7 +178,7 @@ func (c *MapContainer[K, V]) Clone() *MapContainer[K, V] { return &MapContainer[K, V]{m} } -// MapContainerView is a pre-defined readonly view of a [MapContainer][K, T]. +// MapContainerView is a pre-defined read-only view of a [MapContainer][K, T]. type MapContainerView[K comparable, T views.ViewCloner[T, V], V views.StructView[T]] struct { // Đļ is the underlying mutable value, named with a hard-to-type // character that looks pointy like a pointer. @@ -233,3 +238,7 @@ type GenericTypeAliasStruct[T integer, T2 views.ViewCloner[T2, V2], V2 views.Str NonCloneable T Cloneable T2 } + +type StructWithMapOfViews struct { + MapOfViews map[string]StructWithoutPtrsView +} diff --git a/cmd/viewer/tests/tests_clone.go b/cmd/viewer/tests/tests_clone.go index 9131f5040..4602b9d88 100644 --- a/cmd/viewer/tests/tests_clone.go +++ b/cmd/viewer/tests/tests_clone.go @@ -28,6 +28,9 @@ func (src *StructWithPtrs) Clone() *StructWithPtrs { if dst.Int != nil { dst.Int = ptr.To(*src.Int) } + if dst.NoView != nil { + dst.NoView = ptr.To(*src.NoView) + } return dst } @@ -35,6 +38,7 @@ func (src *StructWithPtrs) Clone() *StructWithPtrs { var _StructWithPtrsCloneNeedsRegeneration = StructWithPtrs(struct { Value *StructWithoutPtrs Int *int + NoView *StructWithNoView NoCloneValue *StructWithoutPtrs }{}) @@ -543,3 +547,20 @@ func _GenericTypeAliasStructCloneNeedsRegeneration[T integer, T2 views.ViewClone Cloneable T2 }{}) } + +// Clone makes a deep copy of StructWithMapOfViews. +// The result aliases no memory with the original. +func (src *StructWithMapOfViews) Clone() *StructWithMapOfViews { + if src == nil { + return nil + } + dst := new(StructWithMapOfViews) + *dst = *src + dst.MapOfViews = maps.Clone(src.MapOfViews) + return dst +} + +// A compilation failure here means this code must be regenerated, with the command at the top of this file. +var _StructWithMapOfViewsCloneNeedsRegeneration = StructWithMapOfViews(struct { + MapOfViews map[string]StructWithoutPtrsView +}{}) diff --git a/cmd/viewer/tests/tests_view.go b/cmd/viewer/tests/tests_view.go index 9c74c9426..495281c23 100644 --- a/cmd/viewer/tests/tests_view.go +++ b/cmd/viewer/tests/tests_view.go @@ -6,17 +6,19 @@ package tests import ( - "encoding/json" + jsonv1 "encoding/json" "errors" "net/netip" + jsonv2 "github.com/go-json-experiment/json" + "github.com/go-json-experiment/json/jsontext" "golang.org/x/exp/constraints" "tailscale.com/types/views" ) -//go:generate go run tailscale.com/cmd/cloner -clonefunc=false -type=StructWithPtrs,StructWithoutPtrs,Map,StructWithSlices,OnlyGetClone,StructWithEmbedded,GenericIntStruct,GenericNoPtrsStruct,GenericCloneableStruct,StructWithContainers,StructWithTypeAliasFields,GenericTypeAliasStruct +//go:generate go run tailscale.com/cmd/cloner -clonefunc=false -type=StructWithPtrs,StructWithoutPtrs,Map,StructWithSlices,OnlyGetClone,StructWithEmbedded,GenericIntStruct,GenericNoPtrsStruct,GenericCloneableStruct,StructWithContainers,StructWithTypeAliasFields,GenericTypeAliasStruct,StructWithMapOfViews -// View returns a readonly view of StructWithPtrs. +// View returns a read-only view of StructWithPtrs. func (p *StructWithPtrs) View() StructWithPtrsView { return StructWithPtrsView{Đļ: p} } @@ -32,7 +34,7 @@ type StructWithPtrsView struct { Đļ *StructWithPtrs } -// Valid reports whether underlying value is non-nil. +// Valid reports whether v's underlying value is non-nil. func (v StructWithPtrsView) Valid() bool { return v.Đļ != nil } // AsStruct returns a clone of the underlying value which aliases no memory with @@ -44,8 +46,17 @@ func (v StructWithPtrsView) AsStruct() *StructWithPtrs { return v.Đļ.Clone() } -func (v StructWithPtrsView) MarshalJSON() ([]byte, error) { return json.Marshal(v.Đļ) } +// MarshalJSON implements [jsonv1.Marshaler]. +func (v StructWithPtrsView) MarshalJSON() ([]byte, error) { + return jsonv1.Marshal(v.Đļ) +} + +// MarshalJSONTo implements [jsonv2.MarshalerTo]. +func (v StructWithPtrsView) MarshalJSONTo(enc *jsontext.Encoder) error { + return jsonv2.MarshalEncode(enc, v.Đļ) +} +// UnmarshalJSON implements [jsonv1.Unmarshaler]. func (v *StructWithPtrsView) UnmarshalJSON(b []byte) error { if v.Đļ != nil { return errors.New("already initialized") @@ -54,27 +65,31 @@ func (v *StructWithPtrsView) UnmarshalJSON(b []byte) error { return nil } var x StructWithPtrs - if err := json.Unmarshal(b, &x); err != nil { + if err := jsonv1.Unmarshal(b, &x); err != nil { return err } v.Đļ = &x return nil } -func (v StructWithPtrsView) Value() *StructWithoutPtrs { - if v.Đļ.Value == nil { - return nil +// UnmarshalJSONFrom implements [jsonv2.UnmarshalerFrom]. +func (v *StructWithPtrsView) UnmarshalJSONFrom(dec *jsontext.Decoder) error { + if v.Đļ != nil { + return errors.New("already initialized") + } + var x StructWithPtrs + if err := jsonv2.UnmarshalDecode(dec, &x); err != nil { + return err } - x := *v.Đļ.Value - return &x + v.Đļ = &x + return nil } -func (v StructWithPtrsView) Int() *int { - if v.Đļ.Int == nil { - return nil - } - x := *v.Đļ.Int - return &x +func (v StructWithPtrsView) Value() StructWithoutPtrsView { return v.Đļ.Value.View() } +func (v StructWithPtrsView) Int() views.ValuePointer[int] { return views.ValuePointerOf(v.Đļ.Int) } + +func (v StructWithPtrsView) NoView() views.ValuePointer[StructWithNoView] { + return views.ValuePointerOf(v.Đļ.NoView) } func (v StructWithPtrsView) NoCloneValue() *StructWithoutPtrs { return v.Đļ.NoCloneValue } @@ -85,10 +100,11 @@ func (v StructWithPtrsView) Equal(v2 StructWithPtrsView) bool { return v.Đļ.Equa var _StructWithPtrsViewNeedsRegeneration = StructWithPtrs(struct { Value *StructWithoutPtrs Int *int + NoView *StructWithNoView NoCloneValue *StructWithoutPtrs }{}) -// View returns a readonly view of StructWithoutPtrs. +// View returns a read-only view of StructWithoutPtrs. func (p *StructWithoutPtrs) View() StructWithoutPtrsView { return StructWithoutPtrsView{Đļ: p} } @@ -104,7 +120,7 @@ type StructWithoutPtrsView struct { Đļ *StructWithoutPtrs } -// Valid reports whether underlying value is non-nil. +// Valid reports whether v's underlying value is non-nil. func (v StructWithoutPtrsView) Valid() bool { return v.Đļ != nil } // AsStruct returns a clone of the underlying value which aliases no memory with @@ -116,8 +132,17 @@ func (v StructWithoutPtrsView) AsStruct() *StructWithoutPtrs { return v.Đļ.Clone() } -func (v StructWithoutPtrsView) MarshalJSON() ([]byte, error) { return json.Marshal(v.Đļ) } +// MarshalJSON implements [jsonv1.Marshaler]. +func (v StructWithoutPtrsView) MarshalJSON() ([]byte, error) { + return jsonv1.Marshal(v.Đļ) +} + +// MarshalJSONTo implements [jsonv2.MarshalerTo]. +func (v StructWithoutPtrsView) MarshalJSONTo(enc *jsontext.Encoder) error { + return jsonv2.MarshalEncode(enc, v.Đļ) +} +// UnmarshalJSON implements [jsonv1.Unmarshaler]. func (v *StructWithoutPtrsView) UnmarshalJSON(b []byte) error { if v.Đļ != nil { return errors.New("already initialized") @@ -126,7 +151,20 @@ func (v *StructWithoutPtrsView) UnmarshalJSON(b []byte) error { return nil } var x StructWithoutPtrs - if err := json.Unmarshal(b, &x); err != nil { + if err := jsonv1.Unmarshal(b, &x); err != nil { + return err + } + v.Đļ = &x + return nil +} + +// UnmarshalJSONFrom implements [jsonv2.UnmarshalerFrom]. +func (v *StructWithoutPtrsView) UnmarshalJSONFrom(dec *jsontext.Decoder) error { + if v.Đļ != nil { + return errors.New("already initialized") + } + var x StructWithoutPtrs + if err := jsonv2.UnmarshalDecode(dec, &x); err != nil { return err } v.Đļ = &x @@ -142,7 +180,7 @@ var _StructWithoutPtrsViewNeedsRegeneration = StructWithoutPtrs(struct { Pfx netip.Prefix }{}) -// View returns a readonly view of Map. +// View returns a read-only view of Map. func (p *Map) View() MapView { return MapView{Đļ: p} } @@ -158,7 +196,7 @@ type MapView struct { Đļ *Map } -// Valid reports whether underlying value is non-nil. +// Valid reports whether v's underlying value is non-nil. func (v MapView) Valid() bool { return v.Đļ != nil } // AsStruct returns a clone of the underlying value which aliases no memory with @@ -170,8 +208,17 @@ func (v MapView) AsStruct() *Map { return v.Đļ.Clone() } -func (v MapView) MarshalJSON() ([]byte, error) { return json.Marshal(v.Đļ) } +// MarshalJSON implements [jsonv1.Marshaler]. +func (v MapView) MarshalJSON() ([]byte, error) { + return jsonv1.Marshal(v.Đļ) +} + +// MarshalJSONTo implements [jsonv2.MarshalerTo]. +func (v MapView) MarshalJSONTo(enc *jsontext.Encoder) error { + return jsonv2.MarshalEncode(enc, v.Đļ) +} +// UnmarshalJSON implements [jsonv1.Unmarshaler]. func (v *MapView) UnmarshalJSON(b []byte) error { if v.Đļ != nil { return errors.New("already initialized") @@ -180,54 +227,61 @@ func (v *MapView) UnmarshalJSON(b []byte) error { return nil } var x Map - if err := json.Unmarshal(b, &x); err != nil { + if err := jsonv1.Unmarshal(b, &x); err != nil { return err } v.Đļ = &x return nil } -func (v MapView) Int() views.Map[string, int] { return views.MapOf(v.Đļ.Int) } +// UnmarshalJSONFrom implements [jsonv2.UnmarshalerFrom]. +func (v *MapView) UnmarshalJSONFrom(dec *jsontext.Decoder) error { + if v.Đļ != nil { + return errors.New("already initialized") + } + var x Map + if err := jsonv2.UnmarshalDecode(dec, &x); err != nil { + return err + } + v.Đļ = &x + return nil +} +func (v MapView) Int() views.Map[string, int] { return views.MapOf(v.Đļ.Int) } func (v MapView) SliceInt() views.MapSlice[string, int] { return views.MapSliceOf(v.Đļ.SliceInt) } - func (v MapView) StructPtrWithPtr() views.MapFn[string, *StructWithPtrs, StructWithPtrsView] { return views.MapFnOf(v.Đļ.StructPtrWithPtr, func(t *StructWithPtrs) StructWithPtrsView { return t.View() }) } - func (v MapView) StructPtrWithoutPtr() views.MapFn[string, *StructWithoutPtrs, StructWithoutPtrsView] { return views.MapFnOf(v.Đļ.StructPtrWithoutPtr, func(t *StructWithoutPtrs) StructWithoutPtrsView { return t.View() }) } - func (v MapView) StructWithoutPtr() views.Map[string, StructWithoutPtrs] { return views.MapOf(v.Đļ.StructWithoutPtr) } - func (v MapView) SlicesWithPtrs() views.MapFn[string, []*StructWithPtrs, views.SliceView[*StructWithPtrs, StructWithPtrsView]] { return views.MapFnOf(v.Đļ.SlicesWithPtrs, func(t []*StructWithPtrs) views.SliceView[*StructWithPtrs, StructWithPtrsView] { return views.SliceOfViews[*StructWithPtrs, StructWithPtrsView](t) }) } - func (v MapView) SlicesWithoutPtrs() views.MapFn[string, []*StructWithoutPtrs, views.SliceView[*StructWithoutPtrs, StructWithoutPtrsView]] { return views.MapFnOf(v.Đļ.SlicesWithoutPtrs, func(t []*StructWithoutPtrs) views.SliceView[*StructWithoutPtrs, StructWithoutPtrsView] { return views.SliceOfViews[*StructWithoutPtrs, StructWithoutPtrsView](t) }) } - func (v MapView) StructWithoutPtrKey() views.Map[StructWithoutPtrs, int] { return views.MapOf(v.Đļ.StructWithoutPtrKey) } - func (v MapView) StructWithPtr() views.MapFn[string, StructWithPtrs, StructWithPtrsView] { return views.MapFnOf(v.Đļ.StructWithPtr, func(t StructWithPtrs) StructWithPtrsView { return t.View() }) } + +// Unsupported views. func (v MapView) SliceIntPtr() map[string][]*int { panic("unsupported") } func (v MapView) PointerKey() map[*string]int { panic("unsupported") } func (v MapView) StructWithPtrKey() map[StructWithPtrs]int { panic("unsupported") } @@ -248,7 +302,7 @@ var _MapViewNeedsRegeneration = Map(struct { StructWithPtrKey map[StructWithPtrs]int }{}) -// View returns a readonly view of StructWithSlices. +// View returns a read-only view of StructWithSlices. func (p *StructWithSlices) View() StructWithSlicesView { return StructWithSlicesView{Đļ: p} } @@ -264,7 +318,7 @@ type StructWithSlicesView struct { Đļ *StructWithSlices } -// Valid reports whether underlying value is non-nil. +// Valid reports whether v's underlying value is non-nil. func (v StructWithSlicesView) Valid() bool { return v.Đļ != nil } // AsStruct returns a clone of the underlying value which aliases no memory with @@ -276,8 +330,17 @@ func (v StructWithSlicesView) AsStruct() *StructWithSlices { return v.Đļ.Clone() } -func (v StructWithSlicesView) MarshalJSON() ([]byte, error) { return json.Marshal(v.Đļ) } +// MarshalJSON implements [jsonv1.Marshaler]. +func (v StructWithSlicesView) MarshalJSON() ([]byte, error) { + return jsonv1.Marshal(v.Đļ) +} +// MarshalJSONTo implements [jsonv2.MarshalerTo]. +func (v StructWithSlicesView) MarshalJSONTo(enc *jsontext.Encoder) error { + return jsonv2.MarshalEncode(enc, v.Đļ) +} + +// UnmarshalJSON implements [jsonv1.Unmarshaler]. func (v *StructWithSlicesView) UnmarshalJSON(b []byte) error { if v.Đļ != nil { return errors.New("already initialized") @@ -286,7 +349,20 @@ func (v *StructWithSlicesView) UnmarshalJSON(b []byte) error { return nil } var x StructWithSlices - if err := json.Unmarshal(b, &x); err != nil { + if err := jsonv1.Unmarshal(b, &x); err != nil { + return err + } + v.Đļ = &x + return nil +} + +// UnmarshalJSONFrom implements [jsonv2.UnmarshalerFrom]. +func (v *StructWithSlicesView) UnmarshalJSONFrom(dec *jsontext.Decoder) error { + if v.Đļ != nil { + return errors.New("already initialized") + } + var x StructWithSlices + if err := jsonv2.UnmarshalDecode(dec, &x); err != nil { return err } v.Đļ = &x @@ -307,8 +383,10 @@ func (v StructWithSlicesView) Prefixes() views.Slice[netip.Prefix] { return views.SliceOf(v.Đļ.Prefixes) } func (v StructWithSlicesView) Data() views.ByteSlice[[]byte] { return views.ByteSliceOf(v.Đļ.Data) } -func (v StructWithSlicesView) Structs() StructWithPtrs { panic("unsupported") } -func (v StructWithSlicesView) Ints() *int { panic("unsupported") } + +// Unsupported views. +func (v StructWithSlicesView) Structs() StructWithPtrs { panic("unsupported") } +func (v StructWithSlicesView) Ints() *int { panic("unsupported") } // A compilation failure here means this code must be regenerated, with the command at the top of this file. var _StructWithSlicesViewNeedsRegeneration = StructWithSlices(struct { @@ -322,7 +400,7 @@ var _StructWithSlicesViewNeedsRegeneration = StructWithSlices(struct { Ints []*int }{}) -// View returns a readonly view of StructWithEmbedded. +// View returns a read-only view of StructWithEmbedded. func (p *StructWithEmbedded) View() StructWithEmbeddedView { return StructWithEmbeddedView{Đļ: p} } @@ -338,7 +416,7 @@ type StructWithEmbeddedView struct { Đļ *StructWithEmbedded } -// Valid reports whether underlying value is non-nil. +// Valid reports whether v's underlying value is non-nil. func (v StructWithEmbeddedView) Valid() bool { return v.Đļ != nil } // AsStruct returns a clone of the underlying value which aliases no memory with @@ -350,8 +428,17 @@ func (v StructWithEmbeddedView) AsStruct() *StructWithEmbedded { return v.Đļ.Clone() } -func (v StructWithEmbeddedView) MarshalJSON() ([]byte, error) { return json.Marshal(v.Đļ) } +// MarshalJSON implements [jsonv1.Marshaler]. +func (v StructWithEmbeddedView) MarshalJSON() ([]byte, error) { + return jsonv1.Marshal(v.Đļ) +} + +// MarshalJSONTo implements [jsonv2.MarshalerTo]. +func (v StructWithEmbeddedView) MarshalJSONTo(enc *jsontext.Encoder) error { + return jsonv2.MarshalEncode(enc, v.Đļ) +} +// UnmarshalJSON implements [jsonv1.Unmarshaler]. func (v *StructWithEmbeddedView) UnmarshalJSON(b []byte) error { if v.Đļ != nil { return errors.New("already initialized") @@ -360,7 +447,20 @@ func (v *StructWithEmbeddedView) UnmarshalJSON(b []byte) error { return nil } var x StructWithEmbedded - if err := json.Unmarshal(b, &x); err != nil { + if err := jsonv1.Unmarshal(b, &x); err != nil { + return err + } + v.Đļ = &x + return nil +} + +// UnmarshalJSONFrom implements [jsonv2.UnmarshalerFrom]. +func (v *StructWithEmbeddedView) UnmarshalJSONFrom(dec *jsontext.Decoder) error { + if v.Đļ != nil { + return errors.New("already initialized") + } + var x StructWithEmbedded + if err := jsonv2.UnmarshalDecode(dec, &x); err != nil { return err } v.Đļ = &x @@ -378,7 +478,7 @@ var _StructWithEmbeddedViewNeedsRegeneration = StructWithEmbedded(struct { StructWithSlices }{}) -// View returns a readonly view of GenericIntStruct. +// View returns a read-only view of GenericIntStruct. func (p *GenericIntStruct[T]) View() GenericIntStructView[T] { return GenericIntStructView[T]{Đļ: p} } @@ -394,7 +494,7 @@ type GenericIntStructView[T constraints.Integer] struct { Đļ *GenericIntStruct[T] } -// Valid reports whether underlying value is non-nil. +// Valid reports whether v's underlying value is non-nil. func (v GenericIntStructView[T]) Valid() bool { return v.Đļ != nil } // AsStruct returns a clone of the underlying value which aliases no memory with @@ -406,8 +506,17 @@ func (v GenericIntStructView[T]) AsStruct() *GenericIntStruct[T] { return v.Đļ.Clone() } -func (v GenericIntStructView[T]) MarshalJSON() ([]byte, error) { return json.Marshal(v.Đļ) } +// MarshalJSON implements [jsonv1.Marshaler]. +func (v GenericIntStructView[T]) MarshalJSON() ([]byte, error) { + return jsonv1.Marshal(v.Đļ) +} +// MarshalJSONTo implements [jsonv2.MarshalerTo]. +func (v GenericIntStructView[T]) MarshalJSONTo(enc *jsontext.Encoder) error { + return jsonv2.MarshalEncode(enc, v.Đļ) +} + +// UnmarshalJSON implements [jsonv1.Unmarshaler]. func (v *GenericIntStructView[T]) UnmarshalJSON(b []byte) error { if v.Đļ != nil { return errors.New("already initialized") @@ -416,25 +525,35 @@ func (v *GenericIntStructView[T]) UnmarshalJSON(b []byte) error { return nil } var x GenericIntStruct[T] - if err := json.Unmarshal(b, &x); err != nil { + if err := jsonv1.Unmarshal(b, &x); err != nil { return err } v.Đļ = &x return nil } -func (v GenericIntStructView[T]) Value() T { return v.Đļ.Value } -func (v GenericIntStructView[T]) Pointer() *T { - if v.Đļ.Pointer == nil { - return nil +// UnmarshalJSONFrom implements [jsonv2.UnmarshalerFrom]. +func (v *GenericIntStructView[T]) UnmarshalJSONFrom(dec *jsontext.Decoder) error { + if v.Đļ != nil { + return errors.New("already initialized") } - x := *v.Đļ.Pointer - return &x + var x GenericIntStruct[T] + if err := jsonv2.UnmarshalDecode(dec, &x); err != nil { + return err + } + v.Đļ = &x + return nil } -func (v GenericIntStructView[T]) Slice() views.Slice[T] { return views.SliceOf(v.Đļ.Slice) } +func (v GenericIntStructView[T]) Value() T { return v.Đļ.Value } +func (v GenericIntStructView[T]) Pointer() views.ValuePointer[T] { + return views.ValuePointerOf(v.Đļ.Pointer) +} -func (v GenericIntStructView[T]) Map() views.Map[string, T] { return views.MapOf(v.Đļ.Map) } +func (v GenericIntStructView[T]) Slice() views.Slice[T] { return views.SliceOf(v.Đļ.Slice) } +func (v GenericIntStructView[T]) Map() views.Map[string, T] { return views.MapOf(v.Đļ.Map) } + +// Unsupported views. func (v GenericIntStructView[T]) PtrSlice() *T { panic("unsupported") } func (v GenericIntStructView[T]) PtrKeyMap() map[*T]string { panic("unsupported") } func (v GenericIntStructView[T]) PtrValueMap() map[string]*T { panic("unsupported") } @@ -454,7 +573,7 @@ func _GenericIntStructViewNeedsRegeneration[T constraints.Integer](GenericIntStr }{}) } -// View returns a readonly view of GenericNoPtrsStruct. +// View returns a read-only view of GenericNoPtrsStruct. func (p *GenericNoPtrsStruct[T]) View() GenericNoPtrsStructView[T] { return GenericNoPtrsStructView[T]{Đļ: p} } @@ -470,7 +589,7 @@ type GenericNoPtrsStructView[T StructWithoutPtrs | netip.Prefix | BasicType] str Đļ *GenericNoPtrsStruct[T] } -// Valid reports whether underlying value is non-nil. +// Valid reports whether v's underlying value is non-nil. func (v GenericNoPtrsStructView[T]) Valid() bool { return v.Đļ != nil } // AsStruct returns a clone of the underlying value which aliases no memory with @@ -482,8 +601,17 @@ func (v GenericNoPtrsStructView[T]) AsStruct() *GenericNoPtrsStruct[T] { return v.Đļ.Clone() } -func (v GenericNoPtrsStructView[T]) MarshalJSON() ([]byte, error) { return json.Marshal(v.Đļ) } +// MarshalJSON implements [jsonv1.Marshaler]. +func (v GenericNoPtrsStructView[T]) MarshalJSON() ([]byte, error) { + return jsonv1.Marshal(v.Đļ) +} + +// MarshalJSONTo implements [jsonv2.MarshalerTo]. +func (v GenericNoPtrsStructView[T]) MarshalJSONTo(enc *jsontext.Encoder) error { + return jsonv2.MarshalEncode(enc, v.Đļ) +} +// UnmarshalJSON implements [jsonv1.Unmarshaler]. func (v *GenericNoPtrsStructView[T]) UnmarshalJSON(b []byte) error { if v.Đļ != nil { return errors.New("already initialized") @@ -492,25 +620,35 @@ func (v *GenericNoPtrsStructView[T]) UnmarshalJSON(b []byte) error { return nil } var x GenericNoPtrsStruct[T] - if err := json.Unmarshal(b, &x); err != nil { + if err := jsonv1.Unmarshal(b, &x); err != nil { return err } v.Đļ = &x return nil } -func (v GenericNoPtrsStructView[T]) Value() T { return v.Đļ.Value } -func (v GenericNoPtrsStructView[T]) Pointer() *T { - if v.Đļ.Pointer == nil { - return nil +// UnmarshalJSONFrom implements [jsonv2.UnmarshalerFrom]. +func (v *GenericNoPtrsStructView[T]) UnmarshalJSONFrom(dec *jsontext.Decoder) error { + if v.Đļ != nil { + return errors.New("already initialized") } - x := *v.Đļ.Pointer - return &x + var x GenericNoPtrsStruct[T] + if err := jsonv2.UnmarshalDecode(dec, &x); err != nil { + return err + } + v.Đļ = &x + return nil } -func (v GenericNoPtrsStructView[T]) Slice() views.Slice[T] { return views.SliceOf(v.Đļ.Slice) } +func (v GenericNoPtrsStructView[T]) Value() T { return v.Đļ.Value } +func (v GenericNoPtrsStructView[T]) Pointer() views.ValuePointer[T] { + return views.ValuePointerOf(v.Đļ.Pointer) +} + +func (v GenericNoPtrsStructView[T]) Slice() views.Slice[T] { return views.SliceOf(v.Đļ.Slice) } +func (v GenericNoPtrsStructView[T]) Map() views.Map[string, T] { return views.MapOf(v.Đļ.Map) } -func (v GenericNoPtrsStructView[T]) Map() views.Map[string, T] { return views.MapOf(v.Đļ.Map) } +// Unsupported views. func (v GenericNoPtrsStructView[T]) PtrSlice() *T { panic("unsupported") } func (v GenericNoPtrsStructView[T]) PtrKeyMap() map[*T]string { panic("unsupported") } func (v GenericNoPtrsStructView[T]) PtrValueMap() map[string]*T { panic("unsupported") } @@ -530,7 +668,7 @@ func _GenericNoPtrsStructViewNeedsRegeneration[T StructWithoutPtrs | netip.Prefi }{}) } -// View returns a readonly view of GenericCloneableStruct. +// View returns a read-only view of GenericCloneableStruct. func (p *GenericCloneableStruct[T, V]) View() GenericCloneableStructView[T, V] { return GenericCloneableStructView[T, V]{Đļ: p} } @@ -546,7 +684,7 @@ type GenericCloneableStructView[T views.ViewCloner[T, V], V views.StructView[T]] Đļ *GenericCloneableStruct[T, V] } -// Valid reports whether underlying value is non-nil. +// Valid reports whether v's underlying value is non-nil. func (v GenericCloneableStructView[T, V]) Valid() bool { return v.Đļ != nil } // AsStruct returns a clone of the underlying value which aliases no memory with @@ -558,8 +696,17 @@ func (v GenericCloneableStructView[T, V]) AsStruct() *GenericCloneableStruct[T, return v.Đļ.Clone() } -func (v GenericCloneableStructView[T, V]) MarshalJSON() ([]byte, error) { return json.Marshal(v.Đļ) } +// MarshalJSON implements [jsonv1.Marshaler]. +func (v GenericCloneableStructView[T, V]) MarshalJSON() ([]byte, error) { + return jsonv1.Marshal(v.Đļ) +} + +// MarshalJSONTo implements [jsonv2.MarshalerTo]. +func (v GenericCloneableStructView[T, V]) MarshalJSONTo(enc *jsontext.Encoder) error { + return jsonv2.MarshalEncode(enc, v.Đļ) +} +// UnmarshalJSON implements [jsonv1.Unmarshaler]. func (v *GenericCloneableStructView[T, V]) UnmarshalJSON(b []byte) error { if v.Đļ != nil { return errors.New("already initialized") @@ -568,7 +715,20 @@ func (v *GenericCloneableStructView[T, V]) UnmarshalJSON(b []byte) error { return nil } var x GenericCloneableStruct[T, V] - if err := json.Unmarshal(b, &x); err != nil { + if err := jsonv1.Unmarshal(b, &x); err != nil { + return err + } + v.Đļ = &x + return nil +} + +// UnmarshalJSONFrom implements [jsonv2.UnmarshalerFrom]. +func (v *GenericCloneableStructView[T, V]) UnmarshalJSONFrom(dec *jsontext.Decoder) error { + if v.Đļ != nil { + return errors.New("already initialized") + } + var x GenericCloneableStruct[T, V] + if err := jsonv2.UnmarshalDecode(dec, &x); err != nil { return err } v.Đļ = &x @@ -579,12 +739,13 @@ func (v GenericCloneableStructView[T, V]) Value() V { return v.Đļ.Value.View() } func (v GenericCloneableStructView[T, V]) Slice() views.SliceView[T, V] { return views.SliceOfViews[T, V](v.Đļ.Slice) } - func (v GenericCloneableStructView[T, V]) Map() views.MapFn[string, T, V] { return views.MapFnOf(v.Đļ.Map, func(t T) V { return t.View() }) } + +// Unsupported views. func (v GenericCloneableStructView[T, V]) Pointer() map[string]T { panic("unsupported") } func (v GenericCloneableStructView[T, V]) PtrSlice() *T { panic("unsupported") } func (v GenericCloneableStructView[T, V]) PtrKeyMap() map[*T]string { panic("unsupported") } @@ -605,7 +766,7 @@ func _GenericCloneableStructViewNeedsRegeneration[T views.ViewCloner[T, V], V vi }{}) } -// View returns a readonly view of StructWithContainers. +// View returns a read-only view of StructWithContainers. func (p *StructWithContainers) View() StructWithContainersView { return StructWithContainersView{Đļ: p} } @@ -621,7 +782,7 @@ type StructWithContainersView struct { Đļ *StructWithContainers } -// Valid reports whether underlying value is non-nil. +// Valid reports whether v's underlying value is non-nil. func (v StructWithContainersView) Valid() bool { return v.Đļ != nil } // AsStruct returns a clone of the underlying value which aliases no memory with @@ -633,8 +794,17 @@ func (v StructWithContainersView) AsStruct() *StructWithContainers { return v.Đļ.Clone() } -func (v StructWithContainersView) MarshalJSON() ([]byte, error) { return json.Marshal(v.Đļ) } +// MarshalJSON implements [jsonv1.Marshaler]. +func (v StructWithContainersView) MarshalJSON() ([]byte, error) { + return jsonv1.Marshal(v.Đļ) +} +// MarshalJSONTo implements [jsonv2.MarshalerTo]. +func (v StructWithContainersView) MarshalJSONTo(enc *jsontext.Encoder) error { + return jsonv2.MarshalEncode(enc, v.Đļ) +} + +// UnmarshalJSON implements [jsonv1.Unmarshaler]. func (v *StructWithContainersView) UnmarshalJSON(b []byte) error { if v.Đļ != nil { return errors.New("already initialized") @@ -643,7 +813,20 @@ func (v *StructWithContainersView) UnmarshalJSON(b []byte) error { return nil } var x StructWithContainers - if err := json.Unmarshal(b, &x); err != nil { + if err := jsonv1.Unmarshal(b, &x); err != nil { + return err + } + v.Đļ = &x + return nil +} + +// UnmarshalJSONFrom implements [jsonv2.UnmarshalerFrom]. +func (v *StructWithContainersView) UnmarshalJSONFrom(dec *jsontext.Decoder) error { + if v.Đļ != nil { + return errors.New("already initialized") + } + var x StructWithContainers + if err := jsonv2.UnmarshalDecode(dec, &x); err != nil { return err } v.Đļ = &x @@ -677,7 +860,7 @@ var _StructWithContainersViewNeedsRegeneration = StructWithContainers(struct { CloneableGenericMap MapContainer[int, *GenericNoPtrsStruct[int]] }{}) -// View returns a readonly view of StructWithTypeAliasFields. +// View returns a read-only view of StructWithTypeAliasFields. func (p *StructWithTypeAliasFields) View() StructWithTypeAliasFieldsView { return StructWithTypeAliasFieldsView{Đļ: p} } @@ -693,7 +876,7 @@ type StructWithTypeAliasFieldsView struct { Đļ *StructWithTypeAliasFields } -// Valid reports whether underlying value is non-nil. +// Valid reports whether v's underlying value is non-nil. func (v StructWithTypeAliasFieldsView) Valid() bool { return v.Đļ != nil } // AsStruct returns a clone of the underlying value which aliases no memory with @@ -705,8 +888,17 @@ func (v StructWithTypeAliasFieldsView) AsStruct() *StructWithTypeAliasFields { return v.Đļ.Clone() } -func (v StructWithTypeAliasFieldsView) MarshalJSON() ([]byte, error) { return json.Marshal(v.Đļ) } +// MarshalJSON implements [jsonv1.Marshaler]. +func (v StructWithTypeAliasFieldsView) MarshalJSON() ([]byte, error) { + return jsonv1.Marshal(v.Đļ) +} +// MarshalJSONTo implements [jsonv2.MarshalerTo]. +func (v StructWithTypeAliasFieldsView) MarshalJSONTo(enc *jsontext.Encoder) error { + return jsonv2.MarshalEncode(enc, v.Đļ) +} + +// UnmarshalJSON implements [jsonv1.Unmarshaler]. func (v *StructWithTypeAliasFieldsView) UnmarshalJSON(b []byte) error { if v.Đļ != nil { return errors.New("already initialized") @@ -715,51 +907,55 @@ func (v *StructWithTypeAliasFieldsView) UnmarshalJSON(b []byte) error { return nil } var x StructWithTypeAliasFields - if err := json.Unmarshal(b, &x); err != nil { + if err := jsonv1.Unmarshal(b, &x); err != nil { return err } v.Đļ = &x return nil } -func (v StructWithTypeAliasFieldsView) WithPtr() StructWithPtrsView { return v.Đļ.WithPtr.View() } +// UnmarshalJSONFrom implements [jsonv2.UnmarshalerFrom]. +func (v *StructWithTypeAliasFieldsView) UnmarshalJSONFrom(dec *jsontext.Decoder) error { + if v.Đļ != nil { + return errors.New("already initialized") + } + var x StructWithTypeAliasFields + if err := jsonv2.UnmarshalDecode(dec, &x); err != nil { + return err + } + v.Đļ = &x + return nil +} + +func (v StructWithTypeAliasFieldsView) WithPtr() StructWithPtrsAliasView { return v.Đļ.WithPtr.View() } func (v StructWithTypeAliasFieldsView) WithoutPtr() StructWithoutPtrsAlias { return v.Đļ.WithoutPtr } func (v StructWithTypeAliasFieldsView) WithPtrByPtr() StructWithPtrsAliasView { return v.Đļ.WithPtrByPtr.View() } -func (v StructWithTypeAliasFieldsView) WithoutPtrByPtr() *StructWithoutPtrsAlias { - if v.Đļ.WithoutPtrByPtr == nil { - return nil - } - x := *v.Đļ.WithoutPtrByPtr - return &x +func (v StructWithTypeAliasFieldsView) WithoutPtrByPtr() StructWithoutPtrsAliasView { + return v.Đļ.WithoutPtrByPtr.View() } - func (v StructWithTypeAliasFieldsView) SliceWithPtrs() views.SliceView[*StructWithPtrsAlias, StructWithPtrsAliasView] { return views.SliceOfViews[*StructWithPtrsAlias, StructWithPtrsAliasView](v.Đļ.SliceWithPtrs) } func (v StructWithTypeAliasFieldsView) SliceWithoutPtrs() views.SliceView[*StructWithoutPtrsAlias, StructWithoutPtrsAliasView] { return views.SliceOfViews[*StructWithoutPtrsAlias, StructWithoutPtrsAliasView](v.Đļ.SliceWithoutPtrs) } - func (v StructWithTypeAliasFieldsView) MapWithPtrs() views.MapFn[string, *StructWithPtrsAlias, StructWithPtrsAliasView] { return views.MapFnOf(v.Đļ.MapWithPtrs, func(t *StructWithPtrsAlias) StructWithPtrsAliasView { return t.View() }) } - func (v StructWithTypeAliasFieldsView) MapWithoutPtrs() views.MapFn[string, *StructWithoutPtrsAlias, StructWithoutPtrsAliasView] { return views.MapFnOf(v.Đļ.MapWithoutPtrs, func(t *StructWithoutPtrsAlias) StructWithoutPtrsAliasView { return t.View() }) } - func (v StructWithTypeAliasFieldsView) MapOfSlicesWithPtrs() views.MapFn[string, []*StructWithPtrsAlias, views.SliceView[*StructWithPtrsAlias, StructWithPtrsAliasView]] { return views.MapFnOf(v.Đļ.MapOfSlicesWithPtrs, func(t []*StructWithPtrsAlias) views.SliceView[*StructWithPtrsAlias, StructWithPtrsAliasView] { return views.SliceOfViews[*StructWithPtrsAlias, StructWithPtrsAliasView](t) }) } - func (v StructWithTypeAliasFieldsView) MapOfSlicesWithoutPtrs() views.MapFn[string, []*StructWithoutPtrsAlias, views.SliceView[*StructWithoutPtrsAlias, StructWithoutPtrsAliasView]] { return views.MapFnOf(v.Đļ.MapOfSlicesWithoutPtrs, func(t []*StructWithoutPtrsAlias) views.SliceView[*StructWithoutPtrsAlias, StructWithoutPtrsAliasView] { return views.SliceOfViews[*StructWithoutPtrsAlias, StructWithoutPtrsAliasView](t) @@ -780,7 +976,7 @@ var _StructWithTypeAliasFieldsViewNeedsRegeneration = StructWithTypeAliasFields( MapOfSlicesWithoutPtrs map[string][]*StructWithoutPtrsAlias }{}) -// View returns a readonly view of GenericTypeAliasStruct. +// View returns a read-only view of GenericTypeAliasStruct. func (p *GenericTypeAliasStruct[T, T2, V2]) View() GenericTypeAliasStructView[T, T2, V2] { return GenericTypeAliasStructView[T, T2, V2]{Đļ: p} } @@ -796,7 +992,7 @@ type GenericTypeAliasStructView[T integer, T2 views.ViewCloner[T2, V2], V2 views Đļ *GenericTypeAliasStruct[T, T2, V2] } -// Valid reports whether underlying value is non-nil. +// Valid reports whether v's underlying value is non-nil. func (v GenericTypeAliasStructView[T, T2, V2]) Valid() bool { return v.Đļ != nil } // AsStruct returns a clone of the underlying value which aliases no memory with @@ -808,10 +1004,17 @@ func (v GenericTypeAliasStructView[T, T2, V2]) AsStruct() *GenericTypeAliasStruc return v.Đļ.Clone() } +// MarshalJSON implements [jsonv1.Marshaler]. func (v GenericTypeAliasStructView[T, T2, V2]) MarshalJSON() ([]byte, error) { - return json.Marshal(v.Đļ) + return jsonv1.Marshal(v.Đļ) +} + +// MarshalJSONTo implements [jsonv2.MarshalerTo]. +func (v GenericTypeAliasStructView[T, T2, V2]) MarshalJSONTo(enc *jsontext.Encoder) error { + return jsonv2.MarshalEncode(enc, v.Đļ) } +// UnmarshalJSON implements [jsonv1.Unmarshaler]. func (v *GenericTypeAliasStructView[T, T2, V2]) UnmarshalJSON(b []byte) error { if v.Đļ != nil { return errors.New("already initialized") @@ -820,7 +1023,20 @@ func (v *GenericTypeAliasStructView[T, T2, V2]) UnmarshalJSON(b []byte) error { return nil } var x GenericTypeAliasStruct[T, T2, V2] - if err := json.Unmarshal(b, &x); err != nil { + if err := jsonv1.Unmarshal(b, &x); err != nil { + return err + } + v.Đļ = &x + return nil +} + +// UnmarshalJSONFrom implements [jsonv2.UnmarshalerFrom]. +func (v *GenericTypeAliasStructView[T, T2, V2]) UnmarshalJSONFrom(dec *jsontext.Decoder) error { + if v.Đļ != nil { + return errors.New("already initialized") + } + var x GenericTypeAliasStruct[T, T2, V2] + if err := jsonv2.UnmarshalDecode(dec, &x); err != nil { return err } v.Đļ = &x @@ -837,3 +1053,79 @@ func _GenericTypeAliasStructViewNeedsRegeneration[T integer, T2 views.ViewCloner Cloneable T2 }{}) } + +// View returns a read-only view of StructWithMapOfViews. +func (p *StructWithMapOfViews) View() StructWithMapOfViewsView { + return StructWithMapOfViewsView{Đļ: p} +} + +// StructWithMapOfViewsView provides a read-only view over StructWithMapOfViews. +// +// Its methods should only be called if `Valid()` returns true. +type StructWithMapOfViewsView struct { + // Đļ is the underlying mutable value, named with a hard-to-type + // character that looks pointy like a pointer. + // It is named distinctively to make you think of how dangerous it is to escape + // to callers. You must not let callers be able to mutate it. + Đļ *StructWithMapOfViews +} + +// Valid reports whether v's underlying value is non-nil. +func (v StructWithMapOfViewsView) Valid() bool { return v.Đļ != nil } + +// AsStruct returns a clone of the underlying value which aliases no memory with +// the original. +func (v StructWithMapOfViewsView) AsStruct() *StructWithMapOfViews { + if v.Đļ == nil { + return nil + } + return v.Đļ.Clone() +} + +// MarshalJSON implements [jsonv1.Marshaler]. +func (v StructWithMapOfViewsView) MarshalJSON() ([]byte, error) { + return jsonv1.Marshal(v.Đļ) +} + +// MarshalJSONTo implements [jsonv2.MarshalerTo]. +func (v StructWithMapOfViewsView) MarshalJSONTo(enc *jsontext.Encoder) error { + return jsonv2.MarshalEncode(enc, v.Đļ) +} + +// UnmarshalJSON implements [jsonv1.Unmarshaler]. +func (v *StructWithMapOfViewsView) UnmarshalJSON(b []byte) error { + if v.Đļ != nil { + return errors.New("already initialized") + } + if len(b) == 0 { + return nil + } + var x StructWithMapOfViews + if err := jsonv1.Unmarshal(b, &x); err != nil { + return err + } + v.Đļ = &x + return nil +} + +// UnmarshalJSONFrom implements [jsonv2.UnmarshalerFrom]. +func (v *StructWithMapOfViewsView) UnmarshalJSONFrom(dec *jsontext.Decoder) error { + if v.Đļ != nil { + return errors.New("already initialized") + } + var x StructWithMapOfViews + if err := jsonv2.UnmarshalDecode(dec, &x); err != nil { + return err + } + v.Đļ = &x + return nil +} + +func (v StructWithMapOfViewsView) MapOfViews() views.Map[string, StructWithoutPtrsView] { + return views.MapOf(v.Đļ.MapOfViews) +} + +// A compilation failure here means this code must be regenerated, with the command at the top of this file. +var _StructWithMapOfViewsViewNeedsRegeneration = StructWithMapOfViews(struct { + MapOfViews map[string]StructWithoutPtrsView +}{}) diff --git a/cmd/viewer/viewer.go b/cmd/viewer/viewer.go index 96223297b..3fae737cd 100644 --- a/cmd/viewer/viewer.go +++ b/cmd/viewer/viewer.go @@ -9,6 +9,8 @@ import ( "bytes" "flag" "fmt" + "go/ast" + "go/token" "go/types" "html/template" "log" @@ -17,11 +19,12 @@ import ( "strings" "tailscale.com/util/codegen" + "tailscale.com/util/mak" "tailscale.com/util/must" ) const viewTemplateStr = `{{define "common"}} -// View returns a readonly view of {{.StructName}}. +// View returns a read-only view of {{.StructName}}. func (p *{{.StructName}}{{.TypeParamNames}}) View() {{.ViewName}}{{.TypeParamNames}} { return {{.ViewName}}{{.TypeParamNames}}{Đļ: p} } @@ -37,7 +40,7 @@ type {{.ViewName}}{{.TypeParams}} struct { Đļ *{{.StructName}}{{.TypeParamNames}} } -// Valid reports whether underlying value is non-nil. +// Valid reports whether v's underlying value is non-nil. func (v {{.ViewName}}{{.TypeParamNames}}) Valid() bool { return v.Đļ != nil } // AsStruct returns a clone of the underlying value which aliases no memory with @@ -49,8 +52,17 @@ func (v {{.ViewName}}{{.TypeParamNames}}) AsStruct() *{{.StructName}}{{.TypePara return v.Đļ.Clone() } -func (v {{.ViewName}}{{.TypeParamNames}}) MarshalJSON() ([]byte, error) { return json.Marshal(v.Đļ) } +// MarshalJSON implements [jsonv1.Marshaler]. +func (v {{.ViewName}}{{.TypeParamNames}}) MarshalJSON() ([]byte, error) { + return jsonv1.Marshal(v.Đļ) +} + +// MarshalJSONTo implements [jsonv2.MarshalerTo]. +func (v {{.ViewName}}{{.TypeParamNames}}) MarshalJSONTo(enc *jsontext.Encoder) error { + return jsonv2.MarshalEncode(enc, v.Đļ) +} +// UnmarshalJSON implements [jsonv1.Unmarshaler]. func (v *{{.ViewName}}{{.TypeParamNames}}) UnmarshalJSON(b []byte) error { if v.Đļ != nil { return errors.New("already initialized") @@ -59,10 +71,23 @@ func (v *{{.ViewName}}{{.TypeParamNames}}) UnmarshalJSON(b []byte) error { return nil } var x {{.StructName}}{{.TypeParamNames}} - if err := json.Unmarshal(b, &x); err != nil { + if err := jsonv1.Unmarshal(b, &x); err != nil { + return err + } + v.Đļ = &x + return nil +} + +// UnmarshalJSONFrom implements [jsonv2.UnmarshalerFrom]. +func (v *{{.ViewName}}{{.TypeParamNames}}) UnmarshalJSONFrom(dec *jsontext.Decoder) error { + if v.Đļ != nil { + return errors.New("already initialized") + } + var x {{.StructName}}{{.TypeParamNames}} + if err := jsonv2.UnmarshalDecode(dec, &x); err != nil { return err } - v.Đļ=&x + v.Đļ = &x return nil } @@ -79,25 +104,16 @@ func (v *{{.ViewName}}{{.TypeParamNames}}) UnmarshalJSON(b []byte) error { {{end}} {{define "makeViewField"}}func (v {{.ViewName}}{{.TypeParamNames}}) {{.FieldName}}() {{.FieldViewName}} { return {{.MakeViewFnName}}(&v.Đļ.{{.FieldName}}) } {{end}} -{{define "valuePointerField"}}func (v {{.ViewName}}{{.TypeParamNames}}) {{.FieldName}}() {{.FieldType}} { - if v.Đļ.{{.FieldName}} == nil { - return nil - } - x := *v.Đļ.{{.FieldName}} - return &x -} +{{define "valuePointerField"}}func (v {{.ViewName}}{{.TypeParamNames}}) {{.FieldName}}() views.ValuePointer[{{.FieldType}}] { return views.ValuePointerOf(v.Đļ.{{.FieldName}}) } {{end}} -{{define "mapField"}} -func(v {{.ViewName}}{{.TypeParamNames}}) {{.FieldName}}() views.Map[{{.MapKeyType}},{{.MapValueType}}] { return views.MapOf(v.Đļ.{{.FieldName}})} +{{define "mapField"}}func(v {{.ViewName}}{{.TypeParamNames}}) {{.FieldName}}() views.Map[{{.MapKeyType}},{{.MapValueType}}] { return views.MapOf(v.Đļ.{{.FieldName}})} {{end}} -{{define "mapFnField"}} -func(v {{.ViewName}}{{.TypeParamNames}}) {{.FieldName}}() views.MapFn[{{.MapKeyType}},{{.MapValueType}},{{.MapValueView}}] { return views.MapFnOf(v.Đļ.{{.FieldName}}, func (t {{.MapValueType}}) {{.MapValueView}} { +{{define "mapFnField"}}func(v {{.ViewName}}{{.TypeParamNames}}) {{.FieldName}}() views.MapFn[{{.MapKeyType}},{{.MapValueType}},{{.MapValueView}}] { return views.MapFnOf(v.Đļ.{{.FieldName}}, func (t {{.MapValueType}}) {{.MapValueView}} { return {{.MapFn}} })} {{end}} -{{define "mapSliceField"}} -func(v {{.ViewName}}{{.TypeParamNames}}) {{.FieldName}}() views.MapSlice[{{.MapKeyType}},{{.MapValueType}}] { return views.MapSliceOf(v.Đļ.{{.FieldName}}) } +{{define "mapSliceField"}}func(v {{.ViewName}}{{.TypeParamNames}}) {{.FieldName}}() views.MapSlice[{{.MapKeyType}},{{.MapValueType}}] { return views.MapSliceOf(v.Đļ.{{.FieldName}}) } {{end}} {{define "unsupportedField"}}func(v {{.ViewName}}{{.TypeParamNames}}) {{.FieldName}}() {{.FieldType}} {panic("unsupported")} {{end}} @@ -126,13 +142,89 @@ func requiresCloning(t types.Type) (shallow, deep bool, base types.Type) { return p, p, t } -func genView(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named, thisPkg *types.Package) { +type fieldNameKey struct { + typeName string + fieldName string +} + +// getFieldComments extracts field comments from the AST for a given struct type. +func getFieldComments(syntax []*ast.File) map[fieldNameKey]string { + if len(syntax) == 0 { + return nil + } + var fieldComments map[fieldNameKey]string + + // Search through all AST files in the package + for _, file := range syntax { + // Look for the type declaration + for _, decl := range file.Decls { + genDecl, ok := decl.(*ast.GenDecl) + if !ok || genDecl.Tok != token.TYPE { + continue + } + + for _, spec := range genDecl.Specs { + typeSpec, ok := spec.(*ast.TypeSpec) + if !ok { + continue + } + typeName := typeSpec.Name.Name + + // Check if it's a struct type + structType, ok := typeSpec.Type.(*ast.StructType) + if !ok { + continue + } + + // Extract field comments + for _, field := range structType.Fields.List { + if len(field.Names) == 0 { + // Anonymous field or no names + continue + } + + // Get the field name + fieldName := field.Names[0].Name + key := fieldNameKey{typeName, fieldName} + + // Get the comment + var comment string + if field.Doc != nil && field.Doc.Text() != "" { + // Format the comment for Go code generation + comment = strings.TrimSpace(field.Doc.Text()) + // Convert multi-line comments to proper Go comment format + var sb strings.Builder + for line := range strings.Lines(comment) { + sb.WriteString("// ") + sb.WriteString(line) + } + if sb.Len() > 0 { + comment = sb.String() + } + } else if field.Comment != nil && field.Comment.Text() != "" { + // Handle inline comments + comment = "// " + strings.TrimSpace(field.Comment.Text()) + } + if comment != "" { + mak.Set(&fieldComments, key, comment) + } + } + } + } + } + + return fieldComments +} + +func genView(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named, fieldComments map[fieldNameKey]string) { t, ok := typ.Underlying().(*types.Struct) if !ok || codegen.IsViewType(t) { return } - it.Import("encoding/json") - it.Import("errors") + it.Import("jsonv1", "encoding/json") + it.Import("jsonv2", "github.com/go-json-experiment/json") + it.Import("", "github.com/go-json-experiment/json/jsontext") + it.Import("", "errors") args := struct { StructName string @@ -149,7 +241,7 @@ func genView(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named, thi MapValueView string MapFn string - // MakeViewFnName is the name of the function that accepts a value and returns a readonly view of it. + // MakeViewFnName is the name of the function that accepts a value and returns a read-only view of it. MakeViewFnName string }{ StructName: typ.Obj().Name(), @@ -164,6 +256,15 @@ func genView(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named, thi log.Fatal(err) } } + writeTemplateWithComment := func(name, fieldName string) { + // Write the field comment if it exists + key := fieldNameKey{args.StructName, fieldName} + if comment, ok := fieldComments[key]; ok && comment != "" { + fmt.Fprintln(buf, comment) + } + writeTemplate(name) + } + writeTemplate("common") for i := range t.NumFields() { f := t.Field(i) @@ -178,7 +279,7 @@ func genView(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named, thi } if !codegen.ContainsPointers(fieldType) || codegen.IsViewType(fieldType) || codegen.HasNoClone(t.Tag(i)) { args.FieldType = it.QualifiedName(fieldType) - writeTemplate("valueField") + writeTemplateWithComment("valueField", fname) continue } switch underlying := fieldType.Underlying().(type) { @@ -188,46 +289,46 @@ func genView(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named, thi switch elem.String() { case "byte": args.FieldType = it.QualifiedName(fieldType) - it.Import("tailscale.com/types/views") - writeTemplate("byteSliceField") + it.Import("", "tailscale.com/types/views") + writeTemplateWithComment("byteSliceField", fname) default: args.FieldType = it.QualifiedName(elem) - it.Import("tailscale.com/types/views") + it.Import("", "tailscale.com/types/views") shallow, deep, base := requiresCloning(elem) if deep { switch elem.Underlying().(type) { case *types.Pointer: if _, isIface := base.Underlying().(*types.Interface); !isIface { args.FieldViewName = appendNameSuffix(it.QualifiedName(base), "View") - writeTemplate("viewSliceField") + writeTemplateWithComment("viewSliceField", fname) } else { - writeTemplate("unsupportedField") + writeTemplateWithComment("unsupportedField", fname) } continue case *types.Interface: if viewType := viewTypeForValueType(elem); viewType != nil { args.FieldViewName = it.QualifiedName(viewType) - writeTemplate("viewSliceField") + writeTemplateWithComment("viewSliceField", fname) continue } } - writeTemplate("unsupportedField") + writeTemplateWithComment("unsupportedField", fname) continue } else if shallow { switch base.Underlying().(type) { case *types.Basic, *types.Interface: - writeTemplate("unsupportedField") + writeTemplateWithComment("unsupportedField", fname) default: if _, isIface := base.Underlying().(*types.Interface); !isIface { args.FieldViewName = appendNameSuffix(it.QualifiedName(base), "View") - writeTemplate("viewSliceField") + writeTemplateWithComment("viewSliceField", fname) } else { - writeTemplate("unsupportedField") + writeTemplateWithComment("unsupportedField", fname) } } continue } - writeTemplate("sliceField") + writeTemplateWithComment("sliceField", fname) } continue case *types.Struct: @@ -236,28 +337,29 @@ func genView(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named, thi if codegen.ContainsPointers(strucT) { if viewType := viewTypeForValueType(fieldType); viewType != nil { args.FieldViewName = it.QualifiedName(viewType) - writeTemplate("viewField") + writeTemplateWithComment("viewField", fname) continue } if viewType, makeViewFn := viewTypeForContainerType(fieldType); viewType != nil { args.FieldViewName = it.QualifiedName(viewType) args.MakeViewFnName = it.PackagePrefix(makeViewFn.Pkg()) + makeViewFn.Name() - writeTemplate("makeViewField") + writeTemplateWithComment("makeViewField", fname) continue } - writeTemplate("unsupportedField") + writeTemplateWithComment("unsupportedField", fname) continue } - writeTemplate("valueField") + writeTemplateWithComment("valueField", fname) continue case *types.Map: m := underlying args.FieldType = it.QualifiedName(fieldType) shallow, deep, key := requiresCloning(m.Key()) if shallow || deep { - writeTemplate("unsupportedField") + writeTemplateWithComment("unsupportedField", fname) continue } + it.Import("", "tailscale.com/types/views") args.MapKeyType = it.QualifiedName(key) mElem := m.Elem() var template string @@ -265,14 +367,21 @@ func genView(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named, thi case *types.Struct, *types.Named, *types.Alias: strucT := u args.FieldType = it.QualifiedName(fieldType) - if codegen.ContainsPointers(strucT) { + + // We need to call View() unless the type is + // either a View itself or does not contain + // pointers (and can thus be shallow-copied). + // + // Otherwise, we need to create a View of the + // map value. + if codegen.IsViewType(strucT) || !codegen.ContainsPointers(strucT) { + template = "mapField" + args.MapValueType = it.QualifiedName(mElem) + } else { args.MapFn = "t.View()" template = "mapFnField" args.MapValueType = it.QualifiedName(mElem) args.MapValueView = appendNameSuffix(args.MapValueType, "View") - } else { - template = "mapField" - args.MapValueType = it.QualifiedName(mElem) } case *types.Basic: template = "mapField" @@ -339,7 +448,7 @@ func genView(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named, thi default: template = "unsupportedField" } - writeTemplate(template) + writeTemplateWithComment(template, fname) continue case *types.Pointer: ptr := underlying @@ -349,25 +458,47 @@ func genView(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named, thi if _, isIface := base.Underlying().(*types.Interface); !isIface { args.FieldType = it.QualifiedName(base) args.FieldViewName = appendNameSuffix(args.FieldType, "View") - writeTemplate("viewField") + writeTemplateWithComment("viewField", fname) } else { - writeTemplate("unsupportedField") + writeTemplateWithComment("unsupportedField", fname) } - } else { - args.FieldType = it.QualifiedName(ptr) - writeTemplate("valuePointerField") + continue } + + // If a view type is already defined for the base type, use it as the field's view type. + if viewType := viewTypeForValueType(base); viewType != nil { + args.FieldType = it.QualifiedName(base) + args.FieldViewName = it.QualifiedName(viewType) + writeTemplateWithComment("viewField", fname) + continue + } + + // Otherwise, if the unaliased base type is a named type whose view type will be generated by this viewer invocation, + // append the "View" suffix to the unaliased base type name and use it as the field's view type. + if base, ok := types.Unalias(base).(*types.Named); ok && slices.Contains(typeNames, it.QualifiedName(base)) { + baseTypeName := it.QualifiedName(base) + args.FieldType = baseTypeName + args.FieldViewName = appendNameSuffix(args.FieldType, "View") + writeTemplateWithComment("viewField", fname) + continue + } + + // Otherwise, if the base type does not require deep cloning, has no existing view type, + // and will not have a generated view type, use views.ValuePointer[T] as the field's view type. + // Its Get/GetOk methods return stack-allocated shallow copies of the field's value. + args.FieldType = it.QualifiedName(base) + writeTemplateWithComment("valuePointerField", fname) continue case *types.Interface: // If fieldType is an interface with a "View() {ViewType}" method, it can be used to clone the field. // This includes scenarios where fieldType is a constrained type parameter. if viewType := viewTypeForValueType(underlying); viewType != nil { args.FieldViewName = it.QualifiedName(viewType) - writeTemplate("viewField") + writeTemplateWithComment("viewField", fname) continue } } - writeTemplate("unsupportedField") + writeTemplateWithComment("unsupportedField", fname) } for i := range typ.NumMethods() { f := typ.Method(i) @@ -404,6 +535,33 @@ func appendNameSuffix(name, suffix string) string { return name + suffix } +func typeNameOf(typ types.Type) (name *types.TypeName, ok bool) { + switch t := typ.(type) { + case *types.Alias: + return t.Obj(), true + case *types.Named: + return t.Obj(), true + default: + return nil, false + } +} + +func lookupViewType(typ types.Type) types.Type { + for { + if typeName, ok := typeNameOf(typ); ok && typeName.Pkg() != nil { + if viewTypeObj := typeName.Pkg().Scope().Lookup(typeName.Name() + "View"); viewTypeObj != nil { + return viewTypeObj.Type() + } + } + switch alias := typ.(type) { + case *types.Alias: + typ = alias.Rhs() + default: + return nil + } + } +} + func viewTypeForValueType(typ types.Type) types.Type { if ptr, ok := typ.(*types.Pointer); ok { return viewTypeForValueType(ptr.Elem()) @@ -416,7 +574,12 @@ func viewTypeForValueType(typ types.Type) types.Type { if !ok || sig.Results().Len() != 1 { return nil } - return sig.Results().At(0).Type() + viewType := sig.Results().At(0).Type() + // Check if the typ's package defines an alias for the view type, and use it if so. + if viewTypeAlias, ok := lookupViewType(typ).(*types.Alias); ok && types.AssignableTo(viewType, viewTypeAlias) { + viewType = viewTypeAlias + } + return viewType } func viewTypeForContainerType(typ types.Type) (*types.Named, *types.Func) { @@ -554,6 +717,7 @@ func main() { log.Fatal(err) } it := codegen.NewImportTracker(pkg.Types) + fieldComments := getFieldComments(pkg.Syntax) cloneOnlyType := map[string]bool{} for _, t := range strings.Split(*flagCloneOnlyTypes, ",") { @@ -581,7 +745,7 @@ func main() { if !hasClone { runCloner = true } - genView(buf, it, typ, pkg.Types) + genView(buf, it, typ, fieldComments) } out := pkg.Name + "_view" if *flagBuildTags == "test" { diff --git a/cmd/viewer/viewer_test.go b/cmd/viewer/viewer_test.go new file mode 100644 index 000000000..1e24b7050 --- /dev/null +++ b/cmd/viewer/viewer_test.go @@ -0,0 +1,79 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package main + +import ( + "bytes" + "fmt" + "go/ast" + "go/parser" + "go/token" + "go/types" + "testing" + + "tailscale.com/util/codegen" +) + +func TestViewerImports(t *testing.T) { + tests := []struct { + name string + content string + typeNames []string + wantImports [][2]string + }{ + { + name: "Map", + content: `type Test struct { Map map[string]int }`, + typeNames: []string{"Test"}, + wantImports: [][2]string{{"", "tailscale.com/types/views"}}, + }, + { + name: "Slice", + content: `type Test struct { Slice []int }`, + typeNames: []string{"Test"}, + wantImports: [][2]string{{"", "tailscale.com/types/views"}}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + fset := token.NewFileSet() + f, err := parser.ParseFile(fset, "test.go", "package test\n\n"+tt.content, 0) + if err != nil { + fmt.Println("Error parsing:", err) + return + } + + info := &types.Info{ + Types: make(map[ast.Expr]types.TypeAndValue), + } + + conf := types.Config{} + pkg, err := conf.Check("", fset, []*ast.File{f}, info) + if err != nil { + t.Fatal(err) + } + var fieldComments map[fieldNameKey]string // don't need it for this test. + + var output bytes.Buffer + tracker := codegen.NewImportTracker(pkg) + for i := range tt.typeNames { + typeName, ok := pkg.Scope().Lookup(tt.typeNames[i]).(*types.TypeName) + if !ok { + t.Fatalf("type %q does not exist", tt.typeNames[i]) + } + namedType, ok := typeName.Type().(*types.Named) + if !ok { + t.Fatalf("%q is not a named type", tt.typeNames[i]) + } + genView(&output, tracker, namedType, fieldComments) + } + + for _, pkg := range tt.wantImports { + if !tracker.Has(pkg[0], pkg[1]) { + t.Errorf("missing import %q", pkg) + } + } + }) + } +} diff --git a/cmd/vnet/vnet-main.go b/cmd/vnet/vnet-main.go index 1eb4f65ef..9dd4d8cfa 100644 --- a/cmd/vnet/vnet-main.go +++ b/cmd/vnet/vnet-main.go @@ -7,15 +7,21 @@ package main import ( "context" + "encoding/binary" "flag" + "fmt" + "io" "log" "net" "net/http" "net/http/httputil" "net/url" "os" + "path/filepath" + "slices" "time" + "github.com/coder/websocket" "tailscale.com/tstest/natlab/vnet" "tailscale.com/types/logger" "tailscale.com/util/must" @@ -31,10 +37,18 @@ var ( pcapFile = flag.String("pcap", "", "if non-empty, filename to write pcap") v4 = flag.Bool("v4", true, "enable IPv4") v6 = flag.Bool("v6", true, "enable IPv6") + + wsproxyListen = flag.String("wsproxy", "", "if non-empty, TCP address to run websocket server on. See https://github.com/copy/v86/blob/master/docs/networking.md#backend-url-schemes") ) func main() { flag.Parse() + if *wsproxyListen != "" { + if err := runWSProxy(); err != nil { + log.Fatalf("runWSProxy: %v", err) + } + return + } if _, err := os.Stat(*listen); err == nil { os.Remove(*listen) @@ -137,3 +151,168 @@ func main() { go s.ServeUnixConn(c.(*net.UnixConn), vnet.ProtocolQEMU) } } + +func runWSProxy() error { + ln, err := net.Listen("tcp", *wsproxyListen) + if err != nil { + return err + } + defer ln.Close() + + log.Printf("Running wsproxy mode on %v ...", *wsproxyListen) + + var hs http.Server + hs.Handler = http.HandlerFunc(handleWebSocket) + + return hs.Serve(ln) +} + +func handleWebSocket(w http.ResponseWriter, r *http.Request) { + conn, err := websocket.Accept(w, r, &websocket.AcceptOptions{ + InsecureSkipVerify: true, + }) + if err != nil { + log.Printf("Upgrade error: %v", err) + return + } + defer conn.Close(websocket.StatusInternalError, "closing") + log.Printf("WebSocket client connected: %s", r.RemoteAddr) + + ctx, cancel := context.WithCancel(r.Context()) + defer cancel() + + messageType, firstData, err := conn.Read(ctx) + if err != nil { + log.Printf("ReadMessage first: %v", err) + return + } + if messageType != websocket.MessageBinary { + log.Printf("Ignoring non-binary message") + return + } + if len(firstData) < 12 { + log.Printf("Ignoring short message") + return + } + clientMAC := vnet.MAC(firstData[6:12]) + + // Set up a qemu-protocol Unix socket pair. We'll fake the qemu protocol here + // to avoid changing the vnet package. + td, err := os.MkdirTemp("", "vnet") + if err != nil { + panic(fmt.Errorf("MkdirTemp: %v", err)) + } + defer os.RemoveAll(td) + + unixSrv := filepath.Join(td, "vnet.sock") + + srv, err := net.Listen("unix", unixSrv) + if err != nil { + panic(fmt.Errorf("Listen: %v", err)) + } + defer srv.Close() + + var c vnet.Config + c.SetBlendReality(true) + + var net1opt = []any{vnet.NAT("easy")} + net1opt = append(net1opt, "2.1.1.1", "192.168.1.1/24") + net1opt = append(net1opt, "2000:52::1/64") + + c.AddNode(c.AddNetwork(net1opt...), clientMAC) + + vs, err := vnet.New(&c) + if err != nil { + panic(fmt.Errorf("newServer: %v", err)) + } + if err := vs.PopulateDERPMapIPs(); err != nil { + log.Printf("warning: ignoring failure to populate DERP map: %v", err) + return + } + + errc := make(chan error, 1) + fail := func(err error) { + select { + case errc <- err: + log.Printf("failed: %v", err) + case <-ctx.Done(): + } + } + + go func() { + c, err := srv.Accept() + if err != nil { + fail(err) + return + } + vs.ServeUnixConn(c.(*net.UnixConn), vnet.ProtocolQEMU) + }() + + uc, err := net.Dial("unix", unixSrv) + if err != nil { + panic(fmt.Errorf("Dial: %v", err)) + } + defer uc.Close() + + var frameBuf []byte + writeDataToUnixConn := func(data []byte) error { + frameBuf = slices.Grow(frameBuf[:0], len(data)+4)[:len(data)+4] + binary.BigEndian.PutUint32(frameBuf[:4], uint32(len(data))) + copy(frameBuf[4:], data) + + _, err = uc.Write(frameBuf) + return err + } + if err := writeDataToUnixConn(firstData); err != nil { + fail(err) + return + } + + go func() { + for { + messageType, data, err := conn.Read(ctx) + if err != nil { + fail(fmt.Errorf("ReadMessage: %v", err)) + break + } + + if messageType != websocket.MessageBinary { + log.Printf("Ignoring non-binary message") + continue + } + + if err := writeDataToUnixConn(data); err != nil { + fail(err) + return + } + } + }() + + go func() { + const maxBuf = 4096 + frameBuf := make([]byte, maxBuf) + for { + _, err := io.ReadFull(uc, frameBuf[:4]) + if err != nil { + fail(err) + return + } + frameLen := binary.BigEndian.Uint32(frameBuf[:4]) + if frameLen > maxBuf { + fail(fmt.Errorf("frame too large: %d", frameLen)) + return + } + if _, err := io.ReadFull(uc, frameBuf[:frameLen]); err != nil { + fail(err) + return + } + + if err := conn.Write(ctx, websocket.MessageBinary, frameBuf[:frameLen]); err != nil { + fail(err) + return + } + } + }() + + <-ctx.Done() +} diff --git a/cmd/xdpderper/xdpderper.go b/cmd/xdpderper/xdpderper.go index 599034ae7..c127baf54 100644 --- a/cmd/xdpderper/xdpderper.go +++ b/cmd/xdpderper/xdpderper.go @@ -18,6 +18,9 @@ import ( "tailscale.com/derp/xdp" "tailscale.com/net/netutil" "tailscale.com/tsweb" + + // Support for prometheus varz in tsweb + _ "tailscale.com/tsweb/promvarz" ) var ( diff --git a/control/controlbase/conn.go b/control/controlbase/conn.go index dc22212e8..78ef73f71 100644 --- a/control/controlbase/conn.go +++ b/control/controlbase/conn.go @@ -18,6 +18,7 @@ import ( "golang.org/x/crypto/blake2s" chp "golang.org/x/crypto/chacha20poly1305" + "tailscale.com/syncs" "tailscale.com/types/key" ) @@ -48,7 +49,7 @@ type Conn struct { // rxState is all the Conn state that Read uses. type rxState struct { - sync.Mutex + syncs.Mutex cipher cipher.AEAD nonce nonce buf *maxMsgBuffer // or nil when reads exhausted diff --git a/control/controlbase/conn_test.go b/control/controlbase/conn_test.go index 8a0f46967..ed4642d3b 100644 --- a/control/controlbase/conn_test.go +++ b/control/controlbase/conn_test.go @@ -280,7 +280,7 @@ func TestConnMemoryOverhead(t *testing.T) { growthTotal := int64(ms.HeapAlloc) - int64(ms0.HeapAlloc) growthEach := float64(growthTotal) / float64(num) t.Logf("Alloced %v bytes, %.2f B/each", growthTotal, growthEach) - const max = 2000 + const max = 2048 if growthEach > max { t.Errorf("allocated more than expected; want max %v bytes/each", max) } diff --git a/control/controlclient/auto.go b/control/controlclient/auto.go index edd0ae29c..336a8d491 100644 --- a/control/controlclient/auto.go +++ b/control/controlclient/auto.go @@ -12,7 +12,6 @@ import ( "sync/atomic" "time" - "tailscale.com/logtail/backoff" "tailscale.com/net/sockstats" "tailscale.com/tailcfg" "tailscale.com/tstime" @@ -21,7 +20,10 @@ import ( "tailscale.com/types/netmap" "tailscale.com/types/persist" "tailscale.com/types/structs" + "tailscale.com/util/backoff" + "tailscale.com/util/clientmetric" "tailscale.com/util/execqueue" + "tailscale.com/util/testenv" ) type LoginGoal struct { @@ -116,13 +118,13 @@ type Auto struct { logf logger.Logf closed bool updateCh chan struct{} // readable when we should inform the server of a change - observer Observer // called to update Client status; always non-nil + observer Observer // if non-nil, called to update Client status observerQueue execqueue.ExecQueue - - unregisterHealthWatch func() + shutdownFn func() // to be called prior to shutdown or nil mu sync.Mutex // mutex guards the following fields + started bool // whether [Auto.Start] has been called wantLoggedIn bool // whether the user wants to be logged in per last method call urlToVisit string // the last url we were told to visit expiry time.Time @@ -131,12 +133,13 @@ type Auto struct { // the server. lastUpdateGen updateGen + lastStatus atomic.Pointer[Status] + paused bool // whether we should stop making HTTP requests unpauseWaiters []chan bool // chans that gets sent true (once) on wake, or false on Shutdown loggedIn bool // true if currently logged in loginGoal *LoginGoal // non-nil if some login activity is desired inMapPoll bool // true once we get the first MapResponse in a stream; false when HTTP response ends - state State // TODO(bradfitz): delete this, make it computed by method from other state authCtx context.Context // context used for auth requests mapCtx context.Context // context used for netmap and update requests @@ -149,15 +152,21 @@ type Auto struct { // New creates and starts a new Auto. func New(opts Options) (*Auto, error) { - c, err := NewNoStart(opts) - if c != nil { - c.Start() + c, err := newNoStart(opts) + if err != nil { + return nil, err + } + if opts.StartPaused { + c.SetPaused(true) + } + if !opts.SkipStartForTests { + c.start() } return c, err } -// NewNoStart creates a new Auto, but without calling Start on it. -func NewNoStart(opts Options) (_ *Auto, err error) { +// newNoStart creates a new Auto, but without calling Start on it. +func newNoStart(opts Options) (_ *Auto, err error) { direct, err := NewDirect(opts) if err != nil { return nil, err @@ -168,9 +177,6 @@ func NewNoStart(opts Options) (_ *Auto, err error) { } }() - if opts.Observer == nil { - return nil, errors.New("missing required Options.Observer") - } if opts.Logf == nil { opts.Logf = func(fmt string, args ...any) {} } @@ -186,16 +192,16 @@ func NewNoStart(opts Options) (_ *Auto, err error) { mapDone: make(chan struct{}), updateDone: make(chan struct{}), observer: opts.Observer, + shutdownFn: opts.Shutdown, } + c.authCtx, c.authCancel = context.WithCancel(context.Background()) c.authCtx = sockstats.WithSockStats(c.authCtx, sockstats.LabelControlClientAuto, opts.Logf) c.mapCtx, c.mapCancel = context.WithCancel(context.Background()) c.mapCtx = sockstats.WithSockStats(c.mapCtx, sockstats.LabelControlClientAuto, opts.Logf) - c.unregisterHealthWatch = opts.HealthTracker.RegisterWatcher(direct.ReportHealthChange) return c, nil - } // SetPaused controls whether HTTP activity should be paused. @@ -220,10 +226,21 @@ func (c *Auto) SetPaused(paused bool) { c.unpauseWaiters = nil } -// Start starts the client's goroutines. +// StartForTest starts the client's goroutines. // -// It should only be called for clients created by NewNoStart. -func (c *Auto) Start() { +// It should only be called for clients created with [Options.SkipStartForTests]. +func (c *Auto) StartForTest() { + testenv.AssertInTest() + c.start() +} + +func (c *Auto) start() { + c.mu.Lock() + defer c.mu.Unlock() + if c.started { + return + } + c.started = true go c.authRoutine() go c.mapRoutine() go c.updateRoutine() @@ -297,10 +314,11 @@ func (c *Auto) authRoutine() { c.mu.Lock() goal := c.loginGoal ctx := c.authCtx + loggedIn := c.loggedIn if goal != nil { - c.logf("[v1] authRoutine: %s; wantLoggedIn=%v", c.state, true) + c.logf("[v1] authRoutine: loggedIn=%v; wantLoggedIn=%v", loggedIn, true) } else { - c.logf("[v1] authRoutine: %s; goal=nil paused=%v", c.state, c.paused) + c.logf("[v1] authRoutine: loggedIn=%v; goal=nil paused=%v", loggedIn, c.paused) } c.mu.Unlock() @@ -323,11 +341,6 @@ func (c *Auto) authRoutine() { c.mu.Lock() c.urlToVisit = goal.url - if goal.url != "" { - c.state = StateURLVisitRequired - } else { - c.state = StateAuthenticating - } c.mu.Unlock() var url string @@ -361,7 +374,6 @@ func (c *Auto) authRoutine() { flags: LoginDefault, url: url, } - c.state = StateURLVisitRequired c.mu.Unlock() c.sendStatus("authRoutine-url", err, url, nil) @@ -381,7 +393,6 @@ func (c *Auto) authRoutine() { c.urlToVisit = "" c.loggedIn = true c.loginGoal = nil - c.state = StateAuthenticated c.mu.Unlock() c.sendStatus("authRoutine-success", nil, "", nil) @@ -414,6 +425,11 @@ func (c *Auto) unpausedChanLocked() <-chan bool { return unpaused } +// ClientID returns the ClientID of the direct controlClient +func (c *Auto) ClientID() int64 { + return c.direct.ClientID() +} + // mapRoutineState is the state of Auto.mapRoutine while it's running. type mapRoutineState struct { c *Auto @@ -426,21 +442,17 @@ func (mrs mapRoutineState) UpdateFullNetmap(nm *netmap.NetworkMap) { c := mrs.c c.mu.Lock() - ctx := c.mapCtx c.inMapPoll = true - if c.loggedIn { - c.state = StateSynchronized - } - c.expiry = nm.Expiry + c.expiry = nm.SelfKeyExpiry() stillAuthed := c.loggedIn - c.logf("[v1] mapRoutine: netmap received: %s", c.state) + c.logf("[v1] mapRoutine: netmap received: loggedIn=%v inMapPoll=true", stillAuthed) c.mu.Unlock() if stillAuthed { c.sendStatus("mapRoutine-got-netmap", nil, "", nm) } // Reset the backoff timer if we got a netmap. - mrs.bo.BackOff(ctx, nil) + mrs.bo.Reset() } func (mrs mapRoutineState) UpdateNetmapDelta(muts []netmap.NodeMutation) bool { @@ -481,8 +493,8 @@ func (c *Auto) mapRoutine() { } c.mu.Lock() - c.logf("[v1] mapRoutine: %s", c.state) loggedIn := c.loggedIn + c.logf("[v1] mapRoutine: loggedIn=%v", loggedIn) ctx := c.mapCtx c.mu.Unlock() @@ -513,9 +525,6 @@ func (c *Auto) mapRoutine() { c.direct.health.SetOutOfPollNetMap() c.mu.Lock() c.inMapPoll = false - if c.state == StateSynchronized { - c.state = StateAuthenticated - } paused := c.paused c.mu.Unlock() @@ -581,12 +590,12 @@ func (c *Auto) sendStatus(who string, err error, url string, nm *netmap.NetworkM c.mu.Unlock() return } - state := c.state loggedIn := c.loggedIn inMapPoll := c.inMapPoll + loginGoal := c.loginGoal c.mu.Unlock() - c.logf("[v1] sendStatus: %s: %v", who, state) + c.logf("[v1] sendStatus: %s: loggedIn=%v inMapPoll=%v", who, loggedIn, inMapPoll) var p persist.PersistView if nm != nil && loggedIn && inMapPoll { @@ -596,21 +605,104 @@ func (c *Auto) sendStatus(who string, err error, url string, nm *netmap.NetworkM // not logged in. nm = nil } - new := Status{ - URL: url, - Persist: p, - NetMap: nm, - Err: err, - state: state, + newSt := &Status{ + URL: url, + Persist: p, + NetMap: nm, + Err: err, + LoggedIn: loggedIn && loginGoal == nil, + InMapPoll: inMapPoll, } + if c.observer == nil { + return + } + + c.lastStatus.Store(newSt) + // Launch a new goroutine to avoid blocking the caller while the observer // does its thing, which may result in a call back into the client. + metricQueued.Add(1) c.observerQueue.Add(func() { - c.observer.SetControlClientStatus(c, new) + c.mu.Lock() + closed := c.closed + c.mu.Unlock() + if closed { + return + } + + if canSkipStatus(newSt, c.lastStatus.Load()) { + metricSkippable.Add(1) + if !c.direct.controlKnobs.DisableSkipStatusQueue.Load() { + metricSkipped.Add(1) + return + } + } + c.observer.SetControlClientStatus(c, *newSt) + + // Best effort stop retaining the memory now that we've sent it to the + // observer (LocalBackend). We CAS here because the caller goroutine is + // doing a Store which we want to win a race. This is only a memory + // optimization and is not for correctness. + // + // If the CAS fails, that means somebody else's Store replaced our + // pointer (so mission accomplished: our netmap is no longer retained in + // any case) and that Store caller will be responsible for removing + // their own netmap (or losing their race too, down the chain). + // Eventually the last caller will win this CAS and zero lastStatus. + c.lastStatus.CompareAndSwap(newSt, nil) }) } +var ( + metricQueued = clientmetric.NewCounter("controlclient_auto_status_queued") + metricSkippable = clientmetric.NewCounter("controlclient_auto_status_queue_skippable") + metricSkipped = clientmetric.NewCounter("controlclient_auto_status_queue_skipped") +) + +// canSkipStatus reports whether we can skip sending s1, knowing +// that s2 is enqueued sometime in the future after s1. +// +// s1 must be non-nil. s2 may be nil. +func canSkipStatus(s1, s2 *Status) bool { + if s2 == nil { + // Nothing in the future. + return false + } + if s1 == s2 { + // If the last item in the queue is the same as s1, + // we can't skip it. + return false + } + if s1.Err != nil || s1.URL != "" || s1.LoggedIn { + // If s1 has an error, a URL, or LoginFinished set, we shouldn't skip it, + // lest the error go away in s2 or in-between. We want to make sure all + // the subsystems see it. Plus there aren't many of these, so not worth + // skipping. + return false + } + if !s1.Persist.Equals(s2.Persist) || s1.LoggedIn != s2.LoggedIn || s1.InMapPoll != s2.InMapPoll || s1.URL != s2.URL { + // If s1 has a different Persist, LoginFinished, Synced, or URL than s2, + // don't skip it. We only care about skipping the typical + // entries where the only difference is the NetMap. + return false + } + // If nothing above precludes it, and both s1 and s2 have NetMaps, then + // we can skip it, because s2's NetMap is a newer version and we can + // jump straight from whatever state we had before to s2's state, + // without passing through s1's state first. A NetMap is regrettably a + // full snapshot of the state, not an incremental delta. We're slowly + // moving towards passing around only deltas around internally at all + // layers, but this is explicitly the case where we didn't have a delta + // path for the message we received over the wire and had to resort + // to the legacy full NetMap path. And then we can get behind processing + // these full NetMap snapshots in LocalBackend/wgengine/magicsock/netstack + // and this path (when it returns true) lets us skip over useless work + // and not get behind in the queue. This matters in particular for tailnets + // that are both very large + very churny. + return s1.NetMap != nil && s2.NetMap != nil +} + func (c *Auto) Login(flags LoginFlags) { c.logf("client.Login(%v)", flags) @@ -652,7 +744,6 @@ func (c *Auto) Logout(ctx context.Context) error { } c.mu.Lock() c.loggedIn = false - c.state = StateNotAuthenticated c.cancelAuthCtxLocked() c.cancelMapCtxLocked() c.mu.Unlock() @@ -676,6 +767,13 @@ func (c *Auto) UpdateEndpoints(endpoints []tailcfg.Endpoint) { } } +// SetDiscoPublicKey sets the client's Disco public to key and sends the change +// to the control server. +func (c *Auto) SetDiscoPublicKey(key key.DiscoPublic) { + c.direct.SetDiscoPublicKey(key) + c.updateControl() +} + func (c *Auto) Shutdown() { c.mu.Lock() if c.closed { @@ -683,6 +781,7 @@ func (c *Auto) Shutdown() { return } c.logf("client.Shutdown ...") + shutdownFn := c.shutdownFn direct := c.direct c.closed = true @@ -695,7 +794,10 @@ func (c *Auto) Shutdown() { c.unpauseWaiters = nil c.mu.Unlock() - c.unregisterHealthWatch() + if shutdownFn != nil { + shutdownFn() + } + <-c.authDone <-c.mapDone <-c.updateDone @@ -734,13 +836,3 @@ func (c *Auto) SetDNS(ctx context.Context, req *tailcfg.SetDNSRequest) error { func (c *Auto) DoNoiseRequest(req *http.Request) (*http.Response, error) { return c.direct.DoNoiseRequest(req) } - -// GetSingleUseNoiseRoundTripper returns a RoundTripper that can be only be used -// once (and must be used once) to make a single HTTP request over the noise -// channel to the coordination server. -// -// In addition to the RoundTripper, it returns the HTTP/2 channel's early noise -// payload, if any. -func (c *Auto) GetSingleUseNoiseRoundTripper(ctx context.Context) (http.RoundTripper, *tailcfg.EarlyNoise, error) { - return c.direct.GetSingleUseNoiseRoundTripper(ctx) -} diff --git a/control/controlclient/client.go b/control/controlclient/client.go index 8df64f9e8..41b39622b 100644 --- a/control/controlclient/client.go +++ b/control/controlclient/client.go @@ -12,6 +12,7 @@ import ( "context" "tailscale.com/tailcfg" + "tailscale.com/types/key" ) // LoginFlags is a bitmask of options to change the behavior of Client.Login @@ -80,7 +81,15 @@ type Client interface { // TODO: a server-side change would let us simply upload this // in a separate http request. It has nothing to do with the rest of // the state machine. + // Note: the auto client uploads the new endpoints to control immediately. UpdateEndpoints(endpoints []tailcfg.Endpoint) + // SetDiscoPublicKey updates the disco public key that will be sent in + // future map requests. This should be called after rotating the discovery key. + // Note: the auto client uploads the new key to control immediately. + SetDiscoPublicKey(key.DiscoPublic) + // ClientID returns the ClientID of a client. This ID is meant to + // distinguish one client from another. + ClientID() int64 } // UserVisibleError is an error that should be shown to users. diff --git a/control/controlclient/controlclient_test.go b/control/controlclient/controlclient_test.go index b37623451..bc3011226 100644 --- a/control/controlclient/controlclient_test.go +++ b/control/controlclient/controlclient_test.go @@ -4,8 +4,37 @@ package controlclient import ( + "context" + "crypto/tls" + "errors" + "flag" + "fmt" + "io" + "net" + "net/http" + "net/netip" + "net/url" "reflect" + "sync/atomic" "testing" + "time" + + "tailscale.com/control/controlknobs" + "tailscale.com/health" + "tailscale.com/net/bakedroots" + "tailscale.com/net/connectproxy" + "tailscale.com/net/netmon" + "tailscale.com/net/tsdial" + "tailscale.com/tailcfg" + "tailscale.com/tstest" + "tailscale.com/tstest/integration/testcontrol" + "tailscale.com/tstest/tlstest" + "tailscale.com/tstime" + "tailscale.com/types/key" + "tailscale.com/types/logger" + "tailscale.com/types/netmap" + "tailscale.com/types/persist" + "tailscale.com/util/eventbus/eventbustest" ) func fieldsOf(t reflect.Type) (fields []string) { @@ -19,7 +48,7 @@ func fieldsOf(t reflect.Type) (fields []string) { func TestStatusEqual(t *testing.T) { // Verify that the Equal method stays in sync with reality - equalHandles := []string{"Err", "URL", "NetMap", "Persist", "state"} + equalHandles := []string{"Err", "URL", "LoggedIn", "InMapPoll", "NetMap", "Persist"} if have := fieldsOf(reflect.TypeFor[Status]()); !reflect.DeepEqual(have, equalHandles) { t.Errorf("Status.Equal check might be out of sync\nfields: %q\nhandled: %q\n", have, equalHandles) @@ -51,7 +80,7 @@ func TestStatusEqual(t *testing.T) { }, { &Status{}, - &Status{state: StateAuthenticated}, + &Status{LoggedIn: true, Persist: new(persist.Persist).View()}, false, }, } @@ -62,3 +91,346 @@ func TestStatusEqual(t *testing.T) { } } } + +// tests [canSkipStatus]. +func TestCanSkipStatus(t *testing.T) { + st := new(Status) + nm1 := &netmap.NetworkMap{} + nm2 := &netmap.NetworkMap{} + + tests := []struct { + name string + s1, s2 *Status + want bool + }{ + { + name: "nil-s2", + s1: st, + s2: nil, + want: false, + }, + { + name: "equal", + s1: st, + s2: st, + want: false, + }, + { + name: "s1-error", + s1: &Status{Err: io.EOF, NetMap: nm1}, + s2: &Status{NetMap: nm2}, + want: false, + }, + { + name: "s1-url", + s1: &Status{URL: "foo", NetMap: nm1}, + s2: &Status{NetMap: nm2}, + want: false, + }, + { + name: "s1-persist-diff", + s1: &Status{Persist: new(persist.Persist).View(), NetMap: nm1}, + s2: &Status{NetMap: nm2}, + want: false, + }, + { + name: "s1-login-finished-diff", + s1: &Status{LoggedIn: true, Persist: new(persist.Persist).View(), NetMap: nm1}, + s2: &Status{NetMap: nm2}, + want: false, + }, + { + name: "s1-login-finished", + s1: &Status{LoggedIn: true, Persist: new(persist.Persist).View(), NetMap: nm1}, + s2: &Status{NetMap: nm2}, + want: false, + }, + { + name: "s1-synced-diff", + s1: &Status{InMapPoll: true, LoggedIn: true, Persist: new(persist.Persist).View(), NetMap: nm1}, + s2: &Status{NetMap: nm2}, + want: false, + }, + { + name: "s1-no-netmap1", + s1: &Status{NetMap: nil}, + s2: &Status{NetMap: nm2}, + want: false, + }, + { + name: "s1-no-netmap2", + s1: &Status{NetMap: nm1}, + s2: &Status{NetMap: nil}, + want: false, + }, + { + name: "skip", + s1: &Status{NetMap: nm1}, + s2: &Status{NetMap: nm2}, + want: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := canSkipStatus(tt.s1, tt.s2); got != tt.want { + t.Errorf("canSkipStatus = %v, want %v", got, tt.want) + } + }) + } + + coveredFields := []string{"Err", "URL", "LoggedIn", "InMapPoll", "NetMap", "Persist"} + if have := fieldsOf(reflect.TypeFor[Status]()); !reflect.DeepEqual(have, coveredFields) { + t.Errorf("Status fields = %q; this code was only written to handle fields %q", have, coveredFields) + } + +} + +func TestRetryableErrors(t *testing.T) { + errorTests := []struct { + err error + want bool + }{ + {errNoNoiseClient, true}, + {errNoNodeKey, true}, + {fmt.Errorf("%w: %w", errNoNoiseClient, errors.New("no noise")), true}, + {fmt.Errorf("%w: %w", errHTTPPostFailure, errors.New("bad post")), true}, + {fmt.Errorf("%w: %w", errNoNodeKey, errors.New("not node key")), true}, + {errBadHTTPResponse(429, "too may requests"), true}, + {errBadHTTPResponse(500, "internal server eror"), true}, + {errBadHTTPResponse(502, "bad gateway"), true}, + {errBadHTTPResponse(503, "service unavailable"), true}, + {errBadHTTPResponse(504, "gateway timeout"), true}, + {errBadHTTPResponse(1234, "random error"), false}, + } + + for _, tt := range errorTests { + t.Run(tt.err.Error(), func(t *testing.T) { + if isRetryableErrorForTest(tt.err) != tt.want { + t.Fatalf("retriable: got %v, want %v", tt.err, tt.want) + } + }) + } +} + +type retryableForTest interface { + Retryable() bool +} + +func isRetryableErrorForTest(err error) bool { + var ae retryableForTest + if errors.As(err, &ae) { + return ae.Retryable() + } + return false +} + +var liveNetworkTest = flag.Bool("live-network-test", false, "run live network tests") + +func TestDirectProxyManual(t *testing.T) { + if !*liveNetworkTest { + t.Skip("skipping without --live-network-test") + } + + bus := eventbustest.NewBus(t) + + dialer := &tsdial.Dialer{} + dialer.SetNetMon(netmon.NewStatic()) + dialer.SetBus(bus) + + opts := Options{ + Persist: persist.Persist{}, + GetMachinePrivateKey: func() (key.MachinePrivate, error) { + return key.NewMachine(), nil + }, + ServerURL: "https://controlplane.tailscale.com", + Clock: tstime.StdClock{}, + Hostinfo: &tailcfg.Hostinfo{ + BackendLogID: "test-backend-log-id", + }, + DiscoPublicKey: key.NewDisco().Public(), + Logf: t.Logf, + HealthTracker: health.NewTracker(bus), + PopBrowserURL: func(url string) { + t.Logf("PopBrowserURL: %q", url) + }, + Dialer: dialer, + ControlKnobs: &controlknobs.Knobs{}, + Bus: bus, + } + d, err := NewDirect(opts) + if err != nil { + t.Fatalf("NewDirect: %v", err) + } + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + url, err := d.TryLogin(ctx, LoginEphemeral) + if err != nil { + t.Fatalf("TryLogin: %v", err) + } + t.Logf("URL: %q", url) +} + +func TestHTTPSNoProxy(t *testing.T) { testHTTPS(t, false) } + +// TestTLSWithProxy verifies we can connect to the control plane via +// an HTTPS proxy. +func TestHTTPSWithProxy(t *testing.T) { testHTTPS(t, true) } + +func testHTTPS(t *testing.T, withProxy bool) { + bakedroots.ResetForTest(t, tlstest.TestRootCA()) + + bus := eventbustest.NewBus(t) + + controlLn, err := tls.Listen("tcp", "127.0.0.1:0", tlstest.ControlPlane.ServerTLSConfig()) + if err != nil { + t.Fatal(err) + } + defer controlLn.Close() + + proxyLn, err := tls.Listen("tcp", "127.0.0.1:0", tlstest.ProxyServer.ServerTLSConfig()) + if err != nil { + t.Fatal(err) + } + defer proxyLn.Close() + + const requiredAuthKey = "hunter2" + const someUsername = "testuser" + const somePassword = "testpass" + + testControl := &testcontrol.Server{ + Logf: tstest.WhileTestRunningLogger(t), + RequireAuthKey: requiredAuthKey, + } + controlSrv := &http.Server{ + Handler: testControl, + ErrorLog: logger.StdLogger(t.Logf), + } + go controlSrv.Serve(controlLn) + + const fakeControlIP = "1.2.3.4" + const fakeProxyIP = "5.6.7.8" + + dialer := &tsdial.Dialer{} + dialer.SetNetMon(netmon.NewStatic()) + dialer.SetBus(bus) + dialer.SetSystemDialerForTest(func(ctx context.Context, network, addr string) (net.Conn, error) { + host, _, err := net.SplitHostPort(addr) + if err != nil { + return nil, fmt.Errorf("SplitHostPort(%q): %v", addr, err) + } + var d net.Dialer + if host == fakeControlIP { + return d.DialContext(ctx, network, controlLn.Addr().String()) + } + if host == fakeProxyIP { + return d.DialContext(ctx, network, proxyLn.Addr().String()) + } + return nil, fmt.Errorf("unexpected dial to %q", addr) + }) + + opts := Options{ + Persist: persist.Persist{}, + GetMachinePrivateKey: func() (key.MachinePrivate, error) { + return key.NewMachine(), nil + }, + AuthKey: requiredAuthKey, + ServerURL: "https://controlplane.tstest", + Clock: tstime.StdClock{}, + Hostinfo: &tailcfg.Hostinfo{ + BackendLogID: "test-backend-log-id", + }, + DiscoPublicKey: key.NewDisco().Public(), + Logf: t.Logf, + HealthTracker: health.NewTracker(bus), + PopBrowserURL: func(url string) { + t.Logf("PopBrowserURL: %q", url) + }, + Dialer: dialer, + Bus: bus, + } + d, err := NewDirect(opts) + if err != nil { + t.Fatalf("NewDirect: %v", err) + } + + d.dnsCache.LookupIPForTest = func(ctx context.Context, host string) ([]netip.Addr, error) { + switch host { + case "controlplane.tstest": + return []netip.Addr{netip.MustParseAddr(fakeControlIP)}, nil + case "proxy.tstest": + if !withProxy { + t.Errorf("unexpected DNS lookup for %q with proxy disabled", host) + return nil, fmt.Errorf("unexpected DNS lookup for %q", host) + } + return []netip.Addr{netip.MustParseAddr(fakeProxyIP)}, nil + } + t.Errorf("unexpected DNS query for %q", host) + return []netip.Addr{}, nil + } + + var proxyReqs atomic.Int64 + if withProxy { + d.httpc.Transport.(*http.Transport).Proxy = func(req *http.Request) (*url.URL, error) { + t.Logf("using proxy for %q", req.URL) + u := &url.URL{ + Scheme: "https", + Host: "proxy.tstest:443", + User: url.UserPassword(someUsername, somePassword), + } + return u, nil + } + + connectProxy := &http.Server{ + Handler: connectProxyTo(t, "controlplane.tstest:443", controlLn.Addr().String(), &proxyReqs), + } + go connectProxy.Serve(proxyLn) + } + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + url, err := d.TryLogin(ctx, LoginEphemeral) + if err != nil { + t.Fatalf("TryLogin: %v", err) + } + if url != "" { + t.Errorf("got URL %q, want empty", url) + } + + if withProxy { + if got, want := proxyReqs.Load(), int64(1); got != want { + t.Errorf("proxy CONNECT requests = %d; want %d", got, want) + } + } +} + +func connectProxyTo(t testing.TB, target, backendAddrPort string, reqs *atomic.Int64) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.RequestURI != target { + t.Errorf("invalid CONNECT request to %q; want %q", r.RequestURI, target) + http.Error(w, "bad target", http.StatusBadRequest) + return + } + + r.Header.Set("Authorization", r.Header.Get("Proxy-Authorization")) // for the BasicAuth method. kinda trashy. + user, pass, ok := r.BasicAuth() + if !ok || user != "testuser" || pass != "testpass" { + t.Errorf("invalid CONNECT auth %q:%q; want %q:%q", user, pass, "testuser", "testpass") + http.Error(w, "bad auth", http.StatusUnauthorized) + return + } + + (&connectproxy.Handler{ + Dial: func(ctx context.Context, network, addr string) (net.Conn, error) { + var d net.Dialer + c, err := d.DialContext(ctx, network, backendAddrPort) + if err == nil { + reqs.Add(1) + } + return c, err + }, + Logf: t.Logf, + }).ServeHTTP(w, r) + }) +} diff --git a/control/controlclient/direct.go b/control/controlclient/direct.go index 9cbd0e14e..006a801ef 100644 --- a/control/controlclient/direct.go +++ b/control/controlclient/direct.go @@ -4,9 +4,11 @@ package controlclient import ( - "bufio" "bytes" + "cmp" "context" + "crypto" + "crypto/sha256" "encoding/binary" "encoding/json" "errors" @@ -15,21 +17,21 @@ import ( "log" "net" "net/http" - "net/http/httptest" "net/netip" - "net/url" "os" "reflect" "runtime" "slices" "strings" - "sync" "sync/atomic" "time" "go4.org/mem" "tailscale.com/control/controlknobs" + "tailscale.com/control/ts2021" "tailscale.com/envknob" + "tailscale.com/feature" + "tailscale.com/feature/buildfeatures" "tailscale.com/health" "tailscale.com/hostinfo" "tailscale.com/ipn/ipnstate" @@ -38,9 +40,10 @@ import ( "tailscale.com/net/dnsfallback" "tailscale.com/net/netmon" "tailscale.com/net/netutil" + "tailscale.com/net/netx" "tailscale.com/net/tlsdial" "tailscale.com/net/tsdial" - "tailscale.com/net/tshttpproxy" + "tailscale.com/syncs" "tailscale.com/tailcfg" "tailscale.com/tka" "tailscale.com/tstime" @@ -51,62 +54,70 @@ import ( "tailscale.com/types/ptr" "tailscale.com/types/tkatype" "tailscale.com/util/clientmetric" - "tailscale.com/util/multierr" + "tailscale.com/util/eventbus" "tailscale.com/util/singleflight" - "tailscale.com/util/syspolicy" - "tailscale.com/util/systemd" + "tailscale.com/util/syspolicy/pkey" + "tailscale.com/util/syspolicy/policyclient" "tailscale.com/util/testenv" "tailscale.com/util/zstdframe" ) // Direct is the client that connects to a tailcontrol server for a node. type Direct struct { - httpc *http.Client // HTTP client used to talk to tailcontrol - interceptedDial *atomic.Bool // if non-nil, pointer to bool whether ScreenTime intercepted our dial - dialer *tsdial.Dialer - dnsCache *dnscache.Resolver - controlKnobs *controlknobs.Knobs // always non-nil - serverURL string // URL of the tailcontrol server - clock tstime.Clock - logf logger.Logf - netMon *netmon.Monitor // non-nil - health *health.Tracker - discoPubKey key.DiscoPublic - getMachinePrivKey func() (key.MachinePrivate, error) - debugFlags []string - skipIPForwardingCheck bool - pinger Pinger - popBrowser func(url string) // or nil - c2nHandler http.Handler // or nil - onClientVersion func(*tailcfg.ClientVersion) // or nil - onControlTime func(time.Time) // or nil - onTailnetDefaultAutoUpdate func(bool) // or nil - panicOnUse bool // if true, panic if client is used (for testing) - closedCtx context.Context // alive until Direct.Close is called - closeCtx context.CancelFunc // cancels closedCtx + httpc *http.Client // HTTP client used to do TLS requests to control (just https://controlplane.tailscale.com/key?v=123) + interceptedDial *atomic.Bool // if non-nil, pointer to bool whether ScreenTime intercepted our dial + dialer *tsdial.Dialer + dnsCache *dnscache.Resolver + controlKnobs *controlknobs.Knobs // always non-nil + serverURL string // URL of the tailcontrol server + clock tstime.Clock + logf logger.Logf + netMon *netmon.Monitor // non-nil + health *health.Tracker + busClient *eventbus.Client + clientVersionPub *eventbus.Publisher[tailcfg.ClientVersion] + autoUpdatePub *eventbus.Publisher[AutoUpdate] + controlTimePub *eventbus.Publisher[ControlTime] + getMachinePrivKey func() (key.MachinePrivate, error) + debugFlags []string + skipIPForwardingCheck bool + pinger Pinger + popBrowser func(url string) // or nil + polc policyclient.Client // always non-nil + c2nHandler http.Handler // or nil + panicOnUse bool // if true, panic if client is used (for testing) + closedCtx context.Context // alive until Direct.Close is called + closeCtx context.CancelFunc // cancels closedCtx dialPlan ControlDialPlanner // can be nil - mu sync.Mutex // mutex guards the following fields + mu syncs.Mutex // mutex guards the following fields serverLegacyKey key.MachinePublic // original ("legacy") nacl crypto_box-based public key; only used for signRegisterRequest on Windows now serverNoiseKey key.MachinePublic - - sfGroup singleflight.Group[struct{}, *NoiseClient] // protects noiseClient creation. - noiseClient *NoiseClient - - persist persist.PersistView - authKey string - tryingNewKey key.NodePrivate - expiry time.Time // or zero value if none/unknown - hostinfo *tailcfg.Hostinfo // always non-nil - netinfo *tailcfg.NetInfo - endpoints []tailcfg.Endpoint - tkaHead string - lastPingURL string // last PingRequest.URL received, for dup suppression + discoPubKey key.DiscoPublic // protected by mu; can be updated via [SetDiscoPublicKey] + + sfGroup singleflight.Group[struct{}, *ts2021.Client] // protects noiseClient creation. + noiseClient *ts2021.Client // also protected by mu + + persist persist.PersistView + authKey string + tryingNewKey key.NodePrivate + expiry time.Time // or zero value if none/unknown + hostinfo *tailcfg.Hostinfo // always non-nil + netinfo *tailcfg.NetInfo + endpoints []tailcfg.Endpoint + tkaHead string + lastPingURL string // last PingRequest.URL received, for dup suppression + connectionHandleForTest string // sent in MapRequest.ConnectionHandleForTest + + controlClientID int64 // Random ID used to differentiate clients for consumers of messages. } // Observer is implemented by users of the control client (such as LocalBackend) // to get notified of changes in the control client's status. +// +// If an implementation of Observer also implements [NetmapDeltaUpdater], they get +// delta updates as well as full netmap updates. type Observer interface { // SetControlClientStatus is called when the client has a new status to // report. The Client is provided to allow the Observer to track which @@ -116,28 +127,36 @@ type Observer interface { } type Options struct { - Persist persist.Persist // initial persistent data - GetMachinePrivateKey func() (key.MachinePrivate, error) // returns the machine key to use - ServerURL string // URL of the tailcontrol server - AuthKey string // optional node auth key for auto registration - Clock tstime.Clock - Hostinfo *tailcfg.Hostinfo // non-nil passes ownership, nil means to use default using os.Hostname, etc - DiscoPublicKey key.DiscoPublic - Logf logger.Logf - HTTPTestClient *http.Client // optional HTTP client to use (for tests only) - NoiseTestClient *http.Client // optional HTTP client to use for noise RPCs (tests only) - DebugFlags []string // debug settings to send to control - HealthTracker *health.Tracker - PopBrowserURL func(url string) // optional func to open browser - OnClientVersion func(*tailcfg.ClientVersion) // optional func to inform GUI of client version status - OnControlTime func(time.Time) // optional func to notify callers of new time from control - OnTailnetDefaultAutoUpdate func(bool) // optional func to inform GUI of default auto-update setting for the tailnet - Dialer *tsdial.Dialer // non-nil - C2NHandler http.Handler // or nil - ControlKnobs *controlknobs.Knobs // or nil to ignore + Persist persist.Persist // initial persistent data + GetMachinePrivateKey func() (key.MachinePrivate, error) // returns the machine key to use + ServerURL string // URL of the tailcontrol server + AuthKey string // optional node auth key for auto registration + Clock tstime.Clock + Hostinfo *tailcfg.Hostinfo // non-nil passes ownership, nil means to use default using os.Hostname, etc + DiscoPublicKey key.DiscoPublic + PolicyClient policyclient.Client // or nil for none + Logf logger.Logf + HTTPTestClient *http.Client // optional HTTP client to use (for tests only) + NoiseTestClient *http.Client // optional HTTP client to use for noise RPCs (tests only) + DebugFlags []string // debug settings to send to control + HealthTracker *health.Tracker + PopBrowserURL func(url string) // optional func to open browser + Dialer *tsdial.Dialer // non-nil + C2NHandler http.Handler // or nil + ControlKnobs *controlknobs.Knobs // or nil to ignore + Bus *eventbus.Bus // non-nil, for setting up publishers + + SkipStartForTests bool // if true, don't call [Auto.Start] to avoid any background goroutines (for tests only) + + // StartPaused indicates whether the client should start in a paused state + // where it doesn't do network requests. This primarily exists for testing + // but not necessarily "go test" tests, so it isn't restricted to only + // being used in tests. + StartPaused bool // Observer is called when there's a change in status to report // from the control client. + // If nil, no status updates are reported. Observer Observer // SkipIPForwardingCheck declares that the host's IP @@ -156,6 +175,11 @@ type Options struct { // If we receive a new DialPlan from the server, this value will be // updated. DialPlan ControlDialPlanner + + // Shutdown is an optional function that will be called before client shutdown is + // attempted. It is used to allow the client to clean up any resources or complete any + // tasks that are dependent on a live client. + Shutdown func() } // ControlDialPlanner is the interface optionally supplied when creating a @@ -208,6 +232,8 @@ type NetmapDeltaUpdater interface { UpdateNetmapDelta([]netmap.NodeMutation) (ok bool) } +var nextControlClientID atomic.Int64 + // NewDirect returns a new Direct client. func NewDirect(opts Options) (*Direct, error) { if opts.ServerURL == "" { @@ -233,10 +259,6 @@ func NewDirect(opts Options) (*Direct, error) { opts.ControlKnobs = &controlknobs.Knobs{} } opts.ServerURL = strings.TrimRight(opts.ServerURL, "/") - serverURL, err := url.Parse(opts.ServerURL) - if err != nil { - return nil, err - } if opts.Clock == nil { opts.Clock = tstime.StdClock{} } @@ -264,10 +286,14 @@ func NewDirect(opts Options) (*Direct, error) { var interceptedDial *atomic.Bool if httpc == nil { tr := http.DefaultTransport.(*http.Transport).Clone() - tr.Proxy = tshttpproxy.ProxyFromEnvironment - tshttpproxy.SetTransportGetProxyConnectHeader(tr) - tr.TLSClientConfig = tlsdial.Config(serverURL.Hostname(), opts.HealthTracker, tr.TLSClientConfig) - var dialFunc dialFunc + if buildfeatures.HasUseProxy { + tr.Proxy = feature.HookProxyFromEnvironment.GetOrNil() + if f, ok := feature.HookProxySetTransportGetProxyConnectHeader.GetOk(); ok { + f(tr) + } + } + tr.TLSClientConfig = tlsdial.Config(opts.HealthTracker, tr.TLSClientConfig) + var dialFunc netx.DialFunc dialFunc, interceptedDial = makeScreenTimeDetectingDialFunc(opts.Dialer.SystemDial) tr.DialContext = dnscache.Dialer(dialFunc, dnsCache) tr.DialTLSContext = dnscache.TLSDialer(dialFunc, dnsCache, tr.TLSClientConfig) @@ -281,32 +307,32 @@ func NewDirect(opts Options) (*Direct, error) { } c := &Direct{ - httpc: httpc, - interceptedDial: interceptedDial, - controlKnobs: opts.ControlKnobs, - getMachinePrivKey: opts.GetMachinePrivateKey, - serverURL: opts.ServerURL, - clock: opts.Clock, - logf: opts.Logf, - persist: opts.Persist.View(), - authKey: opts.AuthKey, - discoPubKey: opts.DiscoPublicKey, - debugFlags: opts.DebugFlags, - netMon: netMon, - health: opts.HealthTracker, - skipIPForwardingCheck: opts.SkipIPForwardingCheck, - pinger: opts.Pinger, - popBrowser: opts.PopBrowserURL, - onClientVersion: opts.OnClientVersion, - onTailnetDefaultAutoUpdate: opts.OnTailnetDefaultAutoUpdate, - onControlTime: opts.OnControlTime, - c2nHandler: opts.C2NHandler, - dialer: opts.Dialer, - dnsCache: dnsCache, - dialPlan: opts.DialPlan, - } + httpc: httpc, + interceptedDial: interceptedDial, + controlKnobs: opts.ControlKnobs, + getMachinePrivKey: opts.GetMachinePrivateKey, + serverURL: opts.ServerURL, + clock: opts.Clock, + logf: opts.Logf, + persist: opts.Persist.View(), + authKey: opts.AuthKey, + debugFlags: opts.DebugFlags, + netMon: netMon, + health: opts.HealthTracker, + skipIPForwardingCheck: opts.SkipIPForwardingCheck, + pinger: opts.Pinger, + polc: cmp.Or(opts.PolicyClient, policyclient.Client(policyclient.NoPolicyClient{})), + popBrowser: opts.PopBrowserURL, + c2nHandler: opts.C2NHandler, + dialer: opts.Dialer, + dnsCache: dnsCache, + dialPlan: opts.DialPlan, + } + c.discoPubKey = opts.DiscoPublicKey c.closedCtx, c.closeCtx = context.WithCancel(context.Background()) + c.controlClientID = nextControlClientID.Add(1) + if opts.Hostinfo == nil { c.SetHostinfo(hostinfo.New()) } else { @@ -316,7 +342,7 @@ func NewDirect(opts Options) (*Direct, error) { } } if opts.NoiseTestClient != nil { - c.noiseClient = &NoiseClient{ + c.noiseClient = &ts2021.Client{ Client: opts.NoiseTestClient, } c.serverNoiseKey = key.NewMachine().Public() // prevent early error before hitting test client @@ -324,6 +350,12 @@ func NewDirect(opts Options) (*Direct, error) { if strings.Contains(opts.ServerURL, "controlplane.tailscale.com") && envknob.Bool("TS_PANIC_IF_HIT_MAIN_CONTROL") { c.panicOnUse = true } + + c.busClient = opts.Bus.Client("controlClient.direct") + c.clientVersionPub = eventbus.Publish[tailcfg.ClientVersion](c.busClient) + c.autoUpdatePub = eventbus.Publish[AutoUpdate](c.busClient) + c.controlTimePub = eventbus.Publish[ControlTime](c.busClient) + return c, nil } @@ -333,15 +365,14 @@ func (c *Direct) Close() error { c.mu.Lock() defer c.mu.Unlock() + c.busClient.Close() if c.noiseClient != nil { if err := c.noiseClient.Close(); err != nil { return err } } c.noiseClient = nil - if tr, ok := c.httpc.Transport.(*http.Transport); ok { - tr.CloseIdleConnections() - } + c.httpc.CloseIdleConnections() return nil } @@ -382,7 +413,7 @@ func (c *Direct) SetNetInfo(ni *tailcfg.NetInfo) bool { return true } -// SetNetInfo stores a new TKA head value for next update. +// SetTKAHead stores a new TKA head value for next update. // It reports whether the TKA head changed. func (c *Direct) SetTKAHead(tkaHead string) bool { c.mu.Lock() @@ -397,6 +428,14 @@ func (c *Direct) SetTKAHead(tkaHead string) bool { return true } +// SetConnectionHandleForTest stores a new MapRequest.ConnectionHandleForTest +// value for the next update. +func (c *Direct) SetConnectionHandleForTest(handle string) { + c.mu.Lock() + defer c.mu.Unlock() + c.connectionHandleForTest = handle +} + func (c *Direct) GetPersist() persist.PersistView { c.mu.Lock() defer c.mu.Unlock() @@ -518,7 +557,9 @@ func (c *Direct) doLogin(ctx context.Context, opt loginOpt) (mustRegen bool, new } else { if expired { c.logf("Old key expired -> regen=true") - systemd.Status("key expired; run 'tailscale up' to authenticate") + if f, ok := feature.HookSystemdStatus.GetOk(); ok { + f("key expired; run 'tailscale up' to authenticate") + } regen = true } if (opt.Flags & LoginInteractive) != 0 { @@ -577,6 +618,7 @@ func (c *Direct) doLogin(ctx context.Context, opt loginOpt) (mustRegen bool, new if persist.NetworkLockKey.IsZero() { persist.NetworkLockKey = key.NewNLPrivate() } + nlPub := persist.NetworkLockKey.Public() if tryingNewKey.IsZero() { @@ -606,7 +648,7 @@ func (c *Direct) doLogin(ctx context.Context, opt loginOpt) (mustRegen bool, new return regen, opt.URL, nil, err } - tailnet, err := syspolicy.GetString(syspolicy.Tailnet, "") + tailnet, err := c.polc.GetString(pkey.Tailnet, "") if err != nil { c.logf("unable to provide Tailnet field in register request. err: %v", err) } @@ -636,7 +678,7 @@ func (c *Direct) doLogin(ctx context.Context, opt loginOpt) (mustRegen bool, new AuthKey: authKey, } } - err = signRegisterRequest(&request, c.serverURL, c.serverLegacyKey, machinePrivKey.Public()) + err = signRegisterRequest(c.polc, &request, c.serverURL, c.serverLegacyKey, machinePrivKey.Public()) if err != nil { // If signing failed, clear all related fields request.SignatureType = tailcfg.SignatureNone @@ -650,7 +692,7 @@ func (c *Direct) doLogin(ctx context.Context, opt loginOpt) (mustRegen bool, new c.logf("RegisterReq sign error: %v", err) } } - if debugRegister() { + if DevKnob.DumpRegister() { j, _ := json.MarshalIndent(request, "", "\t") c.logf("RegisterRequest: %s", j) } @@ -673,8 +715,8 @@ func (c *Direct) doLogin(ctx context.Context, opt loginOpt) (mustRegen bool, new if err != nil { return regen, opt.URL, nil, err } - addLBHeader(req, request.OldNodeKey) - addLBHeader(req, request.NodeKey) + ts2021.AddLBHeader(req, request.OldNodeKey) + ts2021.AddLBHeader(req, request.NodeKey) res, err := httpc.Do(req) if err != nil { @@ -691,7 +733,7 @@ func (c *Direct) doLogin(ctx context.Context, opt loginOpt) (mustRegen bool, new c.logf("error decoding RegisterResponse with server key %s and machine key %s: %v", serverKey, machinePrivKey.Public(), err) return regen, opt.URL, nil, fmt.Errorf("register request: %v", err) } - if debugRegister() { + if DevKnob.DumpRegister() { j, _ := json.MarshalIndent(resp, "", "\t") c.logf("RegisterResponse: %s", j) } @@ -811,6 +853,31 @@ func (c *Direct) SendUpdate(ctx context.Context) error { return c.sendMapRequest(ctx, false, nil) } +// SetDiscoPublicKey updates the disco public key in local state. +// It does not implicitly trigger [SendUpdate]; callers should arrange for that. +func (c *Direct) SetDiscoPublicKey(key key.DiscoPublic) { + c.mu.Lock() + defer c.mu.Unlock() + c.discoPubKey = key +} + +// ClientID returns the controlClientID of the controlClient. +func (c *Direct) ClientID() int64 { + return c.controlClientID +} + +// AutoUpdate is an eventbus value, reporting the value of tailcfg.MapResponse.DefaultAutoUpdate. +type AutoUpdate struct { + ClientID int64 // The ID field is used for consumers to differentiate instances of Direct. + Value bool // The Value represents DefaultAutoUpdate from [tailcfg.MapResponse]. +} + +// ControlTime is an eventbus value, reporting the value of tailcfg.MapResponse.ControlTime. +type ControlTime struct { + ClientID int64 // The ID field is used for consumers to differentiate instances of Direct. + Value time.Time // The Value represents ControlTime from [tailcfg.MapResponse]. +} + // If we go more than watchdogTimeout without hearing from the server, // end the long poll. We should be receiving a keep alive ping // every minute. @@ -843,8 +910,11 @@ func (c *Direct) sendMapRequest(ctx context.Context, isStreaming bool, nu Netmap persist := c.persist serverURL := c.serverURL serverNoiseKey := c.serverNoiseKey + discoKey := c.discoPubKey hi := c.hostInfoLocked() backendLogID := hi.BackendLogID + connectionHandleForTest := c.connectionHandleForTest + tkaHead := c.tkaHead var epStrs []string var eps []netip.AddrPort var epTypes []tailcfg.EndpointType @@ -877,28 +947,51 @@ func (c *Direct) sendMapRequest(ctx context.Context, isStreaming bool, nu Netmap c.logf("[v1] PollNetMap: stream=%v ep=%v", isStreaming, epStrs) vlogf := logger.Discard - if DevKnob.DumpNetMaps() { + if DevKnob.DumpNetMapsVerbose() { // TODO(bradfitz): update this to use "[v2]" prefix perhaps? but we don't // want to upload it always. vlogf = c.logf } nodeKey := persist.PublicNodeKey() + request := &tailcfg.MapRequest{ - Version: tailcfg.CurrentCapabilityVersion, - KeepAlive: true, - NodeKey: nodeKey, - DiscoKey: c.discoPubKey, - Endpoints: eps, - EndpointTypes: epTypes, - Stream: isStreaming, - Hostinfo: hi, - DebugFlags: c.debugFlags, - OmitPeers: nu == nil, - TKAHead: c.tkaHead, + Version: tailcfg.CurrentCapabilityVersion, + KeepAlive: true, + NodeKey: nodeKey, + DiscoKey: discoKey, + Endpoints: eps, + EndpointTypes: epTypes, + Stream: isStreaming, + Hostinfo: hi, + DebugFlags: c.debugFlags, + OmitPeers: nu == nil, + TKAHead: tkaHead, + ConnectionHandleForTest: connectionHandleForTest, + } + + // If we have a hardware attestation key, sign the node key with it and send + // the key & signature in the map request. + if buildfeatures.HasTPM { + if k := persist.AsStruct().AttestationKey; k != nil && !k.IsZero() { + hwPub := key.HardwareAttestationPublicFromPlatformKey(k) + request.HardwareAttestationKey = hwPub + + t := c.clock.Now() + msg := fmt.Sprintf("%d|%s", t.Unix(), nodeKey.String()) + digest := sha256.Sum256([]byte(msg)) + sig, err := k.Sign(nil, digest[:], crypto.SHA256) + if err != nil { + c.logf("failed to sign node key with hardware attestation key: %v", err) + } else { + request.HardwareAttestationKeySignature = sig + request.HardwareAttestationKeySignatureTimestamp = t + } + } } + var extraDebugFlags []string - if hi != nil && c.netMon != nil && !c.skipIPForwardingCheck && + if buildfeatures.HasAdvertiseRoutes && hi != nil && c.netMon != nil && !c.skipIPForwardingCheck && ipForwardingBroken(hi.RoutableIPs, c.netMon.InterfaceState()) { extraDebugFlags = append(extraDebugFlags, "warn-ip-forwarding-off") } @@ -962,7 +1055,7 @@ func (c *Direct) sendMapRequest(ctx context.Context, isStreaming bool, nu Netmap if err != nil { return err } - addLBHeader(req, nodeKey) + ts2021.AddLBHeader(req, nodeKey) res, err := httpc.Do(req) if err != nil { @@ -1003,12 +1096,14 @@ func (c *Direct) sendMapRequest(ctx context.Context, isStreaming bool, nu Netmap if persist == c.persist { newPersist := persist.AsStruct() newPersist.NodeID = nm.SelfNode.StableID() - newPersist.UserProfile = nm.UserProfiles[nm.User()] + if up, ok := nm.UserProfiles[nm.User()]; ok { + newPersist.UserProfile = *up.AsStruct() + } c.persist = newPersist.View() persist = c.persist } - c.expiry = nm.Expiry + c.expiry = nm.SelfKeyExpiry() } // gotNonKeepAliveMessage is whether we've yet received a MapResponse message without @@ -1040,7 +1135,7 @@ func (c *Direct) sendMapRequest(ctx context.Context, isStreaming bool, nu Netmap vlogf("netmap: read body after %v", time.Since(t0).Round(time.Millisecond)) var resp tailcfg.MapResponse - if err := c.decodeMsg(msg, &resp); err != nil { + if err := sess.decodeMsg(msg, &resp); err != nil { vlogf("netmap: decode error: %v", err) return err } @@ -1065,21 +1160,19 @@ func (c *Direct) sendMapRequest(ctx context.Context, isStreaming bool, nu Netmap c.logf("netmap: control says to open URL %v; no popBrowser func", u) } } - if resp.ClientVersion != nil && c.onClientVersion != nil { - c.onClientVersion(resp.ClientVersion) + if resp.ClientVersion != nil { + c.clientVersionPub.Publish(*resp.ClientVersion) } if resp.ControlTime != nil && !resp.ControlTime.IsZero() { c.logf.JSON(1, "controltime", resp.ControlTime.UTC()) - if c.onControlTime != nil { - c.onControlTime(*resp.ControlTime) - } + c.controlTimePub.Publish(ControlTime{c.controlClientID, *resp.ControlTime}) } if resp.KeepAlive { vlogf("netmap: got keep-alive") } else { vlogf("netmap: got new map") } - if resp.ControlDialPlan != nil { + if resp.ControlDialPlan != nil && !ignoreDialPlan() { if c.dialPlan != nil { c.logf("netmap: got new dial plan from control") c.dialPlan.Store(resp.ControlDialPlan) @@ -1092,9 +1185,7 @@ func (c *Direct) sendMapRequest(ctx context.Context, isStreaming bool, nu Netmap continue } if au, ok := resp.DefaultAutoUpdate.Get(); ok { - if c.onTailnetDefaultAutoUpdate != nil { - c.onTailnetDefaultAutoUpdate(au) - } + c.autoUpdatePub.Publish(AutoUpdate{c.controlClientID, au}) } metricMapResponseMap.Add(1) @@ -1118,12 +1209,33 @@ func (c *Direct) sendMapRequest(ctx context.Context, isStreaming bool, nu Netmap return nil } +// NetmapFromMapResponseForDebug returns a NetworkMap from the given MapResponse. +// It is intended for debugging only. +func NetmapFromMapResponseForDebug(ctx context.Context, pr persist.PersistView, resp *tailcfg.MapResponse) (*netmap.NetworkMap, error) { + if resp == nil { + return nil, errors.New("nil MapResponse") + } + if resp.Node == nil { + return nil, errors.New("MapResponse lacks Node") + } + + nu := &rememberLastNetmapUpdater{} + sess := newMapSession(pr.PrivateNodeKey(), nu, nil) + defer sess.Close() + + if err := sess.HandleNonKeepAliveMapResponse(ctx, resp); err != nil { + return nil, fmt.Errorf("HandleNonKeepAliveMapResponse: %w", err) + } + + return sess.netmap(), nil +} + func (c *Direct) handleDebugMessage(ctx context.Context, debug *tailcfg.Debug) error { if code := debug.Exit; code != nil { c.logf("exiting process with status %v per controlplane", *code) os.Exit(*code) } - if debug.DisableLogTail { + if buildfeatures.HasLogTail && debug.DisableLogTail { logtail.Disable() envknob.SetNoLogsNoSupport() } @@ -1170,20 +1282,26 @@ func decode(res *http.Response, v any) error { return json.Unmarshal(msg, v) } -var ( - debugMap = envknob.RegisterBool("TS_DEBUG_MAP") - debugRegister = envknob.RegisterBool("TS_DEBUG_REGISTER") -) - var jsonEscapedZero = []byte(`\u0000`) +const justKeepAliveStr = `{"KeepAlive":true}` + // decodeMsg is responsible for uncompressing msg and unmarshaling into v. -func (c *Direct) decodeMsg(compressedMsg []byte, v any) error { +func (sess *mapSession) decodeMsg(compressedMsg []byte, v *tailcfg.MapResponse) error { + // Fast path for common case of keep-alive message. + // See tailscale/tailscale#17343. + if sess.keepAliveZ != nil && bytes.Equal(compressedMsg, sess.keepAliveZ) { + v.KeepAlive = true + return nil + } + b, err := zstdframe.AppendDecode(nil, compressedMsg) if err != nil { return err } - if debugMap() { + sess.ztdDecodesForTest++ + + if DevKnob.DumpNetMaps() { var buf bytes.Buffer json.Indent(&buf, b, "", " ") log.Printf("MapResponse: %s", buf.Bytes()) @@ -1195,6 +1313,9 @@ func (c *Direct) decodeMsg(compressedMsg []byte, v any) error { if err := json.Unmarshal(b, v); err != nil { return fmt.Errorf("response: %v", err) } + if v.KeepAlive && string(b) == justKeepAliveStr { + sess.keepAliveZ = compressedMsg + } return nil } @@ -1205,7 +1326,7 @@ func encode(v any) ([]byte, error) { if err != nil { return nil, err } - if debugMap() { + if DevKnob.DumpNetMaps() { if _, ok := v.(*tailcfg.MapRequest); ok { log.Printf("MapRequest: %s", b) } @@ -1242,7 +1363,7 @@ func loadServerPubKeys(ctx context.Context, httpc *http.Client, serverURL string out = tailcfg.OverTLSPublicKeyResponse{} k, err := key.ParseMachinePublicUntyped(mem.B(b)) if err != nil { - return nil, multierr.New(jsonErr, err) + return nil, errors.Join(jsonErr, err) } out.LegacyPublicKey = k return &out, nil @@ -1253,18 +1374,25 @@ func loadServerPubKeys(ctx context.Context, httpc *http.Client, serverURL string var DevKnob = initDevKnob() type devKnobs struct { - DumpNetMaps func() bool - ForceProxyDNS func() bool - StripEndpoints func() bool // strip endpoints from control (only use disco messages) - StripCaps func() bool // strip all local node's control-provided capabilities + DumpRegister func() bool + DumpNetMaps func() bool + DumpNetMapsVerbose func() bool + ForceProxyDNS func() bool + StripEndpoints func() bool // strip endpoints from control (only use disco messages) + StripHomeDERP func() bool // strip Home DERP from control + StripCaps func() bool // strip all local node's control-provided capabilities } func initDevKnob() devKnobs { + nm := envknob.RegisterInt("TS_DEBUG_MAP") return devKnobs{ - DumpNetMaps: envknob.RegisterBool("TS_DEBUG_NETMAP"), - ForceProxyDNS: envknob.RegisterBool("TS_DEBUG_PROXY_DNS"), - StripEndpoints: envknob.RegisterBool("TS_DEBUG_STRIP_ENDPOINTS"), - StripCaps: envknob.RegisterBool("TS_DEBUG_STRIP_CAPS"), + DumpNetMaps: func() bool { return nm() > 0 }, + DumpNetMapsVerbose: func() bool { return nm() > 1 }, + DumpRegister: envknob.RegisterBool("TS_DEBUG_REGISTER"), + ForceProxyDNS: envknob.RegisterBool("TS_DEBUG_PROXY_DNS"), + StripEndpoints: envknob.RegisterBool("TS_DEBUG_STRIP_ENDPOINTS"), + StripHomeDERP: envknob.RegisterBool("TS_DEBUG_STRIP_HOME_DERP"), + StripCaps: envknob.RegisterBool("TS_DEBUG_STRIP_CAPS"), } } @@ -1305,6 +1433,10 @@ func (c *Direct) isUniquePingRequest(pr *tailcfg.PingRequest) bool { return true } +// HookAnswerC2NPing is where feature/c2n conditionally registers support +// for handling C2N (control-to-node) HTTP requests. +var HookAnswerC2NPing feature.Hook[func(logger.Logf, http.Handler, *http.Client, *tailcfg.PingRequest)] + func (c *Direct) answerPing(pr *tailcfg.PingRequest) { httpc := c.httpc useNoise := pr.URLIsNoise || pr.Types == "c2n" @@ -1325,11 +1457,16 @@ func (c *Direct) answerPing(pr *tailcfg.PingRequest) { answerHeadPing(c.logf, httpc, pr) return case "c2n": + if !buildfeatures.HasC2N { + return + } if !useNoise && !envknob.Bool("TS_DEBUG_PERMIT_HTTP_C2N") { c.logf("refusing to answer c2n ping without noise") return } - answerC2NPing(c.logf, c.c2nHandler, httpc, pr) + if f, ok := HookAnswerC2NPing.GetOk(); ok { + f(c.logf, c.c2nHandler, httpc, pr) + } return } for _, t := range strings.Split(pr.Types, ",") { @@ -1364,54 +1501,6 @@ func answerHeadPing(logf logger.Logf, c *http.Client, pr *tailcfg.PingRequest) { } } -func answerC2NPing(logf logger.Logf, c2nHandler http.Handler, c *http.Client, pr *tailcfg.PingRequest) { - if c2nHandler == nil { - logf("answerC2NPing: c2nHandler not defined") - return - } - hreq, err := http.ReadRequest(bufio.NewReader(bytes.NewReader(pr.Payload))) - if err != nil { - logf("answerC2NPing: ReadRequest: %v", err) - return - } - if pr.Log { - logf("answerC2NPing: got c2n request for %v ...", hreq.RequestURI) - } - handlerTimeout := time.Minute - if v := hreq.Header.Get("C2n-Handler-Timeout"); v != "" { - handlerTimeout, _ = time.ParseDuration(v) - } - handlerCtx, cancel := context.WithTimeout(context.Background(), handlerTimeout) - defer cancel() - hreq = hreq.WithContext(handlerCtx) - rec := httptest.NewRecorder() - c2nHandler.ServeHTTP(rec, hreq) - cancel() - - c2nResBuf := new(bytes.Buffer) - rec.Result().Write(c2nResBuf) - - replyCtx, cancel := context.WithTimeout(context.Background(), time.Minute) - defer cancel() - - req, err := http.NewRequestWithContext(replyCtx, "POST", pr.URL, c2nResBuf) - if err != nil { - logf("answerC2NPing: NewRequestWithContext: %v", err) - return - } - if pr.Log { - logf("answerC2NPing: sending POST ping to %v ...", pr.URL) - } - t0 := clock.Now() - _, err = c.Do(req) - d := time.Since(t0).Round(time.Millisecond) - if err != nil { - logf("answerC2NPing error: %v to %v (after %v)", err, pr.URL, d) - } else if pr.Log { - logf("answerC2NPing complete to %v (after %v)", pr.URL, d) - } -} - // sleepAsRequest implements the sleep for a tailcfg.Debug message requesting // that the client sleep. The complication is that while we're sleeping (if for // a long time), we need to periodically reset the watchdog timer before it @@ -1436,7 +1525,7 @@ func sleepAsRequested(ctx context.Context, logf logger.Logf, d time.Duration, cl } // getNoiseClient returns the noise client, creating one if one doesn't exist. -func (c *Direct) getNoiseClient() (*NoiseClient, error) { +func (c *Direct) getNoiseClient() (*ts2021.Client, error) { c.mu.Lock() serverNoiseKey := c.serverNoiseKey nc := c.noiseClient @@ -1451,13 +1540,13 @@ func (c *Direct) getNoiseClient() (*NoiseClient, error) { if c.dialPlan != nil { dp = c.dialPlan.Load } - nc, err, _ := c.sfGroup.Do(struct{}{}, func() (*NoiseClient, error) { + nc, err, _ := c.sfGroup.Do(struct{}{}, func() (*ts2021.Client, error) { k, err := c.getMachinePrivKey() if err != nil { return nil, err } c.logf("[v1] creating new noise client") - nc, err := NewNoiseClient(NoiseOpts{ + nc, err := ts2021.NewClient(ts2021.ClientOpts{ PrivKey: k, ServerPubKey: serverNoiseKey, ServerURL: c.serverURL, @@ -1491,7 +1580,7 @@ func (c *Direct) setDNSNoise(ctx context.Context, req *tailcfg.SetDNSRequest) er if err != nil { return err } - res, err := nc.post(ctx, "/machine/set-dns", newReq.NodeKey, &newReq) + res, err := nc.Post(ctx, "/machine/set-dns", newReq.NodeKey, &newReq) if err != nil { return err } @@ -1512,6 +1601,9 @@ func (c *Direct) setDNSNoise(ctx context.Context, req *tailcfg.SetDNSRequest) er // SetDNS sends the SetDNSRequest request to the control plane server, // requesting a DNS record be created or updated. func (c *Direct) SetDNS(ctx context.Context, req *tailcfg.SetDNSRequest) (err error) { + if !buildfeatures.HasACME { + return feature.ErrUnavailable + } metricSetDNS.Add(1) defer func() { if err != nil { @@ -1532,20 +1624,6 @@ func (c *Direct) DoNoiseRequest(req *http.Request) (*http.Response, error) { return nc.Do(req) } -// GetSingleUseNoiseRoundTripper returns a RoundTripper that can be only be used -// once (and must be used once) to make a single HTTP request over the noise -// channel to the coordination server. -// -// In addition to the RoundTripper, it returns the HTTP/2 channel's early noise -// payload, if any. -func (c *Direct) GetSingleUseNoiseRoundTripper(ctx context.Context) (http.RoundTripper, *tailcfg.EarlyNoise, error) { - nc, err := c.getNoiseClient() - if err != nil { - return nil, nil, err - } - return nc.GetSingleUseRoundTripper(ctx) -} - // doPingerPing sends a Ping to pr.IP using pinger, and sends an http request back to // pr.URL with ping response data. func doPingerPing(logf logger.Logf, c *http.Client, pr *tailcfg.PingRequest, pinger Pinger, pingType tailcfg.PingType) { @@ -1602,61 +1680,103 @@ func postPingResult(start time.Time, logf logger.Logf, c *http.Client, pr *tailc return nil } -// ReportHealthChange reports to the control plane a change to this node's -// health. w must be non-nil. us can be nil to indicate a healthy state for w. -func (c *Direct) ReportHealthChange(w *health.Warnable, us *health.UnhealthyState) { - if w == health.NetworkStatusWarnable || w == health.IPNStateWarnable || w == health.LoginStateWarnable { - // We don't report these. These include things like the network is down - // (in which case we can't report anyway) or the user wanted things - // stopped, as opposed to the more unexpected failure types in the other - // subsystems. - return - } - np, err := c.getNoiseClient() +// SetDeviceAttrs does a synchronous call to the control plane to update +// the node's attributes. +// +// See docs on [tailcfg.SetDeviceAttributesRequest] for background. +func (c *Auto) SetDeviceAttrs(ctx context.Context, attrs tailcfg.AttrUpdate) error { + return c.direct.SetDeviceAttrs(ctx, attrs) +} + +// SetDeviceAttrs does a synchronous call to the control plane to update +// the node's attributes. +// +// See docs on [tailcfg.SetDeviceAttributesRequest] for background. +func (c *Direct) SetDeviceAttrs(ctx context.Context, attrs tailcfg.AttrUpdate) error { + nc, err := c.getNoiseClient() if err != nil { - // Don't report errors to control if the server doesn't support noise. - return + return fmt.Errorf("%w: %w", errNoNoiseClient, err) } nodeKey, ok := c.GetPersist().PublicNodeKeyOK() if !ok { - return + return errNoNodeKey } if c.panicOnUse { panic("tainted client") } - // TODO(angott): at some point, update `Subsys` in the request to be `Warnable` - req := &tailcfg.HealthChangeRequest{ - Subsys: string(w.Code), + req := &tailcfg.SetDeviceAttributesRequest{ NodeKey: nodeKey, + Version: tailcfg.CurrentCapabilityVersion, + Update: attrs, } - if us != nil { - req.Error = us.Text - } - // Best effort, no logging: - ctx, cancel := context.WithTimeout(c.closedCtx, 5*time.Second) + // TODO(bradfitz): unify the callers using doWithBody vs those using + // DoNoiseRequest. There seems to be a ~50/50 split and they're very close, + // but doWithBody sets the load balancing header and auto-JSON-encodes the + // body, but DoNoiseRequest is exported. Clean it up so they're consistent + // one way or another. + + ctx, cancel := context.WithTimeout(ctx, 30*time.Second) defer cancel() - res, err := np.post(ctx, "/machine/update-health", nodeKey, req) + res, err := nc.DoWithBody(ctx, "PATCH", "/machine/set-device-attr", nodeKey, req) if err != nil { - return + return err } - res.Body.Close() + defer res.Body.Close() + all, _ := io.ReadAll(res.Body) + if res.StatusCode != 200 { + return fmt.Errorf("HTTP error from control plane: %v: %s", res.Status, all) + } + return nil } -func addLBHeader(req *http.Request, nodeKey key.NodePublic) { - if !nodeKey.IsZero() { - req.Header.Add(tailcfg.LBHeader, nodeKey.String()) - } +// SendAuditLog implements [auditlog.Transport] by sending an audit log synchronously to the control plane. +// +// See docs on [tailcfg.AuditLogRequest] and [auditlog.Logger] for background. +func (c *Auto) SendAuditLog(ctx context.Context, auditLog tailcfg.AuditLogRequest) (err error) { + return c.direct.sendAuditLog(ctx, auditLog) } -type dialFunc = func(ctx context.Context, network, addr string) (net.Conn, error) +func (c *Direct) sendAuditLog(ctx context.Context, auditLog tailcfg.AuditLogRequest) (err error) { + nc, err := c.getNoiseClient() + if err != nil { + return fmt.Errorf("%w: %w", errNoNoiseClient, err) + } + + nodeKey, ok := c.GetPersist().PublicNodeKeyOK() + if !ok { + return errNoNodeKey + } + + req := &tailcfg.AuditLogRequest{ + Version: tailcfg.CurrentCapabilityVersion, + NodeKey: nodeKey, + Action: auditLog.Action, + Details: auditLog.Details, + } + + if c.panicOnUse { + panic("tainted client") + } + + res, err := nc.Post(ctx, "/machine/audit-log", nodeKey, req) + if err != nil { + return fmt.Errorf("%w: %w", errHTTPPostFailure, err) + } + defer res.Body.Close() + if res.StatusCode != 200 { + all, _ := io.ReadAll(res.Body) + return errBadHTTPResponse(res.StatusCode, string(all)) + } + return nil +} // makeScreenTimeDetectingDialFunc returns dialFunc, optionally wrapped (on // Apple systems) with a func that sets the returned atomic.Bool for whether // Screen Time seemed to intercept the connection. // // The returned *atomic.Bool is nil on non-Apple systems. -func makeScreenTimeDetectingDialFunc(dial dialFunc) (dialFunc, *atomic.Bool) { +func makeScreenTimeDetectingDialFunc(dial netx.DialFunc) (netx.DialFunc, *atomic.Bool) { switch runtime.GOOS { case "darwin", "ios": // Continue below. @@ -1674,6 +1794,13 @@ func makeScreenTimeDetectingDialFunc(dial dialFunc) (dialFunc, *atomic.Bool) { }, ab } +func ignoreDialPlan() bool { + // If we're running in v86 (a JavaScript-based emulation of a 32-bit x86) + // our networking is very limited. Let's ignore the dial plan since it's too + // complicated to race that many IPs anyway. + return hostinfo.IsInVM86() +} + func isTCPLoopback(a net.Addr) bool { if ta, ok := a.(*net.TCPAddr); ok { return ta.IP.IsLoopback() diff --git a/control/controlclient/direct_test.go b/control/controlclient/direct_test.go index e2a6f9fa4..4329fc878 100644 --- a/control/controlclient/direct_test.go +++ b/control/controlclient/direct_test.go @@ -17,21 +17,52 @@ import ( "tailscale.com/net/tsdial" "tailscale.com/tailcfg" "tailscale.com/types/key" + "tailscale.com/util/eventbus/eventbustest" ) +func TestSetDiscoPublicKey(t *testing.T) { + initialKey := key.NewDisco().Public() + + c := &Direct{ + discoPubKey: initialKey, + } + + c.mu.Lock() + if c.discoPubKey != initialKey { + t.Fatalf("initial disco key mismatch: got %v, want %v", c.discoPubKey, initialKey) + } + c.mu.Unlock() + + newKey := key.NewDisco().Public() + c.SetDiscoPublicKey(newKey) + + c.mu.Lock() + if c.discoPubKey != newKey { + t.Fatalf("disco key not updated: got %v, want %v", c.discoPubKey, newKey) + } + if c.discoPubKey == initialKey { + t.Fatal("disco key should have changed") + } + c.mu.Unlock() +} + func TestNewDirect(t *testing.T) { hi := hostinfo.New() ni := tailcfg.NetInfo{LinkType: "wired"} hi.NetInfo = &ni + bus := eventbustest.NewBus(t) k := key.NewMachine() + dialer := tsdial.NewDialer(netmon.NewStatic()) + dialer.SetBus(bus) opts := Options{ ServerURL: "https://example.com", Hostinfo: hi, GetMachinePrivateKey: func() (key.MachinePrivate, error) { return k, nil }, - Dialer: tsdial.NewDialer(netmon.NewStatic()), + Dialer: dialer, + Bus: bus, } c, err := NewDirect(opts) if err != nil { @@ -99,15 +130,19 @@ func TestTsmpPing(t *testing.T) { hi := hostinfo.New() ni := tailcfg.NetInfo{LinkType: "wired"} hi.NetInfo = &ni + bus := eventbustest.NewBus(t) k := key.NewMachine() + dialer := tsdial.NewDialer(netmon.NewStatic()) + dialer.SetBus(bus) opts := Options{ ServerURL: "https://example.com", Hostinfo: hi, GetMachinePrivateKey: func() (key.MachinePrivate, error) { return k, nil }, - Dialer: tsdial.NewDialer(netmon.NewStatic()), + Dialer: dialer, + Bus: bus, } c, err := NewDirect(opts) diff --git a/control/controlclient/errors.go b/control/controlclient/errors.go new file mode 100644 index 000000000..9b4dab844 --- /dev/null +++ b/control/controlclient/errors.go @@ -0,0 +1,51 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package controlclient + +import ( + "errors" + "fmt" + "net/http" +) + +// apiResponseError is an error type that can be returned by controlclient +// api requests. +// +// It wraps an underlying error and a flag for clients to query if the +// error is retryable via the Retryable() method. +type apiResponseError struct { + err error + retryable bool +} + +// Error implements [error]. +func (e *apiResponseError) Error() string { + return e.err.Error() +} + +// Retryable reports whether the error is retryable. +func (e *apiResponseError) Retryable() bool { + return e.retryable +} + +func (e *apiResponseError) Unwrap() error { return e.err } + +var ( + errNoNodeKey = &apiResponseError{errors.New("no node key"), true} + errNoNoiseClient = &apiResponseError{errors.New("no noise client"), true} + errHTTPPostFailure = &apiResponseError{errors.New("http failure"), true} +) + +func errBadHTTPResponse(code int, msg string) error { + retryable := false + switch code { + case http.StatusTooManyRequests, + http.StatusInternalServerError, + http.StatusBadGateway, + http.StatusServiceUnavailable, + http.StatusGatewayTimeout: + retryable = true + } + return &apiResponseError{fmt.Errorf("http error %d: %s", code, msg), retryable} +} diff --git a/control/controlclient/map.go b/control/controlclient/map.go index 787912222..9aa8e3710 100644 --- a/control/controlclient/map.go +++ b/control/controlclient/map.go @@ -6,21 +6,23 @@ package controlclient import ( "cmp" "context" + "crypto/sha256" + "encoding/hex" "encoding/json" - "fmt" + "io" "maps" "net" "reflect" "runtime" "runtime/debug" "slices" - "sort" "strconv" "sync" "time" "tailscale.com/control/controlknobs" "tailscale.com/envknob" + "tailscale.com/hostinfo" "tailscale.com/tailcfg" "tailscale.com/tstime" "tailscale.com/types/key" @@ -31,6 +33,7 @@ import ( "tailscale.com/util/clientmetric" "tailscale.com/util/mak" "tailscale.com/util/set" + "tailscale.com/util/slicesx" "tailscale.com/wgengine/filter" ) @@ -54,6 +57,9 @@ type mapSession struct { altClock tstime.Clock // if nil, regular time is used cancel context.CancelFunc // always non-nil, shuts down caller's base long poll context + keepAliveZ []byte // if non-nil, the learned zstd encoding of the just-KeepAlive message for this session + ztdDecodesForTest int // for testing + // sessionAliveCtx is a Background-based context that's alive for the // duration of the mapSession that we own the lifetime of. It's closed by // sessionAliveCtxClose. @@ -75,11 +81,10 @@ type mapSession struct { lastPrintMap time.Time lastNode tailcfg.NodeView lastCapSet set.Set[tailcfg.NodeCapability] - peers map[tailcfg.NodeID]*tailcfg.NodeView // pointer to view (oddly). same pointers as sortedPeers. - sortedPeers []*tailcfg.NodeView // same pointers as peers, but sorted by Node.ID + peers map[tailcfg.NodeID]tailcfg.NodeView lastDNSConfig *tailcfg.DNSConfig lastDERPMap *tailcfg.DERPMap - lastUserProfile map[tailcfg.UserID]tailcfg.UserProfile + lastUserProfile map[tailcfg.UserID]tailcfg.UserProfileView lastPacketFilterRules views.Slice[tailcfg.FilterRule] // concatenation of all namedPacketFilters namedPacketFilters map[string]views.Slice[tailcfg.FilterRule] lastParsedPacketFilter []filter.Match @@ -88,10 +93,10 @@ type mapSession struct { lastDomain string lastDomainAuditLogID string lastHealth []string + lastDisplayMessages map[tailcfg.DisplayMessageID]tailcfg.DisplayMessage lastPopBrowserURL string lastTKAInfo *tailcfg.TKAInfo lastNetmapSummary string // from NetworkMap.VeryConcise - lastMaxExpiry time.Duration } // newMapSession returns a mostly unconfigured new mapSession. @@ -106,7 +111,7 @@ func newMapSession(privateNodeKey key.NodePrivate, nu NetmapUpdater, controlKnob privateNodeKey: privateNodeKey, publicNodeKey: privateNodeKey.Public(), lastDNSConfig: new(tailcfg.DNSConfig), - lastUserProfile: map[tailcfg.UserID]tailcfg.UserProfile{}, + lastUserProfile: map[tailcfg.UserID]tailcfg.UserProfileView{}, // Non-nil no-op defaults, to be optionally overridden by the caller. logf: logger.Discard, @@ -167,6 +172,7 @@ func (ms *mapSession) HandleNonKeepAliveMapResponse(ctx context.Context, resp *t // For responses that mutate the self node, check for updated nodeAttrs. if resp.Node != nil { + upgradeNode(resp.Node) if DevKnob.StripCaps() { resp.Node.Capabilities = nil resp.Node.CapMap = nil @@ -182,6 +188,13 @@ func (ms *mapSession) HandleNonKeepAliveMapResponse(ctx context.Context, resp *t ms.controlKnobs.UpdateFromNodeAttributes(resp.Node.CapMap) } + for _, p := range resp.Peers { + upgradeNode(p) + } + for _, p := range resp.PeersChanged { + upgradeNode(p) + } + // Call Node.InitDisplayNames on any changed nodes. initDisplayNames(cmp.Or(resp.Node.View(), ms.lastNode), resp) @@ -217,6 +230,33 @@ func (ms *mapSession) HandleNonKeepAliveMapResponse(ctx context.Context, resp *t return nil } +// upgradeNode upgrades Node fields from the server into the modern forms +// not using deprecated fields. +func upgradeNode(n *tailcfg.Node) { + if n == nil { + return + } + if n.LegacyDERPString != "" { + if n.HomeDERP == 0 { + ip, portStr, err := net.SplitHostPort(n.LegacyDERPString) + if ip == tailcfg.DerpMagicIP && err == nil { + port, err := strconv.Atoi(portStr) + if err == nil { + n.HomeDERP = port + } + } + } + n.LegacyDERPString = "" + } + if DevKnob.StripHomeDERP() { + n.HomeDERP = 0 + } + + if n.AllowedIPs == nil { + n.AllowedIPs = slices.Clone(n.Addresses) + } +} + func (ms *mapSession) tryHandleIncrementally(res *tailcfg.MapResponse) bool { if ms.controlKnobs != nil && ms.controlKnobs.DisableDeltaUpdates.Load() { return false @@ -260,13 +300,47 @@ func (ms *mapSession) updateStateFromResponse(resp *tailcfg.MapResponse) { } for _, up := range resp.UserProfiles { - ms.lastUserProfile[up.ID] = up + ms.lastUserProfile[up.ID] = up.View() } // TODO(bradfitz): clean up old user profiles? maybe not worth it. if dm := resp.DERPMap; dm != nil { ms.vlogf("netmap: new map contains DERP map") + // Guard against the control server accidentally sending + // a nil region definition, which at least Headscale was + // observed to send. + for rid, r := range dm.Regions { + if r == nil { + delete(dm.Regions, rid) + } + } + + // In the copy/v86 wasm environment with limited networking, if the + // control plane didn't pick our DERP home for us, do it ourselves and + // mark all but the lowest region as NoMeasureNoHome. For prod, this + // will be Region 1, NYC, a compromise between the US and Europe. But + // really the control plane should pick this. This is only a fallback. + if hostinfo.IsInVM86() { + numCanMeasure := 0 + lowest := 0 + for rid, r := range dm.Regions { + if !r.NoMeasureNoHome { + numCanMeasure++ + if lowest == 0 || rid < lowest { + lowest = rid + } + } + } + if numCanMeasure > 1 { + for rid, r := range dm.Regions { + if rid != lowest { + r.NoMeasureNoHome = true + } + } + } + } + // Zero-valued fields in a DERPMap mean that we're not changing // anything and are using the previous value(s). if ldm := ms.lastDERPMap; ldm != nil { @@ -342,12 +416,24 @@ func (ms *mapSession) updateStateFromResponse(resp *tailcfg.MapResponse) { if resp.Health != nil { ms.lastHealth = resp.Health } + if resp.DisplayMessages != nil { + if v, ok := resp.DisplayMessages["*"]; ok && v == nil { + ms.lastDisplayMessages = nil + } + for k, v := range resp.DisplayMessages { + if k == "*" { + continue + } + if v != nil { + mak.Set(&ms.lastDisplayMessages, k, *v) + } else { + delete(ms.lastDisplayMessages, k) + } + } + } if resp.TKAInfo != nil { ms.lastTKAInfo = resp.TKAInfo } - if resp.MaxKeyDuration > 0 { - ms.lastMaxExpiry = resp.MaxKeyDuration - } } var ( @@ -366,16 +452,11 @@ var ( patchifiedPeerEqual = clientmetric.NewCounter("controlclient_patchified_peer_equal") ) -// updatePeersStateFromResponseres updates ms.peers and ms.sortedPeers from res. It takes ownership of res. +// updatePeersStateFromResponseres updates ms.peers from resp. +// It takes ownership of resp. func (ms *mapSession) updatePeersStateFromResponse(resp *tailcfg.MapResponse) (stats updateStats) { - defer func() { - if stats.removed > 0 || stats.added > 0 { - ms.rebuildSorted() - } - }() - if ms.peers == nil { - ms.peers = make(map[tailcfg.NodeID]*tailcfg.NodeView) + ms.peers = make(map[tailcfg.NodeID]tailcfg.NodeView) } if len(resp.Peers) > 0 { @@ -384,12 +465,12 @@ func (ms *mapSession) updatePeersStateFromResponse(resp *tailcfg.MapResponse) (s keep := make(map[tailcfg.NodeID]bool, len(resp.Peers)) for _, n := range resp.Peers { keep[n.ID] = true - if vp, ok := ms.peers[n.ID]; ok { + lenBefore := len(ms.peers) + ms.peers[n.ID] = n.View() + if len(ms.peers) == lenBefore { stats.changed++ - *vp = n.View() } else { stats.added++ - ms.peers[n.ID] = ptr.To(n.View()) } } for id := range ms.peers { @@ -410,12 +491,12 @@ func (ms *mapSession) updatePeersStateFromResponse(resp *tailcfg.MapResponse) (s } for _, n := range resp.PeersChanged { - if vp, ok := ms.peers[n.ID]; ok { + lenBefore := len(ms.peers) + ms.peers[n.ID] = n.View() + if len(ms.peers) == lenBefore { stats.changed++ - *vp = n.View() } else { stats.added++ - ms.peers[n.ID] = ptr.To(n.View()) } } @@ -427,7 +508,7 @@ func (ms *mapSession) updatePeersStateFromResponse(resp *tailcfg.MapResponse) (s } else { mut.LastSeen = nil } - *vp = mut.View() + ms.peers[nodeID] = mut.View() stats.changed++ } } @@ -436,7 +517,7 @@ func (ms *mapSession) updatePeersStateFromResponse(resp *tailcfg.MapResponse) (s if vp, ok := ms.peers[nodeID]; ok { mut := vp.AsStruct() mut.Online = ptr.To(online) - *vp = mut.View() + ms.peers[nodeID] = mut.View() stats.changed++ } } @@ -449,7 +530,7 @@ func (ms *mapSession) updatePeersStateFromResponse(resp *tailcfg.MapResponse) (s stats.changed++ mut := vp.AsStruct() if pc.DERPRegion != 0 { - mut.DERP = fmt.Sprintf("%s:%v", tailcfg.DerpMagicIP, pc.DERPRegion) + mut.HomeDERP = pc.DERPRegion patchDERPRegion.Add(1) } if pc.Cap != 0 { @@ -488,31 +569,12 @@ func (ms *mapSession) updatePeersStateFromResponse(resp *tailcfg.MapResponse) (s mut.CapMap = v patchCapMap.Add(1) } - *vp = mut.View() + ms.peers[pc.NodeID] = mut.View() } return } -// rebuildSorted rebuilds ms.sortedPeers from ms.peers. It should be called -// after any additions or removals from peers. -func (ms *mapSession) rebuildSorted() { - if ms.sortedPeers == nil { - ms.sortedPeers = make([]*tailcfg.NodeView, 0, len(ms.peers)) - } else { - if len(ms.sortedPeers) > len(ms.peers) { - clear(ms.sortedPeers[len(ms.peers):]) - } - ms.sortedPeers = ms.sortedPeers[:0] - } - for _, p := range ms.peers { - ms.sortedPeers = append(ms.sortedPeers, p) - } - sort.Slice(ms.sortedPeers, func(i, j int) bool { - return ms.sortedPeers[i].ID() < ms.sortedPeers[j].ID() - }) -} - func (ms *mapSession) addUserProfile(nm *netmap.NetworkMap, userID tailcfg.UserID) { if userID == 0 { return @@ -576,7 +638,7 @@ func (ms *mapSession) patchifyPeer(n *tailcfg.Node) (_ *tailcfg.PeerChange, ok b if !ok { return nil, false } - return peerChangeDiff(*was, n) + return peerChangeDiff(was, n) } // peerChangeDiff returns the difference from 'was' to 'n', if possible. @@ -656,17 +718,13 @@ func peerChangeDiff(was tailcfg.NodeView, n *tailcfg.Node) (_ *tailcfg.PeerChang if !views.SliceEqual(was.Endpoints(), views.SliceOf(n.Endpoints)) { pc().Endpoints = slices.Clone(n.Endpoints) } - case "DERP": - if was.DERP() != n.DERP { - ip, portStr, err := net.SplitHostPort(n.DERP) - if err != nil || ip != "127.3.3.40" { - return nil, false - } - port, err := strconv.Atoi(portStr) - if err != nil || port < 1 || port > 65535 { - return nil, false - } - pc().DERPRegion = port + case "LegacyDERPString": + if was.LegacyDERPString() != "" || n.LegacyDERPString != "" { + panic("unexpected; caller should've already called upgradeNode") + } + case "HomeDERP": + if was.HomeDERP() != n.HomeDERP { + pc().DERPRegion = n.HomeDERP } case "Hostinfo": if !was.Hostinfo().Valid() && !n.Hostinfo.Valid() { @@ -688,21 +746,23 @@ func peerChangeDiff(was tailcfg.NodeView, n *tailcfg.Node) (_ *tailcfg.PeerChang } case "CapMap": if len(n.CapMap) != was.CapMap().Len() { + // If they have different lengths, they're different. if n.CapMap == nil { pc().CapMap = make(tailcfg.NodeCapMap) } else { pc().CapMap = maps.Clone(n.CapMap) } - break - } - was.CapMap().Range(func(k tailcfg.NodeCapability, v views.Slice[tailcfg.RawMessage]) bool { - nv, ok := n.CapMap[k] - if !ok || !views.SliceEqual(v, views.SliceOf(nv)) { - pc().CapMap = maps.Clone(n.CapMap) - return false + } else { + // If they have the same length, check that all their keys + // have the same values. + for k, v := range was.CapMap().All() { + nv, ok := n.CapMap[k] + if !ok || !views.SliceEqual(v, views.SliceOf(nv)) { + pc().CapMap = maps.Clone(n.CapMap) + break + } } - return true - }) + } case "Tags": if !views.SliceEqual(was.Tags(), views.SliceOf(n.Tags)) { return nil, false @@ -712,13 +772,11 @@ func peerChangeDiff(was tailcfg.NodeView, n *tailcfg.Node) (_ *tailcfg.PeerChang return nil, false } case "Online": - wasOnline := was.Online() - if n.Online != nil && wasOnline != nil && *n.Online != *wasOnline { + if wasOnline, ok := was.Online().GetOk(); ok && n.Online != nil && *n.Online != wasOnline { pc().Online = ptr.To(*n.Online) } case "LastSeen": - wasSeen := was.LastSeen() - if n.LastSeen != nil && wasSeen != nil && !wasSeen.Equal(*n.LastSeen) { + if wasSeen, ok := was.LastSeen().GetOk(); ok && n.LastSeen != nil && !wasSeen.Equal(*n.LastSeen) { pc().LastSeen = ptr.To(*n.LastSeen) } case "MachineAuthorized": @@ -743,18 +801,18 @@ func peerChangeDiff(was tailcfg.NodeView, n *tailcfg.Node) (_ *tailcfg.PeerChang } case "SelfNodeV4MasqAddrForThisPeer": va, vb := was.SelfNodeV4MasqAddrForThisPeer(), n.SelfNodeV4MasqAddrForThisPeer - if va == nil && vb == nil { + if !va.Valid() && vb == nil { continue } - if va == nil || vb == nil || *va != *vb { + if va, ok := va.GetOk(); !ok || vb == nil || va != *vb { return nil, false } case "SelfNodeV6MasqAddrForThisPeer": va, vb := was.SelfNodeV6MasqAddrForThisPeer(), n.SelfNodeV6MasqAddrForThisPeer - if va == nil && vb == nil { + if !va.Valid() && vb == nil { continue } - if va == nil || vb == nil || *va != *vb { + if va, ok := va.GetOk(); !ok || vb == nil || va != *vb { return nil, false } case "ExitNodeDNSResolvers": @@ -778,21 +836,40 @@ func peerChangeDiff(was tailcfg.NodeView, n *tailcfg.Node) (_ *tailcfg.PeerChang return ret, true } +func (ms *mapSession) sortedPeers() []tailcfg.NodeView { + ret := slicesx.MapValues(ms.peers) + slices.SortFunc(ret, func(a, b tailcfg.NodeView) int { + return cmp.Compare(a.ID(), b.ID()) + }) + return ret +} + // netmap returns a fully populated NetworkMap from the last state seen from // a call to updateStateFromResponse, filling in omitted // information from prior MapResponse values. func (ms *mapSession) netmap() *netmap.NetworkMap { - peerViews := make([]tailcfg.NodeView, len(ms.sortedPeers)) - for i, vp := range ms.sortedPeers { - peerViews[i] = *vp + peerViews := ms.sortedPeers() + + var msgs map[tailcfg.DisplayMessageID]tailcfg.DisplayMessage + if len(ms.lastDisplayMessages) != 0 { + msgs = ms.lastDisplayMessages + } else if len(ms.lastHealth) > 0 { + // Convert all ms.lastHealth to the new [netmap.NetworkMap.DisplayMessages] + for _, h := range ms.lastHealth { + id := "health-" + strhash(h) // Unique ID in case there is more than one health message + mak.Set(&msgs, tailcfg.DisplayMessageID(id), tailcfg.DisplayMessage{ + Title: "Coordination server reports an issue", + Severity: tailcfg.SeverityMedium, + Text: "The coordination server is reporting a health issue: " + h, + }) + } } nm := &netmap.NetworkMap{ NodeKey: ms.publicNodeKey, - PrivateKey: ms.privateNodeKey, MachineKey: ms.machinePubKey, Peers: peerViews, - UserProfiles: make(map[tailcfg.UserID]tailcfg.UserProfile), + UserProfiles: make(map[tailcfg.UserID]tailcfg.UserProfileView), Domain: ms.lastDomain, DomainAuditLogID: ms.lastDomainAuditLogID, DNS: *ms.lastDNSConfig, @@ -801,9 +878,8 @@ func (ms *mapSession) netmap() *netmap.NetworkMap { SSHPolicy: ms.lastSSHPolicy, CollectServices: ms.collectServices, DERPMap: ms.lastDERPMap, - ControlHealth: ms.lastHealth, + DisplayMessages: msgs, TKAEnabled: ms.lastTKAInfo != nil && !ms.lastTKAInfo.Disabled, - MaxKeyDuration: ms.lastMaxExpiry, } if ms.lastTKAInfo != nil && ms.lastTKAInfo.Head != "" { @@ -815,8 +891,6 @@ func (ms *mapSession) netmap() *netmap.NetworkMap { if node := ms.lastNode; node.Valid() { nm.SelfNode = node - nm.Expiry = node.KeyExpiry() - nm.Name = node.Name() nm.AllCaps = ms.lastCapSet } @@ -828,5 +902,12 @@ func (ms *mapSession) netmap() *netmap.NetworkMap { if DevKnob.ForceProxyDNS() { nm.DNS.Proxied = true } + return nm } + +func strhash(h string) string { + s := sha256.New() + io.WriteString(s, h) + return hex.EncodeToString(s.Sum(nil)) +} diff --git a/control/controlclient/map_test.go b/control/controlclient/map_test.go index 897036a94..2be4b6ad7 100644 --- a/control/controlclient/map_test.go +++ b/control/controlclient/map_test.go @@ -4,9 +4,11 @@ package controlclient import ( + "bytes" "context" "encoding/json" "fmt" + "maps" "net/netip" "reflect" "strings" @@ -15,8 +17,11 @@ import ( "time" "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" "go4.org/mem" "tailscale.com/control/controlknobs" + "tailscale.com/health" + "tailscale.com/ipn" "tailscale.com/tailcfg" "tailscale.com/tstest" "tailscale.com/tstime" @@ -24,9 +29,12 @@ import ( "tailscale.com/types/key" "tailscale.com/types/logger" "tailscale.com/types/netmap" + "tailscale.com/types/persist" "tailscale.com/types/ptr" + "tailscale.com/util/eventbus/eventbustest" "tailscale.com/util/mak" "tailscale.com/util/must" + "tailscale.com/util/zstdframe" ) func eps(s ...string) []netip.AddrPort { @@ -50,9 +58,9 @@ func TestUpdatePeersStateFromResponse(t *testing.T) { n.LastSeen = &t } } - withDERP := func(d string) func(*tailcfg.Node) { + withDERP := func(regionID int) func(*tailcfg.Node) { return func(n *tailcfg.Node) { - n.DERP = d + n.HomeDERP = regionID } } withEP := func(ep string) func(*tailcfg.Node) { @@ -189,14 +197,14 @@ func TestUpdatePeersStateFromResponse(t *testing.T) { }, { name: "ep_change_derp", - prev: peers(n(1, "foo", withDERP("127.3.3.40:3"))), + prev: peers(n(1, "foo", withDERP(3))), mapRes: &tailcfg.MapResponse{ PeersChangedPatch: []*tailcfg.PeerChange{{ NodeID: 1, DERPRegion: 4, }}, }, - want: peers(n(1, "foo", withDERP("127.3.3.40:4"))), + want: peers(n(1, "foo", withDERP(4))), wantStats: updateStats{changed: 1}, }, { @@ -213,19 +221,19 @@ func TestUpdatePeersStateFromResponse(t *testing.T) { }, { name: "ep_change_udp_2", - prev: peers(n(1, "foo", withDERP("127.3.3.40:3"), withEP("1.2.3.4:111"))), + prev: peers(n(1, "foo", withDERP(3), withEP("1.2.3.4:111"))), mapRes: &tailcfg.MapResponse{ PeersChangedPatch: []*tailcfg.PeerChange{{ NodeID: 1, Endpoints: eps("1.2.3.4:56"), }}, }, - want: peers(n(1, "foo", withDERP("127.3.3.40:3"), withEP("1.2.3.4:56"))), + want: peers(n(1, "foo", withDERP(3), withEP("1.2.3.4:56"))), wantStats: updateStats{changed: 1}, }, { name: "ep_change_both", - prev: peers(n(1, "foo", withDERP("127.3.3.40:3"), withEP("1.2.3.4:111"))), + prev: peers(n(1, "foo", withDERP(3), withEP("1.2.3.4:111"))), mapRes: &tailcfg.MapResponse{ PeersChangedPatch: []*tailcfg.PeerChange{{ NodeID: 1, @@ -233,7 +241,7 @@ func TestUpdatePeersStateFromResponse(t *testing.T) { Endpoints: eps("1.2.3.4:56"), }}, }, - want: peers(n(1, "foo", withDERP("127.3.3.40:2"), withEP("1.2.3.4:56"))), + want: peers(n(1, "foo", withDERP(2), withEP("1.2.3.4:56"))), wantStats: updateStats{changed: 1}, }, { @@ -340,19 +348,18 @@ func TestUpdatePeersStateFromResponse(t *testing.T) { } ms := newTestMapSession(t, nil) for _, n := range tt.prev { - mak.Set(&ms.peers, n.ID, ptr.To(n.View())) + mak.Set(&ms.peers, n.ID, n.View()) } - ms.rebuildSorted() gotStats := ms.updatePeersStateFromResponse(tt.mapRes) - - got := make([]*tailcfg.Node, len(ms.sortedPeers)) - for i, vp := range ms.sortedPeers { - got[i] = vp.AsStruct() - } if gotStats != tt.wantStats { t.Errorf("got stats = %+v; want %+v", gotStats, tt.wantStats) } + + var got []*tailcfg.Node + for _, vp := range ms.sortedPeers() { + got = append(got, vp.AsStruct()) + } if !reflect.DeepEqual(got, tt.want) { t.Errorf("wrong results\n got: %s\nwant: %s", formatNodes(got), formatNodes(tt.want)) } @@ -745,8 +752,8 @@ func TestPeerChangeDiff(t *testing.T) { }, { name: "patch-derp", - a: &tailcfg.Node{ID: 1, DERP: "127.3.3.40:1"}, - b: &tailcfg.Node{ID: 1, DERP: "127.3.3.40:2"}, + a: &tailcfg.Node{ID: 1, HomeDERP: 1}, + b: &tailcfg.Node{ID: 1, HomeDERP: 2}, want: &tailcfg.PeerChange{NodeID: 1, DERPRegion: 2}, }, { @@ -930,23 +937,23 @@ func TestPatchifyPeersChanged(t *testing.T) { mr0: &tailcfg.MapResponse{ Node: &tailcfg.Node{Name: "foo.bar.ts.net."}, Peers: []*tailcfg.Node{ - {ID: 1, DERP: "127.3.3.40:1", Hostinfo: hi}, - {ID: 2, DERP: "127.3.3.40:2", Hostinfo: hi}, - {ID: 3, DERP: "127.3.3.40:3", Hostinfo: hi}, + {ID: 1, HomeDERP: 1, Hostinfo: hi}, + {ID: 2, HomeDERP: 2, Hostinfo: hi}, + {ID: 3, HomeDERP: 3, Hostinfo: hi}, }, }, mr1: &tailcfg.MapResponse{ PeersChanged: []*tailcfg.Node{ - {ID: 1, DERP: "127.3.3.40:11", Hostinfo: hi}, + {ID: 1, HomeDERP: 11, Hostinfo: hi}, {ID: 2, StableID: "other-change", Hostinfo: hi}, - {ID: 3, DERP: "127.3.3.40:33", Hostinfo: hi}, - {ID: 4, DERP: "127.3.3.40:4", Hostinfo: hi}, + {ID: 3, HomeDERP: 33, Hostinfo: hi}, + {ID: 4, HomeDERP: 4, Hostinfo: hi}, }, }, want: &tailcfg.MapResponse{ PeersChanged: []*tailcfg.Node{ {ID: 2, StableID: "other-change", Hostinfo: hi}, - {ID: 4, DERP: "127.3.3.40:4", Hostinfo: hi}, + {ID: 4, HomeDERP: 4, Hostinfo: hi}, }, PeersChangedPatch: []*tailcfg.PeerChange{ {NodeID: 1, DERPRegion: 11}, @@ -1007,6 +1014,85 @@ func TestPatchifyPeersChanged(t *testing.T) { } } +func TestUpgradeNode(t *testing.T) { + a1 := netip.MustParsePrefix("0.0.0.1/32") + a2 := netip.MustParsePrefix("0.0.0.2/32") + a3 := netip.MustParsePrefix("0.0.0.3/32") + a4 := netip.MustParsePrefix("0.0.0.4/32") + + tests := []struct { + name string + in *tailcfg.Node + want *tailcfg.Node + also func(t *testing.T, got *tailcfg.Node) // optional + }{ + { + name: "nil", + in: nil, + want: nil, + }, + { + name: "empty", + in: new(tailcfg.Node), + want: new(tailcfg.Node), + }, + { + name: "derp-both", + in: &tailcfg.Node{HomeDERP: 1, LegacyDERPString: tailcfg.DerpMagicIP + ":2"}, + want: &tailcfg.Node{HomeDERP: 1}, + }, + { + name: "derp-str-only", + in: &tailcfg.Node{LegacyDERPString: tailcfg.DerpMagicIP + ":2"}, + want: &tailcfg.Node{HomeDERP: 2}, + }, + { + name: "derp-int-only", + in: &tailcfg.Node{HomeDERP: 2}, + want: &tailcfg.Node{HomeDERP: 2}, + }, + { + name: "implicit-allowed-ips-all-set", + in: &tailcfg.Node{Addresses: []netip.Prefix{a1, a2}, AllowedIPs: []netip.Prefix{a3, a4}}, + want: &tailcfg.Node{Addresses: []netip.Prefix{a1, a2}, AllowedIPs: []netip.Prefix{a3, a4}}, + }, + { + name: "implicit-allowed-ips-only-address-set", + in: &tailcfg.Node{Addresses: []netip.Prefix{a1, a2}}, + want: &tailcfg.Node{Addresses: []netip.Prefix{a1, a2}, AllowedIPs: []netip.Prefix{a1, a2}}, + also: func(t *testing.T, got *tailcfg.Node) { + if t.Failed() { + return + } + if &got.Addresses[0] == &got.AllowedIPs[0] { + t.Error("Addresses and AllowIPs alias the same memory") + } + }, + }, + { + name: "implicit-allowed-ips-set-empty-slice", + in: &tailcfg.Node{Addresses: []netip.Prefix{a1, a2}, AllowedIPs: []netip.Prefix{}}, + want: &tailcfg.Node{Addresses: []netip.Prefix{a1, a2}, AllowedIPs: []netip.Prefix{}}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var got *tailcfg.Node + if tt.in != nil { + got = ptr.To(*tt.in) // shallow clone + } + upgradeNode(got) + if diff := cmp.Diff(tt.want, got); diff != "" { + t.Errorf("wrong result (-want +got):\n%s", diff) + } + if tt.also != nil { + tt.also(t, got) + } + }) + } + +} + func BenchmarkMapSessionDelta(b *testing.B) { for _, size := range []int{10, 100, 1_000, 10_000} { b.Run(fmt.Sprintf("size_%d", size), func(b *testing.B) { @@ -1023,7 +1109,7 @@ func BenchmarkMapSessionDelta(b *testing.B) { res.Peers = append(res.Peers, &tailcfg.Node{ ID: tailcfg.NodeID(i + 2), Name: fmt.Sprintf("peer%d.bar.ts.net.", i), - DERP: "127.3.3.40:10", + HomeDERP: 10, Addresses: []netip.Prefix{netip.MustParsePrefix("100.100.2.3/32"), netip.MustParsePrefix("fd7a:115c:a1e0::123/128")}, AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.100.2.3/32"), netip.MustParsePrefix("fd7a:115c:a1e0::123/128")}, Endpoints: eps("192.168.1.2:345", "192.168.1.3:678"), @@ -1058,3 +1144,342 @@ func BenchmarkMapSessionDelta(b *testing.B) { }) } } + +// TestNetmapDisplayMessage checks that the various diff operations +// (add/update/delete/clear) for [tailcfg.DisplayMessage] in a +// [tailcfg.MapResponse] work as expected. +func TestNetmapDisplayMessage(t *testing.T) { + type test struct { + name string + initialState *tailcfg.MapResponse + mapResponse tailcfg.MapResponse + wantMessages map[tailcfg.DisplayMessageID]tailcfg.DisplayMessage + } + + tests := []test{ + { + name: "basic-set", + mapResponse: tailcfg.MapResponse{ + DisplayMessages: map[tailcfg.DisplayMessageID]*tailcfg.DisplayMessage{ + "test-message": { + Title: "Testing", + Text: "This is a test message", + Severity: tailcfg.SeverityHigh, + ImpactsConnectivity: true, + PrimaryAction: &tailcfg.DisplayMessageAction{ + URL: "https://www.example.com", + Label: "Learn more", + }, + }, + }, + }, + wantMessages: map[tailcfg.DisplayMessageID]tailcfg.DisplayMessage{ + "test-message": { + Title: "Testing", + Text: "This is a test message", + Severity: tailcfg.SeverityHigh, + ImpactsConnectivity: true, + PrimaryAction: &tailcfg.DisplayMessageAction{ + URL: "https://www.example.com", + Label: "Learn more", + }, + }, + }, + }, + { + name: "delete-one", + initialState: &tailcfg.MapResponse{ + DisplayMessages: map[tailcfg.DisplayMessageID]*tailcfg.DisplayMessage{ + "message-a": { + Title: "Message A", + }, + "message-b": { + Title: "Message B", + }, + }, + }, + mapResponse: tailcfg.MapResponse{ + DisplayMessages: map[tailcfg.DisplayMessageID]*tailcfg.DisplayMessage{ + "message-a": nil, + }, + }, + wantMessages: map[tailcfg.DisplayMessageID]tailcfg.DisplayMessage{ + "message-b": { + Title: "Message B", + }, + }, + }, + { + name: "update-one", + initialState: &tailcfg.MapResponse{ + DisplayMessages: map[tailcfg.DisplayMessageID]*tailcfg.DisplayMessage{ + "message-a": { + Title: "Message A", + }, + "message-b": { + Title: "Message B", + }, + }, + }, + mapResponse: tailcfg.MapResponse{ + DisplayMessages: map[tailcfg.DisplayMessageID]*tailcfg.DisplayMessage{ + "message-a": { + Title: "Message A updated", + }, + }, + }, + wantMessages: map[tailcfg.DisplayMessageID]tailcfg.DisplayMessage{ + "message-a": { + Title: "Message A updated", + }, + "message-b": { + Title: "Message B", + }, + }, + }, + { + name: "add-one", + initialState: &tailcfg.MapResponse{ + DisplayMessages: map[tailcfg.DisplayMessageID]*tailcfg.DisplayMessage{ + "message-a": { + Title: "Message A", + }, + }, + }, + mapResponse: tailcfg.MapResponse{ + DisplayMessages: map[tailcfg.DisplayMessageID]*tailcfg.DisplayMessage{ + "message-b": { + Title: "Message B", + }, + }, + }, + wantMessages: map[tailcfg.DisplayMessageID]tailcfg.DisplayMessage{ + "message-a": { + Title: "Message A", + }, + "message-b": { + Title: "Message B", + }, + }, + }, + { + name: "delete-all", + initialState: &tailcfg.MapResponse{ + DisplayMessages: map[tailcfg.DisplayMessageID]*tailcfg.DisplayMessage{ + "message-a": { + Title: "Message A", + }, + "message-b": { + Title: "Message B", + }, + }, + }, + mapResponse: tailcfg.MapResponse{ + DisplayMessages: map[tailcfg.DisplayMessageID]*tailcfg.DisplayMessage{ + "*": nil, + }, + }, + wantMessages: map[tailcfg.DisplayMessageID]tailcfg.DisplayMessage{}, + }, + { + name: "delete-all-and-add", + initialState: &tailcfg.MapResponse{ + DisplayMessages: map[tailcfg.DisplayMessageID]*tailcfg.DisplayMessage{ + "message-a": { + Title: "Message A", + }, + "message-b": { + Title: "Message B", + }, + }, + }, + mapResponse: tailcfg.MapResponse{ + DisplayMessages: map[tailcfg.DisplayMessageID]*tailcfg.DisplayMessage{ + "*": nil, + "message-c": { + Title: "Message C", + }, + }, + }, + wantMessages: map[tailcfg.DisplayMessageID]tailcfg.DisplayMessage{ + "message-c": { + Title: "Message C", + }, + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ms := newTestMapSession(t, nil) + + if test.initialState != nil { + ms.netmapForResponse(test.initialState) + } + + nm := ms.netmapForResponse(&test.mapResponse) + + if diff := cmp.Diff(test.wantMessages, nm.DisplayMessages, cmpopts.EquateEmpty()); diff != "" { + t.Errorf("unexpected warnings (-want +got):\n%s", diff) + } + }) + } +} + +// TestNetmapHealthIntegration checks that we get the expected health warnings +// from processing a [tailcfg.MapResponse] containing health messages and passing the +// [netmap.NetworkMap] to a [health.Tracker]. +func TestNetmapHealthIntegration(t *testing.T) { + ms := newTestMapSession(t, nil) + ht := health.NewTracker(eventbustest.NewBus(t)) + + ht.SetIPNState("NeedsLogin", true) + ht.GotStreamedMapResponse() + + nm := ms.netmapForResponse(&tailcfg.MapResponse{ + Health: []string{ + "Test message", + "Another message", + }, + }) + ht.SetControlHealth(nm.DisplayMessages) + + want := map[health.WarnableCode]health.UnhealthyState{ + "control-health.health-c0719e9a8d5d838d861dc6f675c899d2b309a3a65bb9fe6b11e5afcbf9a2c0b1": { + WarnableCode: "control-health.health-c0719e9a8d5d838d861dc6f675c899d2b309a3a65bb9fe6b11e5afcbf9a2c0b1", + Title: "Coordination server reports an issue", + Severity: health.SeverityMedium, + Text: "The coordination server is reporting a health issue: Test message", + }, + "control-health.health-1dc7017a73a3c55c0d6a8423e3813c7ab6562d9d3064c2ec6ac7822f61b1db9c": { + WarnableCode: "control-health.health-1dc7017a73a3c55c0d6a8423e3813c7ab6562d9d3064c2ec6ac7822f61b1db9c", + Title: "Coordination server reports an issue", + Severity: health.SeverityMedium, + Text: "The coordination server is reporting a health issue: Another message", + }, + } + + got := maps.Clone(ht.CurrentState().Warnings) + for k := range got { + if !strings.HasPrefix(string(k), "control-health") { + delete(got, k) + } + } + + if d := cmp.Diff(want, got, cmpopts.IgnoreFields(health.UnhealthyState{}, "ETag")); d != "" { + t.Fatalf("CurrentStatus().Warnings[\"control-health*\"] different than expected (-want +got)\n%s", d) + } +} + +// TestNetmapDisplayMessageIntegration checks that we get the expected health +// warnings from processing a [tailcfg.MapResponse] that contains DisplayMessages and +// passing the [netmap.NetworkMap] to a [health.Tracker]. +func TestNetmapDisplayMessageIntegration(t *testing.T) { + ms := newTestMapSession(t, nil) + ht := health.NewTracker(eventbustest.NewBus(t)) + + ht.SetIPNState("NeedsLogin", true) + ht.GotStreamedMapResponse() + baseWarnings := ht.CurrentState().Warnings + + nm := ms.netmapForResponse(&tailcfg.MapResponse{ + DisplayMessages: map[tailcfg.DisplayMessageID]*tailcfg.DisplayMessage{ + "test-message": { + Title: "Testing", + Text: "This is a test message", + Severity: tailcfg.SeverityHigh, + ImpactsConnectivity: true, + PrimaryAction: &tailcfg.DisplayMessageAction{ + URL: "https://www.example.com", + Label: "Learn more", + }, + }, + }, + }) + ht.SetControlHealth(nm.DisplayMessages) + + state := ht.CurrentState() + + // Ignore warnings that aren't from the netmap + for k := range baseWarnings { + delete(state.Warnings, k) + } + + want := map[health.WarnableCode]health.UnhealthyState{ + "control-health.test-message": { + WarnableCode: "control-health.test-message", + Title: "Testing", + Text: "This is a test message", + Severity: health.SeverityHigh, + ImpactsConnectivity: true, + PrimaryAction: &health.UnhealthyStateAction{ + URL: "https://www.example.com", + Label: "Learn more", + }, + }, + } + + if diff := cmp.Diff(want, state.Warnings, cmpopts.IgnoreFields(health.UnhealthyState{}, "ETag")); diff != "" { + t.Errorf("unexpected message contents (-want +got):\n%s", diff) + } +} + +func TestNetmapForMapResponseForDebug(t *testing.T) { + mr := &tailcfg.MapResponse{ + Node: &tailcfg.Node{ + ID: 1, + Name: "foo.bar.ts.net.", + }, + Peers: []*tailcfg.Node{ + {ID: 2, Name: "peer1.bar.ts.net.", HomeDERP: 1}, + {ID: 3, Name: "peer2.bar.ts.net.", HomeDERP: 1}, + }, + } + ms := newTestMapSession(t, nil) + nm1 := ms.netmapForResponse(mr) + + prefs := &ipn.Prefs{Persist: &persist.Persist{PrivateNodeKey: ms.privateNodeKey}} + nm2, err := NetmapFromMapResponseForDebug(t.Context(), prefs.View().Persist(), mr) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(nm1, nm2) { + t.Errorf("mismatch\nnm1: %s\nnm2: %s\n", logger.AsJSON(nm1), logger.AsJSON(nm2)) + } +} + +func TestLearnZstdOfKeepAlive(t *testing.T) { + keepAliveMsgZstd := (func() []byte { + msg := must.Get(json.Marshal(tailcfg.MapResponse{ + KeepAlive: true, + })) + return zstdframe.AppendEncode(nil, msg, zstdframe.FastestCompression) + })() + + sess := newTestMapSession(t, nil) + + // The first time we see a zstd keep-alive message, we learn how + // the server encodes that. + var mr tailcfg.MapResponse + must.Do(sess.decodeMsg(keepAliveMsgZstd, &mr)) + if !mr.KeepAlive { + t.Fatal("mr.KeepAlive false; want true") + } + if !bytes.Equal(sess.keepAliveZ, keepAliveMsgZstd) { + t.Fatalf("sess.keepAlive = %q; want %q", sess.keepAliveZ, keepAliveMsgZstd) + } + if got, want := sess.ztdDecodesForTest, 1; got != want { + t.Fatalf("got %d zstd decodes; want %d", got, want) + } + + // The second time on the session where we see that message, we + // decode it without needing to decompress. + var mr2 tailcfg.MapResponse + must.Do(sess.decodeMsg(keepAliveMsgZstd, &mr2)) + if !mr2.KeepAlive { + t.Fatal("mr2.KeepAlive false; want true") + } + if got, want := sess.ztdDecodesForTest, 1; got != want { + t.Fatalf("got %d zstd decodes; want %d", got, want) + } +} diff --git a/control/controlclient/noise.go b/control/controlclient/noise.go deleted file mode 100644 index 3994af056..000000000 --- a/control/controlclient/noise.go +++ /dev/null @@ -1,406 +0,0 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package controlclient - -import ( - "bytes" - "cmp" - "context" - "encoding/json" - "errors" - "math" - "net/http" - "net/url" - "sync" - "time" - - "golang.org/x/net/http2" - "tailscale.com/control/controlhttp" - "tailscale.com/envknob" - "tailscale.com/health" - "tailscale.com/internal/noiseconn" - "tailscale.com/net/dnscache" - "tailscale.com/net/netmon" - "tailscale.com/net/tsdial" - "tailscale.com/tailcfg" - "tailscale.com/tstime" - "tailscale.com/types/key" - "tailscale.com/types/logger" - "tailscale.com/util/mak" - "tailscale.com/util/multierr" - "tailscale.com/util/singleflight" - "tailscale.com/util/testenv" -) - -// NoiseClient provides a http.Client to connect to tailcontrol over -// the ts2021 protocol. -type NoiseClient struct { - // Client is an HTTP client to talk to the coordination server. - // It automatically makes a new Noise connection as needed. - // It does not support node key proofs. To do that, call - // noiseClient.getConn instead to make a connection. - *http.Client - - // h2t is the HTTP/2 transport we use a bit to create new - // *http2.ClientConns. We don't use its connection pool and we don't use its - // dialing. We use it for exactly one reason: its idle timeout that can only - // be configured via the HTTP/1 config. And then we call NewClientConn (with - // an existing Noise connection) on the http2.Transport which sets up an - // http2.ClientConn using that idle timeout from an http1.Transport. - h2t *http2.Transport - - // sfDial ensures that two concurrent requests for a noise connection only - // produce one shared one between the two callers. - sfDial singleflight.Group[struct{}, *noiseconn.Conn] - - dialer *tsdial.Dialer - dnsCache *dnscache.Resolver - privKey key.MachinePrivate - serverPubKey key.MachinePublic - host string // the host part of serverURL - httpPort string // the default port to dial - httpsPort string // the fallback Noise-over-https port or empty if none - - // dialPlan optionally returns a ControlDialPlan previously received - // from the control server; either the function or the return value can - // be nil. - dialPlan func() *tailcfg.ControlDialPlan - - logf logger.Logf - netMon *netmon.Monitor - health *health.Tracker - - // mu only protects the following variables. - mu sync.Mutex - closed bool - last *noiseconn.Conn // or nil - nextID int - connPool map[int]*noiseconn.Conn // active connections not yet closed; see noiseconn.Conn.Close -} - -// NoiseOpts contains options for the NewNoiseClient function. All fields are -// required unless otherwise specified. -type NoiseOpts struct { - // PrivKey is this node's private key. - PrivKey key.MachinePrivate - // ServerPubKey is the public key of the server. - ServerPubKey key.MachinePublic - // ServerURL is the URL of the server to connect to. - ServerURL string - // Dialer's SystemDial function is used to connect to the server. - Dialer *tsdial.Dialer - // DNSCache is the caching Resolver to use to connect to the server. - // - // This field can be nil. - DNSCache *dnscache.Resolver - // Logf is the log function to use. This field can be nil. - Logf logger.Logf - // NetMon is the network monitor that, if set, will be used to get the - // network interface state. This field can be nil; if so, the current - // state will be looked up dynamically. - NetMon *netmon.Monitor - // HealthTracker, if non-nil, is the health tracker to use. - HealthTracker *health.Tracker - // DialPlan, if set, is a function that should return an explicit plan - // on how to connect to the server. - DialPlan func() *tailcfg.ControlDialPlan -} - -// controlIsPlaintext is whether we should assume that the controlplane is only accessible -// over plaintext HTTP (as the first hop, before the ts2021 encryption begins). -// This is used by some tests which don't have a real TLS certificate. -var controlIsPlaintext = envknob.RegisterBool("TS_CONTROL_IS_PLAINTEXT_HTTP") - -// NewNoiseClient returns a new noiseClient for the provided server and machine key. -// serverURL is of the form https://: (no trailing slash). -// -// netMon may be nil, if non-nil it's used to do faster interface lookups. -// dialPlan may be nil -func NewNoiseClient(opts NoiseOpts) (*NoiseClient, error) { - u, err := url.Parse(opts.ServerURL) - if err != nil { - return nil, err - } - var httpPort string - var httpsPort string - if port := u.Port(); port != "" { - // If there is an explicit port specified, trust the scheme and hope for the best - if u.Scheme == "http" { - httpPort = port - httpsPort = "443" - if (testenv.InTest() || controlIsPlaintext()) && (u.Hostname() == "127.0.0.1" || u.Hostname() == "localhost") { - httpsPort = "" - } - } else { - httpPort = "80" - httpsPort = port - } - } else { - // Otherwise, use the standard ports - httpPort = "80" - httpsPort = "443" - } - - np := &NoiseClient{ - serverPubKey: opts.ServerPubKey, - privKey: opts.PrivKey, - host: u.Hostname(), - httpPort: httpPort, - httpsPort: httpsPort, - dialer: opts.Dialer, - dnsCache: opts.DNSCache, - dialPlan: opts.DialPlan, - logf: opts.Logf, - netMon: opts.NetMon, - health: opts.HealthTracker, - } - - // Create the HTTP/2 Transport using a net/http.Transport - // (which only does HTTP/1) because it's the only way to - // configure certain properties on the http2.Transport. But we - // never actually use the net/http.Transport for any HTTP/1 - // requests. - h2Transport, err := http2.ConfigureTransports(&http.Transport{ - IdleConnTimeout: time.Minute, - }) - if err != nil { - return nil, err - } - np.h2t = h2Transport - - np.Client = &http.Client{Transport: np} - return np, nil -} - -// GetSingleUseRoundTripper returns a RoundTripper that can be only be used once -// (and must be used once) to make a single HTTP request over the noise channel -// to the coordination server. -// -// In addition to the RoundTripper, it returns the HTTP/2 channel's early noise -// payload, if any. -func (nc *NoiseClient) GetSingleUseRoundTripper(ctx context.Context) (http.RoundTripper, *tailcfg.EarlyNoise, error) { - for tries := 0; tries < 3; tries++ { - conn, err := nc.getConn(ctx) - if err != nil { - return nil, nil, err - } - ok, earlyPayloadMaybeNil, err := conn.ReserveNewRequest(ctx) - if err != nil { - return nil, nil, err - } - if ok { - return conn, earlyPayloadMaybeNil, nil - } - } - return nil, nil, errors.New("[unexpected] failed to reserve a request on a connection") -} - -// contextErr is an error that wraps another error and is used to indicate that -// the error was because a context expired. -type contextErr struct { - err error -} - -func (e contextErr) Error() string { - return e.err.Error() -} - -func (e contextErr) Unwrap() error { - return e.err -} - -// getConn returns a noiseconn.Conn that can be used to make requests to the -// coordination server. It may return a cached connection or create a new one. -// Dials are singleflighted, so concurrent calls to getConn may only dial once. -// As such, context values may not be respected as there are no guarantees that -// the context passed to getConn is the same as the context passed to dial. -func (nc *NoiseClient) getConn(ctx context.Context) (*noiseconn.Conn, error) { - nc.mu.Lock() - if last := nc.last; last != nil && last.CanTakeNewRequest() { - nc.mu.Unlock() - return last, nil - } - nc.mu.Unlock() - - for { - // We singeflight the dial to avoid making multiple connections, however - // that means that we can't simply cancel the dial if the context is - // canceled. Instead, we have to additionally check that the context - // which was canceled is our context and retry if our context is still - // valid. - conn, err, _ := nc.sfDial.Do(struct{}{}, func() (*noiseconn.Conn, error) { - c, err := nc.dial(ctx) - if err != nil { - if ctx.Err() != nil { - return nil, contextErr{ctx.Err()} - } - return nil, err - } - return c, nil - }) - var ce contextErr - if err == nil || !errors.As(err, &ce) { - return conn, err - } - if ctx.Err() == nil { - // The dial failed because of a context error, but our context - // is still valid. Retry. - continue - } - // The dial failed because our context was canceled. Return the - // underlying error. - return nil, ce.Unwrap() - } -} - -func (nc *NoiseClient) RoundTrip(req *http.Request) (*http.Response, error) { - ctx := req.Context() - conn, err := nc.getConn(ctx) - if err != nil { - return nil, err - } - return conn.RoundTrip(req) -} - -// connClosed removes the connection with the provided ID from the pool -// of active connections. -func (nc *NoiseClient) connClosed(id int) { - nc.mu.Lock() - defer nc.mu.Unlock() - conn := nc.connPool[id] - if conn != nil { - delete(nc.connPool, id) - if nc.last == conn { - nc.last = nil - } - } -} - -// Close closes all the underlying noise connections. -// It is a no-op and returns nil if the connection is already closed. -func (nc *NoiseClient) Close() error { - nc.mu.Lock() - nc.closed = true - conns := nc.connPool - nc.connPool = nil - nc.mu.Unlock() - - var errors []error - for _, c := range conns { - if err := c.Close(); err != nil { - errors = append(errors, err) - } - } - return multierr.New(errors...) -} - -// dial opens a new connection to tailcontrol, fetching the server noise key -// if not cached. -func (nc *NoiseClient) dial(ctx context.Context) (*noiseconn.Conn, error) { - nc.mu.Lock() - connID := nc.nextID - nc.nextID++ - nc.mu.Unlock() - - if tailcfg.CurrentCapabilityVersion > math.MaxUint16 { - // Panic, because a test should have started failing several - // thousand version numbers before getting to this point. - panic("capability version is too high to fit in the wire protocol") - } - - var dialPlan *tailcfg.ControlDialPlan - if nc.dialPlan != nil { - dialPlan = nc.dialPlan() - } - - // If we have a dial plan, then set our timeout as slightly longer than - // the maximum amount of time contained therein; we assume that - // explicit instructions on timeouts are more useful than a single - // hard-coded timeout. - // - // The default value of 5 is chosen so that, when there's no dial plan, - // we retain the previous behaviour of 10 seconds end-to-end timeout. - timeoutSec := 5.0 - if dialPlan != nil { - for _, c := range dialPlan.Candidates { - if v := c.DialStartDelaySec + c.DialTimeoutSec; v > timeoutSec { - timeoutSec = v - } - } - } - - // After we establish a connection, we need some time to actually - // upgrade it into a Noise connection. With a ballpark worst-case RTT - // of 1000ms, give ourselves an extra 5 seconds to complete the - // handshake. - timeoutSec += 5 - - // Be extremely defensive and ensure that the timeout is in the range - // [5, 60] seconds (e.g. if we accidentally get a negative number). - if timeoutSec > 60 { - timeoutSec = 60 - } else if timeoutSec < 5 { - timeoutSec = 5 - } - - timeout := time.Duration(timeoutSec * float64(time.Second)) - ctx, cancel := context.WithTimeout(ctx, timeout) - defer cancel() - - clientConn, err := (&controlhttp.Dialer{ - Hostname: nc.host, - HTTPPort: nc.httpPort, - HTTPSPort: cmp.Or(nc.httpsPort, controlhttp.NoPort), - MachineKey: nc.privKey, - ControlKey: nc.serverPubKey, - ProtocolVersion: uint16(tailcfg.CurrentCapabilityVersion), - Dialer: nc.dialer.SystemDial, - DNSCache: nc.dnsCache, - DialPlan: dialPlan, - Logf: nc.logf, - NetMon: nc.netMon, - HealthTracker: nc.health, - Clock: tstime.StdClock{}, - }).Dial(ctx) - if err != nil { - return nil, err - } - - ncc, err := noiseconn.New(clientConn.Conn, nc.h2t, connID, nc.connClosed) - if err != nil { - return nil, err - } - - nc.mu.Lock() - if nc.closed { - nc.mu.Unlock() - ncc.Close() // Needs to be called without holding the lock. - return nil, errors.New("noise client closed") - } - defer nc.mu.Unlock() - mak.Set(&nc.connPool, connID, ncc) - nc.last = ncc - return ncc, nil -} - -// post does a POST to the control server at the given path, JSON-encoding body. -// The provided nodeKey is an optional load balancing hint. -func (nc *NoiseClient) post(ctx context.Context, path string, nodeKey key.NodePublic, body any) (*http.Response, error) { - jbody, err := json.Marshal(body) - if err != nil { - return nil, err - } - req, err := http.NewRequestWithContext(ctx, "POST", "https://"+nc.host+path, bytes.NewReader(jbody)) - if err != nil { - return nil, err - } - addLBHeader(req, nodeKey) - req.Header.Set("Content-Type", "application/json") - - conn, err := nc.getConn(ctx) - if err != nil { - return nil, err - } - return conn.RoundTrip(req) -} diff --git a/control/controlclient/sign_supported.go b/control/controlclient/sign_supported.go index 0e3dd038e..439e6d36b 100644 --- a/control/controlclient/sign_supported.go +++ b/control/controlclient/sign_supported.go @@ -13,20 +13,15 @@ import ( "crypto/x509" "errors" "fmt" - "sync" "time" "github.com/tailscale/certstore" "tailscale.com/tailcfg" "tailscale.com/types/key" - "tailscale.com/util/syspolicy" + "tailscale.com/util/syspolicy/pkey" + "tailscale.com/util/syspolicy/policyclient" ) -var getMachineCertificateSubjectOnce struct { - sync.Once - v string // Subject of machine certificate to search for -} - // getMachineCertificateSubject returns the exact name of a Subject that needs // to be present in an identity's certificate chain to sign a RegisterRequest, // formatted as per pkix.Name.String(). The Subject may be that of the identity @@ -36,12 +31,9 @@ var getMachineCertificateSubjectOnce struct { // each RegisterRequest will be unsigned. // // Example: "CN=Tailscale Inc Test Root CA,OU=Tailscale Inc Test Certificate Authority,O=Tailscale Inc,ST=ON,C=CA" -func getMachineCertificateSubject() string { - getMachineCertificateSubjectOnce.Do(func() { - getMachineCertificateSubjectOnce.v, _ = syspolicy.GetString(syspolicy.MachineCertificateSubject, "") - }) - - return getMachineCertificateSubjectOnce.v +func getMachineCertificateSubject(polc policyclient.Client) string { + machineCertSubject, _ := polc.GetString(pkey.MachineCertificateSubject, "") + return machineCertSubject } var ( @@ -145,7 +137,7 @@ func findIdentity(subject string, st certstore.Store) (certstore.Identity, []*x5 // using that identity's public key. In addition to the signature, the full // certificate chain is included so that the control server can validate the // certificate from a copy of the root CA's certificate. -func signRegisterRequest(req *tailcfg.RegisterRequest, serverURL string, serverPubKey, machinePubKey key.MachinePublic) (err error) { +func signRegisterRequest(polc policyclient.Client, req *tailcfg.RegisterRequest, serverURL string, serverPubKey, machinePubKey key.MachinePublic) (err error) { defer func() { if err != nil { err = fmt.Errorf("signRegisterRequest: %w", err) @@ -156,7 +148,7 @@ func signRegisterRequest(req *tailcfg.RegisterRequest, serverURL string, serverP return errBadRequest } - machineCertificateSubject := getMachineCertificateSubject() + machineCertificateSubject := getMachineCertificateSubject(polc) if machineCertificateSubject == "" { return errCertificateNotConfigured } diff --git a/control/controlclient/sign_unsupported.go b/control/controlclient/sign_unsupported.go index 5e161dcbc..f6c4ddc62 100644 --- a/control/controlclient/sign_unsupported.go +++ b/control/controlclient/sign_unsupported.go @@ -8,9 +8,10 @@ package controlclient import ( "tailscale.com/tailcfg" "tailscale.com/types/key" + "tailscale.com/util/syspolicy/policyclient" ) // signRegisterRequest on non-supported platforms always returns errNoCertStore. -func signRegisterRequest(req *tailcfg.RegisterRequest, serverURL string, serverPubKey, machinePubKey key.MachinePublic) error { +func signRegisterRequest(polc policyclient.Client, req *tailcfg.RegisterRequest, serverURL string, serverPubKey, machinePubKey key.MachinePublic) error { return errNoCertStore } diff --git a/control/controlclient/status.go b/control/controlclient/status.go index d0fdf80d7..65afb7a50 100644 --- a/control/controlclient/status.go +++ b/control/controlclient/status.go @@ -4,8 +4,6 @@ package controlclient import ( - "encoding/json" - "fmt" "reflect" "tailscale.com/types/netmap" @@ -13,57 +11,6 @@ import ( "tailscale.com/types/structs" ) -// State is the high-level state of the client. It is used only in -// unit tests for proper sequencing, don't depend on it anywhere else. -// -// TODO(apenwarr): eliminate the state, as it's now obsolete. -// -// apenwarr: Historical note: controlclient.Auto was originally -// intended to be the state machine for the whole tailscale client, but that -// turned out to not be the right abstraction layer, and it moved to -// ipn.Backend. Since ipn.Backend now has a state machine, it would be -// much better if controlclient could be a simple stateless API. But the -// current server-side API (two interlocking polling https calls) makes that -// very hard to implement. A server side API change could untangle this and -// remove all the statefulness. -type State int - -const ( - StateNew = State(iota) - StateNotAuthenticated - StateAuthenticating - StateURLVisitRequired - StateAuthenticated - StateSynchronized // connected and received map update -) - -func (s State) AppendText(b []byte) ([]byte, error) { - return append(b, s.String()...), nil -} - -func (s State) MarshalText() ([]byte, error) { - return []byte(s.String()), nil -} - -func (s State) String() string { - switch s { - case StateNew: - return "state:new" - case StateNotAuthenticated: - return "state:not-authenticated" - case StateAuthenticating: - return "state:authenticating" - case StateURLVisitRequired: - return "state:url-visit-required" - case StateAuthenticated: - return "state:authenticated" - case StateSynchronized: - return "state:synchronized" - default: - return fmt.Sprintf("state:unknown:%d", int(s)) - } -} - type Status struct { _ structs.Incomparable @@ -76,6 +23,14 @@ type Status struct { // URL, if non-empty, is the interactive URL to visit to finish logging in. URL string + // LoggedIn, if true, indicates that serveRegister has completed and no + // other login change is in progress. + LoggedIn bool + + // InMapPoll, if true, indicates that we've received at least one netmap + // and are connected to receive updates. + InMapPoll bool + // NetMap is the latest server-pushed state of the tailnet network. NetMap *netmap.NetworkMap @@ -83,26 +38,8 @@ type Status struct { // // TODO(bradfitz,maisem): clarify this. Persist persist.PersistView - - // state is the internal state. It should not be exposed outside this - // package, but we have some automated tests elsewhere that need to - // use it via the StateForTest accessor. - // TODO(apenwarr): Unexport or remove these. - state State } -// LoginFinished reports whether the controlclient is in its "StateAuthenticated" -// state where it's in a happy register state but not yet in a map poll. -// -// TODO(bradfitz): delete this and everything around Status.state. -func (s *Status) LoginFinished() bool { return s.state == StateAuthenticated } - -// StateForTest returns the internal state of s for tests only. -func (s *Status) StateForTest() State { return s.state } - -// SetStateForTest sets the internal state of s for tests only. -func (s *Status) SetStateForTest(state State) { s.state = state } - // Equal reports whether s and s2 are equal. func (s *Status) Equal(s2 *Status) bool { if s == nil && s2 == nil { @@ -111,15 +48,8 @@ func (s *Status) Equal(s2 *Status) bool { return s != nil && s2 != nil && s.Err == s2.Err && s.URL == s2.URL && - s.state == s2.state && + s.LoggedIn == s2.LoggedIn && + s.InMapPoll == s2.InMapPoll && reflect.DeepEqual(s.Persist, s2.Persist) && reflect.DeepEqual(s.NetMap, s2.NetMap) } - -func (s Status) String() string { - b, err := json.MarshalIndent(s, "", "\t") - if err != nil { - panic(err) - } - return s.state.String() + " " + string(b) -} diff --git a/control/controlhttp/client.go b/control/controlhttp/client.go index 7e5263e33..06a2131fd 100644 --- a/control/controlhttp/client.go +++ b/control/controlhttp/client.go @@ -20,36 +20,37 @@ package controlhttp import ( + "cmp" "context" "crypto/tls" "encoding/base64" "errors" "fmt" "io" - "math" "net" "net/http" "net/http/httptrace" "net/netip" "net/url" "runtime" - "sort" "sync/atomic" "time" "tailscale.com/control/controlbase" + "tailscale.com/control/controlhttp/controlhttpcommon" "tailscale.com/envknob" + "tailscale.com/feature" + "tailscale.com/feature/buildfeatures" "tailscale.com/health" "tailscale.com/net/dnscache" "tailscale.com/net/dnsfallback" "tailscale.com/net/netutil" + "tailscale.com/net/netx" "tailscale.com/net/sockstats" "tailscale.com/net/tlsdial" - "tailscale.com/net/tshttpproxy" "tailscale.com/syncs" "tailscale.com/tailcfg" "tailscale.com/tstime" - "tailscale.com/util/multierr" ) var stdDialer net.Dialer @@ -80,7 +81,7 @@ func (a *Dialer) getProxyFunc() func(*http.Request) (*url.URL, error) { if a.proxyFunc != nil { return a.proxyFunc } - return tshttpproxy.ProxyFromEnvironment + return feature.HookProxyFromEnvironment.GetOrNil() } // httpsFallbackDelay is how long we'll wait for a.HTTPPort to work before @@ -95,161 +96,78 @@ func (a *Dialer) httpsFallbackDelay() time.Duration { var _ = envknob.RegisterBool("TS_USE_CONTROL_DIAL_PLAN") // to record at init time whether it's in use func (a *Dialer) dial(ctx context.Context) (*ClientConn, error) { + + a.logPort80Failure.Store(true) + // If we don't have a dial plan, just fall back to dialing the single // host we know about. useDialPlan := envknob.BoolDefaultTrue("TS_USE_CONTROL_DIAL_PLAN") if !useDialPlan || a.DialPlan == nil || len(a.DialPlan.Candidates) == 0 { - return a.dialHost(ctx, netip.Addr{}) + return a.dialHost(ctx) } candidates := a.DialPlan.Candidates - // Otherwise, we try dialing per the plan. Store the highest priority - // in the list, so that if we get a connection to one of those - // candidates we can return quickly. - var highestPriority int = math.MinInt - for _, c := range candidates { - if c.Priority > highestPriority { - highestPriority = c.Priority - } - } - - // This context allows us to cancel in-flight connections if we get a - // highest-priority connection before we're all done. + // Create a context to be canceled as we return, so once we get a good connection, + // we can drop all the other ones. ctx, cancel := context.WithCancel(ctx) defer cancel() // Now, for each candidate, kick off a dial in parallel. type dialResult struct { - conn *ClientConn - err error - addr netip.Addr - priority int - } - resultsCh := make(chan dialResult, len(candidates)) - - var pending atomic.Int32 - pending.Store(int32(len(candidates))) - for _, c := range candidates { - go func(ctx context.Context, c tailcfg.ControlIPCandidate) { - var ( - conn *ClientConn - err error - ) - - // Always send results back to our channel. - defer func() { - resultsCh <- dialResult{conn, err, c.IP, c.Priority} - if pending.Add(-1) == 0 { - close(resultsCh) - } - }() - - // If non-zero, wait the configured start timeout - // before we do anything. - if c.DialStartDelaySec > 0 { - a.logf("[v2] controlhttp: waiting %.2f seconds before dialing %q @ %v", c.DialStartDelaySec, a.Hostname, c.IP) - tmr, tmrChannel := a.clock().NewTimer(time.Duration(c.DialStartDelaySec * float64(time.Second))) - defer tmr.Stop() - select { - case <-ctx.Done(): - err = ctx.Err() - return - case <-tmrChannel: - } - } + conn *ClientConn + err error + } + resultsCh := make(chan dialResult) // unbuffered, never closed - // Now, create a sub-context with the given timeout and - // try dialing the provided host. - ctx, cancel := context.WithTimeout(ctx, time.Duration(c.DialTimeoutSec*float64(time.Second))) - defer cancel() + dialCand := func(cand tailcfg.ControlIPCandidate) (*ClientConn, error) { + if cand.ACEHost != "" { + a.logf("[v2] controlhttp: waited %.2f seconds, dialing %q via ACE %s (%s)", cand.DialStartDelaySec, a.Hostname, cand.ACEHost, cmp.Or(cand.IP.String(), "dns")) + } else { + a.logf("[v2] controlhttp: waited %.2f seconds, dialing %q @ %s", cand.DialStartDelaySec, a.Hostname, cand.IP.String()) + } - // This will dial, and the defer above sends it back to our parent. - a.logf("[v2] controlhttp: trying to dial %q @ %v", a.Hostname, c.IP) - conn, err = a.dialHost(ctx, c.IP) - }(ctx, c) + ctx, cancel := context.WithTimeout(ctx, time.Duration(cand.DialTimeoutSec*float64(time.Second))) + defer cancel() + return a.dialHostOpt(ctx, cand.IP, cand.ACEHost) } - var results []dialResult - for res := range resultsCh { - // If we get a response that has the highest priority, we don't - // need to wait for any of the other connections to finish; we - // can just return this connection. - // - // TODO(andrew): we could make this better by keeping track of - // the highest remaining priority dynamically, instead of just - // checking for the highest total - if res.priority == highestPriority && res.conn != nil { - a.logf("[v1] controlhttp: high-priority success dialing %q @ %v from dial plan", a.Hostname, res.addr) - - // Drain the channel and any existing connections in - // the background. + for _, cand := range candidates { + timer := time.AfterFunc(time.Duration(cand.DialStartDelaySec*float64(time.Second)), func() { go func() { - for _, res := range results { - if res.conn != nil { - res.conn.Close() + conn, err := dialCand(cand) + select { + case resultsCh <- dialResult{conn, err}: + if err == nil { + a.logf("[v1] controlhttp: succeeded dialing %q @ %v from dial plan", a.Hostname, cmp.Or(cand.ACEHost, cand.IP.String())) } - } - for res := range resultsCh { - if res.conn != nil { - res.conn.Close() + case <-ctx.Done(): + if conn != nil { + conn.Close() } } - if a.drainFinished != nil { - close(a.drainFinished) - } }() - return res.conn, nil - } - - // This isn't a highest-priority result, so just store it until - // we're done. - results = append(results, res) + }) + defer timer.Stop() } - // After we finish this function, close any remaining open connections. - defer func() { - for _, result := range results { - // Note: below, we nil out the returned connection (if - // any) in the slice so we don't close it. - if result.conn != nil { - result.conn.Close() + var errs []error + for { + select { + case res := <-resultsCh: + if res.err == nil { + return res.conn, nil } + errs = append(errs, res.err) + if len(errs) == len(candidates) { + // If we get here, then we didn't get anywhere with our dial plan; fall back to just using DNS. + a.logf("controlhttp: failed dialing using DialPlan, falling back to DNS; errs=%s", errors.Join(errs...)) + return a.dialHost(ctx) + } + case <-ctx.Done(): + a.logf("controlhttp: context aborted dialing") + return nil, ctx.Err() } - - // We don't drain asynchronously after this point, so notify our - // channel when we return. - if a.drainFinished != nil { - close(a.drainFinished) - } - }() - - // Sort by priority, then take the first non-error response. - sort.Slice(results, func(i, j int) bool { - // NOTE: intentionally inverted so that the highest priority - // item comes first - return results[i].priority > results[j].priority - }) - - var ( - conn *ClientConn - errs []error - ) - for i, result := range results { - if result.err != nil { - errs = append(errs, result.err) - continue - } - - a.logf("[v1] controlhttp: succeeded dialing %q @ %v from dial plan", a.Hostname, result.addr) - conn = result.conn - results[i].conn = nil // so we don't close it in the defer - return conn, nil } - merr := multierr.New(errs...) - - // If we get here, then we didn't get anywhere with our dial plan; fall back to just using DNS. - a.logf("controlhttp: failed dialing using DialPlan, falling back to DNS; errs=%s", merr.Error()) - return a.dialHost(ctx, netip.Addr{}) } // The TS_FORCE_NOISE_443 envknob forces the controlclient noise dialer to @@ -266,6 +184,15 @@ var forceNoise443 = envknob.RegisterBool("TS_FORCE_NOISE_443") // use HTTPS connections as its underlay connection (double crypto). This can // be necessary when networks or middle boxes are messing with port 80. func (d *Dialer) forceNoise443() bool { + if runtime.GOOS == "plan9" { + // For running demos of Plan 9 in a browser with network relays, + // we want to minimize the number of connections we're making. + // The main reason to use port 80 is to avoid double crypto + // costs server-side but the costs are tiny and number of Plan 9 + // users doesn't make it worth it. Just disable this and always use + // HTTPS for Plan 9. That also reduces some log spam. + return true + } if forceNoise443() { return true } @@ -277,7 +204,9 @@ func (d *Dialer) forceNoise443() bool { // This heuristic works around networks where port 80 is MITMed and // appears to work for a bit post-Upgrade but then gets closed, // such as seen in https://github.com/tailscale/tailscale/issues/13597. - d.logf("controlhttp: forcing port 443 dial due to recent noise dial") + if d.logPort80Failure.CompareAndSwap(true, false) { + d.logf("controlhttp: forcing port 443 dial due to recent noise dial") + } return true } @@ -295,10 +224,19 @@ var debugNoiseDial = envknob.RegisterBool("TS_DEBUG_NOISE_DIAL") // dialHost connects to the configured Dialer.Hostname and upgrades the // connection into a controlbase.Conn. +func (a *Dialer) dialHost(ctx context.Context) (*ClientConn, error) { + return a.dialHostOpt(ctx, + netip.Addr{}, // no pre-resolved IP + "", // don't use ACE + ) +} + +// dialHostOpt connects to the configured Dialer.Hostname and upgrades the +// connection into a controlbase.Conn. // // If optAddr is valid, then no DNS is used and the connection will be made to the // provided address. -func (a *Dialer) dialHost(ctx context.Context, optAddr netip.Addr) (*ClientConn, error) { +func (a *Dialer) dialHostOpt(ctx context.Context, optAddr netip.Addr, optACEHost string) (*ClientConn, error) { // Create one shared context used by both port 80 and port 443 dials. // If port 80 is still in flight when 443 returns, this deferred cancel // will stop the port 80 dial. @@ -320,7 +258,7 @@ func (a *Dialer) dialHost(ctx context.Context, optAddr netip.Addr) (*ClientConn, Host: net.JoinHostPort(a.Hostname, strDef(a.HTTPSPort, "443")), Path: serverUpgradePath, } - if a.HTTPSPort == NoPort { + if a.HTTPSPort == NoPort || optACEHost != "" { u443 = nil } @@ -332,11 +270,11 @@ func (a *Dialer) dialHost(ctx context.Context, optAddr netip.Addr) (*ClientConn, ch := make(chan tryURLRes) // must be unbuffered try := func(u *url.URL) { if debugNoiseDial() { - a.logf("trying noise dial (%v, %v) ...", u, optAddr) + a.logf("trying noise dial (%v, %v) ...", u, cmp.Or(optACEHost, optAddr.String())) } - cbConn, err := a.dialURL(ctx, u, optAddr) + cbConn, err := a.dialURL(ctx, u, optAddr, optACEHost) if debugNoiseDial() { - a.logf("noise dial (%v, %v) = (%v, %v)", u, optAddr, cbConn, err) + a.logf("noise dial (%v, %v) = (%v, %v)", u, cmp.Or(optACEHost, optAddr.String()), cbConn, err) } select { case ch <- tryURLRes{u, cbConn, err}: @@ -367,6 +305,9 @@ func (a *Dialer) dialHost(ctx context.Context, optAddr netip.Addr) (*ClientConn, } var err80, err443 error + if forceTLS { + err80 = errors.New("TLS forced: no port 80 dialed") + } for { select { case <-ctx.Done(): @@ -402,12 +343,12 @@ func (a *Dialer) dialHost(ctx context.Context, optAddr netip.Addr) (*ClientConn, // // If optAddr is valid, then no DNS is used and the connection will be made to the // provided address. -func (a *Dialer) dialURL(ctx context.Context, u *url.URL, optAddr netip.Addr) (*ClientConn, error) { +func (a *Dialer) dialURL(ctx context.Context, u *url.URL, optAddr netip.Addr, optACEHost string) (*ClientConn, error) { init, cont, err := controlbase.ClientDeferred(a.MachineKey, a.ControlKey, a.ProtocolVersion) if err != nil { return nil, err } - netConn, err := a.tryURLUpgrade(ctx, u, optAddr, init) + netConn, err := a.tryURLUpgrade(ctx, u, optAddr, optACEHost, init) if err != nil { return nil, err } @@ -453,13 +394,15 @@ var macOSScreenTime = health.Register(&health.Warnable{ ImpactsConnectivity: true, }) +var HookMakeACEDialer feature.Hook[func(dialer netx.DialFunc, aceHost string, optIP netip.Addr) netx.DialFunc] + // tryURLUpgrade connects to u, and tries to upgrade it to a net.Conn. // // If optAddr is valid, then no DNS is used and the connection will be made to // the provided address. // // Only the provided ctx is used, not a.ctx. -func (a *Dialer) tryURLUpgrade(ctx context.Context, u *url.URL, optAddr netip.Addr, init []byte) (_ net.Conn, retErr error) { +func (a *Dialer) tryURLUpgrade(ctx context.Context, u *url.URL, optAddr netip.Addr, optACEHost string, init []byte) (_ net.Conn, retErr error) { var dns *dnscache.Resolver // If we were provided an address to dial, then create a resolver that just @@ -474,13 +417,24 @@ func (a *Dialer) tryURLUpgrade(ctx context.Context, u *url.URL, optAddr netip.Ad dns = a.resolver() } - var dialer dnscache.DialContextFunc + var dialer netx.DialFunc if a.Dialer != nil { dialer = a.Dialer } else { dialer = stdDialer.DialContext } + if optACEHost != "" { + if !buildfeatures.HasACE { + return nil, feature.ErrUnavailable + } + f, ok := HookMakeACEDialer.GetOk() + if !ok { + return nil, feature.ErrUnavailable + } + dialer = f(dialer, optACEHost, optAddr) + } + // On macOS, see if Screen Time is blocking things. if runtime.GOOS == "darwin" { var proxydIntercepted atomic.Bool // intercepted by macOS webfilterproxyd @@ -507,13 +461,25 @@ func (a *Dialer) tryURLUpgrade(ctx context.Context, u *url.URL, optAddr netip.Ad tr := http.DefaultTransport.(*http.Transport).Clone() defer tr.CloseIdleConnections() - tr.Proxy = a.getProxyFunc() - tshttpproxy.SetTransportGetProxyConnectHeader(tr) - tr.DialContext = dnscache.Dialer(dialer, dns) + if optACEHost != "" { + // If using ACE, we don't want to use any HTTP proxy. + // ACE is already a tunnel+proxy. + // TODO(tailscale/corp#32483): use system proxy too? + tr.Proxy = nil + tr.DialContext = dialer + } else { + if buildfeatures.HasUseProxy { + tr.Proxy = a.getProxyFunc() + if set, ok := feature.HookProxySetTransportGetProxyConnectHeader.GetOk(); ok { + set(tr) + } + } + tr.DialContext = dnscache.Dialer(dialer, dns) + } // Disable HTTP2, since h2 can't do protocol switching. tr.TLSClientConfig.NextProtos = []string{} tr.TLSNextProto = map[string]func(string, *tls.Conn) http.RoundTripper{} - tr.TLSClientConfig = tlsdial.Config(a.Hostname, a.HealthTracker, tr.TLSClientConfig) + tr.TLSClientConfig = tlsdial.Config(a.HealthTracker, tr.TLSClientConfig) if !tr.TLSClientConfig.InsecureSkipVerify { panic("unexpected") // should be set by tlsdial.Config } @@ -571,9 +537,9 @@ func (a *Dialer) tryURLUpgrade(ctx context.Context, u *url.URL, optAddr netip.Ad Method: "POST", URL: u, Header: http.Header{ - "Upgrade": []string{upgradeHeaderValue}, - "Connection": []string{"upgrade"}, - handshakeHeaderName: []string{base64.StdEncoding.EncodeToString(init)}, + "Upgrade": []string{controlhttpcommon.UpgradeHeaderValue}, + "Connection": []string{"upgrade"}, + controlhttpcommon.HandshakeHeaderName: []string{base64.StdEncoding.EncodeToString(init)}, }, } req = req.WithContext(ctx) @@ -597,7 +563,7 @@ func (a *Dialer) tryURLUpgrade(ctx context.Context, u *url.URL, optAddr netip.Ad return nil, fmt.Errorf("httptrace didn't provide a connection") } - if next := resp.Header.Get("Upgrade"); next != upgradeHeaderValue { + if next := resp.Header.Get("Upgrade"); next != controlhttpcommon.UpgradeHeaderValue { resp.Body.Close() return nil, fmt.Errorf("server switched to unexpected protocol %q", next) } diff --git a/control/controlhttp/client_js.go b/control/controlhttp/client_js.go index 4b7126b52..cc05b5b19 100644 --- a/control/controlhttp/client_js.go +++ b/control/controlhttp/client_js.go @@ -12,6 +12,7 @@ import ( "github.com/coder/websocket" "tailscale.com/control/controlbase" + "tailscale.com/control/controlhttp/controlhttpcommon" "tailscale.com/net/wsconn" ) @@ -42,11 +43,11 @@ func (d *Dialer) Dial(ctx context.Context) (*ClientConn, error) { // Can't set HTTP headers on the websocket request, so we have to to send // the handshake via an HTTP header. RawQuery: url.Values{ - handshakeHeaderName: []string{base64.StdEncoding.EncodeToString(init)}, + controlhttpcommon.HandshakeHeaderName: []string{base64.StdEncoding.EncodeToString(init)}, }.Encode(), } wsConn, _, err := websocket.Dial(ctx, wsURL.String(), &websocket.DialOptions{ - Subprotocols: []string{upgradeHeaderValue}, + Subprotocols: []string{controlhttpcommon.UpgradeHeaderValue}, }) if err != nil { return nil, err diff --git a/control/controlhttp/constants.go b/control/controlhttp/constants.go index ea1725e76..359410ae9 100644 --- a/control/controlhttp/constants.go +++ b/control/controlhttp/constants.go @@ -6,11 +6,13 @@ package controlhttp import ( "net/http" "net/url" + "sync/atomic" "time" "tailscale.com/health" "tailscale.com/net/dnscache" "tailscale.com/net/netmon" + "tailscale.com/net/netx" "tailscale.com/tailcfg" "tailscale.com/tstime" "tailscale.com/types/key" @@ -18,15 +20,6 @@ import ( ) const ( - // upgradeHeader is the value of the Upgrade HTTP header used to - // indicate the Tailscale control protocol. - upgradeHeaderValue = "tailscale-control-protocol" - - // handshakeHeaderName is the HTTP request header that can - // optionally contain base64-encoded initial handshake - // payload, to save an RTT. - handshakeHeaderName = "X-Tailscale-Handshake" - // serverUpgradePath is where the server-side HTTP handler to // to do the protocol switch is located. serverUpgradePath = "/ts2021" @@ -74,7 +67,7 @@ type Dialer struct { // Dialer is the dialer used to make outbound connections. // // If not specified, this defaults to net.Dialer.DialContext. - Dialer dnscache.DialContextFunc + Dialer netx.DialFunc // DNSCache is the caching Resolver used by this Dialer. // @@ -85,6 +78,8 @@ type Dialer struct { // dropped. Logf logger.Logf + // NetMon is the [netmon.Monitor] to use for this Dialer. + // It is optional. NetMon *netmon.Monitor // HealthTracker, if non-nil, is the health tracker to use. @@ -97,8 +92,12 @@ type Dialer struct { proxyFunc func(*http.Request) (*url.URL, error) // or nil + // logPort80Failure is whether we should log about port 80 interceptions + // and forcing a port 443 dial. We do this only once per "dial" method + // which can result in many concurrent racing dialHost calls. + logPort80Failure atomic.Bool + // For tests only - drainFinished chan struct{} omitCertErrorLogging bool testFallbackDelay time.Duration diff --git a/control/controlhttp/controlhttpcommon/controlhttpcommon.go b/control/controlhttp/controlhttpcommon/controlhttpcommon.go new file mode 100644 index 000000000..a86b7ca04 --- /dev/null +++ b/control/controlhttp/controlhttpcommon/controlhttpcommon.go @@ -0,0 +1,15 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package controlhttpcommon contains common constants for used +// by the controlhttp client and controlhttpserver packages. +package controlhttpcommon + +// UpgradeHeader is the value of the Upgrade HTTP header used to +// indicate the Tailscale control protocol. +const UpgradeHeaderValue = "tailscale-control-protocol" + +// handshakeHeaderName is the HTTP request header that can +// optionally contain base64-encoded initial handshake +// payload, to save an RTT. +const HandshakeHeaderName = "X-Tailscale-Handshake" diff --git a/control/controlhttp/server.go b/control/controlhttp/controlhttpserver/controlhttpserver.go similarity index 91% rename from control/controlhttp/server.go rename to control/controlhttp/controlhttpserver/controlhttpserver.go index 6a0d2bc56..af3207810 100644 --- a/control/controlhttp/server.go +++ b/control/controlhttp/controlhttpserver/controlhttpserver.go @@ -1,7 +1,10 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -package controlhttp +//go:build !ios + +// Package controlhttpserver contains the HTTP server side of the ts2021 control protocol. +package controlhttpserver import ( "context" @@ -16,6 +19,7 @@ import ( "github.com/coder/websocket" "tailscale.com/control/controlbase" + "tailscale.com/control/controlhttp/controlhttpcommon" "tailscale.com/net/netutil" "tailscale.com/net/wsconn" "tailscale.com/types/key" @@ -43,12 +47,12 @@ func acceptHTTP(ctx context.Context, w http.ResponseWriter, r *http.Request, pri if next == "websocket" { return acceptWebsocket(ctx, w, r, private) } - if next != upgradeHeaderValue { + if next != controlhttpcommon.UpgradeHeaderValue { http.Error(w, "unknown next protocol", http.StatusBadRequest) return nil, fmt.Errorf("client requested unhandled next protocol %q", next) } - initB64 := r.Header.Get(handshakeHeaderName) + initB64 := r.Header.Get(controlhttpcommon.HandshakeHeaderName) if initB64 == "" { http.Error(w, "missing Tailscale handshake header", http.StatusBadRequest) return nil, errors.New("no tailscale handshake header in HTTP request") @@ -65,7 +69,7 @@ func acceptHTTP(ctx context.Context, w http.ResponseWriter, r *http.Request, pri return nil, errors.New("can't hijack client connection") } - w.Header().Set("Upgrade", upgradeHeaderValue) + w.Header().Set("Upgrade", controlhttpcommon.UpgradeHeaderValue) w.Header().Set("Connection", "upgrade") w.WriteHeader(http.StatusSwitchingProtocols) @@ -115,7 +119,7 @@ func acceptHTTP(ctx context.Context, w http.ResponseWriter, r *http.Request, pri // speak HTTP) to a Tailscale control protocol base transport connection. func acceptWebsocket(ctx context.Context, w http.ResponseWriter, r *http.Request, private key.MachinePrivate) (*controlbase.Conn, error) { c, err := websocket.Accept(w, r, &websocket.AcceptOptions{ - Subprotocols: []string{upgradeHeaderValue}, + Subprotocols: []string{controlhttpcommon.UpgradeHeaderValue}, OriginPatterns: []string{"*"}, // Disable compression because we transmit Noise messages that are not // compressible. @@ -127,7 +131,7 @@ func acceptWebsocket(ctx context.Context, w http.ResponseWriter, r *http.Request if err != nil { return nil, fmt.Errorf("Could not accept WebSocket connection %v", err) } - if c.Subprotocol() != upgradeHeaderValue { + if c.Subprotocol() != controlhttpcommon.UpgradeHeaderValue { c.Close(websocket.StatusPolicyViolation, "client must speak the control subprotocol") return nil, fmt.Errorf("Unexpected subprotocol %q", c.Subprotocol()) } @@ -135,7 +139,7 @@ func acceptWebsocket(ctx context.Context, w http.ResponseWriter, r *http.Request c.Close(websocket.StatusPolicyViolation, "Could not parse parameters") return nil, fmt.Errorf("parse query parameters: %v", err) } - initB64 := r.Form.Get(handshakeHeaderName) + initB64 := r.Form.Get(controlhttpcommon.HandshakeHeaderName) if initB64 == "" { c.Close(websocket.StatusPolicyViolation, "missing Tailscale handshake parameter") return nil, errors.New("no tailscale handshake parameter in HTTP request") diff --git a/control/controlhttp/http_test.go b/control/controlhttp/http_test.go index 8c8ed7f57..648b9e5ed 100644 --- a/control/controlhttp/http_test.go +++ b/control/controlhttp/http_test.go @@ -15,15 +15,19 @@ import ( "net/http/httputil" "net/netip" "net/url" - "runtime" "slices" "strconv" + "strings" "sync" "testing" + "testing/synctest" "time" "tailscale.com/control/controlbase" - "tailscale.com/net/dnscache" + "tailscale.com/control/controlhttp/controlhttpcommon" + "tailscale.com/control/controlhttp/controlhttpserver" + "tailscale.com/health" + "tailscale.com/net/memnet" "tailscale.com/net/netmon" "tailscale.com/net/socks5" "tailscale.com/net/tsdial" @@ -32,6 +36,8 @@ import ( "tailscale.com/tstime" "tailscale.com/types/key" "tailscale.com/types/logger" + "tailscale.com/util/eventbus/eventbustest" + "tailscale.com/util/must" ) type httpTestParam struct { @@ -143,6 +149,8 @@ func testControlHTTP(t *testing.T, param httpTestParam) { proxy := param.proxy client, server := key.NewMachine(), key.NewMachine() + bus := eventbustest.NewBus(t) + const testProtocolVersion = 1 const earlyWriteMsg = "Hello, world!" sch := make(chan serverResult, 1) @@ -158,7 +166,7 @@ func testControlHTTP(t *testing.T, param httpTestParam) { return err } } - conn, err := AcceptHTTP(context.Background(), w, r, server, earlyWriteFn) + conn, err := controlhttpserver.AcceptHTTP(context.Background(), w, r, server, earlyWriteFn) if err != nil { log.Print(err) } @@ -212,6 +220,7 @@ func testControlHTTP(t *testing.T, param httpTestParam) { netMon := netmon.NewStatic() dialer := tsdial.NewDialer(netMon) + dialer.SetBus(bus) a := &Dialer{ Hostname: "localhost", HTTPPort: strconv.Itoa(httpLn.Addr().(*net.TCPAddr).Port), @@ -225,6 +234,7 @@ func testControlHTTP(t *testing.T, param httpTestParam) { omitCertErrorLogging: true, testFallbackDelay: fallbackDelay, Clock: clock, + HealthTracker: health.NewTracker(eventbustest.NewBus(t)), } if param.httpInDial { @@ -527,9 +537,31 @@ EKTcWGekdmdDPsHloRNtsiCa697B2O9IFA== } } +// slowListener wraps a memnet listener to delay accept operations +type slowListener struct { + net.Listener + delay time.Duration +} + +func (sl *slowListener) Accept() (net.Conn, error) { + // Add delay before accepting connections + timer := time.NewTimer(sl.delay) + defer timer.Stop() + <-timer.C + + return sl.Listener.Accept() +} + +func newSlowListener(inner net.Listener, delay time.Duration) net.Listener { + return &slowListener{ + Listener: inner, + delay: delay, + } +} + func brokenMITMHandler(clock tstime.Clock) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Upgrade", upgradeHeaderValue) + w.Header().Set("Upgrade", controlhttpcommon.UpgradeHeaderValue) w.Header().Set("Connection", "upgrade") w.WriteHeader(http.StatusSwitchingProtocols) w.(http.Flusher).Flush() @@ -540,33 +572,102 @@ func brokenMITMHandler(clock tstime.Clock) http.HandlerFunc { } func TestDialPlan(t *testing.T) { - if runtime.GOOS != "linux" { - t.Skip("only works on Linux due to multiple localhost addresses") + testCases := []struct { + name string + plan *tailcfg.ControlDialPlan + want []netip.Addr + allowFallback bool + maxDuration time.Duration + }{ + { + name: "single", + plan: &tailcfg.ControlDialPlan{Candidates: []tailcfg.ControlIPCandidate{ + {IP: netip.MustParseAddr("10.0.0.2"), DialTimeoutSec: 10}, + }}, + want: []netip.Addr{netip.MustParseAddr("10.0.0.2")}, + }, + { + name: "broken-then-good", + plan: &tailcfg.ControlDialPlan{Candidates: []tailcfg.ControlIPCandidate{ + {IP: netip.MustParseAddr("10.0.0.10"), DialTimeoutSec: 10}, + {IP: netip.MustParseAddr("10.0.0.2"), DialTimeoutSec: 10, DialStartDelaySec: 1}, + }}, + want: []netip.Addr{netip.MustParseAddr("10.0.0.2")}, + }, + { + name: "multiple-candidates-with-broken", + plan: &tailcfg.ControlDialPlan{Candidates: []tailcfg.ControlIPCandidate{ + // Multiple good IPs plus a broken one + // Should succeed with any of the good ones + {IP: netip.MustParseAddr("10.0.0.10"), DialTimeoutSec: 10}, + {IP: netip.MustParseAddr("10.0.0.2"), DialTimeoutSec: 10}, + {IP: netip.MustParseAddr("10.0.0.4"), DialTimeoutSec: 10}, + {IP: netip.MustParseAddr("10.0.0.3"), DialTimeoutSec: 10}, + }}, + want: []netip.Addr{netip.MustParseAddr("10.0.0.2"), netip.MustParseAddr("10.0.0.4"), netip.MustParseAddr("10.0.0.3")}, + }, + { + name: "multiple-candidates-race", + plan: &tailcfg.ControlDialPlan{Candidates: []tailcfg.ControlIPCandidate{ + {IP: netip.MustParseAddr("10.0.0.10"), DialTimeoutSec: 10}, + {IP: netip.MustParseAddr("10.0.0.3"), DialTimeoutSec: 10}, + {IP: netip.MustParseAddr("10.0.0.2"), DialTimeoutSec: 10}, + }}, + want: []netip.Addr{netip.MustParseAddr("10.0.0.3"), netip.MustParseAddr("10.0.0.2")}, + }, + { + name: "fallback", + plan: &tailcfg.ControlDialPlan{Candidates: []tailcfg.ControlIPCandidate{ + {IP: netip.MustParseAddr("10.0.0.10"), DialTimeoutSec: 1}, + }}, + want: []netip.Addr{netip.MustParseAddr("10.0.0.1")}, + allowFallback: true, + }, + { + // In tailscale/corp#32534 we discovered that a prior implementation + // of the dial race was waiting for all dials to complete when the + // top priority dial was failing. This delay was long enough that in + // real scenarios the server will close the connection due to + // inactivity, because the client does not send the first inside of + // noise request soon enough. This test is a regression guard + // against that behavior - proving that the dial returns promptly + // even if there is some cause of a slow race. + name: "slow-endpoint-doesnt-block", + plan: &tailcfg.ControlDialPlan{Candidates: []tailcfg.ControlIPCandidate{ + {IP: netip.MustParseAddr("10.0.0.12"), Priority: 5, DialTimeoutSec: 10}, + {IP: netip.MustParseAddr("10.0.0.2"), Priority: 1, DialTimeoutSec: 10}, + }}, + want: []netip.Addr{netip.MustParseAddr("10.0.0.2")}, + maxDuration: 2 * time.Second, // Must complete quickly, not wait for slow endpoint + }, + } + + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + runDialPlanTest(t, tt.plan, tt.want, tt.allowFallback, tt.maxDuration) + }) + }) } +} +func runDialPlanTest(t *testing.T, plan *tailcfg.ControlDialPlan, want []netip.Addr, allowFallback bool, maxDuration time.Duration) { client, server := key.NewMachine(), key.NewMachine() const ( testProtocolVersion = 1 + httpPort = "80" + httpsPort = "443" ) - getRandomPort := func() string { - ln, err := net.Listen("tcp", ":0") - if err != nil { - t.Fatalf("net.Listen: %v", err) - } - defer ln.Close() - _, port, err := net.SplitHostPort(ln.Addr().String()) - if err != nil { - t.Fatal(err) - } - return port - } + memNetwork := &memnet.Network{} - // We need consistent ports for each address; these are chosen - // randomly and we hope that they won't conflict during this test. - httpPort := getRandomPort() - httpsPort := getRandomPort() + fallbackAddr := netip.MustParseAddr("10.0.0.1") + goodAddr := netip.MustParseAddr("10.0.0.2") + otherAddr := netip.MustParseAddr("10.0.0.3") + other2Addr := netip.MustParseAddr("10.0.0.4") + brokenAddr := netip.MustParseAddr("10.0.0.10") + slowAddr := netip.MustParseAddr("10.0.0.12") makeHandler := func(t *testing.T, name string, host netip.Addr, wrap func(http.Handler) http.Handler) { done := make(chan struct{}) @@ -574,7 +675,7 @@ func TestDialPlan(t *testing.T) { close(done) }) var handler http.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := AcceptHTTP(context.Background(), w, r, server, nil) + conn, err := controlhttpserver.AcceptHTTP(context.Background(), w, r, server, nil) if err != nil { log.Print(err) } else { @@ -587,17 +688,66 @@ func TestDialPlan(t *testing.T) { handler = wrap(handler) } - httpLn, err := net.Listen("tcp", host.String()+":"+httpPort) + httpLn := must.Get(memNetwork.Listen("tcp", host.String()+":"+httpPort)) + httpsLn := must.Get(memNetwork.Listen("tcp", host.String()+":"+httpsPort)) + + httpServer := &http.Server{Handler: handler} + go httpServer.Serve(httpLn) + t.Cleanup(func() { + httpServer.Close() + }) + + httpsServer := &http.Server{ + Handler: handler, + TLSConfig: tlsConfig(t), + ErrorLog: logger.StdLogger(logger.WithPrefix(t.Logf, "http.Server.ErrorLog: ")), + } + go httpsServer.ServeTLS(httpsLn, "", "") + t.Cleanup(func() { + httpsServer.Close() + }) + } + + // Use synctest's controlled time + clock := tstime.StdClock{} + makeHandler(t, "fallback", fallbackAddr, nil) + makeHandler(t, "good", goodAddr, nil) + makeHandler(t, "other", otherAddr, nil) + makeHandler(t, "other2", other2Addr, nil) + makeHandler(t, "broken", brokenAddr, func(h http.Handler) http.Handler { + return brokenMITMHandler(clock) + }) + // Create slow listener that delays accept by 5 seconds + makeSlowHandler := func(t *testing.T, name string, host netip.Addr, delay time.Duration) { + done := make(chan struct{}) + t.Cleanup(func() { + close(done) + }) + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := controlhttpserver.AcceptHTTP(context.Background(), w, r, server, nil) + if err != nil { + log.Print(err) + } else { + defer conn.Close() + } + w.Header().Set("X-Handler-Name", name) + <-done + }) + + httpLn, err := memNetwork.Listen("tcp", host.String()+":"+httpPort) if err != nil { t.Fatalf("HTTP listen: %v", err) } - httpsLn, err := net.Listen("tcp", host.String()+":"+httpsPort) + httpsLn, err := memNetwork.Listen("tcp", host.String()+":"+httpsPort) if err != nil { t.Fatalf("HTTPS listen: %v", err) } + slowHttpLn := newSlowListener(httpLn, delay) + slowHttpsLn := newSlowListener(httpsLn, delay) + httpServer := &http.Server{Handler: handler} - go httpServer.Serve(httpLn) + go httpServer.Serve(slowHttpLn) t.Cleanup(func() { httpServer.Close() }) @@ -607,212 +757,148 @@ func TestDialPlan(t *testing.T) { TLSConfig: tlsConfig(t), ErrorLog: logger.StdLogger(logger.WithPrefix(t.Logf, "http.Server.ErrorLog: ")), } - go httpsServer.ServeTLS(httpsLn, "", "") + go httpsServer.ServeTLS(slowHttpsLn, "", "") t.Cleanup(func() { httpsServer.Close() }) - return } + makeSlowHandler(t, "slow", slowAddr, 5*time.Second) - fallbackAddr := netip.MustParseAddr("127.0.0.1") - goodAddr := netip.MustParseAddr("127.0.0.2") - otherAddr := netip.MustParseAddr("127.0.0.3") - other2Addr := netip.MustParseAddr("127.0.0.4") - brokenAddr := netip.MustParseAddr("127.0.0.10") + // memnetDialer with connection tracking, so we can catch connection leaks. + dialer := &memnetDialer{ + inner: memNetwork.Dial, + t: t, + } + defer dialer.waitForAllClosedSynctest() - testCases := []struct { - name string - plan *tailcfg.ControlDialPlan - wrap func(http.Handler) http.Handler - want netip.Addr + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() - allowFallback bool - }{ - { - name: "single", - plan: &tailcfg.ControlDialPlan{Candidates: []tailcfg.ControlIPCandidate{ - {IP: goodAddr, Priority: 1, DialTimeoutSec: 10}, - }}, - want: goodAddr, - }, - { - name: "broken-then-good", - plan: &tailcfg.ControlDialPlan{Candidates: []tailcfg.ControlIPCandidate{ - // Dials the broken one, which fails, and then - // eventually dials the good one and succeeds - {IP: brokenAddr, Priority: 2, DialTimeoutSec: 10}, - {IP: goodAddr, Priority: 1, DialTimeoutSec: 10, DialStartDelaySec: 1}, - }}, - want: goodAddr, - }, - // TODO(#8442): fix this test - // { - // name: "multiple-priority-fast-path", - // plan: &tailcfg.ControlDialPlan{Candidates: []tailcfg.ControlIPCandidate{ - // // Dials some good IPs and our bad one (which - // // hangs forever), which then hits the fast - // // path where we bail without waiting. - // {IP: brokenAddr, Priority: 1, DialTimeoutSec: 10}, - // {IP: goodAddr, Priority: 1, DialTimeoutSec: 10}, - // {IP: other2Addr, Priority: 1, DialTimeoutSec: 10}, - // {IP: otherAddr, Priority: 2, DialTimeoutSec: 10}, - // }}, - // want: otherAddr, - // }, - { - name: "multiple-priority-slow-path", - plan: &tailcfg.ControlDialPlan{Candidates: []tailcfg.ControlIPCandidate{ - // Our broken address is the highest priority, - // so we don't hit our fast path. - {IP: brokenAddr, Priority: 10, DialTimeoutSec: 10}, - {IP: otherAddr, Priority: 2, DialTimeoutSec: 10}, - {IP: goodAddr, Priority: 1, DialTimeoutSec: 10}, - }}, - want: otherAddr, - }, - { - name: "fallback", - plan: &tailcfg.ControlDialPlan{Candidates: []tailcfg.ControlIPCandidate{ - {IP: brokenAddr, Priority: 1, DialTimeoutSec: 1}, - }}, - want: fallbackAddr, - allowFallback: true, - }, + host := "example.com" + if allowFallback { + host = fallbackAddr.String() + } + bus := eventbustest.NewBus(t) + a := &Dialer{ + Hostname: host, + HTTPPort: httpPort, + HTTPSPort: httpsPort, + MachineKey: client, + ControlKey: server.Public(), + ProtocolVersion: testProtocolVersion, + Dialer: dialer.Dial, + Logf: t.Logf, + DialPlan: plan, + proxyFunc: func(*http.Request) (*url.URL, error) { return nil, nil }, + omitCertErrorLogging: true, + testFallbackDelay: 50 * time.Millisecond, + Clock: clock, + HealthTracker: health.NewTracker(bus), } - for _, tt := range testCases { - t.Run(tt.name, func(t *testing.T) { - // TODO(awly): replace this with tstest.NewClock and update the - // test to advance the clock correctly. - clock := tstime.StdClock{} - makeHandler(t, "fallback", fallbackAddr, nil) - makeHandler(t, "good", goodAddr, nil) - makeHandler(t, "other", otherAddr, nil) - makeHandler(t, "other2", other2Addr, nil) - makeHandler(t, "broken", brokenAddr, func(h http.Handler) http.Handler { - return brokenMITMHandler(clock) - }) - - dialer := closeTrackDialer{ - t: t, - inner: tsdial.NewDialer(netmon.NewStatic()).SystemDial, - conns: make(map[*closeTrackConn]bool), - } - defer dialer.Done() - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() + start := time.Now() + conn, err := a.dial(ctx) + duration := time.Since(start) - // By default, we intentionally point to something that - // we know won't connect, since we want a fallback to - // DNS to be an error. - host := "example.com" - if tt.allowFallback { - host = "localhost" - } + if err != nil { + t.Fatalf("dialing controlhttp: %v", err) + } + defer conn.Close() - drained := make(chan struct{}) - a := &Dialer{ - Hostname: host, - HTTPPort: httpPort, - HTTPSPort: httpsPort, - MachineKey: client, - ControlKey: server.Public(), - ProtocolVersion: testProtocolVersion, - Dialer: dialer.Dial, - Logf: t.Logf, - DialPlan: tt.plan, - proxyFunc: func(*http.Request) (*url.URL, error) { return nil, nil }, - drainFinished: drained, - omitCertErrorLogging: true, - testFallbackDelay: 50 * time.Millisecond, - Clock: clock, - } + if maxDuration > 0 && duration > maxDuration { + t.Errorf("dial took %v, expected < %v (should not wait for slow endpoints)", duration, maxDuration) + } - conn, err := a.dial(ctx) - if err != nil { - t.Fatalf("dialing controlhttp: %v", err) - } - defer conn.Close() + raddr := conn.RemoteAddr() + raddrStr := raddr.String() - raddr := conn.RemoteAddr().(*net.TCPAddr) + // split on "|" first to remove memnet pipe suffix + addrPart := raddrStr + if idx := strings.Index(raddrStr, "|"); idx >= 0 { + addrPart = raddrStr[:idx] + } - got, ok := netip.AddrFromSlice(raddr.IP) - if !ok { - t.Errorf("invalid remote IP: %v", raddr.IP) - } else if got != tt.want { - t.Errorf("got connection from %q; want %q", got, tt.want) - } else { - t.Logf("successfully connected to %q", raddr.String()) - } + host, _, err2 := net.SplitHostPort(addrPart) + if err2 != nil { + t.Fatalf("failed to parse remote address %q: %v", addrPart, err2) + } - // Wait until our dialer drains so we can verify that - // all connections are closed. - <-drained - }) + got, err3 := netip.ParseAddr(host) + if err3 != nil { + t.Errorf("invalid remote IP: %v", host) + } else { + found := slices.Contains(want, got) + if !found { + t.Errorf("got connection from %q; want one of %v", got, want) + } else { + t.Logf("successfully connected to %q", raddr.String()) + } } } -type closeTrackDialer struct { - t testing.TB - inner dnscache.DialContextFunc +// memnetDialer wraps memnet.Network.Dial to track connections for testing +type memnetDialer struct { + inner func(ctx context.Context, network, addr string) (net.Conn, error) + t *testing.T mu sync.Mutex - conns map[*closeTrackConn]bool + conns map[net.Conn]string // conn -> remote address for debugging } -func (d *closeTrackDialer) Dial(ctx context.Context, network, addr string) (net.Conn, error) { - c, err := d.inner(ctx, network, addr) +func (d *memnetDialer) Dial(ctx context.Context, network, addr string) (net.Conn, error) { + conn, err := d.inner(ctx, network, addr) if err != nil { return nil, err } - ct := &closeTrackConn{Conn: c, d: d} d.mu.Lock() - d.conns[ct] = true + if d.conns == nil { + d.conns = make(map[net.Conn]string) + } + d.conns[conn] = conn.RemoteAddr().String() + d.t.Logf("tracked connection opened to %s", conn.RemoteAddr()) d.mu.Unlock() - return ct, nil + + return &memnetTrackedConn{Conn: conn, dialer: d}, nil } -func (d *closeTrackDialer) Done() { - // Unfortunately, tsdial.Dialer.SystemDial closes connections - // asynchronously in a goroutine, so we can't assume that everything is - // closed by the time we get here. - // - // Sleep/wait a few times on the assumption that things will close - // "eventually". - const iters = 100 - for i := range iters { +func (d *memnetDialer) waitForAllClosedSynctest() { + const maxWait = 15 * time.Second + const checkInterval = 100 * time.Millisecond + + for range int(maxWait / checkInterval) { d.mu.Lock() - if len(d.conns) == 0 { + remaining := len(d.conns) + if remaining == 0 { d.mu.Unlock() return } + d.mu.Unlock() - // Only error on last iteration - if i != iters-1 { - d.mu.Unlock() - time.Sleep(100 * time.Millisecond) - continue - } + time.Sleep(checkInterval) + } - for conn := range d.conns { - d.t.Errorf("expected close of conn %p; RemoteAddr=%q", conn, conn.RemoteAddr().String()) - } - d.mu.Unlock() + d.mu.Lock() + defer d.mu.Unlock() + for _, addr := range d.conns { + d.t.Errorf("connection to %s was not closed after %v", addr, maxWait) } } -func (d *closeTrackDialer) noteClose(c *closeTrackConn) { +func (d *memnetDialer) noteClose(conn net.Conn) { d.mu.Lock() - delete(d.conns, c) // safe if already deleted + if addr, exists := d.conns[conn]; exists { + d.t.Logf("tracked connection closed to %s", addr) + delete(d.conns, conn) + } d.mu.Unlock() } -type closeTrackConn struct { +type memnetTrackedConn struct { net.Conn - d *closeTrackDialer + dialer *memnetDialer } -func (c *closeTrackConn) Close() error { - c.d.noteClose(c) +func (c *memnetTrackedConn) Close() error { + c.dialer.noteClose(c.Conn) return c.Conn.Close() } diff --git a/control/controlknobs/controlknobs.go b/control/controlknobs/controlknobs.go index dd76a3abd..09c16b8b1 100644 --- a/control/controlknobs/controlknobs.go +++ b/control/controlknobs/controlknobs.go @@ -6,6 +6,8 @@ package controlknobs import ( + "fmt" + "reflect" "sync/atomic" "tailscale.com/syncs" @@ -60,8 +62,9 @@ type Knobs struct { // netfiltering, unless overridden by the user. LinuxForceNfTables atomic.Bool - // SeamlessKeyRenewal is whether to enable the alpha functionality of - // renewing node keys without breaking connections. + // SeamlessKeyRenewal is whether to renew node keys without breaking connections. + // This is enabled by default in 1.90 and later, but we but we can remotely disable + // it from the control plane if there's a problem. // http://go/seamless-key-renewal SeamlessKeyRenewal atomic.Bool @@ -96,13 +99,14 @@ type Knobs struct { // allows us to disable the new behavior remotely if needed. DisableLocalDNSOverrideViaNRPT atomic.Bool - // DisableCryptorouting indicates that the node should not use the - // magicsock crypto routing feature. - DisableCryptorouting atomic.Bool - // DisableCaptivePortalDetection is whether the node should not perform captive portal detection // automatically when the network state changes. DisableCaptivePortalDetection atomic.Bool + + // DisableSkipStatusQueue is whether the node should disable skipping + // of queued netmap.NetworkMap between the controlclient and LocalBackend. + // See tailscale/tailscale#14768. + DisableSkipStatusQueue atomic.Bool } // UpdateFromNodeAttributes updates k (if non-nil) based on the provided self @@ -125,13 +129,14 @@ func (k *Knobs) UpdateFromNodeAttributes(capMap tailcfg.NodeCapMap) { forceIPTables = has(tailcfg.NodeAttrLinuxMustUseIPTables) forceNfTables = has(tailcfg.NodeAttrLinuxMustUseNfTables) seamlessKeyRenewal = has(tailcfg.NodeAttrSeamlessKeyRenewal) + disableSeamlessKeyRenewal = has(tailcfg.NodeAttrDisableSeamlessKeyRenewal) probeUDPLifetime = has(tailcfg.NodeAttrProbeUDPLifetime) appCStoreRoutes = has(tailcfg.NodeAttrStoreAppCRoutes) userDialUseRoutes = has(tailcfg.NodeAttrUserDialUseRoutes) disableSplitDNSWhenNoCustomResolvers = has(tailcfg.NodeAttrDisableSplitDNSWhenNoCustomResolvers) disableLocalDNSOverrideViaNRPT = has(tailcfg.NodeAttrDisableLocalDNSOverrideViaNRPT) - disableCryptorouting = has(tailcfg.NodeAttrDisableMagicSockCryptoRouting) disableCaptivePortalDetection = has(tailcfg.NodeAttrDisableCaptivePortalDetection) + disableSkipStatusQueue = has(tailcfg.NodeAttrDisableSkipStatusQueue) ) if has(tailcfg.NodeAttrOneCGNATEnable) { @@ -151,14 +156,28 @@ func (k *Knobs) UpdateFromNodeAttributes(capMap tailcfg.NodeCapMap) { k.SilentDisco.Store(silentDisco) k.LinuxForceIPTables.Store(forceIPTables) k.LinuxForceNfTables.Store(forceNfTables) - k.SeamlessKeyRenewal.Store(seamlessKeyRenewal) k.ProbeUDPLifetime.Store(probeUDPLifetime) k.AppCStoreRoutes.Store(appCStoreRoutes) k.UserDialUseRoutes.Store(userDialUseRoutes) k.DisableSplitDNSWhenNoCustomResolvers.Store(disableSplitDNSWhenNoCustomResolvers) k.DisableLocalDNSOverrideViaNRPT.Store(disableLocalDNSOverrideViaNRPT) - k.DisableCryptorouting.Store(disableCryptorouting) k.DisableCaptivePortalDetection.Store(disableCaptivePortalDetection) + k.DisableSkipStatusQueue.Store(disableSkipStatusQueue) + + // If both attributes are present, then "enable" should win. This reflects + // the history of seamless key renewal. + // + // Before 1.90, seamless was a private alpha, opt-in feature. Devices would + // only seamless do if customers opted in using the seamless renewal attr. + // + // In 1.90 and later, seamless is the default behaviour, and devices will use + // seamless unless explicitly told not to by control (e.g. if we discover + // a bug and want clients to use the prior behaviour). + // + // If a customer has opted in to the pre-1.90 seamless implementation, we + // don't want to switch it off for them -- we only want to switch it off for + // devices that haven't opted in. + k.SeamlessKeyRenewal.Store(seamlessKeyRenewal || !disableSeamlessKeyRenewal) } // AsDebugJSON returns k as something that can be marshalled with json.Marshal @@ -167,25 +186,19 @@ func (k *Knobs) AsDebugJSON() map[string]any { if k == nil { return nil } - return map[string]any{ - "DisableUPnP": k.DisableUPnP.Load(), - "KeepFullWGConfig": k.KeepFullWGConfig.Load(), - "RandomizeClientPort": k.RandomizeClientPort.Load(), - "OneCGNAT": k.OneCGNAT.Load(), - "ForceBackgroundSTUN": k.ForceBackgroundSTUN.Load(), - "DisableDeltaUpdates": k.DisableDeltaUpdates.Load(), - "PeerMTUEnable": k.PeerMTUEnable.Load(), - "DisableDNSForwarderTCPRetries": k.DisableDNSForwarderTCPRetries.Load(), - "SilentDisco": k.SilentDisco.Load(), - "LinuxForceIPTables": k.LinuxForceIPTables.Load(), - "LinuxForceNfTables": k.LinuxForceNfTables.Load(), - "SeamlessKeyRenewal": k.SeamlessKeyRenewal.Load(), - "ProbeUDPLifetime": k.ProbeUDPLifetime.Load(), - "AppCStoreRoutes": k.AppCStoreRoutes.Load(), - "UserDialUseRoutes": k.UserDialUseRoutes.Load(), - "DisableSplitDNSWhenNoCustomResolvers": k.DisableSplitDNSWhenNoCustomResolvers.Load(), - "DisableLocalDNSOverrideViaNRPT": k.DisableLocalDNSOverrideViaNRPT.Load(), - "DisableCryptorouting": k.DisableCryptorouting.Load(), - "DisableCaptivePortalDetection": k.DisableCaptivePortalDetection.Load(), + ret := map[string]any{} + rt := reflect.TypeFor[Knobs]() + rv := reflect.ValueOf(k).Elem() // of *k + for i := 0; i < rt.NumField(); i++ { + name := rt.Field(i).Name + switch v := rv.Field(i).Addr().Interface().(type) { + case *atomic.Bool: + ret[name] = v.Load() + case *syncs.AtomicValue[opt.Bool]: + ret[name] = v.Load() + default: + panic(fmt.Sprintf("unknown field type %T for %v", v, name)) + } } + return ret } diff --git a/control/controlknobs/controlknobs_test.go b/control/controlknobs/controlknobs_test.go index a78a486f3..7618b7121 100644 --- a/control/controlknobs/controlknobs_test.go +++ b/control/controlknobs/controlknobs_test.go @@ -6,6 +6,8 @@ package controlknobs import ( "reflect" "testing" + + "tailscale.com/types/logger" ) func TestAsDebugJSON(t *testing.T) { @@ -18,4 +20,5 @@ func TestAsDebugJSON(t *testing.T) { if want := reflect.TypeFor[Knobs]().NumField(); len(got) != want { t.Errorf("AsDebugJSON map has %d fields; want %v", len(got), want) } + t.Logf("Got: %v", logger.AsJSON(got)) } diff --git a/control/ts2021/client.go b/control/ts2021/client.go new file mode 100644 index 000000000..ca10b1d1b --- /dev/null +++ b/control/ts2021/client.go @@ -0,0 +1,312 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package ts2021 + +import ( + "bytes" + "cmp" + "context" + "encoding/json" + "errors" + "fmt" + "log" + "math" + "net" + "net/http" + "net/netip" + "net/url" + "sync" + "time" + + "tailscale.com/control/controlhttp" + "tailscale.com/health" + "tailscale.com/net/dnscache" + "tailscale.com/net/netmon" + "tailscale.com/net/tsdial" + "tailscale.com/tailcfg" + "tailscale.com/tstime" + "tailscale.com/types/key" + "tailscale.com/types/logger" + "tailscale.com/util/mak" + "tailscale.com/util/set" +) + +// Client provides a http.Client to connect to tailcontrol over +// the ts2021 protocol. +type Client struct { + // Client is an HTTP client to talk to the coordination server. + // It automatically makes a new Noise connection as needed. + *http.Client + + logf logger.Logf // non-nil + opts ClientOpts + host string // the host part of serverURL + httpPort string // the default port to dial + httpsPort string // the fallback Noise-over-https port or empty if none + + // mu protects the following + mu sync.Mutex + closed bool + connPool set.HandleSet[*Conn] // all live connections +} + +// ClientOpts contains options for the [NewClient] function. All fields are +// required unless otherwise specified. +type ClientOpts struct { + // ServerURL is the URL of the server to connect to. + ServerURL string + + // PrivKey is this node's private key. + PrivKey key.MachinePrivate + + // ServerPubKey is the public key of the server. + // It is of the form https://: (no trailing slash). + ServerPubKey key.MachinePublic + + // Dialer's SystemDial function is used to connect to the server. + Dialer *tsdial.Dialer + + // Optional fields follow + + // Logf is the log function to use. + // If nil, log.Printf is used. + Logf logger.Logf + + // NetMon is the network monitor that will be used to get the + // network interface state. This field can be nil; if so, the current + // state will be looked up dynamically. + NetMon *netmon.Monitor + + // DNSCache is the caching Resolver to use to connect to the server. + // + // This field can be nil. + DNSCache *dnscache.Resolver + + // HealthTracker, if non-nil, is the health tracker to use. + HealthTracker *health.Tracker + + // DialPlan, if set, is a function that should return an explicit plan + // on how to connect to the server. + DialPlan func() *tailcfg.ControlDialPlan + + // ProtocolVersion, if non-zero, specifies an alternate + // protocol version to use instead of the default, + // of [tailcfg.CurrentCapabilityVersion]. + ProtocolVersion uint16 +} + +// NewClient returns a new noiseClient for the provided server and machine key. +// +// netMon may be nil, if non-nil it's used to do faster interface lookups. +// dialPlan may be nil +func NewClient(opts ClientOpts) (*Client, error) { + logf := opts.Logf + if logf == nil { + logf = log.Printf + } + if opts.ServerURL == "" { + return nil, errors.New("ServerURL is required") + } + if opts.PrivKey.IsZero() { + return nil, errors.New("PrivKey is required") + } + if opts.ServerPubKey.IsZero() { + return nil, errors.New("ServerPubKey is required") + } + if opts.Dialer == nil { + return nil, errors.New("Dialer is required") + } + + u, err := url.Parse(opts.ServerURL) + if err != nil { + return nil, fmt.Errorf("invalid ClientOpts.ServerURL: %w", err) + } + if u.Scheme != "http" && u.Scheme != "https" { + return nil, errors.New("invalid ServerURL scheme, must be http or https") + } + + httpPort, httpsPort := "80", "443" + addr, _ := netip.ParseAddr(u.Hostname()) + isPrivateHost := addr.IsPrivate() || addr.IsLoopback() || u.Hostname() == "localhost" + if port := u.Port(); port != "" { + // If there is an explicit port specified, entirely rely on the scheme, + // unless it's http with a private host in which case we never try using HTTPS. + if u.Scheme == "https" { + httpPort = "" + httpsPort = port + } else if u.Scheme == "http" { + httpPort = port + httpsPort = "443" + if isPrivateHost { + logf("setting empty HTTPS port with http scheme and private host %s", u.Hostname()) + httpsPort = "" + } + } + } else if u.Scheme == "http" && isPrivateHost { + // Whenever the scheme is http and the hostname is an IP address, do not set the HTTPS port, + // as there cannot be a TLS certificate issued for an IP, unless it's a public IP. + httpPort = "80" + httpsPort = "" + } + + np := &Client{ + opts: opts, + host: u.Hostname(), + httpPort: httpPort, + httpsPort: httpsPort, + logf: logf, + } + + tr := &http.Transport{ + Protocols: new(http.Protocols), + MaxConnsPerHost: 1, + } + // We force only HTTP/2 for this transport, which is what the control server + // speaks inside the ts2021 Noise encryption. But Go doesn't know about that, + // so we use "SetUnencryptedHTTP2" even though it's actually encrypted. + tr.Protocols.SetUnencryptedHTTP2(true) + tr.DialTLSContext = func(ctx context.Context, network, addr string) (net.Conn, error) { + return np.dial(ctx) + } + + np.Client = &http.Client{Transport: tr} + return np, nil +} + +// Close closes all the underlying noise connections. +// It is a no-op and returns nil if the connection is already closed. +func (nc *Client) Close() error { + nc.mu.Lock() + live := nc.connPool + nc.closed = true + nc.connPool = nil // stop noteConnClosed from mutating it as we loop over it (in live) below + nc.mu.Unlock() + + for _, c := range live { + c.Close() + } + nc.Client.CloseIdleConnections() + + return nil +} + +// dial opens a new connection to tailcontrol, fetching the server noise key +// if not cached. +func (nc *Client) dial(ctx context.Context) (*Conn, error) { + if tailcfg.CurrentCapabilityVersion > math.MaxUint16 { + // Panic, because a test should have started failing several + // thousand version numbers before getting to this point. + panic("capability version is too high to fit in the wire protocol") + } + + var dialPlan *tailcfg.ControlDialPlan + if nc.opts.DialPlan != nil { + dialPlan = nc.opts.DialPlan() + } + + // If we have a dial plan, then set our timeout as slightly longer than + // the maximum amount of time contained therein; we assume that + // explicit instructions on timeouts are more useful than a single + // hard-coded timeout. + // + // The default value of 5 is chosen so that, when there's no dial plan, + // we retain the previous behaviour of 10 seconds end-to-end timeout. + timeoutSec := 5.0 + if dialPlan != nil { + for _, c := range dialPlan.Candidates { + if v := c.DialStartDelaySec + c.DialTimeoutSec; v > timeoutSec { + timeoutSec = v + } + } + } + + // After we establish a connection, we need some time to actually + // upgrade it into a Noise connection. With a ballpark worst-case RTT + // of 1000ms, give ourselves an extra 5 seconds to complete the + // handshake. + timeoutSec += 5 + + // Be extremely defensive and ensure that the timeout is in the range + // [5, 60] seconds (e.g. if we accidentally get a negative number). + if timeoutSec > 60 { + timeoutSec = 60 + } else if timeoutSec < 5 { + timeoutSec = 5 + } + + timeout := time.Duration(timeoutSec * float64(time.Second)) + ctx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + + chd := &controlhttp.Dialer{ + Hostname: nc.host, + HTTPPort: nc.httpPort, + HTTPSPort: cmp.Or(nc.httpsPort, controlhttp.NoPort), + MachineKey: nc.opts.PrivKey, + ControlKey: nc.opts.ServerPubKey, + ProtocolVersion: cmp.Or(nc.opts.ProtocolVersion, uint16(tailcfg.CurrentCapabilityVersion)), + Dialer: nc.opts.Dialer.SystemDial, + DNSCache: nc.opts.DNSCache, + DialPlan: dialPlan, + Logf: nc.logf, + NetMon: nc.opts.NetMon, + HealthTracker: nc.opts.HealthTracker, + Clock: tstime.StdClock{}, + } + clientConn, err := chd.Dial(ctx) + if err != nil { + return nil, err + } + + nc.mu.Lock() + + handle := set.NewHandle() + ncc := NewConn(clientConn.Conn, func() { nc.noteConnClosed(handle) }) + mak.Set(&nc.connPool, handle, ncc) + + if nc.closed { + nc.mu.Unlock() + ncc.Close() // Needs to be called without holding the lock. + return nil, errors.New("noise client closed") + } + + defer nc.mu.Unlock() + return ncc, nil +} + +// noteConnClosed notes that the *Conn with the given handle has closed and +// should be removed from the live connPool (which is usually of size 0 or 1, +// except perhaps briefly 2 during a network failure and reconnect). +func (nc *Client) noteConnClosed(handle set.Handle) { + nc.mu.Lock() + defer nc.mu.Unlock() + nc.connPool.Delete(handle) +} + +// post does a POST to the control server at the given path, JSON-encoding body. +// The provided nodeKey is an optional load balancing hint. +func (nc *Client) Post(ctx context.Context, path string, nodeKey key.NodePublic, body any) (*http.Response, error) { + return nc.DoWithBody(ctx, "POST", path, nodeKey, body) +} + +func (nc *Client) DoWithBody(ctx context.Context, method, path string, nodeKey key.NodePublic, body any) (*http.Response, error) { + jbody, err := json.Marshal(body) + if err != nil { + return nil, err + } + req, err := http.NewRequestWithContext(ctx, method, "https://"+nc.host+path, bytes.NewReader(jbody)) + if err != nil { + return nil, err + } + AddLBHeader(req, nodeKey) + req.Header.Set("Content-Type", "application/json") + return nc.Do(req) +} + +// AddLBHeader adds the load balancer header to req if nodeKey is non-zero. +func AddLBHeader(req *http.Request, nodeKey key.NodePublic) { + if !nodeKey.IsZero() { + req.Header.Add(tailcfg.LBHeader, nodeKey.String()) + } +} diff --git a/control/controlclient/noise_test.go b/control/ts2021/client_test.go similarity index 50% rename from control/controlclient/noise_test.go rename to control/ts2021/client_test.go index f2627bd0a..72fa1f442 100644 --- a/control/controlclient/noise_test.go +++ b/control/ts2021/client_test.go @@ -1,7 +1,7 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -package controlclient +package ts2021 import ( "context" @@ -10,18 +10,20 @@ import ( "io" "math" "net/http" - "net/http/httptest" + "net/http/httptrace" + "sync/atomic" "testing" "time" "golang.org/x/net/http2" - "tailscale.com/control/controlhttp" - "tailscale.com/internal/noiseconn" + "tailscale.com/control/controlhttp/controlhttpserver" "tailscale.com/net/netmon" "tailscale.com/net/tsdial" "tailscale.com/tailcfg" + "tailscale.com/tstest/nettest" "tailscale.com/types/key" "tailscale.com/types/logger" + "tailscale.com/util/must" ) // maxAllowedNoiseVersion is the highest we expect the Tailscale @@ -54,6 +56,132 @@ func TestNoiseClientHTTP2Upgrade_earlyPayload(t *testing.T) { }.run(t) } +var ( + testPrivKey = key.NewMachine() + testServerPub = key.NewMachine().Public() +) + +func makeClientWithURL(t *testing.T, url string) *Client { + nc, err := NewClient(ClientOpts{ + Logf: t.Logf, + PrivKey: testPrivKey, + ServerPubKey: testServerPub, + ServerURL: url, + Dialer: tsdial.NewDialer(netmon.NewStatic()), + }) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { nc.Close() }) + return nc +} + +func TestNoiseClientPortsAreSet(t *testing.T) { + tests := []struct { + name string + url string + wantHTTPS string + wantHTTP string + }{ + { + name: "https-url", + url: "https://example.com", + wantHTTPS: "443", + wantHTTP: "80", + }, + { + name: "http-url", + url: "http://example.com", + wantHTTPS: "443", // TODO(bradfitz): questionable; change? + wantHTTP: "80", + }, + { + name: "https-url-custom-port", + url: "https://example.com:123", + wantHTTPS: "123", + wantHTTP: "", + }, + { + name: "http-url-custom-port", + url: "http://example.com:123", + wantHTTPS: "443", // TODO(bradfitz): questionable; change? + wantHTTP: "123", + }, + { + name: "http-loopback-no-port", + url: "http://127.0.0.1", + wantHTTPS: "", + wantHTTP: "80", + }, + { + name: "http-loopback-custom-port", + url: "http://127.0.0.1:8080", + wantHTTPS: "", + wantHTTP: "8080", + }, + { + name: "http-localhost-no-port", + url: "http://localhost", + wantHTTPS: "", + wantHTTP: "80", + }, + { + name: "http-localhost-custom-port", + url: "http://localhost:8080", + wantHTTPS: "", + wantHTTP: "8080", + }, + { + name: "http-private-ip-no-port", + url: "http://192.168.2.3", + wantHTTPS: "", + wantHTTP: "80", + }, + { + name: "http-private-ip-custom-port", + url: "http://192.168.2.3:8080", + wantHTTPS: "", + wantHTTP: "8080", + }, + { + name: "http-public-ip", + url: "http://1.2.3.4", + wantHTTPS: "443", // TODO(bradfitz): questionable; change? + wantHTTP: "80", + }, + { + name: "http-public-ip-custom-port", + url: "http://1.2.3.4:8080", + wantHTTPS: "443", // TODO(bradfitz): questionable; change? + wantHTTP: "8080", + }, + { + name: "https-public-ip", + url: "https://1.2.3.4", + wantHTTPS: "443", + wantHTTP: "80", + }, + { + name: "https-public-ip-custom-port", + url: "https://1.2.3.4:8080", + wantHTTPS: "8080", + wantHTTP: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + nc := makeClientWithURL(t, tt.url) + if nc.httpsPort != tt.wantHTTPS { + t.Errorf("nc.httpsPort = %q; want %q", nc.httpsPort, tt.wantHTTPS) + } + if nc.httpPort != tt.wantHTTP { + t.Errorf("nc.httpPort = %q; want %q", nc.httpPort, tt.wantHTTP) + } + }) + } +} + func (tt noiseClientTest) run(t *testing.T) { serverPrivate := key.NewMachine() clientPrivate := key.NewMachine() @@ -61,7 +189,8 @@ func (tt noiseClientTest) run(t *testing.T) { const msg = "Hello, client" h2 := &http2.Server{} - hs := httptest.NewServer(&Upgrader{ + nw := nettest.GetNetwork(t) + hs := nettest.NewHTTPServer(nw, &Upgrader{ h2srv: h2, noiseKeyPriv: serverPrivate, sendEarlyPayload: tt.sendEarlyPayload, @@ -76,38 +205,54 @@ func (tt noiseClientTest) run(t *testing.T) { defer hs.Close() dialer := tsdial.NewDialer(netmon.NewStatic()) - nc, err := NewNoiseClient(NoiseOpts{ + if nettest.PreferMemNetwork() { + dialer.SetSystemDialerForTest(nw.Dial) + } + + nc, err := NewClient(ClientOpts{ PrivKey: clientPrivate, ServerPubKey: serverPrivate.Public(), ServerURL: hs.URL, Dialer: dialer, + Logf: t.Logf, }) if err != nil { t.Fatal(err) } - // Get a conn and verify it read its early payload before the http/2 - // handshake. - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - c, err := nc.getConn(ctx) - if err != nil { - t.Fatal(err) - } - payload, err := c.GetEarlyPayload(ctx) - if err != nil { - t.Fatal("timed out waiting for didReadHeaderCh") - } + var sawConn atomic.Bool + trace := httptrace.WithClientTrace(t.Context(), &httptrace.ClientTrace{ + GotConn: func(ci httptrace.GotConnInfo) { + ncc, ok := ci.Conn.(*Conn) + if !ok { + // This trace hook sees two dials: the lower-level controlhttp upgrade's + // dial (a tsdial.sysConn), and then the *ts2021.Conn we want. + // Ignore the first one. + return + } + sawConn.Store(true) - gotNonNil := payload != nil - if gotNonNil != tt.sendEarlyPayload { - t.Errorf("sendEarlyPayload = %v but got earlyPayload = %T", tt.sendEarlyPayload, payload) - } - if payload != nil { - if payload.NodeKeyChallenge != chalPrivate.Public() { - t.Errorf("earlyPayload.NodeKeyChallenge = %v; want %v", payload.NodeKeyChallenge, chalPrivate.Public()) - } - } + ctx, cancel := context.WithTimeout(t.Context(), 5*time.Second) + defer cancel() + + payload, err := ncc.GetEarlyPayload(ctx) + if err != nil { + t.Errorf("GetEarlyPayload: %v", err) + return + } + + gotNonNil := payload != nil + if gotNonNil != tt.sendEarlyPayload { + t.Errorf("sendEarlyPayload = %v but got earlyPayload = %T", tt.sendEarlyPayload, payload) + } + if payload != nil { + if payload.NodeKeyChallenge != chalPrivate.Public() { + t.Errorf("earlyPayload.NodeKeyChallenge = %v; want %v", payload.NodeKeyChallenge, chalPrivate.Public()) + } + } + }, + }) + req := must.Get(http.NewRequestWithContext(trace, "GET", "https://unused.example/", nil)) checkRes := func(t *testing.T, res *http.Response) { t.Helper() @@ -121,15 +266,19 @@ func (tt noiseClientTest) run(t *testing.T) { } } - // And verify we can do HTTP/2 against that conn. - res, err := (&http.Client{Transport: c}).Get("https://unused.example/") + // Verify we can do HTTP/2 against that conn. + res, err := nc.Do(req) if err != nil { t.Fatal(err) } checkRes(t, res) + if !sawConn.Load() { + t.Error("ClientTrace.GotConn never saw the *ts2021.Conn") + } + // And try using the high-level nc.post API as well. - res, err = nc.post(context.Background(), "/", key.NodePublic{}, nil) + res, err = nc.Post(context.Background(), "/", key.NodePublic{}, nil) if err != nil { t.Fatal(err) } @@ -184,7 +333,7 @@ func (up *Upgrader) ServeHTTP(w http.ResponseWriter, r *http.Request) { // https://httpwg.org/specs/rfc7540.html#rfc.section.4.1 (Especially not // an HTTP/2 settings frame, which isn't of type 'T') var notH2Frame [5]byte - copy(notH2Frame[:], noiseconn.EarlyPayloadMagic) + copy(notH2Frame[:], EarlyPayloadMagic) var lenBuf [4]byte binary.BigEndian.PutUint32(lenBuf[:], uint32(len(earlyJSON))) // These writes are all buffered by caller, so fine to do them @@ -201,7 +350,7 @@ func (up *Upgrader) ServeHTTP(w http.ResponseWriter, r *http.Request) { return nil } - cbConn, err := controlhttp.AcceptHTTP(r.Context(), w, r, up.noiseKeyPriv, earlyWriteFn) + cbConn, err := controlhttpserver.AcceptHTTP(r.Context(), w, r, up.noiseKeyPriv, earlyWriteFn) if err != nil { up.logf("controlhttp: Accept: %v", err) return diff --git a/internal/noiseconn/conn.go b/control/ts2021/conn.go similarity index 69% rename from internal/noiseconn/conn.go rename to control/ts2021/conn.go index 7476b7ecc..52d663272 100644 --- a/internal/noiseconn/conn.go +++ b/control/ts2021/conn.go @@ -1,12 +1,10 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -// Package noiseconn contains an internal-only wrapper around controlbase.Conn -// that properly handles the early payload sent by the server before the HTTP/2 -// session begins. -// -// See the documentation on the Conn type for more details. -package noiseconn +// Package ts2021 handles the details of the Tailscale 2021 control protocol +// that are after (above) the Noise layer. In particular, the +// "tailcfg.EarlyNoise" message and the subsequent HTTP/2 connection. +package ts2021 import ( "bytes" @@ -15,10 +13,8 @@ import ( "encoding/json" "errors" "io" - "net/http" "sync" - "golang.org/x/net/http2" "tailscale.com/control/controlbase" "tailscale.com/tailcfg" ) @@ -29,12 +25,13 @@ import ( // the pool when the connection is closed, properly handles an optional "early // payload" that's sent prior to beginning the HTTP/2 session, and provides a // way to return a connection to a pool when the connection is closed. +// +// Use [NewConn] to build a new Conn if you want [Conn.GetEarlyPayload] to work. +// Otherwise making a Conn directly, only setting Conn, is fine. type Conn struct { *controlbase.Conn - id int - onClose func(int) - h2cc *http2.ClientConn + onClose func() // or nil readHeaderOnce sync.Once // guards init of reader field reader io.Reader // (effectively Conn.Reader after header) earlyPayloadReady chan struct{} // closed after earlyPayload is set (including set to nil) @@ -42,31 +39,19 @@ type Conn struct { earlyPayloadErr error } -// New creates a new Conn that wraps the given controlbase.Conn. +// NewConn creates a new Conn that wraps the given controlbase.Conn. // // h2t is the HTTP/2 transport to use for the connection; a new // http2.ClientConn will be created that reads from the returned Conn. // // connID should be a unique ID for this connection. When the Conn is closed, -// the onClose function will be called with the connID if it is non-nil. -func New(conn *controlbase.Conn, h2t *http2.Transport, connID int, onClose func(int)) (*Conn, error) { - ncc := &Conn{ +// the onClose function will be called if it is non-nil. +func NewConn(conn *controlbase.Conn, onClose func()) *Conn { + return &Conn{ Conn: conn, - id: connID, - onClose: onClose, earlyPayloadReady: make(chan struct{}), + onClose: sync.OnceFunc(onClose), } - h2cc, err := h2t.NewClientConn(ncc) - if err != nil { - return nil, err - } - ncc.h2cc = h2cc - return ncc, nil -} - -// RoundTrip implements the http.RoundTripper interface. -func (c *Conn) RoundTrip(r *http.Request) (*http.Response, error) { - return c.h2cc.RoundTrip(r) } // GetEarlyPayload waits for the early Noise payload to arrive. @@ -76,6 +61,15 @@ func (c *Conn) RoundTrip(r *http.Request) (*http.Response, error) { // early Noise payload is ready (if any) and will return the same result for // the lifetime of the Conn. func (c *Conn) GetEarlyPayload(ctx context.Context) (*tailcfg.EarlyNoise, error) { + if c.earlyPayloadReady == nil { + return nil, errors.New("Conn was not created with NewConn; early payload not supported") + } + select { + case <-c.earlyPayloadReady: + return c.earlyPayload, c.earlyPayloadErr + default: + go c.readHeaderOnce.Do(c.readHeader) + } select { case <-c.earlyPayloadReady: return c.earlyPayload, c.earlyPayloadErr @@ -84,28 +78,6 @@ func (c *Conn) GetEarlyPayload(ctx context.Context) (*tailcfg.EarlyNoise, error) } } -// ReserveNewRequest will reserve a new concurrent request on the connection. -// -// It returns whether the reservation was successful, and any early Noise -// payload if present. If a reservation was not successful, it will return -// false and nil for the early payload. -func (c *Conn) ReserveNewRequest(ctx context.Context) (bool, *tailcfg.EarlyNoise, error) { - earlyPayloadMaybeNil, err := c.GetEarlyPayload(ctx) - if err != nil { - return false, nil, err - } - if c.h2cc.ReserveNewRequest() { - return true, earlyPayloadMaybeNil, nil - } - return false, nil, nil -} - -// CanTakeNewRequest reports whether the underlying HTTP/2 connection can take -// a new request, meaning it has not been closed or received or sent a GOAWAY. -func (c *Conn) CanTakeNewRequest() bool { - return c.h2cc.CanTakeNewRequest() -} - // The first 9 bytes from the server to client over Noise are either an HTTP/2 // settings frame (a normal HTTP/2 setup) or, as we added later, an "early payload" // header that's also 9 bytes long: 5 bytes (EarlyPayloadMagic) followed by 4 bytes @@ -133,6 +105,14 @@ func (c *Conn) Read(p []byte) (n int, err error) { return c.reader.Read(p) } +// Close closes the connection. +func (c *Conn) Close() error { + if c.onClose != nil { + defer c.onClose() + } + return c.Conn.Close() +} + // readHeader reads the optional "early payload" from the server that arrives // after the Noise handshake but before the HTTP/2 session begins. // @@ -140,7 +120,9 @@ func (c *Conn) Read(p []byte) (n int, err error) { // c.earlyPayload, closing c.earlyPayloadReady, and initializing c.reader for // future reads. func (c *Conn) readHeader() { - defer close(c.earlyPayloadReady) + if c.earlyPayloadReady != nil { + defer close(c.earlyPayloadReady) + } setErr := func(err error) { c.reader = returnErrReader{err} @@ -174,14 +156,3 @@ func (c *Conn) readHeader() { } c.reader = c.Conn } - -// Close closes the connection. -func (c *Conn) Close() error { - if err := c.Conn.Close(); err != nil { - return err - } - if c.onClose != nil { - c.onClose(c.id) - } - return nil -} diff --git a/derp/client_test.go b/derp/client_test.go new file mode 100644 index 000000000..a731ad197 --- /dev/null +++ b/derp/client_test.go @@ -0,0 +1,235 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package derp + +import ( + "bufio" + "bytes" + "io" + "net" + "reflect" + "sync" + "testing" + "time" + + "tailscale.com/tstest" + "tailscale.com/types/key" +) + +type dummyNetConn struct { + net.Conn +} + +func (dummyNetConn) SetReadDeadline(time.Time) error { return nil } + +func TestClientRecv(t *testing.T) { + tests := []struct { + name string + input []byte + want any + }{ + { + name: "ping", + input: []byte{ + byte(FramePing), 0, 0, 0, 8, + 1, 2, 3, 4, 5, 6, 7, 8, + }, + want: PingMessage{1, 2, 3, 4, 5, 6, 7, 8}, + }, + { + name: "pong", + input: []byte{ + byte(FramePong), 0, 0, 0, 8, + 1, 2, 3, 4, 5, 6, 7, 8, + }, + want: PongMessage{1, 2, 3, 4, 5, 6, 7, 8}, + }, + { + name: "health_bad", + input: []byte{ + byte(FrameHealth), 0, 0, 0, 3, + byte('B'), byte('A'), byte('D'), + }, + want: HealthMessage{Problem: "BAD"}, + }, + { + name: "health_ok", + input: []byte{ + byte(FrameHealth), 0, 0, 0, 0, + }, + want: HealthMessage{}, + }, + { + name: "server_restarting", + input: []byte{ + byte(FrameRestarting), 0, 0, 0, 8, + 0, 0, 0, 1, + 0, 0, 0, 2, + }, + want: ServerRestartingMessage{ + ReconnectIn: 1 * time.Millisecond, + TryFor: 2 * time.Millisecond, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &Client{ + nc: dummyNetConn{}, + br: bufio.NewReader(bytes.NewReader(tt.input)), + logf: t.Logf, + clock: &tstest.Clock{}, + } + got, err := c.Recv() + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("got %#v; want %#v", got, tt.want) + } + }) + } +} + +func TestClientSendPing(t *testing.T) { + var buf bytes.Buffer + c := &Client{ + bw: bufio.NewWriter(&buf), + } + if err := c.SendPing([8]byte{1, 2, 3, 4, 5, 6, 7, 8}); err != nil { + t.Fatal(err) + } + want := []byte{ + byte(FramePing), 0, 0, 0, 8, + 1, 2, 3, 4, 5, 6, 7, 8, + } + if !bytes.Equal(buf.Bytes(), want) { + t.Errorf("unexpected output\nwrote: % 02x\n want: % 02x", buf.Bytes(), want) + } +} + +func TestClientSendPong(t *testing.T) { + var buf bytes.Buffer + c := &Client{ + bw: bufio.NewWriter(&buf), + } + if err := c.SendPong([8]byte{1, 2, 3, 4, 5, 6, 7, 8}); err != nil { + t.Fatal(err) + } + want := []byte{ + byte(FramePong), 0, 0, 0, 8, + 1, 2, 3, 4, 5, 6, 7, 8, + } + if !bytes.Equal(buf.Bytes(), want) { + t.Errorf("unexpected output\nwrote: % 02x\n want: % 02x", buf.Bytes(), want) + } +} + +func BenchmarkWriteUint32(b *testing.B) { + w := bufio.NewWriter(io.Discard) + b.ReportAllocs() + b.ResetTimer() + for range b.N { + writeUint32(w, 0x0ba3a) + } +} + +type nopRead struct{} + +func (r nopRead) Read(p []byte) (int, error) { + return len(p), nil +} + +var sinkU32 uint32 + +func BenchmarkReadUint32(b *testing.B) { + r := bufio.NewReader(nopRead{}) + var err error + b.ReportAllocs() + b.ResetTimer() + for range b.N { + sinkU32, err = readUint32(r) + if err != nil { + b.Fatal(err) + } + } +} + +type countWriter struct { + mu sync.Mutex + writes int + bytes int64 +} + +func (w *countWriter) Write(p []byte) (n int, err error) { + w.mu.Lock() + defer w.mu.Unlock() + w.writes++ + w.bytes += int64(len(p)) + return len(p), nil +} + +func (w *countWriter) Stats() (writes int, bytes int64) { + w.mu.Lock() + defer w.mu.Unlock() + return w.writes, w.bytes +} + +func (w *countWriter) ResetStats() { + w.mu.Lock() + defer w.mu.Unlock() + w.writes, w.bytes = 0, 0 +} + +func TestClientSendRateLimiting(t *testing.T) { + cw := new(countWriter) + c := &Client{ + bw: bufio.NewWriter(cw), + clock: &tstest.Clock{}, + } + c.setSendRateLimiter(ServerInfoMessage{}) + + pkt := make([]byte, 1000) + if err := c.send(key.NodePublic{}, pkt); err != nil { + t.Fatal(err) + } + writes1, bytes1 := cw.Stats() + if writes1 != 1 { + t.Errorf("writes = %v, want 1", writes1) + } + + // Flood should all succeed. + cw.ResetStats() + for range 1000 { + if err := c.send(key.NodePublic{}, pkt); err != nil { + t.Fatal(err) + } + } + writes1K, bytes1K := cw.Stats() + if writes1K != 1000 { + t.Logf("writes = %v; want 1000", writes1K) + } + if got, want := bytes1K, bytes1*1000; got != want { + t.Logf("bytes = %v; want %v", got, want) + } + + // Set a rate limiter + cw.ResetStats() + c.setSendRateLimiter(ServerInfoMessage{ + TokenBucketBytesPerSecond: 1, + TokenBucketBytesBurst: int(bytes1 * 2), + }) + for range 1000 { + if err := c.send(key.NodePublic{}, pkt); err != nil { + t.Fatal(err) + } + } + writesLimited, bytesLimited := cw.Stats() + if writesLimited == 0 || writesLimited == writes1K { + t.Errorf("limited conn's write count = %v; want non-zero, less than 1k", writesLimited) + } + if bytesLimited < bytes1*2 || bytesLimited >= bytes1K { + t.Errorf("limited conn's bytes count = %v; want >=%v, <%v", bytesLimited, bytes1K*2, bytes1K) + } +} diff --git a/derp/derp.go b/derp/derp.go index f9b070647..e19a99b00 100644 --- a/derp/derp.go +++ b/derp/derp.go @@ -18,6 +18,7 @@ import ( "errors" "fmt" "io" + "net" "time" ) @@ -26,27 +27,31 @@ import ( // including its on-wire framing overhead) const MaxPacketSize = 64 << 10 -// magic is the DERP magic number, sent in the frameServerKey frame +// Magic is the DERP Magic number, sent in the frameServerKey frame // upon initial connection. -const magic = "DERP🔑" // 8 bytes: 0x44 45 52 50 f0 9f 94 91 +const Magic = "DERP🔑" // 8 bytes: 0x44 45 52 50 f0 9f 94 91 const ( - nonceLen = 24 - frameHeaderLen = 1 + 4 // frameType byte + 4 byte length - keyLen = 32 - maxInfoLen = 1 << 20 - keepAlive = 60 * time.Second + NonceLen = 24 + FrameHeaderLen = 1 + 4 // frameType byte + 4 byte length + KeyLen = 32 + MaxInfoLen = 1 << 20 ) +// KeepAlive is the minimum frequency at which the DERP server sends +// keep alive frames. The server adds some jitter, so this timing is not +// exact, but 2x this value can be considered a missed keep alive. +const KeepAlive = 60 * time.Second + // ProtocolVersion is bumped whenever there's a wire-incompatible change. // - version 1 (zero on wire): consistent box headers, in use by employee dev nodes a bit // - version 2: received packets have src addrs in frameRecvPacket at beginning const ProtocolVersion = 2 -// frameType is the one byte frame type at the beginning of the frame +// FrameType is the one byte frame type at the beginning of the frame // header. The second field is a big-endian uint32 describing the // length of the remaining frame (not including the initial 5 bytes). -type frameType byte +type FrameType byte /* Protocol flow: @@ -64,14 +69,14 @@ Steady state: * server then sends frameRecvPacket to recipient */ const ( - frameServerKey = frameType(0x01) // 8B magic + 32B public key + (0+ bytes future use) - frameClientInfo = frameType(0x02) // 32B pub key + 24B nonce + naclbox(json) - frameServerInfo = frameType(0x03) // 24B nonce + naclbox(json) - frameSendPacket = frameType(0x04) // 32B dest pub key + packet bytes - frameForwardPacket = frameType(0x0a) // 32B src pub key + 32B dst pub key + packet bytes - frameRecvPacket = frameType(0x05) // v0/1: packet bytes, v2: 32B src pub key + packet bytes - frameKeepAlive = frameType(0x06) // no payload, no-op (to be replaced with ping/pong) - frameNotePreferred = frameType(0x07) // 1 byte payload: 0x01 or 0x00 for whether this is client's home node + FrameServerKey = FrameType(0x01) // 8B magic + 32B public key + (0+ bytes future use) + FrameClientInfo = FrameType(0x02) // 32B pub key + 24B nonce + naclbox(json) + FrameServerInfo = FrameType(0x03) // 24B nonce + naclbox(json) + FrameSendPacket = FrameType(0x04) // 32B dest pub key + packet bytes + FrameForwardPacket = FrameType(0x0a) // 32B src pub key + 32B dst pub key + packet bytes + FrameRecvPacket = FrameType(0x05) // v0/1: packet bytes, v2: 32B src pub key + packet bytes + FrameKeepAlive = FrameType(0x06) // no payload, no-op (to be replaced with ping/pong) + FrameNotePreferred = FrameType(0x07) // 1 byte payload: 0x01 or 0x00 for whether this is client's home node // framePeerGone is sent from server to client to signal that // a previous sender is no longer connected. That is, if A @@ -79,9 +84,8 @@ const ( // framePeerGone to B so B can forget that a reverse path // exists on that connection to get back to A. It is also sent // if A tries to send a CallMeMaybe to B and the server has no - // record of B (which currently would only happen if there was - // a bug). - framePeerGone = frameType(0x08) // 32B pub key of peer that's gone + 1 byte reason + // record of B + FramePeerGone = FrameType(0x08) // 32B pub key of peer that's gone + 1 byte reason // framePeerPresent is like framePeerGone, but for other members of the DERP // region when they're meshed up together. @@ -92,7 +96,7 @@ const ( // remaining after that, it's a PeerPresentFlags byte. // While current servers send 41 bytes, old servers will send fewer, and newer // servers might send more. - framePeerPresent = frameType(0x09) + FramePeerPresent = FrameType(0x09) // frameWatchConns is how one DERP node in a regional mesh // subscribes to the others in the region. @@ -100,30 +104,30 @@ const ( // is closed. Otherwise, the client is initially flooded with // framePeerPresent for all connected nodes, and then a stream of // framePeerPresent & framePeerGone has peers connect and disconnect. - frameWatchConns = frameType(0x10) + FrameWatchConns = FrameType(0x10) // frameClosePeer is a privileged frame type (requires the // mesh key for now) that closes the provided peer's // connection. (To be used for cluster load balancing // purposes, when clients end up on a non-ideal node) - frameClosePeer = frameType(0x11) // 32B pub key of peer to close. + FrameClosePeer = FrameType(0x11) // 32B pub key of peer to close. - framePing = frameType(0x12) // 8 byte ping payload, to be echoed back in framePong - framePong = frameType(0x13) // 8 byte payload, the contents of the ping being replied to + FramePing = FrameType(0x12) // 8 byte ping payload, to be echoed back in framePong + FramePong = FrameType(0x13) // 8 byte payload, the contents of the ping being replied to // frameHealth is sent from server to client to tell the client // if their connection is unhealthy somehow. Currently the only unhealthy state // is whether the connection is detected as a duplicate. // The entire frame body is the text of the error message. An empty message // clears the error state. - frameHealth = frameType(0x14) + FrameHealth = FrameType(0x14) // frameRestarting is sent from server to client for the // server to declare that it's restarting. Payload is two big // endian uint32 durations in milliseconds: when to reconnect, // and how long to try total. See ServerRestartingMessage docs for // more details on how the client should interpret them. - frameRestarting = frameType(0x15) + FrameRestarting = FrameType(0x15) ) // PeerGoneReasonType is a one byte reason code explaining why a @@ -131,8 +135,8 @@ const ( type PeerGoneReasonType byte const ( - PeerGoneReasonDisconnected = PeerGoneReasonType(0x00) // peer disconnected from this server - PeerGoneReasonNotHere = PeerGoneReasonType(0x01) // server doesn't know about this peer, unexpected + PeerGoneReasonDisconnected = PeerGoneReasonType(0x00) // is only sent when a peer disconnects from this server + PeerGoneReasonNotHere = PeerGoneReasonType(0x01) // server doesn't know about this peer PeerGoneReasonMeshConnBroke = PeerGoneReasonType(0xf0) // invented by Client.RunWatchConnectionLoop on disconnect; not sent on the wire ) @@ -147,8 +151,21 @@ const ( PeerPresentIsRegular = 1 << 0 PeerPresentIsMeshPeer = 1 << 1 PeerPresentIsProber = 1 << 2 + PeerPresentNotIdeal = 1 << 3 // client said derp server is not its Region.Nodes[0] ideal node ) +// IdealNodeHeader is the HTTP request header sent on DERP HTTP client requests +// to indicate that they're connecting to their ideal (Region.Nodes[0]) node. +// The HTTP header value is the name of the node they wish they were connected +// to. This is an optional header. +const IdealNodeHeader = "Ideal-Node" + +// FastStartHeader is the header (with value "1") that signals to the HTTP +// server that the DERP HTTP client does not want the HTTP 101 response +// headers and it will begin writing & reading the DERP protocol immediately +// following its HTTP request. +const FastStartHeader = "Derp-Fast-Start" + var bin = binary.BigEndian func writeUint32(bw *bufio.Writer, v uint32) error { @@ -181,15 +198,24 @@ func readUint32(br *bufio.Reader) (uint32, error) { return bin.Uint32(b[:]), nil } -func readFrameTypeHeader(br *bufio.Reader, wantType frameType) (frameLen uint32, err error) { - gotType, frameLen, err := readFrameHeader(br) +// ReadFrameTypeHeader reads a frame header from br and +// verifies that the frame type matches wantType. +// +// If it does, it returns the frame length (not including +// the 5 byte header) and a nil error. +// +// If it doesn't, it returns an error and a zero length. +func ReadFrameTypeHeader(br *bufio.Reader, wantType FrameType) (frameLen uint32, err error) { + gotType, frameLen, err := ReadFrameHeader(br) if err == nil && wantType != gotType { err = fmt.Errorf("bad frame type 0x%X, want 0x%X", gotType, wantType) } return frameLen, err } -func readFrameHeader(br *bufio.Reader) (t frameType, frameLen uint32, err error) { +// ReadFrameHeader reads the header of a DERP frame, +// reading 5 bytes from br. +func ReadFrameHeader(br *bufio.Reader) (t FrameType, frameLen uint32, err error) { tb, err := br.ReadByte() if err != nil { return 0, 0, err @@ -198,7 +224,7 @@ func readFrameHeader(br *bufio.Reader) (t frameType, frameLen uint32, err error) if err != nil { return 0, 0, err } - return frameType(tb), frameLen, nil + return FrameType(tb), frameLen, nil } // readFrame reads a frame header and then reads its payload into @@ -211,8 +237,8 @@ func readFrameHeader(br *bufio.Reader) (t frameType, frameLen uint32, err error) // bytes are read, err will be io.ErrShortBuffer, and frameLen and t // will both be set. That is, callers need to explicitly handle when // they get more data than expected. -func readFrame(br *bufio.Reader, maxSize uint32, b []byte) (t frameType, frameLen uint32, err error) { - t, frameLen, err = readFrameHeader(br) +func readFrame(br *bufio.Reader, maxSize uint32, b []byte) (t FrameType, frameLen uint32, err error) { + t, frameLen, err = ReadFrameHeader(br) if err != nil { return 0, 0, err } @@ -234,19 +260,26 @@ func readFrame(br *bufio.Reader, maxSize uint32, b []byte) (t frameType, frameLe return t, frameLen, err } -func writeFrameHeader(bw *bufio.Writer, t frameType, frameLen uint32) error { +// WriteFrameHeader writes a frame header to bw. +// +// The frame header is 5 bytes: a one byte frame type +// followed by a big-endian uint32 length of the +// remaining frame (not including the 5 byte header). +// +// It does not flush bw. +func WriteFrameHeader(bw *bufio.Writer, t FrameType, frameLen uint32) error { if err := bw.WriteByte(byte(t)); err != nil { return err } return writeUint32(bw, frameLen) } -// writeFrame writes a complete frame & flushes it. -func writeFrame(bw *bufio.Writer, t frameType, b []byte) error { +// WriteFrame writes a complete frame & flushes it. +func WriteFrame(bw *bufio.Writer, t FrameType, b []byte) error { if len(b) > 10<<20 { return errors.New("unreasonably large frame write") } - if err := writeFrameHeader(bw, t, uint32(len(b))); err != nil { + if err := WriteFrameHeader(bw, t, uint32(len(b))); err != nil { return err } if _, err := bw.Write(b); err != nil { @@ -254,3 +287,23 @@ func writeFrame(bw *bufio.Writer, t frameType, b []byte) error { } return bw.Flush() } + +// Conn is the subset of the underlying net.Conn the DERP Server needs. +// It is a defined type so that non-net connections can be used. +type Conn interface { + io.WriteCloser + LocalAddr() net.Addr + // The *Deadline methods follow the semantics of net.Conn. + SetDeadline(time.Time) error + SetReadDeadline(time.Time) error + SetWriteDeadline(time.Time) error +} + +// ServerInfo is the message sent from the server to clients during +// the connection setup. +type ServerInfo struct { + Version int `json:"version,omitempty"` + + TokenBucketBytesPerSecond int `json:",omitempty"` + TokenBucketBytesBurst int `json:",omitempty"` +} diff --git a/derp/derp_client.go b/derp/derp_client.go index 7a646fa51..d28905cd2 100644 --- a/derp/derp_client.go +++ b/derp/derp_client.go @@ -30,7 +30,7 @@ type Client struct { logf logger.Logf nc Conn br *bufio.Reader - meshKey string + meshKey key.DERPMesh canAckPings bool isProber bool @@ -56,7 +56,7 @@ func (f clientOptFunc) update(o *clientOpt) { f(o) } // clientOpt are the options passed to newClient. type clientOpt struct { - MeshKey string + MeshKey key.DERPMesh ServerPub key.NodePublic CanAckPings bool IsProber bool @@ -66,7 +66,7 @@ type clientOpt struct { // access to join the mesh. // // An empty key means to not use a mesh key. -func MeshKey(key string) ClientOpt { return clientOptFunc(func(o *clientOpt) { o.MeshKey = key }) } +func MeshKey(k key.DERPMesh) ClientOpt { return clientOptFunc(func(o *clientOpt) { o.MeshKey = k }) } // IsProber returns a ClientOpt to pass to the DERP server during connect to // declare that this client is a a prober. @@ -133,17 +133,17 @@ func (c *Client) recvServerKey() error { if err != nil { return err } - if flen < uint32(len(buf)) || t != frameServerKey || string(buf[:len(magic)]) != magic { + if flen < uint32(len(buf)) || t != FrameServerKey || string(buf[:len(Magic)]) != Magic { return errors.New("invalid server greeting") } - c.serverKey = key.NodePublicFromRaw32(mem.B(buf[len(magic):])) + c.serverKey = key.NodePublicFromRaw32(mem.B(buf[len(Magic):])) return nil } -func (c *Client) parseServerInfo(b []byte) (*serverInfo, error) { - const maxLength = nonceLen + maxInfoLen +func (c *Client) parseServerInfo(b []byte) (*ServerInfo, error) { + const maxLength = NonceLen + MaxInfoLen fl := len(b) - if fl < nonceLen { + if fl < NonceLen { return nil, fmt.Errorf("short serverInfo frame") } if fl > maxLength { @@ -153,19 +153,21 @@ func (c *Client) parseServerInfo(b []byte) (*serverInfo, error) { if !ok { return nil, fmt.Errorf("failed to open naclbox from server key %s", c.serverKey) } - info := new(serverInfo) + info := new(ServerInfo) if err := json.Unmarshal(msg, info); err != nil { return nil, fmt.Errorf("invalid JSON: %v", err) } return info, nil } -type clientInfo struct { +// ClientInfo is the information a DERP client sends to the server +// about itself when it connects. +type ClientInfo struct { // MeshKey optionally specifies a pre-shared key used by // trusted clients. It's required to subscribe to the // connection list & forward packets. It's empty for regular // users. - MeshKey string `json:"meshKey,omitempty"` + MeshKey key.DERPMesh `json:"meshKey,omitempty,omitzero"` // Version is the DERP protocol version that the client was built with. // See the ProtocolVersion const. @@ -179,8 +181,19 @@ type clientInfo struct { IsProber bool `json:",omitempty"` } +// Equal reports if two clientInfo values are equal. +func (c *ClientInfo) Equal(other *ClientInfo) bool { + if c == nil || other == nil { + return c == other + } + if c.Version != other.Version || c.CanAckPings != other.CanAckPings || c.IsProber != other.IsProber { + return false + } + return c.MeshKey.Equal(other.MeshKey) +} + func (c *Client) sendClientKey() error { - msg, err := json.Marshal(clientInfo{ + msg, err := json.Marshal(ClientInfo{ Version: ProtocolVersion, MeshKey: c.meshKey, CanAckPings: c.canAckPings, @@ -191,10 +204,10 @@ func (c *Client) sendClientKey() error { } msgbox := c.privateKey.SealTo(c.serverKey, msg) - buf := make([]byte, 0, keyLen+len(msgbox)) + buf := make([]byte, 0, KeyLen+len(msgbox)) buf = c.publicKey.AppendTo(buf) buf = append(buf, msgbox...) - return writeFrame(c.bw, frameClientInfo, buf) + return WriteFrame(c.bw, FrameClientInfo, buf) } // ServerPublicKey returns the server's public key. @@ -219,12 +232,12 @@ func (c *Client) send(dstKey key.NodePublic, pkt []byte) (ret error) { c.wmu.Lock() defer c.wmu.Unlock() if c.rate != nil { - pktLen := frameHeaderLen + key.NodePublicRawLen + len(pkt) + pktLen := FrameHeaderLen + key.NodePublicRawLen + len(pkt) if !c.rate.AllowN(c.clock.Now(), pktLen) { return nil // drop } } - if err := writeFrameHeader(c.bw, frameSendPacket, uint32(key.NodePublicRawLen+len(pkt))); err != nil { + if err := WriteFrameHeader(c.bw, FrameSendPacket, uint32(key.NodePublicRawLen+len(pkt))); err != nil { return err } if _, err := c.bw.Write(dstKey.AppendTo(nil)); err != nil { @@ -253,7 +266,7 @@ func (c *Client) ForwardPacket(srcKey, dstKey key.NodePublic, pkt []byte) (err e timer := c.clock.AfterFunc(5*time.Second, c.writeTimeoutFired) defer timer.Stop() - if err := writeFrameHeader(c.bw, frameForwardPacket, uint32(keyLen*2+len(pkt))); err != nil { + if err := WriteFrameHeader(c.bw, FrameForwardPacket, uint32(KeyLen*2+len(pkt))); err != nil { return err } if _, err := c.bw.Write(srcKey.AppendTo(nil)); err != nil { @@ -271,17 +284,17 @@ func (c *Client) ForwardPacket(srcKey, dstKey key.NodePublic, pkt []byte) (err e func (c *Client) writeTimeoutFired() { c.nc.Close() } func (c *Client) SendPing(data [8]byte) error { - return c.sendPingOrPong(framePing, data) + return c.sendPingOrPong(FramePing, data) } func (c *Client) SendPong(data [8]byte) error { - return c.sendPingOrPong(framePong, data) + return c.sendPingOrPong(FramePong, data) } -func (c *Client) sendPingOrPong(typ frameType, data [8]byte) error { +func (c *Client) sendPingOrPong(typ FrameType, data [8]byte) error { c.wmu.Lock() defer c.wmu.Unlock() - if err := writeFrameHeader(c.bw, typ, 8); err != nil { + if err := WriteFrameHeader(c.bw, typ, 8); err != nil { return err } if _, err := c.bw.Write(data[:]); err != nil { @@ -303,7 +316,7 @@ func (c *Client) NotePreferred(preferred bool) (err error) { c.wmu.Lock() defer c.wmu.Unlock() - if err := writeFrameHeader(c.bw, frameNotePreferred, 1); err != nil { + if err := WriteFrameHeader(c.bw, FrameNotePreferred, 1); err != nil { return err } var b byte = 0x00 @@ -321,7 +334,7 @@ func (c *Client) NotePreferred(preferred bool) (err error) { func (c *Client) WatchConnectionChanges() error { c.wmu.Lock() defer c.wmu.Unlock() - if err := writeFrameHeader(c.bw, frameWatchConns, 0); err != nil { + if err := WriteFrameHeader(c.bw, FrameWatchConns, 0); err != nil { return err } return c.bw.Flush() @@ -332,7 +345,7 @@ func (c *Client) WatchConnectionChanges() error { func (c *Client) ClosePeer(target key.NodePublic) error { c.wmu.Lock() defer c.wmu.Unlock() - return writeFrame(c.bw, frameClosePeer, target.AppendTo(nil)) + return WriteFrame(c.bw, FrameClosePeer, target.AppendTo(nil)) } // ReceivedMessage represents a type returned by Client.Recv. Unless @@ -491,7 +504,7 @@ func (c *Client) recvTimeout(timeout time.Duration) (m ReceivedMessage, err erro c.peeked = 0 } - t, n, err := readFrameHeader(c.br) + t, n, err := ReadFrameHeader(c.br) if err != nil { return nil, err } @@ -522,7 +535,7 @@ func (c *Client) recvTimeout(timeout time.Duration) (m ReceivedMessage, err erro switch t { default: continue - case frameServerInfo: + case FrameServerInfo: // Server sends this at start-up. Currently unused. // Just has a JSON message saying "version: 2", // but the protocol seems extensible enough as-is without @@ -539,29 +552,29 @@ func (c *Client) recvTimeout(timeout time.Duration) (m ReceivedMessage, err erro } c.setSendRateLimiter(sm) return sm, nil - case frameKeepAlive: + case FrameKeepAlive: // A one-way keep-alive message that doesn't require an acknowledgement. // This predated framePing/framePong. return KeepAliveMessage{}, nil - case framePeerGone: - if n < keyLen { + case FramePeerGone: + if n < KeyLen { c.logf("[unexpected] dropping short peerGone frame from DERP server") continue } // Backward compatibility for the older peerGone without reason byte reason := PeerGoneReasonDisconnected - if n > keyLen { - reason = PeerGoneReasonType(b[keyLen]) + if n > KeyLen { + reason = PeerGoneReasonType(b[KeyLen]) } pg := PeerGoneMessage{ - Peer: key.NodePublicFromRaw32(mem.B(b[:keyLen])), + Peer: key.NodePublicFromRaw32(mem.B(b[:KeyLen])), Reason: reason, } return pg, nil - case framePeerPresent: + case FramePeerPresent: remain := b - chunk, remain, ok := cutLeadingN(remain, keyLen) + chunk, remain, ok := cutLeadingN(remain, KeyLen) if !ok { c.logf("[unexpected] dropping short peerPresent frame from DERP server") continue @@ -589,17 +602,17 @@ func (c *Client) recvTimeout(timeout time.Duration) (m ReceivedMessage, err erro msg.Flags = PeerPresentFlags(chunk[0]) return msg, nil - case frameRecvPacket: + case FrameRecvPacket: var rp ReceivedPacket - if n < keyLen { + if n < KeyLen { c.logf("[unexpected] dropping short packet from DERP server") continue } - rp.Source = key.NodePublicFromRaw32(mem.B(b[:keyLen])) - rp.Data = b[keyLen:n] + rp.Source = key.NodePublicFromRaw32(mem.B(b[:KeyLen])) + rp.Data = b[KeyLen:n] return rp, nil - case framePing: + case FramePing: var pm PingMessage if n < 8 { c.logf("[unexpected] dropping short ping frame") @@ -608,7 +621,7 @@ func (c *Client) recvTimeout(timeout time.Duration) (m ReceivedMessage, err erro copy(pm[:], b[:]) return pm, nil - case framePong: + case FramePong: var pm PongMessage if n < 8 { c.logf("[unexpected] dropping short ping frame") @@ -617,10 +630,10 @@ func (c *Client) recvTimeout(timeout time.Duration) (m ReceivedMessage, err erro copy(pm[:], b[:]) return pm, nil - case frameHealth: + case FrameHealth: return HealthMessage{Problem: string(b[:])}, nil - case frameRestarting: + case FrameRestarting: var m ServerRestartingMessage if n < 8 { c.logf("[unexpected] dropping short server restarting frame") diff --git a/derp/derp_test.go b/derp/derp_test.go index 9185194dd..52793f90f 100644 --- a/derp/derp_test.go +++ b/derp/derp_test.go @@ -1,56 +1,89 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -package derp +package derp_test import ( "bufio" "bytes" "context" - "crypto/x509" - "encoding/asn1" "encoding/json" "errors" "expvar" "fmt" "io" - "log" "net" - "os" - "reflect" - "strconv" + "strings" "sync" "testing" "time" - "go4.org/mem" - "golang.org/x/time/rate" + "tailscale.com/derp" + "tailscale.com/derp/derpserver" "tailscale.com/disco" + "tailscale.com/metrics" "tailscale.com/net/memnet" - "tailscale.com/tstest" "tailscale.com/types/key" "tailscale.com/types/logger" + "tailscale.com/util/must" +) + +type ( + ClientInfo = derp.ClientInfo + Conn = derp.Conn + Client = derp.Client ) func TestClientInfoUnmarshal(t *testing.T) { - for i, in := range []string{ - `{"Version":5,"MeshKey":"abc"}`, - `{"version":5,"meshKey":"abc"}`, + for i, in := range map[string]struct { + json string + want *ClientInfo + wantErr string + }{ + "empty": { + json: `{}`, + want: &ClientInfo{}, + }, + "valid": { + json: `{"Version":5,"MeshKey":"6d529e9d4ef632d22d4a4214cb49da8f1ba1b72697061fb24e312984c35ec8d8"}`, + want: &ClientInfo{MeshKey: must.Get(key.ParseDERPMesh("6d529e9d4ef632d22d4a4214cb49da8f1ba1b72697061fb24e312984c35ec8d8")), Version: 5}, + }, + "validLowerMeshKey": { + json: `{"version":5,"meshKey":"6d529e9d4ef632d22d4a4214cb49da8f1ba1b72697061fb24e312984c35ec8d8"}`, + want: &ClientInfo{MeshKey: must.Get(key.ParseDERPMesh("6d529e9d4ef632d22d4a4214cb49da8f1ba1b72697061fb24e312984c35ec8d8")), Version: 5}, + }, + "invalidMeshKeyToShort": { + json: `{"version":5,"meshKey":"abcdefg"}`, + wantErr: "invalid mesh key", + }, + "invalidMeshKeyToLong": { + json: `{"version":5,"meshKey":"ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"}`, + wantErr: "invalid mesh key", + }, } { - var got clientInfo - if err := json.Unmarshal([]byte(in), &got); err != nil { - t.Fatalf("[%d]: %v", i, err) - } - want := clientInfo{Version: 5, MeshKey: "abc"} - if got != want { - t.Errorf("[%d]: got %+v; want %+v", i, got, want) - } + t.Run(i, func(t *testing.T) { + t.Parallel() + var got ClientInfo + err := json.Unmarshal([]byte(in.json), &got) + if in.wantErr != "" { + if err == nil || !strings.Contains(err.Error(), in.wantErr) { + t.Errorf("Unmarshal(%q) = %v, want error containing %q", in.json, err, in.wantErr) + } + return + } + if err != nil { + t.Fatalf("Unmarshal(%q) = %v, want no error", in.json, err) + } + if !got.Equal(in.want) { + t.Errorf("Unmarshal(%q) = %+v, want %+v", in.json, got, in.want) + } + }) } } func TestSendRecv(t *testing.T) { serverPrivateKey := key.NewNode() - s := NewServer(serverPrivateKey, t.Logf) + s := derpserver.New(serverPrivateKey, t.Logf) defer s.Close() const numClients = 3 @@ -96,7 +129,7 @@ func TestSendRecv(t *testing.T) { key := clientPrivateKeys[i] brw := bufio.NewReadWriter(bufio.NewReader(cout), bufio.NewWriter(cout)) - c, err := NewClient(key, cout, brw, t.Logf) + c, err := derp.NewClient(key, cout, brw, t.Logf) if err != nil { t.Fatalf("client %d: %v", i, err) } @@ -123,16 +156,16 @@ func TestSendRecv(t *testing.T) { default: t.Errorf("unexpected message type %T", m) continue - case PeerGoneMessage: + case derp.PeerGoneMessage: switch m.Reason { - case PeerGoneReasonDisconnected: + case derp.PeerGoneReasonDisconnected: peerGoneCountDisconnected.Add(1) - case PeerGoneReasonNotHere: + case derp.PeerGoneReasonNotHere: peerGoneCountNotHere.Add(1) default: t.Errorf("unexpected PeerGone reason %v", m.Reason) } - case ReceivedPacket: + case derp.ReceivedPacket: if m.Source.IsZero() { t.Errorf("zero Source address in ReceivedPacket") } @@ -162,12 +195,15 @@ func TestSendRecv(t *testing.T) { } } + serverMetrics := s.ExpVar().(*metrics.Set) + wantActive := func(total, home int64) { t.Helper() dl := time.Now().Add(5 * time.Second) var gotTotal, gotHome int64 for time.Now().Before(dl) { - gotTotal, gotHome = s.curClients.Value(), s.curHomeClients.Value() + gotTotal = serverMetrics.Get("gauge_current_connections").(*expvar.Int).Value() + gotHome = serverMetrics.Get("gauge_current_home_connections").(*expvar.Int).Value() if gotTotal == total && gotHome == home { return } @@ -269,7 +305,7 @@ func TestSendRecv(t *testing.T) { func TestSendFreeze(t *testing.T) { serverPrivateKey := key.NewNode() - s := NewServer(serverPrivateKey, t.Logf) + s := derpserver.New(serverPrivateKey, t.Logf) defer s.Close() s.WriteTimeout = 100 * time.Millisecond @@ -287,7 +323,7 @@ func TestSendFreeze(t *testing.T) { go s.Accept(ctx, c1, bufio.NewReadWriter(bufio.NewReader(c1), bufio.NewWriter(c1)), name) brw := bufio.NewReadWriter(bufio.NewReader(c2), bufio.NewWriter(c2)) - c, err := NewClient(k, c2, brw, t.Logf) + c, err := derp.NewClient(k, c2, brw, t.Logf) if err != nil { t.Fatal(err) } @@ -338,7 +374,7 @@ func TestSendFreeze(t *testing.T) { default: errCh <- fmt.Errorf("%s: unexpected message type %T", name, m) return - case ReceivedPacket: + case derp.ReceivedPacket: if m.Source.IsZero() { errCh <- fmt.Errorf("%s: zero Source address in ReceivedPacket", name) return @@ -468,7 +504,7 @@ func TestSendFreeze(t *testing.T) { } type testServer struct { - s *Server + s *derpserver.Server ln net.Listener logf logger.Logf @@ -508,11 +544,13 @@ func (ts *testServer) close(t *testing.T) error { return nil } +const testMeshKey = "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef" + func newTestServer(t *testing.T, ctx context.Context) *testServer { t.Helper() logf := logger.WithPrefix(t.Logf, "derp-server: ") - s := NewServer(key.NewNode(), logf) - s.SetMeshKey("mesh-key") + s := derpserver.New(key.NewNode(), logf) + s.SetMeshKey(testMeshKey) ln, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatal(err) @@ -576,7 +614,7 @@ func newTestClient(t *testing.T, ts *testServer, name string, newClient func(net func newRegularClient(t *testing.T, ts *testServer, name string) *testClient { return newTestClient(t, ts, name, func(nc net.Conn, priv key.NodePrivate, logf logger.Logf) (*Client, error) { brw := bufio.NewReadWriter(bufio.NewReader(nc), bufio.NewWriter(nc)) - c, err := NewClient(priv, nc, brw, logf) + c, err := derp.NewClient(priv, nc, brw, logf) if err != nil { return nil, err } @@ -588,8 +626,12 @@ func newRegularClient(t *testing.T, ts *testServer, name string) *testClient { func newTestWatcher(t *testing.T, ts *testServer, name string) *testClient { return newTestClient(t, ts, name, func(nc net.Conn, priv key.NodePrivate, logf logger.Logf) (*Client, error) { + mk, err := key.ParseDERPMesh(testMeshKey) + if err != nil { + return nil, err + } brw := bufio.NewReadWriter(bufio.NewReader(nc), bufio.NewWriter(nc)) - c, err := NewClient(priv, nc, brw, logf, MeshKey("mesh-key")) + c, err := derp.NewClient(priv, nc, brw, logf, derp.MeshKey(mk)) if err != nil { return nil, err } @@ -609,12 +651,12 @@ func (tc *testClient) wantPresent(t *testing.T, peers ...key.NodePublic) { } for { - m, err := tc.c.recvTimeout(time.Second) + m, err := tc.c.RecvTimeoutForTest(time.Second) if err != nil { t.Fatal(err) } switch m := m.(type) { - case PeerPresentMessage: + case derp.PeerPresentMessage: got := m.Key if !want[got] { t.Fatalf("got peer present for %v; want present for %v", tc.ts.keyName(got), logger.ArgWriter(func(bw *bufio.Writer) { @@ -625,7 +667,7 @@ func (tc *testClient) wantPresent(t *testing.T, peers ...key.NodePublic) { } t.Logf("got present with IP %v, flags=%v", m.IPPort, m.Flags) switch m.Flags { - case PeerPresentIsMeshPeer, PeerPresentIsRegular: + case derp.PeerPresentIsMeshPeer, derp.PeerPresentIsRegular: // Okay default: t.Errorf("unexpected PeerPresentIsMeshPeer flags %v", m.Flags) @@ -642,19 +684,19 @@ func (tc *testClient) wantPresent(t *testing.T, peers ...key.NodePublic) { func (tc *testClient) wantGone(t *testing.T, peer key.NodePublic) { t.Helper() - m, err := tc.c.recvTimeout(time.Second) + m, err := tc.c.RecvTimeoutForTest(time.Second) if err != nil { t.Fatal(err) } switch m := m.(type) { - case PeerGoneMessage: + case derp.PeerGoneMessage: got := key.NodePublic(m.Peer) if peer != got { t.Errorf("got gone message for %v; want gone for %v", tc.ts.keyName(got), tc.ts.keyName(peer)) } reason := m.Reason - if reason != PeerGoneReasonDisconnected { - t.Errorf("got gone message for reason %v; wanted %v", reason, PeerGoneReasonDisconnected) + if reason != derp.PeerGoneReasonDisconnected { + t.Errorf("got gone message for reason %v; wanted %v", reason, derp.PeerGoneReasonDisconnected) } default: t.Fatalf("unexpected message type %T", m) @@ -712,863 +754,15 @@ func TestWatch(t *testing.T) { w3.wantGone(t, c1.pub) } -type testFwd int - -func (testFwd) ForwardPacket(key.NodePublic, key.NodePublic, []byte) error { - panic("not called in tests") -} -func (testFwd) String() string { - panic("not called in tests") -} - -func pubAll(b byte) (ret key.NodePublic) { - var bs [32]byte - for i := range bs { - bs[i] = b - } - return key.NodePublicFromRaw32(mem.B(bs[:])) -} - -func TestForwarderRegistration(t *testing.T) { - s := &Server{ - clients: make(map[key.NodePublic]*clientSet), - clientsMesh: map[key.NodePublic]PacketForwarder{}, - } - want := func(want map[key.NodePublic]PacketForwarder) { - t.Helper() - if got := s.clientsMesh; !reflect.DeepEqual(got, want) { - t.Fatalf("mismatch\n got: %v\nwant: %v\n", got, want) - } - } - wantCounter := func(c *expvar.Int, want int) { - t.Helper() - if got := c.Value(); got != int64(want) { - t.Errorf("counter = %v; want %v", got, want) - } - } - singleClient := func(c *sclient) *clientSet { - cs := &clientSet{} - cs.activeClient.Store(c) - return cs - } - - u1 := pubAll(1) - u2 := pubAll(2) - u3 := pubAll(3) - - s.AddPacketForwarder(u1, testFwd(1)) - s.AddPacketForwarder(u2, testFwd(2)) - want(map[key.NodePublic]PacketForwarder{ - u1: testFwd(1), - u2: testFwd(2), - }) - - // Verify a remove of non-registered forwarder is no-op. - s.RemovePacketForwarder(u2, testFwd(999)) - want(map[key.NodePublic]PacketForwarder{ - u1: testFwd(1), - u2: testFwd(2), - }) - - // Verify a remove of non-registered user is no-op. - s.RemovePacketForwarder(u3, testFwd(1)) - want(map[key.NodePublic]PacketForwarder{ - u1: testFwd(1), - u2: testFwd(2), - }) - - // Actual removal. - s.RemovePacketForwarder(u2, testFwd(2)) - want(map[key.NodePublic]PacketForwarder{ - u1: testFwd(1), - }) - - // Adding a dup for a user. - wantCounter(&s.multiForwarderCreated, 0) - s.AddPacketForwarder(u1, testFwd(100)) - s.AddPacketForwarder(u1, testFwd(100)) // dup to trigger dup path - want(map[key.NodePublic]PacketForwarder{ - u1: newMultiForwarder(testFwd(1), testFwd(100)), - }) - wantCounter(&s.multiForwarderCreated, 1) - - // Removing a forwarder in a multi set that doesn't exist; does nothing. - s.RemovePacketForwarder(u1, testFwd(55)) - want(map[key.NodePublic]PacketForwarder{ - u1: newMultiForwarder(testFwd(1), testFwd(100)), - }) - - // Removing a forwarder in a multi set that does exist should collapse it away - // from being a multiForwarder. - wantCounter(&s.multiForwarderDeleted, 0) - s.RemovePacketForwarder(u1, testFwd(1)) - want(map[key.NodePublic]PacketForwarder{ - u1: testFwd(100), - }) - wantCounter(&s.multiForwarderDeleted, 1) - - // Removing an entry for a client that's still connected locally should result - // in a nil forwarder. - u1c := &sclient{ - key: u1, - logf: logger.Discard, - } - s.clients[u1] = singleClient(u1c) - s.RemovePacketForwarder(u1, testFwd(100)) - want(map[key.NodePublic]PacketForwarder{ - u1: nil, - }) - - // But once that client disconnects, it should go away. - s.unregisterClient(u1c) - want(map[key.NodePublic]PacketForwarder{}) - - // But if it already has a forwarder, it's not removed. - s.AddPacketForwarder(u1, testFwd(2)) - s.unregisterClient(u1c) - want(map[key.NodePublic]PacketForwarder{ - u1: testFwd(2), - }) - - // Now pretend u1 was already connected locally (so clientsMesh[u1] is nil), and then we heard - // that they're also connected to a peer of ours. That shouldn't transition the forwarder - // from nil to the new one, not a multiForwarder. - s.clients[u1] = singleClient(u1c) - s.clientsMesh[u1] = nil - want(map[key.NodePublic]PacketForwarder{ - u1: nil, - }) - s.AddPacketForwarder(u1, testFwd(3)) - want(map[key.NodePublic]PacketForwarder{ - u1: testFwd(3), - }) -} - -type channelFwd struct { - // id is to ensure that different instances that reference the - // same channel are not equal, as they are used as keys in the - // multiForwarder map. - id int - c chan []byte -} - -func (f channelFwd) String() string { return "" } -func (f channelFwd) ForwardPacket(_ key.NodePublic, _ key.NodePublic, packet []byte) error { - f.c <- packet - return nil -} - -func TestMultiForwarder(t *testing.T) { - received := 0 - var wg sync.WaitGroup - ch := make(chan []byte) - ctx, cancel := context.WithCancel(context.Background()) - - s := &Server{ - clients: make(map[key.NodePublic]*clientSet), - clientsMesh: map[key.NodePublic]PacketForwarder{}, - } - u := pubAll(1) - s.AddPacketForwarder(u, channelFwd{1, ch}) - - wg.Add(2) - go func() { - defer wg.Done() - for { - select { - case <-ch: - received += 1 - case <-ctx.Done(): - return - } - } - }() - go func() { - defer wg.Done() - for { - s.AddPacketForwarder(u, channelFwd{2, ch}) - s.AddPacketForwarder(u, channelFwd{3, ch}) - s.RemovePacketForwarder(u, channelFwd{2, ch}) - s.RemovePacketForwarder(u, channelFwd{1, ch}) - s.AddPacketForwarder(u, channelFwd{1, ch}) - s.RemovePacketForwarder(u, channelFwd{3, ch}) - if ctx.Err() != nil { - return - } - } - }() - - // Number of messages is chosen arbitrarily, just for this loop to - // run long enough concurrently with {Add,Remove}PacketForwarder loop above. - numMsgs := 5000 - var fwd PacketForwarder - for i := range numMsgs { - s.mu.Lock() - fwd = s.clientsMesh[u] - s.mu.Unlock() - fwd.ForwardPacket(u, u, []byte(strconv.Itoa(i))) - } - - cancel() - wg.Wait() - if received != numMsgs { - t.Errorf("expected %d messages to be forwarded; got %d", numMsgs, received) - } -} -func TestMetaCert(t *testing.T) { - priv := key.NewNode() - pub := priv.Public() - s := NewServer(priv, t.Logf) - - certBytes := s.MetaCert() - cert, err := x509.ParseCertificate(certBytes) - if err != nil { - log.Fatal(err) - } - if fmt.Sprint(cert.SerialNumber) != fmt.Sprint(ProtocolVersion) { - t.Errorf("serial = %v; want %v", cert.SerialNumber, ProtocolVersion) - } - if g, w := cert.Subject.CommonName, fmt.Sprintf("derpkey%s", pub.UntypedHexString()); g != w { - t.Errorf("CommonName = %q; want %q", g, w) - } - if n := len(cert.Extensions); n != 1 { - t.Fatalf("got %d extensions; want 1", n) - } - - // oidExtensionBasicConstraints is the Basic Constraints ID copied - // from the x509 package. - oidExtensionBasicConstraints := asn1.ObjectIdentifier{2, 5, 29, 19} - - if id := cert.Extensions[0].Id; !id.Equal(oidExtensionBasicConstraints) { - t.Errorf("extension ID = %v; want %v", id, oidExtensionBasicConstraints) - } -} - -type dummyNetConn struct { - net.Conn -} - -func (dummyNetConn) SetReadDeadline(time.Time) error { return nil } - -func TestClientRecv(t *testing.T) { - tests := []struct { - name string - input []byte - want any - }{ - { - name: "ping", - input: []byte{ - byte(framePing), 0, 0, 0, 8, - 1, 2, 3, 4, 5, 6, 7, 8, - }, - want: PingMessage{1, 2, 3, 4, 5, 6, 7, 8}, - }, - { - name: "pong", - input: []byte{ - byte(framePong), 0, 0, 0, 8, - 1, 2, 3, 4, 5, 6, 7, 8, - }, - want: PongMessage{1, 2, 3, 4, 5, 6, 7, 8}, - }, - { - name: "health_bad", - input: []byte{ - byte(frameHealth), 0, 0, 0, 3, - byte('B'), byte('A'), byte('D'), - }, - want: HealthMessage{Problem: "BAD"}, - }, - { - name: "health_ok", - input: []byte{ - byte(frameHealth), 0, 0, 0, 0, - }, - want: HealthMessage{}, - }, - { - name: "server_restarting", - input: []byte{ - byte(frameRestarting), 0, 0, 0, 8, - 0, 0, 0, 1, - 0, 0, 0, 2, - }, - want: ServerRestartingMessage{ - ReconnectIn: 1 * time.Millisecond, - TryFor: 2 * time.Millisecond, - }, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - c := &Client{ - nc: dummyNetConn{}, - br: bufio.NewReader(bytes.NewReader(tt.input)), - logf: t.Logf, - clock: &tstest.Clock{}, - } - got, err := c.Recv() - if err != nil { - t.Fatal(err) - } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("got %#v; want %#v", got, tt.want) - } - }) - } -} - -func TestClientSendPing(t *testing.T) { - var buf bytes.Buffer - c := &Client{ - bw: bufio.NewWriter(&buf), - } - if err := c.SendPing([8]byte{1, 2, 3, 4, 5, 6, 7, 8}); err != nil { - t.Fatal(err) - } - want := []byte{ - byte(framePing), 0, 0, 0, 8, - 1, 2, 3, 4, 5, 6, 7, 8, - } - if !bytes.Equal(buf.Bytes(), want) { - t.Errorf("unexpected output\nwrote: % 02x\n want: % 02x", buf.Bytes(), want) - } -} - -func TestClientSendPong(t *testing.T) { - var buf bytes.Buffer - c := &Client{ - bw: bufio.NewWriter(&buf), - } - if err := c.SendPong([8]byte{1, 2, 3, 4, 5, 6, 7, 8}); err != nil { - t.Fatal(err) - } - want := []byte{ - byte(framePong), 0, 0, 0, 8, - 1, 2, 3, 4, 5, 6, 7, 8, - } - if !bytes.Equal(buf.Bytes(), want) { - t.Errorf("unexpected output\nwrote: % 02x\n want: % 02x", buf.Bytes(), want) - } -} - -func TestServerDupClients(t *testing.T) { - serverPriv := key.NewNode() - var s *Server - - clientPriv := key.NewNode() - clientPub := clientPriv.Public() - - var c1, c2, c3 *sclient - var clientName map[*sclient]string - - // run starts a new test case and resets clients back to their zero values. - run := func(name string, dupPolicy dupPolicy, f func(t *testing.T)) { - s = NewServer(serverPriv, t.Logf) - s.dupPolicy = dupPolicy - c1 = &sclient{key: clientPub, logf: logger.WithPrefix(t.Logf, "c1: ")} - c2 = &sclient{key: clientPub, logf: logger.WithPrefix(t.Logf, "c2: ")} - c3 = &sclient{key: clientPub, logf: logger.WithPrefix(t.Logf, "c3: ")} - clientName = map[*sclient]string{ - c1: "c1", - c2: "c2", - c3: "c3", - } - t.Run(name, f) - } - runBothWays := func(name string, f func(t *testing.T)) { - run(name+"_disablefighters", disableFighters, f) - run(name+"_lastwriteractive", lastWriterIsActive, f) - } - wantSingleClient := func(t *testing.T, want *sclient) { - t.Helper() - got, ok := s.clients[want.key] - if !ok { - t.Error("no clients for key") - return - } - if got.dup != nil { - t.Errorf("unexpected dup set for single client") - } - cur := got.activeClient.Load() - if cur != want { - t.Errorf("active client = %q; want %q", clientName[cur], clientName[want]) - } - if cur != nil { - if cur.isDup.Load() { - t.Errorf("unexpected isDup on singleClient") - } - if cur.isDisabled.Load() { - t.Errorf("unexpected isDisabled on singleClient") - } - } - } - wantNoClient := func(t *testing.T) { - t.Helper() - _, ok := s.clients[clientPub] - if !ok { - // Good - return - } - t.Errorf("got client; want empty") - } - wantDupSet := func(t *testing.T) *dupClientSet { - t.Helper() - cs, ok := s.clients[clientPub] - if !ok { - t.Fatal("no set for key; want dup set") - return nil - } - if cs.dup != nil { - return cs.dup - } - t.Fatalf("no dup set for key; want dup set") - return nil - } - wantActive := func(t *testing.T, want *sclient) { - t.Helper() - set, ok := s.clients[clientPub] - if !ok { - t.Error("no set for key") - return - } - got := set.activeClient.Load() - if got != want { - t.Errorf("active client = %q; want %q", clientName[got], clientName[want]) - } - } - checkDup := func(t *testing.T, c *sclient, want bool) { - t.Helper() - if got := c.isDup.Load(); got != want { - t.Errorf("client %q isDup = %v; want %v", clientName[c], got, want) - } - } - checkDisabled := func(t *testing.T, c *sclient, want bool) { - t.Helper() - if got := c.isDisabled.Load(); got != want { - t.Errorf("client %q isDisabled = %v; want %v", clientName[c], got, want) - } - } - wantDupConns := func(t *testing.T, want int) { - t.Helper() - if got := s.dupClientConns.Value(); got != int64(want) { - t.Errorf("dupClientConns = %v; want %v", got, want) - } - } - wantDupKeys := func(t *testing.T, want int) { - t.Helper() - if got := s.dupClientKeys.Value(); got != int64(want) { - t.Errorf("dupClientKeys = %v; want %v", got, want) - } - } - - // Common case: a single client comes and goes, with no dups. - runBothWays("one_comes_and_goes", func(t *testing.T) { - wantNoClient(t) - s.registerClient(c1) - wantSingleClient(t, c1) - s.unregisterClient(c1) - wantNoClient(t) - }) - - // A still somewhat common case: a single client was - // connected and then their wifi dies or laptop closes - // or they switch networks and connect from a - // different network. They have two connections but - // it's not very bad. Only their new one is - // active. The last one, being dead, doesn't send and - // thus the new one doesn't get disabled. - runBothWays("small_overlap_replacement", func(t *testing.T) { - wantNoClient(t) - s.registerClient(c1) - wantSingleClient(t, c1) - wantActive(t, c1) - wantDupKeys(t, 0) - wantDupKeys(t, 0) - - s.registerClient(c2) // wifi dies; c2 replacement connects - wantDupSet(t) - wantDupConns(t, 2) - wantDupKeys(t, 1) - checkDup(t, c1, true) - checkDup(t, c2, true) - checkDisabled(t, c1, false) - checkDisabled(t, c2, false) - wantActive(t, c2) // sends go to the replacement - - s.unregisterClient(c1) // c1 finally times out - wantSingleClient(t, c2) - checkDup(t, c2, false) // c2 is longer a dup - wantActive(t, c2) - wantDupConns(t, 0) - wantDupKeys(t, 0) - }) - - // Key cloning situation with concurrent clients, both trying - // to write. - run("concurrent_dups_get_disabled", disableFighters, func(t *testing.T) { - wantNoClient(t) - s.registerClient(c1) - wantSingleClient(t, c1) - wantActive(t, c1) - s.registerClient(c2) - wantDupSet(t) - wantDupKeys(t, 1) - wantDupConns(t, 2) - wantActive(t, c2) - checkDup(t, c1, true) - checkDup(t, c2, true) - checkDisabled(t, c1, false) - checkDisabled(t, c2, false) - - s.noteClientActivity(c2) - checkDisabled(t, c1, false) - checkDisabled(t, c2, false) - s.noteClientActivity(c1) - checkDisabled(t, c1, true) - checkDisabled(t, c2, true) - wantActive(t, nil) - - s.registerClient(c3) - wantActive(t, c3) - checkDisabled(t, c3, false) - wantDupKeys(t, 1) - wantDupConns(t, 3) - - s.unregisterClient(c3) - wantActive(t, nil) - wantDupKeys(t, 1) - wantDupConns(t, 2) - - s.unregisterClient(c2) - wantSingleClient(t, c1) - wantDupKeys(t, 0) - wantDupConns(t, 0) - }) - - // Key cloning with an A->B->C->A series instead. - run("concurrent_dups_three_parties", disableFighters, func(t *testing.T) { - wantNoClient(t) - s.registerClient(c1) - s.registerClient(c2) - s.registerClient(c3) - s.noteClientActivity(c1) - checkDisabled(t, c1, true) - checkDisabled(t, c2, true) - checkDisabled(t, c3, true) - wantActive(t, nil) - }) - - run("activity_promotes_primary_when_nil", disableFighters, func(t *testing.T) { - wantNoClient(t) - - // Last registered client is the active one... - s.registerClient(c1) - wantActive(t, c1) - s.registerClient(c2) - wantActive(t, c2) - s.registerClient(c3) - s.noteClientActivity(c2) - wantActive(t, c3) - - // But if the last one goes away, the one with the - // most recent activity wins. - s.unregisterClient(c3) - wantActive(t, c2) - }) - - run("concurrent_dups_three_parties_last_writer", lastWriterIsActive, func(t *testing.T) { - wantNoClient(t) - - s.registerClient(c1) - wantActive(t, c1) - s.registerClient(c2) - wantActive(t, c2) - - s.noteClientActivity(c1) - checkDisabled(t, c1, false) - checkDisabled(t, c2, false) - wantActive(t, c1) - - s.noteClientActivity(c2) - checkDisabled(t, c1, false) - checkDisabled(t, c2, false) - wantActive(t, c2) - - s.unregisterClient(c2) - checkDisabled(t, c1, false) - wantActive(t, c1) - }) -} - -func TestLimiter(t *testing.T) { - rl := rate.NewLimiter(rate.Every(time.Minute), 100) - for i := range 200 { - r := rl.Reserve() - d := r.Delay() - t.Logf("i=%d, allow=%v, d=%v", i, r.OK(), d) - } -} - -// BenchmarkConcurrentStreams exercises mutex contention on a -// single Server instance with multiple concurrent client flows. -func BenchmarkConcurrentStreams(b *testing.B) { - serverPrivateKey := key.NewNode() - s := NewServer(serverPrivateKey, logger.Discard) - defer s.Close() - - ln, err := net.Listen("tcp", "127.0.0.1:0") - if err != nil { - b.Fatal(err) - } - defer ln.Close() - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - go func() { - for ctx.Err() == nil { - connIn, err := ln.Accept() - if err != nil { - if ctx.Err() != nil { - return - } - b.Error(err) - return - } - - brwServer := bufio.NewReadWriter(bufio.NewReader(connIn), bufio.NewWriter(connIn)) - go s.Accept(ctx, connIn, brwServer, "test-client") - } - }() - - newClient := func(t testing.TB) *Client { - t.Helper() - connOut, err := net.Dial("tcp", ln.Addr().String()) - if err != nil { - b.Fatal(err) - } - t.Cleanup(func() { connOut.Close() }) - - k := key.NewNode() - - brw := bufio.NewReadWriter(bufio.NewReader(connOut), bufio.NewWriter(connOut)) - client, err := NewClient(k, connOut, brw, logger.Discard) - if err != nil { - b.Fatalf("client: %v", err) - } - return client - } - - b.RunParallel(func(pb *testing.PB) { - c1, c2 := newClient(b), newClient(b) - const packetSize = 100 - msg := make([]byte, packetSize) - for pb.Next() { - if err := c1.Send(c2.PublicKey(), msg); err != nil { - b.Fatal(err) - } - _, err := c2.Recv() - if err != nil { - return - } - } - }) -} - -func BenchmarkSendRecv(b *testing.B) { - for _, size := range []int{10, 100, 1000, 10000} { - b.Run(fmt.Sprintf("msgsize=%d", size), func(b *testing.B) { benchmarkSendRecvSize(b, size) }) - } -} - -func benchmarkSendRecvSize(b *testing.B, packetSize int) { - serverPrivateKey := key.NewNode() - s := NewServer(serverPrivateKey, logger.Discard) - defer s.Close() - - k := key.NewNode() - clientKey := k.Public() - - ln, err := net.Listen("tcp", "127.0.0.1:0") - if err != nil { - b.Fatal(err) - } - defer ln.Close() - - connOut, err := net.Dial("tcp", ln.Addr().String()) - if err != nil { - b.Fatal(err) - } - defer connOut.Close() - - connIn, err := ln.Accept() - if err != nil { - b.Fatal(err) - } - defer connIn.Close() - - brwServer := bufio.NewReadWriter(bufio.NewReader(connIn), bufio.NewWriter(connIn)) - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - go s.Accept(ctx, connIn, brwServer, "test-client") - - brw := bufio.NewReadWriter(bufio.NewReader(connOut), bufio.NewWriter(connOut)) - client, err := NewClient(k, connOut, brw, logger.Discard) - if err != nil { - b.Fatalf("client: %v", err) - } - - go func() { - for { - _, err := client.Recv() - if err != nil { - return - } - } - }() - - msg := make([]byte, packetSize) - b.SetBytes(int64(len(msg))) - b.ReportAllocs() - b.ResetTimer() - for range b.N { - if err := client.Send(clientKey, msg); err != nil { - b.Fatal(err) - } - } -} - -func BenchmarkWriteUint32(b *testing.B) { - w := bufio.NewWriter(io.Discard) - b.ReportAllocs() - b.ResetTimer() - for range b.N { - writeUint32(w, 0x0ba3a) - } -} - -type nopRead struct{} - -func (r nopRead) Read(p []byte) (int, error) { - return len(p), nil -} - -var sinkU32 uint32 - -func BenchmarkReadUint32(b *testing.B) { - r := bufio.NewReader(nopRead{}) - var err error - b.ReportAllocs() - b.ResetTimer() - for range b.N { - sinkU32, err = readUint32(r) - if err != nil { - b.Fatal(err) - } - } -} - func waitConnect(t testing.TB, c *Client) { t.Helper() if m, err := c.Recv(); err != nil { t.Fatalf("client first Recv: %v", err) - } else if v, ok := m.(ServerInfoMessage); !ok { + } else if v, ok := m.(derp.ServerInfoMessage); !ok { t.Fatalf("client first Recv was unexpected type %T", v) } } -func TestParseSSOutput(t *testing.T) { - contents, err := os.ReadFile("testdata/example_ss.txt") - if err != nil { - t.Errorf("os.ReadFile(example_ss.txt) failed: %v", err) - } - seen := parseSSOutput(string(contents)) - if len(seen) == 0 { - t.Errorf("parseSSOutput expected non-empty map") - } -} - -type countWriter struct { - mu sync.Mutex - writes int - bytes int64 -} - -func (w *countWriter) Write(p []byte) (n int, err error) { - w.mu.Lock() - defer w.mu.Unlock() - w.writes++ - w.bytes += int64(len(p)) - return len(p), nil -} - -func (w *countWriter) Stats() (writes int, bytes int64) { - w.mu.Lock() - defer w.mu.Unlock() - return w.writes, w.bytes -} - -func (w *countWriter) ResetStats() { - w.mu.Lock() - defer w.mu.Unlock() - w.writes, w.bytes = 0, 0 -} - -func TestClientSendRateLimiting(t *testing.T) { - cw := new(countWriter) - c := &Client{ - bw: bufio.NewWriter(cw), - clock: &tstest.Clock{}, - } - c.setSendRateLimiter(ServerInfoMessage{}) - - pkt := make([]byte, 1000) - if err := c.send(key.NodePublic{}, pkt); err != nil { - t.Fatal(err) - } - writes1, bytes1 := cw.Stats() - if writes1 != 1 { - t.Errorf("writes = %v, want 1", writes1) - } - - // Flood should all succeed. - cw.ResetStats() - for range 1000 { - if err := c.send(key.NodePublic{}, pkt); err != nil { - t.Fatal(err) - } - } - writes1K, bytes1K := cw.Stats() - if writes1K != 1000 { - t.Logf("writes = %v; want 1000", writes1K) - } - if got, want := bytes1K, bytes1*1000; got != want { - t.Logf("bytes = %v; want %v", got, want) - } - - // Set a rate limiter - cw.ResetStats() - c.setSendRateLimiter(ServerInfoMessage{ - TokenBucketBytesPerSecond: 1, - TokenBucketBytesBurst: int(bytes1 * 2), - }) - for range 1000 { - if err := c.send(key.NodePublic{}, pkt); err != nil { - t.Fatal(err) - } - } - writesLimited, bytesLimited := cw.Stats() - if writesLimited == 0 || writesLimited == writes1K { - t.Errorf("limited conn's write count = %v; want non-zero, less than 1k", writesLimited) - } - if bytesLimited < bytes1*2 || bytesLimited >= bytes1K { - t.Errorf("limited conn's bytes count = %v; want >=%v, <%v", bytesLimited, bytes1K*2, bytes1K) - } -} - func TestServerRepliesToPing(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -1585,12 +779,12 @@ func TestServerRepliesToPing(t *testing.T) { } for { - m, err := tc.c.recvTimeout(time.Second) + m, err := tc.c.RecvTimeoutForTest(time.Second) if err != nil { t.Fatal(err) } switch m := m.(type) { - case PongMessage: + case derp.PongMessage: if ([8]byte(m)) != data { t.Fatalf("got pong %2x; want %2x", [8]byte(m), data) } diff --git a/derp/derpconst/derpconst.go b/derp/derpconst/derpconst.go new file mode 100644 index 000000000..74ca09ccb --- /dev/null +++ b/derp/derpconst/derpconst.go @@ -0,0 +1,11 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package derpconst contains constants used by the DERP client and server. +package derpconst + +// MetaCertCommonNamePrefix is the prefix that the DERP server +// puts on for the common name of its "metacert". The suffix of +// the common name after "derpkey" is the hex key.NodePublic +// of the DERP server. +const MetaCertCommonNamePrefix = "derpkey" diff --git a/derp/derphttp/derphttp_client.go b/derp/derphttp/derphttp_client.go index b8cce8cdc..db56c4a44 100644 --- a/derp/derphttp/derphttp_client.go +++ b/derp/derphttp/derphttp_client.go @@ -30,14 +30,17 @@ import ( "go4.org/mem" "tailscale.com/derp" + "tailscale.com/derp/derpconst" "tailscale.com/envknob" + "tailscale.com/feature" + "tailscale.com/feature/buildfeatures" "tailscale.com/health" "tailscale.com/net/dnscache" "tailscale.com/net/netmon" "tailscale.com/net/netns" + "tailscale.com/net/netx" "tailscale.com/net/sockstats" "tailscale.com/net/tlsdial" - "tailscale.com/net/tshttpproxy" "tailscale.com/syncs" "tailscale.com/tailcfg" "tailscale.com/tstime" @@ -55,7 +58,7 @@ type Client struct { TLSConfig *tls.Config // optional; nil means default HealthTracker *health.Tracker // optional; used if non-nil only DNSCache *dnscache.Resolver // optional; nil means no caching - MeshKey string // optional; for trusted clients + MeshKey key.DERPMesh // optional; for trusted clients IsProber bool // optional; for probers to optional declare themselves as such // WatchConnectionChanges is whether the client wishes to subscribe to @@ -313,6 +316,9 @@ func (c *Client) preferIPv6() bool { var dialWebsocketFunc func(ctx context.Context, urlStr string) (net.Conn, error) func useWebsockets() bool { + if !canWebsockets { + return false + } if runtime.GOOS == "js" { return true } @@ -383,7 +389,7 @@ func (c *Client) connect(ctx context.Context, caller string) (client *derp.Clien var node *tailcfg.DERPNode // nil when using c.url to dial var idealNodeInRegion bool switch { - case useWebsockets(): + case canWebsockets && useWebsockets(): var urlStr string if c.url != nil { urlStr = c.url.String() @@ -498,7 +504,7 @@ func (c *Client) connect(ctx context.Context, caller string) (client *derp.Clien req.Header.Set("Connection", "Upgrade") if !idealNodeInRegion && reg != nil { // This is purely informative for now (2024-07-06) for stats: - req.Header.Set("Ideal-Node", reg.Nodes[0].Name) + req.Header.Set(derp.IdealNodeHeader, reg.Nodes[0].Name) // TODO(bradfitz,raggi): start a time.AfterFunc for 30m-1h or so to // dialNode(reg.Nodes[0]) and see if we can even TCP connect to it. If // so, TLS handshake it as well (which is mixed up in this massive @@ -517,7 +523,7 @@ func (c *Client) connect(ctx context.Context, caller string) (client *derp.Clien // just to get routed into the server's HTTP Handler so it // can Hijack the request, but we signal with a special header // that we don't want to deal with its HTTP response. - req.Header.Set(fastStartHeader, "1") // suppresses the server's HTTP response + req.Header.Set(derp.FastStartHeader, "1") // suppresses the server's HTTP response if err := req.Write(brw); err != nil { return nil, 0, err } @@ -584,7 +590,7 @@ func (c *Client) connect(ctx context.Context, caller string) (client *derp.Clien // // The primary use for this is the derper mesh mode to connect to each // other over a VPC network. -func (c *Client) SetURLDialer(dialer func(ctx context.Context, network, addr string) (net.Conn, error)) { +func (c *Client) SetURLDialer(dialer netx.DialFunc) { c.dialer = dialer } @@ -642,14 +648,21 @@ func (c *Client) dialRegion(ctx context.Context, reg *tailcfg.DERPRegion) (net.C } func (c *Client) tlsClient(nc net.Conn, node *tailcfg.DERPNode) *tls.Conn { - tlsConf := tlsdial.Config(c.tlsServerName(node), c.HealthTracker, c.TLSConfig) + tlsConf := tlsdial.Config(c.HealthTracker, c.TLSConfig) + // node is allowed to be nil here, tlsServerName falls back to using the URL + // if node is nil. + tlsConf.ServerName = c.tlsServerName(node) if node != nil { if node.InsecureForTests { tlsConf.InsecureSkipVerify = true tlsConf.VerifyConnection = nil } if node.CertName != "" { - tlsdial.SetConfigExpectedCert(tlsConf, node.CertName) + if suf, ok := strings.CutPrefix(node.CertName, "sha256-raw:"); ok { + tlsdial.SetConfigExpectedCertHash(tlsConf, suf) + } else { + tlsdial.SetConfigExpectedCert(tlsConf, node.CertName) + } } } return tls.Client(nc, tlsConf) @@ -663,7 +676,7 @@ func (c *Client) tlsClient(nc net.Conn, node *tailcfg.DERPNode) *tls.Conn { func (c *Client) DialRegionTLS(ctx context.Context, reg *tailcfg.DERPRegion) (tlsConn *tls.Conn, connClose io.Closer, node *tailcfg.DERPNode, err error) { tcpConn, node, err := c.dialRegion(ctx, reg) if err != nil { - return nil, nil, nil, err + return nil, nil, nil, fmt.Errorf("dialRegion(%d): %w", reg.RegionID, err) } done := make(chan bool) // unbuffered defer close(done) @@ -722,8 +735,12 @@ func (c *Client) dialNode(ctx context.Context, n *tailcfg.DERPNode) (net.Conn, e Path: "/", // unused }, } - if proxyURL, err := tshttpproxy.ProxyFromEnvironment(proxyReq); err == nil && proxyURL != nil { - return c.dialNodeUsingProxy(ctx, n, proxyURL) + if buildfeatures.HasUseProxy { + if proxyFromEnv, ok := feature.HookProxyFromEnvironment.GetOk(); ok { + if proxyURL, err := proxyFromEnv(proxyReq); err == nil && proxyURL != nil { + return c.dialNodeUsingProxy(ctx, n, proxyURL) + } + } } type res struct { @@ -738,6 +755,17 @@ func (c *Client) dialNode(ctx context.Context, n *tailcfg.DERPNode) (net.Conn, e nwait := 0 startDial := func(dstPrimary, proto string) { + dst := cmp.Or(dstPrimary, n.HostName) + + // If dialing an IP address directly, check its address family + // and bail out before incrementing nwait. + if ip, err := netip.ParseAddr(dst); err == nil { + if proto == "tcp4" && ip.Is6() || + proto == "tcp6" && ip.Is4() { + return + } + } + nwait++ go func() { if proto == "tcp4" && c.preferIPv6() { @@ -752,8 +780,10 @@ func (c *Client) dialNode(ctx context.Context, n *tailcfg.DERPNode) (net.Conn, e // Start v4 dial } } - dst := cmp.Or(dstPrimary, n.HostName) port := "443" + if !c.useHTTPS() { + port = "3340" + } if n.DERPPort != 0 { port = fmt.Sprint(n.DERPPort) } @@ -840,10 +870,14 @@ func (c *Client) dialNodeUsingProxy(ctx context.Context, n *tailcfg.DERPNode, pr target := net.JoinHostPort(n.HostName, "443") var authHeader string - if v, err := tshttpproxy.GetAuthHeader(pu); err != nil { - c.logf("derphttp: error getting proxy auth header for %v: %v", proxyURL, err) - } else if v != "" { - authHeader = fmt.Sprintf("Proxy-Authorization: %s\r\n", v) + if buildfeatures.HasUseProxy { + if getAuthHeader, ok := feature.HookProxyGetAuthHeader.GetOk(); ok { + if v, err := getAuthHeader(pu); err != nil { + c.logf("derphttp: error getting proxy auth header for %v: %v", proxyURL, err) + } else if v != "" { + authHeader = fmt.Sprintf("Proxy-Authorization: %s\r\n", v) + } + } } if _, err := fmt.Fprintf(proxyConn, "CONNECT %s HTTP/1.1\r\nHost: %s\r\n%s\r\n", target, target, authHeader); err != nil { @@ -1131,7 +1165,7 @@ var ErrClientClosed = errors.New("derphttp.Client closed") func parseMetaCert(certs []*x509.Certificate) (serverPub key.NodePublic, serverProtoVersion int) { for _, cert := range certs { // Look for derpkey prefix added by initMetacert() on the server side. - if pubHex, ok := strings.CutPrefix(cert.Subject.CommonName, "derpkey"); ok { + if pubHex, ok := strings.CutPrefix(cert.Subject.CommonName, derpconst.MetaCertCommonNamePrefix); ok { var err error serverPub, err = key.ParseNodePublicUntyped(mem.S(pubHex)) if err == nil && cert.SerialNumber.BitLen() <= 8 { // supports up to version 255 diff --git a/derp/derphttp/derphttp_test.go b/derp/derphttp/derphttp_test.go index cfb3676cd..5208481ed 100644 --- a/derp/derphttp/derphttp_test.go +++ b/derp/derphttp/derphttp_test.go @@ -1,22 +1,35 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -package derphttp +package derphttp_test import ( "bytes" "context" "crypto/tls" + "encoding/json" + "errors" + "flag" "fmt" + "maps" "net" "net/http" "net/http/httptest" + "slices" + "strings" "sync" "testing" + "testing/synctest" "time" "tailscale.com/derp" + "tailscale.com/derp/derphttp" + "tailscale.com/derp/derpserver" + "tailscale.com/net/memnet" "tailscale.com/net/netmon" + "tailscale.com/net/netx" + "tailscale.com/tailcfg" + "tailscale.com/tstest" "tailscale.com/types/key" ) @@ -34,12 +47,12 @@ func TestSendRecv(t *testing.T) { clientKeys = append(clientKeys, priv.Public()) } - s := derp.NewServer(serverPrivateKey, t.Logf) + s := derpserver.New(serverPrivateKey, t.Logf) defer s.Close() httpsrv := &http.Server{ TLSNextProto: make(map[string]func(*http.Server, *tls.Conn, http.Handler)), - Handler: Handler(s), + Handler: derpserver.Handler(s), } ln, err := net.Listen("tcp4", "localhost:0") @@ -58,7 +71,7 @@ func TestSendRecv(t *testing.T) { } }() - var clients []*Client + var clients []*derphttp.Client var recvChs []chan []byte done := make(chan struct{}) var wg sync.WaitGroup @@ -71,7 +84,7 @@ func TestSendRecv(t *testing.T) { }() for i := range numClients { key := clientPrivateKeys[i] - c, err := NewClient(key, serverURL, t.Logf, netMon) + c, err := derphttp.NewClient(key, serverURL, t.Logf, netMon) if err != nil { t.Fatalf("client %d: %v", i, err) } @@ -151,7 +164,7 @@ func TestSendRecv(t *testing.T) { recvNothing(1) } -func waitConnect(t testing.TB, c *Client) { +func waitConnect(t testing.TB, c *derphttp.Client) { t.Helper() if m, err := c.Recv(); err != nil { t.Fatalf("client first Recv: %v", err) @@ -162,12 +175,12 @@ func waitConnect(t testing.TB, c *Client) { func TestPing(t *testing.T) { serverPrivateKey := key.NewNode() - s := derp.NewServer(serverPrivateKey, t.Logf) + s := derpserver.New(serverPrivateKey, t.Logf) defer s.Close() httpsrv := &http.Server{ TLSNextProto: make(map[string]func(*http.Server, *tls.Conn, http.Handler)), - Handler: Handler(s), + Handler: derpserver.Handler(s), } ln, err := net.Listen("tcp4", "localhost:0") @@ -186,7 +199,7 @@ func TestPing(t *testing.T) { } }() - c, err := NewClient(key.NewNode(), serverURL, t.Logf, netmon.NewStatic()) + c, err := derphttp.NewClient(key.NewNode(), serverURL, t.Logf, netmon.NewStatic()) if err != nil { t.Fatalf("NewClient: %v", err) } @@ -212,24 +225,23 @@ func TestPing(t *testing.T) { } } -func newTestServer(t *testing.T, k key.NodePrivate) (serverURL string, s *derp.Server) { - s = derp.NewServer(k, t.Logf) +const testMeshKey = "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef" + +func newTestServer(t *testing.T, k key.NodePrivate) (serverURL string, s *derpserver.Server, ln *memnet.Listener) { + s = derpserver.New(k, t.Logf) httpsrv := &http.Server{ TLSNextProto: make(map[string]func(*http.Server, *tls.Conn, http.Handler)), - Handler: Handler(s), + Handler: derpserver.Handler(s), } - ln, err := net.Listen("tcp4", "localhost:0") - if err != nil { - t.Fatal(err) - } + ln = memnet.Listen("localhost:0") + serverURL = "http://" + ln.Addr().String() - s.SetMeshKey("1234") + s.SetMeshKey(testMeshKey) go func() { if err := httpsrv.Serve(ln); err != nil { - if err == http.ErrServerClosed { - t.Logf("server closed") + if errors.Is(err, net.ErrClosed) { return } panic(err) @@ -238,194 +250,213 @@ func newTestServer(t *testing.T, k key.NodePrivate) (serverURL string, s *derp.S return } -func newWatcherClient(t *testing.T, watcherPrivateKey key.NodePrivate, serverToWatchURL string) (c *Client) { - c, err := NewClient(watcherPrivateKey, serverToWatchURL, t.Logf, netmon.NewStatic()) +func newWatcherClient(t *testing.T, watcherPrivateKey key.NodePrivate, serverToWatchURL string, ln *memnet.Listener) (c *derphttp.Client) { + c, err := derphttp.NewClient(watcherPrivateKey, serverToWatchURL, t.Logf, netmon.NewStatic()) if err != nil { t.Fatal(err) } - c.MeshKey = "1234" - return -} - -// breakConnection breaks the connection, which should trigger a reconnect. -func (c *Client) breakConnection(brokenClient *derp.Client) { - c.mu.Lock() - defer c.mu.Unlock() - if c.client != brokenClient { - return - } - if c.netConn != nil { - c.netConn.Close() - c.netConn = nil + k, err := key.ParseDERPMesh(testMeshKey) + if err != nil { + t.Fatal(err) } - c.client = nil + c.MeshKey = k + c.SetURLDialer(ln.Dial) + return } // Test that a watcher connection successfully reconnects and processes peer // updates after a different thread breaks and reconnects the connection, while // the watcher is waiting on recv(). func TestBreakWatcherConnRecv(t *testing.T) { - // Set the wait time before a retry after connection failure to be much lower. - // This needs to be early in the test, for defer to run right at the end after - // the DERP client has finished. - origRetryInterval := retryInterval - retryInterval = 50 * time.Millisecond - defer func() { retryInterval = origRetryInterval }() - - var wg sync.WaitGroup - defer wg.Wait() - // Make the watcher server - serverPrivateKey1 := key.NewNode() - _, s1 := newTestServer(t, serverPrivateKey1) - defer s1.Close() - - // Make the watched server - serverPrivateKey2 := key.NewNode() - serverURL2, s2 := newTestServer(t, serverPrivateKey2) - defer s2.Close() - - // Make the watcher (but it is not connected yet) - watcher1 := newWatcherClient(t, serverPrivateKey1, serverURL2) - defer watcher1.Close() - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() + synctest.Test(t, func(t *testing.T) { + // Set the wait time before a retry after connection failure to be much lower. + // This needs to be early in the test, for defer to run right at the end after + // the DERP client has finished. + tstest.Replace(t, derphttp.RetryInterval, 50*time.Millisecond) + + var wg sync.WaitGroup + // Make the watcher server + serverPrivateKey1 := key.NewNode() + _, s1, ln1 := newTestServer(t, serverPrivateKey1) + defer s1.Close() + defer ln1.Close() + + // Make the watched server + serverPrivateKey2 := key.NewNode() + serverURL2, s2, ln2 := newTestServer(t, serverPrivateKey2) + defer s2.Close() + defer ln2.Close() + + // Make the watcher (but it is not connected yet) + watcher := newWatcherClient(t, serverPrivateKey1, serverURL2, ln2) + defer watcher.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + watcherChan := make(chan int, 1) + defer close(watcherChan) + errChan := make(chan error, 1) + + // Start the watcher thread (which connects to the watched server) + wg.Add(1) // To avoid using t.Logf after the test ends. See https://golang.org/issue/40343 + go func() { + defer wg.Done() + var peers int + add := func(m derp.PeerPresentMessage) { + t.Logf("add: %v", m.Key.ShortString()) + peers++ + // Signal that the watcher has run + watcherChan <- peers + } + remove := func(m derp.PeerGoneMessage) { t.Logf("remove: %v", m.Peer.ShortString()); peers-- } + notifyErr := func(err error) { + select { + case errChan <- err: + case <-ctx.Done(): + } + } - watcherChan := make(chan int, 1) + watcher.RunWatchConnectionLoop(ctx, serverPrivateKey1.Public(), t.Logf, add, remove, notifyErr) + }() - // Start the watcher thread (which connects to the watched server) - wg.Add(1) // To avoid using t.Logf after the test ends. See https://golang.org/issue/40343 - go func() { - defer wg.Done() - var peers int - add := func(m derp.PeerPresentMessage) { - t.Logf("add: %v", m.Key.ShortString()) - peers++ - // Signal that the watcher has run - watcherChan <- peers - } - remove := func(m derp.PeerGoneMessage) { t.Logf("remove: %v", m.Peer.ShortString()); peers-- } + synctest.Wait() - watcher1.RunWatchConnectionLoop(ctx, serverPrivateKey1.Public(), t.Logf, add, remove) - }() + // Wait for the watcher to run, then break the connection and check if it + // reconnected and received peer updates. + for range 10 { + select { + case peers := <-watcherChan: + if peers != 1 { + t.Fatalf("wrong number of peers added during watcher connection: have %d, want 1", peers) + } + case err := <-errChan: + if err.Error() != "derp.Recv: EOF" { + t.Fatalf("expected notifyError connection error to be EOF, got %v", err) + } + } - timer := time.NewTimer(5 * time.Second) - defer timer.Stop() + synctest.Wait() - // Wait for the watcher to run, then break the connection and check if it - // reconnected and received peer updates. - for range 10 { - select { - case peers := <-watcherChan: - if peers != 1 { - t.Fatal("wrong number of peers added during watcher connection") - } - case <-timer.C: - t.Fatalf("watcher did not process the peer update") + watcher.BreakConnection(watcher) + // re-establish connection by sending a packet + watcher.ForwardPacket(key.NodePublic{}, key.NodePublic{}, []byte("bogus")) } - watcher1.breakConnection(watcher1.client) - // re-establish connection by sending a packet - watcher1.ForwardPacket(key.NodePublic{}, key.NodePublic{}, []byte("bogus")) - - timer.Reset(5 * time.Second) - } + cancel() // Cancel the context to stop the watcher loop. + wg.Wait() + }) } // Test that a watcher connection successfully reconnects and processes peer // updates after a different thread breaks and reconnects the connection, while // the watcher is not waiting on recv(). func TestBreakWatcherConn(t *testing.T) { - // Set the wait time before a retry after connection failure to be much lower. - // This needs to be early in the test, for defer to run right at the end after - // the DERP client has finished. - origRetryInterval := retryInterval - retryInterval = 50 * time.Millisecond - defer func() { retryInterval = origRetryInterval }() - - var wg sync.WaitGroup - defer wg.Wait() - // Make the watcher server - serverPrivateKey1 := key.NewNode() - _, s1 := newTestServer(t, serverPrivateKey1) - defer s1.Close() - - // Make the watched server - serverPrivateKey2 := key.NewNode() - serverURL2, s2 := newTestServer(t, serverPrivateKey2) - defer s2.Close() - - // Make the watcher (but it is not connected yet) - watcher1 := newWatcherClient(t, serverPrivateKey1, serverURL2) - defer watcher1.Close() - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() + synctest.Test(t, func(t *testing.T) { + // Set the wait time before a retry after connection failure to be much lower. + // This needs to be early in the test, for defer to run right at the end after + // the DERP client has finished. + tstest.Replace(t, derphttp.RetryInterval, 50*time.Millisecond) + + var wg sync.WaitGroup + // Make the watcher server + serverPrivateKey1 := key.NewNode() + _, s1, ln1 := newTestServer(t, serverPrivateKey1) + defer s1.Close() + defer ln1.Close() + + // Make the watched server + serverPrivateKey2 := key.NewNode() + serverURL2, s2, ln2 := newTestServer(t, serverPrivateKey2) + defer s2.Close() + defer ln2.Close() + + // Make the watcher (but it is not connected yet) + watcher1 := newWatcherClient(t, serverPrivateKey1, serverURL2, ln2) + defer watcher1.Close() + + ctx, cancel := context.WithCancel(context.Background()) + + watcherChan := make(chan int, 1) + breakerChan := make(chan bool, 1) + errorChan := make(chan error, 1) + + // Start the watcher thread (which connects to the watched server) + wg.Add(1) // To avoid using t.Logf after the test ends. See https://golang.org/issue/40343 + go func() { + defer wg.Done() + var peers int + add := func(m derp.PeerPresentMessage) { + t.Logf("add: %v", m.Key.ShortString()) + peers++ + // Signal that the watcher has run + watcherChan <- peers + select { + case <-ctx.Done(): + return + // Wait for breaker to run + case <-breakerChan: + } + } + remove := func(m derp.PeerGoneMessage) { t.Logf("remove: %v", m.Peer.ShortString()); peers-- } + notifyError := func(err error) { + errorChan <- err + } - watcherChan := make(chan int, 1) - breakerChan := make(chan bool, 1) + watcher1.RunWatchConnectionLoop(ctx, serverPrivateKey1.Public(), t.Logf, add, remove, notifyError) + }() - // Start the watcher thread (which connects to the watched server) - wg.Add(1) // To avoid using t.Logf after the test ends. See https://golang.org/issue/40343 - go func() { - defer wg.Done() - var peers int - add := func(m derp.PeerPresentMessage) { - t.Logf("add: %v", m.Key.ShortString()) - peers++ - // Signal that the watcher has run - watcherChan <- peers - // Wait for breaker to run - <-breakerChan - } - remove := func(m derp.PeerGoneMessage) { t.Logf("remove: %v", m.Peer.ShortString()); peers-- } + synctest.Wait() - watcher1.RunWatchConnectionLoop(ctx, serverPrivateKey1.Public(), t.Logf, add, remove) - }() + // Wait for the watcher to run, then break the connection and check if it + // reconnected and received peer updates. + for range 10 { + select { + case peers := <-watcherChan: + if peers != 1 { + t.Fatalf("wrong number of peers added during watcher connection have %d, want 1", peers) + } + case err := <-errorChan: + if !errors.Is(err, net.ErrClosed) { + t.Fatalf("expected notifyError connection error to fail with ErrClosed, got %v", err) + } + } - timer := time.NewTimer(5 * time.Second) - defer timer.Stop() + synctest.Wait() - // Wait for the watcher to run, then break the connection and check if it - // reconnected and received peer updates. - for range 10 { - select { - case peers := <-watcherChan: - if peers != 1 { - t.Fatal("wrong number of peers added during watcher connection") - } - case <-timer.C: - t.Fatalf("watcher did not process the peer update") + watcher1.BreakConnection(watcher1) + // re-establish connection by sending a packet + watcher1.ForwardPacket(key.NodePublic{}, key.NodePublic{}, []byte("bogus")) + // signal that the breaker is done + breakerChan <- true } - watcher1.breakConnection(watcher1.client) - // re-establish connection by sending a packet - watcher1.ForwardPacket(key.NodePublic{}, key.NodePublic{}, []byte("bogus")) - // signal that the breaker is done - breakerChan <- true - - timer.Reset(5 * time.Second) - } + watcher1.Close() + cancel() + wg.Wait() + }) } func noopAdd(derp.PeerPresentMessage) {} func noopRemove(derp.PeerGoneMessage) {} +func noopNotifyError(error) {} func TestRunWatchConnectionLoopServeConnect(t *testing.T) { - defer func() { testHookWatchLookConnectResult = nil }() + defer derphttp.SetTestHookWatchLookConnectResult(nil) ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() priv := key.NewNode() - serverURL, s := newTestServer(t, priv) + serverURL, s, ln := newTestServer(t, priv) defer s.Close() + defer ln.Close() pub := priv.Public() - watcher := newWatcherClient(t, priv, serverURL) + watcher := newWatcherClient(t, priv, serverURL, ln) defer watcher.Close() // Test connecting to ourselves, and that we get hung up on. - testHookWatchLookConnectResult = func(err error, wasSelfConnect bool) bool { + derphttp.SetTestHookWatchLookConnectResult(func(err error, wasSelfConnect bool) bool { t.Helper() if err != nil { t.Fatalf("error connecting to server: %v", err) @@ -434,12 +465,12 @@ func TestRunWatchConnectionLoopServeConnect(t *testing.T) { t.Error("wanted self-connect; wasn't") } return false - } - watcher.RunWatchConnectionLoop(ctx, pub, t.Logf, noopAdd, noopRemove) + }) + watcher.RunWatchConnectionLoop(ctx, pub, t.Logf, noopAdd, noopRemove, noopNotifyError) // Test connecting to the server with a zero value for ignoreServerKey, // so we should always connect. - testHookWatchLookConnectResult = func(err error, wasSelfConnect bool) bool { + derphttp.SetTestHookWatchLookConnectResult(func(err error, wasSelfConnect bool) bool { t.Helper() if err != nil { t.Fatalf("error connecting to server: %v", err) @@ -448,16 +479,14 @@ func TestRunWatchConnectionLoopServeConnect(t *testing.T) { t.Error("wanted normal connect; got self connect") } return false - } - watcher.RunWatchConnectionLoop(ctx, key.NodePublic{}, t.Logf, noopAdd, noopRemove) + }) + watcher.RunWatchConnectionLoop(ctx, key.NodePublic{}, t.Logf, noopAdd, noopRemove, noopNotifyError) } // verify that the LocalAddr method doesn't acquire the mutex. // See https://github.com/tailscale/tailscale/issues/11519 func TestLocalAddrNoMutex(t *testing.T) { - var c Client - c.mu.Lock() - defer c.mu.Unlock() // not needed in test but for symmetry + var c derphttp.Client _, err := c.LocalAddr() if got, want := fmt.Sprint(err), "client not connected"; got != want { @@ -466,7 +495,7 @@ func TestLocalAddrNoMutex(t *testing.T) { } func TestProbe(t *testing.T) { - h := Handler(nil) + h := derpserver.Handler(nil) tests := []struct { path string @@ -485,3 +514,118 @@ func TestProbe(t *testing.T) { } } } + +func TestNotifyError(t *testing.T) { + defer derphttp.SetTestHookWatchLookConnectResult(nil) + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + priv := key.NewNode() + serverURL, s, ln := newTestServer(t, priv) + defer s.Close() + defer ln.Close() + + pub := priv.Public() + + // Test early error notification when c.connect fails. + watcher := newWatcherClient(t, priv, serverURL, ln) + watcher.SetURLDialer(netx.DialFunc(func(ctx context.Context, network, addr string) (net.Conn, error) { + t.Helper() + return nil, fmt.Errorf("test error: %s", addr) + })) + defer watcher.Close() + + derphttp.SetTestHookWatchLookConnectResult(func(err error, wasSelfConnect bool) bool { + t.Helper() + if err == nil { + t.Fatal("expected error connecting to server, got nil") + } + if wasSelfConnect { + t.Error("wanted normal connect; got self connect") + } + return false + }) + + errChan := make(chan error, 1) + notifyError := func(err error) { + errChan <- err + } + watcher.RunWatchConnectionLoop(ctx, pub, t.Logf, noopAdd, noopRemove, notifyError) + + select { + case err := <-errChan: + if !strings.Contains(err.Error(), "test") { + t.Errorf("expected test error, got %v", err) + } + case <-ctx.Done(): + t.Fatalf("context done before receiving error: %v", ctx.Err()) + } +} + +var liveNetworkTest = flag.Bool("live-net-tests", false, "run live network tests") + +func TestManualDial(t *testing.T) { + if !*liveNetworkTest { + t.Skip("skipping live network test without --live-net-tests") + } + dm := &tailcfg.DERPMap{} + res, err := http.Get("https://controlplane.tailscale.com/derpmap/default") + if err != nil { + t.Fatalf("fetching DERPMap: %v", err) + } + defer res.Body.Close() + if err := json.NewDecoder(res.Body).Decode(dm); err != nil { + t.Fatalf("decoding DERPMap: %v", err) + } + + region := slices.Sorted(maps.Keys(dm.Regions))[0] + + netMon := netmon.NewStatic() + rc := derphttp.NewRegionClient(key.NewNode(), t.Logf, netMon, func() *tailcfg.DERPRegion { + return dm.Regions[region] + }) + defer rc.Close() + + if err := rc.Connect(context.Background()); err != nil { + t.Fatalf("rc.Connect: %v", err) + } +} + +func TestURLDial(t *testing.T) { + if !*liveNetworkTest { + t.Skip("skipping live network test without --live-net-tests") + } + dm := &tailcfg.DERPMap{} + res, err := http.Get("https://controlplane.tailscale.com/derpmap/default") + if err != nil { + t.Fatalf("fetching DERPMap: %v", err) + } + defer res.Body.Close() + if err := json.NewDecoder(res.Body).Decode(dm); err != nil { + t.Fatalf("decoding DERPMap: %v", err) + } + + // find a valid target DERP host to test against + var hostname string + for _, reg := range dm.Regions { + for _, node := range reg.Nodes { + if !node.STUNOnly && node.CanPort80 && node.CertName == "" || node.CertName == node.HostName { + hostname = node.HostName + break + } + } + if hostname != "" { + break + } + } + netMon := netmon.NewStatic() + c, err := derphttp.NewClient(key.NewNode(), "https://"+hostname+"/", t.Logf, netMon) + if err != nil { + t.Errorf("NewClient: %v", err) + } + defer c.Close() + + if err := c.Connect(context.Background()); err != nil { + t.Fatalf("rc.Connect: %v", err) + } +} diff --git a/derp/derphttp/export_test.go b/derp/derphttp/export_test.go new file mode 100644 index 000000000..59d8324dc --- /dev/null +++ b/derp/derphttp/export_test.go @@ -0,0 +1,24 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package derphttp + +func SetTestHookWatchLookConnectResult(f func(connectError error, wasSelfConnect bool) (keepRunning bool)) { + testHookWatchLookConnectResult = f +} + +// breakConnection breaks the connection, which should trigger a reconnect. +func (c *Client) BreakConnection(brokenClient *Client) { + c.mu.Lock() + defer c.mu.Unlock() + if c.client != brokenClient.client { + return + } + if c.netConn != nil { + c.netConn.Close() + c.netConn = nil + } + c.client = nil +} + +var RetryInterval = &retryInterval diff --git a/derp/derphttp/mesh_client.go b/derp/derphttp/mesh_client.go index 66b8c166e..c14a9a7e1 100644 --- a/derp/derphttp/mesh_client.go +++ b/derp/derphttp/mesh_client.go @@ -31,6 +31,9 @@ var testHookWatchLookConnectResult func(connectError error, wasSelfConnect bool) // This behavior will likely change. Callers should do their own accounting // and dup suppression as needed. // +// If set the notifyError func is called with any error that occurs within the ctx +// main loop connection setup, or the inner loop receiving messages via RecvDetail. +// // infoLogf, if non-nil, is the logger to write periodic status updates about // how many peers are on the server. Error log output is set to the c's logger, // regardless of infoLogf's value. @@ -42,10 +45,11 @@ var testHookWatchLookConnectResult func(connectError error, wasSelfConnect bool) // initialized Client.WatchConnectionChanges to true. // // If the DERP connection breaks and reconnects, remove will be called for all -// previously seen peers, with Reason type PeerGoneReasonSynthetic. Those +// previously seen peers, with Reason type PeerGoneReasonMeshConnBroke. Those // clients are likely still connected and their add message will appear after // reconnect. -func (c *Client) RunWatchConnectionLoop(ctx context.Context, ignoreServerKey key.NodePublic, infoLogf logger.Logf, add func(derp.PeerPresentMessage), remove func(derp.PeerGoneMessage)) { +func (c *Client) RunWatchConnectionLoop(ctx context.Context, ignoreServerKey key.NodePublic, infoLogf logger.Logf, + add func(derp.PeerPresentMessage), remove func(derp.PeerGoneMessage), notifyError func(error)) { if !c.WatchConnectionChanges { if c.isStarted() { panic("invalid use of RunWatchConnectionLoop on already-started Client without setting Client.RunWatchConnectionLoop") @@ -121,6 +125,10 @@ func (c *Client) RunWatchConnectionLoop(ctx context.Context, ignoreServerKey key // Make sure we're connected before calling s.ServerPublicKey. _, _, err := c.connect(ctx, "RunWatchConnectionLoop") if err != nil { + logf("mesh connect: %v", err) + if notifyError != nil { + notifyError(err) + } if f := testHookWatchLookConnectResult; f != nil && !f(err, false) { return } @@ -141,6 +149,9 @@ func (c *Client) RunWatchConnectionLoop(ctx context.Context, ignoreServerKey key if err != nil { clear() logf("Recv: %v", err) + if notifyError != nil { + notifyError(err) + } sleep(retryInterval) break } diff --git a/derp/derphttp/websocket.go b/derp/derphttp/websocket.go index 6ef47473a..9dd640ee3 100644 --- a/derp/derphttp/websocket.go +++ b/derp/derphttp/websocket.go @@ -1,7 +1,7 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -//go:build linux || js +//go:build js || ((linux || darwin) && ts_debug_websockets) package derphttp @@ -14,6 +14,8 @@ import ( "tailscale.com/net/wsconn" ) +const canWebsockets = true + func init() { dialWebsocketFunc = dialWebsocket } diff --git a/derp/derphttp/websocket_stub.go b/derp/derphttp/websocket_stub.go new file mode 100644 index 000000000..d84bfba57 --- /dev/null +++ b/derp/derphttp/websocket_stub.go @@ -0,0 +1,8 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !(js || ((linux || darwin) && ts_debug_websockets)) + +package derphttp + +const canWebsockets = false diff --git a/derp/derp_server.go b/derp/derpserver/derpserver.go similarity index 78% rename from derp/derp_server.go rename to derp/derpserver/derpserver.go index 2e17cbfe5..0bbc66780 100644 --- a/derp/derp_server.go +++ b/derp/derpserver/derpserver.go @@ -1,7 +1,8 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -package derp +// Package derpserver implements a DERP server. +package derpserver // TODO(crawshaw): with predefined serverKey in clients and HMAC on packets we could skip TLS @@ -11,6 +12,7 @@ import ( "context" "crypto/ed25519" crand "crypto/rand" + "crypto/tls" "crypto/x509" "crypto/x509/pkix" "encoding/binary" @@ -23,9 +25,9 @@ import ( "math" "math/big" "math/rand/v2" - "net" "net/http" "net/netip" + "os" "os/exec" "runtime" "strconv" @@ -36,7 +38,9 @@ import ( "go4.org/mem" "golang.org/x/sync/errgroup" - "tailscale.com/client/tailscale" + "tailscale.com/client/local" + "tailscale.com/derp" + "tailscale.com/derp/derpconst" "tailscale.com/disco" "tailscale.com/envknob" "tailscale.com/metrics" @@ -46,6 +50,7 @@ import ( "tailscale.com/tstime/rate" "tailscale.com/types/key" "tailscale.com/types/logger" + "tailscale.com/util/ctxkey" "tailscale.com/util/mak" "tailscale.com/util/set" "tailscale.com/util/slicesx" @@ -56,6 +61,10 @@ import ( // verbosely log whenever DERP drops a packet. var verboseDropKeys = map[key.NodePublic]bool{} +// IdealNodeContextKey is the context key used to pass the IdealNodeHeader value +// from the HTTP handler to the DERP server's Accept method. +var IdealNodeContextKey = ctxkey.New("ideal-node", "") + func init() { keys := envknob.String("TS_DEBUG_VERBOSE_DROPS") if keys == "" { @@ -72,10 +81,19 @@ func init() { } const ( - perClientSendQueueDepth = 32 // packets buffered for sending - writeTimeout = 2 * time.Second + defaultPerClientSendQueueDepth = 32 // default packets buffered for sending + DefaultTCPWiteTimeout = 2 * time.Second + privilegedWriteTimeout = 30 * time.Second // for clients with the mesh key ) +func getPerClientSendQueueDepth() int { + if v, ok := envknob.LookupInt("TS_DEBUG_DERP_PER_CLIENT_SEND_QUEUE_DEPTH"); ok { + return v + } + + return defaultPerClientSendQueueDepth +} + // dupPolicy is a temporary (2021-08-30) mechanism to change the policy // of how duplicate connection for the same key are handled. type dupPolicy int8 @@ -91,6 +109,14 @@ const ( disableFighters ) +// packetKind is the kind of packet being sent through DERP +type packetKind string + +const ( + packetKindDisco packetKind = "disco" + packetKindOther packetKind = "other" +) + type align64 [0]atomic.Int64 // for side effect of its 64-bit alignment // Server is a DERP server. @@ -103,48 +129,45 @@ type Server struct { publicKey key.NodePublic logf logger.Logf memSys0 uint64 // runtime.MemStats.Sys at start (or early-ish) - meshKey string + meshKey key.DERPMesh limitedLogf logger.Logf metaCert []byte // the encoded x509 cert to send after LetsEncrypt cert+intermediate dupPolicy dupPolicy debug bool + localClient local.Client // Counters: - packetsSent, bytesSent expvar.Int - packetsRecv, bytesRecv expvar.Int - packetsRecvByKind metrics.LabelMap - packetsRecvDisco *expvar.Int - packetsRecvOther *expvar.Int - _ align64 - packetsDropped expvar.Int - packetsDroppedReason metrics.LabelMap - packetsDroppedReasonCounters []*expvar.Int // indexed by dropReason - packetsDroppedType metrics.LabelMap - packetsDroppedTypeDisco *expvar.Int - packetsDroppedTypeOther *expvar.Int - _ align64 - packetsForwardedOut expvar.Int - packetsForwardedIn expvar.Int - peerGoneDisconnectedFrames expvar.Int // number of peer disconnected frames sent - peerGoneNotHereFrames expvar.Int // number of peer not here frames sent - gotPing expvar.Int // number of ping frames from client - sentPong expvar.Int // number of pong frames enqueued to client - accepts expvar.Int - curClients expvar.Int - curHomeClients expvar.Int // ones with preferred - dupClientKeys expvar.Int // current number of public keys we have 2+ connections for - dupClientConns expvar.Int // current number of connections sharing a public key - dupClientConnTotal expvar.Int // total number of accepted connections when a dup key existed - unknownFrames expvar.Int - homeMovesIn expvar.Int // established clients announce home server moves in - homeMovesOut expvar.Int // established clients announce home server moves out - multiForwarderCreated expvar.Int - multiForwarderDeleted expvar.Int - removePktForwardOther expvar.Int - avgQueueDuration *uint64 // In milliseconds; accessed atomically - tcpRtt metrics.LabelMap // histogram - meshUpdateBatchSize *metrics.Histogram - meshUpdateLoopCount *metrics.Histogram + packetsSent, bytesSent expvar.Int + packetsRecv, bytesRecv expvar.Int + packetsRecvByKind metrics.LabelMap + packetsRecvDisco *expvar.Int + packetsRecvOther *expvar.Int + _ align64 + packetsForwardedOut expvar.Int + packetsForwardedIn expvar.Int + peerGoneDisconnectedFrames expvar.Int // number of peer disconnected frames sent + peerGoneNotHereFrames expvar.Int // number of peer not here frames sent + gotPing expvar.Int // number of ping frames from client + sentPong expvar.Int // number of pong frames enqueued to client + accepts expvar.Int + curClients expvar.Int + curClientsNotIdeal expvar.Int + curHomeClients expvar.Int // ones with preferred + dupClientKeys expvar.Int // current number of public keys we have 2+ connections for + dupClientConns expvar.Int // current number of connections sharing a public key + dupClientConnTotal expvar.Int // total number of accepted connections when a dup key existed + unknownFrames expvar.Int + homeMovesIn expvar.Int // established clients announce home server moves in + homeMovesOut expvar.Int // established clients announce home server moves out + multiForwarderCreated expvar.Int + multiForwarderDeleted expvar.Int + removePktForwardOther expvar.Int + sclientWriteTimeouts expvar.Int + avgQueueDuration *uint64 // In milliseconds; accessed atomically + tcpRtt metrics.LabelMap // histogram + meshUpdateBatchSize *metrics.Histogram + meshUpdateLoopCount *metrics.Histogram + bufferedWriteFrames *metrics.Histogram // how many sendLoop frames (or groups of related frames) get written per flush // verifyClientsLocalTailscaled only accepts client connections to the DERP // server if the clientKey is a known peer in the network, as specified by a @@ -154,9 +177,9 @@ type Server struct { verifyClientsURL string verifyClientsURLFailOpen bool - mu sync.Mutex + mu syncs.Mutex closed bool - netConns map[Conn]chan struct{} // chan is closed when conn closes + netConns map[derp.Conn]chan struct{} // chan is closed when conn closes clients map[key.NodePublic]*clientSet watchers set.Set[*sclient] // mesh peers // clientsMesh tracks all clients in the cluster, both locally @@ -174,6 +197,11 @@ type Server struct { // maps from netip.AddrPort to a client's public key keyOfAddr map[netip.AddrPort]key.NodePublic + // Sets the client send queue depth for the server. + perClientSendQueueDepth int + + tcpWriteTimeout time.Duration + clock tstime.Clock } @@ -313,84 +341,131 @@ type PacketForwarder interface { String() string } -// Conn is the subset of the underlying net.Conn the DERP Server needs. -// It is a defined type so that non-net connections can be used. -type Conn interface { - io.WriteCloser - LocalAddr() net.Addr - // The *Deadline methods follow the semantics of net.Conn. - SetDeadline(time.Time) error - SetReadDeadline(time.Time) error - SetWriteDeadline(time.Time) error -} +var packetsDropped = metrics.NewMultiLabelMap[dropReasonKindLabels]( + "derp_packets_dropped", + "counter", + "DERP packets dropped by reason and by kind") -// NewServer returns a new DERP server. It doesn't listen on its own. +var bytesDropped = metrics.NewMultiLabelMap[dropReasonKindLabels]( + "derp_bytes_dropped", + "counter", + "DERP bytes dropped by reason and by kind", +) + +// New returns a new DERP server. It doesn't listen on its own. // Connections are given to it via Server.Accept. -func NewServer(privateKey key.NodePrivate, logf logger.Logf) *Server { +func New(privateKey key.NodePrivate, logf logger.Logf) *Server { var ms runtime.MemStats runtime.ReadMemStats(&ms) s := &Server{ - debug: envknob.Bool("DERP_DEBUG_LOGS"), - privateKey: privateKey, - publicKey: privateKey.Public(), - logf: logf, - limitedLogf: logger.RateLimitedFn(logf, 30*time.Second, 5, 100), - packetsRecvByKind: metrics.LabelMap{Label: "kind"}, - packetsDroppedReason: metrics.LabelMap{Label: "reason"}, - packetsDroppedType: metrics.LabelMap{Label: "type"}, - clients: map[key.NodePublic]*clientSet{}, - clientsMesh: map[key.NodePublic]PacketForwarder{}, - netConns: map[Conn]chan struct{}{}, - memSys0: ms.Sys, - watchers: set.Set[*sclient]{}, - peerGoneWatchers: map[key.NodePublic]set.HandleSet[func(key.NodePublic)]{}, - avgQueueDuration: new(uint64), - tcpRtt: metrics.LabelMap{Label: "le"}, - meshUpdateBatchSize: metrics.NewHistogram([]float64{0, 1, 2, 5, 10, 20, 50, 100, 200, 500, 1000}), - meshUpdateLoopCount: metrics.NewHistogram([]float64{0, 1, 2, 5, 10, 20, 50, 100}), - keyOfAddr: map[netip.AddrPort]key.NodePublic{}, - clock: tstime.StdClock{}, + debug: envknob.Bool("DERP_DEBUG_LOGS"), + privateKey: privateKey, + publicKey: privateKey.Public(), + logf: logf, + limitedLogf: logger.RateLimitedFn(logf, 30*time.Second, 5, 100), + packetsRecvByKind: metrics.LabelMap{Label: "kind"}, + clients: map[key.NodePublic]*clientSet{}, + clientsMesh: map[key.NodePublic]PacketForwarder{}, + netConns: map[derp.Conn]chan struct{}{}, + memSys0: ms.Sys, + watchers: set.Set[*sclient]{}, + peerGoneWatchers: map[key.NodePublic]set.HandleSet[func(key.NodePublic)]{}, + avgQueueDuration: new(uint64), + tcpRtt: metrics.LabelMap{Label: "le"}, + meshUpdateBatchSize: metrics.NewHistogram([]float64{0, 1, 2, 5, 10, 20, 50, 100, 200, 500, 1000}), + meshUpdateLoopCount: metrics.NewHistogram([]float64{0, 1, 2, 5, 10, 20, 50, 100}), + bufferedWriteFrames: metrics.NewHistogram([]float64{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 15, 20, 25, 50, 100}), + keyOfAddr: map[netip.AddrPort]key.NodePublic{}, + clock: tstime.StdClock{}, + tcpWriteTimeout: DefaultTCPWiteTimeout, } s.initMetacert() - s.packetsRecvDisco = s.packetsRecvByKind.Get("disco") - s.packetsRecvOther = s.packetsRecvByKind.Get("other") + s.packetsRecvDisco = s.packetsRecvByKind.Get(string(packetKindDisco)) + s.packetsRecvOther = s.packetsRecvByKind.Get(string(packetKindOther)) - s.packetsDroppedReasonCounters = s.genPacketsDroppedReasonCounters() + genDroppedCounters() - s.packetsDroppedTypeDisco = s.packetsDroppedType.Get("disco") - s.packetsDroppedTypeOther = s.packetsDroppedType.Get("other") + s.perClientSendQueueDepth = getPerClientSendQueueDepth() return s } -func (s *Server) genPacketsDroppedReasonCounters() []*expvar.Int { - getMetric := s.packetsDroppedReason.Get - ret := []*expvar.Int{ - dropReasonUnknownDest: getMetric("unknown_dest"), - dropReasonUnknownDestOnFwd: getMetric("unknown_dest_on_fwd"), - dropReasonGoneDisconnected: getMetric("gone_disconnected"), - dropReasonQueueHead: getMetric("queue_head"), - dropReasonQueueTail: getMetric("queue_tail"), - dropReasonWriteError: getMetric("write_error"), - dropReasonDupClient: getMetric("dup_client"), +func genDroppedCounters() { + initMetrics := func(reason dropReason) { + packetsDropped.Add(dropReasonKindLabels{ + Kind: string(packetKindDisco), + Reason: string(reason), + }, 0) + packetsDropped.Add(dropReasonKindLabels{ + Kind: string(packetKindOther), + Reason: string(reason), + }, 0) + bytesDropped.Add(dropReasonKindLabels{ + Kind: string(packetKindDisco), + Reason: string(reason), + }, 0) + bytesDropped.Add(dropReasonKindLabels{ + Kind: string(packetKindOther), + Reason: string(reason), + }, 0) + } + getMetrics := func(reason dropReason) []expvar.Var { + return []expvar.Var{ + packetsDropped.Get(dropReasonKindLabels{ + Kind: string(packetKindDisco), + Reason: string(reason), + }), + packetsDropped.Get(dropReasonKindLabels{ + Kind: string(packetKindOther), + Reason: string(reason), + }), + bytesDropped.Get(dropReasonKindLabels{ + Kind: string(packetKindDisco), + Reason: string(reason), + }), + bytesDropped.Get(dropReasonKindLabels{ + Kind: string(packetKindOther), + Reason: string(reason), + }), + } } - if len(ret) != int(numDropReasons) { - panic("dropReason metrics out of sync") + + dropReasons := []dropReason{ + dropReasonUnknownDest, + dropReasonUnknownDestOnFwd, + dropReasonGoneDisconnected, + dropReasonQueueHead, + dropReasonQueueTail, + dropReasonWriteError, + dropReasonDupClient, } - for i := range numDropReasons { - if ret[i] == nil { + + for _, dr := range dropReasons { + initMetrics(dr) + m := getMetrics(dr) + if len(m) != 4 { panic("dropReason metrics out of sync") } + + for _, v := range m { + if v == nil { + panic("dropReason metrics out of sync") + } + } } - return ret } // SetMesh sets the pre-shared key that regional DERP servers used to mesh // amongst themselves. // // It must be called before serving begins. -func (s *Server) SetMeshKey(v string) { - s.meshKey = v +func (s *Server) SetMeshKey(v string) error { + k, err := key.ParseDERPMesh(v) + if err != nil { + return err + } + s.meshKey = k + return nil } // SetVerifyClients sets whether this DERP server verifies clients through tailscaled. @@ -413,11 +488,28 @@ func (s *Server) SetVerifyClientURLFailOpen(v bool) { s.verifyClientsURLFailOpen = v } +// SetTailscaledSocketPath sets the unix socket path to use to talk to +// tailscaled if client verification is enabled. +// +// If unset or set to the empty string, the default path for the operating +// system is used. +func (s *Server) SetTailscaledSocketPath(path string) { + s.localClient.Socket = path + s.localClient.UseSocketOnly = path != "" +} + +// SetTCPWriteTimeout sets the timeout for writing to connected clients. +// This timeout does not apply to mesh connections. +// Defaults to 2 seconds. +func (s *Server) SetTCPWriteTimeout(d time.Duration) { + s.tcpWriteTimeout = d +} + // HasMeshKey reports whether the server is configured with a mesh key. -func (s *Server) HasMeshKey() bool { return s.meshKey != "" } +func (s *Server) HasMeshKey() bool { return !s.meshKey.IsZero() } // MeshKey returns the configured mesh key, if any. -func (s *Server) MeshKey() string { return s.meshKey } +func (s *Server) MeshKey() key.DERPMesh { return s.meshKey } // PrivateKey returns the server's private key. func (s *Server) PrivateKey() key.NodePrivate { return s.privateKey } @@ -476,7 +568,7 @@ func (s *Server) IsClientConnectedForTest(k key.NodePublic) bool { // on its own. // // Accept closes nc. -func (s *Server) Accept(ctx context.Context, nc Conn, brw *bufio.ReadWriter, remoteAddr string) { +func (s *Server) Accept(ctx context.Context, nc derp.Conn, brw *bufio.ReadWriter, remoteAddr string) { closed := make(chan struct{}) s.mu.Lock() @@ -524,9 +616,9 @@ func (s *Server) initMetacert() { log.Fatal(err) } tmpl := &x509.Certificate{ - SerialNumber: big.NewInt(ProtocolVersion), + SerialNumber: big.NewInt(derp.ProtocolVersion), Subject: pkix.Name{ - CommonName: fmt.Sprintf("derpkey%s", s.publicKey.UntypedHexString()), + CommonName: derpconst.MetaCertCommonNamePrefix + s.publicKey.UntypedHexString(), }, // Windows requires NotAfter and NotBefore set: NotAfter: s.clock.Now().Add(30 * 24 * time.Hour), @@ -546,6 +638,25 @@ func (s *Server) initMetacert() { // TLS server to let the client skip a round trip during start-up. func (s *Server) MetaCert() []byte { return s.metaCert } +// ModifyTLSConfigToAddMetaCert modifies c.GetCertificate to make +// it append s.MetaCert to the returned certificates. +// +// It panics if c or c.GetCertificate is nil. +func (s *Server) ModifyTLSConfigToAddMetaCert(c *tls.Config) { + getCert := c.GetCertificate + if getCert == nil { + panic("c.GetCertificate is nil") + } + c.GetCertificate = func(hi *tls.ClientHelloInfo) (*tls.Certificate, error) { + cert, err := getCert(hi) + if err != nil { + return nil, err + } + cert.Certificate = append(cert.Certificate, s.MetaCert()) + return cert, nil + } +} + // registerClient notes that client c is now authenticated and ready for packets. // // If c.key is connected more than once, the earlier connection(s) are @@ -598,6 +709,9 @@ func (s *Server) registerClient(c *sclient) { } s.keyOfAddr[c.remoteIPPort] = c.key s.curClients.Add(1) + if c.isNotIdealConn { + s.curClientsNotIdeal.Add(1) + } s.broadcastPeerStateChangeLocked(c.key, c.remoteIPPort, c.presentFlags(), true) } @@ -606,7 +720,7 @@ func (s *Server) registerClient(c *sclient) { // presence changed. // // s.mu must be held. -func (s *Server) broadcastPeerStateChangeLocked(peer key.NodePublic, ipPort netip.AddrPort, flags PeerPresentFlags, present bool) { +func (s *Server) broadcastPeerStateChangeLocked(peer key.NodePublic, ipPort netip.AddrPort, flags derp.PeerPresentFlags, present bool) { for w := range s.watchers { w.peerStateChange = append(w.peerStateChange, peerConnState{ peer: peer, @@ -688,6 +802,9 @@ func (s *Server) unregisterClient(c *sclient) { if c.preferred { s.curHomeClients.Add(-1) } + if c.isNotIdealConn { + s.curClientsNotIdeal.Add(-1) + } } // addPeerGoneFromRegionWatcher adds a function to be called when peer is gone @@ -747,7 +864,7 @@ func (s *Server) notePeerGoneFromRegionLocked(key key.NodePublic) { // requestPeerGoneWriteLimited sends a request to write a "peer gone" // frame, but only in reply to a disco packet, and only if we haven't // sent one recently. -func (c *sclient) requestPeerGoneWriteLimited(peer key.NodePublic, contents []byte, reason PeerGoneReasonType) { +func (c *sclient) requestPeerGoneWriteLimited(peer key.NodePublic, contents []byte, reason derp.PeerGoneReasonType) { if disco.LooksLikeDiscoWrapper(contents) != true { return } @@ -791,7 +908,7 @@ func (s *Server) addWatcher(c *sclient) { go c.requestMeshUpdate() } -func (s *Server) accept(ctx context.Context, nc Conn, brw *bufio.ReadWriter, remoteAddr string, connNum int64) error { +func (s *Server) accept(ctx context.Context, nc derp.Conn, brw *bufio.ReadWriter, remoteAddr string, connNum int64) error { br := brw.Reader nc.SetDeadline(time.Now().Add(10 * time.Second)) bw := &lazyBufioWriter{w: nc, lbw: brw.Writer} @@ -804,8 +921,8 @@ func (s *Server) accept(ctx context.Context, nc Conn, brw *bufio.ReadWriter, rem return fmt.Errorf("receive client key: %v", err) } - clientAP, _ := netip.ParseAddrPort(remoteAddr) - if err := s.verifyClient(ctx, clientKey, clientInfo, clientAP.Addr()); err != nil { + remoteIPPort, _ := netip.ParseAddrPort(remoteAddr) + if err := s.verifyClient(ctx, clientKey, clientInfo, remoteIPPort.Addr()); err != nil { return fmt.Errorf("client %v rejected: %v", clientKey, err) } @@ -815,8 +932,6 @@ func (s *Server) accept(ctx context.Context, nc Conn, brw *bufio.ReadWriter, rem ctx, cancel := context.WithCancel(ctx) defer cancel() - remoteIPPort, _ := netip.ParseAddrPort(remoteAddr) - c := &sclient{ connNum: connNum, s: s, @@ -828,11 +943,12 @@ func (s *Server) accept(ctx context.Context, nc Conn, brw *bufio.ReadWriter, rem done: ctx.Done(), remoteIPPort: remoteIPPort, connectedAt: s.clock.Now(), - sendQueue: make(chan pkt, perClientSendQueueDepth), - discoSendQueue: make(chan pkt, perClientSendQueueDepth), + sendQueue: make(chan pkt, s.perClientSendQueueDepth), + discoSendQueue: make(chan pkt, s.perClientSendQueueDepth), sendPongCh: make(chan [8]byte, 1), peerGone: make(chan peerGoneMsg), canMesh: s.isMeshPeer(clientInfo), + isNotIdealConn: IdealNodeContextKey.Value(ctx) != "", peerGoneLim: rate.NewLimiter(rate.Every(time.Second), 3), } @@ -879,6 +995,9 @@ func (c *sclient) run(ctx context.Context) error { if errors.Is(err, context.Canceled) { c.debugLogf("sender canceled by reader exiting") } else { + if errors.Is(err, os.ErrDeadlineExceeded) { + c.s.sclientWriteTimeouts.Add(1) + } c.logf("sender failed: %v", err) } } @@ -887,7 +1006,7 @@ func (c *sclient) run(ctx context.Context) error { c.startStatsLoop(sendCtx) for { - ft, fl, err := readFrameHeader(c.br) + ft, fl, err := derp.ReadFrameHeader(c.br) c.debugLogf("read frame type %d len %d err %v", ft, fl, err) if err != nil { if errors.Is(err, io.EOF) { @@ -902,17 +1021,17 @@ func (c *sclient) run(ctx context.Context) error { } c.s.noteClientActivity(c) switch ft { - case frameNotePreferred: + case derp.FrameNotePreferred: err = c.handleFrameNotePreferred(ft, fl) - case frameSendPacket: + case derp.FrameSendPacket: err = c.handleFrameSendPacket(ft, fl) - case frameForwardPacket: + case derp.FrameForwardPacket: err = c.handleFrameForwardPacket(ft, fl) - case frameWatchConns: + case derp.FrameWatchConns: err = c.handleFrameWatchConns(ft, fl) - case frameClosePeer: + case derp.FrameClosePeer: err = c.handleFrameClosePeer(ft, fl) - case framePing: + case derp.FramePing: err = c.handleFramePing(ft, fl) default: err = c.handleUnknownFrame(ft, fl) @@ -923,12 +1042,12 @@ func (c *sclient) run(ctx context.Context) error { } } -func (c *sclient) handleUnknownFrame(ft frameType, fl uint32) error { +func (c *sclient) handleUnknownFrame(ft derp.FrameType, fl uint32) error { _, err := io.CopyN(io.Discard, c.br, int64(fl)) return err } -func (c *sclient) handleFrameNotePreferred(ft frameType, fl uint32) error { +func (c *sclient) handleFrameNotePreferred(ft derp.FrameType, fl uint32) error { if fl != 1 { return fmt.Errorf("frameNotePreferred wrong size") } @@ -940,7 +1059,7 @@ func (c *sclient) handleFrameNotePreferred(ft frameType, fl uint32) error { return nil } -func (c *sclient) handleFrameWatchConns(ft frameType, fl uint32) error { +func (c *sclient) handleFrameWatchConns(ft derp.FrameType, fl uint32) error { if fl != 0 { return fmt.Errorf("handleFrameWatchConns wrong size") } @@ -951,9 +1070,9 @@ func (c *sclient) handleFrameWatchConns(ft frameType, fl uint32) error { return nil } -func (c *sclient) handleFramePing(ft frameType, fl uint32) error { +func (c *sclient) handleFramePing(ft derp.FrameType, fl uint32) error { c.s.gotPing.Add(1) - var m PingMessage + var m derp.PingMessage if fl < uint32(len(m)) { return fmt.Errorf("short ping: %v", fl) } @@ -978,8 +1097,8 @@ func (c *sclient) handleFramePing(ft frameType, fl uint32) error { return err } -func (c *sclient) handleFrameClosePeer(ft frameType, fl uint32) error { - if fl != keyLen { +func (c *sclient) handleFrameClosePeer(ft derp.FrameType, fl uint32) error { + if fl != derp.KeyLen { return fmt.Errorf("handleFrameClosePeer wrong size") } if !c.canMesh { @@ -1012,7 +1131,7 @@ func (c *sclient) handleFrameClosePeer(ft frameType, fl uint32) error { // handleFrameForwardPacket reads a "forward packet" frame from the client // (which must be a trusted client, a peer in our mesh). -func (c *sclient) handleFrameForwardPacket(ft frameType, fl uint32) error { +func (c *sclient) handleFrameForwardPacket(ft derp.FrameType, fl uint32) error { if !c.canMesh { return fmt.Errorf("insufficient permissions") } @@ -1039,7 +1158,7 @@ func (c *sclient) handleFrameForwardPacket(ft frameType, fl uint32) error { if dstLen > 1 { reason = dropReasonDupClient } else { - c.requestPeerGoneWriteLimited(dstKey, contents, PeerGoneReasonNotHere) + c.requestPeerGoneWriteLimited(dstKey, contents, derp.PeerGoneReasonNotHere) } s.recordDrop(contents, srcKey, dstKey, reason) return nil @@ -1055,7 +1174,7 @@ func (c *sclient) handleFrameForwardPacket(ft frameType, fl uint32) error { } // handleFrameSendPacket reads a "send packet" frame from the client. -func (c *sclient) handleFrameSendPacket(ft frameType, fl uint32) error { +func (c *sclient) handleFrameSendPacket(ft derp.FrameType, fl uint32) error { s := c.s dstKey, contents, err := s.recvPacket(c.br, fl) @@ -1092,7 +1211,7 @@ func (c *sclient) handleFrameSendPacket(ft frameType, fl uint32) error { if dstLen > 1 { reason = dropReasonDupClient } else { - c.requestPeerGoneWriteLimited(dstKey, contents, PeerGoneReasonNotHere) + c.requestPeerGoneWriteLimited(dstKey, contents, derp.PeerGoneReasonNotHere) } s.recordDrop(contents, c.key, dstKey, reason) c.debugLogf("SendPacket for %s, dropping with reason=%s", dstKey.ShortString(), reason) @@ -1114,31 +1233,37 @@ func (c *sclient) debugLogf(format string, v ...any) { } } -// dropReason is why we dropped a DERP frame. -type dropReason int +type dropReasonKindLabels struct { + Reason string // metric label corresponding to a given dropReason + Kind string // either `disco` or `other` +} -//go:generate go run tailscale.com/cmd/addlicense -file dropreason_string.go go run golang.org/x/tools/cmd/stringer -type=dropReason -trimprefix=dropReason +// dropReason is why we dropped a DERP frame. +type dropReason string const ( - dropReasonUnknownDest dropReason = iota // unknown destination pubkey - dropReasonUnknownDestOnFwd // unknown destination pubkey on a derp-forwarded packet - dropReasonGoneDisconnected // destination tailscaled disconnected before we could send - dropReasonQueueHead // destination queue is full, dropped packet at queue head - dropReasonQueueTail // destination queue is full, dropped packet at queue tail - dropReasonWriteError // OS write() failed - dropReasonDupClient // the public key is connected 2+ times (active/active, fighting) - numDropReasons // unused; keep last + dropReasonUnknownDest dropReason = "unknown_dest" // unknown destination pubkey + dropReasonUnknownDestOnFwd dropReason = "unknown_dest_on_fwd" // unknown destination pubkey on a derp-forwarded packet + dropReasonGoneDisconnected dropReason = "gone_disconnected" // destination tailscaled disconnected before we could send + dropReasonQueueHead dropReason = "queue_head" // destination queue is full, dropped packet at queue head + dropReasonQueueTail dropReason = "queue_tail" // destination queue is full, dropped packet at queue tail + dropReasonWriteError dropReason = "write_error" // OS write() failed + dropReasonDupClient dropReason = "dup_client" // the public key is connected 2+ times (active/active, fighting) ) func (s *Server) recordDrop(packetBytes []byte, srcKey, dstKey key.NodePublic, reason dropReason) { - s.packetsDropped.Add(1) - s.packetsDroppedReasonCounters[reason].Add(1) + labels := dropReasonKindLabels{ + Reason: string(reason), + } looksDisco := disco.LooksLikeDiscoWrapper(packetBytes) if looksDisco { - s.packetsDroppedTypeDisco.Add(1) + labels.Kind = string(packetKindDisco) } else { - s.packetsDroppedTypeOther.Add(1) + labels.Kind = string(packetKindOther) } + packetsDropped.Add(labels, 1) + bytesDropped.Add(labels, int64(len(packetBytes))) + if verboseDropKeys[dstKey] { // Preformat the log string prior to calling limitedLogf. The // limiter acts based on the format string, and we want to @@ -1196,13 +1321,13 @@ func (c *sclient) sendPkt(dst *sclient, p pkt) error { // notified (in a new goroutine) whenever a peer has disconnected from all DERP // nodes in the current region. func (c *sclient) onPeerGoneFromRegion(peer key.NodePublic) { - c.requestPeerGoneWrite(peer, PeerGoneReasonDisconnected) + c.requestPeerGoneWrite(peer, derp.PeerGoneReasonDisconnected) } // requestPeerGoneWrite sends a request to write a "peer gone" frame // with an explanation of why it is gone. It blocks until either the // write request is scheduled, or the client has closed. -func (c *sclient) requestPeerGoneWrite(peer key.NodePublic, reason PeerGoneReasonType) { +func (c *sclient) requestPeerGoneWrite(peer key.NodePublic, reason derp.PeerGoneReasonType) { select { case c.peerGone <- peerGoneMsg{ peer: peer, @@ -1227,17 +1352,23 @@ func (c *sclient) requestMeshUpdate() { } } -var localClient tailscale.LocalClient - // isMeshPeer reports whether the client is a trusted mesh peer // node in the DERP region. -func (s *Server) isMeshPeer(info *clientInfo) bool { - return info != nil && info.MeshKey != "" && info.MeshKey == s.meshKey +func (s *Server) isMeshPeer(info *derp.ClientInfo) bool { + // Compare mesh keys in constant time to prevent timing attacks. + // Since mesh keys are a fixed length, we don’t need to be concerned + // about timing attacks on client mesh keys that are the wrong length. + // See https://github.com/tailscale/corp/issues/28720 + if info == nil || info.MeshKey.IsZero() { + return false + } + + return s.meshKey.Equal(info.MeshKey) } // verifyClient checks whether the client is allowed to connect to the derper, // depending on how & whether the server's been configured to verify. -func (s *Server) verifyClient(ctx context.Context, clientKey key.NodePublic, info *clientInfo, clientIP netip.Addr) error { +func (s *Server) verifyClient(ctx context.Context, clientKey key.NodePublic, info *derp.ClientInfo, clientIP netip.Addr) error { if s.isMeshPeer(info) { // Trusted mesh peer. No need to verify further. In fact, verifying // further wouldn't work: it's not part of the tailnet so tailscaled and @@ -1247,8 +1378,8 @@ func (s *Server) verifyClient(ctx context.Context, clientKey key.NodePublic, inf // tailscaled-based verification: if s.verifyClientsLocalTailscaled { - _, err := localClient.WhoIsNodeKey(ctx, clientKey) - if err == tailscale.ErrPeerNotFound { + _, err := s.localClient.WhoIsNodeKey(ctx, clientKey) + if err == local.ErrPeerNotFound { return fmt.Errorf("peer %v not authorized (not found in local tailscaled)", clientKey) } if err != nil { @@ -1301,10 +1432,10 @@ func (s *Server) verifyClient(ctx context.Context, clientKey key.NodePublic, inf } func (s *Server) sendServerKey(lw *lazyBufioWriter) error { - buf := make([]byte, 0, len(magic)+key.NodePublicRawLen) - buf = append(buf, magic...) + buf := make([]byte, 0, len(derp.Magic)+key.NodePublicRawLen) + buf = append(buf, derp.Magic...) buf = s.publicKey.AppendTo(buf) - err := writeFrame(lw.bw(), frameServerKey, buf) + err := derp.WriteFrame(lw.bw(), derp.FrameServerKey, buf) lw.Flush() // redundant (no-op) flush to release bufio.Writer return err } @@ -1369,21 +1500,16 @@ func (s *Server) noteClientActivity(c *sclient) { dup.sendHistory = append(dup.sendHistory, c) } -type serverInfo struct { - Version int `json:"version,omitempty"` - - TokenBucketBytesPerSecond int `json:",omitempty"` - TokenBucketBytesBurst int `json:",omitempty"` -} +type ServerInfo = derp.ServerInfo func (s *Server) sendServerInfo(bw *lazyBufioWriter, clientKey key.NodePublic) error { - msg, err := json.Marshal(serverInfo{Version: ProtocolVersion}) + msg, err := json.Marshal(ServerInfo{Version: derp.ProtocolVersion}) if err != nil { return err } msgbox := s.privateKey.SealTo(clientKey, msg) - if err := writeFrameHeader(bw.bw(), frameServerInfo, uint32(len(msgbox))); err != nil { + if err := derp.WriteFrameHeader(bw.bw(), derp.FrameServerInfo, uint32(len(msgbox))); err != nil { return err } if _, err := bw.Write(msgbox); err != nil { @@ -1395,12 +1521,12 @@ func (s *Server) sendServerInfo(bw *lazyBufioWriter, clientKey key.NodePublic) e // recvClientKey reads the frameClientInfo frame from the client (its // proof of identity) upon its initial connection. It should be // considered especially untrusted at this point. -func (s *Server) recvClientKey(br *bufio.Reader) (clientKey key.NodePublic, info *clientInfo, err error) { - fl, err := readFrameTypeHeader(br, frameClientInfo) +func (s *Server) recvClientKey(br *bufio.Reader) (clientKey key.NodePublic, info *derp.ClientInfo, err error) { + fl, err := derp.ReadFrameTypeHeader(br, derp.FrameClientInfo) if err != nil { return zpub, nil, err } - const minLen = keyLen + nonceLen + const minLen = derp.KeyLen + derp.NonceLen if fl < minLen { return zpub, nil, errors.New("short client info") } @@ -1412,7 +1538,7 @@ func (s *Server) recvClientKey(br *bufio.Reader) (clientKey key.NodePublic, info if err := clientKey.ReadRawWithoutAllocating(br); err != nil { return zpub, nil, err } - msgLen := int(fl - keyLen) + msgLen := int(fl - derp.KeyLen) msgbox := make([]byte, msgLen) if _, err := io.ReadFull(br, msgbox); err != nil { return zpub, nil, fmt.Errorf("msgbox: %v", err) @@ -1421,7 +1547,7 @@ func (s *Server) recvClientKey(br *bufio.Reader) (clientKey key.NodePublic, info if !ok { return zpub, nil, fmt.Errorf("msgbox: cannot open len=%d with client key %s", msgLen, clientKey) } - info = new(clientInfo) + info = new(derp.ClientInfo) if err := json.Unmarshal(msg, info); err != nil { return zpub, nil, fmt.Errorf("msg: %v", err) } @@ -1429,15 +1555,15 @@ func (s *Server) recvClientKey(br *bufio.Reader) (clientKey key.NodePublic, info } func (s *Server) recvPacket(br *bufio.Reader, frameLen uint32) (dstKey key.NodePublic, contents []byte, err error) { - if frameLen < keyLen { + if frameLen < derp.KeyLen { return zpub, nil, errors.New("short send packet frame") } if err := dstKey.ReadRawWithoutAllocating(br); err != nil { return zpub, nil, err } - packetLen := frameLen - keyLen - if packetLen > MaxPacketSize { - return zpub, nil, fmt.Errorf("data packet longer (%d) than max of %v", packetLen, MaxPacketSize) + packetLen := frameLen - derp.KeyLen + if packetLen > derp.MaxPacketSize { + return zpub, nil, fmt.Errorf("data packet longer (%d) than max of %v", packetLen, derp.MaxPacketSize) } contents = make([]byte, packetLen) if _, err := io.ReadFull(br, contents); err != nil { @@ -1457,7 +1583,7 @@ func (s *Server) recvPacket(br *bufio.Reader, frameLen uint32) (dstKey key.NodeP var zpub key.NodePublic func (s *Server) recvForwardPacket(br *bufio.Reader, frameLen uint32) (srcKey, dstKey key.NodePublic, contents []byte, err error) { - if frameLen < keyLen*2 { + if frameLen < derp.KeyLen*2 { return zpub, zpub, nil, errors.New("short send packet frame") } if err := srcKey.ReadRawWithoutAllocating(br); err != nil { @@ -1466,9 +1592,9 @@ func (s *Server) recvForwardPacket(br *bufio.Reader, frameLen uint32) (srcKey, d if err := dstKey.ReadRawWithoutAllocating(br); err != nil { return zpub, zpub, nil, err } - packetLen := frameLen - keyLen*2 - if packetLen > MaxPacketSize { - return zpub, zpub, nil, fmt.Errorf("data packet longer (%d) than max of %v", packetLen, MaxPacketSize) + packetLen := frameLen - derp.KeyLen*2 + if packetLen > derp.MaxPacketSize { + return zpub, zpub, nil, fmt.Errorf("data packet longer (%d) than max of %v", packetLen, derp.MaxPacketSize) } contents = make([]byte, packetLen) if _, err := io.ReadFull(br, contents); err != nil { @@ -1491,9 +1617,9 @@ type sclient struct { // Static after construction. connNum int64 // process-wide unique counter, incremented each Accept s *Server - nc Conn + nc derp.Conn key key.NodePublic - info clientInfo + info derp.ClientInfo logf logger.Logf done <-chan struct{} // closed when connection closes remoteIPPort netip.AddrPort // zero if remoteAddr is not ip:port. @@ -1503,6 +1629,7 @@ type sclient struct { peerGone chan peerGoneMsg // write request that a peer is not at this server (not used by mesh peers) meshUpdate chan struct{} // write request to write peerStateChange canMesh bool // clientInfo had correct mesh token for inter-region routing + isNotIdealConn bool // client indicated it is not its ideal node in the region isDup atomic.Bool // whether more than 1 sclient for key is connected isDisabled atomic.Bool // whether sends to this peer are disabled due to active/active dups debug bool // turn on for verbose logging @@ -1530,16 +1657,19 @@ type sclient struct { peerGoneLim *rate.Limiter } -func (c *sclient) presentFlags() PeerPresentFlags { - var f PeerPresentFlags +func (c *sclient) presentFlags() derp.PeerPresentFlags { + var f derp.PeerPresentFlags if c.info.IsProber { - f |= PeerPresentIsProber + f |= derp.PeerPresentIsProber } if c.canMesh { - f |= PeerPresentIsMeshPeer + f |= derp.PeerPresentIsMeshPeer + } + if c.isNotIdealConn { + f |= derp.PeerPresentNotIdeal } if f == 0 { - return PeerPresentIsRegular + return derp.PeerPresentIsRegular } return f } @@ -1549,7 +1679,7 @@ func (c *sclient) presentFlags() PeerPresentFlags { type peerConnState struct { ipPort netip.AddrPort // if present, the peer's IP:port peer key.NodePublic - flags PeerPresentFlags + flags derp.PeerPresentFlags present bool } @@ -1570,7 +1700,7 @@ type pkt struct { // peerGoneMsg is a request to write a peerGone frame to an sclient type peerGoneMsg struct { peer key.NodePublic - reason PeerGoneReasonType + reason derp.PeerGoneReasonType } func (c *sclient) setPreferred(v bool) { @@ -1649,14 +1779,16 @@ func (c *sclient) sendLoop(ctx context.Context) error { defer c.onSendLoopDone() jitter := rand.N(5 * time.Second) - keepAliveTick, keepAliveTickChannel := c.s.clock.NewTicker(keepAlive + jitter) + keepAliveTick, keepAliveTickChannel := c.s.clock.NewTicker(derp.KeepAlive + jitter) defer keepAliveTick.Stop() var werr error // last write error + inBatch := -1 // for bufferedWriteFrames for { if werr != nil { return werr } + inBatch++ // First, a non-blocking select (with a default) that // does as many non-flushing writes as possible. select { @@ -1688,6 +1820,10 @@ func (c *sclient) sendLoop(ctx context.Context) error { if werr = c.bw.Flush(); werr != nil { return werr } + if inBatch != 0 { // the first loop will almost always hit default & be size zero + c.s.bufferedWriteFrames.Observe(float64(inBatch)) + inBatch = 0 + } } // Then a blocking select with same: @@ -1698,7 +1834,6 @@ func (c *sclient) sendLoop(ctx context.Context) error { werr = c.sendPeerGone(msg.peer, msg.reason) case <-c.meshUpdate: werr = c.sendMeshUpdates() - continue case msg := <-c.sendQueue: werr = c.sendPacket(msg.src, msg.bs) c.recordQueueTime(msg.enqueuedAt) @@ -1707,7 +1842,6 @@ func (c *sclient) sendLoop(ctx context.Context) error { c.recordQueueTime(msg.enqueuedAt) case msg := <-c.sendPongCh: werr = c.sendPong(msg) - continue case <-keepAliveTickChannel: werr = c.sendKeepAlive() } @@ -1715,20 +1849,43 @@ func (c *sclient) sendLoop(ctx context.Context) error { } func (c *sclient) setWriteDeadline() { - c.nc.SetWriteDeadline(time.Now().Add(writeTimeout)) + d := c.s.tcpWriteTimeout + if c.canMesh { + // Trusted peers get more tolerance. + // + // The "canMesh" is a bit of a misnomer; mesh peers typically run over a + // different interface for a per-region private VPC and are not + // throttled. But monitoring software elsewhere over the internet also + // use the private mesh key to subscribe to connect/disconnect events + // and might hit throttling and need more time to get the initial dump + // of connected peers. + d = privilegedWriteTimeout + } + if d == 0 { + // A zero value should disable the write deadline per + // --tcp-write-timeout docs. The flag should only be applicable for + // non-mesh connections, again per its docs. If mesh happened to use a + // zero value constant above it would be a bug, so we don't bother + // with a condition on c.canMesh. + return + } + // Ignore the error from setting the write deadline. In practice, + // setting the deadline will only fail if the connection is closed + // or closing, so the subsequent Write() will fail anyway. + _ = c.nc.SetWriteDeadline(time.Now().Add(d)) } // sendKeepAlive sends a keep-alive frame, without flushing. func (c *sclient) sendKeepAlive() error { c.setWriteDeadline() - return writeFrameHeader(c.bw.bw(), frameKeepAlive, 0) + return derp.WriteFrameHeader(c.bw.bw(), derp.FrameKeepAlive, 0) } // sendPong sends a pong reply, without flushing. func (c *sclient) sendPong(data [8]byte) error { c.s.sentPong.Add(1) c.setWriteDeadline() - if err := writeFrameHeader(c.bw.bw(), framePong, uint32(len(data))); err != nil { + if err := derp.WriteFrameHeader(c.bw.bw(), derp.FramePong, uint32(len(data))); err != nil { return err } _, err := c.bw.Write(data[:]) @@ -1736,23 +1893,23 @@ func (c *sclient) sendPong(data [8]byte) error { } const ( - peerGoneFrameLen = keyLen + 1 - peerPresentFrameLen = keyLen + 16 + 2 + 1 // 16 byte IP + 2 byte port + 1 byte flags + peerGoneFrameLen = derp.KeyLen + 1 + peerPresentFrameLen = derp.KeyLen + 16 + 2 + 1 // 16 byte IP + 2 byte port + 1 byte flags ) // sendPeerGone sends a peerGone frame, without flushing. -func (c *sclient) sendPeerGone(peer key.NodePublic, reason PeerGoneReasonType) error { +func (c *sclient) sendPeerGone(peer key.NodePublic, reason derp.PeerGoneReasonType) error { switch reason { - case PeerGoneReasonDisconnected: + case derp.PeerGoneReasonDisconnected: c.s.peerGoneDisconnectedFrames.Add(1) - case PeerGoneReasonNotHere: + case derp.PeerGoneReasonNotHere: c.s.peerGoneNotHereFrames.Add(1) } c.setWriteDeadline() data := make([]byte, 0, peerGoneFrameLen) data = peer.AppendTo(data) data = append(data, byte(reason)) - if err := writeFrameHeader(c.bw.bw(), framePeerGone, uint32(len(data))); err != nil { + if err := derp.WriteFrameHeader(c.bw.bw(), derp.FramePeerGone, uint32(len(data))); err != nil { return err } @@ -1761,17 +1918,17 @@ func (c *sclient) sendPeerGone(peer key.NodePublic, reason PeerGoneReasonType) e } // sendPeerPresent sends a peerPresent frame, without flushing. -func (c *sclient) sendPeerPresent(peer key.NodePublic, ipPort netip.AddrPort, flags PeerPresentFlags) error { +func (c *sclient) sendPeerPresent(peer key.NodePublic, ipPort netip.AddrPort, flags derp.PeerPresentFlags) error { c.setWriteDeadline() - if err := writeFrameHeader(c.bw.bw(), framePeerPresent, peerPresentFrameLen); err != nil { + if err := derp.WriteFrameHeader(c.bw.bw(), derp.FramePeerPresent, peerPresentFrameLen); err != nil { return err } payload := make([]byte, peerPresentFrameLen) _ = peer.AppendTo(payload[:0]) a16 := ipPort.Addr().As16() - copy(payload[keyLen:], a16[:]) - binary.BigEndian.PutUint16(payload[keyLen+16:], ipPort.Port()) - payload[keyLen+18] = byte(flags) + copy(payload[derp.KeyLen:], a16[:]) + binary.BigEndian.PutUint16(payload[derp.KeyLen+16:], ipPort.Port()) + payload[derp.KeyLen+18] = byte(flags) _, err := c.bw.Write(payload) return err } @@ -1809,7 +1966,7 @@ func (c *sclient) sendMeshUpdates() error { if pcs.present { err = c.sendPeerPresent(pcs.peer, pcs.ipPort, pcs.flags) } else { - err = c.sendPeerGone(pcs.peer, PeerGoneReasonDisconnected) + err = c.sendPeerGone(pcs.peer, derp.PeerGoneReasonDisconnected) } if err != nil { return err @@ -1844,7 +2001,7 @@ func (c *sclient) sendPacket(srcKey key.NodePublic, contents []byte) (err error) pktLen += key.NodePublicRawLen c.noteSendFromSrc(srcKey) } - if err = writeFrameHeader(c.bw.bw(), frameRecvPacket, uint32(pktLen)); err != nil { + if err = derp.WriteFrameHeader(c.bw.bw(), derp.FrameRecvPacket, uint32(pktLen)); err != nil { return err } if withKey { @@ -2027,6 +2184,7 @@ func (s *Server) ExpVar() expvar.Var { m.Set("gauge_current_file_descriptors", expvar.Func(func() any { return metrics.CurrentFDs() })) m.Set("gauge_current_connections", &s.curClients) m.Set("gauge_current_home_connections", &s.curHomeClients) + m.Set("gauge_current_notideal_connections", &s.curClientsNotIdeal) m.Set("gauge_clients_total", expvar.Func(func() any { return len(s.clientsMesh) })) m.Set("gauge_clients_local", expvar.Func(func() any { return len(s.clients) })) m.Set("gauge_clients_remote", expvar.Func(func() any { return len(s.clientsMesh) - len(s.clients) })) @@ -2036,9 +2194,6 @@ func (s *Server) ExpVar() expvar.Var { m.Set("accepts", &s.accepts) m.Set("bytes_received", &s.bytesRecv) m.Set("bytes_sent", &s.bytesSent) - m.Set("packets_dropped", &s.packetsDropped) - m.Set("counter_packets_dropped_reason", &s.packetsDroppedReason) - m.Set("counter_packets_dropped_type", &s.packetsDroppedType) m.Set("counter_packets_received_kind", &s.packetsRecvByKind) m.Set("packets_sent", &s.packetsSent) m.Set("packets_received", &s.packetsRecv) @@ -2054,12 +2209,14 @@ func (s *Server) ExpVar() expvar.Var { m.Set("multiforwarder_created", &s.multiForwarderCreated) m.Set("multiforwarder_deleted", &s.multiForwarderDeleted) m.Set("packet_forwarder_delete_other_value", &s.removePktForwardOther) + m.Set("sclient_write_timeouts", &s.sclientWriteTimeouts) m.Set("average_queue_duration_ms", expvar.Func(func() any { return math.Float64frombits(atomic.LoadUint64(s.avgQueueDuration)) })) m.Set("counter_tcp_rtt", &s.tcpRtt) m.Set("counter_mesh_update_batch_size", s.meshUpdateBatchSize) m.Set("counter_mesh_update_loop_count", s.meshUpdateLoopCount) + m.Set("counter_buffered_write_frames", s.bufferedWriteFrames) var expvarVersion expvar.String expvarVersion.Set(version.Long()) m.Set("version", &expvarVersion) @@ -2116,11 +2273,11 @@ func (s *Server) ConsistencyCheck() error { func (s *Server) checkVerifyClientsLocalTailscaled() error { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - status, err := localClient.StatusWithoutPeers(ctx) + status, err := s.localClient.StatusWithoutPeers(ctx) if err != nil { return fmt.Errorf("localClient.Status: %w", err) } - info := &clientInfo{ + info := &derp.ClientInfo{ IsProber: true, } clientIP := netip.IPv6Loopback() diff --git a/derp/derp_server_default.go b/derp/derpserver/derpserver_default.go similarity index 79% rename from derp/derp_server_default.go rename to derp/derpserver/derpserver_default.go index 3e0b5b5e9..874e590d3 100644 --- a/derp/derp_server_default.go +++ b/derp/derpserver/derpserver_default.go @@ -1,9 +1,9 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -//go:build !linux +//go:build !linux || android -package derp +package derpserver import "context" diff --git a/derp/derp_server_linux.go b/derp/derpserver/derpserver_linux.go similarity index 97% rename from derp/derp_server_linux.go rename to derp/derpserver/derpserver_linux.go index bfc2aade6..768e6a2ab 100644 --- a/derp/derp_server_linux.go +++ b/derp/derpserver/derpserver_linux.go @@ -1,7 +1,9 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -package derp +//go:build linux && !android + +package derpserver import ( "context" diff --git a/derp/derpserver/derpserver_test.go b/derp/derpserver/derpserver_test.go new file mode 100644 index 000000000..2db5f25bc --- /dev/null +++ b/derp/derpserver/derpserver_test.go @@ -0,0 +1,782 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package derpserver + +import ( + "bufio" + "cmp" + "context" + "crypto/x509" + "encoding/asn1" + "expvar" + "fmt" + "log" + "net" + "os" + "reflect" + "strconv" + "sync" + "testing" + "time" + + qt "github.com/frankban/quicktest" + "go4.org/mem" + "golang.org/x/time/rate" + "tailscale.com/derp" + "tailscale.com/derp/derpconst" + "tailscale.com/types/key" + "tailscale.com/types/logger" +) + +const testMeshKey = "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef" + +func TestSetMeshKey(t *testing.T) { + for name, tt := range map[string]struct { + key string + want key.DERPMesh + wantErr bool + }{ + "clobber": { + key: testMeshKey, + wantErr: false, + }, + "invalid": { + key: "badf00d", + wantErr: true, + }, + } { + t.Run(name, func(t *testing.T) { + s := &Server{} + + err := s.SetMeshKey(tt.key) + if tt.wantErr { + if err == nil { + t.Fatalf("expected err") + } + return + } + if err != nil { + t.Fatalf("unexpected err: %v", err) + } + + want, err := key.ParseDERPMesh(tt.key) + if err != nil { + t.Fatal(err) + } + if !s.meshKey.Equal(want) { + t.Fatalf("got %v, want %v", s.meshKey, want) + } + }) + } +} + +func TestIsMeshPeer(t *testing.T) { + s := &Server{} + err := s.SetMeshKey(testMeshKey) + if err != nil { + t.Fatal(err) + } + for name, tt := range map[string]struct { + want bool + meshKey string + wantAllocs float64 + }{ + "nil": { + want: false, + wantAllocs: 0, + }, + "mismatch": { + meshKey: "6d529e9d4ef632d22d4a4214cb49da8f1ba1b72697061fb24e312984c35ec8d8", + want: false, + wantAllocs: 1, + }, + "match": { + meshKey: testMeshKey, + want: true, + wantAllocs: 0, + }, + } { + t.Run(name, func(t *testing.T) { + var got bool + var mKey key.DERPMesh + if tt.meshKey != "" { + mKey, err = key.ParseDERPMesh(tt.meshKey) + if err != nil { + t.Fatalf("ParseDERPMesh(%q) failed: %v", tt.meshKey, err) + } + } + + info := derp.ClientInfo{ + MeshKey: mKey, + } + allocs := testing.AllocsPerRun(1, func() { + got = s.isMeshPeer(&info) + }) + if got != tt.want { + t.Fatalf("got %t, want %t: info = %#v", got, tt.want, info) + } + + if allocs != tt.wantAllocs && tt.want { + t.Errorf("%f allocations, want %f", allocs, tt.wantAllocs) + } + }) + } +} + +type testFwd int + +func (testFwd) ForwardPacket(key.NodePublic, key.NodePublic, []byte) error { + panic("not called in tests") +} +func (testFwd) String() string { + panic("not called in tests") +} + +func pubAll(b byte) (ret key.NodePublic) { + var bs [32]byte + for i := range bs { + bs[i] = b + } + return key.NodePublicFromRaw32(mem.B(bs[:])) +} + +func TestForwarderRegistration(t *testing.T) { + s := &Server{ + clients: make(map[key.NodePublic]*clientSet), + clientsMesh: map[key.NodePublic]PacketForwarder{}, + } + want := func(want map[key.NodePublic]PacketForwarder) { + t.Helper() + if got := s.clientsMesh; !reflect.DeepEqual(got, want) { + t.Fatalf("mismatch\n got: %v\nwant: %v\n", got, want) + } + } + wantCounter := func(c *expvar.Int, want int) { + t.Helper() + if got := c.Value(); got != int64(want) { + t.Errorf("counter = %v; want %v", got, want) + } + } + singleClient := func(c *sclient) *clientSet { + cs := &clientSet{} + cs.activeClient.Store(c) + return cs + } + + u1 := pubAll(1) + u2 := pubAll(2) + u3 := pubAll(3) + + s.AddPacketForwarder(u1, testFwd(1)) + s.AddPacketForwarder(u2, testFwd(2)) + want(map[key.NodePublic]PacketForwarder{ + u1: testFwd(1), + u2: testFwd(2), + }) + + // Verify a remove of non-registered forwarder is no-op. + s.RemovePacketForwarder(u2, testFwd(999)) + want(map[key.NodePublic]PacketForwarder{ + u1: testFwd(1), + u2: testFwd(2), + }) + + // Verify a remove of non-registered user is no-op. + s.RemovePacketForwarder(u3, testFwd(1)) + want(map[key.NodePublic]PacketForwarder{ + u1: testFwd(1), + u2: testFwd(2), + }) + + // Actual removal. + s.RemovePacketForwarder(u2, testFwd(2)) + want(map[key.NodePublic]PacketForwarder{ + u1: testFwd(1), + }) + + // Adding a dup for a user. + wantCounter(&s.multiForwarderCreated, 0) + s.AddPacketForwarder(u1, testFwd(100)) + s.AddPacketForwarder(u1, testFwd(100)) // dup to trigger dup path + want(map[key.NodePublic]PacketForwarder{ + u1: newMultiForwarder(testFwd(1), testFwd(100)), + }) + wantCounter(&s.multiForwarderCreated, 1) + + // Removing a forwarder in a multi set that doesn't exist; does nothing. + s.RemovePacketForwarder(u1, testFwd(55)) + want(map[key.NodePublic]PacketForwarder{ + u1: newMultiForwarder(testFwd(1), testFwd(100)), + }) + + // Removing a forwarder in a multi set that does exist should collapse it away + // from being a multiForwarder. + wantCounter(&s.multiForwarderDeleted, 0) + s.RemovePacketForwarder(u1, testFwd(1)) + want(map[key.NodePublic]PacketForwarder{ + u1: testFwd(100), + }) + wantCounter(&s.multiForwarderDeleted, 1) + + // Removing an entry for a client that's still connected locally should result + // in a nil forwarder. + u1c := &sclient{ + key: u1, + logf: logger.Discard, + } + s.clients[u1] = singleClient(u1c) + s.RemovePacketForwarder(u1, testFwd(100)) + want(map[key.NodePublic]PacketForwarder{ + u1: nil, + }) + + // But once that client disconnects, it should go away. + s.unregisterClient(u1c) + want(map[key.NodePublic]PacketForwarder{}) + + // But if it already has a forwarder, it's not removed. + s.AddPacketForwarder(u1, testFwd(2)) + s.unregisterClient(u1c) + want(map[key.NodePublic]PacketForwarder{ + u1: testFwd(2), + }) + + // Now pretend u1 was already connected locally (so clientsMesh[u1] is nil), and then we heard + // that they're also connected to a peer of ours. That shouldn't transition the forwarder + // from nil to the new one, not a multiForwarder. + s.clients[u1] = singleClient(u1c) + s.clientsMesh[u1] = nil + want(map[key.NodePublic]PacketForwarder{ + u1: nil, + }) + s.AddPacketForwarder(u1, testFwd(3)) + want(map[key.NodePublic]PacketForwarder{ + u1: testFwd(3), + }) +} + +type channelFwd struct { + // id is to ensure that different instances that reference the + // same channel are not equal, as they are used as keys in the + // multiForwarder map. + id int + c chan []byte +} + +func (f channelFwd) String() string { return "" } +func (f channelFwd) ForwardPacket(_ key.NodePublic, _ key.NodePublic, packet []byte) error { + f.c <- packet + return nil +} + +func TestMultiForwarder(t *testing.T) { + received := 0 + var wg sync.WaitGroup + ch := make(chan []byte) + ctx, cancel := context.WithCancel(context.Background()) + + s := &Server{ + clients: make(map[key.NodePublic]*clientSet), + clientsMesh: map[key.NodePublic]PacketForwarder{}, + } + u := pubAll(1) + s.AddPacketForwarder(u, channelFwd{1, ch}) + + wg.Add(2) + go func() { + defer wg.Done() + for { + select { + case <-ch: + received += 1 + case <-ctx.Done(): + return + } + } + }() + go func() { + defer wg.Done() + for { + s.AddPacketForwarder(u, channelFwd{2, ch}) + s.AddPacketForwarder(u, channelFwd{3, ch}) + s.RemovePacketForwarder(u, channelFwd{2, ch}) + s.RemovePacketForwarder(u, channelFwd{1, ch}) + s.AddPacketForwarder(u, channelFwd{1, ch}) + s.RemovePacketForwarder(u, channelFwd{3, ch}) + if ctx.Err() != nil { + return + } + } + }() + + // Number of messages is chosen arbitrarily, just for this loop to + // run long enough concurrently with {Add,Remove}PacketForwarder loop above. + numMsgs := 5000 + var fwd PacketForwarder + for i := range numMsgs { + s.mu.Lock() + fwd = s.clientsMesh[u] + s.mu.Unlock() + fwd.ForwardPacket(u, u, []byte(strconv.Itoa(i))) + } + + cancel() + wg.Wait() + if received != numMsgs { + t.Errorf("expected %d messages to be forwarded; got %d", numMsgs, received) + } +} +func TestMetaCert(t *testing.T) { + priv := key.NewNode() + pub := priv.Public() + s := New(priv, t.Logf) + + certBytes := s.MetaCert() + cert, err := x509.ParseCertificate(certBytes) + if err != nil { + log.Fatal(err) + } + if fmt.Sprint(cert.SerialNumber) != fmt.Sprint(derp.ProtocolVersion) { + t.Errorf("serial = %v; want %v", cert.SerialNumber, derp.ProtocolVersion) + } + if g, w := cert.Subject.CommonName, derpconst.MetaCertCommonNamePrefix+pub.UntypedHexString(); g != w { + t.Errorf("CommonName = %q; want %q", g, w) + } + if n := len(cert.Extensions); n != 1 { + t.Fatalf("got %d extensions; want 1", n) + } + + // oidExtensionBasicConstraints is the Basic Constraints ID copied + // from the x509 package. + oidExtensionBasicConstraints := asn1.ObjectIdentifier{2, 5, 29, 19} + + if id := cert.Extensions[0].Id; !id.Equal(oidExtensionBasicConstraints) { + t.Errorf("extension ID = %v; want %v", id, oidExtensionBasicConstraints) + } +} + +func TestServerDupClients(t *testing.T) { + serverPriv := key.NewNode() + var s *Server + + clientPriv := key.NewNode() + clientPub := clientPriv.Public() + + var c1, c2, c3 *sclient + var clientName map[*sclient]string + + // run starts a new test case and resets clients back to their zero values. + run := func(name string, dupPolicy dupPolicy, f func(t *testing.T)) { + s = New(serverPriv, t.Logf) + s.dupPolicy = dupPolicy + c1 = &sclient{key: clientPub, logf: logger.WithPrefix(t.Logf, "c1: ")} + c2 = &sclient{key: clientPub, logf: logger.WithPrefix(t.Logf, "c2: ")} + c3 = &sclient{key: clientPub, logf: logger.WithPrefix(t.Logf, "c3: ")} + clientName = map[*sclient]string{ + c1: "c1", + c2: "c2", + c3: "c3", + } + t.Run(name, f) + } + runBothWays := func(name string, f func(t *testing.T)) { + run(name+"_disablefighters", disableFighters, f) + run(name+"_lastwriteractive", lastWriterIsActive, f) + } + wantSingleClient := func(t *testing.T, want *sclient) { + t.Helper() + got, ok := s.clients[want.key] + if !ok { + t.Error("no clients for key") + return + } + if got.dup != nil { + t.Errorf("unexpected dup set for single client") + } + cur := got.activeClient.Load() + if cur != want { + t.Errorf("active client = %q; want %q", clientName[cur], clientName[want]) + } + if cur != nil { + if cur.isDup.Load() { + t.Errorf("unexpected isDup on singleClient") + } + if cur.isDisabled.Load() { + t.Errorf("unexpected isDisabled on singleClient") + } + } + } + wantNoClient := func(t *testing.T) { + t.Helper() + _, ok := s.clients[clientPub] + if !ok { + // Good + return + } + t.Errorf("got client; want empty") + } + wantDupSet := func(t *testing.T) *dupClientSet { + t.Helper() + cs, ok := s.clients[clientPub] + if !ok { + t.Fatal("no set for key; want dup set") + return nil + } + if cs.dup != nil { + return cs.dup + } + t.Fatalf("no dup set for key; want dup set") + return nil + } + wantActive := func(t *testing.T, want *sclient) { + t.Helper() + set, ok := s.clients[clientPub] + if !ok { + t.Error("no set for key") + return + } + got := set.activeClient.Load() + if got != want { + t.Errorf("active client = %q; want %q", clientName[got], clientName[want]) + } + } + checkDup := func(t *testing.T, c *sclient, want bool) { + t.Helper() + if got := c.isDup.Load(); got != want { + t.Errorf("client %q isDup = %v; want %v", clientName[c], got, want) + } + } + checkDisabled := func(t *testing.T, c *sclient, want bool) { + t.Helper() + if got := c.isDisabled.Load(); got != want { + t.Errorf("client %q isDisabled = %v; want %v", clientName[c], got, want) + } + } + wantDupConns := func(t *testing.T, want int) { + t.Helper() + if got := s.dupClientConns.Value(); got != int64(want) { + t.Errorf("dupClientConns = %v; want %v", got, want) + } + } + wantDupKeys := func(t *testing.T, want int) { + t.Helper() + if got := s.dupClientKeys.Value(); got != int64(want) { + t.Errorf("dupClientKeys = %v; want %v", got, want) + } + } + + // Common case: a single client comes and goes, with no dups. + runBothWays("one_comes_and_goes", func(t *testing.T) { + wantNoClient(t) + s.registerClient(c1) + wantSingleClient(t, c1) + s.unregisterClient(c1) + wantNoClient(t) + }) + + // A still somewhat common case: a single client was + // connected and then their wifi dies or laptop closes + // or they switch networks and connect from a + // different network. They have two connections but + // it's not very bad. Only their new one is + // active. The last one, being dead, doesn't send and + // thus the new one doesn't get disabled. + runBothWays("small_overlap_replacement", func(t *testing.T) { + wantNoClient(t) + s.registerClient(c1) + wantSingleClient(t, c1) + wantActive(t, c1) + wantDupKeys(t, 0) + wantDupKeys(t, 0) + + s.registerClient(c2) // wifi dies; c2 replacement connects + wantDupSet(t) + wantDupConns(t, 2) + wantDupKeys(t, 1) + checkDup(t, c1, true) + checkDup(t, c2, true) + checkDisabled(t, c1, false) + checkDisabled(t, c2, false) + wantActive(t, c2) // sends go to the replacement + + s.unregisterClient(c1) // c1 finally times out + wantSingleClient(t, c2) + checkDup(t, c2, false) // c2 is longer a dup + wantActive(t, c2) + wantDupConns(t, 0) + wantDupKeys(t, 0) + }) + + // Key cloning situation with concurrent clients, both trying + // to write. + run("concurrent_dups_get_disabled", disableFighters, func(t *testing.T) { + wantNoClient(t) + s.registerClient(c1) + wantSingleClient(t, c1) + wantActive(t, c1) + s.registerClient(c2) + wantDupSet(t) + wantDupKeys(t, 1) + wantDupConns(t, 2) + wantActive(t, c2) + checkDup(t, c1, true) + checkDup(t, c2, true) + checkDisabled(t, c1, false) + checkDisabled(t, c2, false) + + s.noteClientActivity(c2) + checkDisabled(t, c1, false) + checkDisabled(t, c2, false) + s.noteClientActivity(c1) + checkDisabled(t, c1, true) + checkDisabled(t, c2, true) + wantActive(t, nil) + + s.registerClient(c3) + wantActive(t, c3) + checkDisabled(t, c3, false) + wantDupKeys(t, 1) + wantDupConns(t, 3) + + s.unregisterClient(c3) + wantActive(t, nil) + wantDupKeys(t, 1) + wantDupConns(t, 2) + + s.unregisterClient(c2) + wantSingleClient(t, c1) + wantDupKeys(t, 0) + wantDupConns(t, 0) + }) + + // Key cloning with an A->B->C->A series instead. + run("concurrent_dups_three_parties", disableFighters, func(t *testing.T) { + wantNoClient(t) + s.registerClient(c1) + s.registerClient(c2) + s.registerClient(c3) + s.noteClientActivity(c1) + checkDisabled(t, c1, true) + checkDisabled(t, c2, true) + checkDisabled(t, c3, true) + wantActive(t, nil) + }) + + run("activity_promotes_primary_when_nil", disableFighters, func(t *testing.T) { + wantNoClient(t) + + // Last registered client is the active one... + s.registerClient(c1) + wantActive(t, c1) + s.registerClient(c2) + wantActive(t, c2) + s.registerClient(c3) + s.noteClientActivity(c2) + wantActive(t, c3) + + // But if the last one goes away, the one with the + // most recent activity wins. + s.unregisterClient(c3) + wantActive(t, c2) + }) + + run("concurrent_dups_three_parties_last_writer", lastWriterIsActive, func(t *testing.T) { + wantNoClient(t) + + s.registerClient(c1) + wantActive(t, c1) + s.registerClient(c2) + wantActive(t, c2) + + s.noteClientActivity(c1) + checkDisabled(t, c1, false) + checkDisabled(t, c2, false) + wantActive(t, c1) + + s.noteClientActivity(c2) + checkDisabled(t, c1, false) + checkDisabled(t, c2, false) + wantActive(t, c2) + + s.unregisterClient(c2) + checkDisabled(t, c1, false) + wantActive(t, c1) + }) +} + +func TestLimiter(t *testing.T) { + rl := rate.NewLimiter(rate.Every(time.Minute), 100) + for i := range 200 { + r := rl.Reserve() + d := r.Delay() + t.Logf("i=%d, allow=%v, d=%v", i, r.OK(), d) + } +} + +// BenchmarkConcurrentStreams exercises mutex contention on a +// single Server instance with multiple concurrent client flows. +func BenchmarkConcurrentStreams(b *testing.B) { + serverPrivateKey := key.NewNode() + s := New(serverPrivateKey, logger.Discard) + defer s.Close() + + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + b.Fatal(err) + } + defer ln.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { + for ctx.Err() == nil { + connIn, err := ln.Accept() + if err != nil { + if ctx.Err() != nil { + return + } + b.Error(err) + return + } + + brwServer := bufio.NewReadWriter(bufio.NewReader(connIn), bufio.NewWriter(connIn)) + go s.Accept(ctx, connIn, brwServer, "test-client") + } + }() + + newClient := func(t testing.TB) *derp.Client { + t.Helper() + connOut, err := net.Dial("tcp", ln.Addr().String()) + if err != nil { + b.Fatal(err) + } + t.Cleanup(func() { connOut.Close() }) + + k := key.NewNode() + + brw := bufio.NewReadWriter(bufio.NewReader(connOut), bufio.NewWriter(connOut)) + client, err := derp.NewClient(k, connOut, brw, logger.Discard) + if err != nil { + b.Fatalf("client: %v", err) + } + return client + } + + b.RunParallel(func(pb *testing.PB) { + c1, c2 := newClient(b), newClient(b) + const packetSize = 100 + msg := make([]byte, packetSize) + for pb.Next() { + if err := c1.Send(c2.PublicKey(), msg); err != nil { + b.Fatal(err) + } + _, err := c2.Recv() + if err != nil { + return + } + } + }) +} + +func BenchmarkSendRecv(b *testing.B) { + for _, size := range []int{10, 100, 1000, 10000} { + b.Run(fmt.Sprintf("msgsize=%d", size), func(b *testing.B) { benchmarkSendRecvSize(b, size) }) + } +} + +func benchmarkSendRecvSize(b *testing.B, packetSize int) { + serverPrivateKey := key.NewNode() + s := New(serverPrivateKey, logger.Discard) + defer s.Close() + + k := key.NewNode() + clientKey := k.Public() + + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + b.Fatal(err) + } + defer ln.Close() + + connOut, err := net.Dial("tcp", ln.Addr().String()) + if err != nil { + b.Fatal(err) + } + defer connOut.Close() + + connIn, err := ln.Accept() + if err != nil { + b.Fatal(err) + } + defer connIn.Close() + + brwServer := bufio.NewReadWriter(bufio.NewReader(connIn), bufio.NewWriter(connIn)) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go s.Accept(ctx, connIn, brwServer, "test-client") + + brw := bufio.NewReadWriter(bufio.NewReader(connOut), bufio.NewWriter(connOut)) + client, err := derp.NewClient(k, connOut, brw, logger.Discard) + if err != nil { + b.Fatalf("client: %v", err) + } + + go func() { + for { + _, err := client.Recv() + if err != nil { + return + } + } + }() + + msg := make([]byte, packetSize) + b.SetBytes(int64(len(msg))) + b.ReportAllocs() + b.ResetTimer() + for range b.N { + if err := client.Send(clientKey, msg); err != nil { + b.Fatal(err) + } + } +} + +func TestParseSSOutput(t *testing.T) { + contents, err := os.ReadFile("testdata/example_ss.txt") + if err != nil { + t.Errorf("os.ReadFile(example_ss.txt) failed: %v", err) + } + seen := parseSSOutput(string(contents)) + if len(seen) == 0 { + t.Errorf("parseSSOutput expected non-empty map") + } +} + +func TestGetPerClientSendQueueDepth(t *testing.T) { + c := qt.New(t) + envKey := "TS_DEBUG_DERP_PER_CLIENT_SEND_QUEUE_DEPTH" + + testCases := []struct { + envVal string + want int + }{ + // Empty case, envknob treats empty as missing also. + { + "", defaultPerClientSendQueueDepth, + }, + { + "64", 64, + }, + } + + for _, tc := range testCases { + t.Run(cmp.Or(tc.envVal, "empty"), func(t *testing.T) { + t.Setenv(envKey, tc.envVal) + val := getPerClientSendQueueDepth() + c.Assert(val, qt.Equals, tc.want) + }) + } +} diff --git a/derp/derphttp/derphttp_server.go b/derp/derpserver/handler.go similarity index 83% rename from derp/derphttp/derphttp_server.go rename to derp/derpserver/handler.go index 41ce86764..7cd6aa2fd 100644 --- a/derp/derphttp/derphttp_server.go +++ b/derp/derpserver/handler.go @@ -1,7 +1,7 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -package derphttp +package derpserver import ( "fmt" @@ -12,15 +12,11 @@ import ( "tailscale.com/derp" ) -// fastStartHeader is the header (with value "1") that signals to the HTTP -// server that the DERP HTTP client does not want the HTTP 101 response -// headers and it will begin writing & reading the DERP protocol immediately -// following its HTTP request. -const fastStartHeader = "Derp-Fast-Start" - // Handler returns an http.Handler to be mounted at /derp, serving s. -func Handler(s *derp.Server) http.Handler { +func Handler(s *Server) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + // These are installed both here and in cmd/derper. The check here // catches both cmd/derper run with DERP disabled (STUN only mode) as // well as DERP being run in tests with derphttp.Handler directly, @@ -40,7 +36,7 @@ func Handler(s *derp.Server) http.Handler { return } - fastStart := r.Header.Get(fastStartHeader) == "1" + fastStart := r.Header.Get(derp.FastStartHeader) == "1" h, ok := w.(http.Hijacker) if !ok { @@ -66,7 +62,11 @@ func Handler(s *derp.Server) http.Handler { pubKey.UntypedHexString()) } - s.Accept(r.Context(), netConn, conn, netConn.RemoteAddr().String()) + if v := r.Header.Get(derp.IdealNodeHeader); v != "" { + ctx = IdealNodeContextKey.WithValue(ctx, v) + } + + s.Accept(ctx, netConn, conn, netConn.RemoteAddr().String()) }) } @@ -92,6 +92,7 @@ func ServeNoContent(w http.ResponseWriter, r *http.Request) { w.Header().Set(NoContentResponseHeader, "response "+challenge) } } + w.Header().Set("Cache-Control", "no-cache, no-store, must-revalidate, no-transform, max-age=0") w.WriteHeader(http.StatusNoContent) } @@ -99,7 +100,7 @@ func isChallengeChar(c rune) bool { // Semi-randomly chosen as a limited set of valid characters return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || ('0' <= c && c <= '9') || - c == '.' || c == '-' || c == '_' + c == '.' || c == '-' || c == '_' || c == ':' } const ( diff --git a/derp/testdata/example_ss.txt b/derp/derpserver/testdata/example_ss.txt similarity index 100% rename from derp/testdata/example_ss.txt rename to derp/derpserver/testdata/example_ss.txt diff --git a/derp/dropreason_string.go b/derp/dropreason_string.go deleted file mode 100644 index 3ad072819..000000000 --- a/derp/dropreason_string.go +++ /dev/null @@ -1,33 +0,0 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Code generated by "stringer -type=dropReason -trimprefix=dropReason"; DO NOT EDIT. - -package derp - -import "strconv" - -func _() { - // An "invalid array index" compiler error signifies that the constant values have changed. - // Re-run the stringer command to generate them again. - var x [1]struct{} - _ = x[dropReasonUnknownDest-0] - _ = x[dropReasonUnknownDestOnFwd-1] - _ = x[dropReasonGoneDisconnected-2] - _ = x[dropReasonQueueHead-3] - _ = x[dropReasonQueueTail-4] - _ = x[dropReasonWriteError-5] - _ = x[dropReasonDupClient-6] - _ = x[numDropReasons-7] -} - -const _dropReason_name = "UnknownDestUnknownDestOnFwdGoneDisconnectedQueueHeadQueueTailWriteErrorDupClientnumDropReasons" - -var _dropReason_index = [...]uint8{0, 11, 27, 43, 52, 61, 71, 80, 94} - -func (i dropReason) String() string { - if i < 0 || i >= dropReason(len(_dropReason_index)-1) { - return "dropReason(" + strconv.FormatInt(int64(i), 10) + ")" - } - return _dropReason_name[_dropReason_index[i]:_dropReason_index[i+1]] -} diff --git a/derp/export_test.go b/derp/export_test.go new file mode 100644 index 000000000..677a4932d --- /dev/null +++ b/derp/export_test.go @@ -0,0 +1,10 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package derp + +import "time" + +func (c *Client) RecvTimeoutForTest(timeout time.Duration) (m ReceivedMessage, err error) { + return c.recvTimeout(timeout) +} diff --git a/derp/xdp/xdp_linux.go b/derp/xdp/xdp_linux.go index 3ebe0a052..309d9ee9a 100644 --- a/derp/xdp/xdp_linux.go +++ b/derp/xdp/xdp_linux.go @@ -14,7 +14,6 @@ import ( "github.com/cilium/ebpf" "github.com/cilium/ebpf/link" "github.com/prometheus/client_golang/prometheus" - "tailscale.com/util/multierr" ) //go:generate go run github.com/cilium/ebpf/cmd/bpf2go -type config -type counters_key -type counter_key_af -type counter_key_packets_bytes_action -type counter_key_prog_end bpf xdp.c -- -I headers @@ -110,7 +109,7 @@ func (s *STUNServer) Close() error { errs = append(errs, s.link.Close()) } errs = append(errs, s.objs.Close()) - return multierr.New(errs...) + return errors.Join(errs...) } type stunServerMetrics struct { diff --git a/disco/disco.go b/disco/disco.go index b9a90029d..f58bc1b8c 100644 --- a/disco/disco.go +++ b/disco/disco.go @@ -25,6 +25,7 @@ import ( "fmt" "net" "net/netip" + "time" "go4.org/mem" "tailscale.com/types/key" @@ -41,9 +42,15 @@ const NonceLen = 24 type MessageType byte const ( - TypePing = MessageType(0x01) - TypePong = MessageType(0x02) - TypeCallMeMaybe = MessageType(0x03) + TypePing = MessageType(0x01) + TypePong = MessageType(0x02) + TypeCallMeMaybe = MessageType(0x03) + TypeBindUDPRelayEndpoint = MessageType(0x04) + TypeBindUDPRelayEndpointChallenge = MessageType(0x05) + TypeBindUDPRelayEndpointAnswer = MessageType(0x06) + TypeCallMeMaybeVia = MessageType(0x07) + TypeAllocateUDPRelayEndpointRequest = MessageType(0x08) + TypeAllocateUDPRelayEndpointResponse = MessageType(0x09) ) const v0 = byte(0) @@ -77,12 +84,25 @@ func Parse(p []byte) (Message, error) { } t, ver, p := MessageType(p[0]), p[1], p[2:] switch t { + // TODO(jwhited): consider using a signature matching encoding.BinaryUnmarshaler case TypePing: return parsePing(ver, p) case TypePong: return parsePong(ver, p) case TypeCallMeMaybe: return parseCallMeMaybe(ver, p) + case TypeBindUDPRelayEndpoint: + return parseBindUDPRelayEndpoint(ver, p) + case TypeBindUDPRelayEndpointChallenge: + return parseBindUDPRelayEndpointChallenge(ver, p) + case TypeBindUDPRelayEndpointAnswer: + return parseBindUDPRelayEndpointAnswer(ver, p) + case TypeCallMeMaybeVia: + return parseCallMeMaybeVia(ver, p) + case TypeAllocateUDPRelayEndpointRequest: + return parseAllocateUDPRelayEndpointRequest(ver, p) + case TypeAllocateUDPRelayEndpointResponse: + return parseAllocateUDPRelayEndpointResponse(ver, p) default: return nil, fmt.Errorf("unknown message type 0x%02x", byte(t)) } @@ -91,6 +111,7 @@ func Parse(p []byte) (Message, error) { // Message a discovery message. type Message interface { // AppendMarshal appends the message's marshaled representation. + // TODO(jwhited): consider using a signature matching encoding.BinaryAppender AppendMarshal([]byte) []byte } @@ -266,7 +287,368 @@ func MessageSummary(m Message) string { return fmt.Sprintf("pong tx=%x", m.TxID[:6]) case *CallMeMaybe: return "call-me-maybe" + case *CallMeMaybeVia: + return "call-me-maybe-via" + case *BindUDPRelayEndpoint: + return "bind-udp-relay-endpoint" + case *BindUDPRelayEndpointChallenge: + return "bind-udp-relay-endpoint-challenge" + case *BindUDPRelayEndpointAnswer: + return "bind-udp-relay-endpoint-answer" + case *AllocateUDPRelayEndpointRequest: + return "allocate-udp-relay-endpoint-request" + case *AllocateUDPRelayEndpointResponse: + return "allocate-udp-relay-endpoint-response" default: return fmt.Sprintf("%#v", m) } } + +// BindUDPRelayHandshakeState represents the state of the 3-way bind handshake +// between UDP relay client and UDP relay server. Its potential values include +// those for both participants, UDP relay client and UDP relay server. A UDP +// relay server implementation can be found in net/udprelay. This is currently +// considered experimental. +type BindUDPRelayHandshakeState int + +const ( + // BindUDPRelayHandshakeStateInit represents the initial state prior to any + // message being transmitted. + BindUDPRelayHandshakeStateInit BindUDPRelayHandshakeState = iota + // BindUDPRelayHandshakeStateBindSent is the first client state after + // transmitting a BindUDPRelayEndpoint message to a UDP relay server. + BindUDPRelayHandshakeStateBindSent + // BindUDPRelayHandshakeStateChallengeSent is the first server state after + // receiving a BindUDPRelayEndpoint message from a UDP relay client and + // replying with a BindUDPRelayEndpointChallenge. + BindUDPRelayHandshakeStateChallengeSent + // BindUDPRelayHandshakeStateAnswerSent is a client state that is entered + // after transmitting a BindUDPRelayEndpointAnswer message towards a UDP + // relay server in response to a BindUDPRelayEndpointChallenge message. + BindUDPRelayHandshakeStateAnswerSent + // BindUDPRelayHandshakeStateAnswerReceived is a server state that is + // entered after it has received a correct BindUDPRelayEndpointAnswer + // message from a UDP relay client in response to a + // BindUDPRelayEndpointChallenge message. + BindUDPRelayHandshakeStateAnswerReceived +) + +// bindUDPRelayEndpointCommonLen is the length of a marshalled +// [BindUDPRelayEndpointCommon], without the message header. +const bindUDPRelayEndpointCommonLen = 72 + +// BindUDPRelayChallengeLen is the length of the Challenge field carried in +// [BindUDPRelayEndpointChallenge] & [BindUDPRelayEndpointAnswer] messages. +const BindUDPRelayChallengeLen = 32 + +// BindUDPRelayEndpointCommon contains fields that are common across all 3 +// UDP relay handshake message types. All 4 field values are expected to be +// consistent for the lifetime of a handshake besides Challenge, which is +// irrelevant in a [BindUDPRelayEndpoint] message. +type BindUDPRelayEndpointCommon struct { + // VNI is the Geneve header Virtual Network Identifier field value, which + // must match this disco-sealed value upon reception. If they are + // non-matching it indicates the cleartext Geneve header was tampered with + // and/or mangled. + VNI uint32 + // Generation represents the handshake generation. Clients must set a new, + // nonzero value at the start of every handshake. + Generation uint32 + // RemoteKey is the disco key of the remote peer participating over this + // relay endpoint. + RemoteKey key.DiscoPublic + // Challenge is set by the server in a [BindUDPRelayEndpointChallenge] + // message, and expected to be echoed back by the client in a + // [BindUDPRelayEndpointAnswer] message. Its value is irrelevant in a + // [BindUDPRelayEndpoint] message, where it simply serves a padding purpose + // ensuring all handshake messages are equal in size. + Challenge [BindUDPRelayChallengeLen]byte +} + +// encode encodes m in b. b must be at least bindUDPRelayEndpointCommonLen bytes +// long. +func (m *BindUDPRelayEndpointCommon) encode(b []byte) { + binary.BigEndian.PutUint32(b, m.VNI) + b = b[4:] + binary.BigEndian.PutUint32(b, m.Generation) + b = b[4:] + m.RemoteKey.AppendTo(b[:0]) + b = b[key.DiscoPublicRawLen:] + copy(b, m.Challenge[:]) +} + +// decode decodes m from b. +func (m *BindUDPRelayEndpointCommon) decode(b []byte) error { + if len(b) < bindUDPRelayEndpointCommonLen { + return errShort + } + m.VNI = binary.BigEndian.Uint32(b) + b = b[4:] + m.Generation = binary.BigEndian.Uint32(b) + b = b[4:] + m.RemoteKey = key.DiscoPublicFromRaw32(mem.B(b[:key.DiscoPublicRawLen])) + b = b[key.DiscoPublicRawLen:] + copy(m.Challenge[:], b[:BindUDPRelayChallengeLen]) + return nil +} + +// BindUDPRelayEndpoint is the first messaged transmitted from UDP relay client +// towards UDP relay server as part of the 3-way bind handshake. +type BindUDPRelayEndpoint struct { + BindUDPRelayEndpointCommon +} + +func (m *BindUDPRelayEndpoint) AppendMarshal(b []byte) []byte { + ret, d := appendMsgHeader(b, TypeBindUDPRelayEndpoint, v0, bindUDPRelayEndpointCommonLen) + m.BindUDPRelayEndpointCommon.encode(d) + return ret +} + +func parseBindUDPRelayEndpoint(ver uint8, p []byte) (m *BindUDPRelayEndpoint, err error) { + m = new(BindUDPRelayEndpoint) + err = m.BindUDPRelayEndpointCommon.decode(p) + if err != nil { + return nil, err + } + return m, nil +} + +// BindUDPRelayEndpointChallenge is transmitted from UDP relay server towards +// UDP relay client in response to a BindUDPRelayEndpoint message as part of the +// 3-way bind handshake. +type BindUDPRelayEndpointChallenge struct { + BindUDPRelayEndpointCommon +} + +func (m *BindUDPRelayEndpointChallenge) AppendMarshal(b []byte) []byte { + ret, d := appendMsgHeader(b, TypeBindUDPRelayEndpointChallenge, v0, bindUDPRelayEndpointCommonLen) + m.BindUDPRelayEndpointCommon.encode(d) + return ret +} + +func parseBindUDPRelayEndpointChallenge(ver uint8, p []byte) (m *BindUDPRelayEndpointChallenge, err error) { + m = new(BindUDPRelayEndpointChallenge) + err = m.BindUDPRelayEndpointCommon.decode(p) + if err != nil { + return nil, err + } + return m, nil +} + +// BindUDPRelayEndpointAnswer is transmitted from UDP relay client to UDP relay +// server in response to a BindUDPRelayEndpointChallenge message. +type BindUDPRelayEndpointAnswer struct { + BindUDPRelayEndpointCommon +} + +func (m *BindUDPRelayEndpointAnswer) AppendMarshal(b []byte) []byte { + ret, d := appendMsgHeader(b, TypeBindUDPRelayEndpointAnswer, v0, bindUDPRelayEndpointCommonLen) + m.BindUDPRelayEndpointCommon.encode(d) + return ret +} + +func parseBindUDPRelayEndpointAnswer(ver uint8, p []byte) (m *BindUDPRelayEndpointAnswer, err error) { + m = new(BindUDPRelayEndpointAnswer) + err = m.BindUDPRelayEndpointCommon.decode(p) + if err != nil { + return nil, err + } + return m, nil +} + +// AllocateUDPRelayEndpointRequest is a message sent only over DERP to request +// allocation of a relay endpoint on a [tailscale.com/net/udprelay.Server] +type AllocateUDPRelayEndpointRequest struct { + // ClientDisco are the Disco public keys of the clients that should be + // permitted to handshake with the endpoint. + ClientDisco [2]key.DiscoPublic + // Generation represents the allocation request generation. The server must + // echo it back in the [AllocateUDPRelayEndpointResponse] to enable request + // and response alignment client-side. + Generation uint32 +} + +// allocateUDPRelayEndpointRequestLen is the length of a marshaled +// [AllocateUDPRelayEndpointRequest] message without the message header. +const allocateUDPRelayEndpointRequestLen = key.DiscoPublicRawLen*2 + // ClientDisco + 4 // Generation + +func (m *AllocateUDPRelayEndpointRequest) AppendMarshal(b []byte) []byte { + ret, p := appendMsgHeader(b, TypeAllocateUDPRelayEndpointRequest, v0, allocateUDPRelayEndpointRequestLen) + for i := 0; i < len(m.ClientDisco); i++ { + disco := m.ClientDisco[i].AppendTo(nil) + copy(p, disco) + p = p[key.DiscoPublicRawLen:] + } + binary.BigEndian.PutUint32(p, m.Generation) + return ret +} + +func parseAllocateUDPRelayEndpointRequest(ver uint8, p []byte) (m *AllocateUDPRelayEndpointRequest, err error) { + m = new(AllocateUDPRelayEndpointRequest) + if ver != 0 { + return + } + if len(p) < allocateUDPRelayEndpointRequestLen { + return m, errShort + } + for i := 0; i < len(m.ClientDisco); i++ { + m.ClientDisco[i] = key.DiscoPublicFromRaw32(mem.B(p[:key.DiscoPublicRawLen])) + p = p[key.DiscoPublicRawLen:] + } + m.Generation = binary.BigEndian.Uint32(p) + return m, nil +} + +// AllocateUDPRelayEndpointResponse is a message sent only over DERP in response +// to a [AllocateUDPRelayEndpointRequest]. +type AllocateUDPRelayEndpointResponse struct { + // Generation represents the allocation request generation. The server must + // echo back the [AllocateUDPRelayEndpointRequest.Generation] here to enable + // request and response alignment client-side. + Generation uint32 + UDPRelayEndpoint +} + +func (m *AllocateUDPRelayEndpointResponse) AppendMarshal(b []byte) []byte { + endpointsLen := epLength * len(m.AddrPorts) + generationLen := 4 + ret, d := appendMsgHeader(b, TypeAllocateUDPRelayEndpointResponse, v0, generationLen+udpRelayEndpointLenMinusAddrPorts+endpointsLen) + binary.BigEndian.PutUint32(d, m.Generation) + m.encode(d[4:]) + return ret +} + +func parseAllocateUDPRelayEndpointResponse(ver uint8, p []byte) (m *AllocateUDPRelayEndpointResponse, err error) { + m = new(AllocateUDPRelayEndpointResponse) + if ver != 0 { + return m, nil + } + if len(p) < 4 { + return m, errShort + } + m.Generation = binary.BigEndian.Uint32(p) + err = m.decode(p[4:]) + return m, err +} + +const udpRelayEndpointLenMinusAddrPorts = key.DiscoPublicRawLen + // ServerDisco + (key.DiscoPublicRawLen * 2) + // ClientDisco + 8 + // LamportID + 4 + // VNI + 8 + // BindLifetime + 8 // SteadyStateLifetime + +// UDPRelayEndpoint is a mirror of [tailscale.com/net/udprelay/endpoint.ServerEndpoint], +// refer to it for field documentation. [UDPRelayEndpoint] is carried in both +// [CallMeMaybeVia] and [AllocateUDPRelayEndpointResponse] messages. +type UDPRelayEndpoint struct { + // ServerDisco is [tailscale.com/net/udprelay/endpoint.ServerEndpoint.ServerDisco] + ServerDisco key.DiscoPublic + // ClientDisco is [tailscale.com/net/udprelay/endpoint.ServerEndpoint.ClientDisco] + ClientDisco [2]key.DiscoPublic + // LamportID is [tailscale.com/net/udprelay/endpoint.ServerEndpoint.LamportID] + LamportID uint64 + // VNI is [tailscale.com/net/udprelay/endpoint.ServerEndpoint.VNI] + VNI uint32 + // BindLifetime is [tailscale.com/net/udprelay/endpoint.ServerEndpoint.BindLifetime] + BindLifetime time.Duration + // SteadyStateLifetime is [tailscale.com/net/udprelay/endpoint.ServerEndpoint.SteadyStateLifetime] + SteadyStateLifetime time.Duration + // AddrPorts is [tailscale.com/net/udprelay/endpoint.ServerEndpoint.AddrPorts] + AddrPorts []netip.AddrPort +} + +// encode encodes m in b. b must be at least [udpRelayEndpointLenMinusAddrPorts] +// + [epLength] * len(m.AddrPorts) bytes long. +func (m *UDPRelayEndpoint) encode(b []byte) { + disco := m.ServerDisco.AppendTo(nil) + copy(b, disco) + b = b[key.DiscoPublicRawLen:] + for i := 0; i < len(m.ClientDisco); i++ { + disco = m.ClientDisco[i].AppendTo(nil) + copy(b, disco) + b = b[key.DiscoPublicRawLen:] + } + binary.BigEndian.PutUint64(b[:8], m.LamportID) + b = b[8:] + binary.BigEndian.PutUint32(b[:4], m.VNI) + b = b[4:] + binary.BigEndian.PutUint64(b[:8], uint64(m.BindLifetime)) + b = b[8:] + binary.BigEndian.PutUint64(b[:8], uint64(m.SteadyStateLifetime)) + b = b[8:] + for _, ipp := range m.AddrPorts { + a := ipp.Addr().As16() + copy(b, a[:]) + binary.BigEndian.PutUint16(b[16:18], ipp.Port()) + b = b[epLength:] + } +} + +// decode decodes m from b. +func (m *UDPRelayEndpoint) decode(b []byte) error { + if len(b) < udpRelayEndpointLenMinusAddrPorts+epLength || + (len(b)-udpRelayEndpointLenMinusAddrPorts)%epLength != 0 { + return errShort + } + m.ServerDisco = key.DiscoPublicFromRaw32(mem.B(b[:key.DiscoPublicRawLen])) + b = b[key.DiscoPublicRawLen:] + for i := 0; i < len(m.ClientDisco); i++ { + m.ClientDisco[i] = key.DiscoPublicFromRaw32(mem.B(b[:key.DiscoPublicRawLen])) + b = b[key.DiscoPublicRawLen:] + } + m.LamportID = binary.BigEndian.Uint64(b[:8]) + b = b[8:] + m.VNI = binary.BigEndian.Uint32(b[:4]) + b = b[4:] + m.BindLifetime = time.Duration(binary.BigEndian.Uint64(b[:8])) + b = b[8:] + m.SteadyStateLifetime = time.Duration(binary.BigEndian.Uint64(b[:8])) + b = b[8:] + m.AddrPorts = make([]netip.AddrPort, 0, len(b)-udpRelayEndpointLenMinusAddrPorts/epLength) + for len(b) > 0 { + var a [16]byte + copy(a[:], b) + m.AddrPorts = append(m.AddrPorts, netip.AddrPortFrom( + netip.AddrFrom16(a).Unmap(), + binary.BigEndian.Uint16(b[16:18]))) + b = b[epLength:] + } + return nil +} + +// CallMeMaybeVia is a message sent only over DERP to request that the recipient +// try to open up a magicsock path back to the sender. The 'Via' in +// CallMeMaybeVia highlights that candidate paths are served through an +// intermediate relay, likely a [tailscale.com/net/udprelay.Server]. +// +// Usage of the candidate paths in magicsock requires a 3-way handshake +// involving [BindUDPRelayEndpoint], [BindUDPRelayEndpointChallenge], and +// [BindUDPRelayEndpointAnswer]. +// +// CallMeMaybeVia mirrors [tailscale.com/net/udprelay/endpoint.ServerEndpoint], +// which contains field documentation. +// +// The recipient may choose to not open a path back if it's already happy with +// its path. Direct connections, e.g. [CallMeMaybe]-signaled, take priority over +// CallMeMaybeVia paths. +type CallMeMaybeVia struct { + UDPRelayEndpoint +} + +func (m *CallMeMaybeVia) AppendMarshal(b []byte) []byte { + endpointsLen := epLength * len(m.AddrPorts) + ret, p := appendMsgHeader(b, TypeCallMeMaybeVia, v0, udpRelayEndpointLenMinusAddrPorts+endpointsLen) + m.encode(p) + return ret +} + +func parseCallMeMaybeVia(ver uint8, p []byte) (m *CallMeMaybeVia, err error) { + m = new(CallMeMaybeVia) + if ver != 0 { + return m, nil + } + err = m.decode(p) + return m, err +} diff --git a/disco/disco_test.go b/disco/disco_test.go index 1a56324a5..71b68338a 100644 --- a/disco/disco_test.go +++ b/disco/disco_test.go @@ -9,12 +9,35 @@ import ( "reflect" "strings" "testing" + "time" "go4.org/mem" "tailscale.com/types/key" ) func TestMarshalAndParse(t *testing.T) { + relayHandshakeCommon := BindUDPRelayEndpointCommon{ + VNI: 1, + Generation: 2, + RemoteKey: key.DiscoPublicFromRaw32(mem.B([]byte{1: 1, 2: 2, 30: 30, 31: 31})), + Challenge: [BindUDPRelayChallengeLen]byte{ + 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, + }, + } + + udpRelayEndpoint := UDPRelayEndpoint{ + ServerDisco: key.DiscoPublicFromRaw32(mem.B([]byte{1: 1, 2: 2, 30: 30, 31: 31})), + ClientDisco: [2]key.DiscoPublic{key.DiscoPublicFromRaw32(mem.B([]byte{1: 1, 2: 2, 3: 3, 30: 30, 31: 31})), key.DiscoPublicFromRaw32(mem.B([]byte{1: 1, 2: 2, 4: 4, 30: 30, 31: 31}))}, + LamportID: 123, + VNI: 456, + BindLifetime: time.Second, + SteadyStateLifetime: time.Minute, + AddrPorts: []netip.AddrPort{ + netip.MustParseAddrPort("1.2.3.4:567"), + netip.MustParseAddrPort("[2001::3456]:789"), + }, + } + tests := []struct { name string want string @@ -83,6 +106,50 @@ func TestMarshalAndParse(t *testing.T) { }, want: "03 00 00 00 00 00 00 00 00 00 00 00 ff ff 01 02 03 04 02 37 20 01 00 00 00 00 00 00 00 00 00 00 00 00 34 56 03 15", }, + { + name: "bind_udp_relay_endpoint", + m: &BindUDPRelayEndpoint{ + relayHandshakeCommon, + }, + want: "04 00 00 00 00 01 00 00 00 02 00 01 02 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 1e 1f 00 01 02 03 04 05 06 07 08 09 0a 0b 0c 0d 0e 0f 10 11 12 13 14 15 16 17 18 19 1a 1b 1c 1d 1e 1f", + }, + { + name: "bind_udp_relay_endpoint_challenge", + m: &BindUDPRelayEndpointChallenge{ + relayHandshakeCommon, + }, + want: "05 00 00 00 00 01 00 00 00 02 00 01 02 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 1e 1f 00 01 02 03 04 05 06 07 08 09 0a 0b 0c 0d 0e 0f 10 11 12 13 14 15 16 17 18 19 1a 1b 1c 1d 1e 1f", + }, + { + name: "bind_udp_relay_endpoint_answer", + m: &BindUDPRelayEndpointAnswer{ + relayHandshakeCommon, + }, + want: "06 00 00 00 00 01 00 00 00 02 00 01 02 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 1e 1f 00 01 02 03 04 05 06 07 08 09 0a 0b 0c 0d 0e 0f 10 11 12 13 14 15 16 17 18 19 1a 1b 1c 1d 1e 1f", + }, + { + name: "call_me_maybe_via", + m: &CallMeMaybeVia{ + UDPRelayEndpoint: udpRelayEndpoint, + }, + want: "07 00 00 01 02 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 1e 1f 00 01 02 03 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 1e 1f 00 01 02 00 04 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 1e 1f 00 00 00 00 00 00 00 7b 00 00 01 c8 00 00 00 00 3b 9a ca 00 00 00 00 0d f8 47 58 00 00 00 00 00 00 00 00 00 00 00 ff ff 01 02 03 04 02 37 20 01 00 00 00 00 00 00 00 00 00 00 00 00 34 56 03 15", + }, + { + name: "allocate_udp_relay_endpoint_request", + m: &AllocateUDPRelayEndpointRequest{ + ClientDisco: [2]key.DiscoPublic{key.DiscoPublicFromRaw32(mem.B([]byte{1: 1, 2: 2, 3: 3, 30: 30, 31: 31})), key.DiscoPublicFromRaw32(mem.B([]byte{1: 1, 2: 2, 4: 4, 30: 30, 31: 31}))}, + Generation: 1, + }, + want: "08 00 00 01 02 03 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 1e 1f 00 01 02 00 04 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 1e 1f 00 00 00 01", + }, + { + name: "allocate_udp_relay_endpoint_response", + m: &AllocateUDPRelayEndpointResponse{ + Generation: 1, + UDPRelayEndpoint: udpRelayEndpoint, + }, + want: "09 00 00 00 00 01 00 01 02 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 1e 1f 00 01 02 03 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 1e 1f 00 01 02 00 04 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 1e 1f 00 00 00 00 00 00 00 7b 00 00 01 c8 00 00 00 00 3b 9a ca 00 00 00 00 0d f8 47 58 00 00 00 00 00 00 00 00 00 00 00 ff ff 01 02 03 04 02 37 20 01 00 00 00 00 00 00 00 00 00 00 00 00 34 56 03 15", + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { diff --git a/docs/commit-messages.md b/docs/commit-messages.md new file mode 100644 index 000000000..79b16e4c6 --- /dev/null +++ b/docs/commit-messages.md @@ -0,0 +1,194 @@ +# Commit messages + +There are different styles of commit messages followed by different projects. +This is Tailscale's style guide for writing git commit messages. +As with all style guides, many things here are subjective and exist primarily to +codify existing conventions and promote uniformity and thus ease of reading by +others. Others have stronger reasons, such as interop with tooling or making +future git archaeology easier. + +Our commit message style is largely based on the Go language's style, which +shares much in common with the Linux kernel's git commit message style (for +which git was invented): + +* Go's high-level example: https://go.dev/doc/contribute#commit_messages +* Go's details: https://golang.org/wiki/CommitMessage +* Linux's style: https://www.kernel.org/doc/html/v4.10/process/submitting-patches.html#describe-your-changes + +(We do *not* use the [Conventional +Commits](https://www.conventionalcommits.org/en/v1.0.0/) style or [Semantic +Commits](https://gist.github.com/joshbuchea/6f47e86d2510bce28f8e7f42ae84c716) +styles. They're reasonable, but we have already been using the Go and Linux +style of commit messages and there is little justification for switching styles. +Consistency is valuable.) + +In a nutshell, our commit messages should look like: + +``` +net/http: handle foo when bar + +[longer description here in the body] + +Fixes #nnnn +``` + +Notably, for the subject (the first line of description): + +- the primary director(ies) from the root affected by the change goes before the colon, e.g. “derp/derphttp:” (if a lot of packages are involved, you can abbreviate to top-level names e.g. ”derp,magicsock:”, and/or remove less relevant packages) +- the part after the colon is a verb, ideally an imperative verb (Linux style, telling the code what to do) or alternatively an infinitive verb that completes the blank in, *"this change modifies Tailscale to ___________"*. e.g. say *“fix the foobar feature”*, not *“fixing”*, *“fixed”*, or *“fixes”*. Or, as Linux guidelines say: + > Describe your changes in imperative mood, e.g. “make xyzzy do frotz” instead of “[This patch] makes xyzzy do frotz” or “[I] changed xyzzy to do frotz”, as if you are giving orders to the codebase to change its behaviour." +- the verb after the colon is lowercase +- there is no trailing period +- it should be kept as short as possible (many git viewing tools prefer under ~76 characters, though we aren’t super strict about this) + + Examples: + + | Good Example | notes | + | ------- | --- | + | `foo/bar: fix memory leak` | | + | `foo/bar: bump deps` | | + | `foo/bar: temporarily restrict access` | adverbs are okay | + | `foo/bar: implement new UI design` | | + | `control/{foo,bar}: optimize bar` | feel free to use {foo,bar} for common subpackages| + + | Bad Example | notes | + | ------- | --- | + | `fixed memory leak` | BAD: missing package prefix | + | `foo/bar: fixed memory leak` | BAD: past tense | + | `foo/bar: fixing memory leak` | BAD: present continuous tense; no `-ing` verbs | + | `foo/bar: bumping deps` | BAD: present continuous tense; no `-ing` verbs | + | `foo/bar: new UI design` | BAD: that's a noun phrase; no verb | + | `foo/bar: made things larger` | BAD: that's past tense | + | `foo/bar: faster algorithm` | BAD: that's an adjective and a noun, not a verb | + | `foo/bar: Fix memory leak` | BAD: capitalized verb | + | `foo/bar: fix memory leak.` | BAD: trailing period | + | `foo/bar:fix memory leak` | BAD: no space after colon | + | `foo/bar : fix memory leak` | BAD: space before colon | + | `foo/bar: fix memory leak Fixes #123` | BAD: the "Fixes" shouldn't be part of the title | + | `!fixup reviewer feedback` | BAD: we don't check in fixup commits; the history should always bisect to a clean, working tree | + + +For the body (the rest of the description): + +- blank line after the subject (first) line +- the text should be wrapped to ~76 characters (to appease git viewing tools, mainly), unless you really need longer lines (e.g. for ASCII art, tables, or long links) +- there must be a `Fixes` or `Updates` line for all non-cleanup commits linking to a tracking bug. This goes after the body with a blank newline separating the two. [Cleanup commits](#is-it-a-cleanup) can use `Updates #cleanup` instead of an issue. +- `Change-Id` lines should ideally be included in commits in the `corp` repo and are more optional in `tailscale/tailscale`. You can configure Git to do this for you by running `./tool/go run misc/install-git-hooks.go` from the root of the corp repo. This was originally a Gerrit thing and we don't use Gerrit, but it lets us tooling track commits as they're cherry-picked between branches. Also, tools like [git-cleanup](https://github.com/bradfitz/gitutil) use it to clean up your old local branches once they're merged upstream. +- we don't use Markdown in commit messages. (Accidental Markdown like bulleted lists or even headings is fine, but not links) +- we require `Signed-off-by` lines in public repos (such as `tailscale/tailscale`). Add them using `git commit --signoff` or `git commit -s` for short. You can use them in private repos but do not have to. +- when moving code between repos, include the repository name, and git hash that it was moved from/to, so it is easier to trace history/blame. + +Please don't use [alternate GitHub-supported +aliases](https://docs.github.com/en/issues/tracking-your-work-with-issues/linking-a-pull-request-to-an-issue) +like `Close` or `Resolves`. Tailscale only uses the verbs `Fixes` and `Updates`. + +To link a commit to an issue without marking it fixed—for example, if the commit +is working toward a fix but not yet a complete fix—GitHub requires only that the +issue is mentioned by number in the commit message. By convention, our commits +mention this at the bottom of the message using `Updates`, where `Fixes` might +be expected, even if the number is also mentioned in the body of the commit +message. + +For example: + +``` +some/dir: refactor func Foo + +This will make the handling of +shorter and easier to test. + +Updates #nnnn +``` + +Please say `Updates` and not other common Github-recognized conventions (that is, don't use `For #nnnn`) + +## Public release notes + +For changes in `tailscale/tailscale` that fix a significant bug or add a new feature that should be included in the release notes for the next release, +add `RELNOTE: ` toward the end of the commit message. +This will aid the release engineer in writing the release notes for the next release. + +## Is it a #cleanup? + +Our issuebot permits writing `Updates #cleanup` instead of an actual GitHub issue number. + +But only do that if it’s actually a cleanup. Don’t use that as an excuse to avoid filing an issue. + +Shortcuts[^1] to file issues: +- [go/bugc](http://go/bugc) (corp, safe choice) +- [go/bugo](http://go/bugo) (open source, if you want it public to the world). + +[^1]: These shortcuts point to our Tailscale’s internal URL shortener service, which you too [can run in your own Tailnet](https://tailscale.com/blog/golink). + +The following guide can help you decide whether a tracking issue is warranted. + +| | | +| --- | --- | +| Was there a crash/panic? | Not a cleanup. Put the panic in a bug. Talk about when it was introduced, why, why a test didn’t catch it, note what followup work might need to be done. | +| Did a customer report it? | Not a cleanup. Make a corp bug with links to the customer ticket. | +| Is it from an incident, get paged? | Not a cleanup. Let’s track why we got paged. | +| Does it change behavior? | Not a cleanup. File a bug to track why. | +| Adding a test for a recently fixed bug? | Not a cleanup. Use the recently fixed bug’s bug number. | +| Does it tweak a constant/parameter? | Not a cleanup. File a bug to track the debugging/tuning effort and record past results and goals for the future state. | +| Fixing a regression from an earlier change? | Not a cleanup. At minimum, reference the PR that caused the regression, but if users noticed, it might warrant its own bug. | +| Is it part of an overall effort that’ll take a hundred small steps? | Not a cleanup. The overall effort should have a tracking bug to collect all the minor efforts. | +| Is it a security fix? Is it a security hardening? | Not a cleanup. There should be a bug about security incidents or security hardening efforts and backporting to previous releases, etc. | +| Is it a feature flag being removed? | Not a cleanup. File a task to coordinate with other teams and to track the work. | + +### Actual cleanup examples + +- Fixing typos in internal comments that users would’ve never seen +- Simple, mechanical replacement of a deprecated API to its equivalently behaving replacement + - [`errors.Wrapf`](https://pkg.go.dev/github.com/pkg/errors#Wrapf) → [`fmt.Errorf("%w")`](https://pkg.go.dev/fmt#Errorf) + - [math/rand](https://pkg.go.dev/math/rand) → [math/rand/v2](https://pkg.go.dev/math/rand/v2) +- Code movement +- Removing dead code that doesn’t change behavior (API changes, feature flags, etc) +- Refactoring in prep for another change (but maybe mention the upcoming change’s bug as motivation) +- Adding a test that you just noticed was missing, not as a result of any bug or report or new feature coming +- Formatting (gofmt / prettifier) that was missed earlier + +### What’s the point of an issue? + +- Let us capture information that is inappropriate for a commit message +- Let us have conversations on a change after the fact +- Let us track metadata on issues and decide what to backport +- Let us associate related changes to each other, including after the fact +- Lets you write the backstory once on an overall bug/effort and re-use that issue number for N future commits, without having to repeat yourself on each commit message +- Provides archaeological breadcrumbs to future debuggers, providing context on why things were changed + +# Reverts + +When you use `git revert` to revert a commit, the default commit message will identify the commit SHA and message that was reverted. You must expand this message to explain **why** it is being reverted, including a link to the associated issue. + +Don't revert reverts. That gets ugly. Send the change anew but reference +the original & earlier revert. + +# Other repos + +To reference an issue in one repo from a commit in another (for example, fixing an issue in corp with a commit in `tailscale/tailscale`), you need to fully-qualify the issue number with the GitHub org/repo syntax: + +``` +cipher/rot13: add new super secure cipher + +Fixes tailscale/corp#1234 +``` + +Referencing a full URL to the issue is also acceptable, but try to prefer the shorter way. + +It's okay to reference the `corp` repo in open source repo commit messages. + +# GitHub Pull Requests + +In the future we plan to make a bot rewrite all PR bodies programmatically from +the commit messages. But for now (2023-07-25).... + +By convention, GitHub Pull Requests follow similar rules to commits, especially +the title of the PR (which should be the first line of the commit). It is less +important to follow these conventions in the PR itself, as it’s the commits that +become a permanent part of the commit history. + +It's okay (but rare) for a PR to contain multiple commits. When a PR does +contain multiple commits, call that out in the PR body for reviewers so they can +review each separately. + +You don't need to include the `Change-Id` in the description of your PR. diff --git a/docs/k8s/operator-architecture.md b/docs/k8s/operator-architecture.md new file mode 100644 index 000000000..29672f6a3 --- /dev/null +++ b/docs/k8s/operator-architecture.md @@ -0,0 +1,602 @@ +# Operator architecture diagrams + +The Tailscale [Kubernetes operator][kb-operator] has a collection of use-cases +that can be mixed and matched as required. The following diagrams illustrate +how the operator implements each use-case. + +In each diagram, the "tailscale" namespace is entirely managed by the operator +once the operator itself has been deployed. + +Tailscale devices are highlighted as black nodes. The salient devices for each +use-case are marked as "src" or "dst" to denote which node is a source or a +destination in the context of ACL rules that will apply to network traffic. + +Note, in some cases, the config and the state Secret may be the same Kubernetes +Secret. + +## API server proxy + +[Documentation][kb-operator-proxy] + +The operator runs the API server proxy in-process. If the proxy is running in +"noauth" mode, it forwards HTTP requests unmodified. If the proxy is running in +"auth" mode, it deletes any existing auth headers and adds +[impersonation headers][k8s-impersonation] to the request before forwarding to +the API server. A request with impersonation headers will look something like: + +``` +GET /api/v1/namespaces/default/pods HTTP/1.1 +Host: k8s-api.example.com +Authorization: Bearer +Impersonate-Group: tailnet-readers +Accept: application/json +``` + +```mermaid +%%{ init: { 'theme':'neutral' } }%% +flowchart LR + classDef tsnode color:#fff,fill:#000; + classDef pod fill:#fff; + + subgraph Key + ts[Tailscale device]:::tsnode + pod((Pod)):::pod + blank[" "]-->|WireGuard traffic| blank2[" "] + blank3[" "]-->|Other network traffic| blank4[" "] + end + + subgraph k8s[Kubernetes cluster] + subgraph tailscale-ns[namespace=tailscale] + operator(("operator (dst)")):::tsnode + end + + subgraph controlplane["Control plane"] + api[kube-apiserver] + end + end + + client["client (src)"]:::tsnode --> operator + operator -->|"proxy (maybe with impersonation headers)"| api + + linkStyle 0 stroke:red; + linkStyle 2 stroke:red; + + linkStyle 1 stroke:blue; + linkStyle 3 stroke:blue; + +``` + +## L3 ingress + +[Documentation][kb-operator-l3-ingress] + +The user deploys an app to the default namespace, and creates a normal Service +that selects the app's Pods. Either add the annotation +`tailscale.com/expose: "true"` or specify `.spec.type` as `Loadbalancer` and +`.spec.loadBalancerClass` as `tailscale`. The operator will create an ingress +proxy that allows devices anywhere on the tailnet to access the Service. + +The proxy Pod uses `iptables` or `nftables` rules to DNAT traffic bound for the +proxy's tailnet IP to the Service's internal Cluster IP instead. + +```mermaid +%%{ init: { 'theme':'neutral' } }%% +flowchart TD + classDef tsnode color:#fff,fill:#000; + classDef pod fill:#fff; + + subgraph Key + ts[Tailscale device]:::tsnode + pod((Pod)):::pod + blank[" "]-->|WireGuard traffic| blank2[" "] + blank3[" "]-->|Other network traffic| blank4[" "] + end + + subgraph k8s[Kubernetes cluster] + subgraph tailscale-ns[namespace=tailscale] + operator((operator)):::tsnode + ingress-sts["StatefulSet"] + ingress(("ingress proxy (dst)")):::tsnode + config-secret["config Secret"] + state-secret["state Secret"] + end + + subgraph defaultns[namespace=default] + svc[annotated Service] + svc --> pod1((pod1)) + svc --> pod2((pod2)) + end + end + + client["client (src)"]:::tsnode --> ingress + ingress -->|forwards traffic| svc + operator -.->|creates| ingress-sts + ingress-sts -.->|manages| ingress + operator -.->|reads| svc + operator -.->|creates| config-secret + config-secret -.->|mounted| ingress + ingress -.->|stores state| state-secret + + linkStyle 0 stroke:red; + linkStyle 4 stroke:red; + + linkStyle 1 stroke:blue; + linkStyle 2 stroke:blue; + linkStyle 3 stroke:blue; + linkStyle 5 stroke:blue; + +``` + +## L7 ingress + +[Documentation][kb-operator-l7-ingress] + +The L7 ingress architecture diagram is relatively similar to L3 ingress. It is +configured via an `Ingress` object instead of a `Service`, and uses +`tailscale serve` to accept traffic instead of configuring `iptables` or +`nftables` rules. Note that we use tailscaled's local API (`SetServeConfig`) to +set serve config, not the `tailscale serve` command. + +```mermaid +%%{ init: { 'theme':'neutral' } }%% +flowchart TD + classDef tsnode color:#fff,fill:#000; + classDef pod fill:#fff; + + subgraph Key + ts[Tailscale device]:::tsnode + pod((Pod)):::pod + blank[" "]-->|WireGuard traffic| blank2[" "] + blank3[" "]-->|Other network traffic| blank4[" "] + end + + subgraph k8s[Kubernetes cluster] + subgraph tailscale-ns[namespace=tailscale] + operator((operator)):::tsnode + ingress-sts["StatefulSet"] + ingress-pod(("ingress proxy (dst)")):::tsnode + config-secret["config Secret"] + state-secret["state Secret"] + end + + subgraph cluster-scope[Cluster scoped resources] + ingress-class[Tailscale IngressClass] + end + + subgraph defaultns[namespace=default] + ingress[tailscale Ingress] + svc["Service"] + svc --> pod1((pod1)) + svc --> pod2((pod2)) + end + end + + client["client (src)"]:::tsnode --> ingress-pod + ingress-pod -->|forwards /api prefix traffic| svc + operator -.->|creates| ingress-sts + ingress-sts -.->|manages| ingress-pod + operator -.->|reads| ingress + operator -.->|creates| config-secret + config-secret -.->|mounted| ingress-pod + ingress-pod -.->|stores state| state-secret + ingress -.->|/api prefix| svc + + linkStyle 0 stroke:red; + linkStyle 4 stroke:red; + + linkStyle 1 stroke:blue; + linkStyle 2 stroke:blue; + linkStyle 3 stroke:blue; + linkStyle 5 stroke:blue; + +``` + +## L3 egress + +[Documentation][kb-operator-l3-egress] + +1. The user deploys a Service with `type: ExternalName` and an annotation + `tailscale.com/tailnet-fqdn: db.tails-scales.ts.net`. +1. The operator creates a proxy Pod managed by a single replica StatefulSet, and a headless Service pointing at the proxy Pod. +1. The operator updates the `ExternalName` Service's `spec.externalName` field to point + at the headless Service it created in the previous step. + +(Optional) If the user also adds the `tailscale.com/proxy-group: egress-proxies` +annotation to their `ExternalName` Service, the operator will skip creating a +proxy Pod and instead point the headless Service at the existing ProxyGroup's +pods. In this case, ports are also required in the `ExternalName` Service spec. +See below for a more representative diagram. + +```mermaid +%%{ init: { 'theme':'neutral' } }%% + +flowchart TD + classDef tsnode color:#fff,fill:#000; + classDef pod fill:#fff; + + subgraph Key + ts[Tailscale device]:::tsnode + pod((Pod)):::pod + blank[" "]-->|WireGuard traffic| blank2[" "] + blank3[" "]-->|Other network traffic| blank4[" "] + end + + subgraph k8s[Kubernetes cluster] + subgraph tailscale-ns[namespace=tailscale] + operator((operator)):::tsnode + egress(("egress proxy (src)")):::tsnode + egress-sts["StatefulSet"] + headless-svc[headless Service] + cfg-secret["config Secret"] + state-secret["state Secret"] + end + + subgraph defaultns[namespace=default] + svc[ExternalName Service] + pod1((pod1)) --> svc + pod2((pod2)) --> svc + end + end + + node["db.tails-scales.ts.net (dst)"]:::tsnode + + svc -->|DNS points to| headless-svc + headless-svc -->|selects egress Pod| egress + egress -->|forwards traffic| node + operator -.->|creates| egress-sts + egress-sts -.->|manages| egress + operator -.->|creates| headless-svc + operator -.->|creates| cfg-secret + operator -.->|watches & updates| svc + cfg-secret -.->|mounted| egress + egress -.->|stores state| state-secret + + linkStyle 0 stroke:red; + linkStyle 6 stroke:red; + + linkStyle 1 stroke:blue; + linkStyle 2 stroke:blue; + linkStyle 3 stroke:blue; + linkStyle 4 stroke:blue; + linkStyle 5 stroke:blue; + +``` + +## `ProxyGroup` + +### Egress + +[Documentation][kb-operator-l3-egress-proxygroup] + +The `ProxyGroup` custom resource manages a collection of proxy Pods that +can be configured to egress traffic out of the cluster via ExternalName +Services. A `ProxyGroup` is both a high availability (HA) version of L3 +egress, and a mechanism to serve multiple ExternalName Services on a single +set of Tailscale devices (coalescing). + +In this diagram, the `ProxyGroup` is named `pg`. The Secrets associated with +the `ProxyGroup` Pods are omitted for simplicity. They are similar to the L3 +egress case above, but there is a pair of config + state Secrets _per Pod_. + +Each ExternalName Service defines which ports should be mapped to their defined +egress target. The operator maps from these ports to randomly chosen ephemeral +ports via the ClusterIP Service and its EndpointSlice. The operator then +generates the egress ConfigMap that tells the `ProxyGroup` Pods which incoming +ports map to which egress targets. + +```mermaid +%%{ init: { 'theme':'neutral' } }%% + +flowchart LR + classDef tsnode color:#fff,fill:#000; + classDef pod fill:#fff; + + subgraph Key + ts[Tailscale device]:::tsnode + pod((Pod)):::pod + blank[" "]-->|WireGuard traffic| blank2[" "] + blank3[" "]-->|Other network traffic| blank4[" "] + end + + subgraph k8s[Kubernetes cluster] + subgraph tailscale-ns[namespace=tailscale] + operator((operator)):::tsnode + pg-sts[StatefulSet] + pg-0(("pg-0 (src)")):::tsnode + pg-1(("pg-1 (src)")):::tsnode + db-cluster-ip[db ClusterIP Service] + api-cluster-ip[api ClusterIP Service] + egress-cm["egress ConfigMap"] + end + + subgraph cluster-scope["Cluster scoped resources"] + pg["ProxyGroup 'pg'"] + end + + subgraph defaultns[namespace=default] + db-svc[db ExternalName Service] + api-svc[api ExternalName Service] + pod1((pod1)) --> db-svc + pod2((pod2)) --> db-svc + pod1((pod1)) --> api-svc + pod2((pod2)) --> api-svc + end + end + + db["db.tails-scales.ts.net (dst)"]:::tsnode + api["api.tails-scales.ts.net (dst)"]:::tsnode + + db-svc -->|DNS points to| db-cluster-ip + api-svc -->|DNS points to| api-cluster-ip + db-cluster-ip -->|maps to ephemeral db ports| pg-0 + db-cluster-ip -->|maps to ephemeral db ports| pg-1 + api-cluster-ip -->|maps to ephemeral api ports| pg-0 + api-cluster-ip -->|maps to ephemeral api ports| pg-1 + pg-0 -->|forwards db port traffic| db + pg-0 -->|forwards api port traffic| api + pg-1 -->|forwards db port traffic| db + pg-1 -->|forwards api port traffic| api + operator -.->|creates & populates endpointslice| db-cluster-ip + operator -.->|creates & populates endpointslice| api-cluster-ip + operator -.->|stores port mapping| egress-cm + egress-cm -.->|mounted| pg-0 + egress-cm -.->|mounted| pg-1 + operator -.->|watches| pg + operator -.->|creates| pg-sts + pg-sts -.->|manages| pg-0 + pg-sts -.->|manages| pg-1 + operator -.->|watches| db-svc + operator -.->|watches| api-svc + + linkStyle 0 stroke:red; + linkStyle 12 stroke:red; + linkStyle 13 stroke:red; + linkStyle 14 stroke:red; + linkStyle 15 stroke:red; + + linkStyle 1 stroke:blue; + linkStyle 2 stroke:blue; + linkStyle 3 stroke:blue; + linkStyle 4 stroke:blue; + linkStyle 5 stroke:blue; + linkStyle 6 stroke:blue; + linkStyle 7 stroke:blue; + linkStyle 8 stroke:blue; + linkStyle 9 stroke:blue; + linkStyle 10 stroke:blue; + linkStyle 11 stroke:blue; + +``` + +### Ingress + +A ProxyGroup can also serve as a highly available set of proxies for an +Ingress resource. The `-0` Pod is always the replica that will issue a certificate +from Let's Encrypt. + +If the same Ingress config is applied in multiple clusters, ProxyGroup proxies +from each cluster will be valid targets for the ts.net DNS name, and the proxy +each client is routed to will depend on the same rules as for [high availability][kb-ha] +subnet routers, and is encoded in the client's netmap. + +```mermaid +%%{ init: { 'theme':'neutral' } }%% +flowchart LR + classDef tsnode color:#fff,fill:#000; + classDef pod fill:#fff; + + subgraph Key + ts[Tailscale device]:::tsnode + pod((Pod)):::pod + blank[" "]-->|WireGuard traffic| blank2[" "] + blank3[" "]-->|Other network traffic| blank4[" "] + end + + subgraph k8s[Kubernetes cluster] + subgraph tailscale-ns[namespace=tailscale] + operator((operator)):::tsnode + ingress-sts["StatefulSet"] + serve-cm[serve config ConfigMap] + ingress-0(("pg-0 (dst)")):::tsnode + ingress-1(("pg-1 (dst)")):::tsnode + tls-secret[myapp.tails.ts.net Secret] + end + + subgraph defaultns[namespace=default] + ingress[myapp.tails.ts.net Ingress] + svc["myapp Service"] + svc --> pod1((pod1)) + svc --> pod2((pod2)) + end + + subgraph cluster[Cluster scoped resources] + ingress-class[Tailscale IngressClass] + pg[ProxyGroup 'pg'] + end + end + + control["Tailscale control plane"] + ts-svc["myapp Tailscale Service"] + + client["client (src)"]:::tsnode -->|dials https\://myapp.tails.ts.net/api| ingress-1 + ingress-0 -->|forwards traffic| svc + ingress-1 -->|forwards traffic| svc + control -.->|creates| ts-svc + operator -.->|creates myapp Tailscale Service| control + control -.->|netmap points myapp Tailscale Service to pg-1| client + operator -.->|creates| ingress-sts + ingress-sts -.->|manages| ingress-0 + ingress-sts -.->|manages| ingress-1 + ingress-0 -.->|issues myapp.tails.ts.net cert| le[Let's Encrypt] + ingress-0 -.->|stores cert| tls-secret + ingress-1 -.->|reads cert| tls-secret + operator -.->|watches| ingress + operator -.->|watches| pg + operator -.->|creates| serve-cm + serve-cm -.->|mounted| ingress-0 + serve-cm -.->|mounted| ingress-1 + ingress -.->|/api prefix| svc + + linkStyle 0 stroke:red; + linkStyle 4 stroke:red; + + linkStyle 1 stroke:blue; + linkStyle 2 stroke:blue; + linkStyle 3 stroke:blue; + linkStyle 5 stroke:blue; + linkStyle 6 stroke:blue; + +``` + +## Connector + +[Subnet router and exit node documentation][kb-operator-connector] + +[App connector documentation][kb-operator-app-connector] + +The Connector Custom Resource can deploy either a subnet router, an exit node, +or an app connector. The following diagram shows all 3, but only one workflow +can be configured per Connector resource. + +```mermaid +%%{ init: { 'theme':'neutral' } }%% + +flowchart TD + classDef tsnode color:#fff,fill:#000; + classDef pod fill:#fff; + classDef hidden display:none; + + subgraph Key + ts[Tailscale device]:::tsnode + pod((Pod)):::pod + blank[" "]-->|WireGuard traffic| blank2[" "] + blank3[" "]-->|Other network traffic| blank4[" "] + end + + subgraph grouping[" "] + subgraph k8s[Kubernetes cluster] + subgraph tailscale-ns[namespace=tailscale] + operator((operator)):::tsnode + cn-sts[StatefulSet] + cn-pod(("tailscale (dst)")):::tsnode + cfg-secret["config Secret"] + state-secret["state Secret"] + end + + subgraph cluster-scope["Cluster scoped resources"] + cn["Connector"] + end + + subgraph defaultns["namespace=default"] + pod1 + end + end + + client["client (src)"]:::tsnode + Internet + end + + client --> cn-pod + cn-pod -->|app connector or exit node routes| Internet + cn-pod -->|subnet route| pod1 + operator -.->|watches| cn + operator -.->|creates| cn-sts + cn-sts -.->|manages| cn-pod + operator -.->|creates| cfg-secret + cfg-secret -.->|mounted| cn-pod + cn-pod -.->|stores state| state-secret + + class grouping hidden + + linkStyle 0 stroke:red; + linkStyle 2 stroke:red; + + linkStyle 1 stroke:blue; + linkStyle 3 stroke:blue; + linkStyle 4 stroke:blue; + +``` + +## Recorder nodes + +[Documentation][kb-operator-recorder] + +The `Recorder` custom resource makes it easier to deploy `tsrecorder` to a cluster. +It currently only supports a single replica. + +```mermaid +%%{ init: { 'theme':'neutral' } }%% + +flowchart TD + classDef tsnode color:#fff,fill:#000; + classDef pod fill:#fff; + classDef hidden display:none; + + subgraph Key + ts[Tailscale device]:::tsnode + pod((Pod)):::pod + blank[" "]-->|WireGuard traffic| blank2[" "] + blank3[" "]-->|Other network traffic| blank4[" "] + end + + subgraph grouping[" "] + subgraph k8s[Kubernetes cluster] + api["kube-apiserver"] + + subgraph tailscale-ns[namespace=tailscale] + operator(("operator (dst)")):::tsnode + rec-sts[StatefulSet] + rec-0(("tsrecorder")):::tsnode + cfg-secret-0["config Secret"] + state-secret-0["state Secret"] + end + + subgraph cluster-scope["Cluster scoped resources"] + rec["Recorder"] + end + end + + client["client (src)"]:::tsnode + kubectl-exec["kubectl exec (src)"]:::tsnode + server["server (dst)"]:::tsnode + s3["S3-compatible storage"] + end + + kubectl-exec -->|exec session| operator + operator -->|exec session recording| rec-0 + operator -->|exec session| api + client -->|ssh session| server + server -->|ssh session recording| rec-0 + rec-0 -->|session recordings| s3 + operator -.->|watches| rec + operator -.->|creates| rec-sts + rec-sts -.->|manages| rec-0 + operator -.->|creates| cfg-secret-0 + cfg-secret-0 -.->|mounted| rec-0 + rec-0 -.->|stores state| state-secret-0 + + class grouping hidden + + linkStyle 0 stroke:red; + linkStyle 2 stroke:red; + linkStyle 3 stroke:red; + linkStyle 5 stroke:red; + linkStyle 6 stroke:red; + + linkStyle 1 stroke:blue; + linkStyle 4 stroke:blue; + linkStyle 7 stroke:blue; + +``` + +[kb-operator]: https://tailscale.com/kb/1236/kubernetes-operator +[kb-operator-proxy]: https://tailscale.com/kb/1437/kubernetes-operator-api-server-proxy +[kb-operator-l3-ingress]: https://tailscale.com/kb/1439/kubernetes-operator-cluster-ingress#exposing-a-cluster-workload-using-a-kubernetes-service +[kb-operator-l7-ingress]: https://tailscale.com/kb/1439/kubernetes-operator-cluster-ingress#exposing-cluster-workloads-using-a-kubernetes-ingress +[kb-operator-l3-egress]: https://tailscale.com/kb/1438/kubernetes-operator-cluster-egress +[kb-operator-l3-egress-proxygroup]: https://tailscale.com/kb/1438/kubernetes-operator-cluster-egress#configure-an-egress-service-using-proxygroup +[kb-operator-connector]: https://tailscale.com/kb/1441/kubernetes-operator-connector +[kb-operator-app-connector]: https://tailscale.com/kb/1517/kubernetes-operator-app-connector +[kb-operator-recorder]: https://tailscale.com/kb/1484/kubernetes-operator-deploying-tsrecorder +[kb-ha]: https://tailscale.com/kb/1115/high-availability +[k8s-impersonation]: https://kubernetes.io/docs/reference/access-authn-authz/authentication/#user-impersonation diff --git a/docs/k8s/proxy.yaml b/docs/k8s/proxy.yaml index 2ab7ed334..048fd7a5b 100644 --- a/docs/k8s/proxy.yaml +++ b/docs/k8s/proxy.yaml @@ -44,7 +44,13 @@ spec: value: "{{TS_DEST_IP}}" - name: TS_AUTH_ONCE value: "true" + - name: POD_NAME + valueFrom: + fieldRef: + fieldPath: metadata.name + - name: POD_UID + valueFrom: + fieldRef: + fieldPath: metadata.uid securityContext: - capabilities: - add: - - NET_ADMIN + privileged: true diff --git a/docs/k8s/role.yaml b/docs/k8s/role.yaml index 6d6a8117d..d7d0846ab 100644 --- a/docs/k8s/role.yaml +++ b/docs/k8s/role.yaml @@ -13,3 +13,6 @@ rules: resourceNames: ["{{TS_KUBE_SECRET}}"] resources: ["secrets"] verbs: ["get", "update", "patch"] +- apiGroups: [""] # "" indicates the core API group + resources: ["events"] + verbs: ["get", "create", "patch"] diff --git a/docs/k8s/sidecar.yaml b/docs/k8s/sidecar.yaml index 7efd32a38..520e4379a 100644 --- a/docs/k8s/sidecar.yaml +++ b/docs/k8s/sidecar.yaml @@ -26,7 +26,13 @@ spec: name: tailscale-auth key: TS_AUTHKEY optional: true + - name: POD_NAME + valueFrom: + fieldRef: + fieldPath: metadata.name + - name: POD_UID + valueFrom: + fieldRef: + fieldPath: metadata.uid securityContext: - capabilities: - add: - - NET_ADMIN + privileged: true diff --git a/docs/k8s/subnet.yaml b/docs/k8s/subnet.yaml index 4b7066fb3..ef4e4748c 100644 --- a/docs/k8s/subnet.yaml +++ b/docs/k8s/subnet.yaml @@ -28,7 +28,13 @@ spec: optional: true - name: TS_ROUTES value: "{{TS_ROUTES}}" + - name: POD_NAME + valueFrom: + fieldRef: + fieldPath: metadata.name + - name: POD_UID + valueFrom: + fieldRef: + fieldPath: metadata.uid securityContext: - capabilities: - add: - - NET_ADMIN + privileged: true diff --git a/docs/k8s/userspace-sidecar.yaml b/docs/k8s/userspace-sidecar.yaml index fc4ed6350..ee19b10a5 100644 --- a/docs/k8s/userspace-sidecar.yaml +++ b/docs/k8s/userspace-sidecar.yaml @@ -27,3 +27,11 @@ spec: name: tailscale-auth key: TS_AUTHKEY optional: true + - name: POD_NAME + valueFrom: + fieldRef: + fieldPath: metadata.name + - name: POD_UID + valueFrom: + fieldRef: + fieldPath: metadata.uid diff --git a/docs/windows/policy/en-US/tailscale.adml b/docs/windows/policy/en-US/tailscale.adml index 7a658422c..a0be5e831 100644 --- a/docs/windows/policy/en-US/tailscale.adml +++ b/docs/windows/policy/en-US/tailscale.adml @@ -15,34 +15,45 @@ Tailscale version 1.58.0 and later Tailscale version 1.62.0 and later Tailscale version 1.74.0 and later + Tailscale version 1.78.0 and later + Tailscale version 1.80.0 and later + Tailscale version 1.82.0 and later + Tailscale version 1.84.0 and later + Tailscale version 1.86.0 and later + Tailscale version 1.90.0 and later Tailscale UI customization Settings + Allowed + Allowed (with audit) + Not Allowed Require using a specific Tailscale coordination server +If you disable or do not configure this policy, the Tailscale SaaS coordination server will be used by default, but a non-standard Tailscale coordination server can be configured using the CLI. + +See https://tailscale.com/kb/1315/mdm-keys#set-a-custom-control-server-url for more details.]]> Require using a specific Tailscale log server Specify which Tailnet should be used for Login +See https://tailscale.com/kb/1315/mdm-keys#set-a-suggested-or-required-tailnet for more details.]]> Specify the auth key to authenticate devices without user interaction Require using a specific Exit Node +If you do not configure this policy, no exit node will be used by default but an exit node (if one is available and permitted by ACLs) can be chosen by the user if desired. + +See https://tailscale.com/kb/1315/mdm-keys#force-an-exit-node-to-always-be-used and https://tailscale.com/kb/1103/exit-nodes for more details.]]> + Limit automated Exit Node suggestions to specific nodes + Allow incoming connections +If you do not configure this policy, then Allow Incoming Connections depends on what is selected in the Preferences submenu. + +See https://tailscale.com/kb/1315/mdm-keys#set-whether-to-allow-incoming-connections and https://tailscale.com/kb/1072/client-preferences#allow-incoming-connections for more details.]]> Run Tailscale in Unattended Mode +If you do not configure this policy, then Run Unattended depends on what is selected in the Preferences submenu. + +See https://tailscale.com/kb/1315/mdm-keys#set-unattended-mode and https://tailscale.com/kb/1088/run-unattended for more details.]]> + Restrict users from disconnecting Tailscale (always-on mode) + + Configure automatic reconnect delay + + Allow users to restart tailscaled + Allow Local Network Access when an Exit Node is in use +If you do not configure this policy, then Allow Local Network Access depends on what is selected in the Exit Node submenu. + +See https://tailscale.com/kb/1315/mdm-keys#toggle-local-network-access-when-an-exit-node-is-in-use and https://tailscale.com/kb/1103/exit-nodes#step-4-use-the-exit-node for more details.]]> Use Tailscale DNS Settings +If you do not configure this policy, then Use Tailscale DNS depends on what is selected in the Preferences submenu. + +See https://tailscale.com/kb/1315/mdm-keys#set-whether-the-device-uses-tailscale-dns-settings for more details.]]> Use Tailscale Subnets +If you do not configure this policy, then Use Tailscale Subnets depends on what is selected in the Preferences submenu. + +See https://tailscale.com/kb/1315/mdm-keys#set-whether-the-device-accepts-tailscale-subnets or https://tailscale.com/kb/1019/subnets for more details.]]> + Always register + Use adapter properties + Register Tailscale IP addresses in DNS + Automatically install updates +If you do not configure this policy, then Automatically Install Updates depends on what is selected in the Preferences submenu. + +See https://tailscale.com/kb/1067/update#auto-updates for more details.]]> Run Tailscale as an Exit Node - Show the "Admin Panel" menu item - + Show the "Admin Console" menu item + Show the "Debug" submenu +If you disable this policy, the Debug submenu will be hidden from the Tailscale menu. + +See https://tailscale.com/kb/1315/mdm-keys#hide-the-debug-menu for more details.]]> Show the "Update Available" menu item +If you disable this policy, the Update Available item will be hidden from the Tailscale menu. + +See https://tailscale.com/kb/1315/mdm-keys#hide-the-update-menu for more details.]]> Show the "Run Exit Node" menu item +If you disable this policy, the Run Exit Node item will be hidden from the Exit Node submenu. + +See https://tailscale.com/kb/1315/mdm-keys#hide-the-run-as-exit-node-menu-item for more details.]]> Show the "Preferences" submenu +If you disable this policy, the Preferences submenu will be hidden from the Tailscale menu. + +See https://tailscale.com/kb/1315/mdm-keys#hide-the-preferences-menu for more details.]]> Show the "Exit Node" submenu +If you disable this policy, the Exit Node submenu will be hidden from the Tailscale menu. + +See https://tailscale.com/kb/1315/mdm-keys#hide-the-exit-node-picker for more details.]]> Specify a custom key expiration notification time +If you disable or don't configure this policy, the default time period will be used (as of Tailscale 1.56, this is 24 hours). + +See https://tailscale.com/kb/1315/mdm-keys#set-the-key-expiration-notice-period for more details.]]> Log extra details about service events Collect data for posture checking +If you do not configure this policy, then data collection depends on if it has been enabled from the CLI (as of Tailscale 1.56), it may be present in the GUI in later versions. + +See https://tailscale.com/kb/1315/mdm-keys#enable-gathering-device-posture-data and https://tailscale.com/kb/1326/device-identity for more details.]]> Show the "Managed By {Organization}" menu item + Show the onboarding flow + + Encrypt client state file stored on disk + @@ -239,10 +319,27 @@ See https://tailscale.com/kb/1315/mdm-keys#set-your-organization-name for more d + + The options below allow configuring exceptions where disconnecting Tailscale is permitted. + Disconnects with reason: + + + The delay must be a valid Go duration string, such as 30s, 5m, or 1h30m, all without spaces or any other symbols. + + + + + User override: + + + Registration mode: + + + Target IDs: diff --git a/docs/windows/policy/tailscale.admx b/docs/windows/policy/tailscale.admx index e70f124ed..7bd31ac9c 100644 --- a/docs/windows/policy/tailscale.admx +++ b/docs/windows/policy/tailscale.admx @@ -50,6 +50,30 @@ displayName="$(string.SINCE_V1_74)"> + + + + + + + + + + + + + + + + + + @@ -95,6 +119,25 @@ + + + + + + + + + + + + + + + + + + + @@ -117,6 +160,47 @@ never + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + @@ -147,6 +231,24 @@ never + + + + + + + + always + + + + + user-decides + + + + + @@ -197,7 +299,7 @@ - + @@ -207,7 +309,7 @@ hide - + @@ -217,7 +319,7 @@ hide - + @@ -227,7 +329,7 @@ hide - + @@ -237,7 +339,7 @@ hide - + @@ -247,7 +349,7 @@ hide - + @@ -257,7 +359,7 @@ hide - + @@ -267,7 +369,17 @@ hide - + + + + + show + + + hide + + + @@ -276,12 +388,22 @@ - + + + + + + + + + + + diff --git a/doctor/ethtool/ethtool_linux.go b/doctor/ethtool/ethtool_linux.go index b8cc08002..f6eaac1df 100644 --- a/doctor/ethtool/ethtool_linux.go +++ b/doctor/ethtool/ethtool_linux.go @@ -1,6 +1,8 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause +//go:build linux && !android + package ethtool import ( diff --git a/doctor/ethtool/ethtool_other.go b/doctor/ethtool/ethtool_other.go index 9aaa9dda8..7af74eec8 100644 --- a/doctor/ethtool/ethtool_other.go +++ b/doctor/ethtool/ethtool_other.go @@ -1,7 +1,7 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -//go:build !linux +//go:build !linux || android package ethtool diff --git a/drive/drive_view.go b/drive/drive_view.go index a6adfbc70..b481751bb 100644 --- a/drive/drive_view.go +++ b/drive/drive_view.go @@ -6,15 +6,17 @@ package drive import ( - "encoding/json" + jsonv1 "encoding/json" "errors" + jsonv2 "github.com/go-json-experiment/json" + "github.com/go-json-experiment/json/jsontext" "tailscale.com/types/views" ) //go:generate go run tailscale.com/cmd/cloner -clonefunc=true -type=Share -// View returns a readonly view of Share. +// View returns a read-only view of Share. func (p *Share) View() ShareView { return ShareView{Đļ: p} } @@ -30,7 +32,7 @@ type ShareView struct { Đļ *Share } -// Valid reports whether underlying value is non-nil. +// Valid reports whether v's underlying value is non-nil. func (v ShareView) Valid() bool { return v.Đļ != nil } // AsStruct returns a clone of the underlying value which aliases no memory with @@ -42,8 +44,17 @@ func (v ShareView) AsStruct() *Share { return v.Đļ.Clone() } -func (v ShareView) MarshalJSON() ([]byte, error) { return json.Marshal(v.Đļ) } +// MarshalJSON implements [jsonv1.Marshaler]. +func (v ShareView) MarshalJSON() ([]byte, error) { + return jsonv1.Marshal(v.Đļ) +} + +// MarshalJSONTo implements [jsonv2.MarshalerTo]. +func (v ShareView) MarshalJSONTo(enc *jsontext.Encoder) error { + return jsonv2.MarshalEncode(enc, v.Đļ) +} +// UnmarshalJSON implements [jsonv1.Unmarshaler]. func (v *ShareView) UnmarshalJSON(b []byte) error { if v.Đļ != nil { return errors.New("already initialized") @@ -52,16 +63,44 @@ func (v *ShareView) UnmarshalJSON(b []byte) error { return nil } var x Share - if err := json.Unmarshal(b, &x); err != nil { + if err := jsonv1.Unmarshal(b, &x); err != nil { + return err + } + v.Đļ = &x + return nil +} + +// UnmarshalJSONFrom implements [jsonv2.UnmarshalerFrom]. +func (v *ShareView) UnmarshalJSONFrom(dec *jsontext.Decoder) error { + if v.Đļ != nil { + return errors.New("already initialized") + } + var x Share + if err := jsonv2.UnmarshalDecode(dec, &x); err != nil { return err } v.Đļ = &x return nil } +// Name is how this share appears on remote nodes. func (v ShareView) Name() string { return v.Đļ.Name } + +// Path is the path to the directory on this machine that's being shared. func (v ShareView) Path() string { return v.Đļ.Path } -func (v ShareView) As() string { return v.Đļ.As } + +// As is the UNIX or Windows username of the local account used for this +// share. File read/write permissions are enforced based on this username. +// Can be left blank to use the default value of "whoever is running the +// Tailscale GUI". +func (v ShareView) As() string { return v.Đļ.As } + +// BookmarkData contains security-scoped bookmark data for the Sandboxed +// Mac application. The Sandboxed Mac application gains permission to +// access the Share's folder as a result of a user selecting it in a file +// picker. In order to retain access to it across restarts, it needs to +// hold on to a security-scoped bookmark. That bookmark is stored here. See +// https://developer.apple.com/documentation/security/app_sandbox/accessing_files_from_the_macos_app_sandbox#4144043 func (v ShareView) BookmarkData() views.ByteSlice[[]byte] { return views.ByteSliceOf(v.Đļ.BookmarkData) } diff --git a/drive/driveimpl/compositedav/stat_cache.go b/drive/driveimpl/compositedav/stat_cache.go index fc57ff064..36463fe7e 100644 --- a/drive/driveimpl/compositedav/stat_cache.go +++ b/drive/driveimpl/compositedav/stat_cache.go @@ -8,6 +8,7 @@ import ( "encoding/xml" "log" "net/http" + "net/url" "sync" "time" @@ -165,7 +166,12 @@ func (c *StatCache) set(name string, depth int, ce *cacheEntry) { children = make(map[string]*cacheEntry, len(ms.Responses)-1) for i := 0; i < len(ms.Responses); i++ { response := ms.Responses[i] - name := shared.Normalize(response.Href) + name, err := url.PathUnescape(response.Href) + if err != nil { + log.Printf("statcache.set child parse error: %s", err) + return + } + name = shared.Normalize(name) raw := marshalMultiStatus(response) entry := newCacheEntry(ce.Status, raw) if i == 0 { diff --git a/drive/driveimpl/compositedav/stat_cache_test.go b/drive/driveimpl/compositedav/stat_cache_test.go index fa63457a2..baa4fdda2 100644 --- a/drive/driveimpl/compositedav/stat_cache_test.go +++ b/drive/driveimpl/compositedav/stat_cache_test.go @@ -16,12 +16,12 @@ import ( "tailscale.com/tstest" ) -var parentPath = "/parent" +var parentPath = "/parent with spaces" -var childPath = "/parent/child.txt" +var childPath = "/parent with spaces/child.txt" var parentResponse = ` -/parent/ +/parent%20with%20spaces/ Mon, 29 Apr 2024 19:52:23 GMT @@ -36,7 +36,7 @@ var parentResponse = ` var childResponse = ` -/parent/child.txt +/parent%20with%20spaces/child.txt Mon, 29 Apr 2024 19:52:23 GMT diff --git a/drive/driveimpl/connlistener.go b/drive/driveimpl/connlistener.go index e1fcb3b67..ff60f7340 100644 --- a/drive/driveimpl/connlistener.go +++ b/drive/driveimpl/connlistener.go @@ -25,12 +25,12 @@ func newConnListener() *connListener { } } -func (l *connListener) Accept() (net.Conn, error) { +func (ln *connListener) Accept() (net.Conn, error) { select { - case <-l.closedCh: + case <-ln.closedCh: // TODO(oxtoacart): make this error match what a regular net.Listener does return nil, syscall.EINVAL - case conn := <-l.ch: + case conn := <-ln.ch: return conn, nil } } @@ -38,32 +38,32 @@ func (l *connListener) Accept() (net.Conn, error) { // Addr implements net.Listener. This always returns nil. It is assumed that // this method is currently unused, so it logs a warning if it ever does get // called. -func (l *connListener) Addr() net.Addr { +func (ln *connListener) Addr() net.Addr { log.Println("warning: unexpected call to connListener.Addr()") return nil } -func (l *connListener) Close() error { - l.closeMu.Lock() - defer l.closeMu.Unlock() +func (ln *connListener) Close() error { + ln.closeMu.Lock() + defer ln.closeMu.Unlock() select { - case <-l.closedCh: + case <-ln.closedCh: // Already closed. return syscall.EINVAL default: // We don't close l.ch because someone maybe trying to send to that, // which would cause a panic. - close(l.closedCh) + close(ln.closedCh) return nil } } -func (l *connListener) HandleConn(c net.Conn, remoteAddr net.Addr) error { +func (ln *connListener) HandleConn(c net.Conn, remoteAddr net.Addr) error { select { - case <-l.closedCh: + case <-ln.closedCh: return syscall.EINVAL - case l.ch <- &connWithRemoteAddr{Conn: c, remoteAddr: remoteAddr}: + case ln.ch <- &connWithRemoteAddr{Conn: c, remoteAddr: remoteAddr}: // Connection has been accepted. } return nil diff --git a/drive/driveimpl/connlistener_test.go b/drive/driveimpl/connlistener_test.go index d8666448a..6adf15acb 100644 --- a/drive/driveimpl/connlistener_test.go +++ b/drive/driveimpl/connlistener_test.go @@ -10,20 +10,20 @@ import ( ) func TestConnListener(t *testing.T) { - l, err := net.Listen("tcp", "127.0.0.1:") + ln, err := net.Listen("tcp", "127.0.0.1:") if err != nil { t.Fatalf("failed to Listen: %s", err) } cl := newConnListener() // Test that we can accept a connection - cc, err := net.Dial("tcp", l.Addr().String()) + cc, err := net.Dial("tcp", ln.Addr().String()) if err != nil { t.Fatalf("failed to Dial: %s", err) } defer cc.Close() - sc, err := l.Accept() + sc, err := ln.Accept() if err != nil { t.Fatalf("failed to Accept: %s", err) } diff --git a/drive/driveimpl/dirfs/dirfs.go b/drive/driveimpl/dirfs/dirfs.go index c1f28bb9d..50a3330a9 100644 --- a/drive/driveimpl/dirfs/dirfs.go +++ b/drive/driveimpl/dirfs/dirfs.go @@ -44,7 +44,7 @@ func (c *Child) isAvailable() bool { // Any attempts to perform operations on paths inside of children will result // in a panic, as these are not expected to be performed on this FS. // -// An FS an optionally have a StaticRoot, which will insert a folder with that +// An FS can optionally have a StaticRoot, which will insert a folder with that // StaticRoot into the tree, like this: // // -- diff --git a/drive/driveimpl/drive_test.go b/drive/driveimpl/drive_test.go index 20b179511..818e84990 100644 --- a/drive/driveimpl/drive_test.go +++ b/drive/driveimpl/drive_test.go @@ -133,6 +133,71 @@ func TestPermissions(t *testing.T) { } } +// TestMissingPaths verifies that the fileserver running at localhost +// correctly handles paths with missing required components. +// +// Expected path format: +// http://localhost:[PORT]//[/] +func TestMissingPaths(t *testing.T) { + s := newSystem(t) + + fileserverAddr := s.addRemote(remote1) + s.addShare(remote1, share11, drive.PermissionReadWrite) + + client := &http.Client{ + Transport: &http.Transport{DisableKeepAlives: true}, + } + addr := strings.Split(fileserverAddr, "|")[1] + secretToken := strings.Split(fileserverAddr, "|")[0] + + testCases := []struct { + name string + path string + wantStatus int + }{ + { + name: "empty path", + path: "", + wantStatus: http.StatusForbidden, + }, + { + name: "single slash", + path: "/", + wantStatus: http.StatusForbidden, + }, + { + name: "only token", + path: "/" + secretToken, + wantStatus: http.StatusBadRequest, + }, + { + name: "token with trailing slash", + path: "/" + secretToken + "/", + wantStatus: http.StatusBadRequest, + }, + { + name: "token and invalid share", + path: "/" + secretToken + "/nonexistentshare", + wantStatus: http.StatusNotFound, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + u := fmt.Sprintf("http://%s%s", addr, tc.path) + resp, err := client.Get(u) + if err != nil { + t.Fatalf("unexpected error making request: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != tc.wantStatus { + t.Errorf("got status code %d, want %d", resp.StatusCode, tc.wantStatus) + } + }) + } +} + // TestSecretTokenAuth verifies that the fileserver running at localhost cannot // be accessed directly without the correct secret token. This matters because // if a victim can be induced to visit the localhost URL and access a malicious @@ -402,14 +467,14 @@ func newSystem(t *testing.T) *system { tstest.ResourceCheck(t) fs := newFileSystemForLocal(log.Printf, nil) - l, err := net.Listen("tcp", "127.0.0.1:0") + ln, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("failed to Listen: %s", err) } - t.Logf("FileSystemForLocal listening at %s", l.Addr()) + t.Logf("FileSystemForLocal listening at %s", ln.Addr()) go func() { for { - conn, err := l.Accept() + conn, err := ln.Accept() if err != nil { t.Logf("Accept: %v", err) return @@ -418,11 +483,11 @@ func newSystem(t *testing.T) *system { } }() - client := gowebdav.NewAuthClient(fmt.Sprintf("http://%s", l.Addr()), &noopAuthorizer{}) + client := gowebdav.NewAuthClient(fmt.Sprintf("http://%s", ln.Addr()), &noopAuthorizer{}) client.SetTransport(&http.Transport{DisableKeepAlives: true}) s := &system{ t: t, - local: &local{l: l, fs: fs}, + local: &local{l: ln, fs: fs}, client: client, remotes: make(map[string]*remote), } @@ -431,11 +496,11 @@ func newSystem(t *testing.T) *system { } func (s *system) addRemote(name string) string { - l, err := net.Listen("tcp", "127.0.0.1:0") + ln, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { s.t.Fatalf("failed to Listen: %s", err) } - s.t.Logf("Remote for %v listening at %s", name, l.Addr()) + s.t.Logf("Remote for %v listening at %s", name, ln.Addr()) fileServer, err := NewFileServer() if err != nil { @@ -445,21 +510,21 @@ func (s *system) addRemote(name string) string { s.t.Logf("FileServer for %v listening at %s", name, fileServer.Addr()) r := &remote{ - l: l, + l: ln, fileServer: fileServer, fs: NewFileSystemForRemote(log.Printf), shares: make(map[string]string), permissions: make(map[string]drive.Permission), } r.fs.SetFileServerAddr(fileServer.Addr()) - go http.Serve(l, r) + go http.Serve(ln, r) s.remotes[name] = r remotes := make([]*drive.Remote, 0, len(s.remotes)) for name, r := range s.remotes { remotes = append(remotes, &drive.Remote{ Name: name, - URL: fmt.Sprintf("http://%s", r.l.Addr()), + URL: func() string { return fmt.Sprintf("http://%s", r.l.Addr()) }, }) } s.local.fs.SetRemotes( @@ -704,8 +769,8 @@ func (a *noopAuthenticator) Close() error { return nil } -const lockBody = ` - - - +const lockBody = ` + + + ` diff --git a/drive/driveimpl/fileserver.go b/drive/driveimpl/fileserver.go index 0067c1cc7..d448d83af 100644 --- a/drive/driveimpl/fileserver.go +++ b/drive/driveimpl/fileserver.go @@ -20,7 +20,7 @@ import ( // It's typically used in a separate process from the actual Taildrive server to // serve up files as an unprivileged user. type FileServer struct { - l net.Listener + ln net.Listener secretToken string shareHandlers map[string]http.Handler sharesMu sync.RWMutex @@ -41,10 +41,10 @@ type FileServer struct { // called. func NewFileServer() (*FileServer, error) { // path := filepath.Join(os.TempDir(), fmt.Sprintf("%v.socket", uuid.New().String())) - // l, err := safesocket.Listen(path) + // ln, err := safesocket.Listen(path) // if err != nil { // TODO(oxtoacart): actually get safesocket working in more environments (MacOS Sandboxed, Windows, ???) - l, err := net.Listen("tcp", "127.0.0.1:0") + ln, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { return nil, fmt.Errorf("listen: %w", err) } @@ -55,13 +55,13 @@ func NewFileServer() (*FileServer, error) { } return &FileServer{ - l: l, + ln: ln, secretToken: secretToken, shareHandlers: make(map[string]http.Handler), }, nil } -// generateSecretToken generates a hex-encoded 256 bit secet. +// generateSecretToken generates a hex-encoded 256 bit secret. func generateSecretToken() (string, error) { tokenBytes := make([]byte, 32) _, err := rand.Read(tokenBytes) @@ -74,12 +74,12 @@ func generateSecretToken() (string, error) { // Addr returns the address at which this FileServer is listening. This // includes the secret token in front of the address, delimited by a pipe |. func (s *FileServer) Addr() string { - return fmt.Sprintf("%s|%s", s.secretToken, s.l.Addr().String()) + return fmt.Sprintf("%s|%s", s.secretToken, s.ln.Addr().String()) } // Serve() starts serving files and blocks until it encounters a fatal error. func (s *FileServer) Serve() error { - return http.Serve(s.l, s) + return http.Serve(s.ln, s) } // LockShares locks the map of shares in preparation for manipulating it. @@ -142,6 +142,10 @@ func (s *FileServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } + if len(parts) < 2 { + w.WriteHeader(http.StatusBadRequest) + return + } r.URL.Path = shared.Join(parts[2:]...) share := parts[1] s.sharesMu.RLock() @@ -158,5 +162,5 @@ func (s *FileServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { } func (s *FileServer) Close() error { - return s.l.Close() + return s.ln.Close() } diff --git a/drive/driveimpl/local_impl.go b/drive/driveimpl/local_impl.go index 8cdf60179..871d03343 100644 --- a/drive/driveimpl/local_impl.go +++ b/drive/driveimpl/local_impl.go @@ -81,7 +81,7 @@ func (s *FileSystemForLocal) SetRemotes(domain string, remotes []*drive.Remote, Name: remote.Name, Available: remote.Available, }, - BaseURL: func() (string, error) { return remote.URL, nil }, + BaseURL: func() (string, error) { return remote.URL(), nil }, Transport: transport, }) } diff --git a/drive/driveimpl/remote_impl.go b/drive/driveimpl/remote_impl.go index 7fd5d3325..2ff98075e 100644 --- a/drive/driveimpl/remote_impl.go +++ b/drive/driveimpl/remote_impl.go @@ -333,8 +333,14 @@ func (s *userServer) run() error { args = append(args, s.Name, s.Path) } var cmd *exec.Cmd - if su := s.canSU(); su != "" { - s.logf("starting taildrive file server as user %q", s.username) + + if s.canSudo() { + s.logf("starting taildrive file server with sudo as user %q", s.username) + allArgs := []string{"-n", "-u", s.username, s.executable} + allArgs = append(allArgs, args...) + cmd = exec.Command("sudo", allArgs...) + } else if su := s.canSU(); su != "" { + s.logf("starting taildrive file server with su as user %q", s.username) // Quote and escape arguments. Use single quotes to prevent shell substitutions. for i, arg := range args { args[i] = "'" + strings.ReplaceAll(arg, "'", "'\"'\"'") + "'" @@ -343,7 +349,7 @@ func (s *userServer) run() error { allArgs := []string{s.username, "-c", cmdString} cmd = exec.Command(su, allArgs...) } else { - // If we were root, we should have been able to sudo as a specific + // If we were root, we should have been able to sudo or su as a specific // user, but let's check just to make sure, since we never want to // access shared folders as root. err := s.assertNotRoot() @@ -409,6 +415,18 @@ var writeMethods = map[string]bool{ "DELETE": true, } +// canSudo checks wether we can sudo -u the configured executable as the +// configured user by attempting to call the executable with the '-h' flag to +// print help. +func (s *userServer) canSudo() bool { + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + if err := exec.CommandContext(ctx, "sudo", "-n", "-u", s.username, s.executable, "-h").Run(); err != nil { + return false + } + return true +} + // canSU checks whether the current process can run su with the right username. // If su can be run, this returns the path to the su command. // If not, this returns the empty string "". diff --git a/drive/driveimpl/shared/pathutil.go b/drive/driveimpl/shared/pathutil.go index efa9f5f32..fcadcdd5a 100644 --- a/drive/driveimpl/shared/pathutil.go +++ b/drive/driveimpl/shared/pathutil.go @@ -22,6 +22,9 @@ const ( // CleanAndSplit cleans the provided path p and splits it into its constituent // parts. This is different from path.Split which just splits a path into prefix // and suffix. +// +// If p is empty or contains only path separators, CleanAndSplit returns a slice +// of length 1 whose only element is "". func CleanAndSplit(p string) []string { return strings.Split(strings.Trim(path.Clean(p), sepStringAndDot), sepString) } @@ -38,6 +41,8 @@ func Parent(p string) string { } // Join behaves like path.Join() but also includes a leading slash. +// +// When parts are missing, the result is "/". func Join(parts ...string) string { fullParts := make([]string, 0, len(parts)) fullParts = append(fullParts, sepString) diff --git a/drive/driveimpl/shared/pathutil_test.go b/drive/driveimpl/shared/pathutil_test.go index 662adbd8b..daee69563 100644 --- a/drive/driveimpl/shared/pathutil_test.go +++ b/drive/driveimpl/shared/pathutil_test.go @@ -40,6 +40,7 @@ func TestJoin(t *testing.T) { parts []string want string }{ + {[]string{}, "/"}, {[]string{""}, "/"}, {[]string{"a"}, "/a"}, {[]string{"/a"}, "/a"}, diff --git a/drive/local.go b/drive/local.go index aff79a57b..052efb3f9 100644 --- a/drive/local.go +++ b/drive/local.go @@ -17,7 +17,7 @@ import ( // Remote represents a remote Taildrive node. type Remote struct { Name string - URL string + URL func() string Available func() bool } diff --git a/drive/remote.go b/drive/remote.go index 9aeead710..2c6fba894 100644 --- a/drive/remote.go +++ b/drive/remote.go @@ -9,7 +9,6 @@ import ( "bytes" "errors" "net/http" - "regexp" "strings" ) @@ -21,10 +20,6 @@ var ( ErrInvalidShareName = errors.New("Share names may only contain the letters a-z, underscore _, parentheses (), or spaces") ) -var ( - shareNameRegex = regexp.MustCompile(`^[a-z0-9_\(\) ]+$`) -) - // AllowShareAs reports whether sharing files as a specific user is allowed. func AllowShareAs() bool { return !DisallowShareAs && doAllowShareAs() @@ -125,9 +120,26 @@ func NormalizeShareName(name string) (string, error) { // Trim whitespace name = strings.TrimSpace(name) - if !shareNameRegex.MatchString(name) { + if !validShareName(name) { return "", ErrInvalidShareName } return name, nil } + +func validShareName(name string) bool { + if name == "" { + return false + } + for _, r := range name { + if 'a' <= r && r <= 'z' || '0' <= r && r <= '9' { + continue + } + switch r { + case '_', ' ', '(', ')': + continue + } + return false + } + return true +} diff --git a/drive/remote_permissions.go b/drive/remote_permissions.go index d3d41c6ec..420eff9a0 100644 --- a/drive/remote_permissions.go +++ b/drive/remote_permissions.go @@ -32,7 +32,7 @@ type grant struct { Access string } -// ParsePermissions builds a Permissions map from a lis of raw grants. +// ParsePermissions builds a Permissions map from a list of raw grants. func ParsePermissions(rawGrants [][]byte) (Permissions, error) { permissions := make(Permissions) for _, rawGrant := range rawGrants { diff --git a/envknob/envknob.go b/envknob/envknob.go index f1925ccf4..17a21387e 100644 --- a/envknob/envknob.go +++ b/envknob/envknob.go @@ -17,6 +17,7 @@ package envknob import ( "bufio" + "errors" "fmt" "io" "log" @@ -27,18 +28,19 @@ import ( "slices" "strconv" "strings" - "sync" "sync/atomic" "time" + "tailscale.com/feature/buildfeatures" "tailscale.com/kube/kubetypes" + "tailscale.com/syncs" "tailscale.com/types/opt" "tailscale.com/version" "tailscale.com/version/distro" ) var ( - mu sync.Mutex + mu syncs.Mutex // +checklocks:mu set = map[string]string{} // +checklocks:mu @@ -410,12 +412,35 @@ func TKASkipSignatureCheck() bool { return Bool("TS_UNSAFE_SKIP_NKS_VERIFICATION // Kubernetes Operator components. func App() string { a := os.Getenv("TS_INTERNAL_APP") - if a == kubetypes.AppConnector || a == kubetypes.AppEgressProxy || a == kubetypes.AppIngressProxy || a == kubetypes.AppIngressResource { + if a == kubetypes.AppConnector || a == kubetypes.AppEgressProxy || a == kubetypes.AppIngressProxy || a == kubetypes.AppIngressResource || a == kubetypes.AppProxyGroupEgress || a == kubetypes.AppProxyGroupIngress { return a } return "" } +// IsCertShareReadOnlyMode returns true if this replica should never attempt to +// issue or renew TLS credentials for any of the HTTPS endpoints that it is +// serving. It should only return certs found in its cert store. Currently, +// this is used by the Kubernetes Operator's HA Ingress via VIPServices, where +// multiple Ingress proxy instances serve the same HTTPS endpoint with a shared +// TLS credentials. The TLS credentials should only be issued by one of the +// replicas. +// For HTTPS Ingress the operator and containerboot ensure +// that read-only replicas will not be serving the HTTPS endpoints before there +// is a shared cert available. +func IsCertShareReadOnlyMode() bool { + m := String("TS_CERT_SHARE_MODE") + return m == "ro" +} + +// IsCertShareReadWriteMode returns true if this instance is the replica +// responsible for issuing and renewing TLS certs in an HA setup with certs +// shared between multiple replicas. +func IsCertShareReadWriteMode() bool { + m := String("TS_CERT_SHARE_MODE") + return m == "rw" +} + // CrashOnUnexpected reports whether the Tailscale client should panic // on unexpected conditions. If TS_DEBUG_CRASH_ON_UNEXPECTED is set, that's // used. Otherwise the default value is true for unstable builds. @@ -439,7 +464,12 @@ var allowRemoteUpdate = RegisterBool("TS_ALLOW_ADMIN_CONSOLE_REMOTE_UPDATE") // AllowsRemoteUpdate reports whether this node has opted-in to letting the // Tailscale control plane initiate a Tailscale update (e.g. on behalf of an // admin on the admin console). -func AllowsRemoteUpdate() bool { return allowRemoteUpdate() } +func AllowsRemoteUpdate() bool { + if !buildfeatures.HasClientUpdate { + return false + } + return allowRemoteUpdate() +} // SetNoLogsNoSupport enables no-logs-no-support mode. func SetNoLogsNoSupport() { @@ -450,6 +480,9 @@ func SetNoLogsNoSupport() { var notInInit atomic.Bool func assertNotInInit() { + if !buildfeatures.HasDebug { + return + } if notInInit.Load() { return } @@ -503,12 +536,17 @@ func ApplyDiskConfigError() error { return applyDiskConfigErr } // // On macOS, use one of: // -// - ~/Library/Containers/io.tailscale.ipn.macsys/Data/tailscaled-env.txt +// - /private/var/root/Library/Containers/io.tailscale.ipn.macsys.network-extension/Data/tailscaled-env.txt // for standalone macOS GUI builds // - ~/Library/Containers/io.tailscale.ipn.macos.network-extension/Data/tailscaled-env.txt // for App Store builds // - /etc/tailscale/tailscaled-env.txt for tailscaled-on-macOS (homebrew, etc) func ApplyDiskConfig() (err error) { + if runtime.GOOS == "linux" && !(buildfeatures.HasDebug || buildfeatures.HasSynology) { + // This function does nothing on Linux, unless you're + // using TS_DEBUG_ENV_FILE or are on Synology. + return nil + } var f *os.File defer func() { if err != nil { @@ -533,44 +571,73 @@ func ApplyDiskConfig() (err error) { return applyKeyValueEnv(f) } - name := getPlatformEnvFile() - if name == "" { - return nil - } - f, err = os.Open(name) - if os.IsNotExist(err) { + names := getPlatformEnvFiles() + if len(names) == 0 { return nil } - if err != nil { - return err + + var errs []error + for _, name := range names { + f, err = os.Open(name) + if os.IsNotExist(err) { + continue + } + if err != nil { + errs = append(errs, err) + continue + } + defer f.Close() + + return applyKeyValueEnv(f) } - defer f.Close() - return applyKeyValueEnv(f) + + // If we have any errors, return them; if all errors are such that + // os.IsNotExist(err) returns true, then errs is empty and we will + // return nil. + return errors.Join(errs...) } -// getPlatformEnvFile returns the current platform's path to an optional -// tailscaled-env.txt file. It returns an empty string if none is defined -// for the platform. -func getPlatformEnvFile() string { +// getPlatformEnvFiles returns a list of paths to the current platform's +// optional tailscaled-env.txt file. It returns an empty list if none is +// defined for the platform. +func getPlatformEnvFiles() []string { switch runtime.GOOS { case "windows": - return filepath.Join(os.Getenv("ProgramData"), "Tailscale", "tailscaled-env.txt") + return []string{ + filepath.Join(os.Getenv("ProgramData"), "Tailscale", "tailscaled-env.txt"), + } case "linux": - if distro.Get() == distro.Synology { - return "/etc/tailscale/tailscaled-env.txt" + if buildfeatures.HasSynology && distro.Get() == distro.Synology { + return []string{"/etc/tailscale/tailscaled-env.txt"} } case "darwin": if version.IsSandboxedMacOS() { // the two GUI variants (App Store or separate download) - // This will be user-visible as ~/Library/Containers/$VARIANT/Data/tailscaled-env.txt - // where $VARIANT is "io.tailscale.ipn.macsys" for macsys (downloadable mac GUI builds) - // or "io.tailscale.ipn.macos.network-extension" for App Store builds. - return filepath.Join(os.Getenv("HOME"), "tailscaled-env.txt") + // On the App Store variant, the home directory is set + // to something like: + // ~/Library/Containers/io.tailscale.ipn.macos.network-extension/Data + // + // On the macsys (downloadable Mac GUI) variant, the + // home directory can be unset, but we have a working + // directory that looks like: + // /private/var/root/Library/Containers/io.tailscale.ipn.macsys.network-extension/Data + // + // Try both and see if we can find the file in either + // location. + var candidates []string + if home := os.Getenv("HOME"); home != "" { + candidates = append(candidates, filepath.Join(home, "tailscaled-env.txt")) + } + if wd, err := os.Getwd(); err == nil { + candidates = append(candidates, filepath.Join(wd, "tailscaled-env.txt")) + } + + return candidates } else { // Open source / homebrew variable, running tailscaled-on-macOS. - return "/etc/tailscale/tailscaled-env.txt" + return []string{"/etc/tailscale/tailscaled-env.txt"} } } - return "" + return nil } // applyKeyValueEnv reads key=value lines r and calls Setenv for each. diff --git a/envknob/features.go b/envknob/featureknob/featureknob.go similarity index 56% rename from envknob/features.go rename to envknob/featureknob/featureknob.go index 9e5909de3..5a54a1c42 100644 --- a/envknob/features.go +++ b/envknob/featureknob/featureknob.go @@ -1,12 +1,15 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -package envknob +// Package featureknob provides a facility to control whether features +// can run based on either an envknob or running OS / distro. +package featureknob import ( "errors" "runtime" + "tailscale.com/envknob" "tailscale.com/version" "tailscale.com/version/distro" ) @@ -16,10 +19,10 @@ import ( func CanRunTailscaleSSH() error { switch runtime.GOOS { case "linux": - if distro.Get() == distro.Synology && !UseWIPCode() { + if distro.Get() == distro.Synology && !envknob.UseWIPCode() { return errors.New("The Tailscale SSH server does not run on Synology.") } - if distro.Get() == distro.QNAP && !UseWIPCode() { + if distro.Get() == distro.QNAP && !envknob.UseWIPCode() { return errors.New("The Tailscale SSH server does not run on QNAP.") } // otherwise okay @@ -28,12 +31,23 @@ func CanRunTailscaleSSH() error { if version.IsSandboxedMacOS() { return errors.New("The Tailscale SSH server does not run in sandboxed Tailscale GUI builds.") } - case "freebsd", "openbsd": + case "freebsd", "openbsd", "plan9": default: return errors.New("The Tailscale SSH server is not supported on " + runtime.GOOS) } - if !CanSSHD() { + if !envknob.CanSSHD() { return errors.New("The Tailscale SSH server has been administratively disabled.") } return nil } + +// CanUseExitNode reports whether using an exit node is supported for the +// current os/distro. +func CanUseExitNode() error { + switch dist := distro.Get(); dist { + case distro.Synology, // see https://github.com/tailscale/tailscale/issues/1995 + distro.QNAP: + return errors.New("Tailscale exit nodes cannot be used on " + string(dist)) + } + return nil +} diff --git a/envknob/logknob/logknob.go b/envknob/logknob/logknob.go index 350384b86..93302d0d2 100644 --- a/envknob/logknob/logknob.go +++ b/envknob/logknob/logknob.go @@ -11,7 +11,6 @@ import ( "tailscale.com/envknob" "tailscale.com/tailcfg" "tailscale.com/types/logger" - "tailscale.com/types/views" ) // TODO(andrew-d): should we have a package-global registry of logknobs? It @@ -59,7 +58,7 @@ func (lk *LogKnob) Set(v bool) { // about; we use this rather than a concrete type to avoid a circular // dependency. type NetMap interface { - SelfCapabilities() views.Slice[tailcfg.NodeCapability] + HasSelfCapability(tailcfg.NodeCapability) bool } // UpdateFromNetMap will enable logging if the SelfNode in the provided NetMap @@ -68,8 +67,7 @@ func (lk *LogKnob) UpdateFromNetMap(nm NetMap) { if lk.capName == "" { return } - - lk.cap.Store(views.SliceContains(nm.SelfCapabilities(), lk.capName)) + lk.cap.Store(nm.HasSelfCapability(lk.capName)) } // Do will call log with the provided format and arguments if any of the diff --git a/envknob/logknob/logknob_test.go b/envknob/logknob/logknob_test.go index b2a376a25..aa4fb4421 100644 --- a/envknob/logknob/logknob_test.go +++ b/envknob/logknob/logknob_test.go @@ -11,6 +11,7 @@ import ( "tailscale.com/envknob" "tailscale.com/tailcfg" "tailscale.com/types/netmap" + "tailscale.com/util/set" ) var testKnob = NewLogKnob( @@ -63,11 +64,7 @@ func TestLogKnob(t *testing.T) { } testKnob.UpdateFromNetMap(&netmap.NetworkMap{ - SelfNode: (&tailcfg.Node{ - Capabilities: []tailcfg.NodeCapability{ - "https://tailscale.com/cap/testing", - }, - }).View(), + AllCaps: set.Of(tailcfg.NodeCapability("https://tailscale.com/cap/testing")), }) if !testKnob.shouldLog() { t.Errorf("expected shouldLog()=true") diff --git a/feature/ace/ace.go b/feature/ace/ace.go new file mode 100644 index 000000000..b6d36543c --- /dev/null +++ b/feature/ace/ace.go @@ -0,0 +1,25 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package ace registers support for Alternate Connectivity Endpoints (ACE). +package ace + +import ( + "net/netip" + + "tailscale.com/control/controlhttp" + "tailscale.com/net/ace" + "tailscale.com/net/netx" +) + +func init() { + controlhttp.HookMakeACEDialer.Set(mkDialer) +} + +func mkDialer(dialer netx.DialFunc, aceHost string, optIP netip.Addr) netx.DialFunc { + return (&ace.Dialer{ + ACEHost: aceHost, + ACEHostIP: optIP, // may be zero + NetDialer: dialer, + }).Dial +} diff --git a/feature/appconnectors/appconnectors.go b/feature/appconnectors/appconnectors.go new file mode 100644 index 000000000..28f5ccde3 --- /dev/null +++ b/feature/appconnectors/appconnectors.go @@ -0,0 +1,39 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package appconnectors registers support for Tailscale App Connectors. +package appconnectors + +import ( + "encoding/json" + "net/http" + + "tailscale.com/ipn/ipnlocal" + "tailscale.com/tailcfg" +) + +func init() { + ipnlocal.RegisterC2N("GET /appconnector/routes", handleC2NAppConnectorDomainRoutesGet) +} + +// handleC2NAppConnectorDomainRoutesGet handles returning the domains +// that the app connector is responsible for, as well as the resolved +// IP addresses for each domain. If the node is not configured as +// an app connector, an empty map is returned. +func handleC2NAppConnectorDomainRoutesGet(b *ipnlocal.LocalBackend, w http.ResponseWriter, r *http.Request) { + logf := b.Logger() + logf("c2n: GET /appconnector/routes received") + + var res tailcfg.C2NAppConnectorDomainRoutesResponse + appConnector := b.AppConnector() + if appConnector == nil { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(res) + return + } + + res.Domains = appConnector.DomainRoutes() + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(res) +} diff --git a/feature/buildfeatures/buildfeatures.go b/feature/buildfeatures/buildfeatures.go new file mode 100644 index 000000000..cdb31dc01 --- /dev/null +++ b/feature/buildfeatures/buildfeatures.go @@ -0,0 +1,10 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:generate go run gen.go + +// The buildfeatures package contains boolean constants indicating which +// features were included in the binary (via build tags), for use in dead code +// elimination when using separate build tag protected files is impractical +// or undesirable. +package buildfeatures diff --git a/feature/buildfeatures/feature_ace_disabled.go b/feature/buildfeatures/feature_ace_disabled.go new file mode 100644 index 000000000..b4808d497 --- /dev/null +++ b/feature/buildfeatures/feature_ace_disabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build ts_omit_ace + +package buildfeatures + +// HasACE is whether the binary was built with support for modular feature "Alternate Connectivity Endpoints". +// Specifically, it's whether the binary was NOT built with the "ts_omit_ace" build tag. +// It's a const so it can be used for dead code elimination. +const HasACE = false diff --git a/feature/buildfeatures/feature_ace_enabled.go b/feature/buildfeatures/feature_ace_enabled.go new file mode 100644 index 000000000..4812f9a61 --- /dev/null +++ b/feature/buildfeatures/feature_ace_enabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build !ts_omit_ace + +package buildfeatures + +// HasACE is whether the binary was built with support for modular feature "Alternate Connectivity Endpoints". +// Specifically, it's whether the binary was NOT built with the "ts_omit_ace" build tag. +// It's a const so it can be used for dead code elimination. +const HasACE = true diff --git a/feature/buildfeatures/feature_acme_disabled.go b/feature/buildfeatures/feature_acme_disabled.go new file mode 100644 index 000000000..0a7f25a82 --- /dev/null +++ b/feature/buildfeatures/feature_acme_disabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build ts_omit_acme + +package buildfeatures + +// HasACME is whether the binary was built with support for modular feature "ACME TLS certificate management". +// Specifically, it's whether the binary was NOT built with the "ts_omit_acme" build tag. +// It's a const so it can be used for dead code elimination. +const HasACME = false diff --git a/feature/buildfeatures/feature_acme_enabled.go b/feature/buildfeatures/feature_acme_enabled.go new file mode 100644 index 000000000..f074bfb4e --- /dev/null +++ b/feature/buildfeatures/feature_acme_enabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build !ts_omit_acme + +package buildfeatures + +// HasACME is whether the binary was built with support for modular feature "ACME TLS certificate management". +// Specifically, it's whether the binary was NOT built with the "ts_omit_acme" build tag. +// It's a const so it can be used for dead code elimination. +const HasACME = true diff --git a/feature/buildfeatures/feature_advertiseexitnode_disabled.go b/feature/buildfeatures/feature_advertiseexitnode_disabled.go new file mode 100644 index 000000000..d4fdcec22 --- /dev/null +++ b/feature/buildfeatures/feature_advertiseexitnode_disabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build ts_omit_advertiseexitnode + +package buildfeatures + +// HasAdvertiseExitNode is whether the binary was built with support for modular feature "Run an exit node". +// Specifically, it's whether the binary was NOT built with the "ts_omit_advertiseexitnode" build tag. +// It's a const so it can be used for dead code elimination. +const HasAdvertiseExitNode = false diff --git a/feature/buildfeatures/feature_advertiseexitnode_enabled.go b/feature/buildfeatures/feature_advertiseexitnode_enabled.go new file mode 100644 index 000000000..28246143e --- /dev/null +++ b/feature/buildfeatures/feature_advertiseexitnode_enabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build !ts_omit_advertiseexitnode + +package buildfeatures + +// HasAdvertiseExitNode is whether the binary was built with support for modular feature "Run an exit node". +// Specifically, it's whether the binary was NOT built with the "ts_omit_advertiseexitnode" build tag. +// It's a const so it can be used for dead code elimination. +const HasAdvertiseExitNode = true diff --git a/feature/buildfeatures/feature_advertiseroutes_disabled.go b/feature/buildfeatures/feature_advertiseroutes_disabled.go new file mode 100644 index 000000000..59042720f --- /dev/null +++ b/feature/buildfeatures/feature_advertiseroutes_disabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build ts_omit_advertiseroutes + +package buildfeatures + +// HasAdvertiseRoutes is whether the binary was built with support for modular feature "Advertise routes for other nodes to use". +// Specifically, it's whether the binary was NOT built with the "ts_omit_advertiseroutes" build tag. +// It's a const so it can be used for dead code elimination. +const HasAdvertiseRoutes = false diff --git a/feature/buildfeatures/feature_advertiseroutes_enabled.go b/feature/buildfeatures/feature_advertiseroutes_enabled.go new file mode 100644 index 000000000..118fcd55d --- /dev/null +++ b/feature/buildfeatures/feature_advertiseroutes_enabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build !ts_omit_advertiseroutes + +package buildfeatures + +// HasAdvertiseRoutes is whether the binary was built with support for modular feature "Advertise routes for other nodes to use". +// Specifically, it's whether the binary was NOT built with the "ts_omit_advertiseroutes" build tag. +// It's a const so it can be used for dead code elimination. +const HasAdvertiseRoutes = true diff --git a/feature/buildfeatures/feature_appconnectors_disabled.go b/feature/buildfeatures/feature_appconnectors_disabled.go new file mode 100644 index 000000000..64ea8f86b --- /dev/null +++ b/feature/buildfeatures/feature_appconnectors_disabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build ts_omit_appconnectors + +package buildfeatures + +// HasAppConnectors is whether the binary was built with support for modular feature "App Connectors support". +// Specifically, it's whether the binary was NOT built with the "ts_omit_appconnectors" build tag. +// It's a const so it can be used for dead code elimination. +const HasAppConnectors = false diff --git a/feature/buildfeatures/feature_appconnectors_enabled.go b/feature/buildfeatures/feature_appconnectors_enabled.go new file mode 100644 index 000000000..e00eaffa3 --- /dev/null +++ b/feature/buildfeatures/feature_appconnectors_enabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build !ts_omit_appconnectors + +package buildfeatures + +// HasAppConnectors is whether the binary was built with support for modular feature "App Connectors support". +// Specifically, it's whether the binary was NOT built with the "ts_omit_appconnectors" build tag. +// It's a const so it can be used for dead code elimination. +const HasAppConnectors = true diff --git a/feature/buildfeatures/feature_aws_disabled.go b/feature/buildfeatures/feature_aws_disabled.go new file mode 100644 index 000000000..66b670c1f --- /dev/null +++ b/feature/buildfeatures/feature_aws_disabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build ts_omit_aws + +package buildfeatures + +// HasAWS is whether the binary was built with support for modular feature "AWS integration". +// Specifically, it's whether the binary was NOT built with the "ts_omit_aws" build tag. +// It's a const so it can be used for dead code elimination. +const HasAWS = false diff --git a/feature/buildfeatures/feature_aws_enabled.go b/feature/buildfeatures/feature_aws_enabled.go new file mode 100644 index 000000000..30203b2aa --- /dev/null +++ b/feature/buildfeatures/feature_aws_enabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build !ts_omit_aws + +package buildfeatures + +// HasAWS is whether the binary was built with support for modular feature "AWS integration". +// Specifically, it's whether the binary was NOT built with the "ts_omit_aws" build tag. +// It's a const so it can be used for dead code elimination. +const HasAWS = true diff --git a/feature/buildfeatures/feature_bakedroots_disabled.go b/feature/buildfeatures/feature_bakedroots_disabled.go new file mode 100644 index 000000000..f203bc1b0 --- /dev/null +++ b/feature/buildfeatures/feature_bakedroots_disabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build ts_omit_bakedroots + +package buildfeatures + +// HasBakedRoots is whether the binary was built with support for modular feature "Embed CA (LetsEncrypt) x509 roots to use as fallback". +// Specifically, it's whether the binary was NOT built with the "ts_omit_bakedroots" build tag. +// It's a const so it can be used for dead code elimination. +const HasBakedRoots = false diff --git a/feature/buildfeatures/feature_bakedroots_enabled.go b/feature/buildfeatures/feature_bakedroots_enabled.go new file mode 100644 index 000000000..69cf2c34c --- /dev/null +++ b/feature/buildfeatures/feature_bakedroots_enabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build !ts_omit_bakedroots + +package buildfeatures + +// HasBakedRoots is whether the binary was built with support for modular feature "Embed CA (LetsEncrypt) x509 roots to use as fallback". +// Specifically, it's whether the binary was NOT built with the "ts_omit_bakedroots" build tag. +// It's a const so it can be used for dead code elimination. +const HasBakedRoots = true diff --git a/feature/buildfeatures/feature_bird_disabled.go b/feature/buildfeatures/feature_bird_disabled.go new file mode 100644 index 000000000..469aa41f9 --- /dev/null +++ b/feature/buildfeatures/feature_bird_disabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build ts_omit_bird + +package buildfeatures + +// HasBird is whether the binary was built with support for modular feature "Bird BGP integration". +// Specifically, it's whether the binary was NOT built with the "ts_omit_bird" build tag. +// It's a const so it can be used for dead code elimination. +const HasBird = false diff --git a/feature/buildfeatures/feature_bird_enabled.go b/feature/buildfeatures/feature_bird_enabled.go new file mode 100644 index 000000000..792129f64 --- /dev/null +++ b/feature/buildfeatures/feature_bird_enabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build !ts_omit_bird + +package buildfeatures + +// HasBird is whether the binary was built with support for modular feature "Bird BGP integration". +// Specifically, it's whether the binary was NOT built with the "ts_omit_bird" build tag. +// It's a const so it can be used for dead code elimination. +const HasBird = true diff --git a/feature/buildfeatures/feature_c2n_disabled.go b/feature/buildfeatures/feature_c2n_disabled.go new file mode 100644 index 000000000..bc37e9e7b --- /dev/null +++ b/feature/buildfeatures/feature_c2n_disabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build ts_omit_c2n + +package buildfeatures + +// HasC2N is whether the binary was built with support for modular feature "Control-to-node (C2N) support". +// Specifically, it's whether the binary was NOT built with the "ts_omit_c2n" build tag. +// It's a const so it can be used for dead code elimination. +const HasC2N = false diff --git a/feature/buildfeatures/feature_c2n_enabled.go b/feature/buildfeatures/feature_c2n_enabled.go new file mode 100644 index 000000000..5950e7157 --- /dev/null +++ b/feature/buildfeatures/feature_c2n_enabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build !ts_omit_c2n + +package buildfeatures + +// HasC2N is whether the binary was built with support for modular feature "Control-to-node (C2N) support". +// Specifically, it's whether the binary was NOT built with the "ts_omit_c2n" build tag. +// It's a const so it can be used for dead code elimination. +const HasC2N = true diff --git a/feature/buildfeatures/feature_cachenetmap_disabled.go b/feature/buildfeatures/feature_cachenetmap_disabled.go new file mode 100644 index 000000000..22407fe38 --- /dev/null +++ b/feature/buildfeatures/feature_cachenetmap_disabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build ts_omit_cachenetmap + +package buildfeatures + +// HasCacheNetMap is whether the binary was built with support for modular feature "Cache the netmap on disk between runs". +// Specifically, it's whether the binary was NOT built with the "ts_omit_cachenetmap" build tag. +// It's a const so it can be used for dead code elimination. +const HasCacheNetMap = false diff --git a/feature/buildfeatures/feature_cachenetmap_enabled.go b/feature/buildfeatures/feature_cachenetmap_enabled.go new file mode 100644 index 000000000..02663c416 --- /dev/null +++ b/feature/buildfeatures/feature_cachenetmap_enabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build !ts_omit_cachenetmap + +package buildfeatures + +// HasCacheNetMap is whether the binary was built with support for modular feature "Cache the netmap on disk between runs". +// Specifically, it's whether the binary was NOT built with the "ts_omit_cachenetmap" build tag. +// It's a const so it can be used for dead code elimination. +const HasCacheNetMap = true diff --git a/feature/buildfeatures/feature_captiveportal_disabled.go b/feature/buildfeatures/feature_captiveportal_disabled.go new file mode 100644 index 000000000..367fef81b --- /dev/null +++ b/feature/buildfeatures/feature_captiveportal_disabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build ts_omit_captiveportal + +package buildfeatures + +// HasCaptivePortal is whether the binary was built with support for modular feature "Captive portal detection". +// Specifically, it's whether the binary was NOT built with the "ts_omit_captiveportal" build tag. +// It's a const so it can be used for dead code elimination. +const HasCaptivePortal = false diff --git a/feature/buildfeatures/feature_captiveportal_enabled.go b/feature/buildfeatures/feature_captiveportal_enabled.go new file mode 100644 index 000000000..bd8e1f6a8 --- /dev/null +++ b/feature/buildfeatures/feature_captiveportal_enabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build !ts_omit_captiveportal + +package buildfeatures + +// HasCaptivePortal is whether the binary was built with support for modular feature "Captive portal detection". +// Specifically, it's whether the binary was NOT built with the "ts_omit_captiveportal" build tag. +// It's a const so it can be used for dead code elimination. +const HasCaptivePortal = true diff --git a/feature/buildfeatures/feature_capture_disabled.go b/feature/buildfeatures/feature_capture_disabled.go new file mode 100644 index 000000000..58535958f --- /dev/null +++ b/feature/buildfeatures/feature_capture_disabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build ts_omit_capture + +package buildfeatures + +// HasCapture is whether the binary was built with support for modular feature "Packet capture". +// Specifically, it's whether the binary was NOT built with the "ts_omit_capture" build tag. +// It's a const so it can be used for dead code elimination. +const HasCapture = false diff --git a/feature/buildfeatures/feature_capture_enabled.go b/feature/buildfeatures/feature_capture_enabled.go new file mode 100644 index 000000000..7120a3d06 --- /dev/null +++ b/feature/buildfeatures/feature_capture_enabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build !ts_omit_capture + +package buildfeatures + +// HasCapture is whether the binary was built with support for modular feature "Packet capture". +// Specifically, it's whether the binary was NOT built with the "ts_omit_capture" build tag. +// It's a const so it can be used for dead code elimination. +const HasCapture = true diff --git a/feature/buildfeatures/feature_cliconndiag_disabled.go b/feature/buildfeatures/feature_cliconndiag_disabled.go new file mode 100644 index 000000000..06d8c7935 --- /dev/null +++ b/feature/buildfeatures/feature_cliconndiag_disabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build ts_omit_cliconndiag + +package buildfeatures + +// HasCLIConnDiag is whether the binary was built with support for modular feature "CLI connection error diagnostics". +// Specifically, it's whether the binary was NOT built with the "ts_omit_cliconndiag" build tag. +// It's a const so it can be used for dead code elimination. +const HasCLIConnDiag = false diff --git a/feature/buildfeatures/feature_cliconndiag_enabled.go b/feature/buildfeatures/feature_cliconndiag_enabled.go new file mode 100644 index 000000000..d6125ef08 --- /dev/null +++ b/feature/buildfeatures/feature_cliconndiag_enabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build !ts_omit_cliconndiag + +package buildfeatures + +// HasCLIConnDiag is whether the binary was built with support for modular feature "CLI connection error diagnostics". +// Specifically, it's whether the binary was NOT built with the "ts_omit_cliconndiag" build tag. +// It's a const so it can be used for dead code elimination. +const HasCLIConnDiag = true diff --git a/feature/buildfeatures/feature_clientmetrics_disabled.go b/feature/buildfeatures/feature_clientmetrics_disabled.go new file mode 100644 index 000000000..721908bb0 --- /dev/null +++ b/feature/buildfeatures/feature_clientmetrics_disabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build ts_omit_clientmetrics + +package buildfeatures + +// HasClientMetrics is whether the binary was built with support for modular feature "Client metrics support". +// Specifically, it's whether the binary was NOT built with the "ts_omit_clientmetrics" build tag. +// It's a const so it can be used for dead code elimination. +const HasClientMetrics = false diff --git a/feature/buildfeatures/feature_clientmetrics_enabled.go b/feature/buildfeatures/feature_clientmetrics_enabled.go new file mode 100644 index 000000000..deaeb6e69 --- /dev/null +++ b/feature/buildfeatures/feature_clientmetrics_enabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build !ts_omit_clientmetrics + +package buildfeatures + +// HasClientMetrics is whether the binary was built with support for modular feature "Client metrics support". +// Specifically, it's whether the binary was NOT built with the "ts_omit_clientmetrics" build tag. +// It's a const so it can be used for dead code elimination. +const HasClientMetrics = true diff --git a/feature/buildfeatures/feature_clientupdate_disabled.go b/feature/buildfeatures/feature_clientupdate_disabled.go new file mode 100644 index 000000000..165c9cc9a --- /dev/null +++ b/feature/buildfeatures/feature_clientupdate_disabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build ts_omit_clientupdate + +package buildfeatures + +// HasClientUpdate is whether the binary was built with support for modular feature "Client auto-update support". +// Specifically, it's whether the binary was NOT built with the "ts_omit_clientupdate" build tag. +// It's a const so it can be used for dead code elimination. +const HasClientUpdate = false diff --git a/feature/buildfeatures/feature_clientupdate_enabled.go b/feature/buildfeatures/feature_clientupdate_enabled.go new file mode 100644 index 000000000..3c3c7878c --- /dev/null +++ b/feature/buildfeatures/feature_clientupdate_enabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build !ts_omit_clientupdate + +package buildfeatures + +// HasClientUpdate is whether the binary was built with support for modular feature "Client auto-update support". +// Specifically, it's whether the binary was NOT built with the "ts_omit_clientupdate" build tag. +// It's a const so it can be used for dead code elimination. +const HasClientUpdate = true diff --git a/feature/buildfeatures/feature_cloud_disabled.go b/feature/buildfeatures/feature_cloud_disabled.go new file mode 100644 index 000000000..3b877a9c6 --- /dev/null +++ b/feature/buildfeatures/feature_cloud_disabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build ts_omit_cloud + +package buildfeatures + +// HasCloud is whether the binary was built with support for modular feature "detect cloud environment to learn instances IPs and DNS servers". +// Specifically, it's whether the binary was NOT built with the "ts_omit_cloud" build tag. +// It's a const so it can be used for dead code elimination. +const HasCloud = false diff --git a/feature/buildfeatures/feature_cloud_enabled.go b/feature/buildfeatures/feature_cloud_enabled.go new file mode 100644 index 000000000..8fd748de5 --- /dev/null +++ b/feature/buildfeatures/feature_cloud_enabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build !ts_omit_cloud + +package buildfeatures + +// HasCloud is whether the binary was built with support for modular feature "detect cloud environment to learn instances IPs and DNS servers". +// Specifically, it's whether the binary was NOT built with the "ts_omit_cloud" build tag. +// It's a const so it can be used for dead code elimination. +const HasCloud = true diff --git a/feature/buildfeatures/feature_completion_disabled.go b/feature/buildfeatures/feature_completion_disabled.go new file mode 100644 index 000000000..ea319beb0 --- /dev/null +++ b/feature/buildfeatures/feature_completion_disabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build ts_omit_completion + +package buildfeatures + +// HasCompletion is whether the binary was built with support for modular feature "CLI shell completion". +// Specifically, it's whether the binary was NOT built with the "ts_omit_completion" build tag. +// It's a const so it can be used for dead code elimination. +const HasCompletion = false diff --git a/feature/buildfeatures/feature_completion_enabled.go b/feature/buildfeatures/feature_completion_enabled.go new file mode 100644 index 000000000..6db41c97b --- /dev/null +++ b/feature/buildfeatures/feature_completion_enabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build !ts_omit_completion + +package buildfeatures + +// HasCompletion is whether the binary was built with support for modular feature "CLI shell completion". +// Specifically, it's whether the binary was NOT built with the "ts_omit_completion" build tag. +// It's a const so it can be used for dead code elimination. +const HasCompletion = true diff --git a/feature/buildfeatures/feature_dbus_disabled.go b/feature/buildfeatures/feature_dbus_disabled.go new file mode 100644 index 000000000..e6ab89677 --- /dev/null +++ b/feature/buildfeatures/feature_dbus_disabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build ts_omit_dbus + +package buildfeatures + +// HasDBus is whether the binary was built with support for modular feature "Linux DBus support". +// Specifically, it's whether the binary was NOT built with the "ts_omit_dbus" build tag. +// It's a const so it can be used for dead code elimination. +const HasDBus = false diff --git a/feature/buildfeatures/feature_dbus_enabled.go b/feature/buildfeatures/feature_dbus_enabled.go new file mode 100644 index 000000000..374331cda --- /dev/null +++ b/feature/buildfeatures/feature_dbus_enabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build !ts_omit_dbus + +package buildfeatures + +// HasDBus is whether the binary was built with support for modular feature "Linux DBus support". +// Specifically, it's whether the binary was NOT built with the "ts_omit_dbus" build tag. +// It's a const so it can be used for dead code elimination. +const HasDBus = true diff --git a/feature/buildfeatures/feature_debug_disabled.go b/feature/buildfeatures/feature_debug_disabled.go new file mode 100644 index 000000000..eb048c082 --- /dev/null +++ b/feature/buildfeatures/feature_debug_disabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build ts_omit_debug + +package buildfeatures + +// HasDebug is whether the binary was built with support for modular feature "various debug support, for things that don't have or need their own more specific feature". +// Specifically, it's whether the binary was NOT built with the "ts_omit_debug" build tag. +// It's a const so it can be used for dead code elimination. +const HasDebug = false diff --git a/feature/buildfeatures/feature_debug_enabled.go b/feature/buildfeatures/feature_debug_enabled.go new file mode 100644 index 000000000..12a2700a4 --- /dev/null +++ b/feature/buildfeatures/feature_debug_enabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build !ts_omit_debug + +package buildfeatures + +// HasDebug is whether the binary was built with support for modular feature "various debug support, for things that don't have or need their own more specific feature". +// Specifically, it's whether the binary was NOT built with the "ts_omit_debug" build tag. +// It's a const so it can be used for dead code elimination. +const HasDebug = true diff --git a/feature/buildfeatures/feature_debugeventbus_disabled.go b/feature/buildfeatures/feature_debugeventbus_disabled.go new file mode 100644 index 000000000..2eb599934 --- /dev/null +++ b/feature/buildfeatures/feature_debugeventbus_disabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build ts_omit_debugeventbus + +package buildfeatures + +// HasDebugEventBus is whether the binary was built with support for modular feature "eventbus debug support". +// Specifically, it's whether the binary was NOT built with the "ts_omit_debugeventbus" build tag. +// It's a const so it can be used for dead code elimination. +const HasDebugEventBus = false diff --git a/feature/buildfeatures/feature_debugeventbus_enabled.go b/feature/buildfeatures/feature_debugeventbus_enabled.go new file mode 100644 index 000000000..df13b6fa2 --- /dev/null +++ b/feature/buildfeatures/feature_debugeventbus_enabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build !ts_omit_debugeventbus + +package buildfeatures + +// HasDebugEventBus is whether the binary was built with support for modular feature "eventbus debug support". +// Specifically, it's whether the binary was NOT built with the "ts_omit_debugeventbus" build tag. +// It's a const so it can be used for dead code elimination. +const HasDebugEventBus = true diff --git a/feature/buildfeatures/feature_debugportmapper_disabled.go b/feature/buildfeatures/feature_debugportmapper_disabled.go new file mode 100644 index 000000000..eff85b8ba --- /dev/null +++ b/feature/buildfeatures/feature_debugportmapper_disabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build ts_omit_debugportmapper + +package buildfeatures + +// HasDebugPortMapper is whether the binary was built with support for modular feature "portmapper debug support". +// Specifically, it's whether the binary was NOT built with the "ts_omit_debugportmapper" build tag. +// It's a const so it can be used for dead code elimination. +const HasDebugPortMapper = false diff --git a/feature/buildfeatures/feature_debugportmapper_enabled.go b/feature/buildfeatures/feature_debugportmapper_enabled.go new file mode 100644 index 000000000..491aa5ed8 --- /dev/null +++ b/feature/buildfeatures/feature_debugportmapper_enabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build !ts_omit_debugportmapper + +package buildfeatures + +// HasDebugPortMapper is whether the binary was built with support for modular feature "portmapper debug support". +// Specifically, it's whether the binary was NOT built with the "ts_omit_debugportmapper" build tag. +// It's a const so it can be used for dead code elimination. +const HasDebugPortMapper = true diff --git a/feature/buildfeatures/feature_desktop_sessions_disabled.go b/feature/buildfeatures/feature_desktop_sessions_disabled.go new file mode 100644 index 000000000..1536c886f --- /dev/null +++ b/feature/buildfeatures/feature_desktop_sessions_disabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build ts_omit_desktop_sessions + +package buildfeatures + +// HasDesktopSessions is whether the binary was built with support for modular feature "Desktop sessions support". +// Specifically, it's whether the binary was NOT built with the "ts_omit_desktop_sessions" build tag. +// It's a const so it can be used for dead code elimination. +const HasDesktopSessions = false diff --git a/feature/buildfeatures/feature_desktop_sessions_enabled.go b/feature/buildfeatures/feature_desktop_sessions_enabled.go new file mode 100644 index 000000000..84658de95 --- /dev/null +++ b/feature/buildfeatures/feature_desktop_sessions_enabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build !ts_omit_desktop_sessions + +package buildfeatures + +// HasDesktopSessions is whether the binary was built with support for modular feature "Desktop sessions support". +// Specifically, it's whether the binary was NOT built with the "ts_omit_desktop_sessions" build tag. +// It's a const so it can be used for dead code elimination. +const HasDesktopSessions = true diff --git a/feature/buildfeatures/feature_dns_disabled.go b/feature/buildfeatures/feature_dns_disabled.go new file mode 100644 index 000000000..30d7379cb --- /dev/null +++ b/feature/buildfeatures/feature_dns_disabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build ts_omit_dns + +package buildfeatures + +// HasDNS is whether the binary was built with support for modular feature "MagicDNS and system DNS configuration support". +// Specifically, it's whether the binary was NOT built with the "ts_omit_dns" build tag. +// It's a const so it can be used for dead code elimination. +const HasDNS = false diff --git a/feature/buildfeatures/feature_dns_enabled.go b/feature/buildfeatures/feature_dns_enabled.go new file mode 100644 index 000000000..962f2596b --- /dev/null +++ b/feature/buildfeatures/feature_dns_enabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build !ts_omit_dns + +package buildfeatures + +// HasDNS is whether the binary was built with support for modular feature "MagicDNS and system DNS configuration support". +// Specifically, it's whether the binary was NOT built with the "ts_omit_dns" build tag. +// It's a const so it can be used for dead code elimination. +const HasDNS = true diff --git a/feature/buildfeatures/feature_doctor_disabled.go b/feature/buildfeatures/feature_doctor_disabled.go new file mode 100644 index 000000000..8c15e951e --- /dev/null +++ b/feature/buildfeatures/feature_doctor_disabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build ts_omit_doctor + +package buildfeatures + +// HasDoctor is whether the binary was built with support for modular feature "Diagnose possible issues with Tailscale and its host environment". +// Specifically, it's whether the binary was NOT built with the "ts_omit_doctor" build tag. +// It's a const so it can be used for dead code elimination. +const HasDoctor = false diff --git a/feature/buildfeatures/feature_doctor_enabled.go b/feature/buildfeatures/feature_doctor_enabled.go new file mode 100644 index 000000000..a8a0bb7d2 --- /dev/null +++ b/feature/buildfeatures/feature_doctor_enabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build !ts_omit_doctor + +package buildfeatures + +// HasDoctor is whether the binary was built with support for modular feature "Diagnose possible issues with Tailscale and its host environment". +// Specifically, it's whether the binary was NOT built with the "ts_omit_doctor" build tag. +// It's a const so it can be used for dead code elimination. +const HasDoctor = true diff --git a/feature/buildfeatures/feature_drive_disabled.go b/feature/buildfeatures/feature_drive_disabled.go new file mode 100644 index 000000000..072026389 --- /dev/null +++ b/feature/buildfeatures/feature_drive_disabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build ts_omit_drive + +package buildfeatures + +// HasDrive is whether the binary was built with support for modular feature "Tailscale Drive (file server) support". +// Specifically, it's whether the binary was NOT built with the "ts_omit_drive" build tag. +// It's a const so it can be used for dead code elimination. +const HasDrive = false diff --git a/feature/buildfeatures/feature_drive_enabled.go b/feature/buildfeatures/feature_drive_enabled.go new file mode 100644 index 000000000..9f58836a4 --- /dev/null +++ b/feature/buildfeatures/feature_drive_enabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build !ts_omit_drive + +package buildfeatures + +// HasDrive is whether the binary was built with support for modular feature "Tailscale Drive (file server) support". +// Specifically, it's whether the binary was NOT built with the "ts_omit_drive" build tag. +// It's a const so it can be used for dead code elimination. +const HasDrive = true diff --git a/feature/buildfeatures/feature_gro_disabled.go b/feature/buildfeatures/feature_gro_disabled.go new file mode 100644 index 000000000..ffbd0da2e --- /dev/null +++ b/feature/buildfeatures/feature_gro_disabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build ts_omit_gro + +package buildfeatures + +// HasGRO is whether the binary was built with support for modular feature "Generic Receive Offload support (performance)". +// Specifically, it's whether the binary was NOT built with the "ts_omit_gro" build tag. +// It's a const so it can be used for dead code elimination. +const HasGRO = false diff --git a/feature/buildfeatures/feature_gro_enabled.go b/feature/buildfeatures/feature_gro_enabled.go new file mode 100644 index 000000000..e2c8024e0 --- /dev/null +++ b/feature/buildfeatures/feature_gro_enabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build !ts_omit_gro + +package buildfeatures + +// HasGRO is whether the binary was built with support for modular feature "Generic Receive Offload support (performance)". +// Specifically, it's whether the binary was NOT built with the "ts_omit_gro" build tag. +// It's a const so it can be used for dead code elimination. +const HasGRO = true diff --git a/feature/buildfeatures/feature_health_disabled.go b/feature/buildfeatures/feature_health_disabled.go new file mode 100644 index 000000000..2f2bcf240 --- /dev/null +++ b/feature/buildfeatures/feature_health_disabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build ts_omit_health + +package buildfeatures + +// HasHealth is whether the binary was built with support for modular feature "Health checking support". +// Specifically, it's whether the binary was NOT built with the "ts_omit_health" build tag. +// It's a const so it can be used for dead code elimination. +const HasHealth = false diff --git a/feature/buildfeatures/feature_health_enabled.go b/feature/buildfeatures/feature_health_enabled.go new file mode 100644 index 000000000..00ce3684e --- /dev/null +++ b/feature/buildfeatures/feature_health_enabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build !ts_omit_health + +package buildfeatures + +// HasHealth is whether the binary was built with support for modular feature "Health checking support". +// Specifically, it's whether the binary was NOT built with the "ts_omit_health" build tag. +// It's a const so it can be used for dead code elimination. +const HasHealth = true diff --git a/feature/buildfeatures/feature_hujsonconf_disabled.go b/feature/buildfeatures/feature_hujsonconf_disabled.go new file mode 100644 index 000000000..cee076bc2 --- /dev/null +++ b/feature/buildfeatures/feature_hujsonconf_disabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build ts_omit_hujsonconf + +package buildfeatures + +// HasHuJSONConf is whether the binary was built with support for modular feature "HuJSON config file support". +// Specifically, it's whether the binary was NOT built with the "ts_omit_hujsonconf" build tag. +// It's a const so it can be used for dead code elimination. +const HasHuJSONConf = false diff --git a/feature/buildfeatures/feature_hujsonconf_enabled.go b/feature/buildfeatures/feature_hujsonconf_enabled.go new file mode 100644 index 000000000..aefeeace5 --- /dev/null +++ b/feature/buildfeatures/feature_hujsonconf_enabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build !ts_omit_hujsonconf + +package buildfeatures + +// HasHuJSONConf is whether the binary was built with support for modular feature "HuJSON config file support". +// Specifically, it's whether the binary was NOT built with the "ts_omit_hujsonconf" build tag. +// It's a const so it can be used for dead code elimination. +const HasHuJSONConf = true diff --git a/feature/buildfeatures/feature_identityfederation_disabled.go b/feature/buildfeatures/feature_identityfederation_disabled.go new file mode 100644 index 000000000..94488adc8 --- /dev/null +++ b/feature/buildfeatures/feature_identityfederation_disabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build ts_omit_identityfederation + +package buildfeatures + +// HasIdentityFederation is whether the binary was built with support for modular feature "Auth key generation via identity federation support". +// Specifically, it's whether the binary was NOT built with the "ts_omit_identityfederation" build tag. +// It's a const so it can be used for dead code elimination. +const HasIdentityFederation = false diff --git a/feature/buildfeatures/feature_identityfederation_enabled.go b/feature/buildfeatures/feature_identityfederation_enabled.go new file mode 100644 index 000000000..892d62d66 --- /dev/null +++ b/feature/buildfeatures/feature_identityfederation_enabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build !ts_omit_identityfederation + +package buildfeatures + +// HasIdentityFederation is whether the binary was built with support for modular feature "Auth key generation via identity federation support". +// Specifically, it's whether the binary was NOT built with the "ts_omit_identityfederation" build tag. +// It's a const so it can be used for dead code elimination. +const HasIdentityFederation = true diff --git a/feature/buildfeatures/feature_iptables_disabled.go b/feature/buildfeatures/feature_iptables_disabled.go new file mode 100644 index 000000000..8cda5be5d --- /dev/null +++ b/feature/buildfeatures/feature_iptables_disabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build ts_omit_iptables + +package buildfeatures + +// HasIPTables is whether the binary was built with support for modular feature "Linux iptables support". +// Specifically, it's whether the binary was NOT built with the "ts_omit_iptables" build tag. +// It's a const so it can be used for dead code elimination. +const HasIPTables = false diff --git a/feature/buildfeatures/feature_iptables_enabled.go b/feature/buildfeatures/feature_iptables_enabled.go new file mode 100644 index 000000000..44d98473f --- /dev/null +++ b/feature/buildfeatures/feature_iptables_enabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build !ts_omit_iptables + +package buildfeatures + +// HasIPTables is whether the binary was built with support for modular feature "Linux iptables support". +// Specifically, it's whether the binary was NOT built with the "ts_omit_iptables" build tag. +// It's a const so it can be used for dead code elimination. +const HasIPTables = true diff --git a/feature/buildfeatures/feature_kube_disabled.go b/feature/buildfeatures/feature_kube_disabled.go new file mode 100644 index 000000000..2b76c57e7 --- /dev/null +++ b/feature/buildfeatures/feature_kube_disabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build ts_omit_kube + +package buildfeatures + +// HasKube is whether the binary was built with support for modular feature "Kubernetes integration". +// Specifically, it's whether the binary was NOT built with the "ts_omit_kube" build tag. +// It's a const so it can be used for dead code elimination. +const HasKube = false diff --git a/feature/buildfeatures/feature_kube_enabled.go b/feature/buildfeatures/feature_kube_enabled.go new file mode 100644 index 000000000..7abca1759 --- /dev/null +++ b/feature/buildfeatures/feature_kube_enabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build !ts_omit_kube + +package buildfeatures + +// HasKube is whether the binary was built with support for modular feature "Kubernetes integration". +// Specifically, it's whether the binary was NOT built with the "ts_omit_kube" build tag. +// It's a const so it can be used for dead code elimination. +const HasKube = true diff --git a/feature/buildfeatures/feature_lazywg_disabled.go b/feature/buildfeatures/feature_lazywg_disabled.go new file mode 100644 index 000000000..ce81d80ba --- /dev/null +++ b/feature/buildfeatures/feature_lazywg_disabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build ts_omit_lazywg + +package buildfeatures + +// HasLazyWG is whether the binary was built with support for modular feature "Lazy WireGuard configuration for memory-constrained devices with large netmaps". +// Specifically, it's whether the binary was NOT built with the "ts_omit_lazywg" build tag. +// It's a const so it can be used for dead code elimination. +const HasLazyWG = false diff --git a/feature/buildfeatures/feature_lazywg_enabled.go b/feature/buildfeatures/feature_lazywg_enabled.go new file mode 100644 index 000000000..259357f7f --- /dev/null +++ b/feature/buildfeatures/feature_lazywg_enabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build !ts_omit_lazywg + +package buildfeatures + +// HasLazyWG is whether the binary was built with support for modular feature "Lazy WireGuard configuration for memory-constrained devices with large netmaps". +// Specifically, it's whether the binary was NOT built with the "ts_omit_lazywg" build tag. +// It's a const so it can be used for dead code elimination. +const HasLazyWG = true diff --git a/feature/buildfeatures/feature_linkspeed_disabled.go b/feature/buildfeatures/feature_linkspeed_disabled.go new file mode 100644 index 000000000..19e254a74 --- /dev/null +++ b/feature/buildfeatures/feature_linkspeed_disabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build ts_omit_linkspeed + +package buildfeatures + +// HasLinkSpeed is whether the binary was built with support for modular feature "Set link speed on TUN device for better OS integration (Linux only)". +// Specifically, it's whether the binary was NOT built with the "ts_omit_linkspeed" build tag. +// It's a const so it can be used for dead code elimination. +const HasLinkSpeed = false diff --git a/feature/buildfeatures/feature_linkspeed_enabled.go b/feature/buildfeatures/feature_linkspeed_enabled.go new file mode 100644 index 000000000..939858a16 --- /dev/null +++ b/feature/buildfeatures/feature_linkspeed_enabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build !ts_omit_linkspeed + +package buildfeatures + +// HasLinkSpeed is whether the binary was built with support for modular feature "Set link speed on TUN device for better OS integration (Linux only)". +// Specifically, it's whether the binary was NOT built with the "ts_omit_linkspeed" build tag. +// It's a const so it can be used for dead code elimination. +const HasLinkSpeed = true diff --git a/feature/buildfeatures/feature_linuxdnsfight_disabled.go b/feature/buildfeatures/feature_linuxdnsfight_disabled.go new file mode 100644 index 000000000..2e5b50ea0 --- /dev/null +++ b/feature/buildfeatures/feature_linuxdnsfight_disabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build ts_omit_linuxdnsfight + +package buildfeatures + +// HasLinuxDNSFight is whether the binary was built with support for modular feature "Linux support for detecting DNS fights (inotify watching of /etc/resolv.conf)". +// Specifically, it's whether the binary was NOT built with the "ts_omit_linuxdnsfight" build tag. +// It's a const so it can be used for dead code elimination. +const HasLinuxDNSFight = false diff --git a/feature/buildfeatures/feature_linuxdnsfight_enabled.go b/feature/buildfeatures/feature_linuxdnsfight_enabled.go new file mode 100644 index 000000000..b9419fccb --- /dev/null +++ b/feature/buildfeatures/feature_linuxdnsfight_enabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build !ts_omit_linuxdnsfight + +package buildfeatures + +// HasLinuxDNSFight is whether the binary was built with support for modular feature "Linux support for detecting DNS fights (inotify watching of /etc/resolv.conf)". +// Specifically, it's whether the binary was NOT built with the "ts_omit_linuxdnsfight" build tag. +// It's a const so it can be used for dead code elimination. +const HasLinuxDNSFight = true diff --git a/feature/buildfeatures/feature_listenrawdisco_disabled.go b/feature/buildfeatures/feature_listenrawdisco_disabled.go new file mode 100644 index 000000000..291178063 --- /dev/null +++ b/feature/buildfeatures/feature_listenrawdisco_disabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build ts_omit_listenrawdisco + +package buildfeatures + +// HasListenRawDisco is whether the binary was built with support for modular feature "Use raw sockets for more robust disco (NAT traversal) message receiving (Linux only)". +// Specifically, it's whether the binary was NOT built with the "ts_omit_listenrawdisco" build tag. +// It's a const so it can be used for dead code elimination. +const HasListenRawDisco = false diff --git a/feature/buildfeatures/feature_listenrawdisco_enabled.go b/feature/buildfeatures/feature_listenrawdisco_enabled.go new file mode 100644 index 000000000..4a4f85ae3 --- /dev/null +++ b/feature/buildfeatures/feature_listenrawdisco_enabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build !ts_omit_listenrawdisco + +package buildfeatures + +// HasListenRawDisco is whether the binary was built with support for modular feature "Use raw sockets for more robust disco (NAT traversal) message receiving (Linux only)". +// Specifically, it's whether the binary was NOT built with the "ts_omit_listenrawdisco" build tag. +// It's a const so it can be used for dead code elimination. +const HasListenRawDisco = true diff --git a/feature/buildfeatures/feature_logtail_disabled.go b/feature/buildfeatures/feature_logtail_disabled.go new file mode 100644 index 000000000..140092a2e --- /dev/null +++ b/feature/buildfeatures/feature_logtail_disabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build ts_omit_logtail + +package buildfeatures + +// HasLogTail is whether the binary was built with support for modular feature "upload logs to log.tailscale.com (debug logs for bug reports and also by network flow logs if enabled)". +// Specifically, it's whether the binary was NOT built with the "ts_omit_logtail" build tag. +// It's a const so it can be used for dead code elimination. +const HasLogTail = false diff --git a/feature/buildfeatures/feature_logtail_enabled.go b/feature/buildfeatures/feature_logtail_enabled.go new file mode 100644 index 000000000..6e777216b --- /dev/null +++ b/feature/buildfeatures/feature_logtail_enabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build !ts_omit_logtail + +package buildfeatures + +// HasLogTail is whether the binary was built with support for modular feature "upload logs to log.tailscale.com (debug logs for bug reports and also by network flow logs if enabled)". +// Specifically, it's whether the binary was NOT built with the "ts_omit_logtail" build tag. +// It's a const so it can be used for dead code elimination. +const HasLogTail = true diff --git a/feature/buildfeatures/feature_netlog_disabled.go b/feature/buildfeatures/feature_netlog_disabled.go new file mode 100644 index 000000000..60367a126 --- /dev/null +++ b/feature/buildfeatures/feature_netlog_disabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build ts_omit_netlog + +package buildfeatures + +// HasNetLog is whether the binary was built with support for modular feature "Network flow logging support". +// Specifically, it's whether the binary was NOT built with the "ts_omit_netlog" build tag. +// It's a const so it can be used for dead code elimination. +const HasNetLog = false diff --git a/feature/buildfeatures/feature_netlog_enabled.go b/feature/buildfeatures/feature_netlog_enabled.go new file mode 100644 index 000000000..f9d2abad3 --- /dev/null +++ b/feature/buildfeatures/feature_netlog_enabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build !ts_omit_netlog + +package buildfeatures + +// HasNetLog is whether the binary was built with support for modular feature "Network flow logging support". +// Specifically, it's whether the binary was NOT built with the "ts_omit_netlog" build tag. +// It's a const so it can be used for dead code elimination. +const HasNetLog = true diff --git a/feature/buildfeatures/feature_netstack_disabled.go b/feature/buildfeatures/feature_netstack_disabled.go new file mode 100644 index 000000000..acb6e8e76 --- /dev/null +++ b/feature/buildfeatures/feature_netstack_disabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build ts_omit_netstack + +package buildfeatures + +// HasNetstack is whether the binary was built with support for modular feature "gVisor netstack (userspace networking) support". +// Specifically, it's whether the binary was NOT built with the "ts_omit_netstack" build tag. +// It's a const so it can be used for dead code elimination. +const HasNetstack = false diff --git a/feature/buildfeatures/feature_netstack_enabled.go b/feature/buildfeatures/feature_netstack_enabled.go new file mode 100644 index 000000000..04f671185 --- /dev/null +++ b/feature/buildfeatures/feature_netstack_enabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build !ts_omit_netstack + +package buildfeatures + +// HasNetstack is whether the binary was built with support for modular feature "gVisor netstack (userspace networking) support". +// Specifically, it's whether the binary was NOT built with the "ts_omit_netstack" build tag. +// It's a const so it can be used for dead code elimination. +const HasNetstack = true diff --git a/feature/buildfeatures/feature_networkmanager_disabled.go b/feature/buildfeatures/feature_networkmanager_disabled.go new file mode 100644 index 000000000..d0ec6f017 --- /dev/null +++ b/feature/buildfeatures/feature_networkmanager_disabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build ts_omit_networkmanager + +package buildfeatures + +// HasNetworkManager is whether the binary was built with support for modular feature "Linux NetworkManager integration". +// Specifically, it's whether the binary was NOT built with the "ts_omit_networkmanager" build tag. +// It's a const so it can be used for dead code elimination. +const HasNetworkManager = false diff --git a/feature/buildfeatures/feature_networkmanager_enabled.go b/feature/buildfeatures/feature_networkmanager_enabled.go new file mode 100644 index 000000000..ec284c310 --- /dev/null +++ b/feature/buildfeatures/feature_networkmanager_enabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build !ts_omit_networkmanager + +package buildfeatures + +// HasNetworkManager is whether the binary was built with support for modular feature "Linux NetworkManager integration". +// Specifically, it's whether the binary was NOT built with the "ts_omit_networkmanager" build tag. +// It's a const so it can be used for dead code elimination. +const HasNetworkManager = true diff --git a/feature/buildfeatures/feature_oauthkey_disabled.go b/feature/buildfeatures/feature_oauthkey_disabled.go new file mode 100644 index 000000000..72ad1723b --- /dev/null +++ b/feature/buildfeatures/feature_oauthkey_disabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build ts_omit_oauthkey + +package buildfeatures + +// HasOAuthKey is whether the binary was built with support for modular feature "OAuth secret-to-authkey resolution support". +// Specifically, it's whether the binary was NOT built with the "ts_omit_oauthkey" build tag. +// It's a const so it can be used for dead code elimination. +const HasOAuthKey = false diff --git a/feature/buildfeatures/feature_oauthkey_enabled.go b/feature/buildfeatures/feature_oauthkey_enabled.go new file mode 100644 index 000000000..39c52a2b0 --- /dev/null +++ b/feature/buildfeatures/feature_oauthkey_enabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build !ts_omit_oauthkey + +package buildfeatures + +// HasOAuthKey is whether the binary was built with support for modular feature "OAuth secret-to-authkey resolution support". +// Specifically, it's whether the binary was NOT built with the "ts_omit_oauthkey" build tag. +// It's a const so it can be used for dead code elimination. +const HasOAuthKey = true diff --git a/feature/buildfeatures/feature_osrouter_disabled.go b/feature/buildfeatures/feature_osrouter_disabled.go new file mode 100644 index 000000000..ccd7192bb --- /dev/null +++ b/feature/buildfeatures/feature_osrouter_disabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build ts_omit_osrouter + +package buildfeatures + +// HasOSRouter is whether the binary was built with support for modular feature "Configure the operating system's network stack, IPs, and routing tables". +// Specifically, it's whether the binary was NOT built with the "ts_omit_osrouter" build tag. +// It's a const so it can be used for dead code elimination. +const HasOSRouter = false diff --git a/feature/buildfeatures/feature_osrouter_enabled.go b/feature/buildfeatures/feature_osrouter_enabled.go new file mode 100644 index 000000000..a5dacc596 --- /dev/null +++ b/feature/buildfeatures/feature_osrouter_enabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build !ts_omit_osrouter + +package buildfeatures + +// HasOSRouter is whether the binary was built with support for modular feature "Configure the operating system's network stack, IPs, and routing tables". +// Specifically, it's whether the binary was NOT built with the "ts_omit_osrouter" build tag. +// It's a const so it can be used for dead code elimination. +const HasOSRouter = true diff --git a/feature/buildfeatures/feature_outboundproxy_disabled.go b/feature/buildfeatures/feature_outboundproxy_disabled.go new file mode 100644 index 000000000..bf74db060 --- /dev/null +++ b/feature/buildfeatures/feature_outboundproxy_disabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build ts_omit_outboundproxy + +package buildfeatures + +// HasOutboundProxy is whether the binary was built with support for modular feature "Support running an outbound localhost HTTP/SOCK5 proxy support that sends traffic over Tailscale". +// Specifically, it's whether the binary was NOT built with the "ts_omit_outboundproxy" build tag. +// It's a const so it can be used for dead code elimination. +const HasOutboundProxy = false diff --git a/feature/buildfeatures/feature_outboundproxy_enabled.go b/feature/buildfeatures/feature_outboundproxy_enabled.go new file mode 100644 index 000000000..53bb99d5c --- /dev/null +++ b/feature/buildfeatures/feature_outboundproxy_enabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build !ts_omit_outboundproxy + +package buildfeatures + +// HasOutboundProxy is whether the binary was built with support for modular feature "Support running an outbound localhost HTTP/SOCK5 proxy support that sends traffic over Tailscale". +// Specifically, it's whether the binary was NOT built with the "ts_omit_outboundproxy" build tag. +// It's a const so it can be used for dead code elimination. +const HasOutboundProxy = true diff --git a/feature/buildfeatures/feature_peerapiclient_disabled.go b/feature/buildfeatures/feature_peerapiclient_disabled.go new file mode 100644 index 000000000..83cc2bdfe --- /dev/null +++ b/feature/buildfeatures/feature_peerapiclient_disabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build ts_omit_peerapiclient + +package buildfeatures + +// HasPeerAPIClient is whether the binary was built with support for modular feature "PeerAPI client support". +// Specifically, it's whether the binary was NOT built with the "ts_omit_peerapiclient" build tag. +// It's a const so it can be used for dead code elimination. +const HasPeerAPIClient = false diff --git a/feature/buildfeatures/feature_peerapiclient_enabled.go b/feature/buildfeatures/feature_peerapiclient_enabled.go new file mode 100644 index 000000000..0bd3f50a8 --- /dev/null +++ b/feature/buildfeatures/feature_peerapiclient_enabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build !ts_omit_peerapiclient + +package buildfeatures + +// HasPeerAPIClient is whether the binary was built with support for modular feature "PeerAPI client support". +// Specifically, it's whether the binary was NOT built with the "ts_omit_peerapiclient" build tag. +// It's a const so it can be used for dead code elimination. +const HasPeerAPIClient = true diff --git a/feature/buildfeatures/feature_peerapiserver_disabled.go b/feature/buildfeatures/feature_peerapiserver_disabled.go new file mode 100644 index 000000000..4a4f32b8a --- /dev/null +++ b/feature/buildfeatures/feature_peerapiserver_disabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build ts_omit_peerapiserver + +package buildfeatures + +// HasPeerAPIServer is whether the binary was built with support for modular feature "PeerAPI server support". +// Specifically, it's whether the binary was NOT built with the "ts_omit_peerapiserver" build tag. +// It's a const so it can be used for dead code elimination. +const HasPeerAPIServer = false diff --git a/feature/buildfeatures/feature_peerapiserver_enabled.go b/feature/buildfeatures/feature_peerapiserver_enabled.go new file mode 100644 index 000000000..17d0547b8 --- /dev/null +++ b/feature/buildfeatures/feature_peerapiserver_enabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build !ts_omit_peerapiserver + +package buildfeatures + +// HasPeerAPIServer is whether the binary was built with support for modular feature "PeerAPI server support". +// Specifically, it's whether the binary was NOT built with the "ts_omit_peerapiserver" build tag. +// It's a const so it can be used for dead code elimination. +const HasPeerAPIServer = true diff --git a/feature/buildfeatures/feature_portlist_disabled.go b/feature/buildfeatures/feature_portlist_disabled.go new file mode 100644 index 000000000..934061fd8 --- /dev/null +++ b/feature/buildfeatures/feature_portlist_disabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build ts_omit_portlist + +package buildfeatures + +// HasPortList is whether the binary was built with support for modular feature "Optionally advertise listening service ports". +// Specifically, it's whether the binary was NOT built with the "ts_omit_portlist" build tag. +// It's a const so it can be used for dead code elimination. +const HasPortList = false diff --git a/feature/buildfeatures/feature_portlist_enabled.go b/feature/buildfeatures/feature_portlist_enabled.go new file mode 100644 index 000000000..c1dc1c163 --- /dev/null +++ b/feature/buildfeatures/feature_portlist_enabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build !ts_omit_portlist + +package buildfeatures + +// HasPortList is whether the binary was built with support for modular feature "Optionally advertise listening service ports". +// Specifically, it's whether the binary was NOT built with the "ts_omit_portlist" build tag. +// It's a const so it can be used for dead code elimination. +const HasPortList = true diff --git a/feature/buildfeatures/feature_portmapper_disabled.go b/feature/buildfeatures/feature_portmapper_disabled.go new file mode 100644 index 000000000..212b22d40 --- /dev/null +++ b/feature/buildfeatures/feature_portmapper_disabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build ts_omit_portmapper + +package buildfeatures + +// HasPortMapper is whether the binary was built with support for modular feature "NAT-PMP/PCP/UPnP port mapping support". +// Specifically, it's whether the binary was NOT built with the "ts_omit_portmapper" build tag. +// It's a const so it can be used for dead code elimination. +const HasPortMapper = false diff --git a/feature/buildfeatures/feature_portmapper_enabled.go b/feature/buildfeatures/feature_portmapper_enabled.go new file mode 100644 index 000000000..2f915d277 --- /dev/null +++ b/feature/buildfeatures/feature_portmapper_enabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build !ts_omit_portmapper + +package buildfeatures + +// HasPortMapper is whether the binary was built with support for modular feature "NAT-PMP/PCP/UPnP port mapping support". +// Specifically, it's whether the binary was NOT built with the "ts_omit_portmapper" build tag. +// It's a const so it can be used for dead code elimination. +const HasPortMapper = true diff --git a/feature/buildfeatures/feature_posture_disabled.go b/feature/buildfeatures/feature_posture_disabled.go new file mode 100644 index 000000000..a78b1a957 --- /dev/null +++ b/feature/buildfeatures/feature_posture_disabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build ts_omit_posture + +package buildfeatures + +// HasPosture is whether the binary was built with support for modular feature "Device posture checking support". +// Specifically, it's whether the binary was NOT built with the "ts_omit_posture" build tag. +// It's a const so it can be used for dead code elimination. +const HasPosture = false diff --git a/feature/buildfeatures/feature_posture_enabled.go b/feature/buildfeatures/feature_posture_enabled.go new file mode 100644 index 000000000..dcd9595f9 --- /dev/null +++ b/feature/buildfeatures/feature_posture_enabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build !ts_omit_posture + +package buildfeatures + +// HasPosture is whether the binary was built with support for modular feature "Device posture checking support". +// Specifically, it's whether the binary was NOT built with the "ts_omit_posture" build tag. +// It's a const so it can be used for dead code elimination. +const HasPosture = true diff --git a/feature/buildfeatures/feature_relayserver_disabled.go b/feature/buildfeatures/feature_relayserver_disabled.go new file mode 100644 index 000000000..08ced8310 --- /dev/null +++ b/feature/buildfeatures/feature_relayserver_disabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build ts_omit_relayserver + +package buildfeatures + +// HasRelayServer is whether the binary was built with support for modular feature "Relay server". +// Specifically, it's whether the binary was NOT built with the "ts_omit_relayserver" build tag. +// It's a const so it can be used for dead code elimination. +const HasRelayServer = false diff --git a/feature/buildfeatures/feature_relayserver_enabled.go b/feature/buildfeatures/feature_relayserver_enabled.go new file mode 100644 index 000000000..6a35f8305 --- /dev/null +++ b/feature/buildfeatures/feature_relayserver_enabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build !ts_omit_relayserver + +package buildfeatures + +// HasRelayServer is whether the binary was built with support for modular feature "Relay server". +// Specifically, it's whether the binary was NOT built with the "ts_omit_relayserver" build tag. +// It's a const so it can be used for dead code elimination. +const HasRelayServer = true diff --git a/feature/buildfeatures/feature_resolved_disabled.go b/feature/buildfeatures/feature_resolved_disabled.go new file mode 100644 index 000000000..283dd20c7 --- /dev/null +++ b/feature/buildfeatures/feature_resolved_disabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build ts_omit_resolved + +package buildfeatures + +// HasResolved is whether the binary was built with support for modular feature "Linux systemd-resolved integration". +// Specifically, it's whether the binary was NOT built with the "ts_omit_resolved" build tag. +// It's a const so it can be used for dead code elimination. +const HasResolved = false diff --git a/feature/buildfeatures/feature_resolved_enabled.go b/feature/buildfeatures/feature_resolved_enabled.go new file mode 100644 index 000000000..af1b3b41e --- /dev/null +++ b/feature/buildfeatures/feature_resolved_enabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build !ts_omit_resolved + +package buildfeatures + +// HasResolved is whether the binary was built with support for modular feature "Linux systemd-resolved integration". +// Specifically, it's whether the binary was NOT built with the "ts_omit_resolved" build tag. +// It's a const so it can be used for dead code elimination. +const HasResolved = true diff --git a/feature/buildfeatures/feature_sdnotify_disabled.go b/feature/buildfeatures/feature_sdnotify_disabled.go new file mode 100644 index 000000000..7efa2d22f --- /dev/null +++ b/feature/buildfeatures/feature_sdnotify_disabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build ts_omit_sdnotify + +package buildfeatures + +// HasSDNotify is whether the binary was built with support for modular feature "systemd notification support". +// Specifically, it's whether the binary was NOT built with the "ts_omit_sdnotify" build tag. +// It's a const so it can be used for dead code elimination. +const HasSDNotify = false diff --git a/feature/buildfeatures/feature_sdnotify_enabled.go b/feature/buildfeatures/feature_sdnotify_enabled.go new file mode 100644 index 000000000..40fec9755 --- /dev/null +++ b/feature/buildfeatures/feature_sdnotify_enabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build !ts_omit_sdnotify + +package buildfeatures + +// HasSDNotify is whether the binary was built with support for modular feature "systemd notification support". +// Specifically, it's whether the binary was NOT built with the "ts_omit_sdnotify" build tag. +// It's a const so it can be used for dead code elimination. +const HasSDNotify = true diff --git a/feature/buildfeatures/feature_serve_disabled.go b/feature/buildfeatures/feature_serve_disabled.go new file mode 100644 index 000000000..6d7971350 --- /dev/null +++ b/feature/buildfeatures/feature_serve_disabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build ts_omit_serve + +package buildfeatures + +// HasServe is whether the binary was built with support for modular feature "Serve and Funnel support". +// Specifically, it's whether the binary was NOT built with the "ts_omit_serve" build tag. +// It's a const so it can be used for dead code elimination. +const HasServe = false diff --git a/feature/buildfeatures/feature_serve_enabled.go b/feature/buildfeatures/feature_serve_enabled.go new file mode 100644 index 000000000..57bf2c6b0 --- /dev/null +++ b/feature/buildfeatures/feature_serve_enabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build !ts_omit_serve + +package buildfeatures + +// HasServe is whether the binary was built with support for modular feature "Serve and Funnel support". +// Specifically, it's whether the binary was NOT built with the "ts_omit_serve" build tag. +// It's a const so it can be used for dead code elimination. +const HasServe = true diff --git a/feature/buildfeatures/feature_ssh_disabled.go b/feature/buildfeatures/feature_ssh_disabled.go new file mode 100644 index 000000000..754f50eb6 --- /dev/null +++ b/feature/buildfeatures/feature_ssh_disabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build ts_omit_ssh + +package buildfeatures + +// HasSSH is whether the binary was built with support for modular feature "Tailscale SSH support". +// Specifically, it's whether the binary was NOT built with the "ts_omit_ssh" build tag. +// It's a const so it can be used for dead code elimination. +const HasSSH = false diff --git a/feature/buildfeatures/feature_ssh_enabled.go b/feature/buildfeatures/feature_ssh_enabled.go new file mode 100644 index 000000000..dbdc3a89f --- /dev/null +++ b/feature/buildfeatures/feature_ssh_enabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build !ts_omit_ssh + +package buildfeatures + +// HasSSH is whether the binary was built with support for modular feature "Tailscale SSH support". +// Specifically, it's whether the binary was NOT built with the "ts_omit_ssh" build tag. +// It's a const so it can be used for dead code elimination. +const HasSSH = true diff --git a/feature/buildfeatures/feature_synology_disabled.go b/feature/buildfeatures/feature_synology_disabled.go new file mode 100644 index 000000000..0cdf084c3 --- /dev/null +++ b/feature/buildfeatures/feature_synology_disabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build ts_omit_synology + +package buildfeatures + +// HasSynology is whether the binary was built with support for modular feature "Synology NAS integration (applies to Linux builds only)". +// Specifically, it's whether the binary was NOT built with the "ts_omit_synology" build tag. +// It's a const so it can be used for dead code elimination. +const HasSynology = false diff --git a/feature/buildfeatures/feature_synology_enabled.go b/feature/buildfeatures/feature_synology_enabled.go new file mode 100644 index 000000000..dde4123b6 --- /dev/null +++ b/feature/buildfeatures/feature_synology_enabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build !ts_omit_synology + +package buildfeatures + +// HasSynology is whether the binary was built with support for modular feature "Synology NAS integration (applies to Linux builds only)". +// Specifically, it's whether the binary was NOT built with the "ts_omit_synology" build tag. +// It's a const so it can be used for dead code elimination. +const HasSynology = true diff --git a/feature/buildfeatures/feature_syspolicy_disabled.go b/feature/buildfeatures/feature_syspolicy_disabled.go new file mode 100644 index 000000000..54d32e32e --- /dev/null +++ b/feature/buildfeatures/feature_syspolicy_disabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build ts_omit_syspolicy + +package buildfeatures + +// HasSystemPolicy is whether the binary was built with support for modular feature "System policy configuration (MDM) support". +// Specifically, it's whether the binary was NOT built with the "ts_omit_syspolicy" build tag. +// It's a const so it can be used for dead code elimination. +const HasSystemPolicy = false diff --git a/feature/buildfeatures/feature_syspolicy_enabled.go b/feature/buildfeatures/feature_syspolicy_enabled.go new file mode 100644 index 000000000..f7c403ae9 --- /dev/null +++ b/feature/buildfeatures/feature_syspolicy_enabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build !ts_omit_syspolicy + +package buildfeatures + +// HasSystemPolicy is whether the binary was built with support for modular feature "System policy configuration (MDM) support". +// Specifically, it's whether the binary was NOT built with the "ts_omit_syspolicy" build tag. +// It's a const so it can be used for dead code elimination. +const HasSystemPolicy = true diff --git a/feature/buildfeatures/feature_systray_disabled.go b/feature/buildfeatures/feature_systray_disabled.go new file mode 100644 index 000000000..4ae1edb0a --- /dev/null +++ b/feature/buildfeatures/feature_systray_disabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build ts_omit_systray + +package buildfeatures + +// HasSysTray is whether the binary was built with support for modular feature "Linux system tray". +// Specifically, it's whether the binary was NOT built with the "ts_omit_systray" build tag. +// It's a const so it can be used for dead code elimination. +const HasSysTray = false diff --git a/feature/buildfeatures/feature_systray_enabled.go b/feature/buildfeatures/feature_systray_enabled.go new file mode 100644 index 000000000..5fd7fd220 --- /dev/null +++ b/feature/buildfeatures/feature_systray_enabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build !ts_omit_systray + +package buildfeatures + +// HasSysTray is whether the binary was built with support for modular feature "Linux system tray". +// Specifically, it's whether the binary was NOT built with the "ts_omit_systray" build tag. +// It's a const so it can be used for dead code elimination. +const HasSysTray = true diff --git a/feature/buildfeatures/feature_taildrop_disabled.go b/feature/buildfeatures/feature_taildrop_disabled.go new file mode 100644 index 000000000..8ffe90617 --- /dev/null +++ b/feature/buildfeatures/feature_taildrop_disabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build ts_omit_taildrop + +package buildfeatures + +// HasTaildrop is whether the binary was built with support for modular feature "Taildrop (file sending) support". +// Specifically, it's whether the binary was NOT built with the "ts_omit_taildrop" build tag. +// It's a const so it can be used for dead code elimination. +const HasTaildrop = false diff --git a/feature/buildfeatures/feature_taildrop_enabled.go b/feature/buildfeatures/feature_taildrop_enabled.go new file mode 100644 index 000000000..4f55d2801 --- /dev/null +++ b/feature/buildfeatures/feature_taildrop_enabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build !ts_omit_taildrop + +package buildfeatures + +// HasTaildrop is whether the binary was built with support for modular feature "Taildrop (file sending) support". +// Specifically, it's whether the binary was NOT built with the "ts_omit_taildrop" build tag. +// It's a const so it can be used for dead code elimination. +const HasTaildrop = true diff --git a/feature/buildfeatures/feature_tailnetlock_disabled.go b/feature/buildfeatures/feature_tailnetlock_disabled.go new file mode 100644 index 000000000..6b5a57f24 --- /dev/null +++ b/feature/buildfeatures/feature_tailnetlock_disabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build ts_omit_tailnetlock + +package buildfeatures + +// HasTailnetLock is whether the binary was built with support for modular feature "Tailnet Lock support". +// Specifically, it's whether the binary was NOT built with the "ts_omit_tailnetlock" build tag. +// It's a const so it can be used for dead code elimination. +const HasTailnetLock = false diff --git a/feature/buildfeatures/feature_tailnetlock_enabled.go b/feature/buildfeatures/feature_tailnetlock_enabled.go new file mode 100644 index 000000000..afedb7faa --- /dev/null +++ b/feature/buildfeatures/feature_tailnetlock_enabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build !ts_omit_tailnetlock + +package buildfeatures + +// HasTailnetLock is whether the binary was built with support for modular feature "Tailnet Lock support". +// Specifically, it's whether the binary was NOT built with the "ts_omit_tailnetlock" build tag. +// It's a const so it can be used for dead code elimination. +const HasTailnetLock = true diff --git a/feature/buildfeatures/feature_tap_disabled.go b/feature/buildfeatures/feature_tap_disabled.go new file mode 100644 index 000000000..f0b3eec8d --- /dev/null +++ b/feature/buildfeatures/feature_tap_disabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build ts_omit_tap + +package buildfeatures + +// HasTap is whether the binary was built with support for modular feature "Experimental Layer 2 (ethernet) support". +// Specifically, it's whether the binary was NOT built with the "ts_omit_tap" build tag. +// It's a const so it can be used for dead code elimination. +const HasTap = false diff --git a/feature/buildfeatures/feature_tap_enabled.go b/feature/buildfeatures/feature_tap_enabled.go new file mode 100644 index 000000000..1363c4b44 --- /dev/null +++ b/feature/buildfeatures/feature_tap_enabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build !ts_omit_tap + +package buildfeatures + +// HasTap is whether the binary was built with support for modular feature "Experimental Layer 2 (ethernet) support". +// Specifically, it's whether the binary was NOT built with the "ts_omit_tap" build tag. +// It's a const so it can be used for dead code elimination. +const HasTap = true diff --git a/feature/buildfeatures/feature_tpm_disabled.go b/feature/buildfeatures/feature_tpm_disabled.go new file mode 100644 index 000000000..b9d55815e --- /dev/null +++ b/feature/buildfeatures/feature_tpm_disabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build ts_omit_tpm + +package buildfeatures + +// HasTPM is whether the binary was built with support for modular feature "TPM support". +// Specifically, it's whether the binary was NOT built with the "ts_omit_tpm" build tag. +// It's a const so it can be used for dead code elimination. +const HasTPM = false diff --git a/feature/buildfeatures/feature_tpm_enabled.go b/feature/buildfeatures/feature_tpm_enabled.go new file mode 100644 index 000000000..dcfc8a304 --- /dev/null +++ b/feature/buildfeatures/feature_tpm_enabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build !ts_omit_tpm + +package buildfeatures + +// HasTPM is whether the binary was built with support for modular feature "TPM support". +// Specifically, it's whether the binary was NOT built with the "ts_omit_tpm" build tag. +// It's a const so it can be used for dead code elimination. +const HasTPM = true diff --git a/feature/buildfeatures/feature_unixsocketidentity_disabled.go b/feature/buildfeatures/feature_unixsocketidentity_disabled.go new file mode 100644 index 000000000..d64e48b82 --- /dev/null +++ b/feature/buildfeatures/feature_unixsocketidentity_disabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build ts_omit_unixsocketidentity + +package buildfeatures + +// HasUnixSocketIdentity is whether the binary was built with support for modular feature "differentiate between users accessing the LocalAPI over unix sockets (if omitted, all users have full access)". +// Specifically, it's whether the binary was NOT built with the "ts_omit_unixsocketidentity" build tag. +// It's a const so it can be used for dead code elimination. +const HasUnixSocketIdentity = false diff --git a/feature/buildfeatures/feature_unixsocketidentity_enabled.go b/feature/buildfeatures/feature_unixsocketidentity_enabled.go new file mode 100644 index 000000000..463ac2ced --- /dev/null +++ b/feature/buildfeatures/feature_unixsocketidentity_enabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build !ts_omit_unixsocketidentity + +package buildfeatures + +// HasUnixSocketIdentity is whether the binary was built with support for modular feature "differentiate between users accessing the LocalAPI over unix sockets (if omitted, all users have full access)". +// Specifically, it's whether the binary was NOT built with the "ts_omit_unixsocketidentity" build tag. +// It's a const so it can be used for dead code elimination. +const HasUnixSocketIdentity = true diff --git a/feature/buildfeatures/feature_useexitnode_disabled.go b/feature/buildfeatures/feature_useexitnode_disabled.go new file mode 100644 index 000000000..51bec8046 --- /dev/null +++ b/feature/buildfeatures/feature_useexitnode_disabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build ts_omit_useexitnode + +package buildfeatures + +// HasUseExitNode is whether the binary was built with support for modular feature "Use exit nodes". +// Specifically, it's whether the binary was NOT built with the "ts_omit_useexitnode" build tag. +// It's a const so it can be used for dead code elimination. +const HasUseExitNode = false diff --git a/feature/buildfeatures/feature_useexitnode_enabled.go b/feature/buildfeatures/feature_useexitnode_enabled.go new file mode 100644 index 000000000..f7ab414de --- /dev/null +++ b/feature/buildfeatures/feature_useexitnode_enabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build !ts_omit_useexitnode + +package buildfeatures + +// HasUseExitNode is whether the binary was built with support for modular feature "Use exit nodes". +// Specifically, it's whether the binary was NOT built with the "ts_omit_useexitnode" build tag. +// It's a const so it can be used for dead code elimination. +const HasUseExitNode = true diff --git a/feature/buildfeatures/feature_useproxy_disabled.go b/feature/buildfeatures/feature_useproxy_disabled.go new file mode 100644 index 000000000..9f29a9820 --- /dev/null +++ b/feature/buildfeatures/feature_useproxy_disabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build ts_omit_useproxy + +package buildfeatures + +// HasUseProxy is whether the binary was built with support for modular feature "Support using system proxies as specified by env vars or the system configuration to reach Tailscale servers.". +// Specifically, it's whether the binary was NOT built with the "ts_omit_useproxy" build tag. +// It's a const so it can be used for dead code elimination. +const HasUseProxy = false diff --git a/feature/buildfeatures/feature_useproxy_enabled.go b/feature/buildfeatures/feature_useproxy_enabled.go new file mode 100644 index 000000000..9195f2fdc --- /dev/null +++ b/feature/buildfeatures/feature_useproxy_enabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build !ts_omit_useproxy + +package buildfeatures + +// HasUseProxy is whether the binary was built with support for modular feature "Support using system proxies as specified by env vars or the system configuration to reach Tailscale servers.". +// Specifically, it's whether the binary was NOT built with the "ts_omit_useproxy" build tag. +// It's a const so it can be used for dead code elimination. +const HasUseProxy = true diff --git a/feature/buildfeatures/feature_usermetrics_disabled.go b/feature/buildfeatures/feature_usermetrics_disabled.go new file mode 100644 index 000000000..092c89c3b --- /dev/null +++ b/feature/buildfeatures/feature_usermetrics_disabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build ts_omit_usermetrics + +package buildfeatures + +// HasUserMetrics is whether the binary was built with support for modular feature "Usermetrics (documented, stable) metrics support". +// Specifically, it's whether the binary was NOT built with the "ts_omit_usermetrics" build tag. +// It's a const so it can be used for dead code elimination. +const HasUserMetrics = false diff --git a/feature/buildfeatures/feature_usermetrics_enabled.go b/feature/buildfeatures/feature_usermetrics_enabled.go new file mode 100644 index 000000000..813e3c347 --- /dev/null +++ b/feature/buildfeatures/feature_usermetrics_enabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build !ts_omit_usermetrics + +package buildfeatures + +// HasUserMetrics is whether the binary was built with support for modular feature "Usermetrics (documented, stable) metrics support". +// Specifically, it's whether the binary was NOT built with the "ts_omit_usermetrics" build tag. +// It's a const so it can be used for dead code elimination. +const HasUserMetrics = true diff --git a/feature/buildfeatures/feature_useroutes_disabled.go b/feature/buildfeatures/feature_useroutes_disabled.go new file mode 100644 index 000000000..ecf9d022b --- /dev/null +++ b/feature/buildfeatures/feature_useroutes_disabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build ts_omit_useroutes + +package buildfeatures + +// HasUseRoutes is whether the binary was built with support for modular feature "Use routes advertised by other nodes". +// Specifically, it's whether the binary was NOT built with the "ts_omit_useroutes" build tag. +// It's a const so it can be used for dead code elimination. +const HasUseRoutes = false diff --git a/feature/buildfeatures/feature_useroutes_enabled.go b/feature/buildfeatures/feature_useroutes_enabled.go new file mode 100644 index 000000000..c0a59322e --- /dev/null +++ b/feature/buildfeatures/feature_useroutes_enabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build !ts_omit_useroutes + +package buildfeatures + +// HasUseRoutes is whether the binary was built with support for modular feature "Use routes advertised by other nodes". +// Specifically, it's whether the binary was NOT built with the "ts_omit_useroutes" build tag. +// It's a const so it can be used for dead code elimination. +const HasUseRoutes = true diff --git a/feature/buildfeatures/feature_wakeonlan_disabled.go b/feature/buildfeatures/feature_wakeonlan_disabled.go new file mode 100644 index 000000000..816ac661f --- /dev/null +++ b/feature/buildfeatures/feature_wakeonlan_disabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build ts_omit_wakeonlan + +package buildfeatures + +// HasWakeOnLAN is whether the binary was built with support for modular feature "Wake-on-LAN support". +// Specifically, it's whether the binary was NOT built with the "ts_omit_wakeonlan" build tag. +// It's a const so it can be used for dead code elimination. +const HasWakeOnLAN = false diff --git a/feature/buildfeatures/feature_wakeonlan_enabled.go b/feature/buildfeatures/feature_wakeonlan_enabled.go new file mode 100644 index 000000000..34b3348a1 --- /dev/null +++ b/feature/buildfeatures/feature_wakeonlan_enabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build !ts_omit_wakeonlan + +package buildfeatures + +// HasWakeOnLAN is whether the binary was built with support for modular feature "Wake-on-LAN support". +// Specifically, it's whether the binary was NOT built with the "ts_omit_wakeonlan" build tag. +// It's a const so it can be used for dead code elimination. +const HasWakeOnLAN = true diff --git a/feature/buildfeatures/feature_webclient_disabled.go b/feature/buildfeatures/feature_webclient_disabled.go new file mode 100644 index 000000000..a7b24f4ac --- /dev/null +++ b/feature/buildfeatures/feature_webclient_disabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build ts_omit_webclient + +package buildfeatures + +// HasWebClient is whether the binary was built with support for modular feature "Web client support". +// Specifically, it's whether the binary was NOT built with the "ts_omit_webclient" build tag. +// It's a const so it can be used for dead code elimination. +const HasWebClient = false diff --git a/feature/buildfeatures/feature_webclient_enabled.go b/feature/buildfeatures/feature_webclient_enabled.go new file mode 100644 index 000000000..e40dad33c --- /dev/null +++ b/feature/buildfeatures/feature_webclient_enabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build !ts_omit_webclient + +package buildfeatures + +// HasWebClient is whether the binary was built with support for modular feature "Web client support". +// Specifically, it's whether the binary was NOT built with the "ts_omit_webclient" build tag. +// It's a const so it can be used for dead code elimination. +const HasWebClient = true diff --git a/feature/buildfeatures/gen.go b/feature/buildfeatures/gen.go new file mode 100644 index 000000000..e967cb8ff --- /dev/null +++ b/feature/buildfeatures/gen.go @@ -0,0 +1,49 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build ignore + +// The gens.go program generates the feature__enabled.go +// and feature__disabled.go files for each feature tag. +package main + +import ( + "cmp" + "fmt" + "os" + "strings" + + "tailscale.com/feature/featuretags" + "tailscale.com/util/must" +) + +const header = `// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code g|e|n|e|r|a|t|e|d by gen.go; D|O N|OT E|D|I|T. + +` + +func main() { + header := strings.ReplaceAll(header, "|", "") // to avoid this file being marked as generated + for k, m := range featuretags.Features { + if !k.IsOmittable() { + continue + } + sym := "Has" + cmp.Or(m.Sym, strings.ToUpper(string(k)[:1])+string(k)[1:]) + for _, suf := range []string{"enabled", "disabled"} { + bang := "" + if suf == "enabled" { + bang = "!" // !ts_omit_... + } + must.Do(os.WriteFile("feature_"+string(k)+"_"+suf+".go", + fmt.Appendf(nil, "%s//go:build %s%s\n\npackage buildfeatures\n\n"+ + "// %s is whether the binary was built with support for modular feature %q.\n"+ + "// Specifically, it's whether the binary was NOT built with the %q build tag.\n"+ + "// It's a const so it can be used for dead code elimination.\n"+ + "const %s = %t\n", + header, bang, k.OmitTag(), sym, m.Desc, k.OmitTag(), sym, suf == "enabled"), 0644)) + + } + } +} diff --git a/feature/c2n/c2n.go b/feature/c2n/c2n.go new file mode 100644 index 000000000..ae942e31d --- /dev/null +++ b/feature/c2n/c2n.go @@ -0,0 +1,70 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package c2n registers support for C2N (Control-to-Node) communications. +package c2n + +import ( + "bufio" + "bytes" + "context" + "net/http" + "time" + + "tailscale.com/control/controlclient" + "tailscale.com/tailcfg" + "tailscale.com/tempfork/httprec" + "tailscale.com/types/logger" +) + +func init() { + controlclient.HookAnswerC2NPing.Set(answerC2NPing) +} + +func answerC2NPing(logf logger.Logf, c2nHandler http.Handler, c *http.Client, pr *tailcfg.PingRequest) { + if c2nHandler == nil { + logf("answerC2NPing: c2nHandler not defined") + return + } + hreq, err := http.ReadRequest(bufio.NewReader(bytes.NewReader(pr.Payload))) + if err != nil { + logf("answerC2NPing: ReadRequest: %v", err) + return + } + if pr.Log { + logf("answerC2NPing: got c2n request for %v ...", hreq.RequestURI) + } + handlerTimeout := time.Minute + if v := hreq.Header.Get("C2n-Handler-Timeout"); v != "" { + handlerTimeout, _ = time.ParseDuration(v) + } + handlerCtx, cancel := context.WithTimeout(context.Background(), handlerTimeout) + defer cancel() + hreq = hreq.WithContext(handlerCtx) + rec := httprec.NewRecorder() + c2nHandler.ServeHTTP(rec, hreq) + cancel() + + c2nResBuf := new(bytes.Buffer) + rec.Result().Write(c2nResBuf) + + replyCtx, cancel := context.WithTimeout(context.Background(), time.Minute) + defer cancel() + + req, err := http.NewRequestWithContext(replyCtx, "POST", pr.URL, c2nResBuf) + if err != nil { + logf("answerC2NPing: NewRequestWithContext: %v", err) + return + } + if pr.Log { + logf("answerC2NPing: sending POST ping to %v ...", pr.URL) + } + t0 := time.Now() + _, err = c.Do(req) + d := time.Since(t0).Round(time.Millisecond) + if err != nil { + logf("answerC2NPing error: %v to %v (after %v)", err, pr.URL, d) + } else if pr.Log { + logf("answerC2NPing complete to %v (after %v)", pr.URL, d) + } +} diff --git a/wgengine/capture/capture.go b/feature/capture/capture.go similarity index 79% rename from wgengine/capture/capture.go rename to feature/capture/capture.go index 6ea5a9549..e5e150de8 100644 --- a/wgengine/capture/capture.go +++ b/feature/capture/capture.go @@ -13,21 +13,44 @@ import ( "sync" "time" - _ "embed" - + "tailscale.com/feature" + "tailscale.com/ipn/localapi" "tailscale.com/net/packet" "tailscale.com/util/set" ) -//go:embed ts-dissector.lua -var DissectorLua string +func init() { + feature.Register("capture") + localapi.Register("debug-capture", serveLocalAPIDebugCapture) +} + +func serveLocalAPIDebugCapture(h *localapi.Handler, w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + if !h.PermitWrite { + http.Error(w, "debug access denied", http.StatusForbidden) + return + } + if r.Method != "POST" { + http.Error(w, "POST required", http.StatusMethodNotAllowed) + return + } + + w.WriteHeader(http.StatusOK) + w.(http.Flusher).Flush() + + b := h.LocalBackend() + s := b.GetOrSetCaptureSink(newSink) -// Callback describes a function which is called to -// record packets when debugging packet-capture. -// Such callbacks must not take ownership of the -// provided data slice: it may only copy out of it -// within the lifetime of the function. -type Callback func(Path, time.Time, []byte, packet.CaptureMeta) + unregister := s.RegisterOutput(w) + + select { + case <-ctx.Done(): + case <-s.WaitCh(): + } + unregister() + + b.ClearCaptureSink() +} var bufferPool = sync.Pool{ New: func() any { @@ -57,29 +80,8 @@ func writePktHeader(w *bytes.Buffer, when time.Time, length int) { binary.Write(w, binary.LittleEndian, uint32(length)) // total length } -// Path describes where in the data path the packet was captured. -type Path uint8 - -// Valid Path values. -const ( - // FromLocal indicates the packet was logged as it traversed the FromLocal path: - // i.e.: A packet from the local system into the TUN. - FromLocal Path = 0 - // FromPeer indicates the packet was logged upon reception from a remote peer. - FromPeer Path = 1 - // SynthesizedToLocal indicates the packet was generated from within tailscaled, - // and is being routed to the local machine's network stack. - SynthesizedToLocal Path = 2 - // SynthesizedToPeer indicates the packet was generated from within tailscaled, - // and is being routed to a remote Wireguard peer. - SynthesizedToPeer Path = 3 - - // PathDisco indicates the packet is information about a disco frame. - PathDisco Path = 254 -) - -// New creates a new capture sink. -func New() *Sink { +// newSink creates a new capture sink. +func newSink() packet.CaptureSink { ctx, c := context.WithCancel(context.Background()) return &Sink{ ctx: ctx, @@ -126,6 +128,10 @@ func (s *Sink) RegisterOutput(w io.Writer) (unregister func()) { } } +func (s *Sink) CaptureCallback() packet.CaptureCallback { + return s.LogPacket +} + // NumOutputs returns the number of outputs registered with the sink. func (s *Sink) NumOutputs() int { s.mu.Lock() @@ -174,7 +180,7 @@ func customDataLen(meta packet.CaptureMeta) int { // LogPacket is called to insert a packet into the capture. // // This function does not take ownership of the provided data slice. -func (s *Sink) LogPacket(path Path, when time.Time, data []byte, meta packet.CaptureMeta) { +func (s *Sink) LogPacket(path packet.CapturePath, when time.Time, data []byte, meta packet.CaptureMeta) { select { case <-s.ctx.Done(): return diff --git a/feature/capture/dissector/dissector.go b/feature/capture/dissector/dissector.go new file mode 100644 index 000000000..ab2f6c2ec --- /dev/null +++ b/feature/capture/dissector/dissector.go @@ -0,0 +1,12 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package dissector contains the Lua dissector for Tailscale packets. +package dissector + +import ( + _ "embed" +) + +//go:embed ts-dissector.lua +var Lua string diff --git a/wgengine/capture/ts-dissector.lua b/feature/capture/dissector/ts-dissector.lua similarity index 93% rename from wgengine/capture/ts-dissector.lua rename to feature/capture/dissector/ts-dissector.lua index ad553d767..c2ee2b755 100644 --- a/wgengine/capture/ts-dissector.lua +++ b/feature/capture/dissector/ts-dissector.lua @@ -1,5 +1,5 @@ function hasbit(x, p) - return x % (p + p) >= p + return bit.band(x, p) ~= 0 end tsdebug_ll = Proto("tsdebug", "Tailscale debug") @@ -128,6 +128,10 @@ function tsdisco_frame.dissector(buffer, pinfo, tree) if message_type == 1 then subtree:add(DISCO_TYPE, "Ping") elseif message_type == 2 then subtree:add(DISCO_TYPE, "Pong") elseif message_type == 3 then subtree:add(DISCO_TYPE, "Call me maybe") + elseif message_type == 4 then subtree:add(DISCO_TYPE, "Bind UDP Relay Endpoint") + elseif message_type == 5 then subtree:add(DISCO_TYPE, "Bind UDP Relay Endpoint Challenge") + elseif message_type == 6 then subtree:add(DISCO_TYPE, "Bind UDP Relay Endpoint Answer") + elseif message_type == 7 then subtree:add(DISCO_TYPE, "Call me maybe via") end -- Message version diff --git a/feature/clientupdate/clientupdate.go b/feature/clientupdate/clientupdate.go new file mode 100644 index 000000000..45fd21129 --- /dev/null +++ b/feature/clientupdate/clientupdate.go @@ -0,0 +1,530 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package clientupdate enables the client update feature. +package clientupdate + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "net/http" + "os" + "os/exec" + "path/filepath" + "runtime" + "strconv" + "strings" + "sync" + "time" + + "tailscale.com/clientupdate" + "tailscale.com/envknob" + "tailscale.com/feature" + "tailscale.com/ipn" + "tailscale.com/ipn/ipnext" + "tailscale.com/ipn/ipnlocal" + "tailscale.com/ipn/ipnstate" + "tailscale.com/ipn/localapi" + "tailscale.com/tailcfg" + "tailscale.com/types/logger" + "tailscale.com/util/httpm" + "tailscale.com/version" + "tailscale.com/version/distro" +) + +func init() { + ipnext.RegisterExtension("clientupdate", newExt) + + // C2N + ipnlocal.RegisterC2N("GET /update", handleC2NUpdateGet) + ipnlocal.RegisterC2N("POST /update", handleC2NUpdatePost) + + // LocalAPI: + localapi.Register("update/install", serveUpdateInstall) + localapi.Register("update/progress", serveUpdateProgress) +} + +func newExt(logf logger.Logf, sb ipnext.SafeBackend) (ipnext.Extension, error) { + return &extension{ + logf: logf, + sb: sb, + + lastSelfUpdateState: ipnstate.UpdateFinished, + }, nil +} + +type extension struct { + logf logger.Logf + sb ipnext.SafeBackend + + mu sync.Mutex + + // c2nUpdateStatus is the status of c2n-triggered client update. + c2nUpdateStatus updateStatus + prefs ipn.PrefsView + state ipn.State + + lastSelfUpdateState ipnstate.SelfUpdateStatus + selfUpdateProgress []ipnstate.UpdateProgress + + // offlineAutoUpdateCancel stops offline auto-updates when called. It + // should be used via stopOfflineAutoUpdate and + // maybeStartOfflineAutoUpdate. It is nil when offline auto-updates are + // not running. + // + //lint:ignore U1000 only used in Linux and Windows builds in autoupdate.go + offlineAutoUpdateCancel func() +} + +func (e *extension) Name() string { return "clientupdate" } + +func (e *extension) Init(h ipnext.Host) error { + + h.Hooks().ProfileStateChange.Add(e.onChangeProfile) + h.Hooks().BackendStateChange.Add(e.onBackendStateChange) + + // TODO(nickkhyl): remove this after the profileManager refactoring. + // See tailscale/tailscale#15974. + // This same workaround appears in feature/portlist/portlist.go. + profile, prefs := h.Profiles().CurrentProfileState() + e.onChangeProfile(profile, prefs, false) + + return nil +} + +func (e *extension) Shutdown() error { + e.stopOfflineAutoUpdate() + return nil +} + +func (e *extension) onBackendStateChange(newState ipn.State) { + e.mu.Lock() + defer e.mu.Unlock() + e.state = newState + e.updateOfflineAutoUpdateLocked() +} + +func (e *extension) onChangeProfile(profile ipn.LoginProfileView, prefs ipn.PrefsView, sameNode bool) { + e.mu.Lock() + defer e.mu.Unlock() + e.prefs = prefs + e.updateOfflineAutoUpdateLocked() +} + +func (e *extension) updateOfflineAutoUpdateLocked() { + want := e.prefs.Valid() && e.prefs.AutoUpdate().Apply.EqualBool(true) && + e.state != ipn.Running && e.state != ipn.Starting + + cur := e.offlineAutoUpdateCancel != nil + + if want && !cur { + e.maybeStartOfflineAutoUpdateLocked(e.prefs) + } else if !want && cur { + e.stopOfflineAutoUpdateLocked() + } +} + +type updateStatus struct { + started bool +} + +func (e *extension) clearSelfUpdateProgress() { + e.mu.Lock() + defer e.mu.Unlock() + e.selfUpdateProgress = make([]ipnstate.UpdateProgress, 0) + e.lastSelfUpdateState = ipnstate.UpdateFinished +} + +func (e *extension) GetSelfUpdateProgress() []ipnstate.UpdateProgress { + e.mu.Lock() + defer e.mu.Unlock() + res := make([]ipnstate.UpdateProgress, len(e.selfUpdateProgress)) + copy(res, e.selfUpdateProgress) + return res +} + +func (e *extension) DoSelfUpdate() { + e.mu.Lock() + updateState := e.lastSelfUpdateState + e.mu.Unlock() + // don't start an update if one is already in progress + if updateState == ipnstate.UpdateInProgress { + return + } + e.clearSelfUpdateProgress() + e.pushSelfUpdateProgress(ipnstate.NewUpdateProgress(ipnstate.UpdateInProgress, "")) + up, err := clientupdate.NewUpdater(clientupdate.Arguments{ + Logf: func(format string, args ...any) { + e.pushSelfUpdateProgress(ipnstate.NewUpdateProgress(ipnstate.UpdateInProgress, fmt.Sprintf(format, args...))) + }, + }) + if err != nil { + e.pushSelfUpdateProgress(ipnstate.NewUpdateProgress(ipnstate.UpdateFailed, err.Error())) + } + err = up.Update() + if err != nil { + e.pushSelfUpdateProgress(ipnstate.NewUpdateProgress(ipnstate.UpdateFailed, err.Error())) + } else { + e.pushSelfUpdateProgress(ipnstate.NewUpdateProgress(ipnstate.UpdateFinished, "tailscaled did not restart; please restart Tailscale manually.")) + } +} + +// serveUpdateInstall sends a request to the LocalBackend to start a Tailscale +// self-update. A successful response does not indicate whether the update +// succeeded, only that the request was accepted. Clients should use +// serveUpdateProgress after pinging this endpoint to check how the update is +// going. +func serveUpdateInstall(h *localapi.Handler, w http.ResponseWriter, r *http.Request) { + if r.Method != httpm.POST { + http.Error(w, "only POST allowed", http.StatusMethodNotAllowed) + return + } + + b := h.LocalBackend() + ext, ok := ipnlocal.GetExt[*extension](b) + if !ok { + http.Error(w, "clientupdate extension not found", http.StatusInternalServerError) + return + } + + w.WriteHeader(http.StatusAccepted) + + go ext.DoSelfUpdate() +} + +// serveUpdateProgress returns the status of an in-progress Tailscale self-update. +// This is provided as a slice of ipnstate.UpdateProgress structs with various +// log messages in order from oldest to newest. If an update is not in progress, +// the returned slice will be empty. +func serveUpdateProgress(h *localapi.Handler, w http.ResponseWriter, r *http.Request) { + if r.Method != httpm.GET { + http.Error(w, "only GET allowed", http.StatusMethodNotAllowed) + return + } + + b := h.LocalBackend() + ext, ok := ipnlocal.GetExt[*extension](b) + if !ok { + http.Error(w, "clientupdate extension not found", http.StatusInternalServerError) + return + } + + ups := ext.GetSelfUpdateProgress() + + json.NewEncoder(w).Encode(ups) +} + +func (e *extension) pushSelfUpdateProgress(up ipnstate.UpdateProgress) { + e.mu.Lock() + defer e.mu.Unlock() + e.selfUpdateProgress = append(e.selfUpdateProgress, up) + e.lastSelfUpdateState = up.Status +} + +func handleC2NUpdateGet(b *ipnlocal.LocalBackend, w http.ResponseWriter, r *http.Request) { + e, ok := ipnlocal.GetExt[*extension](b) + if !ok { + http.Error(w, "clientupdate extension not found", http.StatusInternalServerError) + return + } + + e.logf("c2n: GET /update received") + + res := e.newC2NUpdateResponse() + res.Started = e.c2nUpdateStarted() + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(res) +} + +func handleC2NUpdatePost(b *ipnlocal.LocalBackend, w http.ResponseWriter, r *http.Request) { + e, ok := ipnlocal.GetExt[*extension](b) + if !ok { + http.Error(w, "clientupdate extension not found", http.StatusInternalServerError) + return + } + e.logf("c2n: POST /update received") + res := e.newC2NUpdateResponse() + defer func() { + if res.Err != "" { + e.logf("c2n: POST /update failed: %s", res.Err) + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(res) + }() + + if !res.Enabled { + res.Err = "not enabled" + return + } + if !res.Supported { + res.Err = "not supported" + return + } + + // Do not update if we have active inbound SSH connections. Control can set + // force=true query parameter to override this. + if r.FormValue("force") != "true" && b.ActiveSSHConns() > 0 { + res.Err = "not updating due to active SSH connections" + return + } + + if err := e.startAutoUpdate("c2n"); err != nil { + res.Err = err.Error() + return + } + res.Started = true +} + +func (e *extension) newC2NUpdateResponse() tailcfg.C2NUpdateResponse { + e.mu.Lock() + defer e.mu.Unlock() + + // If NewUpdater does not return an error, we can update the installation. + // + // Note that we create the Updater solely to check for errors; we do not + // invoke it here. For this purpose, it is ok to pass it a zero Arguments. + var upPref ipn.AutoUpdatePrefs + if e.prefs.Valid() { + upPref = e.prefs.AutoUpdate() + } + return tailcfg.C2NUpdateResponse{ + Enabled: envknob.AllowsRemoteUpdate() || upPref.Apply.EqualBool(true), + Supported: feature.CanAutoUpdate() && !version.IsMacSysExt(), + } +} + +func (e *extension) c2nUpdateStarted() bool { + e.mu.Lock() + defer e.mu.Unlock() + return e.c2nUpdateStatus.started +} + +func (e *extension) setC2NUpdateStarted(v bool) { + e.mu.Lock() + defer e.mu.Unlock() + e.c2nUpdateStatus.started = v +} + +func (e *extension) trySetC2NUpdateStarted() bool { + e.mu.Lock() + defer e.mu.Unlock() + if e.c2nUpdateStatus.started { + return false + } + e.c2nUpdateStatus.started = true + return true +} + +// findCmdTailscale looks for the cmd/tailscale that corresponds to the +// currently running cmd/tailscaled. It's up to the caller to verify that the +// two match, but this function does its best to find the right one. Notably, it +// doesn't use $PATH for security reasons. +func findCmdTailscale() (string, error) { + self, err := os.Executable() + if err != nil { + return "", err + } + var ts string + switch runtime.GOOS { + case "linux": + if self == "/usr/sbin/tailscaled" || self == "/usr/bin/tailscaled" { + ts = "/usr/bin/tailscale" + } + if self == "/usr/local/sbin/tailscaled" || self == "/usr/local/bin/tailscaled" { + ts = "/usr/local/bin/tailscale" + } + switch distro.Get() { + case distro.QNAP: + // The volume under /share/ where qpkg are installed is not + // predictable. But the rest of the path is. + ok, err := filepath.Match("/share/*/.qpkg/Tailscale/tailscaled", self) + if err == nil && ok { + ts = filepath.Join(filepath.Dir(self), "tailscale") + } + case distro.Unraid: + if self == "/usr/local/emhttp/plugins/tailscale/bin/tailscaled" { + ts = "/usr/local/emhttp/plugins/tailscale/bin/tailscale" + } + } + case "windows": + ts = filepath.Join(filepath.Dir(self), "tailscale.exe") + case "freebsd", "openbsd": + if self == "/usr/local/bin/tailscaled" { + ts = "/usr/local/bin/tailscale" + } + default: + return "", fmt.Errorf("unsupported OS %v", runtime.GOOS) + } + if ts != "" && regularFileExists(ts) { + return ts, nil + } + return "", errors.New("tailscale executable not found in expected place") +} + +func tailscaleUpdateCmd(cmdTS string) *exec.Cmd { + defaultCmd := exec.Command(cmdTS, "update", "--yes") + if runtime.GOOS != "linux" { + return defaultCmd + } + if _, err := exec.LookPath("systemd-run"); err != nil { + return defaultCmd + } + + // When systemd-run is available, use it to run the update command. This + // creates a new temporary unit separate from the tailscaled unit. When + // tailscaled is restarted during the update, systemd won't kill this + // temporary update unit, which could cause unexpected breakage. + // + // We want to use a few optional flags: + // * --wait, to block the update command until completion (added in systemd 232) + // * --pipe, to collect stdout/stderr (added in systemd 235) + // * --collect, to clean up failed runs from memory (added in systemd 236) + // + // We need to check the version of systemd to figure out if those flags are + // available. + // + // The output will look like: + // + // systemd 255 (255.7-1-arch) + // +PAM +AUDIT ... other feature flags ... + systemdVerOut, err := exec.Command("systemd-run", "--version").Output() + if err != nil { + return defaultCmd + } + parts := strings.Fields(string(systemdVerOut)) + if len(parts) < 2 || parts[0] != "systemd" { + return defaultCmd + } + systemdVer, err := strconv.Atoi(parts[1]) + if err != nil { + return defaultCmd + } + if systemdVer >= 236 { + return exec.Command("systemd-run", "--wait", "--pipe", "--collect", cmdTS, "update", "--yes") + } else if systemdVer >= 235 { + return exec.Command("systemd-run", "--wait", "--pipe", cmdTS, "update", "--yes") + } else if systemdVer >= 232 { + return exec.Command("systemd-run", "--wait", cmdTS, "update", "--yes") + } else { + return exec.Command("systemd-run", cmdTS, "update", "--yes") + } +} + +func regularFileExists(path string) bool { + fi, err := os.Stat(path) + return err == nil && fi.Mode().IsRegular() +} + +// startAutoUpdate triggers an auto-update attempt. The actual update happens +// asynchronously. If another update is in progress, an error is returned. +func (e *extension) startAutoUpdate(logPrefix string) (retErr error) { + // Check if update was already started, and mark as started. + if !e.trySetC2NUpdateStarted() { + return errors.New("update already started") + } + defer func() { + // Clear the started flag if something failed. + if retErr != nil { + e.setC2NUpdateStarted(false) + } + }() + + cmdTS, err := findCmdTailscale() + if err != nil { + return fmt.Errorf("failed to find cmd/tailscale binary: %w", err) + } + var ver struct { + Long string `json:"long"` + } + out, err := exec.Command(cmdTS, "version", "--json").Output() + if err != nil { + return fmt.Errorf("failed to find cmd/tailscale binary: %w", err) + } + if err := json.Unmarshal(out, &ver); err != nil { + return fmt.Errorf("invalid JSON from cmd/tailscale version --json: %w", err) + } + if ver.Long != version.Long() { + return fmt.Errorf("cmd/tailscale version %q does not match tailscaled version %q", ver.Long, version.Long()) + } + + cmd := tailscaleUpdateCmd(cmdTS) + buf := new(bytes.Buffer) + cmd.Stdout = buf + cmd.Stderr = buf + e.logf("%s: running %q", logPrefix, strings.Join(cmd.Args, " ")) + if err := cmd.Start(); err != nil { + return fmt.Errorf("failed to start cmd/tailscale update: %w", err) + } + + go func() { + if err := cmd.Wait(); err != nil { + e.logf("%s: update command failed: %v, output: %s", logPrefix, err, buf) + } else { + e.logf("%s: update attempt complete", logPrefix) + } + e.setC2NUpdateStarted(false) + }() + return nil +} + +func (e *extension) stopOfflineAutoUpdate() { + e.mu.Lock() + defer e.mu.Unlock() + e.stopOfflineAutoUpdateLocked() +} + +func (e *extension) stopOfflineAutoUpdateLocked() { + if e.offlineAutoUpdateCancel == nil { + return + } + e.logf("offline auto-update: stopping update checks") + e.offlineAutoUpdateCancel() + e.offlineAutoUpdateCancel = nil +} + +// e.mu must be held +func (e *extension) maybeStartOfflineAutoUpdateLocked(prefs ipn.PrefsView) { + if !prefs.Valid() || !prefs.AutoUpdate().Apply.EqualBool(true) { + return + } + // AutoUpdate.Apply field in prefs can only be true for platforms that + // support auto-updates. But check it here again, just in case. + if !feature.CanAutoUpdate() { + return + } + // On macsys, auto-updates are managed by Sparkle. + if version.IsMacSysExt() { + return + } + + if e.offlineAutoUpdateCancel != nil { + // Already running. + return + } + ctx, cancel := context.WithCancel(context.Background()) + e.offlineAutoUpdateCancel = cancel + + e.logf("offline auto-update: starting update checks") + go e.offlineAutoUpdate(ctx) +} + +const offlineAutoUpdateCheckPeriod = time.Hour + +func (e *extension) offlineAutoUpdate(ctx context.Context) { + t := time.NewTicker(offlineAutoUpdateCheckPeriod) + defer t.Stop() + for { + select { + case <-ctx.Done(): + return + case <-t.C: + } + if err := e.startAutoUpdate("offline auto-update"); err != nil { + e.logf("offline auto-update: failed: %v", err) + } + } +} diff --git a/feature/condlite/expvar/expvar.go b/feature/condlite/expvar/expvar.go new file mode 100644 index 000000000..edc16ac77 --- /dev/null +++ b/feature/condlite/expvar/expvar.go @@ -0,0 +1,12 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !(ts_omit_debug && ts_omit_clientmetrics && ts_omit_usermetrics) + +// Package expvar contains type aliases for expvar types, to allow conditionally +// excluding the package from builds. +package expvar + +import "expvar" + +type Int = expvar.Int diff --git a/feature/condlite/expvar/omit.go b/feature/condlite/expvar/omit.go new file mode 100644 index 000000000..a21d94deb --- /dev/null +++ b/feature/condlite/expvar/omit.go @@ -0,0 +1,11 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build ts_omit_debug && ts_omit_clientmetrics && ts_omit_usermetrics + +// excluding the package from builds. +package expvar + +type Int int64 + +func (*Int) Add(int64) {} diff --git a/feature/condregister/condregister.go b/feature/condregister/condregister.go new file mode 100644 index 000000000..654483d1d --- /dev/null +++ b/feature/condregister/condregister.go @@ -0,0 +1,18 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// The condregister package registers all conditional features guarded +// by build tags. It is one central package that callers can empty import +// to ensure all conditional features are registered. +package condregister + +import ( + // Portmapper is special in that the CLI also needs to link it in, + // so it's pulled out into its own package, rather than using a maybe_*.go + // file in condregister. + _ "tailscale.com/feature/condregister/portmapper" + + // HTTP proxy support is also needed by the CLI, and tsnet, so it's its + // own package too. + _ "tailscale.com/feature/condregister/useproxy" +) diff --git a/feature/condregister/identityfederation/doc.go b/feature/condregister/identityfederation/doc.go new file mode 100644 index 000000000..503b2c8f1 --- /dev/null +++ b/feature/condregister/identityfederation/doc.go @@ -0,0 +1,7 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package identityfederation registers support for authkey resolution +// via identity federation if it's not disabled by the +// ts_omit_identityfederation build tag. +package identityfederation diff --git a/feature/condregister/identityfederation/maybe_identityfederation.go b/feature/condregister/identityfederation/maybe_identityfederation.go new file mode 100644 index 000000000..b1db42fc3 --- /dev/null +++ b/feature/condregister/identityfederation/maybe_identityfederation.go @@ -0,0 +1,8 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !ts_omit_identityfederation + +package identityfederation + +import _ "tailscale.com/feature/identityfederation" diff --git a/feature/condregister/maybe_ace.go b/feature/condregister/maybe_ace.go new file mode 100644 index 000000000..070231711 --- /dev/null +++ b/feature/condregister/maybe_ace.go @@ -0,0 +1,8 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !ts_omit_ace + +package condregister + +import _ "tailscale.com/feature/ace" diff --git a/feature/condregister/maybe_appconnectors.go b/feature/condregister/maybe_appconnectors.go new file mode 100644 index 000000000..70112d781 --- /dev/null +++ b/feature/condregister/maybe_appconnectors.go @@ -0,0 +1,8 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !ts_omit_appconnectors + +package condregister + +import _ "tailscale.com/feature/appconnectors" diff --git a/feature/condregister/maybe_c2n.go b/feature/condregister/maybe_c2n.go new file mode 100644 index 000000000..c222af533 --- /dev/null +++ b/feature/condregister/maybe_c2n.go @@ -0,0 +1,8 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !ts_omit_c2n + +package condregister + +import _ "tailscale.com/feature/c2n" diff --git a/feature/condregister/maybe_capture.go b/feature/condregister/maybe_capture.go new file mode 100644 index 000000000..0c68331f1 --- /dev/null +++ b/feature/condregister/maybe_capture.go @@ -0,0 +1,8 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !ios && !ts_omit_capture + +package condregister + +import _ "tailscale.com/feature/capture" diff --git a/feature/condregister/maybe_clientupdate.go b/feature/condregister/maybe_clientupdate.go new file mode 100644 index 000000000..bc694f970 --- /dev/null +++ b/feature/condregister/maybe_clientupdate.go @@ -0,0 +1,8 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !ts_omit_clientupdate + +package condregister + +import _ "tailscale.com/feature/clientupdate" diff --git a/feature/condregister/maybe_debugportmapper.go b/feature/condregister/maybe_debugportmapper.go new file mode 100644 index 000000000..4990d09ea --- /dev/null +++ b/feature/condregister/maybe_debugportmapper.go @@ -0,0 +1,8 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !ts_omit_debugportmapper + +package condregister + +import _ "tailscale.com/feature/debugportmapper" diff --git a/feature/condregister/maybe_doctor.go b/feature/condregister/maybe_doctor.go new file mode 100644 index 000000000..3dc9ffa53 --- /dev/null +++ b/feature/condregister/maybe_doctor.go @@ -0,0 +1,8 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !ts_omit_doctor + +package condregister + +import _ "tailscale.com/feature/doctor" diff --git a/feature/condregister/maybe_drive.go b/feature/condregister/maybe_drive.go new file mode 100644 index 000000000..cb447ff28 --- /dev/null +++ b/feature/condregister/maybe_drive.go @@ -0,0 +1,8 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !ts_omit_drive + +package condregister + +import _ "tailscale.com/feature/drive" diff --git a/feature/condregister/maybe_linkspeed.go b/feature/condregister/maybe_linkspeed.go new file mode 100644 index 000000000..46064b39a --- /dev/null +++ b/feature/condregister/maybe_linkspeed.go @@ -0,0 +1,8 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build linux && !android && !ts_omit_linkspeed + +package condregister + +import _ "tailscale.com/feature/linkspeed" diff --git a/feature/condregister/maybe_linuxdnsfight.go b/feature/condregister/maybe_linuxdnsfight.go new file mode 100644 index 000000000..0dae62b00 --- /dev/null +++ b/feature/condregister/maybe_linuxdnsfight.go @@ -0,0 +1,8 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build linux && !android && !ts_omit_linuxdnsfight + +package condregister + +import _ "tailscale.com/feature/linuxdnsfight" diff --git a/feature/condregister/maybe_osrouter.go b/feature/condregister/maybe_osrouter.go new file mode 100644 index 000000000..7ab85add2 --- /dev/null +++ b/feature/condregister/maybe_osrouter.go @@ -0,0 +1,8 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !ts_omit_osrouter + +package condregister + +import _ "tailscale.com/wgengine/router/osrouter" diff --git a/feature/condregister/maybe_portlist.go b/feature/condregister/maybe_portlist.go new file mode 100644 index 000000000..1be56f177 --- /dev/null +++ b/feature/condregister/maybe_portlist.go @@ -0,0 +1,8 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !ts_omit_portlist + +package condregister + +import _ "tailscale.com/feature/portlist" diff --git a/feature/condregister/maybe_posture.go b/feature/condregister/maybe_posture.go new file mode 100644 index 000000000..6f14c2713 --- /dev/null +++ b/feature/condregister/maybe_posture.go @@ -0,0 +1,8 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !ts_omit_posture + +package condregister + +import _ "tailscale.com/feature/posture" diff --git a/feature/condregister/maybe_relayserver.go b/feature/condregister/maybe_relayserver.go new file mode 100644 index 000000000..3360dd062 --- /dev/null +++ b/feature/condregister/maybe_relayserver.go @@ -0,0 +1,8 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !ios && !ts_omit_relayserver + +package condregister + +import _ "tailscale.com/feature/relayserver" diff --git a/feature/condregister/maybe_sdnotify.go b/feature/condregister/maybe_sdnotify.go new file mode 100644 index 000000000..647996f88 --- /dev/null +++ b/feature/condregister/maybe_sdnotify.go @@ -0,0 +1,8 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build linux && !ts_omit_sdnotify + +package condregister + +import _ "tailscale.com/feature/sdnotify" diff --git a/feature/condregister/maybe_store_aws.go b/feature/condregister/maybe_store_aws.go new file mode 100644 index 000000000..8358b49f0 --- /dev/null +++ b/feature/condregister/maybe_store_aws.go @@ -0,0 +1,8 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build (ts_aws || (linux && (arm64 || amd64) && !android)) && !ts_omit_aws + +package condregister + +import _ "tailscale.com/ipn/store/awsstore" diff --git a/feature/condregister/maybe_store_kube.go b/feature/condregister/maybe_store_kube.go new file mode 100644 index 000000000..bb795b05e --- /dev/null +++ b/feature/condregister/maybe_store_kube.go @@ -0,0 +1,8 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build (ts_kube || (linux && (arm64 || amd64) && !android)) && !ts_omit_kube + +package condregister + +import _ "tailscale.com/ipn/store/kubestore" diff --git a/feature/condregister/maybe_syspolicy.go b/feature/condregister/maybe_syspolicy.go new file mode 100644 index 000000000..49ec5c02c --- /dev/null +++ b/feature/condregister/maybe_syspolicy.go @@ -0,0 +1,8 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !ts_omit_syspolicy + +package condregister + +import _ "tailscale.com/feature/syspolicy" diff --git a/feature/condregister/maybe_taildrop.go b/feature/condregister/maybe_taildrop.go new file mode 100644 index 000000000..5fd7b5f8c --- /dev/null +++ b/feature/condregister/maybe_taildrop.go @@ -0,0 +1,8 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !ts_omit_taildrop + +package condregister + +import _ "tailscale.com/feature/taildrop" diff --git a/feature/condregister/maybe_tap.go b/feature/condregister/maybe_tap.go new file mode 100644 index 000000000..eca4fc3ac --- /dev/null +++ b/feature/condregister/maybe_tap.go @@ -0,0 +1,8 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build linux && !ts_omit_tap + +package condregister + +import _ "tailscale.com/feature/tap" diff --git a/feature/condregister/maybe_tpm.go b/feature/condregister/maybe_tpm.go new file mode 100644 index 000000000..caa57fef1 --- /dev/null +++ b/feature/condregister/maybe_tpm.go @@ -0,0 +1,8 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !ios && !ts_omit_tpm + +package condregister + +import _ "tailscale.com/feature/tpm" diff --git a/feature/condregister/maybe_wakeonlan.go b/feature/condregister/maybe_wakeonlan.go new file mode 100644 index 000000000..14cae605d --- /dev/null +++ b/feature/condregister/maybe_wakeonlan.go @@ -0,0 +1,8 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !ts_omit_wakeonlan + +package condregister + +import _ "tailscale.com/feature/wakeonlan" diff --git a/feature/condregister/oauthkey/doc.go b/feature/condregister/oauthkey/doc.go new file mode 100644 index 000000000..4c4ea5e4e --- /dev/null +++ b/feature/condregister/oauthkey/doc.go @@ -0,0 +1,10 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package oauthkey registers support for OAuth key resolution +// if it's not disabled via the ts_omit_oauthkey build tag. +// Currently (2025-09-19), tailscaled does not need OAuth key +// resolution, only the CLI and tsnet do, so this package is +// pulled out separately to avoid linking OAuth packages into +// tailscaled. +package oauthkey diff --git a/feature/condregister/oauthkey/maybe_oauthkey.go b/feature/condregister/oauthkey/maybe_oauthkey.go new file mode 100644 index 000000000..be8d04b8e --- /dev/null +++ b/feature/condregister/oauthkey/maybe_oauthkey.go @@ -0,0 +1,8 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !ts_omit_oauthkey + +package oauthkey + +import _ "tailscale.com/feature/oauthkey" diff --git a/feature/condregister/portmapper/doc.go b/feature/condregister/portmapper/doc.go new file mode 100644 index 000000000..5c30538c4 --- /dev/null +++ b/feature/condregister/portmapper/doc.go @@ -0,0 +1,6 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package portmapper registers support for portmapper +// if it's not disabled via the ts_omit_portmapper build tag. +package portmapper diff --git a/feature/condregister/portmapper/maybe_portmapper.go b/feature/condregister/portmapper/maybe_portmapper.go new file mode 100644 index 000000000..c306fd3d5 --- /dev/null +++ b/feature/condregister/portmapper/maybe_portmapper.go @@ -0,0 +1,8 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !ts_omit_portmapper + +package portmapper + +import _ "tailscale.com/feature/portmapper" diff --git a/feature/condregister/useproxy/doc.go b/feature/condregister/useproxy/doc.go new file mode 100644 index 000000000..1e8abb358 --- /dev/null +++ b/feature/condregister/useproxy/doc.go @@ -0,0 +1,6 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package useproxy registers support for using proxies +// if it's not disabled via the ts_omit_useproxy build tag. +package useproxy diff --git a/feature/condregister/useproxy/useproxy.go b/feature/condregister/useproxy/useproxy.go new file mode 100644 index 000000000..bda6e49c0 --- /dev/null +++ b/feature/condregister/useproxy/useproxy.go @@ -0,0 +1,8 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !ts_omit_useproxy + +package useproxy + +import _ "tailscale.com/feature/useproxy" diff --git a/feature/debugportmapper/debugportmapper.go b/feature/debugportmapper/debugportmapper.go new file mode 100644 index 000000000..2625086c6 --- /dev/null +++ b/feature/debugportmapper/debugportmapper.go @@ -0,0 +1,204 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package debugportmapper registers support for debugging Tailscale's +// portmapping support. +package debugportmapper + +import ( + "context" + "fmt" + "net" + "net/http" + "net/netip" + "strconv" + "strings" + "sync" + "time" + + "tailscale.com/ipn/localapi" + "tailscale.com/net/netmon" + "tailscale.com/net/portmapper" + "tailscale.com/types/logger" + "tailscale.com/util/eventbus" +) + +func init() { + localapi.Register("debug-portmap", serveDebugPortmap) +} + +func serveDebugPortmap(h *localapi.Handler, w http.ResponseWriter, r *http.Request) { + if !h.PermitWrite { + http.Error(w, "debug access denied", http.StatusForbidden) + return + } + w.Header().Set("Content-Type", "text/plain") + + dur, err := time.ParseDuration(r.FormValue("duration")) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + gwSelf := r.FormValue("gateway_and_self") + + trueFunc := func() bool { return true } + // Update portmapper debug flags + debugKnobs := &portmapper.DebugKnobs{VerboseLogs: true} + switch r.FormValue("type") { + case "": + case "pmp": + debugKnobs.DisablePCPFunc = trueFunc + debugKnobs.DisableUPnPFunc = trueFunc + case "pcp": + debugKnobs.DisablePMPFunc = trueFunc + debugKnobs.DisableUPnPFunc = trueFunc + case "upnp": + debugKnobs.DisablePCPFunc = trueFunc + debugKnobs.DisablePMPFunc = trueFunc + default: + http.Error(w, "unknown portmap debug type", http.StatusBadRequest) + return + } + if k := h.LocalBackend().ControlKnobs(); k != nil { + if k.DisableUPnP.Load() { + debugKnobs.DisableUPnPFunc = trueFunc + } + } + + if defBool(r.FormValue("log_http"), false) { + debugKnobs.LogHTTP = true + } + + var ( + logLock sync.Mutex + handlerDone bool + ) + logf := func(format string, args ...any) { + if !strings.HasSuffix(format, "\n") { + format = format + "\n" + } + + logLock.Lock() + defer logLock.Unlock() + + // The portmapper can call this log function after the HTTP + // handler returns, which is not allowed and can cause a panic. + // If this happens, ignore the log lines since this typically + // occurs due to a client disconnect. + if handlerDone { + return + } + + // Write and flush each line to the client so that output is streamed + fmt.Fprintf(w, format, args...) + if f, ok := w.(http.Flusher); ok { + f.Flush() + } + } + defer func() { + logLock.Lock() + handlerDone = true + logLock.Unlock() + }() + + ctx, cancel := context.WithTimeout(r.Context(), dur) + defer cancel() + + done := make(chan bool, 1) + + var c *portmapper.Client + c = portmapper.NewClient(portmapper.Config{ + Logf: logger.WithPrefix(logf, "portmapper: "), + NetMon: h.LocalBackend().NetMon(), + DebugKnobs: debugKnobs, + EventBus: h.LocalBackend().EventBus(), + OnChange: func() { + logf("portmapping changed.") + logf("have mapping: %v", c.HaveMapping()) + + if ext, ok := c.GetCachedMappingOrStartCreatingOne(); ok { + logf("cb: mapping: %v", ext) + select { + case done <- true: + default: + } + return + } + logf("cb: no mapping") + }, + }) + defer c.Close() + + bus := eventbus.New() + defer bus.Close() + netMon, err := netmon.New(bus, logger.WithPrefix(logf, "monitor: ")) + if err != nil { + logf("error creating monitor: %v", err) + return + } + + gatewayAndSelfIP := func() (gw, self netip.Addr, ok bool) { + if a, b, ok := strings.Cut(gwSelf, "/"); ok { + gw = netip.MustParseAddr(a) + self = netip.MustParseAddr(b) + return gw, self, true + } + return netMon.GatewayAndSelfIP() + } + + c.SetGatewayLookupFunc(gatewayAndSelfIP) + + gw, selfIP, ok := gatewayAndSelfIP() + if !ok { + logf("no gateway or self IP; %v", netMon.InterfaceState()) + return + } + logf("gw=%v; self=%v", gw, selfIP) + + uc, err := net.ListenPacket("udp", "0.0.0.0:0") + if err != nil { + return + } + defer uc.Close() + c.SetLocalPort(uint16(uc.LocalAddr().(*net.UDPAddr).Port)) + + res, err := c.Probe(ctx) + if err != nil { + logf("error in Probe: %v", err) + return + } + logf("Probe: %+v", res) + + if !res.PCP && !res.PMP && !res.UPnP { + logf("no portmapping services available") + return + } + + if ext, ok := c.GetCachedMappingOrStartCreatingOne(); ok { + logf("mapping: %v", ext) + } else { + logf("no mapping") + } + + select { + case <-done: + case <-ctx.Done(): + if r.Context().Err() == nil { + logf("serveDebugPortmap: context done: %v", ctx.Err()) + } else { + h.Logf("serveDebugPortmap: context done: %v", ctx.Err()) + } + } +} + +func defBool(a string, def bool) bool { + if a == "" { + return def + } + v, err := strconv.ParseBool(a) + if err != nil { + return def + } + return v +} diff --git a/feature/doctor/doctor.go b/feature/doctor/doctor.go new file mode 100644 index 000000000..875b57d14 --- /dev/null +++ b/feature/doctor/doctor.go @@ -0,0 +1,95 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// The doctor package registers the "doctor" problem diagnosis support into the +// rest of Tailscale. +package doctor + +import ( + "context" + "fmt" + "html" + "net/http" + "time" + + "tailscale.com/doctor" + "tailscale.com/doctor/ethtool" + "tailscale.com/doctor/permissions" + "tailscale.com/doctor/routetable" + "tailscale.com/ipn/ipnlocal" + "tailscale.com/net/tsaddr" + "tailscale.com/types/logger" +) + +func init() { + ipnlocal.HookDoctor.Set(visitDoctor) + ipnlocal.RegisterPeerAPIHandler("/v0/doctor", handleServeDoctor) +} + +func handleServeDoctor(h ipnlocal.PeerAPIHandler, w http.ResponseWriter, r *http.Request) { + if !h.CanDebug() { + http.Error(w, "denied; no debug access", http.StatusForbidden) + return + } + w.Header().Set("Content-Type", "text/html; charset=utf-8") + fmt.Fprintln(w, "

Doctor Output

") + + fmt.Fprintln(w, "
")
+
+	b := h.LocalBackend()
+	visitDoctor(r.Context(), b, func(format string, args ...any) {
+		line := fmt.Sprintf(format, args...)
+		fmt.Fprintln(w, html.EscapeString(line))
+	})
+
+	fmt.Fprintln(w, "
") +} + +func visitDoctor(ctx context.Context, b *ipnlocal.LocalBackend, logf logger.Logf) { + // We can write logs too fast for logtail to handle, even when + // opting-out of rate limits. Limit ourselves to at most one message + // per 20ms and a burst of 60 log lines, which should be fast enough to + // not block for too long but slow enough that we can upload all lines. + logf = logger.SlowLoggerWithClock(ctx, logf, 20*time.Millisecond, 60, b.Clock().Now) + + var checks []doctor.Check + checks = append(checks, + permissions.Check{}, + routetable.Check{}, + ethtool.Check{}, + ) + + // Print a log message if any of the global DNS resolvers are Tailscale + // IPs; this can interfere with our ability to connect to the Tailscale + // controlplane. + checks = append(checks, doctor.CheckFunc("dns-resolvers", func(_ context.Context, logf logger.Logf) error { + nm := b.NetMap() + if nm == nil { + return nil + } + + for i, resolver := range nm.DNS.Resolvers { + ipp, ok := resolver.IPPort() + if ok && tsaddr.IsTailscaleIP(ipp.Addr()) { + logf("resolver %d is a Tailscale address: %v", i, resolver) + } + } + for i, resolver := range nm.DNS.FallbackResolvers { + ipp, ok := resolver.IPPort() + if ok && tsaddr.IsTailscaleIP(ipp.Addr()) { + logf("fallback resolver %d is a Tailscale address: %v", i, resolver) + } + } + return nil + })) + + // TODO(andrew): more + + numChecks := len(checks) + checks = append(checks, doctor.CheckFunc("numchecks", func(_ context.Context, log logger.Logf) error { + log("%d checks", numChecks) + return nil + })) + + doctor.RunChecks(ctx, logf, checks...) +} diff --git a/feature/drive/drive.go b/feature/drive/drive.go new file mode 100644 index 000000000..3660a2b95 --- /dev/null +++ b/feature/drive/drive.go @@ -0,0 +1,5 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package drive registers the Taildrive (file server) feature. +package drive diff --git a/feature/feature.go b/feature/feature.go new file mode 100644 index 000000000..110b104da --- /dev/null +++ b/feature/feature.go @@ -0,0 +1,95 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package feature tracks which features are linked into the binary. +package feature + +import ( + "errors" + "reflect" +) + +var ErrUnavailable = errors.New("feature not included in this build") + +var in = map[string]bool{} + +// Registered reports the set of registered features. +// +// The returned map should not be modified by the caller, +// not accessed concurrently with calls to Register. +func Registered() map[string]bool { return in } + +// Register notes that the named feature is linked into the binary. +func Register(name string) { + if _, ok := in[name]; ok { + panic("duplicate feature registration for " + name) + } + in[name] = true +} + +// Hook is a func that can only be set once. +// +// It is not safe for concurrent use. +type Hook[Func any] struct { + f Func + ok bool +} + +// IsSet reports whether the hook has been set. +func (h *Hook[Func]) IsSet() bool { + return h.ok +} + +// Set sets the hook function, panicking if it's already been set +// or f is the zero value. +// +// It's meant to be called in init. +func (h *Hook[Func]) Set(f Func) { + if h.ok { + panic("Set on already-set feature hook") + } + if reflect.ValueOf(f).IsZero() { + panic("Set with zero value") + } + h.f = f + h.ok = true +} + +// Get returns the hook function, or panics if it hasn't been set. +// Use IsSet to check if it's been set, or use GetOrNil if you're +// okay with a nil return value. +func (h *Hook[Func]) Get() Func { + if !h.ok { + panic("Get on unset feature hook, without IsSet") + } + return h.f +} + +// GetOk returns the hook function and true if it has been set, +// otherwise its zero value and false. +func (h *Hook[Func]) GetOk() (f Func, ok bool) { + return h.f, h.ok +} + +// GetOrNil returns the hook function or nil if it hasn't been set. +func (h *Hook[Func]) GetOrNil() Func { + return h.f +} + +// Hooks is a slice of funcs. +// +// As opposed to a single Hook, this is meant to be used when +// multiple parties are able to install the same hook. +type Hooks[Func any] []Func + +// Add adds a hook to the list of hooks. +// +// Add should only be called during early program +// startup before Tailscale has started. +// It is not safe for concurrent use. +func (h *Hooks[Func]) Add(f Func) { + if reflect.ValueOf(f).IsZero() { + panic("Add with zero value") + } + *h = append(*h, f) +} diff --git a/feature/featuretags/featuretags.go b/feature/featuretags/featuretags.go new file mode 100644 index 000000000..44b129576 --- /dev/null +++ b/feature/featuretags/featuretags.go @@ -0,0 +1,291 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// The featuretags package is a registry of all the ts_omit-able build tags. +package featuretags + +import "tailscale.com/util/set" + +// CLI is a special feature in the [Features] map that works opposite +// from the others: it is opt-in, rather than opt-out, having a different +// build tag format. +const CLI FeatureTag = "cli" + +// FeatureTag names a Tailscale feature that can be selectively added or removed +// via build tags. +type FeatureTag string + +// IsOmittable reports whether this feature tag is one that can be +// omitted via a ts_omit_ build tag. +func (ft FeatureTag) IsOmittable() bool { + switch ft { + case CLI: + return false + } + return true +} + +// OmitTag returns the ts_omit_ build tag for this feature tag. +// It panics if the feature tag is not omitable. +func (ft FeatureTag) OmitTag() string { + if !ft.IsOmittable() { + panic("not omitable: " + string(ft)) + } + return "ts_omit_" + string(ft) +} + +// Requires returns the set of features that must be included to +// use the given feature, including the provided feature itself. +func Requires(ft FeatureTag) set.Set[FeatureTag] { + s := set.Set[FeatureTag]{} + var add func(FeatureTag) + add = func(ft FeatureTag) { + s.Add(ft) + for _, dep := range Features[ft].Deps { + add(dep) + } + } + add(ft) + return s +} + +// RequiredBy is the inverse of Requires: it returns the set of features that +// depend on the given feature (directly or indirectly), including the feature +// itself. +func RequiredBy(ft FeatureTag) set.Set[FeatureTag] { + s := set.Set[FeatureTag]{} + for f := range Features { + if featureDependsOn(f, ft) { + s.Add(f) + } + } + return s +} + +// featureDependsOn reports whether feature a (directly or indirectly) depends on b. +// It returns true if a == b. +func featureDependsOn(a, b FeatureTag) bool { + if a == b { + return true + } + for _, dep := range Features[a].Deps { + if featureDependsOn(dep, b) { + return true + } + } + return false +} + +// FeatureMeta describes a modular feature that can be conditionally linked into +// the binary. +type FeatureMeta struct { + Sym string // exported Go symbol for boolean const + Desc string // human-readable description + Deps []FeatureTag // other features this feature requires + + // ImplementationDetail is whether the feature is an internal implementation + // detail. That is, it's not something a user wuold care about having or not + // having, but we'd like to able to omit from builds if no other + // user-visible features depend on it. + ImplementationDetail bool +} + +// Features are the known Tailscale features that can be selectively included or +// excluded via build tags, and a description of each. +var Features = map[FeatureTag]FeatureMeta{ + "ace": {Sym: "ACE", Desc: "Alternate Connectivity Endpoints"}, + "acme": {Sym: "ACME", Desc: "ACME TLS certificate management"}, + "appconnectors": {Sym: "AppConnectors", Desc: "App Connectors support"}, + "aws": {Sym: "AWS", Desc: "AWS integration"}, + "advertiseexitnode": { + Sym: "AdvertiseExitNode", + Desc: "Run an exit node", + Deps: []FeatureTag{ + "peerapiserver", // to run the ExitDNS server + "advertiseroutes", + }, + }, + "advertiseroutes": { + Sym: "AdvertiseRoutes", + Desc: "Advertise routes for other nodes to use", + Deps: []FeatureTag{ + "c2n", // for control plane to probe health for HA subnet router leader election + }, + }, + "bakedroots": {Sym: "BakedRoots", Desc: "Embed CA (LetsEncrypt) x509 roots to use as fallback"}, + "bird": { + Sym: "Bird", + Desc: "Bird BGP integration", + Deps: []FeatureTag{"advertiseroutes"}, + }, + "c2n": { + Sym: "C2N", + Desc: "Control-to-node (C2N) support", + ImplementationDetail: true, + }, + "cachenetmap": { + Sym: "CacheNetMap", + Desc: "Cache the netmap on disk between runs", + }, + "captiveportal": {Sym: "CaptivePortal", Desc: "Captive portal detection"}, + "capture": {Sym: "Capture", Desc: "Packet capture"}, + "cli": {Sym: "CLI", Desc: "embed the CLI into the tailscaled binary"}, + "cliconndiag": {Sym: "CLIConnDiag", Desc: "CLI connection error diagnostics"}, + "clientmetrics": {Sym: "ClientMetrics", Desc: "Client metrics support"}, + "clientupdate": { + Sym: "ClientUpdate", + Desc: "Client auto-update support", + Deps: []FeatureTag{"c2n"}, + }, + "completion": {Sym: "Completion", Desc: "CLI shell completion"}, + "cloud": {Sym: "Cloud", Desc: "detect cloud environment to learn instances IPs and DNS servers"}, + "dbus": { + Sym: "DBus", + Desc: "Linux DBus support", + ImplementationDetail: true, + }, + "debug": {Sym: "Debug", Desc: "various debug support, for things that don't have or need their own more specific feature"}, + "debugeventbus": {Sym: "DebugEventBus", Desc: "eventbus debug support"}, + "debugportmapper": { + Sym: "DebugPortMapper", + Desc: "portmapper debug support", + Deps: []FeatureTag{"portmapper"}, + }, + "desktop_sessions": {Sym: "DesktopSessions", Desc: "Desktop sessions support"}, + "doctor": {Sym: "Doctor", Desc: "Diagnose possible issues with Tailscale and its host environment"}, + "drive": {Sym: "Drive", Desc: "Tailscale Drive (file server) support"}, + "gro": { + Sym: "GRO", + Desc: "Generic Receive Offload support (performance)", + Deps: []FeatureTag{"netstack"}, + }, + "health": {Sym: "Health", Desc: "Health checking support"}, + "hujsonconf": {Sym: "HuJSONConf", Desc: "HuJSON config file support"}, + "identityfederation": {Sym: "IdentityFederation", Desc: "Auth key generation via identity federation support"}, + "iptables": {Sym: "IPTables", Desc: "Linux iptables support"}, + "kube": {Sym: "Kube", Desc: "Kubernetes integration"}, + "lazywg": {Sym: "LazyWG", Desc: "Lazy WireGuard configuration for memory-constrained devices with large netmaps"}, + "linuxdnsfight": {Sym: "LinuxDNSFight", Desc: "Linux support for detecting DNS fights (inotify watching of /etc/resolv.conf)"}, + "linkspeed": { + Sym: "LinkSpeed", + Desc: "Set link speed on TUN device for better OS integration (Linux only)", + }, + "listenrawdisco": { + Sym: "ListenRawDisco", + Desc: "Use raw sockets for more robust disco (NAT traversal) message receiving (Linux only)", + }, + "logtail": { + Sym: "LogTail", + Desc: "upload logs to log.tailscale.com (debug logs for bug reports and also by network flow logs if enabled)", + }, + "oauthkey": {Sym: "OAuthKey", Desc: "OAuth secret-to-authkey resolution support"}, + "outboundproxy": { + Sym: "OutboundProxy", + Desc: "Support running an outbound localhost HTTP/SOCK5 proxy support that sends traffic over Tailscale", + Deps: []FeatureTag{"netstack"}, + }, + "osrouter": { + Sym: "OSRouter", + Desc: "Configure the operating system's network stack, IPs, and routing tables", + // TODO(bradfitz): if this is omitted, and netstack is too, then tailscaled needs + // external config to be useful. Some people may want that, and we should support it, + // but it's rare. Maybe there should be a way to declare here that this "Provides" + // another feature (and netstack can too), and then if those required features provided + // by some other feature are missing, then it's an error by default unless you accept + // that it's okay to proceed without that meta feature. + }, + "peerapiclient": { + Sym: "PeerAPIClient", + Desc: "PeerAPI client support", + ImplementationDetail: true, + }, + "peerapiserver": { + Sym: "PeerAPIServer", + Desc: "PeerAPI server support", + ImplementationDetail: true, + }, + "portlist": {Sym: "PortList", Desc: "Optionally advertise listening service ports"}, + "portmapper": {Sym: "PortMapper", Desc: "NAT-PMP/PCP/UPnP port mapping support"}, + "posture": {Sym: "Posture", Desc: "Device posture checking support"}, + "dns": { + Sym: "DNS", + Desc: "MagicDNS and system DNS configuration support", + }, + "netlog": { + Sym: "NetLog", + Desc: "Network flow logging support", + Deps: []FeatureTag{"logtail"}, + }, + "netstack": {Sym: "Netstack", Desc: "gVisor netstack (userspace networking) support"}, + "networkmanager": { + Sym: "NetworkManager", + Desc: "Linux NetworkManager integration", + Deps: []FeatureTag{"dbus"}, + }, + "relayserver": {Sym: "RelayServer", Desc: "Relay server"}, + "resolved": { + Sym: "Resolved", + Desc: "Linux systemd-resolved integration", + Deps: []FeatureTag{"dbus"}, + }, + "sdnotify": { + Sym: "SDNotify", + Desc: "systemd notification support", + }, + "serve": { + Sym: "Serve", + Desc: "Serve and Funnel support", + Deps: []FeatureTag{"netstack"}, + }, + "ssh": { + Sym: "SSH", + Desc: "Tailscale SSH support", + Deps: []FeatureTag{"c2n", "dbus", "netstack"}, + }, + "synology": { + Sym: "Synology", + Desc: "Synology NAS integration (applies to Linux builds only)", + }, + "syspolicy": {Sym: "SystemPolicy", Desc: "System policy configuration (MDM) support"}, + "systray": { + Sym: "SysTray", + Desc: "Linux system tray", + Deps: []FeatureTag{"dbus"}, + }, + "taildrop": { + Sym: "Taildrop", + Desc: "Taildrop (file sending) support", + Deps: []FeatureTag{ + "peerapiclient", "peerapiserver", // assume Taildrop is both sides for now + }, + }, + "tailnetlock": {Sym: "TailnetLock", Desc: "Tailnet Lock support"}, + "tap": {Sym: "Tap", Desc: "Experimental Layer 2 (ethernet) support"}, + "tpm": {Sym: "TPM", Desc: "TPM support"}, + "unixsocketidentity": { + Sym: "UnixSocketIdentity", + Desc: "differentiate between users accessing the LocalAPI over unix sockets (if omitted, all users have full access)", + }, + "useroutes": { + Sym: "UseRoutes", + Desc: "Use routes advertised by other nodes", + }, + "useexitnode": { + Sym: "UseExitNode", + Desc: "Use exit nodes", + Deps: []FeatureTag{"peerapiclient", "useroutes"}, + }, + "useproxy": { + Sym: "UseProxy", + Desc: "Support using system proxies as specified by env vars or the system configuration to reach Tailscale servers.", + }, + "usermetrics": { + Sym: "UserMetrics", + Desc: "Usermetrics (documented, stable) metrics support", + }, + "wakeonlan": {Sym: "WakeOnLAN", Desc: "Wake-on-LAN support"}, + "webclient": { + Sym: "WebClient", Desc: "Web client support", + Deps: []FeatureTag{"serve"}, + }, +} diff --git a/feature/featuretags/featuretags_test.go b/feature/featuretags/featuretags_test.go new file mode 100644 index 000000000..893ab0e6a --- /dev/null +++ b/feature/featuretags/featuretags_test.go @@ -0,0 +1,85 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package featuretags + +import ( + "maps" + "slices" + "testing" + + "tailscale.com/util/set" +) + +func TestKnownDeps(t *testing.T) { + for tag, meta := range Features { + for _, dep := range meta.Deps { + if _, ok := Features[dep]; !ok { + t.Errorf("feature %q has unknown dependency %q", tag, dep) + } + } + + // And indirectly check for cycles. If there were a cycle, + // this would infinitely loop. + deps := Requires(tag) + t.Logf("deps of %q: %v", tag, slices.Sorted(maps.Keys(deps))) + } +} + +func TestRequires(t *testing.T) { + var setOf = set.Of[FeatureTag] + tests := []struct { + in FeatureTag + want set.Set[FeatureTag] + }{ + { + in: "drive", + want: setOf("drive"), + }, + { + in: "cli", + want: setOf("cli"), + }, + { + in: "serve", + want: setOf("serve", "netstack"), + }, + { + in: "webclient", + want: setOf("webclient", "serve", "netstack"), + }, + } + for _, tt := range tests { + got := Requires(tt.in) + if !maps.Equal(got, tt.want) { + t.Errorf("DepSet(%q) = %v, want %v", tt.in, got, tt.want) + } + } +} + +func TestRequiredBy(t *testing.T) { + var setOf = set.Of[FeatureTag] + tests := []struct { + in FeatureTag + want set.Set[FeatureTag] + }{ + { + in: "drive", + want: setOf("drive"), + }, + { + in: "webclient", + want: setOf("webclient"), + }, + { + in: "serve", + want: setOf("webclient", "serve"), + }, + } + for _, tt := range tests { + got := RequiredBy(tt.in) + if !maps.Equal(got, tt.want) { + t.Errorf("FeaturesWhichDependOn(%q) = %v, want %v", tt.in, got, tt.want) + } + } +} diff --git a/feature/hooks.go b/feature/hooks.go new file mode 100644 index 000000000..a3c6c0395 --- /dev/null +++ b/feature/hooks.go @@ -0,0 +1,73 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package feature + +import ( + "net/http" + "net/url" + + "tailscale.com/types/logger" + "tailscale.com/types/persist" +) + +// HookCanAutoUpdate is a hook for the clientupdate package +// to conditionally initialize. +var HookCanAutoUpdate Hook[func() bool] + +// CanAutoUpdate reports whether the current binary is built with auto-update +// support and, if so, whether the current platform supports it. +func CanAutoUpdate() bool { + if f, ok := HookCanAutoUpdate.GetOk(); ok { + return f() + } + return false +} + +// HookProxyFromEnvironment is a hook for feature/useproxy to register +// a function to use as http.ProxyFromEnvironment. +var HookProxyFromEnvironment Hook[func(*http.Request) (*url.URL, error)] + +// HookProxyInvalidateCache is a hook for feature/useproxy to register +// [tshttpproxy.InvalidateCache]. +var HookProxyInvalidateCache Hook[func()] + +// HookProxyGetAuthHeader is a hook for feature/useproxy to register +// [tshttpproxy.GetAuthHeader]. +var HookProxyGetAuthHeader Hook[func(*url.URL) (string, error)] + +// HookProxySetSelfProxy is a hook for feature/useproxy to register +// [tshttpproxy.SetSelfProxy]. +var HookProxySetSelfProxy Hook[func(...string)] + +// HookProxySetTransportGetProxyConnectHeader is a hook for feature/useproxy to register +// [tshttpproxy.SetTransportGetProxyConnectHeader]. +var HookProxySetTransportGetProxyConnectHeader Hook[func(*http.Transport)] + +// HookTPMAvailable is a hook that reports whether a TPM device is supported +// and available. +var HookTPMAvailable Hook[func() bool] + +var HookGenerateAttestationKeyIfEmpty Hook[func(p *persist.Persist, logf logger.Logf) (bool, error)] + +// TPMAvailable reports whether a TPM device is supported and available. +func TPMAvailable() bool { + if f, ok := HookTPMAvailable.GetOk(); ok { + return f() + } + return false +} + +// HookHardwareAttestationAvailable is a hook that reports whether hardware +// attestation is supported and available. +var HookHardwareAttestationAvailable Hook[func() bool] + +// HardwareAttestationAvailable reports whether hardware attestation is +// supported and available (TPM on Windows/Linux, Secure Enclave on macOS|iOS, +// KeyStore on Android) +func HardwareAttestationAvailable() bool { + if f, ok := HookHardwareAttestationAvailable.GetOk(); ok { + return f() + } + return false +} diff --git a/feature/identityfederation/identityfederation.go b/feature/identityfederation/identityfederation.go new file mode 100644 index 000000000..ab1b65f12 --- /dev/null +++ b/feature/identityfederation/identityfederation.go @@ -0,0 +1,130 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package identityfederation registers support for using ID tokens to +// automatically request authkeys for logging in. +package identityfederation + +import ( + "context" + "errors" + "fmt" + "net/http" + "net/url" + "strconv" + "strings" + "time" + + "golang.org/x/oauth2" + "tailscale.com/feature" + "tailscale.com/internal/client/tailscale" + "tailscale.com/ipn" +) + +func init() { + feature.Register("identityfederation") + tailscale.HookResolveAuthKeyViaWIF.Set(resolveAuthKey) +} + +// resolveAuthKey uses OIDC identity federation to exchange the provided ID token and client ID for an authkey. +func resolveAuthKey(ctx context.Context, baseURL, clientID, idToken string, tags []string) (string, error) { + if clientID == "" { + return "", nil // Short-circuit, no client ID means not using identity federation + } + + if idToken == "" { + return "", errors.New("federated identity authkeys require --id-token") + } + if len(tags) == 0 { + return "", errors.New("federated identity authkeys require --advertise-tags") + } + if baseURL == "" { + baseURL = ipn.DefaultControlURL + } + + strippedID, ephemeral, preauth, err := parseOptionalAttributes(clientID) + if err != nil { + return "", fmt.Errorf("failed to parse optional config attributes: %w", err) + } + + accessToken, err := exchangeJWTForToken(ctx, baseURL, strippedID, idToken) + if err != nil { + return "", fmt.Errorf("failed to exchange JWT for access token: %w", err) + } + if accessToken == "" { + return "", errors.New("received empty access token from Tailscale") + } + + tsClient := tailscale.NewClient("-", tailscale.APIKey(accessToken)) + tsClient.UserAgent = "tailscale-cli-identity-federation" + tsClient.BaseURL = baseURL + + authkey, _, err := tsClient.CreateKey(ctx, tailscale.KeyCapabilities{ + Devices: tailscale.KeyDeviceCapabilities{ + Create: tailscale.KeyDeviceCreateCapabilities{ + Reusable: false, + Ephemeral: ephemeral, + Preauthorized: preauth, + Tags: tags, + }, + }, + }) + if err != nil { + return "", fmt.Errorf("unexpected error while creating authkey: %w", err) + } + if authkey == "" { + return "", errors.New("received empty authkey from control server") + } + + return authkey, nil +} + +func parseOptionalAttributes(clientID string) (strippedID string, ephemeral bool, preauthorized bool, err error) { + strippedID, attrs, found := strings.Cut(clientID, "?") + if !found { + return clientID, true, false, nil + } + + parsed, err := url.ParseQuery(attrs) + if err != nil { + return "", false, false, fmt.Errorf("failed to parse optional config attributes: %w", err) + } + + for k := range parsed { + switch k { + case "ephemeral": + ephemeral, err = strconv.ParseBool(parsed.Get(k)) + case "preauthorized": + preauthorized, err = strconv.ParseBool(parsed.Get(k)) + default: + return "", false, false, fmt.Errorf("unknown optional config attribute %q", k) + } + } + if err != nil { + return "", false, false, err + } + + return strippedID, ephemeral, preauthorized, nil +} + +// exchangeJWTForToken exchanges a JWT for a Tailscale access token. +func exchangeJWTForToken(ctx context.Context, baseURL, clientID, idToken string) (string, error) { + httpClient := &http.Client{Timeout: 10 * time.Second} + ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient) + + token, err := (&oauth2.Config{ + Endpoint: oauth2.Endpoint{ + TokenURL: fmt.Sprintf("%s/api/v2/oauth/token-exchange", baseURL), + }, + }).Exchange(ctx, "", oauth2.SetAuthURLParam("client_id", clientID), oauth2.SetAuthURLParam("jwt", idToken)) + if err != nil { + // Try to extract more detailed error message + var retrieveErr *oauth2.RetrieveError + if errors.As(err, &retrieveErr) { + return "", fmt.Errorf("token exchange failed with status %d: %s", retrieveErr.Response.StatusCode, string(retrieveErr.Body)) + } + return "", fmt.Errorf("unexpected token exchange request error: %w", err) + } + + return token.AccessToken, nil +} diff --git a/feature/identityfederation/identityfederation_test.go b/feature/identityfederation/identityfederation_test.go new file mode 100644 index 000000000..a673a4298 --- /dev/null +++ b/feature/identityfederation/identityfederation_test.go @@ -0,0 +1,175 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package identityfederation + +import ( + "context" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func TestResolveAuthKey(t *testing.T) { + tests := []struct { + name string + clientID string + idToken string + tags []string + wantAuthKey string + wantErr string + }{ + { + name: "success", + clientID: "client-123", + idToken: "token", + tags: []string{"tag:test"}, + wantAuthKey: "tskey-auth-xyz", + wantErr: "", + }, + { + name: "missing client id short-circuits without error", + clientID: "", + idToken: "token", + tags: []string{"tag:test"}, + wantAuthKey: "", + wantErr: "", + }, + { + name: "missing id token", + clientID: "client-123", + idToken: "", + tags: []string{"tag:test"}, + wantErr: "federated identity authkeys require --id-token", + }, + { + name: "missing tags", + clientID: "client-123", + idToken: "token", + tags: []string{}, + wantErr: "federated identity authkeys require --advertise-tags", + }, + { + name: "invalid client id attributes", + clientID: "client-123?invalid=value", + idToken: "token", + tags: []string{"tag:test"}, + wantErr: `failed to parse optional config attributes: unknown optional config attribute "invalid"`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + srv := mockedControlServer(t) + defer srv.Close() + + authKey, err := resolveAuthKey(context.Background(), srv.URL, tt.clientID, tt.idToken, tt.tags) + if tt.wantErr != "" { + if err == nil { + t.Errorf("resolveAuthKey() error = nil, want %q", tt.wantErr) + return + } + if err.Error() != tt.wantErr { + t.Errorf("resolveAuthKey() error = %q, want %q", err.Error(), tt.wantErr) + } + } else if err != nil { + t.Fatalf("resolveAuthKey() unexpected error = %v", err) + } + if authKey != tt.wantAuthKey { + t.Errorf("resolveAuthKey() = %q, want %q", authKey, tt.wantAuthKey) + } + }) + } +} + +func TestParseOptionalAttributes(t *testing.T) { + tests := []struct { + name string + clientID string + wantClientID string + wantEphemeral bool + wantPreauth bool + wantErr string + }{ + { + name: "default values", + clientID: "client-123", + wantClientID: "client-123", + wantEphemeral: true, + wantPreauth: false, + wantErr: "", + }, + { + name: "custom values", + clientID: "client-123?ephemeral=false&preauthorized=true", + wantClientID: "client-123", + wantEphemeral: false, + wantPreauth: true, + wantErr: "", + }, + { + name: "unknown attribute", + clientID: "client-123?unknown=value", + wantClientID: "", + wantEphemeral: false, + wantPreauth: false, + wantErr: `unknown optional config attribute "unknown"`, + }, + { + name: "invalid value", + clientID: "client-123?ephemeral=invalid", + wantClientID: "", + wantEphemeral: false, + wantPreauth: false, + wantErr: `strconv.ParseBool: parsing "invalid": invalid syntax`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + strippedID, ephemeral, preauth, err := parseOptionalAttributes(tt.clientID) + if tt.wantErr != "" { + if err == nil { + t.Errorf("parseOptionalAttributes() error = nil, want %q", tt.wantErr) + return + } + if err.Error() != tt.wantErr { + t.Errorf("parseOptionalAttributes() error = %q, want %q", err.Error(), tt.wantErr) + } + } else { + if err != nil { + t.Errorf("parseOptionalAttributes() error = %v, want nil", err) + return + } + } + if strippedID != tt.wantClientID { + t.Errorf("parseOptionalAttributes() strippedID = %v, want %v", strippedID, tt.wantClientID) + } + if ephemeral != tt.wantEphemeral { + t.Errorf("parseOptionalAttributes() ephemeral = %v, want %v", ephemeral, tt.wantEphemeral) + } + if preauth != tt.wantPreauth { + t.Errorf("parseOptionalAttributes() preauth = %v, want %v", preauth, tt.wantPreauth) + } + }) + } +} + +func mockedControlServer(t *testing.T) *httptest.Server { + t.Helper() + + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case strings.Contains(r.URL.Path, "/oauth/token-exchange"): + // OAuth2 library sends the token exchange request + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{"access_token":"access-123","token_type":"Bearer","expires_in":3600}`)) + case strings.Contains(r.URL.Path, "/api/v2/tailnet") && strings.Contains(r.URL.Path, "/keys"): + // Tailscale client creates the authkey + w.Write([]byte(`{"key":"tskey-auth-xyz","created":"2024-01-01T00:00:00Z"}`)) + default: + w.WriteHeader(http.StatusNotFound) + } + })) +} diff --git a/feature/linkspeed/doc.go b/feature/linkspeed/doc.go new file mode 100644 index 000000000..2d2fcf092 --- /dev/null +++ b/feature/linkspeed/doc.go @@ -0,0 +1,6 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package linkspeed registers support for setting the TUN link speed on Linux, +// to better integrate with system monitoring tools. +package linkspeed diff --git a/net/tstun/linkattrs_linux.go b/feature/linkspeed/linkspeed_linux.go similarity index 91% rename from net/tstun/linkattrs_linux.go rename to feature/linkspeed/linkspeed_linux.go index 681e79269..90e33d4c9 100644 --- a/net/tstun/linkattrs_linux.go +++ b/feature/linkspeed/linkspeed_linux.go @@ -1,15 +1,22 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -package tstun +//go:build linux && !android + +package linkspeed import ( "github.com/mdlayher/genetlink" "github.com/mdlayher/netlink" "github.com/tailscale/wireguard-go/tun" "golang.org/x/sys/unix" + "tailscale.com/net/tstun" ) +func init() { + tstun.HookSetLinkAttrs.Set(setLinkAttrs) +} + // setLinkSpeed sets the advertised link speed of the TUN interface. func setLinkSpeed(iface tun.Device, mbps int) error { name, err := iface.Name() diff --git a/feature/linuxdnsfight/linuxdnsfight.go b/feature/linuxdnsfight/linuxdnsfight.go new file mode 100644 index 000000000..02d99a314 --- /dev/null +++ b/feature/linuxdnsfight/linuxdnsfight.go @@ -0,0 +1,51 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build linux && !android + +// Package linuxdnsfight provides Linux support for detecting DNS fights +// (inotify watching of /etc/resolv.conf). +package linuxdnsfight + +import ( + "context" + "fmt" + + "github.com/illarion/gonotify/v3" + "tailscale.com/net/dns" +) + +func init() { + dns.HookWatchFile.Set(watchFile) +} + +// watchFile sets up an inotify watch for a given directory and +// calls the callback function every time a particular file is changed. +// The filename should be located in the provided directory. +func watchFile(ctx context.Context, dir, filename string, cb func()) error { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + const events = gonotify.IN_ATTRIB | + gonotify.IN_CLOSE_WRITE | + gonotify.IN_CREATE | + gonotify.IN_DELETE | + gonotify.IN_MODIFY | + gonotify.IN_MOVE + + watcher, err := gonotify.NewDirWatcher(ctx, events, dir) + if err != nil { + return fmt.Errorf("NewDirWatcher: %w", err) + } + + for { + select { + case event := <-watcher.C: + if event.Name == filename { + cb() + } + case <-ctx.Done(): + return ctx.Err() + } + } +} diff --git a/feature/linuxdnsfight/linuxdnsfight_test.go b/feature/linuxdnsfight/linuxdnsfight_test.go new file mode 100644 index 000000000..bd3463666 --- /dev/null +++ b/feature/linuxdnsfight/linuxdnsfight_test.go @@ -0,0 +1,63 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build linux && !android + +package linuxdnsfight + +import ( + "context" + "errors" + "fmt" + "os" + "sync/atomic" + "testing" + "time" + + "golang.org/x/sync/errgroup" +) + +func TestWatchFile(t *testing.T) { + dir := t.TempDir() + filepath := dir + "/test.txt" + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + var callbackCalled atomic.Bool + callbackDone := make(chan bool) + callback := func() { + // We only send to the channel once to avoid blocking if the + // callback is called multiple times -- this happens occasionally + // if inotify sends multiple events before we cancel the context. + if !callbackCalled.Load() { + callbackDone <- true + callbackCalled.Store(true) + } + } + + var eg errgroup.Group + eg.Go(func() error { return watchFile(ctx, dir, filepath, callback) }) + + // Keep writing until we get a callback. + func() { + for i := range 10000 { + if err := os.WriteFile(filepath, []byte(fmt.Sprintf("write%d", i)), 0644); err != nil { + t.Fatal(err) + } + select { + case <-callbackDone: + return + case <-time.After(10 * time.Millisecond): + } + } + }() + + cancel() + if err := eg.Wait(); err != nil && !errors.Is(err, context.Canceled) { + t.Error(err) + } + if !callbackCalled.Load() { + t.Error("callback was not called") + } +} diff --git a/feature/oauthkey/oauthkey.go b/feature/oauthkey/oauthkey.go new file mode 100644 index 000000000..5834c33be --- /dev/null +++ b/feature/oauthkey/oauthkey.go @@ -0,0 +1,108 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package oauthkey registers support for using OAuth client secrets to +// automatically request authkeys for logging in. +package oauthkey + +import ( + "context" + "errors" + "fmt" + "net/url" + "strconv" + "strings" + + "golang.org/x/oauth2/clientcredentials" + "tailscale.com/feature" + "tailscale.com/internal/client/tailscale" +) + +func init() { + feature.Register("oauthkey") + tailscale.HookResolveAuthKey.Set(resolveAuthKey) +} + +// resolveAuthKey either returns v unchanged (in the common case) or, if it +// starts with "tskey-client-" (as Tailscale OAuth secrets do) parses it like +// +// tskey-client-xxxx[?ephemeral=false&bar&preauthorized=BOOL&baseURL=...] +// +// and does the OAuth2 dance to get and return an authkey. The "ephemeral" +// property defaults to true if unspecified. The "preauthorized" defaults to +// false. The "baseURL" defaults to https://api.tailscale.com. +// The passed in tags are required, and must be non-empty. These will be +// set on the authkey generated by the OAuth2 dance. +func resolveAuthKey(ctx context.Context, v string, tags []string) (string, error) { + if !strings.HasPrefix(v, "tskey-client-") { + return v, nil + } + if len(tags) == 0 { + return "", errors.New("oauth authkeys require --advertise-tags") + } + + clientSecret, named, _ := strings.Cut(v, "?") + attrs, err := url.ParseQuery(named) + if err != nil { + return "", err + } + for k := range attrs { + switch k { + case "ephemeral", "preauthorized", "baseURL": + default: + return "", fmt.Errorf("unknown attribute %q", k) + } + } + getBool := func(name string, def bool) (bool, error) { + v := attrs.Get(name) + if v == "" { + return def, nil + } + ret, err := strconv.ParseBool(v) + if err != nil { + return false, fmt.Errorf("invalid attribute boolean attribute %s value %q", name, v) + } + return ret, nil + } + ephemeral, err := getBool("ephemeral", true) + if err != nil { + return "", err + } + preauth, err := getBool("preauthorized", false) + if err != nil { + return "", err + } + + baseURL := "https://api.tailscale.com" + if v := attrs.Get("baseURL"); v != "" { + baseURL = v + } + + credentials := clientcredentials.Config{ + ClientID: "some-client-id", // ignored + ClientSecret: clientSecret, + TokenURL: baseURL + "/api/v2/oauth/token", + } + + tsClient := tailscale.NewClient("-", nil) + tsClient.UserAgent = "tailscale-cli" + tsClient.HTTPClient = credentials.Client(ctx) + tsClient.BaseURL = baseURL + + caps := tailscale.KeyCapabilities{ + Devices: tailscale.KeyDeviceCapabilities{ + Create: tailscale.KeyDeviceCreateCapabilities{ + Reusable: false, + Ephemeral: ephemeral, + Preauthorized: preauth, + Tags: tags, + }, + }, + } + + authkey, _, err := tsClient.CreateKey(ctx, caps) + if err != nil { + return "", err + } + return authkey, nil +} diff --git a/feature/portlist/portlist.go b/feature/portlist/portlist.go new file mode 100644 index 000000000..7d69796ff --- /dev/null +++ b/feature/portlist/portlist.go @@ -0,0 +1,157 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package portlist contains code to poll the local system for open ports +// and report them to the control plane, if enabled on the tailnet. +package portlist + +import ( + "context" + "sync/atomic" + + "tailscale.com/envknob" + "tailscale.com/ipn" + "tailscale.com/ipn/ipnext" + "tailscale.com/ipn/ipnlocal" + "tailscale.com/ipn/policy" + "tailscale.com/portlist" + "tailscale.com/tailcfg" + "tailscale.com/types/logger" + "tailscale.com/util/eventbus" + "tailscale.com/version" +) + +func init() { + ipnext.RegisterExtension("portlist", newExtension) +} + +func newExtension(logf logger.Logf, sb ipnext.SafeBackend) (ipnext.Extension, error) { + busClient := sb.Sys().Bus.Get().Client("portlist") + e := &Extension{ + sb: sb, + busClient: busClient, + logf: logger.WithPrefix(logf, "portlist: "), + pub: eventbus.Publish[ipnlocal.PortlistServices](busClient), + pollerDone: make(chan struct{}), + wakePoller: make(chan struct{}), + } + e.ctx, e.ctxCancel = context.WithCancel(context.Background()) + return e, nil +} + +// Extension implements the portlist extension. +type Extension struct { + ctx context.Context + ctxCancel context.CancelFunc + pollerDone chan struct{} // close-only chan when poller goroutine exits + wakePoller chan struct{} // best effort chan to wake poller from sleep + busClient *eventbus.Client + pub *eventbus.Publisher[ipnlocal.PortlistServices] + logf logger.Logf + sb ipnext.SafeBackend + host ipnext.Host // from Init + + shieldsUp atomic.Bool + shouldUploadServicesAtomic atomic.Bool +} + +func (e *Extension) Name() string { return "portlist" } +func (e *Extension) Shutdown() error { + e.ctxCancel() + e.busClient.Close() + <-e.pollerDone + return nil +} + +func (e *Extension) Init(h ipnext.Host) error { + if !envknob.BoolDefaultTrue("TS_PORTLIST") { + return ipnext.SkipExtension + } + + e.host = h + h.Hooks().ShouldUploadServices.Set(e.shouldUploadServicesAtomic.Load) + h.Hooks().ProfileStateChange.Add(e.onChangeProfile) + h.Hooks().OnSelfChange.Add(e.onSelfChange) + + // TODO(nickkhyl): remove this after the profileManager refactoring. + // See tailscale/tailscale#15974. + // This same workaround appears in feature/taildrop/ext.go. + profile, prefs := h.Profiles().CurrentProfileState() + e.onChangeProfile(profile, prefs, false) + + go e.runPollLoop() + return nil +} + +func (e *Extension) onSelfChange(tailcfg.NodeView) { + e.updateShouldUploadServices() +} + +func (e *Extension) onChangeProfile(_ ipn.LoginProfileView, prefs ipn.PrefsView, sameNode bool) { + e.shieldsUp.Store(prefs.ShieldsUp()) + e.updateShouldUploadServices() +} + +func (e *Extension) updateShouldUploadServices() { + v := !e.shieldsUp.Load() && e.host.NodeBackend().CollectServices() + if e.shouldUploadServicesAtomic.CompareAndSwap(!v, v) && v { + // Upon transition from false to true (enabling service reporting), try + // to wake the poller to do an immediate poll if it's sleeping. + // It's not a big deal if we miss waking it. It'll get to it soon enough. + select { + case e.wakePoller <- struct{}{}: + default: + } + } +} + +// runPollLoop is a goroutine that periodically checks the open +// ports and publishes them if they've changed. +func (e *Extension) runPollLoop() { + defer close(e.pollerDone) + + var poller portlist.Poller + + ticker, tickerChannel := e.sb.Clock().NewTicker(portlist.PollInterval()) + defer ticker.Stop() + for { + select { + case <-tickerChannel: + case <-e.wakePoller: + case <-e.ctx.Done(): + return + } + + if !e.shouldUploadServicesAtomic.Load() { + continue + } + + ports, changed, err := poller.Poll() + if err != nil { + e.logf("Poll: %v", err) + // TODO: this is kinda weird that we just return here and never try + // again. Maybe that was because all errors are assumed to be + // permission errors and thus permanent? Audit varioys OS + // implementation and check error types, and then make this check + // for permanent vs temporary errors and keep looping with a backoff + // for temporary errors? But for now we just give up, like we always + // have. + return + } + if !changed { + continue + } + sl := []tailcfg.Service{} + for _, p := range ports { + s := tailcfg.Service{ + Proto: tailcfg.ServiceProto(p.Proto), + Port: p.Port, + Description: p.Process, + } + if policy.IsInterestingService(s, version.OS()) { + sl = append(sl, s) + } + } + e.pub.Publish(ipnlocal.PortlistServices(sl)) + } +} diff --git a/feature/portmapper/portmapper.go b/feature/portmapper/portmapper.go new file mode 100644 index 000000000..d1b903cb6 --- /dev/null +++ b/feature/portmapper/portmapper.go @@ -0,0 +1,40 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package portmapper registers support for NAT-PMP, PCP, and UPnP port +// mapping protocols to help get direction connections through NATs. +package portmapper + +import ( + "tailscale.com/feature" + "tailscale.com/net/netmon" + "tailscale.com/net/portmapper" + "tailscale.com/net/portmapper/portmappertype" + "tailscale.com/types/logger" + "tailscale.com/util/eventbus" +) + +func init() { + feature.Register("portmapper") + portmappertype.HookNewPortMapper.Set(newPortMapper) +} + +func newPortMapper( + logf logger.Logf, + bus *eventbus.Bus, + netMon *netmon.Monitor, + disableUPnPOrNil func() bool, + onlyTCP443OrNil func() bool) portmappertype.Client { + + pm := portmapper.NewClient(portmapper.Config{ + EventBus: bus, + Logf: logf, + NetMon: netMon, + DebugKnobs: &portmapper.DebugKnobs{ + DisableAll: onlyTCP443OrNil, + DisableUPnPFunc: disableUPnPOrNil, + }, + }) + pm.SetGatewayLookupFunc(netMon.GatewayAndSelfIP) + return pm +} diff --git a/feature/posture/posture.go b/feature/posture/posture.go new file mode 100644 index 000000000..8e1945d7d --- /dev/null +++ b/feature/posture/posture.go @@ -0,0 +1,114 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package posture registers support for device posture checking, +// reporting machine-specific information to the control plane +// when enabled by the user and tailnet. +package posture + +import ( + "encoding/json" + "net/http" + + "tailscale.com/ipn/ipnext" + "tailscale.com/ipn/ipnlocal" + "tailscale.com/posture" + "tailscale.com/syncs" + "tailscale.com/tailcfg" + "tailscale.com/types/logger" + "tailscale.com/util/syspolicy/pkey" + "tailscale.com/util/syspolicy/ptype" +) + +func init() { + ipnext.RegisterExtension("posture", newExtension) + ipnlocal.RegisterC2N("GET /posture/identity", handleC2NPostureIdentityGet) +} + +func newExtension(logf logger.Logf, b ipnext.SafeBackend) (ipnext.Extension, error) { + e := &extension{ + logf: logger.WithPrefix(logf, "posture: "), + } + return e, nil +} + +type extension struct { + logf logger.Logf + + // lastKnownHardwareAddrs is a list of the previous known hardware addrs. + // Previously known hwaddrs are kept to work around an issue on Windows + // where all addresses might disappear. + // http://go/corp/25168 + lastKnownHardwareAddrs syncs.AtomicValue[[]string] +} + +func (e *extension) Name() string { return "posture" } +func (e *extension) Init(h ipnext.Host) error { return nil } +func (e *extension) Shutdown() error { return nil } + +func handleC2NPostureIdentityGet(b *ipnlocal.LocalBackend, w http.ResponseWriter, r *http.Request) { + e, ok := ipnlocal.GetExt[*extension](b) + if !ok { + http.Error(w, "posture extension not available", http.StatusInternalServerError) + return + } + e.logf("c2n: GET /posture/identity received") + + res := tailcfg.C2NPostureIdentityResponse{} + + // Only collect posture identity if enabled on the client, + // this will first check syspolicy, MDM settings like Registry + // on Windows or defaults on macOS. If they are not set, it falls + // back to the cli-flag, `--posture-checking`. + choice, err := b.PolicyClient().GetPreferenceOption(pkey.PostureChecking, ptype.ShowChoiceByPolicy) + if err != nil { + e.logf( + "c2n: failed to read PostureChecking from syspolicy, returning default from CLI: %s; got error: %s", + b.Prefs().PostureChecking(), + err, + ) + } + + if choice.ShouldEnable(b.Prefs().PostureChecking()) { + res.SerialNumbers, err = posture.GetSerialNumbers(b.PolicyClient(), e.logf) + if err != nil { + e.logf("c2n: GetSerialNumbers returned error: %v", err) + } + + // TODO(tailscale/corp#21371, 2024-07-10): once this has landed in a stable release + // and looks good in client metrics, remove this parameter and always report MAC + // addresses. + if r.FormValue("hwaddrs") == "true" { + res.IfaceHardwareAddrs, err = e.getHardwareAddrs() + if err != nil { + e.logf("c2n: GetHardwareAddrs returned error: %v", err) + } + } + } else { + res.PostureDisabled = true + } + + e.logf("c2n: posture identity disabled=%v reported %d serials %d hwaddrs", res.PostureDisabled, len(res.SerialNumbers), len(res.IfaceHardwareAddrs)) + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(res) +} + +// getHardwareAddrs returns the hardware addresses for the machine. If the list +// of hardware addresses is empty, it will return the previously known hardware +// addresses. Both the current, and previously known hardware addresses might be +// empty. +func (e *extension) getHardwareAddrs() ([]string, error) { + addrs, err := posture.GetHardwareAddrs() + if err != nil { + return nil, err + } + + if len(addrs) == 0 { + e.logf("getHardwareAddrs: got empty list of hwaddrs, returning previous list") + return e.lastKnownHardwareAddrs.Load(), nil + } + + e.lastKnownHardwareAddrs.Store(addrs) + return addrs, nil +} diff --git a/feature/relayserver/relayserver.go b/feature/relayserver/relayserver.go new file mode 100644 index 000000000..7d12d62e5 --- /dev/null +++ b/feature/relayserver/relayserver.go @@ -0,0 +1,252 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package relayserver registers the relay server feature and implements its +// associated ipnext.Extension. +package relayserver + +import ( + "encoding/json" + "fmt" + "net/http" + + "tailscale.com/disco" + "tailscale.com/feature" + "tailscale.com/ipn" + "tailscale.com/ipn/ipnext" + "tailscale.com/ipn/localapi" + "tailscale.com/net/udprelay" + "tailscale.com/net/udprelay/endpoint" + "tailscale.com/net/udprelay/status" + "tailscale.com/syncs" + "tailscale.com/tailcfg" + "tailscale.com/types/key" + "tailscale.com/types/logger" + "tailscale.com/types/ptr" + "tailscale.com/util/eventbus" + "tailscale.com/wgengine/magicsock" +) + +// featureName is the name of the feature implemented by this package. +// It is also the [extension] name and the log prefix. +const featureName = "relayserver" + +func init() { + feature.Register(featureName) + ipnext.RegisterExtension(featureName, newExtension) + localapi.Register("debug-peer-relay-sessions", servePeerRelayDebugSessions) +} + +// servePeerRelayDebugSessions is an HTTP handler for the Local API that +// returns debug/status information for peer relay sessions being relayed by +// this Tailscale node. It writes a JSON-encoded [status.ServerStatus] into the +// HTTP response, or returns an HTTP 405/500 with error text as the body. +func servePeerRelayDebugSessions(h *localapi.Handler, w http.ResponseWriter, r *http.Request) { + if r.Method != "GET" { + http.Error(w, "GET required", http.StatusMethodNotAllowed) + return + } + + var e *extension + if ok := h.LocalBackend().FindMatchingExtension(&e); !ok { + http.Error(w, "peer relay server extension unavailable", http.StatusInternalServerError) + return + } + + st := e.serverStatus() + j, err := json.Marshal(st) + if err != nil { + http.Error(w, fmt.Sprintf("failed to marshal json: %v", err), http.StatusInternalServerError) + return + } + w.Write(j) +} + +// newExtension is an [ipnext.NewExtensionFn] that creates a new relay server +// extension. It is registered with [ipnext.RegisterExtension] if the package is +// imported. +func newExtension(logf logger.Logf, sb ipnext.SafeBackend) (ipnext.Extension, error) { + e := &extension{ + newServerFn: func(logf logger.Logf, port int, onlyStaticAddrPorts bool) (relayServer, error) { + return udprelay.NewServer(logf, port, onlyStaticAddrPorts) + }, + logf: logger.WithPrefix(logf, featureName+": "), + } + e.ec = sb.Sys().Bus.Get().Client("relayserver.extension") + e.respPub = eventbus.Publish[magicsock.UDPRelayAllocResp](e.ec) + eventbus.SubscribeFunc(e.ec, e.onDERPMapView) + eventbus.SubscribeFunc(e.ec, e.onAllocReq) + return e, nil +} + +// relayServer is an interface for [udprelay.Server]. +type relayServer interface { + Close() error + AllocateEndpoint(discoA, discoB key.DiscoPublic) (endpoint.ServerEndpoint, error) + GetSessions() []status.ServerSession + SetDERPMapView(tailcfg.DERPMapView) +} + +// extension is an [ipnext.Extension] managing the relay server on platforms +// that import this package. +type extension struct { + newServerFn func(logf logger.Logf, port int, onlyStaticAddrPorts bool) (relayServer, error) // swappable for tests + logf logger.Logf + ec *eventbus.Client + respPub *eventbus.Publisher[magicsock.UDPRelayAllocResp] + + mu syncs.Mutex // guards the following fields + shutdown bool // true if Shutdown() has been called + rs relayServer // nil when disabled + port *int // ipn.Prefs.RelayServerPort, nil if disabled + derpMapView tailcfg.DERPMapView // latest seen over the eventbus + hasNodeAttrDisableRelayServer bool // [tailcfg.NodeAttrDisableRelayServer] +} + +// Name implements [ipnext.Extension]. +func (e *extension) Name() string { + return featureName +} + +// Init implements [ipnext.Extension] by registering callbacks and providers +// for the duration of the extension's lifetime. +func (e *extension) Init(host ipnext.Host) error { + profile, prefs := host.Profiles().CurrentProfileState() + e.profileStateChanged(profile, prefs, false) + host.Hooks().ProfileStateChange.Add(e.profileStateChanged) + host.Hooks().OnSelfChange.Add(e.selfNodeViewChanged) + return nil +} + +func (e *extension) onDERPMapView(view tailcfg.DERPMapView) { + e.mu.Lock() + defer e.mu.Unlock() + e.derpMapView = view + if e.rs != nil { + e.rs.SetDERPMapView(view) + } +} + +func (e *extension) onAllocReq(req magicsock.UDPRelayAllocReq) { + e.mu.Lock() + defer e.mu.Unlock() + if e.shutdown { + return + } + if e.rs == nil { + if !e.relayServerShouldBeRunningLocked() { + return + } + e.tryStartRelayServerLocked() + if e.rs == nil { + return + } + } + se, err := e.rs.AllocateEndpoint(req.Message.ClientDisco[0], req.Message.ClientDisco[1]) + if err != nil { + e.logf("error allocating endpoint: %v", err) + return + } + e.respPub.Publish(magicsock.UDPRelayAllocResp{ + ReqRxFromNodeKey: req.RxFromNodeKey, + ReqRxFromDiscoKey: req.RxFromDiscoKey, + Message: &disco.AllocateUDPRelayEndpointResponse{ + Generation: req.Message.Generation, + UDPRelayEndpoint: disco.UDPRelayEndpoint{ + ServerDisco: se.ServerDisco, + ClientDisco: se.ClientDisco, + LamportID: se.LamportID, + VNI: se.VNI, + BindLifetime: se.BindLifetime.Duration, + SteadyStateLifetime: se.SteadyStateLifetime.Duration, + AddrPorts: se.AddrPorts, + }, + }, + }) +} + +func (e *extension) tryStartRelayServerLocked() { + rs, err := e.newServerFn(e.logf, *e.port, false) + if err != nil { + e.logf("error initializing server: %v", err) + return + } + e.rs = rs + e.rs.SetDERPMapView(e.derpMapView) +} + +func (e *extension) relayServerShouldBeRunningLocked() bool { + return !e.shutdown && e.port != nil && !e.hasNodeAttrDisableRelayServer +} + +// handleRelayServerLifetimeLocked handles the lifetime of [e.rs]. +func (e *extension) handleRelayServerLifetimeLocked() { + if !e.relayServerShouldBeRunningLocked() { + e.stopRelayServerLocked() + return + } else if e.rs != nil { + return // already running + } + e.tryStartRelayServerLocked() +} + +func (e *extension) selfNodeViewChanged(nodeView tailcfg.NodeView) { + e.mu.Lock() + defer e.mu.Unlock() + e.hasNodeAttrDisableRelayServer = nodeView.HasCap(tailcfg.NodeAttrDisableRelayServer) + e.handleRelayServerLifetimeLocked() +} + +func (e *extension) profileStateChanged(_ ipn.LoginProfileView, prefs ipn.PrefsView, sameNode bool) { + e.mu.Lock() + defer e.mu.Unlock() + newPort, ok := prefs.RelayServerPort().GetOk() + enableOrDisableServer := ok != (e.port != nil) + portChanged := ok && e.port != nil && newPort != *e.port + if enableOrDisableServer || portChanged || !sameNode { + e.stopRelayServerLocked() + e.port = nil + if ok { + e.port = ptr.To(newPort) + } + } + e.handleRelayServerLifetimeLocked() +} + +func (e *extension) stopRelayServerLocked() { + if e.rs != nil { + e.rs.Close() + } + e.rs = nil +} + +// Shutdown implements [ipnlocal.Extension]. +func (e *extension) Shutdown() error { + // [extension.mu] must not be held when closing the [eventbus.Client]. Close + // blocks until all [eventbus.SubscribeFunc]'s have returned, and the ones + // used in this package also acquire [extension.mu]. See #17894. + e.ec.Close() + e.mu.Lock() + defer e.mu.Unlock() + e.shutdown = true + e.stopRelayServerLocked() + return nil +} + +// serverStatus gathers and returns current peer relay server status information +// for this Tailscale node, and status of each peer relay session this node is +// relaying (if any). +func (e *extension) serverStatus() status.ServerStatus { + e.mu.Lock() + defer e.mu.Unlock() + st := status.ServerStatus{ + UDPPort: nil, + Sessions: nil, + } + if e.rs == nil { + return st + } + st.UDPPort = ptr.To(*e.port) + st.Sessions = e.rs.GetSessions() + return st +} diff --git a/feature/relayserver/relayserver_test.go b/feature/relayserver/relayserver_test.go new file mode 100644 index 000000000..3d71c55d7 --- /dev/null +++ b/feature/relayserver/relayserver_test.go @@ -0,0 +1,308 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package relayserver + +import ( + "errors" + "reflect" + "testing" + + "tailscale.com/ipn" + "tailscale.com/net/udprelay/endpoint" + "tailscale.com/net/udprelay/status" + "tailscale.com/tailcfg" + "tailscale.com/tsd" + "tailscale.com/tstime" + "tailscale.com/types/key" + "tailscale.com/types/logger" + "tailscale.com/types/ptr" +) + +func Test_extension_profileStateChanged(t *testing.T) { + prefsWithPortOne := ipn.Prefs{RelayServerPort: ptr.To(1)} + prefsWithNilPort := ipn.Prefs{RelayServerPort: nil} + + type fields struct { + port *int + rs relayServer + } + type args struct { + prefs ipn.PrefsView + sameNode bool + } + tests := []struct { + name string + fields fields + args args + wantPort *int + wantRelayServerFieldNonNil bool + wantRelayServerFieldMutated bool + }{ + { + name: "no changes non-nil port previously running", + fields: fields{ + port: ptr.To(1), + rs: mockRelayServerNotZeroVal(), + }, + args: args{ + prefs: prefsWithPortOne.View(), + sameNode: true, + }, + wantPort: ptr.To(1), + wantRelayServerFieldNonNil: true, + wantRelayServerFieldMutated: false, + }, + { + name: "prefs port nil", + fields: fields{ + port: ptr.To(1), + }, + args: args{ + prefs: prefsWithNilPort.View(), + sameNode: true, + }, + wantPort: nil, + wantRelayServerFieldNonNil: false, + wantRelayServerFieldMutated: false, + }, + { + name: "prefs port nil previously running", + fields: fields{ + port: ptr.To(1), + rs: mockRelayServerNotZeroVal(), + }, + args: args{ + prefs: prefsWithNilPort.View(), + sameNode: true, + }, + wantPort: nil, + wantRelayServerFieldNonNil: false, + wantRelayServerFieldMutated: true, + }, + { + name: "prefs port changed", + fields: fields{ + port: ptr.To(2), + }, + args: args{ + prefs: prefsWithPortOne.View(), + sameNode: true, + }, + wantPort: ptr.To(1), + wantRelayServerFieldNonNil: true, + wantRelayServerFieldMutated: true, + }, + { + name: "prefs port changed previously running", + fields: fields{ + port: ptr.To(2), + rs: mockRelayServerNotZeroVal(), + }, + args: args{ + prefs: prefsWithPortOne.View(), + sameNode: true, + }, + wantPort: ptr.To(1), + wantRelayServerFieldNonNil: true, + wantRelayServerFieldMutated: true, + }, + { + name: "sameNode false", + fields: fields{ + port: ptr.To(1), + }, + args: args{ + prefs: prefsWithPortOne.View(), + sameNode: false, + }, + wantPort: ptr.To(1), + wantRelayServerFieldNonNil: true, + wantRelayServerFieldMutated: true, + }, + { + name: "sameNode false previously running", + fields: fields{ + port: ptr.To(1), + rs: mockRelayServerNotZeroVal(), + }, + args: args{ + prefs: prefsWithPortOne.View(), + sameNode: false, + }, + wantPort: ptr.To(1), + wantRelayServerFieldNonNil: true, + wantRelayServerFieldMutated: true, + }, + { + name: "prefs port non-nil extension port nil", + fields: fields{ + port: nil, + }, + args: args{ + prefs: prefsWithPortOne.View(), + sameNode: false, + }, + wantPort: ptr.To(1), + wantRelayServerFieldNonNil: true, + wantRelayServerFieldMutated: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + sys := tsd.NewSystem() + ipne, err := newExtension(logger.Discard, mockSafeBackend{sys}) + if err != nil { + t.Fatal(err) + } + e := ipne.(*extension) + e.newServerFn = func(logf logger.Logf, port int, onlyStaticAddrPorts bool) (relayServer, error) { + return &mockRelayServer{}, nil + } + e.port = tt.fields.port + e.rs = tt.fields.rs + defer e.Shutdown() + e.profileStateChanged(ipn.LoginProfileView{}, tt.args.prefs, tt.args.sameNode) + if tt.wantRelayServerFieldNonNil != (e.rs != nil) { + t.Errorf("wantRelayServerFieldNonNil: %v != (e.rs != nil): %v", tt.wantRelayServerFieldNonNil, e.rs != nil) + } + if (tt.wantPort == nil) != (e.port == nil) { + t.Errorf("(tt.wantPort == nil): %v != (e.port == nil): %v", tt.wantPort == nil, e.port == nil) + } else if tt.wantPort != nil && *tt.wantPort != *e.port { + t.Errorf("wantPort: %d != *e.port: %d", *tt.wantPort, *e.port) + } + if tt.wantRelayServerFieldMutated != !reflect.DeepEqual(tt.fields.rs, e.rs) { + t.Errorf("wantRelayServerFieldMutated: %v != !reflect.DeepEqual(tt.fields.rs, e.rs): %v", tt.wantRelayServerFieldMutated, !reflect.DeepEqual(tt.fields.rs, e.rs)) + } + }) + } +} + +func mockRelayServerNotZeroVal() *mockRelayServer { + return &mockRelayServer{true} +} + +type mockRelayServer struct { + set bool +} + +func (mockRelayServer) Close() error { return nil } +func (mockRelayServer) AllocateEndpoint(_, _ key.DiscoPublic) (endpoint.ServerEndpoint, error) { + return endpoint.ServerEndpoint{}, errors.New("not implemented") +} +func (mockRelayServer) GetSessions() []status.ServerSession { return nil } +func (mockRelayServer) SetDERPMapView(tailcfg.DERPMapView) { return } + +type mockSafeBackend struct { + sys *tsd.System +} + +func (m mockSafeBackend) Sys() *tsd.System { return m.sys } +func (mockSafeBackend) Clock() tstime.Clock { return nil } +func (mockSafeBackend) TailscaleVarRoot() string { return "" } + +func Test_extension_handleRelayServerLifetimeLocked(t *testing.T) { + tests := []struct { + name string + shutdown bool + port *int + rs relayServer + hasNodeAttrDisableRelayServer bool + wantRelayServerFieldNonNil bool + wantRelayServerFieldMutated bool + }{ + { + name: "want running", + shutdown: false, + port: ptr.To(1), + hasNodeAttrDisableRelayServer: false, + wantRelayServerFieldNonNil: true, + wantRelayServerFieldMutated: true, + }, + { + name: "want running previously running", + shutdown: false, + port: ptr.To(1), + rs: mockRelayServerNotZeroVal(), + hasNodeAttrDisableRelayServer: false, + wantRelayServerFieldNonNil: true, + wantRelayServerFieldMutated: false, + }, + { + name: "shutdown true", + shutdown: true, + port: ptr.To(1), + hasNodeAttrDisableRelayServer: false, + wantRelayServerFieldNonNil: false, + wantRelayServerFieldMutated: false, + }, + { + name: "shutdown true previously running", + shutdown: true, + port: ptr.To(1), + rs: mockRelayServerNotZeroVal(), + hasNodeAttrDisableRelayServer: false, + wantRelayServerFieldNonNil: false, + wantRelayServerFieldMutated: true, + }, + { + name: "port nil", + shutdown: false, + port: nil, + hasNodeAttrDisableRelayServer: false, + wantRelayServerFieldNonNil: false, + wantRelayServerFieldMutated: false, + }, + { + name: "port nil previously running", + shutdown: false, + port: nil, + rs: mockRelayServerNotZeroVal(), + hasNodeAttrDisableRelayServer: false, + wantRelayServerFieldNonNil: false, + wantRelayServerFieldMutated: true, + }, + { + name: "hasNodeAttrDisableRelayServer true", + shutdown: false, + port: nil, + hasNodeAttrDisableRelayServer: true, + wantRelayServerFieldNonNil: false, + wantRelayServerFieldMutated: false, + }, + { + name: "hasNodeAttrDisableRelayServer true previously running", + shutdown: false, + port: nil, + rs: mockRelayServerNotZeroVal(), + hasNodeAttrDisableRelayServer: true, + wantRelayServerFieldNonNil: false, + wantRelayServerFieldMutated: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + sys := tsd.NewSystem() + ipne, err := newExtension(logger.Discard, mockSafeBackend{sys}) + if err != nil { + t.Fatal(err) + } + e := ipne.(*extension) + e.newServerFn = func(logf logger.Logf, port int, onlyStaticAddrPorts bool) (relayServer, error) { + return &mockRelayServer{}, nil + } + e.shutdown = tt.shutdown + e.port = tt.port + e.rs = tt.rs + e.hasNodeAttrDisableRelayServer = tt.hasNodeAttrDisableRelayServer + e.handleRelayServerLifetimeLocked() + defer e.Shutdown() + if tt.wantRelayServerFieldNonNil != (e.rs != nil) { + t.Errorf("wantRelayServerFieldNonNil: %v != (e.rs != nil): %v", tt.wantRelayServerFieldNonNil, e.rs != nil) + } + if tt.wantRelayServerFieldMutated != !reflect.DeepEqual(tt.rs, e.rs) { + t.Errorf("wantRelayServerFieldMutated: %v != !reflect.DeepEqual(tt.rs, e.rs): %v", tt.wantRelayServerFieldMutated, !reflect.DeepEqual(tt.rs, e.rs)) + } + }) + } +} diff --git a/feature/sdnotify.go b/feature/sdnotify.go new file mode 100644 index 000000000..7a786dfab --- /dev/null +++ b/feature/sdnotify.go @@ -0,0 +1,39 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package feature + +import ( + "runtime" + + "tailscale.com/feature/buildfeatures" +) + +// HookSystemdReady sends a readiness to systemd. This will unblock service +// dependents from starting. +var HookSystemdReady Hook[func()] + +// HookSystemdStatus holds a func that will send a single line status update to +// systemd so that information shows up in systemctl output. +var HookSystemdStatus Hook[func(format string, args ...any)] + +// SystemdStatus sends a single line status update to systemd so that +// information shows up in systemctl output. +// +// It does nothing on non-Linux systems or if the binary was built without +// the sdnotify feature. +func SystemdStatus(format string, args ...any) { + if !CanSystemdStatus { // mid-stack inlining DCE + return + } + if f, ok := HookSystemdStatus.GetOk(); ok { + f(format, args...) + } +} + +// CanSystemdStatus reports whether the current build has systemd notifications +// linked in. +// +// It's effectively the same as HookSystemdStatus.IsSet(), but a constant for +// dead code elimination reasons. +const CanSystemdStatus = runtime.GOOS == "linux" && buildfeatures.HasSDNotify diff --git a/util/systemd/doc.go b/feature/sdnotify/sdnotify.go similarity index 81% rename from util/systemd/doc.go rename to feature/sdnotify/sdnotify.go index 0c28e1823..d13aa63f2 100644 --- a/util/systemd/doc.go +++ b/feature/sdnotify/sdnotify.go @@ -2,7 +2,7 @@ // SPDX-License-Identifier: BSD-3-Clause /* -Package systemd contains a minimal wrapper around systemd-notify to enable +Package sdnotify contains a minimal wrapper around systemd-notify to enable applications to signal readiness and status to systemd. This package will only have effect on Linux systems running Tailscale in a @@ -10,4 +10,4 @@ systemd unit with the Type=notify flag set. On other operating systems (or when running in a Linux distro without being run from inside systemd) this package will become a no-op. */ -package systemd +package sdnotify diff --git a/util/systemd/systemd_linux.go b/feature/sdnotify/sdnotify_linux.go similarity index 79% rename from util/systemd/systemd_linux.go rename to feature/sdnotify/sdnotify_linux.go index 909cfcb20..2b13e24bb 100644 --- a/util/systemd/systemd_linux.go +++ b/feature/sdnotify/sdnotify_linux.go @@ -1,9 +1,9 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -//go:build linux +//go:build linux && !android -package systemd +package sdnotify import ( "errors" @@ -12,8 +12,14 @@ import ( "sync" "github.com/mdlayher/sdnotify" + "tailscale.com/feature" ) +func init() { + feature.HookSystemdReady.Set(ready) + feature.HookSystemdStatus.Set(status) +} + var getNotifyOnce struct { sync.Once v *sdnotify.Notifier @@ -23,8 +29,8 @@ type logOnce struct { sync.Once } -func (l *logOnce) logf(format string, args ...any) { - l.Once.Do(func() { +func (lg *logOnce) logf(format string, args ...any) { + lg.Once.Do(func() { log.Printf(format, args...) }) } @@ -46,15 +52,15 @@ func notifier() *sdnotify.Notifier { return getNotifyOnce.v } -// Ready signals readiness to systemd. This will unblock service dependents from starting. -func Ready() { +// ready signals readiness to systemd. This will unblock service dependents from starting. +func ready() { err := notifier().Notify(sdnotify.Ready) if err != nil { readyOnce.logf("systemd: error notifying: %v", err) } } -// Status sends a single line status update to systemd so that information shows up +// status sends a single line status update to systemd so that information shows up // in systemctl output. For example: // // $ systemctl status tailscale @@ -69,7 +75,7 @@ func Ready() { // CPU: 2min 38.469s // CGroup: /system.slice/tailscale.service // └─26741 /nix/store/sv6cj4mw2jajm9xkbwj07k29dj30lh0n-tailscale-date.20200727/bin/tailscaled --port 41641 -func Status(format string, args ...any) { +func status(format string, args ...any) { err := notifier().Notify(sdnotify.Statusf(format, args...)) if err != nil { statusOnce.logf("systemd: error notifying: %v", err) diff --git a/feature/syspolicy/syspolicy.go b/feature/syspolicy/syspolicy.go new file mode 100644 index 000000000..08c3cf373 --- /dev/null +++ b/feature/syspolicy/syspolicy.go @@ -0,0 +1,7 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package syspolicy provides an interface for system-wide policy management. +package syspolicy + +import _ "tailscale.com/util/syspolicy" // for its registration side effects diff --git a/taildrop/delete.go b/feature/taildrop/delete.go similarity index 82% rename from taildrop/delete.go rename to feature/taildrop/delete.go index aaef34df1..8b03a125f 100644 --- a/taildrop/delete.go +++ b/feature/taildrop/delete.go @@ -6,15 +6,12 @@ package taildrop import ( "container/list" "context" - "io/fs" "os" - "path/filepath" "strings" "sync" "time" "tailscale.com/ipn" - "tailscale.com/syncs" "tailscale.com/tstime" "tailscale.com/types/logger" ) @@ -28,7 +25,6 @@ const deleteDelay = time.Hour type fileDeleter struct { logf logger.Logf clock tstime.DefaultClock - dir string event func(string) // called for certain events; for testing only mu sync.Mutex @@ -36,9 +32,10 @@ type fileDeleter struct { byName map[string]*list.Element emptySignal chan struct{} // signal that the queue is empty - group syncs.WaitGroup + group sync.WaitGroup shutdownCtx context.Context shutdown context.CancelFunc + fs FileOps // must be used for all filesystem operations } // deleteFile is a specific file to delete after deleteDelay. @@ -47,18 +44,17 @@ type deleteFile struct { inserted time.Time } -func (d *fileDeleter) Init(m *Manager, eventHook func(string)) { +func (d *fileDeleter) Init(m *manager, eventHook func(string)) { d.logf = m.opts.Logf d.clock = m.opts.Clock - d.dir = m.opts.Dir d.event = eventHook + d.fs = m.opts.fileOps d.byName = make(map[string]*list.Element) d.emptySignal = make(chan struct{}) d.shutdownCtx, d.shutdown = context.WithCancel(context.Background()) // From a cold-start, load the list of partial and deleted files. - // // Only run this if we have ever received at least one file // to avoid ever touching the taildrop directory on systems (e.g., MacOS) // that pop up a security dialog window upon first access. @@ -71,38 +67,45 @@ func (d *fileDeleter) Init(m *Manager, eventHook func(string)) { d.group.Go(func() { d.event("start full-scan") defer d.event("end full-scan") - rangeDir(d.dir, func(de fs.DirEntry) bool { + + if d.fs == nil { + d.logf("deleter: nil FileOps") + } + + files, err := d.fs.ListFiles() + if err != nil { + d.logf("deleter: ListDir error: %v", err) + return + } + for _, filename := range files { switch { case d.shutdownCtx.Err() != nil: - return false // terminate early - case !de.Type().IsRegular(): - return true - case strings.HasSuffix(de.Name(), partialSuffix): + return // terminate early + case strings.HasSuffix(filename, partialSuffix): // Only enqueue the file for deletion if there is no active put. - nameID := strings.TrimSuffix(de.Name(), partialSuffix) + nameID := strings.TrimSuffix(filename, partialSuffix) if i := strings.LastIndexByte(nameID, '.'); i > 0 { - key := incomingFileKey{ClientID(nameID[i+len("."):]), nameID[:i]} + key := incomingFileKey{clientID(nameID[i+len("."):]), nameID[:i]} m.incomingFiles.LoadFunc(key, func(_ *incomingFile, loaded bool) { if !loaded { - d.Insert(de.Name()) + d.Insert(filename) } }) } else { - d.Insert(de.Name()) + d.Insert(filename) } - case strings.HasSuffix(de.Name(), deletedSuffix): + case strings.HasSuffix(filename, deletedSuffix): // Best-effort immediate deletion of deleted files. - name := strings.TrimSuffix(de.Name(), deletedSuffix) - if os.Remove(filepath.Join(d.dir, name)) == nil { - if os.Remove(filepath.Join(d.dir, de.Name())) == nil { - break + name := strings.TrimSuffix(filename, deletedSuffix) + if d.fs.Remove(name) == nil { + if d.fs.Remove(filename) == nil { + continue } } - // Otherwise, enqueue the file for later deletion. - d.Insert(de.Name()) + // Otherwise enqueue for later deletion. + d.Insert(filename) } - return true - }) + } }) } @@ -149,13 +152,13 @@ func (d *fileDeleter) waitAndDelete(wait time.Duration) { // Delete the expired file. if name, ok := strings.CutSuffix(file.name, deletedSuffix); ok { - if err := os.Remove(filepath.Join(d.dir, name)); err != nil && !os.IsNotExist(err) { + if err := d.fs.Remove(name); err != nil && !os.IsNotExist(err) { d.logf("could not delete: %v", redactError(err)) failed = append(failed, elem) continue } } - if err := os.Remove(filepath.Join(d.dir, file.name)); err != nil && !os.IsNotExist(err) { + if err := d.fs.Remove(file.name); err != nil && !os.IsNotExist(err) { d.logf("could not delete: %v", redactError(err)) failed = append(failed, elem) continue diff --git a/taildrop/delete_test.go b/feature/taildrop/delete_test.go similarity index 83% rename from taildrop/delete_test.go rename to feature/taildrop/delete_test.go index 5fa4b9c37..36950f582 100644 --- a/taildrop/delete_test.go +++ b/feature/taildrop/delete_test.go @@ -5,7 +5,6 @@ package taildrop import ( "os" - "path/filepath" "slices" "testing" "time" @@ -20,11 +19,20 @@ import ( func TestDeleter(t *testing.T) { dir := t.TempDir() - must.Do(touchFile(filepath.Join(dir, "foo.partial"))) - must.Do(touchFile(filepath.Join(dir, "bar.partial"))) - must.Do(touchFile(filepath.Join(dir, "fizz"))) - must.Do(touchFile(filepath.Join(dir, "fizz.deleted"))) - must.Do(touchFile(filepath.Join(dir, "buzz.deleted"))) // lacks a matching "buzz" file + var m manager + var fd fileDeleter + m.opts.Logf = t.Logf + m.opts.Clock = tstime.DefaultClock{Clock: tstest.NewClock(tstest.ClockOpts{ + Start: time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), + })} + m.opts.State = must.Get(mem.New(nil, "")) + m.opts.fileOps, _ = newFileOps(dir) + + must.Do(m.touchFile("foo.partial")) + must.Do(m.touchFile("bar.partial")) + must.Do(m.touchFile("fizz")) + must.Do(m.touchFile("fizz.deleted")) + must.Do(m.touchFile("buzz.deleted")) // lacks a matching "buzz" file checkDirectory := func(want ...string) { t.Helper() @@ -69,12 +77,10 @@ func TestDeleter(t *testing.T) { } eventHook := func(event string) { eventsChan <- event } - var m Manager - var fd fileDeleter m.opts.Logf = t.Logf m.opts.Clock = tstime.DefaultClock{Clock: clock} - m.opts.Dir = dir m.opts.State = must.Get(mem.New(nil, "")) + m.opts.fileOps, _ = newFileOps(dir) must.Do(m.opts.State.WriteState(ipn.TaildropReceivedKey, []byte{1})) fd.Init(&m, eventHook) defer fd.Shutdown() @@ -100,17 +106,17 @@ func TestDeleter(t *testing.T) { checkEvents("end waitAndDelete") checkDirectory() - must.Do(touchFile(filepath.Join(dir, "one.partial"))) + must.Do(m.touchFile("one.partial")) insert("one.partial") checkEvents("start waitAndDelete") advance(deleteDelay / 4) - must.Do(touchFile(filepath.Join(dir, "two.partial"))) + must.Do(m.touchFile("two.partial")) insert("two.partial") advance(deleteDelay / 4) - must.Do(touchFile(filepath.Join(dir, "three.partial"))) + must.Do(m.touchFile("three.partial")) insert("three.partial") advance(deleteDelay / 4) - must.Do(touchFile(filepath.Join(dir, "four.partial"))) + must.Do(m.touchFile("four.partial")) insert("four.partial") advance(deleteDelay / 4) @@ -142,11 +148,11 @@ func TestDeleter(t *testing.T) { // Test that the asynchronous full scan of the taildrop directory does not occur // on a cold start if taildrop has never received any files. func TestDeleterInitWithoutTaildrop(t *testing.T) { - var m Manager + var m manager var fd fileDeleter m.opts.Logf = t.Logf - m.opts.Dir = t.TempDir() m.opts.State = must.Get(mem.New(nil, "")) + m.opts.fileOps, _ = newFileOps(t.TempDir()) fd.Init(&m, func(event string) { t.Errorf("unexpected event: %v", event) }) fd.Shutdown() } diff --git a/feature/taildrop/doc.go b/feature/taildrop/doc.go new file mode 100644 index 000000000..8980a2170 --- /dev/null +++ b/feature/taildrop/doc.go @@ -0,0 +1,5 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package taildrop registers the taildrop (file sending) feature. +package taildrop diff --git a/feature/taildrop/ext.go b/feature/taildrop/ext.go new file mode 100644 index 000000000..6bdb375cc --- /dev/null +++ b/feature/taildrop/ext.go @@ -0,0 +1,436 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package taildrop + +import ( + "cmp" + "context" + "errors" + "fmt" + "io" + "maps" + "path/filepath" + "runtime" + "slices" + "strings" + "sync" + "sync/atomic" + + "tailscale.com/client/tailscale/apitype" + "tailscale.com/cmd/tailscaled/tailscaledhooks" + "tailscale.com/ipn" + "tailscale.com/ipn/ipnext" + "tailscale.com/ipn/ipnstate" + "tailscale.com/tailcfg" + "tailscale.com/tstime" + "tailscale.com/types/empty" + "tailscale.com/types/logger" + "tailscale.com/util/osshare" + "tailscale.com/util/set" +) + +func init() { + ipnext.RegisterExtension("taildrop", newExtension) + + if runtime.GOOS == "windows" { + tailscaledhooks.UninstallSystemDaemonWindows.Add(func() { + // Remove file sharing from Windows shell. + osshare.SetFileSharingEnabled(false, logger.Discard) + }) + } +} + +func newExtension(logf logger.Logf, b ipnext.SafeBackend) (ipnext.Extension, error) { + e := &Extension{ + sb: b, + stateStore: b.Sys().StateStore.Get(), + logf: logger.WithPrefix(logf, "taildrop: "), + } + e.setPlatformDefaultDirectFileRoot() + return e, nil +} + +// Extension implements Taildrop. +type Extension struct { + logf logger.Logf + sb ipnext.SafeBackend + stateStore ipn.StateStore + host ipnext.Host // from Init + + // directFileRoot, if non-empty, means to write received files + // directly to this directory, without staging them in an + // intermediate buffered directory for "pick-up" later. If + // empty, the files are received in a daemon-owned location + // and the localapi is used to enumerate, download, and delete + // them. This is used on macOS where the GUI lifetime is the + // same as the Network Extension lifetime and we can thus avoid + // double-copying files by writing them to the right location + // immediately. + // It's also used on several NAS platforms (Synology, TrueNAS, etc) + // but in that case DoFinalRename is also set true, which moves the + // *.partial file to its final name on completion. + directFileRoot string + + // FileOps abstracts platform-specific file operations needed for file transfers. + // This is currently being used for Android to use the Storage Access Framework. + fileOps FileOps + + nodeBackendForTest ipnext.NodeBackend // if non-nil, pretend we're this node state for tests + + mu sync.Mutex // Lock order: lb.mu > e.mu + backendState ipn.State + selfUID tailcfg.UserID + capFileSharing bool + fileWaiters set.HandleSet[context.CancelFunc] // of wake-up funcs + mgr atomic.Pointer[manager] // mutex held to write; safe to read without lock; + // outgoingFiles keeps track of Taildrop outgoing files keyed to their OutgoingFile.ID + outgoingFiles map[string]*ipn.OutgoingFile +} + +func (e *Extension) Name() string { + return "taildrop" +} + +func (e *Extension) Init(h ipnext.Host) error { + e.host = h + + osshare.SetFileSharingEnabled(false, e.logf) + + h.Hooks().ProfileStateChange.Add(e.onChangeProfile) + h.Hooks().OnSelfChange.Add(e.onSelfChange) + h.Hooks().MutateNotifyLocked.Add(e.setNotifyFilesWaiting) + h.Hooks().SetPeerStatus.Add(e.setPeerStatus) + h.Hooks().BackendStateChange.Add(e.onBackendStateChange) + + // TODO(nickkhyl): remove this after the profileManager refactoring. + // See tailscale/tailscale#15974. + // This same workaround appears in feature/portlist/portlist.go. + profile, prefs := h.Profiles().CurrentProfileState() + e.onChangeProfile(profile, prefs, false) + return nil +} + +func (e *Extension) onBackendStateChange(st ipn.State) { + e.mu.Lock() + defer e.mu.Unlock() + e.backendState = st +} + +func (e *Extension) onSelfChange(self tailcfg.NodeView) { + e.mu.Lock() + defer e.mu.Unlock() + + e.selfUID = 0 + if self.Valid() { + e.selfUID = self.User() + } + e.capFileSharing = self.Valid() && self.CapMap().Contains(tailcfg.CapabilityFileSharing) + osshare.SetFileSharingEnabled(e.capFileSharing, e.logf) +} + +func (e *Extension) setMgrLocked(mgr *manager) { + if old := e.mgr.Swap(mgr); old != nil { + old.Shutdown() + } +} + +func (e *Extension) onChangeProfile(profile ipn.LoginProfileView, _ ipn.PrefsView, sameNode bool) { + e.mu.Lock() + defer e.mu.Unlock() + + uid := profile.UserProfile().ID + activeLogin := profile.UserProfile().LoginName + + if uid == 0 { + e.setMgrLocked(nil) + e.outgoingFiles = nil + return + } + + if sameNode && e.manager() != nil { + return + } + + // Use the provided [FileOps] implementation (typically for SAF access on Android), + // or create an [fsFileOps] instance rooted at fileRoot. + // + // A non-nil [FileOps] also implies that we are in DirectFileMode. + fops := e.fileOps + isDirectFileMode := fops != nil + if fops == nil { + var fileRoot string + if fileRoot, isDirectFileMode = e.fileRoot(uid, activeLogin); fileRoot == "" { + e.logf("no Taildrop directory configured") + e.setMgrLocked(nil) + return + } + + var err error + if fops, err = newFileOps(fileRoot); err != nil { + e.logf("taildrop: cannot create FileOps: %v", err) + e.setMgrLocked(nil) + return + } + } + + e.setMgrLocked(managerOptions{ + Logf: e.logf, + Clock: tstime.DefaultClock{Clock: e.sb.Clock()}, + State: e.stateStore, + DirectFileMode: isDirectFileMode, + fileOps: fops, + SendFileNotify: e.sendFileNotify, + }.New()) +} + +// fileRoot returns where to store Taildrop files for the given user and whether +// to write received files directly to this directory, without staging them in +// an intermediate buffered directory for "pick-up" later. +// +// It is safe to call this with b.mu held but it does not require it or acquire +// it itself. +func (e *Extension) fileRoot(uid tailcfg.UserID, activeLogin string) (root string, isDirect bool) { + if v := e.directFileRoot; v != "" { + return v, true + } + varRoot := e.sb.TailscaleVarRoot() + if varRoot == "" { + e.logf("Taildrop disabled; no state directory") + return "", false + } + + if activeLogin == "" { + e.logf("taildrop: no active login; can't select a target directory") + return "", false + } + + baseDir := fmt.Sprintf("%s-uid-%d", + strings.ReplaceAll(activeLogin, "@", "-"), + uid) + return filepath.Join(varRoot, "files", baseDir), false +} + +// hasCapFileSharing reports whether the current node has the file sharing +// capability. +func (e *Extension) hasCapFileSharing() bool { + e.mu.Lock() + defer e.mu.Unlock() + return e.capFileSharing +} + +// manager returns the active Manager, or nil. +// +// Methods on a nil Manager are safe to call. +func (e *Extension) manager() *manager { + return e.mgr.Load() +} + +func (e *Extension) Clock() tstime.Clock { + return e.sb.Clock() +} + +func (e *Extension) Shutdown() error { + e.manager().Shutdown() // no-op on nil receiver + return nil +} + +func (e *Extension) sendFileNotify() { + mgr := e.manager() + if mgr == nil { + return + } + + var n ipn.Notify + + e.mu.Lock() + for _, wakeWaiter := range e.fileWaiters { + wakeWaiter() + } + n.IncomingFiles = mgr.IncomingFiles() + e.mu.Unlock() + + e.host.SendNotifyAsync(n) +} + +func (e *Extension) setNotifyFilesWaiting(n *ipn.Notify) { + if e.manager().HasFilesWaiting() { + n.FilesWaiting = &empty.Message{} + } +} + +func (e *Extension) setPeerStatus(ps *ipnstate.PeerStatus, p tailcfg.NodeView, nb ipnext.NodeBackend) { + ps.TaildropTarget = e.taildropTargetStatus(p, nb) +} + +func (e *Extension) removeFileWaiter(handle set.Handle) { + e.mu.Lock() + defer e.mu.Unlock() + delete(e.fileWaiters, handle) +} + +func (e *Extension) addFileWaiter(wakeWaiter context.CancelFunc) set.Handle { + e.mu.Lock() + defer e.mu.Unlock() + return e.fileWaiters.Add(wakeWaiter) +} + +func (e *Extension) WaitingFiles() ([]apitype.WaitingFile, error) { + return e.manager().WaitingFiles() +} + +// AwaitWaitingFiles is like WaitingFiles but blocks while ctx is not done, +// waiting for any files to be available. +// +// On return, exactly one of the results will be non-empty or non-nil, +// respectively. +func (e *Extension) AwaitWaitingFiles(ctx context.Context) ([]apitype.WaitingFile, error) { + if ff, err := e.WaitingFiles(); err != nil || len(ff) > 0 { + return ff, err + } + if err := ctx.Err(); err != nil { + return nil, err + } + for { + gotFile, gotFileCancel := context.WithCancel(context.Background()) + defer gotFileCancel() + + handle := e.addFileWaiter(gotFileCancel) + defer e.removeFileWaiter(handle) + + // Now that we've registered ourselves, check again, in case + // of race. Otherwise there's a small window where we could + // miss a file arrival and wait forever. + if ff, err := e.WaitingFiles(); err != nil || len(ff) > 0 { + return ff, err + } + + select { + case <-gotFile.Done(): + if ff, err := e.WaitingFiles(); err != nil || len(ff) > 0 { + return ff, err + } + case <-ctx.Done(): + return nil, ctx.Err() + } + } +} + +func (e *Extension) DeleteFile(name string) error { + return e.manager().DeleteFile(name) +} + +func (e *Extension) OpenFile(name string) (rc io.ReadCloser, size int64, err error) { + return e.manager().OpenFile(name) +} + +func (e *Extension) nodeBackend() ipnext.NodeBackend { + if e.nodeBackendForTest != nil { + return e.nodeBackendForTest + } + return e.host.NodeBackend() +} + +// FileTargets lists nodes that the current node can send files to. +func (e *Extension) FileTargets() ([]*apitype.FileTarget, error) { + var ret []*apitype.FileTarget + + e.mu.Lock() + st := e.backendState + self := e.selfUID + e.mu.Unlock() + + if st != ipn.Running { + return nil, errors.New("not connected to the tailnet") + } + if !e.hasCapFileSharing() { + return nil, errors.New("file sharing not enabled by Tailscale admin") + } + nb := e.nodeBackend() + peers := nb.AppendMatchingPeers(nil, func(p tailcfg.NodeView) bool { + if !p.Valid() || p.Hostinfo().OS() == "tvOS" { + return false + } + if self == p.User() { + return true + } + if nb.PeerHasCap(p, tailcfg.PeerCapabilityFileSharingTarget) { + // Explicitly noted in the netmap ACL caps as a target. + return true + } + return false + }) + for _, p := range peers { + peerAPI := nb.PeerAPIBase(p) + if peerAPI == "" { + continue + } + ret = append(ret, &apitype.FileTarget{ + Node: p.AsStruct(), + PeerAPIURL: peerAPI, + }) + } + slices.SortFunc(ret, func(a, b *apitype.FileTarget) int { + return cmp.Compare(a.Node.Name, b.Node.Name) + }) + return ret, nil +} + +func (e *Extension) taildropTargetStatus(p tailcfg.NodeView, nb ipnext.NodeBackend) ipnstate.TaildropTargetStatus { + e.mu.Lock() + st := e.backendState + selfUID := e.selfUID + capFileSharing := e.capFileSharing + e.mu.Unlock() + + if st != ipn.Running { + return ipnstate.TaildropTargetIpnStateNotRunning + } + + if !capFileSharing { + return ipnstate.TaildropTargetMissingCap + } + if !p.Valid() { + return ipnstate.TaildropTargetNoPeerInfo + } + if !p.Online().Get() { + return ipnstate.TaildropTargetOffline + } + if p.Hostinfo().OS() == "tvOS" { + return ipnstate.TaildropTargetUnsupportedOS + } + if selfUID != p.User() { + // Different user must have the explicit file sharing target capability + if !nb.PeerHasCap(p, tailcfg.PeerCapabilityFileSharingTarget) { + return ipnstate.TaildropTargetOwnedByOtherUser + } + } + if !nb.PeerHasPeerAPI(p) { + return ipnstate.TaildropTargetNoPeerAPI + } + return ipnstate.TaildropTargetAvailable +} + +// updateOutgoingFiles updates b.outgoingFiles to reflect the given updates and +// sends an ipn.Notify with the full list of outgoingFiles. +func (e *Extension) updateOutgoingFiles(updates map[string]*ipn.OutgoingFile) { + e.mu.Lock() + if e.outgoingFiles == nil { + e.outgoingFiles = make(map[string]*ipn.OutgoingFile, len(updates)) + } + maps.Copy(e.outgoingFiles, updates) + outgoingFiles := make([]*ipn.OutgoingFile, 0, len(e.outgoingFiles)) + for _, file := range e.outgoingFiles { + outgoingFiles = append(outgoingFiles, file) + } + e.mu.Unlock() + slices.SortFunc(outgoingFiles, func(a, b *ipn.OutgoingFile) int { + t := a.Started.Compare(b.Started) + if t != 0 { + return t + } + return strings.Compare(a.Name, b.Name) + }) + + e.host.SendNotifyAsync(ipn.Notify{OutgoingFiles: outgoingFiles}) +} diff --git a/feature/taildrop/fileops.go b/feature/taildrop/fileops.go new file mode 100644 index 000000000..14f76067a --- /dev/null +++ b/feature/taildrop/fileops.go @@ -0,0 +1,41 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package taildrop + +import ( + "io" + "io/fs" + "os" +) + +// FileOps abstracts over both local‐FS paths and Android SAF URIs. +type FileOps interface { + // OpenWriter creates or truncates a file named relative to the receiver's root, + // seeking to the specified offset. If the file does not exist, it is created with mode perm + // on platforms that support it. + // + // It returns an [io.WriteCloser] and the file's absolute path, or an error. + // This call may block. Callers should avoid holding locks when calling OpenWriter. + OpenWriter(name string, offset int64, perm os.FileMode) (wc io.WriteCloser, path string, err error) + + // Remove deletes a file or directory relative to the receiver's root. + // It returns [io.ErrNotExist] if the file or directory does not exist. + Remove(name string) error + + // Rename atomically renames oldPath to a new file named newName, + // returning the full new path or an error. + Rename(oldPath, newName string) (newPath string, err error) + + // ListFiles returns just the basenames of all regular files + // in the root directory. + ListFiles() ([]string, error) + + // Stat returns the FileInfo for the given name or an error. + Stat(name string) (fs.FileInfo, error) + + // OpenReader opens the given basename for the given name or an error. + OpenReader(name string) (io.ReadCloser, error) +} + +var newFileOps func(dir string) (FileOps, error) diff --git a/feature/taildrop/fileops_fs.go b/feature/taildrop/fileops_fs.go new file mode 100644 index 000000000..4fecbe4af --- /dev/null +++ b/feature/taildrop/fileops_fs.go @@ -0,0 +1,221 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause +//go:build !android + +package taildrop + +import ( + "bytes" + "crypto/sha256" + "errors" + "fmt" + "io" + "io/fs" + "os" + "path" + "path/filepath" + "strings" + "sync" + "unicode/utf8" +) + +var renameMu sync.Mutex + +// fsFileOps implements FileOps using the local filesystem rooted at a directory. +// It is used on non-Android platforms. +type fsFileOps struct{ rootDir string } + +func init() { + newFileOps = func(dir string) (FileOps, error) { + if dir == "" { + return nil, errors.New("rootDir cannot be empty") + } + if err := os.MkdirAll(dir, 0o700); err != nil { + return nil, fmt.Errorf("mkdir %q: %w", dir, err) + } + return fsFileOps{rootDir: dir}, nil + } +} + +func (f fsFileOps) OpenWriter(name string, offset int64, perm os.FileMode) (io.WriteCloser, string, error) { + path, err := joinDir(f.rootDir, name) + if err != nil { + return nil, "", err + } + if err = os.MkdirAll(filepath.Dir(path), 0o700); err != nil { + return nil, "", err + } + fi, err := os.OpenFile(path, os.O_CREATE|os.O_RDWR, perm) + if err != nil { + return nil, "", err + } + if offset != 0 { + curr, err := fi.Seek(0, io.SeekEnd) + if err != nil { + fi.Close() + return nil, "", err + } + if offset < 0 || offset > curr { + fi.Close() + return nil, "", fmt.Errorf("offset %d out of range", offset) + } + if _, err := fi.Seek(offset, io.SeekStart); err != nil { + fi.Close() + return nil, "", err + } + if err := fi.Truncate(offset); err != nil { + fi.Close() + return nil, "", err + } + } + return fi, path, nil +} + +func (f fsFileOps) Remove(name string) error { + path, err := joinDir(f.rootDir, name) + if err != nil { + return err + } + return os.Remove(path) +} + +// Rename moves the partial file into its final name. +// newName must be a base name (not absolute or containing path separators). +// It will retry up to 10 times, de-dup same-checksum files, etc. +func (f fsFileOps) Rename(oldPath, newName string) (newPath string, err error) { + var dst string + if filepath.IsAbs(newName) || strings.ContainsRune(newName, os.PathSeparator) { + return "", fmt.Errorf("invalid newName %q: must not be an absolute path or contain path separators", newName) + } + + dst = filepath.Join(f.rootDir, newName) + + if err := os.MkdirAll(filepath.Dir(dst), 0o700); err != nil { + return "", err + } + + st, err := os.Stat(oldPath) + if err != nil { + return "", err + } + wantSize := st.Size() + + const maxRetries = 10 + for i := 0; i < maxRetries; i++ { + renameMu.Lock() + fi, statErr := os.Stat(dst) + // Atomically rename the partial file as the destination file if it doesn't exist. + // Otherwise, it returns the length of the current destination file. + // The operation is atomic. + if os.IsNotExist(statErr) { + err = os.Rename(oldPath, dst) + renameMu.Unlock() + if err != nil { + return "", err + } + return dst, nil + } + if statErr != nil { + renameMu.Unlock() + return "", statErr + } + gotSize := fi.Size() + renameMu.Unlock() + + // Avoid the final rename if a destination file has the same contents. + // + // Note: this is best effort and copying files from iOS from the Media Library + // results in processing on the iOS side which means the size and shas of the + // same file can be different. + if gotSize == wantSize { + sumP, err := sha256File(oldPath) + if err != nil { + return "", err + } + sumD, err := sha256File(dst) + if err != nil { + return "", err + } + if bytes.Equal(sumP[:], sumD[:]) { + if err := os.Remove(oldPath); err != nil { + return "", err + } + return dst, nil + } + } + + // Choose a new destination filename and try again. + dst = filepath.Join(filepath.Dir(dst), nextFilename(filepath.Base(dst))) + } + + return "", fmt.Errorf("too many retries trying to rename %q to %q", oldPath, newName) +} + +// sha256File computes the SHA‑256 of a file. +func sha256File(path string) (sum [sha256.Size]byte, _ error) { + f, err := os.Open(path) + if err != nil { + return sum, err + } + defer f.Close() + h := sha256.New() + if _, err := io.Copy(h, f); err != nil { + return sum, err + } + copy(sum[:], h.Sum(nil)) + return sum, nil +} + +func (f fsFileOps) ListFiles() ([]string, error) { + entries, err := os.ReadDir(f.rootDir) + if err != nil { + return nil, err + } + var names []string + for _, e := range entries { + if e.Type().IsRegular() { + names = append(names, e.Name()) + } + } + return names, nil +} + +func (f fsFileOps) Stat(name string) (fs.FileInfo, error) { + path, err := joinDir(f.rootDir, name) + if err != nil { + return nil, err + } + return os.Stat(path) +} + +func (f fsFileOps) OpenReader(name string) (io.ReadCloser, error) { + path, err := joinDir(f.rootDir, name) + if err != nil { + return nil, err + } + return os.Open(path) +} + +// joinDir is like [filepath.Join] but returns an error if baseName is too long, +// is a relative path instead of a basename, or is otherwise invalid or unsafe for incoming files. +func joinDir(dir, baseName string) (string, error) { + if !utf8.ValidString(baseName) || + strings.TrimSpace(baseName) != baseName || + len(baseName) > 255 { + return "", ErrInvalidFileName + } + // TODO: validate unicode normalization form too? Varies by platform. + clean := path.Clean(baseName) + if clean != baseName || clean == "." || clean == ".." { + return "", ErrInvalidFileName + } + for _, r := range baseName { + if !validFilenameRune(r) { + return "", ErrInvalidFileName + } + } + if !filepath.IsLocal(baseName) { + return "", ErrInvalidFileName + } + return filepath.Join(dir, baseName), nil +} diff --git a/feature/taildrop/integration_test.go b/feature/taildrop/integration_test.go new file mode 100644 index 000000000..75896a95b --- /dev/null +++ b/feature/taildrop/integration_test.go @@ -0,0 +1,196 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package taildrop_test + +import ( + "bytes" + "context" + "errors" + "fmt" + "io" + "testing" + "time" + + "tailscale.com/client/local" + "tailscale.com/client/tailscale/apitype" + "tailscale.com/tailcfg" + "tailscale.com/tstest" + "tailscale.com/tstest/integration" + "tailscale.com/tstest/integration/testcontrol" +) + +// TODO(bradfitz): add test where control doesn't send tailcfg.CapabilityFileSharing +// and verify that we get the "file sharing not enabled by Tailscale admin" error. + +// TODO(bradfitz): add test between different users with the peercap to permit that? + +func TestTaildropIntegration(t *testing.T) { + testTaildropIntegration(t, false) +} + +func TestTaildropIntegration_Fresh(t *testing.T) { + testTaildropIntegration(t, true) +} + +// freshProfiles is whether to start the test right away +// with a fresh profile. If false, tailscaled is started, stopped, +// and restarted again to simulate a real-world scenario where +// the first profile already existed. +// +// This exercises an ipnext hook ordering issue we hit earlier. +func testTaildropIntegration(t *testing.T, freshProfiles bool) { + tstest.Parallel(t) + controlOpt := integration.ConfigureControl(func(s *testcontrol.Server) { + s.AllNodesSameUser = true // required for Taildrop + }) + env := integration.NewTestEnv(t, controlOpt) + + // Create two nodes: + n1 := integration.NewTestNode(t, env) + d1 := n1.StartDaemon() + + n2 := integration.NewTestNode(t, env) + d2 := n2.StartDaemon() + + awaitUp := func() { + t.Helper() + n1.AwaitListening() + t.Logf("n1 is listening") + n2.AwaitListening() + t.Logf("n2 is listening") + n1.MustUp() + t.Logf("n1 is up") + n2.MustUp() + t.Logf("n2 is up") + n1.AwaitRunning() + t.Logf("n1 is running") + n2.AwaitRunning() + t.Logf("n2 is running") + } + awaitUp() + + if !freshProfiles { + d1.MustCleanShutdown(t) + d2.MustCleanShutdown(t) + d1 = n1.StartDaemon() + d2 = n2.StartDaemon() + awaitUp() + } + + var peerStableID tailcfg.StableNodeID + + if err := tstest.WaitFor(5*time.Second, func() error { + st := n1.MustStatus() + if len(st.Peer) == 0 { + return errors.New("no peers") + } + if len(st.Peer) > 1 { + return fmt.Errorf("got %d peers; want 1", len(st.Peer)) + } + peer := st.Peer[st.Peers()[0]] + peerStableID = peer.ID + if peer.ID == st.Self.ID { + return errors.New("peer is self") + } + + if len(st.TailscaleIPs) == 0 { + return errors.New("no Tailscale IPs") + } + + return nil + }); err != nil { + t.Fatal(err) + } + + const timeout = 30 * time.Second + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + c1 := n1.LocalClient() + c2 := n2.LocalClient() + + wantNoWaitingFiles := func(c *local.Client) { + t.Helper() + files, err := c.WaitingFiles(ctx) + if err != nil { + t.Fatalf("WaitingFiles: %v", err) + } + if len(files) != 0 { + t.Fatalf("WaitingFiles: got %d files; want 0", len(files)) + } + } + + // Verify c2 has no files. + wantNoWaitingFiles(c2) + + gotFile := make(chan bool, 1) + go func() { + v, err := c2.AwaitWaitingFiles(t.Context(), timeout) + if err != nil { + return + } + if len(v) != 0 { + gotFile <- true + } + }() + + fileContents := []byte("hello world this is a file") + + n2ID := n2.MustStatus().Self.ID + t.Logf("n2 self.ID = %q; n1's peer[0].ID = %q", n2ID, peerStableID) + t.Logf("Doing PushFile ...") + err := c1.PushFile(ctx, n2.MustStatus().Self.ID, int64(len(fileContents)), "test.txt", bytes.NewReader(fileContents)) + if err != nil { + t.Fatalf("PushFile from n1->n2: %v", err) + } + t.Logf("PushFile done") + + select { + case <-gotFile: + t.Logf("n2 saw AwaitWaitingFiles wake up") + case <-ctx.Done(): + t.Fatalf("n2 timeout waiting for AwaitWaitingFiles") + } + + files, err := c2.WaitingFiles(ctx) + if err != nil { + t.Fatalf("c2.WaitingFiles: %v", err) + } + if len(files) != 1 { + t.Fatalf("c2.WaitingFiles: got %d files; want 1", len(files)) + } + got := files[0] + want := apitype.WaitingFile{ + Name: "test.txt", + Size: int64(len(fileContents)), + } + if got != want { + t.Fatalf("c2.WaitingFiles: got %+v; want %+v", got, want) + } + + // Download the file. + rc, size, err := c2.GetWaitingFile(ctx, got.Name) + if err != nil { + t.Fatalf("c2.GetWaitingFile: %v", err) + } + if size != int64(len(fileContents)) { + t.Fatalf("c2.GetWaitingFile: got size %d; want %d", size, len(fileContents)) + } + gotBytes, err := io.ReadAll(rc) + if err != nil { + t.Fatalf("c2.GetWaitingFile: %v", err) + } + if !bytes.Equal(gotBytes, fileContents) { + t.Fatalf("c2.GetWaitingFile: got %q; want %q", gotBytes, fileContents) + } + + // Now delete it. + if err := c2.DeleteWaitingFile(ctx, got.Name); err != nil { + t.Fatalf("c2.DeleteWaitingFile: %v", err) + } + wantNoWaitingFiles(c2) + + d1.MustCleanShutdown(t) + d2.MustCleanShutdown(t) +} diff --git a/feature/taildrop/localapi.go b/feature/taildrop/localapi.go new file mode 100644 index 000000000..8a3904f9f --- /dev/null +++ b/feature/taildrop/localapi.go @@ -0,0 +1,458 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package taildrop + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "maps" + "mime" + "mime/multipart" + "net/http" + "net/http/httputil" + "net/url" + "strconv" + "strings" + "time" + + "tailscale.com/client/tailscale/apitype" + "tailscale.com/ipn" + "tailscale.com/ipn/ipnlocal" + "tailscale.com/ipn/localapi" + "tailscale.com/tailcfg" + "tailscale.com/util/clientmetric" + "tailscale.com/util/httphdr" + "tailscale.com/util/mak" + "tailscale.com/util/progresstracking" + "tailscale.com/util/rands" +) + +func init() { + localapi.Register("file-put/", serveFilePut) + localapi.Register("files/", serveFiles) + localapi.Register("file-targets", serveFileTargets) +} + +var ( + metricFilePutCalls = clientmetric.NewCounter("localapi_file_put") +) + +// serveFilePut sends a file to another node. +// +// It's sometimes possible for clients to do this themselves, without +// tailscaled, except in the case of tailscaled running in +// userspace-networking ("netstack") mode, in which case tailscaled +// needs to a do a netstack dial out. +// +// Instead, the CLI also goes through tailscaled so it doesn't need to be +// aware of the network mode in use. +// +// macOS/iOS have always used this localapi method to simplify the GUI +// clients. +// +// The Windows client currently (2021-11-30) uses the peerapi (/v0/put/) +// directly, as the Windows GUI always runs in tun mode anyway. +// +// In addition to single file PUTs, this endpoint accepts multipart file +// POSTS encoded as multipart/form-data.The first part should be an +// application/json file that contains a manifest consisting of a JSON array of +// OutgoingFiles which we can use for tracking progress even before reading the +// file parts. +// +// URL format: +// +// - PUT /localapi/v0/file-put/:stableID/:escaped-filename +// - POST /localapi/v0/file-put/:stableID +func serveFilePut(h *localapi.Handler, w http.ResponseWriter, r *http.Request) { + metricFilePutCalls.Add(1) + + if !h.PermitWrite { + http.Error(w, "file access denied", http.StatusForbidden) + return + } + + if r.Method != "PUT" && r.Method != "POST" { + http.Error(w, "want PUT to put file", http.StatusBadRequest) + return + } + + ext, ok := ipnlocal.GetExt[*Extension](h.LocalBackend()) + if !ok { + http.Error(w, "misconfigured taildrop extension", http.StatusInternalServerError) + return + } + + fts, err := ext.FileTargets() + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + upath, ok := strings.CutPrefix(r.URL.EscapedPath(), "/localapi/v0/file-put/") + if !ok { + http.Error(w, "misconfigured", http.StatusInternalServerError) + return + } + var peerIDStr, filenameEscaped string + if r.Method == "PUT" { + ok := false + peerIDStr, filenameEscaped, ok = strings.Cut(upath, "/") + if !ok { + http.Error(w, "bogus URL", http.StatusBadRequest) + return + } + } else { + peerIDStr = upath + } + peerID := tailcfg.StableNodeID(peerIDStr) + + var ft *apitype.FileTarget + for _, x := range fts { + if x.Node.StableID == peerID { + ft = x + break + } + } + if ft == nil { + http.Error(w, "node not found", http.StatusNotFound) + return + } + dstURL, err := url.Parse(ft.PeerAPIURL) + if err != nil { + http.Error(w, "bogus peer URL", http.StatusInternalServerError) + return + } + + // Periodically report progress of outgoing files. + outgoingFiles := make(map[string]*ipn.OutgoingFile) + t := time.NewTicker(1 * time.Second) + progressUpdates := make(chan ipn.OutgoingFile) + defer close(progressUpdates) + + go func() { + defer t.Stop() + defer ext.updateOutgoingFiles(outgoingFiles) + for { + select { + case u, ok := <-progressUpdates: + if !ok { + return + } + outgoingFiles[u.ID] = &u + case <-t.C: + ext.updateOutgoingFiles(outgoingFiles) + } + } + }() + + switch r.Method { + case "PUT": + file := ipn.OutgoingFile{ + ID: rands.HexString(30), + PeerID: peerID, + Name: filenameEscaped, + DeclaredSize: r.ContentLength, + } + singleFilePut(h, r.Context(), progressUpdates, w, r.Body, dstURL, file) + case "POST": + multiFilePost(h, progressUpdates, w, r, peerID, dstURL) + default: + http.Error(w, "want PUT to put file", http.StatusBadRequest) + return + } +} + +func multiFilePost(h *localapi.Handler, progressUpdates chan (ipn.OutgoingFile), w http.ResponseWriter, r *http.Request, peerID tailcfg.StableNodeID, dstURL *url.URL) { + _, params, err := mime.ParseMediaType(r.Header.Get("Content-Type")) + if err != nil { + http.Error(w, fmt.Sprintf("invalid Content-Type for multipart POST: %s", err), http.StatusBadRequest) + return + } + + ww := &multiFilePostResponseWriter{} + defer func() { + if err := ww.Flush(w); err != nil { + h.Logf("error: multiFilePostResponseWriter.Flush(): %s", err) + } + }() + + outgoingFilesByName := make(map[string]ipn.OutgoingFile) + first := true + mr := multipart.NewReader(r.Body, params["boundary"]) + for { + part, err := mr.NextPart() + if err == io.EOF { + // No more parts. + return + } else if err != nil { + http.Error(ww, fmt.Sprintf("failed to decode multipart/form-data: %s", err), http.StatusBadRequest) + return + } + + if first { + first = false + if part.Header.Get("Content-Type") != "application/json" { + http.Error(ww, "first MIME part must be a JSON map of filename -> size", http.StatusBadRequest) + return + } + + var manifest []ipn.OutgoingFile + err := json.NewDecoder(part).Decode(&manifest) + if err != nil { + http.Error(ww, fmt.Sprintf("invalid manifest: %s", err), http.StatusBadRequest) + return + } + + for _, file := range manifest { + outgoingFilesByName[file.Name] = file + progressUpdates <- file + } + + continue + } + + if !singleFilePut(h, r.Context(), progressUpdates, ww, part, dstURL, outgoingFilesByName[part.FileName()]) { + return + } + + if ww.statusCode >= 400 { + // put failed, stop immediately + h.Logf("error: singleFilePut: failed with status %d", ww.statusCode) + return + } + } +} + +// multiFilePostResponseWriter is a buffering http.ResponseWriter that can be +// reused across multiple singleFilePut calls and then flushed to the client +// when all files have been PUT. +type multiFilePostResponseWriter struct { + header http.Header + statusCode int + body *bytes.Buffer +} + +func (ww *multiFilePostResponseWriter) Header() http.Header { + if ww.header == nil { + ww.header = make(http.Header) + } + return ww.header +} + +func (ww *multiFilePostResponseWriter) WriteHeader(statusCode int) { + ww.statusCode = statusCode +} + +func (ww *multiFilePostResponseWriter) Write(p []byte) (int, error) { + if ww.body == nil { + ww.body = bytes.NewBuffer(nil) + } + return ww.body.Write(p) +} + +func (ww *multiFilePostResponseWriter) Flush(w http.ResponseWriter) error { + if ww.header != nil { + maps.Copy(w.Header(), ww.header) + } + if ww.statusCode > 0 { + w.WriteHeader(ww.statusCode) + } + if ww.body != nil { + _, err := io.Copy(w, ww.body) + return err + } + return nil +} + +func singleFilePut( + h *localapi.Handler, + ctx context.Context, + progressUpdates chan (ipn.OutgoingFile), + w http.ResponseWriter, + body io.Reader, + dstURL *url.URL, + outgoingFile ipn.OutgoingFile, +) bool { + outgoingFile.Started = time.Now() + body = progresstracking.NewReader(body, 1*time.Second, func(n int, err error) { + outgoingFile.Sent = int64(n) + progressUpdates <- outgoingFile + }) + + fail := func() { + outgoingFile.Finished = true + outgoingFile.Succeeded = false + progressUpdates <- outgoingFile + } + + // Before we PUT a file we check to see if there are any existing partial file and if so, + // we resume the upload from where we left off by sending the remaining file instead of + // the full file. + var offset int64 + var resumeDuration time.Duration + remainingBody := io.Reader(body) + client := &http.Client{ + Transport: h.LocalBackend().Dialer().PeerAPITransport(), + Timeout: 10 * time.Second, + } + req, err := http.NewRequestWithContext(ctx, "GET", dstURL.String()+"/v0/put/"+outgoingFile.Name, nil) + if err != nil { + http.Error(w, "bogus peer URL", http.StatusInternalServerError) + fail() + return false + } + resp, err := client.Do(req) + if resp != nil { + defer resp.Body.Close() + } + switch { + case err != nil: + h.Logf("could not fetch remote hashes: %v", err) + case resp.StatusCode == http.StatusMethodNotAllowed || resp.StatusCode == http.StatusNotFound: + // noop; implies older peerapi without resume support + case resp.StatusCode != http.StatusOK: + h.Logf("fetch remote hashes status code: %d", resp.StatusCode) + default: + resumeStart := time.Now() + dec := json.NewDecoder(resp.Body) + offset, remainingBody, err = resumeReader(body, func() (out blockChecksum, err error) { + err = dec.Decode(&out) + return out, err + }) + if err != nil { + h.Logf("reader could not be fully resumed: %v", err) + } + resumeDuration = time.Since(resumeStart).Round(time.Millisecond) + } + + outReq, err := http.NewRequestWithContext(ctx, "PUT", "http://peer/v0/put/"+outgoingFile.Name, remainingBody) + if err != nil { + http.Error(w, "bogus outreq", http.StatusInternalServerError) + fail() + return false + } + outReq.ContentLength = outgoingFile.DeclaredSize + if offset > 0 { + h.Logf("resuming put at offset %d after %v", offset, resumeDuration) + rangeHdr, _ := httphdr.FormatRange([]httphdr.Range{{Start: offset, Length: 0}}) + outReq.Header.Set("Range", rangeHdr) + if outReq.ContentLength >= 0 { + outReq.ContentLength -= offset + } + } + + rp := httputil.NewSingleHostReverseProxy(dstURL) + rp.Transport = h.LocalBackend().Dialer().PeerAPITransport() + rp.ServeHTTP(w, outReq) + + outgoingFile.Finished = true + outgoingFile.Succeeded = true + progressUpdates <- outgoingFile + + return true +} + +func serveFiles(h *localapi.Handler, w http.ResponseWriter, r *http.Request) { + if !h.PermitWrite { + http.Error(w, "file access denied", http.StatusForbidden) + return + } + + ext, ok := ipnlocal.GetExt[*Extension](h.LocalBackend()) + if !ok { + http.Error(w, "misconfigured taildrop extension", http.StatusInternalServerError) + return + } + + suffix, ok := strings.CutPrefix(r.URL.EscapedPath(), "/localapi/v0/files/") + if !ok { + http.Error(w, "misconfigured", http.StatusInternalServerError) + return + } + if suffix == "" { + if r.Method != "GET" { + http.Error(w, "want GET to list files", http.StatusBadRequest) + return + } + ctx := r.Context() + var wfs []apitype.WaitingFile + if s := r.FormValue("waitsec"); s != "" && s != "0" { + d, err := strconv.Atoi(s) + if err != nil { + http.Error(w, "invalid waitsec", http.StatusBadRequest) + return + } + deadline := time.Now().Add(time.Duration(d) * time.Second) + var cancel context.CancelFunc + ctx, cancel = context.WithDeadline(ctx, deadline) + defer cancel() + wfs, err = ext.AwaitWaitingFiles(ctx) + if err != nil && ctx.Err() == nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + } else { + var err error + wfs, err = ext.WaitingFiles() + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(wfs) + return + } + name, err := url.PathUnescape(suffix) + if err != nil { + http.Error(w, "bad filename", http.StatusBadRequest) + return + } + if r.Method == "DELETE" { + if err := ext.DeleteFile(name); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + w.WriteHeader(http.StatusNoContent) + return + } + rc, size, err := ext.OpenFile(name) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + defer rc.Close() + w.Header().Set("Content-Length", fmt.Sprint(size)) + w.Header().Set("Content-Type", "application/octet-stream") + io.Copy(w, rc) +} + +func serveFileTargets(h *localapi.Handler, w http.ResponseWriter, r *http.Request) { + if !h.PermitRead { + http.Error(w, "access denied", http.StatusForbidden) + return + } + if r.Method != "GET" { + http.Error(w, "want GET to list targets", http.StatusBadRequest) + return + } + + ext, ok := ipnlocal.GetExt[*Extension](h.LocalBackend()) + if !ok { + http.Error(w, "misconfigured taildrop extension", http.StatusInternalServerError) + return + } + + fts, err := ext.FileTargets() + if err != nil { + localapi.WriteErrorJSON(w, err) + return + } + mak.NonNilSliceForJSON(&fts) + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(fts) +} diff --git a/cmd/tailscaled/taildrop.go b/feature/taildrop/paths.go similarity index 85% rename from cmd/tailscaled/taildrop.go rename to feature/taildrop/paths.go index 39fe54373..79dc37d8f 100644 --- a/cmd/tailscaled/taildrop.go +++ b/feature/taildrop/paths.go @@ -1,35 +1,44 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -//go:build go1.19 - -package main +package taildrop import ( "fmt" "os" "path/filepath" - "tailscale.com/ipn/ipnlocal" - "tailscale.com/types/logger" "tailscale.com/version/distro" ) -func configureTaildrop(logf logger.Logf, lb *ipnlocal.LocalBackend) { +// SetDirectFileRoot sets the directory where received files are written. +// +// This must be called before Tailscale is started. +func (e *Extension) SetDirectFileRoot(root string) { + e.directFileRoot = root +} + +// SetFileOps sets the platform specific file operations. This is used +// to call Android's Storage Access Framework APIs. +func (e *Extension) SetFileOps(fileOps FileOps) { + e.fileOps = fileOps +} + +func (e *Extension) setPlatformDefaultDirectFileRoot() { dg := distro.Get() + switch dg { case distro.Synology, distro.TrueNAS, distro.QNAP, distro.Unraid: // See if they have a "Taildrop" share. // See https://github.com/tailscale/tailscale/issues/2179#issuecomment-982821319 path, err := findTaildropDir(dg) if err != nil { - logf("%s Taildrop support: %v", dg, err) + e.logf("%s Taildrop support: %v", dg, err) } else { - logf("%s Taildrop: using %v", dg, path) - lb.SetDirectFileRoot(path) + e.logf("%s Taildrop: using %v", dg, path) + e.directFileRoot = path } } - } func findTaildropDir(dg distro.Distro) (string, error) { diff --git a/feature/taildrop/peerapi.go b/feature/taildrop/peerapi.go new file mode 100644 index 000000000..b75ce33b8 --- /dev/null +++ b/feature/taildrop/peerapi.go @@ -0,0 +1,169 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package taildrop + +import ( + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "time" + + "tailscale.com/ipn/ipnlocal" + "tailscale.com/tailcfg" + "tailscale.com/tstime" + "tailscale.com/util/clientmetric" + "tailscale.com/util/httphdr" +) + +func init() { + ipnlocal.RegisterPeerAPIHandler("/v0/put/", handlePeerPut) +} + +var ( + metricPutCalls = clientmetric.NewCounter("peerapi_put") +) + +// canPutFile reports whether h can put a file ("Taildrop") to this node. +func canPutFile(h ipnlocal.PeerAPIHandler) bool { + if h.Peer().UnsignedPeerAPIOnly() { + // Unsigned peers can't send files. + return false + } + return h.IsSelfUntagged() || h.PeerCaps().HasCapability(tailcfg.PeerCapabilityFileSharingSend) +} + +func handlePeerPut(h ipnlocal.PeerAPIHandler, w http.ResponseWriter, r *http.Request) { + ext, ok := ipnlocal.GetExt[*Extension](h.LocalBackend()) + if !ok { + http.Error(w, "miswired", http.StatusInternalServerError) + return + } + handlePeerPutWithBackend(h, ext, w, r) +} + +// extensionForPut is the subset of taildrop extension that taildrop +// file put needs. This is pulled out for testability. +type extensionForPut interface { + manager() *manager + hasCapFileSharing() bool + Clock() tstime.Clock +} + +func handlePeerPutWithBackend(h ipnlocal.PeerAPIHandler, ext extensionForPut, w http.ResponseWriter, r *http.Request) { + if r.Method == "PUT" { + metricPutCalls.Add(1) + } + + taildropMgr := ext.manager() + if taildropMgr == nil { + h.Logf("taildrop: no taildrop manager") + http.Error(w, "failed to get taildrop manager", http.StatusInternalServerError) + return + } + + if !canPutFile(h) { + http.Error(w, ErrNoTaildrop.Error(), http.StatusForbidden) + return + } + if !ext.hasCapFileSharing() { + http.Error(w, ErrNoTaildrop.Error(), http.StatusForbidden) + return + } + rawPath := r.URL.EscapedPath() + prefix, ok := strings.CutPrefix(rawPath, "/v0/put/") + if !ok { + http.Error(w, "misconfigured internals", http.StatusForbidden) + return + } + baseName, err := url.PathUnescape(prefix) + if err != nil { + http.Error(w, ErrInvalidFileName.Error(), http.StatusBadRequest) + return + } + enc := json.NewEncoder(w) + switch r.Method { + case "GET": + id := clientID(h.Peer().StableID()) + if prefix == "" { + // List all the partial files. + files, err := taildropMgr.PartialFiles(id) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + if err := enc.Encode(files); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + h.Logf("json.Encoder.Encode error: %v", err) + return + } + } else { + // Stream all the block hashes for the specified file. + next, close, err := taildropMgr.HashPartialFile(id, baseName) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + defer close() + for { + switch cs, err := next(); { + case err == io.EOF: + return + case err != nil: + http.Error(w, err.Error(), http.StatusInternalServerError) + h.Logf("HashPartialFile.next error: %v", err) + return + default: + if err := enc.Encode(cs); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + h.Logf("json.Encoder.Encode error: %v", err) + return + } + } + } + } + case "PUT": + t0 := ext.Clock().Now() + id := clientID(h.Peer().StableID()) + + var offset int64 + if rangeHdr := r.Header.Get("Range"); rangeHdr != "" { + ranges, ok := httphdr.ParseRange(rangeHdr) + if !ok || len(ranges) != 1 || ranges[0].Length != 0 { + http.Error(w, "invalid Range header", http.StatusBadRequest) + return + } + offset = ranges[0].Start + } + n, err := taildropMgr.PutFile(clientID(fmt.Sprint(id)), baseName, r.Body, offset, r.ContentLength) + switch err { + case nil: + d := ext.Clock().Since(t0).Round(time.Second / 10) + h.Logf("got put of %s in %v from %v/%v", approxSize(n), d, h.RemoteAddr().Addr(), h.Peer().ComputedName) + io.WriteString(w, "{}\n") + case ErrNoTaildrop: + http.Error(w, err.Error(), http.StatusForbidden) + case ErrInvalidFileName: + http.Error(w, err.Error(), http.StatusBadRequest) + case ErrFileExists: + http.Error(w, err.Error(), http.StatusConflict) + default: + http.Error(w, err.Error(), http.StatusInternalServerError) + } + default: + http.Error(w, "expected method GET or PUT", http.StatusMethodNotAllowed) + } +} + +func approxSize(n int64) string { + if n <= 1<<10 { + return "<=1KB" + } + if n <= 1<<20 { + return "<=1MB" + } + return fmt.Sprintf("~%dMB", n>>20) +} diff --git a/feature/taildrop/peerapi_test.go b/feature/taildrop/peerapi_test.go new file mode 100644 index 000000000..254d8794e --- /dev/null +++ b/feature/taildrop/peerapi_test.go @@ -0,0 +1,589 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package taildrop + +import ( + "bytes" + "fmt" + "io" + "io/fs" + "math/rand" + "net/http" + "net/http/httptest" + "net/netip" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/google/go-cmp/cmp" + "tailscale.com/client/tailscale/apitype" + "tailscale.com/ipn/ipnlocal" + "tailscale.com/tailcfg" + "tailscale.com/tstest" + "tailscale.com/tstime" + "tailscale.com/types/logger" + "tailscale.com/util/must" +) + +// peerAPIHandler serves the PeerAPI for a source specific client. +type peerAPIHandler struct { + remoteAddr netip.AddrPort + isSelf bool // whether peerNode is owned by same user as this node + selfNode tailcfg.NodeView // this node; always non-nil + peerNode tailcfg.NodeView // peerNode is who's making the request + canDebug bool // whether peerNode can debug this node (goroutines, metrics, magicsock internal state, etc) +} + +func (h *peerAPIHandler) IsSelfUntagged() bool { + return !h.selfNode.IsTagged() && !h.peerNode.IsTagged() && h.isSelf +} +func (h *peerAPIHandler) CanDebug() bool { return h.canDebug } +func (h *peerAPIHandler) Peer() tailcfg.NodeView { return h.peerNode } +func (h *peerAPIHandler) Self() tailcfg.NodeView { return h.selfNode } +func (h *peerAPIHandler) RemoteAddr() netip.AddrPort { return h.remoteAddr } +func (h *peerAPIHandler) LocalBackend() *ipnlocal.LocalBackend { panic("unexpected") } +func (h *peerAPIHandler) Logf(format string, a ...any) { + //h.logf(format, a...) +} + +func (h *peerAPIHandler) PeerCaps() tailcfg.PeerCapMap { + return nil +} + +type fakeExtension struct { + logf logger.Logf + capFileSharing bool + clock tstime.Clock + taildrop *manager +} + +func (lb *fakeExtension) manager() *manager { + return lb.taildrop +} +func (lb *fakeExtension) Clock() tstime.Clock { return lb.clock } +func (lb *fakeExtension) hasCapFileSharing() bool { + return lb.capFileSharing +} + +type peerAPITestEnv struct { + taildrop *manager + ph *peerAPIHandler + rr *httptest.ResponseRecorder + logBuf tstest.MemLogger +} + +type check func(*testing.T, *peerAPITestEnv) + +func checks(vv ...check) []check { return vv } + +func httpStatus(wantStatus int) check { + return func(t *testing.T, e *peerAPITestEnv) { + if res := e.rr.Result(); res.StatusCode != wantStatus { + t.Errorf("HTTP response code = %v; want %v", res.Status, wantStatus) + } + } +} + +func bodyContains(sub string) check { + return func(t *testing.T, e *peerAPITestEnv) { + if body := e.rr.Body.String(); !strings.Contains(body, sub) { + t.Errorf("HTTP response body does not contain %q; got: %s", sub, body) + } + } +} + +func fileHasSize(name string, size int) check { + return func(t *testing.T, e *peerAPITestEnv) { + fsImpl, ok := e.taildrop.opts.fileOps.(*fsFileOps) + if !ok { + t.Skip("fileHasSize only supported on fsFileOps backend") + return + } + root := fsImpl.rootDir + if root == "" { + t.Errorf("no rootdir; can't check whether %q has size %v", name, size) + return + } + if root == "" { + t.Errorf("no rootdir; can't check whether %q has size %v", name, size) + return + } + path := filepath.Join(root, name) + if fi, err := os.Stat(path); err != nil { + t.Errorf("fileHasSize(%q, %v): %v", name, size, err) + } else if fi.Size() != int64(size) { + t.Errorf("file %q has size %v; want %v", name, fi.Size(), size) + } + } +} + +func fileHasContents(name string, want string) check { + return func(t *testing.T, e *peerAPITestEnv) { + fsImpl, ok := e.taildrop.opts.fileOps.(*fsFileOps) + if !ok { + t.Skip("fileHasContents only supported on fsFileOps backend") + return + } + path := filepath.Join(fsImpl.rootDir, name) + got, err := os.ReadFile(path) + if err != nil { + t.Errorf("fileHasContents: %v", err) + return + } + if string(got) != want { + t.Errorf("file contents = %q; want %q", got, want) + } + } +} + +func hexAll(v string) string { + var sb strings.Builder + for i := range len(v) { + fmt.Fprintf(&sb, "%%%02x", v[i]) + } + return sb.String() +} + +func TestHandlePeerAPI(t *testing.T) { + tests := []struct { + name string + isSelf bool // the peer sending the request is owned by us + capSharing bool // self node has file sharing capability + debugCap bool // self node has debug capability + omitRoot bool // don't configure + reqs []*http.Request + checks []check + }{ + { + name: "reject_non_owner_put", + isSelf: false, + capSharing: true, + reqs: []*http.Request{httptest.NewRequest("PUT", "/v0/put/foo", nil)}, + checks: checks( + httpStatus(http.StatusForbidden), + bodyContains("Taildrop disabled"), + ), + }, + { + name: "owner_without_cap", + isSelf: true, + capSharing: false, + reqs: []*http.Request{httptest.NewRequest("PUT", "/v0/put/foo", nil)}, + checks: checks( + httpStatus(http.StatusForbidden), + bodyContains("Taildrop disabled"), + ), + }, + { + name: "owner_with_cap_no_rootdir", + omitRoot: true, + isSelf: true, + capSharing: true, + reqs: []*http.Request{httptest.NewRequest("PUT", "/v0/put/foo", nil)}, + checks: checks( + httpStatus(http.StatusForbidden), + bodyContains("Taildrop disabled"), + ), + }, + + { + name: "bad_method", + isSelf: true, + capSharing: true, + reqs: []*http.Request{httptest.NewRequest("POST", "/v0/put/foo", nil)}, + checks: checks( + httpStatus(405), + bodyContains("expected method GET or PUT"), + ), + }, + { + name: "put_zero_length", + isSelf: true, + capSharing: true, + reqs: []*http.Request{httptest.NewRequest("PUT", "/v0/put/foo", nil)}, + checks: checks( + httpStatus(200), + bodyContains("{}"), + fileHasSize("foo", 0), + fileHasContents("foo", ""), + ), + }, + { + name: "put_non_zero_length_content_length", + isSelf: true, + capSharing: true, + reqs: []*http.Request{httptest.NewRequest("PUT", "/v0/put/foo", strings.NewReader("contents"))}, + checks: checks( + httpStatus(200), + bodyContains("{}"), + fileHasSize("foo", len("contents")), + fileHasContents("foo", "contents"), + ), + }, + { + name: "put_non_zero_length_chunked", + isSelf: true, + capSharing: true, + reqs: []*http.Request{httptest.NewRequest("PUT", "/v0/put/foo", struct{ io.Reader }{strings.NewReader("contents")})}, + checks: checks( + httpStatus(200), + bodyContains("{}"), + fileHasSize("foo", len("contents")), + fileHasContents("foo", "contents"), + ), + }, + { + name: "bad_filename_partial", + isSelf: true, + capSharing: true, + reqs: []*http.Request{httptest.NewRequest("PUT", "/v0/put/foo.partial", nil)}, + checks: checks( + httpStatus(400), + bodyContains("invalid filename"), + ), + }, + { + name: "bad_filename_deleted", + isSelf: true, + capSharing: true, + reqs: []*http.Request{httptest.NewRequest("PUT", "/v0/put/foo.deleted", nil)}, + checks: checks( + httpStatus(400), + bodyContains("invalid filename"), + ), + }, + { + name: "bad_filename_dot", + isSelf: true, + capSharing: true, + reqs: []*http.Request{httptest.NewRequest("PUT", "/v0/put/.", nil)}, + checks: checks( + httpStatus(400), + bodyContains("invalid filename"), + ), + }, + { + name: "bad_filename_empty", + isSelf: true, + capSharing: true, + reqs: []*http.Request{httptest.NewRequest("PUT", "/v0/put/", nil)}, + checks: checks( + httpStatus(400), + bodyContains("invalid filename"), + ), + }, + { + name: "bad_filename_slash", + isSelf: true, + capSharing: true, + reqs: []*http.Request{httptest.NewRequest("PUT", "/v0/put/foo/bar", nil)}, + checks: checks( + httpStatus(400), + bodyContains("invalid filename"), + ), + }, + { + name: "bad_filename_encoded_dot", + isSelf: true, + capSharing: true, + reqs: []*http.Request{httptest.NewRequest("PUT", "/v0/put/"+hexAll("."), nil)}, + checks: checks( + httpStatus(400), + bodyContains("invalid filename"), + ), + }, + { + name: "bad_filename_encoded_slash", + isSelf: true, + capSharing: true, + reqs: []*http.Request{httptest.NewRequest("PUT", "/v0/put/"+hexAll("/"), nil)}, + checks: checks( + httpStatus(400), + bodyContains("invalid filename"), + ), + }, + { + name: "bad_filename_encoded_backslash", + isSelf: true, + capSharing: true, + reqs: []*http.Request{httptest.NewRequest("PUT", "/v0/put/"+hexAll("\\"), nil)}, + checks: checks( + httpStatus(400), + bodyContains("invalid filename"), + ), + }, + { + name: "bad_filename_encoded_dotdot", + isSelf: true, + capSharing: true, + reqs: []*http.Request{httptest.NewRequest("PUT", "/v0/put/"+hexAll(".."), nil)}, + checks: checks( + httpStatus(400), + bodyContains("invalid filename"), + ), + }, + { + name: "bad_filename_encoded_dotdot_out", + isSelf: true, + capSharing: true, + reqs: []*http.Request{httptest.NewRequest("PUT", "/v0/put/"+hexAll("foo/../../../../../etc/passwd"), nil)}, + checks: checks( + httpStatus(400), + bodyContains("invalid filename"), + ), + }, + { + name: "put_spaces_and_caps", + isSelf: true, + capSharing: true, + reqs: []*http.Request{httptest.NewRequest("PUT", "/v0/put/"+hexAll("Foo Bar.dat"), strings.NewReader("baz"))}, + checks: checks( + httpStatus(200), + bodyContains("{}"), + fileHasContents("Foo Bar.dat", "baz"), + ), + }, + { + name: "put_unicode", + isSelf: true, + capSharing: true, + reqs: []*http.Request{httptest.NewRequest("PUT", "/v0/put/"+hexAll("ĐĸĐžĐŧĐ°Ņ и ĐĩĐŗĐž Đ´Ņ€ŅƒĐˇŅŒŅ.mp3"), strings.NewReader("ĐŗĐģавĐŊŅ‹Đš ĐžĐˇĐžŅ€ĐŊиĐē"))}, + checks: checks( + httpStatus(200), + bodyContains("{}"), + fileHasContents("ĐĸĐžĐŧĐ°Ņ и ĐĩĐŗĐž Đ´Ņ€ŅƒĐˇŅŒŅ.mp3", "ĐŗĐģавĐŊŅ‹Đš ĐžĐˇĐžŅ€ĐŊиĐē"), + ), + }, + { + name: "put_invalid_utf8", + isSelf: true, + capSharing: true, + reqs: []*http.Request{httptest.NewRequest("PUT", "/v0/put/"+(hexAll("😜")[:3]), nil)}, + checks: checks( + httpStatus(400), + bodyContains("invalid filename"), + ), + }, + { + name: "put_invalid_null", + isSelf: true, + capSharing: true, + reqs: []*http.Request{httptest.NewRequest("PUT", "/v0/put/%00", nil)}, + checks: checks( + httpStatus(400), + bodyContains("invalid filename"), + ), + }, + { + name: "put_invalid_non_printable", + isSelf: true, + capSharing: true, + reqs: []*http.Request{httptest.NewRequest("PUT", "/v0/put/%01", nil)}, + checks: checks( + httpStatus(400), + bodyContains("invalid filename"), + ), + }, + { + name: "put_invalid_colon", + isSelf: true, + capSharing: true, + reqs: []*http.Request{httptest.NewRequest("PUT", "/v0/put/"+hexAll("nul:"), nil)}, + checks: checks( + httpStatus(400), + bodyContains("invalid filename"), + ), + }, + { + name: "put_invalid_surrounding_whitespace", + isSelf: true, + capSharing: true, + reqs: []*http.Request{httptest.NewRequest("PUT", "/v0/put/"+hexAll(" foo "), nil)}, + checks: checks( + httpStatus(400), + bodyContains("invalid filename"), + ), + }, + { + name: "duplicate_zero_length", + isSelf: true, + capSharing: true, + reqs: []*http.Request{ + httptest.NewRequest("PUT", "/v0/put/foo", nil), + httptest.NewRequest("PUT", "/v0/put/foo", nil), + }, + checks: checks( + httpStatus(200), + func(t *testing.T, env *peerAPITestEnv) { + got, err := env.taildrop.WaitingFiles() + if err != nil { + t.Fatalf("WaitingFiles error: %v", err) + } + want := []apitype.WaitingFile{{Name: "foo", Size: 0}} + if diff := cmp.Diff(got, want); diff != "" { + t.Fatalf("WaitingFile mismatch (-got +want):\n%s", diff) + } + }, + ), + }, + { + name: "duplicate_non_zero_length_content_length", + isSelf: true, + capSharing: true, + reqs: []*http.Request{ + httptest.NewRequest("PUT", "/v0/put/foo", strings.NewReader("contents")), + httptest.NewRequest("PUT", "/v0/put/foo", strings.NewReader("contents")), + }, + checks: checks( + httpStatus(200), + func(t *testing.T, env *peerAPITestEnv) { + got, err := env.taildrop.WaitingFiles() + if err != nil { + t.Fatalf("WaitingFiles error: %v", err) + } + want := []apitype.WaitingFile{{Name: "foo", Size: 8}} + if diff := cmp.Diff(got, want); diff != "" { + t.Fatalf("WaitingFile mismatch (-got +want):\n%s", diff) + } + }, + ), + }, + { + name: "duplicate_different_files", + isSelf: true, + capSharing: true, + reqs: []*http.Request{ + httptest.NewRequest("PUT", "/v0/put/foo", strings.NewReader("fizz")), + httptest.NewRequest("PUT", "/v0/put/foo", strings.NewReader("buzz")), + }, + checks: checks( + httpStatus(200), + func(t *testing.T, env *peerAPITestEnv) { + got, err := env.taildrop.WaitingFiles() + if err != nil { + t.Fatalf("WaitingFiles error: %v", err) + } + want := []apitype.WaitingFile{{Name: "foo", Size: 4}, {Name: "foo (1)", Size: 4}} + if diff := cmp.Diff(got, want); diff != "" { + t.Fatalf("WaitingFile mismatch (-got +want):\n%s", diff) + } + }, + ), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + selfNode := &tailcfg.Node{ + Addresses: []netip.Prefix{ + netip.MustParsePrefix("100.100.100.101/32"), + }, + } + if tt.debugCap { + selfNode.CapMap = tailcfg.NodeCapMap{tailcfg.CapabilityDebug: nil} + } + var rootDir string + var fo FileOps + if !tt.omitRoot { + var err error + if fo, err = newFileOps(t.TempDir()); err != nil { + t.Fatalf("newFileOps: %v", err) + } + } + + var e peerAPITestEnv + e.taildrop = managerOptions{ + Logf: e.logBuf.Logf, + fileOps: fo, + }.New() + + ext := &fakeExtension{ + logf: e.logBuf.Logf, + capFileSharing: tt.capSharing, + clock: &tstest.Clock{}, + taildrop: e.taildrop, + } + e.ph = &peerAPIHandler{ + isSelf: tt.isSelf, + selfNode: selfNode.View(), + peerNode: (&tailcfg.Node{ComputedName: "some-peer-name"}).View(), + } + for _, req := range tt.reqs { + e.rr = httptest.NewRecorder() + if req.Host == "example.com" { + req.Host = "100.100.100.101:12345" + } + handlePeerPutWithBackend(e.ph, ext, e.rr, req) + } + for _, f := range tt.checks { + f(t, &e) + } + if t.Failed() && rootDir != "" { + t.Logf("Contents of %s:", rootDir) + des, _ := fs.ReadDir(os.DirFS(rootDir), ".") + for _, de := range des { + fi, err := de.Info() + if err != nil { + t.Log(err) + } else { + t.Logf(" %v %5d %s", fi.Mode(), fi.Size(), de.Name()) + } + } + } + }) + } +} + +// Windows likes to hold on to file descriptors for some indeterminate +// amount of time after you close them and not let you delete them for +// a bit. So test that we work around that sufficiently. +func TestFileDeleteRace(t *testing.T) { + dir := t.TempDir() + taildropMgr := managerOptions{ + Logf: t.Logf, + fileOps: must.Get(newFileOps(dir)), + }.New() + + ph := &peerAPIHandler{ + isSelf: true, + peerNode: (&tailcfg.Node{ + ComputedName: "some-peer-name", + }).View(), + selfNode: (&tailcfg.Node{ + Addresses: []netip.Prefix{netip.MustParsePrefix("100.100.100.101/32")}, + }).View(), + } + fakeLB := &fakeExtension{ + logf: t.Logf, + capFileSharing: true, + clock: &tstest.Clock{}, + taildrop: taildropMgr, + } + buf := make([]byte, 2<<20) + for range 30 { + rr := httptest.NewRecorder() + handlePeerPutWithBackend(ph, fakeLB, rr, httptest.NewRequest("PUT", "http://100.100.100.101:123/v0/put/foo.txt", bytes.NewReader(buf[:rand.Intn(len(buf))]))) + if res := rr.Result(); res.StatusCode != 200 { + t.Fatal(res.Status) + } + wfs, err := taildropMgr.WaitingFiles() + if err != nil { + t.Fatal(err) + } + if len(wfs) != 1 { + t.Fatalf("waiting files = %d; want 1", len(wfs)) + } + + if err := taildropMgr.DeleteFile("foo.txt"); err != nil { + t.Fatal(err) + } + wfs, err = taildropMgr.WaitingFiles() + if err != nil { + t.Fatal(err) + } + if len(wfs) != 0 { + t.Fatalf("waiting files = %d; want 0", len(wfs)) + } + } +} diff --git a/taildrop/resume.go b/feature/taildrop/resume.go similarity index 67% rename from taildrop/resume.go rename to feature/taildrop/resume.go index f7bee3d95..20ef527a6 100644 --- a/taildrop/resume.go +++ b/feature/taildrop/resume.go @@ -9,7 +9,6 @@ import ( "encoding/hex" "fmt" "io" - "io/fs" "os" "strings" ) @@ -19,29 +18,29 @@ var ( hashAlgorithm = "sha256" ) -// BlockChecksum represents the checksum for a single block. -type BlockChecksum struct { - Checksum Checksum `json:"checksum"` +// blockChecksum represents the checksum for a single block. +type blockChecksum struct { + Checksum checksum `json:"checksum"` Algorithm string `json:"algo"` // always "sha256" for now Size int64 `json:"size"` // always (64<<10) for now } -// Checksum is an opaque checksum that is comparable. -type Checksum struct{ cs [sha256.Size]byte } +// checksum is an opaque checksum that is comparable. +type checksum struct{ cs [sha256.Size]byte } -func hash(b []byte) Checksum { - return Checksum{sha256.Sum256(b)} +func hash(b []byte) checksum { + return checksum{sha256.Sum256(b)} } -func (cs Checksum) String() string { +func (cs checksum) String() string { return hex.EncodeToString(cs.cs[:]) } -func (cs Checksum) AppendText(b []byte) ([]byte, error) { +func (cs checksum) AppendText(b []byte) ([]byte, error) { return hex.AppendEncode(b, cs.cs[:]), nil } -func (cs Checksum) MarshalText() ([]byte, error) { +func (cs checksum) MarshalText() ([]byte, error) { return hex.AppendEncode(nil, cs.cs[:]), nil } -func (cs *Checksum) UnmarshalText(b []byte) error { +func (cs *checksum) UnmarshalText(b []byte) error { if len(b) != 2*len(cs.cs) { return fmt.Errorf("invalid hex length: %d", len(b)) } @@ -51,19 +50,20 @@ func (cs *Checksum) UnmarshalText(b []byte) error { // PartialFiles returns a list of partial files in [Handler.Dir] // that were sent (or is actively being sent) by the provided id. -func (m *Manager) PartialFiles(id ClientID) (ret []string, err error) { - if m == nil || m.opts.Dir == "" { +func (m *manager) PartialFiles(id clientID) ([]string, error) { + if m == nil || m.opts.fileOps == nil { return nil, ErrNoTaildrop } - suffix := id.partialSuffix() - if err := rangeDir(m.opts.Dir, func(de fs.DirEntry) bool { - if name := de.Name(); strings.HasSuffix(name, suffix) { - ret = append(ret, name) + files, err := m.opts.fileOps.ListFiles() + if err != nil { + return nil, redactError(err) + } + var ret []string + for _, filename := range files { + if strings.HasSuffix(filename, suffix) { + ret = append(ret, filename) } - return true - }); err != nil { - return ret, redactError(err) } return ret, nil } @@ -72,18 +72,14 @@ func (m *Manager) PartialFiles(id ClientID) (ret []string, err error) { // starting from the beginning of the file. // It returns (BlockChecksum{}, io.EOF) when the stream is complete. // It is the caller's responsibility to call close. -func (m *Manager) HashPartialFile(id ClientID, baseName string) (next func() (BlockChecksum, error), close func() error, err error) { - if m == nil || m.opts.Dir == "" { +func (m *manager) HashPartialFile(id clientID, baseName string) (next func() (blockChecksum, error), close func() error, err error) { + if m == nil || m.opts.fileOps == nil { return nil, nil, ErrNoTaildrop } - noopNext := func() (BlockChecksum, error) { return BlockChecksum{}, io.EOF } + noopNext := func() (blockChecksum, error) { return blockChecksum{}, io.EOF } noopClose := func() error { return nil } - dstFile, err := joinDir(m.opts.Dir, baseName) - if err != nil { - return nil, nil, err - } - f, err := os.Open(dstFile + id.partialSuffix()) + f, err := m.opts.fileOps.OpenReader(baseName + id.partialSuffix()) if err != nil { if os.IsNotExist(err) { return noopNext, noopClose, nil @@ -92,25 +88,25 @@ func (m *Manager) HashPartialFile(id ClientID, baseName string) (next func() (Bl } b := make([]byte, blockSize) // TODO: Pool this? - next = func() (BlockChecksum, error) { + next = func() (blockChecksum, error) { switch n, err := io.ReadFull(f, b); { case err != nil && err != io.EOF && err != io.ErrUnexpectedEOF: - return BlockChecksum{}, redactError(err) + return blockChecksum{}, redactError(err) case n == 0: - return BlockChecksum{}, io.EOF + return blockChecksum{}, io.EOF default: - return BlockChecksum{hash(b[:n]), hashAlgorithm, int64(n)}, nil + return blockChecksum{hash(b[:n]), hashAlgorithm, int64(n)}, nil } } close = f.Close return next, close, nil } -// ResumeReader reads and discards the leading content of r +// resumeReader reads and discards the leading content of r // that matches the content based on the checksums that exist. // It returns the number of bytes consumed, // and returns an [io.Reader] representing the remaining content. -func ResumeReader(r io.Reader, hashNext func() (BlockChecksum, error)) (int64, io.Reader, error) { +func resumeReader(r io.Reader, hashNext func() (blockChecksum, error)) (int64, io.Reader, error) { if hashNext == nil { return 0, r, nil } diff --git a/taildrop/resume_test.go b/feature/taildrop/resume_test.go similarity index 83% rename from taildrop/resume_test.go rename to feature/taildrop/resume_test.go index d366340eb..4e59d401d 100644 --- a/taildrop/resume_test.go +++ b/feature/taildrop/resume_test.go @@ -8,6 +8,7 @@ import ( "io" "math/rand" "os" + "path/filepath" "testing" "testing/iotest" @@ -19,7 +20,9 @@ func TestResume(t *testing.T) { defer func() { blockSize = oldBlockSize }() blockSize = 256 - m := ManagerOptions{Logf: t.Logf, Dir: t.TempDir()}.New() + dir := t.TempDir() + + m := managerOptions{Logf: t.Logf, fileOps: must.Get(newFileOps(dir))}.New() defer m.Shutdown() rn := rand.New(rand.NewSource(0)) @@ -32,12 +35,12 @@ func TestResume(t *testing.T) { next, close, err := m.HashPartialFile("", "foo") must.Do(err) defer close() - offset, r, err := ResumeReader(r, next) + offset, r, err := resumeReader(r, next) must.Do(err) must.Do(close()) // Windows wants the file handle to be closed to rename it. must.Get(m.PutFile("", "foo", r, offset, -1)) - got := must.Get(os.ReadFile(must.Get(joinDir(m.opts.Dir, "foo")))) + got := must.Get(os.ReadFile(filepath.Join(dir, "foo"))) if !bytes.Equal(got, want) { t.Errorf("content mismatches") } @@ -51,7 +54,7 @@ func TestResume(t *testing.T) { next, close, err := m.HashPartialFile("", "bar") must.Do(err) defer close() - offset, r, err := ResumeReader(r, next) + offset, r, err := resumeReader(r, next) must.Do(err) must.Do(close()) // Windows wants the file handle to be closed to rename it. @@ -66,7 +69,7 @@ func TestResume(t *testing.T) { t.Fatalf("too many iterations to complete the test") } } - got := must.Get(os.ReadFile(must.Get(joinDir(m.opts.Dir, "bar")))) + got := must.Get(os.ReadFile(filepath.Join(dir, "bar"))) if !bytes.Equal(got, want) { t.Errorf("content mismatches") } diff --git a/taildrop/retrieve.go b/feature/taildrop/retrieve.go similarity index 58% rename from taildrop/retrieve.go rename to feature/taildrop/retrieve.go index 3e37b492a..e767bac32 100644 --- a/taildrop/retrieve.go +++ b/feature/taildrop/retrieve.go @@ -9,19 +9,19 @@ import ( "io" "io/fs" "os" - "path/filepath" "runtime" "sort" "time" "tailscale.com/client/tailscale/apitype" - "tailscale.com/logtail/backoff" + "tailscale.com/util/backoff" + "tailscale.com/util/set" ) // HasFilesWaiting reports whether any files are buffered in [Handler.Dir]. // This always returns false when [Handler.DirectFileMode] is false. -func (m *Manager) HasFilesWaiting() (has bool) { - if m == nil || m.opts.Dir == "" || m.opts.DirectFileMode { +func (m *manager) HasFilesWaiting() bool { + if m == nil || m.opts.fileOps == nil || m.opts.DirectFileMode { return false } @@ -30,63 +30,66 @@ func (m *Manager) HasFilesWaiting() (has bool) { // has-files-or-not values as the macOS/iOS client might // in the future use+delete the files directly. So only // keep this negative cache. - totalReceived := m.totalReceived.Load() - if totalReceived == m.emptySince.Load() { + total := m.totalReceived.Load() + if total == m.emptySince.Load() { return false } - // Check whether there is at least one one waiting file. - err := rangeDir(m.opts.Dir, func(de fs.DirEntry) bool { - name := de.Name() - if isPartialOrDeleted(name) || !de.Type().IsRegular() { - return true + files, err := m.opts.fileOps.ListFiles() + if err != nil { + return false + } + + // Build a set of filenames present in Dir + fileSet := set.Of(files...) + + for _, filename := range files { + if isPartialOrDeleted(filename) { + continue } - _, err := os.Stat(filepath.Join(m.opts.Dir, name+deletedSuffix)) - if os.IsNotExist(err) { - has = true - return false + if fileSet.Contains(filename + deletedSuffix) { + continue // already handled } + // Found at least one downloadable file return true - }) - - // If there are no more waiting files, record totalReceived as emptySince - // so that we can short-circuit the expensive directory traversal - // if no files have been received after the start of this call. - if err == nil && !has { - m.emptySince.Store(totalReceived) } - return has + + // No waiting files → update negative‑result cache + m.emptySince.Store(total) + return false } // WaitingFiles returns the list of files that have been sent by a // peer that are waiting in [Handler.Dir]. // This always returns nil when [Handler.DirectFileMode] is false. -func (m *Manager) WaitingFiles() (ret []apitype.WaitingFile, err error) { - if m == nil || m.opts.Dir == "" { +func (m *manager) WaitingFiles() ([]apitype.WaitingFile, error) { + if m == nil || m.opts.fileOps == nil { return nil, ErrNoTaildrop } if m.opts.DirectFileMode { return nil, nil } - if err := rangeDir(m.opts.Dir, func(de fs.DirEntry) bool { - name := de.Name() - if isPartialOrDeleted(name) || !de.Type().IsRegular() { - return true + names, err := m.opts.fileOps.ListFiles() + if err != nil { + return nil, redactError(err) + } + var ret []apitype.WaitingFile + for _, name := range names { + if isPartialOrDeleted(name) { + continue } - _, err := os.Stat(filepath.Join(m.opts.Dir, name+deletedSuffix)) - if os.IsNotExist(err) { - fi, err := de.Info() - if err != nil { - return true - } - ret = append(ret, apitype.WaitingFile{ - Name: filepath.Base(name), - Size: fi.Size(), - }) + // A corresponding .deleted marker means the file was already handled. + if _, err := m.opts.fileOps.Stat(name + deletedSuffix); err == nil { + continue } - return true - }); err != nil { - return nil, redactError(err) + fi, err := m.opts.fileOps.Stat(name) + if err != nil { + continue + } + ret = append(ret, apitype.WaitingFile{ + Name: name, + Size: fi.Size(), + }) } sort.Slice(ret, func(i, j int) bool { return ret[i].Name < ret[j].Name }) return ret, nil @@ -94,22 +97,19 @@ func (m *Manager) WaitingFiles() (ret []apitype.WaitingFile, err error) { // DeleteFile deletes a file of the given baseName from [Handler.Dir]. // This method is only allowed when [Handler.DirectFileMode] is false. -func (m *Manager) DeleteFile(baseName string) error { - if m == nil || m.opts.Dir == "" { +func (m *manager) DeleteFile(baseName string) error { + if m == nil || m.opts.fileOps == nil { return ErrNoTaildrop } if m.opts.DirectFileMode { return errors.New("deletes not allowed in direct mode") } - path, err := joinDir(m.opts.Dir, baseName) - if err != nil { - return err - } + var bo *backoff.Backoff logf := m.opts.Logf t0 := m.opts.Clock.Now() for { - err := os.Remove(path) + err := m.opts.fileOps.Remove(baseName) if err != nil && !os.IsNotExist(err) { err = redactError(err) // Put a retry loop around deletes on Windows. @@ -129,7 +129,7 @@ func (m *Manager) DeleteFile(baseName string) error { bo.BackOff(context.Background(), err) continue } - if err := touchFile(path + deletedSuffix); err != nil { + if err := m.touchFile(baseName + deletedSuffix); err != nil { logf("peerapi: failed to leave deleted marker: %v", err) } m.deleter.Insert(baseName + deletedSuffix) @@ -141,35 +141,31 @@ func (m *Manager) DeleteFile(baseName string) error { } } -func touchFile(path string) error { - f, err := os.OpenFile(path, os.O_RDWR|os.O_CREATE, 0666) +func (m *manager) touchFile(name string) error { + wc, _, err := m.opts.fileOps.OpenWriter(name /* offset= */, 0, 0666) if err != nil { return redactError(err) } - return f.Close() + return wc.Close() } // OpenFile opens a file of the given baseName from [Handler.Dir]. // This method is only allowed when [Handler.DirectFileMode] is false. -func (m *Manager) OpenFile(baseName string) (rc io.ReadCloser, size int64, err error) { - if m == nil || m.opts.Dir == "" { +func (m *manager) OpenFile(baseName string) (rc io.ReadCloser, size int64, err error) { + if m == nil || m.opts.fileOps == nil { return nil, 0, ErrNoTaildrop } if m.opts.DirectFileMode { return nil, 0, errors.New("opens not allowed in direct mode") } - path, err := joinDir(m.opts.Dir, baseName) - if err != nil { - return nil, 0, err - } - if _, err := os.Stat(path + deletedSuffix); err == nil { - return nil, 0, redactError(&fs.PathError{Op: "open", Path: path, Err: fs.ErrNotExist}) + if _, err := m.opts.fileOps.Stat(baseName + deletedSuffix); err == nil { + return nil, 0, redactError(&fs.PathError{Op: "open", Path: baseName, Err: fs.ErrNotExist}) } - f, err := os.Open(path) + f, err := m.opts.fileOps.OpenReader(baseName) if err != nil { return nil, 0, redactError(err) } - fi, err := f.Stat() + fi, err := m.opts.fileOps.Stat(baseName) if err != nil { f.Close() return nil, 0, redactError(err) diff --git a/feature/taildrop/send.go b/feature/taildrop/send.go new file mode 100644 index 000000000..32ba5f6f0 --- /dev/null +++ b/feature/taildrop/send.go @@ -0,0 +1,171 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package taildrop + +import ( + "fmt" + "io" + "sync" + "time" + + "tailscale.com/envknob" + "tailscale.com/ipn" + "tailscale.com/tstime" + "tailscale.com/version/distro" +) + +type incomingFileKey struct { + id clientID + name string // e.g., "foo.jpeg" +} + +type incomingFile struct { + clock tstime.DefaultClock + + started time.Time + size int64 // or -1 if unknown; never 0 + w io.Writer // underlying writer + sendFileNotify func() // called when done + partialPath string // non-empty in direct mode + finalPath string // not used in direct mode + + mu sync.Mutex + copied int64 + done bool + lastNotify time.Time +} + +func (f *incomingFile) Write(p []byte) (n int, err error) { + n, err = f.w.Write(p) + + var needNotify bool + defer func() { + if needNotify { + f.sendFileNotify() + } + }() + if n > 0 { + f.mu.Lock() + defer f.mu.Unlock() + f.copied += int64(n) + now := f.clock.Now() + if f.lastNotify.IsZero() || now.Sub(f.lastNotify) > time.Second { + f.lastNotify = now + needNotify = true + } + } + return n, err +} + +// PutFile stores a file into [manager.Dir] from a given client id. +// The baseName must be a base filename without any slashes. +// The length is the expected length of content to read from r, +// it may be negative to indicate that it is unknown. +// It returns the length of the entire file. +// +// If there is a failure reading from r, then the partial file is not deleted +// for some period of time. The [manager.PartialFiles] and [manager.HashPartialFile] +// methods may be used to list all partial files and to compute the hash for a +// specific partial file. This allows the client to determine whether to resume +// a partial file. While resuming, PutFile may be called again with a non-zero +// offset to specify where to resume receiving data at. +func (m *manager) PutFile(id clientID, baseName string, r io.Reader, offset, length int64) (fileLength int64, err error) { + + switch { + case m == nil || m.opts.fileOps == nil: + return 0, ErrNoTaildrop + case !envknob.CanTaildrop(): + return 0, ErrNoTaildrop + case distro.Get() == distro.Unraid && !m.opts.DirectFileMode: + return 0, ErrNotAccessible + } + + if err := validateBaseName(baseName); err != nil { + return 0, err + } + + // and make sure we don't delete it while uploading: + m.deleter.Remove(baseName) + + // Create (if not already) the partial file with read-write permissions. + partialName := baseName + id.partialSuffix() + wc, partialPath, err := m.opts.fileOps.OpenWriter(partialName, offset, 0o666) + if err != nil { + return 0, m.redactAndLogError("Create", err) + } + defer func() { + wc.Close() + if err != nil { + m.deleter.Insert(partialName) // mark partial file for eventual deletion + } + }() + + // Check whether there is an in-progress transfer for the file. + inFileKey := incomingFileKey{id, baseName} + inFile, loaded := m.incomingFiles.LoadOrInit(inFileKey, func() *incomingFile { + inFile := &incomingFile{ + clock: m.opts.Clock, + started: m.opts.Clock.Now(), + size: length, + sendFileNotify: m.opts.SendFileNotify, + } + if m.opts.DirectFileMode { + inFile.partialPath = partialPath + } + return inFile + }) + + inFile.w = wc + + if loaded { + return 0, ErrFileExists + } + defer m.incomingFiles.Delete(inFileKey) + + // Record that we have started to receive at least one file. + // This is used by the deleter upon a cold-start to scan the directory + // for any files that need to be deleted. + if st := m.opts.State; st != nil { + if b, _ := st.ReadState(ipn.TaildropReceivedKey); len(b) == 0 { + if werr := st.WriteState(ipn.TaildropReceivedKey, []byte{1}); werr != nil { + m.opts.Logf("WriteState error: %v", werr) // non-fatal error + } + } + } + + // Copy the contents of the file to the writer. + copyLength, err := io.Copy(wc, r) + if err != nil { + return 0, m.redactAndLogError("Copy", err) + } + if length >= 0 && copyLength != length { + return 0, m.redactAndLogError("Copy", fmt.Errorf("copied %d bytes; expected %d", copyLength, length)) + } + if err := wc.Close(); err != nil { + return 0, m.redactAndLogError("Close", err) + } + + fileLength = offset + copyLength + + inFile.mu.Lock() + inFile.done = true + inFile.mu.Unlock() + + // 6) Finalize (rename/move) the partial into place via FileOps.Rename + finalPath, err := m.opts.fileOps.Rename(partialPath, baseName) + if err != nil { + return 0, m.redactAndLogError("Rename", err) + } + inFile.finalPath = finalPath + + m.totalReceived.Add(1) + m.opts.SendFileNotify() + return fileLength, nil +} + +func (m *manager) redactAndLogError(stage string, err error) error { + err = redactError(err) + m.opts.Logf("put %s error: %v", stage, err) + return err +} diff --git a/feature/taildrop/send_test.go b/feature/taildrop/send_test.go new file mode 100644 index 000000000..9ffa5fccc --- /dev/null +++ b/feature/taildrop/send_test.go @@ -0,0 +1,69 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package taildrop + +import ( + "os" + "path/filepath" + "strings" + "testing" + + "tailscale.com/tstime" + "tailscale.com/util/must" +) + +func TestPutFile(t *testing.T) { + const content = "hello, world" + + tests := []struct { + name string + directFileMode bool + }{ + {"DirectFileMode", true}, + {"NonDirectFileMode", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + dir := t.TempDir() + mgr := managerOptions{ + Logf: t.Logf, + Clock: tstime.DefaultClock{}, + State: nil, + fileOps: must.Get(newFileOps(dir)), + DirectFileMode: tt.directFileMode, + SendFileNotify: func() {}, + }.New() + + id := clientID("0") + n, err := mgr.PutFile(id, "file.txt", strings.NewReader(content), 0, int64(len(content))) + if err != nil { + t.Fatalf("PutFile error: %v", err) + } + if n != int64(len(content)) { + t.Errorf("wrote %d bytes; want %d", n, len(content)) + } + + path := filepath.Join(dir, "file.txt") + + got, err := os.ReadFile(path) + if err != nil { + t.Fatalf("ReadFile %q: %v", path, err) + } + if string(got) != content { + t.Errorf("file contents = %q; want %q", string(got), content) + } + + entries, err := os.ReadDir(dir) + if err != nil { + t.Fatal(err) + } + for _, entry := range entries { + if strings.Contains(entry.Name(), ".partial") { + t.Errorf("unexpected partial file left behind: %s", entry.Name()) + } + } + }) + } +} diff --git a/taildrop/taildrop.go b/feature/taildrop/taildrop.go similarity index 77% rename from taildrop/taildrop.go rename to feature/taildrop/taildrop.go index 9ad0e1a7e..6c3deaed1 100644 --- a/taildrop/taildrop.go +++ b/feature/taildrop/taildrop.go @@ -12,15 +12,13 @@ package taildrop import ( "errors" "hash/adler32" - "io" - "io/fs" "os" "path" "path/filepath" "regexp" + "sort" "strconv" "strings" - "sync" "sync/atomic" "unicode" "unicode/utf8" @@ -53,29 +51,24 @@ const ( deletedSuffix = ".deleted" ) -// ClientID is an opaque identifier for file resumption. +// clientID is an opaque identifier for file resumption. // A client can only list and resume partial files for its own ID. // It must contain any filesystem specific characters (e.g., slashes). -type ClientID string // e.g., "n12345CNTRL" +type clientID string // e.g., "n12345CNTRL" -func (id ClientID) partialSuffix() string { +func (id clientID) partialSuffix() string { if id == "" { return partialSuffix } return "." + string(id) + partialSuffix // e.g., ".n12345CNTRL.partial" } -// ManagerOptions are options to configure the [Manager]. -type ManagerOptions struct { +// managerOptions are options to configure the [manager]. +type managerOptions struct { Logf logger.Logf // may be nil Clock tstime.DefaultClock // may be nil State ipn.StateStore // may be nil - // Dir is the directory to store received files. - // This main either be the final location for the files - // or just a temporary staging directory (see DirectFileMode). - Dir string - // DirectFileMode reports whether we are writing files // directly to a download directory, rather than writing them to // a temporary staging directory. @@ -90,6 +83,11 @@ type ManagerOptions struct { // copy them out, and then delete them. DirectFileMode bool + // FileOps abstracts platform-specific file operations needed for file transfers. + // Android's implementation uses the Storage Access Framework, and other platforms + // use fsFileOps. + fileOps FileOps + // SendFileNotify is called periodically while a file is actively // receiving the contents for the file. There is a final call // to the function when reception completes. @@ -97,18 +95,15 @@ type ManagerOptions struct { SendFileNotify func() } -// Manager manages the state for receiving and managing taildropped files. -type Manager struct { - opts ManagerOptions +// manager manages the state for receiving and managing taildropped files. +type manager struct { + opts managerOptions // incomingFiles is a map of files actively being received. incomingFiles syncs.Map[incomingFileKey, *incomingFile] // deleter managers asynchronous deletion of files. deleter fileDeleter - // renameMu is used to protect os.Rename calls so that they are atomic. - renameMu sync.Mutex - // totalReceived counts the cumulative total of received files. totalReceived atomic.Int64 // emptySince specifies that there were no waiting files @@ -119,27 +114,22 @@ type Manager struct { // New initializes a new taildrop manager. // It may spawn asynchronous goroutines to delete files, // so the Shutdown method must be called for resource cleanup. -func (opts ManagerOptions) New() *Manager { +func (opts managerOptions) New() *manager { if opts.Logf == nil { opts.Logf = logger.Discard } if opts.SendFileNotify == nil { opts.SendFileNotify = func() {} } - m := &Manager{opts: opts} + m := &manager{opts: opts} m.deleter.Init(m, func(string) {}) m.emptySince.Store(-1) // invalidate this cache return m } -// Dir returns the directory. -func (m *Manager) Dir() string { - return m.opts.Dir -} - // Shutdown shuts down the Manager. // It blocks until all spawned goroutines have stopped running. -func (m *Manager) Shutdown() { +func (m *manager) Shutdown() { if m != nil { m.deleter.shutdown() m.deleter.group.Wait() @@ -167,68 +157,39 @@ func isPartialOrDeleted(s string) bool { return strings.HasSuffix(s, deletedSuffix) || strings.HasSuffix(s, partialSuffix) } -func joinDir(dir, baseName string) (fullPath string, err error) { - if !utf8.ValidString(baseName) { - return "", ErrInvalidFileName - } - if strings.TrimSpace(baseName) != baseName { - return "", ErrInvalidFileName - } - if len(baseName) > 255 { - return "", ErrInvalidFileName +func validateBaseName(name string) error { + if !utf8.ValidString(name) || + strings.TrimSpace(name) != name || + len(name) > 255 { + return ErrInvalidFileName } // TODO: validate unicode normalization form too? Varies by platform. - clean := path.Clean(baseName) - if clean != baseName || - clean == "." || clean == ".." || - isPartialOrDeleted(clean) { - return "", ErrInvalidFileName + clean := path.Clean(name) + if clean != name || clean == "." || clean == ".." { + return ErrInvalidFileName } - for _, r := range baseName { + if isPartialOrDeleted(name) { + return ErrInvalidFileName + } + for _, r := range name { if !validFilenameRune(r) { - return "", ErrInvalidFileName + return ErrInvalidFileName } } - if !filepath.IsLocal(baseName) { - return "", ErrInvalidFileName - } - return filepath.Join(dir, baseName), nil -} - -// rangeDir iterates over the contents of a directory, calling fn for each entry. -// It continues iterating while fn returns true. -// It reports the number of entries seen. -func rangeDir(dir string, fn func(fs.DirEntry) bool) error { - f, err := os.Open(dir) - if err != nil { - return err - } - defer f.Close() - for { - des, err := f.ReadDir(10) - for _, de := range des { - if !fn(de) { - return nil - } - } - if err != nil { - if err == io.EOF { - return nil - } - return err - } + if !filepath.IsLocal(name) { + return ErrInvalidFileName } + return nil } // IncomingFiles returns a list of active incoming files. -func (m *Manager) IncomingFiles() []ipn.PartialFile { +func (m *manager) IncomingFiles() []ipn.PartialFile { // Make sure we always set n.IncomingFiles non-nil so it gets encoded // in JSON to clients. They distinguish between empty and non-nil // to know whether a Notify should be able about files. files := make([]ipn.PartialFile, 0) - m.incomingFiles.Range(func(k incomingFileKey, f *incomingFile) bool { + for k, f := range m.incomingFiles.All() { f.mu.Lock() - defer f.mu.Unlock() files = append(files, ipn.PartialFile{ Name: k.name, Started: f.started, @@ -238,8 +199,13 @@ func (m *Manager) IncomingFiles() []ipn.PartialFile { FinalPath: f.finalPath, Done: f.done, }) - return true + f.mu.Unlock() + } + + sort.Slice(files, func(i, j int) bool { + return files[i].Started.Before(files[j].Started) }) + return files } @@ -313,12 +279,12 @@ var ( rxNumberSuffix = regexp.MustCompile(` \([0-9]+\)`) ) -// NextFilename returns the next filename in a sequence. +// nextFilename returns the next filename in a sequence. // It is used for construction a new filename if there is a conflict. // // For example, "Foo.jpg" becomes "Foo (1).jpg" and // "Foo (1).jpg" becomes "Foo (2).jpg". -func NextFilename(name string) string { +func nextFilename(name string) string { ext := rxExtensionSuffix.FindString(strings.TrimPrefix(name, ".")) name = strings.TrimSuffix(name, ext) var n uint64 diff --git a/taildrop/taildrop_test.go b/feature/taildrop/taildrop_test.go similarity index 62% rename from taildrop/taildrop_test.go rename to feature/taildrop/taildrop_test.go index df4783c30..0d77273f0 100644 --- a/taildrop/taildrop_test.go +++ b/feature/taildrop/taildrop_test.go @@ -4,40 +4,10 @@ package taildrop import ( - "path/filepath" "strings" "testing" ) -func TestJoinDir(t *testing.T) { - dir := t.TempDir() - tests := []struct { - in string - want string // just relative to m.Dir - wantOk bool - }{ - {"", "", false}, - {"foo", "foo", true}, - {"./foo", "", false}, - {"../foo", "", false}, - {"foo/bar", "", false}, - {"😋", "😋", true}, - {"\xde\xad\xbe\xef", "", false}, - {"foo.partial", "", false}, - {"foo.deleted", "", false}, - {strings.Repeat("a", 1024), "", false}, - {"foo:bar", "", false}, - } - for _, tt := range tests { - got, gotErr := joinDir(dir, tt.in) - got, _ = filepath.Rel(dir, got) - gotOk := gotErr == nil - if got != tt.want || gotOk != tt.wantOk { - t.Errorf("joinDir(%q) = (%v, %v), want (%v, %v)", tt.in, got, gotOk, tt.want, tt.wantOk) - } - } -} - func TestNextFilename(t *testing.T) { tests := []struct { in string @@ -59,11 +29,37 @@ func TestNextFilename(t *testing.T) { } for _, tt := range tests { - if got := NextFilename(tt.in); got != tt.want { + if got := nextFilename(tt.in); got != tt.want { t.Errorf("NextFilename(%q) = %q, want %q", tt.in, got, tt.want) } - if got2 := NextFilename(tt.want); got2 != tt.want2 { + if got2 := nextFilename(tt.want); got2 != tt.want2 { t.Errorf("NextFilename(%q) = %q, want %q", tt.want, got2, tt.want2) } } } + +func TestValidateBaseName(t *testing.T) { + tests := []struct { + in string + wantOk bool + }{ + {"", false}, + {"foo", true}, + {"./foo", false}, + {"../foo", false}, + {"foo/bar", false}, + {"😋", true}, + {"\xde\xad\xbe\xef", false}, + {"foo.partial", false}, + {"foo.deleted", false}, + {strings.Repeat("a", 1024), false}, + {"foo:bar", false}, + } + for _, tt := range tests { + err := validateBaseName(tt.in) + gotOk := err == nil + if gotOk != tt.wantOk { + t.Errorf("validateBaseName(%q) = %v, wantOk = %v", tt.in, err, tt.wantOk) + } + } +} diff --git a/feature/taildrop/target_test.go b/feature/taildrop/target_test.go new file mode 100644 index 000000000..57c96a77a --- /dev/null +++ b/feature/taildrop/target_test.go @@ -0,0 +1,73 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package taildrop + +import ( + "fmt" + "testing" + + "tailscale.com/ipn" + "tailscale.com/ipn/ipnext" + "tailscale.com/tailcfg" +) + +func TestFileTargets(t *testing.T) { + e := new(Extension) + + _, err := e.FileTargets() + if got, want := fmt.Sprint(err), "not connected to the tailnet"; got != want { + t.Errorf("before connect: got %q; want %q", got, want) + } + + e.nodeBackendForTest = testNodeBackend{peers: nil} + + _, err = e.FileTargets() + if got, want := fmt.Sprint(err), "not connected to the tailnet"; got != want { + t.Errorf("non-running netmap: got %q; want %q", got, want) + } + + e.backendState = ipn.Running + _, err = e.FileTargets() + if got, want := fmt.Sprint(err), "file sharing not enabled by Tailscale admin"; got != want { + t.Errorf("without cap: got %q; want %q", got, want) + } + + e.capFileSharing = true + got, err := e.FileTargets() + if err != nil { + t.Fatal(err) + } + if len(got) != 0 { + t.Fatalf("unexpected %d peers", len(got)) + } + + var nodeID tailcfg.NodeID = 1234 + peer := &tailcfg.Node{ + ID: nodeID, + Hostinfo: (&tailcfg.Hostinfo{OS: "tvOS"}).View(), + } + e.nodeBackendForTest = testNodeBackend{peers: []tailcfg.NodeView{peer.View()}} + + got, err = e.FileTargets() + if err != nil { + t.Fatal(err) + } + if len(got) != 0 { + t.Fatalf("unexpected %d peers", len(got)) + } +} + +type testNodeBackend struct { + ipnext.NodeBackend + peers []tailcfg.NodeView +} + +func (t testNodeBackend) AppendMatchingPeers(peers []tailcfg.NodeView, f func(tailcfg.NodeView) bool) []tailcfg.NodeView { + for _, p := range t.peers { + if f(p) { + peers = append(peers, p) + } + } + return peers +} diff --git a/net/tstun/tap_linux.go b/feature/tap/tap_linux.go similarity index 66% rename from net/tstun/tap_linux.go rename to feature/tap/tap_linux.go index c721e6e27..53dcabc36 100644 --- a/net/tstun/tap_linux.go +++ b/feature/tap/tap_linux.go @@ -1,11 +1,12 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -//go:build !ts_omit_tap - -package tstun +// Package tap registers Tailscale's experimental (demo) Linux TAP (Layer 2) support. +package tap import ( + "bytes" + "errors" "fmt" "net" "net/netip" @@ -20,11 +21,15 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/checksum" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" "tailscale.com/net/netaddr" "tailscale.com/net/packet" + "tailscale.com/net/tsaddr" + "tailscale.com/net/tstun" + "tailscale.com/syncs" "tailscale.com/types/ipproto" - "tailscale.com/util/multierr" + "tailscale.com/types/logger" ) // TODO: this was randomly generated once. Maybe do it per process start? But @@ -33,15 +38,19 @@ import ( // For now just hard code it. var ourMAC = net.HardwareAddr{0x30, 0x2D, 0x66, 0xEC, 0x7A, 0x93} -func init() { createTAP = createTAPLinux } +const tapDebug = tstun.TAPDebug + +func init() { + tstun.CreateTAP.Set(createTAPLinux) +} -func createTAPLinux(tapName, bridgeName string) (tun.Device, error) { +func createTAPLinux(logf logger.Logf, tapName, bridgeName string) (tun.Device, error) { fd, err := unix.Open("/dev/net/tun", unix.O_RDWR, 0) if err != nil { return nil, err } - dev, err := openDevice(fd, tapName, bridgeName) + dev, err := openDevice(logf, fd, tapName, bridgeName) if err != nil { unix.Close(fd) return nil, err @@ -50,7 +59,7 @@ func createTAPLinux(tapName, bridgeName string) (tun.Device, error) { return dev, nil } -func openDevice(fd int, tapName, bridgeName string) (tun.Device, error) { +func openDevice(logf logger.Logf, fd int, tapName, bridgeName string) (tun.Device, error) { ifr, err := unix.NewIfreq(tapName) if err != nil { return nil, err @@ -71,7 +80,7 @@ func openDevice(fd int, tapName, bridgeName string) (tun.Device, error) { } } - return newTAPDevice(fd, tapName) + return newTAPDevice(logf, fd, tapName) } type etherType [2]byte @@ -82,7 +91,10 @@ var ( etherTypeIPv6 = etherType{0x86, 0xDD} ) -const ipv4HeaderLen = 20 +const ( + ipv4HeaderLen = 20 + ethernetFrameSize = 14 // 2 six byte MACs, 2 bytes ethertype +) const ( consumePacket = true @@ -91,7 +103,7 @@ const ( // handleTAPFrame handles receiving a raw TAP ethernet frame and reports whether // it's been handled (that is, whether it should NOT be passed to wireguard). -func (t *Wrapper) handleTAPFrame(ethBuf []byte) bool { +func (t *tapDevice) handleTAPFrame(ethBuf []byte) bool { if len(ethBuf) < ethernetFrameSize { // Corrupt. Ignore. @@ -154,7 +166,7 @@ func (t *Wrapper) handleTAPFrame(ethBuf []byte) bool { // If the client's asking about their own IP, tell them it's // their own MAC. TODO(bradfitz): remove String allocs. - if net.IP(req.ProtocolAddressTarget()).String() == theClientIP { + if net.IP(req.ProtocolAddressTarget()).String() == t.clientIPv4.Load() { copy(res.HardwareAddressSender(), ethSrcMAC) } else { copy(res.HardwareAddressSender(), ourMAC[:]) @@ -164,8 +176,7 @@ func (t *Wrapper) handleTAPFrame(ethBuf []byte) bool { copy(res.HardwareAddressTarget(), req.HardwareAddressSender()) copy(res.ProtocolAddressTarget(), req.ProtocolAddressSender()) - // TODO(raggi): reduce allocs! - n, err := t.tdev.Write([][]byte{buf}, 0) + n, err := t.WriteEthernet(buf) if tapDebug { t.logf("tap: wrote ARP reply %v, %v", n, err) } @@ -175,14 +186,22 @@ func (t *Wrapper) handleTAPFrame(ethBuf []byte) bool { } } -// TODO(bradfitz): remove these hard-coded values and move from a /24 to a /10 CGNAT as the range. -const theClientIP = "100.70.145.3" // TODO: make dynamic from netmap -const routerIP = "100.70.145.1" // must be in same netmask (currently hack at /24) as theClientIP +var ( + // routerIP is the IP address of the DHCP server. + routerIP = net.ParseIP(tsaddr.TailscaleServiceIPString) + // cgnatNetMask is the netmask of the 100.64.0.0/10 CGNAT range. + cgnatNetMask = net.IPMask(net.ParseIP("255.192.0.0").To4()) +) + +// parsedPacketPool holds a pool of Parsed structs for use in filtering. +// This is needed because escape analysis cannot see that parsed packets +// do not escape through {Pre,Post}Filter{In,Out}. +var parsedPacketPool = sync.Pool{New: func() any { return new(packet.Parsed) }} // handleDHCPRequest handles receiving a raw TAP ethernet frame and reports whether // it's been handled as a DHCP request. That is, it reports whether the frame should // be ignored by the caller and not passed on. -func (t *Wrapper) handleDHCPRequest(ethBuf []byte) bool { +func (t *tapDevice) handleDHCPRequest(ethBuf []byte) bool { const udpHeader = 8 if len(ethBuf) < ethernetFrameSize+ipv4HeaderLen+udpHeader { if tapDebug { @@ -207,7 +226,7 @@ func (t *Wrapper) handleDHCPRequest(ethBuf []byte) bool { if p.IPProto != ipproto.UDP || p.Src.Port() != 68 || p.Dst.Port() != 67 { // Not a DHCP request. if tapDebug { - t.logf("tap: DHCP wrong meta") + t.logf("tap: DHCP wrong meta: %+v", p) } return passOnPacket } @@ -225,17 +244,22 @@ func (t *Wrapper) handleDHCPRequest(ethBuf []byte) bool { } switch dp.MessageType() { case dhcpv4.MessageTypeDiscover: + ips := t.clientIPv4.Load() + if ips == "" { + t.logf("tap: DHCP no client IP") + return consumePacket + } offer, err := dhcpv4.New( dhcpv4.WithReply(dp), dhcpv4.WithMessageType(dhcpv4.MessageTypeOffer), - dhcpv4.WithRouter(net.ParseIP(routerIP)), // the default route - dhcpv4.WithDNS(net.ParseIP("100.100.100.100")), - dhcpv4.WithServerIP(net.ParseIP("100.100.100.100")), // TODO: what is this? - dhcpv4.WithOption(dhcpv4.OptServerIdentifier(net.ParseIP("100.100.100.100"))), - dhcpv4.WithYourIP(net.ParseIP(theClientIP)), + dhcpv4.WithRouter(routerIP), // the default route + dhcpv4.WithDNS(routerIP), + dhcpv4.WithServerIP(routerIP), // TODO: what is this? + dhcpv4.WithOption(dhcpv4.OptServerIdentifier(routerIP)), + dhcpv4.WithYourIP(net.ParseIP(ips)), dhcpv4.WithLeaseTime(3600), // hour works //dhcpv4.WithHwAddr(ethSrcMAC), - dhcpv4.WithNetmask(net.IPMask(net.ParseIP("255.255.255.0").To4())), // TODO: wrong + dhcpv4.WithNetmask(cgnatNetMask), //dhcpv4.WithTransactionID(dp.TransactionID), ) if err != nil { @@ -250,22 +274,26 @@ func (t *Wrapper) handleDHCPRequest(ethBuf []byte) bool { netip.AddrPortFrom(netaddr.IPv4(255, 255, 255, 255), 68), // dst ) - // TODO(raggi): reduce allocs! - n, err := t.tdev.Write([][]byte{pkt}, 0) + n, err := t.WriteEthernet(pkt) if tapDebug { t.logf("tap: wrote DHCP OFFER %v, %v", n, err) } case dhcpv4.MessageTypeRequest: + ips := t.clientIPv4.Load() + if ips == "" { + t.logf("tap: DHCP no client IP") + return consumePacket + } ack, err := dhcpv4.New( dhcpv4.WithReply(dp), dhcpv4.WithMessageType(dhcpv4.MessageTypeAck), - dhcpv4.WithDNS(net.ParseIP("100.100.100.100")), - dhcpv4.WithRouter(net.ParseIP(routerIP)), // the default route - dhcpv4.WithServerIP(net.ParseIP("100.100.100.100")), // TODO: what is this? - dhcpv4.WithOption(dhcpv4.OptServerIdentifier(net.ParseIP("100.100.100.100"))), - dhcpv4.WithYourIP(net.ParseIP(theClientIP)), // Hello world - dhcpv4.WithLeaseTime(3600), // hour works - dhcpv4.WithNetmask(net.IPMask(net.ParseIP("255.255.255.0").To4())), + dhcpv4.WithDNS(routerIP), + dhcpv4.WithRouter(routerIP), // the default route + dhcpv4.WithServerIP(routerIP), // TODO: what is this? + dhcpv4.WithOption(dhcpv4.OptServerIdentifier(routerIP)), + dhcpv4.WithYourIP(net.ParseIP(ips)), // Hello world + dhcpv4.WithLeaseTime(3600), // hour works + dhcpv4.WithNetmask(cgnatNetMask), ) if err != nil { t.logf("error building DHCP ack: %v", err) @@ -278,8 +306,7 @@ func (t *Wrapper) handleDHCPRequest(ethBuf []byte) bool { netip.AddrPortFrom(netaddr.IPv4(100, 100, 100, 100), 67), // src netip.AddrPortFrom(netaddr.IPv4(255, 255, 255, 255), 68), // dst ) - // TODO(raggi): reduce allocs! - n, err := t.tdev.Write([][]byte{pkt}, 0) + n, err := t.WriteEthernet(pkt) if tapDebug { t.logf("tap: wrote DHCP ACK %v, %v", n, err) } @@ -291,6 +318,16 @@ func (t *Wrapper) handleDHCPRequest(ethBuf []byte) bool { return consumePacket } +func writeEthernetFrame(buf []byte, srcMAC, dstMAC net.HardwareAddr, proto tcpip.NetworkProtocolNumber) { + // Ethernet header + eth := header.Ethernet(buf) + eth.Encode(&header.EthernetFields{ + SrcAddr: tcpip.LinkAddress(srcMAC), + DstAddr: tcpip.LinkAddress(dstMAC), + Type: proto, + }) +} + func packLayer2UDP(payload []byte, srcMAC, dstMAC net.HardwareAddr, src, dst netip.AddrPort) []byte { buf := make([]byte, header.EthernetMinimumSize+header.UDPMinimumSize+header.IPv4MinimumSize+len(payload)) payloadStart := len(buf) - len(payload) @@ -300,12 +337,7 @@ func packLayer2UDP(payload []byte, srcMAC, dstMAC net.HardwareAddr, src, dst net dstB := dst.Addr().As4() dstIP := tcpip.AddrFromSlice(dstB[:]) // Ethernet header - eth := header.Ethernet(buf) - eth.Encode(&header.EthernetFields{ - SrcAddr: tcpip.LinkAddress(srcMAC), - DstAddr: tcpip.LinkAddress(dstMAC), - Type: ipv4.ProtocolNumber, - }) + writeEthernetFrame(buf, srcMAC, dstMAC, ipv4.ProtocolNumber) // IP header ipbuf := buf[header.EthernetMinimumSize:] ip := header.IPv4(ipbuf) @@ -342,17 +374,18 @@ func run(prog string, args ...string) error { return nil } -func (t *Wrapper) destMAC() [6]byte { +func (t *tapDevice) destMAC() [6]byte { return t.destMACAtomic.Load() } -func newTAPDevice(fd int, tapName string) (tun.Device, error) { +func newTAPDevice(logf logger.Logf, fd int, tapName string) (tun.Device, error) { err := unix.SetNonblock(fd, true) if err != nil { return nil, err } file := os.NewFile(uintptr(fd), "/dev/tap") d := &tapDevice{ + logf: logf, file: file, events: make(chan tun.Event), name: tapName, @@ -360,20 +393,22 @@ func newTAPDevice(fd int, tapName string) (tun.Device, error) { return d, nil } -var ( - _ setWrapperer = &tapDevice{} -) - type tapDevice struct { - file *os.File - events chan tun.Event - name string - wrapper *Wrapper - closeOnce sync.Once + file *os.File + logf func(format string, args ...any) + events chan tun.Event + name string + closeOnce sync.Once + clientIPv4 syncs.AtomicValue[string] + + destMACAtomic syncs.AtomicValue[[6]byte] } -func (t *tapDevice) setWrapper(wrapper *Wrapper) { - t.wrapper = wrapper +var _ tstun.SetIPer = (*tapDevice)(nil) + +func (t *tapDevice) SetIP(ipV4, ipV6TODO netip.Addr) error { + t.clientIPv4.Store(ipV4.String()) + return nil } func (t *tapDevice) File() *os.File { @@ -384,43 +419,70 @@ func (t *tapDevice) Name() (string, error) { return t.name, nil } +// Read reads an IP packet from the TAP device. It strips the ethernet frame header. func (t *tapDevice) Read(buffs [][]byte, sizes []int, offset int) (int, error) { + n, err := t.ReadEthernet(buffs, sizes, offset) + if err != nil || n == 0 { + return n, err + } + // Strip the ethernet frame header. + copy(buffs[0][offset:], buffs[0][offset+ethernetFrameSize:offset+sizes[0]]) + sizes[0] -= ethernetFrameSize + return 1, nil +} + +// ReadEthernet reads a raw ethernet frame from the TAP device. +func (t *tapDevice) ReadEthernet(buffs [][]byte, sizes []int, offset int) (int, error) { n, err := t.file.Read(buffs[0][offset:]) if err != nil { return 0, err } + if t.handleTAPFrame(buffs[0][offset : offset+n]) { + return 0, nil + } sizes[0] = n return 1, nil } +// WriteEthernet writes a raw ethernet frame to the TAP device. +func (t *tapDevice) WriteEthernet(buf []byte) (int, error) { + return t.file.Write(buf) +} + +// ethBufPool holds a pool of bytes.Buffers for use in [tapDevice.Write]. +var ethBufPool = syncs.Pool[*bytes.Buffer]{New: func() *bytes.Buffer { return new(bytes.Buffer) }} + +// Write writes a raw IP packet to the TAP device. It adds the ethernet frame header. func (t *tapDevice) Write(buffs [][]byte, offset int) (int, error) { errs := make([]error, 0) wrote := 0 + m := t.destMAC() + dstMac := net.HardwareAddr(m[:]) + buf := ethBufPool.Get() + defer ethBufPool.Put(buf) for _, buff := range buffs { - if offset < ethernetFrameSize { - errs = append(errs, fmt.Errorf("[unexpected] weird offset %d for TAP write", offset)) - return 0, multierr.New(errs...) - } - eth := buff[offset-ethernetFrameSize:] - dst := t.wrapper.destMAC() - copy(eth[:6], dst[:]) - copy(eth[6:12], ourMAC[:]) - et := etherTypeIPv4 - if buff[offset]>>4 == 6 { - et = etherTypeIPv6 + buf.Reset() + buf.Grow(header.EthernetMinimumSize + len(buff) - offset) + + var ebuf [14]byte + switch buff[offset] >> 4 { + case 4: + writeEthernetFrame(ebuf[:], ourMAC, dstMac, ipv4.ProtocolNumber) + case 6: + writeEthernetFrame(ebuf[:], ourMAC, dstMac, ipv6.ProtocolNumber) + default: + continue } - eth[12], eth[13] = et[0], et[1] - if tapDebug { - t.wrapper.logf("tap: tapWrite off=%v % x", offset, buff) - } - _, err := t.file.Write(buff[offset-ethernetFrameSize:]) + buf.Write(ebuf[:]) + buf.Write(buff[offset:]) + _, err := t.WriteEthernet(buf.Bytes()) if err != nil { errs = append(errs, err) } else { wrote++ } } - return wrote, multierr.New(errs...) + return wrote, errors.Join(errs...) } func (t *tapDevice) MTU() (int, error) { @@ -428,8 +490,7 @@ func (t *tapDevice) MTU() (int, error) { if err != nil { return 0, err } - err = unix.IoctlIfreq(int(t.file.Fd()), unix.SIOCGIFMTU, ifr) - if err != nil { + if err := unix.IoctlIfreq(int(t.file.Fd()), unix.SIOCGIFMTU, ifr); err != nil { return 0, err } return int(ifr.Uint32()), nil diff --git a/feature/tpm/attestation.go b/feature/tpm/attestation.go new file mode 100644 index 000000000..197a8d6b8 --- /dev/null +++ b/feature/tpm/attestation.go @@ -0,0 +1,309 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tpm + +import ( + "crypto" + "encoding/json" + "errors" + "fmt" + "io" + "log" + "sync" + + "github.com/google/go-tpm/tpm2" + "github.com/google/go-tpm/tpm2/transport" + "golang.org/x/crypto/cryptobyte" + "golang.org/x/crypto/cryptobyte/asn1" + "tailscale.com/types/key" +) + +type attestationKey struct { + tpmMu sync.Mutex + tpm transport.TPMCloser + // private and public parts of the TPM key as returned from tpm2.Create. + // These are used for serialization. + tpmPrivate tpm2.TPM2BPrivate + tpmPublic tpm2.TPM2BPublic + // handle of the loaded TPM key. + handle *tpm2.NamedHandle + // pub is the parsed *ecdsa.PublicKey. + pub crypto.PublicKey +} + +func newAttestationKey() (ak *attestationKey, retErr error) { + tpm, err := open() + if err != nil { + return nil, key.ErrUnsupported + } + defer func() { + if retErr != nil { + tpm.Close() + } + }() + ak = &attestationKey{tpm: tpm} + + // Create a key under the storage hierarchy. + if err := withSRK(log.Printf, ak.tpm, func(srk tpm2.AuthHandle) error { + resp, err := tpm2.Create{ + ParentHandle: tpm2.NamedHandle{ + Handle: srk.Handle, + Name: srk.Name, + }, + InPublic: tpm2.New2B( + tpm2.TPMTPublic{ + Type: tpm2.TPMAlgECC, + NameAlg: tpm2.TPMAlgSHA256, + ObjectAttributes: tpm2.TPMAObject{ + SensitiveDataOrigin: true, + UserWithAuth: true, + AdminWithPolicy: true, + // We don't set an authorization policy on this key, so + // DA isn't helpful. + NoDA: true, + FixedTPM: true, + FixedParent: true, + SignEncrypt: true, + }, + Parameters: tpm2.NewTPMUPublicParms( + tpm2.TPMAlgECC, + &tpm2.TPMSECCParms{ + CurveID: tpm2.TPMECCNistP256, + Scheme: tpm2.TPMTECCScheme{ + Scheme: tpm2.TPMAlgECDSA, + Details: tpm2.NewTPMUAsymScheme( + tpm2.TPMAlgECDSA, + &tpm2.TPMSSigSchemeECDSA{ + // Unfortunately, TPMs don't let us use + // TPMAlgNull here to make the hash + // algorithm dynamic higher in the + // stack. We have to hardcode it here. + HashAlg: tpm2.TPMAlgSHA256, + }, + ), + }, + }, + ), + }, + ), + }.Execute(ak.tpm) + if err != nil { + return fmt.Errorf("tpm2.Create: %w", err) + } + ak.tpmPrivate = resp.OutPrivate + ak.tpmPublic = resp.OutPublic + return nil + }); err != nil { + return nil, err + } + return ak, ak.load() +} + +func (ak *attestationKey) loaded() bool { + return ak.tpm != nil && ak.handle != nil && ak.pub != nil +} + +// load the key into the TPM from its public/private components. Must be called +// before Sign or Public. +func (ak *attestationKey) load() error { + if ak.loaded() { + return nil + } + if len(ak.tpmPrivate.Buffer) == 0 || len(ak.tpmPublic.Bytes()) == 0 { + return fmt.Errorf("attestationKey.load called without tpmPrivate or tpmPublic") + } + return withSRK(log.Printf, ak.tpm, func(srk tpm2.AuthHandle) error { + resp, err := tpm2.Load{ + ParentHandle: tpm2.NamedHandle{ + Handle: srk.Handle, + Name: srk.Name, + }, + InPrivate: ak.tpmPrivate, + InPublic: ak.tpmPublic, + }.Execute(ak.tpm) + if err != nil { + return fmt.Errorf("tpm2.Load: %w", err) + } + + ak.handle = &tpm2.NamedHandle{ + Handle: resp.ObjectHandle, + Name: resp.Name, + } + pub, err := ak.tpmPublic.Contents() + if err != nil { + return err + } + ak.pub, err = tpm2.Pub(*pub) + return err + }) +} + +// attestationKeySerialized is the JSON-serialized representation of +// attestationKey. +type attestationKeySerialized struct { + TPMPrivate []byte `json:"tpmPrivate"` + TPMPublic []byte `json:"tpmPublic"` +} + +// MarshalJSON implements json.Marshaler. +func (ak *attestationKey) MarshalJSON() ([]byte, error) { + if ak == nil || len(ak.tpmPublic.Bytes()) == 0 || len(ak.tpmPrivate.Buffer) == 0 { + return []byte("null"), nil + } + return json.Marshal(attestationKeySerialized{ + TPMPublic: ak.tpmPublic.Bytes(), + TPMPrivate: ak.tpmPrivate.Buffer, + }) +} + +// UnmarshalJSON implements json.Unmarshaler. +func (ak *attestationKey) UnmarshalJSON(data []byte) (retErr error) { + var aks attestationKeySerialized + if err := json.Unmarshal(data, &aks); err != nil { + return err + } + + ak.tpmPrivate = tpm2.TPM2BPrivate{Buffer: aks.TPMPrivate} + ak.tpmPublic = tpm2.BytesAs2B[tpm2.TPMTPublic, *tpm2.TPMTPublic](aks.TPMPublic) + + ak.tpmMu.Lock() + defer ak.tpmMu.Unlock() + if ak.tpm != nil { + ak.tpm.Close() + ak.tpm = nil + } + + tpm, err := open() + if err != nil { + return key.ErrUnsupported + } + defer func() { + if retErr != nil { + tpm.Close() + } + }() + ak.tpm = tpm + + return ak.load() +} + +func (ak *attestationKey) Public() crypto.PublicKey { + return ak.pub +} + +func (ak *attestationKey) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) (signature []byte, err error) { + ak.tpmMu.Lock() + defer ak.tpmMu.Unlock() + + if !ak.loaded() { + return nil, errors.New("tpm2 attestation key is not loaded during Sign") + } + // Unfortunately, TPMs don't let us make keys with dynamic hash algorithms. + // The hash algorithm is fixed at key creation time (tpm2.Create). + if opts != crypto.SHA256 { + return nil, fmt.Errorf("tpm2 key is restricted to SHA256, have %q", opts) + } + resp, err := tpm2.Sign{ + KeyHandle: ak.handle, + Digest: tpm2.TPM2BDigest{ + Buffer: digest, + }, + InScheme: tpm2.TPMTSigScheme{ + Scheme: tpm2.TPMAlgECDSA, + Details: tpm2.NewTPMUSigScheme( + tpm2.TPMAlgECDSA, + &tpm2.TPMSSchemeHash{ + HashAlg: tpm2.TPMAlgSHA256, + }, + ), + }, + Validation: tpm2.TPMTTKHashCheck{ + Tag: tpm2.TPMSTHashCheck, + }, + }.Execute(ak.tpm) + if err != nil { + return nil, fmt.Errorf("tpm2.Sign: %w", err) + } + sig, err := resp.Signature.Signature.ECDSA() + if err != nil { + return nil, err + } + return encodeSignature(sig.SignatureR.Buffer, sig.SignatureS.Buffer) +} + +// Copied from crypto/ecdsa. +func encodeSignature(r, s []byte) ([]byte, error) { + var b cryptobyte.Builder + b.AddASN1(asn1.SEQUENCE, func(b *cryptobyte.Builder) { + addASN1IntBytes(b, r) + addASN1IntBytes(b, s) + }) + return b.Bytes() +} + +// addASN1IntBytes encodes in ASN.1 a positive integer represented as +// a big-endian byte slice with zero or more leading zeroes. +func addASN1IntBytes(b *cryptobyte.Builder, bytes []byte) { + for len(bytes) > 0 && bytes[0] == 0 { + bytes = bytes[1:] + } + if len(bytes) == 0 { + b.SetError(errors.New("invalid integer")) + return + } + b.AddASN1(asn1.INTEGER, func(c *cryptobyte.Builder) { + if bytes[0]&0x80 != 0 { + c.AddUint8(0) + } + c.AddBytes(bytes) + }) +} + +func (ak *attestationKey) Close() error { + ak.tpmMu.Lock() + defer ak.tpmMu.Unlock() + + var errs []error + if ak.handle != nil && ak.tpm != nil { + _, err := tpm2.FlushContext{FlushHandle: ak.handle.Handle}.Execute(ak.tpm) + errs = append(errs, err) + } + if ak.tpm != nil { + errs = append(errs, ak.tpm.Close()) + } + return errors.Join(errs...) +} + +func (ak *attestationKey) Clone() key.HardwareAttestationKey { + if ak.IsZero() { + return nil + } + + tpm, err := open() + if err != nil { + log.Printf("[unexpected] failed to open a TPM connection in feature/tpm.attestationKey.Clone: %v", err) + return nil + } + akc := &attestationKey{ + tpm: tpm, + tpmPrivate: ak.tpmPrivate, + tpmPublic: ak.tpmPublic, + } + if err := akc.load(); err != nil { + log.Printf("[unexpected] failed to load TPM key in feature/tpm.attestationKey.Clone: %v", err) + tpm.Close() + return nil + } + return akc +} + +func (ak *attestationKey) IsZero() bool { + if ak == nil { + return true + } + + ak.tpmMu.Lock() + defer ak.tpmMu.Unlock() + return !ak.loaded() +} diff --git a/feature/tpm/attestation_test.go b/feature/tpm/attestation_test.go new file mode 100644 index 000000000..e7ff72987 --- /dev/null +++ b/feature/tpm/attestation_test.go @@ -0,0 +1,164 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tpm + +import ( + "bytes" + "crypto" + "crypto/ecdsa" + "crypto/rand" + "crypto/sha256" + "encoding/json" + "runtime" + "sync" + "testing" +) + +func TestAttestationKeySign(t *testing.T) { + skipWithoutTPM(t) + ak, err := newAttestationKey() + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { + if err := ak.Close(); err != nil { + t.Errorf("ak.Close: %v", err) + } + }) + + data := []byte("secrets") + digest := sha256.Sum256(data) + + // Check signature/validation round trip. + sig, err := ak.Sign(rand.Reader, digest[:], crypto.SHA256) + if err != nil { + t.Fatal(err) + } + if !ecdsa.VerifyASN1(ak.Public().(*ecdsa.PublicKey), digest[:], sig) { + t.Errorf("ecdsa.VerifyASN1 failed") + } + + // Create a different key. + ak2, err := newAttestationKey() + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { + if err := ak2.Close(); err != nil { + t.Errorf("ak2.Close: %v", err) + } + }) + + // Make sure that the keys are distinct via their public keys and the + // signatures they produce. + if ak.Public().(*ecdsa.PublicKey).Equal(ak2.Public()) { + t.Errorf("public keys of distinct attestation keys are the same") + } + sig2, err := ak2.Sign(rand.Reader, digest[:], crypto.SHA256) + if err != nil { + t.Fatal(err) + } + if bytes.Equal(sig, sig2) { + t.Errorf("signatures from distinct attestation keys are the same") + } +} + +func TestAttestationKeySignConcurrent(t *testing.T) { + skipWithoutTPM(t) + ak, err := newAttestationKey() + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { + if err := ak.Close(); err != nil { + t.Errorf("ak.Close: %v", err) + } + }) + + data := []byte("secrets") + digest := sha256.Sum256(data) + + wg := sync.WaitGroup{} + for range runtime.GOMAXPROCS(-1) { + wg.Go(func() { + // Check signature/validation round trip. + sig, err := ak.Sign(rand.Reader, digest[:], crypto.SHA256) + if err != nil { + t.Fatal(err) + } + if !ecdsa.VerifyASN1(ak.Public().(*ecdsa.PublicKey), digest[:], sig) { + t.Errorf("ecdsa.VerifyASN1 failed") + } + }) + } + wg.Wait() +} + +func TestAttestationKeyUnmarshal(t *testing.T) { + skipWithoutTPM(t) + ak, err := newAttestationKey() + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { + if err := ak.Close(); err != nil { + t.Errorf("ak.Close: %v", err) + } + }) + + buf, err := ak.MarshalJSON() + if err != nil { + t.Fatal(err) + } + var ak2 attestationKey + if err := json.Unmarshal(buf, &ak2); err != nil { + t.Fatal(err) + } + t.Cleanup(func() { + if err := ak2.Close(); err != nil { + t.Errorf("ak2.Close: %v", err) + } + }) + + if !ak2.loaded() { + t.Error("unmarshalled key is not loaded") + } + + if !ak.Public().(*ecdsa.PublicKey).Equal(ak2.Public()) { + t.Error("unmarshalled public key is not the same as the original public key") + } +} + +func TestAttestationKeyClone(t *testing.T) { + skipWithoutTPM(t) + ak, err := newAttestationKey() + if err != nil { + t.Fatal(err) + } + + ak2 := ak.Clone() + if ak2 == nil { + t.Fatal("Clone failed") + } + t.Cleanup(func() { + if err := ak2.Close(); err != nil { + t.Errorf("ak2.Close: %v", err) + } + }) + // Close the original key, ak2 should remain open and usable. + if err := ak.Close(); err != nil { + t.Fatal(err) + } + + data := []byte("secrets") + digest := sha256.Sum256(data) + // Check signature/validation round trip using cloned key. + sig, err := ak2.Sign(rand.Reader, digest[:], crypto.SHA256) + if err != nil { + t.Fatal(err) + } + if !ecdsa.VerifyASN1(ak2.Public().(*ecdsa.PublicKey), digest[:], sig) { + t.Errorf("ecdsa.VerifyASN1 failed") + } +} diff --git a/feature/tpm/tpm.go b/feature/tpm/tpm.go new file mode 100644 index 000000000..8df269b95 --- /dev/null +++ b/feature/tpm/tpm.go @@ -0,0 +1,480 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package tpm implements support for TPM 2.0 devices. +package tpm + +import ( + "bytes" + "crypto/rand" + "encoding/json" + "errors" + "fmt" + "iter" + "log" + "os" + "path/filepath" + "runtime" + "slices" + "strings" + "sync" + + "github.com/google/go-tpm/tpm2" + "github.com/google/go-tpm/tpm2/transport" + "golang.org/x/crypto/nacl/secretbox" + "tailscale.com/atomicfile" + "tailscale.com/envknob" + "tailscale.com/feature" + "tailscale.com/hostinfo" + "tailscale.com/ipn" + "tailscale.com/ipn/store" + "tailscale.com/paths" + "tailscale.com/tailcfg" + "tailscale.com/types/key" + "tailscale.com/types/logger" + "tailscale.com/util/testenv" +) + +var ( + infoOnce = sync.OnceValue(info) + tpmSupportedOnce = sync.OnceValue(tpmSupported) +) + +func init() { + feature.Register("tpm") + feature.HookTPMAvailable.Set(tpmSupportedOnce) + feature.HookHardwareAttestationAvailable.Set(tpmSupportedOnce) + + hostinfo.RegisterHostinfoNewHook(func(hi *tailcfg.Hostinfo) { + hi.TPM = infoOnce() + }) + store.Register(store.TPMPrefix, newStore) + if runtime.GOOS == "linux" || runtime.GOOS == "windows" { + key.RegisterHardwareAttestationKeyFns( + func() key.HardwareAttestationKey { return &attestationKey{} }, + func() (key.HardwareAttestationKey, error) { return newAttestationKey() }, + ) + } +} + +func tpmSupported() bool { + hi := infoOnce() + if hi == nil { + return false + } + if hi.FamilyIndicator != "2.0" { + return false + } + + tpm, err := open() + if err != nil { + return false + } + defer tpm.Close() + + if err := withSRK(logger.Discard, tpm, func(srk tpm2.AuthHandle) error { + return nil + }); err != nil { + return false + } + return true +} + +var verboseTPM = envknob.RegisterBool("TS_DEBUG_TPM") + +func info() *tailcfg.TPMInfo { + logf := logger.Discard + if !testenv.InTest() || verboseTPM() { + logf = log.New(log.Default().Writer(), "TPM: ", 0).Printf + } + + tpm, err := open() + if err != nil { + if !os.IsNotExist(err) || verboseTPM() { + // Only log if it's an interesting error, not just "no TPM", + // as is very common, especially in VMs. + logf("error opening: %v", err) + } + return nil + } + if verboseTPM() { + logf("successfully opened") + } + defer tpm.Close() + + info := new(tailcfg.TPMInfo) + toStr := func(s *string) func(*tailcfg.TPMInfo, uint32) { + return func(info *tailcfg.TPMInfo, value uint32) { + *s += propToString(value) + } + } + for _, cap := range []struct { + prop tpm2.TPMPT + apply func(info *tailcfg.TPMInfo, value uint32) + }{ + {tpm2.TPMPTManufacturer, toStr(&info.Manufacturer)}, + {tpm2.TPMPTVendorString1, toStr(&info.Vendor)}, + {tpm2.TPMPTVendorString2, toStr(&info.Vendor)}, + {tpm2.TPMPTVendorString3, toStr(&info.Vendor)}, + {tpm2.TPMPTVendorString4, toStr(&info.Vendor)}, + {tpm2.TPMPTRevision, func(info *tailcfg.TPMInfo, value uint32) { info.SpecRevision = int(value) }}, + {tpm2.TPMPTVendorTPMType, func(info *tailcfg.TPMInfo, value uint32) { info.Model = int(value) }}, + {tpm2.TPMPTFirmwareVersion1, func(info *tailcfg.TPMInfo, value uint32) { info.FirmwareVersion += uint64(value) << 32 }}, + {tpm2.TPMPTFirmwareVersion2, func(info *tailcfg.TPMInfo, value uint32) { info.FirmwareVersion += uint64(value) }}, + {tpm2.TPMPTFamilyIndicator, toStr(&info.FamilyIndicator)}, + } { + resp, err := tpm2.GetCapability{ + Capability: tpm2.TPMCapTPMProperties, + Property: uint32(cap.prop), + PropertyCount: 1, + }.Execute(tpm) + if err != nil { + logf("GetCapability %v: %v", cap.prop, err) + continue + } + props, err := resp.CapabilityData.Data.TPMProperties() + if err != nil { + logf("GetCapability %v: %v", cap.prop, err) + continue + } + if len(props.TPMProperty) == 0 { + continue + } + cap.apply(info, props.TPMProperty[0].Value) + } + logf("successfully read all properties") + return info +} + +// propToString converts TPM_PT property value, which is a uint32, into a +// string of up to 4 ASCII characters. This encoding applies only to some +// properties, see +// https://trustedcomputinggroup.org/resource/tpm-library-specification/ Part +// 2, section 6.13. +func propToString(v uint32) string { + chars := []byte{ + byte(v >> 24), + byte(v >> 16), + byte(v >> 8), + byte(v), + } + // Delete any non-printable ASCII characters. + return string(slices.DeleteFunc(chars, func(b byte) bool { return b < ' ' || b > '~' })) +} + +func newStore(logf logger.Logf, path string) (ipn.StateStore, error) { + path = strings.TrimPrefix(path, store.TPMPrefix) + if err := paths.MkStateDir(filepath.Dir(path)); err != nil { + return nil, fmt.Errorf("creating state directory: %w", err) + } + var parsed map[ipn.StateKey][]byte + bs, err := os.ReadFile(path) + if err != nil { + if !os.IsNotExist(err) { + return nil, fmt.Errorf("failed to open %q: %w", path, err) + } + logf("tpm.newStore: initializing state file") + + var key [32]byte + // crypto/rand.Read never returns an error. + rand.Read(key[:]) + + store := &tpmStore{ + logf: logf, + path: path, + key: key, + cache: make(map[ipn.StateKey][]byte), + } + if err := store.writeSealed(); err != nil { + return nil, fmt.Errorf("failed to write initial state file: %w", err) + } + return store, nil + } + + // State file exists, unseal and parse it. + var sealed encryptedData + if err := json.Unmarshal(bs, &sealed); err != nil { + return nil, fmt.Errorf("failed to unmarshal state file: %w", err) + } + if len(sealed.Data) == 0 || sealed.Key == nil || len(sealed.Nonce) == 0 { + return nil, fmt.Errorf("state file %q has not been TPM-sealed or is corrupt", path) + } + data, err := unseal(logf, sealed) + if err != nil { + return nil, fmt.Errorf("failed to unseal state file: %w", err) + } + if err := json.Unmarshal(data.Data, &parsed); err != nil { + return nil, fmt.Errorf("failed to parse state file: %w", err) + } + return &tpmStore{ + logf: logf, + path: path, + key: data.Key, + cache: parsed, + }, nil +} + +// tpmStore is an ipn.StateStore that stores the state in a secretbox-encrypted +// file using a TPM-sealed symmetric key. +type tpmStore struct { + ipn.EncryptedStateStore + + logf logger.Logf + path string + key [32]byte + + mu sync.RWMutex + cache map[ipn.StateKey][]byte +} + +func (s *tpmStore) ReadState(k ipn.StateKey) ([]byte, error) { + s.mu.RLock() + defer s.mu.RUnlock() + v, ok := s.cache[k] + if !ok { + return nil, ipn.ErrStateNotExist + } + return bytes.Clone(v), nil +} + +func (s *tpmStore) WriteState(k ipn.StateKey, bs []byte) error { + s.mu.Lock() + defer s.mu.Unlock() + if bytes.Equal(s.cache[k], bs) { + return nil + } + s.cache[k] = bytes.Clone(bs) + + return s.writeSealed() +} + +func (s *tpmStore) writeSealed() error { + bs, err := json.Marshal(s.cache) + if err != nil { + return err + } + sealed, err := seal(s.logf, decryptedData{Key: s.key, Data: bs}) + if err != nil { + return fmt.Errorf("failed to seal state file: %w", err) + } + buf, err := json.Marshal(sealed) + if err != nil { + return err + } + return atomicfile.WriteFile(s.path, buf, 0600) +} + +func (s *tpmStore) All() iter.Seq2[ipn.StateKey, []byte] { + return func(yield func(ipn.StateKey, []byte) bool) { + s.mu.Lock() + defer s.mu.Unlock() + + for k, v := range s.cache { + if !yield(k, v) { + break + } + } + } +} + +// Ensure tpmStore implements store.ExportableStore for migration to/from +// store.FileStore. +var _ store.ExportableStore = (*tpmStore)(nil) + +// The nested levels of encoding and encryption are confusing, so here's what's +// going on in plain English. +// +// Not all TPM devices support symmetric encryption (TPM2_EncryptDecrypt2) +// natively, but they do support "sealing" small values (see +// tpmSeal/tpmUnseal). The size limit is too small for the actual state file, +// so we seal a symmetric key instead. This symmetric key is then used to seal +// the actual data using nacl/secretbox. +// Confusingly, both TPMs and secretbox use "seal" terminology. +// +// tpmSeal/tpmUnseal do the lower-level sealing of small []byte blobs, which we +// use to seal a 32-byte secretbox key. +// +// seal/unseal do the higher-level sealing of store data using secretbox, and +// also sealing of the symmetric key using TPM. + +// decryptedData contains the fully decrypted raw data along with the symmetric +// key used for secretbox. This struct should only live in memory and never get +// stored to disk! +type decryptedData struct { + Key [32]byte + Data []byte +} + +func (decryptedData) MarshalJSON() ([]byte, error) { + return nil, errors.New("[unexpected]: decryptedData should never get JSON-marshaled!") +} + +// encryptedData contains the secretbox-sealed data and nonce, along with a +// TPM-sealed key. All fields are required. +type encryptedData struct { + Key *tpmSealedData `json:"key"` + Nonce []byte `json:"nonce"` + Data []byte `json:"data"` +} + +func seal(logf logger.Logf, dec decryptedData) (*encryptedData, error) { + var nonce [24]byte + // crypto/rand.Read never returns an error. + rand.Read(nonce[:]) + + sealedData := secretbox.Seal(nil, dec.Data, &nonce, &dec.Key) + sealedKey, err := tpmSeal(logf, dec.Key[:]) + if err != nil { + return nil, fmt.Errorf("failed to seal encryption key to TPM: %w", err) + } + + return &encryptedData{ + Key: sealedKey, + Nonce: nonce[:], + Data: sealedData, + }, nil +} + +func unseal(logf logger.Logf, data encryptedData) (*decryptedData, error) { + if len(data.Nonce) != 24 { + return nil, fmt.Errorf("nonce should be 24 bytes long, got %d", len(data.Nonce)) + } + + unsealedKey, err := tpmUnseal(logf, data.Key) + if err != nil { + return nil, fmt.Errorf("failed to unseal encryption key with TPM: %w", err) + } + if len(unsealedKey) != 32 { + return nil, fmt.Errorf("unsealed key should be 32 bytes long, got %d", len(unsealedKey)) + } + unsealedData, ok := secretbox.Open(nil, data.Data, (*[24]byte)(data.Nonce), (*[32]byte)(unsealedKey)) + if !ok { + return nil, errors.New("failed to unseal data") + } + + return &decryptedData{ + Key: *(*[32]byte)(unsealedKey), + Data: unsealedData, + }, nil +} + +type tpmSealedData struct { + Private []byte + Public []byte +} + +// withSRK runs fn with the loaded Storage Root Key (SRK) handle. The SRK is +// flushed after fn returns. +func withSRK(logf logger.Logf, tpm transport.TPM, fn func(srk tpm2.AuthHandle) error) error { + srkCmd := tpm2.CreatePrimary{ + PrimaryHandle: tpm2.TPMRHOwner, + InPublic: tpm2.New2B(tpm2.ECCSRKTemplate), + } + srkRes, err := srkCmd.Execute(tpm) + if err != nil { + return fmt.Errorf("tpm2.CreatePrimary: %w", err) + } + defer func() { + cmd := tpm2.FlushContext{FlushHandle: srkRes.ObjectHandle} + if _, err := cmd.Execute(tpm); err != nil { + logf("tpm2.FlushContext: failed to flush SRK handle: %v", err) + } + }() + + return fn(tpm2.AuthHandle{ + Handle: srkRes.ObjectHandle, + Name: srkRes.Name, + Auth: tpm2.HMAC(tpm2.TPMAlgSHA256, 32), + }) +} + +// tpmSeal seals the data using SRK of the local TPM. +func tpmSeal(logf logger.Logf, data []byte) (*tpmSealedData, error) { + tpm, err := open() + if err != nil { + return nil, fmt.Errorf("opening TPM: %w", err) + } + defer tpm.Close() + + var res *tpmSealedData + err = withSRK(logf, tpm, func(srk tpm2.AuthHandle) error { + sealCmd := tpm2.Create{ + ParentHandle: srk, + InSensitive: tpm2.TPM2BSensitiveCreate{ + Sensitive: &tpm2.TPMSSensitiveCreate{ + Data: tpm2.NewTPMUSensitiveCreate(&tpm2.TPM2BSensitiveData{ + Buffer: data, + }), + }, + }, + InPublic: tpm2.New2B(tpm2.TPMTPublic{ + Type: tpm2.TPMAlgKeyedHash, + NameAlg: tpm2.TPMAlgSHA256, + ObjectAttributes: tpm2.TPMAObject{ + FixedTPM: true, + FixedParent: true, + UserWithAuth: true, + // We don't set an authorization policy on this key, so DA + // isn't helpful. + NoDA: true, + }, + }), + } + sealRes, err := sealCmd.Execute(tpm) + if err != nil { + return fmt.Errorf("tpm2.Create: %w", err) + } + + res = &tpmSealedData{ + Private: sealRes.OutPrivate.Buffer, + Public: sealRes.OutPublic.Bytes(), + } + return nil + }) + return res, err +} + +// tpmUnseal unseals the data using SRK of the local TPM. +func tpmUnseal(logf logger.Logf, data *tpmSealedData) ([]byte, error) { + tpm, err := open() + if err != nil { + return nil, fmt.Errorf("opening TPM: %w", err) + } + defer tpm.Close() + + var res []byte + err = withSRK(logf, tpm, func(srk tpm2.AuthHandle) error { + // Load the sealed object into the TPM first under SRK. + loadCmd := tpm2.Load{ + ParentHandle: srk, + InPrivate: tpm2.TPM2BPrivate{Buffer: data.Private}, + InPublic: tpm2.BytesAs2B[tpm2.TPMTPublic](data.Public), + } + loadRes, err := loadCmd.Execute(tpm) + if err != nil { + return fmt.Errorf("tpm2.Load: %w", err) + } + defer func() { + cmd := tpm2.FlushContext{FlushHandle: loadRes.ObjectHandle} + if _, err := cmd.Execute(tpm); err != nil { + log.Printf("tpm2.FlushContext: failed to flush loaded sealed blob handle: %v", err) + } + }() + + // Then unseal the object. + unsealCmd := tpm2.Unseal{ + ItemHandle: tpm2.NamedHandle{ + Handle: loadRes.ObjectHandle, + Name: loadRes.Name, + }, + } + unsealRes, err := unsealCmd.Execute(tpm) + if err != nil { + return fmt.Errorf("tpm2.Unseal: %w", err) + } + res = unsealRes.OutData.Buffer + + return nil + }) + return res, err +} diff --git a/feature/tpm/tpm_linux.go b/feature/tpm/tpm_linux.go new file mode 100644 index 000000000..6c8131e8d --- /dev/null +++ b/feature/tpm/tpm_linux.go @@ -0,0 +1,17 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tpm + +import ( + "github.com/google/go-tpm/tpm2/transport" + "github.com/google/go-tpm/tpm2/transport/linuxtpm" +) + +func open() (transport.TPMCloser, error) { + tpm, err := linuxtpm.Open("/dev/tpmrm0") + if err == nil { + return tpm, nil + } + return linuxtpm.Open("/dev/tpm0") +} diff --git a/feature/tpm/tpm_other.go b/feature/tpm/tpm_other.go new file mode 100644 index 000000000..108b2c057 --- /dev/null +++ b/feature/tpm/tpm_other.go @@ -0,0 +1,16 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !linux && !windows + +package tpm + +import ( + "errors" + + "github.com/google/go-tpm/tpm2/transport" +) + +func open() (transport.TPMCloser, error) { + return nil, errors.New("TPM not supported on this platform") +} diff --git a/feature/tpm/tpm_test.go b/feature/tpm/tpm_test.go new file mode 100644 index 000000000..afce570fc --- /dev/null +++ b/feature/tpm/tpm_test.go @@ -0,0 +1,362 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tpm + +import ( + "bytes" + "crypto/rand" + "encoding/json" + "errors" + "fmt" + "iter" + "maps" + "os" + "path/filepath" + "slices" + "strconv" + "strings" + "testing" + + "github.com/google/go-cmp/cmp" + "tailscale.com/ipn" + "tailscale.com/ipn/store" + "tailscale.com/types/logger" + "tailscale.com/util/mak" +) + +func TestPropToString(t *testing.T) { + for prop, want := range map[uint32]string{ + 0: "", + 0x4D534654: "MSFT", + 0x414D4400: "AMD", + 0x414D440D: "AMD", + } { + if got := propToString(prop); got != want { + t.Errorf("propToString(0x%x): got %q, want %q", prop, got, want) + } + } +} + +func skipWithoutTPM(t testing.TB) { + if !tpmSupported() { + t.Skip("TPM not available") + } +} + +func TestSealUnseal(t *testing.T) { + skipWithoutTPM(t) + + data := make([]byte, 100*1024) + rand.Read(data) + var key [32]byte + rand.Read(key[:]) + + sealed, err := seal(t.Logf, decryptedData{Key: key, Data: data}) + if err != nil { + t.Fatalf("seal: %v", err) + } + if bytes.Contains(sealed.Data, data) { + t.Fatalf("sealed data %q contains original input %q", sealed.Data, data) + } + + unsealed, err := unseal(t.Logf, *sealed) + if err != nil { + t.Fatalf("unseal: %v", err) + } + if !bytes.Equal(data, unsealed.Data) { + t.Errorf("got unsealed data: %q, want: %q", unsealed, data) + } + if key != unsealed.Key { + t.Errorf("got unsealed key: %q, want: %q", unsealed.Key, key) + } +} + +func TestStore(t *testing.T) { + skipWithoutTPM(t) + + path := store.TPMPrefix + filepath.Join(t.TempDir(), "state") + store, err := newStore(t.Logf, path) + if err != nil { + t.Fatal(err) + } + + checkState := func(t *testing.T, store ipn.StateStore, k ipn.StateKey, want []byte) { + got, err := store.ReadState(k) + if err != nil { + t.Errorf("ReadState(%q): %v", k, err) + } + if !bytes.Equal(want, got) { + t.Errorf("ReadState(%q): got %q, want %q", k, got, want) + } + } + + k1, k2 := ipn.StateKey("k1"), ipn.StateKey("k2") + v1, v2 := []byte("v1"), []byte("v2") + + t.Run("read-non-existent-key", func(t *testing.T) { + _, err := store.ReadState(k1) + if !errors.Is(err, ipn.ErrStateNotExist) { + t.Errorf("ReadState succeeded, want %v", ipn.ErrStateNotExist) + } + }) + + t.Run("read-write-k1", func(t *testing.T) { + if err := store.WriteState(k1, v1); err != nil { + t.Errorf("WriteState(%q, %q): %v", k1, v1, err) + } + checkState(t, store, k1, v1) + }) + + t.Run("read-write-k2", func(t *testing.T) { + if err := store.WriteState(k2, v2); err != nil { + t.Errorf("WriteState(%q, %q): %v", k2, v2, err) + } + checkState(t, store, k2, v2) + }) + + t.Run("update-k2", func(t *testing.T) { + v2 = []byte("new v2") + if err := store.WriteState(k2, v2); err != nil { + t.Errorf("WriteState(%q, %q): %v", k2, v2, err) + } + checkState(t, store, k2, v2) + }) + + t.Run("reopen-store", func(t *testing.T) { + store, err := newStore(t.Logf, path) + if err != nil { + t.Fatal(err) + } + checkState(t, store, k1, v1) + checkState(t, store, k2, v2) + }) +} + +func BenchmarkInfo(b *testing.B) { + b.StopTimer() + skipWithoutTPM(b) + b.StartTimer() + for i := 0; i < b.N; i++ { + hi := info() + if hi == nil { + b.Fatalf("tpm info error") + } + } + b.StopTimer() +} + +func BenchmarkTPMSupported(b *testing.B) { + b.StopTimer() + skipWithoutTPM(b) + b.StartTimer() + for i := 0; i < b.N; i++ { + if !tpmSupported() { + b.Fatalf("tpmSupported returned false") + } + } + b.StopTimer() +} + +func BenchmarkStore(b *testing.B) { + skipWithoutTPM(b) + b.StopTimer() + + stores := make(map[string]ipn.StateStore) + key := ipn.StateKey(b.Name()) + + // Set up tpmStore + tpmStore, err := newStore(b.Logf, filepath.Join(b.TempDir(), "tpm.store")) + if err != nil { + b.Fatal(err) + } + if err := tpmStore.WriteState(key, []byte("-1")); err != nil { + b.Fatal(err) + } + stores["tpmStore"] = tpmStore + + // Set up FileStore + fileStore, err := store.NewFileStore(b.Logf, filepath.Join(b.TempDir(), "file.store")) + if err != nil { + b.Fatal(err) + } + if err := fileStore.WriteState(key, []byte("-1")); err != nil { + b.Fatal(err) + } + stores["fileStore"] = fileStore + + b.StartTimer() + + for name, store := range stores { + b.Run(name, func(b *testing.B) { + b.Run("write-noop", func(b *testing.B) { + for range b.N { + if err := store.WriteState(key, []byte("-1")); err != nil { + b.Fatal(err) + } + } + }) + b.Run("write", func(b *testing.B) { + for i := range b.N { + if err := store.WriteState(key, []byte(strconv.Itoa(i))); err != nil { + b.Fatal(err) + } + } + }) + b.Run("read", func(b *testing.B) { + for range b.N { + if _, err := store.ReadState(key); err != nil { + b.Fatal(err) + } + } + }) + }) + } +} + +func TestMigrateStateToTPM(t *testing.T) { + if !tpmSupported() { + t.Logf("using mock tpmseal provider") + store.RegisterForTest(t, store.TPMPrefix, newMockTPMSeal) + } + + storePath := filepath.Join(t.TempDir(), "store") + // Make sure migration doesn't cause a failure when no state file exists. + if _, err := store.New(t.Logf, store.TPMPrefix+storePath); err != nil { + t.Fatalf("store.New failed for new tpmseal store: %v", err) + } + os.Remove(storePath) + + initial, err := store.New(t.Logf, storePath) + if err != nil { + t.Fatalf("store.New failed for new file store: %v", err) + } + + // Populate initial state file. + content := map[ipn.StateKey][]byte{ + "foo": []byte("bar"), + "baz": []byte("qux"), + } + for k, v := range content { + if err := initial.WriteState(k, v); err != nil { + t.Fatal(err) + } + } + // Expected file keys for plaintext and sealed versions of state. + keysPlaintext := []string{"foo", "baz"} + keysTPMSeal := []string{"key", "nonce", "data"} + + for _, tt := range []struct { + desc string + path string + wantKeys []string + }{ + { + desc: "plaintext-to-plaintext", + path: storePath, + wantKeys: keysPlaintext, + }, + { + desc: "plaintext-to-tpmseal", + path: store.TPMPrefix + storePath, + wantKeys: keysTPMSeal, + }, + { + desc: "tpmseal-to-tpmseal", + path: store.TPMPrefix + storePath, + wantKeys: keysTPMSeal, + }, + { + desc: "tpmseal-to-plaintext", + path: storePath, + wantKeys: keysPlaintext, + }, + } { + t.Run(tt.desc, func(t *testing.T) { + s, err := store.New(t.Logf, tt.path) + if err != nil { + t.Fatalf("migration failed: %v", err) + } + gotContent := maps.Collect(s.(interface { + All() iter.Seq2[ipn.StateKey, []byte] + }).All()) + if diff := cmp.Diff(content, gotContent); diff != "" { + t.Errorf("unexpected content after migration, diff:\n%s", diff) + } + + buf, err := os.ReadFile(storePath) + if err != nil { + t.Fatal(err) + } + var data map[string]any + if err := json.Unmarshal(buf, &data); err != nil { + t.Fatal(err) + } + gotKeys := slices.Collect(maps.Keys(data)) + slices.Sort(gotKeys) + slices.Sort(tt.wantKeys) + if diff := cmp.Diff(gotKeys, tt.wantKeys); diff != "" { + t.Errorf("unexpected content keys after migration, diff:\n%s", diff) + } + }) + } +} + +type mockTPMSealProvider struct { + path string + data map[ipn.StateKey][]byte +} + +func newMockTPMSeal(logf logger.Logf, path string) (ipn.StateStore, error) { + path, ok := strings.CutPrefix(path, store.TPMPrefix) + if !ok { + return nil, fmt.Errorf("%q missing tpmseal: prefix", path) + } + s := &mockTPMSealProvider{path: path} + buf, err := os.ReadFile(path) + if errors.Is(err, os.ErrNotExist) { + return s, s.flushState() + } + if err != nil { + return nil, err + } + var data struct { + Key string + Nonce string + Data map[ipn.StateKey][]byte + } + if err := json.Unmarshal(buf, &data); err != nil { + return nil, err + } + if data.Key == "" || data.Nonce == "" { + return nil, fmt.Errorf("%q missing key or nonce", path) + } + s.data = data.Data + return s, nil +} + +func (p *mockTPMSealProvider) ReadState(k ipn.StateKey) ([]byte, error) { + return p.data[k], nil +} + +func (p *mockTPMSealProvider) WriteState(k ipn.StateKey, v []byte) error { + mak.Set(&p.data, k, v) + return p.flushState() +} + +func (p *mockTPMSealProvider) All() iter.Seq2[ipn.StateKey, []byte] { + return maps.All(p.data) +} + +func (p *mockTPMSealProvider) flushState() error { + data := map[string]any{ + "key": "foo", + "nonce": "bar", + "data": p.data, + } + buf, err := json.Marshal(data) + if err != nil { + return err + } + return os.WriteFile(p.path, buf, 0600) +} diff --git a/feature/tpm/tpm_windows.go b/feature/tpm/tpm_windows.go new file mode 100644 index 000000000..429d20cb8 --- /dev/null +++ b/feature/tpm/tpm_windows.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tpm + +import ( + "github.com/google/go-tpm/tpm2/transport" + "github.com/google/go-tpm/tpm2/transport/windowstpm" +) + +func open() (transport.TPMCloser, error) { + return windowstpm.Open() +} diff --git a/feature/useproxy/useproxy.go b/feature/useproxy/useproxy.go new file mode 100644 index 000000000..a18e60577 --- /dev/null +++ b/feature/useproxy/useproxy.go @@ -0,0 +1,18 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package useproxy registers support for using system proxies. +package useproxy + +import ( + "tailscale.com/feature" + "tailscale.com/net/tshttpproxy" +) + +func init() { + feature.HookProxyFromEnvironment.Set(tshttpproxy.ProxyFromEnvironment) + feature.HookProxyInvalidateCache.Set(tshttpproxy.InvalidateCache) + feature.HookProxyGetAuthHeader.Set(tshttpproxy.GetAuthHeader) + feature.HookProxySetSelfProxy.Set(tshttpproxy.SetSelfProxy) + feature.HookProxySetTransportGetProxyConnectHeader.Set(tshttpproxy.SetTransportGetProxyConnectHeader) +} diff --git a/feature/wakeonlan/wakeonlan.go b/feature/wakeonlan/wakeonlan.go new file mode 100644 index 000000000..96c424084 --- /dev/null +++ b/feature/wakeonlan/wakeonlan.go @@ -0,0 +1,243 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package wakeonlan registers the Wake-on-LAN feature. +package wakeonlan + +import ( + "encoding/json" + "log" + "net" + "net/http" + "runtime" + "sort" + "strings" + "unicode" + + "github.com/kortschak/wol" + "tailscale.com/envknob" + "tailscale.com/feature" + "tailscale.com/hostinfo" + "tailscale.com/ipn/ipnlocal" + "tailscale.com/tailcfg" + "tailscale.com/util/clientmetric" +) + +func init() { + feature.Register("wakeonlan") + ipnlocal.RegisterC2N("POST /wol", handleC2NWoL) + ipnlocal.RegisterPeerAPIHandler("/v0/wol", handlePeerAPIWakeOnLAN) + hostinfo.RegisterHostinfoNewHook(func(h *tailcfg.Hostinfo) { + h.WoLMACs = getWoLMACs() + }) +} + +func handleC2NWoL(b *ipnlocal.LocalBackend, w http.ResponseWriter, r *http.Request) { + r.ParseForm() + var macs []net.HardwareAddr + for _, macStr := range r.Form["mac"] { + mac, err := net.ParseMAC(macStr) + if err != nil { + http.Error(w, "bad 'mac' param", http.StatusBadRequest) + return + } + macs = append(macs, mac) + } + var res struct { + SentTo []string + Errors []string + } + st := b.NetMon().InterfaceState() + if st == nil { + res.Errors = append(res.Errors, "no interface state") + writeJSON(w, &res) + return + } + var password []byte // TODO(bradfitz): support? does anything use WoL passwords? + for _, mac := range macs { + for ifName, ips := range st.InterfaceIPs { + for _, ip := range ips { + if ip.Addr().IsLoopback() || ip.Addr().Is6() { + continue + } + local := &net.UDPAddr{ + IP: ip.Addr().AsSlice(), + Port: 0, + } + remote := &net.UDPAddr{ + IP: net.IPv4bcast, + Port: 0, + } + if err := wol.Wake(mac, password, local, remote); err != nil { + res.Errors = append(res.Errors, err.Error()) + } else { + res.SentTo = append(res.SentTo, ifName) + } + break // one per interface is enough + } + } + } + sort.Strings(res.SentTo) + writeJSON(w, &res) +} + +func writeJSON(w http.ResponseWriter, v any) { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(v) +} + +func canWakeOnLAN(h ipnlocal.PeerAPIHandler) bool { + if h.Peer().UnsignedPeerAPIOnly() { + return false + } + return h.IsSelfUntagged() || h.PeerCaps().HasCapability(tailcfg.PeerCapabilityWakeOnLAN) +} + +var metricWakeOnLANCalls = clientmetric.NewCounter("peerapi_wol") + +func handlePeerAPIWakeOnLAN(h ipnlocal.PeerAPIHandler, w http.ResponseWriter, r *http.Request) { + metricWakeOnLANCalls.Add(1) + if !canWakeOnLAN(h) { + http.Error(w, "no WoL access", http.StatusForbidden) + return + } + if r.Method != "POST" { + http.Error(w, "bad method", http.StatusMethodNotAllowed) + return + } + macStr := r.FormValue("mac") + if macStr == "" { + http.Error(w, "missing 'mac' param", http.StatusBadRequest) + return + } + mac, err := net.ParseMAC(macStr) + if err != nil { + http.Error(w, "bad 'mac' param", http.StatusBadRequest) + return + } + var password []byte // TODO(bradfitz): support? does anything use WoL passwords? + st := h.LocalBackend().NetMon().InterfaceState() + if st == nil { + http.Error(w, "failed to get interfaces state", http.StatusInternalServerError) + return + } + var res struct { + SentTo []string + Errors []string + } + for ifName, ips := range st.InterfaceIPs { + for _, ip := range ips { + if ip.Addr().IsLoopback() || ip.Addr().Is6() { + continue + } + local := &net.UDPAddr{ + IP: ip.Addr().AsSlice(), + Port: 0, + } + remote := &net.UDPAddr{ + IP: net.IPv4bcast, + Port: 0, + } + if err := wol.Wake(mac, password, local, remote); err != nil { + res.Errors = append(res.Errors, err.Error()) + } else { + res.SentTo = append(res.SentTo, ifName) + } + break // one per interface is enough + } + } + sort.Strings(res.SentTo) + writeJSON(w, res) +} + +// TODO(bradfitz): this is all too simplistic and static. It needs to run +// continuously in response to netmon events (USB ethernet adapters might get +// plugged in) and look for the media type/status/etc. Right now on macOS it +// still detects a half dozen "up" en0, en1, en2, en3 etc interfaces that don't +// have any media. We should only report the one that's actually connected. +// But it works for now (2023-10-05) for fleshing out the rest. + +var wakeMAC = envknob.RegisterString("TS_WAKE_MAC") // mac address, "false" or "auto". for https://github.com/tailscale/tailscale/issues/306 + +// getWoLMACs returns up to 10 MAC address of the local machine to send +// wake-on-LAN packets to in order to wake it up. The returned MACs are in +// lowercase hex colon-separated form ("xx:xx:xx:xx:xx:xx"). +// +// If TS_WAKE_MAC=auto, it tries to automatically find the MACs based on the OS +// type and interface properties. (TODO(bradfitz): incomplete) If TS_WAKE_MAC is +// set to a MAC address, that sole MAC address is returned. +func getWoLMACs() (macs []string) { + switch runtime.GOOS { + case "ios", "android": + return nil + } + if s := wakeMAC(); s != "" { + switch s { + case "auto": + ifs, _ := net.Interfaces() + for _, iface := range ifs { + if iface.Flags&net.FlagLoopback != 0 { + continue + } + if iface.Flags&net.FlagBroadcast == 0 || + iface.Flags&net.FlagRunning == 0 || + iface.Flags&net.FlagUp == 0 { + continue + } + if keepMAC(iface.Name, iface.HardwareAddr) { + macs = append(macs, iface.HardwareAddr.String()) + } + if len(macs) == 10 { + break + } + } + return macs + case "false", "off": // fast path before ParseMAC error + return nil + } + mac, err := net.ParseMAC(s) + if err != nil { + log.Printf("invalid MAC %q", s) + return nil + } + return []string{mac.String()} + } + return nil +} + +var ignoreWakeOUI = map[[3]byte]bool{ + {0x00, 0x15, 0x5d}: true, // Hyper-V + {0x00, 0x50, 0x56}: true, // VMware + {0x00, 0x1c, 0x14}: true, // VMware + {0x00, 0x05, 0x69}: true, // VMware + {0x00, 0x0c, 0x29}: true, // VMware + {0x00, 0x1c, 0x42}: true, // Parallels + {0x08, 0x00, 0x27}: true, // VirtualBox + {0x00, 0x21, 0xf6}: true, // VirtualBox + {0x00, 0x14, 0x4f}: true, // VirtualBox + {0x00, 0x0f, 0x4b}: true, // VirtualBox + {0x52, 0x54, 0x00}: true, // VirtualBox/Vagrant +} + +func keepMAC(ifName string, mac []byte) bool { + if len(mac) != 6 { + return false + } + base := strings.TrimRightFunc(ifName, unicode.IsNumber) + switch runtime.GOOS { + case "darwin": + switch base { + case "llw", "awdl", "utun", "bridge", "lo", "gif", "stf", "anpi", "ap": + return false + } + } + if mac[0] == 0x02 && mac[1] == 0x42 { + // Docker container. + return false + } + oui := [3]byte{mac[0], mac[1], mac[2]} + if ignoreWakeOUI[oui] { + return false + } + return true +} diff --git a/flake.lock b/flake.lock index 8c4aa7dfc..1623342c6 100644 --- a/flake.lock +++ b/flake.lock @@ -16,31 +16,13 @@ "type": "github" } }, - "flake-utils": { - "inputs": { - "systems": "systems" - }, - "locked": { - "lastModified": 1710146030, - "narHash": "sha256-SZ5L6eA7HJ/nmkzGG7/ISclqe6oZdOZTNoesiInkXPQ=", - "owner": "numtide", - "repo": "flake-utils", - "rev": "b1d9ab70662946ef0850d488da1c9019f3a9752a", - "type": "github" - }, - "original": { - "owner": "numtide", - "repo": "flake-utils", - "type": "github" - } - }, "nixpkgs": { "locked": { - "lastModified": 1724748588, - "narHash": "sha256-NlpGA4+AIf1dKNq76ps90rxowlFXUsV9x7vK/mN37JM=", + "lastModified": 1753151930, + "narHash": "sha256-XSQy6wRKHhRe//iVY5lS/ZpI/Jn6crWI8fQzl647wCg=", "owner": "NixOS", "repo": "nixpkgs", - "rev": "a6292e34000dc93d43bccf78338770c1c5ec8a99", + "rev": "83e677f31c84212343f4cc553bab85c2efcad60a", "type": "github" }, "original": { @@ -53,8 +35,8 @@ "root": { "inputs": { "flake-compat": "flake-compat", - "flake-utils": "flake-utils", - "nixpkgs": "nixpkgs" + "nixpkgs": "nixpkgs", + "systems": "systems" } }, "systems": { diff --git a/flake.nix b/flake.nix index 95d5c3035..fc3a466fc 100644 --- a/flake.nix +++ b/flake.nix @@ -32,7 +32,7 @@ { inputs = { nixpkgs.url = "github:NixOS/nixpkgs/nixpkgs-unstable"; - flake-utils.url = "github:numtide/flake-utils"; + systems.url = "github:nix-systems/default"; # Used by shell.nix as a compat shim. flake-compat = { url = "github:edolstra/flake-compat"; @@ -43,13 +43,32 @@ outputs = { self, nixpkgs, - flake-utils, + systems, flake-compat, }: let - # tailscaleRev is the git commit at which this flake was imported, - # or the empty string when building from a local checkout of the - # tailscale repo. + goVersion = nixpkgs.lib.fileContents ./go.toolchain.version; + toolChainRev = nixpkgs.lib.fileContents ./go.toolchain.rev; + gitHash = nixpkgs.lib.fileContents ./go.toolchain.rev.sri; + eachSystem = f: + nixpkgs.lib.genAttrs (import systems) (system: + f (import nixpkgs { + system = system; + overlays = [ + (final: prev: { + go_1_25 = prev.go_1_25.overrideAttrs { + version = goVersion; + src = prev.fetchFromGitHub { + owner = "tailscale"; + repo = "go"; + rev = toolChainRev; + sha256 = gitHash; + }; + }; + }) + ]; + })); tailscaleRev = self.rev or ""; + in { # tailscale takes a nixpkgs package set, and builds Tailscale from # the same commit as this flake. IOW, it provides "tailscale built # from HEAD", where HEAD is "whatever commit you imported the @@ -67,16 +86,20 @@ # So really, this flake is for tailscale devs to dogfood with, if # you're an end user you should be prepared for this flake to not # build periodically. - tailscale = pkgs: - pkgs.buildGo123Module rec { + packages = eachSystem (pkgs: rec { + default = pkgs.buildGo125Module { name = "tailscale"; - + pname = "tailscale"; src = ./.; vendorHash = pkgs.lib.fileContents ./go.mod.sri; - nativeBuildInputs = pkgs.lib.optionals pkgs.stdenv.isLinux [pkgs.makeWrapper]; + nativeBuildInputs = [pkgs.makeWrapper pkgs.installShellFiles]; ldflags = ["-X tailscale.com/version.gitCommitStamp=${tailscaleRev}"]; - CGO_ENABLED = 0; - subPackages = ["cmd/tailscale" "cmd/tailscaled"]; + env.CGO_ENABLED = 0; + subPackages = [ + "cmd/tailscale" + "cmd/tailscaled" + "cmd/tsidp" + ]; doCheck = false; # NOTE: We strip the ${PORT} and $FLAGS because they are unset in the @@ -84,32 +107,31 @@ # point, there should be a NixOS module that allows configuration of these # things, but for now, we hardcode the default of port 41641 (taken from # ./cmd/tailscaled/tailscaled.defaults). - postInstall = pkgs.lib.optionalString pkgs.stdenv.isLinux '' - wrapProgram $out/bin/tailscaled --prefix PATH : ${pkgs.lib.makeBinPath [pkgs.iproute2 pkgs.iptables pkgs.getent pkgs.shadow]} - wrapProgram $out/bin/tailscale --suffix PATH : ${pkgs.lib.makeBinPath [pkgs.procps]} + postInstall = + pkgs.lib.optionalString pkgs.stdenv.isLinux '' + wrapProgram $out/bin/tailscaled --prefix PATH : ${pkgs.lib.makeBinPath [pkgs.iproute2 pkgs.iptables pkgs.getent pkgs.shadow]} + wrapProgram $out/bin/tailscale --suffix PATH : ${pkgs.lib.makeBinPath [pkgs.procps]} - sed -i \ - -e "s#/usr/sbin#$out/bin#" \ - -e "/^EnvironmentFile/d" \ - -e 's/''${PORT}/41641/' \ - -e 's/$FLAGS//' \ - ./cmd/tailscaled/tailscaled.service + sed -i \ + -e "s#/usr/sbin#$out/bin#" \ + -e "/^EnvironmentFile/d" \ + -e 's/''${PORT}/41641/' \ + -e 's/$FLAGS//' \ + ./cmd/tailscaled/tailscaled.service - install -D -m0444 -t $out/lib/systemd/system ./cmd/tailscaled/tailscaled.service - ''; + install -D -m0444 -t $out/lib/systemd/system ./cmd/tailscaled/tailscaled.service + '' + + pkgs.lib.optionalString (pkgs.stdenv.buildPlatform.canExecute pkgs.stdenv.hostPlatform) '' + installShellCompletion --cmd tailscale \ + --bash <($out/bin/tailscale completion bash) \ + --fish <($out/bin/tailscale completion fish) \ + --zsh <($out/bin/tailscale completion zsh) + ''; }; + tailscale = default; + }); - # This whole blob makes the tailscale package available for all - # OS/CPU combos that nix supports, as well as a dev shell so that - # "nix develop" and "nix-shell" give you a dev env. - flakeForSystem = nixpkgs: system: let - pkgs = nixpkgs.legacyPackages.${system}; - ts = tailscale pkgs; - in { - packages = { - default = ts; - tailscale = ts; - }; + devShells = eachSystem (pkgs: { devShell = pkgs.mkShell { packages = with pkgs; [ curl @@ -118,7 +140,7 @@ gotools graphviz perl - go_1_23 + go_1_25 yarn # qemu and e2fsprogs are needed for natlab @@ -126,8 +148,8 @@ e2fsprogs ]; }; - }; - in - flake-utils.lib.eachDefaultSystem (system: flakeForSystem nixpkgs system); + }); + }; } -# nix-direnv cache busting line: sha256-xO1DuLWi6/lpA9ubA2ZYVJM+CkVNA5IaVGZxX9my0j0= +# nix-direnv cache busting line: sha256-sGPgML2YM/XNWfsAdDZvzWHagcydwCmR6nKOHJj5COs= + diff --git a/go.mod b/go.mod index 464db8313..3b4f34b2d 100644 --- a/go.mod +++ b/go.mod @@ -1,28 +1,29 @@ module tailscale.com -go 1.23.1 +go 1.25.3 require ( filippo.io/mkcert v1.4.4 - fyne.io/systray v1.11.0 + fyne.io/systray v1.11.1-0.20250812065214-4856ac3adc3c + github.com/Kodeworks/golang-image-ico v0.0.0-20141118225523-73f0f4cfade9 github.com/akutz/memconn v0.1.0 github.com/alexbrainman/sspi v0.0.0-20231016080023-1a75b4708caa github.com/andybalholm/brotli v1.1.0 github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be github.com/atotto/clipboard v0.1.4 - github.com/aws/aws-sdk-go-v2 v1.24.1 - github.com/aws/aws-sdk-go-v2/config v1.26.5 - github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.11.64 - github.com/aws/aws-sdk-go-v2/service/s3 v1.33.0 + github.com/aws/aws-sdk-go-v2 v1.36.0 + github.com/aws/aws-sdk-go-v2/config v1.29.5 + github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.17.58 + github.com/aws/aws-sdk-go-v2/service/s3 v1.75.3 github.com/aws/aws-sdk-go-v2/service/ssm v1.44.7 github.com/bramvdbogaerde/go-scp v1.4.0 github.com/cilium/ebpf v0.15.0 github.com/coder/websocket v1.8.12 github.com/coreos/go-iptables v0.7.1-0.20240112124308-65c67c9f46e6 github.com/coreos/go-systemd v0.0.0-20191104093116-d3cd4ed1dbcf + github.com/creachadair/msync v0.7.1 + github.com/creachadair/taskgroup v0.13.2 github.com/creack/pty v1.1.23 - github.com/dave/courtney v0.4.0 - github.com/dave/patsy v0.0.0-20210517141501-957256f50cba github.com/dblohm7/wingoes v0.0.0-20240119213807-a09d6be7affa github.com/digitalocean/go-smbios v0.0.0-20180907143718-390a4f403a8e github.com/distribution/reference v0.6.0 @@ -32,140 +33,166 @@ require ( github.com/evanw/esbuild v0.19.11 github.com/fogleman/gg v1.3.0 github.com/frankban/quicktest v1.14.6 - github.com/fxamacker/cbor/v2 v2.6.0 - github.com/gaissmai/bart v0.11.1 - github.com/go-json-experiment/json v0.0.0-20231102232822-2e55bd4e08b0 + github.com/fxamacker/cbor/v2 v2.7.0 + github.com/gaissmai/bart v0.18.0 + github.com/go-json-experiment/json v0.0.0-20250813024750-ebf49471dced github.com/go-logr/zapr v1.3.0 github.com/go-ole/go-ole v1.3.0 + github.com/go4org/plan9netshell v0.0.0-20250324183649-788daa080737 github.com/godbus/dbus/v5 v5.1.1-0.20230522191255-76236955d466 github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da github.com/golang/snappy v0.0.4 github.com/golangci/golangci-lint v1.57.1 - github.com/google/go-cmp v0.6.0 - github.com/google/go-containerregistry v0.18.0 + github.com/google/go-cmp v0.7.0 + github.com/google/go-containerregistry v0.20.3 + github.com/google/go-tpm v0.9.4 github.com/google/gopacket v1.1.19 github.com/google/nftables v0.2.1-0.20240414091927-5e242ec57806 github.com/google/uuid v1.6.0 github.com/goreleaser/nfpm/v2 v2.33.1 + github.com/hashicorp/go-hclog v1.6.2 + github.com/hashicorp/raft v1.7.2 + github.com/hashicorp/raft-boltdb/v2 v2.3.1 github.com/hdevalence/ed25519consensus v0.2.0 - github.com/illarion/gonotify/v2 v2.0.3 - github.com/inetaf/tcpproxy v0.0.0-20240214030015-3ce58045626c + github.com/illarion/gonotify/v3 v3.0.2 + github.com/inetaf/tcpproxy v0.0.0-20250203165043-ded522cbd03f github.com/insomniacslk/dhcp v0.0.0-20231206064809-8c70d406f6d2 github.com/jellydator/ttlcache/v3 v3.1.0 - github.com/josharian/native v1.1.1-0.20230202152459-5c7d0dd6ab86 github.com/jsimonetti/rtnetlink v1.4.0 github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 - github.com/klauspost/compress v1.17.4 + github.com/klauspost/compress v1.17.11 github.com/kortschak/wol v0.0.0-20200729010619-da482cc4850a github.com/mattn/go-colorable v0.1.13 github.com/mattn/go-isatty v0.0.20 github.com/mdlayher/genetlink v1.3.2 - github.com/mdlayher/netlink v1.7.2 + github.com/mdlayher/netlink v1.7.3-0.20250113171957-fbb4dce95f42 github.com/mdlayher/sdnotify v1.0.0 github.com/miekg/dns v1.1.58 github.com/mitchellh/go-ps v1.0.0 github.com/peterbourgon/ff/v3 v3.4.0 + github.com/pires/go-proxyproto v0.8.1 github.com/pkg/errors v0.9.1 github.com/pkg/sftp v1.13.6 github.com/prometheus-community/pro-bing v0.4.0 - github.com/prometheus/client_golang v1.19.1 - github.com/prometheus/common v0.48.0 + github.com/prometheus/client_golang v1.20.5 + github.com/prometheus/common v0.55.0 github.com/prometheus/prometheus v0.49.2-0.20240125131847-c3b8ef1694ff github.com/safchain/ethtool v0.3.0 github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e github.com/studio-b12/gowebdav v0.9.0 github.com/tailscale/certstore v0.1.1-0.20231202035212-d3fa0460f47e - github.com/tailscale/depaware v0.0.0-20210622194025-720c4b409502 + github.com/tailscale/depaware v0.0.0-20251001183927-9c2ad255ef3f github.com/tailscale/goexpect v0.0.0-20210902213824-6e8c725cea41 - github.com/tailscale/golang-x-crypto v0.0.0-20240604161659-3fde5e568aa4 + github.com/tailscale/golang-x-crypto v0.0.0-20250404221719-a5573b049869 github.com/tailscale/goupnp v1.0.1-0.20210804011211-c64d0f06ea05 github.com/tailscale/hujson v0.0.0-20221223112325-20486734a56a - github.com/tailscale/mkctr v0.0.0-20240628074852-17ca944da6ba + github.com/tailscale/mkctr v0.0.0-20250228050937-c75ea1476830 github.com/tailscale/netlink v1.1.1-0.20240822203006-4d49adab4de7 - github.com/tailscale/peercred v0.0.0-20240214030740-b535050b2aa4 - github.com/tailscale/web-client-prebuilt v0.0.0-20240226180453-5db17b287bf1 + github.com/tailscale/peercred v0.0.0-20250107143737-35a0c7bd7edc + github.com/tailscale/setec v0.0.0-20250205144240-8898a29c3fbb + github.com/tailscale/web-client-prebuilt v0.0.0-20250124233751-d4cd19a26976 github.com/tailscale/wf v0.0.0-20240214030419-6fbb0a674ee6 - github.com/tailscale/wireguard-go v0.0.0-20240905161824-799c1978fafc + github.com/tailscale/wireguard-go v0.0.0-20250716170648-1d0488a3d7da github.com/tailscale/xnet v0.0.0-20240729143630-8497ac4dab2e github.com/tc-hib/winres v0.2.1 github.com/tcnksm/go-httpstat v0.2.0 github.com/toqueteos/webbrowser v1.2.0 - github.com/u-root/u-root v0.12.0 - github.com/vishvananda/netns v0.0.4 + github.com/u-root/u-root v0.14.0 + github.com/vishvananda/netns v0.0.5 go.uber.org/zap v1.27.0 - go4.org/mem v0.0.0-20220726221520-4f986261bf13 + go4.org/mem v0.0.0-20240501181205-ae6ca9944745 go4.org/netipx v0.0.0-20231129151722-fdeea329fbba - golang.org/x/crypto v0.25.0 - golang.org/x/exp v0.0.0-20240119083558-1b970713d09a - golang.org/x/mod v0.19.0 - golang.org/x/net v0.27.0 - golang.org/x/oauth2 v0.16.0 - golang.org/x/sync v0.7.0 - golang.org/x/sys v0.22.0 - golang.org/x/term v0.22.0 - golang.org/x/time v0.5.0 - golang.org/x/tools v0.23.0 + golang.org/x/crypto v0.44.0 + golang.org/x/exp v0.0.0-20250210185358-939b2ce775ac + golang.org/x/mod v0.30.0 + golang.org/x/net v0.47.0 + golang.org/x/oauth2 v0.30.0 + golang.org/x/sync v0.18.0 + golang.org/x/sys v0.38.0 + golang.org/x/term v0.37.0 + golang.org/x/time v0.11.0 + golang.org/x/tools v0.39.0 golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 golang.zx2c4.com/wireguard/windows v0.5.3 gopkg.in/square/go-jose.v2 v2.6.0 - gvisor.dev/gvisor v0.0.0-20240722211153-64c016c92987 - honnef.co/go/tools v0.5.1 - k8s.io/api v0.30.3 - k8s.io/apimachinery v0.30.3 - k8s.io/apiserver v0.30.3 - k8s.io/client-go v0.30.3 - sigs.k8s.io/controller-runtime v0.18.4 - sigs.k8s.io/controller-tools v0.15.1-0.20240618033008-7824932b0cab + gvisor.dev/gvisor v0.0.0-20250205023644-9414b50a5633 + honnef.co/go/tools v0.7.0-0.dev.0.20251022135355-8273271481d0 + k8s.io/api v0.32.0 + k8s.io/apimachinery v0.32.0 + k8s.io/apiserver v0.32.0 + k8s.io/client-go v0.32.0 + sigs.k8s.io/controller-runtime v0.19.4 + sigs.k8s.io/controller-tools v0.17.0 sigs.k8s.io/yaml v1.4.0 software.sslmate.com/src/go-pkcs12 v0.4.0 ) require ( + 9fans.net/go v0.0.8-0.20250307142834-96bdba94b63f // indirect github.com/4meepo/tagalign v1.3.3 // indirect github.com/Antonboom/testifylint v1.2.0 // indirect github.com/GaijinEntertainment/go-exhaustruct/v3 v3.2.0 // indirect github.com/Masterminds/sprig v2.22.0+incompatible // indirect - github.com/Microsoft/go-winio v0.6.1 // indirect + github.com/Microsoft/go-winio v0.6.2 // indirect github.com/OpenPeeDeeP/depguard/v2 v2.2.0 // indirect github.com/alecthomas/go-check-sumtype v0.1.4 // indirect github.com/alexkohler/nakedret/v2 v2.0.4 // indirect - github.com/bits-and-blooms/bitset v1.13.0 // indirect + github.com/armon/go-metrics v0.4.1 // indirect + github.com/blang/semver/v4 v4.0.0 // indirect + github.com/boltdb/bolt v1.3.1 // indirect github.com/bombsimon/wsl/v4 v4.2.1 // indirect github.com/butuzov/mirror v1.1.0 // indirect github.com/catenacyber/perfsprint v0.7.1 // indirect github.com/ccojocar/zxcvbn-go v1.0.2 // indirect github.com/ckaznocha/intrange v0.1.0 // indirect - github.com/cyphar/filepath-securejoin v0.2.4 // indirect - github.com/dave/astrid v0.0.0-20170323122508-8c2895878b14 // indirect - github.com/dave/brenda v1.1.0 // indirect - github.com/docker/go-connections v0.4.0 // indirect + github.com/containerd/typeurl/v2 v2.2.3 // indirect + github.com/cyphar/filepath-securejoin v0.3.6 // indirect + github.com/deckarep/golang-set/v2 v2.8.0 // indirect + github.com/docker/go-connections v0.5.0 // indirect github.com/docker/go-units v0.5.0 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect github.com/ghostiam/protogetter v0.3.5 // indirect github.com/go-logr/stdr v1.2.2 // indirect github.com/go-viper/mapstructure/v2 v2.0.0-alpha.1 // indirect - github.com/gobuffalo/flect v1.0.2 // indirect + github.com/gobuffalo/flect v1.0.3 // indirect github.com/goccy/go-yaml v1.12.0 // indirect github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 // indirect github.com/golangci/plugin-module-register v0.1.1 // indirect github.com/google/gnostic-models v0.6.9-0.20230804172637-c7be7c783f49 // indirect - github.com/google/pprof v0.0.0-20240409012703-83162a5b38cd // indirect + github.com/google/go-github/v66 v66.0.0 // indirect + github.com/google/go-querystring v1.1.0 // indirect github.com/gorilla/securecookie v1.1.2 // indirect + github.com/hashicorp/go-immutable-radix v1.3.1 // indirect + github.com/hashicorp/go-metrics v0.5.4 // indirect + github.com/hashicorp/go-msgpack/v2 v2.1.2 // indirect + github.com/hashicorp/golang-lru v0.6.0 // indirect github.com/jjti/go-spancheck v0.5.3 // indirect github.com/karamaru-alpha/copyloopvar v1.0.8 // indirect + github.com/kylelemons/godebug v1.1.0 // indirect github.com/macabu/inamedparam v0.1.3 // indirect + github.com/moby/buildkit v0.20.2 // indirect github.com/moby/docker-image-spec v1.3.1 // indirect + github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10 // indirect + github.com/puzpuzpuz/xsync v1.5.2 // indirect github.com/santhosh-tekuri/jsonschema/v5 v5.3.1 // indirect + github.com/stacklok/frizbee v0.1.7 // indirect github.com/xen0n/gosmopolitan v1.2.2 // indirect github.com/ykadowak/zerologlint v0.1.5 // indirect go-simpler.org/musttag v0.9.0 // indirect go-simpler.org/sloglint v0.5.0 // indirect - go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.47.0 // indirect - go.opentelemetry.io/otel v1.22.0 // indirect - go.opentelemetry.io/otel/metric v1.22.0 // indirect - go.opentelemetry.io/otel/trace v1.22.0 // indirect + go.etcd.io/bbolt v1.3.11 // indirect + go.opentelemetry.io/auto/sdk v1.1.0 // indirect + go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.58.0 // indirect + go.opentelemetry.io/otel v1.33.0 // indirect + go.opentelemetry.io/otel/metric v1.33.0 // indirect + go.opentelemetry.io/otel/trace v1.33.0 // indirect go.uber.org/automaxprocs v1.5.3 // indirect + golang.org/x/telemetry v0.0.0-20251111182119-bc8e575c7b54 // indirect + golang.org/x/tools/go/expect v0.1.1-deprecated // indirect + golang.org/x/tools/go/packages/packagestest v0.1.1-deprecated // indirect golang.org/x/xerrors v0.0.0-20240716161551-93cc26a95ae9 // indirect + gopkg.in/evanphx/json-patch.v4 v4.12.0 // indirect + k8s.io/component-base v0.32.0 // indirect ) require ( @@ -183,26 +210,26 @@ require ( github.com/Masterminds/semver v1.5.0 // indirect github.com/Masterminds/semver/v3 v3.2.1 // indirect github.com/Masterminds/sprig/v3 v3.2.3 // indirect - github.com/ProtonMail/go-crypto v1.0.0 // indirect + github.com/ProtonMail/go-crypto v1.1.3 // indirect github.com/alexkohler/prealloc v1.0.0 // indirect github.com/alingse/asasalint v0.0.11 // indirect github.com/ashanbrown/forbidigo v1.6.0 // indirect github.com/ashanbrown/makezero v1.1.1 // indirect - github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.4.10 // indirect - github.com/aws/aws-sdk-go-v2/credentials v1.16.16 // indirect - github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.14.11 // indirect - github.com/aws/aws-sdk-go-v2/internal/configsources v1.2.10 // indirect - github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.5.10 // indirect - github.com/aws/aws-sdk-go-v2/internal/ini v1.7.2 // indirect - github.com/aws/aws-sdk-go-v2/internal/v4a v1.0.25 // indirect - github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.10.4 // indirect - github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.1.28 // indirect - github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.10.10 // indirect - github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.14.2 // indirect - github.com/aws/aws-sdk-go-v2/service/sso v1.18.7 // indirect - github.com/aws/aws-sdk-go-v2/service/ssooidc v1.21.7 // indirect - github.com/aws/aws-sdk-go-v2/service/sts v1.26.7 // indirect - github.com/aws/smithy-go v1.19.0 // indirect + github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.8 // indirect + github.com/aws/aws-sdk-go-v2/credentials v1.17.58 // indirect + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.27 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.31 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.31 // indirect + github.com/aws/aws-sdk-go-v2/internal/ini v1.8.2 // indirect + github.com/aws/aws-sdk-go-v2/internal/v4a v1.3.31 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.2 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.5.5 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.12 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.18.12 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.24.14 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.13 // indirect + github.com/aws/aws-sdk-go-v2/service/sts v1.33.13 // indirect + github.com/aws/smithy-go v1.22.2 // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/bkielbasa/cyclop v1.2.1 // indirect github.com/blakesmith/ar v0.0.0-20190502131153-809d4375e1fb // indirect @@ -211,37 +238,36 @@ require ( github.com/breml/errchkjson v0.3.6 // indirect github.com/butuzov/ireturn v0.3.0 // indirect github.com/cavaliergopher/cpio v1.0.1 // indirect - github.com/cespare/xxhash/v2 v2.2.0 // indirect + github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/charithe/durationcheck v0.0.10 // indirect github.com/chavacava/garif v0.1.0 // indirect - github.com/cloudflare/circl v1.3.7 // indirect - github.com/containerd/stargz-snapshotter/estargz v0.15.1 // indirect + github.com/cloudflare/circl v1.6.1 // indirect + github.com/containerd/stargz-snapshotter/estargz v0.16.3 // indirect github.com/curioswitch/go-reassign v0.2.0 // indirect github.com/daixiang0/gci v0.12.3 // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/denis-tingaikin/go-header v0.5.0 // indirect - github.com/docker/cli v25.0.0+incompatible // indirect + github.com/docker/cli v27.5.1+incompatible // indirect github.com/docker/distribution v2.8.3+incompatible // indirect - github.com/docker/docker v26.1.4+incompatible // indirect - github.com/docker/docker-credential-helpers v0.8.1 // indirect + github.com/docker/docker v27.5.1+incompatible // indirect + github.com/docker/docker-credential-helpers v0.8.2 // indirect github.com/emicklei/go-restful/v3 v3.11.2 // indirect github.com/emirpasic/gods v1.18.1 // indirect github.com/ettle/strcase v0.2.0 // indirect - github.com/evanphx/json-patch v5.6.0+incompatible // indirect github.com/evanphx/json-patch/v5 v5.9.0 // indirect - github.com/fatih/color v1.17.0 // indirect + github.com/fatih/color v1.18.0 // indirect github.com/fatih/structtag v1.2.0 // indirect github.com/firefart/nonamedreturns v1.0.4 // indirect github.com/fsnotify/fsnotify v1.7.0 github.com/fzipp/gocyclo v0.6.0 // indirect github.com/go-critic/go-critic v0.11.2 // indirect github.com/go-git/gcfg v1.5.1-0.20230307220236-3a3c6141e376 // indirect - github.com/go-git/go-billy/v5 v5.5.0 // indirect - github.com/go-git/go-git/v5 v5.11.0 // indirect + github.com/go-git/go-billy/v5 v5.6.2 // indirect + github.com/go-git/go-git/v5 v5.13.1 // indirect github.com/go-logr/logr v1.4.2 // indirect - github.com/go-openapi/jsonpointer v0.20.2 // indirect + github.com/go-openapi/jsonpointer v0.21.0 // indirect github.com/go-openapi/jsonreference v0.20.4 // indirect - github.com/go-openapi/swag v0.22.7 // indirect + github.com/go-openapi/swag v0.23.0 // indirect github.com/go-toolsmith/astcast v1.1.0 // indirect github.com/go-toolsmith/astcopy v1.1.0 // indirect github.com/go-toolsmith/astequal v1.2.0 // indirect @@ -251,7 +277,7 @@ require ( github.com/go-toolsmith/typep v1.1.0 // indirect github.com/go-xmlfmt/xmlfmt v1.1.2 // indirect github.com/gobwas/glob v0.2.3 // indirect - github.com/gofrs/flock v0.8.1 // indirect + github.com/gofrs/flock v0.12.1 // indirect github.com/gogo/protobuf v1.3.2 // indirect github.com/golang/protobuf v1.5.4 // indirect github.com/golangci/dupl v0.0.0-20180902072040-3e9179ac440a // indirect @@ -266,7 +292,7 @@ require ( github.com/gordonklaus/ineffassign v0.1.0 // indirect github.com/goreleaser/chglog v0.5.0 // indirect github.com/goreleaser/fileglob v1.3.0 // indirect - github.com/gorilla/csrf v1.7.2 + github.com/gorilla/csrf v1.7.3 github.com/gostaticanalysis/analysisutil v0.7.1 // indirect github.com/gostaticanalysis/comment v1.4.2 // indirect github.com/gostaticanalysis/forcetypeassert v0.1.0 // indirect @@ -322,46 +348,46 @@ require ( github.com/nunnatsa/ginkgolinter v0.16.1 // indirect github.com/olekukonko/tablewriter v0.0.5 // indirect github.com/opencontainers/go-digest v1.0.0 // indirect - github.com/opencontainers/image-spec v1.1.0-rc6 // indirect + github.com/opencontainers/image-spec v1.1.0 // indirect github.com/pelletier/go-toml/v2 v2.2.0 // indirect github.com/pierrec/lz4/v4 v4.1.21 // indirect github.com/pjbgf/sha1cd v0.3.0 // indirect github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/polyfloyd/go-errorlint v1.4.8 // indirect - github.com/prometheus/client_model v0.5.0 - github.com/prometheus/procfs v0.12.0 // indirect + github.com/prometheus/client_model v0.6.1 + github.com/prometheus/procfs v0.15.1 // indirect github.com/quasilyte/go-ruleguard v0.4.2 // indirect github.com/quasilyte/gogrep v0.5.0 // indirect github.com/quasilyte/regex/syntax v0.0.0-20210819130434-b3f0c404a727 // indirect github.com/quasilyte/stdinfo v0.0.0-20220114132959-f7386bf02567 // indirect github.com/rivo/uniseg v0.4.4 // indirect - github.com/rogpeppe/go-internal v1.12.0 // indirect + github.com/rogpeppe/go-internal v1.13.1 // indirect github.com/ryancurrah/gomodguard v1.3.1 // indirect github.com/ryanrolds/sqlclosecheck v0.5.1 // indirect github.com/sanposhiho/wastedassign/v2 v2.0.7 // indirect github.com/sashamelentyev/interfacebloat v1.1.0 // indirect github.com/sashamelentyev/usestdlibvars v1.25.0 // indirect github.com/securego/gosec/v2 v2.19.0 // indirect - github.com/sergi/go-diff v1.3.1 // indirect + github.com/sergi/go-diff v1.3.2-0.20230802210424-5b0b94c5c0d3 // indirect github.com/shazow/go-diff v0.0.0-20160112020656-b6b7b6733b8c // indirect github.com/shopspring/decimal v1.3.1 // indirect github.com/sirupsen/logrus v1.9.3 // indirect github.com/sivchari/containedctx v1.0.3 // indirect github.com/sivchari/tenv v1.7.1 // indirect - github.com/skeema/knownhosts v1.2.1 // indirect + github.com/skeema/knownhosts v1.3.0 // indirect github.com/sonatard/noctx v0.0.2 // indirect github.com/sourcegraph/go-diff v0.7.0 // indirect github.com/spf13/afero v1.11.0 // indirect github.com/spf13/cast v1.6.0 // indirect - github.com/spf13/cobra v1.8.1 // indirect + github.com/spf13/cobra v1.9.1 // indirect github.com/spf13/jwalterweatherman v1.1.0 // indirect - github.com/spf13/pflag v1.0.5 // indirect + github.com/spf13/pflag v1.0.6 // indirect github.com/spf13/viper v1.16.0 // indirect github.com/ssgreg/nlreturn/v2 v2.2.1 // indirect github.com/stbenjam/no-sprintf-host-port v0.1.1 // indirect github.com/stretchr/objx v0.5.2 // indirect - github.com/stretchr/testify v1.9.0 + github.com/stretchr/testify v1.10.0 github.com/subosito/gotenv v1.4.2 // indirect github.com/t-yuki/gocover-cobertura v0.0.0-20180217150009-aaee18c8195c // indirect github.com/tailscale/go-winio v0.0.0-20231025203758-c4f33415bf55 @@ -371,12 +397,12 @@ require ( github.com/timonwong/loggercheck v0.9.4 // indirect github.com/tomarrell/wrapcheck/v2 v2.8.3 // indirect github.com/tommy-muehle/go-mnd/v2 v2.5.1 // indirect - github.com/u-root/uio v0.0.0-20240118234441-a3c409a6018e // indirect - github.com/ulikunitz/xz v0.5.11 // indirect + github.com/u-root/uio v0.0.0-20240224005618-d2acac8f3701 // indirect + github.com/ulikunitz/xz v0.5.15 // indirect github.com/ultraware/funlen v0.1.0 // indirect github.com/ultraware/whitespace v0.1.0 // indirect github.com/uudashr/gocognit v1.1.2 // indirect - github.com/vbatts/tar-split v0.11.5 // indirect + github.com/vbatts/tar-split v0.11.6 // indirect github.com/x448/float16 v0.8.4 // indirect github.com/xanzy/ssh-agent v0.3.3 // indirect github.com/yagipy/maintidx v1.0.0 // indirect @@ -385,23 +411,24 @@ require ( gitlab.com/digitalxero/go-conventional-commit v1.0.7 // indirect go.uber.org/multierr v1.11.0 // indirect golang.org/x/exp/typeparams v0.0.0-20240314144324-c7f7c6466f7f // indirect - golang.org/x/image v0.18.0 // indirect - golang.org/x/text v0.16.0 // indirect + golang.org/x/image v0.27.0 // indirect + golang.org/x/text v0.31.0 // indirect gomodules.xyz/jsonpatch/v2 v2.4.0 // indirect - google.golang.org/appengine v1.6.8 // indirect - google.golang.org/protobuf v1.33.0 // indirect + google.golang.org/protobuf v1.36.3 // indirect gopkg.in/inf.v0 v0.9.1 // indirect gopkg.in/ini.v1 v1.67.0 // indirect gopkg.in/warnings.v0 v0.1.2 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect gopkg.in/yaml.v3 v3.0.1 howett.net/plist v1.0.0 // indirect - k8s.io/apiextensions-apiserver v0.30.3 // indirect + k8s.io/apiextensions-apiserver v0.32.0 k8s.io/klog/v2 v2.130.1 // indirect - k8s.io/kube-openapi v0.0.0-20240228011516-70dd3763d340 // indirect - k8s.io/utils v0.0.0-20240711033017-18e509b52bc8 + k8s.io/kube-openapi v0.0.0-20241105132330-32ad38e42d3f // indirect + k8s.io/utils v0.0.0-20241104100929-3ea5e8cea738 mvdan.cc/gofumpt v0.6.0 // indirect mvdan.cc/unparam v0.0.0-20240104100049-c549a3470d14 // indirect - sigs.k8s.io/json v0.0.0-20221116044647-bc3834ca7abd // indirect - sigs.k8s.io/structured-merge-diff/v4 v4.4.1 // indirect + sigs.k8s.io/json v0.0.0-20241010143419-9aa6b5e7a4b3 // indirect + sigs.k8s.io/structured-merge-diff/v4 v4.4.2 // indirect ) + +tool github.com/stacklok/frizbee diff --git a/go.mod.sri b/go.mod.sri index 4abb3c516..76c72f0c9 100644 --- a/go.mod.sri +++ b/go.mod.sri @@ -1 +1 @@ -sha256-xO1DuLWi6/lpA9ubA2ZYVJM+CkVNA5IaVGZxX9my0j0= +sha256-sGPgML2YM/XNWfsAdDZvzWHagcydwCmR6nKOHJj5COs= diff --git a/go.sum b/go.sum index 549f559d0..f0758f2d4 100644 --- a/go.sum +++ b/go.sum @@ -2,6 +2,8 @@ 4d63.com/gocheckcompilerdirectives v1.2.1/go.mod h1:yjDJSxmDTtIHHCqX0ufRYZDL6vQtMG7tJdKVeWwsqvs= 4d63.com/gochecknoglobals v0.2.1 h1:1eiorGsgHOFOuoOiJDy2psSrQbRdIHrlge0IJIkUgDc= 4d63.com/gochecknoglobals v0.2.1/go.mod h1:KRE8wtJB3CXCsb1xy421JfTHIIbmT3U5ruxw2Qu8fSU= +9fans.net/go v0.0.8-0.20250307142834-96bdba94b63f h1:1C7nZuxUMNz7eiQALRfiqNOm04+m3edWlRff/BYHf0Q= +9fans.net/go v0.0.8-0.20250307142834-96bdba94b63f/go.mod h1:hHyrZRryGqVdqrknjq5OWDLGCTJ2NeEvtrpR96mjraM= cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= cloud.google.com/go v0.38.0/go.mod h1:990N+gfupTy94rShfmMCWGDn0LpTmnzTp2qbd1dvSRU= @@ -41,8 +43,8 @@ filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= filippo.io/mkcert v1.4.4 h1:8eVbbwfVlaqUM7OwuftKc2nuYOoTDQWqsoXmzoXZdbc= filippo.io/mkcert v1.4.4/go.mod h1:VyvOchVuAye3BoUsPUOOofKygVwLV2KQMVFJNRq+1dA= -fyne.io/systray v1.11.0 h1:D9HISlxSkx+jHSniMBR6fCFOUjk1x/OOOJLa9lJYAKg= -fyne.io/systray v1.11.0/go.mod h1:RVwqP9nYMo7h5zViCBHri2FgjXF7H2cub7MAq4NSoLs= +fyne.io/systray v1.11.1-0.20250812065214-4856ac3adc3c h1:km4PIleGtbbF1oxmFQuO93CyNCldwuRTPB8WlzNWNZs= +fyne.io/systray v1.11.1-0.20250812065214-4856ac3adc3c/go.mod h1:RVwqP9nYMo7h5zViCBHri2FgjXF7H2cub7MAq4NSoLs= github.com/4meepo/tagalign v1.3.3 h1:ZsOxcwGD/jP4U/aw7qeWu58i7dwYemfy5Y+IF1ACoNw= github.com/4meepo/tagalign v1.3.3/go.mod h1:Q9c1rYMZJc9dPRkbQPpcBNCLEmY2njbAsXhQOZFE2dE= github.com/Abirdcfly/dupword v0.0.14 h1:3U4ulkc8EUo+CaT105/GJ1BQwtgyj6+VaBVbAX11Ba8= @@ -61,12 +63,15 @@ github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03 github.com/BurntSushi/toml v1.4.1-0.20240526193622-a339e1f7089c h1:pxW6RcqyfI9/kWtOwnv/G+AzdKuy2ZrqINhenH4HyNs= github.com/BurntSushi/toml v1.4.1-0.20240526193622-a339e1f7089c/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho= github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= -github.com/DataDog/zstd v1.4.5 h1:EndNeuB0l9syBZhut0wns3gV1hL8zX8LIu6ZiVHWLIQ= -github.com/DataDog/zstd v1.4.5/go.mod h1:1jcaCB/ufaK+sKp1NBhlGmpz41jOoPQ35bpF36t7BBo= +github.com/DataDog/datadog-go v3.2.0+incompatible/go.mod h1:LButxg5PwREeZtORoXG3tL4fMGNddJ+vMq1mwgfaqoQ= +github.com/DataDog/zstd v1.5.2 h1:vUG4lAyuPCXO0TLbXvPv7EB7cNK1QV/luu55UHLrrn8= +github.com/DataDog/zstd v1.5.2/go.mod h1:g4AWEaM3yOg3HYfnJ3YIawPnVdXJh9QME85blwSAmyw= github.com/Djarvur/go-err113 v0.1.0 h1:uCRZZOdMQ0TZPHYTdYpoC0bLYJKPEHPUJ8MeAa51lNU= github.com/Djarvur/go-err113 v0.1.0/go.mod h1:4UJr5HIiMZrwgkSPdsjy2uOQExX/WEILpIrO9UPGuXs= github.com/GaijinEntertainment/go-exhaustruct/v3 v3.2.0 h1:sATXp1x6/axKxz2Gjxv8MALP0bXaNRfQinEwyfMcx8c= github.com/GaijinEntertainment/go-exhaustruct/v3 v3.2.0/go.mod h1:Nl76DrGNJTA1KJ0LePKBw/vznBX1EHbAZX8mwjR82nI= +github.com/Kodeworks/golang-image-ico v0.0.0-20141118225523-73f0f4cfade9 h1:1ltqoej5GtaWF8jaiA49HwsZD459jqm9YFz9ZtMFpQA= +github.com/Kodeworks/golang-image-ico v0.0.0-20141118225523-73f0f4cfade9/go.mod h1:7uhhqiBaR4CpN0k9rMjOtjpcfGd6DG2m04zQxKnWQ0I= github.com/Masterminds/goutils v1.1.1 h1:5nUrii3FMTL5diU80unEVvNevw1nH4+ZV4DSLVJLSYI= github.com/Masterminds/goutils v1.1.1/go.mod h1:8cTjp+g8YejhMuvIA5y2vz3BpJxksy863GQaJW2MFNU= github.com/Masterminds/semver v1.5.0 h1:H65muMkzWKEuNDnfl9d70GUjFniHKHRbFPGBuZ3QEww= @@ -79,12 +84,12 @@ github.com/Masterminds/sprig v2.22.0+incompatible/go.mod h1:y6hNFY5UBTIWBxnzTeuN github.com/Masterminds/sprig/v3 v3.2.3 h1:eL2fZNezLomi0uOLqjQoN6BfsDD+fyLtgbJMAj9n6YA= github.com/Masterminds/sprig/v3 v3.2.3/go.mod h1:rXcFaZ2zZbLRJv/xSysmlgIM1u11eBaRMhvYXJNkGuM= github.com/Microsoft/go-winio v0.5.2/go.mod h1:WpS1mjBmmwHBEWmogvA2mj8546UReBk4v8QkMxJ6pZY= -github.com/Microsoft/go-winio v0.6.1 h1:9/kr64B9VUZrLm5YYwbGtUJnMgqWVOdUAXu6Migciow= -github.com/Microsoft/go-winio v0.6.1/go.mod h1:LRdKpFKfdobln8UmuiYcKPot9D2v6svN5+sAH+4kjUM= +github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY= +github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU= github.com/OpenPeeDeeP/depguard/v2 v2.2.0 h1:vDfG60vDtIuf0MEOhmLlLLSzqaRM8EMcgJPdp74zmpA= github.com/OpenPeeDeeP/depguard/v2 v2.2.0/go.mod h1:CIzddKRvLBC4Au5aYP/i3nyaWQ+ClszLIuVocRiCYFQ= -github.com/ProtonMail/go-crypto v1.0.0 h1:LRuvITjQWX+WIfr930YHG2HNfjR1uOfyf5vE0kC2U78= -github.com/ProtonMail/go-crypto v1.0.0/go.mod h1:EjAoLdwvbIOoOQr3ihjnSoLZRtE8azugULFRteWMNc0= +github.com/ProtonMail/go-crypto v1.1.3 h1:nRBOetoydLeUb4nHajyO2bKqMLfWQ/ZPwkXqXxPxCFk= +github.com/ProtonMail/go-crypto v1.1.3/go.mod h1:rA3QumHc/FZ8pAHreoekgiAbzpNsfQAosU5td4SnOrE= github.com/ProtonMail/go-mime v0.0.0-20230322103455-7d82a3887f2f h1:tCbYj7/299ekTTXpdwKYF8eBlsYsDVoggDAuAjoK66k= github.com/ProtonMail/go-mime v0.0.0-20230322103455-7d82a3887f2f/go.mod h1:gcr0kNtGBqin9zDW9GOHcVntrwnjrK+qdJ06mWYBybw= github.com/ProtonMail/gopenpgp/v2 v2.7.1 h1:Awsg7MPc2gD3I7IFac2qE3Gdls0lZW8SzrFZ3k1oz0s= @@ -114,6 +119,8 @@ github.com/andybalholm/brotli v1.1.0 h1:eLKJA0d02Lf0mVpIDgYnqXcUn0GqVmEFny3VuID1 github.com/andybalholm/brotli v1.1.0/go.mod h1:sms7XGricyQI9K10gOSf56VKKWS4oLer58Q+mhRPtnY= github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be h1:9AeTilPcZAjCFIImctFaOjnTIavg87rW78vTPkQqLI8= github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be/go.mod h1:ySMOLuWl6zY27l47sB3qLNK6tF2fkHG55UZxx8oIVo4= +github.com/armon/go-metrics v0.4.1 h1:hR91U9KYmb6bLBYLQjyM+3j+rcd/UhE+G78SFnF8gJA= +github.com/armon/go-metrics v0.4.1/go.mod h1:E6amYzXo6aW1tqzoZGT755KkbgrJsSdpwZ+3JqfkOG4= github.com/armon/go-proxyproto v0.0.0-20210323213023-7e956b284f0a/go.mod h1:QmP9hvJ91BbJmGVGSbutW19IC0Q9phDCLGaomwTJbgU= github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5 h1:0CwZNZbxp69SHPdPJAN/hZIm0C4OItdklCFmMRWYpio= github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5/go.mod h1:wHh0iHkYZB8zMSxRWpUBQtwG5a7fFgvEO+odwuTv2gs= @@ -123,71 +130,60 @@ github.com/ashanbrown/makezero v1.1.1 h1:iCQ87C0V0vSyO+M9E/FZYbu65auqH0lnsOkf5Fc github.com/ashanbrown/makezero v1.1.1/go.mod h1:i1bJLCRSCHOcOa9Y6MyF2FTfMZMFdHvxKHxgO5Z1axI= github.com/atotto/clipboard v0.1.4 h1:EH0zSVneZPSuFR11BlR9YppQTVDbh5+16AmcJi4g1z4= github.com/atotto/clipboard v0.1.4/go.mod h1:ZY9tmq7sm5xIbd9bOK4onWV4S6X0u6GY7Vn0Yu86PYI= -github.com/aws/aws-sdk-go-v2 v1.18.0/go.mod h1:uzbQtefpm44goOPmdKyAlXSNcwlRgF3ePWVW6EtJvvw= -github.com/aws/aws-sdk-go-v2 v1.24.1 h1:xAojnj+ktS95YZlDf0zxWBkbFtymPeDP+rvUQIH3uAU= -github.com/aws/aws-sdk-go-v2 v1.24.1/go.mod h1:LNh45Br1YAkEKaAqvmE1m8FUx6a5b/V0oAKV7of29b4= -github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.4.10 h1:dK82zF6kkPeCo8J1e+tGx4JdvDIQzj7ygIoLg8WMuGs= -github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.4.10/go.mod h1:VeTZetY5KRJLuD/7fkQXMU6Mw7H5m/KP2J5Iy9osMno= -github.com/aws/aws-sdk-go-v2/config v1.18.22/go.mod h1:mN7Li1wxaPxSSy4Xkr6stFuinJGf3VZW3ZSNvO0q6sI= -github.com/aws/aws-sdk-go-v2/config v1.26.5 h1:lodGSevz7d+kkFJodfauThRxK9mdJbyutUxGq1NNhvw= -github.com/aws/aws-sdk-go-v2/config v1.26.5/go.mod h1:DxHrz6diQJOc9EwDslVRh84VjjrE17g+pVZXUeSxaDU= -github.com/aws/aws-sdk-go-v2/credentials v1.13.21/go.mod h1:90Dk1lJoMyspa/EDUrldTxsPns0wn6+KpRKpdAWc0uA= -github.com/aws/aws-sdk-go-v2/credentials v1.16.16 h1:8q6Rliyv0aUFAVtzaldUEcS+T5gbadPbWdV1WcAddK8= -github.com/aws/aws-sdk-go-v2/credentials v1.16.16/go.mod h1:UHVZrdUsv63hPXFo1H7c5fEneoVo9UXiz36QG1GEPi0= -github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.13.3/go.mod h1:4Q0UFP0YJf0NrsEuEYHpM9fTSEVnD16Z3uyEF7J9JGM= -github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.14.11 h1:c5I5iH+DZcH3xOIMlz3/tCKJDaHFwYEmxvlh2fAcFo8= -github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.14.11/go.mod h1:cRrYDYAMUohBJUtUnOhydaMHtiK/1NZ0Otc9lIb6O0Y= -github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.11.64 h1:9QJQs36z61YB8nxGwRDfWXEDYbU6H7jdI6zFiAX1vag= -github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.11.64/go.mod h1:4Q7R9MFpXRdjO3YnAfUTdnuENs32WzBkASt6VxSYDYQ= -github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.33/go.mod h1:7i0PF1ME/2eUPFcjkVIwq+DOygHEoK92t5cDqNgYbIw= -github.com/aws/aws-sdk-go-v2/internal/configsources v1.2.10 h1:vF+Zgd9s+H4vOXd5BMaPWykta2a6Ih0AKLq/X6NYKn4= -github.com/aws/aws-sdk-go-v2/internal/configsources v1.2.10/go.mod h1:6BkRjejp/GR4411UGqkX8+wFMbFbqsUIimfK4XjOKR4= -github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.4.27/go.mod h1:UrHnn3QV/d0pBZ6QBAEQcqFLf8FAzLmoUfPVIueOvoM= -github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.5.10 h1:nYPe006ktcqUji8S2mqXf9c/7NdiKriOwMvWQHgYztw= -github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.5.10/go.mod h1:6UV4SZkVvmODfXKql4LCbaZUpF7HO2BX38FgBf9ZOLw= -github.com/aws/aws-sdk-go-v2/internal/ini v1.3.34/go.mod h1:Etz2dj6UHYuw+Xw830KfzCfWGMzqvUTCjUj5b76GVDc= -github.com/aws/aws-sdk-go-v2/internal/ini v1.7.2 h1:GrSw8s0Gs/5zZ0SX+gX4zQjRnRsMJDJ2sLur1gRBhEM= -github.com/aws/aws-sdk-go-v2/internal/ini v1.7.2/go.mod h1:6fQQgfuGmw8Al/3M2IgIllycxV7ZW7WCdVSqfBeUiCY= -github.com/aws/aws-sdk-go-v2/internal/v4a v1.0.25 h1:AzwRi5OKKwo4QNqPf7TjeO+tK8AyOK3GVSwmRPo7/Cs= -github.com/aws/aws-sdk-go-v2/internal/v4a v1.0.25/go.mod h1:SUbB4wcbSEyCvqBxv/O/IBf93RbEze7U7OnoTlpPB+g= -github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.9.11/go.mod h1:iV4q2hsqtNECrfmlXyord9u4zyuFEJX9eLgLpSPzWA8= -github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.10.4 h1:/b31bi3YVNlkzkBrm9LfpaKoaYZUxIAj4sHfOTmLfqw= -github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.10.4/go.mod h1:2aGXHFmbInwgP9ZfpmdIfOELL79zhdNYNmReK8qDfdQ= -github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.1.28 h1:vGWm5vTpMr39tEZfQeDiDAMgk+5qsnvRny3FjLpnH5w= -github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.1.28/go.mod h1:spfrICMD6wCAhjhzHuy6DOZZ+LAIY10UxhUmLzpJTTs= -github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.9.27/go.mod h1:EOwBD4J4S5qYszS5/3DpkejfuK+Z5/1uzICfPaZLtqw= -github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.10.10 h1:DBYTXwIGQSGs9w4jKm60F5dmCQ3EEruxdc0MFh+3EY4= -github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.10.10/go.mod h1:wohMUQiFdzo0NtxbBg0mSRGZ4vL3n0dKjLTINdcIino= -github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.14.2 h1:NbWkRxEEIRSCqxhsHQuMiTH7yo+JZW1gp8v3elSVMTQ= -github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.14.2/go.mod h1:4tfW5l4IAB32VWCDEBxCRtR9T4BWy4I4kr1spr8NgZM= -github.com/aws/aws-sdk-go-v2/service/s3 v1.33.0 h1:L5h2fymEdVJYvn6hYO8Jx48YmC6xVmjmgHJV3oGKgmc= -github.com/aws/aws-sdk-go-v2/service/s3 v1.33.0/go.mod h1:J9kLNzEiHSeGMyN7238EjJmBpCniVzFda75Gxl/NqB8= +github.com/aws/aws-sdk-go-v2 v1.36.0 h1:b1wM5CcE65Ujwn565qcwgtOTT1aT4ADOHHgglKjG7fk= +github.com/aws/aws-sdk-go-v2 v1.36.0/go.mod h1:5PMILGVKiW32oDzjj6RU52yrNrDPUHcbZQYr1sM7qmM= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.8 h1:zAxi9p3wsZMIaVCdoiQp2uZ9k1LsZvmAnoTBeZPXom0= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.8/go.mod h1:3XkePX5dSaxveLAYY7nsbsZZrKxCyEuE5pM4ziFxyGg= +github.com/aws/aws-sdk-go-v2/config v1.29.5 h1:4lS2IB+wwkj5J43Tq/AwvnscBerBJtQQ6YS7puzCI1k= +github.com/aws/aws-sdk-go-v2/config v1.29.5/go.mod h1:SNzldMlDVbN6nWxM7XsUiNXPSa1LWlqiXtvh/1PrJGg= +github.com/aws/aws-sdk-go-v2/credentials v1.17.58 h1:/d7FUpAPU8Lf2KUdjniQvfNdlMID0Sd9pS23FJ3SS9Y= +github.com/aws/aws-sdk-go-v2/credentials v1.17.58/go.mod h1:aVYW33Ow10CyMQGFgC0ptMRIqJWvJ4nxZb0sUiuQT/A= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.27 h1:7lOW8NUwE9UZekS1DYoiPdVAqZ6A+LheHWb+mHbNOq8= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.27/go.mod h1:w1BASFIPOPUae7AgaH4SbjNbfdkxuggLyGfNFTn8ITY= +github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.17.58 h1:/BsEGAyMai+KdXS+CMHlLhB5miAO19wOqE6tj8azWPM= +github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.17.58/go.mod h1:KHM3lfl/sAJBCoLI1Lsg5w4SD2VDYWwQi7vxbKhw7TI= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.31 h1:lWm9ucLSRFiI4dQQafLrEOmEDGry3Swrz0BIRdiHJqQ= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.31/go.mod h1:Huu6GG0YTfbPphQkDSo4dEGmQRTKb9k9G7RdtyQWxuI= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.31 h1:ACxDklUKKXb48+eg5ROZXi1vDgfMyfIA/WyvqHcHI0o= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.31/go.mod h1:yadnfsDwqXeVaohbGc/RaD287PuyRw2wugkh5ZL2J6k= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.2 h1:Pg9URiobXy85kgFev3og2CuOZ8JZUBENF+dcgWBaYNk= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.2/go.mod h1:FbtygfRFze9usAadmnGJNc8KsP346kEe+y2/oyhGAGc= +github.com/aws/aws-sdk-go-v2/internal/v4a v1.3.31 h1:8IwBjuLdqIO1dGB+dZ9zJEl8wzY3bVYxcs0Xyu/Lsc0= +github.com/aws/aws-sdk-go-v2/internal/v4a v1.3.31/go.mod h1:8tMBcuVjL4kP/ECEIWTCWtwV2kj6+ouEKl4cqR4iWLw= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.2 h1:D4oz8/CzT9bAEYtVhSBmFj2dNOtaHOtMKc2vHBwYizA= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.2/go.mod h1:Za3IHqTQ+yNcRHxu1OFucBh0ACZT4j4VQFF0BqpZcLY= +github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.5.5 h1:siiQ+jummya9OLPDEyHVb2dLW4aOMe22FGDd0sAfuSw= +github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.5.5/go.mod h1:iHVx2J9pWzITdP5MJY6qWfG34TfD9EA+Qi3eV6qQCXw= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.12 h1:O+8vD2rGjfihBewr5bT+QUfYUHIxCVgG61LHoT59shM= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.12/go.mod h1:usVdWJaosa66NMvmCrr08NcWDBRv4E6+YFG2pUdw1Lk= +github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.18.12 h1:tkVNm99nkJnFo1H9IIQb5QkCiPcvCDn3Pos+IeTbGRA= +github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.18.12/go.mod h1:dIVlquSPUMqEJtx2/W17SM2SuESRaVEhEV9alcMqxjw= +github.com/aws/aws-sdk-go-v2/service/s3 v1.75.3 h1:JBod0SnNqcWQ0+uAyzeRFG1zCHotW8DukumYYyNy0zo= +github.com/aws/aws-sdk-go-v2/service/s3 v1.75.3/go.mod h1:FHSHmyEUkzRbaFFqqm6bkLAOQHgqhsLmfCahvCBMiyA= github.com/aws/aws-sdk-go-v2/service/ssm v1.44.7 h1:a8HvP/+ew3tKwSXqL3BCSjiuicr+XTU2eFYeogV9GJE= github.com/aws/aws-sdk-go-v2/service/ssm v1.44.7/go.mod h1:Q7XIWsMo0JcMpI/6TGD6XXcXcV1DbTj6e9BKNntIMIM= -github.com/aws/aws-sdk-go-v2/service/sso v1.12.9/go.mod h1:ouy2P4z6sJN70fR3ka3wD3Ro3KezSxU6eKGQI2+2fjI= -github.com/aws/aws-sdk-go-v2/service/sso v1.18.7 h1:eajuO3nykDPdYicLlP3AGgOyVN3MOlFmZv7WGTuJPow= -github.com/aws/aws-sdk-go-v2/service/sso v1.18.7/go.mod h1:+mJNDdF+qiUlNKNC3fxn74WWNN+sOiGOEImje+3ScPM= -github.com/aws/aws-sdk-go-v2/service/ssooidc v1.14.9/go.mod h1:AFvkxc8xfBe8XA+5St5XIHHrQQtkxqrRincx4hmMHOk= -github.com/aws/aws-sdk-go-v2/service/ssooidc v1.21.7 h1:QPMJf+Jw8E1l7zqhZmMlFw6w1NmfkfiSK8mS4zOx3BA= -github.com/aws/aws-sdk-go-v2/service/ssooidc v1.21.7/go.mod h1:ykf3COxYI0UJmxcfcxcVuz7b6uADi1FkiUz6Eb7AgM8= -github.com/aws/aws-sdk-go-v2/service/sts v1.18.10/go.mod h1:BgQOMsg8av8jset59jelyPW7NoZcZXLVpDsXunGDrk8= -github.com/aws/aws-sdk-go-v2/service/sts v1.26.7 h1:NzO4Vrau795RkUdSHKEwiR01FaGzGOH1EETJ+5QHnm0= -github.com/aws/aws-sdk-go-v2/service/sts v1.26.7/go.mod h1:6h2YuIoxaMSCFf5fi1EgZAwdfkGMgDY+DVfa61uLe4U= -github.com/aws/smithy-go v1.13.5/go.mod h1:Tg+OJXh4MB2R/uN61Ko2f6hTZwB/ZYGOtib8J3gBHzA= -github.com/aws/smithy-go v1.19.0 h1:KWFKQV80DpP3vJrrA9sVAHQ5gc2z8i4EzrLhLlWXcBM= -github.com/aws/smithy-go v1.19.0/go.mod h1:NukqUGpCZIILqqiV0NIjeFh24kd/FAa4beRb6nbIUPE= +github.com/aws/aws-sdk-go-v2/service/sso v1.24.14 h1:c5WJ3iHz7rLIgArznb3JCSQT3uUMiz9DLZhIX+1G8ok= +github.com/aws/aws-sdk-go-v2/service/sso v1.24.14/go.mod h1:+JJQTxB6N4niArC14YNtxcQtwEqzS3o9Z32n7q33Rfs= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.13 h1:f1L/JtUkVODD+k1+IiSJUUv8A++2qVr+Xvb3xWXETMU= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.13/go.mod h1:tvqlFoja8/s0o+UruA1Nrezo/df0PzdunMDDurUfg6U= +github.com/aws/aws-sdk-go-v2/service/sts v1.33.13 h1:3LXNnmtH3TURctC23hnC0p/39Q5gre3FI7BNOiDcVWc= +github.com/aws/aws-sdk-go-v2/service/sts v1.33.13/go.mod h1:7Yn+p66q/jt38qMoVfNvjbm3D89mGBnkwDcijgtih8w= +github.com/aws/smithy-go v1.22.2 h1:6D9hW43xKFrRx/tXXfAlIZc4JI+yQe6snnWcQyxSyLQ= +github.com/aws/smithy-go v1.22.2/go.mod h1:irrKGvNn1InZwb2d7fkIRNucdfwR8R+Ts3wxYa/cJHg= github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= -github.com/bits-and-blooms/bitset v1.13.0 h1:bAQ9OPNFYbGHV6Nez0tmNI0RiEu7/hxlYJRUA0wFAVE= -github.com/bits-and-blooms/bitset v1.13.0/go.mod h1:7hO7Gc7Pp1vODcmWvKMRA9BNmbv6a/7QIWpPxHddWR8= github.com/bkielbasa/cyclop v1.2.1 h1:AeF71HZDob1P2/pRm1so9cd1alZnrpyc4q2uP2l0gJY= github.com/bkielbasa/cyclop v1.2.1/go.mod h1:K/dT/M0FPAiYjBgQGau7tz+3TMh4FWAEqlMhzFWCrgM= github.com/blakesmith/ar v0.0.0-20190502131153-809d4375e1fb h1:m935MPodAbYS46DG4pJSv7WO+VECIWUQ7OJYSoTrMh4= github.com/blakesmith/ar v0.0.0-20190502131153-809d4375e1fb/go.mod h1:PkYb9DJNAwrSvRx5DYA+gUcOIgTGVMNkfSCbZM8cWpI= +github.com/blang/semver/v4 v4.0.0 h1:1PFHFE6yCCTv8C1TeyNNarDzntLi7wMI5i/pzqYIsAM= +github.com/blang/semver/v4 v4.0.0/go.mod h1:IbckMUScFkM3pff0VJDNKRiT6TG/YpiHIM2yvyW5YoQ= github.com/blizzy78/varnamelen v0.8.0 h1:oqSblyuQvFsW1hbBHh1zfwrKe3kcSj0rnXkKzsQ089M= github.com/blizzy78/varnamelen v0.8.0/go.mod h1:V9TzQZ4fLJ1DSrjVDfl89H7aMnTvKkApdHeyESmyR7k= +github.com/boltdb/bolt v1.3.1 h1:JQmyP4ZBrce+ZQu0dY660FMfatumYDLun9hBCUVIkF4= +github.com/boltdb/bolt v1.3.1/go.mod h1:clJnj/oiGkjum5o1McbSZDSLxVThjynRyGBgiAx27Ps= github.com/bombsimon/wsl/v4 v4.2.1 h1:Cxg6u+XDWff75SIFFmNsqnIOgob+Q9hG6y/ioKbRFiM= github.com/bombsimon/wsl/v4 v4.2.1/go.mod h1:Xu/kDxGZTofQcDGCtQe9KCzhHphIe0fDuyWTxER9Feo= github.com/bramvdbogaerde/go-scp v1.4.0 h1:jKMwpwCbcX1KyvDbm/PDJuXcMuNVlLGi0Q0reuzjyKY= @@ -200,7 +196,6 @@ github.com/butuzov/ireturn v0.3.0 h1:hTjMqWw3y5JC3kpnC5vXmFJAWI/m31jaCYQqzkS6PL0 github.com/butuzov/ireturn v0.3.0/go.mod h1:A09nIiwiqzN/IoVo9ogpa0Hzi9fex1kd9PSD6edP5ZA= github.com/butuzov/mirror v1.1.0 h1:ZqX54gBVMXu78QLoiqdwpl2mgmoOJTk7s4p4o+0avZI= github.com/butuzov/mirror v1.1.0/go.mod h1:8Q0BdQU6rC6WILDiBM60DBfvV78OLJmMmixe7GF45AE= -github.com/bwesterb/go-ristretto v1.2.3/go.mod h1:fUIoIZaG73pV5biE2Blr2xEzDoMj7NFEuV9ekS419A0= github.com/caarlos0/go-rpmutils v0.2.1-0.20211112020245-2cd62ff89b11 h1:IRrDwVlWQr6kS1U8/EtyA1+EHcc4yl8pndcqXWrEamg= github.com/caarlos0/go-rpmutils v0.2.1-0.20211112020245-2cd62ff89b11/go.mod h1:je2KZ+LxaCNvCoKg32jtOIULcFogJKcL1ZWUaIBjKj0= github.com/caarlos0/testfs v0.4.4 h1:3PHvzHi5Lt+g332CiShwS8ogTgS3HjrmzZxCm6JCDr8= @@ -212,13 +207,13 @@ github.com/cavaliergopher/cpio v1.0.1/go.mod h1:pBdaqQjnvXxdS/6CvNDwIANIFSP0xRKI github.com/ccojocar/zxcvbn-go v1.0.2 h1:na/czXU8RrhXO4EZme6eQJLR4PzcGsahsBOAwU6I3Vg= github.com/ccojocar/zxcvbn-go v1.0.2/go.mod h1:g1qkXtUSvHP8lhHp5GrSmTz6uWALGRMQdw6Qnz/hi60= github.com/cenkalti/backoff v2.2.1+incompatible h1:tNowT99t7UNflLxfYYSlKYsBpXdEet03Pg2g16Swow4= -github.com/cenkalti/backoff/v4 v4.2.1 h1:y4OZtCnogmCPw98Zjyt5a6+QwPLGkiQsYW5oUqylYbM= -github.com/cenkalti/backoff/v4 v4.2.1/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE= +github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8= +github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= -github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44= -github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= +github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/charithe/durationcheck v0.0.10 h1:wgw73BiocdBDQPik+zcEoBG/ob8uyBHf2iyoHGPf5w4= github.com/charithe/durationcheck v0.0.10/go.mod h1:bCWXb7gYRysD1CU3C+u4ceO49LoGOY1C1L6uouGNreQ= github.com/chavacava/garif v0.1.0 h1:2JHa3hbYf5D9dsgseMKAmc/MZ109otzgNFk5s87H9Pc= @@ -228,47 +223,50 @@ github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5P github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU= github.com/cilium/ebpf v0.15.0 h1:7NxJhNiBT3NG8pZJ3c+yfrVdHY8ScgKD27sScgjLMMk= github.com/cilium/ebpf v0.15.0/go.mod h1:DHp1WyrLeiBh19Cf/tfiSMhqheEiK8fXFZ4No0P1Hso= +github.com/circonus-labs/circonus-gometrics v2.3.1+incompatible/go.mod h1:nmEj6Dob7S7YxXgwXpfOuvO54S+tGdZdw9fuRZt25Ag= +github.com/circonus-labs/circonusllhist v0.1.3/go.mod h1:kMXHVDlOchFAehlya5ePtbp5jckzBHf4XRpQvBOLI+I= github.com/ckaznocha/intrange v0.1.0 h1:ZiGBhvrdsKpoEfzh9CjBfDSZof6QB0ORY5tXasUtiew= github.com/ckaznocha/intrange v0.1.0/go.mod h1:Vwa9Ekex2BrEQMg6zlrWwbs/FtYw7eS5838Q7UjK7TQ= github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= -github.com/cloudflare/circl v1.3.3/go.mod h1:5XYMA4rFBvNIrhs50XuiBJ15vF2pZn4nnUKZrLbUZFA= -github.com/cloudflare/circl v1.3.7 h1:qlCDlTPz2n9fu58M0Nh1J/JzcFpfgkFHHX3O35r5vcU= -github.com/cloudflare/circl v1.3.7/go.mod h1:sRTcRWXGLrKw6yIGJ+l7amYJFfAXbZG0kBSc8r4zxgA= +github.com/cloudflare/circl v1.6.1 h1:zqIqSPIndyBh1bjLVVDHMPpVKqp8Su/V+6MeDzzQBQ0= +github.com/cloudflare/circl v1.6.1/go.mod h1:uddAzsPgqdMAYatqJ0lsjX1oECcQLIlRpzZh3pJrofs= github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= github.com/coder/websocket v1.8.12 h1:5bUXkEPPIbewrnkU8LTCLVaxi4N4J8ahufH2vlo4NAo= github.com/coder/websocket v1.8.12/go.mod h1:LNVeNrXQZfe5qhS9ALED3uA+l5pPqvwXg3CKoDBB2gs= github.com/containerd/log v0.1.0 h1:TCJt7ioM2cr/tfR8GPbGf9/VRAX8D2B4PjzCpfX540I= github.com/containerd/log v0.1.0/go.mod h1:VRRf09a7mHDIRezVKTRCrOq78v577GXq3bSa3EhrzVo= -github.com/containerd/stargz-snapshotter/estargz v0.15.1 h1:eXJjw9RbkLFgioVaTG+G/ZW/0kEe2oEKCdS/ZxIyoCU= -github.com/containerd/stargz-snapshotter/estargz v0.15.1/go.mod h1:gr2RNwukQ/S9Nv33Lt6UC7xEx58C+LHRdoqbEKjz1Kk= +github.com/containerd/stargz-snapshotter/estargz v0.16.3 h1:7evrXtoh1mSbGj/pfRccTampEyKpjpOnS3CyiV1Ebr8= +github.com/containerd/stargz-snapshotter/estargz v0.16.3/go.mod h1:uyr4BfYfOj3G9WBVE8cOlQmXAbPN9VEQpBBeJIuOipU= +github.com/containerd/typeurl/v2 v2.2.3 h1:yNA/94zxWdvYACdYO8zofhrTVuQY73fFU1y++dYSw40= +github.com/containerd/typeurl/v2 v2.2.3/go.mod h1:95ljDnPfD3bAbDJRugOiShd/DlAAsxGtUBhJxIn7SCk= github.com/coreos/go-iptables v0.7.1-0.20240112124308-65c67c9f46e6 h1:8h5+bWd7R6AYUslN6c6iuZWTKsKxUFDlpnmilO6R2n0= github.com/coreos/go-iptables v0.7.1-0.20240112124308-65c67c9f46e6/go.mod h1:Qe8Bv2Xik5FyTXwgIbLAnv2sWSBmvWdFETJConOQ//Q= github.com/coreos/go-systemd v0.0.0-20191104093116-d3cd4ed1dbcf h1:iW4rZ826su+pqaw19uhpSCzhj44qo35pNgKFGqzDKkU= github.com/coreos/go-systemd v0.0.0-20191104093116-d3cd4ed1dbcf/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= -github.com/cpuguy83/go-md2man/v2 v2.0.4/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= +github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= +github.com/creachadair/mds v0.25.9 h1:080Hr8laN2h+l3NeVCGMBpXtIPnl9mz8e4HLraGPqtA= +github.com/creachadair/mds v0.25.9/go.mod h1:4hatI3hRM+qhzuAmqPRFvaBM8mONkS7nsLxkcuTYUIs= +github.com/creachadair/msync v0.7.1 h1:SeZmuEBXQPe5GqV/C94ER7QIZPwtvFbeQiykzt/7uho= +github.com/creachadair/msync v0.7.1/go.mod h1:8CcFlLsSujfHE5wWm19uUBLHIPDAUr6LXDwneVMO008= +github.com/creachadair/taskgroup v0.13.2 h1:3KyqakBuFsm3KkXi/9XIb0QcA8tEzLHLgaoidf0MdVc= +github.com/creachadair/taskgroup v0.13.2/go.mod h1:i3V1Zx7H8RjwljUEeUWYT30Lmb9poewSb2XI1yTwD0g= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/creack/pty v1.1.23 h1:4M6+isWdcStXEf15G/RbrMPOQj1dZ7HPZCGwE4kOeP0= github.com/creack/pty v1.1.23/go.mod h1:08sCNb52WyoAwi2QDyzUCTgcvVFhUzewun7wtTfvcwE= github.com/curioswitch/go-reassign v0.2.0 h1:G9UZyOcpk/d7Gd6mqYgd8XYWFMw/znxwGDUstnC9DIo= github.com/curioswitch/go-reassign v0.2.0/go.mod h1:x6OpXuWvgfQaMGks2BZybTngWjT84hqJfKoO8Tt/Roc= -github.com/cyphar/filepath-securejoin v0.2.4 h1:Ugdm7cg7i6ZK6x3xDF1oEu1nfkyfH53EtKeQYTC3kyg= -github.com/cyphar/filepath-securejoin v0.2.4/go.mod h1:aPGpWjXOXUn2NCNjFvBE6aRxGGx79pTxQpKOJNYHHl4= +github.com/cyphar/filepath-securejoin v0.3.6 h1:4d9N5ykBnSp5Xn2JkhocYDkOpURL/18CYMpo6xB9uWM= +github.com/cyphar/filepath-securejoin v0.3.6/go.mod h1:Sdj7gXlvMcPZsbhwhQ33GguGLDGQL7h7bg04C/+u9jI= github.com/daixiang0/gci v0.12.3 h1:yOZI7VAxAGPQmkb1eqt5g/11SUlwoat1fSblGLmdiQc= github.com/daixiang0/gci v0.12.3/go.mod h1:xtHP9N7AHdNvtRNfcx9gwTDfw7FRJx4bZUsiEfiNNAI= -github.com/dave/astrid v0.0.0-20170323122508-8c2895878b14 h1:YI1gOOdmMk3xodBao7fehcvoZsEeOyy/cfhlpCSPgM4= -github.com/dave/astrid v0.0.0-20170323122508-8c2895878b14/go.mod h1:Sth2QfxfATb/nW4EsrSi2KyJmbcniZ8TgTaji17D6ms= -github.com/dave/brenda v1.1.0 h1:Sl1LlwXnbw7xMhq3y2x11McFu43AjDcwkllxxgZ3EZw= -github.com/dave/brenda v1.1.0/go.mod h1:4wCUr6gSlu5/1Tk7akE5X7UorwiQ8Rij0SKH3/BGMOM= -github.com/dave/courtney v0.4.0 h1:Vb8hi+k3O0h5++BR96FIcX0x3NovRbnhGd/dRr8inBk= -github.com/dave/courtney v0.4.0/go.mod h1:3WSU3yaloZXYAxRuWt8oRyVb9SaRiMBt5Kz/2J227tM= -github.com/dave/patsy v0.0.0-20210517141501-957256f50cba h1:1o36L4EKbZzazMk8iGC4kXpVnZ6TPxR2mZ9qVKjNNAs= -github.com/dave/patsy v0.0.0-20210517141501-957256f50cba/go.mod h1:qfR88CgEGLoiqDaE+xxDCi5QA5v4vUoW0UCX2Nd5Tlc= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dblohm7/wingoes v0.0.0-20240119213807-a09d6be7affa h1:h8TfIT1xc8FWbwwpmHn1J5i43Y0uZP97GqasGCzSRJk= github.com/dblohm7/wingoes v0.0.0-20240119213807-a09d6be7affa/go.mod h1:Nx87SkVqTKd8UtT+xu7sM/l+LgXs6c0aHrlKusR+2EQ= +github.com/deckarep/golang-set/v2 v2.8.0 h1:swm0rlPCmdWn9mESxKOjWk8hXSqoxOp+ZlfuyaAdFlQ= +github.com/deckarep/golang-set/v2 v2.8.0/go.mod h1:VAky9rY/yGXJOLEDv3OMci+7wtDpOF4IN+y82NBOac4= github.com/denis-tingaikin/go-header v0.5.0 h1:SRdnP5ZKvcO9KKRP1KJrhFR3RrlGuD+42t4429eC9k8= github.com/denis-tingaikin/go-header v0.5.0/go.mod h1:mMenU5bWrok6Wl2UsZjy+1okegmwQ3UgWl4V1D8gjlY= github.com/digitalocean/go-smbios v0.0.0-20180907143718-390a4f403a8e h1:vUmf0yezR0y7jJ5pceLHthLaYf4bA5T14B6q39S4q2Q= @@ -277,24 +275,24 @@ github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5Qvfr github.com/distribution/reference v0.6.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E= github.com/djherbis/times v1.6.0 h1:w2ctJ92J8fBvWPxugmXIv7Nz7Q3iDMKNx9v5ocVH20c= github.com/djherbis/times v1.6.0/go.mod h1:gOHeRAz2h+VJNZ5Gmc/o7iD9k4wW7NMVqieYCY99oc0= -github.com/docker/cli v25.0.0+incompatible h1:zaimaQdnX7fYWFqzN88exE9LDEvRslexpFowZBX6GoQ= -github.com/docker/cli v25.0.0+incompatible/go.mod h1:JLrzqnKDaYBop7H2jaqPtU4hHvMKP+vjCwu2uszcLI8= +github.com/docker/cli v27.5.1+incompatible h1:JB9cieUT9YNiMITtIsguaN55PLOHhBSz3LKVc6cqWaY= +github.com/docker/cli v27.5.1+incompatible/go.mod h1:JLrzqnKDaYBop7H2jaqPtU4hHvMKP+vjCwu2uszcLI8= github.com/docker/distribution v2.8.3+incompatible h1:AtKxIZ36LoNK51+Z6RpzLpddBirtxJnzDrHLEKxTAYk= github.com/docker/distribution v2.8.3+incompatible/go.mod h1:J2gT2udsDAN96Uj4KfcMRqY0/ypR+oyYUYmja8H+y+w= -github.com/docker/docker v26.1.4+incompatible h1:vuTpXDuoga+Z38m1OZHzl7NKisKWaWlhjQk7IDPSLsU= -github.com/docker/docker v26.1.4+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk= -github.com/docker/docker-credential-helpers v0.8.1 h1:j/eKUktUltBtMzKqmfLB0PAgqYyMHOp5vfsD1807oKo= -github.com/docker/docker-credential-helpers v0.8.1/go.mod h1:P3ci7E3lwkZg6XiHdRKft1KckHiO9a2rNtyFbZ/ry9M= -github.com/docker/go-connections v0.4.0 h1:El9xVISelRB7BuFusrZozjnkIM5YnzCViNKohAFqRJQ= -github.com/docker/go-connections v0.4.0/go.mod h1:Gbd7IOopHjR8Iph03tsViu4nIes5XhDvyHbTtUxmeec= +github.com/docker/docker v27.5.1+incompatible h1:4PYU5dnBYqRQi0294d1FBECqT9ECWeQAIfE8q4YnPY8= +github.com/docker/docker v27.5.1+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk= +github.com/docker/docker-credential-helpers v0.8.2 h1:bX3YxiGzFP5sOXWc3bTPEXdEaZSeVMrFgOr3T+zrFAo= +github.com/docker/docker-credential-helpers v0.8.2/go.mod h1:P3ci7E3lwkZg6XiHdRKft1KckHiO9a2rNtyFbZ/ry9M= +github.com/docker/go-connections v0.5.0 h1:USnMq7hx7gwdVZq1L49hLXaFtUdTADjXGp+uj1Br63c= +github.com/docker/go-connections v0.5.0/go.mod h1:ov60Kzw0kKElRwhNs9UlUHAE/F9Fe6GLaXnqyDdmEXc= github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4= github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk= github.com/dsnet/try v0.0.3 h1:ptR59SsrcFUYbT/FhAbKTV6iLkeD6O18qfIWRml2fqI= github.com/dsnet/try v0.0.3/go.mod h1:WBM8tRpUmnXXhY1U6/S8dt6UWdHTQ7y8A5YSkRCkq40= github.com/elastic/crd-ref-docs v0.0.12 h1:F3seyncbzUz3rT3d+caeYWhumb5ojYQ6Bl0Z+zOp16M= github.com/elastic/crd-ref-docs v0.0.12/go.mod h1:X83mMBdJt05heJUYiS3T0yJ/JkCuliuhSUNav5Gjo/U= -github.com/elazarl/goproxy v0.0.0-20230808193330-2592e75ae04a h1:mATvB/9r/3gvcejNsXKSkQ6lcIaNec2nyfOdlTBR2lU= -github.com/elazarl/goproxy v0.0.0-20230808193330-2592e75ae04a/go.mod h1:Ro8st/ElPeALwNFlcTpWmkr6IoMFfkjXAvTHpevnDsM= +github.com/elazarl/goproxy v1.2.3 h1:xwIyKHbaP5yfT6O9KIeYJR5549MXRQkoQMRXGztz8YQ= +github.com/elazarl/goproxy v1.2.3/go.mod h1:YfEbZtqP4AetfO6d40vWchF3znWX7C7Vd6ZMfdL8z64= github.com/emicklei/go-restful/v3 v3.11.2 h1:1onLa9DcsMYO9P+CXaL0dStDqQ2EHHXLiz+BtnqkLAU= github.com/emicklei/go-restful/v3 v3.11.2/go.mod h1:6n3XBCmQQb25CM2LCACGz8ukIrRry+4bhvbpWn3mrbc= github.com/emirpasic/gods v1.18.1 h1:FXtiHYKDGKCW2KzwZKx0iC0PQmdlorYgdFG9jPXJ1Bc= @@ -311,8 +309,9 @@ github.com/evanphx/json-patch/v5 v5.9.0 h1:kcBlZQbplgElYIlo/n1hJbls2z/1awpXxpRi0 github.com/evanphx/json-patch/v5 v5.9.0/go.mod h1:VNkHZ/282BpEyt/tObQO8s5CMPmYYq14uClGH4abBuQ= github.com/evanw/esbuild v0.19.11 h1:mbPO1VJ/df//jjUd+p/nRLYCpizXxXb2w/zZMShxa2k= github.com/evanw/esbuild v0.19.11/go.mod h1:D2vIQZqV/vIf/VRHtViaUtViZmG7o+kKmlBfVQuRi48= -github.com/fatih/color v1.17.0 h1:GlRw1BRJxkpqUCBKzKOw098ed57fEsKeNjpTe3cSjK4= -github.com/fatih/color v1.17.0/go.mod h1:YZ7TlrGPkiz6ku9fK3TLD/pl3CpsiFyu8N92HLgmosI= +github.com/fatih/color v1.13.0/go.mod h1:kLAiJbzzSOZDVNGyDpeOxJ47H46qBXwg5ILebYFFOfk= +github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM= +github.com/fatih/color v1.18.0/go.mod h1:4FelSpRwEGDpQ12mAdzqdOukCy4u8WUtOY6lkT/6HfU= github.com/fatih/structtag v1.2.0 h1:/OdNE99OxoI/PqaW/SuSK9uxxT3f/tcSZgon/ssNSx4= github.com/fatih/structtag v1.2.0/go.mod h1:mBJUNpUnHmRKrKlQQlmCrh5PuhftFbNv8Ys4/aAZl94= github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg= @@ -321,37 +320,39 @@ github.com/firefart/nonamedreturns v1.0.4 h1:abzI1p7mAEPYuR4A+VLKn4eNDOycjYo2phm github.com/firefart/nonamedreturns v1.0.4/go.mod h1:TDhe/tjI1BXo48CmYbUduTV7BdIga8MAO/xbKdcVsGI= github.com/fogleman/gg v1.3.0 h1:/7zJX8F6AaYQc57WQCyN9cAIz+4bCJGO9B+dyW29am8= github.com/fogleman/gg v1.3.0/go.mod h1:R/bRT+9gY/C5z7JzPU0zXsXHKM4/ayA+zqcVNZzPa1k= +github.com/fortytw2/leaktest v1.3.0 h1:u8491cBMTQ8ft8aeV+adlcytMZylmA5nnwwkRZjI8vw= +github.com/fortytw2/leaktest v1.3.0/go.mod h1:jDsjWgpAGjm2CA7WthBh/CdZYEPF31XHquHwclZch5g= github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA= github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM= -github.com/fxamacker/cbor/v2 v2.6.0 h1:sU6J2usfADwWlYDAFhZBQ6TnLFBHxgesMrQfQgk1tWA= -github.com/fxamacker/cbor/v2 v2.6.0/go.mod h1:pxXPTn3joSm21Gbwsv0w9OSA2y1HFR9qXEeXQVeNoDQ= +github.com/fxamacker/cbor/v2 v2.7.0 h1:iM5WgngdRBanHcxugY4JySA0nk1wZorNOpTgCMedv5E= +github.com/fxamacker/cbor/v2 v2.7.0/go.mod h1:pxXPTn3joSm21Gbwsv0w9OSA2y1HFR9qXEeXQVeNoDQ= github.com/fzipp/gocyclo v0.6.0 h1:lsblElZG7d3ALtGMx9fmxeTKZaLLpU8mET09yN4BBLo= github.com/fzipp/gocyclo v0.6.0/go.mod h1:rXPyn8fnlpa0R2csP/31uerbiVBugk5whMdlyaLkLoA= -github.com/gaissmai/bart v0.11.1 h1:5Uv5XwsaFBRo4E5VBcb9TzY8B7zxFf+U7isDxqOrRfc= -github.com/gaissmai/bart v0.11.1/go.mod h1:KHeYECXQiBjTzQz/om2tqn3sZF1J7hw9m6z41ftj3fg= +github.com/gaissmai/bart v0.18.0 h1:jQLBT/RduJu0pv/tLwXE+xKPgtWJejbxuXAR+wLJafo= +github.com/gaissmai/bart v0.18.0/go.mod h1:JJzMAhNF5Rjo4SF4jWBrANuJfqY+FvsFhW7t1UZJ+XY= github.com/ghostiam/protogetter v0.3.5 h1:+f7UiF8XNd4w3a//4DnusQ2SZjPkUjxkMEfjbxOK4Ug= github.com/ghostiam/protogetter v0.3.5/go.mod h1:7lpeDnEJ1ZjL/YtyoN99ljO4z0pd3H0d18/t2dPBxHw= github.com/github/fakeca v0.1.0 h1:Km/MVOFvclqxPM9dZBC4+QE564nU4gz4iZ0D9pMw28I= github.com/github/fakeca v0.1.0/go.mod h1:+bormgoGMMuamOscx7N91aOuUST7wdaJ2rNjeohylyo= -github.com/gliderlabs/ssh v0.3.5 h1:OcaySEmAQJgyYcArR+gGGTHCyE7nvhEMTlYY+Dp8CpY= -github.com/gliderlabs/ssh v0.3.5/go.mod h1:8XB4KraRrX39qHhT6yxPsHedjA08I/uBVwj4xC+/+z4= +github.com/gliderlabs/ssh v0.3.8 h1:a4YXD1V7xMF9g5nTkdfnja3Sxy1PVDCj1Zg4Wb8vY6c= +github.com/gliderlabs/ssh v0.3.8/go.mod h1:xYoytBv1sV0aL3CavoDuJIQNURXkkfPA/wxQ1pL1fAU= github.com/go-critic/go-critic v0.11.2 h1:81xH/2muBphEgPtcwH1p6QD+KzXl2tMSi3hXjBSxDnM= github.com/go-critic/go-critic v0.11.2/go.mod h1:OePaicfjsf+KPy33yq4gzv6CO7TEQ9Rom6ns1KsJnl8= github.com/go-git/gcfg v1.5.1-0.20230307220236-3a3c6141e376 h1:+zs/tPmkDkHx3U66DAb0lQFJrpS6731Oaa12ikc+DiI= github.com/go-git/gcfg v1.5.1-0.20230307220236-3a3c6141e376/go.mod h1:an3vInlBmSxCcxctByoQdvwPiA7DTK7jaaFDBTtu0ic= -github.com/go-git/go-billy/v5 v5.5.0 h1:yEY4yhzCDuMGSv83oGxiBotRzhwhNr8VZyphhiu+mTU= -github.com/go-git/go-billy/v5 v5.5.0/go.mod h1:hmexnoNsr2SJU1Ju67OaNz5ASJY3+sHgFRpCtpDCKow= +github.com/go-git/go-billy/v5 v5.6.2 h1:6Q86EsPXMa7c3YZ3aLAQsMA0VlWmy43r6FHqa/UNbRM= +github.com/go-git/go-billy/v5 v5.6.2/go.mod h1:rcFC2rAsp/erv7CMz9GczHcuD0D32fWzH+MJAU+jaUU= github.com/go-git/go-git-fixtures/v4 v4.3.2-0.20231010084843-55a94097c399 h1:eMje31YglSBqCdIqdhKBW8lokaMrL3uTkpGYlE2OOT4= github.com/go-git/go-git-fixtures/v4 v4.3.2-0.20231010084843-55a94097c399/go.mod h1:1OCfN199q1Jm3HZlxleg+Dw/mwps2Wbk9frAWm+4FII= -github.com/go-git/go-git/v5 v5.11.0 h1:XIZc1p+8YzypNr34itUfSvYJcv+eYdTnTvOZ2vD3cA4= -github.com/go-git/go-git/v5 v5.11.0/go.mod h1:6GFcX2P3NM7FPBfpePbpLd21XxsgdAt+lKqXmCUiUCY= +github.com/go-git/go-git/v5 v5.13.1 h1:DAQ9APonnlvSWpvolXWIuV6Q6zXy2wHbN4cVlNR5Q+M= +github.com/go-git/go-git/v5 v5.13.1/go.mod h1:qryJB4cSBoq3FRoBRf5A77joojuBcmPJ0qu3XXXVixc= github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= -github.com/go-json-experiment/json v0.0.0-20231102232822-2e55bd4e08b0 h1:ymLjT4f35nQbASLnvxEde4XOBL+Sn7rFuV+FOJqkljg= -github.com/go-json-experiment/json v0.0.0-20231102232822-2e55bd4e08b0/go.mod h1:6daplAwHHGbUGib4990V3Il26O0OC4aRyvewaaAihaA= +github.com/go-json-experiment/json v0.0.0-20250813024750-ebf49471dced h1:Q311OHjMh/u5E2TITc++WlTP5We0xNseRMkHDyvhW7I= +github.com/go-json-experiment/json v0.0.0-20250813024750-ebf49471dced/go.mod h1:TiCD2a1pcmjd7YnhGH0f/zKNcCD06B029pHhzV23c2M= github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY= @@ -367,12 +368,12 @@ github.com/go-logr/zapr v1.3.0 h1:XGdV8XW8zdwFiwOA2Dryh1gj2KRQyOOoNmBy4EplIcQ= github.com/go-logr/zapr v1.3.0/go.mod h1:YKepepNBd1u/oyhd/yQmtjVXmm9uML4IXUgMOwR8/Gg= github.com/go-ole/go-ole v1.3.0 h1:Dt6ye7+vXGIKZ7Xtk4s6/xVdGDQynvom7xCFEdWr6uE= github.com/go-ole/go-ole v1.3.0/go.mod h1:5LS6F96DhAwUc7C+1HLexzMXY1xGRSryjyPPKW6zv78= -github.com/go-openapi/jsonpointer v0.20.2 h1:mQc3nmndL8ZBzStEo3JYF8wzmeWffDH4VbXz58sAx6Q= -github.com/go-openapi/jsonpointer v0.20.2/go.mod h1:bHen+N0u1KEO3YlmqOjTT9Adn1RfD91Ar825/PuiRVs= +github.com/go-openapi/jsonpointer v0.21.0 h1:YgdVicSA9vH5RiHs9TZW5oyafXZFc6+2Vc1rr/O9oNQ= +github.com/go-openapi/jsonpointer v0.21.0/go.mod h1:IUyH9l/+uyhIYQ/PXVA41Rexl+kOkAPDdXEYns6fzUY= github.com/go-openapi/jsonreference v0.20.4 h1:bKlDxQxQJgwpUSgOENiMPzCTBVuc7vTdXSSgNeAhojU= github.com/go-openapi/jsonreference v0.20.4/go.mod h1:5pZJyJP2MnYCpoeoMAql78cCHauHj0V9Lhc506VOpw4= -github.com/go-openapi/swag v0.22.7 h1:JWrc1uc/P9cSomxfnsFSVWoE1FW6bNbrVPmpQYpCcR8= -github.com/go-openapi/swag v0.22.7/go.mod h1:Gl91UqO+btAM0plGGxHqJcQZ1ZTy6jbmridBTsDy8A0= +github.com/go-openapi/swag v0.23.0 h1:vsEVJDUo2hPJ2tu0/Xc+4noaxyEffXNIs3cOULZ+GrE= +github.com/go-openapi/swag v0.23.0/go.mod h1:esZ8ITTYEsH1V2trKHjAN8Ai7xHb8RV+YSZ577vPjgQ= github.com/go-playground/locales v0.13.0 h1:HyWk6mgj5qFqCT5fjGBuRArbVDfE4hi8+e8ceBS/t7Q= github.com/go-playground/locales v0.13.0/go.mod h1:taPMhCMXrRLJO55olJkUXHZBHCxTMfnGwq/HNwmWNS8= github.com/go-playground/universal-translator v0.17.0 h1:icxd5fm+REJzpZx7ZfpaD876Lmtgy7VtROAbHHXk8no= @@ -383,7 +384,8 @@ github.com/go-quicktest/qt v1.101.0 h1:O1K29Txy5P2OK0dGo59b7b0LR6wKfIhttaAhHUyn7 github.com/go-quicktest/qt v1.101.0/go.mod h1:14Bz/f7NwaXPtdYEgzsx46kqSxVwTbzVZsDC26tQJow= github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI= -github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572/go.mod h1:9Pwr4B2jHnOSGXyyzV8ROjYa2ojvAY6HCGYYfMoC3Ls= +github.com/go-task/slim-sprig/v3 v3.0.0 h1:sUs3vkvUymDpBKi3qH1YSqBQk9+9D/8M2mN1vB6EwHI= +github.com/go-task/slim-sprig/v3 v3.0.0/go.mod h1:W848ghGpv3Qj3dhTPRyJypKRiqCdHZiAzKg9hl15HA8= github.com/go-toolsmith/astcast v1.1.0 h1:+JN9xZV1A+Re+95pgnMgDboWNVnIMMQXwfBwLRPgSC8= github.com/go-toolsmith/astcast v1.1.0/go.mod h1:qdcuFWeGGS2xX5bLM/c3U9lewg7+Zu4mr+xPwZIB4ZU= github.com/go-toolsmith/astcopy v1.1.0 h1:YGwBN0WM+ekI/6SS6+52zLDEf8Yvp3n2seZITCUBt5s= @@ -407,16 +409,18 @@ github.com/go-viper/mapstructure/v2 v2.0.0-alpha.1 h1:TQcrn6Wq+sKGkpyPvppOz99zsM github.com/go-viper/mapstructure/v2 v2.0.0-alpha.1/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM= github.com/go-xmlfmt/xmlfmt v1.1.2 h1:Nea7b4icn8s57fTx1M5AI4qQT5HEM3rVUO8MuE6g80U= github.com/go-xmlfmt/xmlfmt v1.1.2/go.mod h1:aUCEOzzezBEjDBbFBoSiya/gduyIiWYRP6CnSFIV8AM= -github.com/gobuffalo/flect v1.0.2 h1:eqjPGSo2WmjgY2XlpGwo2NXgL3RucAKo4k4qQMNA5sA= -github.com/gobuffalo/flect v1.0.2/go.mod h1:A5msMlrHtLqh9umBSnvabjsMrCcCpAyzglnDvkbYKHs= +github.com/go4org/plan9netshell v0.0.0-20250324183649-788daa080737 h1:cf60tHxREO3g1nroKr2osU3JWZsJzkfi7rEg+oAB0Lo= +github.com/go4org/plan9netshell v0.0.0-20250324183649-788daa080737/go.mod h1:MIS0jDzbU/vuM9MC4YnBITCv+RYuTRq8dJzmCrFsK9g= +github.com/gobuffalo/flect v1.0.3 h1:xeWBM2nui+qnVvNM4S3foBhCAL2XgPU+a7FdpelbTq4= +github.com/gobuffalo/flect v1.0.3/go.mod h1:A5msMlrHtLqh9umBSnvabjsMrCcCpAyzglnDvkbYKHs= github.com/gobwas/glob v0.2.3 h1:A4xDbljILXROh+kObIiy5kIaPYD8e96x1tgBhUI5J+Y= github.com/gobwas/glob v0.2.3/go.mod h1:d3Ez4x06l9bZtSvzIay5+Yzi0fmZzPgnTbPcKjJAkT8= github.com/goccy/go-yaml v1.12.0 h1:/1WHjnMsI1dlIBQutrvSMGZRQufVO3asrHfTwfACoPM= github.com/goccy/go-yaml v1.12.0/go.mod h1:wKnAMd44+9JAAnGQpWVEgBzGt3YuTaQ4uXoHvE4m7WU= github.com/godbus/dbus/v5 v5.1.1-0.20230522191255-76236955d466 h1:sQspH8M4niEijh3PFscJRLDnkL547IeP7kpPe3uUhEg= github.com/godbus/dbus/v5 v5.1.1-0.20230522191255-76236955d466/go.mod h1:ZiQxhyQ+bbbfxUKVvjfO498oPYvtYhZzycal3G/NHmU= -github.com/gofrs/flock v0.8.1 h1:+gYjHKf32LDeiEEFhQaotPbLuUXjY5ZqxKgXy7n59aw= -github.com/gofrs/flock v0.8.1/go.mod h1:F1TvTiK9OcQqauNUHlbJvyl9Qa1QvF/gOUDKA14jxHU= +github.com/gofrs/flock v0.12.1 h1:MTLVXXHf8ekldpJk3AKicLij9MdwOWkZ+a/jHHZby9E= +github.com/gofrs/flock v0.12.1/go.mod h1:9zxTsyu5xtJ9DK+1tFZyibEV7y3uwDxPPfbxeeHCoD0= github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= @@ -488,10 +492,18 @@ github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/ github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.8/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= -github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= -github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= -github.com/google/go-containerregistry v0.18.0 h1:ShE7erKNPqRh5ue6Z9DUOlk04WsnFWPO6YGr3OxnfoQ= -github.com/google/go-containerregistry v0.18.0/go.mod h1:u0qB2l7mvtWVR5kNcbFIhFY1hLbf8eeGapA+vbFDCtQ= +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/google/go-containerregistry v0.20.3 h1:oNx7IdTI936V8CQRveCjaxOiegWwvM7kqkbXTpyiovI= +github.com/google/go-containerregistry v0.20.3/go.mod h1:w00pIgBRDVUDFM6bq+Qx8lwNWK+cxgCuX1vd3PIBDNI= +github.com/google/go-github/v66 v66.0.0 h1:ADJsaXj9UotwdgK8/iFZtv7MLc8E8WBl62WLd/D/9+M= +github.com/google/go-github/v66 v66.0.0/go.mod h1:+4SO9Zkuyf8ytMj0csN1NR/5OTR+MfqPp8P8dVlcvY4= +github.com/google/go-querystring v1.1.0 h1:AnCroh3fv4ZBgVIf1Iwtovgjaw/GiKJo8M8yD/fhyJ8= +github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17icRSOU623lUBU= +github.com/google/go-tpm v0.9.4 h1:awZRf9FwOeTunQmHoDYSHJps3ie6f1UlhS1fOdPEt1I= +github.com/google/go-tpm v0.9.4/go.mod h1:h9jEsEECg7gtLis0upRBQU+GhYVH6jMjrFxI8u6bVUY= +github.com/google/go-tpm-tools v0.3.13-0.20230620182252-4639ecce2aba h1:qJEJcuLzH5KDR0gKc0zcktin6KSAwL7+jWKBYceddTc= +github.com/google/go-tpm-tools v0.3.13-0.20230620182252-4639ecce2aba/go.mod h1:EFYHy8/1y2KfgTAsx7Luu7NGhoxtuVHnNo8jE7FikKc= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0= github.com/google/gofuzz v1.2.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= @@ -510,8 +522,8 @@ github.com/google/pprof v0.0.0-20200212024743-f11f1df84d12/go.mod h1:ZgVRPoUq/hf github.com/google/pprof v0.0.0-20200229191704-1ebb73c60ed3/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM= github.com/google/pprof v0.0.0-20200430221834-fc25d7d30c6d/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM= github.com/google/pprof v0.0.0-20200708004538-1a94d8640e99/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM= -github.com/google/pprof v0.0.0-20240409012703-83162a5b38cd h1:gbpYu9NMq8jhDVbvlGkMFWCjLFlqqEZjEmObmhUy6Vo= -github.com/google/pprof v0.0.0-20240409012703-83162a5b38cd/go.mod h1:kf6iHlnVGwgKolg33glAes7Yg/8iWP8ukqeldJSO7jw= +github.com/google/pprof v0.0.0-20241029153458-d1b30febd7db h1:097atOisP2aRj7vFgYQBbFN4U4JNXUNYpxael3UzMyo= +github.com/google/pprof v0.0.0-20241029153458-d1b30febd7db/go.mod h1:vavhavw2zAxS5dIdcRluK6cSGGPlZynqzFM8NdvU144= github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= github.com/google/rpmpack v0.5.0 h1:L16KZ3QvkFGpYhmp23iQip+mx1X39foEsqszjMNBm8A= github.com/google/rpmpack v0.5.0/go.mod h1:uqVAUVQLq8UY2hCDfmJ/+rtO3aw7qyhc90rCVEabEfI= @@ -530,8 +542,8 @@ github.com/goreleaser/fileglob v1.3.0 h1:/X6J7U8lbDpQtBvGcwwPS6OpzkNVlVEsFUVRx9+ github.com/goreleaser/fileglob v1.3.0/go.mod h1:Jx6BoXv3mbYkEzwm9THo7xbr5egkAraxkGorbJb4RxU= github.com/goreleaser/nfpm/v2 v2.33.1 h1:EkdAzZyVhAI9JC1vjmjjbmnNzyH1J6Cu4JCsA7YcQuc= github.com/goreleaser/nfpm/v2 v2.33.1/go.mod h1:8wwWWvJWmn84xo/Sqiv0aMvEGTHlHZTXTEuVSgQpkIM= -github.com/gorilla/csrf v1.7.2 h1:oTUjx0vyf2T+wkrx09Trsev1TE+/EbDAeHtSTbtC2eI= -github.com/gorilla/csrf v1.7.2/go.mod h1:F1Fj3KG23WYHE6gozCmBAezKookxbIvUJT+121wTuLk= +github.com/gorilla/csrf v1.7.3 h1:BHWt6FTLZAb2HtWT5KDBf6qgpZzvtbp9QWDRKZMXJC0= +github.com/gorilla/csrf v1.7.3/go.mod h1:F1Fj3KG23WYHE6gozCmBAezKookxbIvUJT+121wTuLk= github.com/gorilla/securecookie v1.1.2 h1:YCIWL56dvtr73r6715mJs5ZvhtnY73hBvEF8kXD8ePA= github.com/gorilla/securecookie v1.1.2/go.mod h1:NfCASbcHqRSY+3a8tlWJwsQap2VX5pwzwo4h3eOamfo= github.com/gostaticanalysis/analysisutil v0.7.1 h1:ZMCjoue3DtDWQ5WyU16YbjbQEQ3VuzwxALrpYd+HeKk= @@ -547,15 +559,40 @@ github.com/gostaticanalysis/testutil v0.3.1-0.20210208050101-bfb5c8eec0e4/go.mod github.com/gostaticanalysis/testutil v0.4.0 h1:nhdCmubdmDF6VEatUNjgUZBJKWRqugoISdUv3PPQgHY= github.com/gostaticanalysis/testutil v0.4.0/go.mod h1:bLIoPefWXrRi/ssLFWX1dx7Repi5x3CuviD3dgAZaBU= github.com/grpc-ecosystem/grpc-gateway v1.16.0 h1:gmcG1KaJ57LophUzW0Hy8NmPhnMZb4M0+kPpLofRdBo= -github.com/grpc-ecosystem/grpc-gateway/v2 v2.16.0 h1:YBftPWNWd4WwGqtY2yeZL2ef8rHAxPBD8KFhJpmcqms= -github.com/grpc-ecosystem/grpc-gateway/v2 v2.16.0/go.mod h1:YN5jB8ie0yfIUg6VvR9Kz84aCaG7AsGZnLjhHbUqwPg= +github.com/grpc-ecosystem/grpc-gateway/v2 v2.22.0 h1:asbCHRVmodnJTuQ3qamDwqVOIjwqUPTYmYuemVOx+Ys= +github.com/grpc-ecosystem/grpc-gateway/v2 v2.22.0/go.mod h1:ggCgvZ2r7uOoQjOyu2Y1NhHmEPPzzuhWgcza5M1Ji1I= +github.com/h2non/parth v0.0.0-20190131123155-b4df798d6542 h1:2VTzZjLZBgl62/EtslCrtky5vbi9dd7HrQPQIx6wqiw= +github.com/h2non/parth v0.0.0-20190131123155-b4df798d6542/go.mod h1:Ow0tF8D4Kplbc8s8sSb3V2oUCygFHVp8gC3Dn6U4MNI= +github.com/hashicorp/go-cleanhttp v0.5.0/go.mod h1:JpRdi6/HCYpAwUzNwuwqhbovhLtngrth3wmdIIUrZ80= +github.com/hashicorp/go-hclog v1.6.2 h1:NOtoftovWkDheyUM/8JW3QMiXyxJK3uHRK7wV04nD2I= +github.com/hashicorp/go-hclog v1.6.2/go.mod h1:W4Qnvbt70Wk/zYJryRzDRU/4r0kIg0PVHBcfoyhpF5M= +github.com/hashicorp/go-immutable-radix v1.0.0/go.mod h1:0y9vanUI8NX6FsYoO3zeMjhV/C5i9g4Q3DwcSNZ4P60= +github.com/hashicorp/go-immutable-radix v1.3.1 h1:DKHmCUm2hRBK510BaiZlwvpD40f8bJFeZnpfm2KLowc= +github.com/hashicorp/go-immutable-radix v1.3.1/go.mod h1:0y9vanUI8NX6FsYoO3zeMjhV/C5i9g4Q3DwcSNZ4P60= +github.com/hashicorp/go-metrics v0.5.4 h1:8mmPiIJkTPPEbAiV97IxdAGNdRdaWwVap1BU6elejKY= +github.com/hashicorp/go-metrics v0.5.4/go.mod h1:CG5yz4NZ/AI/aQt9Ucm/vdBnbh7fvmv4lxZ350i+QQI= +github.com/hashicorp/go-msgpack v0.5.5 h1:i9R9JSrqIz0QVLz3sz+i3YJdT7TTSLcfLLzJi9aZTuI= +github.com/hashicorp/go-msgpack v0.5.5/go.mod h1:ahLV/dePpqEmjfWmKiqvPkv/twdG7iPBM1vqhUKIvfM= +github.com/hashicorp/go-msgpack/v2 v2.1.2 h1:4Ee8FTp834e+ewB71RDrQ0VKpyFdrKOjvYtnQ/ltVj0= +github.com/hashicorp/go-msgpack/v2 v2.1.2/go.mod h1:upybraOAblm4S7rx0+jeNy+CWWhzywQsSRV5033mMu4= +github.com/hashicorp/go-retryablehttp v0.5.3/go.mod h1:9B5zBasrRhHXnJnui7y6sL7es7NDiJgTc6Er0maI1Xs= +github.com/hashicorp/go-uuid v1.0.0 h1:RS8zrF7PhGwyNPOtxSClXXj9HA8feRnJzgnI1RJCSnM= +github.com/hashicorp/go-uuid v1.0.0/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= github.com/hashicorp/go-version v1.2.1/go.mod h1:fltr4n8CU8Ke44wwGCBoEymUuxUHl09ZGVZPK5anwXA= github.com/hashicorp/go-version v1.6.0 h1:feTTfFNnjP967rlCxM/I9g701jU+RN74YKx2mOkIeek= github.com/hashicorp/go-version v1.6.0/go.mod h1:fltr4n8CU8Ke44wwGCBoEymUuxUHl09ZGVZPK5anwXA= github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= +github.com/hashicorp/golang-lru v0.6.0 h1:uL2shRDx7RTrOrTCUZEGP/wJUFiUI8QT6E7z5o8jga4= +github.com/hashicorp/golang-lru v0.6.0/go.mod h1:iADmTwqILo4mZ8BN3D2Q6+9jd8WM5uGBxy+E8yxSoD4= github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4= github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= +github.com/hashicorp/raft v1.7.2 h1:pyvxhfJ4R8VIAlHKvLoKQWElZspsCVT6YWuxVxsPAgc= +github.com/hashicorp/raft v1.7.2/go.mod h1:DfvCGFxpAUPE0L4Uc8JLlTPtc3GzSbdH0MTJCLgnmJQ= +github.com/hashicorp/raft-boltdb v0.0.0-20230125174641-2a8082862702 h1:RLKEcCuKcZ+qp2VlaaZsYZfLOmIiuJNpEi48Rl8u9cQ= +github.com/hashicorp/raft-boltdb v0.0.0-20230125174641-2a8082862702/go.mod h1:nTakvJ4XYq45UXtn0DbwR4aU9ZdjlnIenpbs6Cd+FM0= +github.com/hashicorp/raft-boltdb/v2 v2.3.1 h1:ackhdCNPKblmOhjEU9+4lHSJYFkJd6Jqyvj6eW9pwkc= +github.com/hashicorp/raft-boltdb/v2 v2.3.1/go.mod h1:n4S+g43dXF1tqDT+yzcXHhXM6y7MrlUd3TTwGRcUvQE= github.com/hdevalence/ed25519consensus v0.2.0 h1:37ICyZqdyj0lAZ8P4D1d1id3HqbbG1N3iBb1Tb4rdcU= github.com/hdevalence/ed25519consensus v0.2.0/go.mod h1:w3BHWjwJbFU29IRHL1Iqkw3sus+7FctEyM4RqDxYNzo= github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUqJM= @@ -563,18 +600,18 @@ github.com/hexops/gotextdiff v1.0.3/go.mod h1:pSWU5MAI3yDq+fZBTazCSJysOMbxWL1BSo github.com/huandu/xstrings v1.3.3/go.mod h1:y5/lhBue+AyNmUVz9RLU9xbLR0o4KIIExikq4ovT0aE= github.com/huandu/xstrings v1.5.0 h1:2ag3IFq9ZDANvthTwTiqSSZLjDc+BedvHPAp5tJy2TI= github.com/huandu/xstrings v1.5.0/go.mod h1:y5/lhBue+AyNmUVz9RLU9xbLR0o4KIIExikq4ovT0aE= -github.com/hugelgupf/vmtest v0.0.0-20240102225328-693afabdd27f h1:ov45/OzrJG8EKbGjn7jJZQJTN7Z1t73sFYNIRd64YlI= -github.com/hugelgupf/vmtest v0.0.0-20240102225328-693afabdd27f/go.mod h1:JoDrYMZpDPYo6uH9/f6Peqms3zNNWT2XiGgioMOIGuI= +github.com/hugelgupf/vmtest v0.0.0-20240216064925-0561770280a1 h1:jWoR2Yqg8tzM0v6LAiP7i1bikZJu3gxpgvu3g1Lw+a0= +github.com/hugelgupf/vmtest v0.0.0-20240216064925-0561770280a1/go.mod h1:B63hDJMhTupLWCHwopAyEo7wRFowx9kOc8m8j1sfOqE= github.com/ianlancetaylor/demangle v0.0.0-20181102032728-5e5cf60278f6/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= -github.com/illarion/gonotify/v2 v2.0.3 h1:B6+SKPo/0Sw8cRJh1aLzNEeNVFfzE3c6N+o+vyxM+9A= -github.com/illarion/gonotify/v2 v2.0.3/go.mod h1:38oIJTgFqupkEydkkClkbL6i5lXV/bxdH9do5TALPEE= +github.com/illarion/gonotify/v3 v3.0.2 h1:O7S6vcopHexutmpObkeWsnzMJt/r1hONIEogeVNmJMk= +github.com/illarion/gonotify/v3 v3.0.2/go.mod h1:HWGPdPe817GfvY3w7cx6zkbzNZfi3QjcBm/wgVvEL1U= github.com/imdario/mergo v0.3.11/go.mod h1:jmQim1M+e3UYxmgPu/WyfjB3N3VflVyUjjjwH0dnCYA= github.com/imdario/mergo v0.3.16 h1:wwQJbIsHYGMUyLSPrEq1CT16AhnhNJQ51+4fdHUnCl4= github.com/imdario/mergo v0.3.16/go.mod h1:WBLT9ZmE3lPoWsEzCh9LPo3TiwVN+ZKEjmz+hD27ysY= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= -github.com/inetaf/tcpproxy v0.0.0-20240214030015-3ce58045626c h1:gYfYE403/nlrGNYj6BEOs9ucLCAGB9gstlSk92DttTg= -github.com/inetaf/tcpproxy v0.0.0-20240214030015-3ce58045626c/go.mod h1:Di7LXRyUcnvAcLicFhtM9/MlZl/TNgRSDHORM2c6CMI= +github.com/inetaf/tcpproxy v0.0.0-20250203165043-ded522cbd03f h1:hPcDyz0u+Zo14n0fpJggxL9JMAmZIK97TVLcLJLPMDI= +github.com/inetaf/tcpproxy v0.0.0-20250203165043-ded522cbd03f/go.mod h1:Di7LXRyUcnvAcLicFhtM9/MlZl/TNgRSDHORM2c6CMI= github.com/insomniacslk/dhcp v0.0.0-20231206064809-8c70d406f6d2 h1:9K06NfxkBh25x56yVhWWlKFE8YpicaSfHwoV8SFbueA= github.com/insomniacslk/dhcp v0.0.0-20231206064809-8c70d406f6d2/go.mod h1:3A9PQ1cunSDF/1rbTq99Ts4pVnycWg+vlPkfeD2NLFI= github.com/jbenet/go-context v0.0.0-20150711004518-d14ea06fba99 h1:BQSFePA1RWJOlocH6Fxy8MmwDt+yVQYULKfN0RoTN8A= @@ -596,13 +633,11 @@ github.com/jmespath/go-jmespath/internal/testify v1.5.1 h1:shLQSRRSCCPj3f2gpwzGw github.com/jmespath/go-jmespath/internal/testify v1.5.1/go.mod h1:L3OGu8Wl2/fWfCI6z80xFu9LTZmf1ZRjMHUOPmWr69U= github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY= github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= -github.com/josharian/native v1.0.1-0.20221213033349-c1e37c09b531/go.mod h1:7X/raswPFr05uY3HiLlYeyQntB6OO7E/d2Cu7qoaN2w= -github.com/josharian/native v1.1.1-0.20230202152459-5c7d0dd6ab86 h1:elKwZS1OcdQ0WwEDBeqxKwb7WB62QX8bvZ/FJnVXIfk= -github.com/josharian/native v1.1.1-0.20230202152459-5c7d0dd6ab86/go.mod h1:aFAMtuldEgx/4q7iSGazk22+IcgvtiC+HIimFO9XlS8= github.com/jpillora/backoff v1.0.0/go.mod h1:J/6gKK9jxlEcS3zixgDgUAsiuZ7yrSoa/FX5e0EB2j4= github.com/jsimonetti/rtnetlink v1.4.0 h1:Z1BF0fRgcETPEa0Kt0MRk3yV5+kF1FWTni6KUFKrq2I= github.com/jsimonetti/rtnetlink v1.4.0/go.mod h1:5W1jDvWdnthFJ7fxYX1GMK07BUpI4oskfOqvPteYS6E= github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= +github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= github.com/json-iterator/go v1.1.10/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= github.com/json-iterator/go v1.1.11/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= @@ -627,8 +662,8 @@ github.com/kisielk/errcheck v1.7.0/go.mod h1:1kLL+jV4e+CFfueBmI1dSK2ADDyQnlrnrY/ github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/kkHAIKE/contextcheck v1.1.4 h1:B6zAaLhOEEcjvUgIYEqystmnFk1Oemn8bvJhbt0GMb8= github.com/kkHAIKE/contextcheck v1.1.4/go.mod h1:1+i/gWqokIa+dm31mqGLZhZJ7Uh44DJGZVmr6QRBNJg= -github.com/klauspost/compress v1.17.4 h1:Ej5ixsIri7BrIjBkRZLTo6ghwrEtHFk7ijlczPW4fZ4= -github.com/klauspost/compress v1.17.4/go.mod h1:/dCuZOvVtNoHsyb+cuJD3itjs3NbnF6KH9zAO4BDxPM= +github.com/klauspost/compress v1.17.11 h1:In6xLpyWOi1+C7tXUUWv2ot1QvBjxevKAaI6IXrJmUc= +github.com/klauspost/compress v1.17.11/go.mod h1:pMDklpSncoRMuLFrf1W9Ss9KT+0rH90U12bZKk7uwG0= github.com/klauspost/pgzip v1.2.6 h1:8RXeL5crjEUFnR2/Sn6GJNWtSQ3Dk8pq4CL3jvdDyjU= github.com/klauspost/pgzip v1.2.6/go.mod h1:Ch1tH69qFZu15pkjo5kYi6mth2Zzwzt50oCQKQE9RUs= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= @@ -649,6 +684,8 @@ github.com/kulti/thelper v0.6.3 h1:ElhKf+AlItIu+xGnI990no4cE2+XaSu1ULymV2Yulxs= github.com/kulti/thelper v0.6.3/go.mod h1:DsqKShOvP40epevkFrvIwkCMNYxMeTNjdWL4dqWHZ6I= github.com/kunwardeep/paralleltest v1.0.10 h1:wrodoaKYzS2mdNVnc4/w31YaXFtsc21PCTdvWJ/lDDs= github.com/kunwardeep/paralleltest v1.0.10/go.mod h1:2C7s65hONVqY7Q5Efj5aLzRCNLjw2h4eMc9EcypGjcY= +github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= +github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= github.com/kyoh86/exportloopref v0.1.11 h1:1Z0bcmTypkL3Q4k+IDHMWTcnCliEZcaPiIe0/ymEyhQ= github.com/kyoh86/exportloopref v0.1.11/go.mod h1:qkV4UF1zGl6EkF1ox8L5t9SwyeBAZ3qLMd6up458uqA= github.com/ldez/gomoddirectives v0.2.3 h1:y7MBaisZVDYmKvt9/l1mjNCiSA1BVn34U0ObUcJwlhA= @@ -675,8 +712,12 @@ github.com/matoous/godox v0.0.0-20230222163458-006bad1f9d26 h1:gWg6ZQ4JhDfJPqlo2 github.com/matoous/godox v0.0.0-20230222163458-006bad1f9d26/go.mod h1:1BELzlh859Sh1c6+90blK8lbYy0kwQf1bYlBhBysy1s= github.com/matryer/is v1.4.0 h1:sosSmIWwkYITGrxZ25ULNDeKiMNzFSr4V/eqBQP0PeE= github.com/matryer/is v1.4.0/go.mod h1:8I/i5uYgLzgsgEloJE1U6xx5HkBQpAZvepWuujKwMRU= +github.com/mattn/go-colorable v0.1.9/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc= +github.com/mattn/go-colorable v0.1.12/go.mod h1:u5H1YNBxpqRaxsYJYSkiCWKzEfiAb1Gb520KVy5xxl4= github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= +github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= +github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94= github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= @@ -686,8 +727,8 @@ github.com/mattn/go-runewidth v0.0.14/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= github.com/mdlayher/genetlink v1.3.2 h1:KdrNKe+CTu+IbZnm/GVUMXSqBBLqcGpRDa0xkQy56gw= github.com/mdlayher/genetlink v1.3.2/go.mod h1:tcC3pkCrPUGIKKsCsp0B3AdaaKuHtaxoJRz3cc+528o= -github.com/mdlayher/netlink v1.7.2 h1:/UtM3ofJap7Vl4QWCPDGXY8d3GIY2UGSDbK+QWmY8/g= -github.com/mdlayher/netlink v1.7.2/go.mod h1:xraEF7uJbxLhc5fpHL4cPe221LI2bdttWlU+ZGLfQSw= +github.com/mdlayher/netlink v1.7.3-0.20250113171957-fbb4dce95f42 h1:A1Cq6Ysb0GM0tpKMbdCXCIfBclan4oHk1Jb+Hrejirg= +github.com/mdlayher/netlink v1.7.3-0.20250113171957-fbb4dce95f42/go.mod h1:BB4YCPDOzfy7FniQ/lxuYQ3dgmM2cZumHbK8RpTjN2o= github.com/mdlayher/sdnotify v1.0.0 h1:Ma9XeLVN/l0qpyx1tNeMSeTjCPH6NtuD6/N9XdTlQ3c= github.com/mdlayher/sdnotify v1.0.0/go.mod h1:HQUmpM4XgYkhDLtd+Uad8ZFK1T9D5+pNxnXQjCeJlGE= github.com/mdlayher/socket v0.5.0 h1:ilICZmJcQz70vrWVes1MFera4jGiWNocSkykwwoy3XI= @@ -708,10 +749,12 @@ github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RR github.com/mitchellh/reflectwalk v1.0.0/go.mod h1:mSTlrgnPZtwu0c4WaC2kGObEpuNDbx0jmZXqmk4esnw= github.com/mitchellh/reflectwalk v1.0.2 h1:G2LzWKi524PWgd3mLHV8Y5k7s6XUvT0Gef6zxSIeXaQ= github.com/mitchellh/reflectwalk v1.0.2/go.mod h1:mSTlrgnPZtwu0c4WaC2kGObEpuNDbx0jmZXqmk4esnw= +github.com/moby/buildkit v0.20.2 h1:qIeR47eQ1tzI1rwz0on3Xx2enRw/1CKjFhoONVcTlMA= +github.com/moby/buildkit v0.20.2/go.mod h1:DhaF82FjwOElTftl0JUAJpH/SUIUx4UvcFncLeOtlDI= github.com/moby/docker-image-spec v1.3.1 h1:jMKff3w6PgbfSa69GfNg+zN/XLhfXJGnEx3Nl2EsFP0= github.com/moby/docker-image-spec v1.3.1/go.mod h1:eKmb5VW8vQEh/BAr2yvVNvuiJuY6UIocYsFu/DxxRpo= -github.com/moby/term v0.0.0-20221205130635-1aeaba878587 h1:HfkjXDfhgVaN5rmueG8cL8KKeFNecRCXFhaJ2qZ5SKA= -github.com/moby/term v0.0.0-20221205130635-1aeaba878587/go.mod h1:8FzsFHVUBGZdbDsJw/ot+X+d5HLUbvklYLJ9uGfcI3Y= +github.com/moby/term v0.5.2 h1:6qk3FJAFDs6i/q3W/pQ97SX192qKfZgGjCQqfCJkgzQ= +github.com/moby/term v0.5.2/go.mod h1:d3djjFCrjnB+fl8NJux+EJzu0msscUP+f8it8hPkFLc= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= @@ -743,14 +786,14 @@ github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY= github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE= github.com/onsi/ginkgo v1.16.5/go.mod h1:+E8gABHa3K6zRBolWtd+ROzc/U5bkGt0FwiG042wbpU= -github.com/onsi/ginkgo/v2 v2.17.1 h1:V++EzdbhI4ZV4ev0UTIj0PzhzOcReJFyJaLjtSF55M8= -github.com/onsi/ginkgo/v2 v2.17.1/go.mod h1:llBI3WDLL9Z6taip6f33H76YcWtJv+7R3HigUjbIBOs= -github.com/onsi/gomega v1.33.1 h1:dsYjIxxSR755MDmKVsaFQTE22ChNBcuuTWgkUDSubOk= -github.com/onsi/gomega v1.33.1/go.mod h1:U4R44UsT+9eLIaYRB2a5qajjtQYn0hauxvRm16AVYg0= +github.com/onsi/ginkgo/v2 v2.21.0 h1:7rg/4f3rB88pb5obDgNZrNHrQ4e6WpjonchcpuBRnZM= +github.com/onsi/ginkgo/v2 v2.21.0/go.mod h1:7Du3c42kxCUegi0IImZ1wUQzMBVecgIHjR1C+NkhLQo= +github.com/onsi/gomega v1.36.2 h1:koNYke6TVk6ZmnyHrCXba/T/MoLBXFjeC1PtvYgw0A8= +github.com/onsi/gomega v1.36.2/go.mod h1:DdwyADRjrc825LhMEkD76cHR5+pUnjhUN8GlHlRPHzY= github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U= github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM= -github.com/opencontainers/image-spec v1.1.0-rc6 h1:XDqvyKsJEbRtATzkgItUqBA7QHk58yxX1Ov9HERHNqU= -github.com/opencontainers/image-spec v1.1.0-rc6/go.mod h1:W4s4sFTMaBeK1BQLXbG4AdM2szdn85PY75RI83NrTrM= +github.com/opencontainers/image-spec v1.1.0 h1:8SG7/vwALn54lVB/0yZ/MMwhFrPYtpEHQb2IpWsCzug= +github.com/opencontainers/image-spec v1.1.0/go.mod h1:W4s4sFTMaBeK1BQLXbG4AdM2szdn85PY75RI83NrTrM= github.com/otiai10/copy v1.2.0/go.mod h1:rrF5dJ5F0t/EWSYODDu4j9/vEeYHMkc8jt0zJChqQWw= github.com/otiai10/copy v1.14.0 h1:dCI/t1iTdYGtkvCuBG2BgR6KZa83PTclw4U5n2wAllU= github.com/otiai10/copy v1.14.0/go.mod h1:ECfuL02W+/FkTWZWgQqXPWZgW9oeKCSQ5qVfSc4qc4w= @@ -758,13 +801,16 @@ github.com/otiai10/curr v0.0.0-20150429015615-9b4961190c95/go.mod h1:9qAhocn7zKJ github.com/otiai10/curr v1.0.0/go.mod h1:LskTG5wDwr8Rs+nNQ+1LlxRjAtTZZjtJW4rMXl6j4vs= github.com/otiai10/mint v1.3.0/go.mod h1:F5AjcsTsWUqX+Na9fpHb52P8pcRX2CI6A3ctIT91xUo= github.com/otiai10/mint v1.3.1/go.mod h1:/yxELlJQ0ufhjUwhshSj+wFjZ78CnZ48/1wtmBH1OTc= +github.com/pascaldekloe/goe v0.1.0 h1:cBOtyMzM9HTpWjXfbbunk26uA6nG3a8n06Wieeh0MwY= +github.com/pascaldekloe/goe v0.1.0/go.mod h1:lzWF7FIEvWOWxwDKqyGYQf6ZUaNfKdP144TG7ZOy1lc= github.com/pelletier/go-toml/v2 v2.2.0 h1:QLgLl2yMN7N+ruc31VynXs1vhMZa7CeHHejIeBAsoHo= github.com/pelletier/go-toml/v2 v2.2.0/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs= github.com/peterbourgon/ff/v3 v3.4.0 h1:QBvM/rizZM1cB0p0lGMdmR7HxZeI/ZrBWB4DqLkMUBc= github.com/peterbourgon/ff/v3 v3.4.0/go.mod h1:zjJVUhx+twciwfDl0zBcFzl4dW8axCRyXE/eKY9RztQ= -github.com/pierrec/lz4/v4 v4.1.14/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4= github.com/pierrec/lz4/v4 v4.1.21 h1:yOVMLb6qSIDP67pl/5F7RepeKYu/VmTyEXvuMI5d9mQ= github.com/pierrec/lz4/v4 v4.1.21/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4= +github.com/pires/go-proxyproto v0.8.1 h1:9KEixbdJfhrbtjpz/ZwCdWDD2Xem0NZ38qMYaASJgp0= +github.com/pires/go-proxyproto v0.8.1/go.mod h1:ZKAAyp3cgy5Y5Mo4n9AlScrkCZwUy0g3Jf+slqQVcuU= github.com/pjbgf/sha1cd v0.3.0 h1:4D5XXmUUBUl/xQ6IjCkEAbqXskkq/4O7LmGn0AqMDs4= github.com/pjbgf/sha1cd v0.3.0/go.mod h1:nZ1rrWOcGJ5uZgEEVL1VUM9iRQiZvWdbZjkKyFzPPsI= github.com/pkg/diff v0.0.0-20200914180035-5b29258ca4f7/go.mod h1:zO8QMzTeZd5cpnIkz/Gn6iK0jDfGicM1nynOkkPIl28= @@ -776,6 +822,8 @@ github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/sftp v1.13.6 h1:JFZT4XbOU7l77xGSpOdW+pwIMqP044IyjXX6FGyEKFo= github.com/pkg/sftp v1.13.6/go.mod h1:tz1ryNURKu77RL+GuCzmoJYxQczL3wLNNpPWagdg4Qk= +github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10 h1:GFCKgmp0tecUJ0sJuv4pzYCqS9+RGSn52M3FUwPs+uo= +github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10/go.mod h1:t/avpk3KcrXxUnYOhZhMXJlSEyie6gQbtLq5NM3loB8= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= @@ -787,32 +835,38 @@ github.com/prometheus-community/pro-bing v0.4.0 h1:YMbv+i08gQz97OZZBwLyvmmQEEzyf github.com/prometheus-community/pro-bing v0.4.0/go.mod h1:b7wRYZtCcPmt4Sz319BykUU241rWLe1VFXyiyWK/dH4= github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw= github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5FsnadC4Ky3P0J6CfImo= +github.com/prometheus/client_golang v1.4.0/go.mod h1:e9GMxYsXl05ICDXkRhurwBS4Q3OK1iX/F2sw+iXX5zU= github.com/prometheus/client_golang v1.7.1/go.mod h1:PY5Wy2awLA44sXw4AOSfFBetzPP4j5+D6mVACh+pe2M= github.com/prometheus/client_golang v1.11.0/go.mod h1:Z6t4BnS23TR94PD6BsDNk8yVqroYurpAkEiz0P2BEV0= +github.com/prometheus/client_golang v1.11.1/go.mod h1:Z6t4BnS23TR94PD6BsDNk8yVqroYurpAkEiz0P2BEV0= github.com/prometheus/client_golang v1.12.1/go.mod h1:3Z9XVyYiZYEO+YQWt3RD2R3jrbd179Rt297l4aS6nDY= -github.com/prometheus/client_golang v1.19.1 h1:wZWJDwK+NameRJuPGDhlnFgx8e8HN3XHQeLaYJFJBOE= -github.com/prometheus/client_golang v1.19.1/go.mod h1:mP78NwGzrVks5S2H6ab8+ZZGJLZUq1hoULYBAYBw1Ho= +github.com/prometheus/client_golang v1.20.5 h1:cxppBPuYhUnsO6yo/aoRol4L7q7UFfdm+bR9r+8l63Y= +github.com/prometheus/client_golang v1.20.5/go.mod h1:PIEt8X02hGcP8JWbeHyeZ53Y/jReSnHgO035n//V5WE= github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo= github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/prometheus/client_model v0.2.0/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= -github.com/prometheus/client_model v0.5.0 h1:VQw1hfvPvk3Uv6Qf29VrPF32JB6rtbgI6cYPYQjL0Qw= -github.com/prometheus/client_model v0.5.0/go.mod h1:dTiFglRmd66nLR9Pv9f0mZi7B7fk5Pm3gvsjB5tr+kI= +github.com/prometheus/client_model v0.6.1 h1:ZKSh/rekM+n3CeS952MLRAdFwIKqeY8b62p8ais2e9E= +github.com/prometheus/client_model v0.6.1/go.mod h1:OrxVMOVHjw3lKMa8+x6HeMGkHMQyHDk9E3jmP2AmGiY= github.com/prometheus/common v0.4.1/go.mod h1:TNfzLD0ON7rHzMJeJkieUDPYmFC7Snx/y86RQel1bk4= +github.com/prometheus/common v0.9.1/go.mod h1:yhUN8i9wzaXS3w1O07YhxHEBxD+W35wd8bs7vj7HSQ4= github.com/prometheus/common v0.10.0/go.mod h1:Tlit/dnDKsSWFlCLTWaA1cyBgKHSMdTB80sz/V91rCo= github.com/prometheus/common v0.26.0/go.mod h1:M7rCNAaPfAosfx8veZJCuw84e35h3Cfd9VFqTh1DIvc= github.com/prometheus/common v0.32.1/go.mod h1:vu+V0TpY+O6vW9J44gczi3Ap/oXXR10b+M/gUGO4Hls= -github.com/prometheus/common v0.48.0 h1:QO8U2CdOzSn1BBsmXJXduaaW+dY/5QLjfB8svtSzKKE= -github.com/prometheus/common v0.48.0/go.mod h1:0/KsvlIEfPQCQ5I2iNSAWKPZziNCvRs5EC6ILDTlAPc= +github.com/prometheus/common v0.55.0 h1:KEi6DK7lXW/m7Ig5i47x0vRzuBsHuvJdi5ee6Y3G1dc= +github.com/prometheus/common v0.55.0/go.mod h1:2SECS4xJG1kd8XF9IcM1gMX6510RAEL65zxzNImwdc8= github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= github.com/prometheus/procfs v0.0.2/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsTZCD3I8kEA= +github.com/prometheus/procfs v0.0.8/go.mod h1:7Qr8sr6344vo1JqZ6HhLceV9o3AJ1Ff+GxbHq6oeK9A= github.com/prometheus/procfs v0.1.3/go.mod h1:lV6e/gmhEcM9IjHGsFOCxxuZ+z1YqCvr4OA4YeYWdaU= github.com/prometheus/procfs v0.6.0/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1xBZuNvfVA= github.com/prometheus/procfs v0.7.3/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1xBZuNvfVA= -github.com/prometheus/procfs v0.12.0 h1:jluTpSng7V9hY0O2R9DzzJHYb2xULk9VTR1V1R/k6Bo= -github.com/prometheus/procfs v0.12.0/go.mod h1:pcuDEFsWDnvcgNzo4EEweacyhjeA9Zk3cnaOZAZEfOo= +github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc= +github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk= github.com/prometheus/prometheus v0.49.2-0.20240125131847-c3b8ef1694ff h1:X1Tly81aZ22DA1fxBdfvR3iw8+yFoUBUHMEd+AX/ZXI= github.com/prometheus/prometheus v0.49.2-0.20240125131847-c3b8ef1694ff/go.mod h1:FvE8dtQ1Ww63IlyKBn1V4s+zMwF9kHkVNkQBR1pM4CU= +github.com/puzpuzpuz/xsync v1.5.2 h1:yRAP4wqSOZG+/4pxJ08fPTwrfL0IzE/LKQ/cw509qGY= +github.com/puzpuzpuz/xsync v1.5.2/go.mod h1:K98BYhX3k1dQ2M63t1YNVDanbwUPmBCAhNmVrrxfiGg= github.com/quasilyte/go-ruleguard v0.4.2 h1:htXcXDK6/rO12kiTHKfHuqR4kr3Y4M0J0rOL6CH/BYs= github.com/quasilyte/go-ruleguard v0.4.2/go.mod h1:GJLgqsLeo4qgavUoL8JeGFNS7qcisx3awV/w9eWTmNI= github.com/quasilyte/gogrep v0.5.0 h1:eTKODPXbI8ffJMN+W2aE0+oL0z/nh8/5eNdiO34SOAo= @@ -826,8 +880,8 @@ github.com/rivo/uniseg v0.4.4 h1:8TfxU8dW6PdqD27gjM8MVNuicgxIjxpm4K7x4jp8sis= github.com/rivo/uniseg v0.4.4/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= -github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8= -github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4= +github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII= +github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/ryancurrah/gomodguard v1.3.1 h1:fH+fUg+ngsQO0ruZXXHnA/2aNllWA1whly4a6UvyzGE= github.com/ryancurrah/gomodguard v1.3.1/go.mod h1:DGFHzEhi6iJ0oIDfMuo3TgrS+L9gZvrEfmjjuelnRU0= @@ -846,8 +900,8 @@ github.com/sashamelentyev/usestdlibvars v1.25.0/go.mod h1:9nl0jgOfHKWNFS43Ojw0i7 github.com/securego/gosec/v2 v2.19.0 h1:gl5xMkOI0/E6Hxx0XCY2XujA3V7SNSefA8sC+3f1gnk= github.com/securego/gosec/v2 v2.19.0/go.mod h1:hOkDcHz9J/XIgIlPDXalxjeVYsHxoWUc5zJSHxcB8YM= github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAmXWZgo= -github.com/sergi/go-diff v1.3.1 h1:xkr+Oxo4BOQKmkn/B9eMK0g5Kg/983T9DqqPHwYqD+8= -github.com/sergi/go-diff v1.3.1/go.mod h1:aMJSSKb2lpPvRNec0+w3fl7LP9IOFzdc9Pa4NFbPK1I= +github.com/sergi/go-diff v1.3.2-0.20230802210424-5b0b94c5c0d3 h1:n661drycOFuPLCN3Uc8sB6B/s6Z4t2xvBgU1htSHuq8= +github.com/sergi/go-diff v1.3.2-0.20230802210424-5b0b94c5c0d3/go.mod h1:A0bzQcvG0E7Rwjx0REVgAGH58e96+X0MeOfepqsbeW4= github.com/shazow/go-diff v0.0.0-20160112020656-b6b7b6733b8c h1:W65qqJCIOVP4jpqPQ0YvHYKwcMEMVWIzWC5iNQQfBTU= github.com/shazow/go-diff v0.0.0-20160112020656-b6b7b6733b8c/go.mod h1:/PevMnwAxekIXwN8qQyfc5gl2NlkB3CQlkizAbOkeBs= github.com/shopspring/decimal v1.2.0/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o= @@ -865,8 +919,8 @@ github.com/sivchari/containedctx v1.0.3 h1:x+etemjbsh2fB5ewm5FeLNi5bUjK0V8n0RB+W github.com/sivchari/containedctx v1.0.3/go.mod h1:c1RDvCbnJLtH4lLcYD/GqwiBSSf4F5Qk0xld2rBqzJ4= github.com/sivchari/tenv v1.7.1 h1:PSpuD4bu6fSmtWMxSGWcvqUUgIn7k3yOJhOIzVWn8Ak= github.com/sivchari/tenv v1.7.1/go.mod h1:64yStXKSOxDfX47NlhVwND4dHwfZDdbp2Lyl018Icvg= -github.com/skeema/knownhosts v1.2.1 h1:SHWdIUa82uGZz+F+47k8SY4QhhI291cXCpopT1lK2AQ= -github.com/skeema/knownhosts v1.2.1/go.mod h1:xYbVRSPxqBZFrdmDyMmsOs+uX1UZC3nTN3ThzgDxUwo= +github.com/skeema/knownhosts v1.3.0 h1:AM+y0rI04VksttfwjkSTNQorvGqmwATnvnAHpSgc0LY= +github.com/skeema/knownhosts v1.3.0/go.mod h1:sPINvnADmT/qYH1kfv+ePMmOBTH6Tbl7b5LvTDjFK7M= github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e h1:MRM5ITcdelLK2j1vwZ3Je0FKVCfqOLp5zO6trqMLYs0= github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e/go.mod h1:XV66xRDqSt+GTGFMVlhk3ULuV0y9ZmzeVGR4mloJI3M= github.com/smartystreets/assertions v1.13.1 h1:Ef7KhSmjZcK6AVf9YbJdvPYG9avaF0ZxudX+ThRdWfU= @@ -882,16 +936,19 @@ github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNo github.com/spf13/cast v1.3.1/go.mod h1:Qx5cxh0v+4UWYiBimWS+eyWzqEqokIECu5etghLkUJE= github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0= github.com/spf13/cast v1.6.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo= -github.com/spf13/cobra v1.8.1 h1:e5/vxKd/rZsfSJMUX1agtjeTDf+qv1/JdBF8gg5k9ZM= -github.com/spf13/cobra v1.8.1/go.mod h1:wHxEcudfqmLYa8iTfL+OuZPbBZkmvliBWKIezN3kD9Y= +github.com/spf13/cobra v1.9.1 h1:CXSaggrXdbHK9CF+8ywj8Amf7PBRmPCOJugH954Nnlo= +github.com/spf13/cobra v1.9.1/go.mod h1:nDyEzZ8ogv936Cinf6g1RU9MRY64Ir93oCnqb9wxYW0= github.com/spf13/jwalterweatherman v1.1.0 h1:ue6voC5bR5F8YxI5S67j9i582FU4Qvo2bmqnqMYADFk= github.com/spf13/jwalterweatherman v1.1.0/go.mod h1:aNWZUN0dPAAO/Ljvb5BEdw96iTZ0EXowPYD95IqWIGo= -github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +github.com/spf13/pflag v1.0.6 h1:jFzHGLGAlb3ruxLB8MhbI6A8+AQX/2eW4qeyNZXNp2o= +github.com/spf13/pflag v1.0.6/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/spf13/viper v1.16.0 h1:rGGH0XDZhdUOryiDWjmIvUSWpbNqisK8Wk0Vyefw8hc= github.com/spf13/viper v1.16.0/go.mod h1:yg78JgCJcbrQOvV9YLXgkLaZqUidkY9K+Dd1FofRzQg= github.com/ssgreg/nlreturn/v2 v2.2.1 h1:X4XDI7jstt3ySqGU86YGAURbxw3oTDPK9sPEi6YEwQ0= github.com/ssgreg/nlreturn/v2 v2.2.1/go.mod h1:E/iiPB78hV7Szg2YfRgyIrk1AD6JVMTRkkxBiELzh2I= +github.com/stacklok/frizbee v0.1.7 h1:IgrZy8dqKy+vBxNWrZTbDoctnV0doQKrFC6bNbWP5ho= +github.com/stacklok/frizbee v0.1.7/go.mod h1:eqMjHEgRYDSlpYpir3wXO6jyGpxr1dnFTvrTdrTIF7E= github.com/stbenjam/no-sprintf-host-port v0.1.1 h1:tYugd/yrm1O0dV+ThCbaKZh195Dfm07ysF0U6JQXczc= github.com/stbenjam/no-sprintf-host-port v0.1.1/go.mod h1:TLhvtIvONRzdmkFiio4O8LHsN9N74I+PhRquPsxpL0I= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= @@ -906,11 +963,13 @@ github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81P github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.2/go.mod h1:R6va5+xMeoiuVRoj+gSkQ7d3FALtqAAGI1FQKckRals= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= -github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/studio-b12/gowebdav v0.9.0 h1:1j1sc9gQnNxbXXM4M/CebPOX4aXYtr7MojAVcN4dHjU= github.com/studio-b12/gowebdav v0.9.0/go.mod h1:bHA7t77X/QFExdeAnDzK6vKM34kEZAcE1OX4MfiwjkE= github.com/subosito/gotenv v1.4.2 h1:X1TuBLAMDFbaTAChgCBLu3DU3UPyELpnF2jjJ2cz/S8= @@ -919,30 +978,32 @@ github.com/t-yuki/gocover-cobertura v0.0.0-20180217150009-aaee18c8195c h1:+aPplB github.com/t-yuki/gocover-cobertura v0.0.0-20180217150009-aaee18c8195c/go.mod h1:SbErYREK7xXdsRiigaQiQkI9McGRzYMvlKYaP3Nimdk= github.com/tailscale/certstore v0.1.1-0.20231202035212-d3fa0460f47e h1:PtWT87weP5LWHEY//SWsYkSO3RWRZo4OSWagh3YD2vQ= github.com/tailscale/certstore v0.1.1-0.20231202035212-d3fa0460f47e/go.mod h1:XrBNfAFN+pwoWuksbFS9Ccxnopa15zJGgXRFN90l3K4= -github.com/tailscale/depaware v0.0.0-20210622194025-720c4b409502 h1:34icjjmqJ2HPjrSuJYEkdZ+0ItmGQAQ75cRHIiftIyE= -github.com/tailscale/depaware v0.0.0-20210622194025-720c4b409502/go.mod h1:p9lPsd+cx33L3H9nNoecRRxPssFKUwwI50I3pZ0yT+8= +github.com/tailscale/depaware v0.0.0-20251001183927-9c2ad255ef3f h1:PDPGJtm9PFBLNudHGwkfUGp/FWvP+kXXJ0D1pB35F40= +github.com/tailscale/depaware v0.0.0-20251001183927-9c2ad255ef3f/go.mod h1:p9lPsd+cx33L3H9nNoecRRxPssFKUwwI50I3pZ0yT+8= github.com/tailscale/go-winio v0.0.0-20231025203758-c4f33415bf55 h1:Gzfnfk2TWrk8Jj4P4c1a3CtQyMaTVCznlkLZI++hok4= github.com/tailscale/go-winio v0.0.0-20231025203758-c4f33415bf55/go.mod h1:4k4QO+dQ3R5FofL+SanAUZe+/QfeK0+OIuwDIRu2vSg= github.com/tailscale/goexpect v0.0.0-20210902213824-6e8c725cea41 h1:/V2rCMMWcsjYaYO2MeovLw+ClP63OtXgCF2Y1eb8+Ns= github.com/tailscale/goexpect v0.0.0-20210902213824-6e8c725cea41/go.mod h1:/roCdA6gg6lQyw/Oz6gIIGu3ggJKYhF+WC/AQReE5XQ= -github.com/tailscale/golang-x-crypto v0.0.0-20240604161659-3fde5e568aa4 h1:rXZGgEa+k2vJM8xT0PoSKfVXwFGPQ3z3CJfmnHJkZZw= -github.com/tailscale/golang-x-crypto v0.0.0-20240604161659-3fde5e568aa4/go.mod h1:ikbF+YT089eInTp9f2vmvy4+ZVnW5hzX1q2WknxSprQ= +github.com/tailscale/golang-x-crypto v0.0.0-20250404221719-a5573b049869 h1:SRL6irQkKGQKKLzvQP/ke/2ZuB7Py5+XuqtOgSj+iMM= +github.com/tailscale/golang-x-crypto v0.0.0-20250404221719-a5573b049869/go.mod h1:ikbF+YT089eInTp9f2vmvy4+ZVnW5hzX1q2WknxSprQ= github.com/tailscale/goupnp v1.0.1-0.20210804011211-c64d0f06ea05 h1:4chzWmimtJPxRs2O36yuGRW3f9SYV+bMTTvMBI0EKio= github.com/tailscale/goupnp v1.0.1-0.20210804011211-c64d0f06ea05/go.mod h1:PdCqy9JzfWMJf1H5UJW2ip33/d4YkoKN0r67yKH1mG8= github.com/tailscale/hujson v0.0.0-20221223112325-20486734a56a h1:SJy1Pu0eH1C29XwJucQo73FrleVK6t4kYz4NVhp34Yw= github.com/tailscale/hujson v0.0.0-20221223112325-20486734a56a/go.mod h1:DFSS3NAGHthKo1gTlmEcSBiZrRJXi28rLNd/1udP1c8= -github.com/tailscale/mkctr v0.0.0-20240628074852-17ca944da6ba h1:uNo1VCm/xg4alMkIKo8RWTKNx5y1otfVOcKbp+irkL4= -github.com/tailscale/mkctr v0.0.0-20240628074852-17ca944da6ba/go.mod h1:DxnqIXBplij66U2ZkL688xy07q97qQ83P+TVueLiHq4= +github.com/tailscale/mkctr v0.0.0-20250228050937-c75ea1476830 h1:SwZ72kr1oRzzSPA5PYB4hzPh22UI0nm0dapn3bHaUPs= +github.com/tailscale/mkctr v0.0.0-20250228050937-c75ea1476830/go.mod h1:qTslktI+Qh9hXo7ZP8xLkl5V8AxUMfxG0xLtkCFLxnw= github.com/tailscale/netlink v1.1.1-0.20240822203006-4d49adab4de7 h1:uFsXVBE9Qr4ZoF094vE6iYTLDl0qCiKzYXlL6UeWObU= github.com/tailscale/netlink v1.1.1-0.20240822203006-4d49adab4de7/go.mod h1:NzVQi3Mleb+qzq8VmcWpSkcSYxXIg0DkI6XDzpVkhJ0= -github.com/tailscale/peercred v0.0.0-20240214030740-b535050b2aa4 h1:Gz0rz40FvFVLTBk/K8UNAenb36EbDSnh+q7Z9ldcC8w= -github.com/tailscale/peercred v0.0.0-20240214030740-b535050b2aa4/go.mod h1:phI29ccmHQBc+wvroosENp1IF9195449VDnFDhJ4rJU= -github.com/tailscale/web-client-prebuilt v0.0.0-20240226180453-5db17b287bf1 h1:tdUdyPqJ0C97SJfjB9tW6EylTtreyee9C44de+UBG0g= -github.com/tailscale/web-client-prebuilt v0.0.0-20240226180453-5db17b287bf1/go.mod h1:agQPE6y6ldqCOui2gkIh7ZMztTkIQKH049tv8siLuNQ= +github.com/tailscale/peercred v0.0.0-20250107143737-35a0c7bd7edc h1:24heQPtnFR+yfntqhI3oAu9i27nEojcQ4NuBQOo5ZFA= +github.com/tailscale/peercred v0.0.0-20250107143737-35a0c7bd7edc/go.mod h1:f93CXfllFsO9ZQVq+Zocb1Gp4G5Fz0b0rXHLOzt/Djc= +github.com/tailscale/setec v0.0.0-20250205144240-8898a29c3fbb h1:Rtklwm6HUlCtf/MR2MB9iY4FoA16acWWlC5pLrTVa90= +github.com/tailscale/setec v0.0.0-20250205144240-8898a29c3fbb/go.mod h1:R8iCVJnbOB05pGexHK/bKHneIRHpZ3jLl7wMQ0OM/jw= +github.com/tailscale/web-client-prebuilt v0.0.0-20250124233751-d4cd19a26976 h1:UBPHPtv8+nEAy2PD8RyAhOYvau1ek0HDJqLS/Pysi14= +github.com/tailscale/web-client-prebuilt v0.0.0-20250124233751-d4cd19a26976/go.mod h1:agQPE6y6ldqCOui2gkIh7ZMztTkIQKH049tv8siLuNQ= github.com/tailscale/wf v0.0.0-20240214030419-6fbb0a674ee6 h1:l10Gi6w9jxvinoiq15g8OToDdASBni4CyJOdHY1Hr8M= github.com/tailscale/wf v0.0.0-20240214030419-6fbb0a674ee6/go.mod h1:ZXRML051h7o4OcI0d3AaILDIad/Xw0IkXaHM17dic1Y= -github.com/tailscale/wireguard-go v0.0.0-20240905161824-799c1978fafc h1:cezaQN9pvKVaw56Ma5qr/G646uKIYP0yQf+OyWN/okc= -github.com/tailscale/wireguard-go v0.0.0-20240905161824-799c1978fafc/go.mod h1:BOm5fXUBFM+m9woLNBoxI9TaBXXhGNP50LX/TGIvGb4= +github.com/tailscale/wireguard-go v0.0.0-20250716170648-1d0488a3d7da h1:jVRUZPRs9sqyKlYHHzHjAqKN+6e/Vog6NpHYeNPJqOw= +github.com/tailscale/wireguard-go v0.0.0-20250716170648-1d0488a3d7da/go.mod h1:BOm5fXUBFM+m9woLNBoxI9TaBXXhGNP50LX/TGIvGb4= github.com/tailscale/xnet v0.0.0-20240729143630-8497ac4dab2e h1:zOGKqN5D5hHhiYUp091JqK7DPCqSARyUfduhGUY8Bek= github.com/tailscale/xnet v0.0.0-20240729143630-8497ac4dab2e/go.mod h1:orPd6JZXXRyuDusYilywte7k094d7dycXXU5YnWsrwg= github.com/tc-hib/winres v0.2.1 h1:YDE0FiP0VmtRaDn7+aaChp1KiF4owBiJa5l964l5ujA= @@ -961,31 +1022,34 @@ github.com/timakin/bodyclose v0.0.0-20230421092635-574207250966 h1:quvGphlmUVU+n github.com/timakin/bodyclose v0.0.0-20230421092635-574207250966/go.mod h1:27bSVNWSBOHm+qRp1T9qzaIpsWEP6TbUnei/43HK+PQ= github.com/timonwong/loggercheck v0.9.4 h1:HKKhqrjcVj8sxL7K77beXh0adEm6DLjV/QOGeMXEVi4= github.com/timonwong/loggercheck v0.9.4/go.mod h1:caz4zlPcgvpEkXgVnAJGowHAMW2NwHaNlpS8xDbVhTg= +github.com/tink-crypto/tink-go/v2 v2.1.0 h1:QXFBguwMwTIaU17EgZpEJWsUSc60b1BAGTzBIoMdmok= +github.com/tink-crypto/tink-go/v2 v2.1.0/go.mod h1:y1TnYFt1i2eZVfx4OGc+C+EMp4CoKWAw2VSEuoicHHI= github.com/tomarrell/wrapcheck/v2 v2.8.3 h1:5ov+Cbhlgi7s/a42BprYoxsr73CbdMUTzE3bRDFASUs= github.com/tomarrell/wrapcheck/v2 v2.8.3/go.mod h1:g9vNIyhb5/9TQgumxQyOEqDHsmGYcGsVMOx/xGkqdMo= github.com/tommy-muehle/go-mnd/v2 v2.5.1 h1:NowYhSdyE/1zwK9QCLeRb6USWdoif80Ie+v+yU8u1Zw= github.com/tommy-muehle/go-mnd/v2 v2.5.1/go.mod h1:WsUAkMJMYww6l/ufffCD3m+P7LEvr8TnZn9lwVDlgzw= github.com/toqueteos/webbrowser v1.2.0 h1:tVP/gpK69Fx+qMJKsLE7TD8LuGWPnEV71wBN9rrstGQ= github.com/toqueteos/webbrowser v1.2.0/go.mod h1:XWoZq4cyp9WeUeak7w7LXRUQf1F1ATJMir8RTqb4ayM= -github.com/u-root/gobusybox/src v0.0.0-20231228173702-b69f654846aa h1:unMPGGK/CRzfg923allsikmvk2l7beBeFPUNC4RVX/8= -github.com/u-root/gobusybox/src v0.0.0-20231228173702-b69f654846aa/go.mod h1:Zj4Tt22fJVn/nz/y6Ergm1SahR9dio1Zm/D2/S0TmXM= -github.com/u-root/u-root v0.12.0 h1:K0AuBFriwr0w/PGS3HawiAw89e3+MU7ks80GpghAsNs= -github.com/u-root/u-root v0.12.0/go.mod h1:FYjTOh4IkIZHhjsd17lb8nYW6udgXdJhG1c0r6u0arI= -github.com/u-root/uio v0.0.0-20240118234441-a3c409a6018e h1:BA9O3BmlTmpjbvajAwzWx4Wo2TRVdpPXZEeemGQcajw= -github.com/u-root/uio v0.0.0-20240118234441-a3c409a6018e/go.mod h1:eLL9Nub3yfAho7qB0MzZizFhTU2QkLeoVsWdHtDW264= -github.com/ulikunitz/xz v0.5.11 h1:kpFauv27b6ynzBNT/Xy+1k+fK4WswhN/6PN5WhFAGw8= -github.com/ulikunitz/xz v0.5.11/go.mod h1:nbz6k7qbPmH4IRqmfOplQw/tblSgqTqBwxkY0oWt/14= +github.com/tv42/httpunix v0.0.0-20150427012821-b75d8614f926/go.mod h1:9ESjWnEqriFuLhtthL60Sar/7RFoluCcXsuvEwTV5KM= +github.com/u-root/gobusybox/src v0.0.0-20240225013946-a274a8d5d83a h1:eg5FkNoQp76ZsswyGZ+TjYqA/rhKefxK8BW7XOlQsxo= +github.com/u-root/gobusybox/src v0.0.0-20240225013946-a274a8d5d83a/go.mod h1:e/8TmrdreH0sZOw2DFKBaUV7bvDWRq6SeM9PzkuVM68= +github.com/u-root/u-root v0.14.0 h1:Ka4T10EEML7dQ5XDvO9c3MBN8z4nuSnGjcd1jmU2ivg= +github.com/u-root/u-root v0.14.0/go.mod h1:hAyZorapJe4qzbLWlAkmSVCJGbfoU9Pu4jpJ1WMluqE= +github.com/u-root/uio v0.0.0-20240224005618-d2acac8f3701 h1:pyC9PaHYZFgEKFdlp3G8RaCKgVpHZnecvArXvPXcFkM= +github.com/u-root/uio v0.0.0-20240224005618-d2acac8f3701/go.mod h1:P3a5rG4X7tI17Nn3aOIAYr5HbIMukwXG0urG0WuL8OA= +github.com/ulikunitz/xz v0.5.15 h1:9DNdB5s+SgV3bQ2ApL10xRc35ck0DuIX/isZvIk+ubY= +github.com/ulikunitz/xz v0.5.15/go.mod h1:nbz6k7qbPmH4IRqmfOplQw/tblSgqTqBwxkY0oWt/14= github.com/ultraware/funlen v0.1.0 h1:BuqclbkY6pO+cvxoq7OsktIXZpgBSkYTQtmwhAK81vI= github.com/ultraware/funlen v0.1.0/go.mod h1:XJqmOQja6DpxarLj6Jj1U7JuoS8PvL4nEqDaQhy22p4= github.com/ultraware/whitespace v0.1.0 h1:O1HKYoh0kIeqE8sFqZf1o0qbORXUCOQFrlaQyZsczZw= github.com/ultraware/whitespace v0.1.0/go.mod h1:/se4r3beMFNmewJ4Xmz0nMQ941GJt+qmSHGP9emHYe0= github.com/uudashr/gocognit v1.1.2 h1:l6BAEKJqQH2UpKAPKdMfZf5kE4W/2xk8pfU1OVLvniI= github.com/uudashr/gocognit v1.1.2/go.mod h1:aAVdLURqcanke8h3vg35BC++eseDm66Z7KmchI5et4k= -github.com/vbatts/tar-split v0.11.5 h1:3bHCTIheBm1qFTcgh9oPu+nNBtX+XJIupG/vacinCts= -github.com/vbatts/tar-split v0.11.5/go.mod h1:yZbwRsSeGjusneWgA781EKej9HF8vme8okylkAeNKLk= +github.com/vbatts/tar-split v0.11.6 h1:4SjTW5+PU11n6fZenf2IPoV8/tz3AaYHMWjf23envGs= +github.com/vbatts/tar-split v0.11.6/go.mod h1:dqKNtesIOr2j2Qv3W/cHjnvk9I8+G7oAkFDFN6TCBEI= github.com/vishvananda/netns v0.0.0-20200728191858-db3c7e526aae/go.mod h1:DD4vA1DwXk04H54A1oHXtwZmA0grkVMdPxx/VGLCah0= -github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8= -github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM= +github.com/vishvananda/netns v0.0.5 h1:DfiHV+j8bA32MFM7bfEunvT8IAqQ/NzSJHtcmW5zdEY= +github.com/vishvananda/netns v0.0.5/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM= github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM= github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg= github.com/xanzy/ssh-agent v0.3.3 h1:+/15pJfg/RsTxqYcX6fHqOXZwwMP+2VyYWJeWM2qQFM= @@ -1017,27 +1081,31 @@ go-simpler.org/musttag v0.9.0 h1:Dzt6/tyP9ONr5g9h9P3cnYWCxeBFRkd0uJL/w+1Mxos= go-simpler.org/musttag v0.9.0/go.mod h1:gA9nThnalvNSKpEoyp3Ko4/vCX2xTpqKoUtNqXOnVR4= go-simpler.org/sloglint v0.5.0 h1:2YCcd+YMuYpuqthCgubcF5lBSjb6berc5VMOYUHKrpY= go-simpler.org/sloglint v0.5.0/go.mod h1:EUknX5s8iXqf18KQxKnaBHUPVriiPnOrPjjJcsaTcSQ= +go.etcd.io/bbolt v1.3.11 h1:yGEzV1wPz2yVCLsD8ZAiGHhHVlczyC9d1rP43/VCRJ0= +go.etcd.io/bbolt v1.3.11/go.mod h1:dksAq7YMXoljX0xu6VF5DMZGbhYYoLUalEiSySYAS4I= go.opencensus.io v0.21.0/go.mod h1:mSImk1erAIZhrmZN+AvHh14ztQfjbGwt4TtuofqLduU= go.opencensus.io v0.22.0/go.mod h1:+kGneAE2xo2IficOXnaByMWTGM9T73dGwxeWcUqIpI8= go.opencensus.io v0.22.2/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= go.opencensus.io v0.22.3/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= go.opencensus.io v0.22.4/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= -go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.47.0 h1:sv9kVfal0MK0wBMCOGr+HeJm9v803BkJxGrk2au7j08= -go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.47.0/go.mod h1:SK2UL73Zy1quvRPonmOmRDiWk1KBV3LyIeeIxcEApWw= -go.opentelemetry.io/otel v1.22.0 h1:xS7Ku+7yTFvDfDraDIJVpw7XPyuHlB9MCiqqX5mcJ6Y= -go.opentelemetry.io/otel v1.22.0/go.mod h1:eoV4iAi3Ea8LkAEI9+GFT44O6T/D0GWAVFyZVCC6pMI= -go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.22.0 h1:9M3+rhx7kZCIQQhQRYaZCdNu1V73tm4TvXs2ntl98C4= -go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.22.0/go.mod h1:noq80iT8rrHP1SfybmPiRGc9dc5M8RPmGvtwo7Oo7tc= -go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.22.0 h1:FyjCyI9jVEfqhUh2MoSkmolPjfh5fp2hnV0b0irxH4Q= -go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.22.0/go.mod h1:hYwym2nDEeZfG/motx0p7L7J1N1vyzIThemQsb4g2qY= -go.opentelemetry.io/otel/metric v1.22.0 h1:lypMQnGyJYeuYPhOM/bgjbFM6WE44W1/T45er4d8Hhg= -go.opentelemetry.io/otel/metric v1.22.0/go.mod h1:evJGjVpZv0mQ5QBRJoBF64yMuOf4xCWdXjK8pzFvliY= -go.opentelemetry.io/otel/sdk v1.22.0 h1:6coWHw9xw7EfClIC/+O31R8IY3/+EiRFHevmHafB2Gw= -go.opentelemetry.io/otel/sdk v1.22.0/go.mod h1:iu7luyVGYovrRpe2fmj3CVKouQNdTOkxtLzPvPz1DOc= -go.opentelemetry.io/otel/trace v1.22.0 h1:Hg6pPujv0XG9QaVbGOBVHunyuLcCC3jN7WEhPx83XD0= -go.opentelemetry.io/otel/trace v1.22.0/go.mod h1:RbbHXVqKES9QhzZq/fE5UnOSILqRt40a21sPw2He1xo= -go.opentelemetry.io/proto/otlp v1.0.0 h1:T0TX0tmXU8a3CbNXzEKGeU5mIVOdf0oykP+u2lIVU/I= -go.opentelemetry.io/proto/otlp v1.0.0/go.mod h1:Sy6pihPLfYHkr3NkUbEhGHFhINUSI/v80hjKIs5JXpM= +go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA= +go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A= +go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.58.0 h1:yd02MEjBdJkG3uabWP9apV+OuWRIXGDuJEUJbOHmCFU= +go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.58.0/go.mod h1:umTcuxiv1n/s/S6/c2AT/g2CQ7u5C59sHDNmfSwgz7Q= +go.opentelemetry.io/otel v1.33.0 h1:/FerN9bax5LoK51X/sI0SVYrjSE0/yUL7DpxW4K3FWw= +go.opentelemetry.io/otel v1.33.0/go.mod h1:SUUkR6csvUQl+yjReHu5uM3EtVV7MBm5FHKRlNx4I8I= +go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.31.0 h1:K0XaT3DwHAcV4nKLzcQvwAgSyisUghWoY20I7huthMk= +go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.31.0/go.mod h1:B5Ki776z/MBnVha1Nzwp5arlzBbE3+1jk+pGmaP5HME= +go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.33.0 h1:wpMfgF8E1rkrT1Z6meFh1NDtownE9Ii3n3X2GJYjsaU= +go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.33.0/go.mod h1:wAy0T/dUbs468uOlkT31xjvqQgEVXv58BRFWEgn5v/0= +go.opentelemetry.io/otel/metric v1.33.0 h1:r+JOocAyeRVXD8lZpjdQjzMadVZp2M4WmQ+5WtEnklQ= +go.opentelemetry.io/otel/metric v1.33.0/go.mod h1:L9+Fyctbp6HFTddIxClbQkjtubW6O9QS3Ann/M82u6M= +go.opentelemetry.io/otel/sdk v1.33.0 h1:iax7M131HuAm9QkZotNHEfstof92xM+N8sr3uHXc2IM= +go.opentelemetry.io/otel/sdk v1.33.0/go.mod h1:A1Q5oi7/9XaMlIWzPSxLRWOI8nG3FnzHJNbiENQuihM= +go.opentelemetry.io/otel/trace v1.33.0 h1:cCJuF7LRjUFso9LPnEAHJDB2pqzp+hbO8eu1qqW2d/s= +go.opentelemetry.io/otel/trace v1.33.0/go.mod h1:uIcdVUZMpTAmz0tI1z04GoVSezK37CbGV4fr1f2nBck= +go.opentelemetry.io/proto/otlp v1.3.1 h1:TrMUixzpM0yuc/znrFTP9MMRh8trP93mkCiDVeXrui0= +go.opentelemetry.io/proto/otlp v1.3.1/go.mod h1:0X1WI4de4ZsLrrJNLAQbFeLCm3T7yBkR0XqQ7niQU+8= go.uber.org/automaxprocs v1.5.3 h1:kWazyxZUrS3Gs4qUpbwo5kEIMGe/DAvi5Z4tl2NW4j8= go.uber.org/automaxprocs v1.5.3/go.mod h1:eRbA25aqJrxAbsLO0xy5jVwPt7FQnRgjW+efnwa1WM0= go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= @@ -1046,8 +1114,8 @@ go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8= go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E= -go4.org/mem v0.0.0-20220726221520-4f986261bf13 h1:CbZeCBZ0aZj8EfVgnqQcYZgf0lpZ3H9rmp5nkDTAst8= -go4.org/mem v0.0.0-20220726221520-4f986261bf13/go.mod h1:reUoABIJ9ikfM5sgtSF3Wushcza7+WeD01VB9Lirh3g= +go4.org/mem v0.0.0-20240501181205-ae6ca9944745 h1:Tl++JLUCe4sxGu8cTpDzRLd3tN7US4hOxG5YpKCzkek= +go4.org/mem v0.0.0-20240501181205-ae6ca9944745/go.mod h1:reUoABIJ9ikfM5sgtSF3Wushcza7+WeD01VB9Lirh3g= go4.org/netipx v0.0.0-20231129151722-fdeea329fbba h1:0b9z3AuHCjxk0x/opv64kcgZLBseWJUpBw5I82+2U4M= go4.org/netipx v0.0.0-20231129151722-fdeea329fbba/go.mod h1:PLyyIXexvUFg3Owu6p/WfdlivPbZJsZdgWZlrGope/Y= golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= @@ -1060,10 +1128,8 @@ golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5y golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/crypto v0.1.0/go.mod h1:RecgLatLF4+eUMCP1PoPZQb+cVrJcOPbHkTkbkB9sbw= golang.org/x/crypto v0.3.0/go.mod h1:hebNnKkNXi2UzZN1eVRvBB7co0a+JxK6XbPiWVs/3J4= -golang.org/x/crypto v0.3.1-0.20221117191849-2c476679df9a/go.mod h1:hebNnKkNXi2UzZN1eVRvBB7co0a+JxK6XbPiWVs/3J4= -golang.org/x/crypto v0.7.0/go.mod h1:pYwdfH91IfpZVANVyUOhSIPZaFoJGxTFbZhFTx+dXZU= -golang.org/x/crypto v0.25.0 h1:ypSNr+bnYL2YhwoMt2zPxHFmbAN1KZs/njMG3hxUp30= -golang.org/x/crypto v0.25.0/go.mod h1:T+wALwcMOSE0kXgUAnPAHqTLW+XHgcELELW8VaDgm/M= +golang.org/x/crypto v0.44.0 h1:A97SsFvM3AIwEEmTBiaxPPTYpDC47w720rdiiUvgoAU= +golang.org/x/crypto v0.44.0/go.mod h1:013i+Nw79BMiQiMsOPcVCB5ZIJbYkerPrGnOa00tvmc= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8= @@ -1074,16 +1140,16 @@ golang.org/x/exp v0.0.0-20191227195350-da58074b4299/go.mod h1:2RIsYlXP63K8oxa1u0 golang.org/x/exp v0.0.0-20200119233911-0405dc783f0a/go.mod h1:2RIsYlXP63K8oxa1u096TMicItID8zy7Y6sNkU49FU4= golang.org/x/exp v0.0.0-20200207192155-f17229e696bd/go.mod h1:J/WKrq2StrnmMY6+EHIKF9dgMWnmCNThgcyBT1FY9mM= golang.org/x/exp v0.0.0-20200224162631-6cc2880d07d6/go.mod h1:3jZMyOhIsHpP37uCMkUooju7aAi5cS1Q23tOzKc+0MU= -golang.org/x/exp v0.0.0-20240119083558-1b970713d09a h1:Q8/wZp0KX97QFTc2ywcOE0YRjZPVIx+MXInMzdvQqcA= -golang.org/x/exp v0.0.0-20240119083558-1b970713d09a/go.mod h1:idGWGoKP1toJGkd5/ig9ZLuPcZBC3ewk7SzmH0uou08= +golang.org/x/exp v0.0.0-20250210185358-939b2ce775ac h1:l5+whBCLH3iH2ZNHYLbAe58bo7yrN4mVcnkHDYz5vvs= +golang.org/x/exp v0.0.0-20250210185358-939b2ce775ac/go.mod h1:hH+7mtFmImwwcMvScyxUhjuVHR3HGaDPMn9rMSUUbxo= golang.org/x/exp/typeparams v0.0.0-20220428152302-39d4317da171/go.mod h1:AbB0pIl9nAr9wVwH+Z2ZpaocVmF5I4GyWCDIsVjR0bk= golang.org/x/exp/typeparams v0.0.0-20230203172020-98cc5a0785f9/go.mod h1:AbB0pIl9nAr9wVwH+Z2ZpaocVmF5I4GyWCDIsVjR0bk= golang.org/x/exp/typeparams v0.0.0-20240314144324-c7f7c6466f7f h1:phY1HzDcf18Aq9A8KkmRtY9WvOFIxN8wgfvy6Zm1DV8= golang.org/x/exp/typeparams v0.0.0-20240314144324-c7f7c6466f7f/go.mod h1:AbB0pIl9nAr9wVwH+Z2ZpaocVmF5I4GyWCDIsVjR0bk= golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js= golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= -golang.org/x/image v0.18.0 h1:jGzIakQa/ZXI1I0Fxvaa9W7yP25TqT6cHIHn+6CqvSQ= -golang.org/x/image v0.18.0/go.mod h1:4yyo5vMFQjVjUcVk4jEQcU9MGy/rulF5WvUILseCM2E= +golang.org/x/image v0.27.0 h1:C8gA4oWU/tKkdCfYT6T2u4faJu3MeNS5O8UPWlPF61w= +golang.org/x/image v0.27.0/go.mod h1:xbdrClrAUway1MUTEZDq9mz/UpRwYAkFFNUslZtcB+g= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= golang.org/x/lint v0.0.0-20190301231843-5614ed5bae6f/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= @@ -1111,8 +1177,8 @@ golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91 golang.org/x/mod v0.6.0/go.mod h1:4mET923SAdbXp2ki8ey+zGs1SLqsuM2Y0uvdZR/fUNI= golang.org/x/mod v0.7.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= -golang.org/x/mod v0.19.0 h1:fEdghXQSo20giMthA7cd28ZC+jts4amQ3YMXiP5oMQ8= -golang.org/x/mod v0.19.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= +golang.org/x/mod v0.30.0 h1:fDEXFVZ/fmCKProc/yAXXUijritrDzahmwwefnjoPFk= +golang.org/x/mod v0.30.0/go.mod h1:lAsf5O2EvJeSFMiBxXDki7sCgAxEUcZHXoXMKT4GJKc= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -1152,17 +1218,16 @@ golang.org/x/net v0.1.0/go.mod h1:Cx3nUiGt4eDBEyega/BKRp+/AlGL8hYe7U9odMt2Cco= golang.org/x/net v0.2.0/go.mod h1:KqCZLdyyvdV855qA2rE3GC2aiw5xGR5TEjj8smXukLY= golang.org/x/net v0.5.0/go.mod h1:DivGGAXEgPSlEBzxGzZI+ZLohi+xUj054jfeKui00ws= golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= -golang.org/x/net v0.8.0/go.mod h1:QVkue5JL9kW//ek3r6jTKnTFis1tRmNAW2P1shuFdJc= -golang.org/x/net v0.27.0 h1:5K3Njcw06/l2y9vpGCSdcxWOYHOUk3dVNGDXN+FvAys= -golang.org/x/net v0.27.0/go.mod h1:dDi0PyhWNoiUOrAS8uXv/vnScO4wnHQO4mj9fn/RytE= +golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY= +golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20191202225959-858c2ad4c8b6/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20210514164344-f6687ab2804c/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A= -golang.org/x/oauth2 v0.16.0 h1:aDkGMBSYxElaoP81NpoUoz2oo2R2wHdZpGToUxfyQrQ= -golang.org/x/oauth2 v0.16.0/go.mod h1:hqZ+0LWXsiVoZpeld6jVt06P3adbS2Uu911W1SsJv2o= +golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI= +golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -1176,8 +1241,8 @@ golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= -golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I= +golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -1196,6 +1261,7 @@ golang.org/x/sys v0.0.0-20191204072324-ce4227a45e2e/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20191228213918-04cbcbbfeed8/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200106162015-b016eb3dc98e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200113162924-86b910548bc1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200122134326-e047566fdf82/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200202164722-d101bd2416d5/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200212091648-12a6c2dcc1e4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -1221,12 +1287,14 @@ golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211019181941-9d821ace8654/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211105183446-c75c47738b0c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220114195835-da31bd327af9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220503163025-988cb79eb6c6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220615213510-4f61da869c0c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220622161953-175b2fd9d664/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220702020025-31831981b65f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -1234,22 +1302,21 @@ golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220817070843-5a390386f1f2/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.3.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.4.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.4.1-0.20230131160137-e7d7f63158de/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.22.0 h1:RI27ohtqKCnwULzJLqkv897zojh5/DwS/ENaMzUOaWI= -golang.org/x/sys v0.22.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc= +golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/telemetry v0.0.0-20251111182119-bc8e575c7b54 h1:E2/AqCUMZGgd73TQkxUMcMla25GB9i/5HOdLr+uH7Vo= +golang.org/x/telemetry v0.0.0-20251111182119-bc8e575c7b54/go.mod h1:hKdjCMrbv9skySur+Nek8Hd0uJ0GuxJIoIX2payrIdQ= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.1.0/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.2.0/go.mod h1:TVmDHMZPmdnySmBfhjOoOdhjzdE1h4u1VwSiw2l1Nuc= golang.org/x/term v0.4.0/go.mod h1:9P2UbLfCdcvo3p/nzKvsmas4TnlujnuoV9hGgYzW1lQ= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= -golang.org/x/term v0.6.0/go.mod h1:m6U89DPEgQRMq3DNkDClhWw02AUbt2daBVO4cn4Hv9U= -golang.org/x/term v0.22.0 h1:BbsgPEJULsl2fV/AT3v15Mjva5yXKQDyKf+TbDz7QJk= -golang.org/x/term v0.22.0/go.mod h1:F3qCibpT5AMpCRfhfT53vVJwhLtIVHhB9XDjfFvnMI4= +golang.org/x/term v0.37.0 h1:8EGAD0qCmHYZg6J17DvsMy9/wJ7/D/4pV/wfnld5lTU= +golang.org/x/term v0.37.0/go.mod h1:5pB4lxRNYYVZuTLmy8oR2BH8dflOR+IbTYFD8fi3254= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -1257,18 +1324,16 @@ golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= -golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ= golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.6.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= -golang.org/x/text v0.8.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= -golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4= -golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI= +golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM= +golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= -golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk= -golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= +golang.org/x/time v0.11.0 h1:/bpjEDfN9tkoN/ryeYHnv5hcMlc8ncjMcM4XBk5NWV0= +golang.org/x/time v0.11.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= @@ -1333,8 +1398,12 @@ golang.org/x/tools v0.2.0/go.mod h1:y4OqIKeOV/fWJetJ8bXPU1sEVniLMIyDAZWeHdV+NTA= golang.org/x/tools v0.3.0/go.mod h1:/rWhSS2+zyEVwoJf8YAX6L2f0ntZ7Kn/mGgAWcipA5k= golang.org/x/tools v0.5.0/go.mod h1:N+Kgy78s5I24c24dU8OfWNEotWjutIs8SnJvn5IDq+k= golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= -golang.org/x/tools v0.23.0 h1:SGsXPZ+2l4JsgaCKkx+FQ9YZ5XEtA1GZYuoDjenLjvg= -golang.org/x/tools v0.23.0/go.mod h1:pnu6ufv6vQkll6szChhK3C3L/ruaIv5eBeztNG8wtsI= +golang.org/x/tools v0.39.0 h1:ik4ho21kwuQln40uelmciQPp9SipgNDdrafrYA4TmQQ= +golang.org/x/tools v0.39.0/go.mod h1:JnefbkDPyD8UU2kI5fuf8ZX4/yUeh9W877ZeBONxUqQ= +golang.org/x/tools/go/expect v0.1.1-deprecated h1:jpBZDwmgPhXsKZC6WhL20P4b/wmnpsEAGHaNy0n/rJM= +golang.org/x/tools/go/expect v0.1.1-deprecated/go.mod h1:eihoPOH+FgIqa3FpoTwguz/bVUSGBlGQU67vpBeOrBY= +golang.org/x/tools/go/packages/packagestest v0.1.1-deprecated h1:1h2MnaIAIXISqTFKdENegdpAgUXz6NrPEsbIeWaBRvM= +golang.org/x/tools/go/packages/packagestest v0.1.1-deprecated/go.mod h1:RVAQXBGNv1ib0J382/DPCRS/BPnsGebyM1Gj5VSDpG8= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= @@ -1369,8 +1438,6 @@ google.golang.org/appengine v1.5.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7 google.golang.org/appengine v1.6.1/go.mod h1:i06prIuMbXzDqacNJfV5OdTW448YApPu5ww/cMBSeb0= google.golang.org/appengine v1.6.5/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= google.golang.org/appengine v1.6.6/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= -google.golang.org/appengine v1.6.8 h1:IhEN5q69dyKagZPYMSdIjS2HqprW324FRQZJcGqPAsM= -google.golang.org/appengine v1.6.8/go.mod h1:1jJ3jBArFh5pcgW8gCtRJnepW8FzD1V44FJffLiz/Ds= google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= google.golang.org/genproto v0.0.0-20190307195333-5fe7a883aa19/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= google.golang.org/genproto v0.0.0-20190418145605-e7d98fc518a7/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= @@ -1400,11 +1467,11 @@ google.golang.org/genproto v0.0.0-20200618031413-b414f8b61790/go.mod h1:jDfRM7Fc google.golang.org/genproto v0.0.0-20200729003335-053ba62fc06f/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= google.golang.org/genproto v0.0.0-20200804131852-c06518451d9c/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= google.golang.org/genproto v0.0.0-20200825200019-8632dd797987/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= -google.golang.org/genproto v0.0.0-20240102182953-50ed04b92917 h1:nz5NESFLZbJGPFxDT/HCn+V1mZ8JGNoY4nUpmW/Y2eg= -google.golang.org/genproto/googleapis/api v0.0.0-20240116215550-a9fa1716bcac h1:OZkkudMUu9LVQMCoRUbI/1p5VCo9BOrlvkqMvWtqa6s= -google.golang.org/genproto/googleapis/api v0.0.0-20240116215550-a9fa1716bcac/go.mod h1:B5xPO//w8qmBDjGReYLpR6UJPnkldGkCSMoH/2vxJeg= -google.golang.org/genproto/googleapis/rpc v0.0.0-20240116215550-a9fa1716bcac h1:nUQEQmH/csSvFECKYRv6HWEyypysidKl2I6Qpsglq/0= -google.golang.org/genproto/googleapis/rpc v0.0.0-20240116215550-a9fa1716bcac/go.mod h1:daQN87bsDqDoe316QbbvX60nMoJQa4r6Ds0ZuoAe5yA= +google.golang.org/genproto v0.0.0-20240123012728-ef4313101c80 h1:KAeGQVN3M9nD0/bQXnr/ClcEMJ968gUXJQ9pwfSynuQ= +google.golang.org/genproto/googleapis/api v0.0.0-20241021214115-324edc3d5d38 h1:2oV8dfuIkM1Ti7DwXc0BJfnwr9csz4TDXI9EmiI+Rbw= +google.golang.org/genproto/googleapis/api v0.0.0-20241021214115-324edc3d5d38/go.mod h1:vuAjtvlwkDKF6L1GQ0SokiRLCGFfeBUXWr/aFFkHACc= +google.golang.org/genproto/googleapis/rpc v0.0.0-20241021214115-324edc3d5d38 h1:zciRKQ4kBpFgpfC5QQCVtnnNAcLIqweL7plyZRQHVpI= +google.golang.org/genproto/googleapis/rpc v0.0.0-20241021214115-324edc3d5d38/go.mod h1:GX3210XPVPUjJbTUbvwI8f2IpZDMZuPJWDzDuebbviI= google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38= google.golang.org/grpc v1.21.1/go.mod h1:oYelfM1adQP15Ek0mdvEgi9Df8B9CZIaU1084ijfRaM= @@ -1417,8 +1484,8 @@ google.golang.org/grpc v1.28.0/go.mod h1:rpkK4SK4GF4Ach/+MFLZUBavHOvF2JJB5uozKKa google.golang.org/grpc v1.29.1/go.mod h1:itym6AZVZYACWQqET3MqgPpjcuV5QH3BxFS3IjizoKk= google.golang.org/grpc v1.30.0/go.mod h1:N36X2cJ7JwdamYAgDz+s+rVMFjt3numwzf/HckM8pak= google.golang.org/grpc v1.31.0/go.mod h1:N36X2cJ7JwdamYAgDz+s+rVMFjt3numwzf/HckM8pak= -google.golang.org/grpc v1.61.0 h1:TOvOcuXn30kRao+gfcvsebNEa5iZIiLkisYEkf7R7o0= -google.golang.org/grpc v1.61.0/go.mod h1:VUbo7IFqmF1QtCAstipjG0GIoq49KvMe9+h1jFLBNJs= +google.golang.org/grpc v1.69.4 h1:MF5TftSMkd8GLw/m0KM6V8CMOCY6NZ1NQDPGFgbTt4A= +google.golang.org/grpc v1.69.4/go.mod h1:vyjdE6jLBI76dgpDojsFGNaHlxdjXN9ghpnd2o7JGZ4= google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= @@ -1431,8 +1498,8 @@ google.golang.org/protobuf v1.24.0/go.mod h1:r/3tXBNzIEhYS9I1OUVjXDlt8tc493IdKGj google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= -google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI= -google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= +google.golang.org/protobuf v1.36.3 h1:82DV7MYdb8anAVi3qge1wSnMDrnKK7ebr+I0hHRN1BU= +google.golang.org/protobuf v1.36.3/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE= gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= @@ -1440,6 +1507,10 @@ gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8 gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= +gopkg.in/evanphx/json-patch.v4 v4.12.0 h1:n6jtcsulIzXPJaxegRbvFNNrZDjbij7ny3gmSPG+6V4= +gopkg.in/evanphx/json-patch.v4 v4.12.0/go.mod h1:p8EYWUEYMpynmqDbY58zCKCFZw8pRWMG4EsWvDvM72M= +gopkg.in/h2non/gock.v1 v1.1.2 h1:jBbHXgGBK/AoPVfJh5x4r/WxIrElvbLel8TCZkkZJoY= +gopkg.in/h2non/gock.v1 v1.1.2/go.mod h1:n7UGz/ckNChHiK05rDoiC4MYSunEC/lyaUm2WWaDva0= gopkg.in/inf.v0 v0.9.1 h1:73M5CoZyi3ZLMOyDlQh031Cx6N9NDJ2Vvfl76EDAgDc= gopkg.in/inf.v0 v0.9.1/go.mod h1:cWUDdTG/fYaXco+Dcufb5Vnc6Gp2YChqWtbxRZE0mXw= gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA= @@ -1464,8 +1535,8 @@ gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gotest.tools/v3 v3.4.0 h1:ZazjZUfuVeZGLAmlKKuyv3IKP5orXcwtOwDQH6YVr6o= gotest.tools/v3 v3.4.0/go.mod h1:CtbdzLSsqVhDgMtKsx03ird5YTGB3ar27v0u/yKBW5g= -gvisor.dev/gvisor v0.0.0-20240722211153-64c016c92987 h1:TU8z2Lh3Bbq77w0t1eG8yRlLcNHzZu3x6mhoH2Mk0c8= -gvisor.dev/gvisor v0.0.0-20240722211153-64c016c92987/go.mod h1:sxc3Uvk/vHcd3tj7/DHVBoR5wvWT/MmRq2pj7HRJnwU= +gvisor.dev/gvisor v0.0.0-20250205023644-9414b50a5633 h1:2gap+Kh/3F47cO6hAu3idFvsJ0ue6TRcEi2IUkv/F8k= +gvisor.dev/gvisor v0.0.0-20250205023644-9414b50a5633/go.mod h1:5DMfjtclAbTIjbXqO1qCe2K5GKKxWz2JHvCChuTcJEM= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190106161140-3f1c8253044a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190418001031-e561f6794a2a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= @@ -1473,26 +1544,28 @@ honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWh honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg= honnef.co/go/tools v0.0.1-2020.1.3/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9vFzvIQ3k= honnef.co/go/tools v0.0.1-2020.1.4/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9vFzvIQ3k= -honnef.co/go/tools v0.5.1 h1:4bH5o3b5ZULQ4UrBmP+63W9r7qIkqJClEA9ko5YKx+I= -honnef.co/go/tools v0.5.1/go.mod h1:e9irvo83WDG9/irijV44wr3tbhcFeRnfpVlRqVwpzMs= +honnef.co/go/tools v0.7.0-0.dev.0.20251022135355-8273271481d0 h1:5SXjd4ET5dYijLaf0O3aOenC0Z4ZafIWSpjUzsQaNho= +honnef.co/go/tools v0.7.0-0.dev.0.20251022135355-8273271481d0/go.mod h1:EPDDhEZqVHhWuPI5zPAsjU0U7v9xNIWjoOVyZ5ZcniQ= howett.net/plist v1.0.0 h1:7CrbWYbPPO/PyNy38b2EB/+gYbjCe2DXBxgtOOZbSQM= howett.net/plist v1.0.0/go.mod h1:lqaXoTrLY4hg8tnEzNru53gicrbv7rrk+2xJA/7hw9g= -k8s.io/api v0.30.3 h1:ImHwK9DCsPA9uoU3rVh4QHAHHK5dTSv1nxJUapx8hoQ= -k8s.io/api v0.30.3/go.mod h1:GPc8jlzoe5JG3pb0KJCSLX5oAFIW3/qNJITlDj8BH04= -k8s.io/apiextensions-apiserver v0.30.3 h1:oChu5li2vsZHx2IvnGP3ah8Nj3KyqG3kRSaKmijhB9U= -k8s.io/apiextensions-apiserver v0.30.3/go.mod h1:uhXxYDkMAvl6CJw4lrDN4CPbONkF3+XL9cacCT44kV4= -k8s.io/apimachinery v0.30.3 h1:q1laaWCmrszyQuSQCfNB8cFgCuDAoPszKY4ucAjDwHc= -k8s.io/apimachinery v0.30.3/go.mod h1:iexa2somDaxdnj7bha06bhb43Zpa6eWH8N8dbqVjTUc= -k8s.io/apiserver v0.30.3 h1:QZJndA9k2MjFqpnyYv/PH+9PE0SHhx3hBho4X0vE65g= -k8s.io/apiserver v0.30.3/go.mod h1:6Oa88y1CZqnzetd2JdepO0UXzQX4ZnOekx2/PtEjrOg= -k8s.io/client-go v0.30.3 h1:bHrJu3xQZNXIi8/MoxYtZBBWQQXwy16zqJwloXXfD3k= -k8s.io/client-go v0.30.3/go.mod h1:8d4pf8vYu665/kUbsxWAQ/JDBNWqfFeZnvFiVdmx89U= +k8s.io/api v0.32.0 h1:OL9JpbvAU5ny9ga2fb24X8H6xQlVp+aJMFlgtQjR9CE= +k8s.io/api v0.32.0/go.mod h1:4LEwHZEf6Q/cG96F3dqR965sYOfmPM7rq81BLgsE0p0= +k8s.io/apiextensions-apiserver v0.32.0 h1:S0Xlqt51qzzqjKPxfgX1xh4HBZE+p8KKBq+k2SWNOE0= +k8s.io/apiextensions-apiserver v0.32.0/go.mod h1:86hblMvN5yxMvZrZFX2OhIHAuFIMJIZ19bTvzkP+Fmw= +k8s.io/apimachinery v0.32.0 h1:cFSE7N3rmEEtv4ei5X6DaJPHHX0C+upp+v5lVPiEwpg= +k8s.io/apimachinery v0.32.0/go.mod h1:GpHVgxoKlTxClKcteaeuF1Ul/lDVb74KpZcxcmLDElE= +k8s.io/apiserver v0.32.0 h1:VJ89ZvQZ8p1sLeiWdRJpRD6oLozNZD2+qVSLi+ft5Qs= +k8s.io/apiserver v0.32.0/go.mod h1:HFh+dM1/BE/Hm4bS4nTXHVfN6Z6tFIZPi649n83b4Ag= +k8s.io/client-go v0.32.0 h1:DimtMcnN/JIKZcrSrstiwvvZvLjG0aSxy8PxN8IChp8= +k8s.io/client-go v0.32.0/go.mod h1:boDWvdM1Drk4NJj/VddSLnx59X3OPgwrOo0vGbtq9+8= +k8s.io/component-base v0.32.0 h1:d6cWHZkCiiep41ObYQS6IcgzOUQUNpywm39KVYaUqzU= +k8s.io/component-base v0.32.0/go.mod h1:JLG2W5TUxUu5uDyKiH2R/7NnxJo1HlPoRIIbVLkK5eM= k8s.io/klog/v2 v2.130.1 h1:n9Xl7H1Xvksem4KFG4PYbdQCQxqc/tTUyrgXaOhHSzk= k8s.io/klog/v2 v2.130.1/go.mod h1:3Jpz1GvMt720eyJH1ckRHK1EDfpxISzJ7I9OYgaDtPE= -k8s.io/kube-openapi v0.0.0-20240228011516-70dd3763d340 h1:BZqlfIlq5YbRMFko6/PM7FjZpUb45WallggurYhKGag= -k8s.io/kube-openapi v0.0.0-20240228011516-70dd3763d340/go.mod h1:yD4MZYeKMBwQKVht279WycxKyM84kkAx2DPrTXaeb98= -k8s.io/utils v0.0.0-20240711033017-18e509b52bc8 h1:pUdcCO1Lk/tbT5ztQWOBi5HBgbBP1J8+AsQnQCKsi8A= -k8s.io/utils v0.0.0-20240711033017-18e509b52bc8/go.mod h1:OLgZIPagt7ERELqWJFomSt595RzquPNLL48iOWgYOg0= +k8s.io/kube-openapi v0.0.0-20241105132330-32ad38e42d3f h1:GA7//TjRY9yWGy1poLzYYJJ4JRdzg3+O6e8I+e+8T5Y= +k8s.io/kube-openapi v0.0.0-20241105132330-32ad38e42d3f/go.mod h1:R/HEjbvWI0qdfb8viZUeVZm0X6IZnxAydC7YU42CMw4= +k8s.io/utils v0.0.0-20241104100929-3ea5e8cea738 h1:M3sRQVHv7vB20Xc2ybTt7ODCeFj6JSWYFzOFnYeS6Ro= +k8s.io/utils v0.0.0-20241104100929-3ea5e8cea738/go.mod h1:OLgZIPagt7ERELqWJFomSt595RzquPNLL48iOWgYOg0= mvdan.cc/gofumpt v0.6.0 h1:G3QvahNDmpD+Aek/bNOLrFR2XC6ZAdo62dZu65gmwGo= mvdan.cc/gofumpt v0.6.0/go.mod h1:4L0wf+kgIPZtcCWXynNS2e6bhmj73umwnuXSZarixzA= mvdan.cc/unparam v0.0.0-20240104100049-c549a3470d14 h1:zCr3iRRgdk5eIikZNDphGcM6KGVTx3Yu+/Uu9Es254w= @@ -1500,14 +1573,14 @@ mvdan.cc/unparam v0.0.0-20240104100049-c549a3470d14/go.mod h1:ZzZjEpJDOmx8TdVU6u rsc.io/binaryregexp v0.2.0/go.mod h1:qTv7/COck+e2FymRvadv62gMdZztPaShugOCi3I+8D8= rsc.io/quote/v3 v3.1.0/go.mod h1:yEA65RcK8LyAZtP9Kv3t0HmxON59tX3rD+tICJqUlj0= rsc.io/sampler v1.3.0/go.mod h1:T1hPZKmBbMNahiBKFy5HrXp6adAjACjK9JXDnKaTXpA= -sigs.k8s.io/controller-runtime v0.18.4 h1:87+guW1zhvuPLh1PHybKdYFLU0YJp4FhJRmiHvm5BZw= -sigs.k8s.io/controller-runtime v0.18.4/go.mod h1:TVoGrfdpbA9VRFaRnKgk9P5/atA0pMwq+f+msb9M8Sg= -sigs.k8s.io/controller-tools v0.15.1-0.20240618033008-7824932b0cab h1:Fq4VD28nejtsijBNTeRRy9Tt3FVwq+o6NB7fIxja8uY= -sigs.k8s.io/controller-tools v0.15.1-0.20240618033008-7824932b0cab/go.mod h1:egedX5jq2KrZ3A2zaOz3e2DSsh5BhFyyjvNcBRIQel8= -sigs.k8s.io/json v0.0.0-20221116044647-bc3834ca7abd h1:EDPBXCAspyGV4jQlpZSudPeMmr1bNJefnuqLsRAsHZo= -sigs.k8s.io/json v0.0.0-20221116044647-bc3834ca7abd/go.mod h1:B8JuhiUyNFVKdsE8h686QcCxMaH6HrOAZj4vswFpcB0= -sigs.k8s.io/structured-merge-diff/v4 v4.4.1 h1:150L+0vs/8DA78h1u02ooW1/fFq/Lwr+sGiqlzvrtq4= -sigs.k8s.io/structured-merge-diff/v4 v4.4.1/go.mod h1:N8hJocpFajUSSeSJ9bOZ77VzejKZaXsTtZo4/u7Io08= +sigs.k8s.io/controller-runtime v0.19.4 h1:SUmheabttt0nx8uJtoII4oIP27BVVvAKFvdvGFwV/Qo= +sigs.k8s.io/controller-runtime v0.19.4/go.mod h1:iRmWllt8IlaLjvTTDLhRBXIEtkCK6hwVBJJsYS9Ajf4= +sigs.k8s.io/controller-tools v0.17.0 h1:KaEQZbhrdY6J3zLBHplt+0aKUp8PeIttlhtF2UDo6bI= +sigs.k8s.io/controller-tools v0.17.0/go.mod h1:SKoWY8rwGWDzHtfnhmOwljn6fViG0JF7/xmnxpklgjo= +sigs.k8s.io/json v0.0.0-20241010143419-9aa6b5e7a4b3 h1:/Rv+M11QRah1itp8VhT6HoVx1Ray9eB4DBr+K+/sCJ8= +sigs.k8s.io/json v0.0.0-20241010143419-9aa6b5e7a4b3/go.mod h1:18nIHnGi6636UCz6m8i4DhaJ65T6EruyzmoQqI2BVDo= +sigs.k8s.io/structured-merge-diff/v4 v4.4.2 h1:MdmvkGuXi/8io6ixD5wud3vOLwc1rj0aNqRlpuvjmwA= +sigs.k8s.io/structured-merge-diff/v4 v4.4.2/go.mod h1:N8f93tFZh9U6vpxwRArLiikrE5/2tiu1w1AGfACIGE4= sigs.k8s.io/yaml v1.4.0 h1:Mk1wCc2gy/F0THH0TAp1QYyJNzRm2KCLy3o5ASXVI5E= sigs.k8s.io/yaml v1.4.0/go.mod h1:Ejl7/uTz7PSA4eKMyQCUTnhZYNmLIl+5c2lQPGR2BPY= software.sslmate.com/src/go-pkcs12 v0.4.0 h1:H2g08FrTvSFKUj+D309j1DPfk5APnIdAQAB8aEykJ5k= diff --git a/go.toolchain.branch b/go.toolchain.branch index 47469a20a..a2bebbeb7 100644 --- a/go.toolchain.branch +++ b/go.toolchain.branch @@ -1 +1 @@ -tailscale.go1.23 +tailscale.go1.25 diff --git a/go.toolchain.rev b/go.toolchain.rev index 5d87594c2..9ea6b37dc 100644 --- a/go.toolchain.rev +++ b/go.toolchain.rev @@ -1 +1 @@ -bf15628b759344c6fc7763795a405ba65b8be5d7 +5c01b77ad0d27a8bd4ef89ef7e713fd7043c5a91 diff --git a/go.toolchain.rev.sri b/go.toolchain.rev.sri new file mode 100644 index 000000000..a62a52599 --- /dev/null +++ b/go.toolchain.rev.sri @@ -0,0 +1 @@ +sha256-2TYziJLJrFOW2FehhahKficnDACJEwjuvVYyeQZbrcc= diff --git a/go.toolchain.version b/go.toolchain.version new file mode 100644 index 000000000..5bb76b575 --- /dev/null +++ b/go.toolchain.version @@ -0,0 +1 @@ +1.25.3 diff --git a/gokrazy/build.go b/gokrazy/build.go index 2392af0cb..c1ee1cbeb 100644 --- a/gokrazy/build.go +++ b/gokrazy/build.go @@ -11,7 +11,6 @@ package main import ( "bytes" - "cmp" "encoding/json" "errors" "flag" @@ -30,7 +29,6 @@ import ( var ( app = flag.String("app", "tsapp", "appliance name; one of the subdirectories of gokrazy/") bucket = flag.String("bucket", "tskrazy-import", "S3 bucket to upload disk image to while making AMI") - goArch = flag.String("arch", cmp.Or(os.Getenv("GOARCH"), "amd64"), "GOARCH architecture to build for: arm64 or amd64") build = flag.Bool("build", false, "if true, just build locally and stop, without uploading") ) @@ -54,6 +52,26 @@ func findMkfsExt4() (string, error) { return "", errors.New("No mkfs.ext4 found on system") } +var conf gokrazyConfig + +// gokrazyConfig is the subset of gokrazy/internal/config.Struct +// that we care about. +type gokrazyConfig struct { + // Environment is os.Environment pairs to use when + // building userspace. + // See https://gokrazy.org/userguide/instance-config/#environment + Environment []string +} + +func (c *gokrazyConfig) GOARCH() string { + for _, e := range c.Environment { + if v, ok := strings.CutPrefix(e, "GOARCH="); ok { + return v + } + } + return "" +} + func main() { flag.Parse() @@ -61,6 +79,19 @@ func main() { log.Fatalf("--app must be non-empty name such as 'tsapp' or 'natlabapp'") } + confJSON, err := os.ReadFile(filepath.Join(*app, "config.json")) + if err != nil { + log.Fatalf("reading config.json: %v", err) + } + if err := json.Unmarshal(confJSON, &conf); err != nil { + log.Fatalf("unmarshaling config.json: %v", err) + } + switch conf.GOARCH() { + case "amd64", "arm64": + default: + log.Fatalf("config.json GOARCH %q must be amd64 or arm64", conf.GOARCH()) + } + if err := buildImage(); err != nil { log.Fatalf("build image: %v", err) } @@ -106,7 +137,6 @@ func buildImage() error { // Build the tsapp.img var buf bytes.Buffer cmd := exec.Command("go", "run", - "-exec=env GOOS=linux GOARCH="+*goArch+" ", "github.com/gokrazy/tools/cmd/gok", "--parent_dir="+dir, "--instance="+*app, @@ -253,13 +283,13 @@ func waitForImportSnapshot(importTaskID string) (snapID string, err error) { func makeAMI(name, ebsSnapID string) (ami string, err error) { var arch string - switch *goArch { + switch conf.GOARCH() { case "arm64": arch = "arm64" case "amd64": arch = "x86_64" default: - return "", fmt.Errorf("unknown arch %q", *goArch) + return "", fmt.Errorf("unknown arch %q", conf.GOARCH()) } out, err := exec.Command("aws", "ec2", "register-image", "--name", name, diff --git a/gokrazy/go.mod b/gokrazy/go.mod index a9ba5a07d..f7483f41d 100644 --- a/gokrazy/go.mod +++ b/gokrazy/go.mod @@ -1,13 +1,13 @@ module tailscale.com/gokrazy -go 1.23.1 +go 1.23 -require github.com/gokrazy/tools v0.0.0-20240730192548-9f81add3a91e +require github.com/gokrazy/tools v0.0.0-20250128200151-63160424957c require ( github.com/breml/rootcerts v0.2.10 // indirect github.com/donovanhide/eventsource v0.0.0-20210830082556-c59027999da0 // indirect - github.com/gokrazy/internal v0.0.0-20240629150625-a0f1dee26ef5 // indirect + github.com/gokrazy/internal v0.0.0-20250126213949-423a5b587b57 // indirect github.com/gokrazy/updater v0.0.0-20230215172637-813ccc7f21e2 // indirect github.com/google/renameio/v2 v2.0.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect @@ -15,9 +15,5 @@ require ( github.com/spf13/pflag v1.0.5 // indirect golang.org/x/mod v0.11.0 // indirect golang.org/x/sync v0.1.0 // indirect - golang.org/x/sys v0.20.0 // indirect + golang.org/x/sys v0.28.0 // indirect ) - -replace github.com/gokrazy/gokrazy => github.com/tailscale/gokrazy v0.0.0-20240812224643-6b21ddf64678 - -replace github.com/gokrazy/tools => github.com/tailscale/gokrazy-tools v0.0.0-20240730192548-9f81add3a91e diff --git a/gokrazy/go.sum b/gokrazy/go.sum index dfac8ca37..170d15b3d 100644 --- a/gokrazy/go.sum +++ b/gokrazy/go.sum @@ -3,8 +3,10 @@ github.com/breml/rootcerts v0.2.10/go.mod h1:24FDtzYMpqIeYC7QzaE8VPRQaFZU5TIUDly github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= github.com/donovanhide/eventsource v0.0.0-20210830082556-c59027999da0 h1:C7t6eeMaEQVy6e8CarIhscYQlNmw5e3G36y7l7Y21Ao= github.com/donovanhide/eventsource v0.0.0-20210830082556-c59027999da0/go.mod h1:56wL82FO0bfMU5RvfXoIwSOP2ggqqxT+tAfNEIyxuHw= -github.com/gokrazy/internal v0.0.0-20240629150625-a0f1dee26ef5 h1:XDklMxV0pE5jWiNaoo5TzvWfqdoiRRScmr4ZtDzE4Uw= -github.com/gokrazy/internal v0.0.0-20240629150625-a0f1dee26ef5/go.mod h1:t3ZirVhcs9bH+fPAJuGh51rzT7sVCZ9yfXvszf0ZjF0= +github.com/gokrazy/internal v0.0.0-20250126213949-423a5b587b57 h1:f5bEvO4we3fbfiBkECrrUgWQ8OH6J3SdB2Dwxid/Yx4= +github.com/gokrazy/internal v0.0.0-20250126213949-423a5b587b57/go.mod h1:SJG1KwuJQXFEoBgryaNCkMbdISyovDgZd0xmXJRZmiw= +github.com/gokrazy/tools v0.0.0-20250128200151-63160424957c h1:iEbS8GrNOn671ze8J/AfrYFEVzf8qMx8aR5K0VxPK2w= +github.com/gokrazy/tools v0.0.0-20250128200151-63160424957c/go.mod h1:f2vZhnaPzy92+Bjpx1iuZHK7VuaJx6SNCWQWmu23HZA= github.com/gokrazy/updater v0.0.0-20230215172637-813ccc7f21e2 h1:kBY5R1tSf+EYZ+QaSrofLaVJtBqYsVNVBWkdMq3Smcg= github.com/gokrazy/updater v0.0.0-20230215172637-813ccc7f21e2/go.mod h1:PYOvzGOL4nlBmuxu7IyKQTFLaxr61+WPRNRzVtuYOHw= github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= @@ -19,14 +21,12 @@ github.com/spf13/cobra v1.6.1 h1:o94oiPyS4KD1mPy2fmcYYHHfCxLqYjJOhGsCHFZtEzA= github.com/spf13/cobra v1.6.1/go.mod h1:IOw/AERYS7UzyrGinqmz6HLUo219MORXGxhbaJUqzrY= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= -github.com/tailscale/gokrazy-tools v0.0.0-20240730192548-9f81add3a91e h1:3/xIc1QCvnKL7BCLng9od98HEvxCadjvqiI/bN+Twso= -github.com/tailscale/gokrazy-tools v0.0.0-20240730192548-9f81add3a91e/go.mod h1:eTZ0QsugEPFU5UAQ/87bKMkPxQuTNa7+iFAIahOFwRg= golang.org/x/mod v0.11.0 h1:bUO06HqtnRcc/7l71XBe4WcqTZ+3AH1J59zWDDwLKgU= golang.org/x/mod v0.11.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y= -golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA= +golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/tools v0.1.12 h1:VveCTK38A2rkS8ZqFY25HIDFscX5X9OoEhJd3quQmXU= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/gokrazy/natlabapp.arm64/builddir/tailscale.com/go.sum b/gokrazy/natlabapp.arm64/builddir/tailscale.com/go.sum index 9123439ed..ae814f316 100644 --- a/gokrazy/natlabapp.arm64/builddir/tailscale.com/go.sum +++ b/gokrazy/natlabapp.arm64/builddir/tailscale.com/go.sum @@ -4,32 +4,58 @@ github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be h1:9AeTilPcZAjCFI github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be/go.mod h1:ySMOLuWl6zY27l47sB3qLNK6tF2fkHG55UZxx8oIVo4= github.com/aws/aws-sdk-go-v2 v1.24.1 h1:xAojnj+ktS95YZlDf0zxWBkbFtymPeDP+rvUQIH3uAU= github.com/aws/aws-sdk-go-v2 v1.24.1/go.mod h1:LNh45Br1YAkEKaAqvmE1m8FUx6a5b/V0oAKV7of29b4= +github.com/aws/aws-sdk-go-v2 v1.36.0 h1:b1wM5CcE65Ujwn565qcwgtOTT1aT4ADOHHgglKjG7fk= +github.com/aws/aws-sdk-go-v2 v1.36.0/go.mod h1:5PMILGVKiW32oDzjj6RU52yrNrDPUHcbZQYr1sM7qmM= github.com/aws/aws-sdk-go-v2/config v1.26.5 h1:lodGSevz7d+kkFJodfauThRxK9mdJbyutUxGq1NNhvw= github.com/aws/aws-sdk-go-v2/config v1.26.5/go.mod h1:DxHrz6diQJOc9EwDslVRh84VjjrE17g+pVZXUeSxaDU= +github.com/aws/aws-sdk-go-v2/config v1.29.5 h1:4lS2IB+wwkj5J43Tq/AwvnscBerBJtQQ6YS7puzCI1k= +github.com/aws/aws-sdk-go-v2/config v1.29.5/go.mod h1:SNzldMlDVbN6nWxM7XsUiNXPSa1LWlqiXtvh/1PrJGg= github.com/aws/aws-sdk-go-v2/credentials v1.16.16 h1:8q6Rliyv0aUFAVtzaldUEcS+T5gbadPbWdV1WcAddK8= github.com/aws/aws-sdk-go-v2/credentials v1.16.16/go.mod h1:UHVZrdUsv63hPXFo1H7c5fEneoVo9UXiz36QG1GEPi0= +github.com/aws/aws-sdk-go-v2/credentials v1.17.58 h1:/d7FUpAPU8Lf2KUdjniQvfNdlMID0Sd9pS23FJ3SS9Y= +github.com/aws/aws-sdk-go-v2/credentials v1.17.58/go.mod h1:aVYW33Ow10CyMQGFgC0ptMRIqJWvJ4nxZb0sUiuQT/A= github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.14.11 h1:c5I5iH+DZcH3xOIMlz3/tCKJDaHFwYEmxvlh2fAcFo8= github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.14.11/go.mod h1:cRrYDYAMUohBJUtUnOhydaMHtiK/1NZ0Otc9lIb6O0Y= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.27 h1:7lOW8NUwE9UZekS1DYoiPdVAqZ6A+LheHWb+mHbNOq8= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.27/go.mod h1:w1BASFIPOPUae7AgaH4SbjNbfdkxuggLyGfNFTn8ITY= github.com/aws/aws-sdk-go-v2/internal/configsources v1.2.10 h1:vF+Zgd9s+H4vOXd5BMaPWykta2a6Ih0AKLq/X6NYKn4= github.com/aws/aws-sdk-go-v2/internal/configsources v1.2.10/go.mod h1:6BkRjejp/GR4411UGqkX8+wFMbFbqsUIimfK4XjOKR4= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.31 h1:lWm9ucLSRFiI4dQQafLrEOmEDGry3Swrz0BIRdiHJqQ= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.31/go.mod h1:Huu6GG0YTfbPphQkDSo4dEGmQRTKb9k9G7RdtyQWxuI= github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.5.10 h1:nYPe006ktcqUji8S2mqXf9c/7NdiKriOwMvWQHgYztw= github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.5.10/go.mod h1:6UV4SZkVvmODfXKql4LCbaZUpF7HO2BX38FgBf9ZOLw= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.31 h1:ACxDklUKKXb48+eg5ROZXi1vDgfMyfIA/WyvqHcHI0o= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.31/go.mod h1:yadnfsDwqXeVaohbGc/RaD287PuyRw2wugkh5ZL2J6k= github.com/aws/aws-sdk-go-v2/internal/ini v1.7.2 h1:GrSw8s0Gs/5zZ0SX+gX4zQjRnRsMJDJ2sLur1gRBhEM= github.com/aws/aws-sdk-go-v2/internal/ini v1.7.2/go.mod h1:6fQQgfuGmw8Al/3M2IgIllycxV7ZW7WCdVSqfBeUiCY= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.2 h1:Pg9URiobXy85kgFev3og2CuOZ8JZUBENF+dcgWBaYNk= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.2/go.mod h1:FbtygfRFze9usAadmnGJNc8KsP346kEe+y2/oyhGAGc= github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.10.4 h1:/b31bi3YVNlkzkBrm9LfpaKoaYZUxIAj4sHfOTmLfqw= github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.10.4/go.mod h1:2aGXHFmbInwgP9ZfpmdIfOELL79zhdNYNmReK8qDfdQ= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.2 h1:D4oz8/CzT9bAEYtVhSBmFj2dNOtaHOtMKc2vHBwYizA= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.2/go.mod h1:Za3IHqTQ+yNcRHxu1OFucBh0ACZT4j4VQFF0BqpZcLY= github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.10.10 h1:DBYTXwIGQSGs9w4jKm60F5dmCQ3EEruxdc0MFh+3EY4= github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.10.10/go.mod h1:wohMUQiFdzo0NtxbBg0mSRGZ4vL3n0dKjLTINdcIino= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.12 h1:O+8vD2rGjfihBewr5bT+QUfYUHIxCVgG61LHoT59shM= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.12/go.mod h1:usVdWJaosa66NMvmCrr08NcWDBRv4E6+YFG2pUdw1Lk= github.com/aws/aws-sdk-go-v2/service/ssm v1.44.7 h1:a8HvP/+ew3tKwSXqL3BCSjiuicr+XTU2eFYeogV9GJE= github.com/aws/aws-sdk-go-v2/service/ssm v1.44.7/go.mod h1:Q7XIWsMo0JcMpI/6TGD6XXcXcV1DbTj6e9BKNntIMIM= github.com/aws/aws-sdk-go-v2/service/sso v1.18.7 h1:eajuO3nykDPdYicLlP3AGgOyVN3MOlFmZv7WGTuJPow= github.com/aws/aws-sdk-go-v2/service/sso v1.18.7/go.mod h1:+mJNDdF+qiUlNKNC3fxn74WWNN+sOiGOEImje+3ScPM= +github.com/aws/aws-sdk-go-v2/service/sso v1.24.14 h1:c5WJ3iHz7rLIgArznb3JCSQT3uUMiz9DLZhIX+1G8ok= +github.com/aws/aws-sdk-go-v2/service/sso v1.24.14/go.mod h1:+JJQTxB6N4niArC14YNtxcQtwEqzS3o9Z32n7q33Rfs= github.com/aws/aws-sdk-go-v2/service/ssooidc v1.21.7 h1:QPMJf+Jw8E1l7zqhZmMlFw6w1NmfkfiSK8mS4zOx3BA= github.com/aws/aws-sdk-go-v2/service/ssooidc v1.21.7/go.mod h1:ykf3COxYI0UJmxcfcxcVuz7b6uADi1FkiUz6Eb7AgM8= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.13 h1:f1L/JtUkVODD+k1+IiSJUUv8A++2qVr+Xvb3xWXETMU= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.13/go.mod h1:tvqlFoja8/s0o+UruA1Nrezo/df0PzdunMDDurUfg6U= github.com/aws/aws-sdk-go-v2/service/sts v1.26.7 h1:NzO4Vrau795RkUdSHKEwiR01FaGzGOH1EETJ+5QHnm0= github.com/aws/aws-sdk-go-v2/service/sts v1.26.7/go.mod h1:6h2YuIoxaMSCFf5fi1EgZAwdfkGMgDY+DVfa61uLe4U= +github.com/aws/aws-sdk-go-v2/service/sts v1.33.13 h1:3LXNnmtH3TURctC23hnC0p/39Q5gre3FI7BNOiDcVWc= +github.com/aws/aws-sdk-go-v2/service/sts v1.33.13/go.mod h1:7Yn+p66q/jt38qMoVfNvjbm3D89mGBnkwDcijgtih8w= github.com/aws/smithy-go v1.19.0 h1:KWFKQV80DpP3vJrrA9sVAHQ5gc2z8i4EzrLhLlWXcBM= github.com/aws/smithy-go v1.19.0/go.mod h1:NukqUGpCZIILqqiV0NIjeFh24kd/FAa4beRb6nbIUPE= +github.com/aws/smithy-go v1.22.2 h1:6D9hW43xKFrRx/tXXfAlIZc4JI+yQe6snnWcQyxSyLQ= +github.com/aws/smithy-go v1.22.2/go.mod h1:irrKGvNn1InZwb2d7fkIRNucdfwR8R+Ts3wxYa/cJHg= github.com/bits-and-blooms/bitset v1.13.0 h1:bAQ9OPNFYbGHV6Nez0tmNI0RiEu7/hxlYJRUA0wFAVE= github.com/bits-and-blooms/bitset v1.13.0/go.mod h1:7hO7Gc7Pp1vODcmWvKMRA9BNmbv6a/7QIWpPxHddWR8= github.com/coder/websocket v1.8.12 h1:5bUXkEPPIbewrnkU8LTCLVaxi4N4J8ahufH2vlo4NAo= @@ -46,10 +72,14 @@ github.com/djherbis/times v1.6.0 h1:w2ctJ92J8fBvWPxugmXIv7Nz7Q3iDMKNx9v5ocVH20c= github.com/djherbis/times v1.6.0/go.mod h1:gOHeRAz2h+VJNZ5Gmc/o7iD9k4wW7NMVqieYCY99oc0= github.com/fxamacker/cbor/v2 v2.6.0 h1:sU6J2usfADwWlYDAFhZBQ6TnLFBHxgesMrQfQgk1tWA= github.com/fxamacker/cbor/v2 v2.6.0/go.mod h1:pxXPTn3joSm21Gbwsv0w9OSA2y1HFR9qXEeXQVeNoDQ= +github.com/fxamacker/cbor/v2 v2.7.0 h1:iM5WgngdRBanHcxugY4JySA0nk1wZorNOpTgCMedv5E= +github.com/fxamacker/cbor/v2 v2.7.0/go.mod h1:pxXPTn3joSm21Gbwsv0w9OSA2y1HFR9qXEeXQVeNoDQ= github.com/gaissmai/bart v0.11.1 h1:5Uv5XwsaFBRo4E5VBcb9TzY8B7zxFf+U7isDxqOrRfc= github.com/gaissmai/bart v0.11.1/go.mod h1:KHeYECXQiBjTzQz/om2tqn3sZF1J7hw9m6z41ftj3fg= github.com/go-json-experiment/json v0.0.0-20231102232822-2e55bd4e08b0 h1:ymLjT4f35nQbASLnvxEde4XOBL+Sn7rFuV+FOJqkljg= github.com/go-json-experiment/json v0.0.0-20231102232822-2e55bd4e08b0/go.mod h1:6daplAwHHGbUGib4990V3Il26O0OC4aRyvewaaAihaA= +github.com/go-json-experiment/json v0.0.0-20250103232110-6a9a0fde9288 h1:KbX3Z3CgiYlbaavUq3Cj9/MjpO+88S7/AGXzynVDv84= +github.com/go-json-experiment/json v0.0.0-20250103232110-6a9a0fde9288/go.mod h1:BWmvoE1Xia34f3l/ibJweyhrT+aROb/FQ6d+37F0e2s= github.com/godbus/dbus/v5 v5.1.1-0.20230522191255-76236955d466 h1:sQspH8M4niEijh3PFscJRLDnkL547IeP7kpPe3uUhEg= github.com/godbus/dbus/v5 v5.1.1-0.20230522191255-76236955d466/go.mod h1:ZiQxhyQ+bbbfxUKVvjfO498oPYvtYhZzycal3G/NHmU= github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da h1:oI5xCqsCo564l8iNU+DwB5epxmsaqB+rhGL0m5jtYqE= @@ -62,6 +92,8 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/gorilla/csrf v1.7.2 h1:oTUjx0vyf2T+wkrx09Trsev1TE+/EbDAeHtSTbtC2eI= github.com/gorilla/csrf v1.7.2/go.mod h1:F1Fj3KG23WYHE6gozCmBAezKookxbIvUJT+121wTuLk= +github.com/gorilla/csrf v1.7.3-0.20250123201450-9dd6af1f6d30 h1:fiJdrgVBkjZ5B1HJ2WQwNOaXB+QyYcNXTA3t1XYLz0M= +github.com/gorilla/csrf v1.7.3-0.20250123201450-9dd6af1f6d30/go.mod h1:F1Fj3KG23WYHE6gozCmBAezKookxbIvUJT+121wTuLk= github.com/gorilla/securecookie v1.1.2 h1:YCIWL56dvtr73r6715mJs5ZvhtnY73hBvEF8kXD8ePA= github.com/gorilla/securecookie v1.1.2/go.mod h1:NfCASbcHqRSY+3a8tlWJwsQap2VX5pwzwo4h3eOamfo= github.com/hdevalence/ed25519consensus v0.2.0 h1:37ICyZqdyj0lAZ8P4D1d1id3HqbbG1N3iBb1Tb4rdcU= @@ -70,6 +102,8 @@ github.com/illarion/gonotify v1.0.1 h1:F1d+0Fgbq/sDWjj/r66ekjDG+IDeecQKUFH4wNwso github.com/illarion/gonotify v1.0.1/go.mod h1:zt5pmDofZpU1f8aqlK0+95eQhoEAn/d4G4B/FjVW4jE= github.com/illarion/gonotify/v2 v2.0.2 h1:oDH5yvxq9oiQGWUeut42uShcWzOy/hsT9E7pvO95+kQ= github.com/illarion/gonotify/v2 v2.0.2/go.mod h1:38oIJTgFqupkEydkkClkbL6i5lXV/bxdH9do5TALPEE= +github.com/illarion/gonotify/v2 v2.0.3 h1:B6+SKPo/0Sw8cRJh1aLzNEeNVFfzE3c6N+o+vyxM+9A= +github.com/illarion/gonotify/v2 v2.0.3/go.mod h1:38oIJTgFqupkEydkkClkbL6i5lXV/bxdH9do5TALPEE= github.com/insomniacslk/dhcp v0.0.0-20231206064809-8c70d406f6d2 h1:9K06NfxkBh25x56yVhWWlKFE8YpicaSfHwoV8SFbueA= github.com/insomniacslk/dhcp v0.0.0-20231206064809-8c70d406f6d2/go.mod h1:3A9PQ1cunSDF/1rbTq99Ts4pVnycWg+vlPkfeD2NLFI= github.com/jellydator/ttlcache/v3 v3.1.0 h1:0gPFG0IHHP6xyUyXq+JaD8fwkDCqgqwohXNJBcYE71g= @@ -84,6 +118,8 @@ github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 h1:Z9n2FFNU github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51/go.mod h1:CzGEWj7cYgsdH8dAjBGEr58BoE7ScuLd+fwFZ44+/x8= github.com/klauspost/compress v1.17.4 h1:Ej5ixsIri7BrIjBkRZLTo6ghwrEtHFk7ijlczPW4fZ4= github.com/klauspost/compress v1.17.4/go.mod h1:/dCuZOvVtNoHsyb+cuJD3itjs3NbnF6KH9zAO4BDxPM= +github.com/klauspost/compress v1.17.11 h1:In6xLpyWOi1+C7tXUUWv2ot1QvBjxevKAaI6IXrJmUc= +github.com/klauspost/compress v1.17.11/go.mod h1:pMDklpSncoRMuLFrf1W9Ss9KT+0rH90U12bZKk7uwG0= github.com/kortschak/wol v0.0.0-20200729010619-da482cc4850a h1:+RR6SqnTkDLWyICxS1xpjCi/3dhyV+TgZwA6Ww3KncQ= github.com/kortschak/wol v0.0.0-20200729010619-da482cc4850a/go.mod h1:YTtCCM3ryyfiu4F7t8HQ1mxvp1UBdWM2r6Xa+nGWvDk= github.com/kr/fs v0.1.0 h1:Jskdu9ieNAYnjxsi0LbQp1ulIKZV1LAFgK1tWhpZgl8= @@ -96,6 +132,8 @@ github.com/mdlayher/genetlink v1.3.2 h1:KdrNKe+CTu+IbZnm/GVUMXSqBBLqcGpRDa0xkQy5 github.com/mdlayher/genetlink v1.3.2/go.mod h1:tcC3pkCrPUGIKKsCsp0B3AdaaKuHtaxoJRz3cc+528o= github.com/mdlayher/netlink v1.7.2 h1:/UtM3ofJap7Vl4QWCPDGXY8d3GIY2UGSDbK+QWmY8/g= github.com/mdlayher/netlink v1.7.2/go.mod h1:xraEF7uJbxLhc5fpHL4cPe221LI2bdttWlU+ZGLfQSw= +github.com/mdlayher/netlink v1.7.3-0.20250113171957-fbb4dce95f42 h1:A1Cq6Ysb0GM0tpKMbdCXCIfBclan4oHk1Jb+Hrejirg= +github.com/mdlayher/netlink v1.7.3-0.20250113171957-fbb4dce95f42/go.mod h1:BB4YCPDOzfy7FniQ/lxuYQ3dgmM2cZumHbK8RpTjN2o= github.com/mdlayher/sdnotify v1.0.0 h1:Ma9XeLVN/l0qpyx1tNeMSeTjCPH6NtuD6/N9XdTlQ3c= github.com/mdlayher/sdnotify v1.0.0/go.mod h1:HQUmpM4XgYkhDLtd+Uad8ZFK1T9D5+pNxnXQjCeJlGE= github.com/mdlayher/socket v0.5.0 h1:ilICZmJcQz70vrWVes1MFera4jGiWNocSkykwwoy3XI= @@ -126,12 +164,18 @@ github.com/tailscale/netlink v1.1.1-0.20240822203006-4d49adab4de7 h1:uFsXVBE9Qr4 github.com/tailscale/netlink v1.1.1-0.20240822203006-4d49adab4de7/go.mod h1:NzVQi3Mleb+qzq8VmcWpSkcSYxXIg0DkI6XDzpVkhJ0= github.com/tailscale/peercred v0.0.0-20240214030740-b535050b2aa4 h1:Gz0rz40FvFVLTBk/K8UNAenb36EbDSnh+q7Z9ldcC8w= github.com/tailscale/peercred v0.0.0-20240214030740-b535050b2aa4/go.mod h1:phI29ccmHQBc+wvroosENp1IF9195449VDnFDhJ4rJU= +github.com/tailscale/peercred v0.0.0-20250107143737-35a0c7bd7edc h1:24heQPtnFR+yfntqhI3oAu9i27nEojcQ4NuBQOo5ZFA= +github.com/tailscale/peercred v0.0.0-20250107143737-35a0c7bd7edc/go.mod h1:f93CXfllFsO9ZQVq+Zocb1Gp4G5Fz0b0rXHLOzt/Djc= github.com/tailscale/web-client-prebuilt v0.0.0-20240226180453-5db17b287bf1 h1:tdUdyPqJ0C97SJfjB9tW6EylTtreyee9C44de+UBG0g= github.com/tailscale/web-client-prebuilt v0.0.0-20240226180453-5db17b287bf1/go.mod h1:agQPE6y6ldqCOui2gkIh7ZMztTkIQKH049tv8siLuNQ= +github.com/tailscale/web-client-prebuilt v0.0.0-20250124233751-d4cd19a26976 h1:UBPHPtv8+nEAy2PD8RyAhOYvau1ek0HDJqLS/Pysi14= +github.com/tailscale/web-client-prebuilt v0.0.0-20250124233751-d4cd19a26976/go.mod h1:agQPE6y6ldqCOui2gkIh7ZMztTkIQKH049tv8siLuNQ= github.com/tailscale/wireguard-go v0.0.0-20240705152531-2f5d148bcfe1 h1:ycpNCSYwzZ7x4G4ioPNtKQmIY0G/3o4pVf8wCZq6blY= github.com/tailscale/wireguard-go v0.0.0-20240705152531-2f5d148bcfe1/go.mod h1:BOm5fXUBFM+m9woLNBoxI9TaBXXhGNP50LX/TGIvGb4= github.com/tailscale/wireguard-go v0.0.0-20240731203015-71393c576b98 h1:RNpJrXfI5u6e+uzyIzvmnXbhmhdRkVf//90sMBH3lso= github.com/tailscale/wireguard-go v0.0.0-20240731203015-71393c576b98/go.mod h1:BOm5fXUBFM+m9woLNBoxI9TaBXXhGNP50LX/TGIvGb4= +github.com/tailscale/wireguard-go v0.0.0-20250107165329-0b8b35511f19 h1:BcEJP2ewTIK2ZCsqgl6YGpuO6+oKqqag5HHb7ehljKw= +github.com/tailscale/wireguard-go v0.0.0-20250107165329-0b8b35511f19/go.mod h1:BOm5fXUBFM+m9woLNBoxI9TaBXXhGNP50LX/TGIvGb4= github.com/tailscale/xnet v0.0.0-20240117122442-62b9a7c569f9 h1:81P7rjnikHKTJ75EkjppvbwUfKHDHYk6LJpO5PZy8pA= github.com/tailscale/xnet v0.0.0-20240117122442-62b9a7c569f9/go.mod h1:orPd6JZXXRyuDusYilywte7k094d7dycXXU5YnWsrwg= github.com/tailscale/xnet v0.0.0-20240729143630-8497ac4dab2e h1:zOGKqN5D5hHhiYUp091JqK7DPCqSARyUfduhGUY8Bek= @@ -144,6 +188,8 @@ github.com/u-root/u-root v0.12.0 h1:K0AuBFriwr0w/PGS3HawiAw89e3+MU7ks80GpghAsNs= github.com/u-root/u-root v0.12.0/go.mod h1:FYjTOh4IkIZHhjsd17lb8nYW6udgXdJhG1c0r6u0arI= github.com/u-root/uio v0.0.0-20240118234441-a3c409a6018e h1:BA9O3BmlTmpjbvajAwzWx4Wo2TRVdpPXZEeemGQcajw= github.com/u-root/uio v0.0.0-20240118234441-a3c409a6018e/go.mod h1:eLL9Nub3yfAho7qB0MzZizFhTU2QkLeoVsWdHtDW264= +github.com/u-root/uio v0.0.0-20240224005618-d2acac8f3701 h1:pyC9PaHYZFgEKFdlp3G8RaCKgVpHZnecvArXvPXcFkM= +github.com/u-root/uio v0.0.0-20240224005618-d2acac8f3701/go.mod h1:P3a5rG4X7tI17Nn3aOIAYr5HbIMukwXG0urG0WuL8OA= github.com/vishvananda/netlink v1.2.1-beta.2 h1:Llsql0lnQEbHj0I1OuKyp8otXp0r3q0mPkuhwHfStVs= github.com/vishvananda/netlink v1.2.1-beta.2/go.mod h1:twkDnbuQxJYemMlGd4JFIcuhgX83tXhKS2B/PRMpOho= github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8= @@ -152,42 +198,66 @@ github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM= github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg= go4.org/mem v0.0.0-20220726221520-4f986261bf13 h1:CbZeCBZ0aZj8EfVgnqQcYZgf0lpZ3H9rmp5nkDTAst8= go4.org/mem v0.0.0-20220726221520-4f986261bf13/go.mod h1:reUoABIJ9ikfM5sgtSF3Wushcza7+WeD01VB9Lirh3g= +go4.org/mem v0.0.0-20240501181205-ae6ca9944745 h1:Tl++JLUCe4sxGu8cTpDzRLd3tN7US4hOxG5YpKCzkek= +go4.org/mem v0.0.0-20240501181205-ae6ca9944745/go.mod h1:reUoABIJ9ikfM5sgtSF3Wushcza7+WeD01VB9Lirh3g= go4.org/netipx v0.0.0-20231129151722-fdeea329fbba h1:0b9z3AuHCjxk0x/opv64kcgZLBseWJUpBw5I82+2U4M= go4.org/netipx v0.0.0-20231129151722-fdeea329fbba/go.mod h1:PLyyIXexvUFg3Owu6p/WfdlivPbZJsZdgWZlrGope/Y= golang.org/x/crypto v0.24.0 h1:mnl8DM0o513X8fdIkmyFE/5hTYxbwYOjDS/+rK6qpRI= golang.org/x/crypto v0.24.0/go.mod h1:Z1PMYSOR5nyMcyAVAIQSKCDwalqy85Aqn1x3Ws4L5DM= golang.org/x/crypto v0.25.0 h1:ypSNr+bnYL2YhwoMt2zPxHFmbAN1KZs/njMG3hxUp30= golang.org/x/crypto v0.25.0/go.mod h1:T+wALwcMOSE0kXgUAnPAHqTLW+XHgcELELW8VaDgm/M= +golang.org/x/crypto v0.32.1-0.20250118192723-a8ea4be81f07 h1:Z+Zg+aXJYq6f4TK2E4H+vZkQ4dJAWnInXDR6hM9znxo= +golang.org/x/crypto v0.32.1-0.20250118192723-a8ea4be81f07/go.mod h1:ZnnJkOaASj8g0AjIduWNlq2NRxL0PlBrbKVyZ6V/Ugc= golang.org/x/exp v0.0.0-20240119083558-1b970713d09a h1:Q8/wZp0KX97QFTc2ywcOE0YRjZPVIx+MXInMzdvQqcA= golang.org/x/exp v0.0.0-20240119083558-1b970713d09a/go.mod h1:idGWGoKP1toJGkd5/ig9ZLuPcZBC3ewk7SzmH0uou08= +golang.org/x/exp v0.0.0-20250106191152-7588d65b2ba8 h1:yqrTHse8TCMW1M1ZCP+VAR/l0kKxwaAIqN/il7x4voA= +golang.org/x/exp v0.0.0-20250106191152-7588d65b2ba8/go.mod h1:tujkw807nyEEAamNbDrEGzRav+ilXA7PCRAd6xsmwiU= golang.org/x/net v0.26.0 h1:soB7SVo0PWrY4vPW/+ay0jKDNScG2X9wFeYlXIvJsOQ= golang.org/x/net v0.26.0/go.mod h1:5YKkiSynbBIh3p6iOc/vibscux0x38BZDkn8sCUPxHE= golang.org/x/net v0.27.0 h1:5K3Njcw06/l2y9vpGCSdcxWOYHOUk3dVNGDXN+FvAys= golang.org/x/net v0.27.0/go.mod h1:dDi0PyhWNoiUOrAS8uXv/vnScO4wnHQO4mj9fn/RytE= +golang.org/x/net v0.34.0 h1:Mb7Mrk043xzHgnRM88suvJFwzVrRfHEHJEl5/71CKw0= +golang.org/x/net v0.34.0/go.mod h1:di0qlW3YNM5oh6GqDGQr92MyTozJPmybPK4Ev/Gm31k= golang.org/x/oauth2 v0.16.0 h1:aDkGMBSYxElaoP81NpoUoz2oo2R2wHdZpGToUxfyQrQ= golang.org/x/oauth2 v0.16.0/go.mod h1:hqZ+0LWXsiVoZpeld6jVt06P3adbS2Uu911W1SsJv2o= +golang.org/x/oauth2 v0.25.0 h1:CY4y7XT9v0cRI9oupztF8AgiIu99L/ksR/Xp/6jrZ70= +golang.org/x/oauth2 v0.25.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI= golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ= +golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.21.0 h1:rF+pYz3DAGSQAxAu1CbC7catZg4ebC4UIeIhKxBZvws= golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.22.0 h1:RI27ohtqKCnwULzJLqkv897zojh5/DwS/ENaMzUOaWI= golang.org/x/sys v0.22.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.29.1-0.20250107080300-1c14dcadc3ab h1:BMkEEWYOjkvOX7+YKOGbp6jCyQ5pR2j0Ah47p1Vdsx4= +golang.org/x/sys v0.29.1-0.20250107080300-1c14dcadc3ab/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.21.0 h1:WVXCp+/EBEHOj53Rvu+7KiT/iElMrO8ACK16SMZ3jaA= golang.org/x/term v0.21.0/go.mod h1:ooXLefLobQVslOqselCNF4SxFAaoS6KujMbsGzSDmX0= golang.org/x/term v0.22.0 h1:BbsgPEJULsl2fV/AT3v15Mjva5yXKQDyKf+TbDz7QJk= golang.org/x/term v0.22.0/go.mod h1:F3qCibpT5AMpCRfhfT53vVJwhLtIVHhB9XDjfFvnMI4= +golang.org/x/term v0.28.0 h1:/Ts8HFuMR2E6IP/jlo7QVLZHggjKQbhu/7H0LJFr3Gg= +golang.org/x/term v0.28.0/go.mod h1:Sw/lC2IAUZ92udQNf3WodGtn4k/XoLyZoh8v/8uiwek= golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4= golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI= +golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo= +golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk= golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= +golang.org/x/time v0.9.0 h1:EsRrnYcQiGH+5FfbgvV4AP7qEZstoyrHB0DzarOQ4ZY= +golang.org/x/time v0.9.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= gvisor.dev/gvisor v0.0.0-20240306221502-ee1e1f6070e3 h1:/8/t5pz/mgdRXhYOIeqqYhFAQLE4DDGegc0Y4ZjyFJM= gvisor.dev/gvisor v0.0.0-20240306221502-ee1e1f6070e3/go.mod h1:NQHVAzMwvZ+Qe3ElSiHmq9RUm1MdNHpUZ52fiEqvn+0= gvisor.dev/gvisor v0.0.0-20240722211153-64c016c92987 h1:TU8z2Lh3Bbq77w0t1eG8yRlLcNHzZu3x6mhoH2Mk0c8= gvisor.dev/gvisor v0.0.0-20240722211153-64c016c92987/go.mod h1:sxc3Uvk/vHcd3tj7/DHVBoR5wvWT/MmRq2pj7HRJnwU= +gvisor.dev/gvisor v0.0.0-20250205023644-9414b50a5633 h1:2gap+Kh/3F47cO6hAu3idFvsJ0ue6TRcEi2IUkv/F8k= +gvisor.dev/gvisor v0.0.0-20250205023644-9414b50a5633/go.mod h1:5DMfjtclAbTIjbXqO1qCe2K5GKKxWz2JHvCChuTcJEM= k8s.io/client-go v0.30.1 h1:uC/Ir6A3R46wdkgCV3vbLyNOYyCJ8oZnjtJGKfytl/Q= k8s.io/client-go v0.30.1/go.mod h1:wrAqLNs2trwiCH/wxxmT/x3hKVH9PuV0GGW0oDoHVqc= k8s.io/client-go v0.30.3 h1:bHrJu3xQZNXIi8/MoxYtZBBWQQXwy16zqJwloXXfD3k= k8s.io/client-go v0.30.3/go.mod h1:8d4pf8vYu665/kUbsxWAQ/JDBNWqfFeZnvFiVdmx89U= +k8s.io/client-go v0.32.0 h1:DimtMcnN/JIKZcrSrstiwvvZvLjG0aSxy8PxN8IChp8= +k8s.io/client-go v0.32.0/go.mod h1:boDWvdM1Drk4NJj/VddSLnx59X3OPgwrOo0vGbtq9+8= nhooyr.io/websocket v1.8.10 h1:mv4p+MnGrLDcPlBoWsvPP7XCzTYMXP9F9eIGoKbgx7Q= nhooyr.io/websocket v1.8.10/go.mod h1:rN9OFWIUwuxg4fR5tELlYC04bXYowCP9GX47ivo2l+c= sigs.k8s.io/yaml v1.4.0 h1:Mk1wCc2gy/F0THH0TAp1QYyJNzRm2KCLy3o5ASXVI5E= diff --git a/gokrazy/natlabapp.arm64/config.json b/gokrazy/natlabapp.arm64/config.json index 2577f61a5..2ba9a20f9 100644 --- a/gokrazy/natlabapp.arm64/config.json +++ b/gokrazy/natlabapp.arm64/config.json @@ -20,6 +20,10 @@ } } }, + "Environment": [ + "GOOS=linux", + "GOARCH=arm64" + ], "KernelPackage": "github.com/gokrazy/kernel.arm64", "FirmwarePackage": "github.com/gokrazy/kernel.arm64", "EEPROMPackage": "", diff --git a/gokrazy/natlabapp/builddir/tailscale.com/go.sum b/gokrazy/natlabapp/builddir/tailscale.com/go.sum index baa378c46..25f15059d 100644 --- a/gokrazy/natlabapp/builddir/tailscale.com/go.sum +++ b/gokrazy/natlabapp/builddir/tailscale.com/go.sum @@ -4,32 +4,58 @@ github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be h1:9AeTilPcZAjCFI github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be/go.mod h1:ySMOLuWl6zY27l47sB3qLNK6tF2fkHG55UZxx8oIVo4= github.com/aws/aws-sdk-go-v2 v1.24.1 h1:xAojnj+ktS95YZlDf0zxWBkbFtymPeDP+rvUQIH3uAU= github.com/aws/aws-sdk-go-v2 v1.24.1/go.mod h1:LNh45Br1YAkEKaAqvmE1m8FUx6a5b/V0oAKV7of29b4= +github.com/aws/aws-sdk-go-v2 v1.36.0 h1:b1wM5CcE65Ujwn565qcwgtOTT1aT4ADOHHgglKjG7fk= +github.com/aws/aws-sdk-go-v2 v1.36.0/go.mod h1:5PMILGVKiW32oDzjj6RU52yrNrDPUHcbZQYr1sM7qmM= github.com/aws/aws-sdk-go-v2/config v1.26.5 h1:lodGSevz7d+kkFJodfauThRxK9mdJbyutUxGq1NNhvw= github.com/aws/aws-sdk-go-v2/config v1.26.5/go.mod h1:DxHrz6diQJOc9EwDslVRh84VjjrE17g+pVZXUeSxaDU= +github.com/aws/aws-sdk-go-v2/config v1.29.5 h1:4lS2IB+wwkj5J43Tq/AwvnscBerBJtQQ6YS7puzCI1k= +github.com/aws/aws-sdk-go-v2/config v1.29.5/go.mod h1:SNzldMlDVbN6nWxM7XsUiNXPSa1LWlqiXtvh/1PrJGg= github.com/aws/aws-sdk-go-v2/credentials v1.16.16 h1:8q6Rliyv0aUFAVtzaldUEcS+T5gbadPbWdV1WcAddK8= github.com/aws/aws-sdk-go-v2/credentials v1.16.16/go.mod h1:UHVZrdUsv63hPXFo1H7c5fEneoVo9UXiz36QG1GEPi0= +github.com/aws/aws-sdk-go-v2/credentials v1.17.58 h1:/d7FUpAPU8Lf2KUdjniQvfNdlMID0Sd9pS23FJ3SS9Y= +github.com/aws/aws-sdk-go-v2/credentials v1.17.58/go.mod h1:aVYW33Ow10CyMQGFgC0ptMRIqJWvJ4nxZb0sUiuQT/A= github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.14.11 h1:c5I5iH+DZcH3xOIMlz3/tCKJDaHFwYEmxvlh2fAcFo8= github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.14.11/go.mod h1:cRrYDYAMUohBJUtUnOhydaMHtiK/1NZ0Otc9lIb6O0Y= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.27 h1:7lOW8NUwE9UZekS1DYoiPdVAqZ6A+LheHWb+mHbNOq8= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.27/go.mod h1:w1BASFIPOPUae7AgaH4SbjNbfdkxuggLyGfNFTn8ITY= github.com/aws/aws-sdk-go-v2/internal/configsources v1.2.10 h1:vF+Zgd9s+H4vOXd5BMaPWykta2a6Ih0AKLq/X6NYKn4= github.com/aws/aws-sdk-go-v2/internal/configsources v1.2.10/go.mod h1:6BkRjejp/GR4411UGqkX8+wFMbFbqsUIimfK4XjOKR4= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.31 h1:lWm9ucLSRFiI4dQQafLrEOmEDGry3Swrz0BIRdiHJqQ= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.31/go.mod h1:Huu6GG0YTfbPphQkDSo4dEGmQRTKb9k9G7RdtyQWxuI= github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.5.10 h1:nYPe006ktcqUji8S2mqXf9c/7NdiKriOwMvWQHgYztw= github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.5.10/go.mod h1:6UV4SZkVvmODfXKql4LCbaZUpF7HO2BX38FgBf9ZOLw= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.31 h1:ACxDklUKKXb48+eg5ROZXi1vDgfMyfIA/WyvqHcHI0o= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.31/go.mod h1:yadnfsDwqXeVaohbGc/RaD287PuyRw2wugkh5ZL2J6k= github.com/aws/aws-sdk-go-v2/internal/ini v1.7.2 h1:GrSw8s0Gs/5zZ0SX+gX4zQjRnRsMJDJ2sLur1gRBhEM= github.com/aws/aws-sdk-go-v2/internal/ini v1.7.2/go.mod h1:6fQQgfuGmw8Al/3M2IgIllycxV7ZW7WCdVSqfBeUiCY= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.2 h1:Pg9URiobXy85kgFev3og2CuOZ8JZUBENF+dcgWBaYNk= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.2/go.mod h1:FbtygfRFze9usAadmnGJNc8KsP346kEe+y2/oyhGAGc= github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.10.4 h1:/b31bi3YVNlkzkBrm9LfpaKoaYZUxIAj4sHfOTmLfqw= github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.10.4/go.mod h1:2aGXHFmbInwgP9ZfpmdIfOELL79zhdNYNmReK8qDfdQ= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.2 h1:D4oz8/CzT9bAEYtVhSBmFj2dNOtaHOtMKc2vHBwYizA= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.2/go.mod h1:Za3IHqTQ+yNcRHxu1OFucBh0ACZT4j4VQFF0BqpZcLY= github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.10.10 h1:DBYTXwIGQSGs9w4jKm60F5dmCQ3EEruxdc0MFh+3EY4= github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.10.10/go.mod h1:wohMUQiFdzo0NtxbBg0mSRGZ4vL3n0dKjLTINdcIino= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.12 h1:O+8vD2rGjfihBewr5bT+QUfYUHIxCVgG61LHoT59shM= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.12/go.mod h1:usVdWJaosa66NMvmCrr08NcWDBRv4E6+YFG2pUdw1Lk= github.com/aws/aws-sdk-go-v2/service/ssm v1.44.7 h1:a8HvP/+ew3tKwSXqL3BCSjiuicr+XTU2eFYeogV9GJE= github.com/aws/aws-sdk-go-v2/service/ssm v1.44.7/go.mod h1:Q7XIWsMo0JcMpI/6TGD6XXcXcV1DbTj6e9BKNntIMIM= github.com/aws/aws-sdk-go-v2/service/sso v1.18.7 h1:eajuO3nykDPdYicLlP3AGgOyVN3MOlFmZv7WGTuJPow= github.com/aws/aws-sdk-go-v2/service/sso v1.18.7/go.mod h1:+mJNDdF+qiUlNKNC3fxn74WWNN+sOiGOEImje+3ScPM= +github.com/aws/aws-sdk-go-v2/service/sso v1.24.14 h1:c5WJ3iHz7rLIgArznb3JCSQT3uUMiz9DLZhIX+1G8ok= +github.com/aws/aws-sdk-go-v2/service/sso v1.24.14/go.mod h1:+JJQTxB6N4niArC14YNtxcQtwEqzS3o9Z32n7q33Rfs= github.com/aws/aws-sdk-go-v2/service/ssooidc v1.21.7 h1:QPMJf+Jw8E1l7zqhZmMlFw6w1NmfkfiSK8mS4zOx3BA= github.com/aws/aws-sdk-go-v2/service/ssooidc v1.21.7/go.mod h1:ykf3COxYI0UJmxcfcxcVuz7b6uADi1FkiUz6Eb7AgM8= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.13 h1:f1L/JtUkVODD+k1+IiSJUUv8A++2qVr+Xvb3xWXETMU= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.13/go.mod h1:tvqlFoja8/s0o+UruA1Nrezo/df0PzdunMDDurUfg6U= github.com/aws/aws-sdk-go-v2/service/sts v1.26.7 h1:NzO4Vrau795RkUdSHKEwiR01FaGzGOH1EETJ+5QHnm0= github.com/aws/aws-sdk-go-v2/service/sts v1.26.7/go.mod h1:6h2YuIoxaMSCFf5fi1EgZAwdfkGMgDY+DVfa61uLe4U= +github.com/aws/aws-sdk-go-v2/service/sts v1.33.13 h1:3LXNnmtH3TURctC23hnC0p/39Q5gre3FI7BNOiDcVWc= +github.com/aws/aws-sdk-go-v2/service/sts v1.33.13/go.mod h1:7Yn+p66q/jt38qMoVfNvjbm3D89mGBnkwDcijgtih8w= github.com/aws/smithy-go v1.19.0 h1:KWFKQV80DpP3vJrrA9sVAHQ5gc2z8i4EzrLhLlWXcBM= github.com/aws/smithy-go v1.19.0/go.mod h1:NukqUGpCZIILqqiV0NIjeFh24kd/FAa4beRb6nbIUPE= +github.com/aws/smithy-go v1.22.2 h1:6D9hW43xKFrRx/tXXfAlIZc4JI+yQe6snnWcQyxSyLQ= +github.com/aws/smithy-go v1.22.2/go.mod h1:irrKGvNn1InZwb2d7fkIRNucdfwR8R+Ts3wxYa/cJHg= github.com/bits-and-blooms/bitset v1.13.0 h1:bAQ9OPNFYbGHV6Nez0tmNI0RiEu7/hxlYJRUA0wFAVE= github.com/bits-and-blooms/bitset v1.13.0/go.mod h1:7hO7Gc7Pp1vODcmWvKMRA9BNmbv6a/7QIWpPxHddWR8= github.com/coder/websocket v1.8.12 h1:5bUXkEPPIbewrnkU8LTCLVaxi4N4J8ahufH2vlo4NAo= @@ -46,10 +72,14 @@ github.com/djherbis/times v1.6.0 h1:w2ctJ92J8fBvWPxugmXIv7Nz7Q3iDMKNx9v5ocVH20c= github.com/djherbis/times v1.6.0/go.mod h1:gOHeRAz2h+VJNZ5Gmc/o7iD9k4wW7NMVqieYCY99oc0= github.com/fxamacker/cbor/v2 v2.6.0 h1:sU6J2usfADwWlYDAFhZBQ6TnLFBHxgesMrQfQgk1tWA= github.com/fxamacker/cbor/v2 v2.6.0/go.mod h1:pxXPTn3joSm21Gbwsv0w9OSA2y1HFR9qXEeXQVeNoDQ= +github.com/fxamacker/cbor/v2 v2.7.0 h1:iM5WgngdRBanHcxugY4JySA0nk1wZorNOpTgCMedv5E= +github.com/fxamacker/cbor/v2 v2.7.0/go.mod h1:pxXPTn3joSm21Gbwsv0w9OSA2y1HFR9qXEeXQVeNoDQ= github.com/gaissmai/bart v0.11.1 h1:5Uv5XwsaFBRo4E5VBcb9TzY8B7zxFf+U7isDxqOrRfc= github.com/gaissmai/bart v0.11.1/go.mod h1:KHeYECXQiBjTzQz/om2tqn3sZF1J7hw9m6z41ftj3fg= github.com/go-json-experiment/json v0.0.0-20231102232822-2e55bd4e08b0 h1:ymLjT4f35nQbASLnvxEde4XOBL+Sn7rFuV+FOJqkljg= github.com/go-json-experiment/json v0.0.0-20231102232822-2e55bd4e08b0/go.mod h1:6daplAwHHGbUGib4990V3Il26O0OC4aRyvewaaAihaA= +github.com/go-json-experiment/json v0.0.0-20250103232110-6a9a0fde9288 h1:KbX3Z3CgiYlbaavUq3Cj9/MjpO+88S7/AGXzynVDv84= +github.com/go-json-experiment/json v0.0.0-20250103232110-6a9a0fde9288/go.mod h1:BWmvoE1Xia34f3l/ibJweyhrT+aROb/FQ6d+37F0e2s= github.com/godbus/dbus/v5 v5.1.1-0.20230522191255-76236955d466 h1:sQspH8M4niEijh3PFscJRLDnkL547IeP7kpPe3uUhEg= github.com/godbus/dbus/v5 v5.1.1-0.20230522191255-76236955d466/go.mod h1:ZiQxhyQ+bbbfxUKVvjfO498oPYvtYhZzycal3G/NHmU= github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da h1:oI5xCqsCo564l8iNU+DwB5epxmsaqB+rhGL0m5jtYqE= @@ -62,6 +92,8 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/gorilla/csrf v1.7.2 h1:oTUjx0vyf2T+wkrx09Trsev1TE+/EbDAeHtSTbtC2eI= github.com/gorilla/csrf v1.7.2/go.mod h1:F1Fj3KG23WYHE6gozCmBAezKookxbIvUJT+121wTuLk= +github.com/gorilla/csrf v1.7.3-0.20250123201450-9dd6af1f6d30 h1:fiJdrgVBkjZ5B1HJ2WQwNOaXB+QyYcNXTA3t1XYLz0M= +github.com/gorilla/csrf v1.7.3-0.20250123201450-9dd6af1f6d30/go.mod h1:F1Fj3KG23WYHE6gozCmBAezKookxbIvUJT+121wTuLk= github.com/gorilla/securecookie v1.1.2 h1:YCIWL56dvtr73r6715mJs5ZvhtnY73hBvEF8kXD8ePA= github.com/gorilla/securecookie v1.1.2/go.mod h1:NfCASbcHqRSY+3a8tlWJwsQap2VX5pwzwo4h3eOamfo= github.com/hdevalence/ed25519consensus v0.2.0 h1:37ICyZqdyj0lAZ8P4D1d1id3HqbbG1N3iBb1Tb4rdcU= @@ -86,6 +118,8 @@ github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 h1:Z9n2FFNU github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51/go.mod h1:CzGEWj7cYgsdH8dAjBGEr58BoE7ScuLd+fwFZ44+/x8= github.com/klauspost/compress v1.17.4 h1:Ej5ixsIri7BrIjBkRZLTo6ghwrEtHFk7ijlczPW4fZ4= github.com/klauspost/compress v1.17.4/go.mod h1:/dCuZOvVtNoHsyb+cuJD3itjs3NbnF6KH9zAO4BDxPM= +github.com/klauspost/compress v1.17.11 h1:In6xLpyWOi1+C7tXUUWv2ot1QvBjxevKAaI6IXrJmUc= +github.com/klauspost/compress v1.17.11/go.mod h1:pMDklpSncoRMuLFrf1W9Ss9KT+0rH90U12bZKk7uwG0= github.com/kortschak/wol v0.0.0-20200729010619-da482cc4850a h1:+RR6SqnTkDLWyICxS1xpjCi/3dhyV+TgZwA6Ww3KncQ= github.com/kortschak/wol v0.0.0-20200729010619-da482cc4850a/go.mod h1:YTtCCM3ryyfiu4F7t8HQ1mxvp1UBdWM2r6Xa+nGWvDk= github.com/kr/fs v0.1.0 h1:Jskdu9ieNAYnjxsi0LbQp1ulIKZV1LAFgK1tWhpZgl8= @@ -98,6 +132,8 @@ github.com/mdlayher/genetlink v1.3.2 h1:KdrNKe+CTu+IbZnm/GVUMXSqBBLqcGpRDa0xkQy5 github.com/mdlayher/genetlink v1.3.2/go.mod h1:tcC3pkCrPUGIKKsCsp0B3AdaaKuHtaxoJRz3cc+528o= github.com/mdlayher/netlink v1.7.2 h1:/UtM3ofJap7Vl4QWCPDGXY8d3GIY2UGSDbK+QWmY8/g= github.com/mdlayher/netlink v1.7.2/go.mod h1:xraEF7uJbxLhc5fpHL4cPe221LI2bdttWlU+ZGLfQSw= +github.com/mdlayher/netlink v1.7.3-0.20250113171957-fbb4dce95f42 h1:A1Cq6Ysb0GM0tpKMbdCXCIfBclan4oHk1Jb+Hrejirg= +github.com/mdlayher/netlink v1.7.3-0.20250113171957-fbb4dce95f42/go.mod h1:BB4YCPDOzfy7FniQ/lxuYQ3dgmM2cZumHbK8RpTjN2o= github.com/mdlayher/sdnotify v1.0.0 h1:Ma9XeLVN/l0qpyx1tNeMSeTjCPH6NtuD6/N9XdTlQ3c= github.com/mdlayher/sdnotify v1.0.0/go.mod h1:HQUmpM4XgYkhDLtd+Uad8ZFK1T9D5+pNxnXQjCeJlGE= github.com/mdlayher/socket v0.5.0 h1:ilICZmJcQz70vrWVes1MFera4jGiWNocSkykwwoy3XI= @@ -128,14 +164,20 @@ github.com/tailscale/netlink v1.1.1-0.20240822203006-4d49adab4de7 h1:uFsXVBE9Qr4 github.com/tailscale/netlink v1.1.1-0.20240822203006-4d49adab4de7/go.mod h1:NzVQi3Mleb+qzq8VmcWpSkcSYxXIg0DkI6XDzpVkhJ0= github.com/tailscale/peercred v0.0.0-20240214030740-b535050b2aa4 h1:Gz0rz40FvFVLTBk/K8UNAenb36EbDSnh+q7Z9ldcC8w= github.com/tailscale/peercred v0.0.0-20240214030740-b535050b2aa4/go.mod h1:phI29ccmHQBc+wvroosENp1IF9195449VDnFDhJ4rJU= +github.com/tailscale/peercred v0.0.0-20250107143737-35a0c7bd7edc h1:24heQPtnFR+yfntqhI3oAu9i27nEojcQ4NuBQOo5ZFA= +github.com/tailscale/peercred v0.0.0-20250107143737-35a0c7bd7edc/go.mod h1:f93CXfllFsO9ZQVq+Zocb1Gp4G5Fz0b0rXHLOzt/Djc= github.com/tailscale/web-client-prebuilt v0.0.0-20240226180453-5db17b287bf1 h1:tdUdyPqJ0C97SJfjB9tW6EylTtreyee9C44de+UBG0g= github.com/tailscale/web-client-prebuilt v0.0.0-20240226180453-5db17b287bf1/go.mod h1:agQPE6y6ldqCOui2gkIh7ZMztTkIQKH049tv8siLuNQ= +github.com/tailscale/web-client-prebuilt v0.0.0-20250124233751-d4cd19a26976 h1:UBPHPtv8+nEAy2PD8RyAhOYvau1ek0HDJqLS/Pysi14= +github.com/tailscale/web-client-prebuilt v0.0.0-20250124233751-d4cd19a26976/go.mod h1:agQPE6y6ldqCOui2gkIh7ZMztTkIQKH049tv8siLuNQ= github.com/tailscale/wireguard-go v0.0.0-20240705152531-2f5d148bcfe1 h1:ycpNCSYwzZ7x4G4ioPNtKQmIY0G/3o4pVf8wCZq6blY= github.com/tailscale/wireguard-go v0.0.0-20240705152531-2f5d148bcfe1/go.mod h1:BOm5fXUBFM+m9woLNBoxI9TaBXXhGNP50LX/TGIvGb4= github.com/tailscale/wireguard-go v0.0.0-20240731203015-71393c576b98 h1:RNpJrXfI5u6e+uzyIzvmnXbhmhdRkVf//90sMBH3lso= github.com/tailscale/wireguard-go v0.0.0-20240731203015-71393c576b98/go.mod h1:BOm5fXUBFM+m9woLNBoxI9TaBXXhGNP50LX/TGIvGb4= github.com/tailscale/wireguard-go v0.0.0-20240905161824-799c1978fafc h1:cezaQN9pvKVaw56Ma5qr/G646uKIYP0yQf+OyWN/okc= github.com/tailscale/wireguard-go v0.0.0-20240905161824-799c1978fafc/go.mod h1:BOm5fXUBFM+m9woLNBoxI9TaBXXhGNP50LX/TGIvGb4= +github.com/tailscale/wireguard-go v0.0.0-20250107165329-0b8b35511f19 h1:BcEJP2ewTIK2ZCsqgl6YGpuO6+oKqqag5HHb7ehljKw= +github.com/tailscale/wireguard-go v0.0.0-20250107165329-0b8b35511f19/go.mod h1:BOm5fXUBFM+m9woLNBoxI9TaBXXhGNP50LX/TGIvGb4= github.com/tailscale/xnet v0.0.0-20240117122442-62b9a7c569f9 h1:81P7rjnikHKTJ75EkjppvbwUfKHDHYk6LJpO5PZy8pA= github.com/tailscale/xnet v0.0.0-20240117122442-62b9a7c569f9/go.mod h1:orPd6JZXXRyuDusYilywte7k094d7dycXXU5YnWsrwg= github.com/tailscale/xnet v0.0.0-20240729143630-8497ac4dab2e h1:zOGKqN5D5hHhiYUp091JqK7DPCqSARyUfduhGUY8Bek= @@ -148,6 +190,8 @@ github.com/u-root/u-root v0.12.0 h1:K0AuBFriwr0w/PGS3HawiAw89e3+MU7ks80GpghAsNs= github.com/u-root/u-root v0.12.0/go.mod h1:FYjTOh4IkIZHhjsd17lb8nYW6udgXdJhG1c0r6u0arI= github.com/u-root/uio v0.0.0-20240118234441-a3c409a6018e h1:BA9O3BmlTmpjbvajAwzWx4Wo2TRVdpPXZEeemGQcajw= github.com/u-root/uio v0.0.0-20240118234441-a3c409a6018e/go.mod h1:eLL9Nub3yfAho7qB0MzZizFhTU2QkLeoVsWdHtDW264= +github.com/u-root/uio v0.0.0-20240224005618-d2acac8f3701 h1:pyC9PaHYZFgEKFdlp3G8RaCKgVpHZnecvArXvPXcFkM= +github.com/u-root/uio v0.0.0-20240224005618-d2acac8f3701/go.mod h1:P3a5rG4X7tI17Nn3aOIAYr5HbIMukwXG0urG0WuL8OA= github.com/vishvananda/netlink v1.2.1-beta.2 h1:Llsql0lnQEbHj0I1OuKyp8otXp0r3q0mPkuhwHfStVs= github.com/vishvananda/netlink v1.2.1-beta.2/go.mod h1:twkDnbuQxJYemMlGd4JFIcuhgX83tXhKS2B/PRMpOho= github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8= @@ -156,42 +200,66 @@ github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM= github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg= go4.org/mem v0.0.0-20220726221520-4f986261bf13 h1:CbZeCBZ0aZj8EfVgnqQcYZgf0lpZ3H9rmp5nkDTAst8= go4.org/mem v0.0.0-20220726221520-4f986261bf13/go.mod h1:reUoABIJ9ikfM5sgtSF3Wushcza7+WeD01VB9Lirh3g= +go4.org/mem v0.0.0-20240501181205-ae6ca9944745 h1:Tl++JLUCe4sxGu8cTpDzRLd3tN7US4hOxG5YpKCzkek= +go4.org/mem v0.0.0-20240501181205-ae6ca9944745/go.mod h1:reUoABIJ9ikfM5sgtSF3Wushcza7+WeD01VB9Lirh3g= go4.org/netipx v0.0.0-20231129151722-fdeea329fbba h1:0b9z3AuHCjxk0x/opv64kcgZLBseWJUpBw5I82+2U4M= go4.org/netipx v0.0.0-20231129151722-fdeea329fbba/go.mod h1:PLyyIXexvUFg3Owu6p/WfdlivPbZJsZdgWZlrGope/Y= golang.org/x/crypto v0.24.0 h1:mnl8DM0o513X8fdIkmyFE/5hTYxbwYOjDS/+rK6qpRI= golang.org/x/crypto v0.24.0/go.mod h1:Z1PMYSOR5nyMcyAVAIQSKCDwalqy85Aqn1x3Ws4L5DM= golang.org/x/crypto v0.25.0 h1:ypSNr+bnYL2YhwoMt2zPxHFmbAN1KZs/njMG3hxUp30= golang.org/x/crypto v0.25.0/go.mod h1:T+wALwcMOSE0kXgUAnPAHqTLW+XHgcELELW8VaDgm/M= +golang.org/x/crypto v0.32.1-0.20250118192723-a8ea4be81f07 h1:Z+Zg+aXJYq6f4TK2E4H+vZkQ4dJAWnInXDR6hM9znxo= +golang.org/x/crypto v0.32.1-0.20250118192723-a8ea4be81f07/go.mod h1:ZnnJkOaASj8g0AjIduWNlq2NRxL0PlBrbKVyZ6V/Ugc= golang.org/x/exp v0.0.0-20240119083558-1b970713d09a h1:Q8/wZp0KX97QFTc2ywcOE0YRjZPVIx+MXInMzdvQqcA= golang.org/x/exp v0.0.0-20240119083558-1b970713d09a/go.mod h1:idGWGoKP1toJGkd5/ig9ZLuPcZBC3ewk7SzmH0uou08= +golang.org/x/exp v0.0.0-20250106191152-7588d65b2ba8 h1:yqrTHse8TCMW1M1ZCP+VAR/l0kKxwaAIqN/il7x4voA= +golang.org/x/exp v0.0.0-20250106191152-7588d65b2ba8/go.mod h1:tujkw807nyEEAamNbDrEGzRav+ilXA7PCRAd6xsmwiU= golang.org/x/net v0.26.0 h1:soB7SVo0PWrY4vPW/+ay0jKDNScG2X9wFeYlXIvJsOQ= golang.org/x/net v0.26.0/go.mod h1:5YKkiSynbBIh3p6iOc/vibscux0x38BZDkn8sCUPxHE= golang.org/x/net v0.27.0 h1:5K3Njcw06/l2y9vpGCSdcxWOYHOUk3dVNGDXN+FvAys= golang.org/x/net v0.27.0/go.mod h1:dDi0PyhWNoiUOrAS8uXv/vnScO4wnHQO4mj9fn/RytE= +golang.org/x/net v0.34.0 h1:Mb7Mrk043xzHgnRM88suvJFwzVrRfHEHJEl5/71CKw0= +golang.org/x/net v0.34.0/go.mod h1:di0qlW3YNM5oh6GqDGQr92MyTozJPmybPK4Ev/Gm31k= golang.org/x/oauth2 v0.16.0 h1:aDkGMBSYxElaoP81NpoUoz2oo2R2wHdZpGToUxfyQrQ= golang.org/x/oauth2 v0.16.0/go.mod h1:hqZ+0LWXsiVoZpeld6jVt06P3adbS2Uu911W1SsJv2o= +golang.org/x/oauth2 v0.25.0 h1:CY4y7XT9v0cRI9oupztF8AgiIu99L/ksR/Xp/6jrZ70= +golang.org/x/oauth2 v0.25.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI= golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ= +golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.21.0 h1:rF+pYz3DAGSQAxAu1CbC7catZg4ebC4UIeIhKxBZvws= golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.22.0 h1:RI27ohtqKCnwULzJLqkv897zojh5/DwS/ENaMzUOaWI= golang.org/x/sys v0.22.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.29.1-0.20250107080300-1c14dcadc3ab h1:BMkEEWYOjkvOX7+YKOGbp6jCyQ5pR2j0Ah47p1Vdsx4= +golang.org/x/sys v0.29.1-0.20250107080300-1c14dcadc3ab/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.21.0 h1:WVXCp+/EBEHOj53Rvu+7KiT/iElMrO8ACK16SMZ3jaA= golang.org/x/term v0.21.0/go.mod h1:ooXLefLobQVslOqselCNF4SxFAaoS6KujMbsGzSDmX0= golang.org/x/term v0.22.0 h1:BbsgPEJULsl2fV/AT3v15Mjva5yXKQDyKf+TbDz7QJk= golang.org/x/term v0.22.0/go.mod h1:F3qCibpT5AMpCRfhfT53vVJwhLtIVHhB9XDjfFvnMI4= +golang.org/x/term v0.28.0 h1:/Ts8HFuMR2E6IP/jlo7QVLZHggjKQbhu/7H0LJFr3Gg= +golang.org/x/term v0.28.0/go.mod h1:Sw/lC2IAUZ92udQNf3WodGtn4k/XoLyZoh8v/8uiwek= golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4= golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI= +golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo= +golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk= golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= +golang.org/x/time v0.9.0 h1:EsRrnYcQiGH+5FfbgvV4AP7qEZstoyrHB0DzarOQ4ZY= +golang.org/x/time v0.9.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= gvisor.dev/gvisor v0.0.0-20240306221502-ee1e1f6070e3 h1:/8/t5pz/mgdRXhYOIeqqYhFAQLE4DDGegc0Y4ZjyFJM= gvisor.dev/gvisor v0.0.0-20240306221502-ee1e1f6070e3/go.mod h1:NQHVAzMwvZ+Qe3ElSiHmq9RUm1MdNHpUZ52fiEqvn+0= gvisor.dev/gvisor v0.0.0-20240722211153-64c016c92987 h1:TU8z2Lh3Bbq77w0t1eG8yRlLcNHzZu3x6mhoH2Mk0c8= gvisor.dev/gvisor v0.0.0-20240722211153-64c016c92987/go.mod h1:sxc3Uvk/vHcd3tj7/DHVBoR5wvWT/MmRq2pj7HRJnwU= +gvisor.dev/gvisor v0.0.0-20250205023644-9414b50a5633 h1:2gap+Kh/3F47cO6hAu3idFvsJ0ue6TRcEi2IUkv/F8k= +gvisor.dev/gvisor v0.0.0-20250205023644-9414b50a5633/go.mod h1:5DMfjtclAbTIjbXqO1qCe2K5GKKxWz2JHvCChuTcJEM= k8s.io/client-go v0.30.1 h1:uC/Ir6A3R46wdkgCV3vbLyNOYyCJ8oZnjtJGKfytl/Q= k8s.io/client-go v0.30.1/go.mod h1:wrAqLNs2trwiCH/wxxmT/x3hKVH9PuV0GGW0oDoHVqc= k8s.io/client-go v0.30.3 h1:bHrJu3xQZNXIi8/MoxYtZBBWQQXwy16zqJwloXXfD3k= k8s.io/client-go v0.30.3/go.mod h1:8d4pf8vYu665/kUbsxWAQ/JDBNWqfFeZnvFiVdmx89U= +k8s.io/client-go v0.32.0 h1:DimtMcnN/JIKZcrSrstiwvvZvLjG0aSxy8PxN8IChp8= +k8s.io/client-go v0.32.0/go.mod h1:boDWvdM1Drk4NJj/VddSLnx59X3OPgwrOo0vGbtq9+8= nhooyr.io/websocket v1.8.10 h1:mv4p+MnGrLDcPlBoWsvPP7XCzTYMXP9F9eIGoKbgx7Q= nhooyr.io/websocket v1.8.10/go.mod h1:rN9OFWIUwuxg4fR5tELlYC04bXYowCP9GX47ivo2l+c= sigs.k8s.io/yaml v1.4.0 h1:Mk1wCc2gy/F0THH0TAp1QYyJNzRm2KCLy3o5ASXVI5E= diff --git a/gokrazy/natlabapp/config.json b/gokrazy/natlabapp/config.json index 902f14acd..1968b2aac 100644 --- a/gokrazy/natlabapp/config.json +++ b/gokrazy/natlabapp/config.json @@ -20,6 +20,10 @@ } } }, + "Environment": [ + "GOOS=linux", + "GOARCH=amd64" + ], "KernelPackage": "github.com/tailscale/gokrazy-kernel", "FirmwarePackage": "", "EEPROMPackage": "", diff --git a/gokrazy/tsapp/builddir/tailscale.com/go.sum b/gokrazy/tsapp/builddir/tailscale.com/go.sum index b3b73e2d0..2ffef7bf7 100644 --- a/gokrazy/tsapp/builddir/tailscale.com/go.sum +++ b/gokrazy/tsapp/builddir/tailscale.com/go.sum @@ -4,48 +4,80 @@ github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be h1:9AeTilPcZAjCFI github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be/go.mod h1:ySMOLuWl6zY27l47sB3qLNK6tF2fkHG55UZxx8oIVo4= github.com/aws/aws-sdk-go-v2 v1.24.1 h1:xAojnj+ktS95YZlDf0zxWBkbFtymPeDP+rvUQIH3uAU= github.com/aws/aws-sdk-go-v2 v1.24.1/go.mod h1:LNh45Br1YAkEKaAqvmE1m8FUx6a5b/V0oAKV7of29b4= +github.com/aws/aws-sdk-go-v2 v1.36.0 h1:b1wM5CcE65Ujwn565qcwgtOTT1aT4ADOHHgglKjG7fk= +github.com/aws/aws-sdk-go-v2 v1.36.0/go.mod h1:5PMILGVKiW32oDzjj6RU52yrNrDPUHcbZQYr1sM7qmM= github.com/aws/aws-sdk-go-v2/config v1.26.5 h1:lodGSevz7d+kkFJodfauThRxK9mdJbyutUxGq1NNhvw= github.com/aws/aws-sdk-go-v2/config v1.26.5/go.mod h1:DxHrz6diQJOc9EwDslVRh84VjjrE17g+pVZXUeSxaDU= +github.com/aws/aws-sdk-go-v2/config v1.29.5 h1:4lS2IB+wwkj5J43Tq/AwvnscBerBJtQQ6YS7puzCI1k= +github.com/aws/aws-sdk-go-v2/config v1.29.5/go.mod h1:SNzldMlDVbN6nWxM7XsUiNXPSa1LWlqiXtvh/1PrJGg= github.com/aws/aws-sdk-go-v2/credentials v1.16.16 h1:8q6Rliyv0aUFAVtzaldUEcS+T5gbadPbWdV1WcAddK8= github.com/aws/aws-sdk-go-v2/credentials v1.16.16/go.mod h1:UHVZrdUsv63hPXFo1H7c5fEneoVo9UXiz36QG1GEPi0= +github.com/aws/aws-sdk-go-v2/credentials v1.17.58 h1:/d7FUpAPU8Lf2KUdjniQvfNdlMID0Sd9pS23FJ3SS9Y= +github.com/aws/aws-sdk-go-v2/credentials v1.17.58/go.mod h1:aVYW33Ow10CyMQGFgC0ptMRIqJWvJ4nxZb0sUiuQT/A= github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.14.11 h1:c5I5iH+DZcH3xOIMlz3/tCKJDaHFwYEmxvlh2fAcFo8= github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.14.11/go.mod h1:cRrYDYAMUohBJUtUnOhydaMHtiK/1NZ0Otc9lIb6O0Y= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.27 h1:7lOW8NUwE9UZekS1DYoiPdVAqZ6A+LheHWb+mHbNOq8= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.27/go.mod h1:w1BASFIPOPUae7AgaH4SbjNbfdkxuggLyGfNFTn8ITY= github.com/aws/aws-sdk-go-v2/internal/configsources v1.2.10 h1:vF+Zgd9s+H4vOXd5BMaPWykta2a6Ih0AKLq/X6NYKn4= github.com/aws/aws-sdk-go-v2/internal/configsources v1.2.10/go.mod h1:6BkRjejp/GR4411UGqkX8+wFMbFbqsUIimfK4XjOKR4= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.31 h1:lWm9ucLSRFiI4dQQafLrEOmEDGry3Swrz0BIRdiHJqQ= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.31/go.mod h1:Huu6GG0YTfbPphQkDSo4dEGmQRTKb9k9G7RdtyQWxuI= github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.5.10 h1:nYPe006ktcqUji8S2mqXf9c/7NdiKriOwMvWQHgYztw= github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.5.10/go.mod h1:6UV4SZkVvmODfXKql4LCbaZUpF7HO2BX38FgBf9ZOLw= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.31 h1:ACxDklUKKXb48+eg5ROZXi1vDgfMyfIA/WyvqHcHI0o= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.31/go.mod h1:yadnfsDwqXeVaohbGc/RaD287PuyRw2wugkh5ZL2J6k= github.com/aws/aws-sdk-go-v2/internal/ini v1.7.2 h1:GrSw8s0Gs/5zZ0SX+gX4zQjRnRsMJDJ2sLur1gRBhEM= github.com/aws/aws-sdk-go-v2/internal/ini v1.7.2/go.mod h1:6fQQgfuGmw8Al/3M2IgIllycxV7ZW7WCdVSqfBeUiCY= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.2 h1:Pg9URiobXy85kgFev3og2CuOZ8JZUBENF+dcgWBaYNk= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.2/go.mod h1:FbtygfRFze9usAadmnGJNc8KsP346kEe+y2/oyhGAGc= github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.10.4 h1:/b31bi3YVNlkzkBrm9LfpaKoaYZUxIAj4sHfOTmLfqw= github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.10.4/go.mod h1:2aGXHFmbInwgP9ZfpmdIfOELL79zhdNYNmReK8qDfdQ= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.2 h1:D4oz8/CzT9bAEYtVhSBmFj2dNOtaHOtMKc2vHBwYizA= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.2/go.mod h1:Za3IHqTQ+yNcRHxu1OFucBh0ACZT4j4VQFF0BqpZcLY= github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.10.10 h1:DBYTXwIGQSGs9w4jKm60F5dmCQ3EEruxdc0MFh+3EY4= github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.10.10/go.mod h1:wohMUQiFdzo0NtxbBg0mSRGZ4vL3n0dKjLTINdcIino= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.12 h1:O+8vD2rGjfihBewr5bT+QUfYUHIxCVgG61LHoT59shM= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.12/go.mod h1:usVdWJaosa66NMvmCrr08NcWDBRv4E6+YFG2pUdw1Lk= github.com/aws/aws-sdk-go-v2/service/ssm v1.44.7 h1:a8HvP/+ew3tKwSXqL3BCSjiuicr+XTU2eFYeogV9GJE= github.com/aws/aws-sdk-go-v2/service/ssm v1.44.7/go.mod h1:Q7XIWsMo0JcMpI/6TGD6XXcXcV1DbTj6e9BKNntIMIM= github.com/aws/aws-sdk-go-v2/service/sso v1.18.7 h1:eajuO3nykDPdYicLlP3AGgOyVN3MOlFmZv7WGTuJPow= github.com/aws/aws-sdk-go-v2/service/sso v1.18.7/go.mod h1:+mJNDdF+qiUlNKNC3fxn74WWNN+sOiGOEImje+3ScPM= +github.com/aws/aws-sdk-go-v2/service/sso v1.24.14 h1:c5WJ3iHz7rLIgArznb3JCSQT3uUMiz9DLZhIX+1G8ok= +github.com/aws/aws-sdk-go-v2/service/sso v1.24.14/go.mod h1:+JJQTxB6N4niArC14YNtxcQtwEqzS3o9Z32n7q33Rfs= github.com/aws/aws-sdk-go-v2/service/ssooidc v1.21.7 h1:QPMJf+Jw8E1l7zqhZmMlFw6w1NmfkfiSK8mS4zOx3BA= github.com/aws/aws-sdk-go-v2/service/ssooidc v1.21.7/go.mod h1:ykf3COxYI0UJmxcfcxcVuz7b6uADi1FkiUz6Eb7AgM8= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.13 h1:f1L/JtUkVODD+k1+IiSJUUv8A++2qVr+Xvb3xWXETMU= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.13/go.mod h1:tvqlFoja8/s0o+UruA1Nrezo/df0PzdunMDDurUfg6U= github.com/aws/aws-sdk-go-v2/service/sts v1.26.7 h1:NzO4Vrau795RkUdSHKEwiR01FaGzGOH1EETJ+5QHnm0= github.com/aws/aws-sdk-go-v2/service/sts v1.26.7/go.mod h1:6h2YuIoxaMSCFf5fi1EgZAwdfkGMgDY+DVfa61uLe4U= +github.com/aws/aws-sdk-go-v2/service/sts v1.33.13 h1:3LXNnmtH3TURctC23hnC0p/39Q5gre3FI7BNOiDcVWc= +github.com/aws/aws-sdk-go-v2/service/sts v1.33.13/go.mod h1:7Yn+p66q/jt38qMoVfNvjbm3D89mGBnkwDcijgtih8w= github.com/aws/smithy-go v1.19.0 h1:KWFKQV80DpP3vJrrA9sVAHQ5gc2z8i4EzrLhLlWXcBM= github.com/aws/smithy-go v1.19.0/go.mod h1:NukqUGpCZIILqqiV0NIjeFh24kd/FAa4beRb6nbIUPE= +github.com/aws/smithy-go v1.22.2 h1:6D9hW43xKFrRx/tXXfAlIZc4JI+yQe6snnWcQyxSyLQ= +github.com/aws/smithy-go v1.22.2/go.mod h1:irrKGvNn1InZwb2d7fkIRNucdfwR8R+Ts3wxYa/cJHg= github.com/bits-and-blooms/bitset v1.13.0 h1:bAQ9OPNFYbGHV6Nez0tmNI0RiEu7/hxlYJRUA0wFAVE= github.com/bits-and-blooms/bitset v1.13.0/go.mod h1:7hO7Gc7Pp1vODcmWvKMRA9BNmbv6a/7QIWpPxHddWR8= github.com/coreos/go-iptables v0.7.1-0.20240112124308-65c67c9f46e6 h1:8h5+bWd7R6AYUslN6c6iuZWTKsKxUFDlpnmilO6R2n0= github.com/coreos/go-iptables v0.7.1-0.20240112124308-65c67c9f46e6/go.mod h1:Qe8Bv2Xik5FyTXwgIbLAnv2sWSBmvWdFETJConOQ//Q= github.com/creack/pty v1.1.21 h1:1/QdRyBaHHJP61QkWMXlOIBfsgdDeeKfK8SYVUWJKf0= github.com/creack/pty v1.1.21/go.mod h1:MOBLtS5ELjhRRrroQr9kyvTxUAFNvYEK993ew/Vr4O4= +github.com/creack/pty v1.1.23 h1:4M6+isWdcStXEf15G/RbrMPOQj1dZ7HPZCGwE4kOeP0= +github.com/creack/pty v1.1.23/go.mod h1:08sCNb52WyoAwi2QDyzUCTgcvVFhUzewun7wtTfvcwE= github.com/digitalocean/go-smbios v0.0.0-20180907143718-390a4f403a8e h1:vUmf0yezR0y7jJ5pceLHthLaYf4bA5T14B6q39S4q2Q= github.com/digitalocean/go-smbios v0.0.0-20180907143718-390a4f403a8e/go.mod h1:YTIHhz/QFSYnu/EhlF2SpU2Uk+32abacUYA5ZPljz1A= github.com/djherbis/times v1.6.0 h1:w2ctJ92J8fBvWPxugmXIv7Nz7Q3iDMKNx9v5ocVH20c= github.com/djherbis/times v1.6.0/go.mod h1:gOHeRAz2h+VJNZ5Gmc/o7iD9k4wW7NMVqieYCY99oc0= github.com/fxamacker/cbor/v2 v2.6.0 h1:sU6J2usfADwWlYDAFhZBQ6TnLFBHxgesMrQfQgk1tWA= github.com/fxamacker/cbor/v2 v2.6.0/go.mod h1:pxXPTn3joSm21Gbwsv0w9OSA2y1HFR9qXEeXQVeNoDQ= +github.com/fxamacker/cbor/v2 v2.7.0 h1:iM5WgngdRBanHcxugY4JySA0nk1wZorNOpTgCMedv5E= +github.com/fxamacker/cbor/v2 v2.7.0/go.mod h1:pxXPTn3joSm21Gbwsv0w9OSA2y1HFR9qXEeXQVeNoDQ= github.com/gaissmai/bart v0.11.1 h1:5Uv5XwsaFBRo4E5VBcb9TzY8B7zxFf+U7isDxqOrRfc= github.com/gaissmai/bart v0.11.1/go.mod h1:KHeYECXQiBjTzQz/om2tqn3sZF1J7hw9m6z41ftj3fg= github.com/go-json-experiment/json v0.0.0-20231102232822-2e55bd4e08b0 h1:ymLjT4f35nQbASLnvxEde4XOBL+Sn7rFuV+FOJqkljg= github.com/go-json-experiment/json v0.0.0-20231102232822-2e55bd4e08b0/go.mod h1:6daplAwHHGbUGib4990V3Il26O0OC4aRyvewaaAihaA= +github.com/go-json-experiment/json v0.0.0-20250103232110-6a9a0fde9288 h1:KbX3Z3CgiYlbaavUq3Cj9/MjpO+88S7/AGXzynVDv84= +github.com/go-json-experiment/json v0.0.0-20250103232110-6a9a0fde9288/go.mod h1:BWmvoE1Xia34f3l/ibJweyhrT+aROb/FQ6d+37F0e2s= github.com/godbus/dbus/v5 v5.1.1-0.20230522191255-76236955d466 h1:sQspH8M4niEijh3PFscJRLDnkL547IeP7kpPe3uUhEg= github.com/godbus/dbus/v5 v5.1.1-0.20230522191255-76236955d466/go.mod h1:ZiQxhyQ+bbbfxUKVvjfO498oPYvtYhZzycal3G/NHmU= github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da h1:oI5xCqsCo564l8iNU+DwB5epxmsaqB+rhGL0m5jtYqE= @@ -58,12 +90,16 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/gorilla/csrf v1.7.2 h1:oTUjx0vyf2T+wkrx09Trsev1TE+/EbDAeHtSTbtC2eI= github.com/gorilla/csrf v1.7.2/go.mod h1:F1Fj3KG23WYHE6gozCmBAezKookxbIvUJT+121wTuLk= +github.com/gorilla/csrf v1.7.3-0.20250123201450-9dd6af1f6d30 h1:fiJdrgVBkjZ5B1HJ2WQwNOaXB+QyYcNXTA3t1XYLz0M= +github.com/gorilla/csrf v1.7.3-0.20250123201450-9dd6af1f6d30/go.mod h1:F1Fj3KG23WYHE6gozCmBAezKookxbIvUJT+121wTuLk= github.com/gorilla/securecookie v1.1.2 h1:YCIWL56dvtr73r6715mJs5ZvhtnY73hBvEF8kXD8ePA= github.com/gorilla/securecookie v1.1.2/go.mod h1:NfCASbcHqRSY+3a8tlWJwsQap2VX5pwzwo4h3eOamfo= github.com/hdevalence/ed25519consensus v0.2.0 h1:37ICyZqdyj0lAZ8P4D1d1id3HqbbG1N3iBb1Tb4rdcU= github.com/hdevalence/ed25519consensus v0.2.0/go.mod h1:w3BHWjwJbFU29IRHL1Iqkw3sus+7FctEyM4RqDxYNzo= github.com/illarion/gonotify v1.0.1 h1:F1d+0Fgbq/sDWjj/r66ekjDG+IDeecQKUFH4wNwsoio= github.com/illarion/gonotify v1.0.1/go.mod h1:zt5pmDofZpU1f8aqlK0+95eQhoEAn/d4G4B/FjVW4jE= +github.com/illarion/gonotify/v2 v2.0.3 h1:B6+SKPo/0Sw8cRJh1aLzNEeNVFfzE3c6N+o+vyxM+9A= +github.com/illarion/gonotify/v2 v2.0.3/go.mod h1:38oIJTgFqupkEydkkClkbL6i5lXV/bxdH9do5TALPEE= github.com/insomniacslk/dhcp v0.0.0-20231206064809-8c70d406f6d2 h1:9K06NfxkBh25x56yVhWWlKFE8YpicaSfHwoV8SFbueA= github.com/insomniacslk/dhcp v0.0.0-20231206064809-8c70d406f6d2/go.mod h1:3A9PQ1cunSDF/1rbTq99Ts4pVnycWg+vlPkfeD2NLFI= github.com/jellydator/ttlcache/v3 v3.1.0 h1:0gPFG0IHHP6xyUyXq+JaD8fwkDCqgqwohXNJBcYE71g= @@ -78,6 +114,8 @@ github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 h1:Z9n2FFNU github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51/go.mod h1:CzGEWj7cYgsdH8dAjBGEr58BoE7ScuLd+fwFZ44+/x8= github.com/klauspost/compress v1.17.4 h1:Ej5ixsIri7BrIjBkRZLTo6ghwrEtHFk7ijlczPW4fZ4= github.com/klauspost/compress v1.17.4/go.mod h1:/dCuZOvVtNoHsyb+cuJD3itjs3NbnF6KH9zAO4BDxPM= +github.com/klauspost/compress v1.17.11 h1:In6xLpyWOi1+C7tXUUWv2ot1QvBjxevKAaI6IXrJmUc= +github.com/klauspost/compress v1.17.11/go.mod h1:pMDklpSncoRMuLFrf1W9Ss9KT+0rH90U12bZKk7uwG0= github.com/kortschak/wol v0.0.0-20200729010619-da482cc4850a h1:+RR6SqnTkDLWyICxS1xpjCi/3dhyV+TgZwA6Ww3KncQ= github.com/kortschak/wol v0.0.0-20200729010619-da482cc4850a/go.mod h1:YTtCCM3ryyfiu4F7t8HQ1mxvp1UBdWM2r6Xa+nGWvDk= github.com/kr/fs v0.1.0 h1:Jskdu9ieNAYnjxsi0LbQp1ulIKZV1LAFgK1tWhpZgl8= @@ -90,6 +128,8 @@ github.com/mdlayher/genetlink v1.3.2 h1:KdrNKe+CTu+IbZnm/GVUMXSqBBLqcGpRDa0xkQy5 github.com/mdlayher/genetlink v1.3.2/go.mod h1:tcC3pkCrPUGIKKsCsp0B3AdaaKuHtaxoJRz3cc+528o= github.com/mdlayher/netlink v1.7.2 h1:/UtM3ofJap7Vl4QWCPDGXY8d3GIY2UGSDbK+QWmY8/g= github.com/mdlayher/netlink v1.7.2/go.mod h1:xraEF7uJbxLhc5fpHL4cPe221LI2bdttWlU+ZGLfQSw= +github.com/mdlayher/netlink v1.7.3-0.20250113171957-fbb4dce95f42 h1:A1Cq6Ysb0GM0tpKMbdCXCIfBclan4oHk1Jb+Hrejirg= +github.com/mdlayher/netlink v1.7.3-0.20250113171957-fbb4dce95f42/go.mod h1:BB4YCPDOzfy7FniQ/lxuYQ3dgmM2cZumHbK8RpTjN2o= github.com/mdlayher/sdnotify v1.0.0 h1:Ma9XeLVN/l0qpyx1tNeMSeTjCPH6NtuD6/N9XdTlQ3c= github.com/mdlayher/sdnotify v1.0.0/go.mod h1:HQUmpM4XgYkhDLtd+Uad8ZFK1T9D5+pNxnXQjCeJlGE= github.com/mdlayher/socket v0.5.0 h1:ilICZmJcQz70vrWVes1MFera4jGiWNocSkykwwoy3XI= @@ -116,14 +156,22 @@ github.com/tailscale/hujson v0.0.0-20221223112325-20486734a56a h1:SJy1Pu0eH1C29X github.com/tailscale/hujson v0.0.0-20221223112325-20486734a56a/go.mod h1:DFSS3NAGHthKo1gTlmEcSBiZrRJXi28rLNd/1udP1c8= github.com/tailscale/netlink v1.1.1-0.20211101221916-cabfb018fe85 h1:zrsUcqrG2uQSPhaUPjUQwozcRdDdSxxqhNgNZ3drZFk= github.com/tailscale/netlink v1.1.1-0.20211101221916-cabfb018fe85/go.mod h1:NzVQi3Mleb+qzq8VmcWpSkcSYxXIg0DkI6XDzpVkhJ0= +github.com/tailscale/netlink v1.1.1-0.20240822203006-4d49adab4de7 h1:uFsXVBE9Qr4ZoF094vE6iYTLDl0qCiKzYXlL6UeWObU= +github.com/tailscale/netlink v1.1.1-0.20240822203006-4d49adab4de7/go.mod h1:NzVQi3Mleb+qzq8VmcWpSkcSYxXIg0DkI6XDzpVkhJ0= github.com/tailscale/peercred v0.0.0-20240214030740-b535050b2aa4 h1:Gz0rz40FvFVLTBk/K8UNAenb36EbDSnh+q7Z9ldcC8w= github.com/tailscale/peercred v0.0.0-20240214030740-b535050b2aa4/go.mod h1:phI29ccmHQBc+wvroosENp1IF9195449VDnFDhJ4rJU= +github.com/tailscale/peercred v0.0.0-20250107143737-35a0c7bd7edc h1:24heQPtnFR+yfntqhI3oAu9i27nEojcQ4NuBQOo5ZFA= +github.com/tailscale/peercred v0.0.0-20250107143737-35a0c7bd7edc/go.mod h1:f93CXfllFsO9ZQVq+Zocb1Gp4G5Fz0b0rXHLOzt/Djc= github.com/tailscale/web-client-prebuilt v0.0.0-20240226180453-5db17b287bf1 h1:tdUdyPqJ0C97SJfjB9tW6EylTtreyee9C44de+UBG0g= github.com/tailscale/web-client-prebuilt v0.0.0-20240226180453-5db17b287bf1/go.mod h1:agQPE6y6ldqCOui2gkIh7ZMztTkIQKH049tv8siLuNQ= +github.com/tailscale/web-client-prebuilt v0.0.0-20250124233751-d4cd19a26976 h1:UBPHPtv8+nEAy2PD8RyAhOYvau1ek0HDJqLS/Pysi14= +github.com/tailscale/web-client-prebuilt v0.0.0-20250124233751-d4cd19a26976/go.mod h1:agQPE6y6ldqCOui2gkIh7ZMztTkIQKH049tv8siLuNQ= github.com/tailscale/wireguard-go v0.0.0-20240705152531-2f5d148bcfe1 h1:ycpNCSYwzZ7x4G4ioPNtKQmIY0G/3o4pVf8wCZq6blY= github.com/tailscale/wireguard-go v0.0.0-20240705152531-2f5d148bcfe1/go.mod h1:BOm5fXUBFM+m9woLNBoxI9TaBXXhGNP50LX/TGIvGb4= github.com/tailscale/wireguard-go v0.0.0-20240731203015-71393c576b98 h1:RNpJrXfI5u6e+uzyIzvmnXbhmhdRkVf//90sMBH3lso= github.com/tailscale/wireguard-go v0.0.0-20240731203015-71393c576b98/go.mod h1:BOm5fXUBFM+m9woLNBoxI9TaBXXhGNP50LX/TGIvGb4= +github.com/tailscale/wireguard-go v0.0.0-20250107165329-0b8b35511f19 h1:BcEJP2ewTIK2ZCsqgl6YGpuO6+oKqqag5HHb7ehljKw= +github.com/tailscale/wireguard-go v0.0.0-20250107165329-0b8b35511f19/go.mod h1:BOm5fXUBFM+m9woLNBoxI9TaBXXhGNP50LX/TGIvGb4= github.com/tailscale/xnet v0.0.0-20240117122442-62b9a7c569f9 h1:81P7rjnikHKTJ75EkjppvbwUfKHDHYk6LJpO5PZy8pA= github.com/tailscale/xnet v0.0.0-20240117122442-62b9a7c569f9/go.mod h1:orPd6JZXXRyuDusYilywte7k094d7dycXXU5YnWsrwg= github.com/tailscale/xnet v0.0.0-20240729143630-8497ac4dab2e h1:zOGKqN5D5hHhiYUp091JqK7DPCqSARyUfduhGUY8Bek= @@ -136,6 +184,8 @@ github.com/u-root/u-root v0.12.0 h1:K0AuBFriwr0w/PGS3HawiAw89e3+MU7ks80GpghAsNs= github.com/u-root/u-root v0.12.0/go.mod h1:FYjTOh4IkIZHhjsd17lb8nYW6udgXdJhG1c0r6u0arI= github.com/u-root/uio v0.0.0-20240118234441-a3c409a6018e h1:BA9O3BmlTmpjbvajAwzWx4Wo2TRVdpPXZEeemGQcajw= github.com/u-root/uio v0.0.0-20240118234441-a3c409a6018e/go.mod h1:eLL9Nub3yfAho7qB0MzZizFhTU2QkLeoVsWdHtDW264= +github.com/u-root/uio v0.0.0-20240224005618-d2acac8f3701 h1:pyC9PaHYZFgEKFdlp3G8RaCKgVpHZnecvArXvPXcFkM= +github.com/u-root/uio v0.0.0-20240224005618-d2acac8f3701/go.mod h1:P3a5rG4X7tI17Nn3aOIAYr5HbIMukwXG0urG0WuL8OA= github.com/vishvananda/netlink v1.2.1-beta.2 h1:Llsql0lnQEbHj0I1OuKyp8otXp0r3q0mPkuhwHfStVs= github.com/vishvananda/netlink v1.2.1-beta.2/go.mod h1:twkDnbuQxJYemMlGd4JFIcuhgX83tXhKS2B/PRMpOho= github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8= @@ -144,42 +194,66 @@ github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM= github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg= go4.org/mem v0.0.0-20220726221520-4f986261bf13 h1:CbZeCBZ0aZj8EfVgnqQcYZgf0lpZ3H9rmp5nkDTAst8= go4.org/mem v0.0.0-20220726221520-4f986261bf13/go.mod h1:reUoABIJ9ikfM5sgtSF3Wushcza7+WeD01VB9Lirh3g= +go4.org/mem v0.0.0-20240501181205-ae6ca9944745 h1:Tl++JLUCe4sxGu8cTpDzRLd3tN7US4hOxG5YpKCzkek= +go4.org/mem v0.0.0-20240501181205-ae6ca9944745/go.mod h1:reUoABIJ9ikfM5sgtSF3Wushcza7+WeD01VB9Lirh3g= go4.org/netipx v0.0.0-20231129151722-fdeea329fbba h1:0b9z3AuHCjxk0x/opv64kcgZLBseWJUpBw5I82+2U4M= go4.org/netipx v0.0.0-20231129151722-fdeea329fbba/go.mod h1:PLyyIXexvUFg3Owu6p/WfdlivPbZJsZdgWZlrGope/Y= golang.org/x/crypto v0.24.0 h1:mnl8DM0o513X8fdIkmyFE/5hTYxbwYOjDS/+rK6qpRI= golang.org/x/crypto v0.24.0/go.mod h1:Z1PMYSOR5nyMcyAVAIQSKCDwalqy85Aqn1x3Ws4L5DM= golang.org/x/crypto v0.25.0 h1:ypSNr+bnYL2YhwoMt2zPxHFmbAN1KZs/njMG3hxUp30= golang.org/x/crypto v0.25.0/go.mod h1:T+wALwcMOSE0kXgUAnPAHqTLW+XHgcELELW8VaDgm/M= +golang.org/x/crypto v0.32.1-0.20250118192723-a8ea4be81f07 h1:Z+Zg+aXJYq6f4TK2E4H+vZkQ4dJAWnInXDR6hM9znxo= +golang.org/x/crypto v0.32.1-0.20250118192723-a8ea4be81f07/go.mod h1:ZnnJkOaASj8g0AjIduWNlq2NRxL0PlBrbKVyZ6V/Ugc= golang.org/x/exp v0.0.0-20240119083558-1b970713d09a h1:Q8/wZp0KX97QFTc2ywcOE0YRjZPVIx+MXInMzdvQqcA= golang.org/x/exp v0.0.0-20240119083558-1b970713d09a/go.mod h1:idGWGoKP1toJGkd5/ig9ZLuPcZBC3ewk7SzmH0uou08= +golang.org/x/exp v0.0.0-20250106191152-7588d65b2ba8 h1:yqrTHse8TCMW1M1ZCP+VAR/l0kKxwaAIqN/il7x4voA= +golang.org/x/exp v0.0.0-20250106191152-7588d65b2ba8/go.mod h1:tujkw807nyEEAamNbDrEGzRav+ilXA7PCRAd6xsmwiU= golang.org/x/net v0.26.0 h1:soB7SVo0PWrY4vPW/+ay0jKDNScG2X9wFeYlXIvJsOQ= golang.org/x/net v0.26.0/go.mod h1:5YKkiSynbBIh3p6iOc/vibscux0x38BZDkn8sCUPxHE= golang.org/x/net v0.27.0 h1:5K3Njcw06/l2y9vpGCSdcxWOYHOUk3dVNGDXN+FvAys= golang.org/x/net v0.27.0/go.mod h1:dDi0PyhWNoiUOrAS8uXv/vnScO4wnHQO4mj9fn/RytE= +golang.org/x/net v0.34.0 h1:Mb7Mrk043xzHgnRM88suvJFwzVrRfHEHJEl5/71CKw0= +golang.org/x/net v0.34.0/go.mod h1:di0qlW3YNM5oh6GqDGQr92MyTozJPmybPK4Ev/Gm31k= golang.org/x/oauth2 v0.16.0 h1:aDkGMBSYxElaoP81NpoUoz2oo2R2wHdZpGToUxfyQrQ= golang.org/x/oauth2 v0.16.0/go.mod h1:hqZ+0LWXsiVoZpeld6jVt06P3adbS2Uu911W1SsJv2o= +golang.org/x/oauth2 v0.25.0 h1:CY4y7XT9v0cRI9oupztF8AgiIu99L/ksR/Xp/6jrZ70= +golang.org/x/oauth2 v0.25.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI= golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ= +golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.21.0 h1:rF+pYz3DAGSQAxAu1CbC7catZg4ebC4UIeIhKxBZvws= golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.22.0 h1:RI27ohtqKCnwULzJLqkv897zojh5/DwS/ENaMzUOaWI= golang.org/x/sys v0.22.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.29.1-0.20250107080300-1c14dcadc3ab h1:BMkEEWYOjkvOX7+YKOGbp6jCyQ5pR2j0Ah47p1Vdsx4= +golang.org/x/sys v0.29.1-0.20250107080300-1c14dcadc3ab/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.21.0 h1:WVXCp+/EBEHOj53Rvu+7KiT/iElMrO8ACK16SMZ3jaA= golang.org/x/term v0.21.0/go.mod h1:ooXLefLobQVslOqselCNF4SxFAaoS6KujMbsGzSDmX0= golang.org/x/term v0.22.0 h1:BbsgPEJULsl2fV/AT3v15Mjva5yXKQDyKf+TbDz7QJk= golang.org/x/term v0.22.0/go.mod h1:F3qCibpT5AMpCRfhfT53vVJwhLtIVHhB9XDjfFvnMI4= +golang.org/x/term v0.28.0 h1:/Ts8HFuMR2E6IP/jlo7QVLZHggjKQbhu/7H0LJFr3Gg= +golang.org/x/term v0.28.0/go.mod h1:Sw/lC2IAUZ92udQNf3WodGtn4k/XoLyZoh8v/8uiwek= golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4= golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI= +golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo= +golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk= golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= +golang.org/x/time v0.9.0 h1:EsRrnYcQiGH+5FfbgvV4AP7qEZstoyrHB0DzarOQ4ZY= +golang.org/x/time v0.9.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= gvisor.dev/gvisor v0.0.0-20240306221502-ee1e1f6070e3 h1:/8/t5pz/mgdRXhYOIeqqYhFAQLE4DDGegc0Y4ZjyFJM= gvisor.dev/gvisor v0.0.0-20240306221502-ee1e1f6070e3/go.mod h1:NQHVAzMwvZ+Qe3ElSiHmq9RUm1MdNHpUZ52fiEqvn+0= gvisor.dev/gvisor v0.0.0-20240722211153-64c016c92987 h1:TU8z2Lh3Bbq77w0t1eG8yRlLcNHzZu3x6mhoH2Mk0c8= gvisor.dev/gvisor v0.0.0-20240722211153-64c016c92987/go.mod h1:sxc3Uvk/vHcd3tj7/DHVBoR5wvWT/MmRq2pj7HRJnwU= +gvisor.dev/gvisor v0.0.0-20250205023644-9414b50a5633 h1:2gap+Kh/3F47cO6hAu3idFvsJ0ue6TRcEi2IUkv/F8k= +gvisor.dev/gvisor v0.0.0-20250205023644-9414b50a5633/go.mod h1:5DMfjtclAbTIjbXqO1qCe2K5GKKxWz2JHvCChuTcJEM= k8s.io/client-go v0.30.1 h1:uC/Ir6A3R46wdkgCV3vbLyNOYyCJ8oZnjtJGKfytl/Q= k8s.io/client-go v0.30.1/go.mod h1:wrAqLNs2trwiCH/wxxmT/x3hKVH9PuV0GGW0oDoHVqc= k8s.io/client-go v0.30.3 h1:bHrJu3xQZNXIi8/MoxYtZBBWQQXwy16zqJwloXXfD3k= k8s.io/client-go v0.30.3/go.mod h1:8d4pf8vYu665/kUbsxWAQ/JDBNWqfFeZnvFiVdmx89U= +k8s.io/client-go v0.32.0 h1:DimtMcnN/JIKZcrSrstiwvvZvLjG0aSxy8PxN8IChp8= +k8s.io/client-go v0.32.0/go.mod h1:boDWvdM1Drk4NJj/VddSLnx59X3OPgwrOo0vGbtq9+8= nhooyr.io/websocket v1.8.10 h1:mv4p+MnGrLDcPlBoWsvPP7XCzTYMXP9F9eIGoKbgx7Q= nhooyr.io/websocket v1.8.10/go.mod h1:rN9OFWIUwuxg4fR5tELlYC04bXYowCP9GX47ivo2l+c= sigs.k8s.io/yaml v1.4.0 h1:Mk1wCc2gy/F0THH0TAp1QYyJNzRm2KCLy3o5ASXVI5E= diff --git a/gokrazy/tsapp/config.json b/gokrazy/tsapp/config.json index 33dd98a96..b88be53a4 100644 --- a/gokrazy/tsapp/config.json +++ b/gokrazy/tsapp/config.json @@ -27,6 +27,10 @@ } } }, + "Environment": [ + "GOOS=linux", + "GOARCH=amd64" + ], "KernelPackage": "github.com/tailscale/gokrazy-kernel", "FirmwarePackage": "github.com/tailscale/gokrazy-kernel", "InternalCompatibilityFlags": {} diff --git a/health/health.go b/health/health.go index 216535d17..f0f6a6ffb 100644 --- a/health/health.go +++ b/health/health.go @@ -8,7 +8,6 @@ package health import ( "context" "errors" - "expvar" "fmt" "maps" "net/http" @@ -20,19 +19,19 @@ import ( "time" "tailscale.com/envknob" - "tailscale.com/metrics" + "tailscale.com/feature/buildfeatures" + "tailscale.com/syncs" "tailscale.com/tailcfg" + "tailscale.com/tstime" "tailscale.com/types/opt" "tailscale.com/util/cibuild" + "tailscale.com/util/eventbus" "tailscale.com/util/mak" - "tailscale.com/util/multierr" - "tailscale.com/util/set" - "tailscale.com/util/usermetric" "tailscale.com/version" ) var ( - mu sync.Mutex + mu syncs.Mutex debugHandler map[string]http.Handler ) @@ -63,6 +62,21 @@ var receiveNames = []string{ // Tracker tracks the health of various Tailscale subsystems, // comparing each subsystems' state with each other to make sure // they're consistent based on the user's intended state. +// +// If a client [Warnable] becomes unhealthy or its unhealthy state is updated, +// an event will be emitted with WarnableChanged set to true and the Warnable +// and its UnhealthyState: +// +// Change{WarnableChanged: true, Warnable: w, UnhealthyState: us} +// +// If a Warnable becomes healthy, an event will be emitted with +// WarnableChanged set to true, the Warnable set, and UnhealthyState set to nil: +// +// Change{WarnableChanged: true, Warnable: w, UnhealthyState: nil} +// +// If the health messages from the control-plane change, an event will be +// emitted with ControlHealthChanged set to true. Recipients can fetch the set of +// control-plane health messages by calling [Tracker.CurrentState]: type Tracker struct { // MagicSockReceiveFuncs tracks the state of the three // magicsock receive functions: IPv4, IPv6, and DERP. @@ -73,6 +87,11 @@ type Tracker struct { // mu should not be held during init. initOnce sync.Once + testClock tstime.Clock // nil means use time.Now / tstime.StdClock{} + + eventClient *eventbus.Client + changePub *eventbus.Publisher[Change] + // mu guards everything that follows. mu sync.Mutex @@ -80,39 +99,87 @@ type Tracker struct { warnableVal map[*Warnable]*warningState // pendingVisibleTimers contains timers for Warnables that are unhealthy, but are // not visible to the user yet, because they haven't been unhealthy for TimeToVisible - pendingVisibleTimers map[*Warnable]*time.Timer + pendingVisibleTimers map[*Warnable]tstime.TimerController // sysErr maps subsystems to their current error (or nil if the subsystem is healthy) // Deprecated: using Warnables should be preferred - sysErr map[Subsystem]error - watchers set.HandleSet[func(*Warnable, *UnhealthyState)] // opt func to run if error state changes - timer *time.Timer + sysErr map[Subsystem]error + timer tstime.TimerController latestVersion *tailcfg.ClientVersion // or nil checkForUpdates bool applyUpdates opt.Bool - inMapPoll bool - inMapPollSince time.Time - lastMapPollEndedAt time.Time - lastStreamedMapResponse time.Time - lastNoiseDial time.Time - derpHomeRegion int - derpHomeless bool - derpRegionConnected map[int]bool - derpRegionHealthProblem map[int]string - derpRegionLastFrame map[int]time.Time - derpMap *tailcfg.DERPMap // last DERP map from control, could be nil if never received one - lastMapRequestHeard time.Time // time we got a 200 from control for a MapRequest - ipnState string - ipnWantRunning bool - ipnWantRunningLastTrue time.Time // when ipnWantRunning last changed false -> true - anyInterfaceUp opt.Bool // empty means unknown (assume true) - controlHealth []string - lastLoginErr error - localLogConfigErr error - tlsConnectionErrors map[string]error // map[ServerName]error - metricHealthMessage *metrics.MultiLabelMap[metricHealthMessageLabel] + inMapPoll bool + inMapPollSince time.Time + lastMapPollEndedAt time.Time + lastStreamedMapResponse time.Time + lastNoiseDial time.Time + derpHomeRegion int + derpHomeless bool + derpRegionConnected map[int]bool + derpRegionHealthProblem map[int]string + derpRegionLastFrame map[int]time.Time + derpMap *tailcfg.DERPMap // last DERP map from control, could be nil if never received one + lastMapRequestHeard time.Time // time we got a 200 from control for a MapRequest + ipnState string + ipnWantRunning bool + ipnWantRunningLastTrue time.Time // when ipnWantRunning last changed false -> true + anyInterfaceUp opt.Bool // empty means unknown (assume true) + lastNotifiedControlMessages map[tailcfg.DisplayMessageID]tailcfg.DisplayMessage // latest control messages processed, kept for change detection + controlMessages map[tailcfg.DisplayMessageID]tailcfg.DisplayMessage // latest control messages received + lastLoginErr error + localLogConfigErr error + tlsConnectionErrors map[string]error // map[ServerName]error + metricHealthMessage any // nil or *metrics.MultiLabelMap[metricHealthMessageLabel] +} + +// NewTracker contructs a new [Tracker] and attaches the given eventbus. +// NewTracker will panic is no eventbus is given. +func NewTracker(bus *eventbus.Bus) *Tracker { + if !buildfeatures.HasHealth { + return &Tracker{} + } + if bus == nil { + panic("no eventbus set") + } + + ec := bus.Client("health.Tracker") + t := &Tracker{ + eventClient: ec, + changePub: eventbus.Publish[Change](ec), + } + t.timer = t.clock().AfterFunc(time.Minute, t.timerSelfCheck) + + ec.Monitor(t.awaitEventClientDone) + + return t +} + +func (t *Tracker) awaitEventClientDone(ec *eventbus.Client) { + <-ec.Done() + t.mu.Lock() + defer t.mu.Unlock() + + for _, timer := range t.pendingVisibleTimers { + timer.Stop() + } + t.timer.Stop() + clear(t.pendingVisibleTimers) +} + +func (t *Tracker) now() time.Time { + if t.testClock != nil { + return t.testClock.Now() + } + return time.Now() +} + +func (t *Tracker) clock() tstime.Clock { + if t.testClock != nil { + return t.testClock + } + return tstime.StdClock{} } // Subsystem is the name of a subsystem whose health can be monitored. @@ -128,9 +195,6 @@ const ( // SysDNS is the name of the net/dns subsystem. SysDNS = Subsystem("dns") - // SysDNSOS is the name of the net/dns OSConfigurator subsystem. - SysDNSOS = Subsystem("dns-os") - // SysDNSManager is the name of the net/dns manager subsystem. SysDNSManager = Subsystem("dns-manager") @@ -141,7 +205,7 @@ const ( var subsystemsWarnables = map[Subsystem]*Warnable{} func init() { - for _, s := range []Subsystem{SysRouter, SysDNS, SysDNSOS, SysDNSManager, SysTKA} { + for _, s := range []Subsystem{SysRouter, SysDNS, SysDNSManager, SysTKA} { w := Register(&Warnable{ Code: WarnableCode(s), Severity: SeverityMedium, @@ -159,6 +223,9 @@ const legacyErrorArgKey = "LegacyError" // temporarily (2024-06-14) while we migrate the old health infrastructure based // on Subsystems to the new Warnables architecture. func (s Subsystem) Warnable() *Warnable { + if !buildfeatures.HasHealth { + return &noopWarnable + } w, ok := subsystemsWarnables[s] if !ok { panic(fmt.Sprintf("health: no Warnable for Subsystem %q", s)) @@ -168,10 +235,15 @@ func (s Subsystem) Warnable() *Warnable { var registeredWarnables = map[WarnableCode]*Warnable{} +var noopWarnable Warnable + // Register registers a new Warnable with the health package and returns it. // Register panics if the Warnable was already registered, because Warnables // should be unique across the program. func Register(w *Warnable) *Warnable { + if !buildfeatures.HasHealth { + return &noopWarnable + } if registeredWarnables[w.Code] != nil { panic(fmt.Sprintf("health: a Warnable with code %q was already registered", w.Code)) } @@ -183,6 +255,9 @@ func Register(w *Warnable) *Warnable { // unregister removes a Warnable from the health package. It should only be used // for testing purposes. func unregister(w *Warnable) { + if !buildfeatures.HasHealth { + return + } if registeredWarnables[w.Code] == nil { panic(fmt.Sprintf("health: attempting to unregister Warnable %q that was not registered", w.Code)) } @@ -193,13 +268,15 @@ func unregister(w *Warnable) { // the program. type WarnableCode string -// A Warnable is something that we might want to warn the user about, or not. A Warnable is either -// in an healthy or unhealth state. A Warnable is unhealthy if the Tracker knows about a WarningState -// affecting the Warnable. -// In most cases, Warnables are components of the backend (for instance, "DNS" or "Magicsock"). -// Warnables are similar to the Subsystem type previously used in this package, but they provide -// a unique identifying code for each Warnable, along with more metadata that makes it easier for -// a GUI to display the Warnable in a user-friendly way. +// A Warnable is something that we might want to warn the user about, or not. A +// Warnable is either in a healthy or unhealthy state. A Warnable is unhealthy if +// the Tracker knows about a WarningState affecting the Warnable. +// +// In most cases, Warnables are components of the backend (for instance, "DNS" +// or "Magicsock"). Warnables are similar to the Subsystem type previously used +// in this package, but they provide a unique identifying code for each +// Warnable, along with more metadata that makes it easier for a GUI to display +// the Warnable in a user-friendly way. type Warnable struct { // Code is a string that uniquely identifies this Warnable across the entire Tailscale backend, // and can be mapped to a user-displayable localized string. @@ -217,9 +294,11 @@ type Warnable struct { // TODO(angott): turn this into a SeverityFunc, which allows the Warnable to change its severity based on // the Args of the unhappy state, just like we do in the Text function. Severity Severity - // DependsOn is a set of Warnables that this Warnable depends, on and need to be healthy - // before this Warnable can also be healthy again. The GUI can use this information to ignore + // DependsOn is a set of Warnables that this Warnable depends on and need to be healthy + // before this Warnable is relevant. The GUI can use this information to ignore // this Warnable if one of its dependencies is unhealthy. + // That is, if any of these Warnables are unhealthy, then this Warnable is not relevant + // and should be considered healthy to bother the user about. DependsOn []*Warnable // MapDebugFlag is a MapRequest.DebugFlag that is sent to control when this Warnable is unhealthy @@ -251,6 +330,9 @@ func StaticMessage(s string) func(Args) string { // some lost Tracker plumbing, we want to capture stack trace // samples when it occurs. func (t *Tracker) nil() bool { + if !buildfeatures.HasHealth { + return true + } if t != nil { return false } @@ -312,38 +394,23 @@ func (ws *warningState) Equal(other *warningState) bool { // IsVisible returns whether the Warnable should be visible to the user, based on the TimeToVisible // field of the Warnable and the BrokenSince time when the Warnable became unhealthy. -func (w *Warnable) IsVisible(ws *warningState) bool { +func (w *Warnable) IsVisible(ws *warningState, clockNow func() time.Time) bool { if ws == nil || w.TimeToVisible == 0 { return true } - return time.Since(ws.BrokenSince) >= w.TimeToVisible + return clockNow().Sub(ws.BrokenSince) >= w.TimeToVisible } -// SetMetricsRegistry sets up the metrics for the Tracker. It takes -// a usermetric.Registry and registers the metrics there. -func (t *Tracker) SetMetricsRegistry(reg *usermetric.Registry) { - if reg == nil || t.metricHealthMessage != nil { - return +// IsUnhealthy reports whether the current state is unhealthy because the given +// warnable is set. +func (t *Tracker) IsUnhealthy(w *Warnable) bool { + if !buildfeatures.HasHealth || t.nil() { + return false } - - t.metricHealthMessage = usermetric.NewMultiLabelMapWithRegistry[metricHealthMessageLabel]( - reg, - "tailscaled_health_messages", - "gauge", - "Number of health messages broken down by type.", - ) - - t.metricHealthMessage.Set(metricHealthMessageLabel{ - Type: "warning", - }, expvar.Func(func() any { - if t.nil() { - return 0 - } - t.mu.Lock() - defer t.mu.Unlock() - t.updateBuiltinWarnablesLocked() - return int64(len(t.stringsLocked())) - })) + t.mu.Lock() + defer t.mu.Unlock() + _, exists := t.warnableVal[w] + return exists } // SetUnhealthy sets a warningState for the given Warnable with the provided Args, and should be @@ -351,7 +418,7 @@ func (t *Tracker) SetMetricsRegistry(reg *usermetric.Registry) { // SetUnhealthy takes ownership of args. The args can be nil if no additional information is // needed for the unhealthy state. func (t *Tracker) SetUnhealthy(w *Warnable, args Args) { - if t.nil() { + if !buildfeatures.HasHealth || t.nil() { return } t.mu.Lock() @@ -360,13 +427,13 @@ func (t *Tracker) SetUnhealthy(w *Warnable, args Args) { } func (t *Tracker) setUnhealthyLocked(w *Warnable, args Args) { - if w == nil { + if !buildfeatures.HasHealth || w == nil { return } // If we already have a warningState for this Warnable with an earlier BrokenSince time, keep that // BrokenSince time. - brokenSince := time.Now() + brokenSince := t.now() if existingWS := t.warnableVal[w]; existingWS != nil { brokenSince = existingWS.BrokenSince } @@ -381,35 +448,37 @@ func (t *Tracker) setUnhealthyLocked(w *Warnable, args Args) { prevWs := t.warnableVal[w] mak.Set(&t.warnableVal, w, ws) if !ws.Equal(prevWs) { - for _, cb := range t.watchers { - // If the Warnable has been unhealthy for more than its TimeToVisible, the callback should be - // executed immediately. Otherwise, the callback should be enqueued to run once the Warnable - // becomes visible. - if w.IsVisible(ws) { - go cb(w, w.unhealthyState(ws)) - continue - } - - // The time remaining until the Warnable will be visible to the user is the TimeToVisible - // minus the time that has already passed since the Warnable became unhealthy. - visibleIn := w.TimeToVisible - time.Since(brokenSince) - mak.Set(&t.pendingVisibleTimers, w, time.AfterFunc(visibleIn, func() { + + change := Change{ + WarnableChanged: true, + Warnable: w, + UnhealthyState: w.unhealthyState(ws), + } + // Publish the change to the event bus. If the change is already visible + // now, publish it immediately; otherwise queue a timer to publish it at + // a future time when it becomes visible. + if w.IsVisible(ws, t.now) { + t.changePub.Publish(change) + } else { + visibleIn := w.TimeToVisible - t.now().Sub(brokenSince) + tc := t.clock().AfterFunc(visibleIn, func() { t.mu.Lock() defer t.mu.Unlock() // Check if the Warnable is still unhealthy, as it could have become healthy between the time // the timer was set for and the time it was executed. if t.warnableVal[w] != nil { - go cb(w, w.unhealthyState(ws)) + t.changePub.Publish(change) delete(t.pendingVisibleTimers, w) } - })) + }) + mak.Set(&t.pendingVisibleTimers, w, tc) } } } // SetHealthy removes any warningState for the given Warnable. func (t *Tracker) SetHealthy(w *Warnable) { - if t.nil() { + if !buildfeatures.HasHealth || t.nil() { return } t.mu.Lock() @@ -418,7 +487,7 @@ func (t *Tracker) SetHealthy(w *Warnable) { } func (t *Tracker) setHealthyLocked(w *Warnable) { - if t.warnableVal[w] == nil { + if !buildfeatures.HasHealth || t.warnableVal[w] == nil { // Nothing to remove return } @@ -431,9 +500,20 @@ func (t *Tracker) setHealthyLocked(w *Warnable) { delete(t.pendingVisibleTimers, w) } - for _, cb := range t.watchers { - go cb(w, nil) + change := Change{ + WarnableChanged: true, + Warnable: w, } + t.changePub.Publish(change) +} + +// notifyWatchersControlChangedLocked calls each watcher to signal that control +// health messages have changed (and should be fetched via CurrentState). +func (t *Tracker) notifyWatchersControlChangedLocked() { + change := Change{ + ControlHealthChanged: true, + } + t.changePub.Publish(change) } // AppendWarnableDebugFlags appends to base any health items that are currently in failed @@ -459,35 +539,23 @@ func (t *Tracker) AppendWarnableDebugFlags(base []string) []string { return ret } -// RegisterWatcher adds a function that will be called whenever the health state of any Warnable changes. -// If a Warnable becomes unhealthy or its unhealthy state is updated, the callback will be called with its -// current Representation. -// If a Warnable becomes healthy, the callback will be called with ws set to nil. -// The provided callback function will be executed in its own goroutine. The returned function can be used -// to unregister the callback. -func (t *Tracker) RegisterWatcher(cb func(w *Warnable, r *UnhealthyState)) (unregister func()) { - if t.nil() { - return func() {} - } - t.initOnce.Do(t.doOnceInit) - t.mu.Lock() - defer t.mu.Unlock() - if t.watchers == nil { - t.watchers = set.HandleSet[func(*Warnable, *UnhealthyState)]{} - } - handle := t.watchers.Add(cb) - if t.timer == nil { - t.timer = time.AfterFunc(time.Minute, t.timerSelfCheck) - } - return func() { - t.mu.Lock() - defer t.mu.Unlock() - delete(t.watchers, handle) - if len(t.watchers) == 0 && t.timer != nil { - t.timer.Stop() - t.timer = nil - } - } +// Change is used to communicate a change to health. This could either be due to +// a Warnable changing from health to unhealthy (or vice-versa), or because the +// health messages received from the control-plane have changed. +// +// Exactly one *Changed field will be true. +type Change struct { + // ControlHealthChanged indicates it was health messages from the + // control-plane server that changed. + ControlHealthChanged bool + + // WarnableChanged indicates it was a client Warnable which changed state. + WarnableChanged bool + // Warnable is whose health changed, as indicated in UnhealthyState. + Warnable *Warnable + // UnhealthyState is set if the changed Warnable is now unhealthy, or nil + // if Warnable is now healthy. + UnhealthyState *UnhealthyState } // SetRouterHealth sets the state of the wgengine/router.Router. @@ -510,22 +578,12 @@ func (t *Tracker) SetDNSHealth(err error) { t.setErr(SysDNS, err) } // Deprecated: Warnables should be preferred over Subsystem errors. func (t *Tracker) DNSHealth() error { return t.get(SysDNS) } -// SetDNSOSHealth sets the state of the net/dns.OSConfigurator -// -// Deprecated: Warnables should be preferred over Subsystem errors. -func (t *Tracker) SetDNSOSHealth(err error) { t.setErr(SysDNSOS, err) } - // SetDNSManagerHealth sets the state of the Linux net/dns manager's // discovery of the /etc/resolv.conf situation. // // Deprecated: Warnables should be preferred over Subsystem errors. func (t *Tracker) SetDNSManagerHealth(err error) { t.setErr(SysDNSManager, err) } -// DNSOSHealth returns the net/dns.OSConfigurator error state. -// -// Deprecated: Warnables should be preferred over Subsystem errors. -func (t *Tracker) DNSOSHealth() error { return t.get(SysDNSOS) } - // SetTKAHealth sets the health of the tailnet key authority. // // Deprecated: Warnables should be preferred over Subsystem errors. @@ -630,13 +688,15 @@ func (t *Tracker) updateLegacyErrorWarnableLocked(key Subsystem, err error) { } } -func (t *Tracker) SetControlHealth(problems []string) { +func (t *Tracker) SetControlHealth(problems map[tailcfg.DisplayMessageID]tailcfg.DisplayMessage) { if t.nil() { return } t.mu.Lock() defer t.mu.Unlock() - t.controlHealth = problems + + t.controlMessages = problems + t.selfCheckLocked() } @@ -651,10 +711,10 @@ func (t *Tracker) GotStreamedMapResponse() { } t.mu.Lock() defer t.mu.Unlock() - t.lastStreamedMapResponse = time.Now() + t.lastStreamedMapResponse = t.now() if !t.inMapPoll { t.inMapPoll = true - t.inMapPollSince = time.Now() + t.inMapPollSince = t.now() } t.selfCheckLocked() } @@ -671,7 +731,7 @@ func (t *Tracker) SetOutOfPollNetMap() { return } t.inMapPoll = false - t.lastMapPollEndedAt = time.Now() + t.lastMapPollEndedAt = t.now() t.selfCheckLocked() } @@ -713,7 +773,7 @@ func (t *Tracker) NoteMapRequestHeard(mr *tailcfg.MapRequest) { // against SetMagicSockDERPHome and // SetDERPRegionConnectedState - t.lastMapRequestHeard = time.Now() + t.lastMapRequestHeard = t.now() t.selfCheckLocked() } @@ -751,7 +811,7 @@ func (t *Tracker) NoteDERPRegionReceivedFrame(region int) { } t.mu.Lock() defer t.mu.Unlock() - mak.Set(&t.derpRegionLastFrame, region, time.Now()) + mak.Set(&t.derpRegionLastFrame, region, t.now()) t.selfCheckLocked() } @@ -810,9 +870,9 @@ func (t *Tracker) SetIPNState(state string, wantRunning bool) { // The first time we see wantRunning=true and it used to be false, it means the user requested // the backend to start. We store this timestamp and use it to silence some warnings that are // expected during startup. - t.ipnWantRunningLastTrue = time.Now() + t.ipnWantRunningLastTrue = t.now() t.setUnhealthyLocked(warmingUpWarnable, nil) - time.AfterFunc(warmingUpWarnableDuration, func() { + t.clock().AfterFunc(warmingUpWarnableDuration, func() { t.mu.Lock() t.updateWarmingUpWarnableLocked() t.mu.Unlock() @@ -920,8 +980,8 @@ func (t *Tracker) selfCheckLocked() { // OverallError returns a summary of the health state. // -// If there are multiple problems, the error will be of type -// multierr.Error. +// If there are multiple problems, the error will be joined using +// [errors.Join]. func (t *Tracker) OverallError() error { if t.nil() { return nil @@ -932,13 +992,13 @@ func (t *Tracker) OverallError() error { return t.multiErrLocked() } -// Strings() returns a string array containing the Text of all Warnings -// currently known to the Tracker. These strings can be presented to the -// user, although ideally you would use the Code property on each Warning -// to show a localized version of them instead. -// This function is here for legacy compatibility purposes and is deprecated. +// Strings() returns a string array containing the Text of all Warnings and +// ControlHealth messages currently known to the Tracker. These strings can be +// presented to the user, although ideally you would use the Code property on +// each Warning to show a localized version of them instead. This function is +// here for legacy compatibility purposes and is deprecated. func (t *Tracker) Strings() []string { - if t.nil() { + if !buildfeatures.HasHealth || t.nil() { return nil } t.mu.Lock() @@ -947,18 +1007,42 @@ func (t *Tracker) Strings() []string { } func (t *Tracker) stringsLocked() []string { + if !buildfeatures.HasHealth { + return nil + } result := []string{} for w, ws := range t.warnableVal { - if !w.IsVisible(ws) { + if !w.IsVisible(ws, t.now) { // Do not append invisible warnings. continue } + if t.isEffectivelyHealthyLocked(w) { + continue + } if ws.Args == nil { result = append(result, w.Text(Args{})) } else { result = append(result, w.Text(ws.Args)) } } + + warnLen := len(result) + for _, c := range t.controlMessages { + var msg string + if c.Title != "" && c.Text != "" { + msg = c.Title + ": " + c.Text + } else if c.Title != "" { + msg = c.Title + "." + } else if c.Text != "" { + msg = c.Text + } + if c.PrimaryAction != nil { + msg = msg + " " + c.PrimaryAction.Label + ": " + c.PrimaryAction.URL + } + result = append(result, msg) + } + sort.Strings(result[warnLen:]) + return result } @@ -978,7 +1062,7 @@ func (t *Tracker) errorsLocked() []error { // This function is here for legacy compatibility purposes and is deprecated. func (t *Tracker) multiErrLocked() error { errs := t.errorsLocked() - return multierr.New(errs...) + return errors.Join(errs...) } var fakeErrForTesting = envknob.RegisterString("TS_DEBUG_FAKE_HEALTH_ERROR") @@ -986,6 +1070,9 @@ var fakeErrForTesting = envknob.RegisterString("TS_DEBUG_FAKE_HEALTH_ERROR") // updateBuiltinWarnablesLocked performs a number of checks on the state of the backend, // and adds/removes Warnings from the Tracker as needed. func (t *Tracker) updateBuiltinWarnablesLocked() { + if !buildfeatures.HasHealth { + return + } t.updateWarmingUpWarnableLocked() if w, show := t.showUpdateWarnable(); show { @@ -1018,7 +1105,7 @@ func (t *Tracker) updateBuiltinWarnablesLocked() { t.setHealthyLocked(localLogWarnable) } - now := time.Now() + now := t.now() // How long we assume we'll have heard a DERP frame or a MapResponse // KeepAlive by. @@ -1028,8 +1115,10 @@ func (t *Tracker) updateBuiltinWarnablesLocked() { recentlyOn := now.Sub(t.ipnWantRunningLastTrue) < 5*time.Second homeDERP := t.derpHomeRegion - if recentlyOn { + if recentlyOn || !t.inMapPoll { // If user just turned Tailscale on, don't warn for a bit. + // Also, if we're not in a map poll, that means we don't yet + // have a DERPMap or aren't in a state where we even want t.setHealthyLocked(noDERPHomeWarnable) t.setHealthyLocked(noDERPConnectionWarnable) t.setHealthyLocked(derpTimeoutWarnable) @@ -1051,11 +1140,15 @@ func (t *Tracker) updateBuiltinWarnablesLocked() { ArgDuration: d.Round(time.Second).String(), }) } - } else { + } else if homeDERP != 0 { t.setUnhealthyLocked(noDERPConnectionWarnable, Args{ ArgDERPRegionID: fmt.Sprint(homeDERP), ArgDERPRegionName: t.derpRegionNameLocked(homeDERP), }) + } else { + // No DERP home yet determined yet. There's probably some + // other problem or things are just starting up. + t.setHealthyLocked(noDERPConnectionWarnable) } if !t.ipnWantRunning { @@ -1133,14 +1226,10 @@ func (t *Tracker) updateBuiltinWarnablesLocked() { t.setHealthyLocked(derpRegionErrorWarnable) } - if len(t.controlHealth) > 0 { - for _, s := range t.controlHealth { - t.setUnhealthyLocked(controlHealthWarnable, Args{ - ArgError: s, - }) - } - } else { - t.setHealthyLocked(controlHealthWarnable) + // Check if control health messages have changed + if !maps.EqualFunc(t.lastNotifiedControlMessages, t.controlMessages, tailcfg.DisplayMessage.Equal) { + t.lastNotifiedControlMessages = t.controlMessages + t.notifyWatchersControlChangedLocked() } if err := envknob.ApplyDiskConfigError(); err != nil { @@ -1174,7 +1263,7 @@ func (t *Tracker) updateBuiltinWarnablesLocked() { // updateWarmingUpWarnableLocked ensures the warmingUpWarnable is healthy if wantRunning has been set to true // for more than warmingUpWarnableDuration. func (t *Tracker) updateWarmingUpWarnableLocked() { - if !t.ipnWantRunningLastTrue.IsZero() && time.Now().After(t.ipnWantRunningLastTrue.Add(warmingUpWarnableDuration)) { + if !t.ipnWantRunningLastTrue.IsZero() && t.now().After(t.ipnWantRunningLastTrue.Add(warmingUpWarnableDuration)) { t.setHealthyLocked(warmingUpWarnable) } } @@ -1222,11 +1311,17 @@ func (s *ReceiveFuncStats) Name() string { } func (s *ReceiveFuncStats) Enter() { + if !buildfeatures.HasHealth { + return + } s.numCalls.Add(1) s.inCall.Store(true) } func (s *ReceiveFuncStats) Exit() { + if !buildfeatures.HasHealth { + return + } s.inCall.Store(false) } @@ -1235,7 +1330,7 @@ func (s *ReceiveFuncStats) Exit() { // // If t is nil, it returns nil. func (t *Tracker) ReceiveFuncStats(which ReceiveFunc) *ReceiveFuncStats { - if t == nil { + if !buildfeatures.HasHealth || t == nil { return nil } t.initOnce.Do(t.doOnceInit) @@ -1243,6 +1338,9 @@ func (t *Tracker) ReceiveFuncStats(which ReceiveFunc) *ReceiveFuncStats { } func (t *Tracker) doOnceInit() { + if !buildfeatures.HasHealth { + return + } for i := range t.MagicSockReceiveFuncs { f := &t.MagicSockReceiveFuncs[i] f.name = (ReceiveFunc(i)).String() @@ -1286,13 +1384,8 @@ func (t *Tracker) LastNoiseDialWasRecent() bool { t.mu.Lock() defer t.mu.Unlock() - now := time.Now() + now := t.now() dur := now.Sub(t.lastNoiseDial) t.lastNoiseDial = now return dur < 2*time.Minute } - -type metricHealthMessageLabel struct { - // TODO: break down by warnable.severity as well? - Type string -} diff --git a/health/health_test.go b/health/health_test.go index 8107c1cf0..af7d06c8f 100644 --- a/health/health_test.go +++ b/health/health_test.go @@ -4,23 +4,61 @@ package health import ( + "errors" + "flag" "fmt" + "maps" "reflect" "slices" + "strconv" "testing" + "testing/synctest" "time" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "tailscale.com/metrics" "tailscale.com/tailcfg" + "tailscale.com/tsconst" + "tailscale.com/tstest" + "tailscale.com/tstime" "tailscale.com/types/opt" + "tailscale.com/util/eventbus" + "tailscale.com/util/eventbus/eventbustest" + "tailscale.com/util/usermetric" + "tailscale.com/version" ) +var doDebug = flag.Bool("debug", false, "Enable debug logging") + +func wantChange(c Change) func(c Change) (bool, error) { + return func(cEv Change) (bool, error) { + if cEv.ControlHealthChanged != c.ControlHealthChanged { + return false, fmt.Errorf("expected ControlHealthChanged %t, got %t", c.ControlHealthChanged, cEv.ControlHealthChanged) + } + if cEv.WarnableChanged != c.WarnableChanged { + return false, fmt.Errorf("expected WarnableChanged %t, got %t", c.WarnableChanged, cEv.WarnableChanged) + } + if c.Warnable != nil && (cEv.Warnable == nil || cEv.Warnable != c.Warnable) { + return false, fmt.Errorf("expected Warnable %+v, got %+v", c.Warnable, cEv.Warnable) + } + + if c.UnhealthyState != nil { + panic("comparison of UnhealthyState is not yet supported") + } + + return true, nil + } +} + func TestAppendWarnableDebugFlags(t *testing.T) { - var tr Tracker + tr := NewTracker(eventbustest.NewBus(t)) for i := range 10 { w := Register(&Warnable{ Code: WarnableCode(fmt.Sprintf("warnable-code-%d", i)), MapDebugFlag: fmt.Sprint(i), + Text: StaticMessage(""), }) defer unregister(w) if i%2 == 0 { @@ -59,7 +97,9 @@ func TestNilMethodsDontCrash(t *testing.T) { } func TestSetUnhealthyWithDuplicateThenHealthyAgain(t *testing.T) { - ht := Tracker{} + bus := eventbustest.NewBus(t) + watcher := eventbustest.NewWatcher(t, bus) + ht := NewTracker(bus) if len(ht.Strings()) != 0 { t.Fatalf("before first insertion, len(newTracker.Strings) = %d; want = 0", len(ht.Strings())) } @@ -83,10 +123,20 @@ func TestSetUnhealthyWithDuplicateThenHealthyAgain(t *testing.T) { if !reflect.DeepEqual(ht.Strings(), want) { t.Fatalf("after setting the healthy, newTracker.Strings() = %v; want = %v", ht.Strings(), want) } + + if err := eventbustest.ExpectExactly(watcher, + wantChange(Change{WarnableChanged: true, Warnable: testWarnable}), + wantChange(Change{WarnableChanged: true, Warnable: testWarnable}), + wantChange(Change{WarnableChanged: true, Warnable: testWarnable}), + ); err != nil { + t.Fatalf("expected events, got %q", err) + } } func TestRemoveAllWarnings(t *testing.T) { - ht := Tracker{} + bus := eventbustest.NewBus(t) + watcher := eventbustest.NewWatcher(t, bus) + ht := NewTracker(bus) if len(ht.Strings()) != 0 { t.Fatalf("before first insertion, len(newTracker.Strings) = %d; want = 0", len(ht.Strings())) } @@ -100,65 +150,96 @@ func TestRemoveAllWarnings(t *testing.T) { if len(ht.Strings()) != 0 { t.Fatalf("after RemoveAll, len(newTracker.Strings) = %d; want = 0", len(ht.Strings())) } + if err := eventbustest.ExpectExactly(watcher, + wantChange(Change{WarnableChanged: true, Warnable: testWarnable}), + wantChange(Change{WarnableChanged: true, Warnable: testWarnable}), + ); err != nil { + t.Fatalf("expected events, got %q", err) + } } // TestWatcher tests that a registered watcher function gets called with the correct // Warnable and non-nil/nil UnhealthyState upon setting a Warnable to unhealthy/healthy. func TestWatcher(t *testing.T) { - ht := Tracker{} - wantText := "Hello world" - becameUnhealthy := make(chan struct{}) - becameHealthy := make(chan struct{}) - - watcherFunc := func(w *Warnable, us *UnhealthyState) { - if w != testWarnable { - t.Fatalf("watcherFunc was called, but with an unexpected Warnable: %v, want: %v", w, testWarnable) - } + tests := []struct { + name string + preFunc func(t *testing.T, ht *Tracker, bus *eventbus.Bus, fn func(Change)) + }{ + { + name: "with-eventbus", + preFunc: func(_ *testing.T, _ *Tracker, bus *eventbus.Bus, fn func(c Change)) { + client := bus.Client("healthwatchertestclient") + sub := eventbus.Subscribe[Change](client) + go func() { + for { + select { + case <-sub.Done(): + return + case change := <-sub.Events(): + fn(change) + } + } + }() + }, + }, + } - if us != nil { - if us.Text != wantText { - t.Fatalf("unexpected us.Text: %s, want: %s", us.Text, wantText) - } - if us.Args[ArgError] != wantText { - t.Fatalf("unexpected us.Args[ArgError]: %s, want: %s", us.Args[ArgError], wantText) + for _, tt := range tests { + t.Run(tt.name, func(*testing.T) { + bus := eventbustest.NewBus(t) + ht := NewTracker(bus) + wantText := "Hello world" + becameUnhealthy := make(chan struct{}) + becameHealthy := make(chan struct{}) + + watcherFunc := func(c Change) { + w := c.Warnable + us := c.UnhealthyState + if w != testWarnable { + t.Fatalf("watcherFunc was called, but with an unexpected Warnable: %v, want: %v", w, testWarnable) + } + + if us != nil { + if us.Text != wantText { + t.Fatalf("unexpected us.Text: %q, want: %s", us.Text, wantText) + } + if us.Args[ArgError] != wantText { + t.Fatalf("unexpected us.Args[ArgError]: %q, want: %s", us.Args[ArgError], wantText) + } + becameUnhealthy <- struct{}{} + } else { + becameHealthy <- struct{}{} + } } - becameUnhealthy <- struct{}{} - } else { - becameHealthy <- struct{}{} - } - } - unregisterFunc := ht.RegisterWatcher(watcherFunc) - if len(ht.watchers) != 1 { - t.Fatalf("after RegisterWatcher, len(newTracker.watchers) = %d; want = 1", len(ht.watchers)) - } - ht.SetUnhealthy(testWarnable, Args{ArgError: wantText}) + // Set up test + tt.preFunc(t, ht, bus, watcherFunc) - select { - case <-becameUnhealthy: - // Test passed because the watcher got notified of an unhealthy state - case <-becameHealthy: - // Test failed because the watcher got of a healthy state instead of an unhealthy one - t.Fatalf("watcherFunc was called with a healthy state") - case <-time.After(1 * time.Second): - t.Fatalf("watcherFunc didn't get called upon calling SetUnhealthy") - } + // Start running actual test + ht.SetUnhealthy(testWarnable, Args{ArgError: wantText}) - ht.SetHealthy(testWarnable) + select { + case <-becameUnhealthy: + // Test passed because the watcher got notified of an unhealthy state + case <-becameHealthy: + // Test failed because the watcher got of a healthy state instead of an unhealthy one + t.Fatalf("watcherFunc was called with a healthy state") + case <-time.After(5 * time.Second): + t.Fatalf("watcherFunc didn't get called upon calling SetUnhealthy") + } - select { - case <-becameUnhealthy: - // Test failed because the watcher got of an unhealthy state instead of a healthy one - t.Fatalf("watcherFunc was called with an unhealthy state") - case <-becameHealthy: - // Test passed because the watcher got notified of a healthy state - case <-time.After(1 * time.Second): - t.Fatalf("watcherFunc didn't get called upon calling SetUnhealthy") - } + ht.SetHealthy(testWarnable) - unregisterFunc() - if len(ht.watchers) != 0 { - t.Fatalf("after unregisterFunc, len(newTracker.watchers) = %d; want = 0", len(ht.watchers)) + select { + case <-becameUnhealthy: + // Test failed because the watcher got of an unhealthy state instead of a healthy one + t.Fatalf("watcherFunc was called with an unhealthy state") + case <-becameHealthy: + // Test passed because the watcher got notified of a healthy state + case <-time.After(5 * time.Second): + t.Fatalf("watcherFunc didn't get called upon calling SetUnhealthy") + } + }) } } @@ -167,43 +248,72 @@ func TestWatcher(t *testing.T) { // has a TimeToVisible set, which means that a watcher should only be notified of an unhealthy state after // the TimeToVisible duration has passed. func TestSetUnhealthyWithTimeToVisible(t *testing.T) { - ht := Tracker{} - mw := Register(&Warnable{ - Code: "test-warnable-3-secs-to-visible", - Title: "Test Warnable with 3 seconds to visible", - Text: StaticMessage("Hello world"), - TimeToVisible: 2 * time.Second, - ImpactsConnectivity: true, - }) - defer unregister(mw) - - becameUnhealthy := make(chan struct{}) - becameHealthy := make(chan struct{}) - - watchFunc := func(w *Warnable, us *UnhealthyState) { - if w != mw { - t.Fatalf("watcherFunc was called, but with an unexpected Warnable: %v, want: %v", w, w) - } - - if us != nil { - becameUnhealthy <- struct{}{} - } else { - becameHealthy <- struct{}{} - } + tests := []struct { + name string + preFunc func(t *testing.T, ht *Tracker, bus *eventbus.Bus, fn func(Change)) + }{ + { + name: "with-eventbus", + preFunc: func(_ *testing.T, _ *Tracker, bus *eventbus.Bus, fn func(c Change)) { + client := bus.Client("healthwatchertestclient") + sub := eventbus.Subscribe[Change](client) + go func() { + for { + select { + case <-sub.Done(): + return + case change := <-sub.Events(): + fn(change) + } + } + }() + }, + }, } + for _, tt := range tests { + t.Run(tt.name, func(*testing.T) { + bus := eventbustest.NewBus(t) + ht := NewTracker(bus) + mw := Register(&Warnable{ + Code: "test-warnable-3-secs-to-visible", + Title: "Test Warnable with 3 seconds to visible", + Text: StaticMessage("Hello world"), + TimeToVisible: 2 * time.Second, + ImpactsConnectivity: true, + }) + + becameUnhealthy := make(chan struct{}) + becameHealthy := make(chan struct{}) + + watchFunc := func(c Change) { + w := c.Warnable + us := c.UnhealthyState + if w != mw { + t.Fatalf("watcherFunc was called, but with an unexpected Warnable: %v, want: %v", w, w) + } + + if us != nil { + becameUnhealthy <- struct{}{} + } else { + becameHealthy <- struct{}{} + } + } - ht.RegisterWatcher(watchFunc) - ht.SetUnhealthy(mw, Args{ArgError: "Hello world"}) - - select { - case <-becameUnhealthy: - // Test failed because the watcher got notified of an unhealthy state - t.Fatalf("watcherFunc was called with an unhealthy state") - case <-becameHealthy: - // Test failed because the watcher got of a healthy state - t.Fatalf("watcherFunc was called with a healthy state") - case <-time.After(1 * time.Second): - // As expected, watcherFunc still had not been called after 1 second + tt.preFunc(t, ht, bus, watchFunc) + ht.SetUnhealthy(mw, Args{ArgError: "Hello world"}) + + select { + case <-becameUnhealthy: + // Test failed because the watcher got notified of an unhealthy state + t.Fatalf("watcherFunc was called with an unhealthy state") + case <-becameHealthy: + // Test failed because the watcher got of a healthy state + t.Fatalf("watcherFunc was called with a healthy state") + case <-time.After(1 * time.Second): + // As expected, watcherFunc still had not been called after 1 second + } + unregister(mw) + }) } } @@ -229,7 +339,7 @@ func TestRegisterWarnablePanicsWithDuplicate(t *testing.T) { // TestCheckDependsOnAppearsInUnhealthyState asserts that the DependsOn field in the UnhealthyState // is populated with the WarnableCode(s) of the Warnable(s) that a warning depends on. func TestCheckDependsOnAppearsInUnhealthyState(t *testing.T) { - ht := Tracker{} + ht := NewTracker(eventbustest.NewBus(t)) w1 := Register(&Warnable{ Code: "w1", Text: StaticMessage("W1 Text"), @@ -254,9 +364,15 @@ func TestCheckDependsOnAppearsInUnhealthyState(t *testing.T) { } ht.SetUnhealthy(w2, Args{ArgError: "w2 is also unhealthy now"}) us2, ok := ht.CurrentState().Warnings[w2.Code] + if ok { + t.Fatalf("Saw w2 being unhealthy but it shouldn't be, as it depends on unhealthy w1") + } + ht.SetHealthy(w1) + us2, ok = ht.CurrentState().Warnings[w2.Code] if !ok { - t.Fatalf("Expected an UnhealthyState for w2, got nothing") + t.Fatalf("w2 wasn't unhealthy; want it to be unhealthy now that w1 is back healthy") } + wantDependsOn = slices.Concat([]WarnableCode{w1.Code}, wantDependsOn) if !reflect.DeepEqual(us2.DependsOn, wantDependsOn) { t.Fatalf("Expected DependsOn = %v in the unhealthy state, got: %v", wantDependsOn, us2.DependsOn) @@ -273,7 +389,7 @@ func TestShowUpdateWarnable(t *testing.T) { wantShow bool }{ { - desc: "nil CientVersion", + desc: "nil ClientVersion", check: true, cv: nil, wantWarnable: nil, @@ -333,11 +449,11 @@ func TestShowUpdateWarnable(t *testing.T) { } for _, tt := range tests { t.Run(tt.desc, func(t *testing.T) { - tr := &Tracker{ - checkForUpdates: tt.check, - applyUpdates: tt.apply, - latestVersion: tt.cv, - } + tr := NewTracker(eventbustest.NewBus(t)) + tr.checkForUpdates = tt.check + tr.applyUpdates = tt.apply + tr.latestVersion = tt.cv + gotWarnable, gotShow := tr.showUpdateWarnable() if gotWarnable != tt.wantWarnable { t.Errorf("got warnable: %v, want: %v", gotWarnable, tt.wantWarnable) @@ -348,3 +464,539 @@ func TestShowUpdateWarnable(t *testing.T) { }) } } + +func TestHealthMetric(t *testing.T) { + unstableBuildWarning := 0 + if version.IsUnstableBuild() { + unstableBuildWarning = 1 + } + + tests := []struct { + desc string + check bool + apply opt.Bool + cv *tailcfg.ClientVersion + wantMetricCount int + }{ + // When running in dev, and not initialising the client, there will be two warnings + // by default: + // - is-using-unstable-version (except on the release branch) + // - wantrunning-false + { + desc: "base-warnings", + check: true, + cv: nil, + wantMetricCount: unstableBuildWarning + 1, + }, + // with: update-available + { + desc: "update-warning", + check: true, + cv: &tailcfg.ClientVersion{RunningLatest: false, LatestVersion: "1.2.3"}, + wantMetricCount: unstableBuildWarning + 2, + }, + } + for _, tt := range tests { + t.Run(tt.desc, func(t *testing.T) { + tr := NewTracker(eventbustest.NewBus(t)) + tr.checkForUpdates = tt.check + tr.applyUpdates = tt.apply + tr.latestVersion = tt.cv + tr.SetMetricsRegistry(&usermetric.Registry{}) + m, ok := tr.metricHealthMessage.(*metrics.MultiLabelMap[metricHealthMessageLabel]) + if !ok { + t.Fatal("metricHealthMessage has wrong type or is nil") + } + if val := m.Get(metricHealthMessageLabel{Type: MetricLabelWarning}).String(); val != strconv.Itoa(tt.wantMetricCount) { + t.Fatalf("metric value: %q, want: %q", val, strconv.Itoa(tt.wantMetricCount)) + } + for _, w := range tr.CurrentState().Warnings { + t.Logf("warning: %v", w) + } + }) + } +} + +// TestNoDERPHomeWarnable checks that we don't +// complain about no DERP home if we're not in a +// map poll. +func TestNoDERPHomeWarnable(t *testing.T) { + t.Skip("TODO: fix https://github.com/tailscale/tailscale/issues/14798 to make this test not deadlock") + clock := tstest.NewClock(tstest.ClockOpts{ + Start: time.Unix(123, 0), + FollowRealTime: false, + }) + ht := NewTracker(eventbustest.NewBus(t)) + ht.testClock = clock + ht.SetIPNState("NeedsLogin", true) + + // Advance 30 seconds to get past the "recentlyLoggedIn" check. + clock.Advance(30 * time.Second) + ht.updateBuiltinWarnablesLocked() + + // Advance to get past the the TimeToVisible delay. + clock.Advance(noDERPHomeWarnable.TimeToVisible * 2) + + ht.updateBuiltinWarnablesLocked() + if ws, ok := ht.CurrentState().Warnings[noDERPHomeWarnable.Code]; ok { + t.Fatalf("got unexpected noDERPHomeWarnable warnable: %v", ws) + } +} + +// TestNoDERPHomeWarnableManual is like TestNoDERPHomeWarnable +// but doesn't use tstest.Clock so avoids the deadlock +// I hit: https://github.com/tailscale/tailscale/issues/14798 +func TestNoDERPHomeWarnableManual(t *testing.T) { + ht := NewTracker(eventbustest.NewBus(t)) + ht.SetIPNState("NeedsLogin", true) + + // Avoid wantRunning: + ht.ipnWantRunningLastTrue = ht.ipnWantRunningLastTrue.Add(-10 * time.Second) + ht.updateBuiltinWarnablesLocked() + + ws, ok := ht.warnableVal[noDERPHomeWarnable] + if ok { + t.Fatalf("got unexpected noDERPHomeWarnable warnable: %v", ws) + } +} + +func TestControlHealth(t *testing.T) { + ht := NewTracker(eventbustest.NewBus(t)) + ht.SetIPNState("NeedsLogin", true) + ht.GotStreamedMapResponse() + + baseWarns := ht.CurrentState().Warnings + baseStrs := ht.Strings() + + msgs := map[tailcfg.DisplayMessageID]tailcfg.DisplayMessage{ + "test": { + Title: "Control health message", + Text: "Extra help.", + }, + "title": { + Title: "Control health title only", + }, + "with-action": { + Title: "Control health message", + Text: "Extra help.", + PrimaryAction: &tailcfg.DisplayMessageAction{ + URL: "http://www.example.com", + Label: "Learn more", + }, + }, + } + ht.SetControlHealth(msgs) + + t.Run("Warnings", func(t *testing.T) { + wantWarns := map[WarnableCode]UnhealthyState{ + "control-health.test": { + WarnableCode: "control-health.test", + Severity: SeverityMedium, + Title: "Control health message", + Text: "Extra help.", + }, + "control-health.title": { + WarnableCode: "control-health.title", + Severity: SeverityMedium, + Title: "Control health title only", + }, + "control-health.with-action": { + WarnableCode: "control-health.with-action", + Severity: SeverityMedium, + Title: "Control health message", + Text: "Extra help.", + PrimaryAction: &UnhealthyStateAction{ + URL: "http://www.example.com", + Label: "Learn more", + }, + }, + } + state := ht.CurrentState() + gotWarns := maps.Clone(state.Warnings) + for k := range gotWarns { + if _, inBase := baseWarns[k]; inBase { + delete(gotWarns, k) + } + } + if diff := cmp.Diff(wantWarns, gotWarns, cmpopts.IgnoreFields(UnhealthyState{}, "ETag")); diff != "" { + t.Fatalf(`CurrentState().Warnings["control-health-*"] wrong (-want +got):\n%s`, diff) + } + }) + + t.Run("Strings()", func(t *testing.T) { + wantStrs := []string{ + "Control health message: Extra help.", + "Control health message: Extra help. Learn more: http://www.example.com", + "Control health title only.", + } + var gotStrs []string + for _, s := range ht.Strings() { + if !slices.Contains(baseStrs, s) { + gotStrs = append(gotStrs, s) + } + } + if diff := cmp.Diff(wantStrs, gotStrs); diff != "" { + t.Fatalf(`Strings() wrong (-want +got):\n%s`, diff) + } + }) + + t.Run("tailscaled_health_messages", func(t *testing.T) { + var r usermetric.Registry + ht.SetMetricsRegistry(&r) + + m, ok := ht.metricHealthMessage.(*metrics.MultiLabelMap[metricHealthMessageLabel]) + if !ok { + t.Fatal("metricHealthMessage has wrong type or is nil") + } + got := m.Get(metricHealthMessageLabel{ + Type: MetricLabelWarning, + }).String() + want := strconv.Itoa( + len(msgs) + len(baseStrs), + ) + if got != want { + t.Errorf("metricsHealthMessage.Get(warning) = %q, want %q", got, want) + } + }) +} + +func TestControlHealthNotifies(t *testing.T) { + type test struct { + name string + initialState map[tailcfg.DisplayMessageID]tailcfg.DisplayMessage + newState map[tailcfg.DisplayMessageID]tailcfg.DisplayMessage + wantEvents []any + } + tests := []test{ + { + name: "no-change", + initialState: map[tailcfg.DisplayMessageID]tailcfg.DisplayMessage{ + "test": {}, + }, + newState: map[tailcfg.DisplayMessageID]tailcfg.DisplayMessage{ + "test": {}, + }, + wantEvents: []any{}, + }, + { + name: "on-set", + initialState: map[tailcfg.DisplayMessageID]tailcfg.DisplayMessage{}, + newState: map[tailcfg.DisplayMessageID]tailcfg.DisplayMessage{ + "test": {}, + }, + wantEvents: []any{ + eventbustest.Type[Change](), + }, + }, + { + name: "details-change", + initialState: map[tailcfg.DisplayMessageID]tailcfg.DisplayMessage{ + "test": { + Title: "Title", + }, + }, + newState: map[tailcfg.DisplayMessageID]tailcfg.DisplayMessage{ + "test": { + Title: "Updated title", + }, + }, + wantEvents: []any{ + eventbustest.Type[Change](), + }, + }, + { + name: "action-changes", + initialState: map[tailcfg.DisplayMessageID]tailcfg.DisplayMessage{ + "test": { + PrimaryAction: &tailcfg.DisplayMessageAction{ + URL: "http://www.example.com/a/123456", + Label: "Sign in", + }, + }, + }, + newState: map[tailcfg.DisplayMessageID]tailcfg.DisplayMessage{ + "test": { + PrimaryAction: &tailcfg.DisplayMessageAction{ + URL: "http://www.example.com/a/abcdefg", + Label: "Sign in", + }, + }, + }, + wantEvents: []any{ + eventbustest.Type[Change](), + }, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + bus := eventbustest.NewBus(t) + if *doDebug { + eventbustest.LogAllEvents(t, bus) + } + tw := eventbustest.NewWatcher(t, bus) + + ht := NewTracker(bus) + ht.SetIPNState("NeedsLogin", true) + ht.GotStreamedMapResponse() + + // Expect events at starup, before doing anything else, skip unstable + // event and no warning event as they show up at different times. + synctest.Wait() + if err := eventbustest.Expect(tw, + CompareWarnableCode(t, tsconst.HealthWarnableWarmingUp), + CompareWarnableCode(t, tsconst.HealthWarnableNotInMapPoll), + CompareWarnableCode(t, tsconst.HealthWarnableWarmingUp), + ); err != nil { + t.Errorf("startup error: %v", err) + } + + // Only set initial state if we need to + if len(test.initialState) != 0 { + t.Log("Setting initial state") + ht.SetControlHealth(test.initialState) + synctest.Wait() + if err := eventbustest.Expect(tw, + CompareWarnableCode(t, tsconst.HealthWarnableMagicsockReceiveFuncError), + // Skip event with no warnable + CompareWarnableCode(t, tsconst.HealthWarnableNoDERPHome), + ); err != nil { + t.Errorf("initial state error: %v", err) + } + } + + ht.SetControlHealth(test.newState) + // Close the bus early to avoid timers triggering more events. + bus.Close() + + synctest.Wait() + if err := eventbustest.ExpectExactly(tw, test.wantEvents...); err != nil { + t.Errorf("event error: %v", err) + } + }) + }) + } +} + +func CompareWarnableCode(t *testing.T, code string) func(Change) bool { + t.Helper() + return func(c Change) bool { + t.Helper() + if c.Warnable != nil { + t.Logf("Warnable code: %s", c.Warnable.Code) + if string(c.Warnable.Code) == code { + return true + } + } else { + t.Log("No Warnable") + } + return false + } +} + +func TestControlHealthIgnoredOutsideMapPoll(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + bus := eventbustest.NewBus(t) + tw := eventbustest.NewWatcher(t, bus) + ht := NewTracker(bus) + ht.SetIPNState("NeedsLogin", true) + + ht.SetControlHealth(map[tailcfg.DisplayMessageID]tailcfg.DisplayMessage{ + "control-health": {}, + }) + + state := ht.CurrentState() + _, ok := state.Warnings["control-health"] + + if ok { + t.Error("got a warning with code 'control-health', want none") + } + + // An event is emitted when SetIPNState is run above, + // so only fail on the second event. + eventCounter := 0 + expectOne := func(c *Change) error { + eventCounter++ + if eventCounter == 1 { + return nil + } + return errors.New("saw more than 1 event") + } + + synctest.Wait() + if err := eventbustest.Expect(tw, expectOne); err == nil { + t.Error("event got emitted, want it to not be called") + } + }) +} + +// TestCurrentStateETagControlHealth tests that the ETag on an [UnhealthyState] +// created from Control health & returned by [Tracker.CurrentState] is different +// when the details of the [tailcfg.DisplayMessage] are different. +func TestCurrentStateETagControlHealth(t *testing.T) { + ht := NewTracker(eventbustest.NewBus(t)) + ht.SetIPNState("NeedsLogin", true) + ht.GotStreamedMapResponse() + + msg := tailcfg.DisplayMessage{ + Title: "Test Warning", + Text: "This is a test warning.", + Severity: tailcfg.SeverityHigh, + ImpactsConnectivity: true, + PrimaryAction: &tailcfg.DisplayMessageAction{ + URL: "https://example.com/", + Label: "open", + }, + } + + type test struct { + name string + change func(tailcfg.DisplayMessage) tailcfg.DisplayMessage + wantChangedETag bool + } + tests := []test{ + { + name: "same_value", + change: func(m tailcfg.DisplayMessage) tailcfg.DisplayMessage { return m }, + wantChangedETag: false, + }, + { + name: "different_severity", + change: func(m tailcfg.DisplayMessage) tailcfg.DisplayMessage { + m.Severity = tailcfg.SeverityLow + return m + }, + wantChangedETag: true, + }, + { + name: "different_title", + change: func(m tailcfg.DisplayMessage) tailcfg.DisplayMessage { + m.Title = "Different Title" + return m + }, + wantChangedETag: true, + }, + { + name: "different_text", + change: func(m tailcfg.DisplayMessage) tailcfg.DisplayMessage { + m.Text = "This is a different text." + return m + }, + wantChangedETag: true, + }, + { + name: "different_impacts_connectivity", + change: func(m tailcfg.DisplayMessage) tailcfg.DisplayMessage { + m.ImpactsConnectivity = false + return m + }, + wantChangedETag: true, + }, + { + name: "different_primary_action_label", + change: func(m tailcfg.DisplayMessage) tailcfg.DisplayMessage { + m.PrimaryAction.Label = "new_label" + return m + }, + wantChangedETag: true, + }, + { + name: "different_primary_action_url", + change: func(m tailcfg.DisplayMessage) tailcfg.DisplayMessage { + m.PrimaryAction.URL = "https://new.example.com/" + return m + }, + wantChangedETag: true, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ht.SetControlHealth(map[tailcfg.DisplayMessageID]tailcfg.DisplayMessage{ + "test-message": msg, + }) + state := ht.CurrentState().Warnings["control-health.test-message"] + + newMsg := test.change(msg) + ht.SetControlHealth(map[tailcfg.DisplayMessageID]tailcfg.DisplayMessage{ + "test-message": newMsg, + }) + newState := ht.CurrentState().Warnings["control-health.test-message"] + + if (state.ETag != newState.ETag) != test.wantChangedETag { + if test.wantChangedETag { + t.Errorf("got unchanged ETag, want changed (ETag was %q)", newState.ETag) + } else { + t.Errorf("got changed ETag, want unchanged") + } + } + }) + } +} + +// TestCurrentStateETagWarnable tests that the ETag on an [UnhealthyState] +// created from a Warnable & returned by [Tracker.CurrentState] is different +// when the details of the Warnable are different. +func TestCurrentStateETagWarnable(t *testing.T) { + newTracker := func(clock tstime.Clock) *Tracker { + ht := NewTracker(eventbustest.NewBus(t)) + ht.testClock = clock + ht.SetIPNState("NeedsLogin", true) + ht.GotStreamedMapResponse() + return ht + } + + t.Run("new_args", func(t *testing.T) { + ht := newTracker(nil) + + ht.SetUnhealthy(testWarnable, Args{ArgError: "initial value"}) + state := ht.CurrentState().Warnings[testWarnable.Code] + + ht.SetUnhealthy(testWarnable, Args{ArgError: "new value"}) + newState := ht.CurrentState().Warnings[testWarnable.Code] + + if state.ETag == newState.ETag { + t.Errorf("got unchanged ETag, want changed (ETag was %q)", newState.ETag) + } + }) + + t.Run("new_broken_since", func(t *testing.T) { + clock1 := tstest.NewClock(tstest.ClockOpts{ + Start: time.Unix(123, 0), + }) + ht1 := newTracker(clock1) + + ht1.SetUnhealthy(testWarnable, Args{}) + state := ht1.CurrentState().Warnings[testWarnable.Code] + + // Use a second tracker to get a different broken since time + clock2 := tstest.NewClock(tstest.ClockOpts{ + Start: time.Unix(456, 0), + }) + ht2 := newTracker(clock2) + + ht2.SetUnhealthy(testWarnable, Args{}) + newState := ht2.CurrentState().Warnings[testWarnable.Code] + + if state.ETag == newState.ETag { + t.Errorf("got unchanged ETag, want changed (ETag was %q)", newState.ETag) + } + }) + + t.Run("no_change", func(t *testing.T) { + clock := tstest.NewClock(tstest.ClockOpts{}) + ht1 := newTracker(clock) + + ht1.SetUnhealthy(testWarnable, Args{}) + state := ht1.CurrentState().Warnings[testWarnable.Code] + + // Using a second tracker because SetUnhealthy with no changes is a no-op + ht2 := newTracker(clock) + ht2.SetUnhealthy(testWarnable, Args{}) + newState := ht2.CurrentState().Warnings[testWarnable.Code] + + if state.ETag != newState.ETag { + t.Errorf("got changed ETag, want unchanged") + } + }) +} diff --git a/health/healthmsg/healthmsg.go b/health/healthmsg/healthmsg.go index 6c237678e..5ea1c736d 100644 --- a/health/healthmsg/healthmsg.go +++ b/health/healthmsg/healthmsg.go @@ -8,8 +8,10 @@ package healthmsg const ( - WarnAcceptRoutesOff = "Some peers are advertising routes but --accept-routes is false" - TailscaleSSHOnBut = "Tailscale SSH enabled, but " // + ... something from caller - LockedOut = "this node is locked out; it will not have connectivity until it is signed. For more info, see https://tailscale.com/s/locked-out" - WarnExitNodeUsage = "The following issues on your machine will likely make usage of exit nodes impossible" + WarnAcceptRoutesOff = "Some peers are advertising routes but --accept-routes is false" + TailscaleSSHOnBut = "Tailscale SSH enabled, but " // + ... something from caller + LockedOut = "this node is locked out; it will not have connectivity until it is signed. For more info, see https://tailscale.com/s/locked-out" + WarnExitNodeUsage = "The following issues on your machine will likely make usage of exit nodes impossible" + DisableRPFilter = "Please set rp_filter=2 instead of rp_filter=1; see https://github.com/tailscale/tailscale/issues/3310" + InMemoryTailnetLockState = "Tailnet Lock state is only being stored in-memory. Set --statedir to store state on disk, which is more secure. See https://tailscale.com/kb/1226/tailnet-lock#tailnet-lock-state" ) diff --git a/health/state.go b/health/state.go index 17a646794..e6d937b6a 100644 --- a/health/state.go +++ b/health/state.go @@ -4,11 +4,20 @@ package health import ( + "crypto/sha256" + "encoding/hex" + "encoding/json" "time" + + "tailscale.com/feature/buildfeatures" + "tailscale.com/tailcfg" ) // State contains the health status of the backend, and is // provided to the client UI via LocalAPI through ipn.Notify. +// +// It is also exposed via c2n for debugging purposes, so try +// not to change its structure too gratuitously. type State struct { // Each key-value pair in Warnings represents a Warnable that is currently // unhealthy. If a Warnable is healthy, it will not be present in this map. @@ -21,16 +30,56 @@ type State struct { } // UnhealthyState contains information to be shown to the user to inform them -// that a Warnable is currently unhealthy. +// that a [Warnable] is currently unhealthy or [tailcfg.DisplayMessage] is being +// sent from the control-plane. type UnhealthyState struct { WarnableCode WarnableCode Severity Severity Title string Text string - BrokenSince *time.Time `json:",omitempty"` - Args Args `json:",omitempty"` - DependsOn []WarnableCode `json:",omitempty"` - ImpactsConnectivity bool `json:",omitempty"` + BrokenSince *time.Time `json:",omitempty"` + Args Args `json:",omitempty"` + DependsOn []WarnableCode `json:",omitempty"` + ImpactsConnectivity bool `json:",omitempty"` + PrimaryAction *UnhealthyStateAction `json:",omitempty"` + + // ETag identifies a specific version of an UnhealthyState. If the contents + // of the other fields of two UnhealthyStates are the same, the ETags will + // be the same. If the contents differ, the ETags will also differ. The + // implementation is not defined and the value is opaque: it might be a + // hash, it might be a simple counter. Implementations should not rely on + // any specific implementation detail or format of the ETag string other + // than string (in)equality. + ETag string `json:",omitzero"` +} + +// hash computes a deep hash of UnhealthyState which will be stable across +// different runs of the same binary. +func (u UnhealthyState) hash() []byte { + hasher := sha256.New() + enc := json.NewEncoder(hasher) + + // hash.Hash.Write never returns an error, so this will only fail if u is + // not marshalable, in which case we have much bigger problems. + _ = enc.Encode(u) + return hasher.Sum(nil) +} + +// withETag returns a copy of UnhealthyState with an ETag set. The ETag will be +// the same for all UnhealthyState instances that are equal. If calculating the +// ETag errors, it returns a copy of the UnhealthyState with an empty ETag. +func (u UnhealthyState) withETag() UnhealthyState { + u.ETag = "" + u.ETag = hex.EncodeToString(u.hash()) + return u +} + +// UnhealthyStateAction represents an action (URL and link) to be presented to +// the user associated with an [UnhealthyState]. Analogous to +// [tailcfg.DisplayMessageAction]. +type UnhealthyStateAction struct { + URL string + Label string } // unhealthyState returns a unhealthyState of the Warnable given its current warningState. @@ -72,7 +121,7 @@ func (w *Warnable) unhealthyState(ws *warningState) *UnhealthyState { // The returned State is a snapshot of shared memory, and the caller should not // mutate the returned value. func (t *Tracker) CurrentState() *State { - if t.nil() { + if !buildfeatures.HasHealth || t.nil() { return &State{} } @@ -86,14 +135,71 @@ func (t *Tracker) CurrentState() *State { wm := map[WarnableCode]UnhealthyState{} for w, ws := range t.warnableVal { - if !w.IsVisible(ws) { + if !w.IsVisible(ws, t.now) { // Skip invisible Warnables. continue } - wm[w.Code] = *w.unhealthyState(ws) + if t.isEffectivelyHealthyLocked(w) { + // Skip Warnables that are unhealthy if they have dependencies + // that are unhealthy. + continue + } + state := w.unhealthyState(ws) + wm[w.Code] = state.withETag() + } + + for id, msg := range t.lastNotifiedControlMessages { + state := UnhealthyState{ + WarnableCode: WarnableCode("control-health." + id), + Severity: severityFromTailcfg(msg.Severity), + Title: msg.Title, + Text: msg.Text, + ImpactsConnectivity: msg.ImpactsConnectivity, + // TODO(tailscale/corp#27759): DependsOn? + } + + if msg.PrimaryAction != nil { + state.PrimaryAction = &UnhealthyStateAction{ + URL: msg.PrimaryAction.URL, + Label: msg.PrimaryAction.Label, + } + } + + wm[state.WarnableCode] = state.withETag() } return &State{ Warnings: wm, } } + +func severityFromTailcfg(s tailcfg.DisplayMessageSeverity) Severity { + switch s { + case tailcfg.SeverityHigh: + return SeverityHigh + case tailcfg.SeverityLow: + return SeverityLow + default: + return SeverityMedium + } +} + +// isEffectivelyHealthyLocked reports whether w is effectively healthy. +// That means it's either actually healthy or it has a dependency that +// that's unhealthy, so we should treat w as healthy to not spam users +// with multiple warnings when only the root cause is relevant. +func (t *Tracker) isEffectivelyHealthyLocked(w *Warnable) bool { + if _, ok := t.warnableVal[w]; !ok { + // Warnable not found in the tracker. So healthy. + return true + } + for _, d := range w.DependsOn { + if !t.isEffectivelyHealthyLocked(d) { + // If one of our deps is unhealthy, we're healthy. + return true + } + } + // If we have no unhealthy deps and had warnableVal set, + // we're unhealthy. + return false +} diff --git a/health/usermetrics.go b/health/usermetrics.go new file mode 100644 index 000000000..110c57b57 --- /dev/null +++ b/health/usermetrics.go @@ -0,0 +1,52 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !ts_omit_health && !ts_omit_usermetrics + +package health + +import ( + "expvar" + + "tailscale.com/feature/buildfeatures" + "tailscale.com/util/usermetric" +) + +const MetricLabelWarning = "warning" + +type metricHealthMessageLabel struct { + // TODO: break down by warnable.severity as well? + Type string +} + +// SetMetricsRegistry sets up the metrics for the Tracker. It takes +// a usermetric.Registry and registers the metrics there. +func (t *Tracker) SetMetricsRegistry(reg *usermetric.Registry) { + if !buildfeatures.HasHealth { + return + } + + if reg == nil || t.metricHealthMessage != nil { + return + } + + m := usermetric.NewMultiLabelMapWithRegistry[metricHealthMessageLabel]( + reg, + "tailscaled_health_messages", + "gauge", + "Number of health messages broken down by type.", + ) + + m.Set(metricHealthMessageLabel{ + Type: MetricLabelWarning, + }, expvar.Func(func() any { + if t.nil() { + return 0 + } + t.mu.Lock() + defer t.mu.Unlock() + t.updateBuiltinWarnablesLocked() + return int64(len(t.stringsLocked())) + })) + t.metricHealthMessage = m +} diff --git a/health/usermetrics_omit.go b/health/usermetrics_omit.go new file mode 100644 index 000000000..9d5e35b86 --- /dev/null +++ b/health/usermetrics_omit.go @@ -0,0 +1,8 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build ts_omit_health || ts_omit_usermetrics + +package health + +func (t *Tracker) SetMetricsRegistry(any) {} diff --git a/health/warnings.go b/health/warnings.go index 7a21f9695..a9c4b34a0 100644 --- a/health/warnings.go +++ b/health/warnings.go @@ -8,244 +8,279 @@ import ( "runtime" "time" + "tailscale.com/feature/buildfeatures" + "tailscale.com/tsconst" "tailscale.com/version" ) +func condRegister(f func() *Warnable) *Warnable { + if !buildfeatures.HasHealth { + return nil + } + return f() +} + /** This file contains definitions for the Warnables maintained within this `health` package. */ // updateAvailableWarnable is a Warnable that warns the user that an update is available. -var updateAvailableWarnable = Register(&Warnable{ - Code: "update-available", - Title: "Update available", - Severity: SeverityLow, - Text: func(args Args) string { - if version.IsMacAppStore() || version.IsAppleTV() || version.IsMacSys() || version.IsWindowsGUI() || runtime.GOOS == "android" { - return fmt.Sprintf("An update from version %s to %s is available.", args[ArgCurrentVersion], args[ArgAvailableVersion]) - } else { - return fmt.Sprintf("An update from version %s to %s is available. Run `tailscale update` or `tailscale set --auto-update` to update now.", args[ArgCurrentVersion], args[ArgAvailableVersion]) - } - }, +var updateAvailableWarnable = condRegister(func() *Warnable { + return &Warnable{ + Code: tsconst.HealthWarnableUpdateAvailable, + Title: "Update available", + Severity: SeverityLow, + Text: func(args Args) string { + if version.IsMacAppStore() || version.IsAppleTV() || version.IsMacSys() || version.IsWindowsGUI() || runtime.GOOS == "android" { + return fmt.Sprintf("An update from version %s to %s is available.", args[ArgCurrentVersion], args[ArgAvailableVersion]) + } else { + return fmt.Sprintf("An update from version %s to %s is available. Run `tailscale update` or `tailscale set --auto-update` to update now.", args[ArgCurrentVersion], args[ArgAvailableVersion]) + } + }, + } }) // securityUpdateAvailableWarnable is a Warnable that warns the user that an important security update is available. -var securityUpdateAvailableWarnable = Register(&Warnable{ - Code: "security-update-available", - Title: "Security update available", - Severity: SeverityMedium, - Text: func(args Args) string { - if version.IsMacAppStore() || version.IsAppleTV() || version.IsMacSys() || version.IsWindowsGUI() || runtime.GOOS == "android" { - return fmt.Sprintf("A security update from version %s to %s is available.", args[ArgCurrentVersion], args[ArgAvailableVersion]) - } else { - return fmt.Sprintf("A security update from version %s to %s is available. Run `tailscale update` or `tailscale set --auto-update` to update now.", args[ArgCurrentVersion], args[ArgAvailableVersion]) - } - }, +var securityUpdateAvailableWarnable = condRegister(func() *Warnable { + return &Warnable{ + Code: tsconst.HealthWarnableSecurityUpdateAvailable, + Title: "Security update available", + Severity: SeverityMedium, + Text: func(args Args) string { + if version.IsMacAppStore() || version.IsAppleTV() || version.IsMacSys() || version.IsWindowsGUI() || runtime.GOOS == "android" { + return fmt.Sprintf("A security update from version %s to %s is available.", args[ArgCurrentVersion], args[ArgAvailableVersion]) + } else { + return fmt.Sprintf("A security update from version %s to %s is available. Run `tailscale update` or `tailscale set --auto-update` to update now.", args[ArgCurrentVersion], args[ArgAvailableVersion]) + } + }, + } }) // unstableWarnable is a Warnable that warns the user that they are using an unstable version of Tailscale // so they won't be surprised by all the issues that may arise. -var unstableWarnable = Register(&Warnable{ - Code: "is-using-unstable-version", - Title: "Using an unstable version", - Severity: SeverityLow, - Text: StaticMessage("This is an unstable version of Tailscale meant for testing and development purposes. Please report any issues to Tailscale."), +var unstableWarnable = condRegister(func() *Warnable { + return &Warnable{ + Code: tsconst.HealthWarnableIsUsingUnstableVersion, + Title: "Using an unstable version", + Severity: SeverityLow, + Text: StaticMessage("This is an unstable version of Tailscale meant for testing and development purposes. Please report any issues to Tailscale."), + } }) // NetworkStatusWarnable is a Warnable that warns the user that the network is down. -var NetworkStatusWarnable = Register(&Warnable{ - Code: "network-status", - Title: "Network down", - Severity: SeverityMedium, - Text: StaticMessage("Tailscale cannot connect because the network is down. Check your Internet connection."), - ImpactsConnectivity: true, - TimeToVisible: 5 * time.Second, +var NetworkStatusWarnable = condRegister(func() *Warnable { + return &Warnable{ + Code: tsconst.HealthWarnableNetworkStatus, + Title: "Network down", + Severity: SeverityMedium, + Text: StaticMessage("Tailscale cannot connect because the network is down. Check your Internet connection."), + ImpactsConnectivity: true, + TimeToVisible: 5 * time.Second, + } }) // IPNStateWarnable is a Warnable that warns the user that Tailscale is stopped. -var IPNStateWarnable = Register(&Warnable{ - Code: "wantrunning-false", - Title: "Tailscale off", - Severity: SeverityLow, - Text: StaticMessage("Tailscale is stopped."), +var IPNStateWarnable = condRegister(func() *Warnable { + return &Warnable{ + Code: tsconst.HealthWarnableWantRunningFalse, + Title: "Tailscale off", + Severity: SeverityLow, + Text: StaticMessage("Tailscale is stopped."), + } }) // localLogWarnable is a Warnable that warns the user that the local log is misconfigured. -var localLogWarnable = Register(&Warnable{ - Code: "local-log-config-error", - Title: "Local log misconfiguration", - Severity: SeverityLow, - Text: func(args Args) string { - return fmt.Sprintf("The local log is misconfigured: %v", args[ArgError]) - }, +var localLogWarnable = condRegister(func() *Warnable { + return &Warnable{ + Code: tsconst.HealthWarnableLocalLogConfigError, + Title: "Local log misconfiguration", + Severity: SeverityLow, + Text: func(args Args) string { + return fmt.Sprintf("The local log is misconfigured: %v", args[ArgError]) + }, + } }) // LoginStateWarnable is a Warnable that warns the user that they are logged out, // and provides the last login error if available. -var LoginStateWarnable = Register(&Warnable{ - Code: "login-state", - Title: "Logged out", - Severity: SeverityMedium, - Text: func(args Args) string { - if args[ArgError] != "" { - return fmt.Sprintf("You are logged out. The last login error was: %v", args[ArgError]) - } else { - return "You are logged out." - } - }, - DependsOn: []*Warnable{IPNStateWarnable}, +var LoginStateWarnable = condRegister(func() *Warnable { + return &Warnable{ + Code: tsconst.HealthWarnableLoginState, + Title: "Logged out", + Severity: SeverityMedium, + Text: func(args Args) string { + if args[ArgError] != "" { + return fmt.Sprintf("You are logged out. The last login error was: %v", args[ArgError]) + } else { + return "You are logged out." + } + }, + DependsOn: []*Warnable{IPNStateWarnable}, + } }) // notInMapPollWarnable is a Warnable that warns the user that we are using a stale network map. -var notInMapPollWarnable = Register(&Warnable{ - Code: "not-in-map-poll", - Title: "Out of sync", - Severity: SeverityMedium, - DependsOn: []*Warnable{NetworkStatusWarnable, IPNStateWarnable}, - Text: StaticMessage("Unable to connect to the Tailscale coordination server to synchronize the state of your tailnet. Peer reachability might degrade over time."), - // 8 minutes reflects a maximum maintenance window for the coordination server. - TimeToVisible: 8 * time.Minute, +var notInMapPollWarnable = condRegister(func() *Warnable { + return &Warnable{ + Code: tsconst.HealthWarnableNotInMapPoll, + Title: "Out of sync", + Severity: SeverityMedium, + DependsOn: []*Warnable{NetworkStatusWarnable, IPNStateWarnable}, + Text: StaticMessage("Unable to connect to the Tailscale coordination server to synchronize the state of your tailnet. Peer reachability might degrade over time."), + // 8 minutes reflects a maximum maintenance window for the coordination server. + TimeToVisible: 8 * time.Minute, + } }) // noDERPHomeWarnable is a Warnable that warns the user that Tailscale doesn't have a home DERP. -var noDERPHomeWarnable = Register(&Warnable{ - Code: "no-derp-home", - Title: "No home relay server", - Severity: SeverityMedium, - DependsOn: []*Warnable{NetworkStatusWarnable}, - Text: StaticMessage("Tailscale could not connect to any relay server. Check your Internet connection."), - ImpactsConnectivity: true, - TimeToVisible: 10 * time.Second, +var noDERPHomeWarnable = condRegister(func() *Warnable { + return &Warnable{ + Code: tsconst.HealthWarnableNoDERPHome, + Title: "No home relay server", + Severity: SeverityMedium, + DependsOn: []*Warnable{NetworkStatusWarnable}, + Text: StaticMessage("Tailscale could not connect to any relay server. Check your Internet connection."), + ImpactsConnectivity: true, + TimeToVisible: 10 * time.Second, + } }) // noDERPConnectionWarnable is a Warnable that warns the user that Tailscale couldn't connect to a specific DERP server. -var noDERPConnectionWarnable = Register(&Warnable{ - Code: "no-derp-connection", - Title: "Relay server unavailable", - Severity: SeverityMedium, - DependsOn: []*Warnable{ - NetworkStatusWarnable, +var noDERPConnectionWarnable = condRegister(func() *Warnable { + return &Warnable{ + Code: tsconst.HealthWarnableNoDERPConnection, + Title: "Relay server unavailable", + Severity: SeverityMedium, + DependsOn: []*Warnable{ + NetworkStatusWarnable, - // Technically noDERPConnectionWarnable could be used to warn about - // failure to connect to a specific DERP server (e.g. your home is derp1 - // but you're trying to connect to a peer's derp4 and are unable) but as - // of 2024-09-25 we only use this for connecting to your home DERP, so - // we depend on noDERPHomeWarnable which is the ability to figure out - // what your DERP home even is. - noDERPHomeWarnable, - }, - Text: func(args Args) string { - if n := args[ArgDERPRegionName]; n != "" { - return fmt.Sprintf("Tailscale could not connect to the '%s' relay server. Your Internet connection might be down, or the server might be temporarily unavailable.", n) - } else { - return fmt.Sprintf("Tailscale could not connect to the relay server with ID '%s'. Your Internet connection might be down, or the server might be temporarily unavailable.", args[ArgDERPRegionID]) - } - }, - ImpactsConnectivity: true, - TimeToVisible: 10 * time.Second, + // Technically noDERPConnectionWarnable could be used to warn about + // failure to connect to a specific DERP server (e.g. your home is derp1 + // but you're trying to connect to a peer's derp4 and are unable) but as + // of 2024-09-25 we only use this for connecting to your home DERP, so + // we depend on noDERPHomeWarnable which is the ability to figure out + // what your DERP home even is. + noDERPHomeWarnable, + }, + Text: func(args Args) string { + if n := args[ArgDERPRegionName]; n != "" { + return fmt.Sprintf("Tailscale could not connect to the '%s' relay server. Your Internet connection might be down, or the server might be temporarily unavailable.", n) + } else { + return fmt.Sprintf("Tailscale could not connect to the relay server with ID '%s'. Your Internet connection might be down, or the server might be temporarily unavailable.", args[ArgDERPRegionID]) + } + }, + ImpactsConnectivity: true, + TimeToVisible: 10 * time.Second, + } }) // derpTimeoutWarnable is a Warnable that warns the user that Tailscale hasn't // heard from the home DERP region for a while. -var derpTimeoutWarnable = Register(&Warnable{ - Code: "derp-timed-out", - Title: "Relay server timed out", - Severity: SeverityMedium, - DependsOn: []*Warnable{ - NetworkStatusWarnable, - noDERPConnectionWarnable, // don't warn about it being stalled if we're not connected - noDERPHomeWarnable, // same reason as noDERPConnectionWarnable's dependency - }, - Text: func(args Args) string { - if n := args[ArgDERPRegionName]; n != "" { - return fmt.Sprintf("Tailscale hasn't heard from the '%s' relay server in %v. The server might be temporarily unavailable, or your Internet connection might be down.", n, args[ArgDuration]) - } else { - return fmt.Sprintf("Tailscale hasn't heard from the home relay server (region ID '%v') in %v. The server might be temporarily unavailable, or your Internet connection might be down.", args[ArgDERPRegionID], args[ArgDuration]) - } - }, +var derpTimeoutWarnable = condRegister(func() *Warnable { + return &Warnable{ + Code: tsconst.HealthWarnableDERPTimedOut, + Title: "Relay server timed out", + Severity: SeverityMedium, + DependsOn: []*Warnable{ + NetworkStatusWarnable, + noDERPConnectionWarnable, // don't warn about it being stalled if we're not connected + noDERPHomeWarnable, // same reason as noDERPConnectionWarnable's dependency + }, + Text: func(args Args) string { + if n := args[ArgDERPRegionName]; n != "" { + return fmt.Sprintf("Tailscale hasn't heard from the '%s' relay server in %v. The server might be temporarily unavailable, or your Internet connection might be down.", n, args[ArgDuration]) + } else { + return fmt.Sprintf("Tailscale hasn't heard from the home relay server (region ID '%v') in %v. The server might be temporarily unavailable, or your Internet connection might be down.", args[ArgDERPRegionID], args[ArgDuration]) + } + }, + } }) // derpRegionErrorWarnable is a Warnable that warns the user that a DERP region is reporting an issue. -var derpRegionErrorWarnable = Register(&Warnable{ - Code: "derp-region-error", - Title: "Relay server error", - Severity: SeverityLow, - DependsOn: []*Warnable{NetworkStatusWarnable}, - Text: func(args Args) string { - return fmt.Sprintf("The relay server #%v is reporting an issue: %v", args[ArgDERPRegionID], args[ArgError]) - }, +var derpRegionErrorWarnable = condRegister(func() *Warnable { + return &Warnable{ + Code: tsconst.HealthWarnableDERPRegionError, + Title: "Relay server error", + Severity: SeverityLow, + DependsOn: []*Warnable{NetworkStatusWarnable}, + Text: func(args Args) string { + return fmt.Sprintf("The relay server #%v is reporting an issue: %v", args[ArgDERPRegionID], args[ArgError]) + }, + } }) // noUDP4BindWarnable is a Warnable that warns the user that Tailscale couldn't listen for incoming UDP connections. -var noUDP4BindWarnable = Register(&Warnable{ - Code: "no-udp4-bind", - Title: "NAT traversal setup failure", - Severity: SeverityMedium, - DependsOn: []*Warnable{NetworkStatusWarnable, IPNStateWarnable}, - Text: StaticMessage("Tailscale couldn't listen for incoming UDP connections."), - ImpactsConnectivity: true, +var noUDP4BindWarnable = condRegister(func() *Warnable { + return &Warnable{ + Code: tsconst.HealthWarnableNoUDP4Bind, + Title: "NAT traversal setup failure", + Severity: SeverityMedium, + DependsOn: []*Warnable{NetworkStatusWarnable, IPNStateWarnable}, + Text: StaticMessage("Tailscale couldn't listen for incoming UDP connections."), + ImpactsConnectivity: true, + } }) // mapResponseTimeoutWarnable is a Warnable that warns the user that Tailscale hasn't received a network map from the coordination server in a while. -var mapResponseTimeoutWarnable = Register(&Warnable{ - Code: "mapresponse-timeout", - Title: "Network map response timeout", - Severity: SeverityMedium, - DependsOn: []*Warnable{NetworkStatusWarnable, IPNStateWarnable}, - Text: func(args Args) string { - return fmt.Sprintf("Tailscale hasn't received a network map from the coordination server in %s.", args[ArgDuration]) - }, +var mapResponseTimeoutWarnable = condRegister(func() *Warnable { + return &Warnable{ + Code: tsconst.HealthWarnableMapResponseTimeout, + Title: "Network map response timeout", + Severity: SeverityMedium, + DependsOn: []*Warnable{NetworkStatusWarnable, IPNStateWarnable}, + Text: func(args Args) string { + return fmt.Sprintf("Tailscale hasn't received a network map from the coordination server in %s.", args[ArgDuration]) + }, + } }) // tlsConnectionFailedWarnable is a Warnable that warns the user that Tailscale could not establish an encrypted connection with a server. -var tlsConnectionFailedWarnable = Register(&Warnable{ - Code: "tls-connection-failed", - Title: "Encrypted connection failed", - Severity: SeverityMedium, - DependsOn: []*Warnable{NetworkStatusWarnable}, - Text: func(args Args) string { - return fmt.Sprintf("Tailscale could not establish an encrypted connection with '%q': %v", args[ArgServerName], args[ArgError]) - }, +var tlsConnectionFailedWarnable = condRegister(func() *Warnable { + return &Warnable{ + Code: tsconst.HealthWarnableTLSConnectionFailed, + Title: "Encrypted connection failed", + Severity: SeverityMedium, + DependsOn: []*Warnable{NetworkStatusWarnable}, + Text: func(args Args) string { + return fmt.Sprintf("Tailscale could not establish an encrypted connection with '%q': %v", args[ArgServerName], args[ArgError]) + }, + } }) // magicsockReceiveFuncWarnable is a Warnable that warns the user that one of the Magicsock functions is not running. -var magicsockReceiveFuncWarnable = Register(&Warnable{ - Code: "magicsock-receive-func-error", - Title: "MagicSock function not running", - Severity: SeverityMedium, - Text: func(args Args) string { - return fmt.Sprintf("The MagicSock function %s is not running. You might experience connectivity issues.", args[ArgMagicsockFunctionName]) - }, +var magicsockReceiveFuncWarnable = condRegister(func() *Warnable { + return &Warnable{ + Code: tsconst.HealthWarnableMagicsockReceiveFuncError, + Title: "MagicSock function not running", + Severity: SeverityMedium, + Text: func(args Args) string { + return fmt.Sprintf("The MagicSock function %s is not running. You might experience connectivity issues.", args[ArgMagicsockFunctionName]) + }, + } }) // testWarnable is a Warnable that is used within this package for testing purposes only. -var testWarnable = Register(&Warnable{ - Code: "test-warnable", - Title: "Test warnable", - Severity: SeverityLow, - Text: func(args Args) string { - return args[ArgError] - }, +var testWarnable = condRegister(func() *Warnable { + return &Warnable{ + Code: tsconst.HealthWarnableTestWarnable, + Title: "Test warnable", + Severity: SeverityLow, + Text: func(args Args) string { + return args[ArgError] + }, + } }) // applyDiskConfigWarnable is a Warnable that warns the user that there was an error applying the envknob config stored on disk. -var applyDiskConfigWarnable = Register(&Warnable{ - Code: "apply-disk-config", - Title: "Could not apply configuration", - Severity: SeverityMedium, - Text: func(args Args) string { - return fmt.Sprintf("An error occurred applying the Tailscale envknob configuration stored on disk: %v", args[ArgError]) - }, -}) - -// controlHealthWarnable is a Warnable that warns the user that the coordination server is reporting an health issue. -var controlHealthWarnable = Register(&Warnable{ - Code: "control-health", - Title: "Coordination server reports an issue", - Severity: SeverityMedium, - Text: func(args Args) string { - return fmt.Sprintf("The coordination server is reporting an health issue: %v", args[ArgError]) - }, +var applyDiskConfigWarnable = condRegister(func() *Warnable { + return &Warnable{ + Code: tsconst.HealthWarnableApplyDiskConfig, + Title: "Could not apply configuration", + Severity: SeverityMedium, + Text: func(args Args) string { + return fmt.Sprintf("An error occurred applying the Tailscale envknob configuration stored on disk: %v", args[ArgError]) + }, + } }) // warmingUpWarnableDuration is the duration for which the warmingUpWarnable is reported by the backend after the user @@ -255,9 +290,11 @@ const warmingUpWarnableDuration = 5 * time.Second // warmingUpWarnable is a Warnable that is reported by the backend when it is starting up, for a maximum time of // warmingUpWarnableDuration. The GUIs use the presence of this Warnable to prevent showing any other warnings until // the backend is fully started. -var warmingUpWarnable = Register(&Warnable{ - Code: "warming-up", - Title: "Tailscale is starting", - Severity: SeverityLow, - Text: StaticMessage("Tailscale is starting. Please wait."), +var warmingUpWarnable = condRegister(func() *Warnable { + return &Warnable{ + Code: tsconst.HealthWarnableWarmingUp, + Title: "Tailscale is starting", + Severity: SeverityLow, + Text: StaticMessage("Tailscale is starting. Please wait."), + } }) diff --git a/hostinfo/hostinfo.go b/hostinfo/hostinfo.go index 1f9037829..3e8f2f994 100644 --- a/hostinfo/hostinfo.go +++ b/hostinfo/hostinfo.go @@ -21,22 +21,31 @@ import ( "go4.org/mem" "tailscale.com/envknob" "tailscale.com/tailcfg" + "tailscale.com/types/lazy" "tailscale.com/types/opt" "tailscale.com/types/ptr" "tailscale.com/util/cloudenv" "tailscale.com/util/dnsname" - "tailscale.com/util/lineread" + "tailscale.com/util/lineiter" "tailscale.com/version" "tailscale.com/version/distro" ) var started = time.Now() +var newHooks []func(*tailcfg.Hostinfo) + +// RegisterHostinfoNewHook registers a callback to be called on a non-nil +// [tailcfg.Hostinfo] before it is returned by [New]. +func RegisterHostinfoNewHook(f func(*tailcfg.Hostinfo)) { + newHooks = append(newHooks, f) +} + // New returns a partially populated Hostinfo for the current host. func New() *tailcfg.Hostinfo { - hostname, _ := os.Hostname() + hostname, _ := Hostname() hostname = dnsname.FirstLabel(hostname) - return &tailcfg.Hostinfo{ + hi := &tailcfg.Hostinfo{ IPNVersion: version.Long(), Hostname: hostname, App: appTypeCached(), @@ -57,8 +66,11 @@ func New() *tailcfg.Hostinfo { Cloud: string(cloudenv.Get()), NoLogsNoSupport: envknob.NoLogsNoSupport(), AllowsUpdate: envknob.AllowsRemoteUpdate(), - WoLMACs: getWoLMACs(), } + for _, f := range newHooks { + f(hi) + } + return hi } // non-nil on some platforms @@ -231,12 +243,11 @@ func desktop() (ret opt.Bool) { } seenDesktop := false - lineread.File("/proc/net/unix", func(line []byte) error { - seenDesktop = seenDesktop || mem.Contains(mem.B(line), mem.S(" @/tmp/dbus-")) + for lr := range lineiter.File("/proc/net/unix") { + line, _ := lr.Value() seenDesktop = seenDesktop || mem.Contains(mem.B(line), mem.S(".X11-unix")) seenDesktop = seenDesktop || mem.Contains(mem.B(line), mem.S("/wayland-1")) - return nil - }) + } ret.Set(seenDesktop) // Only cache after a minute - compositors might not have started yet. @@ -280,13 +291,22 @@ func getEnvType() EnvType { return "" } -// inContainer reports whether we're running in a container. +// inContainer reports whether we're running in a container. Best-effort only, +// there's no foolproof way to detect this, but the build tag should catch all +// official builds from 1.78.0. func inContainer() opt.Bool { if runtime.GOOS != "linux" { return "" } var ret opt.Bool ret.Set(false) + if packageType != nil && packageType() == "container" { + // Go build tag ts_package_container was set during build. + ret.Set(true) + return ret + } + // Only set if using docker's container runtime. Not guaranteed by + // documentation, but it's been in place for a long time. if _, err := os.Stat("/.dockerenv"); err == nil { ret.Set(true) return ret @@ -296,21 +316,21 @@ func inContainer() opt.Bool { ret.Set(true) return ret } - lineread.File("/proc/1/cgroup", func(line []byte) error { + for lr := range lineiter.File("/proc/1/cgroup") { + line, _ := lr.Value() if mem.Contains(mem.B(line), mem.S("/docker/")) || mem.Contains(mem.B(line), mem.S("/lxc/")) { ret.Set(true) - return io.EOF // arbitrary non-nil error to stop loop + break } - return nil - }) - lineread.File("/proc/mounts", func(line []byte) error { + } + for lr := range lineiter.File("/proc/mounts") { + line, _ := lr.Value() if mem.Contains(mem.B(line), mem.S("lxcfs /proc/cpuinfo fuse.lxcfs")) { ret.Set(true) - return io.EOF + break } - return nil - }) + } return ret } @@ -362,7 +382,7 @@ func inFlyDotIo() bool { } func inReplit() bool { - // https://docs.replit.com/programming-ide/getting-repl-metadata + // https://docs.replit.com/replit-workspace/configuring-repl#environment-variables if os.Getenv("REPL_OWNER") != "" && os.Getenv("REPL_SLUG") != "" { return true } @@ -478,5 +498,32 @@ func IsNATLabGuestVM() bool { return false } -// NAT Lab VMs have a unique MAC address prefix. -// See +const copyV86DeviceModel = "copy-v86" + +var isV86Cache lazy.SyncValue[bool] + +// IsInVM86 reports whether we're running in the copy/v86 wasm emulator, +// https://github.com/copy/v86/. +func IsInVM86() bool { + return isV86Cache.Get(func() bool { + return New().DeviceModel == copyV86DeviceModel + }) +} + +type hostnameQuery func() (string, error) + +var hostnameFn atomic.Value // of func() (string, error) + +// SetHostNameFn sets a custom function for querying the system hostname. +func SetHostnameFn(fn hostnameQuery) { + hostnameFn.Store(fn) +} + +// Hostname returns the system hostname using the function +// set by SetHostNameFn. We will fallback to os.Hostname. +func Hostname() (string, error) { + if fn, ok := hostnameFn.Load().(hostnameQuery); ok && fn != nil { + return fn() + } + return os.Hostname() +} diff --git a/hostinfo/hostinfo_container_linux_test.go b/hostinfo/hostinfo_container_linux_test.go new file mode 100644 index 000000000..594a5f512 --- /dev/null +++ b/hostinfo/hostinfo_container_linux_test.go @@ -0,0 +1,16 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build linux && !android && ts_package_container + +package hostinfo + +import ( + "testing" +) + +func TestInContainer(t *testing.T) { + if got := inContainer(); !got.EqualBool(true) { + t.Errorf("inContainer = %v; want true due to ts_package_container build tag", got) + } +} diff --git a/hostinfo/hostinfo_linux.go b/hostinfo/hostinfo_linux.go index 53d4187bc..66484a358 100644 --- a/hostinfo/hostinfo_linux.go +++ b/hostinfo/hostinfo_linux.go @@ -12,7 +12,7 @@ import ( "golang.org/x/sys/unix" "tailscale.com/types/ptr" - "tailscale.com/util/lineread" + "tailscale.com/util/lineiter" "tailscale.com/version/distro" ) @@ -106,15 +106,18 @@ func linuxVersionMeta() (meta versionMeta) { } m := map[string]string{} - lineread.File(propFile, func(line []byte) error { + for lr := range lineiter.File(propFile) { + line, err := lr.Value() + if err != nil { + break + } eq := bytes.IndexByte(line, '=') if eq == -1 { - return nil + continue } k, v := string(line[:eq]), strings.Trim(string(line[eq+1:]), `"'`) m[k] = v - return nil - }) + } if v := m["VERSION_CODENAME"]; v != "" { meta.DistroCodeName = v diff --git a/hostinfo/hostinfo_linux_test.go b/hostinfo/hostinfo_linux_test.go index 4859167a2..0286fadf3 100644 --- a/hostinfo/hostinfo_linux_test.go +++ b/hostinfo/hostinfo_linux_test.go @@ -1,7 +1,7 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -//go:build linux && !android +//go:build linux && !android && !ts_package_container package hostinfo @@ -34,3 +34,13 @@ remotes/origin/QTSFW_5.0.0` t.Errorf("got %q; want %q", got, want) } } + +func TestPackageTypeNotContainer(t *testing.T) { + var got string + if packageType != nil { + got = packageType() + } + if got == "container" { + t.Fatal("packageType = container; should only happen if build tag ts_package_container is set") + } +} diff --git a/hostinfo/hostinfo_plan9.go b/hostinfo/hostinfo_plan9.go new file mode 100644 index 000000000..f9aa30e51 --- /dev/null +++ b/hostinfo/hostinfo_plan9.go @@ -0,0 +1,39 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package hostinfo + +import ( + "bytes" + "os" + "strings" + + "tailscale.com/tailcfg" + "tailscale.com/types/lazy" +) + +func init() { + RegisterHostinfoNewHook(func(hi *tailcfg.Hostinfo) { + if isPlan9V86() { + hi.DeviceModel = copyV86DeviceModel + } + }) +} + +var isPlan9V86Cache lazy.SyncValue[bool] + +// isPlan9V86 reports whether we're running in the wasm +// environment (https://github.com/copy/v86/). +func isPlan9V86() bool { + return isPlan9V86Cache.Get(func() bool { + v, _ := os.ReadFile("/dev/cputype") + s, _, _ := strings.Cut(string(v), " ") + if s != "PentiumIV/Xeon" { + return false + } + + v, _ = os.ReadFile("/dev/config") + v, _, _ = bytes.Cut(v, []byte{'\n'}) + return string(v) == "# pcvm - small kernel used to run in vm" + }) +} diff --git a/hostinfo/hostinfo_test.go b/hostinfo/hostinfo_test.go index 9fe32e044..15b6971b6 100644 --- a/hostinfo/hostinfo_test.go +++ b/hostinfo/hostinfo_test.go @@ -5,6 +5,7 @@ package hostinfo import ( "encoding/json" + "os" "strings" "testing" ) @@ -49,3 +50,31 @@ func TestEtcAptSourceFileIsDisabled(t *testing.T) { }) } } + +func TestCustomHostnameFunc(t *testing.T) { + want := "custom-hostname" + SetHostnameFn(func() (string, error) { + return want, nil + }) + + got, err := Hostname() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if got != want { + t.Errorf("got %q, want %q", got, want) + } + + SetHostnameFn(os.Hostname) + got, err = Hostname() + want, _ = os.Hostname() + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got != want { + t.Errorf("got %q, want %q", got, want) + } + +} diff --git a/hostinfo/wol.go b/hostinfo/wol.go deleted file mode 100644 index 3a30af2fe..000000000 --- a/hostinfo/wol.go +++ /dev/null @@ -1,106 +0,0 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package hostinfo - -import ( - "log" - "net" - "runtime" - "strings" - "unicode" - - "tailscale.com/envknob" -) - -// TODO(bradfitz): this is all too simplistic and static. It needs to run -// continuously in response to netmon events (USB ethernet adapaters might get -// plugged in) and look for the media type/status/etc. Right now on macOS it -// still detects a half dozen "up" en0, en1, en2, en3 etc interfaces that don't -// have any media. We should only report the one that's actually connected. -// But it works for now (2023-10-05) for fleshing out the rest. - -var wakeMAC = envknob.RegisterString("TS_WAKE_MAC") // mac address, "false" or "auto". for https://github.com/tailscale/tailscale/issues/306 - -// getWoLMACs returns up to 10 MAC address of the local machine to send -// wake-on-LAN packets to in order to wake it up. The returned MACs are in -// lowercase hex colon-separated form ("xx:xx:xx:xx:xx:xx"). -// -// If TS_WAKE_MAC=auto, it tries to automatically find the MACs based on the OS -// type and interface properties. (TODO(bradfitz): incomplete) If TS_WAKE_MAC is -// set to a MAC address, that sole MAC address is returned. -func getWoLMACs() (macs []string) { - switch runtime.GOOS { - case "ios", "android": - return nil - } - if s := wakeMAC(); s != "" { - switch s { - case "auto": - ifs, _ := net.Interfaces() - for _, iface := range ifs { - if iface.Flags&net.FlagLoopback != 0 { - continue - } - if iface.Flags&net.FlagBroadcast == 0 || - iface.Flags&net.FlagRunning == 0 || - iface.Flags&net.FlagUp == 0 { - continue - } - if keepMAC(iface.Name, iface.HardwareAddr) { - macs = append(macs, iface.HardwareAddr.String()) - } - if len(macs) == 10 { - break - } - } - return macs - case "false", "off": // fast path before ParseMAC error - return nil - } - mac, err := net.ParseMAC(s) - if err != nil { - log.Printf("invalid MAC %q", s) - return nil - } - return []string{mac.String()} - } - return nil -} - -var ignoreWakeOUI = map[[3]byte]bool{ - {0x00, 0x15, 0x5d}: true, // Hyper-V - {0x00, 0x50, 0x56}: true, // VMware - {0x00, 0x1c, 0x14}: true, // VMware - {0x00, 0x05, 0x69}: true, // VMware - {0x00, 0x0c, 0x29}: true, // VMware - {0x00, 0x1c, 0x42}: true, // Parallels - {0x08, 0x00, 0x27}: true, // VirtualBox - {0x00, 0x21, 0xf6}: true, // VirtualBox - {0x00, 0x14, 0x4f}: true, // VirtualBox - {0x00, 0x0f, 0x4b}: true, // VirtualBox - {0x52, 0x54, 0x00}: true, // VirtualBox/Vagrant -} - -func keepMAC(ifName string, mac []byte) bool { - if len(mac) != 6 { - return false - } - base := strings.TrimRightFunc(ifName, unicode.IsNumber) - switch runtime.GOOS { - case "darwin": - switch base { - case "llw", "awdl", "utun", "bridge", "lo", "gif", "stf", "anpi", "ap": - return false - } - } - if mac[0] == 0x02 && mac[1] == 0x42 { - // Docker container. - return false - } - oui := [3]byte{mac[0], mac[1], mac[2]} - if ignoreWakeOUI[oui] { - return false - } - return true -} diff --git a/internal/client/tailscale/identityfederation.go b/internal/client/tailscale/identityfederation.go new file mode 100644 index 000000000..e1fe3559c --- /dev/null +++ b/internal/client/tailscale/identityfederation.go @@ -0,0 +1,19 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tailscale + +import ( + "context" + + "tailscale.com/feature" +) + +// HookResolveAuthKeyViaWIF resolves to [identityfederation.ResolveAuthKey] when the +// corresponding feature tag is enabled in the build process. +// +// baseURL is the URL of the control server used for token exchange and authkey generation. +// clientID is the federated client ID used for token exchange, the format is / +// idToken is the Identity token from the identity provider +// tags is the list of tags to be associated with the auth key +var HookResolveAuthKeyViaWIF feature.Hook[func(ctx context.Context, baseURL, clientID, idToken string, tags []string) (string, error)] diff --git a/internal/client/tailscale/oauthkeys.go b/internal/client/tailscale/oauthkeys.go new file mode 100644 index 000000000..21102ce0b --- /dev/null +++ b/internal/client/tailscale/oauthkeys.go @@ -0,0 +1,20 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tailscale + +import ( + "context" + + "tailscale.com/feature" +) + +// HookResolveAuthKey resolves to [oauthkey.ResolveAuthKey] when the +// corresponding feature tag is enabled in the build process. +// +// authKey is a standard device auth key or an OAuth client secret to +// resolve into an auth key. +// tags is the list of tags being advertised by the client (required to be +// provided for the OAuth secret case, and required to be the same as the +// list of tags for which the OAuth secret is allowed to issue auth keys). +var HookResolveAuthKey feature.Hook[func(ctx context.Context, authKey string, tags []string) (string, error)] diff --git a/internal/client/tailscale/tailscale.go b/internal/client/tailscale/tailscale.go new file mode 100644 index 000000000..0e603bf79 --- /dev/null +++ b/internal/client/tailscale/tailscale.go @@ -0,0 +1,86 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package tailscale provides a minimal control plane API client for internal +// use. A full client for 3rd party use is available at +// tailscale.com/client/tailscale/v2. The internal client is provided to avoid +// having to import that whole package. +package tailscale + +import ( + "errors" + "io" + "net/http" + + tsclient "tailscale.com/client/tailscale" +) + +// maxSize is the maximum read size (10MB) of responses from the server. +const maxReadSize = 10 << 20 + +func init() { + tsclient.I_Acknowledge_This_API_Is_Unstable = true +} + +// AuthMethod is an alias to tailscale.com/client/tailscale. +type AuthMethod = tsclient.AuthMethod + +// APIKey is an alias to tailscale.com/client/tailscale. +type APIKey = tsclient.APIKey + +// Device is an alias to tailscale.com/client/tailscale. +type Device = tsclient.Device + +// DeviceFieldsOpts is an alias to tailscale.com/client/tailscale. +type DeviceFieldsOpts = tsclient.DeviceFieldsOpts + +// Key is an alias to tailscale.com/client/tailscale. +type Key = tsclient.Key + +// KeyCapabilities is an alias to tailscale.com/client/tailscale. +type KeyCapabilities = tsclient.KeyCapabilities + +// KeyDeviceCapabilities is an alias to tailscale.com/client/tailscale. +type KeyDeviceCapabilities = tsclient.KeyDeviceCapabilities + +// KeyDeviceCreateCapabilities is an alias to tailscale.com/client/tailscale. +type KeyDeviceCreateCapabilities = tsclient.KeyDeviceCreateCapabilities + +// ErrResponse is an alias to tailscale.com/client/tailscale. +type ErrResponse = tsclient.ErrResponse + +// NewClient is an alias to tailscale.com/client/tailscale. +func NewClient(tailnet string, auth AuthMethod) *Client { + return &Client{ + Client: tsclient.NewClient(tailnet, auth), + } +} + +// Client is a wrapper of tailscale.com/client/tailscale. +type Client struct { + *tsclient.Client +} + +// HandleErrorResponse is an alias to tailscale.com/client/tailscale. +func HandleErrorResponse(b []byte, resp *http.Response) error { + return tsclient.HandleErrorResponse(b, resp) +} + +// SendRequest add the authentication key to the request and sends it. It +// receives the response and reads up to 10MB of it. +func SendRequest(c *Client, req *http.Request) ([]byte, *http.Response, error) { + resp, err := c.Do(req) + if err != nil { + return nil, resp, err + } + defer resp.Body.Close() + + // Read response. Limit the response to 10MB. + // This limit is carried over from client/tailscale/tailscale.go. + body := io.LimitReader(resp.Body, maxReadSize+1) + b, err := io.ReadAll(body) + if len(b) > maxReadSize { + err = errors.New("API response too large") + } + return b, resp, err +} diff --git a/internal/client/tailscale/vip_service.go b/internal/client/tailscale/vip_service.go new file mode 100644 index 000000000..48c59ce45 --- /dev/null +++ b/internal/client/tailscale/vip_service.go @@ -0,0 +1,133 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tailscale + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "net/http" + + "tailscale.com/tailcfg" + "tailscale.com/util/httpm" +) + +// VIPService is a Tailscale VIPService with Tailscale API JSON representation. +type VIPService struct { + // Name is a VIPService name in form svc:. + Name tailcfg.ServiceName `json:"name,omitempty"` + // Addrs are the IP addresses of the VIP Service. There are two addresses: + // the first is IPv4 and the second is IPv6. + // When creating a new VIP Service, the IP addresses are optional: if no + // addresses are specified then they will be selected. If an IPv4 address is + // specified at index 0, then that address will attempt to be used. An IPv6 + // address can not be specified upon creation. + Addrs []string `json:"addrs,omitempty"` + // Comment is an optional text string for display in the admin panel. + Comment string `json:"comment,omitempty"` + // Annotations are optional key-value pairs that can be used to store arbitrary metadata. + Annotations map[string]string `json:"annotations,omitempty"` + // Ports are the ports of a VIPService that will be configured via Tailscale serve config. + // If set, any node wishing to advertise this VIPService must have this port configured via Tailscale serve. + Ports []string `json:"ports,omitempty"` + // Tags are optional ACL tags that will be applied to the VIPService. + Tags []string `json:"tags,omitempty"` +} + +// VIPServiceList represents the JSON response to the list VIP Services API. +type VIPServiceList struct { + VIPServices []VIPService `json:"vipServices"` +} + +// GetVIPService retrieves a VIPService by its name. It returns 404 if the VIPService is not found. +func (client *Client) GetVIPService(ctx context.Context, name tailcfg.ServiceName) (*VIPService, error) { + path := client.BuildTailnetURL("vip-services", name.String()) + req, err := http.NewRequestWithContext(ctx, httpm.GET, path, nil) + if err != nil { + return nil, fmt.Errorf("error creating new HTTP request: %w", err) + } + b, resp, err := SendRequest(client, req) + if err != nil { + return nil, fmt.Errorf("error making Tailsale API request: %w", err) + } + // If status code was not successful, return the error. + // TODO: Change the check for the StatusCode to include other 2XX success codes. + if resp.StatusCode != http.StatusOK { + return nil, HandleErrorResponse(b, resp) + } + svc := &VIPService{} + if err := json.Unmarshal(b, svc); err != nil { + return nil, err + } + return svc, nil +} + +// ListVIPServices retrieves all existing Services and returns them as a list. +func (client *Client) ListVIPServices(ctx context.Context) (*VIPServiceList, error) { + path := client.BuildTailnetURL("vip-services") + req, err := http.NewRequestWithContext(ctx, httpm.GET, path, nil) + if err != nil { + return nil, fmt.Errorf("error creating new HTTP request: %w", err) + } + b, resp, err := SendRequest(client, req) + if err != nil { + return nil, fmt.Errorf("error making Tailsale API request: %w", err) + } + // If status code was not successful, return the error. + // TODO: Change the check for the StatusCode to include other 2XX success codes. + if resp.StatusCode != http.StatusOK { + return nil, HandleErrorResponse(b, resp) + } + result := &VIPServiceList{} + if err := json.Unmarshal(b, result); err != nil { + return nil, err + } + return result, nil +} + +// CreateOrUpdateVIPService creates or updates a VIPService by its name. Caller must ensure that, if the +// VIPService already exists, the VIPService is fetched first to ensure that any auto-allocated IP addresses are not +// lost during the update. If the VIPService was created without any IP addresses explicitly set (so that they were +// auto-allocated by Tailscale) any subsequent request to this function that does not set any IP addresses will error. +func (client *Client) CreateOrUpdateVIPService(ctx context.Context, svc *VIPService) error { + data, err := json.Marshal(svc) + if err != nil { + return err + } + path := client.BuildTailnetURL("vip-services", svc.Name.String()) + req, err := http.NewRequestWithContext(ctx, httpm.PUT, path, bytes.NewBuffer(data)) + if err != nil { + return fmt.Errorf("error creating new HTTP request: %w", err) + } + b, resp, err := SendRequest(client, req) + if err != nil { + return fmt.Errorf("error making Tailscale API request: %w", err) + } + // If status code was not successful, return the error. + // TODO: Change the check for the StatusCode to include other 2XX success codes. + if resp.StatusCode != http.StatusOK { + return HandleErrorResponse(b, resp) + } + return nil +} + +// DeleteVIPService deletes a VIPService by its name. It returns an error if the VIPService +// does not exist or if the deletion fails. +func (client *Client) DeleteVIPService(ctx context.Context, name tailcfg.ServiceName) error { + path := client.BuildTailnetURL("vip-services", name.String()) + req, err := http.NewRequestWithContext(ctx, httpm.DELETE, path, nil) + if err != nil { + return fmt.Errorf("error creating new HTTP request: %w", err) + } + b, resp, err := SendRequest(client, req) + if err != nil { + return fmt.Errorf("error making Tailscale API request: %w", err) + } + // If status code was not successful, return the error. + if resp.StatusCode != http.StatusOK { + return HandleErrorResponse(b, resp) + } + return nil +} diff --git a/ipn/auditlog/auditlog.go b/ipn/auditlog/auditlog.go new file mode 100644 index 000000000..0460bc4e2 --- /dev/null +++ b/ipn/auditlog/auditlog.go @@ -0,0 +1,468 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package auditlog provides a mechanism for logging audit events. +package auditlog + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "sort" + "sync" + "time" + + "tailscale.com/ipn" + "tailscale.com/tailcfg" + "tailscale.com/types/logger" + "tailscale.com/util/rands" + "tailscale.com/util/set" +) + +// transaction represents an audit log that has not yet been sent to the control plane. +type transaction struct { + // EventID is the unique identifier for the event being logged. + // This is used on the client side only and is not sent to control. + EventID string `json:",omitempty"` + // Retries is the number of times the logger has attempted to send this log. + // This is used on the client side only and is not sent to control. + Retries int `json:",omitempty"` + + // Action is the action to be logged. It must correspond to a known action in the control plane. + Action tailcfg.ClientAuditAction `json:",omitempty"` + // Details is an opaque string specific to the action being logged. Empty strings may not + // be valid depending on the action being logged. + Details string `json:",omitempty"` + // TimeStamp is the time at which the audit log was generated on the node. + TimeStamp time.Time `json:",omitzero"` +} + +// Transport provides a means for a client to send audit logs to a consumer (typically the control plane). +type Transport interface { + // SendAuditLog sends an audit log to a consumer of audit logs. + // Errors should be checked with [IsRetryableError] for retryability. + SendAuditLog(context.Context, tailcfg.AuditLogRequest) error +} + +// LogStore provides a means for a [Logger] to persist logs to disk or memory. +type LogStore interface { + // Save saves the given data to a persistent store. Save will overwrite existing data + // for the given key. + save(key ipn.ProfileID, txns []*transaction) error + + // Load retrieves the data from a persistent store. Returns a nil slice and + // no error if no data exists for the given key. + load(key ipn.ProfileID) ([]*transaction, error) +} + +// Opts contains the configuration options for a [Logger]. +type Opts struct { + // RetryLimit is the maximum number of attempts the logger will make to send a log before giving up. + RetryLimit int + // Store is the persistent store used to save logs to disk. Must be non-nil. + Store LogStore + // Logf is the logger used to log messages from the audit logger. Must be non-nil. + Logf logger.Logf +} + +// IsRetryableError returns true if the given error is retryable +// See [controlclient.apiResponseError]. Potentially retryable errors implement the Retryable() method. +func IsRetryableError(err error) bool { + var retryable interface{ Retryable() bool } + return errors.As(err, &retryable) && retryable.Retryable() +} + +type backoffOpts struct { + min, max time.Duration + multiplier float64 +} + +// .5, 1, 2, 4, 8, 10, 10, 10, 10, 10... +var defaultBackoffOpts = backoffOpts{ + min: time.Millisecond * 500, + max: 10 * time.Second, + multiplier: 2, +} + +// Logger provides a queue-based mechanism for submitting audit logs to the control plane - or +// another suitable consumer. Logs are stored to disk and retried until they are successfully sent, +// or until they permanently fail. +// +// Each individual profile/controlclient tuple should construct and manage a unique [Logger] instance. +type Logger struct { + logf logger.Logf + retryLimit int // the maximum number of attempts to send a log before giving up. + flusher chan struct{} // channel used to signal a flush operation. + done chan struct{} // closed when the flush worker exits. + ctx context.Context // canceled when the logger is stopped. + ctxCancel context.CancelFunc // cancels ctx. + backoffOpts // backoff settings for retry operations. + + // mu protects the fields below. + mu sync.Mutex + store LogStore // persistent storage for unsent logs. + profileID ipn.ProfileID // empty if [Logger.SetProfileID] has not been called. + transport Transport // nil until [Logger.Start] is called. +} + +// NewLogger creates a new [Logger] with the given options. +func NewLogger(opts Opts) *Logger { + ctx, cancel := context.WithCancel(context.Background()) + + al := &Logger{ + retryLimit: opts.RetryLimit, + logf: opts.Logf, + store: opts.Store, + flusher: make(chan struct{}, 1), + done: make(chan struct{}), + ctx: ctx, + ctxCancel: cancel, + backoffOpts: defaultBackoffOpts, + } + al.logf("created") + return al +} + +// FlushAndStop synchronously flushes all pending logs and stops the audit logger. +// This will block until a final flush operation completes or context is done. +// If the logger is already stopped, this will return immediately. All unsent +// logs will be persisted to the store. +func (al *Logger) FlushAndStop(ctx context.Context) { + al.stop() + al.flush(ctx) +} + +// SetProfileID sets the profileID for the logger. This must be called before any logs can be enqueued. +// The profileID of a logger cannot be changed once set. +func (al *Logger) SetProfileID(profileID ipn.ProfileID) error { + al.mu.Lock() + defer al.mu.Unlock() + // It's not an error to call SetProfileID more than once + // with the same [ipn.ProfileID]. + if al.profileID != "" && al.profileID != profileID { + return errors.New("profileID cannot be changed once set") + } + + al.profileID = profileID + return nil +} + +// Start starts the audit logger with the given transport. +// It returns an error if the logger is already started. +func (al *Logger) Start(t Transport) error { + al.mu.Lock() + defer al.mu.Unlock() + + if al.transport != nil { + return errors.New("already started") + } + + al.transport = t + pending, err := al.storedCountLocked() + if err != nil { + al.logf("[unexpected] failed to restore logs: %v", err) + } + go al.flushWorker() + if pending > 0 { + al.flushAsync() + } + return nil +} + +// ErrAuditLogStorageFailure is returned when the logger fails to persist logs to the store. +var ErrAuditLogStorageFailure = errors.New("audit log storage failure") + +// Enqueue queues an audit log to be sent to the control plane (or another suitable consumer/transport). +// This will return an error if the underlying store fails to save the log or we fail to generate a unique +// eventID for the log. +func (al *Logger) Enqueue(action tailcfg.ClientAuditAction, details string) error { + txn := &transaction{ + Action: action, + Details: details, + TimeStamp: time.Now(), + } + // Generate a suitably random eventID for the transaction. + txn.EventID = fmt.Sprint(txn.TimeStamp, rands.HexString(16)) + return al.enqueue(txn) +} + +// flushAsync requests an asynchronous flush. +// It is a no-op if a flush is already pending. +func (al *Logger) flushAsync() { + select { + case al.flusher <- struct{}{}: + default: + } +} + +func (al *Logger) flushWorker() { + defer close(al.done) + + var retryDelay time.Duration + retry := time.NewTimer(0) + retry.Stop() + + for { + select { + case <-al.ctx.Done(): + return + case <-al.flusher: + err := al.flush(al.ctx) + switch { + case errors.Is(err, context.Canceled): + // The logger was stopped, no need to retry. + return + case err != nil: + retryDelay = max(al.backoffOpts.min, min(retryDelay*time.Duration(al.backoffOpts.multiplier), al.backoffOpts.max)) + al.logf("retrying after %v, %v", retryDelay, err) + retry.Reset(retryDelay) + default: + retryDelay = 0 + retry.Stop() + } + case <-retry.C: + al.flushAsync() + } + } +} + +// flush attempts to send all pending logs to the control plane. +// l.mu must not be held. +func (al *Logger) flush(ctx context.Context) error { + al.mu.Lock() + pending, err := al.store.load(al.profileID) + t := al.transport + al.mu.Unlock() + + if err != nil { + // This will catch nil profileIDs + return fmt.Errorf("failed to restore pending logs: %w", err) + } + if len(pending) == 0 { + return nil + } + if t == nil { + return errors.New("no transport") + } + + complete, unsent := al.sendToTransport(ctx, pending, t) + al.markTransactionsDone(complete) + + al.mu.Lock() + defer al.mu.Unlock() + if err = al.appendToStoreLocked(unsent); err != nil { + al.logf("[unexpected] failed to persist logs: %v", err) + } + + if len(unsent) != 0 { + return fmt.Errorf("failed to send %d logs", len(unsent)) + } + + if len(complete) != 0 { + al.logf("complete %d audit log transactions", len(complete)) + } + return nil +} + +// sendToTransport sends all pending logs to the control plane. Returns a pair of slices +// containing the logs that were successfully sent (or failed permanently) and those that were not. +// +// This may require multiple round trips to the control plane and can be a long running transaction. +func (al *Logger) sendToTransport(ctx context.Context, pending []*transaction, t Transport) (complete []*transaction, unsent []*transaction) { + for i, txn := range pending { + req := tailcfg.AuditLogRequest{ + Action: tailcfg.ClientAuditAction(txn.Action), + Details: txn.Details, + Timestamp: txn.TimeStamp, + } + + if err := t.SendAuditLog(ctx, req); err != nil { + switch { + case errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded): + // The contex is done. All further attempts will fail. + unsent = append(unsent, pending[i:]...) + return complete, unsent + case IsRetryableError(err) && txn.Retries+1 < al.retryLimit: + // We permit a maximum number of retries for each log. All retriable + // errors should be transient and we should be able to send the log eventually, but + // we don't want logs to be persisted indefinitely. + txn.Retries++ + unsent = append(unsent, txn) + default: + complete = append(complete, txn) + al.logf("failed permanently: %v", err) + } + } else { + // No error - we're done. + complete = append(complete, txn) + } + } + + return complete, unsent +} + +func (al *Logger) stop() { + al.mu.Lock() + t := al.transport + al.mu.Unlock() + + if t == nil { + // No transport means no worker goroutine and done will not be + // closed if we cancel the context. + return + } + + al.ctxCancel() + <-al.done + al.logf("stopped for profileID: %v", al.profileID) +} + +// appendToStoreLocked persists logs to the store. This will deduplicate +// logs so it is safe to call this with the same logs multiple time, to +// requeue failed transactions for example. +// +// l.mu must be held. +func (al *Logger) appendToStoreLocked(txns []*transaction) error { + if len(txns) == 0 { + return nil + } + + if al.profileID == "" { + return errors.New("no logId set") + } + + persisted, err := al.store.load(al.profileID) + if err != nil { + al.logf("[unexpected] append failed to restore logs: %v", err) + } + + // The order is important here. We want the latest transactions first, which will + // ensure when we dedup, the new transactions are seen and the older transactions + // are discarded. + txnsOut := append(txns, persisted...) + txnsOut = deduplicateAndSort(txnsOut) + + return al.store.save(al.profileID, txnsOut) +} + +// storedCountLocked returns the number of logs persisted to the store. +// al.mu must be held. +func (al *Logger) storedCountLocked() (int, error) { + persisted, err := al.store.load(al.profileID) + return len(persisted), err +} + +// markTransactionsDone removes logs from the store that are complete (sent or failed permanently). +// al.mu must not be held. +func (al *Logger) markTransactionsDone(sent []*transaction) { + al.mu.Lock() + defer al.mu.Unlock() + + ids := set.Set[string]{} + for _, txn := range sent { + ids.Add(txn.EventID) + } + + persisted, err := al.store.load(al.profileID) + if err != nil { + al.logf("[unexpected] markTransactionsDone failed to restore logs: %v", err) + } + var unsent []*transaction + for _, txn := range persisted { + if !ids.Contains(txn.EventID) { + unsent = append(unsent, txn) + } + } + al.store.save(al.profileID, unsent) +} + +// deduplicateAndSort removes duplicate logs from the given slice and sorts them by timestamp. +// The first log entry in the slice will be retained, subsequent logs with the same EventID will be discarded. +func deduplicateAndSort(txns []*transaction) []*transaction { + seen := set.Set[string]{} + deduped := make([]*transaction, 0, len(txns)) + for _, txn := range txns { + if !seen.Contains(txn.EventID) { + deduped = append(deduped, txn) + seen.Add(txn.EventID) + } + } + // Sort logs by timestamp - oldest to newest. This will put the oldest logs at + // the front of the queue. + sort.Slice(deduped, func(i, j int) bool { + return deduped[i].TimeStamp.Before(deduped[j].TimeStamp) + }) + return deduped +} + +func (al *Logger) enqueue(txn *transaction) error { + al.mu.Lock() + defer al.mu.Unlock() + + if err := al.appendToStoreLocked([]*transaction{txn}); err != nil { + return fmt.Errorf("%w: %w", ErrAuditLogStorageFailure, err) + } + + // If a.transport is nil if the logger is stopped. + if al.transport != nil { + al.flushAsync() + } + + return nil +} + +var _ LogStore = (*logStateStore)(nil) + +// logStateStore is a concrete implementation of [LogStore] +// using [ipn.StateStore] as the underlying storage. +type logStateStore struct { + store ipn.StateStore +} + +// NewLogStore creates a new LogStateStore with the given [ipn.StateStore]. +func NewLogStore(store ipn.StateStore) LogStore { + return &logStateStore{ + store: store, + } +} + +func (s *logStateStore) generateKey(key ipn.ProfileID) string { + return "auditlog-" + string(key) +} + +// Save saves the given logs to an [ipn.StateStore]. This overwrites +// any existing entries for the given key. +func (s *logStateStore) save(key ipn.ProfileID, txns []*transaction) error { + if key == "" { + return errors.New("empty key") + } + + data, err := json.Marshal(txns) + if err != nil { + return err + } + k := ipn.StateKey(s.generateKey(key)) + return s.store.WriteState(k, data) +} + +// Load retrieves the logs from an [ipn.StateStore]. +func (s *logStateStore) load(key ipn.ProfileID) ([]*transaction, error) { + if key == "" { + return nil, errors.New("empty key") + } + + k := ipn.StateKey(s.generateKey(key)) + data, err := s.store.ReadState(k) + + switch { + case errors.Is(err, ipn.ErrStateNotExist): + return nil, nil + case err != nil: + return nil, err + } + + var txns []*transaction + err = json.Unmarshal(data, &txns) + return txns, err +} diff --git a/ipn/auditlog/auditlog_test.go b/ipn/auditlog/auditlog_test.go new file mode 100644 index 000000000..041cab354 --- /dev/null +++ b/ipn/auditlog/auditlog_test.go @@ -0,0 +1,484 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package auditlog + +import ( + "context" + "errors" + "fmt" + "sync" + "testing" + "time" + + qt "github.com/frankban/quicktest" + "tailscale.com/ipn/store/mem" + "tailscale.com/tailcfg" + "tailscale.com/tstest" +) + +// loggerForTest creates an auditLogger for you and cleans it up +// (and ensures no goroutines are leaked) when the test is done. +func loggerForTest(t *testing.T, opts Opts) *Logger { + t.Helper() + tstest.ResourceCheck(t) + + if opts.Logf == nil { + opts.Logf = t.Logf + } + + if opts.Store == nil { + t.Fatalf("opts.Store must be set") + } + + a := NewLogger(opts) + + t.Cleanup(func() { + a.FlushAndStop(context.Background()) + }) + return a +} + +func TestNonRetryableErrors(t *testing.T) { + errorTests := []struct { + desc string + err error + want bool + }{ + {"DeadlineExceeded", context.DeadlineExceeded, false}, + {"Canceled", context.Canceled, false}, + {"Canceled wrapped", fmt.Errorf("%w: %w", context.Canceled, errors.New("ctx cancelled")), false}, + {"Random error", errors.New("random error"), false}, + } + + for _, tt := range errorTests { + t.Run(tt.desc, func(t *testing.T) { + if IsRetryableError(tt.err) != tt.want { + t.Fatalf("retriable: got %v, want %v", !tt.want, tt.want) + } + }) + } +} + +// TestEnqueueAndFlush enqueues n logs and flushes them. +// We expect all logs to be flushed and for no +// logs to remain in the store once FlushAndStop returns. +func TestEnqueueAndFlush(t *testing.T) { + c := qt.New(t) + mockTransport := newMockTransport(nil) + al := loggerForTest(t, Opts{ + RetryLimit: 200, + Logf: t.Logf, + Store: NewLogStore(&mem.Store{}), + }) + + c.Assert(al.SetProfileID("test"), qt.IsNil) + c.Assert(al.Start(mockTransport), qt.IsNil) + + wantSent := 10 + + for i := range wantSent { + err := al.Enqueue(tailcfg.AuditNodeDisconnect, fmt.Sprintf("log %d", i)) + c.Assert(err, qt.IsNil) + } + + al.FlushAndStop(context.Background()) + + al.mu.Lock() + defer al.mu.Unlock() + gotStored, err := al.storedCountLocked() + c.Assert(err, qt.IsNil) + + if wantStored := 0; gotStored != wantStored { + t.Fatalf("stored: got %d, want %d", gotStored, wantStored) + } + + if gotSent := mockTransport.sentCount(); gotSent != wantSent { + t.Fatalf("sent: got %d, want %d", gotSent, wantSent) + } +} + +// TestEnqueueAndFlushWithFlushCancel calls FlushAndCancel with a cancelled +// context. We expect nothing to be sent and all logs to be stored. +func TestEnqueueAndFlushWithFlushCancel(t *testing.T) { + c := qt.New(t) + mockTransport := newMockTransport(&retriableError) + al := loggerForTest(t, Opts{ + RetryLimit: 200, + Logf: t.Logf, + Store: NewLogStore(&mem.Store{}), + }) + + c.Assert(al.SetProfileID("test"), qt.IsNil) + c.Assert(al.Start(mockTransport), qt.IsNil) + + for i := range 10 { + err := al.Enqueue(tailcfg.AuditNodeDisconnect, fmt.Sprintf("log %d", i)) + c.Assert(err, qt.IsNil) + } + + // Cancel the context before calling FlushAndStop - nothing should get sent. + // This mimics a timeout before flush() has a chance to execute. + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + al.FlushAndStop(ctx) + + al.mu.Lock() + defer al.mu.Unlock() + gotStored, err := al.storedCountLocked() + c.Assert(err, qt.IsNil) + + if wantStored := 10; gotStored != wantStored { + t.Fatalf("stored: got %d, want %d", gotStored, wantStored) + } + + if gotSent, wantSent := mockTransport.sentCount(), 0; gotSent != wantSent { + t.Fatalf("sent: got %d, want %d", gotSent, wantSent) + } +} + +// TestDeduplicateAndSort tests that the most recent log is kept when deduplicating logs +func TestDeduplicateAndSort(t *testing.T) { + c := qt.New(t) + al := loggerForTest(t, Opts{ + RetryLimit: 100, + Logf: t.Logf, + Store: NewLogStore(&mem.Store{}), + }) + + c.Assert(al.SetProfileID("test"), qt.IsNil) + + logs := []*transaction{ + {EventID: "1", Details: "log 1", TimeStamp: time.Now().Add(-time.Minute * 1), Retries: 1}, + } + + al.mu.Lock() + defer al.mu.Unlock() + al.appendToStoreLocked(logs) + + // Update the transaction and re-append it + logs[0].Retries = 2 + al.appendToStoreLocked(logs) + + fromStore, err := al.store.load("test") + c.Assert(err, qt.IsNil) + + // We should see only one transaction + if wantStored, gotStored := len(logs), len(fromStore); gotStored != wantStored { + t.Fatalf("stored: got %d, want %d", gotStored, wantStored) + } + + // We should see the latest transaction + if wantRetryCount, gotRetryCount := 2, fromStore[0].Retries; gotRetryCount != wantRetryCount { + t.Fatalf("reties: got %d, want %d", gotRetryCount, wantRetryCount) + } +} + +func TestChangeProfileId(t *testing.T) { + c := qt.New(t) + al := loggerForTest(t, Opts{ + RetryLimit: 100, + Logf: t.Logf, + Store: NewLogStore(&mem.Store{}), + }) + c.Assert(al.SetProfileID("test"), qt.IsNil) + + // Calling SetProfileID with the same profile ID must not fail. + c.Assert(al.SetProfileID("test"), qt.IsNil) + + // Changing a profile ID must fail. + c.Assert(al.SetProfileID("test2"), qt.IsNotNil) +} + +// TestSendOnRestore pushes a n logs to the persistent store, and ensures they +// are sent as soon as Start is called then checks to ensure the sent logs no +// longer exist in the store. +func TestSendOnRestore(t *testing.T) { + c := qt.New(t) + mockTransport := newMockTransport(nil) + al := loggerForTest(t, Opts{ + RetryLimit: 100, + Logf: t.Logf, + Store: NewLogStore(&mem.Store{}), + }) + al.SetProfileID("test") + + wantTotal := 10 + + for range 10 { + al.Enqueue(tailcfg.AuditNodeDisconnect, "log") + } + + c.Assert(al.Start(mockTransport), qt.IsNil) + + al.FlushAndStop(context.Background()) + + al.mu.Lock() + defer al.mu.Unlock() + gotStored, err := al.storedCountLocked() + c.Assert(err, qt.IsNil) + + if wantStored := 0; gotStored != wantStored { + t.Fatalf("stored: got %d, want %d", gotStored, wantStored) + } + + if gotSent, wantSent := mockTransport.sentCount(), wantTotal; gotSent != wantSent { + t.Fatalf("sent: got %d, want %d", gotSent, wantSent) + } +} + +// TestFailureExhaustion enqueues n logs, with the transport in a failable state. +// We then set it to a non-failing state, call FlushAndStop and expect all logs to be sent. +func TestFailureExhaustion(t *testing.T) { + c := qt.New(t) + mockTransport := newMockTransport(&retriableError) + + al := loggerForTest(t, Opts{ + RetryLimit: 1, + Logf: t.Logf, + Store: NewLogStore(&mem.Store{}), + }) + + c.Assert(al.SetProfileID("test"), qt.IsNil) + c.Assert(al.Start(mockTransport), qt.IsNil) + + for range 10 { + err := al.Enqueue(tailcfg.AuditNodeDisconnect, "log") + c.Assert(err, qt.IsNil) + } + + al.FlushAndStop(context.Background()) + al.mu.Lock() + defer al.mu.Unlock() + gotStored, err := al.storedCountLocked() + c.Assert(err, qt.IsNil) + + if wantStored := 0; gotStored != wantStored { + t.Fatalf("stored: got %d, want %d", gotStored, wantStored) + } + + if gotSent, wantSent := mockTransport.sentCount(), 0; gotSent != wantSent { + t.Fatalf("sent: got %d, want %d", gotSent, wantSent) + } +} + +// TestEnqueueAndFailNoRetry enqueues a set of logs, all of which will fail and are not +// retriable. We then call FlushAndStop and expect all to be unsent. +func TestEnqueueAndFailNoRetry(t *testing.T) { + c := qt.New(t) + mockTransport := newMockTransport(&nonRetriableError) + + al := loggerForTest(t, Opts{ + RetryLimit: 100, + Logf: t.Logf, + Store: NewLogStore(&mem.Store{}), + }) + + c.Assert(al.SetProfileID("test"), qt.IsNil) + c.Assert(al.Start(mockTransport), qt.IsNil) + + for i := range 10 { + err := al.Enqueue(tailcfg.AuditNodeDisconnect, fmt.Sprintf("log %d", i)) + c.Assert(err, qt.IsNil) + } + + al.FlushAndStop(context.Background()) + al.mu.Lock() + defer al.mu.Unlock() + gotStored, err := al.storedCountLocked() + c.Assert(err, qt.IsNil) + + if wantStored := 0; gotStored != wantStored { + t.Fatalf("stored: got %d, want %d", gotStored, wantStored) + } + + if gotSent, wantSent := mockTransport.sentCount(), 0; gotSent != wantSent { + t.Fatalf("sent: got %d, want %d", gotSent, wantSent) + } +} + +// TestEnqueueAndRetry enqueues a set of logs, all of which will fail and are retriable. +// Mid-test, we set the transport to not-fail and expect the queue to flush properly +// We set the backoff parameters to 0 seconds so retries are immediate. +func TestEnqueueAndRetry(t *testing.T) { + c := qt.New(t) + mockTransport := newMockTransport(&retriableError) + + al := loggerForTest(t, Opts{ + RetryLimit: 100, + Logf: t.Logf, + Store: NewLogStore(&mem.Store{}), + }) + + al.backoffOpts = backoffOpts{ + min: 1 * time.Millisecond, + max: 4 * time.Millisecond, + multiplier: 2.0, + } + + c.Assert(al.SetProfileID("test"), qt.IsNil) + c.Assert(al.Start(mockTransport), qt.IsNil) + + err := al.Enqueue(tailcfg.AuditNodeDisconnect, fmt.Sprintf("log 1")) + c.Assert(err, qt.IsNil) + + // This will wait for at least 2 retries + gotRetried, wantRetried := mockTransport.waitForSendAttemptsToReach(3), true + if gotRetried != wantRetried { + t.Fatalf("retried: got %v, want %v", gotRetried, wantRetried) + } + + mockTransport.setErrorCondition(nil) + + al.FlushAndStop(context.Background()) + al.mu.Lock() + defer al.mu.Unlock() + + gotStored, err := al.storedCountLocked() + c.Assert(err, qt.IsNil) + + if wantStored := 0; gotStored != wantStored { + t.Fatalf("stored: got %d, want %d", gotStored, wantStored) + } + + if gotSent, wantSent := mockTransport.sentCount(), 1; gotSent != wantSent { + t.Fatalf("sent: got %d, want %d", gotSent, wantSent) + } +} + +// TestEnqueueBeforeSetProfileID tests that logs enqueued before SetProfileId are not sent +func TestEnqueueBeforeSetProfileID(t *testing.T) { + c := qt.New(t) + al := loggerForTest(t, Opts{ + RetryLimit: 100, + Logf: t.Logf, + Store: NewLogStore(&mem.Store{}), + }) + + err := al.Enqueue(tailcfg.AuditNodeDisconnect, "log") + c.Assert(err, qt.IsNotNil) + al.FlushAndStop(context.Background()) + + al.mu.Lock() + defer al.mu.Unlock() + gotStored, err := al.storedCountLocked() + c.Assert(err, qt.IsNotNil) + + if wantStored := 0; gotStored != wantStored { + t.Fatalf("stored: got %d, want %d", gotStored, wantStored) + } +} + +// TestLogStoring tests that audit logs are persisted sorted by timestamp, oldest to newest +func TestLogSorting(t *testing.T) { + c := qt.New(t) + mockStore := NewLogStore(&mem.Store{}) + + logs := []*transaction{ + {EventID: "1", Details: "log 3", TimeStamp: time.Now().Add(-time.Minute * 1)}, + {EventID: "1", Details: "log 3", TimeStamp: time.Now().Add(-time.Minute * 2)}, + {EventID: "2", Details: "log 2", TimeStamp: time.Now().Add(-time.Minute * 3)}, + {EventID: "3", Details: "log 1", TimeStamp: time.Now().Add(-time.Minute * 4)}, + } + + wantLogs := []transaction{ + {Details: "log 1"}, + {Details: "log 2"}, + {Details: "log 3"}, + } + + mockStore.save("test", logs) + + gotLogs, err := mockStore.load("test") + c.Assert(err, qt.IsNil) + gotLogs = deduplicateAndSort(gotLogs) + + for i := range gotLogs { + if want, got := wantLogs[i].Details, gotLogs[i].Details; want != got { + t.Fatalf("Details: got %v, want %v", got, want) + } + } +} + +// mock implementations for testing + +// newMockTransport returns a mock transport for testing +// If err is no nil, SendAuditLog will return this error if the send is attempted +// before the context is cancelled. +func newMockTransport(err error) *mockAuditLogTransport { + return &mockAuditLogTransport{ + err: err, + attempts: make(chan int, 1), + } +} + +type mockAuditLogTransport struct { + attempts chan int // channel to notify of send attempts + + mu sync.Mutex + sendAttmpts int // number of attempts to send logs + sendCount int // number of logs sent by the transport + err error // error to return when sending logs +} + +// waitForSendAttemptsToReach blocks until the number of send attempts reaches n +// This should be use only in tests where the transport is expected to retry sending logs +func (t *mockAuditLogTransport) waitForSendAttemptsToReach(n int) bool { + for attempts := range t.attempts { + if attempts >= n { + return true + } + } + return false +} + +func (t *mockAuditLogTransport) setErrorCondition(err error) { + t.mu.Lock() + defer t.mu.Unlock() + t.err = err +} + +func (t *mockAuditLogTransport) sentCount() int { + t.mu.Lock() + defer t.mu.Unlock() + return t.sendCount +} + +func (t *mockAuditLogTransport) SendAuditLog(ctx context.Context, _ tailcfg.AuditLogRequest) (err error) { + t.mu.Lock() + t.sendAttmpts += 1 + defer func() { + a := t.sendAttmpts + t.mu.Unlock() + select { + case t.attempts <- a: + default: + } + }() + + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + + if t.err != nil { + return t.err + } + t.sendCount += 1 + return nil +} + +var ( + retriableError = mockError{errors.New("retriable error")} + nonRetriableError = mockError{errors.New("permanent failure error")} +) + +type mockError struct { + error +} + +func (e mockError) Retryable() bool { + return e == retriableError +} diff --git a/ipn/auditlog/extension.go b/ipn/auditlog/extension.go new file mode 100644 index 000000000..ae2a296b2 --- /dev/null +++ b/ipn/auditlog/extension.go @@ -0,0 +1,189 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package auditlog + +import ( + "context" + "errors" + "fmt" + "time" + + "tailscale.com/control/controlclient" + "tailscale.com/feature" + "tailscale.com/ipn" + "tailscale.com/ipn/ipnauth" + "tailscale.com/ipn/ipnext" + "tailscale.com/syncs" + "tailscale.com/tailcfg" + "tailscale.com/types/lazy" + "tailscale.com/types/logger" +) + +// featureName is the name of the feature implemented by this package. +// It is also the [extension] name and the log prefix. +const featureName = "auditlog" + +func init() { + feature.Register(featureName) + ipnext.RegisterExtension(featureName, newExtension) +} + +// extension is an [ipnext.Extension] managing audit logging +// on platforms that import this package. +// As of 2025-03-27, that's only Windows and macOS. +type extension struct { + logf logger.Logf + + // store is the log store shared by all loggers. + // It is created when the first logger is started. + store lazy.SyncValue[LogStore] + + // mu protects all following fields. + mu syncs.Mutex + // logger is the current audit logger, or nil if it is not set up, + // such as before the first control client is created, or after + // a profile change and before the new control client is created. + // + // It queues, persists, and sends audit logs to the control client. + logger *Logger +} + +// newExtension is an [ipnext.NewExtensionFn] that creates a new audit log extension. +// It is registered with [ipnext.RegisterExtension] if the package is imported. +func newExtension(logf logger.Logf, _ ipnext.SafeBackend) (ipnext.Extension, error) { + return &extension{logf: logger.WithPrefix(logf, featureName+": ")}, nil +} + +// Name implements [ipnext.Extension]. +func (e *extension) Name() string { + return featureName +} + +// Init implements [ipnext.Extension] by registering callbacks and providers +// for the duration of the extension's lifetime. +func (e *extension) Init(h ipnext.Host) error { + h.Hooks().NewControlClient.Add(e.controlClientChanged) + h.Hooks().ProfileStateChange.Add(e.profileChanged) + h.Hooks().AuditLoggers.Add(e.getCurrentLogger) + return nil +} + +// [controlclient.Auto] implements [Transport]. +var _ Transport = (*controlclient.Auto)(nil) + +// startNewLogger creates and starts a new logger for the specified profile +// using the specified [controlclient.Client] as the transport. +// The profileID may be "" if the profile has not been persisted yet. +func (e *extension) startNewLogger(cc controlclient.Client, profileID ipn.ProfileID) (*Logger, error) { + transport, ok := cc.(Transport) + if !ok { + return nil, fmt.Errorf("%T cannot be used as transport", cc) + } + + // Create a new log store if this is the first logger. + // Otherwise, get the existing log store. + store, err := e.store.GetErr(func() (LogStore, error) { + return newDefaultLogStore(e.logf) + }) + if err != nil { + return nil, fmt.Errorf("failed to create audit log store: %w", err) + } + + logger := NewLogger(Opts{ + Logf: e.logf, + RetryLimit: 32, + Store: store, + }) + if err := logger.SetProfileID(profileID); err != nil { + return nil, fmt.Errorf("set profile failed: %w", err) + } + if err := logger.Start(transport); err != nil { + return nil, fmt.Errorf("start failed: %w", err) + } + return logger, nil +} + +func (e *extension) controlClientChanged(cc controlclient.Client, profile ipn.LoginProfileView) (cleanup func()) { + logger, err := e.startNewLogger(cc, profile.ID()) + e.mu.Lock() + e.logger = logger // nil on error + e.mu.Unlock() + if err != nil { + // If we fail to create or start the logger, log the error + // and return a nil cleanup function. There's nothing more + // we can do here. + // + // But [extension.getCurrentLogger] returns [noCurrentLogger] + // when the logger is nil. Since [noCurrentLogger] always + // fails with [errNoLogger], operations that must be audited + // but cannot will fail on platforms where the audit logger + // is enabled (i.e., the auditlog package is imported). + e.logf("[unexpected] %v", err) + return nil + } + return func() { + // Stop the logger when the control client shuts down. + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + logger.FlushAndStop(ctx) + } +} + +func (e *extension) profileChanged(profile ipn.LoginProfileView, _ ipn.PrefsView, sameNode bool) { + e.mu.Lock() + defer e.mu.Unlock() + switch { + case e.logger == nil: + // No-op if we don't have an audit logger. + case sameNode: + // The profile info has changed, but it represents the same node. + // This includes the case where the login has just been completed + // and the profile's [ipn.ProfileID] has been set for the first time. + if err := e.logger.SetProfileID(profile.ID()); err != nil { + e.logf("[unexpected] failed to set profile ID: %v", err) + } + default: + // The profile info has changed, and it represents a different node. + // We won't have an audit logger for the new profile until the new + // control client is created. + // + // We don't expect any auditable actions to be attempted in this state. + // But if they are, they will fail with [errNoLogger]. + e.logger = nil + } +} + +// errNoLogger is an error returned by [noCurrentLogger]. It indicates that +// the logger was unavailable when [ipnlocal.LocalBackend] requested it, +// such as when an auditable action was attempted before [LocalBackend.Start] +// was called for the first time or immediately after a profile change +// and before the new control client was created. +// +// This error is unexpected and should not occur in normal operation. +var errNoLogger = errors.New("[unexpected] no audit logger") + +// noCurrentLogger is an [ipnauth.AuditLogFunc] returned by [extension.getCurrentLogger] +// when the logger is not available. It fails with [errNoLogger] on every call. +func noCurrentLogger(_ tailcfg.ClientAuditAction, _ string) error { + return errNoLogger +} + +// getCurrentLogger is an [ipnext.AuditLogProvider] registered with [ipnext.Host]. +// It is called when [ipnlocal.LocalBackend] or an extension needs to audit an action. +// +// It returns a function that enqueues the audit log for the current profile, +// or [noCurrentLogger] if the logger is unavailable. +func (e *extension) getCurrentLogger() ipnauth.AuditLogFunc { + e.mu.Lock() + defer e.mu.Unlock() + if e.logger == nil { + return noCurrentLogger + } + return e.logger.Enqueue +} + +// Shutdown implements [ipnlocal.Extension]. +func (e *extension) Shutdown() error { + return nil +} diff --git a/ipn/auditlog/store.go b/ipn/auditlog/store.go new file mode 100644 index 000000000..3b58ffa93 --- /dev/null +++ b/ipn/auditlog/store.go @@ -0,0 +1,62 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package auditlog + +import ( + "fmt" + "os" + "path/filepath" + "runtime" + + "tailscale.com/ipn/store" + "tailscale.com/types/lazy" + "tailscale.com/types/logger" + "tailscale.com/util/must" +) + +var storeFilePath lazy.SyncValue[string] + +// SetStoreFilePath sets the audit log store file path. +// It is optional on platforms with a default store path, +// but required on platforms without one (e.g., macOS). +// It panics if called more than once or after the store has been created. +func SetStoreFilePath(path string) { + if !storeFilePath.Set(path) { + panic("store file path already set or used") + } +} + +// DefaultStoreFilePath returns the default audit log store file path +// for the current platform, or an error if the platform does not have one. +func DefaultStoreFilePath() (string, error) { + switch runtime.GOOS { + case "windows": + return filepath.Join(os.Getenv("ProgramData"), "Tailscale", "audit-log.json"), nil + default: + // The auditlog package must either be omitted from the build, + // have the platform-specific store path set with [SetStoreFilePath] (e.g., on macOS), + // or have the default store path available on the current platform. + return "", fmt.Errorf("[unexpected] no default store path available on %s", runtime.GOOS) + } +} + +// newDefaultLogStore returns a new [LogStore] for the current platform. +func newDefaultLogStore(logf logger.Logf) (LogStore, error) { + path, err := storeFilePath.GetErr(DefaultStoreFilePath) + if err != nil { + // This indicates that the auditlog package was not omitted from the build + // on a platform without a default store path and that [SetStoreFilePath] + // was not called to set a platform-specific store path. + // + // This is not expected to happen, but if it does, let's log it + // and use an in-memory store as a fallback. + logf("[unexpected] failed to get audit log store path: %v", err) + return NewLogStore(must.Get(store.New(logf, "mem:auditlog"))), nil + } + fs, err := store.New(logf, path) + if err != nil { + return nil, fmt.Errorf("failed to create audit log store at %q: %w", path, err) + } + return NewLogStore(fs), nil +} diff --git a/ipn/backend.go b/ipn/backend.go index d6ba95408..b4ba958c5 100644 --- a/ipn/backend.go +++ b/ipn/backend.go @@ -58,21 +58,33 @@ type EngineStatus struct { // to subscribe to. type NotifyWatchOpt uint64 +// NotifyWatchOpt values. +// +// These aren't declared using Go's iota because they're not purely internal to +// the process and iota should not be used for values that are serialized to +// disk or network. In this case, these values come over the network via the +// LocalAPI, a mostly stable API. const ( // NotifyWatchEngineUpdates, if set, causes Engine updates to be sent to the // client either regularly or when they change, without having to ask for // each one via Engine.RequestStatus. - NotifyWatchEngineUpdates NotifyWatchOpt = 1 << iota + NotifyWatchEngineUpdates NotifyWatchOpt = 1 << 0 + + NotifyInitialState NotifyWatchOpt = 1 << 1 // if set, the first Notify message (sent immediately) will contain the current State + BrowseToURL + SessionID + NotifyInitialPrefs NotifyWatchOpt = 1 << 2 // if set, the first Notify message (sent immediately) will contain the current Prefs + NotifyInitialNetMap NotifyWatchOpt = 1 << 3 // if set, the first Notify message (sent immediately) will contain the current NetMap + + NotifyNoPrivateKeys NotifyWatchOpt = 1 << 4 // (no-op) it used to redact private keys; now they always are and this does nothing + NotifyInitialDriveShares NotifyWatchOpt = 1 << 5 // if set, the first Notify message (sent immediately) will contain the current Taildrive Shares + NotifyInitialOutgoingFiles NotifyWatchOpt = 1 << 6 // if set, the first Notify message (sent immediately) will contain the current Taildrop OutgoingFiles - NotifyInitialState // if set, the first Notify message (sent immediately) will contain the current State + BrowseToURL + SessionID - NotifyInitialPrefs // if set, the first Notify message (sent immediately) will contain the current Prefs - NotifyInitialNetMap // if set, the first Notify message (sent immediately) will contain the current NetMap + NotifyInitialHealthState NotifyWatchOpt = 1 << 7 // if set, the first Notify message (sent immediately) will contain the current health.State of the client - NotifyNoPrivateKeys // if set, private keys that would normally be sent in updates are zeroed out - NotifyInitialDriveShares // if set, the first Notify message (sent immediately) will contain the current Taildrive Shares - NotifyInitialOutgoingFiles // if set, the first Notify message (sent immediately) will contain the current Taildrop OutgoingFiles + NotifyRateLimit NotifyWatchOpt = 1 << 8 // if set, rate limit spammy netmap updates to every few seconds - NotifyInitialHealthState // if set, the first Notify message (sent immediately) will contain the current health.State of the client + NotifyHealthActions NotifyWatchOpt = 1 << 9 // if set, include PrimaryActions in health.State. Otherwise append the action URL to the text + + NotifyInitialSuggestedExitNode NotifyWatchOpt = 1 << 10 // if set, the first Notify message (sent immediately) will contain the current SuggestedExitNode if available ) // Notify is a communication from a backend (e.g. tailscaled) to a frontend @@ -88,7 +100,7 @@ type Notify struct { // This field is only set in the first message when requesting // NotifyInitialState. Clients must store it on their side as // following notifications will not include this field. - SessionID string `json:",omitempty"` + SessionID string `json:",omitzero"` // ErrMessage, if non-nil, contains a critical error message. // For State InUseOtherUser, ErrMessage is not critical and just contains the details. @@ -100,14 +112,13 @@ type Notify struct { NetMap *netmap.NetworkMap // if non-nil, the new or current netmap Engine *EngineStatus // if non-nil, the new or current wireguard stats BrowseToURL *string // if non-nil, UI should open a browser right now - BackendLogID *string // if non-nil, the public logtail ID used by backend // FilesWaiting if non-nil means that files are buffered in // the Tailscale daemon and ready for local transfer to the // user's preferred storage location. // // Deprecated: use LocalClient.AwaitWaitingFiles instead. - FilesWaiting *empty.Message `json:",omitempty"` + FilesWaiting *empty.Message `json:",omitzero"` // IncomingFiles, if non-nil, specifies which files are in the // process of being received. A nil IncomingFiles means this @@ -116,22 +127,22 @@ type Notify struct { // of being transferred. // // Deprecated: use LocalClient.AwaitWaitingFiles instead. - IncomingFiles []PartialFile `json:",omitempty"` + IncomingFiles []PartialFile `json:",omitzero"` // OutgoingFiles, if non-nil, tracks which files are in the process of // being sent via TailDrop, including files that finished, whether // successful or failed. This slice is sorted by Started time, then Name. - OutgoingFiles []*OutgoingFile `json:",omitempty"` + OutgoingFiles []*OutgoingFile `json:",omitzero"` // LocalTCPPort, if non-nil, informs the UI frontend which // (non-zero) localhost TCP port it's listening on. // This is currently only used by Tailscale when run in the // macOS Network Extension. - LocalTCPPort *uint16 `json:",omitempty"` + LocalTCPPort *uint16 `json:",omitzero"` // ClientVersion, if non-nil, describes whether a client version update // is available. - ClientVersion *tailcfg.ClientVersion `json:",omitempty"` + ClientVersion *tailcfg.ClientVersion `json:",omitzero"` // DriveShares tracks the full set of current DriveShares that we're // publishing. Some client applications, like the MacOS and Windows clients, @@ -144,9 +155,13 @@ type Notify struct { // Health is the last-known health state of the backend. When this field is // non-nil, a change in health verified, and the API client should surface // any changes to the user in the UI. - Health *health.State `json:",omitempty"` + Health *health.State `json:",omitzero"` + + // SuggestedExitNode, if non-nil, is the node that the backend has determined to + // be the best exit node for the current network conditions. + SuggestedExitNode *tailcfg.StableNodeID `json:",omitzero"` - // type is mirrored in xcode/Shared/IPN.swift + // type is mirrored in xcode/IPN/Core/LocalAPI/Model/LocalAPIModel.swift } func (n Notify) String() string { @@ -173,9 +188,6 @@ func (n Notify) String() string { if n.BrowseToURL != nil { sb.WriteString("URL=<...> ") } - if n.BackendLogID != nil { - sb.WriteString("BackendLogID ") - } if n.FilesWaiting != nil { sb.WriteString("FilesWaiting ") } @@ -188,8 +200,16 @@ func (n Notify) String() string { if n.Health != nil { sb.WriteString("Health{...} ") } + if n.SuggestedExitNode != nil { + fmt.Fprintf(&sb, "SuggestedExitNode=%v ", *n.SuggestedExitNode) + } + s := sb.String() - return s[0:len(s)-1] + "}" + if s == "Notify{" { + return "Notify{}" + } else { + return s[0:len(s)-1] + "}" + } } // PartialFile represents an in-progress incoming file transfer. @@ -238,6 +258,7 @@ type StateKey string var DebuggableComponents = []string{ "magicsock", "sockstats", + "syspolicy", } type Options struct { diff --git a/ipn/backend_test.go b/ipn/backend_test.go new file mode 100644 index 000000000..d72b96615 --- /dev/null +++ b/ipn/backend_test.go @@ -0,0 +1,42 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package ipn + +import ( + "testing" + + "tailscale.com/health" + "tailscale.com/types/empty" +) + +func TestNotifyString(t *testing.T) { + for _, tt := range []struct { + name string + value Notify + expected string + }{ + { + name: "notify-empty", + value: Notify{}, + expected: "Notify{}", + }, + { + name: "notify-with-login-finished", + value: Notify{LoginFinished: &empty.Message{}}, + expected: "Notify{LoginFinished}", + }, + { + name: "notify-with-multiple-fields", + value: Notify{LoginFinished: &empty.Message{}, Health: &health.State{}}, + expected: "Notify{LoginFinished Health{...}}", + }, + } { + t.Run(tt.name, func(t *testing.T) { + actual := tt.value.String() + if actual != tt.expected { + t.Fatalf("expected=%q, actual=%q", tt.expected, actual) + } + }) + } +} diff --git a/ipn/conf.go b/ipn/conf.go index 6a67f4004..2c9fb2fd1 100644 --- a/ipn/conf.go +++ b/ipn/conf.go @@ -32,6 +32,10 @@ type ConfigVAlpha struct { AdvertiseRoutes []netip.Prefix `json:",omitempty"` DisableSNAT opt.Bool `json:",omitempty"` + AdvertiseServices []string `json:",omitempty"` + + AppConnector *AppConnectorPrefs `json:",omitempty"` // advertise app connector; defaults to false (if nil or explicitly set to false) + NetfilterMode *string `json:",omitempty"` // "on", "off", "nodivert" NoStatefulFiltering opt.Bool `json:",omitempty"` @@ -137,5 +141,19 @@ func (c *ConfigVAlpha) ToPrefs() (MaskedPrefs, error) { mp.AutoUpdate = *c.AutoUpdate mp.AutoUpdateSet = AutoUpdatePrefsMask{ApplySet: true, CheckSet: true} } + if c.AppConnector != nil { + mp.AppConnector = *c.AppConnector + mp.AppConnectorSet = true + } + // Configfile should be the source of truth for whether this node + // advertises any services. We need to ensure that each reload updates + // currently advertised services as else the transition from 'some + // services are advertised' to 'advertised services are empty/unset in + // conffile' would have no effect (especially given that an empty + // service slice would be omitted from the JSON config). + mp.AdvertiseServicesSet = true + if c.AdvertiseServices != nil { + mp.AdvertiseServices = c.AdvertiseServices + } return mp, nil } diff --git a/ipn/conffile/cloudconf.go b/ipn/conffile/cloudconf.go index 650611cf1..4475a2d7b 100644 --- a/ipn/conffile/cloudconf.go +++ b/ipn/conffile/cloudconf.go @@ -10,6 +10,8 @@ import ( "net/http" "strings" + "tailscale.com/feature" + "tailscale.com/feature/buildfeatures" "tailscale.com/omit" ) @@ -35,6 +37,9 @@ func getEC2MetadataToken() (string, error) { } func readVMUserData() ([]byte, error) { + if !buildfeatures.HasAWS { + return nil, feature.ErrUnavailable + } // TODO(bradfitz): support GCP, Azure, Proxmox/cloud-init // (NoCloud/ConfigDrive ISO), etc. diff --git a/ipn/conffile/conffile.go b/ipn/conffile/conffile.go index 0b4670c42..3a2aeffb3 100644 --- a/ipn/conffile/conffile.go +++ b/ipn/conffile/conffile.go @@ -10,8 +10,9 @@ import ( "encoding/json" "fmt" "os" + "runtime" - "github.com/tailscale/hujson" + "tailscale.com/feature/buildfeatures" "tailscale.com/ipn" ) @@ -39,8 +40,17 @@ func (c *Config) WantRunning() bool { // from the VM's metadata service's user-data field. const VMUserDataPath = "vm:user-data" +// hujsonStandardize is set to hujson.Standardize by conffile_hujson.go on +// platforms that support config files. +var hujsonStandardize func([]byte) ([]byte, error) + // Load reads and parses the config file at the provided path on disk. func Load(path string) (*Config, error) { + switch runtime.GOOS { + case "ios", "android": + // compile-time for deadcode elimination + return nil, fmt.Errorf("config file loading not supported on %q", runtime.GOOS) + } var c Config c.Path = path var err error @@ -54,14 +64,21 @@ func Load(path string) (*Config, error) { if err != nil { return nil, err } - c.Std, err = hujson.Standardize(c.Raw) - if err != nil { - return nil, fmt.Errorf("error parsing config file %s HuJSON/JSON: %w", path, err) + if buildfeatures.HasHuJSONConf && hujsonStandardize != nil { + c.Std, err = hujsonStandardize(c.Raw) + if err != nil { + return nil, fmt.Errorf("error parsing config file %s HuJSON/JSON: %w", path, err) + } + } else { + c.Std = c.Raw // config file must be valid JSON with ts_omit_hujsonconf } var ver struct { Version string `json:"version"` } if err := json.Unmarshal(c.Std, &ver); err != nil { + if !buildfeatures.HasHuJSONConf { + return nil, fmt.Errorf("error parsing config file %s, which must be valid standard JSON: %w", path, err) + } return nil, fmt.Errorf("error parsing config file %s: %w", path, err) } switch ver.Version { diff --git a/ipn/conffile/conffile_hujson.go b/ipn/conffile/conffile_hujson.go new file mode 100644 index 000000000..1e967f1bd --- /dev/null +++ b/ipn/conffile/conffile_hujson.go @@ -0,0 +1,20 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !ios && !android && !ts_omit_hujsonconf + +package conffile + +import "github.com/tailscale/hujson" + +// Only link the hujson package on platforms that use it, to reduce binary size +// & memory a bit. +// +// (iOS and Android don't have config files) + +// While the linker's deadcode mostly handles the hujson package today, this +// keeps us honest for the future. + +func init() { + hujsonStandardize = hujson.Standardize +} diff --git a/ipn/conffile/serveconf.go b/ipn/conffile/serveconf.go new file mode 100644 index 000000000..bb63c1ac5 --- /dev/null +++ b/ipn/conffile/serveconf.go @@ -0,0 +1,239 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !ts_omit_serve + +package conffile + +import ( + "errors" + "fmt" + "net" + "os" + "path" + "strings" + + jsonv2 "github.com/go-json-experiment/json" + "github.com/go-json-experiment/json/jsontext" + "tailscale.com/tailcfg" + "tailscale.com/types/opt" + "tailscale.com/util/mak" +) + +// ServicesConfigFile is the config file format for services configuration. +type ServicesConfigFile struct { + // Version is always "0.0.1" and always present. + Version string `json:"version"` + + Services map[tailcfg.ServiceName]*ServiceDetailsFile `json:"services,omitzero"` +} + +// ServiceDetailsFile is the config syntax for an individual Tailscale Service. +type ServiceDetailsFile struct { + // Version is always "0.0.1", set if and only if this is not inside a + // [ServiceConfigFile]. + Version string `json:"version,omitzero"` + + // Endpoints are sets of reverse proxy mappings from ProtoPortRanges on a + // Service to Targets (proto+destination+port) on remote destinations (or + // localhost). + // For example, "tcp:443" -> "tcp://localhost:8000" is an endpoint definition + // mapping traffic on the TCP port 443 of the Service to port 8080 on localhost. + // The Proto in the key must be populated. + // As a special case, if the only mapping provided is "*" -> "TUN", that + // enables TUN/L3 mode, where packets are delivered to the Tailscale network + // interface with the understanding that the user will deal with them manually. + Endpoints map[*tailcfg.ProtoPortRange]*Target `json:"endpoints"` + + // Advertised is a flag that tells control whether or not the client thinks + // it is ready to host a particular Tailscale Service. If unset, it is + // assumed to be true. + Advertised opt.Bool `json:"advertised,omitzero"` +} + +// ServiceProtocol is the protocol of a Target. +type ServiceProtocol string + +const ( + ProtoHTTP ServiceProtocol = "http" + ProtoHTTPS ServiceProtocol = "https" + ProtoHTTPSInsecure ServiceProtocol = "https+insecure" + ProtoTCP ServiceProtocol = "tcp" + ProtoTLSTerminatedTCP ServiceProtocol = "tls-terminated-tcp" + ProtoFile ServiceProtocol = "file" + ProtoTUN ServiceProtocol = "TUN" +) + +// Target is a destination for traffic to go to when it arrives at a Tailscale +// Service host. +type Target struct { + // The protocol over which to communicate with the Destination. + // Protocol == ProtoTUN is a special case, activating "TUN mode" where + // packets are delivered to the Tailscale TUN interface and then manually + // handled by the user. + Protocol ServiceProtocol + + // If Protocol is ProtoFile, then Destination is a file path. + // If Protocol is ProtoTUN, then Destination is empty. + // Otherwise, it is a host. + Destination string + + // If Protocol is not ProtoFile or ProtoTUN, then DestinationPorts is the + // set of ports on which to connect to the host referred to by Destination. + DestinationPorts tailcfg.PortRange +} + +// UnmarshalJSON implements [jsonv1.Unmarshaler]. +func (t *Target) UnmarshalJSON(buf []byte) error { + return jsonv2.Unmarshal(buf, t) +} + +// UnmarshalJSONFrom implements [jsonv2.UnmarshalerFrom]. +func (t *Target) UnmarshalJSONFrom(dec *jsontext.Decoder) error { + var str string + if err := jsonv2.UnmarshalDecode(dec, &str); err != nil { + return err + } + + // The TUN case does not look like a standard :// arrangement, + // so handled separately. + if str == "TUN" { + t.Protocol = ProtoTUN + t.Destination = "" + t.DestinationPorts = tailcfg.PortRangeAny + return nil + } + + proto, rest, found := strings.Cut(str, "://") + if !found { + return errors.New("handler not of form ://") + } + + switch ServiceProtocol(proto) { + case ProtoFile: + target := path.Clean(rest) + t.Protocol = ProtoFile + t.Destination = target + t.DestinationPorts = tailcfg.PortRange{} + case ProtoHTTP, ProtoHTTPS, ProtoHTTPSInsecure, ProtoTCP, ProtoTLSTerminatedTCP: + host, portRange, err := tailcfg.ParseHostPortRange(rest) + if err != nil { + return err + } + t.Protocol = ServiceProtocol(proto) + t.Destination = host + t.DestinationPorts = portRange + default: + return errors.New("unsupported protocol") + } + + return nil +} + +func (t *Target) MarshalText() ([]byte, error) { + var out string + switch t.Protocol { + case ProtoFile: + out = fmt.Sprintf("%s://%s", t.Protocol, t.Destination) + case ProtoTUN: + out = "TUN" + case ProtoHTTP, ProtoHTTPS, ProtoHTTPSInsecure, ProtoTCP, ProtoTLSTerminatedTCP: + out = fmt.Sprintf("%s://%s", t.Protocol, net.JoinHostPort(t.Destination, t.DestinationPorts.String())) + default: + return nil, errors.New("unsupported protocol") + } + return []byte(out), nil +} + +func LoadServicesConfig(filename string, forService string) (*ServicesConfigFile, error) { + data, err := os.ReadFile(filename) + if err != nil { + return nil, err + } + var json []byte + if hujsonStandardize != nil { + json, err = hujsonStandardize(data) + if err != nil { + return nil, err + } + } else { + json = data + } + var ver struct { + Version string `json:"version"` + } + if err = jsonv2.Unmarshal(json, &ver); err != nil { + return nil, fmt.Errorf("could not parse config file version: %w", err) + } + switch ver.Version { + case "": + return nil, errors.New("config file must have \"version\" field") + case "0.0.1": + return loadConfigV0(json, forService) + } + return nil, fmt.Errorf("unsupported config file version %q", ver.Version) +} + +func loadConfigV0(json []byte, forService string) (*ServicesConfigFile, error) { + var scf ServicesConfigFile + if svcName := tailcfg.AsServiceName(forService); svcName != "" { + var sdf ServiceDetailsFile + err := jsonv2.Unmarshal(json, &sdf, jsonv2.RejectUnknownMembers(true)) + if err != nil { + return nil, err + } + mak.Set(&scf.Services, svcName, &sdf) + + } else { + err := jsonv2.Unmarshal(json, &scf, jsonv2.RejectUnknownMembers(true)) + if err != nil { + return nil, err + } + } + for svcName, svc := range scf.Services { + if forService == "" && svc.Version != "" { + return nil, errors.New("services cannot be versioned separately from config file") + } + if err := svcName.Validate(); err != nil { + return nil, err + } + if svc.Endpoints == nil { + return nil, fmt.Errorf("service %q: missing \"endpoints\" field", svcName) + } + var sourcePorts []tailcfg.PortRange + foundTUN := false + foundNonTUN := false + for ppr, target := range svc.Endpoints { + if target.Protocol == "TUN" { + if ppr.Proto != 0 || ppr.Ports != tailcfg.PortRangeAny { + return nil, fmt.Errorf("service %q: destination \"TUN\" can only be used with source \"*\"", svcName) + } + foundTUN = true + } else { + if ppr.Ports.Last-ppr.Ports.First != target.DestinationPorts.Last-target.DestinationPorts.First { + return nil, fmt.Errorf("service %q: source and destination port ranges must be of equal size", svcName.String()) + } + foundNonTUN = true + } + if foundTUN && foundNonTUN { + return nil, fmt.Errorf("service %q: cannot mix TUN mode with non-TUN mode", svcName) + } + if pr := findOverlappingRange(sourcePorts, ppr.Ports); pr != nil { + return nil, fmt.Errorf("service %q: source port ranges %q and %q overlap", svcName, pr.String(), ppr.Ports.String()) + } + sourcePorts = append(sourcePorts, ppr.Ports) + } + } + return &scf, nil +} + +// findOverlappingRange finds and returns a reference to a [tailcfg.PortRange] +// in haystack that overlaps with needle. It returns nil if it doesn't find one. +func findOverlappingRange(haystack []tailcfg.PortRange, needle tailcfg.PortRange) *tailcfg.PortRange { + for _, pr := range haystack { + if pr.Contains(needle.First) || pr.Contains(needle.Last) || needle.Contains(pr.First) || needle.Contains(pr.Last) { + return &pr + } + } + return nil +} diff --git a/ipn/desktop/doc.go b/ipn/desktop/doc.go new file mode 100644 index 000000000..64a332792 --- /dev/null +++ b/ipn/desktop/doc.go @@ -0,0 +1,6 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package desktop facilitates interaction with the desktop environment +// and user sessions. As of 2025-02-06, it is only implemented for Windows. +package desktop diff --git a/ipn/desktop/extension.go b/ipn/desktop/extension.go new file mode 100644 index 000000000..027772671 --- /dev/null +++ b/ipn/desktop/extension.go @@ -0,0 +1,190 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Both the desktop session manager and multi-user support +// are currently available only on Windows. +// This file does not need to be built for other platforms. + +//go:build windows && !ts_omit_desktop_sessions + +package desktop + +import ( + "cmp" + "fmt" + "sync" + + "tailscale.com/feature" + "tailscale.com/ipn" + "tailscale.com/ipn/ipnext" + "tailscale.com/types/logger" + "tailscale.com/util/syspolicy/pkey" + "tailscale.com/util/syspolicy/policyclient" +) + +// featureName is the name of the feature implemented by this package. +// It is also the the [desktopSessionsExt] name and the log prefix. +const featureName = "desktop-sessions" + +func init() { + feature.Register(featureName) + ipnext.RegisterExtension(featureName, newDesktopSessionsExt) +} + +// [desktopSessionsExt] implements [ipnext.Extension]. +var _ ipnext.Extension = (*desktopSessionsExt)(nil) + +// desktopSessionsExt extends [LocalBackend] with desktop session management. +// It keeps Tailscale running in the background if Always-On mode is enabled, +// and switches to an appropriate profile when a user signs in or out, +// locks their screen, or disconnects a remote session. +type desktopSessionsExt struct { + logf logger.Logf + sm SessionManager + + host ipnext.Host // or nil, until Init is called + cleanup []func() // cleanup functions to call on shutdown + + // mu protects all following fields. + mu sync.Mutex + sessByID map[SessionID]*Session +} + +// newDesktopSessionsExt returns a new [desktopSessionsExt], +// or an error if a [SessionManager] cannot be created. +// It is registered with [ipnext.RegisterExtension] if the package is imported. +func newDesktopSessionsExt(logf logger.Logf, _ ipnext.SafeBackend) (ipnext.Extension, error) { + logf = logger.WithPrefix(logf, featureName+": ") + sm, err := NewSessionManager(logf) + if err != nil { + return nil, fmt.Errorf("%w: session manager is not available: %w", ipnext.SkipExtension, err) + } + return &desktopSessionsExt{ + logf: logf, + sm: sm, + sessByID: make(map[SessionID]*Session), + }, nil +} + +// Name implements [ipnext.Extension]. +func (e *desktopSessionsExt) Name() string { + return featureName +} + +// Init implements [ipnext.Extension]. +func (e *desktopSessionsExt) Init(host ipnext.Host) (err error) { + e.host = host + unregisterSessionCb, err := e.sm.RegisterStateCallback(e.updateDesktopSessionState) + if err != nil { + return fmt.Errorf("session callback registration failed: %w", err) + } + host.Hooks().BackgroundProfileResolvers.Add(e.getBackgroundProfile) + e.cleanup = []func(){unregisterSessionCb} + return nil +} + +// updateDesktopSessionState is a [SessionStateCallback] +// invoked by [SessionManager] once for each existing session +// and whenever the session state changes. It updates the session map +// and switches to the best profile if necessary. +func (e *desktopSessionsExt) updateDesktopSessionState(session *Session) { + e.mu.Lock() + if session.Status != ClosedSession { + e.sessByID[session.ID] = session + } else { + delete(e.sessByID, session.ID) + } + e.mu.Unlock() + + var action string + switch session.Status { + case ForegroundSession: + // The user has either signed in or unlocked their session. + // For remote sessions, this may also mean the user has connected. + // The distinction isn't important for our purposes, + // so let's always say "signed in". + action = "signed in to" + case BackgroundSession: + action = "locked" + case ClosedSession: + action = "signed out from" + default: + panic("unreachable") + } + maybeUsername, _ := session.User.Username() + userIdentifier := cmp.Or(maybeUsername, string(session.User.UserID()), "user") + reason := fmt.Sprintf("%s %s session %v", userIdentifier, action, session.ID) + + e.host.Profiles().SwitchToBestProfileAsync(reason) +} + +// getBackgroundProfile is a [ipnext.ProfileResolver] that works as follows: +// +// If Always-On mode is disabled, it returns no profile. +// +// If AlwaysOn mode is enabled, it returns the current profile unless: +// - The current profile's owner has signed out. +// - Another user has a foreground (i.e. active/unlocked) session. +// +// If the current profile owner's session runs in the background and no other user +// has a foreground session, it returns the current profile. This applies +// when a locally signed-in user locks their screen or when a remote user +// disconnects without signing out. +// +// In all other cases, it returns no profile. +func (e *desktopSessionsExt) getBackgroundProfile(profiles ipnext.ProfileStore) ipn.LoginProfileView { + e.mu.Lock() + defer e.mu.Unlock() + + if alwaysOn, _ := policyclient.Get().GetBoolean(pkey.AlwaysOn, false); !alwaysOn { + // If the Always-On mode is disabled, there's no background profile + // as far as the desktop session extension is concerned. + return ipn.LoginProfileView{} + } + + isCurrentProfileOwnerSignedIn := false + var foregroundUIDs []ipn.WindowsUserID + for _, s := range e.sessByID { + switch uid := s.User.UserID(); uid { + case profiles.CurrentProfile().LocalUserID(): + isCurrentProfileOwnerSignedIn = true + if s.Status == ForegroundSession { + // Keep the current profile if the user has a foreground session. + return profiles.CurrentProfile() + } + default: + if s.Status == ForegroundSession { + foregroundUIDs = append(foregroundUIDs, uid) + } + } + } + + // If the current profile is empty and not owned by anyone (e.g., tailscaled just started), + // or if the current profile's owner has no foreground session, switch to the default profile + // of the first user with a foreground session, if any. + for _, uid := range foregroundUIDs { + if profile := profiles.DefaultUserProfile(uid); profile.ID() != "" { + return profile + } + } + + // If no user has a foreground session but the current profile's owner is still signed in, + // keep the current profile even if the session is not in the foreground, + // such as when the screen is locked or a remote session is disconnected. + if len(foregroundUIDs) == 0 && isCurrentProfileOwnerSignedIn { + return profiles.CurrentProfile() + } + + // Otherwise, there's no background profile. + return ipn.LoginProfileView{} +} + +// Shutdown implements [ipnext.Extension]. +func (e *desktopSessionsExt) Shutdown() error { + for _, f := range e.cleanup { + f() + } + e.cleanup = nil + e.host = nil + return e.sm.Close() +} diff --git a/ipn/desktop/mksyscall.go b/ipn/desktop/mksyscall.go new file mode 100644 index 000000000..b7af12366 --- /dev/null +++ b/ipn/desktop/mksyscall.go @@ -0,0 +1,22 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package desktop + +//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go mksyscall.go +//go:generate go run golang.org/x/tools/cmd/goimports -w zsyscall_windows.go + +//sys setLastError(dwErrorCode uint32) = kernel32.SetLastError + +//sys registerClassEx(windowClass *_WNDCLASSEX) (atom uint16, err error) [atom==0] = user32.RegisterClassExW +//sys createWindowEx(dwExStyle uint32, lpClassName *uint16, lpWindowName *uint16, dwStyle uint32, x int32, y int32, nWidth int32, nHeight int32, hWndParent windows.HWND, hMenu windows.Handle, hInstance windows.Handle, lpParam unsafe.Pointer) (hWnd windows.HWND, err error) [hWnd==0] = user32.CreateWindowExW +//sys defWindowProc(hwnd windows.HWND, msg uint32, wparam uintptr, lparam uintptr) (res uintptr) = user32.DefWindowProcW +//sys sendMessage(hwnd windows.HWND, msg uint32, wparam uintptr, lparam uintptr) (res uintptr) = user32.SendMessageW +//sys getMessage(lpMsg *_MSG, hwnd windows.HWND, msgMin uint32, msgMax uint32) (ret int32) = user32.GetMessageW +//sys translateMessage(lpMsg *_MSG) (res bool) = user32.TranslateMessage +//sys dispatchMessage(lpMsg *_MSG) (res uintptr) = user32.DispatchMessageW +//sys destroyWindow(hwnd windows.HWND) (err error) [int32(failretval)==0] = user32.DestroyWindow +//sys postQuitMessage(exitCode int32) = user32.PostQuitMessage + +//sys registerSessionNotification(hServer windows.Handle, hwnd windows.HWND, flags uint32) (err error) [int32(failretval)==0] = wtsapi32.WTSRegisterSessionNotificationEx +//sys unregisterSessionNotification(hServer windows.Handle, hwnd windows.HWND) (err error) [int32(failretval)==0] = wtsapi32.WTSUnRegisterSessionNotificationEx diff --git a/ipn/desktop/session.go b/ipn/desktop/session.go new file mode 100644 index 000000000..c95378914 --- /dev/null +++ b/ipn/desktop/session.go @@ -0,0 +1,58 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package desktop + +import ( + "fmt" + + "tailscale.com/ipn/ipnauth" +) + +// SessionID is a unique identifier of a desktop session. +type SessionID uint + +// SessionStatus is the status of a desktop session. +type SessionStatus int + +const ( + // ClosedSession is a session that does not exist, is not yet initialized by the OS, + // or has been terminated. + ClosedSession SessionStatus = iota + // ForegroundSession is a session that a user can interact with, + // such as when attached to a physical console or an active, + // unlocked RDP connection. + ForegroundSession + // BackgroundSession indicates that the session is locked, disconnected, + // or otherwise running without user presence or interaction. + BackgroundSession +) + +// String implements [fmt.Stringer]. +func (s SessionStatus) String() string { + switch s { + case ClosedSession: + return "Closed" + case ForegroundSession: + return "Foreground" + case BackgroundSession: + return "Background" + default: + panic("unreachable") + } +} + +// Session is a state of a desktop session at a given point in time. +type Session struct { + ID SessionID // Identifier of the session; can be reused after the session is closed. + Status SessionStatus // The status of the session, such as foreground or background. + User ipnauth.Actor // User logged into the session. +} + +// Description returns a human-readable description of the session. +func (s *Session) Description() string { + if maybeUsername, _ := s.User.Username(); maybeUsername != "" { // best effort + return fmt.Sprintf("Session %d - %q (%s)", s.ID, maybeUsername, s.Status) + } + return fmt.Sprintf("Session %d (%s)", s.ID, s.Status) +} diff --git a/ipn/desktop/sessions.go b/ipn/desktop/sessions.go new file mode 100644 index 000000000..8bf7a75e2 --- /dev/null +++ b/ipn/desktop/sessions.go @@ -0,0 +1,60 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package desktop + +import ( + "errors" + "runtime" +) + +// ErrNotImplemented is returned by [NewSessionManager] when it is not +// implemented for the current GOOS. +var ErrNotImplemented = errors.New("not implemented for GOOS=" + runtime.GOOS) + +// SessionInitCallback is a function that is called once per [Session]. +// It returns an optional cleanup function that is called when the session +// is about to be destroyed, or nil if no cleanup is needed. +// It is not safe to call SessionManager methods from within the callback. +type SessionInitCallback func(session *Session) (cleanup func()) + +// SessionStateCallback is a function that reports the initial or updated +// state of a [Session], such as when it transitions between foreground and background. +// It is guaranteed to be called after all registered [SessionInitCallback] functions +// have completed, and before any cleanup functions are called for the same session. +// It is not safe to call SessionManager methods from within the callback. +type SessionStateCallback func(session *Session) + +// SessionManager is an interface that provides access to desktop sessions on the current platform. +// It is safe for concurrent use. +type SessionManager interface { + // Init explicitly initializes the receiver. + // Unless the receiver is explicitly initialized, it will be lazily initialized + // on the first call to any other method. + // It is safe to call Init multiple times. + Init() error + + // Sessions returns a session snapshot taken at the time of the call. + // Since sessions can be created or destroyed at any time, it may become + // outdated as soon as it is returned. + // + // It is primarily intended for logging and debugging. + // Prefer registering a [SessionInitCallback] or [SessionStateCallback] + // in contexts requiring stronger guarantees. + Sessions() (map[SessionID]*Session, error) + + // RegisterInitCallback registers a [SessionInitCallback] that is called for each existing session + // and for each new session that is created, until the returned unregister function is called. + // If the specified [SessionInitCallback] returns a cleanup function, it is called when the session + // is about to be destroyed. The callback function is guaranteed to be called once and only once + // for each existing and new session. + RegisterInitCallback(cb SessionInitCallback) (unregister func(), err error) + + // RegisterStateCallback registers a [SessionStateCallback] that is called for each existing session + // and every time the state of a session changes, until the returned unregister function is called. + RegisterStateCallback(cb SessionStateCallback) (unregister func(), err error) + + // Close waits for all registered callbacks to complete + // and releases resources associated with the receiver. + Close() error +} diff --git a/ipn/desktop/sessions_notwindows.go b/ipn/desktop/sessions_notwindows.go new file mode 100644 index 000000000..da3230a45 --- /dev/null +++ b/ipn/desktop/sessions_notwindows.go @@ -0,0 +1,15 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !windows + +package desktop + +import "tailscale.com/types/logger" + +// NewSessionManager returns a new [SessionManager] for the current platform, +// [ErrNotImplemented] if the platform is not supported, or an error if the +// session manager could not be created. +func NewSessionManager(logger.Logf) (SessionManager, error) { + return nil, ErrNotImplemented +} diff --git a/ipn/desktop/sessions_windows.go b/ipn/desktop/sessions_windows.go new file mode 100644 index 000000000..83b884228 --- /dev/null +++ b/ipn/desktop/sessions_windows.go @@ -0,0 +1,707 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package desktop + +import ( + "context" + "errors" + "fmt" + "runtime" + "sync" + "syscall" + "unsafe" + + "golang.org/x/sys/windows" + "tailscale.com/ipn/ipnauth" + "tailscale.com/types/logger" + "tailscale.com/util/must" + "tailscale.com/util/set" +) + +// wtsManager is a [SessionManager] implementation for Windows. +type wtsManager struct { + logf logger.Logf + ctx context.Context // cancelled when the manager is closed + ctxCancel context.CancelFunc + + initOnce func() error + watcher *sessionWatcher + + mu sync.Mutex + sessions map[SessionID]*wtsSession + initCbs set.HandleSet[SessionInitCallback] + stateCbs set.HandleSet[SessionStateCallback] +} + +// NewSessionManager returns a new [SessionManager] for the current platform, +func NewSessionManager(logf logger.Logf) (SessionManager, error) { + ctx, ctxCancel := context.WithCancel(context.Background()) + m := &wtsManager{ + logf: logf, + ctx: ctx, + ctxCancel: ctxCancel, + sessions: make(map[SessionID]*wtsSession), + } + m.watcher = newSessionWatcher(m.ctx, m.logf, m.sessionEventHandler) + + m.initOnce = sync.OnceValue(func() error { + if err := waitUntilWTSReady(m.ctx); err != nil { + return fmt.Errorf("WTS is not ready: %w", err) + } + + m.mu.Lock() + defer m.mu.Unlock() + if err := m.watcher.Start(); err != nil { + return fmt.Errorf("failed to start session watcher: %w", err) + } + + var err error + m.sessions, err = enumerateSessions() + return err // may be nil or non-nil + }) + return m, nil +} + +// Init implements [SessionManager]. +func (m *wtsManager) Init() error { + return m.initOnce() +} + +// Sessions implements [SessionManager]. +func (m *wtsManager) Sessions() (map[SessionID]*Session, error) { + if err := m.initOnce(); err != nil { + return nil, err + } + + m.mu.Lock() + defer m.mu.Unlock() + sessions := make(map[SessionID]*Session, len(m.sessions)) + for _, s := range m.sessions { + sessions[s.id] = s.AsSession() + } + return sessions, nil +} + +// RegisterInitCallback implements [SessionManager]. +func (m *wtsManager) RegisterInitCallback(cb SessionInitCallback) (unregister func(), err error) { + if err := m.initOnce(); err != nil { + return nil, err + } + if cb == nil { + return nil, errors.New("nil callback") + } + + m.mu.Lock() + defer m.mu.Unlock() + handle := m.initCbs.Add(cb) + + // TODO(nickkhyl): enqueue callbacks in a separate goroutine? + for _, s := range m.sessions { + if cleanup := cb(s.AsSession()); cleanup != nil { + s.cleanup = append(s.cleanup, cleanup) + } + } + + return func() { + m.mu.Lock() + defer m.mu.Unlock() + delete(m.initCbs, handle) + }, nil +} + +// RegisterStateCallback implements [SessionManager]. +func (m *wtsManager) RegisterStateCallback(cb SessionStateCallback) (unregister func(), err error) { + if err := m.initOnce(); err != nil { + return nil, err + } + if cb == nil { + return nil, errors.New("nil callback") + } + + m.mu.Lock() + defer m.mu.Unlock() + handle := m.stateCbs.Add(cb) + + // TODO(nickkhyl): enqueue callbacks in a separate goroutine? + for _, s := range m.sessions { + cb(s.AsSession()) + } + + return func() { + m.mu.Lock() + defer m.mu.Unlock() + delete(m.stateCbs, handle) + }, nil +} + +func (m *wtsManager) sessionEventHandler(id SessionID, event uint32) { + m.mu.Lock() + defer m.mu.Unlock() + switch event { + case windows.WTS_SESSION_LOGON: + // The session may have been created after we started watching, + // but before the initial enumeration was performed. + // Do not create a new session if it already exists. + if _, _, err := m.getOrCreateSessionLocked(id); err != nil { + m.logf("[unexpected] getOrCreateSessionLocked(%d): %v", id, err) + } + case windows.WTS_SESSION_LOCK: + if err := m.setSessionStatusLocked(id, BackgroundSession); err != nil { + m.logf("[unexpected] setSessionStatusLocked(%d, BackgroundSession): %v", id, err) + } + case windows.WTS_SESSION_UNLOCK: + if err := m.setSessionStatusLocked(id, ForegroundSession); err != nil { + m.logf("[unexpected] setSessionStatusLocked(%d, ForegroundSession): %v", id, err) + } + case windows.WTS_SESSION_LOGOFF: + if err := m.deleteSessionLocked(id); err != nil { + m.logf("[unexpected] deleteSessionLocked(%d): %v", id, err) + } + } +} + +func (m *wtsManager) getOrCreateSessionLocked(id SessionID) (_ *wtsSession, created bool, err error) { + if s, ok := m.sessions[id]; ok { + return s, false, nil + } + + s, err := newWTSSession(id, ForegroundSession) + if err != nil { + return nil, false, err + } + m.sessions[id] = s + + session := s.AsSession() + // TODO(nickkhyl): enqueue callbacks in a separate goroutine? + for _, cb := range m.initCbs { + if cleanup := cb(session); cleanup != nil { + s.cleanup = append(s.cleanup, cleanup) + } + } + for _, cb := range m.stateCbs { + cb(session) + } + + return s, true, err +} + +func (m *wtsManager) setSessionStatusLocked(id SessionID, status SessionStatus) error { + s, _, err := m.getOrCreateSessionLocked(id) + if err != nil { + return err + } + if s.status == status { + return nil + } + + s.status = status + session := s.AsSession() + // TODO(nickkhyl): enqueue callbacks in a separate goroutine? + for _, cb := range m.stateCbs { + cb(session) + } + return nil +} + +func (m *wtsManager) deleteSessionLocked(id SessionID) error { + s, ok := m.sessions[id] + if !ok { + return nil + } + + s.status = ClosedSession + session := s.AsSession() + // TODO(nickkhyl): enqueue callbacks (and [wtsSession.close]!) in a separate goroutine? + for _, cb := range m.stateCbs { + cb(session) + } + + delete(m.sessions, id) + return s.close() +} + +func (m *wtsManager) Close() error { + m.ctxCancel() + + if m.watcher != nil { + err := m.watcher.Stop() + if err != nil { + return err + } + m.watcher = nil + } + + m.mu.Lock() + defer m.mu.Unlock() + m.initCbs = nil + m.stateCbs = nil + errs := make([]error, 0, len(m.sessions)) + for _, s := range m.sessions { + errs = append(errs, s.close()) + } + m.sessions = nil + return errors.Join(errs...) +} + +type wtsSession struct { + id SessionID + user *ipnauth.WindowsActor + + status SessionStatus + + cleanup []func() +} + +func newWTSSession(id SessionID, status SessionStatus) (*wtsSession, error) { + var token windows.Token + if err := windows.WTSQueryUserToken(uint32(id), &token); err != nil { + return nil, err + } + user, err := ipnauth.NewWindowsActorWithToken(token) + if err != nil { + return nil, err + } + return &wtsSession{id, user, status, nil}, nil +} + +// enumerateSessions returns a map of all active WTS sessions. +func enumerateSessions() (map[SessionID]*wtsSession, error) { + const reserved, version uint32 = 0, 1 + var numSessions uint32 + var sessionInfos *windows.WTS_SESSION_INFO + if err := windows.WTSEnumerateSessions(_WTS_CURRENT_SERVER_HANDLE, reserved, version, &sessionInfos, &numSessions); err != nil { + return nil, fmt.Errorf("WTSEnumerateSessions failed: %w", err) + } + defer windows.WTSFreeMemory(uintptr(unsafe.Pointer(sessionInfos))) + + sessions := make(map[SessionID]*wtsSession, numSessions) + for _, si := range unsafe.Slice(sessionInfos, numSessions) { + status := _WTS_CONNECTSTATE_CLASS(si.State).ToSessionStatus() + if status == ClosedSession { + // The session does not exist as far as we're concerned. + // It may be in the process of being created or destroyed, + // or be a special "listener" session, etc. + continue + } + id := SessionID(si.SessionID) + session, err := newWTSSession(id, status) + if err != nil { + continue + } + sessions[id] = session + } + return sessions, nil +} + +func (s *wtsSession) AsSession() *Session { + return &Session{ + ID: s.id, + Status: s.status, + // wtsSession owns the user; don't let the caller close it + User: ipnauth.WithoutClose(s.user), + } +} + +func (m *wtsSession) close() error { + for _, cleanup := range m.cleanup { + cleanup() + } + m.cleanup = nil + + if m.user != nil { + if err := m.user.Close(); err != nil { + return err + } + m.user = nil + } + return nil +} + +type sessionEventHandler func(id SessionID, event uint32) + +// TODO(nickkhyl): implement a sessionWatcher that does not use the message queue. +// One possible approach is to have the tailscaled service register a HandlerEx function +// and stream SERVICE_CONTROL_SESSIONCHANGE events to the tailscaled subprocess +// (the actual tailscaled backend), exposing these events via [sessionWatcher]/[wtsManager]. +// +// See tailscale/corp#26477 for details and tracking. +type sessionWatcher struct { + logf logger.Logf + ctx context.Context // canceled to stop the watcher + ctxCancel context.CancelFunc // cancels the watcher + hWnd windows.HWND // window handle for receiving session change notifications + handler sessionEventHandler // called on session events + + mu sync.Mutex + doneCh chan error // written to when the watcher exits; nil if not started +} + +func newSessionWatcher(ctx context.Context, logf logger.Logf, handler sessionEventHandler) *sessionWatcher { + ctx, cancel := context.WithCancel(ctx) + return &sessionWatcher{logf: logf, ctx: ctx, ctxCancel: cancel, handler: handler} +} + +func (sw *sessionWatcher) Start() error { + sw.mu.Lock() + defer sw.mu.Unlock() + + select { + case <-sw.ctx.Done(): + return fmt.Errorf("sessionWatcher already stopped: %w", sw.ctx.Err()) + default: + } + + if sw.doneCh != nil { + // Already started. + return nil + } + sw.doneCh = make(chan error, 1) + + startedCh := make(chan error, 1) + go sw.run(startedCh, sw.doneCh) + if err := <-startedCh; err != nil { + return err + } + + // Signal the window to unsubscribe from session notifications + // and shut down gracefully when the sessionWatcher is stopped. + context.AfterFunc(sw.ctx, func() { + sendMessage(sw.hWnd, _WM_CLOSE, 0, 0) + }) + return nil +} + +func (sw *sessionWatcher) run(started, done chan<- error) { + runtime.LockOSThread() + defer func() { + runtime.UnlockOSThread() + close(done) + }() + err := sw.createMessageWindow() + started <- err + if err != nil { + return + } + pumpThreadMessages() +} + +// Stop stops the session watcher and waits for it to exit. +func (sw *sessionWatcher) Stop() error { + sw.ctxCancel() + + sw.mu.Lock() + doneCh := sw.doneCh + sw.doneCh = nil + sw.mu.Unlock() + + if doneCh != nil { + return <-doneCh + } + return nil +} + +const watcherWindowClassName = "Tailscale-SessionManager" + +var watcherWindowClassName16 = sync.OnceValue(func() *uint16 { + return must.Get(syscall.UTF16PtrFromString(watcherWindowClassName)) +}) + +var registerSessionManagerWindowClass = sync.OnceValue(func() error { + var hInst windows.Handle + if err := windows.GetModuleHandleEx(0, nil, &hInst); err != nil { + return fmt.Errorf("GetModuleHandle: %w", err) + } + wc := _WNDCLASSEX{ + CbSize: uint32(unsafe.Sizeof(_WNDCLASSEX{})), + HInstance: hInst, + LpfnWndProc: syscall.NewCallback(sessionWatcherWndProc), + LpszClassName: watcherWindowClassName16(), + } + if _, err := registerClassEx(&wc); err != nil { + return fmt.Errorf("RegisterClassEx(%q): %w", watcherWindowClassName, err) + } + return nil +}) + +func (sw *sessionWatcher) createMessageWindow() error { + if err := registerSessionManagerWindowClass(); err != nil { + return err + } + _, err := createWindowEx( + 0, // dwExStyle + watcherWindowClassName16(), // lpClassName + nil, // lpWindowName + 0, // dwStyle + 0, // x + 0, // y + 0, // nWidth + 0, // nHeight + _HWND_MESSAGE, // hWndParent; message-only window + 0, // hMenu + 0, // hInstance + unsafe.Pointer(sw), // lpParam + ) + if err != nil { + return fmt.Errorf("CreateWindowEx: %w", err) + } + return nil +} + +func (sw *sessionWatcher) wndProc(hWnd windows.HWND, msg uint32, wParam, lParam uintptr) (result uintptr) { + switch msg { + case _WM_CREATE: + err := registerSessionNotification(_WTS_CURRENT_SERVER_HANDLE, hWnd, _NOTIFY_FOR_ALL_SESSIONS) + if err != nil { + sw.logf("[unexpected] failed to register for session notifications: %v", err) + return ^uintptr(0) + } + sw.logf("registered for session notifications") + case _WM_WTSSESSION_CHANGE: + sw.handler(SessionID(lParam), uint32(wParam)) + return 0 + case _WM_CLOSE: + if err := destroyWindow(hWnd); err != nil { + sw.logf("[unexpected] failed to destroy window: %v", err) + } + return 0 + case _WM_DESTROY: + err := unregisterSessionNotification(_WTS_CURRENT_SERVER_HANDLE, hWnd) + if err != nil { + sw.logf("[unexpected] failed to unregister session notifications callback: %v", err) + } + sw.logf("unregistered from session notifications") + return 0 + case _WM_NCDESTROY: + sw.hWnd = 0 + postQuitMessage(0) // quit the message loop for this thread + } + return defWindowProc(hWnd, msg, wParam, lParam) +} + +func (sw *sessionWatcher) setHandle(hwnd windows.HWND) error { + sw.hWnd = hwnd + setLastError(0) + _, err := setWindowLongPtr(sw.hWnd, _GWLP_USERDATA, uintptr(unsafe.Pointer(sw))) + return err // may be nil or non-nil +} + +func sessionWatcherByHandle(hwnd windows.HWND) *sessionWatcher { + val, _ := getWindowLongPtr(hwnd, _GWLP_USERDATA) + return (*sessionWatcher)(unsafe.Pointer(val)) +} + +func sessionWatcherWndProc(hWnd windows.HWND, msg uint32, wParam, lParam uintptr) (result uintptr) { + if msg == _WM_NCCREATE { + cs := (*_CREATESTRUCT)(unsafe.Pointer(lParam)) + sw := (*sessionWatcher)(unsafe.Pointer(cs.CreateParams)) + if sw == nil { + return 0 + } + if err := sw.setHandle(hWnd); err != nil { + return 0 + } + return defWindowProc(hWnd, msg, wParam, lParam) + } + if sw := sessionWatcherByHandle(hWnd); sw != nil { + return sw.wndProc(hWnd, msg, wParam, lParam) + } + return defWindowProc(hWnd, msg, wParam, lParam) +} + +func pumpThreadMessages() { + var msg _MSG + for getMessage(&msg, 0, 0, 0) != 0 { + translateMessage(&msg) + dispatchMessage(&msg) + } +} + +// waitUntilWTSReady waits until the Windows Terminal Services (WTS) is ready. +// This is necessary because the WTS API functions may fail if called before +// the WTS is ready. +// +// https://web.archive.org/web/20250207011738/https://learn.microsoft.com/en-us/windows/win32/api/wtsapi32/nf-wtsapi32-wtsregistersessionnotificationex +func waitUntilWTSReady(ctx context.Context) error { + eventName16, err := windows.UTF16PtrFromString(`Global\TermSrvReadyEvent`) + if err != nil { + return err + } + event, err := windows.OpenEvent(windows.SYNCHRONIZE, false, eventName16) + if err != nil { + return err + } + defer windows.CloseHandle(event) + return waitForContextOrHandle(ctx, event) +} + +// waitForContextOrHandle waits for either the context to be done or a handle to be signaled. +func waitForContextOrHandle(ctx context.Context, handle windows.Handle) error { + contextDoneEvent, cleanup, err := channelToEvent(ctx.Done()) + if err != nil { + return err + } + defer cleanup() + + handles := []windows.Handle{contextDoneEvent, handle} + waitCode, err := windows.WaitForMultipleObjects(handles, false, windows.INFINITE) + if err != nil { + return err + } + + waitCode -= windows.WAIT_OBJECT_0 + if waitCode == 0 { // contextDoneEvent + return ctx.Err() + } + return nil +} + +// channelToEvent returns an auto-reset event that is set when the channel +// becomes receivable, including when the channel is closed. +func channelToEvent[T any](c <-chan T) (evt windows.Handle, cleanup func(), err error) { + evt, err = windows.CreateEvent(nil, 0, 0, nil) + if err != nil { + return 0, nil, err + } + + cancel := make(chan struct{}) + + go func() { + select { + case <-cancel: + return + case <-c: + } + windows.SetEvent(evt) + }() + + cleanup = func() { + close(cancel) + windows.CloseHandle(evt) + } + + return evt, cleanup, nil +} + +type _WNDCLASSEX struct { + CbSize uint32 + Style uint32 + LpfnWndProc uintptr + CbClsExtra int32 + CbWndExtra int32 + HInstance windows.Handle + HIcon windows.Handle + HCursor windows.Handle + HbrBackground windows.Handle + LpszMenuName *uint16 + LpszClassName *uint16 + HIconSm windows.Handle +} + +type _CREATESTRUCT struct { + CreateParams uintptr + Instance windows.Handle + Menu windows.Handle + Parent windows.HWND + Cy int32 + Cx int32 + Y int32 + X int32 + Style int32 + Name *uint16 + ClassName *uint16 + ExStyle uint32 +} + +type _POINT struct { + X, Y int32 +} + +type _MSG struct { + HWnd windows.HWND + Message uint32 + WParam uintptr + LParam uintptr + Time uint32 + Pt _POINT +} + +const ( + _WM_CREATE = 1 + _WM_DESTROY = 2 + _WM_CLOSE = 16 + _WM_NCCREATE = 129 + _WM_QUIT = 18 + _WM_NCDESTROY = 130 + + // _WM_WTSSESSION_CHANGE is a message sent to windows that have registered + // for session change notifications, informing them of changes in session state. + // + // https://web.archive.org/web/20250207012421/https://learn.microsoft.com/en-us/windows/win32/termserv/wm-wtssession-change + _WM_WTSSESSION_CHANGE = 0x02B1 +) + +const _GWLP_USERDATA = -21 + +const _HWND_MESSAGE = ^windows.HWND(2) + +// _NOTIFY_FOR_ALL_SESSIONS indicates that the window should receive +// session change notifications for all sessions on the specified server. +const _NOTIFY_FOR_ALL_SESSIONS = 1 + +// _WTS_CURRENT_SERVER_HANDLE indicates that the window should receive +// session change notifications for the host itself rather than a remote server. +const _WTS_CURRENT_SERVER_HANDLE = windows.Handle(0) + +// _WTS_CONNECTSTATE_CLASS represents the connection state of a session. +// +// https://web.archive.org/web/20250206082427/https://learn.microsoft.com/en-us/windows/win32/api/wtsapi32/ne-wtsapi32-wts_connectstate_class +type _WTS_CONNECTSTATE_CLASS int32 + +// ToSessionStatus converts cs to a [SessionStatus]. +func (cs _WTS_CONNECTSTATE_CLASS) ToSessionStatus() SessionStatus { + switch cs { + case windows.WTSActive: + return ForegroundSession + case windows.WTSDisconnected: + return BackgroundSession + default: + // The session does not exist as far as we're concerned. + return ClosedSession + } +} + +var ( + procGetWindowLongPtrW *windows.LazyProc + procSetWindowLongPtrW *windows.LazyProc +) + +func init() { + // GetWindowLongPtrW and SetWindowLongPtrW are only available on 64-bit platforms. + // https://web.archive.org/web/20250414195520/https://learn.microsoft.com/en-us/windows/win32/api/winuser/nf-winuser-getwindowlongptrw + if runtime.GOARCH == "386" || runtime.GOARCH == "arm" { + procGetWindowLongPtrW = moduser32.NewProc("GetWindowLongW") + procSetWindowLongPtrW = moduser32.NewProc("SetWindowLongW") + } else { + procGetWindowLongPtrW = moduser32.NewProc("GetWindowLongPtrW") + procSetWindowLongPtrW = moduser32.NewProc("SetWindowLongPtrW") + } +} + +func getWindowLongPtr(hwnd windows.HWND, index int32) (res uintptr, err error) { + r0, _, e1 := syscall.Syscall(procGetWindowLongPtrW.Addr(), 2, uintptr(hwnd), uintptr(index), 0) + res = uintptr(r0) + if res == 0 && e1 != 0 { + err = errnoErr(e1) + } + return +} + +func setWindowLongPtr(hwnd windows.HWND, index int32, newLong uintptr) (res uintptr, err error) { + r0, _, e1 := syscall.Syscall(procSetWindowLongPtrW.Addr(), 3, uintptr(hwnd), uintptr(index), uintptr(newLong)) + res = uintptr(r0) + if res == 0 && e1 != 0 { + err = errnoErr(e1) + } + return +} diff --git a/ipn/desktop/zsyscall_windows.go b/ipn/desktop/zsyscall_windows.go new file mode 100644 index 000000000..8d97c4d80 --- /dev/null +++ b/ipn/desktop/zsyscall_windows.go @@ -0,0 +1,139 @@ +// Code generated by 'go generate'; DO NOT EDIT. + +package desktop + +import ( + "syscall" + "unsafe" + + "golang.org/x/sys/windows" +) + +var _ unsafe.Pointer + +// Do the interface allocations only once for common +// Errno values. +const ( + errnoERROR_IO_PENDING = 997 +) + +var ( + errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING) + errERROR_EINVAL error = syscall.EINVAL +) + +// errnoErr returns common boxed Errno values, to prevent +// allocations at runtime. +func errnoErr(e syscall.Errno) error { + switch e { + case 0: + return errERROR_EINVAL + case errnoERROR_IO_PENDING: + return errERROR_IO_PENDING + } + // TODO: add more here, after collecting data on the common + // error values see on Windows. (perhaps when running + // all.bat?) + return e +} + +var ( + modkernel32 = windows.NewLazySystemDLL("kernel32.dll") + moduser32 = windows.NewLazySystemDLL("user32.dll") + modwtsapi32 = windows.NewLazySystemDLL("wtsapi32.dll") + + procSetLastError = modkernel32.NewProc("SetLastError") + procCreateWindowExW = moduser32.NewProc("CreateWindowExW") + procDefWindowProcW = moduser32.NewProc("DefWindowProcW") + procDestroyWindow = moduser32.NewProc("DestroyWindow") + procDispatchMessageW = moduser32.NewProc("DispatchMessageW") + procGetMessageW = moduser32.NewProc("GetMessageW") + procPostQuitMessage = moduser32.NewProc("PostQuitMessage") + procRegisterClassExW = moduser32.NewProc("RegisterClassExW") + procSendMessageW = moduser32.NewProc("SendMessageW") + procTranslateMessage = moduser32.NewProc("TranslateMessage") + procWTSRegisterSessionNotificationEx = modwtsapi32.NewProc("WTSRegisterSessionNotificationEx") + procWTSUnRegisterSessionNotificationEx = modwtsapi32.NewProc("WTSUnRegisterSessionNotificationEx") +) + +func setLastError(dwErrorCode uint32) { + syscall.SyscallN(procSetLastError.Addr(), uintptr(dwErrorCode)) + return +} + +func createWindowEx(dwExStyle uint32, lpClassName *uint16, lpWindowName *uint16, dwStyle uint32, x int32, y int32, nWidth int32, nHeight int32, hWndParent windows.HWND, hMenu windows.Handle, hInstance windows.Handle, lpParam unsafe.Pointer) (hWnd windows.HWND, err error) { + r0, _, e1 := syscall.SyscallN(procCreateWindowExW.Addr(), uintptr(dwExStyle), uintptr(unsafe.Pointer(lpClassName)), uintptr(unsafe.Pointer(lpWindowName)), uintptr(dwStyle), uintptr(x), uintptr(y), uintptr(nWidth), uintptr(nHeight), uintptr(hWndParent), uintptr(hMenu), uintptr(hInstance), uintptr(lpParam)) + hWnd = windows.HWND(r0) + if hWnd == 0 { + err = errnoErr(e1) + } + return +} + +func defWindowProc(hwnd windows.HWND, msg uint32, wparam uintptr, lparam uintptr) (res uintptr) { + r0, _, _ := syscall.SyscallN(procDefWindowProcW.Addr(), uintptr(hwnd), uintptr(msg), uintptr(wparam), uintptr(lparam)) + res = uintptr(r0) + return +} + +func destroyWindow(hwnd windows.HWND) (err error) { + r1, _, e1 := syscall.SyscallN(procDestroyWindow.Addr(), uintptr(hwnd)) + if int32(r1) == 0 { + err = errnoErr(e1) + } + return +} + +func dispatchMessage(lpMsg *_MSG) (res uintptr) { + r0, _, _ := syscall.SyscallN(procDispatchMessageW.Addr(), uintptr(unsafe.Pointer(lpMsg))) + res = uintptr(r0) + return +} + +func getMessage(lpMsg *_MSG, hwnd windows.HWND, msgMin uint32, msgMax uint32) (ret int32) { + r0, _, _ := syscall.SyscallN(procGetMessageW.Addr(), uintptr(unsafe.Pointer(lpMsg)), uintptr(hwnd), uintptr(msgMin), uintptr(msgMax)) + ret = int32(r0) + return +} + +func postQuitMessage(exitCode int32) { + syscall.SyscallN(procPostQuitMessage.Addr(), uintptr(exitCode)) + return +} + +func registerClassEx(windowClass *_WNDCLASSEX) (atom uint16, err error) { + r0, _, e1 := syscall.SyscallN(procRegisterClassExW.Addr(), uintptr(unsafe.Pointer(windowClass))) + atom = uint16(r0) + if atom == 0 { + err = errnoErr(e1) + } + return +} + +func sendMessage(hwnd windows.HWND, msg uint32, wparam uintptr, lparam uintptr) (res uintptr) { + r0, _, _ := syscall.SyscallN(procSendMessageW.Addr(), uintptr(hwnd), uintptr(msg), uintptr(wparam), uintptr(lparam)) + res = uintptr(r0) + return +} + +func translateMessage(lpMsg *_MSG) (res bool) { + r0, _, _ := syscall.SyscallN(procTranslateMessage.Addr(), uintptr(unsafe.Pointer(lpMsg))) + res = r0 != 0 + return +} + +func registerSessionNotification(hServer windows.Handle, hwnd windows.HWND, flags uint32) (err error) { + r1, _, e1 := syscall.SyscallN(procWTSRegisterSessionNotificationEx.Addr(), uintptr(hServer), uintptr(hwnd), uintptr(flags)) + if int32(r1) == 0 { + err = errnoErr(e1) + } + return +} + +func unregisterSessionNotification(hServer windows.Handle, hwnd windows.HWND) (err error) { + r1, _, e1 := syscall.SyscallN(procWTSUnRegisterSessionNotificationEx.Addr(), uintptr(hServer), uintptr(hwnd)) + if int32(r1) == 0 { + err = errnoErr(e1) + } + return +} diff --git a/ipn/doc.go b/ipn/doc.go index 4b3810be1..c98c7e8b3 100644 --- a/ipn/doc.go +++ b/ipn/doc.go @@ -1,7 +1,7 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -//go:generate go run tailscale.com/cmd/viewer -type=Prefs,ServeConfig,TCPPortHandler,HTTPHandler,WebServerConfig +//go:generate go run tailscale.com/cmd/viewer -type=LoginProfile,Prefs,ServeConfig,ServiceConfig,TCPPortHandler,HTTPHandler,WebServerConfig // Package ipn implements the interactions between the Tailscale cloud // control plane and the local network stack. diff --git a/ipn/ipn_clone.go b/ipn/ipn_clone.go index de35b60a7..1be716197 100644 --- a/ipn/ipn_clone.go +++ b/ipn/ipn_clone.go @@ -17,6 +17,29 @@ import ( "tailscale.com/types/ptr" ) +// Clone makes a deep copy of LoginProfile. +// The result aliases no memory with the original. +func (src *LoginProfile) Clone() *LoginProfile { + if src == nil { + return nil + } + dst := new(LoginProfile) + *dst = *src + return dst +} + +// A compilation failure here means this code must be regenerated, with the command at the top of this file. +var _LoginProfileCloneNeedsRegeneration = LoginProfile(struct { + ID ProfileID + Name string + NetworkProfile NetworkProfile + Key StateKey + UserProfile tailcfg.UserProfile + NodeID tailcfg.StableNodeID + LocalUserID WindowsUserID + ControlURL string +}{}) + // Clone makes a deep copy of Prefs. // The result aliases no memory with the original. func (src *Prefs) Clone() *Prefs { @@ -27,6 +50,7 @@ func (src *Prefs) Clone() *Prefs { *dst = *src dst.AdvertiseTags = append(src.AdvertiseTags[:0:0], src.AdvertiseTags...) dst.AdvertiseRoutes = append(src.AdvertiseRoutes[:0:0], src.AdvertiseRoutes...) + dst.AdvertiseServices = append(src.AdvertiseServices[:0:0], src.AdvertiseServices...) if src.DriveShares != nil { dst.DriveShares = make([]*drive.Share, len(src.DriveShares)) for i := range dst.DriveShares { @@ -37,6 +61,9 @@ func (src *Prefs) Clone() *Prefs { } } } + if dst.RelayServerPort != nil { + dst.RelayServerPort = ptr.To(*src.RelayServerPort) + } dst.Persist = src.Persist.Clone() return dst } @@ -47,6 +74,7 @@ var _PrefsCloneNeedsRegeneration = Prefs(struct { RouteAll bool ExitNodeID tailcfg.StableNodeID ExitNodeIP netip.Addr + AutoExitNode ExitNodeExpression InternalExitNodePrior tailcfg.StableNodeID ExitNodeAllowLANAccess bool CorpDNS bool @@ -61,6 +89,8 @@ var _PrefsCloneNeedsRegeneration = Prefs(struct { ForceDaemon bool Egg bool AdvertiseRoutes []netip.Prefix + AdvertiseServices []string + Sync opt.Bool NoSNAT bool NoStatefulFiltering opt.Bool NetfilterMode preftype.NetfilterMode @@ -71,6 +101,7 @@ var _PrefsCloneNeedsRegeneration = Prefs(struct { PostureChecking bool NetfilterKind string DriveShares []*drive.Share + RelayServerPort *int AllowSingleHosts marshalAsTrueInJSON Persist *persist.Persist }{}) @@ -103,6 +134,16 @@ func (src *ServeConfig) Clone() *ServeConfig { } } } + if dst.Services != nil { + dst.Services = map[tailcfg.ServiceName]*ServiceConfig{} + for k, v := range src.Services { + if v == nil { + dst.Services[k] = nil + } else { + dst.Services[k] = v.Clone() + } + } + } dst.AllowFunnel = maps.Clone(src.AllowFunnel) if dst.Foreground != nil { dst.Foreground = map[string]*ServeConfig{} @@ -121,11 +162,50 @@ func (src *ServeConfig) Clone() *ServeConfig { var _ServeConfigCloneNeedsRegeneration = ServeConfig(struct { TCP map[uint16]*TCPPortHandler Web map[HostPort]*WebServerConfig + Services map[tailcfg.ServiceName]*ServiceConfig AllowFunnel map[HostPort]bool Foreground map[string]*ServeConfig ETag string }{}) +// Clone makes a deep copy of ServiceConfig. +// The result aliases no memory with the original. +func (src *ServiceConfig) Clone() *ServiceConfig { + if src == nil { + return nil + } + dst := new(ServiceConfig) + *dst = *src + if dst.TCP != nil { + dst.TCP = map[uint16]*TCPPortHandler{} + for k, v := range src.TCP { + if v == nil { + dst.TCP[k] = nil + } else { + dst.TCP[k] = ptr.To(*v) + } + } + } + if dst.Web != nil { + dst.Web = map[HostPort]*WebServerConfig{} + for k, v := range src.Web { + if v == nil { + dst.Web[k] = nil + } else { + dst.Web[k] = v.Clone() + } + } + } + return dst +} + +// A compilation failure here means this code must be regenerated, with the command at the top of this file. +var _ServiceConfigCloneNeedsRegeneration = ServiceConfig(struct { + TCP map[uint16]*TCPPortHandler + Web map[HostPort]*WebServerConfig + Tun bool +}{}) + // Clone makes a deep copy of TCPPortHandler. // The result aliases no memory with the original. func (src *TCPPortHandler) Clone() *TCPPortHandler { @@ -139,10 +219,11 @@ func (src *TCPPortHandler) Clone() *TCPPortHandler { // A compilation failure here means this code must be regenerated, with the command at the top of this file. var _TCPPortHandlerCloneNeedsRegeneration = TCPPortHandler(struct { - HTTPS bool - HTTP bool - TCPForward string - TerminateTLS string + HTTPS bool + HTTP bool + TCPForward string + TerminateTLS string + ProxyProtocol int }{}) // Clone makes a deep copy of HTTPHandler. @@ -153,14 +234,17 @@ func (src *HTTPHandler) Clone() *HTTPHandler { } dst := new(HTTPHandler) *dst = *src + dst.AcceptAppCaps = append(src.AcceptAppCaps[:0:0], src.AcceptAppCaps...) return dst } // A compilation failure here means this code must be regenerated, with the command at the top of this file. var _HTTPHandlerCloneNeedsRegeneration = HTTPHandler(struct { - Path string - Proxy string - Text string + Path string + Proxy string + Text string + AcceptAppCaps []tailcfg.PeerCapability + Redirect string }{}) // Clone makes a deep copy of WebServerConfig. @@ -177,7 +261,7 @@ func (src *WebServerConfig) Clone() *WebServerConfig { if v == nil { dst.Handlers[k] = nil } else { - dst.Handlers[k] = ptr.To(*v) + dst.Handlers[k] = v.Clone() } } } diff --git a/ipn/ipn_view.go b/ipn/ipn_view.go index ff48b9c89..d3836416b 100644 --- a/ipn/ipn_view.go +++ b/ipn/ipn_view.go @@ -6,10 +6,12 @@ package ipn import ( - "encoding/json" + jsonv1 "encoding/json" "errors" "net/netip" + jsonv2 "github.com/go-json-experiment/json" + "github.com/go-json-experiment/json/jsontext" "tailscale.com/drive" "tailscale.com/tailcfg" "tailscale.com/types/opt" @@ -18,9 +20,130 @@ import ( "tailscale.com/types/views" ) -//go:generate go run tailscale.com/cmd/cloner -clonefunc=false -type=Prefs,ServeConfig,TCPPortHandler,HTTPHandler,WebServerConfig +//go:generate go run tailscale.com/cmd/cloner -clonefunc=false -type=LoginProfile,Prefs,ServeConfig,ServiceConfig,TCPPortHandler,HTTPHandler,WebServerConfig -// View returns a readonly view of Prefs. +// View returns a read-only view of LoginProfile. +func (p *LoginProfile) View() LoginProfileView { + return LoginProfileView{Đļ: p} +} + +// LoginProfileView provides a read-only view over LoginProfile. +// +// Its methods should only be called if `Valid()` returns true. +type LoginProfileView struct { + // Đļ is the underlying mutable value, named with a hard-to-type + // character that looks pointy like a pointer. + // It is named distinctively to make you think of how dangerous it is to escape + // to callers. You must not let callers be able to mutate it. + Đļ *LoginProfile +} + +// Valid reports whether v's underlying value is non-nil. +func (v LoginProfileView) Valid() bool { return v.Đļ != nil } + +// AsStruct returns a clone of the underlying value which aliases no memory with +// the original. +func (v LoginProfileView) AsStruct() *LoginProfile { + if v.Đļ == nil { + return nil + } + return v.Đļ.Clone() +} + +// MarshalJSON implements [jsonv1.Marshaler]. +func (v LoginProfileView) MarshalJSON() ([]byte, error) { + return jsonv1.Marshal(v.Đļ) +} + +// MarshalJSONTo implements [jsonv2.MarshalerTo]. +func (v LoginProfileView) MarshalJSONTo(enc *jsontext.Encoder) error { + return jsonv2.MarshalEncode(enc, v.Đļ) +} + +// UnmarshalJSON implements [jsonv1.Unmarshaler]. +func (v *LoginProfileView) UnmarshalJSON(b []byte) error { + if v.Đļ != nil { + return errors.New("already initialized") + } + if len(b) == 0 { + return nil + } + var x LoginProfile + if err := jsonv1.Unmarshal(b, &x); err != nil { + return err + } + v.Đļ = &x + return nil +} + +// UnmarshalJSONFrom implements [jsonv2.UnmarshalerFrom]. +func (v *LoginProfileView) UnmarshalJSONFrom(dec *jsontext.Decoder) error { + if v.Đļ != nil { + return errors.New("already initialized") + } + var x LoginProfile + if err := jsonv2.UnmarshalDecode(dec, &x); err != nil { + return err + } + v.Đļ = &x + return nil +} + +// ID is a unique identifier for this profile. +// It is assigned on creation and never changes. +// It may seem redundant to have both ID and UserProfile.ID +// but they are different things. UserProfile.ID may change +// over time (e.g. if a device is tagged). +func (v LoginProfileView) ID() ProfileID { return v.Đļ.ID } + +// Name is the user-visible name of this profile. +// It is filled in from the UserProfile.LoginName field. +func (v LoginProfileView) Name() string { return v.Đļ.Name } + +// NetworkProfile is a subset of netmap.NetworkMap that we +// store to remember information about the tailnet that this +// profile was logged in with. +// +// This field was added on 2023-11-17. +func (v LoginProfileView) NetworkProfile() NetworkProfile { return v.Đļ.NetworkProfile } + +// Key is the StateKey under which the profile is stored. +// It is assigned once at profile creation time and never changes. +func (v LoginProfileView) Key() StateKey { return v.Đļ.Key } + +// UserProfile is the server provided UserProfile for this profile. +// This is updated whenever the server provides a new UserProfile. +func (v LoginProfileView) UserProfile() tailcfg.UserProfile { return v.Đļ.UserProfile } + +// NodeID is the NodeID of the node that this profile is logged into. +// This should be stable across tagging and untagging nodes. +// It may seem redundant to check against both the UserProfile.UserID +// and the NodeID. However the NodeID can change if the node is deleted +// from the admin panel. +func (v LoginProfileView) NodeID() tailcfg.StableNodeID { return v.Đļ.NodeID } + +// LocalUserID is the user ID of the user who created this profile. +// It is only relevant on Windows where we have a multi-user system. +// It is assigned once at profile creation time and never changes. +func (v LoginProfileView) LocalUserID() WindowsUserID { return v.Đļ.LocalUserID } + +// ControlURL is the URL of the control server that this profile is logged +// into. +func (v LoginProfileView) ControlURL() string { return v.Đļ.ControlURL } + +// A compilation failure here means this code must be regenerated, with the command at the top of this file. +var _LoginProfileViewNeedsRegeneration = LoginProfile(struct { + ID ProfileID + Name string + NetworkProfile NetworkProfile + Key StateKey + UserProfile tailcfg.UserProfile + NodeID tailcfg.StableNodeID + LocalUserID WindowsUserID + ControlURL string +}{}) + +// View returns a read-only view of Prefs. func (p *Prefs) View() PrefsView { return PrefsView{Đļ: p} } @@ -36,7 +159,7 @@ type PrefsView struct { Đļ *Prefs } -// Valid reports whether underlying value is non-nil. +// Valid reports whether v's underlying value is non-nil. func (v PrefsView) Valid() bool { return v.Đļ != nil } // AsStruct returns a clone of the underlying value which aliases no memory with @@ -48,8 +171,17 @@ func (v PrefsView) AsStruct() *Prefs { return v.Đļ.Clone() } -func (v PrefsView) MarshalJSON() ([]byte, error) { return json.Marshal(v.Đļ) } +// MarshalJSON implements [jsonv1.Marshaler]. +func (v PrefsView) MarshalJSON() ([]byte, error) { + return jsonv1.Marshal(v.Đļ) +} +// MarshalJSONTo implements [jsonv2.MarshalerTo]. +func (v PrefsView) MarshalJSONTo(enc *jsontext.Encoder) error { + return jsonv2.MarshalEncode(enc, v.Đļ) +} + +// UnmarshalJSON implements [jsonv1.Unmarshaler]. func (v *PrefsView) UnmarshalJSON(b []byte) error { if v.Đļ != nil { return errors.New("already initialized") @@ -58,47 +190,281 @@ func (v *PrefsView) UnmarshalJSON(b []byte) error { return nil } var x Prefs - if err := json.Unmarshal(b, &x); err != nil { + if err := jsonv1.Unmarshal(b, &x); err != nil { return err } v.Đļ = &x return nil } -func (v PrefsView) ControlURL() string { return v.Đļ.ControlURL } -func (v PrefsView) RouteAll() bool { return v.Đļ.RouteAll } -func (v PrefsView) ExitNodeID() tailcfg.StableNodeID { return v.Đļ.ExitNodeID } -func (v PrefsView) ExitNodeIP() netip.Addr { return v.Đļ.ExitNodeIP } +// UnmarshalJSONFrom implements [jsonv2.UnmarshalerFrom]. +func (v *PrefsView) UnmarshalJSONFrom(dec *jsontext.Decoder) error { + if v.Đļ != nil { + return errors.New("already initialized") + } + var x Prefs + if err := jsonv2.UnmarshalDecode(dec, &x); err != nil { + return err + } + v.Đļ = &x + return nil +} + +// ControlURL is the URL of the control server to use. +// +// If empty, the default for new installs, DefaultControlURL +// is used. It's set non-empty once the daemon has been started +// for the first time. +// +// TODO(apenwarr): Make it safe to update this with EditPrefs(). +// Right now, you have to pass it in the initial prefs in Start(), +// which is the only code that actually uses the ControlURL value. +// It would be more consistent to restart controlclient +// automatically whenever this variable changes. +// +// Meanwhile, you have to provide this as part of +// Options.LegacyMigrationPrefs or Options.UpdatePrefs when +// calling Backend.Start(). +func (v PrefsView) ControlURL() string { return v.Đļ.ControlURL } + +// RouteAll specifies whether to accept subnets advertised by +// other nodes on the Tailscale network. Note that this does not +// include default routes (0.0.0.0/0 and ::/0), those are +// controlled by ExitNodeID/IP below. +func (v PrefsView) RouteAll() bool { return v.Đļ.RouteAll } + +// ExitNodeID and ExitNodeIP specify the node that should be used +// as an exit node for internet traffic. At most one of these +// should be non-zero. +// +// The preferred way to express the chosen node is ExitNodeID, but +// in some cases it's not possible to use that ID (e.g. in the +// linux CLI, before tailscaled has a netmap). For those +// situations, we allow specifying the exit node by IP, and +// ipnlocal.LocalBackend will translate the IP into an ID when the +// node is found in the netmap. +// +// If the selected exit node doesn't exist (e.g. it's not part of +// the current tailnet), or it doesn't offer exit node services, a +// blackhole route will be installed on the local system to +// prevent any traffic escaping to the local network. +func (v PrefsView) ExitNodeID() tailcfg.StableNodeID { return v.Đļ.ExitNodeID } +func (v PrefsView) ExitNodeIP() netip.Addr { return v.Đļ.ExitNodeIP } + +// AutoExitNode is an optional expression that specifies whether and how +// tailscaled should pick an exit node automatically. +// +// If specified, tailscaled will use an exit node based on the expression, +// and will re-evaluate the selection periodically as network conditions, +// available exit nodes, or policy settings change. A blackhole route will +// be installed to prevent traffic from escaping to the local network until +// an exit node is selected. It takes precedence over ExitNodeID and ExitNodeIP. +// +// If empty, tailscaled will not automatically select an exit node. +// +// If the specified expression is invalid or unsupported by the client, +// it falls back to the behavior of [AnyExitNode]. +// +// As of 2025-07-02, the only supported value is [AnyExitNode]. +// It's a string rather than a boolean to allow future extensibility +// (e.g., AutoExitNode = "mullvad" or AutoExitNode = "geo:us"). +func (v PrefsView) AutoExitNode() ExitNodeExpression { return v.Đļ.AutoExitNode } + +// InternalExitNodePrior is the most recently used ExitNodeID in string form. It is set by +// the backend on transition from exit node on to off and used by the +// backend. +// +// As an Internal field, it can't be set by LocalAPI clients, rather it is set indirectly +// when the ExitNodeID value is zero'd and via the set-use-exit-node-enabled endpoint. func (v PrefsView) InternalExitNodePrior() tailcfg.StableNodeID { return v.Đļ.InternalExitNodePrior } -func (v PrefsView) ExitNodeAllowLANAccess() bool { return v.Đļ.ExitNodeAllowLANAccess } -func (v PrefsView) CorpDNS() bool { return v.Đļ.CorpDNS } -func (v PrefsView) RunSSH() bool { return v.Đļ.RunSSH } -func (v PrefsView) RunWebClient() bool { return v.Đļ.RunWebClient } -func (v PrefsView) WantRunning() bool { return v.Đļ.WantRunning } -func (v PrefsView) LoggedOut() bool { return v.Đļ.LoggedOut } -func (v PrefsView) ShieldsUp() bool { return v.Đļ.ShieldsUp } -func (v PrefsView) AdvertiseTags() views.Slice[string] { return views.SliceOf(v.Đļ.AdvertiseTags) } -func (v PrefsView) Hostname() string { return v.Đļ.Hostname } -func (v PrefsView) NotepadURLs() bool { return v.Đļ.NotepadURLs } -func (v PrefsView) ForceDaemon() bool { return v.Đļ.ForceDaemon } -func (v PrefsView) Egg() bool { return v.Đļ.Egg } + +// ExitNodeAllowLANAccess indicates whether locally accessible subnets should be +// routed directly or via the exit node. +func (v PrefsView) ExitNodeAllowLANAccess() bool { return v.Đļ.ExitNodeAllowLANAccess } + +// CorpDNS specifies whether to install the Tailscale network's +// DNS configuration, if it exists. +func (v PrefsView) CorpDNS() bool { return v.Đļ.CorpDNS } + +// RunSSH bool is whether this node should run an SSH +// server, permitting access to peers according to the +// policies as configured by the Tailnet's admin(s). +func (v PrefsView) RunSSH() bool { return v.Đļ.RunSSH } + +// RunWebClient bool is whether this node should expose +// its web client over Tailscale at port 5252, +// permitting access to peers according to the +// policies as configured by the Tailnet's admin(s). +func (v PrefsView) RunWebClient() bool { return v.Đļ.RunWebClient } + +// WantRunning indicates whether networking should be active on +// this node. +func (v PrefsView) WantRunning() bool { return v.Đļ.WantRunning } + +// LoggedOut indicates whether the user intends to be logged out. +// There are other reasons we may be logged out, including no valid +// keys. +// We need to remember this state so that, on next startup, we can +// generate the "Login" vs "Connect" buttons correctly, without having +// to contact the server to confirm our nodekey status first. +func (v PrefsView) LoggedOut() bool { return v.Đļ.LoggedOut } + +// ShieldsUp indicates whether to block all incoming connections, +// regardless of the control-provided packet filter. If false, we +// use the packet filter as provided. If true, we block incoming +// connections. This overrides tailcfg.Hostinfo's ShieldsUp. +func (v PrefsView) ShieldsUp() bool { return v.Đļ.ShieldsUp } + +// AdvertiseTags specifies tags that should be applied to this node, for +// purposes of ACL enforcement. These can be referenced from the ACL policy +// document. Note that advertising a tag on the client doesn't guarantee +// that the control server will allow the node to adopt that tag. +func (v PrefsView) AdvertiseTags() views.Slice[string] { return views.SliceOf(v.Đļ.AdvertiseTags) } + +// Hostname is the hostname to use for identifying the node. If +// not set, os.Hostname is used. +func (v PrefsView) Hostname() string { return v.Đļ.Hostname } + +// NotepadURLs is a debugging setting that opens OAuth URLs in +// notepad.exe on Windows, rather than loading them in a browser. +// +// apenwarr 2020-04-29: Unfortunately this is still needed sometimes. +// Windows' default browser setting is sometimes screwy and this helps +// users narrow it down a bit. +func (v PrefsView) NotepadURLs() bool { return v.Đļ.NotepadURLs } + +// ForceDaemon specifies whether a platform that normally +// operates in "client mode" (that is, requires an active user +// logged in with the GUI app running) should keep running after the +// GUI ends and/or the user logs out. +// +// The only current applicable platform is Windows. This +// forced Windows to go into "server mode" where Tailscale is +// running even with no users logged in. This might also be +// used for macOS in the future. This setting has no effect +// for Linux/etc, which always operate in daemon mode. +func (v PrefsView) ForceDaemon() bool { return v.Đļ.ForceDaemon } + +// Egg is a optional debug flag. +func (v PrefsView) Egg() bool { return v.Đļ.Egg } + +// AdvertiseRoutes specifies CIDR prefixes to advertise into the +// Tailscale network as reachable through the current +// node. func (v PrefsView) AdvertiseRoutes() views.Slice[netip.Prefix] { return views.SliceOf(v.Đļ.AdvertiseRoutes) } -func (v PrefsView) NoSNAT() bool { return v.Đļ.NoSNAT } -func (v PrefsView) NoStatefulFiltering() opt.Bool { return v.Đļ.NoStatefulFiltering } + +// AdvertiseServices specifies the list of services that this +// node can serve as a destination for. Note that an advertised +// service must still go through the approval process from the +// control server. +func (v PrefsView) AdvertiseServices() views.Slice[string] { + return views.SliceOf(v.Đļ.AdvertiseServices) +} + +// Sync is whether this node should sync its configuration from +// the control plane. If unset, this defaults to true. +// This exists primarily for testing, to verify that netmap caching +// and offline operation work correctly. +func (v PrefsView) Sync() opt.Bool { return v.Đļ.Sync } + +// NoSNAT specifies whether to source NAT traffic going to +// destinations in AdvertiseRoutes. The default is to apply source +// NAT, which makes the traffic appear to come from the router +// machine rather than the peer's Tailscale IP. +// +// Disabling SNAT requires additional manual configuration in your +// network to route Tailscale traffic back to the subnet relay +// machine. +// +// Linux-only. +func (v PrefsView) NoSNAT() bool { return v.Đļ.NoSNAT } + +// NoStatefulFiltering specifies whether to apply stateful filtering when +// advertising routes in AdvertiseRoutes. The default is to not apply +// stateful filtering. +// +// To allow inbound connections from advertised routes, both NoSNAT and +// NoStatefulFiltering must be true. +// +// This is an opt.Bool because it was first added after NoSNAT, with a +// backfill based on the value of that parameter. The backfill has been +// removed since then, but the field remains an opt.Bool. +// +// Linux-only. +func (v PrefsView) NoStatefulFiltering() opt.Bool { return v.Đļ.NoStatefulFiltering } + +// NetfilterMode specifies how much to manage netfilter rules for +// Tailscale, if at all. func (v PrefsView) NetfilterMode() preftype.NetfilterMode { return v.Đļ.NetfilterMode } -func (v PrefsView) OperatorUser() string { return v.Đļ.OperatorUser } -func (v PrefsView) ProfileName() string { return v.Đļ.ProfileName } -func (v PrefsView) AutoUpdate() AutoUpdatePrefs { return v.Đļ.AutoUpdate } -func (v PrefsView) AppConnector() AppConnectorPrefs { return v.Đļ.AppConnector } -func (v PrefsView) PostureChecking() bool { return v.Đļ.PostureChecking } -func (v PrefsView) NetfilterKind() string { return v.Đļ.NetfilterKind } + +// OperatorUser is the local machine user name who is allowed to +// operate tailscaled without being root or using sudo. +func (v PrefsView) OperatorUser() string { return v.Đļ.OperatorUser } + +// ProfileName is the desired name of the profile. If empty, then the user's +// LoginName is used. It is only used for display purposes in the client UI +// and CLI. +func (v PrefsView) ProfileName() string { return v.Đļ.ProfileName } + +// AutoUpdate sets the auto-update preferences for the node agent. See +// AutoUpdatePrefs docs for more details. +func (v PrefsView) AutoUpdate() AutoUpdatePrefs { return v.Đļ.AutoUpdate } + +// AppConnector sets the app connector preferences for the node agent. See +// AppConnectorPrefs docs for more details. +func (v PrefsView) AppConnector() AppConnectorPrefs { return v.Đļ.AppConnector } + +// PostureChecking enables the collection of information used for device +// posture checks. +// +// Note: this should be named ReportPosture, but it was shipped as +// PostureChecking in some early releases and this JSON field is written to +// disk, so we just keep its old name. (akin to CorpDNS which is an internal +// pref name that doesn't match the public interface) +func (v PrefsView) PostureChecking() bool { return v.Đļ.PostureChecking } + +// NetfilterKind specifies what netfilter implementation to use. +// +// It can be "iptables", "nftables", or "" to auto-detect. +// +// Linux-only. +func (v PrefsView) NetfilterKind() string { return v.Đļ.NetfilterKind } + +// DriveShares are the configured DriveShares, stored in increasing order +// by name. func (v PrefsView) DriveShares() views.SliceView[*drive.Share, drive.ShareView] { return views.SliceOfViews[*drive.Share, drive.ShareView](v.Đļ.DriveShares) } + +// RelayServerPort is the UDP port number for the relay server to bind to, +// on all interfaces. A non-nil zero value signifies a random unused port +// should be used. A nil value signifies relay server functionality +// should be disabled. This field is currently experimental, and therefore +// no guarantees are made about its current naming and functionality when +// non-nil/enabled. +func (v PrefsView) RelayServerPort() views.ValuePointer[int] { + return views.ValuePointerOf(v.Đļ.RelayServerPort) +} + +// AllowSingleHosts was a legacy field that was always true +// for the past 4.5 years. It controlled whether Tailscale +// peers got /32 or /128 routes for each other. +// As of 2024-05-17 we're starting to ignore it, but to let +// people still downgrade Tailscale versions and not break +// all peer-to-peer networking we still write it to disk (as JSON) +// so it can be loaded back by old versions. +// TODO(bradfitz): delete this in 2025 sometime. See #12058. func (v PrefsView) AllowSingleHosts() marshalAsTrueInJSON { return v.Đļ.AllowSingleHosts } -func (v PrefsView) Persist() persist.PersistView { return v.Đļ.Persist.View() } + +// The Persist field is named 'Config' in the file for backward +// compatibility with earlier versions. +// TODO(apenwarr): We should move this out of here, it's not a pref. +// +// We can maybe do that once we're sure which module should persist +// it (backend or frontend?) +func (v PrefsView) Persist() persist.PersistView { return v.Đļ.Persist.View() } // A compilation failure here means this code must be regenerated, with the command at the top of this file. var _PrefsViewNeedsRegeneration = Prefs(struct { @@ -106,6 +472,7 @@ var _PrefsViewNeedsRegeneration = Prefs(struct { RouteAll bool ExitNodeID tailcfg.StableNodeID ExitNodeIP netip.Addr + AutoExitNode ExitNodeExpression InternalExitNodePrior tailcfg.StableNodeID ExitNodeAllowLANAccess bool CorpDNS bool @@ -120,6 +487,8 @@ var _PrefsViewNeedsRegeneration = Prefs(struct { ForceDaemon bool Egg bool AdvertiseRoutes []netip.Prefix + AdvertiseServices []string + Sync opt.Bool NoSNAT bool NoStatefulFiltering opt.Bool NetfilterMode preftype.NetfilterMode @@ -130,11 +499,12 @@ var _PrefsViewNeedsRegeneration = Prefs(struct { PostureChecking bool NetfilterKind string DriveShares []*drive.Share + RelayServerPort *int AllowSingleHosts marshalAsTrueInJSON Persist *persist.Persist }{}) -// View returns a readonly view of ServeConfig. +// View returns a read-only view of ServeConfig. func (p *ServeConfig) View() ServeConfigView { return ServeConfigView{Đļ: p} } @@ -150,7 +520,7 @@ type ServeConfigView struct { Đļ *ServeConfig } -// Valid reports whether underlying value is non-nil. +// Valid reports whether v's underlying value is non-nil. func (v ServeConfigView) Valid() bool { return v.Đļ != nil } // AsStruct returns a clone of the underlying value which aliases no memory with @@ -162,8 +532,17 @@ func (v ServeConfigView) AsStruct() *ServeConfig { return v.Đļ.Clone() } -func (v ServeConfigView) MarshalJSON() ([]byte, error) { return json.Marshal(v.Đļ) } +// MarshalJSON implements [jsonv1.Marshaler]. +func (v ServeConfigView) MarshalJSON() ([]byte, error) { + return jsonv1.Marshal(v.Đļ) +} +// MarshalJSONTo implements [jsonv2.MarshalerTo]. +func (v ServeConfigView) MarshalJSONTo(enc *jsontext.Encoder) error { + return jsonv2.MarshalEncode(enc, v.Đļ) +} + +// UnmarshalJSON implements [jsonv1.Unmarshaler]. func (v *ServeConfigView) UnmarshalJSON(b []byte) error { if v.Đļ != nil { return errors.New("already initialized") @@ -172,46 +551,178 @@ func (v *ServeConfigView) UnmarshalJSON(b []byte) error { return nil } var x ServeConfig - if err := json.Unmarshal(b, &x); err != nil { + if err := jsonv1.Unmarshal(b, &x); err != nil { return err } v.Đļ = &x return nil } +// UnmarshalJSONFrom implements [jsonv2.UnmarshalerFrom]. +func (v *ServeConfigView) UnmarshalJSONFrom(dec *jsontext.Decoder) error { + if v.Đļ != nil { + return errors.New("already initialized") + } + var x ServeConfig + if err := jsonv2.UnmarshalDecode(dec, &x); err != nil { + return err + } + v.Đļ = &x + return nil +} + +// TCP are the list of TCP port numbers that tailscaled should handle for +// the Tailscale IP addresses. (not subnet routers, etc) func (v ServeConfigView) TCP() views.MapFn[uint16, *TCPPortHandler, TCPPortHandlerView] { return views.MapFnOf(v.Đļ.TCP, func(t *TCPPortHandler) TCPPortHandlerView { return t.View() }) } +// Web maps from "$SNI_NAME:$PORT" to a set of HTTP handlers +// keyed by mount point ("/", "/foo", etc) func (v ServeConfigView) Web() views.MapFn[HostPort, *WebServerConfig, WebServerConfigView] { return views.MapFnOf(v.Đļ.Web, func(t *WebServerConfig) WebServerConfigView { return t.View() }) } +// Services maps from service name (in the form "svc:dns-label") to a ServiceConfig. +// Which describes the L3, L4, and L7 forwarding information for the service. +func (v ServeConfigView) Services() views.MapFn[tailcfg.ServiceName, *ServiceConfig, ServiceConfigView] { + return views.MapFnOf(v.Đļ.Services, func(t *ServiceConfig) ServiceConfigView { + return t.View() + }) +} + +// AllowFunnel is the set of SNI:port values for which funnel +// traffic is allowed, from trusted ingress peers. func (v ServeConfigView) AllowFunnel() views.Map[HostPort, bool] { return views.MapOf(v.Đļ.AllowFunnel) } +// Foreground is a map of an IPN Bus session ID to an alternate foreground serve config that's valid for the +// life of that WatchIPNBus session ID. This allows the config to specify ephemeral configs that are used +// in the CLI's foreground mode to ensure ungraceful shutdowns of either the client or the LocalBackend does not +// expose ports that users are not aware of. In practice this contains any serve config set via 'tailscale +// serve' command run without the '--bg' flag. ServeConfig contained by Foreground is not expected itself to contain +// another Foreground block. func (v ServeConfigView) Foreground() views.MapFn[string, *ServeConfig, ServeConfigView] { return views.MapFnOf(v.Đļ.Foreground, func(t *ServeConfig) ServeConfigView { return t.View() }) } + +// ETag is the checksum of the serve config that's populated +// by the LocalClient through the HTTP ETag header during a +// GetServeConfig request and is translated to an If-Match header +// during a SetServeConfig request. func (v ServeConfigView) ETag() string { return v.Đļ.ETag } // A compilation failure here means this code must be regenerated, with the command at the top of this file. var _ServeConfigViewNeedsRegeneration = ServeConfig(struct { TCP map[uint16]*TCPPortHandler Web map[HostPort]*WebServerConfig + Services map[tailcfg.ServiceName]*ServiceConfig AllowFunnel map[HostPort]bool Foreground map[string]*ServeConfig ETag string }{}) -// View returns a readonly view of TCPPortHandler. +// View returns a read-only view of ServiceConfig. +func (p *ServiceConfig) View() ServiceConfigView { + return ServiceConfigView{Đļ: p} +} + +// ServiceConfigView provides a read-only view over ServiceConfig. +// +// Its methods should only be called if `Valid()` returns true. +type ServiceConfigView struct { + // Đļ is the underlying mutable value, named with a hard-to-type + // character that looks pointy like a pointer. + // It is named distinctively to make you think of how dangerous it is to escape + // to callers. You must not let callers be able to mutate it. + Đļ *ServiceConfig +} + +// Valid reports whether v's underlying value is non-nil. +func (v ServiceConfigView) Valid() bool { return v.Đļ != nil } + +// AsStruct returns a clone of the underlying value which aliases no memory with +// the original. +func (v ServiceConfigView) AsStruct() *ServiceConfig { + if v.Đļ == nil { + return nil + } + return v.Đļ.Clone() +} + +// MarshalJSON implements [jsonv1.Marshaler]. +func (v ServiceConfigView) MarshalJSON() ([]byte, error) { + return jsonv1.Marshal(v.Đļ) +} + +// MarshalJSONTo implements [jsonv2.MarshalerTo]. +func (v ServiceConfigView) MarshalJSONTo(enc *jsontext.Encoder) error { + return jsonv2.MarshalEncode(enc, v.Đļ) +} + +// UnmarshalJSON implements [jsonv1.Unmarshaler]. +func (v *ServiceConfigView) UnmarshalJSON(b []byte) error { + if v.Đļ != nil { + return errors.New("already initialized") + } + if len(b) == 0 { + return nil + } + var x ServiceConfig + if err := jsonv1.Unmarshal(b, &x); err != nil { + return err + } + v.Đļ = &x + return nil +} + +// UnmarshalJSONFrom implements [jsonv2.UnmarshalerFrom]. +func (v *ServiceConfigView) UnmarshalJSONFrom(dec *jsontext.Decoder) error { + if v.Đļ != nil { + return errors.New("already initialized") + } + var x ServiceConfig + if err := jsonv2.UnmarshalDecode(dec, &x); err != nil { + return err + } + v.Đļ = &x + return nil +} + +// TCP are the list of TCP port numbers that tailscaled should handle for +// the Tailscale IP addresses. (not subnet routers, etc) +func (v ServiceConfigView) TCP() views.MapFn[uint16, *TCPPortHandler, TCPPortHandlerView] { + return views.MapFnOf(v.Đļ.TCP, func(t *TCPPortHandler) TCPPortHandlerView { + return t.View() + }) +} + +// Web maps from "$SNI_NAME:$PORT" to a set of HTTP handlers +// keyed by mount point ("/", "/foo", etc) +func (v ServiceConfigView) Web() views.MapFn[HostPort, *WebServerConfig, WebServerConfigView] { + return views.MapFnOf(v.Đļ.Web, func(t *WebServerConfig) WebServerConfigView { + return t.View() + }) +} + +// Tun determines if the service should be using L3 forwarding (Tun mode). +func (v ServiceConfigView) Tun() bool { return v.Đļ.Tun } + +// A compilation failure here means this code must be regenerated, with the command at the top of this file. +var _ServiceConfigViewNeedsRegeneration = ServiceConfig(struct { + TCP map[uint16]*TCPPortHandler + Web map[HostPort]*WebServerConfig + Tun bool +}{}) + +// View returns a read-only view of TCPPortHandler. func (p *TCPPortHandler) View() TCPPortHandlerView { return TCPPortHandlerView{Đļ: p} } @@ -227,7 +738,7 @@ type TCPPortHandlerView struct { Đļ *TCPPortHandler } -// Valid reports whether underlying value is non-nil. +// Valid reports whether v's underlying value is non-nil. func (v TCPPortHandlerView) Valid() bool { return v.Đļ != nil } // AsStruct returns a clone of the underlying value which aliases no memory with @@ -239,8 +750,17 @@ func (v TCPPortHandlerView) AsStruct() *TCPPortHandler { return v.Đļ.Clone() } -func (v TCPPortHandlerView) MarshalJSON() ([]byte, error) { return json.Marshal(v.Đļ) } +// MarshalJSON implements [jsonv1.Marshaler]. +func (v TCPPortHandlerView) MarshalJSON() ([]byte, error) { + return jsonv1.Marshal(v.Đļ) +} +// MarshalJSONTo implements [jsonv2.MarshalerTo]. +func (v TCPPortHandlerView) MarshalJSONTo(enc *jsontext.Encoder) error { + return jsonv2.MarshalEncode(enc, v.Đļ) +} + +// UnmarshalJSON implements [jsonv1.Unmarshaler]. func (v *TCPPortHandlerView) UnmarshalJSON(b []byte) error { if v.Đļ != nil { return errors.New("already initialized") @@ -249,27 +769,67 @@ func (v *TCPPortHandlerView) UnmarshalJSON(b []byte) error { return nil } var x TCPPortHandler - if err := json.Unmarshal(b, &x); err != nil { + if err := jsonv1.Unmarshal(b, &x); err != nil { + return err + } + v.Đļ = &x + return nil +} + +// UnmarshalJSONFrom implements [jsonv2.UnmarshalerFrom]. +func (v *TCPPortHandlerView) UnmarshalJSONFrom(dec *jsontext.Decoder) error { + if v.Đļ != nil { + return errors.New("already initialized") + } + var x TCPPortHandler + if err := jsonv2.UnmarshalDecode(dec, &x); err != nil { return err } v.Đļ = &x return nil } -func (v TCPPortHandlerView) HTTPS() bool { return v.Đļ.HTTPS } -func (v TCPPortHandlerView) HTTP() bool { return v.Đļ.HTTP } -func (v TCPPortHandlerView) TCPForward() string { return v.Đļ.TCPForward } +// HTTPS, if true, means that tailscaled should handle this connection as an +// HTTPS request as configured by ServeConfig.Web. +// +// It is mutually exclusive with TCPForward. +func (v TCPPortHandlerView) HTTPS() bool { return v.Đļ.HTTPS } + +// HTTP, if true, means that tailscaled should handle this connection as an +// HTTP request as configured by ServeConfig.Web. +// +// It is mutually exclusive with TCPForward. +func (v TCPPortHandlerView) HTTP() bool { return v.Đļ.HTTP } + +// TCPForward is the IP:port to forward TCP connections to. +// Whether or not TLS is terminated by tailscaled depends on +// TerminateTLS. +// +// It is mutually exclusive with HTTPS. +func (v TCPPortHandlerView) TCPForward() string { return v.Đļ.TCPForward } + +// TerminateTLS, if non-empty, means that tailscaled should terminate the +// TLS connections before forwarding them to TCPForward, permitting only the +// SNI name with this value. It is only used if TCPForward is non-empty. +// (the HTTPS mode uses ServeConfig.Web) func (v TCPPortHandlerView) TerminateTLS() string { return v.Đļ.TerminateTLS } +// ProxyProtocol indicates whether to send a PROXY protocol header +// before forwarding the connection to TCPForward. +// +// This is only valid if TCPForward is non-empty. +func (v TCPPortHandlerView) ProxyProtocol() int { return v.Đļ.ProxyProtocol } + // A compilation failure here means this code must be regenerated, with the command at the top of this file. var _TCPPortHandlerViewNeedsRegeneration = TCPPortHandler(struct { - HTTPS bool - HTTP bool - TCPForward string - TerminateTLS string + HTTPS bool + HTTP bool + TCPForward string + TerminateTLS string + ProxyProtocol int }{}) -// View returns a readonly view of HTTPHandler. +// View returns a read-only view of HTTPHandler. func (p *HTTPHandler) View() HTTPHandlerView { return HTTPHandlerView{Đļ: p} } @@ -285,7 +845,7 @@ type HTTPHandlerView struct { Đļ *HTTPHandler } -// Valid reports whether underlying value is non-nil. +// Valid reports whether v's underlying value is non-nil. func (v HTTPHandlerView) Valid() bool { return v.Đļ != nil } // AsStruct returns a clone of the underlying value which aliases no memory with @@ -297,8 +857,17 @@ func (v HTTPHandlerView) AsStruct() *HTTPHandler { return v.Đļ.Clone() } -func (v HTTPHandlerView) MarshalJSON() ([]byte, error) { return json.Marshal(v.Đļ) } +// MarshalJSON implements [jsonv1.Marshaler]. +func (v HTTPHandlerView) MarshalJSON() ([]byte, error) { + return jsonv1.Marshal(v.Đļ) +} + +// MarshalJSONTo implements [jsonv2.MarshalerTo]. +func (v HTTPHandlerView) MarshalJSONTo(enc *jsontext.Encoder) error { + return jsonv2.MarshalEncode(enc, v.Đļ) +} +// UnmarshalJSON implements [jsonv1.Unmarshaler]. func (v *HTTPHandlerView) UnmarshalJSON(b []byte) error { if v.Đļ != nil { return errors.New("already initialized") @@ -307,25 +876,59 @@ func (v *HTTPHandlerView) UnmarshalJSON(b []byte) error { return nil } var x HTTPHandler - if err := json.Unmarshal(b, &x); err != nil { + if err := jsonv1.Unmarshal(b, &x); err != nil { return err } v.Đļ = &x return nil } -func (v HTTPHandlerView) Path() string { return v.Đļ.Path } +// UnmarshalJSONFrom implements [jsonv2.UnmarshalerFrom]. +func (v *HTTPHandlerView) UnmarshalJSONFrom(dec *jsontext.Decoder) error { + if v.Đļ != nil { + return errors.New("already initialized") + } + var x HTTPHandler + if err := jsonv2.UnmarshalDecode(dec, &x); err != nil { + return err + } + v.Đļ = &x + return nil +} + +// absolute path to directory or file to serve +func (v HTTPHandlerView) Path() string { return v.Đļ.Path } + +// http://localhost:3000/, localhost:3030, 3030 func (v HTTPHandlerView) Proxy() string { return v.Đļ.Proxy } -func (v HTTPHandlerView) Text() string { return v.Đļ.Text } + +// plaintext to serve (primarily for testing) +func (v HTTPHandlerView) Text() string { return v.Đļ.Text } + +// peer capabilities to forward in grant header, e.g. example.com/cap/mon +func (v HTTPHandlerView) AcceptAppCaps() views.Slice[tailcfg.PeerCapability] { + return views.SliceOf(v.Đļ.AcceptAppCaps) +} + +// Redirect, if not empty, is the target URL to redirect requests to. +// By default, we redirect with HTTP 302 (Found) status. +// If Redirect starts with ':', then we use that status instead. +// +// The target URL supports the following expansion variables: +// - ${HOST}: replaced with the request's Host header value +// - ${REQUEST_URI}: replaced with the request's full URI (path and query string) +func (v HTTPHandlerView) Redirect() string { return v.Đļ.Redirect } // A compilation failure here means this code must be regenerated, with the command at the top of this file. var _HTTPHandlerViewNeedsRegeneration = HTTPHandler(struct { - Path string - Proxy string - Text string + Path string + Proxy string + Text string + AcceptAppCaps []tailcfg.PeerCapability + Redirect string }{}) -// View returns a readonly view of WebServerConfig. +// View returns a read-only view of WebServerConfig. func (p *WebServerConfig) View() WebServerConfigView { return WebServerConfigView{Đļ: p} } @@ -341,7 +944,7 @@ type WebServerConfigView struct { Đļ *WebServerConfig } -// Valid reports whether underlying value is non-nil. +// Valid reports whether v's underlying value is non-nil. func (v WebServerConfigView) Valid() bool { return v.Đļ != nil } // AsStruct returns a clone of the underlying value which aliases no memory with @@ -353,8 +956,17 @@ func (v WebServerConfigView) AsStruct() *WebServerConfig { return v.Đļ.Clone() } -func (v WebServerConfigView) MarshalJSON() ([]byte, error) { return json.Marshal(v.Đļ) } +// MarshalJSON implements [jsonv1.Marshaler]. +func (v WebServerConfigView) MarshalJSON() ([]byte, error) { + return jsonv1.Marshal(v.Đļ) +} + +// MarshalJSONTo implements [jsonv2.MarshalerTo]. +func (v WebServerConfigView) MarshalJSONTo(enc *jsontext.Encoder) error { + return jsonv2.MarshalEncode(enc, v.Đļ) +} +// UnmarshalJSON implements [jsonv1.Unmarshaler]. func (v *WebServerConfigView) UnmarshalJSON(b []byte) error { if v.Đļ != nil { return errors.New("already initialized") @@ -363,13 +975,27 @@ func (v *WebServerConfigView) UnmarshalJSON(b []byte) error { return nil } var x WebServerConfig - if err := json.Unmarshal(b, &x); err != nil { + if err := jsonv1.Unmarshal(b, &x); err != nil { + return err + } + v.Đļ = &x + return nil +} + +// UnmarshalJSONFrom implements [jsonv2.UnmarshalerFrom]. +func (v *WebServerConfigView) UnmarshalJSONFrom(dec *jsontext.Decoder) error { + if v.Đļ != nil { + return errors.New("already initialized") + } + var x WebServerConfig + if err := jsonv2.UnmarshalDecode(dec, &x); err != nil { return err } v.Đļ = &x return nil } +// mountPoint => handler func (v WebServerConfigView) Handlers() views.MapFn[string, *HTTPHandler, HTTPHandlerView] { return views.MapFnOf(v.Đļ.Handlers, func(t *HTTPHandler) HTTPHandlerView { return t.View() diff --git a/ipn/ipnauth/access.go b/ipn/ipnauth/access.go new file mode 100644 index 000000000..74c663922 --- /dev/null +++ b/ipn/ipnauth/access.go @@ -0,0 +1,17 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package ipnauth + +// ProfileAccess is a bitmask representing the requested, required, or granted +// access rights to an [ipn.LoginProfile]. +// +// It is not to be written to disk or transmitted over the network in its integer form, +// but rather serialized to a string or other format if ever needed. +type ProfileAccess uint + +// Define access rights that might be granted or denied on a per-profile basis. +const ( + // Disconnect is required to disconnect (or switch from) a Tailscale profile. + Disconnect = ProfileAccess(1 << iota) +) diff --git a/ipn/ipnauth/actor.go b/ipn/ipnauth/actor.go index db3192c91..108bdd341 100644 --- a/ipn/ipnauth/actor.go +++ b/ipn/ipnauth/actor.go @@ -4,9 +4,18 @@ package ipnauth import ( + "context" + "encoding/json" + "fmt" + + "tailscale.com/client/tailscale/apitype" "tailscale.com/ipn" + "tailscale.com/tailcfg" ) +// AuditLogFunc is any function that can be used to log audit actions performed by an [Actor]. +type AuditLogFunc func(action tailcfg.ClientAuditAction, details string) error + // Actor is any actor using the [ipnlocal.LocalBackend]. // // It typically represents a specific OS user, indicating that an operation @@ -20,6 +29,22 @@ type Actor interface { // Username returns the user name associated with the receiver, // or "" if the actor does not represent a specific user. Username() (string, error) + // ClientID returns a non-zero ClientID and true if the actor represents + // a connected LocalAPI client. Otherwise, it returns a zero value and false. + ClientID() (_ ClientID, ok bool) + + // Context returns the context associated with the actor. + // It carries additional information about the actor + // and is canceled when the actor is done. + Context() context.Context + + // CheckProfileAccess checks whether the actor has the necessary access rights + // to perform a given action on the specified Tailscale profile. + // It returns an error if access is denied. + // + // If the auditLogger is non-nil, it is used to write details about the action + // to the audit log when required by the policy. + CheckProfileAccess(profile ipn.LoginProfileView, requestedAccess ProfileAccess, auditLogFn AuditLogFunc) error // IsLocalSystem reports whether the actor is the Windows' Local System account. // @@ -45,3 +70,65 @@ type ActorCloser interface { // Close releases resources associated with the receiver. Close() error } + +// ClientID is an opaque, comparable value used to identify a connected LocalAPI +// client, such as a connected Tailscale GUI or CLI. It does not necessarily +// correspond to the same [net.Conn] or any physical session. +// +// Its zero value is valid, but does not represent a specific connected client. +type ClientID struct { + v any +} + +// NoClientID is the zero value of [ClientID]. +var NoClientID ClientID + +// ClientIDFrom returns a new [ClientID] derived from the specified value. +// ClientIDs derived from equal values are equal. +func ClientIDFrom[T comparable](v T) ClientID { + return ClientID{v} +} + +// String implements [fmt.Stringer]. +func (id ClientID) String() string { + if id.v == nil { + return "(none)" + } + return fmt.Sprint(id.v) +} + +// MarshalJSON implements [json.Marshaler]. +// It is primarily used for testing. +func (id ClientID) MarshalJSON() ([]byte, error) { + return json.Marshal(id.v) +} + +// UnmarshalJSON implements [json.Unmarshaler]. +// It is primarily used for testing. +func (id *ClientID) UnmarshalJSON(b []byte) error { + return json.Unmarshal(b, &id.v) +} + +type actorWithRequestReason struct { + Actor + ctx context.Context +} + +// WithRequestReason returns an [Actor] that wraps the given actor and +// carries the specified request reason in its context. +func WithRequestReason(actor Actor, requestReason string) Actor { + ctx := apitype.RequestReasonKey.WithValue(actor.Context(), requestReason) + return &actorWithRequestReason{Actor: actor, ctx: ctx} +} + +// Context implements [Actor]. +func (a *actorWithRequestReason) Context() context.Context { return a.ctx } + +type withoutCloseActor struct{ Actor } + +// WithoutClose returns an [Actor] that does not expose the [ActorCloser] interface. +// In other words, _, ok := WithoutClose(actor).(ActorCloser) will always be false, +// even if the original actor implements [ActorCloser]. +func WithoutClose(actor Actor) Actor { + return withoutCloseActor{actor} +} diff --git a/ipn/ipnauth/actor_windows.go b/ipn/ipnauth/actor_windows.go new file mode 100644 index 000000000..90d3bdd36 --- /dev/null +++ b/ipn/ipnauth/actor_windows.go @@ -0,0 +1,102 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package ipnauth + +import ( + "context" + "errors" + + "golang.org/x/sys/windows" + "tailscale.com/ipn" + "tailscale.com/types/lazy" +) + +// WindowsActor implements [Actor]. +var _ Actor = (*WindowsActor)(nil) + +// WindowsActor represents a logged in Windows user. +type WindowsActor struct { + ctx context.Context + cancelCtx context.CancelFunc + token WindowsToken + uid ipn.WindowsUserID + username lazy.SyncValue[string] +} + +// NewWindowsActorWithToken returns a new [WindowsActor] for the user +// represented by the given [windows.Token]. +// It takes ownership of the token. +func NewWindowsActorWithToken(t windows.Token) (_ *WindowsActor, err error) { + tok := newToken(t) + uid, err := tok.UID() + if err != nil { + t.Close() + return nil, err + } + ctx, cancelCtx := context.WithCancel(context.Background()) + return &WindowsActor{ctx: ctx, cancelCtx: cancelCtx, token: tok, uid: uid}, nil +} + +// UserID implements [Actor]. +func (a *WindowsActor) UserID() ipn.WindowsUserID { + return a.uid +} + +// Username implements [Actor]. +func (a *WindowsActor) Username() (string, error) { + return a.username.GetErr(a.token.Username) +} + +// ClientID implements [Actor]. +func (a *WindowsActor) ClientID() (_ ClientID, ok bool) { + // TODO(nickkhyl): assign and return a client ID when the actor + // represents a connected LocalAPI client. + return NoClientID, false +} + +// Context implements [Actor]. +func (a *WindowsActor) Context() context.Context { + return a.ctx +} + +// CheckProfileAccess implements [Actor]. +func (a *WindowsActor) CheckProfileAccess(profile ipn.LoginProfileView, _ ProfileAccess, _ AuditLogFunc) error { + if profile.LocalUserID() != a.UserID() { + // TODO(nickkhyl): return errors of more specific types and have them + // translated to the appropriate HTTP status codes in the API handler. + return errors.New("the target profile does not belong to the user") + } + return nil +} + +// IsLocalSystem implements [Actor]. +// +// Deprecated: this method exists for compatibility with the current (as of 2025-02-06) +// permission model and will be removed as we progress on tailscale/corp#18342. +func (a *WindowsActor) IsLocalSystem() bool { + // https://web.archive.org/web/2024/https://learn.microsoft.com/en-us/windows-server/identity/ad-ds/manage/understand-security-identifiers + const systemUID = ipn.WindowsUserID("S-1-5-18") + return a.uid == systemUID +} + +// IsLocalAdmin implements [Actor]. +// +// Deprecated: this method exists for compatibility with the current (as of 2025-02-06) +// permission model and will be removed as we progress on tailscale/corp#18342. +func (a *WindowsActor) IsLocalAdmin(operatorUID string) bool { + return a.token.IsElevated() +} + +// Close releases resources associated with the actor +// and cancels its context. +func (a *WindowsActor) Close() error { + if a.token != nil { + if err := a.token.Close(); err != nil { + return err + } + a.token = nil + } + a.cancelCtx() + return nil +} diff --git a/ipn/ipnauth/ipnauth.go b/ipn/ipnauth/ipnauth.go index e6560570c..497f30f8c 100644 --- a/ipn/ipnauth/ipnauth.go +++ b/ipn/ipnauth/ipnauth.go @@ -14,8 +14,8 @@ import ( "runtime" "strconv" - "github.com/tailscale/peercred" "tailscale.com/envknob" + "tailscale.com/feature/buildfeatures" "tailscale.com/ipn" "tailscale.com/safesocket" "tailscale.com/types/logger" @@ -63,8 +63,8 @@ type ConnIdentity struct { notWindows bool // runtime.GOOS != "windows" // Fields used when NotWindows: - isUnixSock bool // Conn is a *net.UnixConn - creds *peercred.Creds // or nil + isUnixSock bool // Conn is a *net.UnixConn + creds PeerCreds // or nil if peercred.Get was not implemented on this OS // Used on Windows: // TODO(bradfitz): merge these into the peercreds package and @@ -78,6 +78,13 @@ type ConnIdentity struct { // It's suitable for passing to LookupUserFromID (os/user.LookupId) on any // operating system. func (ci *ConnIdentity) WindowsUserID() ipn.WindowsUserID { + if !buildfeatures.HasDebug && runtime.GOOS != "windows" { + // This function is only implemented on non-Windows for simulating + // Windows in tests. But that test (per comments below) is broken + // anyway. So disable this testing path in non-debug builds + // and just do the thing that optimizes away. + return "" + } if envknob.GOOS() != "windows" { return "" } @@ -97,9 +104,18 @@ func (ci *ConnIdentity) WindowsUserID() ipn.WindowsUserID { return "" } -func (ci *ConnIdentity) Pid() int { return ci.pid } -func (ci *ConnIdentity) IsUnixSock() bool { return ci.isUnixSock } -func (ci *ConnIdentity) Creds() *peercred.Creds { return ci.creds } +func (ci *ConnIdentity) Pid() int { return ci.pid } +func (ci *ConnIdentity) IsUnixSock() bool { return ci.isUnixSock } +func (ci *ConnIdentity) Creds() PeerCreds { return ci.creds } + +// PeerCreds is the interface for a github.com/tailscale/peercred.Creds, +// if linked into the binary. +// +// (It's not used on some platforms, or if ts_omit_unixsocketidentity is set.) +type PeerCreds interface { + UserID() (uid string, ok bool) + PID() (pid int, ok bool) +} var metricIssue869Workaround = clientmetric.NewCounter("issue_869_workaround") diff --git a/ipn/ipnauth/ipnauth_notwindows.go b/ipn/ipnauth/ipnauth_omit_unixsocketidentity.go similarity index 77% rename from ipn/ipnauth/ipnauth_notwindows.go rename to ipn/ipnauth/ipnauth_omit_unixsocketidentity.go index 3dad8233a..defe7d89c 100644 --- a/ipn/ipnauth/ipnauth_notwindows.go +++ b/ipn/ipnauth/ipnauth_omit_unixsocketidentity.go @@ -1,14 +1,13 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -//go:build !windows +//go:build !windows && ts_omit_unixsocketidentity package ipnauth import ( "net" - "github.com/tailscale/peercred" "tailscale.com/types/logger" ) @@ -16,10 +15,7 @@ import ( // based on the user who owns the other end of the connection. // and couldn't. The returned connIdentity has NotWindows set to true. func GetConnIdentity(_ logger.Logf, c net.Conn) (ci *ConnIdentity, err error) { - ci = &ConnIdentity{conn: c, notWindows: true} - _, ci.isUnixSock = c.(*net.UnixConn) - ci.creds, _ = peercred.Get(c) - return ci, nil + return &ConnIdentity{conn: c, notWindows: true}, nil } // WindowsToken is unsupported when GOOS != windows and always returns diff --git a/ipn/ipnauth/ipnauth_unix_creds.go b/ipn/ipnauth/ipnauth_unix_creds.go new file mode 100644 index 000000000..89a9ceaa9 --- /dev/null +++ b/ipn/ipnauth/ipnauth_unix_creds.go @@ -0,0 +1,37 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !windows && !ts_omit_unixsocketidentity + +package ipnauth + +import ( + "net" + + "github.com/tailscale/peercred" + "tailscale.com/types/logger" +) + +// GetConnIdentity extracts the identity information from the connection +// based on the user who owns the other end of the connection. +// and couldn't. The returned connIdentity has NotWindows set to true. +func GetConnIdentity(_ logger.Logf, c net.Conn) (ci *ConnIdentity, err error) { + ci = &ConnIdentity{conn: c, notWindows: true} + _, ci.isUnixSock = c.(*net.UnixConn) + if creds, err := peercred.Get(c); err == nil { + ci.creds = creds + ci.pid, _ = ci.creds.PID() + } else if err == peercred.ErrNotImplemented { + // peercred.Get is not implemented on this OS (such as OpenBSD) + // Just leave creds as nil, as documented. + } else { + return nil, err + } + return ci, nil +} + +// WindowsToken is unsupported when GOOS != windows and always returns +// ErrNotImplemented. +func (ci *ConnIdentity) WindowsToken() (WindowsToken, error) { + return nil, ErrNotImplemented +} diff --git a/ipn/ipnauth/ipnauth_windows.go b/ipn/ipnauth/ipnauth_windows.go index 9abd04cd1..1138bc23d 100644 --- a/ipn/ipnauth/ipnauth_windows.go +++ b/ipn/ipnauth/ipnauth_windows.go @@ -36,6 +36,12 @@ type token struct { t windows.Token } +func newToken(t windows.Token) *token { + tok := &token{t: t} + runtime.SetFinalizer(tok, func(t *token) { t.Close() }) + return tok +} + func (t *token) UID() (ipn.WindowsUserID, error) { sid, err := t.uid() if err != nil { @@ -184,7 +190,5 @@ func (ci *ConnIdentity) WindowsToken() (WindowsToken, error) { return nil, err } - result := &token{t: windows.Token(h)} - runtime.SetFinalizer(result, func(t *token) { t.Close() }) - return result, nil + return newToken(windows.Token(h)), nil } diff --git a/ipn/ipnauth/policy.go b/ipn/ipnauth/policy.go new file mode 100644 index 000000000..eeee32435 --- /dev/null +++ b/ipn/ipnauth/policy.go @@ -0,0 +1,79 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package ipnauth + +import ( + "errors" + "fmt" + + "tailscale.com/client/tailscale/apitype" + "tailscale.com/feature/buildfeatures" + "tailscale.com/ipn" + "tailscale.com/tailcfg" + "tailscale.com/util/syspolicy/pkey" + "tailscale.com/util/syspolicy/policyclient" +) + +type actorWithPolicyChecks struct{ Actor } + +// WithPolicyChecks returns an [Actor] that wraps the given actor and +// performs additional policy checks on top of the access checks +// implemented by the wrapped actor. +func WithPolicyChecks(actor Actor) Actor { + // TODO(nickkhyl): We should probably exclude the Windows Local System + // account from policy checks as well. + switch actor.(type) { + case unrestricted: + return actor + default: + return &actorWithPolicyChecks{Actor: actor} + } +} + +// CheckProfileAccess implements [Actor]. +func (a actorWithPolicyChecks) CheckProfileAccess(profile ipn.LoginProfileView, requestedAccess ProfileAccess, auditLogger AuditLogFunc) error { + if err := a.Actor.CheckProfileAccess(profile, requestedAccess, auditLogger); err != nil { + return err + } + requestReason := apitype.RequestReasonKey.Value(a.Context()) + return CheckDisconnectPolicy(a.Actor, profile, requestReason, auditLogger) +} + +// CheckDisconnectPolicy checks if the policy allows the specified actor to disconnect +// Tailscale with the given optional reason. It returns nil if the operation is allowed, +// or an error if it is not. If auditLogger is non-nil, it is called to log the action +// when required by the policy. +// +// Note: this function only checks the policy and does not check whether the actor has +// the necessary access rights to the device or profile. It is intended to be used by +// [Actor] implementations on platforms where [syspolicy] is supported. +// +// TODO(nickkhyl): unexport it when we move [ipn.Actor] implementations from [ipnserver] +// and corp to this package. +func CheckDisconnectPolicy(actor Actor, profile ipn.LoginProfileView, reason string, auditFn AuditLogFunc) error { + if !buildfeatures.HasSystemPolicy { + return nil + } + if alwaysOn, _ := policyclient.Get().GetBoolean(pkey.AlwaysOn, false); !alwaysOn { + return nil + } + if allowWithReason, _ := policyclient.Get().GetBoolean(pkey.AlwaysOnOverrideWithReason, false); !allowWithReason { + return errors.New("disconnect not allowed: always-on mode is enabled") + } + if reason == "" { + return errors.New("disconnect not allowed: reason required") + } + if auditFn != nil { + var details string + if username, _ := actor.Username(); username != "" { // best-effort; we don't have it on all platforms + details = fmt.Sprintf("%q is being disconnected by %q: %v", profile.Name(), username, reason) + } else { + details = fmt.Sprintf("%q is being disconnected: %v", profile.Name(), reason) + } + if err := auditFn(tailcfg.AuditNodeDisconnect, details); err != nil { + return err + } + } + return nil +} diff --git a/ipn/ipnauth/self.go b/ipn/ipnauth/self.go new file mode 100644 index 000000000..adee06964 --- /dev/null +++ b/ipn/ipnauth/self.go @@ -0,0 +1,63 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package ipnauth + +import ( + "context" + + "tailscale.com/ipn" +) + +// Self is a caller identity that represents the tailscaled itself and therefore +// has unlimited access. +var Self Actor = unrestricted{} + +// TODO is a caller identity used when the operation is performed on behalf of a user, +// rather than by tailscaled itself, but the surrounding function is not yet extended +// to accept an [Actor] parameter. It grants the same unrestricted access as [Self]. +var TODO Actor = unrestricted{} + +// unrestricted is an [Actor] that has unlimited access to the currently running +// tailscaled instance. It's typically used for operations performed by tailscaled +// on its own, or upon a request from the control plane, rather on behalf of a user. +type unrestricted struct{} + +// UserID implements [Actor]. +func (unrestricted) UserID() ipn.WindowsUserID { return "" } + +// Username implements [Actor]. +func (unrestricted) Username() (string, error) { return "", nil } + +// Context implements [Actor]. +func (unrestricted) Context() context.Context { return context.Background() } + +// ClientID implements [Actor]. +// It always returns (NoClientID, false) because the tailscaled itself +// is not a connected LocalAPI client. +func (unrestricted) ClientID() (_ ClientID, ok bool) { return NoClientID, false } + +// CheckProfileAccess implements [Actor]. +func (unrestricted) CheckProfileAccess(_ ipn.LoginProfileView, _ ProfileAccess, _ AuditLogFunc) error { + // Unrestricted access to all profiles. + return nil +} + +// IsLocalSystem implements [Actor]. +// +// Deprecated: this method exists for compatibility with the current (as of 2025-01-28) +// permission model and will be removed as we progress on tailscale/corp#18342. +func (unrestricted) IsLocalSystem() bool { return false } + +// IsLocalAdmin implements [Actor]. +// +// Deprecated: this method exists for compatibility with the current (as of 2025-01-28) +// permission model and will be removed as we progress on tailscale/corp#18342. +func (unrestricted) IsLocalAdmin(operatorUID string) bool { return false } + +// IsTailscaled reports whether the given Actor represents Tailscaled itself, +// such as [Self] or a [TODO] placeholder actor. +func IsTailscaled(a Actor) bool { + _, ok := a.(unrestricted) + return ok +} diff --git a/ipn/ipnauth/test_actor.go b/ipn/ipnauth/test_actor.go new file mode 100644 index 000000000..80c5fcc8a --- /dev/null +++ b/ipn/ipnauth/test_actor.go @@ -0,0 +1,48 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package ipnauth + +import ( + "cmp" + "context" + "errors" + + "tailscale.com/ipn" +) + +var _ Actor = (*TestActor)(nil) + +// TestActor is an [Actor] used exclusively for testing purposes. +type TestActor struct { + UID ipn.WindowsUserID // OS-specific UID of the user, if the actor represents a local Windows user + Name string // username associated with the actor, or "" + NameErr error // error to be returned by [TestActor.Username] + CID ClientID // non-zero if the actor represents a connected LocalAPI client + Ctx context.Context // context associated with the actor + LocalSystem bool // whether the actor represents the special Local System account on Windows + LocalAdmin bool // whether the actor has local admin access +} + +// UserID implements [Actor]. +func (a *TestActor) UserID() ipn.WindowsUserID { return a.UID } + +// Username implements [Actor]. +func (a *TestActor) Username() (string, error) { return a.Name, a.NameErr } + +// ClientID implements [Actor]. +func (a *TestActor) ClientID() (_ ClientID, ok bool) { return a.CID, a.CID != NoClientID } + +// Context implements [Actor]. +func (a *TestActor) Context() context.Context { return cmp.Or(a.Ctx, context.Background()) } + +// CheckProfileAccess implements [Actor]. +func (a *TestActor) CheckProfileAccess(profile ipn.LoginProfileView, _ ProfileAccess, _ AuditLogFunc) error { + return errors.New("profile access denied") +} + +// IsLocalSystem implements [Actor]. +func (a *TestActor) IsLocalSystem() bool { return a.LocalSystem } + +// IsLocalAdmin implements [Actor]. +func (a *TestActor) IsLocalAdmin(operatorUID string) bool { return a.LocalAdmin } diff --git a/ipn/ipnext/ipnext.go b/ipn/ipnext/ipnext.go new file mode 100644 index 000000000..fc93cc876 --- /dev/null +++ b/ipn/ipnext/ipnext.go @@ -0,0 +1,411 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package ipnext defines types and interfaces used for extending the core LocalBackend +// functionality with additional features and services. +package ipnext + +import ( + "errors" + "fmt" + "iter" + "net/netip" + + "tailscale.com/control/controlclient" + "tailscale.com/feature" + "tailscale.com/ipn" + "tailscale.com/ipn/ipnauth" + "tailscale.com/ipn/ipnstate" + "tailscale.com/tailcfg" + "tailscale.com/tsd" + "tailscale.com/tstime" + "tailscale.com/types/logger" + "tailscale.com/types/mapx" +) + +// Extension augments LocalBackend with additional functionality. +// +// An extension uses the provided [Host] to register callbacks +// and interact with the backend in a controlled, well-defined +// and thread-safe manner. +// +// Extensions are registered using [RegisterExtension]. +// +// They must be safe for concurrent use. +type Extension interface { + // Name is a unique name of the extension. + // It must be the same as the name used to register the extension. + Name() string + + // Init is called to initialize the extension when LocalBackend's + // Start method is called. Extensions are created but not initialized + // unless LocalBackend is started. + // + // If the extension cannot be initialized, it must return an error, + // and its Shutdown method will not be called on the host's shutdown. + // Returned errors are not fatal; they are used for logging. + // A [SkipExtension] error indicates an intentional decision rather than a failure. + Init(Host) error + + // Shutdown is called when LocalBackend is shutting down, + // provided the extension was initialized. For multiple extensions, + // Shutdown is called in the reverse order of Init. + // Returned errors are not fatal; they are used for logging. + // After a call to Shutdown, the extension will not be called again. + Shutdown() error +} + +// NewExtensionFn is a function that instantiates an [Extension]. +// If a registered extension cannot be instantiated, the function must return an error. +// If the extension should be skipped at runtime, it must return either [SkipExtension] +// or a wrapped [SkipExtension]. Any other error returned is fatal and will prevent +// the LocalBackend from starting. +type NewExtensionFn func(logger.Logf, SafeBackend) (Extension, error) + +// SkipExtension is an error returned by [NewExtensionFn] to indicate that the extension +// should be skipped rather than prevent the LocalBackend from starting. +// +// Skipping an extension should be reserved for cases where the extension is not supported +// on the current platform or configuration, or depends on a feature that is not available, +// or otherwise should be disabled permanently rather than temporarily. +// +// Specifically, it must not be returned if the extension is not required right now +// based on user preferences, policy settings, the current tailnet, or other factors +// that may change throughout the LocalBackend's lifetime. +var SkipExtension = errors.New("skipping extension") + +// Definition describes a registered [Extension]. +type Definition struct { + name string // name under which the extension is registered + newFn NewExtensionFn // function that creates a new instance of the extension +} + +// Name returns the name of the extension. +func (d *Definition) Name() string { + return d.name +} + +// MakeExtension instantiates the extension. +func (d *Definition) MakeExtension(logf logger.Logf, sb SafeBackend) (Extension, error) { + ext, err := d.newFn(logf, sb) + if err != nil { + return nil, err + } + if ext.Name() != d.name { + return nil, fmt.Errorf("extension name mismatch: registered %q; actual %q", d.name, ext.Name()) + } + return ext, nil +} + +// extensions is a map of registered extensions, +// where the key is the name of the extension. +var extensions mapx.OrderedMap[string, *Definition] + +// RegisterExtension registers a function that instantiates an [Extension]. +// The name must be the same as returned by the extension's [Extension.Name]. +// +// It must be called on the main goroutine before LocalBackend is created, +// such as from an init function of the package implementing the extension. +// +// It panics if newExt is nil or if an extension with the same name +// has already been registered. +func RegisterExtension(name string, newExt NewExtensionFn) { + if newExt == nil { + panic(fmt.Sprintf("ipnext: newExt is nil: %q", name)) + } + if extensions.Contains(name) { + panic(fmt.Sprintf("ipnext: duplicate extension name %q", name)) + } + extensions.Set(name, &Definition{name, newExt}) +} + +// Extensions iterates over the extensions in the order they were registered +// via [RegisterExtension]. +func Extensions() iter.Seq[*Definition] { + return extensions.Values() +} + +// DefinitionForTest returns a [Definition] for the specified [Extension]. +// It is primarily used for testing where the test code needs to instantiate +// and use an extension without registering it. +func DefinitionForTest(ext Extension) *Definition { + return &Definition{ + name: ext.Name(), + newFn: func(logger.Logf, SafeBackend) (Extension, error) { return ext, nil }, + } +} + +// DefinitionWithErrForTest returns a [Definition] with the specified extension name +// whose [Definition.MakeExtension] method returns the specified error. +// It is used for testing. +func DefinitionWithErrForTest(name string, err error) *Definition { + return &Definition{ + name: name, + newFn: func(logger.Logf, SafeBackend) (Extension, error) { return nil, err }, + } +} + +// Host is the API surface used by [Extension]s to interact with LocalBackend +// in a controlled manner. +// +// Extensions can register callbacks, request information, or perform actions +// via the [Host] interface. +// +// Typically, the host invokes registered callbacks when one of the following occurs: +// - LocalBackend notifies it of an event or state change that may be +// of interest to extensions, such as when switching [ipn.LoginProfile]. +// - LocalBackend needs to consult extensions for information, for example, +// determining the most appropriate profile for the current state of the system. +// - LocalBackend performs an extensible action, such as logging an auditable event, +// and delegates its execution to the extension. +// +// The callbacks are invoked synchronously, and the LocalBackend's state +// remains unchanged while callbacks execute. +// +// In contrast, actions initiated by extensions are generally asynchronous, +// as indicated by the "Async" suffix in their names. +// Performing actions may result in callbacks being invoked as described above. +// +// To prevent conflicts between extensions competing for shared state, +// such as the current profile or prefs, the host must not expose methods +// that directly modify that state. For example, instead of allowing extensions +// to switch profiles at-will, the host's [ProfileServices] provides a method +// to switch to the "best" profile. The host can then consult extensions +// to determine the appropriate profile to use and resolve any conflicts +// in a controlled manner. +// +// A host must be safe for concurrent use. +type Host interface { + // Extensions returns the host's [ExtensionServices]. + Extensions() ExtensionServices + + // Profiles returns the host's [ProfileServices]. + Profiles() ProfileServices + + // AuditLogger returns a function that calls all currently registered audit loggers. + // The function fails if any logger returns an error, indicating that the action + // cannot be logged and must not be performed. + // + // The returned function captures the current state (e.g., the current profile) at + // the time of the call and must not be persisted. + AuditLogger() ipnauth.AuditLogFunc + + // Hooks returns a non-nil pointer to a [Hooks] struct. + // Hooks must not be modified concurrently or after Tailscale has started. + Hooks() *Hooks + + // SendNotifyAsync sends a notification to the IPN bus, + // typically to the GUI client. + SendNotifyAsync(ipn.Notify) + + // NodeBackend returns the [NodeBackend] for the currently active node + // (which is approximately the same as the current profile). + NodeBackend() NodeBackend +} + +// SafeBackend is a subset of the [ipnlocal.LocalBackend] type's methods that +// are safe to call from extension hooks at any time (even hooks called while +// LocalBackend's internal mutex is held). +type SafeBackend interface { + Sys() *tsd.System + Clock() tstime.Clock + TailscaleVarRoot() string +} + +// ExtensionServices provides access to the [Host]'s extension management services, +// such as fetching active extensions. +type ExtensionServices interface { + // FindExtensionByName returns an active extension with the given name, + // or nil if no such extension exists. + FindExtensionByName(name string) any + + // FindMatchingExtension finds the first active extension that matches target, + // and if one is found, sets target to that extension and returns true. + // Otherwise, it returns false. + // + // It panics if target is not a non-nil pointer to either a type + // that implements [ipnext.Extension], or to any interface type. + FindMatchingExtension(target any) bool +} + +// ProfileServices provides access to the [Host]'s profile management services, +// such as switching profiles and registering profile change callbacks. +type ProfileServices interface { + // CurrentProfileState returns read-only views of the current profile + // and its preferences. The returned views are always valid, + // but the profile's [ipn.LoginProfileView.ID] returns "" + // if the profile is new and has not been persisted yet. + // + // The returned views are immutable snapshots of the current profile + // and prefs at the time of the call. The actual state is only guaranteed + // to remain unchanged and match these views for the duration + // of a callback invoked by the host, if used within that callback. + // + // Extensions that need the current profile or prefs at other times + // should typically subscribe to [ProfileStateChangeCallback] + // to be notified if the profile or prefs change after retrieval. + // CurrentProfileState returns both the profile and prefs + // to guarantee that they are consistent with each other. + CurrentProfileState() (ipn.LoginProfileView, ipn.PrefsView) + + // CurrentPrefs is like [CurrentProfileState] but only returns prefs. + CurrentPrefs() ipn.PrefsView + + // SwitchToBestProfileAsync asynchronously selects the best profile to use + // and switches to it, unless it is already the current profile. + // + // If an extension needs to know when a profile switch occurs, + // it must use [ProfileServices.RegisterProfileStateChangeCallback] + // to register a [ProfileStateChangeCallback]. + // + // The reason indicates why the profile is being switched, such as due + // to a client connecting or disconnecting or a change in the desktop + // session state. It is used for logging. + SwitchToBestProfileAsync(reason string) +} + +// ProfileStore provides read-only access to available login profiles and their preferences. +// It is not safe for concurrent use and can only be used from the callback it is passed to. +type ProfileStore interface { + // CurrentUserID returns the current user ID. It is only non-empty on + // Windows where we have a multi-user system. + // + // Deprecated: this method exists for compatibility with the current (as of 2024-08-27) + // permission model and will be removed as we progress on tailscale/corp#18342. + CurrentUserID() ipn.WindowsUserID + + // CurrentProfile returns a read-only [ipn.LoginProfileView] of the current profile. + // The returned view is always valid, but the profile's [ipn.LoginProfileView.ID] + // returns "" if the profile is new and has not been persisted yet. + CurrentProfile() ipn.LoginProfileView + + // CurrentPrefs returns a read-only view of the current prefs. + // The returned view is always valid. + CurrentPrefs() ipn.PrefsView + + // DefaultUserProfile returns a read-only view of the default (last used) profile for the specified user. + // It returns a read-only view of a new, non-persisted profile if the specified user does not have a default profile. + DefaultUserProfile(uid ipn.WindowsUserID) ipn.LoginProfileView +} + +// AuditLogProvider is a function that returns an [ipnauth.AuditLogFunc] for +// logging auditable actions. +type AuditLogProvider func() ipnauth.AuditLogFunc + +// ProfileResolver is a function that returns a read-only view of a login profile. +// An invalid view indicates no profile. A valid profile view with an empty [ipn.ProfileID] +// indicates that the profile is new and has not been persisted yet. +// The provided [ProfileStore] can only be used for the duration of the callback. +type ProfileResolver func(ProfileStore) ipn.LoginProfileView + +// ProfileStateChangeCallback is a function to be called when the current login profile +// or its preferences change. +// +// The sameNode parameter indicates whether the profile represents the same node as before, +// which is true when: +// - Only the profile's [ipn.Prefs] or metadata (e.g., [tailcfg.UserProfile]) have changed, +// but the node ID and [ipn.ProfileID] remain the same. +// - The profile has been persisted and assigned an [ipn.ProfileID] for the first time, +// so while its node ID and [ipn.ProfileID] have changed, it is still the same profile. +// +// It can be used to decide whether to reset state bound to the current profile or node identity. +// +// The profile and prefs are always valid, but the profile's [ipn.LoginProfileView.ID] +// returns "" if the profile is new and has not been persisted yet. +type ProfileStateChangeCallback func(_ ipn.LoginProfileView, _ ipn.PrefsView, sameNode bool) + +// NewControlClientCallback is a function to be called when a new [controlclient.Client] +// is created and before it is first used. The specified profile represents the node +// for which the cc is created and is always valid. Its [ipn.LoginProfileView.ID] +// returns "" if it is a new node whose profile has never been persisted. +// +// If the [controlclient.Client] is created due to a profile switch, any registered +// [ProfileStateChangeCallback]s are called first. +// +// It returns a function to be called when the cc is being shut down, +// or nil if no cleanup is needed. That cleanup function should not call +// back into LocalBackend, which may be locked during shutdown. +type NewControlClientCallback func(controlclient.Client, ipn.LoginProfileView) (cleanup func()) + +// Hooks is a collection of hooks that extensions can add to (non-concurrently) +// during program initialization and can be called by LocalBackend and others at +// runtime. +// +// Each hook has its own rules about when it's called and what environment it +// has access to and what it's allowed to do. +type Hooks struct { + // BackendStateChange is called when the backend state changes. + BackendStateChange feature.Hooks[func(ipn.State)] + + // ProfileStateChange contains callbacks that are invoked when the current login profile + // or its [ipn.Prefs] change, after those changes have been made. The current login profile + // may be changed either because of a profile switch, or because the profile information + // was updated by [LocalBackend.SetControlClientStatus], including when the profile + // is first populated and persisted. + ProfileStateChange feature.Hooks[ProfileStateChangeCallback] + + // BackgroundProfileResolvers are registered background profile resolvers. + // They're used to determine the profile to use when no GUI/CLI client is connected. + // + // TODO(nickkhyl): allow specifying some kind of priority/altitude for the resolver. + // TODO(nickkhyl): make it a "profile resolver" instead of a "background profile resolver". + // The concepts of the "current user", "foreground profile" and "background profile" + // only exist on Windows, and we're moving away from them anyway. + BackgroundProfileResolvers feature.Hooks[ProfileResolver] + + // AuditLoggers are registered [AuditLogProvider]s. + // Each provider is called to get an [ipnauth.AuditLogFunc] when an auditable action + // is about to be performed. If an audit logger returns an error, the action is denied. + AuditLoggers feature.Hooks[AuditLogProvider] + + // NewControlClient are the functions to be called when a new control client + // is created. It is called with the LocalBackend locked. + NewControlClient feature.Hooks[NewControlClientCallback] + + // OnSelfChange is called (with LocalBackend.mu held) when the self node + // changes, including changing to nothing (an invalid view). + OnSelfChange feature.Hooks[func(tailcfg.NodeView)] + + // MutateNotifyLocked is called to optionally mutate the provided Notify + // before sending it to the IPN bus. It is called with LocalBackend.mu held. + MutateNotifyLocked feature.Hooks[func(*ipn.Notify)] + + // SetPeerStatus is called to mutate PeerStatus. + // Callers must only use NodeBackend to read data. + SetPeerStatus feature.Hooks[func(*ipnstate.PeerStatus, tailcfg.NodeView, NodeBackend)] + + // ShouldUploadServices reports whether this node should include services + // in Hostinfo from the portlist extension. + ShouldUploadServices feature.Hook[func() bool] +} + +// NodeBackend is an interface to query the current node and its peers. +// +// It is not a snapshot in time but is locked to a particular node. +type NodeBackend interface { + // AppendMatchingPeers appends all peers that match the predicate + // to the base slice and returns it. + AppendMatchingPeers(base []tailcfg.NodeView, pred func(tailcfg.NodeView) bool) []tailcfg.NodeView + + // PeerCaps returns the capabilities that src has to this node. + PeerCaps(src netip.Addr) tailcfg.PeerCapMap + + // PeerHasCap reports whether the peer has the specified peer capability. + PeerHasCap(peer tailcfg.NodeView, cap tailcfg.PeerCapability) bool + + // PeerAPIBase returns the "http://ip:port" URL base to reach peer's + // PeerAPI, or the empty string if the peer is invalid or doesn't support + // PeerAPI. + PeerAPIBase(tailcfg.NodeView) string + + // PeerHasPeerAPI whether the provided peer supports PeerAPI. + // + // It effectively just reports whether PeerAPIBase(node) is non-empty, but + // potentially more efficiently. + PeerHasPeerAPI(tailcfg.NodeView) bool + + // CollectServices reports whether the control plane is telling this + // node that the portlist service collection is desirable, should it + // choose to report them. + CollectServices() bool +} diff --git a/ipn/ipnlocal/autoupdate.go b/ipn/ipnlocal/autoupdate.go deleted file mode 100644 index b7d217a10..000000000 --- a/ipn/ipnlocal/autoupdate.go +++ /dev/null @@ -1,65 +0,0 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build linux || windows - -package ipnlocal - -import ( - "context" - "time" - - "tailscale.com/clientupdate" - "tailscale.com/ipn" - "tailscale.com/version" -) - -func (b *LocalBackend) stopOfflineAutoUpdate() { - if b.offlineAutoUpdateCancel != nil { - b.logf("offline auto-update: stopping update checks") - b.offlineAutoUpdateCancel() - b.offlineAutoUpdateCancel = nil - } -} - -func (b *LocalBackend) maybeStartOfflineAutoUpdate(prefs ipn.PrefsView) { - if !prefs.AutoUpdate().Apply.EqualBool(true) { - return - } - // AutoUpdate.Apply field in prefs can only be true for platforms that - // support auto-updates. But check it here again, just in case. - if !clientupdate.CanAutoUpdate() { - return - } - // On macsys, auto-updates are managed by Sparkle. - if version.IsMacSysExt() { - return - } - - if b.offlineAutoUpdateCancel != nil { - // Already running. - return - } - ctx, cancel := context.WithCancel(context.Background()) - b.offlineAutoUpdateCancel = cancel - - b.logf("offline auto-update: starting update checks") - go b.offlineAutoUpdate(ctx) -} - -const offlineAutoUpdateCheckPeriod = time.Hour - -func (b *LocalBackend) offlineAutoUpdate(ctx context.Context) { - t := time.NewTicker(offlineAutoUpdateCheckPeriod) - defer t.Stop() - for { - select { - case <-ctx.Done(): - return - case <-t.C: - } - if err := b.startAutoUpdate("offline auto-update"); err != nil { - b.logf("offline auto-update: failed: %v", err) - } - } -} diff --git a/ipn/ipnlocal/autoupdate_disabled.go b/ipn/ipnlocal/autoupdate_disabled.go deleted file mode 100644 index 88ed68c95..000000000 --- a/ipn/ipnlocal/autoupdate_disabled.go +++ /dev/null @@ -1,18 +0,0 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !(linux || windows) - -package ipnlocal - -import ( - "tailscale.com/ipn" -) - -func (b *LocalBackend) stopOfflineAutoUpdate() { - // Not supported on this platform. -} - -func (b *LocalBackend) maybeStartOfflineAutoUpdate(prefs ipn.PrefsView) { - // Not supported on this platform. -} diff --git a/ipn/ipnlocal/bus.go b/ipn/ipnlocal/bus.go new file mode 100644 index 000000000..910e4e774 --- /dev/null +++ b/ipn/ipnlocal/bus.go @@ -0,0 +1,161 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package ipnlocal + +import ( + "context" + "time" + + "tailscale.com/ipn" + "tailscale.com/tstime" +) + +type rateLimitingBusSender struct { + fn func(*ipn.Notify) (keepGoing bool) + lastFlush time.Time // last call to fn, or zero value if none + interval time.Duration // 0 to flush immediately; non-zero to rate limit sends + clock tstime.DefaultClock // non-nil for testing + didSendTestHook func() // non-nil for testing + + // pending, if non-nil, is the pending notification that we + // haven't sent yet. We own this memory to mutate. + pending *ipn.Notify + + // flushTimer is non-nil if the timer is armed. + flushTimer tstime.TimerController // effectively a *time.Timer + flushTimerC <-chan time.Time // ... said ~Timer's C chan +} + +func (s *rateLimitingBusSender) close() { + if s.flushTimer != nil { + s.flushTimer.Stop() + } +} + +func (s *rateLimitingBusSender) flushChan() <-chan time.Time { + return s.flushTimerC +} + +func (s *rateLimitingBusSender) flush() (keepGoing bool) { + if n := s.pending; n != nil { + s.pending = nil + return s.flushNotify(n) + } + return true +} + +func (s *rateLimitingBusSender) flushNotify(n *ipn.Notify) (keepGoing bool) { + s.lastFlush = s.clock.Now() + return s.fn(n) +} + +// send conditionally sends n to the underlying fn, possibly rate +// limiting it, depending on whether s.interval is set, and whether +// n is a notable notification that the client (typically a GUI) would +// want to act on (render) immediately. +// +// It returns whether the caller should keep looping. +// +// The passed-in memory 'n' is owned by the caller and should +// not be mutated. +func (s *rateLimitingBusSender) send(n *ipn.Notify) (keepGoing bool) { + if s.interval <= 0 { + // No rate limiting case. + return s.fn(n) + } + if isNotableNotify(n) { + // Notable notifications are always sent immediately. + // But first send any boring one that was pending. + // TODO(bradfitz): there might be a boring one pending + // with a NetMap or Engine field that is redundant + // with the new one (n) with NetMap or Engine populated. + // We should clear the pending one's NetMap/Engine in + // that case. Or really, merge the two, but mergeBoringNotifies + // only handles the case of both sides being boring. + // So for now, flush both. + if !s.flush() { + return false + } + return s.flushNotify(n) + } + s.pending = mergeBoringNotifies(s.pending, n) + d := s.clock.Now().Sub(s.lastFlush) + if d > s.interval { + return s.flush() + } + nextFlushIn := s.interval - d + if s.flushTimer == nil { + s.flushTimer, s.flushTimerC = s.clock.NewTimer(nextFlushIn) + } else { + s.flushTimer.Reset(nextFlushIn) + } + return true +} + +func (s *rateLimitingBusSender) Run(ctx context.Context, ch <-chan *ipn.Notify) { + for { + select { + case <-ctx.Done(): + return + case n, ok := <-ch: + if !ok { + return + } + if !s.send(n) { + return + } + if f := s.didSendTestHook; f != nil { + f() + } + case <-s.flushChan(): + if !s.flush() { + return + } + } + } +} + +// mergeBoringNotify merges new notify 'src' into possibly-nil 'dst', +// either mutating 'dst' or allocating a new one if 'dst' is nil, +// returning the merged result. +// +// dst and src must both be "boring" (i.e. not notable per isNotifiableNotify). +func mergeBoringNotifies(dst, src *ipn.Notify) *ipn.Notify { + if dst == nil { + dst = &ipn.Notify{Version: src.Version} + } + if src.NetMap != nil { + dst.NetMap = src.NetMap + } + if src.Engine != nil { + dst.Engine = src.Engine + } + return dst +} + +// isNotableNotify reports whether n is a "notable" notification that +// should be sent on the IPN bus immediately (e.g. to GUIs) without +// rate limiting it for a few seconds. +// +// It effectively reports whether n contains any field set that's +// not NetMap or Engine. +func isNotableNotify(n *ipn.Notify) bool { + if n == nil { + return false + } + return n.State != nil || + n.SessionID != "" || + n.BrowseToURL != nil || + n.LocalTCPPort != nil || + n.ClientVersion != nil || + n.Prefs != nil || + n.ErrMessage != nil || + n.LoginFinished != nil || + !n.DriveShares.IsNil() || + n.Health != nil || + len(n.IncomingFiles) > 0 || + len(n.OutgoingFiles) > 0 || + n.FilesWaiting != nil || + n.SuggestedExitNode != nil +} diff --git a/ipn/ipnlocal/bus_test.go b/ipn/ipnlocal/bus_test.go new file mode 100644 index 000000000..5c75ac54d --- /dev/null +++ b/ipn/ipnlocal/bus_test.go @@ -0,0 +1,220 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package ipnlocal + +import ( + "context" + "reflect" + "slices" + "testing" + "time" + + "tailscale.com/drive" + "tailscale.com/ipn" + "tailscale.com/tstest" + "tailscale.com/tstime" + "tailscale.com/types/logger" + "tailscale.com/types/netmap" + "tailscale.com/types/views" +) + +func TestIsNotableNotify(t *testing.T) { + tests := []struct { + name string + notify *ipn.Notify + want bool + }{ + {"nil", nil, false}, + {"empty", &ipn.Notify{}, false}, + {"version", &ipn.Notify{Version: "foo"}, false}, + {"netmap", &ipn.Notify{NetMap: new(netmap.NetworkMap)}, false}, + {"engine", &ipn.Notify{Engine: new(ipn.EngineStatus)}, false}, + } + + // Then for all other fields, assume they're notable. + // We use reflect to catch fields that might be added in the future without + // remembering to update the [isNotableNotify] function. + rt := reflect.TypeFor[ipn.Notify]() + for i := range rt.NumField() { + n := &ipn.Notify{} + sf := rt.Field(i) + switch sf.Name { + case "_", "NetMap", "Engine", "Version": + // Already covered above or not applicable. + continue + case "DriveShares": + n.DriveShares = views.SliceOfViews[*drive.Share, drive.ShareView](make([]*drive.Share, 1)) + default: + rf := reflect.ValueOf(n).Elem().Field(i) + switch rf.Kind() { + case reflect.Pointer: + rf.Set(reflect.New(rf.Type().Elem())) + case reflect.String: + rf.SetString("foo") + case reflect.Slice: + rf.Set(reflect.MakeSlice(rf.Type(), 1, 1)) + default: + t.Errorf("unhandled field kind %v for %q", rf.Kind(), sf.Name) + } + } + + tests = append(tests, struct { + name string + notify *ipn.Notify + want bool + }{ + name: "field-" + rt.Field(i).Name, + notify: n, + want: true, + }) + } + + for _, tt := range tests { + if got := isNotableNotify(tt.notify); got != tt.want { + t.Errorf("%v: got %v; want %v", tt.name, got, tt.want) + } + } +} + +type rateLimitingBusSenderTester struct { + tb testing.TB + got []*ipn.Notify + clock *tstest.Clock + s *rateLimitingBusSender +} + +func (st *rateLimitingBusSenderTester) init() { + if st.s != nil { + return + } + st.clock = tstest.NewClock(tstest.ClockOpts{ + Start: time.Unix(1731777537, 0), // time I wrote this test :) + }) + st.s = &rateLimitingBusSender{ + clock: tstime.DefaultClock{Clock: st.clock}, + fn: func(n *ipn.Notify) bool { + st.got = append(st.got, n) + return true + }, + } +} + +func (st *rateLimitingBusSenderTester) send(n *ipn.Notify) { + st.tb.Helper() + st.init() + if !st.s.send(n) { + st.tb.Fatal("unexpected send failed") + } +} + +func (st *rateLimitingBusSenderTester) advance(d time.Duration) { + st.tb.Helper() + st.clock.Advance(d) + select { + case <-st.s.flushChan(): + if !st.s.flush() { + st.tb.Fatal("unexpected flush failed") + } + default: + } +} + +func TestRateLimitingBusSender(t *testing.T) { + nm1 := &ipn.Notify{NetMap: new(netmap.NetworkMap)} + nm2 := &ipn.Notify{NetMap: new(netmap.NetworkMap)} + eng1 := &ipn.Notify{Engine: new(ipn.EngineStatus)} + eng2 := &ipn.Notify{Engine: new(ipn.EngineStatus)} + + t.Run("unbuffered", func(t *testing.T) { + st := &rateLimitingBusSenderTester{tb: t} + st.send(nm1) + st.send(nm2) + st.send(eng1) + st.send(eng2) + if !slices.Equal(st.got, []*ipn.Notify{nm1, nm2, eng1, eng2}) { + t.Errorf("got %d items; want 4 specific ones, unmodified", len(st.got)) + } + }) + + t.Run("buffered", func(t *testing.T) { + st := &rateLimitingBusSenderTester{tb: t} + st.init() + st.s.interval = 1 * time.Second + st.send(&ipn.Notify{Version: "initial"}) + if len(st.got) != 1 { + t.Fatalf("got %d items; expected 1 (first to flush immediately)", len(st.got)) + } + st.send(nm1) + st.send(nm2) + st.send(eng1) + st.send(eng2) + if len(st.got) != 1 { + if len(st.got) != 1 { + t.Fatalf("got %d items; expected still just that first 1", len(st.got)) + } + } + + // But moving the clock should flush the rest, collasced into one new one. + st.advance(5 * time.Second) + if len(st.got) != 2 { + t.Fatalf("got %d items; want 2", len(st.got)) + } + gotn := st.got[1] + if gotn.NetMap != nm2.NetMap { + t.Errorf("got wrong NetMap; got %p", gotn.NetMap) + } + if gotn.Engine != eng2.Engine { + t.Errorf("got wrong Engine; got %p", gotn.Engine) + } + if t.Failed() { + t.Logf("failed Notify was: %v", logger.AsJSON(gotn)) + } + }) + + // Test the Run method + t.Run("run", func(t *testing.T) { + st := &rateLimitingBusSenderTester{tb: t} + st.init() + st.s.interval = 1 * time.Second + st.s.lastFlush = st.clock.Now() // pretend we just flushed + + flushc := make(chan *ipn.Notify, 1) + st.s.fn = func(n *ipn.Notify) bool { + flushc <- n + return true + } + didSend := make(chan bool, 2) + st.s.didSendTestHook = func() { didSend <- true } + waitSend := func() { + select { + case <-didSend: + case <-time.After(5 * time.Second): + t.Error("timeout waiting for call to send") + } + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + incoming := make(chan *ipn.Notify, 2) + go func() { + incoming <- nm1 + waitSend() + incoming <- nm2 + waitSend() + st.advance(5 * time.Second) + select { + case n := <-flushc: + if n.NetMap != nm2.NetMap { + t.Errorf("got wrong NetMap; got %p", n.NetMap) + } + case <-time.After(10 * time.Second): + t.Error("timeout") + } + cancel() + }() + + st.s.Run(ctx, incoming) + }) +} diff --git a/ipn/ipnlocal/c2n.go b/ipn/ipnlocal/c2n.go index de6ca2321..b5e722b97 100644 --- a/ipn/ipnlocal/c2n.go +++ b/ipn/ipnlocal/c2n.go @@ -4,79 +4,86 @@ package ipnlocal import ( - "crypto/x509" "encoding/json" - "encoding/pem" - "errors" "fmt" "io" - "net" "net/http" - "os" - "os/exec" "path" - "path/filepath" + "reflect" "runtime" - "sort" "strconv" "strings" "time" - "github.com/kortschak/wol" - "tailscale.com/clientupdate" - "tailscale.com/envknob" + "tailscale.com/control/controlclient" + "tailscale.com/feature" + "tailscale.com/feature/buildfeatures" + "tailscale.com/health" "tailscale.com/ipn" "tailscale.com/net/sockstats" - "tailscale.com/posture" "tailscale.com/tailcfg" + "tailscale.com/types/netmap" "tailscale.com/util/clientmetric" "tailscale.com/util/goroutines" + "tailscale.com/util/httpm" "tailscale.com/util/set" - "tailscale.com/util/syspolicy" "tailscale.com/version" - "tailscale.com/version/distro" ) // c2nHandlers maps an HTTP method and URI path (without query parameters) to // its handler. The exact method+path match is preferred, but if no entry // exists for that, a map entry with an empty method is used as a fallback. -var c2nHandlers = map[methodAndPath]c2nHandler{ - // Debug. - req("/echo"): handleC2NEcho, - req("/debug/goroutines"): handleC2NDebugGoroutines, - req("/debug/prefs"): handleC2NDebugPrefs, - req("/debug/metrics"): handleC2NDebugMetrics, - req("/debug/component-logging"): handleC2NDebugComponentLogging, - req("/debug/logheap"): handleC2NDebugLogHeap, +var c2nHandlers map[methodAndPath]c2nHandler - // PPROF - We only expose a subset of typical pprof endpoints for security. - req("/debug/pprof/heap"): handleC2NPprof, - req("/debug/pprof/allocs"): handleC2NPprof, - - req("POST /logtail/flush"): handleC2NLogtailFlush, - req("POST /sockstats"): handleC2NSockStats, - - // Check TLS certificate status. - req("GET /tls-cert-status"): handleC2NTLSCertStatus, - - // SSH - req("/ssh/usernames"): handleC2NSSHUsernames, - - // Auto-updates. - req("GET /update"): handleC2NUpdateGet, - req("POST /update"): handleC2NUpdatePost, - - // Wake-on-LAN. - req("POST /wol"): handleC2NWoL, - - // Device posture. - req("GET /posture/identity"): handleC2NPostureIdentityGet, - - // App Connectors. - req("GET /appconnector/routes"): handleC2NAppConnectorDomainRoutesGet, +func init() { + c2nHandlers = map[methodAndPath]c2nHandler{} + if buildfeatures.HasC2N { + // Echo is the basic "ping" handler as used by the control plane to probe + // whether a node is reachable. In particular, it's important for + // high-availability subnet routers for the control plane to probe which of + // several candidate nodes is reachable and actually alive. + RegisterC2N("/echo", handleC2NEcho) + } + if buildfeatures.HasSSH { + RegisterC2N("/ssh/usernames", handleC2NSSHUsernames) + } + if buildfeatures.HasLogTail { + RegisterC2N("POST /logtail/flush", handleC2NLogtailFlush) + } + if buildfeatures.HasDebug { + RegisterC2N("POST /sockstats", handleC2NSockStats) + + // pprof: + // we only expose a subset of typical pprof endpoints for security. + RegisterC2N("/debug/pprof/heap", handleC2NPprof) + RegisterC2N("/debug/pprof/allocs", handleC2NPprof) + + RegisterC2N("/debug/goroutines", handleC2NDebugGoroutines) + RegisterC2N("/debug/prefs", handleC2NDebugPrefs) + RegisterC2N("/debug/metrics", handleC2NDebugMetrics) + RegisterC2N("/debug/component-logging", handleC2NDebugComponentLogging) + RegisterC2N("/debug/logheap", handleC2NDebugLogHeap) + RegisterC2N("/debug/netmap", handleC2NDebugNetMap) + RegisterC2N("/debug/health", handleC2NDebugHealth) + } + if runtime.GOOS == "linux" && buildfeatures.HasOSRouter { + RegisterC2N("POST /netfilter-kind", handleC2NSetNetfilterKind) + } +} - // Linux netfilter. - req("POST /netfilter-kind"): handleC2NSetNetfilterKind, +// RegisterC2N registers a new c2n handler for the given pattern. +// +// A pattern is like "GET /foo" (specific to an HTTP method) or "/foo" (all +// methods). It panics if the pattern is already registered. +func RegisterC2N(pattern string, h func(*LocalBackend, http.ResponseWriter, *http.Request)) { + if !buildfeatures.HasC2N { + return + } + k := req(pattern) + if _, ok := c2nHandlers[k]; ok { + panic(fmt.Sprintf("c2n: duplicate handler for %q", pattern)) + } + c2nHandlers[k] = h } type c2nHandler func(*LocalBackend, http.ResponseWriter, *http.Request) @@ -140,21 +147,108 @@ func handleC2NLogtailFlush(b *LocalBackend, w http.ResponseWriter, r *http.Reque } } +func handleC2NDebugHealth(b *LocalBackend, w http.ResponseWriter, r *http.Request) { + var st *health.State + if buildfeatures.HasDebug && b.health != nil { + st = b.health.CurrentState() + } + writeJSON(w, st) +} + +func handleC2NDebugNetMap(b *LocalBackend, w http.ResponseWriter, r *http.Request) { + if !buildfeatures.HasDebug { + http.Error(w, feature.ErrUnavailable.Error(), http.StatusNotImplemented) + return + } + ctx := r.Context() + if r.Method != httpm.POST && r.Method != httpm.GET { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + + b.logf("c2n: %s /debug/netmap received", r.Method) + + // redactAndMarshal redacts private keys from the given netmap, clears fields + // that should be omitted, and marshals it to JSON. + redactAndMarshal := func(nm *netmap.NetworkMap, omitFields []string) (json.RawMessage, error) { + for _, f := range omitFields { + field := reflect.ValueOf(nm).Elem().FieldByName(f) + if !field.IsValid() { + b.logf("c2n: /debug/netmap: unknown field %q in omitFields", f) + continue + } + field.SetZero() + } + return json.Marshal(nm) + } + + var omitFields []string + resp := &tailcfg.C2NDebugNetmapResponse{} + + if r.Method == httpm.POST { + var req tailcfg.C2NDebugNetmapRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, fmt.Sprintf("failed to decode request body: %v", err), http.StatusBadRequest) + return + } + omitFields = req.OmitFields + + if req.Candidate != nil { + cand, err := controlclient.NetmapFromMapResponseForDebug(ctx, b.unsanitizedPersist(), req.Candidate) + if err != nil { + http.Error(w, fmt.Sprintf("failed to convert candidate MapResponse: %v", err), http.StatusBadRequest) + return + } + candJSON, err := redactAndMarshal(cand, omitFields) + if err != nil { + http.Error(w, fmt.Sprintf("failed to marshal candidate netmap: %v", err), http.StatusInternalServerError) + return + } + resp.Candidate = candJSON + } + } + + var err error + resp.Current, err = redactAndMarshal(b.currentNode().netMapWithPeers(), omitFields) + if err != nil { + http.Error(w, fmt.Sprintf("failed to marshal current netmap: %v", err), http.StatusInternalServerError) + return + } + + writeJSON(w, resp) +} + func handleC2NDebugGoroutines(_ *LocalBackend, w http.ResponseWriter, r *http.Request) { + if !buildfeatures.HasDebug { + http.Error(w, feature.ErrUnavailable.Error(), http.StatusNotImplemented) + return + } w.Header().Set("Content-Type", "text/plain") w.Write(goroutines.ScrubbedGoroutineDump(true)) } func handleC2NDebugPrefs(b *LocalBackend, w http.ResponseWriter, r *http.Request) { + if !buildfeatures.HasDebug { + http.Error(w, feature.ErrUnavailable.Error(), http.StatusNotImplemented) + return + } writeJSON(w, b.Prefs()) } func handleC2NDebugMetrics(_ *LocalBackend, w http.ResponseWriter, r *http.Request) { + if !buildfeatures.HasDebug { + http.Error(w, feature.ErrUnavailable.Error(), http.StatusNotImplemented) + return + } w.Header().Set("Content-Type", "text/plain") clientmetric.WritePrometheusExpositionFormat(w) } func handleC2NDebugComponentLogging(b *LocalBackend, w http.ResponseWriter, r *http.Request) { + if !buildfeatures.HasDebug { + http.Error(w, feature.ErrUnavailable.Error(), http.StatusNotImplemented) + return + } component := r.FormValue("component") secs, _ := strconv.Atoi(r.FormValue("secs")) if secs == 0 { @@ -197,6 +291,10 @@ func handleC2NPprof(b *LocalBackend, w http.ResponseWriter, r *http.Request) { } func handleC2NSSHUsernames(b *LocalBackend, w http.ResponseWriter, r *http.Request) { + if !buildfeatures.HasSSH { + http.Error(w, feature.ErrUnavailable.Error(), http.StatusNotImplemented) + return + } var req tailcfg.C2NSSHUsernamesRequest if r.Method == "POST" { if err := json.NewDecoder(r.Body).Decode(&req); err != nil { @@ -223,26 +321,6 @@ func handleC2NSockStats(b *LocalBackend, w http.ResponseWriter, r *http.Request) fmt.Fprintf(w, "debug info: %v\n", sockstats.DebugInfo()) } -// handleC2NAppConnectorDomainRoutesGet handles returning the domains -// that the app connector is responsible for, as well as the resolved -// IP addresses for each domain. If the node is not configured as -// an app connector, an empty map is returned. -func handleC2NAppConnectorDomainRoutesGet(b *LocalBackend, w http.ResponseWriter, r *http.Request) { - b.logf("c2n: GET /appconnector/routes received") - - var res tailcfg.C2NAppConnectorDomainRoutesResponse - if b.appConnector == nil { - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(res) - return - } - - res.Domains = b.appConnector.DomainRoutes() - - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(res) -} - func handleC2NSetNetfilterKind(b *LocalBackend, w http.ResponseWriter, r *http.Request) { b.logf("c2n: POST /netfilter-kind received") @@ -268,324 +346,3 @@ func handleC2NSetNetfilterKind(b *LocalBackend, w http.ResponseWriter, r *http.R w.WriteHeader(http.StatusNoContent) } - -func handleC2NUpdateGet(b *LocalBackend, w http.ResponseWriter, r *http.Request) { - b.logf("c2n: GET /update received") - - res := b.newC2NUpdateResponse() - res.Started = b.c2nUpdateStarted() - - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(res) -} - -func handleC2NUpdatePost(b *LocalBackend, w http.ResponseWriter, r *http.Request) { - b.logf("c2n: POST /update received") - res := b.newC2NUpdateResponse() - defer func() { - if res.Err != "" { - b.logf("c2n: POST /update failed: %s", res.Err) - } - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(res) - }() - - if !res.Enabled { - res.Err = "not enabled" - return - } - if !res.Supported { - res.Err = "not supported" - return - } - - // Do not update if we have active inbound SSH connections. Control can set - // force=true query parameter to override this. - if r.FormValue("force") != "true" && b.sshServer != nil && b.sshServer.NumActiveConns() > 0 { - res.Err = "not updating due to active SSH connections" - return - } - - if err := b.startAutoUpdate("c2n"); err != nil { - res.Err = err.Error() - return - } - res.Started = true -} - -func handleC2NPostureIdentityGet(b *LocalBackend, w http.ResponseWriter, r *http.Request) { - b.logf("c2n: GET /posture/identity received") - - res := tailcfg.C2NPostureIdentityResponse{} - - // Only collect posture identity if enabled on the client, - // this will first check syspolicy, MDM settings like Registry - // on Windows or defaults on macOS. If they are not set, it falls - // back to the cli-flag, `--posture-checking`. - choice, err := syspolicy.GetPreferenceOption(syspolicy.PostureChecking) - if err != nil { - b.logf( - "c2n: failed to read PostureChecking from syspolicy, returning default from CLI: %s; got error: %s", - b.Prefs().PostureChecking(), - err, - ) - } - - if choice.ShouldEnable(b.Prefs().PostureChecking()) { - sns, err := posture.GetSerialNumbers(b.logf) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - res.SerialNumbers = sns - - // TODO(tailscale/corp#21371, 2024-07-10): once this has landed in a stable release - // and looks good in client metrics, remove this parameter and always report MAC - // addresses. - if r.FormValue("hwaddrs") == "true" { - res.IfaceHardwareAddrs, err = posture.GetHardwareAddrs() - if err != nil { - b.logf("c2n: GetHardwareAddrs returned error: %v", err) - } - } - } else { - res.PostureDisabled = true - } - - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(res) -} - -func (b *LocalBackend) newC2NUpdateResponse() tailcfg.C2NUpdateResponse { - // If NewUpdater does not return an error, we can update the installation. - // - // Note that we create the Updater solely to check for errors; we do not - // invoke it here. For this purpose, it is ok to pass it a zero Arguments. - prefs := b.Prefs().AutoUpdate() - return tailcfg.C2NUpdateResponse{ - Enabled: envknob.AllowsRemoteUpdate() || prefs.Apply.EqualBool(true), - Supported: clientupdate.CanAutoUpdate() && !version.IsMacSysExt(), - } -} - -func (b *LocalBackend) c2nUpdateStarted() bool { - b.mu.Lock() - defer b.mu.Unlock() - return b.c2nUpdateStatus.started -} - -func (b *LocalBackend) setC2NUpdateStarted(v bool) { - b.mu.Lock() - defer b.mu.Unlock() - b.c2nUpdateStatus.started = v -} - -func (b *LocalBackend) trySetC2NUpdateStarted() bool { - b.mu.Lock() - defer b.mu.Unlock() - if b.c2nUpdateStatus.started { - return false - } - b.c2nUpdateStatus.started = true - return true -} - -// findCmdTailscale looks for the cmd/tailscale that corresponds to the -// currently running cmd/tailscaled. It's up to the caller to verify that the -// two match, but this function does its best to find the right one. Notably, it -// doesn't use $PATH for security reasons. -func findCmdTailscale() (string, error) { - self, err := os.Executable() - if err != nil { - return "", err - } - var ts string - switch runtime.GOOS { - case "linux": - if self == "/usr/sbin/tailscaled" || self == "/usr/bin/tailscaled" { - ts = "/usr/bin/tailscale" - } - if self == "/usr/local/sbin/tailscaled" || self == "/usr/local/bin/tailscaled" { - ts = "/usr/local/bin/tailscale" - } - switch distro.Get() { - case distro.QNAP: - // The volume under /share/ where qpkg are installed is not - // predictable. But the rest of the path is. - ok, err := filepath.Match("/share/*/.qpkg/Tailscale/tailscaled", self) - if err == nil && ok { - ts = filepath.Join(filepath.Dir(self), "tailscale") - } - case distro.Unraid: - if self == "/usr/local/emhttp/plugins/tailscale/bin/tailscaled" { - ts = "/usr/local/emhttp/plugins/tailscale/bin/tailscale" - } - } - case "windows": - ts = filepath.Join(filepath.Dir(self), "tailscale.exe") - case "freebsd": - if self == "/usr/local/bin/tailscaled" { - ts = "/usr/local/bin/tailscale" - } - default: - return "", fmt.Errorf("unsupported OS %v", runtime.GOOS) - } - if ts != "" && regularFileExists(ts) { - return ts, nil - } - return "", errors.New("tailscale executable not found in expected place") -} - -func tailscaleUpdateCmd(cmdTS string) *exec.Cmd { - defaultCmd := exec.Command(cmdTS, "update", "--yes") - if runtime.GOOS != "linux" { - return defaultCmd - } - if _, err := exec.LookPath("systemd-run"); err != nil { - return defaultCmd - } - - // When systemd-run is available, use it to run the update command. This - // creates a new temporary unit separate from the tailscaled unit. When - // tailscaled is restarted during the update, systemd won't kill this - // temporary update unit, which could cause unexpected breakage. - // - // We want to use a few optional flags: - // * --wait, to block the update command until completion (added in systemd 232) - // * --pipe, to collect stdout/stderr (added in systemd 235) - // * --collect, to clean up failed runs from memory (added in systemd 236) - // - // We need to check the version of systemd to figure out if those flags are - // available. - // - // The output will look like: - // - // systemd 255 (255.7-1-arch) - // +PAM +AUDIT ... other feature flags ... - systemdVerOut, err := exec.Command("systemd-run", "--version").Output() - if err != nil { - return defaultCmd - } - parts := strings.Fields(string(systemdVerOut)) - if len(parts) < 2 || parts[0] != "systemd" { - return defaultCmd - } - systemdVer, err := strconv.Atoi(parts[1]) - if err != nil { - return defaultCmd - } - if systemdVer >= 236 { - return exec.Command("systemd-run", "--wait", "--pipe", "--collect", cmdTS, "update", "--yes") - } else if systemdVer >= 235 { - return exec.Command("systemd-run", "--wait", "--pipe", cmdTS, "update", "--yes") - } else if systemdVer >= 232 { - return exec.Command("systemd-run", "--wait", cmdTS, "update", "--yes") - } else { - return exec.Command("systemd-run", cmdTS, "update", "--yes") - } -} - -func regularFileExists(path string) bool { - fi, err := os.Stat(path) - return err == nil && fi.Mode().IsRegular() -} - -func handleC2NWoL(b *LocalBackend, w http.ResponseWriter, r *http.Request) { - r.ParseForm() - var macs []net.HardwareAddr - for _, macStr := range r.Form["mac"] { - mac, err := net.ParseMAC(macStr) - if err != nil { - http.Error(w, "bad 'mac' param", http.StatusBadRequest) - return - } - macs = append(macs, mac) - } - var res struct { - SentTo []string - Errors []string - } - st := b.sys.NetMon.Get().InterfaceState() - if st == nil { - res.Errors = append(res.Errors, "no interface state") - writeJSON(w, &res) - return - } - var password []byte // TODO(bradfitz): support? does anything use WoL passwords? - for _, mac := range macs { - for ifName, ips := range st.InterfaceIPs { - for _, ip := range ips { - if ip.Addr().IsLoopback() || ip.Addr().Is6() { - continue - } - local := &net.UDPAddr{ - IP: ip.Addr().AsSlice(), - Port: 0, - } - remote := &net.UDPAddr{ - IP: net.IPv4bcast, - Port: 0, - } - if err := wol.Wake(mac, password, local, remote); err != nil { - res.Errors = append(res.Errors, err.Error()) - } else { - res.SentTo = append(res.SentTo, ifName) - } - break // one per interface is enough - } - } - } - sort.Strings(res.SentTo) - writeJSON(w, &res) -} - -// handleC2NTLSCertStatus returns info about the last TLS certificate issued for the -// provided domain. This can be called by the controlplane to clean up DNS TXT -// records when they're no longer needed by LetsEncrypt. -// -// It does not kick off a cert fetch or async refresh. It only reports anything -// that's already sitting on disk, and only reports metadata about the public -// cert (stuff that'd be the in CT logs anyway). -func handleC2NTLSCertStatus(b *LocalBackend, w http.ResponseWriter, r *http.Request) { - cs, err := b.getCertStore() - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - - domain := r.FormValue("domain") - if domain == "" { - http.Error(w, "no 'domain'", http.StatusBadRequest) - return - } - - ret := &tailcfg.C2NTLSCertInfo{} - pair, err := getCertPEMCached(cs, domain, b.clock.Now()) - ret.Valid = err == nil - if err != nil { - ret.Error = err.Error() - if errors.Is(err, errCertExpired) { - ret.Expired = true - } else if errors.Is(err, ipn.ErrStateNotExist) { - ret.Missing = true - ret.Error = "no certificate" - } - } else { - block, _ := pem.Decode(pair.CertPEM) - if block == nil { - ret.Error = "invalid PEM" - ret.Valid = false - } else { - cert, err := x509.ParseCertificate(block.Bytes) - if err != nil { - ret.Error = fmt.Sprintf("invalid certificate: %v", err) - ret.Valid = false - } else { - ret.NotBefore = cert.NotBefore.UTC().Format(time.RFC3339) - ret.NotAfter = cert.NotAfter.UTC().Format(time.RFC3339) - } - } - } - - writeJSON(w, ret) -} diff --git a/ipn/ipnlocal/c2n_pprof.go b/ipn/ipnlocal/c2n_pprof.go index b4bc35790..13237cc4f 100644 --- a/ipn/ipnlocal/c2n_pprof.go +++ b/ipn/ipnlocal/c2n_pprof.go @@ -1,7 +1,7 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -//go:build !js && !wasm +//go:build !js && !wasm && !ts_omit_debug package ipnlocal diff --git a/ipn/ipnlocal/c2n_test.go b/ipn/ipnlocal/c2n_test.go index cc31e284a..86cc6a549 100644 --- a/ipn/ipnlocal/c2n_test.go +++ b/ipn/ipnlocal/c2n_test.go @@ -4,6 +4,7 @@ package ipnlocal import ( + "bytes" "cmp" "crypto/x509" "encoding/json" @@ -18,8 +19,14 @@ import ( "tailscale.com/ipn/store/mem" "tailscale.com/tailcfg" "tailscale.com/tstest" + "tailscale.com/types/key" "tailscale.com/types/logger" + "tailscale.com/types/netmap" + "tailscale.com/types/views" "tailscale.com/util/must" + + gcmp "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" ) func TestHandleC2NTLSCertStatus(t *testing.T) { @@ -132,3 +139,88 @@ func TestHandleC2NTLSCertStatus(t *testing.T) { } } + +func TestHandleC2NDebugNetmap(t *testing.T) { + nm := &netmap.NetworkMap{ + SelfNode: (&tailcfg.Node{ + ID: 100, + Name: "myhost", + StableID: "deadbeef", + Key: key.NewNode().Public(), + Hostinfo: (&tailcfg.Hostinfo{Hostname: "myhost"}).View(), + }).View(), + Peers: []tailcfg.NodeView{ + (&tailcfg.Node{ + ID: 101, + Name: "peer1", + StableID: "deadbeef", + Key: key.NewNode().Public(), + Hostinfo: (&tailcfg.Hostinfo{Hostname: "peer1"}).View(), + }).View(), + }, + } + + for _, tt := range []struct { + name string + req *tailcfg.C2NDebugNetmapRequest + want *netmap.NetworkMap + }{ + { + name: "simple_get", + want: nm, + }, + { + name: "post_no_omit", + req: &tailcfg.C2NDebugNetmapRequest{}, + want: nm, + }, + { + name: "post_omit_peers_and_name", + req: &tailcfg.C2NDebugNetmapRequest{OmitFields: []string{"Peers", "Name"}}, + want: &netmap.NetworkMap{ + SelfNode: nm.SelfNode, + }, + }, + { + name: "post_omit_nonexistent_field", + req: &tailcfg.C2NDebugNetmapRequest{OmitFields: []string{"ThisFieldDoesNotExist"}}, + want: nm, + }, + } { + t.Run(tt.name, func(t *testing.T) { + b := newTestLocalBackend(t) + b.currentNode().SetNetMap(nm) + + rec := httptest.NewRecorder() + req := httptest.NewRequest("GET", "/debug/netmap", nil) + if tt.req != nil { + b, err := json.Marshal(tt.req) + if err != nil { + t.Fatalf("json.Marshal: %v", err) + } + req = httptest.NewRequest("POST", "/debug/netmap", bytes.NewReader(b)) + } + handleC2NDebugNetMap(b, rec, req) + res := rec.Result() + wantStatus := 200 + if res.StatusCode != wantStatus { + t.Fatalf("status code = %v; want %v. Body: %s", res.Status, wantStatus, rec.Body.Bytes()) + } + var resp tailcfg.C2NDebugNetmapResponse + if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil { + t.Fatalf("bad JSON: %v", err) + } + got := &netmap.NetworkMap{} + if err := json.Unmarshal(resp.Current, got); err != nil { + t.Fatalf("bad JSON: %v", err) + } + + if diff := gcmp.Diff(tt.want, got, + gcmp.AllowUnexported(netmap.NetworkMap{}, key.NodePublic{}, views.Slice[tailcfg.FilterRule]{}), + cmpopts.EquateComparable(key.MachinePublic{}), + ); diff != "" { + t.Errorf("netmap mismatch (-want +got):\n%s", diff) + } + }) + } +} diff --git a/ipn/ipnlocal/captiveportal.go b/ipn/ipnlocal/captiveportal.go new file mode 100644 index 000000000..14f8b799e --- /dev/null +++ b/ipn/ipnlocal/captiveportal.go @@ -0,0 +1,186 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !ts_omit_captiveportal + +package ipnlocal + +import ( + "context" + "time" + + "tailscale.com/health" + "tailscale.com/net/captivedetection" + "tailscale.com/util/clientmetric" +) + +func init() { + hookCaptivePortalHealthChange.Set(captivePortalHealthChange) + hookCheckCaptivePortalLoop.Set(checkCaptivePortalLoop) +} + +var metricCaptivePortalDetected = clientmetric.NewCounter("captiveportal_detected") + +// captivePortalDetectionInterval is the duration to wait in an unhealthy state with connectivity broken +// before running captive portal detection. +const captivePortalDetectionInterval = 2 * time.Second + +func captivePortalHealthChange(b *LocalBackend, state *health.State) { + isConnectivityImpacted := false + for _, w := range state.Warnings { + // Ignore the captive portal warnable itself. + if w.ImpactsConnectivity && w.WarnableCode != captivePortalWarnable.Code { + isConnectivityImpacted = true + break + } + } + + // captiveCtx can be changed, and is protected with 'mu'; grab that + // before we start our select, below. + // + // It is guaranteed to be non-nil. + b.mu.Lock() + ctx := b.captiveCtx + b.mu.Unlock() + + // If the context is canceled, we don't need to do anything. + if ctx.Err() != nil { + return + } + + if isConnectivityImpacted { + b.logf("health: connectivity impacted; triggering captive portal detection") + + // Ensure that we select on captiveCtx so that we can time out + // triggering captive portal detection if the backend is shutdown. + select { + case b.needsCaptiveDetection <- true: + case <-ctx.Done(): + } + } else { + // If connectivity is not impacted, we know for sure we're not behind a captive portal, + // so drop any warning, and signal that we don't need captive portal detection. + b.health.SetHealthy(captivePortalWarnable) + select { + case b.needsCaptiveDetection <- false: + case <-ctx.Done(): + } + } +} + +// captivePortalWarnable is a Warnable which is set to an unhealthy state when a captive portal is detected. +var captivePortalWarnable = health.Register(&health.Warnable{ + Code: "captive-portal-detected", + Title: "Captive portal detected", + // High severity, because captive portals block all traffic and require user intervention. + Severity: health.SeverityHigh, + Text: health.StaticMessage("This network requires you to log in using your web browser."), + ImpactsConnectivity: true, +}) + +func checkCaptivePortalLoop(b *LocalBackend, ctx context.Context) { + var tmr *time.Timer + + maybeStartTimer := func() { + // If there's an existing timer, nothing to do; just continue + // waiting for it to expire. Otherwise, create a new timer. + if tmr == nil { + tmr = time.NewTimer(captivePortalDetectionInterval) + } + } + maybeStopTimer := func() { + if tmr == nil { + return + } + if !tmr.Stop() { + <-tmr.C + } + tmr = nil + } + + for { + if ctx.Err() != nil { + maybeStopTimer() + return + } + + // First, see if we have a signal on our "healthy" channel, which + // takes priority over an existing timer. Because a select is + // nondeterministic, we explicitly check this channel before + // entering the main select below, so that we're guaranteed to + // stop the timer before starting captive portal detection. + select { + case needsCaptiveDetection := <-b.needsCaptiveDetection: + if needsCaptiveDetection { + maybeStartTimer() + } else { + maybeStopTimer() + } + default: + } + + var timerChan <-chan time.Time + if tmr != nil { + timerChan = tmr.C + } + select { + case <-ctx.Done(): + // All done; stop the timer and then exit. + maybeStopTimer() + return + case <-timerChan: + // Kick off captive portal check + b.performCaptiveDetection() + // nil out timer to force recreation + tmr = nil + case needsCaptiveDetection := <-b.needsCaptiveDetection: + if needsCaptiveDetection { + maybeStartTimer() + } else { + // Healthy; cancel any existing timer + maybeStopTimer() + } + } + } +} + +// shouldRunCaptivePortalDetection reports whether captive portal detection +// should be run. It is enabled by default, but can be disabled via a control +// knob. It is also only run when the user explicitly wants the backend to be +// running. +func (b *LocalBackend) shouldRunCaptivePortalDetection() bool { + b.mu.Lock() + defer b.mu.Unlock() + return !b.ControlKnobs().DisableCaptivePortalDetection.Load() && b.pm.prefs.WantRunning() +} + +// performCaptiveDetection checks if captive portal detection is enabled via controlknob. If so, it runs +// the detection and updates the Warnable accordingly. +func (b *LocalBackend) performCaptiveDetection() { + if !b.shouldRunCaptivePortalDetection() { + return + } + + d := captivedetection.NewDetector(b.logf) + b.mu.Lock() // for b.hostinfo + cn := b.currentNode() + dm := cn.DERPMap() + preferredDERP := 0 + if b.hostinfo != nil { + if b.hostinfo.NetInfo != nil { + preferredDERP = b.hostinfo.NetInfo.PreferredDERP + } + } + ctx := b.ctx + netMon := b.NetMon() + b.mu.Unlock() + found := d.Detect(ctx, netMon, dm, preferredDERP) + if found { + if !b.health.IsUnhealthy(captivePortalWarnable) { + metricCaptivePortalDetected.Add(1) + } + b.health.SetUnhealthy(captivePortalWarnable, health.Args{}) + } else { + b.health.SetHealthy(captivePortalWarnable) + } +} diff --git a/ipn/ipnlocal/cert.go b/ipn/ipnlocal/cert.go index d87374bbb..d7133d25e 100644 --- a/ipn/ipnlocal/cert.go +++ b/ipn/ipnlocal/cert.go @@ -1,7 +1,7 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -//go:build !js +//go:build !js && !ts_omit_acme package ipnlocal @@ -24,37 +24,45 @@ import ( "log" randv2 "math/rand/v2" "net" + "net/http" "os" "path/filepath" "runtime" "slices" "strings" - "sync" "time" - "github.com/tailscale/golang-x-crypto/acme" "tailscale.com/atomicfile" "tailscale.com/envknob" + "tailscale.com/feature/buildfeatures" "tailscale.com/hostinfo" "tailscale.com/ipn" "tailscale.com/ipn/ipnstate" "tailscale.com/ipn/store" "tailscale.com/ipn/store/mem" + "tailscale.com/net/bakedroots" + "tailscale.com/syncs" + "tailscale.com/tailcfg" + "tailscale.com/tempfork/acme" "tailscale.com/types/logger" "tailscale.com/util/testenv" "tailscale.com/version" "tailscale.com/version/distro" ) +func init() { + RegisterC2N("GET /tls-cert-status", handleC2NTLSCertStatus) +} + // Process-wide cache. (A new *Handler is created per connection, // effectively per request) var ( // acmeMu guards all ACME operations, so concurrent requests // for certs don't slam ACME. The first will go through and // populate the on-disk cache and the rest should use that. - acmeMu sync.Mutex + acmeMu syncs.Mutex - renewMu sync.Mutex // lock order: acmeMu before renewMu + renewMu syncs.Mutex // lock order: acmeMu before renewMu renewCertAt = map[string]time.Time{} ) @@ -66,7 +74,7 @@ func (b *LocalBackend) certDir() (string, error) { // As a workaround for Synology DSM6 not having a "var" directory, use the // app's "etc" directory (on a small partition) to hold certs at least. // See https://github.com/tailscale/tailscale/issues/4060#issuecomment-1186592251 - if d == "" && runtime.GOOS == "linux" && distro.Get() == distro.Synology && distro.DSMVersion() == 6 { + if buildfeatures.HasSynology && d == "" && runtime.GOOS == "linux" && distro.Get() == distro.Synology && distro.DSMVersion() == 6 { d = "/var/packages/Tailscale/etc" // base; we append "certs" below } if d == "" { @@ -118,6 +126,9 @@ func (b *LocalBackend) GetCertPEMWithValidity(ctx context.Context, domain string } if pair, err := getCertPEMCached(cs, domain, now); err == nil { + if envknob.IsCertShareReadOnlyMode() { + return pair, nil + } // If we got here, we have a valid unexpired cert. // Check whether we should start an async renewal. shouldRenew, err := b.shouldStartDomainRenewal(cs, domain, now, pair, minValidity) @@ -133,7 +144,7 @@ func (b *LocalBackend) GetCertPEMWithValidity(ctx context.Context, domain string if minValidity == 0 { logf("starting async renewal") // Start renewal in the background, return current valid cert. - go b.getCertPEM(context.Background(), cs, logf, traceACME, domain, now, minValidity) + b.goTracker.Go(func() { getCertPEM(context.Background(), b, cs, logf, traceACME, domain, now, minValidity) }) return pair, nil } // If the caller requested a specific validity duration, fall through @@ -141,7 +152,11 @@ func (b *LocalBackend) GetCertPEMWithValidity(ctx context.Context, domain string logf("starting sync renewal") } - pair, err := b.getCertPEM(ctx, cs, logf, traceACME, domain, now, minValidity) + if envknob.IsCertShareReadOnlyMode() { + return nil, fmt.Errorf("retrieving cached TLS certificate failed and cert store is configured in read-only mode, not attempting to issue a new certificate: %w", err) + } + + pair, err := getCertPEM(ctx, b, cs, logf, traceACME, domain, now, minValidity) if err != nil { logf("getCertPEM: %v", err) return nil, err @@ -249,15 +264,13 @@ type certStore interface { // for now. If they're expired, it returns errCertExpired. // If they don't exist, it returns ipn.ErrStateNotExist. Read(domain string, now time.Time) (*TLSCertKeyPair, error) - // WriteCert writes the cert for domain. - WriteCert(domain string, cert []byte) error - // WriteKey writes the key for domain. - WriteKey(domain string, key []byte) error // ACMEKey returns the value previously stored via WriteACMEKey. // It is a PEM encoded ECDSA key. ACMEKey() ([]byte, error) // WriteACMEKey stores the provided PEM encoded ECDSA key. WriteACMEKey([]byte) error + // WriteTLSCertAndKey writes the cert and key for domain. + WriteTLSCertAndKey(domain string, cert, key []byte) error } var errCertExpired = errors.New("cert expired") @@ -343,6 +356,13 @@ func (f certFileStore) WriteKey(domain string, key []byte) error { return atomicfile.WriteFile(keyFile(f.dir, domain), key, 0600) } +func (f certFileStore) WriteTLSCertAndKey(domain string, cert, key []byte) error { + if err := f.WriteKey(domain, key); err != nil { + return err + } + return f.WriteCert(domain, cert) +} + // certStateStore implements certStore by storing the cert & key files in an ipn.StateStore. type certStateStore struct { ipn.StateStore @@ -352,7 +372,29 @@ type certStateStore struct { testRoots *x509.CertPool } +// TLSCertKeyReader is an interface implemented by state stores where it makes +// sense to read the TLS cert and key in a single operation that can be +// distinguished from generic state value reads. Currently this is only implemented +// by the kubestore.Store, which, in some cases, need to read cert and key from a +// non-cached TLS Secret. +type TLSCertKeyReader interface { + ReadTLSCertAndKey(domain string) ([]byte, []byte, error) +} + func (s certStateStore) Read(domain string, now time.Time) (*TLSCertKeyPair, error) { + // If we're using a store that supports atomic reads, use that + if kr, ok := s.StateStore.(TLSCertKeyReader); ok { + cert, key, err := kr.ReadTLSCertAndKey(domain) + if err != nil { + return nil, err + } + if !validCertPEM(domain, key, cert, s.testRoots, now) { + return nil, errCertExpired + } + return &TLSCertKeyPair{CertPEM: cert, KeyPEM: key, Cached: true}, nil + } + + // Otherwise fall back to separate reads certPEM, err := s.ReadState(ipn.StateKey(domain + ".crt")) if err != nil { return nil, err @@ -383,6 +425,27 @@ func (s certStateStore) WriteACMEKey(key []byte) error { return ipn.WriteState(s.StateStore, ipn.StateKey(acmePEMName), key) } +// TLSCertKeyWriter is an interface implemented by state stores that can write the TLS +// cert and key in a single atomic operation. Currently this is only implemented +// by the kubestore.StoreKube. +type TLSCertKeyWriter interface { + WriteTLSCertAndKey(domain string, cert, key []byte) error +} + +// WriteTLSCertAndKey writes the TLS cert and key for domain to the current +// LocalBackend's StateStore. +func (s certStateStore) WriteTLSCertAndKey(domain string, cert, key []byte) error { + // If we're using a store that supports atomic writes, use that. + if aw, ok := s.StateStore.(TLSCertKeyWriter); ok { + return aw.WriteTLSCertAndKey(domain, cert, key) + } + // Otherwise fall back to separate writes for cert and key. + if err := s.WriteKey(domain, key); err != nil { + return err + } + return s.WriteCert(domain, cert) +} + // TLSCertKeyPair is a TLS public and private key, and whether they were obtained // from cache or freshly obtained. type TLSCertKeyPair struct { @@ -419,21 +482,24 @@ func getCertPEMCached(cs certStore, domain string, now time.Time) (p *TLSCertKey return cs.Read(domain, now) } -func (b *LocalBackend) getCertPEM(ctx context.Context, cs certStore, logf logger.Logf, traceACME func(any), domain string, now time.Time, minValidity time.Duration) (*TLSCertKeyPair, error) { +// getCertPem checks if a cert needs to be renewed and if so, renews it. +// It can be overridden in tests. +var getCertPEM = func(ctx context.Context, b *LocalBackend, cs certStore, logf logger.Logf, traceACME func(any), domain string, now time.Time, minValidity time.Duration) (*TLSCertKeyPair, error) { acmeMu.Lock() defer acmeMu.Unlock() // In case this method was triggered multiple times in parallel (when // serving incoming requests), check whether one of the other goroutines // already renewed the cert before us. - if p, err := getCertPEMCached(cs, domain, now); err == nil { + previous, err := getCertPEMCached(cs, domain, now) + if err == nil { // shouldStartDomainRenewal caches its result so it's OK to call this // frequently. - shouldRenew, err := b.shouldStartDomainRenewal(cs, domain, now, p, minValidity) + shouldRenew, err := b.shouldStartDomainRenewal(cs, domain, now, previous, minValidity) if err != nil { logf("error checking for certificate renewal: %v", err) } else if !shouldRenew { - return p, nil + return previous, nil } } else if !errors.Is(err, ipn.ErrStateNotExist) && !errors.Is(err, errCertExpired) { return nil, err @@ -444,6 +510,10 @@ func (b *LocalBackend) getCertPEM(ctx context.Context, cs certStore, logf logger return nil, err } + if !isDefaultDirectoryURL(ac.DirectoryURL) { + logf("acme: using Directory URL %q", ac.DirectoryURL) + } + a, err := ac.GetReg(ctx, "" /* pre-RFC param */) switch { case err == nil: @@ -474,7 +544,17 @@ func (b *LocalBackend) getCertPEM(ctx context.Context, cs certStore, logf logger return nil, err } - order, err := ac.AuthorizeOrder(ctx, []acme.AuthzID{{Type: "dns", Value: domain}}) + // If we have a previous cert, include it in the order. Assuming we're + // within the ARI renewal window this should exclude us from LE rate + // limits. + var opts []acme.OrderOption + if previous != nil { + prevCrt, err := previous.parseCertificate() + if err == nil { + opts = append(opts, acme.WithOrderReplacesCert(prevCrt)) + } + } + order, err := ac.AuthorizeOrder(ctx, []acme.AuthzID{{Type: "dns", Value: domain}}, opts...) if err != nil { return nil, err } @@ -545,9 +625,6 @@ func (b *LocalBackend) getCertPEM(ctx context.Context, cs certStore, logf logger if err := encodeECDSAKey(&privPEM, certPrivKey); err != nil { return nil, err } - if err := cs.WriteKey(domain, privPEM.Bytes()); err != nil { - return nil, err - } csr, err := certRequest(certPrivKey, domain, nil) if err != nil { @@ -555,6 +632,7 @@ func (b *LocalBackend) getCertPEM(ctx context.Context, cs certStore, logf logger } logf("requesting cert...") + traceACME(csr) der, _, err := ac.CreateOrderCert(ctx, order.FinalizeURL, csr, true) if err != nil { return nil, fmt.Errorf("CreateOrder: %v", err) @@ -568,7 +646,7 @@ func (b *LocalBackend) getCertPEM(ctx context.Context, cs certStore, logf logger return nil, err } } - if err := cs.WriteCert(domain, certPEM.Bytes()); err != nil { + if err := cs.WriteTLSCertAndKey(domain, certPEM.Bytes(), privPEM.Bytes()); err != nil { return nil, err } b.domainRenewed(domain) @@ -577,10 +655,10 @@ func (b *LocalBackend) getCertPEM(ctx context.Context, cs certStore, logf logger } // certRequest generates a CSR for the given common name cn and optional SANs. -func certRequest(key crypto.Signer, cn string, ext []pkix.Extension, san ...string) ([]byte, error) { +func certRequest(key crypto.Signer, name string, ext []pkix.Extension) ([]byte, error) { req := &x509.CertificateRequest{ - Subject: pkix.Name{CommonName: cn}, - DNSNames: san, + Subject: pkix.Name{CommonName: name}, + DNSNames: []string{name}, ExtraExtensions: ext, } return x509.CreateCertificateRequest(rand.Reader, req, key) @@ -657,15 +735,16 @@ func acmeClient(cs certStore) (*acme.Client, error) { // LetsEncrypt), we should make sure that they support ARI extension (see // shouldStartDomainRenewalARI). return &acme.Client{ - Key: key, - UserAgent: "tailscaled/" + version.Long(), + Key: key, + UserAgent: "tailscaled/" + version.Long(), + DirectoryURL: envknob.String("TS_DEBUG_ACME_DIRECTORY_URL"), }, nil } // validCertPEM reports whether the given certificate is valid for domain at now. // // If roots != nil, it is used instead of the system root pool. This is meant -// to support testing, and production code should pass roots == nil. +// to support testing; production code should pass roots == nil. func validCertPEM(domain string, keyPEM, certPEM []byte, roots *x509.CertPool, now time.Time) bool { if len(keyPEM) == 0 || len(certPEM) == 0 { return false @@ -688,16 +767,51 @@ func validCertPEM(domain string, keyPEM, certPEM []byte, roots *x509.CertPool, n intermediates.AddCert(cert) } } + return validateLeaf(leaf, intermediates, domain, now, roots) +} + +// validateLeaf is a helper for [validCertPEM]. +// +// If called with roots == nil, it will use the system root pool as well as the +// baked-in roots. If non-nil, only those roots are used. +func validateLeaf(leaf *x509.Certificate, intermediates *x509.CertPool, domain string, now time.Time, roots *x509.CertPool) bool { if leaf == nil { return false } - _, err = leaf.Verify(x509.VerifyOptions{ + _, err := leaf.Verify(x509.VerifyOptions{ DNSName: domain, CurrentTime: now, Roots: roots, Intermediates: intermediates, }) - return err == nil + if err != nil && roots == nil { + // If validation failed and they specified nil for roots (meaning to use + // the system roots), then give it another chance to validate using the + // binary's baked-in roots (LetsEncrypt). See tailscale/tailscale#14690. + return validateLeaf(leaf, intermediates, domain, now, bakedroots.Get()) + } + + if err == nil { + return true + } + + // When pointed at a non-prod ACME server, we don't expect to have the CA + // in our system or baked-in roots. Verify only throws UnknownAuthorityError + // after first checking the leaf cert's expiry, hostnames etc, so we know + // that the only reason for an error is to do with constructing a full chain. + // Allow this error so that cert caching still works in testing environments. + if errors.As(err, &x509.UnknownAuthorityError{}) { + acmeURL := envknob.String("TS_DEBUG_ACME_DIRECTORY_URL") + if !isDefaultDirectoryURL(acmeURL) { + return true + } + } + + return false +} + +func isDefaultDirectoryURL(u string) bool { + return u == "" || u == acme.LetsEncryptURL } // validLookingCertDomain reports whether name looks like a valid domain name that @@ -729,3 +843,54 @@ func checkCertDomain(st *ipnstate.Status, domain string) error { } return fmt.Errorf("invalid domain %q; must be one of %q", domain, st.CertDomains) } + +// handleC2NTLSCertStatus returns info about the last TLS certificate issued for the +// provided domain. This can be called by the controlplane to clean up DNS TXT +// records when they're no longer needed by LetsEncrypt. +// +// It does not kick off a cert fetch or async refresh. It only reports anything +// that's already sitting on disk, and only reports metadata about the public +// cert (stuff that'd be the in CT logs anyway). +func handleC2NTLSCertStatus(b *LocalBackend, w http.ResponseWriter, r *http.Request) { + cs, err := b.getCertStore() + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + domain := r.FormValue("domain") + if domain == "" { + http.Error(w, "no 'domain'", http.StatusBadRequest) + return + } + + ret := &tailcfg.C2NTLSCertInfo{} + pair, err := getCertPEMCached(cs, domain, b.clock.Now()) + ret.Valid = err == nil + if err != nil { + ret.Error = err.Error() + if errors.Is(err, errCertExpired) { + ret.Expired = true + } else if errors.Is(err, ipn.ErrStateNotExist) { + ret.Missing = true + ret.Error = "no certificate" + } + } else { + block, _ := pem.Decode(pair.CertPEM) + if block == nil { + ret.Error = "invalid PEM" + ret.Valid = false + } else { + cert, err := x509.ParseCertificate(block.Bytes) + if err != nil { + ret.Error = fmt.Sprintf("invalid certificate: %v", err) + ret.Valid = false + } else { + ret.NotBefore = cert.NotBefore.UTC().Format(time.RFC3339) + ret.NotAfter = cert.NotAfter.UTC().Format(time.RFC3339) + } + } + } + + writeJSON(w, ret) +} diff --git a/ipn/ipnlocal/cert_js.go b/ipn/ipnlocal/cert_disabled.go similarity index 51% rename from ipn/ipnlocal/cert_js.go rename to ipn/ipnlocal/cert_disabled.go index 6acc57a60..17d446c11 100644 --- a/ipn/ipnlocal/cert_js.go +++ b/ipn/ipnlocal/cert_disabled.go @@ -1,20 +1,30 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause +//go:build js || ts_omit_acme + package ipnlocal import ( "context" "errors" + "io" + "net/http" "time" ) +func init() { + RegisterC2N("GET /tls-cert-status", handleC2NTLSCertStatusDisabled) +} + +var errNoCerts = errors.New("cert support not compiled in this build") + type TLSCertKeyPair struct { CertPEM, KeyPEM []byte } func (b *LocalBackend) GetCertPEM(ctx context.Context, domain string) (*TLSCertKeyPair, error) { - return nil, errors.New("not implemented for js/wasm") + return nil, errNoCerts } var errCertExpired = errors.New("cert expired") @@ -22,9 +32,14 @@ var errCertExpired = errors.New("cert expired") type certStore interface{} func getCertPEMCached(cs certStore, domain string, now time.Time) (p *TLSCertKeyPair, err error) { - return nil, errors.New("not implemented for js/wasm") + return nil, errNoCerts } func (b *LocalBackend) getCertStore() (certStore, error) { - return nil, errors.New("not implemented for js/wasm") + return nil, errNoCerts +} + +func handleC2NTLSCertStatusDisabled(b *LocalBackend, w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + io.WriteString(w, `{"Missing":true}`) // a minimal tailcfg.C2NTLSCertInfo } diff --git a/ipn/ipnlocal/cert_test.go b/ipn/ipnlocal/cert_test.go index 3ae7870e3..e2398f670 100644 --- a/ipn/ipnlocal/cert_test.go +++ b/ipn/ipnlocal/cert_test.go @@ -6,6 +6,7 @@ package ipnlocal import ( + "context" "crypto/ecdsa" "crypto/elliptic" "crypto/rand" @@ -14,11 +15,17 @@ import ( "embed" "encoding/pem" "math/big" + "os" + "path/filepath" "testing" "time" "github.com/google/go-cmp/cmp" + "tailscale.com/envknob" "tailscale.com/ipn/store/mem" + "tailscale.com/tstest" + "tailscale.com/types/logger" + "tailscale.com/util/must" ) func TestValidLookingCertDomain(t *testing.T) { @@ -47,10 +54,10 @@ var certTestFS embed.FS func TestCertStoreRoundTrip(t *testing.T) { const testDomain = "example.com" - // Use a fixed verification timestamp so validity doesn't fall off when the - // cert expires. If you update the test data below, this may also need to be - // updated. + // Use fixed verification timestamps so validity doesn't change over time. + // If you update the test data below, these may also need to be updated. testNow := time.Date(2023, time.February, 10, 0, 0, 0, 0, time.UTC) + testExpired := time.Date(2026, time.February, 10, 0, 0, 0, 0, time.UTC) // To re-generate a root certificate and domain certificate for testing, // use: @@ -78,21 +85,23 @@ func TestCertStoreRoundTrip(t *testing.T) { } tests := []struct { - name string - store certStore + name string + store certStore + debugACMEURL bool }{ - {"FileStore", certFileStore{dir: t.TempDir(), testRoots: roots}}, - {"StateStore", certStateStore{StateStore: new(mem.Store), testRoots: roots}}, + {"FileStore", certFileStore{dir: t.TempDir(), testRoots: roots}, false}, + {"FileStore_UnknownCA", certFileStore{dir: t.TempDir()}, true}, + {"StateStore", certStateStore{StateStore: new(mem.Store), testRoots: roots}, false}, + {"StateStore_UnknownCA", certStateStore{StateStore: new(mem.Store)}, true}, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - if err := test.store.WriteCert(testDomain, testCert); err != nil { - t.Fatalf("WriteCert: unexpected error: %v", err) + if test.debugACMEURL { + t.Setenv("TS_DEBUG_ACME_DIRECTORY_URL", "https://acme-staging-v02.api.letsencrypt.org/directory") } - if err := test.store.WriteKey(testDomain, testKey); err != nil { - t.Fatalf("WriteKey: unexpected error: %v", err) + if err := test.store.WriteTLSCertAndKey(testDomain, testCert, testKey); err != nil { + t.Fatalf("WriteTLSCertAndKey: unexpected error: %v", err) } - kp, err := test.store.Read(testDomain, testNow) if err != nil { t.Fatalf("Read: unexpected error: %v", err) @@ -103,6 +112,10 @@ func TestCertStoreRoundTrip(t *testing.T) { if diff := cmp.Diff(kp.KeyPEM, testKey); diff != "" { t.Errorf("Key (-got, +want):\n%s", diff) } + unexpected, err := test.store.Read(testDomain, testExpired) + if err != errCertExpired { + t.Fatalf("Read: expected expiry error: %v", string(unexpected.CertPEM)) + } }) } } @@ -199,3 +212,167 @@ func TestShouldStartDomainRenewal(t *testing.T) { }) } } + +func TestDebugACMEDirectoryURL(t *testing.T) { + for _, tc := range []string{"", "https://acme-staging-v02.api.letsencrypt.org/directory"} { + const setting = "TS_DEBUG_ACME_DIRECTORY_URL" + t.Run(tc, func(t *testing.T) { + t.Setenv(setting, tc) + ac, err := acmeClient(certStateStore{StateStore: new(mem.Store)}) + if err != nil { + t.Fatalf("acmeClient creation err: %v", err) + } + if ac.DirectoryURL != tc { + t.Fatalf("acmeClient.DirectoryURL = %q, want %q", ac.DirectoryURL, tc) + } + }) + } +} + +func TestGetCertPEMWithValidity(t *testing.T) { + const testDomain = "example.com" + b := &LocalBackend{ + store: &mem.Store{}, + varRoot: t.TempDir(), + ctx: context.Background(), + logf: t.Logf, + } + certDir, err := b.certDir() + if err != nil { + t.Fatalf("certDir error: %v", err) + } + if _, err := b.getCertStore(); err != nil { + t.Fatalf("getCertStore error: %v", err) + } + testRoot, err := certTestFS.ReadFile("testdata/rootCA.pem") + if err != nil { + t.Fatal(err) + } + roots := x509.NewCertPool() + if !roots.AppendCertsFromPEM(testRoot) { + t.Fatal("Unable to add test CA to the cert pool") + } + testX509Roots = roots + defer func() { testX509Roots = nil }() + tests := []struct { + name string + now time.Time + // storeCerts is true if the test cert and key should be written to store. + storeCerts bool + readOnlyMode bool // TS_READ_ONLY_CERTS env var + wantAsyncRenewal bool // async issuance should be started + wantIssuance bool // sync issuance should be started + wantErr bool + }{ + { + name: "valid_no_renewal", + now: time.Date(2023, time.February, 20, 0, 0, 0, 0, time.UTC), + storeCerts: true, + wantAsyncRenewal: false, + wantIssuance: false, + wantErr: false, + }, + { + name: "issuance_needed", + now: time.Date(2023, time.February, 20, 0, 0, 0, 0, time.UTC), + storeCerts: false, + wantAsyncRenewal: false, + wantIssuance: true, + wantErr: false, + }, + { + name: "renewal_needed", + now: time.Date(2025, time.May, 1, 0, 0, 0, 0, time.UTC), + storeCerts: true, + wantAsyncRenewal: true, + wantIssuance: false, + wantErr: false, + }, + { + name: "renewal_needed_read_only_mode", + now: time.Date(2025, time.May, 1, 0, 0, 0, 0, time.UTC), + storeCerts: true, + readOnlyMode: true, + wantAsyncRenewal: false, + wantIssuance: false, + wantErr: false, + }, + { + name: "no_certs_read_only_mode", + now: time.Date(2025, time.May, 1, 0, 0, 0, 0, time.UTC), + storeCerts: false, + readOnlyMode: true, + wantAsyncRenewal: false, + wantIssuance: false, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + + if tt.readOnlyMode { + envknob.Setenv("TS_CERT_SHARE_MODE", "ro") + } + + os.RemoveAll(certDir) + if tt.storeCerts { + os.MkdirAll(certDir, 0755) + if err := os.WriteFile(filepath.Join(certDir, "example.com.crt"), + must.Get(os.ReadFile("testdata/example.com.pem")), 0644); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(certDir, "example.com.key"), + must.Get(os.ReadFile("testdata/example.com-key.pem")), 0644); err != nil { + t.Fatal(err) + } + } + + b.clock = tstest.NewClock(tstest.ClockOpts{Start: tt.now}) + + allDone := make(chan bool, 1) + defer b.goTracker.AddDoneCallback(func() { + b.mu.Lock() + defer b.mu.Unlock() + if b.goTracker.RunningGoroutines() > 0 { + return + } + select { + case allDone <- true: + default: + } + })() + + // Set to true if get getCertPEM is called. GetCertPEM can be called in a goroutine for async + // renewal or in the main goroutine if issuance is required to obtain valid TLS credentials. + getCertPemWasCalled := false + getCertPEM = func(ctx context.Context, b *LocalBackend, cs certStore, logf logger.Logf, traceACME func(any), domain string, now time.Time, minValidity time.Duration) (*TLSCertKeyPair, error) { + getCertPemWasCalled = true + return nil, nil + } + prevGoRoutines := b.goTracker.StartedGoroutines() + _, err = b.GetCertPEMWithValidity(context.Background(), testDomain, 0) + if (err != nil) != tt.wantErr { + t.Errorf("b.GetCertPemWithValidity got err %v, wants error: '%v'", err, tt.wantErr) + } + // GetCertPEMWithValidity calls getCertPEM in a goroutine if async renewal is needed. That's the + // only goroutine it starts, so this can be used to test if async renewal was started. + gotAsyncRenewal := b.goTracker.StartedGoroutines()-prevGoRoutines != 0 + if gotAsyncRenewal { + select { + case <-time.After(5 * time.Second): + t.Fatal("timed out waiting for goroutines to finish") + case <-allDone: + } + } + // Verify that async renewal was triggered if expected. + if tt.wantAsyncRenewal != gotAsyncRenewal { + t.Fatalf("wants getCertPem to be called async: %v, got called %v", tt.wantAsyncRenewal, gotAsyncRenewal) + } + // Verify that (non-async) issuance was started if expected. + gotIssuance := getCertPemWasCalled && !gotAsyncRenewal + if tt.wantIssuance != gotIssuance { + t.Errorf("wants getCertPem to be called: %v, got called %v", tt.wantIssuance, gotIssuance) + } + }) + } +} diff --git a/ipn/ipnlocal/dnsconfig_test.go b/ipn/ipnlocal/dnsconfig_test.go index 19d8e8b86..e23d8a057 100644 --- a/ipn/ipnlocal/dnsconfig_test.go +++ b/ipn/ipnlocal/dnsconfig_test.go @@ -70,8 +70,8 @@ func TestDNSConfigForNetmap(t *testing.T) { { name: "self_name_and_peers", nm: &netmap.NetworkMap{ - Name: "myname.net", SelfNode: (&tailcfg.Node{ + Name: "myname.net.", Addresses: ipps("100.101.101.101"), }).View(), }, @@ -109,15 +109,15 @@ func TestDNSConfigForNetmap(t *testing.T) { // even if they have IPv4. name: "v6_only_self", nm: &netmap.NetworkMap{ - Name: "myname.net", SelfNode: (&tailcfg.Node{ + Name: "myname.net.", Addresses: ipps("fe75::1"), }).View(), }, peers: nodeViews([]*tailcfg.Node{ { ID: 1, - Name: "peera.net", + Name: "peera.net.", Addresses: ipps("100.102.0.1", "100.102.0.2", "fe75::1001"), }, { @@ -146,8 +146,8 @@ func TestDNSConfigForNetmap(t *testing.T) { { name: "extra_records", nm: &netmap.NetworkMap{ - Name: "myname.net", SelfNode: (&tailcfg.Node{ + Name: "myname.net.", Addresses: ipps("100.101.101.101"), }).View(), DNS: tailcfg.DNSConfig{ @@ -171,7 +171,9 @@ func TestDNSConfigForNetmap(t *testing.T) { { name: "corp_dns_misc", nm: &netmap.NetworkMap{ - Name: "host.some.domain.net.", + SelfNode: (&tailcfg.Node{ + Name: "host.some.domain.net.", + }).View(), DNS: tailcfg.DNSConfig{ Proxied: true, Domains: []string{"foo.com", "bar.com"}, @@ -331,8 +333,8 @@ func TestDNSConfigForNetmap(t *testing.T) { { name: "self_expired", nm: &netmap.NetworkMap{ - Name: "myname.net", SelfNode: (&tailcfg.Node{ + Name: "myname.net.", Addresses: ipps("100.101.101.101"), }).View(), }, @@ -377,19 +379,19 @@ func peersMap(s []tailcfg.NodeView) map[tailcfg.NodeID]tailcfg.NodeView { } func TestAllowExitNodeDNSProxyToServeName(t *testing.T) { - b := &LocalBackend{} + b := newTestLocalBackend(t) if b.allowExitNodeDNSProxyToServeName("google.com") { t.Fatal("unexpected true on backend with nil NetMap") } - b.netMap = &netmap.NetworkMap{ + b.currentNode().SetNetMap(&netmap.NetworkMap{ DNS: tailcfg.DNSConfig{ ExitNodeFilteredSet: []string{ ".ts.net", "some.exact.bad", }, }, - } + }) tests := []struct { name string want bool diff --git a/ipn/ipnlocal/drive.go b/ipn/ipnlocal/drive.go index 98d563d87..7d6dc2427 100644 --- a/ipn/ipnlocal/drive.go +++ b/ipn/ipnlocal/drive.go @@ -1,51 +1,35 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause +//go:build !ts_omit_drive + package ipnlocal import ( - "cmp" + "errors" "fmt" + "io" + "net/http" + "net/netip" "os" "slices" "tailscale.com/drive" "tailscale.com/ipn" "tailscale.com/tailcfg" + "tailscale.com/types/logger" "tailscale.com/types/netmap" "tailscale.com/types/views" + "tailscale.com/util/httpm" ) -const ( - // DriveLocalPort is the port on which the Taildrive listens for location - // connections on quad 100. - DriveLocalPort = 8080 -) - -// DriveSharingEnabled reports whether sharing to remote nodes via Taildrive is -// enabled. This is currently based on checking for the drive:share node -// attribute. -func (b *LocalBackend) DriveSharingEnabled() bool { - b.mu.Lock() - defer b.mu.Unlock() - return b.driveSharingEnabledLocked() -} - -func (b *LocalBackend) driveSharingEnabledLocked() bool { - return b.netMap != nil && b.netMap.SelfNode.HasCap(tailcfg.NodeAttrsTaildriveShare) +func init() { + hookSetNetMapLockedDrive.Set(setNetMapLockedDrive) } -// DriveAccessEnabled reports whether accessing Taildrive shares on remote nodes -// is enabled. This is currently based on checking for the drive:access node -// attribute. -func (b *LocalBackend) DriveAccessEnabled() bool { - b.mu.Lock() - defer b.mu.Unlock() - return b.driveAccessEnabledLocked() -} - -func (b *LocalBackend) driveAccessEnabledLocked() bool { - return b.netMap != nil && b.netMap.SelfNode.HasCap(tailcfg.NodeAttrsTaildriveAccess) +func setNetMapLockedDrive(b *LocalBackend, nm *netmap.NetworkMap) { + b.updateDrivePeersLocked(nm) + b.driveNotifyCurrentSharesLocked() } // DriveSetServerAddr tells Taildrive to use the given address for connecting @@ -266,7 +250,7 @@ func (b *LocalBackend) driveNotifyShares(shares views.SliceView[*drive.Share, dr // shares has changed since the last notification. func (b *LocalBackend) driveNotifyCurrentSharesLocked() { var shares views.SliceView[*drive.Share, drive.ShareView] - if b.driveSharingEnabledLocked() { + if b.DriveSharingEnabled() { // Only populate shares if sharing is enabled. shares = b.pm.prefs.DriveShares() } @@ -310,61 +294,203 @@ func (b *LocalBackend) updateDrivePeersLocked(nm *netmap.NetworkMap) { } var driveRemotes []*drive.Remote - if b.driveAccessEnabledLocked() { + if b.DriveAccessEnabled() { // Only populate peers if access is enabled, otherwise leave blank. driveRemotes = b.driveRemotesFromPeers(nm) } - fs.SetRemotes(b.netMap.Domain, driveRemotes, b.newDriveTransport()) + fs.SetRemotes(nm.Domain, driveRemotes, b.newDriveTransport()) } func (b *LocalBackend) driveRemotesFromPeers(nm *netmap.NetworkMap) []*drive.Remote { + b.logf("[v1] taildrive: setting up drive remotes from peers") driveRemotes := make([]*drive.Remote, 0, len(nm.Peers)) for _, p := range nm.Peers { - peerID := p.ID() - url := fmt.Sprintf("%s/%s", peerAPIBase(nm, p), taildrivePrefix[1:]) + peer := p + peerID := peer.ID() + peerKey := peer.Key().ShortString() + b.logf("[v1] taildrive: appending remote for peer %s", peerKey) driveRemotes = append(driveRemotes, &drive.Remote{ Name: p.DisplayName(false), - URL: url, + URL: func() string { + url := fmt.Sprintf("%s/%s", b.currentNode().PeerAPIBase(peer), taildrivePrefix[1:]) + b.logf("[v2] taildrive: url for peer %s: %s", peerKey, url) + return url + }, Available: func() bool { // Peers are available to Taildrive if: // - They are online + // - Their PeerAPI is reachable // - They are allowed to share at least one folder with us - b.mu.Lock() - latestNetMap := b.netMap - b.mu.Unlock() - - idx, found := slices.BinarySearchFunc(latestNetMap.Peers, peerID, func(candidate tailcfg.NodeView, id tailcfg.NodeID) int { - return cmp.Compare(candidate.ID(), id) - }) - if !found { + cn := b.currentNode() + peer, ok := cn.NodeByID(peerID) + if !ok { + b.logf("[v2] taildrive: Available(): peer %s not found", peerKey) return false } - peer := latestNetMap.Peers[idx] - // Exclude offline peers. // TODO(oxtoacart): for some reason, this correctly // catches when a node goes from offline to online, // but not the other way around... - online := peer.Online() - if online == nil || !*online { + // TODO(oxtoacart,nickkhyl): the reason was probably + // that we were using netmap.Peers instead of b.peers. + // The netmap.Peers slice is not updated in all cases. + // It should be fixed now that we use PeerByIDOk. + if !peer.Online().Get() { + b.logf("[v2] taildrive: Available(): peer %s offline", peerKey) + return false + } + + if b.currentNode().PeerAPIBase(peer) == "" { + b.logf("[v2] taildrive: Available(): peer %s PeerAPI unreachable", peerKey) return false } // Check that the peer is allowed to share with us. - addresses := peer.Addresses() - for i := range addresses.Len() { - addr := addresses.At(i) - capsMap := b.PeerCaps(addr.Addr()) - if capsMap.HasCapability(tailcfg.PeerCapabilityTaildriveSharer) { - return true - } + if cn.PeerHasCap(peer, tailcfg.PeerCapabilityTaildriveSharer) { + b.logf("[v2] taildrive: Available(): peer %s available", peerKey) + return true } + b.logf("[v2] taildrive: Available(): peer %s not allowed to share", peerKey) return false }, }) } return driveRemotes } + +// responseBodyWrapper wraps an io.ReadCloser and stores +// the number of bytesRead. +type responseBodyWrapper struct { + io.ReadCloser + logVerbose bool + bytesRx int64 + bytesTx int64 + log logger.Logf + method string + statusCode int + contentType string + fileExtension string + shareNodeKey string + selfNodeKey string + contentLength int64 +} + +// logAccess logs the taildrive: access: log line. If the logger is nil, +// the log will not be written. +func (rbw *responseBodyWrapper) logAccess(err string) { + if rbw.log == nil { + return + } + + // Some operating systems create and copy lots of 0 length hidden files for + // tracking various states. Omit these to keep logs from being too verbose. + if rbw.logVerbose || rbw.contentLength > 0 { + levelPrefix := "" + if rbw.logVerbose { + levelPrefix = "[v1] " + } + rbw.log( + "%staildrive: access: %s from %s to %s: status-code=%d ext=%q content-type=%q content-length=%.f tx=%.f rx=%.f err=%q", + levelPrefix, + rbw.method, + rbw.selfNodeKey, + rbw.shareNodeKey, + rbw.statusCode, + rbw.fileExtension, + rbw.contentType, + roundTraffic(rbw.contentLength), + roundTraffic(rbw.bytesTx), roundTraffic(rbw.bytesRx), err) + } +} + +// Read implements the io.Reader interface. +func (rbw *responseBodyWrapper) Read(b []byte) (int, error) { + n, err := rbw.ReadCloser.Read(b) + rbw.bytesRx += int64(n) + if err != nil && !errors.Is(err, io.EOF) { + rbw.logAccess(err.Error()) + } + + return n, err +} + +// Close implements the io.Close interface. +func (rbw *responseBodyWrapper) Close() error { + err := rbw.ReadCloser.Close() + var errStr string + if err != nil { + errStr = err.Error() + } + rbw.logAccess(errStr) + + return err +} + +// driveTransport is an http.RoundTripper that wraps +// b.Dialer().PeerAPITransport() with metrics tracking. +type driveTransport struct { + b *LocalBackend + tr *http.Transport +} + +func (b *LocalBackend) newDriveTransport() *driveTransport { + return &driveTransport{ + b: b, + tr: b.Dialer().PeerAPITransport(), + } +} + +func (dt *driveTransport) RoundTrip(req *http.Request) (resp *http.Response, err error) { + // Some WebDAV clients include origin and refer headers, which peerapi does + // not like. Remove them. + req.Header.Del("origin") + req.Header.Del("referer") + + bw := &requestBodyWrapper{} + if req.Body != nil { + bw.ReadCloser = req.Body + req.Body = bw + } + + defer func() { + contentType := "unknown" + if ct := req.Header.Get("Content-Type"); ct != "" { + contentType = ct + } + + dt.b.mu.Lock() + selfNodeKey := dt.b.currentNode().Self().Key().ShortString() + dt.b.mu.Unlock() + n, _, ok := dt.b.WhoIs("tcp", netip.MustParseAddrPort(req.URL.Host)) + shareNodeKey := "unknown" + if ok { + shareNodeKey = string(n.Key().ShortString()) + } + + rbw := responseBodyWrapper{ + log: dt.b.logf, + logVerbose: req.Method != httpm.GET && req.Method != httpm.PUT, // other requests like PROPFIND are quite chatty, so we log those at verbose level + method: req.Method, + bytesTx: int64(bw.bytesRead), + selfNodeKey: selfNodeKey, + shareNodeKey: shareNodeKey, + contentType: contentType, + contentLength: resp.ContentLength, + fileExtension: parseDriveFileExtensionForLog(req.URL.Path), + statusCode: resp.StatusCode, + ReadCloser: resp.Body, + } + + if resp.StatusCode >= 400 { + // in case of error response, just log immediately + rbw.logAccess("") + } else { + resp.Body = &rbw + } + }() + + return dt.tr.RoundTrip(req) +} diff --git a/ipn/ipnlocal/drive_tomove.go b/ipn/ipnlocal/drive_tomove.go new file mode 100644 index 000000000..290fe0970 --- /dev/null +++ b/ipn/ipnlocal/drive_tomove.go @@ -0,0 +1,30 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// This is the Taildrive stuff that should ideally be registered in init only when +// the ts_omit_drive is not set, but for transition reasons is currently (2025-09-08) +// always defined, as we work to pull it out of LocalBackend. + +package ipnlocal + +import "tailscale.com/tailcfg" + +const ( + // DriveLocalPort is the port on which the Taildrive listens for location + // connections on quad 100. + DriveLocalPort = 8080 +) + +// DriveSharingEnabled reports whether sharing to remote nodes via Taildrive is +// enabled. This is currently based on checking for the drive:share node +// attribute. +func (b *LocalBackend) DriveSharingEnabled() bool { + return b.currentNode().SelfHasCap(tailcfg.NodeAttrsTaildriveShare) +} + +// DriveAccessEnabled reports whether accessing Taildrive shares on remote nodes +// is enabled. This is currently based on checking for the drive:access node +// attribute. +func (b *LocalBackend) DriveAccessEnabled() bool { + return b.currentNode().SelfHasCap(tailcfg.NodeAttrsTaildriveAccess) +} diff --git a/ipn/ipnlocal/expiry.go b/ipn/ipnlocal/expiry.go index 04c10226d..8ea63d21a 100644 --- a/ipn/ipnlocal/expiry.go +++ b/ipn/ipnlocal/expiry.go @@ -6,12 +6,14 @@ package ipnlocal import ( "time" + "tailscale.com/control/controlclient" "tailscale.com/syncs" "tailscale.com/tailcfg" "tailscale.com/tstime" "tailscale.com/types/key" "tailscale.com/types/logger" "tailscale.com/types/netmap" + "tailscale.com/util/eventbus" ) // For extra defense-in-depth, when we're testing expired nodes we check @@ -40,14 +42,22 @@ type expiryManager struct { logf logger.Logf clock tstime.Clock + + eventClient *eventbus.Client } -func newExpiryManager(logf logger.Logf) *expiryManager { - return &expiryManager{ +func newExpiryManager(logf logger.Logf, bus *eventbus.Bus) *expiryManager { + em := &expiryManager{ previouslyExpired: map[tailcfg.StableNodeID]bool{}, logf: logf, clock: tstime.StdClock{}, } + + em.eventClient = bus.Client("ipnlocal.expiryManager") + eventbus.SubscribeFunc(em.eventClient, func(ct controlclient.ControlTime) { + em.onControlTime(ct.Value) + }) + return em } // onControlTime is called whenever we receive a new timestamp from the control @@ -116,7 +126,7 @@ func (em *expiryManager) flagExpiredPeers(netmap *netmap.NetworkMap, localNow ti // since we discover endpoints via DERP, and due to DERP return // path optimization. mut.Endpoints = nil - mut.DERP = "" + mut.HomeDERP = 0 // Defense-in-depth: break the node's public key as well, in // case something tries to communicate. @@ -218,6 +228,8 @@ func (em *expiryManager) nextPeerExpiry(nm *netmap.NetworkMap, localNow time.Tim return nextExpiry } +func (em *expiryManager) close() { em.eventClient.Close() } + // ControlNow estimates the current time on the control server, calculated as // localNow + the delta between local and control server clocks as recorded // when the LocalBackend last received a time message from the control server. diff --git a/ipn/ipnlocal/expiry_test.go b/ipn/ipnlocal/expiry_test.go index af1aa337b..2c646ca72 100644 --- a/ipn/ipnlocal/expiry_test.go +++ b/ipn/ipnlocal/expiry_test.go @@ -14,6 +14,7 @@ import ( "tailscale.com/tstest" "tailscale.com/types/key" "tailscale.com/types/netmap" + "tailscale.com/util/eventbus/eventbustest" ) func TestFlagExpiredPeers(t *testing.T) { @@ -110,7 +111,8 @@ func TestFlagExpiredPeers(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - em := newExpiryManager(t.Logf) + bus := eventbustest.NewBus(t) + em := newExpiryManager(t.Logf, bus) em.clock = tstest.NewClock(tstest.ClockOpts{Start: now}) if tt.controlTime != nil { em.onControlTime(*tt.controlTime) @@ -240,7 +242,8 @@ func TestNextPeerExpiry(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - em := newExpiryManager(t.Logf) + bus := eventbustest.NewBus(t) + em := newExpiryManager(t.Logf, bus) em.clock = tstest.NewClock(tstest.ClockOpts{Start: now}) got := em.nextPeerExpiry(tt.netmap, now) if !got.Equal(tt.want) { @@ -253,7 +256,8 @@ func TestNextPeerExpiry(t *testing.T) { t.Run("ClockSkew", func(t *testing.T) { t.Logf("local time: %q", now.Format(time.RFC3339)) - em := newExpiryManager(t.Logf) + bus := eventbustest.NewBus(t) + em := newExpiryManager(t.Logf, bus) em.clock = tstest.NewClock(tstest.ClockOpts{Start: now}) // The local clock is "running fast"; our clock skew is -2h @@ -283,11 +287,11 @@ func formatNodes(nodes []tailcfg.NodeView) string { } fmt.Fprintf(&sb, "(%d, %q", n.ID(), n.Name()) - if n.Online() != nil { - fmt.Fprintf(&sb, ", online=%v", *n.Online()) + if online, ok := n.Online().GetOk(); ok { + fmt.Fprintf(&sb, ", online=%v", online) } - if n.LastSeen() != nil { - fmt.Fprintf(&sb, ", lastSeen=%v", n.LastSeen().Unix()) + if lastSeen, ok := n.LastSeen().GetOk(); ok { + fmt.Fprintf(&sb, ", lastSeen=%v", lastSeen.Unix()) } if n.Key() != (key.NodePublic{}) { fmt.Fprintf(&sb, ", key=%v", n.Key().String()) diff --git a/ipn/ipnlocal/extension_host.go b/ipn/ipnlocal/extension_host.go new file mode 100644 index 000000000..ca802ab89 --- /dev/null +++ b/ipn/ipnlocal/extension_host.go @@ -0,0 +1,621 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package ipnlocal + +import ( + "context" + "errors" + "fmt" + "maps" + "reflect" + "slices" + "strings" + "sync" + "sync/atomic" + "time" + + "tailscale.com/control/controlclient" + "tailscale.com/ipn" + "tailscale.com/ipn/ipnauth" + "tailscale.com/ipn/ipnext" + "tailscale.com/tailcfg" + "tailscale.com/types/logger" + "tailscale.com/util/execqueue" + "tailscale.com/util/mak" + "tailscale.com/util/testenv" +) + +// ExtensionHost is a bridge between the [LocalBackend] and the registered [ipnext.Extension]s. +// It implements [ipnext.Host] and is safe for concurrent use. +// +// A nil pointer to [ExtensionHost] is a valid, no-op extension host which is primarily used in tests +// that instantiate [LocalBackend] directly without using [NewExtensionHost]. +// +// The [LocalBackend] is not required to hold its mutex when calling the host's methods, +// but it typically does so either to prevent changes to its state (for example, the current profile) +// while callbacks are executing, or because it calls the host's methods as part of a larger operation +// that requires the mutex to be held. +// +// Extensions might invoke the host's methods either from callbacks triggered by the [LocalBackend], +// or in a response to external events. Some methods can be called by both the extensions and the backend. +// +// As a general rule, the host cannot assume anything about the current state of the [LocalBackend]'s +// internal mutex on entry to its methods, and therefore cannot safely call [LocalBackend] methods directly. +// +// The following are typical and supported patterns: +// - LocalBackend notifies the host about an event, such as a change in the current profile. +// The host invokes callbacks registered by Extensions, forwarding the event arguments to them. +// If necessary, the host can also update its own state for future use. +// - LocalBackend requests information from the host, such as the effective [ipnauth.AuditLogFunc] +// or the [ipn.LoginProfile] to use when no GUI/CLI client is connected. Typically, [LocalBackend] +// provides the required context to the host, and the host returns the result to [LocalBackend] +// after forwarding the request to the extensions. +// - Extension invokes the host's method to perform an action, such as switching to the "best" profile +// in response to a change in the device's state. Since the host does not know whether the [LocalBackend]'s +// internal mutex is held, it cannot invoke any methods on the [LocalBackend] directly and must instead +// do so asynchronously, such as by using [ExtensionHost.enqueueBackendOperation]. +// - Extension requests information from the host, such as the effective [ipnauth.AuditLogFunc] +// or the current [ipn.LoginProfile]. Since the host cannot invoke any methods on the [LocalBackend] directly, +// it should maintain its own view of the current state, updating it when the [LocalBackend] notifies it +// about a change or event. +// +// To safeguard against adopting incorrect or risky patterns, the host does not store [LocalBackend] in its fields +// and instead provides [ExtensionHost.enqueueBackendOperation]. Additionally, to make it easier to test extensions +// and to further reduce the risk of accessing unexported methods or fields of [LocalBackend], the host interacts +// with it via the [Backend] interface. +type ExtensionHost struct { + b Backend + hooks ipnext.Hooks + logf logger.Logf // prefixed with "ipnext:" + + // allExtensions holds the extensions in the order they were registered, + // including those that have not yet attempted initialization or have failed to initialize. + allExtensions []ipnext.Extension + + // initOnce is used to ensure that the extensions are initialized only once, + // even if [extensionHost.Init] is called multiple times. + initOnce sync.Once + initDone atomic.Bool + // shutdownOnce is like initOnce, but for [ExtensionHost.Shutdown]. + shutdownOnce sync.Once + + // workQueue maintains execution order for asynchronous operations requested by extensions. + // It is always an [execqueue.ExecQueue] except in some tests. + workQueue execQueue + // doEnqueueBackendOperation adds an asynchronous [LocalBackend] operation to the workQueue. + doEnqueueBackendOperation func(func(Backend)) + + shuttingDown atomic.Bool + + extByType sync.Map // reflect.Type -> ipnext.Extension + + // mu protects the following fields. + // It must not be held when calling [LocalBackend] methods + // or when invoking callbacks registered by extensions. + mu sync.Mutex + // initialized is whether the host and extensions have been fully initialized. + initialized atomic.Bool + // activeExtensions is a subset of allExtensions that have been initialized and are ready to use. + activeExtensions []ipnext.Extension + // extensionsByName are the extensions indexed by their names. + // They are not necessarily initialized (in activeExtensions) yet. + extensionsByName map[string]ipnext.Extension + // postInitWorkQueue is a queue of functions to be executed + // by the workQueue after all extensions have been initialized. + postInitWorkQueue []func(Backend) + + // currentProfile is a read-only view of the currently used profile. + // The view is always Valid, but might be of an empty, non-persisted profile. + currentProfile ipn.LoginProfileView + // currentPrefs is a read-only view of the current profile's [ipn.Prefs] + // with any private keys stripped. It is always Valid. + currentPrefs ipn.PrefsView +} + +// Backend is a subset of [LocalBackend] methods that are used by [ExtensionHost]. +// It is primarily used for testing. +type Backend interface { + // SwitchToBestProfile switches to the best profile for the current state of the system. + // The reason indicates why the profile is being switched. + SwitchToBestProfile(reason string) + + SendNotify(ipn.Notify) + + NodeBackend() ipnext.NodeBackend + + ipnext.SafeBackend +} + +// NewExtensionHost returns a new [ExtensionHost] which manages registered extensions for the given backend. +// The extensions are instantiated, but are not initialized until [ExtensionHost.Init] is called. +// It returns an error if instantiating any extension fails. +func NewExtensionHost(logf logger.Logf, b Backend) (*ExtensionHost, error) { + return newExtensionHost(logf, b) +} + +func NewExtensionHostForTest(logf logger.Logf, b Backend, overrideExts ...*ipnext.Definition) (*ExtensionHost, error) { + if !testenv.InTest() { + panic("use outside of test") + } + return newExtensionHost(logf, b, overrideExts...) +} + +// newExtensionHost is the shared implementation of [NewExtensionHost] and +// [NewExtensionHostForTest]. +// +// If overrideExts is non-nil, the registered extensions are ignored and the +// provided extensions are used instead. Overriding extensions is primarily used +// for testing. +func newExtensionHost(logf logger.Logf, b Backend, overrideExts ...*ipnext.Definition) (_ *ExtensionHost, err error) { + host := &ExtensionHost{ + b: b, + logf: logger.WithPrefix(logf, "ipnext: "), + workQueue: &execqueue.ExecQueue{}, + // The host starts with an empty profile and default prefs. + // We'll update them once [profileManager] notifies us of the initial profile. + currentProfile: zeroProfile, + currentPrefs: defaultPrefs, + } + + // All operations on the backend must be executed asynchronously by the work queue. + // DO NOT retain a direct reference to the backend in the host. + // See the docstring for [ExtensionHost] for more details. + host.doEnqueueBackendOperation = func(f func(Backend)) { + if f == nil { + panic("nil backend operation") + } + host.workQueue.Add(func() { f(b) }) + } + + // Use registered extensions. + extDef := ipnext.Extensions() + if overrideExts != nil { + // Use the provided, potentially empty, overrideExts + // instead of the registered ones. + extDef = slices.Values(overrideExts) + } + + for d := range extDef { + ext, err := d.MakeExtension(logf, b) + if errors.Is(err, ipnext.SkipExtension) { + // The extension wants to be skipped. + host.logf("%q: %v", d.Name(), err) + continue + } else if err != nil { + return nil, fmt.Errorf("failed to create %q extension: %v", d.Name(), err) + } + host.allExtensions = append(host.allExtensions, ext) + + if d.Name() != ext.Name() { + return nil, fmt.Errorf("extension name %q does not match the registered name %q", ext.Name(), d.Name()) + } + + if _, ok := host.extensionsByName[ext.Name()]; ok { + return nil, fmt.Errorf("duplicate extension name %q", ext.Name()) + } else { + mak.Set(&host.extensionsByName, ext.Name(), ext) + } + + typ := reflect.TypeOf(ext) + if _, ok := host.extByType.Load(typ); ok { + if _, ok := ext.(interface{ PermitDoubleRegister() }); !ok { + return nil, fmt.Errorf("duplicate extension type %T", ext) + } + } + host.extByType.Store(typ, ext) + } + return host, nil +} + +func (h *ExtensionHost) NodeBackend() ipnext.NodeBackend { + if h == nil { + return nil + } + return h.b.NodeBackend() +} + +// Init initializes the host and the extensions it manages. +func (h *ExtensionHost) Init() { + if h != nil { + h.initOnce.Do(h.init) + } +} + +var zeroHooks ipnext.Hooks + +func (h *ExtensionHost) Hooks() *ipnext.Hooks { + if h == nil { + return &zeroHooks + } + return &h.hooks +} + +func (h *ExtensionHost) init() { + defer h.initDone.Store(true) + + // Initialize the extensions in the order they were registered. + for _, ext := range h.allExtensions { + // Do not hold the lock while calling [ipnext.Extension.Init]. + // Extensions call back into the host to register their callbacks, + // and that would cause a deadlock if the h.mu is already held. + if err := ext.Init(h); err != nil { + // As per the [ipnext.Extension] interface, failures to initialize + // an extension are never fatal. The extension is simply skipped. + // + // But we handle [ipnext.SkipExtension] differently for nicer logging + // if the extension wants to be skipped and not actually failing. + if errors.Is(err, ipnext.SkipExtension) { + h.logf("%q: %v", ext.Name(), err) + } else { + h.logf("%q init failed: %v", ext.Name(), err) + } + continue + } + // Update the initialized extensions lists as soon as the extension is initialized. + // We'd like to make them visible to other extensions that are initialized later. + h.mu.Lock() + h.activeExtensions = append(h.activeExtensions, ext) + h.mu.Unlock() + } + + // Report active extensions to the log. + // TODO(nickkhyl): update client metrics to include the active/failed/skipped extensions. + h.mu.Lock() + extensionNames := slices.Collect(maps.Keys(h.extensionsByName)) + h.mu.Unlock() + h.logf("active extensions: %v", strings.Join(extensionNames, ", ")) + + // Additional init steps that need to be performed after all extensions have been initialized. + h.mu.Lock() + wq := h.postInitWorkQueue + h.postInitWorkQueue = nil + h.initialized.Store(true) + h.mu.Unlock() + + // Enqueue work that was requested and deferred during initialization. + h.doEnqueueBackendOperation(func(b Backend) { + for _, f := range wq { + f(b) + } + }) +} + +// Extensions implements [ipnext.Host]. +func (h *ExtensionHost) Extensions() ipnext.ExtensionServices { + // Currently, [ExtensionHost] implements [ExtensionServices] directly. + // We might want to extract it to a separate type in the future. + return h +} + +// FindExtensionByName implements [ipnext.ExtensionServices] +// and is also used by the [LocalBackend]. +// It returns nil if the extension is not found. +func (h *ExtensionHost) FindExtensionByName(name string) any { + if h == nil { + return nil + } + h.mu.Lock() + defer h.mu.Unlock() + return h.extensionsByName[name] +} + +// extensionIfaceType is the runtime type of the [ipnext.Extension] interface. +var extensionIfaceType = reflect.TypeFor[ipnext.Extension]() + +// GetExt returns the extension of type T registered with lb. +// If lb is nil or the extension is not found, it returns zero, false. +func GetExt[T ipnext.Extension](lb *LocalBackend) (_ T, ok bool) { + var zero T + if lb == nil { + return zero, false + } + if ext, ok := lb.extHost.extensionOfType(reflect.TypeFor[T]()); ok { + return ext.(T), true + } + return zero, false +} + +func (h *ExtensionHost) extensionOfType(t reflect.Type) (_ ipnext.Extension, ok bool) { + if h == nil { + return nil, false + } + if v, ok := h.extByType.Load(t); ok { + return v.(ipnext.Extension), true + } + return nil, false +} + +// FindMatchingExtension implements [ipnext.ExtensionServices] +// and is also used by the [LocalBackend]. +func (h *ExtensionHost) FindMatchingExtension(target any) bool { + if h == nil { + return false + } + + if target == nil { + panic("ipnext: target cannot be nil") + } + + val := reflect.ValueOf(target) + typ := val.Type() + if typ.Kind() != reflect.Ptr || val.IsNil() { + panic("ipnext: target must be a non-nil pointer") + } + targetType := typ.Elem() + if targetType.Kind() != reflect.Interface && !targetType.Implements(extensionIfaceType) { + panic("ipnext: *target must be interface or implement ipnext.Extension") + } + + h.mu.Lock() + defer h.mu.Unlock() + for _, ext := range h.activeExtensions { + if reflect.TypeOf(ext).AssignableTo(targetType) { + val.Elem().Set(reflect.ValueOf(ext)) + return true + } + } + return false +} + +// Profiles implements [ipnext.Host]. +func (h *ExtensionHost) Profiles() ipnext.ProfileServices { + // Currently, [ExtensionHost] implements [ipnext.ProfileServices] directly. + // We might want to extract it to a separate type in the future. + return h +} + +// CurrentProfileState implements [ipnext.ProfileServices]. +func (h *ExtensionHost) CurrentProfileState() (ipn.LoginProfileView, ipn.PrefsView) { + if h == nil { + return zeroProfile, defaultPrefs + } + h.mu.Lock() + defer h.mu.Unlock() + return h.currentProfile, h.currentPrefs +} + +// CurrentPrefs implements [ipnext.ProfileServices]. +func (h *ExtensionHost) CurrentPrefs() ipn.PrefsView { + _, prefs := h.CurrentProfileState() + return prefs +} + +// SwitchToBestProfileAsync implements [ipnext.ProfileServices]. +func (h *ExtensionHost) SwitchToBestProfileAsync(reason string) { + if h == nil { + return + } + h.enqueueBackendOperation(func(b Backend) { + b.SwitchToBestProfile(reason) + }) +} + +// SendNotifyAsync implements [ipnext.Host]. +func (h *ExtensionHost) SendNotifyAsync(n ipn.Notify) { + if h == nil { + return + } + h.enqueueBackendOperation(func(b Backend) { + b.SendNotify(n) + }) +} + +// NotifyProfileChange invokes registered profile state change callbacks +// and updates the current profile and prefs in the host. +// It strips private keys from the [ipn.Prefs] before preserving +// or passing them to the callbacks. +func (h *ExtensionHost) NotifyProfileChange(profile ipn.LoginProfileView, prefs ipn.PrefsView, sameNode bool) { + if !h.active() { + return + } + h.mu.Lock() + // Strip private keys from the prefs before preserving or passing them to the callbacks. + // Extensions should not need them (unless proven otherwise in the future), + // and this is a good way to ensure that they won't accidentally leak them. + prefs = stripKeysFromPrefs(prefs) + // Update the current profile and prefs in the host, + // so we can provide them to the extensions later if they ask. + h.currentPrefs = prefs + h.currentProfile = profile + h.mu.Unlock() + + for _, cb := range h.hooks.ProfileStateChange { + cb(profile, prefs, sameNode) + } +} + +// NotifyProfilePrefsChanged invokes registered profile state change callbacks, +// and updates the current profile and prefs in the host. +// It strips private keys from the [ipn.Prefs] before preserving or using them. +func (h *ExtensionHost) NotifyProfilePrefsChanged(profile ipn.LoginProfileView, oldPrefs, newPrefs ipn.PrefsView) { + if !h.active() { + return + } + h.mu.Lock() + // Strip private keys from the prefs before preserving or passing them to the callbacks. + // Extensions should not need them (unless proven otherwise in the future), + // and this is a good way to ensure that they won't accidentally leak them. + newPrefs = stripKeysFromPrefs(newPrefs) + // Update the current profile and prefs in the host, + // so we can provide them to the extensions later if they ask. + h.currentPrefs = newPrefs + h.currentProfile = profile + // Get the callbacks to be invoked. + h.mu.Unlock() + + for _, cb := range h.hooks.ProfileStateChange { + cb(profile, newPrefs, true) + } +} + +func (h *ExtensionHost) active() bool { + return h != nil && !h.shuttingDown.Load() +} + +// DetermineBackgroundProfile returns a read-only view of the profile +// used when no GUI/CLI client is connected, using background profile +// resolvers registered by extensions. +// +// It returns an invalid view if Tailscale should not run in the background +// and instead disconnect until a GUI/CLI client connects. +// +// As of 2025-02-07, this is only used on Windows. +func (h *ExtensionHost) DetermineBackgroundProfile(profiles ipnext.ProfileStore) ipn.LoginProfileView { + if !h.active() { + return ipn.LoginProfileView{} + } + // TODO(nickkhyl): check if the returned profile is allowed on the device, + // such as when [syspolicy.Tailnet] policy setting requires a specific Tailnet. + // See tailscale/corp#26249. + + // Attempt to resolve the background profile using the registered + // background profile resolvers (e.g., [ipn/desktop.desktopSessionsExt] on Windows). + for _, resolver := range h.hooks.BackgroundProfileResolvers { + if profile := resolver(profiles); profile.Valid() { + return profile + } + } + + // Otherwise, switch to an empty profile and disconnect Tailscale + // until a GUI or CLI client connects. + return ipn.LoginProfileView{} +} + +// NotifyNewControlClient invokes all registered control client callbacks. +// It returns callbacks to be executed when the control client shuts down. +func (h *ExtensionHost) NotifyNewControlClient(cc controlclient.Client, profile ipn.LoginProfileView) (ccShutdownCbs []func()) { + if !h.active() { + return nil + } + for _, cb := range h.hooks.NewControlClient { + if shutdown := cb(cc, profile); shutdown != nil { + ccShutdownCbs = append(ccShutdownCbs, shutdown) + } + } + return ccShutdownCbs +} + +// AuditLogger returns a function that reports an auditable action +// to all registered audit loggers. It fails if any of them returns an error, +// indicating that the action cannot be logged and must not be performed. +// +// It implements [ipnext.Host], but is also used by the [LocalBackend]. +// +// The returned function closes over the current state of the host and extensions, +// which typically includes the current profile and the audit loggers registered by extensions. +// It must not be persisted outside of the auditable action context. +func (h *ExtensionHost) AuditLogger() ipnauth.AuditLogFunc { + if !h.active() { + return func(tailcfg.ClientAuditAction, string) error { return nil } + } + loggers := make([]ipnauth.AuditLogFunc, 0, len(h.hooks.AuditLoggers)) + for _, provider := range h.hooks.AuditLoggers { + loggers = append(loggers, provider()) + } + return func(action tailcfg.ClientAuditAction, details string) error { + // Log auditable actions to the host's log regardless of whether + // the audit loggers are available or not. + h.logf("auditlog: %v: %v", action, details) + + // Invoke all registered audit loggers and collect errors. + // If any of them returns an error, the action is denied. + var errs []error + for _, logger := range loggers { + if err := logger(action, details); err != nil { + errs = append(errs, err) + } + } + return errors.Join(errs...) + } +} + +// Shutdown shuts down the extension host and all initialized extensions. +func (h *ExtensionHost) Shutdown() { + if h == nil { + return + } + // Ensure that the init function has completed before shutting down, + // or prevent any further init calls from happening. + h.initOnce.Do(func() {}) + h.shutdownOnce.Do(h.shutdown) +} + +func (h *ExtensionHost) shutdown() { + h.shuttingDown.Store(true) + // Prevent any queued but not yet started operations from running, + // block new operations from being enqueued, and wait for the + // currently executing operation (if any) to finish. + h.shutdownWorkQueue() + // Invoke shutdown callbacks registered by extensions. + h.shutdownExtensions() +} + +func (h *ExtensionHost) shutdownWorkQueue() { + h.workQueue.Shutdown() + var ctx context.Context + if testenv.InTest() { + // In tests, we'd like to wait indefinitely for the current operation to finish, + // mostly to help avoid flaky tests. Test runners can be pretty slow. + ctx = context.Background() + } else { + // In prod, however, we want to avoid blocking indefinitely. + // The 5s timeout is somewhat arbitrary; LocalBackend operations + // should not take that long. + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + } + // Since callbacks are invoked synchronously, this will also wait + // for in-flight callbacks associated with those operations to finish. + if err := h.workQueue.Wait(ctx); err != nil { + h.logf("work queue shutdown failed: %v", err) + } +} + +func (h *ExtensionHost) shutdownExtensions() { + h.mu.Lock() + extensions := h.activeExtensions + h.mu.Unlock() + + // h.mu must not be held while shutting down extensions. + // Extensions might call back into the host and that would cause + // a deadlock if the h.mu is already held. + // + // Shutdown is called in the reverse order of Init. + for _, ext := range slices.Backward(extensions) { + if err := ext.Shutdown(); err != nil { + // Extension shutdown errors are never fatal, but we log them for debugging purposes. + h.logf("%q: shutdown callback failed: %v", ext.Name(), err) + } + } +} + +// enqueueBackendOperation enqueues a function to perform an operation on the [Backend]. +// If the host has not yet been initialized (e.g., when called from an extension's Init method), +// the operation is deferred until after the host and all extensions have completed initialization. +// It panics if the f is nil. +func (h *ExtensionHost) enqueueBackendOperation(f func(Backend)) { + if h == nil { + return + } + if f == nil { + panic("nil backend operation") + } + h.mu.Lock() // protects h.initialized and h.postInitWorkQueue + defer h.mu.Unlock() + if h.initialized.Load() { + h.doEnqueueBackendOperation(f) + } else { + h.postInitWorkQueue = append(h.postInitWorkQueue, f) + } +} + +// execQueue is an ordered asynchronous queue for executing functions. +// It is implemented by [execqueue.ExecQueue]. The interface is used +// to allow testing with a mock implementation. +type execQueue interface { + Add(func()) + Shutdown() + Wait(context.Context) error +} diff --git a/ipn/ipnlocal/extension_host_test.go b/ipn/ipnlocal/extension_host_test.go new file mode 100644 index 000000000..f5c081a5b --- /dev/null +++ b/ipn/ipnlocal/extension_host_test.go @@ -0,0 +1,1412 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package ipnlocal + +import ( + "cmp" + "context" + "errors" + "net/netip" + "reflect" + "slices" + "strconv" + "strings" + "sync" + "sync/atomic" + "testing" + + deepcmp "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + + "tailscale.com/health" + "tailscale.com/ipn" + "tailscale.com/ipn/ipnauth" + "tailscale.com/ipn/ipnext" + "tailscale.com/ipn/store/mem" + "tailscale.com/tailcfg" + "tailscale.com/tsd" + "tailscale.com/tstest" + "tailscale.com/tstime" + "tailscale.com/types/key" + "tailscale.com/types/lazy" + "tailscale.com/types/logger" + "tailscale.com/types/persist" + "tailscale.com/util/eventbus/eventbustest" + "tailscale.com/util/must" +) + +// defaultCmpOpts are the default options used for deepcmp comparisons in tests. +var defaultCmpOpts = []deepcmp.Option{ + cmpopts.EquateComparable(key.NodePublic{}, netip.Addr{}, netip.Prefix{}), +} + +// TestExtensionInitShutdown tests that [ExtensionHost] correctly initializes +// and shuts down extensions. +func TestExtensionInitShutdown(t *testing.T) { + t.Parallel() + + // As of 2025-04-08, [ipn.Host.Init] and [ipn.Host.Shutdown] do not return errors + // as extension initialization and shutdown errors are not fatal. + // If these methods are updated to return errors, this test should also be updated. + // The conversions below will fail to compile if their signatures change, reminding us to update the test. + _ = (func(*ExtensionHost))((*ExtensionHost).Init) + _ = (func(*ExtensionHost))((*ExtensionHost).Shutdown) + + tests := []struct { + name string + nilHost bool + exts []*testExtension + wantInit []string + wantShutdown []string + skipInit bool + }{ + { + name: "nil-host", + nilHost: true, + exts: []*testExtension{}, + wantInit: []string{}, + wantShutdown: []string{}, + }, + { + name: "empty-extensions", + exts: []*testExtension{}, + wantInit: []string{}, + wantShutdown: []string{}, + }, + { + name: "single-extension", + exts: []*testExtension{{name: "A"}}, + wantInit: []string{"A"}, + wantShutdown: []string{"A"}, + }, + { + name: "multiple-extensions/all-ok", + exts: []*testExtension{{name: "A"}, {name: "B"}, {name: "C"}}, + wantInit: []string{"A", "B", "C"}, + wantShutdown: []string{"C", "B", "A"}, + }, + { + name: "multiple-extensions/no-init-no-shutdown", + exts: []*testExtension{{name: "A"}, {name: "B"}, {name: "C"}}, + wantInit: []string{}, + wantShutdown: []string{}, + skipInit: true, + }, + { + name: "multiple-extensions/init-failed/first", + exts: []*testExtension{{ + name: "A", + InitHook: func(*testExtension) error { return errors.New("init failed") }, + }, { + name: "B", + InitHook: func(*testExtension) error { return nil }, + }, { + name: "C", + InitHook: func(*testExtension) error { return nil }, + }}, + wantInit: []string{"A", "B", "C"}, + wantShutdown: []string{"C", "B"}, + }, + { + name: "multiple-extensions/init-failed/second", + exts: []*testExtension{{ + name: "A", + InitHook: func(*testExtension) error { return nil }, + }, { + name: "B", + InitHook: func(*testExtension) error { return errors.New("init failed") }, + }, { + name: "C", + InitHook: func(*testExtension) error { return nil }, + }}, + wantInit: []string{"A", "B", "C"}, + wantShutdown: []string{"C", "A"}, + }, + { + name: "multiple-extensions/init-failed/third", + exts: []*testExtension{{ + name: "A", + InitHook: func(*testExtension) error { return nil }, + }, { + name: "B", + InitHook: func(*testExtension) error { return nil }, + }, { + name: "C", + InitHook: func(*testExtension) error { return errors.New("init failed") }, + }}, + wantInit: []string{"A", "B", "C"}, + wantShutdown: []string{"B", "A"}, + }, + { + name: "multiple-extensions/init-failed/all", + exts: []*testExtension{{ + name: "A", + InitHook: func(*testExtension) error { return errors.New("init failed") }, + }, { + name: "B", + InitHook: func(*testExtension) error { return errors.New("init failed") }, + }, { + name: "C", + InitHook: func(*testExtension) error { return errors.New("init failed") }, + }}, + wantInit: []string{"A", "B", "C"}, + wantShutdown: []string{}, + }, + { + name: "multiple-extensions/init-skipped", + exts: []*testExtension{{ + name: "A", + InitHook: func(*testExtension) error { return nil }, + }, { + name: "B", + InitHook: func(*testExtension) error { return ipnext.SkipExtension }, + }, { + name: "C", + InitHook: func(*testExtension) error { return nil }, + }}, + wantInit: []string{"A", "B", "C"}, + wantShutdown: []string{"C", "A"}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + // Configure all extensions to append their names + // to the gotInit and gotShutdown slices + // during initialization and shutdown, + // so we can check that they are called in the right order + // and that shutdown is not unless init succeeded. + var gotInit, gotShutdown []string + for _, ext := range tt.exts { + oldInitHook := ext.InitHook + ext.InitHook = func(e *testExtension) error { + gotInit = append(gotInit, e.name) + if oldInitHook == nil { + return nil + } + return oldInitHook(e) + } + ext.ShutdownHook = func(e *testExtension) error { + gotShutdown = append(gotShutdown, e.name) + return nil + } + } + + var h *ExtensionHost + if !tt.nilHost { + h = newExtensionHostForTest(t, &testBackend{}, false, tt.exts...) + } + + if !tt.skipInit { + h.Init() + } + + // Check that the extensions were initialized in the right order. + if !slices.Equal(gotInit, tt.wantInit) { + t.Errorf("Init extensions: got %v; want %v", gotInit, tt.wantInit) + } + + // Calling Init again on the host should be a no-op. + // The [testExtension.Init] method fails the test if called more than once, + // regardless of which test is running, so we don't need to check it here. + // Similarly, calling Shutdown again on the host should be a no-op as well. + // It is verified by the [testExtension.Shutdown] method itself. + if !tt.skipInit { + h.Init() + } + + // Extensions should not be shut down before the host is shut down, + // even if they are not initialized successfully. + for _, ext := range tt.exts { + if gotShutdown := ext.ShutdownCalled(); gotShutdown { + t.Errorf("%q: Extension shutdown called before host shutdown", ext.name) + } + } + + h.Shutdown() + // Check that the extensions were shut down in the right order, + // and that they were not shut down if they were not initialized successfully. + if !slices.Equal(gotShutdown, tt.wantShutdown) { + t.Errorf("Shutdown extensions: got %v; want %v", gotShutdown, tt.wantShutdown) + } + + }) + } +} + +// TestNewExtensionHost tests that [NewExtensionHost] correctly creates +// an [ExtensionHost], instantiates the extensions and handles errors +// if an extension cannot be created. +func TestNewExtensionHost(t *testing.T) { + t.Parallel() + tests := []struct { + name string + defs []*ipnext.Definition + wantErr bool + wantExts []string + }{ + { + name: "no-exts", + defs: []*ipnext.Definition{}, + wantErr: false, + wantExts: []string{}, + }, + { + name: "exts-ok", + defs: []*ipnext.Definition{ + ipnext.DefinitionForTest(&testExtension{name: "A"}), + ipnext.DefinitionForTest(&testExtension{name: "B"}), + ipnext.DefinitionForTest(&testExtension{name: "C"}), + }, + wantErr: false, + wantExts: []string{"A", "B", "C"}, + }, + { + name: "exts-skipped", + defs: []*ipnext.Definition{ + ipnext.DefinitionForTest(&testExtension{name: "A"}), + ipnext.DefinitionWithErrForTest("B", ipnext.SkipExtension), + ipnext.DefinitionForTest(&testExtension{name: "C"}), + }, + wantErr: false, // extension B is skipped, that's ok + wantExts: []string{"A", "C"}, + }, + { + name: "exts-fail", + defs: []*ipnext.Definition{ + ipnext.DefinitionForTest(&testExtension{name: "A"}), + ipnext.DefinitionWithErrForTest("B", errors.New("failed creating Ext-2")), + ipnext.DefinitionForTest(&testExtension{name: "C"}), + }, + wantErr: true, // extension B failed to create, that's not ok + wantExts: []string{}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + logf := tstest.WhileTestRunningLogger(t) + h, err := NewExtensionHostForTest(logf, &testBackend{}, tt.defs...) + if gotErr := err != nil; gotErr != tt.wantErr { + t.Errorf("NewExtensionHost: gotErr %v(%v); wantErr %v", gotErr, err, tt.wantErr) + } + if err != nil { + return + } + + var gotExts []string + for _, ext := range h.allExtensions { + gotExts = append(gotExts, ext.Name()) + } + + if !slices.Equal(gotExts, tt.wantExts) { + t.Errorf("Shutdown extensions: got %v; want %v", gotExts, tt.wantExts) + } + }) + } +} + +// TestFindMatchingExtension tests that [ExtensionHost.FindMatchingExtension] correctly +// finds extensions by their type or interface. +func TestFindMatchingExtension(t *testing.T) { + t.Parallel() + + // Define test extension types and a couple of interfaces + type ( + extensionA struct { + testExtension + } + extensionB struct { + testExtension + } + extensionC struct { + testExtension + } + supportedIface interface { + Name() string + } + unsupportedIface interface { + Unsupported() + } + ) + + // Register extensions A and B, but not C. + extA := &extensionA{testExtension: testExtension{name: "A"}} + extB := &extensionB{testExtension: testExtension{name: "B"}} + h := newExtensionHostForTest[ipnext.Extension](t, &testBackend{}, true, extA, extB) + + var gotA *extensionA + if !h.FindMatchingExtension(&gotA) { + t.Errorf("LookupExtension(%T): not found", gotA) + } else if gotA != extA { + t.Errorf("LookupExtension(%T): got %v; want %v", gotA, gotA, extA) + } + + var gotB *extensionB + if !h.FindMatchingExtension(&gotB) { + t.Errorf("LookupExtension(%T): extension B not found", gotB) + } else if gotB != extB { + t.Errorf("LookupExtension(%T): got %v; want %v", gotB, gotB, extB) + } + + var gotC *extensionC + if h.FindMatchingExtension(&gotC) { + t.Errorf("LookupExtension(%T): found, but it should not exist", gotC) + } + + // All extensions implement the supportedIface interface, + // but LookupExtension should only return the first one found, + // which is extA. + var gotSupportedIface supportedIface + if !h.FindMatchingExtension(&gotSupportedIface) { + t.Errorf("LookupExtension(%T): not found", gotSupportedIface) + } else if gotName, wantName := gotSupportedIface.Name(), extA.Name(); gotName != wantName { + t.Errorf("LookupExtension(%T): name: got %v; want %v", gotSupportedIface, gotName, wantName) + } else if gotSupportedIface != extA { + t.Errorf("LookupExtension(%T): got %v; want %v", gotSupportedIface, gotSupportedIface, extA) + } + + var gotUnsupportedIface unsupportedIface + if h.FindMatchingExtension(&gotUnsupportedIface) { + t.Errorf("LookupExtension(%T): found, but it should not exist", gotUnsupportedIface) + } +} + +// TestFindExtensionByName tests that [ExtensionHost.FindExtensionByName] correctly +// finds extensions by their name. +func TestFindExtensionByName(t *testing.T) { + // Register extensions A and B, but not C. + extA := &testExtension{name: "A"} + extB := &testExtension{name: "B"} + h := newExtensionHostForTest(t, &testBackend{}, true, extA, extB) + + gotA, ok := h.FindExtensionByName(extA.Name()).(*testExtension) + if !ok { + t.Errorf("FindExtensionByName(%q): not found", extA.Name()) + } else if gotA != extA { + t.Errorf(`FindExtensionByName(%q): got %v; want %v`, extA.Name(), gotA, extA) + } + + gotB, ok := h.FindExtensionByName(extB.Name()).(*testExtension) + if !ok { + t.Errorf("FindExtensionByName(%q): not found", extB.Name()) + } else if gotB != extB { + t.Errorf(`FindExtensionByName(%q): got %v; want %v`, extB.Name(), gotB, extB) + } + + gotC, ok := h.FindExtensionByName("C").(*testExtension) + if ok { + t.Errorf(`FindExtensionByName("C"): found, but it should not exist: %v`, gotC) + } +} + +// TestExtensionHostEnqueueBackendOperation verifies that [ExtensionHost] enqueues +// backend operations and executes them asynchronously in the order they were received. +// It also checks that operations requested before the host and all extensions are initialized +// are not executed immediately but rather after the host and extensions are initialized. +func TestExtensionHostEnqueueBackendOperation(t *testing.T) { + t.Parallel() + tests := []struct { + name string + preInitCalls []string // before host init + extInitCalls []string // from [Extension.Init]; "" means no call + wantInitCalls []string // what we expect to be called after host init + postInitCalls []string // after host init + }{ + { + name: "no-calls", + preInitCalls: []string{}, + extInitCalls: []string{}, + wantInitCalls: []string{}, + postInitCalls: []string{}, + }, + { + name: "pre-init-calls", + preInitCalls: []string{"pre-init-1", "pre-init-2"}, + extInitCalls: []string{}, + wantInitCalls: []string{"pre-init-1", "pre-init-2"}, + postInitCalls: []string{}, + }, + { + name: "init-calls", + preInitCalls: []string{}, + extInitCalls: []string{"init-1", "init-2"}, + wantInitCalls: []string{"init-1", "init-2"}, + postInitCalls: []string{}, + }, + { + name: "post-init-calls", + preInitCalls: []string{}, + extInitCalls: []string{}, + wantInitCalls: []string{}, + postInitCalls: []string{"post-init-1", "post-init-2"}, + }, + { + name: "mixed-calls", + preInitCalls: []string{"pre-init-1", "pre-init-2"}, + extInitCalls: []string{"init-1", "", "init-2"}, + wantInitCalls: []string{"pre-init-1", "pre-init-2", "init-1", "init-2"}, + postInitCalls: []string{"post-init-1", "post-init-2"}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + var gotCalls []string + var h *ExtensionHost + b := &testBackend{ + switchToBestProfileHook: func(reason string) { + gotCalls = append(gotCalls, reason) + }, + } + + exts := make([]*testExtension, len(tt.extInitCalls)) + for i, reason := range tt.extInitCalls { + exts[i] = &testExtension{} + if reason != "" { + exts[i].InitHook = func(e *testExtension) error { + e.host.Profiles().SwitchToBestProfileAsync(reason) + return nil + } + } + } + + h = newExtensionHostForTest(t, b, false, exts...) + wq := h.SetWorkQueueForTest(t) // use a test queue instead of [execqueue.ExecQueue]. + + // Issue some pre-init calls. They should be deferred and not + // added to the queue until the host is initialized. + for _, call := range tt.preInitCalls { + h.Profiles().SwitchToBestProfileAsync(call) + } + + // The queue should be empty before the host is initialized. + wq.Drain() + if len(gotCalls) != 0 { + t.Errorf("Pre-init calls: got %v; want (none)", gotCalls) + } + gotCalls = nil + + // Initialize the host and all extensions. + // The extensions will make their calls during initialization. + h.Init() + + // Calls made before or during initialization should now be enqueued and running. + wq.Drain() + if diff := deepcmp.Diff(tt.wantInitCalls, gotCalls, cmpopts.EquateEmpty()); diff != "" { + t.Errorf("Init calls: (+got -want): %v", diff) + } + gotCalls = nil + + // Let's make some more calls, as if extensions were making them in a response + // to external events. + for _, call := range tt.postInitCalls { + h.Profiles().SwitchToBestProfileAsync(call) + } + + // Any calls made after initialization should be enqueued and running. + wq.Drain() + if diff := deepcmp.Diff(tt.postInitCalls, gotCalls, cmpopts.EquateEmpty()); diff != "" { + t.Errorf("Init calls: (+got -want): %v", diff) + } + gotCalls = nil + }) + } +} + +// TestExtensionHostProfileStateChangeCallback verifies that [ExtensionHost] correctly handles the registration, +// invocation, and unregistration of profile state change callbacks. This includes callbacks triggered by profile changes +// and by changes to the profile's [ipn.Prefs]. It also checks that the callbacks are called with the correct arguments +// and that any private keys are stripped from [ipn.Prefs] before being passed to the callback. +func TestExtensionHostProfileStateChangeCallback(t *testing.T) { + t.Parallel() + + type stateChange struct { + Profile *ipn.LoginProfile + Prefs *ipn.Prefs + SameNode bool + } + type prefsChange struct { + Profile *ipn.LoginProfile + Old, New *ipn.Prefs + } + + // newStateChange creates a new [stateChange] with deep copies of the profile and prefs. + newStateChange := func(profile ipn.LoginProfileView, prefs ipn.PrefsView, sameNode bool) stateChange { + return stateChange{ + Profile: profile.AsStruct(), + Prefs: prefs.AsStruct(), + SameNode: sameNode, + } + } + // makeStateChangeAppender returns a callback that appends profile state changes to the extension's state. + makeStateChangeAppender := func(e *testExtension) ipnext.ProfileStateChangeCallback { + return func(profile ipn.LoginProfileView, prefs ipn.PrefsView, sameNode bool) { + UpdateExtState(e, "changes", func(changes []stateChange) []stateChange { + return append(changes, newStateChange(profile, prefs, sameNode)) + }) + } + } + // getStateChanges returns the profile state changes stored in the extension's state. + getStateChanges := func(e *testExtension) []stateChange { + changes, _ := GetExtStateOk[[]stateChange](e, "changes") + return changes + } + + tests := []struct { + name string + ext *testExtension + stateCalls []stateChange + prefsCalls []prefsChange + wantChanges []stateChange + }{ + { + // Register the callback for the lifetime of the extension. + name: "Register/Lifetime", + ext: &testExtension{}, + stateCalls: []stateChange{ + {Profile: &ipn.LoginProfile{ID: "profile-1"}}, + {Profile: &ipn.LoginProfile{ID: "profile-2"}}, + {Profile: &ipn.LoginProfile{ID: "profile-3"}}, + {Profile: &ipn.LoginProfile{ID: "profile-3"}, SameNode: true}, + }, + wantChanges: []stateChange{ // all calls are received by the callback + {Profile: &ipn.LoginProfile{ID: "profile-1"}}, + {Profile: &ipn.LoginProfile{ID: "profile-2"}}, + {Profile: &ipn.LoginProfile{ID: "profile-3"}}, + {Profile: &ipn.LoginProfile{ID: "profile-3"}, SameNode: true}, + }, + }, + { + // Ensure that ipn.Prefs are passed to the callback. + name: "CheckPrefs", + ext: &testExtension{}, + stateCalls: []stateChange{{ + Profile: &ipn.LoginProfile{ID: "profile-1"}, + Prefs: &ipn.Prefs{ + WantRunning: true, + LoggedOut: false, + AdvertiseRoutes: []netip.Prefix{ + netip.MustParsePrefix("192.168.1.0/24"), + netip.MustParsePrefix("192.168.2.0/24"), + }, + }, + }}, + wantChanges: []stateChange{{ + Profile: &ipn.LoginProfile{ID: "profile-1"}, + Prefs: &ipn.Prefs{ + WantRunning: true, + LoggedOut: false, + AdvertiseRoutes: []netip.Prefix{ + netip.MustParsePrefix("192.168.1.0/24"), + netip.MustParsePrefix("192.168.2.0/24"), + }, + }, + }}, + }, + { + // Ensure that private keys are stripped from persist.Persist shared with extensions. + name: "StripPrivateKeys", + ext: &testExtension{}, + stateCalls: []stateChange{{ + Profile: &ipn.LoginProfile{ID: "profile-1"}, + Prefs: &ipn.Prefs{ + Persist: &persist.Persist{ + NodeID: "12345", + PrivateNodeKey: key.NewNode(), + OldPrivateNodeKey: key.NewNode(), + NetworkLockKey: key.NewNLPrivate(), + UserProfile: tailcfg.UserProfile{ + ID: 12345, + LoginName: "test@example.com", + DisplayName: "Test User", + ProfilePicURL: "https://example.com/profile.png", + }, + }, + }, + }}, + wantChanges: []stateChange{{ + Profile: &ipn.LoginProfile{ID: "profile-1"}, + Prefs: &ipn.Prefs{ + Persist: &persist.Persist{ + NodeID: "12345", + PrivateNodeKey: key.NodePrivate{}, // stripped + OldPrivateNodeKey: key.NodePrivate{}, // stripped + NetworkLockKey: key.NLPrivate{}, // stripped + UserProfile: tailcfg.UserProfile{ + ID: 12345, + LoginName: "test@example.com", + DisplayName: "Test User", + ProfilePicURL: "https://example.com/profile.png", + }, + }, + }, + }}, + }, + { + // Ensure that profile state callbacks are also invoked when prefs (rather than profile) change. + name: "PrefsChange", + ext: &testExtension{}, + prefsCalls: []prefsChange{ + { + Profile: &ipn.LoginProfile{ID: "profile-1"}, + Old: &ipn.Prefs{WantRunning: false, LoggedOut: true}, + New: &ipn.Prefs{WantRunning: true, LoggedOut: false}, + }, + { + Profile: &ipn.LoginProfile{ID: "profile-1"}, + Old: &ipn.Prefs{AdvertiseRoutes: []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")}}, + New: &ipn.Prefs{AdvertiseRoutes: []netip.Prefix{netip.MustParsePrefix("10.10.10.0/24")}}, + }, + }, + wantChanges: []stateChange{ + { + Profile: &ipn.LoginProfile{ID: "profile-1"}, + Prefs: &ipn.Prefs{WantRunning: true, LoggedOut: false}, + SameNode: true, // must be true for prefs changes + }, + { + Profile: &ipn.LoginProfile{ID: "profile-1"}, + Prefs: &ipn.Prefs{AdvertiseRoutes: []netip.Prefix{netip.MustParsePrefix("10.10.10.0/24")}}, + SameNode: true, // must be true for prefs changes + }, + }, + }, + { + // Ensure that private keys are stripped from prefs when state change callback + // is invoked by prefs change. + name: "PrefsChange/StripPrivateKeys", + ext: &testExtension{}, + prefsCalls: []prefsChange{ + { + Profile: &ipn.LoginProfile{ID: "profile-1"}, + Old: &ipn.Prefs{ + WantRunning: false, + LoggedOut: true, + Persist: &persist.Persist{ + NodeID: "12345", + PrivateNodeKey: key.NewNode(), + OldPrivateNodeKey: key.NewNode(), + NetworkLockKey: key.NewNLPrivate(), + UserProfile: tailcfg.UserProfile{ + ID: 12345, + LoginName: "test@example.com", + DisplayName: "Test User", + ProfilePicURL: "https://example.com/profile.png", + }, + }, + }, + New: &ipn.Prefs{ + WantRunning: true, + LoggedOut: false, + Persist: &persist.Persist{ + NodeID: "12345", + PrivateNodeKey: key.NewNode(), + OldPrivateNodeKey: key.NewNode(), + NetworkLockKey: key.NewNLPrivate(), + UserProfile: tailcfg.UserProfile{ + ID: 12345, + LoginName: "test@example.com", + DisplayName: "Test User", + ProfilePicURL: "https://example.com/profile.png", + }, + }, + }, + }, + }, + wantChanges: []stateChange{ + { + Profile: &ipn.LoginProfile{ID: "profile-1"}, + Prefs: &ipn.Prefs{ + WantRunning: true, + LoggedOut: false, + Persist: &persist.Persist{ + NodeID: "12345", + PrivateNodeKey: key.NodePrivate{}, // stripped + OldPrivateNodeKey: key.NodePrivate{}, // stripped + NetworkLockKey: key.NLPrivate{}, // stripped + UserProfile: tailcfg.UserProfile{ + ID: 12345, + LoginName: "test@example.com", + DisplayName: "Test User", + ProfilePicURL: "https://example.com/profile.png", + }, + }, + }, + SameNode: true, // must be true for prefs changes + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + // Use the default InitHook if not provided by the test. + if tt.ext.InitHook == nil { + tt.ext.InitHook = func(e *testExtension) error { + // Create and register the callback on init. + handler := makeStateChangeAppender(e) + e.host.Hooks().ProfileStateChange.Add(handler) + return nil + } + } + + h := newExtensionHostForTest(t, &testBackend{}, true, tt.ext) + for _, call := range tt.stateCalls { + h.NotifyProfileChange(call.Profile.View(), call.Prefs.View(), call.SameNode) + } + for _, call := range tt.prefsCalls { + h.NotifyProfilePrefsChanged(call.Profile.View(), call.Old.View(), call.New.View()) + } + if diff := deepcmp.Diff(tt.wantChanges, getStateChanges(tt.ext), defaultCmpOpts...); diff != "" { + t.Errorf("StateChange callbacks: (-want +got): %v", diff) + } + }) + } +} + +// TestCurrentProfileState tests that the current profile and prefs are correctly +// initialized and updated when the host is notified of changes. +func TestCurrentProfileState(t *testing.T) { + h := newExtensionHostForTest[ipnext.Extension](t, &testBackend{}, false) + + // The initial profile and prefs should be valid and set to the default values. + gotProfile, gotPrefs := h.Profiles().CurrentProfileState() + checkViewsEqual(t, "Initial profile (from state)", gotProfile, zeroProfile) + checkViewsEqual(t, "Initial prefs (from state)", gotPrefs, defaultPrefs) + gotPrefs = h.Profiles().CurrentPrefs() // same when we only ask for prefs + checkViewsEqual(t, "Initial prefs (direct)", gotPrefs, defaultPrefs) + + // Create a new profile and prefs, and notify the host of the change. + profile := &ipn.LoginProfile{ID: "profile-A"} + prefsV1 := &ipn.Prefs{ProfileName: "Prefs V1", WantRunning: true} + h.NotifyProfileChange(profile.View(), prefsV1.View(), false) + // The current profile and prefs should be updated. + gotProfile, gotPrefs = h.Profiles().CurrentProfileState() + checkViewsEqual(t, "Changed profile (from state)", gotProfile, profile.View()) + checkViewsEqual(t, "New prefs (from state)", gotPrefs, prefsV1.View()) + gotPrefs = h.Profiles().CurrentPrefs() + checkViewsEqual(t, "New prefs (direct)", gotPrefs, prefsV1.View()) + + // Notify the host of a change to the profile's prefs. + prefsV2 := &ipn.Prefs{ProfileName: "Prefs V2", WantRunning: false} + h.NotifyProfilePrefsChanged(profile.View(), prefsV1.View(), prefsV2.View()) + // The current prefs should be updated. + gotProfile, gotPrefs = h.Profiles().CurrentProfileState() + checkViewsEqual(t, "Unchanged profile (from state)", gotProfile, profile.View()) + checkViewsEqual(t, "Changed (from state)", gotPrefs, prefsV2.View()) + gotPrefs = h.Profiles().CurrentPrefs() + checkViewsEqual(t, "Changed prefs (direct)", gotPrefs, prefsV2.View()) +} + +// TestBackgroundProfileResolver tests that the background profile resolvers +// are correctly registered, unregistered and invoked by the [ExtensionHost]. +func TestBackgroundProfileResolver(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + profiles []ipn.LoginProfile // the first one is the current profile + resolvers []ipnext.ProfileResolver + wantProfile *ipn.LoginProfile + }{ + { + name: "No-Profiles/No-Resolvers", + profiles: nil, + resolvers: nil, + wantProfile: nil, + }, + { + // TODO(nickkhyl): update this test as we change "background profile resolvers" + // to just "profile resolvers". The wantProfile should be the current profile by default. + name: "Has-Profiles/No-Resolvers", + profiles: []ipn.LoginProfile{{ID: "profile-1"}}, + resolvers: nil, + wantProfile: nil, + }, + { + name: "Has-Profiles/Single-Resolver", + profiles: []ipn.LoginProfile{{ID: "profile-1"}}, + resolvers: []ipnext.ProfileResolver{ + func(ps ipnext.ProfileStore) ipn.LoginProfileView { + return ps.CurrentProfile() + }, + }, + wantProfile: &ipn.LoginProfile{ID: "profile-1"}, + }, + // TODO(nickkhyl): add more tests for multiple resolvers and different profiles + // once we change "background profile resolvers" to just "profile resolvers" + // and add proper conflict resolution logic. + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + // Create a new profile manager and add the profiles to it. + // We expose the profile manager to the extensions via the read-only [ipnext.ProfileStore] interface. + pm := must.Get(newProfileManager(new(mem.Store), t.Logf, health.NewTracker(eventbustest.NewBus(t)))) + for i, p := range tt.profiles { + // Generate a unique ID and key for each profile, + // unless the profile already has them set + // or is an empty, unnamed profile. + if p.Name != "" { + if p.ID == "" { + p.ID = ipn.ProfileID("profile-" + strconv.Itoa(i)) + } + if p.Key == "" { + p.Key = "key-" + ipn.StateKey(p.ID) + } + } + pv := p.View() + pm.knownProfiles[p.ID] = pv + if i == 0 { + // Set the first profile as the current one. + // A profileManager starts with an empty profile, + // so it's okay if the list of profiles is empty. + pm.SwitchToProfile(pv) + } + } + + h := newExtensionHostForTest[ipnext.Extension](t, &testBackend{}, false) + + // Register the resolvers with the host. + // This is typically done by the extensions themselves, + // but we do it here for testing purposes. + for _, r := range tt.resolvers { + h.Hooks().BackgroundProfileResolvers.Add(r) + } + h.Init() + + // Call the resolver to get the profile. + gotProfile := h.DetermineBackgroundProfile(pm) + if !gotProfile.Equals(tt.wantProfile.View()) { + t.Errorf("Resolved profile: got %v; want %v", gotProfile, tt.wantProfile) + } + }) + } +} + +// TestAuditLogProviders tests that the [ExtensionHost] correctly handles +// the registration and invocation of audit log providers. It verifies that +// the audit loggers are called with the correct actions and details, +// and that any errors returned by the providers are properly propagated. +func TestAuditLogProviders(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + auditLoggers []ipnauth.AuditLogFunc // each represents an extension + actions []tailcfg.ClientAuditAction + wantErr bool + }{ + { + name: "No-Providers", + auditLoggers: nil, + actions: []tailcfg.ClientAuditAction{"TestAction-1", "TestAction-2"}, + wantErr: false, + }, + { + name: "Single-Provider/Ok", + auditLoggers: []ipnauth.AuditLogFunc{ + func(tailcfg.ClientAuditAction, string) error { return nil }, + }, + actions: []tailcfg.ClientAuditAction{"TestAction-1", "TestAction-2"}, + wantErr: false, + }, + { + name: "Single-Provider/Err", + auditLoggers: []ipnauth.AuditLogFunc{ + func(tailcfg.ClientAuditAction, string) error { + return errors.New("failed to log") + }, + }, + actions: []tailcfg.ClientAuditAction{"TestAction-1", "TestAction-2"}, + wantErr: true, + }, + { + name: "Many-Providers/Ok", + auditLoggers: []ipnauth.AuditLogFunc{ + func(tailcfg.ClientAuditAction, string) error { return nil }, + func(tailcfg.ClientAuditAction, string) error { return nil }, + }, + actions: []tailcfg.ClientAuditAction{"TestAction-1", "TestAction-2"}, + wantErr: false, + }, + { + name: "Many-Providers/Err", + auditLoggers: []ipnauth.AuditLogFunc{ + func(tailcfg.ClientAuditAction, string) error { + return errors.New("failed to log") + }, + func(tailcfg.ClientAuditAction, string) error { + return nil // all good + }, + func(tailcfg.ClientAuditAction, string) error { + return errors.New("also failed to log") + }, + }, + actions: []tailcfg.ClientAuditAction{"TestAction-1", "TestAction-2"}, + wantErr: true, // some providers failed to log, so that's an error + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create extensions that register the audit log providers. + // Each extension/provider will append auditable actions to its state, + // then call the test's auditLogger function. + var exts []*testExtension + for _, auditLogger := range tt.auditLoggers { + ext := &testExtension{} + provider := func() ipnauth.AuditLogFunc { + return func(action tailcfg.ClientAuditAction, details string) error { + UpdateExtState(ext, "actions", func(actions []tailcfg.ClientAuditAction) []tailcfg.ClientAuditAction { + return append(actions, action) + }) + return auditLogger(action, details) + } + } + ext.InitHook = func(e *testExtension) error { + e.host.Hooks().AuditLoggers.Add(provider) + return nil + } + exts = append(exts, ext) + } + + // Initialize the host and the extensions. + h := newExtensionHostForTest(t, &testBackend{}, true, exts...) + + // Use [ExtensionHost.AuditLogger] to log actions. + for _, action := range tt.actions { + err := h.AuditLogger()(action, "Test details") + if gotErr := err != nil; gotErr != tt.wantErr { + t.Errorf("AuditLogger: gotErr %v (%v); wantErr %v", gotErr, err, tt.wantErr) + } + } + + // Check that the actions were logged correctly by each provider. + for _, ext := range exts { + gotActions := GetExtState[[]tailcfg.ClientAuditAction](ext, "actions") + if !slices.Equal(gotActions, tt.actions) { + t.Errorf("Actions: got %v; want %v", gotActions, tt.actions) + } + } + }) + } +} + +// TestNilExtensionHostMethodCall tests that calling exported methods +// on a nil [ExtensionHost] does not panic. We should treat it as a valid +// value since it's used in various tests that instantiate [LocalBackend] +// manually without calling [NewLocalBackend]. It also verifies that if +// a method returns a single func value (e.g., a cleanup function), +// it should not be nil. This is a basic sanity check to ensure that +// typical method calls on a nil receiver work as expected. +// It does not replace the need for more thorough testing of specific methods. +func TestNilExtensionHostMethodCall(t *testing.T) { + t.Parallel() + + var h *ExtensionHost + typ := reflect.TypeOf(h) + for i := range typ.NumMethod() { + m := typ.Method(i) + if strings.HasSuffix(m.Name, "ForTest") { + // Skip methods that are only for testing. + continue + } + + t.Run(m.Name, func(t *testing.T) { + t.Parallel() + // Calling the method on the nil receiver should not panic. + ret := checkMethodCallWithZeroArgs(t, m, h) + if len(ret) == 1 && ret[0].Kind() == reflect.Func { + // If the method returns a single func, such as a cleanup function, + // it should not be nil. + fn := ret[0] + if fn.IsNil() { + t.Fatalf("(%T).%s returned a nil func", h, m.Name) + } + // We expect it to be a no-op and calling it should not panic. + args := makeZeroArgsFor(fn) + func() { + defer func() { + if e := recover(); e != nil { + t.Fatalf("panic calling the func returned by (%T).%s: %v", e, m.Name, e) + } + }() + fn.Call(args) + }() + } + }) + } +} + +// extBeforeStartExtension is a test extension used by TestGetExtBeforeStart. +// It is registered with the [ipnext.RegisterExtension]. +type extBeforeStartExtension struct{} + +func init() { + ipnext.RegisterExtension("ext-before-start", mkExtBeforeStartExtension) +} + +func mkExtBeforeStartExtension(logger.Logf, ipnext.SafeBackend) (ipnext.Extension, error) { + return extBeforeStartExtension{}, nil +} + +func (extBeforeStartExtension) Name() string { return "ext-before-start" } +func (extBeforeStartExtension) Init(ipnext.Host) error { + return nil +} +func (extBeforeStartExtension) Shutdown() error { + return nil +} + +// TestGetExtBeforeStart verifies that an extension registered via +// RegisterExtension can be retrieved with GetExt before the host is started +// (via LocalBackend.Start) +func TestGetExtBeforeStart(t *testing.T) { + lb := newTestBackend(t) + // Now call GetExt without calling Start on the LocalBackend. + _, ok := GetExt[extBeforeStartExtension](lb) + if !ok { + t.Fatal("didn't find extension") + } +} + +// checkMethodCallWithZeroArgs calls the method m on the receiver r +// with zero values for all its arguments, except the receiver itself. +// It returns the result of the method call, or fails the test if the call panics. +func checkMethodCallWithZeroArgs[T any](t *testing.T, m reflect.Method, r T) []reflect.Value { + t.Helper() + args := makeZeroArgsFor(m.Func) + // The first arg is the receiver. + args[0] = reflect.ValueOf(r) + // Calling the method should not panic. + defer func() { + if e := recover(); e != nil { + t.Fatalf("panic calling (%T).%s: %v", r, m.Name, e) + } + }() + return m.Func.Call(args) +} + +func makeZeroArgsFor(fn reflect.Value) []reflect.Value { + args := make([]reflect.Value, fn.Type().NumIn()) + for i := range args { + args[i] = reflect.Zero(fn.Type().In(i)) + } + return args +} + +// newExtensionHostForTest creates an [ExtensionHost] with the given backend and extensions. +// It associates each extension that either is or embeds a [testExtension] with the test +// and assigns a name if one isn’t already set. +// +// If the host cannot be created, it fails the test. +// +// The host is initialized if the initialize parameter is true. +// It is shut down automatically when the test ends. +func newExtensionHostForTest[T ipnext.Extension](t *testing.T, b Backend, initialize bool, exts ...T) *ExtensionHost { + t.Helper() + + // testExtensionIface is a subset of the methods implemented by [testExtension] that are used here. + // We use testExtensionIface in type assertions instead of using the [testExtension] type directly, + // which supports scenarios where an extension type embeds a [testExtension]. + type testExtensionIface interface { + Name() string + setName(string) + setT(*testing.T) + checkShutdown() + } + + logf := tstest.WhileTestRunningLogger(t) + defs := make([]*ipnext.Definition, len(exts)) + for i, ext := range exts { + if ext, ok := any(ext).(testExtensionIface); ok { + ext.setName(cmp.Or(ext.Name(), "Ext-"+strconv.Itoa(i))) + ext.setT(t) + } + defs[i] = ipnext.DefinitionForTest(ext) + } + h, err := NewExtensionHostForTest(logf, b, defs...) + if err != nil { + t.Fatalf("NewExtensionHost: %v", err) + } + // Replace doEnqueueBackendOperation with the one that's marked as a helper, + // so that we'll have better output if [testExecQueue.Add] fails a test. + h.doEnqueueBackendOperation = func(f func(Backend)) { + t.Helper() + h.workQueue.Add(func() { f(b) }) + } + for _, ext := range exts { + if ext, ok := any(ext).(testExtensionIface); ok { + t.Cleanup(ext.checkShutdown) + } + } + t.Cleanup(h.Shutdown) + if initialize { + h.Init() + } + return h +} + +// testExtension is an [ipnext.Extension] that: +// - Calls the provided init and shutdown callbacks +// when [Init] and [Shutdown] are called. +// - Ensures that [Init] and [Shutdown] are called at most once, +// that [Shutdown] is called after [Init], but is not called if [Init] fails +// and is called before the test ends if [Init] succeeds. +// +// Typically, [testExtension]s are created and passed to [newExtensionHostForTest] +// when creating an [ExtensionHost] for testing. +type testExtension struct { + t *testing.T // test that created the extension + name string // name of the extension, used for logging + + host ipnext.Host // or nil if not initialized + + // InitHook and ShutdownHook are optional hooks that can be set by tests. + InitHook, ShutdownHook func(*testExtension) error + + // initCnt, initOkCnt and shutdownCnt are used to verify that Init and Shutdown + // are called at most once and in the correct order. + initCnt, initOkCnt, shutdownCnt atomic.Int32 + + // mu protects the following fields. + mu sync.Mutex + // state is the optional state used by tests. + // It can be accessed by tests using [setTestExtensionState], + // [getTestExtensionStateOk] and [getTestExtensionState]. + state map[string]any +} + +var _ ipnext.Extension = (*testExtension)(nil) + +// PermitDoubleRegister is a sentinel method whose existence tells the +// ExtensionHost to permit it to be registered multiple times. +func (*testExtension) PermitDoubleRegister() {} + +func (e *testExtension) setT(t *testing.T) { + e.t = t +} + +func (e *testExtension) setName(name string) { + e.name = name +} + +// Name implements [ipnext.Extension]. +func (e *testExtension) Name() string { + return e.name +} + +// Init implements [ipnext.Extension]. +func (e *testExtension) Init(host ipnext.Host) (err error) { + e.t.Helper() + e.host = host + if e.initCnt.Add(1) == 1 { + e.mu.Lock() + e.state = make(map[string]any) + e.mu.Unlock() + } else { + e.t.Errorf("%q: Init called more than once", e.name) + } + if e.InitHook != nil { + err = e.InitHook(e) + } + if err == nil { + e.initOkCnt.Add(1) + } + return err // may be nil or non-nil +} + +// InitCalled reports whether the Init method was called on the receiver. +func (e *testExtension) InitCalled() bool { + return e.initCnt.Load() != 0 +} + +// Shutdown implements [ipnext.Extension]. +func (e *testExtension) Shutdown() (err error) { + e.t.Helper() + e.mu.Lock() + defer e.mu.Unlock() + if e.ShutdownHook != nil { + err = e.ShutdownHook(e) + } + if e.shutdownCnt.Add(1) != 1 { + e.t.Errorf("%q: Shutdown called more than once", e.name) + } + if e.initCnt.Load() == 0 { + e.t.Errorf("%q: Shutdown called without Init", e.name) + } else if e.initOkCnt.Load() == 0 { + e.t.Errorf("%q: Shutdown called despite failed Init", e.name) + } + e.host = nil + return err // may be nil or non-nil +} + +func (e *testExtension) checkShutdown() { + e.t.Helper() + if e.initOkCnt.Load() != 0 && e.shutdownCnt.Load() == 0 { + e.t.Errorf("%q: Shutdown has not been called before test end", e.name) + } +} + +// ShutdownCalled reports whether the Shutdown method was called on the receiver. +func (e *testExtension) ShutdownCalled() bool { + return e.shutdownCnt.Load() != 0 +} + +// SetExtState sets a keyed state on [testExtension] to the given value. +// Tests use it to propagate test-specific state throughout the extension lifecycle +// (e.g., between [testExtension.Init], [testExtension.Shutdown], and registered callbacks) +func SetExtState[T any](e *testExtension, key string, value T) { + e.mu.Lock() + defer e.mu.Unlock() + e.state[key] = value +} + +// UpdateExtState updates a keyed state of the extension using the provided update function. +func UpdateExtState[T any](e *testExtension, key string, update func(T) T) { + e.mu.Lock() + defer e.mu.Unlock() + old, _ := e.state[key].(T) + new := update(old) + e.state[key] = new +} + +// GetExtState returns the value of the keyed state of the extension. +// It returns a zero value of T if the state is not set or is of a different type. +func GetExtState[T any](e *testExtension, key string) T { + v, _ := GetExtStateOk[T](e, key) + return v +} + +// GetExtStateOk is like [getExtState], but also reports whether the state +// with the given key exists and is of the expected type. +func GetExtStateOk[T any](e *testExtension, key string) (_ T, ok bool) { + e.mu.Lock() + defer e.mu.Unlock() + v, ok := e.state[key].(T) + return v, ok +} + +// testExecQueue is a test implementation of [execQueue] +// that defers execution of the enqueued funcs until +// [testExecQueue.Drain] is called, and fails the test if +// if [execQueue.Add] is called before the host is initialized. +// +// It is typically used by calling [ExtensionHost.SetWorkQueueForTest]. +type testExecQueue struct { + t *testing.T // test that created the queue + h *ExtensionHost // host to own the queue + + mu sync.Mutex + queue []func() +} + +var _ execQueue = (*testExecQueue)(nil) + +// SetWorkQueueForTest is a helper function that creates a new [testExecQueue] +// and sets it as the work queue for the specified [ExtensionHost], +// returning the new queue. +// +// It fails the test if the host is already initialized. +func (h *ExtensionHost) SetWorkQueueForTest(t *testing.T) *testExecQueue { + t.Helper() + if h.initialized.Load() { + t.Fatalf("UseTestWorkQueue: host is already initialized") + return nil + } + q := &testExecQueue{t: t, h: h} + h.workQueue = q + return q +} + +// Add implements [execQueue]. +func (q *testExecQueue) Add(f func()) { + q.t.Helper() + + if !q.h.initialized.Load() { + q.t.Fatal("ExecQueue.Add must not be called until the host is initialized") + return + } + + q.mu.Lock() + q.queue = append(q.queue, f) + q.mu.Unlock() +} + +// Drain executes all queued functions in the order they were added. +func (q *testExecQueue) Drain() { + q.mu.Lock() + queue := q.queue + q.queue = nil + q.mu.Unlock() + + for _, f := range queue { + f() + } +} + +// Shutdown implements [execQueue]. +func (q *testExecQueue) Shutdown() {} + +// Wait implements [execQueue]. +func (q *testExecQueue) Wait(context.Context) error { return nil } + +// testBackend implements [ipnext.Backend] for testing purposes +// by calling the provided hooks when its methods are called. +type testBackend struct { + lazySys lazy.SyncValue[*tsd.System] + switchToBestProfileHook func(reason string) + + // mu protects the backend state. + // It is acquired on entry to the exported methods of the backend + // and released on exit, mimicking the behavior of the [LocalBackend]. + mu sync.Mutex +} + +func (b *testBackend) Clock() tstime.Clock { return tstime.StdClock{} } +func (b *testBackend) Sys() *tsd.System { + return b.lazySys.Get(tsd.NewSystem) +} +func (b *testBackend) SendNotify(ipn.Notify) { panic("not implemented") } +func (b *testBackend) NodeBackend() ipnext.NodeBackend { panic("not implemented") } +func (b *testBackend) TailscaleVarRoot() string { panic("not implemented") } + +func (b *testBackend) SwitchToBestProfile(reason string) { + b.mu.Lock() + defer b.mu.Unlock() + if b.switchToBestProfileHook != nil { + b.switchToBestProfileHook(reason) + } +} + +// equatableView is an interface implemented by views +// that can be compared for equality. +type equatableView[T any] interface { + Valid() bool + Equals(other T) bool +} + +// checkViewsEqual checks that the two views are equal +// and fails the test if they are not. The prefix is used +// to format the error message. +func checkViewsEqual[T equatableView[T]](t *testing.T, prefix string, got, want T) { + t.Helper() + switch { + case got.Equals(want): + return + case got.Valid() && want.Valid(): + t.Errorf("%s: got %v; want %v", prefix, got, want) + case got.Valid() && !want.Valid(): + t.Errorf("%s: got %v; want invalid", prefix, got) + case !got.Valid() && want.Valid(): + t.Errorf("%s: got invalid; want %v", prefix, want) + default: + panic("unreachable") + } +} diff --git a/ipn/ipnlocal/hwattest.go b/ipn/ipnlocal/hwattest.go new file mode 100644 index 000000000..2c93cad4c --- /dev/null +++ b/ipn/ipnlocal/hwattest.go @@ -0,0 +1,48 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !ts_omit_tpm + +package ipnlocal + +import ( + "errors" + + "tailscale.com/feature" + "tailscale.com/types/key" + "tailscale.com/types/logger" + "tailscale.com/types/persist" +) + +func init() { + feature.HookGenerateAttestationKeyIfEmpty.Set(generateAttestationKeyIfEmpty) +} + +// generateAttestationKeyIfEmpty generates a new hardware attestation key if +// none exists. It returns true if a new key was generated and stored in +// p.AttestationKey. +func generateAttestationKeyIfEmpty(p *persist.Persist, logf logger.Logf) (bool, error) { + // attempt to generate a new hardware attestation key if none exists + var ak key.HardwareAttestationKey + if p != nil { + ak = p.AttestationKey + } + + if ak == nil || ak.IsZero() { + var err error + ak, err = key.NewHardwareAttestationKey() + if err != nil { + if !errors.Is(err, key.ErrUnsupported) { + logf("failed to create hardware attestation key: %v", err) + } + } else if ak != nil { + logf("using new hardware attestation key: %v", ak.Public()) + if p == nil { + p = &persist.Persist{} + } + p.AttestationKey = ak + return true, nil + } + } + return false, nil +} diff --git a/ipn/ipnlocal/local.go b/ipn/ipnlocal/local.go index 8fc78a36b..0ff299399 100644 --- a/ipn/ipnlocal/local.go +++ b/ipn/ipnlocal/local.go @@ -6,16 +6,15 @@ package ipnlocal import ( - "bytes" "cmp" "context" - "encoding/base64" + "crypto/sha256" + "encoding/binary" "encoding/json" "errors" "fmt" "io" "log" - "maps" "math" "math/rand/v2" "net" @@ -23,12 +22,9 @@ import ( "net/netip" "net/url" "os" - "os/exec" - "path/filepath" "reflect" "runtime" "slices" - "sort" "strconv" "strings" "sync" @@ -37,31 +33,26 @@ import ( "go4.org/mem" "go4.org/netipx" - xmaps "golang.org/x/exp/maps" "golang.org/x/net/dns/dnsmessage" - "gvisor.dev/gvisor/pkg/tcpip" "tailscale.com/appc" "tailscale.com/client/tailscale/apitype" - "tailscale.com/clientupdate" "tailscale.com/control/controlclient" "tailscale.com/control/controlknobs" - "tailscale.com/doctor" - "tailscale.com/doctor/ethtool" - "tailscale.com/doctor/permissions" - "tailscale.com/doctor/routetable" "tailscale.com/drive" "tailscale.com/envknob" + "tailscale.com/envknob/featureknob" + "tailscale.com/feature" + "tailscale.com/feature/buildfeatures" "tailscale.com/health" "tailscale.com/health/healthmsg" "tailscale.com/hostinfo" "tailscale.com/ipn" "tailscale.com/ipn/conffile" "tailscale.com/ipn/ipnauth" + "tailscale.com/ipn/ipnext" "tailscale.com/ipn/ipnstate" - "tailscale.com/ipn/policy" "tailscale.com/log/sockstatlog" "tailscale.com/logpolicy" - "tailscale.com/net/captivedetection" "tailscale.com/net/dns" "tailscale.com/net/dnscache" "tailscale.com/net/dnsfallback" @@ -71,21 +62,18 @@ import ( "tailscale.com/net/netmon" "tailscale.com/net/netns" "tailscale.com/net/netutil" + "tailscale.com/net/packet" "tailscale.com/net/tsaddr" "tailscale.com/net/tsdial" "tailscale.com/paths" - "tailscale.com/portlist" "tailscale.com/syncs" "tailscale.com/tailcfg" - "tailscale.com/taildrop" - "tailscale.com/tka" "tailscale.com/tsd" "tailscale.com/tstime" "tailscale.com/types/appctype" "tailscale.com/types/dnstype" "tailscale.com/types/empty" "tailscale.com/types/key" - "tailscale.com/types/lazy" "tailscale.com/types/logger" "tailscale.com/types/logid" "tailscale.com/types/netmap" @@ -94,24 +82,25 @@ import ( "tailscale.com/types/preftype" "tailscale.com/types/ptr" "tailscale.com/types/views" - "tailscale.com/util/deephash" + "tailscale.com/util/checkchange" + "tailscale.com/util/clientmetric" "tailscale.com/util/dnsname" - "tailscale.com/util/httpm" + "tailscale.com/util/eventbus" + "tailscale.com/util/execqueue" + "tailscale.com/util/goroutines" "tailscale.com/util/mak" - "tailscale.com/util/multierr" - "tailscale.com/util/osshare" "tailscale.com/util/osuser" "tailscale.com/util/rands" "tailscale.com/util/set" - "tailscale.com/util/syspolicy" - "tailscale.com/util/systemd" + "tailscale.com/util/slicesx" + "tailscale.com/util/syspolicy/pkey" + "tailscale.com/util/syspolicy/policyclient" + "tailscale.com/util/syspolicy/ptype" "tailscale.com/util/testenv" - "tailscale.com/util/uniq" "tailscale.com/util/usermetric" "tailscale.com/version" "tailscale.com/version/distro" "tailscale.com/wgengine" - "tailscale.com/wgengine/capture" "tailscale.com/wgengine/filter" "tailscale.com/wgengine/magicsock" "tailscale.com/wgengine/router" @@ -154,14 +143,31 @@ func RegisterNewSSHServer(fn newSSHServerFunc) { newSSHServer = fn } -// watchSession represents a WatchNotifications channel +// watchSession represents a WatchNotifications channel, +// an [ipnauth.Actor] that owns it (e.g., a connected GUI/CLI), // and sessionID as required to close targeted buses. type watchSession struct { ch chan *ipn.Notify + owner ipnauth.Actor // or nil sessionID string - cancel func() // call to signal that the session must be terminated + cancel context.CancelFunc // to shut down the session } +var ( + // errShutdown indicates that the [LocalBackend.Shutdown] was called. + errShutdown = errors.New("shutting down") + + // errNodeContextChanged indicates that [LocalBackend] has switched + // to a different [localNodeContext], usually due to a profile change. + // It is used as a context cancellation cause for the old context + // and can be returned when an operation is performed on it. + errNodeContextChanged = errors.New("profile changed") + + // errManagedByPolicy indicates the operation is blocked + // because the target state is managed by a GP/MDM policy. + errManagedByPolicy = errors.New("managed by policy") +) + // LocalBackend is the glue between the major pieces of the Tailscale // network software: the cloud control plane (via controlclient), the // network data plane (via wgengine), and the user-facing UIs and CLIs @@ -174,35 +180,36 @@ type watchSession struct { // state machine generates events back out to zero or more components. type LocalBackend struct { // Elements that are thread-safe or constant after construction. - ctx context.Context // canceled by Close - ctxCancel context.CancelFunc // cancels ctx - logf logger.Logf // general logging - keyLogf logger.Logf // for printing list of peers on change - statsLogf logger.Logf // for printing peers stats on change - sys *tsd.System - health *health.Tracker // always non-nil - metrics metrics - e wgengine.Engine // non-nil; TODO(bradfitz): remove; use sys - store ipn.StateStore // non-nil; TODO(bradfitz): remove; use sys - dialer *tsdial.Dialer // non-nil; TODO(bradfitz): remove; use sys - pushDeviceToken syncs.AtomicValue[string] - backendLogID logid.PublicID - unregisterNetMon func() - unregisterHealthWatch func() - portpoll *portlist.Poller // may be nil - portpollOnce sync.Once // guards starting readPoller - varRoot string // or empty if SetVarRoot never called - logFlushFunc func() // or nil if SetLogFlusher wasn't called - em *expiryManager // non-nil - sshAtomicBool atomic.Bool + ctx context.Context // canceled by [LocalBackend.Shutdown] + ctxCancel context.CancelCauseFunc // cancels ctx + logf logger.Logf // general logging + keyLogf logger.Logf // for printing list of peers on change + statsLogf logger.Logf // for printing peers stats on change + sys *tsd.System + eventClient *eventbus.Client + appcTask execqueue.ExecQueue // handles updates from appc + + health *health.Tracker // always non-nil + polc policyclient.Client // always non-nil + metrics metrics + e wgengine.Engine // non-nil; TODO(bradfitz): remove; use sys + store ipn.StateStore // non-nil; TODO(bradfitz): remove; use sys + dialer *tsdial.Dialer // non-nil; TODO(bradfitz): remove; use sys + pushDeviceToken syncs.AtomicValue[string] + backendLogID logid.PublicID // or zero value if logging not in use + unregisterSysPolicyWatch func() + varRoot string // or empty if SetVarRoot never called + logFlushFunc func() // or nil if SetLogFlusher wasn't called + em *expiryManager // non-nil; TODO(nickkhyl): move to nodeBackend + sshAtomicBool atomic.Bool // TODO(nickkhyl): move to nodeBackend // webClientAtomicBool controls whether the web client is running. This should // be true unless the disable-web-client node attribute has been set. - webClientAtomicBool atomic.Bool + webClientAtomicBool atomic.Bool // TODO(nickkhyl): move to nodeBackend // exposeRemoteWebClientAtomicBool controls whether the web client is exposed over // Tailscale on port 5252. - exposeRemoteWebClientAtomicBool atomic.Bool - shutdownCalled bool // if Shutdown has been called - debugSink *capture.Sink + exposeRemoteWebClientAtomicBool atomic.Bool // TODO(nickkhyl): move to nodeBackend + shutdownCalled bool // if Shutdown has been called + debugSink packet.CaptureSink sockstatLogger *sockstatlog.Logger // getTCPHandlerForFunnelFlow returns a handler for an incoming TCP flow for @@ -221,93 +228,90 @@ type LocalBackend struct { // is never called. getTCPHandlerForFunnelFlow func(srcAddr netip.AddrPort, dstPort uint16) (handler func(net.Conn)) - filterAtomic atomic.Pointer[filter.Filter] - containsViaIPFuncAtomic syncs.AtomicValue[func(netip.Addr) bool] - shouldInterceptTCPPortAtomic syncs.AtomicValue[func(uint16) bool] - numClientStatusCalls atomic.Uint32 + containsViaIPFuncAtomic syncs.AtomicValue[func(netip.Addr) bool] // TODO(nickkhyl): move to nodeBackend + shouldInterceptTCPPortAtomic syncs.AtomicValue[func(uint16) bool] // TODO(nickkhyl): move to nodeBackend + shouldInterceptVIPServicesTCPPortAtomic syncs.AtomicValue[func(netip.AddrPort) bool] // TODO(nickkhyl): move to nodeBackend + numClientStatusCalls atomic.Uint32 // TODO(nickkhyl): move to nodeBackend + + // goTracker accounts for all goroutines started by LocalBacked, primarily + // for testing and graceful shutdown purposes. + goTracker goroutines.Tracker + + startOnce sync.Once // protects the one‑time initialization in [LocalBackend.Start] + + // extHost is the bridge between [LocalBackend] and the registered [ipnext.Extension]s. + // It may be nil in tests that use direct composite literal initialization of [LocalBackend] + // instead of calling [NewLocalBackend]. A nil pointer is a valid, no-op host. + // It can be used with or without b.mu held, but is typically used with it held + // to prevent state changes while invoking callbacks. + extHost *ExtensionHost // The mutex protects the following elements. - mu sync.Mutex - conf *conffile.Config // latest parsed config, or nil if not in declarative mode - pm *profileManager // mu guards access - filterHash deephash.Sum - httpTestClient *http.Client // for controlclient. nil by default, used by tests. - ccGen clientGen // function for producing controlclient; lazily populated - sshServer SSHServer // or nil, initialized lazily. - appConnector *appc.AppConnector // or nil, initialized when configured. + mu syncs.Mutex + + // currentNodeAtomic is the current node context. It is always non-nil. + // It must be re-created when [LocalBackend] switches to a different profile/node + // (see tailscale/corp#28014 for a bug), but can be mutated in place (via its methods) + // while [LocalBackend] represents the same node. + // + // It is safe for reading with or without holding b.mu, but mutating it in place + // or creating a new one must be done with b.mu held. If both mutexes must be held, + // the LocalBackend's mutex must be acquired first before acquiring the nodeBackend's mutex. + // + // We intend to relax this in the future and only require holding b.mu when replacing it, + // but that requires a better (strictly ordered?) state machine and better management + // of [LocalBackend]'s own state that is not tied to the node context. + currentNodeAtomic atomic.Pointer[nodeBackend] + + conf *conffile.Config // latest parsed config, or nil if not in declarative mode + pm *profileManager // mu guards access + lastFilterInputs *filterInputs + httpTestClient *http.Client // for controlclient. nil by default, used by tests. + ccGen clientGen // function for producing controlclient; lazily populated + sshServer SSHServer // or nil, initialized lazily. + appConnector *appc.AppConnector // or nil, initialized when configured. // notifyCancel cancels notifications to the current SetNotifyCallback. - notifyCancel context.CancelFunc - cc controlclient.Client - ccAuto *controlclient.Auto // if cc is of type *controlclient.Auto + notifyCancel context.CancelFunc + cc controlclient.Client // TODO(nickkhyl): move to nodeBackend + ccAuto *controlclient.Auto // if cc is of type *controlclient.Auto; TODO(nickkhyl): move to nodeBackend + + // ignoreControlClientUpdates indicates whether we want to ignore SetControlClientStatus updates + // before acquiring b.mu. This is used during shutdown to avoid deadlocks. + ignoreControlClientUpdates atomic.Bool + machinePrivKey key.MachinePrivate - tka *tkaState - state ipn.State - capFileSharing bool // whether netMap contains the file sharing capability - capTailnetLock bool // whether netMap contains the tailnet lock capability + tka *tkaState // TODO(nickkhyl): move to nodeBackend + state ipn.State // TODO(nickkhyl): move to nodeBackend + capTailnetLock bool // whether netMap contains the tailnet lock capability // hostinfo is mutated in-place while mu is held. - hostinfo *tailcfg.Hostinfo - // netMap is the most recently set full netmap from the controlclient. - // It can't be mutated in place once set. Because it can't be mutated in place, - // delta updates from the control server don't apply to it. Instead, use - // the peers map to get up-to-date information on the state of peers. - // In general, avoid using the netMap.Peers slice. We'd like it to go away - // as of 2023-09-17. - netMap *netmap.NetworkMap - // peers is the set of current peers and their current values after applying - // delta node mutations as they come in (with mu held). The map values can - // be given out to callers, but the map itself must not escape the LocalBackend. - peers map[tailcfg.NodeID]tailcfg.NodeView - nodeByAddr map[netip.Addr]tailcfg.NodeID // by Node.Addresses only (not subnet routes) - nmExpiryTimer tstime.TimerController // for updating netMap on node expiry; can be nil - activeLogin string // last logged LoginName from netMap - engineStatus ipn.EngineStatus - endpoints []tailcfg.Endpoint - blocked bool - keyExpired bool - authURL string // non-empty if not Running - authURLTime time.Time // when the authURL was received from the control server - interact bool // indicates whether a user requested interactive login - egg bool - prevIfState *netmon.State - peerAPIServer *peerAPIServer // or nil - peerAPIListeners []*peerAPIListener - loginFlags controlclient.LoginFlags - fileWaiters set.HandleSet[context.CancelFunc] // of wake-up funcs - notifyWatchers map[string]*watchSession // by session ID - lastStatusTime time.Time // status.AsOf value of the last processed status update - // directFileRoot, if non-empty, means to write received files - // directly to this directory, without staging them in an - // intermediate buffered directory for "pick-up" later. If - // empty, the files are received in a daemon-owned location - // and the localapi is used to enumerate, download, and delete - // them. This is used on macOS where the GUI lifetime is the - // same as the Network Extension lifetime and we can thus avoid - // double-copying files by writing them to the right location - // immediately. - // It's also used on several NAS platforms (Synology, TrueNAS, etc) - // but in that case DoFinalRename is also set true, which moves the - // *.partial file to its final name on completion. - directFileRoot string + hostinfo *tailcfg.Hostinfo // TODO(nickkhyl): move to nodeBackend + nmExpiryTimer tstime.TimerController // for updating netMap on node expiry; can be nil; TODO(nickkhyl): move to nodeBackend + activeLogin string // last logged LoginName from netMap; TODO(nickkhyl): move to nodeBackend (or remove? it's in [ipn.LoginProfile]). + engineStatus ipn.EngineStatus + endpoints []tailcfg.Endpoint + blocked bool + keyExpired bool // TODO(nickkhyl): move to nodeBackend + authURL string // non-empty if not Running; TODO(nickkhyl): move to nodeBackend + authURLTime time.Time // when the authURL was received from the control server; TODO(nickkhyl): move to nodeBackend + authActor ipnauth.Actor // an actor who called [LocalBackend.StartLoginInteractive] last, or nil; TODO(nickkhyl): move to nodeBackend + egg bool + prevIfState *netmon.State + peerAPIServer *peerAPIServer // or nil + peerAPIListeners []*peerAPIListener + loginFlags controlclient.LoginFlags + notifyWatchers map[string]*watchSession // by session ID + lastStatusTime time.Time // status.AsOf value of the last processed status update componentLogUntil map[string]componentLogState - // c2nUpdateStatus is the status of c2n-triggered client update. - c2nUpdateStatus updateStatus - currentUser ipnauth.Actor - selfUpdateProgress []ipnstate.UpdateProgress - lastSelfUpdateState ipnstate.SelfUpdateStatus + currentUser ipnauth.Actor + // capForcedNetfilter is the netfilter that control instructs Linux clients // to use, unless overridden locally. - capForcedNetfilter string - // offlineAutoUpdateCancel stops offline auto-updates when called. It - // should be used via stopOfflineAutoUpdate and - // maybeStartOfflineAutoUpdate. It is nil when offline auto-updates are - // note running. - // - //lint:ignore U1000 only used in Linux and Windows builds in autoupdate.go - offlineAutoUpdateCancel func() + capForcedNetfilter string // TODO(nickkhyl): move to nodeBackend // ServeConfig fields. (also guarded by mu) - lastServeConfJSON mem.RO // last JSON that was parsed into serveConfig - serveConfig ipn.ServeConfigView // or !Valid if none + lastServeConfJSON mem.RO // last JSON that was parsed into serveConfig + serveConfig ipn.ServeConfigView // or !Valid if none + ipVIPServiceMap netmap.IPServiceMappings // map of VIPService IPs to their corresponding service names; TODO(nickkhyl): move to nodeBackend webClient webClient webClientListeners map[netip.AddrPort]*localListener // listeners for local web client traffic @@ -315,14 +319,9 @@ type LocalBackend struct { serveListeners map[netip.AddrPort]*localListener // listeners for local serve traffic serveProxyHandlers sync.Map // string (HTTPHandler.Proxy) => *reverseProxy - // statusLock must be held before calling statusChanged.Wait() or - // statusChanged.Broadcast(). - statusLock sync.Mutex - statusChanged *sync.Cond - // dialPlan is any dial plan that we've received from the control // server during a previous connection; it is cleared on logout. - dialPlan atomic.Pointer[tailcfg.ControlDialPlan] + dialPlan atomic.Pointer[tailcfg.ControlDialPlan] // TODO(nickkhyl): maybe move to nodeBackend? // tkaSyncLock is used to make tkaSyncIfNeeded an exclusive // section. This is needed to stop two map-responses in quick succession @@ -330,28 +329,31 @@ type LocalBackend struct { // // tkaSyncLock MUST be taken before mu (or inversely, mu must not be held // at the moment that tkaSyncLock is taken). - tkaSyncLock sync.Mutex + tkaSyncLock syncs.Mutex clock tstime.Clock // Last ClientVersion received in MapResponse, guarded by mu. lastClientVersion *tailcfg.ClientVersion // lastNotifiedDriveSharesMu guards lastNotifiedDriveShares - lastNotifiedDriveSharesMu sync.Mutex + lastNotifiedDriveSharesMu syncs.Mutex // lastNotifiedDriveShares keeps track of the last set of shares that we // notified about. lastNotifiedDriveShares *views.SliceView[*drive.Share, drive.ShareView] - // outgoingFiles keeps track of Taildrop outgoing files keyed to their OutgoingFile.ID - outgoingFiles map[string]*ipn.OutgoingFile - // lastSuggestedExitNode stores the last suggested exit node suggestion to // avoid unnecessary churn between multiple equally-good options. lastSuggestedExitNode tailcfg.StableNodeID + // allowedSuggestedExitNodes is a set of exit nodes permitted by the most recent + // [pkey.AllowedSuggestedExitNodes] value. The allowedSuggestedExitNodesMu + // mutex guards access to this set. + allowedSuggestedExitNodesMu sync.Mutex + allowedSuggestedExitNodes set.Set[tailcfg.StableNodeID] + // refreshAutoExitNode indicates if the exit node should be recomputed when the next netcheck report is available. - refreshAutoExitNode bool + refreshAutoExitNode bool // guarded by mu // captiveCtx and captiveCancel are used to control captive portal // detection. They are protected by 'mu' and can be changed during the @@ -367,13 +369,56 @@ type LocalBackend struct { // backend is healthy and captive portal detection is not required // (sending false). needsCaptiveDetection chan bool + + // overrideAlwaysOn is whether [pkey.AlwaysOn] is overridden by the user + // and should have no impact on the WantRunning state until the policy changes, + // or the user re-connects manually, switches to a different profile, etc. + // Notably, this is true when [pkey.AlwaysOnOverrideWithReason] is enabled, + // and the user has disconnected with a reason. + // See tailscale/corp#26146. + overrideAlwaysOn bool + + // reconnectTimer is used to schedule a reconnect by setting [ipn.Prefs.WantRunning] + // to true after a delay, or nil if no reconnect is scheduled. + reconnectTimer tstime.TimerController + + // overrideExitNodePolicy is whether the user has overridden the exit node policy + // by manually selecting an exit node, as allowed by [pkey.AllowExitNodeOverride]. + // + // If true, the [pkey.ExitNodeID] and [pkey.ExitNodeIP] policy settings are ignored, + // and the suggested exit node is not applied automatically. + // + // It is cleared when the user switches back to the state required by policy (typically, auto:any), + // or when switching profiles, connecting/disconnecting Tailscale, restarting the client, + // or on similar events. + // + // See tailscale/corp#29969. + overrideExitNodePolicy bool + + // hardwareAttested is whether backend should use a hardware-backed key to + // bind the node identity to this device. + hardwareAttested atomic.Bool } -// HealthTracker returns the health tracker for the backend. -func (b *LocalBackend) HealthTracker() *health.Tracker { - return b.health +// SetHardwareAttested enables hardware attestation key signatures in map +// requests, if supported on this platform. SetHardwareAttested should be called +// before Start. +func (b *LocalBackend) SetHardwareAttested() { + b.hardwareAttested.Store(true) } +// HardwareAttested reports whether hardware-backed attestation keys should be +// used to bind the node's identity to this device. +func (b *LocalBackend) HardwareAttested() bool { + return b.hardwareAttested.Load() +} + +// HealthTracker returns the health tracker for the backend. +func (b *LocalBackend) HealthTracker() *health.Tracker { return b.health } + +// Logger returns the logger for the backend. +func (b *LocalBackend) Logger() logger.Logf { return b.logf } + // UserMetricsRegistry returns the usermetrics registry for the backend func (b *LocalBackend) UserMetricsRegistry() *usermetric.Registry { return b.sys.UserMetricsRegistry() @@ -384,9 +429,8 @@ func (b *LocalBackend) NetMon() *netmon.Monitor { return b.sys.NetMon.Get() } -type updateStatus struct { - started bool -} +// PolicyClient returns the policy client for the backend. +func (b *LocalBackend) PolicyClient() policyclient.Client { return b.polc } type metrics struct { // advertisedRoutes is a metric that reports the number of network routes that are advertised by the local node. @@ -396,11 +440,6 @@ type metrics struct { // approvedRoutes is a metric that reports the number of network routes served by the local node and approved // by the control server. approvedRoutes *usermetric.Gauge - - // primaryRoutes is a metric that reports the number of primary network routes served by the local node. - // A route being a primary route implies that the route is currently served by this node, and not by another - // subnet router in a high availability configuration. - primaryRoutes *usermetric.Gauge } // clientGen is a func that creates a control plane client. @@ -411,7 +450,9 @@ type clientGen func(controlclient.Options) (controlclient.Client, error) // but is not actually running. // // If dialer is nil, a new one is made. -func NewLocalBackend(logf logger.Logf, logID logid.PublicID, sys *tsd.System, loginFlags controlclient.LoginFlags) (*LocalBackend, error) { +// +// The logID may be the zero value if logging is not in use. +func NewLocalBackend(logf logger.Logf, logID logid.PublicID, sys *tsd.System, loginFlags controlclient.LoginFlags) (_ *LocalBackend, err error) { e := sys.Engine.Get() store := sys.StateStore.Get() dialer := sys.Dialer.Get() @@ -427,7 +468,7 @@ func NewLocalBackend(logf logger.Logf, logID logid.PublicID, sys *tsd.System, lo if loginFlags&controlclient.LocalBackendStartKeyOSNeutral != 0 { goos = "" } - pm, err := newProfileManagerWithGOOS(store, logf, sys.HealthTracker(), goos) + pm, err := newProfileManagerWithGOOS(store, logf, sys.HealthTracker.Get(), goos) if err != nil { return nil, err } @@ -436,9 +477,8 @@ func NewLocalBackend(logf logger.Logf, logID logid.PublicID, sys *tsd.System, lo } envknob.LogCurrent(logf) - osshare.SetFileSharingEnabled(false, logf) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancelCause(context.Background()) clock := tstime.StdClock{} // Until we transition to a Running state, use a canceled context for @@ -451,8 +491,6 @@ func NewLocalBackend(logf logger.Logf, logID logid.PublicID, sys *tsd.System, lo "tailscaled_advertised_routes", "Number of advertised network routes (e.g. by a subnet router)"), approvedRoutes: sys.UserMetricsRegistry().NewGauge( "tailscaled_approved_routes", "Number of approved network routes (e.g. by a subnet router)"), - primaryRoutes: sys.UserMetricsRegistry().NewGauge( - "tailscaled_primary_routes", "Number of network routes for which this node is a primary router (in high availability configuration)"), } b := &LocalBackend{ @@ -462,7 +500,8 @@ func NewLocalBackend(logf logger.Logf, logID logid.PublicID, sys *tsd.System, lo keyLogf: logger.LogOnChange(logf, 5*time.Minute, clock.Now), statsLogf: logger.LogOnChange(logf, 5*time.Minute, clock.Now), sys: sys, - health: sys.HealthTracker(), + polc: sys.PolicyClientOrDefault(), + health: sys.HealthTracker.Get(), metrics: m, e: e, dialer: dialer, @@ -470,26 +509,40 @@ func NewLocalBackend(logf logger.Logf, logID logid.PublicID, sys *tsd.System, lo pm: pm, backendLogID: logID, state: ipn.NoState, - portpoll: new(portlist.Poller), - em: newExpiryManager(logf), + em: newExpiryManager(logf, sys.Bus.Get()), loginFlags: loginFlags, clock: clock, - selfUpdateProgress: make([]ipnstate.UpdateProgress, 0), - lastSelfUpdateState: ipnstate.UpdateFinished, captiveCtx: captiveCtx, captiveCancel: nil, // so that we start checkCaptivePortalLoop when Running needsCaptiveDetection: make(chan bool), } - mConn.SetNetInfoCallback(b.setNetInfo) + + nb := newNodeBackend(ctx, b.logf, b.sys.Bus.Get()) + b.currentNodeAtomic.Store(nb) + nb.ready() if sys.InitialConfig != nil { - if err := b.setConfigLocked(sys.InitialConfig); err != nil { + if err := b.initPrefsFromConfig(sys.InitialConfig); err != nil { return nil, err } } + if b.extHost, err = NewExtensionHost(logf, b); err != nil { + return nil, fmt.Errorf("failed to create extension host: %w", err) + } + b.pm.SetExtensionHost(b.extHost) + + if b.unregisterSysPolicyWatch, err = b.registerSysPolicyWatch(); err != nil { + return nil, err + } + defer func() { + if err != nil { + b.unregisterSysPolicyWatch() + } + }() + netMon := sys.NetMon.Get() - b.sockstatLogger, err = sockstatlog.NewLogger(logpolicy.LogsDir(logf), logf, logID, netMon, sys.HealthTracker()) + b.sockstatLogger, err = sockstatlog.NewLogger(logpolicy.LogsDir(logf), logf, logID, netMon, sys.HealthTracker.Get(), sys.Bus.Get()) if err != nil { log.Printf("error setting up sockstat logger: %v", err) } @@ -505,49 +558,116 @@ func NewLocalBackend(logf logger.Logf, logID logid.PublicID, sys *tsd.System, lo b.setTCPPortsIntercepted(nil) - b.statusChanged = sync.NewCond(&b.statusLock) b.e.SetStatusCallback(b.setWgengineStatus) b.prevIfState = netMon.InterfaceState() - // Call our linkChange code once with the current state, and - // then also whenever it changes: + // Call our linkChange code once with the current state. + // Following changes are triggered via the eventbus. b.linkChange(&netmon.ChangeDelta{New: netMon.InterfaceState()}) - b.unregisterNetMon = netMon.RegisterChangeCallback(b.linkChange) - b.unregisterHealthWatch = b.health.RegisterWatcher(b.onHealthChange) - - if tunWrap, ok := b.sys.Tun.GetOK(); ok { - tunWrap.PeerAPIPort = b.GetPeerAPIPort - } else { - b.logf("[unexpected] failed to wire up PeerAPI port for engine %T", e) + if buildfeatures.HasPeerAPIServer { + if tunWrap, ok := b.sys.Tun.GetOK(); ok { + tunWrap.PeerAPIPort = b.GetPeerAPIPort + } else { + b.logf("[unexpected] failed to wire up PeerAPI port for engine %T", e) + } } - for _, component := range ipn.DebuggableComponents { - key := componentStateKey(component) - if ut, err := ipn.ReadStoreInt(pm.Store(), key); err == nil { - if until := time.Unix(ut, 0); until.After(b.clock.Now()) { - // conditional to avoid log spam at start when off - b.SetComponentDebugLogging(component, until) + if buildfeatures.HasDebug { + for _, component := range ipn.DebuggableComponents { + key := componentStateKey(component) + if ut, err := ipn.ReadStoreInt(pm.Store(), key); err == nil { + if until := time.Unix(ut, 0); until.After(b.clock.Now()) { + // conditional to avoid log spam at start when off + b.SetComponentDebugLogging(component, until) + } } } } - // initialize Taildrive shares from saved state - fs, ok := b.sys.DriveForRemote.GetOK() - if ok { - currentShares := b.pm.prefs.DriveShares() - if currentShares.Len() > 0 { - var shares []*drive.Share - for _, share := range currentShares.All() { - shares = append(shares, share.AsStruct()) - } - fs.SetShares(shares) - } + // Start the event bus late, once all the assignments above are done. + // (See previous race in tailscale/tailscale#17252) + ec := b.Sys().Bus.Get().Client("ipnlocal.LocalBackend") + b.eventClient = ec + eventbus.SubscribeFunc(ec, b.onClientVersion) + eventbus.SubscribeFunc(ec, func(au controlclient.AutoUpdate) { + b.onTailnetDefaultAutoUpdate(au.Value) + }) + eventbus.SubscribeFunc(ec, func(cd netmon.ChangeDelta) { b.linkChange(&cd) }) + if buildfeatures.HasHealth { + eventbus.SubscribeFunc(ec, b.onHealthChange) } + if buildfeatures.HasPortList { + eventbus.SubscribeFunc(ec, b.setPortlistServices) + } + eventbus.SubscribeFunc(ec, b.onAppConnectorRouteUpdate) + eventbus.SubscribeFunc(ec, b.onAppConnectorStoreRoutes) + mConn.SetNetInfoCallback(b.setNetInfo) // TODO(tailscale/tailscale#17887): move to eventbus return b, nil } +func (b *LocalBackend) onAppConnectorRouteUpdate(ru appctype.RouteUpdate) { + // TODO(creachadair, 2025-10-02): It is currently possible for updates produced under + // one profile to arrive and be applied after a switch to another profile. + // We need to find a way to ensure that changes to the backend state are applied + // consistently in the presnce of profile changes, which currently may not happen in + // a single atomic step. See: https://github.com/tailscale/tailscale/issues/17414 + b.appcTask.Add(func() { + if err := b.AdvertiseRoute(ru.Advertise...); err != nil { + b.logf("appc: failed to advertise routes: %v: %v", ru.Advertise, err) + } + if err := b.UnadvertiseRoute(ru.Unadvertise...); err != nil { + b.logf("appc: failed to unadvertise routes: %v: %v", ru.Unadvertise, err) + } + }) +} + +func (b *LocalBackend) onAppConnectorStoreRoutes(ri appctype.RouteInfo) { + // Whether or not routes should be stored can change over time. + shouldStoreRoutes := b.ControlKnobs().AppCStoreRoutes.Load() + if shouldStoreRoutes { + if err := b.storeRouteInfo(ri); err != nil { + b.logf("appc: failed to store route info: %v", err) + } + } +} + +func (b *LocalBackend) Clock() tstime.Clock { return b.clock } +func (b *LocalBackend) Sys() *tsd.System { return b.sys } + +// NodeBackend returns the current node's NodeBackend interface. +func (b *LocalBackend) NodeBackend() ipnext.NodeBackend { + return b.currentNode() +} + +func (b *LocalBackend) currentNode() *nodeBackend { + if v := b.currentNodeAtomic.Load(); v != nil || !testenv.InTest() { + return v + } + v := newNodeBackend(cmp.Or(b.ctx, context.Background()), b.logf, b.sys.Bus.Get()) + if b.currentNodeAtomic.CompareAndSwap(nil, v) { + v.ready() + } + return b.currentNodeAtomic.Load() +} + +// FindExtensionByName returns an active extension with the given name, +// or nil if no such extension exists. +func (b *LocalBackend) FindExtensionByName(name string) any { + return b.extHost.Extensions().FindExtensionByName(name) +} + +// FindMatchingExtension finds the first active extension that matches target, +// and if one is found, sets target to that extension and returns true. +// Otherwise, it returns false. +// +// It panics if target is not a non-nil pointer to either a type +// that implements [ipnext.Extension], or to any interface type. +func (b *LocalBackend) FindMatchingExtension(target any) bool { + return b.extHost.Extensions().FindMatchingExtension(target) +} + type componentLogState struct { until time.Time timer tstime.TimerController // if non-nil, the AfterFunc to disable it @@ -565,6 +685,9 @@ func componentStateKey(component string) ipn.StateKey { // - magicsock // - sockstats func (b *LocalBackend) SetComponentDebugLogging(component string, until time.Time) error { + if !buildfeatures.HasDebug { + return feature.ErrUnavailable + } b.mu.Lock() defer b.mu.Unlock() @@ -583,6 +706,8 @@ func (b *LocalBackend) SetComponentDebugLogging(component string, until time.Tim } } } + case "syspolicy": + setEnabled = b.polc.SetDebugLoggingEnabled } if setEnabled == nil || !slices.Contains(ipn.DebuggableComponents, component) { return fmt.Errorf("unknown component %q", component) @@ -626,6 +751,9 @@ func (b *LocalBackend) SetComponentDebugLogging(component string, until time.Tim // GetDNSOSConfig returns the base OS DNS configuration, as seen by the DNS manager. func (b *LocalBackend) GetDNSOSConfig() (dns.OSConfig, error) { + if !buildfeatures.HasDNS { + panic("unreachable") + } manager, ok := b.sys.DNSManager.GetOK() if !ok { return dns.OSConfig{}, errors.New("DNS manager not available") @@ -637,6 +765,9 @@ func (b *LocalBackend) GetDNSOSConfig() (dns.OSConfig, error) { // the raw DNS response and the resolvers that are were able to handle the query (the internal forwarder // may race multiple resolvers). func (b *LocalBackend) QueryDNS(name string, queryType dnsmessage.Type) (res []byte, resolvers []*dnstype.Resolver, err error) { + if !buildfeatures.HasDNS { + return nil, nil, feature.ErrUnavailable + } manager, ok := b.sys.DNSManager.GetOK() if !ok { return nil, nil, errors.New("DNS manager not available") @@ -681,6 +812,9 @@ func (b *LocalBackend) QueryDNS(name string, queryType dnsmessage.Type) (res []b // enabled until, or the zero time if component's time is not currently // enabled. func (b *LocalBackend) GetComponentDebugLogging(component string) time.Time { + if !buildfeatures.HasDebug { + return time.Time{} + } b.mu.Lock() defer b.mu.Unlock() @@ -698,17 +832,6 @@ func (b *LocalBackend) Dialer() *tsdial.Dialer { return b.dialer } -// SetDirectFileRoot sets the directory to download files to directly, -// without buffering them through an intermediate daemon-owned -// tailcfg.UserID-specific directory. -// -// This must be called before the LocalBackend starts being used. -func (b *LocalBackend) SetDirectFileRoot(dir string) { - b.mu.Lock() - defer b.mu.Unlock() - b.directFileRoot = dir -} - // ReloadConfig reloads the backend's config from disk. // // It returns (false, nil) if not running in declarative mode, (true, nil) on @@ -730,11 +853,14 @@ func (b *LocalBackend) ReloadConfig() (ok bool, err error) { return true, nil } -func (b *LocalBackend) setConfigLocked(conf *conffile.Config) error { - - // TODO(irbekrm): notify the relevant components to consume any prefs - // updates. Currently only initial configfile settings are applied - // immediately. +// initPrefsFromConfig initializes the backend's prefs from the provided config. +// This should only be called once, at startup. For updates at runtime, use +// [LocalBackend.setConfigLocked]. +func (b *LocalBackend) initPrefsFromConfig(conf *conffile.Config) error { + // TODO(maisem,bradfitz): combine this with setConfigLocked. This is called + // before anything is running, so there's no need to lock and we don't + // update any subsystems. At runtime, we both need to lock and update + // subsystems with the new prefs. p := b.pm.CurrentPrefs().AsStruct() mp, err := conf.Parsed.ToPrefs() if err != nil { @@ -744,13 +870,15 @@ func (b *LocalBackend) setConfigLocked(conf *conffile.Config) error { if err := b.pm.SetPrefs(p.View(), ipn.NetworkProfile{}); err != nil { return err } + b.updateWarnSync(p.View()) + b.setStaticEndpointsFromConfigLocked(conf) + b.conf = conf + return nil +} - defer func() { - b.conf = conf - }() - +func (b *LocalBackend) setStaticEndpointsFromConfigLocked(conf *conffile.Config) { if conf.Parsed.StaticEndpoints == nil && (b.conf == nil || b.conf.Parsed.StaticEndpoints == nil) { - return nil + return } // Ensure that magicsock conn has the up to date static wireguard @@ -764,6 +892,31 @@ func (b *LocalBackend) setConfigLocked(conf *conffile.Config) error { ms.SetStaticEndpoints(views.SliceOf(conf.Parsed.StaticEndpoints)) } } +} + +func (b *LocalBackend) setStateLocked(state ipn.State) { + if b.state == state { + return + } + b.state = state + for _, f := range b.extHost.Hooks().BackendStateChange { + f(state) + } +} + +// setConfigLocked uses the provided config to update the backend's prefs +// and other state. +func (b *LocalBackend) setConfigLocked(conf *conffile.Config) error { + p := b.pm.CurrentPrefs().AsStruct() + mp, err := conf.Parsed.ToPrefs() + if err != nil { + return fmt.Errorf("error parsing config to prefs: %w", err) + } + p.ApplyEdits(&mp) + b.setStaticEndpointsFromConfigLocked(conf) + b.setPrefsLocked(p) + + b.conf = conf return nil } @@ -779,12 +932,26 @@ func (b *LocalBackend) pauseOrResumeControlClientLocked() { return } networkUp := b.prevIfState.AnyInterfaceUp() - b.cc.SetPaused((b.state == ipn.Stopped && b.netMap != nil) || (!networkUp && !testenv.InTest() && !assumeNetworkUpdateForTest())) + pauseForNetwork := (b.state == ipn.Stopped && b.NetMap() != nil) || (!networkUp && !testenv.InTest() && !assumeNetworkUpdateForTest()) + + prefs := b.pm.CurrentPrefs() + pauseForSyncPref := prefs.Valid() && prefs.Sync().EqualBool(false) + + b.cc.SetPaused(pauseForNetwork || pauseForSyncPref) } -// captivePortalDetectionInterval is the duration to wait in an unhealthy state with connectivity broken -// before running captive portal detection. -const captivePortalDetectionInterval = 2 * time.Second +// DisconnectControl shuts down control client. This can be run before node shutdown to force control to consider this ndoe +// inactive. This can be used to ensure that nodes that are HA subnet router or app connector replicas are shutting +// down, clients switch over to other replicas whilst the existing connections are kept alive for some period of time. +func (b *LocalBackend) DisconnectControl() { + b.mu.Lock() + defer b.mu.Unlock() + cc := b.resetControlClientLocked() + if cc == nil { + return + } + cc.Shutdown() +} // linkChange is our network monitor callback, called whenever the network changes. func (b *LocalBackend) linkChange(delta *netmon.ChangeDelta) { @@ -795,13 +962,14 @@ func (b *LocalBackend) linkChange(delta *netmon.ChangeDelta) { hadPAC := b.prevIfState.HasPAC() b.prevIfState = ifst b.pauseOrResumeControlClientLocked() - if delta.Major && shouldAutoExitNode() { + prefs := b.pm.CurrentPrefs() + if delta.Major && prefs.AutoExitNode().IsSet() { b.refreshAutoExitNode = true } var needReconfig bool // If the network changed and we're using an exit node and allowing LAN access, we may need to reconfigure. - if delta.Major && b.pm.CurrentPrefs().ExitNodeID() != "" && b.pm.CurrentPrefs().ExitNodeAllowLANAccess() { + if delta.Major && prefs.ExitNodeID() != "" && prefs.ExitNodeAllowLANAccess() { b.logf("linkChange: in state %v; updating LAN routes", b.state) needReconfig = true } @@ -818,29 +986,48 @@ func (b *LocalBackend) linkChange(delta *netmon.ChangeDelta) { // TODO(raggi,tailscale/corp#22574): authReconfig should be refactored such that we can call the // necessary operations here and avoid the need for asynchronous behavior that is racy and hard // to test here, and do less extra work in these conditions. - go b.authReconfig() + b.goTracker.Go(b.authReconfig) } } // If the local network configuration has changed, our filter may // need updating to tweak default routes. - b.updateFilterLocked(b.netMap, b.pm.CurrentPrefs()) - updateExitNodeUsageWarning(b.pm.CurrentPrefs(), delta.New, b.health) - - if peerAPIListenAsync && b.netMap != nil && b.state == ipn.Running { - want := b.netMap.GetAddresses().Len() - if len(b.peerAPIListeners) < want { - b.logf("linkChange: peerAPIListeners too low; trying again") - go b.initPeerAPIListener() + b.updateFilterLocked(prefs) + updateExitNodeUsageWarning(prefs, delta.New, b.health) + + if buildfeatures.HasPeerAPIServer { + cn := b.currentNode() + nm := cn.NetMap() + if peerAPIListenAsync && nm != nil && b.state == ipn.Running { + want := nm.GetAddresses().Len() + have := len(b.peerAPIListeners) + b.logf("[v1] linkChange: have %d peerAPIListeners, want %d", have, want) + if have < want { + b.logf("linkChange: peerAPIListeners too low; trying again") + b.goTracker.Go(b.initPeerAPIListener) + } } } } -func (b *LocalBackend) onHealthChange(w *health.Warnable, us *health.UnhealthyState) { - if us == nil { - b.logf("health(warnable=%s): ok", w.Code) - } else { - b.logf("health(warnable=%s): error: %s", w.Code, us.Text) +// Captive portal detection hooks. +var ( + hookCaptivePortalHealthChange feature.Hook[func(*LocalBackend, *health.State)] + hookCheckCaptivePortalLoop feature.Hook[func(*LocalBackend, context.Context)] +) + +func (b *LocalBackend) onHealthChange(change health.Change) { + if !buildfeatures.HasHealth { + return + } + if change.WarnableChanged { + w := change.Warnable + us := change.UnhealthyState + if us == nil { + b.logf("health(warnable=%s): ok", w.Code) + } else { + b.logf("health(warnable=%s): error: %s", w.Code, us.Text) + } } // Whenever health changes, send the current health state to the frontend. @@ -849,51 +1036,64 @@ func (b *LocalBackend) onHealthChange(w *health.Warnable, us *health.UnhealthySt Health: state, }) - isConnectivityImpacted := false - for _, w := range state.Warnings { - // Ignore the captive portal warnable itself. - if w.ImpactsConnectivity && w.WarnableCode != captivePortalWarnable.Code { - isConnectivityImpacted = true - break - } + if f, ok := hookCaptivePortalHealthChange.GetOk(); ok { + f(b, state) } +} - // captiveCtx can be changed, and is protected with 'mu'; grab that - // before we start our select, below. - // - // It is guaranteed to be non-nil. +// GetOrSetCaptureSink returns the current packet capture sink, creating it +// with the provided newSink function if it does not already exist. +func (b *LocalBackend) GetOrSetCaptureSink(newSink func() packet.CaptureSink) packet.CaptureSink { + if !buildfeatures.HasCapture { + return nil + } b.mu.Lock() - ctx := b.captiveCtx - b.mu.Unlock() + defer b.mu.Unlock() - // If the context is canceled, we don't need to do anything. - if ctx.Err() != nil { - return + if b.debugSink != nil { + return b.debugSink } + s := newSink() + b.debugSink = s + b.e.InstallCaptureHook(s.CaptureCallback()) + return s +} - if isConnectivityImpacted { - b.logf("health: connectivity impacted; triggering captive portal detection") +func (b *LocalBackend) ClearCaptureSink() { + if !buildfeatures.HasCapture { + return + } + // Shut down & uninstall the sink if there are no longer + // any outputs on it. + b.mu.Lock() + defer b.mu.Unlock() - // Ensure that we select on captiveCtx so that we can time out - // triggering captive portal detection if the backend is shutdown. - select { - case b.needsCaptiveDetection <- true: - case <-ctx.Done(): - } - } else { - // If connectivity is not impacted, we know for sure we're not behind a captive portal, - // so drop any warning, and signal that we don't need captive portal detection. - b.health.SetHealthy(captivePortalWarnable) - select { - case b.needsCaptiveDetection <- false: - case <-ctx.Done(): - } + select { + case <-b.ctx.Done(): + return + default: + } + if b.debugSink != nil && b.debugSink.NumOutputs() == 0 { + s := b.debugSink + b.e.InstallCaptureHook(nil) + b.debugSink = nil + s.Close() } } // Shutdown halts the backend and all its sub-components. The backend // can no longer be used after Shutdown returns. func (b *LocalBackend) Shutdown() { + // Close the [eventbus.Client] to wait for subscribers to + // return before acquiring b.mu: + // 1. Event handlers also acquire b.mu, they can deadlock with c.Shutdown(). + // 2. Event handlers may not guard against undesirable post/in-progress + // LocalBackend.Shutdown() behaviors. + b.appcTask.Shutdown() + b.eventClient.Close() + + b.em.close() + b.mu.Lock() if b.shutdownCalled { b.mu.Unlock() @@ -901,17 +1101,19 @@ func (b *LocalBackend) Shutdown() { } b.shutdownCalled = true - if b.captiveCancel != nil { + if buildfeatures.HasCaptivePortal && b.captiveCancel != nil { b.logf("canceling captive portal context") b.captiveCancel() } + b.stopReconnectTimerLocked() + if b.loginFlags&controlclient.LoginEphemeral != 0 { b.mu.Unlock() ctx, cancel := context.WithTimeout(b.ctx, 5*time.Second) defer cancel() t0 := time.Now() - err := b.Logout(ctx) // best effort + err := b.Logout(ctx, ipnauth.Self) // best effort td := time.Since(t0).Round(time.Millisecond) if err != nil { b.logf("failed to log out ephemeral node on shutdown after %v: %v", td, err) @@ -934,6 +1136,7 @@ func (b *LocalBackend) Shutdown() { if b.notifyCancel != nil { b.notifyCancel() } + b.appConnector.Close() b.mu.Unlock() b.webClientShutdown() @@ -942,19 +1145,43 @@ func (b *LocalBackend) Shutdown() { defer cancel() b.sockstatLogger.Shutdown(ctx) } - if b.peerAPIServer != nil { - b.peerAPIServer.taildrop.Shutdown() - } - b.stopOfflineAutoUpdate() - b.unregisterNetMon() - b.unregisterHealthWatch() + b.unregisterSysPolicyWatch() if cc != nil { cc.Shutdown() } - b.ctxCancel() + b.ctxCancel(errShutdown) + b.currentNode().shutdown(errShutdown) + b.extHost.Shutdown() b.e.Close() <-b.e.Done() + b.awaitNoGoroutinesInTest() +} + +func (b *LocalBackend) awaitNoGoroutinesInTest() { + if !buildfeatures.HasDebug || !testenv.InTest() { + return + } + ctx, cancel := context.WithTimeout(context.Background(), 8*time.Second) + defer cancel() + + ch := make(chan bool, 1) + defer b.goTracker.AddDoneCallback(func() { ch <- true })() + + for { + n := b.goTracker.RunningGoroutines() + if n == 0 { + return + } + select { + case <-ctx.Done(): + // TODO(bradfitz): pass down some TB-like failer interface from + // tests, without depending on testing from here? + // But this is fine in tests too: + panic(fmt.Sprintf("timeout waiting for %d goroutines to stop", n)) + case <-ch: + } + } } func stripKeysFromPrefs(p ipn.PrefsView) ipn.PrefsView { @@ -963,10 +1190,10 @@ func stripKeysFromPrefs(p ipn.PrefsView) ipn.PrefsView { } p2 := p.AsStruct() - p2.Persist.LegacyFrontendPrivateMachineKey = key.MachinePrivate{} p2.Persist.PrivateNodeKey = key.NodePrivate{} p2.Persist.OldPrivateNodeKey = key.NodePrivate{} p2.Persist.NetworkLockKey = key.NLPrivate{} + p2.Persist.AttestationKey = nil return p2.View() } @@ -981,6 +1208,13 @@ func (b *LocalBackend) sanitizedPrefsLocked() ipn.PrefsView { return stripKeysFromPrefs(b.pm.CurrentPrefs()) } +// unsanitizedPersist returns the current PersistView, including any private keys. +func (b *LocalBackend) unsanitizedPersist() persist.PersistView { + b.mu.Lock() + defer b.mu.Unlock() + return b.pm.CurrentPrefs().Persist() +} + // Status returns the latest status of the backend and its // sub-components. func (b *LocalBackend) Status() *ipnstate.Status { @@ -1004,6 +1238,8 @@ func (b *LocalBackend) UpdateStatus(sb *ipnstate.StatusBuilder) { b.mu.Lock() defer b.mu.Unlock() + cn := b.currentNode() + nm := cn.NetMap() sb.MutateStatus(func(s *ipnstate.Status) { s.Version = version.Long() s.TUN = !b.sys.IsNetstack() @@ -1020,28 +1256,24 @@ func (b *LocalBackend) UpdateStatus(sb *ipnstate.StatusBuilder) { if m := b.sshOnButUnusableHealthCheckMessageLocked(); m != "" { s.Health = append(s.Health, m) } - if b.netMap != nil { - s.CertDomains = append([]string(nil), b.netMap.DNS.CertDomains...) - s.MagicDNSSuffix = b.netMap.MagicDNSSuffix() + if nm != nil { + s.CertDomains = append([]string(nil), nm.DNS.CertDomains...) + s.MagicDNSSuffix = nm.MagicDNSSuffix() if s.CurrentTailnet == nil { s.CurrentTailnet = &ipnstate.TailnetStatus{} } - s.CurrentTailnet.MagicDNSSuffix = b.netMap.MagicDNSSuffix() - s.CurrentTailnet.MagicDNSEnabled = b.netMap.DNS.Proxied - s.CurrentTailnet.Name = b.netMap.Domain + s.CurrentTailnet.MagicDNSSuffix = nm.MagicDNSSuffix() + s.CurrentTailnet.MagicDNSEnabled = nm.DNS.Proxied + s.CurrentTailnet.Name = nm.Domain if prefs := b.pm.CurrentPrefs(); prefs.Valid() { - if !prefs.RouteAll() && b.netMap.AnyPeersAdvertiseRoutes() { + if !prefs.RouteAll() && nm.AnyPeersAdvertiseRoutes() { s.Health = append(s.Health, healthmsg.WarnAcceptRoutesOff) } if !prefs.ExitNodeID().IsZero() { - if exitPeer, ok := b.netMap.PeerWithStableID(prefs.ExitNodeID()); ok { - online := false - if v := exitPeer.Online(); v != nil { - online = *v - } + if exitPeer, ok := nm.PeerWithStableID(prefs.ExitNodeID()); ok { s.ExitNodeStatus = &ipnstate.ExitNodeStatus{ ID: prefs.ExitNodeID(), - Online: online, + Online: exitPeer.Online().Get(), TailscaleIPs: exitPeer.Addresses().AsSlice(), } } @@ -1051,8 +1283,8 @@ func (b *LocalBackend) UpdateStatus(sb *ipnstate.StatusBuilder) { }) var tailscaleIPs []netip.Addr - if b.netMap != nil { - addrs := b.netMap.GetAddresses() + if nm != nil { + addrs := nm.GetAddresses() for i := range addrs.Len() { if addr := addrs.At(i); addr.IsSingleIP() { sb.AddTailscaleIP(addr.Addr()) @@ -1064,24 +1296,23 @@ func (b *LocalBackend) UpdateStatus(sb *ipnstate.StatusBuilder) { sb.MutateSelfStatus(func(ss *ipnstate.PeerStatus) { ss.OS = version.OS() ss.Online = b.health.GetInPollNetMap() - if b.netMap != nil { + if nm != nil { ss.InNetworkMap = true - if hi := b.netMap.SelfNode.Hostinfo(); hi.Valid() { + if hi := nm.SelfNode.Hostinfo(); hi.Valid() { ss.HostName = hi.Hostname() } - ss.DNSName = b.netMap.Name - ss.UserID = b.netMap.User() - if sn := b.netMap.SelfNode; sn.Valid() { + ss.DNSName = nm.SelfName() + ss.UserID = nm.User() + if sn := nm.SelfNode; sn.Valid() { peerStatusFromNode(ss, sn) if cm := sn.CapMap(); cm.Len() > 0 { ss.Capabilities = make([]tailcfg.NodeCapability, 1, cm.Len()+1) ss.Capabilities[0] = "HTTPS://TAILSCALE.COM/s/DEPRECATED-NODE-CAPS#see-https://github.com/tailscale/tailscale/issues/11508" ss.CapMap = make(tailcfg.NodeCapMap, sn.CapMap().Len()) - cm.Range(func(k tailcfg.NodeCapability, v views.Slice[tailcfg.RawMessage]) bool { + for k, v := range cm.All() { ss.CapMap[k] = v.AsSlice() ss.Capabilities = append(ss.Capabilities, k) - return true - }) + } slices.Sort(ss.Capabilities[1:]) } } @@ -1090,7 +1321,7 @@ func (b *LocalBackend) UpdateStatus(sb *ipnstate.StatusBuilder) { } } else { - ss.HostName, _ = os.Hostname() + ss.HostName, _ = hostinfo.Hostname() } for _, pln := range b.peerAPIListeners { ss.PeerAPIURL = append(ss.PeerAPIURL, pln.urlStr) @@ -1105,18 +1336,16 @@ func (b *LocalBackend) UpdateStatus(sb *ipnstate.StatusBuilder) { } func (b *LocalBackend) populatePeerStatusLocked(sb *ipnstate.StatusBuilder) { - if b.netMap == nil { + cn := b.currentNode() + nm := cn.NetMap() + if nm == nil { return } - for id, up := range b.netMap.UserProfiles { + for id, up := range nm.UserProfiles { sb.AddUser(id, up) } exitNodeID := b.pm.CurrentPrefs().ExitNodeID() - for _, p := range b.peers { - var lastSeen time.Time - if p.LastSeen() != nil { - lastSeen = *p.LastSeen() - } + for _, p := range cn.Peers() { tailscaleIPs := make([]netip.Addr, 0, p.Addresses().Len()) for i := range p.Addresses().Len() { addr := p.Addresses().At(i) @@ -1124,7 +1353,6 @@ func (b *LocalBackend) populatePeerStatusLocked(sb *ipnstate.StatusBuilder) { tailscaleIPs = append(tailscaleIPs, addr.Addr()) } } - online := p.Online() ps := &ipnstate.PeerStatus{ InNetworkMap: true, UserID: p.User(), @@ -1133,20 +1361,22 @@ func (b *LocalBackend) populatePeerStatusLocked(sb *ipnstate.StatusBuilder) { HostName: p.Hostinfo().Hostname(), DNSName: p.Name(), OS: p.Hostinfo().OS(), - LastSeen: lastSeen, - Online: online != nil && *online, + LastSeen: p.LastSeen().Get(), + Online: p.Online().Get(), ShareeNode: p.Hostinfo().ShareeNode(), ExitNode: p.StableID() != "" && p.StableID() == exitNodeID, SSH_HostKeys: p.Hostinfo().SSH_HostKeys().AsSlice(), - Location: p.Hostinfo().Location(), + Location: p.Hostinfo().Location().AsStruct(), Capabilities: p.Capabilities().AsSlice(), } + for _, f := range b.extHost.Hooks().SetPeerStatus { + f(ps, p, cn) + } if cm := p.CapMap(); cm.Len() > 0 { ps.CapMap = make(tailcfg.NodeCapMap, cm.Len()) - cm.Range(func(k tailcfg.NodeCapability, v views.Slice[tailcfg.RawMessage]) bool { + for k, v := range cm.All() { ps.CapMap[k] = v.AsSlice() - return true - }) + } } peerStatusFromNode(ps, p) @@ -1167,7 +1397,7 @@ func peerStatusFromNode(ps *ipnstate.PeerStatus, n tailcfg.NodeView) { ps.PublicKey = n.Key() ps.ID = n.StableID() ps.Created = n.Created() - ps.ExitNodeOption = tsaddr.ContainsExitRoutes(n.AllowedIPs()) + ps.ExitNodeOption = buildfeatures.HasUseExitNode && tsaddr.ContainsExitRoutes(n.AllowedIPs()) if n.Tags().Len() != 0 { v := n.Tags() ps.Tags = &v @@ -1190,20 +1420,25 @@ func peerStatusFromNode(ps *ipnstate.PeerStatus, n tailcfg.NodeView) { } } +func profileFromView(v tailcfg.UserProfileView) tailcfg.UserProfile { + if v.Valid() { + return tailcfg.UserProfile{ + ID: v.ID(), + LoginName: v.LoginName(), + DisplayName: v.DisplayName(), + ProfilePicURL: v.ProfilePicURL(), + } + } + return tailcfg.UserProfile{} +} + // WhoIsNodeKey returns the peer info of given public key, if it exists. func (b *LocalBackend) WhoIsNodeKey(k key.NodePublic) (n tailcfg.NodeView, u tailcfg.UserProfile, ok bool) { - b.mu.Lock() - defer b.mu.Unlock() - // TODO(bradfitz): add nodeByKey like nodeByAddr instead of walking peers. - if b.netMap == nil { - return n, u, false - } - if self := b.netMap.SelfNode; self.Valid() && self.Key() == k { - return self, b.netMap.UserProfiles[self.User()], true - } - for _, n := range b.peers { - if n.Key() == k { - u, ok = b.netMap.UserProfiles[n.User()] + cn := b.currentNode() + if nid, ok := cn.NodeByKey(k); ok { + if n, ok := cn.NodeByID(nid); ok { + up, ok := cn.NetMap().UserProfiles[n.User()] + u = profileFromView(up) return n, u, ok } } @@ -1235,8 +1470,9 @@ func (b *LocalBackend) WhoIs(proto string, ipp netip.AddrPort) (n tailcfg.NodeVi return zero, u, false } - nid, ok := b.nodeByAddr[ipp.Addr()] - if !ok { + cn := b.currentNode() + nid, ok := cn.NodeByAddr(ipp.Addr()) + if !ok && buildfeatures.HasNetstack { var ip netip.Addr if ipp.Port() != 0 { var protos []string @@ -1257,72 +1493,69 @@ func (b *LocalBackend) WhoIs(proto string, ipp netip.AddrPort) (n tailcfg.NodeVi if !ok { return failf("no IP found in ProxyMapper for %v", ipp) } - nid, ok = b.nodeByAddr[ip] + nid, ok = cn.NodeByAddr(ip) if !ok { return failf("no node for proxymapped IP %v", ip) } } - if b.netMap == nil { + nm := cn.NetMap() + if nm == nil { return failf("no netmap") } - n, ok = b.peers[nid] + n, ok = cn.NodeByID(nid) if !ok { - // Check if this the self-node, which would not appear in peers. - if !b.netMap.SelfNode.Valid() || nid != b.netMap.SelfNode.ID() { - return zero, u, false - } - n = b.netMap.SelfNode + return zero, u, false } - u, ok = b.netMap.UserProfiles[n.User()] + up, ok := cn.UserByID(n.User()) if !ok { return failf("no userprofile for node %v", n.Key()) } - return n, u, true + return n, profileFromView(up), true } // PeerCaps returns the capabilities that remote src IP has to // ths current node. func (b *LocalBackend) PeerCaps(src netip.Addr) tailcfg.PeerCapMap { - b.mu.Lock() - defer b.mu.Unlock() - return b.peerCapsLocked(src) + return b.currentNode().PeerCaps(src) } -func (b *LocalBackend) peerCapsLocked(src netip.Addr) tailcfg.PeerCapMap { - if b.netMap == nil { - return nil - } - filt := b.filterAtomic.Load() - if filt == nil { - return nil - } - addrs := b.netMap.GetAddresses() - for i := range addrs.Len() { - a := addrs.At(i) - if !a.IsSingleIP() { - continue - } - dst := a.Addr() - if dst.BitLen() == src.BitLen() { // match on family - return filt.CapsWithValues(src, dst) +func (b *LocalBackend) GetFilterForTest() *filter.Filter { + testenv.AssertInTest() + nb := b.currentNode() + return nb.filterAtomic.Load() +} + +func (b *LocalBackend) settleEventBus() { + // The move to eventbus made some things racy that + // weren't before so we have to wait for it to all be settled + // before we call certain things. + // See https://github.com/tailscale/tailscale/issues/16369 + // But we can't do this while holding b.mu without deadlocks, + // (https://github.com/tailscale/tailscale/pull/17804#issuecomment-3514426485) so + // now we just do it in lots of places before acquiring b.mu. + // Is this winning?? + if b.sys != nil { + if ms, ok := b.sys.MagicSock.GetOK(); ok { + ms.Synchronize() } } - return nil } // SetControlClientStatus is the callback invoked by the control client whenever it posts a new status. // Among other things, this is where we update the netmap, packet filters, DNS and DERP maps. func (b *LocalBackend) SetControlClientStatus(c controlclient.Client, st controlclient.Status) { - unlock := b.lockAndGetUnlock() - defer unlock() + if b.ignoreControlClientUpdates.Load() { + b.logf("ignoring SetControlClientStatus during controlclient shutdown") + return + } + b.mu.Lock() + defer b.mu.Unlock() if b.cc != c { b.logf("Ignoring SetControlClientStatus from old client") return } if st.Err != nil { - // The following do not depend on any data for which we need b locked. - unlock.UnlockEarly() if errors.Is(st.Err, io.EOF) { b.logf("[v1] Received error: EOF") return @@ -1331,7 +1564,7 @@ func (b *LocalBackend) SetControlClientStatus(c controlclient.Client, st control var uerr controlclient.UserVisibleError if errors.As(st.Err, &uerr) { s := uerr.UserVisibleError() - b.send(ipn.Notify{ErrMessage: &s}) + b.sendLocked(ipn.Notify{ErrMessage: &s}) } return } @@ -1380,38 +1613,35 @@ func (b *LocalBackend) SetControlClientStatus(c controlclient.Client, st control } wasBlocked := b.blocked + authWasInProgress := b.authURL != "" keyExpiryExtended := false if st.NetMap != nil { wasExpired := b.keyExpired - isExpired := !st.NetMap.Expiry.IsZero() && st.NetMap.Expiry.Before(b.clock.Now()) + isExpired := !st.NetMap.SelfKeyExpiry().IsZero() && st.NetMap.SelfKeyExpiry().Before(b.clock.Now()) if wasExpired && !isExpired { keyExpiryExtended = true } b.keyExpired = isExpired } - unlock.UnlockEarly() - if keyExpiryExtended && wasBlocked { // Key extended, unblock the engine - b.blockEngineUpdates(false) + b.blockEngineUpdatesLocked(false) } - if st.LoginFinished() && (wasBlocked || b.seamlessRenewalEnabled()) { + if st.LoggedIn && (wasBlocked || authWasInProgress) { if wasBlocked { // Auth completed, unblock the engine - b.blockEngineUpdates(false) + b.blockEngineUpdatesLocked(false) } - b.authReconfig() - b.send(ipn.Notify{LoginFinished: &empty.Message{}}) + b.authReconfigLocked() + b.sendLocked(ipn.Notify{LoginFinished: &empty.Message{}}) } - // Lock b again and do only the things that require locking. - b.mu.Lock() - prefsChanged := false + cn := b.currentNode() prefs := b.pm.CurrentPrefs().AsStruct() - oldNetMap := b.netMap + oldNetMap := cn.NetMap() curNetMap := st.NetMap if curNetMap == nil { // The status didn't include a netmap update, so the old one is still @@ -1425,7 +1655,7 @@ func (b *LocalBackend) SetControlClientStatus(c controlclient.Client, st control // future "tailscale up" to start checking for // implicit setting reverts, which it doesn't do when // ControlURL is blank. - prefs.ControlURL = prefs.ControlURLOrDefault() + prefs.ControlURL = prefs.ControlURLOrDefault(b.polc) prefsChanged = true } if st.Persist.Valid() { @@ -1434,8 +1664,8 @@ func (b *LocalBackend) SetControlClientStatus(c controlclient.Client, st control prefs.Persist = st.Persist.AsStruct() } } - if st.LoginFinished() { - if b.authURL != "" { + if st.LoggedIn { + if authWasInProgress { b.resetAuthURLLocked() // Interactive login finished successfully (URL visited). // After an interactive login, the user always wants @@ -1450,17 +1680,11 @@ func (b *LocalBackend) SetControlClientStatus(c controlclient.Client, st control prefsChanged = true } } - if shouldAutoExitNode() { - // Re-evaluate exit node suggestion in case circumstances have changed. - _, err := b.suggestExitNodeLocked(curNetMap) - if err != nil && !errors.Is(err, ErrNoPreferredDERP) { - b.logf("SetControlClientStatus failed to select auto exit node: %v", err) - } - } - if setExitNodeID(prefs, curNetMap, b.lastSuggestedExitNode) { - prefsChanged = true - } - if applySysPolicy(prefs) { + // We primarily need this to apply syspolicy to the prefs if an implicit profile + // switch is about to happen. + // TODO(nickkhyl): remove this once we improve handling of implicit profile switching + // in tailscale/corp#28014 and we apply syspolicy when the switch actually happens. + if b.reconcilePrefsLocked(prefs) { prefsChanged = true } @@ -1470,16 +1694,25 @@ func (b *LocalBackend) SetControlClientStatus(c controlclient.Client, st control prefsChanged = true } + // If the tailnet's display name has changed, update prefs. + if st.NetMap != nil && st.NetMap.TailnetDisplayName() != b.pm.CurrentProfile().NetworkProfile().DisplayName { + prefsChanged = true + } + // Perform all mutations of prefs based on the netmap here. if prefsChanged { // Prefs will be written out if stale; this is not safe unless locked or cloned. if err := b.pm.SetPrefs(prefs.View(), ipn.NetworkProfile{ MagicDNSName: curNetMap.MagicDNSSuffix(), DomainName: curNetMap.DomainName(), + DisplayName: curNetMap.TailnetDisplayName(), }); err != nil { b.logf("Failed to save new controlclient state: %v", err) } + + b.sendToLocked(ipn.Notify{Prefs: ptr.To(prefs.View())}, allClients) } + // initTKALocked is dependent on CurrentProfile.ID, which is initialized // (for new profiles) on the first call to b.pm.SetPrefs. if err := b.initTKALocked(); err != nil { @@ -1515,31 +1748,26 @@ func (b *LocalBackend) SetControlClientStatus(c controlclient.Client, st control b.tkaFilterNetmapLocked(st.NetMap) } b.setNetMapLocked(st.NetMap) - b.updateFilterLocked(st.NetMap, prefs.View()) + b.updateFilterLocked(prefs.View()) } - b.mu.Unlock() // Now complete the lock-free parts of what we started while locked. - if prefsChanged { - b.send(ipn.Notify{Prefs: ptr.To(prefs.View())}) - } - if st.NetMap != nil { if envknob.NoLogsNoSupport() && st.NetMap.HasCap(tailcfg.CapabilityDataPlaneAuditLogs) { msg := "tailnet requires logging to be enabled. Remove --no-logs-no-support from tailscaled command line." b.health.SetLocalLogConfigHealth(errors.New(msg)) - // Connecting to this tailnet without logging is forbidden; boot us outta here. - b.mu.Lock() + // Get the current prefs again, since we unlocked above. + prefs := b.pm.CurrentPrefs().AsStruct() prefs.WantRunning = false p := prefs.View() if err := b.pm.SetPrefs(p, ipn.NetworkProfile{ MagicDNSName: st.NetMap.MagicDNSSuffix(), DomainName: st.NetMap.DomainName(), + DisplayName: st.NetMap.TailnetDisplayName(), }); err != nil { b.logf("Failed to save new controlclient state: %v", err) } - b.mu.Unlock() - b.send(ipn.Notify{ErrMessage: &msg, Prefs: &p}) + b.sendLocked(ipn.Notify{ErrMessage: &msg, Prefs: &p}) return } if oldNetMap != nil { @@ -1561,126 +1789,328 @@ func (b *LocalBackend) SetControlClientStatus(c controlclient.Client, st control // Update the DERP map in the health package, which uses it for health notifications b.health.SetDERPMap(st.NetMap.DERPMap) - b.send(ipn.Notify{NetMap: st.NetMap}) + b.sendLocked(ipn.Notify{NetMap: st.NetMap}) + + // The error here is unimportant as is the result. This will recalculate the suggested exit node + // cache the value and push any changes to the IPN bus. + b.suggestExitNodeLocked() + + // Check and update the exit node if needed, now that we have a new netmap. + // + // This must happen after the netmap change is sent via [ipn.Notify], + // so the GUI can correctly display the exit node if it has changed + // since the last netmap was sent. + // + // Otherwise, it might briefly show the exit node as offline and display a warning, + // if the node wasn't online or wasn't advertising default routes in the previous netmap. + b.refreshExitNodeLocked() } if st.URL != "" { b.logf("Received auth URL: %.20v...", st.URL) - b.setAuthURL(st.URL) + b.setAuthURLLocked(st.URL) } - b.stateMachine() + b.stateMachineLocked() // This is currently (2020-07-28) necessary; conditionally disabling it is fragile! // This is where netmap information gets propagated to router and magicsock. - b.authReconfig() + b.authReconfigLocked() } type preferencePolicyInfo struct { - key syspolicy.Key + key pkey.Key get func(ipn.PrefsView) bool set func(*ipn.Prefs, bool) } var preferencePolicies = []preferencePolicyInfo{ { - key: syspolicy.EnableIncomingConnections, + key: pkey.EnableIncomingConnections, // Allow Incoming (used by the UI) is the negation of ShieldsUp (used by the // backend), so this has to convert between the two conventions. get: func(p ipn.PrefsView) bool { return !p.ShieldsUp() }, set: func(p *ipn.Prefs, v bool) { p.ShieldsUp = !v }, }, { - key: syspolicy.EnableServerMode, + key: pkey.EnableServerMode, get: func(p ipn.PrefsView) bool { return p.ForceDaemon() }, set: func(p *ipn.Prefs, v bool) { p.ForceDaemon = v }, }, { - key: syspolicy.ExitNodeAllowLANAccess, + key: pkey.ExitNodeAllowLANAccess, get: func(p ipn.PrefsView) bool { return p.ExitNodeAllowLANAccess() }, set: func(p *ipn.Prefs, v bool) { p.ExitNodeAllowLANAccess = v }, }, { - key: syspolicy.EnableTailscaleDNS, + key: pkey.EnableTailscaleDNS, get: func(p ipn.PrefsView) bool { return p.CorpDNS() }, set: func(p *ipn.Prefs, v bool) { p.CorpDNS = v }, }, { - key: syspolicy.EnableTailscaleSubnets, + key: pkey.EnableTailscaleSubnets, get: func(p ipn.PrefsView) bool { return p.RouteAll() }, set: func(p *ipn.Prefs, v bool) { p.RouteAll = v }, }, { - key: syspolicy.CheckUpdates, + key: pkey.CheckUpdates, get: func(p ipn.PrefsView) bool { return p.AutoUpdate().Check }, set: func(p *ipn.Prefs, v bool) { p.AutoUpdate.Check = v }, }, { - key: syspolicy.ApplyUpdates, + key: pkey.ApplyUpdates, get: func(p ipn.PrefsView) bool { v, _ := p.AutoUpdate().Apply.Get(); return v }, set: func(p *ipn.Prefs, v bool) { p.AutoUpdate.Apply.Set(v) }, }, { - key: syspolicy.EnableRunExitNode, + key: pkey.EnableRunExitNode, get: func(p ipn.PrefsView) bool { return p.AdvertisesExitNode() }, set: func(p *ipn.Prefs, v bool) { p.SetAdvertiseExitNode(v) }, }, } -// applySysPolicy overwrites configured preferences with policies that may be +// applySysPolicyLocked overwrites configured preferences with policies that may be // configured by the system administrator in an OS-specific way. -func applySysPolicy(prefs *ipn.Prefs) (anyChange bool) { - if controlURL, err := syspolicy.GetString(syspolicy.ControlURL, prefs.ControlURL); err == nil && prefs.ControlURL != controlURL { +// +// b.mu must be held. +func (b *LocalBackend) applySysPolicyLocked(prefs *ipn.Prefs) (anyChange bool) { + if !buildfeatures.HasSystemPolicy { + return false + } + if controlURL, err := b.polc.GetString(pkey.ControlURL, prefs.ControlURL); err == nil && prefs.ControlURL != controlURL { prefs.ControlURL = controlURL anyChange = true } - for _, opt := range preferencePolicies { - if po, err := syspolicy.GetPreferenceOption(opt.key); err == nil { - curVal := opt.get(prefs.View()) - newVal := po.ShouldEnable(curVal) - if curVal != newVal { - opt.set(prefs, newVal) - anyChange = true - } - } + const sentinel = "HostnameDefaultValue" + hostnameFromPolicy, _ := b.polc.GetString(pkey.Hostname, sentinel) + switch hostnameFromPolicy { + case sentinel: + // An empty string for this policy value means that the admin wants to delete + // the hostname stored in the ipn.Prefs. To make that work, we need to + // distinguish between an empty string and a policy that was not set. + // We cannot do that with the current implementation of syspolicy.GetString. + // It currently does not return an error if a policy was not configured. + // Instead, it returns the default value provided as the second argument. + // This behavior makes it impossible to distinguish between a policy that + // was not set and a policy that was set to an empty default value. + // Checking for sentinel here is a workaround to distinguish between + // the two cases. If we get it, we do nothing because the policy was not set. + // + // TODO(angott,nickkhyl): clean up this behavior once syspolicy.GetString starts + // properly returning errors. + case "": + // The policy was set to an empty string, which means the admin intends + // to clear the hostname stored in preferences. + prefs.Hostname = "" + anyChange = true + default: + // The policy was set to a non-empty string, which means the admin wants + // to override the hostname stored in preferences. + if prefs.Hostname != hostnameFromPolicy { + prefs.Hostname = hostnameFromPolicy + anyChange = true + } + } + + // Only apply the exit node policy if the user hasn't overridden it. + if !b.overrideExitNodePolicy && b.applyExitNodeSysPolicyLocked(prefs) { + anyChange = true + } + + if alwaysOn, _ := b.polc.GetBoolean(pkey.AlwaysOn, false); alwaysOn && !b.overrideAlwaysOn && !prefs.WantRunning { + prefs.WantRunning = true + anyChange = true + } + + for _, opt := range preferencePolicies { + if po, err := b.polc.GetPreferenceOption(opt.key, ptype.ShowChoiceByPolicy); err == nil { + curVal := opt.get(prefs.View()) + newVal := po.ShouldEnable(curVal) + if curVal != newVal { + opt.set(prefs, newVal) + anyChange = true + } + } } return anyChange } +// applyExitNodeSysPolicyLocked applies the exit node policy settings to prefs +// and reports whether any change was made. +// +// b.mu must be held. +func (b *LocalBackend) applyExitNodeSysPolicyLocked(prefs *ipn.Prefs) (anyChange bool) { + if !buildfeatures.HasUseExitNode { + return false + } + if exitNodeIDStr, _ := b.polc.GetString(pkey.ExitNodeID, ""); exitNodeIDStr != "" { + exitNodeID := tailcfg.StableNodeID(exitNodeIDStr) + + // Try to parse the policy setting value as an "auto:"-prefixed [ipn.ExitNodeExpression], + // and update prefs if it differs from the current one. + // This includes cases where it was previously an expression but no longer is, + // or where it wasn't before but now is. + autoExitNode, useAutoExitNode := ipn.ParseAutoExitNodeString(exitNodeID) + if prefs.AutoExitNode != autoExitNode { + prefs.AutoExitNode = autoExitNode + anyChange = true + } + // Additionally, if the specified exit node ID is an expression, + // meaning an exit node is required but we don't yet have a valid exit node ID, + // we should set exitNodeID to a value that is never a valid [tailcfg.StableNodeID], + // to install a blackhole route and prevent accidental non-exit-node usage + // until the expression is evaluated and an actual exit node is selected. + // We use "auto:any" for this purpose, primarily for compatibility with + // older clients (in case a user downgrades to an earlier version) + // and GUIs/CLIs that have special handling for it. + if useAutoExitNode { + exitNodeID = unresolvedExitNodeID + } + + // If the current exit node ID doesn't match the one enforced by the policy setting, + // and the policy either requires a specific exit node ID, + // or requires an auto exit node ID and the current one isn't allowed, + // then update the exit node ID. + if prefs.ExitNodeID != exitNodeID { + if !useAutoExitNode || !isAllowedAutoExitNodeID(b.polc, prefs.ExitNodeID) { + prefs.ExitNodeID = exitNodeID + anyChange = true + } + } + + // If the exit node IP is set, clear it. When ExitNodeIP is set in the prefs, + // it takes precedence over the ExitNodeID. + if prefs.ExitNodeIP.IsValid() { + prefs.ExitNodeIP = netip.Addr{} + anyChange = true + } + } else if exitNodeIPStr, _ := b.polc.GetString(pkey.ExitNodeIP, ""); exitNodeIPStr != "" { + if prefs.AutoExitNode != "" { + prefs.AutoExitNode = "" // mutually exclusive with ExitNodeIP + anyChange = true + } + if exitNodeIP, err := netip.ParseAddr(exitNodeIPStr); err == nil { + if prefs.ExitNodeID != "" || prefs.ExitNodeIP != exitNodeIP { + anyChange = true + } + prefs.ExitNodeID = "" + prefs.ExitNodeIP = exitNodeIP + } + } + + return anyChange +} + +// registerSysPolicyWatch subscribes to syspolicy change notifications +// and immediately applies the effective syspolicy settings to the current profile. +func (b *LocalBackend) registerSysPolicyWatch() (unregister func(), err error) { + if unregister, err = b.polc.RegisterChangeCallback(b.sysPolicyChanged); err != nil { + return nil, fmt.Errorf("syspolicy: LocalBacked failed to register policy change callback: %v", err) + } + if prefs, anyChange := b.reconcilePrefs(); anyChange { + b.logf("syspolicy: changed initial profile prefs: %v", prefs.Pretty()) + } + b.refreshAllowedSuggestions() + return unregister, nil +} + +// reconcilePrefs overwrites the current profile's preferences with policies +// that may be configured by the system administrator in an OS-specific way. +// +// b.mu must not be held. +func (b *LocalBackend) reconcilePrefs() (_ ipn.PrefsView, anyChange bool) { + b.mu.Lock() + defer b.mu.Unlock() + + prefs := b.pm.CurrentPrefs().AsStruct() + if !b.reconcilePrefsLocked(prefs) { + return prefs.View(), false + } + return b.setPrefsLocked(prefs), true +} + +// sysPolicyChanged is a callback triggered by syspolicy when it detects +// a change in one or more syspolicy settings. +func (b *LocalBackend) sysPolicyChanged(policy policyclient.PolicyChange) { + if policy.HasChangedAnyOf(pkey.AlwaysOn, pkey.AlwaysOnOverrideWithReason) { + // If the AlwaysOn or the AlwaysOnOverrideWithReason policy has changed, + // we should reset the overrideAlwaysOn flag, as the override might + // no longer be valid. + b.mu.Lock() + b.overrideAlwaysOn = false + b.mu.Unlock() + } + + if policy.HasChangedAnyOf(pkey.ExitNodeID, pkey.ExitNodeIP, pkey.AllowExitNodeOverride) { + // Reset the exit node override if a policy that enforces exit node usage + // or allows the user to override automatic exit node selection has changed. + b.mu.Lock() + b.overrideExitNodePolicy = false + b.mu.Unlock() + } + + if buildfeatures.HasUseExitNode && policy.HasChanged(pkey.AllowedSuggestedExitNodes) { + b.refreshAllowedSuggestions() + // Re-evaluate exit node suggestion now that the policy setting has changed. + if _, err := b.SuggestExitNode(); err != nil && !errors.Is(err, ErrNoPreferredDERP) { + b.logf("failed to select auto exit node: %v", err) + } + // If [pkey.ExitNodeID] is set to `auto:any`, the suggested exit node ID + // will be used when [applySysPolicy] updates the current profile's prefs. + } + + if prefs, anyChange := b.reconcilePrefs(); anyChange { + b.logf("syspolicy: changed profile prefs: %v", prefs.Pretty()) + } +} + var _ controlclient.NetmapDeltaUpdater = (*LocalBackend)(nil) // UpdateNetmapDelta implements controlclient.NetmapDeltaUpdater. func (b *LocalBackend) UpdateNetmapDelta(muts []netmap.NodeMutation) (handled bool) { - if !b.MagicConn().UpdateNetmapDelta(muts) { - return false - } - var notify *ipn.Notify // non-nil if we need to send a Notify defer func() { if notify != nil { b.send(*notify) } }() - unlock := b.lockAndGetUnlock() - defer unlock() - if !b.updateNetmapDeltaLocked(muts) { - return false - } - if b.netMap != nil && mutationsAreWorthyOfTellingIPNBus(muts) { - nm := ptr.To(*b.netMap) // shallow clone - nm.Peers = make([]tailcfg.NodeView, 0, len(b.peers)) - shouldAutoExitNode := shouldAutoExitNode() - for _, p := range b.peers { - nm.Peers = append(nm.Peers, p) - // If the auto exit node currently set goes offline, find another auto exit node. - if shouldAutoExitNode && b.pm.prefs.ExitNodeID() == p.StableID() && p.Online() != nil && !*p.Online() { - b.setAutoExitNodeIDLockedOnEntry(unlock) - return false + // Gross. See https://github.com/tailscale/tailscale/issues/16369 + b.settleEventBus() + defer b.settleEventBus() + + b.mu.Lock() + defer b.mu.Unlock() + + cn := b.currentNode() + cn.UpdateNetmapDelta(muts) + + // If auto exit nodes are enabled and our exit node went offline, + // we need to schedule picking a new one. + // TODO(nickkhyl): move the auto exit node logic to a feature package. + if prefs := b.pm.CurrentPrefs(); prefs.AutoExitNode().IsSet() { + exitNodeID := prefs.ExitNodeID() + for _, m := range muts { + mo, ok := m.(netmap.NodeMutationOnline) + if !ok || mo.Online { + continue + } + n, ok := cn.NodeByID(m.NodeIDBeingMutated()) + if !ok || n.StableID() != exitNodeID { + continue } + b.refreshExitNodeLocked() + break } - slices.SortFunc(nm.Peers, func(a, b tailcfg.NodeView) int { - return cmp.Compare(a.ID(), b.ID()) - }) + } + + if cn.NetMap() != nil && mutationsAreWorthyOfRecalculatingSuggestedExitNode(muts, cn, b.lastSuggestedExitNode) { + // Recompute the suggested exit node + b.suggestExitNodeLocked() + } + + if cn.NetMap() != nil && mutationsAreWorthyOfTellingIPNBus(muts) { + + nm := cn.netMapWithPeers() notify = &ipn.Notify{NetMap: nm} } else if testenv.InTest() { // In tests, send an empty Notify as a wake-up so end-to-end @@ -1691,6 +2121,44 @@ func (b *LocalBackend) UpdateNetmapDelta(muts []netmap.NodeMutation) (handled bo return true } +// mustationsAreWorthyOfRecalculatingSuggestedExitNode reports whether any mutation type in muts is +// worthy of recalculating the suggested exit node. +func mutationsAreWorthyOfRecalculatingSuggestedExitNode(muts []netmap.NodeMutation, cn *nodeBackend, sid tailcfg.StableNodeID) bool { + if !buildfeatures.HasUseExitNode { + return false + } + for _, m := range muts { + n, ok := cn.NodeByID(m.NodeIDBeingMutated()) + if !ok { + // The node being mutated is not in the netmap. + continue + } + + // The previously suggested exit node itself is being mutated. + if sid != "" && n.StableID() == sid { + return true + } + + allowed := n.AllowedIPs().AsSlice() + isExitNode := slices.Contains(allowed, tsaddr.AllIPv4()) || slices.Contains(allowed, tsaddr.AllIPv6()) + // The node being mutated is not an exit node. We don't care about it - unless + // it was our previously suggested exit node which we catch above. + if !isExitNode { + continue + } + + // Some exit node is being mutated. We care about it if it's online + // or offline state has changed. We *might* eventually care about it for other reasons + // but for the sake of finding a "better" suggested exit node, this is probably + // sufficient. + switch m.(type) { + case netmap.NodeMutationOnline: + return true + } + } + return false +} + // mutationsAreWorthyOfTellingIPNBus reports whether any mutation type in muts is // worthy of spamming the IPN bus (the Windows & Mac GUIs, basically) to tell them // about the update. @@ -1707,68 +2175,62 @@ func mutationsAreWorthyOfTellingIPNBus(muts []netmap.NodeMutation) bool { return false } -func (b *LocalBackend) updateNetmapDeltaLocked(muts []netmap.NodeMutation) (handled bool) { - if b.netMap == nil || len(b.peers) == 0 { +// resolveAutoExitNodeLocked computes a suggested exit node and updates prefs +// to use it if AutoExitNode is enabled, and reports whether prefs was mutated. +// +// b.mu must be held. +func (b *LocalBackend) resolveAutoExitNodeLocked(prefs *ipn.Prefs) (prefsChanged bool) { + if !buildfeatures.HasUseExitNode { return false } - - // Locally cloned mutable nodes, to avoid calling AsStruct (clone) - // multiple times on a node if it's mutated multiple times in this - // call (e.g. its endpoints + online status both change) - var mutableNodes map[tailcfg.NodeID]*tailcfg.Node - - for _, m := range muts { - n, ok := mutableNodes[m.NodeIDBeingMutated()] - if !ok { - nv, ok := b.peers[m.NodeIDBeingMutated()] - if !ok { - // TODO(bradfitz): unexpected metric? - return false - } - n = nv.AsStruct() - mak.Set(&mutableNodes, nv.ID(), n) - } - m.Apply(n) + // As of 2025-07-08, the only supported auto exit node expression is [ipn.AnyExitNode]. + // + // However, to maintain forward compatibility with future auto exit node expressions, + // we treat any non-empty AutoExitNode as [ipn.AnyExitNode]. + // + // If and when we support additional auto exit node expressions, this method should be updated + // to handle them appropriately, while still falling back to [ipn.AnyExitNode] or a more appropriate + // default for unknown (or partially supported) expressions. + if !prefs.AutoExitNode.IsSet() { + return false } - for nid, n := range mutableNodes { - b.peers[nid] = n.View() + if _, err := b.suggestExitNodeLocked(); err != nil && !errors.Is(err, ErrNoPreferredDERP) { + b.logf("failed to select auto exit node: %v", err) // non-fatal, see below + } + var newExitNodeID tailcfg.StableNodeID + if !b.lastSuggestedExitNode.IsZero() { + // If we have a suggested exit node, use it. + newExitNodeID = b.lastSuggestedExitNode + } else if isAllowedAutoExitNodeID(b.polc, prefs.ExitNodeID) { + // If we don't have a suggested exit node, but the prefs already + // specify an allowed auto exit node ID, retain it. + newExitNodeID = prefs.ExitNodeID + } else { + // Otherwise, use [unresolvedExitNodeID] to install a blackhole route, + // preventing traffic from leaking to the local network until an actual + // exit node is selected. + newExitNodeID = unresolvedExitNodeID } - return true -} - -// setExitNodeID updates prefs to reference an exit node by ID, rather -// than by IP. It returns whether prefs was mutated. -func setExitNodeID(prefs *ipn.Prefs, nm *netmap.NetworkMap, lastSuggestedExitNode tailcfg.StableNodeID) (prefsChanged bool) { - if exitNodeIDStr, _ := syspolicy.GetString(syspolicy.ExitNodeID, ""); exitNodeIDStr != "" { - exitNodeID := tailcfg.StableNodeID(exitNodeIDStr) - if shouldAutoExitNode() && lastSuggestedExitNode != "" { - exitNodeID = lastSuggestedExitNode - } - // Note: when exitNodeIDStr == "auto" && lastSuggestedExitNode == "", then exitNodeID is now "auto" which will never match a peer's node ID. - // When there is no a peer matching the node ID, traffic will blackhole, preventing accidental non-exit-node usage when a policy is in effect that requires an exit node. - changed := prefs.ExitNodeID != exitNodeID || prefs.ExitNodeIP.IsValid() - prefs.ExitNodeID = exitNodeID - prefs.ExitNodeIP = netip.Addr{} - return changed + if prefs.ExitNodeID != newExitNodeID { + prefs.ExitNodeID = newExitNodeID + prefsChanged = true } - - oldExitNodeID := prefs.ExitNodeID - if exitNodeIPStr, _ := syspolicy.GetString(syspolicy.ExitNodeIP, ""); exitNodeIPStr != "" { - exitNodeIP, err := netip.ParseAddr(exitNodeIPStr) - if exitNodeIP.IsValid() && err == nil { - prefsChanged = prefs.ExitNodeID != "" || prefs.ExitNodeIP != exitNodeIP - prefs.ExitNodeID = "" - prefs.ExitNodeIP = exitNodeIP - } + if prefs.ExitNodeIP.IsValid() { + prefs.ExitNodeIP = netip.Addr{} + prefsChanged = true } + return prefsChanged +} - if nm == nil { - // No netmap, can't resolve anything. +// resolveExitNodeIPLocked updates prefs to reference an exit node by ID, rather +// than by IP. It returns whether prefs was mutated. +// +// b.mu must be held. +func (b *LocalBackend) resolveExitNodeIPLocked(prefs *ipn.Prefs) (prefsChanged bool) { + if !buildfeatures.HasUseExitNode { return false } - - // If we have a desired IP on file, try to find the corresponding - // node. + // If we have a desired IP on file, try to find the corresponding node. if !prefs.ExitNodeIP.IsValid() { return false } @@ -1779,20 +2241,19 @@ func setExitNodeID(prefs *ipn.Prefs, nm *netmap.NetworkMap, lastSuggestedExitNod prefsChanged = true } - for _, peer := range nm.Peers { - for i := range peer.Addresses().Len() { - addr := peer.Addresses().At(i) - if !addr.IsSingleIP() || addr.Addr() != prefs.ExitNodeIP { - continue - } + cn := b.currentNode() + if nid, ok := cn.NodeByAddr(prefs.ExitNodeIP); ok { + if node, ok := cn.NodeByID(nid); ok { // Found the node being referenced, upgrade prefs to // reference it directly for next time. - prefs.ExitNodeID = peer.StableID() + prefs.ExitNodeID = node.StableID() prefs.ExitNodeIP = netip.Addr{} - return oldExitNodeID != prefs.ExitNodeID + // Cleared ExitNodeIP, so prefs changed + // even if the ID stayed the same. + prefsChanged = true + } } - return prefsChanged } @@ -1801,62 +2262,60 @@ func setExitNodeID(prefs *ipn.Prefs, nm *netmap.NetworkMap, lastSuggestedExitNod func (b *LocalBackend) setWgengineStatus(s *wgengine.Status, err error) { if err != nil { b.logf("wgengine status error: %v", err) - b.broadcastStatusChanged() return } if s == nil { b.logf("[unexpected] non-error wgengine update with status=nil: %v", s) - b.broadcastStatusChanged() return } b.mu.Lock() + defer b.mu.Unlock() + + // For now, only check this in the callback, but don't check it in setWgengineStatusLocked if s.AsOf.Before(b.lastStatusTime) { // Don't process a status update that is older than the one we have // already processed. (corp#2579) - b.mu.Unlock() return } b.lastStatusTime = s.AsOf + + b.setWgengineStatusLocked(s) +} + +// setWgengineStatusLocked updates LocalBackend's view of the engine status and +// updates the endpoints both in the backend and in the control client. +// +// Unlike setWgengineStatus it does not discard out-of-order updates, so +// statuses sent here are always processed. This is useful for ensuring we don't +// miss a "we shut down" status during backend shutdown even if other statuses +// arrive out of order. +// +// TODO(zofrex): we should ensure updates actually do arrive in order and move +// the out-of-order check into this function. +// +// b.mu must be held. +func (b *LocalBackend) setWgengineStatusLocked(s *wgengine.Status) { es := b.parseWgStatusLocked(s) cc := b.cc + + // TODO(zofrex): the only reason we even write this is to transition from + // "Starting" to "Running" in the call to state machine a few lines below + // this. Maybe we don't even need to store it at all. b.engineStatus = es - needUpdateEndpoints := !endpointsEqual(s.LocalAddrs, b.endpoints) + + needUpdateEndpoints := !slices.Equal(s.LocalAddrs, b.endpoints) if needUpdateEndpoints { b.endpoints = append([]tailcfg.Endpoint{}, s.LocalAddrs...) } - b.mu.Unlock() if cc != nil { if needUpdateEndpoints { cc.UpdateEndpoints(s.LocalAddrs) } - b.stateMachine() - } - b.broadcastStatusChanged() - b.send(ipn.Notify{Engine: &es}) -} - -func (b *LocalBackend) broadcastStatusChanged() { - // The sync.Cond docs say: "It is allowed but not required for the caller to hold c.L during the call." - // In this particular case, we must acquire b.statusLock. Otherwise we might broadcast before - // the waiter (in requestEngineStatusAndWait) starts to wait, in which case - // the waiter can get stuck indefinitely. See PR 2865. - b.statusLock.Lock() - b.statusChanged.Broadcast() - b.statusLock.Unlock() -} - -func endpointsEqual(x, y []tailcfg.Endpoint) bool { - if len(x) != len(y) { - return false - } - for i := range x { - if x[i] != y[i] { - return false - } + b.stateMachineLocked() } - return true + b.sendLocked(ipn.Notify{Engine: &es}) } // SetNotifyCallback sets the function to call when the backend has something to @@ -1901,33 +2360,11 @@ func (b *LocalBackend) SetControlClientGetterForTesting(newControlClient func(co b.ccGen = newControlClient } -// NodeViewByIDForTest returns the state of the node with the given ID -// for integration tests in another repo. -func (b *LocalBackend) NodeViewByIDForTest(id tailcfg.NodeID) (_ tailcfg.NodeView, ok bool) { - b.mu.Lock() - defer b.mu.Unlock() - n, ok := b.peers[id] - return n, ok -} - -// DisablePortMapperForTest disables the portmapper for tests. -// It must be called before Start. -func (b *LocalBackend) DisablePortMapperForTest() { - b.mu.Lock() - defer b.mu.Unlock() - b.portpoll = nil -} - // PeersForTest returns all the current peers, sorted by Node.ID, // for integration tests in another repo. func (b *LocalBackend) PeersForTest() []tailcfg.NodeView { - b.mu.Lock() - defer b.mu.Unlock() - ret := xmaps.Values(b.peers) - slices.SortFunc(ret, func(a, b tailcfg.NodeView) int { - return cmp.Compare(a.ID(), b.ID()) - }) - return ret + testenv.AssertInTest() + return b.currentNode().PeersForTest() } func (b *LocalBackend) getNewControlClientFuncLocked() clientGen { @@ -1942,6 +2379,11 @@ func (b *LocalBackend) getNewControlClientFuncLocked() clientGen { return b.ccGen } +// initOnce is called on the first call to [LocalBackend.Start]. +func (b *LocalBackend) initOnce() { + b.extHost.Init() +} + // Start applies the configuration specified in opts, and starts the // state machine. // @@ -1953,7 +2395,16 @@ func (b *LocalBackend) getNewControlClientFuncLocked() clientGen { // actually a supported operation (it should be, but it's very unclear // from the following whether or not that is a safe transition). func (b *LocalBackend) Start(opts ipn.Options) error { + defer b.settleEventBus() // with b.mu unlocked + b.mu.Lock() + defer b.mu.Unlock() + return b.startLocked(opts) +} + +func (b *LocalBackend) startLocked(opts ipn.Options) error { b.logf("Start") + logf := logger.WithPrefix(b.logf, "Start: ") + b.startOnce.Do(b.initOnce) var clientToShutdown controlclient.Client defer func() { @@ -1961,8 +2412,6 @@ func (b *LocalBackend) Start(opts ipn.Options) error { clientToShutdown.Shutdown() } }() - unlock := b.lockAndGetUnlock() - defer unlock() if opts.UpdatePrefs != nil { if err := b.checkPrefsLocked(opts.UpdatePrefs); err != nil { @@ -1982,9 +2431,9 @@ func (b *LocalBackend) Start(opts ipn.Options) error { } if b.state != ipn.Running && b.conf == nil && opts.AuthKey == "" { - sysak, _ := syspolicy.GetString(syspolicy.AuthKey, "") + sysak, _ := b.polc.GetString(pkey.AuthKey, "") if sysak != "" { - b.logf("Start: setting opts.AuthKey by syspolicy, len=%v", len(sysak)) + logf("setting opts.AuthKey by syspolicy, len=%v", len(sysak)) opts.AuthKey = strings.TrimSpace(sysak) } } @@ -1996,6 +2445,7 @@ func (b *LocalBackend) Start(opts ipn.Options) error { hostinfo.Userspace.Set(b.sys.IsNetstack()) hostinfo.UserspaceRouter.Set(b.sys.IsNetstackRouter()) hostinfo.AppConnector.Set(b.appConnector != nil) + hostinfo.StateEncrypted = b.stateEncrypted() b.logf.JSON(1, "Hostinfo", hostinfo) // TODO(apenwarr): avoid the need to reinit controlclient. @@ -2012,25 +2462,55 @@ func (b *LocalBackend) Start(opts ipn.Options) error { hostinfo.Services = b.hostinfo.Services // keep any previous services } b.hostinfo = hostinfo - b.state = ipn.NoState + b.setStateLocked(ipn.NoState) + cn := b.currentNode() + + var prefsChanged bool + var prefsChangedWhy []string + newPrefs := b.pm.CurrentPrefs().AsStruct() if opts.UpdatePrefs != nil { - oldPrefs := b.pm.CurrentPrefs() - newPrefs := opts.UpdatePrefs.Clone() - newPrefs.Persist = oldPrefs.Persist().AsStruct() - pv := newPrefs.View() - if err := b.pm.SetPrefs(pv, ipn.NetworkProfile{ - MagicDNSName: b.netMap.MagicDNSSuffix(), - DomainName: b.netMap.DomainName(), - }); err != nil { - b.logf("failed to save UpdatePrefs state: %v", err) + newPrefs = opts.UpdatePrefs.Clone() + prefsChanged = true + prefsChangedWhy = append(prefsChangedWhy, "opts.UpdatePrefs") + } + // Apply any syspolicy overrides, resolve exit node ID, etc. + // As of 2025-07-03, this is primarily needed in two cases: + // - when opts.UpdatePrefs is not nil + // - when Always Mode is enabled and we need to set WantRunning to true + if b.reconcilePrefsLocked(newPrefs) { + prefsChanged = true + prefsChangedWhy = append(prefsChangedWhy, "reconcilePrefsLocked") + } + + // neither UpdatePrefs or reconciliation should change Persist + newPrefs.Persist = b.pm.CurrentPrefs().Persist().AsStruct() + + if buildfeatures.HasTPM { + if genKey, ok := feature.HookGenerateAttestationKeyIfEmpty.GetOk(); ok { + newKey, err := genKey(newPrefs.Persist, logf) + if err != nil { + logf("failed to populate attestation key from TPM: %v", err) + } + if newKey { + prefsChanged = true + prefsChangedWhy = append(prefsChangedWhy, "newKey") + } } - b.setAtomicValuesFromPrefsLocked(pv) - } else { - b.setAtomicValuesFromPrefsLocked(b.pm.CurrentPrefs()) } - prefs := b.pm.CurrentPrefs() + if prefsChanged { + logf("updated prefs: %v, reason: %v", newPrefs.Pretty(), prefsChangedWhy) + if err := b.pm.SetPrefs(newPrefs.View(), cn.NetworkProfile()); err != nil { + logf("failed to save updated and reconciled prefs (but still using updated prefs in memory): %v", err) + } + } + prefs := newPrefs.View() + + // Reset the always-on override whenever Start is called. + b.resetAlwaysOnOverrideLocked() + b.setAtomicValuesFromPrefsLocked(prefs) + wantRunning := prefs.WantRunning() if wantRunning { if err := b.initMachineKeyLocked(); err != nil { @@ -2040,60 +2520,64 @@ func (b *LocalBackend) Start(opts ipn.Options) error { loggedOut := prefs.LoggedOut() - serverURL := prefs.ControlURLOrDefault() + serverURL := prefs.ControlURLOrDefault(b.polc) if inServerMode := prefs.ForceDaemon(); inServerMode || runtime.GOOS == "windows" { - b.logf("Start: serverMode=%v", inServerMode) + logf("serverMode=%v", inServerMode) } b.applyPrefsToHostinfoLocked(hostinfo, prefs) + b.updateWarnSync(prefs) - b.setNetMapLocked(nil) persistv := prefs.Persist().AsStruct() if persistv == nil { persistv = new(persist.Persist) } - b.updateFilterLocked(nil, ipn.PrefsView{}) - - if b.portpoll != nil { - b.portpollOnce.Do(func() { - go b.readPoller() - }) - } discoPublic := b.MagicConn().DiscoPublicKey() - var err error - isNetstack := b.sys.IsNetstackRouter() debugFlags := controlDebugFlags if isNetstack { debugFlags = append([]string{"netstack"}, debugFlags...) } + var ccShutdownCbs []func() + ccShutdown := func() { + for _, cb := range ccShutdownCbs { + cb() + } + } + + var c2nHandler http.Handler + if buildfeatures.HasC2N { + c2nHandler = http.HandlerFunc(b.handleC2N) + } + // TODO(apenwarr): The only way to change the ServerURL is to // re-run b.Start, because this is the only place we create a // new controlclient. EditPrefs allows you to overwrite ServerURL, // but it won't take effect until the next Start. cc, err := b.getNewControlClientFuncLocked()(controlclient.Options{ - GetMachinePrivateKey: b.createGetMachinePrivateKeyFunc(), - Logf: logger.WithPrefix(b.logf, "control: "), - Persist: *persistv, - ServerURL: serverURL, - AuthKey: opts.AuthKey, - Hostinfo: hostinfo, - HTTPTestClient: httpTestClient, - DiscoPublicKey: discoPublic, - DebugFlags: debugFlags, - HealthTracker: b.health, - Pinger: b, - PopBrowserURL: b.tellClientToBrowseToURL, - OnClientVersion: b.onClientVersion, - OnTailnetDefaultAutoUpdate: b.onTailnetDefaultAutoUpdate, - OnControlTime: b.em.onControlTime, - Dialer: b.Dialer(), - Observer: b, - C2NHandler: http.HandlerFunc(b.handleC2N), - DialPlan: &b.dialPlan, // pointer because it can't be copied - ControlKnobs: b.sys.ControlKnobs(), + GetMachinePrivateKey: b.createGetMachinePrivateKeyFunc(), + Logf: logger.WithPrefix(b.logf, "control: "), + Persist: *persistv, + ServerURL: serverURL, + AuthKey: opts.AuthKey, + Hostinfo: hostinfo, + HTTPTestClient: httpTestClient, + DiscoPublicKey: discoPublic, + DebugFlags: debugFlags, + HealthTracker: b.health, + PolicyClient: b.sys.PolicyClientOrDefault(), + Pinger: b, + PopBrowserURL: b.tellClientToBrowseToURL, + Dialer: b.Dialer(), + Observer: b, + C2NHandler: c2nHandler, + DialPlan: &b.dialPlan, // pointer because it can't be copied + ControlKnobs: b.sys.ControlKnobs(), + Shutdown: ccShutdown, + Bus: b.sys.Bus.Get(), + StartPaused: prefs.Sync().EqualBool(false), // Don't warn about broken Linux IP forwarding when // netstack is being used. @@ -2102,12 +2586,13 @@ func (b *LocalBackend) Start(opts ipn.Options) error { if err != nil { return err } + ccShutdownCbs = b.extHost.NotifyNewControlClient(cc, b.pm.CurrentProfile()) b.setControlClientLocked(cc) endpoints := b.endpoints if err := b.initTKALocked(); err != nil { - b.logf("initTKALocked: %v", err) + logf("initTKALocked: %v", err) } var tkaHead string if b.tka != nil { @@ -2126,10 +2611,17 @@ func (b *LocalBackend) Start(opts ipn.Options) error { blid := b.backendLogID.String() b.logf("Backend: logs: be:%v fe:%v", blid, opts.FrontendLogID) - b.sendLocked(ipn.Notify{ - BackendLogID: &blid, - Prefs: &prefs, - }) + b.sendToLocked(ipn.Notify{Prefs: &prefs}, allClients) + + // initialize Taildrive shares from saved state + if fs, ok := b.sys.DriveForRemote.GetOK(); ok { + currentShares := b.pm.CurrentPrefs().DriveShares() + var shares []*drive.Share + for _, share := range currentShares.All() { + shares = append(shares, share.AsStruct()) + } + fs.SetShares(shares) + } if !loggedOut && (b.hasNodeKeyLocked() || confWantRunning) { // If we know that we're either logged in or meant to be @@ -2141,7 +2633,30 @@ func (b *LocalBackend) Start(opts ipn.Options) error { // regress tsnet.Server restarts. cc.Login(controlclient.LoginDefault) } - b.stateMachineLockedOnEntry(unlock) + b.stateMachineLocked() + + return nil +} + +// addServiceIPs adds the IP addresses of any VIP Services sent from the +// coordination server to the list of addresses that we expect to handle. +func addServiceIPs(localNetsB *netipx.IPSetBuilder, selfNode tailcfg.NodeView) error { + if !selfNode.Valid() { + return nil + } + + serviceMap, err := tailcfg.UnmarshalNodeCapViewJSON[tailcfg.ServiceIPMappings](selfNode.CapMap(), tailcfg.NodeAttrServiceHost) + if err != nil { + return err + } + + for _, sm := range serviceMap { // typically there will be exactly one of these + for _, serviceAddrs := range sm { + for _, addr := range serviceAddrs { // typically there will be exactly two of these + localNetsB.Add(addr) + } + } + } return nil } @@ -2154,17 +2669,58 @@ var invalidPacketFilterWarnable = health.Register(&health.Warnable{ Text: health.StaticMessage("The coordination server sent an invalid packet filter permitting traffic to unlocked nodes; rejecting all packets for safety"), }) +// filterInputs holds the inputs to the packet filter. +// +// Any field changes or additions here should be accompanied by a change to +// [filterInputs.Equal] and [filterInputs.Clone] if necessary. (e.g. non-view +// and non-value fields) +type filterInputs struct { + HaveNetmap bool + Addrs views.Slice[netip.Prefix] + FilterMatch views.Slice[filter.Match] + LocalNets views.Slice[netipx.IPRange] + LogNets views.Slice[netipx.IPRange] + ShieldsUp bool + SSHPolicy tailcfg.SSHPolicyView +} + +func (fi *filterInputs) Equal(o *filterInputs) bool { + if fi == nil || o == nil { + return fi == o + } + return reflect.DeepEqual(fi, o) +} + +func (fi *filterInputs) Clone() *filterInputs { + if fi == nil { + return nil + } + v := *fi // all fields are shallow copyable + return &v +} + // updateFilterLocked updates the packet filter in wgengine based on the // given netMap and user preferences. // // b.mu must be held. -func (b *LocalBackend) updateFilterLocked(netMap *netmap.NetworkMap, prefs ipn.PrefsView) { +func (b *LocalBackend) updateFilterLocked(prefs ipn.PrefsView) { + // TODO(nickkhyl) split this into two functions: + // - (*nodeBackend).RebuildFilters() (normalFilter, jailedFilter *filter.Filter, changed bool), + // which would return packet filters for the current state and whether they changed since the last call. + // - (*LocalBackend).updateFilters(), which would use the above to update the engine with the new filters, + // notify b.sshServer, etc. + // + // For this, we would need to plumb a few more things into the [nodeBackend]. Most importantly, + // the current [ipn.PrefsView]), but also maybe also a b.logf and a b.health? + // // NOTE(danderson): keep change detection as the first thing in // this function. Don't try to optimize by returning early, more // likely than not you'll just end up breaking the change // detection and end up with the wrong filter installed. This is // quite hard to debug, so save yourself the trouble. var ( + cn = b.currentNode() + netMap = cn.NetMap() haveNetmap = netMap != nil addrs views.Slice[netip.Prefix] packetFilter []filter.Match @@ -2183,39 +2739,45 @@ func (b *LocalBackend) updateFilterLocked(netMap *netmap.NetworkMap, prefs ipn.P } packetFilter = netMap.PacketFilter - if packetFilterPermitsUnlockedNodes(b.peers, packetFilter) { + if cn.unlockedNodesPermitted(packetFilter) { b.health.SetUnhealthy(invalidPacketFilterWarnable, nil) packetFilter = nil } else { b.health.SetHealthy(invalidPacketFilterWarnable) } + + if err := addServiceIPs(&localNetsB, netMap.SelfNode); err != nil { + b.logf("addServiceIPs: %v", err) + } } if prefs.Valid() { - for _, r := range prefs.AdvertiseRoutes().All() { - if r.Bits() == 0 { - // When offering a default route to the world, we - // filter out locally reachable LANs, so that the - // default route effectively appears to be a "guest - // wifi": you get internet access, but to additionally - // get LAN access the LAN(s) need to be offered - // explicitly as well. - localInterfaceRoutes, hostIPs, err := interfaceRoutes() - if err != nil { - b.logf("getting local interface routes: %v", err) - continue - } - s, err := shrinkDefaultRoute(r, localInterfaceRoutes, hostIPs) - if err != nil { - b.logf("computing default route filter: %v", err) - continue + if buildfeatures.HasAdvertiseRoutes { + for _, r := range prefs.AdvertiseRoutes().All() { + if r.Bits() == 0 { + // When offering a default route to the world, we + // filter out locally reachable LANs, so that the + // default route effectively appears to be a "guest + // wifi": you get internet access, but to additionally + // get LAN access the LAN(s) need to be offered + // explicitly as well. + localInterfaceRoutes, hostIPs, err := interfaceRoutes() + if err != nil { + b.logf("getting local interface routes: %v", err) + continue + } + s, err := shrinkDefaultRoute(r, localInterfaceRoutes, hostIPs) + if err != nil { + b.logf("computing default route filter: %v", err) + continue + } + localNetsB.AddSet(s) + } else { + localNetsB.AddPrefix(r) + // When advertising a non-default route, we assume + // this is a corporate subnet that should be present + // in the audit logs. + logNetsB.AddPrefix(r) } - localNetsB.AddSet(s) - } else { - localNetsB.AddPrefix(r) - // When advertising a non-default route, we assume - // this is a corporate subnet that should be present - // in the audit logs. - logNetsB.AddPrefix(r) } } @@ -2226,27 +2788,27 @@ func (b *LocalBackend) updateFilterLocked(netMap *netmap.NetworkMap, prefs ipn.P // The correct filter rules are synthesized by the coordination server // and sent down, but the address needs to be part of the 'local net' for the // filter package to even bother checking the filter rules, so we set them here. - if prefs.AppConnector().Advertise { + if buildfeatures.HasAppConnectors && prefs.AppConnector().Advertise { localNetsB.Add(netip.MustParseAddr("0.0.0.0")) localNetsB.Add(netip.MustParseAddr("::0")) } } localNets, _ := localNetsB.IPSet() logNets, _ := logNetsB.IPSet() - var sshPol tailcfg.SSHPolicy - if haveNetmap && netMap.SSHPolicy != nil { - sshPol = *netMap.SSHPolicy - } - - changed := deephash.Update(&b.filterHash, &struct { - HaveNetmap bool - Addrs views.Slice[netip.Prefix] - FilterMatch []filter.Match - LocalNets []netipx.IPRange - LogNets []netipx.IPRange - ShieldsUp bool - SSHPolicy tailcfg.SSHPolicy - }{haveNetmap, addrs, packetFilter, localNets.Ranges(), logNets.Ranges(), shieldsUp, sshPol}) + var sshPol tailcfg.SSHPolicyView + if buildfeatures.HasSSH && haveNetmap && netMap.SSHPolicy != nil { + sshPol = netMap.SSHPolicy.View() + } + + changed := checkchange.Update(&b.lastFilterInputs, &filterInputs{ + HaveNetmap: haveNetmap, + Addrs: addrs, + FilterMatch: views.SliceOf(packetFilter), + LocalNets: views.SliceOf(localNets.Ranges()), + LogNets: views.SliceOf(logNets.Ranges()), + ShieldsUp: shieldsUp, + SSHPolicy: sshPol, + }) if !changed { return } @@ -2272,126 +2834,10 @@ func (b *LocalBackend) updateFilterLocked(netMap *netmap.NetworkMap, prefs ipn.P b.e.SetJailedFilter(filter.NewShieldsUpFilter(localNets, logNets, oldJailedFilter, b.logf)) if b.sshServer != nil { - go b.sshServer.OnPolicyChange() - } -} - -// captivePortalWarnable is a Warnable which is set to an unhealthy state when a captive portal is detected. -var captivePortalWarnable = health.Register(&health.Warnable{ - Code: "captive-portal-detected", - Title: "Captive portal detected", - // High severity, because captive portals block all traffic and require user intervention. - Severity: health.SeverityHigh, - Text: health.StaticMessage("This network requires you to log in using your web browser."), - ImpactsConnectivity: true, -}) - -func (b *LocalBackend) checkCaptivePortalLoop(ctx context.Context) { - var tmr *time.Timer - - maybeStartTimer := func() { - // If there's an existing timer, nothing to do; just continue - // waiting for it to expire. Otherwise, create a new timer. - if tmr == nil { - tmr = time.NewTimer(captivePortalDetectionInterval) - } - } - maybeStopTimer := func() { - if tmr == nil { - return - } - if !tmr.Stop() { - <-tmr.C - } - tmr = nil - } - - for { - if ctx.Err() != nil { - maybeStopTimer() - return - } - - // First, see if we have a signal on our "healthy" channel, which - // takes priority over an existing timer. Because a select is - // nondeterministic, we explicitly check this channel before - // entering the main select below, so that we're guaranteed to - // stop the timer before starting captive portal detection. - select { - case needsCaptiveDetection := <-b.needsCaptiveDetection: - if needsCaptiveDetection { - maybeStartTimer() - } else { - maybeStopTimer() - } - default: - } - - var timerChan <-chan time.Time - if tmr != nil { - timerChan = tmr.C - } - select { - case <-ctx.Done(): - // All done; stop the timer and then exit. - maybeStopTimer() - return - case <-timerChan: - // Kick off captive portal check - b.performCaptiveDetection() - // nil out timer to force recreation - tmr = nil - case needsCaptiveDetection := <-b.needsCaptiveDetection: - if needsCaptiveDetection { - maybeStartTimer() - } else { - // Healthy; cancel any existing timer - maybeStopTimer() - } - } - } -} - -// performCaptiveDetection checks if captive portal detection is enabled via controlknob. If so, it runs -// the detection and updates the Warnable accordingly. -func (b *LocalBackend) performCaptiveDetection() { - if !b.shouldRunCaptivePortalDetection() { - return - } - - d := captivedetection.NewDetector(b.logf) - var dm *tailcfg.DERPMap - b.mu.Lock() - if b.netMap != nil { - dm = b.netMap.DERPMap - } - preferredDERP := 0 - if b.hostinfo != nil { - if b.hostinfo.NetInfo != nil { - preferredDERP = b.hostinfo.NetInfo.PreferredDERP - } - } - ctx := b.ctx - netMon := b.NetMon() - b.mu.Unlock() - found := d.Detect(ctx, netMon, dm, preferredDERP) - if found { - b.health.SetUnhealthy(captivePortalWarnable, health.Args{}) - } else { - b.health.SetHealthy(captivePortalWarnable) + b.goTracker.Go(b.sshServer.OnPolicyChange) } } -// shouldRunCaptivePortalDetection reports whether captive portal detection -// should be run. It is enabled by default, but can be disabled via a control -// knob. It is also only run when the user explicitly wants the backend to be -// running. -func (b *LocalBackend) shouldRunCaptivePortalDetection() bool { - b.mu.Lock() - defer b.mu.Unlock() - return !b.ControlKnobs().DisableCaptivePortalDetection.Load() && b.pm.prefs.WantRunning() -} - // packetFilterPermitsUnlockedNodes reports any peer in peers with the // UnsignedPeerAPIOnly bool set true has any of its allowed IPs in the packet // filter. @@ -2431,8 +2877,10 @@ func packetFilterPermitsUnlockedNodes(peers map[tailcfg.NodeID]tailcfg.NodeView, return false } +// TODO(nickkhyl): this should be non-existent with a proper [LocalBackend.updateFilterLocked]. +// See the comment in that function for more details. func (b *LocalBackend) setFilter(f *filter.Filter) { - b.filterAtomic.Store(f) + b.currentNode().setFilter(f) b.e.SetFilter(f) } @@ -2563,57 +3011,6 @@ func shrinkDefaultRoute(route netip.Prefix, localInterfaceRoutes *netipx.IPSet, return b.IPSet() } -// readPoller is a goroutine that receives service lists from -// b.portpoll and propagates them into the controlclient's HostInfo. -func (b *LocalBackend) readPoller() { - if !envknob.BoolDefaultTrue("TS_PORTLIST") { - return - } - - ticker, tickerChannel := b.clock.NewTicker(portlist.PollInterval()) - defer ticker.Stop() - for { - select { - case <-tickerChannel: - case <-b.ctx.Done(): - return - } - - if !b.shouldUploadServices() { - continue - } - - ports, changed, err := b.portpoll.Poll() - if err != nil { - b.logf("error polling for open ports: %v", err) - return - } - if !changed { - continue - } - sl := []tailcfg.Service{} - for _, p := range ports { - s := tailcfg.Service{ - Proto: tailcfg.ServiceProto(p.Proto), - Port: p.Port, - Description: p.Process, - } - if policy.IsInterestingService(s, version.OS()) { - sl = append(sl, s) - } - } - - b.mu.Lock() - if b.hostinfo == nil { - b.hostinfo = new(tailcfg.Hostinfo) - } - b.hostinfo.Services = sl - b.mu.Unlock() - - b.doSetHostinfoFilterServices() - } -} - // GetPushDeviceToken returns the push notification device token. func (b *LocalBackend) GetPushDeviceToken() string { return b.pushDeviceToken.Load() @@ -2654,34 +3051,28 @@ func applyConfigToHostinfo(hi *tailcfg.Hostinfo, c *conffile.Config) { // notifications. There is currently (2022-11-22) no mechanism provided to // detect when a message has been dropped. func (b *LocalBackend) WatchNotifications(ctx context.Context, mask ipn.NotifyWatchOpt, onWatchAdded func(), fn func(roNotify *ipn.Notify) (keepGoing bool)) { - ch := make(chan *ipn.Notify, 128) + b.WatchNotificationsAs(ctx, nil, mask, onWatchAdded, fn) +} +// WatchNotificationsAs is like [LocalBackend.WatchNotifications] but takes an [ipnauth.Actor] +// as an additional parameter. If non-nil, the specified callback is invoked +// only for notifications relevant to this actor. +func (b *LocalBackend) WatchNotificationsAs(ctx context.Context, actor ipnauth.Actor, mask ipn.NotifyWatchOpt, onWatchAdded func(), fn func(roNotify *ipn.Notify) (keepGoing bool)) { + ch := make(chan *ipn.Notify, 128) sessionID := rands.HexString(16) - - origFn := fn - if mask&ipn.NotifyNoPrivateKeys != 0 { - fn = func(n *ipn.Notify) bool { - if n.NetMap == nil || n.NetMap.PrivateKey.IsZero() { - return origFn(n) - } - - // The netmap in n is shared across all watchers, so to mutate it for a - // single watcher we have to clone the notify and the netmap. We can - // make shallow clones, at least. - nm2 := *n.NetMap - n2 := *n - n2.NetMap = &nm2 - n2.NetMap.PrivateKey = key.NodePrivate{} - return origFn(&n2) - } + if mask&ipn.NotifyHealthActions == 0 { + // if UI does not support PrimaryAction in health warnings, append + // action URLs to the warning text instead. + fn = appendHealthActions(fn) } var ini *ipn.Notify b.mu.Lock() - const initialBits = ipn.NotifyInitialState | ipn.NotifyInitialPrefs | ipn.NotifyInitialNetMap | ipn.NotifyInitialDriveShares + const initialBits = ipn.NotifyInitialState | ipn.NotifyInitialPrefs | ipn.NotifyInitialNetMap | ipn.NotifyInitialDriveShares | ipn.NotifyInitialSuggestedExitNode if mask&initialBits != 0 { + cn := b.currentNode() ini = &ipn.Notify{Version: version.Long()} if mask&ipn.NotifyInitialState != 0 { ini.SessionID = sessionID @@ -2694,14 +3085,19 @@ func (b *LocalBackend) WatchNotifications(ctx context.Context, mask ipn.NotifyWa ini.Prefs = ptr.To(b.sanitizedPrefsLocked()) } if mask&ipn.NotifyInitialNetMap != 0 { - ini.NetMap = b.netMap + ini.NetMap = cn.NetMap() } - if mask&ipn.NotifyInitialDriveShares != 0 && b.driveSharingEnabledLocked() { + if mask&ipn.NotifyInitialDriveShares != 0 && b.DriveSharingEnabled() { ini.DriveShares = b.pm.prefs.DriveShares() } if mask&ipn.NotifyInitialHealthState != 0 { ini.Health = b.HealthTracker().CurrentState() } + if mask&ipn.NotifyInitialSuggestedExitNode != 0 { + if en, err := b.suggestExitNodeLocked(); err == nil { + ini.SuggestedExitNode = &en.ID + } + } } ctx, cancel := context.WithCancel(ctx) @@ -2709,12 +3105,16 @@ func (b *LocalBackend) WatchNotifications(ctx context.Context, mask ipn.NotifyWa session := &watchSession{ ch: ch, + owner: actor, sessionID: sessionID, cancel: cancel, } mak.Set(&b.notifyWatchers, sessionID, session) b.mu.Unlock() + metricCurrentWatchIPNBus.Add(1) + defer metricCurrentWatchIPNBus.Add(-1) + defer func() { b.mu.Lock() delete(b.notifyWatchers, sessionID) @@ -2743,22 +3143,46 @@ func (b *LocalBackend) WatchNotifications(ctx context.Context, mask ipn.NotifyWa // request every 2 seconds. // TODO(bradfitz): plumb this further and only send a Notify on change. if mask&ipn.NotifyWatchEngineUpdates != 0 { - go b.pollRequestEngineStatus(ctx) + b.goTracker.Go(func() { b.pollRequestEngineStatus(ctx) }) } - // TODO(marwan-at-work): check err // TODO(marwan-at-work): streaming background logs? defer b.DeleteForegroundSession(sessionID) - for { - select { - case <-ctx.Done(): - return - case n := <-ch: - if !fn(n) { - return + sender := &rateLimitingBusSender{fn: fn} + defer sender.close() + + if mask&ipn.NotifyRateLimit != 0 { + sender.interval = 3 * time.Second + } + + sender.Run(ctx, ch) +} + +// appendHealthActions returns an IPN listener func that wraps the supplied IPN +// listener func and transforms health messages passed to the wrapped listener. +// If health messages with PrimaryActions are present, it appends the label & +// url in the PrimaryAction to the text of the message. For use for clients that +// do not process the PrimaryAction. +func appendHealthActions(fn func(roNotify *ipn.Notify) (keepGoing bool)) func(*ipn.Notify) bool { + return func(n *ipn.Notify) bool { + if n.Health == nil || len(n.Health.Warnings) == 0 { + return fn(n) + } + + // Shallow clone the notify and health so we can mutate them + h2 := *n.Health + n2 := *n + n2.Health = &h2 + n2.Health.Warnings = make(map[health.WarnableCode]health.UnhealthyState, len(n.Health.Warnings)) + for k, v := range n.Health.Warnings { + if v.PrimaryAction != nil { + v.Text = fmt.Sprintf("%s %s: %s", v.Text, v.PrimaryAction.Label, v.PrimaryAction.URL) + v.PrimaryAction = nil } + n2.Health.Warnings[k] = v } + return fn(&n2) } } @@ -2789,11 +3213,7 @@ func (b *LocalBackend) DebugNotify(n ipn.Notify) { // // It should only be used via the LocalAPI's debug handler. func (b *LocalBackend) DebugNotifyLastNetMap() { - b.mu.Lock() - nm := b.netMap - b.mu.Unlock() - - if nm != nil { + if nm := b.currentNode().NetMap(); nm != nil { b.send(ipn.Notify{NetMap: nm}) } } @@ -2807,7 +3227,8 @@ func (b *LocalBackend) DebugNotifyLastNetMap() { func (b *LocalBackend) DebugForceNetmapUpdate() { b.mu.Lock() defer b.mu.Unlock() - nm := b.netMap + // TODO(nickkhyl): this all should be done in [LocalBackend.setNetMapLocked]. + nm := b.currentNode().NetMap() b.e.SetNetworkMap(nm) if nm != nil { b.MagicConn().SetDERPMap(nm.DERPMap) @@ -2821,6 +3242,12 @@ func (b *LocalBackend) DebugPickNewDERP() error { return b.sys.MagicSock.Get().DebugPickNewDERP() } +// DebugForcePreferDERP forwards to netcheck.DebugForcePreferDERP. +// See its docs. +func (b *LocalBackend) DebugForcePreferDERP(n int) { + b.sys.MagicSock.Get().DebugForcePreferDERP(n) +} + // send delivers n to the connected frontend and any API watchers from // LocalBackend.WatchNotifications (via the LocalAPI). // @@ -2831,13 +3258,81 @@ func (b *LocalBackend) DebugPickNewDERP() error { // // b.mu must not be held. func (b *LocalBackend) send(n ipn.Notify) { + b.sendTo(n, allClients) +} + +func (b *LocalBackend) sendLocked(n ipn.Notify) { + b.sendToLocked(n, allClients) +} + +// SendNotify sends a notification to the IPN bus, +// typically to the GUI client. +func (b *LocalBackend) SendNotify(n ipn.Notify) { + b.send(n) +} + +// notificationTarget describes a notification recipient. +// A zero value is valid and indicate that the notification +// should be broadcast to all active [watchSession]s. +type notificationTarget struct { + // userID is the OS-specific UID of the target user. + // If empty, the notification is not user-specific and + // will be broadcast to all connected users. + // TODO(nickkhyl): make this field cross-platform rather + // than Windows-specific. + userID ipn.WindowsUserID + // clientID identifies a client that should be the exclusive recipient + // of the notification. A zero value indicates that notification should + // be sent to all sessions of the specified user. + clientID ipnauth.ClientID +} + +var allClients = notificationTarget{} // broadcast to all connected clients + +// toNotificationTarget returns a [notificationTarget] that matches only actors +// representing the same user as the specified actor. If the actor represents +// a specific connected client, the [ipnauth.ClientID] must also match. +// If the actor is nil, the [notificationTarget] matches all actors. +func toNotificationTarget(actor ipnauth.Actor) notificationTarget { + t := notificationTarget{} + if actor != nil { + t.userID = actor.UserID() + t.clientID, _ = actor.ClientID() + } + return t +} + +// match reports whether the specified actor should receive notifications +// targeting t. If the actor is nil, it should only receive notifications +// intended for all users. +func (t notificationTarget) match(actor ipnauth.Actor) bool { + if t == allClients { + return true + } + if actor == nil { + return false + } + if t.userID != "" && t.userID != actor.UserID() { + return false + } + if t.clientID != ipnauth.NoClientID { + clientID, ok := actor.ClientID() + if !ok || clientID != t.clientID { + return false + } + } + return true +} + +// sendTo is like [LocalBackend.send] but allows specifying a recipient. +func (b *LocalBackend) sendTo(n ipn.Notify, recipient notificationTarget) { b.mu.Lock() defer b.mu.Unlock() - b.sendLocked(n) + b.sendToLocked(n, recipient) } -// sendLocked is like send, but assumes b.mu is already held. -func (b *LocalBackend) sendLocked(n ipn.Notify) { +// sendToLocked is like [LocalBackend.sendTo], but assumes b.mu is already held. +func (b *LocalBackend) sendToLocked(n ipn.Notify, recipient notificationTarget) { if n.Prefs != nil { n.Prefs = ptr.To(stripKeysFromPrefs(*n.Prefs)) } @@ -2845,59 +3340,38 @@ func (b *LocalBackend) sendLocked(n ipn.Notify) { n.Version = version.Long() } - apiSrv := b.peerAPIServer - if mayDeref(apiSrv).taildrop.HasFilesWaiting() { - n.FilesWaiting = &empty.Message{} + for _, f := range b.extHost.Hooks().MutateNotifyLocked { + f(&n) } for _, sess := range b.notifyWatchers { - select { - case sess.ch <- &n: - default: - // Drop the notification if the channel is full. + if recipient.match(sess.owner) { + select { + case sess.ch <- &n: + default: + // Drop the notification if the channel is full. + } } } } -func (b *LocalBackend) sendFileNotify() { - var n ipn.Notify - - b.mu.Lock() - for _, wakeWaiter := range b.fileWaiters { - wakeWaiter() - } - apiSrv := b.peerAPIServer - if apiSrv == nil { - b.mu.Unlock() - return - } - - // Make sure we always set n.IncomingFiles non-nil so it gets encoded - // in JSON to clients. They distinguish between empty and non-nil - // to know whether a Notify should be able about files. - n.IncomingFiles = apiSrv.taildrop.IncomingFiles() - b.mu.Unlock() - - sort.Slice(n.IncomingFiles, func(i, j int) bool { - return n.IncomingFiles[i].Started.Before(n.IncomingFiles[j].Started) - }) - - b.send(n) -} - -// setAuthURL sets the authURL and triggers [LocalBackend.popBrowserAuthNow] if the URL has changed. +// setAuthURLLocked sets the authURL and triggers [LocalBackend.popBrowserAuthNow] if the URL has changed. // This method is called when a new authURL is received from the control plane, meaning that either a user // has started a new interactive login (e.g., by running `tailscale login` or clicking Login in the GUI), // or the control plane was unable to authenticate this node non-interactively (e.g., due to key expiration). -// b.interact indicates whether an interactive login is in progress. +// A non-nil b.authActor indicates that an interactive login is in progress and was initiated by the specified actor. +// +// b.mu must be held. +// // If url is "", it is equivalent to calling [LocalBackend.resetAuthURLLocked] with b.mu held. -func (b *LocalBackend) setAuthURL(url string) { +func (b *LocalBackend) setAuthURLLocked(url string) { var popBrowser, keyExpired bool + var recipient ipnauth.Actor - b.mu.Lock() switch { case url == "": b.resetAuthURLLocked() + return case b.authURL != url: b.authURL = url b.authURLTime = b.clock.Now() @@ -2906,38 +3380,40 @@ func (b *LocalBackend) setAuthURL(url string) { popBrowser = true default: // Otherwise, only open it if the user explicitly requests interactive login. - popBrowser = b.interact + popBrowser = b.authActor != nil } keyExpired = b.keyExpired + recipient = b.authActor // or nil // Consume the StartLoginInteractive call, if any, that caused the control // plane to send us this URL. - b.interact = false - b.mu.Unlock() + b.authActor = nil if popBrowser { - b.popBrowserAuthNow(url, keyExpired) + b.popBrowserAuthNowLocked(url, keyExpired, recipient) } } -// popBrowserAuthNow shuts down the data plane and sends an auth URL -// to the connected frontend, if any. +// popBrowserAuthNowLocked shuts down the data plane and sends the URL to the recipient's +// [watchSession]s if the recipient is non-nil; otherwise, it sends the URL to all watchSessions. // keyExpired is the value of b.keyExpired upon entry and indicates // whether the node's key has expired. -// It must not be called with b.mu held. -func (b *LocalBackend) popBrowserAuthNow(url string, keyExpired bool) { - b.logf("popBrowserAuthNow: url=%v, key-expired=%v, seamless-key-renewal=%v", url != "", keyExpired, b.seamlessRenewalEnabled()) +// +// b.mu must be held. +func (b *LocalBackend) popBrowserAuthNowLocked(url string, keyExpired bool, recipient ipnauth.Actor) { + b.logf("popBrowserAuthNow(%q): url=%v, key-expired=%v, seamless-key-renewal=%v", maybeUsernameOf(recipient), url != "", keyExpired, b.seamlessRenewalEnabled()) // Deconfigure the local network data plane if: // - seamless key renewal is not enabled; // - key is expired (in which case tailnet connectivity is down anyway). if !b.seamlessRenewalEnabled() || keyExpired { - b.blockEngineUpdates(true) - b.stopEngineAndWait() - } - b.tellClientToBrowseToURL(url) - if b.State() == ipn.Running { - b.enterState(ipn.Starting) + b.blockEngineUpdatesLocked(true) + b.stopEngineAndWaitLocked() + + if b.state == ipn.Running { + b.enterStateLocked(ipn.Starting) + } } + b.tellRecipientToBrowseToURLLocked(url, toNotificationTarget(recipient)) } // validPopBrowserURL reports whether urlStr is a valid value for a @@ -2945,6 +3421,16 @@ func (b *LocalBackend) popBrowserAuthNow(url string, keyExpired bool) { // // b.mu must *not* be held. func (b *LocalBackend) validPopBrowserURL(urlStr string) bool { + b.mu.Lock() + defer b.mu.Unlock() + return b.validPopBrowserURLLocked(urlStr) +} + +// validPopBrowserURLLocked reports whether urlStr is a valid value for a +// control server to send in a *URL field. +// +// b.mu must be held. +func (b *LocalBackend) validPopBrowserURLLocked(urlStr string) bool { if urlStr == "" { return false } @@ -2952,7 +3438,7 @@ func (b *LocalBackend) validPopBrowserURL(urlStr string) bool { if err != nil { return false } - serverURL := b.Prefs().ControlURLOrDefault() + serverURL := b.sanitizedPrefsLocked().ControlURLOrDefault(b.polc) if ipn.IsLoginServerSynonym(serverURL) { // When connected to the official Tailscale control plane, only allow // URLs from tailscale.com or its subdomains. @@ -2975,8 +3461,16 @@ func (b *LocalBackend) validPopBrowserURL(urlStr string) bool { } func (b *LocalBackend) tellClientToBrowseToURL(url string) { - if b.validPopBrowserURL(url) { - b.send(ipn.Notify{BrowseToURL: &url}) + b.mu.Lock() + defer b.mu.Unlock() + b.tellRecipientToBrowseToURLLocked(url, allClients) +} + +// tellRecipientToBrowseToURLLocked is like tellClientToBrowseToURL but allows specifying a recipient +// and b.mu must be held. +func (b *LocalBackend) tellRecipientToBrowseToURLLocked(url string, recipient notificationTarget) { + if b.validPopBrowserURLLocked(url) { + b.sendToLocked(ipn.Notify{BrowseToURL: &url}, recipient) } } @@ -2991,8 +3485,8 @@ func (b *LocalBackend) onClientVersion(v *tailcfg.ClientVersion) { } func (b *LocalBackend) onTailnetDefaultAutoUpdate(au bool) { - unlock := b.lockAndGetUnlock() - defer unlock() + b.mu.Lock() + defer b.mu.Unlock() prefs := b.pm.CurrentPrefs() if !prefs.Valid() { @@ -3010,18 +3504,22 @@ func (b *LocalBackend) onTailnetDefaultAutoUpdate(au bool) { // can still manually enable auto-updates on this node. return } - b.logf("using tailnet default auto-update setting: %v", au) - prefsClone := prefs.AsStruct() - prefsClone.AutoUpdate.Apply = opt.NewBool(au) - _, err := b.editPrefsLockedOnEntry(&ipn.MaskedPrefs{ - Prefs: *prefsClone, - AutoUpdateSet: ipn.AutoUpdatePrefsMask{ - ApplySet: true, - }, - }, unlock) - if err != nil { - b.logf("failed to apply tailnet-wide default for auto-updates (%v): %v", au, err) - return + if buildfeatures.HasClientUpdate && feature.CanAutoUpdate() { + b.logf("using tailnet default auto-update setting: %v", au) + prefsClone := prefs.AsStruct() + prefsClone.AutoUpdate.Apply = opt.NewBool(au) + _, err := b.editPrefsLocked( + ipnauth.Self, + &ipn.MaskedPrefs{ + Prefs: *prefsClone, + AutoUpdateSet: ipn.AutoUpdatePrefsMask{ + ApplySet: true, + }, + }) + if err != nil { + b.logf("failed to apply tailnet-wide default for auto-updates (%v): %v", au, err) + return + } } } @@ -3061,11 +3559,6 @@ func (b *LocalBackend) initMachineKeyLocked() (err error) { return nil } - var legacyMachineKey key.MachinePrivate - if p := b.pm.CurrentPrefs().Persist(); p.Valid() { - legacyMachineKey = p.LegacyFrontendPrivateMachineKey() - } - keyText, err := b.store.ReadState(ipn.MachineKeyStateKey) if err == nil { if err := b.machinePrivKey.UnmarshalText(keyText); err != nil { @@ -3074,9 +3567,6 @@ func (b *LocalBackend) initMachineKeyLocked() (err error) { if b.machinePrivKey.IsZero() { return fmt.Errorf("invalid zero key stored in %v key of %v", ipn.MachineKeyStateKey, b.store) } - if !legacyMachineKey.IsZero() && !legacyMachineKey.Equal(b.machinePrivKey) { - b.logf("frontend-provided legacy machine key ignored; used value from server state") - } return nil } if err != ipn.ErrStateNotExist { @@ -3086,12 +3576,8 @@ func (b *LocalBackend) initMachineKeyLocked() (err error) { // If we didn't find one already on disk and the prefs already // have a legacy machine key, use that. Otherwise generate a // new one. - if !legacyMachineKey.IsZero() { - b.machinePrivKey = legacyMachineKey - } else { - b.logf("generating new machine key") - b.machinePrivKey = key.NewMachine() - } + b.logf("generating new machine key") + b.machinePrivKey = key.NewMachine() keyText, _ = b.machinePrivKey.MarshalText() if err := ipn.WriteState(b.store, ipn.MachineKeyStateKey, keyText); err != nil { @@ -3117,12 +3603,9 @@ func (b *LocalBackend) clearMachineKeyLocked() error { return nil } -// setTCPPortsIntercepted populates b.shouldInterceptTCPPortAtomic with an -// efficient func for ShouldInterceptTCPPort to use, which is called on every -// incoming packet. -func (b *LocalBackend) setTCPPortsIntercepted(ports []uint16) { +func generateInterceptTCPPortFunc(ports []uint16) func(uint16) bool { slices.Sort(ports) - uniq.ModifySlice(&ports) + ports = slices.Compact(ports) var f func(uint16) bool switch len(ports) { case 0: @@ -3151,7 +3634,23 @@ func (b *LocalBackend) setTCPPortsIntercepted(ports []uint16) { } } } - b.shouldInterceptTCPPortAtomic.Store(f) + return f +} + +// setTCPPortsIntercepted populates b.shouldInterceptTCPPortAtomic with an +// efficient func for ShouldInterceptTCPPort to use, which is called on every +// incoming packet. +func (b *LocalBackend) setTCPPortsIntercepted(ports []uint16) { + b.shouldInterceptTCPPortAtomic.Store(generateInterceptTCPPortFunc(ports)) +} + +func generateInterceptVIPServicesTCPPortFunc(svcAddrPorts map[netip.Addr]func(uint16) bool) func(netip.AddrPort) bool { + return func(ap netip.AddrPort) bool { + if f, ok := svcAddrPorts[ap.Addr()]; ok { + return f(ap.Port()) + } + return false + } } // setAtomicValuesFromPrefsLocked populates sshAtomicBool, containsViaIPFuncAtomic, @@ -3164,6 +3663,9 @@ func (b *LocalBackend) setAtomicValuesFromPrefsLocked(p ipn.PrefsView) { if !p.Valid() { b.containsViaIPFuncAtomic.Store(ipset.FalseContainsIPFunc()) b.setTCPPortsIntercepted(nil) + if f, ok := hookServeClearVIPServicesTCPPortsInterceptedLocked.GetOk(); ok { + f(b) + } b.lastServeConfJSON = mem.B(nil) b.serveConfig = ipn.ServeConfigView{} } else { @@ -3181,23 +3683,6 @@ func (b *LocalBackend) State() ipn.State { return b.state } -// InServerMode reports whether the Tailscale backend is explicitly running in -// "server mode" where it continues to run despite whatever the platform's -// default is. In practice, this is only used on Windows, where the default -// tailscaled behavior is to shut down whenever the GUI disconnects. -// -// On non-Windows platforms, this usually returns false (because people don't -// set unattended mode on other platforms) and also isn't checked on other -// platforms. -// -// TODO(bradfitz): rename to InWindowsUnattendedMode or something? Or make this -// return true on Linux etc and always be called? It's kinda messy now. -func (b *LocalBackend) InServerMode() bool { - b.mu.Lock() - defer b.mu.Unlock() - return b.pm.CurrentPrefs().ForceDaemon() -} - // CheckIPNConnectionAllowed returns an error if the specified actor should not // be allowed to connect or make requests to the LocalAPI currently. // @@ -3207,16 +3692,10 @@ func (b *LocalBackend) InServerMode() bool { func (b *LocalBackend) CheckIPNConnectionAllowed(actor ipnauth.Actor) error { b.mu.Lock() defer b.mu.Unlock() - serverModeUid := b.pm.CurrentUserID() - if serverModeUid == "" { - // Either this platform isn't a "multi-user" platform or we're not yet - // running as one. - return nil - } - if !b.pm.CurrentPrefs().ForceDaemon() { + if b.pm.CurrentUserID() == "" { + // There's no "current user" yet; allow the connection. return nil } - // Always allow Windows SYSTEM user to connect, // even if Tailscale is currently being used by another user. if actor.IsLocalSystem() { @@ -3227,10 +3706,21 @@ func (b *LocalBackend) CheckIPNConnectionAllowed(actor ipnauth.Actor) error { if uid == "" { return errors.New("empty user uid in connection identity") } - if uid != serverModeUid { - return fmt.Errorf("Tailscale running in server mode (%q); connection from %q not allowed", b.tryLookupUserName(string(serverModeUid)), b.tryLookupUserName(string(uid))) + if uid == b.pm.CurrentUserID() { + // The connection is from the current user; allow it. + return nil } - return nil + + // The connection is from a different user; block it. + var reason string + if b.pm.CurrentPrefs().ForceDaemon() { + reason = "running in server mode" + } else { + reason = "already in use" + } + return fmt.Errorf("Tailscale %s (%q); connection from %q not allowed", + reason, b.tryLookupUserName(string(b.pm.CurrentUserID())), + b.tryLookupUserName(string(uid))) } // tryLookupUserName tries to look up the username for the uid. @@ -3248,7 +3738,17 @@ func (b *LocalBackend) tryLookupUserName(uid string) string { // StartLoginInteractive attempts to pick up the in-progress flow where it left // off. func (b *LocalBackend) StartLoginInteractive(ctx context.Context) error { + return b.StartLoginInteractiveAs(ctx, nil) +} + +// StartLoginInteractiveAs is like StartLoginInteractive but takes an [ipnauth.Actor] +// as an additional parameter. If non-nil, the specified user is expected to complete +// the interactive login, and therefore will receive the BrowseToURL notification once +// the control plane sends us one. Otherwise, the notification will be delivered to all +// active [watchSession]s. +func (b *LocalBackend) StartLoginInteractiveAs(ctx context.Context, user ipnauth.Actor) error { b.mu.Lock() + defer b.mu.Unlock() if b.cc == nil { panic("LocalBackend.assertClient: b.cc == nil") } @@ -3261,17 +3761,16 @@ func (b *LocalBackend) StartLoginInteractive(ctx context.Context) error { hasValidURL := url != "" && timeSinceAuthURLCreated < ((7*24*time.Hour)-(1*time.Hour)) if !hasValidURL { // A user wants to log in interactively, but we don't have a valid authURL. - // Set a flag to indicate that interactive login is in progress, forcing - // a BrowseToURL notification once the authURL becomes available. - b.interact = true + // Remember the user who initiated the login, so that we can notify them + // once the authURL is available. + b.authActor = user } cc := b.cc - b.mu.Unlock() - b.logf("StartLoginInteractive: url=%v", hasValidURL) + b.logf("StartLoginInteractiveAs(%q): url=%v", maybeUsernameOf(user), hasValidURL) if hasValidURL { - b.popBrowserAuthNow(url, keyExpired) + b.popBrowserAuthNowLocked(url, keyExpired, user) } else { cc.Login(b.loginFlags | controlclient.LoginInteractive) } @@ -3316,6 +3815,9 @@ func (b *LocalBackend) Ping(ctx context.Context, ip netip.Addr, pingType tailcfg } func (b *LocalBackend) pingPeerAPI(ctx context.Context, ip netip.Addr) (peer tailcfg.NodeView, peerBase string, err error) { + if !buildfeatures.HasPeerAPIClient { + return peer, peerBase, feature.ErrUnavailable + } var zero tailcfg.NodeView ctx, cancel := context.WithTimeout(ctx, 10*time.Second) defer cancel() @@ -3381,23 +3883,8 @@ func (b *LocalBackend) parseWgStatusLocked(s *wgengine.Status) (ret ipn.EngineSt return ret } -// shouldUploadServices reports whether this node should include services -// in Hostinfo. When the user preferences currently request "shields up" -// mode, all inbound connections are refused, so services are not reported. -// Otherwise, shouldUploadServices respects NetMap.CollectServices. -func (b *LocalBackend) shouldUploadServices() bool { - b.mu.Lock() - defer b.mu.Unlock() - - p := b.pm.CurrentPrefs() - if !p.Valid() || b.netMap == nil { - return false // default to safest setting - } - return !p.ShieldsUp() && b.netMap.CollectServices -} - // SetCurrentUser is used to implement support for multi-user systems (only -// Windows 2022-11-25). On such systems, the uid is used to determine which +// Windows 2022-11-25). On such systems, the actor is used to determine which // user's state should be used. The current user is maintained by active // connections open to the backend. // @@ -3411,29 +3898,154 @@ func (b *LocalBackend) shouldUploadServices() bool { // unattended mode. The user must disable unattended mode before the user can be // changed. // -// On non-multi-user systems, the user should be set to nil. -// -// SetCurrentUser returns the ipn.WindowsUserID associated with the user -// when successful. -func (b *LocalBackend) SetCurrentUser(actor ipnauth.Actor) (ipn.WindowsUserID, error) { - var uid ipn.WindowsUserID - if actor != nil { - uid = actor.UserID() +// On non-multi-user systems, the actor should be set to nil. +func (b *LocalBackend) SetCurrentUser(actor ipnauth.Actor) { + b.mu.Lock() + defer b.mu.Unlock() + + var userIdentifier string + if user := cmp.Or(actor, b.currentUser); user != nil { + maybeUsername, _ := user.Username() + userIdentifier = cmp.Or(maybeUsername, string(user.UserID())) + } + + if actor != b.currentUser { + if c, ok := b.currentUser.(ipnauth.ActorCloser); ok { + c.Close() + } + b.currentUser = actor + } + + var action string + if actor == nil { + action = "disconnected" + } else { + action = "connected" } + reason := fmt.Sprintf("client %s (%s)", action, userIdentifier) + b.switchToBestProfileLocked(reason) +} + +// SwitchToBestProfile selects the best profile to use, +// as reported by [LocalBackend.resolveBestProfileLocked], and switches +// to it, unless it's already the current profile. The reason indicates +// why the profile is being switched, such as due to a client connecting +// or disconnecting, or a change in the desktop session state, and is used +// for logging. +func (b *LocalBackend) SwitchToBestProfile(reason string) { + b.mu.Lock() + defer b.mu.Unlock() + b.switchToBestProfileLocked(reason) +} - unlock := b.lockAndGetUnlock() - defer unlock() +// switchToBestProfileLocked is like [LocalBackend.SwitchToBestProfile], +// but b.mu must held on entry. +func (b *LocalBackend) switchToBestProfileLocked(reason string) { + oldControlURL := b.pm.CurrentPrefs().ControlURLOrDefault(b.polc) + profile, background := b.resolveBestProfileLocked() + cp, switched, err := b.pm.SwitchToProfile(profile) + switch { + case !switched && cp.ID() == "": + if err != nil { + b.logf("%s: an error occurred; staying on empty profile: %v", reason, err) + } else { + b.logf("%s: staying on empty profile", reason) + } + case !switched: + if err != nil { + b.logf("%s: an error occurred; staying on profile %q (%s): %v", reason, cp.UserProfile().LoginName, cp.ID(), err) + } else { + b.logf("%s: staying on profile %q (%s)", reason, cp.UserProfile().LoginName, cp.ID()) + } + case cp.ID() == "": + b.logf("%s: disconnecting Tailscale", reason) + case background: + b.logf("%s: switching to background profile %q (%s)", reason, cp.UserProfile().LoginName, cp.ID()) + default: + b.logf("%s: switching to profile %q (%s)", reason, cp.UserProfile().LoginName, cp.ID()) + } + if !switched { + return + } + // As an optimization, only reset the dialPlan if the control URL changed. + if newControlURL := b.pm.CurrentPrefs().ControlURLOrDefault(b.polc); oldControlURL != newControlURL { + b.resetDialPlan() + } + if err := b.resetForProfileChangeLocked(); err != nil { + // TODO(nickkhyl): The actual reset cannot fail. However, + // the TKA initialization or [LocalBackend.Start] can fail. + // These errors are not critical as far as we're concerned. + // But maybe we should post a notification to the API watchers? + b.logf("failed switching profile to %q: %v", profile.ID(), err) + } +} - if b.pm.CurrentUserID() == uid { - return uid, nil +// resolveBestProfileLocked returns the best profile to use based on the current +// state of the backend, such as whether a GUI/CLI client is connected, whether +// the unattended mode is enabled, the current state of the desktop sessions, +// and other factors. +// +// It returns a read-only view of the profile and whether it is considered +// a background profile. A background profile is used when no OS user is actively +// using Tailscale, such as when no GUI/CLI client is connected and Unattended Mode +// is enabled (see also [LocalBackend.getBackgroundProfileLocked]). +// +// An invalid view indicates no profile, meaning Tailscale should disconnect +// and remain idle until a GUI or CLI client connects. +// A valid profile view with an empty [ipn.ProfileID] indicates a new profile that +// has not been persisted yet. +// +// b.mu must be held. +func (b *LocalBackend) resolveBestProfileLocked() (_ ipn.LoginProfileView, isBackground bool) { + // TODO(nickkhyl): delegate all of this to the extensions and remove the distinction + // between "foreground" and "background" profiles as we migrate away from the concept + // of a single "current user" on Windows. See tailscale/corp#18342. + // + // If a GUI/CLI client is connected, use the connected user's profile, which means + // either the current profile if owned by the user, or their default profile. + if b.currentUser != nil { + profile := b.pm.CurrentProfile() + // TODO(nickkhyl): check if the current profile is allowed on the device, + // such as when [pkey.Tailnet] policy setting requires a specific Tailnet. + // See tailscale/corp#26249. + if uid := b.currentUser.UserID(); profile.LocalUserID() != uid { + profile = b.pm.DefaultUserProfile(uid) + } + return profile, false } - b.pm.SetCurrentUserID(uid) - if c, ok := b.currentUser.(ipnauth.ActorCloser); ok { - c.Close() + + // Otherwise, if on Windows, use the background profile if one is set. + // This includes staying on the current profile if Unattended Mode is enabled + // or if AlwaysOn mode is enabled and the current user is still signed in. + // If the returned background profileID is "", Tailscale will disconnect + // and remain idle until a GUI or CLI client connects. + if goos := envknob.GOOS(); goos == "windows" { + // If Unattended Mode is enabled for the current profile, keep using it. + if b.pm.CurrentPrefs().ForceDaemon() { + return b.pm.CurrentProfile(), true + } + // Otherwise, use the profile returned by the extension. + profile := b.extHost.DetermineBackgroundProfile(b.pm) + return profile, true } - b.currentUser = actor - b.resetForProfileChangeLockedOnEntry(unlock) - return uid, nil + + // On other platforms, however, Tailscale continues to run in the background + // using the current profile. + // + // TODO(nickkhyl): check if the current profile is allowed on the device, + // such as when [pkey.Tailnet] policy setting requires a specific Tailnet. + // See tailscale/corp#26249. + return b.pm.CurrentProfile(), false +} + +// CurrentUserForTest returns the current user and the associated WindowsUserID. +// It is used for testing only, and will be removed along with the rest of the +// "current user" functionality as we progress on the multi-user improvements (tailscale/corp#18342). +func (b *LocalBackend) CurrentUserForTest() (ipn.WindowsUserID, ipnauth.Actor) { + testenv.AssertInTest() + b.mu.Lock() + defer b.mu.Unlock() + return b.pm.CurrentUserID(), b.currentUser } func (b *LocalBackend) CheckPrefs(p *ipn.Prefs) error { @@ -3475,14 +4087,14 @@ func (b *LocalBackend) checkPrefsLocked(p *ipn.Prefs) error { if err := b.checkAutoUpdatePrefsLocked(p); err != nil { errs = append(errs, err) } - return multierr.New(errs...) + return errors.Join(errs...) } func (b *LocalBackend) checkSSHPrefsLocked(p *ipn.Prefs) error { if !p.RunSSH { return nil } - if err := envknob.CanRunTailscaleSSH(); err != nil { + if err := featureknob.CanRunTailscaleSSH(); err != nil { return err } if runtime.GOOS == "linux" { @@ -3491,13 +4103,12 @@ func (b *LocalBackend) checkSSHPrefsLocked(p *ipn.Prefs) error { if envknob.SSHIgnoreTailnetPolicy() || envknob.SSHPolicyFile() != "" { return nil } - if b.netMap != nil { - if !b.netMap.HasCap(tailcfg.CapabilitySSH) { - if b.isDefaultServerLocked() { - return errors.New("Unable to enable local Tailscale SSH server; not enabled on Tailnet. See https://tailscale.com/s/ssh") - } - return errors.New("Unable to enable local Tailscale SSH server; not enabled on Tailnet.") + // Assume that we do have the SSH capability if don't have a netmap yet. + if !b.currentNode().SelfHasCapOr(tailcfg.CapabilitySSH, true) { + if b.isDefaultServerLocked() { + return errors.New("Unable to enable local Tailscale SSH server; not enabled on Tailnet. See https://tailscale.com/s/ssh") } + return errors.New("Unable to enable local Tailscale SSH server; not enabled on Tailnet.") } return nil } @@ -3509,7 +4120,7 @@ func (b *LocalBackend) sshOnButUnusableHealthCheckMessageLocked() (healthMessage if envknob.SSHIgnoreTailnetPolicy() || envknob.SSHPolicyFile() != "" { return "development SSH policy in use" } - nm := b.netMap + nm := b.currentNode().NetMap() if nm == nil { return "" } @@ -3532,7 +4143,7 @@ func (b *LocalBackend) isDefaultServerLocked() bool { if !prefs.Valid() { return true // assume true until set otherwise } - return prefs.ControlURLOrDefault() == ipn.DefaultControlURL + return prefs.ControlURLOrDefault(b.polc) == ipn.DefaultControlURL } var exitNodeMisconfigurationWarnable = health.Register(&health.Warnable{ @@ -3547,12 +4158,14 @@ var exitNodeMisconfigurationWarnable = health.Register(&health.Warnable{ // updateExitNodeUsageWarning updates a warnable meant to notify users of // configuration issues that could break exit node usage. func updateExitNodeUsageWarning(p ipn.PrefsView, state *netmon.State, healthTracker *health.Tracker) { + if !buildfeatures.HasUseExitNode { + return + } var msg string if p.ExitNodeIP().IsValid() || p.ExitNodeID() != "" { warn, _ := netutil.CheckReversePathFiltering(state) - const comment = "please set rp_filter=2 instead of rp_filter=1; see https://github.com/tailscale/tailscale/issues/3310" if len(warn) > 0 { - msg = fmt.Sprintf("%s: %v, %s", healthmsg.WarnExitNodeUsage, warn, comment) + msg = fmt.Sprintf("%s: %v, %s", healthmsg.WarnExitNodeUsage, warn, healthmsg.DisableRPFilter) } } if len(msg) > 0 { @@ -3563,7 +4176,19 @@ func updateExitNodeUsageWarning(p ipn.PrefsView, state *netmon.State, healthTrac } func (b *LocalBackend) checkExitNodePrefsLocked(p *ipn.Prefs) error { - if (p.ExitNodeIP.IsValid() || p.ExitNodeID != "") && p.AdvertisesExitNode() { + tryingToUseExitNode := p.ExitNodeIP.IsValid() || p.ExitNodeID != "" + if !tryingToUseExitNode { + return nil + } + if !buildfeatures.HasUseExitNode { + return feature.ErrUnavailable + } + + if err := featureknob.CanUseExitNode(); err != nil { + return err + } + + if p.AdvertisesExitNode() { return errors.New("Cannot advertise an exit node and use an exit node at the same time.") } return nil @@ -3577,7 +4202,12 @@ func (b *LocalBackend) checkFunnelEnabledLocked(p *ipn.Prefs) error { } func (b *LocalBackend) checkAutoUpdatePrefsLocked(p *ipn.Prefs) error { - if p.AutoUpdate.Apply.EqualBool(true) && !clientupdate.CanAutoUpdate() { + if !buildfeatures.HasClientUpdate { + if p.AutoUpdate.Apply.EqualBool(true) { + return errors.New("Auto-update support is disabled in this build") + } + } + if p.AutoUpdate.Apply.EqualBool(true) && !feature.CanAutoUpdate() { return errors.New("Auto-updates are not supported on this platform.") } return nil @@ -3588,11 +4218,14 @@ func (b *LocalBackend) checkAutoUpdatePrefsLocked(p *ipn.Prefs) error { // On success, it returns the resulting prefs (or current prefs, in the case of no change). // Setting the value to false when use of an exit node is already false is not an error, // nor is true when the exit node is already in use. -func (b *LocalBackend) SetUseExitNodeEnabled(v bool) (ipn.PrefsView, error) { - unlock := b.lockAndGetUnlock() - defer unlock() +func (b *LocalBackend) SetUseExitNodeEnabled(actor ipnauth.Actor, v bool) (ipn.PrefsView, error) { + b.mu.Lock() + defer b.mu.Unlock() p0 := b.pm.CurrentPrefs() + if !buildfeatures.HasUseExitNode { + return p0, nil + } if v && p0.ExitNodeID() != "" { // Already on. return p0, nil @@ -3613,22 +4246,36 @@ func (b *LocalBackend) SetUseExitNodeEnabled(v bool) (ipn.PrefsView, error) { mp := &ipn.MaskedPrefs{} if v { mp.ExitNodeIDSet = true - mp.ExitNodeID = tailcfg.StableNodeID(p0.InternalExitNodePrior()) + mp.ExitNodeID = p0.InternalExitNodePrior() + if expr, ok := ipn.ParseAutoExitNodeString(mp.ExitNodeID); ok { + mp.AutoExitNodeSet = true + mp.AutoExitNode = expr + mp.ExitNodeID = unresolvedExitNodeID + } } else { mp.ExitNodeIDSet = true mp.ExitNodeID = "" + mp.AutoExitNodeSet = true + mp.AutoExitNode = "" mp.InternalExitNodePriorSet = true - mp.InternalExitNodePrior = p0.ExitNodeID() + if p0.AutoExitNode().IsSet() { + mp.InternalExitNodePrior = tailcfg.StableNodeID(ipn.AutoExitNodePrefix + p0.AutoExitNode()) + } else { + mp.InternalExitNodePrior = p0.ExitNodeID() + } } - return b.editPrefsLockedOnEntry(mp, unlock) + return b.editPrefsLocked(actor, mp) } // MaybeClearAppConnector clears the routes from any AppConnector if // AdvertiseRoutes has been set in the MaskedPrefs. func (b *LocalBackend) MaybeClearAppConnector(mp *ipn.MaskedPrefs) error { + if !buildfeatures.HasAppConnectors { + return nil + } var err error - if b.appConnector != nil && mp.AdvertiseRoutesSet { - err = b.appConnector.ClearRoutes() + if ac := b.AppConnector(); ac != nil && mp.AdvertiseRoutesSet { + err = ac.ClearRoutes() if err != nil { b.logf("appc: clear routes error: %v", err) } @@ -3636,52 +4283,303 @@ func (b *LocalBackend) MaybeClearAppConnector(mp *ipn.MaskedPrefs) error { return err } +// EditPrefs applies the changes in mp to the current prefs, +// acting as the tailscaled itself rather than a specific user. func (b *LocalBackend) EditPrefs(mp *ipn.MaskedPrefs) (ipn.PrefsView, error) { + return b.EditPrefsAs(mp, ipnauth.Self) +} + +// EditPrefsAs is like EditPrefs, but makes the change as the specified actor. +// It returns an error if the actor is not allowed to make the change. +func (b *LocalBackend) EditPrefsAs(mp *ipn.MaskedPrefs, actor ipnauth.Actor) (ipn.PrefsView, error) { if mp.SetsInternal() { return ipn.PrefsView{}, errors.New("can't set Internal fields") } + defer b.settleEventBus() + + b.mu.Lock() + defer b.mu.Unlock() + return b.editPrefsLocked(actor, mp) +} + +// checkEditPrefsAccessLocked checks whether the current user has access +// to apply the changes in mp to the given prefs. +// +// It returns an error if the user is not allowed, or nil otherwise. +// +// b.mu must be held. +func (b *LocalBackend) checkEditPrefsAccessLocked(actor ipnauth.Actor, prefs ipn.PrefsView, mp *ipn.MaskedPrefs) error { + var errs []error + + if mp.RunSSHSet && mp.RunSSH && !envknob.CanSSHD() { + errs = append(errs, errors.New("Tailscale SSH server administratively disabled")) + } + + // Check if the user is allowed to disconnect Tailscale. + if mp.WantRunningSet && !mp.WantRunning && b.pm.CurrentPrefs().WantRunning() { + if err := actor.CheckProfileAccess(b.pm.CurrentProfile(), ipnauth.Disconnect, b.extHost.AuditLogger()); err != nil { + errs = append(errs, err) + } + } + + // Prevent users from changing exit node preferences + // when exit node usage is managed by policy. + if mp.ExitNodeIDSet || mp.ExitNodeIPSet || mp.AutoExitNodeSet { + isManaged, err := b.polc.HasAnyOf(pkey.ExitNodeID, pkey.ExitNodeIP) + if err != nil { + err = fmt.Errorf("policy check failed: %w", err) + } else if isManaged { + // Allow users to override ExitNode policy settings and select an exit node manually + // if permitted by [pkey.AllowExitNodeOverride]. + // + // Disabling exit node usage entirely is not allowed. + allowExitNodeOverride, _ := b.polc.GetBoolean(pkey.AllowExitNodeOverride, false) + if !allowExitNodeOverride || b.changeDisablesExitNodeLocked(prefs, mp) { + err = errManagedByPolicy + } + } + if err != nil { + errs = append(errs, fmt.Errorf("exit node cannot be changed: %w", err)) + } + } + + return errors.Join(errs...) +} + +// changeDisablesExitNodeLocked reports whether applying the change +// to the given prefs would disable exit node usage. +// +// In other words, it returns true if prefs.ExitNodeID is non-empty +// initially, but would become empty after applying the given change. +// +// It applies the same adjustments and resolves the exit node in the prefs +// as done during actual edits. While not optimal performance-wise, +// changing the exit node via LocalAPI isn't a hot path, and reusing +// the same logic ensures consistency and simplifies maintenance. +// +// b.mu must be held. +func (b *LocalBackend) changeDisablesExitNodeLocked(prefs ipn.PrefsView, change *ipn.MaskedPrefs) bool { + if !buildfeatures.HasUseExitNode { + return false + } + if !change.AutoExitNodeSet && !change.ExitNodeIDSet && !change.ExitNodeIPSet { + // The change does not affect exit node usage. + return false + } - // Zeroing the ExitNodeId via localAPI must also zero the prior exit node. - if mp.ExitNodeIDSet && mp.ExitNodeID == "" { + if prefs.ExitNodeID() == "" { + // Exit node usage is already disabled. + // Note that we do not check for ExitNodeIP here. + // If ExitNodeIP hasn't been resolved to a node, + // it's not enabled yet. + return false + } + + // First, apply the adjustments to a copy of the changes, + // e.g., clear AutoExitNode if ExitNodeID is set. + tmpChange := ptr.To(*change) + tmpChange.Prefs = *change.Prefs.Clone() + b.adjustEditPrefsLocked(prefs, tmpChange) + + // Then apply the adjusted changes to a copy of the current prefs, + // and resolve the exit node in the prefs. + tmpPrefs := prefs.AsStruct() + tmpPrefs.ApplyEdits(tmpChange) + b.resolveExitNodeInPrefsLocked(tmpPrefs) + + // If ExitNodeID is empty after applying the changes, + // but wasn't empty before, then the change disables + // exit node usage. + return tmpPrefs.ExitNodeID == "" +} + +// adjustEditPrefsLocked applies additional changes to mp if necessary, +// such as zeroing out mutually exclusive fields. +// +// It must not assume that the changes in mp will actually be applied. +// +// b.mu must be held. +func (b *LocalBackend) adjustEditPrefsLocked(prefs ipn.PrefsView, mp *ipn.MaskedPrefs) { + // Zeroing the ExitNodeID via localAPI must also zero the prior exit node. + if mp.ExitNodeIDSet && mp.ExitNodeID == "" && !mp.InternalExitNodePriorSet { mp.InternalExitNodePrior = "" mp.InternalExitNodePriorSet = true } - unlock := b.lockAndGetUnlock() - defer unlock() - return b.editPrefsLockedOnEntry(mp, unlock) + // Clear ExitNodeID if AutoExitNode is disabled and ExitNodeID is still unresolved. + if mp.AutoExitNodeSet && mp.AutoExitNode == "" && prefs.ExitNodeID() == unresolvedExitNodeID { + mp.ExitNodeIDSet = true + mp.ExitNodeID = "" + } + + // Disable automatic exit node selection if the user explicitly sets + // ExitNodeID or ExitNodeIP. + if (mp.ExitNodeIDSet || mp.ExitNodeIPSet) && !mp.AutoExitNodeSet { + mp.AutoExitNodeSet = true + mp.AutoExitNode = "" + } +} + +// onEditPrefsLocked is called when prefs are edited (typically, via LocalAPI), +// just before the changes in newPrefs are set for the current profile. +// +// The changes in mp have been allowed, but the resulting [ipn.Prefs] +// have not yet been applied and may be subject to reconciliation +// by [LocalBackend.reconcilePrefsLocked], either before or after being set. +// +// This method handles preference edits, typically initiated by the user, +// as opposed to reconfiguring the backend when the final prefs are set. +// +// b.mu must be held; mp must not be mutated by this method. +func (b *LocalBackend) onEditPrefsLocked(_ ipnauth.Actor, mp *ipn.MaskedPrefs, oldPrefs, newPrefs ipn.PrefsView) { + if mp.WantRunningSet && !mp.WantRunning && oldPrefs.WantRunning() { + // If a user has enough rights to disconnect, such as when [pkey.AlwaysOn] + // is disabled, or [pkey.AlwaysOnOverrideWithReason] is also set and the user + // provides a reason for disconnecting, then we should not force the "always on" + // mode on them until the policy changes, they switch to a different profile, etc. + b.overrideAlwaysOn = true + + if reconnectAfter, _ := b.polc.GetDuration(pkey.ReconnectAfter, 0); reconnectAfter > 0 { + b.startReconnectTimerLocked(reconnectAfter) + } + } + + if oldPrefs.WantRunning() != newPrefs.WantRunning() { + // Connecting to or disconnecting from Tailscale clears the override, + // unless the user is also explicitly changing the exit node (see below). + b.overrideExitNodePolicy = false + } + if mp.AutoExitNodeSet || mp.ExitNodeIDSet || mp.ExitNodeIPSet { + if allowExitNodeOverride, _ := b.polc.GetBoolean(pkey.AllowExitNodeOverride, false); allowExitNodeOverride { + // If applying exit node policy settings to the new prefs results in no change, + // the user is not overriding the policy. Otherwise, it is an override. + b.overrideExitNodePolicy = b.applyExitNodeSysPolicyLocked(newPrefs.AsStruct()) + } else { + // Overrides are not allowed; clear the override flag. + b.overrideExitNodePolicy = false + } + } + + // This is recorded here in the EditPrefs path, not the setPrefs path on purpose. + // recordForEdit records metrics related to edits and changes, not the final state. + // If, in the future, we want to record gauge-metrics related to the state of prefs, + // that should be done in the setPrefs path. + e := prefsMetricsEditEvent{ + change: mp, + pNew: newPrefs, + pOld: oldPrefs, + node: b.currentNode(), + lastSuggestedExitNode: b.lastSuggestedExitNode, + } + e.record() +} + +// startReconnectTimerLocked sets a timer to automatically set WantRunning to true +// after the specified duration. +func (b *LocalBackend) startReconnectTimerLocked(d time.Duration) { + if b.reconnectTimer != nil { + // Stop may return false if the timer has already fired, + // and the function has been called in its own goroutine, + // but lost the race to acquire b.mu. In this case, it'll + // end up as a no-op due to a reconnectTimer mismatch + // once it manages to acquire the lock. This is fine, and we + // don't need to check the return value. + b.reconnectTimer.Stop() + } + profileID := b.pm.CurrentProfile().ID() + var reconnectTimer tstime.TimerController + reconnectTimer = b.clock.AfterFunc(d, func() { + b.mu.Lock() + defer b.mu.Unlock() + + if b.reconnectTimer != reconnectTimer { + // We're either not the most recent timer, or we lost the race when + // the timer was stopped. No need to reconnect. + return + } + b.reconnectTimer = nil + + cp := b.pm.CurrentProfile() + if cp.ID() != profileID { + // The timer fired before the profile changed but we lost the race + // and acquired the lock shortly after. + // No need to reconnect. + return + } + + mp := &ipn.MaskedPrefs{WantRunningSet: true, Prefs: ipn.Prefs{WantRunning: true}} + if _, err := b.editPrefsLocked(ipnauth.Self, mp); err != nil { + b.logf("failed to automatically reconnect as %q after %v: %v", cp.Name(), d, err) + } else { + b.logf("automatically reconnected as %q after %v", cp.Name(), d) + } + }) + b.reconnectTimer = reconnectTimer + b.logf("reconnect for %q has been scheduled and will be performed in %v", b.pm.CurrentProfile().Name(), d) +} + +func (b *LocalBackend) resetAlwaysOnOverrideLocked() { + b.overrideAlwaysOn = false + b.stopReconnectTimerLocked() +} + +func (b *LocalBackend) stopReconnectTimerLocked() { + if b.reconnectTimer != nil { + // Stop may return false if the timer has already fired, + // and the function has been called in its own goroutine, + // but lost the race to acquire b.mu. + // In this case, it'll end up as a no-op due to a reconnectTimer + // mismatch (see [LocalBackend.startReconnectTimerLocked]) + // once it manages to acquire the lock. This is fine, and we + // don't need to check the return value. + b.reconnectTimer.Stop() + b.reconnectTimer = nil + } } -// Warning: b.mu must be held on entry, but it unlocks it on the way out. -// TODO(bradfitz): redo the locking on all these weird methods like this. -func (b *LocalBackend) editPrefsLockedOnEntry(mp *ipn.MaskedPrefs, unlock unlockOnce) (ipn.PrefsView, error) { - defer unlock() // for error paths +// b.mu must be held. +func (b *LocalBackend) editPrefsLocked(actor ipnauth.Actor, mp *ipn.MaskedPrefs) (ipn.PrefsView, error) { + p0 := b.pm.CurrentPrefs() + + // Check if the changes in mp are allowed. + if err := b.checkEditPrefsAccessLocked(actor, p0, mp); err != nil { + b.logf("EditPrefs(%v): %v", mp.Pretty(), err) + return ipn.PrefsView{}, err + } + + // Apply additional changes to mp if necessary, + // such as clearing mutually exclusive fields. + b.adjustEditPrefsLocked(p0, mp) if mp.EggSet { mp.EggSet = false b.egg = true - go b.doSetHostinfoFilterServices() + b.goTracker.Go(b.doSetHostinfoFilterServices) } - p0 := b.pm.CurrentPrefs() + p1 := b.pm.CurrentPrefs().AsStruct() p1.ApplyEdits(mp) + if err := b.checkPrefsLocked(p1); err != nil { b.logf("EditPrefs check error: %v", err) return ipn.PrefsView{}, err } - if p1.RunSSH && !envknob.CanSSHD() { - b.logf("EditPrefs requests SSH, but disabled by envknob; returning error") - return ipn.PrefsView{}, errors.New("Tailscale SSH server administratively disabled.") - } + if p1.View().Equals(p0) { return stripKeysFromPrefs(p0), nil } + b.logf("EditPrefs: %v", mp.Pretty()) - newPrefs := b.setPrefsLockedOnEntry(p1, unlock) + + // Perform any actions required when prefs are edited (typically by a user), + // before the modified prefs are actually set for the current profile. + b.onEditPrefsLocked(actor, mp, p0, p1.View()) + + newPrefs := b.setPrefsLocked(p1) // Note: don't perform any actions for the new prefs here. Not // every prefs change goes through EditPrefs. Put your actions - // in setPrefsLocksOnEntry instead. + // in setPrefsLocked instead. // This should return the public prefs, not the private ones. return stripKeysFromPrefs(newPrefs), nil @@ -3697,46 +4595,31 @@ func (b *LocalBackend) checkProfileNameLocked(p *ipn.Prefs) error { // No profile with that name exists. That's fine. return nil } - if id != b.pm.CurrentProfile().ID { + if id != b.pm.CurrentProfile().ID() { // Name is already in use by another profile. return fmt.Errorf("profile name %q already in use", p.ProfileName) } return nil } -// wantIngressLocked reports whether this node has ingress configured. This bool -// is sent to the coordination server (in Hostinfo.WireIngress) as an -// optimization hint to know primarily which nodes are NOT using ingress, to -// avoid doing work for regular nodes. -// -// Even if the user's ServeConfig.AllowFunnel map was manually edited in raw -// mode and contains map entries with false values, sending true (from Len > 0) -// is still fine. This is only an optimization hint for the control plane and -// doesn't affect security or correctness. And we also don't expect people to -// modify their ServeConfig in raw mode. -func (b *LocalBackend) wantIngressLocked() bool { - return b.serveConfig.Valid() && b.serveConfig.HasAllowFunnel() -} - -// setPrefsLockedOnEntry requires b.mu be held to call it, but it -// unlocks b.mu when done. newp ownership passes to this function. -// It returns a readonly copy of the new prefs. -func (b *LocalBackend) setPrefsLockedOnEntry(newp *ipn.Prefs, unlock unlockOnce) ipn.PrefsView { - defer unlock() - - netMap := b.netMap +// setPrefsLocked requires b.mu be held to call it. +// newp ownership passes to this function. +// It returns a read-only copy of the new prefs. +func (b *LocalBackend) setPrefsLocked(newp *ipn.Prefs) ipn.PrefsView { + cn := b.currentNode() + netMap := cn.NetMap() b.setAtomicValuesFromPrefsLocked(newp.View()) oldp := b.pm.CurrentPrefs() if oldp.Valid() { newp.Persist = oldp.Persist().AsStruct() // caller isn't allowed to override this } - // setExitNodeID returns whether it updated b.prefs, but - // everything in this function treats b.prefs as completely new - // anyway. No-op if no exit node resolution is needed. - setExitNodeID(newp, netMap, b.lastSuggestedExitNode) - // applySysPolicy does likewise so we can also ignore its return value. - applySysPolicy(newp) + // Apply reconciliation to the prefs, such as policy overrides, + // exit node resolution, and so on. The call returns whether it updated + // newp, but everything in this function treats newp as completely new + // anyway, so its return value can be ignored here. + b.reconcilePrefsLocked(newp) + // We do this to avoid holding the lock while doing everything else. oldHi := b.hostinfo @@ -3749,16 +4632,16 @@ func (b *LocalBackend) setPrefsLockedOnEntry(newp *ipn.Prefs, unlock unlockOnce) hostInfoChanged := !oldHi.Equal(newHi) cc := b.cc - b.updateFilterLocked(netMap, newp.View()) + b.updateFilterLocked(newp.View()) - if oldp.ShouldSSHBeRunning() && !newp.ShouldSSHBeRunning() { + if buildfeatures.HasSSH && oldp.ShouldSSHBeRunning() && !newp.ShouldSSHBeRunning() { if b.sshServer != nil { - go b.sshServer.Shutdown() + b.goTracker.Go(b.sshServer.Shutdown) b.sshServer = nil } } if netMap != nil { - newProfile := netMap.UserProfiles[netMap.User()] + newProfile := profileFromView(netMap.UserProfiles[netMap.User()]) if newLoginName := newProfile.LoginName; newLoginName != "" { if !oldp.Persist().Valid() { b.logf("active login: %s", newLoginName) @@ -3773,49 +4656,48 @@ func (b *LocalBackend) setPrefsLockedOnEntry(newp *ipn.Prefs, unlock unlockOnce) } prefs := newp.View() - if err := b.pm.SetPrefs(prefs, ipn.NetworkProfile{ - MagicDNSName: b.netMap.MagicDNSSuffix(), - DomainName: b.netMap.DomainName(), - }); err != nil { + np := cmp.Or(cn.NetworkProfile(), b.pm.CurrentProfile().NetworkProfile()) + if err := b.pm.SetPrefs(prefs, np); err != nil { b.logf("failed to save new controlclient state: %v", err) + } else if prefs.WantRunning() { + // Reset the always-on override if WantRunning is true in the new prefs, + // such as when the user toggles the Connected switch in the GUI + // or runs `tailscale up`. + b.resetAlwaysOnOverrideLocked() } - if newp.AutoUpdate.Apply.EqualBool(true) { - if b.state != ipn.Running { - b.maybeStartOfflineAutoUpdate(newp.View()) - } - } else { - b.stopOfflineAutoUpdate() - } - - unlock.UnlockEarly() + b.pauseOrResumeControlClientLocked() // for prefs.Sync changes + b.updateWarnSync(prefs) if oldp.ShieldsUp() != newp.ShieldsUp || hostInfoChanged { - b.doSetHostinfoFilterServices() + b.doSetHostinfoFilterServicesLocked() } if netMap != nil { b.MagicConn().SetDERPMap(netMap.DERPMap) } - if !oldp.WantRunning() && newp.WantRunning { + if !oldp.WantRunning() && newp.WantRunning && cc != nil { b.logf("transitioning to running; doing Login...") cc.Login(controlclient.LoginDefault) } if oldp.WantRunning() != newp.WantRunning { - b.stateMachine() + b.stateMachineLocked() } else { - b.authReconfig() + b.authReconfigLocked() } - b.send(ipn.Notify{Prefs: &prefs}) + b.sendLocked(ipn.Notify{Prefs: &prefs}) return prefs } // GetPeerAPIPort returns the port number for the peerapi server // running on the provided IP. func (b *LocalBackend) GetPeerAPIPort(ip netip.Addr) (port uint16, ok bool) { + if !buildfeatures.HasPeerAPIServer { + return 0, false + } b.mu.Lock() defer b.mu.Unlock() for _, pln := range b.peerAPIListeners { @@ -3857,57 +4739,15 @@ var ( magicDNSIPv6 = tsaddr.TailscaleServiceIPv6() ) -// TCPHandlerForDst returns a TCP handler for connections to dst, or nil if -// no handler is needed. It also returns a list of TCP socket options to -// apply to the socket before calling the handler. -// TCPHandlerForDst is called both for connections to our node's local IP -// as well as to the service IP (quad 100). -func (b *LocalBackend) TCPHandlerForDst(src, dst netip.AddrPort) (handler func(c net.Conn) error, opts []tcpip.SettableSocketOption) { - // First handle internal connections to the service IP - hittingServiceIP := dst.Addr() == magicDNSIP || dst.Addr() == magicDNSIPv6 - if hittingServiceIP { - switch dst.Port() { - case 80: - // TODO(mpminardi): do we want to show an error message if the web client - // has been disabled instead of the more "basic" web UI? - if b.ShouldRunWebClient() { - return b.handleWebClientConn, opts - } - return b.HandleQuad100Port80Conn, opts - case DriveLocalPort: - return b.handleDriveConn, opts - } - } +// Hook exclusively for serve. +var ( + hookServeTCPHandlerForVIPService feature.Hook[func(b *LocalBackend, dst netip.AddrPort, src netip.AddrPort) (handler func(c net.Conn) error)] + hookTCPHandlerForServe feature.Hook[func(b *LocalBackend, dport uint16, srcAddr netip.AddrPort, f *funnelFlow) (handler func(net.Conn) error)] + hookServeUpdateServeTCPPortNetMapAddrListenersLocked feature.Hook[func(b *LocalBackend, ports []uint16)] - // Then handle external connections to the local IP. - if !b.isLocalIP(dst.Addr()) { - return nil, nil - } - if dst.Port() == 22 && b.ShouldRunSSH() { - // Use a higher keepalive idle time for SSH connections, as they are - // typically long lived and idle connections are more likely to be - // intentional. Ideally we would turn this off entirely, but we can't - // tell the difference between a long lived connection that is idle - // vs a connection that is dead because the peer has gone away. - // We pick 72h as that is typically sufficient for a long weekend. - opts = append(opts, ptr.To(tcpip.KeepaliveIdleOption(72*time.Hour))) - return b.handleSSHConn, opts - } - // TODO(will,sonia): allow customizing web client port ? - if dst.Port() == webClientPort && b.ShouldExposeRemoteWebClient() { - return b.handleWebClientConn, opts - } - if port, ok := b.GetPeerAPIPort(dst.Addr()); ok && dst.Port() == port { - return func(c net.Conn) error { - b.handlePeerAPIConn(src, dst, c) - return nil - }, opts - } - if handler := b.tcpHandlerForServe(dst.Port(), src, nil); handler != nil { - return handler, opts - } - return nil, nil -} + hookServeSetTCPPortsInterceptedFromNetmapAndPrefsLocked feature.Hook[func(b *LocalBackend, prefs ipn.PrefsView) (handlePorts []uint16)] + hookServeClearVIPServicesTCPPortsInterceptedLocked feature.Hook[func(*LocalBackend)] +) func (b *LocalBackend) handleDriveConn(conn net.Conn) error { fs, ok := b.sys.DriveForLocal.GetOK() @@ -3930,7 +4770,7 @@ func (b *LocalBackend) peerAPIServicesLocked() (ret []tailcfg.Service) { }) } switch runtime.GOOS { - case "linux", "freebsd", "openbsd", "illumos", "darwin", "windows", "android", "ios": + case "linux", "freebsd", "openbsd", "illumos", "solaris", "darwin", "windows", "android", "ios": // These are the platforms currently supported by // net/dns/resolver/tsdns.go:Resolver.HandleExitNodeDNSQuery. ret = append(ret, tailcfg.Service{ @@ -3941,15 +4781,38 @@ func (b *LocalBackend) peerAPIServicesLocked() (ret []tailcfg.Service) { return ret } +// PortlistServices is an eventbus topic for the portlist extension +// to advertise the running services on the host. +type PortlistServices []tailcfg.Service + +func (b *LocalBackend) setPortlistServices(sl []tailcfg.Service) { + if !buildfeatures.HasPortList { // redundant, but explicit for linker deadcode and humans + return + } + + b.mu.Lock() + if b.hostinfo == nil { + b.hostinfo = new(tailcfg.Hostinfo) + } + b.hostinfo.Services = sl + b.mu.Unlock() + + b.doSetHostinfoFilterServices() +} + // doSetHostinfoFilterServices calls SetHostinfo on the controlclient, // possibly after mangling the given hostinfo. // // TODO(danderson): we shouldn't be mangling hostinfo here after // painstakingly constructing it in twelvety other places. func (b *LocalBackend) doSetHostinfoFilterServices() { - unlock := b.lockAndGetUnlock() - defer unlock() + b.mu.Lock() + defer b.mu.Unlock() + b.doSetHostinfoFilterServicesLocked() +} +// b.mu must be held +func (b *LocalBackend) doSetHostinfoFilterServicesLocked() { cc := b.cc if cc == nil { // Control client isn't up yet. @@ -3966,27 +4829,50 @@ func (b *LocalBackend) doSetHostinfoFilterServices() { // TODO(maisem,bradfitz): store hostinfo as a view, not as a mutable struct. hi := *b.hostinfo // shallow copy - unlock.UnlockEarly() // Make a shallow copy of hostinfo so we can mutate // at the Service field. - if !b.shouldUploadServices() { + if f, ok := b.extHost.Hooks().ShouldUploadServices.GetOk(); !ok || !f() { hi.Services = []tailcfg.Service{} } + // Don't mutate hi.Service's underlying array. Append to // the slice with no free capacity. c := len(hi.Services) hi.Services = append(hi.Services[:c:c], peerAPIServices...) hi.PushDeviceToken = b.pushDeviceToken.Load() + + // Compare the expected ports from peerAPIServices to the actual ports in hi.Services. + expectedPorts := extractPeerAPIPorts(peerAPIServices) + actualPorts := extractPeerAPIPorts(hi.Services) + if expectedPorts != actualPorts { + b.logf("Hostinfo peerAPI ports changed: expected %v, got %v", expectedPorts, actualPorts) + } + cc.SetHostinfo(&hi) } +type portPair struct { + v4, v6 uint16 +} + +func extractPeerAPIPorts(services []tailcfg.Service) portPair { + var p portPair + for _, s := range services { + switch s.Proto { + case "peerapi4": + p.v4 = s.Port + case "peerapi6": + p.v6 = s.Port + } + } + return p +} + // NetMap returns the latest cached network map received from // controlclient, or nil if no network map was received yet. func (b *LocalBackend) NetMap() *netmap.NetworkMap { - b.mu.Lock() - defer b.mu.Unlock() - return b.netMap + return b.currentNode().NetMap() } func (b *LocalBackend) isEngineBlocked() bool { @@ -3995,21 +4881,24 @@ func (b *LocalBackend) isEngineBlocked() bool { return b.blocked } -// blockEngineUpdate sets b.blocked to block, while holding b.mu. Its -// indirect effect is to turn b.authReconfig() into a no-op if block -// is true. -func (b *LocalBackend) blockEngineUpdates(block bool) { +// blockEngineUpdatesLocked sets b.blocked to block. +// +// Its indirect effect is to turn b.authReconfig() into a no-op if block is +// true. +// +// b.mu must be held. +func (b *LocalBackend) blockEngineUpdatesLocked(block bool) { b.logf("blockEngineUpdates(%v)", block) - - b.mu.Lock() b.blocked = block - b.mu.Unlock() } // reconfigAppConnectorLocked updates the app connector state based on the // current network map and preferences. // b.mu must be held. func (b *LocalBackend) reconfigAppConnectorLocked(nm *netmap.NetworkMap, prefs ipn.PrefsView) { + if !buildfeatures.HasAppConnectors { + return + } const appConnectorCapName = "tailscale.com/app-connectors" defer func() { if b.hostinfo != nil { @@ -4017,36 +4906,34 @@ func (b *LocalBackend) reconfigAppConnectorLocked(nm *netmap.NetworkMap, prefs i } }() + // App connectors have been disabled. if !prefs.AppConnector().Advertise { + b.appConnector.Close() // clean up a previous connector (safe on nil) b.appConnector = nil return } - shouldAppCStoreRoutes := b.ControlKnobs().AppCStoreRoutes.Load() - if b.appConnector == nil || b.appConnector.ShouldStoreRoutes() != shouldAppCStoreRoutes { - var ri *appc.RouteInfo - var storeFunc func(*appc.RouteInfo) error - if shouldAppCStoreRoutes { - var err error - ri, err = b.readRouteInfoLocked() - if err != nil { - ri = &appc.RouteInfo{} - if err != ipn.ErrStateNotExist { - b.logf("Unsuccessful Read RouteInfo: ", err) - } - } - storeFunc = b.storeRouteInfo - } - b.appConnector = appc.NewAppConnector(b.logf, b, ri, storeFunc) + // We don't (yet) have an app connector configured, or the configured + // connector has a different route persistence setting. + shouldStoreRoutes := b.ControlKnobs().AppCStoreRoutes.Load() + if b.appConnector == nil || (shouldStoreRoutes != b.appConnector.ShouldStoreRoutes()) { + ri, err := b.readRouteInfoLocked() + if err != nil && err != ipn.ErrStateNotExist { + b.logf("Unsuccessful Read RouteInfo: %v", err) + } + b.appConnector.Close() // clean up a previous connector (safe on nil) + b.appConnector = appc.NewAppConnector(appc.Config{ + Logf: b.logf, + EventBus: b.sys.Bus.Get(), + RouteInfo: ri, + HasStoredRoutes: shouldStoreRoutes, + }) } if nm == nil { return } - // TODO(raggi): rework the view infrastructure so the large deep clone is no - // longer required - sn := nm.SelfNode.AsStruct() - attrs, err := tailcfg.UnmarshalNodeCapJSON[appctype.AppConnectorAttr](sn.CapMap, appConnectorCapName) + attrs, err := tailcfg.UnmarshalNodeCapViewJSON[appctype.AppConnectorAttr](nm.SelfNode.CapMap(), appConnectorCapName) if err != nil { b.logf("[unexpected] error parsing app connector mapcap: %v", err) return @@ -4076,31 +4963,78 @@ func (b *LocalBackend) reconfigAppConnectorLocked(nm *netmap.NetworkMap, prefs i b.appConnector.UpdateDomainsAndRoutes(domains, routes) } +func (b *LocalBackend) readvertiseAppConnectorRoutes() { + // Note: we should never call b.appConnector methods while holding b.mu. + // This can lead to a deadlock, like + // https://github.com/tailscale/corp/issues/25965. + // + // Grab a copy of the field, since b.mu only guards access to the + // b.appConnector field itself. + appConnector := b.AppConnector() + + if appConnector == nil { + return + } + domainRoutes := appConnector.DomainRoutes() + if domainRoutes == nil { + return + } + + // Re-advertise the stored routes, in case stored state got out of + // sync with previously advertised routes in prefs. + var prefixes []netip.Prefix + for _, ips := range domainRoutes { + for _, ip := range ips { + prefixes = append(prefixes, netip.PrefixFrom(ip, ip.BitLen())) + } + } + // Note: AdvertiseRoute will trim routes that are already + // advertised, so if everything is already being advertised this is + // a noop. + if err := b.AdvertiseRoute(prefixes...); err != nil { + b.logf("error advertising stored app connector routes: %v", err) + } +} + // authReconfig pushes a new configuration into wgengine, if engine // updates are not currently blocked, based on the cached netmap and // user prefs. func (b *LocalBackend) authReconfig() { b.mu.Lock() - blocked := b.blocked - prefs := b.pm.CurrentPrefs() - nm := b.netMap - hasPAC := b.prevIfState.HasPAC() - disableSubnetsIfPAC := nm.HasCap(tailcfg.NodeAttrDisableSubnetsIfPAC) - userDialUseRoutes := nm.HasCap(tailcfg.NodeAttrUserDialUseRoutes) - dohURL, dohURLOK := exitNodeCanProxyDNS(nm, b.peers, prefs.ExitNodeID()) - dcfg := dnsConfigForNetmap(nm, b.peers, prefs, b.keyExpired, b.logf, version.OS()) - // If the current node is an app connector, ensure the app connector machine is started - b.reconfigAppConnectorLocked(nm, prefs) - b.mu.Unlock() + defer b.mu.Unlock() + b.authReconfigLocked() +} + +// authReconfigLocked is the locked version of [LocalBackend.authReconfig]. +// +// b.mu must be held. +func (b *LocalBackend) authReconfigLocked() { - if blocked { + if b.shutdownCalled { + b.logf("[v1] authReconfig: skipping because in shutdown") + return + } + if b.blocked { b.logf("[v1] authReconfig: blocked, skipping.") return } + + cn := b.currentNode() + + nm := cn.NetMap() if nm == nil { b.logf("[v1] authReconfig: netmap not yet valid. Skipping.") return } + + prefs := b.pm.CurrentPrefs() + hasPAC := b.prevIfState.HasPAC() + disableSubnetsIfPAC := cn.SelfHasCap(tailcfg.NodeAttrDisableSubnetsIfPAC) + dohURL, dohURLOK := cn.exitNodeCanProxyDNS(prefs.ExitNodeID()) + dcfg := cn.dnsConfigForNetmap(prefs, b.keyExpired, version.OS()) + // If the current node is an app connector, ensure the app connector machine is started + b.reconfigAppConnectorLocked(nm, prefs) + if !prefs.WantRunning() { b.logf("[v1] authReconfig: skipping because !WantRunning.") return @@ -4120,20 +5054,27 @@ func (b *LocalBackend) authReconfig() { // Keep the dialer updated about whether we're supposed to use // an exit node's DNS server (so SOCKS5/HTTP outgoing dials // can use it for name resolution) - if dohURLOK { - b.dialer.SetExitDNSDoH(dohURL) - } else { - b.dialer.SetExitDNSDoH("") + if buildfeatures.HasUseExitNode { + if dohURLOK { + b.dialer.SetExitDNSDoH(dohURL) + } else { + b.dialer.SetExitDNSDoH("") + } + } + + priv := b.pm.CurrentPrefs().Persist().PrivateNodeKey() + if !priv.IsZero() && priv.Public() != nm.NodeKey { + priv = key.NodePrivate{} } - cfg, err := nmcfg.WGCfg(nm, b.logf, flags, prefs.ExitNodeID()) + cfg, err := nmcfg.WGCfg(priv, nm, b.logf, flags, prefs.ExitNodeID()) if err != nil { b.logf("wgcfg: %v", err) return } - oneCGNATRoute := shouldUseOneCGNATRoute(b.logf, b.sys.ControlKnobs(), version.OS()) - rcfg := b.routerConfig(cfg, prefs, oneCGNATRoute) + oneCGNATRoute := shouldUseOneCGNATRoute(b.logf, b.sys.NetMon.Get(), b.sys.ControlKnobs(), version.OS()) + rcfg := b.routerConfigLocked(cfg, prefs, oneCGNATRoute) err = b.e.Reconfig(cfg, rcfg, dcfg) if err == wgengine.ErrNoChanges { @@ -4141,13 +5082,10 @@ func (b *LocalBackend) authReconfig() { } b.logf("[v1] authReconfig: ra=%v dns=%v 0x%02x: %v", prefs.RouteAll(), prefs.CorpDNS(), flags, err) - if userDialUseRoutes { - b.dialer.SetRoutes(rcfg.Routes, rcfg.LocalRoutes) - } else { - b.dialer.SetRoutes(nil, nil) + b.initPeerAPIListenerLocked() + if buildfeatures.HasAppConnectors { + go b.goTracker.Go(b.readvertiseAppConnectorRoutes) } - - b.initPeerAPIListener() } // shouldUseOneCGNATRoute reports whether we should prefer to make one big @@ -4155,7 +5093,7 @@ func (b *LocalBackend) authReconfig() { // // The versionOS is a Tailscale-style version ("iOS", "macOS") and not // a runtime.GOOS. -func shouldUseOneCGNATRoute(logf logger.Logf, controlKnobs *controlknobs.Knobs, versionOS string) bool { +func shouldUseOneCGNATRoute(logf logger.Logf, mon *netmon.Monitor, controlKnobs *controlknobs.Knobs, versionOS string) bool { if controlKnobs != nil { // Explicit enabling or disabling always take precedence. if v, ok := controlKnobs.OneCGNAT.Load().Get(); ok { @@ -4164,13 +5102,18 @@ func shouldUseOneCGNATRoute(logf logger.Logf, controlKnobs *controlknobs.Knobs, } } + if versionOS == "plan9" { + // Just temporarily during plan9 bringup to have fewer routes to debug. + return true + } + // Also prefer to do this on the Mac, so that we don't need to constantly // update the network extension configuration (which is disruptive to // Chrome, see https://github.com/tailscale/tailscale/issues/3102). Only // use fine-grained routes if another interfaces is also using the CGNAT // IP range. if versionOS == "macOS" { - hasCGNATInterface, err := netmon.HasCGNATInterface() + hasCGNATInterface, err := mon.HasCGNATInterface() if err != nil { logf("shouldUseOneCGNATRoute: Could not determine if any interfaces use CGNAT: %v", err) return false @@ -4183,193 +5126,6 @@ func shouldUseOneCGNATRoute(logf logger.Logf, controlKnobs *controlknobs.Knobs, return false } -// dnsConfigForNetmap returns a *dns.Config for the given netmap, -// prefs, client OS version, and cloud hosting environment. -// -// The versionOS is a Tailscale-style version ("iOS", "macOS") and not -// a runtime.GOOS. -func dnsConfigForNetmap(nm *netmap.NetworkMap, peers map[tailcfg.NodeID]tailcfg.NodeView, prefs ipn.PrefsView, selfExpired bool, logf logger.Logf, versionOS string) *dns.Config { - if nm == nil { - return nil - } - - // If the current node's key is expired, then we don't program any DNS - // configuration into the operating system. This ensures that if the - // DNS configuration specifies a DNS server that is only reachable over - // Tailscale, we don't break connectivity for the user. - // - // TODO(andrew-d): this also stops returning anything from quad-100; we - // could do the same thing as having "CorpDNS: false" and keep that but - // not program the OS? - if selfExpired { - return &dns.Config{} - } - - dcfg := &dns.Config{ - Routes: map[dnsname.FQDN][]*dnstype.Resolver{}, - Hosts: map[dnsname.FQDN][]netip.Addr{}, - } - - // selfV6Only is whether we only have IPv6 addresses ourselves. - selfV6Only := nm.GetAddresses().ContainsFunc(tsaddr.PrefixIs6) && - !nm.GetAddresses().ContainsFunc(tsaddr.PrefixIs4) - dcfg.OnlyIPv6 = selfV6Only - - // Populate MagicDNS records. We do this unconditionally so that - // quad-100 can always respond to MagicDNS queries, even if the OS - // isn't configured to make MagicDNS resolution truly - // magic. Details in - // https://github.com/tailscale/tailscale/issues/1886. - set := func(name string, addrs views.Slice[netip.Prefix]) { - if addrs.Len() == 0 || name == "" { - return - } - fqdn, err := dnsname.ToFQDN(name) - if err != nil { - return // TODO: propagate error? - } - var have4 bool - for _, addr := range addrs.All() { - if addr.Addr().Is4() { - have4 = true - break - } - } - var ips []netip.Addr - for _, addr := range addrs.All() { - if selfV6Only { - if addr.Addr().Is6() { - ips = append(ips, addr.Addr()) - } - continue - } - // If this node has an IPv4 address, then - // remove peers' IPv6 addresses for now, as we - // don't guarantee that the peer node actually - // can speak IPv6 correctly. - // - // https://github.com/tailscale/tailscale/issues/1152 - // tracks adding the right capability reporting to - // enable AAAA in MagicDNS. - if addr.Addr().Is6() && have4 { - continue - } - ips = append(ips, addr.Addr()) - } - dcfg.Hosts[fqdn] = ips - } - set(nm.Name, nm.GetAddresses()) - for _, peer := range peers { - set(peer.Name(), peer.Addresses()) - } - for _, rec := range nm.DNS.ExtraRecords { - switch rec.Type { - case "", "A", "AAAA": - // Treat these all the same for now: infer from the value - default: - // TODO: more - continue - } - ip, err := netip.ParseAddr(rec.Value) - if err != nil { - // Ignore. - continue - } - fqdn, err := dnsname.ToFQDN(rec.Name) - if err != nil { - continue - } - dcfg.Hosts[fqdn] = append(dcfg.Hosts[fqdn], ip) - } - - if !prefs.CorpDNS() { - return dcfg - } - - for _, dom := range nm.DNS.Domains { - fqdn, err := dnsname.ToFQDN(dom) - if err != nil { - logf("[unexpected] non-FQDN search domain %q", dom) - } - dcfg.SearchDomains = append(dcfg.SearchDomains, fqdn) - } - if nm.DNS.Proxied { // actually means "enable MagicDNS" - for _, dom := range magicDNSRootDomains(nm) { - dcfg.Routes[dom] = nil // resolve internally with dcfg.Hosts - } - } - - addDefault := func(resolvers []*dnstype.Resolver) { - dcfg.DefaultResolvers = append(dcfg.DefaultResolvers, resolvers...) - } - - // If we're using an exit node and that exit node is new enough (1.19.x+) - // to run a DoH DNS proxy, then send all our DNS traffic through it. - if dohURL, ok := exitNodeCanProxyDNS(nm, peers, prefs.ExitNodeID()); ok { - addDefault([]*dnstype.Resolver{{Addr: dohURL}}) - return dcfg - } - - // If the user has set default resolvers ("override local DNS"), prefer to - // use those resolvers as the default, otherwise if there are WireGuard exit - // node resolvers, use those as the default. - if len(nm.DNS.Resolvers) > 0 { - addDefault(nm.DNS.Resolvers) - } else { - if resolvers, ok := wireguardExitNodeDNSResolvers(nm, peers, prefs.ExitNodeID()); ok { - addDefault(resolvers) - } - } - - for suffix, resolvers := range nm.DNS.Routes { - fqdn, err := dnsname.ToFQDN(suffix) - if err != nil { - logf("[unexpected] non-FQDN route suffix %q", suffix) - } - - // Create map entry even if len(resolvers) == 0; Issue 2706. - // This lets the control plane send ExtraRecords for which we - // can authoritatively answer "name not exists" for when the - // control plane also sends this explicit but empty route - // making it as something we handle. - // - // While we're already populating it, might as well size the - // slice appropriately. - // Per #9498 the exact requirements of nil vs empty slice remain - // unclear, this is a haunted graveyard to be resolved. - dcfg.Routes[fqdn] = make([]*dnstype.Resolver, 0, len(resolvers)) - dcfg.Routes[fqdn] = append(dcfg.Routes[fqdn], resolvers...) - } - - // Set FallbackResolvers as the default resolvers in the - // scenarios that can't handle a purely split-DNS config. See - // https://github.com/tailscale/tailscale/issues/1743 for - // details. - switch { - case len(dcfg.DefaultResolvers) != 0: - // Default resolvers already set. - case !prefs.ExitNodeID().IsZero(): - // When using an exit node, we send all DNS traffic to the exit node, so - // we don't need a fallback resolver. - // - // However, if the exit node is too old to run a DoH DNS proxy, then we - // need to use a fallback resolver as it's very likely the LAN resolvers - // will become unreachable. - // - // This is especially important on Apple OSes, where - // adding the default route to the tunnel interface makes - // it "primary", and we MUST provide VPN-sourced DNS - // settings or we break all DNS resolution. - // - // https://github.com/tailscale/tailscale/issues/1713 - addDefault(nm.DNS.FallbackResolvers) - case len(dcfg.Routes) == 0: - // No settings requiring split DNS, no problem. - } - - return dcfg -} - // SetTCPHandlerForFunnelFlow sets the TCP handler for Funnel flows. // It should only be called before the LocalBackend is used. func (b *LocalBackend) SetTCPHandlerForFunnelFlow(h func(src netip.AddrPort, dstPort uint16) (handler func(net.Conn))) { @@ -4388,6 +5144,9 @@ func (b *LocalBackend) SetVarRoot(dir string) { // // It should only be called before the LocalBackend is used. func (b *LocalBackend) SetLogFlusher(flushFunc func()) { + if !buildfeatures.HasLogTail { + return + } b.logFlushFunc = flushFunc } @@ -4396,7 +5155,7 @@ func (b *LocalBackend) SetLogFlusher(flushFunc func()) { // // TryFlushLogs should not block. func (b *LocalBackend) TryFlushLogs() bool { - if b.logFlushFunc == nil { + if !buildfeatures.HasLogTail || b.logFlushFunc == nil { return false } b.logFlushFunc() @@ -4423,26 +5182,6 @@ func (b *LocalBackend) TailscaleVarRoot() string { return "" } -func (b *LocalBackend) fileRootLocked(uid tailcfg.UserID) string { - if v := b.directFileRoot; v != "" { - return v - } - varRoot := b.TailscaleVarRoot() - if varRoot == "" { - b.logf("Taildrop disabled; no state directory") - return "" - } - baseDir := fmt.Sprintf("%s-uid-%d", - strings.ReplaceAll(b.activeLogin, "@", "-"), - uid) - dir := filepath.Join(varRoot, "files", baseDir) - if err := os.MkdirAll(dir, 0700); err != nil { - b.logf("Taildrop disabled; error making directory: %v", err) - return "" - } - return dir -} - // closePeerAPIListenersLocked closes any existing PeerAPI listeners // and clears out the PeerAPI server state. // @@ -4450,6 +5189,9 @@ func (b *LocalBackend) fileRootLocked(uid tailcfg.UserID) string { // // b.mu must be held. func (b *LocalBackend) closePeerAPIListenersLocked() { + if !buildfeatures.HasPeerAPIServer { + return + } b.peerAPIServer = nil for _, pln := range b.peerAPIListeners { pln.Close() @@ -4467,20 +5209,34 @@ const peerAPIListenAsync = runtime.GOOS == "windows" || runtime.GOOS == "android func (b *LocalBackend) initPeerAPIListener() { b.mu.Lock() defer b.mu.Unlock() + b.initPeerAPIListenerLocked() +} + +// b.mu must be held. +func (b *LocalBackend) initPeerAPIListenerLocked() { + if !buildfeatures.HasPeerAPIServer { + return + } + b.logf("[v1] initPeerAPIListener: entered") + if b.shutdownCalled { + b.logf("[v1] initPeerAPIListener: shutting down") return } - if b.netMap == nil { + cn := b.currentNode() + nm := cn.NetMap() + if nm == nil { // We're called from authReconfig which checks that // netMap is non-nil, but if a concurrent Logout, // ResetForClientDisconnect, or Start happens when its // mutex was released, the netMap could be // nil'ed out (Issue 1996). Bail out early here if so. + b.logf("[v1] initPeerAPIListener: no netmap") return } - addrs := b.netMap.GetAddresses() + addrs := nm.GetAddresses() if addrs.Len() == len(b.peerAPIListeners) { allSame := true for i, pln := range b.peerAPIListeners { @@ -4491,32 +5247,21 @@ func (b *LocalBackend) initPeerAPIListener() { } if allSame { // Nothing to do. + b.logf("[v1] initPeerAPIListener: %d netmap addresses match existing listeners", addrs.Len()) return } } b.closePeerAPIListenersLocked() - selfNode := b.netMap.SelfNode - if !selfNode.Valid() || b.netMap.GetAddresses().Len() == 0 { + selfNode := nm.SelfNode + if !selfNode.Valid() || nm.GetAddresses().Len() == 0 { + b.logf("[v1] initPeerAPIListener: no addresses in netmap") return } - fileRoot := b.fileRootLocked(selfNode.User()) - if fileRoot == "" { - b.logf("peerapi starting without Taildrop directory configured") - } - ps := &peerAPIServer{ b: b, - taildrop: taildrop.ManagerOptions{ - Logf: b.logf, - Clock: tstime.DefaultClock{Clock: b.clock}, - State: b.store, - Dir: fileRoot, - DirectFileMode: b.directFileRoot != "", - SendFileNotify: b.sendFileNotify, - }.New(), } if dm, ok := b.sys.DNSManager.GetOK(); ok { ps.resolver = dm.Resolver() @@ -4532,6 +5277,7 @@ func (b *LocalBackend) initPeerAPIListener() { ln, err = ps.listen(a.Addr(), b.prevIfState) if err != nil { if peerAPIListenAsync { + b.logf("[v1] possibly transient peerapi listen(%q) error, will try again on linkChange: %v", a.Addr(), err) // Expected. But we fix it later in linkChange // ("peerAPIListeners too low"). continue @@ -4557,7 +5303,7 @@ func (b *LocalBackend) initPeerAPIListener() { b.peerAPIListeners = append(b.peerAPIListeners, pln) } - go b.doSetHostinfoFilterServices() + b.goTracker.Go(b.doSetHostinfoFilterServices) } // magicDNSRootDomains returns the subset of nm.DNS.Domains that are the search domains for MagicDNS. @@ -4635,15 +5381,15 @@ func peerRoutes(logf logger.Logf, peers []wgcfg.Peer, cgnatThreshold int) (route } // routerConfig produces a router.Config from a wireguard config and IPN prefs. -func (b *LocalBackend) routerConfig(cfg *wgcfg.Config, prefs ipn.PrefsView, oneCGNATRoute bool) *router.Config { +// +// b.mu must be held. +func (b *LocalBackend) routerConfigLocked(cfg *wgcfg.Config, prefs ipn.PrefsView, oneCGNATRoute bool) *router.Config { singleRouteThreshold := 10_000 if oneCGNATRoute { singleRouteThreshold = 1 } - b.mu.Lock() - netfilterKind := b.capForcedNetfilter // protected by b.mu - b.mu.Unlock() + netfilterKind := b.capForcedNetfilter // protected by b.mu (hence the Locked suffix) if prefs.NetfilterKind() != "" { if netfilterKind != "" { @@ -4670,7 +5416,7 @@ func (b *LocalBackend) routerConfig(cfg *wgcfg.Config, prefs ipn.PrefsView, oneC NetfilterKind: netfilterKind, } - if distro.Get() == distro.Synology { + if buildfeatures.HasSynology && distro.Get() == distro.Synology { // Issue 1995: we don't use iptables on Synology. rs.NetfilterMode = preftype.NetfilterOff } @@ -4681,7 +5427,7 @@ func (b *LocalBackend) routerConfig(cfg *wgcfg.Config, prefs ipn.PrefsView, oneC // likely to break some functionality, but if the user expressed a // preference for routing remotely, we want to avoid leaking // traffic at the expense of functionality. - if prefs.ExitNodeID() != "" || prefs.ExitNodeIP().IsValid() { + if buildfeatures.HasUseExitNode && (prefs.ExitNodeID() != "" || prefs.ExitNodeIP().IsValid()) { var default4, default6 bool for _, route := range rs.Routes { switch route { @@ -4753,12 +5499,14 @@ func (b *LocalBackend) applyPrefsToHostinfoLocked(hi *tailcfg.Hostinfo, prefs ip hi.RoutableIPs = prefs.AdvertiseRoutes().AsSlice() hi.RequestTags = prefs.AdvertiseTags().AsSlice() hi.ShieldsUp = prefs.ShieldsUp() - hi.AllowsUpdate = envknob.AllowsRemoteUpdate() || prefs.AutoUpdate().Apply.EqualBool(true) + hi.AllowsUpdate = buildfeatures.HasClientUpdate && (envknob.AllowsRemoteUpdate() || prefs.AutoUpdate().Apply.EqualBool(true)) - b.metrics.advertisedRoutes.Set(float64(tsaddr.WithoutExitRoute(prefs.AdvertiseRoutes()).Len())) + if buildfeatures.HasAdvertiseRoutes { + b.metrics.advertisedRoutes.Set(float64(tsaddr.WithoutExitRoute(prefs.AdvertiseRoutes()).Len())) + } var sshHostKeys []string - if prefs.RunSSH() && envknob.CanSSHD() { + if buildfeatures.HasSSH && prefs.RunSSH() && envknob.CanSSHD() { // TODO(bradfitz): this is called with b.mu held. Not ideal. // If the filesystem gets wedged or something we could block for // a long time. But probably fine. @@ -4770,62 +5518,86 @@ func (b *LocalBackend) applyPrefsToHostinfoLocked(hi *tailcfg.Hostinfo, prefs ip } hi.SSH_HostKeys = sshHostKeys - // The Hostinfo.WantIngress field tells control whether this node wants to - // be wired up for ingress connections. If harmless if it's accidentally - // true; the actual policy is controlled in tailscaled by ServeConfig. But - // if this is accidentally false, then control may not configure DNS - // properly. This exists as an optimization to control to program fewer DNS - // records that have ingress enabled but are not actually being used. - hi.WireIngress = b.wantIngressLocked() - hi.AppConnector.Set(prefs.AppConnector().Advertise) + for _, f := range hookMaybeMutateHostinfoLocked { + f(b, hi, prefs) + } + + if buildfeatures.HasAppConnectors { + hi.AppConnector.Set(prefs.AppConnector().Advertise) + } + + // The [tailcfg.Hostinfo.ExitNodeID] field tells control which exit node + // was selected, if any. + // + // If auto exit node is enabled (via [ipn.Prefs.AutoExitNode] or + // [pkey.ExitNodeID]), or an exit node is specified by ExitNodeIP + // instead of ExitNodeID , and we don't yet have enough info to resolve + // it (usually due to missing netmap or net report), then ExitNodeID in + // the prefs may be invalid (typically, [unresolvedExitNodeID]) until + // the netmap is available. + // + // In this case, we shouldn't update the Hostinfo with the bogus + // ExitNodeID here; [LocalBackend.ResolveExitNode] will be called once + // the netmap and/or net report have been received to both pick the exit + // node and notify control of the change. + if buildfeatures.HasUseExitNode { + if sid := prefs.ExitNodeID(); sid != unresolvedExitNodeID { + hi.ExitNodeID = prefs.ExitNodeID() + } + } } -// enterState transitions the backend into newState, updating internal +// enterStateLocked transitions the backend into newState, updating internal // state and propagating events out as needed. // // TODO(danderson): while this isn't a lie, exactly, a ton of other // places twiddle IPN internal state without going through here, so // really this is more "one of several places in which random things // happen". -func (b *LocalBackend) enterState(newState ipn.State) { - unlock := b.lockAndGetUnlock() - b.enterStateLockedOnEntry(newState, unlock) -} - -// enterStateLockedOnEntry is like enterState but requires b.mu be held to call -// it, but it unlocks b.mu when done (via unlock, a once func). -func (b *LocalBackend) enterStateLockedOnEntry(newState ipn.State, unlock unlockOnce) { +// +// b.mu must be held. +func (b *LocalBackend) enterStateLocked(newState ipn.State) { + cn := b.currentNode() oldState := b.state - b.state = newState + b.setStateLocked(newState) prefs := b.pm.CurrentPrefs() // Some temporary (2024-05-05) debugging code to help us catch // https://github.com/tailscale/tailscale/issues/11962 in the act. if prefs.WantRunning() && - prefs.ControlURLOrDefault() == ipn.DefaultControlURL && + prefs.ControlURLOrDefault(b.polc) == ipn.DefaultControlURL && envknob.Bool("TS_PANIC_IF_HIT_MAIN_CONTROL") { panic("[unexpected] use of main control server in integration test") } - netMap := b.netMap + netMap := cn.NetMap() activeLogin := b.activeLogin authURL := b.authURL if newState == ipn.Running { - b.resetAuthURLLocked() + // TODO(zofrex): Is this needed? As of 2025-10-03 it doesn't seem to be + // necessary when logging in or authenticating. When do we need to reset it + // here, rather than the other places it is reset? We should test if it is + // necessary and add unit tests to cover those cases, or remove it. + if oldState != ipn.Running { + b.resetAuthURLLocked() + } // Start a captive portal detection loop if none has been // started. Create a new context if none is present, since it // can be shut down if we transition away from Running. - if b.captiveCancel == nil { - b.captiveCtx, b.captiveCancel = context.WithCancel(b.ctx) - go b.checkCaptivePortalLoop(b.captiveCtx) + if buildfeatures.HasCaptivePortal { + if b.captiveCancel == nil { + captiveCtx, captiveCancel := context.WithCancel(b.ctx) + b.captiveCtx, b.captiveCancel = captiveCtx, captiveCancel + b.goTracker.Go(func() { hookCheckCaptivePortalLoop.Get()(b, captiveCtx) }) + } } } else if oldState == ipn.Running { // Transitioning away from running. b.closePeerAPIListenersLocked() // Stop any existing captive portal detection loop. - if b.captiveCancel != nil { + if buildfeatures.HasCaptivePortal && b.captiveCancel != nil { b.captiveCancel() b.captiveCancel = nil @@ -4836,54 +5608,49 @@ func (b *LocalBackend) enterStateLockedOnEntry(newState ipn.State, unlock unlock } b.pauseOrResumeControlClientLocked() - if newState == ipn.Running { - b.stopOfflineAutoUpdate() - } else { - b.maybeStartOfflineAutoUpdate(prefs) - } - - unlock.UnlockEarly() - // prefs may change irrespective of state; WantRunning should be explicitly // set before potential early return even if the state is unchanged. b.health.SetIPNState(newState.String(), prefs.Valid() && prefs.WantRunning()) if oldState == newState { return } + b.logf("Switching ipn state %v -> %v (WantRunning=%v, nm=%v)", oldState, newState, prefs.WantRunning(), netMap != nil) - b.send(ipn.Notify{State: &newState}) + b.sendLocked(ipn.Notify{State: &newState}) switch newState { case ipn.NeedsLogin: - systemd.Status("Needs login: %s", authURL) - if b.seamlessRenewalEnabled() { - break - } - b.blockEngineUpdates(true) + feature.SystemdStatus("Needs login: %s", authURL) + // always block updates on NeedsLogin even if seamless renewal is enabled, + // to prevent calls to authReconfig from reconfiguring the engine when our + // key has expired and we're waiting to authenticate to use the new key. + b.blockEngineUpdatesLocked(true) fallthrough - case ipn.Stopped: + case ipn.Stopped, ipn.NoState: + // Unconfigure the engine if it has stopped (WantRunning is set to false) + // or if we've switched to a different profile and the state is unknown. err := b.e.Reconfig(&wgcfg.Config{}, &router.Config{}, &dns.Config{}) if err != nil { b.logf("Reconfig(down): %v", err) } - if authURL == "" { - systemd.Status("Stopped; run 'tailscale up' to log in") + if newState == ipn.Stopped && authURL == "" { + feature.SystemdStatus("Stopped; run 'tailscale up' to log in") } case ipn.Starting, ipn.NeedsMachineAuth: - b.authReconfig() + b.authReconfigLocked() // Needed so that UpdateEndpoints can run - b.e.RequestStatus() + b.goTracker.Go(b.e.RequestStatus) case ipn.Running: - var addrStrs []string - addrs := netMap.GetAddresses() - for i := range addrs.Len() { - addrStrs = append(addrStrs, addrs.At(i).Addr().String()) + if feature.CanSystemdStatus { + var addrStrs []string + addrs := netMap.GetAddresses() + for _, p := range addrs.All() { + addrStrs = append(addrStrs, p.Addr().String()) + } + feature.SystemdStatus("Connected; %s; %s", activeLogin, strings.Join(addrStrs, " ")) } - systemd.Status("Connected; %s; %s", activeLogin, strings.Join(addrStrs, " ")) - case ipn.NoState: - // Do nothing. default: b.logf("[unexpected] unknown newState %#v", newState) } @@ -4912,7 +5679,8 @@ func (b *LocalBackend) NodeKey() key.NodePublic { func (b *LocalBackend) nextStateLocked() ipn.State { var ( cc = b.cc - netMap = b.netMap + cn = b.currentNode() + netMap = cn.NetMap() state = b.state blocked = b.blocked st = b.engineStatus @@ -4982,100 +5750,23 @@ func (b *LocalBackend) nextStateLocked() ipn.State { // that have happened. It is invoked from the various callbacks that // feed events into LocalBackend. // -// TODO(apenwarr): use a channel or something to prevent reentrancy? -// Or maybe just call the state machine from fewer places. -func (b *LocalBackend) stateMachine() { - unlock := b.lockAndGetUnlock() - b.stateMachineLockedOnEntry(unlock) -} - -// stateMachineLockedOnEntry is like stateMachine but requires b.mu be held to -// call it, but it unlocks b.mu when done (via unlock, a once func). -func (b *LocalBackend) stateMachineLockedOnEntry(unlock unlockOnce) { - b.enterStateLockedOnEntry(b.nextStateLocked(), unlock) -} - -// lockAndGetUnlock locks b.mu and returns a sync.OnceFunc function that will -// unlock it at most once. -// -// This is all very unfortunate but exists as a guardrail against the -// unfortunate "lockedOnEntry" methods in this package (primarily -// enterStateLockedOnEntry) that require b.mu held to be locked on entry to the -// function but unlock the mutex on their way out. As a stepping stone to -// cleaning things up (as of 2024-04-06), we at least pass the unlock func -// around now and defer unlock in the caller to avoid missing unlocks and double -// unlocks. TODO(bradfitz,maisem): make the locking in this package more -// traditional (simple). See https://github.com/tailscale/tailscale/issues/11649 -func (b *LocalBackend) lockAndGetUnlock() (unlock unlockOnce) { - b.mu.Lock() - var unlocked atomic.Bool - return func() bool { - if unlocked.CompareAndSwap(false, true) { - b.mu.Unlock() - return true - } - return false - } -} - -// unlockOnce is a func that unlocks only b.mu the first time it's called. -// Therefore it can be safely deferred to catch error paths, without worrying -// about double unlocks if a different point in the code later needs to explicitly -// unlock it first as well. It reports whether it was unlocked. -type unlockOnce func() bool - -// UnlockEarly unlocks the LocalBackend.mu. It panics if u returns false, -// indicating that this unlocker was already used. -// -// We're using this method to help us document & find the places that have -// atypical locking patterns. See -// https://github.com/tailscale/tailscale/issues/11649 for background. -// -// A normal unlock is a deferred one or an explicit b.mu.Unlock a few lines -// after the lock, without lots of control flow in-between. An "early" unlock is -// one that happens in weird places, like in various "LockedOnEntry" methods in -// this package that require the mutex to be locked on entry but unlock it -// somewhere in the middle (maybe several calls away) and then sometimes proceed -// to lock it again. -// -// The reason UnlockeEarly panics if already called is because these are the -// points at which it's assumed that the mutex is already held and it now needs -// to be released. If somebody already released it, that invariant was violated. -// On the other hand, simply calling u only returns false instead of panicking -// so you can defer it without care, confident you got all the error return -// paths which were previously done by hand. -func (u unlockOnce) UnlockEarly() { - if !u() { - panic("Unlock on already-called unlockOnce") - } +// requires b.mu to be held. +func (b *LocalBackend) stateMachineLocked() { + b.enterStateLocked(b.nextStateLocked()) } // stopEngineAndWait deconfigures the local network data plane, and -// waits for it to deliver a status update before returning. +// waits for it to deliver a status update indicating it has stopped +// before returning. // -// TODO(danderson): this may be racy. We could unblock upon receiving -// a status update that predates the "I've shut down" update. -func (b *LocalBackend) stopEngineAndWait() { +// b.mu must be held. +func (b *LocalBackend) stopEngineAndWaitLocked() { b.logf("stopEngineAndWait...") - b.e.Reconfig(&wgcfg.Config{}, &router.Config{}, &dns.Config{}) - b.requestEngineStatusAndWait() + st, _ := b.e.ResetAndStop() // TODO: what should we do if this returns an error? + b.setWgengineStatusLocked(st) b.logf("stopEngineAndWait: done.") } -// Requests the wgengine status, and does not return until the status -// was delivered (to the usual callback). -func (b *LocalBackend) requestEngineStatusAndWait() { - b.logf("requestEngineStatusAndWait") - - b.statusLock.Lock() - defer b.statusLock.Unlock() - - go b.e.RequestStatus() - b.logf("requestEngineStatusAndWait: waiting...") - b.statusChanged.Wait() // temporarily releases lock while waiting - b.logf("requestEngineStatusAndWait: got status update.") -} - // setControlClientLocked sets the control client to cc, // which may be nil. // @@ -5083,6 +5774,7 @@ func (b *LocalBackend) requestEngineStatusAndWait() { func (b *LocalBackend) setControlClientLocked(cc controlclient.Client) { b.cc = cc b.ccAuto, _ = cc.(*controlclient.Auto) + b.ignoreControlClientUpdates.Store(cc == nil) } // resetControlClientLocked sets b.cc to nil and returns the old value. If the @@ -5117,41 +5809,7 @@ func (b *LocalBackend) resetControlClientLocked() controlclient.Client { func (b *LocalBackend) resetAuthURLLocked() { b.authURL = "" b.authURLTime = time.Time{} - b.interact = false -} - -// ResetForClientDisconnect resets the backend for GUI clients running -// in interactive (non-headless) mode. This is currently used only by -// Windows. This causes all state to be cleared, lest an unrelated user -// connect to tailscaled next. But it does not trigger a logout; we -// don't want to the user to have to reauthenticate in the future -// when they restart the GUI. -func (b *LocalBackend) ResetForClientDisconnect() { - b.logf("LocalBackend.ResetForClientDisconnect") - - unlock := b.lockAndGetUnlock() - defer unlock() - - prevCC := b.resetControlClientLocked() - if prevCC != nil { - // Needs to happen without b.mu held. - defer prevCC.Shutdown() - } - - b.setNetMapLocked(nil) - b.pm.Reset() - if b.currentUser != nil { - if c, ok := b.currentUser.(ipnauth.ActorCloser); ok { - c.Close() - } - b.currentUser = nil - } - b.keyExpired = false - b.resetAuthURLLocked() - b.activeLogin = "" - b.resetDialPlan() - b.setAtomicValuesFromPrefsLocked(ipn.PrefsView{}) - b.enterStateLockedOnEntry(ipn.Stopped, unlock) + b.authActor = nil } func (b *LocalBackend) ShouldRunSSH() bool { return b.sshAtomicBool.Load() && envknob.CanSSHD() } @@ -5181,7 +5839,7 @@ func (b *LocalBackend) setWebClientAtomicBoolLocked(nm *netmap.NetworkMap) { shouldRun := !nm.HasCap(tailcfg.NodeAttrDisableWebClient) wasRunning := b.webClientAtomicBool.Swap(shouldRun) if wasRunning && !shouldRun { - go b.webClientShutdown() // stop web client + b.goTracker.Go(b.webClientShutdown) // stop web client } } @@ -5190,6 +5848,9 @@ func (b *LocalBackend) setWebClientAtomicBoolLocked(nm *netmap.NetworkMap) { // // b.mu must be held. func (b *LocalBackend) setExposeRemoteWebClientAtomicBoolLocked(prefs ipn.PrefsView) { + if !buildfeatures.HasWebClient { + return + } shouldExpose := prefs.Valid() && prefs.RunWebClient() b.exposeRemoteWebClientAtomicBool.Store(shouldExpose) } @@ -5206,12 +5867,12 @@ func (b *LocalBackend) ShouldHandleViaIP(ip netip.Addr) bool { // Logout logs out the current profile, if any, and waits for the logout to // complete. -func (b *LocalBackend) Logout(ctx context.Context) error { - unlock := b.lockAndGetUnlock() - defer unlock() +func (b *LocalBackend) Logout(ctx context.Context, actor ipnauth.Actor) error { + b.mu.Lock() if !b.hasNodeKeyLocked() { // Already logged out. + b.mu.Unlock() return nil } cc := b.cc @@ -5220,15 +5881,17 @@ func (b *LocalBackend) Logout(ctx context.Context) error { // delete it later. profile := b.pm.CurrentProfile() - _, err := b.editPrefsLockedOnEntry(&ipn.MaskedPrefs{ - WantRunningSet: true, - LoggedOutSet: true, - Prefs: ipn.Prefs{WantRunning: false, LoggedOut: true}, - }, unlock) + _, err := b.editPrefsLocked( + actor, + &ipn.MaskedPrefs{ + WantRunningSet: true, + LoggedOutSet: true, + Prefs: ipn.Prefs{WantRunning: false, LoggedOut: true}, + }) + b.mu.Unlock() if err != nil { return err } - // b.mu is now unlocked, after editPrefsLockedOnEntry. // Clear any previous dial plan(s), if set. b.resetDialPlan() @@ -5248,14 +5911,14 @@ func (b *LocalBackend) Logout(ctx context.Context) error { return err } - unlock = b.lockAndGetUnlock() - defer unlock() + b.mu.Lock() + defer b.mu.Unlock() - if err := b.pm.DeleteProfile(profile.ID); err != nil { + if err := b.pm.DeleteProfile(profile.ID()); err != nil { b.logf("error deleting profile: %v", err) return err } - return b.resetForProfileChangeLockedOnEntry(unlock) + return b.resetForProfileChangeLocked() } // setNetInfo sets b.hostinfo.NetInfo to ni, and passes ni along to the @@ -5296,290 +5959,203 @@ func (b *LocalBackend) setNetInfo(ni *tailcfg.NetInfo) { } cc.SetNetInfo(ni) if refresh { - unlock := b.lockAndGetUnlock() - defer unlock() - b.setAutoExitNodeIDLockedOnEntry(unlock) + b.RefreshExitNode() } } -func (b *LocalBackend) setAutoExitNodeIDLockedOnEntry(unlock unlockOnce) { - defer unlock() - - prefs := b.pm.CurrentPrefs() - if !prefs.Valid() { - b.logf("[unexpected]: received tailnet exit node ID pref change callback but current prefs are nil") - return - } - prefsClone := prefs.AsStruct() - newSuggestion, err := b.suggestExitNodeLocked(nil) - if err != nil { - b.logf("setAutoExitNodeID: %v", err) - return - } - prefsClone.ExitNodeID = newSuggestion.ID - _, err = b.editPrefsLockedOnEntry(&ipn.MaskedPrefs{ - Prefs: *prefsClone, - ExitNodeIDSet: true, - }, unlock) - if err != nil { - b.logf("setAutoExitNodeID: failed to apply exit node ID preference: %v", err) +// RefreshExitNode determines which exit node to use based on the current +// prefs and netmap and switches to it if needed. +func (b *LocalBackend) RefreshExitNode() { + if !buildfeatures.HasUseExitNode { return } + b.mu.Lock() + defer b.mu.Unlock() + b.refreshExitNodeLocked() } -// setNetMapLocked updates the LocalBackend state to reflect the newly -// received nm. If nm is nil, it resets all configuration as though -// Tailscale is turned off. -func (b *LocalBackend) setNetMapLocked(nm *netmap.NetworkMap) { - b.dialer.SetNetMap(nm) - if ns, ok := b.sys.Netstack.GetOK(); ok { - ns.UpdateNetstackIPs(nm) - } - var login string - if nm != nil { - login = cmp.Or(nm.UserProfiles[nm.User()].LoginName, "") - } - b.netMap = nm - b.updatePeersFromNetmapLocked(nm) - if login != b.activeLogin { - b.logf("active login: %v", login) - b.activeLogin = login +// refreshExitNodeLocked is like RefreshExitNode but requires b.mu be held. +func (b *LocalBackend) refreshExitNodeLocked() { + if b.resolveExitNodeLocked() { + b.authReconfigLocked() } - b.pauseOrResumeControlClientLocked() +} - if nm != nil { - b.health.SetControlHealth(nm.ControlHealth) - } else { - b.health.SetControlHealth(nil) +// resolveExitNodeLocked determines which exit node to use based on the current prefs +// and netmap. It updates the exit node ID in the prefs if needed, updates the +// exit node ID in the hostinfo if needed, sends a notification to clients, and +// returns true if the exit node has changed. +// +// It is the caller's responsibility to reconfigure routes and actually +// start using the selected exit node, if needed. +// +// b.mu must be held. +func (b *LocalBackend) resolveExitNodeLocked() (changed bool) { + if !buildfeatures.HasUseExitNode { + return false } - // Determine if file sharing is enabled - fs := nm.HasCap(tailcfg.CapabilityFileSharing) - if fs != b.capFileSharing { - osshare.SetFileSharingEnabled(fs, b.logf) + nm := b.currentNode().NetMap() + prefs := b.pm.CurrentPrefs().AsStruct() + if !b.resolveExitNodeInPrefsLocked(prefs) { + return } - b.capFileSharing = fs - if nm.HasCap(tailcfg.NodeAttrLinuxMustUseIPTables) { - b.capForcedNetfilter = "iptables" - } else if nm.HasCap(tailcfg.NodeAttrLinuxMustUseNfTables) { - b.capForcedNetfilter = "nftables" - } else { - b.capForcedNetfilter = "" // empty string means client can auto-detect + if err := b.pm.SetPrefs(prefs.View(), ipn.NetworkProfile{ + MagicDNSName: nm.MagicDNSSuffix(), + DomainName: nm.DomainName(), + DisplayName: nm.TailnetDisplayName(), + }); err != nil { + b.logf("failed to save exit node changes: %v", err) } - b.MagicConn().SetSilentDisco(b.ControlKnobs().SilentDisco.Load()) - b.MagicConn().SetProbeUDPLifetime(b.ControlKnobs().ProbeUDPLifetime.Load()) - - b.setDebugLogsByCapabilityLocked(nm) - - // See the netns package for documentation on what this capability does. - netns.SetBindToInterfaceByRoute(nm.HasCap(tailcfg.CapabilityBindToInterfaceByRoute)) - netns.SetDisableBindConnToInterface(nm.HasCap(tailcfg.CapabilityDebugDisableBindConnToInterface)) - - b.setTCPPortsInterceptedFromNetmapAndPrefsLocked(b.pm.CurrentPrefs()) - if nm == nil { - b.nodeByAddr = nil - - // If there is no netmap, the client is going into a "turned off" - // state so reset the metrics. - b.metrics.approvedRoutes.Set(0) - b.metrics.primaryRoutes.Set(0) - return + // Send the resolved exit node to control via [tailcfg.Hostinfo]. + // [LocalBackend.applyPrefsToHostinfoLocked] usually sets the Hostinfo, + // but it deferred until this point because there was a bogus ExitNodeID + // in the prefs. + // + // TODO(sfllaw): Mutating b.hostinfo here is undesirable, mutating + // in-place doubly so. + sid := prefs.ExitNodeID + if sid != unresolvedExitNodeID && b.hostinfo.ExitNodeID != sid { + b.hostinfo.ExitNodeID = sid + b.goTracker.Go(b.doSetHostinfoFilterServices) } - // Update the nodeByAddr index. - if b.nodeByAddr == nil { - b.nodeByAddr = map[netip.Addr]tailcfg.NodeID{} + b.sendToLocked(ipn.Notify{Prefs: ptr.To(prefs.View())}, allClients) + return true +} + +// reconcilePrefsLocked applies policy overrides, exit node resolution, +// and other post-processing to the prefs, and reports whether the prefs +// were modified as a result. +// +// It must not perform any reconfiguration, as the prefs are not yet effective. +// +// b.mu must be held. +func (b *LocalBackend) reconcilePrefsLocked(prefs *ipn.Prefs) (changed bool) { + if buildfeatures.HasSystemPolicy && b.applySysPolicyLocked(prefs) { + changed = true } - // First pass, mark everything unwanted. - for k := range b.nodeByAddr { - b.nodeByAddr[k] = 0 + if buildfeatures.HasUseExitNode && b.resolveExitNodeInPrefsLocked(prefs) { + changed = true } - addNode := func(n tailcfg.NodeView) { - for _, ipp := range n.Addresses().All() { - if ipp.IsSingleIP() { - b.nodeByAddr[ipp.Addr()] = n.ID() - } - } + if changed { + b.logf("prefs reconciled: %v", prefs.Pretty()) } - if nm.SelfNode.Valid() { - addNode(nm.SelfNode) + return changed +} - var approved float64 - for _, route := range nm.SelfNode.AllowedIPs().All() { - if !views.SliceContains(nm.SelfNode.Addresses(), route) && !tsaddr.IsExitRoute(route) { - approved++ - } - } - b.metrics.approvedRoutes.Set(approved) - b.metrics.primaryRoutes.Set(float64(tsaddr.WithoutExitRoute(nm.SelfNode.PrimaryRoutes()).Len())) +// resolveExitNodeInPrefsLocked determines which exit node to use +// based on the specified prefs and netmap. It updates the exit node ID +// in the prefs if needed, and returns true if the exit node has changed. +// +// b.mu must be held. +func (b *LocalBackend) resolveExitNodeInPrefsLocked(prefs *ipn.Prefs) (changed bool) { + if !buildfeatures.HasUseExitNode { + return false } - for _, p := range nm.Peers { - addNode(p) + if b.resolveAutoExitNodeLocked(prefs) { + changed = true } - // Third pass, actually delete the unwanted items. - for k, v := range b.nodeByAddr { - if v == 0 { - delete(b.nodeByAddr, k) - } + if b.resolveExitNodeIPLocked(prefs) { + changed = true } - - b.updateDrivePeersLocked(nm) - b.driveNotifyCurrentSharesLocked() + return changed } -func (b *LocalBackend) updatePeersFromNetmapLocked(nm *netmap.NetworkMap) { - if nm == nil { - b.peers = nil - return - } +// setNetMapLocked updates the LocalBackend state to reflect the newly +// received nm. If nm is nil, it resets all configuration as though +// Tailscale is turned off. +func (b *LocalBackend) setNetMapLocked(nm *netmap.NetworkMap) { + oldSelf := b.currentNode().NetMap().SelfNodeOrZero() - // First pass, mark everything unwanted. - for k := range b.peers { - b.peers[k] = tailcfg.NodeView{} + b.dialer.SetNetMap(nm) + if ns, ok := b.sys.Netstack.GetOK(); ok { + ns.UpdateNetstackIPs(nm) } - - // Second pass, add everything wanted. - for _, p := range nm.Peers { - mak.Set(&b.peers, p.ID(), p) + var login string + if nm != nil { + login = cmp.Or(profileFromView(nm.UserProfiles[nm.User()]).LoginName, "") + } + b.currentNode().SetNetMap(nm) + if login != b.activeLogin { + b.logf("active login: %v", login) + b.activeLogin = login } + b.pauseOrResumeControlClientLocked() - // Third pass, remove deleted things. - for k, v := range b.peers { - if !v.Valid() { - delete(b.peers, k) + if nm != nil { + messages := make(map[tailcfg.DisplayMessageID]tailcfg.DisplayMessage) + for id, msg := range nm.DisplayMessages { + if msg.PrimaryAction != nil && !b.validPopBrowserURLLocked(msg.PrimaryAction.URL) { + msg.PrimaryAction = nil + } + messages[id] = msg } + b.health.SetControlHealth(messages) + } else { + b.health.SetControlHealth(nil) } -} - -// responseBodyWrapper wraps an io.ReadCloser and stores -// the number of bytesRead. -type responseBodyWrapper struct { - io.ReadCloser - bytesRx int64 - bytesTx int64 - log logger.Logf - method string - statusCode int - contentType string - fileExtension string - shareNodeKey string - selfNodeKey string - contentLength int64 -} - -// logAccess logs the taildrive: access: log line. If the logger is nil, -// the log will not be written. -func (rbw *responseBodyWrapper) logAccess(err string) { - if rbw.log == nil { - return - } - - // Some operating systems create and copy lots of 0 length hidden files for - // tracking various states. Omit these to keep logs from being too verbose. - if rbw.contentLength > 0 { - rbw.log("taildrive: access: %s from %s to %s: status-code=%d ext=%q content-type=%q content-length=%.f tx=%.f rx=%.f err=%q", rbw.method, rbw.selfNodeKey, rbw.shareNodeKey, rbw.statusCode, rbw.fileExtension, rbw.contentType, roundTraffic(rbw.contentLength), roundTraffic(rbw.bytesTx), roundTraffic(rbw.bytesRx), err) - } -} -// Read implements the io.Reader interface. -func (rbw *responseBodyWrapper) Read(b []byte) (int, error) { - n, err := rbw.ReadCloser.Read(b) - rbw.bytesRx += int64(n) - if err != nil && !errors.Is(err, io.EOF) { - rbw.logAccess(err.Error()) + if runtime.GOOS == "linux" && buildfeatures.HasOSRouter { + if nm.HasCap(tailcfg.NodeAttrLinuxMustUseIPTables) { + b.capForcedNetfilter = "iptables" + } else if nm.HasCap(tailcfg.NodeAttrLinuxMustUseNfTables) { + b.capForcedNetfilter = "nftables" + } else { + b.capForcedNetfilter = "" // empty string means client can auto-detect + } } - return n, err -} + b.MagicConn().SetSilentDisco(b.ControlKnobs().SilentDisco.Load()) + b.MagicConn().SetProbeUDPLifetime(b.ControlKnobs().ProbeUDPLifetime.Load()) -// Close implements the io.Close interface. -func (rbw *responseBodyWrapper) Close() error { - err := rbw.ReadCloser.Close() - var errStr string - if err != nil { - errStr = err.Error() + if buildfeatures.HasDebug { + b.setDebugLogsByCapabilityLocked(nm) } - rbw.logAccess(errStr) - return err -} - -// driveTransport is an http.RoundTripper that wraps -// b.Dialer().PeerAPITransport() with metrics tracking. -type driveTransport struct { - b *LocalBackend - tr *http.Transport -} + // See the netns package for documentation on what these capability do. + netns.SetBindToInterfaceByRoute(b.logf, nm.HasCap(tailcfg.CapabilityBindToInterfaceByRoute)) + netns.SetDisableBindConnToInterface(b.logf, nm.HasCap(tailcfg.CapabilityDebugDisableBindConnToInterface)) + netns.SetDisableBindConnToInterfaceAppleExt(b.logf, nm.HasCap(tailcfg.CapabilityDebugDisableBindConnToInterfaceAppleExt)) -func (b *LocalBackend) newDriveTransport() *driveTransport { - return &driveTransport{ - b: b, - tr: b.Dialer().PeerAPITransport(), + b.setTCPPortsInterceptedFromNetmapAndPrefsLocked(b.pm.CurrentPrefs()) + if buildfeatures.HasServe { + b.ipVIPServiceMap = nm.GetIPVIPServiceMap() } -} -func (dt *driveTransport) RoundTrip(req *http.Request) (resp *http.Response, err error) { - // Some WebDAV clients include origin and refer headers, which peerapi does - // not like. Remove them. - req.Header.Del("origin") - req.Header.Del("referer") - - bw := &requestBodyWrapper{} - if req.Body != nil { - bw.ReadCloser = req.Body - req.Body = bw + if !oldSelf.Equal(nm.SelfNodeOrZero()) { + for _, f := range b.extHost.Hooks().OnSelfChange { + f(nm.SelfNode) + } } - defer func() { - contentType := "unknown" - switch req.Method { - case httpm.PUT: - if ct := req.Header.Get("Content-Type"); ct != "" { - contentType = ct - } - case httpm.GET: - if ct := resp.Header.Get("Content-Type"); ct != "" { - contentType = ct + if buildfeatures.HasAdvertiseRoutes { + if nm == nil { + // If there is no netmap, the client is going into a "turned off" + // state so reset the metrics. + b.metrics.approvedRoutes.Set(0) + } else if nm.SelfNode.Valid() { + var approved float64 + for _, route := range nm.SelfNode.AllowedIPs().All() { + if !views.SliceContains(nm.SelfNode.Addresses(), route) && !tsaddr.IsExitRoute(route) { + approved++ + } } - default: - return - } - - dt.b.mu.Lock() - selfNodeKey := dt.b.netMap.SelfNode.Key().ShortString() - dt.b.mu.Unlock() - n, _, ok := dt.b.WhoIs("tcp", netip.MustParseAddrPort(req.URL.Host)) - shareNodeKey := "unknown" - if ok { - shareNodeKey = string(n.Key().ShortString()) - } - - rbw := responseBodyWrapper{ - log: dt.b.logf, - method: req.Method, - bytesTx: int64(bw.bytesRead), - selfNodeKey: selfNodeKey, - shareNodeKey: shareNodeKey, - contentType: contentType, - contentLength: resp.ContentLength, - fileExtension: parseDriveFileExtensionForLog(req.URL.Path), - statusCode: resp.StatusCode, - ReadCloser: resp.Body, + b.metrics.approvedRoutes.Set(approved) } + } - if resp.StatusCode >= 400 { - // in case of error response, just log immediately - rbw.logAccess("") - } else { - resp.Body = &rbw + if buildfeatures.HasDrive && nm != nil { + if f, ok := hookSetNetMapLockedDrive.GetOk(); ok { + f(b, nm) } - }() - - return dt.tr.RoundTrip(req) + } } +var hookSetNetMapLockedDrive feature.Hook[func(*LocalBackend, *netmap.NetworkMap)] + // roundTraffic rounds bytes. This is used to preserve user privacy within logs. func roundTraffic(bytes int64) float64 { var x float64 @@ -5618,48 +6194,6 @@ func (b *LocalBackend) setDebugLogsByCapabilityLocked(nm *netmap.NetworkMap) { } } -// reloadServeConfigLocked reloads the serve config from the store or resets the -// serve config to nil if not logged in. The "changed" parameter, when false, instructs -// the method to only run the reset-logic and not reload the store from memory to ensure -// foreground sessions are not removed if they are not saved on disk. -func (b *LocalBackend) reloadServeConfigLocked(prefs ipn.PrefsView) { - if b.netMap == nil || !b.netMap.SelfNode.Valid() || !prefs.Valid() || b.pm.CurrentProfile().ID == "" { - // We're not logged in, so we don't have a profile. - // Don't try to load the serve config. - b.lastServeConfJSON = mem.B(nil) - b.serveConfig = ipn.ServeConfigView{} - return - } - - confKey := ipn.ServeConfigKey(b.pm.CurrentProfile().ID) - // TODO(maisem,bradfitz): prevent reading the config from disk - // if the profile has not changed. - confj, err := b.store.ReadState(confKey) - if err != nil { - b.lastServeConfJSON = mem.B(nil) - b.serveConfig = ipn.ServeConfigView{} - return - } - if b.lastServeConfJSON.Equal(mem.B(confj)) { - return - } - b.lastServeConfJSON = mem.B(confj) - var conf ipn.ServeConfig - if err := json.Unmarshal(confj, &conf); err != nil { - b.logf("invalid ServeConfig %q in StateStore: %v", confKey, err) - b.serveConfig = ipn.ServeConfigView{} - return - } - - // remove inactive sessions - maps.DeleteFunc(conf.Foreground, func(sessionID string, sc *ipn.ServeConfig) bool { - _, ok := b.notifyWatchers[sessionID] - return !ok - }) - - b.serveConfig = conf.View() -} - // setTCPPortsInterceptedFromNetmapAndPrefsLocked calls setTCPPortsIntercepted with // the ports that tailscaled should handle as a function of b.netMap and b.prefs. // @@ -5679,79 +6213,42 @@ func (b *LocalBackend) setTCPPortsInterceptedFromNetmapAndPrefsLocked(prefs ipn. } } - b.reloadServeConfigLocked(prefs) - if b.serveConfig.Valid() { - servePorts := make([]uint16, 0, 3) - b.serveConfig.RangeOverTCPs(func(port uint16, _ ipn.TCPPortHandlerView) bool { - if port > 0 { - servePorts = append(servePorts, uint16(port)) - } - return true - }) - handlePorts = append(handlePorts, servePorts...) - - b.setServeProxyHandlersLocked() - - // don't listen on netmap addresses if we're in userspace mode - if !b.sys.IsNetstack() { - b.updateServeTCPPortNetMapAddrListenersLocked(servePorts) - } - } - // Kick off a Hostinfo update to control if WireIngress changed. - if wire := b.wantIngressLocked(); b.hostinfo != nil && b.hostinfo.WireIngress != wire { - b.logf("Hostinfo.WireIngress changed to %v", wire) - b.hostinfo.WireIngress = wire - go b.doSetHostinfoFilterServices() + if f, ok := hookServeSetTCPPortsInterceptedFromNetmapAndPrefsLocked.GetOk(); ok { + v := f(b, prefs) + handlePorts = append(handlePorts, v...) } + // Update funnel and service hash info in hostinfo and kick off control update if needed. + b.maybeSentHostinfoIfChangedLocked(prefs) b.setTCPPortsIntercepted(handlePorts) } -// setServeProxyHandlersLocked ensures there is an http proxy handler for each -// backend specified in serveConfig. It expects serveConfig to be valid and -// up-to-date, so should be called after reloadServeConfigLocked. -func (b *LocalBackend) setServeProxyHandlersLocked() { - if !b.serveConfig.Valid() { +// hookMaybeMutateHostinfoLocked is a hook that allows conditional features +// to mutate the provided hostinfo before it is sent to control. +// +// The hook function should return true if it mutated the hostinfo. +// +// The LocalBackend's mutex is held while calling. +var hookMaybeMutateHostinfoLocked feature.Hooks[func(*LocalBackend, *tailcfg.Hostinfo, ipn.PrefsView) bool] + +// maybeSentHostinfoIfChangedLocked updates the hostinfo.ServicesHash, hostinfo.WireIngress and +// hostinfo.IngressEnabled fields and kicks off a Hostinfo update if the values have changed. +// +// b.mu must be held. +func (b *LocalBackend) maybeSentHostinfoIfChangedLocked(prefs ipn.PrefsView) { + if b.hostinfo == nil { return } - var backends map[string]bool - b.serveConfig.RangeOverWebs(func(_ ipn.HostPort, conf ipn.WebServerConfigView) (cont bool) { - conf.Handlers().Range(func(_ string, h ipn.HTTPHandlerView) (cont bool) { - backend := h.Proxy() - if backend == "" { - // Only create proxy handlers for servers with a proxy backend. - return true - } - mak.Set(&backends, backend, true) - if _, ok := b.serveProxyHandlers.Load(backend); ok { - return true - } - - b.logf("serve: creating a new proxy handler for %s", backend) - p, err := b.proxyHandlerForBackend(backend) - if err != nil { - // The backend endpoint (h.Proxy) should have been validated by expandProxyTarget - // in the CLI, so just log the error here. - b.logf("[unexpected] could not create proxy for %v: %s", backend, err) - return true - } - b.serveProxyHandlers.Store(backend, p) - return true - }) - return true - }) - - // Clean up handlers for proxy backends that are no longer present - // in configuration. - b.serveProxyHandlers.Range(func(key, value any) bool { - backend := key.(string) - if !backends[backend] { - b.logf("serve: closing idle connections to %s", backend) - b.serveProxyHandlers.Delete(backend) - value.(*reverseProxy).close() + changed := false + for _, f := range hookMaybeMutateHostinfoLocked { + if f(b, b.hostinfo, prefs) { + changed = true } - return true - }) + } + // Kick off a Hostinfo update to control if ingress status has changed. + if changed { + b.goTracker.Go(b.doSetHostinfoFilterServices) + } } // operatorUserName returns the current pref's OperatorUser's name, or the @@ -5799,141 +6296,6 @@ func (b *LocalBackend) TestOnlyPublicKeys() (machineKey key.MachinePublic, nodeK return mk, nk } -func (b *LocalBackend) removeFileWaiter(handle set.Handle) { - b.mu.Lock() - defer b.mu.Unlock() - delete(b.fileWaiters, handle) -} - -func (b *LocalBackend) addFileWaiter(wakeWaiter context.CancelFunc) set.Handle { - b.mu.Lock() - defer b.mu.Unlock() - return b.fileWaiters.Add(wakeWaiter) -} - -func (b *LocalBackend) WaitingFiles() ([]apitype.WaitingFile, error) { - b.mu.Lock() - apiSrv := b.peerAPIServer - b.mu.Unlock() - return mayDeref(apiSrv).taildrop.WaitingFiles() -} - -// AwaitWaitingFiles is like WaitingFiles but blocks while ctx is not done, -// waiting for any files to be available. -// -// On return, exactly one of the results will be non-empty or non-nil, -// respectively. -func (b *LocalBackend) AwaitWaitingFiles(ctx context.Context) ([]apitype.WaitingFile, error) { - if ff, err := b.WaitingFiles(); err != nil || len(ff) > 0 { - return ff, err - } - - for { - gotFile, gotFileCancel := context.WithCancel(context.Background()) - defer gotFileCancel() - - handle := b.addFileWaiter(gotFileCancel) - defer b.removeFileWaiter(handle) - - // Now that we've registered ourselves, check again, in case - // of race. Otherwise there's a small window where we could - // miss a file arrival and wait forever. - if ff, err := b.WaitingFiles(); err != nil || len(ff) > 0 { - return ff, err - } - - select { - case <-gotFile.Done(): - if ff, err := b.WaitingFiles(); err != nil || len(ff) > 0 { - return ff, err - } - case <-ctx.Done(): - return nil, ctx.Err() - } - } -} - -func (b *LocalBackend) DeleteFile(name string) error { - b.mu.Lock() - apiSrv := b.peerAPIServer - b.mu.Unlock() - return mayDeref(apiSrv).taildrop.DeleteFile(name) -} - -func (b *LocalBackend) OpenFile(name string) (rc io.ReadCloser, size int64, err error) { - b.mu.Lock() - apiSrv := b.peerAPIServer - b.mu.Unlock() - return mayDeref(apiSrv).taildrop.OpenFile(name) -} - -// hasCapFileSharing reports whether the current node has the file -// sharing capability enabled. -func (b *LocalBackend) hasCapFileSharing() bool { - b.mu.Lock() - defer b.mu.Unlock() - return b.capFileSharing -} - -// FileTargets lists nodes that the current node can send files to. -func (b *LocalBackend) FileTargets() ([]*apitype.FileTarget, error) { - var ret []*apitype.FileTarget - - b.mu.Lock() - defer b.mu.Unlock() - nm := b.netMap - if b.state != ipn.Running || nm == nil { - return nil, errors.New("not connected to the tailnet") - } - if !b.capFileSharing { - return nil, errors.New("file sharing not enabled by Tailscale admin") - } - for _, p := range b.peers { - if !b.peerIsTaildropTargetLocked(p) { - continue - } - if p.Hostinfo().OS() == "tvOS" { - continue - } - peerAPI := peerAPIBase(b.netMap, p) - if peerAPI == "" { - continue - } - ret = append(ret, &apitype.FileTarget{ - Node: p.AsStruct(), - PeerAPIURL: peerAPI, - }) - } - slices.SortFunc(ret, func(a, b *apitype.FileTarget) int { - return cmp.Compare(a.Node.Name, b.Node.Name) - }) - return ret, nil -} - -// peerIsTaildropTargetLocked reports whether p is a valid Taildrop file -// recipient from this node according to its ownership and the capabilities in -// the netmap. -// -// b.mu must be locked. -func (b *LocalBackend) peerIsTaildropTargetLocked(p tailcfg.NodeView) bool { - if b.netMap == nil || !p.Valid() { - return false - } - if b.netMap.User() == p.User() { - return true - } - if p.Addresses().Len() > 0 && - b.peerHasCapLocked(p.Addresses().At(0).Addr(), tailcfg.PeerCapabilityFileSharingTarget) { - // Explicitly noted in the netmap ACL caps as a target. - return true - } - return false -} - -func (b *LocalBackend) peerHasCapLocked(addr netip.Addr, wantCap tailcfg.PeerCapability) bool { - return b.peerCapsLocked(addr).HasCapability(wantCap) -} - // SetDNS adds a DNS record for the given domain name & TXT record // value. // @@ -5942,6 +6304,9 @@ func (b *LocalBackend) peerHasCapLocked(addr netip.Addr, wantCap tailcfg.PeerCap // This is the low-level interface. Other layers will provide more // friendly options to get HTTPS certs. func (b *LocalBackend) SetDNS(ctx context.Context, name, value string) error { + if !buildfeatures.HasACME { + return feature.ErrUnavailable + } req := &tailcfg.SetDNSRequest{ Version: 1, // TODO(bradfitz,maisem): use tailcfg.CurrentCapabilityVersion when using the Noise transport Type: "TXT", @@ -5972,8 +6337,7 @@ func (b *LocalBackend) SetDNS(ctx context.Context, name, value string) error { func peerAPIPorts(peer tailcfg.NodeView) (p4, p6 uint16) { svcs := peer.Hostinfo().Services() - for i := range svcs.Len() { - s := svcs.At(i) + for _, s := range svcs.All() { switch s.Proto { case tailcfg.PeerAPI4: p4 = s.Port @@ -5984,60 +6348,10 @@ func peerAPIPorts(peer tailcfg.NodeView) (p4, p6 uint16) { return } -// peerAPIURL returns an HTTP URL for the peer's peerapi service, -// without a trailing slash. -// -// If ip or port is the zero value then it returns the empty string. -func peerAPIURL(ip netip.Addr, port uint16) string { - if port == 0 || !ip.IsValid() { - return "" - } - return fmt.Sprintf("http://%v", netip.AddrPortFrom(ip, port)) -} - -// peerAPIBase returns the "http://ip:port" URL base to reach peer's peerAPI. -// It returns the empty string if the peer doesn't support the peerapi -// or there's no matching address family based on the netmap's own addresses. -func peerAPIBase(nm *netmap.NetworkMap, peer tailcfg.NodeView) string { - if nm == nil || !peer.Valid() || !peer.Hostinfo().Valid() { - return "" - } - - var have4, have6 bool - addrs := nm.GetAddresses() - for i := range addrs.Len() { - a := addrs.At(i) - if !a.IsSingleIP() { - continue - } - switch { - case a.Addr().Is4(): - have4 = true - case a.Addr().Is6(): - have6 = true - } - } - p4, p6 := peerAPIPorts(peer) - switch { - case have4 && p4 != 0: - return peerAPIURL(nodeIP(peer, netip.Addr.Is4), p4) - case have6 && p6 != 0: - return peerAPIURL(nodeIP(peer, netip.Addr.Is6), p6) - } - return "" -} - -func nodeIP(n tailcfg.NodeView, pred func(netip.Addr) bool) netip.Addr { - for i := range n.Addresses().Len() { - a := n.Addresses().At(i) - if a.IsSingleIP() && pred(a.Addr()) { - return a.Addr() - } - } - return netip.Addr{} -} - func (b *LocalBackend) CheckIPForwarding() error { + if !buildfeatures.HasAdvertiseRoutes { + return nil + } if b.sys.IsNetstackRouter() { return nil } @@ -6123,12 +6437,7 @@ func (b *LocalBackend) SetUDPGROForwarding() error { // DERPMap returns the current DERPMap in use, or nil if not connected. func (b *LocalBackend) DERPMap() *tailcfg.DERPMap { - b.mu.Lock() - defer b.mu.Unlock() - if b.netMap == nil { - return nil - } - return b.netMap.DERPMap + return b.currentNode().DERPMap() } // OfferingExitNode reports whether b is currently offering exit node @@ -6158,17 +6467,32 @@ func (b *LocalBackend) OfferingExitNode() bool { // OfferingAppConnector reports whether b is currently offering app // connector services. func (b *LocalBackend) OfferingAppConnector() bool { + if !buildfeatures.HasAppConnectors { + return false + } b.mu.Lock() defer b.mu.Unlock() return b.appConnector != nil } +// AppConnector returns the current AppConnector, or nil if not configured. +// +// TODO(nickkhyl): move app connectors to [nodeBackend], or perhaps a feature package? +func (b *LocalBackend) AppConnector() *appc.AppConnector { + if !buildfeatures.HasAppConnectors { + return nil + } + b.mu.Lock() + defer b.mu.Unlock() + return b.appConnector +} + // allowExitNodeDNSProxyToServeName reports whether the Exit Node DNS // proxy is allowed to serve responses for the provided DNS name. func (b *LocalBackend) allowExitNodeDNSProxyToServeName(name string) bool { b.mu.Lock() defer b.mu.Unlock() - nm := b.netMap + nm := b.NetMap() if nm == nil { return false } @@ -6209,11 +6533,28 @@ func (b *LocalBackend) SetExpirySooner(ctx context.Context, expiry time.Time) er return cc.SetExpirySooner(ctx, expiry) } +// SetDeviceAttrs does a synchronous call to the control plane to update +// the node's attributes. +// +// See docs on [tailcfg.SetDeviceAttributesRequest] for background. +func (b *LocalBackend) SetDeviceAttrs(ctx context.Context, attrs tailcfg.AttrUpdate) error { + b.mu.Lock() + cc := b.ccAuto + b.mu.Unlock() + if cc == nil { + return errors.New("not running") + } + return cc.SetDeviceAttrs(ctx, attrs) +} + // exitNodeCanProxyDNS reports the DoH base URL ("http://foo/dns-query") without query parameters // to exitNodeID's DoH service, if available. // // If exitNodeID is the zero valid, it returns "", false. func exitNodeCanProxyDNS(nm *netmap.NetworkMap, peers map[tailcfg.NodeID]tailcfg.NodeView, exitNodeID tailcfg.StableNodeID) (dohURL string, ok bool) { + if !buildfeatures.HasUseExitNode { + return "", false + } if exitNodeID.IsZero() { return "", false } @@ -6261,8 +6602,8 @@ func peerCanProxyDNS(p tailcfg.NodeView) bool { // If p.Cap is not populated (e.g. older control server), then do the old // thing of searching through services. services := p.Hostinfo().Services() - for i := range services.Len() { - if s := services.At(i); s.Proto == tailcfg.PeerAPIDNS && s.Port >= 1 { + for _, s := range services.All() { + if s.Proto == tailcfg.PeerAPIDNS && s.Port >= 1 { return true } } @@ -6279,70 +6620,47 @@ func (b *LocalBackend) DebugReSTUN() error { return nil } -// ControlKnobs returns the node's control knobs. -func (b *LocalBackend) ControlKnobs() *controlknobs.Knobs { - return b.sys.ControlKnobs() -} - -// MagicConn returns the backend's *magicsock.Conn. -func (b *LocalBackend) MagicConn() *magicsock.Conn { - return b.sys.MagicSock.Get() -} +func (b *LocalBackend) DebugRotateDiscoKey() error { + if !buildfeatures.HasDebug { + return nil + } -type keyProvingNoiseRoundTripper struct { - b *LocalBackend -} + mc := b.MagicConn() + mc.RotateDiscoKey() -func (n keyProvingNoiseRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { - b := n.b + newDiscoKey := mc.DiscoPublicKey() - var priv key.NodePrivate + if tunWrap, ok := b.sys.Tun.GetOK(); ok { + tunWrap.SetDiscoKey(newDiscoKey) + } b.mu.Lock() - cc := b.ccAuto - if nm := b.netMap; nm != nil { - priv = nm.PrivateKey - } + cc := b.cc b.mu.Unlock() - if cc == nil { - return nil, errors.New("no client") - } - if priv.IsZero() { - return nil, errors.New("no netmap or private key") - } - rt, ep, err := cc.GetSingleUseNoiseRoundTripper(req.Context()) - if err != nil { - return nil, err - } - if ep == nil || ep.NodeKeyChallenge.IsZero() { - go rt.RoundTrip(new(http.Request)) // return our reservation with a bogus request - return nil, errors.New("this coordination server does not support API calls over the Noise channel") + if cc != nil { + cc.SetDiscoPublicKey(newDiscoKey) } - // QueryEscape the node key since it has a colon in it. - nk := url.QueryEscape(priv.Public().String()) - req.SetBasicAuth(nk, "") + return nil +} - // genNodeProofHeaderValue returns the Tailscale-Node-Proof header's value to prove - // to chalPub that we control claimedPrivate. - genNodeProofHeaderValue := func(claimedPrivate key.NodePrivate, chalPub key.ChallengePublic) string { - // TODO(bradfitz): cache this somewhere? - box := claimedPrivate.SealToChallenge(chalPub, []byte(chalPub.String())) - return claimedPrivate.Public().String() + " " + base64.StdEncoding.EncodeToString(box) - } +func (b *LocalBackend) DebugPeerRelayServers() set.Set[netip.Addr] { + return b.MagicConn().PeerRelays() +} - // And prove we have the private key corresponding to the public key sent - // tin the basic auth username. - req.Header.Set("Tailscale-Node-Proof", genNodeProofHeaderValue(priv, ep.NodeKeyChallenge)) +// ControlKnobs returns the node's control knobs. +func (b *LocalBackend) ControlKnobs() *controlknobs.Knobs { + return b.sys.ControlKnobs() +} - return rt.RoundTrip(req) +// EventBus returns the node's event bus. +func (b *LocalBackend) EventBus() *eventbus.Bus { + return b.sys.Bus.Get() } -// KeyProvingNoiseRoundTripper returns an http.RoundTripper that uses the LocalBackend's -// DoNoiseRequest method and mutates the request to add an authorization header -// to prove the client's nodekey. -func (b *LocalBackend) KeyProvingNoiseRoundTripper() http.RoundTripper { - return keyProvingNoiseRoundTripper{b} +// MagicConn returns the backend's *magicsock.Conn. +func (b *LocalBackend) MagicConn() *magicsock.Conn { + return b.sys.MagicSock.Get() } // DoNoiseRequest sends a request to URL over the control plane @@ -6357,6 +6675,15 @@ func (b *LocalBackend) DoNoiseRequest(req *http.Request) (*http.Response, error) return cc.DoNoiseRequest(req) } +// ActiveSSHConns returns the number of active SSH connections, +// or 0 if SSH is not linked into the binary or available on the platform. +func (b *LocalBackend) ActiveSSHConns() int { + if b.sshServer == nil { + return 0 + } + return b.sshServer.NumActiveConns() +} + func (b *LocalBackend) sshServerOrInit() (_ SSHServer, err error) { b.mu.Lock() defer b.mu.Unlock() @@ -6373,6 +6700,13 @@ func (b *LocalBackend) sshServerOrInit() (_ SSHServer, err error) { return b.sshServer, nil } +var warnSyncDisabled = health.Register(&health.Warnable{ + Code: "sync-disabled", + Title: "Tailscale Sync is Disabled", + Severity: health.SeverityHigh, + Text: health.StaticMessage("Tailscale control plane syncing is disabled; run `tailscale set --sync` to restore"), +}) + var warnSSHSELinuxWarnable = health.Register(&health.Warnable{ Code: "ssh-unavailable-selinux-enabled", Title: "Tailscale SSH and SELinux", @@ -6388,6 +6722,14 @@ func (b *LocalBackend) updateSELinuxHealthWarning() { } } +func (b *LocalBackend) updateWarnSync(prefs ipn.PrefsView) { + if prefs.Sync().EqualBool(false) { + b.health.SetUnhealthy(warnSyncDisabled, nil) + } else { + b.health.SetHealthy(warnSyncDisabled) + } +} + func (b *LocalBackend) handleSSHConn(c net.Conn) (err error) { s, err := b.sshServerOrInit() if err != nil { @@ -6432,11 +6774,12 @@ func (b *LocalBackend) handleQuad100Port80Conn(w http.ResponseWriter, r *http.Re defer b.mu.Unlock() io.WriteString(w, "

Tailscale

\n") - if b.netMap == nil { + nm := b.currentNode().NetMap() + if nm == nil { io.WriteString(w, "No netmap.\n") return } - addrs := b.netMap.GetAddresses() + addrs := nm.GetAddresses() if addrs.Len() == 0 { io.WriteString(w, "No local addresses.\n") return @@ -6448,56 +6791,8 @@ func (b *LocalBackend) handleQuad100Port80Conn(w http.ResponseWriter, r *http.Re io.WriteString(w, "\n") } -func (b *LocalBackend) Doctor(ctx context.Context, logf logger.Logf) { - // We can write logs too fast for logtail to handle, even when - // opting-out of rate limits. Limit ourselves to at most one message - // per 20ms and a burst of 60 log lines, which should be fast enough to - // not block for too long but slow enough that we can upload all lines. - logf = logger.SlowLoggerWithClock(ctx, logf, 20*time.Millisecond, 60, b.clock.Now) - - var checks []doctor.Check - checks = append(checks, - permissions.Check{}, - routetable.Check{}, - ethtool.Check{}, - ) - - // Print a log message if any of the global DNS resolvers are Tailscale - // IPs; this can interfere with our ability to connect to the Tailscale - // controlplane. - checks = append(checks, doctor.CheckFunc("dns-resolvers", func(_ context.Context, logf logger.Logf) error { - b.mu.Lock() - nm := b.netMap - b.mu.Unlock() - if nm == nil { - return nil - } - - for i, resolver := range nm.DNS.Resolvers { - ipp, ok := resolver.IPPort() - if ok && tsaddr.IsTailscaleIP(ipp.Addr()) { - logf("resolver %d is a Tailscale address: %v", i, resolver) - } - } - for i, resolver := range nm.DNS.FallbackResolvers { - ipp, ok := resolver.IPPort() - if ok && tsaddr.IsTailscaleIP(ipp.Addr()) { - logf("fallback resolver %d is a Tailscale address: %v", i, resolver) - } - } - return nil - })) - - // TODO(andrew): more - - numChecks := len(checks) - checks = append(checks, doctor.CheckFunc("numchecks", func(_ context.Context, log logger.Logf) error { - log("%d checks", numChecks) - return nil - })) - - doctor.RunChecks(ctx, logf, checks...) -} +// HookDoctor is an optional hook for the "doctor" problem diagnosis feature. +var HookDoctor feature.Hook[func(context.Context, *LocalBackend, logger.Logf)] // SetDevStateStore updates the LocalBackend's state storage to the provided values. // @@ -6527,76 +6822,37 @@ func (b *LocalBackend) ShouldInterceptTCPPort(port uint16) bool { return b.shouldInterceptTCPPortAtomic.Load()(port) } +// ShouldInterceptVIPServiceTCPPort reports whether the given TCP port number +// to a VIP service should be intercepted by Tailscaled and handled in-process. +func (b *LocalBackend) ShouldInterceptVIPServiceTCPPort(ap netip.AddrPort) bool { + if !buildfeatures.HasServe { + return false + } + f := b.shouldInterceptVIPServicesTCPPortAtomic.Load() + if f == nil { + return false + } + return f(ap) +} + // SwitchProfile switches to the profile with the given id. // It will restart the backend on success. // If the profile is not known, it returns an errProfileNotFound. func (b *LocalBackend) SwitchProfile(profile ipn.ProfileID) error { - if b.CurrentProfile().ID == profile { - return nil - } - unlock := b.lockAndGetUnlock() - defer unlock() + b.mu.Lock() + defer b.mu.Unlock() - oldControlURL := b.pm.CurrentPrefs().ControlURLOrDefault() - if err := b.pm.SwitchProfile(profile); err != nil { - return err + oldControlURL := b.pm.CurrentPrefs().ControlURLOrDefault(b.polc) + if _, changed, err := b.pm.SwitchToProfileByID(profile); !changed || err != nil { + return err // nil if we're already on the target profile } - // As an optimization, only reset the dialPlan if the control URL - // changed; we treat an empty URL as "unknown" and always reset. - newControlURL := b.pm.CurrentPrefs().ControlURLOrDefault() - if oldControlURL != newControlURL || oldControlURL == "" || newControlURL == "" { + // As an optimization, only reset the dialPlan if the control URL changed. + if newControlURL := b.pm.CurrentPrefs().ControlURLOrDefault(b.polc); oldControlURL != newControlURL { b.resetDialPlan() } - return b.resetForProfileChangeLockedOnEntry(unlock) -} - -func (b *LocalBackend) initTKALocked() error { - cp := b.pm.CurrentProfile() - if cp.ID == "" { - b.tka = nil - return nil - } - if b.tka != nil { - if b.tka.profile == cp.ID { - // Already initialized. - return nil - } - // As we're switching profiles, we need to reset the TKA to nil. - b.tka = nil - } - root := b.TailscaleVarRoot() - if root == "" { - b.tka = nil - b.logf("network-lock unavailable; no state directory") - return nil - } - - chonkDir := b.chonkPathLocked() - if _, err := os.Stat(chonkDir); err == nil { - // The directory exists, which means network-lock has been initialized. - storage, err := tka.ChonkDir(chonkDir) - if err != nil { - return fmt.Errorf("opening tailchonk: %v", err) - } - authority, err := tka.Open(storage) - if err != nil { - return fmt.Errorf("initializing tka: %v", err) - } - if err := authority.Compact(storage, tkaCompactionDefaults); err != nil { - b.logf("tka compaction failed: %v", err) - } - - b.tka = &tkaState{ - profile: cp.ID, - authority: authority, - storage: storage, - } - b.logf("tka initialized at head %x", authority.Head()) - } - - return nil + return b.resetForProfileChangeLocked() } // resetDialPlan resets the dialPlan for this LocalBackend. It will log if @@ -6610,39 +6866,54 @@ func (b *LocalBackend) resetDialPlan() { } } -// resetForProfileChangeLockedOnEntry resets the backend for a profile change. +// resetForProfileChangeLocked resets the backend for a profile change. // -// b.mu must held on entry. It is released on exit. -func (b *LocalBackend) resetForProfileChangeLockedOnEntry(unlock unlockOnce) error { - defer unlock() - +// b.mu must be held. +func (b *LocalBackend) resetForProfileChangeLocked() error { if b.shutdownCalled { // Prevent a call back to Start during Shutdown, which calls Logout for // ephemeral nodes, which can then call back here. But we're shutting // down, so no need to do any work. return nil } + newNode := newNodeBackend(b.ctx, b.logf, b.sys.Bus.Get()) + if oldNode := b.currentNodeAtomic.Swap(newNode); oldNode != nil { + oldNode.shutdown(errNodeContextChanged) + } + defer newNode.ready() b.setNetMapLocked(nil) // Reset netmap. + b.updateFilterLocked(ipn.PrefsView{}) // Reset the NetworkMap in the engine b.e.SetNetworkMap(new(netmap.NetworkMap)) - if err := b.initTKALocked(); err != nil { - return err + if prevCC := b.resetControlClientLocked(); prevCC != nil { + defer prevCC.Shutdown() } + // TKA errors should not prevent resetting the backend state. + // However, we should still return the error to the caller. + tkaErr := b.initTKALocked() b.lastServeConfJSON = mem.B(nil) b.serveConfig = ipn.ServeConfigView{} b.lastSuggestedExitNode = "" - b.enterStateLockedOnEntry(ipn.NoState, unlock) // Reset state; releases b.mu + b.keyExpired = false + b.overrideExitNodePolicy = false + b.resetAlwaysOnOverrideLocked() + b.extHost.NotifyProfileChange(b.pm.CurrentProfile(), b.pm.CurrentPrefs(), false) + b.setAtomicValuesFromPrefsLocked(b.pm.CurrentPrefs()) + b.enterStateLocked(ipn.NoState) b.health.SetLocalLogConfigHealth(nil) - return b.Start(ipn.Options{}) + if tkaErr != nil { + return tkaErr + } + return b.startLocked(ipn.Options{}) } // DeleteProfile deletes a profile with the given ID. // If the profile is not known, it is a no-op. func (b *LocalBackend) DeleteProfile(p ipn.ProfileID) error { - unlock := b.lockAndGetUnlock() - defer unlock() + b.mu.Lock() + defer b.mu.Unlock() - needToRestart := b.pm.CurrentProfile().ID == p + needToRestart := b.pm.CurrentProfile().ID() == p if err := b.pm.DeleteProfile(p); err != nil { if err == errProfileNotFound { return nil @@ -6652,12 +6923,12 @@ func (b *LocalBackend) DeleteProfile(p ipn.ProfileID) error { if !needToRestart { return nil } - return b.resetForProfileChangeLockedOnEntry(unlock) + return b.resetForProfileChangeLocked() } // CurrentProfile returns the current LoginProfile. // The value may be zero if the profile is not persisted. -func (b *LocalBackend) CurrentProfile() ipn.LoginProfile { +func (b *LocalBackend) CurrentProfile() ipn.LoginProfileView { b.mu.Lock() defer b.mu.Unlock() return b.pm.CurrentProfile() @@ -6665,20 +6936,20 @@ func (b *LocalBackend) CurrentProfile() ipn.LoginProfile { // NewProfile creates and switches to the new profile. func (b *LocalBackend) NewProfile() error { - unlock := b.lockAndGetUnlock() - defer unlock() + b.mu.Lock() + defer b.mu.Unlock() - b.pm.NewProfile() + b.pm.SwitchToNewProfile() // The new profile doesn't yet have a ControlURL because it hasn't been // set. Conservatively reset the dialPlan. b.resetDialPlan() - return b.resetForProfileChangeLockedOnEntry(unlock) + return b.resetForProfileChangeLocked() } // ListProfiles returns a list of all LoginProfiles. -func (b *LocalBackend) ListProfiles() []ipn.LoginProfile { +func (b *LocalBackend) ListProfiles() []ipn.LoginProfileView { b.mu.Lock() defer b.mu.Unlock() return b.pm.Profiles() @@ -6689,12 +6960,11 @@ func (b *LocalBackend) ListProfiles() []ipn.LoginProfile { // backend is left with a new profile, ready for StartLoginInterative to be // called to register it as new node. func (b *LocalBackend) ResetAuth() error { - unlock := b.lockAndGetUnlock() - defer unlock() + b.mu.Lock() + defer b.mu.Unlock() - prevCC := b.resetControlClientLocked() - if prevCC != nil { - defer prevCC.Shutdown() // call must happen after release b.mu + if prevCC := b.resetControlClientLocked(); prevCC != nil { + defer prevCC.Shutdown() } if err := b.clearMachineKeyLocked(); err != nil { return err @@ -6703,49 +6973,7 @@ func (b *LocalBackend) ResetAuth() error { return err } b.resetDialPlan() // always reset if we're removing everything - return b.resetForProfileChangeLockedOnEntry(unlock) -} - -// StreamDebugCapture writes a pcap stream of packets traversing -// tailscaled to the provided response writer. -func (b *LocalBackend) StreamDebugCapture(ctx context.Context, w io.Writer) error { - var s *capture.Sink - - b.mu.Lock() - if b.debugSink == nil { - s = capture.New() - b.debugSink = s - b.e.InstallCaptureHook(s.LogPacket) - } else { - s = b.debugSink - } - b.mu.Unlock() - - unregister := s.RegisterOutput(w) - - select { - case <-ctx.Done(): - case <-s.WaitCh(): - } - unregister() - - // Shut down & uninstall the sink if there are no longer - // any outputs on it. - b.mu.Lock() - defer b.mu.Unlock() - - select { - case <-b.ctx.Done(): - return nil - default: - } - if b.debugSink != nil && b.debugSink.NumOutputs() == 0 { - s := b.debugSink - b.e.InstallCaptureHook(nil) - b.debugSink = nil - return s.Close() - } - return nil + return b.resetForProfileChangeLocked() } func (b *LocalBackend) GetPeerEndpointChanges(ctx context.Context, ip netip.Addr) ([]magicsock.EndpointChange, error) { @@ -6778,78 +7006,33 @@ func (b *LocalBackend) DebugBreakDERPConns() error { return b.MagicConn().DebugBreakDERPConns() } -func (b *LocalBackend) pushSelfUpdateProgress(up ipnstate.UpdateProgress) { - b.mu.Lock() - defer b.mu.Unlock() - b.selfUpdateProgress = append(b.selfUpdateProgress, up) - b.lastSelfUpdateState = up.Status -} - -func (b *LocalBackend) clearSelfUpdateProgress() { - b.mu.Lock() - defer b.mu.Unlock() - b.selfUpdateProgress = make([]ipnstate.UpdateProgress, 0) - b.lastSelfUpdateState = ipnstate.UpdateFinished -} - -func (b *LocalBackend) GetSelfUpdateProgress() []ipnstate.UpdateProgress { - b.mu.Lock() - defer b.mu.Unlock() - res := make([]ipnstate.UpdateProgress, len(b.selfUpdateProgress)) - copy(res, b.selfUpdateProgress) - return res -} - -func (b *LocalBackend) DoSelfUpdate() { - b.mu.Lock() - updateState := b.lastSelfUpdateState - b.mu.Unlock() - // don't start an update if one is already in progress - if updateState == ipnstate.UpdateInProgress { - return - } - b.clearSelfUpdateProgress() - b.pushSelfUpdateProgress(ipnstate.NewUpdateProgress(ipnstate.UpdateInProgress, "")) - up, err := clientupdate.NewUpdater(clientupdate.Arguments{ - Logf: func(format string, args ...any) { - b.pushSelfUpdateProgress(ipnstate.NewUpdateProgress(ipnstate.UpdateInProgress, fmt.Sprintf(format, args...))) - }, - }) - if err != nil { - b.pushSelfUpdateProgress(ipnstate.NewUpdateProgress(ipnstate.UpdateFailed, err.Error())) - } - err = up.Update() - if err != nil { - b.pushSelfUpdateProgress(ipnstate.NewUpdateProgress(ipnstate.UpdateFailed, err.Error())) - } else { - b.pushSelfUpdateProgress(ipnstate.NewUpdateProgress(ipnstate.UpdateFinished, "tailscaled did not restart; please restart Tailscale manually.")) - } -} - // ObserveDNSResponse passes a DNS response from the PeerAPI DNS server to the // App Connector to enable route discovery. -func (b *LocalBackend) ObserveDNSResponse(res []byte) { +func (b *LocalBackend) ObserveDNSResponse(res []byte) error { + if !buildfeatures.HasAppConnectors { + return nil + } var appConnector *appc.AppConnector b.mu.Lock() if b.appConnector == nil { b.mu.Unlock() - return + return nil } appConnector = b.appConnector b.mu.Unlock() - appConnector.ObserveDNSResponse(res) + return appConnector.ObserveDNSResponse(res) } // ErrDisallowedAutoRoute is returned by AdvertiseRoute when a route that is not allowed is requested. var ErrDisallowedAutoRoute = errors.New("route is not allowed") -// AdvertiseRoute implements the appc.RouteAdvertiser interface. It sets a new -// route advertisement if one is not already present in the existing routes. -// If the route is disallowed, ErrDisallowedAutoRoute is returned. +// AdvertiseRoute implements the appctype.RouteAdvertiser interface. It sets a +// new route advertisement if one is not already present in the existing +// routes. If the route is disallowed, ErrDisallowedAutoRoute is returned. func (b *LocalBackend) AdvertiseRoute(ipps ...netip.Prefix) error { finalRoutes := b.Prefs().AdvertiseRoutes().AsSlice() - newRoutes := false + var newRoutes []netip.Prefix for _, ipp := range ipps { if !allowedAutoRoute(ipp) { @@ -6865,13 +7048,14 @@ func (b *LocalBackend) AdvertiseRoute(ipps ...netip.Prefix) error { } finalRoutes = append(finalRoutes, ipp) - newRoutes = true + newRoutes = append(newRoutes, ipp) } - if !newRoutes { + if len(newRoutes) == 0 { return nil } + b.logf("advertising new app connector routes: %v", newRoutes) _, err := b.EditPrefs(&ipn.MaskedPrefs{ Prefs: ipn.Prefs{ AdvertiseRoutes: finalRoutes, @@ -6901,8 +7085,8 @@ func coveredRouteRangeNoDefault(finalRoutes []netip.Prefix, ipp netip.Prefix) bo return false } -// UnadvertiseRoute implements the appc.RouteAdvertiser interface. It removes -// a route advertisement if one is present in the existing routes. +// UnadvertiseRoute implements the appctype.RouteAdvertiser interface. It +// removes a route advertisement if one is present in the existing routes. func (b *LocalBackend) UnadvertiseRoute(toRemove ...netip.Prefix) error { currentRoutes := b.Prefs().AdvertiseRoutes().AsSlice() finalRoutes := currentRoutes[:0] @@ -6925,15 +7109,18 @@ func (b *LocalBackend) UnadvertiseRoute(toRemove ...netip.Prefix) error { // namespace a key with the profile manager's current profile key, if any func namespaceKeyForCurrentProfile(pm *profileManager, key ipn.StateKey) ipn.StateKey { - return pm.CurrentProfile().Key + "||" + key + return pm.CurrentProfile().Key() + "||" + key } const routeInfoStateStoreKey ipn.StateKey = "_routeInfo" -func (b *LocalBackend) storeRouteInfo(ri *appc.RouteInfo) error { +func (b *LocalBackend) storeRouteInfo(ri appctype.RouteInfo) error { + if !buildfeatures.HasAppConnectors { + return feature.ErrUnavailable + } b.mu.Lock() defer b.mu.Unlock() - if b.pm.CurrentProfile().ID == "" { + if b.pm.CurrentProfile().ID() == "" { return nil } key := namespaceKeyForCurrentProfile(b.pm, routeInfoStateStoreKey) @@ -6944,13 +7131,16 @@ func (b *LocalBackend) storeRouteInfo(ri *appc.RouteInfo) error { return b.pm.WriteState(key, bs) } -func (b *LocalBackend) readRouteInfoLocked() (*appc.RouteInfo, error) { - if b.pm.CurrentProfile().ID == "" { - return &appc.RouteInfo{}, nil +func (b *LocalBackend) readRouteInfoLocked() (*appctype.RouteInfo, error) { + if !buildfeatures.HasAppConnectors { + return nil, feature.ErrUnavailable + } + if b.pm.CurrentProfile().ID() == "" { + return &appctype.RouteInfo{}, nil } key := namespaceKeyForCurrentProfile(b.pm, routeInfoStateStoreKey) bs, err := b.pm.Store().ReadState(key) - ri := &appc.RouteInfo{} + ri := &appctype.RouteInfo{} if err != nil { return nil, err } @@ -6960,10 +7150,19 @@ func (b *LocalBackend) readRouteInfoLocked() (*appc.RouteInfo, error) { return ri, nil } -// seamlessRenewalEnabled reports whether seamless key renewals are enabled -// (i.e. we saw our self node with the SeamlessKeyRenewal attr in a netmap). -// This enables beta functionality of renewing node keys without breaking -// connections. +// ReadRouteInfo returns the app connector route information that is +// stored in prefs to be consistent across restarts. It should be up +// to date with the RouteInfo in memory being used by appc. +func (b *LocalBackend) ReadRouteInfo() (*appctype.RouteInfo, error) { + b.mu.Lock() + defer b.mu.Unlock() + return b.readRouteInfoLocked() +} + +// seamlessRenewalEnabled reports whether seamless key renewals are enabled. +// +// As of 2025-09-11, this is the default behaviour unless nodes receive +// [tailcfg.NodeAttrDisableSeamlessKeyRenewal] in their netmap. func (b *LocalBackend) seamlessRenewalEnabled() bool { return b.ControlKnobs().SeamlessKeyRenewal.Load() } @@ -6998,48 +7197,60 @@ func allowedAutoRoute(ipp netip.Prefix) bool { return true } -// mayDeref dereferences p if non-nil, otherwise it returns the zero value. -func mayDeref[T any](p *T) (v T) { - if p == nil { - return v - } - return *p -} - var ErrNoPreferredDERP = errors.New("no preferred DERP, try again later") -// suggestExitNodeLocked computes a suggestion based on the current netmap and last netcheck report. If -// there are multiple equally good options, one is selected at random, so the result is not stable. To be -// eligible for consideration, the peer must have NodeAttrSuggestExitNode in its CapMap. +// suggestExitNodeLocked computes a suggestion based on the current netmap and +// other optional factors. If there are multiple equally good options, one may +// be selected at random, so the result is not stable. To be eligible for +// consideration, the peer must have NodeAttrSuggestExitNode in its CapMap. // -// Currently, peers with a DERP home are preferred over those without (typically this means Mullvad). -// Peers are selected based on having a DERP home that is the lowest latency to this device. For peers -// without a DERP home, we look for geographic proximity to this device's DERP home. -// -// netMap is an optional netmap to use that overrides b.netMap (needed for SetControlClientStatus before b.netMap is updated). -// If netMap is nil, then b.netMap is used. -// -// b.mu.lock() must be held. -func (b *LocalBackend) suggestExitNodeLocked(netMap *netmap.NetworkMap) (response apitype.ExitNodeSuggestionResponse, err error) { - // netMap is an optional netmap to use that overrides b.netMap (needed for SetControlClientStatus before b.netMap is updated). If netMap is nil, then b.netMap is used. - if netMap == nil { - netMap = b.netMap +// b.mu must be held. +func (b *LocalBackend) suggestExitNodeLocked() (response apitype.ExitNodeSuggestionResponse, err error) { + if !buildfeatures.HasUseExitNode { + return response, feature.ErrUnavailable } lastReport := b.MagicConn().GetLastNetcheckReport(b.ctx) prevSuggestion := b.lastSuggestedExitNode - res, err := suggestExitNode(lastReport, netMap, prevSuggestion, randomRegion, randomNode, getAllowedSuggestions()) + res, err := suggestExitNode(lastReport, b.currentNode(), prevSuggestion, randomRegion, randomNode, b.getAllowedSuggestions()) if err != nil { return res, err } + if prevSuggestion != res.ID { + // Notify the clients via the IPN bus if the exit node suggestion has changed. + b.sendToLocked(ipn.Notify{SuggestedExitNode: &res.ID}, allClients) + } b.lastSuggestedExitNode = res.ID + return res, err } func (b *LocalBackend) SuggestExitNode() (response apitype.ExitNodeSuggestionResponse, err error) { + if !buildfeatures.HasUseExitNode { + return response, feature.ErrUnavailable + } b.mu.Lock() defer b.mu.Unlock() - return b.suggestExitNodeLocked(nil) + return b.suggestExitNodeLocked() +} + +// getAllowedSuggestions returns a set of exit nodes permitted by the most recent +// [pkey.AllowedSuggestedExitNodes] value. Callers must not mutate the returned set. +func (b *LocalBackend) getAllowedSuggestions() set.Set[tailcfg.StableNodeID] { + b.allowedSuggestedExitNodesMu.Lock() + defer b.allowedSuggestedExitNodesMu.Unlock() + return b.allowedSuggestedExitNodes +} + +// refreshAllowedSuggestions rebuilds the set of permitted exit nodes +// from the current [pkey.AllowedSuggestedExitNodes] value. +func (b *LocalBackend) refreshAllowedSuggestions() { + if !buildfeatures.HasUseExitNode { + return + } + b.allowedSuggestedExitNodesMu.Lock() + defer b.allowedSuggestedExitNodesMu.Unlock() + b.allowedSuggestedExitNodes = fillAllowedSuggestions(b.polc) } // selectRegionFunc returns a DERP region from the slice of candidate regions. @@ -7051,12 +7262,10 @@ type selectRegionFunc func(views.Slice[int]) int // choice. type selectNodeFunc func(nodes views.Slice[tailcfg.NodeView], last tailcfg.StableNodeID) tailcfg.NodeView -var getAllowedSuggestions = lazy.SyncFunc(fillAllowedSuggestions) - -func fillAllowedSuggestions() set.Set[tailcfg.StableNodeID] { - nodes, err := syspolicy.GetStringArray(syspolicy.AllowedSuggestedExitNodes, nil) +func fillAllowedSuggestions(polc policyclient.Client) set.Set[tailcfg.StableNodeID] { + nodes, err := polc.GetStringArray(pkey.AllowedSuggestedExitNodes, nil) if err != nil { - log.Printf("fillAllowedSuggestions: unable to look up %q policy: %v", syspolicy.AllowedSuggestedExitNodes, err) + log.Printf("fillAllowedSuggestions: unable to look up %q policy: %v", pkey.AllowedSuggestedExitNodes, err) return nil } if nodes == nil { @@ -7069,30 +7278,60 @@ func fillAllowedSuggestions() set.Set[tailcfg.StableNodeID] { return s } -func suggestExitNode(report *netcheck.Report, netMap *netmap.NetworkMap, prevSuggestion tailcfg.StableNodeID, selectRegion selectRegionFunc, selectNode selectNodeFunc, allowList set.Set[tailcfg.StableNodeID]) (res apitype.ExitNodeSuggestionResponse, err error) { +// suggestExitNode returns a suggestion for reasonably good exit node based on +// the current netmap and the previous suggestion. +func suggestExitNode(report *netcheck.Report, nb *nodeBackend, prevSuggestion tailcfg.StableNodeID, selectRegion selectRegionFunc, selectNode selectNodeFunc, allowList set.Set[tailcfg.StableNodeID]) (res apitype.ExitNodeSuggestionResponse, err error) { + switch { + case nb.SelfHasCap(tailcfg.NodeAttrTrafficSteering): + // The traffic-steering feature flag is enabled on this tailnet. + return suggestExitNodeUsingTrafficSteering(nb, allowList) + default: + return suggestExitNodeUsingDERP(report, nb, prevSuggestion, selectRegion, selectNode, allowList) + } +} + +// suggestExitNodeUsingDERP is the classic algorithm used to suggest exit nodes, +// before traffic steering was implemented. This handles the plain failover +// case, in addition to the optional Regional Routing. +// +// It computes a suggestion based on the current netmap and last netcheck +// report. If there are multiple equally good options, one is selected at +// random, so the result is not stable. To be eligible for consideration, the +// peer must have NodeAttrSuggestExitNode in its CapMap. +// +// Currently, peers with a DERP home are preferred over those without (typically +// this means Mullvad). Peers are selected based on having a DERP home that is +// the lowest latency to this device. For peers without a DERP home, we look for +// geographic proximity to this device's DERP home. +func suggestExitNodeUsingDERP(report *netcheck.Report, nb *nodeBackend, prevSuggestion tailcfg.StableNodeID, selectRegion selectRegionFunc, selectNode selectNodeFunc, allowList set.Set[tailcfg.StableNodeID]) (res apitype.ExitNodeSuggestionResponse, err error) { + // TODO(sfllaw): Context needs to be plumbed down here to support + // reachability testing. + ctx := context.TODO() + + netMap := nb.NetMap() if report == nil || report.PreferredDERP == 0 || netMap == nil || netMap.DERPMap == nil { return res, ErrNoPreferredDERP } - candidates := make([]tailcfg.NodeView, 0, len(netMap.Peers)) - for _, peer := range netMap.Peers { - if !peer.Valid() { - continue + // Use [nodeBackend.AppendMatchingPeers] instead of the netmap directly, + // since the netmap doesn't include delta updates (e.g., home DERP or Online + // status changes) from the control plane since the last full update. + candidates := nb.AppendMatchingPeers(nil, func(peer tailcfg.NodeView) bool { + if !peer.Valid() || !nb.PeerIsReachable(ctx, peer) { + return false } if allowList != nil && !allowList.Contains(peer.StableID()) { - continue - } - if peer.CapMap().Contains(tailcfg.NodeAttrSuggestExitNode) && tsaddr.ContainsExitRoutes(peer.AllowedIPs()) { - candidates = append(candidates, peer) + return false } - } + return peer.CapMap().Contains(tailcfg.NodeAttrSuggestExitNode) && tsaddr.ContainsExitRoutes(peer.AllowedIPs()) + }) if len(candidates) == 0 { return res, nil } if len(candidates) == 1 { peer := candidates[0] if hi := peer.Hostinfo(); hi.Valid() { - if loc := hi.Location(); loc != nil { - res.Location = loc.View() + if loc := hi.Location(); loc.Valid() { + res.Location = loc } } res.ID = peer.StableID() @@ -7112,15 +7351,7 @@ func suggestExitNode(report *netcheck.Report, netMap *netmap.NetworkMap, prevSug } distances := make([]nodeDistance, 0, len(candidates)) for _, c := range candidates { - if c.DERP() != "" { - ipp, err := netip.ParseAddrPort(c.DERP()) - if err != nil { - continue - } - if ipp.Addr() != tailcfg.DerpMagicIPAddr { - continue - } - regionID := int(ipp.Port()) + if regionID := c.HomeDERP(); regionID != 0 { candidatesByRegion[regionID] = append(candidatesByRegion[regionID], c) continue } @@ -7136,10 +7367,10 @@ func suggestExitNode(report *netcheck.Report, netMap *netmap.NetworkMap, prevSug continue } loc := hi.Location() - if loc == nil { + if !loc.Valid() { continue } - distance := longLatDistance(preferredDERP.Latitude, preferredDERP.Longitude, loc.Latitude, loc.Longitude) + distance := longLatDistance(preferredDERP.Latitude, preferredDERP.Longitude, loc.Latitude(), loc.Longitude()) if distance < minDistance { minDistance = distance } @@ -7148,9 +7379,9 @@ func suggestExitNode(report *netcheck.Report, netMap *netmap.NetworkMap, prevSug // First, try to select an exit node that has the closest DERP home, based on lastReport's DERP latency. // If there are no latency values, it returns an arbitrary region if len(candidatesByRegion) > 0 { - minRegion := minLatencyDERPRegion(xmaps.Keys(candidatesByRegion), report) + minRegion := minLatencyDERPRegion(slicesx.MapKeys(candidatesByRegion), report) if minRegion == 0 { - minRegion = selectRegion(views.SliceOf(xmaps.Keys(candidatesByRegion))) + minRegion = selectRegion(views.SliceOf(slicesx.MapKeys(candidatesByRegion))) } regionCandidates, ok := candidatesByRegion[minRegion] if !ok { @@ -7160,8 +7391,8 @@ func suggestExitNode(report *netcheck.Report, netMap *netmap.NetworkMap, prevSug res.ID = chosen.StableID() res.Name = chosen.Name() if hi := chosen.Hostinfo(); hi.Valid() { - if loc := hi.Location(); loc != nil { - res.Location = loc.View() + if loc := hi.Location(); loc.Valid() { + res.Location = loc } } return res, nil @@ -7190,8 +7421,105 @@ func suggestExitNode(report *netcheck.Report, netMap *netmap.NetworkMap, prevSug res.ID = chosen.StableID() res.Name = chosen.Name() if hi := chosen.Hostinfo(); hi.Valid() { - if loc := hi.Location(); loc != nil { - res.Location = loc.View() + if loc := hi.Location(); loc.Valid() { + res.Location = loc + } + } + return res, nil +} + +var ErrNoNetMap = errors.New("no network map, try again later") + +// suggestExitNodeUsingTrafficSteering uses traffic steering priority scores to +// pick one of the best exit nodes. These priorities are provided by Control in +// the node’s [tailcfg.Location]. To be eligible for consideration, the node +// must have NodeAttrSuggestExitNode in its CapMap. +func suggestExitNodeUsingTrafficSteering(nb *nodeBackend, allowed set.Set[tailcfg.StableNodeID]) (apitype.ExitNodeSuggestionResponse, error) { + // TODO(sfllaw): Context needs to be plumbed down here to support + // reachability testing. + ctx := context.TODO() + + nm := nb.NetMap() + if nm == nil { + return apitype.ExitNodeSuggestionResponse{}, ErrNoNetMap + } + + self := nb.Self() + if !self.Valid() { + return apitype.ExitNodeSuggestionResponse{}, ErrNoNetMap + } + + if !nb.SelfHasCap(tailcfg.NodeAttrTrafficSteering) { + panic("missing traffic-steering capability") + } + + nodes := nb.AppendMatchingPeers(nil, func(p tailcfg.NodeView) bool { + if !p.Valid() { + return false + } + if !nb.PeerIsReachable(ctx, p) { + return false + } + if allowed != nil && !allowed.Contains(p.StableID()) { + return false + } + if !p.CapMap().Contains(tailcfg.NodeAttrSuggestExitNode) { + return false + } + if !tsaddr.ContainsExitRoutes(p.AllowedIPs()) { + return false + } + return true + }) + + scores := make(map[tailcfg.NodeID]int, len(nodes)) + score := func(n tailcfg.NodeView) int { + id := n.ID() + s, ok := scores[id] + if !ok { + s = 0 // score of zero means incomparable + if hi := n.Hostinfo(); hi.Valid() { + if loc := hi.Location(); loc.Valid() { + s = loc.Priority() + } + } + scores[id] = s + } + return s + } + rdvHash := makeRendezvousHasher(self.ID()) + + var pick tailcfg.NodeView + if len(nodes) == 1 { + pick = nodes[0] + } + if len(nodes) > 1 { + // Find the highest scoring exit nodes. + slices.SortFunc(nodes, func(a, b tailcfg.NodeView) int { + c := cmp.Compare(score(b), score(a)) // Highest score first. + if c == 0 { + // Rendezvous hashing for reliably picking the + // same node from a list: tailscale/tailscale#16551. + return cmp.Compare(rdvHash(b.ID()), rdvHash(a.ID())) + } + return c + }) + + // TODO(sfllaw): add a temperature knob so that this client has + // a chance of picking the next best option. + pick = nodes[0] + } + + if !pick.Valid() { + return apitype.ExitNodeSuggestionResponse{}, nil + } + res := apitype.ExitNodeSuggestionResponse{ + ID: pick.StableID(), + Name: pick.Name(), + } + if hi := pick.Hostinfo(); hi.Valid() { + if loc := hi.Location(); loc.Valid() { + res.Location = loc } } return res, nil @@ -7207,13 +7535,13 @@ func pickWeighted(candidates []tailcfg.NodeView) []tailcfg.NodeView { continue } loc := hi.Location() - if loc == nil || loc.Priority < maxWeight { + if !loc.Valid() || loc.Priority() < maxWeight { continue } - if maxWeight != loc.Priority { + if maxWeight != loc.Priority() { best = best[:0] } - maxWeight = loc.Priority + maxWeight = loc.Priority() best = append(best, c) } return best @@ -7280,85 +7608,95 @@ func longLatDistance(fromLat, fromLong, toLat, toLong float64) float64 { return earthRadiusMeters * c } -// shouldAutoExitNode checks for the auto exit node MDM policy. -func shouldAutoExitNode() bool { - exitNodeIDStr, _ := syspolicy.GetString(syspolicy.ExitNodeID, "") - return exitNodeIDStr == "auto:any" +// makeRendezvousHasher returns a function that hashes a node ID to a uint64. +// https://en.wikipedia.org/wiki/Rendezvous_hashing +func makeRendezvousHasher(seed tailcfg.NodeID) func(tailcfg.NodeID) uint64 { + en := binary.BigEndian + return func(n tailcfg.NodeID) uint64 { + var b [16]byte + en.PutUint64(b[:], uint64(seed)) + en.PutUint64(b[8:], uint64(n)) + v := sha256.Sum256(b[:]) + return en.Uint64(v[:]) + } } -// startAutoUpdate triggers an auto-update attempt. The actual update happens -// asynchronously. If another update is in progress, an error is returned. -func (b *LocalBackend) startAutoUpdate(logPrefix string) (retErr error) { - // Check if update was already started, and mark as started. - if !b.trySetC2NUpdateStarted() { - return errors.New("update already started") - } - defer func() { - // Clear the started flag if something failed. - if retErr != nil { - b.setC2NUpdateStarted(false) - } - }() +const ( + // unresolvedExitNodeID is a special [tailcfg.StableNodeID] value + // used as an exit node ID to install a blackhole route, preventing + // accidental non-exit-node usage until the [ipn.ExitNodeExpression] + // is evaluated and an actual exit node is selected. + // + // We use "auto:any" for compatibility with older, pre-[ipn.ExitNodeExpression] + // clients that have been using "auto:any" for this purpose for a long time. + unresolvedExitNodeID tailcfg.StableNodeID = "auto:any" +) - cmdTS, err := findCmdTailscale() - if err != nil { - return fmt.Errorf("failed to find cmd/tailscale binary: %w", err) - } - var ver struct { - Long string `json:"long"` - } - out, err := exec.Command(cmdTS, "version", "--json").Output() - if err != nil { - return fmt.Errorf("failed to find cmd/tailscale binary: %w", err) - } - if err := json.Unmarshal(out, &ver); err != nil { - return fmt.Errorf("invalid JSON from cmd/tailscale version --json: %w", err) +func isAllowedAutoExitNodeID(polc policyclient.Client, exitNodeID tailcfg.StableNodeID) bool { + if exitNodeID == "" { + return false // an exit node is required } - if ver.Long != version.Long() { - return fmt.Errorf("cmd/tailscale version %q does not match tailscaled version %q", ver.Long, version.Long()) - } - - cmd := tailscaleUpdateCmd(cmdTS) - buf := new(bytes.Buffer) - cmd.Stdout = buf - cmd.Stderr = buf - b.logf("%s: running %q", logPrefix, strings.Join(cmd.Args, " ")) - if err := cmd.Start(); err != nil { - return fmt.Errorf("failed to start cmd/tailscale update: %w", err) + if nodes, _ := polc.GetStringArray(pkey.AllowedSuggestedExitNodes, nil); nodes != nil { + return slices.Contains(nodes, string(exitNodeID)) } - - go func() { - if err := cmd.Wait(); err != nil { - b.logf("%s: update command failed: %v, output: %s", logPrefix, err, buf) - } else { - b.logf("%s: update attempt complete", logPrefix) - } - b.setC2NUpdateStarted(false) - }() - return nil + return true // no policy configured; allow all exit nodes } // srcIPHasCapForFilter is called by the packet filter when evaluating firewall // rules that require a source IP to have a certain node capability. // // TODO(bradfitz): optimize this later if/when it matters. +// TODO(nickkhyl): move this into [nodeBackend] along with [LocalBackend.updateFilterLocked]. func (b *LocalBackend) srcIPHasCapForFilter(srcIP netip.Addr, cap tailcfg.NodeCapability) bool { if cap == "" { // Shouldn't happen, but just in case. // But the empty cap also shouldn't be found in Node.CapMap. return false } - - b.mu.Lock() - defer b.mu.Unlock() - - nodeID, ok := b.nodeByAddr[srcIP] + cn := b.currentNode() + nodeID, ok := cn.NodeByAddr(srcIP) if !ok { return false } - n, ok := b.peers[nodeID] + n, ok := cn.NodeByID(nodeID) if !ok { return false } return n.HasCap(cap) } + +// maybeUsernameOf returns the actor's username if the actor +// is non-nil and its username can be resolved. +func maybeUsernameOf(actor ipnauth.Actor) string { + var username string + if actor != nil { + username, _ = actor.Username() + } + return username +} + +var ( + metricCurrentWatchIPNBus = clientmetric.NewGauge("localbackend_current_watch_ipn_bus") +) + +func (b *LocalBackend) stateEncrypted() opt.Bool { + switch runtime.GOOS { + case "android", "ios": + return opt.NewBool(true) + case "darwin": + switch { + case version.IsMacAppStore(): + return opt.NewBool(true) + case version.IsMacSysExt(): + sp, _ := b.polc.GetBoolean(pkey.EncryptState, true) + return opt.NewBool(sp) + default: + // Probably self-compiled tailscaled, we don't use the Keychain + // there. + return opt.NewBool(false) + } + default: + _, ok := b.store.(ipn.EncryptedStateStore) + return opt.NewBool(ok) + } +} diff --git a/ipn/ipnlocal/local_test.go b/ipn/ipnlocal/local_test.go index b0e12d500..f17fabb60 100644 --- a/ipn/ipnlocal/local_test.go +++ b/ipn/ipnlocal/local_test.go @@ -5,32 +5,42 @@ package ipnlocal import ( "context" + "encoding/binary" "encoding/json" "errors" "fmt" + "maps" "math" "net" "net/http" "net/netip" "os" + "path/filepath" "reflect" "slices" + "strings" "sync" + "sync/atomic" "testing" "time" "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + memro "go4.org/mem" "go4.org/netipx" "golang.org/x/net/dns/dnsmessage" "tailscale.com/appc" "tailscale.com/appc/appctest" - "tailscale.com/clientupdate" "tailscale.com/control/controlclient" "tailscale.com/drive" "tailscale.com/drive/driveimpl" + "tailscale.com/feature" + _ "tailscale.com/feature/condregister/portmapper" "tailscale.com/health" "tailscale.com/hostinfo" "tailscale.com/ipn" + "tailscale.com/ipn/conffile" + "tailscale.com/ipn/ipnauth" "tailscale.com/ipn/store/mem" "tailscale.com/net/netcheck" "tailscale.com/net/netmon" @@ -39,26 +49,35 @@ import ( "tailscale.com/tailcfg" "tailscale.com/tsd" "tailscale.com/tstest" + "tailscale.com/tstest/deptest" + "tailscale.com/tstest/typewalk" + "tailscale.com/types/appctype" "tailscale.com/types/dnstype" + "tailscale.com/types/ipproto" "tailscale.com/types/key" "tailscale.com/types/logger" "tailscale.com/types/logid" "tailscale.com/types/netmap" "tailscale.com/types/opt" + "tailscale.com/types/persist" "tailscale.com/types/ptr" "tailscale.com/types/views" "tailscale.com/util/dnsname" + "tailscale.com/util/eventbus" + "tailscale.com/util/eventbus/eventbustest" "tailscale.com/util/mak" "tailscale.com/util/must" "tailscale.com/util/set" "tailscale.com/util/syspolicy" + "tailscale.com/util/syspolicy/pkey" + "tailscale.com/util/syspolicy/policytest" + "tailscale.com/util/syspolicy/source" "tailscale.com/wgengine" "tailscale.com/wgengine/filter" + "tailscale.com/wgengine/filter/filtertype" "tailscale.com/wgengine/wgcfg" ) -func fakeStoreRoutes(*appc.RouteInfo) error { return nil } - func inRemove(ip netip.Addr) bool { for _, pfx := range removeFromDefaultRoute { if pfx.Contains(ip) { @@ -68,6 +87,18 @@ func inRemove(ip netip.Addr) bool { return false } +func makeNodeKeyFromID(nodeID tailcfg.NodeID) key.NodePublic { + raw := make([]byte, 32) + binary.BigEndian.PutUint64(raw[24:], uint64(nodeID)) + return key.NodePublicFromRaw32(memro.B(raw)) +} + +func makeDiscoKeyFromID(nodeID tailcfg.NodeID) (ret key.DiscoPublic) { + raw := make([]byte, 32) + binary.BigEndian.PutUint64(raw[24:], uint64(nodeID)) + return key.DiscoPublicFromRaw32(memro.B(raw)) +} + func TestShrinkDefaultRoute(t *testing.T) { tests := []struct { route string @@ -428,20 +459,39 @@ func (panicOnUseTransport) RoundTrip(*http.Request) (*http.Response, error) { } func newTestLocalBackend(t testing.TB) *LocalBackend { + bus := eventbustest.NewBus(t) + return newTestLocalBackendWithSys(t, tsd.NewSystemWithBus(bus)) +} + +// newTestLocalBackendWithSys creates a new LocalBackend with the given tsd.System. +// If the state store or engine are not set in sys, they will be set to a new +// in-memory store and fake userspace engine, respectively. +func newTestLocalBackendWithSys(t testing.TB, sys *tsd.System) *LocalBackend { var logf logger.Logf = logger.Discard - sys := new(tsd.System) - store := new(mem.Store) - sys.Set(store) - eng, err := wgengine.NewFakeUserspaceEngine(logf, sys.Set, sys.HealthTracker(), sys.UserMetricsRegistry()) - if err != nil { - t.Fatalf("NewFakeUserspaceEngine: %v", err) + if _, ok := sys.StateStore.GetOK(); !ok { + sys.Set(new(mem.Store)) + t.Log("Added memory store for testing") + } + if _, ok := sys.Engine.GetOK(); !ok { + eng, err := wgengine.NewFakeUserspaceEngine(logf, sys.Set, sys.HealthTracker.Get(), sys.UserMetricsRegistry(), sys.Bus.Get()) + if err != nil { + t.Fatalf("NewFakeUserspaceEngine: %v", err) + } + t.Cleanup(eng.Close) + sys.Set(eng) + t.Log("Added fake userspace engine for testing") + } + if _, ok := sys.Dialer.GetOK(); !ok { + dialer := tsdial.NewDialer(netmon.NewStatic()) + dialer.SetBus(sys.Bus.Get()) + sys.Set(dialer) + t.Log("Added static dialer for testing") } - t.Cleanup(eng.Close) - sys.Set(eng) lb, err := NewLocalBackend(logf, logid.PublicID{}, sys, 0) if err != nil { t.Fatalf("NewLocalBackend: %v", err) } + t.Cleanup(lb.Shutdown) return lb } @@ -466,30 +516,30 @@ func TestLazyMachineKeyGeneration(t *testing.T) { func TestZeroExitNodeViaLocalAPI(t *testing.T) { lb := newTestLocalBackend(t) + user := &ipnauth.TestActor{} // Give it an initial exit node in use. - if _, err := lb.EditPrefs(&ipn.MaskedPrefs{ + if _, err := lb.EditPrefsAs(&ipn.MaskedPrefs{ ExitNodeIDSet: true, Prefs: ipn.Prefs{ ExitNodeID: "foo", }, - }); err != nil { + }, user); err != nil { t.Fatalf("enabling first exit node: %v", err) } // SetUseExitNodeEnabled(false) "remembers" the prior exit node. - if _, err := lb.SetUseExitNodeEnabled(false); err != nil { + if _, err := lb.SetUseExitNodeEnabled(user, false); err != nil { t.Fatal("expected failure") } // Zero the exit node - pv, err := lb.EditPrefs(&ipn.MaskedPrefs{ + pv, err := lb.EditPrefsAs(&ipn.MaskedPrefs{ ExitNodeIDSet: true, Prefs: ipn.Prefs{ ExitNodeID: "", }, - }) - + }, user) if err != nil { t.Fatalf("enabling first exit node: %v", err) } @@ -499,34 +549,34 @@ func TestZeroExitNodeViaLocalAPI(t *testing.T) { if got, want := pv.InternalExitNodePrior(), tailcfg.StableNodeID(""); got != want { t.Fatalf("unexpected InternalExitNodePrior %q, want: %q", got, want) } - } func TestSetUseExitNodeEnabled(t *testing.T) { lb := newTestLocalBackend(t) + user := &ipnauth.TestActor{} // Can't turn it on if it never had an old value. - if _, err := lb.SetUseExitNodeEnabled(true); err == nil { + if _, err := lb.SetUseExitNodeEnabled(user, true); err == nil { t.Fatal("expected success") } // But we can turn it off when it's already off. - if _, err := lb.SetUseExitNodeEnabled(false); err != nil { + if _, err := lb.SetUseExitNodeEnabled(user, false); err != nil { t.Fatal("expected failure") } // Give it an initial exit node in use. - if _, err := lb.EditPrefs(&ipn.MaskedPrefs{ + if _, err := lb.EditPrefsAs(&ipn.MaskedPrefs{ ExitNodeIDSet: true, Prefs: ipn.Prefs{ ExitNodeID: "foo", }, - }); err != nil { + }, user); err != nil { t.Fatalf("enabling first exit node: %v", err) } // Now turn off that exit node. - if prefs, err := lb.SetUseExitNodeEnabled(false); err != nil { + if prefs, err := lb.SetUseExitNodeEnabled(user, false); err != nil { t.Fatal("expected failure") } else { if g, w := prefs.ExitNodeID(), tailcfg.StableNodeID(""); g != w { @@ -538,7 +588,7 @@ func TestSetUseExitNodeEnabled(t *testing.T) { } // And turn it back on. - if prefs, err := lb.SetUseExitNodeEnabled(true); err != nil { + if prefs, err := lb.SetUseExitNodeEnabled(user, true); err != nil { t.Fatal("expected failure") } else { if g, w := prefs.ExitNodeID(), tailcfg.StableNodeID("foo"); g != w { @@ -550,3450 +600,6618 @@ func TestSetUseExitNodeEnabled(t *testing.T) { } // Verify we block setting an Internal field. - if _, err := lb.EditPrefs(&ipn.MaskedPrefs{ + if _, err := lb.EditPrefsAs(&ipn.MaskedPrefs{ InternalExitNodePriorSet: true, - }); err == nil { + }, user); err == nil { t.Fatalf("unexpected success; want an error trying to set an internal field") } } -func TestFileTargets(t *testing.T) { - b := new(LocalBackend) - _, err := b.FileTargets() - if got, want := fmt.Sprint(err), "not connected to the tailnet"; got != want { - t.Errorf("before connect: got %q; want %q", got, want) - } - - b.netMap = new(netmap.NetworkMap) - _, err = b.FileTargets() - if got, want := fmt.Sprint(err), "not connected to the tailnet"; got != want { - t.Errorf("non-running netmap: got %q; want %q", got, want) - } - - b.state = ipn.Running - _, err = b.FileTargets() - if got, want := fmt.Sprint(err), "file sharing not enabled by Tailscale admin"; got != want { - t.Errorf("without cap: got %q; want %q", got, want) - } - - b.capFileSharing = true - got, err := b.FileTargets() - if err != nil { - t.Fatal(err) - } - if len(got) != 0 { - t.Fatalf("unexpected %d peers", len(got)) - } - - var peerMap map[tailcfg.NodeID]tailcfg.NodeView - mak.NonNil(&peerMap) - var nodeID tailcfg.NodeID - nodeID = 1234 - peer := &tailcfg.Node{ - ID: 1234, - Hostinfo: (&tailcfg.Hostinfo{OS: "tvOS"}).View(), - } - peerMap[nodeID] = peer.View() - b.peers = peerMap - got, err = b.FileTargets() - if err != nil { - t.Fatal(err) - } - if len(got) != 0 { - t.Fatalf("unexpected %d peers", len(got)) - } - // (other cases handled by TestPeerAPIBase above) +func makeExitNode(id tailcfg.NodeID, opts ...peerOptFunc) tailcfg.NodeView { + return makePeer(id, append([]peerOptFunc{withCap(26), withSuggest(), withExitRoutes()}, opts...)...) } -func TestInternalAndExternalInterfaces(t *testing.T) { - type interfacePrefix struct { - i netmon.Interface - pfx netip.Prefix - } +func TestConfigureExitNode(t *testing.T) { + controlURL := "https://localhost:1/" + exitNode1 := makeExitNode(1, withName("node-1"), withDERP(1), withAddresses(netip.MustParsePrefix("100.64.1.1/32"))) + exitNode2 := makeExitNode(2, withName("node-2"), withDERP(2), withAddresses(netip.MustParsePrefix("100.64.1.2/32"))) + selfNode := makeExitNode(3, withName("node-3"), withDERP(1), withAddresses(netip.MustParsePrefix("100.64.1.3/32"))) + clientNetmap := buildNetmapWithPeers(selfNode, exitNode1, exitNode2) - masked := func(ips ...interfacePrefix) (pfxs []netip.Prefix) { - for _, ip := range ips { - pfxs = append(pfxs, ip.pfx.Masked()) - } - return pfxs - } - iList := func(ips ...interfacePrefix) (il netmon.InterfaceList) { - for _, ip := range ips { - il = append(il, ip.i) - } - return il - } - newInterface := func(name, pfx string, wsl2, loopback bool) interfacePrefix { - ippfx := netip.MustParsePrefix(pfx) - ip := netmon.Interface{ - Interface: &net.Interface{}, - AltAddrs: []net.Addr{ - netipx.PrefixIPNet(ippfx), - }, - } - if loopback { - ip.Flags = net.FlagLoopback - } - if wsl2 { - ip.HardwareAddr = []byte{0x00, 0x15, 0x5d, 0x00, 0x00, 0x00} - } - return interfacePrefix{i: ip, pfx: ippfx} + report := &netcheck.Report{ + RegionLatency: map[int]time.Duration{ + 1: 5 * time.Millisecond, + 2: 10 * time.Millisecond, + }, + PreferredDERP: 1, } - var ( - en0 = newInterface("en0", "10.20.2.5/16", false, false) - en1 = newInterface("en1", "192.168.1.237/24", false, false) - wsl = newInterface("wsl", "192.168.5.34/24", true, false) - loopback = newInterface("lo0", "127.0.0.1/8", false, true) - ) tests := []struct { - name string - goos string - il netmon.InterfaceList - wantInt []netip.Prefix - wantExt []netip.Prefix + name string + prefs ipn.Prefs + netMap *netmap.NetworkMap + report *netcheck.Report + changePrefs *ipn.MaskedPrefs + useExitNodeEnabled *bool + exitNodeIDPolicy *tailcfg.StableNodeID + exitNodeIPPolicy *netip.Addr + exitNodeAllowedIDs []tailcfg.StableNodeID // nil if all IDs are allowed for auto exit nodes + exitNodeAllowOverride bool // whether [pkey.AllowExitNodeOverride] should be set to true + wantChangePrefsErr error // if non-nil, the error we expect from [LocalBackend.EditPrefsAs] + wantPrefs ipn.Prefs + wantExitNodeToggleErr error // if non-nil, the error we expect from [LocalBackend.SetUseExitNodeEnabled] + wantHostinfoExitNodeID tailcfg.StableNodeID }{ { - name: "single-interface", - goos: "linux", - il: iList( - en0, - loopback, - ), - wantInt: masked(loopback), - wantExt: masked(en0), + name: "exit-node-id-via-prefs", // set exit node ID via prefs + prefs: ipn.Prefs{ + ControlURL: controlURL, + }, + netMap: clientNetmap, + report: report, + changePrefs: &ipn.MaskedPrefs{ + Prefs: ipn.Prefs{ExitNodeID: exitNode1.StableID()}, + ExitNodeIDSet: true, + }, + wantPrefs: ipn.Prefs{ + ControlURL: controlURL, + ExitNodeID: exitNode1.StableID(), + }, + wantHostinfoExitNodeID: exitNode1.StableID(), }, { - name: "multiple-interfaces", - goos: "linux", - il: iList( - en0, - en1, - wsl, - loopback, - ), - wantInt: masked(loopback), - wantExt: masked(en0, en1, wsl), + name: "exit-node-ip-via-prefs", // set exit node IP via prefs (should be resolved to an ID) + prefs: ipn.Prefs{ + ControlURL: controlURL, + }, + netMap: clientNetmap, + report: report, + changePrefs: &ipn.MaskedPrefs{ + Prefs: ipn.Prefs{ExitNodeIP: exitNode1.Addresses().At(0).Addr()}, + ExitNodeIPSet: true, + }, + wantPrefs: ipn.Prefs{ + ControlURL: controlURL, + ExitNodeID: exitNode1.StableID(), + }, + wantHostinfoExitNodeID: exitNode1.StableID(), }, { - name: "wsl2", - goos: "windows", - il: iList( - en0, - en1, - wsl, - loopback, - ), - wantInt: masked(loopback, wsl), - wantExt: masked(en0, en1), + name: "auto-exit-node-via-prefs/any", // set auto exit node via prefs + prefs: ipn.Prefs{ + ControlURL: controlURL, + }, + netMap: clientNetmap, + report: report, + changePrefs: &ipn.MaskedPrefs{ + Prefs: ipn.Prefs{AutoExitNode: "any"}, + AutoExitNodeSet: true, + }, + wantPrefs: ipn.Prefs{ + ControlURL: controlURL, + ExitNodeID: exitNode1.StableID(), + AutoExitNode: "any", + }, + wantHostinfoExitNodeID: exitNode1.StableID(), }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - gotInt, gotExt, err := internalAndExternalInterfacesFrom(tc.il, tc.goos) - if err != nil { - t.Fatal(err) - } - if !reflect.DeepEqual(gotInt, tc.wantInt) { - t.Errorf("unexpected internal prefixes\ngot %v\nwant %v", gotInt, tc.wantInt) - } - if !reflect.DeepEqual(gotExt, tc.wantExt) { - t.Errorf("unexpected external prefixes\ngot %v\nwant %v", gotExt, tc.wantExt) - } - }) - } -} - -func TestPacketFilterPermitsUnlockedNodes(t *testing.T) { - tests := []struct { - name string - peers []*tailcfg.Node - filter []filter.Match - want bool - }{ { - name: "empty", - want: false, + name: "auto-exit-node-via-prefs/set-exit-node-id-via-prefs", // setting exit node ID explicitly should disable auto exit node + prefs: ipn.Prefs{ + ControlURL: controlURL, + AutoExitNode: "any", + ExitNodeID: exitNode1.StableID(), + }, + netMap: clientNetmap, + report: report, + changePrefs: &ipn.MaskedPrefs{ + Prefs: ipn.Prefs{ExitNodeID: exitNode2.StableID()}, + ExitNodeIDSet: true, + }, + wantPrefs: ipn.Prefs{ + ControlURL: controlURL, + ExitNodeID: exitNode2.StableID(), + AutoExitNode: "", // should be unset + }, + wantHostinfoExitNodeID: exitNode2.StableID(), }, { - name: "no-unsigned", - peers: []*tailcfg.Node{ - {ID: 1}, + name: "auto-exit-node-via-prefs/any/no-report", // set auto exit node via prefs, but no report means we can't resolve the exit node ID + prefs: ipn.Prefs{ + ControlURL: controlURL, }, - want: false, + netMap: clientNetmap, + changePrefs: &ipn.MaskedPrefs{ + Prefs: ipn.Prefs{AutoExitNode: "any"}, + AutoExitNodeSet: true, + }, + wantPrefs: ipn.Prefs{ + ControlURL: controlURL, + ExitNodeID: unresolvedExitNodeID, // cannot resolve; traffic will be dropped + AutoExitNode: "any", + }, + wantHostinfoExitNodeID: "", }, { - name: "unsigned-good", - peers: []*tailcfg.Node{ - {ID: 1, UnsignedPeerAPIOnly: true}, + name: "auto-exit-node-via-prefs/any/no-netmap", // similarly, but without a netmap (no exit node should be selected) + prefs: ipn.Prefs{ + ControlURL: controlURL, }, - want: false, + report: report, + changePrefs: &ipn.MaskedPrefs{ + Prefs: ipn.Prefs{AutoExitNode: "any"}, + AutoExitNodeSet: true, + }, + wantPrefs: ipn.Prefs{ + ControlURL: controlURL, + ExitNodeID: unresolvedExitNodeID, // cannot resolve; traffic will be dropped + AutoExitNode: "any", + }, + wantHostinfoExitNodeID: "", }, { - name: "unsigned-bad", - peers: []*tailcfg.Node{ - { - ID: 1, - UnsignedPeerAPIOnly: true, - AllowedIPs: []netip.Prefix{ - netip.MustParsePrefix("100.64.0.0/32"), - }, - }, + name: "auto-exit-node-via-prefs/foo", // set auto exit node via prefs with an unknown/unsupported expression + prefs: ipn.Prefs{ + ControlURL: controlURL, }, - filter: []filter.Match{ - { - Srcs: []netip.Prefix{netip.MustParsePrefix("100.64.0.0/32")}, - Dsts: []filter.NetPortRange{ - { - Net: netip.MustParsePrefix("100.99.0.0/32"), - }, - }, - }, + netMap: clientNetmap, + report: report, + changePrefs: &ipn.MaskedPrefs{ + Prefs: ipn.Prefs{AutoExitNode: "foo"}, + AutoExitNodeSet: true, }, - want: true, + wantPrefs: ipn.Prefs{ + ControlURL: controlURL, + ExitNodeID: exitNode1.StableID(), // unknown exit node expressions should work as "any" + AutoExitNode: "foo", + }, + wantHostinfoExitNodeID: exitNode1.StableID(), }, { - name: "unsigned-bad-src-is-superset", - peers: []*tailcfg.Node{ - { - ID: 1, - UnsignedPeerAPIOnly: true, - AllowedIPs: []netip.Prefix{ - netip.MustParsePrefix("100.64.0.0/32"), - }, - }, + name: "auto-exit-node-via-prefs/off", // toggle the exit node off after it was set to "any" + prefs: ipn.Prefs{ + ControlURL: controlURL, }, - filter: []filter.Match{ - { - Srcs: []netip.Prefix{netip.MustParsePrefix("100.64.0.0/24")}, - Dsts: []filter.NetPortRange{ - { - Net: netip.MustParsePrefix("100.99.0.0/32"), - }, - }, - }, + netMap: clientNetmap, + report: report, + changePrefs: &ipn.MaskedPrefs{ + Prefs: ipn.Prefs{AutoExitNode: "any"}, + AutoExitNodeSet: true, }, - want: true, + useExitNodeEnabled: ptr.To(false), + wantPrefs: ipn.Prefs{ + ControlURL: controlURL, + ExitNodeID: "", + AutoExitNode: "", + InternalExitNodePrior: "auto:any", + }, + wantHostinfoExitNodeID: "", }, { - name: "unsigned-okay-because-no-dsts", - peers: []*tailcfg.Node{ - { - ID: 1, - UnsignedPeerAPIOnly: true, - AllowedIPs: []netip.Prefix{ - netip.MustParsePrefix("100.64.0.0/32"), - }, - }, + name: "auto-exit-node-via-prefs/on", // toggle the exit node on + prefs: ipn.Prefs{ + ControlURL: controlURL, + InternalExitNodePrior: "auto:any", }, - filter: []filter.Match{ - { - Srcs: []netip.Prefix{netip.MustParsePrefix("100.64.0.0/32")}, - Caps: []filter.CapMatch{ - { - Dst: netip.MustParsePrefix("100.99.0.0/32"), - Cap: "foo", - }, - }, + netMap: clientNetmap, + report: report, + useExitNodeEnabled: ptr.To(true), + wantPrefs: ipn.Prefs{ + ControlURL: controlURL, + ExitNodeID: exitNode1.StableID(), + AutoExitNode: "any", + InternalExitNodePrior: "auto:any", + }, + wantHostinfoExitNodeID: exitNode1.StableID(), + }, + { + name: "id-via-policy", // set exit node ID via syspolicy + prefs: ipn.Prefs{ + ControlURL: controlURL, + }, + netMap: clientNetmap, + exitNodeIDPolicy: ptr.To(exitNode1.StableID()), + wantPrefs: ipn.Prefs{ + ControlURL: controlURL, + ExitNodeID: exitNode1.StableID(), + }, + wantHostinfoExitNodeID: exitNode1.StableID(), + }, + { + name: "id-via-policy/cannot-override-via-prefs/by-id", // syspolicy should take precedence over prefs + prefs: ipn.Prefs{ + ControlURL: controlURL, + }, + netMap: clientNetmap, + exitNodeIDPolicy: ptr.To(exitNode1.StableID()), + changePrefs: &ipn.MaskedPrefs{ + Prefs: ipn.Prefs{ + ExitNodeID: exitNode2.StableID(), // this should be ignored }, + ExitNodeIDSet: true, }, - want: false, + wantPrefs: ipn.Prefs{ + ControlURL: controlURL, + ExitNodeID: exitNode1.StableID(), + }, + wantHostinfoExitNodeID: exitNode1.StableID(), + wantChangePrefsErr: errManagedByPolicy, }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if got := packetFilterPermitsUnlockedNodes(peersMap(nodeViews(tt.peers)), tt.filter); got != tt.want { - t.Errorf("got %v, want %v", got, tt.want) - } - }) - } -} - -func TestStatusPeerCapabilities(t *testing.T) { - tests := []struct { - name string - peers []tailcfg.NodeView - expectedPeerCapabilities map[tailcfg.StableNodeID][]tailcfg.NodeCapability - expectedPeerCapMap map[tailcfg.StableNodeID]tailcfg.NodeCapMap - }{ { - name: "peers-with-capabilities", - peers: []tailcfg.NodeView{ - (&tailcfg.Node{ - ID: 1, - StableID: "foo", - IsWireGuardOnly: true, - Hostinfo: (&tailcfg.Hostinfo{}).View(), - Capabilities: []tailcfg.NodeCapability{tailcfg.CapabilitySSH}, - CapMap: (tailcfg.NodeCapMap)(map[tailcfg.NodeCapability][]tailcfg.RawMessage{ - tailcfg.CapabilitySSH: nil, - }), - }).View(), - (&tailcfg.Node{ - ID: 2, - StableID: "bar", - Hostinfo: (&tailcfg.Hostinfo{}).View(), - Capabilities: []tailcfg.NodeCapability{tailcfg.CapabilityAdmin}, - CapMap: (tailcfg.NodeCapMap)(map[tailcfg.NodeCapability][]tailcfg.RawMessage{ - tailcfg.CapabilityAdmin: {`{"test": "true}`}, - }), - }).View(), + name: "id-via-policy/cannot-override-via-prefs/by-ip", // syspolicy should take precedence over prefs + prefs: ipn.Prefs{ + ControlURL: controlURL, }, - expectedPeerCapabilities: map[tailcfg.StableNodeID][]tailcfg.NodeCapability{ - tailcfg.StableNodeID("foo"): {tailcfg.CapabilitySSH}, - tailcfg.StableNodeID("bar"): {tailcfg.CapabilityAdmin}, + netMap: clientNetmap, + exitNodeIDPolicy: ptr.To(exitNode1.StableID()), + changePrefs: &ipn.MaskedPrefs{ + Prefs: ipn.Prefs{ + ExitNodeIP: exitNode2.Addresses().At(0).Addr(), // this should be ignored + }, + ExitNodeIPSet: true, }, - expectedPeerCapMap: map[tailcfg.StableNodeID]tailcfg.NodeCapMap{ - tailcfg.StableNodeID("foo"): (tailcfg.NodeCapMap)(map[tailcfg.NodeCapability][]tailcfg.RawMessage{ - tailcfg.CapabilitySSH: nil, - }), - tailcfg.StableNodeID("bar"): (tailcfg.NodeCapMap)(map[tailcfg.NodeCapability][]tailcfg.RawMessage{ - tailcfg.CapabilityAdmin: {`{"test": "true}`}, - }), + wantPrefs: ipn.Prefs{ + ControlURL: controlURL, + ExitNodeID: exitNode1.StableID(), }, + wantHostinfoExitNodeID: exitNode1.StableID(), + wantChangePrefsErr: errManagedByPolicy, }, { - name: "peers-without-capabilities", - peers: []tailcfg.NodeView{ - (&tailcfg.Node{ - ID: 1, - StableID: "foo", - IsWireGuardOnly: true, - Hostinfo: (&tailcfg.Hostinfo{}).View(), - }).View(), - (&tailcfg.Node{ - ID: 2, - StableID: "bar", - Hostinfo: (&tailcfg.Hostinfo{}).View(), - }).View(), + name: "id-via-policy/cannot-override-via-prefs/by-auto-expr", // syspolicy should take precedence over prefs + prefs: ipn.Prefs{ + ControlURL: controlURL, }, + netMap: clientNetmap, + exitNodeIDPolicy: ptr.To(exitNode1.StableID()), + changePrefs: &ipn.MaskedPrefs{ + Prefs: ipn.Prefs{ + AutoExitNode: "any", // this should be ignored + }, + AutoExitNodeSet: true, + }, + wantPrefs: ipn.Prefs{ + ControlURL: controlURL, + ExitNodeID: exitNode1.StableID(), + }, + wantHostinfoExitNodeID: exitNode1.StableID(), + wantChangePrefsErr: errManagedByPolicy, }, - } - b := newTestLocalBackend(t) - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - b.setNetMapLocked(&netmap.NetworkMap{ - SelfNode: (&tailcfg.Node{ - MachineAuthorized: true, - Addresses: ipps("100.101.101.101"), - }).View(), - Peers: tt.peers, - }) - got := b.Status() - for _, peer := range got.Peer { - if !reflect.DeepEqual(peer.Capabilities, tt.expectedPeerCapabilities[peer.ID]) { - t.Errorf("peer capabilities: expected %v got %v", tt.expectedPeerCapabilities, peer.Capabilities) - } - if !reflect.DeepEqual(peer.CapMap, tt.expectedPeerCapMap[peer.ID]) { - t.Errorf("peer capmap: expected %v got %v", tt.expectedPeerCapMap, peer.CapMap) - } - } - }) - } -} - -// legacyBackend was the interface between Tailscale frontends -// (e.g. cmd/tailscale, iOS/MacOS/Windows GUIs) and the tailscale -// backend (e.g. cmd/tailscaled) running on the same machine. -// (It has nothing to do with the interface between the backends -// and the cloud control plane.) -type legacyBackend interface { - // SetNotifyCallback sets the callback to be called on updates - // from the backend to the client. - SetNotifyCallback(func(ipn.Notify)) - // Start starts or restarts the backend, typically when a - // frontend client connects. - Start(ipn.Options) error -} - -// Verify that LocalBackend still implements the legacyBackend interface -// for now, at least until the macOS and iOS clients move off of it. -var _ legacyBackend = (*LocalBackend)(nil) - -func TestWatchNotificationsCallbacks(t *testing.T) { - b := new(LocalBackend) - n := new(ipn.Notify) - b.WatchNotifications(context.Background(), 0, func() { - b.mu.Lock() - defer b.mu.Unlock() - - // Ensure a watcher has been installed. - if len(b.notifyWatchers) != 1 { - t.Fatalf("unexpected number of watchers in new LocalBackend, want: 1 got: %v", len(b.notifyWatchers)) - } - // Send a notification. Range over notifyWatchers to get the channel - // because WatchNotifications doesn't expose the handle for it. - for _, sess := range b.notifyWatchers { - select { - case sess.ch <- n: - default: - t.Fatalf("could not send notification") - } - } - }, func(roNotify *ipn.Notify) bool { - if roNotify != n { - t.Fatalf("unexpected notification received. want: %v got: %v", n, roNotify) - } - return false - }) - - // Ensure watchers have been cleaned up. - b.mu.Lock() - defer b.mu.Unlock() - if len(b.notifyWatchers) != 0 { - t.Fatalf("unexpected number of watchers in new LocalBackend, want: 0 got: %v", len(b.notifyWatchers)) - } -} - -// tests LocalBackend.updateNetmapDeltaLocked -func TestUpdateNetmapDelta(t *testing.T) { - b := newTestLocalBackend(t) - if b.updateNetmapDeltaLocked(nil) { - t.Errorf("updateNetmapDeltaLocked() = true, want false with nil netmap") - } - - b.netMap = &netmap.NetworkMap{} - for i := range 5 { - b.netMap.Peers = append(b.netMap.Peers, (&tailcfg.Node{ID: (tailcfg.NodeID(i) + 1)}).View()) - } - b.updatePeersFromNetmapLocked(b.netMap) - - someTime := time.Unix(123, 0) - muts, ok := netmap.MutationsFromMapResponse(&tailcfg.MapResponse{ - PeersChangedPatch: []*tailcfg.PeerChange{ - { - NodeID: 1, - DERPRegion: 1, + { + name: "ip-via-policy", // set exit node IP via syspolicy (should be resolved to an ID) + prefs: ipn.Prefs{ + ControlURL: controlURL, }, - { - NodeID: 2, - Online: ptr.To(true), + netMap: clientNetmap, + exitNodeIPPolicy: ptr.To(exitNode2.Addresses().At(0).Addr()), + wantPrefs: ipn.Prefs{ + ControlURL: controlURL, + ExitNodeID: exitNode2.StableID(), }, - { - NodeID: 3, - Online: ptr.To(false), + wantHostinfoExitNodeID: exitNode2.StableID(), + }, + { + name: "auto-any-via-policy", // set auto exit node via syspolicy (an exit node should be selected) + prefs: ipn.Prefs{ + ControlURL: controlURL, }, - { - NodeID: 4, - LastSeen: ptr.To(someTime), + netMap: clientNetmap, + report: report, + exitNodeIDPolicy: ptr.To(tailcfg.StableNodeID("auto:any")), + wantPrefs: ipn.Prefs{ + ControlURL: controlURL, + ExitNodeID: exitNode1.StableID(), + AutoExitNode: "any", }, + wantHostinfoExitNodeID: exitNode1.StableID(), }, - }, someTime) - if !ok { - t.Fatal("netmap.MutationsFromMapResponse failed") - } - - if !b.updateNetmapDeltaLocked(muts) { - t.Fatalf("updateNetmapDeltaLocked() = false, want true with new netmap") - } - - wants := []*tailcfg.Node{ { - ID: 1, - DERP: "127.3.3.40:1", + name: "auto-any-via-policy/no-report", // set auto exit node via syspolicy without a netcheck report (no exit node should be selected) + prefs: ipn.Prefs{ + ControlURL: controlURL, + }, + netMap: clientNetmap, + report: nil, + exitNodeIDPolicy: ptr.To(tailcfg.StableNodeID("auto:any")), + wantPrefs: ipn.Prefs{ + ControlURL: controlURL, + ExitNodeID: unresolvedExitNodeID, + AutoExitNode: "any", + }, + wantHostinfoExitNodeID: "", }, { - ID: 2, - Online: ptr.To(true), + name: "auto-any-via-policy/no-netmap", // similarly, but without a netmap (no exit node should be selected) + prefs: ipn.Prefs{ + ControlURL: controlURL, + }, + netMap: nil, + report: report, + exitNodeIDPolicy: ptr.To(tailcfg.StableNodeID("auto:any")), + wantPrefs: ipn.Prefs{ + ControlURL: controlURL, + ExitNodeID: unresolvedExitNodeID, + AutoExitNode: "any", + }, + wantHostinfoExitNodeID: "", }, { - ID: 3, - Online: ptr.To(false), + name: "auto-any-via-policy/no-netmap/with-existing", // set auto exit node via syspolicy without a netmap, but with a previously set exit node ID + prefs: ipn.Prefs{ + ControlURL: controlURL, + ExitNodeID: exitNode2.StableID(), // should be retained + }, + netMap: nil, + report: report, + exitNodeIDPolicy: ptr.To(tailcfg.StableNodeID("auto:any")), + exitNodeAllowedIDs: nil, // not configured, so all exit node IDs are implicitly allowed + wantPrefs: ipn.Prefs{ + ControlURL: controlURL, + ExitNodeID: exitNode2.StableID(), + AutoExitNode: "any", + }, + wantHostinfoExitNodeID: exitNode2.StableID(), }, { - ID: 4, - LastSeen: ptr.To(someTime), + name: "auto-any-via-policy/no-netmap/with-allowed-existing", // same, but now with a syspolicy setting that explicitly allows the existing exit node ID + prefs: ipn.Prefs{ + ControlURL: controlURL, + ExitNodeID: exitNode2.StableID(), // should be retained + }, + netMap: nil, + report: report, + exitNodeIDPolicy: ptr.To(tailcfg.StableNodeID("auto:any")), + exitNodeAllowedIDs: []tailcfg.StableNodeID{ + exitNode2.StableID(), // the current exit node ID is allowed + }, + wantPrefs: ipn.Prefs{ + ControlURL: controlURL, + ExitNodeID: exitNode2.StableID(), + AutoExitNode: "any", + }, + wantHostinfoExitNodeID: exitNode2.StableID(), }, - } - for _, want := range wants { - gotv, ok := b.peers[want.ID] - if !ok { - t.Errorf("netmap.Peer %v missing from b.peers", want.ID) - continue - } - got := gotv.AsStruct() - if !reflect.DeepEqual(got, want) { - t.Errorf("netmap.Peer %v wrong.\n got: %v\nwant: %v", want.ID, logger.AsJSON(got), logger.AsJSON(want)) - } - } -} - -// tests WhoIs and indirectly that setNetMapLocked updates b.nodeByAddr correctly. -func TestWhoIs(t *testing.T) { - b := newTestLocalBackend(t) - b.setNetMapLocked(&netmap.NetworkMap{ - SelfNode: (&tailcfg.Node{ - ID: 1, - User: 10, - Addresses: []netip.Prefix{netip.MustParsePrefix("100.101.102.103/32")}, - }).View(), - Peers: []tailcfg.NodeView{ - (&tailcfg.Node{ - ID: 2, - User: 20, - Addresses: []netip.Prefix{netip.MustParsePrefix("100.200.200.200/32")}, - }).View(), + { + name: "auto-any-via-policy/no-netmap/with-disallowed-existing", // same, but now with a syspolicy setting that does not allow the existing exit node ID + prefs: ipn.Prefs{ + ControlURL: controlURL, + ExitNodeID: exitNode2.StableID(), // not allowed by [pkey.AllowedSuggestedExitNodes] + }, + netMap: nil, + report: report, + exitNodeIDPolicy: ptr.To(tailcfg.StableNodeID("auto:any")), + exitNodeAllowedIDs: []tailcfg.StableNodeID{ + exitNode1.StableID(), // a different exit node ID; the current one is not allowed + }, + wantPrefs: ipn.Prefs{ + ControlURL: controlURL, + ExitNodeID: unresolvedExitNodeID, // we don't have a netmap yet, and the current exit node ID is not allowed; block traffic + AutoExitNode: "any", + }, + wantHostinfoExitNodeID: "", }, - UserProfiles: map[tailcfg.UserID]tailcfg.UserProfile{ - 10: { - DisplayName: "Myself", + { + name: "auto-any-via-policy/with-netmap/with-allowed-existing", // same, but now with a syspolicy setting that does not allow the existing exit node ID + prefs: ipn.Prefs{ + ControlURL: controlURL, + ExitNodeID: exitNode1.StableID(), // not allowed by [pkey.AllowedSuggestedExitNodes] }, - 20: { - DisplayName: "Peer", + netMap: clientNetmap, + report: report, + exitNodeIDPolicy: ptr.To(tailcfg.StableNodeID("auto:any")), + exitNodeAllowedIDs: []tailcfg.StableNodeID{ + exitNode2.StableID(), // a different exit node ID; the current one is not allowed + }, + wantPrefs: ipn.Prefs{ + ControlURL: controlURL, + ExitNodeID: exitNode2.StableID(), // we have a netmap; switch to the best allowed exit node + AutoExitNode: "any", }, + wantHostinfoExitNodeID: exitNode2.StableID(), }, - }) - tests := []struct { - q string - want tailcfg.NodeID // 0 means want ok=false - wantName string - }{ - {"100.101.102.103:0", 1, "Myself"}, - {"100.101.102.103:123", 1, "Myself"}, - {"100.200.200.200:0", 2, "Peer"}, - {"100.200.200.200:123", 2, "Peer"}, - {"100.4.0.4:404", 0, ""}, - } - for _, tt := range tests { - t.Run(tt.q, func(t *testing.T) { - nv, up, ok := b.WhoIs("", netip.MustParseAddrPort(tt.q)) - var got tailcfg.NodeID - if ok { - got = nv.ID() - } - if got != tt.want { - t.Errorf("got nodeID %v; want %v", got, tt.want) - } - if up.DisplayName != tt.wantName { - t.Errorf("got name %q; want %q", up.DisplayName, tt.wantName) - } - }) - } -} - -func TestWireguardExitNodeDNSResolvers(t *testing.T) { - type tc struct { - name string - id tailcfg.StableNodeID - peers []*tailcfg.Node - wantOK bool - wantResolvers []*dnstype.Resolver - } - - tests := []tc{ { - name: "no peers", - id: "1", - wantOK: false, - wantResolvers: nil, + name: "auto-any-via-policy/with-netmap/switch-to-better", // if all exit nodes are allowed, switch to the best one once we have a netmap + prefs: ipn.Prefs{ + ControlURL: controlURL, + ExitNodeID: exitNode2.StableID(), + }, + netMap: clientNetmap, + report: report, + exitNodeIDPolicy: ptr.To(tailcfg.StableNodeID("auto:any")), + wantPrefs: ipn.Prefs{ + ControlURL: controlURL, + ExitNodeID: exitNode1.StableID(), // switch to the best exit node + AutoExitNode: "any", + }, + wantHostinfoExitNodeID: exitNode1.StableID(), }, { - name: "non wireguard peer", - id: "1", - peers: []*tailcfg.Node{ - { - ID: 1, - StableID: "1", - IsWireGuardOnly: false, - ExitNodeDNSResolvers: []*dnstype.Resolver{{Addr: "dns.example.com"}}, + name: "auto-foo-via-policy", // set auto exit node via syspolicy with an unknown/unsupported expression + prefs: ipn.Prefs{ + ControlURL: controlURL, + }, + netMap: clientNetmap, + report: report, + exitNodeIDPolicy: ptr.To(tailcfg.StableNodeID("auto:foo")), + wantPrefs: ipn.Prefs{ + ControlURL: controlURL, + ExitNodeID: exitNode1.StableID(), // unknown exit node expressions should work as "any" + AutoExitNode: "foo", + }, + wantHostinfoExitNodeID: exitNode1.StableID(), + }, + { + name: "auto-foo-via-edit-prefs", // set auto exit node via EditPrefs with an unknown/unsupported expression + prefs: ipn.Prefs{ + ControlURL: controlURL, + }, + netMap: clientNetmap, + report: report, + changePrefs: &ipn.MaskedPrefs{ + Prefs: ipn.Prefs{AutoExitNode: "foo"}, + AutoExitNodeSet: true, + }, + wantPrefs: ipn.Prefs{ + ControlURL: controlURL, + ExitNodeID: exitNode1.StableID(), // unknown exit node expressions should work as "any" + AutoExitNode: "foo", + }, + wantHostinfoExitNodeID: exitNode1.StableID(), + }, + { + name: "auto-any-via-policy/toggle-off", // cannot toggle off the exit node if it was set via syspolicy + prefs: ipn.Prefs{ + ControlURL: controlURL, + }, + netMap: clientNetmap, + report: report, + exitNodeIDPolicy: ptr.To(tailcfg.StableNodeID("auto:any")), + useExitNodeEnabled: ptr.To(false), // should fail with an error + wantExitNodeToggleErr: errManagedByPolicy, + wantPrefs: ipn.Prefs{ + ControlURL: controlURL, + ExitNodeID: exitNode1.StableID(), // still enforced by the policy setting + AutoExitNode: "any", + InternalExitNodePrior: "", + }, + wantHostinfoExitNodeID: exitNode1.StableID(), + }, + { + name: "auto-any-via-policy/allow-override/change", // changing the exit node is allowed by [pkey.AllowExitNodeOverride] + prefs: ipn.Prefs{ + ControlURL: controlURL, + }, + netMap: clientNetmap, + report: report, + exitNodeIDPolicy: ptr.To(tailcfg.StableNodeID("auto:any")), + exitNodeAllowOverride: true, // allow changing the exit node + changePrefs: &ipn.MaskedPrefs{ + Prefs: ipn.Prefs{ + ExitNodeID: exitNode2.StableID(), // change the exit node ID }, + ExitNodeIDSet: true, }, - wantOK: false, - wantResolvers: nil, + wantPrefs: ipn.Prefs{ + ControlURL: controlURL, + ExitNodeID: exitNode2.StableID(), // overridden by user + AutoExitNode: "", // cleared, as we are setting the exit node ID explicitly + }, + wantHostinfoExitNodeID: exitNode2.StableID(), }, { - name: "no matching IDs", - id: "2", - peers: []*tailcfg.Node{ - { - ID: 1, - StableID: "1", - IsWireGuardOnly: true, - ExitNodeDNSResolvers: []*dnstype.Resolver{{Addr: "dns.example.com"}}, + name: "auto-any-via-policy/allow-override/clear", // clearing the exit node ID is not allowed by [pkey.AllowExitNodeOverride] + prefs: ipn.Prefs{ + ControlURL: controlURL, + }, + netMap: clientNetmap, + report: report, + exitNodeIDPolicy: ptr.To(tailcfg.StableNodeID("auto:any")), + exitNodeAllowOverride: true, // allow changing, but not disabling, the exit node + changePrefs: &ipn.MaskedPrefs{ + Prefs: ipn.Prefs{ + ExitNodeID: "", // clearing the exit node ID disables the exit node and should not be allowed }, + ExitNodeIDSet: true, }, - wantOK: false, - wantResolvers: nil, + wantChangePrefsErr: errManagedByPolicy, // edit prefs should fail with an error + wantPrefs: ipn.Prefs{ + ControlURL: controlURL, + ExitNodeID: exitNode1.StableID(), // still enforced by the policy setting + AutoExitNode: "any", + InternalExitNodePrior: "", + }, + wantHostinfoExitNodeID: exitNode1.StableID(), }, { - name: "wireguard peer", - id: "1", - peers: []*tailcfg.Node{ - { - ID: 1, - StableID: "1", - IsWireGuardOnly: true, - ExitNodeDNSResolvers: []*dnstype.Resolver{{Addr: "dns.example.com"}}, + name: "auto-any-via-policy/allow-override/toggle-off", // similarly, toggling off the exit node is not allowed even with [pkey.AllowExitNodeOverride] + prefs: ipn.Prefs{ + ControlURL: controlURL, + }, + netMap: clientNetmap, + report: report, + exitNodeIDPolicy: ptr.To(tailcfg.StableNodeID("auto:any")), + exitNodeAllowOverride: true, // allow changing, but not disabling, the exit node + useExitNodeEnabled: ptr.To(false), // should fail with an error + wantExitNodeToggleErr: errManagedByPolicy, + wantPrefs: ipn.Prefs{ + ControlURL: controlURL, + ExitNodeID: exitNode1.StableID(), // still enforced by the policy setting + AutoExitNode: "any", + InternalExitNodePrior: "", + }, + wantHostinfoExitNodeID: exitNode1.StableID(), + }, + { + name: "auto-any-via-initial-prefs/no-netmap/clear-auto-exit-node", + prefs: ipn.Prefs{ + ControlURL: controlURL, + AutoExitNode: ipn.AnyExitNode, + }, + netMap: nil, // no netmap; exit node cannot be resolved + report: report, + changePrefs: &ipn.MaskedPrefs{ + Prefs: ipn.Prefs{ + AutoExitNode: "", // clear the auto exit node }, + AutoExitNodeSet: true, }, - wantOK: true, - wantResolvers: []*dnstype.Resolver{{Addr: "dns.example.com"}}, + wantPrefs: ipn.Prefs{ + ControlURL: controlURL, + AutoExitNode: "", // cleared + ExitNodeID: "", // has never been resolved, so it should be cleared as well + }, + wantHostinfoExitNodeID: "", + }, + { + name: "auto-any-via-initial-prefs/with-netmap/clear-auto-exit-node", + prefs: ipn.Prefs{ + ControlURL: controlURL, + AutoExitNode: ipn.AnyExitNode, + }, + netMap: clientNetmap, // has a netmap; exit node will be resolved + report: report, + changePrefs: &ipn.MaskedPrefs{ + Prefs: ipn.Prefs{ + AutoExitNode: "", // clear the auto exit node + }, + AutoExitNodeSet: true, + }, + wantPrefs: ipn.Prefs{ + ControlURL: controlURL, + AutoExitNode: "", // cleared + ExitNodeID: exitNode1.StableID(), // a resolved exit node ID should be retained + }, + wantHostinfoExitNodeID: exitNode1.StableID(), }, } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() - for _, tc := range tests { - peers := peersMap(nodeViews(tc.peers)) - nm := &netmap.NetworkMap{} - gotResolvers, gotOK := wireguardExitNodeDNSResolvers(nm, peers, tc.id) + var pol policytest.Config + // Configure policy settings, if any. + if tt.exitNodeIDPolicy != nil { + pol.Set(pkey.ExitNodeID, string(*tt.exitNodeIDPolicy)) + } + if tt.exitNodeIPPolicy != nil { + pol.Set(pkey.ExitNodeIP, tt.exitNodeIPPolicy.String()) + } + if tt.exitNodeAllowedIDs != nil { + pol.Set(pkey.AllowedSuggestedExitNodes, toStrings(tt.exitNodeAllowedIDs)) + } + if tt.exitNodeAllowOverride { + pol.Set(pkey.AllowExitNodeOverride, true) + } - if gotOK != tc.wantOK || !resolversEqual(t, gotResolvers, tc.wantResolvers) { - t.Errorf("case: %s: got %v, %v, want %v, %v", tc.name, gotOK, gotResolvers, tc.wantOK, tc.wantResolvers) - } - } -} + // Create a new LocalBackend with the given prefs. + // Any syspolicy settings will be applied to the initial prefs. + sys := tsd.NewSystem() + sys.PolicyClient.Set(pol) + lb := newTestLocalBackendWithSys(t, sys) + lb.SetPrefsForTest(tt.prefs.Clone()) -func TestDNSConfigForNetmapForExitNodeConfigs(t *testing.T) { - type tc struct { - name string - exitNode tailcfg.StableNodeID - peers []tailcfg.NodeView - dnsConfig *tailcfg.DNSConfig - wantDefaultResolvers []*dnstype.Resolver - wantRoutes map[dnsname.FQDN][]*dnstype.Resolver - } + // Then set the netcheck report and netmap, if any. + if tt.report != nil { + lb.MagicConn().SetLastNetcheckReportForTest(t.Context(), tt.report) + } + if tt.netMap != nil { + lb.SetControlClientStatus(lb.cc, controlclient.Status{NetMap: tt.netMap}) + } - defaultResolvers := []*dnstype.Resolver{{Addr: "default.example.com"}} - wgResolvers := []*dnstype.Resolver{{Addr: "wg.example.com"}} - peers := []tailcfg.NodeView{ - (&tailcfg.Node{ - ID: 1, - StableID: "wg", - IsWireGuardOnly: true, - ExitNodeDNSResolvers: wgResolvers, - Hostinfo: (&tailcfg.Hostinfo{}).View(), - }).View(), - // regular tailscale exit node with DNS capabilities - (&tailcfg.Node{ - Cap: 26, - ID: 2, - StableID: "ts", - Hostinfo: (&tailcfg.Hostinfo{}).View(), - }).View(), - } - exitDOH := peerAPIBase(&netmap.NetworkMap{Peers: peers}, peers[0]) + "/dns-query" - routes := map[dnsname.FQDN][]*dnstype.Resolver{ - "route.example.com.": {{Addr: "route.example.com"}}, - } - stringifyRoutes := func(routes map[dnsname.FQDN][]*dnstype.Resolver) map[string][]*dnstype.Resolver { - if routes == nil { - return nil - } - m := make(map[string][]*dnstype.Resolver) - for k, v := range routes { - m[string(k)] = v - } - return m + user := &ipnauth.TestActor{} + // If we have a changePrefs, apply it. + if tt.changePrefs != nil { + _, err := lb.EditPrefsAs(tt.changePrefs, user) + checkError(t, err, tt.wantChangePrefsErr, true) + } + + // If we need to flip exit node toggle on or off, do it. + if tt.useExitNodeEnabled != nil { + _, err := lb.SetUseExitNodeEnabled(user, *tt.useExitNodeEnabled) + checkError(t, err, tt.wantExitNodeToggleErr, true) + } + + // Now check the prefs. + opts := []cmp.Option{ + cmpopts.EquateComparable(netip.Addr{}, netip.Prefix{}), + } + if diff := cmp.Diff(&tt.wantPrefs, lb.Prefs().AsStruct(), opts...); diff != "" { + t.Errorf("Prefs(+got -want): %v", diff) + } + + // And check Hostinfo. + if got := lb.hostinfo.ExitNodeID; got != tt.wantHostinfoExitNodeID { + t.Errorf("Hostinfo.ExitNodeID got %s, want %s", got, tt.wantHostinfoExitNodeID) + } + }) } +} - tests := []tc{ - { - name: "noExit/noRoutes/noResolver", - exitNode: "", - peers: peers, - dnsConfig: &tailcfg.DNSConfig{}, - wantDefaultResolvers: nil, - wantRoutes: nil, - }, +func TestPrefsChangeDisablesExitNode(t *testing.T) { + tests := []struct { + name string + netMap *netmap.NetworkMap + prefs ipn.Prefs + change ipn.MaskedPrefs + wantDisablesExitNode bool + }{ { - name: "tsExit/noRoutes/noResolver", - exitNode: "ts", - peers: peers, - dnsConfig: &tailcfg.DNSConfig{}, - wantDefaultResolvers: []*dnstype.Resolver{{Addr: exitDOH}}, - wantRoutes: nil, + name: "has-exit-node-id/no-change", + prefs: ipn.Prefs{ + ExitNodeID: "test-exit-node", + }, + change: ipn.MaskedPrefs{}, + wantDisablesExitNode: false, }, { - name: "tsExit/noRoutes/defaultResolver", - exitNode: "ts", - peers: peers, - dnsConfig: &tailcfg.DNSConfig{Resolvers: defaultResolvers}, - wantDefaultResolvers: []*dnstype.Resolver{{Addr: exitDOH}}, - wantRoutes: nil, + name: "has-exit-node-ip/no-change", + prefs: ipn.Prefs{ + ExitNodeIP: netip.MustParseAddr("100.100.1.1"), + }, + change: ipn.MaskedPrefs{}, + wantDisablesExitNode: false, }, - - // The following two cases may need to be revisited. For a shared-in - // exit node split-DNS may effectively break, furthermore in the future - // if different nodes observe different DNS configurations, even a - // tailnet local exit node may present a different DNS configuration, - // which may not meet expectations in some use cases. - // In the case where a default resolver is set, the default resolver - // should also perhaps take precedence also. { - name: "tsExit/routes/noResolver", - exitNode: "ts", - peers: peers, - dnsConfig: &tailcfg.DNSConfig{Routes: stringifyRoutes(routes)}, - wantDefaultResolvers: []*dnstype.Resolver{{Addr: exitDOH}}, - wantRoutes: nil, + name: "has-auto-exit-node/no-change", + prefs: ipn.Prefs{ + AutoExitNode: ipn.AnyExitNode, + }, + change: ipn.MaskedPrefs{}, + wantDisablesExitNode: false, }, { - name: "tsExit/routes/defaultResolver", - exitNode: "ts", - peers: peers, - dnsConfig: &tailcfg.DNSConfig{Routes: stringifyRoutes(routes), Resolvers: defaultResolvers}, - wantDefaultResolvers: []*dnstype.Resolver{{Addr: exitDOH}}, - wantRoutes: nil, + name: "has-exit-node-id/non-exit-node-change", + prefs: ipn.Prefs{ + ExitNodeID: "test-exit-node", + }, + change: ipn.MaskedPrefs{ + WantRunningSet: true, + HostnameSet: true, + ExitNodeAllowLANAccessSet: true, + Prefs: ipn.Prefs{ + WantRunning: true, + Hostname: "test-hostname", + ExitNodeAllowLANAccess: true, + }, + }, + wantDisablesExitNode: false, }, - - // WireGuard exit nodes with DNS capabilities provide a "fallback" type - // behavior, they have a lower precedence than a default resolver, but - // otherwise allow split-DNS to operate as normal, and are used when - // there is no default resolver. { - name: "wgExit/noRoutes/noResolver", - exitNode: "wg", - peers: peers, - dnsConfig: &tailcfg.DNSConfig{}, - wantDefaultResolvers: wgResolvers, - wantRoutes: nil, + name: "has-exit-node-ip/non-exit-node-change", + prefs: ipn.Prefs{ + ExitNodeIP: netip.MustParseAddr("100.100.1.1"), + }, + change: ipn.MaskedPrefs{ + WantRunningSet: true, + RouteAllSet: true, + ShieldsUpSet: true, + Prefs: ipn.Prefs{ + WantRunning: false, + RouteAll: false, + ShieldsUp: true, + }, + }, + wantDisablesExitNode: false, }, { - name: "wgExit/noRoutes/defaultResolver", - exitNode: "wg", - peers: peers, - dnsConfig: &tailcfg.DNSConfig{Resolvers: defaultResolvers}, - wantDefaultResolvers: defaultResolvers, - wantRoutes: nil, + name: "has-auto-exit-node/non-exit-node-change", + prefs: ipn.Prefs{ + AutoExitNode: ipn.AnyExitNode, + }, + change: ipn.MaskedPrefs{ + CorpDNSSet: true, + RouteAllSet: true, + ExitNodeAllowLANAccessSet: true, + Prefs: ipn.Prefs{ + CorpDNS: true, + RouteAll: false, + ExitNodeAllowLANAccess: true, + }, + }, + wantDisablesExitNode: false, }, { - name: "wgExit/routes/defaultResolver", - exitNode: "wg", - peers: peers, - dnsConfig: &tailcfg.DNSConfig{Routes: stringifyRoutes(routes), Resolvers: defaultResolvers}, - wantDefaultResolvers: defaultResolvers, - wantRoutes: routes, + name: "has-exit-node-id/change-exit-node-id", + prefs: ipn.Prefs{ + ExitNodeID: "exit-node-1", + }, + change: ipn.MaskedPrefs{ + ExitNodeIDSet: true, + Prefs: ipn.Prefs{ + ExitNodeID: "exit-node-2", + }, + }, + wantDisablesExitNode: false, // changing the exit node ID does not disable it }, { - name: "wgExit/routes/noResolver", - exitNode: "wg", - peers: peers, - dnsConfig: &tailcfg.DNSConfig{Routes: stringifyRoutes(routes)}, - wantDefaultResolvers: wgResolvers, - wantRoutes: routes, + name: "has-exit-node-id/enable-auto-exit-node", + prefs: ipn.Prefs{ + ExitNodeID: "exit-node-1", + }, + change: ipn.MaskedPrefs{ + AutoExitNodeSet: true, + Prefs: ipn.Prefs{ + AutoExitNode: ipn.AnyExitNode, + }, + }, + wantDisablesExitNode: false, // changing the exit node ID does not disable it + }, + { + name: "has-exit-node-id/clear-exit-node-id", + prefs: ipn.Prefs{ + ExitNodeID: "exit-node-1", + }, + change: ipn.MaskedPrefs{ + ExitNodeIDSet: true, + Prefs: ipn.Prefs{ + ExitNodeID: "", + }, + }, + wantDisablesExitNode: true, // clearing the exit node ID disables it + }, + { + name: "has-auto-exit-node/clear-exit-node-id", + prefs: ipn.Prefs{ + AutoExitNode: ipn.AnyExitNode, + }, + change: ipn.MaskedPrefs{ + ExitNodeIDSet: true, + Prefs: ipn.Prefs{ + ExitNodeID: "", + }, + }, + wantDisablesExitNode: true, // clearing the exit node ID disables auto exit node as well... + }, + { + name: "has-auto-exit-node/clear-exit-node-id/but-keep-auto-exit-node", + prefs: ipn.Prefs{ + AutoExitNode: ipn.AnyExitNode, + }, + change: ipn.MaskedPrefs{ + ExitNodeIDSet: true, + AutoExitNodeSet: true, + Prefs: ipn.Prefs{ + ExitNodeID: "", + AutoExitNode: ipn.AnyExitNode, + }, + }, + wantDisablesExitNode: false, // ... unless we explicitly keep the auto exit node enabled + }, + { + name: "has-auto-exit-node/clear-exit-node-ip", + prefs: ipn.Prefs{ + AutoExitNode: ipn.AnyExitNode, + }, + change: ipn.MaskedPrefs{ + ExitNodeIPSet: true, + Prefs: ipn.Prefs{ + ExitNodeIP: netip.Addr{}, + }, + }, + wantDisablesExitNode: false, // auto exit node is still enabled + }, + { + name: "has-auto-exit-node/clear-auto-exit-node", + prefs: ipn.Prefs{ + AutoExitNode: ipn.AnyExitNode, + }, + change: ipn.MaskedPrefs{ + AutoExitNodeSet: true, + Prefs: ipn.Prefs{ + AutoExitNode: "", + }, + }, + wantDisablesExitNode: true, // clearing the auto exit while the exit node ID is unresolved disables exit node usage }, } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - nm := &netmap.NetworkMap{ - Peers: tc.peers, - DNS: *tc.dnsConfig, + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + lb := newTestLocalBackend(t) + if tt.netMap != nil { + lb.SetControlClientStatus(lb.cc, controlclient.Status{NetMap: tt.netMap}) + } + // Set the initial prefs via SetPrefsForTest + // to apply necessary adjustments. + lb.SetPrefsForTest(tt.prefs.Clone()) + initialPrefs := lb.Prefs() + + // Check whether changeDisablesExitNodeLocked correctly identifies the change. + if got := lb.changeDisablesExitNodeLocked(initialPrefs, &tt.change); got != tt.wantDisablesExitNode { + t.Errorf("disablesExitNode: got %v; want %v", got, tt.wantDisablesExitNode) } - prefs := &ipn.Prefs{ExitNodeID: tc.exitNode, CorpDNS: true} - got := dnsConfigForNetmap(nm, peersMap(tc.peers), prefs.View(), false, t.Logf, "") - if !resolversEqual(t, got.DefaultResolvers, tc.wantDefaultResolvers) { - t.Errorf("DefaultResolvers: got %#v, want %#v", got.DefaultResolvers, tc.wantDefaultResolvers) + // Apply the change and check if it the actual behavior matches the expectation. + gotPrefs, err := lb.EditPrefsAs(&tt.change, &ipnauth.TestActor{}) + if err != nil { + t.Fatalf("EditPrefsAs failed: %v", err) } - if !routesEqual(t, got.Routes, tc.wantRoutes) { - t.Errorf("Routes: got %#v, want %#v", got.Routes, tc.wantRoutes) + gotDisabledExitNode := initialPrefs.ExitNodeID() != "" && gotPrefs.ExitNodeID() == "" + if gotDisabledExitNode != tt.wantDisablesExitNode { + t.Errorf("disabledExitNode: got %v; want %v", gotDisabledExitNode, tt.wantDisablesExitNode) } }) } } -func TestOfferingAppConnector(t *testing.T) { - for _, shouldStore := range []bool{false, true} { - b := newTestBackend(t) - if b.OfferingAppConnector() { - t.Fatal("unexpected offering app connector") - } - if shouldStore { - b.appConnector = appc.NewAppConnector(t.Logf, nil, &appc.RouteInfo{}, fakeStoreRoutes) - } else { - b.appConnector = appc.NewAppConnector(t.Logf, nil, nil, nil) - } - if !b.OfferingAppConnector() { - t.Fatal("unexpected not offering app connector") - } - } -} +func TestExitNodeNotifyOrder(t *testing.T) { + const controlURL = "https://localhost:1/" -func TestRouteAdvertiser(t *testing.T) { - b := newTestBackend(t) - testPrefix := netip.MustParsePrefix("192.0.0.8/32") + report := &netcheck.Report{ + RegionLatency: map[int]time.Duration{ + 1: 5 * time.Millisecond, + 2: 10 * time.Millisecond, + }, + PreferredDERP: 1, + } - ra := appc.RouteAdvertiser(b) - must.Do(ra.AdvertiseRoute(testPrefix)) + exitNode1 := makeExitNode(1, withName("node-1"), withDERP(1), withAddresses(netip.MustParsePrefix("100.64.1.1/32"))) + exitNode2 := makeExitNode(2, withName("node-2"), withDERP(2), withAddresses(netip.MustParsePrefix("100.64.1.2/32"))) + selfNode := makeExitNode(3, withName("node-3"), withDERP(1), withAddresses(netip.MustParsePrefix("100.64.1.3/32"))) + clientNetmap := buildNetmapWithPeers(selfNode, exitNode1, exitNode2) - routes := b.Prefs().AdvertiseRoutes() - if routes.Len() != 1 || routes.At(0) != testPrefix { - t.Fatalf("got routes %v, want %v", routes, []netip.Prefix{testPrefix}) - } + lb := newTestLocalBackend(t) + lb.sys.MagicSock.Get().SetLastNetcheckReportForTest(lb.ctx, report) + lb.SetPrefsForTest(&ipn.Prefs{ + ControlURL: controlURL, + AutoExitNode: ipn.AnyExitNode, + }) - must.Do(ra.UnadvertiseRoute(testPrefix)) + nw := newNotificationWatcher(t, lb, ipnauth.Self) - routes = b.Prefs().AdvertiseRoutes() - if routes.Len() != 0 { - t.Fatalf("got routes %v, want none", routes) - } + // Updating the netmap should trigger both a netmap notification + // and an exit node ID notification (since an exit node is selected). + // The netmap notification should be sent first. + nw.watch(0, []wantedNotification{ + wantNetmapNotify(clientNetmap), + wantExitNodeIDNotify(exitNode1.StableID()), + }) + lb.SetControlClientStatus(lb.cc, controlclient.Status{NetMap: clientNetmap}) + nw.check() } -func TestRouterAdvertiserIgnoresContainedRoutes(t *testing.T) { - b := newTestBackend(t) - testPrefix := netip.MustParsePrefix("192.0.0.0/24") - ra := appc.RouteAdvertiser(b) - must.Do(ra.AdvertiseRoute(testPrefix)) - - routes := b.Prefs().AdvertiseRoutes() - if routes.Len() != 1 || routes.At(0) != testPrefix { - t.Fatalf("got routes %v, want %v", routes, []netip.Prefix{testPrefix}) +func wantNetmapNotify(want *netmap.NetworkMap) wantedNotification { + return wantedNotification{ + name: "Netmap", + cond: func(t testing.TB, _ ipnauth.Actor, n *ipn.Notify) bool { + return n.NetMap == want + }, } +} - must.Do(ra.AdvertiseRoute(netip.MustParsePrefix("192.0.0.8/32"))) - - // the above /32 is not added as it is contained within the /24 - routes = b.Prefs().AdvertiseRoutes() - if routes.Len() != 1 || routes.At(0) != testPrefix { - t.Fatalf("got routes %v, want %v", routes, []netip.Prefix{testPrefix}) +func wantExitNodeIDNotify(want tailcfg.StableNodeID) wantedNotification { + return wantedNotification{ + name: fmt.Sprintf("ExitNodeID-%s", want), + cond: func(_ testing.TB, _ ipnauth.Actor, n *ipn.Notify) bool { + return n.Prefs != nil && n.Prefs.Valid() && n.Prefs.ExitNodeID() == want + }, } } -func TestObserveDNSResponse(t *testing.T) { - for _, shouldStore := range []bool{false, true} { - b := newTestBackend(t) - - // ensure no error when no app connector is configured - b.ObserveDNSResponse(dnsResponse("example.com.", "192.0.0.8")) +func TestInternalAndExternalInterfaces(t *testing.T) { + type interfacePrefix struct { + i netmon.Interface + pfx netip.Prefix + } - rc := &appctest.RouteCollector{} - if shouldStore { - b.appConnector = appc.NewAppConnector(t.Logf, rc, &appc.RouteInfo{}, fakeStoreRoutes) - } else { - b.appConnector = appc.NewAppConnector(t.Logf, rc, nil, nil) + masked := func(ips ...interfacePrefix) (pfxs []netip.Prefix) { + for _, ip := range ips { + pfxs = append(pfxs, ip.pfx.Masked()) } - b.appConnector.UpdateDomains([]string{"example.com"}) - b.appConnector.Wait(context.Background()) - - b.ObserveDNSResponse(dnsResponse("example.com.", "192.0.0.8")) - b.appConnector.Wait(context.Background()) - wantRoutes := []netip.Prefix{netip.MustParsePrefix("192.0.0.8/32")} - if !slices.Equal(rc.Routes(), wantRoutes) { - t.Fatalf("got routes %v, want %v", rc.Routes(), wantRoutes) + return pfxs + } + iList := func(ips ...interfacePrefix) (il netmon.InterfaceList) { + for _, ip := range ips { + il = append(il, ip.i) } + return il } -} + newInterface := func(name, pfx string, wsl2, loopback bool) interfacePrefix { + ippfx := netip.MustParsePrefix(pfx) + ip := netmon.Interface{ + Interface: &net.Interface{}, + AltAddrs: []net.Addr{ + netipx.PrefixIPNet(ippfx), + }, + } + if loopback { + ip.Flags = net.FlagLoopback + } + if wsl2 { + ip.HardwareAddr = []byte{0x00, 0x15, 0x5d, 0x00, 0x00, 0x00} + } + return interfacePrefix{i: ip, pfx: ippfx} + } + var ( + en0 = newInterface("en0", "10.20.2.5/16", false, false) + en1 = newInterface("en1", "192.168.1.237/24", false, false) + wsl = newInterface("wsl", "192.168.5.34/24", true, false) + loopback = newInterface("lo0", "127.0.0.1/8", false, true) + ) -func TestCoveredRouteRangeNoDefault(t *testing.T) { tests := []struct { - existingRoute netip.Prefix - newRoute netip.Prefix - want bool + name string + goos string + il netmon.InterfaceList + wantInt []netip.Prefix + wantExt []netip.Prefix }{ { - existingRoute: netip.MustParsePrefix("192.0.0.1/32"), - newRoute: netip.MustParsePrefix("192.0.0.1/32"), - want: true, - }, - { - existingRoute: netip.MustParsePrefix("192.0.0.1/32"), - newRoute: netip.MustParsePrefix("192.0.0.2/32"), - want: false, - }, - { - existingRoute: netip.MustParsePrefix("192.0.0.0/24"), - newRoute: netip.MustParsePrefix("192.0.0.1/32"), - want: true, - }, - { - existingRoute: netip.MustParsePrefix("192.0.0.0/16"), - newRoute: netip.MustParsePrefix("192.0.0.0/24"), - want: true, + name: "single-interface", + goos: "linux", + il: iList( + en0, + loopback, + ), + wantInt: masked(loopback), + wantExt: masked(en0), }, { - existingRoute: netip.MustParsePrefix("0.0.0.0/0"), - newRoute: netip.MustParsePrefix("192.0.0.0/24"), - want: false, + name: "multiple-interfaces", + goos: "linux", + il: iList( + en0, + en1, + wsl, + loopback, + ), + wantInt: masked(loopback), + wantExt: masked(en0, en1, wsl), }, { - existingRoute: netip.MustParsePrefix("::/0"), - newRoute: netip.MustParsePrefix("2001:db8::/32"), - want: false, + name: "wsl2", + goos: "windows", + il: iList( + en0, + en1, + wsl, + loopback, + ), + wantInt: masked(loopback, wsl), + wantExt: masked(en0, en1), }, } - for _, tt := range tests { - got := coveredRouteRangeNoDefault([]netip.Prefix{tt.existingRoute}, tt.newRoute) - if got != tt.want { - t.Errorf("coveredRouteRange(%v, %v) = %v, want %v", tt.existingRoute, tt.newRoute, got, tt.want) - } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + gotInt, gotExt, err := internalAndExternalInterfacesFrom(tc.il, tc.goos) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(gotInt, tc.wantInt) { + t.Errorf("unexpected internal prefixes\ngot %v\nwant %v", gotInt, tc.wantInt) + } + if !reflect.DeepEqual(gotExt, tc.wantExt) { + t.Errorf("unexpected external prefixes\ngot %v\nwant %v", gotExt, tc.wantExt) + } + }) } } -func TestReconfigureAppConnector(t *testing.T) { - b := newTestBackend(t) - b.reconfigAppConnectorLocked(b.netMap, b.pm.prefs) - if b.appConnector != nil { - t.Fatal("unexpected app connector") - } - - b.EditPrefs(&ipn.MaskedPrefs{ +func TestPacketFilterPermitsUnlockedNodes(t *testing.T) { + tests := []struct { + name string + peers []*tailcfg.Node + filter []filter.Match + want bool + }{ + { + name: "empty", + want: false, + }, + { + name: "no-unsigned", + peers: []*tailcfg.Node{ + {ID: 1}, + }, + want: false, + }, + { + name: "unsigned-good", + peers: []*tailcfg.Node{ + {ID: 1, UnsignedPeerAPIOnly: true}, + }, + want: false, + }, + { + name: "unsigned-bad", + peers: []*tailcfg.Node{ + { + ID: 1, + UnsignedPeerAPIOnly: true, + AllowedIPs: []netip.Prefix{ + netip.MustParsePrefix("100.64.0.0/32"), + }, + }, + }, + filter: []filter.Match{ + { + Srcs: []netip.Prefix{netip.MustParsePrefix("100.64.0.0/32")}, + Dsts: []filter.NetPortRange{ + { + Net: netip.MustParsePrefix("100.99.0.0/32"), + }, + }, + }, + }, + want: true, + }, + { + name: "unsigned-bad-src-is-superset", + peers: []*tailcfg.Node{ + { + ID: 1, + UnsignedPeerAPIOnly: true, + AllowedIPs: []netip.Prefix{ + netip.MustParsePrefix("100.64.0.0/32"), + }, + }, + }, + filter: []filter.Match{ + { + Srcs: []netip.Prefix{netip.MustParsePrefix("100.64.0.0/24")}, + Dsts: []filter.NetPortRange{ + { + Net: netip.MustParsePrefix("100.99.0.0/32"), + }, + }, + }, + }, + want: true, + }, + { + name: "unsigned-okay-because-no-dsts", + peers: []*tailcfg.Node{ + { + ID: 1, + UnsignedPeerAPIOnly: true, + AllowedIPs: []netip.Prefix{ + netip.MustParsePrefix("100.64.0.0/32"), + }, + }, + }, + filter: []filter.Match{ + { + Srcs: []netip.Prefix{netip.MustParsePrefix("100.64.0.0/32")}, + Caps: []filter.CapMatch{ + { + Dst: netip.MustParsePrefix("100.99.0.0/32"), + Cap: "foo", + }, + }, + }, + }, + want: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := packetFilterPermitsUnlockedNodes(peersMap(nodeViews(tt.peers)), tt.filter); got != tt.want { + t.Errorf("got %v, want %v", got, tt.want) + } + }) + } +} + +func TestStatusPeerCapabilities(t *testing.T) { + tests := []struct { + name string + peers []tailcfg.NodeView + expectedPeerCapabilities map[tailcfg.StableNodeID][]tailcfg.NodeCapability + expectedPeerCapMap map[tailcfg.StableNodeID]tailcfg.NodeCapMap + }{ + { + name: "peers-with-capabilities", + peers: []tailcfg.NodeView{ + (&tailcfg.Node{ + ID: 1, + StableID: "foo", + Key: makeNodeKeyFromID(1), + IsWireGuardOnly: true, + Hostinfo: (&tailcfg.Hostinfo{}).View(), + Capabilities: []tailcfg.NodeCapability{tailcfg.CapabilitySSH}, + CapMap: (tailcfg.NodeCapMap)(map[tailcfg.NodeCapability][]tailcfg.RawMessage{ + tailcfg.CapabilitySSH: nil, + }), + }).View(), + (&tailcfg.Node{ + ID: 2, + StableID: "bar", + Key: makeNodeKeyFromID(2), + Hostinfo: (&tailcfg.Hostinfo{}).View(), + Capabilities: []tailcfg.NodeCapability{tailcfg.CapabilityAdmin}, + CapMap: (tailcfg.NodeCapMap)(map[tailcfg.NodeCapability][]tailcfg.RawMessage{ + tailcfg.CapabilityAdmin: {`{"test": "true}`}, + }), + }).View(), + (&tailcfg.Node{ + ID: 3, + StableID: "baz", + Key: makeNodeKeyFromID(3), + Hostinfo: (&tailcfg.Hostinfo{}).View(), + Capabilities: []tailcfg.NodeCapability{tailcfg.CapabilityOwner}, + CapMap: (tailcfg.NodeCapMap)(map[tailcfg.NodeCapability][]tailcfg.RawMessage{ + tailcfg.CapabilityOwner: nil, + }), + }).View(), + }, + expectedPeerCapabilities: map[tailcfg.StableNodeID][]tailcfg.NodeCapability{ + tailcfg.StableNodeID("foo"): {tailcfg.CapabilitySSH}, + tailcfg.StableNodeID("bar"): {tailcfg.CapabilityAdmin}, + tailcfg.StableNodeID("baz"): {tailcfg.CapabilityOwner}, + }, + expectedPeerCapMap: map[tailcfg.StableNodeID]tailcfg.NodeCapMap{ + tailcfg.StableNodeID("foo"): (tailcfg.NodeCapMap)(map[tailcfg.NodeCapability][]tailcfg.RawMessage{ + tailcfg.CapabilitySSH: nil, + }), + tailcfg.StableNodeID("bar"): (tailcfg.NodeCapMap)(map[tailcfg.NodeCapability][]tailcfg.RawMessage{ + tailcfg.CapabilityAdmin: {`{"test": "true}`}, + }), + tailcfg.StableNodeID("baz"): (tailcfg.NodeCapMap)(map[tailcfg.NodeCapability][]tailcfg.RawMessage{ + tailcfg.CapabilityOwner: nil, + }), + }, + }, + { + name: "peers-without-capabilities", + peers: []tailcfg.NodeView{ + (&tailcfg.Node{ + ID: 1, + StableID: "foo", + Key: makeNodeKeyFromID(1), + IsWireGuardOnly: true, + Hostinfo: (&tailcfg.Hostinfo{}).View(), + }).View(), + (&tailcfg.Node{ + ID: 2, + StableID: "bar", + Key: makeNodeKeyFromID(2), + Hostinfo: (&tailcfg.Hostinfo{}).View(), + }).View(), + }, + }, + } + b := newTestLocalBackend(t) + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + b.setNetMapLocked(&netmap.NetworkMap{ + SelfNode: (&tailcfg.Node{ + MachineAuthorized: true, + Addresses: ipps("100.101.101.101"), + }).View(), + Peers: tt.peers, + }) + got := b.Status() + for _, peer := range got.Peer { + if !reflect.DeepEqual(peer.Capabilities, tt.expectedPeerCapabilities[peer.ID]) { + t.Errorf("peer capabilities: expected %v got %v", tt.expectedPeerCapabilities, peer.Capabilities) + } + if !reflect.DeepEqual(peer.CapMap, tt.expectedPeerCapMap[peer.ID]) { + t.Errorf("peer capmap: expected %v got %v", tt.expectedPeerCapMap, peer.CapMap) + } + } + }) + } +} + +// legacyBackend was the interface between Tailscale frontends +// (e.g. cmd/tailscale, iOS/MacOS/Windows GUIs) and the tailscale +// backend (e.g. cmd/tailscaled) running on the same machine. +// (It has nothing to do with the interface between the backends +// and the cloud control plane.) +type legacyBackend interface { + // SetNotifyCallback sets the callback to be called on updates + // from the backend to the client. + SetNotifyCallback(func(ipn.Notify)) + // Start starts or restarts the backend, typically when a + // frontend client connects. + Start(ipn.Options) error +} + +// Verify that LocalBackend still implements the legacyBackend interface +// for now, at least until the macOS and iOS clients move off of it. +var _ legacyBackend = (*LocalBackend)(nil) + +func TestWatchNotificationsCallbacks(t *testing.T) { + b := new(LocalBackend) + n := new(ipn.Notify) + b.WatchNotifications(context.Background(), 0, func() { + b.mu.Lock() + defer b.mu.Unlock() + + // Ensure a watcher has been installed. + if len(b.notifyWatchers) != 1 { + t.Fatalf("unexpected number of watchers in new LocalBackend, want: 1 got: %v", len(b.notifyWatchers)) + } + // Send a notification. Range over notifyWatchers to get the channel + // because WatchNotifications doesn't expose the handle for it. + for _, sess := range b.notifyWatchers { + select { + case sess.ch <- n: + default: + t.Fatalf("could not send notification") + } + } + }, func(roNotify *ipn.Notify) bool { + if roNotify != n { + t.Fatalf("unexpected notification received. want: %v got: %v", n, roNotify) + } + return false + }) + + // Ensure watchers have been cleaned up. + b.mu.Lock() + defer b.mu.Unlock() + if len(b.notifyWatchers) != 0 { + t.Fatalf("unexpected number of watchers in new LocalBackend, want: 0 got: %v", len(b.notifyWatchers)) + } +} + +// tests LocalBackend.updateNetmapDeltaLocked +func TestUpdateNetmapDelta(t *testing.T) { + b := newTestLocalBackend(t) + if b.currentNode().UpdateNetmapDelta(nil) { + t.Errorf("updateNetmapDeltaLocked() = true, want false with nil netmap") + } + + nm := &netmap.NetworkMap{} + for i := range 5 { + id := tailcfg.NodeID(i + 1) + nm.Peers = append(nm.Peers, (&tailcfg.Node{ + ID: id, + Key: makeNodeKeyFromID(id), + }).View()) + } + b.currentNode().SetNetMap(nm) + + someTime := time.Unix(123, 0) + muts, ok := netmap.MutationsFromMapResponse(&tailcfg.MapResponse{ + PeersChangedPatch: []*tailcfg.PeerChange{ + { + NodeID: 1, + DERPRegion: 1, + }, + { + NodeID: 2, + Online: ptr.To(true), + }, + { + NodeID: 3, + Online: ptr.To(false), + }, + { + NodeID: 4, + LastSeen: ptr.To(someTime), + }, + }, + }, someTime) + if !ok { + t.Fatal("netmap.MutationsFromMapResponse failed") + } + + if !b.currentNode().UpdateNetmapDelta(muts) { + t.Fatalf("updateNetmapDeltaLocked() = false, want true with new netmap") + } + + wants := []*tailcfg.Node{ + { + ID: 1, + Key: makeNodeKeyFromID(1), + HomeDERP: 1, + }, + { + ID: 2, + Key: makeNodeKeyFromID(2), + Online: ptr.To(true), + }, + { + ID: 3, + Key: makeNodeKeyFromID(3), + Online: ptr.To(false), + }, + { + ID: 4, + Key: makeNodeKeyFromID(4), + LastSeen: ptr.To(someTime), + }, + } + for _, want := range wants { + gotv, ok := b.currentNode().NodeByID(want.ID) + if !ok { + t.Errorf("netmap.Peer %v missing from b.profile.Peers", want.ID) + continue + } + got := gotv.AsStruct() + if !reflect.DeepEqual(got, want) { + t.Errorf("netmap.Peer %v wrong.\n got: %v\nwant: %v", want.ID, logger.AsJSON(got), logger.AsJSON(want)) + } + } +} + +// tests WhoIs and indirectly that setNetMapLocked updates b.nodeByAddr correctly. +func TestWhoIs(t *testing.T) { + b := newTestLocalBackend(t) + b.setNetMapLocked(&netmap.NetworkMap{ + SelfNode: (&tailcfg.Node{ + ID: 1, + User: 10, + Key: makeNodeKeyFromID(1), + Addresses: []netip.Prefix{netip.MustParsePrefix("100.101.102.103/32")}, + }).View(), + Peers: []tailcfg.NodeView{ + (&tailcfg.Node{ + ID: 2, + User: 20, + Key: makeNodeKeyFromID(2), + Addresses: []netip.Prefix{netip.MustParsePrefix("100.200.200.200/32")}, + }).View(), + }, + UserProfiles: map[tailcfg.UserID]tailcfg.UserProfileView{ + 10: (&tailcfg.UserProfile{ + DisplayName: "Myself", + }).View(), + 20: (&tailcfg.UserProfile{ + DisplayName: "Peer", + }).View(), + }, + }) + tests := []struct { + q string + want tailcfg.NodeID // 0 means want ok=false + wantName string + }{ + {"100.101.102.103:0", 1, "Myself"}, + {"100.101.102.103:123", 1, "Myself"}, + {"100.200.200.200:0", 2, "Peer"}, + {"100.200.200.200:123", 2, "Peer"}, + {"100.4.0.4:404", 0, ""}, + } + for _, tt := range tests { + t.Run(tt.q, func(t *testing.T) { + nv, up, ok := b.WhoIs("", netip.MustParseAddrPort(tt.q)) + var got tailcfg.NodeID + if ok { + got = nv.ID() + } + if got != tt.want { + t.Errorf("got nodeID %v; want %v", got, tt.want) + } + if up.DisplayName != tt.wantName { + t.Errorf("got name %q; want %q", up.DisplayName, tt.wantName) + } + }) + } +} + +func TestWireguardExitNodeDNSResolvers(t *testing.T) { + type tc struct { + name string + id tailcfg.StableNodeID + peers []*tailcfg.Node + wantOK bool + wantResolvers []*dnstype.Resolver + } + + tests := []tc{ + { + name: "no peers", + id: "1", + wantOK: false, + wantResolvers: nil, + }, + { + name: "non wireguard peer", + id: "1", + peers: []*tailcfg.Node{ + { + ID: 1, + StableID: "1", + IsWireGuardOnly: false, + ExitNodeDNSResolvers: []*dnstype.Resolver{{Addr: "dns.example.com"}}, + }, + }, + wantOK: false, + wantResolvers: nil, + }, + { + name: "no matching IDs", + id: "2", + peers: []*tailcfg.Node{ + { + ID: 1, + StableID: "1", + IsWireGuardOnly: true, + ExitNodeDNSResolvers: []*dnstype.Resolver{{Addr: "dns.example.com"}}, + }, + }, + wantOK: false, + wantResolvers: nil, + }, + { + name: "wireguard peer", + id: "1", + peers: []*tailcfg.Node{ + { + ID: 1, + StableID: "1", + IsWireGuardOnly: true, + ExitNodeDNSResolvers: []*dnstype.Resolver{{Addr: "dns.example.com"}}, + }, + }, + wantOK: true, + wantResolvers: []*dnstype.Resolver{{Addr: "dns.example.com"}}, + }, + } + + for _, tc := range tests { + peers := peersMap(nodeViews(tc.peers)) + nm := &netmap.NetworkMap{} + gotResolvers, gotOK := wireguardExitNodeDNSResolvers(nm, peers, tc.id) + + if gotOK != tc.wantOK || !resolversEqual(t, gotResolvers, tc.wantResolvers) { + t.Errorf("case: %s: got %v, %v, want %v, %v", tc.name, gotOK, gotResolvers, tc.wantOK, tc.wantResolvers) + } + } +} + +func TestDNSConfigForNetmapForExitNodeConfigs(t *testing.T) { + type tc struct { + name string + exitNode tailcfg.StableNodeID + peers []tailcfg.NodeView + dnsConfig *tailcfg.DNSConfig + wantDefaultResolvers []*dnstype.Resolver + wantRoutes map[dnsname.FQDN][]*dnstype.Resolver + } + + const tsUseWithExitNodeResolverAddr = "usewithexitnode.example.com" + defaultResolvers := []*dnstype.Resolver{ + {Addr: "default.example.com"}, + } + containsFlaggedResolvers := append([]*dnstype.Resolver{ + {Addr: tsUseWithExitNodeResolverAddr, UseWithExitNode: true}, + }, defaultResolvers...) + + wgResolvers := []*dnstype.Resolver{{Addr: "wg.example.com"}} + peers := []tailcfg.NodeView{ + (&tailcfg.Node{ + ID: 1, + StableID: "wg", + IsWireGuardOnly: true, + ExitNodeDNSResolvers: wgResolvers, + Hostinfo: (&tailcfg.Hostinfo{}).View(), + }).View(), + // regular tailscale exit node with DNS capabilities + (&tailcfg.Node{ + Cap: 26, + ID: 2, + StableID: "ts", + Hostinfo: (&tailcfg.Hostinfo{}).View(), + }).View(), + } + exitDOH := peerAPIBase(&netmap.NetworkMap{Peers: peers}, peers[0]) + "/dns-query" + baseRoutes := map[dnsname.FQDN][]*dnstype.Resolver{ + "route.example.com.": {{Addr: "route.example.com"}}, + } + containsEmptyRoutes := map[dnsname.FQDN][]*dnstype.Resolver{ + "route.example.com.": {{Addr: "route.example.com"}}, + "empty.example.com.": {}, + } + containsFlaggedRoutes := map[dnsname.FQDN][]*dnstype.Resolver{ + "route.example.com.": {{Addr: "route.example.com"}}, + "withexit.example.com.": {{Addr: tsUseWithExitNodeResolverAddr, UseWithExitNode: true}}, + } + containsFlaggedAndEmptyRoutes := map[dnsname.FQDN][]*dnstype.Resolver{ + "empty.example.com.": {}, + "route.example.com.": {{Addr: "route.example.com"}}, + "withexit.example.com.": {{Addr: tsUseWithExitNodeResolverAddr, UseWithExitNode: true}}, + } + flaggedRoutes := map[dnsname.FQDN][]*dnstype.Resolver{ + "withexit.example.com.": {{Addr: tsUseWithExitNodeResolverAddr, UseWithExitNode: true}}, + } + emptyRoutes := map[dnsname.FQDN][]*dnstype.Resolver{ + "empty.example.com.": {}, + } + flaggedAndEmptyRoutes := map[dnsname.FQDN][]*dnstype.Resolver{ + "empty.example.com.": {}, + "withexit.example.com.": {{Addr: tsUseWithExitNodeResolverAddr, UseWithExitNode: true}}, + } + + stringifyRoutes := func(routes map[dnsname.FQDN][]*dnstype.Resolver) map[string][]*dnstype.Resolver { + if routes == nil { + return nil + } + m := make(map[string][]*dnstype.Resolver) + for k, v := range routes { + m[string(k)] = v + } + return m + } + + tests := []tc{ + { + name: "noExit/noRoutes/noResolver", + exitNode: "", + peers: peers, + dnsConfig: &tailcfg.DNSConfig{}, + wantDefaultResolvers: nil, + wantRoutes: nil, + }, + { + name: "tsExit/noRoutes/noResolver", + exitNode: "ts", + peers: peers, + dnsConfig: &tailcfg.DNSConfig{}, + wantDefaultResolvers: []*dnstype.Resolver{{Addr: exitDOH}}, + wantRoutes: nil, + }, + { + name: "tsExit/noRoutes/defaultResolver", + exitNode: "ts", + peers: peers, + dnsConfig: &tailcfg.DNSConfig{Resolvers: defaultResolvers}, + wantDefaultResolvers: []*dnstype.Resolver{{Addr: exitDOH}}, + wantRoutes: nil, + }, + { + name: "tsExit/noRoutes/flaggedResolverOnly", + exitNode: "ts", + peers: peers, + dnsConfig: &tailcfg.DNSConfig{Resolvers: containsFlaggedResolvers}, + wantDefaultResolvers: []*dnstype.Resolver{{Addr: tsUseWithExitNodeResolverAddr, UseWithExitNode: true}}, + wantRoutes: nil, + }, + + // When at tailscale exit node is in use, + // only routes that reference resolvers with the UseWithExitNode should be installed, + // as well as routes with 0-length resolver lists, which should be installed in all cases. + { + name: "tsExit/routes/noResolver", + exitNode: "ts", + peers: peers, + dnsConfig: &tailcfg.DNSConfig{Routes: stringifyRoutes(baseRoutes)}, + wantDefaultResolvers: []*dnstype.Resolver{{Addr: exitDOH}}, + wantRoutes: nil, + }, + { + name: "tsExit/routes/defaultResolver", + exitNode: "ts", + peers: peers, + dnsConfig: &tailcfg.DNSConfig{Routes: stringifyRoutes(baseRoutes), Resolvers: defaultResolvers}, + wantDefaultResolvers: []*dnstype.Resolver{{Addr: exitDOH}}, + wantRoutes: nil, + }, + { + name: "tsExit/routes/flaggedResolverOnly", + exitNode: "ts", + peers: peers, + dnsConfig: &tailcfg.DNSConfig{Routes: stringifyRoutes(baseRoutes), Resolvers: containsFlaggedResolvers}, + wantDefaultResolvers: []*dnstype.Resolver{{Addr: tsUseWithExitNodeResolverAddr, UseWithExitNode: true}}, + wantRoutes: nil, + }, + { + name: "tsExit/flaggedRoutesOnly/defaultResolver", + exitNode: "ts", + peers: peers, + dnsConfig: &tailcfg.DNSConfig{Routes: stringifyRoutes(containsFlaggedRoutes), Resolvers: defaultResolvers}, + wantDefaultResolvers: []*dnstype.Resolver{{Addr: exitDOH}}, + wantRoutes: flaggedRoutes, + }, + { + name: "tsExit/flaggedRoutesOnly/flaggedResolverOnly", + exitNode: "ts", + peers: peers, + dnsConfig: &tailcfg.DNSConfig{Routes: stringifyRoutes(containsFlaggedRoutes), Resolvers: containsFlaggedResolvers}, + wantDefaultResolvers: []*dnstype.Resolver{{Addr: tsUseWithExitNodeResolverAddr, UseWithExitNode: true}}, + wantRoutes: flaggedRoutes, + }, + { + name: "tsExit/emptyRoutesOnly/defaultResolver", + exitNode: "ts", + peers: peers, + dnsConfig: &tailcfg.DNSConfig{Routes: stringifyRoutes(containsEmptyRoutes), Resolvers: defaultResolvers}, + wantDefaultResolvers: []*dnstype.Resolver{{Addr: exitDOH}}, + wantRoutes: emptyRoutes, + }, + { + name: "tsExit/flaggedAndEmptyRoutesOnly/defaultResolver", + exitNode: "ts", + peers: peers, + dnsConfig: &tailcfg.DNSConfig{Routes: stringifyRoutes(containsFlaggedAndEmptyRoutes), Resolvers: defaultResolvers}, + wantDefaultResolvers: []*dnstype.Resolver{{Addr: exitDOH}}, + wantRoutes: flaggedAndEmptyRoutes, + }, + { + name: "tsExit/flaggedAndEmptyRoutesOnly/flaggedResolverOnly", + exitNode: "ts", + peers: peers, + dnsConfig: &tailcfg.DNSConfig{Routes: stringifyRoutes(containsFlaggedAndEmptyRoutes), Resolvers: containsFlaggedResolvers}, + wantDefaultResolvers: []*dnstype.Resolver{{Addr: tsUseWithExitNodeResolverAddr, UseWithExitNode: true}}, + wantRoutes: flaggedAndEmptyRoutes, + }, + + // WireGuard exit nodes with DNS capabilities provide a "fallback" type + // behavior, they have a lower precedence than a default resolver, but + // otherwise allow split-DNS to operate as normal, and are used when + // there is no default resolver. + { + name: "wgExit/noRoutes/noResolver", + exitNode: "wg", + peers: peers, + dnsConfig: &tailcfg.DNSConfig{}, + wantDefaultResolvers: wgResolvers, + wantRoutes: nil, + }, + { + name: "wgExit/noRoutes/defaultResolver", + exitNode: "wg", + peers: peers, + dnsConfig: &tailcfg.DNSConfig{Resolvers: defaultResolvers}, + wantDefaultResolvers: defaultResolvers, + wantRoutes: nil, + }, + { + name: "wgExit/routes/defaultResolver", + exitNode: "wg", + peers: peers, + dnsConfig: &tailcfg.DNSConfig{Routes: stringifyRoutes(baseRoutes), Resolvers: defaultResolvers}, + wantDefaultResolvers: defaultResolvers, + wantRoutes: baseRoutes, + }, + { + name: "wgExit/routes/noResolver", + exitNode: "wg", + peers: peers, + dnsConfig: &tailcfg.DNSConfig{Routes: stringifyRoutes(baseRoutes)}, + wantDefaultResolvers: wgResolvers, + wantRoutes: baseRoutes, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + nm := &netmap.NetworkMap{ + Peers: tc.peers, + DNS: *tc.dnsConfig, + } + + prefs := &ipn.Prefs{ExitNodeID: tc.exitNode, CorpDNS: true} + got := dnsConfigForNetmap(nm, peersMap(tc.peers), prefs.View(), false, t.Logf, "") + if !resolversEqual(t, got.DefaultResolvers, tc.wantDefaultResolvers) { + t.Errorf("DefaultResolvers: got %#v, want %#v", got.DefaultResolvers, tc.wantDefaultResolvers) + } + if !routesEqual(t, got.Routes, tc.wantRoutes) { + t.Errorf("Routes: got %#v, want %#v", got.Routes, tc.wantRoutes) + } + }) + } +} + +func TestOfferingAppConnector(t *testing.T) { + for _, shouldStore := range []bool{false, true} { + b := newTestBackend(t) + bus := b.sys.Bus.Get() + if b.OfferingAppConnector() { + t.Fatal("unexpected offering app connector") + } + b.appConnector = appc.NewAppConnector(appc.Config{ + Logf: t.Logf, EventBus: bus, HasStoredRoutes: shouldStore, + }) + if !b.OfferingAppConnector() { + t.Fatal("unexpected not offering app connector") + } + } +} + +func TestRouteAdvertiser(t *testing.T) { + b := newTestBackend(t) + testPrefix := netip.MustParsePrefix("192.0.0.8/32") + + ra := appc.RouteAdvertiser(b) + must.Do(ra.AdvertiseRoute(testPrefix)) + + routes := b.Prefs().AdvertiseRoutes() + if routes.Len() != 1 || routes.At(0) != testPrefix { + t.Fatalf("got routes %v, want %v", routes, []netip.Prefix{testPrefix}) + } + + must.Do(ra.UnadvertiseRoute(testPrefix)) + + routes = b.Prefs().AdvertiseRoutes() + if routes.Len() != 0 { + t.Fatalf("got routes %v, want none", routes) + } +} + +func TestRouterAdvertiserIgnoresContainedRoutes(t *testing.T) { + b := newTestBackend(t) + testPrefix := netip.MustParsePrefix("192.0.0.0/24") + ra := appc.RouteAdvertiser(b) + must.Do(ra.AdvertiseRoute(testPrefix)) + + routes := b.Prefs().AdvertiseRoutes() + if routes.Len() != 1 || routes.At(0) != testPrefix { + t.Fatalf("got routes %v, want %v", routes, []netip.Prefix{testPrefix}) + } + + must.Do(ra.AdvertiseRoute(netip.MustParsePrefix("192.0.0.8/32"))) + + // the above /32 is not added as it is contained within the /24 + routes = b.Prefs().AdvertiseRoutes() + if routes.Len() != 1 || routes.At(0) != testPrefix { + t.Fatalf("got routes %v, want %v", routes, []netip.Prefix{testPrefix}) + } +} + +func TestObserveDNSResponse(t *testing.T) { + for _, shouldStore := range []bool{false, true} { + b := newTestBackend(t) + bus := b.sys.Bus.Get() + w := eventbustest.NewWatcher(t, bus) + + // ensure no error when no app connector is configured + if err := b.ObserveDNSResponse(dnsResponse("example.com.", "192.0.0.8")); err != nil { + t.Errorf("ObserveDNSResponse: %v", err) + } + + rc := &appctest.RouteCollector{} + a := appc.NewAppConnector(appc.Config{ + Logf: t.Logf, + EventBus: bus, + RouteAdvertiser: rc, + HasStoredRoutes: shouldStore, + }) + a.UpdateDomains([]string{"example.com"}) + a.Wait(t.Context()) + b.appConnector = a + + if err := b.ObserveDNSResponse(dnsResponse("example.com.", "192.0.0.8")); err != nil { + t.Errorf("ObserveDNSResponse: %v", err) + } + a.Wait(t.Context()) + wantRoutes := []netip.Prefix{netip.MustParsePrefix("192.0.0.8/32")} + if !slices.Equal(rc.Routes(), wantRoutes) { + t.Fatalf("got routes %v, want %v", rc.Routes(), wantRoutes) + } + + if err := eventbustest.Expect(w, + eqUpdate(appctype.RouteUpdate{Advertise: mustPrefix("192.0.0.8/32")}), + ); err != nil { + t.Error(err) + } + } +} + +func TestCoveredRouteRangeNoDefault(t *testing.T) { + tests := []struct { + existingRoute netip.Prefix + newRoute netip.Prefix + want bool + }{ + { + existingRoute: netip.MustParsePrefix("192.0.0.1/32"), + newRoute: netip.MustParsePrefix("192.0.0.1/32"), + want: true, + }, + { + existingRoute: netip.MustParsePrefix("192.0.0.1/32"), + newRoute: netip.MustParsePrefix("192.0.0.2/32"), + want: false, + }, + { + existingRoute: netip.MustParsePrefix("192.0.0.0/24"), + newRoute: netip.MustParsePrefix("192.0.0.1/32"), + want: true, + }, + { + existingRoute: netip.MustParsePrefix("192.0.0.0/16"), + newRoute: netip.MustParsePrefix("192.0.0.0/24"), + want: true, + }, + { + existingRoute: netip.MustParsePrefix("0.0.0.0/0"), + newRoute: netip.MustParsePrefix("192.0.0.0/24"), + want: false, + }, + { + existingRoute: netip.MustParsePrefix("::/0"), + newRoute: netip.MustParsePrefix("2001:db8::/32"), + want: false, + }, + } + + for _, tt := range tests { + got := coveredRouteRangeNoDefault([]netip.Prefix{tt.existingRoute}, tt.newRoute) + if got != tt.want { + t.Errorf("coveredRouteRange(%v, %v) = %v, want %v", tt.existingRoute, tt.newRoute, got, tt.want) + } + } +} + +func TestReconfigureAppConnector(t *testing.T) { + b := newTestBackend(t) + b.reconfigAppConnectorLocked(b.NetMap(), b.pm.prefs) + if b.appConnector != nil { + t.Fatal("unexpected app connector") + } + + b.EditPrefs(&ipn.MaskedPrefs{ Prefs: ipn.Prefs{ AppConnector: ipn.AppConnectorPrefs{ Advertise: true, }, }, - AppConnectorSet: true, - }) - b.reconfigAppConnectorLocked(b.netMap, b.pm.prefs) - if b.appConnector == nil { - t.Fatal("expected app connector") + AppConnectorSet: true, + }) + b.reconfigAppConnectorLocked(b.NetMap(), b.pm.prefs) + if b.appConnector == nil { + t.Fatal("expected app connector") + } + + appCfg := `{ + "name": "example", + "domains": ["example.com"], + "connectors": ["tag:example"] + }` + + nm := &netmap.NetworkMap{ + SelfNode: (&tailcfg.Node{ + Name: "example.ts.net", + Tags: []string{"tag:example"}, + CapMap: (tailcfg.NodeCapMap)(map[tailcfg.NodeCapability][]tailcfg.RawMessage{ + "tailscale.com/app-connectors": {tailcfg.RawMessage(appCfg)}, + }), + }).View(), + } + + b.currentNode().SetNetMap(nm) + + b.reconfigAppConnectorLocked(b.NetMap(), b.pm.prefs) + b.appConnector.Wait(context.Background()) + + want := []string{"example.com"} + if !slices.Equal(b.appConnector.Domains().AsSlice(), want) { + t.Fatalf("got domains %v, want %v", b.appConnector.Domains(), want) + } + if v, _ := b.hostinfo.AppConnector.Get(); !v { + t.Fatalf("expected app connector service") + } + + // disable the connector in order to assert that the service is removed + b.EditPrefs(&ipn.MaskedPrefs{ + Prefs: ipn.Prefs{ + AppConnector: ipn.AppConnectorPrefs{ + Advertise: false, + }, + }, + AppConnectorSet: true, + }) + b.reconfigAppConnectorLocked(b.NetMap(), b.pm.prefs) + if b.appConnector != nil { + t.Fatal("expected no app connector") + } + if v, _ := b.hostinfo.AppConnector.Get(); v { + t.Fatalf("expected no app connector service") + } +} + +func TestBackfillAppConnectorRoutes(t *testing.T) { + // Create backend with an empty app connector. + b := newTestBackend(t) + // newTestBackend creates a backend with a non-nil netmap, + // but this test requires a nil netmap. + // Otherwise, instead of backfilling, [LocalBackend.reconfigAppConnectorLocked] + // uses the domains and routes from netmap's [appctype.AppConnectorAttr]. + // Additionally, a non-nil netmap makes reconfigAppConnectorLocked + // asynchronous, resulting in a flaky test. + // Therefore, we set the netmap to nil to simulate a fresh backend start + // or a profile switch where the netmap is not yet available. + b.setNetMapLocked(nil) + if err := b.Start(ipn.Options{}); err != nil { + t.Fatal(err) + } + if _, err := b.EditPrefs(&ipn.MaskedPrefs{ + Prefs: ipn.Prefs{ + AppConnector: ipn.AppConnectorPrefs{Advertise: true}, + }, + AppConnectorSet: true, + }); err != nil { + t.Fatal(err) + } + b.reconfigAppConnectorLocked(b.NetMap(), b.pm.prefs) + + // Smoke check that AdvertiseRoutes doesn't have the test IP. + ip := netip.MustParseAddr("1.2.3.4") + routes := b.Prefs().AdvertiseRoutes().AsSlice() + if slices.Contains(routes, netip.PrefixFrom(ip, ip.BitLen())) { + t.Fatalf("AdvertiseRoutes %v on a fresh backend already contains advertised route for %v", routes, ip) + } + + // Store the test IP in profile data, but not in Prefs.AdvertiseRoutes. + b.ControlKnobs().AppCStoreRoutes.Store(true) + if err := b.storeRouteInfo(appctype.RouteInfo{ + Domains: map[string][]netip.Addr{ + "example.com": {ip}, + }, + }); err != nil { + t.Fatal(err) + } + + // Mimic b.authReconfigure for the app connector bits. + b.mu.Lock() + b.reconfigAppConnectorLocked(b.NetMap(), b.pm.prefs) + b.mu.Unlock() + b.readvertiseAppConnectorRoutes() + + // Check that Prefs.AdvertiseRoutes got backfilled with routes stored in + // profile data. + routes = b.Prefs().AdvertiseRoutes().AsSlice() + if !slices.Contains(routes, netip.PrefixFrom(ip, ip.BitLen())) { + t.Fatalf("AdvertiseRoutes %v was not backfilled from stored app connector routes with %v", routes, ip) + } +} + +func resolversEqual(t *testing.T, a, b []*dnstype.Resolver) bool { + if a == nil && b == nil { + return true + } + if a == nil || b == nil { + t.Errorf("resolversEqual: a == nil || b == nil : %#v != %#v", a, b) + return false + } + if len(a) != len(b) { + t.Errorf("resolversEqual: len(a) != len(b) : %#v != %#v", a, b) + return false + } + for i := range a { + if !a[i].Equal(b[i]) { + t.Errorf("resolversEqual: a != b [%d]: %v != %v", i, *a[i], *b[i]) + return false + } + } + return true +} + +func routesEqual(t *testing.T, a, b map[dnsname.FQDN][]*dnstype.Resolver) bool { + if len(a) != len(b) { + t.Logf("routes: len(a) != len(b): %d != %d", len(a), len(b)) + return false + } + for name := range a { + if !resolversEqual(t, a[name], b[name]) { + t.Logf("routes: a != b [%s]: %v != %v", name, a[name], b[name]) + return false + } + } + return true +} + +// dnsResponse is a test helper that creates a DNS response buffer for the given domain and address +func dnsResponse(domain, address string) []byte { + addr := netip.MustParseAddr(address) + b := dnsmessage.NewBuilder(nil, dnsmessage.Header{}) + b.EnableCompression() + b.StartAnswers() + switch addr.BitLen() { + case 32: + b.AResource( + dnsmessage.ResourceHeader{ + Name: dnsmessage.MustNewName(domain), + Type: dnsmessage.TypeA, + Class: dnsmessage.ClassINET, + TTL: 0, + }, + dnsmessage.AResource{ + A: addr.As4(), + }, + ) + case 128: + b.AAAAResource( + dnsmessage.ResourceHeader{ + Name: dnsmessage.MustNewName(domain), + Type: dnsmessage.TypeAAAA, + Class: dnsmessage.ClassINET, + TTL: 0, + }, + dnsmessage.AAAAResource{ + AAAA: addr.As16(), + }, + ) + default: + panic("invalid address length") + } + return must.Get(b.Finish()) +} + +func TestSetExitNodeIDPolicy(t *testing.T) { + zeroValHostinfoView := new(tailcfg.Hostinfo).View() + pfx := netip.MustParsePrefix + tests := []struct { + name string + exitNodeIPKey bool + exitNodeIDKey bool + exitNodeID string + exitNodeIP string + prefs *ipn.Prefs + exitNodeIPWant string + exitNodeIDWant string + autoExitNodeWant ipn.ExitNodeExpression + prefsChanged bool + nm *netmap.NetworkMap + lastSuggestedExitNode tailcfg.StableNodeID + }{ + { + name: "ExitNodeID key is set", + exitNodeIDKey: true, + exitNodeID: "123", + exitNodeIDWant: "123", + prefsChanged: true, + }, + { + name: "ExitNodeID key not set", + exitNodeIDKey: true, + exitNodeIDWant: "", + prefsChanged: false, + }, + { + name: "ExitNodeID key set, ExitNodeIP preference set", + exitNodeIDKey: true, + exitNodeID: "123", + prefs: &ipn.Prefs{ExitNodeIP: netip.MustParseAddr("127.0.0.1")}, + exitNodeIDWant: "123", + prefsChanged: true, + }, + { + name: "ExitNodeID key not set, ExitNodeIP key set", + exitNodeIPKey: true, + exitNodeIP: "127.0.0.1", + prefs: &ipn.Prefs{ExitNodeIP: netip.MustParseAddr("127.0.0.1")}, + exitNodeIPWant: "127.0.0.1", + prefsChanged: false, + }, + { + name: "ExitNodeIP key set, existing ExitNodeIP pref", + exitNodeIPKey: true, + exitNodeIP: "127.0.0.1", + prefs: &ipn.Prefs{ExitNodeIP: netip.MustParseAddr("127.0.0.1")}, + exitNodeIPWant: "127.0.0.1", + prefsChanged: false, + }, + { + name: "existing preferences match policy", + exitNodeIDKey: true, + exitNodeID: "123", + prefs: &ipn.Prefs{ExitNodeID: tailcfg.StableNodeID("123")}, + exitNodeIDWant: "123", + prefsChanged: false, + }, + { + name: "ExitNodeIP set if net map does not have corresponding node", + exitNodeIPKey: true, + prefs: &ipn.Prefs{ExitNodeIP: netip.MustParseAddr("127.0.0.1")}, + exitNodeIP: "127.0.0.1", + exitNodeIPWant: "127.0.0.1", + prefsChanged: false, + nm: &netmap.NetworkMap{ + SelfNode: (&tailcfg.Node{ + Name: "foo.tailnet.", + Addresses: []netip.Prefix{ + pfx("100.102.103.104/32"), + pfx("100::123/128"), + }, + }).View(), + Peers: []tailcfg.NodeView{ + (&tailcfg.Node{ + ID: 201, + Name: "a.tailnet", + Key: makeNodeKeyFromID(201), + Addresses: []netip.Prefix{ + pfx("100.0.0.201/32"), + pfx("100::201/128"), + }, + }).View(), + (&tailcfg.Node{ + ID: 202, + Name: "b.tailnet", + Key: makeNodeKeyFromID(202), + Addresses: []netip.Prefix{ + pfx("100::202/128"), + }, + }).View(), + }, + }, + }, + { + name: "ExitNodeIP cleared if net map has corresponding node - policy matches prefs", + prefs: &ipn.Prefs{ExitNodeIP: netip.MustParseAddr("127.0.0.1")}, + exitNodeIPKey: true, + exitNodeIP: "127.0.0.1", + exitNodeIPWant: "", + exitNodeIDWant: "123", + prefsChanged: true, + nm: &netmap.NetworkMap{ + SelfNode: (&tailcfg.Node{ + Name: "foo.tailnet.", + Addresses: []netip.Prefix{ + pfx("100.102.103.104/32"), + pfx("100::123/128"), + }, + }).View(), + Peers: []tailcfg.NodeView{ + (&tailcfg.Node{ + ID: 123, + Name: "a.tailnet", + StableID: tailcfg.StableNodeID("123"), + Key: makeNodeKeyFromID(123), + Addresses: []netip.Prefix{ + pfx("127.0.0.1/32"), + pfx("100::201/128"), + }, + Hostinfo: zeroValHostinfoView, + }).View(), + (&tailcfg.Node{ + ID: 202, + Name: "b.tailnet", + Key: makeNodeKeyFromID(202), + Addresses: []netip.Prefix{ + pfx("100::202/128"), + }, + Hostinfo: zeroValHostinfoView, + }).View(), + }, + }, + }, + { + name: "ExitNodeIP cleared if net map has corresponding node - no policy set", + prefs: &ipn.Prefs{ExitNodeIP: netip.MustParseAddr("127.0.0.1")}, + exitNodeIPWant: "", + exitNodeIDWant: "123", + prefsChanged: true, + nm: &netmap.NetworkMap{ + SelfNode: (&tailcfg.Node{ + Name: "foo.tailnet.", + Addresses: []netip.Prefix{ + pfx("100.102.103.104/32"), + pfx("100::123/128"), + }, + }).View(), + Peers: []tailcfg.NodeView{ + (&tailcfg.Node{ + ID: 123, + Name: "a.tailnet", + StableID: tailcfg.StableNodeID("123"), + Key: makeNodeKeyFromID(123), + Addresses: []netip.Prefix{ + pfx("127.0.0.1/32"), + pfx("100::201/128"), + }, + Hostinfo: zeroValHostinfoView, + }).View(), + (&tailcfg.Node{ + ID: 202, + Name: "b.tailnet", + Key: makeNodeKeyFromID(202), + Addresses: []netip.Prefix{ + pfx("100::202/128"), + }, + Hostinfo: zeroValHostinfoView, + }).View(), + }, + }, + }, + { + name: "ExitNodeIP cleared if net map has corresponding node - different exit node IP in policy", + exitNodeIPKey: true, + prefs: &ipn.Prefs{ExitNodeIP: netip.MustParseAddr("127.0.0.1")}, + exitNodeIP: "100.64.5.6", + exitNodeIPWant: "", + exitNodeIDWant: "123", + prefsChanged: true, + nm: &netmap.NetworkMap{ + SelfNode: (&tailcfg.Node{ + Name: "foo.tailnet.", + Addresses: []netip.Prefix{ + pfx("100.102.103.104/32"), + pfx("100::123/128"), + }, + }).View(), + Peers: []tailcfg.NodeView{ + (&tailcfg.Node{ + ID: 123, + Name: "a.tailnet", + StableID: tailcfg.StableNodeID("123"), + Key: makeNodeKeyFromID(123), + Addresses: []netip.Prefix{ + pfx("100.64.5.6/32"), + pfx("100::201/128"), + }, + Hostinfo: zeroValHostinfoView, + }).View(), + (&tailcfg.Node{ + ID: 202, + Name: "b.tailnet", + Key: makeNodeKeyFromID(202), + Addresses: []netip.Prefix{ + pfx("100::202/128"), + }, + Hostinfo: zeroValHostinfoView, + }).View(), + }, + }, + }, + { + name: "ExitNodeID key is set to auto:any and last suggested exit node is populated", + exitNodeIDKey: true, + exitNodeID: "auto:any", + lastSuggestedExitNode: "123", + exitNodeIDWant: "123", + autoExitNodeWant: "any", + prefsChanged: true, + }, + { + name: "ExitNodeID key is set to auto:any and last suggested exit node is not populated", + exitNodeIDKey: true, + exitNodeID: "auto:any", + exitNodeIDWant: "auto:any", + autoExitNodeWant: "any", + prefsChanged: true, + }, + { + name: "ExitNodeID key is set to auto:foo and last suggested exit node is populated", + exitNodeIDKey: true, + exitNodeID: "auto:foo", + lastSuggestedExitNode: "123", + exitNodeIDWant: "123", + autoExitNodeWant: "foo", + prefsChanged: true, + }, + { + name: "ExitNodeID key is set to auto:foo and last suggested exit node is not populated", + exitNodeIDKey: true, + exitNodeID: "auto:foo", + exitNodeIDWant: "auto:any", // should be "auto:any" for compatibility with existing clients + autoExitNodeWant: "foo", + prefsChanged: true, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + var polc policytest.Config + if test.exitNodeIDKey { + polc.Set(pkey.ExitNodeID, test.exitNodeID) + } + if test.exitNodeIPKey { + polc.Set(pkey.ExitNodeIP, test.exitNodeIP) + } + b := newTestBackend(t, polc) + + if test.nm == nil { + test.nm = new(netmap.NetworkMap) + } + if test.prefs == nil { + test.prefs = ipn.NewPrefs() + } + pm := must.Get(newProfileManager(new(mem.Store), t.Logf, health.NewTracker(eventbustest.NewBus(t)))) + pm.prefs = test.prefs.View() + b.currentNode().SetNetMap(test.nm) + b.pm = pm + b.lastSuggestedExitNode = test.lastSuggestedExitNode + prefs := b.pm.prefs.AsStruct() + if changed := b.reconcilePrefsLocked(prefs); changed != test.prefsChanged { + t.Errorf("wanted prefs changed %v, got prefs changed %v", test.prefsChanged, changed) + } + + // Both [LocalBackend.SetPrefsForTest] and [LocalBackend.EditPrefs] + // apply syspolicy settings to the current profile's preferences. Therefore, + // we pass the current, unmodified preferences and expect the effective + // preferences to change. + b.SetPrefsForTest(pm.CurrentPrefs().AsStruct()) + + if got := b.Prefs().ExitNodeID(); got != tailcfg.StableNodeID(test.exitNodeIDWant) { + t.Errorf("ExitNodeID: got %q; want %q", got, test.exitNodeIDWant) + } + if got := b.Prefs().ExitNodeIP(); test.exitNodeIPWant == "" { + if got.String() != "invalid IP" { + t.Errorf("ExitNodeIP: got %v want invalid IP", got) + } + } else if got.String() != test.exitNodeIPWant { + t.Errorf("ExitNodeIP: got %q; want %q", got, test.exitNodeIPWant) + } + if got := b.Prefs().AutoExitNode(); got != test.autoExitNodeWant { + t.Errorf("AutoExitNode: got %q; want %q", got, test.autoExitNodeWant) + } + }) + } +} + +func TestUpdateNetmapDeltaAutoExitNode(t *testing.T) { + peer1 := makePeer(1, withCap(26), withSuggest(), withOnline(true), withExitRoutes()) + peer2 := makePeer(2, withCap(26), withSuggest(), withOnline(true), withExitRoutes()) + derpMap := &tailcfg.DERPMap{ + Regions: map[int]*tailcfg.DERPRegion{ + 1: { + Nodes: []*tailcfg.DERPNode{ + { + Name: "t1", + RegionID: 1, + }, + }, + }, + 2: { + Nodes: []*tailcfg.DERPNode{ + { + Name: "t2", + RegionID: 2, + }, + }, + }, + }, + } + report := &netcheck.Report{ + RegionLatency: map[int]time.Duration{ + 1: 10 * time.Millisecond, + 2: 5 * time.Millisecond, + 3: 30 * time.Millisecond, + }, + PreferredDERP: 2, + } + tests := []struct { + name string + lastSuggestedExitNode tailcfg.StableNodeID + netmap *netmap.NetworkMap + muts []*tailcfg.PeerChange + exitNodeIDWant tailcfg.StableNodeID + report *netcheck.Report + }{ + { + // selected auto exit node goes offline + name: "exit-node-goes-offline", + // PreferredDERP is 2, and it's also the region with the lowest latency. + // So, peer2 should be selected as the exit node. + lastSuggestedExitNode: peer2.StableID(), + netmap: &netmap.NetworkMap{ + Peers: []tailcfg.NodeView{ + peer1, + peer2, + }, + DERPMap: derpMap, + }, + muts: []*tailcfg.PeerChange{ + { + NodeID: 1, + Online: ptr.To(true), + }, + { + NodeID: 2, + Online: ptr.To(false), // the selected exit node goes offline + }, + }, + exitNodeIDWant: peer1.StableID(), + report: report, + }, + { + // other exit node goes offline doesn't change selected auto exit node that's still online + name: "other-node-goes-offline", + lastSuggestedExitNode: peer2.StableID(), + netmap: &netmap.NetworkMap{ + Peers: []tailcfg.NodeView{ + peer1, + peer2, + }, + DERPMap: derpMap, + }, + muts: []*tailcfg.PeerChange{ + { + NodeID: 1, + Online: ptr.To(false), // a different exit node goes offline + }, + { + NodeID: 2, + Online: ptr.To(true), + }, + }, + exitNodeIDWant: peer2.StableID(), + report: report, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + sys := tsd.NewSystem() + sys.PolicyClient.Set(policytest.Config{ + pkey.ExitNodeID: "auto:any", + }) + b := newTestLocalBackendWithSys(t, sys) + b.currentNode().SetNetMap(tt.netmap) + b.lastSuggestedExitNode = tt.lastSuggestedExitNode + b.sys.MagicSock.Get().SetLastNetcheckReportForTest(b.ctx, tt.report) + b.SetPrefsForTest(b.pm.CurrentPrefs().AsStruct()) + + allDone := make(chan bool, 1) + defer b.goTracker.AddDoneCallback(func() { + b.mu.Lock() + defer b.mu.Unlock() + if b.goTracker.RunningGoroutines() > 0 { + return + } + select { + case allDone <- true: + default: + } + })() + + someTime := time.Unix(123, 0) + muts, ok := netmap.MutationsFromMapResponse(&tailcfg.MapResponse{ + PeersChangedPatch: tt.muts, + }, someTime) + if !ok { + t.Fatal("netmap.MutationsFromMapResponse failed") + } + + if b.pm.prefs.ExitNodeID() != tt.lastSuggestedExitNode { + t.Fatalf("did not set exit node ID to last suggested exit node despite auto policy") + } + + was := b.goTracker.StartedGoroutines() + got := b.UpdateNetmapDelta(muts) + if !got { + t.Error("got false from UpdateNetmapDelta") + } + startedGoroutine := b.goTracker.StartedGoroutines() != was + + wantChange := tt.exitNodeIDWant != tt.lastSuggestedExitNode + if startedGoroutine != wantChange { + t.Errorf("got startedGoroutine %v, want %v", startedGoroutine, wantChange) + } + if startedGoroutine { + select { + case <-time.After(5 * time.Second): + t.Fatal("timed out waiting for goroutine to finish") + case <-allDone: + } + } + b.mu.Lock() + gotExitNode := b.pm.prefs.ExitNodeID() + b.mu.Unlock() + if gotExitNode != tt.exitNodeIDWant { + t.Fatalf("exit node ID after UpdateNetmapDelta = %v; want %v", gotExitNode, tt.exitNodeIDWant) + } + }) + } +} + +func TestAutoExitNodeSetNetInfoCallback(t *testing.T) { + polc := policytest.Config{ + pkey.ExitNodeID: "auto:any", + } + sys := tsd.NewSystem() + sys.PolicyClient.Set(polc) + + b := newTestLocalBackendWithSys(t, sys) + hi := hostinfo.New() + ni := tailcfg.NetInfo{LinkType: "wired"} + hi.NetInfo = &ni + b.hostinfo = hi + k := key.NewMachine() + var cc *mockControl + dialer := tsdial.NewDialer(netmon.NewStatic()) + dialer.SetBus(sys.Bus.Get()) + opts := controlclient.Options{ + ServerURL: "https://example.com", + GetMachinePrivateKey: func() (key.MachinePrivate, error) { + return k, nil + }, + Dialer: dialer, + Logf: b.logf, + PolicyClient: polc, + } + cc = newClient(t, opts) + b.cc = cc + peer1 := makePeer(1, withCap(26), withDERP(3), withSuggest(), withExitRoutes()) + peer2 := makePeer(2, withCap(26), withDERP(2), withSuggest(), withExitRoutes()) + selfNode := tailcfg.Node{ + Addresses: []netip.Prefix{ + netip.MustParsePrefix("100.64.1.1/32"), + netip.MustParsePrefix("fe70::1/128"), + }, + HomeDERP: 2, + } + defaultDERPMap := &tailcfg.DERPMap{ + Regions: map[int]*tailcfg.DERPRegion{ + 1: { + Nodes: []*tailcfg.DERPNode{ + { + Name: "t1", + RegionID: 1, + }, + }, + }, + 2: { + Nodes: []*tailcfg.DERPNode{ + { + Name: "t2", + RegionID: 2, + }, + }, + }, + 3: { + Nodes: []*tailcfg.DERPNode{ + { + Name: "t3", + RegionID: 3, + }, + }, + }, + }, + } + b.currentNode().SetNetMap(&netmap.NetworkMap{ + SelfNode: selfNode.View(), + Peers: []tailcfg.NodeView{ + peer1, + peer2, + }, + DERPMap: defaultDERPMap, + }) + b.lastSuggestedExitNode = peer1.StableID() + b.SetPrefsForTest(b.pm.CurrentPrefs().AsStruct()) + if eid := b.Prefs().ExitNodeID(); eid != peer1.StableID() { + t.Errorf("got initial exit node %v, want %v", eid, peer1.StableID()) + } + b.refreshAutoExitNode = true + b.sys.MagicSock.Get().SetLastNetcheckReportForTest(b.ctx, &netcheck.Report{ + RegionLatency: map[int]time.Duration{ + 1: 10 * time.Millisecond, + 2: 5 * time.Millisecond, + 3: 30 * time.Millisecond, + }, + PreferredDERP: 2, + }) + b.setNetInfo(&ni) + if eid := b.Prefs().ExitNodeID(); eid != peer2.StableID() { + t.Errorf("got final exit node %v, want %v", eid, peer2.StableID()) + } +} + +func TestSetControlClientStatusAutoExitNode(t *testing.T) { + peer1 := makePeer(1, withCap(26), withSuggest(), withExitRoutes(), withOnline(true), withNodeKey()) + peer2 := makePeer(2, withCap(26), withSuggest(), withExitRoutes(), withOnline(true), withNodeKey()) + derpMap := &tailcfg.DERPMap{ + Regions: map[int]*tailcfg.DERPRegion{ + 1: { + Nodes: []*tailcfg.DERPNode{ + { + Name: "t1", + RegionID: 1, + }, + }, + }, + 2: { + Nodes: []*tailcfg.DERPNode{ + { + Name: "t2", + RegionID: 2, + }, + }, + }, + }, + } + report := &netcheck.Report{ + RegionLatency: map[int]time.Duration{ + 1: 10 * time.Millisecond, + 2: 5 * time.Millisecond, + 3: 30 * time.Millisecond, + }, + PreferredDERP: 1, + } + nm := &netmap.NetworkMap{ + Peers: []tailcfg.NodeView{ + peer1, + peer2, + }, + DERPMap: derpMap, + } + + polc := policytest.Config{ + pkey.ExitNodeID: "auto:any", + } + sys := tsd.NewSystem() + sys.PolicyClient.Set(polc) + + b := newTestLocalBackendWithSys(t, sys) + b.currentNode().SetNetMap(nm) + // Peer 2 should be the initial exit node, as it's better than peer 1 + // in terms of latency and DERP region. + b.lastSuggestedExitNode = peer2.StableID() + b.sys.MagicSock.Get().SetLastNetcheckReportForTest(b.ctx, report) + b.SetPrefsForTest(b.pm.CurrentPrefs().AsStruct()) + offlinePeer2 := makePeer(2, withCap(26), withSuggest(), withExitRoutes(), withOnline(false), withNodeKey()) + updatedNetmap := &netmap.NetworkMap{ + Peers: []tailcfg.NodeView{ + peer1, + offlinePeer2, + }, + DERPMap: derpMap, + } + b.SetControlClientStatus(b.cc, controlclient.Status{NetMap: updatedNetmap}) + // But now that peer 2 is offline, we should switch to peer 1. + wantExitNode := peer1.StableID() + gotExitNode := b.Prefs().ExitNodeID() + if gotExitNode != wantExitNode { + t.Errorf("did not switch exit nodes despite auto exit node going offline: got %q; want %q", gotExitNode, wantExitNode) + } +} + +func TestApplySysPolicy(t *testing.T) { + tests := []struct { + name string + prefs ipn.Prefs + wantPrefs ipn.Prefs + wantAnyChange bool + stringPolicies map[pkey.Key]string + }{ + { + name: "empty prefs without policies", + }, + { + name: "prefs set without policies", + prefs: ipn.Prefs{ + ControlURL: "1", + ShieldsUp: true, + ForceDaemon: true, + ExitNodeAllowLANAccess: true, + CorpDNS: true, + RouteAll: true, + }, + wantPrefs: ipn.Prefs{ + ControlURL: "1", + ShieldsUp: true, + ForceDaemon: true, + ExitNodeAllowLANAccess: true, + CorpDNS: true, + RouteAll: true, + }, + }, + { + name: "empty prefs with policies", + wantPrefs: ipn.Prefs{ + ControlURL: "1", + ShieldsUp: true, + ForceDaemon: true, + ExitNodeAllowLANAccess: true, + CorpDNS: true, + RouteAll: true, + }, + wantAnyChange: true, + stringPolicies: map[pkey.Key]string{ + pkey.ControlURL: "1", + pkey.EnableIncomingConnections: "never", + pkey.EnableServerMode: "always", + pkey.ExitNodeAllowLANAccess: "always", + pkey.EnableTailscaleDNS: "always", + pkey.EnableTailscaleSubnets: "always", + }, + }, + { + name: "prefs set with matching policies", + prefs: ipn.Prefs{ + ControlURL: "1", + ShieldsUp: true, + ForceDaemon: true, + }, + wantPrefs: ipn.Prefs{ + ControlURL: "1", + ShieldsUp: true, + ForceDaemon: true, + }, + stringPolicies: map[pkey.Key]string{ + pkey.ControlURL: "1", + pkey.EnableIncomingConnections: "never", + pkey.EnableServerMode: "always", + pkey.ExitNodeAllowLANAccess: "never", + pkey.EnableTailscaleDNS: "never", + pkey.EnableTailscaleSubnets: "never", + }, + }, + { + name: "prefs set with conflicting policies", + prefs: ipn.Prefs{ + ControlURL: "1", + ShieldsUp: true, + ForceDaemon: true, + ExitNodeAllowLANAccess: false, + CorpDNS: true, + RouteAll: false, + }, + wantPrefs: ipn.Prefs{ + ControlURL: "2", + ShieldsUp: false, + ForceDaemon: false, + ExitNodeAllowLANAccess: true, + CorpDNS: false, + RouteAll: true, + }, + wantAnyChange: true, + stringPolicies: map[pkey.Key]string{ + pkey.ControlURL: "2", + pkey.EnableIncomingConnections: "always", + pkey.EnableServerMode: "never", + pkey.ExitNodeAllowLANAccess: "always", + pkey.EnableTailscaleDNS: "never", + pkey.EnableTailscaleSubnets: "always", + }, + }, + { + name: "prefs set with neutral policies", + prefs: ipn.Prefs{ + ControlURL: "1", + ShieldsUp: true, + ForceDaemon: true, + ExitNodeAllowLANAccess: false, + CorpDNS: true, + RouteAll: true, + }, + wantPrefs: ipn.Prefs{ + ControlURL: "1", + ShieldsUp: true, + ForceDaemon: true, + ExitNodeAllowLANAccess: false, + CorpDNS: true, + RouteAll: true, + }, + stringPolicies: map[pkey.Key]string{ + pkey.EnableIncomingConnections: "user-decides", + pkey.EnableServerMode: "user-decides", + pkey.ExitNodeAllowLANAccess: "user-decides", + pkey.EnableTailscaleDNS: "user-decides", + pkey.EnableTailscaleSubnets: "user-decides", + }, + }, + { + name: "ControlURL", + wantPrefs: ipn.Prefs{ + ControlURL: "set", + }, + wantAnyChange: true, + stringPolicies: map[pkey.Key]string{ + pkey.ControlURL: "set", + }, + }, + { + name: "enable AutoUpdate apply does not unset check", + prefs: ipn.Prefs{ + AutoUpdate: ipn.AutoUpdatePrefs{ + Check: true, + Apply: opt.NewBool(false), + }, + }, + wantPrefs: ipn.Prefs{ + AutoUpdate: ipn.AutoUpdatePrefs{ + Check: true, + Apply: opt.NewBool(true), + }, + }, + wantAnyChange: true, + stringPolicies: map[pkey.Key]string{ + pkey.ApplyUpdates: "always", + }, + }, + { + name: "disable AutoUpdate apply does not unset check", + prefs: ipn.Prefs{ + AutoUpdate: ipn.AutoUpdatePrefs{ + Check: true, + Apply: opt.NewBool(true), + }, + }, + wantPrefs: ipn.Prefs{ + AutoUpdate: ipn.AutoUpdatePrefs{ + Check: true, + Apply: opt.NewBool(false), + }, + }, + wantAnyChange: true, + stringPolicies: map[pkey.Key]string{ + pkey.ApplyUpdates: "never", + }, + }, + { + name: "enable AutoUpdate check does not unset apply", + prefs: ipn.Prefs{ + AutoUpdate: ipn.AutoUpdatePrefs{ + Check: false, + Apply: opt.NewBool(true), + }, + }, + wantPrefs: ipn.Prefs{ + AutoUpdate: ipn.AutoUpdatePrefs{ + Check: true, + Apply: opt.NewBool(true), + }, + }, + wantAnyChange: true, + stringPolicies: map[pkey.Key]string{ + pkey.CheckUpdates: "always", + }, + }, + { + name: "disable AutoUpdate check does not unset apply", + prefs: ipn.Prefs{ + AutoUpdate: ipn.AutoUpdatePrefs{ + Check: true, + Apply: opt.NewBool(true), + }, + }, + wantPrefs: ipn.Prefs{ + AutoUpdate: ipn.AutoUpdatePrefs{ + Check: false, + Apply: opt.NewBool(true), + }, + }, + wantAnyChange: true, + stringPolicies: map[pkey.Key]string{ + pkey.CheckUpdates: "never", + }, + }, } - appCfg := `{ - "name": "example", - "domains": ["example.com"], - "connectors": ["tag:example"] - }` + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var polc policytest.Config + for k, v := range tt.stringPolicies { + polc.Set(k, v) + } - b.netMap.SelfNode = (&tailcfg.Node{ - Name: "example.ts.net", - Tags: []string{"tag:example"}, - CapMap: (tailcfg.NodeCapMap)(map[tailcfg.NodeCapability][]tailcfg.RawMessage{ - "tailscale.com/app-connectors": {tailcfg.RawMessage(appCfg)}, - }), - }).View() + t.Run("unit", func(t *testing.T) { + prefs := tt.prefs.Clone() - b.reconfigAppConnectorLocked(b.netMap, b.pm.prefs) - b.appConnector.Wait(context.Background()) + sys := tsd.NewSystem() + sys.PolicyClient.Set(polc) - want := []string{"example.com"} - if !slices.Equal(b.appConnector.Domains().AsSlice(), want) { - t.Fatalf("got domains %v, want %v", b.appConnector.Domains(), want) - } - if v, _ := b.hostinfo.AppConnector.Get(); !v { - t.Fatalf("expected app connector service") - } + lb := newTestLocalBackendWithSys(t, sys) + gotAnyChange := lb.applySysPolicyLocked(prefs) - // disable the connector in order to assert that the service is removed - b.EditPrefs(&ipn.MaskedPrefs{ - Prefs: ipn.Prefs{ - AppConnector: ipn.AppConnectorPrefs{ - Advertise: false, - }, - }, - AppConnectorSet: true, - }) - b.reconfigAppConnectorLocked(b.netMap, b.pm.prefs) - if b.appConnector != nil { - t.Fatal("expected no app connector") - } - if v, _ := b.hostinfo.AppConnector.Get(); v { - t.Fatalf("expected no app connector service") - } -} + if gotAnyChange && prefs.Equals(&tt.prefs) { + t.Errorf("anyChange but prefs is unchanged: %v", prefs.Pretty()) + } + if !gotAnyChange && !prefs.Equals(&tt.prefs) { + t.Errorf("!anyChange but prefs changed from %v to %v", tt.prefs.Pretty(), prefs.Pretty()) + } + if gotAnyChange != tt.wantAnyChange { + t.Errorf("anyChange=%v, want %v", gotAnyChange, tt.wantAnyChange) + } + if !prefs.Equals(&tt.wantPrefs) { + t.Errorf("prefs=%v, want %v", prefs.Pretty(), tt.wantPrefs.Pretty()) + } + }) -func resolversEqual(t *testing.T, a, b []*dnstype.Resolver) bool { - if a == nil && b == nil { - return true - } - if a == nil || b == nil { - t.Errorf("resolversEqual: a == nil || b == nil : %#v != %#v", a, b) - return false - } - if len(a) != len(b) { - t.Errorf("resolversEqual: len(a) != len(b) : %#v != %#v", a, b) - return false - } - for i := range a { - if !a[i].Equal(b[i]) { - t.Errorf("resolversEqual: a != b [%d]: %v != %v", i, *a[i], *b[i]) - return false - } - } - return true -} + t.Run("status update", func(t *testing.T) { + // Profile manager fills in blank ControlURL but it's not set + // in most test cases to avoid cluttering them, so adjust for + // that. + usePrefs := tt.prefs.Clone() + if usePrefs.ControlURL == "" { + usePrefs.ControlURL = ipn.DefaultControlURL + } + wantPrefs := tt.wantPrefs.Clone() + if wantPrefs.ControlURL == "" { + wantPrefs.ControlURL = ipn.DefaultControlURL + } -func routesEqual(t *testing.T, a, b map[dnsname.FQDN][]*dnstype.Resolver) bool { - if len(a) != len(b) { - t.Logf("routes: len(a) != len(b): %d != %d", len(a), len(b)) - return false - } - for name := range a { - if !resolversEqual(t, a[name], b[name]) { - t.Logf("routes: a != b [%s]: %v != %v", name, a[name], b[name]) - return false - } + pm := must.Get(newProfileManager(new(mem.Store), t.Logf, health.NewTracker(eventbustest.NewBus(t)))) + pm.prefs = usePrefs.View() + + b := newTestBackend(t, polc) + b.mu.Lock() + b.pm = pm + b.mu.Unlock() + + b.SetControlClientStatus(b.cc, controlclient.Status{}) + if !b.Prefs().Equals(wantPrefs.View()) { + t.Errorf("prefs=%v, want %v", b.Prefs().Pretty(), wantPrefs.Pretty()) + } + }) + }) } - return true } -// dnsResponse is a test helper that creates a DNS response buffer for the given domain and address -func dnsResponse(domain, address string) []byte { - addr := netip.MustParseAddr(address) - b := dnsmessage.NewBuilder(nil, dnsmessage.Header{}) - b.EnableCompression() - b.StartAnswers() - switch addr.BitLen() { - case 32: - b.AResource( - dnsmessage.ResourceHeader{ - Name: dnsmessage.MustNewName(domain), - Type: dnsmessage.TypeA, - Class: dnsmessage.ClassINET, - TTL: 0, - }, - dnsmessage.AResource{ - A: addr.As4(), - }, - ) - case 128: - b.AAAAResource( - dnsmessage.ResourceHeader{ - Name: dnsmessage.MustNewName(domain), - Type: dnsmessage.TypeAAAA, - Class: dnsmessage.ClassINET, - TTL: 0, - }, - dnsmessage.AAAAResource{ - AAAA: addr.As16(), - }, - ) - default: - panic("invalid address length") +func TestPreferencePolicyInfo(t *testing.T) { + tests := []struct { + name string + initialValue bool + wantValue bool + wantChange bool + policyValue string + policyError error + }{ + { + name: "force enable modify", + initialValue: false, + wantValue: true, + wantChange: true, + policyValue: "always", + }, + { + name: "force enable unchanged", + initialValue: true, + wantValue: true, + policyValue: "always", + }, + { + name: "force disable modify", + initialValue: true, + wantValue: false, + wantChange: true, + policyValue: "never", + }, + { + name: "force disable unchanged", + initialValue: false, + wantValue: false, + policyValue: "never", + }, + { + name: "unforced enabled", + initialValue: true, + wantValue: true, + policyValue: "user-decides", + }, + { + name: "unforced disabled", + initialValue: false, + wantValue: false, + policyValue: "user-decides", + }, + { + name: "blank enabled", + initialValue: true, + wantValue: true, + policyValue: "", + }, + { + name: "blank disabled", + initialValue: false, + wantValue: false, + policyValue: "", + }, + { + name: "unset enabled", + initialValue: true, + wantValue: true, + policyError: syspolicy.ErrNoSuchKey, + }, + { + name: "unset disabled", + initialValue: false, + wantValue: false, + policyError: syspolicy.ErrNoSuchKey, + }, + { + name: "error enabled", + initialValue: true, + wantValue: true, + policyError: errors.New("test error"), + }, + { + name: "error disabled", + initialValue: false, + wantValue: false, + policyError: errors.New("test error"), + }, } - return must.Get(b.Finish()) -} -type errorSyspolicyHandler struct { - t *testing.T - err error - key syspolicy.Key - allowKeys map[syspolicy.Key]*string -} + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + for _, pp := range preferencePolicies { + t.Run(string(pp.key), func(t *testing.T) { + t.Parallel() -func (h *errorSyspolicyHandler) ReadString(key string) (string, error) { - sk := syspolicy.Key(key) - if _, ok := h.allowKeys[sk]; !ok { - h.t.Errorf("ReadString: %q is not in list of permitted keys", h.key) - } - if sk == h.key { - return "", h.err - } - return "", syspolicy.ErrNoSuchKey -} + var polc policytest.Config + if tt.policyError != nil { + polc.Set(pp.key, tt.policyError) + } else { + polc.Set(pp.key, tt.policyValue) + } -func (h *errorSyspolicyHandler) ReadUInt64(key string) (uint64, error) { - h.t.Errorf("ReadUInt64(%q) unexpectedly called", key) - return 0, syspolicy.ErrNoSuchKey -} + prefs := defaultPrefs.AsStruct() + pp.set(prefs, tt.initialValue) -func (h *errorSyspolicyHandler) ReadBoolean(key string) (bool, error) { - h.t.Errorf("ReadBoolean(%q) unexpectedly called", key) - return false, syspolicy.ErrNoSuchKey -} + bus := eventbustest.NewBus(t) + sys := tsd.NewSystemWithBus(bus) + sys.PolicyClient.Set(polc) -func (h *errorSyspolicyHandler) ReadStringArray(key string) ([]string, error) { - h.t.Errorf("ReadStringArray(%q) unexpectedly called", key) - return nil, syspolicy.ErrNoSuchKey -} + lb := newTestLocalBackendWithSys(t, sys) + gotAnyChange := lb.applySysPolicyLocked(prefs) -type mockSyspolicyHandler struct { - t *testing.T - // stringPolicies is the collection of policies that we expect to see - // queried by the current test. If the policy is expected but unset, then - // use nil, otherwise use a string equal to the policy's desired value. - stringPolicies map[syspolicy.Key]*string - // stringArrayPolicies is the collection of policies that we expected to see - // queries by the current test, that return policy string arrays. - stringArrayPolicies map[syspolicy.Key][]string - // failUnknownPolicies is set if policies other than those in stringPolicies - // (uint64 or bool policies are not supported by mockSyspolicyHandler yet) - // should be considered a test failure if they are queried. - failUnknownPolicies bool + if gotAnyChange != tt.wantChange { + t.Errorf("anyChange=%v, want %v", gotAnyChange, tt.wantChange) + } + got := pp.get(prefs.View()) + if got != tt.wantValue { + t.Errorf("pref=%v, want %v", got, tt.wantValue) + } + }) + } + }) + } } -func (h *mockSyspolicyHandler) ReadString(key string) (string, error) { - if s, ok := h.stringPolicies[syspolicy.Key(key)]; ok { - if s == nil { - return "", syspolicy.ErrNoSuchKey - } - return *s, nil +func TestOnTailnetDefaultAutoUpdate(t *testing.T) { + tests := []struct { + before, after opt.Bool + container opt.Bool + tailnetDefault bool + }{ + { + before: opt.Bool(""), + tailnetDefault: true, + after: opt.NewBool(true), + }, + { + before: opt.Bool(""), + tailnetDefault: false, + after: opt.NewBool(false), + }, + { + before: opt.Bool("unset"), + tailnetDefault: true, + after: opt.NewBool(true), + }, + { + before: opt.Bool("unset"), + tailnetDefault: false, + after: opt.NewBool(false), + }, + { + before: opt.NewBool(false), + tailnetDefault: true, + after: opt.NewBool(false), + }, + { + before: opt.NewBool(true), + tailnetDefault: false, + after: opt.NewBool(true), + }, + { + before: opt.Bool(""), + container: opt.NewBool(true), + tailnetDefault: true, + after: opt.Bool(""), + }, + { + before: opt.NewBool(false), + container: opt.NewBool(true), + tailnetDefault: true, + after: opt.NewBool(false), + }, + { + before: opt.NewBool(true), + container: opt.NewBool(true), + tailnetDefault: false, + after: opt.NewBool(true), + }, } - if h.failUnknownPolicies { - h.t.Errorf("ReadString(%q) unexpectedly called", key) + for _, tt := range tests { + t.Run(fmt.Sprintf("before=%s,after=%s", tt.before, tt.after), func(t *testing.T) { + b := newTestBackend(t) + b.hostinfo = hostinfo.New() + b.hostinfo.Container = tt.container + p := ipn.NewPrefs() + p.AutoUpdate.Apply = tt.before + if err := b.pm.setPrefsNoPermCheck(p.View()); err != nil { + t.Fatal(err) + } + b.onTailnetDefaultAutoUpdate(tt.tailnetDefault) + want := tt.after + // On platforms that don't support auto-update we can never + // transition to auto-updates being enabled. The value should + // remain unchanged after onTailnetDefaultAutoUpdate. + if !feature.CanAutoUpdate() { + want = tt.before + } + if got := b.pm.CurrentPrefs().AutoUpdate().Apply; got != want { + t.Errorf("got: %q, want %q", got, want) + } + }) } - return "", syspolicy.ErrNoSuchKey } -func (h *mockSyspolicyHandler) ReadUInt64(key string) (uint64, error) { - if h.failUnknownPolicies { - h.t.Errorf("ReadUInt64(%q) unexpectedly called", key) +func TestTCPHandlerForDst(t *testing.T) { + b := newTestBackend(t) + tests := []struct { + desc string + dst string + intercept bool + }{ + { + desc: "intercept port 80 (Web UI) on quad100 IPv4", + dst: "100.100.100.100:80", + intercept: true, + }, + { + desc: "intercept port 80 (Web UI) on quad100 IPv6", + dst: "[fd7a:115c:a1e0::53]:80", + intercept: true, + }, + { + desc: "don't intercept port 80 on local ip", + dst: "100.100.103.100:80", + intercept: false, + }, + { + desc: "intercept port 8080 (Taildrive) on quad100 IPv4", + dst: "[fd7a:115c:a1e0::53]:8080", + intercept: true, + }, + { + desc: "don't intercept port 8080 on local ip", + dst: "100.100.103.100:8080", + intercept: false, + }, + { + desc: "don't intercept port 9080 on quad100 IPv4", + dst: "100.100.100.100:9080", + intercept: false, + }, + { + desc: "don't intercept port 9080 on quad100 IPv6", + dst: "[fd7a:115c:a1e0::53]:9080", + intercept: false, + }, + { + desc: "don't intercept port 9080 on local ip", + dst: "100.100.103.100:9080", + intercept: false, + }, + } + for _, tt := range tests { + t.Run(tt.dst, func(t *testing.T) { + t.Log(tt.desc) + src := netip.MustParseAddrPort("100.100.102.100:51234") + h, _ := b.TCPHandlerForDst(src, netip.MustParseAddrPort(tt.dst)) + if !tt.intercept && h != nil { + t.Error("intercepted traffic we shouldn't have") + } else if tt.intercept && h == nil { + t.Error("failed to intercept traffic we should have") + } + }) } - return 0, syspolicy.ErrNoSuchKey } -func (h *mockSyspolicyHandler) ReadBoolean(key string) (bool, error) { - if h.failUnknownPolicies { - h.t.Errorf("ReadBoolean(%q) unexpectedly called", key) +func TestTCPHandlerForDstWithVIPService(t *testing.T) { + b := newTestBackend(t) + svcIPMap := tailcfg.ServiceIPMappings{ + "svc:foo": []netip.Addr{ + netip.MustParseAddr("100.101.101.101"), + netip.MustParseAddr("fd7a:115c:a1e0:ab12:4843:cd96:6565:6565"), + }, + "svc:bar": []netip.Addr{ + netip.MustParseAddr("100.99.99.99"), + netip.MustParseAddr("fd7a:115c:a1e0:ab12:4843:cd96:626b:628b"), + }, + "svc:baz": []netip.Addr{ + netip.MustParseAddr("100.133.133.133"), + netip.MustParseAddr("fd7a:115c:a1e0:ab12:4843:cd96:8585:8585"), + }, + } + svcIPMapJSON, err := json.Marshal(svcIPMap) + if err != nil { + t.Fatal(err) + } + b.setNetMapLocked( + &netmap.NetworkMap{ + SelfNode: (&tailcfg.Node{ + Name: "example.ts.net", + CapMap: tailcfg.NodeCapMap{ + tailcfg.NodeAttrServiceHost: []tailcfg.RawMessage{tailcfg.RawMessage(svcIPMapJSON)}, + }, + }).View(), + UserProfiles: map[tailcfg.UserID]tailcfg.UserProfileView{ + tailcfg.UserID(1): (&tailcfg.UserProfile{ + LoginName: "someone@example.com", + DisplayName: "Some One", + ProfilePicURL: "https://example.com/photo.jpg", + }).View(), + }, + }, + ) + + err = b.setServeConfigLocked( + &ipn.ServeConfig{ + Services: map[tailcfg.ServiceName]*ipn.ServiceConfig{ + "svc:foo": { + TCP: map[uint16]*ipn.TCPPortHandler{ + 882: {HTTP: true}, + 883: {HTTPS: true}, + }, + Web: map[ipn.HostPort]*ipn.WebServerConfig{ + "foo.example.ts.net:882": { + Handlers: map[string]*ipn.HTTPHandler{ + "/": {Proxy: "http://127.0.0.1:3000"}, + }, + }, + "foo.example.ts.net:883": { + Handlers: map[string]*ipn.HTTPHandler{ + "/": {Text: "test"}, + }, + }, + }, + }, + "svc:bar": { + TCP: map[uint16]*ipn.TCPPortHandler{ + 990: {TCPForward: "127.0.0.1:8443"}, + 991: {TCPForward: "127.0.0.1:5432", TerminateTLS: "bar.test.ts.net"}, + }, + }, + "svc:qux": { + TCP: map[uint16]*ipn.TCPPortHandler{ + 600: {HTTPS: true}, + }, + Web: map[ipn.HostPort]*ipn.WebServerConfig{ + "qux.example.ts.net:600": { + Handlers: map[string]*ipn.HTTPHandler{ + "/": {Text: "qux"}, + }, + }, + }, + }, + }, + }, + "", + ) + if err != nil { + t.Fatal(err) + } + + tests := []struct { + desc string + dst string + intercept bool + }{ + { + desc: "intercept port 80 (Web UI) on quad100 IPv4", + dst: "100.100.100.100:80", + intercept: true, + }, + { + desc: "intercept port 80 (Web UI) on quad100 IPv6", + dst: "[fd7a:115c:a1e0::53]:80", + intercept: true, + }, + { + desc: "don't intercept port 80 on local ip", + dst: "100.100.103.100:80", + intercept: false, + }, + { + desc: "intercept port 8080 (Taildrive) on quad100 IPv4", + dst: "100.100.100.100:8080", + intercept: true, + }, + { + desc: "intercept port 8080 (Taildrive) on quad100 IPv6", + dst: "[fd7a:115c:a1e0::53]:8080", + intercept: true, + }, + { + desc: "don't intercept port 8080 on local ip", + dst: "100.100.103.100:8080", + intercept: false, + }, + { + desc: "don't intercept port 9080 on quad100 IPv4", + dst: "100.100.100.100:9080", + intercept: false, + }, + { + desc: "don't intercept port 9080 on quad100 IPv6", + dst: "[fd7a:115c:a1e0::53]:9080", + intercept: false, + }, + { + desc: "don't intercept port 9080 on local ip", + dst: "100.100.103.100:9080", + intercept: false, + }, + // VIP service destinations + { + desc: "intercept port 882 (HTTP) on service foo IPv4", + dst: "100.101.101.101:882", + intercept: true, + }, + { + desc: "intercept port 882 (HTTP) on service foo IPv6", + dst: "[fd7a:115c:a1e0:ab12:4843:cd96:6565:6565]:882", + intercept: true, + }, + { + desc: "intercept port 883 (HTTPS) on service foo IPv4", + dst: "100.101.101.101:883", + intercept: true, + }, + { + desc: "intercept port 883 (HTTPS) on service foo IPv6", + dst: "[fd7a:115c:a1e0:ab12:4843:cd96:6565:6565]:883", + intercept: true, + }, + { + desc: "intercept port 990 (TCPForward) on service bar IPv4", + dst: "100.99.99.99:990", + intercept: true, + }, + { + desc: "intercept port 990 (TCPForward) on service bar IPv6", + dst: "[fd7a:115c:a1e0:ab12:4843:cd96:626b:628b]:990", + intercept: true, + }, + { + desc: "intercept port 991 (TCPForward with TerminateTLS) on service bar IPv4", + dst: "100.99.99.99:990", + intercept: true, + }, + { + desc: "intercept port 991 (TCPForward with TerminateTLS) on service bar IPv6", + dst: "[fd7a:115c:a1e0:ab12:4843:cd96:626b:628b]:990", + intercept: true, + }, + { + desc: "don't intercept port 4444 on service foo IPv4", + dst: "100.101.101.101:4444", + intercept: false, + }, + { + desc: "don't intercept port 4444 on service foo IPv6", + dst: "[fd7a:115c:a1e0:ab12:4843:cd96:6565:6565]:4444", + intercept: false, + }, + { + desc: "don't intercept port 600 on unknown service IPv4", + dst: "100.22.22.22:883", + intercept: false, + }, + { + desc: "don't intercept port 600 on unknown service IPv6", + dst: "[fd7a:115c:a1e0:ab12:4843:cd96:626b:628b]:883", + intercept: false, + }, + { + desc: "don't intercept port 600 (HTTPS) on service baz IPv4", + dst: "100.133.133.133:600", + intercept: false, + }, + { + desc: "don't intercept port 600 (HTTPS) on service baz IPv6", + dst: "[fd7a:115c:a1e0:ab12:4843:cd96:8585:8585]:600", + intercept: false, + }, } - return false, syspolicy.ErrNoSuchKey -} -func (h *mockSyspolicyHandler) ReadStringArray(key string) ([]string, error) { - if h.failUnknownPolicies { - h.t.Errorf("ReadStringArray(%q) unexpectedly called", key) - } - if s, ok := h.stringArrayPolicies[syspolicy.Key(key)]; ok { - if s == nil { - return []string{}, syspolicy.ErrNoSuchKey - } - return s, nil + for _, tt := range tests { + t.Run(tt.dst, func(t *testing.T) { + t.Log(tt.desc) + src := netip.MustParseAddrPort("100.100.102.100:51234") + h, _ := b.TCPHandlerForDst(src, netip.MustParseAddrPort(tt.dst)) + if !tt.intercept && h != nil { + t.Error("intercepted traffic we shouldn't have") + } else if tt.intercept && h == nil { + t.Error("failed to intercept traffic we should have") + } + }) } - return nil, syspolicy.ErrNoSuchKey } -func TestSetExitNodeIDPolicy(t *testing.T) { - pfx := netip.MustParsePrefix +func TestDriveManageShares(t *testing.T) { tests := []struct { - name string - exitNodeIPKey bool - exitNodeIDKey bool - exitNodeID string - exitNodeIP string - prefs *ipn.Prefs - exitNodeIPWant string - exitNodeIDWant string - prefsChanged bool - nm *netmap.NetworkMap - lastSuggestedExitNode tailcfg.StableNodeID + name string + disabled bool + existing []*drive.Share + add *drive.Share + remove string + rename [2]string + expect any }{ { - name: "ExitNodeID key is set", - exitNodeIDKey: true, - exitNodeID: "123", - exitNodeIDWant: "123", - prefsChanged: true, + name: "append", + existing: []*drive.Share{ + {Name: "b"}, + {Name: "d"}, + }, + add: &drive.Share{Name: " E "}, + expect: []*drive.Share{ + {Name: "b"}, + {Name: "d"}, + {Name: "e"}, + }, }, { - name: "ExitNodeID key not set", - exitNodeIDKey: true, - exitNodeIDWant: "", - prefsChanged: false, + name: "prepend", + existing: []*drive.Share{ + {Name: "b"}, + {Name: "d"}, + }, + add: &drive.Share{Name: " A "}, + expect: []*drive.Share{ + {Name: "a"}, + {Name: "b"}, + {Name: "d"}, + }, }, { - name: "ExitNodeID key set, ExitNodeIP preference set", - exitNodeIDKey: true, - exitNodeID: "123", - prefs: &ipn.Prefs{ExitNodeIP: netip.MustParseAddr("127.0.0.1")}, - exitNodeIDWant: "123", - prefsChanged: true, + name: "insert", + existing: []*drive.Share{ + {Name: "b"}, + {Name: "d"}, + }, + add: &drive.Share{Name: " C "}, + expect: []*drive.Share{ + {Name: "b"}, + {Name: "c"}, + {Name: "d"}, + }, }, { - name: "ExitNodeID key not set, ExitNodeIP key set", - exitNodeIPKey: true, - exitNodeIP: "127.0.0.1", - prefs: &ipn.Prefs{ExitNodeIP: netip.MustParseAddr("127.0.0.1")}, - exitNodeIPWant: "127.0.0.1", - prefsChanged: false, + name: "replace", + existing: []*drive.Share{ + {Name: "b", Path: "i"}, + {Name: "d"}, + }, + add: &drive.Share{Name: " B ", Path: "ii"}, + expect: []*drive.Share{ + {Name: "b", Path: "ii"}, + {Name: "d"}, + }, }, { - name: "ExitNodeIP key set, existing ExitNodeIP pref", - exitNodeIPKey: true, - exitNodeIP: "127.0.0.1", - prefs: &ipn.Prefs{ExitNodeIP: netip.MustParseAddr("127.0.0.1")}, - exitNodeIPWant: "127.0.0.1", - prefsChanged: false, + name: "add_bad_name", + add: &drive.Share{Name: "$"}, + expect: drive.ErrInvalidShareName, }, { - name: "existing preferences match policy", - exitNodeIDKey: true, - exitNodeID: "123", - prefs: &ipn.Prefs{ExitNodeID: tailcfg.StableNodeID("123")}, - exitNodeIDWant: "123", - prefsChanged: false, + name: "add_disabled", + disabled: true, + add: &drive.Share{Name: "a"}, + expect: drive.ErrDriveNotEnabled, }, { - name: "ExitNodeIP set if net map does not have corresponding node", - exitNodeIPKey: true, - prefs: &ipn.Prefs{ExitNodeIP: netip.MustParseAddr("127.0.0.1")}, - exitNodeIP: "127.0.0.1", - exitNodeIPWant: "127.0.0.1", - prefsChanged: false, - nm: &netmap.NetworkMap{ - Name: "foo.tailnet", - SelfNode: (&tailcfg.Node{ - Addresses: []netip.Prefix{ - pfx("100.102.103.104/32"), - pfx("100::123/128"), - }, - }).View(), - Peers: []tailcfg.NodeView{ - (&tailcfg.Node{ - Name: "a.tailnet", - Addresses: []netip.Prefix{ - pfx("100.0.0.201/32"), - pfx("100::201/128"), - }, - }).View(), - (&tailcfg.Node{ - Name: "b.tailnet", - Addresses: []netip.Prefix{ - pfx("100::202/128"), - }, - }).View(), - }, + name: "remove", + existing: []*drive.Share{ + {Name: "a"}, + {Name: "b"}, + {Name: "c"}, + }, + remove: "b", + expect: []*drive.Share{ + {Name: "a"}, + {Name: "c"}, }, }, { - name: "ExitNodeIP cleared if net map has corresponding node - policy matches prefs", - prefs: &ipn.Prefs{ExitNodeIP: netip.MustParseAddr("127.0.0.1")}, - exitNodeIPKey: true, - exitNodeIP: "127.0.0.1", - exitNodeIPWant: "", - exitNodeIDWant: "123", - prefsChanged: true, - nm: &netmap.NetworkMap{ - Name: "foo.tailnet", - SelfNode: (&tailcfg.Node{ - Addresses: []netip.Prefix{ - pfx("100.102.103.104/32"), - pfx("100::123/128"), - }, - }).View(), - Peers: []tailcfg.NodeView{ - (&tailcfg.Node{ - Name: "a.tailnet", - StableID: tailcfg.StableNodeID("123"), - Addresses: []netip.Prefix{ - pfx("127.0.0.1/32"), - pfx("100::201/128"), - }, - }).View(), - (&tailcfg.Node{ - Name: "b.tailnet", - Addresses: []netip.Prefix{ - pfx("100::202/128"), - }, - }).View(), - }, + name: "remove_non_existing", + existing: []*drive.Share{ + {Name: "a"}, + {Name: "b"}, + {Name: "c"}, }, + remove: "D", + expect: os.ErrNotExist, }, { - name: "ExitNodeIP cleared if net map has corresponding node - no policy set", - prefs: &ipn.Prefs{ExitNodeIP: netip.MustParseAddr("127.0.0.1")}, - exitNodeIPWant: "", - exitNodeIDWant: "123", - prefsChanged: true, - nm: &netmap.NetworkMap{ - Name: "foo.tailnet", - SelfNode: (&tailcfg.Node{ - Addresses: []netip.Prefix{ - pfx("100.102.103.104/32"), - pfx("100::123/128"), - }, - }).View(), - Peers: []tailcfg.NodeView{ - (&tailcfg.Node{ - Name: "a.tailnet", - StableID: tailcfg.StableNodeID("123"), - Addresses: []netip.Prefix{ - pfx("127.0.0.1/32"), - pfx("100::201/128"), - }, - }).View(), - (&tailcfg.Node{ - Name: "b.tailnet", - Addresses: []netip.Prefix{ - pfx("100::202/128"), - }, - }).View(), - }, + name: "remove_disabled", + disabled: true, + remove: "b", + expect: drive.ErrDriveNotEnabled, + }, + { + name: "rename", + existing: []*drive.Share{ + {Name: "a"}, + {Name: "b"}, + }, + rename: [2]string{"a", " C "}, + expect: []*drive.Share{ + {Name: "b"}, + {Name: "c"}, }, }, { - name: "ExitNodeIP cleared if net map has corresponding node - different exit node IP in policy", - exitNodeIPKey: true, - prefs: &ipn.Prefs{ExitNodeIP: netip.MustParseAddr("127.0.0.1")}, - exitNodeIP: "100.64.5.6", - exitNodeIPWant: "", - exitNodeIDWant: "123", - prefsChanged: true, - nm: &netmap.NetworkMap{ - Name: "foo.tailnet", - SelfNode: (&tailcfg.Node{ - Addresses: []netip.Prefix{ - pfx("100.102.103.104/32"), - pfx("100::123/128"), - }, - }).View(), - Peers: []tailcfg.NodeView{ - (&tailcfg.Node{ - Name: "a.tailnet", - StableID: tailcfg.StableNodeID("123"), - Addresses: []netip.Prefix{ - pfx("100.64.5.6/32"), - pfx("100::201/128"), - }, - }).View(), - (&tailcfg.Node{ - Name: "b.tailnet", - Addresses: []netip.Prefix{ - pfx("100::202/128"), - }, - }).View(), - }, + name: "rename_not_exist", + existing: []*drive.Share{ + {Name: "a"}, + {Name: "b"}, }, + rename: [2]string{"d", "c"}, + expect: os.ErrNotExist, }, { - name: "ExitNodeID key is set to auto and last suggested exit node is populated", - exitNodeIDKey: true, - exitNodeID: "auto:any", - lastSuggestedExitNode: "123", - exitNodeIDWant: "123", - prefsChanged: true, + name: "rename_exists", + existing: []*drive.Share{ + {Name: "a"}, + {Name: "b"}, + }, + rename: [2]string{"a", "b"}, + expect: os.ErrExist, }, { - name: "ExitNodeID key is set to auto and last suggested exit node is not populated", - exitNodeIDKey: true, - exitNodeID: "auto:any", - prefsChanged: true, - exitNodeIDWant: "auto:any", + name: "rename_bad_name", + rename: [2]string{"a", "$"}, + expect: drive.ErrInvalidShareName, + }, + { + name: "rename_disabled", + disabled: true, + rename: [2]string{"a", "c"}, + expect: drive.ErrDriveNotEnabled, }, } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { + drive.DisallowShareAs = true + t.Cleanup(func() { + drive.DisallowShareAs = false + }) + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { b := newTestBackend(t) - msh := &mockSyspolicyHandler{ - t: t, - stringPolicies: map[syspolicy.Key]*string{ - syspolicy.ExitNodeID: nil, - syspolicy.ExitNodeIP: nil, + b.mu.Lock() + if tt.existing != nil { + b.driveSetSharesLocked(tt.existing) + } + if !tt.disabled { + nm := ptr.To(*b.currentNode().NetMap()) + self := nm.SelfNode.AsStruct() + self.CapMap = tailcfg.NodeCapMap{tailcfg.NodeAttrsTaildriveShare: nil} + nm.SelfNode = self.View() + b.currentNode().SetNetMap(nm) + b.sys.Set(driveimpl.NewFileSystemForRemote(b.logf)) + } + b.mu.Unlock() + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + t.Cleanup(cancel) + + result := make(chan views.SliceView[*drive.Share, drive.ShareView], 1) + + var wg sync.WaitGroup + wg.Add(1) + go b.WatchNotifications( + ctx, + 0, + func() { wg.Done() }, + func(n *ipn.Notify) bool { + select { + case result <- n.DriveShares: + default: + // + } + return false }, + ) + wg.Wait() + + var err error + switch { + case tt.add != nil: + err = b.DriveSetShare(tt.add) + case tt.remove != "": + err = b.DriveRemoveShare(tt.remove) + default: + err = b.DriveRenameShare(tt.rename[0], tt.rename[1]) } - if test.exitNodeIDKey { - msh.stringPolicies[syspolicy.ExitNodeID] = &test.exitNodeID + + switch e := tt.expect.(type) { + case error: + if !errors.Is(err, e) { + t.Errorf("expected error, want: %v got: %v", e, err) + } + case []*drive.Share: + if err != nil { + t.Errorf("unexpected error: %v", err) + } else { + r := <-result + + got, err := json.MarshalIndent(r, "", " ") + if err != nil { + t.Fatalf("can't marshal got: %v", err) + } + want, err := json.MarshalIndent(e, "", " ") + if err != nil { + t.Fatalf("can't marshal want: %v", err) + } + if diff := cmp.Diff(string(got), string(want)); diff != "" { + t.Errorf("wrong shares; (-got+want):%v", diff) + } + } } - if test.exitNodeIPKey { - msh.stringPolicies[syspolicy.ExitNodeIP] = &test.exitNodeIP + }) + } +} + +func TestValidPopBrowserURL(t *testing.T) { + b := newTestBackend(t) + tests := []struct { + desc string + controlURL string + popBrowserURL string + want bool + }{ + {"saas_login", "https://login.tailscale.com", "https://login.tailscale.com/a/foo", true}, + {"saas_controlplane", "https://controlplane.tailscale.com", "https://controlplane.tailscale.com/a/foo", true}, + {"saas_root", "https://login.tailscale.com", "https://tailscale.com/", true}, + {"saas_bad_hostname", "https://login.tailscale.com", "https://example.com/a/foo", false}, + {"localhost", "http://localhost", "http://localhost/a/foo", true}, + {"custom_control_url_https", "https://example.com", "https://example.com/a/foo", true}, + {"custom_control_url_https_diff_domain", "https://example.com", "https://other.com/a/foo", true}, + {"custom_control_url_http", "http://example.com", "http://example.com/a/foo", true}, + {"custom_control_url_http_diff_domain", "http://example.com", "http://other.com/a/foo", true}, + {"bad_scheme", "https://example.com", "http://example.com/a/foo", false}, + } + for _, tt := range tests { + t.Run(tt.desc, func(t *testing.T) { + if _, err := b.EditPrefs(&ipn.MaskedPrefs{ + ControlURLSet: true, + Prefs: ipn.Prefs{ + ControlURL: tt.controlURL, + }, + }); err != nil { + t.Fatal(err) } - syspolicy.SetHandlerForTest(t, msh) - if test.nm == nil { - test.nm = new(netmap.NetworkMap) + + got := b.validPopBrowserURL(tt.popBrowserURL) + if got != tt.want { + t.Errorf("got %v, want %v", got, tt.want) } - if test.prefs == nil { - test.prefs = ipn.NewPrefs() + }) + } +} + +func TestRoundTraffic(t *testing.T) { + tests := []struct { + name string + bytes int64 + want float64 + }{ + {name: "under 5 bytes", bytes: 4, want: 4}, + {name: "under 1000 bytes", bytes: 987, want: 990}, + {name: "under 10_000 bytes", bytes: 8875, want: 8900}, + {name: "under 100_000 bytes", bytes: 77777, want: 78000}, + {name: "under 1_000_000 bytes", bytes: 666523, want: 670000}, + {name: "under 10_000_000 bytes", bytes: 22556677, want: 23000000}, + {name: "under 1_000_000_000 bytes", bytes: 1234234234, want: 1200000000}, + {name: "under 1_000_000_000 bytes", bytes: 123423423499, want: 123400000000}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if result := roundTraffic(tt.bytes); result != tt.want { + t.Errorf("unexpected rounding got %v want %v", result, tt.want) } - pm := must.Get(newProfileManager(new(mem.Store), t.Logf, new(health.Tracker))) - pm.prefs = test.prefs.View() - b.netMap = test.nm - b.pm = pm - b.lastSuggestedExitNode = test.lastSuggestedExitNode - changed := setExitNodeID(b.pm.prefs.AsStruct(), test.nm, tailcfg.StableNodeID(test.lastSuggestedExitNode)) - b.SetPrefsForTest(pm.CurrentPrefs().AsStruct()) + }) + } +} + +func (b *LocalBackend) SetPrefsForTest(newp *ipn.Prefs) { + if newp == nil { + panic("SetPrefsForTest got nil prefs") + } + b.mu.Lock() + defer b.mu.Unlock() + b.setPrefsLocked(newp) +} - if got := b.pm.prefs.ExitNodeID(); got != tailcfg.StableNodeID(test.exitNodeIDWant) { - t.Errorf("got %v want %v", got, test.exitNodeIDWant) - } - if got := b.pm.prefs.ExitNodeIP(); test.exitNodeIPWant == "" { - if got.String() != "invalid IP" { - t.Errorf("got %v want invalid IP", got) - } - } else if got.String() != test.exitNodeIPWant { - t.Errorf("got %v want %v", got, test.exitNodeIPWant) - } +type peerOptFunc func(*tailcfg.Node) - if changed != test.prefsChanged { - t.Errorf("wanted prefs changed %v, got prefs changed %v", test.prefsChanged, changed) - } - }) +func makePeer(id tailcfg.NodeID, opts ...peerOptFunc) tailcfg.NodeView { + node := &tailcfg.Node{ + ID: id, + Key: makeNodeKeyFromID(id), + DiscoKey: makeDiscoKeyFromID(id), + StableID: tailcfg.StableNodeID(fmt.Sprintf("stable%d", id)), + Name: fmt.Sprintf("peer%d", id), + Online: ptr.To(true), + MachineAuthorized: true, + HomeDERP: int(id), + } + for _, opt := range opts { + opt(node) } + return node.View() } -func TestUpdateNetmapDeltaAutoExitNode(t *testing.T) { - peer1 := makePeer(1, withCap(26), withSuggest(), withExitRoutes()) - peer2 := makePeer(2, withCap(26), withSuggest(), withExitRoutes()) - derpMap := &tailcfg.DERPMap{ - Regions: map[int]*tailcfg.DERPRegion{ - 1: { - Nodes: []*tailcfg.DERPNode{ - { - Name: "t1", - RegionID: 1, - }, - }, - }, - 2: { - Nodes: []*tailcfg.DERPNode{ - { - Name: "t2", - RegionID: 2, - }, - }, - }, - }, +func withName(name string) peerOptFunc { + return func(n *tailcfg.Node) { + n.Name = name } - report := &netcheck.Report{ - RegionLatency: map[int]time.Duration{ - 1: 10 * time.Millisecond, - 2: 5 * time.Millisecond, - 3: 30 * time.Millisecond, - }, - PreferredDERP: 2, +} + +func withDERP(region int) peerOptFunc { + return func(n *tailcfg.Node) { + n.HomeDERP = region } - tests := []struct { - name string - lastSuggestedExitNode tailcfg.StableNodeID - netmap *netmap.NetworkMap - muts []*tailcfg.PeerChange - exitNodeIDWant tailcfg.StableNodeID - updateNetmapDeltaResponse bool - report *netcheck.Report - }{ - { - name: "selected auto exit node goes offline", - lastSuggestedExitNode: peer1.StableID(), - netmap: &netmap.NetworkMap{ - Peers: []tailcfg.NodeView{ - peer1, - peer2, - }, - DERPMap: derpMap, - }, - muts: []*tailcfg.PeerChange{ - { - NodeID: 1, - Online: ptr.To(false), - }, - { - NodeID: 2, - Online: ptr.To(true), - }, - }, - exitNodeIDWant: peer2.StableID(), - updateNetmapDeltaResponse: false, - report: report, - }, - { - name: "other exit node goes offline doesn't change selected auto exit node that's still online", - lastSuggestedExitNode: peer2.StableID(), - netmap: &netmap.NetworkMap{ - Peers: []tailcfg.NodeView{ - peer1, - peer2, - }, - DERPMap: derpMap, - }, - muts: []*tailcfg.PeerChange{ - { - NodeID: 1, - Online: ptr.To(false), - }, - { - NodeID: 2, - Online: ptr.To(true), - }, - }, - exitNodeIDWant: peer2.StableID(), - updateNetmapDeltaResponse: true, - report: report, - }, +} + +func withoutDERP() peerOptFunc { + return func(n *tailcfg.Node) { + n.HomeDERP = 0 } - msh := &mockSyspolicyHandler{ - t: t, - stringPolicies: map[syspolicy.Key]*string{ - syspolicy.ExitNodeID: ptr.To("auto:any"), - }, +} + +func withLocation(loc tailcfg.LocationView) peerOptFunc { + return func(n *tailcfg.Node) { + var hi *tailcfg.Hostinfo + if n.Hostinfo.Valid() { + hi = n.Hostinfo.AsStruct() + } else { + hi = new(tailcfg.Hostinfo) + } + hi.Location = loc.AsStruct() + + n.Hostinfo = hi.View() } - syspolicy.SetHandlerForTest(t, msh) - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - b := newTestLocalBackend(t) - b.netMap = tt.netmap - b.updatePeersFromNetmapLocked(b.netMap) - b.lastSuggestedExitNode = tt.lastSuggestedExitNode - b.sys.MagicSock.Get().SetLastNetcheckReportForTest(b.ctx, tt.report) - b.SetPrefsForTest(b.pm.CurrentPrefs().AsStruct()) - someTime := time.Unix(123, 0) - muts, ok := netmap.MutationsFromMapResponse(&tailcfg.MapResponse{ - PeersChangedPatch: tt.muts, - }, someTime) - if !ok { - t.Fatal("netmap.MutationsFromMapResponse failed") - } - if b.pm.prefs.ExitNodeID() != tt.lastSuggestedExitNode { - t.Fatalf("did not set exit node ID to last suggested exit node despite auto policy") - } +} - got := b.UpdateNetmapDelta(muts) - if got != tt.updateNetmapDeltaResponse { - t.Fatalf("got %v expected %v from UpdateNetmapDelta", got, tt.updateNetmapDeltaResponse) - } - if b.pm.prefs.ExitNodeID() != tt.exitNodeIDWant { - t.Fatalf("did not get expected exit node id after UpdateNetmapDelta") - } - }) +func withLocationPriority(pri int) peerOptFunc { + return func(n *tailcfg.Node) { + var hi *tailcfg.Hostinfo + if n.Hostinfo.Valid() { + hi = n.Hostinfo.AsStruct() + } else { + hi = new(tailcfg.Hostinfo) + } + if hi.Location == nil { + hi.Location = new(tailcfg.Location) + } + hi.Location.Priority = pri + + n.Hostinfo = hi.View() } } -func TestAutoExitNodeSetNetInfoCallback(t *testing.T) { - b := newTestLocalBackend(t) - hi := hostinfo.New() - ni := tailcfg.NetInfo{LinkType: "wired"} - hi.NetInfo = &ni - b.hostinfo = hi - k := key.NewMachine() - var cc *mockControl - opts := controlclient.Options{ - ServerURL: "https://example.com", - GetMachinePrivateKey: func() (key.MachinePrivate, error) { - return k, nil - }, - Dialer: tsdial.NewDialer(netmon.NewStatic()), - Logf: b.logf, +func withExitRoutes() peerOptFunc { + return func(n *tailcfg.Node) { + n.AllowedIPs = append(n.AllowedIPs, tsaddr.ExitRoutes()...) } - cc = newClient(t, opts) - b.cc = cc - msh := &mockSyspolicyHandler{ - t: t, - stringPolicies: map[syspolicy.Key]*string{ - syspolicy.ExitNodeID: ptr.To("auto:any"), - }, +} + +func withSuggest() peerOptFunc { + return func(n *tailcfg.Node) { + mak.Set(&n.CapMap, tailcfg.NodeAttrSuggestExitNode, []tailcfg.RawMessage{}) } - syspolicy.SetHandlerForTest(t, msh) - peer1 := makePeer(1, withCap(26), withDERP(3), withSuggest(), withExitRoutes()) - peer2 := makePeer(2, withCap(26), withDERP(2), withSuggest(), withExitRoutes()) - selfNode := tailcfg.Node{ - Addresses: []netip.Prefix{ - netip.MustParsePrefix("100.64.1.1/32"), - netip.MustParsePrefix("fe70::1/128"), - }, - DERP: "127.3.3.40:2", +} + +func withCap(version tailcfg.CapabilityVersion) peerOptFunc { + return func(n *tailcfg.Node) { + n.Cap = version } - defaultDERPMap := &tailcfg.DERPMap{ - Regions: map[int]*tailcfg.DERPRegion{ - 1: { - Nodes: []*tailcfg.DERPNode{ - { - Name: "t1", - RegionID: 1, - }, - }, - }, - 2: { - Nodes: []*tailcfg.DERPNode{ - { - Name: "t2", - RegionID: 2, - }, - }, - }, - 3: { - Nodes: []*tailcfg.DERPNode{ - { - Name: "t3", - RegionID: 3, - }, - }, - }, - }, +} + +func withOnline(isOnline bool) peerOptFunc { + return func(n *tailcfg.Node) { + n.Online = &isOnline } - b.netMap = &netmap.NetworkMap{ - SelfNode: selfNode.View(), - Peers: []tailcfg.NodeView{ - peer1, - peer2, - }, - DERPMap: defaultDERPMap, +} + +func withNodeKey() peerOptFunc { + return func(n *tailcfg.Node) { + n.Key = key.NewNode().Public() } - b.lastSuggestedExitNode = peer1.StableID() - b.SetPrefsForTest(b.pm.CurrentPrefs().AsStruct()) - if eid := b.Prefs().ExitNodeID(); eid != peer1.StableID() { - t.Errorf("got initial exit node %v, want %v", eid, peer1.StableID()) +} + +func withAddresses(addresses ...netip.Prefix) peerOptFunc { + return func(n *tailcfg.Node) { + n.Addresses = append(n.Addresses, addresses...) } - b.refreshAutoExitNode = true - b.sys.MagicSock.Get().SetLastNetcheckReportForTest(b.ctx, &netcheck.Report{ - RegionLatency: map[int]time.Duration{ - 1: 10 * time.Millisecond, - 2: 5 * time.Millisecond, - 3: 30 * time.Millisecond, - }, - PreferredDERP: 2, - }) - b.setNetInfo(&ni) - if eid := b.Prefs().ExitNodeID(); eid != peer2.StableID() { - t.Errorf("got final exit node %v, want %v", eid, peer2.StableID()) +} + +func deterministicRegionForTest(t testing.TB, want views.Slice[int], use int) selectRegionFunc { + t.Helper() + + if !views.SliceContains(want, use) { + t.Errorf("invalid test: use %v is not in want %v", use, want) + } + + return func(got views.Slice[int]) int { + if !views.SliceEqualAnyOrder(got, want) { + t.Errorf("candidate regions = %v, want %v", got, want) + } + return use + } +} + +func deterministicNodeForTest(t testing.TB, want views.Slice[tailcfg.StableNodeID], wantLast tailcfg.StableNodeID, use tailcfg.StableNodeID) selectNodeFunc { + t.Helper() + + if !views.SliceContains(want, use) { + t.Errorf("invalid test: use %v is not in want %v", use, want) + } + + return func(got views.Slice[tailcfg.NodeView], last tailcfg.StableNodeID) tailcfg.NodeView { + var ret tailcfg.NodeView + + gotIDs := make([]tailcfg.StableNodeID, got.Len()) + for i, nv := range got.All() { + if !nv.Valid() { + t.Fatalf("invalid node at index %v", i) + } + gotIDs[i] = nv.StableID() + if nv.StableID() == use { + ret = nv + } + } + if !views.SliceEqualAnyOrder(views.SliceOf(gotIDs), want) { + t.Errorf("candidate nodes = %v, want %v", gotIDs, want) + } + if last != wantLast { + t.Errorf("last node = %v, want %v", last, wantLast) + } + if !ret.Valid() { + t.Fatalf("did not find matching node in %v, want %v", gotIDs, use) + } + + return ret } } -func TestSetControlClientStatusAutoExitNode(t *testing.T) { - peer1 := makePeer(1, withCap(26), withSuggest(), withExitRoutes(), withNodeKey()) - peer2 := makePeer(2, withCap(26), withSuggest(), withExitRoutes(), withNodeKey()) - derpMap := &tailcfg.DERPMap{ +func TestSuggestExitNode(t *testing.T) { + t.Parallel() + + defaultDERPMap := &tailcfg.DERPMap{ Regions: map[int]*tailcfg.DERPRegion{ 1: { - Nodes: []*tailcfg.DERPNode{ - { - Name: "t1", - RegionID: 1, - }, - }, - }, - 2: { - Nodes: []*tailcfg.DERPNode{ - { - Name: "t2", - RegionID: 2, - }, - }, + Latitude: 32, + Longitude: -97, }, + 2: {}, + 3: {}, }, } - report := &netcheck.Report{ + + preferred1Report := &netcheck.Report{ RegionLatency: map[int]time.Duration{ 1: 10 * time.Millisecond, - 2: 5 * time.Millisecond, + 2: 20 * time.Millisecond, 3: 30 * time.Millisecond, }, PreferredDERP: 1, } - nm := &netmap.NetworkMap{ - Peers: []tailcfg.NodeView{ - peer1, - peer2, + noLatency1Report := &netcheck.Report{ + RegionLatency: map[int]time.Duration{ + 1: 0, + 2: 0, + 3: 0, }, - DERPMap: derpMap, + PreferredDERP: 1, } - b := newTestLocalBackend(t) - msh := &mockSyspolicyHandler{ - t: t, - stringPolicies: map[syspolicy.Key]*string{ - syspolicy.ExitNodeID: ptr.To("auto:any"), + preferredNoneReport := &netcheck.Report{ + RegionLatency: map[int]time.Duration{ + 1: 10 * time.Millisecond, + 2: 20 * time.Millisecond, + 3: 30 * time.Millisecond, }, + PreferredDERP: 0, } - syspolicy.SetHandlerForTest(t, msh) - b.netMap = nm - b.lastSuggestedExitNode = peer1.StableID() - b.sys.MagicSock.Get().SetLastNetcheckReportForTest(b.ctx, report) - b.SetPrefsForTest(b.pm.CurrentPrefs().AsStruct()) - firstExitNode := b.Prefs().ExitNodeID() - newPeer1 := makePeer(1, withCap(26), withSuggest(), withExitRoutes(), withOnline(false), withNodeKey()) - updatedNetmap := &netmap.NetworkMap{ - Peers: []tailcfg.NodeView{ - newPeer1, - peer2, - }, - DERPMap: derpMap, + + dallas := tailcfg.Location{ + Latitude: 32.779167, + Longitude: -96.808889, + Priority: 100, } - b.SetControlClientStatus(b.cc, controlclient.Status{NetMap: updatedNetmap}) - lastExitNode := b.Prefs().ExitNodeID() - if firstExitNode == lastExitNode { - t.Errorf("did not switch exit nodes despite auto exit node going offline") + sanJose := tailcfg.Location{ + Latitude: 37.3382082, + Longitude: -121.8863286, + Priority: 20, + } + fortWorth := tailcfg.Location{ + Latitude: 32.756389, + Longitude: -97.3325, + Priority: 150, + } + fortWorthLowPriority := tailcfg.Location{ + Latitude: 32.756389, + Longitude: -97.3325, + Priority: 100, } -} -func TestApplySysPolicy(t *testing.T) { - tests := []struct { - name string - prefs ipn.Prefs - wantPrefs ipn.Prefs - wantAnyChange bool - stringPolicies map[syspolicy.Key]string - }{ - { - name: "empty prefs without policies", - }, - { - name: "prefs set without policies", - prefs: ipn.Prefs{ - ControlURL: "1", - ShieldsUp: true, - ForceDaemon: true, - ExitNodeAllowLANAccess: true, - CorpDNS: true, - RouteAll: true, - }, - wantPrefs: ipn.Prefs{ - ControlURL: "1", - ShieldsUp: true, - ForceDaemon: true, - ExitNodeAllowLANAccess: true, - CorpDNS: true, - RouteAll: true, - }, - }, - { - name: "empty prefs with policies", - wantPrefs: ipn.Prefs{ - ControlURL: "1", - ShieldsUp: true, - ForceDaemon: true, - ExitNodeAllowLANAccess: true, - CorpDNS: true, - RouteAll: true, - }, - wantAnyChange: true, - stringPolicies: map[syspolicy.Key]string{ - syspolicy.ControlURL: "1", - syspolicy.EnableIncomingConnections: "never", - syspolicy.EnableServerMode: "always", - syspolicy.ExitNodeAllowLANAccess: "always", - syspolicy.EnableTailscaleDNS: "always", - syspolicy.EnableTailscaleSubnets: "always", - }, - }, - { - name: "prefs set with matching policies", - prefs: ipn.Prefs{ - ControlURL: "1", - ShieldsUp: true, - ForceDaemon: true, - }, - wantPrefs: ipn.Prefs{ - ControlURL: "1", - ShieldsUp: true, - ForceDaemon: true, - }, - stringPolicies: map[syspolicy.Key]string{ - syspolicy.ControlURL: "1", - syspolicy.EnableIncomingConnections: "never", - syspolicy.EnableServerMode: "always", - syspolicy.ExitNodeAllowLANAccess: "never", - syspolicy.EnableTailscaleDNS: "never", - syspolicy.EnableTailscaleSubnets: "never", - }, - }, - { - name: "prefs set with conflicting policies", - prefs: ipn.Prefs{ - ControlURL: "1", - ShieldsUp: true, - ForceDaemon: true, - ExitNodeAllowLANAccess: false, - CorpDNS: true, - RouteAll: false, - }, - wantPrefs: ipn.Prefs{ - ControlURL: "2", - ShieldsUp: false, - ForceDaemon: false, - ExitNodeAllowLANAccess: true, - CorpDNS: false, - RouteAll: true, - }, - wantAnyChange: true, - stringPolicies: map[syspolicy.Key]string{ - syspolicy.ControlURL: "2", - syspolicy.EnableIncomingConnections: "always", - syspolicy.EnableServerMode: "never", - syspolicy.ExitNodeAllowLANAccess: "always", - syspolicy.EnableTailscaleDNS: "never", - syspolicy.EnableTailscaleSubnets: "always", - }, - }, - { - name: "prefs set with neutral policies", - prefs: ipn.Prefs{ - ControlURL: "1", - ShieldsUp: true, - ForceDaemon: true, - ExitNodeAllowLANAccess: false, - CorpDNS: true, - RouteAll: true, - }, - wantPrefs: ipn.Prefs{ - ControlURL: "1", - ShieldsUp: true, - ForceDaemon: true, - ExitNodeAllowLANAccess: false, - CorpDNS: true, - RouteAll: true, - }, - stringPolicies: map[syspolicy.Key]string{ - syspolicy.EnableIncomingConnections: "user-decides", - syspolicy.EnableServerMode: "user-decides", - syspolicy.ExitNodeAllowLANAccess: "user-decides", - syspolicy.EnableTailscaleDNS: "user-decides", - syspolicy.EnableTailscaleSubnets: "user-decides", - }, - }, - { - name: "ControlURL", - wantPrefs: ipn.Prefs{ - ControlURL: "set", - }, - wantAnyChange: true, - stringPolicies: map[syspolicy.Key]string{ - syspolicy.ControlURL: "set", - }, - }, - { - name: "enable AutoUpdate apply does not unset check", - prefs: ipn.Prefs{ - AutoUpdate: ipn.AutoUpdatePrefs{ - Check: true, - Apply: opt.NewBool(false), - }, - }, - wantPrefs: ipn.Prefs{ - AutoUpdate: ipn.AutoUpdatePrefs{ - Check: true, - Apply: opt.NewBool(true), - }, - }, - wantAnyChange: true, - stringPolicies: map[syspolicy.Key]string{ - syspolicy.ApplyUpdates: "always", - }, + peer1 := makePeer(1, + withExitRoutes(), + withSuggest()) + peer2DERP1 := makePeer(2, + withDERP(1), + withExitRoutes(), + withSuggest()) + peer3 := makePeer(3, + withExitRoutes(), + withSuggest()) + peer4DERP3 := makePeer(4, + withDERP(3), + withExitRoutes(), + withSuggest()) + dallasPeer5 := makePeer(5, + withName("Dallas"), + withoutDERP(), + withExitRoutes(), + withSuggest(), + withLocation(dallas.View())) + sanJosePeer6 := makePeer(6, + withName("San Jose"), + withoutDERP(), + withExitRoutes(), + withSuggest(), + withLocation(sanJose.View())) + fortWorthPeer7 := makePeer(7, + withName("Fort Worth"), + withoutDERP(), + withExitRoutes(), + withSuggest(), + withLocation(fortWorth.View())) + fortWorthPeer8LowPriority := makePeer(8, + withName("Fort Worth Low"), + withoutDERP(), + withExitRoutes(), + withSuggest(), + withLocation(fortWorthLowPriority.View())) + + selfNode := tailcfg.Node{ + Addresses: []netip.Prefix{ + netip.MustParsePrefix("100.64.1.1/32"), + netip.MustParsePrefix("fe70::1/128"), }, - { - name: "disable AutoUpdate apply does not unset check", - prefs: ipn.Prefs{ - AutoUpdate: ipn.AutoUpdatePrefs{ - Check: true, - Apply: opt.NewBool(true), - }, - }, - wantPrefs: ipn.Prefs{ - AutoUpdate: ipn.AutoUpdatePrefs{ - Check: true, - Apply: opt.NewBool(false), - }, - }, - wantAnyChange: true, - stringPolicies: map[syspolicy.Key]string{ - syspolicy.ApplyUpdates: "never", - }, + } + + defaultNetmap := &netmap.NetworkMap{ + SelfNode: selfNode.View(), + DERPMap: defaultDERPMap, + Peers: []tailcfg.NodeView{ + peer2DERP1, + peer3, }, - { - name: "enable AutoUpdate check does not unset apply", - prefs: ipn.Prefs{ - AutoUpdate: ipn.AutoUpdatePrefs{ - Check: false, - Apply: opt.NewBool(true), - }, - }, - wantPrefs: ipn.Prefs{ - AutoUpdate: ipn.AutoUpdatePrefs{ - Check: true, - Apply: opt.NewBool(true), - }, - }, - wantAnyChange: true, - stringPolicies: map[syspolicy.Key]string{ - syspolicy.CheckUpdates: "always", - }, + } + locationNetmap := &netmap.NetworkMap{ + SelfNode: selfNode.View(), + DERPMap: defaultDERPMap, + Peers: []tailcfg.NodeView{ + dallasPeer5, + sanJosePeer6, }, - { - name: "disable AutoUpdate check does not unset apply", - prefs: ipn.Prefs{ - AutoUpdate: ipn.AutoUpdatePrefs{ - Check: true, - Apply: opt.NewBool(true), - }, - }, - wantPrefs: ipn.Prefs{ - AutoUpdate: ipn.AutoUpdatePrefs{ - Check: false, - Apply: opt.NewBool(true), - }, - }, - wantAnyChange: true, - stringPolicies: map[syspolicy.Key]string{ - syspolicy.CheckUpdates: "never", - }, + } + largeNetmap := &netmap.NetworkMap{ + SelfNode: selfNode.View(), + DERPMap: defaultDERPMap, + Peers: []tailcfg.NodeView{ + peer1, + peer2DERP1, + peer3, + peer4DERP3, + dallasPeer5, + sanJosePeer6, + fortWorthPeer7, }, } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - msh := &mockSyspolicyHandler{ - t: t, - stringPolicies: make(map[syspolicy.Key]*string, len(tt.stringPolicies)), - } - for p, v := range tt.stringPolicies { - v := v // construct a unique pointer for each policy value - msh.stringPolicies[p] = &v - } - syspolicy.SetHandlerForTest(t, msh) - - t.Run("unit", func(t *testing.T) { - prefs := tt.prefs.Clone() - - gotAnyChange := applySysPolicy(prefs) + tests := []struct { + name string - if gotAnyChange && prefs.Equals(&tt.prefs) { - t.Errorf("anyChange but prefs is unchanged: %v", prefs.Pretty()) - } - if !gotAnyChange && !prefs.Equals(&tt.prefs) { - t.Errorf("!anyChange but prefs changed from %v to %v", tt.prefs.Pretty(), prefs.Pretty()) - } - if gotAnyChange != tt.wantAnyChange { - t.Errorf("anyChange=%v, want %v", gotAnyChange, tt.wantAnyChange) - } - if !prefs.Equals(&tt.wantPrefs) { - t.Errorf("prefs=%v, want %v", prefs.Pretty(), tt.wantPrefs.Pretty()) - } - }) + lastReport *netcheck.Report + netMap *netmap.NetworkMap + lastSuggestion tailcfg.StableNodeID - t.Run("status update", func(t *testing.T) { - // Profile manager fills in blank ControlURL but it's not set - // in most test cases to avoid cluttering them, so adjust for - // that. - usePrefs := tt.prefs.Clone() - if usePrefs.ControlURL == "" { - usePrefs.ControlURL = ipn.DefaultControlURL - } - wantPrefs := tt.wantPrefs.Clone() - if wantPrefs.ControlURL == "" { - wantPrefs.ControlURL = ipn.DefaultControlURL - } + allowPolicy []tailcfg.StableNodeID - pm := must.Get(newProfileManager(new(mem.Store), t.Logf, new(health.Tracker))) - pm.prefs = usePrefs.View() + wantRegions []int + useRegion int - b := newTestBackend(t) - b.mu.Lock() - b.pm = pm - b.mu.Unlock() + wantNodes []tailcfg.StableNodeID - b.SetControlClientStatus(b.cc, controlclient.Status{}) - if !b.Prefs().Equals(wantPrefs.View()) { - t.Errorf("prefs=%v, want %v", b.Prefs().Pretty(), wantPrefs.Pretty()) - } - }) - }) - } -} + wantID tailcfg.StableNodeID + wantName string + wantLocation tailcfg.LocationView -func TestPreferencePolicyInfo(t *testing.T) { - tests := []struct { - name string - initialValue bool - wantValue bool - wantChange bool - policyValue string - policyError error + wantError error }{ { - name: "force enable modify", - initialValue: false, - wantValue: true, - wantChange: true, - policyValue: "always", + name: "2 exit nodes in same region", + lastReport: preferred1Report, + netMap: &netmap.NetworkMap{ + SelfNode: selfNode.View(), + DERPMap: defaultDERPMap, + Peers: []tailcfg.NodeView{ + peer1, + peer2DERP1, + }, + }, + wantNodes: []tailcfg.StableNodeID{ + "stable1", + "stable2", + }, + wantName: "peer1", + wantID: "stable1", }, { - name: "force enable unchanged", - initialValue: true, - wantValue: true, - policyValue: "always", + name: "2 exit nodes different regions unknown latency", + lastReport: noLatency1Report, + netMap: defaultNetmap, + wantRegions: []int{1, 3}, // the only regions with peers + useRegion: 1, + wantName: "peer2", + wantID: "stable2", }, { - name: "force disable modify", - initialValue: true, - wantValue: false, - wantChange: true, - policyValue: "never", + name: "2 derp based exit nodes, different regions, equal latency", + lastReport: &netcheck.Report{ + RegionLatency: map[int]time.Duration{ + 1: 10, + 2: 20, + 3: 10, + }, + PreferredDERP: 1, + }, + netMap: &netmap.NetworkMap{ + SelfNode: selfNode.View(), + DERPMap: defaultDERPMap, + Peers: []tailcfg.NodeView{ + peer1, + peer3, + }, + }, + wantRegions: []int{1, 2}, + useRegion: 1, + wantName: "peer1", + wantID: "stable1", }, { - name: "force disable unchanged", - initialValue: false, - wantValue: false, - policyValue: "never", + name: "mullvad nodes, no derp based exit nodes", + lastReport: noLatency1Report, + netMap: locationNetmap, + wantID: "stable5", + wantLocation: dallas.View(), + wantName: "Dallas", }, { - name: "unforced enabled", - initialValue: true, - wantValue: true, - policyValue: "user-decides", + name: "nearby mullvad nodes with different priorities", + lastReport: noLatency1Report, + netMap: &netmap.NetworkMap{ + SelfNode: selfNode.View(), + DERPMap: defaultDERPMap, + Peers: []tailcfg.NodeView{ + dallasPeer5, + sanJosePeer6, + fortWorthPeer7, + }, + }, + wantID: "stable7", + wantLocation: fortWorth.View(), + wantName: "Fort Worth", + }, + { + name: "nearby mullvad nodes with same priorities", + lastReport: noLatency1Report, + netMap: &netmap.NetworkMap{ + SelfNode: selfNode.View(), + DERPMap: defaultDERPMap, + Peers: []tailcfg.NodeView{ + dallasPeer5, + sanJosePeer6, + fortWorthPeer8LowPriority, + }, + }, + wantNodes: []tailcfg.StableNodeID{"stable5", "stable8"}, + wantID: "stable5", + wantLocation: dallas.View(), + wantName: "Dallas", }, { - name: "unforced disabled", - initialValue: false, - wantValue: false, - policyValue: "user-decides", + name: "mullvad nodes, remaining node is not in preferred derp", + lastReport: noLatency1Report, + netMap: &netmap.NetworkMap{ + SelfNode: selfNode.View(), + DERPMap: defaultDERPMap, + Peers: []tailcfg.NodeView{ + dallasPeer5, + sanJosePeer6, + peer4DERP3, + }, + }, + useRegion: 3, + wantID: "stable4", + wantName: "peer4", }, { - name: "blank enabled", - initialValue: true, - wantValue: true, - policyValue: "", + name: "no peers", + lastReport: noLatency1Report, + netMap: &netmap.NetworkMap{ + SelfNode: selfNode.View(), + DERPMap: defaultDERPMap, + }, }, { - name: "blank disabled", - initialValue: false, - wantValue: false, - policyValue: "", + name: "nil report", + lastReport: nil, + netMap: largeNetmap, + wantError: ErrNoPreferredDERP, }, { - name: "unset enabled", - initialValue: true, - wantValue: true, - policyError: syspolicy.ErrNoSuchKey, + name: "no preferred derp region", + lastReport: preferredNoneReport, + netMap: &netmap.NetworkMap{ + SelfNode: selfNode.View(), + DERPMap: defaultDERPMap, + }, + wantError: ErrNoPreferredDERP, }, { - name: "unset disabled", - initialValue: false, - wantValue: false, - policyError: syspolicy.ErrNoSuchKey, + name: "nil netmap", + lastReport: noLatency1Report, + netMap: nil, + wantError: ErrNoPreferredDERP, }, { - name: "error enabled", - initialValue: true, - wantValue: true, - policyError: errors.New("test error"), + name: "nil derpmap", + lastReport: noLatency1Report, + netMap: &netmap.NetworkMap{ + SelfNode: selfNode.View(), + DERPMap: nil, + Peers: []tailcfg.NodeView{ + dallasPeer5, + }, + }, + wantError: ErrNoPreferredDERP, }, { - name: "error disabled", - initialValue: false, - wantValue: false, - policyError: errors.New("test error"), + name: "missing suggestion capability", + lastReport: noLatency1Report, + netMap: &netmap.NetworkMap{ + SelfNode: selfNode.View(), + DERPMap: defaultDERPMap, + Peers: []tailcfg.NodeView{ + makePeer(1, withExitRoutes()), + makePeer(2, withLocation(dallas.View()), withExitRoutes()), + }, + }, }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - for _, pp := range preferencePolicies { - t.Run(string(pp.key), func(t *testing.T) { - var h syspolicy.Handler - - allPolicies := make(map[syspolicy.Key]*string, len(preferencePolicies)+1) - allPolicies[syspolicy.ControlURL] = nil - for _, pp := range preferencePolicies { - allPolicies[pp.key] = nil - } - - if tt.policyError != nil { - h = &errorSyspolicyHandler{ - t: t, - err: tt.policyError, - key: pp.key, - allowKeys: allPolicies, - } - } else { - msh := &mockSyspolicyHandler{ - t: t, - stringPolicies: allPolicies, - failUnknownPolicies: true, - } - msh.stringPolicies[pp.key] = &tt.policyValue - h = msh - } - syspolicy.SetHandlerForTest(t, h) - - prefs := defaultPrefs.AsStruct() - pp.set(prefs, tt.initialValue) - - gotAnyChange := applySysPolicy(prefs) - - if gotAnyChange != tt.wantChange { - t.Errorf("anyChange=%v, want %v", gotAnyChange, tt.wantChange) - } - got := pp.get(prefs.View()) - if got != tt.wantValue { - t.Errorf("pref=%v, want %v", got, tt.wantValue) - } - }) - } - }) - } -} - -func TestOnTailnetDefaultAutoUpdate(t *testing.T) { - tests := []struct { - before, after opt.Bool - container opt.Bool - tailnetDefault bool - }{ { - before: opt.Bool(""), - tailnetDefault: true, - after: opt.NewBool(true), + name: "prefer last node", + lastReport: preferred1Report, + netMap: &netmap.NetworkMap{ + SelfNode: selfNode.View(), + DERPMap: defaultDERPMap, + Peers: []tailcfg.NodeView{ + peer1, + peer2DERP1, + }, + }, + lastSuggestion: "stable2", + wantNodes: []tailcfg.StableNodeID{ + "stable1", + "stable2", + }, + wantName: "peer2", + wantID: "stable2", }, { - before: opt.Bool(""), - tailnetDefault: false, - after: opt.NewBool(false), + name: "found better derp node", + lastSuggestion: "stable3", + lastReport: preferred1Report, + netMap: defaultNetmap, + wantID: "stable2", + wantName: "peer2", }, { - before: opt.Bool("unset"), - tailnetDefault: true, - after: opt.NewBool(true), + name: "prefer last mullvad node", + lastSuggestion: "stable2", + lastReport: preferred1Report, + netMap: &netmap.NetworkMap{ + SelfNode: selfNode.View(), + DERPMap: defaultDERPMap, + Peers: []tailcfg.NodeView{ + dallasPeer5, + sanJosePeer6, + fortWorthPeer8LowPriority, + }, + }, + wantNodes: []tailcfg.StableNodeID{"stable5", "stable8"}, + wantID: "stable5", + wantName: "Dallas", + wantLocation: dallas.View(), }, { - before: opt.Bool("unset"), - tailnetDefault: false, - after: opt.NewBool(false), + name: "prefer better mullvad node", + lastSuggestion: "stable2", + lastReport: preferred1Report, + netMap: &netmap.NetworkMap{ + SelfNode: selfNode.View(), + DERPMap: defaultDERPMap, + Peers: []tailcfg.NodeView{ + dallasPeer5, + sanJosePeer6, + fortWorthPeer7, + }, + }, + wantNodes: []tailcfg.StableNodeID{"stable7"}, + wantID: "stable7", + wantName: "Fort Worth", + wantLocation: fortWorth.View(), }, { - before: opt.NewBool(false), - tailnetDefault: true, - after: opt.NewBool(false), + name: "large netmap", + lastReport: preferred1Report, + netMap: largeNetmap, + wantNodes: []tailcfg.StableNodeID{"stable1", "stable2"}, + wantID: "stable2", + wantName: "peer2", }, { - before: opt.NewBool(true), - tailnetDefault: false, - after: opt.NewBool(true), + name: "no allowed suggestions", + lastReport: preferred1Report, + netMap: largeNetmap, + allowPolicy: []tailcfg.StableNodeID{}, }, { - before: opt.Bool(""), - container: opt.NewBool(true), - tailnetDefault: true, - after: opt.Bool(""), + name: "only derp suggestions", + lastReport: preferred1Report, + netMap: largeNetmap, + allowPolicy: []tailcfg.StableNodeID{"stable1", "stable2", "stable3"}, + wantNodes: []tailcfg.StableNodeID{"stable1", "stable2"}, + wantID: "stable2", + wantName: "peer2", }, { - before: opt.NewBool(false), - container: opt.NewBool(true), - tailnetDefault: true, - after: opt.NewBool(false), + name: "only mullvad suggestions", + lastReport: preferred1Report, + netMap: largeNetmap, + allowPolicy: []tailcfg.StableNodeID{"stable5", "stable6", "stable7"}, + wantID: "stable7", + wantName: "Fort Worth", + wantLocation: fortWorth.View(), }, { - before: opt.NewBool(true), - container: opt.NewBool(true), - tailnetDefault: false, - after: opt.NewBool(true), + name: "only worst derp", + lastReport: preferred1Report, + netMap: largeNetmap, + allowPolicy: []tailcfg.StableNodeID{"stable3"}, + wantID: "stable3", + wantName: "peer3", }, - } - for _, tt := range tests { - t.Run(fmt.Sprintf("before=%s,after=%s", tt.before, tt.after), func(t *testing.T) { - b := newTestBackend(t) - b.hostinfo = hostinfo.New() - b.hostinfo.Container = tt.container - p := ipn.NewPrefs() - p.AutoUpdate.Apply = tt.before - if err := b.pm.setPrefsNoPermCheck(p.View()); err != nil { - t.Fatal(err) + { + name: "only worst mullvad", + lastReport: preferred1Report, + netMap: largeNetmap, + allowPolicy: []tailcfg.StableNodeID{"stable6"}, + wantID: "stable6", + wantName: "San Jose", + wantLocation: sanJose.View(), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + wantRegions := tt.wantRegions + if wantRegions == nil { + wantRegions = []int{tt.useRegion} } - b.onTailnetDefaultAutoUpdate(tt.tailnetDefault) - want := tt.after - // On platforms that don't support auto-update we can never - // transition to auto-updates being enabled. The value should - // remain unchanged after onTailnetDefaultAutoUpdate. - if !clientupdate.CanAutoUpdate() && want.EqualBool(true) { - want = tt.before + selectRegion := deterministicRegionForTest(t, views.SliceOf(wantRegions), tt.useRegion) + + wantNodes := tt.wantNodes + if wantNodes == nil { + wantNodes = []tailcfg.StableNodeID{tt.wantID} } - if got := b.pm.CurrentPrefs().AutoUpdate().Apply; got != want { - t.Errorf("got: %q, want %q", got, want) + selectNode := deterministicNodeForTest(t, views.SliceOf(wantNodes), tt.lastSuggestion, tt.wantID) + + var allowList set.Set[tailcfg.StableNodeID] + if tt.allowPolicy != nil { + allowList = set.SetOf(tt.allowPolicy) + } + + nb := newNodeBackend(t.Context(), tstest.WhileTestRunningLogger(t), eventbus.New()) + defer nb.shutdown(errShutdown) + nb.SetNetMap(tt.netMap) + + got, err := suggestExitNode(tt.lastReport, nb, tt.lastSuggestion, selectRegion, selectNode, allowList) + if got.Name != tt.wantName { + t.Errorf("name=%v, want %v", got.Name, tt.wantName) + } + if got.ID != tt.wantID { + t.Errorf("ID=%v, want %v", got.ID, tt.wantID) + } + if tt.wantError == nil && err != nil { + t.Errorf("err=%v, want no error", err) + } + if tt.wantError != nil && !errors.Is(err, tt.wantError) { + t.Errorf("err=%v, want %v", err, tt.wantError) + } + if !reflect.DeepEqual(got.Location, tt.wantLocation) { + t.Errorf("location=%v, want %v", got.Location, tt.wantLocation) } }) } } -func TestTCPHandlerForDst(t *testing.T) { - b := newTestBackend(t) +func TestSuggestExitNodePickWeighted(t *testing.T) { + location10 := tailcfg.Location{ + Priority: 10, + } + location20 := tailcfg.Location{ + Priority: 20, + } tests := []struct { - desc string - dst string - intercept bool + name string + candidates []tailcfg.NodeView + wantIDs []tailcfg.StableNodeID }{ { - desc: "intercept port 80 (Web UI) on quad100 IPv4", - dst: "100.100.100.100:80", - intercept: true, - }, - { - desc: "intercept port 80 (Web UI) on quad100 IPv6", - dst: "[fd7a:115c:a1e0::53]:80", - intercept: true, - }, - { - desc: "don't intercept port 80 on local ip", - dst: "100.100.103.100:80", - intercept: false, + name: "different priorities", + candidates: []tailcfg.NodeView{ + makePeer(2, withExitRoutes(), withLocation(location20.View())), + makePeer(3, withExitRoutes(), withLocation(location10.View())), + }, + wantIDs: []tailcfg.StableNodeID{"stable2"}, }, { - desc: "intercept port 8080 (Taildrive) on quad100 IPv4", - dst: "100.100.100.100:8080", - intercept: true, + name: "same priorities", + candidates: []tailcfg.NodeView{ + makePeer(2, withExitRoutes(), withLocation(location10.View())), + makePeer(3, withExitRoutes(), withLocation(location10.View())), + }, + wantIDs: []tailcfg.StableNodeID{"stable2", "stable3"}, }, { - desc: "intercept port 8080 (Taildrive) on quad100 IPv6", - dst: "[fd7a:115c:a1e0::53]:8080", - intercept: true, + name: "<1 candidates", + candidates: []tailcfg.NodeView{}, }, { - desc: "don't intercept port 8080 on local ip", - dst: "100.100.103.100:8080", - intercept: false, + name: "1 candidate", + candidates: []tailcfg.NodeView{ + makePeer(2, withExitRoutes(), withLocation(location20.View())), + }, + wantIDs: []tailcfg.StableNodeID{"stable2"}, }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := pickWeighted(tt.candidates) + gotIDs := make([]tailcfg.StableNodeID, 0, len(got)) + for _, n := range got { + if !n.Valid() { + gotIDs = append(gotIDs, "") + continue + } + gotIDs = append(gotIDs, n.StableID()) + } + if !views.SliceEqualAnyOrder(views.SliceOf(gotIDs), views.SliceOf(tt.wantIDs)) { + t.Errorf("node IDs = %v, want %v", gotIDs, tt.wantIDs) + } + }) + } +} + +func TestSuggestExitNodeLongLatDistance(t *testing.T) { + tests := []struct { + name string + fromLat float64 + fromLong float64 + toLat float64 + toLong float64 + want float64 + }{ { - desc: "don't intercept port 9080 on quad100 IPv4", - dst: "100.100.100.100:9080", - intercept: false, + name: "zero values", + fromLat: 0, + fromLong: 0, + toLat: 0, + toLong: 0, + want: 0, }, { - desc: "don't intercept port 9080 on quad100 IPv6", - dst: "[fd7a:115c:a1e0::53]:9080", - intercept: false, + name: "valid values", + fromLat: 40.73061, + fromLong: -73.935242, + toLat: 37.3382082, + toLong: -121.8863286, + want: 4117266.873301274, }, { - desc: "don't intercept port 9080 on local ip", - dst: "100.100.103.100:9080", - intercept: false, + name: "valid values, locations in north and south of equator", + fromLat: 40.73061, + fromLong: -73.935242, + toLat: -33.861481, + toLong: 151.205475, + want: 15994089.144368416, }, } - + // The wanted values are computed using a more precise algorithm using the WGS84 model but + // longLatDistance uses a spherical approximation for simplicity. To account for this, we allow for + // 10km of error. for _, tt := range tests { - t.Run(tt.dst, func(t *testing.T) { - t.Log(tt.desc) - src := netip.MustParseAddrPort("100.100.102.100:51234") - h, _ := b.TCPHandlerForDst(src, netip.MustParseAddrPort(tt.dst)) - if !tt.intercept && h != nil { - t.Error("intercepted traffic we shouldn't have") - } else if tt.intercept && h == nil { - t.Error("failed to intercept traffic we should have") + t.Run(tt.name, func(t *testing.T) { + got := longLatDistance(tt.fromLat, tt.fromLong, tt.toLat, tt.toLong) + const maxError = 10000 // 10km + if math.Abs(got-tt.want) > maxError { + t.Errorf("distance=%vm, want within %vm of %vm", got, maxError, tt.want) } }) } } -func TestDriveManageShares(t *testing.T) { - tests := []struct { - name string - disabled bool - existing []*drive.Share - add *drive.Share - remove string - rename [2]string - expect any +func TestSuggestExitNodeTrafficSteering(t *testing.T) { + city := &tailcfg.Location{ + Country: "Canada", + CountryCode: "CA", + City: "Montreal", + CityCode: "MTR", + Latitude: 45.5053, + Longitude: -73.5525, + } + noLatLng := &tailcfg.Location{ + Country: "Canada", + CountryCode: "CA", + City: "Montreal", + CityCode: "MTR", + } + + selfNode := tailcfg.Node{ + ID: 0, // randomness is seeded off NetMap.SelfNode.ID + Addresses: []netip.Prefix{ + netip.MustParsePrefix("100.64.1.1/32"), + netip.MustParsePrefix("fe70::1/128"), + }, + CapMap: tailcfg.NodeCapMap{ + tailcfg.NodeAttrTrafficSteering: []tailcfg.RawMessage{}, + }, + } + + for _, tt := range []struct { + name string + + netMap *netmap.NetworkMap + lastExit tailcfg.StableNodeID + allowPolicy []tailcfg.StableNodeID + + wantID tailcfg.StableNodeID + wantName string + wantLoc *tailcfg.Location + wantPri int + + wantErr error }{ { - name: "append", - existing: []*drive.Share{ - {Name: "b"}, - {Name: "d"}, - }, - add: &drive.Share{Name: " E "}, - expect: []*drive.Share{ - {Name: "b"}, - {Name: "d"}, - {Name: "e"}, - }, + name: "no-netmap", + netMap: nil, + wantErr: ErrNoNetMap, }, { - name: "prepend", - existing: []*drive.Share{ - {Name: "b"}, - {Name: "d"}, - }, - add: &drive.Share{Name: " A "}, - expect: []*drive.Share{ - {Name: "a"}, - {Name: "b"}, - {Name: "d"}, + name: "no-nodes", + netMap: &netmap.NetworkMap{ + SelfNode: selfNode.View(), + Peers: []tailcfg.NodeView{}, }, + wantID: "", }, { - name: "insert", - existing: []*drive.Share{ - {Name: "b"}, - {Name: "d"}, + name: "no-exit-nodes", + netMap: &netmap.NetworkMap{ + SelfNode: selfNode.View(), + Peers: []tailcfg.NodeView{ + makePeer(1), + }, }, - add: &drive.Share{Name: " C "}, - expect: []*drive.Share{ - {Name: "b"}, - {Name: "c"}, - {Name: "d"}, + wantID: "", + }, + { + name: "exit-node-without-suggestion", + netMap: &netmap.NetworkMap{ + SelfNode: selfNode.View(), + Peers: []tailcfg.NodeView{ + makePeer(1, + withExitRoutes()), + }, }, + wantID: "", }, { - name: "replace", - existing: []*drive.Share{ - {Name: "b", Path: "i"}, - {Name: "d"}, + name: "suggested-exit-node-without-routes", + netMap: &netmap.NetworkMap{ + SelfNode: selfNode.View(), + Peers: []tailcfg.NodeView{ + makePeer(1, + withSuggest()), + }, }, - add: &drive.Share{Name: " B ", Path: "ii"}, - expect: []*drive.Share{ - {Name: "b", Path: "ii"}, - {Name: "d"}, + wantID: "", + }, + { + name: "suggested-exit-node", + netMap: &netmap.NetworkMap{ + SelfNode: selfNode.View(), + Peers: []tailcfg.NodeView{ + makePeer(1, + withExitRoutes(), + withSuggest()), + }, }, + wantID: "stable1", + wantName: "peer1", }, { - name: "add_bad_name", - add: &drive.Share{Name: "$"}, - expect: drive.ErrInvalidShareName, + name: "suggest-exit-node-stable-pick", + netMap: &netmap.NetworkMap{ + SelfNode: selfNode.View(), + Peers: []tailcfg.NodeView{ + makePeer(1, + withExitRoutes(), + withSuggest()), + makePeer(2, + withExitRoutes(), + withSuggest()), + makePeer(3, + withExitRoutes(), + withSuggest()), + makePeer(4, + withExitRoutes(), + withSuggest()), + }, + }, + // Change this, if the hashing function changes. + wantID: "stable3", + wantName: "peer3", }, { - name: "add_disabled", - disabled: true, - add: &drive.Share{Name: "a"}, - expect: drive.ErrDriveNotEnabled, + name: "exit-nodes-with-and-without-priority", + netMap: &netmap.NetworkMap{ + SelfNode: selfNode.View(), + Peers: []tailcfg.NodeView{ + makePeer(1, + withExitRoutes(), + withSuggest(), + withLocationPriority(1)), + makePeer(2, + withExitRoutes(), + withSuggest()), + }, + }, + wantID: "stable1", + wantName: "peer1", + wantPri: 1, }, { - name: "remove", - existing: []*drive.Share{ - {Name: "a"}, - {Name: "b"}, - {Name: "c"}, - }, - remove: "b", - expect: []*drive.Share{ - {Name: "a"}, - {Name: "c"}, + name: "exit-nodes-without-and-with-priority", + netMap: &netmap.NetworkMap{ + SelfNode: selfNode.View(), + Peers: []tailcfg.NodeView{ + makePeer(1, + withExitRoutes(), + withSuggest()), + makePeer(2, + withExitRoutes(), + withSuggest(), + withLocationPriority(1)), + }, }, + wantID: "stable2", + wantName: "peer2", + wantPri: 1, }, { - name: "remove_non_existing", - existing: []*drive.Share{ - {Name: "a"}, - {Name: "b"}, - {Name: "c"}, + name: "exit-nodes-with-negative-priority", + netMap: &netmap.NetworkMap{ + SelfNode: selfNode.View(), + Peers: []tailcfg.NodeView{ + makePeer(1, + withExitRoutes(), + withSuggest(), + withLocationPriority(-1)), + makePeer(2, + withExitRoutes(), + withSuggest(), + withLocationPriority(-2)), + makePeer(3, + withExitRoutes(), + withSuggest(), + withLocationPriority(-3)), + makePeer(4, + withExitRoutes(), + withSuggest(), + withLocationPriority(-4)), + }, }, - remove: "D", - expect: os.ErrNotExist, + wantID: "stable1", + wantName: "peer1", + wantPri: -1, }, { - name: "remove_disabled", - disabled: true, - remove: "b", - expect: drive.ErrDriveNotEnabled, + name: "exit-nodes-no-priority-beats-negative-priority", + netMap: &netmap.NetworkMap{ + SelfNode: selfNode.View(), + Peers: []tailcfg.NodeView{ + makePeer(1, + withExitRoutes(), + withSuggest(), + withLocationPriority(-1)), + makePeer(2, + withExitRoutes(), + withSuggest(), + withLocationPriority(-2)), + makePeer(3, + withExitRoutes(), + withSuggest()), + }, + }, + wantID: "stable3", + wantName: "peer3", }, { - name: "rename", - existing: []*drive.Share{ - {Name: "a"}, - {Name: "b"}, - }, - rename: [2]string{"a", " C "}, - expect: []*drive.Share{ - {Name: "b"}, - {Name: "c"}, + name: "exit-nodes-same-priority", + netMap: &netmap.NetworkMap{ + SelfNode: selfNode.View(), + Peers: []tailcfg.NodeView{ + makePeer(1, + withExitRoutes(), + withSuggest(), + withLocationPriority(1)), + makePeer(2, + withExitRoutes(), + withSuggest(), + withLocationPriority(2)), // top + makePeer(3, + withExitRoutes(), + withSuggest(), + withLocationPriority(1)), + makePeer(4, + withExitRoutes(), + withSuggest(), + withLocationPriority(2)), // top + makePeer(5, + withExitRoutes(), + withSuggest(), + withLocationPriority(2)), // top + makePeer(6, + withExitRoutes(), + withSuggest()), + makePeer(7, + withExitRoutes(), + withSuggest(), + withLocationPriority(2)), // top + }, }, + wantID: "stable5", + wantName: "peer5", + wantPri: 2, }, { - name: "rename_not_exist", - existing: []*drive.Share{ - {Name: "a"}, - {Name: "b"}, + name: "suggested-exit-node-with-city", + netMap: &netmap.NetworkMap{ + SelfNode: selfNode.View(), + Peers: []tailcfg.NodeView{ + makePeer(1, + withExitRoutes(), + withSuggest(), + withLocation(city.View())), + }, }, - rename: [2]string{"d", "c"}, - expect: os.ErrNotExist, + wantID: "stable1", + wantName: "peer1", + wantLoc: city, }, { - name: "rename_exists", - existing: []*drive.Share{ - {Name: "a"}, - {Name: "b"}, + name: "suggested-exit-node-with-city-and-priority", + netMap: &netmap.NetworkMap{ + SelfNode: selfNode.View(), + Peers: []tailcfg.NodeView{ + makePeer(1, + withExitRoutes(), + withSuggest(), + withLocation(city.View()), + withLocationPriority(1)), + }, }, - rename: [2]string{"a", "b"}, - expect: os.ErrExist, + wantID: "stable1", + wantName: "peer1", + wantLoc: city, + wantPri: 1, }, { - name: "rename_bad_name", - rename: [2]string{"a", "$"}, - expect: drive.ErrInvalidShareName, + name: "suggested-exit-node-without-latlng", + netMap: &netmap.NetworkMap{ + SelfNode: selfNode.View(), + Peers: []tailcfg.NodeView{ + makePeer(1, + withExitRoutes(), + withSuggest(), + withLocation(noLatLng.View())), + }, + }, + wantID: "stable1", + wantName: "peer1", + wantLoc: noLatLng, }, { - name: "rename_disabled", - disabled: true, - rename: [2]string{"a", "c"}, - expect: drive.ErrDriveNotEnabled, - }, - } - - drive.DisallowShareAs = true - t.Cleanup(func() { - drive.DisallowShareAs = false - }) - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - b := newTestBackend(t) - b.mu.Lock() - if tt.existing != nil { - b.driveSetSharesLocked(tt.existing) - } - if !tt.disabled { - self := b.netMap.SelfNode.AsStruct() - self.CapMap = tailcfg.NodeCapMap{tailcfg.NodeAttrsTaildriveShare: nil} - b.netMap.SelfNode = self.View() - b.sys.Set(driveimpl.NewFileSystemForRemote(b.logf)) - } - b.mu.Unlock() - - ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) - t.Cleanup(cancel) - - result := make(chan views.SliceView[*drive.Share, drive.ShareView], 1) - - var wg sync.WaitGroup - wg.Add(1) - go b.WatchNotifications( - ctx, - 0, - func() { wg.Done() }, - func(n *ipn.Notify) bool { - select { - case result <- n.DriveShares: - default: - // - } - return false - }, - ) - wg.Wait() - - var err error - switch { - case tt.add != nil: - err = b.DriveSetShare(tt.add) - case tt.remove != "": - err = b.DriveRemoveShare(tt.remove) - default: - err = b.DriveRenameShare(tt.rename[0], tt.rename[1]) - } - - switch e := tt.expect.(type) { - case error: - if !errors.Is(err, e) { - t.Errorf("expected error, want: %v got: %v", e, err) - } - case []*drive.Share: - if err != nil { - t.Errorf("unexpected error: %v", err) - } else { - r := <-result - - got, err := json.MarshalIndent(r, "", " ") - if err != nil { - t.Fatalf("can't marshal got: %v", err) - } - want, err := json.MarshalIndent(e, "", " ") - if err != nil { - t.Fatalf("can't marshal want: %v", err) - } - if diff := cmp.Diff(string(got), string(want)); diff != "" { - t.Errorf("wrong shares; (-got+want):%v", diff) - } - } - } - }) - } -} - -func TestValidPopBrowserURL(t *testing.T) { - b := newTestBackend(t) - tests := []struct { - desc string - controlURL string - popBrowserURL string - want bool - }{ - {"saas_login", "https://login.tailscale.com", "https://login.tailscale.com/a/foo", true}, - {"saas_controlplane", "https://controlplane.tailscale.com", "https://controlplane.tailscale.com/a/foo", true}, - {"saas_root", "https://login.tailscale.com", "https://tailscale.com/", true}, - {"saas_bad_hostname", "https://login.tailscale.com", "https://example.com/a/foo", false}, - {"localhost", "http://localhost", "http://localhost/a/foo", true}, - {"custom_control_url_https", "https://example.com", "https://example.com/a/foo", true}, - {"custom_control_url_https_diff_domain", "https://example.com", "https://other.com/a/foo", true}, - {"custom_control_url_http", "http://example.com", "http://example.com/a/foo", true}, - {"custom_control_url_http_diff_domain", "http://example.com", "http://other.com/a/foo", true}, - {"bad_scheme", "https://example.com", "http://example.com/a/foo", false}, - } - for _, tt := range tests { - t.Run(tt.desc, func(t *testing.T) { - if _, err := b.EditPrefs(&ipn.MaskedPrefs{ - ControlURLSet: true, - Prefs: ipn.Prefs{ - ControlURL: tt.controlURL, + name: "suggested-exit-node-without-latlng-with-priority", + netMap: &netmap.NetworkMap{ + SelfNode: selfNode.View(), + Peers: []tailcfg.NodeView{ + makePeer(1, + withExitRoutes(), + withSuggest(), + withLocation(noLatLng.View()), + withLocationPriority(1)), }, - }); err != nil { - t.Fatal(err) - } - - got := b.validPopBrowserURL(tt.popBrowserURL) - if got != tt.want { - t.Errorf("got %v, want %v", got, tt.want) - } - }) - } -} - -func TestRoundTraffic(t *testing.T) { - tests := []struct { - name string - bytes int64 - want float64 - }{ - {name: "under 5 bytes", bytes: 4, want: 4}, - {name: "under 1000 bytes", bytes: 987, want: 990}, - {name: "under 10_000 bytes", bytes: 8875, want: 8900}, - {name: "under 100_000 bytes", bytes: 77777, want: 78000}, - {name: "under 1_000_000 bytes", bytes: 666523, want: 670000}, - {name: "under 10_000_000 bytes", bytes: 22556677, want: 23000000}, - {name: "under 1_000_000_000 bytes", bytes: 1234234234, want: 1200000000}, - {name: "under 1_000_000_000 bytes", bytes: 123423423499, want: 123400000000}, - } - - for _, tt := range tests { + }, + wantID: "stable1", + wantName: "peer1", + wantLoc: noLatLng, + wantPri: 1, + }, + } { t.Run(tt.name, func(t *testing.T) { - if result := roundTraffic(tt.bytes); result != tt.want { - t.Errorf("unexpected rounding got %v want %v", result, tt.want) + var allowList set.Set[tailcfg.StableNodeID] + if tt.allowPolicy != nil { + allowList = set.SetOf(tt.allowPolicy) } - }) - } -} - -func (b *LocalBackend) SetPrefsForTest(newp *ipn.Prefs) { - if newp == nil { - panic("SetPrefsForTest got nil prefs") - } - unlock := b.lockAndGetUnlock() - defer unlock() - b.setPrefsLockedOnEntry(newp, unlock) -} - -type peerOptFunc func(*tailcfg.Node) - -func makePeer(id tailcfg.NodeID, opts ...peerOptFunc) tailcfg.NodeView { - node := &tailcfg.Node{ - ID: id, - StableID: tailcfg.StableNodeID(fmt.Sprintf("stable%d", id)), - Name: fmt.Sprintf("peer%d", id), - DERP: fmt.Sprintf("127.3.3.40:%d", id), - } - for _, opt := range opts { - opt(node) - } - return node.View() -} - -func withName(name string) peerOptFunc { - return func(n *tailcfg.Node) { - n.Name = name - } -} - -func withDERP(region int) peerOptFunc { - return func(n *tailcfg.Node) { - n.DERP = fmt.Sprintf("127.3.3.40:%d", region) - } -} - -func withoutDERP() peerOptFunc { - return func(n *tailcfg.Node) { - n.DERP = "" - } -} - -func withLocation(loc tailcfg.LocationView) peerOptFunc { - return func(n *tailcfg.Node) { - var hi *tailcfg.Hostinfo - if n.Hostinfo.Valid() { - hi = n.Hostinfo.AsStruct() - } else { - hi = new(tailcfg.Hostinfo) - } - hi.Location = loc.AsStruct() - - n.Hostinfo = hi.View() - } -} - -func withExitRoutes() peerOptFunc { - return func(n *tailcfg.Node) { - n.AllowedIPs = append(n.AllowedIPs, tsaddr.ExitRoutes()...) - } -} - -func withSuggest() peerOptFunc { - return func(n *tailcfg.Node) { - mak.Set(&n.CapMap, tailcfg.NodeAttrSuggestExitNode, []tailcfg.RawMessage{}) - } -} -func withCap(version tailcfg.CapabilityVersion) peerOptFunc { - return func(n *tailcfg.Node) { - n.Cap = version - } -} + // HACK: NetMap.AllCaps is populated by Control: + if tt.netMap != nil { + caps := maps.Keys(tt.netMap.SelfNode.CapMap().AsMap()) + tt.netMap.AllCaps = set.SetOf(slices.Collect(caps)) + } -func withOnline(isOnline bool) peerOptFunc { - return func(n *tailcfg.Node) { - n.Online = &isOnline - } -} + nb := newNodeBackend(t.Context(), tstest.WhileTestRunningLogger(t), eventbus.New()) + defer nb.shutdown(errShutdown) + nb.SetNetMap(tt.netMap) -func withNodeKey() peerOptFunc { - return func(n *tailcfg.Node) { - n.Key = key.NewNode().Public() - } -} + got, err := suggestExitNodeUsingTrafficSteering(nb, allowList) + if tt.wantErr == nil && err != nil { + t.Fatalf("err=%v, want nil", err) + } + if tt.wantErr != nil && !errors.Is(err, tt.wantErr) { + t.Fatalf("err=%v, want %v", err, tt.wantErr) + } -func deterministicRegionForTest(t testing.TB, want views.Slice[int], use int) selectRegionFunc { - t.Helper() + if got.Name != tt.wantName { + t.Errorf("name=%q, want %q", got.Name, tt.wantName) + } - if !views.SliceContains(want, use) { - t.Errorf("invalid test: use %v is not in want %v", use, want) - } + if got.ID != tt.wantID { + t.Errorf("ID=%q, want %q", got.ID, tt.wantID) + } - return func(got views.Slice[int]) int { - if !views.SliceEqualAnyOrder(got, want) { - t.Errorf("candidate regions = %v, want %v", got, want) - } - return use + wantLoc := tt.wantLoc + if tt.wantPri != 0 { + if wantLoc == nil { + wantLoc = new(tailcfg.Location) + } + wantLoc.Priority = tt.wantPri + } + if diff := cmp.Diff(got.Location.AsStruct(), wantLoc); diff != "" { + t.Errorf("location mismatch (+want -got)\n%s", diff) + } + }) } } -func deterministicNodeForTest(t testing.TB, want views.Slice[tailcfg.StableNodeID], wantLast tailcfg.StableNodeID, use tailcfg.StableNodeID) selectNodeFunc { - t.Helper() - - if !views.SliceContains(want, use) { - t.Errorf("invalid test: use %v is not in want %v", use, want) +func TestMinLatencyDERPregion(t *testing.T) { + tests := []struct { + name string + regions []int + report *netcheck.Report + wantRegion int + }{ + { + name: "regions, no latency values", + regions: []int{1, 2, 3}, + wantRegion: 0, + report: &netcheck.Report{}, + }, + { + name: "regions, different latency values", + regions: []int{1, 2, 3}, + wantRegion: 2, + report: &netcheck.Report{ + RegionLatency: map[int]time.Duration{ + 1: 10 * time.Millisecond, + 2: 5 * time.Millisecond, + 3: 30 * time.Millisecond, + }, + }, + }, + { + name: "regions, same values", + regions: []int{1, 2, 3}, + wantRegion: 1, + report: &netcheck.Report{ + RegionLatency: map[int]time.Duration{ + 1: 10 * time.Millisecond, + 2: 10 * time.Millisecond, + 3: 10 * time.Millisecond, + }, + }, + }, } - return func(got views.Slice[tailcfg.NodeView], last tailcfg.StableNodeID) tailcfg.NodeView { - var ret tailcfg.NodeView - - gotIDs := make([]tailcfg.StableNodeID, got.Len()) - for i := range got.Len() { - nv := got.At(i) - if !nv.Valid() { - t.Fatalf("invalid node at index %v", i) - } - - gotIDs[i] = nv.StableID() - if nv.StableID() == use { - ret = nv + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := minLatencyDERPRegion(tt.regions, tt.report) + if got != tt.wantRegion { + t.Errorf("got region %v want region %v", got, tt.wantRegion) } - } - if !views.SliceEqualAnyOrder(views.SliceOf(gotIDs), want) { - t.Errorf("candidate nodes = %v, want %v", gotIDs, want) - } - if last != wantLast { - t.Errorf("last node = %v, want %v", last, wantLast) - } - if !ret.Valid() { - t.Fatalf("did not find matching node in %v, want %v", gotIDs, use) - } - - return ret + }) } } -func TestSuggestExitNode(t *testing.T) { - t.Parallel() +func TestEnableAutoUpdates(t *testing.T) { + lb := newTestLocalBackend(t) - defaultDERPMap := &tailcfg.DERPMap{ - Regions: map[int]*tailcfg.DERPRegion{ - 1: { - Latitude: 32, - Longitude: -97, + _, err := lb.EditPrefs(&ipn.MaskedPrefs{ + AutoUpdateSet: ipn.AutoUpdatePrefsMask{ + ApplySet: true, + }, + Prefs: ipn.Prefs{ + AutoUpdate: ipn.AutoUpdatePrefs{ + Apply: opt.NewBool(true), }, - 2: {}, - 3: {}, }, + }) + // Enabling may fail, depending on which environment we are running this + // test in. + wantErr := !feature.CanAutoUpdate() + gotErr := err != nil + if gotErr != wantErr { + t.Fatalf("enabling auto-updates: got error: %v (%v); want error: %v", gotErr, err, wantErr) } - preferred1Report := &netcheck.Report{ - RegionLatency: map[int]time.Duration{ - 1: 10 * time.Millisecond, - 2: 20 * time.Millisecond, - 3: 30 * time.Millisecond, - }, - PreferredDERP: 1, - } - noLatency1Report := &netcheck.Report{ - RegionLatency: map[int]time.Duration{ - 1: 0, - 2: 0, - 3: 0, + // Disabling should always succeed. + if _, err := lb.EditPrefs(&ipn.MaskedPrefs{ + AutoUpdateSet: ipn.AutoUpdatePrefsMask{ + ApplySet: true, }, - PreferredDERP: 1, - } - preferredNoneReport := &netcheck.Report{ - RegionLatency: map[int]time.Duration{ - 1: 10 * time.Millisecond, - 2: 20 * time.Millisecond, - 3: 30 * time.Millisecond, + Prefs: ipn.Prefs{ + AutoUpdate: ipn.AutoUpdatePrefs{ + Apply: opt.NewBool(false), + }, }, - PreferredDERP: 0, + }); err != nil { + t.Fatalf("disabling auto-updates: got error: %v", err) } +} - dallas := tailcfg.Location{ - Latitude: 32.779167, - Longitude: -96.808889, - Priority: 100, - } - sanJose := tailcfg.Location{ - Latitude: 37.3382082, - Longitude: -121.8863286, - Priority: 20, +func TestReadWriteRouteInfo(t *testing.T) { + // set up a backend with more than one profile + b := newTestBackend(t) + prof1 := ipn.LoginProfile{ID: "id1", Key: "key1"} + prof2 := ipn.LoginProfile{ID: "id2", Key: "key2"} + b.pm.knownProfiles["id1"] = prof1.View() + b.pm.knownProfiles["id2"] = prof2.View() + b.pm.currentProfile = prof1.View() + + // set up routeInfo + ri1 := appctype.RouteInfo{} + ri1.Wildcards = []string{"1"} + + ri2 := appctype.RouteInfo{} + ri2.Wildcards = []string{"2"} + + // read before write + readRi, err := b.readRouteInfoLocked() + if readRi != nil { + t.Fatalf("read before writing: want nil, got %v", readRi) } - fortWorth := tailcfg.Location{ - Latitude: 32.756389, - Longitude: -97.3325, - Priority: 150, + if err != ipn.ErrStateNotExist { + t.Fatalf("read before writing: want %v, got %v", ipn.ErrStateNotExist, err) } - fortWorthLowPriority := tailcfg.Location{ - Latitude: 32.756389, - Longitude: -97.3325, - Priority: 100, + + // write the first routeInfo + if err := b.storeRouteInfo(ri1); err != nil { + t.Fatal(err) } - peer1 := makePeer(1, - withExitRoutes(), - withSuggest()) - peer2DERP1 := makePeer(2, - withDERP(1), - withExitRoutes(), - withSuggest()) - peer3 := makePeer(3, - withExitRoutes(), - withSuggest()) - peer4DERP3 := makePeer(4, - withDERP(3), - withExitRoutes(), - withSuggest()) - dallasPeer5 := makePeer(5, - withName("Dallas"), - withoutDERP(), - withExitRoutes(), - withSuggest(), - withLocation(dallas.View())) - sanJosePeer6 := makePeer(6, - withName("San Jose"), - withoutDERP(), - withExitRoutes(), - withSuggest(), - withLocation(sanJose.View())) - fortWorthPeer7 := makePeer(7, - withName("Fort Worth"), - withoutDERP(), - withExitRoutes(), - withSuggest(), - withLocation(fortWorth.View())) - fortWorthPeer8LowPriority := makePeer(8, - withName("Fort Worth Low"), - withoutDERP(), - withExitRoutes(), - withSuggest(), - withLocation(fortWorthLowPriority.View())) + // write the other routeInfo as the other profile + if _, _, err := b.pm.SwitchToProfileByID("id2"); err != nil { + t.Fatal(err) + } + if err := b.storeRouteInfo(ri2); err != nil { + t.Fatal(err) + } - selfNode := tailcfg.Node{ - Addresses: []netip.Prefix{ - netip.MustParsePrefix("100.64.1.1/32"), - netip.MustParsePrefix("fe70::1/128"), - }, + // read the routeInfo of the first profile + if _, _, err := b.pm.SwitchToProfileByID("id1"); err != nil { + t.Fatal(err) + } + readRi, err = b.readRouteInfoLocked() + if err != nil { + t.Fatal(err) + } + if !slices.Equal(readRi.Wildcards, ri1.Wildcards) { + t.Fatalf("read prof1 routeInfo wildcards: want %v, got %v", ri1.Wildcards, readRi.Wildcards) } - defaultNetmap := &netmap.NetworkMap{ - SelfNode: selfNode.View(), - DERPMap: defaultDERPMap, - Peers: []tailcfg.NodeView{ - peer2DERP1, - peer3, - }, + // read the routeInfo of the second profile + if _, _, err := b.pm.SwitchToProfileByID("id2"); err != nil { + t.Fatal(err) } - locationNetmap := &netmap.NetworkMap{ - SelfNode: selfNode.View(), - DERPMap: defaultDERPMap, - Peers: []tailcfg.NodeView{ - dallasPeer5, - sanJosePeer6, - }, + readRi, err = b.readRouteInfoLocked() + if err != nil { + t.Fatal(err) } - largeNetmap := &netmap.NetworkMap{ - SelfNode: selfNode.View(), - DERPMap: defaultDERPMap, - Peers: []tailcfg.NodeView{ - peer1, - peer2DERP1, - peer3, - peer4DERP3, - dallasPeer5, - sanJosePeer6, - fortWorthPeer7, - }, + if !slices.Equal(readRi.Wildcards, ri2.Wildcards) { + t.Fatalf("read prof2 routeInfo wildcards: want %v, got %v", ri2.Wildcards, readRi.Wildcards) } +} +func TestFillAllowedSuggestions(t *testing.T) { tests := []struct { - name string - - lastReport *netcheck.Report - netMap *netmap.NetworkMap - lastSuggestion tailcfg.StableNodeID - - allowPolicy []tailcfg.StableNodeID + name string + allowPolicy []string + want []tailcfg.StableNodeID + }{ + { + name: "unset", + }, + { + name: "zero", + allowPolicy: []string{}, + want: []tailcfg.StableNodeID{}, + }, + { + name: "one", + allowPolicy: []string{"one"}, + want: []tailcfg.StableNodeID{"one"}, + }, + { + name: "many", + allowPolicy: []string{"one", "two", "three", "four"}, + want: []tailcfg.StableNodeID{"one", "three", "four", "two"}, // order should not matter + }, + { + name: "preserve case", + allowPolicy: []string{"ABC", "def", "gHiJ"}, + want: []tailcfg.StableNodeID{"ABC", "def", "gHiJ"}, + }, + } - wantRegions []int - useRegion int + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var pol policytest.Config + pol.Set(pkey.AllowedSuggestedExitNodes, tt.allowPolicy) - wantNodes []tailcfg.StableNodeID + got := fillAllowedSuggestions(pol) + if got == nil { + if tt.want == nil { + return + } + t.Errorf("got nil, want %v", tt.want) + } + if tt.want == nil { + t.Errorf("got %v, want nil", got) + } - wantID tailcfg.StableNodeID - wantName string - wantLocation tailcfg.LocationView + if !got.Equal(set.SetOf(tt.want)) { + t.Errorf("got %v, want %v", got, tt.want) + } + }) + } +} - wantError error +func TestNotificationTargetMatch(t *testing.T) { + tests := []struct { + name string + target notificationTarget + actor ipnauth.Actor + wantMatch bool }{ { - name: "2 exit nodes in same region", - lastReport: preferred1Report, - netMap: &netmap.NetworkMap{ - SelfNode: selfNode.View(), - DERPMap: defaultDERPMap, - Peers: []tailcfg.NodeView{ - peer1, - peer2DERP1, - }, - }, - wantNodes: []tailcfg.StableNodeID{ - "stable1", - "stable2", - }, - wantName: "peer1", - wantID: "stable1", + name: "AllClients/Nil", + target: allClients, + actor: nil, + wantMatch: true, }, { - name: "2 exit nodes different regions unknown latency", - lastReport: noLatency1Report, - netMap: defaultNetmap, - wantRegions: []int{1, 3}, // the only regions with peers - useRegion: 1, - wantName: "peer2", - wantID: "stable2", + name: "AllClients/NoUID/NoCID", + target: allClients, + actor: &ipnauth.TestActor{}, + wantMatch: true, }, { - name: "2 derp based exit nodes, different regions, equal latency", - lastReport: &netcheck.Report{ - RegionLatency: map[int]time.Duration{ - 1: 10, - 2: 20, - 3: 10, - }, - PreferredDERP: 1, - }, - netMap: &netmap.NetworkMap{ - SelfNode: selfNode.View(), - DERPMap: defaultDERPMap, - Peers: []tailcfg.NodeView{ - peer1, - peer3, - }, - }, - wantRegions: []int{1, 2}, - useRegion: 1, - wantName: "peer1", - wantID: "stable1", + name: "AllClients/WithUID/NoCID", + target: allClients, + actor: &ipnauth.TestActor{UID: "S-1-5-21-1-2-3-4", CID: ipnauth.NoClientID}, + wantMatch: true, }, { - name: "mullvad nodes, no derp based exit nodes", - lastReport: noLatency1Report, - netMap: locationNetmap, - wantID: "stable5", - wantLocation: dallas.View(), - wantName: "Dallas", + name: "AllClients/NoUID/WithCID", + target: allClients, + actor: &ipnauth.TestActor{CID: ipnauth.ClientIDFrom("A")}, + wantMatch: true, }, { - name: "nearby mullvad nodes with different priorities", - lastReport: noLatency1Report, - netMap: &netmap.NetworkMap{ - SelfNode: selfNode.View(), - DERPMap: defaultDERPMap, - Peers: []tailcfg.NodeView{ - dallasPeer5, - sanJosePeer6, - fortWorthPeer7, - }, - }, - wantID: "stable7", - wantLocation: fortWorth.View(), - wantName: "Fort Worth", + name: "AllClients/WithUID/WithCID", + target: allClients, + actor: &ipnauth.TestActor{UID: "S-1-5-21-1-2-3-4", CID: ipnauth.ClientIDFrom("A")}, + wantMatch: true, }, { - name: "nearby mullvad nodes with same priorities", - lastReport: noLatency1Report, - netMap: &netmap.NetworkMap{ - SelfNode: selfNode.View(), - DERPMap: defaultDERPMap, - Peers: []tailcfg.NodeView{ - dallasPeer5, - sanJosePeer6, - fortWorthPeer8LowPriority, - }, - }, - wantNodes: []tailcfg.StableNodeID{"stable5", "stable8"}, - wantID: "stable5", - wantLocation: dallas.View(), - wantName: "Dallas", + name: "FilterByUID/Nil", + target: notificationTarget{userID: "S-1-5-21-1-2-3-4"}, + actor: nil, + wantMatch: false, }, - { - name: "mullvad nodes, remaining node is not in preferred derp", - lastReport: noLatency1Report, - netMap: &netmap.NetworkMap{ - SelfNode: selfNode.View(), - DERPMap: defaultDERPMap, - Peers: []tailcfg.NodeView{ - dallasPeer5, - sanJosePeer6, - peer4DERP3, - }, - }, - useRegion: 3, - wantID: "stable4", - wantName: "peer4", + { + name: "FilterByUID/NoUID/NoCID", + target: notificationTarget{userID: "S-1-5-21-1-2-3-4"}, + actor: &ipnauth.TestActor{}, + wantMatch: false, }, { - name: "no peers", - lastReport: noLatency1Report, - netMap: &netmap.NetworkMap{ - SelfNode: selfNode.View(), - DERPMap: defaultDERPMap, - }, + name: "FilterByUID/NoUID/WithCID", + target: notificationTarget{userID: "S-1-5-21-1-2-3-4"}, + actor: &ipnauth.TestActor{CID: ipnauth.ClientIDFrom("A")}, + wantMatch: false, }, { - name: "nil report", - lastReport: nil, - netMap: largeNetmap, - wantError: ErrNoPreferredDERP, + name: "FilterByUID/SameUID/NoCID", + target: notificationTarget{userID: "S-1-5-21-1-2-3-4"}, + actor: &ipnauth.TestActor{UID: "S-1-5-21-1-2-3-4"}, + wantMatch: true, }, { - name: "no preferred derp region", - lastReport: preferredNoneReport, - netMap: &netmap.NetworkMap{ - SelfNode: selfNode.View(), - DERPMap: defaultDERPMap, - }, - wantError: ErrNoPreferredDERP, + name: "FilterByUID/DifferentUID/NoCID", + target: notificationTarget{userID: "S-1-5-21-1-2-3-4"}, + actor: &ipnauth.TestActor{UID: "S-1-5-21-5-6-7-8"}, + wantMatch: false, }, { - name: "nil netmap", - lastReport: noLatency1Report, - netMap: nil, - wantError: ErrNoPreferredDERP, + name: "FilterByUID/SameUID/WithCID", + target: notificationTarget{userID: "S-1-5-21-1-2-3-4"}, + actor: &ipnauth.TestActor{UID: "S-1-5-21-1-2-3-4", CID: ipnauth.ClientIDFrom("A")}, + wantMatch: true, }, { - name: "nil derpmap", - lastReport: noLatency1Report, - netMap: &netmap.NetworkMap{ - SelfNode: selfNode.View(), - DERPMap: nil, - Peers: []tailcfg.NodeView{ - dallasPeer5, - }, - }, - wantError: ErrNoPreferredDERP, + name: "FilterByUID/DifferentUID/WithCID", + target: notificationTarget{userID: "S-1-5-21-1-2-3-4"}, + actor: &ipnauth.TestActor{UID: "S-1-5-21-5-6-7-8", CID: ipnauth.ClientIDFrom("A")}, + wantMatch: false, }, { - name: "missing suggestion capability", - lastReport: noLatency1Report, - netMap: &netmap.NetworkMap{ - SelfNode: selfNode.View(), - DERPMap: defaultDERPMap, - Peers: []tailcfg.NodeView{ - makePeer(1, withExitRoutes()), - makePeer(2, withLocation(dallas.View()), withExitRoutes()), - }, - }, + name: "FilterByCID/Nil", + target: notificationTarget{clientID: ipnauth.ClientIDFrom("A")}, + actor: nil, + wantMatch: false, }, { - name: "prefer last node", - lastReport: preferred1Report, - netMap: &netmap.NetworkMap{ - SelfNode: selfNode.View(), - DERPMap: defaultDERPMap, - Peers: []tailcfg.NodeView{ - peer1, - peer2DERP1, - }, - }, - lastSuggestion: "stable2", - wantNodes: []tailcfg.StableNodeID{ - "stable1", - "stable2", - }, - wantName: "peer2", - wantID: "stable2", + name: "FilterByCID/NoUID/NoCID", + target: notificationTarget{clientID: ipnauth.ClientIDFrom("A")}, + actor: &ipnauth.TestActor{}, + wantMatch: false, }, { - name: "found better derp node", - lastSuggestion: "stable3", - lastReport: preferred1Report, - netMap: defaultNetmap, - wantID: "stable2", - wantName: "peer2", + name: "FilterByCID/NoUID/SameCID", + target: notificationTarget{clientID: ipnauth.ClientIDFrom("A")}, + actor: &ipnauth.TestActor{CID: ipnauth.ClientIDFrom("A")}, + wantMatch: true, }, { - name: "prefer last mullvad node", - lastSuggestion: "stable2", - lastReport: preferred1Report, - netMap: &netmap.NetworkMap{ - SelfNode: selfNode.View(), - DERPMap: defaultDERPMap, - Peers: []tailcfg.NodeView{ - dallasPeer5, - sanJosePeer6, - fortWorthPeer8LowPriority, - }, - }, - wantNodes: []tailcfg.StableNodeID{"stable5", "stable8"}, - wantID: "stable5", - wantName: "Dallas", - wantLocation: dallas.View(), + name: "FilterByCID/NoUID/DifferentCID", + target: notificationTarget{clientID: ipnauth.ClientIDFrom("A")}, + actor: &ipnauth.TestActor{CID: ipnauth.ClientIDFrom("B")}, + wantMatch: false, }, { - name: "prefer better mullvad node", - lastSuggestion: "stable2", - lastReport: preferred1Report, - netMap: &netmap.NetworkMap{ - SelfNode: selfNode.View(), - DERPMap: defaultDERPMap, - Peers: []tailcfg.NodeView{ - dallasPeer5, - sanJosePeer6, - fortWorthPeer7, - }, - }, - wantNodes: []tailcfg.StableNodeID{"stable7"}, - wantID: "stable7", - wantName: "Fort Worth", - wantLocation: fortWorth.View(), + name: "FilterByCID/WithUID/NoCID", + target: notificationTarget{clientID: ipnauth.ClientIDFrom("A")}, + actor: &ipnauth.TestActor{UID: "S-1-5-21-1-2-3-4"}, + wantMatch: false, }, { - name: "large netmap", - lastReport: preferred1Report, - netMap: largeNetmap, - wantNodes: []tailcfg.StableNodeID{"stable1", "stable2"}, - wantID: "stable2", - wantName: "peer2", + name: "FilterByCID/WithUID/SameCID", + target: notificationTarget{clientID: ipnauth.ClientIDFrom("A")}, + actor: &ipnauth.TestActor{UID: "S-1-5-21-1-2-3-4", CID: ipnauth.ClientIDFrom("A")}, + wantMatch: true, }, { - name: "no allowed suggestions", - lastReport: preferred1Report, - netMap: largeNetmap, - allowPolicy: []tailcfg.StableNodeID{}, + name: "FilterByCID/WithUID/DifferentCID", + target: notificationTarget{clientID: ipnauth.ClientIDFrom("A")}, + actor: &ipnauth.TestActor{UID: "S-1-5-21-1-2-3-4", CID: ipnauth.ClientIDFrom("B")}, + wantMatch: false, }, { - name: "only derp suggestions", - lastReport: preferred1Report, - netMap: largeNetmap, - allowPolicy: []tailcfg.StableNodeID{"stable1", "stable2", "stable3"}, - wantNodes: []tailcfg.StableNodeID{"stable1", "stable2"}, - wantID: "stable2", - wantName: "peer2", + name: "FilterByUID+CID/Nil", + target: notificationTarget{userID: "S-1-5-21-1-2-3-4"}, + actor: nil, + wantMatch: false, }, { - name: "only mullvad suggestions", - lastReport: preferred1Report, - netMap: largeNetmap, - allowPolicy: []tailcfg.StableNodeID{"stable5", "stable6", "stable7"}, - wantID: "stable7", - wantName: "Fort Worth", - wantLocation: fortWorth.View(), + name: "FilterByUID+CID/NoUID/NoCID", + target: notificationTarget{userID: "S-1-5-21-1-2-3-4", clientID: ipnauth.ClientIDFrom("A")}, + actor: &ipnauth.TestActor{}, + wantMatch: false, }, { - name: "only worst derp", - lastReport: preferred1Report, - netMap: largeNetmap, - allowPolicy: []tailcfg.StableNodeID{"stable3"}, - wantID: "stable3", - wantName: "peer3", + name: "FilterByUID+CID/NoUID/SameCID", + target: notificationTarget{userID: "S-1-5-21-1-2-3-4", clientID: ipnauth.ClientIDFrom("A")}, + actor: &ipnauth.TestActor{CID: ipnauth.ClientIDFrom("A")}, + wantMatch: false, }, { - name: "only worst mullvad", - lastReport: preferred1Report, - netMap: largeNetmap, - allowPolicy: []tailcfg.StableNodeID{"stable6"}, - wantID: "stable6", - wantName: "San Jose", - wantLocation: sanJose.View(), + name: "FilterByUID+CID/NoUID/DifferentCID", + target: notificationTarget{userID: "S-1-5-21-1-2-3-4", clientID: ipnauth.ClientIDFrom("A")}, + actor: &ipnauth.TestActor{CID: ipnauth.ClientIDFrom("B")}, + wantMatch: false, + }, + { + name: "FilterByUID+CID/SameUID/NoCID", + target: notificationTarget{userID: "S-1-5-21-1-2-3-4", clientID: ipnauth.ClientIDFrom("A")}, + actor: &ipnauth.TestActor{UID: "S-1-5-21-1-2-3-4"}, + wantMatch: false, + }, + { + name: "FilterByUID+CID/SameUID/SameCID", + target: notificationTarget{userID: "S-1-5-21-1-2-3-4", clientID: ipnauth.ClientIDFrom("A")}, + actor: &ipnauth.TestActor{UID: "S-1-5-21-1-2-3-4", CID: ipnauth.ClientIDFrom("A")}, + wantMatch: true, + }, + { + name: "FilterByUID+CID/SameUID/DifferentCID", + target: notificationTarget{userID: "S-1-5-21-1-2-3-4", clientID: ipnauth.ClientIDFrom("A")}, + actor: &ipnauth.TestActor{UID: "S-1-5-21-1-2-3-4", CID: ipnauth.ClientIDFrom("B")}, + wantMatch: false, + }, + { + name: "FilterByUID+CID/DifferentUID/NoCID", + target: notificationTarget{userID: "S-1-5-21-1-2-3-4", clientID: ipnauth.ClientIDFrom("A")}, + actor: &ipnauth.TestActor{UID: "S-1-5-21-5-6-7-8"}, + wantMatch: false, + }, + { + name: "FilterByUID+CID/DifferentUID/SameCID", + target: notificationTarget{userID: "S-1-5-21-1-2-3-4", clientID: ipnauth.ClientIDFrom("A")}, + actor: &ipnauth.TestActor{UID: "S-1-5-21-5-6-7-8", CID: ipnauth.ClientIDFrom("A")}, + wantMatch: false, + }, + { + name: "FilterByUID+CID/DifferentUID/DifferentCID", + target: notificationTarget{userID: "S-1-5-21-1-2-3-4", clientID: ipnauth.ClientIDFrom("A")}, + actor: &ipnauth.TestActor{UID: "S-1-5-21-5-6-7-8", CID: ipnauth.ClientIDFrom("B")}, + wantMatch: false, }, } - for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - wantRegions := tt.wantRegions - if wantRegions == nil { - wantRegions = []int{tt.useRegion} + gotMatch := tt.target.match(tt.actor) + if gotMatch != tt.wantMatch { + t.Errorf("match: got %v; want %v", gotMatch, tt.wantMatch) } - selectRegion := deterministicRegionForTest(t, views.SliceOf(wantRegions), tt.useRegion) + }) + } +} - wantNodes := tt.wantNodes - if wantNodes == nil { - wantNodes = []tailcfg.StableNodeID{tt.wantID} +type newTestControlFn func(tb testing.TB, opts controlclient.Options) controlclient.Client + +func newLocalBackendWithTestControl(t testing.TB, enableLogging bool, newControl newTestControlFn) *LocalBackend { + bus := eventbustest.NewBus(t) + return newLocalBackendWithSysAndTestControl(t, enableLogging, tsd.NewSystemWithBus(bus), newControl) +} + +func newLocalBackendWithSysAndTestControl(t testing.TB, enableLogging bool, sys *tsd.System, newControl newTestControlFn) *LocalBackend { + logf := logger.Discard + if enableLogging { + logf = tstest.WhileTestRunningLogger(t) + } + + if _, hasStore := sys.StateStore.GetOK(); !hasStore { + store := new(mem.Store) + sys.Set(store) + } + if _, hasEngine := sys.Engine.GetOK(); !hasEngine { + e, err := wgengine.NewFakeUserspaceEngine(logf, sys.Set, sys.HealthTracker.Get(), sys.UserMetricsRegistry(), sys.Bus.Get()) + if err != nil { + t.Fatalf("NewFakeUserspaceEngine: %v", err) + } + t.Cleanup(e.Close) + sys.Set(e) + } + + b, err := NewLocalBackend(logf, logid.PublicID{}, sys, 0) + if err != nil { + t.Fatalf("NewLocalBackend: %v", err) + } + t.Cleanup(b.Shutdown) + + b.SetControlClientGetterForTesting(func(opts controlclient.Options) (controlclient.Client, error) { + return newControl(t, opts), nil + }) + return b +} + +// notificationHandler is any function that can process (e.g., check) a notification. +// It returns whether the notification has been handled or should be passed to the next handler. +// The handler may be called from any goroutine, so it must avoid calling functions +// that are restricted to the goroutine running the test or benchmark function, +// such as [testing.common.FailNow] and [testing.common.Fatalf]. +type notificationHandler func(testing.TB, ipnauth.Actor, *ipn.Notify) bool + +// wantedNotification names a [notificationHandler] that processes a notification +// the test expects and wants to receive. The name is used to report notifications +// that haven't been received within the expected timeout. +type wantedNotification struct { + name string + cond notificationHandler +} + +// notificationWatcher observes [LocalBackend] notifications as the specified actor, +// reporting missing but expected notifications using [testing.common.Error], +// and delegating the handling of unexpected notifications to the [notificationHandler]s. +type notificationWatcher struct { + tb testing.TB + lb *LocalBackend + actor ipnauth.Actor + + mu sync.Mutex + mask ipn.NotifyWatchOpt + want []wantedNotification // notifications we want to receive + unexpected []notificationHandler // funcs that are called to check any other notifications + ctxCancel context.CancelFunc // cancels the outstanding [LocalBackend.WatchNotificationsAs] call + got []*ipn.Notify // all notifications, both wanted and unexpected, we've received so far + gotWanted []*ipn.Notify // only the expected notifications; holds nil for any notification that hasn't been received + gotWantedCh chan struct{} // closed when we have received the last wanted notification + doneCh chan struct{} // closed when [LocalBackend.WatchNotificationsAs] returns +} + +func newNotificationWatcher(tb testing.TB, lb *LocalBackend, actor ipnauth.Actor) *notificationWatcher { + return ¬ificationWatcher{tb: tb, lb: lb, actor: actor} +} + +func (w *notificationWatcher) watch(mask ipn.NotifyWatchOpt, wanted []wantedNotification, unexpected ...notificationHandler) { + w.tb.Helper() + + // Cancel any outstanding [LocalBackend.WatchNotificationsAs] calls. + w.mu.Lock() + ctxCancel := w.ctxCancel + doneCh := w.doneCh + w.mu.Unlock() + if doneCh != nil { + ctxCancel() + <-doneCh + } + + doneCh = make(chan struct{}) + gotWantedCh := make(chan struct{}) + ctx, ctxCancel := context.WithCancel(context.Background()) + w.tb.Cleanup(func() { + ctxCancel() + <-doneCh + }) + + w.mu.Lock() + w.mask = mask + w.want = wanted + w.unexpected = unexpected + w.ctxCancel = ctxCancel + w.got = nil + w.gotWanted = make([]*ipn.Notify, len(wanted)) + w.gotWantedCh = gotWantedCh + w.doneCh = doneCh + w.mu.Unlock() + + watchAddedCh := make(chan struct{}) + go func() { + defer close(doneCh) + if len(wanted) == 0 { + close(gotWantedCh) + if len(unexpected) == 0 { + close(watchAddedCh) + return } - selectNode := deterministicNodeForTest(t, views.SliceOf(wantNodes), tt.lastSuggestion, tt.wantID) + } - var allowList set.Set[tailcfg.StableNodeID] - if tt.allowPolicy != nil { - allowList = set.SetOf(tt.allowPolicy) + var nextWantIdx int + w.lb.WatchNotificationsAs(ctx, w.actor, w.mask, func() { close(watchAddedCh) }, func(notify *ipn.Notify) (keepGoing bool) { + w.tb.Helper() + + w.mu.Lock() + defer w.mu.Unlock() + w.got = append(w.got, notify) + + wanted := false + for i := nextWantIdx; i < len(w.want); i++ { + if wanted = w.want[i].cond(w.tb, w.actor, notify); wanted { + w.gotWanted[i] = notify + nextWantIdx = i + 1 + break + } } - got, err := suggestExitNode(tt.lastReport, tt.netMap, tt.lastSuggestion, selectRegion, selectNode, allowList) - if got.Name != tt.wantName { - t.Errorf("name=%v, want %v", got.Name, tt.wantName) + if wanted && nextWantIdx == len(w.want) { + close(w.gotWantedCh) + if len(w.unexpected) == 0 { + // If we have received the last wanted notification, + // and we don't have any handlers for the unexpected notifications, + // we can stop the watcher right away. + return false + } + } - if got.ID != tt.wantID { - t.Errorf("ID=%v, want %v", got.ID, tt.wantID) + + if !wanted { + // If we've received a notification we didn't expect, + // it could either be an unwanted notification caused by a bug + // or just a miscellaneous one that's irrelevant for the current test. + // Call unexpected notification handlers, if any, to + // check and fail the test if necessary. + for _, h := range w.unexpected { + if h(w.tb, w.actor, notify) { + break + } + } + } + + return true + }) + }() + <-watchAddedCh +} + +func (w *notificationWatcher) check() []*ipn.Notify { + w.tb.Helper() + + w.mu.Lock() + cancel := w.ctxCancel + gotWantedCh := w.gotWantedCh + checkUnexpected := len(w.unexpected) != 0 + doneCh := w.doneCh + w.mu.Unlock() + + // Wait for up to 10 seconds to receive expected notifications. + timeout := 10 * time.Second + for { + select { + case <-gotWantedCh: + if checkUnexpected { + gotWantedCh = nil + // But do not wait longer than 500ms for unexpected notifications after + // the expected notifications have been received. + timeout = 500 * time.Millisecond + continue + } + case <-doneCh: + // [LocalBackend.WatchNotificationsAs] has already returned, so no further + // notifications will be received. There's no reason to wait any longer. + case <-time.After(timeout): + } + cancel() + <-doneCh + break + } + + // Report missing notifications, if any, and log all received notifications, + // including both expected and unexpected ones. + w.mu.Lock() + defer w.mu.Unlock() + if hasMissing := slices.Contains(w.gotWanted, nil); hasMissing { + want := make([]string, len(w.want)) + got := make([]string, 0, len(w.want)) + for i, wn := range w.want { + want[i] = wn.name + if w.gotWanted[i] != nil { + got = append(got, wn.name) + } + } + w.tb.Errorf("Notifications(%s): got %q; want %q", actorDescriptionForTest(w.actor), strings.Join(got, ", "), strings.Join(want, ", ")) + for i, n := range w.got { + w.tb.Logf("%d. %v", i, n) + } + return nil + } + + return w.gotWanted +} + +func actorDescriptionForTest(actor ipnauth.Actor) string { + var parts []string + if actor != nil { + if name, _ := actor.Username(); name != "" { + parts = append(parts, name) + } + if uid := actor.UserID(); uid != "" { + parts = append(parts, string(uid)) + } + if clientID, _ := actor.ClientID(); clientID != ipnauth.NoClientID { + parts = append(parts, clientID.String()) + } + } + return fmt.Sprintf("Actor{%s}", strings.Join(parts, ", ")) +} + +func TestLoginNotifications(t *testing.T) { + const ( + enableLogging = true + controlURL = "https://localhost:1/" + loginURL = "https://localhost:1/1" + ) + + wantBrowseToURL := wantedNotification{ + name: "BrowseToURL", + cond: func(t testing.TB, actor ipnauth.Actor, n *ipn.Notify) bool { + if n.BrowseToURL != nil && *n.BrowseToURL != loginURL { + t.Errorf("BrowseToURL (%s): got %q; want %q", actorDescriptionForTest(actor), *n.BrowseToURL, loginURL) + return false + } + return n.BrowseToURL != nil + }, + } + unexpectedBrowseToURL := func(t testing.TB, actor ipnauth.Actor, n *ipn.Notify) bool { + if n.BrowseToURL != nil { + t.Errorf("Unexpected BrowseToURL(%s): %v", actorDescriptionForTest(actor), n) + return true + } + return false + } + + tests := []struct { + name string + logInAs ipnauth.Actor + urlExpectedBy []ipnauth.Actor + urlUnexpectedBy []ipnauth.Actor + }{ + { + name: "NoObservers", + logInAs: &ipnauth.TestActor{UID: "A"}, + urlExpectedBy: []ipnauth.Actor{}, // ensure that it does not panic if no one is watching + }, + { + name: "SingleUser", + logInAs: &ipnauth.TestActor{UID: "A"}, + urlExpectedBy: []ipnauth.Actor{&ipnauth.TestActor{UID: "A"}}, + }, + { + name: "SameUser/TwoSessions/NoCID", + logInAs: &ipnauth.TestActor{UID: "A"}, + urlExpectedBy: []ipnauth.Actor{&ipnauth.TestActor{UID: "A"}, &ipnauth.TestActor{UID: "A"}}, + }, + { + name: "SameUser/TwoSessions/OneWithCID", + logInAs: &ipnauth.TestActor{UID: "A", CID: ipnauth.ClientIDFrom("123")}, + urlExpectedBy: []ipnauth.Actor{&ipnauth.TestActor{UID: "A", CID: ipnauth.ClientIDFrom("123")}}, + urlUnexpectedBy: []ipnauth.Actor{&ipnauth.TestActor{UID: "A"}}, + }, + { + name: "SameUser/TwoSessions/BothWithCID", + logInAs: &ipnauth.TestActor{UID: "A", CID: ipnauth.ClientIDFrom("123")}, + urlExpectedBy: []ipnauth.Actor{&ipnauth.TestActor{UID: "A", CID: ipnauth.ClientIDFrom("123")}}, + urlUnexpectedBy: []ipnauth.Actor{&ipnauth.TestActor{UID: "A", CID: ipnauth.ClientIDFrom("456")}}, + }, + { + name: "DifferentUsers/NoCID", + logInAs: &ipnauth.TestActor{UID: "A"}, + urlExpectedBy: []ipnauth.Actor{&ipnauth.TestActor{UID: "A"}}, + urlUnexpectedBy: []ipnauth.Actor{&ipnauth.TestActor{UID: "B"}}, + }, + { + name: "DifferentUsers/SameCID", + logInAs: &ipnauth.TestActor{UID: "A"}, + urlExpectedBy: []ipnauth.Actor{&ipnauth.TestActor{UID: "A", CID: ipnauth.ClientIDFrom("123")}}, + urlUnexpectedBy: []ipnauth.Actor{&ipnauth.TestActor{UID: "B", CID: ipnauth.ClientIDFrom("123")}}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + lb := newLocalBackendWithTestControl(t, enableLogging, func(tb testing.TB, opts controlclient.Options) controlclient.Client { + return newClient(tb, opts) + }) + if _, err := lb.EditPrefs(&ipn.MaskedPrefs{ControlURLSet: true, Prefs: ipn.Prefs{ControlURL: controlURL}}); err != nil { + t.Fatalf("(*EditPrefs).Start(): %v", err) + } + if err := lb.Start(ipn.Options{}); err != nil { + t.Fatalf("(*LocalBackend).Start(): %v", err) + } + + sessions := make([]*notificationWatcher, 0, len(tt.urlExpectedBy)+len(tt.urlUnexpectedBy)) + for _, actor := range tt.urlExpectedBy { + session := newNotificationWatcher(t, lb, actor) + session.watch(0, []wantedNotification{wantBrowseToURL}) + sessions = append(sessions, session) } - if tt.wantError == nil && err != nil { - t.Errorf("err=%v, want no error", err) + for _, actor := range tt.urlUnexpectedBy { + session := newNotificationWatcher(t, lb, actor) + session.watch(0, nil, unexpectedBrowseToURL) + sessions = append(sessions, session) } - if tt.wantError != nil && !errors.Is(err, tt.wantError) { - t.Errorf("err=%v, want %v", err, tt.wantError) + + if err := lb.StartLoginInteractiveAs(context.Background(), tt.logInAs); err != nil { + t.Fatal(err) } - if !reflect.DeepEqual(got.Location, tt.wantLocation) { - t.Errorf("location=%v, want %v", got.Location, tt.wantLocation) + + lb.cc.(*mockControl).send(sendOpt{url: loginURL}) + + var wg sync.WaitGroup + wg.Add(len(sessions)) + for _, sess := range sessions { + go func() { // check all sessions in parallel + sess.check() + wg.Done() + }() } + wg.Wait() }) } } -func TestSuggestExitNodePickWeighted(t *testing.T) { - location10 := tailcfg.Location{ - Priority: 10, - } - location20 := tailcfg.Location{ - Priority: 20, +// TestConfigFileReload tests that the LocalBackend reloads its configuration +// when the configuration file changes. +func TestConfigFileReload(t *testing.T) { + type testCase struct { + name string + initial *conffile.Config + updated *conffile.Config + checkFn func(*testing.T, *LocalBackend) } - tests := []struct { - name string - candidates []tailcfg.NodeView - wantIDs []tailcfg.StableNodeID - }{ + tests := []testCase{ { - name: "different priorities", - candidates: []tailcfg.NodeView{ - makePeer(2, withExitRoutes(), withLocation(location20.View())), - makePeer(3, withExitRoutes(), withLocation(location10.View())), + name: "hostname_change", + initial: &conffile.Config{ + Parsed: ipn.ConfigVAlpha{ + Version: "alpha0", + Hostname: ptr.To("initial-host"), + }, + }, + updated: &conffile.Config{ + Parsed: ipn.ConfigVAlpha{ + Version: "alpha0", + Hostname: ptr.To("updated-host"), + }, + }, + checkFn: func(t *testing.T, b *LocalBackend) { + if got := b.Prefs().Hostname(); got != "updated-host" { + t.Errorf("hostname = %q; want updated-host", got) + } }, - wantIDs: []tailcfg.StableNodeID{"stable2"}, }, { - name: "same priorities", - candidates: []tailcfg.NodeView{ - makePeer(2, withExitRoutes(), withLocation(location10.View())), - makePeer(3, withExitRoutes(), withLocation(location10.View())), + name: "start_advertising_services", + initial: &conffile.Config{ + Parsed: ipn.ConfigVAlpha{ + Version: "alpha0", + }, + }, + updated: &conffile.Config{ + Parsed: ipn.ConfigVAlpha{ + Version: "alpha0", + AdvertiseServices: []string{"svc:abc", "svc:def"}, + }, + }, + checkFn: func(t *testing.T, b *LocalBackend) { + if got := b.Prefs().AdvertiseServices().AsSlice(); !reflect.DeepEqual(got, []string{"svc:abc", "svc:def"}) { + t.Errorf("AdvertiseServices = %v; want [svc:abc, svc:def]", got) + } }, - wantIDs: []tailcfg.StableNodeID{"stable2", "stable3"}, }, { - name: "<1 candidates", - candidates: []tailcfg.NodeView{}, + name: "change_advertised_services", + initial: &conffile.Config{ + Parsed: ipn.ConfigVAlpha{ + Version: "alpha0", + AdvertiseServices: []string{"svc:abc", "svc:def"}, + }, + }, + updated: &conffile.Config{ + Parsed: ipn.ConfigVAlpha{ + Version: "alpha0", + AdvertiseServices: []string{"svc:abc", "svc:ghi"}, + }, + }, + checkFn: func(t *testing.T, b *LocalBackend) { + if got := b.Prefs().AdvertiseServices().AsSlice(); !reflect.DeepEqual(got, []string{"svc:abc", "svc:ghi"}) { + t.Errorf("AdvertiseServices = %v; want [svc:abc, svc:ghi]", got) + } + }, }, { - name: "1 candidate", - candidates: []tailcfg.NodeView{ - makePeer(2, withExitRoutes(), withLocation(location20.View())), + name: "unset_advertised_services", + initial: &conffile.Config{ + Parsed: ipn.ConfigVAlpha{ + Version: "alpha0", + AdvertiseServices: []string{"svc:abc"}, + }, + }, + updated: &conffile.Config{ + Parsed: ipn.ConfigVAlpha{ + Version: "alpha0", + }, + }, + checkFn: func(t *testing.T, b *LocalBackend) { + if b.Prefs().AdvertiseServices().Len() != 0 { + t.Errorf("got %d AdvertiseServices wants none", b.Prefs().AdvertiseServices().Len()) + } }, - wantIDs: []tailcfg.StableNodeID{"stable2"}, }, } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := pickWeighted(tt.candidates) - gotIDs := make([]tailcfg.StableNodeID, 0, len(got)) - for _, n := range got { - if !n.Valid() { - gotIDs = append(gotIDs, "") - continue - } - gotIDs = append(gotIDs, n.StableID()) + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "tailscale.conf") + + // Write initial config + initialJSON, err := json.Marshal(tc.initial.Parsed) + if err != nil { + t.Fatal(err) } - if !views.SliceEqualAnyOrder(views.SliceOf(gotIDs), views.SliceOf(tt.wantIDs)) { - t.Errorf("node IDs = %v, want %v", gotIDs, tt.wantIDs) + if err := os.WriteFile(path, initialJSON, 0644); err != nil { + t.Fatal(err) + } + + // Create backend with initial config + tc.initial.Path = path + tc.initial.Raw = initialJSON + sys := tsd.NewSystem() + sys.InitialConfig = tc.initial + b := newTestLocalBackendWithSys(t, sys) + + // Update config file + updatedJSON, err := json.Marshal(tc.updated.Parsed) + if err != nil { + t.Fatal(err) + } + if err := os.WriteFile(path, updatedJSON, 0644); err != nil { + t.Fatal(err) + } + + // Trigger reload + if ok, err := b.ReloadConfig(); !ok || err != nil { + t.Fatalf("ReloadConfig() = %v, %v; want true, nil", ok, err) } + + // Check outcome + tc.checkFn(t, b) }) } } -func TestSuggestExitNodeLongLatDistance(t *testing.T) { +func TestGetVIPServices(t *testing.T) { tests := []struct { - name string - fromLat float64 - fromLong float64 - toLat float64 - toLong float64 - want float64 + name string + advertised []string + serveConfig *ipn.ServeConfig + want []*tailcfg.VIPService }{ { - name: "zero values", - fromLat: 0, - fromLong: 0, - toLat: 0, - toLong: 0, - want: 0, + "advertised-only", + []string{"svc:abc", "svc:def"}, + &ipn.ServeConfig{}, + []*tailcfg.VIPService{ + { + Name: "svc:abc", + Active: true, + }, + { + Name: "svc:def", + Active: true, + }, + }, }, { - name: "valid values", - fromLat: 40.73061, - fromLong: -73.935242, - toLat: 37.3382082, - toLong: -121.8863286, - want: 4117266.873301274, + "served-only", + []string{}, + &ipn.ServeConfig{ + Services: map[tailcfg.ServiceName]*ipn.ServiceConfig{ + "svc:abc": {Tun: true}, + }, + }, + []*tailcfg.VIPService{ + { + Name: "svc:abc", + Ports: []tailcfg.ProtoPortRange{{Ports: tailcfg.PortRangeAny}}, + }, + }, }, { - name: "valid values, locations in north and south of equator", - fromLat: 40.73061, - fromLong: -73.935242, - toLat: -33.861481, - toLong: 151.205475, - want: 15994089.144368416, + "served-and-advertised", + []string{"svc:abc"}, + &ipn.ServeConfig{ + Services: map[tailcfg.ServiceName]*ipn.ServiceConfig{ + "svc:abc": {Tun: true}, + }, + }, + []*tailcfg.VIPService{ + { + Name: "svc:abc", + Active: true, + Ports: []tailcfg.ProtoPortRange{{Ports: tailcfg.PortRangeAny}}, + }, + }, + }, + { + "served-and-advertised-different-service", + []string{"svc:def"}, + &ipn.ServeConfig{ + Services: map[tailcfg.ServiceName]*ipn.ServiceConfig{ + "svc:abc": {Tun: true}, + }, + }, + []*tailcfg.VIPService{ + { + Name: "svc:abc", + Ports: []tailcfg.ProtoPortRange{{Ports: tailcfg.PortRangeAny}}, + }, + { + Name: "svc:def", + Active: true, + }, + }, + }, + { + "served-with-port-ranges-one-range-single", + []string{}, + &ipn.ServeConfig{ + Services: map[tailcfg.ServiceName]*ipn.ServiceConfig{ + "svc:abc": {TCP: map[uint16]*ipn.TCPPortHandler{ + 80: {HTTPS: true}, + }}, + }, + }, + []*tailcfg.VIPService{ + { + Name: "svc:abc", + Ports: []tailcfg.ProtoPortRange{{Proto: 6, Ports: tailcfg.PortRange{First: 80, Last: 80}}}, + }, + }, + }, + { + "served-with-port-ranges-one-range-multiple", + []string{}, + &ipn.ServeConfig{ + Services: map[tailcfg.ServiceName]*ipn.ServiceConfig{ + "svc:abc": {TCP: map[uint16]*ipn.TCPPortHandler{ + 80: {HTTPS: true}, + 81: {HTTPS: true}, + 82: {HTTPS: true}, + }}, + }, + }, + []*tailcfg.VIPService{ + { + Name: "svc:abc", + Ports: []tailcfg.ProtoPortRange{{Proto: 6, Ports: tailcfg.PortRange{First: 80, Last: 82}}}, + }, + }, + }, + { + "served-with-port-ranges-multiple-ranges", + []string{}, + &ipn.ServeConfig{ + Services: map[tailcfg.ServiceName]*ipn.ServiceConfig{ + "svc:abc": {TCP: map[uint16]*ipn.TCPPortHandler{ + 80: {HTTPS: true}, + 81: {HTTPS: true}, + 82: {HTTPS: true}, + 1212: {HTTPS: true}, + 1213: {HTTPS: true}, + 1214: {HTTPS: true}, + }}, + }, + }, + []*tailcfg.VIPService{ + { + Name: "svc:abc", + Ports: []tailcfg.ProtoPortRange{ + {Proto: 6, Ports: tailcfg.PortRange{First: 80, Last: 82}}, + {Proto: 6, Ports: tailcfg.PortRange{First: 1212, Last: 1214}}, + }, + }, + }, }, } - // The wanted values are computed using a more precise algorithm using the WGS84 model but - // longLatDistance uses a spherical approximation for simplicity. To account for this, we allow for - // 10km of error. for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got := longLatDistance(tt.fromLat, tt.fromLong, tt.toLat, tt.toLong) - const maxError = 10000 // 10km - if math.Abs(got-tt.want) > maxError { - t.Errorf("distance=%vm, want within %vm of %vm", got, maxError, tt.want) + lb := newLocalBackendWithTestControl(t, false, func(tb testing.TB, opts controlclient.Options) controlclient.Client { + return newClient(tb, opts) + }) + lb.serveConfig = tt.serveConfig.View() + prefs := &ipn.Prefs{ + AdvertiseServices: tt.advertised, + } + got := lb.vipServicesFromPrefsLocked(prefs.View()) + slices.SortFunc(got, func(a, b *tailcfg.VIPService) int { + return strings.Compare(a.Name.String(), b.Name.String()) + }) + if !reflect.DeepEqual(tt.want, got) { + t.Logf("want:") + for _, s := range tt.want { + t.Logf("%+v", s) + } + t.Logf("got:") + for _, s := range got { + t.Logf("%+v", s) + } + t.Fail() + return } }) } } -func TestMinLatencyDERPregion(t *testing.T) { +func TestUpdatePrefsOnSysPolicyChange(t *testing.T) { + const enableLogging = false + + type fieldChange struct { + name string + want any + } + + wantPrefsChanges := func(want ...fieldChange) *wantedNotification { + return &wantedNotification{ + name: "Prefs", + cond: func(t testing.TB, actor ipnauth.Actor, n *ipn.Notify) bool { + if n.Prefs != nil { + prefs := reflect.Indirect(reflect.ValueOf(n.Prefs.AsStruct())) + for _, f := range want { + got := prefs.FieldByName(f.name).Interface() + if !reflect.DeepEqual(got, f.want) { + t.Errorf("%v: got %v; want %v", f.name, got, f.want) + } + } + } + return n.Prefs != nil + }, + } + } + + unexpectedPrefsChange := func(t testing.TB, _ ipnauth.Actor, n *ipn.Notify) bool { + if n.Prefs != nil { + t.Errorf("Unexpected Prefs: %v", n.Prefs.Pretty()) + return true + } + return false + } + tests := []struct { - name string - regions []int - report *netcheck.Report - wantRegion int + name string + initialPrefs *ipn.Prefs + stringSettings []source.TestSetting[string] + want *wantedNotification }{ { - name: "regions, no latency values", - regions: []int{1, 2, 3}, - wantRegion: 0, - report: &netcheck.Report{}, + name: "ShieldsUp/True", + stringSettings: []source.TestSetting[string]{source.TestSettingOf(pkey.EnableIncomingConnections, "never")}, + want: wantPrefsChanges(fieldChange{"ShieldsUp", true}), }, { - name: "regions, different latency values", - regions: []int{1, 2, 3}, - wantRegion: 2, - report: &netcheck.Report{ - RegionLatency: map[int]time.Duration{ - 1: 10 * time.Millisecond, - 2: 5 * time.Millisecond, - 3: 30 * time.Millisecond, - }, + name: "ShieldsUp/False", + initialPrefs: &ipn.Prefs{ShieldsUp: true}, + stringSettings: []source.TestSetting[string]{source.TestSettingOf(pkey.EnableIncomingConnections, "always")}, + want: wantPrefsChanges(fieldChange{"ShieldsUp", false}), + }, + { + name: "ExitNodeID", + stringSettings: []source.TestSetting[string]{source.TestSettingOf(pkey.ExitNodeID, "foo")}, + want: wantPrefsChanges(fieldChange{"ExitNodeID", tailcfg.StableNodeID("foo")}), + }, + { + name: "EnableRunExitNode", + stringSettings: []source.TestSetting[string]{source.TestSettingOf(pkey.EnableRunExitNode, "always")}, + want: wantPrefsChanges(fieldChange{"AdvertiseRoutes", []netip.Prefix{tsaddr.AllIPv4(), tsaddr.AllIPv6()}}), + }, + { + name: "Multiple", + initialPrefs: &ipn.Prefs{ + ExitNodeAllowLANAccess: true, + }, + stringSettings: []source.TestSetting[string]{ + source.TestSettingOf(pkey.EnableServerMode, "always"), + source.TestSettingOf(pkey.ExitNodeAllowLANAccess, "never"), + source.TestSettingOf(pkey.ExitNodeIP, "127.0.0.1"), }, + want: wantPrefsChanges( + fieldChange{"ForceDaemon", true}, + fieldChange{"ExitNodeAllowLANAccess", false}, + fieldChange{"ExitNodeIP", netip.MustParseAddr("127.0.0.1")}, + ), }, { - name: "regions, same values", - regions: []int{1, 2, 3}, - wantRegion: 1, - report: &netcheck.Report{ - RegionLatency: map[int]time.Duration{ - 1: 10 * time.Millisecond, - 2: 10 * time.Millisecond, - 3: 10 * time.Millisecond, - }, + name: "NoChange", + initialPrefs: &ipn.Prefs{ + CorpDNS: true, + ExitNodeID: "foo", + AdvertiseRoutes: []netip.Prefix{tsaddr.AllIPv4(), tsaddr.AllIPv6()}, }, + stringSettings: []source.TestSetting[string]{ + source.TestSettingOf(pkey.EnableTailscaleDNS, "always"), + source.TestSettingOf(pkey.ExitNodeID, "foo"), + source.TestSettingOf(pkey.EnableRunExitNode, "always"), + }, + want: nil, // syspolicy settings match the preferences; no change notification is expected. }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got := minLatencyDERPRegion(tt.regions, tt.report) - if got != tt.wantRegion { - t.Errorf("got region %v want region %v", got, tt.wantRegion) + var polc policytest.Config + polc.EnableRegisterChangeCallback() + + sys := tsd.NewSystem() + sys.PolicyClient.Set(polc) + lb := newLocalBackendWithSysAndTestControl(t, enableLogging, sys, func(tb testing.TB, opts controlclient.Options) controlclient.Client { + opts.PolicyClient = polc + return newClient(tb, opts) + }) + if tt.initialPrefs != nil { + lb.SetPrefsForTest(tt.initialPrefs) + } + if err := lb.Start(ipn.Options{}); err != nil { + t.Fatalf("(*LocalBackend).Start(): %v", err) + } + + nw := newNotificationWatcher(t, lb, &ipnauth.TestActor{}) + if tt.want != nil { + nw.watch(0, []wantedNotification{*tt.want}) + } else { + nw.watch(0, nil, unexpectedPrefsChange) } + + var batch policytest.Config + for _, ss := range tt.stringSettings { + batch.Set(ss.Key, ss.Value) + } + polc.SetMultiple(batch) + + nw.check() }) } } -func TestShouldAutoExitNode(t *testing.T) { +func TestUpdateIngressAndServiceHashLocked(t *testing.T) { + prefs := ipn.NewPrefs().View() + previousSC := &ipn.ServeConfig{ + Services: map[tailcfg.ServiceName]*ipn.ServiceConfig{ + "svc:abc": {Tun: true}, + }, + } tests := []struct { - name string - exitNodeIDPolicyValue string - expectedBool bool + name string + hi *tailcfg.Hostinfo + hasPreviousSC bool // whether to overwrite the ServeConfig hash in the Hostinfo using previousSC + sc *ipn.ServeConfig + wantIngress bool + wantWireIngress bool + wantControlUpdate bool }{ { - name: "auto:any", - exitNodeIDPolicyValue: "auto:any", - expectedBool: true, + name: "no_hostinfo_no_serve_config", + hi: nil, + }, + { + name: "empty_hostinfo_no_serve_config", + hi: &tailcfg.Hostinfo{}, + }, + { + name: "empty_hostinfo_funnel_enabled", + hi: &tailcfg.Hostinfo{}, + sc: &ipn.ServeConfig{ + AllowFunnel: map[ipn.HostPort]bool{ + "tailnet.xyz:443": true, + }, + }, + wantIngress: true, + wantWireIngress: false, // implied by wantIngress + wantControlUpdate: true, + }, + { + name: "empty_hostinfo_service_configured", + hi: &tailcfg.Hostinfo{}, + sc: &ipn.ServeConfig{ + Services: map[tailcfg.ServiceName]*ipn.ServiceConfig{ + "svc:abc": {Tun: true}, + }, + }, + wantControlUpdate: true, + }, + { + name: "empty_hostinfo_funnel_disabled", + hi: &tailcfg.Hostinfo{}, + sc: &ipn.ServeConfig{ + AllowFunnel: map[ipn.HostPort]bool{ + "tailnet.xyz:443": false, + }, + }, + wantWireIngress: true, // true if there is any AllowFunnel block + wantControlUpdate: true, + }, + { + name: "empty_hostinfo_no_funnel_no_service", + hi: &tailcfg.Hostinfo{}, + sc: &ipn.ServeConfig{ + TCP: map[uint16]*ipn.TCPPortHandler{ + 80: {HTTPS: true}, + }, + }, + }, + { + name: "funnel_enabled_no_change", + hi: &tailcfg.Hostinfo{ + IngressEnabled: true, + }, + sc: &ipn.ServeConfig{ + AllowFunnel: map[ipn.HostPort]bool{ + "tailnet.xyz:443": true, + }, + }, + wantIngress: true, + wantWireIngress: false, // implied by wantIngress + }, + { + name: "service_hash_no_change", + hi: &tailcfg.Hostinfo{}, + hasPreviousSC: true, + sc: &ipn.ServeConfig{ + Services: map[tailcfg.ServiceName]*ipn.ServiceConfig{ + "svc:abc": {Tun: true}, + }, + }, + }, + { + name: "funnel_disabled_no_change", + hi: &tailcfg.Hostinfo{ + WireIngress: true, + }, + sc: &ipn.ServeConfig{ + AllowFunnel: map[ipn.HostPort]bool{ + "tailnet.xyz:443": false, + }, + }, + wantWireIngress: true, // true if there is any AllowFunnel block }, { - name: "no auto prefix", - exitNodeIDPolicyValue: "foo", - expectedBool: false, + name: "service_got_removed", + hi: &tailcfg.Hostinfo{}, + hasPreviousSC: true, + sc: &ipn.ServeConfig{}, + wantControlUpdate: true, }, { - name: "auto prefix but empty suffix", - exitNodeIDPolicyValue: "auto:", - expectedBool: false, + name: "funnel_changes_to_disabled", + hi: &tailcfg.Hostinfo{ + IngressEnabled: true, + }, + sc: &ipn.ServeConfig{ + AllowFunnel: map[ipn.HostPort]bool{ + "tailnet.xyz:443": false, + }, + }, + wantWireIngress: true, // true if there is any AllowFunnel block + wantControlUpdate: true, }, { - name: "auto prefix no colon", - exitNodeIDPolicyValue: "auto", - expectedBool: false, + name: "funnel_changes_to_enabled", + hi: &tailcfg.Hostinfo{ + WireIngress: true, + }, + sc: &ipn.ServeConfig{ + AllowFunnel: map[ipn.HostPort]bool{ + "tailnet.xyz:443": true, + }, + }, + wantIngress: true, + wantWireIngress: false, // implied by wantIngress + wantControlUpdate: true, }, { - name: "auto prefix invalid suffix", - exitNodeIDPolicyValue: "auto:foo", - expectedBool: false, + name: "both_funnel_and_service_changes", + hi: &tailcfg.Hostinfo{ + IngressEnabled: true, + }, + sc: &ipn.ServeConfig{ + AllowFunnel: map[ipn.HostPort]bool{ + "tailnet.xyz:443": false, + }, + Services: map[tailcfg.ServiceName]*ipn.ServiceConfig{ + "svc:abc": {Tun: true}, + }, + }, + wantWireIngress: true, // true if there is any AllowFunnel block + wantControlUpdate: true, }, } + for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - msh := &mockSyspolicyHandler{ - t: t, - stringPolicies: map[syspolicy.Key]*string{ - syspolicy.ExitNodeID: ptr.To(tt.exitNodeIDPolicyValue), - }, + t.Parallel() + b := newTestLocalBackend(t) + b.hostinfo = tt.hi + if tt.hasPreviousSC { + b.mu.Lock() + b.serveConfig = previousSC.View() + b.hostinfo.ServicesHash = vipServiceHash(b.logf, b.vipServicesFromPrefsLocked(prefs)) + b.mu.Unlock() + } + b.serveConfig = tt.sc.View() + allDone := make(chan bool, 1) + defer b.goTracker.AddDoneCallback(func() { + b.mu.Lock() + defer b.mu.Unlock() + if b.goTracker.RunningGoroutines() > 0 { + return + } + select { + case allDone <- true: + default: + } + })() + + was := b.goTracker.StartedGoroutines() + b.maybeSentHostinfoIfChangedLocked(prefs) + + if tt.hi != nil { + if tt.hi.IngressEnabled != tt.wantIngress { + t.Errorf("IngressEnabled = %v, want %v", tt.hi.IngressEnabled, tt.wantIngress) + } + if tt.hi.WireIngress != tt.wantWireIngress { + t.Errorf("WireIngress = %v, want %v", tt.hi.WireIngress, tt.wantWireIngress) + } + b.mu.Lock() + svcHash := vipServiceHash(b.logf, b.vipServicesFromPrefsLocked(prefs)) + b.mu.Unlock() + if tt.hi.ServicesHash != svcHash { + t.Errorf("ServicesHash = %v, want %v", tt.hi.ServicesHash, svcHash) + } + } + + startedGoroutine := b.goTracker.StartedGoroutines() != was + if startedGoroutine != tt.wantControlUpdate { + t.Errorf("control update triggered = %v, want %v", startedGoroutine, tt.wantControlUpdate) } - syspolicy.SetHandlerForTest(t, msh) - got := shouldAutoExitNode() - if got != tt.expectedBool { - t.Fatalf("expected %v got %v for %v policy value", tt.expectedBool, got, tt.exitNodeIDPolicyValue) + + if startedGoroutine { + select { + case <-time.After(5 * time.Second): + t.Fatal("timed out waiting for goroutine to finish") + case <-allDone: + } } }) } } -func TestEnableAutoUpdates(t *testing.T) { - lb := newTestLocalBackend(t) +// TestSrcCapPacketFilter tests that LocalBackend handles packet filters with +// SrcCaps instead of Srcs (IPs) +func TestSrcCapPacketFilter(t *testing.T) { + lb := newLocalBackendWithTestControl(t, false, func(tb testing.TB, opts controlclient.Options) controlclient.Client { + return newClient(tb, opts) + }) + if err := lb.Start(ipn.Options{}); err != nil { + t.Fatalf("(*LocalBackend).Start(): %v", err) + } - _, err := lb.EditPrefs(&ipn.MaskedPrefs{ - AutoUpdateSet: ipn.AutoUpdatePrefsMask{ - ApplySet: true, + var k key.NodePublic + must.Do(k.UnmarshalText([]byte("nodekey:5c8f86d5fc70d924e55f02446165a5dae8f822994ad26bcf4b08fd841f9bf261"))) + + controlClient := lb.cc.(*mockControl) + controlClient.send(sendOpt{nm: &netmap.NetworkMap{ + SelfNode: (&tailcfg.Node{ + Addresses: []netip.Prefix{netip.MustParsePrefix("1.1.1.1/32")}, + }).View(), + Peers: []tailcfg.NodeView{ + (&tailcfg.Node{ + Addresses: []netip.Prefix{netip.MustParsePrefix("2.2.2.2/32")}, + ID: 2, + Key: k, + CapMap: tailcfg.NodeCapMap{"cap-X": nil}, // node 2 has cap + }).View(), + (&tailcfg.Node{ + Addresses: []netip.Prefix{netip.MustParsePrefix("3.3.3.3/32")}, + ID: 3, + Key: k, + CapMap: tailcfg.NodeCapMap{}, // node 3 does not have the cap + }).View(), }, - Prefs: ipn.Prefs{ - AutoUpdate: ipn.AutoUpdatePrefs{ - Apply: opt.NewBool(true), + PacketFilter: []filtertype.Match{{ + IPProto: views.SliceOf([]ipproto.Proto{ipproto.TCP}), + SrcCaps: []tailcfg.NodeCapability{"cap-X"}, // cap in packet filter rule + Dsts: []filtertype.NetPortRange{{ + Net: netip.MustParsePrefix("1.1.1.1/32"), + Ports: filtertype.PortRange{ + First: 22, + Last: 22, + }, + }}, + }}, + }}) + + f := lb.GetFilterForTest() + res := f.Check(netip.MustParseAddr("2.2.2.2"), netip.MustParseAddr("1.1.1.1"), 22, ipproto.TCP) + if res != filter.Accept { + t.Errorf("Check(2.2.2.2, ...) = %s, want %s", res, filter.Accept) + } + + res = f.Check(netip.MustParseAddr("3.3.3.3"), netip.MustParseAddr("1.1.1.1"), 22, ipproto.TCP) + if !res.IsDrop() { + t.Error("IsDrop() for node without cap = false, want true") + } +} + +func TestDisplayMessages(t *testing.T) { + b := newTestLocalBackend(t) + + // Pretend we're in a map poll so health updates get processed + ht := b.HealthTracker() + ht.SetIPNState("NeedsLogin", true) + ht.GotStreamedMapResponse() + + b.mu.Lock() + defer b.mu.Unlock() + b.setNetMapLocked(&netmap.NetworkMap{ + DisplayMessages: map[tailcfg.DisplayMessageID]tailcfg.DisplayMessage{ + "test-message": { + Title: "Testing", }, }, }) - // Enabling may fail, depending on which environment we are running this - // test in. - wantErr := !clientupdate.CanAutoUpdate() - gotErr := err != nil - if gotErr != wantErr { - t.Fatalf("enabling auto-updates: got error: %v (%v); want error: %v", gotErr, err, wantErr) + + state := ht.CurrentState() + wantID := health.WarnableCode("control-health.test-message") + _, ok := state.Warnings[wantID] + + if !ok { + t.Errorf("no warning found with id %q", wantID) } +} - // Disabling should always succeed. - if _, err := lb.EditPrefs(&ipn.MaskedPrefs{ - AutoUpdateSet: ipn.AutoUpdatePrefsMask{ - ApplySet: true, - }, - Prefs: ipn.Prefs{ - AutoUpdate: ipn.AutoUpdatePrefs{ - Apply: opt.NewBool(false), +// TestDisplayMessagesURLFilter tests that we filter out any URLs that are not +// valid as a pop browser URL (see [LocalBackend.validPopBrowserURL]). +func TestDisplayMessagesURLFilter(t *testing.T) { + b := newTestLocalBackend(t) + + // Pretend we're in a map poll so health updates get processed + ht := b.HealthTracker() + ht.SetIPNState("NeedsLogin", true) + ht.GotStreamedMapResponse() + + b.mu.Lock() + defer b.mu.Unlock() + b.setNetMapLocked(&netmap.NetworkMap{ + DisplayMessages: map[tailcfg.DisplayMessageID]tailcfg.DisplayMessage{ + "test-message": { + Title: "Testing", + Severity: tailcfg.SeverityHigh, + PrimaryAction: &tailcfg.DisplayMessageAction{ + URL: "https://www.evil.com", + Label: "Phishing Link", + }, }, }, - }); err != nil { - t.Fatalf("disabling auto-updates: got error: %v", err) + }) + + state := ht.CurrentState() + wantID := health.WarnableCode("control-health.test-message") + got, ok := state.Warnings[wantID] + + if !ok { + t.Fatalf("no warning found with id %q", wantID) } -} -func TestReadWriteRouteInfo(t *testing.T) { - // set up a backend with more than one profile - b := newTestBackend(t) - prof1 := ipn.LoginProfile{ID: "id1", Key: "key1"} - prof2 := ipn.LoginProfile{ID: "id2", Key: "key2"} - b.pm.knownProfiles["id1"] = &prof1 - b.pm.knownProfiles["id2"] = &prof2 - b.pm.currentProfile = &prof1 + want := health.UnhealthyState{ + WarnableCode: wantID, + Title: "Testing", + Severity: health.SeverityHigh, + } - // set up routeInfo - ri1 := &appc.RouteInfo{} - ri1.Wildcards = []string{"1"} + if diff := cmp.Diff(want, got, cmpopts.IgnoreFields(health.UnhealthyState{}, "ETag")); diff != "" { + t.Errorf("Unexpected message content (-want/+got):\n%s", diff) + } +} - ri2 := &appc.RouteInfo{} - ri2.Wildcards = []string{"2"} +// TestDisplayMessageIPNBus checks that we send health messages appropriately +// based on whether the watcher has sent the [ipn.NotifyHealthActions] watch +// option or not. +func TestDisplayMessageIPNBus(t *testing.T) { + type test struct { + name string + mask ipn.NotifyWatchOpt + wantWarning health.UnhealthyState + } - // read before write - readRi, err := b.readRouteInfoLocked() - if readRi != nil { - t.Fatalf("read before writing: want nil, got %v", readRi) + msgs := map[tailcfg.DisplayMessageID]tailcfg.DisplayMessage{ + "test-message": { + Title: "Message title", + Text: "Message text.", + Severity: tailcfg.SeverityMedium, + PrimaryAction: &tailcfg.DisplayMessageAction{ + URL: "https://example.com", + Label: "Learn more", + }, + }, } - if err != ipn.ErrStateNotExist { - t.Fatalf("read before writing: want %v, got %v", ipn.ErrStateNotExist, err) + + wantID := health.WarnableCode("control-health.test-message") + + for _, tt := range []test{ + { + name: "older-client-no-actions", + mask: 0, + wantWarning: health.UnhealthyState{ + WarnableCode: wantID, + Severity: health.SeverityMedium, + Title: "Message title", + Text: "Message text. Learn more: https://example.com", // PrimaryAction appended to text + PrimaryAction: nil, // PrimaryAction not included + }, + }, + { + name: "new-client-with-actions", + mask: ipn.NotifyHealthActions, + wantWarning: health.UnhealthyState{ + WarnableCode: wantID, + Severity: health.SeverityMedium, + Title: "Message title", + Text: "Message text.", + PrimaryAction: &health.UnhealthyStateAction{ + URL: "https://example.com", + Label: "Learn more", + }, + }, + }, + } { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + lb := newLocalBackendWithTestControl(t, false, func(tb testing.TB, opts controlclient.Options) controlclient.Client { + return newClient(tb, opts) + }) + + ipnWatcher := newNotificationWatcher(t, lb, nil) + ipnWatcher.watch(tt.mask, []wantedNotification{{ + name: fmt.Sprintf("warning with ID %q", wantID), + cond: func(_ testing.TB, _ ipnauth.Actor, n *ipn.Notify) bool { + if n.Health == nil { + return false + } + got, ok := n.Health.Warnings[wantID] + if ok { + if diff := cmp.Diff(tt.wantWarning, got, cmpopts.IgnoreFields(health.UnhealthyState{}, "ETag")); diff != "" { + t.Errorf("unexpected warning details (-want/+got):\n%s", diff) + return true // we failed the test so tell the watcher we've seen what we need to to stop it waiting + } + } else { + got := slices.Collect(maps.Keys(n.Health.Warnings)) + t.Logf("saw warnings: %v", got) + } + return ok + }, + }}) + + lb.SetPrefsForTest(&ipn.Prefs{ + ControlURL: "https://localhost:1/", + WantRunning: true, + LoggedOut: false, + }) + if err := lb.Start(ipn.Options{}); err != nil { + t.Fatalf("(*LocalBackend).Start(): %v", err) + } + + cc := lb.cc.(*mockControl) + + // Assert that we are logged in and authorized, and also send our DisplayMessages + cc.send(sendOpt{loginFinished: true, nm: &netmap.NetworkMap{ + SelfNode: (&tailcfg.Node{MachineAuthorized: true}).View(), + DisplayMessages: msgs, + }}) + + // Tell the health tracker that we are in a map poll because + // mockControl doesn't tell it + lb.HealthTracker().GotStreamedMapResponse() + + // Assert that we got the expected notification + ipnWatcher.check() + }) } +} - // write the first routeInfo - if err := b.storeRouteInfo(ri1); err != nil { - t.Fatal(err) +func TestHardwareAttested(t *testing.T) { + b := new(LocalBackend) + + // default false + if got := b.HardwareAttested(); got != false { + t.Errorf("HardwareAttested() = %v, want false", got) } - // write the other routeInfo as the other profile - if err := b.pm.SwitchProfile("id2"); err != nil { - t.Fatal(err) + // set true + b.SetHardwareAttested() + if got := b.HardwareAttested(); got != true { + t.Errorf("HardwareAttested() = %v, want true after SetHardwareAttested()", got) } - if err := b.storeRouteInfo(ri2); err != nil { - t.Fatal(err) + + // repeat calls are safe; still true + b.SetHardwareAttested() + if got := b.HardwareAttested(); got != true { + t.Errorf("HardwareAttested() = %v, want true after second SetHardwareAttested()", got) } +} - // read the routeInfo of the first profile - if err := b.pm.SwitchProfile("id1"); err != nil { - t.Fatal(err) +func TestDeps(t *testing.T) { + deptest.DepChecker{ + OnImport: func(pkg string) { + switch pkg { + case "tailscale.com/util/syspolicy", + "tailscale.com/util/syspolicy/setting", + "tailscale.com/util/syspolicy/rsop": + t.Errorf("ipn/ipnlocal: importing syspolicy package %q is not allowed; only policyclient and its deps should be used by ipn/ipnlocal", pkg) + } + }, + }.Check(t) +} + +func checkError(tb testing.TB, got, want error, fatal bool) { + tb.Helper() + f := tb.Errorf + if fatal { + f = tb.Fatalf } - readRi, err = b.readRouteInfoLocked() - if err != nil { - t.Fatal(err) + if (want == nil) != (got == nil) || + (want != nil && got != nil && want.Error() != got.Error() && !errors.Is(got, want)) { + f("gotErr: %v; wantErr: %v", got, want) } - if !slices.Equal(readRi.Wildcards, ri1.Wildcards) { - t.Fatalf("read prof1 routeInfo wildcards: want %v, got %v", ri1.Wildcards, readRi.Wildcards) +} + +func toStrings[T ~string](in []T) []string { + out := make([]string, len(in)) + for i, v := range in { + out[i] = string(v) } + return out +} - // read the routeInfo of the second profile - if err := b.pm.SwitchProfile("id2"); err != nil { - t.Fatal(err) +type textUpdate struct { + Advertise []string + Unadvertise []string +} + +func routeUpdateToText(u appctype.RouteUpdate) textUpdate { + var out textUpdate + for _, p := range u.Advertise { + out.Advertise = append(out.Advertise, p.String()) } - readRi, err = b.readRouteInfoLocked() - if err != nil { - t.Fatal(err) + for _, p := range u.Unadvertise { + out.Unadvertise = append(out.Unadvertise, p.String()) } - if !slices.Equal(readRi.Wildcards, ri2.Wildcards) { - t.Fatalf("read prof2 routeInfo wildcards: want %v, got %v", ri2.Wildcards, readRi.Wildcards) + return out +} + +func mustPrefix(ss ...string) (out []netip.Prefix) { + for _, s := range ss { + out = append(out, netip.MustParsePrefix(s)) } + return } -func TestFillAllowedSuggestions(t *testing.T) { - tests := []struct { - name string - allowPolicy []string - want []tailcfg.StableNodeID - }{ - { - name: "unset", - }, - { - name: "zero", - allowPolicy: []string{}, - want: []tailcfg.StableNodeID{}, +// eqUpdate generates an eventbus test filter that matches an appctype.RouteUpdate +// message equal to want, or reports an error giving a human-readable diff. +// +// TODO(creachadair): This is copied from the appc test package, but we can't +// put it into the appctest package because the appc tests depend on it and +// that makes a cycle. Clean up those tests and put this somewhere common. +func eqUpdate(want appctype.RouteUpdate) func(appctype.RouteUpdate) error { + return func(got appctype.RouteUpdate) error { + if diff := cmp.Diff(routeUpdateToText(got), routeUpdateToText(want)); diff != "" { + return fmt.Errorf("wrong update (-got, +want):\n%s", diff) + } + return nil + } +} + +type fakeAttestationKey struct{ key.HardwareAttestationKey } + +func (f *fakeAttestationKey) Clone() key.HardwareAttestationKey { + return &fakeAttestationKey{} +} + +// TestStripKeysFromPrefs tests that LocalBackend's [stripKeysFromPrefs] (as used +// by sendNotify etc) correctly removes all private keys from an ipn.Notify. +// +// It does so by testing the the two ways that Notifys are sent: via sendNotify, +// and via extension hooks. +func TestStripKeysFromPrefs(t *testing.T) { + // genNotify generates a sample ipn.Notify with various private keys set + // at a certain path through the Notify data structure. + genNotify := map[string]func() ipn.Notify{ + "Notify.Prefs.Đļ.Persist.PrivateNodeKey": func() ipn.Notify { + return ipn.Notify{ + Prefs: ptr.To((&ipn.Prefs{ + Persist: &persist.Persist{PrivateNodeKey: key.NewNode()}, + }).View()), + } }, - { - name: "one", - allowPolicy: []string{"one"}, - want: []tailcfg.StableNodeID{"one"}, + "Notify.Prefs.Đļ.Persist.OldPrivateNodeKey": func() ipn.Notify { + return ipn.Notify{ + Prefs: ptr.To((&ipn.Prefs{ + Persist: &persist.Persist{OldPrivateNodeKey: key.NewNode()}, + }).View()), + } }, - { - name: "many", - allowPolicy: []string{"one", "two", "three", "four"}, - want: []tailcfg.StableNodeID{"one", "three", "four", "two"}, // order should not matter + "Notify.Prefs.Đļ.Persist.NetworkLockKey": func() ipn.Notify { + return ipn.Notify{ + Prefs: ptr.To((&ipn.Prefs{ + Persist: &persist.Persist{NetworkLockKey: key.NewNLPrivate()}, + }).View()), + } }, - { - name: "preserve case", - allowPolicy: []string{"ABC", "def", "gHiJ"}, - want: []tailcfg.StableNodeID{"ABC", "def", "gHiJ"}, + "Notify.Prefs.Đļ.Persist.AttestationKey": func() ipn.Notify { + return ipn.Notify{ + Prefs: ptr.To((&ipn.Prefs{ + Persist: &persist.Persist{AttestationKey: new(fakeAttestationKey)}, + }).View()), + } }, } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - mh := mockSyspolicyHandler{ - t: t, + + private := key.PrivateTypesForTest() + + for path := range typewalk.MatchingPaths(reflect.TypeFor[ipn.Notify](), private.Contains) { + t.Run(path.Name, func(t *testing.T) { + gen, ok := genNotify[path.Name] + if !ok { + t.Fatalf("no genNotify function for path %q", path.Name) } - if tt.allowPolicy != nil { - mh.stringArrayPolicies = map[syspolicy.Key][]string{ - syspolicy.AllowedSuggestedExitNodes: tt.allowPolicy, - } + withKey := gen() + + if path.Walk(reflect.ValueOf(withKey)).IsZero() { + t.Fatalf("generated notify does not have non-zero value at path %q", path.Name) } - syspolicy.SetHandlerForTest(t, &mh) - got := fillAllowedSuggestions() - if got == nil { - if tt.want == nil { - return + h := &ExtensionHost{} + ch := make(chan *ipn.Notify, 1) + b := &LocalBackend{ + extHost: h, + notifyWatchers: map[string]*watchSession{ + "test": {ch: ch}, + }, + } + + var okay atomic.Int32 + testNotify := func(via string) func(*ipn.Notify) { + return func(n *ipn.Notify) { + if n == nil { + t.Errorf("notify from %s is nil", via) + return + } + if !path.Walk(reflect.ValueOf(*n)).IsZero() { + t.Errorf("notify from %s has non-zero value at path %q; key not stripped", via, path.Name) + } else { + okay.Add(1) + } } - t.Errorf("got nil, want %v", tt.want) } - if tt.want == nil { - t.Errorf("got %v, want nil", got) + + h.Hooks().MutateNotifyLocked.Add(testNotify("MutateNotifyLocked hook")) + + b.send(withKey) + + select { + case n := <-ch: + testNotify("watchSession")(n) + default: + t.Errorf("no notify sent to watcher channel") } - if !got.Equal(set.SetOf(tt.want)) { - t.Errorf("got %v, want %v", got, tt.want) + if got := okay.Load(); got != 2 { + t.Errorf("notify passed validation %d times; want 2", got) } }) } diff --git a/ipn/ipnlocal/loglines_test.go b/ipn/ipnlocal/loglines_test.go index f70987c0e..d831aa8b0 100644 --- a/ipn/ipnlocal/loglines_test.go +++ b/ipn/ipnlocal/loglines_test.go @@ -47,10 +47,10 @@ func TestLocalLogLines(t *testing.T) { idA := logid(0xaa) // set up a LocalBackend, super bare bones. No functional data. - sys := new(tsd.System) + sys := tsd.NewSystem() store := new(mem.Store) sys.Set(store) - e, err := wgengine.NewFakeUserspaceEngine(logf, sys.Set, sys.HealthTracker(), sys.UserMetricsRegistry()) + e, err := wgengine.NewFakeUserspaceEngine(logf, sys.Set, sys.HealthTracker.Get(), sys.UserMetricsRegistry(), sys.Bus.Get()) if err != nil { t.Fatal(err) } diff --git a/ipn/ipnlocal/netstack.go b/ipn/ipnlocal/netstack.go new file mode 100644 index 000000000..f7ffd0305 --- /dev/null +++ b/ipn/ipnlocal/netstack.go @@ -0,0 +1,74 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !ts_omit_netstack + +package ipnlocal + +import ( + "net" + "net/netip" + "time" + + "gvisor.dev/gvisor/pkg/tcpip" + "tailscale.com/types/ptr" +) + +// TCPHandlerForDst returns a TCP handler for connections to dst, or nil if +// no handler is needed. It also returns a list of TCP socket options to +// apply to the socket before calling the handler. +// TCPHandlerForDst is called both for connections to our node's local IP +// as well as to the service IP (quad 100). +func (b *LocalBackend) TCPHandlerForDst(src, dst netip.AddrPort) (handler func(c net.Conn) error, opts []tcpip.SettableSocketOption) { + // First handle internal connections to the service IP + hittingServiceIP := dst.Addr() == magicDNSIP || dst.Addr() == magicDNSIPv6 + if hittingServiceIP { + switch dst.Port() { + case 80: + // TODO(mpminardi): do we want to show an error message if the web client + // has been disabled instead of the more "basic" web UI? + if b.ShouldRunWebClient() { + return b.handleWebClientConn, opts + } + return b.HandleQuad100Port80Conn, opts + case DriveLocalPort: + return b.handleDriveConn, opts + } + } + + if f, ok := hookServeTCPHandlerForVIPService.GetOk(); ok { + if handler := f(b, dst, src); handler != nil { + return handler, opts + } + } + // Then handle external connections to the local IP. + if !b.isLocalIP(dst.Addr()) { + return nil, nil + } + if dst.Port() == 22 && b.ShouldRunSSH() { + // Use a higher keepalive idle time for SSH connections, as they are + // typically long lived and idle connections are more likely to be + // intentional. Ideally we would turn this off entirely, but we can't + // tell the difference between a long lived connection that is idle + // vs a connection that is dead because the peer has gone away. + // We pick 72h as that is typically sufficient for a long weekend. + opts = append(opts, ptr.To(tcpip.KeepaliveIdleOption(72*time.Hour))) + return b.handleSSHConn, opts + } + // TODO(will,sonia): allow customizing web client port ? + if dst.Port() == webClientPort && b.ShouldExposeRemoteWebClient() { + return b.handleWebClientConn, opts + } + if port, ok := b.GetPeerAPIPort(dst.Addr()); ok && dst.Port() == port { + return func(c net.Conn) error { + b.handlePeerAPIConn(src, dst, c) + return nil + }, opts + } + if f, ok := hookTCPHandlerForServe.GetOk(); ok { + if handler := f(b, dst.Port(), src, nil); handler != nil { + return handler, opts + } + } + return nil, nil +} diff --git a/ipn/ipnlocal/network-lock.go b/ipn/ipnlocal/network-lock.go index d20bf94eb..78d4d236d 100644 --- a/ipn/ipnlocal/network-lock.go +++ b/ipn/ipnlocal/network-lock.go @@ -1,6 +1,8 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause +//go:build !ts_omit_tailnetlock + package ipnlocal import ( @@ -21,6 +23,7 @@ import ( "slices" "time" + "tailscale.com/health" "tailscale.com/health/healthmsg" "tailscale.com/ipn" "tailscale.com/ipn/ipnstate" @@ -52,10 +55,68 @@ var ( type tkaState struct { profile ipn.ProfileID authority *tka.Authority - storage *tka.FS + storage tka.CompactableChonk filtered []ipnstate.TKAPeer } +func (b *LocalBackend) initTKALocked() error { + cp := b.pm.CurrentProfile() + if cp.ID() == "" { + b.tka = nil + return nil + } + if b.tka != nil { + if b.tka.profile == cp.ID() { + // Already initialized. + return nil + } + // As we're switching profiles, we need to reset the TKA to nil. + b.tka = nil + } + root := b.TailscaleVarRoot() + if root == "" { + b.tka = nil + b.logf("cannot fetch existing TKA state; no state directory for network-lock") + return nil + } + + chonkDir := b.chonkPathLocked() + if _, err := os.Stat(chonkDir); err == nil { + // The directory exists, which means network-lock has been initialized. + storage, err := tka.ChonkDir(chonkDir) + if err != nil { + return fmt.Errorf("opening tailchonk: %v", err) + } + authority, err := tka.Open(storage) + if err != nil { + return fmt.Errorf("initializing tka: %v", err) + } + + if err := authority.Compact(storage, tkaCompactionDefaults); err != nil { + b.logf("tka compaction failed: %v", err) + } + + b.tka = &tkaState{ + profile: cp.ID(), + authority: authority, + storage: storage, + } + b.logf("tka initialized at head %x", authority.Head()) + } + + return nil +} + +// noNetworkLockStateDirWarnable is a Warnable to warn the user that Tailnet Lock data +// (in particular, the list of AUMs in the TKA state) is being stored in memory and will +// be lost when tailscaled restarts. +var noNetworkLockStateDirWarnable = health.Register(&health.Warnable{ + Code: "no-tailnet-lock-state-dir", + Title: "No statedir for Tailnet Lock", + Severity: health.SeverityMedium, + Text: health.StaticMessage(healthmsg.InMemoryTailnetLockState), +}) + // tkaFilterNetmapLocked checks the signatures on each node key, dropping // nodes from the netmap whose signature does not verify. // @@ -239,8 +300,11 @@ func (b *LocalBackend) tkaSyncIfNeeded(nm *netmap.NetworkMap, prefs ipn.PrefsVie return nil } - if b.tka != nil || nm.TKAEnabled { - b.logf("tkaSyncIfNeeded: enabled=%v, head=%v", nm.TKAEnabled, nm.TKAHead) + isEnabled := b.tka != nil + wantEnabled := nm.TKAEnabled + + if isEnabled || wantEnabled { + b.logf("tkaSyncIfNeeded: isEnabled=%t, wantEnabled=%t, head=%v", isEnabled, wantEnabled, nm.TKAHead) } ourNodeKey, ok := prefs.Persist().PublicNodeKeyOK() @@ -248,8 +312,6 @@ func (b *LocalBackend) tkaSyncIfNeeded(nm *netmap.NetworkMap, prefs ipn.PrefsVie return errors.New("tkaSyncIfNeeded: no node key in prefs") } - isEnabled := b.tka != nil - wantEnabled := nm.TKAEnabled didJustEnable := false if isEnabled != wantEnabled { var ourHead tka.AUMHash @@ -294,6 +356,13 @@ func (b *LocalBackend) tkaSyncIfNeeded(nm *netmap.NetworkMap, prefs ipn.PrefsVie if err := b.tkaSyncLocked(ourNodeKey); err != nil { return fmt.Errorf("tka sync: %w", err) } + // Try to compact the TKA state, to avoid unbounded storage on nodes. + // + // We run this on every sync so that clients compact consistently. In many + // cases this will be a no-op. + if err := b.tka.authority.Compact(b.tka.storage, tkaCompactionDefaults); err != nil { + return fmt.Errorf("tka compact: %w", err) + } } return nil @@ -393,7 +462,7 @@ func (b *LocalBackend) tkaSyncLocked(ourNodeKey key.NodePublic) error { // b.mu must be held & TKA must be initialized. func (b *LocalBackend) tkaApplyDisablementLocked(secret []byte) error { if b.tka.authority.ValidDisablement(secret) { - if err := os.RemoveAll(b.chonkPathLocked()); err != nil { + if err := b.tka.storage.RemoveAll(); err != nil { return err } b.tka = nil @@ -407,7 +476,7 @@ func (b *LocalBackend) tkaApplyDisablementLocked(secret []byte) error { // // b.mu must be held. func (b *LocalBackend) chonkPathLocked() string { - return filepath.Join(b.TailscaleVarRoot(), "tka-profiles", string(b.pm.CurrentProfile().ID)) + return filepath.Join(b.TailscaleVarRoot(), "tka-profiles", string(b.pm.CurrentProfile().ID())) } // tkaBootstrapFromGenesisLocked initializes the local (on-disk) state of the @@ -415,10 +484,6 @@ func (b *LocalBackend) chonkPathLocked() string { // // b.mu must be held. func (b *LocalBackend) tkaBootstrapFromGenesisLocked(g tkatype.MarshaledAUM, persist persist.PersistView) error { - if err := b.CanSupportNetworkLock(); err != nil { - return err - } - var genesis tka.AUM if err := genesis.Unserialize(g); err != nil { return fmt.Errorf("reading genesis: %v", err) @@ -430,54 +495,37 @@ func (b *LocalBackend) tkaBootstrapFromGenesisLocked(g tkatype.MarshaledAUM, per } bootstrapStateID := fmt.Sprintf("%d:%d", genesis.State.StateID1, genesis.State.StateID2) - for i := range persist.DisallowedTKAStateIDs().Len() { - stateID := persist.DisallowedTKAStateIDs().At(i) + for _, stateID := range persist.DisallowedTKAStateIDs().All() { if stateID == bootstrapStateID { return fmt.Errorf("TKA with stateID of %q is disallowed on this node", stateID) } } } - chonkDir := b.chonkPathLocked() - if err := os.Mkdir(filepath.Dir(chonkDir), 0755); err != nil && !os.IsExist(err) { - return fmt.Errorf("creating chonk root dir: %v", err) - } - if err := os.Mkdir(chonkDir, 0755); err != nil && !os.IsExist(err) { - return fmt.Errorf("mkdir: %v", err) - } - - chonk, err := tka.ChonkDir(chonkDir) - if err != nil { - return fmt.Errorf("chonk: %v", err) + root := b.TailscaleVarRoot() + var storage tka.CompactableChonk + if root == "" { + b.health.SetUnhealthy(noNetworkLockStateDirWarnable, nil) + b.logf("network-lock using in-memory storage; no state directory") + storage = tka.ChonkMem() + } else { + chonkDir := b.chonkPathLocked() + chonk, err := tka.ChonkDir(chonkDir) + if err != nil { + return fmt.Errorf("chonk: %v", err) + } + storage = chonk } - authority, err := tka.Bootstrap(chonk, genesis) + authority, err := tka.Bootstrap(storage, genesis) if err != nil { return fmt.Errorf("tka bootstrap: %v", err) } b.tka = &tkaState{ - profile: b.pm.CurrentProfile().ID, + profile: b.pm.CurrentProfile().ID(), authority: authority, - storage: chonk, - } - return nil -} - -// CanSupportNetworkLock returns nil if tailscaled is able to operate -// a local tailnet key authority (and hence enforce network lock). -func (b *LocalBackend) CanSupportNetworkLock() error { - if b.tka != nil { - // If the TKA is being used, it is supported. - return nil - } - - if b.TailscaleVarRoot() == "" { - return errors.New("network-lock is not supported in this configuration, try setting --statedir") + storage: storage, } - - // There's a var root (aka --statedir), so if network lock gets - // initialized we have somewhere to store our AUMs. That's all - // we need. return nil } @@ -517,9 +565,10 @@ func (b *LocalBackend) NetworkLockStatus() *ipnstate.NetworkLockStatus { var selfAuthorized bool nodeKeySignature := &tka.NodeKeySignature{} - if b.netMap != nil { - selfAuthorized = b.tka.authority.NodeKeyAuthorized(b.netMap.SelfNode.Key(), b.netMap.SelfNode.KeySignature().AsSlice()) == nil - if err := nodeKeySignature.Unserialize(b.netMap.SelfNode.KeySignature().AsSlice()); err != nil { + nm := b.currentNode().NetMap() + if nm != nil { + selfAuthorized = b.tka.authority.NodeKeyAuthorized(nm.SelfNode.Key(), nm.SelfNode.KeySignature().AsSlice()) == nil + if err := nodeKeySignature.Unserialize(nm.SelfNode.KeySignature().AsSlice()); err != nil { b.logf("failed to decode self node key signature: %v", err) } } @@ -540,9 +589,9 @@ func (b *LocalBackend) NetworkLockStatus() *ipnstate.NetworkLockStatus { } var visible []*ipnstate.TKAPeer - if b.netMap != nil { - visible = make([]*ipnstate.TKAPeer, len(b.netMap.Peers)) - for i, p := range b.netMap.Peers { + if nm != nil { + visible = make([]*ipnstate.TKAPeer, len(nm.Peers)) + for i, p := range nm.Peers { s := tkaStateFromPeer(p) visible[i] = &s } @@ -572,8 +621,7 @@ func tkaStateFromPeer(p tailcfg.NodeView) ipnstate.TKAPeer { TailscaleIPs: make([]netip.Addr, 0, p.Addresses().Len()), NodeKey: p.Key(), } - for i := range p.Addresses().Len() { - addr := p.Addresses().At(i) + for _, addr := range p.Addresses().All() { if addr.IsSingleIP() && tsaddr.IsTailscaleIP(addr.Addr()) { fp.TailscaleIPs = append(fp.TailscaleIPs, addr.Addr()) } @@ -595,24 +643,16 @@ func tkaStateFromPeer(p tailcfg.NodeView) ipnstate.TKAPeer { // The Finish RPC submits signatures for all these nodes, at which point // Control has everything it needs to atomically enable network lock. func (b *LocalBackend) NetworkLockInit(keys []tka.Key, disablementValues [][]byte, supportDisablement []byte) error { - if err := b.CanSupportNetworkLock(); err != nil { - return err - } - var ourNodeKey key.NodePublic var nlPriv key.NLPrivate - b.mu.Lock() - - if !b.capTailnetLock { - b.mu.Unlock() - return errors.New("not permitted to enable tailnet lock") - } + b.mu.Lock() if p := b.pm.CurrentPrefs(); p.Valid() && p.Persist().Valid() && !p.Persist().PrivateNodeKey().IsZero() { ourNodeKey = p.Persist().PublicNodeKey() nlPriv = p.Persist().NetworkLockKey() } b.mu.Unlock() + if ourNodeKey.IsZero() || nlPriv.IsZero() { return errors.New("no node-key: is tailscale logged in?") } @@ -626,7 +666,7 @@ func (b *LocalBackend) NetworkLockInit(keys []tka.Key, disablementValues [][]byt // We use an in-memory tailchonk because we don't want to commit to // the filesystem until we've finished the initialization sequence, // just in case something goes wrong. - _, genesisAUM, err := tka.Create(&tka.Mem{}, tka.State{ + _, genesisAUM, err := tka.Create(tka.ChonkMem(), tka.State{ Keys: keys, // TODO(tom): s/tka.State.DisablementSecrets/tka.State.DisablementValues // This will center on consistent nomenclature: @@ -672,6 +712,13 @@ func (b *LocalBackend) NetworkLockInit(keys []tka.Key, disablementValues [][]byt return err } +// NetworkLockAllowed reports whether the node is allowed to use Tailnet Lock. +func (b *LocalBackend) NetworkLockAllowed() bool { + b.mu.Lock() + defer b.mu.Unlock() + return b.capTailnetLock +} + // Only use is in tests. func (b *LocalBackend) NetworkLockVerifySignatureForTest(nks tkatype.MarshaledSignature, nodeKey key.NodePublic) error { b.mu.Lock() @@ -704,16 +751,14 @@ func (b *LocalBackend) NetworkLockForceLocalDisable() error { id1, id2 := b.tka.authority.StateIDs() stateID := fmt.Sprintf("%d:%d", id1, id2) + cn := b.currentNode() newPrefs := b.pm.CurrentPrefs().AsStruct().Clone() // .Persist should always be initialized here. newPrefs.Persist.DisallowedTKAStateIDs = append(newPrefs.Persist.DisallowedTKAStateIDs, stateID) - if err := b.pm.SetPrefs(newPrefs.View(), ipn.NetworkProfile{ - MagicDNSName: b.netMap.MagicDNSSuffix(), - DomainName: b.netMap.DomainName(), - }); err != nil { + if err := b.pm.SetPrefs(newPrefs.View(), cn.NetworkProfile()); err != nil { return fmt.Errorf("saving prefs: %w", err) } - if err := os.RemoveAll(b.chonkPathLocked()); err != nil { + if err := b.tka.storage.RemoveAll(); err != nil { return fmt.Errorf("deleting TKA state: %w", err) } b.tka = nil @@ -899,7 +944,7 @@ func (b *LocalBackend) NetworkLockLog(maxEntries int) ([]ipnstate.NetworkLockUpd if err == os.ErrNotExist { break } - return out, fmt.Errorf("reading AUM: %w", err) + return out, fmt.Errorf("reading AUM (%v): %w", cursor, err) } update := ipnstate.NetworkLockUpdate{ diff --git a/ipn/ipnlocal/network-lock_test.go b/ipn/ipnlocal/network-lock_test.go index 4b79136c8..5d22425a1 100644 --- a/ipn/ipnlocal/network-lock_test.go +++ b/ipn/ipnlocal/network-lock_test.go @@ -1,6 +1,8 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause +//go:build !ts_omit_tailnetlock + package ipnlocal import ( @@ -15,6 +17,7 @@ import ( "path/filepath" "reflect" "testing" + "time" go4mem "go4.org/mem" @@ -28,26 +31,25 @@ import ( "tailscale.com/net/tsdial" "tailscale.com/tailcfg" "tailscale.com/tka" + "tailscale.com/tsd" + "tailscale.com/tstest" "tailscale.com/types/key" "tailscale.com/types/netmap" "tailscale.com/types/persist" "tailscale.com/types/tkatype" + "tailscale.com/util/eventbus/eventbustest" "tailscale.com/util/must" "tailscale.com/util/set" ) -type observerFunc func(controlclient.Status) - -func (f observerFunc) SetControlClientStatus(_ controlclient.Client, s controlclient.Status) { - f(s) -} - func fakeControlClient(t *testing.T, c *http.Client) *controlclient.Auto { hi := hostinfo.New() ni := tailcfg.NetInfo{LinkType: "wired"} hi.NetInfo = &ni + bus := eventbustest.NewBus(t) k := key.NewMachine() + dialer := tsdial.NewDialer(netmon.NewStatic()) opts := controlclient.Options{ ServerURL: "https://example.com", Hostinfo: hi, @@ -56,11 +58,13 @@ func fakeControlClient(t *testing.T, c *http.Client) *controlclient.Auto { }, HTTPTestClient: c, NoiseTestClient: c, - Observer: observerFunc(func(controlclient.Status) {}), - Dialer: tsdial.NewDialer(netmon.NewStatic()), + Dialer: dialer, + Bus: bus, + + SkipStartForTests: true, } - cc, err := controlclient.NewNoStart(opts) + cc, err := controlclient.New(opts) if err != nil { t.Fatal(err) } @@ -68,6 +72,7 @@ func fakeControlClient(t *testing.T, c *http.Client) *controlclient.Auto { } func fakeNoiseServer(t *testing.T, handler http.HandlerFunc) (*httptest.Server, *http.Client) { + t.Helper() ts := httptest.NewUnstartedServer(handler) ts.StartTLS() client := ts.Client() @@ -78,6 +83,17 @@ func fakeNoiseServer(t *testing.T, handler http.HandlerFunc) (*httptest.Server, return ts, client } +func setupProfileManager(t *testing.T, nodePriv key.NodePrivate, nlPriv key.NLPrivate) *profileManager { + pm := must.Get(newProfileManager(new(mem.Store), t.Logf, health.NewTracker(eventbustest.NewBus(t)))) + must.Do(pm.SetPrefs((&ipn.Prefs{ + Persist: &persist.Persist{ + PrivateNodeKey: nodePriv, + NetworkLockKey: nlPriv, + }, + }).View(), ipn.NetworkProfile{})) + return pm +} + func TestTKAEnablementFlow(t *testing.T) { nodePriv := key.NewNode() @@ -85,7 +101,7 @@ func TestTKAEnablementFlow(t *testing.T) { // our mock server can communicate. nlPriv := key.NewNLPrivate() key := tka.Key{Kind: tka.Key25519, Public: nlPriv.Public().Verifier(), Votes: 2} - a1, genesisAUM, err := tka.Create(&tka.Mem{}, tka.State{ + a1, genesisAUM, err := tka.Create(tka.ChonkMem(), tka.State{ Keys: []tka.Key{key}, DisablementSecrets: [][]byte{bytes.Repeat([]byte{0xa5}, 32)}, }, nlPriv) @@ -153,13 +169,7 @@ func TestTKAEnablementFlow(t *testing.T) { temp := t.TempDir() cc := fakeControlClient(t, client) - pm := must.Get(newProfileManager(new(mem.Store), t.Logf, new(health.Tracker))) - must.Do(pm.SetPrefs((&ipn.Prefs{ - Persist: &persist.Persist{ - PrivateNodeKey: nodePriv, - NetworkLockKey: nlPriv, - }, - }).View(), ipn.NetworkProfile{})) + pm := setupProfileManager(t, nodePriv, nlPriv) b := LocalBackend{ capTailnetLock: true, varRoot: temp, @@ -193,16 +203,10 @@ func TestTKADisablementFlow(t *testing.T) { nlPriv := key.NewNLPrivate() key := tka.Key{Kind: tka.Key25519, Public: nlPriv.Public().Verifier(), Votes: 2} - pm := must.Get(newProfileManager(new(mem.Store), t.Logf, new(health.Tracker))) - must.Do(pm.SetPrefs((&ipn.Prefs{ - Persist: &persist.Persist{ - PrivateNodeKey: nodePriv, - NetworkLockKey: nlPriv, - }, - }).View(), ipn.NetworkProfile{})) + pm := setupProfileManager(t, nodePriv, nlPriv) temp := t.TempDir() - tkaPath := filepath.Join(temp, "tka-profile", string(pm.CurrentProfile().ID)) + tkaPath := filepath.Join(temp, "tka-profile", string(pm.CurrentProfile().ID())) os.Mkdir(tkaPath, 0755) chonk, err := tka.ChonkDir(tkaPath) if err != nil { @@ -385,17 +389,11 @@ func TestTKASync(t *testing.T) { t.Run(tc.name, func(t *testing.T) { nodePriv := key.NewNode() nlPriv := key.NewNLPrivate() - pm := must.Get(newProfileManager(new(mem.Store), t.Logf, new(health.Tracker))) - must.Do(pm.SetPrefs((&ipn.Prefs{ - Persist: &persist.Persist{ - PrivateNodeKey: nodePriv, - NetworkLockKey: nlPriv, - }, - }).View(), ipn.NetworkProfile{})) + pm := setupProfileManager(t, nodePriv, nlPriv) // Setup the tka authority on the control plane. key := tka.Key{Kind: tka.Key25519, Public: nlPriv.Public().Verifier(), Votes: 2} - controlStorage := &tka.Mem{} + controlStorage := tka.ChonkMem() controlAuthority, bootstrap, err := tka.Create(controlStorage, tka.State{ Keys: []tka.Key{key, someKey}, DisablementSecrets: [][]byte{tka.DisablementKDF(disablementSecret)}, @@ -410,7 +408,7 @@ func TestTKASync(t *testing.T) { } temp := t.TempDir() - tkaPath := filepath.Join(temp, "tka-profile", string(pm.CurrentProfile().ID)) + tkaPath := filepath.Join(temp, "tka-profile", string(pm.CurrentProfile().ID())) os.Mkdir(tkaPath, 0755) // Setup the TKA authority on the node. nodeStorage, err := tka.ChonkDir(tkaPath) @@ -526,7 +524,7 @@ func TestTKASync(t *testing.T) { }, } - // Finally, lets trigger a sync. + // Finally, let's trigger a sync. err = b.tkaSyncIfNeeded(&netmap.NetworkMap{ TKAEnabled: true, TKAHead: controlAuthority.Head(), @@ -544,10 +542,220 @@ func TestTKASync(t *testing.T) { } } +// Whenever we run a TKA sync and get new state from control, we compact the +// local state. +func TestTKASyncTriggersCompact(t *testing.T) { + someKeyPriv := key.NewNLPrivate() + someKey := tka.Key{Kind: tka.Key25519, Public: someKeyPriv.Public().Verifier(), Votes: 1} + + disablementSecret := bytes.Repeat([]byte{0xa5}, 32) + + nodePriv := key.NewNode() + nlPriv := key.NewNLPrivate() + pm := setupProfileManager(t, nodePriv, nlPriv) + + // Create a clock, and roll it back by 30 days. + // + // Our compaction algorithm preserves AUMs received in the last 14 days, so + // we need to backdate the commit times to make the AUMs eligible for compaction. + clock := tstest.NewClock(tstest.ClockOpts{}) + clock.Advance(-30 * 24 * time.Hour) + + // Set up the TKA authority on the control plane. + key := tka.Key{Kind: tka.Key25519, Public: nlPriv.Public().Verifier(), Votes: 2} + controlStorage := tka.ChonkMem() + controlStorage.SetClock(clock) + controlAuthority, bootstrap, err := tka.Create(controlStorage, tka.State{ + Keys: []tka.Key{key, someKey}, + DisablementSecrets: [][]byte{tka.DisablementKDF(disablementSecret)}, + }, nlPriv) + if err != nil { + t.Fatalf("tka.Create() failed: %v", err) + } + + // Fill the control plane TKA authority with a lot of AUMs, enough so that: + // + // 1. the chain of AUMs includes some checkpoints + // 2. the chain is long enough it would be trimmed if we ran the compaction + // algorithm with the defaults + for range 100 { + upd := controlAuthority.NewUpdater(nlPriv) + if err := upd.RemoveKey(someKey.MustID()); err != nil { + t.Fatalf("RemoveKey: %v", err) + } + if err := upd.AddKey(someKey); err != nil { + t.Fatalf("AddKey: %v", err) + } + aums, err := upd.Finalize(controlStorage) + if err != nil { + t.Fatalf("Finalize: %v", err) + } + if err := controlAuthority.Inform(controlStorage, aums); err != nil { + t.Fatalf("controlAuthority.Inform() failed: %v", err) + } + } + + // Set up the TKA authority on the node. + nodeStorage := tka.ChonkMem() + nodeStorage.SetClock(clock) + nodeAuthority, err := tka.Bootstrap(nodeStorage, bootstrap) + if err != nil { + t.Fatalf("tka.Bootstrap() failed: %v", err) + } + + // Make a mock control server. + ts, client := fakeNoiseServer(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + defer r.Body.Close() + switch r.URL.Path { + case "/machine/tka/sync/offer": + body := new(tailcfg.TKASyncOfferRequest) + if err := json.NewDecoder(r.Body).Decode(body); err != nil { + t.Fatal(err) + } + t.Logf("got sync offer:\n%+v", body) + nodeOffer, err := toSyncOffer(body.Head, body.Ancestors) + if err != nil { + t.Fatal(err) + } + controlOffer, err := controlAuthority.SyncOffer(controlStorage) + if err != nil { + t.Fatal(err) + } + sendAUMs, err := controlAuthority.MissingAUMs(controlStorage, nodeOffer) + if err != nil { + t.Fatal(err) + } + + head, ancestors, err := fromSyncOffer(controlOffer) + if err != nil { + t.Fatal(err) + } + resp := tailcfg.TKASyncOfferResponse{ + Head: head, + Ancestors: ancestors, + MissingAUMs: make([]tkatype.MarshaledAUM, len(sendAUMs)), + } + for i, a := range sendAUMs { + resp.MissingAUMs[i] = a.Serialize() + } + + t.Logf("responding to sync offer with:\n%+v", resp) + w.WriteHeader(200) + if err := json.NewEncoder(w).Encode(resp); err != nil { + t.Fatal(err) + } + + case "/machine/tka/sync/send": + body := new(tailcfg.TKASyncSendRequest) + if err := json.NewDecoder(r.Body).Decode(body); err != nil { + t.Fatal(err) + } + t.Logf("got sync send:\n%+v", body) + + var remoteHead tka.AUMHash + if err := remoteHead.UnmarshalText([]byte(body.Head)); err != nil { + t.Fatalf("head unmarshal: %v", err) + } + toApply := make([]tka.AUM, len(body.MissingAUMs)) + for i, a := range body.MissingAUMs { + if err := toApply[i].Unserialize(a); err != nil { + t.Fatalf("decoding missingAUM[%d]: %v", i, err) + } + } + + if len(toApply) > 0 { + if err := controlAuthority.Inform(controlStorage, toApply); err != nil { + t.Fatalf("control.Inform(%+v) failed: %v", toApply, err) + } + } + head, err := controlAuthority.Head().MarshalText() + if err != nil { + t.Fatal(err) + } + + w.WriteHeader(200) + if err := json.NewEncoder(w).Encode(tailcfg.TKASyncSendResponse{ + Head: string(head), + }); err != nil { + t.Fatal(err) + } + + default: + t.Errorf("unhandled endpoint path: %v", r.URL.Path) + w.WriteHeader(404) + } + })) + defer ts.Close() + + // Setup the client. + cc := fakeControlClient(t, client) + b := LocalBackend{ + cc: cc, + ccAuto: cc, + logf: t.Logf, + pm: pm, + store: pm.Store(), + tka: &tkaState{ + authority: nodeAuthority, + storage: nodeStorage, + }, + } + + // Trigger a sync. + err = b.tkaSyncIfNeeded(&netmap.NetworkMap{ + TKAEnabled: true, + TKAHead: controlAuthority.Head(), + }, pm.CurrentPrefs()) + if err != nil { + t.Errorf("tkaSyncIfNeeded() failed: %v", err) + } + + // Add a new AUM in control. + upd := controlAuthority.NewUpdater(nlPriv) + if err := upd.RemoveKey(someKey.MustID()); err != nil { + t.Fatalf("RemoveKey: %v", err) + } + aums, err := upd.Finalize(controlStorage) + if err != nil { + t.Fatalf("Finalize: %v", err) + } + if err := controlAuthority.Inform(controlStorage, aums); err != nil { + t.Fatalf("controlAuthority.Inform() failed: %v", err) + } + + // Run a second sync, which should trigger a compaction. + err = b.tkaSyncIfNeeded(&netmap.NetworkMap{ + TKAEnabled: true, + TKAHead: controlAuthority.Head(), + }, pm.CurrentPrefs()) + if err != nil { + t.Errorf("tkaSyncIfNeeded() failed: %v", err) + } + + // Check that the node and control plane are in sync. + if nodeHead, controlHead := b.tka.authority.Head(), controlAuthority.Head(); nodeHead != controlHead { + t.Errorf("node head = %v, want %v", nodeHead, controlHead) + } + + // Check the node has compacted away some of its AUMs; that it has purged some AUMs which + // are still kept in the control plane. + nodeAUMs, err := b.tka.storage.AllAUMs() + if err != nil { + t.Errorf("AllAUMs() for node failed: %v", err) + } + controlAUMS, err := controlStorage.AllAUMs() + if err != nil { + t.Errorf("AllAUMs() for control failed: %v", err) + } + if len(nodeAUMs) == len(controlAUMS) { + t.Errorf("node has not compacted; it has the same number of AUMs as control (node = control = %d)", len(nodeAUMs)) + } +} + func TestTKAFilterNetmap(t *testing.T) { nlPriv := key.NewNLPrivate() nlKey := tka.Key{Kind: tka.Key25519, Public: nlPriv.Public().Verifier(), Votes: 2} - storage := &tka.Mem{} + storage := tka.ChonkMem() authority, _, err := tka.Create(storage, tka.State{ Keys: []tka.Key{nlKey}, DisablementSecrets: [][]byte{bytes.Repeat([]byte{0xa5}, 32)}, @@ -701,16 +909,10 @@ func TestTKADisable(t *testing.T) { disablementSecret := bytes.Repeat([]byte{0xa5}, 32) nlPriv := key.NewNLPrivate() - pm := must.Get(newProfileManager(new(mem.Store), t.Logf, new(health.Tracker))) - must.Do(pm.SetPrefs((&ipn.Prefs{ - Persist: &persist.Persist{ - PrivateNodeKey: nodePriv, - NetworkLockKey: nlPriv, - }, - }).View(), ipn.NetworkProfile{})) + pm := setupProfileManager(t, nodePriv, nlPriv) temp := t.TempDir() - tkaPath := filepath.Join(temp, "tka-profile", string(pm.CurrentProfile().ID)) + tkaPath := filepath.Join(temp, "tka-profile", string(pm.CurrentProfile().ID())) os.Mkdir(tkaPath, 0755) key := tka.Key{Kind: tka.Key25519, Public: nlPriv.Public().Verifier(), Votes: 2} chonk, err := tka.ChonkDir(tkaPath) @@ -770,7 +972,7 @@ func TestTKADisable(t *testing.T) { ccAuto: cc, logf: t.Logf, tka: &tkaState{ - profile: pm.CurrentProfile().ID, + profile: pm.CurrentProfile().ID(), authority: authority, storage: chonk, }, @@ -792,20 +994,14 @@ func TestTKASign(t *testing.T) { toSign := key.NewNode() nlPriv := key.NewNLPrivate() - pm := must.Get(newProfileManager(new(mem.Store), t.Logf, new(health.Tracker))) - must.Do(pm.SetPrefs((&ipn.Prefs{ - Persist: &persist.Persist{ - PrivateNodeKey: nodePriv, - NetworkLockKey: nlPriv, - }, - }).View(), ipn.NetworkProfile{})) + pm := setupProfileManager(t, nodePriv, nlPriv) // Make a fake TKA authority, to seed local state. disablementSecret := bytes.Repeat([]byte{0xa5}, 32) key := tka.Key{Kind: tka.Key25519, Public: nlPriv.Public().Verifier(), Votes: 2} temp := t.TempDir() - tkaPath := filepath.Join(temp, "tka-profile", string(pm.CurrentProfile().ID)) + tkaPath := filepath.Join(temp, "tka-profile", string(pm.CurrentProfile().ID())) os.Mkdir(tkaPath, 0755) chonk, err := tka.ChonkDir(tkaPath) if err != nil { @@ -881,16 +1077,10 @@ func TestTKAForceDisable(t *testing.T) { nlPriv := key.NewNLPrivate() key := tka.Key{Kind: tka.Key25519, Public: nlPriv.Public().Verifier(), Votes: 2} - pm := must.Get(newProfileManager(new(mem.Store), t.Logf, new(health.Tracker))) - must.Do(pm.SetPrefs((&ipn.Prefs{ - Persist: &persist.Persist{ - PrivateNodeKey: nodePriv, - NetworkLockKey: nlPriv, - }, - }).View(), ipn.NetworkProfile{})) + pm := setupProfileManager(t, nodePriv, nlPriv) temp := t.TempDir() - tkaPath := filepath.Join(temp, "tka-profile", string(pm.CurrentProfile().ID)) + tkaPath := filepath.Join(temp, "tka-profile", string(pm.CurrentProfile().ID())) os.Mkdir(tkaPath, 0755) chonk, err := tka.ChonkDir(tkaPath) if err != nil { @@ -935,18 +1125,21 @@ func TestTKAForceDisable(t *testing.T) { defer ts.Close() cc := fakeControlClient(t, client) - b := LocalBackend{ - varRoot: temp, - cc: cc, - ccAuto: cc, - logf: t.Logf, - tka: &tkaState{ - authority: authority, - storage: chonk, - }, - pm: pm, - store: pm.Store(), + sys := tsd.NewSystem() + sys.Set(pm.Store()) + + b := newTestLocalBackendWithSys(t, sys) + b.SetVarRoot(temp) + b.SetControlClientGetterForTesting(func(controlclient.Options) (controlclient.Client, error) { + return cc, nil + }) + b.mu.Lock() + b.tka = &tkaState{ + authority: authority, + storage: chonk, } + b.pm = pm + b.mu.Unlock() if err := b.NetworkLockForceLocalDisable(); err != nil { t.Fatalf("NetworkLockForceLocalDisable() failed: %v", err) @@ -976,20 +1169,14 @@ func TestTKAAffectedSigs(t *testing.T) { // toSign := key.NewNode() nlPriv := key.NewNLPrivate() - pm := must.Get(newProfileManager(new(mem.Store), t.Logf, new(health.Tracker))) - must.Do(pm.SetPrefs((&ipn.Prefs{ - Persist: &persist.Persist{ - PrivateNodeKey: nodePriv, - NetworkLockKey: nlPriv, - }, - }).View(), ipn.NetworkProfile{})) + pm := setupProfileManager(t, nodePriv, nlPriv) // Make a fake TKA authority, to seed local state. disablementSecret := bytes.Repeat([]byte{0xa5}, 32) tkaKey := tka.Key{Kind: tka.Key25519, Public: nlPriv.Public().Verifier(), Votes: 2} temp := t.TempDir() - tkaPath := filepath.Join(temp, "tka-profile", string(pm.CurrentProfile().ID)) + tkaPath := filepath.Join(temp, "tka-profile", string(pm.CurrentProfile().ID())) os.Mkdir(tkaPath, 0755) chonk, err := tka.ChonkDir(tkaPath) if err != nil { @@ -1109,13 +1296,7 @@ func TestTKARecoverCompromisedKeyFlow(t *testing.T) { cosignPriv := key.NewNLPrivate() compromisedPriv := key.NewNLPrivate() - pm := must.Get(newProfileManager(new(mem.Store), t.Logf, new(health.Tracker))) - must.Do(pm.SetPrefs((&ipn.Prefs{ - Persist: &persist.Persist{ - PrivateNodeKey: nodePriv, - NetworkLockKey: nlPriv, - }, - }).View(), ipn.NetworkProfile{})) + pm := setupProfileManager(t, nodePriv, nlPriv) // Make a fake TKA authority, to seed local state. disablementSecret := bytes.Repeat([]byte{0xa5}, 32) @@ -1124,7 +1305,7 @@ func TestTKARecoverCompromisedKeyFlow(t *testing.T) { compromisedKey := tka.Key{Kind: tka.Key25519, Public: compromisedPriv.Public().Verifier(), Votes: 1} temp := t.TempDir() - tkaPath := filepath.Join(temp, "tka-profile", string(pm.CurrentProfile().ID)) + tkaPath := filepath.Join(temp, "tka-profile", string(pm.CurrentProfile().ID())) os.Mkdir(tkaPath, 0755) chonk, err := tka.ChonkDir(tkaPath) if err != nil { @@ -1200,13 +1381,7 @@ func TestTKARecoverCompromisedKeyFlow(t *testing.T) { // Cosign using the cosigning key. { - pm := must.Get(newProfileManager(new(mem.Store), t.Logf, new(health.Tracker))) - must.Do(pm.SetPrefs((&ipn.Prefs{ - Persist: &persist.Persist{ - PrivateNodeKey: nodePriv, - NetworkLockKey: cosignPriv, - }, - }).View(), ipn.NetworkProfile{})) + pm := setupProfileManager(t, nodePriv, cosignPriv) b := LocalBackend{ varRoot: temp, logf: t.Logf, diff --git a/ipn/ipnlocal/node_backend.go b/ipn/ipnlocal/node_backend.go new file mode 100644 index 000000000..efef57ea4 --- /dev/null +++ b/ipn/ipnlocal/node_backend.go @@ -0,0 +1,872 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package ipnlocal + +import ( + "cmp" + "context" + "net/netip" + "slices" + "sync" + "sync/atomic" + + "go4.org/netipx" + "tailscale.com/feature/buildfeatures" + "tailscale.com/ipn" + "tailscale.com/net/dns" + "tailscale.com/net/tsaddr" + "tailscale.com/syncs" + "tailscale.com/tailcfg" + "tailscale.com/types/dnstype" + "tailscale.com/types/key" + "tailscale.com/types/logger" + "tailscale.com/types/netmap" + "tailscale.com/types/ptr" + "tailscale.com/types/views" + "tailscale.com/util/dnsname" + "tailscale.com/util/eventbus" + "tailscale.com/util/mak" + "tailscale.com/util/slicesx" + "tailscale.com/wgengine/filter" + "tailscale.com/wgengine/magicsock" +) + +// nodeBackend is node-specific [LocalBackend] state. It is usually the current node. +// +// Its exported methods are safe for concurrent use, but the struct is not a snapshot of state at a given moment; +// its state can change between calls. For example, asking for the same value (e.g., netmap or prefs) twice +// may return different results. Returned values are immutable and safe for concurrent use. +// +// If both the [LocalBackend]'s internal mutex and the [nodeBackend] mutex must be held at the same time, +// the [LocalBackend] mutex must be acquired first. See the comment on the [LocalBackend] field for more details. +// +// Two pointers to different [nodeBackend] instances represent different local nodes. +// However, there's currently a bug where a new [nodeBackend] might not be created +// during an implicit node switch (see tailscale/corp#28014). +// +// In the future, we might want to include at least the following in this struct (in addition to the current fields). +// However, not everything should be exported or otherwise made available to the outside world (e.g. [ipnext] extensions, +// peer API handlers, etc.). +// - [ipn.State]: when the LocalBackend switches to a different [nodeBackend], it can update the state of the old one. +// - [ipn.LoginProfileView] and [ipn.Prefs]: we should update them when the [profileManager] reports changes to them. +// In the future, [profileManager] (and the corresponding methods of the [LocalBackend]) can be made optional, +// and something else could be used to set them once or update them as needed. +// - [tailcfg.HostinfoView]: it includes certain fields that are tied to the current profile/node/prefs. We should also +// update to build it once instead of mutating it in twelvety different places. +// - [filter.Filter] (normal and jailed, along with the filterHash): the nodeBackend could have a method to (re-)build +// the filter for the current netmap/prefs (see [LocalBackend.updateFilterLocked]), and it needs to track the current +// filters and their hash. +// - Fields related to a requested or required (re-)auth: authURL, authURLTime, authActor, keyExpired, etc. +// - [controlclient.Client]/[*controlclient.Auto]: the current control client. It is ties to a node identity. +// - [tkaState]: it is tied to the current profile / node. +// - Fields related to scheduled node expiration: nmExpiryTimer, numClientStatusCalls, [expiryManager]. +// +// It should not include any fields used by specific features that don't belong in [LocalBackend]. +// Even if they're tied to the local node, instead of moving them here, we should extract the entire feature +// into a separate package and have it install proper hooks. +type nodeBackend struct { + logf logger.Logf + + ctx context.Context // canceled by [nodeBackend.shutdown] + ctxCancel context.CancelCauseFunc // cancels ctx + + // filterAtomic is a stateful packet filter. Immutable once created, but can be + // replaced with a new one. + filterAtomic atomic.Pointer[filter.Filter] + + // initialized once and immutable + eventClient *eventbus.Client + filterPub *eventbus.Publisher[magicsock.FilterUpdate] + nodeViewsPub *eventbus.Publisher[magicsock.NodeViewsUpdate] + nodeMutsPub *eventbus.Publisher[magicsock.NodeMutationsUpdate] + derpMapViewPub *eventbus.Publisher[tailcfg.DERPMapView] + + // TODO(nickkhyl): maybe use sync.RWMutex? + mu syncs.Mutex // protects the following fields + + shutdownOnce sync.Once // guards calling [nodeBackend.shutdown] + readyCh chan struct{} // closed by [nodeBackend.ready]; nil after shutdown + + // NetMap is the most recently set full netmap from the controlclient. + // It can't be mutated in place once set. Because it can't be mutated in place, + // delta updates from the control server don't apply to it. Instead, use + // the peers map to get up-to-date information on the state of peers. + // In general, avoid using the netMap.Peers slice. We'd like it to go away + // as of 2023-09-17. + // TODO(nickkhyl): make it an atomic pointer to avoid the need for a mutex? + netMap *netmap.NetworkMap + + // peers is the set of current peers and their current values after applying + // delta node mutations as they come in (with mu held). The map values can be + // given out to callers, but the map itself can be mutated in place (with mu held) + // and must not escape the [nodeBackend]. + peers map[tailcfg.NodeID]tailcfg.NodeView + + // nodeByAddr maps nodes' own addresses (excluding subnet routes) to node IDs. + // It is mutated in place (with mu held) and must not escape the [nodeBackend]. + nodeByAddr map[netip.Addr]tailcfg.NodeID +} + +func newNodeBackend(ctx context.Context, logf logger.Logf, bus *eventbus.Bus) *nodeBackend { + ctx, ctxCancel := context.WithCancelCause(ctx) + nb := &nodeBackend{ + logf: logf, + ctx: ctx, + ctxCancel: ctxCancel, + eventClient: bus.Client("ipnlocal.nodeBackend"), + readyCh: make(chan struct{}), + } + // Default filter blocks everything and logs nothing. + noneFilter := filter.NewAllowNone(logger.Discard, &netipx.IPSet{}) + nb.filterAtomic.Store(noneFilter) + nb.filterPub = eventbus.Publish[magicsock.FilterUpdate](nb.eventClient) + nb.nodeViewsPub = eventbus.Publish[magicsock.NodeViewsUpdate](nb.eventClient) + nb.nodeMutsPub = eventbus.Publish[magicsock.NodeMutationsUpdate](nb.eventClient) + nb.derpMapViewPub = eventbus.Publish[tailcfg.DERPMapView](nb.eventClient) + nb.filterPub.Publish(magicsock.FilterUpdate{Filter: nb.filterAtomic.Load()}) + return nb +} + +// Context returns a context that is canceled when the [nodeBackend] shuts down, +// either because [LocalBackend] is switching to a different [nodeBackend] +// or is shutting down itself. +func (nb *nodeBackend) Context() context.Context { + return nb.ctx +} + +func (nb *nodeBackend) Self() tailcfg.NodeView { + nb.mu.Lock() + defer nb.mu.Unlock() + if nb.netMap == nil { + return tailcfg.NodeView{} + } + return nb.netMap.SelfNode +} + +func (nb *nodeBackend) SelfUserID() tailcfg.UserID { + self := nb.Self() + if !self.Valid() { + return 0 + } + return self.User() +} + +// SelfHasCap reports whether the specified capability was granted to the self node in the most recent netmap. +func (nb *nodeBackend) SelfHasCap(wantCap tailcfg.NodeCapability) bool { + return nb.SelfHasCapOr(wantCap, false) +} + +// SelfHasCapOr is like [nodeBackend.SelfHasCap], but returns the specified default value +// if the netmap is not available yet. +func (nb *nodeBackend) SelfHasCapOr(wantCap tailcfg.NodeCapability, def bool) bool { + nb.mu.Lock() + defer nb.mu.Unlock() + if nb.netMap == nil { + return def + } + return nb.netMap.AllCaps.Contains(wantCap) +} + +func (nb *nodeBackend) NetworkProfile() ipn.NetworkProfile { + nb.mu.Lock() + defer nb.mu.Unlock() + return ipn.NetworkProfile{ + // These are ok to call with nil netMap. + MagicDNSName: nb.netMap.MagicDNSSuffix(), + DomainName: nb.netMap.DomainName(), + DisplayName: nb.netMap.TailnetDisplayName(), + } +} + +// TODO(nickkhyl): update it to return a [tailcfg.DERPMapView]? +func (nb *nodeBackend) DERPMap() *tailcfg.DERPMap { + nb.mu.Lock() + defer nb.mu.Unlock() + if nb.netMap == nil { + return nil + } + return nb.netMap.DERPMap +} + +func (nb *nodeBackend) NodeByAddr(ip netip.Addr) (_ tailcfg.NodeID, ok bool) { + nb.mu.Lock() + defer nb.mu.Unlock() + nid, ok := nb.nodeByAddr[ip] + return nid, ok +} + +func (nb *nodeBackend) NodeByKey(k key.NodePublic) (_ tailcfg.NodeID, ok bool) { + nb.mu.Lock() + defer nb.mu.Unlock() + if nb.netMap == nil { + return 0, false + } + if self := nb.netMap.SelfNode; self.Valid() && self.Key() == k { + return self.ID(), true + } + // TODO(bradfitz,nickkhyl): add nodeByKey like nodeByAddr instead of walking peers. + for _, n := range nb.peers { + if n.Key() == k { + return n.ID(), true + } + } + return 0, false +} + +func (nb *nodeBackend) NodeByID(id tailcfg.NodeID) (_ tailcfg.NodeView, ok bool) { + nb.mu.Lock() + defer nb.mu.Unlock() + if nb.netMap != nil { + if self := nb.netMap.SelfNode; self.Valid() && self.ID() == id { + return self, true + } + } + n, ok := nb.peers[id] + return n, ok +} + +func (nb *nodeBackend) PeerByStableID(id tailcfg.StableNodeID) (_ tailcfg.NodeView, ok bool) { + nb.mu.Lock() + defer nb.mu.Unlock() + for _, n := range nb.peers { + if n.StableID() == id { + return n, true + } + } + return tailcfg.NodeView{}, false +} + +func (nb *nodeBackend) UserByID(id tailcfg.UserID) (_ tailcfg.UserProfileView, ok bool) { + nb.mu.Lock() + nm := nb.netMap + nb.mu.Unlock() + if nm == nil { + return tailcfg.UserProfileView{}, false + } + u, ok := nm.UserProfiles[id] + return u, ok +} + +// Peers returns all the current peers in an undefined order. +func (nb *nodeBackend) Peers() []tailcfg.NodeView { + nb.mu.Lock() + defer nb.mu.Unlock() + return slicesx.MapValues(nb.peers) +} + +func (nb *nodeBackend) PeersForTest() []tailcfg.NodeView { + nb.mu.Lock() + defer nb.mu.Unlock() + ret := slicesx.MapValues(nb.peers) + slices.SortFunc(ret, func(a, b tailcfg.NodeView) int { + return cmp.Compare(a.ID(), b.ID()) + }) + return ret +} + +func (nb *nodeBackend) CollectServices() bool { + nb.mu.Lock() + defer nb.mu.Unlock() + return nb.netMap != nil && nb.netMap.CollectServices +} + +// AppendMatchingPeers returns base with all peers that match pred appended. +// +// It acquires b.mu to read the netmap but releases it before calling pred. +func (nb *nodeBackend) AppendMatchingPeers(base []tailcfg.NodeView, pred func(tailcfg.NodeView) bool) []tailcfg.NodeView { + var peers []tailcfg.NodeView + + nb.mu.Lock() + if nb.netMap != nil { + // All fields on b.netMap are immutable, so this is + // safe to copy and use outside the lock. + peers = nb.netMap.Peers + } + nb.mu.Unlock() + + ret := base + for _, peer := range peers { + // The peers in b.netMap don't contain updates made via + // UpdateNetmapDelta. So only use PeerView in b.netMap for its NodeID, + // and then look up the latest copy in b.peers which is updated in + // response to UpdateNetmapDelta edits. + nb.mu.Lock() + peer, ok := nb.peers[peer.ID()] + nb.mu.Unlock() + if ok && pred(peer) { + ret = append(ret, peer) + } + } + return ret +} + +// PeerCaps returns the capabilities that remote src IP has to +// ths current node. +func (nb *nodeBackend) PeerCaps(src netip.Addr) tailcfg.PeerCapMap { + nb.mu.Lock() + defer nb.mu.Unlock() + return nb.peerCapsLocked(src) +} + +func (nb *nodeBackend) peerCapsLocked(src netip.Addr) tailcfg.PeerCapMap { + if nb.netMap == nil { + return nil + } + filt := nb.filterAtomic.Load() + if filt == nil { + return nil + } + addrs := nb.netMap.GetAddresses() + for i := range addrs.Len() { + a := addrs.At(i) + if !a.IsSingleIP() { + continue + } + dst := a.Addr() + if dst.BitLen() == src.BitLen() { // match on family + return filt.CapsWithValues(src, dst) + } + } + return nil +} + +// PeerHasCap reports whether the peer contains the given capability string, +// with any value(s). +func (nb *nodeBackend) PeerHasCap(peer tailcfg.NodeView, wantCap tailcfg.PeerCapability) bool { + if !peer.Valid() { + return false + } + + nb.mu.Lock() + defer nb.mu.Unlock() + for _, ap := range peer.Addresses().All() { + if nb.peerHasCapLocked(ap.Addr(), wantCap) { + return true + } + } + return false +} + +func (nb *nodeBackend) peerHasCapLocked(addr netip.Addr, wantCap tailcfg.PeerCapability) bool { + return nb.peerCapsLocked(addr).HasCapability(wantCap) +} + +func (nb *nodeBackend) PeerHasPeerAPI(p tailcfg.NodeView) bool { + return nb.PeerAPIBase(p) != "" +} + +// PeerAPIBase returns the "http://ip:port" URL base to reach peer's PeerAPI, +// or the empty string if the peer is invalid or doesn't support PeerAPI. +func (nb *nodeBackend) PeerAPIBase(p tailcfg.NodeView) string { + nb.mu.Lock() + nm := nb.netMap + nb.mu.Unlock() + return peerAPIBase(nm, p) +} + +// PeerIsReachable reports whether the current node can reach p. If the ctx is +// done, this function may return a result based on stale reachability data. +func (nb *nodeBackend) PeerIsReachable(ctx context.Context, p tailcfg.NodeView) bool { + if !nb.SelfHasCap(tailcfg.NodeAttrClientSideReachability) { + // Legacy behavior is to always trust the control plane, which + // isn’t always correct because the peer could be slow to check + // in so that control marks it as offline. + // See tailscale/corp#32686. + return p.Online().Get() + } + + nb.mu.Lock() + nm := nb.netMap + nb.mu.Unlock() + + if self := nm.SelfNode; self.Valid() && self.ID() == p.ID() { + // This node can always reach itself. + return true + } + return nb.peerIsReachable(ctx, p) +} + +func (nb *nodeBackend) peerIsReachable(ctx context.Context, p tailcfg.NodeView) bool { + // TODO(sfllaw): The following does not actually test for client-side + // reachability. This would require a mechanism that tracks whether the + // current node can actually reach this peer, either because they are + // already communicating or because they can ping each other. + // + // Instead, it makes the client ignore p.Online completely. + // + // See tailscale/corp#32686. + return true +} + +func nodeIP(n tailcfg.NodeView, pred func(netip.Addr) bool) netip.Addr { + for _, pfx := range n.Addresses().All() { + if pfx.IsSingleIP() && pred(pfx.Addr()) { + return pfx.Addr() + } + } + return netip.Addr{} +} + +func (nb *nodeBackend) NetMap() *netmap.NetworkMap { + nb.mu.Lock() + defer nb.mu.Unlock() + return nb.netMap +} + +func (nb *nodeBackend) netMapWithPeers() *netmap.NetworkMap { + nb.mu.Lock() + defer nb.mu.Unlock() + if nb.netMap == nil { + return nil + } + nm := ptr.To(*nb.netMap) // shallow clone + nm.Peers = slicesx.MapValues(nb.peers) + slices.SortFunc(nm.Peers, func(a, b tailcfg.NodeView) int { + return cmp.Compare(a.ID(), b.ID()) + }) + return nm +} + +func (nb *nodeBackend) SetNetMap(nm *netmap.NetworkMap) { + nb.mu.Lock() + defer nb.mu.Unlock() + nb.netMap = nm + nb.updateNodeByAddrLocked() + nb.updatePeersLocked() + nv := magicsock.NodeViewsUpdate{} + if nm != nil { + nv.SelfNode = nm.SelfNode + nv.Peers = nm.Peers + nb.derpMapViewPub.Publish(nm.DERPMap.View()) + } else { + nb.derpMapViewPub.Publish(tailcfg.DERPMapView{}) + } + nb.nodeViewsPub.Publish(nv) +} + +func (nb *nodeBackend) updateNodeByAddrLocked() { + nm := nb.netMap + if nm == nil { + nb.nodeByAddr = nil + return + } + + // Update the nodeByAddr index. + if nb.nodeByAddr == nil { + nb.nodeByAddr = map[netip.Addr]tailcfg.NodeID{} + } + // First pass, mark everything unwanted. + for k := range nb.nodeByAddr { + nb.nodeByAddr[k] = 0 + } + addNode := func(n tailcfg.NodeView) { + for _, ipp := range n.Addresses().All() { + if ipp.IsSingleIP() { + nb.nodeByAddr[ipp.Addr()] = n.ID() + } + } + } + if nm.SelfNode.Valid() { + addNode(nm.SelfNode) + } + for _, p := range nm.Peers { + addNode(p) + } + // Third pass, actually delete the unwanted items. + for k, v := range nb.nodeByAddr { + if v == 0 { + delete(nb.nodeByAddr, k) + } + } +} + +func (nb *nodeBackend) updatePeersLocked() { + nm := nb.netMap + if nm == nil { + nb.peers = nil + return + } + + // First pass, mark everything unwanted. + for k := range nb.peers { + nb.peers[k] = tailcfg.NodeView{} + } + + // Second pass, add everything wanted. + for _, p := range nm.Peers { + mak.Set(&nb.peers, p.ID(), p) + } + + // Third pass, remove deleted things. + for k, v := range nb.peers { + if !v.Valid() { + delete(nb.peers, k) + } + } +} + +func (nb *nodeBackend) UpdateNetmapDelta(muts []netmap.NodeMutation) (handled bool) { + nb.mu.Lock() + defer nb.mu.Unlock() + if nb.netMap == nil || len(nb.peers) == 0 { + return false + } + + // Locally cloned mutable nodes, to avoid calling AsStruct (clone) + // multiple times on a node if it's mutated multiple times in this + // call (e.g. its endpoints + online status both change) + var mutableNodes map[tailcfg.NodeID]*tailcfg.Node + + update := magicsock.NodeMutationsUpdate{ + Mutations: make([]netmap.NodeMutation, 0, len(muts)), + } + for _, m := range muts { + n, ok := mutableNodes[m.NodeIDBeingMutated()] + if !ok { + nv, ok := nb.peers[m.NodeIDBeingMutated()] + if !ok { + // TODO(bradfitz): unexpected metric? + return false + } + n = nv.AsStruct() + mak.Set(&mutableNodes, nv.ID(), n) + update.Mutations = append(update.Mutations, m) + } + m.Apply(n) + } + for nid, n := range mutableNodes { + nb.peers[nid] = n.View() + } + nb.nodeMutsPub.Publish(update) + return true +} + +// unlockedNodesPermitted reports whether any peer with theUnsignedPeerAPIOnly bool set true has any of its allowed IPs +// in the specified packet filter. +// +// TODO(nickkhyl): It is here temporarily until we can move the whole [LocalBackend.updateFilterLocked] here, +// but change it so it builds and returns a filter for the current netmap/prefs instead of re-configuring the engine filter. +// Something like (*nodeBackend).RebuildFilters() (filter, jailedFilter *filter.Filter, changed bool) perhaps? +func (nb *nodeBackend) unlockedNodesPermitted(packetFilter []filter.Match) bool { + nb.mu.Lock() + defer nb.mu.Unlock() + return packetFilterPermitsUnlockedNodes(nb.peers, packetFilter) +} + +func (nb *nodeBackend) filter() *filter.Filter { + return nb.filterAtomic.Load() +} + +func (nb *nodeBackend) setFilter(f *filter.Filter) { + nb.filterAtomic.Store(f) + nb.filterPub.Publish(magicsock.FilterUpdate{Filter: f}) +} + +func (nb *nodeBackend) dnsConfigForNetmap(prefs ipn.PrefsView, selfExpired bool, versionOS string) *dns.Config { + nb.mu.Lock() + defer nb.mu.Unlock() + return dnsConfigForNetmap(nb.netMap, nb.peers, prefs, selfExpired, nb.logf, versionOS) +} + +func (nb *nodeBackend) exitNodeCanProxyDNS(exitNodeID tailcfg.StableNodeID) (dohURL string, ok bool) { + if !buildfeatures.HasUseExitNode { + return "", false + } + nb.mu.Lock() + defer nb.mu.Unlock() + return exitNodeCanProxyDNS(nb.netMap, nb.peers, exitNodeID) +} + +// ready signals that [LocalBackend] has completed the switch to this [nodeBackend] +// and any pending calls to [nodeBackend.Wait] must be unblocked. +func (nb *nodeBackend) ready() { + nb.mu.Lock() + defer nb.mu.Unlock() + if nb.readyCh != nil { + close(nb.readyCh) + } +} + +// Wait blocks until [LocalBackend] completes the switch to this [nodeBackend] +// and calls [nodeBackend.ready]. It returns an error if the provided context +// is canceled or if the [nodeBackend] shuts down or is already shut down. +// +// It must not be called with the [LocalBackend]'s internal mutex held as [LocalBackend] +// may need to acquire it to complete the switch. +// +// TODO(nickkhyl): Relax this restriction once [LocalBackend]'s state machine +// runs in its own goroutine, or if we decide that waiting for the state machine +// restart to finish isn't necessary for [LocalBackend] to consider the switch complete. +// We mostly need this because of [LocalBackend.Start] acquiring b.mu and the fact that +// methods like [LocalBackend.SwitchProfile] must report any errors returned by it. +// Perhaps we could report those errors asynchronously as [health.Warnable]s? +func (nb *nodeBackend) Wait(ctx context.Context) error { + nb.mu.Lock() + readyCh := nb.readyCh + nb.mu.Unlock() + + select { + case <-ctx.Done(): + return ctx.Err() + case <-nb.ctx.Done(): + return context.Cause(nb.ctx) + case <-readyCh: + return nil + } +} + +// shutdown shuts down the [nodeBackend] and cancels its context +// with the provided cause. +func (nb *nodeBackend) shutdown(cause error) { + nb.shutdownOnce.Do(func() { + nb.doShutdown(cause) + }) +} + +func (nb *nodeBackend) doShutdown(cause error) { + nb.mu.Lock() + defer nb.mu.Unlock() + nb.ctxCancel(cause) + nb.readyCh = nil + nb.eventClient.Close() +} + +// useWithExitNodeResolvers filters out resolvers so the ones that remain +// are all the ones marked for use with exit nodes. +func useWithExitNodeResolvers(resolvers []*dnstype.Resolver) []*dnstype.Resolver { + var filtered []*dnstype.Resolver + for _, res := range resolvers { + if res.UseWithExitNode { + filtered = append(filtered, res) + } + } + return filtered +} + +// useWithExitNodeRoutes filters out routes so the ones that remain +// are either zero-length resolver lists, or lists containing only +// resolvers marked for use with exit nodes. +func useWithExitNodeRoutes(routes map[string][]*dnstype.Resolver) map[string][]*dnstype.Resolver { + var filtered map[string][]*dnstype.Resolver + for suffix, resolvers := range routes { + // Suffixes with no resolvers represent a valid configuration, + // and should persist regardless of exit node considerations. + if len(resolvers) == 0 { + mak.Set(&filtered, suffix, make([]*dnstype.Resolver, 0)) + continue + } + + // In exit node contexts, we filter out resolvers not configured for use with + // exit nodes. If there are no such configured resolvers, there should not be an entry for that suffix. + filteredResolvers := useWithExitNodeResolvers(resolvers) + if len(filteredResolvers) > 0 { + mak.Set(&filtered, suffix, filteredResolvers) + } + } + + return filtered +} + +// dnsConfigForNetmap returns a *dns.Config for the given netmap, +// prefs, client OS version, and cloud hosting environment. +// +// The versionOS is a Tailscale-style version ("iOS", "macOS") and not +// a runtime.GOOS. +func dnsConfigForNetmap(nm *netmap.NetworkMap, peers map[tailcfg.NodeID]tailcfg.NodeView, prefs ipn.PrefsView, selfExpired bool, logf logger.Logf, versionOS string) *dns.Config { + if nm == nil { + return nil + } + if !buildfeatures.HasDNS { + return &dns.Config{} + } + + // If the current node's key is expired, then we don't program any DNS + // configuration into the operating system. This ensures that if the + // DNS configuration specifies a DNS server that is only reachable over + // Tailscale, we don't break connectivity for the user. + // + // TODO(andrew-d): this also stops returning anything from quad-100; we + // could do the same thing as having "CorpDNS: false" and keep that but + // not program the OS? + if selfExpired { + return &dns.Config{} + } + + dcfg := &dns.Config{ + Routes: map[dnsname.FQDN][]*dnstype.Resolver{}, + Hosts: map[dnsname.FQDN][]netip.Addr{}, + } + + // selfV6Only is whether we only have IPv6 addresses ourselves. + selfV6Only := nm.GetAddresses().ContainsFunc(tsaddr.PrefixIs6) && + !nm.GetAddresses().ContainsFunc(tsaddr.PrefixIs4) + dcfg.OnlyIPv6 = selfV6Only + + wantAAAA := nm.AllCaps.Contains(tailcfg.NodeAttrMagicDNSPeerAAAA) + + // Populate MagicDNS records. We do this unconditionally so that + // quad-100 can always respond to MagicDNS queries, even if the OS + // isn't configured to make MagicDNS resolution truly + // magic. Details in + // https://github.com/tailscale/tailscale/issues/1886. + set := func(name string, addrs views.Slice[netip.Prefix]) { + if addrs.Len() == 0 || name == "" { + return + } + fqdn, err := dnsname.ToFQDN(name) + if err != nil { + return // TODO: propagate error? + } + var have4 bool + for _, addr := range addrs.All() { + if addr.Addr().Is4() { + have4 = true + break + } + } + var ips []netip.Addr + for _, addr := range addrs.All() { + if selfV6Only { + if addr.Addr().Is6() { + ips = append(ips, addr.Addr()) + } + continue + } + // If this node has an IPv4 address, then + // remove peers' IPv6 addresses for now, as we + // don't guarantee that the peer node actually + // can speak IPv6 correctly. + // + // https://github.com/tailscale/tailscale/issues/1152 + // tracks adding the right capability reporting to + // enable AAAA in MagicDNS. + if addr.Addr().Is6() && have4 && !wantAAAA { + continue + } + ips = append(ips, addr.Addr()) + } + dcfg.Hosts[fqdn] = ips + } + set(nm.SelfName(), nm.GetAddresses()) + for _, peer := range peers { + set(peer.Name(), peer.Addresses()) + } + for _, rec := range nm.DNS.ExtraRecords { + switch rec.Type { + case "", "A", "AAAA": + // Treat these all the same for now: infer from the value + default: + // TODO: more + continue + } + ip, err := netip.ParseAddr(rec.Value) + if err != nil { + // Ignore. + continue + } + fqdn, err := dnsname.ToFQDN(rec.Name) + if err != nil { + continue + } + dcfg.Hosts[fqdn] = append(dcfg.Hosts[fqdn], ip) + } + + if !prefs.CorpDNS() { + return dcfg + } + + for _, dom := range nm.DNS.Domains { + fqdn, err := dnsname.ToFQDN(dom) + if err != nil { + logf("[unexpected] non-FQDN search domain %q", dom) + } + dcfg.SearchDomains = append(dcfg.SearchDomains, fqdn) + } + if nm.DNS.Proxied { // actually means "enable MagicDNS" + for _, dom := range magicDNSRootDomains(nm) { + dcfg.Routes[dom] = nil // resolve internally with dcfg.Hosts + } + } + + addDefault := func(resolvers []*dnstype.Resolver) { + dcfg.DefaultResolvers = append(dcfg.DefaultResolvers, resolvers...) + } + + addSplitDNSRoutes := func(routes map[string][]*dnstype.Resolver) { + for suffix, resolvers := range routes { + fqdn, err := dnsname.ToFQDN(suffix) + if err != nil { + logf("[unexpected] non-FQDN route suffix %q", suffix) + } + + // Create map entry even if len(resolvers) == 0; Issue 2706. + // This lets the control plane send ExtraRecords for which we + // can authoritatively answer "name not exists" for when the + // control plane also sends this explicit but empty route + // making it as something we handle. + dcfg.Routes[fqdn] = slices.Clone(resolvers) + } + } + + // If we're using an exit node and that exit node is new enough (1.19.x+) + // to run a DoH DNS proxy, then send all our DNS traffic through it, + // unless we find resolvers with UseWithExitNode set, in which case we use that. + if buildfeatures.HasUseExitNode { + if dohURL, ok := exitNodeCanProxyDNS(nm, peers, prefs.ExitNodeID()); ok { + filtered := useWithExitNodeResolvers(nm.DNS.Resolvers) + if len(filtered) > 0 { + addDefault(filtered) + } else { + // If no default global resolvers with the override + // are configured, configure the exit node's resolver. + addDefault([]*dnstype.Resolver{{Addr: dohURL}}) + } + + addSplitDNSRoutes(useWithExitNodeRoutes(nm.DNS.Routes)) + return dcfg + } + } + + // If the user has set default resolvers ("override local DNS"), prefer to + // use those resolvers as the default, otherwise if there are WireGuard exit + // node resolvers, use those as the default. + if len(nm.DNS.Resolvers) > 0 { + addDefault(nm.DNS.Resolvers) + } else if buildfeatures.HasUseExitNode { + if resolvers, ok := wireguardExitNodeDNSResolvers(nm, peers, prefs.ExitNodeID()); ok { + addDefault(resolvers) + } + } + + // Add split DNS routes, with no regard to exit node configuration. + addSplitDNSRoutes(nm.DNS.Routes) + + // Set FallbackResolvers as the default resolvers in the + // scenarios that can't handle a purely split-DNS config. See + // https://github.com/tailscale/tailscale/issues/1743 for + // details. + switch { + case len(dcfg.DefaultResolvers) != 0: + // Default resolvers already set. + case !prefs.ExitNodeID().IsZero(): + // When using an exit node, we send all DNS traffic to the exit node, so + // we don't need a fallback resolver. + // + // However, if the exit node is too old to run a DoH DNS proxy, then we + // need to use a fallback resolver as it's very likely the LAN resolvers + // will become unreachable. + // + // This is especially important on Apple OSes, where + // adding the default route to the tunnel interface makes + // it "primary", and we MUST provide VPN-sourced DNS + // settings or we break all DNS resolution. + // + // https://github.com/tailscale/tailscale/issues/1713 + addDefault(nm.DNS.FallbackResolvers) + case len(dcfg.Routes) == 0: + // No settings requiring split DNS, no problem. + } + + return dcfg +} diff --git a/ipn/ipnlocal/node_backend_test.go b/ipn/ipnlocal/node_backend_test.go new file mode 100644 index 000000000..f6698bd4b --- /dev/null +++ b/ipn/ipnlocal/node_backend_test.go @@ -0,0 +1,192 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package ipnlocal + +import ( + "context" + "errors" + "testing" + "time" + + "tailscale.com/tailcfg" + "tailscale.com/tstest" + "tailscale.com/types/netmap" + "tailscale.com/types/ptr" + "tailscale.com/util/eventbus" +) + +func TestNodeBackendReadiness(t *testing.T) { + nb := newNodeBackend(t.Context(), tstest.WhileTestRunningLogger(t), eventbus.New()) + + // The node backend is not ready until [nodeBackend.ready] is called, + // and [nodeBackend.Wait] should fail with [context.DeadlineExceeded]. + ctx, cancelCtx := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancelCtx() + if err := nb.Wait(ctx); err != ctx.Err() { + t.Fatalf("Wait: got %v; want %v", err, ctx.Err()) + } + + // Start a goroutine to wait for the node backend to become ready. + waitDone := make(chan struct{}) + go func() { + if err := nb.Wait(context.Background()); err != nil { + t.Errorf("Wait: got %v; want nil", err) + } + close(waitDone) + }() + + // Call [nodeBackend.ready] to indicate that the node backend is now ready. + go nb.ready() + + // Once the backend is called, [nodeBackend.Wait] should return immediately without error. + if err := nb.Wait(context.Background()); err != nil { + t.Fatalf("Wait: got %v; want nil", err) + } + // And any pending waiters should also be unblocked. + <-waitDone +} + +func TestNodeBackendShutdown(t *testing.T) { + nb := newNodeBackend(t.Context(), tstest.WhileTestRunningLogger(t), eventbus.New()) + + shutdownCause := errors.New("test shutdown") + + // Start a goroutine to wait for the node backend to become ready. + // This test expects it to block until the node backend shuts down + // and then return the specified shutdown cause. + waitDone := make(chan struct{}) + go func() { + if err := nb.Wait(context.Background()); err != shutdownCause { + t.Errorf("Wait: got %v; want %v", err, shutdownCause) + } + close(waitDone) + }() + + // Call [nodeBackend.shutdown] to indicate that the node backend is shutting down. + nb.shutdown(shutdownCause) + + // Calling it again is fine, but should not change the shutdown cause. + nb.shutdown(errors.New("test shutdown again")) + + // After shutdown, [nodeBackend.Wait] should return with the specified shutdown cause. + if err := nb.Wait(context.Background()); err != shutdownCause { + t.Fatalf("Wait: got %v; want %v", err, shutdownCause) + } + // The context associated with the node backend should also be cancelled + // and its cancellation cause should match the shutdown cause. + if err := nb.Context().Err(); !errors.Is(err, context.Canceled) { + t.Fatalf("Context.Err: got %v; want %v", err, context.Canceled) + } + if cause := context.Cause(nb.Context()); cause != shutdownCause { + t.Fatalf("Cause: got %v; want %v", cause, shutdownCause) + } + // And any pending waiters should also be unblocked. + <-waitDone +} + +func TestNodeBackendReadyAfterShutdown(t *testing.T) { + nb := newNodeBackend(t.Context(), tstest.WhileTestRunningLogger(t), eventbus.New()) + + shutdownCause := errors.New("test shutdown") + nb.shutdown(shutdownCause) + nb.ready() // Calling ready after shutdown is a no-op, but should not panic, etc. + if err := nb.Wait(context.Background()); err != shutdownCause { + t.Fatalf("Wait: got %v; want %v", err, shutdownCause) + } +} + +func TestNodeBackendParentContextCancellation(t *testing.T) { + ctx, cancelCtx := context.WithCancel(context.Background()) + nb := newNodeBackend(ctx, tstest.WhileTestRunningLogger(t), eventbus.New()) + + cancelCtx() + + // Cancelling the parent context should cause [nodeBackend.Wait] + // to return with [context.Canceled]. + if err := nb.Wait(context.Background()); !errors.Is(err, context.Canceled) { + t.Fatalf("Wait: got %v; want %v", err, context.Canceled) + } + + // And the node backend's context should also be cancelled. + if err := nb.Context().Err(); !errors.Is(err, context.Canceled) { + t.Fatalf("Context.Err: got %v; want %v", err, context.Canceled) + } +} + +func TestNodeBackendConcurrentReadyAndShutdown(t *testing.T) { + nb := newNodeBackend(t.Context(), tstest.WhileTestRunningLogger(t), eventbus.New()) + + // Calling [nodeBackend.ready] and [nodeBackend.shutdown] concurrently + // should not cause issues, and [nodeBackend.Wait] should unblock, + // but the result of [nodeBackend.Wait] is intentionally undefined. + go nb.ready() + go nb.shutdown(errors.New("test shutdown")) + + nb.Wait(context.Background()) +} + +func TestNodeBackendReachability(t *testing.T) { + for _, tc := range []struct { + name string + + // Cap sets [tailcfg.NodeAttrClientSideReachability] on the self + // node. + // + // When disabled, the client relies on the control plane sending + // an accurate peer.Online flag. When enabled, the client + // ignores peer.Online and determines whether it can reach the + // peer node. + cap bool + + peer tailcfg.Node + want bool + }{ + { + name: "disabled/offline", + cap: false, + peer: tailcfg.Node{ + Online: ptr.To(false), + }, + want: false, + }, + { + name: "disabled/online", + cap: false, + peer: tailcfg.Node{ + Online: ptr.To(true), + }, + want: true, + }, + { + name: "enabled/offline", + cap: true, + peer: tailcfg.Node{ + Online: ptr.To(false), + }, + want: true, + }, + { + name: "enabled/online", + cap: true, + peer: tailcfg.Node{ + Online: ptr.To(true), + }, + want: true, + }, + } { + t.Run(tc.name, func(t *testing.T) { + nb := newNodeBackend(t.Context(), tstest.WhileTestRunningLogger(t), eventbus.New()) + nb.netMap = &netmap.NetworkMap{} + if tc.cap { + nb.netMap.AllCaps.Make() + nb.netMap.AllCaps.Add(tailcfg.NodeAttrClientSideReachability) + } + + got := nb.PeerIsReachable(t.Context(), tc.peer.View()) + if got != tc.want { + t.Errorf("got %v, want %v", got, tc.want) + } + }) + } +} diff --git a/ipn/ipnlocal/peerapi.go b/ipn/ipnlocal/peerapi.go index aa18c3588..a045086d4 100644 --- a/ipn/ipnlocal/peerapi.go +++ b/ipn/ipnlocal/peerapi.go @@ -15,48 +15,34 @@ import ( "net" "net/http" "net/netip" - "net/url" "os" - "path/filepath" "runtime" "slices" - "sort" "strconv" "strings" "sync" "time" - "github.com/kortschak/wol" "golang.org/x/net/dns/dnsmessage" "golang.org/x/net/http/httpguts" - "tailscale.com/drive" "tailscale.com/envknob" + "tailscale.com/feature" + "tailscale.com/feature/buildfeatures" "tailscale.com/health" "tailscale.com/hostinfo" - "tailscale.com/ipn" "tailscale.com/net/netaddr" "tailscale.com/net/netmon" "tailscale.com/net/netutil" "tailscale.com/net/sockstats" "tailscale.com/tailcfg" - "tailscale.com/taildrop" + "tailscale.com/types/netmap" "tailscale.com/types/views" "tailscale.com/util/clientmetric" - "tailscale.com/util/httphdr" - "tailscale.com/util/httpm" "tailscale.com/wgengine/filter" ) -const ( - taildrivePrefix = "/v0/drive" -) - var initListenConfig func(*net.ListenConfig, netip.Addr, *netmon.State, string) error -// addH2C is non-nil on platforms where we want to add H2C -// ("cleartext" HTTP/2) support to the peerAPI. -var addH2C func(*http.Server) - // peerDNSQueryHandler is implemented by tsdns.Resolver. type peerDNSQueryHandler interface { HandlePeerDNSQuery(context.Context, []byte, netip.AddrPort, func(name string) bool) (res []byte, err error) @@ -65,8 +51,6 @@ type peerDNSQueryHandler interface { type peerAPIServer struct { b *LocalBackend resolver peerDNSQueryHandler - - taildrop *taildrop.Manager } func (s *peerAPIServer) listen(ip netip.Addr, ifState *netmon.State) (ln net.Listener, err error) { @@ -148,6 +132,9 @@ type peerAPIListener struct { } func (pln *peerAPIListener) Close() error { + if !buildfeatures.HasPeerAPIServer { + return nil + } if pln.ln != nil { return pln.ln.Close() } @@ -155,6 +142,9 @@ func (pln *peerAPIListener) Close() error { } func (pln *peerAPIListener) serve() { + if !buildfeatures.HasPeerAPIServer { + return + } if pln.ln == nil { return } @@ -208,11 +198,11 @@ func (pln *peerAPIListener) ServeConn(src netip.AddrPort, c net.Conn) { peerUser: peerUser, } httpServer := &http.Server{ - Handler: h, - } - if addH2C != nil { - addH2C(httpServer) + Handler: h, + Protocols: new(http.Protocols), } + httpServer.Protocols.SetHTTP1(true) + httpServer.Protocols.SetUnencryptedHTTP2(true) // over WireGuard; "unencrypted" means no TLS go httpServer.Serve(netutil.NewOneConnListener(c, nil)) } @@ -226,18 +216,48 @@ type peerAPIHandler struct { peerUser tailcfg.UserProfile // profile of peerNode } +// PeerAPIHandler is the interface implemented by [peerAPIHandler] and needed by +// module features registered via tailscale.com/feature/*. +type PeerAPIHandler interface { + Peer() tailcfg.NodeView + PeerCaps() tailcfg.PeerCapMap + CanDebug() bool // can remote node can debug this node (internal state, etc) + Self() tailcfg.NodeView + LocalBackend() *LocalBackend + IsSelfUntagged() bool // whether the peer is untagged and the same as this user + RemoteAddr() netip.AddrPort + Logf(format string, a ...any) +} + +func (h *peerAPIHandler) IsSelfUntagged() bool { + return !h.selfNode.IsTagged() && !h.peerNode.IsTagged() && h.isSelf +} +func (h *peerAPIHandler) Peer() tailcfg.NodeView { return h.peerNode } +func (h *peerAPIHandler) Self() tailcfg.NodeView { return h.selfNode } +func (h *peerAPIHandler) RemoteAddr() netip.AddrPort { return h.remoteAddr } +func (h *peerAPIHandler) LocalBackend() *LocalBackend { return h.ps.b } +func (h *peerAPIHandler) Logf(format string, a ...any) { + h.logf(format, a...) +} + func (h *peerAPIHandler) logf(format string, a ...any) { h.ps.b.logf("peerapi: "+format, a...) } +func (h *peerAPIHandler) logfv1(format string, a ...any) { + h.ps.b.logf("[v1] peerapi: "+format, a...) +} + // isAddressValid reports whether addr is a valid destination address for this // node originating from the peer. func (h *peerAPIHandler) isAddressValid(addr netip.Addr) bool { - if v := h.peerNode.SelfNodeV4MasqAddrForThisPeer(); v != nil { - return *v == addr + if !addr.IsValid() { + return false } - if v := h.peerNode.SelfNodeV6MasqAddrForThisPeer(); v != nil { - return *v == addr + v4MasqAddr, hasMasqV4 := h.peerNode.SelfNodeV4MasqAddrForThisPeer().GetOk() + v6MasqAddr, hasMasqV6 := h.peerNode.SelfNodeV6MasqAddrForThisPeer().GetOk() + if hasMasqV4 || hasMasqV6 { + return addr == v4MasqAddr || addr == v6MasqAddr } pfx := netip.PrefixFrom(addr, addr.BitLen()) return views.SliceContains(h.selfNode.Addresses(), pfx) @@ -300,7 +320,37 @@ func peerAPIRequestShouldGetSecurityHeaders(r *http.Request) bool { return false } +// RegisterPeerAPIHandler registers a PeerAPI handler. +// +// The path should be of the form "/v0/foo". +// +// It panics if the path is already registered. +func RegisterPeerAPIHandler(path string, f func(PeerAPIHandler, http.ResponseWriter, *http.Request)) { + if !buildfeatures.HasPeerAPIServer { + return + } + if _, ok := peerAPIHandlers[path]; ok { + panic(fmt.Sprintf("duplicate PeerAPI handler %q", path)) + } + peerAPIHandlers[path] = f + if strings.HasSuffix(path, "/") { + peerAPIHandlerPrefixes[path] = f + } +} + +var ( + peerAPIHandlers = map[string]func(PeerAPIHandler, http.ResponseWriter, *http.Request){} // by URL.Path + + // peerAPIHandlerPrefixes are the subset of peerAPIHandlers where + // the map key ends with a slash, indicating a prefix match. + peerAPIHandlerPrefixes = map[string]func(PeerAPIHandler, http.ResponseWriter, *http.Request){} +) + func (h *peerAPIHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if !buildfeatures.HasPeerAPIServer { + http.Error(w, feature.ErrUnavailable.Error(), http.StatusNotImplemented) + return + } if err := h.validatePeerAPIRequest(r); err != nil { metricInvalidRequests.Add(1) h.logf("invalid request from %v: %v", h.remoteAddr, err) @@ -312,54 +362,48 @@ func (h *peerAPIHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { w.Header().Set("X-Frame-Options", "DENY") w.Header().Set("X-Content-Type-Options", "nosniff") } - if strings.HasPrefix(r.URL.Path, "/v0/put/") { - if r.Method == "PUT" { - metricPutCalls.Add(1) + for pfx, ph := range peerAPIHandlerPrefixes { + if strings.HasPrefix(r.URL.Path, pfx) { + ph(h, w, r) + return } - h.handlePeerPut(w, r) - return } - if strings.HasPrefix(r.URL.Path, "/dns-query") { + if buildfeatures.HasDNS && strings.HasPrefix(r.URL.Path, "/dns-query") { metricDNSCalls.Add(1) h.handleDNSQuery(w, r) return } - if strings.HasPrefix(r.URL.Path, taildrivePrefix) { - h.handleServeDrive(w, r) - return + if buildfeatures.HasDebug { + switch r.URL.Path { + case "/v0/goroutines": + h.handleServeGoroutines(w, r) + return + case "/v0/env": + h.handleServeEnv(w, r) + return + case "/v0/metrics": + h.handleServeMetrics(w, r) + return + case "/v0/magicsock": + h.handleServeMagicsock(w, r) + return + case "/v0/dnsfwd": + h.handleServeDNSFwd(w, r) + return + case "/v0/interfaces": + h.handleServeInterfaces(w, r) + return + case "/v0/sockstats": + h.handleServeSockStats(w, r) + return + } } - switch r.URL.Path { - case "/v0/goroutines": - h.handleServeGoroutines(w, r) - return - case "/v0/env": - h.handleServeEnv(w, r) - return - case "/v0/metrics": - h.handleServeMetrics(w, r) - return - case "/v0/magicsock": - h.handleServeMagicsock(w, r) - return - case "/v0/dnsfwd": - h.handleServeDNSFwd(w, r) + if ph, ok := peerAPIHandlers[r.URL.Path]; ok { + ph(h, w, r) return - case "/v0/wol": - metricWakeOnLANCalls.Add(1) - h.handleWakeOnLAN(w, r) - return - case "/v0/interfaces": - h.handleServeInterfaces(w, r) - return - case "/v0/doctor": - h.handleServeDoctor(w, r) - return - case "/v0/sockstats": - h.handleServeSockStats(w, r) - return - case "/v0/ingress": - metricIngressCalls.Add(1) - h.handleServeIngress(w, r) + } + if r.URL.Path != "/" { + http.Error(w, "unsupported peerapi path", http.StatusNotFound) return } who := h.peerUser.DisplayName @@ -375,67 +419,6 @@ This is my Tailscale device. Your device is %v. } } -func (h *peerAPIHandler) handleServeIngress(w http.ResponseWriter, r *http.Request) { - // http.Errors only useful if hitting endpoint manually - // otherwise rely on log lines when debugging ingress connections - // as connection is hijacked for bidi and is encrypted tls - if !h.canIngress() { - h.logf("ingress: denied; no ingress cap from %v", h.remoteAddr) - http.Error(w, "denied; no ingress cap", http.StatusForbidden) - return - } - logAndError := func(code int, publicMsg string) { - h.logf("ingress: bad request from %v: %s", h.remoteAddr, publicMsg) - http.Error(w, publicMsg, http.StatusMethodNotAllowed) - } - bad := func(publicMsg string) { - logAndError(http.StatusBadRequest, publicMsg) - } - if r.Method != "POST" { - logAndError(http.StatusMethodNotAllowed, "only POST allowed") - return - } - srcAddrStr := r.Header.Get("Tailscale-Ingress-Src") - if srcAddrStr == "" { - bad("Tailscale-Ingress-Src header not set") - return - } - srcAddr, err := netip.ParseAddrPort(srcAddrStr) - if err != nil { - bad("Tailscale-Ingress-Src header invalid; want ip:port") - return - } - target := ipn.HostPort(r.Header.Get("Tailscale-Ingress-Target")) - if target == "" { - bad("Tailscale-Ingress-Target header not set") - return - } - if _, _, err := net.SplitHostPort(string(target)); err != nil { - bad("Tailscale-Ingress-Target header invalid; want host:port") - return - } - - getConnOrReset := func() (net.Conn, bool) { - conn, _, err := w.(http.Hijacker).Hijack() - if err != nil { - h.logf("ingress: failed hijacking conn") - http.Error(w, "failed hijacking conn", http.StatusInternalServerError) - return nil, false - } - io.WriteString(conn, "HTTP/1.1 101 Switching Protocols\r\n\r\n") - return &ipn.FunnelConn{ - Conn: conn, - Src: srcAddr, - Target: target, - }, true - } - sendRST := func() { - http.Error(w, "denied", http.StatusForbidden) - } - - h.ps.b.HandleIngressTCPConn(h.peerNode, target, srcAddr, getConnOrReset, sendRST) -} - func (h *peerAPIHandler) handleServeInterfaces(w http.ResponseWriter, r *http.Request) { if !h.canDebug() { http.Error(w, "denied; no debug access", http.StatusForbidden) @@ -450,7 +433,7 @@ func (h *peerAPIHandler) handleServeInterfaces(w http.ResponseWriter, r *http.Re fmt.Fprintf(w, "

Could not get the default route: %s

\n", html.EscapeString(err.Error())) } - if hasCGNATInterface, err := netmon.HasCGNATInterface(); hasCGNATInterface { + if hasCGNATInterface, err := h.ps.b.sys.NetMon.Get().HasCGNATInterface(); hasCGNATInterface { fmt.Fprintln(w, "

There is another interface using the CGNAT range.

") } else if err != nil { fmt.Fprintf(w, "

Could not check for CGNAT interfaces: %s

\n", html.EscapeString(err.Error())) @@ -483,24 +466,6 @@ func (h *peerAPIHandler) handleServeInterfaces(w http.ResponseWriter, r *http.Re fmt.Fprintln(w, "") } -func (h *peerAPIHandler) handleServeDoctor(w http.ResponseWriter, r *http.Request) { - if !h.canDebug() { - http.Error(w, "denied; no debug access", http.StatusForbidden) - return - } - w.Header().Set("Content-Type", "text/html; charset=utf-8") - fmt.Fprintln(w, "

Doctor Output

") - - fmt.Fprintln(w, "
")
-
-	h.ps.b.Doctor(r.Context(), func(format string, args ...any) {
-		line := fmt.Sprintf(format, args...)
-		fmt.Fprintln(w, html.EscapeString(line))
-	})
-
-	fmt.Fprintln(w, "
") -} - func (h *peerAPIHandler) handleServeSockStats(w http.ResponseWriter, r *http.Request) { if !h.canDebug() { http.Error(w, "denied; no debug access", http.StatusForbidden) @@ -599,14 +564,7 @@ func (h *peerAPIHandler) handleServeSockStats(w http.ResponseWriter, r *http.Req fmt.Fprintln(w, "") } -// canPutFile reports whether h can put a file ("Taildrop") to this node. -func (h *peerAPIHandler) canPutFile() bool { - if h.peerNode.UnsignedPeerAPIOnly() { - // Unsigned peers can't send files. - return false - } - return h.isSelf || h.peerHasCap(tailcfg.PeerCapabilityFileSharingSend) -} +func (h *peerAPIHandler) CanDebug() bool { return h.canDebug() } // canDebug reports whether h can debug this node (goroutines, metrics, // magicsock internal state, etc). @@ -622,14 +580,6 @@ func (h *peerAPIHandler) canDebug() bool { return h.isSelf || h.peerHasCap(tailcfg.PeerCapabilityDebugPeer) } -// canWakeOnLAN reports whether h can send a Wake-on-LAN packet from this node. -func (h *peerAPIHandler) canWakeOnLAN() bool { - if h.peerNode.UnsignedPeerAPIOnly() { - return false - } - return h.isSelf || h.peerHasCap(tailcfg.PeerCapabilityWakeOnLAN) -} - var allowSelfIngress = envknob.RegisterBool("TS_ALLOW_SELF_INGRESS") // canIngress reports whether h can send ingress requests to this node. @@ -638,117 +588,13 @@ func (h *peerAPIHandler) canIngress() bool { } func (h *peerAPIHandler) peerHasCap(wantCap tailcfg.PeerCapability) bool { - return h.peerCaps().HasCapability(wantCap) + return h.PeerCaps().HasCapability(wantCap) } -func (h *peerAPIHandler) peerCaps() tailcfg.PeerCapMap { +func (h *peerAPIHandler) PeerCaps() tailcfg.PeerCapMap { return h.ps.b.PeerCaps(h.remoteAddr.Addr()) } -func (h *peerAPIHandler) handlePeerPut(w http.ResponseWriter, r *http.Request) { - if !h.canPutFile() { - http.Error(w, taildrop.ErrNoTaildrop.Error(), http.StatusForbidden) - return - } - if !h.ps.b.hasCapFileSharing() { - http.Error(w, taildrop.ErrNoTaildrop.Error(), http.StatusForbidden) - return - } - rawPath := r.URL.EscapedPath() - prefix, ok := strings.CutPrefix(rawPath, "/v0/put/") - if !ok { - http.Error(w, "misconfigured internals", http.StatusForbidden) - return - } - baseName, err := url.PathUnescape(prefix) - if err != nil { - http.Error(w, taildrop.ErrInvalidFileName.Error(), http.StatusBadRequest) - return - } - enc := json.NewEncoder(w) - switch r.Method { - case "GET": - id := taildrop.ClientID(h.peerNode.StableID()) - if prefix == "" { - // List all the partial files. - files, err := h.ps.taildrop.PartialFiles(id) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - if err := enc.Encode(files); err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - h.logf("json.Encoder.Encode error: %v", err) - return - } - } else { - // Stream all the block hashes for the specified file. - next, close, err := h.ps.taildrop.HashPartialFile(id, baseName) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - defer close() - for { - switch cs, err := next(); { - case err == io.EOF: - return - case err != nil: - http.Error(w, err.Error(), http.StatusInternalServerError) - h.logf("HashPartialFile.next error: %v", err) - return - default: - if err := enc.Encode(cs); err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - h.logf("json.Encoder.Encode error: %v", err) - return - } - } - } - } - case "PUT": - t0 := h.ps.b.clock.Now() - id := taildrop.ClientID(h.peerNode.StableID()) - - var offset int64 - if rangeHdr := r.Header.Get("Range"); rangeHdr != "" { - ranges, ok := httphdr.ParseRange(rangeHdr) - if !ok || len(ranges) != 1 || ranges[0].Length != 0 { - http.Error(w, "invalid Range header", http.StatusBadRequest) - return - } - offset = ranges[0].Start - } - n, err := h.ps.taildrop.PutFile(taildrop.ClientID(fmt.Sprint(id)), baseName, r.Body, offset, r.ContentLength) - switch err { - case nil: - d := h.ps.b.clock.Since(t0).Round(time.Second / 10) - h.logf("got put of %s in %v from %v/%v", approxSize(n), d, h.remoteAddr.Addr(), h.peerNode.ComputedName) - io.WriteString(w, "{}\n") - case taildrop.ErrNoTaildrop: - http.Error(w, err.Error(), http.StatusForbidden) - case taildrop.ErrInvalidFileName: - http.Error(w, err.Error(), http.StatusBadRequest) - case taildrop.ErrFileExists: - http.Error(w, err.Error(), http.StatusConflict) - default: - http.Error(w, err.Error(), http.StatusInternalServerError) - } - default: - http.Error(w, "expected method GET or PUT", http.StatusMethodNotAllowed) - } -} - -func approxSize(n int64) string { - if n <= 1<<10 { - return "<=1KB" - } - if n <= 1<<20 { - return "<=1MB" - } - return fmt.Sprintf("~%dMB", n>>20) -} - func (h *peerAPIHandler) handleServeGoroutines(w http.ResponseWriter, r *http.Request) { if !h.canDebug() { http.Error(w, "denied; no debug access", http.StatusForbidden) @@ -803,6 +649,10 @@ func (h *peerAPIHandler) handleServeMetrics(w http.ResponseWriter, r *http.Reque } func (h *peerAPIHandler) handleServeDNSFwd(w http.ResponseWriter, r *http.Request) { + if !buildfeatures.HasDNS { + http.NotFound(w, r) + return + } if !h.canDebug() { http.Error(w, "denied; no debug access", http.StatusForbidden) return @@ -815,62 +665,10 @@ func (h *peerAPIHandler) handleServeDNSFwd(w http.ResponseWriter, r *http.Reques dh.ServeHTTP(w, r) } -func (h *peerAPIHandler) handleWakeOnLAN(w http.ResponseWriter, r *http.Request) { - if !h.canWakeOnLAN() { - http.Error(w, "no WoL access", http.StatusForbidden) - return - } - if r.Method != "POST" { - http.Error(w, "bad method", http.StatusMethodNotAllowed) - return - } - macStr := r.FormValue("mac") - if macStr == "" { - http.Error(w, "missing 'mac' param", http.StatusBadRequest) - return - } - mac, err := net.ParseMAC(macStr) - if err != nil { - http.Error(w, "bad 'mac' param", http.StatusBadRequest) - return - } - var password []byte // TODO(bradfitz): support? does anything use WoL passwords? - st := h.ps.b.sys.NetMon.Get().InterfaceState() - if st == nil { - http.Error(w, "failed to get interfaces state", http.StatusInternalServerError) - return - } - var res struct { - SentTo []string - Errors []string - } - for ifName, ips := range st.InterfaceIPs { - for _, ip := range ips { - if ip.Addr().IsLoopback() || ip.Addr().Is6() { - continue - } - local := &net.UDPAddr{ - IP: ip.Addr().AsSlice(), - Port: 0, - } - remote := &net.UDPAddr{ - IP: net.IPv4bcast, - Port: 0, - } - if err := wol.Wake(mac, password, local, remote); err != nil { - res.Errors = append(res.Errors, err.Error()) - } else { - res.SentTo = append(res.SentTo, ifName) - } - break // one per interface is enough - } - } - sort.Strings(res.SentTo) - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(res) -} - func (h *peerAPIHandler) replyToDNSQueries() bool { + if !buildfeatures.HasDNS { + return false + } if h.isSelf { // If the peer is owned by the same user, just allow it // without further checks. @@ -900,7 +698,7 @@ func (h *peerAPIHandler) replyToDNSQueries() bool { // but an app connector explicitly adds 0.0.0.0/32 (and the // IPv6 equivalent) to make this work (see updateFilterLocked // in LocalBackend). - f := b.filterAtomic.Load() + f := b.currentNode().filter() if f == nil { return false } @@ -922,7 +720,7 @@ func (h *peerAPIHandler) replyToDNSQueries() bool { // handleDNSQuery implements a DoH server (RFC 8484) over the peerapi. // It's not over HTTPS as the spec dictates, but rather HTTP-over-WireGuard. func (h *peerAPIHandler) handleDNSQuery(w http.ResponseWriter, r *http.Request) { - if h.ps.resolver == nil { + if !buildfeatures.HasDNS || h.ps.resolver == nil { http.Error(w, "DNS not wired up", http.StatusNotImplemented) return } @@ -963,8 +761,12 @@ func (h *peerAPIHandler) handleDNSQuery(w http.ResponseWriter, r *http.Request) // TODO(raggi): consider pushing the integration down into the resolver // instead to avoid re-parsing the DNS response for improved performance in // the future. - if h.ps.b.OfferingAppConnector() { - h.ps.b.ObserveDNSResponse(res) + if buildfeatures.HasAppConnectors && h.ps.b.OfferingAppConnector() { + if err := h.ps.b.ObserveDNSResponse(res); err != nil { + h.logf("ObserveDNSResponse error: %v", err) + // This is not fatal, we probably just failed to parse the upstream + // response. Return it to the caller anyway. + } } if pretty { @@ -986,7 +788,7 @@ func dohQuery(r *http.Request) (dnsQuery []byte, publicErr string) { case "GET": q64 := r.FormValue("dns") if q64 == "" { - return nil, "missing 'dns' parameter" + return nil, "missing ‘dns’ parameter; try '?dns=' (DoH standard) or use '?q=' for JSON debug mode" } if base64.RawURLEncoding.DecodedLen(len(q64)) > maxQueryLen { return nil, "query too large" @@ -1141,85 +943,46 @@ func (rbw *requestBodyWrapper) Read(b []byte) (int, error) { return n, err } -func (h *peerAPIHandler) handleServeDrive(w http.ResponseWriter, r *http.Request) { - if !h.ps.b.DriveSharingEnabled() { - h.logf("taildrive: not enabled") - http.Error(w, "taildrive not enabled", http.StatusNotFound) - return - } - - capsMap := h.peerCaps() - driveCaps, ok := capsMap[tailcfg.PeerCapabilityTaildrive] - if !ok { - h.logf("taildrive: not permitted") - http.Error(w, "taildrive not permitted", http.StatusForbidden) - return - } - - rawPerms := make([][]byte, 0, len(driveCaps)) - for _, cap := range driveCaps { - rawPerms = append(rawPerms, []byte(cap)) - } - - p, err := drive.ParsePermissions(rawPerms) - if err != nil { - h.logf("taildrive: error parsing permissions: %w", err.Error()) - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - - fs, ok := h.ps.b.sys.DriveForRemote.GetOK() - if !ok { - h.logf("taildrive: not supported on platform") - http.Error(w, "taildrive not supported on platform", http.StatusNotFound) - return - } - wr := &httpResponseWrapper{ - ResponseWriter: w, - } - bw := &requestBodyWrapper{ - ReadCloser: r.Body, +// peerAPIURL returns an HTTP URL for the peer's peerapi service, +// without a trailing slash. +// +// If ip or port is the zero value then it returns the empty string. +func peerAPIURL(ip netip.Addr, port uint16) string { + if port == 0 || !ip.IsValid() { + return "" } - r.Body = bw - - if r.Method == httpm.PUT || r.Method == httpm.GET { - defer func() { - switch wr.statusCode { - case 304: - // 304s are particularly chatty so skip logging. - default: - contentType := "unknown" - if ct := wr.Header().Get("Content-Type"); ct != "" { - contentType = ct - } + return fmt.Sprintf("http://%v", netip.AddrPortFrom(ip, port)) +} - h.logf("taildrive: share: %s from %s to %s: status-code=%d ext=%q content-type=%q tx=%.f rx=%.f", r.Method, h.peerNode.Key().ShortString(), h.selfNode.Key().ShortString(), wr.statusCode, parseDriveFileExtensionForLog(r.URL.Path), contentType, roundTraffic(wr.contentLength), roundTraffic(bw.bytesRead)) - } - }() +// peerAPIBase returns the "http://ip:port" URL base to reach peer's peerAPI. +// It returns the empty string if the peer doesn't support the peerapi +// or there's no matching address family based on the netmap's own addresses. +func peerAPIBase(nm *netmap.NetworkMap, peer tailcfg.NodeView) string { + if nm == nil || !peer.Valid() || !peer.Hostinfo().Valid() { + return "" } - r.URL.Path = strings.TrimPrefix(r.URL.Path, taildrivePrefix) - fs.ServeHTTPWithPerms(p, wr, r) -} - -// parseDriveFileExtensionForLog parses the file extension, if available. -// If a file extension is not present or parsable, the file extension is -// set to "unknown". If the file extension contains a double quote, it is -// replaced with "removed". -// All whitespace is removed from a parsed file extension. -// File extensions including the leading ., e.g. ".gif". -func parseDriveFileExtensionForLog(path string) string { - fileExt := "unknown" - if fe := filepath.Ext(path); fe != "" { - if strings.Contains(fe, "\"") { - // Do not log include file extensions with quotes within them. - return "removed" + var have4, have6 bool + addrs := nm.GetAddresses() + for _, a := range addrs.All() { + if !a.IsSingleIP() { + continue + } + switch { + case a.Addr().Is4(): + have4 = true + case a.Addr().Is6(): + have6 = true } - // Remove white space from user defined inputs. - fileExt = strings.ReplaceAll(fe, " ", "") } - - return fileExt + p4, p6 := peerAPIPorts(peer) + switch { + case have4 && p4 != 0: + return peerAPIURL(nodeIP(peer, netip.Addr.Is4), p4) + case have6 && p6 != 0: + return peerAPIURL(nodeIP(peer, netip.Addr.Is6), p6) + } + return "" } // newFakePeerAPIListener creates a new net.Listener that acts like @@ -1272,8 +1035,5 @@ var ( metricInvalidRequests = clientmetric.NewCounter("peerapi_invalid_requests") // Non-debug PeerAPI endpoints. - metricPutCalls = clientmetric.NewCounter("peerapi_put") - metricDNSCalls = clientmetric.NewCounter("peerapi_dns") - metricWakeOnLANCalls = clientmetric.NewCounter("peerapi_wol") - metricIngressCalls = clientmetric.NewCounter("peerapi_ingress") + metricDNSCalls = clientmetric.NewCounter("peerapi_dns") ) diff --git a/ipn/ipnlocal/peerapi_drive.go b/ipn/ipnlocal/peerapi_drive.go new file mode 100644 index 000000000..8dffacd9a --- /dev/null +++ b/ipn/ipnlocal/peerapi_drive.go @@ -0,0 +1,110 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !ts_omit_drive + +package ipnlocal + +import ( + "net/http" + "path/filepath" + "strings" + + "tailscale.com/drive" + "tailscale.com/tailcfg" + "tailscale.com/util/httpm" +) + +const ( + taildrivePrefix = "/v0/drive" +) + +func init() { + peerAPIHandlerPrefixes[taildrivePrefix] = handleServeDrive +} + +func handleServeDrive(hi PeerAPIHandler, w http.ResponseWriter, r *http.Request) { + h := hi.(*peerAPIHandler) + + h.logfv1("taildrive: got %s request from %s", r.Method, h.peerNode.Key().ShortString()) + if !h.ps.b.DriveSharingEnabled() { + h.logf("taildrive: not enabled") + http.Error(w, "taildrive not enabled", http.StatusNotFound) + return + } + + capsMap := h.PeerCaps() + driveCaps, ok := capsMap[tailcfg.PeerCapabilityTaildrive] + if !ok { + h.logf("taildrive: not permitted") + http.Error(w, "taildrive not permitted", http.StatusForbidden) + return + } + + rawPerms := make([][]byte, 0, len(driveCaps)) + for _, cap := range driveCaps { + rawPerms = append(rawPerms, []byte(cap)) + } + + p, err := drive.ParsePermissions(rawPerms) + if err != nil { + h.logf("taildrive: error parsing permissions: %v", err) + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + fs, ok := h.ps.b.sys.DriveForRemote.GetOK() + if !ok { + h.logf("taildrive: not supported on platform") + http.Error(w, "taildrive not supported on platform", http.StatusNotFound) + return + } + wr := &httpResponseWrapper{ + ResponseWriter: w, + } + bw := &requestBodyWrapper{ + ReadCloser: r.Body, + } + r.Body = bw + + defer func() { + switch wr.statusCode { + case 304: + // 304s are particularly chatty so skip logging. + default: + log := h.logf + if r.Method != httpm.PUT && r.Method != httpm.GET { + log = h.logfv1 + } + contentType := "unknown" + if ct := wr.Header().Get("Content-Type"); ct != "" { + contentType = ct + } + + log("taildrive: share: %s from %s to %s: status-code=%d ext=%q content-type=%q tx=%.f rx=%.f", r.Method, h.peerNode.Key().ShortString(), h.selfNode.Key().ShortString(), wr.statusCode, parseDriveFileExtensionForLog(r.URL.Path), contentType, roundTraffic(wr.contentLength), roundTraffic(bw.bytesRead)) + } + }() + + r.URL.Path = strings.TrimPrefix(r.URL.Path, taildrivePrefix) + fs.ServeHTTPWithPerms(p, wr, r) +} + +// parseDriveFileExtensionForLog parses the file extension, if available. +// If a file extension is not present or parsable, the file extension is +// set to "unknown". If the file extension contains a double quote, it is +// replaced with "removed". +// All whitespace is removed from a parsed file extension. +// File extensions including the leading ., e.g. ".gif". +func parseDriveFileExtensionForLog(path string) string { + fileExt := "unknown" + if fe := filepath.Ext(path); fe != "" { + if strings.Contains(fe, "\"") { + // Do not log include file extensions with quotes within them. + return "removed" + } + // Remove white space from user defined inputs. + fileExt = strings.ReplaceAll(fe, " ", "") + } + + return fileExt +} diff --git a/ipn/ipnlocal/peerapi_h2c.go b/ipn/ipnlocal/peerapi_h2c.go deleted file mode 100644 index fbfa86398..000000000 --- a/ipn/ipnlocal/peerapi_h2c.go +++ /dev/null @@ -1,20 +0,0 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !ios && !android && !js - -package ipnlocal - -import ( - "net/http" - - "golang.org/x/net/http2" - "golang.org/x/net/http2/h2c" -) - -func init() { - addH2C = func(s *http.Server) { - h2s := &http2.Server{} - s.Handler = h2c.NewHandler(s.Handler, h2s) - } -} diff --git a/ipn/ipnlocal/peerapi_test.go b/ipn/ipnlocal/peerapi_test.go index ff9b62769..3c9f57f1f 100644 --- a/ipn/ipnlocal/peerapi_test.go +++ b/ipn/ipnlocal/peerapi_test.go @@ -4,36 +4,29 @@ package ipnlocal import ( - "bytes" "context" "encoding/json" - "fmt" - "io" - "io/fs" - "math/rand" "net/http" "net/http/httptest" "net/netip" - "os" - "path/filepath" "slices" "strings" "testing" - "github.com/google/go-cmp/cmp" "go4.org/netipx" "golang.org/x/net/dns/dnsmessage" "tailscale.com/appc" "tailscale.com/appc/appctest" - "tailscale.com/client/tailscale/apitype" "tailscale.com/health" "tailscale.com/ipn" "tailscale.com/ipn/store/mem" "tailscale.com/tailcfg" - "tailscale.com/taildrop" + "tailscale.com/tsd" "tailscale.com/tstest" + "tailscale.com/types/appctype" "tailscale.com/types/logger" "tailscale.com/types/netmap" + "tailscale.com/util/eventbus/eventbustest" "tailscale.com/util/must" "tailscale.com/util/usermetric" "tailscale.com/wgengine" @@ -74,64 +67,18 @@ func bodyNotContains(sub string) check { } } -func fileHasSize(name string, size int) check { - return func(t *testing.T, e *peerAPITestEnv) { - root := e.ph.ps.taildrop.Dir() - if root == "" { - t.Errorf("no rootdir; can't check whether %q has size %v", name, size) - return - } - path := filepath.Join(root, name) - if fi, err := os.Stat(path); err != nil { - t.Errorf("fileHasSize(%q, %v): %v", name, size, err) - } else if fi.Size() != int64(size) { - t.Errorf("file %q has size %v; want %v", name, fi.Size(), size) - } - } -} - -func fileHasContents(name string, want string) check { - return func(t *testing.T, e *peerAPITestEnv) { - root := e.ph.ps.taildrop.Dir() - if root == "" { - t.Errorf("no rootdir; can't check contents of %q", name) - return - } - path := filepath.Join(root, name) - got, err := os.ReadFile(path) - if err != nil { - t.Errorf("fileHasContents: %v", err) - return - } - if string(got) != want { - t.Errorf("file contents = %q; want %q", got, want) - } - } -} - -func hexAll(v string) string { - var sb strings.Builder - for i := range len(v) { - fmt.Fprintf(&sb, "%%%02x", v[i]) - } - return sb.String() -} - func TestHandlePeerAPI(t *testing.T) { tests := []struct { - name string - isSelf bool // the peer sending the request is owned by us - capSharing bool // self node has file sharing capability - debugCap bool // self node has debug capability - omitRoot bool // don't configure - reqs []*http.Request - checks []check + name string + isSelf bool // the peer sending the request is owned by us + debugCap bool // self node has debug capability + reqs []*http.Request + checks []check }{ { - name: "not_peer_api", - isSelf: true, - capSharing: true, - reqs: []*http.Request{httptest.NewRequest("GET", "/", nil)}, + name: "not_peer_api", + isSelf: true, + reqs: []*http.Request{httptest.NewRequest("GET", "/", nil)}, checks: checks( httpStatus(200), bodyContains("This is my Tailscale device."), @@ -139,10 +86,9 @@ func TestHandlePeerAPI(t *testing.T) { ), }, { - name: "not_peer_api_not_owner", - isSelf: false, - capSharing: true, - reqs: []*http.Request{httptest.NewRequest("GET", "/", nil)}, + name: "not_peer_api_not_owner", + isSelf: false, + reqs: []*http.Request{httptest.NewRequest("GET", "/", nil)}, checks: checks( httpStatus(200), bodyContains("This is my Tailscale device."), @@ -173,255 +119,6 @@ func TestHandlePeerAPI(t *testing.T) { bodyContains("ServeHTTP"), ), }, - { - name: "reject_non_owner_put", - isSelf: false, - capSharing: true, - reqs: []*http.Request{httptest.NewRequest("PUT", "/v0/put/foo", nil)}, - checks: checks( - httpStatus(http.StatusForbidden), - bodyContains("Taildrop disabled"), - ), - }, - { - name: "owner_without_cap", - isSelf: true, - capSharing: false, - reqs: []*http.Request{httptest.NewRequest("PUT", "/v0/put/foo", nil)}, - checks: checks( - httpStatus(http.StatusForbidden), - bodyContains("Taildrop disabled"), - ), - }, - { - name: "owner_with_cap_no_rootdir", - omitRoot: true, - isSelf: true, - capSharing: true, - reqs: []*http.Request{httptest.NewRequest("PUT", "/v0/put/foo", nil)}, - checks: checks( - httpStatus(http.StatusForbidden), - bodyContains("Taildrop disabled; no storage directory"), - ), - }, - { - name: "bad_method", - isSelf: true, - capSharing: true, - reqs: []*http.Request{httptest.NewRequest("POST", "/v0/put/foo", nil)}, - checks: checks( - httpStatus(405), - bodyContains("expected method GET or PUT"), - ), - }, - { - name: "put_zero_length", - isSelf: true, - capSharing: true, - reqs: []*http.Request{httptest.NewRequest("PUT", "/v0/put/foo", nil)}, - checks: checks( - httpStatus(200), - bodyContains("{}"), - fileHasSize("foo", 0), - fileHasContents("foo", ""), - ), - }, - { - name: "put_non_zero_length_content_length", - isSelf: true, - capSharing: true, - reqs: []*http.Request{httptest.NewRequest("PUT", "/v0/put/foo", strings.NewReader("contents"))}, - checks: checks( - httpStatus(200), - bodyContains("{}"), - fileHasSize("foo", len("contents")), - fileHasContents("foo", "contents"), - ), - }, - { - name: "put_non_zero_length_chunked", - isSelf: true, - capSharing: true, - reqs: []*http.Request{httptest.NewRequest("PUT", "/v0/put/foo", struct{ io.Reader }{strings.NewReader("contents")})}, - checks: checks( - httpStatus(200), - bodyContains("{}"), - fileHasSize("foo", len("contents")), - fileHasContents("foo", "contents"), - ), - }, - { - name: "bad_filename_partial", - isSelf: true, - capSharing: true, - reqs: []*http.Request{httptest.NewRequest("PUT", "/v0/put/foo.partial", nil)}, - checks: checks( - httpStatus(400), - bodyContains("invalid filename"), - ), - }, - { - name: "bad_filename_deleted", - isSelf: true, - capSharing: true, - reqs: []*http.Request{httptest.NewRequest("PUT", "/v0/put/foo.deleted", nil)}, - checks: checks( - httpStatus(400), - bodyContains("invalid filename"), - ), - }, - { - name: "bad_filename_dot", - isSelf: true, - capSharing: true, - reqs: []*http.Request{httptest.NewRequest("PUT", "/v0/put/.", nil)}, - checks: checks( - httpStatus(400), - bodyContains("invalid filename"), - ), - }, - { - name: "bad_filename_empty", - isSelf: true, - capSharing: true, - reqs: []*http.Request{httptest.NewRequest("PUT", "/v0/put/", nil)}, - checks: checks( - httpStatus(400), - bodyContains("invalid filename"), - ), - }, - { - name: "bad_filename_slash", - isSelf: true, - capSharing: true, - reqs: []*http.Request{httptest.NewRequest("PUT", "/v0/put/foo/bar", nil)}, - checks: checks( - httpStatus(400), - bodyContains("invalid filename"), - ), - }, - { - name: "bad_filename_encoded_dot", - isSelf: true, - capSharing: true, - reqs: []*http.Request{httptest.NewRequest("PUT", "/v0/put/"+hexAll("."), nil)}, - checks: checks( - httpStatus(400), - bodyContains("invalid filename"), - ), - }, - { - name: "bad_filename_encoded_slash", - isSelf: true, - capSharing: true, - reqs: []*http.Request{httptest.NewRequest("PUT", "/v0/put/"+hexAll("/"), nil)}, - checks: checks( - httpStatus(400), - bodyContains("invalid filename"), - ), - }, - { - name: "bad_filename_encoded_backslash", - isSelf: true, - capSharing: true, - reqs: []*http.Request{httptest.NewRequest("PUT", "/v0/put/"+hexAll("\\"), nil)}, - checks: checks( - httpStatus(400), - bodyContains("invalid filename"), - ), - }, - { - name: "bad_filename_encoded_dotdot", - isSelf: true, - capSharing: true, - reqs: []*http.Request{httptest.NewRequest("PUT", "/v0/put/"+hexAll(".."), nil)}, - checks: checks( - httpStatus(400), - bodyContains("invalid filename"), - ), - }, - { - name: "bad_filename_encoded_dotdot_out", - isSelf: true, - capSharing: true, - reqs: []*http.Request{httptest.NewRequest("PUT", "/v0/put/"+hexAll("foo/../../../../../etc/passwd"), nil)}, - checks: checks( - httpStatus(400), - bodyContains("invalid filename"), - ), - }, - { - name: "put_spaces_and_caps", - isSelf: true, - capSharing: true, - reqs: []*http.Request{httptest.NewRequest("PUT", "/v0/put/"+hexAll("Foo Bar.dat"), strings.NewReader("baz"))}, - checks: checks( - httpStatus(200), - bodyContains("{}"), - fileHasContents("Foo Bar.dat", "baz"), - ), - }, - { - name: "put_unicode", - isSelf: true, - capSharing: true, - reqs: []*http.Request{httptest.NewRequest("PUT", "/v0/put/"+hexAll("ĐĸĐžĐŧĐ°Ņ и ĐĩĐŗĐž Đ´Ņ€ŅƒĐˇŅŒŅ.mp3"), strings.NewReader("ĐŗĐģавĐŊŅ‹Đš ĐžĐˇĐžŅ€ĐŊиĐē"))}, - checks: checks( - httpStatus(200), - bodyContains("{}"), - fileHasContents("ĐĸĐžĐŧĐ°Ņ и ĐĩĐŗĐž Đ´Ņ€ŅƒĐˇŅŒŅ.mp3", "ĐŗĐģавĐŊŅ‹Đš ĐžĐˇĐžŅ€ĐŊиĐē"), - ), - }, - { - name: "put_invalid_utf8", - isSelf: true, - capSharing: true, - reqs: []*http.Request{httptest.NewRequest("PUT", "/v0/put/"+(hexAll("😜")[:3]), nil)}, - checks: checks( - httpStatus(400), - bodyContains("invalid filename"), - ), - }, - { - name: "put_invalid_null", - isSelf: true, - capSharing: true, - reqs: []*http.Request{httptest.NewRequest("PUT", "/v0/put/%00", nil)}, - checks: checks( - httpStatus(400), - bodyContains("invalid filename"), - ), - }, - { - name: "put_invalid_non_printable", - isSelf: true, - capSharing: true, - reqs: []*http.Request{httptest.NewRequest("PUT", "/v0/put/%01", nil)}, - checks: checks( - httpStatus(400), - bodyContains("invalid filename"), - ), - }, - { - name: "put_invalid_colon", - isSelf: true, - capSharing: true, - reqs: []*http.Request{httptest.NewRequest("PUT", "/v0/put/"+hexAll("nul:"), nil)}, - checks: checks( - httpStatus(400), - bodyContains("invalid filename"), - ), - }, - { - name: "put_invalid_surrounding_whitespace", - isSelf: true, - capSharing: true, - reqs: []*http.Request{httptest.NewRequest("PUT", "/v0/put/"+hexAll(" foo "), nil)}, - checks: checks( - httpStatus(400), - bodyContains("invalid filename"), - ), - }, { name: "host-val/bad-ip", isSelf: true, @@ -449,72 +146,6 @@ func TestHandlePeerAPI(t *testing.T) { httpStatus(200), ), }, - { - name: "duplicate_zero_length", - isSelf: true, - capSharing: true, - reqs: []*http.Request{ - httptest.NewRequest("PUT", "/v0/put/foo", nil), - httptest.NewRequest("PUT", "/v0/put/foo", nil), - }, - checks: checks( - httpStatus(200), - func(t *testing.T, env *peerAPITestEnv) { - got, err := env.ph.ps.taildrop.WaitingFiles() - if err != nil { - t.Fatalf("WaitingFiles error: %v", err) - } - want := []apitype.WaitingFile{{Name: "foo", Size: 0}} - if diff := cmp.Diff(got, want); diff != "" { - t.Fatalf("WaitingFile mismatch (-got +want):\n%s", diff) - } - }, - ), - }, - { - name: "duplicate_non_zero_length_content_length", - isSelf: true, - capSharing: true, - reqs: []*http.Request{ - httptest.NewRequest("PUT", "/v0/put/foo", strings.NewReader("contents")), - httptest.NewRequest("PUT", "/v0/put/foo", strings.NewReader("contents")), - }, - checks: checks( - httpStatus(200), - func(t *testing.T, env *peerAPITestEnv) { - got, err := env.ph.ps.taildrop.WaitingFiles() - if err != nil { - t.Fatalf("WaitingFiles error: %v", err) - } - want := []apitype.WaitingFile{{Name: "foo", Size: 8}} - if diff := cmp.Diff(got, want); diff != "" { - t.Fatalf("WaitingFile mismatch (-got +want):\n%s", diff) - } - }, - ), - }, - { - name: "duplicate_different_files", - isSelf: true, - capSharing: true, - reqs: []*http.Request{ - httptest.NewRequest("PUT", "/v0/put/foo", strings.NewReader("fizz")), - httptest.NewRequest("PUT", "/v0/put/foo", strings.NewReader("buzz")), - }, - checks: checks( - httpStatus(200), - func(t *testing.T, env *peerAPITestEnv) { - got, err := env.ph.ps.taildrop.WaitingFiles() - if err != nil { - t.Fatalf("WaitingFiles error: %v", err) - } - want := []apitype.WaitingFile{{Name: "foo", Size: 4}, {Name: "foo (1)", Size: 4}} - if diff := cmp.Diff(got, want); diff != "" { - t.Fatalf("WaitingFile mismatch (-got +want):\n%s", diff) - } - }, - ), - }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -527,12 +158,10 @@ func TestHandlePeerAPI(t *testing.T) { selfNode.CapMap = tailcfg.NodeCapMap{tailcfg.CapabilityDebug: nil} } var e peerAPITestEnv - lb := &LocalBackend{ - logf: e.logBuf.Logf, - capFileSharing: tt.capSharing, - netMap: &netmap.NetworkMap{SelfNode: selfNode.View()}, - clock: &tstest.Clock{}, - } + lb := newTestLocalBackend(t) + lb.logf = e.logBuf.Logf + lb.clock = &tstest.Clock{} + lb.currentNode().SetNetMap(&netmap.NetworkMap{SelfNode: selfNode.View()}) e.ph = &peerAPIHandler{ isSelf: tt.isSelf, selfNode: selfNode.View(), @@ -543,16 +172,6 @@ func TestHandlePeerAPI(t *testing.T) { b: lb, }, } - var rootDir string - if !tt.omitRoot { - rootDir = t.TempDir() - if e.ph.ps.taildrop == nil { - e.ph.ps.taildrop = taildrop.ManagerOptions{ - Logf: e.logBuf.Logf, - Dir: rootDir, - }.New() - } - } for _, req := range tt.reqs { e.rr = httptest.NewRecorder() if req.Host == "example.com" { @@ -563,76 +182,10 @@ func TestHandlePeerAPI(t *testing.T) { for _, f := range tt.checks { f(t, &e) } - if t.Failed() && rootDir != "" { - t.Logf("Contents of %s:", rootDir) - des, _ := fs.ReadDir(os.DirFS(rootDir), ".") - for _, de := range des { - fi, err := de.Info() - if err != nil { - t.Log(err) - } else { - t.Logf(" %v %5d %s", fi.Mode(), fi.Size(), de.Name()) - } - } - } }) } } -// Windows likes to hold on to file descriptors for some indeterminate -// amount of time after you close them and not let you delete them for -// a bit. So test that we work around that sufficiently. -func TestFileDeleteRace(t *testing.T) { - dir := t.TempDir() - ps := &peerAPIServer{ - b: &LocalBackend{ - logf: t.Logf, - capFileSharing: true, - clock: &tstest.Clock{}, - }, - taildrop: taildrop.ManagerOptions{ - Logf: t.Logf, - Dir: dir, - }.New(), - } - ph := &peerAPIHandler{ - isSelf: true, - peerNode: (&tailcfg.Node{ - ComputedName: "some-peer-name", - }).View(), - selfNode: (&tailcfg.Node{ - Addresses: []netip.Prefix{netip.MustParsePrefix("100.100.100.101/32")}, - }).View(), - ps: ps, - } - buf := make([]byte, 2<<20) - for range 30 { - rr := httptest.NewRecorder() - ph.ServeHTTP(rr, httptest.NewRequest("PUT", "http://100.100.100.101:123/v0/put/foo.txt", bytes.NewReader(buf[:rand.Intn(len(buf))]))) - if res := rr.Result(); res.StatusCode != 200 { - t.Fatal(res.Status) - } - wfs, err := ps.taildrop.WaitingFiles() - if err != nil { - t.Fatal(err) - } - if len(wfs) != 1 { - t.Fatalf("waiting files = %d; want 1", len(wfs)) - } - - if err := ps.taildrop.DeleteFile("foo.txt"); err != nil { - t.Fatal(err) - } - wfs, err = ps.taildrop.WaitingFiles() - if err != nil { - t.Fatal(err) - } - if len(wfs) != 0 { - t.Fatalf("waiting files = %d; want 0", len(wfs)) - } - } -} - func TestPeerAPIReplyToDNSQueries(t *testing.T) { var h peerAPIHandler @@ -643,17 +196,19 @@ func TestPeerAPIReplyToDNSQueries(t *testing.T) { h.isSelf = false h.remoteAddr = netip.MustParseAddrPort("100.150.151.152:12345") - ht := new(health.Tracker) - reg := new(usermetric.Registry) - eng, _ := wgengine.NewFakeUserspaceEngine(logger.Discard, 0, ht, reg) + sys := tsd.NewSystemWithBus(eventbustest.NewBus(t)) + + ht := health.NewTracker(sys.Bus.Get()) pm := must.Get(newProfileManager(new(mem.Store), t.Logf, ht)) - h.ps = &peerAPIServer{ - b: &LocalBackend{ - e: eng, - pm: pm, - store: pm.Store(), - }, - } + reg := new(usermetric.Registry) + eng, _ := wgengine.NewFakeUserspaceEngine(logger.Discard, 0, ht, reg, sys.Bus.Get(), sys.Set) + sys.Set(pm.Store()) + sys.Set(eng) + + b := newTestLocalBackendWithSys(t, sys) + b.pm = pm + + h.ps = &peerAPIServer{b: b} if h.ps.b.OfferingExitNode() { t.Fatal("unexpectedly offering exit node") } @@ -695,26 +250,26 @@ func TestPeerAPIPrettyReplyCNAME(t *testing.T) { var h peerAPIHandler h.remoteAddr = netip.MustParseAddrPort("100.150.151.152:12345") - ht := new(health.Tracker) + sys := tsd.NewSystemWithBus(eventbustest.NewBus(t)) + + ht := health.NewTracker(sys.Bus.Get()) reg := new(usermetric.Registry) - eng, _ := wgengine.NewFakeUserspaceEngine(logger.Discard, 0, ht, reg) + eng, _ := wgengine.NewFakeUserspaceEngine(logger.Discard, 0, ht, reg, sys.Bus.Get(), sys.Set) pm := must.Get(newProfileManager(new(mem.Store), t.Logf, ht)) - var a *appc.AppConnector - if shouldStore { - a = appc.NewAppConnector(t.Logf, &appctest.RouteCollector{}, &appc.RouteInfo{}, fakeStoreRoutes) - } else { - a = appc.NewAppConnector(t.Logf, &appctest.RouteCollector{}, nil, nil) - } - h.ps = &peerAPIServer{ - b: &LocalBackend{ - e: eng, - pm: pm, - store: pm.Store(), - // configure as an app connector just to enable the API. - appConnector: a, - }, - } + a := appc.NewAppConnector(appc.Config{ + Logf: t.Logf, + EventBus: sys.Bus.Get(), + HasStoredRoutes: shouldStore, + }) + t.Cleanup(a.Close) + sys.Set(pm.Store()) + sys.Set(eng) + + b := newTestLocalBackendWithSys(t, sys) + b.pm = pm + b.appConnector = a // configure as an app connector just to enable the API. + h.ps = &peerAPIServer{b: b} h.ps.resolver = &fakeResolver{build: func(b *dnsmessage.Builder) { b.CNAMEResource( dnsmessage.ResourceHeader{ @@ -764,31 +319,35 @@ func TestPeerAPIPrettyReplyCNAME(t *testing.T) { func TestPeerAPIReplyToDNSQueriesAreObserved(t *testing.T) { for _, shouldStore := range []bool{false, true} { - ctx := context.Background() var h peerAPIHandler h.remoteAddr = netip.MustParseAddrPort("100.150.151.152:12345") + sys := tsd.NewSystemWithBus(eventbustest.NewBus(t)) + bw := eventbustest.NewWatcher(t, sys.Bus.Get()) + rc := &appctest.RouteCollector{} - ht := new(health.Tracker) - reg := new(usermetric.Registry) - eng, _ := wgengine.NewFakeUserspaceEngine(logger.Discard, 0, ht, reg) + ht := health.NewTracker(sys.Bus.Get()) pm := must.Get(newProfileManager(new(mem.Store), t.Logf, ht)) - var a *appc.AppConnector - if shouldStore { - a = appc.NewAppConnector(t.Logf, rc, &appc.RouteInfo{}, fakeStoreRoutes) - } else { - a = appc.NewAppConnector(t.Logf, rc, nil, nil) - } - h.ps = &peerAPIServer{ - b: &LocalBackend{ - e: eng, - pm: pm, - store: pm.Store(), - appConnector: a, - }, - } + + reg := new(usermetric.Registry) + eng, _ := wgengine.NewFakeUserspaceEngine(logger.Discard, 0, ht, reg, sys.Bus.Get(), sys.Set) + a := appc.NewAppConnector(appc.Config{ + Logf: t.Logf, + EventBus: sys.Bus.Get(), + RouteAdvertiser: rc, + HasStoredRoutes: shouldStore, + }) + t.Cleanup(a.Close) + sys.Set(pm.Store()) + sys.Set(eng) + + b := newTestLocalBackendWithSys(t, sys) + b.pm = pm + b.appConnector = a + + h.ps = &peerAPIServer{b: b} h.ps.b.appConnector.UpdateDomains([]string{"example.com"}) - h.ps.b.appConnector.Wait(ctx) + a.Wait(t.Context()) h.ps.resolver = &fakeResolver{build: func(b *dnsmessage.Builder) { b.AResource( @@ -818,12 +377,18 @@ func TestPeerAPIReplyToDNSQueriesAreObserved(t *testing.T) { if w.Code != http.StatusOK { t.Errorf("unexpected status code: %v", w.Code) } - h.ps.b.appConnector.Wait(ctx) + a.Wait(t.Context()) wantRoutes := []netip.Prefix{netip.MustParsePrefix("192.0.0.8/32")} if !slices.Equal(rc.Routes(), wantRoutes) { t.Errorf("got %v; want %v", rc.Routes(), wantRoutes) } + + if err := eventbustest.Expect(bw, + eqUpdate(appctype.RouteUpdate{Advertise: mustPrefix("192.0.0.8/32")}), + ); err != nil { + t.Error(err) + } } } @@ -833,27 +398,31 @@ func TestPeerAPIReplyToDNSQueriesAreObservedWithCNAMEFlattening(t *testing.T) { var h peerAPIHandler h.remoteAddr = netip.MustParseAddrPort("100.150.151.152:12345") - ht := new(health.Tracker) + sys := tsd.NewSystemWithBus(eventbustest.NewBus(t)) + bw := eventbustest.NewWatcher(t, sys.Bus.Get()) + + ht := health.NewTracker(sys.Bus.Get()) reg := new(usermetric.Registry) rc := &appctest.RouteCollector{} - eng, _ := wgengine.NewFakeUserspaceEngine(logger.Discard, 0, ht, reg) + eng, _ := wgengine.NewFakeUserspaceEngine(logger.Discard, 0, ht, reg, sys.Bus.Get(), sys.Set) pm := must.Get(newProfileManager(new(mem.Store), t.Logf, ht)) - var a *appc.AppConnector - if shouldStore { - a = appc.NewAppConnector(t.Logf, rc, &appc.RouteInfo{}, fakeStoreRoutes) - } else { - a = appc.NewAppConnector(t.Logf, rc, nil, nil) - } - h.ps = &peerAPIServer{ - b: &LocalBackend{ - e: eng, - pm: pm, - store: pm.Store(), - appConnector: a, - }, - } + a := appc.NewAppConnector(appc.Config{ + Logf: t.Logf, + EventBus: sys.Bus.Get(), + RouteAdvertiser: rc, + HasStoredRoutes: shouldStore, + }) + t.Cleanup(a.Close) + sys.Set(pm.Store()) + sys.Set(eng) + + b := newTestLocalBackendWithSys(t, sys) + b.pm = pm + b.appConnector = a + + h.ps = &peerAPIServer{b: b} h.ps.b.appConnector.UpdateDomains([]string{"www.example.com"}) - h.ps.b.appConnector.Wait(ctx) + a.Wait(ctx) h.ps.resolver = &fakeResolver{build: func(b *dnsmessage.Builder) { b.CNAMEResource( @@ -894,12 +463,18 @@ func TestPeerAPIReplyToDNSQueriesAreObservedWithCNAMEFlattening(t *testing.T) { if w.Code != http.StatusOK { t.Errorf("unexpected status code: %v", w.Code) } - h.ps.b.appConnector.Wait(ctx) + a.Wait(ctx) wantRoutes := []netip.Prefix{netip.MustParsePrefix("192.0.0.8/32")} if !slices.Equal(rc.Routes(), wantRoutes) { t.Errorf("got %v; want %v", rc.Routes(), wantRoutes) } + + if err := eventbustest.Expect(bw, + eqUpdate(appctype.RouteUpdate{Advertise: mustPrefix("192.0.0.8/32")}), + ); err != nil { + t.Error(err) + } } } diff --git a/ipn/ipnlocal/prefs_metrics.go b/ipn/ipnlocal/prefs_metrics.go new file mode 100644 index 000000000..34c5f5504 --- /dev/null +++ b/ipn/ipnlocal/prefs_metrics.go @@ -0,0 +1,103 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package ipnlocal + +import ( + "errors" + + "tailscale.com/feature/buildfeatures" + "tailscale.com/ipn" + "tailscale.com/tailcfg" + "tailscale.com/util/clientmetric" +) + +// Counter metrics for edit/change events +var ( + // metricExitNodeEnabled is incremented when the user enables an exit node independent of the node's characteristics. + metricExitNodeEnabled = clientmetric.NewCounter("prefs_exit_node_enabled") + // metricExitNodeEnabledSuggested is incremented when the user enables the suggested exit node. + metricExitNodeEnabledSuggested = clientmetric.NewCounter("prefs_exit_node_enabled_suggested") + // metricExitNodeEnabledMullvad is incremented when the user enables a Mullvad exit node. + metricExitNodeEnabledMullvad = clientmetric.NewCounter("prefs_exit_node_enabled_mullvad") + // metricWantRunningEnabled is incremented when WantRunning transitions from false to true. + metricWantRunningEnabled = clientmetric.NewCounter("prefs_want_running_enabled") + // metricWantRunningDisabled is incremented when WantRunning transitions from true to false. + metricWantRunningDisabled = clientmetric.NewCounter("prefs_want_running_disabled") +) + +type exitNodeProperty string + +const ( + exitNodeTypePreferred exitNodeProperty = "suggested" // The exit node is the last suggested exit node + exitNodeTypeMullvad exitNodeProperty = "mullvad" // The exit node is a Mullvad exit node +) + +// prefsMetricsEditEvent encapsulates information needed to record metrics related +// to any changes to preferences. +type prefsMetricsEditEvent struct { + change *ipn.MaskedPrefs // the preference mask used to update the preferences + pNew ipn.PrefsView // new preferences (after ApplyUpdates) + pOld ipn.PrefsView // old preferences (before ApplyUpdates) + node *nodeBackend // the node the event is associated with + lastSuggestedExitNode tailcfg.StableNodeID // the last suggested exit node +} + +// record records changes to preferences as clientmetrics. +func (e *prefsMetricsEditEvent) record() error { + if e.change == nil || e.node == nil { + return errors.New("prefsMetricsEditEvent: missing required fields") + } + + // Record up/down events. + if e.change.WantRunningSet && (e.pNew.WantRunning() != e.pOld.WantRunning()) { + if e.pNew.WantRunning() { + metricWantRunningEnabled.Add(1) + } else { + metricWantRunningDisabled.Add(1) + } + } + + // Record any changes to exit node settings. + if e.change.ExitNodeIDSet || e.change.ExitNodeIPSet { + if exitNodeTypes, ok := e.exitNodeType(e.pNew.ExitNodeID()); ok { + // We have switched to a valid exit node if ok is true. + metricExitNodeEnabled.Add(1) + + // We may have some additional characteristics we should also record. + for _, t := range exitNodeTypes { + switch t { + case exitNodeTypePreferred: + metricExitNodeEnabledSuggested.Add(1) + case exitNodeTypeMullvad: + metricExitNodeEnabledMullvad.Add(1) + } + } + } + } + return nil +} + +// exitNodeTypesLocked returns type of exit node for the given stable ID. +// An exit node may have multiple type (can be both mullvad and preferred +// simultaneously for example). +// +// This will return ok as true if the supplied stable ID resolves to a known peer, +// false otherwise. The caller is responsible for ensuring that the id belongs to +// an exit node. +func (e *prefsMetricsEditEvent) exitNodeType(id tailcfg.StableNodeID) (props []exitNodeProperty, isNode bool) { + if !buildfeatures.HasUseExitNode { + return nil, false + } + var peer tailcfg.NodeView + + if peer, isNode = e.node.PeerByStableID(id); isNode { + if tailcfg.StableNodeID(id) == e.lastSuggestedExitNode { + props = append(props, exitNodeTypePreferred) + } + if peer.IsWireGuardOnly() { + props = append(props, exitNodeTypeMullvad) + } + } + return props, isNode +} diff --git a/ipn/ipnlocal/profiles.go b/ipn/ipnlocal/profiles.go index b13f921d6..40a3c9887 100644 --- a/ipn/ipnlocal/profiles.go +++ b/ipn/ipnlocal/profiles.go @@ -13,17 +13,25 @@ import ( "slices" "strings" - "tailscale.com/clientupdate" "tailscale.com/envknob" + "tailscale.com/feature" "tailscale.com/health" "tailscale.com/ipn" + "tailscale.com/ipn/ipnext" "tailscale.com/tailcfg" + "tailscale.com/types/key" "tailscale.com/types/logger" + "tailscale.com/types/persist" "tailscale.com/util/clientmetric" + "tailscale.com/util/eventbus" + "tailscale.com/util/testenv" ) var debug = envknob.RegisterBool("TS_DEBUG_PROFILES") +// [profileManager] implements [ipnext.ProfileStore]. +var _ ipnext.ProfileStore = (*profileManager)(nil) + // profileManager is a wrapper around an [ipn.StateStore] that manages // multiple profiles and the current profile. // @@ -35,9 +43,31 @@ type profileManager struct { health *health.Tracker currentUserID ipn.WindowsUserID - knownProfiles map[ipn.ProfileID]*ipn.LoginProfile // always non-nil - currentProfile *ipn.LoginProfile // always non-nil - prefs ipn.PrefsView // always Valid. + knownProfiles map[ipn.ProfileID]ipn.LoginProfileView // always non-nil + currentProfile ipn.LoginProfileView // always Valid (once [newProfileManager] returns). + prefs ipn.PrefsView // always Valid (once [newProfileManager] returns). + + // StateChangeHook is an optional hook that is called when the current profile or prefs change, + // such as due to a profile switch or a change in the profile's preferences. + // It is typically set by the [LocalBackend] to invert the dependency between + // the [profileManager] and the [LocalBackend], so that instead of [LocalBackend] + // asking [profileManager] for the state, we can have [profileManager] call + // [LocalBackend] when the state changes. See also: + // https://github.com/tailscale/tailscale/pull/15791#discussion_r2060838160 + StateChangeHook ipnext.ProfileStateChangeCallback + + // extHost is the bridge between [profileManager] and the registered [ipnext.Extension]s. + // It may be nil in tests. A nil pointer is a valid, no-op host. + extHost *ExtensionHost +} + +// SetExtensionHost sets the [ExtensionHost] for the [profileManager]. +// The specified host will be notified about profile and prefs changes +// and will immediately be notified about the current profile and prefs. +// A nil host is a valid, no-op host. +func (pm *profileManager) SetExtensionHost(host *ExtensionHost) { + pm.extHost = host + host.NotifyProfileChange(pm.currentProfile, pm.prefs, false) } func (pm *profileManager) dlogf(format string, args ...any) { @@ -64,8 +94,7 @@ func (pm *profileManager) SetCurrentUserID(uid ipn.WindowsUserID) { if pm.currentUserID == uid { return } - pm.currentUserID = uid - if err := pm.SwitchToDefaultProfile(); err != nil { + if _, _, err := pm.SwitchToDefaultProfileForUser(uid); err != nil { // SetCurrentUserID should never fail and must always switch to the // user's default profile or create a new profile for the current user. // Until we implement multi-user support and the new permission model, @@ -73,65 +102,158 @@ func (pm *profileManager) SetCurrentUserID(uid ipn.WindowsUserID) { // that when SetCurrentUserID exits, the profile in pm.currentProfile // is either an existing profile owned by the user, or a new, empty profile. pm.logf("%q's default profile cannot be used; creating a new one: %v", uid, err) - pm.NewProfileForUser(uid) + pm.SwitchToNewProfileForUser(uid) } } -// DefaultUserProfileID returns [ipn.ProfileID] of the default (last used) profile for the specified user, -// or an empty string if the specified user does not have a default profile. -func (pm *profileManager) DefaultUserProfileID(uid ipn.WindowsUserID) ipn.ProfileID { +// SwitchToProfile switches to the specified profile and (temporarily, +// while the "current user" is still a thing on Windows; see tailscale/corp#18342) +// sets its owner as the current user. The profile must be a valid profile +// returned by the [profileManager], such as by [profileManager.Profiles], +// [profileManager.ProfileByID], or [profileManager.NewProfileForUser]. +// +// It is a shorthand for [profileManager.SetCurrentUserID] followed by +// [profileManager.SwitchProfileByID], but it is more efficient as it switches +// directly to the specified profile rather than switching to the user's +// default profile first. It is a no-op if the specified profile is already +// the current profile. +// +// As a special case, if the specified profile view is not valid, it resets +// both the current user and the profile to a new, empty profile not owned +// by any user. +// +// It returns the current profile and whether the call resulted in a profile change, +// or an error if the specified profile does not exist or its prefs could not be loaded. +// +// It may be called during [profileManager] initialization before [newProfileManager] returns +// and must check whether pm.currentProfile is Valid before using it. +func (pm *profileManager) SwitchToProfile(profile ipn.LoginProfileView) (cp ipn.LoginProfileView, changed bool, err error) { + prefs := defaultPrefs + switch { + case !profile.Valid(): + // Create a new profile that is not associated with any user. + profile = pm.NewProfileForUser("") + case profile == pm.currentProfile, + profile.ID() != "" && pm.currentProfile.Valid() && profile.ID() == pm.currentProfile.ID(), + profile.ID() == "" && profile.Equals(pm.currentProfile) && prefs.Equals(pm.prefs): + // The profile is already the current profile; no need to switch. + // + // It includes three cases: + // 1. The target profile and the current profile are aliases referencing the [ipn.LoginProfile]. + // The profile may be either a new (non-persisted) profile or an existing well-known profile. + // 2. The target profile is a well-known, persisted profile with the same ID as the current profile. + // 3. The target and the current profiles are both new (non-persisted) profiles and they are equal. + // At minimum, equality means that the profiles are owned by the same user on platforms that support it + // and the prefs are the same as well. + return pm.currentProfile, false, nil + case profile.ID() == "": + // Copy the specified profile to prevent accidental mutation. + profile = profile.AsStruct().View() + default: + // Find an existing profile by ID and load its prefs. + kp, ok := pm.knownProfiles[profile.ID()] + if !ok { + // The profile ID is not valid; it may have been deleted or never existed. + // As the target profile should have been returned by the [profileManager], + // this is unexpected and might indicate a bug in the code. + return pm.currentProfile, false, fmt.Errorf("[unexpected] %w: %s (%s)", errProfileNotFound, profile.Name(), profile.ID()) + } + profile = kp + if prefs, err = pm.loadSavedPrefs(profile.Key()); err != nil { + return pm.currentProfile, false, fmt.Errorf("failed to load profile prefs for %s (%s): %w", profile.Name(), profile.ID(), err) + } + } + + if profile.ID() == "" { // new profile that has never been persisted + metricNewProfile.Add(1) + } else { + metricSwitchProfile.Add(1) + } + + pm.prefs = prefs + pm.updateHealth() + pm.currentProfile = profile + pm.currentUserID = profile.LocalUserID() + if err := pm.setProfileAsUserDefault(profile); err != nil { + // This is not a fatal error; we've already switched to the profile. + // But if updating the default profile fails, we should log it. + pm.logf("failed to set %s (%s) as the default profile: %v", profile.Name(), profile.ID(), err) + } + + if f := pm.StateChangeHook; f != nil { + f(pm.currentProfile, pm.prefs, false) + } + // Do not call pm.extHost.NotifyProfileChange here; it is invoked in + // [LocalBackend.resetForProfileChangeLockedOnEntry] after the netmap reset. + // TODO(nickkhyl): Consider moving it here (or into the stateChangeCb handler + // in [LocalBackend]) once the profile/node state, including the netmap, + // is actually tied to the current profile. + + return profile, true, nil +} + +// DefaultUserProfile returns a read-only view of the default (last used) profile for the specified user. +// It returns a read-only view of a new, non-persisted profile if the specified user does not have a default profile. +func (pm *profileManager) DefaultUserProfile(uid ipn.WindowsUserID) ipn.LoginProfileView { // Read the CurrentProfileKey from the store which stores // the selected profile for the specified user. b, err := pm.store.ReadState(ipn.CurrentProfileKey(string(uid))) - pm.dlogf("DefaultUserProfileID: ReadState(%q) = %v, %v", string(uid), len(b), err) + pm.dlogf("DefaultUserProfile: ReadState(%q) = %v, %v", string(uid), len(b), err) if err == ipn.ErrStateNotExist || len(b) == 0 { if runtime.GOOS == "windows" { - pm.dlogf("DefaultUserProfileID: windows: migrating from legacy preferences") - profile, err := pm.migrateFromLegacyPrefs(uid, false) + pm.dlogf("DefaultUserProfile: windows: migrating from legacy preferences") + profile, err := pm.migrateFromLegacyPrefs(uid) if err == nil { - return profile.ID + return profile } pm.logf("failed to migrate from legacy preferences: %v", err) } - return "" + return pm.NewProfileForUser(uid) } pk := ipn.StateKey(string(b)) - prof := pm.findProfileByKey(pk) - if prof == nil { - pm.dlogf("DefaultUserProfileID: no profile found for key: %q", pk) - return "" + prof := pm.findProfileByKey(uid, pk) + if !prof.Valid() { + pm.dlogf("DefaultUserProfile: no profile found for key: %q", pk) + return pm.NewProfileForUser(uid) } - return prof.ID + return prof } // checkProfileAccess returns an [errProfileAccessDenied] if the current user // does not have access to the specified profile. -func (pm *profileManager) checkProfileAccess(profile *ipn.LoginProfile) error { - if pm.currentUserID != "" && profile.LocalUserID != pm.currentUserID { +func (pm *profileManager) checkProfileAccess(profile ipn.LoginProfileView) error { + return pm.checkProfileAccessAs(pm.currentUserID, profile) +} + +// checkProfileAccessAs returns an [errProfileAccessDenied] if the specified user +// does not have access to the specified profile. +func (pm *profileManager) checkProfileAccessAs(uid ipn.WindowsUserID, profile ipn.LoginProfileView) error { + if uid != "" && profile.LocalUserID() != uid { return errProfileAccessDenied } return nil } -// allProfiles returns all profiles accessible to the current user. +// allProfilesFor returns all profiles accessible to the specified user. // The returned profiles are sorted by Name. -func (pm *profileManager) allProfiles() (out []*ipn.LoginProfile) { +func (pm *profileManager) allProfilesFor(uid ipn.WindowsUserID) []ipn.LoginProfileView { + out := make([]ipn.LoginProfileView, 0, len(pm.knownProfiles)) for _, p := range pm.knownProfiles { - if pm.checkProfileAccess(p) == nil { + if pm.checkProfileAccessAs(uid, p) == nil { out = append(out, p) } } - slices.SortFunc(out, func(a, b *ipn.LoginProfile) int { - return cmp.Compare(a.Name, b.Name) + slices.SortFunc(out, func(a, b ipn.LoginProfileView) int { + return cmp.Compare(a.Name(), b.Name()) }) return out } -// matchingProfiles is like [profileManager.allProfiles], but returns only profiles +// matchingProfiles is like [profileManager.allProfilesFor], but returns only profiles // matching the given predicate. -func (pm *profileManager) matchingProfiles(f func(*ipn.LoginProfile) bool) (out []*ipn.LoginProfile) { - all := pm.allProfiles() +func (pm *profileManager) matchingProfiles(uid ipn.WindowsUserID, f func(ipn.LoginProfileView) bool) (out []ipn.LoginProfileView) { + all := pm.allProfilesFor(uid) out = all[:0] for _, p := range all { if f(p) { @@ -144,11 +266,11 @@ func (pm *profileManager) matchingProfiles(f func(*ipn.LoginProfile) bool) (out // findMatchingProfiles returns all profiles accessible to the current user // that represent the same node/user as prefs. // The returned profiles are sorted by Name. -func (pm *profileManager) findMatchingProfiles(prefs ipn.PrefsView) []*ipn.LoginProfile { - return pm.matchingProfiles(func(p *ipn.LoginProfile) bool { - return p.ControlURL == prefs.ControlURL() && - (p.UserProfile.ID == prefs.Persist().UserProfile().ID || - p.NodeID == prefs.Persist().NodeID()) +func (pm *profileManager) findMatchingProfiles(uid ipn.WindowsUserID, prefs ipn.PrefsView) []ipn.LoginProfileView { + return pm.matchingProfiles(uid, func(p ipn.LoginProfileView) bool { + return p.ControlURL() == prefs.ControlURL() && + (p.UserProfile().ID == prefs.Persist().UserProfile().ID || + p.NodeID() == prefs.Persist().NodeID()) }) } @@ -156,19 +278,19 @@ func (pm *profileManager) findMatchingProfiles(prefs ipn.PrefsView) []*ipn.Login // given name. It returns "" if no such profile exists among profiles // accessible to the current user. func (pm *profileManager) ProfileIDForName(name string) ipn.ProfileID { - p := pm.findProfileByName(name) - if p == nil { + p := pm.findProfileByName(pm.currentUserID, name) + if !p.Valid() { return "" } - return p.ID + return p.ID() } -func (pm *profileManager) findProfileByName(name string) *ipn.LoginProfile { - out := pm.matchingProfiles(func(p *ipn.LoginProfile) bool { - return p.Name == name +func (pm *profileManager) findProfileByName(uid ipn.WindowsUserID, name string) ipn.LoginProfileView { + out := pm.matchingProfiles(uid, func(p ipn.LoginProfileView) bool { + return p.Name() == name && pm.checkProfileAccessAs(uid, p) == nil }) if len(out) == 0 { - return nil + return ipn.LoginProfileView{} } if len(out) > 1 { pm.logf("[unexpected] multiple profiles with the same name") @@ -176,12 +298,12 @@ func (pm *profileManager) findProfileByName(name string) *ipn.LoginProfile { return out[0] } -func (pm *profileManager) findProfileByKey(key ipn.StateKey) *ipn.LoginProfile { - out := pm.matchingProfiles(func(p *ipn.LoginProfile) bool { - return p.Key == key +func (pm *profileManager) findProfileByKey(uid ipn.WindowsUserID, key ipn.StateKey) ipn.LoginProfileView { + out := pm.matchingProfiles(uid, func(p ipn.LoginProfileView) bool { + return p.Key() == key && pm.checkProfileAccessAs(uid, p) == nil }) if len(out) == 0 { - return nil + return ipn.LoginProfileView{} } if len(out) > 1 { pm.logf("[unexpected] multiple profiles with the same key") @@ -194,19 +316,13 @@ func (pm *profileManager) setUnattendedModeAsConfigured() error { return nil } - if pm.currentProfile.Key != "" && pm.prefs.ForceDaemon() { - return pm.WriteState(ipn.ServerModeStartKey, []byte(pm.currentProfile.Key)) + if pm.currentProfile.Key() != "" && pm.prefs.ForceDaemon() { + return pm.WriteState(ipn.ServerModeStartKey, []byte(pm.currentProfile.Key())) } else { return pm.WriteState(ipn.ServerModeStartKey, nil) } } -// Reset unloads the current profile, if any. -func (pm *profileManager) Reset() { - pm.currentUserID = "" - pm.NewProfile() -} - // SetPrefs sets the current profile's prefs to the provided value. // It also saves the prefs to the [ipn.StateStore]. It stores a copy of the // provided prefs, which may be accessed via [profileManager.CurrentPrefs]. @@ -222,36 +338,67 @@ func (pm *profileManager) SetPrefs(prefsIn ipn.PrefsView, np ipn.NetworkProfile) } // Check if we already have an existing profile that matches the user/node. - if existing := pm.findMatchingProfiles(prefsIn); len(existing) > 0 { + if existing := pm.findMatchingProfiles(pm.currentUserID, prefsIn); len(existing) > 0 { // We already have a profile for this user/node we should reuse it. Also // cleanup any other duplicate profiles. cp = existing[0] existing = existing[1:] for _, p := range existing { // Clear the state. - if err := pm.store.WriteState(p.Key, nil); err != nil { + if err := pm.store.WriteState(p.Key(), nil); err != nil { // We couldn't delete the state, so keep the profile around. continue } // Remove the profile, knownProfiles will be persisted // in [profileManager.setProfilePrefs] below. - delete(pm.knownProfiles, p.ID) + delete(pm.knownProfiles, p.ID()) + } + } + // TODO(nickkhyl): Revisit how we handle implicit switching to a different profile, + // which occurs when prefsIn represents a node/user different from that of the + // currentProfile. It happens when a login (either reauth or user-initiated login) + // is completed with a different node/user identity than the one currently in use. + // + // Currently, we overwrite the existing profile prefs with the ones from prefsIn, + // where prefsIn is the previous profile's prefs with an updated Persist, LoggedOut, + // WantRunning and possibly other fields. This may not be the desired behavior. + // + // Additionally, LocalBackend doesn't treat it as a proper profile switch, meaning that + // [LocalBackend.resetForProfileChangeLockedOnEntry] is not called and certain + // node/profile-specific state may not be reset as expected. + // + // However, [profileManager] notifies [ipnext.Extension]s about the profile change, + // so features migrated from LocalBackend to external packages should not be affected. + // + // See tailscale/corp#28014. + if !cp.Equals(pm.currentProfile) { + const sameNode = false // implicit profile switch + pm.currentProfile = cp + pm.prefs = prefsIn.AsStruct().View() + if f := pm.StateChangeHook; f != nil { + f(cp, prefsIn, sameNode) } + pm.extHost.NotifyProfileChange(cp, prefsIn, sameNode) } - pm.currentProfile = cp - if err := pm.SetProfilePrefs(cp, prefsIn, np); err != nil { + cp, err := pm.setProfilePrefs(nil, prefsIn, np) + if err != nil { return err } return pm.setProfileAsUserDefault(cp) - } -// SetProfilePrefs is like [profileManager.SetPrefs], but sets prefs for the specified [ipn.LoginProfile] -// which is not necessarily the [profileManager.CurrentProfile]. It returns an [errProfileAccessDenied] -// if the specified profile is not accessible by the current user. -func (pm *profileManager) SetProfilePrefs(lp *ipn.LoginProfile, prefsIn ipn.PrefsView, np ipn.NetworkProfile) error { - if err := pm.checkProfileAccess(lp); err != nil { - return err +// setProfilePrefs is like [profileManager.SetPrefs], but sets prefs for the specified [ipn.LoginProfile], +// returning a read-only view of the updated profile on success. If the specified profile is nil, +// it defaults to the current profile. If the profile is not accessible by the current user, +// the method returns an [errProfileAccessDenied]. +func (pm *profileManager) setProfilePrefs(lp *ipn.LoginProfile, prefsIn ipn.PrefsView, np ipn.NetworkProfile) (ipn.LoginProfileView, error) { + isCurrentProfile := lp == nil || (lp.ID != "" && lp.ID == pm.currentProfile.ID()) + if isCurrentProfile { + lp = pm.CurrentProfile().AsStruct() + } + + if err := pm.checkProfileAccess(lp.View()); err != nil { + return ipn.LoginProfileView{}, err } // An empty profile.ID indicates that the profile is new, the node info wasn't available, @@ -291,23 +438,42 @@ func (pm *profileManager) SetProfilePrefs(lp *ipn.LoginProfile, prefsIn ipn.Pref lp.UserProfile = up lp.NetworkProfile = np + // Update the current profile view to reflect the changes + // if the specified profile is the current profile. + if isCurrentProfile { + // Always set pm.currentProfile to the new profile view for pointer equality. + // We check it further down the call stack. + lp := lp.View() + sameProfileInfo := lp.Equals(pm.currentProfile) + pm.currentProfile = lp + if !sameProfileInfo { + // But only invoke the callbacks if the profile info has actually changed. + const sameNode = true // just an info update; still the same node + pm.prefs = prefsIn.AsStruct().View() // suppress further callbacks for this change + if f := pm.StateChangeHook; f != nil { + f(lp, prefsIn, sameNode) + } + pm.extHost.NotifyProfileChange(lp, prefsIn, sameNode) + } + } + // An empty profile.ID indicates that the node info is not available yet, // and the profile doesn't need to be saved on disk. if lp.ID != "" { - pm.knownProfiles[lp.ID] = lp + pm.knownProfiles[lp.ID] = lp.View() if err := pm.writeKnownProfiles(); err != nil { - return err + return ipn.LoginProfileView{}, err } // Clone prefsIn and create a read-only view as a safety measure to // prevent accidental preference mutations, both externally and internally. - if err := pm.setProfilePrefsNoPermCheck(lp, prefsIn.AsStruct().View()); err != nil { - return err + if err := pm.setProfilePrefsNoPermCheck(lp.View(), prefsIn.AsStruct().View()); err != nil { + return ipn.LoginProfileView{}, err } } - return nil + return lp.View(), nil } -func newUnusedID(knownProfiles map[ipn.ProfileID]*ipn.LoginProfile) (ipn.ProfileID, ipn.StateKey) { +func newUnusedID(knownProfiles map[ipn.ProfileID]ipn.LoginProfileView) (ipn.ProfileID, ipn.StateKey) { var idb [2]byte for { rand.Read(idb[:]) @@ -326,14 +492,40 @@ func newUnusedID(knownProfiles map[ipn.ProfileID]*ipn.LoginProfile) (ipn.Profile // The method does not perform any additional checks on the specified // profile, such as verifying the caller's access rights or checking // if another profile for the same node already exists. -func (pm *profileManager) setProfilePrefsNoPermCheck(profile *ipn.LoginProfile, clonedPrefs ipn.PrefsView) error { +func (pm *profileManager) setProfilePrefsNoPermCheck(profile ipn.LoginProfileView, clonedPrefs ipn.PrefsView) error { isCurrentProfile := pm.currentProfile == profile if isCurrentProfile { + oldPrefs := pm.prefs pm.prefs = clonedPrefs + + // Sadly, profile prefs can be changed in multiple ways. + // It's pretty chaotic, and in many cases callers use + // unexported methods of the profile manager instead of + // going through [LocalBackend.setPrefsLockedOnEntry] + // or at least using [profileManager.SetPrefs]. + // + // While we should definitely clean this up to improve + // the overall structure of how prefs are set, which would + // also address current and future conflicts, such as + // competing features changing the same prefs, this method + // is currently the central place where we can detect all + // changes to the current profile's prefs. + // + // That said, regardless of the cleanup, we might want + // to keep the profileManager responsible for invoking + // profile- and prefs-related callbacks. + + if !clonedPrefs.Equals(oldPrefs) { + if f := pm.StateChangeHook; f != nil { + f(pm.currentProfile, clonedPrefs, true) + } + pm.extHost.NotifyProfilePrefsChanged(pm.currentProfile, oldPrefs, clonedPrefs) + } + pm.updateHealth() } - if profile.Key != "" { - if err := pm.writePrefsToStore(profile.Key, clonedPrefs); err != nil { + if profile.Key() != "" { + if err := pm.writePrefsToStore(profile.Key(), clonedPrefs); err != nil { return err } } else if !isCurrentProfile { @@ -362,38 +554,33 @@ func (pm *profileManager) writePrefsToStore(key ipn.StateKey, prefs ipn.PrefsVie } // Profiles returns the list of known profiles accessible to the current user. -func (pm *profileManager) Profiles() []ipn.LoginProfile { - allProfiles := pm.allProfiles() - out := make([]ipn.LoginProfile, len(allProfiles)) - for i, p := range allProfiles { - out[i] = *p - } - return out +func (pm *profileManager) Profiles() []ipn.LoginProfileView { + return pm.allProfilesFor(pm.currentUserID) } // ProfileByID returns a profile with the given id, if it is accessible to the current user. // If the profile exists but is not accessible to the current user, it returns an [errProfileAccessDenied]. // If the profile does not exist, it returns an [errProfileNotFound]. -func (pm *profileManager) ProfileByID(id ipn.ProfileID) (ipn.LoginProfile, error) { +func (pm *profileManager) ProfileByID(id ipn.ProfileID) (ipn.LoginProfileView, error) { kp, err := pm.profileByIDNoPermCheck(id) if err != nil { - return ipn.LoginProfile{}, err + return ipn.LoginProfileView{}, err } if err := pm.checkProfileAccess(kp); err != nil { - return ipn.LoginProfile{}, err + return ipn.LoginProfileView{}, err } - return *kp, nil + return kp, nil } // profileByIDNoPermCheck is like [profileManager.ProfileByID], but it doesn't // check user's access rights to the profile. -func (pm *profileManager) profileByIDNoPermCheck(id ipn.ProfileID) (*ipn.LoginProfile, error) { - if id == pm.currentProfile.ID { +func (pm *profileManager) profileByIDNoPermCheck(id ipn.ProfileID) (ipn.LoginProfileView, error) { + if id == pm.currentProfile.ID() { return pm.currentProfile, nil } kp, ok := pm.knownProfiles[id] if !ok { - return nil, errProfileNotFound + return ipn.LoginProfileView{}, errProfileNotFound } return kp, nil } @@ -412,55 +599,45 @@ func (pm *profileManager) ProfilePrefs(id ipn.ProfileID) (ipn.PrefsView, error) return pm.profilePrefs(kp) } -func (pm *profileManager) profilePrefs(p *ipn.LoginProfile) (ipn.PrefsView, error) { - if p.ID == pm.currentProfile.ID { +func (pm *profileManager) profilePrefs(p ipn.LoginProfileView) (ipn.PrefsView, error) { + if p.ID() == pm.currentProfile.ID() { return pm.prefs, nil } - return pm.loadSavedPrefs(p.Key) + return pm.loadSavedPrefs(p.Key()) } -// SwitchProfile switches to the profile with the given id. +// SwitchToProfileByID switches to the profile with the given id. +// It returns the current profile and whether the call resulted in a profile change. // If the profile exists but is not accessible to the current user, it returns an [errProfileAccessDenied]. // If the profile does not exist, it returns an [errProfileNotFound]. -func (pm *profileManager) SwitchProfile(id ipn.ProfileID) error { - metricSwitchProfile.Add(1) - - kp, ok := pm.knownProfiles[id] - if !ok { - return errProfileNotFound - } - if pm.currentProfile != nil && kp.ID == pm.currentProfile.ID && pm.prefs.Valid() { - return nil - } - - if err := pm.checkProfileAccess(kp); err != nil { - return fmt.Errorf("%w: profile %q is not accessible to the current user", err, id) +func (pm *profileManager) SwitchToProfileByID(id ipn.ProfileID) (_ ipn.LoginProfileView, changed bool, err error) { + if id == pm.currentProfile.ID() { + return pm.currentProfile, false, nil } - prefs, err := pm.loadSavedPrefs(kp.Key) + profile, err := pm.ProfileByID(id) if err != nil { - return err + return pm.currentProfile, false, err } - pm.prefs = prefs - pm.updateHealth() - pm.currentProfile = kp - return pm.setProfileAsUserDefault(kp) + return pm.SwitchToProfile(profile) } -// SwitchToDefaultProfile switches to the default (last used) profile for the current user. -// It creates a new one and switches to it if the current user does not have a default profile, +// SwitchToDefaultProfileForUser switches to the default (last used) profile for the specified user. +// It creates a new one and switches to it if the specified user does not have a default profile, // or returns an error if the default profile is inaccessible or could not be loaded. -func (pm *profileManager) SwitchToDefaultProfile() error { - if id := pm.DefaultUserProfileID(pm.currentUserID); id != "" { - return pm.SwitchProfile(id) - } - pm.NewProfileForUser(pm.currentUserID) - return nil +func (pm *profileManager) SwitchToDefaultProfileForUser(uid ipn.WindowsUserID) (_ ipn.LoginProfileView, changed bool, err error) { + return pm.SwitchToProfile(pm.DefaultUserProfile(uid)) +} + +// SwitchToDefaultProfile is like [profileManager.SwitchToDefaultProfileForUser], but switches +// to the default profile for the current user. +func (pm *profileManager) SwitchToDefaultProfile() (_ ipn.LoginProfileView, changed bool, err error) { + return pm.SwitchToDefaultProfileForUser(pm.currentUserID) } // setProfileAsUserDefault sets the specified profile as the default for the current user. // It returns an [errProfileAccessDenied] if the specified profile is not accessible to the current user. -func (pm *profileManager) setProfileAsUserDefault(profile *ipn.LoginProfile) error { - if profile.Key == "" { +func (pm *profileManager) setProfileAsUserDefault(profile ipn.LoginProfileView) error { + if profile.Key() == "" { // The profile has not been persisted yet; ignore it for now. return nil } @@ -468,11 +645,11 @@ func (pm *profileManager) setProfileAsUserDefault(profile *ipn.LoginProfile) err return errProfileAccessDenied } k := ipn.CurrentProfileKey(string(pm.currentUserID)) - return pm.WriteState(k, []byte(profile.Key)) + return pm.WriteState(k, []byte(profile.Key())) } -func (pm *profileManager) loadSavedPrefs(key ipn.StateKey) (ipn.PrefsView, error) { - bs, err := pm.store.ReadState(key) +func (pm *profileManager) loadSavedPrefs(k ipn.StateKey) (ipn.PrefsView, error) { + bs, err := pm.store.ReadState(k) if err == ipn.ErrStateNotExist || len(bs) == 0 { return defaultPrefs, nil } @@ -480,10 +657,18 @@ func (pm *profileManager) loadSavedPrefs(key ipn.StateKey) (ipn.PrefsView, error return ipn.PrefsView{}, err } savedPrefs := ipn.NewPrefs() + + // if supported by the platform, create an empty hardware attestation key to use when deserializing + // to avoid type exceptions from json.Unmarshaling into an interface{}. + hw, _ := key.NewEmptyHardwareAttestationKey() + savedPrefs.Persist = &persist.Persist{ + AttestationKey: hw, + } + if err := ipn.PrefsFromBytes(bs, savedPrefs); err != nil { return ipn.PrefsView{}, fmt.Errorf("parsing saved prefs: %v", err) } - pm.logf("using backend prefs for %q: %v", key, savedPrefs.Pretty()) + pm.logf("using backend prefs for %q: %v", k, savedPrefs.Pretty()) // Ignore any old stored preferences for https://login.tailscale.com // as the control server that would override the new default of @@ -500,17 +685,17 @@ func (pm *profileManager) loadSavedPrefs(key ipn.StateKey) (ipn.PrefsView, error // cause any EditPrefs calls to fail (other than disabling auto-updates). // // Reset AutoUpdate.Apply if we detect such invalid prefs. - if savedPrefs.AutoUpdate.Apply.EqualBool(true) && !clientupdate.CanAutoUpdate() { + if savedPrefs.AutoUpdate.Apply.EqualBool(true) && !feature.CanAutoUpdate() { savedPrefs.AutoUpdate.Apply.Clear() } return savedPrefs.View(), nil } -// CurrentProfile returns the current LoginProfile. +// CurrentProfile returns a read-only [ipn.LoginProfileView] of the current profile. // The value may be zero if the profile is not persisted. -func (pm *profileManager) CurrentProfile() ipn.LoginProfile { - return *pm.currentProfile +func (pm *profileManager) CurrentProfile() ipn.LoginProfileView { + return pm.currentProfile } // errProfileNotFound is returned by methods that accept a ProfileID @@ -532,8 +717,7 @@ var errProfileAccessDenied = errors.New("profile access denied") // This is useful for deleting the last profile. In other cases, it is // recommended to call [profileManager.SwitchProfile] first. func (pm *profileManager) DeleteProfile(id ipn.ProfileID) error { - metricDeleteProfile.Add(1) - if id == pm.currentProfile.ID { + if id == pm.currentProfile.ID() { return pm.deleteCurrentProfile() } kp, ok := pm.knownProfiles[id] @@ -550,9 +734,9 @@ func (pm *profileManager) deleteCurrentProfile() error { if err := pm.checkProfileAccess(pm.currentProfile); err != nil { return err } - if pm.currentProfile.ID == "" { + if pm.currentProfile.ID() == "" { // Deleting the in-memory only new profile, just create a new one. - pm.NewProfile() + pm.SwitchToNewProfile() return nil } return pm.deleteProfileNoPermCheck(pm.currentProfile) @@ -560,14 +744,15 @@ func (pm *profileManager) deleteCurrentProfile() error { // deleteProfileNoPermCheck is like [profileManager.DeleteProfile], // but it doesn't check user's access rights to the profile. -func (pm *profileManager) deleteProfileNoPermCheck(profile *ipn.LoginProfile) error { - if profile.ID == pm.currentProfile.ID { - pm.NewProfile() +func (pm *profileManager) deleteProfileNoPermCheck(profile ipn.LoginProfileView) error { + if profile.ID() == pm.currentProfile.ID() { + pm.SwitchToNewProfile() } - if err := pm.WriteState(profile.Key, nil); err != nil { + if err := pm.WriteState(profile.Key(), nil); err != nil { return err } - delete(pm.knownProfiles, profile.ID) + delete(pm.knownProfiles, profile.ID()) + metricDeleteProfile.Add(1) return pm.writeKnownProfiles() } @@ -578,8 +763,8 @@ func (pm *profileManager) DeleteAllProfilesForUser() error { currentProfileDeleted := false writeKnownProfiles := func() error { - if currentProfileDeleted || pm.currentProfile.ID == "" { - pm.NewProfile() + if currentProfileDeleted || pm.currentProfile.ID() == "" { + pm.SwitchToNewProfile() } return pm.writeKnownProfiles() } @@ -589,14 +774,14 @@ func (pm *profileManager) DeleteAllProfilesForUser() error { // Skip profiles we don't have access to. continue } - if err := pm.WriteState(kp.Key, nil); err != nil { + if err := pm.WriteState(kp.Key(), nil); err != nil { // Write to remove references to profiles we've already deleted, but // return the original error. writeKnownProfiles() return err } - delete(pm.knownProfiles, kp.ID) - if kp.ID == pm.currentProfile.ID { + delete(pm.knownProfiles, kp.ID()) + if kp.ID() == pm.currentProfile.ID() { currentProfileDeleted = true } } @@ -608,6 +793,7 @@ func (pm *profileManager) writeKnownProfiles() error { if err != nil { return err } + metricProfileCount.Set(int64(len(pm.knownProfiles))) return pm.WriteState(ipn.KnownProfilesStateKey, b) } @@ -618,44 +804,25 @@ func (pm *profileManager) updateHealth() { pm.health.SetAutoUpdatePrefs(pm.prefs.AutoUpdate().Check, pm.prefs.AutoUpdate().Apply) } -// NewProfile creates and switches to a new unnamed profile. The new profile is +// SwitchToNewProfile creates and switches to a new unnamed profile. The new profile is // not persisted until [profileManager.SetPrefs] is called with a logged-in user. -func (pm *profileManager) NewProfile() { - pm.NewProfileForUser(pm.currentUserID) +func (pm *profileManager) SwitchToNewProfile() { + pm.SwitchToNewProfileForUser(pm.currentUserID) } -// NewProfileForUser is like [profileManager.NewProfile], but it switches to the +// SwitchToNewProfileForUser is like [profileManager.SwitchToNewProfile], but it switches to the // specified user and sets that user as the profile owner for the new profile. -func (pm *profileManager) NewProfileForUser(uid ipn.WindowsUserID) { - pm.currentUserID = uid - - metricNewProfile.Add(1) - - pm.prefs = defaultPrefs - pm.updateHealth() - pm.currentProfile = &ipn.LoginProfile{LocalUserID: uid} +func (pm *profileManager) SwitchToNewProfileForUser(uid ipn.WindowsUserID) { + pm.SwitchToProfile(pm.NewProfileForUser(uid)) } -// newProfileWithPrefs creates a new profile with the specified prefs and assigns -// the specified uid as the profile owner. If switchNow is true, it switches to the -// newly created profile immediately. It returns the newly created profile on success, -// or an error on failure. -func (pm *profileManager) newProfileWithPrefs(uid ipn.WindowsUserID, prefs ipn.PrefsView, switchNow bool) (*ipn.LoginProfile, error) { - metricNewProfile.Add(1) +// zeroProfile is a read-only view of a new, empty profile that is not persisted to the store. +var zeroProfile = (&ipn.LoginProfile{}).View() - profile := &ipn.LoginProfile{LocalUserID: uid} - if err := pm.SetProfilePrefs(profile, prefs, ipn.NetworkProfile{}); err != nil { - return nil, err - } - if switchNow { - pm.currentProfile = profile - pm.prefs = prefs.AsStruct().View() - pm.updateHealth() - if err := pm.setProfileAsUserDefault(profile); err != nil { - return nil, err - } - } - return profile, nil +// NewProfileForUser creates a new profile for the specified user and returns a read-only view of it. +// It neither switches to the new profile nor persists it to the store. +func (pm *profileManager) NewProfileForUser(uid ipn.WindowsUserID) ipn.LoginProfileView { + return (&ipn.LoginProfile{LocalUserID: uid}).View() } // defaultPrefs is the default prefs for a new profile. This initializes before @@ -683,7 +850,10 @@ func (pm *profileManager) CurrentPrefs() ipn.PrefsView { // ReadStartupPrefsForTest reads the startup prefs from disk. It is only used for testing. func ReadStartupPrefsForTest(logf logger.Logf, store ipn.StateStore) (ipn.PrefsView, error) { - ht := new(health.Tracker) // in tests, don't care about the health status + testenv.AssertInTest() + bus := eventbus.New() + defer bus.Close() + ht := health.NewTracker(bus) // in tests, don't care about the health status pm, err := newProfileManager(store, logf, ht) if err != nil { return ipn.PrefsView{}, err @@ -711,8 +881,8 @@ func readAutoStartKey(store ipn.StateStore, goos string) (ipn.StateKey, error) { return ipn.StateKey(autoStartKey), nil } -func readKnownProfiles(store ipn.StateStore) (map[ipn.ProfileID]*ipn.LoginProfile, error) { - var knownProfiles map[ipn.ProfileID]*ipn.LoginProfile +func readKnownProfiles(store ipn.StateStore) (map[ipn.ProfileID]ipn.LoginProfileView, error) { + var knownProfiles map[ipn.ProfileID]ipn.LoginProfileView prfB, err := store.ReadState(ipn.KnownProfilesStateKey) switch err { case nil: @@ -720,7 +890,7 @@ func readKnownProfiles(store ipn.StateStore) (map[ipn.ProfileID]*ipn.LoginProfil return nil, fmt.Errorf("unmarshaling known profiles: %w", err) } case ipn.ErrStateNotExist: - knownProfiles = make(map[ipn.ProfileID]*ipn.LoginProfile) + knownProfiles = make(map[ipn.ProfileID]ipn.LoginProfileView) default: return nil, fmt.Errorf("calling ReadState on state store: %w", err) } @@ -739,6 +909,8 @@ func newProfileManagerWithGOOS(store ipn.StateStore, logf logger.Logf, ht *healt return nil, err } + metricProfileCount.Set(int64(len(knownProfiles))) + pm := &profileManager{ goos: goos, store: store, @@ -747,27 +919,9 @@ func newProfileManagerWithGOOS(store ipn.StateStore, logf logger.Logf, ht *healt health: ht, } + var initialProfile ipn.LoginProfileView if stateKey != "" { - for _, v := range knownProfiles { - if v.Key == stateKey { - pm.currentProfile = v - } - } - if pm.currentProfile == nil { - if suf, ok := strings.CutPrefix(string(stateKey), "user-"); ok { - pm.currentUserID = ipn.WindowsUserID(suf) - } - pm.NewProfile() - } else { - pm.currentUserID = pm.currentProfile.LocalUserID - } - prefs, err := pm.loadSavedPrefs(stateKey) - if err != nil { - return nil, err - } - if err := pm.setProfilePrefsNoPermCheck(pm.currentProfile, prefs); err != nil { - return nil, err - } + initialProfile = pm.findProfileByKey("", stateKey) // Most platform behavior is controlled by the goos parameter, however // some behavior is implied by build tag and fails when run on Windows, // so we explicitly avoid that behavior when running on Windows. @@ -778,28 +932,35 @@ func newProfileManagerWithGOOS(store ipn.StateStore, logf logger.Logf, ht *healt } else if len(knownProfiles) == 0 && goos != "windows" && runtime.GOOS != "windows" { // No known profiles, try a migration. pm.dlogf("no known profiles; trying to migrate from legacy prefs") - if _, err := pm.migrateFromLegacyPrefs(pm.currentUserID, true); err != nil { - return nil, err + if initialProfile, err = pm.migrateFromLegacyPrefs(pm.currentUserID); err != nil { + } - } else { - pm.NewProfile() } - + if !initialProfile.Valid() { + var initialUserID ipn.WindowsUserID + if suf, ok := strings.CutPrefix(string(stateKey), "user-"); ok { + initialUserID = ipn.WindowsUserID(suf) + } + initialProfile = pm.NewProfileForUser(initialUserID) + } + if _, _, err := pm.SwitchToProfile(initialProfile); err != nil { + return nil, err + } return pm, nil } -func (pm *profileManager) migrateFromLegacyPrefs(uid ipn.WindowsUserID, switchNow bool) (*ipn.LoginProfile, error) { +func (pm *profileManager) migrateFromLegacyPrefs(uid ipn.WindowsUserID) (ipn.LoginProfileView, error) { metricMigration.Add(1) sentinel, prefs, err := pm.loadLegacyPrefs(uid) if err != nil { metricMigrationError.Add(1) - return nil, fmt.Errorf("load legacy prefs: %w", err) + return ipn.LoginProfileView{}, fmt.Errorf("load legacy prefs: %w", err) } pm.dlogf("loaded legacy preferences; sentinel=%q", sentinel) - profile, err := pm.newProfileWithPrefs(uid, prefs, switchNow) + profile, err := pm.setProfilePrefs(&ipn.LoginProfile{LocalUserID: uid}, prefs, ipn.NetworkProfile{}) if err != nil { metricMigrationError.Add(1) - return nil, fmt.Errorf("migrating _daemon profile: %w", err) + return ipn.LoginProfileView{}, fmt.Errorf("migrating _daemon profile: %w", err) } pm.completeMigration(sentinel) pm.dlogf("completed legacy preferences migration with sentinel=%q", sentinel) @@ -809,8 +970,8 @@ func (pm *profileManager) migrateFromLegacyPrefs(uid ipn.WindowsUserID, switchNo func (pm *profileManager) requiresBackfill() bool { return pm != nil && - pm.currentProfile != nil && - pm.currentProfile.NetworkProfile.RequiresBackfill() + pm.currentProfile.Valid() && + pm.currentProfile.NetworkProfile().RequiresBackfill() } var ( @@ -818,6 +979,7 @@ var ( metricSwitchProfile = clientmetric.NewCounter("profiles_switch") metricDeleteProfile = clientmetric.NewCounter("profiles_delete") metricDeleteAllProfile = clientmetric.NewCounter("profiles_delete_all") + metricProfileCount = clientmetric.NewGauge("profiles_count") metricMigration = clientmetric.NewCounter("profiles_migration") metricMigrationError = clientmetric.NewCounter("profiles_migration_error") diff --git a/ipn/ipnlocal/profiles_test.go b/ipn/ipnlocal/profiles_test.go index 73e4f6535..95834284e 100644 --- a/ipn/ipnlocal/profiles_test.go +++ b/ipn/ipnlocal/profiles_test.go @@ -7,11 +7,13 @@ import ( "fmt" "os/user" "strconv" + "strings" "testing" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" - "tailscale.com/clientupdate" + _ "tailscale.com/clientupdate" // for feature registration side effects + "tailscale.com/feature" "tailscale.com/health" "tailscale.com/ipn" "tailscale.com/ipn/store/mem" @@ -19,13 +21,14 @@ import ( "tailscale.com/types/key" "tailscale.com/types/logger" "tailscale.com/types/persist" + "tailscale.com/util/eventbus/eventbustest" "tailscale.com/util/must" ) func TestProfileCurrentUserSwitch(t *testing.T) { store := new(mem.Store) - pm, err := newProfileManagerWithGOOS(store, logger.Discard, new(health.Tracker), "linux") + pm, err := newProfileManagerWithGOOS(store, logger.Discard, health.NewTracker(eventbustest.NewBus(t)), "linux") if err != nil { t.Fatal(err) } @@ -33,7 +36,7 @@ func TestProfileCurrentUserSwitch(t *testing.T) { newProfile := func(t *testing.T, loginName string) ipn.PrefsView { id++ t.Helper() - pm.NewProfile() + pm.SwitchToNewProfile() p := pm.CurrentPrefs().AsStruct() p.Persist = &persist.Persist{ NodeID: tailcfg.StableNodeID(fmt.Sprint(id)), @@ -52,25 +55,25 @@ func TestProfileCurrentUserSwitch(t *testing.T) { pm.SetCurrentUserID("user1") newProfile(t, "user1") cp := pm.currentProfile - pm.DeleteProfile(cp.ID) - if pm.currentProfile == nil { + pm.DeleteProfile(cp.ID()) + if !pm.currentProfile.Valid() { t.Fatal("currentProfile is nil") - } else if pm.currentProfile.ID != "" { - t.Fatalf("currentProfile.ID = %q, want empty", pm.currentProfile.ID) + } else if pm.currentProfile.ID() != "" { + t.Fatalf("currentProfile.ID = %q, want empty", pm.currentProfile.ID()) } if !pm.CurrentPrefs().Equals(defaultPrefs) { t.Fatalf("CurrentPrefs() = %v, want emptyPrefs", pm.CurrentPrefs().Pretty()) } - pm, err = newProfileManagerWithGOOS(store, logger.Discard, new(health.Tracker), "linux") + pm, err = newProfileManagerWithGOOS(store, logger.Discard, health.NewTracker(eventbustest.NewBus(t)), "linux") if err != nil { t.Fatal(err) } pm.SetCurrentUserID("user1") - if pm.currentProfile == nil { + if !pm.currentProfile.Valid() { t.Fatal("currentProfile is nil") - } else if pm.currentProfile.ID != "" { - t.Fatalf("currentProfile.ID = %q, want empty", pm.currentProfile.ID) + } else if pm.currentProfile.ID() != "" { + t.Fatalf("currentProfile.ID = %q, want empty", pm.currentProfile.ID()) } if !pm.CurrentPrefs().Equals(defaultPrefs) { t.Fatalf("CurrentPrefs() = %v, want emptyPrefs", pm.CurrentPrefs().Pretty()) @@ -80,7 +83,7 @@ func TestProfileCurrentUserSwitch(t *testing.T) { func TestProfileList(t *testing.T) { store := new(mem.Store) - pm, err := newProfileManagerWithGOOS(store, logger.Discard, new(health.Tracker), "linux") + pm, err := newProfileManagerWithGOOS(store, logger.Discard, health.NewTracker(eventbustest.NewBus(t)), "linux") if err != nil { t.Fatal(err) } @@ -88,7 +91,7 @@ func TestProfileList(t *testing.T) { newProfile := func(t *testing.T, loginName string) ipn.PrefsView { id++ t.Helper() - pm.NewProfile() + pm.SwitchToNewProfile() p := pm.CurrentPrefs().AsStruct() p.Persist = &persist.Persist{ NodeID: tailcfg.StableNodeID(fmt.Sprint(id)), @@ -110,8 +113,8 @@ func TestProfileList(t *testing.T) { t.Fatalf("got %d profiles, want %d", len(got), len(want)) } for i, w := range want { - if got[i].Name != w { - t.Errorf("got profile %d name %q, want %q", i, got[i].Name, w) + if got[i].Name() != w { + t.Errorf("got profile %d name %q, want %q", i, got[i].Name(), w) } } } @@ -129,10 +132,10 @@ func TestProfileList(t *testing.T) { pm.SetCurrentUserID("user1") checkProfiles(t, "alice", "bob") - if lp := pm.findProfileByKey(carol.Key); lp != nil { + if lp := pm.findProfileByKey("user1", carol.Key()); lp.Valid() { t.Fatalf("found profile for user2 in user1's profile list") } - if lp := pm.findProfileByName(carol.Name); lp != nil { + if lp := pm.findProfileByName("user1", carol.Name()); lp.Valid() { t.Fatalf("found profile for user2 in user1's profile list") } @@ -148,6 +151,7 @@ func TestProfileDupe(t *testing.T) { ID: tailcfg.UserID(user), LoginName: fmt.Sprintf("user%d@example.com", user), }, + AttestationKey: nil, } } user1Node1 := newPersist(1, 1) @@ -162,7 +166,7 @@ func TestProfileDupe(t *testing.T) { must.Do(pm.SetPrefs(prefs.View(), ipn.NetworkProfile{})) } login := func(pm *profileManager, p *persist.Persist) { - pm.NewProfile() + pm.SwitchToNewProfile() reauth(pm, p) } @@ -284,7 +288,7 @@ func TestProfileDupe(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { store := new(mem.Store) - pm, err := newProfileManagerWithGOOS(store, logger.Discard, new(health.Tracker), "linux") + pm, err := newProfileManagerWithGOOS(store, logger.Discard, health.NewTracker(eventbustest.NewBus(t)), "linux") if err != nil { t.Fatal(err) } @@ -294,7 +298,7 @@ func TestProfileDupe(t *testing.T) { profs := pm.Profiles() var got []*persist.Persist for _, p := range profs { - prefs, err := pm.loadSavedPrefs(p.Key) + prefs, err := pm.loadSavedPrefs(p.Key()) if err != nil { t.Fatal(err) } @@ -317,7 +321,7 @@ func TestProfileDupe(t *testing.T) { func TestProfileManagement(t *testing.T) { store := new(mem.Store) - pm, err := newProfileManagerWithGOOS(store, logger.Discard, new(health.Tracker), "linux") + pm, err := newProfileManagerWithGOOS(store, logger.Discard, health.NewTracker(eventbustest.NewBus(t)), "linux") if err != nil { t.Fatal(err) } @@ -328,9 +332,9 @@ func TestProfileManagement(t *testing.T) { checkProfiles := func(t *testing.T) { t.Helper() prof := pm.CurrentProfile() - t.Logf("\tCurrentProfile = %q", prof) - if prof.Name != wantCurProfile { - t.Fatalf("CurrentProfile = %q; want %q", prof, wantCurProfile) + t.Logf("\tCurrentProfile = %q", prof.Name()) + if prof.Name() != wantCurProfile { + t.Fatalf("CurrentProfile = %q; want %q", prof.Name(), wantCurProfile) } profiles := pm.Profiles() wantLen := len(wantProfiles) @@ -349,13 +353,13 @@ func TestProfileManagement(t *testing.T) { t.Fatalf("CurrentPrefs = %v; want %v", p.Pretty(), wantProfiles[wantCurProfile].Pretty()) } for _, p := range profiles { - got, err := pm.loadSavedPrefs(p.Key) + got, err := pm.loadSavedPrefs(p.Key()) if err != nil { t.Fatal(err) } // Use Hostname as a proxy for all prefs. - if !got.Equals(wantProfiles[p.Name]) { - t.Fatalf("Prefs for profile %q =\n got=%+v\nwant=%v", p, got.Pretty(), wantProfiles[p.Name].Pretty()) + if !got.Equals(wantProfiles[p.Name()]) { + t.Fatalf("Prefs for profile %q =\n got=%+v\nwant=%v", p.Name(), got.Pretty(), wantProfiles[p.Name()].Pretty()) } } } @@ -399,7 +403,7 @@ func TestProfileManagement(t *testing.T) { checkProfiles(t) t.Logf("Create new profile") - pm.NewProfile() + pm.SwitchToNewProfile() wantCurProfile = "" wantProfiles[""] = defaultPrefs checkProfiles(t) @@ -415,14 +419,14 @@ func TestProfileManagement(t *testing.T) { t.Logf("Recreate profile manager from store") // Recreate the profile manager to ensure that it can load the profiles // from the store at startup. - pm, err = newProfileManagerWithGOOS(store, logger.Discard, new(health.Tracker), "linux") + pm, err = newProfileManagerWithGOOS(store, logger.Discard, health.NewTracker(eventbustest.NewBus(t)), "linux") if err != nil { t.Fatal(err) } checkProfiles(t) t.Logf("Delete default profile") - if err := pm.DeleteProfile(pm.findProfileByName("user@1.example.com").ID); err != nil { + if err := pm.DeleteProfile(pm.ProfileIDForName("user@1.example.com")); err != nil { t.Fatal(err) } delete(wantProfiles, "user@1.example.com") @@ -431,14 +435,14 @@ func TestProfileManagement(t *testing.T) { t.Logf("Recreate profile manager from store after deleting default profile") // Recreate the profile manager to ensure that it can load the profiles // from the store at startup. - pm, err = newProfileManagerWithGOOS(store, logger.Discard, new(health.Tracker), "linux") + pm, err = newProfileManagerWithGOOS(store, logger.Discard, health.NewTracker(eventbustest.NewBus(t)), "linux") if err != nil { t.Fatal(err) } checkProfiles(t) t.Logf("Create new profile - 2") - pm.NewProfile() + pm.SwitchToNewProfile() wantCurProfile = "" wantProfiles[""] = defaultPrefs checkProfiles(t) @@ -462,7 +466,7 @@ func TestProfileManagement(t *testing.T) { wantCurProfile = "user@2.example.com" checkProfiles(t) - if !clientupdate.CanAutoUpdate() { + if !feature.CanAutoUpdate() { t.Logf("Save an invalid AutoUpdate pref value") prefs := pm.CurrentPrefs().AsStruct() prefs.AutoUpdate.Apply.Set(true) @@ -473,7 +477,7 @@ func TestProfileManagement(t *testing.T) { t.Fatal("SetPrefs failed to save auto-update setting") } // Re-load profiles to trigger migration for invalid auto-update value. - pm, err = newProfileManagerWithGOOS(store, logger.Discard, new(health.Tracker), "linux") + pm, err = newProfileManagerWithGOOS(store, logger.Discard, health.NewTracker(eventbustest.NewBus(t)), "linux") if err != nil { t.Fatal(err) } @@ -495,7 +499,7 @@ func TestProfileManagementWindows(t *testing.T) { store := new(mem.Store) - pm, err := newProfileManagerWithGOOS(store, logger.Discard, new(health.Tracker), "windows") + pm, err := newProfileManagerWithGOOS(store, logger.Discard, health.NewTracker(eventbustest.NewBus(t)), "windows") if err != nil { t.Fatal(err) } @@ -506,9 +510,9 @@ func TestProfileManagementWindows(t *testing.T) { checkProfiles := func(t *testing.T) { t.Helper() prof := pm.CurrentProfile() - t.Logf("\tCurrentProfile = %q", prof) - if prof.Name != wantCurProfile { - t.Fatalf("CurrentProfile = %q; want %q", prof, wantCurProfile) + t.Logf("\tCurrentProfile = %q", prof.Name()) + if prof.Name() != wantCurProfile { + t.Fatalf("CurrentProfile = %q; want %q", prof.Name(), wantCurProfile) } if p := pm.CurrentPrefs(); !p.Equals(wantProfiles[wantCurProfile]) { t.Fatalf("CurrentPrefs = %+v; want %+v", p.Pretty(), wantProfiles[wantCurProfile].Pretty()) @@ -550,7 +554,7 @@ func TestProfileManagementWindows(t *testing.T) { { t.Logf("Create new profile") - pm.NewProfile() + pm.SwitchToNewProfile() wantCurProfile = "" wantProfiles[""] = defaultPrefs checkProfiles(t) @@ -564,7 +568,7 @@ func TestProfileManagementWindows(t *testing.T) { t.Logf("Recreate profile manager from store, should reset prefs") // Recreate the profile manager to ensure that it can load the profiles // from the store at startup. - pm, err = newProfileManagerWithGOOS(store, logger.Discard, new(health.Tracker), "windows") + pm, err = newProfileManagerWithGOOS(store, logger.Discard, health.NewTracker(eventbustest.NewBus(t)), "windows") if err != nil { t.Fatal(err) } @@ -587,7 +591,7 @@ func TestProfileManagementWindows(t *testing.T) { } // Recreate the profile manager to ensure that it starts with test profile. - pm, err = newProfileManagerWithGOOS(store, logger.Discard, new(health.Tracker), "windows") + pm, err = newProfileManagerWithGOOS(store, logger.Discard, health.NewTracker(eventbustest.NewBus(t)), "windows") if err != nil { t.Fatal(err) } @@ -609,3 +613,537 @@ func TestDefaultPrefs(t *testing.T) { t.Errorf("defaultPrefs is %s, want %s; defaultPrefs should only modify WantRunning and LoggedOut, all other defaults should be in ipn.NewPrefs.", p2.Pretty(), p1.Pretty()) } } + +// mutPrefsFn is a function that mutates the prefs. +// Deserialization pre‑populates prefs with default (non‑zero) values. +// After saving prefs and reading them back, we may not get exactly what we set. +// For this reason, tests apply changes through a helper that mutates +// [ipn.NewPrefs] instead of hard‑coding expected values in each case. +type mutPrefsFn func(*ipn.Prefs) + +type profileState struct { + *ipn.LoginProfile + mutPrefs mutPrefsFn +} + +func (s *profileState) prefs() ipn.PrefsView { + prefs := ipn.NewPrefs() // apply changes to the default prefs + s.mutPrefs(prefs) + return prefs.View() +} + +type profileStateChange struct { + *ipn.LoginProfile + mutPrefs mutPrefsFn + sameNode bool +} + +func wantProfileChange(state profileState) profileStateChange { + return profileStateChange{ + LoginProfile: state.LoginProfile, + mutPrefs: state.mutPrefs, + sameNode: false, + } +} + +func wantPrefsChange(state profileState) profileStateChange { + return profileStateChange{ + LoginProfile: state.LoginProfile, + mutPrefs: state.mutPrefs, + sameNode: true, + } +} + +func makeDefaultPrefs(p *ipn.Prefs) { *p = *defaultPrefs.AsStruct() } + +func makeKnownProfileState(id int, nameSuffix string, uid ipn.WindowsUserID, mutPrefs mutPrefsFn) profileState { + lowerNameSuffix := strings.ToLower(nameSuffix) + nid := "node-" + tailcfg.StableNodeID(lowerNameSuffix) + up := tailcfg.UserProfile{ + ID: tailcfg.UserID(id), + LoginName: fmt.Sprintf("user-%s@example.com", lowerNameSuffix), + DisplayName: "User " + nameSuffix, + } + return profileState{ + LoginProfile: &ipn.LoginProfile{ + LocalUserID: uid, + Name: up.LoginName, + ID: ipn.ProfileID(fmt.Sprintf("%04X", id)), + Key: "profile-" + ipn.StateKey(nameSuffix), + NodeID: nid, + UserProfile: up, + }, + mutPrefs: func(p *ipn.Prefs) { + p.Hostname = "Hostname-" + nameSuffix + if mutPrefs != nil { + mutPrefs(p) // apply any additional changes + } + p.Persist = &persist.Persist{NodeID: nid, UserProfile: up} + }, + } +} + +func TestProfileStateChangeCallback(t *testing.T) { + t.Parallel() + + // A few well-known profiles to use in tests. + emptyProfile := profileState{ + LoginProfile: &ipn.LoginProfile{}, + mutPrefs: makeDefaultPrefs, + } + profile0000 := profileState{ + LoginProfile: &ipn.LoginProfile{ID: "0000", Key: "profile-0000"}, + mutPrefs: makeDefaultPrefs, + } + profileA := makeKnownProfileState(0xA, "A", "", nil) + profileB := makeKnownProfileState(0xB, "B", "", nil) + profileC := makeKnownProfileState(0xC, "C", "", nil) + + aliceUserID := ipn.WindowsUserID("S-1-5-21-1-2-3-4") + aliceEmptyProfile := profileState{ + LoginProfile: &ipn.LoginProfile{LocalUserID: aliceUserID}, + mutPrefs: makeDefaultPrefs, + } + bobUserID := ipn.WindowsUserID("S-1-5-21-3-4-5-6") + bobEmptyProfile := profileState{ + LoginProfile: &ipn.LoginProfile{LocalUserID: bobUserID}, + mutPrefs: makeDefaultPrefs, + } + bobKnownProfile := makeKnownProfileState(0xB0B, "Bob", bobUserID, nil) + + tests := []struct { + name string + initial *profileState // if non-nil, this is the initial profile and prefs to start wit + knownProfiles []profileState // known profiles we can switch to + action func(*profileManager) // action to take on the profile manager + wantChanges []profileStateChange // expected state changes + }{ + { + name: "no-changes", + action: func(*profileManager) { + // do nothing + }, + wantChanges: nil, + }, + { + name: "no-initial/new-profile", + action: func(pm *profileManager) { + // The profile manager is new and started with a new empty profile. + // This should not trigger a state change callback. + pm.SwitchToNewProfile() + }, + wantChanges: nil, + }, + { + name: "no-initial/new-profile-for-user", + action: func(pm *profileManager) { + // But switching to a new profile for a specific user should trigger + // a state change callback. + pm.SwitchToNewProfileForUser(aliceUserID) + }, + wantChanges: []profileStateChange{ + // We want a new empty profile (owned by the specified user) + // and the default prefs. + wantProfileChange(aliceEmptyProfile), + }, + }, + { + name: "with-initial/new-profile", + initial: &profile0000, + action: func(pm *profileManager) { + // And so does switching to a new profile when the initial profile + // is non-empty. + pm.SwitchToNewProfile() + }, + wantChanges: []profileStateChange{ + // We want a new empty profile and the default prefs. + wantProfileChange(emptyProfile), + }, + }, + { + name: "with-initial/new-profile/twice", + initial: &profile0000, + action: func(pm *profileManager) { + // If we switch to a new profile twice, we should only get one state change. + pm.SwitchToNewProfile() + pm.SwitchToNewProfile() + }, + wantChanges: []profileStateChange{ + // We want a new empty profile and the default prefs. + wantProfileChange(emptyProfile), + }, + }, + { + name: "with-initial/new-profile-for-user/twice", + initial: &profile0000, + action: func(pm *profileManager) { + // Unless we switch to a new profile for a specific user, + // in which case we should get a state change twice. + pm.SwitchToNewProfileForUser(aliceUserID) + pm.SwitchToNewProfileForUser(aliceUserID) // no change here + pm.SwitchToNewProfileForUser(bobUserID) + }, + wantChanges: []profileStateChange{ + // Both profiles are empty, but they are owned by different users. + wantProfileChange(aliceEmptyProfile), + wantProfileChange(bobEmptyProfile), + }, + }, + { + name: "with-initial/new-profile/twice/with-prefs-change", + initial: &profile0000, + action: func(pm *profileManager) { + // Or unless we switch to a new profile, change the prefs, + // then switch to a new profile again. Since the current + // profile is not empty after the prefs change, we should + // get state changes for all three actions. + pm.SwitchToNewProfile() + p := pm.CurrentPrefs().AsStruct() + p.WantRunning = true + pm.SetPrefs(p.View(), ipn.NetworkProfile{}) + pm.SwitchToNewProfile() + }, + wantChanges: []profileStateChange{ + wantProfileChange(emptyProfile), // new empty profile + wantPrefsChange(profileState{ // prefs change, same profile + LoginProfile: &ipn.LoginProfile{}, + mutPrefs: func(p *ipn.Prefs) { + *p = *defaultPrefs.AsStruct() + p.WantRunning = true + }, + }), + wantProfileChange(emptyProfile), // new empty profile again + }, + }, + { + name: "switch-to-profile/by-id", + knownProfiles: []profileState{profileA, profileB, profileC}, + action: func(pm *profileManager) { + // Switching to a known profile by ID should trigger a state change callback. + pm.SwitchToProfileByID(profileB.ID) + }, + wantChanges: []profileStateChange{ + wantProfileChange(profileB), + }, + }, + { + name: "switch-to-profile/by-id/non-existent", + knownProfiles: []profileState{profileA, profileC}, // no profileB + action: func(pm *profileManager) { + // Switching to a non-existent profile should fail and not trigger a state change callback. + pm.SwitchToProfileByID(profileB.ID) + }, + wantChanges: []profileStateChange{}, + }, + { + name: "switch-to-profile/by-id/twice-same", + knownProfiles: []profileState{profileA, profileB, profileC}, + action: func(pm *profileManager) { + // But only for the first switch. + // The second switch to the same profile should not trigger a state change callback. + pm.SwitchToProfileByID(profileB.ID) + pm.SwitchToProfileByID(profileB.ID) + }, + wantChanges: []profileStateChange{ + wantProfileChange(profileB), + }, + }, + { + name: "switch-to-profile/by-id/many", + knownProfiles: []profileState{profileA, profileB, profileC}, + action: func(pm *profileManager) { + // Same idea, but with multiple switches. + pm.SwitchToProfileByID(profileB.ID) // switch to Profile-B + pm.SwitchToProfileByID(profileB.ID) // then to Profile-B again (no change) + pm.SwitchToProfileByID(profileC.ID) // then to Profile-C (change) + pm.SwitchToProfileByID(profileA.ID) // then to Profile-A (change) + pm.SwitchToProfileByID(profileB.ID) // then to Profile-B (change) + }, + wantChanges: []profileStateChange{ + wantProfileChange(profileB), + wantProfileChange(profileC), + wantProfileChange(profileA), + wantProfileChange(profileB), + }, + }, + { + name: "switch-to-profile/by-view", + knownProfiles: []profileState{profileA, profileB, profileC}, + action: func(pm *profileManager) { + // Switching to a known profile by an [ipn.LoginProfileView] + // should also trigger a state change callback. + pm.SwitchToProfile(profileB.View()) + }, + wantChanges: []profileStateChange{ + wantProfileChange(profileB), + }, + }, + { + name: "switch-to-profile/by-view/empty", + initial: &profile0000, + action: func(pm *profileManager) { + // SwitchToProfile supports switching to an empty profile. + emptyProfile := &ipn.LoginProfile{} + pm.SwitchToProfile(emptyProfile.View()) + }, + wantChanges: []profileStateChange{ + wantProfileChange(emptyProfile), + }, + }, + { + name: "switch-to-profile/by-view/non-existent", + knownProfiles: []profileState{profileA, profileC}, + action: func(pm *profileManager) { + // Switching to a an unknown profile by an [ipn.LoginProfileView] + // should fail and not trigger a state change callback. + pm.SwitchToProfile(profileB.View()) + }, + wantChanges: []profileStateChange{}, + }, + { + name: "switch-to-profile/by-view/empty-for-user", + initial: &profile0000, + action: func(pm *profileManager) { + // And switching to an empty profile for a specific user also works. + pm.SwitchToProfile(bobEmptyProfile.View()) + }, + wantChanges: []profileStateChange{ + wantProfileChange(bobEmptyProfile), + }, + }, + { + name: "switch-to-profile/by-view/invalid", + initial: &profile0000, + action: func(pm *profileManager) { + // Switching to an invalid profile should create and switch + // to a new empty profile. + pm.SwitchToProfile(ipn.LoginProfileView{}) + }, + wantChanges: []profileStateChange{ + wantProfileChange(emptyProfile), + }, + }, + { + name: "delete-profile/current", + initial: &profileA, // profileA is the current profile + knownProfiles: []profileState{profileA, profileB, profileC}, + action: func(pm *profileManager) { + // Deleting the current profile should switch to a new empty profile. + pm.DeleteProfile(profileA.ID) + }, + wantChanges: []profileStateChange{ + wantProfileChange(emptyProfile), + }, + }, + { + name: "delete-profile/current-with-user", + initial: &bobKnownProfile, + knownProfiles: []profileState{profileA, profileB, profileC, bobKnownProfile}, + action: func(pm *profileManager) { + // Similarly, deleting the current profile for a specific user should switch + // to a new empty profile for that user (at least while the "current user" + // is still a thing on Windows). + pm.DeleteProfile(bobKnownProfile.ID) + }, + wantChanges: []profileStateChange{ + wantProfileChange(bobEmptyProfile), + }, + }, + { + name: "delete-profile/non-current", + initial: &profileA, // profileA is the current profile + knownProfiles: []profileState{profileA, profileB, profileC}, + action: func(pm *profileManager) { + // But deleting a non-current profile should not trigger a state change callback. + pm.DeleteProfile(profileB.ID) + }, + wantChanges: []profileStateChange{}, + }, + { + name: "set-prefs/new-profile", + initial: &emptyProfile, // the current profile is empty + action: func(pm *profileManager) { + // The current profile is new and empty, but we can still set p. + // This should trigger a state change callback. + p := pm.CurrentPrefs().AsStruct() + p.WantRunning = true + p.Hostname = "New-Hostname" + pm.SetPrefs(p.View(), ipn.NetworkProfile{}) + }, + wantChanges: []profileStateChange{ + // Still an empty profile, but with new prefs. + wantPrefsChange(profileState{ + LoginProfile: emptyProfile.LoginProfile, + mutPrefs: func(p *ipn.Prefs) { + *p = *emptyProfile.prefs().AsStruct() + p.WantRunning = true + p.Hostname = "New-Hostname" + }, + }), + }, + }, + { + name: "set-prefs/current-profile", + initial: &profileA, // profileA is the current profile + knownProfiles: []profileState{profileA, profileB, profileC}, + action: func(pm *profileManager) { + p := pm.CurrentPrefs().AsStruct() + p.WantRunning = true + p.Hostname = "New-Hostname" + pm.SetPrefs(p.View(), ipn.NetworkProfile{}) + }, + wantChanges: []profileStateChange{ + wantPrefsChange(profileState{ + LoginProfile: profileA.LoginProfile, // same profile + mutPrefs: func(p *ipn.Prefs) { // but with new prefs + *p = *profileA.prefs().AsStruct() + p.WantRunning = true + p.Hostname = "New-Hostname" + }, + }), + }, + }, + { + name: "set-prefs/current-profile/profile-name", + initial: &profileA, // profileA is the current profile + knownProfiles: []profileState{profileA, profileB, profileC}, + action: func(pm *profileManager) { + p := pm.CurrentPrefs().AsStruct() + p.ProfileName = "This is User A" + pm.SetPrefs(p.View(), ipn.NetworkProfile{}) + }, + wantChanges: []profileStateChange{ + // Still the same profile, but with a new profile name + // populated from the prefs. The prefs are also updated. + wantPrefsChange(profileState{ + LoginProfile: func() *ipn.LoginProfile { + p := profileA.Clone() + p.Name = "This is User A" + return p + }(), + mutPrefs: func(p *ipn.Prefs) { + *p = *profileA.prefs().AsStruct() + p.ProfileName = "This is User A" + }, + }), + }, + }, + { + name: "set-prefs/implicit-switch/from-new", + initial: &emptyProfile, // a new, empty profile + knownProfiles: []profileState{profileA, profileB, profileC}, + action: func(pm *profileManager) { + // The user attempted to add a new profile but actually logged in as the same + // node/user as profileB. When [LocalBackend.SetControlClientStatus] calls + // [profileManager.SetPrefs] with the [persist.Persist] for profileB, we + // implicitly switch to that profile instead of creating a duplicate for the + // same node/user. + // + // TODO(nickkhyl): currently, [LocalBackend.SetControlClientStatus] uses the p + // of the current profile, not those of the profile we switch to. This is all wrong + // and should be fixed. But for now, we just test that the state change callback + // is called with the new profile and p. + p := pm.CurrentPrefs().AsStruct() + p.Persist = profileB.prefs().Persist().AsStruct() + p.WantRunning = true + p.LoggedOut = false + pm.SetPrefs(p.View(), ipn.NetworkProfile{}) + }, + wantChanges: []profileStateChange{ + // Calling [profileManager.SetPrefs] like this is effectively a profile switch + // rather than a prefs change. + wantProfileChange(profileState{ + LoginProfile: profileB.LoginProfile, + mutPrefs: func(p *ipn.Prefs) { + *p = *emptyProfile.prefs().AsStruct() + p.Persist = profileB.prefs().Persist().AsStruct() + p.WantRunning = true + p.LoggedOut = false + }, + }), + }, + }, + { + name: "set-prefs/implicit-switch/from-other", + initial: &profileA, // profileA is the current profile + knownProfiles: []profileState{profileA, profileB, profileC}, + action: func(pm *profileManager) { + // Same idea, but the current profile is profileA rather than a new empty profile. + // Note: this is all wrong. See the comment above and [profileManager.SetPrefs]. + p := pm.CurrentPrefs().AsStruct() + p.Persist = profileB.prefs().Persist().AsStruct() + p.WantRunning = true + p.LoggedOut = false + pm.SetPrefs(p.View(), ipn.NetworkProfile{}) + }, + wantChanges: []profileStateChange{ + wantProfileChange(profileState{ + LoginProfile: profileB.LoginProfile, + mutPrefs: func(p *ipn.Prefs) { + *p = *profileA.prefs().AsStruct() + p.Persist = profileB.prefs().Persist().AsStruct() + p.WantRunning = true + p.LoggedOut = false + }, + }), + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + store := new(mem.Store) + pm, err := newProfileManagerWithGOOS(store, logger.Discard, health.NewTracker(eventbustest.NewBus(t)), "linux") + if err != nil { + t.Fatalf("newProfileManagerWithGOOS: %v", err) + } + for _, p := range tt.knownProfiles { + pm.writePrefsToStore(p.Key, p.prefs()) + pm.knownProfiles[p.ID] = p.View() + } + if err := pm.writeKnownProfiles(); err != nil { + t.Fatalf("writeKnownProfiles: %v", err) + } + + if tt.initial != nil { + pm.currentUserID = tt.initial.LocalUserID + pm.currentProfile = tt.initial.View() + pm.prefs = tt.initial.prefs() + } + + type stateChange struct { + Profile *ipn.LoginProfile + Prefs *ipn.Prefs + SameNode bool + } + wantChanges := make([]stateChange, 0, len(tt.wantChanges)) + for _, w := range tt.wantChanges { + wantPrefs := ipn.NewPrefs() + w.mutPrefs(wantPrefs) // apply changes to the default prefs + wantChanges = append(wantChanges, stateChange{ + Profile: w.LoginProfile, + Prefs: wantPrefs, + SameNode: w.sameNode, + }) + } + + gotChanges := make([]stateChange, 0, len(tt.wantChanges)) + pm.StateChangeHook = func(profile ipn.LoginProfileView, prefView ipn.PrefsView, sameNode bool) { + prefs := prefView.AsStruct() + prefs.Sync = prefs.Sync.Normalized() + gotChanges = append(gotChanges, stateChange{ + Profile: profile.AsStruct(), + Prefs: prefs, + SameNode: sameNode, + }) + } + + tt.action(pm) + + if diff := cmp.Diff(wantChanges, gotChanges, defaultCmpOpts...); diff != "" { + t.Errorf("StateChange callbacks: (-want +got): %v", diff) + } + }) + } +} diff --git a/ipn/ipnlocal/serve.go b/ipn/ipnlocal/serve.go index 67d521f09..b5118873b 100644 --- a/ipn/ipnlocal/serve.go +++ b/ipn/ipnlocal/serve.go @@ -1,6 +1,10 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause +//go:build !ts_omit_serve + +// TODO: move this whole file to its own package, out of ipnlocal. + package ipnlocal import ( @@ -12,6 +16,7 @@ import ( "errors" "fmt" "io" + "maps" "mime" "net" "net/http" @@ -28,19 +33,39 @@ import ( "time" "unicode/utf8" - "golang.org/x/net/http2" + "github.com/pires/go-proxyproto" + "go4.org/mem" "tailscale.com/ipn" - "tailscale.com/logtail/backoff" "tailscale.com/net/netutil" "tailscale.com/syncs" "tailscale.com/tailcfg" "tailscale.com/types/lazy" "tailscale.com/types/logger" + "tailscale.com/types/views" + "tailscale.com/util/backoff" + "tailscale.com/util/clientmetric" "tailscale.com/util/ctxkey" "tailscale.com/util/mak" + "tailscale.com/util/slicesx" "tailscale.com/version" ) +func init() { + hookServeTCPHandlerForVIPService.Set((*LocalBackend).tcpHandlerForVIPService) + hookTCPHandlerForServe.Set((*LocalBackend).tcpHandlerForServe) + hookServeUpdateServeTCPPortNetMapAddrListenersLocked.Set((*LocalBackend).updateServeTCPPortNetMapAddrListenersLocked) + + hookServeSetTCPPortsInterceptedFromNetmapAndPrefsLocked.Set(serveSetTCPPortsInterceptedFromNetmapAndPrefsLocked) + hookServeClearVIPServicesTCPPortsInterceptedLocked.Set(func(b *LocalBackend) { + b.setVIPServicesTCPPortsInterceptedLocked(nil) + }) + + hookMaybeMutateHostinfoLocked.Add(maybeUpdateHostinfoServicesHashLocked) + hookMaybeMutateHostinfoLocked.Add(maybeUpdateHostinfoFunnelLocked) + + RegisterC2N("GET /vip-services", handleC2NVIPServicesGet) +} + const ( contentTypeHeader = "Content-Type" grpcBaseContentType = "application/grpc" @@ -54,11 +79,14 @@ var ErrETagMismatch = errors.New("etag mismatch") var serveHTTPContextKey ctxkey.Key[*serveHTTPContext] type serveHTTPContext struct { - SrcAddr netip.AddrPort - DestPort uint16 + SrcAddr netip.AddrPort + ForVIPService tailcfg.ServiceName // "" means local + DestPort uint16 // provides funnel-specific context, nil if not funneled Funnel *funnelFlow + // AppCapabilities lists all PeerCapabilities that should be forwarded by serve + AppCapabilities views.Slice[tailcfg.PeerCapability] } // funnelFlow represents a funneled connection initiated via IngressPeer @@ -221,6 +249,10 @@ func (s *localListener) handleListenersAccept(ln net.Listener) error { // // b.mu must be held. func (b *LocalBackend) updateServeTCPPortNetMapAddrListenersLocked(ports []uint16) { + if b.sys.IsNetstack() { + // don't listen on netmap addresses if we're in userspace mode + return + } // close existing listeners where port // is no longer in incoming ports list for ap, sl := range b.serveListeners { @@ -231,7 +263,7 @@ func (b *LocalBackend) updateServeTCPPortNetMapAddrListenersLocked(ports []uint1 } } - nm := b.netMap + nm := b.NetMap() if nm == nil { b.logf("netMap is nil") return @@ -242,8 +274,7 @@ func (b *LocalBackend) updateServeTCPPortNetMapAddrListenersLocked(ports []uint1 } addrs := nm.GetAddresses() - for i := range addrs.Len() { - a := addrs.At(i) + for _, a := range addrs.All() { for _, p := range ports { addrPort := netip.AddrPortFrom(a.Addr(), p) if _, ok := b.serveListeners[addrPort]; ok { @@ -276,7 +307,13 @@ func (b *LocalBackend) setServeConfigLocked(config *ipn.ServeConfig, etag string return errors.New("can't reconfigure tailscaled when using a config file; config file is locked") } - nm := b.netMap + if config != nil { + if err := config.CheckValidServicesConfig(); err != nil { + return err + } + } + + nm := b.NetMap() if nm == nil { return errors.New("netMap is nil") } @@ -312,7 +349,7 @@ func (b *LocalBackend) setServeConfigLocked(config *ipn.ServeConfig, etag string bs = j } - profileID := b.pm.CurrentProfile().ID + profileID := b.pm.CurrentProfile().ID() confKey := ipn.ServeConfigKey(profileID) if err := b.store.WriteState(confKey, bs); err != nil { return fmt.Errorf("writing ServeConfig to StateStore: %w", err) @@ -327,7 +364,7 @@ func (b *LocalBackend) setServeConfigLocked(config *ipn.ServeConfig, etag string if b.serveConfig.Valid() { has = b.serveConfig.Foreground().Contains } - prevConfig.Foreground().Range(func(k string, v ipn.ServeConfigView) (cont bool) { + for k := range prevConfig.Foreground().All() { if !has(k) { for _, sess := range b.notifyWatchers { if sess.sessionID == k { @@ -335,8 +372,7 @@ func (b *LocalBackend) setServeConfigLocked(config *ipn.ServeConfig, etag string } } } - return true - }) + } } return nil @@ -434,6 +470,137 @@ func (b *LocalBackend) HandleIngressTCPConn(ingressPeer tailcfg.NodeView, target handler(c) } +func (b *LocalBackend) vipServicesFromPrefsLocked(prefs ipn.PrefsView) []*tailcfg.VIPService { + // keyed by service name + var services map[tailcfg.ServiceName]*tailcfg.VIPService + if b.serveConfig.Valid() { + for svc, config := range b.serveConfig.Services().All() { + mak.Set(&services, svc, &tailcfg.VIPService{ + Name: svc, + Ports: config.ServicePortRange(), + }) + } + } + + for _, s := range prefs.AdvertiseServices().All() { + sn := tailcfg.ServiceName(s) + if services == nil || services[sn] == nil { + mak.Set(&services, sn, &tailcfg.VIPService{ + Name: sn, + }) + } + services[sn].Active = true + } + + servicesList := slicesx.MapValues(services) + // [slicesx.MapValues] provides the values in an indeterminate order, but since we'll + // be hashing a representation of this list later we want it to be in a consistent + // order. + slices.SortFunc(servicesList, func(a, b *tailcfg.VIPService) int { + return strings.Compare(a.Name.String(), b.Name.String()) + }) + return servicesList +} + +// tcpHandlerForVIPService returns a handler for a TCP connection to a VIP service +// that is being served via the ipn.ServeConfig. It returns nil if the destination +// address is not a VIP service or if the VIP service does not have a TCP handler set. +func (b *LocalBackend) tcpHandlerForVIPService(dstAddr, srcAddr netip.AddrPort) (handler func(net.Conn) error) { + b.mu.Lock() + sc := b.serveConfig + ipVIPServiceMap := b.ipVIPServiceMap + b.mu.Unlock() + + if !sc.Valid() { + return nil + } + + dport := dstAddr.Port() + + dstSvc, ok := ipVIPServiceMap[dstAddr.Addr()] + if !ok { + return nil + } + + tcph, ok := sc.FindServiceTCP(dstSvc, dstAddr.Port()) + if !ok { + b.logf("The destination service doesn't have a TCP handler set.") + return nil + } + + if tcph.HTTPS() || tcph.HTTP() { + hs := &http.Server{ + Handler: http.HandlerFunc(b.serveWebHandler), + BaseContext: func(_ net.Listener) context.Context { + return serveHTTPContextKey.WithValue(context.Background(), &serveHTTPContext{ + SrcAddr: srcAddr, + ForVIPService: dstSvc, + DestPort: dport, + }) + }, + } + if tcph.HTTPS() { + // TODO(kevinliang10): just leaving this TLS cert creation as if we don't have other + // hostnames, but for services this getTLSServeCetForPort will need a version that also take + // in the hostname. How to store the TLS cert is still being discussed. + hs.TLSConfig = &tls.Config{ + GetCertificate: b.getTLSServeCertForPort(dport, dstSvc), + } + return func(c net.Conn) error { + return hs.ServeTLS(netutil.NewOneConnListener(c, nil), "", "") + } + } + + return func(c net.Conn) error { + return hs.Serve(netutil.NewOneConnListener(c, nil)) + } + } + + if backDst := tcph.TCPForward(); backDst != "" { + return func(conn net.Conn) error { + defer conn.Close() + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + backConn, err := b.dialer.SystemDial(ctx, "tcp", backDst) + cancel() + if err != nil { + b.logf("localbackend: failed to TCP proxy port %v (from %v) to %s: %v", dport, srcAddr, backDst, err) + return nil + } + defer backConn.Close() + if sni := tcph.TerminateTLS(); sni != "" { + conn = tls.Server(conn, &tls.Config{ + GetCertificate: func(hi *tls.ClientHelloInfo) (*tls.Certificate, error) { + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + defer cancel() + pair, err := b.GetCertPEM(ctx, sni) + if err != nil { + return nil, err + } + cert, err := tls.X509KeyPair(pair.CertPEM, pair.KeyPEM) + if err != nil { + return nil, err + } + return &cert, nil + }, + }) + } + + errc := make(chan error, 1) + go func() { + _, err := io.Copy(backConn, conn) + errc <- err + }() + go func() { + _, err := io.Copy(conn, backConn) + errc <- err + }() + return <-errc + } + } + + return nil +} + // tcpHandlerForServe returns a handler for a TCP connection to be served via // the ipn.ServeConfig. The funnelFlow can be nil if this is not a funneled // connection. @@ -464,7 +631,7 @@ func (b *LocalBackend) tcpHandlerForServe(dport uint16, srcAddr netip.AddrPort, } if tcph.HTTPS() { hs.TLSConfig = &tls.Config{ - GetCertificate: b.getTLSServeCertForPort(dport), + GetCertificate: b.getTLSServeCertForPort(dport, ""), } return func(c net.Conn) error { return hs.ServeTLS(netutil.NewOneConnListener(c, nil), "", "") @@ -505,10 +672,81 @@ func (b *LocalBackend) tcpHandlerForServe(dport uint16, srcAddr netip.AddrPort, }) } + var proxyHeader []byte + if ver := tcph.ProxyProtocol(); ver > 0 { + // backAddr is the final "destination" of the connection, + // which is the connection to the proxied-to backend. + backAddr := backConn.RemoteAddr().(*net.TCPAddr) + + // We always want to format the PROXY protocol + // header based on the IPv4 or IPv6-ness of + // the client. The SourceAddr and + // DestinationAddr need to match in type, so we + // need to be careful to not e.g. set a + // SourceAddr of type IPv6 and DestinationAddr + // of type IPv4. + // + // If this is an IPv6-mapped IPv4 address, + // though, unmap it. + proxySrcAddr := srcAddr + if proxySrcAddr.Addr().Is4In6() { + proxySrcAddr = netip.AddrPortFrom( + proxySrcAddr.Addr().Unmap(), + proxySrcAddr.Port(), + ) + } + + is4 := proxySrcAddr.Addr().Is4() + + var destAddr netip.Addr + if self := b.currentNode().Self(); self.Valid() { + if is4 { + destAddr = nodeIP(self, netip.Addr.Is4) + } else { + destAddr = nodeIP(self, netip.Addr.Is6) + } + } + if !destAddr.IsValid() { + // Pick a best-effort destination address of localhost. + if is4 { + destAddr = netip.AddrFrom4([4]byte{127, 0, 0, 1}) + } else { + destAddr = netip.IPv6Loopback() + } + } + + header := &proxyproto.Header{ + Version: byte(ver), + Command: proxyproto.PROXY, + SourceAddr: net.TCPAddrFromAddrPort(proxySrcAddr), + DestinationAddr: &net.TCPAddr{ + IP: destAddr.AsSlice(), + Port: backAddr.Port, + }, + } + if is4 { + header.TransportProtocol = proxyproto.TCPv4 + } else { + header.TransportProtocol = proxyproto.TCPv6 + } + var err error + proxyHeader, err = header.Format() + if err != nil { + b.logf("localbackend: failed to format proxy protocol header for port %v (from %v) to %s: %v", dport, srcAddr, backDst, err) + } + } + // TODO(bradfitz): do the RegisterIPPortIdentity and // UnregisterIPPortIdentity stuff that netstack does errc := make(chan error, 1) go func() { + if len(proxyHeader) > 0 { + if _, err := backConn.Write(proxyHeader); err != nil { + errc <- err + backConn.Close() // to ensure that the other side gets EOF + return + } + } _, err := io.Copy(backConn, conn) errc <- err }() @@ -528,7 +766,7 @@ func (b *LocalBackend) getServeHandler(r *http.Request) (_ ipn.HTTPHandlerView, hostname := r.Host if r.TLS == nil { - tcd := "." + b.Status().CurrentTailnet.MagicDNSSuffix + tcd := "." + b.CurrentProfile().NetworkProfile().MagicDNSName if host, _, err := net.SplitHostPort(hostname); err == nil { hostname = host } @@ -544,7 +782,7 @@ func (b *LocalBackend) getServeHandler(r *http.Request) (_ ipn.HTTPHandlerView, b.logf("[unexpected] localbackend: no serveHTTPContext in request") return z, "", false } - wsc, ok := b.webServerConfig(hostname, sctx.DestPort) + wsc, ok := b.webServerConfig(hostname, sctx.ForVIPService, sctx.DestPort) if !ok { return z, "", false } @@ -600,8 +838,8 @@ type reverseProxy struct { insecure bool backend string lb *LocalBackend - httpTransport lazy.SyncValue[*http.Transport] // transport for non-h2c backends - h2cTransport lazy.SyncValue[*http2.Transport] // transport for h2c backends + httpTransport lazy.SyncValue[*http.Transport] // transport for non-h2c backends + h2cTransport lazy.SyncValue[*http.Transport] // transport for h2c backends // closed tracks whether proxy is closed/currently closing. closed atomic.Bool } @@ -609,9 +847,7 @@ type reverseProxy struct { // close ensures that any open backend connections get closed. func (rp *reverseProxy) close() { rp.closed.Store(true) - if h2cT := rp.h2cTransport.Get(func() *http2.Transport { - return nil - }); h2cT != nil { + if h2cT := rp.h2cTransport.Get(func() *http.Transport { return nil }); h2cT != nil { h2cT.CloseIdleConnections() } if httpTransport := rp.httpTransport.Get(func() *http.Transport { @@ -645,9 +881,11 @@ func (rp *reverseProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { r.Out.Host = r.In.Host addProxyForwardedHeaders(r) rp.lb.addTailscaleIdentityHeaders(r) - }} - - // There is no way to autodetect h2c as per RFC 9113 + if err := rp.lb.addAppCapabilitiesHeader(r); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + }} // There is no way to autodetect h2c as per RFC 9113 // https://datatracker.ietf.org/doc/html/rfc9113#name-starting-http-2. // However, we assume that http:// proxy prefix in combination with the // protoccol being HTTP/2 is sufficient to detect h2c for our needs. Only use this for @@ -682,14 +920,17 @@ func (rp *reverseProxy) getTransport() *http.Transport { // getH2CTransport returns the Transport used for GRPC requests to the backend. // The Transport gets created lazily, at most once. -func (rp *reverseProxy) getH2CTransport() *http2.Transport { - return rp.h2cTransport.Get(func() *http2.Transport { - return &http2.Transport{ - AllowHTTP: true, - DialTLSContext: func(ctx context.Context, network string, addr string, _ *tls.Config) (net.Conn, error) { +func (rp *reverseProxy) getH2CTransport() http.RoundTripper { + return rp.h2cTransport.Get(func() *http.Transport { + var p http.Protocols + p.SetUnencryptedHTTP2(true) + tr := &http.Transport{ + Protocols: &p, + DialTLSContext: func(ctx context.Context, network string, addr string) (net.Conn, error) { return rp.lb.dialer.SystemDial(ctx, "tcp", rp.url.Host) }, } + return tr }) } @@ -766,6 +1007,53 @@ func encTailscaleHeaderValue(v string) string { return mime.QEncoding.Encode("utf-8", v) } +func (b *LocalBackend) addAppCapabilitiesHeader(r *httputil.ProxyRequest) error { + const appCapabilitiesHeaderName = "Tailscale-App-Capabilities" + r.Out.Header.Del(appCapabilitiesHeaderName) + + c, ok := serveHTTPContextKey.ValueOk(r.Out.Context()) + if !ok || c.Funnel != nil { + return nil + } + acceptCaps := c.AppCapabilities + if acceptCaps.IsNil() { + return nil + } + peerCaps := b.PeerCaps(c.SrcAddr.Addr()) + if peerCaps == nil { + return nil + } + + peerCapsFiltered := make(map[tailcfg.PeerCapability][]tailcfg.RawMessage, acceptCaps.Len()) + for _, cap := range acceptCaps.AsSlice() { + if peerCaps.HasCapability(cap) { + peerCapsFiltered[cap] = peerCaps[cap] + } + } + + peerCapsSerialized, err := json.Marshal(peerCapsFiltered) + if err != nil { + b.logf("serve: failed to serialize filtered PeerCapMap: %v", err) + return fmt.Errorf("unable to process app capabilities") + } + + r.Out.Header.Set(appCapabilitiesHeaderName, encTailscaleHeaderValue(string(peerCapsSerialized))) + return nil +} + +// parseRedirectWithCode parses a redirect string that may optionally start with +// a HTTP redirect status code ("3xx:"). +// Returns the status code and the final redirect URL. +// If no code prefix is found, returns http.StatusFound (302). +func parseRedirectWithCode(redirect string) (code int, url string) { + if len(redirect) >= 4 && redirect[3] == ':' { + if statusCode, err := strconv.Atoi(redirect[:3]); err == nil && statusCode >= 300 && statusCode <= 399 { + return statusCode, redirect[4:] + } + } + return http.StatusFound, redirect +} + // serveWebHandler is an http.HandlerFunc that maps incoming requests to the // correct *http. func (b *LocalBackend) serveWebHandler(w http.ResponseWriter, r *http.Request) { @@ -779,6 +1067,13 @@ func (b *LocalBackend) serveWebHandler(w http.ResponseWriter, r *http.Request) { io.WriteString(w, s) return } + if v := h.Redirect(); v != "" { + code, v := parseRedirectWithCode(v) + v = strings.ReplaceAll(v, "${HOST}", r.Host) + v = strings.ReplaceAll(v, "${REQUEST_URI}", r.RequestURI) + http.Redirect(w, r, v, code) + return + } if v := h.Path(); v != "" { b.serveFileOrDirectory(w, r, v, mountPoint) return @@ -789,6 +1084,12 @@ func (b *LocalBackend) serveWebHandler(w http.ResponseWriter, r *http.Request) { http.Error(w, "unknown proxy destination", http.StatusInternalServerError) return } + // Inject app capabilities to forward into the request context + c, ok := serveHTTPContextKey.ValueOk(r.Context()) + if !ok { + return + } + c.AppCapabilities = h.AcceptAppCaps() h := p.(http.Handler) // Trim the mount point from the URL path before proxying. (#6571) if r.URL.Path != "/" { @@ -902,24 +1203,29 @@ func allNumeric(s string) bool { return s != "" } -func (b *LocalBackend) webServerConfig(hostname string, port uint16) (c ipn.WebServerConfigView, ok bool) { - key := ipn.HostPort(fmt.Sprintf("%s:%v", hostname, port)) - +func (b *LocalBackend) webServerConfig(hostname string, forVIPService tailcfg.ServiceName, port uint16) (c ipn.WebServerConfigView, ok bool) { b.mu.Lock() defer b.mu.Unlock() if !b.serveConfig.Valid() { return c, false } + if forVIPService != "" { + magicDNSSuffix := b.currentNode().NetMap().MagicDNSSuffix() + fqdn := strings.Join([]string{forVIPService.WithoutPrefix(), magicDNSSuffix}, ".") + key := ipn.HostPort(net.JoinHostPort(fqdn, fmt.Sprintf("%d", port))) + return b.serveConfig.FindServiceWeb(forVIPService, key) + } + key := ipn.HostPort(net.JoinHostPort(hostname, fmt.Sprintf("%d", port))) return b.serveConfig.FindWeb(key) } -func (b *LocalBackend) getTLSServeCertForPort(port uint16) func(hi *tls.ClientHelloInfo) (*tls.Certificate, error) { +func (b *LocalBackend) getTLSServeCertForPort(port uint16, forVIPService tailcfg.ServiceName) func(hi *tls.ClientHelloInfo) (*tls.Certificate, error) { return func(hi *tls.ClientHelloInfo) (*tls.Certificate, error) { if hi == nil || hi.ServerName == "" { return nil, errors.New("no SNI ServerName") } - _, ok := b.webServerConfig(hi.ServerName, port) + _, ok := b.webServerConfig(hi.ServerName, forVIPService, port) if !ok { return nil, errors.New("no webserver configured for name/port") } @@ -937,3 +1243,326 @@ func (b *LocalBackend) getTLSServeCertForPort(port uint16) func(hi *tls.ClientHe return &cert, nil } } + +// setServeProxyHandlersLocked ensures there is an http proxy handler for each +// backend specified in serveConfig. It expects serveConfig to be valid and +// up-to-date, so should be called after reloadServeConfigLocked. +func (b *LocalBackend) setServeProxyHandlersLocked() { + if !b.serveConfig.Valid() { + return + } + var backends map[string]bool + for _, conf := range b.serveConfig.Webs() { + for _, h := range conf.Handlers().All() { + backend := h.Proxy() + if backend == "" { + // Only create proxy handlers for servers with a proxy backend. + continue + } + mak.Set(&backends, backend, true) + if _, ok := b.serveProxyHandlers.Load(backend); ok { + continue + } + + b.logf("serve: creating a new proxy handler for %s", backend) + p, err := b.proxyHandlerForBackend(backend) + if err != nil { + // The backend endpoint (h.Proxy) should have been validated by expandProxyTarget + // in the CLI, so just log the error here. + b.logf("[unexpected] could not create proxy for %v: %s", backend, err) + continue + } + b.serveProxyHandlers.Store(backend, p) + } + } + + // Clean up handlers for proxy backends that are no longer present + // in configuration. + b.serveProxyHandlers.Range(func(key, value any) bool { + backend := key.(string) + if !backends[backend] { + b.logf("serve: closing idle connections to %s", backend) + b.serveProxyHandlers.Delete(backend) + value.(*reverseProxy).close() + } + return true + }) +} + +// VIPServices returns the list of tailnet services that this node +// is serving as a destination for. +// The returned memory is owned by the caller. +func (b *LocalBackend) VIPServices() []*tailcfg.VIPService { + b.mu.Lock() + defer b.mu.Unlock() + return b.vipServicesFromPrefsLocked(b.pm.CurrentPrefs()) +} + +func handleC2NVIPServicesGet(b *LocalBackend, w http.ResponseWriter, r *http.Request) { + b.logf("c2n: GET /vip-services received") + var res tailcfg.C2NVIPServicesResponse + res.VIPServices = b.VIPServices() + res.ServicesHash = vipServiceHash(b.logf, res.VIPServices) + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(res) +} + +var metricIngressCalls = clientmetric.NewCounter("peerapi_ingress") + +func init() { + RegisterPeerAPIHandler("/v0/ingress", handleServeIngress) + +} + +func handleServeIngress(ph PeerAPIHandler, w http.ResponseWriter, r *http.Request) { + h := ph.(*peerAPIHandler) + metricIngressCalls.Add(1) + + // http.Errors only useful if hitting endpoint manually + // otherwise rely on log lines when debugging ingress connections + // as connection is hijacked for bidi and is encrypted tls + if !h.canIngress() { + h.logf("ingress: denied; no ingress cap from %v", h.remoteAddr) + http.Error(w, "denied; no ingress cap", http.StatusForbidden) + return + } + logAndError := func(code int, publicMsg string) { + h.logf("ingress: bad request from %v: %s", h.remoteAddr, publicMsg) + http.Error(w, publicMsg, code) + } + bad := func(publicMsg string) { + logAndError(http.StatusBadRequest, publicMsg) + } + if r.Method != "POST" { + logAndError(http.StatusMethodNotAllowed, "only POST allowed") + return + } + srcAddrStr := r.Header.Get("Tailscale-Ingress-Src") + if srcAddrStr == "" { + bad("Tailscale-Ingress-Src header not set") + return + } + srcAddr, err := netip.ParseAddrPort(srcAddrStr) + if err != nil { + bad("Tailscale-Ingress-Src header invalid; want ip:port") + return + } + target := ipn.HostPort(r.Header.Get("Tailscale-Ingress-Target")) + if target == "" { + bad("Tailscale-Ingress-Target header not set") + return + } + if _, _, err := net.SplitHostPort(string(target)); err != nil { + bad("Tailscale-Ingress-Target header invalid; want host:port") + return + } + + getConnOrReset := func() (net.Conn, bool) { + conn, _, err := w.(http.Hijacker).Hijack() + if err != nil { + h.logf("ingress: failed hijacking conn") + http.Error(w, "failed hijacking conn", http.StatusInternalServerError) + return nil, false + } + io.WriteString(conn, "HTTP/1.1 101 Switching Protocols\r\n\r\n") + return &ipn.FunnelConn{ + Conn: conn, + Src: srcAddr, + Target: target, + }, true + } + sendRST := func() { + http.Error(w, "denied", http.StatusForbidden) + } + + h.ps.b.HandleIngressTCPConn(h.peerNode, target, srcAddr, getConnOrReset, sendRST) +} + +// wantIngressLocked reports whether this node has ingress configured. This bool +// is sent to the coordination server (in Hostinfo.WireIngress) as an +// optimization hint to know primarily which nodes are NOT using ingress, to +// avoid doing work for regular nodes. +// +// Even if the user's ServeConfig.AllowFunnel map was manually edited in raw +// mode and contains map entries with false values, sending true (from Len > 0) +// is still fine. This is only an optimization hint for the control plane and +// doesn't affect security or correctness. And we also don't expect people to +// modify their ServeConfig in raw mode. +func (b *LocalBackend) wantIngressLocked() bool { + return b.serveConfig.Valid() && b.serveConfig.HasAllowFunnel() +} + +// hasIngressEnabledLocked reports whether the node has any funnel endpoint enabled. This bool is sent to control (in +// Hostinfo.IngressEnabled) to determine whether 'Funnel' badge should be displayed on this node in the admin panel. +func (b *LocalBackend) hasIngressEnabledLocked() bool { + return b.serveConfig.Valid() && b.serveConfig.IsFunnelOn() +} + +// shouldWireInactiveIngressLocked reports whether the node is in a state where funnel is not actively enabled, but it +// seems that it is intended to be used with funnel. +func (b *LocalBackend) shouldWireInactiveIngressLocked() bool { + return b.serveConfig.Valid() && !b.hasIngressEnabledLocked() && b.wantIngressLocked() +} + +func serveSetTCPPortsInterceptedFromNetmapAndPrefsLocked(b *LocalBackend, prefs ipn.PrefsView) (handlePorts []uint16) { + var vipServicesPorts map[tailcfg.ServiceName][]uint16 + + b.reloadServeConfigLocked(prefs) + if b.serveConfig.Valid() { + servePorts := make([]uint16, 0, 3) + for port := range b.serveConfig.TCPs() { + if port > 0 { + servePorts = append(servePorts, uint16(port)) + } + } + handlePorts = append(handlePorts, servePorts...) + + for svc, cfg := range b.serveConfig.Services().All() { + servicePorts := make([]uint16, 0, 3) + for port := range cfg.TCP().All() { + if port > 0 { + servicePorts = append(servicePorts, uint16(port)) + } + } + if _, ok := vipServicesPorts[svc]; !ok { + mak.Set(&vipServicesPorts, svc, servicePorts) + } else { + mak.Set(&vipServicesPorts, svc, append(vipServicesPorts[svc], servicePorts...)) + } + } + + b.setServeProxyHandlersLocked() + + // don't listen on netmap addresses if we're in userspace mode + if !b.sys.IsNetstack() { + b.updateServeTCPPortNetMapAddrListenersLocked(servePorts) + } + } + + b.setVIPServicesTCPPortsInterceptedLocked(vipServicesPorts) + + return handlePorts +} + +// reloadServeConfigLocked reloads the serve config from the store or resets the +// serve config to nil if not logged in. The "changed" parameter, when false, instructs +// the method to only run the reset-logic and not reload the store from memory to ensure +// foreground sessions are not removed if they are not saved on disk. +func (b *LocalBackend) reloadServeConfigLocked(prefs ipn.PrefsView) { + if !b.currentNode().Self().Valid() || !prefs.Valid() || b.pm.CurrentProfile().ID() == "" { + // We're not logged in, so we don't have a profile. + // Don't try to load the serve config. + b.lastServeConfJSON = mem.B(nil) + b.serveConfig = ipn.ServeConfigView{} + return + } + + confKey := ipn.ServeConfigKey(b.pm.CurrentProfile().ID()) + // TODO(maisem,bradfitz): prevent reading the config from disk + // if the profile has not changed. + confj, err := b.store.ReadState(confKey) + if err != nil { + b.lastServeConfJSON = mem.B(nil) + b.serveConfig = ipn.ServeConfigView{} + return + } + if b.lastServeConfJSON.Equal(mem.B(confj)) { + return + } + b.lastServeConfJSON = mem.B(confj) + var conf ipn.ServeConfig + if err := json.Unmarshal(confj, &conf); err != nil { + b.logf("invalid ServeConfig %q in StateStore: %v", confKey, err) + b.serveConfig = ipn.ServeConfigView{} + return + } + + // remove inactive sessions + maps.DeleteFunc(conf.Foreground, func(sessionID string, sc *ipn.ServeConfig) bool { + _, ok := b.notifyWatchers[sessionID] + return !ok + }) + + b.serveConfig = conf.View() +} + +func (b *LocalBackend) setVIPServicesTCPPortsInterceptedLocked(svcPorts map[tailcfg.ServiceName][]uint16) { + if len(svcPorts) == 0 { + b.shouldInterceptVIPServicesTCPPortAtomic.Store(func(netip.AddrPort) bool { return false }) + return + } + nm := b.currentNode().NetMap() + if nm == nil { + b.logf("can't set intercept function for Service TCP Ports, netMap is nil") + return + } + vipServiceIPMap := nm.GetVIPServiceIPMap() + if len(vipServiceIPMap) == 0 { + // No approved VIP Services + return + } + + svcAddrPorts := make(map[netip.Addr]func(uint16) bool) + // Only set the intercept function if the service has been assigned a VIP. + for svcName, ports := range svcPorts { + addrs, ok := vipServiceIPMap[svcName] + if !ok { + continue + } + interceptFn := generateInterceptTCPPortFunc(ports) + for _, addr := range addrs { + svcAddrPorts[addr] = interceptFn + } + } + + b.shouldInterceptVIPServicesTCPPortAtomic.Store(generateInterceptVIPServicesTCPPortFunc(svcAddrPorts)) +} + +func maybeUpdateHostinfoServicesHashLocked(b *LocalBackend, hi *tailcfg.Hostinfo, prefs ipn.PrefsView) bool { + latestHash := vipServiceHash(b.logf, b.vipServicesFromPrefsLocked(prefs)) + if hi.ServicesHash != latestHash { + hi.ServicesHash = latestHash + return true + } + return false +} + +func maybeUpdateHostinfoFunnelLocked(b *LocalBackend, hi *tailcfg.Hostinfo, prefs ipn.PrefsView) (changed bool) { + // The Hostinfo.IngressEnabled field is used to communicate to control whether + // the node has funnel enabled. + if ie := b.hasIngressEnabledLocked(); hi.IngressEnabled != ie { + b.logf("Hostinfo.IngressEnabled changed to %v", ie) + hi.IngressEnabled = ie + changed = true + } + // The Hostinfo.WireIngress field tells control whether the user intends + // to use funnel with this node even though it is not currently enabled. + // This is an optimization to control- Funnel requires creation of DNS + // records and because DNS propagation can take time, we want to ensure + // that the records exist for any node that intends to use funnel even + // if it's not enabled. If hi.IngressEnabled is true, control knows that + // DNS records are needed, so we can save bandwidth and not send + // WireIngress. + if wire := b.shouldWireInactiveIngressLocked(); hi.WireIngress != wire { + b.logf("Hostinfo.WireIngress changed to %v", wire) + hi.WireIngress = wire + changed = true + } + return changed +} + +func vipServiceHash(logf logger.Logf, services []*tailcfg.VIPService) string { + if len(services) == 0 { + return "" + } + h := sha256.New() + jh := json.NewEncoder(h) + if err := jh.Encode(services); err != nil { + logf("vipServiceHashLocked: %v", err) + return "" + } + var buf [sha256.Size]byte + h.Sum(buf[:0]) + return hex.EncodeToString(buf[:]) +} diff --git a/ipn/ipnlocal/serve_disabled.go b/ipn/ipnlocal/serve_disabled.go new file mode 100644 index 000000000..a97112941 --- /dev/null +++ b/ipn/ipnlocal/serve_disabled.go @@ -0,0 +1,34 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build ts_omit_serve + +// These are temporary (2025-09-13) stubs for when tailscaled is built with the +// ts_omit_serve build tag, disabling serve. +// +// TODO: move serve to a separate package, out of ipnlocal, and delete this +// file. One step at a time. + +package ipnlocal + +import ( + "tailscale.com/ipn" + "tailscale.com/tailcfg" +) + +const serveEnabled = false + +type localListener = struct{} + +func (b *LocalBackend) DeleteForegroundSession(sessionID string) error { + return nil +} + +type funnelFlow = struct{} + +func (*LocalBackend) hasIngressEnabledLocked() bool { return false } +func (*LocalBackend) shouldWireInactiveIngressLocked() bool { return false } + +func (b *LocalBackend) vipServicesFromPrefsLocked(prefs ipn.PrefsView) []*tailcfg.VIPService { + return nil +} diff --git a/ipn/ipnlocal/serve_test.go b/ipn/ipnlocal/serve_test.go index 73e66c2b9..c3e5b2ff9 100644 --- a/ipn/ipnlocal/serve_test.go +++ b/ipn/ipnlocal/serve_test.go @@ -1,6 +1,8 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause +//go:build !ts_omit_serve + package ipnlocal import ( @@ -13,6 +15,8 @@ import ( "encoding/json" "errors" "fmt" + "io" + "mime" "net/http" "net/http/httptest" "net/netip" @@ -24,6 +28,7 @@ import ( "testing" "time" + "tailscale.com/control/controlclient" "tailscale.com/health" "tailscale.com/ipn" "tailscale.com/ipn/store/mem" @@ -33,9 +38,12 @@ import ( "tailscale.com/types/logger" "tailscale.com/types/logid" "tailscale.com/types/netmap" + "tailscale.com/util/eventbus/eventbustest" "tailscale.com/util/mak" "tailscale.com/util/must" + "tailscale.com/util/syspolicy/policyclient" "tailscale.com/wgengine" + "tailscale.com/wgengine/filter" ) func TestExpandProxyArg(t *testing.T) { @@ -64,6 +72,41 @@ func TestExpandProxyArg(t *testing.T) { } } +func TestParseRedirectWithRedirectCode(t *testing.T) { + tests := []struct { + in string + wantCode int + wantURL string + }{ + {"301:https://example.com", 301, "https://example.com"}, + {"302:https://example.com", 302, "https://example.com"}, + {"303:/path", 303, "/path"}, + {"307:https://example.com/path?query=1", 307, "https://example.com/path?query=1"}, + {"308:https://example.com", 308, "https://example.com"}, + + {"https://example.com", 302, "https://example.com"}, + {"/path", 302, "/path"}, + {"http://example.com", 302, "http://example.com"}, + {"git://example.com", 302, "git://example.com"}, + + {"200:https://example.com", 302, "200:https://example.com"}, + {"404:https://example.com", 302, "404:https://example.com"}, + {"500:https://example.com", 302, "500:https://example.com"}, + {"30:https://example.com", 302, "30:https://example.com"}, + {"3:https://example.com", 302, "3:https://example.com"}, + {"3012:https://example.com", 302, "3012:https://example.com"}, + {"abc:https://example.com", 302, "abc:https://example.com"}, + {"301", 302, "301"}, + } + for _, tt := range tests { + gotCode, gotURL := parseRedirectWithCode(tt.in) + if gotCode != tt.wantCode || gotURL != tt.wantURL { + t.Errorf("parseRedirectWithCode(%q) = (%d, %q), want (%d, %q)", + tt.in, gotCode, gotURL, tt.wantCode, tt.wantURL) + } + } +} + func TestGetServeHandler(t *testing.T) { const serverName = "example.ts.net" conf1 := &ipn.ServeConfig{ @@ -239,11 +282,15 @@ func TestServeConfigForeground(t *testing.T) { err := b.SetServeConfig(&ipn.ServeConfig{ Foreground: map[string]*ipn.ServeConfig{ - session1: {TCP: map[uint16]*ipn.TCPPortHandler{ - 443: {TCPForward: "http://localhost:3000"}}, + session1: { + TCP: map[uint16]*ipn.TCPPortHandler{ + 443: {TCPForward: "http://localhost:3000"}, + }, }, - session2: {TCP: map[uint16]*ipn.TCPPortHandler{ - 999: {TCPForward: "http://localhost:4000"}}, + session2: { + TCP: map[uint16]*ipn.TCPPortHandler{ + 999: {TCPForward: "http://localhost:4000"}, + }, }, }, }, "") @@ -266,8 +313,10 @@ func TestServeConfigForeground(t *testing.T) { 5000: {TCPForward: "http://localhost:5000"}, }, Foreground: map[string]*ipn.ServeConfig{ - session2: {TCP: map[uint16]*ipn.TCPPortHandler{ - 999: {TCPForward: "http://localhost:4000"}}, + session2: { + TCP: map[uint16]*ipn.TCPPortHandler{ + 999: {TCPForward: "http://localhost:4000"}, + }, }, }, }, "") @@ -296,6 +345,202 @@ func TestServeConfigForeground(t *testing.T) { } } +// TestServeConfigServices tests the side effects of setting the +// Services field in a ServeConfig. The Services field is a map +// of all services the current service host is serving. Unlike what we +// serve for node itself, there is no foreground and no local handlers +// for the services. So the only things we need to test are if the +// services configured are valid and if they correctly set intercept +// functions for netStack. +func TestServeConfigServices(t *testing.T) { + b := newTestBackend(t) + svcIPMap := tailcfg.ServiceIPMappings{ + "svc:foo": []netip.Addr{ + netip.MustParseAddr("100.101.101.101"), + netip.MustParseAddr("fd7a:115c:a1e0:ab12:4843:cd96:6565:6565"), + }, + "svc:bar": []netip.Addr{ + netip.MustParseAddr("100.99.99.99"), + netip.MustParseAddr("fd7a:115c:a1e0:ab12:4843:cd96:626b:628b"), + }, + } + svcIPMapJSON, err := json.Marshal(svcIPMap) + if err != nil { + t.Fatal(err) + } + + b.currentNode().SetNetMap(&netmap.NetworkMap{ + SelfNode: (&tailcfg.Node{ + Name: "example.ts.net", + CapMap: tailcfg.NodeCapMap{ + tailcfg.NodeAttrServiceHost: []tailcfg.RawMessage{tailcfg.RawMessage(svcIPMapJSON)}, + }, + }).View(), + UserProfiles: map[tailcfg.UserID]tailcfg.UserProfileView{ + tailcfg.UserID(1): (&tailcfg.UserProfile{ + LoginName: "someone@example.com", + DisplayName: "Some One", + ProfilePicURL: "https://example.com/photo.jpg", + }).View(), + }, + }) + + tests := []struct { + name string + conf *ipn.ServeConfig + expectedErr error + packetDstAddrPort []netip.AddrPort + intercepted bool + }{ + { + name: "no-services", + conf: &ipn.ServeConfig{}, + packetDstAddrPort: []netip.AddrPort{ + netip.MustParseAddrPort("100.101.101.101:443"), + }, + intercepted: false, + }, + { + name: "one-incorrectly-configured-service", + conf: &ipn.ServeConfig{ + Services: map[tailcfg.ServiceName]*ipn.ServiceConfig{ + "svc:foo": { + TCP: map[uint16]*ipn.TCPPortHandler{ + 80: {HTTP: true}, + }, + Tun: true, + }, + }, + }, + expectedErr: ipn.ErrServiceConfigHasBothTCPAndTun, + }, + { + // one correctly configured service with packet should be intercepted + name: "one-service-intercept-packet", + conf: &ipn.ServeConfig{ + Services: map[tailcfg.ServiceName]*ipn.ServiceConfig{ + "svc:foo": { + TCP: map[uint16]*ipn.TCPPortHandler{ + 80: {HTTP: true}, + 81: {HTTPS: true}, + }, + }, + }, + }, + packetDstAddrPort: []netip.AddrPort{ + netip.MustParseAddrPort("100.101.101.101:80"), + netip.MustParseAddrPort("[fd7a:115c:a1e0:ab12:4843:cd96:6565:6565]:80"), + }, + intercepted: true, + }, + { + // one correctly configured service with packet should not be intercepted + name: "one-service-not-intercept-packet", + conf: &ipn.ServeConfig{ + Services: map[tailcfg.ServiceName]*ipn.ServiceConfig{ + "svc:foo": { + TCP: map[uint16]*ipn.TCPPortHandler{ + 80: {HTTP: true}, + 81: {HTTPS: true}, + }, + }, + }, + }, + packetDstAddrPort: []netip.AddrPort{ + netip.MustParseAddrPort("100.99.99.99:80"), + netip.MustParseAddrPort("[fd7a:115c:a1e0:ab12:4843:cd96:626b:628b]:80"), + netip.MustParseAddrPort("100.101.101.101:82"), + netip.MustParseAddrPort("[fd7a:115c:a1e0:ab12:4843:cd96:6565:6565]:82"), + }, + intercepted: false, + }, + { + // multiple correctly configured service with packet should be intercepted + name: "multiple-service-intercept-packet", + conf: &ipn.ServeConfig{ + Services: map[tailcfg.ServiceName]*ipn.ServiceConfig{ + "svc:foo": { + TCP: map[uint16]*ipn.TCPPortHandler{ + 80: {HTTP: true}, + 81: {HTTPS: true}, + }, + }, + "svc:bar": { + TCP: map[uint16]*ipn.TCPPortHandler{ + 80: {HTTP: true}, + 81: {HTTPS: true}, + 82: {HTTPS: true}, + }, + }, + }, + }, + packetDstAddrPort: []netip.AddrPort{ + netip.MustParseAddrPort("100.99.99.99:80"), + netip.MustParseAddrPort("[fd7a:115c:a1e0:ab12:4843:cd96:626b:628b]:80"), + netip.MustParseAddrPort("100.101.101.101:81"), + netip.MustParseAddrPort("[fd7a:115c:a1e0:ab12:4843:cd96:6565:6565]:81"), + }, + intercepted: true, + }, + { + // multiple correctly configured service with packet should not be intercepted + name: "multiple-service-not-intercept-packet", + conf: &ipn.ServeConfig{ + Services: map[tailcfg.ServiceName]*ipn.ServiceConfig{ + "svc:foo": { + TCP: map[uint16]*ipn.TCPPortHandler{ + 80: {HTTP: true}, + 81: {HTTPS: true}, + }, + }, + "svc:bar": { + TCP: map[uint16]*ipn.TCPPortHandler{ + 80: {HTTP: true}, + 81: {HTTPS: true}, + 82: {HTTPS: true}, + }, + }, + }, + }, + packetDstAddrPort: []netip.AddrPort{ + // ips in capmap but port is not hosting service + netip.MustParseAddrPort("100.99.99.99:77"), + netip.MustParseAddrPort("[fd7a:115c:a1e0:ab12:4843:cd96:626b:628b]:77"), + netip.MustParseAddrPort("100.101.101.101:85"), + netip.MustParseAddrPort("[fd7a:115c:a1e0:ab12:4843:cd96:6565:6565]:85"), + // ips not in capmap + netip.MustParseAddrPort("100.102.102.102:80"), + netip.MustParseAddrPort("[fd7a:115c:a1e0:ab12:4843:cd96:6666:6666]:80"), + }, + intercepted: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := b.SetServeConfig(tt.conf, "") + if err != nil && tt.expectedErr != nil { + if !errors.Is(err, tt.expectedErr) { + t.Fatalf("expected error %v,\n got %v", tt.expectedErr, err) + } + return + } + if err != nil { + t.Fatal(err) + } + for _, addrPort := range tt.packetDstAddrPort { + if tt.intercepted != b.ShouldInterceptVIPServiceTCPPort(addrPort) { + if tt.intercepted { + t.Fatalf("expected packet to be intercepted") + } else { + t.Fatalf("expected packet not to be intercepted") + } + } + } + }) + } +} + func TestServeConfigETag(t *testing.T) { b := newTestBackend(t) @@ -461,6 +706,7 @@ func TestServeHTTPProxyPath(t *testing.T) { }) } } + func TestServeHTTPProxyHeaders(t *testing.T) { b := newTestBackend(t) @@ -560,6 +806,156 @@ func TestServeHTTPProxyHeaders(t *testing.T) { } } +func TestServeHTTPProxyGrantHeader(t *testing.T) { + b := newTestBackend(t) + + nm := b.NetMap() + matches, err := filter.MatchesFromFilterRules([]tailcfg.FilterRule{ + { + SrcIPs: []string{"100.150.151.152"}, + CapGrant: []tailcfg.CapGrant{{ + Dsts: []netip.Prefix{ + netip.MustParsePrefix("100.150.151.151/32"), + }, + CapMap: tailcfg.PeerCapMap{ + "example.com/cap/interesting": []tailcfg.RawMessage{ + `{"role": "đŸŋ"}`, + }, + }, + }}, + }, + { + SrcIPs: []string{"100.150.151.153"}, + CapGrant: []tailcfg.CapGrant{{ + Dsts: []netip.Prefix{ + netip.MustParsePrefix("100.150.151.151/32"), + }, + CapMap: tailcfg.PeerCapMap{ + "example.com/cap/boring": []tailcfg.RawMessage{ + `{"role": "Viewer"}`, + }, + "example.com/cap/irrelevant": []tailcfg.RawMessage{ + `{"role": "Editor"}`, + }, + }, + }}, + }, + }) + if err != nil { + t.Fatal(err) + } + nm.PacketFilter = matches + b.SetControlClientStatus(nil, controlclient.Status{NetMap: nm}) + + // Start test serve endpoint. + testServ := httptest.NewServer(http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + // Piping all the headers through the response writer + // so we can check their values in tests below. + for key, val := range r.Header { + w.Header().Add(key, strings.Join(val, ",")) + } + }, + )) + defer testServ.Close() + + conf := &ipn.ServeConfig{ + Web: map[ipn.HostPort]*ipn.WebServerConfig{ + "example.ts.net:443": {Handlers: map[string]*ipn.HTTPHandler{ + "/": { + Proxy: testServ.URL, + AcceptAppCaps: []tailcfg.PeerCapability{"example.com/cap/interesting", "example.com/cap/boring"}, + }, + }}, + }, + } + if err := b.SetServeConfig(conf, ""); err != nil { + t.Fatal(err) + } + + type headerCheck struct { + header string + want string + } + + tests := []struct { + name string + srcIP string + wantHeaders []headerCheck + }{ + { + name: "request-from-user-within-tailnet", + srcIP: "100.150.151.152", + wantHeaders: []headerCheck{ + {"X-Forwarded-Proto", "https"}, + {"X-Forwarded-For", "100.150.151.152"}, + {"Tailscale-User-Login", "someone@example.com"}, + {"Tailscale-User-Name", "Some One"}, + {"Tailscale-User-Profile-Pic", "https://example.com/photo.jpg"}, + {"Tailscale-Headers-Info", "https://tailscale.com/s/serve-headers"}, + {"Tailscale-App-Capabilities", `{"example.com/cap/interesting":[{"role":"đŸŋ"}]}`}, + }, + }, + { + name: "request-from-tagged-node-within-tailnet", + srcIP: "100.150.151.153", + wantHeaders: []headerCheck{ + {"X-Forwarded-Proto", "https"}, + {"X-Forwarded-For", "100.150.151.153"}, + {"Tailscale-User-Login", ""}, + {"Tailscale-User-Name", ""}, + {"Tailscale-User-Profile-Pic", ""}, + {"Tailscale-Headers-Info", ""}, + {"Tailscale-App-Capabilities", `{"example.com/cap/boring":[{"role":"Viewer"}]}`}, + }, + }, + { + name: "request-from-outside-tailnet", + srcIP: "100.160.161.162", + wantHeaders: []headerCheck{ + {"X-Forwarded-Proto", "https"}, + {"X-Forwarded-For", "100.160.161.162"}, + {"Tailscale-User-Login", ""}, + {"Tailscale-User-Name", ""}, + {"Tailscale-User-Profile-Pic", ""}, + {"Tailscale-Headers-Info", ""}, + {"Tailscale-App-Capabilities", ""}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := &http.Request{ + URL: &url.URL{Path: "/"}, + TLS: &tls.ConnectionState{ServerName: "example.ts.net"}, + } + req = req.WithContext(serveHTTPContextKey.WithValue(req.Context(), &serveHTTPContext{ + DestPort: 443, + SrcAddr: netip.MustParseAddrPort(tt.srcIP + ":1234"), // random src port for tests + })) + + w := httptest.NewRecorder() + b.serveWebHandler(w, req) + + // Verify the headers. The contract with users is that identity and grant headers containing non-ASCII + // UTF-8 characters will be Q-encoded. + h := w.Result().Header + dec := new(mime.WordDecoder) + for _, c := range tt.wantHeaders { + maybeEncoded := h.Get(c.header) + got, err := dec.DecodeHeader(maybeEncoded) + if err != nil { + t.Fatalf("invalid %q header; failed to decode: %v", maybeEncoded, err) + } + if got != c.want { + t.Errorf("invalid %q header; want=%q, got=%q", c.header, c.want, got) + } + } + }) + } +} + func Test_reverseProxyConfiguration(t *testing.T) { b := newTestBackend(t) type test struct { @@ -661,7 +1057,6 @@ func Test_reverseProxyConfiguration(t *testing.T) { wantsURL: mustCreateURL(t, "https://example3.com"), }, }) - } func mustCreateURL(t *testing.T, u string) url.URL { @@ -673,18 +1068,30 @@ func mustCreateURL(t *testing.T, u string) url.URL { return *uParsed } -func newTestBackend(t *testing.T) *LocalBackend { +func newTestBackend(t *testing.T, opts ...any) *LocalBackend { var logf logger.Logf = logger.Discard - const debug = true + const debug = false if debug { logf = logger.WithPrefix(tstest.WhileTestRunningLogger(t), "... ") } - sys := &tsd.System{} + bus := eventbustest.NewBus(t) + sys := tsd.NewSystemWithBus(bus) + + for _, o := range opts { + switch v := o.(type) { + case policyclient.Client: + sys.PolicyClient.Set(v) + default: + panic(fmt.Sprintf("unsupported option type %T", v)) + } + } + e, err := wgengine.NewUserspaceEngine(logf, wgengine.Config{ SetSubsystem: sys.Set, - HealthTracker: sys.HealthTracker(), + HealthTracker: sys.HealthTracker.Get(), Metrics: sys.UserMetricsRegistry(), + EventBus: sys.Bus.Get(), }) if err != nil { t.Fatal(err) @@ -700,52 +1107,59 @@ func newTestBackend(t *testing.T) *LocalBackend { dir := t.TempDir() b.SetVarRoot(dir) - pm := must.Get(newProfileManager(new(mem.Store), logf, new(health.Tracker))) - pm.currentProfile = &ipn.LoginProfile{ID: "id0"} + pm := must.Get(newProfileManager(new(mem.Store), logf, health.NewTracker(bus))) + pm.currentProfile = (&ipn.LoginProfile{ID: "id0"}).View() b.pm = pm - b.netMap = &netmap.NetworkMap{ + b.currentNode().SetNetMap(&netmap.NetworkMap{ SelfNode: (&tailcfg.Node{ Name: "example.ts.net", + Addresses: []netip.Prefix{ + netip.MustParsePrefix("100.150.151.151/32"), + }, }).View(), - UserProfiles: map[tailcfg.UserID]tailcfg.UserProfile{ - tailcfg.UserID(1): { + UserProfiles: map[tailcfg.UserID]tailcfg.UserProfileView{ + tailcfg.UserID(1): (&tailcfg.UserProfile{ LoginName: "someone@example.com", DisplayName: "Some One", ProfilePicURL: "https://example.com/photo.jpg", - }, + }).View(), }, - } - b.peers = map[tailcfg.NodeID]tailcfg.NodeView{ - 152: (&tailcfg.Node{ - ID: 152, - ComputedName: "some-peer", - User: tailcfg.UserID(1), - }).View(), - 153: (&tailcfg.Node{ - ID: 153, - ComputedName: "some-tagged-peer", - Tags: []string{"tag:server", "tag:test"}, - User: tailcfg.UserID(1), - }).View(), - } - b.nodeByAddr = map[netip.Addr]tailcfg.NodeID{ - netip.MustParseAddr("100.150.151.152"): 152, - netip.MustParseAddr("100.150.151.153"): 153, - } + Peers: []tailcfg.NodeView{ + (&tailcfg.Node{ + ID: 152, + ComputedName: "some-peer", + User: tailcfg.UserID(1), + Key: makeNodeKeyFromID(152), + Addresses: []netip.Prefix{ + netip.MustParsePrefix("100.150.151.152/32"), + }, + }).View(), + (&tailcfg.Node{ + ID: 153, + ComputedName: "some-tagged-peer", + Tags: []string{"tag:server", "tag:test"}, + User: tailcfg.UserID(1), + Key: makeNodeKeyFromID(153), + Addresses: []netip.Prefix{ + netip.MustParsePrefix("100.150.151.153/32"), + }, + }).View(), + }, + }) return b } func TestServeFileOrDirectory(t *testing.T) { td := t.TempDir() writeFile := func(suffix, contents string) { - if err := os.WriteFile(filepath.Join(td, suffix), []byte(contents), 0600); err != nil { + if err := os.WriteFile(filepath.Join(td, suffix), []byte(contents), 0o600); err != nil { t.Fatal(err) } } writeFile("foo", "this is foo") writeFile("bar", "this is bar") - os.MkdirAll(filepath.Join(td, "subdir"), 0700) + os.MkdirAll(filepath.Join(td, "subdir"), 0o700) writeFile("subdir/file-a", "this is A") writeFile("subdir/file-b", "this is B") writeFile("subdir/file-c", "this is C") @@ -863,3 +1277,180 @@ func TestEncTailscaleHeaderValue(t *testing.T) { } } } + +func TestServeGRPCProxy(t *testing.T) { + const msg = "some-response\n" + backend := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Path-Was", r.RequestURI) + w.Header().Set("Proto-Was", r.Proto) + io.WriteString(w, msg) + })) + backend.EnableHTTP2 = true + backend.Config.Protocols = new(http.Protocols) + backend.Config.Protocols.SetHTTP1(true) + backend.Config.Protocols.SetUnencryptedHTTP2(true) + backend.Start() + defer backend.Close() + + backendURL := must.Get(url.Parse(backend.URL)) + + lb := newTestBackend(t) + rp := &reverseProxy{ + logf: t.Logf, + url: backendURL, + backend: backend.URL, + lb: lb, + } + + req := func(method, urlStr string, opt ...any) *http.Request { + req := httptest.NewRequest(method, urlStr, nil) + for _, o := range opt { + switch v := o.(type) { + case int: + req.ProtoMajor = v + case string: + req.Header.Set("Content-Type", v) + default: + panic(fmt.Sprintf("unsupported option type %T", v)) + } + } + return req + } + + tests := []struct { + name string + req *http.Request + wantPath string + wantProto string + wantBody string + }{ + { + name: "non-gRPC", + req: req("GET", "http://foo/bar"), + wantPath: "/bar", + wantProto: "HTTP/1.1", + }, + { + name: "gRPC-but-not-http2", + req: req("GET", "http://foo/bar", "application/grpc"), + wantPath: "/bar", + wantProto: "HTTP/1.1", + }, + { + name: "gRPC--http2", + req: req("GET", "http://foo/bar", 2, "application/grpc"), + wantPath: "/bar", + wantProto: "HTTP/2.0", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rec := httptest.NewRecorder() + rp.ServeHTTP(rec, tt.req) + + res := rec.Result() + got := must.Get(io.ReadAll(res.Body)) + if got, want := res.Header.Get("Path-Was"), tt.wantPath; want != got { + t.Errorf("Path-Was %q, want %q", got, want) + } + if got, want := res.Header.Get("Proto-Was"), tt.wantProto; want != got { + t.Errorf("Proto-Was %q, want %q", got, want) + } + if string(got) != msg { + t.Errorf("got body %q, want %q", got, msg) + } + }) + } +} + +func TestServeHTTPRedirect(t *testing.T) { + b := newTestBackend(t) + + tests := []struct { + host string + path string + redirect string + reqURI string + wantCode int + wantLoc string + }{ + { + host: "hardcoded-root", + path: "/", + redirect: "https://example.com/", + reqURI: "/old", + wantCode: http.StatusFound, // 302 is the default + wantLoc: "https://example.com/", + }, + { + host: "template-host-and-uri", + path: "/", + redirect: "https://${HOST}${REQUEST_URI}", + reqURI: "/path?foo=bar", + wantCode: http.StatusFound, // 302 is the default + wantLoc: "https://template-host-and-uri/path?foo=bar", + }, + { + host: "custom-301", + path: "/", + redirect: "301:https://example.com/", + reqURI: "/old", + wantCode: http.StatusMovedPermanently, // 301 + wantLoc: "https://example.com/", + }, + { + host: "custom-307", + path: "/", + redirect: "307:https://example.com/new", + reqURI: "/old", + wantCode: http.StatusTemporaryRedirect, // 307 + wantLoc: "https://example.com/new", + }, + { + host: "custom-308", + path: "/", + redirect: "308:https://example.com/permanent", + reqURI: "/old", + wantCode: http.StatusPermanentRedirect, // 308 + wantLoc: "https://example.com/permanent", + }, + } + + for _, tt := range tests { + t.Run(tt.host, func(t *testing.T) { + conf := &ipn.ServeConfig{ + Web: map[ipn.HostPort]*ipn.WebServerConfig{ + ipn.HostPort(tt.host + ":80"): { + Handlers: map[string]*ipn.HTTPHandler{ + tt.path: {Redirect: tt.redirect}, + }, + }, + }, + } + if err := b.SetServeConfig(conf, ""); err != nil { + t.Fatal(err) + } + + req := &http.Request{ + Host: tt.host, + URL: &url.URL{Path: tt.path}, + RequestURI: tt.reqURI, + TLS: &tls.ConnectionState{ServerName: tt.host}, + } + req = req.WithContext(serveHTTPContextKey.WithValue(req.Context(), &serveHTTPContext{ + DestPort: 80, + SrcAddr: netip.MustParseAddrPort("1.2.3.4:1234"), + })) + + w := httptest.NewRecorder() + b.serveWebHandler(w, req) + + if w.Code != tt.wantCode { + t.Errorf("got status %d, want %d", w.Code, tt.wantCode) + } + if got := w.Header().Get("Location"); got != tt.wantLoc { + t.Errorf("got Location %q, want %q", got, tt.wantLoc) + } + }) + } +} diff --git a/ipn/ipnlocal/ssh.go b/ipn/ipnlocal/ssh.go index fbeb19bd1..e2c2f5067 100644 --- a/ipn/ipnlocal/ssh.go +++ b/ipn/ipnlocal/ssh.go @@ -1,7 +1,7 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -//go:build linux || (darwin && !ios) || freebsd || openbsd +//go:build ((linux && !android) || (darwin && !ios) || freebsd || openbsd || plan9) && !ts_omit_ssh package ipnlocal @@ -24,10 +24,10 @@ import ( "strings" "sync" - "github.com/tailscale/golang-x-crypto/ssh" "go4.org/mem" + "golang.org/x/crypto/ssh" "tailscale.com/tailcfg" - "tailscale.com/util/lineread" + "tailscale.com/util/lineiter" "tailscale.com/util/mak" ) @@ -80,30 +80,32 @@ func (b *LocalBackend) getSSHUsernames(req *tailcfg.C2NSSHUsernamesRequest) (*ta if err != nil { return nil, err } - lineread.Reader(bytes.NewReader(out), func(line []byte) error { + for line := range lineiter.Bytes(out) { line = bytes.TrimSpace(line) if len(line) == 0 || line[0] == '_' { - return nil + continue } add(string(line)) - return nil - }) + } default: - lineread.File("/etc/passwd", func(line []byte) error { + for lr := range lineiter.File("/etc/passwd") { + line, err := lr.Value() + if err != nil { + break + } line = bytes.TrimSpace(line) if len(line) == 0 || line[0] == '#' || line[0] == '_' { - return nil + continue } if mem.HasSuffix(mem.B(line), mem.S("/nologin")) || mem.HasSuffix(mem.B(line), mem.S("/false")) { - return nil + continue } colon := bytes.IndexByte(line, ':') if colon != -1 { add(string(line[:colon])) } - return nil - }) + } } return res, nil } diff --git a/ipn/ipnlocal/ssh_stub.go b/ipn/ipnlocal/ssh_stub.go index 7875ae311..6b2e36015 100644 --- a/ipn/ipnlocal/ssh_stub.go +++ b/ipn/ipnlocal/ssh_stub.go @@ -1,7 +1,7 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -//go:build ios || (!linux && !darwin && !freebsd && !openbsd) +//go:build ts_omit_ssh || ios || android || (!linux && !darwin && !freebsd && !openbsd && !plan9) package ipnlocal diff --git a/ipn/ipnlocal/ssh_test.go b/ipn/ipnlocal/ssh_test.go index 6e93b34f0..b24cd6732 100644 --- a/ipn/ipnlocal/ssh_test.go +++ b/ipn/ipnlocal/ssh_test.go @@ -13,6 +13,7 @@ import ( "tailscale.com/health" "tailscale.com/ipn/store/mem" "tailscale.com/tailcfg" + "tailscale.com/util/eventbus/eventbustest" "tailscale.com/util/must" ) @@ -50,7 +51,7 @@ type fakeSSHServer struct { } func TestGetSSHUsernames(t *testing.T) { - pm := must.Get(newProfileManager(new(mem.Store), t.Logf, new(health.Tracker))) + pm := must.Get(newProfileManager(new(mem.Store), t.Logf, health.NewTracker(eventbustest.NewBus(t)))) b := &LocalBackend{pm: pm, store: pm.Store()} b.sshServer = fakeSSHServer{} res, err := b.getSSHUsernames(new(tailcfg.C2NSSHUsernamesRequest)) diff --git a/ipn/ipnlocal/state_test.go b/ipn/ipnlocal/state_test.go index bebd0152b..152b375b0 100644 --- a/ipn/ipnlocal/state_test.go +++ b/ipn/ipnlocal/state_test.go @@ -5,26 +5,50 @@ package ipnlocal import ( "context" + "errors" + "fmt" + "math/rand/v2" + "net/netip" + "strings" "sync" "sync/atomic" "testing" "time" qt "github.com/frankban/quicktest" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" "tailscale.com/control/controlclient" "tailscale.com/envknob" "tailscale.com/ipn" + "tailscale.com/ipn/ipnauth" + "tailscale.com/ipn/ipnstate" "tailscale.com/ipn/store/mem" + "tailscale.com/net/dns" + "tailscale.com/net/netmon" + "tailscale.com/net/packet" + "tailscale.com/net/tsdial" "tailscale.com/tailcfg" "tailscale.com/tsd" "tailscale.com/tstest" + "tailscale.com/types/dnstype" "tailscale.com/types/key" "tailscale.com/types/logger" "tailscale.com/types/logid" "tailscale.com/types/netmap" "tailscale.com/types/persist" + "tailscale.com/types/preftype" + "tailscale.com/util/dnsname" + "tailscale.com/util/eventbus/eventbustest" + "tailscale.com/util/mak" + "tailscale.com/util/must" "tailscale.com/wgengine" + "tailscale.com/wgengine/filter" + "tailscale.com/wgengine/magicsock" + "tailscale.com/wgengine/router" + "tailscale.com/wgengine/wgcfg" + "tailscale.com/wgengine/wgint" ) // notifyThrottler receives notifications from an ipn.Backend, blocking @@ -35,8 +59,9 @@ type notifyThrottler struct { // ch gets replaced frequently. Lock the mutex before getting or // setting it, but not while waiting on it. - mu sync.Mutex - ch chan ipn.Notify + mu sync.Mutex + ch chan ipn.Notify + putErr error // set by put if the channel is full } // expect tells the throttler to expect count upcoming notifications. @@ -57,7 +82,11 @@ func (nt *notifyThrottler) put(n ipn.Notify) { case ch <- n: return default: - nt.t.Fatalf("put: channel full: %v", n) + err := fmt.Errorf("put: channel full: %v", n) + nt.t.Log(err) + nt.mu.Lock() + nt.putErr = err + nt.mu.Unlock() } } @@ -67,8 +96,13 @@ func (nt *notifyThrottler) drain(count int) []ipn.Notify { nt.t.Helper() nt.mu.Lock() ch := nt.ch + putErr := nt.putErr nt.mu.Unlock() + if putErr != nil { + nt.t.Fatalf("drain: previous call to put errored: %s", putErr) + } + nn := []ipn.Notify{} for i := range count { select { @@ -91,10 +125,11 @@ func (nt *notifyThrottler) drain(count int) []ipn.Notify { // in the controlclient.Client, so by controlling it, we can check that // the state machine works as expected. type mockControl struct { - tb testing.TB - logf logger.Logf - opts controlclient.Options - paused atomic.Bool + tb testing.TB + logf logger.Logf + opts controlclient.Options + paused atomic.Bool + controlClientID int64 mu sync.Mutex persist *persist.Persist @@ -105,12 +140,13 @@ type mockControl struct { func newClient(tb testing.TB, opts controlclient.Options) *mockControl { return &mockControl{ - tb: tb, - authBlocked: true, - logf: opts.Logf, - opts: opts, - shutdown: make(chan struct{}), - persist: opts.Persist.Clone(), + tb: tb, + authBlocked: true, + logf: opts.Logf, + opts: opts, + shutdown: make(chan struct{}), + persist: opts.Persist.Clone(), + controlClientID: rand.Int64(), } } @@ -146,9 +182,17 @@ func (cc *mockControl) populateKeys() (newKeys bool) { return newKeys } +type sendOpt struct { + err error + url string + loginFinished bool + nm *netmap.NetworkMap +} + // send publishes a controlclient.Status notification upstream. // (In our tests here, upstream is the ipnlocal.Local instance.) -func (cc *mockControl) send(err error, url string, loginFinished bool, nm *netmap.NetworkMap) { +func (cc *mockControl) send(opts sendOpt) { + err, url, loginFinished, nm := opts.err, opts.url, opts.loginFinished, opts.nm if loginFinished { cc.mu.Lock() cc.authBlocked = false @@ -162,14 +206,29 @@ func (cc *mockControl) send(err error, url string, loginFinished bool, nm *netma Err: err, } if loginFinished { - s.SetStateForTest(controlclient.StateAuthenticated) - } else if url == "" && err == nil && nm == nil { - s.SetStateForTest(controlclient.StateNotAuthenticated) + s.LoggedIn = true } cc.opts.Observer.SetControlClientStatus(cc, s) } } +func (cc *mockControl) authenticated(nm *netmap.NetworkMap) { + if selfUser, ok := nm.UserProfiles[nm.SelfNode.User()]; ok { + cc.persist.UserProfile = *selfUser.AsStruct() + } + cc.persist.NodeID = nm.SelfNode.StableID() + cc.send(sendOpt{loginFinished: true, nm: nm}) +} + +func (cc *mockControl) sendAuthURL(nm *netmap.NetworkMap) { + s := controlclient.Status{ + URL: "https://example.com/a/foo", + NetMap: nm, + Persist: cc.persist.View(), + } + cc.opts.Observer.SetControlClientStatus(cc, s) +} + // called records that a particular function name was called. func (cc *mockControl) called(s string) { cc.mu.Lock() @@ -257,6 +316,15 @@ func (cc *mockControl) UpdateEndpoints(endpoints []tailcfg.Endpoint) { cc.called("UpdateEndpoints") } +func (cc *mockControl) SetDiscoPublicKey(key key.DiscoPublic) { + cc.logf("SetDiscoPublicKey: %v", key) + cc.called("SetDiscoPublicKey") +} + +func (cc *mockControl) ClientID() int64 { + return cc.controlClientID +} + func (b *LocalBackend) nonInteractiveLoginForStateTest() { b.mu.Lock() if b.cc == nil { @@ -290,15 +358,23 @@ func (b *LocalBackend) nonInteractiveLoginForStateTest() { // predictable, but maybe a bit less thorough. This is more of an overall // state machine test than a test of the wgengine+magicsock integration. func TestStateMachine(t *testing.T) { + runTestStateMachine(t, false) +} + +func TestStateMachineSeamless(t *testing.T) { + runTestStateMachine(t, true) +} + +func runTestStateMachine(t *testing.T, seamless bool) { envknob.Setenv("TAILSCALE_USE_WIP_CODE", "1") defer envknob.Setenv("TAILSCALE_USE_WIP_CODE", "") c := qt.New(t) logf := tstest.WhileTestRunningLogger(t) - sys := new(tsd.System) + sys := tsd.NewSystem() store := new(testStateStorage) sys.Set(store) - e, err := wgengine.NewFakeUserspaceEngine(logf, sys.Set, sys.HealthTracker(), sys.UserMetricsRegistry()) + e, err := wgengine.NewFakeUserspaceEngine(logf, sys.Set, sys.HealthTracker.Get(), sys.UserMetricsRegistry(), sys.Bus.Get()) if err != nil { t.Fatalf("NewFakeUserspaceEngine: %v", err) } @@ -309,7 +385,7 @@ func TestStateMachine(t *testing.T) { if err != nil { t.Fatalf("NewLocalBackend: %v", err) } - b.DisablePortMapperForTest() + t.Cleanup(b.Shutdown) var cc, previousCC *mockControl b.SetControlClientGetterForTesting(func(opts controlclient.Options) (controlclient.Client, error) { @@ -360,8 +436,11 @@ func TestStateMachine(t *testing.T) { // for it, so it doesn't count as Prefs.LoggedOut==true. c.Assert(prefs.LoggedOut(), qt.IsTrue) c.Assert(prefs.WantRunning(), qt.IsFalse) - c.Assert(ipn.NeedsLogin, qt.Equals, *nn[1].State) - c.Assert(ipn.NeedsLogin, qt.Equals, b.State()) + // Verify notification indicates we need login (prefs show logged out) + c.Assert(nn[1].Prefs == nil || nn[1].Prefs.LoggedOut(), qt.IsTrue) + // Verify the actual facts about our state + c.Assert(needsLogin(b), qt.IsTrue) + c.Assert(hasValidNetMap(b), qt.IsFalse) } // Restart the state machine. @@ -381,8 +460,11 @@ func TestStateMachine(t *testing.T) { c.Assert(nn[1].State, qt.IsNotNil) c.Assert(nn[0].Prefs.LoggedOut(), qt.IsTrue) c.Assert(nn[0].Prefs.WantRunning(), qt.IsFalse) - c.Assert(ipn.NeedsLogin, qt.Equals, *nn[1].State) - c.Assert(ipn.NeedsLogin, qt.Equals, b.State()) + // Verify notification indicates we need login + c.Assert(nn[1].Prefs == nil || nn[1].Prefs.LoggedOut(), qt.IsTrue) + // Verify the actual facts about our state + c.Assert(needsLogin(b), qt.IsTrue) + c.Assert(hasValidNetMap(b), qt.IsFalse) } // Start non-interactive login with no token. @@ -399,7 +481,8 @@ func TestStateMachine(t *testing.T) { // (This behaviour is needed so that b.Login() won't // start connecting to an old account right away, if one // exists when you launch another login.) - c.Assert(ipn.NeedsLogin, qt.Equals, b.State()) + // Verify we still need login + c.Assert(needsLogin(b), qt.IsTrue) } // Attempted non-interactive login with no key; indicate that @@ -414,7 +497,7 @@ func TestStateMachine(t *testing.T) { }, }) url1 := "https://localhost:1/1" - cc.send(nil, url1, false, nil) + cc.send(sendOpt{url: url1}) { cc.assertCalls() @@ -426,10 +509,11 @@ func TestStateMachine(t *testing.T) { c.Assert(nn[1].Prefs, qt.IsNotNil) c.Assert(nn[1].Prefs.LoggedOut(), qt.IsTrue) c.Assert(nn[1].Prefs.WantRunning(), qt.IsFalse) - c.Assert(ipn.NeedsLogin, qt.Equals, b.State()) + // Verify we need URL visit + c.Assert(hasAuthURL(b), qt.IsTrue) c.Assert(nn[2].BrowseToURL, qt.IsNotNil) c.Assert(url1, qt.Equals, *nn[2].BrowseToURL) - c.Assert(ipn.NeedsLogin, qt.Equals, b.State()) + c.Assert(isFullyAuthenticated(b), qt.IsFalse) } // Now we'll try an interactive login. @@ -444,7 +528,8 @@ func TestStateMachine(t *testing.T) { cc.assertCalls() c.Assert(nn[0].BrowseToURL, qt.IsNotNil) c.Assert(url1, qt.Equals, *nn[0].BrowseToURL) - c.Assert(ipn.NeedsLogin, qt.Equals, b.State()) + // Verify we still need to complete login + c.Assert(needsLogin(b), qt.IsTrue) } // Sometimes users press the Login button again, in the middle of @@ -460,14 +545,15 @@ func TestStateMachine(t *testing.T) { notifies.drain(0) // backend asks control for another login sequence cc.assertCalls("Login") - c.Assert(ipn.NeedsLogin, qt.Equals, b.State()) + // Verify we still need login + c.Assert(needsLogin(b), qt.IsTrue) } // Provide a new interactive login URL. t.Logf("\n\nLogin2 (url response)") notifies.expect(1) url2 := "https://localhost:1/2" - cc.send(nil, url2, false, nil) + cc.send(sendOpt{url: url2}) { cc.assertCalls() @@ -476,7 +562,8 @@ func TestStateMachine(t *testing.T) { nn := notifies.drain(1) c.Assert(nn[0].BrowseToURL, qt.IsNotNil) c.Assert(url2, qt.Equals, *nn[0].BrowseToURL) - c.Assert(ipn.NeedsLogin, qt.Equals, b.State()) + // Verify we still need to complete login + c.Assert(needsLogin(b), qt.IsTrue) } // Pretend that the interactive login actually happened. @@ -487,7 +574,14 @@ func TestStateMachine(t *testing.T) { notifies.expect(3) cc.persist.UserProfile.LoginName = "user1" cc.persist.NodeID = "node1" - cc.send(nil, "", true, &netmap.NetworkMap{}) + + // even if seamless is being enabled by default rather than by policy, this is + // the point where it will first get enabled. + if seamless { + sys.ControlKnobs().SeamlessKeyRenewal.Store(true) + } + + cc.send(sendOpt{loginFinished: true, nm: &netmap.NetworkMap{}}) { nn := notifies.drain(3) // Arguably it makes sense to unpause now, since the machine @@ -501,10 +595,18 @@ func TestStateMachine(t *testing.T) { cc.assertCalls() c.Assert(nn[0].LoginFinished, qt.IsNotNil) c.Assert(nn[1].Prefs, qt.IsNotNil) - c.Assert(nn[2].State, qt.IsNotNil) c.Assert(nn[1].Prefs.Persist().UserProfile().LoginName, qt.Equals, "user1") - c.Assert(ipn.NeedsMachineAuth, qt.Equals, *nn[2].State) - c.Assert(ipn.NeedsMachineAuth, qt.Equals, b.State()) + // nn[2] is a state notification after login + // Verify login finished but need machine auth using backend state + c.Assert(isFullyAuthenticated(b), qt.IsTrue) + c.Assert(needsMachineAuth(b), qt.IsTrue) + nm := b.NetMap() + c.Assert(nm, qt.IsNotNil) + // For an empty netmap (after initial login), SelfNode may not be valid yet. + // In this case, we can't check MachineAuthorized, but needsMachineAuth already verified the state. + if nm.SelfNode.Valid() { + c.Assert(nm.SelfNode.MachineAuthorized(), qt.IsFalse) + } } // Pretend that the administrator has authorized our machine. @@ -516,14 +618,19 @@ func TestStateMachine(t *testing.T) { // but the current code is brittle. // (ie. I suspect it would be better to change false->true in send() // below, and do the same in the real controlclient.) - cc.send(nil, "", false, &netmap.NetworkMap{ + cc.send(sendOpt{nm: &netmap.NetworkMap{ SelfNode: (&tailcfg.Node{MachineAuthorized: true}).View(), - }) + }}) { nn := notifies.drain(1) cc.assertCalls() - c.Assert(nn[0].State, qt.IsNotNil) - c.Assert(ipn.Starting, qt.Equals, *nn[0].State) + // nn[0] is a state notification after machine auth granted + c.Assert(len(nn), qt.Equals, 1) + // Verify machine authorized using backend state + nm := b.NetMap() + c.Assert(nm, qt.IsNotNil) + c.Assert(nm.SelfNode.Valid(), qt.IsTrue) + c.Assert(nm.SelfNode.MachineAuthorized(), qt.IsTrue) } // TODO: add a fake DERP server to our fake netmap, so we can @@ -546,9 +653,9 @@ func TestStateMachine(t *testing.T) { nn := notifies.drain(2) cc.assertCalls("pause") // BUG: I would expect Prefs to change first, and state after. - c.Assert(nn[0].State, qt.IsNotNil) + // nn[0] is state notification, nn[1] is prefs notification c.Assert(nn[1].Prefs, qt.IsNotNil) - c.Assert(ipn.Stopped, qt.Equals, *nn[0].State) + c.Assert(nn[1].Prefs.WantRunning(), qt.IsFalse) } // The user changes their preference to WantRunning after all. @@ -564,69 +671,67 @@ func TestStateMachine(t *testing.T) { // BUG: Login isn't needed here. We never logged out. cc.assertCalls("Login", "unpause") // BUG: I would expect Prefs to change first, and state after. - c.Assert(nn[0].State, qt.IsNotNil) + // nn[0] is state notification, nn[1] is prefs notification c.Assert(nn[1].Prefs, qt.IsNotNil) - c.Assert(ipn.Starting, qt.Equals, *nn[0].State) + c.Assert(nn[1].Prefs.WantRunning(), qt.IsTrue) c.Assert(store.sawWrite(), qt.IsTrue) } - // undo the state hack above. - b.state = ipn.Starting - // User wants to logout. store.awaitWrite() t.Logf("\n\nLogout") notifies.expect(5) - b.Logout(context.Background()) + b.Logout(context.Background(), ipnauth.Self) { nn := notifies.drain(5) previousCC.assertCalls("pause", "Logout", "unpause", "Shutdown") + // nn[0] is state notification (Stopped) c.Assert(nn[0].State, qt.IsNotNil) c.Assert(*nn[0].State, qt.Equals, ipn.Stopped) - + // nn[1] is prefs notification after logout c.Assert(nn[1].Prefs, qt.IsNotNil) c.Assert(nn[1].Prefs.LoggedOut(), qt.IsTrue) c.Assert(nn[1].Prefs.WantRunning(), qt.IsFalse) cc.assertCalls("New") - c.Assert(nn[2].State, qt.IsNotNil) - c.Assert(*nn[2].State, qt.Equals, ipn.NoState) - - c.Assert(nn[3].Prefs, qt.IsNotNil) // emptyPrefs + // nn[2] is the initial state notification after New (NoState) + // nn[3] is prefs notification with emptyPrefs + c.Assert(nn[3].Prefs, qt.IsNotNil) c.Assert(nn[3].Prefs.LoggedOut(), qt.IsTrue) c.Assert(nn[3].Prefs.WantRunning(), qt.IsFalse) - c.Assert(nn[4].State, qt.IsNotNil) - c.Assert(*nn[4].State, qt.Equals, ipn.NeedsLogin) - - c.Assert(b.State(), qt.Equals, ipn.NeedsLogin) - c.Assert(store.sawWrite(), qt.IsTrue) + // nn[4] is state notification (NeedsLogin) + // Verify logged out and needs new login using backend state + c.Assert(needsLogin(b), qt.IsTrue) + c.Assert(hasValidNetMap(b), qt.IsFalse) } // A second logout should be a no-op as we are in the NeedsLogin state. t.Logf("\n\nLogout2") notifies.expect(0) - b.Logout(context.Background()) + b.Logout(context.Background(), ipnauth.Self) { notifies.drain(0) cc.assertCalls() c.Assert(b.Prefs().LoggedOut(), qt.IsTrue) c.Assert(b.Prefs().WantRunning(), qt.IsFalse) - c.Assert(ipn.NeedsLogin, qt.Equals, b.State()) + // Verify still needs login + c.Assert(needsLogin(b), qt.IsTrue) } // A third logout should also be a no-op as the cc should be in // AuthCantContinue state. t.Logf("\n\nLogout3") notifies.expect(3) - b.Logout(context.Background()) + b.Logout(context.Background(), ipnauth.Self) { notifies.drain(0) cc.assertCalls() c.Assert(b.Prefs().LoggedOut(), qt.IsTrue) c.Assert(b.Prefs().WantRunning(), qt.IsFalse) - c.Assert(ipn.NeedsLogin, qt.Equals, b.State()) + // Verify still needs login + c.Assert(needsLogin(b), qt.IsTrue) } // Oh, you thought we were done? Ha! Now we have to test what @@ -649,11 +754,13 @@ func TestStateMachine(t *testing.T) { nn := notifies.drain(2) cc.assertCalls() c.Assert(nn[0].Prefs, qt.IsNotNil) - c.Assert(nn[1].State, qt.IsNotNil) c.Assert(nn[0].Prefs.LoggedOut(), qt.IsTrue) c.Assert(nn[0].Prefs.WantRunning(), qt.IsFalse) - c.Assert(ipn.NeedsLogin, qt.Equals, *nn[1].State) - c.Assert(ipn.NeedsLogin, qt.Equals, b.State()) + // Verify notification indicates we need login + c.Assert(nn[1].Prefs == nil || nn[1].Prefs.LoggedOut(), qt.IsTrue) + // Verify we need login after restart + c.Assert(needsLogin(b), qt.IsTrue) + c.Assert(hasValidNetMap(b), qt.IsFalse) } // Explicitly set the ControlURL to avoid defaulting to [ipn.DefaultControlURL]. @@ -679,7 +786,7 @@ func TestStateMachine(t *testing.T) { // an interactive login URL to visit. notifies.expect(2) url3 := "https://localhost:1/3" - cc.send(nil, url3, false, nil) + cc.send(sendOpt{url: url3}) { nn := notifies.drain(2) cc.assertCalls("Login") @@ -690,9 +797,9 @@ func TestStateMachine(t *testing.T) { notifies.expect(3) cc.persist.UserProfile.LoginName = "user2" cc.persist.NodeID = "node2" - cc.send(nil, "", true, &netmap.NetworkMap{ + cc.send(sendOpt{loginFinished: true, nm: &netmap.NetworkMap{ SelfNode: (&tailcfg.Node{MachineAuthorized: true}).View(), - }) + }}) t.Logf("\n\nLoginFinished3") { nn := notifies.drain(3) @@ -704,8 +811,9 @@ func TestStateMachine(t *testing.T) { c.Assert(nn[1].Prefs.LoggedOut(), qt.IsFalse) // If a user initiates an interactive login, they also expect WantRunning to become true. c.Assert(nn[1].Prefs.WantRunning(), qt.IsTrue) - c.Assert(nn[2].State, qt.IsNotNil) - c.Assert(ipn.Starting, qt.Equals, *nn[2].State) + // nn[2] is state notification (Starting) - verify using backend state + c.Assert(isWantRunning(b), qt.IsTrue) + c.Assert(isLoggedIn(b), qt.IsTrue) } // Now we've logged in successfully. Let's disconnect. @@ -719,9 +827,9 @@ func TestStateMachine(t *testing.T) { nn := notifies.drain(2) cc.assertCalls("pause") // BUG: I would expect Prefs to change first, and state after. - c.Assert(nn[0].State, qt.IsNotNil) + // nn[0] is state notification (Stopped), nn[1] is prefs notification c.Assert(nn[1].Prefs, qt.IsNotNil) - c.Assert(ipn.Stopped, qt.Equals, *nn[0].State) + c.Assert(nn[1].Prefs.WantRunning(), qt.IsFalse) c.Assert(nn[1].Prefs.LoggedOut(), qt.IsFalse) } @@ -734,23 +842,26 @@ func TestStateMachine(t *testing.T) { // b.Shutdown() explicitly ourselves. previousCC.assertShutdown(false) - // Note: unpause happens because ipn needs to get at least one netmap - // on startup, otherwise UIs can't show the node list, login - // name, etc when in state ipn.Stopped. - // Arguably they shouldn't try. But they currently do. nn := notifies.drain(2) - cc.assertCalls("New", "Login") + // We already have a netmap for this node, + // and WantRunning is false, so cc should be paused. + cc.assertCalls("New", "Login", "pause") c.Assert(nn[0].Prefs, qt.IsNotNil) - c.Assert(nn[1].State, qt.IsNotNil) c.Assert(nn[0].Prefs.WantRunning(), qt.IsFalse) c.Assert(nn[0].Prefs.LoggedOut(), qt.IsFalse) - c.Assert(*nn[1].State, qt.Equals, ipn.Stopped) + // nn[1] is state notification (Stopped) + // Verify backend shows we're not wanting to run + c.Assert(isWantRunning(b), qt.IsFalse) } // When logged in but !WantRunning, ipn leaves us unpaused to retrieve // the first netmap. Simulate that netmap being received, after which // it should pause us, to avoid wasting CPU retrieving unnecessarily - // additional netmap updates. + // additional netmap updates. Since our LocalBackend instance already + // has a netmap, we will reset it to nil to simulate the first netmap + // retrieval. + b.setNetMapLocked(nil) + cc.assertCalls("unpause") // // TODO: really the various GUIs and prefs should be refactored to // not require the netmap structure at all when starting while @@ -758,9 +869,9 @@ func TestStateMachine(t *testing.T) { // the control server at all when stopped). t.Logf("\n\nStart4 -> netmap") notifies.expect(0) - cc.send(nil, "", true, &netmap.NetworkMap{ + cc.send(sendOpt{loginFinished: true, nm: &netmap.NetworkMap{ SelfNode: (&tailcfg.Node{MachineAuthorized: true}).View(), - }) + }}) { notifies.drain(0) cc.assertCalls("pause") @@ -778,9 +889,9 @@ func TestStateMachine(t *testing.T) { nn := notifies.drain(2) cc.assertCalls("Login", "unpause") // BUG: I would expect Prefs to change first, and state after. - c.Assert(nn[0].State, qt.IsNotNil) + // nn[0] is state notification (Starting), nn[1] is prefs notification c.Assert(nn[1].Prefs, qt.IsNotNil) - c.Assert(ipn.Starting, qt.Equals, *nn[0].State) + c.Assert(nn[1].Prefs.WantRunning(), qt.IsTrue) } // Disconnect. @@ -794,9 +905,9 @@ func TestStateMachine(t *testing.T) { nn := notifies.drain(2) cc.assertCalls("pause") // BUG: I would expect Prefs to change first, and state after. - c.Assert(nn[0].State, qt.IsNotNil) + // nn[0] is state notification (Stopped), nn[1] is prefs notification c.Assert(nn[1].Prefs, qt.IsNotNil) - c.Assert(ipn.Stopped, qt.Equals, *nn[0].State) + c.Assert(nn[1].Prefs.WantRunning(), qt.IsFalse) } // We want to try logging in as a different user, while Stopped. @@ -805,7 +916,7 @@ func TestStateMachine(t *testing.T) { notifies.expect(1) b.StartLoginInteractive(context.Background()) url4 := "https://localhost:1/4" - cc.send(nil, url4, false, nil) + cc.send(sendOpt{url: url4}) { nn := notifies.drain(1) // It might seem like WantRunning should switch to true here, @@ -827,9 +938,9 @@ func TestStateMachine(t *testing.T) { notifies.expect(3) cc.persist.UserProfile.LoginName = "user3" cc.persist.NodeID = "node3" - cc.send(nil, "", true, &netmap.NetworkMap{ + cc.send(sendOpt{loginFinished: true, nm: &netmap.NetworkMap{ SelfNode: (&tailcfg.Node{MachineAuthorized: true}).View(), - }) + }}) { nn := notifies.drain(3) // BUG: pause() being called here is a bad sign. @@ -841,18 +952,19 @@ func TestStateMachine(t *testing.T) { cc.assertCalls("unpause") c.Assert(nn[0].LoginFinished, qt.IsNotNil) c.Assert(nn[1].Prefs, qt.IsNotNil) - c.Assert(nn[2].State, qt.IsNotNil) // Prefs after finishing the login, so LoginName updated. c.Assert(nn[1].Prefs.Persist().UserProfile().LoginName, qt.Equals, "user3") c.Assert(nn[1].Prefs.LoggedOut(), qt.IsFalse) c.Assert(nn[1].Prefs.WantRunning(), qt.IsTrue) - c.Assert(ipn.Starting, qt.Equals, *nn[2].State) + // nn[2] is state notification (Starting) - verify using backend state + c.Assert(isWantRunning(b), qt.IsTrue) + c.Assert(isLoggedIn(b), qt.IsTrue) } // The last test case is the most common one: restarting when both // logged in and WantRunning. t.Logf("\n\nStart5") - notifies.expect(1) + notifies.expect(2) c.Assert(b.Start(ipn.Options{}), qt.IsNil) { // NOTE: cc.Shutdown() is correct here, since we didn't call @@ -860,59 +972,69 @@ func TestStateMachine(t *testing.T) { previousCC.assertShutdown(false) cc.assertCalls("New", "Login") - nn := notifies.drain(1) + nn := notifies.drain(2) cc.assertCalls() c.Assert(nn[0].Prefs, qt.IsNotNil) c.Assert(nn[0].Prefs.LoggedOut(), qt.IsFalse) c.Assert(nn[0].Prefs.WantRunning(), qt.IsTrue) - c.Assert(b.State(), qt.Equals, ipn.NoState) + // nn[1] is state notification (Starting) + // Verify we're authenticated with valid netmap using backend state + c.Assert(isFullyAuthenticated(b), qt.IsTrue) + c.Assert(hasValidNetMap(b), qt.IsTrue) } // Control server accepts our valid key from before. t.Logf("\n\nLoginFinished5") - notifies.expect(1) - cc.send(nil, "", true, &netmap.NetworkMap{ + notifies.expect(0) + cc.send(sendOpt{loginFinished: true, nm: &netmap.NetworkMap{ SelfNode: (&tailcfg.Node{MachineAuthorized: true}).View(), - }) + }}) { - nn := notifies.drain(1) + notifies.drain(0) cc.assertCalls() // NOTE: No LoginFinished message since no interactive // login was needed. - c.Assert(nn[0].State, qt.IsNotNil) - c.Assert(ipn.Starting, qt.Equals, *nn[0].State) // NOTE: No prefs change this time. WantRunning stays true. // We were in Starting in the first place, so that doesn't - // change either. - c.Assert(ipn.Starting, qt.Equals, b.State()) + // change either, so we don't expect any notifications. + // Verify we're still authenticated with valid netmap + c.Assert(isFullyAuthenticated(b), qt.IsTrue) + c.Assert(hasValidNetMap(b), qt.IsTrue) } t.Logf("\n\nExpireKey") notifies.expect(1) - cc.send(nil, "", false, &netmap.NetworkMap{ - Expiry: time.Now().Add(-time.Minute), - SelfNode: (&tailcfg.Node{MachineAuthorized: true}).View(), - }) + cc.send(sendOpt{nm: &netmap.NetworkMap{ + SelfNode: (&tailcfg.Node{ + KeyExpiry: time.Now().Add(-time.Minute), + MachineAuthorized: true, + }).View(), + }}) { nn := notifies.drain(1) cc.assertCalls() - c.Assert(nn[0].State, qt.IsNotNil) - c.Assert(ipn.NeedsLogin, qt.Equals, *nn[0].State) - c.Assert(ipn.NeedsLogin, qt.Equals, b.State()) + // nn[0] is state notification (NeedsLogin) due to key expiry + c.Assert(len(nn), qt.Equals, 1) + // Verify key expired, need new login using backend state + c.Assert(needsLogin(b), qt.IsTrue) c.Assert(b.isEngineBlocked(), qt.IsTrue) } t.Logf("\n\nExtendKey") notifies.expect(1) - cc.send(nil, "", false, &netmap.NetworkMap{ - Expiry: time.Now().Add(time.Minute), - SelfNode: (&tailcfg.Node{MachineAuthorized: true}).View(), - }) + cc.send(sendOpt{nm: &netmap.NetworkMap{ + SelfNode: (&tailcfg.Node{ + MachineAuthorized: true, + KeyExpiry: time.Now().Add(time.Minute), + }).View(), + }}) { nn := notifies.drain(1) cc.assertCalls() - c.Assert(nn[0].State, qt.IsNotNil) - c.Assert(ipn.Starting, qt.Equals, *nn[0].State) - c.Assert(ipn.Starting, qt.Equals, b.State()) + // nn[0] is state notification (Starting) after key extension + c.Assert(len(nn), qt.Equals, 1) + // Verify key extended, authenticated again using backend state + c.Assert(isFullyAuthenticated(b), qt.IsTrue) + c.Assert(hasValidNetMap(b), qt.IsTrue) c.Assert(b.isEngineBlocked(), qt.IsFalse) } notifies.expect(1) @@ -921,17 +1043,18 @@ func TestStateMachine(t *testing.T) { { nn := notifies.drain(1) cc.assertCalls() - c.Assert(nn[0].State, qt.IsNotNil) - c.Assert(ipn.Running, qt.Equals, *nn[0].State) - c.Assert(ipn.Running, qt.Equals, b.State()) + // nn[0] is state notification (Running) after DERP connection + c.Assert(len(nn), qt.Equals, 1) + // Verify we can route traffic using backend state + c.Assert(canRouteTraffic(b), qt.IsTrue) } } func TestEditPrefsHasNoKeys(t *testing.T) { logf := tstest.WhileTestRunningLogger(t) - sys := new(tsd.System) + sys := tsd.NewSystem() sys.Set(new(mem.Store)) - e, err := wgengine.NewFakeUserspaceEngine(logf, sys.Set, sys.HealthTracker(), sys.UserMetricsRegistry()) + e, err := wgengine.NewFakeUserspaceEngine(logf, sys.Set, sys.HealthTracker.Get(), sys.UserMetricsRegistry(), sys.Bus.Get()) if err != nil { t.Fatalf("NewFakeUserspaceEngine: %v", err) } @@ -942,13 +1065,12 @@ func TestEditPrefsHasNoKeys(t *testing.T) { if err != nil { t.Fatalf("NewLocalBackend: %v", err) } + t.Cleanup(b.Shutdown) b.hostinfo = &tailcfg.Hostinfo{OS: "testos"} b.pm.SetPrefs((&ipn.Prefs{ Persist: &persist.Persist{ PrivateNodeKey: key.NewNode(), OldPrivateNodeKey: key.NewNode(), - - LegacyFrontendPrivateMachineKey: key.NewMachine(), }, }).View(), ipn.NetworkProfile{}) if p := b.pm.CurrentPrefs().Persist(); !p.Valid() || p.PrivateNodeKey().IsZero() { @@ -975,27 +1097,19 @@ func TestEditPrefsHasNoKeys(t *testing.T) { t.Errorf("OldPrivateNodeKey = %v; want zero", p.Persist().OldPrivateNodeKey()) } - if !p.Persist().LegacyFrontendPrivateMachineKey().IsZero() { - t.Errorf("LegacyFrontendPrivateMachineKey = %v; want zero", p.Persist().LegacyFrontendPrivateMachineKey()) - } - if !p.Persist().NetworkLockKey().IsZero() { t.Errorf("NetworkLockKey= %v; want zero", p.Persist().NetworkLockKey()) } } type testStateStorage struct { - mem mem.Store + mem.Store written atomic.Bool } -func (s *testStateStorage) ReadState(id ipn.StateKey) ([]byte, error) { - return s.mem.ReadState(id) -} - func (s *testStateStorage) WriteState(id ipn.StateKey, bs []byte) error { s.written.Store(true) - return s.mem.WriteState(id, bs) + return s.Store.WriteState(id, bs) } // awaitWrite clears the "I've seen writes" bit, in prep for a future @@ -1014,15 +1128,16 @@ func TestWGEngineStatusRace(t *testing.T) { t.Skip("test fails") c := qt.New(t) logf := tstest.WhileTestRunningLogger(t) - sys := new(tsd.System) + sys := tsd.NewSystem() sys.Set(new(mem.Store)) - eng, err := wgengine.NewFakeUserspaceEngine(logf, sys.Set) + eng, err := wgengine.NewFakeUserspaceEngine(logf, sys.Set, sys.Bus.Get()) c.Assert(err, qt.IsNil) t.Cleanup(eng.Close) sys.Set(eng) b, err := NewLocalBackend(logf, logid.PublicID{}, sys, 0) c.Assert(err, qt.IsNil) + t.Cleanup(b.Shutdown) var cc *mockControl b.SetControlClientGetterForTesting(func(opts controlclient.Options) (controlclient.Client, error) { @@ -1049,9 +1164,9 @@ func TestWGEngineStatusRace(t *testing.T) { wantState(ipn.NeedsLogin) // Assert that we are logged in and authorized. - cc.send(nil, "", true, &netmap.NetworkMap{ + cc.send(sendOpt{loginFinished: true, nm: &netmap.NetworkMap{ SelfNode: (&tailcfg.Node{MachineAuthorized: true}).View(), - }) + }}) wantState(ipn.Starting) // Simulate multiple concurrent callbacks from wgengine. @@ -1075,3 +1190,827 @@ func TestWGEngineStatusRace(t *testing.T) { wg.Wait() wantState(ipn.Running) } + +// TestEngineReconfigOnStateChange verifies that wgengine is properly reconfigured +// when the LocalBackend's state changes, such as when the user logs in, switches +// profiles, or disconnects from Tailscale. +func TestEngineReconfigOnStateChange(t *testing.T) { + enableLogging := false + connect := &ipn.MaskedPrefs{Prefs: ipn.Prefs{WantRunning: true}, WantRunningSet: true} + disconnect := &ipn.MaskedPrefs{Prefs: ipn.Prefs{WantRunning: false}, WantRunningSet: true} + node1 := buildNetmapWithPeers( + makePeer(1, withName("node-1"), withAddresses(netip.MustParsePrefix("100.64.1.1/32"))), + ) + node2 := buildNetmapWithPeers( + makePeer(2, withName("node-2"), withAddresses(netip.MustParsePrefix("100.64.1.2/32"))), + ) + node3 := buildNetmapWithPeers( + makePeer(3, withName("node-3"), withAddresses(netip.MustParsePrefix("100.64.1.3/32"))), + node1.SelfNode, + node2.SelfNode, + ) + routesWithQuad100 := func(extra ...netip.Prefix) []netip.Prefix { + return append(extra, netip.MustParsePrefix("100.100.100.100/32")) + } + hostsFor := func(nm *netmap.NetworkMap) map[dnsname.FQDN][]netip.Addr { + var hosts map[dnsname.FQDN][]netip.Addr + appendNode := func(n tailcfg.NodeView) { + addrs := make([]netip.Addr, 0, n.Addresses().Len()) + for _, addr := range n.Addresses().All() { + addrs = append(addrs, addr.Addr()) + } + mak.Set(&hosts, must.Get(dnsname.ToFQDN(n.Name())), addrs) + } + if nm != nil && nm.SelfNode.Valid() { + appendNode(nm.SelfNode) + } + for _, n := range nm.Peers { + appendNode(n) + } + return hosts + } + + tests := []struct { + name string + steps func(*testing.T, *LocalBackend, func() *mockControl) + wantState ipn.State + wantCfg *wgcfg.Config + wantRouterCfg *router.Config + wantDNSCfg *dns.Config + }{ + { + name: "Initial", + // The configs are nil until the the LocalBackend is started. + wantState: ipn.NoState, + wantCfg: nil, + wantRouterCfg: nil, + wantDNSCfg: nil, + }, + { + name: "Start", + steps: func(t *testing.T, lb *LocalBackend, _ func() *mockControl) { + mustDo(t)(lb.Start(ipn.Options{})) + }, + // Once started, all configs must be reset and have their zero values. + wantState: ipn.NeedsLogin, + wantCfg: &wgcfg.Config{}, + wantRouterCfg: &router.Config{}, + wantDNSCfg: &dns.Config{}, + }, + { + name: "Start/Connect", + steps: func(t *testing.T, lb *LocalBackend, _ func() *mockControl) { + mustDo(t)(lb.Start(ipn.Options{})) + mustDo2(t)(lb.EditPrefs(connect)) + }, + // Same if WantRunning is true, but the auth is not completed yet. + wantState: ipn.NeedsLogin, + wantCfg: &wgcfg.Config{}, + wantRouterCfg: &router.Config{}, + wantDNSCfg: &dns.Config{}, + }, + { + name: "Start/Connect/Login", + steps: func(t *testing.T, lb *LocalBackend, cc func() *mockControl) { + mustDo(t)(lb.Start(ipn.Options{})) + mustDo2(t)(lb.EditPrefs(connect)) + cc().authenticated(node1) + }, + // After the auth is completed, the configs must be updated to reflect the node's netmap. + wantState: ipn.Starting, + wantCfg: &wgcfg.Config{ + Peers: []wgcfg.Peer{}, + Addresses: node1.SelfNode.Addresses().AsSlice(), + }, + wantRouterCfg: &router.Config{ + SNATSubnetRoutes: true, + NetfilterMode: preftype.NetfilterOn, + LocalAddrs: node1.SelfNode.Addresses().AsSlice(), + Routes: routesWithQuad100(), + }, + wantDNSCfg: &dns.Config{ + Routes: map[dnsname.FQDN][]*dnstype.Resolver{}, + Hosts: hostsFor(node1), + }, + }, + { + name: "Start/Connect/Login/Disconnect", + steps: func(t *testing.T, lb *LocalBackend, cc func() *mockControl) { + mustDo(t)(lb.Start(ipn.Options{})) + mustDo2(t)(lb.EditPrefs(connect)) + cc().authenticated(node1) + mustDo2(t)(lb.EditPrefs(disconnect)) + }, + // After disconnecting, all configs must be reset and have their zero values. + wantState: ipn.Stopped, + wantCfg: &wgcfg.Config{}, + wantRouterCfg: &router.Config{}, + wantDNSCfg: &dns.Config{}, + }, + { + name: "Start/Connect/Login/NewProfile", + steps: func(t *testing.T, lb *LocalBackend, cc func() *mockControl) { + mustDo(t)(lb.Start(ipn.Options{})) + mustDo2(t)(lb.EditPrefs(connect)) + cc().authenticated(node1) + mustDo(t)(lb.NewProfile()) + }, + // After switching to a new, empty profile, all configs should be reset + // and have their zero values until the auth is completed. + wantState: ipn.NeedsLogin, + wantCfg: &wgcfg.Config{}, + wantRouterCfg: &router.Config{}, + wantDNSCfg: &dns.Config{}, + }, + { + name: "Start/Connect/Login/NewProfile/Login", + steps: func(t *testing.T, lb *LocalBackend, cc func() *mockControl) { + mustDo(t)(lb.Start(ipn.Options{})) + mustDo2(t)(lb.EditPrefs(connect)) + cc().authenticated(node1) + mustDo(t)(lb.NewProfile()) + mustDo2(t)(lb.EditPrefs(connect)) + cc().authenticated(node2) + }, + // Once the auth is completed, the configs must be updated to reflect the node's netmap. + wantState: ipn.Starting, + wantCfg: &wgcfg.Config{ + Peers: []wgcfg.Peer{}, + Addresses: node2.SelfNode.Addresses().AsSlice(), + }, + wantRouterCfg: &router.Config{ + SNATSubnetRoutes: true, + NetfilterMode: preftype.NetfilterOn, + LocalAddrs: node2.SelfNode.Addresses().AsSlice(), + Routes: routesWithQuad100(), + }, + wantDNSCfg: &dns.Config{ + Routes: map[dnsname.FQDN][]*dnstype.Resolver{}, + Hosts: hostsFor(node2), + }, + }, + { + name: "Start/Connect/Login/SwitchProfile", + steps: func(t *testing.T, lb *LocalBackend, cc func() *mockControl) { + mustDo(t)(lb.Start(ipn.Options{})) + mustDo2(t)(lb.EditPrefs(connect)) + cc().authenticated(node1) + profileID := lb.CurrentProfile().ID() + mustDo(t)(lb.NewProfile()) + cc().authenticated(node2) + mustDo(t)(lb.SwitchProfile(profileID)) + }, + // After switching to an existing profile, all configs must be reset + // and have their zero values until the (non-interactive) login is completed. + wantState: ipn.NoState, + wantCfg: &wgcfg.Config{}, + wantRouterCfg: &router.Config{}, + wantDNSCfg: &dns.Config{}, + }, + { + name: "Start/Connect/Login/SwitchProfile/NonInteractiveLogin", + steps: func(t *testing.T, lb *LocalBackend, cc func() *mockControl) { + mustDo(t)(lb.Start(ipn.Options{})) + mustDo2(t)(lb.EditPrefs(connect)) + cc().authenticated(node1) + profileID := lb.CurrentProfile().ID() + mustDo(t)(lb.NewProfile()) + cc().authenticated(node2) + mustDo(t)(lb.SwitchProfile(profileID)) + cc().authenticated(node1) // complete the login + }, + // After switching profiles and completing the auth, the configs + // must be updated to reflect the node's netmap. + wantState: ipn.Starting, + wantCfg: &wgcfg.Config{ + Peers: []wgcfg.Peer{}, + Addresses: node1.SelfNode.Addresses().AsSlice(), + }, + wantRouterCfg: &router.Config{ + SNATSubnetRoutes: true, + NetfilterMode: preftype.NetfilterOn, + LocalAddrs: node1.SelfNode.Addresses().AsSlice(), + Routes: routesWithQuad100(), + }, + wantDNSCfg: &dns.Config{ + Routes: map[dnsname.FQDN][]*dnstype.Resolver{}, + Hosts: hostsFor(node1), + }, + }, + { + name: "Start/Connect/Login/WithPeers", + steps: func(t *testing.T, lb *LocalBackend, cc func() *mockControl) { + mustDo(t)(lb.Start(ipn.Options{})) + mustDo2(t)(lb.EditPrefs(connect)) + cc().authenticated(node3) + }, + wantState: ipn.Starting, + wantCfg: &wgcfg.Config{ + Peers: []wgcfg.Peer{ + { + PublicKey: node1.SelfNode.Key(), + DiscoKey: node1.SelfNode.DiscoKey(), + }, + { + PublicKey: node2.SelfNode.Key(), + DiscoKey: node2.SelfNode.DiscoKey(), + }, + }, + Addresses: node3.SelfNode.Addresses().AsSlice(), + }, + wantRouterCfg: &router.Config{ + SNATSubnetRoutes: true, + NetfilterMode: preftype.NetfilterOn, + LocalAddrs: node3.SelfNode.Addresses().AsSlice(), + Routes: routesWithQuad100(), + }, + wantDNSCfg: &dns.Config{ + Routes: map[dnsname.FQDN][]*dnstype.Resolver{}, + Hosts: hostsFor(node3), + }, + }, + { + name: "Start/Connect/Login/Expire", + steps: func(t *testing.T, lb *LocalBackend, cc func() *mockControl) { + mustDo(t)(lb.Start(ipn.Options{})) + mustDo2(t)(lb.EditPrefs(connect)) + cc().authenticated(node1) + cc().send(sendOpt{nm: &netmap.NetworkMap{ + SelfNode: (&tailcfg.Node{ + KeyExpiry: time.Now().Add(-time.Minute), + }).View(), + }}) + }, + wantState: ipn.NeedsLogin, + wantCfg: &wgcfg.Config{}, + wantRouterCfg: &router.Config{}, + wantDNSCfg: &dns.Config{}, + }, + { + name: "Start/Connect/Login/InitReauth", + steps: func(t *testing.T, lb *LocalBackend, cc func() *mockControl) { + mustDo(t)(lb.Start(ipn.Options{})) + mustDo2(t)(lb.EditPrefs(connect)) + cc().authenticated(node1) + + // Start the re-auth process: + lb.StartLoginInteractive(context.Background()) + cc().sendAuthURL(node1) + }, + // Without seamless renewal, even starting a reauth tears down everything: + wantState: ipn.Starting, + wantCfg: &wgcfg.Config{}, + wantRouterCfg: &router.Config{}, + wantDNSCfg: &dns.Config{}, + }, + { + name: "Start/Connect/Login/InitReauth/Login", + steps: func(t *testing.T, lb *LocalBackend, cc func() *mockControl) { + mustDo(t)(lb.Start(ipn.Options{})) + mustDo2(t)(lb.EditPrefs(connect)) + cc().authenticated(node1) + + // Start the re-auth process: + lb.StartLoginInteractive(context.Background()) + cc().sendAuthURL(node1) + + // Complete the re-auth process: + cc().authenticated(node1) + }, + wantState: ipn.Starting, + wantCfg: &wgcfg.Config{ + Peers: []wgcfg.Peer{}, + Addresses: node1.SelfNode.Addresses().AsSlice(), + }, + wantRouterCfg: &router.Config{ + SNATSubnetRoutes: true, + NetfilterMode: preftype.NetfilterOn, + LocalAddrs: node1.SelfNode.Addresses().AsSlice(), + Routes: routesWithQuad100(), + }, + wantDNSCfg: &dns.Config{ + Routes: map[dnsname.FQDN][]*dnstype.Resolver{}, + Hosts: hostsFor(node1), + }, + }, + { + name: "Seamless/Start/Connect/Login/InitReauth", + steps: func(t *testing.T, lb *LocalBackend, cc func() *mockControl) { + lb.ControlKnobs().SeamlessKeyRenewal.Store(true) + mustDo(t)(lb.Start(ipn.Options{})) + mustDo2(t)(lb.EditPrefs(connect)) + cc().authenticated(node1) + + // Start the re-auth process: + lb.StartLoginInteractive(context.Background()) + cc().sendAuthURL(node1) + }, + // With seamless renewal, starting a reauth should leave everything up: + wantState: ipn.Starting, + wantCfg: &wgcfg.Config{ + Peers: []wgcfg.Peer{}, + Addresses: node1.SelfNode.Addresses().AsSlice(), + }, + wantRouterCfg: &router.Config{ + SNATSubnetRoutes: true, + NetfilterMode: preftype.NetfilterOn, + LocalAddrs: node1.SelfNode.Addresses().AsSlice(), + Routes: routesWithQuad100(), + }, + wantDNSCfg: &dns.Config{ + Routes: map[dnsname.FQDN][]*dnstype.Resolver{}, + Hosts: hostsFor(node1), + }, + }, + { + name: "Seamless/Start/Connect/Login/InitReauth/Login", + steps: func(t *testing.T, lb *LocalBackend, cc func() *mockControl) { + lb.ControlKnobs().SeamlessKeyRenewal.Store(true) + mustDo(t)(lb.Start(ipn.Options{})) + mustDo2(t)(lb.EditPrefs(connect)) + cc().authenticated(node1) + + // Start the re-auth process: + lb.StartLoginInteractive(context.Background()) + cc().sendAuthURL(node1) + + // Complete the re-auth process: + cc().authenticated(node1) + }, + wantState: ipn.Starting, + wantCfg: &wgcfg.Config{ + Peers: []wgcfg.Peer{}, + Addresses: node1.SelfNode.Addresses().AsSlice(), + }, + wantRouterCfg: &router.Config{ + SNATSubnetRoutes: true, + NetfilterMode: preftype.NetfilterOn, + LocalAddrs: node1.SelfNode.Addresses().AsSlice(), + Routes: routesWithQuad100(), + }, + wantDNSCfg: &dns.Config{ + Routes: map[dnsname.FQDN][]*dnstype.Resolver{}, + Hosts: hostsFor(node1), + }, + }, + { + name: "Seamless/Start/Connect/Login/Expire", + steps: func(t *testing.T, lb *LocalBackend, cc func() *mockControl) { + lb.ControlKnobs().SeamlessKeyRenewal.Store(true) + mustDo(t)(lb.Start(ipn.Options{})) + mustDo2(t)(lb.EditPrefs(connect)) + cc().authenticated(node1) + cc().send(sendOpt{nm: &netmap.NetworkMap{ + SelfNode: (&tailcfg.Node{ + KeyExpiry: time.Now().Add(-time.Minute), + }).View(), + }}) + }, + // Even with seamless, if the key we are using expires, we want to disconnect: + wantState: ipn.NeedsLogin, + wantCfg: &wgcfg.Config{}, + wantRouterCfg: &router.Config{}, + wantDNSCfg: &dns.Config{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + lb, engine, cc := newLocalBackendWithMockEngineAndControl(t, enableLogging) + + if tt.steps != nil { + tt.steps(t, lb, cc) + } + + // TODO(bradfitz): this whole event bus settling thing + // should be unnecessary once the bogus uses of eventbus + // are removed. (https://github.com/tailscale/tailscale/issues/16369) + lb.settleEventBus() + + if gotState := lb.State(); gotState != tt.wantState { + t.Errorf("State: got %v; want %v", gotState, tt.wantState) + } + + if engine.Config() != nil { + for _, p := range engine.Config().Peers { + pKey := p.PublicKey.UntypedHexString() + _, err := lb.MagicConn().ParseEndpoint(pKey) + if err != nil { + t.Errorf("ParseEndpoint(%q) failed: %v", pKey, err) + } + } + } + + opts := []cmp.Option{ + cmpopts.EquateComparable(key.NodePublic{}, key.DiscoPublic{}, netip.Addr{}, netip.Prefix{}), + } + if diff := cmp.Diff(tt.wantCfg, engine.Config(), opts...); diff != "" { + t.Errorf("wgcfg.Config(+got -want): %v", diff) + } + if diff := cmp.Diff(tt.wantRouterCfg, engine.RouterConfig(), opts...); diff != "" { + t.Errorf("router.Config(+got -want): %v", diff) + } + if diff := cmp.Diff(tt.wantDNSCfg, engine.DNSConfig(), opts...); diff != "" { + t.Errorf("dns.Config(+got -want): %v", diff) + } + }) + } +} + +// TestSendPreservesAuthURL tests that wgengine updates arriving in the middle of +// processing an auth URL doesn't result in the auth URL being cleared. +func TestSendPreservesAuthURL(t *testing.T) { + runTestSendPreservesAuthURL(t, false) +} + +func TestSendPreservesAuthURLSeamless(t *testing.T) { + runTestSendPreservesAuthURL(t, true) +} + +func runTestSendPreservesAuthURL(t *testing.T, seamless bool) { + var cc *mockControl + b := newLocalBackendWithTestControl(t, true, func(tb testing.TB, opts controlclient.Options) controlclient.Client { + cc = newClient(t, opts) + return cc + }) + + t.Logf("Start") + b.Start(ipn.Options{ + UpdatePrefs: &ipn.Prefs{ + WantRunning: true, + ControlURL: "https://localhost:1/", + }, + }) + + t.Logf("LoginFinished") + cc.persist.UserProfile.LoginName = "user1" + cc.persist.NodeID = "node1" + + if seamless { + b.sys.ControlKnobs().SeamlessKeyRenewal.Store(true) + } + + cc.send(sendOpt{loginFinished: true, nm: &netmap.NetworkMap{ + SelfNode: (&tailcfg.Node{MachineAuthorized: true}).View(), + }}) + + t.Logf("Running") + b.setWgengineStatus(&wgengine.Status{AsOf: time.Now(), DERPs: 1}, nil) + + t.Logf("Re-auth (StartLoginInteractive)") + b.StartLoginInteractive(t.Context()) + + t.Logf("Re-auth (receive URL)") + url1 := "https://localhost:1/1" + cc.send(sendOpt{url: url1}) + + // Don't need to wait on anything else - once .send completes, authURL should + // be set, and once .send has completed, any opportunities for a WG engine + // status update to trample it have ended as well. + if b.authURL == "" { + t.Fatalf("expected authURL to be set") + } else { + t.Log("authURL was set") + } +} + +func buildNetmapWithPeers(self tailcfg.NodeView, peers ...tailcfg.NodeView) *netmap.NetworkMap { + const ( + firstAutoUserID = tailcfg.UserID(10000) + domain = "example.com" + magicDNSSuffix = ".test.ts.net" + ) + + users := make(map[tailcfg.UserID]tailcfg.UserProfileView) + makeUserForNode := func(n *tailcfg.Node) { + var user *tailcfg.UserProfile + if n.User == 0 { + n.User = firstAutoUserID + tailcfg.UserID(n.ID) + user = &tailcfg.UserProfile{ + DisplayName: n.Name, + LoginName: n.Name, + } + } else if _, ok := users[n.User]; !ok { + user = &tailcfg.UserProfile{ + DisplayName: fmt.Sprintf("User %d", n.User), + LoginName: fmt.Sprintf("user-%d", n.User), + } + } + if user != nil { + user.ID = n.User + user.LoginName = strings.Join([]string{user.LoginName, domain}, "@") + users[n.User] = user.View() + } + } + + derpmap := &tailcfg.DERPMap{ + Regions: make(map[int]*tailcfg.DERPRegion), + } + makeDERPRegionForNode := func(n *tailcfg.Node) { + if n.HomeDERP == 0 { + return // no DERP region + } + if _, ok := derpmap.Regions[n.HomeDERP]; !ok { + r := &tailcfg.DERPRegion{ + RegionID: n.HomeDERP, + RegionName: fmt.Sprintf("Region %d", n.HomeDERP), + } + r.Nodes = append(r.Nodes, &tailcfg.DERPNode{ + Name: fmt.Sprintf("%da", n.HomeDERP), + RegionID: n.HomeDERP, + }) + derpmap.Regions[n.HomeDERP] = r + } + } + + updateNode := func(n tailcfg.NodeView) tailcfg.NodeView { + mut := n.AsStruct() + makeUserForNode(mut) + makeDERPRegionForNode(mut) + mut.Name = mut.Name + magicDNSSuffix + return mut.View() + } + + self = updateNode(self) + for i := range peers { + peers[i] = updateNode(peers[i]) + } + + return &netmap.NetworkMap{ + SelfNode: self, + Domain: domain, + Peers: peers, + UserProfiles: users, + DERPMap: derpmap, + } +} + +func mustDo(t *testing.T) func(error) { + t.Helper() + return func(err error) { + t.Helper() + if err != nil { + t.Fatal(err) + } + } +} + +func mustDo2(t *testing.T) func(any, error) { + t.Helper() + return func(_ any, err error) { + t.Helper() + if err != nil { + t.Fatal(err) + } + } +} + +func newLocalBackendWithMockEngineAndControl(t *testing.T, enableLogging bool) (*LocalBackend, *mockEngine, func() *mockControl) { + t.Helper() + + logf := logger.Discard + if enableLogging { + logf = tstest.WhileTestRunningLogger(t) + } + + dialer := &tsdial.Dialer{Logf: logf} + dialer.SetNetMon(netmon.NewStatic()) + + bus := eventbustest.NewBus(t) + sys := tsd.NewSystemWithBus(bus) + sys.Set(dialer) + sys.Set(dialer.NetMon()) + dialer.SetBus(bus) + + magicConn, err := magicsock.NewConn(magicsock.Options{ + Logf: logf, + EventBus: sys.Bus.Get(), + NetMon: dialer.NetMon(), + Metrics: sys.UserMetricsRegistry(), + HealthTracker: sys.HealthTracker.Get(), + DisablePortMapper: true, + }) + if err != nil { + t.Fatalf("NewConn failed: %v", err) + } + magicConn.SetNetworkUp(dialer.NetMon().InterfaceState().AnyInterfaceUp()) + sys.Set(magicConn) + + engine := newMockEngine() + sys.Set(engine) + t.Cleanup(func() { + engine.Close() + <-engine.Done() + }) + + lb := newLocalBackendWithSysAndTestControl(t, enableLogging, sys, func(tb testing.TB, opts controlclient.Options) controlclient.Client { + return newClient(tb, opts) + }) + return lb, engine, func() *mockControl { return lb.cc.(*mockControl) } +} + +var _ wgengine.Engine = (*mockEngine)(nil) + +// mockEngine implements [wgengine.Engine]. +type mockEngine struct { + done chan struct{} // closed when Close is called + + mu sync.Mutex // protects all following fields + closed bool + cfg *wgcfg.Config + routerCfg *router.Config + dnsCfg *dns.Config + + filter, jailedFilter *filter.Filter + + statusCb wgengine.StatusCallback +} + +func newMockEngine() *mockEngine { + return &mockEngine{ + done: make(chan struct{}), + } +} + +func (e *mockEngine) Reconfig(cfg *wgcfg.Config, routerCfg *router.Config, dnsCfg *dns.Config) error { + e.mu.Lock() + defer e.mu.Unlock() + if e.closed { + return errors.New("engine closed") + } + e.cfg = cfg + e.routerCfg = routerCfg + e.dnsCfg = dnsCfg + return nil +} + +func (e *mockEngine) Config() *wgcfg.Config { + e.mu.Lock() + defer e.mu.Unlock() + return e.cfg +} + +func (e *mockEngine) RouterConfig() *router.Config { + e.mu.Lock() + defer e.mu.Unlock() + return e.routerCfg +} + +func (e *mockEngine) DNSConfig() *dns.Config { + e.mu.Lock() + defer e.mu.Unlock() + return e.dnsCfg +} + +func (e *mockEngine) PeerForIP(netip.Addr) (_ wgengine.PeerForIP, ok bool) { + return wgengine.PeerForIP{}, false +} + +func (e *mockEngine) GetFilter() *filter.Filter { + e.mu.Lock() + defer e.mu.Unlock() + return e.filter +} + +func (e *mockEngine) SetFilter(f *filter.Filter) { + e.mu.Lock() + e.filter = f + e.mu.Unlock() +} + +func (e *mockEngine) GetJailedFilter() *filter.Filter { + e.mu.Lock() + defer e.mu.Unlock() + return e.jailedFilter +} + +func (e *mockEngine) SetJailedFilter(f *filter.Filter) { + e.mu.Lock() + e.jailedFilter = f + e.mu.Unlock() +} + +func (e *mockEngine) SetStatusCallback(cb wgengine.StatusCallback) { + e.mu.Lock() + e.statusCb = cb + e.mu.Unlock() +} + +func (e *mockEngine) RequestStatus() { + e.mu.Lock() + cb := e.statusCb + e.mu.Unlock() + if cb != nil { + cb(&wgengine.Status{AsOf: time.Now()}, nil) + } +} + +func (e *mockEngine) ResetAndStop() (*wgengine.Status, error) { + err := e.Reconfig(&wgcfg.Config{}, &router.Config{}, &dns.Config{}) + if err != nil { + return nil, err + } + return &wgengine.Status{AsOf: time.Now()}, nil +} + +func (e *mockEngine) PeerByKey(key.NodePublic) (_ wgint.Peer, ok bool) { + return wgint.Peer{}, false +} + +func (e *mockEngine) SetNetworkMap(*netmap.NetworkMap) {} + +func (e *mockEngine) UpdateStatus(*ipnstate.StatusBuilder) {} + +func (e *mockEngine) Ping(ip netip.Addr, pingType tailcfg.PingType, size int, cb func(*ipnstate.PingResult)) { + cb(&ipnstate.PingResult{IP: ip.String(), Err: "not implemented"}) +} + +func (e *mockEngine) InstallCaptureHook(packet.CaptureCallback) {} + +func (e *mockEngine) Close() { + e.mu.Lock() + defer e.mu.Unlock() + if e.closed { + return + } + e.closed = true + close(e.done) +} + +func (e *mockEngine) Done() <-chan struct{} { + return e.done +} + +// hasValidNetMap returns true if the backend has a valid network map with a valid self node. +func hasValidNetMap(b *LocalBackend) bool { + nm := b.NetMap() + return nm != nil && nm.SelfNode.Valid() +} + +// needsLogin returns true if the backend needs user login action. +// This is true when logged out, when an auth URL is present (interactive login in progress), +// or when the node key has expired. +func needsLogin(b *LocalBackend) bool { + // Note: b.Prefs() handles its own locking, so we lock only for authURL and keyExpired access + b.mu.Lock() + authURL := b.authURL + keyExpired := b.keyExpired + b.mu.Unlock() + return b.Prefs().LoggedOut() || authURL != "" || keyExpired +} + +// needsMachineAuth returns true if the user has logged in but the machine is not yet authorized. +// This includes the case where we have a netmap but no valid SelfNode yet (empty netmap after initial login). +func needsMachineAuth(b *LocalBackend) bool { + // Note: b.NetMap() and b.Prefs() handle their own locking + nm := b.NetMap() + prefs := b.Prefs() + if prefs.LoggedOut() || nm == nil { + return false + } + // If we have a valid SelfNode, check its MachineAuthorized status + if nm.SelfNode.Valid() { + return !nm.SelfNode.MachineAuthorized() + } + // Empty netmap (no SelfNode yet) after login also means we need machine auth + return true +} + +// hasAuthURL returns true if an authentication URL is present (user needs to visit a URL). +func hasAuthURL(b *LocalBackend) bool { + b.mu.Lock() + authURL := b.authURL + b.mu.Unlock() + return authURL != "" +} + +// canRouteTraffic returns true if the backend is capable of routing traffic. +// This requires a valid netmap, machine authorization, and WantRunning preference. +func canRouteTraffic(b *LocalBackend) bool { + // Note: b.NetMap() and b.Prefs() handle their own locking + nm := b.NetMap() + prefs := b.Prefs() + return nm != nil && + nm.SelfNode.Valid() && + nm.SelfNode.MachineAuthorized() && + prefs.WantRunning() +} + +// isFullyAuthenticated returns true if the user has completed login and no auth URL is pending. +func isFullyAuthenticated(b *LocalBackend) bool { + // Note: b.Prefs() handles its own locking, so we lock only for authURL access + b.mu.Lock() + authURL := b.authURL + b.mu.Unlock() + return !b.Prefs().LoggedOut() && authURL == "" +} + +// isWantRunning returns true if the WantRunning preference is set. +func isWantRunning(b *LocalBackend) bool { + return b.Prefs().WantRunning() +} + +// isLoggedIn returns true if the user is logged in (not logged out). +func isLoggedIn(b *LocalBackend) bool { + return !b.Prefs().LoggedOut() +} diff --git a/ipn/ipnlocal/taildrop.go b/ipn/ipnlocal/taildrop.go deleted file mode 100644 index db7d8e12a..000000000 --- a/ipn/ipnlocal/taildrop.go +++ /dev/null @@ -1,35 +0,0 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package ipnlocal - -import ( - "maps" - "slices" - "strings" - - "tailscale.com/ipn" -) - -// UpdateOutgoingFiles updates b.outgoingFiles to reflect the given updates and -// sends an ipn.Notify with the full list of outgoingFiles. -func (b *LocalBackend) UpdateOutgoingFiles(updates map[string]*ipn.OutgoingFile) { - b.mu.Lock() - if b.outgoingFiles == nil { - b.outgoingFiles = make(map[string]*ipn.OutgoingFile, len(updates)) - } - maps.Copy(b.outgoingFiles, updates) - outgoingFiles := make([]*ipn.OutgoingFile, 0, len(b.outgoingFiles)) - for _, file := range b.outgoingFiles { - outgoingFiles = append(outgoingFiles, file) - } - b.mu.Unlock() - slices.SortFunc(outgoingFiles, func(a, b *ipn.OutgoingFile) int { - t := a.Started.Compare(b.Started) - if t != 0 { - return t - } - return strings.Compare(a.Name, b.Name) - }) - b.send(ipn.Notify{OutgoingFiles: outgoingFiles}) -} diff --git a/ipn/ipnlocal/tailnetlock_disabled.go b/ipn/ipnlocal/tailnetlock_disabled.go new file mode 100644 index 000000000..85cf4bd3f --- /dev/null +++ b/ipn/ipnlocal/tailnetlock_disabled.go @@ -0,0 +1,31 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build ts_omit_tailnetlock + +package ipnlocal + +import ( + "tailscale.com/ipn" + "tailscale.com/ipn/ipnstate" + "tailscale.com/tka" + "tailscale.com/types/netmap" +) + +type tkaState struct { + authority *tka.Authority +} + +func (b *LocalBackend) initTKALocked() error { + return nil +} + +func (b *LocalBackend) tkaSyncIfNeeded(nm *netmap.NetworkMap, prefs ipn.PrefsView) error { + return nil +} + +func (b *LocalBackend) tkaFilterNetmapLocked(nm *netmap.NetworkMap) {} + +func (b *LocalBackend) NetworkLockStatus() *ipnstate.NetworkLockStatus { + return &ipnstate.NetworkLockStatus{Enabled: false} +} diff --git a/ipn/ipnlocal/web_client.go b/ipn/ipnlocal/web_client.go index ccde9f01d..a3c9387e4 100644 --- a/ipn/ipnlocal/web_client.go +++ b/ipn/ipnlocal/web_client.go @@ -1,7 +1,7 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -//go:build !ios && !android +//go:build !ios && !android && !ts_omit_webclient package ipnlocal @@ -17,16 +17,17 @@ import ( "sync" "time" - "tailscale.com/client/tailscale" + "tailscale.com/client/local" "tailscale.com/client/web" - "tailscale.com/logtail/backoff" "tailscale.com/net/netutil" "tailscale.com/tailcfg" + "tailscale.com/tsconst" "tailscale.com/types/logger" + "tailscale.com/util/backoff" "tailscale.com/util/mak" ) -const webClientPort = web.ListenPort +const webClientPort = tsconst.WebListenPort // webClient holds state for the web interface for managing this // tailscale instance. The web interface is not used by default, @@ -36,16 +37,16 @@ type webClient struct { server *web.Server // or nil, initialized lazily - // lc optionally specifies a LocalClient to use to connect + // lc optionally specifies a local.Client to use to connect // to the localapi for this tailscaled instance. // If nil, a default is used. - lc *tailscale.LocalClient + lc *local.Client } // ConfigureWebClient configures b.web prior to use. -// Specifially, it sets b.web.lc to the provided LocalClient. +// Specifially, it sets b.web.lc to the provided local.Client. // If provided as nil, b.web.lc is cleared out. -func (b *LocalBackend) ConfigureWebClient(lc *tailscale.LocalClient) { +func (b *LocalBackend) ConfigureWebClient(lc *local.Client) { b.webClient.mu.Lock() defer b.webClient.mu.Unlock() b.webClient.lc = lc @@ -116,13 +117,14 @@ func (b *LocalBackend) handleWebClientConn(c net.Conn) error { // for each of the local device's Tailscale IP addresses. This is needed to properly // route local traffic when using kernel networking mode. func (b *LocalBackend) updateWebClientListenersLocked() { - if b.netMap == nil { + nm := b.currentNode().NetMap() + if nm == nil { return } - addrs := b.netMap.GetAddresses() - for i := range addrs.Len() { - addrPort := netip.AddrPortFrom(addrs.At(i).Addr(), webClientPort) + addrs := nm.GetAddresses() + for _, pfx := range addrs.All() { + addrPort := netip.AddrPortFrom(pfx.Addr(), webClientPort) if _, ok := b.webClientListeners[addrPort]; ok { continue // already listening } diff --git a/ipn/ipnlocal/web_client_stub.go b/ipn/ipnlocal/web_client_stub.go index 1dfc8c27c..787867b4f 100644 --- a/ipn/ipnlocal/web_client_stub.go +++ b/ipn/ipnlocal/web_client_stub.go @@ -1,22 +1,20 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -//go:build ios || android +//go:build ios || android || ts_omit_webclient package ipnlocal import ( "errors" "net" - - "tailscale.com/client/tailscale" ) const webClientPort = 5252 type webClient struct{} -func (b *LocalBackend) ConfigureWebClient(lc *tailscale.LocalClient) {} +func (b *LocalBackend) ConfigureWebClient(any) {} func (b *LocalBackend) webClientGetOrInit() error { return errors.New("not implemented") diff --git a/ipn/ipnserver/actor.go b/ipn/ipnserver/actor.go index 761c9816c..628e3c37c 100644 --- a/ipn/ipnserver/actor.go +++ b/ipn/ipnserver/actor.go @@ -12,6 +12,7 @@ import ( "runtime" "time" + "tailscale.com/feature/buildfeatures" "tailscale.com/ipn" "tailscale.com/ipn/ipnauth" "tailscale.com/types/logger" @@ -31,7 +32,13 @@ type actor struct { logf logger.Logf ci *ipnauth.ConnIdentity - isLocalSystem bool // whether the actor is the Windows' Local System identity. + clientID ipnauth.ClientID + userID ipn.WindowsUserID // cached Windows user ID of the connected client process. + // accessOverrideReason specifies the reason for overriding certain access restrictions, + // such as permitting a user to disconnect when the always-on mode is enabled, + // provided that such justification is allowed by the policy. + accessOverrideReason string + isLocalSystem bool // whether the actor is the Windows' Local System identity. } func newActor(logf logger.Logf, c net.Conn) (*actor, error) { @@ -39,7 +46,62 @@ func newActor(logf logger.Logf, c net.Conn) (*actor, error) { if err != nil { return nil, err } - return &actor{logf: logf, ci: ci, isLocalSystem: connIsLocalSystem(ci)}, nil + var clientID ipnauth.ClientID + if pid := ci.Pid(); pid != 0 { + // Derive [ipnauth.ClientID] from the PID of the connected client process. + // TODO(nickkhyl): This is transient and will be re-worked as we + // progress on tailscale/corp#18342. At minimum, we should use a 2-tuple + // (PID + StartTime) or a 3-tuple (PID + StartTime + UID) to identify + // the client process. This helps prevent security issues where a + // terminated client process's PID could be reused by a different + // process. This is not currently an issue as we allow only one user to + // connect anyway. + // Additionally, we should consider caching authentication results since + // operations like retrieving a username by SID might require network + // connectivity on domain-joined devices and/or be slow. + clientID = ipnauth.ClientIDFrom(pid) + } + return &actor{ + logf: logf, + ci: ci, + clientID: clientID, + userID: ci.WindowsUserID(), + isLocalSystem: connIsLocalSystem(ci), + }, + nil +} + +// actorWithAccessOverride returns a new actor that carries the specified +// reason for overriding certain access restrictions, if permitted by the +// policy. If the reason is "", it returns the base actor. +func actorWithAccessOverride(baseActor *actor, reason string) *actor { + if reason == "" { + return baseActor + } + return &actor{ + logf: baseActor.logf, + ci: baseActor.ci, + clientID: baseActor.clientID, + userID: baseActor.userID, + accessOverrideReason: reason, + isLocalSystem: baseActor.isLocalSystem, + } +} + +// CheckProfileAccess implements [ipnauth.Actor]. +func (a *actor) CheckProfileAccess(profile ipn.LoginProfileView, requestedAccess ipnauth.ProfileAccess, auditLogger ipnauth.AuditLogFunc) error { + // TODO(nickkhyl): return errors of more specific types and have them + // translated to the appropriate HTTP status codes in the API handler. + if profile.LocalUserID() != a.UserID() { + return errors.New("the target profile does not belong to the user") + } + switch requestedAccess { + case ipnauth.Disconnect: + // Disconnect is allowed if a user owns the profile and the policy permits it. + return ipnauth.CheckDisconnectPolicy(a, profile, a.accessOverrideReason, auditLogger) + default: + return errors.New("the requested operation is not allowed") + } } // IsLocalSystem implements [ipnauth.Actor]. @@ -54,13 +116,21 @@ func (a *actor) IsLocalAdmin(operatorUID string) bool { // UserID implements [ipnauth.Actor]. func (a *actor) UserID() ipn.WindowsUserID { - return a.ci.WindowsUserID() + return a.userID } func (a *actor) pid() int { return a.ci.Pid() } +// ClientID implements [ipnauth.Actor]. +func (a *actor) ClientID() (_ ipnauth.ClientID, ok bool) { + return a.clientID, a.clientID != ipnauth.NoClientID +} + +// Context implements [ipnauth.Actor]. +func (a *actor) Context() context.Context { return context.Background() } + // Username implements [ipnauth.Actor]. func (a *actor) Username() (string, error) { if a.ci == nil { @@ -75,8 +145,12 @@ func (a *actor) Username() (string, error) { } defer tok.Close() return tok.Username() - case "darwin", "linux": - uid, ok := a.ci.Creds().UserID() + case "darwin", "linux", "illumos", "solaris", "openbsd": + creds := a.ci.Creds() + if creds == nil { + return "", errors.New("peer credentials not implemented on this OS") + } + uid, ok := creds.UserID() if !ok { return "", errors.New("missing user ID") } @@ -91,11 +165,11 @@ func (a *actor) Username() (string, error) { } type actorOrError struct { - actor *actor + actor ipnauth.Actor err error } -func (a actorOrError) unwrap() (*actor, error) { +func (a actorOrError) unwrap() (ipnauth.Actor, error) { return a.actor, a.err } @@ -110,9 +184,15 @@ func contextWithActor(ctx context.Context, logf logger.Logf, c net.Conn) context return actorKey.WithValue(ctx, actorOrError{actor: actor, err: err}) } -// actorFromContext returns an [actor] associated with ctx, +// NewContextWithActorForTest returns a new context that carries the identity +// of the specified actor. It is used in tests only. +func NewContextWithActorForTest(ctx context.Context, actor ipnauth.Actor) context.Context { + return actorKey.WithValue(ctx, actorOrError{actor: actor}) +} + +// actorFromContext returns an [ipnauth.Actor] associated with ctx, // or an error if the context does not carry an actor's identity. -func actorFromContext(ctx context.Context) (*actor, error) { +func actorFromContext(ctx context.Context) (ipnauth.Actor, error) { return actorKey.Value(ctx).unwrap() } @@ -158,6 +238,11 @@ func connIsLocalAdmin(logf logger.Logf, ci *ipnauth.ConnIdentity, operatorUID st // Linux. fallthrough case "linux": + if !buildfeatures.HasUnixSocketIdentity { + // Everybody is an admin if support for unix socket identities + // is omitted for the build. + return true + } uid, ok := ci.Creds().UserID() if !ok { return false diff --git a/ipn/ipnserver/proxyconnect.go b/ipn/ipnserver/proxyconnect.go index 1094a79f9..7d41273bd 100644 --- a/ipn/ipnserver/proxyconnect.go +++ b/ipn/ipnserver/proxyconnect.go @@ -10,11 +10,13 @@ import ( "net" "net/http" + "tailscale.com/feature" + "tailscale.com/feature/buildfeatures" "tailscale.com/logpolicy" ) // handleProxyConnectConn handles a CONNECT request to -// log.tailscale.io (or whatever the configured log server is). This +// log.tailscale.com (or whatever the configured log server is). This // is intended for use by the Windows GUI client to log via when an // exit node is in use, so the logs don't go out via the exit node and // instead go directly, like tailscaled's. The dialer tried to do that @@ -23,6 +25,10 @@ import ( // precludes that from working and instead the GUI fails to dial out. // So, go through tailscaled (with a CONNECT request) instead. func (s *Server) handleProxyConnectConn(w http.ResponseWriter, r *http.Request) { + if !buildfeatures.HasOutboundProxy { + http.Error(w, feature.ErrUnavailable.Error(), http.StatusNotImplemented) + return + } ctx := r.Context() if r.Method != "CONNECT" { panic("[unexpected] miswired") diff --git a/ipn/ipnserver/server.go b/ipn/ipnserver/server.go index 73b5e82ab..d473252e1 100644 --- a/ipn/ipnserver/server.go +++ b/ipn/ipnserver/server.go @@ -7,6 +7,7 @@ package ipnserver import ( "context" + "encoding/base64" "encoding/json" "errors" "fmt" @@ -14,22 +15,27 @@ import ( "net" "net/http" "os/user" + "runtime" "strconv" "strings" "sync" "sync/atomic" "unicode" + "tailscale.com/client/tailscale/apitype" "tailscale.com/envknob" - "tailscale.com/ipn" + "tailscale.com/feature" + "tailscale.com/feature/buildfeatures" + "tailscale.com/ipn/ipnauth" "tailscale.com/ipn/ipnlocal" "tailscale.com/ipn/localapi" "tailscale.com/net/netmon" "tailscale.com/types/logger" "tailscale.com/types/logid" + "tailscale.com/util/eventbus" "tailscale.com/util/mak" "tailscale.com/util/set" - "tailscale.com/util/systemd" + "tailscale.com/util/testenv" ) // Server is an IPN backend and its set of 0 or more active localhost @@ -37,20 +43,14 @@ import ( type Server struct { lb atomic.Pointer[ipnlocal.LocalBackend] logf logger.Logf + bus *eventbus.Bus netMon *netmon.Monitor // must be non-nil backendLogID logid.PublicID - // resetOnZero is whether to call bs.Reset on transition from - // 1->0 active HTTP requests. That is, this is whether the backend is - // being run in "client mode" that requires an active GUI - // connection (such as on Windows by default). Even if this - // is true, the ForceDaemon pref can override this. - resetOnZero bool // mu guards the fields that follow. // lock order: mu, then LocalBackend.mu mu sync.Mutex - lastUserID ipn.WindowsUserID // tracks last userid; on change, Reset state for paranoia - activeReqs map[*http.Request]*actor + activeReqs map[*http.Request]ipnauth.Actor backendWaiter waiterSet // of LocalBackend waiters zeroReqWaiter waiterSet // of blockUntilZeroConnections waiters } @@ -122,6 +122,10 @@ func (s *Server) awaitBackend(ctx context.Context) (_ *ipnlocal.LocalBackend, ok // This is primarily for the Windows GUI, because wintun can take awhile to // come up. See https://github.com/tailscale/tailscale/issues/6522. func (s *Server) serveServerStatus(w http.ResponseWriter, r *http.Request) { + if !buildfeatures.HasDebug && runtime.GOOS != "windows" { + http.Error(w, feature.ErrUnavailable.Error(), http.StatusNotFound) + return + } ctx := r.Context() w.Header().Set("Content-Type", "application/json") @@ -194,10 +198,28 @@ func (s *Server) serveHTTP(w http.ResponseWriter, r *http.Request) { defer onDone() if strings.HasPrefix(r.URL.Path, "/localapi/") { - lah := localapi.NewHandler(lb, s.logf, s.backendLogID) - lah.PermitRead, lah.PermitWrite = ci.Permissions(lb.OperatorUserID()) - lah.PermitCert = ci.CanFetchCerts() - lah.Actor = ci + if actor, ok := ci.(*actor); ok { + reason, err := base64.StdEncoding.DecodeString(r.Header.Get(apitype.RequestReasonHeader)) + if err != nil { + http.Error(w, "invalid reason header", http.StatusBadRequest) + return + } + ci = actorWithAccessOverride(actor, string(reason)) + } + + lah := localapi.NewHandler(localapi.HandlerConfig{ + Actor: ci, + Backend: lb, + Logf: s.logf, + LogID: s.backendLogID, + EventBus: lb.Sys().Bus.Get(), + }) + if actor, ok := ci.(*actor); ok { + lah.PermitRead, lah.PermitWrite = actor.Permissions(lb.OperatorUserID()) + lah.PermitCert = actor.CanFetchCerts() + } else if testenv.InTest() { + lah.PermitRead, lah.PermitWrite = true, true + } lah.ServeHTTP(w, r) return } @@ -230,11 +252,11 @@ func (e inUseOtherUserError) Unwrap() error { return e.error } // The returned error, when non-nil, will be of type inUseOtherUserError. // // s.mu must be held. -func (s *Server) checkConnIdentityLocked(ci *actor) error { +func (s *Server) checkConnIdentityLocked(ci ipnauth.Actor) error { // If clients are already connected, verify they're the same user. // This mostly matters on Windows at the moment. if len(s.activeReqs) > 0 { - var active *actor + var active ipnauth.Actor for _, active = range s.activeReqs { break } @@ -251,7 +273,9 @@ func (s *Server) checkConnIdentityLocked(ci *actor) error { if username, err := active.Username(); err == nil { fmt.Fprintf(&b, " by %s", username) } - fmt.Fprintf(&b, ", pid %d", active.pid()) + if active, ok := active.(*actor); ok { + fmt.Fprintf(&b, ", pid %d", active.pid()) + } return inUseOtherUserError{errors.New(b.String())} } } @@ -267,7 +291,7 @@ func (s *Server) checkConnIdentityLocked(ci *actor) error { // // This is primarily used for the Windows GUI, to block until one user's done // controlling the tailscaled process. -func (s *Server) blockWhileIdentityInUse(ctx context.Context, actor *actor) error { +func (s *Server) blockWhileIdentityInUse(ctx context.Context, actor ipnauth.Actor) error { inUse := func() bool { s.mu.Lock() defer s.mu.Unlock() @@ -277,7 +301,18 @@ func (s *Server) blockWhileIdentityInUse(ctx context.Context, actor *actor) erro for inUse() { // Check whenever the connection count drops down to zero. ready, cleanup := s.zeroReqWaiter.add(&s.mu, ctx) - <-ready + if inUse() { + // If the server was in use at the time of the initial check, + // but disconnected and was removed from the activeReqs map + // by the time we registered a waiter, the ready channel + // will never be closed, resulting in a deadlock. To avoid + // this, we can check again after registering the waiter. + // + // This method is planned for complete removal as part of the + // multi-user improvements in tailscale/corp#18342, + // and this approach should be fine as a temporary solution. + <-ready + } cleanup() if err := ctx.Err(); err != nil { return err @@ -291,6 +326,13 @@ func (s *Server) blockWhileIdentityInUse(ctx context.Context, actor *actor) erro // Unix-like platforms and specifies the ID of a local user // (in the os/user.User.Uid string form) who is allowed // to operate tailscaled without being root or using sudo. +// +// Sandboxed macos clients must directly supply, or be able to read, +// an explicit token. Permission is inferred by validating that +// token. Sandboxed macos clients also don't use ipnserver.actor at all +// (and prior to that, they didn't use ipnauth.ConnIdentity) +// +// See safesocket and safesocket_darwin. func (a *actor) Permissions(operatorUID string) (read, write bool) { switch envknob.GOOS() { case "windows": @@ -303,7 +345,7 @@ func (a *actor) Permissions(operatorUID string) (read, write bool) { // checks here. Note that this permission model is being changed in // tailscale/corp#18342. return true, true - case "js": + case "js", "plan9": return true, true } if a.ci.IsUnixSock() { @@ -346,6 +388,9 @@ func isAllDigit(s string) bool { // connection. It's intended to give your non-root webserver access // (www-data, caddy, nginx, etc) to certs. func (a *actor) CanFetchCerts() bool { + if !buildfeatures.HasACME { + return false + } if a.ci.IsUnixSock() && a.ci.Creds() != nil { connUID, ok := a.ci.Creds().UserID() if ok && connUID == userIDFromString(envknob.String("TS_PERMIT_CERT_UID")) { @@ -361,23 +406,17 @@ func (a *actor) CanFetchCerts() bool { // The returned error may be of type [inUseOtherUserError]. // // onDone must be called when the HTTP request is done. -func (s *Server) addActiveHTTPRequest(req *http.Request, actor *actor) (onDone func(), err error) { +func (s *Server) addActiveHTTPRequest(req *http.Request, actor ipnauth.Actor) (onDone func(), err error) { + if runtime.GOOS != "windows" && !buildfeatures.HasUnixSocketIdentity { + return func() {}, nil + } + if actor == nil { return nil, errors.New("internal error: nil actor") } lb := s.mustBackend() - // If the connected user changes, reset the backend server state to make - // sure node keys don't leak between users. - var doReset bool - defer func() { - if doReset { - s.logf("identity changed; resetting server") - lb.ResetForClientDisconnect() - } - }() - s.mu.Lock() defer s.mu.Unlock() @@ -392,40 +431,25 @@ func (s *Server) addActiveHTTPRequest(req *http.Request, actor *actor) (onDone f // Tell the LocalBackend about the identity we're now running as, // unless its the SYSTEM user. That user is not a real account and // doesn't have a home directory. - uid, err := lb.SetCurrentUser(actor) - if err != nil { - return nil, err - } - if s.lastUserID != uid { - if s.lastUserID != "" { - doReset = true - } - s.lastUserID = uid - } + lb.SetCurrentUser(actor) } } onDone = func() { s.mu.Lock() + defer s.mu.Unlock() delete(s.activeReqs, req) - remain := len(s.activeReqs) - s.mu.Unlock() - - if remain == 0 && s.resetOnZero { - if lb.InServerMode() { - s.logf("client disconnected; staying alive in server mode") - } else { - s.logf("client disconnected; stopping server") - lb.ResetForClientDisconnect() - } + if len(s.activeReqs) != 0 { + // The server is not idle yet. + return } - // Wake up callers waiting for the server to be idle: - if remain == 0 { - s.mu.Lock() - s.zeroReqWaiter.wakeAll() - s.mu.Unlock() + if envknob.GOOS() == "windows" && !actor.IsLocalSystem() { + lb.SetCurrentUser(nil) } + + // Wake up callers waiting for the server to be idle: + s.zeroReqWaiter.wakeAll() } return onDone, nil @@ -437,15 +461,15 @@ func (s *Server) addActiveHTTPRequest(req *http.Request, actor *actor) (onDone f // // At some point, either before or after Run, the Server's SetLocalBackend // method must also be called before Server can do anything useful. -func New(logf logger.Logf, logID logid.PublicID, netMon *netmon.Monitor) *Server { +func New(logf logger.Logf, logID logid.PublicID, bus *eventbus.Bus, netMon *netmon.Monitor) *Server { if netMon == nil { panic("nil netMon") } return &Server{ backendLogID: logID, logf: logf, + bus: bus, netMon: netMon, - resetOnZero: envknob.GOOS() == "windows", } } @@ -486,17 +510,25 @@ func (s *Server) Run(ctx context.Context, ln net.Listener) error { runDone := make(chan struct{}) defer close(runDone) - // When the context is closed or when we return, whichever is first, close our listener + ec := s.bus.Client("ipnserver.Server") + defer ec.Close() + shutdownSub := eventbus.Subscribe[localapi.Shutdown](ec) + + // When the context is closed, a [localapi.Shutdown] event is received, + // or when we return, whichever is first, close our listener // and all open connections. go func() { select { + case <-shutdownSub.Events(): case <-ctx.Done(): case <-runDone: } ln.Close() }() - systemd.Ready() + if ready, ok := feature.HookSystemdReady.GetOk(); ok { + ready() + } hs := &http.Server{ Handler: http.HandlerFunc(s.serveHTTP), @@ -519,6 +551,10 @@ func (s *Server) Run(ctx context.Context, ln net.Listener) error { // Windows and via $DEBUG_LISTENER/debug/ipn when tailscaled's --debug flag // is used to run a debug server. func (s *Server) ServeHTMLStatus(w http.ResponseWriter, r *http.Request) { + if !buildfeatures.HasDebug { + http.Error(w, feature.ErrUnavailable.Error(), http.StatusNotFound) + return + } lb := s.lb.Load() if lb == nil { http.Error(w, "no LocalBackend", http.StatusServiceUnavailable) diff --git a/ipn/ipnserver/server_fortest.go b/ipn/ipnserver/server_fortest.go new file mode 100644 index 000000000..9aab3b276 --- /dev/null +++ b/ipn/ipnserver/server_fortest.go @@ -0,0 +1,42 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package ipnserver + +import ( + "context" + "net/http" + + "tailscale.com/ipn/ipnauth" +) + +// BlockWhileInUseByOtherForTest blocks while the actor can't connect to the server because +// the server is in use by a different actor. It is used in tests only. +func (s *Server) BlockWhileInUseByOtherForTest(ctx context.Context, actor ipnauth.Actor) error { + return s.blockWhileIdentityInUse(ctx, actor) +} + +// BlockWhileInUseForTest blocks until the server becomes idle (no active requests), +// or the specified context is done. It returns the context's error if it is done. +// It is used in tests only. +func (s *Server) BlockWhileInUseForTest(ctx context.Context) error { + ready, cleanup := s.zeroReqWaiter.add(&s.mu, ctx) + + s.mu.Lock() + busy := len(s.activeReqs) != 0 + s.mu.Unlock() + + if busy { + <-ready + } + cleanup() + return ctx.Err() +} + +// ServeHTTPForTest responds to a single LocalAPI HTTP request. +// The request's context carries the actor that made the request +// and can be created with [NewContextWithActorForTest]. +// It is used in tests only. +func (s *Server) ServeHTTPForTest(w http.ResponseWriter, r *http.Request) { + s.serveHTTP(w, r) +} diff --git a/ipn/ipnserver/server_test.go b/ipn/ipnserver/server_test.go index b7d5ea144..713db9e50 100644 --- a/ipn/ipnserver/server_test.go +++ b/ipn/ipnserver/server_test.go @@ -1,46 +1,329 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -package ipnserver +package ipnserver_test import ( "context" + "errors" + "runtime" + "strconv" "sync" "testing" + + "tailscale.com/client/local" + "tailscale.com/envknob" + "tailscale.com/ipn" + "tailscale.com/ipn/lapitest" + "tailscale.com/tsd" + "tailscale.com/types/ptr" + "tailscale.com/util/syspolicy/pkey" + "tailscale.com/util/syspolicy/policytest" ) -func TestWaiterSet(t *testing.T) { - var s waiterSet +func TestUserConnectDisconnectNonWindows(t *testing.T) { + enableLogging := false + if runtime.GOOS == "windows" { + setGOOSForTest(t, "linux") + } + + ctx := context.Background() + server := lapitest.NewServer(t, lapitest.WithLogging(enableLogging)) + + // UserA connects and starts watching the IPN bus. + clientA := server.ClientWithName("UserA") + watcherA, _ := clientA.WatchIPNBus(ctx, 0) + + // The concept of "current user" is only relevant on Windows + // and it should not be set on non-Windows platforms. + server.CheckCurrentUser(nil) + + // Additionally, a different user should be able to connect and use the LocalAPI. + clientB := server.ClientWithName("UserB") + if _, gotErr := clientB.Status(ctx); gotErr != nil { + t.Fatalf("Status(%q): want nil; got %v", clientB.Username(), gotErr) + } + + // Watching the IPN bus should also work for UserB. + watcherB, _ := clientB.WatchIPNBus(ctx, 0) + + // And if we send a notification, both users should receive it. + wantErrMessage := "test error" + testNotify := ipn.Notify{ErrMessage: ptr.To(wantErrMessage)} + server.Backend().DebugNotify(testNotify) + + if n, err := watcherA.Next(); err != nil { + t.Fatalf("IPNBusWatcher.Next(%q): %v", clientA.Username(), err) + } else if gotErrMessage := n.ErrMessage; gotErrMessage == nil || *gotErrMessage != wantErrMessage { + t.Fatalf("IPNBusWatcher.Next(%q): want %v; got %v", clientA.Username(), wantErrMessage, gotErrMessage) + } + + if n, err := watcherB.Next(); err != nil { + t.Fatalf("IPNBusWatcher.Next(%q): %v", clientB.Username(), err) + } else if gotErrMessage := n.ErrMessage; gotErrMessage == nil || *gotErrMessage != wantErrMessage { + t.Fatalf("IPNBusWatcher.Next(%q): want %v; got %v", clientB.Username(), wantErrMessage, gotErrMessage) + } +} + +func TestUserConnectDisconnectOnWindows(t *testing.T) { + enableLogging := false + setGOOSForTest(t, "windows") + + ctx := context.Background() + server := lapitest.NewServer(t, lapitest.WithLogging(enableLogging)) + + client := server.ClientWithName("User") + _, cancelWatcher := client.WatchIPNBus(ctx, 0) + + // On Windows, however, the current user should be set to the user that connected. + server.CheckCurrentUser(client.Actor) + + // Cancel the IPN bus watcher request and wait for the server to unblock. + cancelWatcher() + server.BlockWhileInUse(ctx) + + // The current user should not be set after a disconnect, as no one is + // currently using the server. + server.CheckCurrentUser(nil) +} + +func TestIPNAlreadyInUseOnWindows(t *testing.T) { + enableLogging := false + setGOOSForTest(t, "windows") + + ctx := context.Background() + server := lapitest.NewServer(t, lapitest.WithLogging(enableLogging)) + + // UserA connects and starts watching the IPN bus. + clientA := server.ClientWithName("UserA") + clientA.WatchIPNBus(ctx, 0) + + // While UserA is connected, UserB should not be able to connect. + clientB := server.ClientWithName("UserB") + if _, gotErr := clientB.Status(ctx); gotErr == nil { + t.Fatalf("Status(%q): want error; got nil", clientB.Username()) + } else if wantError := "401 Unauthorized: Tailscale already in use by UserA"; gotErr.Error() != wantError { + t.Fatalf("Status(%q): want %q; got %q", clientB.Username(), wantError, gotErr.Error()) + } + + // Current user should still be UserA. + server.CheckCurrentUser(clientA.Actor) +} + +func TestSequentialOSUserSwitchingOnWindows(t *testing.T) { + enableLogging := false + setGOOSForTest(t, "windows") + + ctx := context.Background() + server := lapitest.NewServer(t, lapitest.WithLogging(enableLogging)) + + connectDisconnectAsUser := func(name string) { + // User connects and starts watching the IPN bus. + client := server.ClientWithName(name) + watcher, cancelWatcher := client.WatchIPNBus(ctx, 0) + defer cancelWatcher() + go pumpIPNBus(watcher) + + // It should be the current user from the LocalBackend's perspective... + server.CheckCurrentUser(client.Actor) + // until it disconnects. + cancelWatcher() + server.BlockWhileInUse(ctx) + // Now, the current user should be unset. + server.CheckCurrentUser(nil) + } + + // UserA logs in, uses Tailscale for a bit, then logs out. + connectDisconnectAsUser("UserA") + // Same for UserB. + connectDisconnectAsUser("UserB") +} + +func TestConcurrentOSUserSwitchingOnWindows(t *testing.T) { + enableLogging := false + setGOOSForTest(t, "windows") + + ctx := context.Background() + server := lapitest.NewServer(t, lapitest.WithLogging(enableLogging)) + + connectDisconnectAsUser := func(name string) { + // User connects and starts watching the IPN bus. + client := server.ClientWithName(name) + watcher, cancelWatcher := client.WatchIPNBus(ctx, ipn.NotifyInitialState) + defer cancelWatcher() - wantLen := func(want int, when string) { - t.Helper() - if got := len(s); got != want { - t.Errorf("%s: len = %v; want %v", when, got, want) + runtime.Gosched() + + // Get the current user from the LocalBackend's perspective + // as soon as we're connected. + gotUID, gotActor := server.Backend().CurrentUserForTest() + + // Wait for the first notification to arrive. + // It will either be the initial state we've requested via [ipn.NotifyInitialState], + // returned by an actual handler, or a "fake" notification sent by the server + // itself to indicate that it is being used by someone else. + n, err := watcher.Next() + if err != nil { + t.Fatal(err) + } + + // If our user lost the race and the IPN is in use by another user, + // we should just return. For the sake of this test, we're not + // interested in waiting for the server to become idle. + if n.State != nil && *n.State == ipn.InUseOtherUser { + return + } + + // Otherwise, our user should have been the current user since the time we connected. + if gotUID != client.Actor.UserID() { + t.Errorf("CurrentUser(Initial): got UID %q; want %q", gotUID, client.Actor.UserID()) + return + } + if hasActor := gotActor != nil; !hasActor || gotActor != client.Actor { + t.Errorf("CurrentUser(Initial): got %v; want %v", gotActor, client.Actor) + return } + + // And should still be the current user (as they're still connected)... + server.CheckCurrentUser(client.Actor) } - wantLen(0, "initial") - var mu sync.Mutex - ctx, cancel := context.WithCancel(context.Background()) - ready, cleanup := s.add(&mu, ctx) - wantLen(1, "after add") + numIterations := 10 + for range numIterations { + numGoRoutines := 100 + var wg sync.WaitGroup + wg.Add(numGoRoutines) + for i := range numGoRoutines { + // User logs in, uses Tailscale for a bit, then logs out + // in parallel with other users doing the same. + go func() { + defer wg.Done() + connectDisconnectAsUser("User-" + strconv.Itoa(i)) + }() + } + wg.Wait() - select { - case <-ready: - t.Fatal("should not be ready") - default: + if err := server.BlockWhileInUse(ctx); err != nil { + t.Fatalf("BlockUntilIdle: %v", err) + } + + server.CheckCurrentUser(nil) } - s.wakeAll() - <-ready +} + +func TestBlockWhileIdentityInUse(t *testing.T) { + enableLogging := false + setGOOSForTest(t, "windows") + + ctx := context.Background() + server := lapitest.NewServer(t, lapitest.WithLogging(enableLogging)) + + // connectWaitDisconnectAsUser connects as a user with the specified name + // and keeps the IPN bus watcher alive until the context is canceled. + // It returns a channel that is closed when done. + connectWaitDisconnectAsUser := func(ctx context.Context, name string) <-chan struct{} { + client := server.ClientWithName(name) + watcher, cancelWatcher := client.WatchIPNBus(ctx, 0) - wantLen(1, "after fire") - cleanup() - wantLen(0, "after cleanup") + done := make(chan struct{}) + go func() { + defer cancelWatcher() + defer close(done) + for { + _, err := watcher.Next() + if err != nil { + // There's either an error or the request has been canceled. + break + } + } + }() + return done + } + + for range 100 { + // Connect as UserA, and keep the connection alive + // until disconnectUserA is called. + userAContext, disconnectUserA := context.WithCancel(ctx) + userADone := connectWaitDisconnectAsUser(userAContext, "UserA") + disconnectUserA() + // Check if userB can connect. Calling it directly increases + // the likelihood of triggering a deadlock due to a race condition + // in blockWhileIdentityInUse. But the issue also occurs during + // the normal execution path when UserB connects to the IPN server + // while UserA is disconnecting. + userB := server.MakeTestActor("UserB", "ClientB") + server.BlockWhileInUseByOther(ctx, userB) + <-userADone + } +} - // And again but on an already-expired ctx. - cancel() - ready, cleanup = s.add(&mu, ctx) - <-ready // shouldn't block - cleanup() - wantLen(0, "at end") +func TestShutdownViaLocalAPI(t *testing.T) { + t.Parallel() + + errAccessDeniedByPolicy := errors.New("Access denied: shutdown access denied by policy") + + tests := []struct { + name string + allowTailscaledRestart *bool + wantErr error + }{ + { + name: "AllowTailscaledRestart/NotConfigured", + allowTailscaledRestart: nil, + wantErr: errAccessDeniedByPolicy, + }, + { + name: "AllowTailscaledRestart/False", + allowTailscaledRestart: ptr.To(false), + wantErr: errAccessDeniedByPolicy, + }, + { + name: "AllowTailscaledRestart/True", + allowTailscaledRestart: ptr.To(true), + wantErr: nil, // shutdown should be allowed + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + sys := tsd.NewSystem() + + var pol policytest.Config + if tt.allowTailscaledRestart != nil { + pol.Set(pkey.AllowTailscaledRestart, *tt.allowTailscaledRestart) + } + sys.Set(pol) + + server := lapitest.NewServer(t, lapitest.WithSys(sys)) + lc := server.ClientWithName("User") + + err := lc.ShutdownTailscaled(t.Context()) + checkError(t, err, tt.wantErr) + }) + } +} + +func checkError(tb testing.TB, got, want error) { + tb.Helper() + if (want == nil) != (got == nil) || + (want != nil && got != nil && want.Error() != got.Error() && !errors.Is(got, want)) { + tb.Fatalf("gotErr: %v; wantErr: %v", got, want) + } +} + +func setGOOSForTest(tb testing.TB, goos string) { + tb.Helper() + envknob.Setenv("TS_DEBUG_FAKE_GOOS", goos) + tb.Cleanup(func() { envknob.Setenv("TS_DEBUG_FAKE_GOOS", "") }) +} + +func pumpIPNBus(watcher *local.IPNBusWatcher) { + for { + _, err := watcher.Next() + if err != nil { + break + } + } } diff --git a/ipn/ipnserver/waiterset_test.go b/ipn/ipnserver/waiterset_test.go new file mode 100644 index 000000000..b7d5ea144 --- /dev/null +++ b/ipn/ipnserver/waiterset_test.go @@ -0,0 +1,46 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package ipnserver + +import ( + "context" + "sync" + "testing" +) + +func TestWaiterSet(t *testing.T) { + var s waiterSet + + wantLen := func(want int, when string) { + t.Helper() + if got := len(s); got != want { + t.Errorf("%s: len = %v; want %v", when, got, want) + } + } + wantLen(0, "initial") + var mu sync.Mutex + ctx, cancel := context.WithCancel(context.Background()) + + ready, cleanup := s.add(&mu, ctx) + wantLen(1, "after add") + + select { + case <-ready: + t.Fatal("should not be ready") + default: + } + s.wakeAll() + <-ready + + wantLen(1, "after fire") + cleanup() + wantLen(0, "after cleanup") + + // And again but on an already-expired ctx. + cancel() + ready, cleanup = s.add(&mu, ctx) + <-ready // shouldn't block + cleanup() + wantLen(0, "at end") +} diff --git a/ipn/ipnstate/ipnstate.go b/ipn/ipnstate/ipnstate.go index 9f8bd34f6..e7ae2d62b 100644 --- a/ipn/ipnstate/ipnstate.go +++ b/ipn/ipnstate/ipnstate.go @@ -216,6 +216,11 @@ type PeerStatusLite struct { } // PeerStatus describes a peer node and its current state. +// WARNING: The fields in PeerStatus are merged by the AddPeer method in the StatusBuilder. +// When adding a new field to PeerStatus, you must update AddPeer to handle merging +// the new field. The AddPeer function is responsible for combining multiple updates +// to the same peer, and any new field that is not merged properly may lead to +// inconsistencies or lost data in the peer status. type PeerStatus struct { ID tailcfg.StableNodeID PublicKey key.NodePublic @@ -246,9 +251,10 @@ type PeerStatus struct { PrimaryRoutes *views.Slice[netip.Prefix] `json:",omitempty"` // Endpoints: - Addrs []string - CurAddr string // one of Addrs, or unique if roaming - Relay string // DERP region + Addrs []string + CurAddr string // one of Addrs, or unique if roaming + Relay string // DERP region + PeerRelay string // peer relay address (ip:port:vni) RxBytes int64 TxBytes int64 @@ -270,6 +276,12 @@ type PeerStatus struct { // PeerAPIURL are the URLs of the node's PeerAPI servers. PeerAPIURL []string + // TaildropTargetStatus represents the node's eligibility to have files shared to it. + TaildropTarget TaildropTargetStatus + + // Reason why this peer cannot receive files. Empty if CanReceiveFiles=true + NoFileSharingReason string + // Capabilities are capabilities that the node has. // They're free-form strings, but should be in the form of URLs/URIs // such as: @@ -318,6 +330,21 @@ type PeerStatus struct { Location *tailcfg.Location `json:",omitempty"` } +type TaildropTargetStatus int + +const ( + TaildropTargetUnknown TaildropTargetStatus = iota + TaildropTargetAvailable + TaildropTargetNoNetmapAvailable + TaildropTargetIpnStateNotRunning + TaildropTargetMissingCap + TaildropTargetOffline + TaildropTargetNoPeerInfo + TaildropTargetUnsupportedOS + TaildropTargetNoPeerAPI + TaildropTargetOwnedByOtherUser +) + // HasCap reports whether ps has the given capability. func (ps *PeerStatus) HasCap(cap tailcfg.NodeCapability) bool { return ps.CapMap.Contains(cap) @@ -367,7 +394,7 @@ func (sb *StatusBuilder) MutateSelfStatus(f func(*PeerStatus)) { } // AddUser adds a user profile to the status. -func (sb *StatusBuilder) AddUser(id tailcfg.UserID, up tailcfg.UserProfile) { +func (sb *StatusBuilder) AddUser(id tailcfg.UserID, up tailcfg.UserProfileView) { if sb.locked { log.Printf("[unexpected] ipnstate: AddUser after Locked") return @@ -377,7 +404,7 @@ func (sb *StatusBuilder) AddUser(id tailcfg.UserID, up tailcfg.UserProfile) { sb.st.User = make(map[tailcfg.UserID]tailcfg.UserProfile) } - sb.st.User[id] = up + sb.st.User[id] = *up.AsStruct() } // AddIP adds a Tailscale IP address to the status. @@ -425,6 +452,9 @@ func (sb *StatusBuilder) AddPeer(peer key.NodePublic, st *PeerStatus) { if v := st.Relay; v != "" { e.Relay = v } + if v := st.PeerRelay; v != "" { + e.PeerRelay = v + } if v := st.UserID; v != 0 { e.UserID = v } @@ -512,6 +542,9 @@ func (sb *StatusBuilder) AddPeer(peer key.NodePublic, st *PeerStatus) { if v := st.Capabilities; v != nil { e.Capabilities = v } + if v := st.TaildropTarget; v != TaildropTargetUnknown { + e.TaildropTarget = v + } e.Location = st.Location } @@ -650,6 +683,8 @@ func osEmoji(os string) string { return "🐡" case "illumos": return "â˜€ī¸" + case "solaris": + return "đŸŒ¤ī¸" } return "đŸ‘Ŋ" } @@ -666,10 +701,17 @@ type PingResult struct { Err string LatencySeconds float64 - // Endpoint is the ip:port if direct UDP was used. - // It is not currently set for TSMP pings. + // Endpoint is a string of the form "{ip}:{port}" if direct UDP was used. It + // is not currently set for TSMP. Endpoint string + // PeerRelay is a string of the form "{ip}:{port}:vni:{vni}" if a peer + // relay was used. It is not currently set for TSMP. Note that this field + // is not omitted during JSON encoding if it contains a zero value. This is + // done for consistency with the Endpoint field; this structure is exposed + // externally via localAPI, so we want to maintain the existing convention. + PeerRelay string + // DERPRegionID is non-zero DERP region ID if DERP was used. // It is not currently set for TSMP pings. DERPRegionID int @@ -704,6 +746,7 @@ func (pr *PingResult) ToPingResponse(pingType tailcfg.PingType) *tailcfg.PingRes Err: pr.Err, LatencySeconds: pr.LatencySeconds, Endpoint: pr.Endpoint, + PeerRelay: pr.PeerRelay, DERPRegionID: pr.DERPRegionID, DERPRegionCode: pr.DERPRegionCode, PeerAPIPort: pr.PeerAPIPort, diff --git a/ipn/lapitest/backend.go b/ipn/lapitest/backend.go new file mode 100644 index 000000000..7a1c276a7 --- /dev/null +++ b/ipn/lapitest/backend.go @@ -0,0 +1,62 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package lapitest + +import ( + "testing" + + "tailscale.com/control/controlclient" + "tailscale.com/ipn/ipnlocal" + "tailscale.com/ipn/store/mem" + "tailscale.com/types/logid" + "tailscale.com/wgengine" +) + +// NewBackend returns a new [ipnlocal.LocalBackend] for testing purposes. +// It fails the test if the specified options are invalid or if the backend cannot be created. +func NewBackend(tb testing.TB, opts ...Option) *ipnlocal.LocalBackend { + tb.Helper() + options, err := newOptions(tb, opts...) + if err != nil { + tb.Fatalf("NewBackend: %v", err) + } + return newBackend(options) +} + +func newBackend(opts *options) *ipnlocal.LocalBackend { + tb := opts.TB() + tb.Helper() + + sys := opts.Sys() + if _, ok := sys.StateStore.GetOK(); !ok { + sys.Set(&mem.Store{}) + } + + e, err := wgengine.NewFakeUserspaceEngine(opts.Logf(), sys.Set, sys.HealthTracker.Get(), sys.UserMetricsRegistry(), sys.Bus.Get()) + if err != nil { + opts.tb.Fatalf("NewFakeUserspaceEngine: %v", err) + } + tb.Cleanup(e.Close) + sys.Set(e) + + b, err := ipnlocal.NewLocalBackend(opts.Logf(), logid.PublicID{}, sys, 0) + if err != nil { + tb.Fatalf("NewLocalBackend: %v", err) + } + tb.Cleanup(b.Shutdown) + b.SetControlClientGetterForTesting(opts.MakeControlClient) + return b +} + +// NewUnreachableControlClient is a [NewControlFn] that creates +// a new [controlclient.Client] for an unreachable control server. +func NewUnreachableControlClient(tb testing.TB, opts controlclient.Options) (controlclient.Client, error) { + tb.Helper() + opts.ServerURL = "https://127.0.0.1:1" + cc, err := controlclient.New(opts) + if err != nil { + tb.Fatal(err) + } + return cc, nil +} diff --git a/ipn/lapitest/client.go b/ipn/lapitest/client.go new file mode 100644 index 000000000..6d22e938b --- /dev/null +++ b/ipn/lapitest/client.go @@ -0,0 +1,71 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package lapitest + +import ( + "context" + "testing" + + "tailscale.com/client/local" + "tailscale.com/ipn" + "tailscale.com/ipn/ipnauth" +) + +// Client wraps a [local.Client] for testing purposes. +// It can be created using [Server.Client], [Server.ClientWithName], +// or [Server.ClientFor] and sends requests as the specified actor +// to the associated [Server]. +type Client struct { + tb testing.TB + // Client is the underlying [local.Client] wrapped by the test client. + // It is configured to send requests to the test server on behalf of the actor. + *local.Client + // Actor represents the user on whose behalf this client is making requests. + // The server uses it to determine the client's identity and permissions. + // The test can mutate the user to alter the actor's identity or permissions + // before making a new request. It is typically an [ipnauth.TestActor], + // unless the [Client] was created with s specific actor using [Server.ClientFor]. + Actor ipnauth.Actor +} + +// Username returns username of the client's owner. +func (c *Client) Username() string { + c.tb.Helper() + name, err := c.Actor.Username() + if err != nil { + c.tb.Fatalf("Client.Username: %v", err) + } + return name +} + +// WatchIPNBus is like [local.Client.WatchIPNBus] but returns a [local.IPNBusWatcher] +// that is closed when the test ends and a cancel function that stops the watcher. +// It fails the test if the underlying WatchIPNBus returns an error. +func (c *Client) WatchIPNBus(ctx context.Context, mask ipn.NotifyWatchOpt) (*local.IPNBusWatcher, context.CancelFunc) { + c.tb.Helper() + ctx, cancelWatcher := context.WithCancel(ctx) + c.tb.Cleanup(cancelWatcher) + watcher, err := c.Client.WatchIPNBus(ctx, mask) + name, _ := c.Actor.Username() + if err != nil { + c.tb.Fatalf("Client.WatchIPNBus(%q): %v", name, err) + } + c.tb.Cleanup(func() { watcher.Close() }) + return watcher, cancelWatcher +} + +// generateSequentialName generates a unique sequential name based on the given prefix and number n. +// It uses a base-26 encoding to create names like "User-A", "User-B", ..., "User-Z", "User-AA", etc. +func generateSequentialName(prefix string, n int) string { + n++ + name := "" + const numLetters = 'Z' - 'A' + 1 + for n > 0 { + n-- + remainder := byte(n % numLetters) + name = string([]byte{'A' + remainder}) + name + n = n / numLetters + } + return prefix + "-" + name +} diff --git a/ipn/lapitest/example_test.go b/ipn/lapitest/example_test.go new file mode 100644 index 000000000..57479199a --- /dev/null +++ b/ipn/lapitest/example_test.go @@ -0,0 +1,80 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package lapitest + +import ( + "context" + "testing" + + "tailscale.com/ipn" +) + +func TestClientServer(t *testing.T) { + t.Parallel() + + ctx := context.Background() + + // Create a server and two clients. + // Both clients represent the same user to make this work across platforms. + // On Windows we've been restricting the API usage to a single user at a time. + // While we're planning on changing this once a better permission model is in place, + // this test is currently limited to a single user (but more than one client is fine). + // Alternatively, we could override GOOS via envknobs to test as if we're + // on a different platform, but that would make the test depend on global state, etc. + s := NewServer(t, WithLogging(false)) + c1 := s.ClientWithName("User-A") + c2 := s.ClientWithName("User-A") + + // Start watching the IPN bus as the second client. + w2, _ := c2.WatchIPNBus(context.Background(), ipn.NotifyInitialPrefs) + + // We're supposed to get a notification about the initial prefs, + // and WantRunning should be false. + n, err := w2.Next() + for ; err == nil; n, err = w2.Next() { + if n.Prefs == nil { + // Ignore non-prefs notifications. + continue + } + if n.Prefs.WantRunning() { + t.Errorf("WantRunning(initial): got %v, want false", n.Prefs.WantRunning()) + } + break + } + if err != nil { + t.Fatalf("IPNBusWatcher.Next failed: %v", err) + } + + // Now send an EditPrefs request from the first client to set WantRunning to true. + change := &ipn.MaskedPrefs{Prefs: ipn.Prefs{WantRunning: true}, WantRunningSet: true} + gotPrefs, err := c1.EditPrefs(ctx, change) + if err != nil { + t.Fatalf("EditPrefs failed: %v", err) + } + if !gotPrefs.WantRunning { + t.Fatalf("EditPrefs.WantRunning: got %v, want true", gotPrefs.WantRunning) + } + + // We can check the backend directly to see if the prefs were set correctly. + if gotWantRunning := s.Backend().Prefs().WantRunning(); !gotWantRunning { + t.Fatalf("Backend.Prefs.WantRunning: got %v, want true", gotWantRunning) + } + + // And can also wait for the second client with an IPN bus watcher to receive the notification + // about the prefs change. + n, err = w2.Next() + for ; err == nil; n, err = w2.Next() { + if n.Prefs == nil { + // Ignore non-prefs notifications. + continue + } + if !n.Prefs.WantRunning() { + t.Fatalf("WantRunning(changed): got %v, want true", n.Prefs.WantRunning()) + } + break + } + if err != nil { + t.Fatalf("IPNBusWatcher.Next failed: %v", err) + } +} diff --git a/ipn/lapitest/opts.go b/ipn/lapitest/opts.go new file mode 100644 index 000000000..6eb1594da --- /dev/null +++ b/ipn/lapitest/opts.go @@ -0,0 +1,170 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package lapitest + +import ( + "context" + "errors" + "fmt" + "testing" + + "tailscale.com/control/controlclient" + "tailscale.com/ipn/ipnlocal" + "tailscale.com/tsd" + "tailscale.com/tstest" + "tailscale.com/types/lazy" + "tailscale.com/types/logger" +) + +// Option is any optional configuration that can be passed to [NewServer] or [NewBackend]. +type Option interface { + apply(*options) error +} + +// options is the merged result of all applied [Option]s. +type options struct { + tb testing.TB + ctx lazy.SyncValue[context.Context] + logf lazy.SyncValue[logger.Logf] + sys lazy.SyncValue[*tsd.System] + newCC lazy.SyncValue[NewControlFn] + backend lazy.SyncValue[*ipnlocal.LocalBackend] +} + +// newOptions returns a new [options] struct with the specified [Option]s applied. +func newOptions(tb testing.TB, opts ...Option) (*options, error) { + options := &options{tb: tb} + for _, opt := range opts { + if err := opt.apply(options); err != nil { + return nil, fmt.Errorf("lapitest: %w", err) + } + } + return options, nil +} + +// TB returns the owning [*testing.T] or [*testing.B]. +func (o *options) TB() testing.TB { + return o.tb +} + +// Context returns the base context to be used by the server. +func (o *options) Context() context.Context { + return o.ctx.Get(context.Background) +} + +// Logf returns the [logger.Logf] to be used for logging. +func (o *options) Logf() logger.Logf { + return o.logf.Get(func() logger.Logf { return logger.Discard }) +} + +// Sys returns the [tsd.System] that contains subsystems to be used +// when creating a new [ipnlocal.LocalBackend]. +func (o *options) Sys() *tsd.System { + return o.sys.Get(func() *tsd.System { return tsd.NewSystem() }) +} + +// Backend returns the [ipnlocal.LocalBackend] to be used by the server. +// If a backend is provided via [WithBackend], it is used as-is. +// Otherwise, a new backend is created with the the [options] in o. +func (o *options) Backend() *ipnlocal.LocalBackend { + return o.backend.Get(func() *ipnlocal.LocalBackend { return newBackend(o) }) +} + +// MakeControlClient returns a new [controlclient.Client] to be used by newly +// created [ipnlocal.LocalBackend]s. It is only used if no backend is provided +// via [WithBackend]. +func (o *options) MakeControlClient(opts controlclient.Options) (controlclient.Client, error) { + newCC := o.newCC.Get(func() NewControlFn { return NewUnreachableControlClient }) + return newCC(o.tb, opts) +} + +type loggingOption struct{ enableLogging bool } + +// WithLogging returns an [Option] that enables or disables logging. +func WithLogging(enableLogging bool) Option { + return loggingOption{enableLogging: enableLogging} +} + +func (o loggingOption) apply(opts *options) error { + var logf logger.Logf + if o.enableLogging { + logf = tstest.WhileTestRunningLogger(opts.tb) + } else { + logf = logger.Discard + } + if !opts.logf.Set(logf) { + return errors.New("logging already configured") + } + return nil +} + +type contextOption struct{ ctx context.Context } + +// WithContext returns an [Option] that sets the base context to be used by the [Server]. +func WithContext(ctx context.Context) Option { + return contextOption{ctx: ctx} +} + +func (o contextOption) apply(opts *options) error { + if !opts.ctx.Set(o.ctx) { + return errors.New("context already configured") + } + return nil +} + +type sysOption struct{ sys *tsd.System } + +// WithSys returns an [Option] that sets the [tsd.System] to be used +// when creating a new [ipnlocal.LocalBackend]. +func WithSys(sys *tsd.System) Option { + return sysOption{sys: sys} +} + +func (o sysOption) apply(opts *options) error { + if !opts.sys.Set(o.sys) { + return errors.New("tsd.System already configured") + } + return nil +} + +type backendOption struct{ backend *ipnlocal.LocalBackend } + +// WithBackend returns an [Option] that configures the server to use the specified +// [ipnlocal.LocalBackend] instead of creating a new one. +// It is mutually exclusive with [WithControlClient]. +func WithBackend(backend *ipnlocal.LocalBackend) Option { + return backendOption{backend: backend} +} + +func (o backendOption) apply(opts *options) error { + if _, ok := opts.backend.Peek(); ok { + return errors.New("backend cannot be set when control client is already set") + } + if !opts.backend.Set(o.backend) { + return errors.New("backend already set") + } + return nil +} + +// NewControlFn is any function that creates a new [controlclient.Client] +// with the specified options. +type NewControlFn func(tb testing.TB, opts controlclient.Options) (controlclient.Client, error) + +// WithControlClient returns an option that specifies a function to be used +// by the [ipnlocal.LocalBackend] when creating a new [controlclient.Client]. +// It is mutually exclusive with [WithBackend] and is only used if no backend +// has been provided. +func WithControlClient(newControl NewControlFn) Option { + return newControl +} + +func (fn NewControlFn) apply(opts *options) error { + if _, ok := opts.backend.Peek(); ok { + return errors.New("control client cannot be set when backend is already set") + } + if !opts.newCC.Set(fn) { + return errors.New("control client already set") + } + return nil +} diff --git a/ipn/lapitest/server.go b/ipn/lapitest/server.go new file mode 100644 index 000000000..457a338ab --- /dev/null +++ b/ipn/lapitest/server.go @@ -0,0 +1,324 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package lapitest provides utilities for black-box testing of LocalAPI ([ipnserver]). +package lapitest + +import ( + "context" + "fmt" + "net" + "net/http" + "net/http/httptest" + "sync" + "testing" + + "tailscale.com/client/local" + "tailscale.com/client/tailscale/apitype" + "tailscale.com/envknob" + "tailscale.com/ipn" + "tailscale.com/ipn/ipnauth" + "tailscale.com/ipn/ipnlocal" + "tailscale.com/ipn/ipnserver" + "tailscale.com/types/logger" + "tailscale.com/types/logid" + "tailscale.com/types/ptr" + "tailscale.com/util/mak" + "tailscale.com/util/rands" +) + +// A Server is an in-process LocalAPI server that can be used in end-to-end tests. +type Server struct { + tb testing.TB + + ctx context.Context + cancelCtx context.CancelFunc + + lb *ipnlocal.LocalBackend + ipnServer *ipnserver.Server + + // mu protects the following fields. + mu sync.Mutex + started bool + httpServer *httptest.Server + actorsByName map[string]*ipnauth.TestActor + lastClientID int +} + +// NewUnstartedServer returns a new [Server] with the specified options without starting it. +func NewUnstartedServer(tb testing.TB, opts ...Option) *Server { + tb.Helper() + options, err := newOptions(tb, opts...) + if err != nil { + tb.Fatalf("invalid options: %v", err) + } + + s := &Server{tb: tb, lb: options.Backend()} + s.ctx, s.cancelCtx = context.WithCancel(options.Context()) + s.ipnServer = newUnstartedIPNServer(options) + s.httpServer = httptest.NewUnstartedServer(http.HandlerFunc(s.serveHTTP)) + s.httpServer.Config.Addr = "http://" + apitype.LocalAPIHost + s.httpServer.Config.BaseContext = func(_ net.Listener) context.Context { return s.ctx } + s.httpServer.Config.ErrorLog = logger.StdLogger(logger.WithPrefix(options.Logf(), "lapitest: ")) + tb.Cleanup(s.Close) + return s +} + +// NewServer starts and returns a new [Server] with the specified options. +func NewServer(tb testing.TB, opts ...Option) *Server { + tb.Helper() + server := NewUnstartedServer(tb, opts...) + server.Start() + return server +} + +// Start starts the server from [NewUnstartedServer]. +func (s *Server) Start() { + s.tb.Helper() + s.mu.Lock() + defer s.mu.Unlock() + if !s.started && s.httpServer != nil { + s.httpServer.Start() + s.started = true + } +} + +// Backend returns the underlying [ipnlocal.LocalBackend]. +func (s *Server) Backend() *ipnlocal.LocalBackend { + s.tb.Helper() + return s.lb +} + +// Client returns a new [Client] configured for making requests to the server +// as a new [ipnauth.TestActor] with a unique username and [ipnauth.ClientID]. +func (s *Server) Client() *Client { + s.tb.Helper() + user := s.MakeTestActor("", "") // generate a unique username and client ID + return s.ClientFor(user) +} + +// ClientWithName returns a new [Client] configured for making requests to the server +// as a new [ipnauth.TestActor] with the specified name and a unique [ipnauth.ClientID]. +func (s *Server) ClientWithName(name string) *Client { + s.tb.Helper() + user := s.MakeTestActor(name, "") // generate a unique client ID + return s.ClientFor(user) +} + +// ClientFor returns a new [Client] configured for making requests to the server +// as the specified actor. +func (s *Server) ClientFor(actor ipnauth.Actor) *Client { + s.tb.Helper() + client := &Client{ + tb: s.tb, + Actor: actor, + } + client.Client = &local.Client{Transport: newRoundTripper(client, s.httpServer)} + return client +} + +// MakeTestActor returns a new [ipnauth.TestActor] with the specified name and client ID. +// If the name is empty, a unique sequential name is generated. Likewise, +// if clientID is empty, a unique sequential client ID is generated. +func (s *Server) MakeTestActor(name string, clientID string) *ipnauth.TestActor { + s.tb.Helper() + + s.mu.Lock() + defer s.mu.Unlock() + + // Generate a unique sequential name if the provided name is empty. + if name == "" { + n := len(s.actorsByName) + name = generateSequentialName("User", n) + } + + if clientID == "" { + s.lastClientID += 1 + clientID = fmt.Sprintf("Client-%d", s.lastClientID) + } + + // Create a new base actor if one doesn't already exist for the given name. + baseActor := s.actorsByName[name] + if baseActor == nil { + baseActor = &ipnauth.TestActor{Name: name} + if envknob.GOOS() == "windows" { + // Historically, as of 2025-04-15, IPN does not distinguish between + // different users on non-Windows devices. Therefore, the UID, which is + // an [ipn.WindowsUserID], should only be populated when the actual or + // fake GOOS is Windows. + baseActor.UID = ipn.WindowsUserID(fmt.Sprintf("S-1-5-21-1-0-0-%d", 1001+len(s.actorsByName))) + } + mak.Set(&s.actorsByName, name, baseActor) + s.tb.Cleanup(func() { delete(s.actorsByName, name) }) + } + + // Create a shallow copy of the base actor and assign it the new client ID. + actor := ptr.To(*baseActor) + actor.CID = ipnauth.ClientIDFrom(clientID) + return actor +} + +// BlockWhileInUse blocks until the server becomes idle (no active requests), +// or the context is done. It returns the context's error if it is done. +// It is used in tests only. +func (s *Server) BlockWhileInUse(ctx context.Context) error { + s.tb.Helper() + s.mu.Lock() + defer s.mu.Unlock() + if s.httpServer == nil { + return nil + } + return s.ipnServer.BlockWhileInUseForTest(ctx) +} + +// BlockWhileInUseByOther blocks while the specified actor can't connect to the server +// due to another actor being connected. +// It is used in tests only. +func (s *Server) BlockWhileInUseByOther(ctx context.Context, actor ipnauth.Actor) error { + s.tb.Helper() + s.mu.Lock() + defer s.mu.Unlock() + if s.httpServer == nil { + return nil + } + return s.ipnServer.BlockWhileInUseByOtherForTest(ctx, actor) +} + +// CheckCurrentUser fails the test if the current user does not match the expected user. +// It is only used on Windows and will be removed as we progress on tailscale/corp#18342. +func (s *Server) CheckCurrentUser(want ipnauth.Actor) { + s.tb.Helper() + var wantUID ipn.WindowsUserID + if want != nil { + wantUID = want.UserID() + } + lb := s.Backend() + if lb == nil { + s.tb.Fatalf("Backend: nil") + } + gotUID, gotActor := lb.CurrentUserForTest() + if gotUID != wantUID { + s.tb.Errorf("CurrentUser: got UID %q; want %q", gotUID, wantUID) + } + if hasActor := gotActor != nil; hasActor != (want != nil) || (want != nil && gotActor != want) { + s.tb.Errorf("CurrentUser: got %v; want %v", gotActor, want) + } +} + +func (s *Server) serveHTTP(w http.ResponseWriter, r *http.Request) { + actor, err := getActorForRequest(r) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + s.tb.Errorf("getActorForRequest: %v", err) + return + } + ctx := ipnserver.NewContextWithActorForTest(r.Context(), actor) + s.ipnServer.ServeHTTPForTest(w, r.Clone(ctx)) +} + +// Close shuts down the server and blocks until all outstanding requests on this server have completed. +func (s *Server) Close() { + s.tb.Helper() + s.mu.Lock() + server := s.httpServer + s.httpServer = nil + s.mu.Unlock() + + if server != nil { + server.Close() + } + s.cancelCtx() +} + +// newUnstartedIPNServer returns a new [ipnserver.Server] that exposes +// the specified [ipnlocal.LocalBackend] via LocalAPI, but does not start it. +// The opts carry additional configuration options. +func newUnstartedIPNServer(opts *options) *ipnserver.Server { + opts.TB().Helper() + lb := opts.Backend() + server := ipnserver.New(opts.Logf(), logid.PublicID{}, lb.EventBus(), lb.NetMon()) + server.SetLocalBackend(lb) + return server +} + +// roundTripper is a [http.RoundTripper] that sends requests to a [Server] +// on behalf of the [Client] who owns it. +type roundTripper struct { + client *Client + transport http.RoundTripper +} + +// newRoundTripper returns a new [http.RoundTripper] that sends requests +// to the specified server as the specified client. +func newRoundTripper(client *Client, server *httptest.Server) http.RoundTripper { + return &roundTripper{ + client: client, + transport: &http.Transport{DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + var std net.Dialer + return std.DialContext(ctx, network, server.Listener.Addr().(*net.TCPAddr).String()) + }}, + } +} + +// requestIDHeaderName is the name of the header used to pass request IDs +// between the client and server. It is used to associate requests with their actors. +const requestIDHeaderName = "TS-Request-ID" + +// RoundTrip implements [http.RoundTripper] by sending the request to the [ipnserver.Server] +// on behalf of the owning [Client]. It registers each request for the duration +// of the call and associates it with the actor sending the request. +func (rt *roundTripper) RoundTrip(r *http.Request) (*http.Response, error) { + reqID, unregister := registerRequest(rt.client.Actor) + defer unregister() + r = r.Clone(r.Context()) + r.Header.Set(requestIDHeaderName, reqID) + return rt.transport.RoundTrip(r) +} + +// getActorForRequest returns the actor for a given request. +// It returns an error if the request is not associated with an actor, +// such as when it wasn't sent by a [roundTripper]. +func getActorForRequest(r *http.Request) (ipnauth.Actor, error) { + reqID := r.Header.Get(requestIDHeaderName) + if reqID == "" { + return nil, fmt.Errorf("missing %s header", requestIDHeaderName) + } + actor, ok := getActorByRequestID(reqID) + if !ok { + return nil, fmt.Errorf("unknown request: %s", reqID) + } + return actor, nil +} + +var ( + inFlightRequestsMu sync.Mutex + inFlightRequests map[string]ipnauth.Actor +) + +// registerRequest associates a request with the specified actor and returns a unique request ID +// which can be used to retrieve the actor later. The returned function unregisters the request. +func registerRequest(actor ipnauth.Actor) (requestID string, unregister func()) { + inFlightRequestsMu.Lock() + defer inFlightRequestsMu.Unlock() + for { + requestID = rands.HexString(16) + if _, ok := inFlightRequests[requestID]; !ok { + break + } + } + mak.Set(&inFlightRequests, requestID, actor) + return requestID, func() { + inFlightRequestsMu.Lock() + defer inFlightRequestsMu.Unlock() + delete(inFlightRequests, requestID) + } +} + +// getActorByRequestID returns the actor associated with the specified request ID. +// It returns the actor and true if found, or nil and false if not. +func getActorByRequestID(requestID string) (ipnauth.Actor, bool) { + inFlightRequestsMu.Lock() + defer inFlightRequestsMu.Unlock() + actor, ok := inFlightRequests[requestID] + return actor, ok +} diff --git a/ipn/localapi/cert.go b/ipn/localapi/cert.go index 323406f7b..2313631cc 100644 --- a/ipn/localapi/cert.go +++ b/ipn/localapi/cert.go @@ -1,7 +1,7 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -//go:build !ios && !android && !js +//go:build !ios && !android && !js && !ts_omit_acme package localapi @@ -14,6 +14,10 @@ import ( "tailscale.com/ipn/ipnlocal" ) +func init() { + Register("cert/", (*Handler).serveCert) +} + func (h *Handler) serveCert(w http.ResponseWriter, r *http.Request) { if !h.PermitWrite && !h.PermitCert { http.Error(w, "cert access denied", http.StatusForbidden) diff --git a/ipn/localapi/debug.go b/ipn/localapi/debug.go new file mode 100644 index 000000000..ae9cb01e0 --- /dev/null +++ b/ipn/localapi/debug.go @@ -0,0 +1,495 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !ts_omit_debug + +package localapi + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net" + "net/http" + "net/netip" + "reflect" + "slices" + "strconv" + "sync" + "time" + + "tailscale.com/client/tailscale/apitype" + "tailscale.com/feature" + "tailscale.com/feature/buildfeatures" + "tailscale.com/ipn" + "tailscale.com/types/logger" + "tailscale.com/util/eventbus" + "tailscale.com/util/httpm" +) + +func init() { + Register("component-debug-logging", (*Handler).serveComponentDebugLogging) + Register("debug", (*Handler).serveDebug) + Register("debug-rotate-disco-key", (*Handler).serveDebugRotateDiscoKey) + Register("dev-set-state-store", (*Handler).serveDevSetStateStore) + Register("debug-bus-events", (*Handler).serveDebugBusEvents) + Register("debug-bus-graph", (*Handler).serveEventBusGraph) + Register("debug-derp-region", (*Handler).serveDebugDERPRegion) + Register("debug-dial-types", (*Handler).serveDebugDialTypes) + Register("debug-log", (*Handler).serveDebugLog) + Register("debug-packet-filter-matches", (*Handler).serveDebugPacketFilterMatches) + Register("debug-packet-filter-rules", (*Handler).serveDebugPacketFilterRules) + Register("debug-peer-endpoint-changes", (*Handler).serveDebugPeerEndpointChanges) + Register("debug-optional-features", (*Handler).serveDebugOptionalFeatures) +} + +func (h *Handler) serveDebugPeerEndpointChanges(w http.ResponseWriter, r *http.Request) { + if !h.PermitRead { + http.Error(w, "status access denied", http.StatusForbidden) + return + } + + ipStr := r.FormValue("ip") + if ipStr == "" { + http.Error(w, "missing 'ip' parameter", http.StatusBadRequest) + return + } + ip, err := netip.ParseAddr(ipStr) + if err != nil { + http.Error(w, "invalid IP", http.StatusBadRequest) + return + } + w.Header().Set("Content-Type", "application/json") + chs, err := h.b.GetPeerEndpointChanges(r.Context(), ip) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + e := json.NewEncoder(w) + e.SetIndent("", "\t") + e.Encode(chs) +} + +func (h *Handler) serveComponentDebugLogging(w http.ResponseWriter, r *http.Request) { + if !h.PermitWrite { + http.Error(w, "debug access denied", http.StatusForbidden) + return + } + component := r.FormValue("component") + secs, _ := strconv.Atoi(r.FormValue("secs")) + err := h.b.SetComponentDebugLogging(component, h.clock.Now().Add(time.Duration(secs)*time.Second)) + var res struct { + Error string + } + if err != nil { + res.Error = err.Error() + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(res) +} + +func (h *Handler) serveDebugDialTypes(w http.ResponseWriter, r *http.Request) { + if !h.PermitWrite { + http.Error(w, "debug-dial-types access denied", http.StatusForbidden) + return + } + if r.Method != httpm.POST { + http.Error(w, "only POST allowed", http.StatusMethodNotAllowed) + return + } + + ip := r.FormValue("ip") + port := r.FormValue("port") + network := r.FormValue("network") + + addr := ip + ":" + port + if _, err := netip.ParseAddrPort(addr); err != nil { + w.WriteHeader(http.StatusBadRequest) + fmt.Fprintf(w, "invalid address %q: %v", addr, err) + return + } + + ctx, cancel := context.WithTimeout(r.Context(), 5*time.Second) + defer cancel() + + var bareDialer net.Dialer + + dialer := h.b.Dialer() + + var peerDialer net.Dialer + peerDialer.Control = dialer.PeerDialControlFunc() + + // Kick off a dial with each available dialer in parallel. + dialers := []struct { + name string + dial func(context.Context, string, string) (net.Conn, error) + }{ + {"SystemDial", dialer.SystemDial}, + {"UserDial", dialer.UserDial}, + {"PeerDial", peerDialer.DialContext}, + {"BareDial", bareDialer.DialContext}, + } + type result struct { + name string + conn net.Conn + err error + } + results := make(chan result, len(dialers)) + + var wg sync.WaitGroup + for _, dialer := range dialers { + dialer := dialer // loop capture + + wg.Add(1) + go func() { + defer wg.Done() + conn, err := dialer.dial(ctx, network, addr) + results <- result{dialer.name, conn, err} + }() + } + + wg.Wait() + for range len(dialers) { + res := <-results + fmt.Fprintf(w, "[%s] connected=%v err=%v\n", res.name, res.conn != nil, res.err) + if res.conn != nil { + res.conn.Close() + } + } +} + +func (h *Handler) serveDebug(w http.ResponseWriter, r *http.Request) { + if !buildfeatures.HasDebug { + http.Error(w, "debug not supported in this build", http.StatusNotImplemented) + return + } + if !h.PermitWrite { + http.Error(w, "debug access denied", http.StatusForbidden) + return + } + if r.Method != httpm.POST { + http.Error(w, "POST required", http.StatusMethodNotAllowed) + return + } + // The action is normally in a POST form parameter, but + // some actions (like "notify") want a full JSON body, so + // permit some to have their action in a header. + var action string + switch v := r.Header.Get("Debug-Action"); v { + case "notify": + action = v + default: + action = r.FormValue("action") + } + var err error + switch action { + case "derp-set-homeless": + h.b.MagicConn().SetHomeless(true) + case "derp-unset-homeless": + h.b.MagicConn().SetHomeless(false) + case "rebind": + err = h.b.DebugRebind() + case "restun": + err = h.b.DebugReSTUN() + case "notify": + var n ipn.Notify + err = json.NewDecoder(r.Body).Decode(&n) + if err != nil { + break + } + h.b.DebugNotify(n) + case "notify-last-netmap": + h.b.DebugNotifyLastNetMap() + case "break-tcp-conns": + err = h.b.DebugBreakTCPConns() + case "break-derp-conns": + err = h.b.DebugBreakDERPConns() + case "force-netmap-update": + h.b.DebugForceNetmapUpdate() + case "control-knobs": + k := h.b.ControlKnobs() + w.Header().Set("Content-Type", "application/json") + err = json.NewEncoder(w).Encode(k.AsDebugJSON()) + if err == nil { + return + } + case "pick-new-derp": + err = h.b.DebugPickNewDERP() + case "force-prefer-derp": + var n int + err = json.NewDecoder(r.Body).Decode(&n) + if err != nil { + break + } + h.b.DebugForcePreferDERP(n) + case "peer-relay-servers": + servers := h.b.DebugPeerRelayServers().Slice() + slices.SortFunc(servers, func(a, b netip.Addr) int { + return a.Compare(b) + }) + err = json.NewEncoder(w).Encode(servers) + if err == nil { + return + } + case "rotate-disco-key": + err = h.b.DebugRotateDiscoKey() + case "": + err = fmt.Errorf("missing parameter 'action'") + default: + err = fmt.Errorf("unknown action %q", action) + } + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + w.Header().Set("Content-Type", "text/plain") + io.WriteString(w, "done\n") +} + +func (h *Handler) serveDevSetStateStore(w http.ResponseWriter, r *http.Request) { + if !h.PermitWrite { + http.Error(w, "debug access denied", http.StatusForbidden) + return + } + if r.Method != httpm.POST { + http.Error(w, "POST required", http.StatusMethodNotAllowed) + return + } + if err := h.b.SetDevStateStore(r.FormValue("key"), r.FormValue("value")); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + w.Header().Set("Content-Type", "text/plain") + io.WriteString(w, "done\n") +} + +func (h *Handler) serveDebugPacketFilterRules(w http.ResponseWriter, r *http.Request) { + if !h.PermitWrite { + http.Error(w, "debug access denied", http.StatusForbidden) + return + } + nm := h.b.NetMap() + if nm == nil { + http.Error(w, "no netmap", http.StatusNotFound) + return + } + w.Header().Set("Content-Type", "application/json") + + enc := json.NewEncoder(w) + enc.SetIndent("", "\t") + enc.Encode(nm.PacketFilterRules) +} + +func (h *Handler) serveDebugPacketFilterMatches(w http.ResponseWriter, r *http.Request) { + if !h.PermitWrite { + http.Error(w, "debug access denied", http.StatusForbidden) + return + } + nm := h.b.NetMap() + if nm == nil { + http.Error(w, "no netmap", http.StatusNotFound) + return + } + w.Header().Set("Content-Type", "application/json") + + enc := json.NewEncoder(w) + enc.SetIndent("", "\t") + enc.Encode(nm.PacketFilter) +} + +// debugEventError provides the JSON encoding of internal errors from event processing. +type debugEventError struct { + Error string +} + +// serveDebugBusEvents taps into the tailscaled/utils/eventbus and streams +// events to the client. +func (h *Handler) serveDebugBusEvents(w http.ResponseWriter, r *http.Request) { + // Require write access (~root) as the logs could contain something + // sensitive. + if !h.PermitWrite { + http.Error(w, "event bus access denied", http.StatusForbidden) + return + } + if r.Method != httpm.GET { + http.Error(w, "GET required", http.StatusMethodNotAllowed) + return + } + + bus, ok := h.LocalBackend().Sys().Bus.GetOK() + if !ok { + http.Error(w, "event bus not running", http.StatusNoContent) + return + } + + f, ok := w.(http.Flusher) + if !ok { + http.Error(w, "streaming unsupported", http.StatusInternalServerError) + return + } + + io.WriteString(w, `{"Event":"[event listener connected]\n"}`+"\n") + f.Flush() + + mon := bus.Debugger().WatchBus() + defer mon.Close() + + i := 0 + for { + select { + case <-r.Context().Done(): + fmt.Fprintf(w, `{"Event":"[event listener closed]\n"}`) + return + case <-mon.Done(): + return + case event := <-mon.Events(): + data := eventbus.DebugEvent{ + Count: i, + Type: reflect.TypeOf(event.Event).String(), + Event: event.Event, + From: event.From.Name(), + } + for _, client := range event.To { + data.To = append(data.To, client.Name()) + } + + if msg, err := json.Marshal(data); err != nil { + data.Event = debugEventError{Error: fmt.Sprintf( + "failed to marshal JSON for %T", event.Event, + )} + if errMsg, err := json.Marshal(data); err != nil { + fmt.Fprintf(w, + `{"Count": %d, "Event":"[ERROR] failed to marshal JSON for %T\n"}`, + i, event.Event) + } else { + w.Write(errMsg) + } + } else { + w.Write(msg) + } + f.Flush() + i++ + } + } +} + +// serveEventBusGraph taps into the event bus and dumps out the active graph of +// publishers and subscribers. It does not represent anything about the messages +// exchanged. +func (h *Handler) serveEventBusGraph(w http.ResponseWriter, r *http.Request) { + if r.Method != httpm.GET { + http.Error(w, "GET required", http.StatusMethodNotAllowed) + return + } + + bus, ok := h.LocalBackend().Sys().Bus.GetOK() + if !ok { + http.Error(w, "event bus not running", http.StatusPreconditionFailed) + return + } + + debugger := bus.Debugger() + clients := debugger.Clients() + + graph := map[string]eventbus.DebugTopic{} + + for _, client := range clients { + for _, pub := range debugger.PublishTypes(client) { + topic, ok := graph[pub.Name()] + if !ok { + topic = eventbus.DebugTopic{Name: pub.Name()} + } + topic.Publisher = client.Name() + graph[pub.Name()] = topic + } + for _, sub := range debugger.SubscribeTypes(client) { + topic, ok := graph[sub.Name()] + if !ok { + topic = eventbus.DebugTopic{Name: sub.Name()} + } + topic.Subscribers = append(topic.Subscribers, client.Name()) + graph[sub.Name()] = topic + } + } + + // The top level map is not really needed for the client, convert to a list. + topics := eventbus.DebugTopics{} + for _, v := range graph { + topics.Topics = append(topics.Topics, v) + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(topics) +} + +func (h *Handler) serveDebugLog(w http.ResponseWriter, r *http.Request) { + if !buildfeatures.HasLogTail { + http.Error(w, feature.ErrUnavailable.Error(), http.StatusNotImplemented) + return + } + if !h.PermitRead { + http.Error(w, "debug-log access denied", http.StatusForbidden) + return + } + if r.Method != httpm.POST { + http.Error(w, "only POST allowed", http.StatusMethodNotAllowed) + return + } + defer h.b.TryFlushLogs() // kick off upload after we're done logging + + type logRequestJSON struct { + Lines []string + Prefix string + } + + var logRequest logRequestJSON + if err := json.NewDecoder(r.Body).Decode(&logRequest); err != nil { + http.Error(w, "invalid JSON body", http.StatusBadRequest) + return + } + + prefix := logRequest.Prefix + if prefix == "" { + prefix = "debug-log" + } + logf := logger.WithPrefix(h.logf, prefix+": ") + + // We can write logs too fast for logtail to handle, even when + // opting-out of rate limits. Limit ourselves to at most one message + // per 20ms and a burst of 60 log lines, which should be fast enough to + // not block for too long but slow enough that we can upload all lines. + logf = logger.SlowLoggerWithClock(r.Context(), logf, 20*time.Millisecond, 60, h.clock.Now) + + for _, line := range logRequest.Lines { + logf("%s", line) + } + + w.WriteHeader(http.StatusNoContent) +} + +func (h *Handler) serveDebugOptionalFeatures(w http.ResponseWriter, r *http.Request) { + of := &apitype.OptionalFeatures{ + Features: feature.Registered(), + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(of) +} + +func (h *Handler) serveDebugRotateDiscoKey(w http.ResponseWriter, r *http.Request) { + if !h.PermitWrite { + http.Error(w, "debug access denied", http.StatusForbidden) + return + } + if r.Method != httpm.POST { + http.Error(w, "POST required", http.StatusMethodNotAllowed) + return + } + if err := h.b.DebugRotateDiscoKey(); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + w.Header().Set("Content-Type", "text/plain") + io.WriteString(w, "done\n") +} diff --git a/ipn/localapi/debugderp.go b/ipn/localapi/debugderp.go index 85eb031e6..3edbc0856 100644 --- a/ipn/localapi/debugderp.go +++ b/ipn/localapi/debugderp.go @@ -1,9 +1,12 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause +//go:build !ts_omit_debug + package localapi import ( + "cmp" "context" "crypto/tls" "encoding/json" @@ -81,7 +84,7 @@ func (h *Handler) serveDebugDERPRegion(w http.ResponseWriter, r *http.Request) { client *http.Client = http.DefaultClient ) checkConn := func(derpNode *tailcfg.DERPNode) bool { - port := firstNonzero(derpNode.DERPPort, 443) + port := cmp.Or(derpNode.DERPPort, 443) var ( hasIPv4 bool @@ -89,7 +92,7 @@ func (h *Handler) serveDebugDERPRegion(w http.ResponseWriter, r *http.Request) { ) // Check IPv4 first - addr := net.JoinHostPort(firstNonzero(derpNode.IPv4, derpNode.HostName), strconv.Itoa(port)) + addr := net.JoinHostPort(cmp.Or(derpNode.IPv4, derpNode.HostName), strconv.Itoa(port)) conn, err := dialer.DialContext(ctx, "tcp4", addr) if err != nil { st.Errors = append(st.Errors, fmt.Sprintf("Error connecting to node %q @ %q over IPv4: %v", derpNode.HostName, addr, err)) @@ -98,7 +101,7 @@ func (h *Handler) serveDebugDERPRegion(w http.ResponseWriter, r *http.Request) { // Upgrade to TLS and verify that works properly. tlsConn := tls.Client(conn, &tls.Config{ - ServerName: firstNonzero(derpNode.CertName, derpNode.HostName), + ServerName: cmp.Or(derpNode.CertName, derpNode.HostName), }) if err := tlsConn.HandshakeContext(ctx); err != nil { st.Errors = append(st.Errors, fmt.Sprintf("Error upgrading connection to node %q @ %q to TLS over IPv4: %v", derpNode.HostName, addr, err)) @@ -108,7 +111,7 @@ func (h *Handler) serveDebugDERPRegion(w http.ResponseWriter, r *http.Request) { } // Check IPv6 - addr = net.JoinHostPort(firstNonzero(derpNode.IPv6, derpNode.HostName), strconv.Itoa(port)) + addr = net.JoinHostPort(cmp.Or(derpNode.IPv6, derpNode.HostName), strconv.Itoa(port)) conn, err = dialer.DialContext(ctx, "tcp6", addr) if err != nil { st.Errors = append(st.Errors, fmt.Sprintf("Error connecting to node %q @ %q over IPv6: %v", derpNode.HostName, addr, err)) @@ -117,7 +120,7 @@ func (h *Handler) serveDebugDERPRegion(w http.ResponseWriter, r *http.Request) { // Upgrade to TLS and verify that works properly. tlsConn := tls.Client(conn, &tls.Config{ - ServerName: firstNonzero(derpNode.CertName, derpNode.HostName), + ServerName: cmp.Or(derpNode.CertName, derpNode.HostName), // TODO(andrew-d): we should print more // detailed failure information on if/why TLS // verification fails @@ -166,7 +169,7 @@ func (h *Handler) serveDebugDERPRegion(w http.ResponseWriter, r *http.Request) { addr = addrs[0] } - addrPort := netip.AddrPortFrom(addr, uint16(firstNonzero(derpNode.STUNPort, 3478))) + addrPort := netip.AddrPortFrom(addr, uint16(cmp.Or(derpNode.STUNPort, 3478))) txID := stun.NewTxID() req := stun.Request(txID) @@ -227,49 +230,59 @@ func (h *Handler) serveDebugDERPRegion(w http.ResponseWriter, r *http.Request) { // Start by checking whether we can establish a HTTP connection for _, derpNode := range reg.Nodes { - connSuccess := checkConn(derpNode) + if !derpNode.STUNOnly { + connSuccess := checkConn(derpNode) - // Verify that the /generate_204 endpoint works - captivePortalURL := "http://" + derpNode.HostName + "/generate_204" - resp, err := client.Get(captivePortalURL) - if err != nil { - st.Warnings = append(st.Warnings, fmt.Sprintf("Error making request to the captive portal check %q; is port 80 blocked?", captivePortalURL)) - } else { - resp.Body.Close() - } + // Verify that the /generate_204 endpoint works + captivePortalURL := fmt.Sprintf("http://%s/generate_204?t=%d", derpNode.HostName, time.Now().Unix()) + req, err := http.NewRequest("GET", captivePortalURL, nil) + if err != nil { + st.Warnings = append(st.Warnings, fmt.Sprintf("Internal error creating request for captive portal check: %v", err)) + continue + } + req.Header.Set("Cache-Control", "no-cache, no-store, must-revalidate, no-transform, max-age=0") + resp, err := client.Do(req) + if err != nil { + st.Warnings = append(st.Warnings, fmt.Sprintf("Error making request to the captive portal check %q; is port 80 blocked?", captivePortalURL)) + } else { + resp.Body.Close() + } - if !connSuccess { - continue - } + if !connSuccess { + continue + } - fakePrivKey := key.NewNode() - - // Next, repeatedly get the server key to see if the node is - // behind a load balancer (incorrectly). - serverPubKeys := make(map[key.NodePublic]bool) - for i := range 5 { - func() { - rc := derphttp.NewRegionClient(fakePrivKey, h.logf, h.b.NetMon(), func() *tailcfg.DERPRegion { - return &tailcfg.DERPRegion{ - RegionID: reg.RegionID, - RegionCode: reg.RegionCode, - RegionName: reg.RegionName, - Nodes: []*tailcfg.DERPNode{derpNode}, + fakePrivKey := key.NewNode() + + // Next, repeatedly get the server key to see if the node is + // behind a load balancer (incorrectly). + serverPubKeys := make(map[key.NodePublic]bool) + for i := range 5 { + func() { + rc := derphttp.NewRegionClient(fakePrivKey, h.logf, h.b.NetMon(), func() *tailcfg.DERPRegion { + return &tailcfg.DERPRegion{ + RegionID: reg.RegionID, + RegionCode: reg.RegionCode, + RegionName: reg.RegionName, + Nodes: []*tailcfg.DERPNode{derpNode}, + } + }) + if err := rc.Connect(ctx); err != nil { + st.Errors = append(st.Errors, fmt.Sprintf("Error connecting to node %q @ try %d: %v", derpNode.HostName, i, err)) + return } - }) - if err := rc.Connect(ctx); err != nil { - st.Errors = append(st.Errors, fmt.Sprintf("Error connecting to node %q @ try %d: %v", derpNode.HostName, i, err)) - return - } - if len(serverPubKeys) == 0 { - st.Info = append(st.Info, fmt.Sprintf("Successfully established a DERP connection with node %q", derpNode.HostName)) - } - serverPubKeys[rc.ServerPublicKey()] = true - }() - } - if len(serverPubKeys) > 1 { - st.Errors = append(st.Errors, fmt.Sprintf("Received multiple server public keys (%d); is the DERP server behind a load balancer?", len(serverPubKeys))) + if len(serverPubKeys) == 0 { + st.Info = append(st.Info, fmt.Sprintf("Successfully established a DERP connection with node %q", derpNode.HostName)) + } + serverPubKeys[rc.ServerPublicKey()] = true + }() + } + if len(serverPubKeys) > 1 { + st.Errors = append(st.Errors, fmt.Sprintf("Received multiple server public keys (%d); is the DERP server behind a load balancer?", len(serverPubKeys))) + } + } else { + st.Info = append(st.Info, fmt.Sprintf("Node %q is marked STUNOnly; skipped non-STUN checks", derpNode.HostName)) } // Send a STUN query to this node to verify whether or not it @@ -292,13 +305,3 @@ func (h *Handler) serveDebugDERPRegion(w http.ResponseWriter, r *http.Request) { // issued in the first place, tell them specifically that the // cert is bad not just that the connection failed. } - -func firstNonzero[T comparable](items ...T) T { - var zero T - for _, item := range items { - if item != zero { - return item - } - } - return zero -} diff --git a/ipn/localapi/localapi.go b/ipn/localapi/localapi.go index 7c076e8ab..d3503d302 100644 --- a/ipn/localapi/localapi.go +++ b/ipn/localapi/localapi.go @@ -7,23 +7,15 @@ package localapi import ( "bytes" "cmp" - "context" - "crypto/sha256" - "encoding/hex" + "crypto/subtle" "encoding/json" "errors" "fmt" "io" - "maps" - "mime" - "mime/multipart" "net" "net/http" - "net/http/httputil" "net/netip" "net/url" - "os" - "path" "runtime" "slices" "strconv" @@ -31,122 +23,150 @@ import ( "sync" "time" - "github.com/google/uuid" "golang.org/x/net/dns/dnsmessage" "tailscale.com/client/tailscale/apitype" - "tailscale.com/clientupdate" - "tailscale.com/drive" "tailscale.com/envknob" + "tailscale.com/feature" + "tailscale.com/feature/buildfeatures" + "tailscale.com/health/healthmsg" "tailscale.com/hostinfo" "tailscale.com/ipn" "tailscale.com/ipn/ipnauth" "tailscale.com/ipn/ipnlocal" "tailscale.com/ipn/ipnstate" "tailscale.com/logtail" - "tailscale.com/net/netmon" + "tailscale.com/net/netns" "tailscale.com/net/netutil" - "tailscale.com/net/portmapper" "tailscale.com/tailcfg" - "tailscale.com/taildrop" - "tailscale.com/tka" "tailscale.com/tstime" - "tailscale.com/types/dnstype" + "tailscale.com/types/appctype" "tailscale.com/types/key" "tailscale.com/types/logger" "tailscale.com/types/logid" "tailscale.com/types/ptr" - "tailscale.com/types/tkatype" "tailscale.com/util/clientmetric" - "tailscale.com/util/httphdr" + "tailscale.com/util/eventbus" "tailscale.com/util/httpm" "tailscale.com/util/mak" "tailscale.com/util/osdiag" - "tailscale.com/util/progresstracking" "tailscale.com/util/rands" - "tailscale.com/util/testenv" + "tailscale.com/util/syspolicy/pkey" "tailscale.com/version" "tailscale.com/wgengine/magicsock" ) -type localAPIHandler func(*Handler, http.ResponseWriter, *http.Request) +var ( + metricInvalidRequests = clientmetric.NewCounter("localapi_invalid_requests") + metricDebugMetricsCalls = clientmetric.NewCounter("localapi_debugmetric_requests") + metricUserMetricsCalls = clientmetric.NewCounter("localapi_usermetric_requests") + metricBugReportRequests = clientmetric.NewCounter("localapi_bugreport_requests") +) + +type LocalAPIHandler func(*Handler, http.ResponseWriter, *http.Request) // handler is the set of LocalAPI handlers, keyed by the part of the // Request.URL.Path after "/localapi/v0/". If the key ends with a trailing slash // then it's a prefix match. -var handler = map[string]localAPIHandler{ +var handler = map[string]LocalAPIHandler{ // The prefix match handlers end with a slash: - "cert/": (*Handler).serveCert, - "file-put/": (*Handler).serveFilePut, - "files/": (*Handler).serveFiles, "profiles/": (*Handler).serveProfiles, // The other /localapi/v0/NAME handlers are exact matches and contain only NAME // without a trailing slash: - "bugreport": (*Handler).serveBugReport, - "check-ip-forwarding": (*Handler).serveCheckIPForwarding, - "check-prefs": (*Handler).serveCheckPrefs, - "check-udp-gro-forwarding": (*Handler).serveCheckUDPGROForwarding, - "component-debug-logging": (*Handler).serveComponentDebugLogging, - "debug": (*Handler).serveDebug, - "debug-capture": (*Handler).serveDebugCapture, - "debug-derp-region": (*Handler).serveDebugDERPRegion, - "debug-dial-types": (*Handler).serveDebugDialTypes, - "debug-log": (*Handler).serveDebugLog, - "debug-packet-filter-matches": (*Handler).serveDebugPacketFilterMatches, - "debug-packet-filter-rules": (*Handler).serveDebugPacketFilterRules, - "debug-peer-endpoint-changes": (*Handler).serveDebugPeerEndpointChanges, - "debug-portmap": (*Handler).serveDebugPortmap, - "derpmap": (*Handler).serveDERPMap, - "dev-set-state-store": (*Handler).serveDevSetStateStore, - "dial": (*Handler).serveDial, - "dns-osconfig": (*Handler).serveDNSOSConfig, - "dns-query": (*Handler).serveDNSQuery, - "drive/fileserver-address": (*Handler).serveDriveServerAddr, - "drive/shares": (*Handler).serveShares, - "file-targets": (*Handler).serveFileTargets, - "goroutines": (*Handler).serveGoroutines, - "handle-push-message": (*Handler).serveHandlePushMessage, - "id-token": (*Handler).serveIDToken, - "login-interactive": (*Handler).serveLoginInteractive, - "logout": (*Handler).serveLogout, - "logtap": (*Handler).serveLogTap, - "metrics": (*Handler).serveMetrics, - "ping": (*Handler).servePing, - "pprof": (*Handler).servePprof, - "prefs": (*Handler).servePrefs, - "query-feature": (*Handler).serveQueryFeature, - "reload-config": (*Handler).reloadConfig, - "reset-auth": (*Handler).serveResetAuth, - "serve-config": (*Handler).serveServeConfig, - "set-dns": (*Handler).serveSetDNS, - "set-expiry-sooner": (*Handler).serveSetExpirySooner, - "set-gui-visible": (*Handler).serveSetGUIVisible, - "set-push-device-token": (*Handler).serveSetPushDeviceToken, - "set-udp-gro-forwarding": (*Handler).serveSetUDPGROForwarding, - "set-use-exit-node-enabled": (*Handler).serveSetUseExitNodeEnabled, - "start": (*Handler).serveStart, - "status": (*Handler).serveStatus, - "suggest-exit-node": (*Handler).serveSuggestExitNode, - "tka/affected-sigs": (*Handler).serveTKAAffectedSigs, - "tka/cosign-recovery-aum": (*Handler).serveTKACosignRecoveryAUM, - "tka/disable": (*Handler).serveTKADisable, - "tka/force-local-disable": (*Handler).serveTKALocalDisable, - "tka/generate-recovery-aum": (*Handler).serveTKAGenerateRecoveryAUM, - "tka/init": (*Handler).serveTKAInit, - "tka/log": (*Handler).serveTKALog, - "tka/modify": (*Handler).serveTKAModify, - "tka/sign": (*Handler).serveTKASign, - "tka/status": (*Handler).serveTKAStatus, - "tka/submit-recovery-aum": (*Handler).serveTKASubmitRecoveryAUM, - "tka/verify-deeplink": (*Handler).serveTKAVerifySigningDeeplink, - "tka/wrap-preauth-key": (*Handler).serveTKAWrapPreauthKey, - "update/check": (*Handler).serveUpdateCheck, - "update/install": (*Handler).serveUpdateInstall, - "update/progress": (*Handler).serveUpdateProgress, - "upload-client-metrics": (*Handler).serveUploadClientMetrics, - "usermetrics": (*Handler).serveUserMetrics, - "watch-ipn-bus": (*Handler).serveWatchIPNBus, - "whois": (*Handler).serveWhoIs, + "check-prefs": (*Handler).serveCheckPrefs, + "check-so-mark-in-use": (*Handler).serveCheckSOMarkInUse, + "derpmap": (*Handler).serveDERPMap, + "goroutines": (*Handler).serveGoroutines, + "login-interactive": (*Handler).serveLoginInteractive, + "logout": (*Handler).serveLogout, + "ping": (*Handler).servePing, + "prefs": (*Handler).servePrefs, + "reload-config": (*Handler).reloadConfig, + "reset-auth": (*Handler).serveResetAuth, + "set-expiry-sooner": (*Handler).serveSetExpirySooner, + "shutdown": (*Handler).serveShutdown, + "start": (*Handler).serveStart, + "status": (*Handler).serveStatus, + "whois": (*Handler).serveWhoIs, +} + +func init() { + if buildfeatures.HasAppConnectors { + Register("appc-route-info", (*Handler).serveGetAppcRouteInfo) + } + if buildfeatures.HasAdvertiseRoutes { + Register("check-ip-forwarding", (*Handler).serveCheckIPForwarding) + Register("check-udp-gro-forwarding", (*Handler).serveCheckUDPGROForwarding) + Register("set-udp-gro-forwarding", (*Handler).serveSetUDPGROForwarding) + } + if buildfeatures.HasUseExitNode && runtime.GOOS == "linux" { + Register("check-reverse-path-filtering", (*Handler).serveCheckReversePathFiltering) + } + if buildfeatures.HasClientMetrics { + Register("upload-client-metrics", (*Handler).serveUploadClientMetrics) + } + if buildfeatures.HasClientUpdate { + Register("update/check", (*Handler).serveUpdateCheck) + } + if buildfeatures.HasUseExitNode { + Register("suggest-exit-node", (*Handler).serveSuggestExitNode) + Register("set-use-exit-node-enabled", (*Handler).serveSetUseExitNodeEnabled) + } + if buildfeatures.HasACME { + Register("set-dns", (*Handler).serveSetDNS) + } + if buildfeatures.HasDebug { + Register("bugreport", (*Handler).serveBugReport) + Register("pprof", (*Handler).servePprof) + } + if buildfeatures.HasDebug || buildfeatures.HasServe { + Register("watch-ipn-bus", (*Handler).serveWatchIPNBus) + } + if buildfeatures.HasDNS { + Register("dns-osconfig", (*Handler).serveDNSOSConfig) + Register("dns-query", (*Handler).serveDNSQuery) + } + if buildfeatures.HasUserMetrics { + Register("usermetrics", (*Handler).serveUserMetrics) + } + if buildfeatures.HasServe { + Register("query-feature", (*Handler).serveQueryFeature) + } + if buildfeatures.HasOutboundProxy || buildfeatures.HasSSH { + Register("dial", (*Handler).serveDial) + } + if buildfeatures.HasClientMetrics || buildfeatures.HasDebug { + Register("metrics", (*Handler).serveMetrics) + } + if buildfeatures.HasDebug || buildfeatures.HasAdvertiseRoutes { + Register("disconnect-control", (*Handler).disconnectControl) + } + // Alpha/experimental/debug features. These should be moved to + // their own features if/when they graduate. + if buildfeatures.HasDebug { + Register("id-token", (*Handler).serveIDToken) + Register("alpha-set-device-attrs", (*Handler).serveSetDeviceAttrs) // see tailscale/corp#24690 + Register("handle-push-message", (*Handler).serveHandlePushMessage) + Register("set-push-device-token", (*Handler).serveSetPushDeviceToken) + } + if buildfeatures.HasDebug || runtime.GOOS == "windows" || runtime.GOOS == "darwin" { + Register("set-gui-visible", (*Handler).serveSetGUIVisible) + } + if buildfeatures.HasLogTail { + // TODO(bradfitz): separate out logtail tap functionality from upload + // functionality to make this possible? But seems unlikely people would + // want just this. They could "tail -f" or "journalctl -f" their logs + // themselves. + Register("logtap", (*Handler).serveLogTap) + } +} + +// Register registers a new LocalAPI handler for the given name. +func Register(name string, fn LocalAPIHandler) { + if _, ok := handler[name]; ok { + panic("duplicate LocalAPI handler registration: " + name) + } + handler[name] = fn } var ( @@ -159,10 +179,26 @@ var ( metrics = map[string]*clientmetric.Metric{} ) -// NewHandler creates a new LocalAPI HTTP handler. All parameters except netMon -// are required (if non-nil it's used to do faster interface lookups). -func NewHandler(b *ipnlocal.LocalBackend, logf logger.Logf, logID logid.PublicID) *Handler { - return &Handler{b: b, logf: logf, backendLogID: logID, clock: tstime.StdClock{}} +// NewHandler creates a new LocalAPI HTTP handler from the given config. +func NewHandler(cfg HandlerConfig) *Handler { + return &Handler{ + Actor: cfg.Actor, + b: cfg.Backend, + logf: cfg.Logf, + backendLogID: cfg.LogID, + clock: tstime.StdClock{}, + eventBus: cfg.EventBus, + } +} + +// HandlerConfig carries the settings for a local API handler. +// All fields are required. +type HandlerConfig struct { + Actor ipnauth.Actor + Backend *ipnlocal.LocalBackend + Logf logger.Logf + LogID logid.PublicID + EventBus *eventbus.Bus } type Handler struct { @@ -191,6 +227,15 @@ type Handler struct { logf logger.Logf backendLogID logid.PublicID clock tstime.Clock + eventBus *eventbus.Bus // read-only after initialization +} + +func (h *Handler) Logf(format string, args ...any) { + h.logf(format, args...) +} + +func (h *Handler) LocalBackend() *ipnlocal.LocalBackend { + return h.b } func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { @@ -215,13 +260,14 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { http.Error(w, "auth required", http.StatusUnauthorized) return } - if pass != h.RequiredPassword { + if subtle.ConstantTimeCompare([]byte(pass), []byte(h.RequiredPassword)) == 0 { metricInvalidRequests.Add(1) http.Error(w, "bad password", http.StatusForbidden) return } } - if fn, ok := handlerForPath(r.URL.Path); ok { + if fn, route, ok := handlerForPath(r.URL.Path); ok { + h.logRequest(r.Method, route) fn(h, w, r) } else { http.NotFound(w, r) @@ -257,9 +303,9 @@ func (h *Handler) validHost(hostname string) bool { // handlerForPath returns the LocalAPI handler for the provided Request.URI.Path. // (the path doesn't include any query parameters) -func handlerForPath(urlPath string) (h localAPIHandler, ok bool) { +func handlerForPath(urlPath string) (h LocalAPIHandler, route string, ok bool) { if urlPath == "/" { - return (*Handler).serveLocalAPIRoot, true + return (*Handler).serveLocalAPIRoot, "/", true } suff, ok := strings.CutPrefix(urlPath, "/localapi/v0/") if !ok { @@ -267,22 +313,31 @@ func handlerForPath(urlPath string) (h localAPIHandler, ok bool) { // to people that they're not necessarily stable APIs. In practice we'll // probably need to keep them pretty stable anyway, but for now treat // them as an internal implementation detail. - return nil, false + return nil, "", false } if fn, ok := handler[suff]; ok { // Here we match exact handler suffixes like "status" or ones with a // slash already in their name, like "tka/status". - return fn, true + return fn, "/localapi/v0/" + suff, true } // Otherwise, it might be a prefix match like "files/*" which we look up // by the prefix including first trailing slash. if i := strings.IndexByte(suff, '/'); i != -1 { suff = suff[:i+1] if fn, ok := handler[suff]; ok { - return fn, true + return fn, "/localapi/v0/" + suff, true } } - return nil, false + return nil, "", false +} + +func (h *Handler) logRequest(method, route string) { + switch method { + case httpm.GET, httpm.HEAD, httpm.OPTIONS: + // don't log safe methods + default: + h.Logf("localapi: [%s] %s", method, route) + } } func (*Handler) serveLocalAPIRoot(w http.ResponseWriter, r *http.Request) { @@ -315,7 +370,7 @@ func (h *Handler) serveIDToken(w http.ResponseWriter, r *http.Request) { http.Error(w, err.Error(), http.StatusInternalServerError) return } - httpReq, err := http.NewRequest("POST", "https://unused/machine/id-token", bytes.NewReader(b)) + httpReq, err := http.NewRequest(httpm.POST, "https://unused/machine/id-token", bytes.NewReader(b)) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return @@ -338,7 +393,7 @@ func (h *Handler) serveBugReport(w http.ResponseWriter, r *http.Request) { http.Error(w, "bugreport access denied", http.StatusForbidden) return } - if r.Method != "POST" { + if r.Method != httpm.POST { http.Error(w, "only POST allowed", http.StatusMethodNotAllowed) return } @@ -382,8 +437,19 @@ func (h *Handler) serveBugReport(w http.ResponseWriter, r *http.Request) { // OS-specific details h.logf.JSON(1, "UserBugReportOS", osdiag.SupportInfo(osdiag.LogSupportInfoReasonBugReport)) + // Tailnet Lock details + st := h.b.NetworkLockStatus() + if st.Enabled { + h.logf.JSON(1, "UserBugReportTailnetLockStatus", st) + if st.NodeKeySignature != nil { + h.logf("user bugreport tailnet lock signature: %s", st.NodeKeySignature.String()) + } + } + if defBool(r.URL.Query().Get("diagnose"), false) { - h.b.Doctor(r.Context(), logger.WithPrefix(h.logf, "diag: ")) + if f, ok := ipnlocal.HookDoctor.GetOk(); ok { + f(r.Context(), h.b, logger.WithPrefix(h.logf, "diag: ")) + } } w.Header().Set("Content-Type", "text/plain") fmt.Fprintln(w, startMarker) @@ -416,6 +482,8 @@ func (h *Handler) serveBugReport(w http.ResponseWriter, r *http.Request) { // NOTE(andrew): if we have anything else we want to do while recording // a bugreport, we can add it here. + metricBugReportRequests.Add(1) + // Read from the client; this will also return when the client closes // the connection. var buf [1]byte @@ -444,6 +512,33 @@ func (h *Handler) serveWhoIs(w http.ResponseWriter, r *http.Request) { h.serveWhoIsWithBackend(w, r, h.b) } +// serveSetDeviceAttrs is (as of 2024-12-30) an experimental LocalAPI handler to +// set device attributes via the control plane. +// +// See tailscale/corp#24690. +func (h *Handler) serveSetDeviceAttrs(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + if !h.PermitWrite { + http.Error(w, "set-device-attrs access denied", http.StatusForbidden) + return + } + if r.Method != httpm.PATCH { + http.Error(w, "only PATCH allowed", http.StatusMethodNotAllowed) + return + } + var req map[string]any + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + if err := h.b.SetDeviceAttrs(ctx, req); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + w.Header().Set("Content-Type", "application/json") + io.WriteString(w, "{}\n") +} + // localBackendWhoIsMethods is the subset of ipn.LocalBackend as needed // by the localapi WhoIs method. type localBackendWhoIsMethods interface { @@ -532,7 +627,7 @@ func (h *Handler) serveLogTap(w http.ResponseWriter, r *http.Request) { http.Error(w, "logtap access denied", http.StatusForbidden) return } - if r.Method != "GET" { + if r.Method != httpm.GET { http.Error(w, "GET required", http.StatusMethodNotAllowed) return } @@ -561,6 +656,7 @@ func (h *Handler) serveLogTap(w http.ResponseWriter, r *http.Request) { } func (h *Handler) serveMetrics(w http.ResponseWriter, r *http.Request) { + metricDebugMetricsCalls.Add(1) // Require write access out of paranoia that the metrics // might contain something sensitive. if !h.PermitWrite { @@ -571,630 +667,220 @@ func (h *Handler) serveMetrics(w http.ResponseWriter, r *http.Request) { clientmetric.WritePrometheusExpositionFormat(w) } -// TODO(kradalby): Remove this once we have landed on a final set of -// metrics to export to clients and consider the metrics stable. -var debugUsermetricsEndpoint = envknob.RegisterBool("TS_DEBUG_USER_METRICS") - +// serveUserMetrics returns user-facing metrics in Prometheus text +// exposition format. func (h *Handler) serveUserMetrics(w http.ResponseWriter, r *http.Request) { - if !testenv.InTest() && !debugUsermetricsEndpoint() { - http.Error(w, "usermetrics debug flag not enabled", http.StatusForbidden) - return - } + metricUserMetricsCalls.Add(1) h.b.UserMetricsRegistry().Handler(w, r) } -func (h *Handler) serveDebug(w http.ResponseWriter, r *http.Request) { +// servePprofFunc is the implementation of Handler.servePprof, after auth, +// for platforms where we want to link it in. +var servePprofFunc func(http.ResponseWriter, *http.Request) + +func (h *Handler) servePprof(w http.ResponseWriter, r *http.Request) { + // Require write access out of paranoia that the profile dump + // might contain something sensitive. if !h.PermitWrite { - http.Error(w, "debug access denied", http.StatusForbidden) - return - } - if r.Method != "POST" { - http.Error(w, "POST required", http.StatusMethodNotAllowed) + http.Error(w, "profile access denied", http.StatusForbidden) return } - // The action is normally in a POST form parameter, but - // some actions (like "notify") want a full JSON body, so - // permit some to have their action in a header. - var action string - switch v := r.Header.Get("Debug-Action"); v { - case "notify": - action = v - default: - action = r.FormValue("action") - } - var err error - switch action { - case "derp-set-homeless": - h.b.MagicConn().SetHomeless(true) - case "derp-unset-homeless": - h.b.MagicConn().SetHomeless(false) - case "rebind": - err = h.b.DebugRebind() - case "restun": - err = h.b.DebugReSTUN() - case "notify": - var n ipn.Notify - err = json.NewDecoder(r.Body).Decode(&n) - if err != nil { - break - } - h.b.DebugNotify(n) - case "notify-last-netmap": - h.b.DebugNotifyLastNetMap() - case "break-tcp-conns": - err = h.b.DebugBreakTCPConns() - case "break-derp-conns": - err = h.b.DebugBreakDERPConns() - case "force-netmap-update": - h.b.DebugForceNetmapUpdate() - case "control-knobs": - k := h.b.ControlKnobs() - w.Header().Set("Content-Type", "application/json") - err = json.NewEncoder(w).Encode(k.AsDebugJSON()) - if err == nil { - return - } - case "pick-new-derp": - err = h.b.DebugPickNewDERP() - case "": - err = fmt.Errorf("missing parameter 'action'") - default: - err = fmt.Errorf("unknown action %q", action) - } - if err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) + if servePprofFunc == nil { + http.Error(w, "not implemented on this platform", http.StatusServiceUnavailable) return } - w.Header().Set("Content-Type", "text/plain") - io.WriteString(w, "done\n") + servePprofFunc(w, r) } -func (h *Handler) serveDevSetStateStore(w http.ResponseWriter, r *http.Request) { +// disconnectControl is the handler for local API /disconnect-control endpoint that shuts down control client, so that +// node no longer communicates with control. Doing this makes control consider this node inactive. This can be used +// before shutting down a replica of HA subnet router or app connector deployments to ensure that control tells the +// peers to switch over to another replica whilst still maintaining th existing peer connections. +func (h *Handler) disconnectControl(w http.ResponseWriter, r *http.Request) { if !h.PermitWrite { - http.Error(w, "debug access denied", http.StatusForbidden) - return - } - if r.Method != "POST" { - http.Error(w, "POST required", http.StatusMethodNotAllowed) + http.Error(w, "access denied", http.StatusForbidden) return } - if err := h.b.SetDevStateStore(r.FormValue("key"), r.FormValue("value")); err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) + if r.Method != httpm.POST { + http.Error(w, "use POST", http.StatusMethodNotAllowed) return } - w.Header().Set("Content-Type", "text/plain") - io.WriteString(w, "done\n") + h.b.DisconnectControl() } -func (h *Handler) serveDebugPacketFilterRules(w http.ResponseWriter, r *http.Request) { +func (h *Handler) reloadConfig(w http.ResponseWriter, r *http.Request) { if !h.PermitWrite { - http.Error(w, "debug access denied", http.StatusForbidden) + http.Error(w, "access denied", http.StatusForbidden) return } - nm := h.b.NetMap() - if nm == nil { - http.Error(w, "no netmap", http.StatusNotFound) + if r.Method != httpm.POST { + http.Error(w, "use POST", http.StatusMethodNotAllowed) + return + } + ok, err := h.b.ReloadConfig() + var res apitype.ReloadConfigResponse + res.Reloaded = ok + if err != nil { + res.Err = err.Error() return } w.Header().Set("Content-Type", "application/json") - - enc := json.NewEncoder(w) - enc.SetIndent("", "\t") - enc.Encode(nm.PacketFilterRules) + json.NewEncoder(w).Encode(&res) } -func (h *Handler) serveDebugPacketFilterMatches(w http.ResponseWriter, r *http.Request) { +func (h *Handler) serveResetAuth(w http.ResponseWriter, r *http.Request) { if !h.PermitWrite { - http.Error(w, "debug access denied", http.StatusForbidden) + http.Error(w, "reset-auth modify access denied", http.StatusForbidden) return } - nm := h.b.NetMap() - if nm == nil { - http.Error(w, "no netmap", http.StatusNotFound) + if r.Method != httpm.POST { + http.Error(w, "use POST", http.StatusMethodNotAllowed) return } - w.Header().Set("Content-Type", "application/json") - - enc := json.NewEncoder(w) - enc.SetIndent("", "\t") - enc.Encode(nm.PacketFilter) -} -func (h *Handler) serveDebugPortmap(w http.ResponseWriter, r *http.Request) { - if !h.PermitWrite { - http.Error(w, "debug access denied", http.StatusForbidden) + if err := h.b.ResetAuth(); err != nil { + http.Error(w, "reset-auth failed: "+err.Error(), http.StatusInternalServerError) return } - w.Header().Set("Content-Type", "text/plain") + w.WriteHeader(http.StatusNoContent) +} - dur, err := time.ParseDuration(r.FormValue("duration")) - if err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) +func (h *Handler) serveCheckIPForwarding(w http.ResponseWriter, r *http.Request) { + if !h.PermitRead { + http.Error(w, "IP forwarding check access denied", http.StatusForbidden) return } + var warning string + if err := h.b.CheckIPForwarding(); err != nil { + warning = err.Error() + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(struct { + Warning string + }{ + Warning: warning, + }) +} - gwSelf := r.FormValue("gateway_and_self") - - // Update portmapper debug flags - debugKnobs := &portmapper.DebugKnobs{VerboseLogs: true} - switch r.FormValue("type") { - case "": - case "pmp": - debugKnobs.DisablePCP = true - debugKnobs.DisableUPnP = true - case "pcp": - debugKnobs.DisablePMP = true - debugKnobs.DisableUPnP = true - case "upnp": - debugKnobs.DisablePCP = true - debugKnobs.DisablePMP = true - default: - http.Error(w, "unknown portmap debug type", http.StatusBadRequest) +// serveCheckSOMarkInUse reports whether SO_MARK is in use on the linux while +// running without TUN. For any other OS, it reports false. +func (h *Handler) serveCheckSOMarkInUse(w http.ResponseWriter, r *http.Request) { + if !h.PermitRead { + http.Error(w, "SO_MARK check access denied", http.StatusForbidden) return } + usingSOMark := netns.UseSocketMark() + usingUserspaceNetworking := h.b.Sys().IsNetstack() + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(struct { + UseSOMark bool + }{ + UseSOMark: usingSOMark || usingUserspaceNetworking, + }) +} - if defBool(r.FormValue("log_http"), false) { - debugKnobs.LogHTTP = true +func (h *Handler) serveCheckReversePathFiltering(w http.ResponseWriter, r *http.Request) { + if !h.PermitRead { + http.Error(w, "reverse path filtering check access denied", http.StatusForbidden) + return } + var warning string - var ( - logLock sync.Mutex - handlerDone bool - ) - logf := func(format string, args ...any) { - if !strings.HasSuffix(format, "\n") { - format = format + "\n" - } - - logLock.Lock() - defer logLock.Unlock() - - // The portmapper can call this log function after the HTTP - // handler returns, which is not allowed and can cause a panic. - // If this happens, ignore the log lines since this typically - // occurs due to a client disconnect. - if handlerDone { - return - } - - // Write and flush each line to the client so that output is streamed - fmt.Fprintf(w, format, args...) - if f, ok := w.(http.Flusher); ok { - f.Flush() + state := h.b.Sys().NetMon.Get().InterfaceState() + warn, err := netutil.CheckReversePathFiltering(state) + if err == nil && len(warn) > 0 { + var msg strings.Builder + msg.WriteString(healthmsg.WarnExitNodeUsage + ":\n") + for _, w := range warn { + msg.WriteString("- " + w + "\n") } + msg.WriteString(healthmsg.DisableRPFilter) + warning = msg.String() } - defer func() { - logLock.Lock() - handlerDone = true - logLock.Unlock() - }() - - ctx, cancel := context.WithTimeout(r.Context(), dur) - defer cancel() - - done := make(chan bool, 1) - - var c *portmapper.Client - c = portmapper.NewClient(logger.WithPrefix(logf, "portmapper: "), h.b.NetMon(), debugKnobs, h.b.ControlKnobs(), func() { - logf("portmapping changed.") - logf("have mapping: %v", c.HaveMapping()) - - if ext, ok := c.GetCachedMappingOrStartCreatingOne(); ok { - logf("cb: mapping: %v", ext) - select { - case done <- true: - default: - } - return - } - logf("cb: no mapping") + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(struct { + Warning string + }{ + Warning: warning, }) - defer c.Close() +} - netMon, err := netmon.New(logger.WithPrefix(logf, "monitor: ")) - if err != nil { - logf("error creating monitor: %v", err) +func (h *Handler) serveCheckUDPGROForwarding(w http.ResponseWriter, r *http.Request) { + if !h.PermitRead { + http.Error(w, "UDP GRO forwarding check access denied", http.StatusForbidden) return } - - gatewayAndSelfIP := func() (gw, self netip.Addr, ok bool) { - if a, b, ok := strings.Cut(gwSelf, "/"); ok { - gw = netip.MustParseAddr(a) - self = netip.MustParseAddr(b) - return gw, self, true - } - return netMon.GatewayAndSelfIP() + var warning string + if err := h.b.CheckUDPGROForwarding(); err != nil { + warning = err.Error() } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(struct { + Warning string + }{ + Warning: warning, + }) +} - c.SetGatewayLookupFunc(gatewayAndSelfIP) - - gw, selfIP, ok := gatewayAndSelfIP() - if !ok { - logf("no gateway or self IP; %v", netMon.InterfaceState()) +func (h *Handler) serveSetUDPGROForwarding(w http.ResponseWriter, r *http.Request) { + if !buildfeatures.HasGRO { + http.Error(w, feature.ErrUnavailable.Error(), http.StatusNotImplemented) return } - logf("gw=%v; self=%v", gw, selfIP) - - uc, err := net.ListenPacket("udp", "0.0.0.0:0") - if err != nil { + if !h.PermitWrite { + http.Error(w, "UDP GRO forwarding set access denied", http.StatusForbidden) return } - defer uc.Close() - c.SetLocalPort(uint16(uc.LocalAddr().(*net.UDPAddr).Port)) - - res, err := c.Probe(ctx) - if err != nil { - logf("error in Probe: %v", err) - return + var warning string + if err := h.b.SetUDPGROForwarding(); err != nil { + warning = err.Error() } - logf("Probe: %+v", res) + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(struct { + Warning string + }{ + Warning: warning, + }) +} - if !res.PCP && !res.PMP && !res.UPnP { - logf("no portmapping services available") +func (h *Handler) serveStatus(w http.ResponseWriter, r *http.Request) { + if !h.PermitRead { + http.Error(w, "status access denied", http.StatusForbidden) return } - - if ext, ok := c.GetCachedMappingOrStartCreatingOne(); ok { - logf("mapping: %v", ext) + w.Header().Set("Content-Type", "application/json") + var st *ipnstate.Status + if defBool(r.FormValue("peers"), true) { + st = h.b.Status() } else { - logf("no mapping") - } - - select { - case <-done: - case <-ctx.Done(): - if r.Context().Err() == nil { - logf("serveDebugPortmap: context done: %v", ctx.Err()) - } else { - h.logf("serveDebugPortmap: context done: %v", ctx.Err()) - } + st = h.b.StatusWithoutPeers() } + e := json.NewEncoder(w) + e.SetIndent("", "\t") + e.Encode(st) } -func (h *Handler) serveComponentDebugLogging(w http.ResponseWriter, r *http.Request) { - if !h.PermitWrite { - http.Error(w, "debug access denied", http.StatusForbidden) - return - } - component := r.FormValue("component") - secs, _ := strconv.Atoi(r.FormValue("secs")) - err := h.b.SetComponentDebugLogging(component, h.clock.Now().Add(time.Duration(secs)*time.Second)) - var res struct { - Error string +// InUseOtherUserIPNStream reports whether r is a request for the watch-ipn-bus +// handler. If so, it writes an ipn.Notify InUseOtherUser message to the user +// and returns true. Otherwise it returns false, in which case it doesn't write +// to w. +// +// Unlike the regular watch-ipn-bus handler, this one doesn't block. The caller +// (in ipnserver.Server) provides the blocking until the connection is no longer +// in use. +func InUseOtherUserIPNStream(w http.ResponseWriter, r *http.Request, err error) (handled bool) { + if r.Method != httpm.GET || r.URL.Path != "/localapi/v0/watch-ipn-bus" { + return false } + js, err := json.Marshal(&ipn.Notify{ + Version: version.Long(), + State: ptr.To(ipn.InUseOtherUser), + ErrMessage: ptr.To(err.Error()), + }) if err != nil { - res.Error = err.Error() + return false } - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(res) -} - -func (h *Handler) serveDebugDialTypes(w http.ResponseWriter, r *http.Request) { - if !h.PermitWrite { - http.Error(w, "debug-dial-types access denied", http.StatusForbidden) - return - } - if r.Method != httpm.POST { - http.Error(w, "only POST allowed", http.StatusMethodNotAllowed) - return - } - - ip := r.FormValue("ip") - port := r.FormValue("port") - network := r.FormValue("network") - - addr := ip + ":" + port - if _, err := netip.ParseAddrPort(addr); err != nil { - w.WriteHeader(http.StatusBadRequest) - fmt.Fprintf(w, "invalid address %q: %v", addr, err) - return - } - - ctx, cancel := context.WithTimeout(r.Context(), 5*time.Second) - defer cancel() - - var bareDialer net.Dialer - - dialer := h.b.Dialer() - - var peerDialer net.Dialer - peerDialer.Control = dialer.PeerDialControlFunc() - - // Kick off a dial with each available dialer in parallel. - dialers := []struct { - name string - dial func(context.Context, string, string) (net.Conn, error) - }{ - {"SystemDial", dialer.SystemDial}, - {"UserDial", dialer.UserDial}, - {"PeerDial", peerDialer.DialContext}, - {"BareDial", bareDialer.DialContext}, - } - type result struct { - name string - conn net.Conn - err error - } - results := make(chan result, len(dialers)) - - var wg sync.WaitGroup - for _, dialer := range dialers { - dialer := dialer // loop capture - - wg.Add(1) - go func() { - defer wg.Done() - conn, err := dialer.dial(ctx, network, addr) - results <- result{dialer.name, conn, err} - }() - } - - wg.Wait() - for range len(dialers) { - res := <-results - fmt.Fprintf(w, "[%s] connected=%v err=%v\n", res.name, res.conn != nil, res.err) - if res.conn != nil { - res.conn.Close() - } - } -} - -// servePprofFunc is the implementation of Handler.servePprof, after auth, -// for platforms where we want to link it in. -var servePprofFunc func(http.ResponseWriter, *http.Request) - -func (h *Handler) servePprof(w http.ResponseWriter, r *http.Request) { - // Require write access out of paranoia that the profile dump - // might contain something sensitive. - if !h.PermitWrite { - http.Error(w, "profile access denied", http.StatusForbidden) - return - } - if servePprofFunc == nil { - http.Error(w, "not implemented on this platform", http.StatusServiceUnavailable) - return - } - servePprofFunc(w, r) -} - -func (h *Handler) reloadConfig(w http.ResponseWriter, r *http.Request) { - if !h.PermitWrite { - http.Error(w, "access denied", http.StatusForbidden) - return - } - if r.Method != httpm.POST { - http.Error(w, "use POST", http.StatusMethodNotAllowed) - return - } - ok, err := h.b.ReloadConfig() - var res apitype.ReloadConfigResponse - res.Reloaded = ok - if err != nil { - res.Err = err.Error() - return - } - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(&res) -} - -func (h *Handler) serveResetAuth(w http.ResponseWriter, r *http.Request) { - if !h.PermitWrite { - http.Error(w, "reset-auth modify access denied", http.StatusForbidden) - return - } - if r.Method != httpm.POST { - http.Error(w, "use POST", http.StatusMethodNotAllowed) - return - } - - if err := h.b.ResetAuth(); err != nil { - http.Error(w, "reset-auth failed: "+err.Error(), http.StatusInternalServerError) - return - } - w.WriteHeader(http.StatusNoContent) -} - -func (h *Handler) serveServeConfig(w http.ResponseWriter, r *http.Request) { - switch r.Method { - case "GET": - if !h.PermitRead { - http.Error(w, "serve config denied", http.StatusForbidden) - return - } - config := h.b.ServeConfig() - bts, err := json.Marshal(config) - if err != nil { - http.Error(w, "error encoding config: "+err.Error(), http.StatusInternalServerError) - return - } - sum := sha256.Sum256(bts) - etag := hex.EncodeToString(sum[:]) - w.Header().Set("Etag", etag) - w.Header().Set("Content-Type", "application/json") - w.Write(bts) - case "POST": - if !h.PermitWrite { - http.Error(w, "serve config denied", http.StatusForbidden) - return - } - configIn := new(ipn.ServeConfig) - if err := json.NewDecoder(r.Body).Decode(configIn); err != nil { - writeErrorJSON(w, fmt.Errorf("decoding config: %w", err)) - return - } - - // require a local admin when setting a path handler - // TODO: roll-up this Windows-specific check into either PermitWrite - // or a global admin escalation check. - if err := authorizeServeConfigForGOOSAndUserContext(runtime.GOOS, configIn, h); err != nil { - http.Error(w, err.Error(), http.StatusUnauthorized) - return - } - - etag := r.Header.Get("If-Match") - if err := h.b.SetServeConfig(configIn, etag); err != nil { - if errors.Is(err, ipnlocal.ErrETagMismatch) { - http.Error(w, err.Error(), http.StatusPreconditionFailed) - return - } - writeErrorJSON(w, fmt.Errorf("updating config: %w", err)) - return - } - w.WriteHeader(http.StatusOK) - default: - http.Error(w, "method not allowed", http.StatusMethodNotAllowed) - } -} - -func authorizeServeConfigForGOOSAndUserContext(goos string, configIn *ipn.ServeConfig, h *Handler) error { - switch goos { - case "windows", "linux", "darwin": - default: - return nil - } - // Only check for local admin on tailscaled-on-mac (based on "sudo" - // permissions). On sandboxed variants (MacSys and AppStore), tailscaled - // cannot serve files outside of the sandbox and this check is not - // relevant. - if goos == "darwin" && version.IsSandboxedMacOS() { - return nil - } - if !configIn.HasPathHandler() { - return nil - } - if h.Actor.IsLocalAdmin(h.b.OperatorUserID()) { - return nil - } - switch goos { - case "windows": - return errors.New("must be a Windows local admin to serve a path") - case "linux", "darwin": - return errors.New("must be root, or be an operator and able to run 'sudo tailscale' to serve a path") - default: - // We filter goos at the start of the func, this default case - // should never happen. - panic("unreachable") - } - -} - -func (h *Handler) serveCheckIPForwarding(w http.ResponseWriter, r *http.Request) { - if !h.PermitRead { - http.Error(w, "IP forwarding check access denied", http.StatusForbidden) - return - } - var warning string - if err := h.b.CheckIPForwarding(); err != nil { - warning = err.Error() - } - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(struct { - Warning string - }{ - Warning: warning, - }) -} - -func (h *Handler) serveCheckUDPGROForwarding(w http.ResponseWriter, r *http.Request) { - if !h.PermitRead { - http.Error(w, "UDP GRO forwarding check access denied", http.StatusForbidden) - return - } - var warning string - if err := h.b.CheckUDPGROForwarding(); err != nil { - warning = err.Error() - } - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(struct { - Warning string - }{ - Warning: warning, - }) -} - -func (h *Handler) serveSetUDPGROForwarding(w http.ResponseWriter, r *http.Request) { - if !h.PermitWrite { - http.Error(w, "UDP GRO forwarding set access denied", http.StatusForbidden) - return - } - var warning string - if err := h.b.SetUDPGROForwarding(); err != nil { - warning = err.Error() - } - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(struct { - Warning string - }{ - Warning: warning, - }) -} - -func (h *Handler) serveStatus(w http.ResponseWriter, r *http.Request) { - if !h.PermitRead { - http.Error(w, "status access denied", http.StatusForbidden) - return - } - w.Header().Set("Content-Type", "application/json") - var st *ipnstate.Status - if defBool(r.FormValue("peers"), true) { - st = h.b.Status() - } else { - st = h.b.StatusWithoutPeers() - } - e := json.NewEncoder(w) - e.SetIndent("", "\t") - e.Encode(st) -} - -func (h *Handler) serveDebugPeerEndpointChanges(w http.ResponseWriter, r *http.Request) { - if !h.PermitRead { - http.Error(w, "status access denied", http.StatusForbidden) - return - } - - ipStr := r.FormValue("ip") - if ipStr == "" { - http.Error(w, "missing 'ip' parameter", http.StatusBadRequest) - return - } - ip, err := netip.ParseAddr(ipStr) - if err != nil { - http.Error(w, "invalid IP", http.StatusBadRequest) - return - } - w.Header().Set("Content-Type", "application/json") - chs, err := h.b.GetPeerEndpointChanges(r.Context(), ip) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - - e := json.NewEncoder(w) - e.SetIndent("", "\t") - e.Encode(chs) -} - -// InUseOtherUserIPNStream reports whether r is a request for the watch-ipn-bus -// handler. If so, it writes an ipn.Notify InUseOtherUser message to the user -// and returns true. Otherwise it returns false, in which case it doesn't write -// to w. -// -// Unlike the regular watch-ipn-bus handler, this one doesn't block. The caller -// (in ipnserver.Server) provides the blocking until the connection is no longer -// in use. -func InUseOtherUserIPNStream(w http.ResponseWriter, r *http.Request, err error) (handled bool) { - if r.Method != "GET" || r.URL.Path != "/localapi/v0/watch-ipn-bus" { - return false - } - js, err := json.Marshal(&ipn.Notify{ - Version: version.Long(), - State: ptr.To(ipn.InUseOtherUser), - ErrMessage: ptr.To(err.Error()), - }) - if err != nil { - return false - } - js = append(js, '\n') + js = append(js, '\n') w.Header().Set("Content-Type", "application/json") w.Write(js) return true @@ -1220,19 +906,11 @@ func (h *Handler) serveWatchIPNBus(w http.ResponseWriter, r *http.Request) { } mask = ipn.NotifyWatchOpt(v) } - // Users with only read access must request private key filtering. If they - // don't filter out private keys, require write access. - if (mask & ipn.NotifyNoPrivateKeys) == 0 { - if !h.PermitWrite { - http.Error(w, "watch IPN bus access denied, must set ipn.NotifyNoPrivateKeys when not running as admin/root or operator", http.StatusForbidden) - return - } - } w.Header().Set("Content-Type", "application/json") ctx := r.Context() enc := json.NewEncoder(w) - h.b.WatchNotifications(ctx, mask, f.Flush, func(roNotify *ipn.Notify) (keepGoing bool) { + h.b.WatchNotificationsAs(ctx, h.Actor, mask, f.Flush, func(roNotify *ipn.Notify) (keepGoing bool) { err := enc.Encode(roNotify) if err != nil { h.logf("json.Encode: %v", err) @@ -1248,11 +926,11 @@ func (h *Handler) serveLoginInteractive(w http.ResponseWriter, r *http.Request) http.Error(w, "login access denied", http.StatusForbidden) return } - if r.Method != "POST" { + if r.Method != httpm.POST { http.Error(w, "want POST", http.StatusBadRequest) return } - h.b.StartLoginInteractive(r.Context()) + h.b.StartLoginInteractiveAs(r.Context(), h.Actor) w.WriteHeader(http.StatusNoContent) return } @@ -1262,7 +940,7 @@ func (h *Handler) serveStart(w http.ResponseWriter, r *http.Request) { http.Error(w, "access denied", http.StatusForbidden) return } - if r.Method != "POST" { + if r.Method != httpm.POST { http.Error(w, "want POST", http.StatusBadRequest) return } @@ -1285,11 +963,11 @@ func (h *Handler) serveLogout(w http.ResponseWriter, r *http.Request) { http.Error(w, "logout access denied", http.StatusForbidden) return } - if r.Method != "POST" { + if r.Method != httpm.POST { http.Error(w, "want POST", http.StatusBadRequest) return } - err := h.b.Logout(r.Context()) + err := h.b.Logout(r.Context(), h.Actor) if err == nil { w.WriteHeader(http.StatusNoContent) return @@ -1304,7 +982,7 @@ func (h *Handler) servePrefs(w http.ResponseWriter, r *http.Request) { } var prefs ipn.PrefsView switch r.Method { - case "PATCH": + case httpm.PATCH: if !h.PermitWrite { http.Error(w, "prefs write access denied", http.StatusForbidden) return @@ -1314,21 +992,23 @@ func (h *Handler) servePrefs(w http.ResponseWriter, r *http.Request) { http.Error(w, err.Error(), http.StatusBadRequest) return } - if err := h.b.MaybeClearAppConnector(mp); err != nil { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusInternalServerError) - json.NewEncoder(w).Encode(resJSON{Error: err.Error()}) - return + if buildfeatures.HasAppConnectors { + if err := h.b.MaybeClearAppConnector(mp); err != nil { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusInternalServerError) + json.NewEncoder(w).Encode(resJSON{Error: err.Error()}) + return + } } var err error - prefs, err = h.b.EditPrefs(mp) + prefs, err = h.b.EditPrefsAs(mp, h.Actor) if err != nil { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusBadRequest) json.NewEncoder(w).Encode(resJSON{Error: err.Error()}) return } - case "GET", "HEAD": + case httpm.GET, httpm.HEAD: prefs = h.b.Prefs() default: http.Error(w, "unsupported method", http.StatusMethodNotAllowed) @@ -1349,7 +1029,7 @@ func (h *Handler) serveCheckPrefs(w http.ResponseWriter, r *http.Request) { http.Error(w, "checkprefs access denied", http.StatusForbidden) return } - if r.Method != "POST" { + if r.Method != httpm.POST { http.Error(w, "unsupported method", http.StatusMethodNotAllowed) return } @@ -1367,414 +1047,34 @@ func (h *Handler) serveCheckPrefs(w http.ResponseWriter, r *http.Request) { json.NewEncoder(w).Encode(res) } -func (h *Handler) serveFiles(w http.ResponseWriter, r *http.Request) { +// WriteErrorJSON writes a JSON object (with a single "error" string field) to w +// with the given error. If err is nil, "unexpected nil error" is used for the +// stringification instead. +func WriteErrorJSON(w http.ResponseWriter, err error) { + if err == nil { + err = errors.New("unexpected nil error") + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusInternalServerError) + type E struct { + Error string `json:"error"` + } + json.NewEncoder(w).Encode(E{err.Error()}) +} + +func (h *Handler) serveSetDNS(w http.ResponseWriter, r *http.Request) { if !h.PermitWrite { - http.Error(w, "file access denied", http.StatusForbidden) + http.Error(w, "access denied", http.StatusForbidden) return } - suffix, ok := strings.CutPrefix(r.URL.EscapedPath(), "/localapi/v0/files/") - if !ok { - http.Error(w, "misconfigured", http.StatusInternalServerError) - return - } - if suffix == "" { - if r.Method != "GET" { - http.Error(w, "want GET to list files", http.StatusBadRequest) - return - } - ctx := r.Context() - if s := r.FormValue("waitsec"); s != "" && s != "0" { - d, err := strconv.Atoi(s) - if err != nil { - http.Error(w, "invalid waitsec", http.StatusBadRequest) - return - } - deadline := time.Now().Add(time.Duration(d) * time.Second) - var cancel context.CancelFunc - ctx, cancel = context.WithDeadline(ctx, deadline) - defer cancel() - } - wfs, err := h.b.AwaitWaitingFiles(ctx) - if err != nil && ctx.Err() == nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(wfs) - return - } - name, err := url.PathUnescape(suffix) - if err != nil { - http.Error(w, "bad filename", http.StatusBadRequest) - return - } - if r.Method == "DELETE" { - if err := h.b.DeleteFile(name); err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - w.WriteHeader(http.StatusNoContent) - return - } - rc, size, err := h.b.OpenFile(name) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - defer rc.Close() - w.Header().Set("Content-Length", fmt.Sprint(size)) - w.Header().Set("Content-Type", "application/octet-stream") - io.Copy(w, rc) -} - -func writeErrorJSON(w http.ResponseWriter, err error) { - if err == nil { - err = errors.New("unexpected nil error") - } - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusInternalServerError) - type E struct { - Error string `json:"error"` - } - json.NewEncoder(w).Encode(E{err.Error()}) -} - -func (h *Handler) serveFileTargets(w http.ResponseWriter, r *http.Request) { - if !h.PermitRead { - http.Error(w, "access denied", http.StatusForbidden) - return - } - if r.Method != "GET" { - http.Error(w, "want GET to list targets", http.StatusBadRequest) - return - } - fts, err := h.b.FileTargets() - if err != nil { - writeErrorJSON(w, err) - return - } - mak.NonNilSliceForJSON(&fts) - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(fts) -} - -// serveFilePut sends a file to another node. -// -// It's sometimes possible for clients to do this themselves, without -// tailscaled, except in the case of tailscaled running in -// userspace-networking ("netstack") mode, in which case tailscaled -// needs to a do a netstack dial out. -// -// Instead, the CLI also goes through tailscaled so it doesn't need to be -// aware of the network mode in use. -// -// macOS/iOS have always used this localapi method to simplify the GUI -// clients. -// -// The Windows client currently (2021-11-30) uses the peerapi (/v0/put/) -// directly, as the Windows GUI always runs in tun mode anyway. -// -// In addition to single file PUTs, this endpoint accepts multipart file -// POSTS encoded as multipart/form-data.The first part should be an -// application/json file that contains a manifest consisting of a JSON array of -// OutgoingFiles which wecan use for tracking progress even before reading the -// file parts. -// -// URL format: -// -// - PUT /localapi/v0/file-put/:stableID/:escaped-filename -// - POST /localapi/v0/file-put/:stableID -func (h *Handler) serveFilePut(w http.ResponseWriter, r *http.Request) { - metricFilePutCalls.Add(1) - - if !h.PermitWrite { - http.Error(w, "file access denied", http.StatusForbidden) - return - } - - if r.Method != "PUT" && r.Method != "POST" { - http.Error(w, "want PUT to put file", http.StatusBadRequest) - return - } - - fts, err := h.b.FileTargets() - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - - upath, ok := strings.CutPrefix(r.URL.EscapedPath(), "/localapi/v0/file-put/") - if !ok { - http.Error(w, "misconfigured", http.StatusInternalServerError) - return - } - var peerIDStr, filenameEscaped string - if r.Method == "PUT" { - ok := false - peerIDStr, filenameEscaped, ok = strings.Cut(upath, "/") - if !ok { - http.Error(w, "bogus URL", http.StatusBadRequest) - return - } - } else { - peerIDStr = upath - } - peerID := tailcfg.StableNodeID(peerIDStr) - - var ft *apitype.FileTarget - for _, x := range fts { - if x.Node.StableID == peerID { - ft = x - break - } - } - if ft == nil { - http.Error(w, "node not found", http.StatusNotFound) - return - } - dstURL, err := url.Parse(ft.PeerAPIURL) - if err != nil { - http.Error(w, "bogus peer URL", http.StatusInternalServerError) - return - } - - // Periodically report progress of outgoing files. - outgoingFiles := make(map[string]*ipn.OutgoingFile) - t := time.NewTicker(1 * time.Second) - progressUpdates := make(chan ipn.OutgoingFile) - defer close(progressUpdates) - - go func() { - defer t.Stop() - defer h.b.UpdateOutgoingFiles(outgoingFiles) - for { - select { - case u, ok := <-progressUpdates: - if !ok { - return - } - outgoingFiles[u.ID] = &u - case <-t.C: - h.b.UpdateOutgoingFiles(outgoingFiles) - } - } - }() - - switch r.Method { - case "PUT": - file := ipn.OutgoingFile{ - ID: uuid.Must(uuid.NewRandom()).String(), - PeerID: peerID, - Name: filenameEscaped, - DeclaredSize: r.ContentLength, - } - h.singleFilePut(r.Context(), progressUpdates, w, r.Body, dstURL, file) - case "POST": - h.multiFilePost(progressUpdates, w, r, peerID, dstURL) - default: - http.Error(w, "want PUT to put file", http.StatusBadRequest) - return - } -} - -func (h *Handler) multiFilePost(progressUpdates chan (ipn.OutgoingFile), w http.ResponseWriter, r *http.Request, peerID tailcfg.StableNodeID, dstURL *url.URL) { - _, params, err := mime.ParseMediaType(r.Header.Get("Content-Type")) - if err != nil { - http.Error(w, fmt.Sprintf("invalid Content-Type for multipart POST: %s", err), http.StatusBadRequest) - return - } - - ww := &multiFilePostResponseWriter{} - defer func() { - if err := ww.Flush(w); err != nil { - h.logf("error: multiFilePostResponseWriter.Flush(): %s", err) - } - }() - - outgoingFilesByName := make(map[string]ipn.OutgoingFile) - first := true - mr := multipart.NewReader(r.Body, params["boundary"]) - for { - part, err := mr.NextPart() - if err == io.EOF { - // No more parts. - return - } else if err != nil { - http.Error(ww, fmt.Sprintf("failed to decode multipart/form-data: %s", err), http.StatusBadRequest) - return - } - - if first { - first = false - if part.Header.Get("Content-Type") != "application/json" { - http.Error(ww, "first MIME part must be a JSON map of filename -> size", http.StatusBadRequest) - return - } - - var manifest []ipn.OutgoingFile - err := json.NewDecoder(part).Decode(&manifest) - if err != nil { - http.Error(ww, fmt.Sprintf("invalid manifest: %s", err), http.StatusBadRequest) - return - } - - for _, file := range manifest { - outgoingFilesByName[file.Name] = file - progressUpdates <- file - } - - continue - } - - if !h.singleFilePut(r.Context(), progressUpdates, ww, part, dstURL, outgoingFilesByName[part.FileName()]) { - return - } - - if ww.statusCode >= 400 { - // put failed, stop immediately - h.logf("error: singleFilePut: failed with status %d", ww.statusCode) - return - } - } -} - -// multiFilePostResponseWriter is a buffering http.ResponseWriter that can be -// reused across multiple singleFilePut calls and then flushed to the client -// when all files have been PUT. -type multiFilePostResponseWriter struct { - header http.Header - statusCode int - body *bytes.Buffer -} - -func (ww *multiFilePostResponseWriter) Header() http.Header { - if ww.header == nil { - ww.header = make(http.Header) - } - return ww.header -} - -func (ww *multiFilePostResponseWriter) WriteHeader(statusCode int) { - ww.statusCode = statusCode -} - -func (ww *multiFilePostResponseWriter) Write(p []byte) (int, error) { - if ww.body == nil { - ww.body = bytes.NewBuffer(nil) - } - return ww.body.Write(p) -} - -func (ww *multiFilePostResponseWriter) Flush(w http.ResponseWriter) error { - if ww.header != nil { - maps.Copy(w.Header(), ww.header) - } - if ww.statusCode > 0 { - w.WriteHeader(ww.statusCode) - } - if ww.body != nil { - _, err := io.Copy(w, ww.body) - return err - } - return nil -} - -func (h *Handler) singleFilePut( - ctx context.Context, - progressUpdates chan (ipn.OutgoingFile), - w http.ResponseWriter, - body io.Reader, - dstURL *url.URL, - outgoingFile ipn.OutgoingFile, -) bool { - outgoingFile.Started = time.Now() - body = progresstracking.NewReader(body, 1*time.Second, func(n int, err error) { - outgoingFile.Sent = int64(n) - progressUpdates <- outgoingFile - }) - - fail := func() { - outgoingFile.Finished = true - outgoingFile.Succeeded = false - progressUpdates <- outgoingFile - } - - // Before we PUT a file we check to see if there are any existing partial file and if so, - // we resume the upload from where we left off by sending the remaining file instead of - // the full file. - var offset int64 - var resumeDuration time.Duration - remainingBody := io.Reader(body) - client := &http.Client{ - Transport: h.b.Dialer().PeerAPITransport(), - Timeout: 10 * time.Second, - } - req, err := http.NewRequestWithContext(ctx, "GET", dstURL.String()+"/v0/put/"+outgoingFile.Name, nil) - if err != nil { - http.Error(w, "bogus peer URL", http.StatusInternalServerError) - fail() - return false - } - switch resp, err := client.Do(req); { - case err != nil: - h.logf("could not fetch remote hashes: %v", err) - case resp.StatusCode == http.StatusMethodNotAllowed || resp.StatusCode == http.StatusNotFound: - // noop; implies older peerapi without resume support - case resp.StatusCode != http.StatusOK: - h.logf("fetch remote hashes status code: %d", resp.StatusCode) - default: - resumeStart := time.Now() - dec := json.NewDecoder(resp.Body) - offset, remainingBody, err = taildrop.ResumeReader(body, func() (out taildrop.BlockChecksum, err error) { - err = dec.Decode(&out) - return out, err - }) - if err != nil { - h.logf("reader could not be fully resumed: %v", err) - } - resumeDuration = time.Since(resumeStart).Round(time.Millisecond) - } - - outReq, err := http.NewRequestWithContext(ctx, "PUT", "http://peer/v0/put/"+outgoingFile.Name, remainingBody) - if err != nil { - http.Error(w, "bogus outreq", http.StatusInternalServerError) - fail() - return false - } - outReq.ContentLength = outgoingFile.DeclaredSize - if offset > 0 { - h.logf("resuming put at offset %d after %v", offset, resumeDuration) - rangeHdr, _ := httphdr.FormatRange([]httphdr.Range{{Start: offset, Length: 0}}) - outReq.Header.Set("Range", rangeHdr) - if outReq.ContentLength >= 0 { - outReq.ContentLength -= offset - } - } - - rp := httputil.NewSingleHostReverseProxy(dstURL) - rp.Transport = h.b.Dialer().PeerAPITransport() - rp.ServeHTTP(w, outReq) - - outgoingFile.Finished = true - outgoingFile.Succeeded = true - progressUpdates <- outgoingFile - - return true -} - -func (h *Handler) serveSetDNS(w http.ResponseWriter, r *http.Request) { - if !h.PermitWrite { - http.Error(w, "access denied", http.StatusForbidden) - return - } - if r.Method != "POST" { - http.Error(w, "want POST", http.StatusBadRequest) + if r.Method != httpm.POST { + http.Error(w, "want POST", http.StatusBadRequest) return } ctx := r.Context() err := h.b.SetDNS(ctx, r.FormValue("name"), r.FormValue("value")) if err != nil { - writeErrorJSON(w, err) + WriteErrorJSON(w, err) return } w.Header().Set("Content-Type", "application/json") @@ -1782,7 +1082,7 @@ func (h *Handler) serveSetDNS(w http.ResponseWriter, r *http.Request) { } func (h *Handler) serveDERPMap(w http.ResponseWriter, r *http.Request) { - if r.Method != "GET" { + if r.Method != httpm.GET { http.Error(w, "want GET", http.StatusBadRequest) return } @@ -1799,7 +1099,7 @@ func (h *Handler) serveSetExpirySooner(w http.ResponseWriter, r *http.Request) { http.Error(w, "access denied", http.StatusForbidden) return } - if r.Method != "POST" { + if r.Method != httpm.POST { http.Error(w, "POST required", http.StatusMethodNotAllowed) return } @@ -1827,7 +1127,7 @@ func (h *Handler) serveSetExpirySooner(w http.ResponseWriter, r *http.Request) { func (h *Handler) servePing(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - if r.Method != "POST" { + if r.Method != httpm.POST { http.Error(w, "want POST", http.StatusBadRequest) return } @@ -1865,7 +1165,7 @@ func (h *Handler) servePing(w http.ResponseWriter, r *http.Request) { } res, err := h.b.Ping(ctx, ip, tailcfg.PingType(pingTypeStr), size) if err != nil { - writeErrorJSON(w, err) + WriteErrorJSON(w, err) return } w.Header().Set("Content-Type", "application/json") @@ -1873,7 +1173,7 @@ func (h *Handler) servePing(w http.ResponseWriter, r *http.Request) { } func (h *Handler) serveDial(w http.ResponseWriter, r *http.Request) { - if r.Method != "POST" { + if r.Method != httpm.POST { http.Error(w, "POST required", http.StatusMethodNotAllowed) return } @@ -1936,7 +1236,7 @@ func (h *Handler) serveSetPushDeviceToken(w http.ResponseWriter, r *http.Request http.Error(w, "set push device token access denied", http.StatusForbidden) return } - if r.Method != "POST" { + if r.Method != httpm.POST { http.Error(w, "unsupported method", http.StatusMethodNotAllowed) return } @@ -1954,7 +1254,7 @@ func (h *Handler) serveHandlePushMessage(w http.ResponseWriter, r *http.Request) http.Error(w, "handle push message not allowed", http.StatusForbidden) return } - if r.Method != "POST" { + if r.Method != httpm.POST { http.Error(w, "unsupported method", http.StatusMethodNotAllowed) return } @@ -1964,478 +1264,108 @@ func (h *Handler) serveHandlePushMessage(w http.ResponseWriter, r *http.Request) return } - // TODO(bradfitz): do something with pushMessageBody - h.logf("localapi: got push message: %v", logger.AsJSON(pushMessageBody)) - - w.WriteHeader(http.StatusNoContent) -} - -func (h *Handler) serveUploadClientMetrics(w http.ResponseWriter, r *http.Request) { - if r.Method != "POST" { - http.Error(w, "unsupported method", http.StatusMethodNotAllowed) - return - } - type clientMetricJSON struct { - Name string `json:"name"` - Type string `json:"type"` // one of "counter" or "gauge" - Value int `json:"value"` // amount to increment metric by - } - - var clientMetrics []clientMetricJSON - if err := json.NewDecoder(r.Body).Decode(&clientMetrics); err != nil { - http.Error(w, "invalid JSON body", http.StatusBadRequest) - return - } - - metricsMu.Lock() - defer metricsMu.Unlock() - - for _, m := range clientMetrics { - if metric, ok := metrics[m.Name]; ok { - metric.Add(int64(m.Value)) - } else { - if clientmetric.HasPublished(m.Name) { - http.Error(w, "Already have a metric named "+m.Name, http.StatusBadRequest) - return - } - var metric *clientmetric.Metric - switch m.Type { - case "counter": - metric = clientmetric.NewCounter(m.Name) - case "gauge": - metric = clientmetric.NewGauge(m.Name) - default: - http.Error(w, "Unknown metric type "+m.Type, http.StatusBadRequest) - return - } - metrics[m.Name] = metric - metric.Add(int64(m.Value)) - } - } - - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(struct{}{}) -} - -func (h *Handler) serveTKAStatus(w http.ResponseWriter, r *http.Request) { - if !h.PermitRead { - http.Error(w, "lock status access denied", http.StatusForbidden) - return - } - if r.Method != httpm.GET { - http.Error(w, "use GET", http.StatusMethodNotAllowed) - return - } - - j, err := json.MarshalIndent(h.b.NetworkLockStatus(), "", "\t") - if err != nil { - http.Error(w, "JSON encoding error", http.StatusInternalServerError) - return - } - w.Header().Set("Content-Type", "application/json") - w.Write(j) -} - -func (h *Handler) serveSetGUIVisible(w http.ResponseWriter, r *http.Request) { - if r.Method != httpm.POST { - http.Error(w, "use POST", http.StatusMethodNotAllowed) - return - } - - type setGUIVisibleRequest struct { - IsVisible bool // whether the Tailscale client UI is now presented to the user - SessionID string // the last SessionID sent to the client in ipn.Notify.SessionID - } - var req setGUIVisibleRequest - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - http.Error(w, "invalid JSON body", http.StatusBadRequest) - return - } - - // TODO(bradfitz): use `req.IsVisible == true` to flush netmap - - w.WriteHeader(http.StatusOK) -} - -func (h *Handler) serveSetUseExitNodeEnabled(w http.ResponseWriter, r *http.Request) { - if r.Method != httpm.POST { - http.Error(w, "use POST", http.StatusMethodNotAllowed) - return - } - if !h.PermitWrite { - http.Error(w, "access denied", http.StatusForbidden) - return - } - - v, err := strconv.ParseBool(r.URL.Query().Get("enabled")) - if err != nil { - http.Error(w, "invalid 'enabled' parameter", http.StatusBadRequest) - return - } - prefs, err := h.b.SetUseExitNodeEnabled(v) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - w.Header().Set("Content-Type", "application/json") - e := json.NewEncoder(w) - e.SetIndent("", "\t") - e.Encode(prefs) -} - -func (h *Handler) serveTKASign(w http.ResponseWriter, r *http.Request) { - if !h.PermitWrite { - http.Error(w, "lock sign access denied", http.StatusForbidden) - return - } - if r.Method != httpm.POST { - http.Error(w, "use POST", http.StatusMethodNotAllowed) - return - } - - type signRequest struct { - NodeKey key.NodePublic - RotationPublic []byte - } - var req signRequest - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - http.Error(w, "invalid JSON body", http.StatusBadRequest) - return - } - - if err := h.b.NetworkLockSign(req.NodeKey, req.RotationPublic); err != nil { - http.Error(w, "signing failed: "+err.Error(), http.StatusInternalServerError) - return - } - - w.WriteHeader(http.StatusOK) -} - -func (h *Handler) serveTKAInit(w http.ResponseWriter, r *http.Request) { - if !h.PermitWrite { - http.Error(w, "lock init access denied", http.StatusForbidden) - return - } - if r.Method != httpm.POST { - http.Error(w, "use POST", http.StatusMethodNotAllowed) - return - } - - type initRequest struct { - Keys []tka.Key - DisablementValues [][]byte - SupportDisablement []byte - } - var req initRequest - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - http.Error(w, "invalid JSON body", http.StatusBadRequest) - return - } - - if err := h.b.NetworkLockInit(req.Keys, req.DisablementValues, req.SupportDisablement); err != nil { - http.Error(w, "initialization failed: "+err.Error(), http.StatusInternalServerError) - return - } - - j, err := json.MarshalIndent(h.b.NetworkLockStatus(), "", "\t") - if err != nil { - http.Error(w, "JSON encoding error", http.StatusInternalServerError) - return - } - w.Header().Set("Content-Type", "application/json") - w.Write(j) -} - -func (h *Handler) serveTKAModify(w http.ResponseWriter, r *http.Request) { - if !h.PermitWrite { - http.Error(w, "network-lock modify access denied", http.StatusForbidden) - return - } - if r.Method != httpm.POST { - http.Error(w, "use POST", http.StatusMethodNotAllowed) - return - } - - type modifyRequest struct { - AddKeys []tka.Key - RemoveKeys []tka.Key - } - var req modifyRequest - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - http.Error(w, "invalid JSON body", http.StatusBadRequest) - return - } - - if err := h.b.NetworkLockModify(req.AddKeys, req.RemoveKeys); err != nil { - http.Error(w, "network-lock modify failed: "+err.Error(), http.StatusInternalServerError) - return - } - w.WriteHeader(204) -} - -func (h *Handler) serveTKAWrapPreauthKey(w http.ResponseWriter, r *http.Request) { - if !h.PermitWrite { - http.Error(w, "network-lock modify access denied", http.StatusForbidden) - return - } - if r.Method != httpm.POST { - http.Error(w, "use POST", http.StatusMethodNotAllowed) - return - } - - type wrapRequest struct { - TSKey string - TKAKey string // key.NLPrivate.MarshalText - } - var req wrapRequest - if err := json.NewDecoder(http.MaxBytesReader(w, r.Body, 12*1024)).Decode(&req); err != nil { - http.Error(w, "invalid JSON body", http.StatusBadRequest) - return - } - var priv key.NLPrivate - if err := priv.UnmarshalText([]byte(req.TKAKey)); err != nil { - http.Error(w, "invalid JSON body", http.StatusBadRequest) - return - } - - wrappedKey, err := h.b.NetworkLockWrapPreauthKey(req.TSKey, priv) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - w.WriteHeader(http.StatusOK) - w.Write([]byte(wrappedKey)) -} - -func (h *Handler) serveTKAVerifySigningDeeplink(w http.ResponseWriter, r *http.Request) { - if !h.PermitRead { - http.Error(w, "signing deeplink verification access denied", http.StatusForbidden) - return - } - if r.Method != httpm.POST { - http.Error(w, "use POST", http.StatusMethodNotAllowed) - return - } - - type verifyRequest struct { - URL string - } - var req verifyRequest - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - http.Error(w, "invalid JSON for verifyRequest body", http.StatusBadRequest) - return - } - - res := h.b.NetworkLockVerifySigningDeeplink(req.URL) - j, err := json.MarshalIndent(res, "", "\t") - if err != nil { - http.Error(w, "JSON encoding error", http.StatusInternalServerError) - return - } - w.Header().Set("Content-Type", "application/json") - w.Write(j) -} - -func (h *Handler) serveTKADisable(w http.ResponseWriter, r *http.Request) { - if !h.PermitWrite { - http.Error(w, "network-lock modify access denied", http.StatusForbidden) - return - } - if r.Method != httpm.POST { - http.Error(w, "use POST", http.StatusMethodNotAllowed) - return - } - - body := io.LimitReader(r.Body, 1024*1024) - secret, err := io.ReadAll(body) - if err != nil { - http.Error(w, "reading secret", http.StatusBadRequest) - return - } - - if err := h.b.NetworkLockDisable(secret); err != nil { - http.Error(w, "network-lock disable failed: "+err.Error(), http.StatusBadRequest) - return - } - w.WriteHeader(http.StatusOK) -} - -func (h *Handler) serveTKALocalDisable(w http.ResponseWriter, r *http.Request) { - if !h.PermitWrite { - http.Error(w, "network-lock modify access denied", http.StatusForbidden) - return - } - if r.Method != httpm.POST { - http.Error(w, "use POST", http.StatusMethodNotAllowed) - return - } - - // Require a JSON stanza for the body as an additional CSRF protection. - var req struct{} - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - http.Error(w, "invalid JSON body", http.StatusBadRequest) - return - } - - if err := h.b.NetworkLockForceLocalDisable(); err != nil { - http.Error(w, "network-lock local disable failed: "+err.Error(), http.StatusBadRequest) - return - } - w.WriteHeader(http.StatusOK) -} - -func (h *Handler) serveTKALog(w http.ResponseWriter, r *http.Request) { - if r.Method != httpm.GET { - http.Error(w, "use GET", http.StatusMethodNotAllowed) - return - } - - limit := 50 - if limitStr := r.FormValue("limit"); limitStr != "" { - l, err := strconv.Atoi(limitStr) - if err != nil { - http.Error(w, "parsing 'limit' parameter: "+err.Error(), http.StatusBadRequest) - return - } - limit = int(l) - } - - updates, err := h.b.NetworkLockLog(limit) - if err != nil { - http.Error(w, "reading log failed: "+err.Error(), http.StatusInternalServerError) - return - } - - j, err := json.MarshalIndent(updates, "", "\t") - if err != nil { - http.Error(w, "JSON encoding error", http.StatusInternalServerError) - return - } - w.Header().Set("Content-Type", "application/json") - w.Write(j) + // TODO(bradfitz): do something with pushMessageBody + h.logf("localapi: got push message: %v", logger.AsJSON(pushMessageBody)) + + w.WriteHeader(http.StatusNoContent) } -func (h *Handler) serveTKAAffectedSigs(w http.ResponseWriter, r *http.Request) { +func (h *Handler) serveUploadClientMetrics(w http.ResponseWriter, r *http.Request) { if r.Method != httpm.POST { - http.Error(w, "use POST", http.StatusMethodNotAllowed) + http.Error(w, "unsupported method", http.StatusMethodNotAllowed) return } - keyID, err := io.ReadAll(http.MaxBytesReader(w, r.Body, 2048)) - if err != nil { - http.Error(w, "reading body", http.StatusBadRequest) - return + type clientMetricJSON struct { + Name string `json:"name"` + Type string `json:"type"` // one of "counter" or "gauge" + Value int `json:"value"` // amount to increment metric by } - sigs, err := h.b.NetworkLockAffectedSigs(keyID) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) + var clientMetrics []clientMetricJSON + if err := json.NewDecoder(r.Body).Decode(&clientMetrics); err != nil { + http.Error(w, "invalid JSON body", http.StatusBadRequest) return } - j, err := json.MarshalIndent(sigs, "", "\t") - if err != nil { - http.Error(w, "JSON encoding error", http.StatusInternalServerError) - return + metricsMu.Lock() + defer metricsMu.Unlock() + + for _, m := range clientMetrics { + if metric, ok := metrics[m.Name]; ok { + metric.Add(int64(m.Value)) + } else { + if clientmetric.HasPublished(m.Name) { + http.Error(w, "Already have a metric named "+m.Name, http.StatusBadRequest) + return + } + var metric *clientmetric.Metric + switch m.Type { + case "counter": + metric = clientmetric.NewCounter(m.Name) + case "gauge": + metric = clientmetric.NewGauge(m.Name) + default: + http.Error(w, "Unknown metric type "+m.Type, http.StatusBadRequest) + return + } + metrics[m.Name] = metric + metric.Add(int64(m.Value)) + } } + w.Header().Set("Content-Type", "application/json") - w.Write(j) + json.NewEncoder(w).Encode(struct{}{}) } -func (h *Handler) serveTKAGenerateRecoveryAUM(w http.ResponseWriter, r *http.Request) { - if !h.PermitWrite { - http.Error(w, "access denied", http.StatusForbidden) - return - } +func (h *Handler) serveSetGUIVisible(w http.ResponseWriter, r *http.Request) { if r.Method != httpm.POST { http.Error(w, "use POST", http.StatusMethodNotAllowed) return } - type verifyRequest struct { - Keys []tkatype.KeyID - ForkFrom string + type setGUIVisibleRequest struct { + IsVisible bool // whether the Tailscale client UI is now presented to the user + SessionID string // the last SessionID sent to the client in ipn.Notify.SessionID } - var req verifyRequest + var req setGUIVisibleRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - http.Error(w, "invalid JSON for verifyRequest body", http.StatusBadRequest) + http.Error(w, "invalid JSON body", http.StatusBadRequest) return } - var forkFrom tka.AUMHash - if req.ForkFrom != "" { - if err := forkFrom.UnmarshalText([]byte(req.ForkFrom)); err != nil { - http.Error(w, "decoding fork-from: "+err.Error(), http.StatusBadRequest) - return - } - } + // TODO(bradfitz): use `req.IsVisible == true` to flush netmap - res, err := h.b.NetworkLockGenerateRecoveryAUM(req.Keys, forkFrom) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - w.Header().Set("Content-Type", "application/octet-stream") - w.Write(res.Serialize()) + w.WriteHeader(http.StatusOK) } -func (h *Handler) serveTKACosignRecoveryAUM(w http.ResponseWriter, r *http.Request) { - if !h.PermitWrite { - http.Error(w, "access denied", http.StatusForbidden) +func (h *Handler) serveSetUseExitNodeEnabled(w http.ResponseWriter, r *http.Request) { + if !buildfeatures.HasUseExitNode { + http.Error(w, feature.ErrUnavailable.Error(), http.StatusNotImplemented) return } if r.Method != httpm.POST { http.Error(w, "use POST", http.StatusMethodNotAllowed) return } - - body := io.LimitReader(r.Body, 1024*1024) - aumBytes, err := io.ReadAll(body) - if err != nil { - http.Error(w, "reading AUM", http.StatusBadRequest) - return - } - var aum tka.AUM - if err := aum.Unserialize(aumBytes); err != nil { - http.Error(w, "decoding AUM", http.StatusBadRequest) - return - } - - res, err := h.b.NetworkLockCosignRecoveryAUM(&aum) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - w.Header().Set("Content-Type", "application/octet-stream") - w.Write(res.Serialize()) -} - -func (h *Handler) serveTKASubmitRecoveryAUM(w http.ResponseWriter, r *http.Request) { if !h.PermitWrite { http.Error(w, "access denied", http.StatusForbidden) return } - if r.Method != httpm.POST { - http.Error(w, "use POST", http.StatusMethodNotAllowed) - return - } - body := io.LimitReader(r.Body, 1024*1024) - aumBytes, err := io.ReadAll(body) + v, err := strconv.ParseBool(r.URL.Query().Get("enabled")) if err != nil { - http.Error(w, "reading AUM", http.StatusBadRequest) - return - } - var aum tka.AUM - if err := aum.Unserialize(aumBytes); err != nil { - http.Error(w, "decoding AUM", http.StatusBadRequest) + http.Error(w, "invalid 'enabled' parameter", http.StatusBadRequest) return } - - if err := h.b.NetworkLockSubmitRecoveryAUM(&aum); err != nil { + prefs, err := h.b.SetUseExitNodeEnabled(h.Actor, v) + if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } - w.WriteHeader(http.StatusOK) + w.Header().Set("Content-Type", "application/json") + e := json.NewEncoder(w) + e.SetIndent("", "\t") + e.Encode(prefs) } // serveProfiles serves profile switching-related endpoints. Supported methods @@ -2494,8 +1424,8 @@ func (h *Handler) serveProfiles(w http.ResponseWriter, r *http.Request) { switch r.Method { case httpm.GET: profiles := h.b.ListProfiles() - profileIndex := slices.IndexFunc(profiles, func(p ipn.LoginProfile) bool { - return p.ID == profileID + profileIndex := slices.IndexFunc(profiles, func(p ipn.LoginProfileView) bool { + return p.ID() == profileID }) if profileIndex == -1 { http.Error(w, "Profile not found", http.StatusNotFound) @@ -2562,7 +1492,7 @@ func (h *Handler) serveQueryFeature(w http.ResponseWriter, r *http.Request) { } req, err := http.NewRequestWithContext(r.Context(), - "POST", "https://unused/machine/feature/query", bytes.NewReader(b)) + httpm.POST, "https://unused/machine/feature/query", bytes.NewReader(b)) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return @@ -2593,62 +1523,6 @@ func defBool(a string, def bool) bool { return v } -func (h *Handler) serveDebugCapture(w http.ResponseWriter, r *http.Request) { - if !h.PermitWrite { - http.Error(w, "debug access denied", http.StatusForbidden) - return - } - if r.Method != "POST" { - http.Error(w, "POST required", http.StatusMethodNotAllowed) - return - } - - w.WriteHeader(http.StatusOK) - w.(http.Flusher).Flush() - h.b.StreamDebugCapture(r.Context(), w) -} - -func (h *Handler) serveDebugLog(w http.ResponseWriter, r *http.Request) { - if !h.PermitRead { - http.Error(w, "debug-log access denied", http.StatusForbidden) - return - } - if r.Method != httpm.POST { - http.Error(w, "only POST allowed", http.StatusMethodNotAllowed) - return - } - defer h.b.TryFlushLogs() // kick off upload after we're done logging - - type logRequestJSON struct { - Lines []string - Prefix string - } - - var logRequest logRequestJSON - if err := json.NewDecoder(r.Body).Decode(&logRequest); err != nil { - http.Error(w, "invalid JSON body", http.StatusBadRequest) - return - } - - prefix := logRequest.Prefix - if prefix == "" { - prefix = "debug-log" - } - logf := logger.WithPrefix(h.logf, prefix+": ") - - // We can write logs too fast for logtail to handle, even when - // opting-out of rate limits. Limit ourselves to at most one message - // per 20ms and a burst of 60 log lines, which should be fast enough to - // not block for too long but slow enough that we can upload all lines. - logf = logger.SlowLoggerWithClock(r.Context(), logf, 20*time.Millisecond, 60, h.clock.Now) - - for _, line := range logRequest.Lines { - logf("%s", line) - } - - w.WriteHeader(http.StatusNoContent) -} - // serveUpdateCheck returns the ClientVersion from Status, which contains // information on whether an update is available, and if so, what version, // *if* we support auto-updates on this platform. If we don't, this endpoint @@ -2656,17 +1530,10 @@ func (h *Handler) serveDebugLog(w http.ResponseWriter, r *http.Request) { // Effectively, it tells us whether serveUpdateInstall will be able to install // an update for us. func (h *Handler) serveUpdateCheck(w http.ResponseWriter, r *http.Request) { - if r.Method != "GET" { + if r.Method != httpm.GET { http.Error(w, "only GET allowed", http.StatusMethodNotAllowed) return } - - if !clientupdate.CanAutoUpdate() { - // if we don't support auto-update, just say that we're up to date - json.NewEncoder(w).Encode(tailcfg.ClientVersion{RunningLatest: true}) - return - } - cv := h.b.StatusWithoutPeers().ClientVersion // ipnstate.Status documentation notes that ClientVersion may be nil on some // platforms where this information is unavailable. In that case, return a @@ -2679,40 +1546,13 @@ func (h *Handler) serveUpdateCheck(w http.ResponseWriter, r *http.Request) { json.NewEncoder(w).Encode(cv) } -// serveUpdateInstall sends a request to the LocalBackend to start a Tailscale -// self-update. A successful response does not indicate whether the update -// succeeded, only that the request was accepted. Clients should use -// serveUpdateProgress after pinging this endpoint to check how the update is -// going. -func (h *Handler) serveUpdateInstall(w http.ResponseWriter, r *http.Request) { - if r.Method != "POST" { - http.Error(w, "only POST allowed", http.StatusMethodNotAllowed) - return - } - - w.WriteHeader(http.StatusAccepted) - - go h.b.DoSelfUpdate() -} - -// serveUpdateProgress returns the status of an in-progress Tailscale self-update. -// This is provided as a slice of ipnstate.UpdateProgress structs with various -// log messages in order from oldest to newest. If an update is not in progress, -// the returned slice will be empty. -func (h *Handler) serveUpdateProgress(w http.ResponseWriter, r *http.Request) { - if r.Method != "GET" { - http.Error(w, "only GET allowed", http.StatusMethodNotAllowed) - return - } - - ups := h.b.GetSelfUpdateProgress() - - json.NewEncoder(w).Encode(ups) -} - // serveDNSOSConfig serves the current system DNS configuration as a JSON object, if // supported by the OS. func (h *Handler) serveDNSOSConfig(w http.ResponseWriter, r *http.Request) { + if !buildfeatures.HasDNS { + http.Error(w, feature.ErrUnavailable.Error(), http.StatusNotImplemented) + return + } if r.Method != httpm.GET { http.Error(w, "only GET allowed", http.StatusMethodNotAllowed) return @@ -2756,7 +1596,11 @@ func (h *Handler) serveDNSOSConfig(w http.ResponseWriter, r *http.Request) { // // The response if successful is a DNSQueryResponse JSON object. func (h *Handler) serveDNSQuery(w http.ResponseWriter, r *http.Request) { - if r.Method != "GET" { + if !buildfeatures.HasDNS { + http.Error(w, feature.ErrUnavailable.Error(), http.StatusNotImplemented) + return + } + if r.Method != httpm.GET { http.Error(w, "only GET allowed", http.StatusMethodNotAllowed) return } @@ -2770,7 +1614,7 @@ func (h *Handler) serveDNSQuery(w http.ResponseWriter, r *http.Request) { queryType := q.Get("type") qt := dnsmessage.TypeA if queryType != "" { - t, err := dnstype.DNSMessageTypeForString(queryType) + t, err := dnsMessageTypeForString(queryType) if err != nil { http.Error(w, err.Error(), http.StatusBadRequest) return @@ -2791,141 +1635,114 @@ func (h *Handler) serveDNSQuery(w http.ResponseWriter, r *http.Request) { }) } -// serveDriveServerAddr handles updates of the Taildrive file server address. -func (h *Handler) serveDriveServerAddr(w http.ResponseWriter, r *http.Request) { - if r.Method != "PUT" { - http.Error(w, "only PUT allowed", http.StatusMethodNotAllowed) +// dnsMessageTypeForString returns the dnsmessage.Type for the given string. +// For example, DNSMessageTypeForString("A") returns dnsmessage.TypeA. +func dnsMessageTypeForString(s string) (t dnsmessage.Type, err error) { + s = strings.TrimSpace(strings.ToUpper(s)) + switch s { + case "AAAA": + return dnsmessage.TypeAAAA, nil + case "ALL": + return dnsmessage.TypeALL, nil + case "A": + return dnsmessage.TypeA, nil + case "CNAME": + return dnsmessage.TypeCNAME, nil + case "HINFO": + return dnsmessage.TypeHINFO, nil + case "MINFO": + return dnsmessage.TypeMINFO, nil + case "MX": + return dnsmessage.TypeMX, nil + case "NS": + return dnsmessage.TypeNS, nil + case "OPT": + return dnsmessage.TypeOPT, nil + case "PTR": + return dnsmessage.TypePTR, nil + case "SOA": + return dnsmessage.TypeSOA, nil + case "SRV": + return dnsmessage.TypeSRV, nil + case "TXT": + return dnsmessage.TypeTXT, nil + case "WKS": + return dnsmessage.TypeWKS, nil + } + return 0, errors.New("unknown DNS message type: " + s) +} + +// serveSuggestExitNode serves a POST endpoint for returning a suggested exit node. +func (h *Handler) serveSuggestExitNode(w http.ResponseWriter, r *http.Request) { + if !buildfeatures.HasUseExitNode { + http.Error(w, feature.ErrUnavailable.Error(), http.StatusNotImplemented) return } - - b, err := io.ReadAll(r.Body) + if r.Method != httpm.GET { + http.Error(w, "only GET allowed", http.StatusMethodNotAllowed) + return + } + res, err := h.b.SuggestExitNode() if err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) + WriteErrorJSON(w, err) return } - - h.b.DriveSetServerAddr(string(b)) - w.WriteHeader(http.StatusCreated) + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(res) } -// serveShares handles the management of Taildrive shares. -// -// PUT - adds or updates an existing share -// DELETE - removes a share -// GET - gets a list of all shares, sorted by name -// POST - renames an existing share -func (h *Handler) serveShares(w http.ResponseWriter, r *http.Request) { - if !h.b.DriveSharingEnabled() { - http.Error(w, `taildrive sharing not enabled, please add the attribute "drive:share" to this node in your ACLs' "nodeAttrs" section`, http.StatusForbidden) +// Shutdown is an eventbus value published when tailscaled shutdown +// is requested via LocalAPI. Its only consumer is [ipnserver.Server]. +type Shutdown struct{} + +// serveShutdown shuts down tailscaled. It requires write access +// and the [pkey.AllowTailscaledRestart] policy to be enabled. +// See tailscale/corp#32674. +func (h *Handler) serveShutdown(w http.ResponseWriter, r *http.Request) { + if r.Method != httpm.POST { + http.Error(w, "only POST allowed", http.StatusMethodNotAllowed) return } - switch r.Method { - case "PUT": - var share drive.Share - err := json.NewDecoder(r.Body).Decode(&share) - if err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - share.Path = path.Clean(share.Path) - fi, err := os.Stat(share.Path) - if err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - if !fi.IsDir() { - http.Error(w, "not a directory", http.StatusBadRequest) - return - } - if drive.AllowShareAs() { - // share as the connected user - username, err := h.Actor.Username() - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - share.As = username - } - err = h.b.DriveSetShare(&share) - if err != nil { - if errors.Is(err, drive.ErrInvalidShareName) { - http.Error(w, "invalid share name", http.StatusBadRequest) - return - } - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - w.WriteHeader(http.StatusCreated) - case "DELETE": - b, err := io.ReadAll(r.Body) - if err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - err = h.b.DriveRemoveShare(string(b)) - if err != nil { - if os.IsNotExist(err) { - http.Error(w, "share not found", http.StatusNotFound) - return - } - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - w.WriteHeader(http.StatusNoContent) - case "POST": - var names [2]string - err := json.NewDecoder(r.Body).Decode(&names) - if err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - err = h.b.DriveRenameShare(names[0], names[1]) - if err != nil { - if os.IsNotExist(err) { - http.Error(w, "share not found", http.StatusNotFound) - return - } - if os.IsExist(err) { - http.Error(w, "share name already used", http.StatusBadRequest) - return - } - if errors.Is(err, drive.ErrInvalidShareName) { - http.Error(w, "invalid share name", http.StatusBadRequest) - return - } - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - w.WriteHeader(http.StatusNoContent) - case "GET": - shares := h.b.DriveGetShares() - err := json.NewEncoder(w).Encode(shares) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - default: - http.Error(w, "unsupported method", http.StatusMethodNotAllowed) + + if !h.PermitWrite { + http.Error(w, "shutdown access denied", http.StatusForbidden) + return } -} -var ( - metricInvalidRequests = clientmetric.NewCounter("localapi_invalid_requests") + polc := h.b.Sys().PolicyClientOrDefault() + if permitShutdown, _ := polc.GetBoolean(pkey.AllowTailscaledRestart, false); !permitShutdown { + http.Error(w, "shutdown access denied by policy", http.StatusForbidden) + return + } - // User-visible LocalAPI endpoints. - metricFilePutCalls = clientmetric.NewCounter("localapi_file_put") -) + ec := h.eventBus.Client("localapi.Handler") + defer ec.Close() -// serveSuggestExitNode serves a POST endpoint for returning a suggested exit node. -func (h *Handler) serveSuggestExitNode(w http.ResponseWriter, r *http.Request) { - if r.Method != "GET" { + w.WriteHeader(http.StatusOK) + if f, ok := w.(http.Flusher); ok { + f.Flush() + } + + eventbus.Publish[Shutdown](ec).Publish(Shutdown{}) +} + +func (h *Handler) serveGetAppcRouteInfo(w http.ResponseWriter, r *http.Request) { + if !buildfeatures.HasAppConnectors { + http.Error(w, feature.ErrUnavailable.Error(), http.StatusNotImplemented) + return + } + if r.Method != httpm.GET { http.Error(w, "only GET allowed", http.StatusMethodNotAllowed) return } - res, err := h.b.SuggestExitNode() + res, err := h.b.ReadRouteInfo() if err != nil { - writeErrorJSON(w, err) - return + if errors.Is(err, ipn.ErrStateNotExist) { + res = &appctype.RouteInfo{} + } else { + WriteErrorJSON(w, err) + return + } } w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(res) diff --git a/ipn/localapi/localapi_drive.go b/ipn/localapi/localapi_drive.go new file mode 100644 index 000000000..eb765ec2e --- /dev/null +++ b/ipn/localapi/localapi_drive.go @@ -0,0 +1,141 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !ts_omit_drive + +package localapi + +import ( + "encoding/json" + "errors" + "io" + "net/http" + "os" + "path" + + "tailscale.com/drive" + "tailscale.com/util/httpm" +) + +func init() { + Register("drive/fileserver-address", (*Handler).serveDriveServerAddr) + Register("drive/shares", (*Handler).serveShares) +} + +// serveDriveServerAddr handles updates of the Taildrive file server address. +func (h *Handler) serveDriveServerAddr(w http.ResponseWriter, r *http.Request) { + if r.Method != httpm.PUT { + http.Error(w, "only PUT allowed", http.StatusMethodNotAllowed) + return + } + + b, err := io.ReadAll(r.Body) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + h.b.DriveSetServerAddr(string(b)) + w.WriteHeader(http.StatusCreated) +} + +// serveShares handles the management of Taildrive shares. +// +// PUT - adds or updates an existing share +// DELETE - removes a share +// GET - gets a list of all shares, sorted by name +// POST - renames an existing share +func (h *Handler) serveShares(w http.ResponseWriter, r *http.Request) { + if !h.b.DriveSharingEnabled() { + http.Error(w, `taildrive sharing not enabled, please add the attribute "drive:share" to this node in your ACLs' "nodeAttrs" section`, http.StatusForbidden) + return + } + switch r.Method { + case httpm.PUT: + var share drive.Share + err := json.NewDecoder(r.Body).Decode(&share) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + share.Path = path.Clean(share.Path) + fi, err := os.Stat(share.Path) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + if !fi.IsDir() { + http.Error(w, "not a directory", http.StatusBadRequest) + return + } + if drive.AllowShareAs() { + // share as the connected user + username, err := h.Actor.Username() + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + share.As = username + } + err = h.b.DriveSetShare(&share) + if err != nil { + if errors.Is(err, drive.ErrInvalidShareName) { + http.Error(w, "invalid share name", http.StatusBadRequest) + return + } + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + w.WriteHeader(http.StatusCreated) + case httpm.DELETE: + b, err := io.ReadAll(r.Body) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + err = h.b.DriveRemoveShare(string(b)) + if err != nil { + if os.IsNotExist(err) { + http.Error(w, "share not found", http.StatusNotFound) + return + } + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + w.WriteHeader(http.StatusNoContent) + case httpm.POST: + var names [2]string + err := json.NewDecoder(r.Body).Decode(&names) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + err = h.b.DriveRenameShare(names[0], names[1]) + if err != nil { + if os.IsNotExist(err) { + http.Error(w, "share not found", http.StatusNotFound) + return + } + if os.IsExist(err) { + http.Error(w, "share name already used", http.StatusBadRequest) + return + } + if errors.Is(err, drive.ErrInvalidShareName) { + http.Error(w, "invalid share name", http.StatusBadRequest) + return + } + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + w.WriteHeader(http.StatusNoContent) + case httpm.GET: + shares := h.b.DriveGetShares() + err := json.NewEncoder(w).Encode(shares) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + default: + http.Error(w, "unsupported method", http.StatusMethodNotAllowed) + } +} diff --git a/ipn/localapi/localapi_test.go b/ipn/localapi/localapi_test.go index fa54a1e75..6bb9b5182 100644 --- a/ipn/localapi/localapi_test.go +++ b/ipn/localapi/localapi_test.go @@ -35,27 +35,24 @@ import ( "tailscale.com/types/key" "tailscale.com/types/logger" "tailscale.com/types/logid" + "tailscale.com/util/eventbus/eventbustest" "tailscale.com/util/slicesx" "tailscale.com/wgengine" ) -var _ ipnauth.Actor = (*testActor)(nil) - -type testActor struct { - uid ipn.WindowsUserID - name string - isLocalSystem bool - isLocalAdmin bool +func handlerForTest(t testing.TB, h *Handler) *Handler { + if h.Actor == nil { + h.Actor = &ipnauth.TestActor{} + } + if h.b == nil { + h.b = &ipnlocal.LocalBackend{} + } + if h.logf == nil { + h.logf = logger.TestLogger(t) + } + return h } -func (u *testActor) UserID() ipn.WindowsUserID { return u.uid } - -func (u *testActor) Username() (string, error) { return u.name, nil } - -func (u *testActor) IsLocalSystem() bool { return u.isLocalSystem } - -func (u *testActor) IsLocalAdmin(operatorUID string) bool { return u.isLocalAdmin } - func TestValidHost(t *testing.T) { tests := []struct { host string @@ -73,7 +70,7 @@ func TestValidHost(t *testing.T) { for _, test := range tests { t.Run(test.host, func(t *testing.T) { - h := &Handler{} + h := handlerForTest(t, &Handler{}) if got := h.validHost(test.host); got != test.valid { t.Errorf("validHost(%q)=%v, want %v", test.host, got, test.valid) } @@ -84,10 +81,9 @@ func TestValidHost(t *testing.T) { func TestSetPushDeviceToken(t *testing.T) { tstest.Replace(t, &validLocalHostForTesting, true) - h := &Handler{ + h := handlerForTest(t, &Handler{ PermitWrite: true, - b: &ipnlocal.LocalBackend{}, - } + }) s := httptest.NewServer(h) defer s.Close() c := s.Client() @@ -141,9 +137,9 @@ func (b whoIsBackend) PeerCaps(ip netip.Addr) tailcfg.PeerCapMap { // // And https://github.com/tailscale/tailscale/issues/12465 func TestWhoIsArgTypes(t *testing.T) { - h := &Handler{ + h := handlerForTest(t, &Handler{ PermitRead: true, - } + }) match := func() (n tailcfg.NodeView, u tailcfg.UserProfile, ok bool) { return (&tailcfg.Node{ @@ -175,7 +171,6 @@ func TestWhoIsArgTypes(t *testing.T) { t.Fatalf("backend called with %v; want %v", k, keyStr) } return match() - }, peerCaps: map[netip.Addr]tailcfg.PeerCapMap{ netip.MustParseAddr("100.101.102.103"): map[tailcfg.PeerCapability][]tailcfg.RawMessage{ @@ -207,7 +202,10 @@ func TestWhoIsArgTypes(t *testing.T) { func TestShouldDenyServeConfigForGOOSAndUserContext(t *testing.T) { newHandler := func(connIsLocalAdmin bool) *Handler { - return &Handler{Actor: &testActor{isLocalAdmin: connIsLocalAdmin}, b: newTestLocalBackend(t)} + return handlerForTest(t, &Handler{ + Actor: &ipnauth.TestActor{LocalAdmin: connIsLocalAdmin}, + b: newTestLocalBackend(t), + }) } tests := []struct { name string @@ -254,7 +252,7 @@ func TestShouldDenyServeConfigForGOOSAndUserContext(t *testing.T) { } for _, tt := range tests { - for _, goos := range []string{"linux", "windows", "darwin"} { + for _, goos := range []string{"linux", "windows", "darwin", "illumos", "solaris"} { t.Run(goos+"-"+tt.name, func(t *testing.T) { err := authorizeServeConfigForGOOSAndUserContext(goos, tt.configIn, tt.h) gotErr := err != nil @@ -280,13 +278,17 @@ func TestShouldDenyServeConfigForGOOSAndUserContext(t *testing.T) { }) } +// TestServeWatchIPNBus used to test that various WatchIPNBus mask flags +// changed the permissions required to access the endpoint. +// However, since the removal of the NotifyNoPrivateKeys flag requirement +// for read-only users, this test now only verifies that the endpoint +// behaves correctly based on the PermitRead and PermitWrite settings. func TestServeWatchIPNBus(t *testing.T) { tstest.Replace(t, &validLocalHostForTesting, true) tests := []struct { desc string permitRead, permitWrite bool - mask ipn.NotifyWatchOpt // extra bits in addition to ipn.NotifyInitialState wantStatus int }{ { @@ -296,20 +298,13 @@ func TestServeWatchIPNBus(t *testing.T) { wantStatus: http.StatusForbidden, }, { - desc: "read-initial-state", - permitRead: true, - permitWrite: false, - wantStatus: http.StatusForbidden, - }, - { - desc: "read-initial-state-no-private-keys", + desc: "read-only", permitRead: true, permitWrite: false, - mask: ipn.NotifyNoPrivateKeys, wantStatus: http.StatusOK, }, { - desc: "read-initial-state-with-private-keys", + desc: "read-and-write", permitRead: true, permitWrite: true, wantStatus: http.StatusOK, @@ -318,17 +313,17 @@ func TestServeWatchIPNBus(t *testing.T) { for _, tt := range tests { t.Run(tt.desc, func(t *testing.T) { - h := &Handler{ + h := handlerForTest(t, &Handler{ PermitRead: tt.permitRead, PermitWrite: tt.permitWrite, b: newTestLocalBackend(t), - } + }) s := httptest.NewServer(h) defer s.Close() c := s.Client() ctx, cancel := context.WithCancel(context.Background()) - req, err := http.NewRequestWithContext(ctx, "GET", fmt.Sprintf("%s/localapi/v0/watch-ipn-bus?mask=%d", s.URL, ipn.NotifyInitialState|tt.mask), nil) + req, err := http.NewRequestWithContext(ctx, "GET", fmt.Sprintf("%s/localapi/v0/watch-ipn-bus?mask=%d", s.URL, ipn.NotifyInitialState), nil) if err != nil { t.Fatal(err) } @@ -353,10 +348,10 @@ func TestServeWatchIPNBus(t *testing.T) { func newTestLocalBackend(t testing.TB) *ipnlocal.LocalBackend { var logf logger.Logf = logger.Discard - sys := new(tsd.System) + sys := tsd.NewSystemWithBus(eventbustest.NewBus(t)) store := new(mem.Store) sys.Set(store) - eng, err := wgengine.NewFakeUserspaceEngine(logf, sys.Set, sys.HealthTracker(), sys.UserMetricsRegistry()) + eng, err := wgengine.NewFakeUserspaceEngine(logf, sys.Set, sys.HealthTracker.Get(), sys.UserMetricsRegistry(), sys.Bus.Get()) if err != nil { t.Fatalf("NewFakeUserspaceEngine: %v", err) } @@ -366,6 +361,7 @@ func newTestLocalBackend(t testing.TB) *ipnlocal.LocalBackend { if err != nil { t.Fatalf("NewLocalBackend: %v", err) } + t.Cleanup(lb.Shutdown) return lb } diff --git a/ipn/localapi/pprof.go b/ipn/localapi/pprof.go index 8c9429b31..9476f721f 100644 --- a/ipn/localapi/pprof.go +++ b/ipn/localapi/pprof.go @@ -1,7 +1,7 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -//go:build !ios && !android && !js +//go:build !ios && !android && !js && !ts_omit_debug // We don't include it on mobile where we're more memory constrained and // there's no CLI to get at the results anyway. diff --git a/ipn/localapi/serve.go b/ipn/localapi/serve.go new file mode 100644 index 000000000..56c8b486c --- /dev/null +++ b/ipn/localapi/serve.go @@ -0,0 +1,108 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !ts_omit_serve + +package localapi + +import ( + "crypto/sha256" + "encoding/hex" + "encoding/json" + "errors" + "fmt" + "net/http" + "runtime" + + "tailscale.com/ipn" + "tailscale.com/ipn/ipnlocal" + "tailscale.com/util/httpm" + "tailscale.com/version" +) + +func init() { + Register("serve-config", (*Handler).serveServeConfig) +} + +func (h *Handler) serveServeConfig(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case httpm.GET: + if !h.PermitRead { + http.Error(w, "serve config denied", http.StatusForbidden) + return + } + config := h.b.ServeConfig() + bts, err := json.Marshal(config) + if err != nil { + http.Error(w, "error encoding config: "+err.Error(), http.StatusInternalServerError) + return + } + sum := sha256.Sum256(bts) + etag := hex.EncodeToString(sum[:]) + w.Header().Set("Etag", etag) + w.Header().Set("Content-Type", "application/json") + w.Write(bts) + case httpm.POST: + if !h.PermitWrite { + http.Error(w, "serve config denied", http.StatusForbidden) + return + } + configIn := new(ipn.ServeConfig) + if err := json.NewDecoder(r.Body).Decode(configIn); err != nil { + WriteErrorJSON(w, fmt.Errorf("decoding config: %w", err)) + return + } + + // require a local admin when setting a path handler + // TODO: roll-up this Windows-specific check into either PermitWrite + // or a global admin escalation check. + if err := authorizeServeConfigForGOOSAndUserContext(runtime.GOOS, configIn, h); err != nil { + http.Error(w, err.Error(), http.StatusUnauthorized) + return + } + + etag := r.Header.Get("If-Match") + if err := h.b.SetServeConfig(configIn, etag); err != nil { + if errors.Is(err, ipnlocal.ErrETagMismatch) { + http.Error(w, err.Error(), http.StatusPreconditionFailed) + return + } + WriteErrorJSON(w, fmt.Errorf("updating config: %w", err)) + return + } + w.WriteHeader(http.StatusOK) + default: + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + } +} + +func authorizeServeConfigForGOOSAndUserContext(goos string, configIn *ipn.ServeConfig, h *Handler) error { + switch goos { + case "windows", "linux", "darwin", "illumos", "solaris": + default: + return nil + } + // Only check for local admin on tailscaled-on-mac (based on "sudo" + // permissions). On sandboxed variants (MacSys and AppStore), tailscaled + // cannot serve files outside of the sandbox and this check is not + // relevant. + if goos == "darwin" && version.IsSandboxedMacOS() { + return nil + } + if !configIn.HasPathHandler() { + return nil + } + if h.Actor.IsLocalAdmin(h.b.OperatorUserID()) { + return nil + } + switch goos { + case "windows": + return errors.New("must be a Windows local admin to serve a path") + case "linux", "darwin", "illumos", "solaris": + return errors.New("must be root, or be an operator and able to run 'sudo tailscale' to serve a path") + default: + // We filter goos at the start of the func, this default case + // should never happen. + panic("unreachable") + } +} diff --git a/ipn/localapi/syspolicy_api.go b/ipn/localapi/syspolicy_api.go new file mode 100644 index 000000000..edb82e042 --- /dev/null +++ b/ipn/localapi/syspolicy_api.go @@ -0,0 +1,68 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !ts_omit_syspolicy + +package localapi + +import ( + "encoding/json" + "fmt" + "net/http" + "strings" + + "tailscale.com/util/httpm" + "tailscale.com/util/syspolicy/rsop" + "tailscale.com/util/syspolicy/setting" +) + +func init() { + Register("policy/", (*Handler).servePolicy) +} + +func (h *Handler) servePolicy(w http.ResponseWriter, r *http.Request) { + if !h.PermitRead { + http.Error(w, "policy access denied", http.StatusForbidden) + return + } + + suffix, ok := strings.CutPrefix(r.URL.EscapedPath(), "/localapi/v0/policy/") + if !ok { + http.Error(w, "misconfigured", http.StatusInternalServerError) + return + } + + var scope setting.PolicyScope + if suffix == "" { + scope = setting.DefaultScope() + } else if err := scope.UnmarshalText([]byte(suffix)); err != nil { + http.Error(w, fmt.Sprintf("%q is not a valid scope", suffix), http.StatusBadRequest) + return + } + + policy, err := rsop.PolicyFor(scope) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + var effectivePolicy *setting.Snapshot + switch r.Method { + case httpm.GET: + effectivePolicy = policy.Get() + case httpm.POST: + effectivePolicy, err = policy.Reload() + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + default: + http.Error(w, "unsupported method", http.StatusMethodNotAllowed) + return + } + + w.Header().Set("Content-Type", "application/json") + e := json.NewEncoder(w) + e.SetIndent("", "\t") + e.Encode(effectivePolicy) +} diff --git a/ipn/localapi/tailnetlock.go b/ipn/localapi/tailnetlock.go new file mode 100644 index 000000000..e5f999bb8 --- /dev/null +++ b/ipn/localapi/tailnetlock.go @@ -0,0 +1,413 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !ts_omit_tailnetlock + +package localapi + +import ( + "encoding/json" + "io" + "net/http" + "strconv" + + "tailscale.com/tka" + "tailscale.com/types/key" + "tailscale.com/types/tkatype" + "tailscale.com/util/httpm" +) + +func init() { + Register("tka/affected-sigs", (*Handler).serveTKAAffectedSigs) + Register("tka/cosign-recovery-aum", (*Handler).serveTKACosignRecoveryAUM) + Register("tka/disable", (*Handler).serveTKADisable) + Register("tka/force-local-disable", (*Handler).serveTKALocalDisable) + Register("tka/generate-recovery-aum", (*Handler).serveTKAGenerateRecoveryAUM) + Register("tka/init", (*Handler).serveTKAInit) + Register("tka/log", (*Handler).serveTKALog) + Register("tka/modify", (*Handler).serveTKAModify) + Register("tka/sign", (*Handler).serveTKASign) + Register("tka/status", (*Handler).serveTKAStatus) + Register("tka/submit-recovery-aum", (*Handler).serveTKASubmitRecoveryAUM) + Register("tka/verify-deeplink", (*Handler).serveTKAVerifySigningDeeplink) + Register("tka/wrap-preauth-key", (*Handler).serveTKAWrapPreauthKey) +} + +func (h *Handler) serveTKAStatus(w http.ResponseWriter, r *http.Request) { + if !h.PermitRead { + http.Error(w, "lock status access denied", http.StatusForbidden) + return + } + if r.Method != httpm.GET { + http.Error(w, "use GET", http.StatusMethodNotAllowed) + return + } + + j, err := json.MarshalIndent(h.b.NetworkLockStatus(), "", "\t") + if err != nil { + http.Error(w, "JSON encoding error", http.StatusInternalServerError) + return + } + w.Header().Set("Content-Type", "application/json") + w.Write(j) +} + +func (h *Handler) serveTKASign(w http.ResponseWriter, r *http.Request) { + if !h.PermitWrite { + http.Error(w, "lock sign access denied", http.StatusForbidden) + return + } + if r.Method != httpm.POST { + http.Error(w, "use POST", http.StatusMethodNotAllowed) + return + } + + type signRequest struct { + NodeKey key.NodePublic + RotationPublic []byte + } + var req signRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, "invalid JSON body", http.StatusBadRequest) + return + } + + if err := h.b.NetworkLockSign(req.NodeKey, req.RotationPublic); err != nil { + http.Error(w, "signing failed: "+err.Error(), http.StatusInternalServerError) + return + } + + w.WriteHeader(http.StatusOK) +} + +func (h *Handler) serveTKAInit(w http.ResponseWriter, r *http.Request) { + if !h.PermitWrite { + http.Error(w, "lock init access denied", http.StatusForbidden) + return + } + if r.Method != httpm.POST { + http.Error(w, "use POST", http.StatusMethodNotAllowed) + return + } + + type initRequest struct { + Keys []tka.Key + DisablementValues [][]byte + SupportDisablement []byte + } + var req initRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, "invalid JSON body", http.StatusBadRequest) + return + } + + if !h.b.NetworkLockAllowed() { + http.Error(w, "Tailnet Lock is not supported on your pricing plan", http.StatusForbidden) + return + } + + if err := h.b.NetworkLockInit(req.Keys, req.DisablementValues, req.SupportDisablement); err != nil { + http.Error(w, "initialization failed: "+err.Error(), http.StatusInternalServerError) + return + } + + j, err := json.MarshalIndent(h.b.NetworkLockStatus(), "", "\t") + if err != nil { + http.Error(w, "JSON encoding error", http.StatusInternalServerError) + return + } + w.Header().Set("Content-Type", "application/json") + w.Write(j) +} + +func (h *Handler) serveTKAModify(w http.ResponseWriter, r *http.Request) { + if !h.PermitWrite { + http.Error(w, "network-lock modify access denied", http.StatusForbidden) + return + } + if r.Method != httpm.POST { + http.Error(w, "use POST", http.StatusMethodNotAllowed) + return + } + + type modifyRequest struct { + AddKeys []tka.Key + RemoveKeys []tka.Key + } + var req modifyRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, "invalid JSON body", http.StatusBadRequest) + return + } + + if err := h.b.NetworkLockModify(req.AddKeys, req.RemoveKeys); err != nil { + http.Error(w, "network-lock modify failed: "+err.Error(), http.StatusInternalServerError) + return + } + w.WriteHeader(204) +} + +func (h *Handler) serveTKAWrapPreauthKey(w http.ResponseWriter, r *http.Request) { + if !h.PermitWrite { + http.Error(w, "network-lock modify access denied", http.StatusForbidden) + return + } + if r.Method != httpm.POST { + http.Error(w, "use POST", http.StatusMethodNotAllowed) + return + } + + type wrapRequest struct { + TSKey string + TKAKey string // key.NLPrivate.MarshalText + } + var req wrapRequest + if err := json.NewDecoder(http.MaxBytesReader(w, r.Body, 12*1024)).Decode(&req); err != nil { + http.Error(w, "invalid JSON body", http.StatusBadRequest) + return + } + var priv key.NLPrivate + if err := priv.UnmarshalText([]byte(req.TKAKey)); err != nil { + http.Error(w, "invalid JSON body", http.StatusBadRequest) + return + } + + wrappedKey, err := h.b.NetworkLockWrapPreauthKey(req.TSKey, priv) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + w.WriteHeader(http.StatusOK) + w.Write([]byte(wrappedKey)) +} + +func (h *Handler) serveTKAVerifySigningDeeplink(w http.ResponseWriter, r *http.Request) { + if !h.PermitRead { + http.Error(w, "signing deeplink verification access denied", http.StatusForbidden) + return + } + if r.Method != httpm.POST { + http.Error(w, "use POST", http.StatusMethodNotAllowed) + return + } + + type verifyRequest struct { + URL string + } + var req verifyRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, "invalid JSON for verifyRequest body", http.StatusBadRequest) + return + } + + res := h.b.NetworkLockVerifySigningDeeplink(req.URL) + j, err := json.MarshalIndent(res, "", "\t") + if err != nil { + http.Error(w, "JSON encoding error", http.StatusInternalServerError) + return + } + w.Header().Set("Content-Type", "application/json") + w.Write(j) +} + +func (h *Handler) serveTKADisable(w http.ResponseWriter, r *http.Request) { + if !h.PermitWrite { + http.Error(w, "network-lock modify access denied", http.StatusForbidden) + return + } + if r.Method != httpm.POST { + http.Error(w, "use POST", http.StatusMethodNotAllowed) + return + } + + body := io.LimitReader(r.Body, 1024*1024) + secret, err := io.ReadAll(body) + if err != nil { + http.Error(w, "reading secret", http.StatusBadRequest) + return + } + + if err := h.b.NetworkLockDisable(secret); err != nil { + http.Error(w, "network-lock disable failed: "+err.Error(), http.StatusBadRequest) + return + } + w.WriteHeader(http.StatusOK) +} + +func (h *Handler) serveTKALocalDisable(w http.ResponseWriter, r *http.Request) { + if !h.PermitWrite { + http.Error(w, "network-lock modify access denied", http.StatusForbidden) + return + } + if r.Method != httpm.POST { + http.Error(w, "use POST", http.StatusMethodNotAllowed) + return + } + + // Require a JSON stanza for the body as an additional CSRF protection. + var req struct{} + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, "invalid JSON body", http.StatusBadRequest) + return + } + + if err := h.b.NetworkLockForceLocalDisable(); err != nil { + http.Error(w, "network-lock local disable failed: "+err.Error(), http.StatusBadRequest) + return + } + w.WriteHeader(http.StatusOK) +} + +func (h *Handler) serveTKALog(w http.ResponseWriter, r *http.Request) { + if r.Method != httpm.GET { + http.Error(w, "use GET", http.StatusMethodNotAllowed) + return + } + + limit := 50 + if limitStr := r.FormValue("limit"); limitStr != "" { + lm, err := strconv.Atoi(limitStr) + if err != nil { + http.Error(w, "parsing 'limit' parameter: "+err.Error(), http.StatusBadRequest) + return + } + limit = int(lm) + } + + updates, err := h.b.NetworkLockLog(limit) + if err != nil { + http.Error(w, "reading log failed: "+err.Error(), http.StatusInternalServerError) + return + } + + j, err := json.MarshalIndent(updates, "", "\t") + if err != nil { + http.Error(w, "JSON encoding error", http.StatusInternalServerError) + return + } + w.Header().Set("Content-Type", "application/json") + w.Write(j) +} + +func (h *Handler) serveTKAAffectedSigs(w http.ResponseWriter, r *http.Request) { + if r.Method != httpm.POST { + http.Error(w, "use POST", http.StatusMethodNotAllowed) + return + } + keyID, err := io.ReadAll(http.MaxBytesReader(w, r.Body, 2048)) + if err != nil { + http.Error(w, "reading body", http.StatusBadRequest) + return + } + + sigs, err := h.b.NetworkLockAffectedSigs(keyID) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + j, err := json.MarshalIndent(sigs, "", "\t") + if err != nil { + http.Error(w, "JSON encoding error", http.StatusInternalServerError) + return + } + w.Header().Set("Content-Type", "application/json") + w.Write(j) +} + +func (h *Handler) serveTKAGenerateRecoveryAUM(w http.ResponseWriter, r *http.Request) { + if !h.PermitWrite { + http.Error(w, "access denied", http.StatusForbidden) + return + } + if r.Method != httpm.POST { + http.Error(w, "use POST", http.StatusMethodNotAllowed) + return + } + + type verifyRequest struct { + Keys []tkatype.KeyID + ForkFrom string + } + var req verifyRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, "invalid JSON for verifyRequest body", http.StatusBadRequest) + return + } + + var forkFrom tka.AUMHash + if req.ForkFrom != "" { + if err := forkFrom.UnmarshalText([]byte(req.ForkFrom)); err != nil { + http.Error(w, "decoding fork-from: "+err.Error(), http.StatusBadRequest) + return + } + } + + res, err := h.b.NetworkLockGenerateRecoveryAUM(req.Keys, forkFrom) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + w.Header().Set("Content-Type", "application/octet-stream") + w.Write(res.Serialize()) +} + +func (h *Handler) serveTKACosignRecoveryAUM(w http.ResponseWriter, r *http.Request) { + if !h.PermitWrite { + http.Error(w, "access denied", http.StatusForbidden) + return + } + if r.Method != httpm.POST { + http.Error(w, "use POST", http.StatusMethodNotAllowed) + return + } + + body := io.LimitReader(r.Body, 1024*1024) + aumBytes, err := io.ReadAll(body) + if err != nil { + http.Error(w, "reading AUM", http.StatusBadRequest) + return + } + var aum tka.AUM + if err := aum.Unserialize(aumBytes); err != nil { + http.Error(w, "decoding AUM", http.StatusBadRequest) + return + } + + res, err := h.b.NetworkLockCosignRecoveryAUM(&aum) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + w.Header().Set("Content-Type", "application/octet-stream") + w.Write(res.Serialize()) +} + +func (h *Handler) serveTKASubmitRecoveryAUM(w http.ResponseWriter, r *http.Request) { + if !h.PermitWrite { + http.Error(w, "access denied", http.StatusForbidden) + return + } + if r.Method != httpm.POST { + http.Error(w, "use POST", http.StatusMethodNotAllowed) + return + } + + body := io.LimitReader(r.Body, 1024*1024) + aumBytes, err := io.ReadAll(body) + if err != nil { + http.Error(w, "reading AUM", http.StatusBadRequest) + return + } + var aum tka.AUM + if err := aum.Unserialize(aumBytes); err != nil { + http.Error(w, "decoding AUM", http.StatusBadRequest) + return + } + + if err := h.b.NetworkLockSubmitRecoveryAUM(&aum); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + w.WriteHeader(http.StatusOK) +} diff --git a/ipn/prefs.go b/ipn/prefs.go index 5d61f0119..7f8216c60 100644 --- a/ipn/prefs.go +++ b/ipn/prefs.go @@ -5,6 +5,7 @@ package ipn import ( "bytes" + "cmp" "encoding/json" "errors" "fmt" @@ -19,6 +20,7 @@ import ( "tailscale.com/atomicfile" "tailscale.com/drive" + "tailscale.com/feature/buildfeatures" "tailscale.com/ipn/ipnstate" "tailscale.com/net/netaddr" "tailscale.com/net/tsaddr" @@ -28,7 +30,9 @@ import ( "tailscale.com/types/preftype" "tailscale.com/types/views" "tailscale.com/util/dnsname" - "tailscale.com/util/syspolicy" + "tailscale.com/util/syspolicy/pkey" + "tailscale.com/util/syspolicy/policyclient" + "tailscale.com/version" ) // DefaultControlURL is the URL base of the control plane @@ -93,6 +97,25 @@ type Prefs struct { ExitNodeID tailcfg.StableNodeID ExitNodeIP netip.Addr + // AutoExitNode is an optional expression that specifies whether and how + // tailscaled should pick an exit node automatically. + // + // If specified, tailscaled will use an exit node based on the expression, + // and will re-evaluate the selection periodically as network conditions, + // available exit nodes, or policy settings change. A blackhole route will + // be installed to prevent traffic from escaping to the local network until + // an exit node is selected. It takes precedence over ExitNodeID and ExitNodeIP. + // + // If empty, tailscaled will not automatically select an exit node. + // + // If the specified expression is invalid or unsupported by the client, + // it falls back to the behavior of [AnyExitNode]. + // + // As of 2025-07-02, the only supported value is [AnyExitNode]. + // It's a string rather than a boolean to allow future extensibility + // (e.g., AutoExitNode = "mullvad" or AutoExitNode = "geo:us"). + AutoExitNode ExitNodeExpression `json:",omitempty"` + // InternalExitNodePrior is the most recently used ExitNodeID in string form. It is set by // the backend on transition from exit node on to off and used by the // backend. @@ -138,11 +161,10 @@ type Prefs struct { // connections. This overrides tailcfg.Hostinfo's ShieldsUp. ShieldsUp bool - // AdvertiseTags specifies groups that this node wants to join, for - // purposes of ACL enforcement. These can be referenced from the ACL - // security policy. Note that advertising a tag doesn't guarantee that - // the control server will allow you to take on the rights for that - // tag. + // AdvertiseTags specifies tags that should be applied to this node, for + // purposes of ACL enforcement. These can be referenced from the ACL policy + // document. Note that advertising a tag on the client doesn't guarantee + // that the control server will allow the node to adopt that tag. AdvertiseTags []string // Hostname is the hostname to use for identifying the node. If @@ -179,6 +201,18 @@ type Prefs struct { // node. AdvertiseRoutes []netip.Prefix + // AdvertiseServices specifies the list of services that this + // node can serve as a destination for. Note that an advertised + // service must still go through the approval process from the + // control server. + AdvertiseServices []string + + // Sync is whether this node should sync its configuration from + // the control plane. If unset, this defaults to true. + // This exists primarily for testing, to verify that netmap caching + // and offline operation work correctly. + Sync opt.Bool + // NoSNAT specifies whether to source NAT traffic going to // destinations in AdvertiseRoutes. The default is to apply source // NAT, which makes the traffic appear to come from the router @@ -228,10 +262,17 @@ type Prefs struct { // PostureChecking enables the collection of information used for device // posture checks. + // + // Note: this should be named ReportPosture, but it was shipped as + // PostureChecking in some early releases and this JSON field is written to + // disk, so we just keep its old name. (akin to CorpDNS which is an internal + // pref name that doesn't match the public interface) PostureChecking bool // NetfilterKind specifies what netfilter implementation to use. // + // It can be "iptables", "nftables", or "" to auto-detect. + // // Linux-only. NetfilterKind string @@ -239,9 +280,17 @@ type Prefs struct { // by name. DriveShares []*drive.Share + // RelayServerPort is the UDP port number for the relay server to bind to, + // on all interfaces. A non-nil zero value signifies a random unused port + // should be used. A nil value signifies relay server functionality + // should be disabled. This field is currently experimental, and therefore + // no guarantees are made about its current naming and functionality when + // non-nil/enabled. + RelayServerPort *int `json:",omitempty"` + // AllowSingleHosts was a legacy field that was always true // for the past 4.5 years. It controlled whether Tailscale - // peers got /32 or /127 routes for each other. + // peers got /32 or /128 routes for each other. // As of 2024-05-17 we're starting to ignore it, but to let // people still downgrade Tailscale versions and not break // all peer-to-peer networking we still write it to disk (as JSON) @@ -305,6 +354,7 @@ type MaskedPrefs struct { RouteAllSet bool `json:",omitempty"` ExitNodeIDSet bool `json:",omitempty"` ExitNodeIPSet bool `json:",omitempty"` + AutoExitNodeSet bool `json:",omitempty"` InternalExitNodePriorSet bool `json:",omitempty"` // Internal; can't be set by LocalAPI clients ExitNodeAllowLANAccessSet bool `json:",omitempty"` CorpDNSSet bool `json:",omitempty"` @@ -319,16 +369,19 @@ type MaskedPrefs struct { ForceDaemonSet bool `json:",omitempty"` EggSet bool `json:",omitempty"` AdvertiseRoutesSet bool `json:",omitempty"` + AdvertiseServicesSet bool `json:",omitempty"` + SyncSet bool `json:",omitzero"` NoSNATSet bool `json:",omitempty"` NoStatefulFilteringSet bool `json:",omitempty"` NetfilterModeSet bool `json:",omitempty"` OperatorUserSet bool `json:",omitempty"` ProfileNameSet bool `json:",omitempty"` - AutoUpdateSet AutoUpdatePrefsMask `json:",omitempty"` + AutoUpdateSet AutoUpdatePrefsMask `json:",omitzero"` AppConnectorSet bool `json:",omitempty"` PostureCheckingSet bool `json:",omitempty"` NetfilterKindSet bool `json:",omitempty"` DriveSharesSet bool `json:",omitempty"` + RelayServerPortSet bool `json:",omitempty"` } // SetsInternal reports whether mp has any of the Internal*Set field bools set @@ -486,17 +539,24 @@ func (p *Prefs) Pretty() string { return p.pretty(runtime.GOOS) } func (p *Prefs) pretty(goos string) string { var sb strings.Builder sb.WriteString("Prefs{") - fmt.Fprintf(&sb, "ra=%v ", p.RouteAll) - fmt.Fprintf(&sb, "dns=%v want=%v ", p.CorpDNS, p.WantRunning) - if p.RunSSH { + if buildfeatures.HasUseRoutes { + fmt.Fprintf(&sb, "ra=%v ", p.RouteAll) + } + if buildfeatures.HasDNS { + fmt.Fprintf(&sb, "dns=%v want=%v ", p.CorpDNS, p.WantRunning) + } + if buildfeatures.HasSSH && p.RunSSH { sb.WriteString("ssh=true ") } - if p.RunWebClient { + if buildfeatures.HasWebClient && p.RunWebClient { sb.WriteString("webclient=true ") } if p.LoggedOut { sb.WriteString("loggedout=true ") } + if p.Sync.EqualBool(false) { + sb.WriteString("sync=false ") + } if p.ForceDaemon { sb.WriteString("server=true ") } @@ -506,27 +566,37 @@ func (p *Prefs) pretty(goos string) string { if p.ShieldsUp { sb.WriteString("shields=true ") } - if p.ExitNodeIP.IsValid() { - fmt.Fprintf(&sb, "exit=%v lan=%t ", p.ExitNodeIP, p.ExitNodeAllowLANAccess) - } else if !p.ExitNodeID.IsZero() { - fmt.Fprintf(&sb, "exit=%v lan=%t ", p.ExitNodeID, p.ExitNodeAllowLANAccess) - } - if len(p.AdvertiseRoutes) > 0 || goos == "linux" { - fmt.Fprintf(&sb, "routes=%v ", p.AdvertiseRoutes) - } - if len(p.AdvertiseRoutes) > 0 || p.NoSNAT { - fmt.Fprintf(&sb, "snat=%v ", !p.NoSNAT) + if buildfeatures.HasUseExitNode { + if p.ExitNodeIP.IsValid() { + fmt.Fprintf(&sb, "exit=%v lan=%t ", p.ExitNodeIP, p.ExitNodeAllowLANAccess) + } else if !p.ExitNodeID.IsZero() { + fmt.Fprintf(&sb, "exit=%v lan=%t ", p.ExitNodeID, p.ExitNodeAllowLANAccess) + } + if p.AutoExitNode.IsSet() { + fmt.Fprintf(&sb, "auto=%v ", p.AutoExitNode) + } } - if len(p.AdvertiseRoutes) > 0 || p.NoStatefulFiltering.EqualBool(true) { - // Only print if we're advertising any routes, or the user has - // turned off stateful filtering (NoStatefulFiltering=true ⇒ - // StatefulFiltering=false). - bb, _ := p.NoStatefulFiltering.Get() - fmt.Fprintf(&sb, "statefulFiltering=%v ", !bb) + if buildfeatures.HasAdvertiseRoutes { + if len(p.AdvertiseRoutes) > 0 || goos == "linux" { + fmt.Fprintf(&sb, "routes=%v ", p.AdvertiseRoutes) + } + if len(p.AdvertiseRoutes) > 0 || p.NoSNAT { + fmt.Fprintf(&sb, "snat=%v ", !p.NoSNAT) + } + if len(p.AdvertiseRoutes) > 0 || p.NoStatefulFiltering.EqualBool(true) { + // Only print if we're advertising any routes, or the user has + // turned off stateful filtering (NoStatefulFiltering=true ⇒ + // StatefulFiltering=false). + bb, _ := p.NoStatefulFiltering.Get() + fmt.Fprintf(&sb, "statefulFiltering=%v ", !bb) + } } if len(p.AdvertiseTags) > 0 { fmt.Fprintf(&sb, "tags=%s ", strings.Join(p.AdvertiseTags, ",")) } + if len(p.AdvertiseServices) > 0 { + fmt.Fprintf(&sb, "services=%s ", strings.Join(p.AdvertiseServices, ",")) + } if goos == "linux" { fmt.Fprintf(&sb, "nf=%v ", p.NetfilterMode) } @@ -542,8 +612,15 @@ func (p *Prefs) pretty(goos string) string { if p.NetfilterKind != "" { fmt.Fprintf(&sb, "netfilterKind=%s ", p.NetfilterKind) } - sb.WriteString(p.AutoUpdate.Pretty()) - sb.WriteString(p.AppConnector.Pretty()) + if buildfeatures.HasClientUpdate { + sb.WriteString(p.AutoUpdate.Pretty()) + } + if buildfeatures.HasAppConnectors { + sb.WriteString(p.AppConnector.Pretty()) + } + if buildfeatures.HasRelayServer && p.RelayServerPort != nil { + fmt.Fprintf(&sb, "relayServerPort=%d ", *p.RelayServerPort) + } if p.Persist != nil { sb.WriteString(p.Persist.Pretty()) } else { @@ -570,7 +647,7 @@ func (p PrefsView) Equals(p2 PrefsView) bool { } func (p *Prefs) Equals(p2 *Prefs) bool { - if p == nil && p2 == nil { + if p == p2 { return true } if p == nil || p2 == nil { @@ -581,10 +658,12 @@ func (p *Prefs) Equals(p2 *Prefs) bool { p.RouteAll == p2.RouteAll && p.ExitNodeID == p2.ExitNodeID && p.ExitNodeIP == p2.ExitNodeIP && + p.AutoExitNode == p2.AutoExitNode && p.InternalExitNodePrior == p2.InternalExitNodePrior && p.ExitNodeAllowLANAccess == p2.ExitNodeAllowLANAccess && p.CorpDNS == p2.CorpDNS && p.RunSSH == p2.RunSSH && + p.Sync.Normalized() == p2.Sync.Normalized() && p.RunWebClient == p2.RunWebClient && p.WantRunning == p2.WantRunning && p.LoggedOut == p2.LoggedOut && @@ -596,15 +675,17 @@ func (p *Prefs) Equals(p2 *Prefs) bool { p.OperatorUser == p2.OperatorUser && p.Hostname == p2.Hostname && p.ForceDaemon == p2.ForceDaemon && - compareIPNets(p.AdvertiseRoutes, p2.AdvertiseRoutes) && - compareStrings(p.AdvertiseTags, p2.AdvertiseTags) && + slices.Equal(p.AdvertiseRoutes, p2.AdvertiseRoutes) && + slices.Equal(p.AdvertiseTags, p2.AdvertiseTags) && + slices.Equal(p.AdvertiseServices, p2.AdvertiseServices) && p.Persist.Equals(p2.Persist) && p.ProfileName == p2.ProfileName && p.AutoUpdate.Equals(p2.AutoUpdate) && p.AppConnector == p2.AppConnector && p.PostureChecking == p2.PostureChecking && slices.EqualFunc(p.DriveShares, p2.DriveShares, drive.SharesEqual) && - p.NetfilterKind == p2.NetfilterKind + p.NetfilterKind == p2.NetfilterKind && + compareIntPtrs(p.RelayServerPort, p2.RelayServerPort) } func (au AutoUpdatePrefs) Pretty() string { @@ -624,28 +705,14 @@ func (ap AppConnectorPrefs) Pretty() string { return "" } -func compareIPNets(a, b []netip.Prefix) bool { - if len(a) != len(b) { +func compareIntPtrs(a, b *int) bool { + if (a == nil) != (b == nil) { return false } - for i := range a { - if a[i] != b[i] { - return false - } - } - return true -} - -func compareStrings(a, b []string) bool { - if len(a) != len(b) { - return false - } - for i := range a { - if a[i] != b[i] { - return false - } + if a == nil { + return true } - return true + return *a == *b } // NewPrefs returns the default preferences to use. @@ -653,7 +720,8 @@ func NewPrefs() *Prefs { // Provide default values for options which might be missing // from the json data for any reason. The json can still // override them to false. - return &Prefs{ + + p := &Prefs{ // ControlURL is explicitly not set to signal that // it's not yet configured, which relaxes the CLI "up" // safety net features. It will get set to DefaultControlURL @@ -661,7 +729,6 @@ func NewPrefs() *Prefs { // later anyway. ControlURL: "", - RouteAll: true, CorpDNS: true, WantRunning: false, NetfilterMode: preftype.NetfilterOn, @@ -671,22 +738,24 @@ func NewPrefs() *Prefs { Apply: opt.Bool("unset"), }, } + p.RouteAll = p.DefaultRouteAll(runtime.GOOS) + return p } // ControlURLOrDefault returns the coordination server's URL base. // // If not configured, or if the configured value is a legacy name equivalent to // the default, then DefaultControlURL is returned instead. -func (p PrefsView) ControlURLOrDefault() string { - return p.Đļ.ControlURLOrDefault() +func (p PrefsView) ControlURLOrDefault(polc policyclient.Client) string { + return p.Đļ.ControlURLOrDefault(polc) } // ControlURLOrDefault returns the coordination server's URL base. // // If not configured, or if the configured value is a legacy name equivalent to // the default, then DefaultControlURL is returned instead. -func (p *Prefs) ControlURLOrDefault() string { - controlURL, err := syspolicy.GetString(syspolicy.ControlURL, p.ControlURL) +func (p *Prefs) ControlURLOrDefault(polc policyclient.Client) string { + controlURL, err := polc.GetString(pkey.ControlURL, p.ControlURL) if err != nil { controlURL = p.ControlURL } @@ -700,12 +769,26 @@ func (p *Prefs) ControlURLOrDefault() string { return DefaultControlURL } +// DefaultRouteAll returns the default value of [Prefs.RouteAll] as a function +// of the platform it's running on. +func (p *Prefs) DefaultRouteAll(goos string) bool { + switch goos { + case "windows", "android", "ios": + return true + case "darwin": + // Only true for macAppStore and macsys, false for darwin tailscaled. + return version.IsSandboxedMacOS() + default: + return false + } +} + // AdminPageURL returns the admin web site URL for the current ControlURL. -func (p PrefsView) AdminPageURL() string { return p.Đļ.AdminPageURL() } +func (p PrefsView) AdminPageURL(polc policyclient.Client) string { return p.Đļ.AdminPageURL(polc) } // AdminPageURL returns the admin web site URL for the current ControlURL. -func (p *Prefs) AdminPageURL() string { - url := p.ControlURLOrDefault() +func (p *Prefs) AdminPageURL(polc policyclient.Client) string { + url := p.ControlURLOrDefault(polc) if IsLoginServerSynonym(url) { // TODO(crawshaw): In future release, make this https://console.tailscale.com url = "https://login.tailscale.com" @@ -729,6 +812,9 @@ func (p *Prefs) AdvertisesExitNode() bool { // SetAdvertiseExitNode mutates p (if non-nil) to add or remove the two // /0 exit node routes. func (p *Prefs) SetAdvertiseExitNode(runExit bool) { + if !buildfeatures.HasAdvertiseExitNode { + return + } if p == nil { return } @@ -773,6 +859,7 @@ func isRemoteIP(st *ipnstate.Status, ip netip.Addr) bool { func (p *Prefs) ClearExitNode() { p.ExitNodeID = "" p.ExitNodeIP = netip.Addr{} + p.AutoExitNode = "" } // ExitNodeLocalIPError is returned when the requested IP address for an exit @@ -791,6 +878,9 @@ func exitNodeIPOfArg(s string, st *ipnstate.Status) (ip netip.Addr, err error) { } ip, err = netip.ParseAddr(s) if err == nil { + if !isRemoteIP(st, ip) { + return ip, ExitNodeLocalIPError{s} + } // If we're online already and have a netmap, double check that the IP // address specified is valid. if st.BackendState == "Running" { @@ -802,9 +892,6 @@ func exitNodeIPOfArg(s string, st *ipnstate.Status) (ip netip.Addr, err error) { return ip, fmt.Errorf("node %v is not advertising an exit node", ip) } } - if !isRemoteIP(st, ip) { - return ip, ExitNodeLocalIPError{s} - } return ip, nil } match := 0 @@ -880,10 +967,15 @@ func PrefsFromBytes(b []byte, base *Prefs) error { if len(b) == 0 { return nil } - return json.Unmarshal(b, base) } +func (p *Prefs) normalizeOptBools() { + if p.Sync == opt.ExplicitlyUnset { + p.Sync = "" + } +} + var jsonEscapedZero = []byte(`\u0000`) // LoadPrefsWindows loads a legacy relaynode config file into Prefs with @@ -932,6 +1024,7 @@ type WindowsUserID string type NetworkProfile struct { MagicDNSName string DomainName string + DisplayName string } // RequiresBackfill returns whether this object does not have all the data @@ -944,6 +1037,13 @@ func (n NetworkProfile) RequiresBackfill() bool { return n == NetworkProfile{} } +// DisplayNameOrDefault will always return a non-empty string. +// If there is a defined display name, it will return that. +// If they did not it will default to their domain name. +func (n NetworkProfile) DisplayNameOrDefault() string { + return cmp.Or(n.DisplayName, n.DomainName) +} + // LoginProfile represents a single login profile as managed // by the ProfileManager. type LoginProfile struct { @@ -989,3 +1089,68 @@ type LoginProfile struct { // into. ControlURL string } + +// Equals reports whether p and p2 are equal. +func (p LoginProfileView) Equals(p2 LoginProfileView) bool { + return p.Đļ.Equals(p2.Đļ) +} + +// Equals reports whether p and p2 are equal. +func (p *LoginProfile) Equals(p2 *LoginProfile) bool { + if p == p2 { + return true + } + if p == nil || p2 == nil { + return false + } + return p.ID == p2.ID && + p.Name == p2.Name && + p.NetworkProfile == p2.NetworkProfile && + p.Key == p2.Key && + p.UserProfile.Equal(&p2.UserProfile) && + p.NodeID == p2.NodeID && + p.LocalUserID == p2.LocalUserID && + p.ControlURL == p2.ControlURL +} + +// ExitNodeExpression is a string that specifies how an exit node +// should be selected. An empty string means that no exit node +// should be selected. +// +// As of 2025-07-02, the only supported value is [AnyExitNode]. +type ExitNodeExpression string + +// AnyExitNode indicates that the exit node should be automatically +// selected from the pool of available exit nodes, excluding any +// disallowed by policy (e.g., [syspolicy.AllowedSuggestedExitNodes]). +// The exact implementation is subject to change, but exit nodes +// offering the best performance will be preferred. +const AnyExitNode ExitNodeExpression = "any" + +// IsSet reports whether the expression is non-empty and can be used +// to select an exit node. +func (e ExitNodeExpression) IsSet() bool { + return e != "" +} + +const ( + // AutoExitNodePrefix is the prefix used in [syspolicy.ExitNodeID] values and CLI + // to indicate that the string following the prefix is an [ipn.ExitNodeExpression]. + AutoExitNodePrefix = "auto:" +) + +// ParseAutoExitNodeString attempts to parse the given string +// as an [ExitNodeExpression]. +// +// It returns the parsed expression and true on success, +// or an empty string and false if the input does not appear to be +// an [ExitNodeExpression] (i.e., it doesn't start with "auto:"). +// +// It is mainly used to parse the [syspolicy.ExitNodeID] value +// when it is set to "auto:" (e.g., auto:any). +func ParseAutoExitNodeString[T ~string](s T) (_ ExitNodeExpression, ok bool) { + if expr, ok := strings.CutPrefix(string(s), AutoExitNodePrefix); ok && expr != "" { + return ExitNodeExpression(expr), true + } + return "", false +} diff --git a/ipn/prefs_test.go b/ipn/prefs_test.go index dcb999ef5..7c9c3ef43 100644 --- a/ipn/prefs_test.go +++ b/ipn/prefs_test.go @@ -23,6 +23,7 @@ import ( "tailscale.com/types/opt" "tailscale.com/types/persist" "tailscale.com/types/preftype" + "tailscale.com/util/syspolicy/policyclient" ) func fieldsOf(t reflect.Type) (fields []string) { @@ -40,6 +41,7 @@ func TestPrefsEqual(t *testing.T) { "RouteAll", "ExitNodeID", "ExitNodeIP", + "AutoExitNode", "InternalExitNodePrior", "ExitNodeAllowLANAccess", "CorpDNS", @@ -54,6 +56,8 @@ func TestPrefsEqual(t *testing.T) { "ForceDaemon", "Egg", "AdvertiseRoutes", + "AdvertiseServices", + "Sync", "NoSNAT", "NoStatefulFiltering", "NetfilterMode", @@ -64,6 +68,7 @@ func TestPrefsEqual(t *testing.T) { "PostureChecking", "NetfilterKind", "DriveShares", + "RelayServerPort", "AllowSingleHosts", "Persist", } @@ -72,6 +77,9 @@ func TestPrefsEqual(t *testing.T) { have, prefsHandles) } + relayServerPort := func(port int) *int { + return &port + } nets := func(strs ...string) (ns []netip.Prefix) { for _, s := range strs { n, err := netip.ParsePrefix(s) @@ -145,6 +153,17 @@ func TestPrefsEqual(t *testing.T) { true, }, + { + &Prefs{AutoExitNode: ""}, + &Prefs{AutoExitNode: "auto:any"}, + false, + }, + { + &Prefs{AutoExitNode: "auto:any"}, + &Prefs{AutoExitNode: "auto:any"}, + true, + }, + { &Prefs{}, &Prefs{ExitNodeAllowLANAccess: true}, @@ -330,6 +349,26 @@ func TestPrefsEqual(t *testing.T) { &Prefs{NetfilterKind: ""}, false, }, + { + &Prefs{AdvertiseServices: []string{"svc:tux", "svc:xenia"}}, + &Prefs{AdvertiseServices: []string{"svc:tux", "svc:xenia"}}, + true, + }, + { + &Prefs{AdvertiseServices: []string{"svc:tux", "svc:xenia"}}, + &Prefs{AdvertiseServices: []string{"svc:tux", "svc:amelie"}}, + false, + }, + { + &Prefs{RelayServerPort: relayServerPort(0)}, + &Prefs{RelayServerPort: nil}, + false, + }, + { + &Prefs{RelayServerPort: relayServerPort(0)}, + &Prefs{RelayServerPort: relayServerPort(1)}, + false, + }, } for i, tt := range tests { got := tt.a.Equals(tt.b) @@ -366,6 +405,7 @@ func checkPrefs(t *testing.T, p Prefs) { if err != nil { t.Fatalf("PrefsFromBytes(p2) failed: bytes=%q; err=%v\n", p2.ToBytes(), err) } + p2b.normalizeOptBools() p2p := p2.Pretty() p2bp := p2b.Pretty() t.Logf("\np2p: %#v\np2bp: %#v\n", p2p, p2bp) @@ -381,6 +421,42 @@ func checkPrefs(t *testing.T, p Prefs) { } } +// PrefsFromBytes documents that it preserves fields unset in the JSON. +// This verifies that stays true. +func TestPrefsFromBytesPreservesOldValues(t *testing.T) { + tests := []struct { + name string + old Prefs + json []byte + want Prefs + }{ + { + name: "preserve-control-url", + old: Prefs{ControlURL: "https://foo"}, + json: []byte(`{"RouteAll": true}`), + want: Prefs{ControlURL: "https://foo", RouteAll: true}, + }, + { + name: "opt.Bool", // test that we don't normalize it early + old: Prefs{Sync: "unset"}, + json: []byte(`{}`), + want: Prefs{Sync: "unset"}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + old := tt.old // shallow + err := PrefsFromBytes(tt.json, &old) + if err != nil { + t.Fatalf("PrefsFromBytes failed: %v", err) + } + if !old.Equals(&tt.want) { + t.Fatalf("got %+v; want %+v", old, tt.want) + } + }) + } +} + func TestBasicPrefs(t *testing.T) { tstest.PanicOnLog() @@ -456,13 +532,6 @@ func TestPrefsPretty(t *testing.T) { "darwin", `Prefs{ra=false dns=false want=true tags=tag:foo,tag:bar url="http://localhost:1234" update=off Persist=nil}`, }, - { - Prefs{ - Persist: &persist.Persist{}, - }, - "linux", - `Prefs{ra=false dns=false want=false routes=[] nf=off update=off Persist{lm=, o=, n= u=""}}`, - }, { Prefs{ Persist: &persist.Persist{ @@ -470,7 +539,7 @@ func TestPrefsPretty(t *testing.T) { }, }, "linux", - `Prefs{ra=false dns=false want=false routes=[] nf=off update=off Persist{lm=, o=, n=[B1VKl] u=""}}`, + `Prefs{ra=false dns=false want=false routes=[] nf=off update=off Persist{o=, n=[B1VKl] u="" ak=-}}`, }, { Prefs{ @@ -560,6 +629,11 @@ func TestPrefsPretty(t *testing.T) { "linux", `Prefs{ra=false dns=false want=false routes=[] nf=off update=off Persist=nil}`, }, + { + Prefs{Sync: "false"}, + "linux", + "Prefs{ra=false dns=false want=false sync=false routes=[] nf=off update=off Persist=nil}", + }, } for i, tt := range tests { got := tt.p.pretty(tt.os) @@ -866,6 +940,23 @@ func TestExitNodeIPOfArg(t *testing.T) { }, wantErr: `no node found in netmap with IP 1.2.3.4`, }, + { + name: "ip_is_self", + arg: "1.2.3.4", + st: &ipnstate.Status{ + TailscaleIPs: []netip.Addr{mustIP("1.2.3.4")}, + }, + wantErr: "cannot use 1.2.3.4 as an exit node as it is a local IP address to this machine", + }, + { + name: "ip_is_self_when_backend_running", + arg: "1.2.3.4", + st: &ipnstate.Status{ + BackendState: "Running", + TailscaleIPs: []netip.Addr{mustIP("1.2.3.4")}, + }, + wantErr: "cannot use 1.2.3.4 as an exit node as it is a local IP address to this machine", + }, { name: "ip_not_exit", arg: "1.2.3.4", @@ -1002,15 +1093,16 @@ func TestExitNodeIPOfArg(t *testing.T) { func TestControlURLOrDefault(t *testing.T) { var p Prefs - if got, want := p.ControlURLOrDefault(), DefaultControlURL; got != want { + polc := policyclient.NoPolicyClient{} + if got, want := p.ControlURLOrDefault(polc), DefaultControlURL; got != want { t.Errorf("got %q; want %q", got, want) } p.ControlURL = "http://foo.bar" - if got, want := p.ControlURLOrDefault(), "http://foo.bar"; got != want { + if got, want := p.ControlURLOrDefault(polc), "http://foo.bar"; got != want { t.Errorf("got %q; want %q", got, want) } p.ControlURL = "https://login.tailscale.com" - if got, want := p.ControlURLOrDefault(), DefaultControlURL; got != want { + if got, want := p.ControlURLOrDefault(polc), DefaultControlURL; got != want { t.Errorf("got %q; want %q", got, want) } } @@ -1099,3 +1191,62 @@ func TestPrefsDowngrade(t *testing.T) { t.Fatal("AllowSingleHosts should be true") } } + +func TestParseAutoExitNodeString(t *testing.T) { + tests := []struct { + name string + exitNodeID string + wantOk bool + wantExpr ExitNodeExpression + }{ + { + name: "empty expr", + exitNodeID: "", + wantOk: false, + wantExpr: "", + }, + { + name: "no auto prefix", + exitNodeID: "foo", + wantOk: false, + wantExpr: "", + }, + { + name: "auto:any", + exitNodeID: "auto:any", + wantOk: true, + wantExpr: AnyExitNode, + }, + { + name: "auto:foo", + exitNodeID: "auto:foo", + wantOk: true, + wantExpr: "foo", + }, + { + name: "auto prefix but empty suffix", + exitNodeID: "auto:", + wantOk: false, + wantExpr: "", + }, + { + name: "auto prefix no colon", + exitNodeID: "auto", + wantOk: false, + wantExpr: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotExpr, gotOk := ParseAutoExitNodeString(tt.exitNodeID) + if gotOk != tt.wantOk || gotExpr != tt.wantExpr { + if tt.wantOk { + t.Fatalf("got %v (%q); want %v (%q)", gotOk, gotExpr, tt.wantOk, tt.wantExpr) + } else { + t.Fatalf("got %v (%q); want false", gotOk, gotExpr) + } + } + }) + } +} diff --git a/ipn/serve.go b/ipn/serve.go index 5c0a97ed3..74195191c 100644 --- a/ipn/serve.go +++ b/ipn/serve.go @@ -6,6 +6,7 @@ package ipn import ( "errors" "fmt" + "iter" "net" "net/netip" "net/url" @@ -15,7 +16,10 @@ import ( "tailscale.com/ipn/ipnstate" "tailscale.com/tailcfg" + "tailscale.com/types/ipproto" + "tailscale.com/util/dnsname" "tailscale.com/util/mak" + "tailscale.com/util/set" ) // ServeConfigKey returns a StateKey that stores the @@ -24,6 +28,23 @@ func ServeConfigKey(profileID ProfileID) StateKey { return StateKey("_serve/" + profileID) } +// ServiceConfig contains the config information for a single service. +// it contains a bool to indicate if the service is in Tun mode (L3 forwarding). +// If the service is not in Tun mode, the service is configured by the L4 forwarding +// (TCP ports) and/or the L7 forwarding (http handlers) information. +type ServiceConfig struct { + // TCP are the list of TCP port numbers that tailscaled should handle for + // the Tailscale IP addresses. (not subnet routers, etc) + TCP map[uint16]*TCPPortHandler `json:",omitempty"` + + // Web maps from "$SNI_NAME:$PORT" to a set of HTTP handlers + // keyed by mount point ("/", "/foo", etc) + Web map[HostPort]*WebServerConfig `json:",omitempty"` + + // Tun determines if the service should be using L3 forwarding (Tun mode). + Tun bool `json:",omitempty"` +} + // ServeConfig is the JSON type stored in the StateStore for // StateKey "_serve/$PROFILE_ID" as returned by ServeConfigKey. type ServeConfig struct { @@ -35,16 +56,20 @@ type ServeConfig struct { // keyed by mount point ("/", "/foo", etc) Web map[HostPort]*WebServerConfig `json:",omitempty"` + // Services maps from service name (in the form "svc:dns-label") to a ServiceConfig. + // Which describes the L3, L4, and L7 forwarding information for the service. + Services map[tailcfg.ServiceName]*ServiceConfig `json:",omitempty"` + // AllowFunnel is the set of SNI:port values for which funnel // traffic is allowed, from trusted ingress peers. AllowFunnel map[HostPort]bool `json:",omitempty"` - // Foreground is a map of an IPN Bus session ID to an alternate foreground - // serve config that's valid for the life of that WatchIPNBus session ID. - // This. This allows the config to specify ephemeral configs that are - // used in the CLI's foreground mode to ensure ungraceful shutdowns - // of either the client or the LocalBackend does not expose ports - // that users are not aware of. + // Foreground is a map of an IPN Bus session ID to an alternate foreground serve config that's valid for the + // life of that WatchIPNBus session ID. This allows the config to specify ephemeral configs that are used + // in the CLI's foreground mode to ensure ungraceful shutdowns of either the client or the LocalBackend does not + // expose ports that users are not aware of. In practice this contains any serve config set via 'tailscale + // serve' command run without the '--bg' flag. ServeConfig contained by Foreground is not expected itself to contain + // another Foreground block. Foreground map[string]*ServeConfig `json:",omitempty"` // ETag is the checksum of the serve config that's populated @@ -125,6 +150,12 @@ type TCPPortHandler struct { // SNI name with this value. It is only used if TCPForward is non-empty. // (the HTTPS mode uses ServeConfig.Web) TerminateTLS string `json:",omitempty"` + + // ProxyProtocol indicates whether to send a PROXY protocol header + // before forwarding the connection to TCPForward. + // + // This is only valid if TCPForward is non-empty. + ProxyProtocol int `json:",omitzero"` } // HTTPHandler is either a path or a proxy to serve. @@ -136,32 +167,61 @@ type HTTPHandler struct { Text string `json:",omitempty"` // plaintext to serve (primarily for testing) + AcceptAppCaps []tailcfg.PeerCapability `json:",omitempty"` // peer capabilities to forward in grant header, e.g. example.com/cap/mon + + // Redirect, if not empty, is the target URL to redirect requests to. + // By default, we redirect with HTTP 302 (Found) status. + // If Redirect starts with ':', then we use that status instead. + // + // The target URL supports the following expansion variables: + // - ${HOST}: replaced with the request's Host header value + // - ${REQUEST_URI}: replaced with the request's full URI (path and query string) + Redirect string `json:",omitempty"` + // TODO(bradfitz): bool to not enumerate directories? TTL on mapping for - // temporary ones? Error codes? Redirects? + // temporary ones? Error codes? } // WebHandlerExists reports whether if the ServeConfig Web handler exists for // the given host:port and mount point. -func (sc *ServeConfig) WebHandlerExists(hp HostPort, mount string) bool { - h := sc.GetWebHandler(hp, mount) +func (sc *ServeConfig) WebHandlerExists(svcName tailcfg.ServiceName, hp HostPort, mount string) bool { + h := sc.GetWebHandler(svcName, hp, mount) return h != nil } // GetWebHandler returns the HTTPHandler for the given host:port and mount point. // Returns nil if the handler does not exist. -func (sc *ServeConfig) GetWebHandler(hp HostPort, mount string) *HTTPHandler { - if sc == nil || sc.Web[hp] == nil { +func (sc *ServeConfig) GetWebHandler(svcName tailcfg.ServiceName, hp HostPort, mount string) *HTTPHandler { + if sc == nil { + return nil + } + if svcName != "" { + if svc, ok := sc.Services[svcName]; ok && svc.Web != nil { + if webCfg, ok := svc.Web[hp]; ok { + return webCfg.Handlers[mount] + } + } + return nil + } + if sc.Web[hp] == nil { return nil } return sc.Web[hp].Handlers[mount] } -// GetTCPPortHandler returns the TCPPortHandler for the given port. -// If the port is not configured, nil is returned. -func (sc *ServeConfig) GetTCPPortHandler(port uint16) *TCPPortHandler { +// GetTCPPortHandler returns the TCPPortHandler for the given port. If the port +// is not configured, nil is returned. Parameter svcName can be tailcfg.NoService +// for local serve or a service name for a service hosted on node. +func (sc *ServeConfig) GetTCPPortHandler(port uint16, svcName tailcfg.ServiceName) *TCPPortHandler { if sc == nil { return nil } + if svcName != "" { + if svc, ok := sc.Services[svcName]; ok && svc != nil { + return svc.TCP[port] + } + return nil + } return sc.TCP[port] } @@ -203,34 +263,78 @@ func (sc *ServeConfig) IsTCPForwardingAny() bool { return false } -// IsTCPForwardingOnPort reports whether if ServeConfig is currently forwarding -// in TCPForward mode on the given port. This is exclusive of Web/HTTPS serving. -func (sc *ServeConfig) IsTCPForwardingOnPort(port uint16) bool { - if sc == nil || sc.TCP[port] == nil { +// IsTCPForwardingOnPort reports whether ServeConfig is currently forwarding +// in TCPForward mode on the given port for local or a service. svcName will +// either be noService (empty string) for local serve or a serviceName for service +// hosted on node. Notice TCPForwarding is exclusive with Web/HTTPS serving. +func (sc *ServeConfig) IsTCPForwardingOnPort(port uint16, svcName tailcfg.ServiceName) bool { + if sc == nil { return false } - return !sc.IsServingWeb(port) + + if svcName != "" { + svc, ok := sc.Services[svcName] + if !ok || svc == nil { + return false + } + if svc.TCP[port] == nil { + return false + } + } else if sc.TCP[port] == nil { + return false + } + return !sc.IsServingWeb(port, svcName) } -// IsServingWeb reports whether if ServeConfig is currently serving Web -// (HTTP/HTTPS) on the given port. This is exclusive of TCPForwarding. -func (sc *ServeConfig) IsServingWeb(port uint16) bool { - return sc.IsServingHTTP(port) || sc.IsServingHTTPS(port) +// IsServingWeb reports whether ServeConfig is currently serving Web (HTTP/HTTPS) +// on the given port for local or a service. svcName will be either tailcfg.NoService, +// or a serviceName for service hosted on node. This is exclusive with TCPForwarding. +func (sc *ServeConfig) IsServingWeb(port uint16, svcName tailcfg.ServiceName) bool { + return sc.IsServingHTTP(port, svcName) || sc.IsServingHTTPS(port, svcName) } -// IsServingHTTPS reports whether if ServeConfig is currently serving HTTPS on -// the given port. This is exclusive of HTTP and TCPForwarding. -func (sc *ServeConfig) IsServingHTTPS(port uint16) bool { - if sc == nil || sc.TCP[port] == nil { +// IsServingHTTPS reports whether ServeConfig is currently serving HTTPS on +// the given port for local or a service. svcName will be either tailcfg.NoService +// for local serve, or a serviceName for service hosted on node. This is exclusive +// with HTTP and TCPForwarding. +func (sc *ServeConfig) IsServingHTTPS(port uint16, svcName tailcfg.ServiceName) bool { + if sc == nil { return false } - return sc.TCP[port].HTTPS + var tcpHandlers map[uint16]*TCPPortHandler + if svcName != "" { + if svc := sc.Services[svcName]; svc != nil { + tcpHandlers = svc.TCP + } + } else { + tcpHandlers = sc.TCP + } + + th := tcpHandlers[port] + if th == nil { + return false + } + return th.HTTPS } -// IsServingHTTP reports whether if ServeConfig is currently serving HTTP on the -// given port. This is exclusive of HTTPS and TCPForwarding. -func (sc *ServeConfig) IsServingHTTP(port uint16) bool { - if sc == nil || sc.TCP[port] == nil { +// IsServingHTTP reports whether ServeConfig is currently serving HTTP on the +// given port for local or a service. svcName will be either tailcfg.NoService for +// local serve, or a serviceName for service hosted on node. This is exclusive +// with HTTPS and TCPForwarding. +func (sc *ServeConfig) IsServingHTTP(port uint16, svcName tailcfg.ServiceName) bool { + if sc == nil { + return false + } + if svcName != "" { + if svc := sc.Services[svcName]; svc != nil { + if svc.TCP[port] != nil { + return svc.TCP[port].HTTP + } + } + return false + } + + if sc.TCP[port] == nil { return false } return sc.TCP[port].HTTP @@ -256,21 +360,38 @@ func (sc *ServeConfig) FindConfig(port uint16) (*ServeConfig, bool) { // SetWebHandler sets the given HTTPHandler at the specified host, port, // and mount in the serve config. sc.TCP is also updated to reflect web -// serving usage of the given port. -func (sc *ServeConfig) SetWebHandler(handler *HTTPHandler, host string, port uint16, mount string, useTLS bool) { +// serving usage of the given port. The st argument is needed when setting +// a web handler for a service, otherwise it can be nil. mds is the Magic DNS +// suffix, which is used to recreate serve's host. +func (sc *ServeConfig) SetWebHandler(handler *HTTPHandler, host string, port uint16, mount string, useTLS bool, mds string) { if sc == nil { sc = new(ServeConfig) } - mak.Set(&sc.TCP, port, &TCPPortHandler{HTTPS: useTLS, HTTP: !useTLS}) - hp := HostPort(net.JoinHostPort(host, strconv.Itoa(int(port)))) - if _, ok := sc.Web[hp]; !ok { - mak.Set(&sc.Web, hp, new(WebServerConfig)) + tcpMap := &sc.TCP + webServerMap := &sc.Web + hostName := host + if svcName := tailcfg.AsServiceName(host); svcName != "" { + hostName = strings.Join([]string{svcName.WithoutPrefix(), mds}, ".") + svc, ok := sc.Services[svcName] + if !ok { + svc = new(ServiceConfig) + mak.Set(&sc.Services, svcName, svc) + } + tcpMap = &svc.TCP + webServerMap = &svc.Web } - mak.Set(&sc.Web[hp].Handlers, mount, handler) + mak.Set(tcpMap, port, &TCPPortHandler{HTTPS: useTLS, HTTP: !useTLS}) + hp := HostPort(net.JoinHostPort(hostName, strconv.Itoa(int(port)))) + webCfg, ok := (*webServerMap)[hp] + if !ok { + webCfg = new(WebServerConfig) + mak.Set(webServerMap, hp, webCfg) + } + mak.Set(&webCfg.Handlers, mount, handler) // TODO(tylersmalley): handle multiple web handlers from foreground mode - for k, v := range sc.Web[hp].Handlers { + for k, v := range webCfg.Handlers { if v == handler { continue } @@ -281,7 +402,7 @@ func (sc *ServeConfig) SetWebHandler(handler *HTTPHandler, host string, port uin m1 := strings.TrimSuffix(mount, "/") m2 := strings.TrimSuffix(k, "/") if m1 == m2 { - delete(sc.Web[hp].Handlers, k) + delete(webCfg.Handlers, k) } } } @@ -290,14 +411,31 @@ func (sc *ServeConfig) SetWebHandler(handler *HTTPHandler, host string, port uin // connections from the given port. If terminateTLS is true, TLS connections // are terminated with only the given host name permitted before passing them // to the fwdAddr. -func (sc *ServeConfig) SetTCPForwarding(port uint16, fwdAddr string, terminateTLS bool, host string) { +// +// If proxyProtocol is non-zero, the corresponding PROXY protocol version +// header is sent before forwarding the connection. +func (sc *ServeConfig) SetTCPForwarding(port uint16, fwdAddr string, terminateTLS bool, proxyProtocol int, host string) { if sc == nil { sc = new(ServeConfig) } - mak.Set(&sc.TCP, port, &TCPPortHandler{TCPForward: fwdAddr}) + tcpPortHandler := &sc.TCP + if svcName := tailcfg.AsServiceName(host); svcName != "" { + svcConfig, ok := sc.Services[svcName] + if !ok { + svcConfig = new(ServiceConfig) + mak.Set(&sc.Services, svcName, svcConfig) + } + tcpPortHandler = &svcConfig.TCP + } + + handler := &TCPPortHandler{ + TCPForward: fwdAddr, + ProxyProtocol: proxyProtocol, // can be 0 + } if terminateTLS { - sc.TCP[port].TerminateTLS = host + handler.TerminateTLS = host } + mak.Set(tcpPortHandler, port, handler) } // SetFunnel sets the sc.AllowFunnel value for the given host and port. @@ -320,9 +458,9 @@ func (sc *ServeConfig) SetFunnel(host string, port uint16, setOn bool) { } } -// RemoveWebHandler deletes the web handlers at all of the given mount points -// for the provided host and port in the serve config. If cleanupFunnel is -// true, this also removes the funnel value for this port if no handlers remain. +// RemoveWebHandler deletes the web handlers at all of the given mount points for the +// provided host and port in the serve config for the node (as opposed to a service). +// If cleanupFunnel is true, this also removes the funnel value for this port if no handlers remain. func (sc *ServeConfig) RemoveWebHandler(host string, port uint16, mounts []string, cleanupFunnel bool) { hp := HostPort(net.JoinHostPort(host, strconv.Itoa(int(port)))) @@ -350,9 +488,50 @@ func (sc *ServeConfig) RemoveWebHandler(host string, port uint16, mounts []strin } } +// RemoveServiceWebHandler deletes the web handlers at all of the given mount points +// for the provided host and port in the serve config for the given service. +func (sc *ServeConfig) RemoveServiceWebHandler(svcName tailcfg.ServiceName, hostName string, port uint16, mounts []string) { + hp := HostPort(net.JoinHostPort(hostName, strconv.Itoa(int(port)))) + + svc, ok := sc.Services[svcName] + if !ok || svc == nil { + return + } + + // Delete existing handler, then cascade delete if empty. + for _, m := range mounts { + delete(svc.Web[hp].Handlers, m) + } + if len(svc.Web[hp].Handlers) == 0 { + delete(svc.Web, hp) + delete(svc.TCP, port) + } + if len(svc.Web) == 0 && len(svc.TCP) == 0 { + delete(sc.Services, svcName) + } + if len(sc.Services) == 0 { + sc.Services = nil + } +} + // RemoveTCPForwarding deletes the TCP forwarding configuration for the given // port from the serve config. -func (sc *ServeConfig) RemoveTCPForwarding(port uint16) { +func (sc *ServeConfig) RemoveTCPForwarding(svcName tailcfg.ServiceName, port uint16) { + if svcName != "" { + if svc := sc.Services[svcName]; svc != nil { + delete(svc.TCP, port) + if len(svc.TCP) == 0 { + svc.TCP = nil + } + if len(svc.Web) == 0 && len(svc.TCP) == 0 { + delete(sc.Services, svcName) + } + if len(sc.Services) == 0 { + sc.Services = nil + } + } + return + } delete(sc.TCP, port) if len(sc.TCP) == 0 { sc.TCP = nil @@ -365,8 +544,7 @@ func (sc *ServeConfig) RemoveTCPForwarding(port uint16) { // View version of ServeConfig.IsFunnelOn. func (v ServeConfigView) IsFunnelOn() bool { return v.Đļ.IsFunnelOn() } -// IsFunnelOn reports whether if ServeConfig is currently allowing funnel -// traffic for any host:port. +// IsFunnelOn reports whether any funnel endpoint is currently enabled for this node. func (sc *ServeConfig) IsFunnelOn() bool { if sc == nil { return false @@ -376,6 +554,11 @@ func (sc *ServeConfig) IsFunnelOn() bool { return true } } + for _, conf := range sc.Foreground { + if conf.IsFunnelOn() { + return true + } + } return false } @@ -491,7 +674,8 @@ func CheckFunnelPort(wantedPort uint16, node *ipnstate.PeerStatus) error { // ExpandProxyTargetValue expands the supported target values to be proxied // allowing for input values to be a port number, a partial URL, or a full URL -// including a path. +// including a path. If it's for a service, remote addresses are allowed and +// there doesn't have to be a port specified. // // examples: // - 3000 @@ -501,17 +685,25 @@ func CheckFunnelPort(wantedPort uint16, node *ipnstate.PeerStatus) error { // - https://localhost:3000 // - https-insecure://localhost:3000 // - https-insecure://localhost:3000/foo +// - https://tailscale.com func ExpandProxyTargetValue(target string, supportedSchemes []string, defaultScheme string) (string, error) { const host = "127.0.0.1" + // empty target is invalid + if target == "" { + return "", fmt.Errorf("empty target") + } + // support target being a port number if port, err := strconv.ParseUint(target, 10, 16); err == nil { return fmt.Sprintf("%s://%s:%d", defaultScheme, host, port), nil } + hasScheme := true // prepend scheme if not present if !strings.Contains(target, "://") { target = defaultScheme + "://" + target + hasScheme = false } // make sure we can parse the target @@ -525,16 +717,28 @@ func ExpandProxyTargetValue(target string, supportedSchemes []string, defaultSch return "", fmt.Errorf("must be a URL starting with one of the supported schemes: %v", supportedSchemes) } - // validate the host. - switch u.Hostname() { - case "localhost", "127.0.0.1": - default: - return "", errors.New("only localhost or 127.0.0.1 proxies are currently supported") + // validate port according to host. + if u.Hostname() == "localhost" || u.Hostname() == "127.0.0.1" || u.Hostname() == "::1" { + // require port for localhost targets + if u.Port() == "" { + return "", fmt.Errorf("port required for localhost target %q", target) + } + } else { + validHN := dnsname.ValidHostname(u.Hostname()) == nil + validIP := net.ParseIP(u.Hostname()) != nil + if !validHN && !validIP { + return "", fmt.Errorf("invalid hostname or IP address %q", u.Hostname()) + } + // require scheme for non-localhost targets + if !hasScheme { + return "", fmt.Errorf("non-localhost target %q must include a scheme", target) + } } - - // validate the port port, err := strconv.ParseUint(u.Port(), 10, 16) if err != nil || port == 0 { + if u.Port() == "" { + return u.String(), nil // allow no port for remote destinations + } return "", fmt.Errorf("invalid port %q", u.Port()) } @@ -543,58 +747,78 @@ func ExpandProxyTargetValue(target string, supportedSchemes []string, defaultSch return u.String(), nil } -// RangeOverTCPs ranges over both background and foreground TCPs. -// If the returned bool from the given f is false, then this function stops -// iterating immediately and does not check other foreground configs. -func (v ServeConfigView) RangeOverTCPs(f func(port uint16, _ TCPPortHandlerView) bool) { - parentCont := true - v.TCP().Range(func(k uint16, v TCPPortHandlerView) (cont bool) { - parentCont = f(k, v) - return parentCont - }) - v.Foreground().Range(func(k string, v ServeConfigView) (cont bool) { - if !parentCont { - return false +// TCPs returns an iterator over both background and foreground TCP +// listeners. +// +// The key is the port number. +func (v ServeConfigView) TCPs() iter.Seq2[uint16, TCPPortHandlerView] { + return func(yield func(uint16, TCPPortHandlerView) bool) { + for k, v := range v.TCP().All() { + if !yield(k, v) { + return + } } - v.TCP().Range(func(k uint16, v TCPPortHandlerView) (cont bool) { - parentCont = f(k, v) - return parentCont - }) - return parentCont - }) -} - -// RangeOverWebs ranges over both background and foreground Webs. -// If the returned bool from the given f is false, then this function stops -// iterating immediately and does not check other foreground configs. -func (v ServeConfigView) RangeOverWebs(f func(_ HostPort, conf WebServerConfigView) bool) { - parentCont := true - v.Web().Range(func(k HostPort, v WebServerConfigView) (cont bool) { - parentCont = f(k, v) - return parentCont - }) - v.Foreground().Range(func(k string, v ServeConfigView) (cont bool) { - if !parentCont { - return false + for _, conf := range v.Foreground().All() { + for k, v := range conf.TCP().All() { + if !yield(k, v) { + return + } + } } - v.Web().Range(func(k HostPort, v WebServerConfigView) (cont bool) { - parentCont = f(k, v) - return parentCont - }) - return parentCont - }) + } +} + +// Webs returns an iterator over both background and foreground Web configurations. +func (v ServeConfigView) Webs() iter.Seq2[HostPort, WebServerConfigView] { + return func(yield func(HostPort, WebServerConfigView) bool) { + for k, v := range v.Web().All() { + if !yield(k, v) { + return + } + } + for _, conf := range v.Foreground().All() { + for k, v := range conf.Web().All() { + if !yield(k, v) { + return + } + } + } + for _, service := range v.Services().All() { + for k, v := range service.Web().All() { + if !yield(k, v) { + return + } + } + } + } +} + +// FindServiceTCP return the TCPPortHandlerView for the given service name and port. +func (v ServeConfigView) FindServiceTCP(svcName tailcfg.ServiceName, port uint16) (res TCPPortHandlerView, ok bool) { + svcCfg, ok := v.Services().GetOk(svcName) + if !ok { + return res, ok + } + return svcCfg.TCP().GetOk(port) +} + +func (v ServeConfigView) FindServiceWeb(svcName tailcfg.ServiceName, hp HostPort) (res WebServerConfigView, ok bool) { + if svcCfg, ok := v.Services().GetOk(svcName); ok { + if res, ok := svcCfg.Web().GetOk(hp); ok { + return res, ok + } + } + return res, ok } // FindTCP returns the first TCP that matches with the given port. It // prefers a foreground match first followed by a background search if none // existed. func (v ServeConfigView) FindTCP(port uint16) (res TCPPortHandlerView, ok bool) { - v.Foreground().Range(func(_ string, v ServeConfigView) (cont bool) { - res, ok = v.TCP().GetOk(port) - return !ok - }) - if ok { - return res, ok + for _, conf := range v.Foreground().All() { + if res, ok := conf.TCP().GetOk(port); ok { + return res, ok + } } return v.TCP().GetOk(port) } @@ -603,12 +827,10 @@ func (v ServeConfigView) FindTCP(port uint16) (res TCPPortHandlerView, ok bool) // prefers a foreground match first followed by a background search if none // existed. func (v ServeConfigView) FindWeb(hp HostPort) (res WebServerConfigView, ok bool) { - v.Foreground().Range(func(_ string, v ServeConfigView) (cont bool) { - res, ok = v.Web().GetOk(hp) - return !ok - }) - if ok { - return res, ok + for _, conf := range v.Foreground().All() { + if res, ok := conf.Web().GetOk(hp); ok { + return res, ok + } } return v.Web().GetOk(hp) } @@ -616,14 +838,15 @@ func (v ServeConfigView) FindWeb(hp HostPort) (res WebServerConfigView, ok bool) // HasAllowFunnel returns whether this config has at least one AllowFunnel // set in the background or foreground configs. func (v ServeConfigView) HasAllowFunnel() bool { - return v.AllowFunnel().Len() > 0 || func() bool { - var exists bool - v.Foreground().Range(func(k string, v ServeConfigView) (cont bool) { - exists = v.AllowFunnel().Len() > 0 - return !exists - }) - return exists - }() + if v.AllowFunnel().Len() > 0 { + return true + } + for _, conf := range v.Foreground().All() { + if conf.AllowFunnel().Len() > 0 { + return true + } + } + return false } // FindFunnel reports whether target exists in either the background AllowFunnel @@ -632,12 +855,73 @@ func (v ServeConfigView) HasFunnelForTarget(target HostPort) bool { if v.AllowFunnel().Get(target) { return true } - var exists bool - v.Foreground().Range(func(_ string, v ServeConfigView) (cont bool) { - if exists = v.AllowFunnel().Get(target); exists { - return false + for _, conf := range v.Foreground().All() { + if conf.AllowFunnel().Get(target) { + return true } - return true - }) - return exists + } + return false +} + +// CheckValidServicesConfig reports whether the ServeConfig has +// invalid service configurations. +func (sc *ServeConfig) CheckValidServicesConfig() error { + for svcName, service := range sc.Services { + if err := service.checkValidConfig(); err != nil { + return fmt.Errorf("invalid service configuration for %q: %w", svcName, err) + } + } + return nil +} + +// ServicePortRange returns the list of tailcfg.ProtoPortRange that represents +// the proto/ports pairs that are being served by the service. +// +// Right now Tun mode is the only thing supports UDP, otherwise serve only supports TCP. +func (v ServiceConfigView) ServicePortRange() []tailcfg.ProtoPortRange { + if v.Tun() { + // If the service is in Tun mode, means service accept TCP/UDP on all ports. + return []tailcfg.ProtoPortRange{{Ports: tailcfg.PortRangeAny}} + } + tcp := int(ipproto.TCP) + + // Deduplicate the ports. + servePorts := make(set.Set[uint16]) + for port := range v.TCP().All() { + if port > 0 { + servePorts.Add(uint16(port)) + } + } + dedupedServePorts := servePorts.Slice() + slices.Sort(dedupedServePorts) + + var ranges []tailcfg.ProtoPortRange + for _, p := range dedupedServePorts { + if n := len(ranges); n > 0 && p == ranges[n-1].Ports.Last+1 { + ranges[n-1].Ports.Last = p + continue + } + ranges = append(ranges, tailcfg.ProtoPortRange{ + Proto: tcp, + Ports: tailcfg.PortRange{ + First: p, + Last: p, + }, + }) + } + return ranges +} + +// ErrServiceConfigHasBothTCPAndTun signals that a service +// in Tun mode cannot also has TCP or Web handlers set. +var ErrServiceConfigHasBothTCPAndTun = errors.New("the VIP Service configuration can not set TUN at the same time as TCP or Web") + +// checkValidConfig checks if the service configuration is valid. +// Currently, the only invalid configuration is when the service is in Tun mode +// and has TCP or Web handlers. +func (v *ServiceConfig) checkValidConfig() error { + if v.Tun && (len(v.TCP) > 0 || len(v.Web) > 0) { + return ErrServiceConfigHasBothTCPAndTun + } + return nil } diff --git a/ipn/serve_test.go b/ipn/serve_test.go index e9d8e8f32..063ff3a87 100644 --- a/ipn/serve_test.go +++ b/ipn/serve_test.go @@ -1,5 +1,6 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause + package ipn import ( @@ -127,6 +128,121 @@ func TestHasPathHandler(t *testing.T) { } } +func TestIsTCPForwardingOnPort(t *testing.T) { + tests := []struct { + name string + cfg ServeConfig + svcName tailcfg.ServiceName + port uint16 + want bool + }{ + { + name: "empty-config", + cfg: ServeConfig{}, + svcName: "", + port: 80, + want: false, + }, + { + name: "node-tcp-config-match", + cfg: ServeConfig{ + TCP: map[uint16]*TCPPortHandler{80: {TCPForward: "10.0.0.123:3000"}}, + }, + svcName: "", + port: 80, + want: true, + }, + { + name: "node-tcp-config-no-match", + cfg: ServeConfig{ + TCP: map[uint16]*TCPPortHandler{80: {TCPForward: "10.0.0.123:3000"}}, + }, + svcName: "", + port: 443, + want: false, + }, + { + name: "node-tcp-config-no-match-with-service", + cfg: ServeConfig{ + TCP: map[uint16]*TCPPortHandler{80: {TCPForward: "10.0.0.123:3000"}}, + }, + svcName: "svc:bar", + port: 80, + want: false, + }, + { + name: "node-web-config-no-match", + cfg: ServeConfig{ + TCP: map[uint16]*TCPPortHandler{80: {HTTPS: true}}, + Web: map[HostPort]*WebServerConfig{ + "foo.test.ts.net:80": { + Handlers: map[string]*HTTPHandler{ + "/": {Text: "Hello, world!"}, + }, + }, + }, + }, + svcName: "", + port: 80, + want: false, + }, + { + name: "service-tcp-config-match", + cfg: ServeConfig{ + Services: map[tailcfg.ServiceName]*ServiceConfig{ + "svc:foo": { + TCP: map[uint16]*TCPPortHandler{80: {TCPForward: "10.0.0.123:3000"}}, + }, + }, + }, + svcName: "svc:foo", + port: 80, + want: true, + }, + { + name: "service-tcp-config-no-match", + cfg: ServeConfig{ + Services: map[tailcfg.ServiceName]*ServiceConfig{ + "svc:foo": { + TCP: map[uint16]*TCPPortHandler{80: {TCPForward: "10.0.0.123:3000"}}, + }, + }, + }, + svcName: "svc:bar", + port: 80, + want: false, + }, + { + name: "service-web-config-no-match", + cfg: ServeConfig{ + Services: map[tailcfg.ServiceName]*ServiceConfig{ + "svc:foo": { + TCP: map[uint16]*TCPPortHandler{80: {HTTPS: true}}, + Web: map[HostPort]*WebServerConfig{ + "foo.test.ts.net:80": { + Handlers: map[string]*HTTPHandler{ + "/": {Text: "Hello, world!"}, + }, + }, + }, + }, + }, + }, + svcName: "svc:foo", + port: 80, + want: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.cfg.IsTCPForwardingOnPort(tt.port, tt.svcName) + if tt.want != got { + t.Errorf("IsTCPForwardingOnPort() = %v, want %v", got, tt.want) + } + }) + } +} + func TestExpandProxyTargetDev(t *testing.T) { tests := []struct { name string @@ -144,12 +260,16 @@ func TestExpandProxyTargetDev(t *testing.T) { {name: "https+insecure-scheme", input: "https+insecure://localhost:8080", expected: "https+insecure://localhost:8080"}, {name: "change-default-scheme", input: "localhost:8080", defaultScheme: "https", expected: "https://localhost:8080"}, {name: "change-supported-schemes", input: "localhost:8080", defaultScheme: "tcp", supportedSchemes: []string{"tcp"}, expected: "tcp://localhost:8080"}, + {name: "remote-target", input: "https://example.com:8080", expected: "https://example.com:8080"}, + {name: "remote-IP-target", input: "http://120.133.20.2:8080", expected: "http://120.133.20.2:8080"}, + {name: "remote-target-no-port", input: "https://example.com", expected: "https://example.com"}, // errors {name: "invalid-port", input: "localhost:9999999", wantErr: true}, + {name: "invalid-hostname", input: "192.168.1:8080", wantErr: true}, {name: "unsupported-scheme", input: "ftp://localhost:8080", expected: "", wantErr: true}, - {name: "not-localhost", input: "https://tailscale.com:8080", expected: "", wantErr: true}, {name: "empty-input", input: "", expected: "", wantErr: true}, + {name: "localhost-no-port", input: "localhost", expected: "", wantErr: true}, } for _, tt := range tests { @@ -182,3 +302,88 @@ func TestExpandProxyTargetDev(t *testing.T) { }) } } + +func TestIsFunnelOn(t *testing.T) { + tests := []struct { + name string + sc *ServeConfig + want bool + }{ + { + name: "nil_config", + }, + { + name: "empty_config", + sc: &ServeConfig{}, + }, + { + name: "funnel_enabled_in_background", + sc: &ServeConfig{ + AllowFunnel: map[HostPort]bool{ + "tailnet.xyz:443": true, + }, + }, + want: true, + }, + { + name: "funnel_disabled_in_background", + sc: &ServeConfig{ + AllowFunnel: map[HostPort]bool{ + "tailnet.xyz:443": false, + }, + }, + }, + { + name: "funnel_enabled_in_foreground", + sc: &ServeConfig{ + Foreground: map[string]*ServeConfig{ + "abc123": { + AllowFunnel: map[HostPort]bool{ + "tailnet.xyz:443": true, + }, + }, + }, + }, + want: true, + }, + { + name: "funnel_disabled_in_both", + sc: &ServeConfig{ + AllowFunnel: map[HostPort]bool{ + "tailnet.xyz:443": false, + }, + Foreground: map[string]*ServeConfig{ + "abc123": { + AllowFunnel: map[HostPort]bool{ + "tailnet.xyz:8443": false, + }, + }, + }, + }, + }, + { + name: "funnel_enabled_in_both", + sc: &ServeConfig{ + AllowFunnel: map[HostPort]bool{ + "tailnet.xyz:443": true, + }, + Foreground: map[string]*ServeConfig{ + "abc123": { + AllowFunnel: map[HostPort]bool{ + "tailnet.xyz:8443": true, + }, + }, + }, + }, + want: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.sc.IsFunnelOn(); got != tt.want { + t.Errorf("ServeConfig.IsFunnelOn() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/ipn/store.go b/ipn/store.go index 550aa8cba..9da5288c0 100644 --- a/ipn/store.go +++ b/ipn/store.go @@ -113,3 +113,9 @@ func ReadStoreInt(store StateStore, id StateKey) (int64, error) { func PutStoreInt(store StateStore, id StateKey, val int64) error { return WriteState(store, id, fmt.Appendf(nil, "%d", val)) } + +// EncryptedStateStore is a marker interface implemented by StateStores that +// encrypt data at rest. +type EncryptedStateStore interface { + stateStoreIsEncrypted() +} diff --git a/ipn/store/awsstore/store_aws.go b/ipn/store/awsstore/store_aws.go index 0fb78d45a..78b72d0bc 100644 --- a/ipn/store/awsstore/store_aws.go +++ b/ipn/store/awsstore/store_aws.go @@ -1,7 +1,7 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -//go:build linux && !ts_omit_aws +//go:build !ts_omit_aws // Package awsstore contains an ipn.StateStore implementation using AWS SSM. package awsstore @@ -10,7 +10,9 @@ import ( "context" "errors" "fmt" + "net/url" "regexp" + "strings" "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/aws/arn" @@ -18,16 +20,35 @@ import ( "github.com/aws/aws-sdk-go-v2/service/ssm" ssmTypes "github.com/aws/aws-sdk-go-v2/service/ssm/types" "tailscale.com/ipn" + "tailscale.com/ipn/store" "tailscale.com/ipn/store/mem" "tailscale.com/types/logger" ) +func init() { + store.Register("arn:", func(logf logger.Logf, arg string) (ipn.StateStore, error) { + ssmARN, opts, err := ParseARNAndOpts(arg) + if err != nil { + return nil, err + } + return New(logf, ssmARN, opts...) + }) +} + const ( parameterNameRxStr = `^parameter(/.*)` ) var parameterNameRx = regexp.MustCompile(parameterNameRxStr) +// Option defines a functional option type for configuring awsStore. +type Option func(*storeOptions) + +// storeOptions holds optional settings for creating a new awsStore. +type storeOptions struct { + kmsKey string +} + // awsSSMClient is an interface allowing us to mock the couple of // API calls we are leveraging with the AWSStore provider type awsSSMClient interface { @@ -46,6 +67,10 @@ type awsStore struct { ssmClient awsSSMClient ssmARN arn.ARN + // kmsKey is optional. If empty, the parameter is stored in plaintext. + // If non-empty, the parameter is encrypted with this KMS key. + kmsKey string + memory mem.Store } @@ -57,30 +82,80 @@ type awsStore struct { // Tailscaled to only only store new state in-memory and // restarting Tailscaled can fail until you delete your state // from the AWS Parameter Store. -func New(_ logger.Logf, ssmARN string) (ipn.StateStore, error) { - return newStore(ssmARN, nil) +// +// If you want to specify an optional KMS key, +// pass one or more Option objects, e.g. awsstore.WithKeyID("alias/my-key"). +func New(_ logger.Logf, ssmARN string, opts ...Option) (ipn.StateStore, error) { + // Apply all options to an empty storeOptions + var so storeOptions + for _, opt := range opts { + opt(&so) + } + + return newStore(ssmARN, so, nil) +} + +// WithKeyID sets the KMS key to be used for encryption. It can be +// a KeyID, an alias ("alias/my-key"), or a full ARN. +// +// If kmsKey is empty, the Option is a no-op. +func WithKeyID(kmsKey string) Option { + return func(o *storeOptions) { + o.kmsKey = kmsKey + } +} + +// ParseARNAndOpts parses an ARN and optional URL-encoded parameters +// from arg. +func ParseARNAndOpts(arg string) (ssmARN string, opts []Option, err error) { + ssmARN = arg + + // Support optional ?url-encoded-parameters. + if s, q, ok := strings.Cut(arg, "?"); ok { + ssmARN = s + q, err := url.ParseQuery(q) + if err != nil { + return "", nil, err + } + + for k := range q { + switch k { + default: + return "", nil, fmt.Errorf("unknown arn option parameter %q", k) + case "kmsKey": + // We allow an ARN, a key ID, or an alias name for kmsKeyID. + // If it doesn't look like an ARN and doesn't have a '/', + // prepend "alias/" for KMS alias references. + kmsKey := q.Get(k) + if kmsKey != "" && + !strings.Contains(kmsKey, "/") && + !strings.HasPrefix(kmsKey, "arn:") { + kmsKey = "alias/" + kmsKey + } + if kmsKey != "" { + opts = append(opts, WithKeyID(kmsKey)) + } + } + } + } + return ssmARN, opts, nil } // newStore is NewStore, but for tests. If client is non-nil, it's // used instead of making one. -func newStore(ssmARN string, client awsSSMClient) (ipn.StateStore, error) { +func newStore(ssmARN string, so storeOptions, client awsSSMClient) (ipn.StateStore, error) { s := &awsStore{ ssmClient: client, + kmsKey: so.kmsKey, } var err error - - // Parse the ARN if s.ssmARN, err = arn.Parse(ssmARN); err != nil { return nil, fmt.Errorf("unable to parse the ARN correctly: %v", err) } - - // Validate the ARN corresponds to the SSM service if s.ssmARN.Service != "ssm" { return nil, fmt.Errorf("invalid service %q, expected 'ssm'", s.ssmARN.Service) } - - // Validate the ARN corresponds to a parameter store resource if !parameterNameRx.MatchString(s.ssmARN.Resource) { return nil, fmt.Errorf("invalid resource %q, expected to match %v", s.ssmARN.Resource, parameterNameRxStr) } @@ -96,12 +171,11 @@ func newStore(ssmARN string, client awsSSMClient) (ipn.StateStore, error) { s.ssmClient = ssm.NewFromConfig(cfg) } - // Hydrate cache with the potentially current state + // Preload existing state, if any if err := s.LoadState(); err != nil { return nil, err } return s, nil - } // LoadState attempts to read the state from AWS SSM parameter store key. @@ -172,15 +246,21 @@ func (s *awsStore) persistState() error { // which is free. However, if it exceeds 4kb it switches the parameter to advanced tiering // doubling the capacity to 8kb per the following docs: // https://aws.amazon.com/about-aws/whats-new/2019/08/aws-systems-manager-parameter-store-announces-intelligent-tiering-to-enable-automatic-parameter-tier-selection/ - _, err = s.ssmClient.PutParameter( - context.TODO(), - &ssm.PutParameterInput{ - Name: aws.String(s.ParameterName()), - Value: aws.String(string(bs)), - Overwrite: aws.Bool(true), - Tier: ssmTypes.ParameterTierIntelligentTiering, - Type: ssmTypes.ParameterTypeSecureString, - }, - ) + in := &ssm.PutParameterInput{ + Name: aws.String(s.ParameterName()), + Value: aws.String(string(bs)), + Overwrite: aws.Bool(true), + Tier: ssmTypes.ParameterTierIntelligentTiering, + Type: ssmTypes.ParameterTypeSecureString, + } + + // If kmsKey is specified, encrypt with that key + // NOTE: this input allows any alias, keyID or ARN + // If this isn't specified, AWS will use the default KMS key + if s.kmsKey != "" { + in.KeyId = aws.String(s.kmsKey) + } + + _, err = s.ssmClient.PutParameter(context.TODO(), in) return err } diff --git a/ipn/store/awsstore/store_aws_stub.go b/ipn/store/awsstore/store_aws_stub.go deleted file mode 100644 index 8d2156ce9..000000000 --- a/ipn/store/awsstore/store_aws_stub.go +++ /dev/null @@ -1,18 +0,0 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !linux || ts_omit_aws - -package awsstore - -import ( - "fmt" - "runtime" - - "tailscale.com/ipn" - "tailscale.com/types/logger" -) - -func New(logger.Logf, string) (ipn.StateStore, error) { - return nil, fmt.Errorf("AWS store is not supported on %v", runtime.GOOS) -} diff --git a/ipn/store/awsstore/store_aws_test.go b/ipn/store/awsstore/store_aws_test.go index f6c8fedb3..3cc23e48d 100644 --- a/ipn/store/awsstore/store_aws_test.go +++ b/ipn/store/awsstore/store_aws_test.go @@ -1,7 +1,7 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -//go:build linux +//go:build !ts_omit_aws package awsstore @@ -65,7 +65,11 @@ func TestNewAWSStore(t *testing.T) { Resource: "parameter/foo", } - s, err := newStore(storeParameterARN.String(), mc) + opts := storeOptions{ + kmsKey: "arn:aws:kms:eu-west-1:123456789:key/MyCustomKey", + } + + s, err := newStore(storeParameterARN.String(), opts, mc) if err != nil { t.Fatalf("creating aws store failed: %v", err) } @@ -73,7 +77,7 @@ func TestNewAWSStore(t *testing.T) { // Build a brand new file store and check that both IDs written // above are still there. - s2, err := newStore(storeParameterARN.String(), mc) + s2, err := newStore(storeParameterARN.String(), opts, mc) if err != nil { t.Fatalf("creating second aws store failed: %v", err) } @@ -162,3 +166,54 @@ func testStoreSemantics(t *testing.T, store ipn.StateStore) { } } } + +func TestParseARNAndOpts(t *testing.T) { + tests := []struct { + name string + arg string + wantARN string + wantKey string + }{ + { + name: "no-key", + arg: "arn:aws:ssm:us-east-1:123456789012:parameter/myTailscaleParam", + wantARN: "arn:aws:ssm:us-east-1:123456789012:parameter/myTailscaleParam", + }, + { + name: "custom-key", + arg: "arn:aws:ssm:us-east-1:123456789012:parameter/myTailscaleParam?kmsKey=alias/MyCustomKey", + wantARN: "arn:aws:ssm:us-east-1:123456789012:parameter/myTailscaleParam", + wantKey: "alias/MyCustomKey", + }, + { + name: "bare-name", + arg: "arn:aws:ssm:us-east-1:123456789012:parameter/myTailscaleParam?kmsKey=Bare", + wantARN: "arn:aws:ssm:us-east-1:123456789012:parameter/myTailscaleParam", + wantKey: "alias/Bare", + }, + { + name: "arn-arg", + arg: "arn:aws:ssm:us-east-1:123456789012:parameter/myTailscaleParam?kmsKey=arn:foo", + wantARN: "arn:aws:ssm:us-east-1:123456789012:parameter/myTailscaleParam", + wantKey: "arn:foo", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + arn, opts, err := ParseARNAndOpts(tt.arg) + if err != nil { + t.Fatalf("New: %v", err) + } + if arn != tt.wantARN { + t.Errorf("ARN = %q; want %q", arn, tt.wantARN) + } + var got storeOptions + for _, opt := range opts { + opt(&got) + } + if got.kmsKey != tt.wantKey { + t.Errorf("kmsKey = %q; want %q", got.kmsKey, tt.wantKey) + } + }) + } +} diff --git a/ipn/store/kubestore/store_kube.go b/ipn/store/kubestore/store_kube.go index 00950bd3b..f48237c05 100644 --- a/ipn/store/kubestore/store_kube.go +++ b/ipn/store/kubestore/store_kube.go @@ -7,27 +7,75 @@ package kubestore import ( "context" "fmt" + "log" "net" + "net/http" "os" "strings" "time" + "tailscale.com/envknob" "tailscale.com/ipn" + "tailscale.com/ipn/store" + "tailscale.com/ipn/store/mem" "tailscale.com/kube/kubeapi" "tailscale.com/kube/kubeclient" + "tailscale.com/kube/kubetypes" "tailscale.com/types/logger" + "tailscale.com/util/dnsname" + "tailscale.com/util/mak" +) + +func init() { + store.Register("kube:", func(logf logger.Logf, path string) (ipn.StateStore, error) { + secretName := strings.TrimPrefix(path, "kube:") + return New(logf, secretName) + }) +} + +const ( + // timeout is the timeout for a single state update that includes calls to the API server to write or read a + // state Secret and emit an Event. + timeout = 30 * time.Second + + reasonTailscaleStateUpdated = "TailscaledStateUpdated" + reasonTailscaleStateLoaded = "TailscaleStateLoaded" + reasonTailscaleStateUpdateFailed = "TailscaleStateUpdateFailed" + reasonTailscaleStateLoadFailed = "TailscaleStateLoadFailed" + eventTypeWarning = "Warning" + eventTypeNormal = "Normal" + + keyTLSCert = "tls.crt" + keyTLSKey = "tls.key" ) // Store is an ipn.StateStore that uses a Kubernetes Secret for persistence. type Store struct { - client kubeclient.Client - canPatch bool - secretName string + client kubeclient.Client + canPatch bool + secretName string // state Secret + certShareMode string // 'ro', 'rw', or empty + podName string + + // memory holds the latest tailscale state. Writes write state to a kube + // Secret and memory, Reads read from memory. + memory mem.Store } -// New returns a new Store that persists to the named secret. -func New(_ logger.Logf, secretName string) (*Store, error) { - c, err := kubeclient.New() +// New returns a new Store that persists state to Kubernets Secret(s). +// Tailscale state is stored in a Secret named by the secretName parameter. +// TLS certs are stored and retrieved from state Secret or separate Secrets +// named after TLS endpoints if running in cert share mode. +func New(logf logger.Logf, secretName string) (*Store, error) { + c, err := newClient() + if err != nil { + return nil, err + } + return newWithClient(logf, c, secretName) +} + +func newClient() (kubeclient.Client, error) { + c, err := kubeclient.New("tailscale-state-store") if err != nil { return nil, err } @@ -35,15 +83,43 @@ func New(_ logger.Logf, secretName string) (*Store, error) { // Derive the API server address from the environment variables c.SetURL(fmt.Sprintf("https://%s:%s", os.Getenv("KUBERNETES_SERVICE_HOST"), os.Getenv("KUBERNETES_SERVICE_PORT_HTTPS"))) } + return c, nil +} + +func newWithClient(logf logger.Logf, c kubeclient.Client, secretName string) (*Store, error) { canPatch, _, err := c.CheckSecretPermissions(context.Background(), secretName) if err != nil { return nil, err } - return &Store{ + s := &Store{ client: c, canPatch: canPatch, secretName: secretName, - }, nil + podName: os.Getenv("POD_NAME"), + } + if envknob.IsCertShareReadWriteMode() { + s.certShareMode = "rw" + } else if envknob.IsCertShareReadOnlyMode() { + s.certShareMode = "ro" + } + + // Load latest state from kube Secret if it already exists. + if err := s.loadState(); err != nil && err != ipn.ErrStateNotExist { + return nil, fmt.Errorf("error loading state from kube Secret: %w", err) + } + // If we are in cert share mode, pre-load existing shared certs. + if s.certShareMode == "rw" || s.certShareMode == "ro" { + sel := s.certSecretSelector() + if err := s.loadCerts(context.Background(), sel); err != nil { + // We will attempt to again retrieve the certs from Secrets when a request for an HTTPS endpoint + // is received. + log.Printf("[unexpected] error loading TLS certs: %v", err) + } + } + if s.certShareMode == "ro" { + go s.runCertReload(context.Background(), logf) + } + return s, nil } func (s *Store) SetDialer(d func(ctx context.Context, network, address string) (net.Conn, error)) { @@ -54,86 +130,327 @@ func (s *Store) String() string { return "kube.Store" } // ReadState implements the StateStore interface. func (s *Store) ReadState(id ipn.StateKey) ([]byte, error) { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() + return s.memory.ReadState(ipn.StateKey(sanitizeKey(id))) +} - secret, err := s.client.GetSecret(ctx, s.secretName) - if err != nil { - if st, ok := err.(*kubeapi.Status); ok && st.Code == 404 { - return nil, ipn.ErrStateNotExist +// WriteState implements the StateStore interface. +func (s *Store) WriteState(id ipn.StateKey, bs []byte) (err error) { + defer func() { + if err == nil { + s.memory.WriteState(ipn.StateKey(sanitizeKey(id)), bs) + } + }() + return s.updateSecret(map[string][]byte{string(id): bs}, s.secretName) +} + +// WriteTLSCertAndKey writes a TLS cert and key to domain.crt, domain.key fields +// of a Tailscale Kubernetes node's state Secret. +func (s *Store) WriteTLSCertAndKey(domain string, cert, key []byte) (err error) { + if s.certShareMode == "ro" { + log.Printf("[unexpected] TLS cert and key write in read-only mode") + } + if err := dnsname.ValidHostname(domain); err != nil { + return fmt.Errorf("invalid domain name %q: %w", domain, err) + } + secretName := s.secretName + data := map[string][]byte{ + domain + ".crt": cert, + domain + ".key": key, + } + // If we run in cert share mode, cert and key for a DNS name are written + // to a separate Secret. + if s.certShareMode == "rw" { + secretName = domain + data = map[string][]byte{ + keyTLSCert: cert, + keyTLSKey: key, } - return nil, err } - b, ok := secret.Data[sanitizeKey(id)] - if !ok { - return nil, ipn.ErrStateNotExist + if err := s.updateSecret(data, secretName); err != nil { + return fmt.Errorf("error writing TLS cert and key to Secret: %w", err) } - return b, nil + // TODO(irbekrm): certs for write replicas are currently not + // written to memory to avoid out of sync memory state after + // Ingress resources have been recreated. This means that TLS + // certs for write replicas are retrieved from the Secret on + // each HTTPS request. This is a temporary solution till we + // implement a Secret watch. + if s.certShareMode != "rw" { + s.memory.WriteState(ipn.StateKey(domain+".crt"), cert) + s.memory.WriteState(ipn.StateKey(domain+".key"), key) + } + return nil } -func sanitizeKey(k ipn.StateKey) string { - // The only valid characters in a Kubernetes secret key are alphanumeric, -, - // _, and . - return strings.Map(func(r rune) rune { - if r >= 'a' && r <= 'z' || r >= 'A' && r <= 'Z' || r >= '0' && r <= '9' || r == '-' || r == '_' || r == '.' { - return r +// ReadTLSCertAndKey reads a TLS cert and key from memory or from a +// domain-specific Secret. It first checks the in-memory store, if not found in +// memory and running cert store in read-only mode, looks up a Secret. +// Note that write replicas of HA Ingress always retrieve TLS certs from Secrets. +func (s *Store) ReadTLSCertAndKey(domain string) (cert, key []byte, err error) { + if err := dnsname.ValidHostname(domain); err != nil { + return nil, nil, fmt.Errorf("invalid domain name %q: %w", domain, err) + } + certKey := domain + ".crt" + keyKey := domain + ".key" + cert, err = s.memory.ReadState(ipn.StateKey(certKey)) + if err == nil { + key, err = s.memory.ReadState(ipn.StateKey(keyKey)) + if err == nil { + return cert, key, nil } - return '_' - }, string(k)) -} + } + if s.certShareMode == "" { + return nil, nil, ipn.ErrStateNotExist + } -// WriteState implements the StateStore interface. -func (s *Store) WriteState(id ipn.StateKey, bs []byte) error { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() - - secret, err := s.client.GetSecret(ctx, s.secretName) + secret, err := s.client.GetSecret(ctx, domain) if err != nil { if kubeclient.IsNotFoundErr(err) { + // TODO(irbekrm): we should return a more specific error + // that wraps ipn.ErrStateNotExist here. + return nil, nil, ipn.ErrStateNotExist + } + st, ok := err.(*kubeapi.Status) + if ok && st.Code == http.StatusForbidden && (s.certShareMode == "ro" || s.certShareMode == "rw") { + // In cert share mode, we read from a dedicated Secret per domain. + // To get here, we already had a cache miss from our in-memory + // store. For write replicas, that means it wasn't available on + // start and it wasn't written since. For read replicas, that means + // it wasn't available on start and it hasn't been reloaded in the + // background. So getting a "forbidden" error is an expected + // "not found" case where we've been asked for a cert we don't + // expect to issue, and so the forbidden error reflects that the + // operator didn't assign permission for a Secret for that domain. + // + // This code path gets triggered by the admin UI's machine page, + // which queries for the node's own TLS cert existing via the + // "tls-cert-status" c2n API. + return nil, nil, ipn.ErrStateNotExist + } + return nil, nil, fmt.Errorf("getting TLS Secret %q: %w", domain, err) + } + cert = secret.Data[keyTLSCert] + key = secret.Data[keyTLSKey] + if len(cert) == 0 || len(key) == 0 { + return nil, nil, ipn.ErrStateNotExist + } + // TODO(irbekrm): a read between these two separate writes would + // get a mismatched cert and key. Allow writing both cert and + // key to the memory store in a single, lock-protected operation. + // + // TODO(irbekrm): currently certs for write replicas of HA Ingress get + // retrieved from the cluster Secret on each HTTPS request to avoid a + // situation when after Ingress recreation stale certs are read from + // memory. + // Fix this by watching Secrets to ensure that memory store gets updated + // when Secrets are deleted. + if s.certShareMode == "ro" { + s.memory.WriteState(ipn.StateKey(certKey), cert) + s.memory.WriteState(ipn.StateKey(keyKey), key) + } + return cert, key, nil +} + +func (s *Store) updateSecret(data map[string][]byte, secretName string) (err error) { + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer func() { + if err != nil { + if err := s.client.Event(ctx, eventTypeWarning, reasonTailscaleStateUpdateFailed, err.Error()); err != nil { + log.Printf("kubestore: error creating tailscaled state update Event: %v", err) + } + } else { + if err := s.client.Event(ctx, eventTypeNormal, reasonTailscaleStateUpdated, "Successfully updated tailscaled state Secret"); err != nil { + log.Printf("kubestore: error creating tailscaled state Event: %v", err) + } + } + cancel() + }() + secret, err := s.client.GetSecret(ctx, secretName) + if err != nil { + // If the Secret does not exist, create it with the required data. + if kubeclient.IsNotFoundErr(err) && s.canCreateSecret(secretName) { return s.client.CreateSecret(ctx, &kubeapi.Secret{ TypeMeta: kubeapi.TypeMeta{ APIVersion: "v1", Kind: "Secret", }, ObjectMeta: kubeapi.ObjectMeta{ - Name: s.secretName, - }, - Data: map[string][]byte{ - sanitizeKey(id): bs, + Name: secretName, }, + Data: func(m map[string][]byte) map[string][]byte { + d := make(map[string][]byte, len(m)) + for key, val := range m { + d[sanitizeKey(key)] = val + } + return d + }(data), }) } - return err + return fmt.Errorf("error getting Secret %s: %w", secretName, err) } - if s.canPatch { - if len(secret.Data) == 0 { // if user has pre-created a blank Secret - m := []kubeclient.JSONPatch{ + if s.canPatchSecret(secretName) { + var m []kubeclient.JSONPatch + // If the user has pre-created a Secret with no data, we need to ensure the top level /data field. + if len(secret.Data) == 0 { + m = []kubeclient.JSONPatch{ { - Op: "add", - Path: "/data", - Value: map[string][]byte{sanitizeKey(id): bs}, + Op: "add", + Path: "/data", + Value: func(m map[string][]byte) map[string][]byte { + d := make(map[string][]byte, len(m)) + for key, val := range m { + d[sanitizeKey(key)] = val + } + return d + }(data), }, } - if err := s.client.JSONPatchSecret(ctx, s.secretName, m); err != nil { - return fmt.Errorf("error patching Secret %s with a /data field: %v", s.secretName, err) + // If the Secret has data, patch it with the new data. + } else { + for key, val := range data { + m = append(m, kubeclient.JSONPatch{ + Op: "add", + Path: "/data/" + sanitizeKey(key), + Value: val, + }) } - return nil - } - m := []kubeclient.JSONPatch{ - { - Op: "add", - Path: "/data/" + sanitizeKey(id), - Value: bs, - }, } - if err := s.client.JSONPatchSecret(ctx, s.secretName, m); err != nil { - return fmt.Errorf("error patching Secret %s with /data/%s field", s.secretName, sanitizeKey(id)) + if err := s.client.JSONPatchResource(ctx, secretName, kubeclient.TypeSecrets, m); err != nil { + return fmt.Errorf("error patching Secret %s: %w", secretName, err) } return nil } - secret.Data[sanitizeKey(id)] = bs + // No patch permissions, use UPDATE instead. + for key, val := range data { + mak.Set(&secret.Data, sanitizeKey(key), val) + } if err := s.client.UpdateSecret(ctx, secret); err != nil { + return fmt.Errorf("error updating Secret %s: %w", s.secretName, err) + } + return nil +} + +func (s *Store) loadState() (err error) { + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + secret, err := s.client.GetSecret(ctx, s.secretName) + if err != nil { + if st, ok := err.(*kubeapi.Status); ok && st.Code == 404 { + return ipn.ErrStateNotExist + } + if err := s.client.Event(ctx, eventTypeWarning, reasonTailscaleStateLoadFailed, err.Error()); err != nil { + log.Printf("kubestore: error creating Event: %v", err) + } return err } - return err + if err := s.client.Event(ctx, eventTypeNormal, reasonTailscaleStateLoaded, "Successfully loaded tailscaled state from Secret"); err != nil { + log.Printf("kubestore: error creating Event: %v", err) + } + s.memory.LoadFromMap(secret.Data) + return nil +} + +// runCertReload relists and reloads all TLS certs for endpoints shared by this +// node from Secrets other than the state Secret to ensure that renewed certs get eventually loaded. +// It is not critical to reload a cert immediately after +// renewal, so a daily check is acceptable. +// Currently (3/2025) this is only used for the shared HA Ingress certs on 'read' replicas. +// Note that if shared certs are not found in memory on an HTTPS request, we +// do a Secret lookup, so this mechanism does not need to ensure that newly +// added Ingresses' certs get loaded. +func (s *Store) runCertReload(ctx context.Context, logf logger.Logf) { + ticker := time.NewTicker(time.Hour * 24) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + sel := s.certSecretSelector() + if err := s.loadCerts(ctx, sel); err != nil { + logf("[unexpected] error reloading TLS certs: %v", err) + } + } + } +} + +// loadCerts lists all Secrets matching the provided selector and loads TLS +// certs and keys from those. +func (s *Store) loadCerts(ctx context.Context, sel map[string]string) error { + ss, err := s.client.ListSecrets(ctx, sel) + if err != nil { + return fmt.Errorf("error listing TLS Secrets: %w", err) + } + for _, secret := range ss.Items { + if !hasTLSData(&secret) { + continue + } + // Only load secrets that have valid domain names (ending in .ts.net) + if !strings.HasSuffix(secret.Name, ".ts.net") { + continue + } + s.memory.WriteState(ipn.StateKey(secret.Name)+".crt", secret.Data[keyTLSCert]) + s.memory.WriteState(ipn.StateKey(secret.Name)+".key", secret.Data[keyTLSKey]) + } + return nil +} + +// canCreateSecret returns true if this node should be allowed to create the given +// Secret in its namespace. +func (s *Store) canCreateSecret(secret string) bool { + // Only allow creating the state Secret (and not TLS Secrets). + return secret == s.secretName +} + +// canPatchSecret returns true if this node should be allowed to patch the given +// Secret. +func (s *Store) canPatchSecret(secret string) bool { + // For backwards compatibility reasons, setups where the proxies are not + // given PATCH permissions for state Secrets are allowed. For TLS + // Secrets, we should always have PATCH permissions. + if secret == s.secretName { + return s.canPatch + } + return true +} + +// certSecretSelector returns a label selector that can be used to list all +// Secrets that aren't Tailscale state Secrets and contain TLS certificates for +// HTTPS endpoints that this node serves. +// Currently (7/2025) this only applies to the Kubernetes Operator's ProxyGroup +// when spec.Type is "ingress" or "kube-apiserver". +func (s *Store) certSecretSelector() map[string]string { + if s.podName == "" { + return map[string]string{} + } + p := strings.LastIndex(s.podName, "-") + if p == -1 { + return map[string]string{} + } + pgName := s.podName[:p] + return map[string]string{ + kubetypes.LabelSecretType: kubetypes.LabelSecretTypeCerts, + kubetypes.LabelManaged: "true", + "tailscale.com/proxy-group": pgName, + } +} + +// hasTLSData returns true if the provided Secret contains non-empty TLS cert and key. +func hasTLSData(s *kubeapi.Secret) bool { + return len(s.Data[keyTLSCert]) != 0 && len(s.Data[keyTLSKey]) != 0 +} + +// sanitizeKey converts any value that can be converted to a string into a valid Kubernetes Secret key. +// Valid characters are alphanumeric, -, _, and . +// https://kubernetes.io/docs/concepts/configuration/secret/#restriction-names-data. +func sanitizeKey[T ~string](k T) string { + return strings.Map(func(r rune) rune { + if r >= 'a' && r <= 'z' || r >= 'A' && r <= 'Z' || r >= '0' && r <= '9' || r == '-' || r == '_' || r == '.' { + return r + } + return '_' + }, string(k)) } diff --git a/ipn/store/mem/store_mem.go b/ipn/store/mem/store_mem.go index f3a308ae5..6f474ce99 100644 --- a/ipn/store/mem/store_mem.go +++ b/ipn/store/mem/store_mem.go @@ -9,8 +9,10 @@ import ( "encoding/json" "sync" + xmaps "golang.org/x/exp/maps" "tailscale.com/ipn" "tailscale.com/types/logger" + "tailscale.com/util/mak" ) // New returns a new Store. @@ -28,6 +30,7 @@ type Store struct { func (s *Store) String() string { return "mem.Store" } // ReadState implements the StateStore interface. +// It returns ipn.ErrStateNotExist if the state does not exist. func (s *Store) ReadState(id ipn.StateKey) ([]byte, error) { s.mu.Lock() defer s.mu.Unlock() @@ -39,6 +42,7 @@ func (s *Store) ReadState(id ipn.StateKey) ([]byte, error) { } // WriteState implements the StateStore interface. +// It never returns an error. func (s *Store) WriteState(id ipn.StateKey, bs []byte) error { s.mu.Lock() defer s.mu.Unlock() @@ -49,6 +53,19 @@ func (s *Store) WriteState(id ipn.StateKey, bs []byte) error { return nil } +// LoadFromMap loads the in-memory cache from the provided map. +// Any existing content is cleared, and the provided map is +// copied into the cache. +func (s *Store) LoadFromMap(m map[string][]byte) { + s.mu.Lock() + defer s.mu.Unlock() + xmaps.Clear(s.cache) + for k, v := range m { + mak.Set(&s.cache, ipn.StateKey(k), v) + } + return +} + // LoadFromJSON attempts to unmarshal json content into the // in-memory cache. func (s *Store) LoadFromJSON(data []byte) error { diff --git a/ipn/store/store_aws.go b/ipn/store/store_aws.go deleted file mode 100644 index e164f9de7..000000000 --- a/ipn/store/store_aws.go +++ /dev/null @@ -1,18 +0,0 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build (ts_aws || (linux && (arm64 || amd64))) && !ts_omit_aws - -package store - -import ( - "tailscale.com/ipn/store/awsstore" -) - -func init() { - registerAvailableExternalStores = append(registerAvailableExternalStores, registerAWSStore) -} - -func registerAWSStore() { - Register("arn:", awsstore.New) -} diff --git a/ipn/store/store_kube.go b/ipn/store/store_kube.go deleted file mode 100644 index 8941620f6..000000000 --- a/ipn/store/store_kube.go +++ /dev/null @@ -1,25 +0,0 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build (ts_kube || (linux && (arm64 || amd64))) && !ts_omit_kube - -package store - -import ( - "strings" - - "tailscale.com/ipn" - "tailscale.com/ipn/store/kubestore" - "tailscale.com/types/logger" -) - -func init() { - registerAvailableExternalStores = append(registerAvailableExternalStores, registerKubeStore) -} - -func registerKubeStore() { - Register("kube:", func(logf logger.Logf, path string) (ipn.StateStore, error) { - secretName := strings.TrimPrefix(path, "kube:") - return kubestore.New(logf, secretName) - }) -} diff --git a/ipn/store/stores.go b/ipn/store/stores.go index 1a87fc548..bf175da41 100644 --- a/ipn/store/stores.go +++ b/ipn/store/stores.go @@ -7,10 +7,14 @@ package store import ( "bytes" "encoding/json" + "errors" "fmt" + "iter" + "maps" "os" "path/filepath" "runtime" + "slices" "strings" "sync" @@ -20,26 +24,22 @@ import ( "tailscale.com/paths" "tailscale.com/types/logger" "tailscale.com/util/mak" + "tailscale.com/util/testenv" ) // Provider returns a StateStore for the provided path. // The arg is of the form "prefix:rest", where prefix was previously registered with Register. type Provider func(logf logger.Logf, arg string) (ipn.StateStore, error) -var regOnce sync.Once - -var registerAvailableExternalStores []func() - -func registerDefaultStores() { +func init() { Register("mem:", mem.New) - - for _, f := range registerAvailableExternalStores { - f() - } } var knownStores map[string]Provider +// TPMPrefix is the path prefix used for TPM-encrypted StateStore. +const TPMPrefix = "tpmseal:" + // New returns a StateStore based on the provided arg // and registered stores. // The arg is of the form "prefix:rest", where prefix was previously @@ -53,19 +53,31 @@ var knownStores map[string]Provider // the suffix an AWS ARN for an SSM. // - (Linux-only) if the string begins with "kube:", // the suffix is a Kubernetes secret name +// - (Linux or Windows) if the string begins with "tpmseal:", the suffix is +// filepath that is sealed with the local TPM device. // - In all other cases, the path is treated as a filepath. func New(logf logger.Logf, path string) (ipn.StateStore, error) { - regOnce.Do(registerDefaultStores) for prefix, sf := range knownStores { if strings.HasPrefix(path, prefix) { // We can't strip the prefix here as some NewStoreFunc (like arn:) // expect the prefix. + if prefix == TPMPrefix { + if runtime.GOOS == "windows" { + path = TPMPrefix + TryWindowsAppDataMigration(logf, strings.TrimPrefix(path, TPMPrefix)) + } + if err := maybeMigrateLocalStateFile(logf, path); err != nil { + return nil, fmt.Errorf("failed to migrate existing state file to TPM-sealed format: %w", err) + } + } return sf(logf, path) } } if runtime.GOOS == "windows" { path = TryWindowsAppDataMigration(logf, path) } + if err := maybeMigrateLocalStateFile(logf, path); err != nil { + return nil, fmt.Errorf("failed to migrate existing TPM-sealed state file to plaintext format: %w", err) + } return NewFileStore(logf, path) } @@ -84,6 +96,29 @@ func Register(prefix string, fn Provider) { mak.Set(&knownStores, prefix, fn) } +// RegisterForTest registers a prefix to be used for NewStore in tests. An +// existing registered prefix will be replaced. +func RegisterForTest(t testenv.TB, prefix string, fn Provider) { + if len(prefix) == 0 { + panic("prefix is empty") + } + old := maps.Clone(knownStores) + t.Cleanup(func() { knownStores = old }) + + mak.Set(&knownStores, prefix, fn) +} + +// HasKnownProviderPrefix reports whether path uses one of the registered +// Provider prefixes. +func HasKnownProviderPrefix(path string) bool { + for prefix := range knownStores { + if strings.HasPrefix(path, prefix) { + return true + } + } + return false +} + // TryWindowsAppDataMigration attempts to copy the Windows state file // from its old location to the new location. (Issue 2856) // @@ -186,3 +221,123 @@ func (s *FileStore) WriteState(id ipn.StateKey, bs []byte) error { } return atomicfile.WriteFile(s.path, bs, 0600) } + +func (s *FileStore) All() iter.Seq2[ipn.StateKey, []byte] { + return func(yield func(ipn.StateKey, []byte) bool) { + s.mu.Lock() + defer s.mu.Unlock() + + for k, v := range s.cache { + if !yield(k, v) { + break + } + } + } +} + +// Ensure FileStore implements ExportableStore for migration to/from +// tpm.tpmStore. +var _ ExportableStore = (*FileStore)(nil) + +// ExportableStore is an ipn.StateStore that can export all of its contents. +// This interface is optional to implement, and used for migrating the state +// between different store implementations. +type ExportableStore interface { + ipn.StateStore + + // All returns an iterator over all store keys. Using ReadState or + // WriteState is not safe while iterating and can lead to a deadlock. The + // order of keys in the iterator is not specified and may change between + // runs. + All() iter.Seq2[ipn.StateKey, []byte] +} + +func maybeMigrateLocalStateFile(logf logger.Logf, path string) error { + path, toTPM := strings.CutPrefix(path, TPMPrefix) + + // Extract JSON keys from the file on disk and guess what kind it is. + bs, err := os.ReadFile(path) + if err != nil { + if errors.Is(err, os.ErrNotExist) { + return nil + } + return err + } + var content map[string]any + if err := json.Unmarshal(bs, &content); err != nil { + return fmt.Errorf("failed to unmarshal %q: %w", path, err) + } + keys := slices.Sorted(maps.Keys(content)) + tpmKeys := []string{"key", "nonce", "data"} + slices.Sort(tpmKeys) + // TPM-sealed files will have exactly these keys. + existingFileSealed := slices.Equal(keys, tpmKeys) + // Plaintext files for nodes that registered at least once will have this + // key, plus other dynamic ones. + _, existingFilePlaintext := content["_machinekey"] + isTPM := existingFileSealed && !existingFilePlaintext + + if isTPM == toTPM { + // No migration needed. + return nil + } + + newTPMStore, ok := knownStores[TPMPrefix] + if !ok { + return errors.New("this build does not support TPM integration") + } + + // Open from (old format) and to (new format) stores for migration. The + // "to" store will be at tmpPath. + var from, to ipn.StateStore + tmpPath := path + ".tmp" + if toTPM { + // Migrate plaintext file to be TPM-sealed. + from, err = NewFileStore(logf, path) + if err != nil { + return fmt.Errorf("NewFileStore(%q): %w", path, err) + } + to, err = newTPMStore(logf, TPMPrefix+tmpPath) + if err != nil { + return fmt.Errorf("newTPMStore(%q): %w", tmpPath, err) + } + } else { + // Migrate TPM-selaed file to plaintext. + from, err = newTPMStore(logf, TPMPrefix+path) + if err != nil { + return fmt.Errorf("newTPMStore(%q): %w", path, err) + } + to, err = NewFileStore(logf, tmpPath) + if err != nil { + return fmt.Errorf("NewFileStore(%q): %w", tmpPath, err) + } + } + defer os.Remove(tmpPath) + + fromExp, ok := from.(ExportableStore) + if !ok { + return fmt.Errorf("%T does not implement the exportableStore interface", from) + } + + // Copy all the items. This is pretty inefficient, because both stores + // write the file to disk for each WriteState, but that's ok for a one-time + // migration. + for k, v := range fromExp.All() { + if err := to.WriteState(k, v); err != nil { + return err + } + } + + // Finally, overwrite the state file with the new one we created at + // tmpPath. + if err := atomicfile.Rename(tmpPath, path); err != nil { + return err + } + + if toTPM { + logf("migrated %q from plaintext to TPM-sealed format", path) + } else { + logf("migrated %q from TPM-sealed to plaintext format", path) + } + return nil +} diff --git a/ipn/store/stores_test.go b/ipn/store/stores_test.go index ea09e6ea6..1f0fc0fef 100644 --- a/ipn/store/stores_test.go +++ b/ipn/store/stores_test.go @@ -4,6 +4,7 @@ package store import ( + "maps" "path/filepath" "testing" @@ -14,10 +15,9 @@ import ( ) func TestNewStore(t *testing.T) { - regOnce.Do(registerDefaultStores) + oldKnownStores := maps.Clone(knownStores) t.Cleanup(func() { - knownStores = map[string]Provider{} - registerDefaultStores() + knownStores = oldKnownStores }) knownStores = map[string]Provider{} diff --git a/ipn/store_test.go b/ipn/store_test.go index fcc082d8a..4dd7321b9 100644 --- a/ipn/store_test.go +++ b/ipn/store_test.go @@ -5,6 +5,7 @@ package ipn import ( "bytes" + "iter" "sync" "testing" @@ -31,6 +32,19 @@ func (s *memStore) WriteState(k StateKey, v []byte) error { return nil } +func (s *memStore) All() iter.Seq2[StateKey, []byte] { + return func(yield func(StateKey, []byte) bool) { + s.mu.Lock() + defer s.mu.Unlock() + + for k, v := range s.m { + if !yield(k, v) { + break + } + } + } +} + func TestWriteState(t *testing.T) { var ss StateStore = new(memStore) WriteState(ss, "foo", []byte("bar")) diff --git a/k8s-operator/api-proxy/doc.go b/k8s-operator/api-proxy/doc.go new file mode 100644 index 000000000..89d890959 --- /dev/null +++ b/k8s-operator/api-proxy/doc.go @@ -0,0 +1,8 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !plan9 + +// Package apiproxy contains the Kubernetes API Proxy implementation used by +// k8s-operator and k8s-proxy. +package apiproxy diff --git a/k8s-operator/api-proxy/proxy.go b/k8s-operator/api-proxy/proxy.go new file mode 100644 index 000000000..762a52f1f --- /dev/null +++ b/k8s-operator/api-proxy/proxy.go @@ -0,0 +1,574 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !plan9 + +package apiproxy + +import ( + "bytes" + "context" + "crypto/tls" + "encoding/json" + "errors" + "fmt" + "io" + "net" + "net/http" + "net/http/httputil" + "net/netip" + "net/url" + "strings" + "time" + + "go.uber.org/zap" + "k8s.io/apimachinery/pkg/util/sets" + "k8s.io/apiserver/pkg/endpoints/request" + "k8s.io/client-go/rest" + "k8s.io/client-go/transport" + "tailscale.com/client/local" + "tailscale.com/client/tailscale/apitype" + "tailscale.com/envknob" + ksr "tailscale.com/k8s-operator/sessionrecording" + "tailscale.com/kube/kubetypes" + "tailscale.com/net/netx" + "tailscale.com/sessionrecording" + "tailscale.com/tailcfg" + "tailscale.com/tsnet" + "tailscale.com/util/clientmetric" + "tailscale.com/util/ctxkey" + "tailscale.com/util/set" +) + +var ( + // counterNumRequestsproxies counts the number of API server requests proxied via this proxy. + counterNumRequestsProxied = clientmetric.NewCounter("k8s_auth_proxy_requests_proxied") + whoIsKey = ctxkey.New("", (*apitype.WhoIsResponse)(nil)) +) + +// NewAPIServerProxy creates a new APIServerProxy that's ready to start once Run +// is called. No network traffic will flow until Run is called. +// +// authMode controls how the proxy behaves: +// - true: the proxy is started and requests are impersonated using the +// caller's Tailscale identity and the rules defined in the tailnet ACLs. +// - false: the proxy is started and requests are passed through to the +// Kubernetes API without any auth modifications. +func NewAPIServerProxy(zlog *zap.SugaredLogger, restConfig *rest.Config, ts *tsnet.Server, mode kubetypes.APIServerProxyMode, https bool) (*APIServerProxy, error) { + if mode == kubetypes.APIServerProxyModeNoAuth { + restConfig = rest.AnonymousClientConfig(restConfig) + } + + cfg, err := restConfig.TransportConfig() + if err != nil { + return nil, fmt.Errorf("could not get rest.TransportConfig(): %w", err) + } + + tr := http.DefaultTransport.(*http.Transport).Clone() + tr.TLSClientConfig, err = transport.TLSConfigFor(cfg) + if err != nil { + return nil, fmt.Errorf("could not get transport.TLSConfigFor(): %w", err) + } + tr.TLSNextProto = make(map[string]func(authority string, c *tls.Conn) http.RoundTripper) + + rt, err := transport.HTTPWrappersForConfig(cfg, tr) + if err != nil { + return nil, fmt.Errorf("could not get rest.TransportConfig(): %w", err) + } + + u, err := url.Parse(restConfig.Host) + if err != nil { + return nil, fmt.Errorf("failed to parse URL %w", err) + } + if u.Scheme == "" || u.Host == "" { + return nil, fmt.Errorf("the API server proxy requires host and scheme but got: %q", restConfig.Host) + } + + lc, err := ts.LocalClient() + if err != nil { + return nil, fmt.Errorf("could not get local client: %w", err) + } + + ap := &APIServerProxy{ + log: zlog, + lc: lc, + authMode: mode == kubetypes.APIServerProxyModeAuth, + https: https, + upstreamURL: u, + ts: ts, + sendEventFunc: sessionrecording.SendEvent, + eventsEnabled: envknob.Bool("TS_EXPERIMENTAL_KUBE_API_EVENTS"), + } + ap.rp = &httputil.ReverseProxy{ + Rewrite: func(pr *httputil.ProxyRequest) { + ap.addImpersonationHeadersAsRequired(pr.Out) + }, + Transport: rt, + } + + return ap, nil +} + +// Run starts the HTTP server that authenticates requests using the +// Tailscale LocalAPI and then proxies them to the Kubernetes API. +// It listens on :443 and uses the Tailscale HTTPS certificate. +// +// It return when ctx is cancelled or ServeTLS fails. +func (ap *APIServerProxy) Run(ctx context.Context) error { + mux := http.NewServeMux() + mux.HandleFunc("/", ap.serveDefault) + mux.HandleFunc("POST /api/v1/namespaces/{namespace}/pods/{pod}/exec", ap.serveExecSPDY) + mux.HandleFunc("GET /api/v1/namespaces/{namespace}/pods/{pod}/exec", ap.serveExecWS) + mux.HandleFunc("POST /api/v1/namespaces/{namespace}/pods/{pod}/attach", ap.serveAttachSPDY) + mux.HandleFunc("GET /api/v1/namespaces/{namespace}/pods/{pod}/attach", ap.serveAttachWS) + + ap.hs = &http.Server{ + Handler: mux, + ErrorLog: zap.NewStdLog(ap.log.Desugar()), + TLSNextProto: make(map[string]func(*http.Server, *tls.Conn, http.Handler)), + } + + mode := "noauth" + if ap.authMode { + mode = "auth" + } + var proxyLn net.Listener + var serve func(ln net.Listener) error + if ap.https { + var err error + proxyLn, err = ap.ts.Listen("tcp", ":443") + if err != nil { + return fmt.Errorf("could not listen on :443: %w", err) + } + serve = func(ln net.Listener) error { + return ap.hs.ServeTLS(ln, "", "") + } + + // Kubernetes uses SPDY for exec and port-forward, however SPDY is + // incompatible with HTTP/2; so disable HTTP/2 in the proxy. + ap.hs.TLSConfig = &tls.Config{ + GetCertificate: ap.lc.GetCertificate, + NextProtos: []string{"http/1.1"}, + } + } else { + var err error + proxyLn, err = net.Listen("tcp", "localhost:80") + if err != nil { + return fmt.Errorf("could not listen on :80: %w", err) + } + serve = ap.hs.Serve + } + + errs := make(chan error) + go func() { + ap.log.Infof("API server proxy in %s mode is listening on %s", mode, proxyLn.Addr()) + if err := serve(proxyLn); err != nil && err != http.ErrServerClosed { + errs <- fmt.Errorf("error serving: %w", err) + } + }() + + select { + case <-ctx.Done(): + case err := <-errs: + ap.hs.Close() + return err + } + + // Graceful shutdown with a timeout of 10s. + shutdownCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + return ap.hs.Shutdown(shutdownCtx) +} + +// APIServerProxy is an [net/http.Handler] that authenticates requests using the Tailscale +// LocalAPI and then proxies them to the Kubernetes API. +type APIServerProxy struct { + log *zap.SugaredLogger + lc *local.Client + rp *httputil.ReverseProxy + + authMode bool // Whether to run with impersonation using caller's tailnet identity. + https bool // Whether to serve on https for the device hostname; true for k8s-operator, false (and localhost) for k8s-proxy. + ts *tsnet.Server + hs *http.Server + upstreamURL *url.URL + + sendEventFunc func(ap netip.AddrPort, event io.Reader, dial netx.DialFunc) error + + // Flag used to enable sending API requests as events to tsrecorder. + eventsEnabled bool +} + +// serveDefault is the default handler for Kubernetes API server requests. +func (ap *APIServerProxy) serveDefault(w http.ResponseWriter, r *http.Request) { + who, err := ap.whoIs(r) + if err != nil { + ap.authError(w, err) + return + } + + if err = ap.recordRequestAsEvent(r, who); err != nil { + msg := fmt.Sprintf("error recording Kubernetes API request: %v", err) + ap.log.Errorf(msg) + http.Error(w, msg, http.StatusBadGateway) + return + } + + counterNumRequestsProxied.Add(1) + + ap.rp.ServeHTTP(w, r.WithContext(whoIsKey.WithValue(r.Context(), who))) +} + +// serveExecSPDY serves '/exec' requests for sessions streamed over SPDY, +// optionally configuring the kubectl exec sessions to be recorded. +func (ap *APIServerProxy) serveExecSPDY(w http.ResponseWriter, r *http.Request) { + ap.sessionForProto(w, r, ksr.ExecSessionType, ksr.SPDYProtocol) +} + +// serveExecWS serves '/exec' requests for sessions streamed over WebSocket, +// optionally configuring the kubectl exec sessions to be recorded. +func (ap *APIServerProxy) serveExecWS(w http.ResponseWriter, r *http.Request) { + ap.sessionForProto(w, r, ksr.ExecSessionType, ksr.WSProtocol) +} + +// serveAttachSPDY serves '/attach' requests for sessions streamed over SPDY, +// optionally configuring the kubectl exec sessions to be recorded. +func (ap *APIServerProxy) serveAttachSPDY(w http.ResponseWriter, r *http.Request) { + ap.sessionForProto(w, r, ksr.AttachSessionType, ksr.SPDYProtocol) +} + +// serveAttachWS serves '/attach' requests for sessions streamed over WebSocket, +// optionally configuring the kubectl exec sessions to be recorded. +func (ap *APIServerProxy) serveAttachWS(w http.ResponseWriter, r *http.Request) { + ap.sessionForProto(w, r, ksr.AttachSessionType, ksr.WSProtocol) +} + +func (ap *APIServerProxy) sessionForProto(w http.ResponseWriter, r *http.Request, sessionType ksr.SessionType, proto ksr.Protocol) { + const ( + podNameKey = "pod" + namespaceNameKey = "namespace" + upgradeHeaderKey = "Upgrade" + ) + + who, err := ap.whoIs(r) + if err != nil { + ap.authError(w, err) + return + } + + if err = ap.recordRequestAsEvent(r, who); err != nil { + msg := fmt.Sprintf("error recording Kubernetes API request: %v", err) + ap.log.Errorf(msg) + http.Error(w, msg, http.StatusBadGateway) + return + } + + counterNumRequestsProxied.Add(1) + failOpen, addrs, err := determineRecorderConfig(who) + if err != nil { + ap.log.Errorf("error trying to determine whether the 'kubectl %s' session needs to be recorded: %v", sessionType, err) + return + } + if failOpen && len(addrs) == 0 { // will not record + ap.rp.ServeHTTP(w, r.WithContext(whoIsKey.WithValue(r.Context(), who))) + return + } + ksr.CounterSessionRecordingsAttempted.Add(1) // at this point we know that users intended for this session to be recorded + if !failOpen && len(addrs) == 0 { + msg := fmt.Sprintf("forbidden: 'kubectl %s' session must be recorded, but no recorders are available.", sessionType) + ap.log.Error(msg) + http.Error(w, msg, http.StatusForbidden) + return + } + + wantsHeader := upgradeHeaderForProto[proto] + if h := r.Header.Get(upgradeHeaderKey); h != wantsHeader { + msg := fmt.Sprintf("[unexpected] unable to verify that streaming protocol is %s, wants Upgrade header %q, got: %q", proto, wantsHeader, h) + if failOpen { + msg = msg + "; failure mode is 'fail open'; continuing session without recording." + ap.log.Warn(msg) + ap.rp.ServeHTTP(w, r.WithContext(whoIsKey.WithValue(r.Context(), who))) + return + } + ap.log.Error(msg) + msg += "; failure mode is 'fail closed'; closing connection." + http.Error(w, msg, http.StatusForbidden) + return + } + + opts := ksr.HijackerOpts{ + Req: r, + W: w, + Proto: proto, + SessionType: sessionType, + TS: ap.ts, + Who: who, + Addrs: addrs, + FailOpen: failOpen, + Pod: r.PathValue(podNameKey), + Namespace: r.PathValue(namespaceNameKey), + Log: ap.log, + } + h := ksr.NewHijacker(opts) + + ap.rp.ServeHTTP(h, r.WithContext(whoIsKey.WithValue(r.Context(), who))) +} + +func (ap *APIServerProxy) recordRequestAsEvent(req *http.Request, who *apitype.WhoIsResponse) error { + if !ap.eventsEnabled { + return nil + } + + failOpen, addrs, err := determineRecorderConfig(who) + if err != nil { + return fmt.Errorf("error trying to determine whether the kubernetes api request needs to be recorded: %w", err) + } + if len(addrs) == 0 { + if failOpen { + return nil + } else { + return fmt.Errorf("forbidden: kubernetes api request must be recorded, but no recorders are available") + } + } + + factory := &request.RequestInfoFactory{ + APIPrefixes: sets.NewString("api", "apis"), + GrouplessAPIPrefixes: sets.NewString("api"), + } + + reqInfo, err := factory.NewRequestInfo(req) + if err != nil { + return fmt.Errorf("error parsing request %s %s: %w", req.Method, req.URL.Path, err) + } + + kubeReqInfo := sessionrecording.KubernetesRequestInfo{ + IsResourceRequest: reqInfo.IsResourceRequest, + Path: reqInfo.Path, + Verb: reqInfo.Verb, + APIPrefix: reqInfo.APIPrefix, + APIGroup: reqInfo.APIGroup, + APIVersion: reqInfo.APIVersion, + Namespace: reqInfo.Namespace, + Resource: reqInfo.Resource, + Subresource: reqInfo.Subresource, + Name: reqInfo.Name, + Parts: reqInfo.Parts, + FieldSelector: reqInfo.FieldSelector, + LabelSelector: reqInfo.LabelSelector, + } + event := &sessionrecording.Event{ + Timestamp: time.Now().Unix(), + Kubernetes: kubeReqInfo, + Type: sessionrecording.KubernetesAPIEventType, + UserAgent: req.UserAgent(), + Request: sessionrecording.Request{ + Method: req.Method, + Path: req.URL.String(), + QueryParameters: req.URL.Query(), + }, + Source: sessionrecording.Source{ + NodeID: who.Node.StableID, + Node: strings.TrimSuffix(who.Node.Name, "."), + }, + } + + if !who.Node.IsTagged() { + event.Source.NodeUser = who.UserProfile.LoginName + event.Source.NodeUserID = who.UserProfile.ID + } else { + event.Source.NodeTags = who.Node.Tags + } + + bodyBytes, err := io.ReadAll(req.Body) + if err != nil { + return fmt.Errorf("failed to read body: %w", err) + } + req.Body = io.NopCloser(bytes.NewReader(bodyBytes)) + event.Request.Body = bodyBytes + + var errs []error + // TODO: ChaosInTheCRD ensure that if there are multiple addrs timing out we don't experience slowdown on client waiting for response. + fail := true + for _, addr := range addrs { + data := new(bytes.Buffer) + if err := json.NewEncoder(data).Encode(event); err != nil { + return fmt.Errorf("error marshaling request event: %w", err) + } + + if err := ap.sendEventFunc(addr, data, ap.ts.Dial); err != nil { + if apiSupportErr, ok := err.(sessionrecording.EventAPINotSupportedErr); ok { + ap.log.Warnf(apiSupportErr.Error()) + fail = false + } else { + err := fmt.Errorf("error sending event to recorder with address %q: %v", addr.String(), err) + errs = append(errs, err) + } + } else { + return nil + } + } + + merr := errors.Join(errs...) + if fail && failOpen { + msg := fmt.Sprintf("[unexpected] failed to send event to recorders with errors: %s", merr.Error()) + msg = msg + "; failure mode is 'fail open'; continuing request without recording." + ap.log.Warn(msg) + return nil + } + + return merr +} + +func (ap *APIServerProxy) addImpersonationHeadersAsRequired(r *http.Request) { + r.URL.Scheme = ap.upstreamURL.Scheme + r.URL.Host = ap.upstreamURL.Host + if !ap.authMode { + // If we are not providing authentication, then we are just + // proxying to the Kubernetes API, so we don't need to do + // anything else. + return + } + + // We want to proxy to the Kubernetes API, but we want to use + // the caller's identity to do so. We do this by impersonating + // the caller using the Kubernetes User Impersonation feature: + // https://kubernetes.io/docs/reference/access-authn-authz/authentication/#user-impersonation + + // Out of paranoia, remove all authentication headers that might + // have been set by the client. + r.Header.Del("Authorization") + r.Header.Del("Impersonate-Group") + r.Header.Del("Impersonate-User") + r.Header.Del("Impersonate-Uid") + for k := range r.Header { + if strings.HasPrefix(k, "Impersonate-Extra-") { + r.Header.Del(k) + } + } + + // Now add the impersonation headers that we want. + if err := addImpersonationHeaders(r, ap.log); err != nil { + ap.log.Errorf("failed to add impersonation headers: %v", err) + } +} + +func (ap *APIServerProxy) whoIs(r *http.Request) (*apitype.WhoIsResponse, error) { + who, remoteErr := ap.lc.WhoIs(r.Context(), r.RemoteAddr) + if remoteErr == nil { + ap.log.Debugf("WhoIs from remote addr: %s", r.RemoteAddr) + return who, nil + } + + var fwdErr error + fwdFor := r.Header.Get("X-Forwarded-For") + if fwdFor != "" && !ap.https { + who, fwdErr = ap.lc.WhoIs(r.Context(), fwdFor) + if fwdErr == nil { + ap.log.Debugf("WhoIs from X-Forwarded-For header: %s", fwdFor) + return who, nil + } + } + + return nil, errors.Join(remoteErr, fwdErr) +} + +func (ap *APIServerProxy) authError(w http.ResponseWriter, err error) { + ap.log.Errorf("failed to authenticate caller: %v", err) + http.Error(w, "failed to authenticate caller", http.StatusInternalServerError) +} + +const ( + // oldCapabilityName is a legacy form of + // tailfcg.PeerCapabilityKubernetes capability. The only capability rule + // that is respected for this form is group impersonation - for + // backwards compatibility reasons. + // TODO (irbekrm): determine if anyone uses this and remove if possible. + oldCapabilityName = "https://" + tailcfg.PeerCapabilityKubernetes +) + +// addImpersonationHeaders adds the appropriate headers to r to impersonate the +// caller when proxying to the Kubernetes API. It uses the WhoIsResponse stashed +// in the context by the apiserverProxy. +func addImpersonationHeaders(r *http.Request, log *zap.SugaredLogger) error { + log = log.With("remote", r.RemoteAddr) + who := whoIsKey.Value(r.Context()) + rules, err := tailcfg.UnmarshalCapJSON[kubetypes.KubernetesCapRule](who.CapMap, tailcfg.PeerCapabilityKubernetes) + if len(rules) == 0 && err == nil { + // Try the old capability name for backwards compatibility. + rules, err = tailcfg.UnmarshalCapJSON[kubetypes.KubernetesCapRule](who.CapMap, oldCapabilityName) + } + if err != nil { + return fmt.Errorf("failed to unmarshal capability: %v", err) + } + + var groupsAdded set.Slice[string] + for _, rule := range rules { + if rule.Impersonate == nil { + continue + } + for _, group := range rule.Impersonate.Groups { + if groupsAdded.Contains(group) { + continue + } + r.Header.Add("Impersonate-Group", group) + groupsAdded.Add(group) + log.Debugf("adding group impersonation header for user group %s", group) + } + } + + if !who.Node.IsTagged() { + r.Header.Set("Impersonate-User", who.UserProfile.LoginName) + log.Debugf("adding user impersonation header for user %s", who.UserProfile.LoginName) + return nil + } + // "Impersonate-Group" requires "Impersonate-User" to be set, so we set it + // to the node FQDN for tagged nodes. + nodeName := strings.TrimSuffix(who.Node.Name, ".") + r.Header.Set("Impersonate-User", nodeName) + log.Debugf("adding user impersonation header for node name %s", nodeName) + + // For legacy behavior (before caps), set the groups to the nodes tags. + if groupsAdded.Slice().Len() == 0 { + for _, tag := range who.Node.Tags { + r.Header.Add("Impersonate-Group", tag) + log.Debugf("adding group impersonation header for node tag %s", tag) + } + } + return nil +} + +// determineRecorderConfig determines recorder config from requester's peer +// capabilities. Determines whether a 'kubectl exec' session from this requester +// needs to be recorded and what recorders the recording should be sent to. +func determineRecorderConfig(who *apitype.WhoIsResponse) (failOpen bool, recorderAddresses []netip.AddrPort, _ error) { + if who == nil { + return false, nil, errors.New("[unexpected] cannot determine caller") + } + failOpen = true + rules, err := tailcfg.UnmarshalCapJSON[kubetypes.KubernetesCapRule](who.CapMap, tailcfg.PeerCapabilityKubernetes) + if err != nil { + return failOpen, nil, fmt.Errorf("failed to unmarshal Kubernetes capability: %w", err) + } + if len(rules) == 0 { + return failOpen, nil, nil + } + + for _, rule := range rules { + if len(rule.RecorderAddrs) != 0 { + // TODO (irbekrm): here or later determine if the + // recorders behind those addrs are online - else we + // spend 30s trying to reach a recorder whose tailscale + // status is offline. + recorderAddresses = append(recorderAddresses, rule.RecorderAddrs...) + } + if rule.EnforceRecorder { + failOpen = false + } + } + return failOpen, recorderAddresses, nil +} + +var upgradeHeaderForProto = map[ksr.Protocol]string{ + ksr.SPDYProtocol: "SPDY/3.1", + ksr.WSProtocol: "websocket", +} diff --git a/k8s-operator/api-proxy/proxy_events_test.go b/k8s-operator/api-proxy/proxy_events_test.go new file mode 100644 index 000000000..8bcf48436 --- /dev/null +++ b/k8s-operator/api-proxy/proxy_events_test.go @@ -0,0 +1,549 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !plan9 + +package apiproxy + +import ( + "bytes" + "encoding/json" + "errors" + "io" + "net/http" + "net/http/httptest" + "net/netip" + "net/url" + "reflect" + "testing" + + "go.uber.org/zap" + "tailscale.com/client/tailscale/apitype" + "tailscale.com/net/netx" + "tailscale.com/sessionrecording" + "tailscale.com/tailcfg" + "tailscale.com/tsnet" +) + +type fakeSender struct { + sent map[netip.AddrPort][]byte + err error + calls int +} + +func (s *fakeSender) Send(ap netip.AddrPort, event io.Reader, dial netx.DialFunc) error { + s.calls++ + if s.err != nil { + return s.err + } + if s.sent == nil { + s.sent = make(map[netip.AddrPort][]byte) + } + data, _ := io.ReadAll(event) + s.sent[ap] = data + return nil +} + +func (s *fakeSender) Reset() { + s.sent = nil + s.err = nil + s.calls = 0 +} + +func TestRecordRequestAsEvent(t *testing.T) { + zl, err := zap.NewDevelopment() + if err != nil { + t.Fatal(err) + } + + sender := &fakeSender{} + ap := &APIServerProxy{ + log: zl.Sugar(), + ts: &tsnet.Server{}, + sendEventFunc: sender.Send, + eventsEnabled: true, + } + + defaultWho := &apitype.WhoIsResponse{ + Node: &tailcfg.Node{ + StableID: "stable-id", + Name: "node.ts.net.", + }, + UserProfile: &tailcfg.UserProfile{ + ID: 1, + LoginName: "user@example.com", + }, + CapMap: tailcfg.PeerCapMap{ + tailcfg.PeerCapabilityKubernetes: []tailcfg.RawMessage{ + tailcfg.RawMessage(`{"recorderAddrs":["127.0.0.1:1234"]}`), + tailcfg.RawMessage(`{"enforceRecorder": true}`), + }, + }, + } + + defaultSource := sessionrecording.Source{ + Node: "node.ts.net", + NodeID: "stable-id", + NodeUser: "user@example.com", + NodeUserID: 1, + } + + tests := []struct { + name string + req func() *http.Request + who *apitype.WhoIsResponse + setupSender func() + wantErr bool + wantEvent *sessionrecording.Event + wantNumCalls int + }{ + { + name: "request-with-dot-in-name", + req: func() *http.Request { + return httptest.NewRequest("GET", "/api/v1/namespaces/default/pods/foo.bar", nil) + }, + who: defaultWho, + setupSender: func() { sender.Reset() }, + wantNumCalls: 1, + wantEvent: &sessionrecording.Event{ + Type: sessionrecording.KubernetesAPIEventType, + Request: sessionrecording.Request{ + Method: "GET", + Path: "/api/v1/namespaces/default/pods/foo.bar", + Body: nil, + QueryParameters: url.Values{}, + }, + Kubernetes: sessionrecording.KubernetesRequestInfo{ + IsResourceRequest: true, + Path: "/api/v1/namespaces/default/pods/foo.bar", + Verb: "get", + APIPrefix: "api", + APIVersion: "v1", + Namespace: "default", + Resource: "pods", + Name: "foo.bar", + Parts: []string{"pods", "foo.bar"}, + }, + Source: defaultSource, + }, + }, + { + name: "request-with-dash-in-name", + req: func() *http.Request { + return httptest.NewRequest("GET", "/api/v1/namespaces/default/pods/foo-bar", nil) + }, + who: defaultWho, + setupSender: func() { sender.Reset() }, + wantNumCalls: 1, + wantEvent: &sessionrecording.Event{ + Type: sessionrecording.KubernetesAPIEventType, + Request: sessionrecording.Request{ + Method: "GET", + Path: "/api/v1/namespaces/default/pods/foo-bar", + Body: nil, + QueryParameters: url.Values{}, + }, + Kubernetes: sessionrecording.KubernetesRequestInfo{ + IsResourceRequest: true, + Path: "/api/v1/namespaces/default/pods/foo-bar", + Verb: "get", + APIPrefix: "api", + APIVersion: "v1", + Namespace: "default", + Resource: "pods", + Name: "foo-bar", + Parts: []string{"pods", "foo-bar"}, + }, + Source: defaultSource, + }, + }, + { + name: "request-with-query-parameter", + req: func() *http.Request { + return httptest.NewRequest("GET", "/api/v1/pods?watch=true", nil) + }, + who: defaultWho, + setupSender: func() { sender.Reset() }, + wantNumCalls: 1, + wantEvent: &sessionrecording.Event{ + Type: sessionrecording.KubernetesAPIEventType, + Request: sessionrecording.Request{ + Method: "GET", + Path: "/api/v1/pods?watch=true", + Body: nil, + QueryParameters: url.Values{"watch": []string{"true"}}, + }, + Kubernetes: sessionrecording.KubernetesRequestInfo{ + IsResourceRequest: true, + Path: "/api/v1/pods", + Verb: "watch", + APIPrefix: "api", + APIVersion: "v1", + Resource: "pods", + Parts: []string{"pods"}, + }, + Source: defaultSource, + }, + }, + { + name: "request-with-label-selector", + req: func() *http.Request { + return httptest.NewRequest("GET", "/api/v1/pods?labelSelector=app%3Dfoo", nil) + }, + who: defaultWho, + setupSender: func() { sender.Reset() }, + wantNumCalls: 1, + wantEvent: &sessionrecording.Event{ + Type: sessionrecording.KubernetesAPIEventType, + Request: sessionrecording.Request{ + Method: "GET", + Path: "/api/v1/pods?labelSelector=app%3Dfoo", + Body: nil, + QueryParameters: url.Values{"labelSelector": []string{"app=foo"}}, + }, + Kubernetes: sessionrecording.KubernetesRequestInfo{ + IsResourceRequest: true, + Path: "/api/v1/pods", + Verb: "list", + APIPrefix: "api", + APIVersion: "v1", + Resource: "pods", + Parts: []string{"pods"}, + LabelSelector: "app=foo", + }, + Source: defaultSource, + }, + }, + { + name: "request-with-field-selector", + req: func() *http.Request { + return httptest.NewRequest("GET", "/api/v1/pods?fieldSelector=status.phase%3DRunning", nil) + }, + who: defaultWho, + setupSender: func() { sender.Reset() }, + wantNumCalls: 1, + wantEvent: &sessionrecording.Event{ + Type: sessionrecording.KubernetesAPIEventType, + Request: sessionrecording.Request{ + Method: "GET", + Path: "/api/v1/pods?fieldSelector=status.phase%3DRunning", + Body: nil, + QueryParameters: url.Values{"fieldSelector": []string{"status.phase=Running"}}, + }, + Kubernetes: sessionrecording.KubernetesRequestInfo{ + IsResourceRequest: true, + Path: "/api/v1/pods", + Verb: "list", + APIPrefix: "api", + APIVersion: "v1", + Resource: "pods", + Parts: []string{"pods"}, + FieldSelector: "status.phase=Running", + }, + Source: defaultSource, + }, + }, + { + name: "request-for-non-existent-resource", + req: func() *http.Request { + return httptest.NewRequest("GET", "/api/v1/foo", nil) + }, + who: defaultWho, + setupSender: func() { sender.Reset() }, + wantNumCalls: 1, + wantEvent: &sessionrecording.Event{ + Type: sessionrecording.KubernetesAPIEventType, + Request: sessionrecording.Request{ + Method: "GET", + Path: "/api/v1/foo", + Body: nil, + QueryParameters: url.Values{}, + }, + Kubernetes: sessionrecording.KubernetesRequestInfo{ + IsResourceRequest: true, + Path: "/api/v1/foo", + Verb: "list", + APIPrefix: "api", + APIVersion: "v1", + Resource: "foo", + Parts: []string{"foo"}, + }, + Source: defaultSource, + }, + }, + { + name: "basic-request", + req: func() *http.Request { + return httptest.NewRequest("GET", "/api/v1/pods", nil) + }, + who: defaultWho, + setupSender: func() { sender.Reset() }, + wantNumCalls: 1, + wantEvent: &sessionrecording.Event{ + Type: sessionrecording.KubernetesAPIEventType, + Request: sessionrecording.Request{ + Method: "GET", + Path: "/api/v1/pods", + Body: nil, + QueryParameters: url.Values{}, + }, + Kubernetes: sessionrecording.KubernetesRequestInfo{ + IsResourceRequest: true, + Path: "/api/v1/pods", + Verb: "list", + APIPrefix: "api", + APIVersion: "v1", + Resource: "pods", + Parts: []string{"pods"}, + }, + Source: defaultSource, + }, + }, + { + name: "multiple-recorders", + req: func() *http.Request { + return httptest.NewRequest("GET", "/api/v1/pods", nil) + }, + who: &apitype.WhoIsResponse{ + Node: defaultWho.Node, + UserProfile: defaultWho.UserProfile, + CapMap: tailcfg.PeerCapMap{ + tailcfg.PeerCapabilityKubernetes: []tailcfg.RawMessage{ + tailcfg.RawMessage(`{"recorderAddrs":["127.0.0.1:1234", "127.0.0.1:5678"]}`), + }, + }, + }, + setupSender: func() { sender.Reset() }, + wantNumCalls: 1, + }, + { + name: "request-with-body", + req: func() *http.Request { + req := httptest.NewRequest("POST", "/api/v1/pods", bytes.NewBufferString(`{"foo":"bar"}`)) + req.Header.Set("Content-Type", "application/json") + return req + }, + who: defaultWho, + setupSender: func() { sender.Reset() }, + wantNumCalls: 1, + wantEvent: &sessionrecording.Event{ + Type: sessionrecording.KubernetesAPIEventType, + Request: sessionrecording.Request{ + Method: "POST", + Path: "/api/v1/pods", + Body: json.RawMessage(`{"foo":"bar"}`), + QueryParameters: url.Values{}, + }, + Kubernetes: sessionrecording.KubernetesRequestInfo{ + IsResourceRequest: true, + Path: "/api/v1/pods", + Verb: "create", + APIPrefix: "api", + APIVersion: "v1", + Resource: "pods", + Parts: []string{"pods"}, + }, + Source: defaultSource, + }, + }, + { + name: "tagged-node", + req: func() *http.Request { + return httptest.NewRequest("GET", "/api/v1/pods", nil) + }, + who: &apitype.WhoIsResponse{ + Node: &tailcfg.Node{ + StableID: "stable-id", + Name: "node.ts.net.", + Tags: []string{"tag:foo"}, + }, + UserProfile: &tailcfg.UserProfile{}, + CapMap: defaultWho.CapMap, + }, + setupSender: func() { sender.Reset() }, + wantNumCalls: 1, + wantEvent: &sessionrecording.Event{ + Type: sessionrecording.KubernetesAPIEventType, + Request: sessionrecording.Request{ + Method: "GET", + Path: "/api/v1/pods", + Body: nil, + QueryParameters: url.Values{}, + }, + Kubernetes: sessionrecording.KubernetesRequestInfo{ + IsResourceRequest: true, + Path: "/api/v1/pods", + Verb: "list", + APIPrefix: "api", + APIVersion: "v1", + Resource: "pods", + Parts: []string{"pods"}, + }, + Source: sessionrecording.Source{ + Node: "node.ts.net", + NodeID: "stable-id", + NodeTags: []string{"tag:foo"}, + }, + }, + }, + { + name: "no-recorders", + req: func() *http.Request { + return httptest.NewRequest("GET", "/api/v1/pods", nil) + }, + who: &apitype.WhoIsResponse{ + Node: defaultWho.Node, + UserProfile: defaultWho.UserProfile, + CapMap: tailcfg.PeerCapMap{}, + }, + setupSender: func() { sender.Reset() }, + wantNumCalls: 0, + }, + { + name: "error-sending", + req: func() *http.Request { + return httptest.NewRequest("GET", "/api/v1/pods", nil) + }, + who: defaultWho, + setupSender: func() { + sender.Reset() + sender.err = errors.New("send error") + }, + wantErr: true, + wantNumCalls: 1, + }, + { + name: "request-for-crd", + req: func() *http.Request { + return httptest.NewRequest("GET", "/apis/custom.example.com/v1/myresources", nil) + }, + who: defaultWho, + setupSender: func() { sender.Reset() }, + wantNumCalls: 1, + wantEvent: &sessionrecording.Event{ + Type: sessionrecording.KubernetesAPIEventType, + Request: sessionrecording.Request{ + Method: "GET", + Path: "/apis/custom.example.com/v1/myresources", + Body: nil, + QueryParameters: url.Values{}, + }, + Kubernetes: sessionrecording.KubernetesRequestInfo{ + IsResourceRequest: true, + Path: "/apis/custom.example.com/v1/myresources", + Verb: "list", + APIPrefix: "apis", + APIGroup: "custom.example.com", + APIVersion: "v1", + Resource: "myresources", + Parts: []string{"myresources"}, + }, + Source: defaultSource, + }, + }, + { + name: "request-with-proxy-verb", + req: func() *http.Request { + return httptest.NewRequest("GET", "/api/v1/namespaces/default/pods/foo/proxy", nil) + }, + who: defaultWho, + setupSender: func() { sender.Reset() }, + wantNumCalls: 1, + wantEvent: &sessionrecording.Event{ + Type: sessionrecording.KubernetesAPIEventType, + Request: sessionrecording.Request{ + Method: "GET", + Path: "/api/v1/namespaces/default/pods/foo/proxy", + Body: nil, + QueryParameters: url.Values{}, + }, + Kubernetes: sessionrecording.KubernetesRequestInfo{ + IsResourceRequest: true, + Path: "/api/v1/namespaces/default/pods/foo/proxy", + Verb: "get", + APIPrefix: "api", + APIVersion: "v1", + Namespace: "default", + Resource: "pods", + Subresource: "proxy", + Name: "foo", + Parts: []string{"pods", "foo", "proxy"}, + }, + Source: defaultSource, + }, + }, + { + name: "request-with-complex-path", + req: func() *http.Request { + return httptest.NewRequest("GET", "/api/v1/namespaces/default/services/foo:8080/proxy-subpath/more/segments", nil) + }, + who: defaultWho, + setupSender: func() { sender.Reset() }, + wantNumCalls: 1, + wantEvent: &sessionrecording.Event{ + Type: sessionrecording.KubernetesAPIEventType, + Request: sessionrecording.Request{ + Method: "GET", + Path: "/api/v1/namespaces/default/services/foo:8080/proxy-subpath/more/segments", + Body: nil, + QueryParameters: url.Values{}, + }, + Kubernetes: sessionrecording.KubernetesRequestInfo{ + IsResourceRequest: true, + Path: "/api/v1/namespaces/default/services/foo:8080/proxy-subpath/more/segments", + Verb: "get", + APIPrefix: "api", + APIVersion: "v1", + Namespace: "default", + Resource: "services", + Subresource: "proxy-subpath", + Name: "foo:8080", + Parts: []string{"services", "foo:8080", "proxy-subpath", "more", "segments"}, + }, + Source: defaultSource, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.setupSender() + + req := tt.req() + err := ap.recordRequestAsEvent(req, tt.who) + + if (err != nil) != tt.wantErr { + t.Fatalf("recordRequestAsEvent() error = %v, wantErr %v", err, tt.wantErr) + } + + if sender.calls != tt.wantNumCalls { + t.Fatalf("expected %d calls to sender, got %d", tt.wantNumCalls, sender.calls) + } + + if tt.wantEvent != nil { + for _, sentData := range sender.sent { + var got sessionrecording.Event + if err := json.Unmarshal(sentData, &got); err != nil { + t.Fatalf("failed to unmarshal sent event: %v", err) + } + + got.Timestamp = 0 + tt.wantEvent.Timestamp = got.Timestamp + + got.UserAgent = "" + tt.wantEvent.UserAgent = "" + + if !bytes.Equal(got.Request.Body, tt.wantEvent.Request.Body) { + t.Errorf("sent event body does not match wanted event body.\nGot: %s\nWant: %s", string(got.Request.Body), string(tt.wantEvent.Request.Body)) + } + got.Request.Body = nil + tt.wantEvent.Request.Body = nil + + if !reflect.DeepEqual(&got, tt.wantEvent) { + t.Errorf("sent event does not match wanted event.\nGot: %#v\nWant: %#v", &got, tt.wantEvent) + } + } + } + }) + } +} diff --git a/cmd/k8s-operator/proxy_test.go b/k8s-operator/api-proxy/proxy_test.go similarity index 99% rename from cmd/k8s-operator/proxy_test.go rename to k8s-operator/api-proxy/proxy_test.go index d1d5733e7..71bf65648 100644 --- a/cmd/k8s-operator/proxy_test.go +++ b/k8s-operator/api-proxy/proxy_test.go @@ -3,7 +3,7 @@ //go:build !plan9 -package main +package apiproxy import ( "net/http" diff --git a/k8s-operator/api.md b/k8s-operator/api.md index d343e6395..979d199cb 100644 --- a/k8s-operator/api.md +++ b/k8s-operator/api.md @@ -21,6 +21,37 @@ +#### APIServerProxyMode + +_Underlying type:_ _string_ + + + +_Validation:_ +- Enum: [auth noauth] +- Type: string + +_Appears in:_ +- [KubeAPIServerConfig](#kubeapiserverconfig) + + + +#### AppConnector + + + +AppConnector defines a Tailscale app connector node configured via Connector. + + + +_Appears in:_ +- [ConnectorSpec](#connectorspec) + +| Field | Description | Default | Validation | +| --- | --- | --- | --- | +| `routes` _[Routes](#routes)_ | Routes are optional preconfigured routes for the domains routed via the app connector.
If not set, routes for the domains will be discovered dynamically.
If set, the app connector will immediately be able to route traffic using the preconfigured routes, but may
also dynamically discover other routes.
https://tailscale.com/kb/1332/apps-best-practices#preconfiguration | | Format: cidr
MinItems: 1
Type: string
| + + #### Connector @@ -50,6 +81,23 @@ _Appears in:_ | `status` _[ConnectorStatus](#connectorstatus)_ | ConnectorStatus describes the status of the Connector. This is set
and managed by the Tailscale operator. | | | +#### ConnectorDevice + + + + + + + +_Appears in:_ +- [ConnectorStatus](#connectorstatus) + +| Field | Description | Default | Validation | +| --- | --- | --- | --- | +| `hostname` _string_ | Hostname is the fully qualified domain name of the Connector replica.
If MagicDNS is enabled in your tailnet, it is the MagicDNS name of the
node. | | | +| `tailnetIPs` _string array_ | TailnetIPs is the set of tailnet IP addresses (both IPv4 and IPv6)
assigned to the Connector replica. | | | + + #### ConnectorList @@ -84,10 +132,13 @@ _Appears in:_ | Field | Description | Default | Validation | | --- | --- | --- | --- | | `tags` _[Tags](#tags)_ | Tags that the Tailscale node will be tagged with.
Defaults to [tag:k8s].
To autoapprove the subnet routes or exit node defined by a Connector,
you can configure Tailscale ACLs to give these tags the necessary
permissions.
See https://tailscale.com/kb/1337/acl-syntax#autoapprovers.
If you specify custom tags here, you must also make the operator an owner of these tags.
See https://tailscale.com/kb/1236/kubernetes-operator/#setting-up-the-kubernetes-operator.
Tags cannot be changed once a Connector node has been created.
Tag values must be in form ^tag:[a-zA-Z][a-zA-Z0-9-]*$. | | Pattern: `^tag:[a-zA-Z][a-zA-Z0-9-]*$`
Type: string
| -| `hostname` _[Hostname](#hostname)_ | Hostname is the tailnet hostname that should be assigned to the
Connector node. If unset, hostname defaults to name>-connector. Hostname can contain lower case letters, numbers and
dashes, it must not start or end with a dash and must be between 2
and 63 characters long. | | Pattern: `^[a-z0-9][a-z0-9-]{0,61}[a-z0-9]$`
Type: string
| +| `hostname` _[Hostname](#hostname)_ | Hostname is the tailnet hostname that should be assigned to the
Connector node. If unset, hostname defaults to name>-connector. Hostname can contain lower case letters, numbers and
dashes, it must not start or end with a dash and must be between 2
and 63 characters long. This field should only be used when creating a connector
with an unspecified number of replicas, or a single replica. | | Pattern: `^[a-z0-9][a-z0-9-]{0,61}[a-z0-9]$`
Type: string
| +| `hostnamePrefix` _[HostnamePrefix](#hostnameprefix)_ | HostnamePrefix specifies the hostname prefix for each
replica. Each device will have the integer number
from its StatefulSet pod appended to this prefix to form the full hostname.
HostnamePrefix can contain lower case letters, numbers and dashes, it
must not start with a dash and must be between 1 and 62 characters long. | | Pattern: `^[a-z0-9][a-z0-9-]{0,61}$`
Type: string
| | `proxyClass` _string_ | ProxyClass is the name of the ProxyClass custom resource that
contains configuration options that should be applied to the
resources created for this Connector. If unset, the operator will
create resources with the default configuration. | | | -| `subnetRouter` _[SubnetRouter](#subnetrouter)_ | SubnetRouter defines subnet routes that the Connector node should
expose to tailnet. If unset, none are exposed.
https://tailscale.com/kb/1019/subnets/ | | | -| `exitNode` _boolean_ | ExitNode defines whether the Connector node should act as a
Tailscale exit node. Defaults to false.
https://tailscale.com/kb/1103/exit-nodes | | | +| `subnetRouter` _[SubnetRouter](#subnetrouter)_ | SubnetRouter defines subnet routes that the Connector device should
expose to tailnet as a Tailscale subnet router.
https://tailscale.com/kb/1019/subnets/
If this field is unset, the device does not get configured as a Tailscale subnet router.
This field is mutually exclusive with the appConnector field. | | | +| `appConnector` _[AppConnector](#appconnector)_ | AppConnector defines whether the Connector device should act as a Tailscale app connector. A Connector that is
configured as an app connector cannot be a subnet router or an exit node. If this field is unset, the
Connector does not act as an app connector.
Note that you will need to manually configure the permissions and the domains for the app connector via the
Admin panel.
Note also that the main tested and supported use case of this config option is to deploy an app connector on
Kubernetes to access SaaS applications available on the public internet. Using the app connector to expose
cluster workloads or other internal workloads to tailnet might work, but this is not a use case that we have
tested or optimised for.
If you are using the app connector to access SaaS applications because you need a predictable egress IP that
can be whitelisted, it is also your responsibility to ensure that cluster traffic from the connector flows
via that predictable IP, for example by enforcing that cluster egress traffic is routed via an egress NAT
device with a static IP address.
https://tailscale.com/kb/1281/app-connectors | | | +| `exitNode` _boolean_ | ExitNode defines whether the Connector device should act as a Tailscale exit node. Defaults to false.
This field is mutually exclusive with the appConnector field.
https://tailscale.com/kb/1103/exit-nodes | | | +| `replicas` _integer_ | Replicas specifies how many devices to create. Set this to enable
high availability for app connectors, subnet routers, or exit nodes.
https://tailscale.com/kb/1115/high-availability. Defaults to 1. | | Minimum: 0
| #### ConnectorStatus @@ -106,8 +157,10 @@ _Appears in:_ | `conditions` _[Condition](https://kubernetes.io/docs/reference/generated/kubernetes-api/v1.3/#condition-v1-meta) array_ | List of status conditions to indicate the status of the Connector.
Known condition types are `ConnectorReady`. | | | | `subnetRoutes` _string_ | SubnetRoutes are the routes currently exposed to tailnet via this
Connector instance. | | | | `isExitNode` _boolean_ | IsExitNode is set to true if the Connector acts as an exit node. | | | +| `isAppConnector` _boolean_ | IsAppConnector is set to true if the Connector acts as an app connector. | | | | `tailnetIPs` _string array_ | TailnetIPs is the set of tailnet IP addresses (both IPv4 and IPv6)
assigned to the Connector node. | | | -| `hostname` _string_ | Hostname is the fully qualified domain name of the Connector node.
If MagicDNS is enabled in your tailnet, it is the MagicDNS name of the
node. | | | +| `hostname` _string_ | Hostname is the fully qualified domain name of the Connector node.
If MagicDNS is enabled in your tailnet, it is the MagicDNS name of the
node. When using multiple replicas, this field will be populated with the
first replica's hostname. Use the Hostnames field for the full list
of hostnames. | | | +| `devices` _[ConnectorDevice](#connectordevice) array_ | Devices contains information on each device managed by the Connector resource. | | | #### Container @@ -124,10 +177,11 @@ _Appears in:_ | Field | Description | Default | Validation | | --- | --- | --- | --- | | `env` _[Env](#env) array_ | List of environment variables to set in the container.
https://kubernetes.io/docs/reference/kubernetes-api/workload-resources/pod-v1/#environment-variables
Note that environment variables provided here will take precedence
over Tailscale-specific environment variables set by the operator,
however running proxies with custom values for Tailscale environment
variables (i.e TS_USERSPACE) is not recommended and might break in
the future. | | | -| `image` _string_ | Container image name. By default images are pulled from
docker.io/tailscale/tailscale, but the official images are also
available at ghcr.io/tailscale/tailscale. Specifying image name here
will override any proxy image values specified via the Kubernetes
operator's Helm chart values or PROXY_IMAGE env var in the operator
Deployment.
https://kubernetes.io/docs/reference/kubernetes-api/workload-resources/pod-v1/#image | | | +| `image` _string_ | Container image name. By default images are pulled from docker.io/tailscale,
but the official images are also available at ghcr.io/tailscale.
For all uses except on ProxyGroups of type "kube-apiserver", this image must
be either tailscale/tailscale, or an equivalent mirror of that image.
To apply to ProxyGroups of type "kube-apiserver", this image must be
tailscale/k8s-proxy or a mirror of that image.
For "tailscale/tailscale"-based proxies, specifying image name here will
override any proxy image values specified via the Kubernetes operator's
Helm chart values or PROXY_IMAGE env var in the operator Deployment.
For "tailscale/k8s-proxy"-based proxies, there is currently no way to
configure your own default, and this field is the only way to use a
custom image.
https://kubernetes.io/docs/reference/kubernetes-api/workload-resources/pod-v1/#image | | | | `imagePullPolicy` _[PullPolicy](https://kubernetes.io/docs/reference/generated/kubernetes-api/v1.3/#pullpolicy-v1-core)_ | Image pull policy. One of Always, Never, IfNotPresent. Defaults to Always.
https://kubernetes.io/docs/reference/kubernetes-api/workload-resources/pod-v1/#image | | Enum: [Always Never IfNotPresent]
| | `resources` _[ResourceRequirements](https://kubernetes.io/docs/reference/generated/kubernetes-api/v1.3/#resourcerequirements-v1-core)_ | Container resource requirements.
By default Tailscale Kubernetes operator does not apply any resource
requirements. The amount of resources required wil depend on the
amount of resources the operator needs to parse, usage patterns and
cluster size.
https://kubernetes.io/docs/reference/kubernetes-api/workload-resources/pod-v1/#resources | | | -| `securityContext` _[SecurityContext](https://kubernetes.io/docs/reference/generated/kubernetes-api/v1.3/#securitycontext-v1-core)_ | Container security context.
Security context specified here will override the security context by the operator.
By default the operator:
- sets 'privileged: true' for the init container
- set NET_ADMIN capability for tailscale container for proxies that
are created for Services or Connector.
https://kubernetes.io/docs/reference/kubernetes-api/workload-resources/pod-v1/#security-context | | | +| `securityContext` _[SecurityContext](https://kubernetes.io/docs/reference/generated/kubernetes-api/v1.3/#securitycontext-v1-core)_ | Container security context.
Security context specified here will override the security context set by the operator.
By default the operator sets the Tailscale container and the Tailscale init container to privileged
for proxies created for Tailscale ingress and egress Service, Connector and ProxyGroup.
You can reduce the permissions of the Tailscale container to cap NET_ADMIN by
installing device plugin in your cluster and configuring the proxies tun device to be created
by the device plugin, see https://github.com/tailscale/tailscale/issues/10814#issuecomment-2479977752
https://kubernetes.io/docs/reference/kubernetes-api/workload-resources/pod-v1/#security-context | | | +| `debug` _[Debug](#debug)_ | Configuration for enabling extra debug information in the container.
Not recommended for production use. | | | #### DNSConfig @@ -159,7 +213,6 @@ NB: if you want cluster workloads to be able to refer to Tailscale Ingress using its MagicDNS name, you must also annotate the Ingress resource with tailscale.com/experimental-forward-cluster-traffic-via-ingress annotation to ensure that the proxy created for the Ingress listens on its Pod IP address. -NB: Clusters where Pods get assigned IPv6 addresses only are currently not supported. @@ -230,6 +283,22 @@ _Appears in:_ | `nameserver` _[NameserverStatus](#nameserverstatus)_ | Nameserver describes the status of nameserver cluster resources. | | | +#### Debug + + + + + + + +_Appears in:_ +- [Container](#container) + +| Field | Description | Default | Validation | +| --- | --- | --- | --- | +| `enable` _boolean_ | Enable tailscaled's HTTP pprof endpoints at :9001/debug/pprof/
and internal debug metrics endpoint at :9001/debug/metrics, where
9001 is a container port named "debug". The endpoints and their responses
may change in backwards incompatible ways in the future, and should not
be considered stable.
In 1.78.x and 1.80.x, this setting will default to the value of
.spec.metrics.enable, and requests to the "metrics" port matching the
mux pattern /debug/ will be forwarded to the "debug" port. In 1.82.x,
this setting will default to false, and no requests will be proxied. | | | + + #### Env @@ -273,9 +342,58 @@ _Validation:_ - Pattern: `^[a-z0-9][a-z0-9-]{0,61}$` - Type: string +_Appears in:_ +- [ConnectorSpec](#connectorspec) +- [ProxyGroupSpec](#proxygroupspec) + + + +#### KubeAPIServerConfig + + + +KubeAPIServerConfig contains configuration specific to the kube-apiserver ProxyGroup type. + + + _Appears in:_ - [ProxyGroupSpec](#proxygroupspec) +| Field | Description | Default | Validation | +| --- | --- | --- | --- | +| `mode` _[APIServerProxyMode](#apiserverproxymode)_ | Mode to run the API server proxy in. Supported modes are auth and noauth.
In auth mode, requests from the tailnet proxied over to the Kubernetes
API server are additionally impersonated using the sender's tailnet identity.
If not specified, defaults to auth mode. | | Enum: [auth noauth]
Type: string
| +| `hostname` _string_ | Hostname is the hostname with which to expose the Kubernetes API server
proxies. Must be a valid DNS label no longer than 63 characters. If not
specified, the name of the ProxyGroup is used as the hostname. Must be
unique across the whole tailnet. | | Pattern: `^[a-z0-9]([a-z0-9-]{0,61}[a-z0-9])?$`
Type: string
| + + +#### LabelValue + +_Underlying type:_ _string_ + + + +_Validation:_ +- MaxLength: 63 +- Pattern: `^(([a-zA-Z0-9][-._a-zA-Z0-9]*)?[a-zA-Z0-9])?$` +- Type: string + +_Appears in:_ +- [Labels](#labels) + + + +#### Labels + +_Underlying type:_ _[map[string]LabelValue](#map[string]labelvalue)_ + + + + + +_Appears in:_ +- [Pod](#pod) +- [ServiceMonitor](#servicemonitor) +- [StatefulSet](#statefulset) + #### Metrics @@ -291,7 +409,8 @@ _Appears in:_ | Field | Description | Default | Validation | | --- | --- | --- | --- | -| `enable` _boolean_ | Setting enable to true will make the proxy serve Tailscale metrics
at :9001/debug/metrics.
Defaults to false. | | | +| `enable` _boolean_ | Setting enable to true will make the proxy serve Tailscale metrics
at :9002/metrics.
A metrics Service named -metrics will also be created in the operator's namespace and will
serve the metrics at :9002/metrics.
In 1.78.x and 1.80.x, this field also serves as the default value for
.spec.statefulSet.pod.tailscaleContainer.debug.enable. From 1.82.0, both
fields will independently default to false.
Defaults to false. | | | +| `serviceMonitor` _[ServiceMonitor](#servicemonitor)_ | Enable to create a Prometheus ServiceMonitor for scraping the proxy's Tailscale metrics.
The ServiceMonitor will select the metrics Service that gets created when metrics are enabled.
The ingested metrics for each Service monitor will have labels to identify the proxy:
ts_proxy_type: ingress_service\|ingress_resource\|connector\|proxygroup
ts_proxy_parent_name: name of the parent resource (i.e name of the Connector, Tailscale Ingress, Tailscale Service or ProxyGroup)
ts_proxy_parent_namespace: namespace of the parent resource (if the parent resource is not cluster scoped)
job: ts__[]_ | | | #### Name @@ -323,6 +442,9 @@ _Appears in:_ | Field | Description | Default | Validation | | --- | --- | --- | --- | | `image` _[NameserverImage](#nameserverimage)_ | Nameserver image. Defaults to tailscale/k8s-nameserver:unstable. | | | +| `service` _[NameserverService](#nameserverservice)_ | Service configuration. | | | +| `pod` _[NameserverPod](#nameserverpod)_ | Pod configuration. | | | +| `replicas` _integer_ | Replicas specifies how many Pods to create. Defaults to 1. | | Minimum: 0
| #### NameserverImage @@ -342,6 +464,38 @@ _Appears in:_ | `tag` _string_ | Tag defaults to unstable. | | | +#### NameserverPod + + + + + + + +_Appears in:_ +- [Nameserver](#nameserver) + +| Field | Description | Default | Validation | +| --- | --- | --- | --- | +| `tolerations` _[Toleration](https://kubernetes.io/docs/reference/generated/kubernetes-api/v1.3/#toleration-v1-core) array_ | If specified, applies tolerations to the pods deployed by the DNSConfig resource. | | | + + +#### NameserverService + + + + + + + +_Appears in:_ +- [Nameserver](#nameserver) + +| Field | Description | Default | Validation | +| --- | --- | --- | --- | +| `clusterIP` _string_ | ClusterIP sets the static IP of the service used by the nameserver. | | | + + #### NameserverStatus @@ -355,7 +509,24 @@ _Appears in:_ | Field | Description | Default | Validation | | --- | --- | --- | --- | -| `ip` _string_ | IP is the ClusterIP of the Service fronting the deployed ts.net nameserver.
Currently you must manually update your cluster DNS config to add
this address as a stub nameserver for ts.net for cluster workloads to be
able to resolve MagicDNS names associated with egress or Ingress
proxies.
The IP address will change if you delete and recreate the DNSConfig. | | | +| `ip` _string_ | IP is the ClusterIP of the Service fronting the deployed ts.net nameserver.
Currently, you must manually update your cluster DNS config to add
this address as a stub nameserver for ts.net for cluster workloads to be
able to resolve MagicDNS names associated with egress or Ingress
proxies.
The IP address will change if you delete and recreate the DNSConfig. | | | + + +#### NodePortConfig + + + + + + + +_Appears in:_ +- [StaticEndpointsConfig](#staticendpointsconfig) + +| Field | Description | Default | Validation | +| --- | --- | --- | --- | +| `ports` _[PortRange](#portrange) array_ | The port ranges from which the operator will select NodePorts for the Services.
You must ensure that firewall rules allow UDP ingress traffic for these ports
to the node's external IPs.
The ports must be in the range of service node ports for the cluster (default `30000-32767`).
See https://kubernetes.io/docs/concepts/services-networking/service/#type-nodeport. | | MinItems: 1
| +| `selector` _object (keys:string, values:string)_ | A selector which will be used to select the node's that will have their `ExternalIP`'s advertised
by the ProxyGroup as Static Endpoints. | | | #### Pod @@ -371,16 +542,40 @@ _Appears in:_ | Field | Description | Default | Validation | | --- | --- | --- | --- | -| `labels` _object (keys:string, values:string)_ | Labels that will be added to the proxy Pod.
Any labels specified here will be merged with the default labels
applied to the Pod by the Tailscale Kubernetes operator.
Label keys and values must be valid Kubernetes label keys and values.
https://kubernetes.io/docs/concepts/overview/working-with-objects/labels/#syntax-and-character-set | | | +| `labels` _[Labels](#labels)_ | Labels that will be added to the proxy Pod.
Any labels specified here will be merged with the default labels
applied to the Pod by the Tailscale Kubernetes operator.
Label keys and values must be valid Kubernetes label keys and values.
https://kubernetes.io/docs/concepts/overview/working-with-objects/labels/#syntax-and-character-set | | | | `annotations` _object (keys:string, values:string)_ | Annotations that will be added to the proxy Pod.
Any annotations specified here will be merged with the default
annotations applied to the Pod by the Tailscale Kubernetes operator.
Annotations must be valid Kubernetes annotations.
https://kubernetes.io/docs/concepts/overview/working-with-objects/annotations/#syntax-and-character-set | | | | `affinity` _[Affinity](https://kubernetes.io/docs/reference/generated/kubernetes-api/v1.3/#affinity-v1-core)_ | Proxy Pod's affinity rules.
By default, the Tailscale Kubernetes operator does not apply any affinity rules.
https://kubernetes.io/docs/reference/kubernetes-api/workload-resources/pod-v1/#affinity | | | | `tailscaleContainer` _[Container](#container)_ | Configuration for the proxy container running tailscale. | | | -| `tailscaleInitContainer` _[Container](#container)_ | Configuration for the proxy init container that enables forwarding. | | | +| `tailscaleInitContainer` _[Container](#container)_ | Configuration for the proxy init container that enables forwarding.
Not valid to apply to ProxyGroups of type "kube-apiserver". | | | | `securityContext` _[PodSecurityContext](https://kubernetes.io/docs/reference/generated/kubernetes-api/v1.3/#podsecuritycontext-v1-core)_ | Proxy Pod's security context.
By default Tailscale Kubernetes operator does not apply any Pod
security context.
https://kubernetes.io/docs/reference/kubernetes-api/workload-resources/pod-v1/#security-context-2 | | | | `imagePullSecrets` _[LocalObjectReference](https://kubernetes.io/docs/reference/generated/kubernetes-api/v1.3/#localobjectreference-v1-core) array_ | Proxy Pod's image pull Secrets.
https://kubernetes.io/docs/reference/kubernetes-api/workload-resources/pod-v1/#PodSpec | | | | `nodeName` _string_ | Proxy Pod's node name.
https://kubernetes.io/docs/reference/kubernetes-api/workload-resources/pod-v1/#scheduling | | | | `nodeSelector` _object (keys:string, values:string)_ | Proxy Pod's node selector.
By default Tailscale Kubernetes operator does not apply any node
selector.
https://kubernetes.io/docs/reference/kubernetes-api/workload-resources/pod-v1/#scheduling | | | | `tolerations` _[Toleration](https://kubernetes.io/docs/reference/generated/kubernetes-api/v1.3/#toleration-v1-core) array_ | Proxy Pod's tolerations.
By default Tailscale Kubernetes operator does not apply any
tolerations.
https://kubernetes.io/docs/reference/kubernetes-api/workload-resources/pod-v1/#scheduling | | | +| `topologySpreadConstraints` _[TopologySpreadConstraint](https://kubernetes.io/docs/reference/generated/kubernetes-api/v1.3/#topologyspreadconstraint-v1-core) array_ | Proxy Pod's topology spread constraints.
By default Tailscale Kubernetes operator does not apply any topology spread constraints.
https://kubernetes.io/docs/concepts/scheduling-eviction/topology-spread-constraints/ | | | +| `priorityClassName` _string_ | PriorityClassName for the proxy Pod.
By default Tailscale Kubernetes operator does not apply any priority class.
https://kubernetes.io/docs/reference/kubernetes-api/workload-resources/pod-v1/#scheduling | | | +| `dnsPolicy` _[DNSPolicy](https://kubernetes.io/docs/reference/generated/kubernetes-api/v1.3/#dnspolicy-v1-core)_ | DNSPolicy defines how DNS will be configured for the proxy Pod.
By default the Tailscale Kubernetes Operator does not set a DNS policy (uses cluster default).
https://kubernetes.io/docs/concepts/services-networking/dns-pod-service/#pod-s-dns-policy | | Enum: [ClusterFirstWithHostNet ClusterFirst Default None]
| +| `dnsConfig` _[PodDNSConfig](https://kubernetes.io/docs/reference/generated/kubernetes-api/v1.3/#poddnsconfig-v1-core)_ | DNSConfig defines DNS parameters for the proxy Pod in addition to those generated from DNSPolicy.
When DNSPolicy is set to "None", DNSConfig must be specified.
https://kubernetes.io/docs/concepts/services-networking/dns-pod-service/#pod-dns-config | | | + + +#### PortRange + + + + + + + +_Appears in:_ +- [NodePortConfig](#nodeportconfig) +- [PortRanges](#portranges) + +| Field | Description | Default | Validation | +| --- | --- | --- | --- | +| `port` _integer_ | port represents a port selected to be used. This is a required field. | | | +| `endPort` _integer_ | endPort indicates that the range of ports from port to endPort if set, inclusive,
should be used. This field cannot be defined if the port field is not defined.
The endPort must be either unset, or equal or greater than port. | | | + + #### ProxyClass @@ -449,6 +644,8 @@ _Appears in:_ | `statefulSet` _[StatefulSet](#statefulset)_ | Configuration parameters for the proxy's StatefulSet. Tailscale
Kubernetes operator deploys a StatefulSet for each of the user
configured proxies (Tailscale Ingress, Tailscale Service, Connector). | | | | `metrics` _[Metrics](#metrics)_ | Configuration for proxy metrics. Metrics are currently not supported
for egress proxies and for Ingress proxies that have been configured
with tailscale.com/experimental-forward-cluster-traffic-via-ingress
annotation. Note that the metrics are currently considered unstable
and will likely change in breaking ways in the future - we only
recommend that you use those for debugging purposes. | | | | `tailscale` _[TailscaleConfig](#tailscaleconfig)_ | TailscaleConfig contains options to configure the tailscale-specific
parameters of proxies. | | | +| `useLetsEncryptStagingEnvironment` _boolean_ | Set UseLetsEncryptStagingEnvironment to true to issue TLS
certificates for any HTTPS endpoints exposed to the tailnet from
LetsEncrypt's staging environment.
https://letsencrypt.org/docs/staging-environment/
This setting only affects Tailscale Ingress resources.
By default Ingress TLS certificates are issued from LetsEncrypt's
production environment.
Changing this setting true -> false, will result in any
existing certs being re-issued from the production environment.
Changing this setting false (default) -> true, when certs have already
been provisioned from production environment will NOT result in certs
being re-issued from the staging environment before they need to be
renewed. | | | +| `staticEndpoints` _[StaticEndpointsConfig](#staticendpointsconfig)_ | Configuration for 'static endpoints' on proxies in order to facilitate
direct connections from other devices on the tailnet.
See https://tailscale.com/kb/1445/kubernetes-operator-customization#static-endpoints. | | | #### ProxyClassStatus @@ -471,7 +668,23 @@ _Appears in:_ +ProxyGroup defines a set of Tailscale devices that will act as proxies. +Depending on spec.Type, it can be a group of egress, ingress, or kube-apiserver +proxies. In addition to running a highly available set of proxies, ingress +and egress ProxyGroups also allow for serving many annotated Services from a +single set of proxies to minimise resource consumption. + +For ingress and egress, use the tailscale.com/proxy-group annotation on a +Service to specify that the proxy should be implemented by a ProxyGroup +instead of a single dedicated proxy. + +More info: +* https://tailscale.com/kb/1438/kubernetes-operator-cluster-egress +* https://tailscale.com/kb/1439/kubernetes-operator-cluster-ingress +For kube-apiserver, the ProxyGroup is a standalone resource. Use the +spec.kubeAPIServer field to configure options specific to the kube-apiserver +ProxyGroup type. @@ -522,11 +735,12 @@ _Appears in:_ | Field | Description | Default | Validation | | --- | --- | --- | --- | -| `type` _[ProxyGroupType](#proxygrouptype)_ | Type of the ProxyGroup, either ingress or egress. Each set of proxies
managed by a single ProxyGroup definition operate as only ingress or
only egress proxies. | | Enum: [egress]
Type: string
| +| `type` _[ProxyGroupType](#proxygrouptype)_ | Type of the ProxyGroup proxies. Supported types are egress, ingress, and kube-apiserver.
Type is immutable once a ProxyGroup is created. | | Enum: [egress ingress kube-apiserver]
Type: string
| | `tags` _[Tags](#tags)_ | Tags that the Tailscale devices will be tagged with. Defaults to [tag:k8s].
If you specify custom tags here, make sure you also make the operator
an owner of these tags.
See https://tailscale.com/kb/1236/kubernetes-operator/#setting-up-the-kubernetes-operator.
Tags cannot be changed once a ProxyGroup device has been created.
Tag values must be in form ^tag:[a-zA-Z][a-zA-Z0-9-]*$. | | Pattern: `^tag:[a-zA-Z][a-zA-Z0-9-]*$`
Type: string
| -| `replicas` _integer_ | Replicas specifies how many replicas to create the StatefulSet with.
Defaults to 2. | | | +| `replicas` _integer_ | Replicas specifies how many replicas to create the StatefulSet with.
Defaults to 2. | | Minimum: 0
| | `hostnamePrefix` _[HostnamePrefix](#hostnameprefix)_ | HostnamePrefix is the hostname prefix to use for tailnet devices created
by the ProxyGroup. Each device will have the integer number from its
StatefulSet pod appended to this prefix to form the full hostname.
HostnamePrefix can contain lower case letters, numbers and dashes, it
must not start with a dash and must be between 1 and 62 characters long. | | Pattern: `^[a-z0-9][a-z0-9-]{0,61}$`
Type: string
| -| `proxyClass` _string_ | ProxyClass is the name of the ProxyClass custom resource that contains
configuration options that should be applied to the resources created
for this ProxyGroup. If unset, and no default ProxyClass is set, the
operator will create resources with the default configuration. | | | +| `proxyClass` _string_ | ProxyClass is the name of the ProxyClass custom resource that contains
configuration options that should be applied to the resources created
for this ProxyGroup. If unset, and there is no default ProxyClass
configured, the operator will create resources with the default
configuration. | | | +| `kubeAPIServer` _[KubeAPIServerConfig](#kubeapiserverconfig)_ | KubeAPIServer contains configuration specific to the kube-apiserver
ProxyGroup type. This field is only used when Type is set to "kube-apiserver". | | | #### ProxyGroupStatus @@ -542,8 +756,9 @@ _Appears in:_ | Field | Description | Default | Validation | | --- | --- | --- | --- | -| `conditions` _[Condition](https://kubernetes.io/docs/reference/generated/kubernetes-api/v1.3/#condition-v1-meta) array_ | List of status conditions to indicate the status of the ProxyGroup
resources. Known condition types are `ProxyGroupReady`. | | | +| `conditions` _[Condition](https://kubernetes.io/docs/reference/generated/kubernetes-api/v1.3/#condition-v1-meta) array_ | List of status conditions to indicate the status of the ProxyGroup
resources. Known condition types include `ProxyGroupReady` and
`ProxyGroupAvailable`.
* `ProxyGroupReady` indicates all ProxyGroup resources are reconciled and
all expected conditions are true.
* `ProxyGroupAvailable` indicates that at least one proxy is ready to
serve traffic.
For ProxyGroups of type kube-apiserver, there are two additional conditions:
* `KubeAPIServerProxyConfigured` indicates that at least one API server
proxy is configured and ready to serve traffic.
* `KubeAPIServerProxyValid` indicates that spec.kubeAPIServer config is
valid. | | | | `devices` _[TailnetDevice](#tailnetdevice) array_ | List of tailnet devices associated with the ProxyGroup StatefulSet. | | | +| `url` _string_ | URL of the kube-apiserver proxy advertised by the ProxyGroup devices, if
any. Only applies to ProxyGroups of type kube-apiserver. | | | #### ProxyGroupType @@ -553,7 +768,7 @@ _Underlying type:_ _string_ _Validation:_ -- Enum: [egress] +- Enum: [egress ingress kube-apiserver] - Type: string _Appears in:_ @@ -565,7 +780,11 @@ _Appears in:_ +Recorder defines a tsrecorder device for recording SSH sessions. By default, +it will store recordings in a local ephemeral volume. If you want to persist +recordings, you can configure an S3-compatible API for storage. +More info: https://tailscale.com/kb/1484/kubernetes-operator-deploying-tsrecorder @@ -644,6 +863,24 @@ _Appears in:_ | `imagePullSecrets` _[LocalObjectReference](https://kubernetes.io/docs/reference/generated/kubernetes-api/v1.3/#localobjectreference-v1-core) array_ | Image pull Secrets for Recorder Pods.
https://kubernetes.io/docs/reference/kubernetes-api/workload-resources/pod-v1/#PodSpec | | | | `nodeSelector` _object (keys:string, values:string)_ | Node selector rules for Recorder Pods. By default, the operator does
not apply any node selector rules.
https://kubernetes.io/docs/reference/kubernetes-api/workload-resources/pod-v1/#scheduling | | | | `tolerations` _[Toleration](https://kubernetes.io/docs/reference/generated/kubernetes-api/v1.3/#toleration-v1-core) array_ | Tolerations for Recorder Pods. By default, the operator does not apply
any tolerations.
https://kubernetes.io/docs/reference/kubernetes-api/workload-resources/pod-v1/#scheduling | | | +| `serviceAccount` _[RecorderServiceAccount](#recorderserviceaccount)_ | Config for the ServiceAccount to create for the Recorder's StatefulSet.
By default, the operator will create a ServiceAccount with the same
name as the Recorder resource.
https://kubernetes.io/docs/reference/kubernetes-api/workload-resources/pod-v1/#service-account | | | + + +#### RecorderServiceAccount + + + + + + + +_Appears in:_ +- [RecorderPod](#recorderpod) + +| Field | Description | Default | Validation | +| --- | --- | --- | --- | +| `name` _string_ | Name of the ServiceAccount to create. Defaults to the name of the
Recorder resource.
https://kubernetes.io/docs/reference/kubernetes-api/workload-resources/pod-v1/#service-account | | MaxLength: 253
Pattern: `^[a-z0-9]([a-z0-9-.]{0,61}[a-z0-9])?$`
Type: string
| +| `annotations` _object (keys:string, values:string)_ | Annotations to add to the ServiceAccount.
https://kubernetes.io/docs/concepts/overview/working-with-objects/annotations/#syntax-and-character-set
You can use this to add IAM roles to the ServiceAccount (IRSA) instead of
providing static S3 credentials in a Secret.
https://docs.aws.amazon.com/eks/latest/userguide/iam-roles-for-service-accounts.html
For example:
eks.amazonaws.com/role-arn: arn:aws:iam:::role/ | | | #### RecorderSpec @@ -745,6 +982,7 @@ _Validation:_ - Type: string _Appears in:_ +- [AppConnector](#appconnector) - [SubnetRouter](#subnetrouter) @@ -799,6 +1037,23 @@ _Appears in:_ | `name` _string_ | The name of a Kubernetes Secret in the operator's namespace that contains
credentials for writing to the configured bucket. Each key-value pair
from the secret's data will be mounted as an environment variable. It
should include keys for AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY if
using a static access key. | | | +#### ServiceMonitor + + + + + + + +_Appears in:_ +- [Metrics](#metrics) + +| Field | Description | Default | Validation | +| --- | --- | --- | --- | +| `enable` _boolean_ | If Enable is set to true, a Prometheus ServiceMonitor will be created. Enable can only be set to true if metrics are enabled. | | | +| `labels` _[Labels](#labels)_ | Labels to add to the ServiceMonitor.
Labels must be valid Kubernetes labels.
https://kubernetes.io/docs/concepts/overview/working-with-objects/labels/#syntax-and-character-set | | | + + #### StatefulSet @@ -812,11 +1067,27 @@ _Appears in:_ | Field | Description | Default | Validation | | --- | --- | --- | --- | -| `labels` _object (keys:string, values:string)_ | Labels that will be added to the StatefulSet created for the proxy.
Any labels specified here will be merged with the default labels
applied to the StatefulSet by the Tailscale Kubernetes operator as
well as any other labels that might have been applied by other
actors.
Label keys and values must be valid Kubernetes label keys and values.
https://kubernetes.io/docs/concepts/overview/working-with-objects/labels/#syntax-and-character-set | | | +| `labels` _[Labels](#labels)_ | Labels that will be added to the StatefulSet created for the proxy.
Any labels specified here will be merged with the default labels
applied to the StatefulSet by the Tailscale Kubernetes operator as
well as any other labels that might have been applied by other
actors.
Label keys and values must be valid Kubernetes label keys and values.
https://kubernetes.io/docs/concepts/overview/working-with-objects/labels/#syntax-and-character-set | | | | `annotations` _object (keys:string, values:string)_ | Annotations that will be added to the StatefulSet created for the proxy.
Any Annotations specified here will be merged with the default annotations
applied to the StatefulSet by the Tailscale Kubernetes operator as
well as any other annotations that might have been applied by other
actors.
Annotations must be valid Kubernetes annotations.
https://kubernetes.io/docs/concepts/overview/working-with-objects/annotations/#syntax-and-character-set | | | | `pod` _[Pod](#pod)_ | Configuration for the proxy Pod. | | | +#### StaticEndpointsConfig + + + + + + + +_Appears in:_ +- [ProxyClassSpec](#proxyclassspec) + +| Field | Description | Default | Validation | +| --- | --- | --- | --- | +| `nodePort` _[NodePortConfig](#nodeportconfig)_ | The configuration for static endpoints using NodePort Services. | | | + + #### Storage @@ -897,6 +1168,7 @@ _Appears in:_ | --- | --- | --- | --- | | `hostname` _string_ | Hostname is the fully qualified domain name of the device.
If MagicDNS is enabled in your tailnet, it is the MagicDNS name of the
node. | | | | `tailnetIPs` _string array_ | TailnetIPs is the set of tailnet IP addresses (both IPv4 and IPv6)
assigned to the device. | | | +| `staticEndpoints` _string array_ | StaticEndpoints are user configured, 'static' endpoints by which tailnet peers can reach this device. | | | #### TailscaleConfig diff --git a/k8s-operator/apis/v1alpha1/register.go b/k8s-operator/apis/v1alpha1/register.go index 70b411d12..0880ac975 100644 --- a/k8s-operator/apis/v1alpha1/register.go +++ b/k8s-operator/apis/v1alpha1/register.go @@ -10,6 +10,7 @@ import ( "tailscale.com/k8s-operator/apis" + apiextensionsv1 "k8s.io/apiextensions-apiserver/pkg/apis/apiextensions/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/runtime/schema" @@ -39,12 +40,18 @@ func init() { localSchemeBuilder.Register(addKnownTypes) GlobalScheme = runtime.NewScheme() + // Add core types if err := scheme.AddToScheme(GlobalScheme); err != nil { panic(fmt.Sprintf("failed to add k8s.io scheme: %s", err)) } + // Add tailscale.com types if err := AddToScheme(GlobalScheme); err != nil { panic(fmt.Sprintf("failed to add tailscale.com scheme: %s", err)) } + // Add apiextensions types (CustomResourceDefinitions/CustomResourceDefinitionLists) + if err := apiextensionsv1.AddToScheme(GlobalScheme); err != nil { + panic(fmt.Sprintf("failed to add apiextensions.k8s.io scheme: %s", err)) + } } // Adds the list of known types to api.Scheme. diff --git a/k8s-operator/apis/v1alpha1/types_connector.go b/k8s-operator/apis/v1alpha1/types_connector.go index 175d62eea..58457500f 100644 --- a/k8s-operator/apis/v1alpha1/types_connector.go +++ b/k8s-operator/apis/v1alpha1/types_connector.go @@ -22,7 +22,9 @@ var ConnectorKind = "Connector" // +kubebuilder:resource:scope=Cluster,shortName=cn // +kubebuilder:printcolumn:name="SubnetRoutes",type="string",JSONPath=`.status.subnetRoutes`,description="CIDR ranges exposed to tailnet by a subnet router defined via this Connector instance." // +kubebuilder:printcolumn:name="IsExitNode",type="string",JSONPath=`.status.isExitNode`,description="Whether this Connector instance defines an exit node." +// +kubebuilder:printcolumn:name="IsAppConnector",type="string",JSONPath=`.status.isAppConnector`,description="Whether this Connector instance is an app connector." // +kubebuilder:printcolumn:name="Status",type="string",JSONPath=`.status.conditions[?(@.type == "ConnectorReady")].reason`,description="Status of the deployed Connector resources." +// +kubebuilder:printcolumn:name="Age",type="date",JSONPath=".metadata.creationTimestamp" // Connector defines a Tailscale node that will be deployed in the cluster. The // node can be configured to act as a Tailscale subnet router and/or a Tailscale @@ -55,7 +57,10 @@ type ConnectorList struct { } // ConnectorSpec describes a Tailscale node to be deployed in the cluster. -// +kubebuilder:validation:XValidation:rule="has(self.subnetRouter) || self.exitNode == true",message="A Connector needs to be either an exit node or a subnet router, or both." +// +kubebuilder:validation:XValidation:rule="has(self.subnetRouter) || (has(self.exitNode) && self.exitNode == true) || has(self.appConnector)",message="A Connector needs to have at least one of exit node, subnet router or app connector configured." +// +kubebuilder:validation:XValidation:rule="!((has(self.subnetRouter) || (has(self.exitNode) && self.exitNode == true)) && has(self.appConnector))",message="The appConnector field is mutually exclusive with exitNode and subnetRouter fields." +// +kubebuilder:validation:XValidation:rule="!(has(self.hostname) && has(self.replicas) && self.replicas > 1)",message="The hostname field cannot be specified when replicas is greater than 1." +// +kubebuilder:validation:XValidation:rule="!(has(self.hostname) && has(self.hostnamePrefix))",message="The hostname and hostnamePrefix fields are mutually exclusive." type ConnectorSpec struct { // Tags that the Tailscale node will be tagged with. // Defaults to [tag:k8s]. @@ -73,25 +78,61 @@ type ConnectorSpec struct { // Connector node. If unset, hostname defaults to -connector. Hostname can contain lower case letters, numbers and // dashes, it must not start or end with a dash and must be between 2 - // and 63 characters long. + // and 63 characters long. This field should only be used when creating a connector + // with an unspecified number of replicas, or a single replica. // +optional Hostname Hostname `json:"hostname,omitempty"` + + // HostnamePrefix specifies the hostname prefix for each + // replica. Each device will have the integer number + // from its StatefulSet pod appended to this prefix to form the full hostname. + // HostnamePrefix can contain lower case letters, numbers and dashes, it + // must not start with a dash and must be between 1 and 62 characters long. + // +optional + HostnamePrefix HostnamePrefix `json:"hostnamePrefix,omitempty"` + // ProxyClass is the name of the ProxyClass custom resource that // contains configuration options that should be applied to the // resources created for this Connector. If unset, the operator will // create resources with the default configuration. // +optional ProxyClass string `json:"proxyClass,omitempty"` - // SubnetRouter defines subnet routes that the Connector node should - // expose to tailnet. If unset, none are exposed. + // SubnetRouter defines subnet routes that the Connector device should + // expose to tailnet as a Tailscale subnet router. // https://tailscale.com/kb/1019/subnets/ + // If this field is unset, the device does not get configured as a Tailscale subnet router. + // This field is mutually exclusive with the appConnector field. // +optional - SubnetRouter *SubnetRouter `json:"subnetRouter"` - // ExitNode defines whether the Connector node should act as a - // Tailscale exit node. Defaults to false. + SubnetRouter *SubnetRouter `json:"subnetRouter,omitempty"` + // AppConnector defines whether the Connector device should act as a Tailscale app connector. A Connector that is + // configured as an app connector cannot be a subnet router or an exit node. If this field is unset, the + // Connector does not act as an app connector. + // Note that you will need to manually configure the permissions and the domains for the app connector via the + // Admin panel. + // Note also that the main tested and supported use case of this config option is to deploy an app connector on + // Kubernetes to access SaaS applications available on the public internet. Using the app connector to expose + // cluster workloads or other internal workloads to tailnet might work, but this is not a use case that we have + // tested or optimised for. + // If you are using the app connector to access SaaS applications because you need a predictable egress IP that + // can be whitelisted, it is also your responsibility to ensure that cluster traffic from the connector flows + // via that predictable IP, for example by enforcing that cluster egress traffic is routed via an egress NAT + // device with a static IP address. + // https://tailscale.com/kb/1281/app-connectors + // +optional + AppConnector *AppConnector `json:"appConnector,omitempty"` + + // ExitNode defines whether the Connector device should act as a Tailscale exit node. Defaults to false. + // This field is mutually exclusive with the appConnector field. // https://tailscale.com/kb/1103/exit-nodes // +optional ExitNode bool `json:"exitNode"` + + // Replicas specifies how many devices to create. Set this to enable + // high availability for app connectors, subnet routers, or exit nodes. + // https://tailscale.com/kb/1115/high-availability. Defaults to 1. + // +optional + // +kubebuilder:validation:Minimum=0 + Replicas *int32 `json:"replicas,omitempty"` } // SubnetRouter defines subnet routes that should be exposed to tailnet via a @@ -104,6 +145,17 @@ type SubnetRouter struct { AdvertiseRoutes Routes `json:"advertiseRoutes"` } +// AppConnector defines a Tailscale app connector node configured via Connector. +type AppConnector struct { + // Routes are optional preconfigured routes for the domains routed via the app connector. + // If not set, routes for the domains will be discovered dynamically. + // If set, the app connector will immediately be able to route traffic using the preconfigured routes, but may + // also dynamically discover other routes. + // https://tailscale.com/kb/1332/apps-best-practices#preconfiguration + // +optional + Routes Routes `json:"routes"` +} + type Tags []Tag func (tags Tags) Stringify() []string { @@ -156,30 +208,62 @@ type ConnectorStatus struct { // IsExitNode is set to true if the Connector acts as an exit node. // +optional IsExitNode bool `json:"isExitNode"` + // IsAppConnector is set to true if the Connector acts as an app connector. + // +optional + IsAppConnector bool `json:"isAppConnector"` // TailnetIPs is the set of tailnet IP addresses (both IPv4 and IPv6) // assigned to the Connector node. // +optional TailnetIPs []string `json:"tailnetIPs,omitempty"` // Hostname is the fully qualified domain name of the Connector node. // If MagicDNS is enabled in your tailnet, it is the MagicDNS name of the - // node. + // node. When using multiple replicas, this field will be populated with the + // first replica's hostname. Use the Hostnames field for the full list + // of hostnames. // +optional Hostname string `json:"hostname,omitempty"` + // Devices contains information on each device managed by the Connector resource. + // +optional + Devices []ConnectorDevice `json:"devices"` +} + +type ConnectorDevice struct { + // Hostname is the fully qualified domain name of the Connector replica. + // If MagicDNS is enabled in your tailnet, it is the MagicDNS name of the + // node. + // +optional + Hostname string `json:"hostname"` + // TailnetIPs is the set of tailnet IP addresses (both IPv4 and IPv6) + // assigned to the Connector replica. + // +optional + TailnetIPs []string `json:"tailnetIPs,omitempty"` } type ConditionType string const ( - ConnectorReady ConditionType = `ConnectorReady` - ProxyClassready ConditionType = `ProxyClassReady` - ProxyGroupReady ConditionType = `ProxyGroupReady` - ProxyReady ConditionType = `TailscaleProxyReady` // a Tailscale-specific condition type for corev1.Service - RecorderReady ConditionType = `RecorderReady` - // EgressSvcValid is set to true if the user configured ExternalName Service for exposing a tailnet target on - // ProxyGroup nodes is valid. - EgressSvcValid ConditionType = `EgressSvcValid` - // EgressSvcConfigured is set to true if the configuration for the egress Service (proxy ConfigMap update, - // EndpointSlice for the Service) has been successfully applied. The Reason for this condition - // contains the name of the ProxyGroup and the hash of the Service ports and the tailnet target. - EgressSvcConfigured ConditionType = `EgressSvcConfigured` + ConnectorReady ConditionType = `ConnectorReady` + ProxyClassReady ConditionType = `ProxyClassReady` + ProxyGroupReady ConditionType = `ProxyGroupReady` // All proxy Pods running. + ProxyGroupAvailable ConditionType = `ProxyGroupAvailable` // At least one proxy Pod running. + ProxyReady ConditionType = `TailscaleProxyReady` // a Tailscale-specific condition type for corev1.Service + RecorderReady ConditionType = `RecorderReady` + // EgressSvcValid gets set on a user configured ExternalName Service that defines a tailnet target to be exposed + // on a ProxyGroup. + // Set to true if the user provided configuration is valid. + EgressSvcValid ConditionType = `TailscaleEgressSvcValid` + // EgressSvcConfigured gets set on a user configured ExternalName Service that defines a tailnet target to be exposed + // on a ProxyGroup. + // Set to true if the cluster resources for the service have been successfully configured. + EgressSvcConfigured ConditionType = `TailscaleEgressSvcConfigured` + // EgressSvcReady gets set on a user configured ExternalName Service that defines a tailnet target to be exposed + // on a ProxyGroup. + // Set to true if the service is ready to route cluster traffic. + EgressSvcReady ConditionType = `TailscaleEgressSvcReady` + + IngressSvcValid ConditionType = `TailscaleIngressSvcValid` + IngressSvcConfigured ConditionType = `TailscaleIngressSvcConfigured` + + KubeAPIServerProxyValid ConditionType = `KubeAPIServerProxyValid` // The kubeAPIServer config for the ProxyGroup is valid. + KubeAPIServerProxyConfigured ConditionType = `KubeAPIServerProxyConfigured` // At least one of the ProxyGroup's Pods is advertising the kube-apiserver proxy's hostname. ) diff --git a/k8s-operator/apis/v1alpha1/types_proxyclass.go b/k8s-operator/apis/v1alpha1/types_proxyclass.go index 7f415bc34..670df3b95 100644 --- a/k8s-operator/apis/v1alpha1/types_proxyclass.go +++ b/k8s-operator/apis/v1alpha1/types_proxyclass.go @@ -6,6 +6,10 @@ package v1alpha1 import ( + "fmt" + "iter" + "strings" + corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" ) @@ -16,6 +20,7 @@ var ProxyClassKind = "ProxyClass" // +kubebuilder:subresource:status // +kubebuilder:resource:scope=Cluster // +kubebuilder:printcolumn:name="Status",type="string",JSONPath=`.status.conditions[?(@.type == "ProxyClassReady")].reason`,description="Status of the ProxyClass." +// +kubebuilder:printcolumn:name="Age",type="date",JSONPath=".metadata.creationTimestamp" // ProxyClass describes a set of configuration parameters that can be applied to // proxy resources created by the Tailscale Kubernetes operator. @@ -66,6 +71,139 @@ type ProxyClassSpec struct { // parameters of proxies. // +optional TailscaleConfig *TailscaleConfig `json:"tailscale,omitempty"` + // Set UseLetsEncryptStagingEnvironment to true to issue TLS + // certificates for any HTTPS endpoints exposed to the tailnet from + // LetsEncrypt's staging environment. + // https://letsencrypt.org/docs/staging-environment/ + // This setting only affects Tailscale Ingress resources. + // By default Ingress TLS certificates are issued from LetsEncrypt's + // production environment. + // Changing this setting true -> false, will result in any + // existing certs being re-issued from the production environment. + // Changing this setting false (default) -> true, when certs have already + // been provisioned from production environment will NOT result in certs + // being re-issued from the staging environment before they need to be + // renewed. + // +optional + UseLetsEncryptStagingEnvironment bool `json:"useLetsEncryptStagingEnvironment,omitempty"` + // Configuration for 'static endpoints' on proxies in order to facilitate + // direct connections from other devices on the tailnet. + // See https://tailscale.com/kb/1445/kubernetes-operator-customization#static-endpoints. + // +optional + StaticEndpoints *StaticEndpointsConfig `json:"staticEndpoints,omitempty"` +} + +type StaticEndpointsConfig struct { + // The configuration for static endpoints using NodePort Services. + NodePort *NodePortConfig `json:"nodePort"` +} + +type NodePortConfig struct { + // The port ranges from which the operator will select NodePorts for the Services. + // You must ensure that firewall rules allow UDP ingress traffic for these ports + // to the node's external IPs. + // The ports must be in the range of service node ports for the cluster (default `30000-32767`). + // See https://kubernetes.io/docs/concepts/services-networking/service/#type-nodeport. + // +kubebuilder:validation:MinItems=1 + Ports []PortRange `json:"ports"` + // A selector which will be used to select the node's that will have their `ExternalIP`'s advertised + // by the ProxyGroup as Static Endpoints. + Selector map[string]string `json:"selector,omitempty"` +} + +// PortRanges is a list of PortRange(s) +type PortRanges []PortRange + +func (prs PortRanges) String() string { + var prStrings []string + + for _, pr := range prs { + prStrings = append(prStrings, pr.String()) + } + + return strings.Join(prStrings, ", ") +} + +// All allows us to iterate over all the ports in the PortRanges +func (prs PortRanges) All() iter.Seq[uint16] { + return func(yield func(uint16) bool) { + for _, pr := range prs { + end := pr.EndPort + if end == 0 { + end = pr.Port + } + + for port := pr.Port; port <= end; port++ { + if !yield(port) { + return + } + } + } + } +} + +// Contains reports whether port is in any of the PortRanges. +func (prs PortRanges) Contains(port uint16) bool { + for _, r := range prs { + if r.Contains(port) { + return true + } + } + + return false +} + +// ClashesWith reports whether the supplied PortRange clashes with any of the PortRanges. +func (prs PortRanges) ClashesWith(pr PortRange) bool { + for p := range prs.All() { + if pr.Contains(p) { + return true + } + } + + return false +} + +type PortRange struct { + // port represents a port selected to be used. This is a required field. + Port uint16 `json:"port"` + + // endPort indicates that the range of ports from port to endPort if set, inclusive, + // should be used. This field cannot be defined if the port field is not defined. + // The endPort must be either unset, or equal or greater than port. + // +optional + EndPort uint16 `json:"endPort,omitempty"` +} + +// Contains reports whether port is in pr. +func (pr PortRange) Contains(port uint16) bool { + switch pr.EndPort { + case 0: + return port == pr.Port + default: + return port >= pr.Port && port <= pr.EndPort + } +} + +// String returns the PortRange in a string form. +func (pr PortRange) String() string { + if pr.EndPort == 0 { + return fmt.Sprintf("%d", pr.Port) + } + + return fmt.Sprintf("%d-%d", pr.Port, pr.EndPort) +} + +// IsValid reports whether the port range is valid. +func (pr PortRange) IsValid() bool { + if pr.Port == 0 { + return false + } + if pr.EndPort == 0 { + return true + } + + return pr.Port <= pr.EndPort } type TailscaleConfig struct { @@ -87,7 +225,7 @@ type StatefulSet struct { // Label keys and values must be valid Kubernetes label keys and values. // https://kubernetes.io/docs/concepts/overview/working-with-objects/labels/#syntax-and-character-set // +optional - Labels map[string]string `json:"labels,omitempty"` + Labels Labels `json:"labels,omitempty"` // Annotations that will be added to the StatefulSet created for the proxy. // Any Annotations specified here will be merged with the default annotations // applied to the StatefulSet by the Tailscale Kubernetes operator as @@ -109,7 +247,7 @@ type Pod struct { // Label keys and values must be valid Kubernetes label keys and values. // https://kubernetes.io/docs/concepts/overview/working-with-objects/labels/#syntax-and-character-set // +optional - Labels map[string]string `json:"labels,omitempty"` + Labels Labels `json:"labels,omitempty"` // Annotations that will be added to the proxy Pod. // Any annotations specified here will be merged with the default // annotations applied to the Pod by the Tailscale Kubernetes operator. @@ -126,6 +264,7 @@ type Pod struct { // +optional TailscaleContainer *Container `json:"tailscaleContainer,omitempty"` // Configuration for the proxy init container that enables forwarding. + // Not valid to apply to ProxyGroups of type "kube-apiserver". // +optional TailscaleInitContainer *Container `json:"tailscaleInitContainer,omitempty"` // Proxy Pod's security context. @@ -154,16 +293,84 @@ type Pod struct { // https://kubernetes.io/docs/reference/kubernetes-api/workload-resources/pod-v1/#scheduling // +optional Tolerations []corev1.Toleration `json:"tolerations,omitempty"` + // Proxy Pod's topology spread constraints. + // By default Tailscale Kubernetes operator does not apply any topology spread constraints. + // https://kubernetes.io/docs/concepts/scheduling-eviction/topology-spread-constraints/ + // +optional + TopologySpreadConstraints []corev1.TopologySpreadConstraint `json:"topologySpreadConstraints,omitempty"` + // PriorityClassName for the proxy Pod. + // By default Tailscale Kubernetes operator does not apply any priority class. + // https://kubernetes.io/docs/reference/kubernetes-api/workload-resources/pod-v1/#scheduling + // +optional + PriorityClassName string `json:"priorityClassName,omitempty"` + // DNSPolicy defines how DNS will be configured for the proxy Pod. + // By default the Tailscale Kubernetes Operator does not set a DNS policy (uses cluster default). + // https://kubernetes.io/docs/concepts/services-networking/dns-pod-service/#pod-s-dns-policy + // +kubebuilder:validation:Enum=ClusterFirstWithHostNet;ClusterFirst;Default;None // +optional + DNSPolicy *corev1.DNSPolicy `json:"dnsPolicy,omitempty"` + // DNSConfig defines DNS parameters for the proxy Pod in addition to those generated from DNSPolicy. + // When DNSPolicy is set to "None", DNSConfig must be specified. + // https://kubernetes.io/docs/concepts/services-networking/dns-pod-service/#pod-dns-config + // +optional + DNSConfig *corev1.PodDNSConfig `json:"dnsConfig,omitempty"` } +// +kubebuilder:validation:XValidation:rule="!(has(self.serviceMonitor) && self.serviceMonitor.enable && !self.enable)",message="ServiceMonitor can only be enabled if metrics are enabled" type Metrics struct { // Setting enable to true will make the proxy serve Tailscale metrics - // at :9001/debug/metrics. + // at :9002/metrics. + // A metrics Service named -metrics will also be created in the operator's namespace and will + // serve the metrics at :9002/metrics. + // + // In 1.78.x and 1.80.x, this field also serves as the default value for + // .spec.statefulSet.pod.tailscaleContainer.debug.enable. From 1.82.0, both + // fields will independently default to false. + // // Defaults to false. Enable bool `json:"enable"` + // Enable to create a Prometheus ServiceMonitor for scraping the proxy's Tailscale metrics. + // The ServiceMonitor will select the metrics Service that gets created when metrics are enabled. + // The ingested metrics for each Service monitor will have labels to identify the proxy: + // ts_proxy_type: ingress_service|ingress_resource|connector|proxygroup + // ts_proxy_parent_name: name of the parent resource (i.e name of the Connector, Tailscale Ingress, Tailscale Service or ProxyGroup) + // ts_proxy_parent_namespace: namespace of the parent resource (if the parent resource is not cluster scoped) + // job: ts__[]_ + // +optional + ServiceMonitor *ServiceMonitor `json:"serviceMonitor"` } +type ServiceMonitor struct { + // If Enable is set to true, a Prometheus ServiceMonitor will be created. Enable can only be set to true if metrics are enabled. + Enable bool `json:"enable"` + // Labels to add to the ServiceMonitor. + // Labels must be valid Kubernetes labels. + // https://kubernetes.io/docs/concepts/overview/working-with-objects/labels/#syntax-and-character-set + // +optional + Labels Labels `json:"labels"` +} + +type Labels map[string]LabelValue + +func (lb Labels) Parse() map[string]string { + if lb == nil { + return nil + } + m := make(map[string]string, len(lb)) + for k, v := range lb { + m[k] = string(v) + } + return m +} + +// We do not validate the values of the label keys here - it is done by the ProxyClass +// reconciler because the validation rules are too complex for a CRD validation markers regex. + +// +kubebuilder:validation:Type=string +// +kubebuilder:validation:Pattern=`^(([a-zA-Z0-9][-._a-zA-Z0-9]*)?[a-zA-Z0-9])?$` +// +kubebuilder:validation:MaxLength=63 +type LabelValue string + type Container struct { // List of environment variables to set in the container. // https://kubernetes.io/docs/reference/kubernetes-api/workload-resources/pod-v1/#environment-variables @@ -174,12 +381,21 @@ type Container struct { // the future. // +optional Env []Env `json:"env,omitempty"` - // Container image name. By default images are pulled from - // docker.io/tailscale/tailscale, but the official images are also - // available at ghcr.io/tailscale/tailscale. Specifying image name here - // will override any proxy image values specified via the Kubernetes - // operator's Helm chart values or PROXY_IMAGE env var in the operator - // Deployment. + // Container image name. By default images are pulled from docker.io/tailscale, + // but the official images are also available at ghcr.io/tailscale. + // + // For all uses except on ProxyGroups of type "kube-apiserver", this image must + // be either tailscale/tailscale, or an equivalent mirror of that image. + // To apply to ProxyGroups of type "kube-apiserver", this image must be + // tailscale/k8s-proxy or a mirror of that image. + // + // For "tailscale/tailscale"-based proxies, specifying image name here will + // override any proxy image values specified via the Kubernetes operator's + // Helm chart values or PROXY_IMAGE env var in the operator Deployment. + // For "tailscale/k8s-proxy"-based proxies, there is currently no way to + // configure your own default, and this field is the only way to use a + // custom image. + // // https://kubernetes.io/docs/reference/kubernetes-api/workload-resources/pod-v1/#image // +optional Image string `json:"image,omitempty"` @@ -197,14 +413,35 @@ type Container struct { // +optional Resources corev1.ResourceRequirements `json:"resources,omitempty"` // Container security context. - // Security context specified here will override the security context by the operator. - // By default the operator: - // - sets 'privileged: true' for the init container - // - set NET_ADMIN capability for tailscale container for proxies that - // are created for Services or Connector. + // Security context specified here will override the security context set by the operator. + // By default the operator sets the Tailscale container and the Tailscale init container to privileged + // for proxies created for Tailscale ingress and egress Service, Connector and ProxyGroup. + // You can reduce the permissions of the Tailscale container to cap NET_ADMIN by + // installing device plugin in your cluster and configuring the proxies tun device to be created + // by the device plugin, see https://github.com/tailscale/tailscale/issues/10814#issuecomment-2479977752 // https://kubernetes.io/docs/reference/kubernetes-api/workload-resources/pod-v1/#security-context // +optional SecurityContext *corev1.SecurityContext `json:"securityContext,omitempty"` + // Configuration for enabling extra debug information in the container. + // Not recommended for production use. + // +optional + Debug *Debug `json:"debug,omitempty"` +} + +type Debug struct { + // Enable tailscaled's HTTP pprof endpoints at :9001/debug/pprof/ + // and internal debug metrics endpoint at :9001/debug/metrics, where + // 9001 is a container port named "debug". The endpoints and their responses + // may change in backwards incompatible ways in the future, and should not + // be considered stable. + // + // In 1.78.x and 1.80.x, this setting will default to the value of + // .spec.metrics.enable, and requests to the "metrics" port matching the + // mux pattern /debug/ will be forwarded to the "debug" port. In 1.82.x, + // this setting will default to false, and no requests will be proxied. + // + // +optional + Enable bool `json:"enable"` } type Env struct { diff --git a/k8s-operator/apis/v1alpha1/types_proxygroup.go b/k8s-operator/apis/v1alpha1/types_proxygroup.go index 92912a779..28fd9e009 100644 --- a/k8s-operator/apis/v1alpha1/types_proxygroup.go +++ b/k8s-operator/apis/v1alpha1/types_proxygroup.go @@ -13,7 +13,27 @@ import ( // +kubebuilder:subresource:status // +kubebuilder:resource:scope=Cluster,shortName=pg // +kubebuilder:printcolumn:name="Status",type="string",JSONPath=`.status.conditions[?(@.type == "ProxyGroupReady")].reason`,description="Status of the deployed ProxyGroup resources." - +// +kubebuilder:printcolumn:name="URL",type="string",JSONPath=`.status.url`,description="URL of the kube-apiserver proxy advertised by the ProxyGroup devices, if any. Only applies to ProxyGroups of type kube-apiserver." +// +kubebuilder:printcolumn:name="Type",type="string",JSONPath=`.spec.type`,description="ProxyGroup type." +// +kubebuilder:printcolumn:name="Age",type="date",JSONPath=".metadata.creationTimestamp" + +// ProxyGroup defines a set of Tailscale devices that will act as proxies. +// Depending on spec.Type, it can be a group of egress, ingress, or kube-apiserver +// proxies. In addition to running a highly available set of proxies, ingress +// and egress ProxyGroups also allow for serving many annotated Services from a +// single set of proxies to minimise resource consumption. +// +// For ingress and egress, use the tailscale.com/proxy-group annotation on a +// Service to specify that the proxy should be implemented by a ProxyGroup +// instead of a single dedicated proxy. +// +// More info: +// * https://tailscale.com/kb/1438/kubernetes-operator-cluster-egress +// * https://tailscale.com/kb/1439/kubernetes-operator-cluster-ingress +// +// For kube-apiserver, the ProxyGroup is a standalone resource. Use the +// spec.kubeAPIServer field to configure options specific to the kube-apiserver +// ProxyGroup type. type ProxyGroup struct { metav1.TypeMeta `json:",inline"` metav1.ObjectMeta `json:"metadata,omitempty"` @@ -37,9 +57,9 @@ type ProxyGroupList struct { } type ProxyGroupSpec struct { - // Type of the ProxyGroup, either ingress or egress. Each set of proxies - // managed by a single ProxyGroup definition operate as only ingress or - // only egress proxies. + // Type of the ProxyGroup proxies. Supported types are egress, ingress, and kube-apiserver. + // Type is immutable once a ProxyGroup is created. + // +kubebuilder:validation:XValidation:rule="self == oldSelf",message="ProxyGroup type is immutable" Type ProxyGroupType `json:"type"` // Tags that the Tailscale devices will be tagged with. Defaults to [tag:k8s]. @@ -54,7 +74,8 @@ type ProxyGroupSpec struct { // Replicas specifies how many replicas to create the StatefulSet with. // Defaults to 2. // +optional - Replicas *int `json:"replicas,omitempty"` + // +kubebuilder:validation:Minimum=0 + Replicas *int32 `json:"replicas,omitempty"` // HostnamePrefix is the hostname prefix to use for tailnet devices created // by the ProxyGroup. Each device will have the integer number from its @@ -66,15 +87,35 @@ type ProxyGroupSpec struct { // ProxyClass is the name of the ProxyClass custom resource that contains // configuration options that should be applied to the resources created - // for this ProxyGroup. If unset, and no default ProxyClass is set, the - // operator will create resources with the default configuration. + // for this ProxyGroup. If unset, and there is no default ProxyClass + // configured, the operator will create resources with the default + // configuration. // +optional ProxyClass string `json:"proxyClass,omitempty"` + + // KubeAPIServer contains configuration specific to the kube-apiserver + // ProxyGroup type. This field is only used when Type is set to "kube-apiserver". + // +optional + KubeAPIServer *KubeAPIServerConfig `json:"kubeAPIServer,omitempty"` } type ProxyGroupStatus struct { // List of status conditions to indicate the status of the ProxyGroup - // resources. Known condition types are `ProxyGroupReady`. + // resources. Known condition types include `ProxyGroupReady` and + // `ProxyGroupAvailable`. + // + // * `ProxyGroupReady` indicates all ProxyGroup resources are reconciled and + // all expected conditions are true. + // * `ProxyGroupAvailable` indicates that at least one proxy is ready to + // serve traffic. + // + // For ProxyGroups of type kube-apiserver, there are two additional conditions: + // + // * `KubeAPIServerProxyConfigured` indicates that at least one API server + // proxy is configured and ready to serve traffic. + // * `KubeAPIServerProxyValid` indicates that spec.kubeAPIServer config is + // valid. + // // +listType=map // +listMapKey=type // +optional @@ -85,6 +126,11 @@ type ProxyGroupStatus struct { // +listMapKey=hostname // +optional Devices []TailnetDevice `json:"devices,omitempty"` + + // URL of the kube-apiserver proxy advertised by the ProxyGroup devices, if + // any. Only applies to ProxyGroups of type kube-apiserver. + // +optional + URL string `json:"url,omitempty"` } type TailnetDevice struct { @@ -97,16 +143,50 @@ type TailnetDevice struct { // assigned to the device. // +optional TailnetIPs []string `json:"tailnetIPs,omitempty"` + + // StaticEndpoints are user configured, 'static' endpoints by which tailnet peers can reach this device. + // +optional + StaticEndpoints []string `json:"staticEndpoints,omitempty"` } // +kubebuilder:validation:Type=string -// +kubebuilder:validation:Enum=egress +// +kubebuilder:validation:Enum=egress;ingress;kube-apiserver type ProxyGroupType string const ( - ProxyGroupTypeEgress ProxyGroupType = "egress" + ProxyGroupTypeEgress ProxyGroupType = "egress" + ProxyGroupTypeIngress ProxyGroupType = "ingress" + ProxyGroupTypeKubernetesAPIServer ProxyGroupType = "kube-apiserver" +) + +// +kubebuilder:validation:Type=string +// +kubebuilder:validation:Enum=auth;noauth +type APIServerProxyMode string + +const ( + APIServerProxyModeAuth APIServerProxyMode = "auth" + APIServerProxyModeNoAuth APIServerProxyMode = "noauth" ) // +kubebuilder:validation:Type=string // +kubebuilder:validation:Pattern=`^[a-z0-9][a-z0-9-]{0,61}$` type HostnamePrefix string + +// KubeAPIServerConfig contains configuration specific to the kube-apiserver ProxyGroup type. +type KubeAPIServerConfig struct { + // Mode to run the API server proxy in. Supported modes are auth and noauth. + // In auth mode, requests from the tailnet proxied over to the Kubernetes + // API server are additionally impersonated using the sender's tailnet identity. + // If not specified, defaults to auth mode. + // +optional + Mode *APIServerProxyMode `json:"mode,omitempty"` + + // Hostname is the hostname with which to expose the Kubernetes API server + // proxies. Must be a valid DNS label no longer than 63 characters. If not + // specified, the name of the ProxyGroup is used as the hostname. Must be + // unique across the whole tailnet. + // +kubebuilder:validation:Type=string + // +kubebuilder:validation:Pattern=`^[a-z0-9]([a-z0-9-]{0,61}[a-z0-9])?$` + // +optional + Hostname string `json:"hostname,omitempty"` +} diff --git a/k8s-operator/apis/v1alpha1/types_recorder.go b/k8s-operator/apis/v1alpha1/types_recorder.go index 3728154b4..16a610b26 100644 --- a/k8s-operator/apis/v1alpha1/types_recorder.go +++ b/k8s-operator/apis/v1alpha1/types_recorder.go @@ -15,7 +15,13 @@ import ( // +kubebuilder:resource:scope=Cluster,shortName=rec // +kubebuilder:printcolumn:name="Status",type="string",JSONPath=`.status.conditions[?(@.type == "RecorderReady")].reason`,description="Status of the deployed Recorder resources." // +kubebuilder:printcolumn:name="URL",type="string",JSONPath=`.status.devices[?(@.url != "")].url`,description="URL on which the UI is exposed if enabled." +// +kubebuilder:printcolumn:name="Age",type="date",JSONPath=".metadata.creationTimestamp" +// Recorder defines a tsrecorder device for recording SSH sessions. By default, +// it will store recordings in a local ephemeral volume. If you want to persist +// recordings, you can configure an S3-compatible API for storage. +// +// More info: https://tailscale.com/kb/1484/kubernetes-operator-deploying-tsrecorder type Recorder struct { metav1.TypeMeta `json:",inline"` metav1.ObjectMeta `json:"metadata,omitempty"` @@ -136,6 +142,36 @@ type RecorderPod struct { // https://kubernetes.io/docs/reference/kubernetes-api/workload-resources/pod-v1/#scheduling // +optional Tolerations []corev1.Toleration `json:"tolerations,omitempty"` + + // Config for the ServiceAccount to create for the Recorder's StatefulSet. + // By default, the operator will create a ServiceAccount with the same + // name as the Recorder resource. + // https://kubernetes.io/docs/reference/kubernetes-api/workload-resources/pod-v1/#service-account + // +optional + ServiceAccount RecorderServiceAccount `json:"serviceAccount,omitempty"` +} + +type RecorderServiceAccount struct { + // Name of the ServiceAccount to create. Defaults to the name of the + // Recorder resource. + // https://kubernetes.io/docs/reference/kubernetes-api/workload-resources/pod-v1/#service-account + // +kubebuilder:validation:Type=string + // +kubebuilder:validation:Pattern=`^[a-z0-9]([a-z0-9-.]{0,61}[a-z0-9])?$` + // +kubebuilder:validation:MaxLength=253 + // +optional + Name string `json:"name,omitempty"` + + // Annotations to add to the ServiceAccount. + // https://kubernetes.io/docs/concepts/overview/working-with-objects/annotations/#syntax-and-character-set + // + // You can use this to add IAM roles to the ServiceAccount (IRSA) instead of + // providing static S3 credentials in a Secret. + // https://docs.aws.amazon.com/eks/latest/userguide/iam-roles-for-service-accounts.html + // + // For example: + // eks.amazonaws.com/role-arn: arn:aws:iam:::role/ + // +optional + Annotations map[string]string `json:"annotations,omitempty"` } type RecorderContainer struct { diff --git a/k8s-operator/apis/v1alpha1/types_tsdnsconfig.go b/k8s-operator/apis/v1alpha1/types_tsdnsconfig.go index 60d212279..7991003b8 100644 --- a/k8s-operator/apis/v1alpha1/types_tsdnsconfig.go +++ b/k8s-operator/apis/v1alpha1/types_tsdnsconfig.go @@ -6,6 +6,7 @@ package v1alpha1 import ( + corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" ) @@ -18,6 +19,7 @@ var DNSConfigKind = "DNSConfig" // +kubebuilder:subresource:status // +kubebuilder:resource:scope=Cluster,shortName=dc // +kubebuilder:printcolumn:name="NameserverIP",type="string",JSONPath=`.status.nameserver.ip`,description="Service IP address of the nameserver" +// +kubebuilder:printcolumn:name="Age",type="date",JSONPath=".metadata.creationTimestamp" // DNSConfig can be deployed to cluster to make a subset of Tailscale MagicDNS // names resolvable by cluster workloads. Use this if: A) you need to refer to @@ -44,7 +46,6 @@ var DNSConfigKind = "DNSConfig" // using its MagicDNS name, you must also annotate the Ingress resource with // tailscale.com/experimental-forward-cluster-traffic-via-ingress annotation to // ensure that the proxy created for the Ingress listens on its Pod IP address. -// NB: Clusters where Pods get assigned IPv6 addresses only are currently not supported. type DNSConfig struct { metav1.TypeMeta `json:",inline"` metav1.ObjectMeta `json:"metadata,omitempty"` @@ -81,6 +82,16 @@ type Nameserver struct { // Nameserver image. Defaults to tailscale/k8s-nameserver:unstable. // +optional Image *NameserverImage `json:"image,omitempty"` + // Service configuration. + // +optional + Service *NameserverService `json:"service,omitempty"` + // Pod configuration. + // +optional + Pod *NameserverPod `json:"pod,omitempty"` + // Replicas specifies how many Pods to create. Defaults to 1. + // +optional + // +kubebuilder:validation:Minimum=0 + Replicas *int32 `json:"replicas,omitempty"` } type NameserverImage struct { @@ -92,6 +103,18 @@ type NameserverImage struct { Tag string `json:"tag,omitempty"` } +type NameserverService struct { + // ClusterIP sets the static IP of the service used by the nameserver. + // +optional + ClusterIP string `json:"clusterIP,omitempty"` +} + +type NameserverPod struct { + // If specified, applies tolerations to the pods deployed by the DNSConfig resource. + // +optional + Tolerations []corev1.Toleration `json:"tolerations,omitempty"` +} + type DNSConfigStatus struct { // +listType=map // +listMapKey=type @@ -104,7 +127,7 @@ type DNSConfigStatus struct { type NameserverStatus struct { // IP is the ClusterIP of the Service fronting the deployed ts.net nameserver. - // Currently you must manually update your cluster DNS config to add + // Currently, you must manually update your cluster DNS config to add // this address as a stub nameserver for ts.net for cluster workloads to be // able to resolve MagicDNS names associated with egress or Ingress // proxies. diff --git a/k8s-operator/apis/v1alpha1/zz_generated.deepcopy.go b/k8s-operator/apis/v1alpha1/zz_generated.deepcopy.go index b6b94ce3f..7492f1e54 100644 --- a/k8s-operator/apis/v1alpha1/zz_generated.deepcopy.go +++ b/k8s-operator/apis/v1alpha1/zz_generated.deepcopy.go @@ -13,6 +13,26 @@ import ( "k8s.io/apimachinery/pkg/runtime" ) +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *AppConnector) DeepCopyInto(out *AppConnector) { + *out = *in + if in.Routes != nil { + in, out := &in.Routes, &out.Routes + *out = make(Routes, len(*in)) + copy(*out, *in) + } +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new AppConnector. +func (in *AppConnector) DeepCopy() *AppConnector { + if in == nil { + return nil + } + out := new(AppConnector) + in.DeepCopyInto(out) + return out +} + // DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. func (in *Connector) DeepCopyInto(out *Connector) { *out = *in @@ -40,6 +60,26 @@ func (in *Connector) DeepCopyObject() runtime.Object { return nil } +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *ConnectorDevice) DeepCopyInto(out *ConnectorDevice) { + *out = *in + if in.TailnetIPs != nil { + in, out := &in.TailnetIPs, &out.TailnetIPs + *out = make([]string, len(*in)) + copy(*out, *in) + } +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new ConnectorDevice. +func (in *ConnectorDevice) DeepCopy() *ConnectorDevice { + if in == nil { + return nil + } + out := new(ConnectorDevice) + in.DeepCopyInto(out) + return out +} + // DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. func (in *ConnectorList) DeepCopyInto(out *ConnectorList) { *out = *in @@ -85,6 +125,16 @@ func (in *ConnectorSpec) DeepCopyInto(out *ConnectorSpec) { *out = new(SubnetRouter) (*in).DeepCopyInto(*out) } + if in.AppConnector != nil { + in, out := &in.AppConnector, &out.AppConnector + *out = new(AppConnector) + (*in).DeepCopyInto(*out) + } + if in.Replicas != nil { + in, out := &in.Replicas, &out.Replicas + *out = new(int32) + **out = **in + } } // DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new ConnectorSpec. @@ -112,6 +162,13 @@ func (in *ConnectorStatus) DeepCopyInto(out *ConnectorStatus) { *out = make([]string, len(*in)) copy(*out, *in) } + if in.Devices != nil { + in, out := &in.Devices, &out.Devices + *out = make([]ConnectorDevice, len(*in)) + for i := range *in { + (*in)[i].DeepCopyInto(&(*out)[i]) + } + } } // DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new ConnectorStatus. @@ -138,6 +195,11 @@ func (in *Container) DeepCopyInto(out *Container) { *out = new(corev1.SecurityContext) (*in).DeepCopyInto(*out) } + if in.Debug != nil { + in, out := &in.Debug, &out.Debug + *out = new(Debug) + **out = **in + } } // DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new Container. @@ -256,6 +318,21 @@ func (in *DNSConfigStatus) DeepCopy() *DNSConfigStatus { return out } +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *Debug) DeepCopyInto(out *Debug) { + *out = *in +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new Debug. +func (in *Debug) DeepCopy() *Debug { + if in == nil { + return nil + } + out := new(Debug) + in.DeepCopyInto(out) + return out +} + // DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. func (in *Env) DeepCopyInto(out *Env) { *out = *in @@ -271,9 +348,55 @@ func (in *Env) DeepCopy() *Env { return out } +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *KubeAPIServerConfig) DeepCopyInto(out *KubeAPIServerConfig) { + *out = *in + if in.Mode != nil { + in, out := &in.Mode, &out.Mode + *out = new(APIServerProxyMode) + **out = **in + } +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new KubeAPIServerConfig. +func (in *KubeAPIServerConfig) DeepCopy() *KubeAPIServerConfig { + if in == nil { + return nil + } + out := new(KubeAPIServerConfig) + in.DeepCopyInto(out) + return out +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in Labels) DeepCopyInto(out *Labels) { + { + in := &in + *out = make(Labels, len(*in)) + for key, val := range *in { + (*out)[key] = val + } + } +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new Labels. +func (in Labels) DeepCopy() Labels { + if in == nil { + return nil + } + out := new(Labels) + in.DeepCopyInto(out) + return *out +} + // DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. func (in *Metrics) DeepCopyInto(out *Metrics) { *out = *in + if in.ServiceMonitor != nil { + in, out := &in.ServiceMonitor, &out.ServiceMonitor + *out = new(ServiceMonitor) + (*in).DeepCopyInto(*out) + } } // DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new Metrics. @@ -294,6 +417,21 @@ func (in *Nameserver) DeepCopyInto(out *Nameserver) { *out = new(NameserverImage) **out = **in } + if in.Service != nil { + in, out := &in.Service, &out.Service + *out = new(NameserverService) + **out = **in + } + if in.Pod != nil { + in, out := &in.Pod, &out.Pod + *out = new(NameserverPod) + (*in).DeepCopyInto(*out) + } + if in.Replicas != nil { + in, out := &in.Replicas, &out.Replicas + *out = new(int32) + **out = **in + } } // DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new Nameserver. @@ -321,6 +459,43 @@ func (in *NameserverImage) DeepCopy() *NameserverImage { return out } +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *NameserverPod) DeepCopyInto(out *NameserverPod) { + *out = *in + if in.Tolerations != nil { + in, out := &in.Tolerations, &out.Tolerations + *out = make([]corev1.Toleration, len(*in)) + for i := range *in { + (*in)[i].DeepCopyInto(&(*out)[i]) + } + } +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new NameserverPod. +func (in *NameserverPod) DeepCopy() *NameserverPod { + if in == nil { + return nil + } + out := new(NameserverPod) + in.DeepCopyInto(out) + return out +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *NameserverService) DeepCopyInto(out *NameserverService) { + *out = *in +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new NameserverService. +func (in *NameserverService) DeepCopy() *NameserverService { + if in == nil { + return nil + } + out := new(NameserverService) + in.DeepCopyInto(out) + return out +} + // DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. func (in *NameserverStatus) DeepCopyInto(out *NameserverStatus) { *out = *in @@ -336,12 +511,39 @@ func (in *NameserverStatus) DeepCopy() *NameserverStatus { return out } +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *NodePortConfig) DeepCopyInto(out *NodePortConfig) { + *out = *in + if in.Ports != nil { + in, out := &in.Ports, &out.Ports + *out = make([]PortRange, len(*in)) + copy(*out, *in) + } + if in.Selector != nil { + in, out := &in.Selector, &out.Selector + *out = make(map[string]string, len(*in)) + for key, val := range *in { + (*out)[key] = val + } + } +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new NodePortConfig. +func (in *NodePortConfig) DeepCopy() *NodePortConfig { + if in == nil { + return nil + } + out := new(NodePortConfig) + in.DeepCopyInto(out) + return out +} + // DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. func (in *Pod) DeepCopyInto(out *Pod) { *out = *in if in.Labels != nil { in, out := &in.Labels, &out.Labels - *out = make(map[string]string, len(*in)) + *out = make(Labels, len(*in)) for key, val := range *in { (*out)[key] = val } @@ -392,6 +594,23 @@ func (in *Pod) DeepCopyInto(out *Pod) { (*in)[i].DeepCopyInto(&(*out)[i]) } } + if in.TopologySpreadConstraints != nil { + in, out := &in.TopologySpreadConstraints, &out.TopologySpreadConstraints + *out = make([]corev1.TopologySpreadConstraint, len(*in)) + for i := range *in { + (*in)[i].DeepCopyInto(&(*out)[i]) + } + } + if in.DNSPolicy != nil { + in, out := &in.DNSPolicy, &out.DNSPolicy + *out = new(corev1.DNSPolicy) + **out = **in + } + if in.DNSConfig != nil { + in, out := &in.DNSConfig, &out.DNSConfig + *out = new(corev1.PodDNSConfig) + (*in).DeepCopyInto(*out) + } } // DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new Pod. @@ -404,6 +623,40 @@ func (in *Pod) DeepCopy() *Pod { return out } +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *PortRange) DeepCopyInto(out *PortRange) { + *out = *in +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new PortRange. +func (in *PortRange) DeepCopy() *PortRange { + if in == nil { + return nil + } + out := new(PortRange) + in.DeepCopyInto(out) + return out +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in PortRanges) DeepCopyInto(out *PortRanges) { + { + in := &in + *out = make(PortRanges, len(*in)) + copy(*out, *in) + } +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new PortRanges. +func (in PortRanges) DeepCopy() PortRanges { + if in == nil { + return nil + } + out := new(PortRanges) + in.DeepCopyInto(out) + return *out +} + // DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. func (in *ProxyClass) DeepCopyInto(out *ProxyClass) { *out = *in @@ -474,13 +727,18 @@ func (in *ProxyClassSpec) DeepCopyInto(out *ProxyClassSpec) { if in.Metrics != nil { in, out := &in.Metrics, &out.Metrics *out = new(Metrics) - **out = **in + (*in).DeepCopyInto(*out) } if in.TailscaleConfig != nil { in, out := &in.TailscaleConfig, &out.TailscaleConfig *out = new(TailscaleConfig) **out = **in } + if in.StaticEndpoints != nil { + in, out := &in.StaticEndpoints, &out.StaticEndpoints + *out = new(StaticEndpointsConfig) + (*in).DeepCopyInto(*out) + } } // DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new ProxyClassSpec. @@ -584,9 +842,14 @@ func (in *ProxyGroupSpec) DeepCopyInto(out *ProxyGroupSpec) { } if in.Replicas != nil { in, out := &in.Replicas, &out.Replicas - *out = new(int) + *out = new(int32) **out = **in } + if in.KubeAPIServer != nil { + in, out := &in.KubeAPIServer, &out.KubeAPIServer + *out = new(KubeAPIServerConfig) + (*in).DeepCopyInto(*out) + } } // DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new ProxyGroupSpec. @@ -760,6 +1023,7 @@ func (in *RecorderPod) DeepCopyInto(out *RecorderPod) { (*in)[i].DeepCopyInto(&(*out)[i]) } } + in.ServiceAccount.DeepCopyInto(&out.ServiceAccount) } // DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new RecorderPod. @@ -772,6 +1036,28 @@ func (in *RecorderPod) DeepCopy() *RecorderPod { return out } +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *RecorderServiceAccount) DeepCopyInto(out *RecorderServiceAccount) { + *out = *in + if in.Annotations != nil { + in, out := &in.Annotations, &out.Annotations + *out = make(map[string]string, len(*in)) + for key, val := range *in { + (*out)[key] = val + } + } +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new RecorderServiceAccount. +func (in *RecorderServiceAccount) DeepCopy() *RecorderServiceAccount { + if in == nil { + return nil + } + out := new(RecorderServiceAccount) + in.DeepCopyInto(out) + return out +} + // DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. func (in *RecorderSpec) DeepCopyInto(out *RecorderSpec) { *out = *in @@ -939,12 +1225,34 @@ func (in *S3Secret) DeepCopy() *S3Secret { return out } +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *ServiceMonitor) DeepCopyInto(out *ServiceMonitor) { + *out = *in + if in.Labels != nil { + in, out := &in.Labels, &out.Labels + *out = make(Labels, len(*in)) + for key, val := range *in { + (*out)[key] = val + } + } +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new ServiceMonitor. +func (in *ServiceMonitor) DeepCopy() *ServiceMonitor { + if in == nil { + return nil + } + out := new(ServiceMonitor) + in.DeepCopyInto(out) + return out +} + // DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. func (in *StatefulSet) DeepCopyInto(out *StatefulSet) { *out = *in if in.Labels != nil { in, out := &in.Labels, &out.Labels - *out = make(map[string]string, len(*in)) + *out = make(Labels, len(*in)) for key, val := range *in { (*out)[key] = val } @@ -973,6 +1281,26 @@ func (in *StatefulSet) DeepCopy() *StatefulSet { return out } +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *StaticEndpointsConfig) DeepCopyInto(out *StaticEndpointsConfig) { + *out = *in + if in.NodePort != nil { + in, out := &in.NodePort, &out.NodePort + *out = new(NodePortConfig) + (*in).DeepCopyInto(*out) + } +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new StaticEndpointsConfig. +func (in *StaticEndpointsConfig) DeepCopy() *StaticEndpointsConfig { + if in == nil { + return nil + } + out := new(StaticEndpointsConfig) + in.DeepCopyInto(out) + return out +} + // DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. func (in *Storage) DeepCopyInto(out *Storage) { *out = *in @@ -1040,6 +1368,11 @@ func (in *TailnetDevice) DeepCopyInto(out *TailnetDevice) { *out = make([]string, len(*in)) copy(*out, *in) } + if in.StaticEndpoints != nil { + in, out := &in.StaticEndpoints, &out.StaticEndpoints + *out = make([]string, len(*in)) + copy(*out, *in) + } } // DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new TailnetDevice. diff --git a/k8s-operator/conditions.go b/k8s-operator/conditions.go index 2b4022c40..ae465a728 100644 --- a/k8s-operator/conditions.go +++ b/k8s-operator/conditions.go @@ -75,16 +75,6 @@ func RemoveServiceCondition(svc *corev1.Service, conditionType tsapi.ConditionTy }) } -func EgressServiceIsValidAndConfigured(svc *corev1.Service) bool { - for _, typ := range []tsapi.ConditionType{tsapi.EgressSvcValid, tsapi.EgressSvcConfigured} { - cond := GetServiceCondition(svc, typ) - if cond == nil || cond.Status != metav1.ConditionTrue { - return false - } - } - return true -} - // SetRecorderCondition ensures that Recorder status has a condition with the // given attributes. LastTransitionTime gets set every time condition's status // changes. @@ -93,6 +83,14 @@ func SetRecorderCondition(tsr *tsapi.Recorder, conditionType tsapi.ConditionType tsr.Status.Conditions = conds } +// SetProxyGroupCondition ensures that ProxyGroup status has a condition with the +// given attributes. LastTransitionTime gets set every time condition's status +// changes. +func SetProxyGroupCondition(pg *tsapi.ProxyGroup, conditionType tsapi.ConditionType, status metav1.ConditionStatus, reason, message string, gen int64, clock tstime.Clock, logger *zap.SugaredLogger) { + conds := updateCondition(pg.Status.Conditions, conditionType, status, reason, message, gen, clock, logger) + pg.Status.Conditions = conds +} + func updateCondition(conds []metav1.Condition, conditionType tsapi.ConditionType, status metav1.ConditionStatus, reason, message string, gen int64, clock tstime.Clock, logger *zap.SugaredLogger) []metav1.Condition { newCondition := metav1.Condition{ Type: string(conditionType), @@ -129,7 +127,7 @@ func updateCondition(conds []metav1.Condition, conditionType tsapi.ConditionType func ProxyClassIsReady(pc *tsapi.ProxyClass) bool { idx := xslices.IndexFunc(pc.Status.Conditions, func(cond metav1.Condition) bool { - return cond.Type == string(tsapi.ProxyClassready) + return cond.Type == string(tsapi.ProxyClassReady) }) if idx == -1 { return false @@ -139,14 +137,33 @@ func ProxyClassIsReady(pc *tsapi.ProxyClass) bool { } func ProxyGroupIsReady(pg *tsapi.ProxyGroup) bool { + cond := proxyGroupCondition(pg, tsapi.ProxyGroupReady) + return cond != nil && cond.Status == metav1.ConditionTrue && cond.ObservedGeneration == pg.Generation +} + +func ProxyGroupAvailable(pg *tsapi.ProxyGroup) bool { + cond := proxyGroupCondition(pg, tsapi.ProxyGroupAvailable) + return cond != nil && cond.Status == metav1.ConditionTrue +} + +func KubeAPIServerProxyValid(pg *tsapi.ProxyGroup) (valid bool, set bool) { + cond := proxyGroupCondition(pg, tsapi.KubeAPIServerProxyValid) + return cond != nil && cond.Status == metav1.ConditionTrue && cond.ObservedGeneration == pg.Generation, cond != nil +} + +func KubeAPIServerProxyConfigured(pg *tsapi.ProxyGroup) bool { + cond := proxyGroupCondition(pg, tsapi.KubeAPIServerProxyConfigured) + return cond != nil && cond.Status == metav1.ConditionTrue && cond.ObservedGeneration == pg.Generation +} + +func proxyGroupCondition(pg *tsapi.ProxyGroup, condType tsapi.ConditionType) *metav1.Condition { idx := xslices.IndexFunc(pg.Status.Conditions, func(cond metav1.Condition) bool { - return cond.Type == string(tsapi.ProxyGroupReady) + return cond.Type == string(condType) }) if idx == -1 { - return false + return nil } - cond := pg.Status.Conditions[idx] - return cond.Status == metav1.ConditionTrue && cond.ObservedGeneration == pg.Generation + return &pg.Status.Conditions[idx] } func DNSCfgIsReady(cfg *tsapi.DNSConfig) bool { @@ -159,3 +176,14 @@ func DNSCfgIsReady(cfg *tsapi.DNSConfig) bool { cond := cfg.Status.Conditions[idx] return cond.Status == metav1.ConditionTrue && cond.ObservedGeneration == cfg.Generation } + +func SvcIsReady(svc *corev1.Service) bool { + idx := xslices.IndexFunc(svc.Status.Conditions, func(cond metav1.Condition) bool { + return cond.Type == string(tsapi.ProxyReady) + }) + if idx == -1 { + return false + } + cond := svc.Status.Conditions[idx] + return cond.Status == metav1.ConditionTrue +} diff --git a/k8s-operator/sessionrecording/fakes/fakes.go b/k8s-operator/sessionrecording/fakes/fakes.go index 9eb1047e4..94853df19 100644 --- a/k8s-operator/sessionrecording/fakes/fakes.go +++ b/k8s-operator/sessionrecording/fakes/fakes.go @@ -10,13 +10,13 @@ package fakes import ( "bytes" "encoding/json" + "fmt" + "math/rand" "net" "sync" "testing" "time" - "math/rand" - "tailscale.com/sessionrecording" "tailscale.com/tstime" ) @@ -107,7 +107,13 @@ func CastLine(t *testing.T, p []byte, clock tstime.Clock) []byte { return append(j, '\n') } -func AsciinemaResizeMsg(t *testing.T, width, height int) []byte { +func AsciinemaCastResizeMsg(t *testing.T, width, height int) []byte { + msg := fmt.Sprintf(`[0,"r","%dx%d"]`, height, width) + + return append([]byte(msg), '\n') +} + +func AsciinemaCastHeaderMsg(t *testing.T, width, height int) []byte { t.Helper() ch := sessionrecording.CastHeader{ Width: width, diff --git a/k8s-operator/sessionrecording/hijacker.go b/k8s-operator/sessionrecording/hijacker.go index f8ef951d4..2d6c94710 100644 --- a/k8s-operator/sessionrecording/hijacker.go +++ b/k8s-operator/sessionrecording/hijacker.go @@ -4,13 +4,14 @@ //go:build !plan9 // Package sessionrecording contains functionality for recording Kubernetes API -// server proxy 'kubectl exec' sessions. +// server proxy 'kubectl exec/attach' sessions. package sessionrecording import ( "bufio" "bytes" "context" + "errors" "fmt" "io" "net" @@ -19,29 +20,34 @@ import ( "net/netip" "strings" - "github.com/pkg/errors" "go.uber.org/zap" "tailscale.com/client/tailscale/apitype" "tailscale.com/k8s-operator/sessionrecording/spdy" "tailscale.com/k8s-operator/sessionrecording/tsrecorder" "tailscale.com/k8s-operator/sessionrecording/ws" + "tailscale.com/net/netx" "tailscale.com/sessionrecording" "tailscale.com/tailcfg" "tailscale.com/tsnet" "tailscale.com/tstime" "tailscale.com/util/clientmetric" - "tailscale.com/util/multierr" ) const ( - SPDYProtocol Protocol = "SPDY" - WSProtocol Protocol = "WebSocket" + SPDYProtocol Protocol = "SPDY" + WSProtocol Protocol = "WebSocket" + ExecSessionType SessionType = "exec" + AttachSessionType SessionType = "attach" ) // Protocol is the streaming protocol of the hijacked session. Supported // protocols are SPDY and WebSocket. type Protocol string +// SessionType is the type of session initiated with `kubectl` +// (`exec` or `attach`) +type SessionType string + var ( // CounterSessionRecordingsAttempted counts the number of session recording attempts. CounterSessionRecordingsAttempted = clientmetric.NewCounter("k8s_auth_proxy_session_recordings_attempted") @@ -50,7 +56,7 @@ var ( counterSessionRecordingsUploaded = clientmetric.NewCounter("k8s_auth_proxy_session_recordings_uploaded") ) -func New(opts HijackerOpts) *Hijacker { +func NewHijacker(opts HijackerOpts) *Hijacker { return &Hijacker{ ts: opts.TS, req: opts.Req, @@ -62,25 +68,27 @@ func New(opts HijackerOpts) *Hijacker { failOpen: opts.FailOpen, proto: opts.Proto, log: opts.Log, + sessionType: opts.SessionType, connectToRecorder: sessionrecording.ConnectToRecorder, } } type HijackerOpts struct { - TS *tsnet.Server - Req *http.Request - W http.ResponseWriter - Who *apitype.WhoIsResponse - Addrs []netip.AddrPort - Log *zap.SugaredLogger - Pod string - Namespace string - FailOpen bool - Proto Protocol + TS *tsnet.Server + Req *http.Request + W http.ResponseWriter + Who *apitype.WhoIsResponse + Addrs []netip.AddrPort + Log *zap.SugaredLogger + Pod string + Namespace string + FailOpen bool + Proto Protocol + SessionType SessionType } // Hijacker implements [net/http.Hijacker] interface. -// It must be configured with an http request for a 'kubectl exec' session that +// It must be configured with an http request for a 'kubectl exec/attach' session that // needs to be recorded. It knows how to hijack the connection and configure for // the session contents to be sent to a tsrecorder instance. type Hijacker struct { @@ -89,12 +97,13 @@ type Hijacker struct { req *http.Request who *apitype.WhoIsResponse log *zap.SugaredLogger - pod string // pod being exec-d - ns string // namespace of the pod being exec-d + pod string // pod being exec/attach-d + ns string // namespace of the pod being exec/attach-d addrs []netip.AddrPort // tsrecorder addresses failOpen bool // whether to fail open if recording fails connectToRecorder RecorderDialFn - proto Protocol // streaming protocol + proto Protocol // streaming protocol + sessionType SessionType // subcommand, e.g., "exec, attach" } // RecorderDialFn dials the specified netip.AddrPorts that should be tsrecorder @@ -102,9 +111,9 @@ type Hijacker struct { // connection succeeds. In case of success, returns a list with a single // successful recording attempt and an error channel. If the connection errors // after having been established, an error is sent down the channel. -type RecorderDialFn func(context.Context, []netip.AddrPort, func(context.Context, string, string) (net.Conn, error)) (io.WriteCloser, []*tailcfg.SSHRecordingAttempt, <-chan error, error) +type RecorderDialFn func(context.Context, []netip.AddrPort, netx.DialFunc) (io.WriteCloser, []*tailcfg.SSHRecordingAttempt, <-chan error, error) -// Hijack hijacks a 'kubectl exec' session and configures for the session +// Hijack hijacks a 'kubectl exec/attach' session and configures for the session // contents to be sent to a recorder. func (h *Hijacker) Hijack() (net.Conn, *bufio.ReadWriter, error) { h.log.Infof("recorder addrs: %v, failOpen: %v", h.addrs, h.failOpen) @@ -113,7 +122,7 @@ func (h *Hijacker) Hijack() (net.Conn, *bufio.ReadWriter, error) { return nil, nil, fmt.Errorf("error hijacking connection: %w", err) } - conn, err := h.setUpRecording(context.Background(), reqConn) + conn, err := h.setUpRecording(reqConn) if err != nil { return nil, nil, fmt.Errorf("error setting up session recording: %w", err) } @@ -124,7 +133,7 @@ func (h *Hijacker) Hijack() (net.Conn, *bufio.ReadWriter, error) { // spdyHijacker.addrs. Returns conn from provided opts, wrapped in recording // logic. If connecting to the recorder fails or an error is received during the // session and spdyHijacker.failOpen is false, connection will be closed. -func (h *Hijacker) setUpRecording(ctx context.Context, conn net.Conn) (net.Conn, error) { +func (h *Hijacker) setUpRecording(conn net.Conn) (_ net.Conn, retErr error) { const ( // https://docs.asciinema.org/manual/asciicast/v2/ asciicastv2 = 2 @@ -137,7 +146,15 @@ func (h *Hijacker) setUpRecording(ctx context.Context, conn net.Conn) (net.Conn, err error errChan <-chan error ) - h.log.Infof("kubectl exec session will be recorded, recorders: %v, fail open policy: %t", h.addrs, h.failOpen) + h.log.Infof("kubectl %s session will be recorded, recorders: %v, fail open policy: %t", h.sessionType, h.addrs, h.failOpen) + // NOTE: (ChaosInTheCRD) we want to use a dedicated context here, rather than the context from the request, + // otherwise the context can be cancelled by the client (kubectl) while we are still streaming to tsrecorder. + ctx, cancel := context.WithCancel(context.Background()) + defer func() { + if retErr != nil { + cancel() + } + }() qp := h.req.URL.Query() container := strings.Join(qp[containerKey], "") var recorderAddr net.Addr @@ -156,11 +173,11 @@ func (h *Hijacker) setUpRecording(ctx context.Context, conn net.Conn) (net.Conn, } msg = msg + "; failure mode is 'fail closed'; closing connection." if err := closeConnWithWarning(conn, msg); err != nil { - return nil, multierr.New(errors.New(msg), err) + return nil, errors.Join(errors.New(msg), err) } return nil, errors.New(msg) } else { - h.log.Infof("exec session to container %q in Pod %q namespace %q will be recorded, the recording will be sent to a tsrecorder instance at %q", container, h.pod, h.ns, recorderAddr) + h.log.Infof("%s session to container %q in Pod %q namespace %q will be recorded, the recording will be sent to a tsrecorder instance at %q", h.sessionType, container, h.pod, h.ns, recorderAddr) } cl := tstime.DefaultClock{} @@ -174,9 +191,10 @@ func (h *Hijacker) setUpRecording(ctx context.Context, conn net.Conn) (net.Conn, SrcNode: strings.TrimSuffix(h.who.Node.Name, "."), SrcNodeID: h.who.Node.StableID, Kubernetes: &sessionrecording.Kubernetes{ - PodName: h.pod, - Namespace: h.ns, - Container: container, + PodName: h.pod, + Namespace: h.ns, + Container: container, + SessionType: string(h.sessionType), }, } if !h.who.Node.IsTagged() { @@ -189,14 +207,21 @@ func (h *Hijacker) setUpRecording(ctx context.Context, conn net.Conn) (net.Conn, var lc net.Conn switch h.proto { case SPDYProtocol: - lc = spdy.New(conn, rec, ch, hasTerm, h.log) + lc, err = spdy.New(ctx, conn, rec, ch, hasTerm, h.log) + if err != nil { + return nil, fmt.Errorf("failed to initialize spdy connection: %w", err) + } case WSProtocol: - lc = ws.New(conn, rec, ch, hasTerm, h.log) + lc, err = ws.New(ctx, conn, rec, ch, hasTerm, h.log) + if err != nil { + return nil, fmt.Errorf("failed to initialize websocket connection: %w", err) + } default: return nil, fmt.Errorf("unknown protocol: %s", h.proto) } go func() { + defer cancel() var err error select { case <-ctx.Done(): @@ -208,7 +233,7 @@ func (h *Hijacker) setUpRecording(ctx context.Context, conn net.Conn) (net.Conn, h.log.Info("finished uploading the recording") return } - msg := fmt.Sprintf("connection to the session recorder errorred: %v;", err) + msg := fmt.Sprintf("connection to the session recorder errored: %v;", err) if h.failOpen { msg += msg + "; failure mode is 'fail open'; continuing session without recording." h.log.Info(msg) @@ -220,7 +245,6 @@ func (h *Hijacker) setUpRecording(ctx context.Context, conn net.Conn) (net.Conn, if err := lc.Close(); err != nil { h.log.Infof("error closing recorder connections: %v", err) } - return }() return lc, nil } @@ -229,7 +253,7 @@ func closeConnWithWarning(conn net.Conn, msg string) error { b := io.NopCloser(bytes.NewBuffer([]byte(msg))) resp := http.Response{Status: http.StatusText(http.StatusForbidden), StatusCode: http.StatusForbidden, Body: b} if err := resp.Write(conn); err != nil { - return multierr.New(fmt.Errorf("error writing msg %q to conn: %v", msg, err), conn.Close()) + return errors.Join(fmt.Errorf("error writing msg %q to conn: %v", msg, err), conn.Close()) } return conn.Close() } diff --git a/k8s-operator/sessionrecording/hijacker_test.go b/k8s-operator/sessionrecording/hijacker_test.go index 440d9c942..fb45820a7 100644 --- a/k8s-operator/sessionrecording/hijacker_test.go +++ b/k8s-operator/sessionrecording/hijacker_test.go @@ -10,7 +10,6 @@ import ( "errors" "fmt" "io" - "net" "net/http" "net/netip" "net/url" @@ -20,6 +19,7 @@ import ( "go.uber.org/zap" "tailscale.com/client/tailscale/apitype" "tailscale.com/k8s-operator/sessionrecording/fakes" + "tailscale.com/net/netx" "tailscale.com/tailcfg" "tailscale.com/tsnet" "tailscale.com/tstest" @@ -80,7 +80,7 @@ func Test_Hijacker(t *testing.T) { h := &Hijacker{ connectToRecorder: func(context.Context, []netip.AddrPort, - func(context.Context, string, string) (net.Conn, error), + netx.DialFunc, ) (wc io.WriteCloser, rec []*tailcfg.SSHRecordingAttempt, _ <-chan error, err error) { if tt.failRecorderConnect { err = errors.New("test") @@ -91,11 +91,11 @@ func Test_Hijacker(t *testing.T) { who: &apitype.WhoIsResponse{Node: &tailcfg.Node{}, UserProfile: &tailcfg.UserProfile{}}, log: zl.Sugar(), ts: &tsnet.Server{}, - req: &http.Request{URL: &url.URL{}}, + req: &http.Request{URL: &url.URL{RawQuery: "tty=true"}}, proto: tt.proto, } ctx := context.Background() - _, err := h.setUpRecording(ctx, tc) + _, err := h.setUpRecording(tc) if (err != nil) != tt.wantsSetupErr { t.Errorf("spdyHijacker.setupRecording() error = %v, wantErr %v", err, tt.wantsSetupErr) return diff --git a/k8s-operator/sessionrecording/spdy/conn.go b/k8s-operator/sessionrecording/spdy/conn.go index 455c2225a..9fefca11f 100644 --- a/k8s-operator/sessionrecording/spdy/conn.go +++ b/k8s-operator/sessionrecording/spdy/conn.go @@ -4,11 +4,12 @@ //go:build !plan9 // Package spdy contains functionality for parsing SPDY streaming sessions. This -// is used for 'kubectl exec' session recording. +// is used for 'kubectl exec/attach' session recording. package spdy import ( "bytes" + "context" "encoding/binary" "encoding/json" "fmt" @@ -24,29 +25,50 @@ import ( ) // New wraps the provided network connection and returns a connection whose reads and writes will get triggered as data is received on the hijacked connection. -// The connection must be a hijacked connection for a 'kubectl exec' session using SPDY. +// The connection must be a hijacked connection for a 'kubectl exec/attach' session using SPDY. // The hijacked connection is used to transmit SPDY streams between Kubernetes client ('kubectl') and the destination container. // Data read from the underlying network connection is data sent via one of the SPDY streams from the client to the container. // Data written to the underlying connection is data sent from the container to the client. // We parse the data and send everything for the stdout/stderr streams to the configured tsrecorder as an asciinema recording with the provided header. // https://github.com/kubernetes/enhancements/tree/master/keps/sig-api-machinery/4006-transition-spdy-to-websockets#background-remotecommand-subprotocol -func New(nc net.Conn, rec *tsrecorder.Client, ch sessionrecording.CastHeader, hasTerm bool, log *zap.SugaredLogger) net.Conn { - return &conn{ - Conn: nc, - rec: rec, - ch: ch, - log: log, - hasTerm: hasTerm, - initialTermSizeSet: make(chan struct{}), +func New(ctx context.Context, nc net.Conn, rec *tsrecorder.Client, ch sessionrecording.CastHeader, hasTerm bool, log *zap.SugaredLogger) (net.Conn, error) { + lc := &conn{ + Conn: nc, + ctx: ctx, + rec: rec, + ch: ch, + log: log, + hasTerm: hasTerm, + initialCastHeaderSent: make(chan struct{}, 1), } + + // if there is no term, we don't need to wait for a resize message + if !hasTerm { + var err error + lc.writeCastHeaderOnce.Do(func() { + // If this is a session with a terminal attached, + // we must wait for the terminal width and + // height to be parsed from a resize message + // before sending CastHeader, else tsrecorder + // will not be able to play this recording. + err = lc.rec.WriteCastHeader(ch) + close(lc.initialCastHeaderSent) + }) + if err != nil { + return nil, fmt.Errorf("error writing CastHeader: %w", err) + } + } + + return lc, nil } // conn is a wrapper around net.Conn. It reads the bytestream for a 'kubectl -// exec' session streamed using SPDY protocol, sends session recording data to +// exec/attach' session streamed using SPDY protocol, sends session recording data to // the configured recorder and forwards the raw bytes to the original // destination. type conn struct { net.Conn + ctx context.Context // rec knows how to send data written to it to a tsrecorder instance. rec *tsrecorder.Client @@ -63,7 +85,7 @@ type conn struct { // CastHeader must be sent before any payload. If the session has a // terminal attached, the CastHeader must have '.Width' and '.Height' // fields set for the tsrecorder UI to be able to play the recording. - // For 'kubectl exec' sessions, terminal width and height are sent as a + // For 'kubectl exec/attach' sessions, terminal width and height are sent as a // resize message on resize stream from the client when the session // starts as well as at any time the client detects a terminal change. // We can intercept the resize message on Read calls. As there is no @@ -79,15 +101,10 @@ type conn struct { // writeCastHeaderOnce is used to ensure CastHeader gets sent to tsrecorder once. writeCastHeaderOnce sync.Once hasTerm bool // whether the session had TTY attached - // initialTermSizeSet channel gets sent a value once, when the Read has - // received a resize message and set the initial terminal size. It must - // be set to a buffered channel to prevent Reads being blocked on the - // first stdout/stderr write reading from the channel. - initialTermSizeSet chan struct{} - // sendInitialTermSizeSetOnce is used to ensure that a value is sent to - // initialTermSizeSet channel only once, when the initial resize message - // is received. - sendinitialTermSizeSetOnce sync.Once + // initialCastHeaderSent is a channel to ensure that the cast + // header is the first thing that is streamed to the session recorder. + // Otherwise the stream will fail. + initialCastHeaderSent chan struct{} zlibReqReader zlibReader // writeBuf is used to store data written to the connection that has not @@ -124,7 +141,7 @@ func (c *conn) Read(b []byte) (int, error) { } c.readBuf.Next(len(sf.Raw)) // advance buffer past the parsed frame - if !sf.Ctrl { // data frame + if !sf.Ctrl && c.hasTerm { // data frame switch sf.StreamID { case c.resizeStreamID.Load(): @@ -140,10 +157,19 @@ func (c *conn) Read(b []byte) (int, error) { // subsequent resize message, we need to send asciinema // resize message. var isInitialResize bool - c.sendinitialTermSizeSetOnce.Do(func() { + c.writeCastHeaderOnce.Do(func() { isInitialResize = true - close(c.initialTermSizeSet) // unblock sending of CastHeader + // If this is a session with a terminal attached, + // we must wait for the terminal width and + // height to be parsed from a resize message + // before sending CastHeader, else tsrecorder + // will not be able to play this recording. + err = c.rec.WriteCastHeader(c.ch) + close(c.initialCastHeaderSent) }) + if err != nil { + return 0, fmt.Errorf("error writing CastHeader: %w", err) + } if !isInitialResize { if err := c.rec.WriteResize(c.ch.Height, c.ch.Width); err != nil { return 0, fmt.Errorf("error writing resize message: %w", err) @@ -190,24 +216,14 @@ func (c *conn) Write(b []byte) (int, error) { if !sf.Ctrl { switch sf.StreamID { case c.stdoutStreamID.Load(), c.stderrStreamID.Load(): - var err error - c.writeCastHeaderOnce.Do(func() { - // If this is a session with a terminal attached, - // we must wait for the terminal width and - // height to be parsed from a resize message - // before sending CastHeader, else tsrecorder - // will not be able to play this recording. - if c.hasTerm { - c.log.Debugf("write: waiting for the initial terminal size to be set before proceeding with sending the first payload") - <-c.initialTermSizeSet + // we must wait for confirmation that the initial cast header was sent before proceeding with any more writes + select { + case <-c.ctx.Done(): + return 0, c.ctx.Err() + case <-c.initialCastHeaderSent: + if err := c.rec.WriteOutput(sf.Payload); err != nil { + return 0, fmt.Errorf("error sending payload to session recorder: %w", err) } - err = c.rec.WriteCastHeader(c.ch) - }) - if err != nil { - return 0, fmt.Errorf("error writing CastHeader: %w", err) - } - if err := c.rec.WriteOutput(sf.Payload); err != nil { - return 0, fmt.Errorf("error sending payload to session recorder: %w", err) } } } diff --git a/k8s-operator/sessionrecording/spdy/conn_test.go b/k8s-operator/sessionrecording/spdy/conn_test.go index 3485d61c4..3c1cb8427 100644 --- a/k8s-operator/sessionrecording/spdy/conn_test.go +++ b/k8s-operator/sessionrecording/spdy/conn_test.go @@ -6,10 +6,12 @@ package spdy import ( + "context" "encoding/json" "fmt" "reflect" "testing" + "time" "go.uber.org/zap" "tailscale.com/k8s-operator/sessionrecording/fakes" @@ -29,15 +31,11 @@ func Test_Writes(t *testing.T) { } cl := tstest.NewClock(tstest.ClockOpts{}) tests := []struct { - name string - inputs [][]byte - wantForwarded []byte - wantRecorded []byte - firstWrite bool - width int - height int - sendInitialResize bool - hasTerm bool + name string + inputs [][]byte + wantForwarded []byte + wantRecorded []byte + hasTerm bool }{ { name: "single_write_control_frame_with_payload", @@ -78,24 +76,17 @@ func Test_Writes(t *testing.T) { wantRecorded: fakes.CastLine(t, []byte{0x1, 0x2, 0x3, 0x4, 0x5}, cl), }, { - name: "single_first_write_stdout_data_frame_with_payload_sess_has_terminal", - inputs: [][]byte{{0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x5, 0x1, 0x2, 0x3, 0x4, 0x5}}, - wantForwarded: []byte{0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x5, 0x1, 0x2, 0x3, 0x4, 0x5}, - wantRecorded: append(fakes.AsciinemaResizeMsg(t, 10, 20), fakes.CastLine(t, []byte{0x1, 0x2, 0x3, 0x4, 0x5}, cl)...), - width: 10, - height: 20, - hasTerm: true, - firstWrite: true, - sendInitialResize: true, + name: "single_first_write_stdout_data_frame_with_payload_sess_has_terminal", + inputs: [][]byte{{0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x5, 0x1, 0x2, 0x3, 0x4, 0x5}}, + wantForwarded: []byte{0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x5, 0x1, 0x2, 0x3, 0x4, 0x5}, + wantRecorded: fakes.CastLine(t, []byte{0x1, 0x2, 0x3, 0x4, 0x5}, cl), + hasTerm: true, }, { name: "single_first_write_stdout_data_frame_with_payload_sess_does_not_have_terminal", inputs: [][]byte{{0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x5, 0x1, 0x2, 0x3, 0x4, 0x5}}, wantForwarded: []byte{0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x5, 0x1, 0x2, 0x3, 0x4, 0x5}, - wantRecorded: append(fakes.AsciinemaResizeMsg(t, 10, 20), fakes.CastLine(t, []byte{0x1, 0x2, 0x3, 0x4, 0x5}, cl)...), - width: 10, - height: 20, - firstWrite: true, + wantRecorded: fakes.CastLine(t, []byte{0x1, 0x2, 0x3, 0x4, 0x5}, cl), }, } for _, tt := range tests { @@ -104,29 +95,25 @@ func Test_Writes(t *testing.T) { sr := &fakes.TestSessionRecorder{} rec := tsrecorder.New(sr, cl, cl.Now(), true, zl.Sugar()) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() c := &conn{ - Conn: tc, - log: zl.Sugar(), - rec: rec, - ch: sessionrecording.CastHeader{ - Width: tt.width, - Height: tt.height, - }, - initialTermSizeSet: make(chan struct{}), - hasTerm: tt.hasTerm, - } - if !tt.firstWrite { - // this test case does not intend to test that cast header gets written once - c.writeCastHeaderOnce.Do(func() {}) - } - if tt.sendInitialResize { - close(c.initialTermSizeSet) + ctx: ctx, + Conn: tc, + log: zl.Sugar(), + rec: rec, + ch: sessionrecording.CastHeader{}, + initialCastHeaderSent: make(chan struct{}), + hasTerm: tt.hasTerm, } + c.writeCastHeaderOnce.Do(func() { + close(c.initialCastHeaderSent) + }) + c.stdoutStreamID.Store(stdoutStreamID) c.stderrStreamID.Store(stderrStreamID) for i, input := range tt.inputs { - c.hasTerm = tt.hasTerm if _, err := c.Write(input); err != nil { t.Errorf("[%d] spdyRemoteConnRecorder.Write() unexpected error %v", i, err) } @@ -171,11 +158,25 @@ func Test_Reads(t *testing.T) { wantResizeStreamID uint32 wantWidth int wantHeight int + wantRecorded []byte resizeStreamIDBeforeRead uint32 }{ { name: "resize_data_frame_single_read", inputs: [][]byte{append([]byte{0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, uint8(len(resizeMsg))}, resizeMsg...)}, + wantRecorded: fakes.AsciinemaCastHeaderMsg(t, 10, 20), + resizeStreamIDBeforeRead: 1, + wantWidth: 10, + wantHeight: 20, + }, + { + name: "resize_data_frame_many", + inputs: [][]byte{ + append([]byte{0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, uint8(len(resizeMsg))}, resizeMsg...), + append([]byte{0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, uint8(len(resizeMsg))}, resizeMsg...), + }, + wantRecorded: append(fakes.AsciinemaCastHeaderMsg(t, 10, 20), fakes.AsciinemaCastResizeMsg(t, 10, 20)...), + resizeStreamIDBeforeRead: 1, wantWidth: 10, wantHeight: 20, @@ -183,6 +184,7 @@ func Test_Reads(t *testing.T) { { name: "resize_data_frame_two_reads", inputs: [][]byte{{0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, uint8(len(resizeMsg))}, resizeMsg}, + wantRecorded: fakes.AsciinemaCastHeaderMsg(t, 10, 20), resizeStreamIDBeforeRead: 1, wantWidth: 10, wantHeight: 20, @@ -215,11 +217,15 @@ func Test_Reads(t *testing.T) { tc := &fakes.TestConn{} sr := &fakes.TestSessionRecorder{} rec := tsrecorder.New(sr, cl, cl.Now(), true, zl.Sugar()) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() c := &conn{ - Conn: tc, - log: zl.Sugar(), - rec: rec, - initialTermSizeSet: make(chan struct{}), + ctx: ctx, + Conn: tc, + log: zl.Sugar(), + rec: rec, + initialCastHeaderSent: make(chan struct{}), + hasTerm: true, } c.resizeStreamID.Store(tt.resizeStreamIDBeforeRead) @@ -251,6 +257,12 @@ func Test_Reads(t *testing.T) { t.Errorf("want height: %v, got %v", tt.wantHeight, c.ch.Height) } } + + // Assert that the expected bytes have been forwarded to the session recorder. + gotRecorded := sr.Bytes() + if !reflect.DeepEqual(gotRecorded, tt.wantRecorded) { + t.Errorf("expected bytes not recorded, wants\n%v\ngot\n%v", tt.wantRecorded, gotRecorded) + } }) } } diff --git a/k8s-operator/sessionrecording/tsrecorder/tsrecorder.go b/k8s-operator/sessionrecording/tsrecorder/tsrecorder.go index af5fcb8da..a5bdf7ddd 100644 --- a/k8s-operator/sessionrecording/tsrecorder/tsrecorder.go +++ b/k8s-operator/sessionrecording/tsrecorder/tsrecorder.go @@ -25,6 +25,7 @@ func New(conn io.WriteCloser, clock tstime.Clock, start time.Time, failOpen bool clock: clock, conn: conn, failOpen: failOpen, + logger: logger, } } diff --git a/k8s-operator/sessionrecording/ws/conn.go b/k8s-operator/sessionrecording/ws/conn.go index 86029f67b..a618f85fb 100644 --- a/k8s-operator/sessionrecording/ws/conn.go +++ b/k8s-operator/sessionrecording/ws/conn.go @@ -3,12 +3,13 @@ //go:build !plan9 -// package ws has functionality to parse 'kubectl exec' sessions streamed using +// package ws has functionality to parse 'kubectl exec/attach' sessions streamed using // WebSocket protocol. package ws import ( "bytes" + "context" "encoding/json" "errors" "fmt" @@ -20,35 +21,56 @@ import ( "k8s.io/apimachinery/pkg/util/remotecommand" "tailscale.com/k8s-operator/sessionrecording/tsrecorder" "tailscale.com/sessionrecording" - "tailscale.com/util/multierr" ) // New wraps the provided network connection and returns a connection whose reads and writes will get triggered as data is received on the hijacked connection. -// The connection must be a hijacked connection for a 'kubectl exec' session using WebSocket protocol and a *.channel.k8s.io subprotocol. +// The connection must be a hijacked connection for a 'kubectl exec/attach' session using WebSocket protocol and a *.channel.k8s.io subprotocol. // The hijacked connection is used to transmit *.channel.k8s.io streams between Kubernetes client ('kubectl') and the destination proxy controlled by Kubernetes. // Data read from the underlying network connection is data sent via one of the streams from the client to the container. // Data written to the underlying connection is data sent from the container to the client. // We parse the data and send everything for the stdout/stderr streams to the configured tsrecorder as an asciinema recording with the provided header. // https://github.com/kubernetes/enhancements/tree/master/keps/sig-api-machinery/4006-transition-spdy-to-websockets#proposal-new-remotecommand-sub-protocol-version---v5channelk8sio -func New(c net.Conn, rec *tsrecorder.Client, ch sessionrecording.CastHeader, hasTerm bool, log *zap.SugaredLogger) net.Conn { - return &conn{ - Conn: c, - rec: rec, - ch: ch, - hasTerm: hasTerm, - log: log, - initialTermSizeSet: make(chan struct{}, 1), +func New(ctx context.Context, c net.Conn, rec *tsrecorder.Client, ch sessionrecording.CastHeader, hasTerm bool, log *zap.SugaredLogger) (net.Conn, error) { + lc := &conn{ + Conn: c, + ctx: ctx, + rec: rec, + ch: ch, + hasTerm: hasTerm, + log: log, + initialCastHeaderSent: make(chan struct{}, 1), } + + // if there is no term, we don't need to wait for a resize message + if !hasTerm { + var err error + lc.writeCastHeaderOnce.Do(func() { + // If this is a session with a terminal attached, + // we must wait for the terminal width and + // height to be parsed from a resize message + // before sending CastHeader, else tsrecorder + // will not be able to play this recording. + err = lc.rec.WriteCastHeader(ch) + close(lc.initialCastHeaderSent) + }) + if err != nil { + return nil, fmt.Errorf("error writing CastHeader: %w", err) + } + } + + return lc, nil } // conn is a wrapper around net.Conn. It reads the bytestream -// for a 'kubectl exec' session, sends session recording data to the configured +// for a 'kubectl exec/attach' session, sends session recording data to the configured // recorder and forwards the raw bytes to the original destination. // A new conn is created per session. -// conn only knows to how to read a 'kubectl exec' session that is streamed using WebSocket protocol. +// conn only knows to how to read a 'kubectl exec/attach' session that is streamed using WebSocket protocol. // https://www.rfc-editor.org/rfc/rfc6455 type conn struct { net.Conn + + ctx context.Context // rec knows how to send data to a tsrecorder instance. rec *tsrecorder.Client @@ -56,7 +78,7 @@ type conn struct { // CastHeader must be sent before any payload. If the session has a // terminal attached, the CastHeader must have '.Width' and '.Height' // fields set for the tsrecorder UI to be able to play the recording. - // For 'kubectl exec' sessions, terminal width and height are sent as a + // For 'kubectl exec/attach' sessions, terminal width and height are sent as a // resize message on resize stream from the client when the session // starts as well as at any time the client detects a terminal change. // We can intercept the resize message on Read calls. As there is no @@ -72,15 +94,10 @@ type conn struct { // writeCastHeaderOnce is used to ensure CastHeader gets sent to tsrecorder once. writeCastHeaderOnce sync.Once hasTerm bool // whether the session has TTY attached - // initialTermSizeSet channel gets sent a value once, when the Read has - // received a resize message and set the initial terminal size. It must - // be set to a buffered channel to prevent Reads being blocked on the - // first stdout/stderr write reading from the channel. - initialTermSizeSet chan struct{} - // sendInitialTermSizeSetOnce is used to ensure that a value is sent to - // initialTermSizeSet channel only once, when the initial resize message - // is received. - sendInitialTermSizeSetOnce sync.Once + // initialCastHeaderSent is a boolean that is set to ensure that the cast + // header is the first thing that is streamed to the session recorder. + // Otherwise the stream will fail. + initialCastHeaderSent chan struct{} log *zap.SugaredLogger @@ -130,6 +147,8 @@ func (c *conn) Read(b []byte) (int, error) { return 0, nil } + // TODO(tomhjp): If we get multiple frames in a single Read with different + // types, we may parse the second frame with the wrong type. typ := messageType(opcode(b)) if (typ == noOpcode && c.readMsgIsIncomplete()) || c.readBufHasIncompleteFragment() { // subsequent fragment if typ, err = c.curReadMsgType(); err != nil { @@ -139,6 +158,8 @@ func (c *conn) Read(b []byte) (int, error) { // A control message can not be fragmented and we are not interested in // these messages. Just return. + // TODO(tomhjp): If we get multiple frames in a single Read, we may skip + // some non-control messages. if isControlMessage(typ) { return n, nil } @@ -151,54 +172,65 @@ func (c *conn) Read(b []byte) (int, error) { return n, nil } - readMsg := &message{typ: typ} // start a new message... - // ... or pick up an already started one if the previous fragment was not final. - if c.readMsgIsIncomplete() || c.readBufHasIncompleteFragment() { - readMsg = c.currentReadMsg - } - if _, err := c.readBuf.Write(b[:n]); err != nil { return 0, fmt.Errorf("[unexpected] error writing message contents to read buffer: %w", err) } - ok, err := readMsg.Parse(c.readBuf.Bytes(), c.log) - if err != nil { - return 0, fmt.Errorf("error parsing message: %v", err) - } - if !ok { // incomplete fragment - return n, nil - } - c.readBuf.Next(len(readMsg.raw)) - - if readMsg.isFinalized && !c.readMsgIsIncomplete() { - // Stream IDs for websocket streams are static. - // https://github.com/kubernetes/client-go/blob/v0.30.0-rc.1/tools/remotecommand/websocket.go#L218 - if readMsg.streamID.Load() == remotecommand.StreamResize { - var msg tsrecorder.ResizeMsg - if err = json.Unmarshal(readMsg.payload, &msg); err != nil { - return 0, fmt.Errorf("error umarshalling resize message: %w", err) - } + for c.readBuf.Len() != 0 { + readMsg := &message{typ: typ} // start a new message... + // ... or pick up an already started one if the previous fragment was not final. + if c.readMsgIsIncomplete() { + readMsg = c.currentReadMsg + } - c.ch.Width = msg.Width - c.ch.Height = msg.Height - - // If this is initial resize message, the width and - // height will be sent in the CastHeader. If this is a - // subsequent resize message, we need to send asciinema - // resize message. - var isInitialResize bool - c.sendInitialTermSizeSetOnce.Do(func() { - isInitialResize = true - close(c.initialTermSizeSet) // unblock sending of CastHeader - }) - if !isInitialResize { - if err := c.rec.WriteResize(c.ch.Height, c.ch.Width); err != nil { - return 0, fmt.Errorf("error writing resize message: %w", err) + ok, err := readMsg.Parse(c.readBuf.Bytes(), c.log) + if err != nil { + return 0, fmt.Errorf("error parsing message: %v", err) + } + if !ok { // incomplete fragment + return n, nil + } + c.readBuf.Next(len(readMsg.raw)) + + if readMsg.isFinalized && !c.readMsgIsIncomplete() { + // we want to send stream resize messages for terminal sessions + // Stream IDs for websocket streams are static. + // https://github.com/kubernetes/client-go/blob/v0.30.0-rc.1/tools/remotecommand/websocket.go#L218 + if readMsg.streamID.Load() == remotecommand.StreamResize && c.hasTerm { + var msg tsrecorder.ResizeMsg + if err = json.Unmarshal(readMsg.payload, &msg); err != nil { + return 0, fmt.Errorf("error umarshalling resize message: %w", err) + } + + c.ch.Width = msg.Width + c.ch.Height = msg.Height + + var isInitialResize bool + c.writeCastHeaderOnce.Do(func() { + isInitialResize = true + // If this is a session with a terminal attached, + // we must wait for the terminal width and + // height to be parsed from a resize message + // before sending CastHeader, else tsrecorder + // will not be able to play this recording. + err = c.rec.WriteCastHeader(c.ch) + close(c.initialCastHeaderSent) + }) + if err != nil { + return 0, fmt.Errorf("error writing CastHeader: %w", err) + } + + if !isInitialResize { + if err := c.rec.WriteResize(msg.Height, msg.Width); err != nil { + return 0, fmt.Errorf("error writing resize message: %w", err) + } } } } + + c.currentReadMsg = readMsg } - c.currentReadMsg = readMsg + return n, nil } @@ -244,39 +276,33 @@ func (c *conn) Write(b []byte) (int, error) { c.log.Errorf("write: parsing a message errored: %v", err) return 0, fmt.Errorf("write: error parsing message: %v", err) } + c.currentWriteMsg = writeMsg if !ok { // incomplete fragment return len(b), nil } + c.writeBuf.Next(len(writeMsg.raw)) // advance frame if len(writeMsg.payload) != 0 && writeMsg.isFinalized { if writeMsg.streamID.Load() == remotecommand.StreamStdOut || writeMsg.streamID.Load() == remotecommand.StreamStdErr { - var err error - c.writeCastHeaderOnce.Do(func() { - // If this is a session with a terminal attached, - // we must wait for the terminal width and - // height to be parsed from a resize message - // before sending CastHeader, else tsrecorder - // will not be able to play this recording. - if c.hasTerm { - c.log.Debug("waiting for terminal size to be set before starting to send recorded data") - <-c.initialTermSizeSet + // we must wait for confirmation that the initial cast header was sent before proceeding with any more writes + select { + case <-c.ctx.Done(): + return 0, c.ctx.Err() + case <-c.initialCastHeaderSent: + if err := c.rec.WriteOutput(writeMsg.payload); err != nil { + return 0, fmt.Errorf("error writing message to recorder: %w", err) } - err = c.rec.WriteCastHeader(c.ch) - }) - if err != nil { - return 0, fmt.Errorf("error writing CastHeader: %w", err) - } - if err := c.rec.WriteOutput(writeMsg.payload); err != nil { - return 0, fmt.Errorf("error writing message to recorder: %v", err) } } } + _, err = c.Conn.Write(c.currentWriteMsg.raw) if err != nil { c.log.Errorf("write: error writing to conn: %v", err) } + return len(b), nil } @@ -289,7 +315,7 @@ func (c *conn) Close() error { c.closed = true connCloseErr := c.Conn.Close() recCloseErr := c.rec.Close() - return multierr.New(connCloseErr, recCloseErr) + return errors.Join(connCloseErr, recCloseErr) } // writeBufHasIncompleteFragment returns true if the latest data message @@ -321,6 +347,7 @@ func (c *conn) writeMsgIsIncomplete() bool { func (c *conn) readMsgIsIncomplete() bool { return c.currentReadMsg != nil && !c.currentReadMsg.isFinalized } + func (c *conn) curReadMsgType() (messageType, error) { if c.currentReadMsg != nil { return c.currentReadMsg.typ, nil diff --git a/k8s-operator/sessionrecording/ws/conn_test.go b/k8s-operator/sessionrecording/ws/conn_test.go index 11174480b..87205c4e6 100644 --- a/k8s-operator/sessionrecording/ws/conn_test.go +++ b/k8s-operator/sessionrecording/ws/conn_test.go @@ -6,9 +6,12 @@ package ws import ( + "context" "fmt" "reflect" + "runtime/debug" "testing" + "time" "go.uber.org/zap" "k8s.io/apimachinery/pkg/util/remotecommand" @@ -26,46 +29,93 @@ func Test_conn_Read(t *testing.T) { // Resize stream ID + {"width": 10, "height": 20} testResizeMsg := []byte{byte(remotecommand.StreamResize), 0x7b, 0x22, 0x77, 0x69, 0x64, 0x74, 0x68, 0x22, 0x3a, 0x31, 0x30, 0x2c, 0x22, 0x68, 0x65, 0x69, 0x67, 0x68, 0x74, 0x22, 0x3a, 0x32, 0x30, 0x7d} lenResizeMsgPayload := byte(len(testResizeMsg)) - + cl := tstest.NewClock(tstest.ClockOpts{}) tests := []struct { - name string - inputs [][]byte - wantWidth int - wantHeight int + name string + inputs [][]byte + wantCastHeaderWidth int + wantCastHeaderHeight int + wantRecorded []byte }{ { name: "single_read_control_message", inputs: [][]byte{{0x88, 0x0}}, }, { - name: "single_read_resize_message", - inputs: [][]byte{append([]byte{0x82, lenResizeMsgPayload}, testResizeMsg...)}, - wantWidth: 10, - wantHeight: 20, + name: "single_read_resize_message", + inputs: [][]byte{append([]byte{0x82, lenResizeMsgPayload}, testResizeMsg...)}, + wantCastHeaderWidth: 10, + wantCastHeaderHeight: 20, + wantRecorded: fakes.AsciinemaCastHeaderMsg(t, 10, 20), + }, + { + name: "resize_data_frame_many", + inputs: [][]byte{ + append([]byte{0x82, lenResizeMsgPayload}, testResizeMsg...), + append([]byte{0x82, lenResizeMsgPayload}, testResizeMsg...), + }, + wantRecorded: append(fakes.AsciinemaCastHeaderMsg(t, 10, 20), fakes.AsciinemaCastResizeMsg(t, 10, 20)...), + wantCastHeaderWidth: 10, + wantCastHeaderHeight: 20, + }, + { + name: "resize_data_frame_two_in_one_read", + inputs: [][]byte{ + fmt.Appendf(nil, "%s%s", + append([]byte{0x82, lenResizeMsgPayload}, testResizeMsg...), + append([]byte{0x82, lenResizeMsgPayload}, testResizeMsg...), + ), + }, + wantRecorded: append(fakes.AsciinemaCastHeaderMsg(t, 10, 20), fakes.AsciinemaCastResizeMsg(t, 10, 20)...), + wantCastHeaderWidth: 10, + wantCastHeaderHeight: 20, }, { - name: "two_reads_resize_message", - inputs: [][]byte{{0x2, 0x9, 0x4, 0x7b, 0x22, 0x77, 0x69, 0x64, 0x74, 0x68, 0x22}, {0x80, 0x11, 0x4, 0x3a, 0x31, 0x30, 0x2c, 0x22, 0x68, 0x65, 0x69, 0x67, 0x68, 0x74, 0x22, 0x3a, 0x32, 0x30, 0x7d}}, - wantWidth: 10, - wantHeight: 20, + name: "two_reads_resize_message", + inputs: [][]byte{ + // op, len, stream ID, `{"width` + {0x2, 0x9, 0x4, 0x7b, 0x22, 0x77, 0x69, 0x64, 0x74, 0x68, 0x22}, + // op, len, stream ID, `:10,"height":20}` + {0x80, 0x11, 0x4, 0x3a, 0x31, 0x30, 0x2c, 0x22, 0x68, 0x65, 0x69, 0x67, 0x68, 0x74, 0x22, 0x3a, 0x32, 0x30, 0x7d}, + }, + wantCastHeaderWidth: 10, + wantCastHeaderHeight: 20, + wantRecorded: fakes.AsciinemaCastHeaderMsg(t, 10, 20), }, { - name: "three_reads_resize_message_with_split_fragment", - inputs: [][]byte{{0x2, 0x9, 0x4, 0x7b, 0x22, 0x77, 0x69, 0x64, 0x74, 0x68, 0x22}, {0x80, 0x11, 0x4, 0x3a, 0x31, 0x30, 0x2c, 0x22, 0x68, 0x65, 0x69, 0x67, 0x68, 0x74}, {0x22, 0x3a, 0x32, 0x30, 0x7d}}, - wantWidth: 10, - wantHeight: 20, + name: "three_reads_resize_message_with_split_fragment", + inputs: [][]byte{ + // op, len, stream ID, `{"width"` + {0x2, 0x9, 0x4, 0x7b, 0x22, 0x77, 0x69, 0x64, 0x74, 0x68, 0x22}, + // op, len, stream ID, `:10,"height` + {0x00, 0x0c, 0x4, 0x3a, 0x31, 0x30, 0x2c, 0x22, 0x68, 0x65, 0x69, 0x67, 0x68, 0x74}, + // op, len, stream ID, `":20}` + {0x80, 0x06, 0x4, 0x22, 0x3a, 0x32, 0x30, 0x7d}, + }, + wantCastHeaderWidth: 10, + wantCastHeaderHeight: 20, + wantRecorded: fakes.AsciinemaCastHeaderMsg(t, 10, 20), }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + log := zl.Sugar() tc := &fakes.TestConn{} + sr := &fakes.TestSessionRecorder{} + rec := tsrecorder.New(sr, cl, cl.Now(), true, zl.Sugar()) tc.ResetReadBuf() + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() c := &conn{ - Conn: tc, - log: zl.Sugar(), + ctx: ctx, + Conn: tc, + log: log, + hasTerm: true, + initialCastHeaderSent: make(chan struct{}), + rec: rec, } for i, input := range tt.inputs { - c.initialTermSizeSet = make(chan struct{}) if err := tc.WriteReadBufBytes(input); err != nil { t.Fatalf("writing bytes to test conn: %v", err) } @@ -75,14 +125,20 @@ func Test_conn_Read(t *testing.T) { return } } - if tt.wantHeight != 0 || tt.wantWidth != 0 { - if tt.wantWidth != c.ch.Width { - t.Errorf("wants width: %v, got %v", tt.wantWidth, c.ch.Width) + + if tt.wantCastHeaderHeight != 0 || tt.wantCastHeaderWidth != 0 { + if tt.wantCastHeaderWidth != c.ch.Width { + t.Errorf("wants width: %v, got %v", tt.wantCastHeaderWidth, c.ch.Width) } - if tt.wantHeight != c.ch.Height { - t.Errorf("want height: %v, got %v", tt.wantHeight, c.ch.Height) + if tt.wantCastHeaderHeight != c.ch.Height { + t.Errorf("want height: %v, got %v", tt.wantCastHeaderHeight, c.ch.Height) } } + + gotRecorded := sr.Bytes() + if !reflect.DeepEqual(gotRecorded, tt.wantRecorded) { + t.Errorf("expected bytes not recorded, wants\n%v\ngot\n%v", string(tt.wantRecorded), string(gotRecorded)) + } }) } } @@ -94,15 +150,11 @@ func Test_conn_Write(t *testing.T) { } cl := tstest.NewClock(tstest.ClockOpts{}) tests := []struct { - name string - inputs [][]byte - wantForwarded []byte - wantRecorded []byte - firstWrite bool - width int - height int - hasTerm bool - sendInitialResize bool + name string + inputs [][]byte + wantForwarded []byte + wantRecorded []byte + hasTerm bool }{ { name: "single_write_control_frame", @@ -130,10 +182,7 @@ func Test_conn_Write(t *testing.T) { name: "single_write_stdout_data_message_with_cast_header", inputs: [][]byte{{0x82, 0x3, 0x1, 0x7, 0x8}}, wantForwarded: []byte{0x82, 0x3, 0x1, 0x7, 0x8}, - wantRecorded: append(fakes.AsciinemaResizeMsg(t, 10, 20), fakes.CastLine(t, []byte{0x7, 0x8}, cl)...), - width: 10, - height: 20, - firstWrite: true, + wantRecorded: fakes.CastLine(t, []byte{0x7, 0x8}, cl), }, { name: "two_writes_stdout_data_message", @@ -148,15 +197,11 @@ func Test_conn_Write(t *testing.T) { wantRecorded: fakes.CastLine(t, []byte{0x7, 0x8, 0x1, 0x2, 0x3, 0x4, 0x5}, cl), }, { - name: "three_writes_stdout_data_message_with_split_fragment_cast_header_with_terminal", - inputs: [][]byte{{0x2, 0x3, 0x1, 0x7, 0x8}, {0x80, 0x6, 0x1, 0x1, 0x2, 0x3}, {0x4, 0x5}}, - wantForwarded: []byte{0x2, 0x3, 0x1, 0x7, 0x8, 0x80, 0x6, 0x1, 0x1, 0x2, 0x3, 0x4, 0x5}, - wantRecorded: append(fakes.AsciinemaResizeMsg(t, 10, 20), fakes.CastLine(t, []byte{0x7, 0x8, 0x1, 0x2, 0x3, 0x4, 0x5}, cl)...), - height: 20, - width: 10, - hasTerm: true, - firstWrite: true, - sendInitialResize: true, + name: "three_writes_stdout_data_message_with_split_fragment_cast_header_with_terminal", + inputs: [][]byte{{0x2, 0x3, 0x1, 0x7, 0x8}, {0x80, 0x6, 0x1, 0x1, 0x2, 0x3}, {0x4, 0x5}}, + wantForwarded: []byte{0x2, 0x3, 0x1, 0x7, 0x8, 0x80, 0x6, 0x1, 0x1, 0x2, 0x3, 0x4, 0x5}, + wantRecorded: fakes.CastLine(t, []byte{0x7, 0x8, 0x1, 0x2, 0x3, 0x4, 0x5}, cl), + hasTerm: true, }, } for _, tt := range tests { @@ -164,24 +209,22 @@ func Test_conn_Write(t *testing.T) { tc := &fakes.TestConn{} sr := &fakes.TestSessionRecorder{} rec := tsrecorder.New(sr, cl, cl.Now(), true, zl.Sugar()) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() c := &conn{ - Conn: tc, - log: zl.Sugar(), - ch: sessionrecording.CastHeader{ - Width: tt.width, - Height: tt.height, - }, - rec: rec, - initialTermSizeSet: make(chan struct{}), - hasTerm: tt.hasTerm, - } - if !tt.firstWrite { - // This test case does not intend to test that cast header gets written once. - c.writeCastHeaderOnce.Do(func() {}) - } - if tt.sendInitialResize { - close(c.initialTermSizeSet) + Conn: tc, + ctx: ctx, + log: zl.Sugar(), + ch: sessionrecording.CastHeader{}, + rec: rec, + initialCastHeaderSent: make(chan struct{}), + hasTerm: tt.hasTerm, } + + c.writeCastHeaderOnce.Do(func() { + close(c.initialCastHeaderSent) + }) + for i, input := range tt.inputs { _, err := c.Write(input) if err != nil { @@ -242,19 +285,28 @@ func Test_conn_WriteRand(t *testing.T) { sr := &fakes.TestSessionRecorder{} rec := tsrecorder.New(sr, cl, cl.Now(), true, zl.Sugar()) for i := range 100 { - tc := &fakes.TestConn{} - c := &conn{ - Conn: tc, - log: zl.Sugar(), - rec: rec, - } - bb := fakes.RandomBytes(t) - for j, input := range bb { - f := func() { - c.Write(input) + t.Run(fmt.Sprintf("test_%d", i), func(t *testing.T) { + tc := &fakes.TestConn{} + c := &conn{ + Conn: tc, + log: zl.Sugar(), + rec: rec, + + ctx: context.Background(), // ctx must be non-nil. + initialCastHeaderSent: make(chan struct{}), } - testPanic(t, f, fmt.Sprintf("[%d %d] Write: panic parsing input of length %d first bytes %b current write message %+#v", i, j, len(input), firstBytes(input), c.currentWriteMsg)) - } + // Never block for random data. + c.writeCastHeaderOnce.Do(func() { + close(c.initialCastHeaderSent) + }) + bb := fakes.RandomBytes(t) + for j, input := range bb { + f := func() { + c.Write(input) + } + testPanic(t, f, fmt.Sprintf("[%d %d] Write: panic parsing input of length %d first bytes %b current write message %+#v", i, j, len(input), firstBytes(input), c.currentWriteMsg)) + } + }) } } @@ -262,7 +314,7 @@ func testPanic(t *testing.T, f func(), msg string) { t.Helper() defer func() { if r := recover(); r != nil { - t.Fatal(msg, r) + t.Fatal(msg, r, string(debug.Stack())) } }() f() diff --git a/k8s-operator/sessionrecording/ws/message.go b/k8s-operator/sessionrecording/ws/message.go index 713febec7..35667ae21 100644 --- a/k8s-operator/sessionrecording/ws/message.go +++ b/k8s-operator/sessionrecording/ws/message.go @@ -7,10 +7,10 @@ package ws import ( "encoding/binary" + "errors" "fmt" "sync/atomic" - "github.com/pkg/errors" "go.uber.org/zap" "golang.org/x/net/websocket" @@ -139,6 +139,8 @@ func (msg *message) Parse(b []byte, log *zap.SugaredLogger) (bool, error) { return false, errors.New("[unexpected] received a message fragment with no stream ID") } + // Stream ID will be one of the constants from: + // https://github.com/kubernetes/kubernetes/blob/f9ed14bf9b1119a2e091f4b487a3b54930661034/staging/src/k8s.io/apimachinery/pkg/util/remotecommand/constants.go#L57-L64 streamID := uint32(msgPayload[0]) if !isInitialFragment && msg.streamID.Load() != streamID { return false, fmt.Errorf("[unexpected] received message fragments with mismatched streamIDs %d and %d", msg.streamID.Load(), streamID) diff --git a/k8s-operator/utils.go b/k8s-operator/utils.go index 497f31b60..2acbf338d 100644 --- a/k8s-operator/utils.go +++ b/k8s-operator/utils.go @@ -27,14 +27,16 @@ type Records struct { Version string `json:"version"` // IP4 contains a mapping of DNS names to IPv4 address(es). IP4 map[string][]string `json:"ip4"` + // IP6 contains a mapping of DNS names to IPv6 address(es). + // This field is optional and will be omitted from JSON if empty. + // It enables dual-stack DNS support in Kubernetes clusters. + // +optional + IP6 map[string][]string `json:"ip6,omitempty"` } -// TailscaledConfigFileNameForCap returns a tailscaled config file name in +// TailscaledConfigFileName returns a tailscaled config file name in // format expected by containerboot for the given CapVer. -func TailscaledConfigFileNameForCap(cap tailcfg.CapabilityVersion) string { - if cap < 95 { - return "tailscaled" - } +func TailscaledConfigFileName(cap tailcfg.CapabilityVersion) string { return fmt.Sprintf("cap-%v.hujson", cap) } diff --git a/kube/certs/certs.go b/kube/certs/certs.go new file mode 100644 index 000000000..8e2e5fb43 --- /dev/null +++ b/kube/certs/certs.go @@ -0,0 +1,189 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package certs implements logic to help multiple Kubernetes replicas share TLS +// certs for a common Tailscale Service. +package certs + +import ( + "context" + "fmt" + "net" + "slices" + "sync" + "time" + + "tailscale.com/ipn" + "tailscale.com/kube/localclient" + "tailscale.com/types/logger" + "tailscale.com/util/goroutines" + "tailscale.com/util/mak" +) + +// CertManager is responsible for issuing certificates for known domains and for +// maintaining a loop that re-attempts issuance daily. +// Currently cert manager logic is only run on ingress ProxyGroup replicas that are responsible for managing certs for +// HA Ingress HTTPS endpoints ('write' replicas). +type CertManager struct { + lc localclient.LocalClient + logf logger.Logf + tracker goroutines.Tracker // tracks running goroutines + mu sync.Mutex // guards the following + // certLoops contains a map of DNS names, for which we currently need to + // manage certs to cancel functions that allow stopping a goroutine when + // we no longer need to manage certs for the DNS name. + certLoops map[string]context.CancelFunc +} + +func NewCertManager(lc localclient.LocalClient, logf logger.Logf) *CertManager { + return &CertManager{ + lc: lc, + logf: logf, + } +} + +// EnsureCertLoops ensures that, for all currently managed Service HTTPS +// endpoints, there is a cert loop responsible for issuing and ensuring the +// renewal of the TLS certs. +// ServeConfig must not be nil. +func (cm *CertManager) EnsureCertLoops(ctx context.Context, sc *ipn.ServeConfig) error { + if sc == nil { + return fmt.Errorf("[unexpected] ensureCertLoops called with nil ServeConfig") + } + currentDomains := make(map[string]bool) + const httpsPort = "443" + for _, service := range sc.Services { + for hostPort := range service.Web { + domain, port, err := net.SplitHostPort(string(hostPort)) + if err != nil { + return fmt.Errorf("[unexpected] unable to parse HostPort %s", hostPort) + } + if port != httpsPort { // HA Ingress' HTTP endpoint + continue + } + currentDomains[domain] = true + } + } + cm.mu.Lock() + defer cm.mu.Unlock() + for domain := range currentDomains { + if _, exists := cm.certLoops[domain]; !exists { + cancelCtx, cancel := context.WithCancel(ctx) + mak.Set(&cm.certLoops, domain, cancel) + // Note that most of the issuance anyway happens + // serially because the cert client has a shared lock + // that's held during any issuance. + cm.tracker.Go(func() { cm.runCertLoop(cancelCtx, domain) }) + } + } + + // Stop goroutines for domain names that are no longer in the config. + for domain, cancel := range cm.certLoops { + if !currentDomains[domain] { + cancel() + delete(cm.certLoops, domain) + } + } + return nil +} + +// runCertLoop: +// - calls localAPI certificate endpoint to ensure that certs are issued for the +// given domain name +// - calls localAPI certificate endpoint daily to ensure that certs are renewed +// - if certificate issuance failed retries after an exponential backoff period +// starting at 1 minute and capped at 24 hours. Reset the backoff once issuance succeeds. +// Note that renewal check also happens when the node receives an HTTPS request and it is possible that certs get +// renewed at that point. Renewal here is needed to prevent the shared certs from expiry in edge cases where the 'write' +// replica does not get any HTTPS requests. +// https://letsencrypt.org/docs/integration-guide/#retrying-failures +func (cm *CertManager) runCertLoop(ctx context.Context, domain string) { + const ( + normalInterval = 24 * time.Hour // regular renewal check + initialRetry = 1 * time.Minute // initial backoff after a failure + maxRetryInterval = 24 * time.Hour // max backoff period + ) + + if err := cm.waitForCertDomain(ctx, domain); err != nil { + // Best-effort, log and continue with the issuing loop. + cm.logf("error waiting for cert domain %s: %v", domain, err) + } + + timer := time.NewTimer(0) // fire off timer immediately + defer timer.Stop() + retryCount := 0 + for { + select { + case <-ctx.Done(): + return + case <-timer.C: + // We call the certificate endpoint, but don't do anything with the + // returned certs here. The call to the certificate endpoint will + // ensure that certs are issued/renewed as needed and stored in the + // relevant state store. For example, for HA Ingress 'write' replica, + // the cert and key will be stored in a Kubernetes Secret named after + // the domain for which we are issuing. + // + // Note that renewals triggered by the call to the certificates + // endpoint here and by renewal check triggered during a call to + // node's HTTPS endpoint share the same state/renewal lock mechanism, + // so we should not run into redundant issuances during concurrent + // renewal checks. + + // An issuance holds a shared lock, so we need to avoid a situation + // where other services cannot issue certs because a single one is + // holding the lock. + ctxT, cancel := context.WithTimeout(ctx, time.Second*300) + _, _, err := cm.lc.CertPair(ctxT, domain) + cancel() + if err != nil { + cm.logf("error refreshing certificate for %s: %v", domain, err) + } + var nextInterval time.Duration + // TODO(irbekrm): distinguish between LE rate limit errors and other + // error types like transient network errors. + if err == nil { + retryCount = 0 + nextInterval = normalInterval + } else { + retryCount++ + // Calculate backoff: initialRetry * 2^(retryCount-1) + // For retryCount=1: 1min * 2^0 = 1min + // For retryCount=2: 1min * 2^1 = 2min + // For retryCount=3: 1min * 2^2 = 4min + backoff := initialRetry * time.Duration(1<<(retryCount-1)) + if backoff > maxRetryInterval { + backoff = maxRetryInterval + } + nextInterval = backoff + cm.logf("Error refreshing certificate for %s (retry %d): %v. Will retry in %v\n", + domain, retryCount, err, nextInterval) + } + timer.Reset(nextInterval) + } + } +} + +// waitForCertDomain ensures the requested domain is in the list of allowed +// domains before issuing the cert for the first time. +func (cm *CertManager) waitForCertDomain(ctx context.Context, domain string) error { + w, err := cm.lc.WatchIPNBus(ctx, ipn.NotifyInitialNetMap) + if err != nil { + return fmt.Errorf("error watching IPN bus: %w", err) + } + defer w.Close() + + for { + n, err := w.Next() + if err != nil { + return err + } + if n.NetMap == nil { + continue + } + + if slices.Contains(n.NetMap.DNS.CertDomains, domain) { + return nil + } + } +} diff --git a/kube/certs/certs_test.go b/kube/certs/certs_test.go new file mode 100644 index 000000000..8434f21ae --- /dev/null +++ b/kube/certs/certs_test.go @@ -0,0 +1,250 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package certs + +import ( + "context" + "log" + "testing" + "time" + + "tailscale.com/ipn" + "tailscale.com/kube/localclient" + "tailscale.com/tailcfg" + "tailscale.com/types/netmap" +) + +// TestEnsureCertLoops tests that the certManager correctly starts and stops +// update loops for certs when the serve config changes. It tracks goroutine +// count and uses that as a validator that the expected number of cert loops are +// running. +func TestEnsureCertLoops(t *testing.T) { + tests := []struct { + name string + initialConfig *ipn.ServeConfig + updatedConfig *ipn.ServeConfig + initialGoroutines int64 // after initial serve config is applied + updatedGoroutines int64 // after updated serve config is applied + wantErr bool + }{ + { + name: "empty_serve_config", + initialConfig: &ipn.ServeConfig{}, + initialGoroutines: 0, + }, + { + name: "nil_serve_config", + initialConfig: nil, + initialGoroutines: 0, + wantErr: true, + }, + { + name: "empty_to_one_service", + initialConfig: &ipn.ServeConfig{}, + updatedConfig: &ipn.ServeConfig{ + Services: map[tailcfg.ServiceName]*ipn.ServiceConfig{ + "svc:my-app": { + Web: map[ipn.HostPort]*ipn.WebServerConfig{ + "my-app.tailnetxyz.ts.net:443": {}, + }, + }, + }, + }, + initialGoroutines: 0, + updatedGoroutines: 1, + }, + { + name: "single_service", + initialConfig: &ipn.ServeConfig{ + Services: map[tailcfg.ServiceName]*ipn.ServiceConfig{ + "svc:my-app": { + Web: map[ipn.HostPort]*ipn.WebServerConfig{ + "my-app.tailnetxyz.ts.net:443": {}, + }, + }, + }, + }, + initialGoroutines: 1, + }, + { + name: "multiple_services", + initialConfig: &ipn.ServeConfig{ + Services: map[tailcfg.ServiceName]*ipn.ServiceConfig{ + "svc:my-app": { + Web: map[ipn.HostPort]*ipn.WebServerConfig{ + "my-app.tailnetxyz.ts.net:443": {}, + }, + }, + "svc:my-other-app": { + Web: map[ipn.HostPort]*ipn.WebServerConfig{ + "my-other-app.tailnetxyz.ts.net:443": {}, + }, + }, + }, + }, + initialGoroutines: 2, // one loop per domain across all services + }, + { + name: "ignore_non_https_ports", + initialConfig: &ipn.ServeConfig{ + Services: map[tailcfg.ServiceName]*ipn.ServiceConfig{ + "svc:my-app": { + Web: map[ipn.HostPort]*ipn.WebServerConfig{ + "my-app.tailnetxyz.ts.net:443": {}, + "my-app.tailnetxyz.ts.net:80": {}, + }, + }, + }, + }, + initialGoroutines: 1, // only one loop for the 443 endpoint + }, + { + name: "remove_domain", + initialConfig: &ipn.ServeConfig{ + Services: map[tailcfg.ServiceName]*ipn.ServiceConfig{ + "svc:my-app": { + Web: map[ipn.HostPort]*ipn.WebServerConfig{ + "my-app.tailnetxyz.ts.net:443": {}, + }, + }, + "svc:my-other-app": { + Web: map[ipn.HostPort]*ipn.WebServerConfig{ + "my-other-app.tailnetxyz.ts.net:443": {}, + }, + }, + }, + }, + updatedConfig: &ipn.ServeConfig{ + Services: map[tailcfg.ServiceName]*ipn.ServiceConfig{ + "svc:my-app": { + Web: map[ipn.HostPort]*ipn.WebServerConfig{ + "my-app.tailnetxyz.ts.net:443": {}, + }, + }, + }, + }, + initialGoroutines: 2, // initially two loops (one per service) + updatedGoroutines: 1, // one loop after removing service2 + }, + { + name: "add_domain", + initialConfig: &ipn.ServeConfig{ + Services: map[tailcfg.ServiceName]*ipn.ServiceConfig{ + "svc:my-app": { + Web: map[ipn.HostPort]*ipn.WebServerConfig{ + "my-app.tailnetxyz.ts.net:443": {}, + }, + }, + }, + }, + updatedConfig: &ipn.ServeConfig{ + Services: map[tailcfg.ServiceName]*ipn.ServiceConfig{ + "svc:my-app": { + Web: map[ipn.HostPort]*ipn.WebServerConfig{ + "my-app.tailnetxyz.ts.net:443": {}, + }, + }, + "svc:my-other-app": { + Web: map[ipn.HostPort]*ipn.WebServerConfig{ + "my-other-app.tailnetxyz.ts.net:443": {}, + }, + }, + }, + }, + initialGoroutines: 1, + updatedGoroutines: 2, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + notifyChan := make(chan ipn.Notify) + go func() { + for { + notifyChan <- ipn.Notify{ + NetMap: &netmap.NetworkMap{ + DNS: tailcfg.DNSConfig{ + CertDomains: []string{ + "my-app.tailnetxyz.ts.net", + "my-other-app.tailnetxyz.ts.net", + }, + }, + }, + } + } + }() + cm := &CertManager{ + lc: &localclient.FakeLocalClient{ + FakeIPNBusWatcher: localclient.FakeIPNBusWatcher{ + NotifyChan: notifyChan, + }, + }, + logf: log.Printf, + certLoops: make(map[string]context.CancelFunc), + } + + allDone := make(chan bool, 1) + defer cm.tracker.AddDoneCallback(func() { + cm.mu.Lock() + defer cm.mu.Unlock() + if cm.tracker.RunningGoroutines() > 0 { + return + } + select { + case allDone <- true: + default: + } + })() + + err := cm.EnsureCertLoops(ctx, tt.initialConfig) + if (err != nil) != tt.wantErr { + t.Fatalf("ensureCertLoops() error = %v", err) + } + + if got := cm.tracker.RunningGoroutines(); got != tt.initialGoroutines { + t.Errorf("after initial config: got %d running goroutines, want %d", got, tt.initialGoroutines) + } + + if tt.updatedConfig != nil { + if err := cm.EnsureCertLoops(ctx, tt.updatedConfig); err != nil { + t.Fatalf("ensureCertLoops() error on update = %v", err) + } + + // Although starting goroutines and cancelling + // the context happens in the main goroutine, it + // the actual goroutine exit when a context is + // cancelled does not- so wait for a bit for the + // running goroutine count to reach the expected + // number. + deadline := time.After(5 * time.Second) + for { + if got := cm.tracker.RunningGoroutines(); got == tt.updatedGoroutines { + break + } + select { + case <-deadline: + t.Fatalf("timed out waiting for goroutine count to reach %d, currently at %d", + tt.updatedGoroutines, cm.tracker.RunningGoroutines()) + case <-time.After(10 * time.Millisecond): + continue + } + } + } + + if tt.updatedGoroutines == 0 { + return // no goroutines to wait for + } + // cancel context to make goroutines exit + cancel() + select { + case <-time.After(5 * time.Second): + t.Fatal("timed out waiting for goroutine to finish") + case <-allDone: + } + }) + } +} diff --git a/kube/egressservices/egressservices.go b/kube/egressservices/egressservices.go index f634458d9..56c874f31 100644 --- a/kube/egressservices/egressservices.go +++ b/kube/egressservices/egressservices.go @@ -9,16 +9,19 @@ package egressservices import ( - "encoding" - "fmt" + "encoding/json" "net/netip" - "strconv" - "strings" ) -// KeyEgressServices is name of the proxy state Secret field that contains the -// currently applied egress proxy config. -const KeyEgressServices = "egress-services" +const ( + // KeyEgressServices is name of the proxy state Secret field that contains the + // currently applied egress proxy config. + KeyEgressServices = "egress-services" + + // KeyHEPPings is the number of times an egress service health check endpoint needs to be pinged to ensure that + // each currently configured backend is hit. In practice, it depends on the number of ProxyGroup replicas. + KeyHEPPings = "hep-pings" +) // Configs contains the desired configuration for egress services keyed by // service name. @@ -27,11 +30,12 @@ type Configs map[string]Config // Config is an egress service configuration. // TODO(irbekrm): version this? type Config struct { + HealthCheckEndpoint string `json:"healthCheckEndpoint"` // TailnetTarget is the target to which cluster traffic for this service // should be proxied. TailnetTarget TailnetTarget `json:"tailnetTarget"` // Ports contains mappings for ports that can be accessed on the tailnet target. - Ports map[PortMap]struct{} `json:"ports"` + Ports PortMaps `json:"ports"` } // TailnetTarget is the tailnet target to which traffic for the egress service @@ -52,41 +56,44 @@ type PortMap struct { TargetPort uint16 `json:"targetPort"` } -// PortMap is used as a Config.Ports map key. Config needs to be serialized/deserialized to/from JSON. JSON only -// supports string map keys, so we need to implement TextMarshaler/TextUnmarshaler to convert PortMap to string and -// back. -var _ encoding.TextMarshaler = PortMap{} -var _ encoding.TextUnmarshaler = &PortMap{} - -func (pm *PortMap) UnmarshalText(t []byte) error { - tt := string(t) - ss := strings.Split(tt, ":") - if len(ss) != 3 { - return fmt.Errorf("error unmarshalling portmap from JSON, wants a portmap in form ::, got %q", tt) - } - pm.Protocol = ss[0] - matchPort, err := strconv.ParseUint(ss[1], 10, 16) - if err != nil { - return fmt.Errorf("error converting match port %q to uint16: %w", ss[1], err) +type PortMaps map[PortMap]struct{} + +// PortMaps is a list of PortMap structs, however, we want to use it as a set +// with efficient lookups in code. It implements custom JSON marshalling +// methods to convert between being a list in JSON and a set (map with empty +// values) in code. +var _ json.Marshaler = &PortMaps{} +var _ json.Marshaler = PortMaps{} +var _ json.Unmarshaler = &PortMaps{} + +func (p *PortMaps) UnmarshalJSON(data []byte) error { + *p = make(map[PortMap]struct{}) + + var v []PortMap + if err := json.Unmarshal(data, &v); err != nil { + return err } - pm.MatchPort = uint16(matchPort) - targetPort, err := strconv.ParseUint(ss[2], 10, 16) - if err != nil { - return fmt.Errorf("error converting target port %q to uint16: %w", ss[2], err) + + for _, pm := range v { + (*p)[pm] = struct{}{} } - pm.TargetPort = uint16(targetPort) + return nil } -func (pm PortMap) MarshalText() ([]byte, error) { - s := fmt.Sprintf("%s:%d:%d", pm.Protocol, pm.MatchPort, pm.TargetPort) - return []byte(s), nil +func (p PortMaps) MarshalJSON() ([]byte, error) { + v := make([]PortMap, 0, len(p)) + for pm := range p { + v = append(v, pm) + } + + return json.Marshal(v) } // Status represents the currently configured firewall rules for all egress // services for a proxy identified by the PodIP. type Status struct { - PodIP string `json:"podIP"` + PodIPv4 string `json:"podIPv4"` // All egress service status keyed by service name. Services map[string]*ServiceStatus `json:"services"` } @@ -94,7 +101,7 @@ type Status struct { // ServiceStatus is the currently configured firewall rules for an egress // service. type ServiceStatus struct { - Ports map[PortMap]struct{} `json:"ports"` + Ports PortMaps `json:"ports"` // TailnetTargetIPs are the tailnet target IPs that were used to // configure these firewall rules. For a TailnetTarget with IP set, this // is the same as IP. diff --git a/kube/egressservices/egressservices_test.go b/kube/egressservices/egressservices_test.go index 5e5651e77..806ad91be 100644 --- a/kube/egressservices/egressservices_test.go +++ b/kube/egressservices/egressservices_test.go @@ -5,8 +5,9 @@ package egressservices import ( "encoding/json" - "reflect" "testing" + + "github.com/google/go-cmp/cmp" ) func Test_jsonUnmarshalConfig(t *testing.T) { @@ -18,7 +19,7 @@ func Test_jsonUnmarshalConfig(t *testing.T) { }{ { name: "success", - bs: []byte(`{"ports":{"tcp:4003:80":{}}}`), + bs: []byte(`{"ports":[{"protocol":"tcp","matchPort":4003,"targetPort":80}]}`), wantsCfg: Config{Ports: map[PortMap]struct{}{{Protocol: "tcp", MatchPort: 4003, TargetPort: 80}: {}}}, }, { @@ -34,8 +35,8 @@ func Test_jsonUnmarshalConfig(t *testing.T) { if gotErr := json.Unmarshal(tt.bs, &cfg); (gotErr != nil) != tt.wantsErr { t.Errorf("json.Unmarshal returned error %v, wants error %v", gotErr, tt.wantsErr) } - if !reflect.DeepEqual(cfg, tt.wantsCfg) { - t.Errorf("json.Unmarshal produced Config %v, wants Config %v", cfg, tt.wantsCfg) + if diff := cmp.Diff(cfg, tt.wantsCfg); diff != "" { + t.Errorf("unexpected secrets (-got +want):\n%s", diff) } }) } @@ -54,12 +55,12 @@ func Test_jsonMarshalConfig(t *testing.T) { protocol: "tcp", matchPort: 4003, targetPort: 80, - wantsBs: []byte(`{"tailnetTarget":{"ip":"","fqdn":""},"ports":{"tcp:4003:80":{}}}`), + wantsBs: []byte(`{"healthCheckEndpoint":"","tailnetTarget":{"ip":"","fqdn":""},"ports":[{"protocol":"tcp","matchPort":4003,"targetPort":80}]}`), }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - cfg := Config{Ports: map[PortMap]struct{}{{ + cfg := Config{Ports: PortMaps{{ Protocol: tt.protocol, MatchPort: tt.matchPort, TargetPort: tt.targetPort}: {}}} @@ -68,8 +69,8 @@ func Test_jsonMarshalConfig(t *testing.T) { if gotErr != nil { t.Errorf("json.Marshal(%+#v) returned unexpected error %v", cfg, gotErr) } - if !reflect.DeepEqual(gotBs, tt.wantsBs) { - t.Errorf("json.Marshal(%+#v) returned '%v', wants '%v'", cfg, string(gotBs), string(tt.wantsBs)) + if diff := cmp.Diff(gotBs, tt.wantsBs); diff != "" { + t.Errorf("unexpected secrets (-got +want):\n%s", diff) } }) } diff --git a/kube/health/healthz.go b/kube/health/healthz.go new file mode 100644 index 000000000..c8cfcc7ec --- /dev/null +++ b/kube/health/healthz.go @@ -0,0 +1,84 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !plan9 + +// Package health contains shared types and underlying methods for serving +// a `/healthz` endpoint for containerboot and k8s-proxy. +package health + +import ( + "context" + "fmt" + "net/http" + "sync" + + "tailscale.com/client/local" + "tailscale.com/ipn" + "tailscale.com/kube/kubetypes" + "tailscale.com/types/logger" +) + +// Healthz is a simple health check server, if enabled it returns 200 OK if +// this tailscale node currently has at least one tailnet IP address else +// returns 503. +type Healthz struct { + sync.Mutex + hasAddrs bool + podIPv4 string + logger logger.Logf +} + +func (h *Healthz) ServeHTTP(w http.ResponseWriter, r *http.Request) { + h.Lock() + defer h.Unlock() + + if h.hasAddrs { + w.Header().Add(kubetypes.PodIPv4Header, h.podIPv4) + if _, err := w.Write([]byte("ok")); err != nil { + http.Error(w, fmt.Sprintf("error writing status: %v", err), http.StatusInternalServerError) + } + } else { + http.Error(w, "node currently has no tailscale IPs", http.StatusServiceUnavailable) + } +} + +func (h *Healthz) Update(healthy bool) { + h.Lock() + defer h.Unlock() + + if h.hasAddrs != healthy { + h.logger("Setting healthy %v", healthy) + } + h.hasAddrs = healthy +} + +func (h *Healthz) MonitorHealth(ctx context.Context, lc *local.Client) error { + w, err := lc.WatchIPNBus(ctx, ipn.NotifyInitialNetMap) + if err != nil { + return fmt.Errorf("failed to watch IPN bus: %w", err) + } + + for { + n, err := w.Next() + if err != nil { + return err + } + + if n.NetMap != nil { + h.Update(n.NetMap.SelfNode.Addresses().Len() != 0) + } + } +} + +// RegisterHealthHandlers registers a simple health handler at /healthz. +// A containerized tailscale instance is considered healthy if +// it has at least one tailnet IP address. +func RegisterHealthHandlers(mux *http.ServeMux, podIPv4 string, logger logger.Logf) *Healthz { + h := &Healthz{ + podIPv4: podIPv4, + logger: logger, + } + mux.Handle("GET /healthz", h) + return h +} diff --git a/kube/ingressservices/ingressservices.go b/kube/ingressservices/ingressservices.go new file mode 100644 index 000000000..f79410761 --- /dev/null +++ b/kube/ingressservices/ingressservices.go @@ -0,0 +1,53 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package ingressservices contains shared types for exposing Kubernetes Services to tailnet. +// These are split into a separate package for consumption of +// non-Kubernetes shared libraries and binaries. Be mindful of not increasing +// dependency size for those consumers when adding anything new here. +package ingressservices + +import "net/netip" + +// IngressConfigKey is the key at which both the desired ingress firewall +// configuration is stored in the ingress proxies' ConfigMap and at which the +// recorded firewall configuration status is stored in the proxies' state +// Secrets. +const IngressConfigKey = "ingress-config.json" + +// Configs contains the desired configuration for ingress proxies firewall. Map +// keys are Tailscale Service names. +type Configs map[string]Config + +// GetConfig returns the desired configuration for the given Tailscale Service name. +func (cfgs *Configs) GetConfig(name string) *Config { + if cfgs == nil { + return nil + } + if cfg, ok := (*cfgs)[name]; ok { + return &cfg + } + return nil +} + +// Status contains the recorded firewall configuration status for a specific +// ingress proxy Pod. +// Pod IPs are used to identify the ingress proxy Pod. +type Status struct { + Configs Configs `json:"configs,omitempty"` + PodIPv4 string `json:"podIPv4,omitempty"` + PodIPv6 string `json:"podIPv6,omitempty"` +} + +// Config is an ingress service configuration. +type Config struct { + IPv4Mapping *Mapping `json:"IPv4Mapping,omitempty"` + IPv6Mapping *Mapping `json:"IPv6Mapping,omitempty"` +} + +// Mapping describes a rule that forwards traffic from Tailscale Service IP to a +// Kubernetes Service IP. +type Mapping struct { + TailscaleServiceIP netip.Addr `json:"TailscaleServiceIP"` + ClusterIP netip.Addr `json:"ClusterIP"` +} diff --git a/kube/k8s-proxy/conf/conf.go b/kube/k8s-proxy/conf/conf.go new file mode 100644 index 000000000..529495243 --- /dev/null +++ b/kube/k8s-proxy/conf/conf.go @@ -0,0 +1,128 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !plan9 + +// Package conf contains code to load, manipulate, and access config file +// settings for k8s-proxy. +package conf + +import ( + "encoding/json" + "errors" + "fmt" + "net/netip" + + "github.com/tailscale/hujson" + "tailscale.com/kube/kubetypes" + "tailscale.com/tailcfg" + "tailscale.com/types/opt" +) + +const v1Alpha1 = "v1alpha1" + +// Config describes a config file. +type Config struct { + Raw []byte // raw bytes, in HuJSON form + Std []byte // standardized JSON form + Version string // "v1alpha1" + + // Parsed is the parsed config, converted from its raw bytes version to the + // latest known format. + Parsed ConfigV1Alpha1 +} + +// VersionedConfig allows specifying config at the root of the object, or in +// a versioned sub-object. +// e.g. {"version": "v1alpha1", "authKey": "abc123"} +// or {"version": "v1beta1", "a-beta-config": "a-beta-value", "v1alpha1": {"authKey": "abc123"}} +type VersionedConfig struct { + Version string `json:",omitempty"` // "v1alpha1" + + // Latest version of the config. + *ConfigV1Alpha1 + + // Backwards compatibility version(s) of the config. Fields and sub-fields + // from here should only be added to, never changed in place. + V1Alpha1 *ConfigV1Alpha1 `json:",omitempty"` + // V1Beta1 *ConfigV1Beta1 `json:",omitempty"` // Not yet used. +} + +type ConfigV1Alpha1 struct { + AuthKey *string `json:",omitempty"` // Tailscale auth key to use. + State *string `json:",omitempty"` // Path to the Tailscale state. + LogLevel *string `json:",omitempty"` // "debug", "info". Defaults to "info". + App *string `json:",omitempty"` // e.g. kubetypes.AppProxyGroupKubeAPIServer + ServerURL *string `json:",omitempty"` // URL of the Tailscale coordination server. + LocalAddr *string `json:",omitempty"` // The address to use for serving HTTP health checks and metrics (defaults to all interfaces). + LocalPort *uint16 `json:",omitempty"` // The port to use for serving HTTP health checks and metrics (defaults to 9002). + MetricsEnabled opt.Bool `json:",omitempty"` // Serve metrics on :/metrics. + HealthCheckEnabled opt.Bool `json:",omitempty"` // Serve health check on :/metrics. + + // TODO(tomhjp): The remaining fields should all be reloadable during + // runtime, but currently missing most of the APIServerProxy fields. + Hostname *string `json:",omitempty"` // Tailscale device hostname. + AcceptRoutes opt.Bool `json:",omitempty"` // Accepts routes advertised by other Tailscale nodes. + AdvertiseServices []string `json:",omitempty"` // Tailscale Services to advertise. + APIServerProxy *APIServerProxyConfig `json:",omitempty"` // Config specific to the API Server proxy. + StaticEndpoints []netip.AddrPort `json:",omitempty"` // StaticEndpoints are additional, user-defined endpoints that this node should advertise amongst its wireguard endpoints. +} + +type APIServerProxyConfig struct { + Enabled opt.Bool `json:",omitempty"` // Whether to enable the API Server proxy. + Mode *kubetypes.APIServerProxyMode `json:",omitempty"` // "auth" or "noauth" mode. + ServiceName *tailcfg.ServiceName `json:",omitempty"` // Name of the Tailscale Service to advertise. + IssueCerts opt.Bool `json:",omitempty"` // Whether this replica should issue TLS certs for the Tailscale Service. +} + +// Load reads and parses the config file at the provided path on disk. +func Load(raw []byte) (c Config, err error) { + c.Raw = raw + c.Std, err = hujson.Standardize(c.Raw) + if err != nil { + return c, fmt.Errorf("error parsing config as HuJSON/JSON: %w", err) + } + var ver VersionedConfig + if err := json.Unmarshal(c.Std, &ver); err != nil { + return c, fmt.Errorf("error parsing config: %w", err) + } + rootV1Alpha1 := (ver.Version == v1Alpha1) + backCompatV1Alpha1 := (ver.V1Alpha1 != nil) + switch { + case ver.Version == "": + return c, errors.New("error parsing config: no \"version\" field provided") + case rootV1Alpha1 && backCompatV1Alpha1: + // Exactly one of these should be set. + return c, errors.New("error parsing config: both root and v1alpha1 config provided") + case rootV1Alpha1 != backCompatV1Alpha1: + c.Version = v1Alpha1 + switch { + case rootV1Alpha1 && ver.ConfigV1Alpha1 != nil: + c.Parsed = *ver.ConfigV1Alpha1 + case backCompatV1Alpha1: + c.Parsed = *ver.V1Alpha1 + default: + c.Parsed = ConfigV1Alpha1{} + } + default: + return c, fmt.Errorf("error parsing config: unsupported \"version\" value %q; want \"%s\"", ver.Version, v1Alpha1) + } + + return c, nil +} + +func (c *Config) GetLocalAddr() string { + if c.Parsed.LocalAddr == nil { + return "[::]" + } + + return *c.Parsed.LocalAddr +} + +func (c *Config) GetLocalPort() uint16 { + if c.Parsed.LocalPort == nil { + return uint16(9002) + } + + return *c.Parsed.LocalPort +} diff --git a/kube/k8s-proxy/conf/conf_test.go b/kube/k8s-proxy/conf/conf_test.go new file mode 100644 index 000000000..3082be1ba --- /dev/null +++ b/kube/k8s-proxy/conf/conf_test.go @@ -0,0 +1,79 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !plan9 + +package conf + +import ( + "strings" + "testing" + + "github.com/google/go-cmp/cmp" + "tailscale.com/types/ptr" +) + +// Test that the config file can be at the root of the object, or in a versioned sub-object. +// or {"version": "v1beta1", "a-beta-config": "a-beta-value", "v1alpha1": {"authKey": "abc123"}} +func TestVersionedConfig(t *testing.T) { + testCases := map[string]struct { + inputConfig string + expectedConfig ConfigV1Alpha1 + expectedError string + }{ + "root_config_v1alpha1": { + inputConfig: `{"version": "v1alpha1", "authKey": "abc123"}`, + expectedConfig: ConfigV1Alpha1{AuthKey: ptr.To("abc123")}, + }, + "backwards_compat_v1alpha1_config": { + // Client doesn't know about v1beta1, so it should read in v1alpha1. + inputConfig: `{"version": "v1beta1", "beta-key": "beta-value", "authKey": "def456", "v1alpha1": {"authKey": "abc123"}}`, + expectedConfig: ConfigV1Alpha1{AuthKey: ptr.To("abc123")}, + }, + "unknown_key_allowed": { + // Adding new keys to the config doesn't require a version bump. + inputConfig: `{"version": "v1alpha1", "unknown-key": "unknown-value", "authKey": "abc123"}`, + expectedConfig: ConfigV1Alpha1{AuthKey: ptr.To("abc123")}, + }, + "version_only_no_authkey": { + inputConfig: `{"version": "v1alpha1"}`, + expectedConfig: ConfigV1Alpha1{}, + }, + "both_config_v1alpha1": { + inputConfig: `{"version": "v1alpha1", "authKey": "abc123", "v1alpha1": {"authKey": "def456"}}`, + expectedError: "both root and v1alpha1 config provided", + }, + "empty_config": { + inputConfig: `{}`, + expectedError: `no "version" field provided`, + }, + "v1beta1_without_backwards_compat": { + inputConfig: `{"version": "v1beta1", "beta-key": "beta-value", "authKey": "def456"}`, + expectedError: `unsupported "version" value "v1beta1"; want "v1alpha1"`, + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + cfg, err := Load([]byte(tc.inputConfig)) + switch { + case tc.expectedError == "" && err != nil: + t.Fatalf("unexpected error: %v", err) + case tc.expectedError != "": + if err == nil { + t.Fatalf("expected error %q, got nil", tc.expectedError) + } else if !strings.Contains(err.Error(), tc.expectedError) { + t.Fatalf("expected error %q, got %q", tc.expectedError, err.Error()) + } + return + } + if cfg.Version != "v1alpha1" { + t.Fatalf("expected version %q, got %q", "v1alpha1", cfg.Version) + } + // Diff actual vs expected config. + if diff := cmp.Diff(cfg.Parsed, tc.expectedConfig); diff != "" { + t.Fatalf("Unexpected parsed config (-got +want):\n%s", diff) + } + }) + } +} diff --git a/kube/kubeapi/api.go b/kube/kubeapi/api.go index 0e42437a6..e62bd6e2b 100644 --- a/kube/kubeapi/api.go +++ b/kube/kubeapi/api.go @@ -7,7 +7,9 @@ // dependency size for those consumers when adding anything new here. package kubeapi -import "time" +import ( + "time" +) // Note: The API types are copied from k8s.io/api{,machinery} to not introduce a // module dependency on the Kubernetes API as it pulls in many more dependencies. @@ -151,6 +153,65 @@ type Secret struct { Data map[string][]byte `json:"data,omitempty"` } +// SecretList is a list of Secret objects. +type SecretList struct { + TypeMeta `json:",inline"` + ObjectMeta `json:"metadata"` + + Items []Secret `json:"items,omitempty"` +} + +// Event contains a subset of fields from corev1.Event. +// https://github.com/kubernetes/api/blob/6cc44b8953ae704d6d9ec2adf32e7ae19199ea9f/core/v1/types.go#L7034 +// It is copied here to avoid having to import kube libraries. +type Event struct { + TypeMeta `json:",inline"` + ObjectMeta `json:"metadata"` + Message string `json:"message,omitempty"` + Reason string `json:"reason,omitempty"` + Source EventSource `json:"source,omitempty"` // who is emitting this Event + Type string `json:"type,omitempty"` // Normal or Warning + // InvolvedObject is the subject of the Event. `kubectl describe` will, for most object types, display any + // currently present cluster Events matching the object (but you probably want to set UID for this to work). + InvolvedObject ObjectReference `json:"involvedObject"` + Count int32 `json:"count,omitempty"` // how many times Event was observed + FirstTimestamp time.Time `json:"firstTimestamp,omitempty"` + LastTimestamp time.Time `json:"lastTimestamp,omitempty"` +} + +// EventSource includes a subset of fields from corev1.EventSource. +// https://github.com/kubernetes/api/blob/6cc44b8953ae704d6d9ec2adf32e7ae19199ea9f/core/v1/types.go#L7007 +// It is copied here to avoid having to import kube libraries. +type EventSource struct { + // Component is the name of the component that is emitting the Event. + Component string `json:"component,omitempty"` +} + +// ObjectReference contains a subset of fields from corev1.ObjectReference. +// https://github.com/kubernetes/api/blob/6cc44b8953ae704d6d9ec2adf32e7ae19199ea9f/core/v1/types.go#L6902 +// It is copied here to avoid having to import kube libraries. +type ObjectReference struct { + // Kind of the referent. + // More info: https://git.k8s.io/community/contributors/devel/sig-architecture/api-conventions.md#types-kinds + // +optional + Kind string `json:"kind,omitempty"` + // Namespace of the referent. + // More info: https://kubernetes.io/docs/concepts/overview/working-with-objects/namespaces/ + // +optional + Namespace string `json:"namespace,omitempty"` + // Name of the referent. + // More info: https://kubernetes.io/docs/concepts/overview/working-with-objects/names/#names + // +optional + Name string `json:"name,omitempty"` + // UID of the referent. + // More info: https://kubernetes.io/docs/concepts/overview/working-with-objects/names/#uids + // +optional + UID string `json:"uid,omitempty"` + // API version of the referent. + // +optional + APIVersion string `json:"apiVersion,omitempty"` +} + // Status is a return value for calls that don't return other objects. type Status struct { TypeMeta `json:",inline"` @@ -186,6 +247,6 @@ type Status struct { Code int `json:"code,omitempty"` } -func (s *Status) Error() string { +func (s Status) Error() string { return s.Message } diff --git a/kube/kubeclient/client.go b/kube/kubeclient/client.go index e8ddec75d..0ed960f4d 100644 --- a/kube/kubeclient/client.go +++ b/kube/kubeclient/client.go @@ -15,6 +15,7 @@ import ( "crypto/tls" "crypto/x509" "encoding/json" + "errors" "fmt" "io" "log" @@ -23,16 +24,20 @@ import ( "net/url" "os" "path/filepath" + "strings" "sync" "time" "tailscale.com/kube/kubeapi" - "tailscale.com/util/multierr" + "tailscale.com/tstime" ) const ( saPath = "/var/run/secrets/kubernetes.io/serviceaccount" defaultURL = "https://kubernetes.default.svc" + + TypeSecrets = "secrets" + typeEvents = "events" ) // rootPathForTests is set by tests to override the root path to the @@ -55,10 +60,16 @@ func readFile(n string) ([]byte, error) { // It expects to be run inside a cluster. type Client interface { GetSecret(context.Context, string) (*kubeapi.Secret, error) + ListSecrets(context.Context, map[string]string) (*kubeapi.SecretList, error) UpdateSecret(context.Context, *kubeapi.Secret) error CreateSecret(context.Context, *kubeapi.Secret) error + // Event attempts to ensure an event with the specified options associated with the Pod in which we are + // currently running. This is best effort - if the client is not able to create events, this operation will be a + // no-op. If there is already an Event with the given reason for the current Pod, it will get updated (only + // count and timestamp are expected to change), else a new event will be created. + Event(_ context.Context, typ, reason, msg string) error StrategicMergePatchSecret(context.Context, string, *kubeapi.Secret, string) error - JSONPatchSecret(context.Context, string, []JSONPatch) error + JSONPatchResource(_ context.Context, resourceName string, resourceType string, patches []JSONPatch) error CheckSecretPermissions(context.Context, string) (bool, bool, error) SetDialer(dialer func(context.Context, string, string) (net.Conn, error)) SetURL(string) @@ -66,15 +77,24 @@ type Client interface { type client struct { mu sync.Mutex + name string url string - ns string + podName string + podUID string + ns string // Pod namespace client *http.Client token string tokenExpiry time.Time + cl tstime.Clock + // hasEventsPerms is true if client can emit Events for the Pod in which it runs. If it is set to false any + // calls to Events() will be a no-op. + hasEventsPerms bool + // kubeAPIRequest sends a request to the kube API server. It can set to a fake in tests. + kubeAPIRequest kubeAPIRequestFunc } // New returns a new client -func New() (Client, error) { +func New(name string) (Client, error) { ns, err := readFile("namespace") if err != nil { return nil, err @@ -87,9 +107,11 @@ func New() (Client, error) { if ok := cp.AppendCertsFromPEM(caCert); !ok { return nil, fmt.Errorf("kube: error in creating root cert pool") } - return &client{ - url: defaultURL, - ns: string(ns), + c := &client{ + url: defaultURL, + ns: string(ns), + name: name, + cl: tstime.DefaultClock{}, client: &http.Client{ Transport: &http.Transport{ TLSClientConfig: &tls.Config{ @@ -97,7 +119,10 @@ func New() (Client, error) { }, }, }, - }, nil + } + c.kubeAPIRequest = newKubeAPIRequest(c) + c.setEventPerms() + return c, nil } // SetURL sets the URL to use for the Kubernetes API. @@ -115,14 +140,14 @@ func (c *client) SetDialer(dialer func(ctx context.Context, network, addr string func (c *client) expireToken() { c.mu.Lock() defer c.mu.Unlock() - c.tokenExpiry = time.Now() + c.tokenExpiry = c.cl.Now() } func (c *client) getOrRenewToken() (string, error) { c.mu.Lock() defer c.mu.Unlock() tk, te := c.token, c.tokenExpiry - if time.Now().Before(te) { + if c.cl.Now().Before(te) { return tk, nil } @@ -131,17 +156,10 @@ func (c *client) getOrRenewToken() (string, error) { return "", err } c.token = string(tkb) - c.tokenExpiry = time.Now().Add(30 * time.Minute) + c.tokenExpiry = c.cl.Now().Add(30 * time.Minute) return c.token, nil } -func (c *client) secretURL(name string) string { - if name == "" { - return fmt.Sprintf("%s/api/v1/namespaces/%s/secrets", c.url, c.ns) - } - return fmt.Sprintf("%s/api/v1/namespaces/%s/secrets/%s", c.url, c.ns, name) -} - func getError(resp *http.Response) error { if resp.StatusCode == 200 || resp.StatusCode == 201 { // These are the only success codes returned by the Kubernetes API. @@ -161,36 +179,41 @@ func setHeader(key, value string) func(*http.Request) { } } -// doRequest performs an HTTP request to the Kubernetes API. -// If in is not nil, it is expected to be a JSON-encodable object and will be -// sent as the request body. -// If out is not nil, it is expected to be a pointer to an object that can be -// decoded from JSON. -// If the request fails with a 401, the token is expired and a new one is -// requested. -func (c *client) doRequest(ctx context.Context, method, url string, in, out any, opts ...func(*http.Request)) error { - req, err := c.newRequest(ctx, method, url, in) - if err != nil { - return err - } - for _, opt := range opts { - opt(req) - } - resp, err := c.client.Do(req) - if err != nil { - return err - } - defer resp.Body.Close() - if err := getError(resp); err != nil { - if st, ok := err.(*kubeapi.Status); ok && st.Code == 401 { - c.expireToken() +type kubeAPIRequestFunc func(ctx context.Context, method, url string, in, out any, opts ...func(*http.Request)) error + +// newKubeAPIRequest returns a function that can perform an HTTP request to the Kubernetes API. +func newKubeAPIRequest(c *client) kubeAPIRequestFunc { + // If in is not nil, it is expected to be a JSON-encodable object and will be + // sent as the request body. + // If out is not nil, it is expected to be a pointer to an object that can be + // decoded from JSON. + // If the request fails with a 401, the token is expired and a new one is + // requested. + f := func(ctx context.Context, method, url string, in, out any, opts ...func(*http.Request)) error { + req, err := c.newRequest(ctx, method, url, in) + if err != nil { + return err } - return err - } - if out != nil { - return json.NewDecoder(resp.Body).Decode(out) + for _, opt := range opts { + opt(req) + } + resp, err := c.client.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + if err := getError(resp); err != nil { + if st, ok := err.(*kubeapi.Status); ok && st.Code == 401 { + c.expireToken() + } + return err + } + if out != nil { + return json.NewDecoder(resp.Body).Decode(out) + } + return nil } - return nil + return f } func (c *client) newRequest(ctx context.Context, method, url string, in any) (*http.Request, error) { @@ -226,25 +249,39 @@ func (c *client) newRequest(ctx context.Context, method, url string, in any) (*h // GetSecret fetches the secret from the Kubernetes API. func (c *client) GetSecret(ctx context.Context, name string) (*kubeapi.Secret, error) { s := &kubeapi.Secret{Data: make(map[string][]byte)} - if err := c.doRequest(ctx, "GET", c.secretURL(name), nil, s); err != nil { + if err := c.kubeAPIRequest(ctx, "GET", c.resourceURL(name, TypeSecrets, ""), nil, s); err != nil { return nil, err } return s, nil } +// ListSecrets fetches the secret from the Kubernetes API. +func (c *client) ListSecrets(ctx context.Context, selector map[string]string) (*kubeapi.SecretList, error) { + sl := new(kubeapi.SecretList) + s := make([]string, 0, len(selector)) + for key, val := range selector { + s = append(s, key+"="+url.QueryEscape(val)) + } + ss := strings.Join(s, ",") + if err := c.kubeAPIRequest(ctx, "GET", c.resourceURL("", TypeSecrets, ss), nil, sl); err != nil { + return nil, err + } + return sl, nil +} + // CreateSecret creates a secret in the Kubernetes API. func (c *client) CreateSecret(ctx context.Context, s *kubeapi.Secret) error { s.Namespace = c.ns - return c.doRequest(ctx, "POST", c.secretURL(""), s, nil) + return c.kubeAPIRequest(ctx, "POST", c.resourceURL("", TypeSecrets, ""), s, nil) } // UpdateSecret updates a secret in the Kubernetes API. func (c *client) UpdateSecret(ctx context.Context, s *kubeapi.Secret) error { - return c.doRequest(ctx, "PUT", c.secretURL(s.Name), s, nil) + return c.kubeAPIRequest(ctx, "PUT", c.resourceURL(s.Name, TypeSecrets, ""), s, nil) } // JSONPatch is a JSON patch operation. -// It currently (2023-03-02) only supports "add" and "remove" operations. +// It currently (2024-11-15) only supports "add", "remove" and "replace" operations. // // https://tools.ietf.org/html/rfc6902 type JSONPatch struct { @@ -253,22 +290,22 @@ type JSONPatch struct { Value any `json:"value,omitempty"` } -// JSONPatchSecret updates a secret in the Kubernetes API using a JSON patch. -// It currently (2023-03-02) only supports "add" and "remove" operations. -func (c *client) JSONPatchSecret(ctx context.Context, name string, patch []JSONPatch) error { - for _, p := range patch { +// JSONPatchResource updates a resource in the Kubernetes API using a JSON patch. +// It currently (2024-11-15) only supports "add", "remove" and "replace" operations. +func (c *client) JSONPatchResource(ctx context.Context, name, typ string, patches []JSONPatch) error { + for _, p := range patches { if p.Op != "remove" && p.Op != "add" && p.Op != "replace" { return fmt.Errorf("unsupported JSON patch operation: %q", p.Op) } } - return c.doRequest(ctx, "PATCH", c.secretURL(name), patch, nil, setHeader("Content-Type", "application/json-patch+json")) + return c.kubeAPIRequest(ctx, "PATCH", c.resourceURL(name, typ, ""), patches, nil, setHeader("Content-Type", "application/json-patch+json")) } // StrategicMergePatchSecret updates a secret in the Kubernetes API using a // strategic merge patch. // If a fieldManager is provided, it will be used to track the patch. func (c *client) StrategicMergePatchSecret(ctx context.Context, name string, s *kubeapi.Secret, fieldManager string) error { - surl := c.secretURL(name) + surl := c.resourceURL(name, TypeSecrets, "") if fieldManager != "" { uv := url.Values{ "fieldManager": {fieldManager}, @@ -277,7 +314,66 @@ func (c *client) StrategicMergePatchSecret(ctx context.Context, name string, s * } s.Namespace = c.ns s.Name = name - return c.doRequest(ctx, "PATCH", surl, s, nil, setHeader("Content-Type", "application/strategic-merge-patch+json")) + return c.kubeAPIRequest(ctx, "PATCH", surl, s, nil, setHeader("Content-Type", "application/strategic-merge-patch+json")) +} + +// Event tries to ensure an Event associated with the Pod in which we are running. It is best effort - the event will be +// created if the kube client on startup was able to determine the name and UID of this Pod from POD_NAME,POD_UID env +// vars and if permissions check for event creation succeeded. Events are keyed on opts.Reason- if an Event for the +// current Pod with that reason already exists, its count and first timestamp will be updated, else a new Event will be +// created. +func (c *client) Event(ctx context.Context, typ, reason, msg string) error { + if !c.hasEventsPerms { + return nil + } + name := c.nameForEvent(reason) + ev, err := c.getEvent(ctx, name) + now := c.cl.Now() + if err != nil { + if !IsNotFoundErr(err) { + return err + } + // Event not found - create it + ev := kubeapi.Event{ + ObjectMeta: kubeapi.ObjectMeta{ + Name: name, + Namespace: c.ns, + }, + Type: typ, + Reason: reason, + Message: msg, + Source: kubeapi.EventSource{ + Component: c.name, + }, + InvolvedObject: kubeapi.ObjectReference{ + Name: c.podName, + Namespace: c.ns, + UID: c.podUID, + Kind: "Pod", + APIVersion: "v1", + }, + + FirstTimestamp: now, + LastTimestamp: now, + Count: 1, + } + return c.kubeAPIRequest(ctx, "POST", c.resourceURL("", typeEvents, ""), &ev, nil) + } + // If the Event already exists, we patch its count and last timestamp. This ensures that when users run 'kubectl + // describe pod...', they see the event just once (but with a message of how many times it has appeared over + // last timestamp - first timestamp period of time). + count := ev.Count + 1 + countPatch := JSONPatch{ + Op: "replace", + Value: count, + Path: "/count", + } + tsPatch := JSONPatch{ + Op: "replace", + Value: now, + Path: "/lastTimestamp", + } + return c.JSONPatchResource(ctx, name, typeEvents, []JSONPatch{countPatch, tsPatch}) } // CheckSecretPermissions checks the secret access permissions of the current @@ -293,7 +389,7 @@ func (c *client) StrategicMergePatchSecret(ctx context.Context, name string, s * func (c *client) CheckSecretPermissions(ctx context.Context, secretName string) (canPatch, canCreate bool, err error) { var errs []error for _, verb := range []string{"get", "update"} { - ok, err := c.checkPermission(ctx, verb, secretName) + ok, err := c.checkPermission(ctx, verb, TypeSecrets, secretName) if err != nil { log.Printf("error checking %s permission on secret %s: %v", verb, secretName, err) } else if !ok { @@ -301,14 +397,14 @@ func (c *client) CheckSecretPermissions(ctx context.Context, secretName string) } } if len(errs) > 0 { - return false, false, multierr.New(errs...) + return false, false, errors.Join(errs...) } - canPatch, err = c.checkPermission(ctx, "patch", secretName) + canPatch, err = c.checkPermission(ctx, "patch", TypeSecrets, secretName) if err != nil { log.Printf("error checking patch permission on secret %s: %v", secretName, err) return false, false, nil } - canCreate, err = c.checkPermission(ctx, "create", secretName) + canCreate, err = c.checkPermission(ctx, "create", TypeSecrets, secretName) if err != nil { log.Printf("error checking create permission on secret %s: %v", secretName, err) return false, false, nil @@ -316,19 +412,64 @@ func (c *client) CheckSecretPermissions(ctx context.Context, secretName string) return canPatch, canCreate, nil } -// checkPermission reports whether the current pod has permission to use the -// given verb (e.g. get, update, patch, create) on secretName. -func (c *client) checkPermission(ctx context.Context, verb, secretName string) (bool, error) { +func IsNotFoundErr(err error) bool { + if st, ok := err.(*kubeapi.Status); ok && st.Code == 404 { + return true + } + return false +} + +// setEventPerms checks whether this client will be able to write tailscaled Events to its Pod and updates the state +// accordingly. If it determines that the client can not write Events, any subsequent calls to client.Event will be a +// no-op. +func (c *client) setEventPerms() { + name := os.Getenv("POD_NAME") + uid := os.Getenv("POD_UID") + hasPerms := false + defer func() { + c.podName = name + c.podUID = uid + c.hasEventsPerms = hasPerms + if !hasPerms { + log.Printf(`kubeclient: this client is not able to write tailscaled Events to the Pod in which it is running. + To help with future debugging you can make it able write Events by giving it get,create,patch permissions for Events in the Pod namespace + and setting POD_NAME, POD_UID env vars for the Pod.`) + } + }() + if name == "" || uid == "" { + return + } + for _, verb := range []string{"get", "create", "patch"} { + can, err := c.checkPermission(context.Background(), verb, typeEvents, "") + if err != nil { + log.Printf("kubeclient: error checking Events permissions: %v", err) + return + } + if !can { + return + } + } + hasPerms = true + return +} + +// checkPermission reports whether the current pod has permission to use the given verb (e.g. get, update, patch, +// create) on the given resource type. If name is not an empty string, will check the check will be for resource with +// the given name only. +func (c *client) checkPermission(ctx context.Context, verb, typ, name string) (bool, error) { + ra := map[string]any{ + "namespace": c.ns, + "verb": verb, + "resource": typ, + } + if name != "" { + ra["name"] = name + } sar := map[string]any{ "apiVersion": "authorization.k8s.io/v1", "kind": "SelfSubjectAccessReview", "spec": map[string]any{ - "resourceAttributes": map[string]any{ - "namespace": c.ns, - "verb": verb, - "resource": "secrets", - "name": secretName, - }, + "resourceAttributes": ra, }, } var res struct { @@ -337,15 +478,36 @@ func (c *client) checkPermission(ctx context.Context, verb, secretName string) ( } `json:"status"` } url := c.url + "/apis/authorization.k8s.io/v1/selfsubjectaccessreviews" - if err := c.doRequest(ctx, "POST", url, sar, &res); err != nil { + if err := c.kubeAPIRequest(ctx, "POST", url, sar, &res); err != nil { return false, err } return res.Status.Allowed, nil } -func IsNotFoundErr(err error) bool { - if st, ok := err.(*kubeapi.Status); ok && st.Code == 404 { - return true +// resourceURL returns a URL that can be used to interact with the given resource type and, if name is not empty string, +// the named resource of that type. +// Note that this only works for core/v1 resource types. +func (c *client) resourceURL(name, typ, sel string) string { + if name == "" { + url := fmt.Sprintf("%s/api/v1/namespaces/%s/%s", c.url, c.ns, typ) + if sel != "" { + url += "?labelSelector=" + sel + } + return url } - return false + return fmt.Sprintf("%s/api/v1/namespaces/%s/%s/%s", c.url, c.ns, typ, name) +} + +// nameForEvent returns a name for the Event that uniquely identifies Event with that reason for the current Pod. +func (c *client) nameForEvent(reason string) string { + return fmt.Sprintf("%s.%s.%s", c.podName, c.podUID, strings.ToLower(reason)) +} + +// getEvent fetches the event from the Kubernetes API. +func (c *client) getEvent(ctx context.Context, name string) (*kubeapi.Event, error) { + e := &kubeapi.Event{} + if err := c.kubeAPIRequest(ctx, "GET", c.resourceURL(name, typeEvents, ""), nil, e); err != nil { + return nil, err + } + return e, nil } diff --git a/kube/kubeclient/client_test.go b/kube/kubeclient/client_test.go new file mode 100644 index 000000000..8599e7e3c --- /dev/null +++ b/kube/kubeclient/client_test.go @@ -0,0 +1,227 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package kubeclient + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" + + "github.com/google/go-cmp/cmp" + "tailscale.com/kube/kubeapi" + "tailscale.com/tstest" +) + +func Test_client_Event(t *testing.T) { + cl := &tstest.Clock{} + tests := []struct { + name string + typ string + reason string + msg string + argSets []args + wantErr bool + }{ + { + name: "new_event_gets_created", + typ: "Normal", + reason: "TestReason", + msg: "TestMessage", + argSets: []args{ + { // request to GET event returns not found + wantsMethod: "GET", + wantsURL: "test-apiserver/api/v1/namespaces/test-ns/events/test-pod.test-uid.testreason", + setErr: &kubeapi.Status{Code: 404}, + }, + { // sends POST request to create event + wantsMethod: "POST", + wantsURL: "test-apiserver/api/v1/namespaces/test-ns/events", + wantsIn: &kubeapi.Event{ + ObjectMeta: kubeapi.ObjectMeta{ + Name: "test-pod.test-uid.testreason", + Namespace: "test-ns", + }, + Type: "Normal", + Reason: "TestReason", + Message: "TestMessage", + Source: kubeapi.EventSource{ + Component: "test-client", + }, + InvolvedObject: kubeapi.ObjectReference{ + Name: "test-pod", + UID: "test-uid", + Namespace: "test-ns", + APIVersion: "v1", + Kind: "Pod", + }, + FirstTimestamp: cl.Now(), + LastTimestamp: cl.Now(), + Count: 1, + }, + }, + }, + }, + { + name: "existing_event_gets_patched", + typ: "Warning", + reason: "TestReason", + msg: "TestMsg", + argSets: []args{ + { // request to GET event does not error - this is enough to assume that event exists + wantsMethod: "GET", + wantsURL: "test-apiserver/api/v1/namespaces/test-ns/events/test-pod.test-uid.testreason", + setOut: []byte(`{"count":2}`), + }, + { // sends PATCH request to update the event + wantsMethod: "PATCH", + wantsURL: "test-apiserver/api/v1/namespaces/test-ns/events/test-pod.test-uid.testreason", + wantsIn: []JSONPatch{ + {Op: "replace", Path: "/count", Value: int32(3)}, + {Op: "replace", Path: "/lastTimestamp", Value: cl.Now()}, + }, + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &client{ + cl: cl, + name: "test-client", + podName: "test-pod", + podUID: "test-uid", + url: "test-apiserver", + ns: "test-ns", + kubeAPIRequest: fakeKubeAPIRequest(t, tt.argSets), + hasEventsPerms: true, + } + if err := c.Event(context.Background(), tt.typ, tt.reason, tt.msg); (err != nil) != tt.wantErr { + t.Errorf("client.Event() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +// TestReturnsKubeStatusError ensures HTTP error codes from the Kubernetes API +// server can always be extracted by casting the error to the *kubeapi.Status +// type, as lots of calling code relies on this cast succeeding. Note that +// transport errors are not expected or required to be of type *kubeapi.Status. +func TestReturnsKubeStatusError(t *testing.T) { + cl := clientForKubeHandler(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusForbidden) + _ = json.NewEncoder(w).Encode(kubeapi.Status{Code: http.StatusForbidden, Message: "test error"}) + })) + + _, err := cl.GetSecret(t.Context(), "test-secret") + if err == nil { + t.Fatal("expected error, got nil") + } + if st, ok := err.(*kubeapi.Status); !ok || st.Code != http.StatusForbidden { + t.Fatalf("expected kubeapi.Status with code %d, got %T: %v", http.StatusForbidden, err, err) + } +} + +// clientForKubeHandler creates a client using the externally accessible package +// API to ensure it's testing behaviour as close to prod as possible. The passed +// in handler mocks the Kubernetes API server's responses to any HTTP requests +// made by the client. +func clientForKubeHandler(t *testing.T, handler http.Handler) Client { + t.Helper() + tmpDir := t.TempDir() + rootPathForTests = tmpDir + saDir := filepath.Join(tmpDir, "var", "run", "secrets", "kubernetes.io", "serviceaccount") + _ = os.MkdirAll(saDir, 0755) + _ = os.WriteFile(filepath.Join(saDir, "token"), []byte("test-token"), 0600) + _ = os.WriteFile(filepath.Join(saDir, "namespace"), []byte("test-namespace"), 0600) + _ = os.WriteFile(filepath.Join(saDir, "ca.crt"), []byte(ca), 0644) + cl, err := New("test-client") + if err != nil { + t.Fatalf("New() error = %v", err) + } + srv := httptest.NewServer(handler) + t.Cleanup(srv.Close) + cl.SetURL(srv.URL) + return cl +} + +// args is a set of values for testing a single call to client.kubeAPIRequest. +type args struct { + // wantsMethod is the expected value of 'method' arg. + wantsMethod string + // wantsURL is the expected value of 'url' arg. + wantsURL string + // wantsIn is the expected value of 'in' arg. + wantsIn any + // setOut can be set to a byte slice representing valid JSON. If set 'out' arg will get set to the unmarshalled + // JSON object. + setOut []byte + // setErr is the error that kubeAPIRequest will return. + setErr error +} + +// fakeKubeAPIRequest can be used to test that a series of calls to client.kubeAPIRequest gets called with expected +// values and to set these calls to return preconfigured values. 'argSets' should be set to a slice of expected +// arguments and should-be return values of a series of kubeAPIRequest calls. +func fakeKubeAPIRequest(t *testing.T, argSets []args) kubeAPIRequestFunc { + count := 0 + f := func(ctx context.Context, gotMethod, gotUrl string, gotIn, gotOut any, opts ...func(*http.Request)) error { + t.Helper() + if count >= len(argSets) { + t.Fatalf("unexpected call to client.kubeAPIRequest, expected %d calls, but got a %dth call", len(argSets), count+1) + } + a := argSets[count] + if gotMethod != a.wantsMethod { + t.Errorf("[%d] got method %q, wants method %q", count, gotMethod, a.wantsMethod) + } + if gotUrl != a.wantsURL { + t.Errorf("[%d] got URL %q, wants URL %q", count, gotUrl, a.wantsURL) + } + if d := cmp.Diff(gotIn, a.wantsIn); d != "" { + t.Errorf("[%d] unexpected payload (-want + got):\n%s", count, d) + } + if len(a.setOut) != 0 { + if err := json.Unmarshal(a.setOut, gotOut); err != nil { + t.Fatalf("[%d] error unmarshalling output: %v", count, err) + } + } + count++ + return a.setErr + } + return f +} + +const ca = `-----BEGIN CERTIFICATE----- +MIIFEDCCA3igAwIBAgIRANf5NdPojIfj70wMfJVYUg8wDQYJKoZIhvcNAQELBQAw +gZ8xHjAcBgNVBAoTFW1rY2VydCBkZXZlbG9wbWVudCBDQTE6MDgGA1UECwwxZnJv +bWJlcmdlckBzdGFyZHVzdC5sb2NhbCAoTWljaGFlbCBKLiBGcm9tYmVyZ2VyKTFB +MD8GA1UEAww4bWtjZXJ0IGZyb21iZXJnZXJAc3RhcmR1c3QubG9jYWwgKE1pY2hh +ZWwgSi4gRnJvbWJlcmdlcikwHhcNMjMwMjA3MjAzNDE4WhcNMzMwMjA3MjAzNDE4 +WjCBnzEeMBwGA1UEChMVbWtjZXJ0IGRldmVsb3BtZW50IENBMTowOAYDVQQLDDFm +cm9tYmVyZ2VyQHN0YXJkdXN0LmxvY2FsIChNaWNoYWVsIEouIEZyb21iZXJnZXIp +MUEwPwYDVQQDDDhta2NlcnQgZnJvbWJlcmdlckBzdGFyZHVzdC5sb2NhbCAoTWlj +aGFlbCBKLiBGcm9tYmVyZ2VyKTCCAaIwDQYJKoZIhvcNAQEBBQADggGPADCCAYoC +ggGBAL5uXNnrZ6dgjcvK0Hc7ZNUIRYEWst9qbO0P9H7le08pJ6d9T2BUWruZtVjk +Q12msv5/bVWHhVk8dZclI9FLXuMsIrocH8bsoP4wruPMyRyp6EedSKODN51fFSRv +/jHbS5vzUVAWTYy9qYmd6qL0uhsHCZCCT6gfigamHPUFKM3sHDn5ZHWvySMwcyGl +AicmPAIkBWqiCZAkB5+WM7+oyRLjmrIalfWIZYxW/rojGLwTfneHv6J5WjVQnpJB +ayWCzCzaiXukK9MeBWeTOe8UfVN0Engd74/rjLWvjbfC+uZSr6RVkZvs2jANLwPF +zgzBPHgRPfAhszU1NNAMjnNQ47+OMOTKRt7e6jYzhO5fyO1qVAAvGBqcfpj+JfDk +cccaUMhUvdiGrhGf1V1tN/PislxvALirzcFipjD01isBKwn0fxRugzvJNrjEo8RA +RvbcdeKcwex7M0o/Cd0+G2B13gZNOFvR33PmG7iTpp7IUrUKfQg28I83Sp8tMY3s +ljJSawIDAQABo0UwQzAOBgNVHQ8BAf8EBAMCAgQwEgYDVR0TAQH/BAgwBgEB/wIB +ADAdBgNVHQ4EFgQU18qto0Fa56kCi/HwfQuC9ECX7cAwDQYJKoZIhvcNAQELBQAD +ggGBAAzs96LwZVOsRSlBdQqMo8oMAvs7HgnYbXt8SqaACLX3+kJ3cV/vrCE3iJrW +ma4CiQbxS/HqsiZjota5m4lYeEevRnUDpXhp+7ugZTiz33Flm1RU99c9UYfQ+919 +ANPAKeqNpoPco/HF5Bz0ocepjcfKQrVZZNTj6noLs8o12FHBLO5976AcF9mqlNfh +8/F0gDJXq6+x7VT5y8u0rY004XKPRe3CklRt8kpeMiP6mhRyyUehOaHeIbNx8ubi +Pi44ByN/ueAnuRhF9zYtyZVZZOaSLysJge01tuPXF8rBXGruoJIv35xTTBa9BzaP +YDOGbGn1ZnajdNagHqCba8vjTLDSpqMvgRj3TFrGHdETA2LDQat38uVxX8gxm68K +va5Tyv7n+6BQ5YTpJjTPnmSJKaXZrrhdLPvG0OU2TxeEsvbcm5LFQofirOOw86Se +vzF2cQ94mmHRZiEk0Av3NO0jF93ELDrBCuiccVyEKq6TknuvPQlutCXKDOYSEb8I +MHctBg== +-----END CERTIFICATE-----` diff --git a/kube/kubeclient/fake_client.go b/kube/kubeclient/fake_client.go index 3cef3d27e..15ebb5f44 100644 --- a/kube/kubeclient/fake_client.go +++ b/kube/kubeclient/fake_client.go @@ -13,8 +13,13 @@ import ( var _ Client = &FakeClient{} type FakeClient struct { - GetSecretImpl func(context.Context, string) (*kubeapi.Secret, error) - CheckSecretPermissionsImpl func(ctx context.Context, name string) (bool, bool, error) + GetSecretImpl func(context.Context, string) (*kubeapi.Secret, error) + CheckSecretPermissionsImpl func(ctx context.Context, name string) (bool, bool, error) + CreateSecretImpl func(context.Context, *kubeapi.Secret) error + UpdateSecretImpl func(context.Context, *kubeapi.Secret) error + JSONPatchResourceImpl func(context.Context, string, string, []JSONPatch) error + ListSecretsImpl func(context.Context, map[string]string) (*kubeapi.SecretList, error) + StrategicMergePatchSecretImpl func(context.Context, string, *kubeapi.Secret, string) error } func (fc *FakeClient) CheckSecretPermissions(ctx context.Context, name string) (bool, bool, error) { @@ -26,11 +31,25 @@ func (fc *FakeClient) GetSecret(ctx context.Context, name string) (*kubeapi.Secr func (fc *FakeClient) SetURL(_ string) {} func (fc *FakeClient) SetDialer(dialer func(ctx context.Context, network, addr string) (net.Conn, error)) { } -func (fc *FakeClient) StrategicMergePatchSecret(context.Context, string, *kubeapi.Secret, string) error { - return nil +func (fc *FakeClient) StrategicMergePatchSecret(ctx context.Context, name string, s *kubeapi.Secret, fieldManager string) error { + return fc.StrategicMergePatchSecretImpl(ctx, name, s, fieldManager) } -func (fc *FakeClient) JSONPatchSecret(context.Context, string, []JSONPatch) error { +func (fc *FakeClient) Event(context.Context, string, string, string) error { return nil } -func (fc *FakeClient) UpdateSecret(context.Context, *kubeapi.Secret) error { return nil } -func (fc *FakeClient) CreateSecret(context.Context, *kubeapi.Secret) error { return nil } + +func (fc *FakeClient) JSONPatchResource(ctx context.Context, resource, name string, patches []JSONPatch) error { + return fc.JSONPatchResourceImpl(ctx, resource, name, patches) +} +func (fc *FakeClient) UpdateSecret(ctx context.Context, secret *kubeapi.Secret) error { + return fc.UpdateSecretImpl(ctx, secret) +} +func (fc *FakeClient) CreateSecret(ctx context.Context, secret *kubeapi.Secret) error { + return fc.CreateSecretImpl(ctx, secret) +} +func (fc *FakeClient) ListSecrets(ctx context.Context, selector map[string]string) (*kubeapi.SecretList, error) { + if fc.ListSecretsImpl != nil { + return fc.ListSecretsImpl(ctx, selector) + } + return nil, nil +} diff --git a/kube/kubetypes/metrics.go b/kube/kubetypes/metrics.go deleted file mode 100644 index 021c1e26b..000000000 --- a/kube/kubetypes/metrics.go +++ /dev/null @@ -1,25 +0,0 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package kubetypes - -const ( - // Hostinfo App values for the Tailscale Kubernetes Operator components. - AppOperator = "k8s-operator" - AppAPIServerProxy = "k8s-operator-proxy" - AppIngressProxy = "k8s-operator-ingress-proxy" - AppIngressResource = "k8s-operator-ingress-resource" - AppEgressProxy = "k8s-operator-egress-proxy" - AppConnector = "k8s-operator-connector-resource" - - // Clientmetrics for Tailscale Kubernetes Operator components - MetricIngressProxyCount = "k8s_ingress_proxies" // L3 - MetricIngressResourceCount = "k8s_ingress_resources" // L7 - MetricEgressProxyCount = "k8s_egress_proxies" - MetricConnectorResourceCount = "k8s_connector_resources" - MetricConnectorWithSubnetRouterCount = "k8s_connector_subnetrouter_resources" - MetricConnectorWithExitNodeCount = "k8s_connector_exitnode_resources" - MetricNameserverCount = "k8s_nameserver_resources" - MetricRecorderCount = "k8s_recorder_resources" - MetricEgressServiceCount = "k8s_egress_service_resources" -) diff --git a/kube/kubetypes/types.go b/kube/kubetypes/types.go new file mode 100644 index 000000000..44b01fe1a --- /dev/null +++ b/kube/kubetypes/types.go @@ -0,0 +1,84 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package kubetypes + +import "fmt" + +const ( + // Hostinfo App values for the Tailscale Kubernetes Operator components. + AppOperator = "k8s-operator" + AppInProcessAPIServerProxy = "k8s-operator-proxy" + AppIngressProxy = "k8s-operator-ingress-proxy" + AppIngressResource = "k8s-operator-ingress-resource" + AppEgressProxy = "k8s-operator-egress-proxy" + AppConnector = "k8s-operator-connector-resource" + AppProxyGroupEgress = "k8s-operator-proxygroup-egress" + AppProxyGroupIngress = "k8s-operator-proxygroup-ingress" + AppProxyGroupKubeAPIServer = "k8s-operator-proxygroup-kube-apiserver" + + // Clientmetrics for Tailscale Kubernetes Operator components + MetricIngressProxyCount = "k8s_ingress_proxies" // L3 + MetricIngressResourceCount = "k8s_ingress_resources" // L7 + MetricIngressPGResourceCount = "k8s_ingress_pg_resources" // L7 on ProxyGroup + MetricServicePGResourceCount = "k8s_service_pg_resources" // L3 on ProxyGroup + MetricEgressProxyCount = "k8s_egress_proxies" + MetricConnectorResourceCount = "k8s_connector_resources" + MetricConnectorWithSubnetRouterCount = "k8s_connector_subnetrouter_resources" + MetricConnectorWithExitNodeCount = "k8s_connector_exitnode_resources" + MetricConnectorWithAppConnectorCount = "k8s_connector_appconnector_resources" + MetricNameserverCount = "k8s_nameserver_resources" + MetricRecorderCount = "k8s_recorder_resources" + MetricEgressServiceCount = "k8s_egress_service_resources" + MetricProxyGroupEgressCount = "k8s_proxygroup_egress_resources" + MetricProxyGroupIngressCount = "k8s_proxygroup_ingress_resources" + MetricProxyGroupAPIServerCount = "k8s_proxygroup_kube_apiserver_resources" + + // Keys that containerboot writes to state file that can be used to determine its state. + // fields set in Tailscale state Secret. These are mostly used by the Tailscale Kubernetes operator to determine + // the state of this tailscale device. + KeyDeviceID string = "device_id" // node stable ID of the device + KeyDeviceFQDN string = "device_fqdn" // device's tailnet hostname + KeyDeviceIPs string = "device_ips" // device's tailnet IPs + KeyPodUID string = "pod_uid" // Pod UID + // KeyCapVer contains Tailscale capability version of this proxy instance. + KeyCapVer string = "tailscale_capver" + // KeyHTTPSEndpoint is a name of a field that can be set to the value of any HTTPS endpoint currently exposed by + // this device to the tailnet. This is used by the Kubernetes operator Ingress proxy to communicate to the operator + // that cluster workloads behind the Ingress can now be accessed via the given DNS name over HTTPS. + KeyHTTPSEndpoint string = "https_endpoint" + ValueNoHTTPS string = "no-https" + + // Pod's IPv4 address header key as returned by containerboot health check endpoint. + PodIPv4Header string = "Pod-IPv4" + + EgessServicesPreshutdownEP = "/internal-egress-services-preshutdown" + + LabelManaged = "tailscale.com/managed" + LabelSecretType = "tailscale.com/secret-type" // "config", "state" "certs" + + LabelSecretTypeConfig = "config" + LabelSecretTypeState = "state" + LabelSecretTypeCerts = "certs" + + KubeAPIServerConfigFile = "config.hujson" + APIServerProxyModeAuth APIServerProxyMode = "auth" + APIServerProxyModeNoAuth APIServerProxyMode = "noauth" +) + +// APIServerProxyMode specifies whether the API server proxy will add +// impersonation headers to requests based on the caller's Tailscale identity. +// May be "auth" or "noauth". +type APIServerProxyMode string + +func (a *APIServerProxyMode) UnmarshalJSON(data []byte) error { + switch string(data) { + case `"auth"`: + *a = APIServerProxyModeAuth + case `"noauth"`: + *a = APIServerProxyModeNoAuth + default: + return fmt.Errorf("unknown APIServerProxyMode %q", data) + } + return nil +} diff --git a/kube/kubetypes/types_test.go b/kube/kubetypes/types_test.go new file mode 100644 index 000000000..ea1846b32 --- /dev/null +++ b/kube/kubetypes/types_test.go @@ -0,0 +1,42 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package kubetypes + +import ( + "encoding/json" + "testing" +) + +func TestUnmarshalAPIServerProxyMode(t *testing.T) { + tests := []struct { + data string + expected APIServerProxyMode + }{ + {data: `{"mode":"auth"}`, expected: APIServerProxyModeAuth}, + {data: `{"mode":"noauth"}`, expected: APIServerProxyModeNoAuth}, + {data: `{"mode":""}`, expected: ""}, + {data: `{"mode":"Auth"}`, expected: ""}, + {data: `{"mode":"unknown"}`, expected: ""}, + } + + for _, tc := range tests { + var s struct { + Mode *APIServerProxyMode `json:",omitempty"` + } + err := json.Unmarshal([]byte(tc.data), &s) + if tc.expected == "" { + if err == nil { + t.Errorf("expected error for %q, got none", tc.data) + } + continue + } + if err != nil { + t.Errorf("unexpected error for %q: %v", tc.data, err) + continue + } + if *s.Mode != tc.expected { + t.Errorf("for %q expected %q, got %q", tc.data, tc.expected, *s.Mode) + } + } +} diff --git a/kube/localclient/fake-client.go b/kube/localclient/fake-client.go new file mode 100644 index 000000000..7f0a08316 --- /dev/null +++ b/kube/localclient/fake-client.go @@ -0,0 +1,35 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package localclient + +import ( + "context" + "fmt" + + "tailscale.com/ipn" +) + +type FakeLocalClient struct { + FakeIPNBusWatcher +} + +func (f *FakeLocalClient) WatchIPNBus(ctx context.Context, mask ipn.NotifyWatchOpt) (IPNBusWatcher, error) { + return &f.FakeIPNBusWatcher, nil +} + +func (f *FakeLocalClient) CertPair(ctx context.Context, domain string) ([]byte, []byte, error) { + return nil, nil, fmt.Errorf("CertPair not implemented") +} + +type FakeIPNBusWatcher struct { + NotifyChan chan ipn.Notify +} + +func (f *FakeIPNBusWatcher) Close() error { + return nil +} + +func (f *FakeIPNBusWatcher) Next() (ipn.Notify, error) { + return <-f.NotifyChan, nil +} diff --git a/kube/localclient/local-client.go b/kube/localclient/local-client.go new file mode 100644 index 000000000..550b3ae74 --- /dev/null +++ b/kube/localclient/local-client.go @@ -0,0 +1,49 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package localclient provides an interface for all the local.Client methods +// kube needs to use, so that we can easily mock it in tests. +package localclient + +import ( + "context" + "io" + + "tailscale.com/client/local" + "tailscale.com/ipn" +) + +// LocalClient is roughly a subset of the local.Client struct's methods, used +// for easier testing. +type LocalClient interface { + WatchIPNBus(ctx context.Context, mask ipn.NotifyWatchOpt) (IPNBusWatcher, error) + CertIssuer +} + +// IPNBusWatcher is local.IPNBusWatcher's methods restated in an interface to +// allow for easier mocking in tests. +type IPNBusWatcher interface { + io.Closer + Next() (ipn.Notify, error) +} + +type CertIssuer interface { + CertPair(context.Context, string) ([]byte, []byte, error) +} + +// New returns a LocalClient that wraps the provided local.Client. +func New(lc *local.Client) LocalClient { + return &localClient{lc: lc} +} + +type localClient struct { + lc *local.Client +} + +func (lc *localClient) WatchIPNBus(ctx context.Context, mask ipn.NotifyWatchOpt) (IPNBusWatcher, error) { + return lc.lc.WatchIPNBus(ctx, mask) +} + +func (lc *localClient) CertPair(ctx context.Context, domain string) ([]byte, []byte, error) { + return lc.lc.CertPair(ctx, domain) +} diff --git a/kube/metrics/metrics.go b/kube/metrics/metrics.go new file mode 100644 index 000000000..0db683008 --- /dev/null +++ b/kube/metrics/metrics.go @@ -0,0 +1,81 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !plan9 + +// Package metrics contains shared types and underlying methods for serving +// localapi metrics. This is primarily consumed by containerboot and k8s-proxy. +package metrics + +import ( + "fmt" + "io" + "net/http" + + "tailscale.com/client/local" + "tailscale.com/client/tailscale/apitype" +) + +// metrics is a simple metrics HTTP server, if enabled it forwards requests to +// the tailscaled's LocalAPI usermetrics endpoint at /localapi/v0/usermetrics. +type metrics struct { + debugEndpoint string + lc *local.Client +} + +func proxy(w http.ResponseWriter, r *http.Request, url string, do func(*http.Request) (*http.Response, error)) { + req, err := http.NewRequestWithContext(r.Context(), r.Method, url, r.Body) + if err != nil { + http.Error(w, fmt.Sprintf("failed to construct request: %s", err), http.StatusInternalServerError) + return + } + req.Header = r.Header.Clone() + + resp, err := do(req) + if err != nil { + http.Error(w, fmt.Sprintf("failed to proxy request: %s", err), http.StatusInternalServerError) + return + } + defer resp.Body.Close() + + for key, val := range resp.Header { + for _, v := range val { + w.Header().Add(key, v) + } + } + w.WriteHeader(resp.StatusCode) + if _, err := io.Copy(w, resp.Body); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } +} + +func (m *metrics) handleMetrics(w http.ResponseWriter, r *http.Request) { + localAPIURL := "http://" + apitype.LocalAPIHost + "/localapi/v0/usermetrics" + proxy(w, r, localAPIURL, m.lc.DoLocalRequest) +} + +func (m *metrics) handleDebug(w http.ResponseWriter, r *http.Request) { + if m.debugEndpoint == "" { + http.Error(w, "debug endpoint not configured", http.StatusNotFound) + return + } + + debugURL := "http://" + m.debugEndpoint + r.URL.Path + proxy(w, r, debugURL, http.DefaultClient.Do) +} + +// registerMetricsHandlers registers a simple HTTP metrics handler at /metrics, forwarding +// requests to tailscaled's /localapi/v0/usermetrics API. +// +// In 1.78.x and 1.80.x, it also proxies debug paths to tailscaled's debug +// endpoint if configured to ease migration for a breaking change serving user +// metrics instead of debug metrics on the "metrics" port. +func RegisterMetricsHandlers(mux *http.ServeMux, lc *local.Client, debugAddrPort string) { + m := &metrics{ + lc: lc, + debugEndpoint: debugAddrPort, + } + + mux.HandleFunc("GET /metrics", m.handleMetrics) + mux.HandleFunc("/debug/", m.handleDebug) // TODO(tomhjp): Remove for 1.82.0 release. +} diff --git a/kube/services/services.go b/kube/services/services.go new file mode 100644 index 000000000..a9e50975c --- /dev/null +++ b/kube/services/services.go @@ -0,0 +1,63 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package services manages graceful shutdown of Tailscale Services advertised +// by Kubernetes clients. +package services + +import ( + "context" + "fmt" + "time" + + "tailscale.com/client/local" + "tailscale.com/ipn" + "tailscale.com/types/logger" +) + +// EnsureServicesNotAdvertised is a function that gets called on containerboot +// or k8s-proxy termination and ensures that any currently advertised Services +// get unadvertised to give clients time to switch to another node before this +// one is shut down. +func EnsureServicesNotAdvertised(ctx context.Context, lc *local.Client, logf logger.Logf) error { + prefs, err := lc.GetPrefs(ctx) + if err != nil { + return fmt.Errorf("error getting prefs: %w", err) + } + if len(prefs.AdvertiseServices) == 0 { + return nil + } + + logf("unadvertising services: %v", prefs.AdvertiseServices) + if _, err := lc.EditPrefs(ctx, &ipn.MaskedPrefs{ + AdvertiseServicesSet: true, + Prefs: ipn.Prefs{ + AdvertiseServices: nil, + }, + }); err != nil { + // EditPrefs only returns an error if it fails _set_ its local prefs. + // If it fails to _persist_ the prefs in state, we don't get an error + // and we continue waiting below, as control will failover as usual. + return fmt.Errorf("error setting prefs AdvertiseServices: %w", err) + } + + // Services use the same (failover XOR regional routing) mechanism that + // HA subnet routers use. Unfortunately we don't yet get a reliable signal + // from control that it's responded to our unadvertisement, so the best we + // can do is wait for 20 seconds, where 15s is the approximate maximum time + // it should take for control to choose a new primary, and 5s is for buffer. + // + // Note: There is no guarantee that clients have been _informed_ of the new + // primary no matter how long we wait. We would need a mechanism to await + // netmap updates for peers to know for sure. + // + // See https://tailscale.com/kb/1115/high-availability for more details. + // TODO(tomhjp): Wait for a netmap update instead of sleeping when control + // supports that. + select { + case <-ctx.Done(): + return nil + case <-time.After(20 * time.Second): + return nil + } +} diff --git a/kube/state/state.go b/kube/state/state.go new file mode 100644 index 000000000..2605f0952 --- /dev/null +++ b/kube/state/state.go @@ -0,0 +1,107 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !plan9 + +// Package state updates state keys for tailnet client devices managed by the +// operator. These keys are used to signal readiness, metadata, and current +// configuration state to the operator. Client packages deployed by the operator +// include containerboot, tsrecorder, and k8s-proxy, but currently containerboot +// has its own implementation to manage the same keys. +package state + +import ( + "context" + "encoding/json" + "fmt" + + "tailscale.com/ipn" + "tailscale.com/kube/kubetypes" + klc "tailscale.com/kube/localclient" + "tailscale.com/tailcfg" + "tailscale.com/util/deephash" +) + +const ( + keyPodUID = ipn.StateKey(kubetypes.KeyPodUID) + keyCapVer = ipn.StateKey(kubetypes.KeyCapVer) + keyDeviceID = ipn.StateKey(kubetypes.KeyDeviceID) + keyDeviceIPs = ipn.StateKey(kubetypes.KeyDeviceIPs) + keyDeviceFQDN = ipn.StateKey(kubetypes.KeyDeviceFQDN) +) + +// SetInitialKeys sets Pod UID and cap ver and clears tailnet device state +// keys to help stop the operator using stale tailnet device state. +func SetInitialKeys(store ipn.StateStore, podUID string) error { + // Clear device state keys first so the operator knows if the pod UID + // matches, the other values are definitely not stale. + for _, key := range []ipn.StateKey{keyDeviceID, keyDeviceFQDN, keyDeviceIPs} { + if _, err := store.ReadState(key); err == nil { + if err := store.WriteState(key, nil); err != nil { + return fmt.Errorf("error writing %q to state store: %w", key, err) + } + } + } + + if err := store.WriteState(keyPodUID, []byte(podUID)); err != nil { + return fmt.Errorf("error writing pod UID to state store: %w", err) + } + if err := store.WriteState(keyCapVer, fmt.Appendf(nil, "%d", tailcfg.CurrentCapabilityVersion)); err != nil { + return fmt.Errorf("error writing capability version to state store: %w", err) + } + + return nil +} + +// KeepKeysUpdated sets state store keys consistent with containerboot to +// signal proxy readiness to the operator. It runs until its context is +// cancelled or it hits an error. The passed in next function is expected to be +// from a local.IPNBusWatcher that is at least subscribed to +// ipn.NotifyInitialNetMap. +func KeepKeysUpdated(ctx context.Context, store ipn.StateStore, lc klc.LocalClient) error { + w, err := lc.WatchIPNBus(ctx, ipn.NotifyInitialNetMap) + if err != nil { + return fmt.Errorf("error watching IPN bus: %w", err) + } + defer w.Close() + + var currentDeviceID, currentDeviceIPs, currentDeviceFQDN deephash.Sum + for { + n, err := w.Next() // Blocks on a streaming LocalAPI HTTP call. + if err != nil { + if err == ctx.Err() { + return nil + } + return err + } + if n.NetMap == nil { + continue + } + + if deviceID := n.NetMap.SelfNode.StableID(); deephash.Update(¤tDeviceID, &deviceID) { + if err := store.WriteState(keyDeviceID, []byte(deviceID)); err != nil { + return fmt.Errorf("failed to store device ID in state: %w", err) + } + } + + if fqdn := n.NetMap.SelfNode.Name(); deephash.Update(¤tDeviceFQDN, &fqdn) { + if err := store.WriteState(keyDeviceFQDN, []byte(fqdn)); err != nil { + return fmt.Errorf("failed to store device FQDN in state: %w", err) + } + } + + if addrs := n.NetMap.SelfNode.Addresses(); deephash.Update(¤tDeviceIPs, &addrs) { + var deviceIPs []string + for _, addr := range addrs.AsSlice() { + deviceIPs = append(deviceIPs, addr.Addr().String()) + } + deviceIPsValue, err := json.Marshal(deviceIPs) + if err != nil { + return err + } + if err := store.WriteState(keyDeviceIPs, deviceIPsValue); err != nil { + return fmt.Errorf("failed to store device IPs in state: %w", err) + } + } + } +} diff --git a/kube/state/state_test.go b/kube/state/state_test.go new file mode 100644 index 000000000..8701aa1b7 --- /dev/null +++ b/kube/state/state_test.go @@ -0,0 +1,196 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !plan9 + +package state + +import ( + "bytes" + "fmt" + "net/netip" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "tailscale.com/ipn" + "tailscale.com/ipn/store" + klc "tailscale.com/kube/localclient" + "tailscale.com/tailcfg" + "tailscale.com/types/logger" + "tailscale.com/types/netmap" +) + +func TestSetInitialStateKeys(t *testing.T) { + var ( + podUID = []byte("test-pod-uid") + expectedCapVer = fmt.Appendf(nil, "%d", tailcfg.CurrentCapabilityVersion) + ) + for name, tc := range map[string]struct { + initial map[ipn.StateKey][]byte + expected map[ipn.StateKey][]byte + }{ + "empty_initial": { + initial: map[ipn.StateKey][]byte{}, + expected: map[ipn.StateKey][]byte{ + keyPodUID: podUID, + keyCapVer: expectedCapVer, + }, + }, + "existing_pod_uid_and_capver": { + initial: map[ipn.StateKey][]byte{ + keyPodUID: podUID, + keyCapVer: expectedCapVer, + }, + expected: map[ipn.StateKey][]byte{ + keyPodUID: podUID, + keyCapVer: expectedCapVer, + }, + }, + "all_keys_preexisting": { + initial: map[ipn.StateKey][]byte{ + keyPodUID: podUID, + keyCapVer: expectedCapVer, + keyDeviceID: []byte("existing-device-id"), + keyDeviceFQDN: []byte("existing-device-fqdn"), + keyDeviceIPs: []byte(`["1.2.3.4"]`), + }, + expected: map[ipn.StateKey][]byte{ + keyPodUID: podUID, + keyCapVer: expectedCapVer, + keyDeviceID: nil, + keyDeviceFQDN: nil, + keyDeviceIPs: nil, + }, + }, + } { + t.Run(name, func(t *testing.T) { + store, err := store.New(logger.Discard, "mem:") + if err != nil { + t.Fatalf("error creating in-memory store: %v", err) + } + + for key, value := range tc.initial { + if err := store.WriteState(key, value); err != nil { + t.Fatalf("error writing initial state key %q: %v", key, err) + } + } + + if err := SetInitialKeys(store, string(podUID)); err != nil { + t.Fatalf("setInitialStateKeys failed: %v", err) + } + + actual := make(map[ipn.StateKey][]byte) + for expectedKey, expectedValue := range tc.expected { + actualValue, err := store.ReadState(expectedKey) + if err != nil { + t.Errorf("error reading state key %q: %v", expectedKey, err) + continue + } + + actual[expectedKey] = actualValue + if !bytes.Equal(actualValue, expectedValue) { + t.Errorf("state key %q mismatch: expected %q, got %q", expectedKey, expectedValue, actualValue) + } + } + if diff := cmp.Diff(actual, tc.expected); diff != "" { + t.Errorf("state keys mismatch (-got +want):\n%s", diff) + } + }) + } +} + +func TestKeepStateKeysUpdated(t *testing.T) { + store := fakeStore{ + writeChan: make(chan string), + } + + errs := make(chan error) + notifyChan := make(chan ipn.Notify) + lc := &klc.FakeLocalClient{ + FakeIPNBusWatcher: klc.FakeIPNBusWatcher{ + NotifyChan: notifyChan, + }, + } + + go func() { + err := KeepKeysUpdated(t.Context(), store, lc) + if err != nil { + errs <- fmt.Errorf("keepStateKeysUpdated returned with error: %w", err) + } + }() + + for _, tc := range []struct { + name string + notify ipn.Notify + expected []string + }{ + { + name: "initial_not_authed", + notify: ipn.Notify{}, + expected: nil, + }, + { + name: "authed", + notify: ipn.Notify{ + NetMap: &netmap.NetworkMap{ + SelfNode: (&tailcfg.Node{ + StableID: "TESTCTRL00000001", + Name: "test-node.test.ts.net", + Addresses: []netip.Prefix{netip.MustParsePrefix("100.64.0.1/32"), netip.MustParsePrefix("fd7a:115c:a1e0:ab12:4843:cd96:0:1/128")}, + }).View(), + }, + }, + expected: []string{ + fmt.Sprintf("%s=%s", keyDeviceID, "TESTCTRL00000001"), + fmt.Sprintf("%s=%s", keyDeviceFQDN, "test-node.test.ts.net"), + fmt.Sprintf("%s=%s", keyDeviceIPs, `["100.64.0.1","fd7a:115c:a1e0:ab12:4843:cd96:0:1"]`), + }, + }, + { + name: "updated_fields", + notify: ipn.Notify{ + NetMap: &netmap.NetworkMap{ + SelfNode: (&tailcfg.Node{ + StableID: "TESTCTRL00000001", + Name: "updated.test.ts.net", + Addresses: []netip.Prefix{netip.MustParsePrefix("100.64.0.250/32")}, + }).View(), + }, + }, + expected: []string{ + fmt.Sprintf("%s=%s", keyDeviceFQDN, "updated.test.ts.net"), + fmt.Sprintf("%s=%s", keyDeviceIPs, `["100.64.0.250"]`), + }, + }, + } { + t.Run(tc.name, func(t *testing.T) { + notifyChan <- tc.notify + for _, expected := range tc.expected { + select { + case got := <-store.writeChan: + if got != expected { + t.Errorf("expected %q, got %q", expected, got) + } + case err := <-errs: + t.Fatalf("unexpected error: %v", err) + case <-time.After(5 * time.Second): + t.Fatalf("timed out waiting for expected write %q", expected) + } + } + }) + } +} + +type fakeStore struct { + writeChan chan string +} + +func (f fakeStore) ReadState(key ipn.StateKey) ([]byte, error) { + return nil, fmt.Errorf("ReadState not implemented") +} + +func (f fakeStore) WriteState(key ipn.StateKey, value []byte) error { + f.writeChan <- fmt.Sprintf("%s=%s", key, value) + return nil +} diff --git a/license_test.go b/license_test.go new file mode 100644 index 000000000..9b62c48ed --- /dev/null +++ b/license_test.go @@ -0,0 +1,117 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tailscaleroot + +import ( + "bytes" + "fmt" + "io" + "os" + "path/filepath" + "strings" + "testing" + + "tailscale.com/util/set" +) + +func normalizeLineEndings(b []byte) []byte { + return bytes.ReplaceAll(b, []byte("\r\n"), []byte("\n")) +} + +// TestLicenseHeaders checks that all Go files in the tree +// directory tree have a correct-looking Tailscale license header. +func TestLicenseHeaders(t *testing.T) { + want := normalizeLineEndings([]byte(strings.TrimLeft(` +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause +`, "\n"))) + + exceptions := set.Of( + // Subprocess test harness code + "util/winutil/testdata/testrestartableprocesses/main.go", + "util/winutil/subprocess_windows_test.go", + + // WireGuard copyright + "cmd/tailscale/cli/authenticode_windows.go", + "wgengine/router/osrouter/ifconfig_windows.go", + + // noiseexplorer.com copyright + "control/controlbase/noiseexplorer_test.go", + + // Generated eBPF management code + "derp/xdp/bpf_bpfeb.go", + "derp/xdp/bpf_bpfel.go", + + // Generated kube deepcopy funcs file starts with a Go build tag + an empty line + "k8s-operator/apis/v1alpha1/zz_generated.deepcopy.go", + ) + + err := filepath.Walk(".", func(path string, fi os.FileInfo, err error) error { + if err != nil { + return fmt.Errorf("path %s: %v", path, err) + } + if exceptions.Contains(filepath.ToSlash(path)) { + return nil + } + base := filepath.Base(path) + switch base { + case ".git", "node_modules", "tempfork": + return filepath.SkipDir + } + switch base { + case "zsyscall_windows.go": + // Generated code. + return nil + } + + if strings.HasSuffix(base, ".config.ts") { + return nil + } + if strings.HasSuffix(base, "_string.go") { + // Generated file from go:generate stringer + return nil + } + + ext := filepath.Ext(base) + switch ext { + default: + return nil + case ".go", ".ts", ".tsx": + } + + buf := make([]byte, 512) + f, err := os.Open(path) + if err != nil { + return err + } + defer f.Close() + if n, err := io.ReadAtLeast(f, buf, 512); err != nil && err != io.ErrUnexpectedEOF { + return err + } else { + buf = buf[:n] + } + + buf = normalizeLineEndings(buf) + + bufNoTrunc := buf + if i := bytes.Index(buf, []byte("\npackage ")); i != -1 { + buf = buf[:i] + } + + if bytes.Contains(buf, want) { + return nil + } + + if bytes.Contains(bufNoTrunc, []byte("BSD-3-Clause\npackage ")) { + t.Errorf("file %s has license header as a package doc; add a blank line before the package line", path) + return nil + } + + t.Errorf("file %s is missing Tailscale copyright header:\n\n%s", path, want) + return nil + }) + if err != nil { + t.Fatalf("Walk: %v", err) + } +} diff --git a/licenses/README.md b/licenses/README.md new file mode 100644 index 000000000..46fe8b77f --- /dev/null +++ b/licenses/README.md @@ -0,0 +1,35 @@ +# Licenses + +This directory contains a list of dependencies, and their licenses, that are included in the Tailscale clients. +These lists are generated using the [go-licenses] tool to analyze all Go packages in the Tailscale binaries, +as well as a set of custom output templates that includes any additional non-Go dependencies. +For example, the clients for macOS and iOS include some additional Swift libraries. + +These lists are updated roughly every week, so it is possible to see the dependencies in a given release by looking at the release tag. +For example, the dependences for the 1.80.0 release of the macOS client can be seen at +. + +[go-licenses]: https://github.com/google/go-licenses + +## Other formats + +The go-licenses tool can output other formats like CSV, but that wouldn't include the non-Go dependencies. +We can generate a CSV file if that's really needed by running a regex over the markdown files: + +```sh +cat apple.md | grep "^ -" | sed -E "s/- \[(.*)\]\(.*?\) \(\[(.*)\]\((.*)\)\)/\1,\2,\3/" +``` + +## Reviewer instructions + +The majority of changes in this directory are from updating dependency versions. +In that case, only the URL for the license file will change to reflect the new version. +Occasionally, a dependency is added or removed, or the import path is changed. + +New dependencies require the closest review to ensure the license is acceptable. +Because we generate the license reports **after** dependencies are changed, +the new dependency would have already gone through one review when it was initially added. +This is just a secondary review to double-check the license. If in doubt, ask legal. + +Always do a normal GitHub code review on the license PR with a brief summary of what changed. +For example, see #13936 or #14064. Then approve and merge the PR. diff --git a/licenses/android.md b/licenses/android.md index ef53117e8..f578c17cb 100644 --- a/licenses/android.md +++ b/licenses/android.md @@ -9,79 +9,41 @@ Client][]. See also the dependencies in the [Tailscale CLI][]. - [filippo.io/edwards25519](https://pkg.go.dev/filippo.io/edwards25519) ([BSD-3-Clause](https://github.com/FiloSottile/edwards25519/blob/v1.1.0/LICENSE)) - - [github.com/aws/aws-sdk-go-v2](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/v1.24.1/LICENSE.txt)) - - [github.com/aws/aws-sdk-go-v2/config](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/config) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/config/v1.26.5/config/LICENSE.txt)) - - [github.com/aws/aws-sdk-go-v2/credentials](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/credentials) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/credentials/v1.16.16/credentials/LICENSE.txt)) - - [github.com/aws/aws-sdk-go-v2/feature/ec2/imds](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/feature/ec2/imds) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/feature/ec2/imds/v1.14.11/feature/ec2/imds/LICENSE.txt)) - - [github.com/aws/aws-sdk-go-v2/internal/configsources](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/internal/configsources) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/internal/configsources/v1.2.10/internal/configsources/LICENSE.txt)) - - [github.com/aws/aws-sdk-go-v2/internal/endpoints/v2](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/internal/endpoints/v2) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/internal/endpoints/v2.5.10/internal/endpoints/v2/LICENSE.txt)) - - [github.com/aws/aws-sdk-go-v2/internal/ini](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/internal/ini) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/internal/ini/v1.7.2/internal/ini/LICENSE.txt)) - - [github.com/aws/aws-sdk-go-v2/internal/sync/singleflight](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/internal/sync/singleflight) ([BSD-3-Clause](https://github.com/aws/aws-sdk-go-v2/blob/v1.24.1/internal/sync/singleflight/LICENSE)) - - [github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/service/internal/accept-encoding/v1.10.4/service/internal/accept-encoding/LICENSE.txt)) - - [github.com/aws/aws-sdk-go-v2/service/internal/presigned-url](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/internal/presigned-url) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/service/internal/presigned-url/v1.10.10/service/internal/presigned-url/LICENSE.txt)) - - [github.com/aws/aws-sdk-go-v2/service/ssm](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/ssm) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/service/ssm/v1.44.7/service/ssm/LICENSE.txt)) - - [github.com/aws/aws-sdk-go-v2/service/sso](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/sso) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/service/sso/v1.18.7/service/sso/LICENSE.txt)) - - [github.com/aws/aws-sdk-go-v2/service/ssooidc](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/ssooidc) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/service/ssooidc/v1.21.7/service/ssooidc/LICENSE.txt)) - - [github.com/aws/aws-sdk-go-v2/service/sts](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/sts) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/service/sts/v1.26.7/service/sts/LICENSE.txt)) - - [github.com/aws/smithy-go](https://pkg.go.dev/github.com/aws/smithy-go) ([Apache-2.0](https://github.com/aws/smithy-go/blob/v1.19.0/LICENSE)) - - [github.com/aws/smithy-go/internal/sync/singleflight](https://pkg.go.dev/github.com/aws/smithy-go/internal/sync/singleflight) ([BSD-3-Clause](https://github.com/aws/smithy-go/blob/v1.19.0/internal/sync/singleflight/LICENSE)) - - [github.com/bits-and-blooms/bitset](https://pkg.go.dev/github.com/bits-and-blooms/bitset) ([BSD-3-Clause](https://github.com/bits-and-blooms/bitset/blob/v1.13.0/LICENSE)) - - [github.com/coder/websocket](https://pkg.go.dev/github.com/coder/websocket) ([ISC](https://github.com/coder/websocket/blob/v1.8.12/LICENSE.txt)) - - [github.com/coreos/go-iptables/iptables](https://pkg.go.dev/github.com/coreos/go-iptables/iptables) ([Apache-2.0](https://github.com/coreos/go-iptables/blob/65c67c9f46e6/LICENSE)) - [github.com/djherbis/times](https://pkg.go.dev/github.com/djherbis/times) ([MIT](https://github.com/djherbis/times/blob/v1.6.0/LICENSE)) - - [github.com/fxamacker/cbor/v2](https://pkg.go.dev/github.com/fxamacker/cbor/v2) ([MIT](https://github.com/fxamacker/cbor/blob/v2.6.0/LICENSE)) - - [github.com/gaissmai/bart](https://pkg.go.dev/github.com/gaissmai/bart) ([MIT](https://github.com/gaissmai/bart/blob/v0.11.1/LICENSE)) - - [github.com/go-json-experiment/json](https://pkg.go.dev/github.com/go-json-experiment/json) ([BSD-3-Clause](https://github.com/go-json-experiment/json/blob/2e55bd4e08b0/LICENSE)) - - [github.com/godbus/dbus/v5](https://pkg.go.dev/github.com/godbus/dbus/v5) ([BSD-2-Clause](https://github.com/godbus/dbus/blob/76236955d466/LICENSE)) + - [github.com/fxamacker/cbor/v2](https://pkg.go.dev/github.com/fxamacker/cbor/v2) ([MIT](https://github.com/fxamacker/cbor/blob/v2.7.0/LICENSE)) + - [github.com/gaissmai/bart](https://pkg.go.dev/github.com/gaissmai/bart) ([MIT](https://github.com/gaissmai/bart/blob/v0.18.0/LICENSE)) + - [github.com/go-json-experiment/json](https://pkg.go.dev/github.com/go-json-experiment/json) ([BSD-3-Clause](https://github.com/go-json-experiment/json/blob/ebf49471dced/LICENSE)) - [github.com/golang/groupcache/lru](https://pkg.go.dev/github.com/golang/groupcache/lru) ([Apache-2.0](https://github.com/golang/groupcache/blob/41bb18bfe9da/LICENSE)) - [github.com/google/btree](https://pkg.go.dev/github.com/google/btree) ([Apache-2.0](https://github.com/google/btree/blob/v1.1.2/LICENSE)) - - [github.com/google/nftables](https://pkg.go.dev/github.com/google/nftables) ([Apache-2.0](https://github.com/google/nftables/blob/5e242ec57806/LICENSE)) - - [github.com/google/uuid](https://pkg.go.dev/github.com/google/uuid) ([BSD-3-Clause](https://github.com/google/uuid/blob/v1.6.0/LICENSE)) + - [github.com/google/go-tpm](https://pkg.go.dev/github.com/google/go-tpm) ([Apache-2.0](https://github.com/google/go-tpm/blob/v0.9.4/LICENSE)) - [github.com/hdevalence/ed25519consensus](https://pkg.go.dev/github.com/hdevalence/ed25519consensus) ([BSD-3-Clause](https://github.com/hdevalence/ed25519consensus/blob/v0.2.0/LICENSE)) - - [github.com/illarion/gonotify/v2](https://pkg.go.dev/github.com/illarion/gonotify/v2) ([MIT](https://github.com/illarion/gonotify/blob/v2.0.3/LICENSE)) - [github.com/insomniacslk/dhcp](https://pkg.go.dev/github.com/insomniacslk/dhcp) ([BSD-3-Clause](https://github.com/insomniacslk/dhcp/blob/8c70d406f6d2/LICENSE)) - [github.com/jellydator/ttlcache/v3](https://pkg.go.dev/github.com/jellydator/ttlcache/v3) ([MIT](https://github.com/jellydator/ttlcache/blob/v3.1.0/LICENSE)) - - [github.com/jmespath/go-jmespath](https://pkg.go.dev/github.com/jmespath/go-jmespath) ([Apache-2.0](https://github.com/jmespath/go-jmespath/blob/v0.4.0/LICENSE)) - - [github.com/josharian/native](https://pkg.go.dev/github.com/josharian/native) ([MIT](https://github.com/josharian/native/blob/5c7d0dd6ab86/license)) - - [github.com/klauspost/compress](https://pkg.go.dev/github.com/klauspost/compress) ([Apache-2.0](https://github.com/klauspost/compress/blob/v1.17.4/LICENSE)) - - [github.com/klauspost/compress/internal/snapref](https://pkg.go.dev/github.com/klauspost/compress/internal/snapref) ([BSD-3-Clause](https://github.com/klauspost/compress/blob/v1.17.4/internal/snapref/LICENSE)) - - [github.com/klauspost/compress/zstd/internal/xxhash](https://pkg.go.dev/github.com/klauspost/compress/zstd/internal/xxhash) ([MIT](https://github.com/klauspost/compress/blob/v1.17.4/zstd/internal/xxhash/LICENSE.txt)) + - [github.com/klauspost/compress](https://pkg.go.dev/github.com/klauspost/compress) ([Apache-2.0](https://github.com/klauspost/compress/blob/v1.17.11/LICENSE)) + - [github.com/klauspost/compress/internal/snapref](https://pkg.go.dev/github.com/klauspost/compress/internal/snapref) ([BSD-3-Clause](https://github.com/klauspost/compress/blob/v1.17.11/internal/snapref/LICENSE)) + - [github.com/klauspost/compress/zstd/internal/xxhash](https://pkg.go.dev/github.com/klauspost/compress/zstd/internal/xxhash) ([MIT](https://github.com/klauspost/compress/blob/v1.17.11/zstd/internal/xxhash/LICENSE.txt)) - [github.com/kortschak/wol](https://pkg.go.dev/github.com/kortschak/wol) ([BSD-3-Clause](https://github.com/kortschak/wol/blob/da482cc4850a/LICENSE)) - - [github.com/mdlayher/genetlink](https://pkg.go.dev/github.com/mdlayher/genetlink) ([MIT](https://github.com/mdlayher/genetlink/blob/v1.3.2/LICENSE.md)) - - [github.com/mdlayher/netlink](https://pkg.go.dev/github.com/mdlayher/netlink) ([MIT](https://github.com/mdlayher/netlink/blob/v1.7.2/LICENSE.md)) - - [github.com/mdlayher/sdnotify](https://pkg.go.dev/github.com/mdlayher/sdnotify) ([MIT](https://github.com/mdlayher/sdnotify/blob/v1.0.0/LICENSE.md)) - [github.com/mdlayher/socket](https://pkg.go.dev/github.com/mdlayher/socket) ([MIT](https://github.com/mdlayher/socket/blob/v0.5.0/LICENSE.md)) - - [github.com/miekg/dns](https://pkg.go.dev/github.com/miekg/dns) ([BSD-3-Clause](https://github.com/miekg/dns/blob/v1.1.58/LICENSE)) - - [github.com/mitchellh/go-ps](https://pkg.go.dev/github.com/mitchellh/go-ps) ([MIT](https://github.com/mitchellh/go-ps/blob/v1.0.0/LICENSE.md)) - [github.com/pierrec/lz4/v4](https://pkg.go.dev/github.com/pierrec/lz4/v4) ([BSD-3-Clause](https://github.com/pierrec/lz4/blob/v4.1.21/LICENSE)) - - [github.com/safchain/ethtool](https://pkg.go.dev/github.com/safchain/ethtool) ([Apache-2.0](https://github.com/safchain/ethtool/blob/v0.3.0/LICENSE)) - - [github.com/tailscale/golang-x-crypto](https://pkg.go.dev/github.com/tailscale/golang-x-crypto) ([BSD-3-Clause](https://github.com/tailscale/golang-x-crypto/blob/3fde5e568aa4/LICENSE)) - [github.com/tailscale/goupnp](https://pkg.go.dev/github.com/tailscale/goupnp) ([BSD-2-Clause](https://github.com/tailscale/goupnp/blob/c64d0f06ea05/LICENSE)) - - [github.com/tailscale/hujson](https://pkg.go.dev/github.com/tailscale/hujson) ([BSD-3-Clause](https://github.com/tailscale/hujson/blob/20486734a56a/LICENSE)) - - [github.com/tailscale/netlink](https://pkg.go.dev/github.com/tailscale/netlink) ([Apache-2.0](https://github.com/tailscale/netlink/blob/4d49adab4de7/LICENSE)) - - [github.com/tailscale/peercred](https://pkg.go.dev/github.com/tailscale/peercred) ([BSD-3-Clause](https://github.com/tailscale/peercred/blob/b535050b2aa4/LICENSE)) + - [github.com/tailscale/peercred](https://pkg.go.dev/github.com/tailscale/peercred) ([BSD-3-Clause](https://github.com/tailscale/peercred/blob/35a0c7bd7edc/LICENSE)) - [github.com/tailscale/tailscale-android/libtailscale](https://pkg.go.dev/github.com/tailscale/tailscale-android/libtailscale) ([BSD-3-Clause](https://github.com/tailscale/tailscale-android/blob/HEAD/LICENSE)) - - [github.com/tailscale/wireguard-go](https://pkg.go.dev/github.com/tailscale/wireguard-go) ([MIT](https://github.com/tailscale/wireguard-go/blob/799c1978fafc/LICENSE)) + - [github.com/tailscale/wireguard-go](https://pkg.go.dev/github.com/tailscale/wireguard-go) ([MIT](https://github.com/tailscale/wireguard-go/blob/1d0488a3d7da/LICENSE)) - [github.com/tailscale/xnet/webdav](https://pkg.go.dev/github.com/tailscale/xnet/webdav) ([BSD-3-Clause](https://github.com/tailscale/xnet/blob/8497ac4dab2e/LICENSE)) - - [github.com/tcnksm/go-httpstat](https://pkg.go.dev/github.com/tcnksm/go-httpstat) ([MIT](https://github.com/tcnksm/go-httpstat/blob/v0.2.0/LICENSE)) - - [github.com/u-root/uio](https://pkg.go.dev/github.com/u-root/uio) ([BSD-3-Clause](https://github.com/u-root/uio/blob/a3c409a6018e/LICENSE)) - - [github.com/vishvananda/netns](https://pkg.go.dev/github.com/vishvananda/netns) ([Apache-2.0](https://github.com/vishvananda/netns/blob/v0.0.4/LICENSE)) + - [github.com/u-root/uio](https://pkg.go.dev/github.com/u-root/uio) ([BSD-3-Clause](https://github.com/u-root/uio/blob/d2acac8f3701/LICENSE)) - [github.com/x448/float16](https://pkg.go.dev/github.com/x448/float16) ([MIT](https://github.com/x448/float16/blob/v0.8.4/LICENSE)) - - [go4.org/intern](https://pkg.go.dev/go4.org/intern) ([BSD-3-Clause](https://github.com/go4org/intern/blob/ae77deb06f29/LICENSE)) - - [go4.org/mem](https://pkg.go.dev/go4.org/mem) ([Apache-2.0](https://github.com/go4org/mem/blob/4f986261bf13/LICENSE)) + - [go4.org/mem](https://pkg.go.dev/go4.org/mem) ([Apache-2.0](https://github.com/go4org/mem/blob/ae6ca9944745/LICENSE)) - [go4.org/netipx](https://pkg.go.dev/go4.org/netipx) ([BSD-3-Clause](https://github.com/go4org/netipx/blob/fdeea329fbba/LICENSE)) - - [go4.org/unsafe/assume-no-moving-gc](https://pkg.go.dev/go4.org/unsafe/assume-no-moving-gc) ([BSD-3-Clause](https://github.com/go4org/unsafe-assume-no-moving-gc/blob/e7c30c78aeb2/LICENSE)) - - [golang.org/x/crypto](https://pkg.go.dev/golang.org/x/crypto) ([BSD-3-Clause](https://cs.opensource.google/go/x/crypto/+/v0.26.0:LICENSE)) - - [golang.org/x/exp](https://pkg.go.dev/golang.org/x/exp) ([BSD-3-Clause](https://cs.opensource.google/go/x/exp/+/1b970713:LICENSE)) + - [golang.org/x/crypto](https://pkg.go.dev/golang.org/x/crypto) ([BSD-3-Clause](https://cs.opensource.google/go/x/crypto/+/v0.38.0:LICENSE)) + - [golang.org/x/exp](https://pkg.go.dev/golang.org/x/exp) ([BSD-3-Clause](https://cs.opensource.google/go/x/exp/+/939b2ce7:LICENSE)) - [golang.org/x/mobile](https://pkg.go.dev/golang.org/x/mobile) ([BSD-3-Clause](https://cs.opensource.google/go/x/mobile/+/81131f64:LICENSE)) - - [golang.org/x/mod/semver](https://pkg.go.dev/golang.org/x/mod/semver) ([BSD-3-Clause](https://cs.opensource.google/go/x/mod/+/v0.20.0:LICENSE)) - - [golang.org/x/net](https://pkg.go.dev/golang.org/x/net) ([BSD-3-Clause](https://cs.opensource.google/go/x/net/+/v0.28.0:LICENSE)) - - [golang.org/x/sync](https://pkg.go.dev/golang.org/x/sync) ([BSD-3-Clause](https://cs.opensource.google/go/x/sync/+/v0.8.0:LICENSE)) - - [golang.org/x/sys](https://pkg.go.dev/golang.org/x/sys) ([BSD-3-Clause](https://cs.opensource.google/go/x/sys/+/v0.23.0:LICENSE)) - - [golang.org/x/term](https://pkg.go.dev/golang.org/x/term) ([BSD-3-Clause](https://cs.opensource.google/go/x/term/+/v0.23.0:LICENSE)) - - [golang.org/x/text](https://pkg.go.dev/golang.org/x/text) ([BSD-3-Clause](https://cs.opensource.google/go/x/text/+/v0.17.0:LICENSE)) - - [golang.org/x/time/rate](https://pkg.go.dev/golang.org/x/time/rate) ([BSD-3-Clause](https://cs.opensource.google/go/x/time/+/v0.5.0:LICENSE)) - - [golang.org/x/tools](https://pkg.go.dev/golang.org/x/tools) ([BSD-3-Clause](https://cs.opensource.google/go/x/tools/+/v0.24.0:LICENSE)) - - [gvisor.dev/gvisor/pkg](https://pkg.go.dev/gvisor.dev/gvisor/pkg) ([Apache-2.0](https://github.com/google/gvisor/blob/64c016c92987/LICENSE)) - - [inet.af/netaddr](https://pkg.go.dev/inet.af/netaddr) ([BSD-3-Clause](Unknown)) + - [golang.org/x/mod/semver](https://pkg.go.dev/golang.org/x/mod/semver) ([BSD-3-Clause](https://cs.opensource.google/go/x/mod/+/v0.24.0:LICENSE)) + - [golang.org/x/net](https://pkg.go.dev/golang.org/x/net) ([BSD-3-Clause](https://cs.opensource.google/go/x/net/+/v0.40.0:LICENSE)) + - [golang.org/x/sync](https://pkg.go.dev/golang.org/x/sync) ([BSD-3-Clause](https://cs.opensource.google/go/x/sync/+/v0.14.0:LICENSE)) + - [golang.org/x/sys](https://pkg.go.dev/golang.org/x/sys) ([BSD-3-Clause](https://cs.opensource.google/go/x/sys/+/v0.33.0:LICENSE)) + - [golang.org/x/term](https://pkg.go.dev/golang.org/x/term) ([BSD-3-Clause](https://cs.opensource.google/go/x/term/+/v0.32.0:LICENSE)) + - [golang.org/x/text](https://pkg.go.dev/golang.org/x/text) ([BSD-3-Clause](https://cs.opensource.google/go/x/text/+/v0.25.0:LICENSE)) + - [golang.org/x/time/rate](https://pkg.go.dev/golang.org/x/time/rate) ([BSD-3-Clause](https://cs.opensource.google/go/x/time/+/v0.11.0:LICENSE)) + - [golang.org/x/tools](https://pkg.go.dev/golang.org/x/tools) ([BSD-3-Clause](https://cs.opensource.google/go/x/tools/+/v0.33.0:LICENSE)) + - [gvisor.dev/gvisor/pkg](https://pkg.go.dev/gvisor.dev/gvisor/pkg) ([Apache-2.0](https://github.com/google/gvisor/blob/9414b50a5633/LICENSE)) - [tailscale.com](https://pkg.go.dev/tailscale.com) ([BSD-3-Clause](https://github.com/tailscale/tailscale/blob/HEAD/LICENSE)) diff --git a/licenses/apple.md b/licenses/apple.md index 4cb100c62..2a795ddbb 100644 --- a/licenses/apple.md +++ b/licenses/apple.md @@ -12,77 +12,71 @@ See also the dependencies in the [Tailscale CLI][]. - [filippo.io/edwards25519](https://pkg.go.dev/filippo.io/edwards25519) ([BSD-3-Clause](https://github.com/FiloSottile/edwards25519/blob/v1.1.0/LICENSE)) - - [github.com/aws/aws-sdk-go-v2](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/v1.30.4/LICENSE.txt)) - - [github.com/aws/aws-sdk-go-v2/config](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/config) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/config/v1.27.28/config/LICENSE.txt)) - - [github.com/aws/aws-sdk-go-v2/credentials](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/credentials) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/credentials/v1.17.28/credentials/LICENSE.txt)) - - [github.com/aws/aws-sdk-go-v2/feature/ec2/imds](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/feature/ec2/imds) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/feature/ec2/imds/v1.16.12/feature/ec2/imds/LICENSE.txt)) - - [github.com/aws/aws-sdk-go-v2/internal/configsources](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/internal/configsources) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/internal/configsources/v1.3.16/internal/configsources/LICENSE.txt)) - - [github.com/aws/aws-sdk-go-v2/internal/endpoints/v2](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/internal/endpoints/v2) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/internal/endpoints/v2.6.16/internal/endpoints/v2/LICENSE.txt)) - - [github.com/aws/aws-sdk-go-v2/internal/ini](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/internal/ini) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/internal/ini/v1.8.1/internal/ini/LICENSE.txt)) - - [github.com/aws/aws-sdk-go-v2/internal/sync/singleflight](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/internal/sync/singleflight) ([BSD-3-Clause](https://github.com/aws/aws-sdk-go-v2/blob/v1.30.4/internal/sync/singleflight/LICENSE)) - - [github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/service/internal/accept-encoding/v1.11.4/service/internal/accept-encoding/LICENSE.txt)) - - [github.com/aws/aws-sdk-go-v2/service/internal/presigned-url](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/internal/presigned-url) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/service/internal/presigned-url/v1.11.18/service/internal/presigned-url/LICENSE.txt)) + - [github.com/aws/aws-sdk-go-v2](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/v1.36.0/LICENSE.txt)) + - [github.com/aws/aws-sdk-go-v2/config](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/config) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/config/v1.29.5/config/LICENSE.txt)) + - [github.com/aws/aws-sdk-go-v2/credentials](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/credentials) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/credentials/v1.17.58/credentials/LICENSE.txt)) + - [github.com/aws/aws-sdk-go-v2/feature/ec2/imds](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/feature/ec2/imds) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/feature/ec2/imds/v1.16.27/feature/ec2/imds/LICENSE.txt)) + - [github.com/aws/aws-sdk-go-v2/internal/configsources](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/internal/configsources) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/internal/configsources/v1.3.31/internal/configsources/LICENSE.txt)) + - [github.com/aws/aws-sdk-go-v2/internal/endpoints/v2](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/internal/endpoints/v2) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/internal/endpoints/v2.6.31/internal/endpoints/v2/LICENSE.txt)) + - [github.com/aws/aws-sdk-go-v2/internal/ini](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/internal/ini) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/internal/ini/v1.8.2/internal/ini/LICENSE.txt)) + - [github.com/aws/aws-sdk-go-v2/internal/sync/singleflight](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/internal/sync/singleflight) ([BSD-3-Clause](https://github.com/aws/aws-sdk-go-v2/blob/v1.36.0/internal/sync/singleflight/LICENSE)) + - [github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/service/internal/accept-encoding/v1.12.2/service/internal/accept-encoding/LICENSE.txt)) + - [github.com/aws/aws-sdk-go-v2/service/internal/presigned-url](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/internal/presigned-url) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/service/internal/presigned-url/v1.12.12/service/internal/presigned-url/LICENSE.txt)) - [github.com/aws/aws-sdk-go-v2/service/ssm](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/ssm) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/service/ssm/v1.45.0/service/ssm/LICENSE.txt)) - - [github.com/aws/aws-sdk-go-v2/service/sso](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/sso) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/service/sso/v1.22.5/service/sso/LICENSE.txt)) - - [github.com/aws/aws-sdk-go-v2/service/ssooidc](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/ssooidc) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/service/ssooidc/v1.26.5/service/ssooidc/LICENSE.txt)) - - [github.com/aws/aws-sdk-go-v2/service/sts](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/sts) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/service/sts/v1.30.4/service/sts/LICENSE.txt)) - - [github.com/aws/smithy-go](https://pkg.go.dev/github.com/aws/smithy-go) ([Apache-2.0](https://github.com/aws/smithy-go/blob/v1.20.4/LICENSE)) - - [github.com/aws/smithy-go/internal/sync/singleflight](https://pkg.go.dev/github.com/aws/smithy-go/internal/sync/singleflight) ([BSD-3-Clause](https://github.com/aws/smithy-go/blob/v1.20.4/internal/sync/singleflight/LICENSE)) - - [github.com/bits-and-blooms/bitset](https://pkg.go.dev/github.com/bits-and-blooms/bitset) ([BSD-3-Clause](https://github.com/bits-and-blooms/bitset/blob/v1.13.0/LICENSE)) - - [github.com/coder/websocket](https://pkg.go.dev/github.com/coder/websocket) ([ISC](https://github.com/coder/websocket/blob/v1.8.12/LICENSE.txt)) + - [github.com/aws/aws-sdk-go-v2/service/sso](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/sso) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/service/sso/v1.24.14/service/sso/LICENSE.txt)) + - [github.com/aws/aws-sdk-go-v2/service/ssooidc](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/ssooidc) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/service/ssooidc/v1.28.13/service/ssooidc/LICENSE.txt)) + - [github.com/aws/aws-sdk-go-v2/service/sts](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/sts) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/service/sts/v1.33.13/service/sts/LICENSE.txt)) + - [github.com/aws/smithy-go](https://pkg.go.dev/github.com/aws/smithy-go) ([Apache-2.0](https://github.com/aws/smithy-go/blob/v1.22.2/LICENSE)) + - [github.com/aws/smithy-go/internal/sync/singleflight](https://pkg.go.dev/github.com/aws/smithy-go/internal/sync/singleflight) ([BSD-3-Clause](https://github.com/aws/smithy-go/blob/v1.22.2/internal/sync/singleflight/LICENSE)) - [github.com/coreos/go-iptables/iptables](https://pkg.go.dev/github.com/coreos/go-iptables/iptables) ([Apache-2.0](https://github.com/coreos/go-iptables/blob/65c67c9f46e6/LICENSE)) + - [github.com/creachadair/msync/trigger](https://pkg.go.dev/github.com/creachadair/msync/trigger) ([BSD-3-Clause](https://github.com/creachadair/msync/blob/v0.7.1/LICENSE)) - [github.com/digitalocean/go-smbios/smbios](https://pkg.go.dev/github.com/digitalocean/go-smbios/smbios) ([Apache-2.0](https://github.com/digitalocean/go-smbios/blob/390a4f403a8e/LICENSE.md)) - [github.com/djherbis/times](https://pkg.go.dev/github.com/djherbis/times) ([MIT](https://github.com/djherbis/times/blob/v1.6.0/LICENSE)) - - [github.com/fxamacker/cbor/v2](https://pkg.go.dev/github.com/fxamacker/cbor/v2) ([MIT](https://github.com/fxamacker/cbor/blob/v2.6.0/LICENSE)) - - [github.com/gaissmai/bart](https://pkg.go.dev/github.com/gaissmai/bart) ([MIT](https://github.com/gaissmai/bart/blob/v0.11.1/LICENSE)) - - [github.com/go-json-experiment/json](https://pkg.go.dev/github.com/go-json-experiment/json) ([BSD-3-Clause](https://github.com/go-json-experiment/json/blob/2e55bd4e08b0/LICENSE)) + - [github.com/fxamacker/cbor/v2](https://pkg.go.dev/github.com/fxamacker/cbor/v2) ([MIT](https://github.com/fxamacker/cbor/blob/v2.7.0/LICENSE)) + - [github.com/gaissmai/bart](https://pkg.go.dev/github.com/gaissmai/bart) ([MIT](https://github.com/gaissmai/bart/blob/v0.18.0/LICENSE)) + - [github.com/go-json-experiment/json](https://pkg.go.dev/github.com/go-json-experiment/json) ([BSD-3-Clause](https://github.com/go-json-experiment/json/blob/cc2cfa0554c3/LICENSE)) - [github.com/godbus/dbus/v5](https://pkg.go.dev/github.com/godbus/dbus/v5) ([BSD-2-Clause](https://github.com/godbus/dbus/blob/76236955d466/LICENSE)) - - [github.com/golang/groupcache/lru](https://pkg.go.dev/github.com/golang/groupcache/lru) ([Apache-2.0](https://github.com/golang/groupcache/blob/41bb18bfe9da/LICENSE)) + - [github.com/golang/groupcache/lru](https://pkg.go.dev/github.com/golang/groupcache/lru) ([Apache-2.0](https://github.com/golang/groupcache/blob/2c02b8208cf8/LICENSE)) - [github.com/google/btree](https://pkg.go.dev/github.com/google/btree) ([Apache-2.0](https://github.com/google/btree/blob/v1.1.2/LICENSE)) - [github.com/google/nftables](https://pkg.go.dev/github.com/google/nftables) ([Apache-2.0](https://github.com/google/nftables/blob/5e242ec57806/LICENSE)) - [github.com/google/uuid](https://pkg.go.dev/github.com/google/uuid) ([BSD-3-Clause](https://github.com/google/uuid/blob/v1.6.0/LICENSE)) - [github.com/hdevalence/ed25519consensus](https://pkg.go.dev/github.com/hdevalence/ed25519consensus) ([BSD-3-Clause](https://github.com/hdevalence/ed25519consensus/blob/v0.2.0/LICENSE)) - - [github.com/illarion/gonotify/v2](https://pkg.go.dev/github.com/illarion/gonotify/v2) ([MIT](https://github.com/illarion/gonotify/blob/v2.0.3/LICENSE)) + - [github.com/illarion/gonotify/v3](https://pkg.go.dev/github.com/illarion/gonotify/v3) ([MIT](https://github.com/illarion/gonotify/blob/v3.0.2/LICENSE)) - [github.com/insomniacslk/dhcp](https://pkg.go.dev/github.com/insomniacslk/dhcp) ([BSD-3-Clause](https://github.com/insomniacslk/dhcp/blob/15c9b8791914/LICENSE)) - [github.com/jellydator/ttlcache/v3](https://pkg.go.dev/github.com/jellydator/ttlcache/v3) ([MIT](https://github.com/jellydator/ttlcache/blob/v3.1.0/LICENSE)) - [github.com/jmespath/go-jmespath](https://pkg.go.dev/github.com/jmespath/go-jmespath) ([Apache-2.0](https://github.com/jmespath/go-jmespath/blob/v0.4.0/LICENSE)) - - [github.com/josharian/native](https://pkg.go.dev/github.com/josharian/native) ([MIT](https://github.com/josharian/native/blob/5c7d0dd6ab86/license)) - [github.com/jsimonetti/rtnetlink](https://pkg.go.dev/github.com/jsimonetti/rtnetlink) ([MIT](https://github.com/jsimonetti/rtnetlink/blob/v1.4.1/LICENSE.md)) - - [github.com/klauspost/compress](https://pkg.go.dev/github.com/klauspost/compress) ([Apache-2.0](https://github.com/klauspost/compress/blob/v1.17.8/LICENSE)) - - [github.com/klauspost/compress/internal/snapref](https://pkg.go.dev/github.com/klauspost/compress/internal/snapref) ([BSD-3-Clause](https://github.com/klauspost/compress/blob/v1.17.8/internal/snapref/LICENSE)) - - [github.com/klauspost/compress/zstd/internal/xxhash](https://pkg.go.dev/github.com/klauspost/compress/zstd/internal/xxhash) ([MIT](https://github.com/klauspost/compress/blob/v1.17.8/zstd/internal/xxhash/LICENSE.txt)) + - [github.com/klauspost/compress](https://pkg.go.dev/github.com/klauspost/compress) ([Apache-2.0](https://github.com/klauspost/compress/blob/v1.18.0/LICENSE)) + - [github.com/klauspost/compress/internal/snapref](https://pkg.go.dev/github.com/klauspost/compress/internal/snapref) ([BSD-3-Clause](https://github.com/klauspost/compress/blob/v1.18.0/internal/snapref/LICENSE)) + - [github.com/klauspost/compress/zstd/internal/xxhash](https://pkg.go.dev/github.com/klauspost/compress/zstd/internal/xxhash) ([MIT](https://github.com/klauspost/compress/blob/v1.18.0/zstd/internal/xxhash/LICENSE.txt)) - [github.com/kortschak/wol](https://pkg.go.dev/github.com/kortschak/wol) ([BSD-3-Clause](https://github.com/kortschak/wol/blob/da482cc4850a/LICENSE)) - [github.com/mdlayher/genetlink](https://pkg.go.dev/github.com/mdlayher/genetlink) ([MIT](https://github.com/mdlayher/genetlink/blob/v1.3.2/LICENSE.md)) - - [github.com/mdlayher/netlink](https://pkg.go.dev/github.com/mdlayher/netlink) ([MIT](https://github.com/mdlayher/netlink/blob/v1.7.2/LICENSE.md)) + - [github.com/mdlayher/netlink](https://pkg.go.dev/github.com/mdlayher/netlink) ([MIT](https://github.com/mdlayher/netlink/blob/fbb4dce95f42/LICENSE.md)) - [github.com/mdlayher/sdnotify](https://pkg.go.dev/github.com/mdlayher/sdnotify) ([MIT](https://github.com/mdlayher/sdnotify/blob/v1.0.0/LICENSE.md)) - [github.com/mdlayher/socket](https://pkg.go.dev/github.com/mdlayher/socket) ([MIT](https://github.com/mdlayher/socket/blob/v0.5.0/LICENSE.md)) - - [github.com/miekg/dns](https://pkg.go.dev/github.com/miekg/dns) ([BSD-3-Clause](https://github.com/miekg/dns/blob/v1.1.58/LICENSE)) - [github.com/mitchellh/go-ps](https://pkg.go.dev/github.com/mitchellh/go-ps) ([MIT](https://github.com/mitchellh/go-ps/blob/v1.0.0/LICENSE.md)) - - [github.com/pierrec/lz4/v4](https://pkg.go.dev/github.com/pierrec/lz4/v4) ([BSD-3-Clause](https://github.com/pierrec/lz4/blob/v4.1.21/LICENSE)) + - [github.com/pierrec/lz4/v4](https://pkg.go.dev/github.com/pierrec/lz4/v4) ([BSD-3-Clause](https://github.com/pierrec/lz4/blob/v4.1.22/LICENSE)) - [github.com/prometheus-community/pro-bing](https://pkg.go.dev/github.com/prometheus-community/pro-bing) ([MIT](https://github.com/prometheus-community/pro-bing/blob/v0.4.0/LICENSE)) - [github.com/safchain/ethtool](https://pkg.go.dev/github.com/safchain/ethtool) ([Apache-2.0](https://github.com/safchain/ethtool/blob/v0.3.0/LICENSE)) - - [github.com/tailscale/golang-x-crypto](https://pkg.go.dev/github.com/tailscale/golang-x-crypto) ([BSD-3-Clause](https://github.com/tailscale/golang-x-crypto/blob/3fde5e568aa4/LICENSE)) - [github.com/tailscale/goupnp](https://pkg.go.dev/github.com/tailscale/goupnp) ([BSD-2-Clause](https://github.com/tailscale/goupnp/blob/c64d0f06ea05/LICENSE)) - - [github.com/tailscale/hujson](https://pkg.go.dev/github.com/tailscale/hujson) ([BSD-3-Clause](https://github.com/tailscale/hujson/blob/20486734a56a/LICENSE)) - [github.com/tailscale/netlink](https://pkg.go.dev/github.com/tailscale/netlink) ([Apache-2.0](https://github.com/tailscale/netlink/blob/4d49adab4de7/LICENSE)) - - [github.com/tailscale/peercred](https://pkg.go.dev/github.com/tailscale/peercred) ([BSD-3-Clause](https://github.com/tailscale/peercred/blob/b535050b2aa4/LICENSE)) - - [github.com/tailscale/wireguard-go](https://pkg.go.dev/github.com/tailscale/wireguard-go) ([MIT](https://github.com/tailscale/wireguard-go/blob/799c1978fafc/LICENSE)) + - [github.com/tailscale/peercred](https://pkg.go.dev/github.com/tailscale/peercred) ([BSD-3-Clause](https://github.com/tailscale/peercred/blob/35a0c7bd7edc/LICENSE)) + - [github.com/tailscale/wireguard-go](https://pkg.go.dev/github.com/tailscale/wireguard-go) ([MIT](https://github.com/tailscale/wireguard-go/blob/1d0488a3d7da/LICENSE)) - [github.com/tailscale/xnet/webdav](https://pkg.go.dev/github.com/tailscale/xnet/webdav) ([BSD-3-Clause](https://github.com/tailscale/xnet/blob/8497ac4dab2e/LICENSE)) - - [github.com/tcnksm/go-httpstat](https://pkg.go.dev/github.com/tcnksm/go-httpstat) ([MIT](https://github.com/tcnksm/go-httpstat/blob/v0.2.0/LICENSE)) - - [github.com/u-root/uio](https://pkg.go.dev/github.com/u-root/uio) ([BSD-3-Clause](https://github.com/u-root/uio/blob/a3c409a6018e/LICENSE)) - - [github.com/vishvananda/netns](https://pkg.go.dev/github.com/vishvananda/netns) ([Apache-2.0](https://github.com/vishvananda/netns/blob/v0.0.4/LICENSE)) + - [github.com/u-root/uio](https://pkg.go.dev/github.com/u-root/uio) ([BSD-3-Clause](https://github.com/u-root/uio/blob/d2acac8f3701/LICENSE)) + - [github.com/vishvananda/netns](https://pkg.go.dev/github.com/vishvananda/netns) ([Apache-2.0](https://github.com/vishvananda/netns/blob/v0.0.5/LICENSE)) - [github.com/x448/float16](https://pkg.go.dev/github.com/x448/float16) ([MIT](https://github.com/x448/float16/blob/v0.8.4/LICENSE)) - [go4.org/mem](https://pkg.go.dev/go4.org/mem) ([Apache-2.0](https://github.com/go4org/mem/blob/ae6ca9944745/LICENSE)) - [go4.org/netipx](https://pkg.go.dev/go4.org/netipx) ([BSD-3-Clause](https://github.com/go4org/netipx/blob/fdeea329fbba/LICENSE)) - - [golang.org/x/crypto](https://pkg.go.dev/golang.org/x/crypto) ([BSD-3-Clause](https://cs.opensource.google/go/x/crypto/+/v0.25.0:LICENSE)) - - [golang.org/x/exp](https://pkg.go.dev/golang.org/x/exp) ([BSD-3-Clause](https://cs.opensource.google/go/x/exp/+/fe59bbe5:LICENSE)) - - [golang.org/x/net](https://pkg.go.dev/golang.org/x/net) ([BSD-3-Clause](https://cs.opensource.google/go/x/net/+/v0.27.0:LICENSE)) - - [golang.org/x/sync](https://pkg.go.dev/golang.org/x/sync) ([BSD-3-Clause](https://cs.opensource.google/go/x/sync/+/v0.7.0:LICENSE)) - - [golang.org/x/sys](https://pkg.go.dev/golang.org/x/sys) ([BSD-3-Clause](https://cs.opensource.google/go/x/sys/+/v0.22.0:LICENSE)) - - [golang.org/x/term](https://pkg.go.dev/golang.org/x/term) ([BSD-3-Clause](https://cs.opensource.google/go/x/term/+/v0.22.0:LICENSE)) - - [golang.org/x/text](https://pkg.go.dev/golang.org/x/text) ([BSD-3-Clause](https://cs.opensource.google/go/x/text/+/v0.16.0:LICENSE)) - - [golang.org/x/time/rate](https://pkg.go.dev/golang.org/x/time/rate) ([BSD-3-Clause](https://cs.opensource.google/go/x/time/+/v0.5.0:LICENSE)) - - [gvisor.dev/gvisor/pkg](https://pkg.go.dev/gvisor.dev/gvisor/pkg) ([Apache-2.0](https://github.com/google/gvisor/blob/64c016c92987/LICENSE)) + - [golang.org/x/crypto](https://pkg.go.dev/golang.org/x/crypto) ([BSD-3-Clause](https://cs.opensource.google/go/x/crypto/+/v0.43.0:LICENSE)) + - [golang.org/x/exp](https://pkg.go.dev/golang.org/x/exp) ([BSD-3-Clause](https://cs.opensource.google/go/x/exp/+/df929982:LICENSE)) + - [golang.org/x/net](https://pkg.go.dev/golang.org/x/net) ([BSD-3-Clause](https://cs.opensource.google/go/x/net/+/v0.46.0:LICENSE)) + - [golang.org/x/sync](https://pkg.go.dev/golang.org/x/sync) ([BSD-3-Clause](https://cs.opensource.google/go/x/sync/+/v0.17.0:LICENSE)) + - [golang.org/x/sys](https://pkg.go.dev/golang.org/x/sys) ([BSD-3-Clause](https://cs.opensource.google/go/x/sys/+/v0.37.0:LICENSE)) + - [golang.org/x/term](https://pkg.go.dev/golang.org/x/term) ([BSD-3-Clause](https://cs.opensource.google/go/x/term/+/v0.36.0:LICENSE)) + - [golang.org/x/text](https://pkg.go.dev/golang.org/x/text) ([BSD-3-Clause](https://cs.opensource.google/go/x/text/+/v0.30.0:LICENSE)) + - [golang.org/x/time/rate](https://pkg.go.dev/golang.org/x/time/rate) ([BSD-3-Clause](https://cs.opensource.google/go/x/time/+/v0.12.0:LICENSE)) + - [gvisor.dev/gvisor/pkg](https://pkg.go.dev/gvisor.dev/gvisor/pkg) ([Apache-2.0](https://github.com/google/gvisor/blob/9414b50a5633/LICENSE)) - [tailscale.com](https://pkg.go.dev/tailscale.com) ([BSD-3-Clause](https://github.com/tailscale/tailscale/blob/HEAD/LICENSE)) ## Additional Dependencies diff --git a/licenses/tailscale.md b/licenses/tailscale.md index 544aa91ce..163a76d40 100644 --- a/licenses/tailscale.md +++ b/licenses/tailscale.md @@ -14,103 +14,93 @@ Some packages may only be included on certain architectures or operating systems - [filippo.io/edwards25519](https://pkg.go.dev/filippo.io/edwards25519) ([BSD-3-Clause](https://github.com/FiloSottile/edwards25519/blob/v1.1.0/LICENSE)) + - [fyne.io/systray](https://pkg.go.dev/fyne.io/systray) ([Apache-2.0](https://github.com/fyne-io/systray/blob/4856ac3adc3c/LICENSE)) + - [github.com/Kodeworks/golang-image-ico](https://pkg.go.dev/github.com/Kodeworks/golang-image-ico) ([BSD-3-Clause](https://github.com/Kodeworks/golang-image-ico/blob/73f0f4cfade9/LICENSE)) - [github.com/akutz/memconn](https://pkg.go.dev/github.com/akutz/memconn) ([Apache-2.0](https://github.com/akutz/memconn/blob/v0.1.0/LICENSE)) - [github.com/alexbrainman/sspi](https://pkg.go.dev/github.com/alexbrainman/sspi) ([BSD-3-Clause](https://github.com/alexbrainman/sspi/blob/1a75b4708caa/LICENSE)) - [github.com/anmitsu/go-shlex](https://pkg.go.dev/github.com/anmitsu/go-shlex) ([MIT](https://github.com/anmitsu/go-shlex/blob/38f4b401e2be/LICENSE)) - - [github.com/aws/aws-sdk-go-v2](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/v1.24.1/LICENSE.txt)) - - [github.com/aws/aws-sdk-go-v2/config](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/config) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/config/v1.26.5/config/LICENSE.txt)) - - [github.com/aws/aws-sdk-go-v2/credentials](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/credentials) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/credentials/v1.16.16/credentials/LICENSE.txt)) - - [github.com/aws/aws-sdk-go-v2/feature/ec2/imds](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/feature/ec2/imds) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/feature/ec2/imds/v1.14.11/feature/ec2/imds/LICENSE.txt)) - - [github.com/aws/aws-sdk-go-v2/internal/configsources](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/internal/configsources) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/internal/configsources/v1.2.10/internal/configsources/LICENSE.txt)) - - [github.com/aws/aws-sdk-go-v2/internal/endpoints/v2](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/internal/endpoints/v2) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/internal/endpoints/v2.5.10/internal/endpoints/v2/LICENSE.txt)) - - [github.com/aws/aws-sdk-go-v2/internal/ini](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/internal/ini) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/internal/ini/v1.7.2/internal/ini/LICENSE.txt)) - - [github.com/aws/aws-sdk-go-v2/internal/sync/singleflight](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/internal/sync/singleflight) ([BSD-3-Clause](https://github.com/aws/aws-sdk-go-v2/blob/v1.24.1/internal/sync/singleflight/LICENSE)) - - [github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/service/internal/accept-encoding/v1.10.4/service/internal/accept-encoding/LICENSE.txt)) - - [github.com/aws/aws-sdk-go-v2/service/internal/presigned-url](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/internal/presigned-url) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/service/internal/presigned-url/v1.10.10/service/internal/presigned-url/LICENSE.txt)) + - [github.com/atotto/clipboard](https://pkg.go.dev/github.com/atotto/clipboard) ([BSD-3-Clause](https://github.com/atotto/clipboard/blob/v0.1.4/LICENSE)) + - [github.com/aws/aws-sdk-go-v2](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/v1.36.0/LICENSE.txt)) + - [github.com/aws/aws-sdk-go-v2/config](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/config) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/config/v1.29.5/config/LICENSE.txt)) + - [github.com/aws/aws-sdk-go-v2/credentials](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/credentials) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/credentials/v1.17.58/credentials/LICENSE.txt)) + - [github.com/aws/aws-sdk-go-v2/feature/ec2/imds](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/feature/ec2/imds) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/feature/ec2/imds/v1.16.27/feature/ec2/imds/LICENSE.txt)) + - [github.com/aws/aws-sdk-go-v2/internal/configsources](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/internal/configsources) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/internal/configsources/v1.3.31/internal/configsources/LICENSE.txt)) + - [github.com/aws/aws-sdk-go-v2/internal/endpoints/v2](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/internal/endpoints/v2) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/internal/endpoints/v2.6.31/internal/endpoints/v2/LICENSE.txt)) + - [github.com/aws/aws-sdk-go-v2/internal/ini](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/internal/ini) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/internal/ini/v1.8.2/internal/ini/LICENSE.txt)) + - [github.com/aws/aws-sdk-go-v2/internal/sync/singleflight](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/internal/sync/singleflight) ([BSD-3-Clause](https://github.com/aws/aws-sdk-go-v2/blob/v1.36.0/internal/sync/singleflight/LICENSE)) + - [github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/service/internal/accept-encoding/v1.12.2/service/internal/accept-encoding/LICENSE.txt)) + - [github.com/aws/aws-sdk-go-v2/service/internal/presigned-url](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/internal/presigned-url) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/service/internal/presigned-url/v1.12.12/service/internal/presigned-url/LICENSE.txt)) - [github.com/aws/aws-sdk-go-v2/service/ssm](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/ssm) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/service/ssm/v1.44.7/service/ssm/LICENSE.txt)) - - [github.com/aws/aws-sdk-go-v2/service/sso](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/sso) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/service/sso/v1.18.7/service/sso/LICENSE.txt)) - - [github.com/aws/aws-sdk-go-v2/service/ssooidc](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/ssooidc) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/service/ssooidc/v1.21.7/service/ssooidc/LICENSE.txt)) - - [github.com/aws/aws-sdk-go-v2/service/sts](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/sts) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/service/sts/v1.26.7/service/sts/LICENSE.txt)) - - [github.com/aws/smithy-go](https://pkg.go.dev/github.com/aws/smithy-go) ([Apache-2.0](https://github.com/aws/smithy-go/blob/v1.19.0/LICENSE)) - - [github.com/aws/smithy-go/internal/sync/singleflight](https://pkg.go.dev/github.com/aws/smithy-go/internal/sync/singleflight) ([BSD-3-Clause](https://github.com/aws/smithy-go/blob/v1.19.0/internal/sync/singleflight/LICENSE)) - - [github.com/bits-and-blooms/bitset](https://pkg.go.dev/github.com/bits-and-blooms/bitset) ([BSD-3-Clause](https://github.com/bits-and-blooms/bitset/blob/v1.13.0/LICENSE)) + - [github.com/aws/aws-sdk-go-v2/service/sso](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/sso) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/service/sso/v1.24.14/service/sso/LICENSE.txt)) + - [github.com/aws/aws-sdk-go-v2/service/ssooidc](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/ssooidc) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/service/ssooidc/v1.28.13/service/ssooidc/LICENSE.txt)) + - [github.com/aws/aws-sdk-go-v2/service/sts](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/sts) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/service/sts/v1.33.13/service/sts/LICENSE.txt)) + - [github.com/aws/smithy-go](https://pkg.go.dev/github.com/aws/smithy-go) ([Apache-2.0](https://github.com/aws/smithy-go/blob/v1.22.2/LICENSE)) + - [github.com/aws/smithy-go/internal/sync/singleflight](https://pkg.go.dev/github.com/aws/smithy-go/internal/sync/singleflight) ([BSD-3-Clause](https://github.com/aws/smithy-go/blob/v1.22.2/internal/sync/singleflight/LICENSE)) - [github.com/coder/websocket](https://pkg.go.dev/github.com/coder/websocket) ([ISC](https://github.com/coder/websocket/blob/v1.8.12/LICENSE.txt)) - - [github.com/coreos/go-iptables/iptables](https://pkg.go.dev/github.com/coreos/go-iptables/iptables) ([Apache-2.0](https://github.com/coreos/go-iptables/blob/65c67c9f46e6/LICENSE)) + - [github.com/creachadair/msync/trigger](https://pkg.go.dev/github.com/creachadair/msync/trigger) ([BSD-3-Clause](https://github.com/creachadair/msync/blob/v0.7.1/LICENSE)) - [github.com/creack/pty](https://pkg.go.dev/github.com/creack/pty) ([MIT](https://github.com/creack/pty/blob/v1.1.23/LICENSE)) - [github.com/dblohm7/wingoes](https://pkg.go.dev/github.com/dblohm7/wingoes) ([BSD-3-Clause](https://github.com/dblohm7/wingoes/blob/a09d6be7affa/LICENSE)) - [github.com/digitalocean/go-smbios/smbios](https://pkg.go.dev/github.com/digitalocean/go-smbios/smbios) ([Apache-2.0](https://github.com/digitalocean/go-smbios/blob/390a4f403a8e/LICENSE.md)) - [github.com/djherbis/times](https://pkg.go.dev/github.com/djherbis/times) ([MIT](https://github.com/djherbis/times/blob/v1.6.0/LICENSE)) - - [github.com/fxamacker/cbor/v2](https://pkg.go.dev/github.com/fxamacker/cbor/v2) ([MIT](https://github.com/fxamacker/cbor/blob/v2.6.0/LICENSE)) - - [github.com/gaissmai/bart](https://pkg.go.dev/github.com/gaissmai/bart) ([MIT](https://github.com/gaissmai/bart/blob/v0.11.1/LICENSE)) - - [github.com/go-json-experiment/json](https://pkg.go.dev/github.com/go-json-experiment/json) ([BSD-3-Clause](https://github.com/go-json-experiment/json/blob/2e55bd4e08b0/LICENSE)) + - [github.com/fogleman/gg](https://pkg.go.dev/github.com/fogleman/gg) ([MIT](https://github.com/fogleman/gg/blob/v1.3.0/LICENSE.md)) + - [github.com/fxamacker/cbor/v2](https://pkg.go.dev/github.com/fxamacker/cbor/v2) ([MIT](https://github.com/fxamacker/cbor/blob/v2.7.0/LICENSE)) + - [github.com/gaissmai/bart](https://pkg.go.dev/github.com/gaissmai/bart) ([MIT](https://github.com/gaissmai/bart/blob/v0.18.0/LICENSE)) + - [github.com/go-json-experiment/json](https://pkg.go.dev/github.com/go-json-experiment/json) ([BSD-3-Clause](https://github.com/go-json-experiment/json/blob/ebf49471dced/LICENSE)) - [github.com/go-ole/go-ole](https://pkg.go.dev/github.com/go-ole/go-ole) ([MIT](https://github.com/go-ole/go-ole/blob/v1.3.0/LICENSE)) - [github.com/godbus/dbus/v5](https://pkg.go.dev/github.com/godbus/dbus/v5) ([BSD-2-Clause](https://github.com/godbus/dbus/blob/76236955d466/LICENSE)) + - [github.com/golang/freetype/raster](https://pkg.go.dev/github.com/golang/freetype/raster) ([Unknown](Unknown)) + - [github.com/golang/freetype/truetype](https://pkg.go.dev/github.com/golang/freetype/truetype) ([Unknown](Unknown)) - [github.com/golang/groupcache/lru](https://pkg.go.dev/github.com/golang/groupcache/lru) ([Apache-2.0](https://github.com/golang/groupcache/blob/41bb18bfe9da/LICENSE)) - [github.com/google/btree](https://pkg.go.dev/github.com/google/btree) ([Apache-2.0](https://github.com/google/btree/blob/v1.1.2/LICENSE)) - - [github.com/google/nftables](https://pkg.go.dev/github.com/google/nftables) ([Apache-2.0](https://github.com/google/nftables/blob/5e242ec57806/LICENSE)) - [github.com/google/uuid](https://pkg.go.dev/github.com/google/uuid) ([BSD-3-Clause](https://github.com/google/uuid/blob/v1.6.0/LICENSE)) - - [github.com/gorilla/csrf](https://pkg.go.dev/github.com/gorilla/csrf) ([BSD-3-Clause](https://github.com/gorilla/csrf/blob/v1.7.2/LICENSE)) - - [github.com/gorilla/securecookie](https://pkg.go.dev/github.com/gorilla/securecookie) ([BSD-3-Clause](https://github.com/gorilla/securecookie/blob/v1.1.2/LICENSE)) - [github.com/hdevalence/ed25519consensus](https://pkg.go.dev/github.com/hdevalence/ed25519consensus) ([BSD-3-Clause](https://github.com/hdevalence/ed25519consensus/blob/v0.2.0/LICENSE)) - - [github.com/illarion/gonotify/v2](https://pkg.go.dev/github.com/illarion/gonotify/v2) ([MIT](https://github.com/illarion/gonotify/blob/v2.0.3/LICENSE)) - [github.com/insomniacslk/dhcp](https://pkg.go.dev/github.com/insomniacslk/dhcp) ([BSD-3-Clause](https://github.com/insomniacslk/dhcp/blob/8c70d406f6d2/LICENSE)) - [github.com/jellydator/ttlcache/v3](https://pkg.go.dev/github.com/jellydator/ttlcache/v3) ([MIT](https://github.com/jellydator/ttlcache/blob/v3.1.0/LICENSE)) - [github.com/jmespath/go-jmespath](https://pkg.go.dev/github.com/jmespath/go-jmespath) ([Apache-2.0](https://github.com/jmespath/go-jmespath/blob/v0.4.0/LICENSE)) - - [github.com/josharian/native](https://pkg.go.dev/github.com/josharian/native) ([MIT](https://github.com/josharian/native/blob/5c7d0dd6ab86/license)) - [github.com/kballard/go-shellquote](https://pkg.go.dev/github.com/kballard/go-shellquote) ([MIT](https://github.com/kballard/go-shellquote/blob/95032a82bc51/LICENSE)) - - [github.com/klauspost/compress](https://pkg.go.dev/github.com/klauspost/compress) ([Apache-2.0](https://github.com/klauspost/compress/blob/v1.17.4/LICENSE)) - - [github.com/klauspost/compress/internal/snapref](https://pkg.go.dev/github.com/klauspost/compress/internal/snapref) ([BSD-3-Clause](https://github.com/klauspost/compress/blob/v1.17.4/internal/snapref/LICENSE)) - - [github.com/klauspost/compress/zstd/internal/xxhash](https://pkg.go.dev/github.com/klauspost/compress/zstd/internal/xxhash) ([MIT](https://github.com/klauspost/compress/blob/v1.17.4/zstd/internal/xxhash/LICENSE.txt)) + - [github.com/klauspost/compress](https://pkg.go.dev/github.com/klauspost/compress) ([Apache-2.0](https://github.com/klauspost/compress/blob/v1.17.11/LICENSE)) + - [github.com/klauspost/compress/internal/snapref](https://pkg.go.dev/github.com/klauspost/compress/internal/snapref) ([BSD-3-Clause](https://github.com/klauspost/compress/blob/v1.17.11/internal/snapref/LICENSE)) + - [github.com/klauspost/compress/zstd/internal/xxhash](https://pkg.go.dev/github.com/klauspost/compress/zstd/internal/xxhash) ([MIT](https://github.com/klauspost/compress/blob/v1.17.11/zstd/internal/xxhash/LICENSE.txt)) - [github.com/kortschak/wol](https://pkg.go.dev/github.com/kortschak/wol) ([BSD-3-Clause](https://github.com/kortschak/wol/blob/da482cc4850a/LICENSE)) - [github.com/kr/fs](https://pkg.go.dev/github.com/kr/fs) ([BSD-3-Clause](https://github.com/kr/fs/blob/v0.1.0/LICENSE)) - [github.com/mattn/go-colorable](https://pkg.go.dev/github.com/mattn/go-colorable) ([MIT](https://github.com/mattn/go-colorable/blob/v0.1.13/LICENSE)) - [github.com/mattn/go-isatty](https://pkg.go.dev/github.com/mattn/go-isatty) ([MIT](https://github.com/mattn/go-isatty/blob/v0.0.20/LICENSE)) - - [github.com/mdlayher/genetlink](https://pkg.go.dev/github.com/mdlayher/genetlink) ([MIT](https://github.com/mdlayher/genetlink/blob/v1.3.2/LICENSE.md)) - - [github.com/mdlayher/netlink](https://pkg.go.dev/github.com/mdlayher/netlink) ([MIT](https://github.com/mdlayher/netlink/blob/v1.7.2/LICENSE.md)) - - [github.com/mdlayher/sdnotify](https://pkg.go.dev/github.com/mdlayher/sdnotify) ([MIT](https://github.com/mdlayher/sdnotify/blob/v1.0.0/LICENSE.md)) - [github.com/mdlayher/socket](https://pkg.go.dev/github.com/mdlayher/socket) ([MIT](https://github.com/mdlayher/socket/blob/v0.5.0/LICENSE.md)) - - [github.com/miekg/dns](https://pkg.go.dev/github.com/miekg/dns) ([BSD-3-Clause](https://github.com/miekg/dns/blob/v1.1.58/LICENSE)) - [github.com/mitchellh/go-ps](https://pkg.go.dev/github.com/mitchellh/go-ps) ([MIT](https://github.com/mitchellh/go-ps/blob/v1.0.0/LICENSE.md)) - [github.com/peterbourgon/ff/v3](https://pkg.go.dev/github.com/peterbourgon/ff/v3) ([Apache-2.0](https://github.com/peterbourgon/ff/blob/v3.4.0/LICENSE)) - [github.com/pierrec/lz4/v4](https://pkg.go.dev/github.com/pierrec/lz4/v4) ([BSD-3-Clause](https://github.com/pierrec/lz4/blob/v4.1.21/LICENSE)) + - [github.com/pires/go-proxyproto](https://pkg.go.dev/github.com/pires/go-proxyproto) ([Apache-2.0](https://github.com/pires/go-proxyproto/blob/v0.8.1/LICENSE)) - [github.com/pkg/sftp](https://pkg.go.dev/github.com/pkg/sftp) ([BSD-2-Clause](https://github.com/pkg/sftp/blob/v1.13.6/LICENSE)) - [github.com/prometheus-community/pro-bing](https://pkg.go.dev/github.com/prometheus-community/pro-bing) ([MIT](https://github.com/prometheus-community/pro-bing/blob/v0.4.0/LICENSE)) - - [github.com/safchain/ethtool](https://pkg.go.dev/github.com/safchain/ethtool) ([Apache-2.0](https://github.com/safchain/ethtool/blob/v0.3.0/LICENSE)) - [github.com/skip2/go-qrcode](https://pkg.go.dev/github.com/skip2/go-qrcode) ([MIT](https://github.com/skip2/go-qrcode/blob/da1b6568686e/LICENSE)) - [github.com/tailscale/certstore](https://pkg.go.dev/github.com/tailscale/certstore) ([MIT](https://github.com/tailscale/certstore/blob/d3fa0460f47e/LICENSE.md)) - [github.com/tailscale/go-winio](https://pkg.go.dev/github.com/tailscale/go-winio) ([MIT](https://github.com/tailscale/go-winio/blob/c4f33415bf55/LICENSE)) - - [github.com/tailscale/golang-x-crypto](https://pkg.go.dev/github.com/tailscale/golang-x-crypto) ([BSD-3-Clause](https://github.com/tailscale/golang-x-crypto/blob/3fde5e568aa4/LICENSE)) - - [github.com/tailscale/hujson](https://pkg.go.dev/github.com/tailscale/hujson) ([BSD-3-Clause](https://github.com/tailscale/hujson/blob/20486734a56a/LICENSE)) - - [github.com/tailscale/netlink](https://pkg.go.dev/github.com/tailscale/netlink) ([Apache-2.0](https://github.com/tailscale/netlink/blob/4d49adab4de7/LICENSE)) - - [github.com/tailscale/peercred](https://pkg.go.dev/github.com/tailscale/peercred) ([BSD-3-Clause](https://github.com/tailscale/peercred/blob/b535050b2aa4/LICENSE)) - - [github.com/tailscale/web-client-prebuilt](https://pkg.go.dev/github.com/tailscale/web-client-prebuilt) ([BSD-3-Clause](https://github.com/tailscale/web-client-prebuilt/blob/5db17b287bf1/LICENSE)) + - [github.com/tailscale/web-client-prebuilt](https://pkg.go.dev/github.com/tailscale/web-client-prebuilt) ([BSD-3-Clause](https://github.com/tailscale/web-client-prebuilt/blob/d4cd19a26976/LICENSE)) - [github.com/tailscale/wf](https://pkg.go.dev/github.com/tailscale/wf) ([BSD-3-Clause](https://github.com/tailscale/wf/blob/6fbb0a674ee6/LICENSE)) - - [github.com/tailscale/wireguard-go](https://pkg.go.dev/github.com/tailscale/wireguard-go) ([MIT](https://github.com/tailscale/wireguard-go/blob/799c1978fafc/LICENSE)) + - [github.com/tailscale/wireguard-go](https://pkg.go.dev/github.com/tailscale/wireguard-go) ([MIT](https://github.com/tailscale/wireguard-go/blob/1d0488a3d7da/LICENSE)) - [github.com/tailscale/xnet/webdav](https://pkg.go.dev/github.com/tailscale/xnet/webdav) ([BSD-3-Clause](https://github.com/tailscale/xnet/blob/8497ac4dab2e/LICENSE)) - - [github.com/tcnksm/go-httpstat](https://pkg.go.dev/github.com/tcnksm/go-httpstat) ([MIT](https://github.com/tcnksm/go-httpstat/blob/v0.2.0/LICENSE)) - [github.com/toqueteos/webbrowser](https://pkg.go.dev/github.com/toqueteos/webbrowser) ([MIT](https://github.com/toqueteos/webbrowser/blob/v1.2.0/LICENSE.md)) - - [github.com/u-root/u-root/pkg/termios](https://pkg.go.dev/github.com/u-root/u-root/pkg/termios) ([BSD-3-Clause](https://github.com/u-root/u-root/blob/v0.12.0/LICENSE)) - - [github.com/u-root/uio](https://pkg.go.dev/github.com/u-root/uio) ([BSD-3-Clause](https://github.com/u-root/uio/blob/a3c409a6018e/LICENSE)) - - [github.com/vishvananda/netns](https://pkg.go.dev/github.com/vishvananda/netns) ([Apache-2.0](https://github.com/vishvananda/netns/blob/v0.0.4/LICENSE)) + - [github.com/u-root/u-root/pkg/termios](https://pkg.go.dev/github.com/u-root/u-root/pkg/termios) ([BSD-3-Clause](https://github.com/u-root/u-root/blob/v0.14.0/LICENSE)) + - [github.com/u-root/uio](https://pkg.go.dev/github.com/u-root/uio) ([BSD-3-Clause](https://github.com/u-root/uio/blob/d2acac8f3701/LICENSE)) - [github.com/x448/float16](https://pkg.go.dev/github.com/x448/float16) ([MIT](https://github.com/x448/float16/blob/v0.8.4/LICENSE)) - - [go4.org/mem](https://pkg.go.dev/go4.org/mem) ([Apache-2.0](https://github.com/go4org/mem/blob/4f986261bf13/LICENSE)) + - [go4.org/mem](https://pkg.go.dev/go4.org/mem) ([Apache-2.0](https://github.com/go4org/mem/blob/ae6ca9944745/LICENSE)) - [go4.org/netipx](https://pkg.go.dev/go4.org/netipx) ([BSD-3-Clause](https://github.com/go4org/netipx/blob/fdeea329fbba/LICENSE)) - - [golang.org/x/crypto](https://pkg.go.dev/golang.org/x/crypto) ([BSD-3-Clause](https://cs.opensource.google/go/x/crypto/+/v0.25.0:LICENSE)) - - [golang.org/x/exp](https://pkg.go.dev/golang.org/x/exp) ([BSD-3-Clause](https://cs.opensource.google/go/x/exp/+/1b970713:LICENSE)) - - [golang.org/x/net](https://pkg.go.dev/golang.org/x/net) ([BSD-3-Clause](https://cs.opensource.google/go/x/net/+/v0.27.0:LICENSE)) - - [golang.org/x/oauth2](https://pkg.go.dev/golang.org/x/oauth2) ([BSD-3-Clause](https://cs.opensource.google/go/x/oauth2/+/v0.16.0:LICENSE)) - - [golang.org/x/sync](https://pkg.go.dev/golang.org/x/sync) ([BSD-3-Clause](https://cs.opensource.google/go/x/sync/+/v0.7.0:LICENSE)) - - [golang.org/x/sys](https://pkg.go.dev/golang.org/x/sys) ([BSD-3-Clause](https://cs.opensource.google/go/x/sys/+/v0.22.0:LICENSE)) - - [golang.org/x/term](https://pkg.go.dev/golang.org/x/term) ([BSD-3-Clause](https://cs.opensource.google/go/x/term/+/v0.22.0:LICENSE)) - - [golang.org/x/text](https://pkg.go.dev/golang.org/x/text) ([BSD-3-Clause](https://cs.opensource.google/go/x/text/+/v0.16.0:LICENSE)) - - [golang.org/x/time/rate](https://pkg.go.dev/golang.org/x/time/rate) ([BSD-3-Clause](https://cs.opensource.google/go/x/time/+/v0.5.0:LICENSE)) + - [golang.org/x/crypto](https://pkg.go.dev/golang.org/x/crypto) ([BSD-3-Clause](https://cs.opensource.google/go/x/crypto/+/v0.38.0:LICENSE)) + - [golang.org/x/exp](https://pkg.go.dev/golang.org/x/exp) ([BSD-3-Clause](https://cs.opensource.google/go/x/exp/+/939b2ce7:LICENSE)) + - [golang.org/x/image](https://pkg.go.dev/golang.org/x/image) ([BSD-3-Clause](https://cs.opensource.google/go/x/image/+/v0.27.0:LICENSE)) + - [golang.org/x/net](https://pkg.go.dev/golang.org/x/net) ([BSD-3-Clause](https://cs.opensource.google/go/x/net/+/v0.40.0:LICENSE)) + - [golang.org/x/oauth2](https://pkg.go.dev/golang.org/x/oauth2) ([BSD-3-Clause](https://cs.opensource.google/go/x/oauth2/+/v0.30.0:LICENSE)) + - [golang.org/x/sync](https://pkg.go.dev/golang.org/x/sync) ([BSD-3-Clause](https://cs.opensource.google/go/x/sync/+/v0.14.0:LICENSE)) + - [golang.org/x/sys](https://pkg.go.dev/golang.org/x/sys) ([BSD-3-Clause](https://cs.opensource.google/go/x/sys/+/v0.33.0:LICENSE)) + - [golang.org/x/term](https://pkg.go.dev/golang.org/x/term) ([BSD-3-Clause](https://cs.opensource.google/go/x/term/+/v0.32.0:LICENSE)) + - [golang.org/x/text](https://pkg.go.dev/golang.org/x/text) ([BSD-3-Clause](https://cs.opensource.google/go/x/text/+/v0.25.0:LICENSE)) + - [golang.org/x/time/rate](https://pkg.go.dev/golang.org/x/time/rate) ([BSD-3-Clause](https://cs.opensource.google/go/x/time/+/v0.11.0:LICENSE)) - [golang.zx2c4.com/wintun](https://pkg.go.dev/golang.zx2c4.com/wintun) ([MIT](https://git.zx2c4.com/wintun-go/tree/LICENSE?id=0fa3db229ce2)) - [golang.zx2c4.com/wireguard/windows/tunnel/winipcfg](https://pkg.go.dev/golang.zx2c4.com/wireguard/windows/tunnel/winipcfg) ([MIT](https://git.zx2c4.com/wireguard-windows/tree/COPYING?h=v0.5.3)) - - [gvisor.dev/gvisor/pkg](https://pkg.go.dev/gvisor.dev/gvisor/pkg) ([Apache-2.0](https://github.com/google/gvisor/blob/64c016c92987/LICENSE)) - - [k8s.io/client-go/util/homedir](https://pkg.go.dev/k8s.io/client-go/util/homedir) ([Apache-2.0](https://github.com/kubernetes/client-go/blob/v0.30.3/LICENSE)) + - [gvisor.dev/gvisor/pkg](https://pkg.go.dev/gvisor.dev/gvisor/pkg) ([Apache-2.0](https://github.com/google/gvisor/blob/9414b50a5633/LICENSE)) + - [k8s.io/client-go/util/homedir](https://pkg.go.dev/k8s.io/client-go/util/homedir) ([Apache-2.0](https://github.com/kubernetes/client-go/blob/v0.32.0/LICENSE)) - [sigs.k8s.io/yaml](https://pkg.go.dev/sigs.k8s.io/yaml) ([Apache-2.0](https://github.com/kubernetes-sigs/yaml/blob/v1.4.0/LICENSE)) - [sigs.k8s.io/yaml/goyaml.v2](https://pkg.go.dev/sigs.k8s.io/yaml/goyaml.v2) ([Apache-2.0](https://github.com/kubernetes-sigs/yaml/blob/v1.4.0/goyaml.v2/LICENSE)) - - [software.sslmate.com/src/go-pkcs12](https://pkg.go.dev/software.sslmate.com/src/go-pkcs12) ([BSD-3-Clause](https://github.com/SSLMate/go-pkcs12/blob/v0.4.0/LICENSE)) - [tailscale.com](https://pkg.go.dev/tailscale.com) ([BSD-3-Clause](https://github.com/tailscale/tailscale/blob/HEAD/LICENSE)) - [tailscale.com/tempfork/gliderlabs/ssh](https://pkg.go.dev/tailscale.com/tempfork/gliderlabs/ssh) ([BSD-3-Clause](https://github.com/tailscale/tailscale/blob/HEAD/tempfork/gliderlabs/ssh/LICENSE)) - [tailscale.com/tempfork/spf13/cobra](https://pkg.go.dev/tailscale.com/tempfork/spf13/cobra) ([Apache-2.0](https://github.com/tailscale/tailscale/blob/HEAD/tempfork/spf13/cobra/LICENSE.txt)) diff --git a/licenses/windows.md b/licenses/windows.md index e7f7f6f13..06a5712ce 100644 --- a/licenses/windows.md +++ b/licenses/windows.md @@ -10,73 +10,60 @@ Windows][]. See also the dependencies in the [Tailscale CLI][]. - [filippo.io/edwards25519](https://pkg.go.dev/filippo.io/edwards25519) ([BSD-3-Clause](https://github.com/FiloSottile/edwards25519/blob/v1.1.0/LICENSE)) - - [github.com/alexbrainman/sspi](https://pkg.go.dev/github.com/alexbrainman/sspi) ([BSD-3-Clause](https://github.com/alexbrainman/sspi/blob/1a75b4708caa/LICENSE)) - [github.com/apenwarr/fixconsole](https://pkg.go.dev/github.com/apenwarr/fixconsole) ([Apache-2.0](https://github.com/apenwarr/fixconsole/blob/5a9f6489cc29/LICENSE)) - [github.com/apenwarr/w32](https://pkg.go.dev/github.com/apenwarr/w32) ([BSD-3-Clause](https://github.com/apenwarr/w32/blob/aa00fece76ab/LICENSE)) - - [github.com/aws/aws-sdk-go-v2](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/v1.30.4/LICENSE.txt)) - - [github.com/aws/aws-sdk-go-v2/config](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/config) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/config/v1.27.28/config/LICENSE.txt)) - - [github.com/aws/aws-sdk-go-v2/credentials](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/credentials) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/credentials/v1.17.28/credentials/LICENSE.txt)) - - [github.com/aws/aws-sdk-go-v2/feature/ec2/imds](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/feature/ec2/imds) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/feature/ec2/imds/v1.16.12/feature/ec2/imds/LICENSE.txt)) - - [github.com/aws/aws-sdk-go-v2/internal/configsources](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/internal/configsources) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/internal/configsources/v1.3.16/internal/configsources/LICENSE.txt)) - - [github.com/aws/aws-sdk-go-v2/internal/endpoints/v2](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/internal/endpoints/v2) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/internal/endpoints/v2.6.16/internal/endpoints/v2/LICENSE.txt)) - - [github.com/aws/aws-sdk-go-v2/internal/ini](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/internal/ini) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/internal/ini/v1.8.1/internal/ini/LICENSE.txt)) - - [github.com/aws/aws-sdk-go-v2/internal/sync/singleflight](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/internal/sync/singleflight) ([BSD-3-Clause](https://github.com/aws/aws-sdk-go-v2/blob/v1.30.4/internal/sync/singleflight/LICENSE)) - - [github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/service/internal/accept-encoding/v1.11.4/service/internal/accept-encoding/LICENSE.txt)) - - [github.com/aws/aws-sdk-go-v2/service/internal/presigned-url](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/internal/presigned-url) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/service/internal/presigned-url/v1.11.18/service/internal/presigned-url/LICENSE.txt)) - - [github.com/aws/aws-sdk-go-v2/service/ssm](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/ssm) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/service/ssm/v1.45.0/service/ssm/LICENSE.txt)) - - [github.com/aws/aws-sdk-go-v2/service/sso](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/sso) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/service/sso/v1.22.5/service/sso/LICENSE.txt)) - - [github.com/aws/aws-sdk-go-v2/service/ssooidc](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/ssooidc) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/service/ssooidc/v1.26.5/service/ssooidc/LICENSE.txt)) - - [github.com/aws/aws-sdk-go-v2/service/sts](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/sts) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/service/sts/v1.30.4/service/sts/LICENSE.txt)) - - [github.com/aws/smithy-go](https://pkg.go.dev/github.com/aws/smithy-go) ([Apache-2.0](https://github.com/aws/smithy-go/blob/v1.20.4/LICENSE)) - - [github.com/aws/smithy-go/internal/sync/singleflight](https://pkg.go.dev/github.com/aws/smithy-go/internal/sync/singleflight) ([BSD-3-Clause](https://github.com/aws/smithy-go/blob/v1.20.4/internal/sync/singleflight/LICENSE)) - - [github.com/coreos/go-iptables/iptables](https://pkg.go.dev/github.com/coreos/go-iptables/iptables) ([Apache-2.0](https://github.com/coreos/go-iptables/blob/65c67c9f46e6/LICENSE)) + - [github.com/beorn7/perks/quantile](https://pkg.go.dev/github.com/beorn7/perks/quantile) ([MIT](https://github.com/beorn7/perks/blob/v1.0.1/LICENSE)) + - [github.com/cespare/xxhash/v2](https://pkg.go.dev/github.com/cespare/xxhash/v2) ([MIT](https://github.com/cespare/xxhash/blob/v2.3.0/LICENSE.txt)) + - [github.com/coder/websocket](https://pkg.go.dev/github.com/coder/websocket) ([ISC](https://github.com/coder/websocket/blob/v1.8.12/LICENSE.txt)) + - [github.com/creachadair/msync/trigger](https://pkg.go.dev/github.com/creachadair/msync/trigger) ([BSD-3-Clause](https://github.com/creachadair/msync/blob/v0.7.1/LICENSE)) - [github.com/dblohm7/wingoes](https://pkg.go.dev/github.com/dblohm7/wingoes) ([BSD-3-Clause](https://github.com/dblohm7/wingoes/blob/b75a8a7d7eb0/LICENSE)) - [github.com/djherbis/times](https://pkg.go.dev/github.com/djherbis/times) ([MIT](https://github.com/djherbis/times/blob/v1.6.0/LICENSE)) - - [github.com/fxamacker/cbor/v2](https://pkg.go.dev/github.com/fxamacker/cbor/v2) ([MIT](https://github.com/fxamacker/cbor/blob/v2.6.0/LICENSE)) - - [github.com/go-json-experiment/json](https://pkg.go.dev/github.com/go-json-experiment/json) ([BSD-3-Clause](https://github.com/go-json-experiment/json/blob/2e55bd4e08b0/LICENSE)) - - [github.com/golang/groupcache/lru](https://pkg.go.dev/github.com/golang/groupcache/lru) ([Apache-2.0](https://github.com/golang/groupcache/blob/41bb18bfe9da/LICENSE)) + - [github.com/fxamacker/cbor/v2](https://pkg.go.dev/github.com/fxamacker/cbor/v2) ([MIT](https://github.com/fxamacker/cbor/blob/v2.7.0/LICENSE)) + - [github.com/go-json-experiment/json](https://pkg.go.dev/github.com/go-json-experiment/json) ([BSD-3-Clause](https://github.com/go-json-experiment/json/blob/cc2cfa0554c3/LICENSE)) + - [github.com/golang/groupcache/lru](https://pkg.go.dev/github.com/golang/groupcache/lru) ([Apache-2.0](https://github.com/golang/groupcache/blob/2c02b8208cf8/LICENSE)) - [github.com/google/btree](https://pkg.go.dev/github.com/google/btree) ([Apache-2.0](https://github.com/google/btree/blob/v1.1.2/LICENSE)) - - [github.com/google/nftables](https://pkg.go.dev/github.com/google/nftables) ([Apache-2.0](https://github.com/google/nftables/blob/5e242ec57806/LICENSE)) + - [github.com/google/go-cmp/cmp](https://pkg.go.dev/github.com/google/go-cmp/cmp) ([BSD-3-Clause](https://github.com/google/go-cmp/blob/v0.7.0/LICENSE)) - [github.com/google/uuid](https://pkg.go.dev/github.com/google/uuid) ([BSD-3-Clause](https://github.com/google/uuid/blob/v1.6.0/LICENSE)) - [github.com/gregjones/httpcache](https://pkg.go.dev/github.com/gregjones/httpcache) ([MIT](https://github.com/gregjones/httpcache/blob/901d90724c79/LICENSE.txt)) - [github.com/hdevalence/ed25519consensus](https://pkg.go.dev/github.com/hdevalence/ed25519consensus) ([BSD-3-Clause](https://github.com/hdevalence/ed25519consensus/blob/v0.2.0/LICENSE)) - [github.com/jellydator/ttlcache/v3](https://pkg.go.dev/github.com/jellydator/ttlcache/v3) ([MIT](https://github.com/jellydator/ttlcache/blob/v3.1.0/LICENSE)) - - [github.com/jmespath/go-jmespath](https://pkg.go.dev/github.com/jmespath/go-jmespath) ([Apache-2.0](https://github.com/jmespath/go-jmespath/blob/v0.4.0/LICENSE)) - - [github.com/josharian/native](https://pkg.go.dev/github.com/josharian/native) ([MIT](https://github.com/josharian/native/blob/5c7d0dd6ab86/license)) - [github.com/jsimonetti/rtnetlink](https://pkg.go.dev/github.com/jsimonetti/rtnetlink) ([MIT](https://github.com/jsimonetti/rtnetlink/blob/v1.4.1/LICENSE.md)) - - [github.com/klauspost/compress](https://pkg.go.dev/github.com/klauspost/compress) ([Apache-2.0](https://github.com/klauspost/compress/blob/v1.17.8/LICENSE)) - - [github.com/klauspost/compress/internal/snapref](https://pkg.go.dev/github.com/klauspost/compress/internal/snapref) ([BSD-3-Clause](https://github.com/klauspost/compress/blob/v1.17.8/internal/snapref/LICENSE)) - - [github.com/klauspost/compress/zstd/internal/xxhash](https://pkg.go.dev/github.com/klauspost/compress/zstd/internal/xxhash) ([MIT](https://github.com/klauspost/compress/blob/v1.17.8/zstd/internal/xxhash/LICENSE.txt)) - - [github.com/mdlayher/netlink](https://pkg.go.dev/github.com/mdlayher/netlink) ([MIT](https://github.com/mdlayher/netlink/blob/v1.7.2/LICENSE.md)) + - [github.com/klauspost/compress](https://pkg.go.dev/github.com/klauspost/compress) ([Apache-2.0](https://github.com/klauspost/compress/blob/v1.18.0/LICENSE)) + - [github.com/klauspost/compress/internal/snapref](https://pkg.go.dev/github.com/klauspost/compress/internal/snapref) ([BSD-3-Clause](https://github.com/klauspost/compress/blob/v1.18.0/internal/snapref/LICENSE)) + - [github.com/klauspost/compress/zstd/internal/xxhash](https://pkg.go.dev/github.com/klauspost/compress/zstd/internal/xxhash) ([MIT](https://github.com/klauspost/compress/blob/v1.18.0/zstd/internal/xxhash/LICENSE.txt)) + - [github.com/mdlayher/netlink](https://pkg.go.dev/github.com/mdlayher/netlink) ([MIT](https://github.com/mdlayher/netlink/blob/fbb4dce95f42/LICENSE.md)) - [github.com/mdlayher/socket](https://pkg.go.dev/github.com/mdlayher/socket) ([MIT](https://github.com/mdlayher/socket/blob/v0.5.0/LICENSE.md)) - - [github.com/miekg/dns](https://pkg.go.dev/github.com/miekg/dns) ([BSD-3-Clause](https://github.com/miekg/dns/blob/v1.1.58/LICENSE)) - [github.com/mitchellh/go-ps](https://pkg.go.dev/github.com/mitchellh/go-ps) ([MIT](https://github.com/mitchellh/go-ps/blob/v1.0.0/LICENSE.md)) + - [github.com/munnerz/goautoneg](https://pkg.go.dev/github.com/munnerz/goautoneg) ([BSD-3-Clause](https://github.com/munnerz/goautoneg/blob/a7dc8b61c822/LICENSE)) - [github.com/nfnt/resize](https://pkg.go.dev/github.com/nfnt/resize) ([ISC](https://github.com/nfnt/resize/blob/83c6a9932646/LICENSE)) - [github.com/peterbourgon/diskv](https://pkg.go.dev/github.com/peterbourgon/diskv) ([MIT](https://github.com/peterbourgon/diskv/blob/v2.0.1/LICENSE)) + - [github.com/prometheus/client_golang/prometheus](https://pkg.go.dev/github.com/prometheus/client_golang/prometheus) ([Apache-2.0](https://github.com/prometheus/client_golang/blob/v1.23.2/LICENSE)) + - [github.com/prometheus/client_model/go](https://pkg.go.dev/github.com/prometheus/client_model/go) ([Apache-2.0](https://github.com/prometheus/client_model/blob/v0.6.2/LICENSE)) + - [github.com/prometheus/common](https://pkg.go.dev/github.com/prometheus/common) ([Apache-2.0](https://github.com/prometheus/common/blob/v0.66.1/LICENSE)) - [github.com/skip2/go-qrcode](https://pkg.go.dev/github.com/skip2/go-qrcode) ([MIT](https://github.com/skip2/go-qrcode/blob/da1b6568686e/LICENSE)) - [github.com/tailscale/go-winio](https://pkg.go.dev/github.com/tailscale/go-winio) ([MIT](https://github.com/tailscale/go-winio/blob/c4f33415bf55/LICENSE)) - - [github.com/tailscale/hujson](https://pkg.go.dev/github.com/tailscale/hujson) ([BSD-3-Clause](https://github.com/tailscale/hujson/blob/20486734a56a/LICENSE)) - - [github.com/tailscale/netlink](https://pkg.go.dev/github.com/tailscale/netlink) ([Apache-2.0](https://github.com/tailscale/netlink/blob/4d49adab4de7/LICENSE)) - - [github.com/tailscale/walk](https://pkg.go.dev/github.com/tailscale/walk) ([BSD-3-Clause](https://github.com/tailscale/walk/blob/52804fd3056a/LICENSE)) - - [github.com/tailscale/win](https://pkg.go.dev/github.com/tailscale/win) ([BSD-3-Clause](https://github.com/tailscale/win/blob/6580b55d49ca/LICENSE)) + - [github.com/tailscale/hujson](https://pkg.go.dev/github.com/tailscale/hujson) ([BSD-3-Clause](https://github.com/tailscale/hujson/blob/992244df8c5a/LICENSE)) + - [github.com/tailscale/walk](https://pkg.go.dev/github.com/tailscale/walk) ([BSD-3-Clause](https://github.com/tailscale/walk/blob/963e260a8227/LICENSE)) + - [github.com/tailscale/win](https://pkg.go.dev/github.com/tailscale/win) ([BSD-3-Clause](https://github.com/tailscale/win/blob/f4da2b8ee071/LICENSE)) - [github.com/tailscale/xnet/webdav](https://pkg.go.dev/github.com/tailscale/xnet/webdav) ([BSD-3-Clause](https://github.com/tailscale/xnet/blob/8497ac4dab2e/LICENSE)) - [github.com/tc-hib/winres](https://pkg.go.dev/github.com/tc-hib/winres) ([0BSD](https://github.com/tc-hib/winres/blob/v0.2.1/LICENSE)) - - [github.com/vishvananda/netns](https://pkg.go.dev/github.com/vishvananda/netns) ([Apache-2.0](https://github.com/vishvananda/netns/blob/v0.0.4/LICENSE)) - [github.com/x448/float16](https://pkg.go.dev/github.com/x448/float16) ([MIT](https://github.com/x448/float16/blob/v0.8.4/LICENSE)) + - [go.yaml.in/yaml/v2](https://pkg.go.dev/go.yaml.in/yaml/v2) ([Apache-2.0](https://github.com/yaml/go-yaml/blob/v2.4.2/LICENSE)) - [go4.org/mem](https://pkg.go.dev/go4.org/mem) ([Apache-2.0](https://github.com/go4org/mem/blob/ae6ca9944745/LICENSE)) - [go4.org/netipx](https://pkg.go.dev/go4.org/netipx) ([BSD-3-Clause](https://github.com/go4org/netipx/blob/fdeea329fbba/LICENSE)) - - [golang.org/x/crypto](https://pkg.go.dev/golang.org/x/crypto) ([BSD-3-Clause](https://cs.opensource.google/go/x/crypto/+/v0.25.0:LICENSE)) - - [golang.org/x/exp](https://pkg.go.dev/golang.org/x/exp) ([BSD-3-Clause](https://cs.opensource.google/go/x/exp/+/fe59bbe5:LICENSE)) - - [golang.org/x/image/bmp](https://pkg.go.dev/golang.org/x/image/bmp) ([BSD-3-Clause](https://cs.opensource.google/go/x/image/+/v0.18.0:LICENSE)) - - [golang.org/x/mod](https://pkg.go.dev/golang.org/x/mod) ([BSD-3-Clause](https://cs.opensource.google/go/x/mod/+/v0.19.0:LICENSE)) - - [golang.org/x/net](https://pkg.go.dev/golang.org/x/net) ([BSD-3-Clause](https://cs.opensource.google/go/x/net/+/v0.27.0:LICENSE)) - - [golang.org/x/sync](https://pkg.go.dev/golang.org/x/sync) ([BSD-3-Clause](https://cs.opensource.google/go/x/sync/+/v0.7.0:LICENSE)) - - [golang.org/x/sys](https://pkg.go.dev/golang.org/x/sys) ([BSD-3-Clause](https://cs.opensource.google/go/x/sys/+/v0.22.0:LICENSE)) - - [golang.org/x/term](https://pkg.go.dev/golang.org/x/term) ([BSD-3-Clause](https://cs.opensource.google/go/x/term/+/v0.22.0:LICENSE)) - - [golang.org/x/text](https://pkg.go.dev/golang.org/x/text) ([BSD-3-Clause](https://cs.opensource.google/go/x/text/+/v0.16.0:LICENSE)) + - [golang.org/x/crypto](https://pkg.go.dev/golang.org/x/crypto) ([BSD-3-Clause](https://cs.opensource.google/go/x/crypto/+/v0.43.0:LICENSE)) + - [golang.org/x/exp](https://pkg.go.dev/golang.org/x/exp) ([BSD-3-Clause](https://cs.opensource.google/go/x/exp/+/df929982:LICENSE)) + - [golang.org/x/image/bmp](https://pkg.go.dev/golang.org/x/image/bmp) ([BSD-3-Clause](https://cs.opensource.google/go/x/image/+/v0.27.0:LICENSE)) + - [golang.org/x/mod](https://pkg.go.dev/golang.org/x/mod) ([BSD-3-Clause](https://cs.opensource.google/go/x/mod/+/v0.28.0:LICENSE)) + - [golang.org/x/net](https://pkg.go.dev/golang.org/x/net) ([BSD-3-Clause](https://cs.opensource.google/go/x/net/+/v0.46.0:LICENSE)) + - [golang.org/x/sync](https://pkg.go.dev/golang.org/x/sync) ([BSD-3-Clause](https://cs.opensource.google/go/x/sync/+/v0.17.0:LICENSE)) + - [golang.org/x/sys](https://pkg.go.dev/golang.org/x/sys) ([BSD-3-Clause](https://cs.opensource.google/go/x/sys/+/v0.37.0:LICENSE)) + - [golang.org/x/term](https://pkg.go.dev/golang.org/x/term) ([BSD-3-Clause](https://cs.opensource.google/go/x/term/+/v0.36.0:LICENSE)) - [golang.zx2c4.com/wintun](https://pkg.go.dev/golang.zx2c4.com/wintun) ([MIT](https://git.zx2c4.com/wintun-go/tree/LICENSE?id=0fa3db229ce2)) - [golang.zx2c4.com/wireguard/windows/tunnel/winipcfg](https://pkg.go.dev/golang.zx2c4.com/wireguard/windows/tunnel/winipcfg) ([MIT](https://git.zx2c4.com/wireguard-windows/tree/COPYING?h=v0.5.3)) + - [google.golang.org/protobuf](https://pkg.go.dev/google.golang.org/protobuf) ([BSD-3-Clause](https://github.com/protocolbuffers/protobuf-go/blob/v1.36.8/LICENSE)) - [gopkg.in/Knetic/govaluate.v3](https://pkg.go.dev/gopkg.in/Knetic/govaluate.v3) ([MIT](https://github.com/Knetic/govaluate/blob/v3.0.0/LICENSE)) + - [gopkg.in/yaml.v3](https://pkg.go.dev/gopkg.in/yaml.v3) ([MIT](https://github.com/go-yaml/yaml/blob/v3.0.1/LICENSE)) - [tailscale.com](https://pkg.go.dev/tailscale.com) ([BSD-3-Clause](https://github.com/tailscale/tailscale/blob/HEAD/LICENSE)) ## Additional Dependencies diff --git a/log/sockstatlog/logger.go b/log/sockstatlog/logger.go index 3cc27c22d..8ddfabb86 100644 --- a/log/sockstatlog/logger.go +++ b/log/sockstatlog/logger.go @@ -17,6 +17,7 @@ import ( "sync/atomic" "time" + "tailscale.com/feature/buildfeatures" "tailscale.com/health" "tailscale.com/logpolicy" "tailscale.com/logtail" @@ -25,6 +26,7 @@ import ( "tailscale.com/net/sockstats" "tailscale.com/types/logger" "tailscale.com/types/logid" + "tailscale.com/util/eventbus" "tailscale.com/util/mak" ) @@ -96,8 +98,8 @@ func SockstatLogID(logID logid.PublicID) logid.PrivateID { // // The netMon parameter is optional. It should be specified in environments where // Tailscaled is manipulating the routing table. -func NewLogger(logdir string, logf logger.Logf, logID logid.PublicID, netMon *netmon.Monitor, health *health.Tracker) (*Logger, error) { - if !sockstats.IsAvailable { +func NewLogger(logdir string, logf logger.Logf, logID logid.PublicID, netMon *netmon.Monitor, health *health.Tracker, bus *eventbus.Bus) (*Logger, error) { + if !sockstats.IsAvailable || !buildfeatures.HasLogTail { return nil, nil } if netMon == nil { @@ -126,6 +128,7 @@ func NewLogger(logdir string, logf logger.Logf, logID logid.PublicID, netMon *ne PrivateID: SockstatLogID(logID), Collection: "sockstats.log.tailscale.io", Buffer: filch, + Bus: bus, CompressLogs: true, FlushDelayFn: func() time.Duration { // set flush delay to 100 years so it never flushes automatically @@ -143,33 +146,33 @@ func NewLogger(logdir string, logf logger.Logf, logID logid.PublicID, netMon *ne // SetLoggingEnabled enables or disables logging. // When disabled, socket stats are not polled and no new logs are written to disk. // Existing logs can still be fetched via the C2N API. -func (l *Logger) SetLoggingEnabled(v bool) { - old := l.enabled.Load() - if old != v && l.enabled.CompareAndSwap(old, v) { +func (lg *Logger) SetLoggingEnabled(v bool) { + old := lg.enabled.Load() + if old != v && lg.enabled.CompareAndSwap(old, v) { if v { - if l.eventCh == nil { + if lg.eventCh == nil { // eventCh should be large enough for the number of events that will occur within logInterval. // Add an extra second's worth of events to ensure we don't drop any. - l.eventCh = make(chan event, (logInterval+time.Second)/pollInterval) + lg.eventCh = make(chan event, (logInterval+time.Second)/pollInterval) } - l.ctx, l.cancelFn = context.WithCancel(context.Background()) - go l.poll() - go l.logEvents() + lg.ctx, lg.cancelFn = context.WithCancel(context.Background()) + go lg.poll() + go lg.logEvents() } else { - l.cancelFn() + lg.cancelFn() } } } -func (l *Logger) Write(p []byte) (int, error) { - return l.logger.Write(p) +func (lg *Logger) Write(p []byte) (int, error) { + return lg.logger.Write(p) } // poll fetches the current socket stats at the configured time interval, // calculates the delta since the last poll, // and writes any non-zero values to the logger event channel. // This method does not return. -func (l *Logger) poll() { +func (lg *Logger) poll() { // last is the last set of socket stats we saw. var lastStats *sockstats.SockStats var lastTime time.Time @@ -177,7 +180,7 @@ func (l *Logger) poll() { ticker := time.NewTicker(pollInterval) for { select { - case <-l.ctx.Done(): + case <-lg.ctx.Done(): ticker.Stop() return case t := <-ticker.C: @@ -193,7 +196,7 @@ func (l *Logger) poll() { if stats.CurrentInterfaceCellular { e.IsCellularInterface = 1 } - l.eventCh <- e + lg.eventCh <- e } } lastTime = t @@ -204,14 +207,14 @@ func (l *Logger) poll() { // logEvents reads events from the event channel at logInterval and logs them to disk. // This method does not return. -func (l *Logger) logEvents() { - enc := json.NewEncoder(l) +func (lg *Logger) logEvents() { + enc := json.NewEncoder(lg) flush := func() { for { select { - case e := <-l.eventCh: + case e := <-lg.eventCh: if err := enc.Encode(e); err != nil { - l.logf("sockstatlog: error encoding log: %v", err) + lg.logf("sockstatlog: error encoding log: %v", err) } default: return @@ -221,7 +224,7 @@ func (l *Logger) logEvents() { ticker := time.NewTicker(logInterval) for { select { - case <-l.ctx.Done(): + case <-lg.ctx.Done(): ticker.Stop() return case <-ticker.C: @@ -230,29 +233,29 @@ func (l *Logger) logEvents() { } } -func (l *Logger) LogID() string { - if l.logger == nil { +func (lg *Logger) LogID() string { + if lg.logger == nil { return "" } - return l.logger.PrivateID().Public().String() + return lg.logger.PrivateID().Public().String() } // Flush sends pending logs to the log server and flushes them from the local buffer. -func (l *Logger) Flush() { - l.logger.StartFlush() +func (lg *Logger) Flush() { + lg.logger.StartFlush() } -func (l *Logger) Shutdown(ctx context.Context) { - if l.cancelFn != nil { - l.cancelFn() +func (lg *Logger) Shutdown(ctx context.Context) { + if lg.cancelFn != nil { + lg.cancelFn() } - l.filch.Close() - l.logger.Shutdown(ctx) + lg.filch.Close() + lg.logger.Shutdown(ctx) type closeIdler interface { CloseIdleConnections() } - if tr, ok := l.tr.(closeIdler); ok { + if tr, ok := lg.tr.(closeIdler); ok { tr.CloseIdleConnections() } } diff --git a/log/sockstatlog/logger_test.go b/log/sockstatlog/logger_test.go index 31fb17e46..e5c2feb29 100644 --- a/log/sockstatlog/logger_test.go +++ b/log/sockstatlog/logger_test.go @@ -24,7 +24,7 @@ func TestResourceCleanup(t *testing.T) { if err != nil { t.Fatal(err) } - lg, err := NewLogger(td, logger.Discard, id.Public(), nil, nil) + lg, err := NewLogger(td, logger.Discard, id.Public(), nil, nil, nil) if err != nil { t.Fatal(err) } diff --git a/logpolicy/logpolicy.go b/logpolicy/logpolicy.go index 0d2af77f2..f7491783a 100644 --- a/logpolicy/logpolicy.go +++ b/logpolicy/logpolicy.go @@ -31,6 +31,8 @@ import ( "golang.org/x/term" "tailscale.com/atomicfile" "tailscale.com/envknob" + "tailscale.com/feature" + "tailscale.com/feature/buildfeatures" "tailscale.com/health" "tailscale.com/hostinfo" "tailscale.com/log/filelogger" @@ -41,16 +43,18 @@ import ( "tailscale.com/net/netknob" "tailscale.com/net/netmon" "tailscale.com/net/netns" + "tailscale.com/net/netx" "tailscale.com/net/tlsdial" - "tailscale.com/net/tshttpproxy" "tailscale.com/paths" "tailscale.com/safesocket" "tailscale.com/types/logger" "tailscale.com/types/logid" "tailscale.com/util/clientmetric" + "tailscale.com/util/eventbus" "tailscale.com/util/must" "tailscale.com/util/racebuild" - "tailscale.com/util/syspolicy" + "tailscale.com/util/syspolicy/pkey" + "tailscale.com/util/syspolicy/policyclient" "tailscale.com/util/testenv" "tailscale.com/version" "tailscale.com/version/distro" @@ -64,7 +68,7 @@ var getLogTargetOnce struct { func getLogTarget() string { getLogTargetOnce.Do(func() { envTarget, _ := os.LookupEnv("TS_LOG_TARGET") - getLogTargetOnce.v, _ = syspolicy.GetString(syspolicy.LogTarget, envTarget) + getLogTargetOnce.v, _ = policyclient.Get().GetString(pkey.LogTarget, envTarget) }) return getLogTargetOnce.v @@ -104,6 +108,7 @@ type Policy struct { // Logtail is the logger. Logtail *logtail.Logger // PublicID is the logger's instance identifier. + // It may be the zero value if logging is not in use. PublicID logid.PublicID // Logf is where to write informational messages about this Logger. Logf logger.Logf @@ -188,8 +193,8 @@ type logWriter struct { logger *log.Logger } -func (l logWriter) Write(buf []byte) (int, error) { - l.logger.Printf("%s", buf) +func (lg logWriter) Write(buf []byte) (int, error) { + lg.logger.Printf("%s", buf) return len(buf), nil } @@ -223,6 +228,9 @@ func LogsDir(logf logger.Logf) string { logf("logpolicy: using LocalAppData dir %v", dir) return dir case "linux": + if distro.Get() == distro.JetKVM { + return "/userdata/tailscale/var" + } // STATE_DIRECTORY is set by systemd 240+ but we support older // systems-d. For example, Ubuntu 18.04 (Bionic Beaver) is 237. systemdStateDir := os.Getenv("STATE_DIRECTORY") @@ -230,6 +238,9 @@ func LogsDir(logf logger.Logf) string { logf("logpolicy: using $STATE_DIRECTORY, %q", systemdStateDir) return systemdStateDir } + case "js": + logf("logpolicy: no logs directory in the browser") + return "" } // Default to e.g. /var/lib/tailscale or /var/db/tailscale on Unix. @@ -446,25 +457,69 @@ func tryFixLogStateLocation(dir, cmdname string, logf logger.Logf) { } } -// New returns a new log policy (a logger and its instance ID) for a given -// collection name. -// -// The netMon parameter is optional. It should be specified in environments where -// Tailscaled is manipulating the routing table. -// -// The logf parameter is optional; if non-nil, information logs (e.g. when -// migrating state) are sent to that logger, and global changes to the log -// package are avoided. If nil, logs will be printed using log.Printf. +// Deprecated: Use [Options.New] instead. func New(collection string, netMon *netmon.Monitor, health *health.Tracker, logf logger.Logf) *Policy { - return NewWithConfigPath(collection, "", "", netMon, health, logf) + return Options{ + Collection: collection, + NetMon: netMon, + Health: health, + Logf: logf, + }.New() } -// NewWithConfigPath is identical to New, but uses the specified directory and -// command name. If either is empty, it derives them automatically. -// -// The netMon parameter is optional. It should be specified in environments where -// Tailscaled is manipulating the routing table. -func NewWithConfigPath(collection, dir, cmdName string, netMon *netmon.Monitor, health *health.Tracker, logf logger.Logf) *Policy { +// Options is used to construct a [Policy]. +type Options struct { + // Collection is a required collection to upload logs under. + // Collection is a namespace for the type logs. + // For example, logs for a node use "tailnode.log.tailscale.io". + Collection string + + // Dir is an optional directory to store the log configuration. + // If empty, [LogsDir] is used. + Dir string + + // CmdName is an optional name of the current binary. + // If empty, [version.CmdName] is used. + CmdName string + + // NetMon is an optional parameter for monitoring. + // If non-nil, it's used to do faster interface lookups. + NetMon *netmon.Monitor + + // Health is an optional parameter for health status. + // If non-nil, it's used to construct the default HTTP client. + Health *health.Tracker + + // Bus is an optional parameter for communication on the eventbus. + // If non-nil, it's passed to logtail for use in interface monitoring. + // TODO(cmol): Make this non-optional when it's plumbed in by the clients. + Bus *eventbus.Bus + + // Logf is an optional logger to use. + // If nil, [log.Printf] will be used instead. + Logf logger.Logf + + // HTTPC is an optional client to use upload logs. + // If nil, [TransportOptions.New] is used to construct a new client + // with that particular transport sending logs to the default logs server. + HTTPC *http.Client + + // MaxBufferSize is the maximum size of the log buffer. + // This controls the amount of logs that can be temporarily stored + // before the logs can be successfully upload. + // If zero, a default buffer size is chosen. + MaxBufferSize int + + // MaxUploadSize is the maximum size per upload. + // This should only be set by clients that have been authenticated + // with the logging service as having a higher upload limit. + // If zero, a default upload size is chosen. + MaxUploadSize int +} + +// init initializes the log policy and returns a logtail.Config and the +// Policy. +func (opts Options) init(disableLogging bool) (*logtail.Config, *Policy) { if hostinfo.IsNATLabGuestVM() { // In NATLab Gokrazy instances, tailscaled comes up concurently with // DHCP and the doesn't have DNS for a while. Wait for DHCP first. @@ -492,23 +547,23 @@ func NewWithConfigPath(collection, dir, cmdName string, netMon *netmon.Monitor, earlyErrBuf.WriteByte('\n') } - if dir == "" { - dir = LogsDir(earlyLogf) + if opts.Dir == "" { + opts.Dir = LogsDir(earlyLogf) } - if cmdName == "" { - cmdName = version.CmdName() + if opts.CmdName == "" { + opts.CmdName = version.CmdName() } - useStdLogger := logf == nil + useStdLogger := opts.Logf == nil if useStdLogger { - logf = log.Printf + opts.Logf = log.Printf } - tryFixLogStateLocation(dir, cmdName, logf) + tryFixLogStateLocation(opts.Dir, opts.CmdName, opts.Logf) - cfgPath := filepath.Join(dir, fmt.Sprintf("%s.log.conf", cmdName)) + cfgPath := filepath.Join(opts.Dir, fmt.Sprintf("%s.log.conf", opts.CmdName)) if runtime.GOOS == "windows" { - switch cmdName { + switch opts.CmdName { case "tailscaled": // Tailscale 1.14 and before stored state under %LocalAppData% // (usually "C:\WINDOWS\system32\config\systemprofile\AppData\Local" @@ -539,7 +594,7 @@ func NewWithConfigPath(collection, dir, cmdName string, netMon *netmon.Monitor, cfgPath = paths.TryConfigFileMigration(earlyLogf, oldPath, cfgPath) case "tailscale-ipn": for _, oldBase := range []string{"wg64.log.conf", "wg32.log.conf"} { - oldConf := filepath.Join(dir, oldBase) + oldConf := filepath.Join(opts.Dir, oldBase) if fi, err := os.Stat(oldConf); err == nil && fi.Mode().IsRegular() { cfgPath = paths.TryConfigFileMigration(earlyLogf, oldConf, cfgPath) break @@ -552,44 +607,61 @@ func NewWithConfigPath(collection, dir, cmdName string, netMon *netmon.Monitor, if err != nil { earlyLogf("logpolicy.ConfigFromFile %v: %v", cfgPath, err) } - if err := newc.Validate(collection); err != nil { + if err := newc.Validate(opts.Collection); err != nil { earlyLogf("logpolicy.Config.Validate for %v: %v", cfgPath, err) - newc = NewConfig(collection) + newc = NewConfig(opts.Collection) if err := newc.Save(cfgPath); err != nil { earlyLogf("logpolicy.Config.Save for %v: %v", cfgPath, err) } } conf := logtail.Config{ - Collection: newc.Collection, - PrivateID: newc.PrivateID, - Stderr: logWriter{console}, - CompressLogs: true, - HTTPC: &http.Client{Transport: NewLogtailTransport(logtail.DefaultHost, netMon, health, logf)}, - } - if collection == logtail.CollectionNode { + Collection: newc.Collection, + PrivateID: newc.PrivateID, + Stderr: logWriter{console}, + CompressLogs: true, + MaxUploadSize: opts.MaxUploadSize, + Bus: opts.Bus, + } + if opts.Collection == logtail.CollectionNode { conf.MetricsDelta = clientmetric.EncodeLogTailMetricsDelta conf.IncludeProcID = true conf.IncludeProcSequence = true } - if envknob.NoLogsNoSupport() || testenv.InTest() { - logf("You have disabled logging. Tailscale will not be able to provide support.") + if disableLogging { + opts.Logf("You have disabled logging. Tailscale will not be able to provide support.") conf.HTTPC = &http.Client{Transport: noopPretendSuccessTransport{}} } else { // Only attach an on-disk filch buffer if we are going to be sending logs. // No reason to persist them locally just to drop them later. - attachFilchBuffer(&conf, dir, cmdName, logf) + attachFilchBuffer(&conf, opts.Dir, opts.CmdName, opts.MaxBufferSize, opts.Logf) + conf.HTTPC = opts.HTTPC + logHost := logtail.DefaultHost if val := getLogTarget(); val != "" { - logf("You have enabled a non-default log target. Doing without being told to by Tailscale staff or your network administrator will make getting support difficult.") - conf.BaseURL = val - u, _ := url.Parse(val) - conf.HTTPC = &http.Client{Transport: NewLogtailTransport(u.Host, netMon, health, logf)} + u, err := url.Parse(val) + if err != nil { + opts.Logf("logpolicy: invalid TS_LOG_TARGET %q: %v; using default log host", val, err) + } else if u.Host == "" { + opts.Logf("logpolicy: invalid TS_LOG_TARGET %q: missing host; using default log host", val) + } else { + opts.Logf("You have enabled a non-default log target. Doing without being told to by Tailscale staff or your network administrator will make getting support difficult.") + conf.BaseURL = val + logHost = u.Host + } } + if conf.HTTPC == nil { + conf.HTTPC = &http.Client{Transport: TransportOptions{ + Host: logHost, + NetMon: opts.NetMon, + Health: opts.Health, + Logf: opts.Logf, + }.New()} + } } - lw := logtail.NewLogger(conf, logf) + lw := logtail.NewLogger(conf, opts.Logf) var logOutput io.Writer = lw @@ -607,28 +679,36 @@ func NewWithConfigPath(collection, dir, cmdName string, netMon *netmon.Monitor, log.SetOutput(logOutput) } - logf("Program starting: v%v, Go %v: %#v", + opts.Logf("Program starting: v%v, Go %v: %#v", version.Long(), goVersion(), os.Args) - logf("LogID: %v", newc.PublicID) + opts.Logf("LogID: %v", newc.PublicID) if earlyErrBuf.Len() != 0 { - logf("%s", earlyErrBuf.Bytes()) + opts.Logf("%s", earlyErrBuf.Bytes()) } - return &Policy{ + return &conf, &Policy{ Logtail: lw, PublicID: newc.PublicID, - Logf: logf, + Logf: opts.Logf, } } +// New returns a new log policy (a logger and its instance ID). +func (opts Options) New() *Policy { + disableLogging := envknob.NoLogsNoSupport() || testenv.InTest() || runtime.GOOS == "plan9" || !buildfeatures.HasLogTail + _, policy := opts.init(disableLogging) + return policy +} + // attachFilchBuffer creates an on-disk ring buffer using filch and attaches // it to the logtail config. Note that this is optional; if no buffer is set, // logtail will use an in-memory buffer. -func attachFilchBuffer(conf *logtail.Config, dir, cmdName string, logf logger.Logf) { +func attachFilchBuffer(conf *logtail.Config, dir, cmdName string, maxFileSize int, logf logger.Logf) { filchOptions := filch.Options{ ReplaceStderr: redirectStderrToLogPanics(), + MaxFileSize: maxFileSize, } filchPrefix := filepath.Join(dir, cmdName) @@ -705,7 +785,7 @@ func (p *Policy) Shutdown(ctx context.Context) error { // // The netMon parameter is optional. It should be specified in environments where // Tailscaled is manipulating the routing table. -func MakeDialFunc(netMon *netmon.Monitor, logf logger.Logf) func(ctx context.Context, netw, addr string) (net.Conn, error) { +func MakeDialFunc(netMon *netmon.Monitor, logf logger.Logf) netx.DialFunc { if netMon == nil { netMon = netmon.NewStatic() } @@ -760,26 +840,55 @@ func dialContext(ctx context.Context, netw, addr string, netMon *netmon.Monitor, return c, err } -// NewLogtailTransport returns an HTTP Transport particularly suited to uploading -// logs to the given host name. See DialContext for details on how it works. -// -// The netMon parameter is optional. It should be specified in environments where -// Tailscaled is manipulating the routing table. -// -// The logf parameter is optional; if non-nil, logs are printed using the -// provided function; if nil, log.Printf will be used instead. +// Deprecated: Use [TransportOptions.New] instead. func NewLogtailTransport(host string, netMon *netmon.Monitor, health *health.Tracker, logf logger.Logf) http.RoundTripper { - if testenv.InTest() { + return TransportOptions{Host: host, NetMon: netMon, Health: health, Logf: logf}.New() +} + +// TransportOptions is used to construct an [http.RoundTripper]. +type TransportOptions struct { + // Host is the optional hostname of the logs server. + // If empty, then [logtail.DefaultHost] is used. + Host string + + // NetMon is an optional parameter for monitoring. + // If non-nil, it's used to do faster interface lookups. + NetMon *netmon.Monitor + + // Health is an optional parameter for health status. + // If non-nil, it's used to construct the default HTTP client. + Health *health.Tracker + + // Logf is an optional logger to use. + // If nil, [log.Printf] will be used instead. + Logf logger.Logf + + // TLSClientConfig is an optional TLS configuration to use. + // If non-nil, the configuration will be cloned. + TLSClientConfig *tls.Config +} + +// New returns an HTTP Transport particularly suited to uploading logs +// to the given host name. See [DialContext] for details on how it works. +func (opts TransportOptions) New() http.RoundTripper { + if testenv.InTest() || envknob.NoLogsNoSupport() { return noopPretendSuccessTransport{} } - if netMon == nil { - netMon = netmon.NewStatic() + if opts.NetMon == nil { + opts.NetMon = netmon.NewStatic() } // Start with a copy of http.DefaultTransport and tweak it a bit. tr := http.DefaultTransport.(*http.Transport).Clone() + if opts.TLSClientConfig != nil { + tr.TLSClientConfig = opts.TLSClientConfig.Clone() + } - tr.Proxy = tshttpproxy.ProxyFromEnvironment - tshttpproxy.SetTransportGetProxyConnectHeader(tr) + if buildfeatures.HasUseProxy { + tr.Proxy = feature.HookProxyFromEnvironment.GetOrNil() + if set, ok := feature.HookProxySetTransportGetProxyConnectHeader.GetOk(); ok { + set(tr) + } + } // We do our own zstd compression on uploads, and responses never contain any payload, // so don't send "Accept-Encoding: gzip" to save a few bytes on the wire, since there @@ -787,10 +896,10 @@ func NewLogtailTransport(host string, netMon *netmon.Monitor, health *health.Tra tr.DisableCompression = true // Log whenever we dial: - if logf == nil { - logf = log.Printf + if opts.Logf == nil { + opts.Logf = log.Printf } - tr.DialContext = MakeDialFunc(netMon, logf) + tr.DialContext = MakeDialFunc(opts.NetMon, opts.Logf) // We're uploading logs ideally infrequently, with specific timing that will // change over time. Try to keep the connection open, to avoid repeatedly @@ -812,8 +921,8 @@ func NewLogtailTransport(host string, netMon *netmon.Monitor, health *health.Tra tr.TLSNextProto = map[string]func(authority string, c *tls.Conn) http.RoundTripper{} } - tr.TLSClientConfig = tlsdial.Config(host, health, tr.TLSClientConfig) - // Force TLS 1.3 since we know log.tailscale.io supports it. + tr.TLSClientConfig = tlsdial.Config(opts.Health, tr.TLSClientConfig) + // Force TLS 1.3 since we know log.tailscale.com supports it. tr.TLSClientConfig.MinVersion = tls.VersionTLS13 return tr diff --git a/logpolicy/logpolicy_test.go b/logpolicy/logpolicy_test.go index fdbfe4506..c09e590bb 100644 --- a/logpolicy/logpolicy_test.go +++ b/logpolicy/logpolicy_test.go @@ -4,33 +4,127 @@ package logpolicy import ( + "net/http" "os" "reflect" "testing" + + "tailscale.com/logtail" ) -func TestLogHost(t *testing.T) { +func resetLogTarget() { + os.Unsetenv("TS_LOG_TARGET") v := reflect.ValueOf(&getLogTargetOnce).Elem() - reset := func() { - v.Set(reflect.Zero(v.Type())) - } - defer reset() + v.Set(reflect.Zero(v.Type())) +} + +func TestLogHost(t *testing.T) { + defer resetLogTarget() tests := []struct { env string want string }{ - {"", "log.tailscale.io"}, + {"", logtail.DefaultHost}, {"http://foo.com", "foo.com"}, {"https://foo.com", "foo.com"}, {"https://foo.com/", "foo.com"}, {"https://foo.com:123/", "foo.com"}, } for _, tt := range tests { - reset() + resetLogTarget() os.Setenv("TS_LOG_TARGET", tt.env) if got := LogHost(); got != tt.want { t.Errorf("for env %q, got %q, want %q", tt.env, got, tt.want) } } } +func TestOptions(t *testing.T) { + defer resetLogTarget() + + tests := []struct { + name string + opts func() Options + wantBaseURL string + }{ + { + name: "default", + opts: func() Options { return Options{} }, + wantBaseURL: "", + }, + { + name: "custom_baseurl", + opts: func() Options { + os.Setenv("TS_LOG_TARGET", "http://localhost:1234") + return Options{} + }, + wantBaseURL: "http://localhost:1234", + }, + { + name: "custom_httpc_and_baseurl", + opts: func() Options { + os.Setenv("TS_LOG_TARGET", "http://localhost:12345") + return Options{HTTPC: &http.Client{Transport: noopPretendSuccessTransport{}}} + }, + wantBaseURL: "http://localhost:12345", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + resetLogTarget() + config, policy := tt.opts().init(false) + if policy == nil { + t.Fatal("unexpected nil policy") + } + if config.BaseURL != tt.wantBaseURL { + t.Errorf("got %q, want %q", config.BaseURL, tt.wantBaseURL) + } + policy.Close() + }) + } +} + +// TestInvalidLogTarget is a test for #17792 +func TestInvalidLogTarget(t *testing.T) { + defer resetLogTarget() + + tests := []struct { + name string + logTarget string + }{ + { + name: "invalid_url_no_scheme", + logTarget: "not a url at all", + }, + { + name: "malformed_url", + logTarget: "ht!tp://invalid", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + resetLogTarget() + os.Setenv("TS_LOG_TARGET", tt.logTarget) + + opts := Options{ + Collection: "test.log.tailscale.io", + Logf: t.Logf, + } + + // This should not panic even with invalid log target + config, policy := opts.init(false) + if policy == nil { + t.Fatal("expected non-nil policy") + } + defer policy.Close() + + // When log target is invalid, it should fall back to the invalid value + // but not crash. BaseURL should remain empty + if config.BaseURL != "" { + t.Errorf("got BaseURL=%q, want empty", config.BaseURL) + } + }) + } +} diff --git a/logpolicy/maybe_syspolicy.go b/logpolicy/maybe_syspolicy.go new file mode 100644 index 000000000..8b2836c97 --- /dev/null +++ b/logpolicy/maybe_syspolicy.go @@ -0,0 +1,8 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !ts_omit_syspolicy + +package logpolicy + +import _ "tailscale.com/feature/syspolicy" diff --git a/logtail/api.md b/logtail/api.md index 8ec0b69c0..20726e209 100644 --- a/logtail/api.md +++ b/logtail/api.md @@ -6,14 +6,14 @@ retrieving, and processing log entries. # Overview HTTP requests are received at the service **base URL** -[https://log.tailscale.io](https://log.tailscale.io), and return JSON-encoded +[https://log.tailscale.com](https://log.tailscale.com), and return JSON-encoded responses using standard HTTP response codes. Authorization for the configuration and retrieval APIs is done with a secret API key passed as the HTTP basic auth username. Secret keys are generated via the web UI at base URL. An example of using basic auth with curl: - curl -u : https://log.tailscale.io/collections + curl -u : https://log.tailscale.com/collections In the future, an HTTP header will allow using MessagePack instead of JSON. diff --git a/logtail/buffer.go b/logtail/buffer.go index c9f2e1ad0..82c9b4610 100644 --- a/logtail/buffer.go +++ b/logtail/buffer.go @@ -1,13 +1,16 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause +//go:build !ts_omit_logtail + package logtail import ( "bytes" "errors" "fmt" - "sync" + + "tailscale.com/syncs" ) type Buffer interface { @@ -34,7 +37,7 @@ type memBuffer struct { next []byte pending chan qentry - dropMu sync.Mutex + dropMu syncs.Mutex dropCount int } diff --git a/logtail/config.go b/logtail/config.go new file mode 100644 index 000000000..bf47dd8aa --- /dev/null +++ b/logtail/config.go @@ -0,0 +1,67 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package logtail + +import ( + "io" + "net/http" + "time" + + "tailscale.com/tstime" + "tailscale.com/types/logid" + "tailscale.com/util/eventbus" +) + +// DefaultHost is the default host name to upload logs to when +// Config.BaseURL isn't provided. +const DefaultHost = "log.tailscale.com" + +const defaultFlushDelay = 2 * time.Second + +const ( + // CollectionNode is the name of a logtail Config.Collection + // for tailscaled (or equivalent: IPNExtension, Android app). + CollectionNode = "tailnode.log.tailscale.io" +) + +type Config struct { + Collection string // collection name, a domain name + PrivateID logid.PrivateID // private ID for the primary log stream + CopyPrivateID logid.PrivateID // private ID for a log stream that is a superset of this log stream + BaseURL string // if empty defaults to "https://log.tailscale.com" + HTTPC *http.Client // if empty defaults to http.DefaultClient + SkipClientTime bool // if true, client_time is not written to logs + LowMemory bool // if true, logtail minimizes memory use + Clock tstime.Clock // if set, Clock.Now substitutes uses of time.Now + Stderr io.Writer // if set, logs are sent here instead of os.Stderr + Bus *eventbus.Bus // if set, uses the eventbus for awaitInternetUp instead of callback + StderrLevel int // max verbosity level to write to stderr; 0 means the non-verbose messages only + Buffer Buffer // temp storage, if nil a MemoryBuffer + CompressLogs bool // whether to compress the log uploads + MaxUploadSize int // maximum upload size; 0 means using the default + + // MetricsDelta, if non-nil, is a func that returns an encoding + // delta in clientmetrics to upload alongside existing logs. + // It can return either an empty string (for nothing) or a string + // that's safe to embed in a JSON string literal without further escaping. + MetricsDelta func() string + + // FlushDelayFn, if non-nil is a func that returns how long to wait to + // accumulate logs before uploading them. 0 or negative means to upload + // immediately. + // + // If nil, a default value is used. (currently 2 seconds) + FlushDelayFn func() time.Duration + + // IncludeProcID, if true, results in an ephemeral process identifier being + // included in logs. The ID is random and not guaranteed to be globally + // unique, but it can be used to distinguish between different instances + // running with same PrivateID. + IncludeProcID bool + + // IncludeProcSequence, if true, results in an ephemeral sequence number + // being included in the logs. The sequence number is incremented for each + // log message sent, but is not persisted across process restarts. + IncludeProcSequence bool +} diff --git a/logtail/example/logadopt/logadopt.go b/logtail/example/logadopt/logadopt.go index 984a8a35a..eba3f9311 100644 --- a/logtail/example/logadopt/logadopt.go +++ b/logtail/example/logadopt/logadopt.go @@ -25,7 +25,7 @@ func main() { } log.SetFlags(0) - req, err := http.NewRequest("POST", "https://log.tailscale.io/instances", strings.NewReader(url.Values{ + req, err := http.NewRequest("POST", "https://log.tailscale.com/instances", strings.NewReader(url.Values{ "collection": []string{*collection}, "instances": []string{*publicID}, "adopt": []string{"true"}, diff --git a/logtail/example/logreprocess/demo.sh b/logtail/example/logreprocess/demo.sh index 4ec819a67..583929c12 100755 --- a/logtail/example/logreprocess/demo.sh +++ b/logtail/example/logreprocess/demo.sh @@ -13,7 +13,7 @@ # # Then generate a LOGTAIL_API_KEY and two test collections by visiting: # -# https://log.tailscale.io +# https://log.tailscale.com # # Then set the three variables below. trap 'rv=$?; [ "$rv" = 0 ] || echo "-- exiting with code $rv"; exit $rv' EXIT diff --git a/logtail/example/logreprocess/logreprocess.go b/logtail/example/logreprocess/logreprocess.go index 5dbf76578..aae65df9f 100644 --- a/logtail/example/logreprocess/logreprocess.go +++ b/logtail/example/logreprocess/logreprocess.go @@ -37,7 +37,7 @@ func main() { }() } - req, err := http.NewRequest("GET", "https://log.tailscale.io/c/"+*collection+"?stream=true", nil) + req, err := http.NewRequest("GET", "https://log.tailscale.com/c/"+*collection+"?stream=true", nil) if err != nil { log.Fatal(err) } diff --git a/logtail/logtail.go b/logtail/logtail.go index bb4232c34..c1e43258a 100644 --- a/logtail/logtail.go +++ b/logtail/logtail.go @@ -1,11 +1,14 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -// Package logtail sends logs to log.tailscale.io. +//go:build !ts_omit_logtail + +// Package logtail sends logs to log.tailscale.com. package logtail import ( "bytes" + "cmp" "context" "crypto/rand" "encoding/binary" @@ -14,9 +17,7 @@ import ( "log" mrand "math/rand/v2" "net/http" - "net/netip" "os" - "regexp" "runtime" "slices" "strconv" @@ -24,14 +25,15 @@ import ( "sync/atomic" "time" + "github.com/creachadair/msync/trigger" "github.com/go-json-experiment/json/jsontext" "tailscale.com/envknob" "tailscale.com/net/netmon" "tailscale.com/net/sockstats" - "tailscale.com/net/tsaddr" "tailscale.com/tstime" tslogger "tailscale.com/types/logger" "tailscale.com/types/logid" + "tailscale.com/util/eventbus" "tailscale.com/util/set" "tailscale.com/util/truncate" "tailscale.com/util/zstdframe" @@ -141,13 +143,14 @@ func NewLogger(cfg Config, logf tslogger.Logf) *Logger { if !cfg.CopyPrivateID.IsZero() { urlSuffix = "?copyId=" + cfg.CopyPrivateID.String() } - l := &Logger{ + logger := &Logger{ privateID: cfg.PrivateID, stderr: cfg.Stderr, stderrLevel: int64(cfg.StderrLevel), httpc: cfg.HTTPC, url: cfg.BaseURL + "/c/" + cfg.Collection + "/" + cfg.PrivateID.String() + urlSuffix, buffer: cfg.Buffer, + maxUploadSize: cfg.MaxUploadSize, skipClientTime: cfg.SkipClientTime, drainWake: make(chan struct{}, 1), sentinel: make(chan int32, 16), @@ -161,15 +164,21 @@ func NewLogger(cfg Config, logf tslogger.Logf) *Logger { shutdownStart: make(chan struct{}), shutdownDone: make(chan struct{}), } - l.SetSockstatsLabel(sockstats.LabelLogtailLogger) - l.compressLogs = cfg.CompressLogs + + if cfg.Bus != nil { + logger.eventClient = cfg.Bus.Client("logtail.Logger") + // Subscribe to change deltas from NetMon to detect when the network comes up. + eventbus.SubscribeFunc(logger.eventClient, logger.onChangeDelta) + } + logger.SetSockstatsLabel(sockstats.LabelLogtailLogger) + logger.compressLogs = cfg.CompressLogs ctx, cancel := context.WithCancel(context.Background()) - l.uploadCancel = cancel + logger.uploadCancel = cancel - go l.uploading(ctx) - l.Write([]byte("logtail started")) - return l + go logger.uploading(ctx) + logger.Write([]byte("logtail started")) + return logger } // Logger writes logs, splitting them as configured between local @@ -182,6 +191,7 @@ type Logger struct { skipClientTime bool netMonitor *netmon.Monitor buffer Buffer + maxUploadSize int drainWake chan struct{} // signal to speed up drain drainBuf []byte // owned by drainPending for reuse flushDelayFn func() time.Duration // negative or zero return value to upload aggressively, or >0 to batch at this delay @@ -195,6 +205,8 @@ type Logger struct { privateID logid.PrivateID httpDoCalls atomic.Int32 sockstatsLabel atomicSocktatsLabel + eventClient *eventbus.Client + networkIsUp trigger.Cond // set/reset by netmon.ChangeDelta events procID uint32 includeProcSequence bool @@ -203,6 +215,7 @@ type Logger struct { procSequence uint64 flushTimer tstime.TimerController // used when flushDelay is >0 writeBuf [bufferSize]byte // owned by Write for reuse + bytesBuf bytes.Buffer // owned by appendTextOrJSONLocked for reuse jsonDec jsontext.Decoder // owned by appendTextOrJSONLocked for reuse shutdownStartMu sync.Mutex // guards the closing of shutdownStart @@ -218,27 +231,27 @@ func (p *atomicSocktatsLabel) Store(label sockstats.Label) { p.p.Store(uint32(la // SetVerbosityLevel controls the verbosity level that should be // written to stderr. 0 is the default (not verbose). Levels 1 or higher // are increasingly verbose. -func (l *Logger) SetVerbosityLevel(level int) { - atomic.StoreInt64(&l.stderrLevel, int64(level)) +func (lg *Logger) SetVerbosityLevel(level int) { + atomic.StoreInt64(&lg.stderrLevel, int64(level)) } // SetNetMon sets the network monitor. // // It should not be changed concurrently with log writes and should // only be set once. -func (l *Logger) SetNetMon(lm *netmon.Monitor) { - l.netMonitor = lm +func (lg *Logger) SetNetMon(lm *netmon.Monitor) { + lg.netMonitor = lm } // SetSockstatsLabel sets the label used in sockstat logs to identify network traffic from this logger. -func (l *Logger) SetSockstatsLabel(label sockstats.Label) { - l.sockstatsLabel.Store(label) +func (lg *Logger) SetSockstatsLabel(label sockstats.Label) { + lg.sockstatsLabel.Store(label) } // PrivateID returns the logger's private log ID. // // It exists for internal use only. -func (l *Logger) PrivateID() logid.PrivateID { return l.privateID } +func (lg *Logger) PrivateID() logid.PrivateID { return lg.privateID } // Shutdown gracefully shuts down the logger while completing any // remaining uploads. @@ -246,30 +259,33 @@ func (l *Logger) PrivateID() logid.PrivateID { return l.privateID } // It will block, continuing to try and upload unless the passed // context object interrupts it by being done. // If the shutdown is interrupted, an error is returned. -func (l *Logger) Shutdown(ctx context.Context) error { +func (lg *Logger) Shutdown(ctx context.Context) error { done := make(chan struct{}) go func() { select { case <-ctx.Done(): - l.uploadCancel() - <-l.shutdownDone - case <-l.shutdownDone: + lg.uploadCancel() + <-lg.shutdownDone + case <-lg.shutdownDone: } close(done) - l.httpc.CloseIdleConnections() + lg.httpc.CloseIdleConnections() }() - l.shutdownStartMu.Lock() + if lg.eventClient != nil { + lg.eventClient.Close() + } + lg.shutdownStartMu.Lock() select { - case <-l.shutdownStart: - l.shutdownStartMu.Unlock() + case <-lg.shutdownStart: + lg.shutdownStartMu.Unlock() return nil default: } - close(l.shutdownStart) - l.shutdownStartMu.Unlock() + close(lg.shutdownStart) + lg.shutdownStartMu.Unlock() - io.WriteString(l, "logger closing down\n") + io.WriteString(lg, "logger closing down\n") <-done return nil @@ -279,8 +295,8 @@ func (l *Logger) Shutdown(ctx context.Context) error { // process, and any associated goroutines. // // Deprecated: use Shutdown -func (l *Logger) Close() { - l.Shutdown(context.Background()) +func (lg *Logger) Close() { + lg.Shutdown(context.Background()) } // drainBlock is called by drainPending when there are no logs to drain. @@ -290,11 +306,11 @@ func (l *Logger) Close() { // // If the caller specified FlushInterface, drainWake is only sent to // periodically. -func (l *Logger) drainBlock() (shuttingDown bool) { +func (lg *Logger) drainBlock() (shuttingDown bool) { select { - case <-l.shutdownStart: + case <-lg.shutdownStart: return true - case <-l.drainWake: + case <-lg.drainWake: } return false } @@ -302,13 +318,13 @@ func (l *Logger) drainBlock() (shuttingDown bool) { // drainPending drains and encodes a batch of logs from the buffer for upload. // If no logs are available, drainPending blocks until logs are available. // The returned buffer is only valid until the next call to drainPending. -func (l *Logger) drainPending() (b []byte) { - b = l.drainBuf[:0] +func (lg *Logger) drainPending() (b []byte) { + b = lg.drainBuf[:0] b = append(b, '[') defer func() { b = bytes.TrimRight(b, ",") b = append(b, ']') - l.drainBuf = b + lg.drainBuf = b if len(b) <= len("[]") { b = nil } @@ -316,13 +332,13 @@ func (l *Logger) drainPending() (b []byte) { maxLen := maxSize for len(b) < maxLen { - line, err := l.buffer.TryReadLine() + line, err := lg.buffer.TryReadLine() switch { case err == io.EOF: return b case err != nil: b = append(b, '{') - b = l.appendMetadata(b, false, true, 0, 0, "reading ringbuffer: "+err.Error(), nil, 0) + b = lg.appendMetadata(b, false, true, 0, 0, "reading ringbuffer: "+err.Error(), nil, 0) b = bytes.TrimRight(b, ",") b = append(b, '}') return b @@ -336,10 +352,10 @@ func (l *Logger) drainPending() (b []byte) { // in our buffer from a previous large write, let it go. if cap(b) > bufferSize { b = bytes.Clone(b) - l.drainBuf = b + lg.drainBuf = b } - if shuttingDown := l.drainBlock(); shuttingDown { + if shuttingDown := lg.drainBlock(); shuttingDown { return b } continue @@ -356,18 +372,18 @@ func (l *Logger) drainPending() (b []byte) { default: // This is probably a log added to stderr by filch // outside of the logtail logger. Encode it. - if !l.explainedRaw { - fmt.Fprintf(l.stderr, "RAW-STDERR: ***\n") - fmt.Fprintf(l.stderr, "RAW-STDERR: *** Lines prefixed with RAW-STDERR below bypassed logtail and probably come from a previous run of the program\n") - fmt.Fprintf(l.stderr, "RAW-STDERR: ***\n") - fmt.Fprintf(l.stderr, "RAW-STDERR:\n") - l.explainedRaw = true + if !lg.explainedRaw { + fmt.Fprintf(lg.stderr, "RAW-STDERR: ***\n") + fmt.Fprintf(lg.stderr, "RAW-STDERR: *** Lines prefixed with RAW-STDERR below bypassed logtail and probably come from a previous run of the program\n") + fmt.Fprintf(lg.stderr, "RAW-STDERR: ***\n") + fmt.Fprintf(lg.stderr, "RAW-STDERR:\n") + lg.explainedRaw = true } - fmt.Fprintf(l.stderr, "RAW-STDERR: %s", b) + fmt.Fprintf(lg.stderr, "RAW-STDERR: %s", b) // Do not add a client time, as it could be really old. // Do not include instance key or ID either, // since this came from a different instance. - b = l.appendText(b, line, true, 0, 0, 0) + b = lg.appendText(b, line, true, 0, 0, 0) } b = append(b, ',') } @@ -375,14 +391,14 @@ func (l *Logger) drainPending() (b []byte) { } // This is the goroutine that repeatedly uploads logs in the background. -func (l *Logger) uploading(ctx context.Context) { - defer close(l.shutdownDone) +func (lg *Logger) uploading(ctx context.Context) { + defer close(lg.shutdownDone) for { - body := l.drainPending() + body := lg.drainPending() origlen := -1 // sentinel value: uncompressed // Don't attempt to compress tiny bodies; not worth the CPU cycles. - if l.compressLogs && len(body) > 256 { + if lg.compressLogs && len(body) > 256 { zbody := zstdframe.AppendEncode(nil, body, zstdframe.FastestCompression, zstdframe.LowMemory(true)) @@ -399,20 +415,20 @@ func (l *Logger) uploading(ctx context.Context) { var numFailures int var firstFailure time.Time for len(body) > 0 && ctx.Err() == nil { - retryAfter, err := l.upload(ctx, body, origlen) + retryAfter, err := lg.upload(ctx, body, origlen) if err != nil { numFailures++ - firstFailure = l.clock.Now() + firstFailure = lg.clock.Now() - if !l.internetUp() { - fmt.Fprintf(l.stderr, "logtail: internet down; waiting\n") - l.awaitInternetUp(ctx) + if !lg.internetUp() { + fmt.Fprintf(lg.stderr, "logtail: internet down; waiting\n") + lg.awaitInternetUp(ctx) continue } // Only print the same message once. if currError := err.Error(); lastError != currError { - fmt.Fprintf(l.stderr, "logtail: upload: %v\n", err) + fmt.Fprintf(lg.stderr, "logtail: upload: %v\n", err) lastError = currError } @@ -425,31 +441,55 @@ func (l *Logger) uploading(ctx context.Context) { } else { // Only print a success message after recovery. if numFailures > 0 { - fmt.Fprintf(l.stderr, "logtail: upload succeeded after %d failures and %s\n", numFailures, l.clock.Since(firstFailure).Round(time.Second)) + fmt.Fprintf(lg.stderr, "logtail: upload succeeded after %d failures and %s\n", numFailures, lg.clock.Since(firstFailure).Round(time.Second)) } break } } select { - case <-l.shutdownStart: + case <-lg.shutdownStart: return default: } } } -func (l *Logger) internetUp() bool { - if l.netMonitor == nil { - // No way to tell, so assume it is. +func (lg *Logger) internetUp() bool { + select { + case <-lg.networkIsUp.Ready(): return true + default: + if lg.netMonitor == nil { + return true // No way to tell, so assume it is. + } + return lg.netMonitor.InterfaceState().AnyInterfaceUp() + } +} + +// onChangeDelta is an eventbus subscriber function that handles +// [netmon.ChangeDelta] events to detect whether the Internet is expected to be +// reachable. +func (lg *Logger) onChangeDelta(delta *netmon.ChangeDelta) { + if delta.New.AnyInterfaceUp() { + fmt.Fprintf(lg.stderr, "logtail: internet back up\n") + lg.networkIsUp.Set() + } else { + fmt.Fprintf(lg.stderr, "logtail: network changed, but is not up\n") + lg.networkIsUp.Reset() } - return l.netMonitor.InterfaceState().AnyInterfaceUp() } -func (l *Logger) awaitInternetUp(ctx context.Context) { +func (lg *Logger) awaitInternetUp(ctx context.Context) { + if lg.eventClient != nil { + select { + case <-lg.networkIsUp.Ready(): + case <-ctx.Done(): + } + return + } upc := make(chan bool, 1) - defer l.netMonitor.RegisterChangeCallback(func(delta *netmon.ChangeDelta) { + defer lg.netMonitor.RegisterChangeCallback(func(delta *netmon.ChangeDelta) { if delta.New.AnyInterfaceUp() { select { case upc <- true: @@ -457,12 +497,12 @@ func (l *Logger) awaitInternetUp(ctx context.Context) { } } })() - if l.internetUp() { + if lg.internetUp() { return } select { case <-upc: - fmt.Fprintf(l.stderr, "logtail: internet back up\n") + fmt.Fprintf(lg.stderr, "logtail: internet back up\n") case <-ctx.Done(): } } @@ -470,13 +510,13 @@ func (l *Logger) awaitInternetUp(ctx context.Context) { // upload uploads body to the log server. // origlen indicates the pre-compression body length. // origlen of -1 indicates that the body is not compressed. -func (l *Logger) upload(ctx context.Context, body []byte, origlen int) (retryAfter time.Duration, err error) { +func (lg *Logger) upload(ctx context.Context, body []byte, origlen int) (retryAfter time.Duration, err error) { const maxUploadTime = 45 * time.Second - ctx = sockstats.WithSockStats(ctx, l.sockstatsLabel.Load(), l.Logf) + ctx = sockstats.WithSockStats(ctx, lg.sockstatsLabel.Load(), lg.Logf) ctx, cancel := context.WithTimeout(ctx, maxUploadTime) defer cancel() - req, err := http.NewRequestWithContext(ctx, "POST", l.url, bytes.NewReader(body)) + req, err := http.NewRequestWithContext(ctx, "POST", lg.url, bytes.NewReader(body)) if err != nil { // I know of no conditions under which this could fail. // Report it very loudly. @@ -489,7 +529,7 @@ func (l *Logger) upload(ctx context.Context, body []byte, origlen int) (retryAft } if runtime.GOOS == "js" { // We once advertised we'd accept optional client certs (for internal use) - // on log.tailscale.io but then Tailscale SSH js/wasm clients prompted + // on log.tailscale.com but then Tailscale SSH js/wasm clients prompted // users (on some browsers?) to pick a client cert. We'll fix the server's // TLS ServerHello, but we can also fix it client side for good measure. // @@ -507,8 +547,8 @@ func (l *Logger) upload(ctx context.Context, body []byte, origlen int) (retryAft compressedNote = "compressed" } - l.httpDoCalls.Add(1) - resp, err := l.httpc.Do(req) + lg.httpDoCalls.Add(1) + resp, err := lg.httpc.Do(req) if err != nil { return 0, fmt.Errorf("log upload of %d bytes %s failed: %v", len(body), compressedNote, err) } @@ -527,16 +567,16 @@ func (l *Logger) upload(ctx context.Context, body []byte, origlen int) (retryAft // // TODO(bradfitz): this apparently just returns nil, as of tailscale/corp@9c2ec35. // Finish cleaning this up. -func (l *Logger) Flush() error { +func (lg *Logger) Flush() error { return nil } // StartFlush starts a log upload, if anything is pending. // // If l is nil, StartFlush is a no-op. -func (l *Logger) StartFlush() { - if l != nil { - l.tryDrainWake() +func (lg *Logger) StartFlush() { + if lg != nil { + lg.tryDrainWake() } } @@ -552,41 +592,41 @@ var debugWakesAndUploads = envknob.RegisterBool("TS_DEBUG_LOGTAIL_WAKES") // tryDrainWake tries to send to lg.drainWake, to cause an uploading wakeup. // It does not block. -func (l *Logger) tryDrainWake() { - l.flushPending.Store(false) +func (lg *Logger) tryDrainWake() { + lg.flushPending.Store(false) if debugWakesAndUploads() { // Using println instead of log.Printf here to avoid recursing back into // ourselves. - println("logtail: try drain wake, numHTTP:", l.httpDoCalls.Load()) + println("logtail: try drain wake, numHTTP:", lg.httpDoCalls.Load()) } select { - case l.drainWake <- struct{}{}: + case lg.drainWake <- struct{}{}: default: } } -func (l *Logger) sendLocked(jsonBlob []byte) (int, error) { +func (lg *Logger) sendLocked(jsonBlob []byte) (int, error) { tapSend(jsonBlob) if logtailDisabled.Load() { return len(jsonBlob), nil } - n, err := l.buffer.Write(jsonBlob) + n, err := lg.buffer.Write(jsonBlob) flushDelay := defaultFlushDelay - if l.flushDelayFn != nil { - flushDelay = l.flushDelayFn() + if lg.flushDelayFn != nil { + flushDelay = lg.flushDelayFn() } if flushDelay > 0 { - if l.flushPending.CompareAndSwap(false, true) { - if l.flushTimer == nil { - l.flushTimer = l.clock.AfterFunc(flushDelay, l.tryDrainWake) + if lg.flushPending.CompareAndSwap(false, true) { + if lg.flushTimer == nil { + lg.flushTimer = lg.clock.AfterFunc(flushDelay, lg.tryDrainWake) } else { - l.flushTimer.Reset(flushDelay) + lg.flushTimer.Reset(flushDelay) } } } else { - l.tryDrainWake() + lg.tryDrainWake() } return n, err } @@ -594,13 +634,13 @@ func (l *Logger) sendLocked(jsonBlob []byte) (int, error) { // appendMetadata appends optional "logtail", "metrics", and "v" JSON members. // This assumes dst is already within a JSON object. // Each member is comma-terminated. -func (l *Logger) appendMetadata(dst []byte, skipClientTime, skipMetrics bool, procID uint32, procSequence uint64, errDetail string, errData jsontext.Value, level int) []byte { +func (lg *Logger) appendMetadata(dst []byte, skipClientTime, skipMetrics bool, procID uint32, procSequence uint64, errDetail string, errData jsontext.Value, level int) []byte { // Append optional logtail metadata. if !skipClientTime || procID != 0 || procSequence != 0 || errDetail != "" || errData != nil { dst = append(dst, `"logtail":{`...) if !skipClientTime { dst = append(dst, `"client_time":"`...) - dst = l.clock.Now().UTC().AppendFormat(dst, time.RFC3339Nano) + dst = lg.clock.Now().UTC().AppendFormat(dst, time.RFC3339Nano) dst = append(dst, '"', ',') } if procID != 0 { @@ -633,8 +673,8 @@ func (l *Logger) appendMetadata(dst []byte, skipClientTime, skipMetrics bool, pr } // Append optional metrics metadata. - if !skipMetrics && l.metricsDelta != nil { - if d := l.metricsDelta(); d != "" { + if !skipMetrics && lg.metricsDelta != nil { + if d := lg.metricsDelta(); d != "" { dst = append(dst, `"metrics":"`...) dst = append(dst, d...) dst = append(dst, '"', ',') @@ -654,10 +694,10 @@ func (l *Logger) appendMetadata(dst []byte, skipClientTime, skipMetrics bool, pr } // appendText appends a raw text message in the Tailscale JSON log entry format. -func (l *Logger) appendText(dst, src []byte, skipClientTime bool, procID uint32, procSequence uint64, level int) []byte { +func (lg *Logger) appendText(dst, src []byte, skipClientTime bool, procID uint32, procSequence uint64, level int) []byte { dst = slices.Grow(dst, len(src)) dst = append(dst, '{') - dst = l.appendMetadata(dst, skipClientTime, false, procID, procSequence, "", nil, level) + dst = lg.appendMetadata(dst, skipClientTime, false, procID, procSequence, "", nil, level) if len(src) == 0 { dst = bytes.TrimRight(dst, ",") return append(dst, "}\n"...) @@ -686,28 +726,30 @@ func appendTruncatedString(dst, src []byte, n int) []byte { return dst } -func (l *Logger) AppendTextOrJSONLocked(dst, src []byte) []byte { - l.clock = tstime.StdClock{} - return l.appendTextOrJSONLocked(dst, src, 0) -} - // appendTextOrJSONLocked appends a raw text message or a raw JSON object // in the Tailscale JSON log format. -func (l *Logger) appendTextOrJSONLocked(dst, src []byte, level int) []byte { - if l.includeProcSequence { - l.procSequence++ +func (lg *Logger) appendTextOrJSONLocked(dst, src []byte, level int) []byte { + if lg.includeProcSequence { + lg.procSequence++ } if len(src) == 0 || src[0] != '{' { - return l.appendText(dst, src, l.skipClientTime, l.procID, l.procSequence, level) + return lg.appendText(dst, src, lg.skipClientTime, lg.procID, lg.procSequence, level) } // Check whether the input is a valid JSON object and // whether it contains the reserved "logtail" name at the top-level. var logtailKeyOffset, logtailValOffset, logtailValLength int validJSON := func() bool { - // TODO(dsnet): Avoid allocation of bytes.Buffer struct. - dec := &l.jsonDec - dec.Reset(bytes.NewBuffer(src)) + // The jsontext.NewDecoder API operates on an io.Reader, for which + // bytes.Buffer provides a means to convert a []byte into an io.Reader. + // However, bytes.NewBuffer normally allocates unless + // we immediately shallow copy it into a pre-allocated Buffer struct. + // See https://go.dev/issue/67004. + lg.bytesBuf = *bytes.NewBuffer(src) + defer func() { lg.bytesBuf = bytes.Buffer{} }() // avoid pinning src + + dec := &lg.jsonDec + dec.Reset(&lg.bytesBuf) if tok, err := dec.ReadToken(); tok.Kind() != '{' || err != nil { return false } @@ -739,7 +781,7 @@ func (l *Logger) appendTextOrJSONLocked(dst, src []byte, level int) []byte { // Treat invalid JSON as a raw text message. if !validJSON { - return l.appendText(dst, src, l.skipClientTime, l.procID, l.procSequence, level) + return lg.appendText(dst, src, lg.skipClientTime, lg.procID, lg.procSequence, level) } // Check whether the JSON payload is too large. @@ -747,12 +789,13 @@ func (l *Logger) appendTextOrJSONLocked(dst, src []byte, level int) []byte { // That's okay as the Tailscale log service limit is actually 2*maxSize. // However, so long as logging applications aim to target the maxSize limit, // there should be no trouble eventually uploading logs. - if len(src) > maxSize { + maxLen := cmp.Or(lg.maxUploadSize, maxSize) + if len(src) > maxLen { errDetail := fmt.Sprintf("entry too large: %d bytes", len(src)) - errData := appendTruncatedString(nil, src, maxSize/len(`\uffff`)) // escaping could increase size + errData := appendTruncatedString(nil, src, maxLen/len(`\uffff`)) // escaping could increase size dst = append(dst, '{') - dst = l.appendMetadata(dst, l.skipClientTime, true, l.procID, l.procSequence, errDetail, errData, level) + dst = lg.appendMetadata(dst, lg.skipClientTime, true, lg.procID, lg.procSequence, errDetail, errData, level) dst = bytes.TrimRight(dst, ",") return append(dst, "}\n"...) } @@ -769,7 +812,7 @@ func (l *Logger) appendTextOrJSONLocked(dst, src []byte, level int) []byte { } dst = slices.Grow(dst, len(src)) dst = append(dst, '{') - dst = l.appendMetadata(dst, l.skipClientTime, true, l.procID, l.procSequence, errDetail, errData, level) + dst = lg.appendMetadata(dst, lg.skipClientTime, true, lg.procID, lg.procSequence, errDetail, errData, level) if logtailValLength > 0 { // Exclude original logtail member from the message. dst = appendWithoutNewline(dst, src[len("{"):logtailKeyOffset]) @@ -796,82 +839,42 @@ func appendWithoutNewline(dst, src []byte) []byte { } // Logf logs to l using the provided fmt-style format and optional arguments. -func (l *Logger) Logf(format string, args ...any) { - fmt.Fprintf(l, format, args...) +func (lg *Logger) Logf(format string, args ...any) { + fmt.Fprintf(lg, format, args...) } -var obscureIPs = envknob.RegisterBool("TS_OBSCURE_LOGGED_IPS") - // Write logs an encoded JSON blob. // // If the []byte passed to Write is not an encoded JSON blob, // then contents is fit into a JSON blob and written. // // This is intended as an interface for the stdlib "log" package. -func (l *Logger) Write(buf []byte) (int, error) { +func (lg *Logger) Write(buf []byte) (int, error) { if len(buf) == 0 { return 0, nil } inLen := len(buf) // length as provided to us, before modifications to downstream writers level, buf := parseAndRemoveLogLevel(buf) - if l.stderr != nil && l.stderr != io.Discard && int64(level) <= atomic.LoadInt64(&l.stderrLevel) { + if lg.stderr != nil && lg.stderr != io.Discard && int64(level) <= atomic.LoadInt64(&lg.stderrLevel) { if buf[len(buf)-1] == '\n' { - l.stderr.Write(buf) + lg.stderr.Write(buf) } else { // The log package always line-terminates logs, // so this is an uncommon path. withNL := append(buf[:len(buf):len(buf)], '\n') - l.stderr.Write(withNL) + lg.stderr.Write(withNL) } } - if obscureIPs() { - buf = redactIPs(buf) - } - - l.writeLock.Lock() - defer l.writeLock.Unlock() + lg.writeLock.Lock() + defer lg.writeLock.Unlock() - b := l.appendTextOrJSONLocked(l.writeBuf[:0], buf, level) - _, err := l.sendLocked(b) + b := lg.appendTextOrJSONLocked(lg.writeBuf[:0], buf, level) + _, err := lg.sendLocked(b) return inLen, err } -var ( - regexMatchesIPv6 = regexp.MustCompile(`([0-9a-fA-F]{1,4}):([0-9a-fA-F]{1,4}):([0-9a-fA-F:]{1,4})*`) - regexMatchesIPv4 = regexp.MustCompile(`(\d{1,3})\.(\d{1,3})\.\d{1,3}\.\d{1,3}`) -) - -// redactIPs is a helper function used in Write() to redact IPs (other than tailscale IPs). -// This function takes a log line as a byte slice and -// uses regex matching to parse and find IP addresses. Based on if the IP address is IPv4 or -// IPv6, it parses and replaces the end of the addresses with an "x". This function returns the -// log line with the IPs redacted. -func redactIPs(buf []byte) []byte { - out := regexMatchesIPv6.ReplaceAllFunc(buf, func(b []byte) []byte { - ip, err := netip.ParseAddr(string(b)) - if err != nil || tsaddr.IsTailscaleIP(ip) { - return b // don't change this one - } - - prefix := bytes.Split(b, []byte(":")) - return bytes.Join(append(prefix[:2], []byte("x")), []byte(":")) - }) - - out = regexMatchesIPv4.ReplaceAllFunc(out, func(b []byte) []byte { - ip, err := netip.ParseAddr(string(b)) - if err != nil || tsaddr.IsTailscaleIP(ip) { - return b // don't change this one - } - - prefix := bytes.Split(b, []byte(".")) - return bytes.Join(append(prefix[:2], []byte("x.x")), []byte(".")) - }) - - return []byte(out) -} - var ( openBracketV = []byte("[v") v1 = []byte("[v1] ") diff --git a/logtail/logtail_omit.go b/logtail/logtail_omit.go new file mode 100644 index 000000000..814fd3be9 --- /dev/null +++ b/logtail/logtail_omit.go @@ -0,0 +1,44 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build ts_omit_logtail + +package logtail + +import ( + "context" + + tslogger "tailscale.com/types/logger" + "tailscale.com/types/logid" +) + +// Noop implementations of everything when ts_omit_logtail is set. + +type Logger struct{} + +type Buffer any + +func Disable() {} + +func NewLogger(cfg Config, logf tslogger.Logf) *Logger { + return &Logger{} +} + +func (*Logger) Write(p []byte) (n int, err error) { + return len(p), nil +} + +func (*Logger) Logf(format string, args ...any) {} +func (*Logger) Shutdown(ctx context.Context) error { return nil } +func (*Logger) SetVerbosityLevel(level int) {} + +func (l *Logger) SetSockstatsLabel(label any) {} + +func (l *Logger) PrivateID() logid.PrivateID { return logid.PrivateID{} } +func (l *Logger) StartFlush() {} + +func RegisterLogTap(dst chan<- string) (unregister func()) { + return func() {} +} + +func (*Logger) SetNetMon(any) {} diff --git a/logtail/logtail_test.go b/logtail/logtail_test.go index 3ea630406..b618fc0d7 100644 --- a/logtail/logtail_test.go +++ b/logtail/logtail_test.go @@ -15,9 +15,9 @@ import ( "time" "github.com/go-json-experiment/json/jsontext" - "tailscale.com/envknob" "tailscale.com/tstest" "tailscale.com/tstime" + "tailscale.com/util/eventbus/eventbustest" "tailscale.com/util/must" ) @@ -29,10 +29,11 @@ func TestFastShutdown(t *testing.T) { func(w http.ResponseWriter, r *http.Request) {})) defer testServ.Close() - l := NewLogger(Config{ + logger := NewLogger(Config{ BaseURL: testServ.URL, + Bus: eventbustest.NewBus(t), }, t.Logf) - err := l.Shutdown(ctx) + err := logger.Shutdown(ctx) if err != nil { t.Error(err) } @@ -63,7 +64,10 @@ func NewLogtailTestHarness(t *testing.T) (*LogtailTestServer, *Logger) { t.Cleanup(ts.srv.Close) - l := NewLogger(Config{BaseURL: ts.srv.URL}, t.Logf) + logger := NewLogger(Config{ + BaseURL: ts.srv.URL, + Bus: eventbustest.NewBus(t), + }, t.Logf) // There is always an initial "logtail started" message body := <-ts.uploaded @@ -71,14 +75,14 @@ func NewLogtailTestHarness(t *testing.T) (*LogtailTestServer, *Logger) { t.Errorf("unknown start logging statement: %q", string(body)) } - return &ts, l + return &ts, logger } func TestDrainPendingMessages(t *testing.T) { - ts, l := NewLogtailTestHarness(t) + ts, logger := NewLogtailTestHarness(t) for range logLines { - l.Write([]byte("log line")) + logger.Write([]byte("log line")) } // all of the "log line" messages usually arrive at once, but poll if needed. @@ -92,14 +96,14 @@ func TestDrainPendingMessages(t *testing.T) { // if we never find count == logLines, the test will eventually time out. } - err := l.Shutdown(context.Background()) + err := logger.Shutdown(context.Background()) if err != nil { t.Error(err) } } func TestEncodeAndUploadMessages(t *testing.T) { - ts, l := NewLogtailTestHarness(t) + ts, logger := NewLogtailTestHarness(t) tests := []struct { name string @@ -119,7 +123,7 @@ func TestEncodeAndUploadMessages(t *testing.T) { } for _, tt := range tests { - io.WriteString(l, tt.log) + io.WriteString(logger, tt.log) body := <-ts.uploaded data := unmarshalOne(t, body) @@ -140,7 +144,7 @@ func TestEncodeAndUploadMessages(t *testing.T) { } } - err := l.Shutdown(context.Background()) + err := logger.Shutdown(context.Background()) if err != nil { t.Error(err) } @@ -316,90 +320,11 @@ func TestLoggerWriteResult(t *testing.T) { t.Errorf("mismatch.\n got: %#q\nwant: %#q", back, want) } } -func TestRedact(t *testing.T) { - envknob.Setenv("TS_OBSCURE_LOGGED_IPS", "true") - tests := []struct { - in string - want string - }{ - // tests for ipv4 addresses - { - "120.100.30.47", - "120.100.x.x", - }, - { - "192.167.0.1/65", - "192.167.x.x/65", - }, - { - "node [5Btdd] d:e89a3384f526d251 now using 10.0.0.222:41641 mtu=1360 tx=d81a8a35a0ce", - "node [5Btdd] d:e89a3384f526d251 now using 10.0.x.x:41641 mtu=1360 tx=d81a8a35a0ce", - }, - //tests for ipv6 addresses - { - "2001:0db8:85a3:0000:0000:8a2e:0370:7334", - "2001:0db8:x", - }, - { - "2345:0425:2CA1:0000:0000:0567:5673:23b5", - "2345:0425:x", - }, - { - "2601:645:8200:edf0::c9de/64", - "2601:645:x/64", - }, - { - "node [5Btdd] d:e89a3384f526d251 now using 2051:0000:140F::875B:131C mtu=1360 tx=d81a8a35a0ce", - "node [5Btdd] d:e89a3384f526d251 now using 2051:0000:x mtu=1360 tx=d81a8a35a0ce", - }, - { - "2601:645:8200:edf0::c9de/64 2601:645:8200:edf0:1ce9:b17d:71f5:f6a3/64", - "2601:645:x/64 2601:645:x/64", - }, - //tests for tailscale ip addresses - { - "100.64.5.6", - "100.64.5.6", - }, - { - "fd7a:115c:a1e0::/96", - "fd7a:115c:a1e0::/96", - }, - //tests for ipv6 and ipv4 together - { - "192.167.0.1 2001:0db8:85a3:0000:0000:8a2e:0370:7334", - "192.167.x.x 2001:0db8:x", - }, - { - "node [5Btdd] d:e89a3384f526d251 now using 10.0.0.222:41641 mtu=1360 tx=d81a8a35a0ce 2345:0425:2CA1::0567:5673:23b5", - "node [5Btdd] d:e89a3384f526d251 now using 10.0.x.x:41641 mtu=1360 tx=d81a8a35a0ce 2345:0425:x", - }, - { - "100.64.5.6 2091:0db8:85a3:0000:0000:8a2e:0370:7334", - "100.64.5.6 2091:0db8:x", - }, - { - "192.167.0.1 120.100.30.47 2041:0000:140F::875B:131B", - "192.167.x.x 120.100.x.x 2041:0000:x", - }, - { - "fd7a:115c:a1e0::/96 192.167.0.1 2001:0db8:85a3:0000:0000:8a2e:0370:7334", - "fd7a:115c:a1e0::/96 192.167.x.x 2001:0db8:x", - }, - } - - for _, tt := range tests { - gotBuf := redactIPs([]byte(tt.in)) - if string(gotBuf) != tt.want { - t.Errorf("for %q,\n got: %#q\nwant: %#q\n", tt.in, gotBuf, tt.want) - } - } -} func TestAppendMetadata(t *testing.T) { - var l Logger - l.clock = tstest.NewClock(tstest.ClockOpts{Start: time.Date(2000, 01, 01, 0, 0, 0, 0, time.UTC)}) - l.metricsDelta = func() string { return "metrics" } + var lg Logger + lg.clock = tstest.NewClock(tstest.ClockOpts{Start: time.Date(2000, 01, 01, 0, 0, 0, 0, time.UTC)}) + lg.metricsDelta = func() string { return "metrics" } for _, tt := range []struct { skipClientTime bool @@ -425,7 +350,7 @@ func TestAppendMetadata(t *testing.T) { {procID: 1, procSeq: 2, errDetail: "error", errData: jsontext.Value(`["something","bad","happened"]`), level: 2, want: `"logtail":{"client_time":"2000-01-01T00:00:00Z","proc_id":1,"proc_seq":2,"error":{"detail":"error","bad_data":["something","bad","happened"]}},"metrics":"metrics","v":2,`}, } { - got := string(l.appendMetadata(nil, tt.skipClientTime, tt.skipMetrics, tt.procID, tt.procSeq, tt.errDetail, tt.errData, tt.level)) + got := string(lg.appendMetadata(nil, tt.skipClientTime, tt.skipMetrics, tt.procID, tt.procSeq, tt.errDetail, tt.errData, tt.level)) if got != tt.want { t.Errorf("appendMetadata(%v, %v, %v, %v, %v, %v, %v):\n\tgot %s\n\twant %s", tt.skipClientTime, tt.skipMetrics, tt.procID, tt.procSeq, tt.errDetail, tt.errData, tt.level, got, tt.want) } @@ -437,10 +362,10 @@ func TestAppendMetadata(t *testing.T) { } func TestAppendText(t *testing.T) { - var l Logger - l.clock = tstest.NewClock(tstest.ClockOpts{Start: time.Date(2000, 01, 01, 0, 0, 0, 0, time.UTC)}) - l.metricsDelta = func() string { return "metrics" } - l.lowMem = true + var lg Logger + lg.clock = tstest.NewClock(tstest.ClockOpts{Start: time.Date(2000, 01, 01, 0, 0, 0, 0, time.UTC)}) + lg.metricsDelta = func() string { return "metrics" } + lg.lowMem = true for _, tt := range []struct { text string @@ -457,7 +382,7 @@ func TestAppendText(t *testing.T) { {text: "\b\f\n\r\t\"\\", want: `{"logtail":{"client_time":"2000-01-01T00:00:00Z"},"metrics":"metrics","text":"\b\f\n\r\t\"\\"}`}, {text: "x" + strings.Repeat("😐", maxSize), want: `{"logtail":{"client_time":"2000-01-01T00:00:00Z"},"metrics":"metrics","text":"x` + strings.Repeat("😐", 1023) + `â€Ļ+1044484"}`}, } { - got := string(l.appendText(nil, []byte(tt.text), tt.skipClientTime, tt.procID, tt.procSeq, tt.level)) + got := string(lg.appendText(nil, []byte(tt.text), tt.skipClientTime, tt.procID, tt.procSeq, tt.level)) if !strings.HasSuffix(got, "\n") { t.Errorf("`%s` does not end with a newline", got) } @@ -472,10 +397,10 @@ func TestAppendText(t *testing.T) { } func TestAppendTextOrJSON(t *testing.T) { - var l Logger - l.clock = tstest.NewClock(tstest.ClockOpts{Start: time.Date(2000, 01, 01, 0, 0, 0, 0, time.UTC)}) - l.metricsDelta = func() string { return "metrics" } - l.lowMem = true + var lg Logger + lg.clock = tstest.NewClock(tstest.ClockOpts{Start: time.Date(2000, 01, 01, 0, 0, 0, 0, time.UTC)}) + lg.metricsDelta = func() string { return "metrics" } + lg.lowMem = true for _, tt := range []struct { in string @@ -494,7 +419,7 @@ func TestAppendTextOrJSON(t *testing.T) { {in: `{ "fizz" : "buzz" , "logtail" : "duplicate" , "wizz" : "wuzz" }`, want: `{"logtail":{"client_time":"2000-01-01T00:00:00Z","error":{"detail":"duplicate logtail member","bad_data":"duplicate"}}, "fizz" : "buzz" , "wizz" : "wuzz"}`}, {in: `{"long":"` + strings.Repeat("a", maxSize) + `"}`, want: `{"logtail":{"client_time":"2000-01-01T00:00:00Z","error":{"detail":"entry too large: 262155 bytes","bad_data":"{\"long\":\"` + strings.Repeat("a", 43681) + `â€Ļ+218465"}}}`}, } { - got := string(l.appendTextOrJSONLocked(nil, []byte(tt.in), tt.level)) + got := string(lg.appendTextOrJSONLocked(nil, []byte(tt.in), tt.level)) if !strings.HasSuffix(got, "\n") { t.Errorf("`%s` does not end with a newline", got) } @@ -536,21 +461,21 @@ var testdataTextLog = []byte(`netcheck: report: udp=true v6=false v6os=true mapv var testdataJSONLog = []byte(`{"end":"2024-04-08T21:39:15.715291586Z","nodeId":"nQRJBE7CNTRL","physicalTraffic":[{"dst":"127.x.x.x:2","src":"100.x.x.x:0","txBytes":148,"txPkts":1},{"dst":"127.x.x.x:2","src":"100.x.x.x:0","txBytes":148,"txPkts":1},{"dst":"98.x.x.x:1025","rxBytes":640,"rxPkts":5,"src":"100.x.x.x:0","txBytes":640,"txPkts":5},{"dst":"24.x.x.x:49973","rxBytes":640,"rxPkts":5,"src":"100.x.x.x:0","txBytes":640,"txPkts":5},{"dst":"73.x.x.x:41641","rxBytes":732,"rxPkts":6,"src":"100.x.x.x:0","txBytes":820,"txPkts":7},{"dst":"75.x.x.x:1025","rxBytes":640,"rxPkts":5,"src":"100.x.x.x:0","txBytes":640,"txPkts":5},{"dst":"75.x.x.x:41641","rxBytes":640,"rxPkts":5,"src":"100.x.x.x:0","txBytes":640,"txPkts":5},{"dst":"174.x.x.x:35497","rxBytes":13008,"rxPkts":98,"src":"100.x.x.x:0","txBytes":26688,"txPkts":150},{"dst":"47.x.x.x:41641","rxBytes":640,"rxPkts":5,"src":"100.x.x.x:0","txBytes":640,"txPkts":5},{"dst":"64.x.x.x:41641","rxBytes":640,"rxPkts":5,"src":"100.x.x.x:0","txBytes":640,"txPkts":5}],"start":"2024-04-08T21:39:11.099495616Z","virtualTraffic":[{"dst":"100.x.x.x:33008","proto":6,"src":"100.x.x.x:22","txBytes":1260,"txPkts":10},{"dst":"100.x.x.x:0","proto":1,"rxBytes":420,"rxPkts":5,"src":"100.x.x.x:0","txBytes":420,"txPkts":5},{"dst":"100.x.x.x:32984","proto":6,"src":"100.x.x.x:22","txBytes":1340,"txPkts":10},{"dst":"100.x.x.x:32998","proto":6,"src":"100.x.x.x:22","txBytes":1020,"txPkts":10},{"dst":"100.x.x.x:32994","proto":6,"src":"100.x.x.x:22","txBytes":1260,"txPkts":10},{"dst":"100.x.x.x:32980","proto":6,"src":"100.x.x.x:22","txBytes":1260,"txPkts":10},{"dst":"100.x.x.x:0","proto":1,"rxBytes":420,"rxPkts":5,"src":"100.x.x.x:0","txBytes":420,"txPkts":5},{"dst":"100.x.x.x:0","proto":1,"rxBytes":420,"rxPkts":5,"src":"100.x.x.x:0","txBytes":420,"txPkts":5},{"dst":"100.x.x.x:32950","proto":6,"src":"100.x.x.x:22","txBytes":1340,"txPkts":10},{"dst":"100.x.x.x:22","proto":6,"src":"100.x.x.x:53332","txBytes":60,"txPkts":1},{"dst":"100.x.x.x:0","proto":1,"src":"100.x.x.x:0","txBytes":420,"txPkts":5},{"dst":"100.x.x.x:0","proto":1,"rxBytes":420,"rxPkts":5,"src":"100.x.x.x:0","txBytes":420,"txPkts":5},{"dst":"100.x.x.x:32966","proto":6,"src":"100.x.x.x:22","txBytes":1260,"txPkts":10},{"dst":"100.x.x.x:22","proto":6,"src":"100.x.x.x:57882","txBytes":60,"txPkts":1},{"dst":"100.x.x.x:22","proto":6,"src":"100.x.x.x:53326","txBytes":60,"txPkts":1},{"dst":"100.x.x.x:22","proto":6,"src":"100.x.x.x:57892","txBytes":60,"txPkts":1},{"dst":"100.x.x.x:32934","proto":6,"src":"100.x.x.x:22","txBytes":8712,"txPkts":55},{"dst":"100.x.x.x:0","proto":1,"rxBytes":420,"rxPkts":5,"src":"100.x.x.x:0","txBytes":420,"txPkts":5},{"dst":"100.x.x.x:32942","proto":6,"src":"100.x.x.x:22","txBytes":1260,"txPkts":10},{"dst":"100.x.x.x:0","proto":1,"rxBytes":420,"rxPkts":5,"src":"100.x.x.x:0","txBytes":420,"txPkts":5},{"dst":"100.x.x.x:32964","proto":6,"src":"100.x.x.x:22","txBytes":1260,"txPkts":10},{"dst":"100.x.x.x:0","proto":1,"rxBytes":420,"rxPkts":5,"src":"100.x.x.x:0","txBytes":420,"txPkts":5},{"dst":"100.x.x.x:0","proto":1,"rxBytes":420,"rxPkts":5,"src":"100.x.x.x:0","txBytes":420,"txPkts":5},{"dst":"100.x.x.x:22","proto":6,"src":"100.x.x.x:37238","txBytes":60,"txPkts":1},{"dst":"100.x.x.x:22","proto":6,"src":"100.x.x.x:37252","txBytes":60,"txPkts":1}]}`) func BenchmarkWriteText(b *testing.B) { - var l Logger - l.clock = tstime.StdClock{} - l.buffer = discardBuffer{} + var lg Logger + lg.clock = tstime.StdClock{} + lg.buffer = discardBuffer{} b.ReportAllocs() for range b.N { - must.Get(l.Write(testdataTextLog)) + must.Get(lg.Write(testdataTextLog)) } } func BenchmarkWriteJSON(b *testing.B) { - var l Logger - l.clock = tstime.StdClock{} - l.buffer = discardBuffer{} + var lg Logger + lg.clock = tstime.StdClock{} + lg.buffer = discardBuffer{} b.ReportAllocs() for range b.N { - must.Get(l.Write(testdataJSONLog)) + must.Get(lg.Write(testdataJSONLog)) } } diff --git a/maths/ewma.go b/maths/ewma.go new file mode 100644 index 000000000..0897b73e4 --- /dev/null +++ b/maths/ewma.go @@ -0,0 +1,72 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package maths contains additional mathematical functions or structures not +// found in the standard library. +package maths + +import ( + "math" + "time" +) + +// EWMA is an exponentially weighted moving average supporting updates at +// irregular intervals with at most nanosecond resolution. +// The zero value will compute a half-life of 1 second. +// It is not safe for concurrent use. +// TODO(raggi): de-duplicate with tstime/rate.Value, which has a more complex +// and synchronized interface and does not provide direct access to the stable +// value. +type EWMA struct { + value float64 // current value of the average + lastTime int64 // time of last update in unix nanos + halfLife float64 // half-life in seconds +} + +// NewEWMA creates a new EWMA with the specified half-life. If halfLifeSeconds +// is 0, it defaults to 1. +func NewEWMA(halfLifeSeconds float64) *EWMA { + return &EWMA{ + halfLife: halfLifeSeconds, + } +} + +// Update adds a new sample to the average. If t is zero or precedes the last +// update, the update is ignored. +func (e *EWMA) Update(value float64, t time.Time) { + if t.IsZero() { + return + } + hl := e.halfLife + if hl == 0 { + hl = 1 + } + tn := t.UnixNano() + if e.lastTime == 0 { + e.value = value + e.lastTime = tn + return + } + + dt := (time.Duration(tn-e.lastTime) * time.Nanosecond).Seconds() + if dt < 0 { + // drop out of order updates + return + } + + // decay = 2^(-dt/halfLife) + decay := math.Exp2(-dt / hl) + e.value = e.value*decay + value*(1-decay) + e.lastTime = tn +} + +// Get returns the current value of the average +func (e *EWMA) Get() float64 { + return e.value +} + +// Reset clears the EWMA to its initial state +func (e *EWMA) Reset() { + e.value = 0 + e.lastTime = 0 +} diff --git a/maths/ewma_test.go b/maths/ewma_test.go new file mode 100644 index 000000000..307078a38 --- /dev/null +++ b/maths/ewma_test.go @@ -0,0 +1,178 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package maths + +import ( + "slices" + "testing" + "time" +) + +// some real world latency samples. +var ( + latencyHistory1 = []int{ + 14, 12, 15, 6, 19, 12, 13, 13, 13, 16, 17, 11, 17, 11, 14, 15, 14, 15, + 16, 16, 17, 14, 12, 16, 18, 14, 14, 11, 15, 15, 25, 11, 15, 14, 12, 15, + 13, 12, 13, 15, 11, 13, 15, 14, 14, 15, 12, 15, 18, 12, 15, 22, 12, 13, + 10, 14, 16, 15, 16, 11, 14, 17, 18, 20, 16, 11, 16, 14, 5, 15, 17, 12, + 15, 11, 15, 20, 12, 17, 12, 17, 15, 12, 12, 11, 14, 15, 11, 20, 14, 13, + 11, 12, 13, 13, 11, 13, 11, 15, 13, 13, 14, 12, 11, 12, 12, 14, 11, 13, + 12, 12, 12, 19, 14, 13, 13, 14, 11, 12, 10, 11, 15, 12, 14, 11, 11, 14, + 14, 12, 12, 11, 14, 12, 11, 12, 14, 11, 12, 15, 12, 14, 12, 12, 21, 16, + 21, 12, 16, 9, 11, 16, 14, 13, 14, 12, 13, 16, + } + latencyHistory2 = []int{ + 18, 20, 21, 21, 20, 23, 18, 18, 20, 21, 20, 19, 22, 18, 20, 20, 19, 21, + 21, 22, 22, 19, 18, 22, 22, 19, 20, 17, 16, 11, 25, 16, 18, 21, 17, 22, + 19, 18, 22, 21, 20, 18, 22, 17, 17, 20, 19, 10, 19, 16, 19, 25, 17, 18, + 15, 20, 21, 20, 23, 22, 22, 22, 19, 22, 22, 17, 22, 20, 20, 19, 21, 22, + 20, 19, 17, 22, 16, 16, 20, 22, 17, 19, 21, 16, 20, 22, 19, 21, 20, 19, + 13, 14, 23, 19, 16, 10, 19, 15, 15, 17, 16, 18, 14, 16, 18, 22, 20, 18, + 18, 21, 15, 19, 18, 19, 18, 20, 17, 19, 21, 19, 20, 19, 20, 20, 17, 14, + 17, 17, 18, 21, 20, 18, 18, 17, 16, 17, 17, 20, 22, 19, 20, 21, 21, 20, + 21, 24, 20, 18, 12, 17, 18, 17, 19, 19, 19, + } +) + +func TestEWMALatencyHistory(t *testing.T) { + type result struct { + t time.Time + v float64 + s int + } + + for _, latencyHistory := range [][]int{latencyHistory1, latencyHistory2} { + startTime := time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC) + halfLife := 30.0 + + ewma := NewEWMA(halfLife) + + var results []result + sum := 0.0 + for i, latency := range latencyHistory { + t := startTime.Add(time.Duration(i) * time.Second) + ewma.Update(float64(latency), t) + sum += float64(latency) + + results = append(results, result{t, ewma.Get(), latency}) + } + mean := sum / float64(len(latencyHistory)) + min := float64(slices.Min(latencyHistory)) + max := float64(slices.Max(latencyHistory)) + + t.Logf("EWMA Latency History (half-life: %.1f seconds):", halfLife) + t.Logf("Mean latency: %.2f ms", mean) + t.Logf("Range: [%.1f, %.1f]", min, max) + + t.Log("Samples: ") + sparkline := []rune("▁▂▃▄▅▆▇█") + var sampleLine []rune + for _, r := range results { + idx := int(((float64(r.s) - min) / (max - min)) * float64(len(sparkline)-1)) + if idx >= len(sparkline) { + idx = len(sparkline) - 1 + } + sampleLine = append(sampleLine, sparkline[idx]) + } + t.Log(string(sampleLine)) + + t.Log("EWMA: ") + var ewmaLine []rune + for _, r := range results { + idx := int(((r.v - min) / (max - min)) * float64(len(sparkline)-1)) + if idx >= len(sparkline) { + idx = len(sparkline) - 1 + } + ewmaLine = append(ewmaLine, sparkline[idx]) + } + t.Log(string(ewmaLine)) + t.Log("") + + t.Logf("Time | Sample | Value | Value - Sample") + t.Logf("") + + for _, result := range results { + t.Logf("%10s | % 6d | % 5.2f | % 5.2f", result.t.Format("15:04:05"), result.s, result.v, result.v-float64(result.s)) + } + + // check that all results are greater than the min, and less than the max of the input, + // and they're all close to the mean. + for _, result := range results { + if result.v < float64(min) || result.v > float64(max) { + t.Errorf("result %f out of range [%f, %f]", result.v, min, max) + } + + if result.v < mean*0.9 || result.v > mean*1.1 { + t.Errorf("result %f not close to mean %f", result.v, mean) + } + } + } +} + +func TestHalfLife(t *testing.T) { + start := time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC) + + ewma := NewEWMA(30.0) + ewma.Update(10, start) + ewma.Update(0, start.Add(30*time.Second)) + + if ewma.Get() != 5 { + t.Errorf("expected 5, got %f", ewma.Get()) + } + + ewma.Update(10, start.Add(60*time.Second)) + if ewma.Get() != 7.5 { + t.Errorf("expected 7.5, got %f", ewma.Get()) + } + + ewma.Update(10, start.Add(90*time.Second)) + if ewma.Get() != 8.75 { + t.Errorf("expected 8.75, got %f", ewma.Get()) + } +} + +func TestZeroValue(t *testing.T) { + start := time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC) + + var ewma EWMA + ewma.Update(10, start) + ewma.Update(0, start.Add(time.Second)) + + if ewma.Get() != 5 { + t.Errorf("expected 5, got %f", ewma.Get()) + } + + ewma.Update(10, start.Add(2*time.Second)) + if ewma.Get() != 7.5 { + t.Errorf("expected 7.5, got %f", ewma.Get()) + } + + ewma.Update(10, start.Add(3*time.Second)) + if ewma.Get() != 8.75 { + t.Errorf("expected 8.75, got %f", ewma.Get()) + } +} + +func TestReset(t *testing.T) { + start := time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC) + + ewma := NewEWMA(30.0) + ewma.Update(10, start) + ewma.Update(0, start.Add(30*time.Second)) + + if ewma.Get() != 5 { + t.Errorf("expected 5, got %f", ewma.Get()) + } + + ewma.Reset() + + if ewma.Get() != 0 { + t.Errorf("expected 0, got %f", ewma.Get()) + } + + ewma.Update(10, start.Add(90*time.Second)) + if ewma.Get() != 10 { + t.Errorf("expected 10, got %f", ewma.Get()) + } +} diff --git a/metrics/metrics.go b/metrics/metrics.go index a07ddccae..19966d395 100644 --- a/metrics/metrics.go +++ b/metrics/metrics.go @@ -11,6 +11,8 @@ import ( "io" "slices" "strings" + + "tailscale.com/syncs" ) // Set is a string-to-Var map variable that satisfies the expvar.Var @@ -37,6 +39,8 @@ type Set struct { type LabelMap struct { Label string expvar.Map + // shardedIntMu orders the initialization of new shardedint keys + shardedIntMu syncs.Mutex } // SetInt64 sets the *Int value stored under the given map key. @@ -44,6 +48,19 @@ func (m *LabelMap) SetInt64(key string, v int64) { m.Get(key).Set(v) } +// Add adds delta to the any int-like value stored under the given map key. +func (m *LabelMap) Add(key string, delta int64) { + type intAdder interface { + Add(delta int64) + } + o := m.Map.Get(key) + if o == nil { + m.Map.Add(key, delta) + return + } + o.(intAdder).Add(delta) +} + // Get returns a direct pointer to the expvar.Int for key, creating it // if necessary. func (m *LabelMap) Get(key string) *expvar.Int { @@ -51,6 +68,23 @@ func (m *LabelMap) Get(key string) *expvar.Int { return m.Map.Get(key).(*expvar.Int) } +// GetShardedInt returns a direct pointer to the syncs.ShardedInt for key, +// creating it if necessary. +func (m *LabelMap) GetShardedInt(key string) *syncs.ShardedInt { + i := m.Map.Get(key) + if i == nil { + m.shardedIntMu.Lock() + defer m.shardedIntMu.Unlock() + i = m.Map.Get(key) + if i != nil { + return i.(*syncs.ShardedInt) + } + i = syncs.NewShardedInt() + m.Set(key, i) + } + return i.(*syncs.ShardedInt) +} + // GetIncrFunc returns a function that increments the expvar.Int named by key. // // Most callers should not need this; it exists to satisfy an diff --git a/metrics/metrics_test.go b/metrics/metrics_test.go index 45bf39e56..a808d5a73 100644 --- a/metrics/metrics_test.go +++ b/metrics/metrics_test.go @@ -21,6 +21,15 @@ func TestLabelMap(t *testing.T) { if g, w := m.Get("bar").Value(), int64(2); g != w { t.Errorf("bar = %v; want %v", g, w) } + m.GetShardedInt("sharded").Add(5) + if g, w := m.GetShardedInt("sharded").Value(), int64(5); g != w { + t.Errorf("sharded = %v; want %v", g, w) + } + m.Add("sharded", 1) + if g, w := m.GetShardedInt("sharded").Value(), int64(6); g != w { + t.Errorf("sharded = %v; want %v", g, w) + } + m.Add("neverbefore", 1) } func TestCurrentFileDescriptors(t *testing.T) { diff --git a/net/ace/ace.go b/net/ace/ace.go new file mode 100644 index 000000000..47e780313 --- /dev/null +++ b/net/ace/ace.go @@ -0,0 +1,125 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package ace implements a Dialer that dials via a Tailscale ACE (CONNECT) +// proxy. +// +// TODO: document this more, when it's more done. As of 2025-09-17, it's in +// development. +package ace + +import ( + "bufio" + "cmp" + "context" + "crypto/tls" + "errors" + "fmt" + "net" + "net/http" + "net/netip" + "sync/atomic" +) + +// Dialer is an HTTP CONNECT proxy dialer to dial the control plane via an ACE +// proxy. +type Dialer struct { + ACEHost string + ACEHostIP netip.Addr // optional; if non-zero, use this IP instead of DNS + ACEPort int // zero means 443 + + // NetDialer optionally specifies the underlying dialer to use to reach the + // ACEHost. If nil, net.Dialer.DialContext is used. + NetDialer func(ctx context.Context, network, address string) (net.Conn, error) +} + +func (d *Dialer) netDialer() func(ctx context.Context, network, address string) (net.Conn, error) { + if d.NetDialer != nil { + return d.NetDialer + } + var std net.Dialer + return std.DialContext +} + +func (d *Dialer) acePort() int { return cmp.Or(d.ACEPort, 443) } + +func (d *Dialer) Dial(ctx context.Context, network, address string) (_ net.Conn, err error) { + if network != "tcp" { + return nil, errors.New("only TCP is supported") + } + + var targetHost string + if d.ACEHostIP.IsValid() { + targetHost = d.ACEHostIP.String() + } else { + targetHost = d.ACEHost + } + + cc, err := d.netDialer()(ctx, "tcp", net.JoinHostPort(targetHost, fmt.Sprint(d.acePort()))) + if err != nil { + return nil, err + } + + // Now that we've dialed, we're about to do three potentially blocking + // operations: the TLS handshake, the CONNECT write, and the HTTP response + // read. To make our context work over all that, we use a context.AfterFunc + // to start a goroutine that'll tear down the underlying connection if the + // context expires. + // + // To prevent races, we use an atomic.Bool to guard access to the underlying + // connection being either good or bad. Only one goroutine (the success path + // in this goroutine after the ReadResponse or the AfterFunc's failure + // goroutine) will compare-and-swap it from false to true. + var done atomic.Bool + stop := context.AfterFunc(ctx, func() { + if done.CompareAndSwap(false, true) { + cc.Close() + } + }) + defer func() { + if err != nil { + if ctx.Err() != nil { + // Prefer the context error. The other error is likely a side + // effect of the context expiring and our tearing down of the + // underlying connection, and is thus probably something like + // "use of closed network connection", which isn't useful (and + // actually misleading) for the caller. + err = ctx.Err() + } + stop() + cc.Close() + } + }() + + tc := tls.Client(cc, &tls.Config{ServerName: d.ACEHost}) + if err := tc.Handshake(); err != nil { + return nil, err + } + + // TODO(tailscale/corp#32484): send proxy-auth header + if _, err := fmt.Fprintf(tc, "CONNECT %s HTTP/1.1\r\nHost: %s\r\n\r\n", address, d.ACEHost); err != nil { + return nil, err + } + + br := bufio.NewReader(tc) + connRes, err := http.ReadResponse(br, &http.Request{Method: "CONNECT"}) + if err != nil { + return nil, fmt.Errorf("reading CONNECT response: %w", err) + } + + // Now that we're done with blocking operations, mark the connection + // as good, to prevent the context's AfterFunc from closing it. + if !stop() || !done.CompareAndSwap(false, true) { + // We lost a race and the context expired. + return nil, ctx.Err() + } + + if connRes.StatusCode != http.StatusOK { + return nil, fmt.Errorf("ACE CONNECT response: %s", connRes.Status) + } + + if br.Buffered() > 0 { + return nil, fmt.Errorf("unexpected %d bytes of buffered data after ACE CONNECT", br.Buffered()) + } + return tc, nil +} diff --git a/net/art/stride_table.go b/net/art/stride_table.go index 5ff0455fe..5050df245 100644 --- a/net/art/stride_table.go +++ b/net/art/stride_table.go @@ -303,21 +303,21 @@ func formatPrefixTable(addr uint8, len int) string { // // For example, childPrefixOf("192.168.0.0/16", 8) == "192.168.8.0/24". func childPrefixOf(parent netip.Prefix, stride uint8) netip.Prefix { - l := parent.Bits() - if l%8 != 0 { + ln := parent.Bits() + if ln%8 != 0 { panic("parent prefix is not 8-bit aligned") } - if l >= parent.Addr().BitLen() { + if ln >= parent.Addr().BitLen() { panic("parent prefix cannot be extended further") } - off := l / 8 + off := ln / 8 if parent.Addr().Is4() { bs := parent.Addr().As4() bs[off] = stride - return netip.PrefixFrom(netip.AddrFrom4(bs), l+8) + return netip.PrefixFrom(netip.AddrFrom4(bs), ln+8) } else { bs := parent.Addr().As16() bs[off] = stride - return netip.PrefixFrom(netip.AddrFrom16(bs), l+8) + return netip.PrefixFrom(netip.AddrFrom16(bs), ln+8) } } diff --git a/net/art/stride_table_test.go b/net/art/stride_table_test.go index bff2bb7c5..4ccef1fe0 100644 --- a/net/art/stride_table_test.go +++ b/net/art/stride_table_test.go @@ -377,8 +377,8 @@ func pfxMask(pfxLen int) uint8 { func allPrefixes() []slowEntry[int] { ret := make([]slowEntry[int], 0, lastHostIndex) for i := 1; i < lastHostIndex+1; i++ { - a, l := inversePrefixIndex(i) - ret = append(ret, slowEntry[int]{a, l, i}) + a, ln := inversePrefixIndex(i) + ret = append(ret, slowEntry[int]{a, ln, i}) } return ret } diff --git a/net/bakedroots/bakedroots.go b/net/bakedroots/bakedroots.go new file mode 100644 index 000000000..b268b1546 --- /dev/null +++ b/net/bakedroots/bakedroots.go @@ -0,0 +1,147 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package bakedroots contains WebPKI CA roots we bake into the tailscaled binary, +// lest the system's CA roots be missing them (or entirely empty). +package bakedroots + +import ( + "crypto/x509" + "fmt" + "sync" + + "tailscale.com/util/testenv" +) + +// Get returns the baked-in roots. +// +// As of 2025-01-21, this includes only the LetsEncrypt ISRG Root X1 & X2 roots. +func Get() *x509.CertPool { + roots.once.Do(func() { + roots.parsePEM(append( + []byte(letsEncryptX1), + letsEncryptX2..., + )) + }) + return roots.p +} + +// ResetForTest resets the cached roots for testing, +// optionally setting them to caPEM if non-nil. +func ResetForTest(tb testenv.TB, caPEM []byte) { + if !testenv.InTest() { + panic("not in test") + } + tb.Setenv("ASSERT_NOT_PARALLEL_TEST", "1") // panics if tb's Parallel was called + + roots = rootsOnce{} + if caPEM != nil { + roots.once.Do(func() { roots.parsePEM(caPEM) }) + tb.Cleanup(func() { + // Reset the roots to real roots for any following test. + roots = rootsOnce{} + }) + } +} + +var roots rootsOnce + +type rootsOnce struct { + once sync.Once + p *x509.CertPool +} + +func (r *rootsOnce) parsePEM(caPEM []byte) { + p := x509.NewCertPool() + if !p.AppendCertsFromPEM(caPEM) { + panic(fmt.Sprintf("bogus PEM: %q", caPEM)) + } + r.p = p +} + +/* +letsEncryptX1 is the LetsEncrypt X1 root: + +Certificate: + + Data: + Version: 3 (0x2) + Serial Number: + 82:10:cf:b0:d2:40:e3:59:44:63:e0:bb:63:82:8b:00 + Signature Algorithm: sha256WithRSAEncryption + Issuer: C = US, O = Internet Security Research Group, CN = ISRG Root X1 + Validity + Not Before: Jun 4 11:04:38 2015 GMT + Not After : Jun 4 11:04:38 2035 GMT + Subject: C = US, O = Internet Security Research Group, CN = ISRG Root X1 + Subject Public Key Info: + Public Key Algorithm: rsaEncryption + RSA Public-Key: (4096 bit) + +We bake it into the binary as a fallback verification root, +in case the system we're running on doesn't have it. +(Tailscale runs on some ancient devices.) + +To test that this code is working on Debian/Ubuntu: + +$ sudo mv /usr/share/ca-certificates/mozilla/ISRG_Root_X1.crt{,.old} +$ sudo update-ca-certificates + +Then restart tailscaled. To also test dnsfallback's use of it, nuke +your /etc/resolv.conf and it should still start & run fine. +*/ +const letsEncryptX1 = ` +-----BEGIN CERTIFICATE----- +MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw +TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh +cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4 +WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu +ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY +MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc +h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+ +0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U +A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW +T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH +B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC +B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv +KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn +OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn +jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw +qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI +rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV +HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq +hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL +ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ +3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK +NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5 +ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur +TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC +jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc +oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq +4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA +mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d +emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc= +-----END CERTIFICATE----- +` + +// letsEncryptX2 is the ISRG Root X2. +// +// Subject: O = Internet Security Research Group, CN = ISRG Root X2 +// Key type: ECDSA P-384 +// Validity: until 2035-09-04 (generated 2020-09-04) +const letsEncryptX2 = ` +-----BEGIN CERTIFICATE----- +MIICGzCCAaGgAwIBAgIQQdKd0XLq7qeAwSxs6S+HUjAKBggqhkjOPQQDAzBPMQsw +CQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJuZXQgU2VjdXJpdHkgUmVzZWFyY2gg +R3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBYMjAeFw0yMDA5MDQwMDAwMDBaFw00 +MDA5MTcxNjAwMDBaME8xCzAJBgNVBAYTAlVTMSkwJwYDVQQKEyBJbnRlcm5ldCBT +ZWN1cml0eSBSZXNlYXJjaCBHcm91cDEVMBMGA1UEAxMMSVNSRyBSb290IFgyMHYw +EAYHKoZIzj0CAQYFK4EEACIDYgAEzZvVn4CDCuwJSvMWSj5cz3es3mcFDR0HttwW ++1qLFNvicWDEukWVEYmO6gbf9yoWHKS5xcUy4APgHoIYOIvXRdgKam7mAHf7AlF9 +ItgKbppbd9/w+kHsOdx1ymgHDB/qo0IwQDAOBgNVHQ8BAf8EBAMCAQYwDwYDVR0T +AQH/BAUwAwEB/zAdBgNVHQ4EFgQUfEKWrt5LSDv6kviejM9ti6lyN5UwCgYIKoZI +zj0EAwMDaAAwZQIwe3lORlCEwkSHRhtFcP9Ymd70/aTSVaYgLXTWNLxBo1BfASdW +tL4ndQavEi51mI38AjEAi/V3bNTIZargCyzuFJ0nN6T5U6VR5CmD1/iQMVtCnwr1 +/q4AaOeMSQ+2b1tbFfLn +-----END CERTIFICATE----- +` diff --git a/net/bakedroots/bakedroots_test.go b/net/bakedroots/bakedroots_test.go new file mode 100644 index 000000000..8ba502a78 --- /dev/null +++ b/net/bakedroots/bakedroots_test.go @@ -0,0 +1,32 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package bakedroots + +import ( + "slices" + "testing" +) + +func TestBakedInRoots(t *testing.T) { + ResetForTest(t, nil) + p := Get() + got := p.Subjects() + if len(got) != 2 { + t.Errorf("subjects = %v; want 2", len(got)) + } + + // TODO(bradfitz): is there a way to easily make this test prettier without + // writing a DER decoder? I'm not seeing how. + var name []string + for _, der := range got { + name = append(name, string(der)) + } + want := []string{ + "0O1\v0\t\x06\x03U\x04\x06\x13\x02US1)0'\x06\x03U\x04\n\x13 Internet Security Research Group1\x150\x13\x06\x03U\x04\x03\x13\fISRG Root X1", + "0O1\v0\t\x06\x03U\x04\x06\x13\x02US1)0'\x06\x03U\x04\n\x13 Internet Security Research Group1\x150\x13\x06\x03U\x04\x03\x13\fISRG Root X2", + } + if !slices.Equal(name, want) { + t.Errorf("subjects = %q; want %q", name, want) + } +} diff --git a/net/batching/conn.go b/net/batching/conn.go new file mode 100644 index 000000000..77cdf8c84 --- /dev/null +++ b/net/batching/conn.go @@ -0,0 +1,47 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package batching implements a socket optimized for increased throughput. +package batching + +import ( + "net/netip" + + "golang.org/x/net/ipv4" + "golang.org/x/net/ipv6" + "tailscale.com/net/packet" + "tailscale.com/types/nettype" +) + +var ( + // This acts as a compile-time check for our usage of ipv6.Message in + // [Conn] for both IPv6 and IPv4 operations. + _ ipv6.Message = ipv4.Message{} +) + +// Conn is a nettype.PacketConn that provides batched i/o using +// platform-specific optimizations, e.g. {recv,send}mmsg & UDP GSO/GRO. +// +// Conn originated from (and is still used by) magicsock where its API was +// strongly influenced by [wireguard-go/conn.Bind] constraints, namely +// wireguard-go's ownership of packet memory. +type Conn interface { + nettype.PacketConn + // ReadBatch reads messages from [Conn] into msgs. It returns the number of + // messages the caller should evaluate for nonzero len, as a zero len + // message may fall on either side of a nonzero. + // + // Each [ipv6.Message.OOB] must be sized to at least MinControlMessageSize(). + ReadBatch(msgs []ipv6.Message, flags int) (n int, err error) + // WriteBatchTo writes buffs to addr. + // + // If geneve.VNI.IsSet(), then geneve is encoded into the space preceding + // offset, and offset must equal [packet.GeneveFixedHeaderLength]. If + // !geneve.VNI.IsSet() then the space preceding offset is ignored. + // + // len(buffs) must be <= batchSize supplied in TryUpgradeToConn(). + // + // WriteBatchTo may return a [neterror.ErrUDPGSODisabled] error if UDP GSO + // was disabled as a result of a send error. + WriteBatchTo(buffs [][]byte, addr netip.AddrPort, geneve packet.GeneveHeader, offset int) error +} diff --git a/net/batching/conn_default.go b/net/batching/conn_default.go new file mode 100644 index 000000000..37d644f50 --- /dev/null +++ b/net/batching/conn_default.go @@ -0,0 +1,23 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !linux + +package batching + +import ( + "tailscale.com/types/nettype" +) + +// TryUpgradeToConn is no-op on all platforms except linux. +func TryUpgradeToConn(pconn nettype.PacketConn, _ string, _ int) nettype.PacketConn { + return pconn +} + +var controlMessageSize = 0 + +func MinControlMessageSize() int { + return controlMessageSize +} + +const IdealBatchSize = 1 diff --git a/wgengine/magicsock/batching_conn_linux.go b/net/batching/conn_linux.go similarity index 86% rename from wgengine/magicsock/batching_conn_linux.go rename to net/batching/conn_linux.go index 25bf974b0..bd7ac25be 100644 --- a/wgengine/magicsock/batching_conn_linux.go +++ b/net/batching/conn_linux.go @@ -1,7 +1,7 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -package magicsock +package batching import ( "encoding/binary" @@ -22,6 +22,7 @@ import ( "golang.org/x/sys/unix" "tailscale.com/hostinfo" "tailscale.com/net/neterror" + "tailscale.com/net/packet" "tailscale.com/types/nettype" ) @@ -42,10 +43,15 @@ type xnetBatchWriter interface { WriteBatch([]ipv6.Message, int) (int, error) } +var ( + // [linuxBatchingConn] implements [Conn]. + _ Conn = (*linuxBatchingConn)(nil) +) + // linuxBatchingConn is a UDP socket that provides batched i/o. It implements -// batchingConn. +// [Conn]. type linuxBatchingConn struct { - pc nettype.PacketConn + pc *net.UDPConn xpc xnetBatchReaderWriter rxOffload bool // supports UDP GRO or similar txOffload atomic.Bool // supports UDP GSO or similar @@ -92,9 +98,13 @@ const ( maxIPv6PayloadLen = 1<<16 - 1 - 8 ) -// coalesceMessages iterates msgs, coalescing them where possible while -// maintaining datagram order. All msgs have their Addr field set to addr. -func (c *linuxBatchingConn) coalesceMessages(addr *net.UDPAddr, buffs [][]byte, msgs []ipv6.Message) int { +// coalesceMessages iterates 'buffs', setting and coalescing them in 'msgs' +// where possible while maintaining datagram order. +// +// All msgs have their Addr field set to addr. +// +// All msgs[i].Buffers[0] are preceded by a Geneve header (geneve) if geneve.VNI.IsSet(). +func (c *linuxBatchingConn) coalesceMessages(addr *net.UDPAddr, geneve packet.GeneveHeader, buffs [][]byte, msgs []ipv6.Message, offset int) int { var ( base = -1 // index of msg we are currently coalescing into gsoSize int // segmentation size of msgs[base] @@ -105,7 +115,13 @@ func (c *linuxBatchingConn) coalesceMessages(addr *net.UDPAddr, buffs [][]byte, if addr.IP.To4() == nil { maxPayloadLen = maxIPv6PayloadLen } + vniIsSet := geneve.VNI.IsSet() for i, buff := range buffs { + if vniIsSet { + geneve.Encode(buff) + } else { + buff = buff[offset:] + } if i > 0 { msgLen := len(buff) baseLenBefore := len(msgs[base].Buffers[0]) @@ -162,7 +178,7 @@ func (c *linuxBatchingConn) putSendBatch(batch *sendBatch) { c.sendBatchPool.Put(batch) } -func (c *linuxBatchingConn) WriteBatchTo(buffs [][]byte, addr netip.AddrPort) error { +func (c *linuxBatchingConn) WriteBatchTo(buffs [][]byte, addr netip.AddrPort, geneve packet.GeneveHeader, offset int) error { batch := c.getSendBatch() defer c.putSendBatch(batch) if addr.Addr().Is6() { @@ -181,10 +197,17 @@ func (c *linuxBatchingConn) WriteBatchTo(buffs [][]byte, addr netip.AddrPort) er ) retry: if c.txOffload.Load() { - n = c.coalesceMessages(batch.ua, buffs, batch.msgs) + n = c.coalesceMessages(batch.ua, geneve, buffs, batch.msgs, offset) } else { + vniIsSet := geneve.VNI.IsSet() + if vniIsSet { + offset -= packet.GeneveFixedHeaderLength + } for i := range buffs { - batch.msgs[i].Buffers[0] = buffs[i] + if vniIsSet { + geneve.Encode(buffs[i]) + } + batch.msgs[i].Buffers[0] = buffs[i][offset:] batch.msgs[i].Addr = batch.ua batch.msgs[i].OOB = batch.msgs[i].OOB[:0] } @@ -204,11 +227,7 @@ retry: } func (c *linuxBatchingConn) SyscallConn() (syscall.RawConn, error) { - sc, ok := c.pc.(syscall.Conn) - if !ok { - return nil, errUnsupportedConnType - } - return sc.SyscallConn() + return c.pc.SyscallConn() } func (c *linuxBatchingConn) writeBatch(msgs []ipv6.Message) error { @@ -334,7 +353,7 @@ func getGSOSizeFromControl(control []byte) (int, error) { ) for len(rem) > unix.SizeofCmsghdr { - hdr, data, rem, err = unix.ParseOneSocketControlMessage(control) + hdr, data, rem, err = unix.ParseOneSocketControlMessage(rem) if err != nil { return 0, fmt.Errorf("error parsing socket control message: %w", err) } @@ -364,9 +383,10 @@ func setGSOSizeInControl(control *[]byte, gsoSize uint16) { *control = (*control)[:unix.CmsgSpace(2)] } -// tryUpgradeToBatchingConn probes the capabilities of the OS and pconn, and -// upgrades pconn to a *linuxBatchingConn if appropriate. -func tryUpgradeToBatchingConn(pconn nettype.PacketConn, network string, batchSize int) nettype.PacketConn { +// TryUpgradeToConn probes the capabilities of the OS and pconn, and upgrades +// pconn to a [Conn] if appropriate. A batch size of [IdealBatchSize] is +// suggested for the best performance. +func TryUpgradeToConn(pconn nettype.PacketConn, network string, batchSize int) nettype.PacketConn { if runtime.GOOS != "linux" { // Exclude Android. return pconn @@ -388,7 +408,7 @@ func tryUpgradeToBatchingConn(pconn nettype.PacketConn, network string, batchSiz return pconn } b := &linuxBatchingConn{ - pc: pconn, + pc: uc, getGSOSizeFromControl: getGSOSizeFromControl, setGSOSizeInControl: setGSOSizeInControl, sendBatchPool: sync.Pool{ @@ -422,3 +442,19 @@ func tryUpgradeToBatchingConn(pconn nettype.PacketConn, network string, batchSiz b.txOffload.Store(txOffload) return b } + +var controlMessageSize = -1 // bomb if used for allocation before init + +func init() { + // controlMessageSize is set to hold a UDP_GRO or UDP_SEGMENT control + // message. These contain a single uint16 of data. + controlMessageSize = unix.CmsgSpace(2) +} + +// MinControlMessageSize returns the minimum control message size required to +// support read batching via [Conn.ReadBatch]. +func MinControlMessageSize() int { + return controlMessageSize +} + +const IdealBatchSize = 128 diff --git a/wgengine/magicsock/batching_conn_linux_test.go b/net/batching/conn_linux_test.go similarity index 55% rename from wgengine/magicsock/batching_conn_linux_test.go rename to net/batching/conn_linux_test.go index 5c22bf1c7..c2cc463eb 100644 --- a/wgengine/magicsock/batching_conn_linux_test.go +++ b/net/batching/conn_linux_test.go @@ -1,14 +1,18 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -package magicsock +package batching import ( "encoding/binary" "net" "testing" + "unsafe" + "github.com/tailscale/wireguard-go/conn" "golang.org/x/net/ipv6" + "golang.org/x/sys/unix" + "tailscale.com/net/packet" ) func setGSOSize(control *[]byte, gsoSize uint16) { @@ -154,58 +158,119 @@ func Test_linuxBatchingConn_coalesceMessages(t *testing.T) { getGSOSizeFromControl: getGSOSize, } + withGeneveSpace := func(len, cap int) []byte { + return make([]byte, len+packet.GeneveFixedHeaderLength, cap+packet.GeneveFixedHeaderLength) + } + + geneve := packet.GeneveHeader{ + Protocol: packet.GeneveProtocolWireGuard, + } + geneve.VNI.Set(1) + cases := []struct { name string buffs [][]byte + geneve packet.GeneveHeader wantLens []int wantGSO []int }{ { name: "one message no coalesce", buffs: [][]byte{ - make([]byte, 1, 1), + withGeneveSpace(1, 1), }, wantLens: []int{1}, wantGSO: []int{0}, }, + { + name: "one message no coalesce vni.isSet", + buffs: [][]byte{ + withGeneveSpace(1, 1), + }, + geneve: geneve, + wantLens: []int{1 + packet.GeneveFixedHeaderLength}, + wantGSO: []int{0}, + }, { name: "two messages equal len coalesce", buffs: [][]byte{ - make([]byte, 1, 2), - make([]byte, 1, 1), + withGeneveSpace(1, 2), + withGeneveSpace(1, 1), }, wantLens: []int{2}, wantGSO: []int{1}, }, + { + name: "two messages equal len coalesce vni.isSet", + buffs: [][]byte{ + withGeneveSpace(1, 2+packet.GeneveFixedHeaderLength), + withGeneveSpace(1, 1), + }, + geneve: geneve, + wantLens: []int{2 + (2 * packet.GeneveFixedHeaderLength)}, + wantGSO: []int{1 + packet.GeneveFixedHeaderLength}, + }, { name: "two messages unequal len coalesce", buffs: [][]byte{ - make([]byte, 2, 3), - make([]byte, 1, 1), + withGeneveSpace(2, 3), + withGeneveSpace(1, 1), }, wantLens: []int{3}, wantGSO: []int{2}, }, + { + name: "two messages unequal len coalesce vni.isSet", + buffs: [][]byte{ + withGeneveSpace(2, 3+packet.GeneveFixedHeaderLength), + withGeneveSpace(1, 1), + }, + geneve: geneve, + wantLens: []int{3 + (2 * packet.GeneveFixedHeaderLength)}, + wantGSO: []int{2 + packet.GeneveFixedHeaderLength}, + }, { name: "three messages second unequal len coalesce", buffs: [][]byte{ - make([]byte, 2, 3), - make([]byte, 1, 1), - make([]byte, 2, 2), + withGeneveSpace(2, 3), + withGeneveSpace(1, 1), + withGeneveSpace(2, 2), }, wantLens: []int{3, 2}, wantGSO: []int{2, 0}, }, + { + name: "three messages second unequal len coalesce vni.isSet", + buffs: [][]byte{ + withGeneveSpace(2, 3+(2*packet.GeneveFixedHeaderLength)), + withGeneveSpace(1, 1), + withGeneveSpace(2, 2), + }, + geneve: geneve, + wantLens: []int{3 + (2 * packet.GeneveFixedHeaderLength), 2 + packet.GeneveFixedHeaderLength}, + wantGSO: []int{2 + packet.GeneveFixedHeaderLength, 0}, + }, { name: "three messages limited cap coalesce", buffs: [][]byte{ - make([]byte, 2, 4), - make([]byte, 2, 2), - make([]byte, 2, 2), + withGeneveSpace(2, 4), + withGeneveSpace(2, 2), + withGeneveSpace(2, 2), }, wantLens: []int{4, 2}, wantGSO: []int{2, 0}, }, + { + name: "three messages limited cap coalesce vni.isSet", + buffs: [][]byte{ + withGeneveSpace(2, 4+packet.GeneveFixedHeaderLength), + withGeneveSpace(2, 2), + withGeneveSpace(2, 2), + }, + geneve: geneve, + wantLens: []int{4 + (2 * packet.GeneveFixedHeaderLength), 2 + packet.GeneveFixedHeaderLength}, + wantGSO: []int{2 + packet.GeneveFixedHeaderLength, 0}, + }, } for _, tt := range cases { @@ -219,7 +284,7 @@ func Test_linuxBatchingConn_coalesceMessages(t *testing.T) { msgs[i].Buffers = make([][]byte, 1) msgs[i].OOB = make([]byte, 0, 2) } - got := c.coalesceMessages(addr, tt.buffs, msgs) + got := c.coalesceMessages(addr, tt.geneve, tt.buffs, msgs, packet.GeneveFixedHeaderLength) if got != len(tt.wantLens) { t.Fatalf("got len %d want: %d", got, len(tt.wantLens)) } @@ -242,3 +307,45 @@ func Test_linuxBatchingConn_coalesceMessages(t *testing.T) { }) } } + +func TestMinReadBatchMsgsLen(t *testing.T) { + // So long as magicsock uses [Conn], and [wireguard-go/conn.Bind] API is + // shaped for wireguard-go to control packet memory, these values should be + // aligned. + if IdealBatchSize != conn.IdealBatchSize { + t.Fatalf("IdealBatchSize: %d != conn.IdealBatchSize(): %d", IdealBatchSize, conn.IdealBatchSize) + } +} + +func Test_getGSOSizeFromControl_MultipleMessages(t *testing.T) { + // Test that getGSOSizeFromControl correctly parses UDP_GRO when it's not the first control message. + const expectedGSOSize = 1420 + + // First message: IP_TOS + firstMsgLen := unix.CmsgSpace(1) + firstMsg := make([]byte, firstMsgLen) + hdr1 := (*unix.Cmsghdr)(unsafe.Pointer(&firstMsg[0])) + hdr1.Level = unix.SOL_IP + hdr1.Type = unix.IP_TOS + hdr1.SetLen(unix.CmsgLen(1)) + firstMsg[unix.SizeofCmsghdr] = 0 + + // Second message: UDP_GRO + secondMsgLen := unix.CmsgSpace(2) + secondMsg := make([]byte, secondMsgLen) + hdr2 := (*unix.Cmsghdr)(unsafe.Pointer(&secondMsg[0])) + hdr2.Level = unix.SOL_UDP + hdr2.Type = unix.UDP_GRO + hdr2.SetLen(unix.CmsgLen(2)) + binary.NativeEndian.PutUint16(secondMsg[unix.SizeofCmsghdr:], expectedGSOSize) + + control := append(firstMsg, secondMsg...) + + gsoSize, err := getGSOSizeFromControl(control) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if gsoSize != expectedGSOSize { + t.Errorf("got GSO size %d, want %d", gsoSize, expectedGSOSize) + } +} diff --git a/net/captivedetection/captivedetection.go b/net/captivedetection/captivedetection.go index c6e8bca3a..3ec820b79 100644 --- a/net/captivedetection/captivedetection.go +++ b/net/captivedetection/captivedetection.go @@ -11,18 +11,21 @@ import ( "net" "net/http" "runtime" + "strconv" "strings" "sync" "syscall" "time" "tailscale.com/net/netmon" + "tailscale.com/syncs" "tailscale.com/tailcfg" "tailscale.com/types/logger" ) // Detector checks whether the system is behind a captive portal. type Detector struct { + clock func() time.Time // httpClient is the HTTP client that is used for captive portal detection. It is configured // to not follow redirects, have a short timeout and no keep-alive. @@ -30,7 +33,7 @@ type Detector struct { // currIfIndex is the index of the interface that is currently being used by the httpClient. currIfIndex int // mu guards currIfIndex. - mu sync.Mutex + mu syncs.Mutex // logf is the logger used for logging messages. If it is nil, log.Printf is used. logf logger.Logf } @@ -52,6 +55,13 @@ func NewDetector(logf logger.Logf) *Detector { return d } +func (d *Detector) Now() time.Time { + if d.clock != nil { + return d.clock() + } + return time.Now() +} + // Timeout is the timeout for captive portal detection requests. Because the captive portal intercepting our requests // is usually located on the LAN, this is a relatively short timeout. const Timeout = 3 * time.Second @@ -136,26 +146,31 @@ func interfaceNameDoesNotNeedCaptiveDetection(ifName string, goos string) bool { func (d *Detector) detectOnInterface(ctx context.Context, ifIndex int, endpoints []Endpoint) bool { defer d.httpClient.CloseIdleConnections() - d.logf("[v2] %d available captive portal detection endpoints: %v", len(endpoints), endpoints) + use := min(len(endpoints), 5) + endpoints = endpoints[:use] + d.logf("[v2] %d available captive portal detection endpoints; trying %v", len(endpoints), use) // We try to detect the captive portal more quickly by making requests to multiple endpoints concurrently. var wg sync.WaitGroup resultCh := make(chan bool, len(endpoints)) - for i, e := range endpoints { - if i >= 5 { - // Try a maximum of 5 endpoints, break out (returning false) if we run of attempts. - break - } + // Once any goroutine detects a captive portal, we shut down the others. + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + for _, e := range endpoints { wg.Add(1) go func(endpoint Endpoint) { defer wg.Done() found, err := d.verifyCaptivePortalEndpoint(ctx, endpoint, ifIndex) if err != nil { - d.logf("[v1] checkCaptivePortalEndpoint failed with endpoint %v: %v", endpoint, err) + if ctx.Err() == nil { + d.logf("[v1] checkCaptivePortalEndpoint failed with endpoint %v: %v", endpoint, err) + } return } if found { + cancel() // one match is good enough resultCh <- true } }(e) @@ -182,10 +197,16 @@ func (d *Detector) verifyCaptivePortalEndpoint(ctx context.Context, e Endpoint, ctx, cancel := context.WithTimeout(ctx, Timeout) defer cancel() - req, err := http.NewRequestWithContext(ctx, "GET", e.URL.String(), nil) + u := *e.URL + v := u.Query() + v.Add("t", strconv.Itoa(int(d.Now().Unix()))) + u.RawQuery = v.Encode() + + req, err := http.NewRequestWithContext(ctx, "GET", u.String(), nil) if err != nil { return false, err } + req.Header.Set("Cache-Control", "no-cache, no-store, must-revalidate, no-transform, max-age=0") // Attach the Tailscale challenge header if the endpoint supports it. Not all captive portal detection endpoints // support this, so we only attach it if the endpoint does. diff --git a/net/captivedetection/captivedetection_test.go b/net/captivedetection/captivedetection_test.go index e74273afd..0778e07df 100644 --- a/net/captivedetection/captivedetection_test.go +++ b/net/captivedetection/captivedetection_test.go @@ -5,12 +5,21 @@ package captivedetection import ( "context" + "net/http" + "net/http/httptest" + "net/url" "runtime" + "strconv" "sync" + "sync/atomic" "testing" + "time" - "tailscale.com/cmd/testwrapper/flakytest" + "tailscale.com/derp/derpserver" "tailscale.com/net/netmon" + "tailscale.com/syncs" + "tailscale.com/tstest/nettest" + "tailscale.com/util/must" ) func TestAvailableEndpointsAlwaysAtLeastTwo(t *testing.T) { @@ -36,25 +45,110 @@ func TestDetectCaptivePortalReturnsFalse(t *testing.T) { } } -func TestAllEndpointsAreUpAndReturnExpectedResponse(t *testing.T) { - flakytest.Mark(t, "https://github.com/tailscale/tailscale/issues/13019") +func TestEndpointsAreUpAndReturnExpectedResponse(t *testing.T) { + nettest.SkipIfNoNetwork(t) + d := NewDetector(t.Logf) endpoints := availableEndpoints(nil, 0, t.Logf, runtime.GOOS) + t.Logf("testing %d endpoints", len(endpoints)) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + var good atomic.Bool var wg sync.WaitGroup + sem := syncs.NewSemaphore(5) for _, e := range endpoints { wg.Add(1) go func(endpoint Endpoint) { defer wg.Done() - found, err := d.verifyCaptivePortalEndpoint(context.Background(), endpoint, 0) - if err != nil { - t.Errorf("verifyCaptivePortalEndpoint failed with endpoint %v: %v", endpoint, err) + + if !sem.AcquireContext(ctx) { + return + } + defer sem.Release() + + found, err := d.verifyCaptivePortalEndpoint(ctx, endpoint, 0) + if err != nil && ctx.Err() == nil { + t.Logf("verifyCaptivePortalEndpoint failed with endpoint %v: %v", endpoint, err) } if found { - t.Errorf("verifyCaptivePortalEndpoint with endpoint %v says we're behind a captive portal, but we aren't", endpoint) + t.Logf("verifyCaptivePortalEndpoint with endpoint %v says we're behind a captive portal, but we aren't", endpoint) + return } + good.Store(true) + t.Logf("endpoint good: %v", endpoint) + cancel() }(e) } wg.Wait() + + if !good.Load() { + t.Errorf("no good endpoints found") + } +} + +func TestCaptivePortalRequest(t *testing.T) { + d := NewDetector(t.Logf) + now := time.Now() + d.clock = func() time.Time { return now } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != "GET" { + t.Errorf("expected GET, got %q", r.Method) + } + if r.URL.Path != "/generate_204" { + t.Errorf("expected /generate_204, got %q", r.URL.Path) + } + q := r.URL.Query() + if got, want := q.Get("t"), strconv.Itoa(int(now.Unix())); got != want { + t.Errorf("timestamp param; got %v, want %v", got, want) + } + w.Header().Set("X-Tailscale-Response", "response "+r.Header.Get("X-Tailscale-Challenge")) + + w.WriteHeader(http.StatusNoContent) + })) + defer s.Close() + + e := Endpoint{ + URL: must.Get(url.Parse(s.URL + "/generate_204")), + StatusCode: 204, + ExpectedContent: "", + SupportsTailscaleChallenge: true, + } + + found, err := d.verifyCaptivePortalEndpoint(ctx, e, 0) + if err != nil { + t.Fatalf("verifyCaptivePortalEndpoint = %v, %v", found, err) + } + if found { + t.Errorf("verifyCaptivePortalEndpoint = %v, want false", found) + } +} + +func TestAgainstDERPHandler(t *testing.T) { + d := NewDetector(t.Logf) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + s := httptest.NewServer(http.HandlerFunc(derpserver.ServeNoContent)) + defer s.Close() + e := Endpoint{ + URL: must.Get(url.Parse(s.URL + "/generate_204")), + StatusCode: 204, + ExpectedContent: "", + SupportsTailscaleChallenge: true, + } + found, err := d.verifyCaptivePortalEndpoint(ctx, e, 0) + if err != nil { + t.Fatalf("verifyCaptivePortalEndpoint = %v, %v", found, err) + } + if found { + t.Errorf("verifyCaptivePortalEndpoint = %v, want false", found) + } } diff --git a/net/captivedetection/endpoints.go b/net/captivedetection/endpoints.go index 450ed4a1c..57b3e5335 100644 --- a/net/captivedetection/endpoints.go +++ b/net/captivedetection/endpoints.go @@ -89,7 +89,7 @@ func availableEndpoints(derpMap *tailcfg.DERPMap, preferredDERPRegionID int, log // Use the DERP IPs as captive portal detection endpoints. Using IPs is better than hostnames // because they do not depend on DNS resolution. for _, region := range derpMap.Regions { - if region.Avoid { + if region.Avoid || region.NoMeasureNoHome { continue } for _, node := range region.Nodes { diff --git a/net/connectproxy/connectproxy.go b/net/connectproxy/connectproxy.go new file mode 100644 index 000000000..4bf687502 --- /dev/null +++ b/net/connectproxy/connectproxy.go @@ -0,0 +1,93 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package connectproxy contains some CONNECT proxy code. +package connectproxy + +import ( + "context" + "io" + "log" + "net" + "net/http" + "time" + + "tailscale.com/net/netx" + "tailscale.com/types/logger" +) + +// Handler is an HTTP CONNECT proxy handler. +type Handler struct { + // Dial, if non-nil, is an alternate dialer to use + // instead of the default dialer. + Dial netx.DialFunc + + // Logf, if non-nil, is an alterate logger to + // use instead of log.Printf. + Logf logger.Logf + + // Check, if non-nil, validates the CONNECT target. + Check func(hostPort string) error +} + +func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + if r.Method != "CONNECT" { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + dial := h.Dial + if dial == nil { + var d net.Dialer + dial = d.DialContext + } + logf := h.Logf + if logf == nil { + logf = log.Printf + } + + hostPort := r.RequestURI + if h.Check != nil { + if err := h.Check(hostPort); err != nil { + logf("CONNECT target %q not allowed: %v", hostPort, err) + http.Error(w, "Invalid CONNECT target", http.StatusForbidden) + return + } + } + + ctx, cancel := context.WithTimeout(ctx, 15*time.Second) + defer cancel() + back, err := dial(ctx, "tcp", hostPort) + if err != nil { + logf("error CONNECT dialing %v: %v", hostPort, err) + http.Error(w, "Connect failure", http.StatusBadGateway) + return + } + defer back.Close() + + hj, ok := w.(http.Hijacker) + if !ok { + http.Error(w, "CONNECT hijack unavailable", http.StatusInternalServerError) + return + } + c, br, err := hj.Hijack() + if err != nil { + logf("CONNECT hijack: %v", err) + return + } + defer c.Close() + + io.WriteString(c, "HTTP/1.1 200 OK\r\n\r\n") + + errc := make(chan error, 2) + go func() { + _, err := io.Copy(c, back) + errc <- err + }() + go func() { + _, err := io.Copy(back, br) + errc <- err + }() + <-errc +} diff --git a/net/connstats/stats.go b/net/connstats/stats.go deleted file mode 100644 index dbcd946b8..000000000 --- a/net/connstats/stats.go +++ /dev/null @@ -1,222 +0,0 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package connstats maintains statistics about connections -// flowing through a TUN device (which operate at the IP layer). -package connstats - -import ( - "context" - "net/netip" - "sync" - "time" - - "golang.org/x/sync/errgroup" - "tailscale.com/net/packet" - "tailscale.com/net/tsaddr" - "tailscale.com/types/netlogtype" -) - -// Statistics maintains counters for every connection. -// All methods are safe for concurrent use. -// The zero value is ready for use. -type Statistics struct { - maxConns int // immutable once set - - mu sync.Mutex - connCnts - - connCntsCh chan connCnts - shutdownCtx context.Context - shutdown context.CancelFunc - group errgroup.Group -} - -type connCnts struct { - start time.Time - end time.Time - virtual map[netlogtype.Connection]netlogtype.Counts - physical map[netlogtype.Connection]netlogtype.Counts -} - -// NewStatistics creates a data structure for tracking connection statistics -// that periodically dumps the virtual and physical connection counts -// depending on whether the maxPeriod or maxConns is exceeded. -// The dump function is called from a single goroutine. -// Shutdown must be called to cleanup resources. -func NewStatistics(maxPeriod time.Duration, maxConns int, dump func(start, end time.Time, virtual, physical map[netlogtype.Connection]netlogtype.Counts)) *Statistics { - s := &Statistics{maxConns: maxConns} - s.connCntsCh = make(chan connCnts, 256) - s.shutdownCtx, s.shutdown = context.WithCancel(context.Background()) - s.group.Go(func() error { - // TODO(joetsai): Using a ticker is problematic on mobile platforms - // where waking up a process every maxPeriod when there is no activity - // is a drain on battery life. Switch this instead to instead use - // a time.Timer that is triggered upon network activity. - ticker := new(time.Ticker) - if maxPeriod > 0 { - ticker = time.NewTicker(maxPeriod) - defer ticker.Stop() - } - - for { - var cc connCnts - select { - case cc = <-s.connCntsCh: - case <-ticker.C: - cc = s.extract() - case <-s.shutdownCtx.Done(): - cc = s.extract() - } - if len(cc.virtual)+len(cc.physical) > 0 && dump != nil { - dump(cc.start, cc.end, cc.virtual, cc.physical) - } - if s.shutdownCtx.Err() != nil { - return nil - } - } - }) - return s -} - -// UpdateTxVirtual updates the counters for a transmitted IP packet -// The source and destination of the packet directly correspond with -// the source and destination in netlogtype.Connection. -func (s *Statistics) UpdateTxVirtual(b []byte) { - s.updateVirtual(b, false) -} - -// UpdateRxVirtual updates the counters for a received IP packet. -// The source and destination of the packet are inverted with respect to -// the source and destination in netlogtype.Connection. -func (s *Statistics) UpdateRxVirtual(b []byte) { - s.updateVirtual(b, true) -} - -var ( - tailscaleServiceIPv4 = tsaddr.TailscaleServiceIP() - tailscaleServiceIPv6 = tsaddr.TailscaleServiceIPv6() -) - -func (s *Statistics) updateVirtual(b []byte, receive bool) { - var p packet.Parsed - p.Decode(b) - conn := netlogtype.Connection{Proto: p.IPProto, Src: p.Src, Dst: p.Dst} - if receive { - conn.Src, conn.Dst = conn.Dst, conn.Src - } - - // Network logging is defined as traffic between two Tailscale nodes. - // Traffic with the internal Tailscale service is not with another node - // and should not be logged. It also happens to be a high volume - // amount of discrete traffic flows (e.g., DNS lookups). - switch conn.Dst.Addr() { - case tailscaleServiceIPv4, tailscaleServiceIPv6: - return - } - - s.mu.Lock() - defer s.mu.Unlock() - cnts, found := s.virtual[conn] - if !found && !s.preInsertConn() { - return - } - if receive { - cnts.RxPackets++ - cnts.RxBytes += uint64(len(b)) - } else { - cnts.TxPackets++ - cnts.TxBytes += uint64(len(b)) - } - s.virtual[conn] = cnts -} - -// UpdateTxPhysical updates the counters for a transmitted wireguard packet -// The src is always a Tailscale IP address, representing some remote peer. -// The dst is a remote IP address and port that corresponds -// with some physical peer backing the Tailscale IP address. -func (s *Statistics) UpdateTxPhysical(src netip.Addr, dst netip.AddrPort, n int) { - s.updatePhysical(src, dst, n, false) -} - -// UpdateRxPhysical updates the counters for a received wireguard packet. -// The src is always a Tailscale IP address, representing some remote peer. -// The dst is a remote IP address and port that corresponds -// with some physical peer backing the Tailscale IP address. -func (s *Statistics) UpdateRxPhysical(src netip.Addr, dst netip.AddrPort, n int) { - s.updatePhysical(src, dst, n, true) -} - -func (s *Statistics) updatePhysical(src netip.Addr, dst netip.AddrPort, n int, receive bool) { - conn := netlogtype.Connection{Src: netip.AddrPortFrom(src, 0), Dst: dst} - - s.mu.Lock() - defer s.mu.Unlock() - cnts, found := s.physical[conn] - if !found && !s.preInsertConn() { - return - } - if receive { - cnts.RxPackets++ - cnts.RxBytes += uint64(n) - } else { - cnts.TxPackets++ - cnts.TxBytes += uint64(n) - } - s.physical[conn] = cnts -} - -// preInsertConn updates the maps to handle insertion of a new connection. -// It reports false if insertion is not allowed (i.e., after shutdown). -func (s *Statistics) preInsertConn() bool { - // Check whether insertion of a new connection will exceed maxConns. - if len(s.virtual)+len(s.physical) == s.maxConns && s.maxConns > 0 { - // Extract the current statistics and send it to the serializer. - // Avoid blocking the network packet handling path. - select { - case s.connCntsCh <- s.extractLocked(): - default: - // TODO(joetsai): Log that we are dropping an entire connCounts. - } - } - - // Initialize the maps if nil. - if s.virtual == nil && s.physical == nil { - s.start = time.Now().UTC() - s.virtual = make(map[netlogtype.Connection]netlogtype.Counts) - s.physical = make(map[netlogtype.Connection]netlogtype.Counts) - } - - return s.shutdownCtx.Err() == nil -} - -func (s *Statistics) extract() connCnts { - s.mu.Lock() - defer s.mu.Unlock() - return s.extractLocked() -} - -func (s *Statistics) extractLocked() connCnts { - if len(s.virtual)+len(s.physical) == 0 { - return connCnts{} - } - s.end = time.Now().UTC() - cc := s.connCnts - s.connCnts = connCnts{} - return cc -} - -// TestExtract synchronously extracts the current network statistics map -// and resets the counters. This should only be used for testing purposes. -func (s *Statistics) TestExtract() (virtual, physical map[netlogtype.Connection]netlogtype.Counts) { - cc := s.extract() - return cc.virtual, cc.physical -} - -// Shutdown performs a final flush of statistics. -// Statistics for any subsequent calls to Update will be dropped. -// It is safe to call Shutdown concurrently and repeatedly. -func (s *Statistics) Shutdown(context.Context) error { - s.shutdown() - return s.group.Wait() -} diff --git a/net/connstats/stats_test.go b/net/connstats/stats_test.go deleted file mode 100644 index ae0bca8a5..000000000 --- a/net/connstats/stats_test.go +++ /dev/null @@ -1,235 +0,0 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package connstats - -import ( - "context" - "encoding/binary" - "fmt" - "math/rand" - "net/netip" - "runtime" - "sync" - "testing" - "time" - - qt "github.com/frankban/quicktest" - "tailscale.com/cmd/testwrapper/flakytest" - "tailscale.com/types/ipproto" - "tailscale.com/types/netlogtype" -) - -func testPacketV4(proto ipproto.Proto, srcAddr, dstAddr [4]byte, srcPort, dstPort, size uint16) (out []byte) { - var ipHdr [20]byte - ipHdr[0] = 4<<4 | 5 - binary.BigEndian.PutUint16(ipHdr[2:], size) - ipHdr[9] = byte(proto) - *(*[4]byte)(ipHdr[12:]) = srcAddr - *(*[4]byte)(ipHdr[16:]) = dstAddr - out = append(out, ipHdr[:]...) - switch proto { - case ipproto.TCP: - var tcpHdr [20]byte - binary.BigEndian.PutUint16(tcpHdr[0:], srcPort) - binary.BigEndian.PutUint16(tcpHdr[2:], dstPort) - out = append(out, tcpHdr[:]...) - case ipproto.UDP: - var udpHdr [8]byte - binary.BigEndian.PutUint16(udpHdr[0:], srcPort) - binary.BigEndian.PutUint16(udpHdr[2:], dstPort) - out = append(out, udpHdr[:]...) - default: - panic(fmt.Sprintf("unknown proto: %d", proto)) - } - return append(out, make([]byte, int(size)-len(out))...) -} - -// TestInterval ensures that we receive at least one call to `dump` using only -// maxPeriod. -func TestInterval(t *testing.T) { - c := qt.New(t) - - const maxPeriod = 10 * time.Millisecond - const maxConns = 2048 - - gotDump := make(chan struct{}, 1) - stats := NewStatistics(maxPeriod, maxConns, func(_, _ time.Time, _, _ map[netlogtype.Connection]netlogtype.Counts) { - select { - case gotDump <- struct{}{}: - default: - } - }) - defer stats.Shutdown(context.Background()) - - srcAddr := netip.AddrFrom4([4]byte{192, 168, 0, byte(rand.Intn(16))}) - dstAddr := netip.AddrFrom4([4]byte{192, 168, 0, byte(rand.Intn(16))}) - srcPort := uint16(rand.Intn(16)) - dstPort := uint16(rand.Intn(16)) - size := uint16(64 + rand.Intn(1024)) - p := testPacketV4(ipproto.TCP, srcAddr.As4(), dstAddr.As4(), srcPort, dstPort, size) - stats.UpdateRxVirtual(p) - - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - defer cancel() - select { - case <-ctx.Done(): - c.Fatal("didn't receive dump within context deadline") - case <-gotDump: - } -} - -func TestConcurrent(t *testing.T) { - flakytest.Mark(t, "https://github.com/tailscale/tailscale/issues/7030") - c := qt.New(t) - - const maxPeriod = 10 * time.Millisecond - const maxConns = 10 - virtualAggregate := make(map[netlogtype.Connection]netlogtype.Counts) - stats := NewStatistics(maxPeriod, maxConns, func(start, end time.Time, virtual, physical map[netlogtype.Connection]netlogtype.Counts) { - c.Assert(start.IsZero(), qt.IsFalse) - c.Assert(end.IsZero(), qt.IsFalse) - c.Assert(end.Before(start), qt.IsFalse) - c.Assert(len(virtual) > 0 && len(virtual) <= maxConns, qt.IsTrue) - c.Assert(len(physical) == 0, qt.IsTrue) - for conn, cnts := range virtual { - virtualAggregate[conn] = virtualAggregate[conn].Add(cnts) - } - }) - defer stats.Shutdown(context.Background()) - var wants []map[netlogtype.Connection]netlogtype.Counts - gots := make([]map[netlogtype.Connection]netlogtype.Counts, runtime.NumCPU()) - var group sync.WaitGroup - for i := range gots { - group.Add(1) - go func(i int) { - defer group.Done() - gots[i] = make(map[netlogtype.Connection]netlogtype.Counts) - rn := rand.New(rand.NewSource(time.Now().UnixNano())) - var p []byte - var t netlogtype.Connection - for j := 0; j < 1000; j++ { - delay := rn.Intn(10000) - if p == nil || rn.Intn(64) == 0 { - proto := ipproto.TCP - if rn.Intn(2) == 0 { - proto = ipproto.UDP - } - srcAddr := netip.AddrFrom4([4]byte{192, 168, 0, byte(rand.Intn(16))}) - dstAddr := netip.AddrFrom4([4]byte{192, 168, 0, byte(rand.Intn(16))}) - srcPort := uint16(rand.Intn(16)) - dstPort := uint16(rand.Intn(16)) - size := uint16(64 + rand.Intn(1024)) - p = testPacketV4(proto, srcAddr.As4(), dstAddr.As4(), srcPort, dstPort, size) - t = netlogtype.Connection{Proto: proto, Src: netip.AddrPortFrom(srcAddr, srcPort), Dst: netip.AddrPortFrom(dstAddr, dstPort)} - } - t2 := t - receive := rn.Intn(2) == 0 - if receive { - t2.Src, t2.Dst = t2.Dst, t2.Src - } - - cnts := gots[i][t2] - if receive { - stats.UpdateRxVirtual(p) - cnts.RxPackets++ - cnts.RxBytes += uint64(len(p)) - } else { - cnts.TxPackets++ - cnts.TxBytes += uint64(len(p)) - stats.UpdateTxVirtual(p) - } - gots[i][t2] = cnts - time.Sleep(time.Duration(rn.Intn(1 + delay))) - } - }(i) - } - group.Wait() - c.Assert(stats.Shutdown(context.Background()), qt.IsNil) - wants = append(wants, virtualAggregate) - - got := make(map[netlogtype.Connection]netlogtype.Counts) - want := make(map[netlogtype.Connection]netlogtype.Counts) - mergeMaps(got, gots...) - mergeMaps(want, wants...) - c.Assert(got, qt.DeepEquals, want) -} - -func mergeMaps(dst map[netlogtype.Connection]netlogtype.Counts, srcs ...map[netlogtype.Connection]netlogtype.Counts) { - for _, src := range srcs { - for conn, cnts := range src { - dst[conn] = dst[conn].Add(cnts) - } - } -} - -func Benchmark(b *testing.B) { - // TODO: Test IPv6 packets? - b.Run("SingleRoutine/SameConn", func(b *testing.B) { - p := testPacketV4(ipproto.UDP, [4]byte{192, 168, 0, 1}, [4]byte{192, 168, 0, 2}, 123, 456, 789) - b.ResetTimer() - b.ReportAllocs() - for range b.N { - s := NewStatistics(0, 0, nil) - for j := 0; j < 1e3; j++ { - s.UpdateTxVirtual(p) - } - } - }) - b.Run("SingleRoutine/UniqueConns", func(b *testing.B) { - p := testPacketV4(ipproto.UDP, [4]byte{}, [4]byte{}, 0, 0, 789) - b.ResetTimer() - b.ReportAllocs() - for range b.N { - s := NewStatistics(0, 0, nil) - for j := 0; j < 1e3; j++ { - binary.BigEndian.PutUint32(p[20:], uint32(j)) // unique port combination - s.UpdateTxVirtual(p) - } - } - }) - b.Run("MultiRoutine/SameConn", func(b *testing.B) { - p := testPacketV4(ipproto.UDP, [4]byte{192, 168, 0, 1}, [4]byte{192, 168, 0, 2}, 123, 456, 789) - b.ResetTimer() - b.ReportAllocs() - for range b.N { - s := NewStatistics(0, 0, nil) - var group sync.WaitGroup - for j := 0; j < runtime.NumCPU(); j++ { - group.Add(1) - go func() { - defer group.Done() - for k := 0; k < 1e3; k++ { - s.UpdateTxVirtual(p) - } - }() - } - group.Wait() - } - }) - b.Run("MultiRoutine/UniqueConns", func(b *testing.B) { - ps := make([][]byte, runtime.NumCPU()) - for i := range ps { - ps[i] = testPacketV4(ipproto.UDP, [4]byte{192, 168, 0, 1}, [4]byte{192, 168, 0, 2}, 0, 0, 789) - } - b.ResetTimer() - b.ReportAllocs() - for range b.N { - s := NewStatistics(0, 0, nil) - var group sync.WaitGroup - for j := 0; j < runtime.NumCPU(); j++ { - group.Add(1) - go func(j int) { - defer group.Done() - p := ps[j] - j *= 1e3 - for k := 0; k < 1e3; k++ { - binary.BigEndian.PutUint32(p[20:], uint32(j+k)) // unique port combination - s.UpdateTxVirtual(p) - } - }(j) - } - group.Wait() - } - }) -} diff --git a/net/dns/config.go b/net/dns/config.go index 67d3d753c..2425b304d 100644 --- a/net/dns/config.go +++ b/net/dns/config.go @@ -1,6 +1,8 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause +//go:generate go run tailscale.com/cmd/viewer --type=Config --clonefunc + // Package dns contains code to configure and manage DNS settings. package dns @@ -8,8 +10,12 @@ import ( "bufio" "fmt" "net/netip" + "reflect" + "slices" "sort" + "tailscale.com/control/controlknobs" + "tailscale.com/envknob" "tailscale.com/net/dns/publicdns" "tailscale.com/net/dns/resolver" "tailscale.com/net/tsaddr" @@ -47,11 +53,28 @@ type Config struct { OnlyIPv6 bool } -func (c *Config) serviceIP() netip.Addr { +var magicDNSDualStack = envknob.RegisterBool("TS_DEBUG_MAGIC_DNS_DUAL_STACK") + +// serviceIPs returns the list of service IPs where MagicDNS is reachable. +// +// The provided knobs may be nil. +func (c *Config) serviceIPs(knobs *controlknobs.Knobs) []netip.Addr { if c.OnlyIPv6 { - return tsaddr.TailscaleServiceIPv6() + return []netip.Addr{tsaddr.TailscaleServiceIPv6()} } - return tsaddr.TailscaleServiceIP() + + // TODO(bradfitz,mikeodr,raggi): include IPv6 here too; tailscale/tailscale#15404 + // And add a controlknobs knob to disable dual stack. + // + // For now, opt-in for testing. + if magicDNSDualStack() { + return []netip.Addr{ + tsaddr.TailscaleServiceIP(), + tsaddr.TailscaleServiceIPv6(), + } + } + + return []netip.Addr{tsaddr.TailscaleServiceIP()} } // WriteToBufioWriter write a debug version of c for logs to w, omitting @@ -162,21 +185,16 @@ func sameResolverNames(a, b []*dnstype.Resolver) bool { if a[i].Addr != b[i].Addr { return false } - if !sameIPs(a[i].BootstrapResolution, b[i].BootstrapResolution) { + if !slices.Equal(a[i].BootstrapResolution, b[i].BootstrapResolution) { return false } } return true } -func sameIPs(a, b []netip.Addr) bool { - if len(a) != len(b) { - return false - } - for i := range a { - if a[i] != b[i] { - return false - } +func (c *Config) Equal(o *Config) bool { + if c == nil || o == nil { + return c == o } - return true + return reflect.DeepEqual(c, o) } diff --git a/net/dns/dbus.go b/net/dns/dbus.go new file mode 100644 index 000000000..c53e8b720 --- /dev/null +++ b/net/dns/dbus.go @@ -0,0 +1,59 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build linux && !android && !ts_omit_dbus + +package dns + +import ( + "context" + "time" + + "github.com/godbus/dbus/v5" +) + +func init() { + optDBusPing.Set(dbusPing) + optDBusReadString.Set(dbusReadString) +} + +func dbusPing(name, objectPath string) error { + conn, err := dbus.SystemBus() + if err != nil { + // DBus probably not running. + return err + } + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + + obj := conn.Object(name, dbus.ObjectPath(objectPath)) + call := obj.CallWithContext(ctx, "org.freedesktop.DBus.Peer.Ping", 0) + return call.Err +} + +// dbusReadString reads a string property from the provided name and object +// path. property must be in "interface.member" notation. +func dbusReadString(name, objectPath, iface, member string) (string, error) { + conn, err := dbus.SystemBus() + if err != nil { + // DBus probably not running. + return "", err + } + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + + obj := conn.Object(name, dbus.ObjectPath(objectPath)) + + var result dbus.Variant + err = obj.CallWithContext(ctx, "org.freedesktop.DBus.Properties.Get", 0, iface, member).Store(&result) + if err != nil { + return "", err + } + + if s, ok := result.Value().(string); ok { + return s, nil + } + return result.String(), nil +} diff --git a/net/dns/debian_resolvconf.go b/net/dns/debian_resolvconf.go index 3ffc796e0..63fd80c12 100644 --- a/net/dns/debian_resolvconf.go +++ b/net/dns/debian_resolvconf.go @@ -1,7 +1,7 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -//go:build linux || freebsd || openbsd +//go:build (linux && !android) || freebsd || openbsd package dns diff --git a/net/dns/direct.go b/net/dns/direct.go index aaff18fcb..59eb06964 100644 --- a/net/dns/direct.go +++ b/net/dns/direct.go @@ -1,6 +1,8 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause +//go:build !android && !ios + package dns import ( @@ -21,6 +23,7 @@ import ( "sync" "time" + "tailscale.com/feature" "tailscale.com/health" "tailscale.com/net/dns/resolvconffile" "tailscale.com/net/tsaddr" @@ -413,6 +416,73 @@ func (m *directManager) GetBaseConfig() (OSConfig, error) { return oscfg, nil } +// HookWatchFile is a hook for watching file changes, for platforms that support it. +// The function is called with a directory and filename to watch, and a callback +// to call when the file changes. It returns an error if the watch could not be set up. +var HookWatchFile feature.Hook[func(ctx context.Context, dir, filename string, cb func()) error] + +func (m *directManager) runFileWatcher() { + watchFile, ok := HookWatchFile.GetOk() + if !ok { + return + } + if err := watchFile(m.ctx, "/etc/", resolvConf, m.checkForFileTrample); err != nil { + // This is all best effort for now, so surface warnings to users. + m.logf("dns: inotify: %s", err) + } +} + +var resolvTrampleWarnable = health.Register(&health.Warnable{ + Code: "resolv-conf-overwritten", + Severity: health.SeverityMedium, + Title: "DNS configuration issue", + Text: health.StaticMessage("System DNS config not ideal. /etc/resolv.conf overwritten. See https://tailscale.com/s/dns-fight"), +}) + +// checkForFileTrample checks whether /etc/resolv.conf has been trampled +// by another program on the system. (e.g. a DHCP client) +func (m *directManager) checkForFileTrample() { + m.mu.Lock() + want := m.wantResolvConf + lastWarn := m.lastWarnContents + m.mu.Unlock() + + if want == nil { + return + } + + cur, err := m.fs.ReadFile(resolvConf) + if err != nil { + m.logf("trample: read error: %v", err) + return + } + if bytes.Equal(cur, want) { + m.health.SetHealthy(resolvTrampleWarnable) + if lastWarn != nil { + m.mu.Lock() + m.lastWarnContents = nil + m.mu.Unlock() + m.logf("trample: resolv.conf again matches expected content") + } + return + } + if bytes.Equal(cur, lastWarn) { + // We already logged about this, so not worth doing it again. + return + } + + m.mu.Lock() + m.lastWarnContents = cur + m.mu.Unlock() + + show := cur + if len(show) > 1024 { + show = show[:1024] + } + m.logf("trample: resolv.conf changed from what we expected. did some other program interfere? current contents: %q", show) + m.health.SetUnhealthy(resolvTrampleWarnable, nil) +} + func (m *directManager) Close() error { m.ctxClose() diff --git a/net/dns/direct_linux.go b/net/dns/direct_linux.go deleted file mode 100644 index bdeefb352..000000000 --- a/net/dns/direct_linux.go +++ /dev/null @@ -1,108 +0,0 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package dns - -import ( - "bytes" - "context" - - "github.com/illarion/gonotify/v2" - "tailscale.com/health" -) - -func (m *directManager) runFileWatcher() { - ctx, cancel := context.WithCancel(m.ctx) - defer cancel() - in, err := gonotify.NewInotify(ctx) - if err != nil { - // Oh well, we tried. This is all best effort for now, to - // surface warnings to users. - m.logf("dns: inotify new: %v", err) - return - } - - const events = gonotify.IN_ATTRIB | - gonotify.IN_CLOSE_WRITE | - gonotify.IN_CREATE | - gonotify.IN_DELETE | - gonotify.IN_MODIFY | - gonotify.IN_MOVE - - if err := in.AddWatch("/etc/", events); err != nil { - m.logf("dns: inotify addwatch: %v", err) - return - } - for { - events, err := in.Read() - if ctx.Err() != nil { - return - } - if err != nil { - m.logf("dns: inotify read: %v", err) - return - } - var match bool - for _, ev := range events { - if ev.Name == resolvConf { - match = true - break - } - } - if !match { - continue - } - m.checkForFileTrample() - } -} - -var resolvTrampleWarnable = health.Register(&health.Warnable{ - Code: "resolv-conf-overwritten", - Severity: health.SeverityMedium, - Title: "Linux DNS configuration issue", - Text: health.StaticMessage("Linux DNS config not ideal. /etc/resolv.conf overwritten. See https://tailscale.com/s/dns-fight"), -}) - -// checkForFileTrample checks whether /etc/resolv.conf has been trampled -// by another program on the system. (e.g. a DHCP client) -func (m *directManager) checkForFileTrample() { - m.mu.Lock() - want := m.wantResolvConf - lastWarn := m.lastWarnContents - m.mu.Unlock() - - if want == nil { - return - } - - cur, err := m.fs.ReadFile(resolvConf) - if err != nil { - m.logf("trample: read error: %v", err) - return - } - if bytes.Equal(cur, want) { - m.health.SetHealthy(resolvTrampleWarnable) - if lastWarn != nil { - m.mu.Lock() - m.lastWarnContents = nil - m.mu.Unlock() - m.logf("trample: resolv.conf again matches expected content") - } - return - } - if bytes.Equal(cur, lastWarn) { - // We already logged about this, so not worth doing it again. - return - } - - m.mu.Lock() - m.lastWarnContents = cur - m.mu.Unlock() - - show := cur - if len(show) > 1024 { - show = show[:1024] - } - m.logf("trample: resolv.conf changed from what we expected. did some other program interfere? current contents: %q", show) - m.health.SetUnhealthy(resolvTrampleWarnable, nil) -} diff --git a/net/dns/direct_notlinux.go b/net/dns/direct_notlinux.go deleted file mode 100644 index c221ca1be..000000000 --- a/net/dns/direct_notlinux.go +++ /dev/null @@ -1,10 +0,0 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !linux - -package dns - -func (m *directManager) runFileWatcher() { - // Not implemented on other platforms. Maybe it could resort to polling. -} diff --git a/net/dns/dns_clone.go b/net/dns/dns_clone.go new file mode 100644 index 000000000..807bfce23 --- /dev/null +++ b/net/dns/dns_clone.go @@ -0,0 +1,74 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by tailscale.com/cmd/cloner; DO NOT EDIT. + +package dns + +import ( + "net/netip" + + "tailscale.com/types/dnstype" + "tailscale.com/util/dnsname" +) + +// Clone makes a deep copy of Config. +// The result aliases no memory with the original. +func (src *Config) Clone() *Config { + if src == nil { + return nil + } + dst := new(Config) + *dst = *src + if src.DefaultResolvers != nil { + dst.DefaultResolvers = make([]*dnstype.Resolver, len(src.DefaultResolvers)) + for i := range dst.DefaultResolvers { + if src.DefaultResolvers[i] == nil { + dst.DefaultResolvers[i] = nil + } else { + dst.DefaultResolvers[i] = src.DefaultResolvers[i].Clone() + } + } + } + if dst.Routes != nil { + dst.Routes = map[dnsname.FQDN][]*dnstype.Resolver{} + for k := range src.Routes { + dst.Routes[k] = append([]*dnstype.Resolver{}, src.Routes[k]...) + } + } + dst.SearchDomains = append(src.SearchDomains[:0:0], src.SearchDomains...) + if dst.Hosts != nil { + dst.Hosts = map[dnsname.FQDN][]netip.Addr{} + for k := range src.Hosts { + dst.Hosts[k] = append([]netip.Addr{}, src.Hosts[k]...) + } + } + return dst +} + +// A compilation failure here means this code must be regenerated, with the command at the top of this file. +var _ConfigCloneNeedsRegeneration = Config(struct { + DefaultResolvers []*dnstype.Resolver + Routes map[dnsname.FQDN][]*dnstype.Resolver + SearchDomains []dnsname.FQDN + Hosts map[dnsname.FQDN][]netip.Addr + OnlyIPv6 bool +}{}) + +// Clone duplicates src into dst and reports whether it succeeded. +// To succeed, must be of types <*T, *T> or <*T, **T>, +// where T is one of Config. +func Clone(dst, src any) bool { + switch src := src.(type) { + case *Config: + switch dst := dst.(type) { + case *Config: + *dst = *src.Clone() + return true + case **Config: + *dst = src.Clone() + return true + } + } + return false +} diff --git a/net/dns/dns_view.go b/net/dns/dns_view.go new file mode 100644 index 000000000..c7ce376cb --- /dev/null +++ b/net/dns/dns_view.go @@ -0,0 +1,138 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by tailscale/cmd/viewer; DO NOT EDIT. + +package dns + +import ( + jsonv1 "encoding/json" + "errors" + "net/netip" + + jsonv2 "github.com/go-json-experiment/json" + "github.com/go-json-experiment/json/jsontext" + "tailscale.com/types/dnstype" + "tailscale.com/types/views" + "tailscale.com/util/dnsname" +) + +//go:generate go run tailscale.com/cmd/cloner -clonefunc=true -type=Config + +// View returns a read-only view of Config. +func (p *Config) View() ConfigView { + return ConfigView{Đļ: p} +} + +// ConfigView provides a read-only view over Config. +// +// Its methods should only be called if `Valid()` returns true. +type ConfigView struct { + // Đļ is the underlying mutable value, named with a hard-to-type + // character that looks pointy like a pointer. + // It is named distinctively to make you think of how dangerous it is to escape + // to callers. You must not let callers be able to mutate it. + Đļ *Config +} + +// Valid reports whether v's underlying value is non-nil. +func (v ConfigView) Valid() bool { return v.Đļ != nil } + +// AsStruct returns a clone of the underlying value which aliases no memory with +// the original. +func (v ConfigView) AsStruct() *Config { + if v.Đļ == nil { + return nil + } + return v.Đļ.Clone() +} + +// MarshalJSON implements [jsonv1.Marshaler]. +func (v ConfigView) MarshalJSON() ([]byte, error) { + return jsonv1.Marshal(v.Đļ) +} + +// MarshalJSONTo implements [jsonv2.MarshalerTo]. +func (v ConfigView) MarshalJSONTo(enc *jsontext.Encoder) error { + return jsonv2.MarshalEncode(enc, v.Đļ) +} + +// UnmarshalJSON implements [jsonv1.Unmarshaler]. +func (v *ConfigView) UnmarshalJSON(b []byte) error { + if v.Đļ != nil { + return errors.New("already initialized") + } + if len(b) == 0 { + return nil + } + var x Config + if err := jsonv1.Unmarshal(b, &x); err != nil { + return err + } + v.Đļ = &x + return nil +} + +// UnmarshalJSONFrom implements [jsonv2.UnmarshalerFrom]. +func (v *ConfigView) UnmarshalJSONFrom(dec *jsontext.Decoder) error { + if v.Đļ != nil { + return errors.New("already initialized") + } + var x Config + if err := jsonv2.UnmarshalDecode(dec, &x); err != nil { + return err + } + v.Đļ = &x + return nil +} + +// DefaultResolvers are the DNS resolvers to use for DNS names +// which aren't covered by more specific per-domain routes below. +// If empty, the OS's default resolvers (the ones that predate +// Tailscale altering the configuration) are used. +func (v ConfigView) DefaultResolvers() views.SliceView[*dnstype.Resolver, dnstype.ResolverView] { + return views.SliceOfViews[*dnstype.Resolver, dnstype.ResolverView](v.Đļ.DefaultResolvers) +} + +// Routes maps a DNS suffix to the resolvers that should be used +// for queries that fall within that suffix. +// If a query doesn't match any entry in Routes, the +// DefaultResolvers are used. +// A Routes entry with no resolvers means the route should be +// authoritatively answered using the contents of Hosts. +func (v ConfigView) Routes() views.MapFn[dnsname.FQDN, []*dnstype.Resolver, views.SliceView[*dnstype.Resolver, dnstype.ResolverView]] { + return views.MapFnOf(v.Đļ.Routes, func(t []*dnstype.Resolver) views.SliceView[*dnstype.Resolver, dnstype.ResolverView] { + return views.SliceOfViews[*dnstype.Resolver, dnstype.ResolverView](t) + }) +} + +// SearchDomains are DNS suffixes to try when expanding +// single-label queries. +func (v ConfigView) SearchDomains() views.Slice[dnsname.FQDN] { + return views.SliceOf(v.Đļ.SearchDomains) +} + +// Hosts maps DNS FQDNs to their IPs, which can be a mix of IPv4 +// and IPv6. +// Queries matching entries in Hosts are resolved locally by +// 100.100.100.100 without leaving the machine. +// Adding an entry to Hosts merely creates the record. If you want +// it to resolve, you also need to add appropriate routes to +// Routes. +func (v ConfigView) Hosts() views.MapSlice[dnsname.FQDN, netip.Addr] { + return views.MapSliceOf(v.Đļ.Hosts) +} + +// OnlyIPv6, if true, uses the IPv6 service IP (for MagicDNS) +// instead of the IPv4 version (100.100.100.100). +func (v ConfigView) OnlyIPv6() bool { return v.Đļ.OnlyIPv6 } +func (v ConfigView) Equal(v2 ConfigView) bool { return v.Đļ.Equal(v2.Đļ) } + +// A compilation failure here means this code must be regenerated, with the command at the top of this file. +var _ConfigViewNeedsRegeneration = Config(struct { + DefaultResolvers []*dnstype.Resolver + Routes map[dnsname.FQDN][]*dnstype.Resolver + SearchDomains []dnsname.FQDN + Hosts map[dnsname.FQDN][]netip.Addr + OnlyIPv6 bool +}{}) diff --git a/net/dns/manager.go b/net/dns/manager.go index 51a0fa12c..de99fe646 100644 --- a/net/dns/manager.go +++ b/net/dns/manager.go @@ -8,6 +8,7 @@ import ( "context" "encoding/binary" "errors" + "fmt" "io" "net" "net/netip" @@ -18,22 +19,27 @@ import ( "sync/atomic" "time" - xmaps "golang.org/x/exp/maps" "tailscale.com/control/controlknobs" + "tailscale.com/feature/buildfeatures" "tailscale.com/health" "tailscale.com/net/dns/resolver" "tailscale.com/net/netmon" "tailscale.com/net/tsdial" "tailscale.com/syncs" - "tailscale.com/tstime/rate" "tailscale.com/types/dnstype" "tailscale.com/types/logger" "tailscale.com/util/clientmetric" "tailscale.com/util/dnsname" + "tailscale.com/util/eventbus" + "tailscale.com/util/slicesx" + "tailscale.com/util/syspolicy/policyclient" ) var ( errFullQueue = errors.New("request queue full") + // ErrNoDNSConfig is returned by RecompileDNSConfig when the Manager + // has no existing DNS configuration. + ErrNoDNSConfig = errors.New("no DNS configuration") ) // maxActiveQueries returns the maximal number of DNS requests that can @@ -59,16 +65,17 @@ type Manager struct { knobs *controlknobs.Knobs // or nil goos string // if empty, gets set to runtime.GOOS - mu sync.Mutex // guards following - // config is the last configuration we successfully compiled or nil if there - // was any failure applying the last configuration. - config *Config + mu sync.Mutex // guards following + config *Config // Tracks the last viable DNS configuration set by Set. nil on failures other than compilation failures or if set has never been called. } // NewManagers created a new manager from the given config. // // knobs may be nil. func NewManager(logf logger.Logf, oscfg OSConfigurator, health *health.Tracker, dialer *tsdial.Dialer, linkSel resolver.ForwardLinkSelector, knobs *controlknobs.Knobs, goos string) *Manager { + if !buildfeatures.HasDNS { + return nil + } if dialer == nil { panic("nil Dialer") } @@ -89,34 +96,46 @@ func NewManager(logf logger.Logf, oscfg OSConfigurator, health *health.Tracker, goos: goos, } - // Rate limit our attempts to correct our DNS configuration. - limiter := rate.NewLimiter(1.0/5.0, 1) - - // This will recompile the DNS config, which in turn will requery the system - // DNS settings. The recovery func should triggered only when we are missing - // upstream nameservers and require them to forward a query. - m.resolver.SetMissingUpstreamRecovery(func() { - m.mu.Lock() - defer m.mu.Unlock() - if m.config == nil { - return - } - - if limiter.Allow() { - m.logf("DNS resolution failed due to missing upstream nameservers. Recompiling DNS configuration.") - m.setLocked(*m.config) - } - }) - m.ctx, m.ctxCancel = context.WithCancel(context.Background()) m.logf("using %T", m.os) return m } // Resolver returns the Manager's DNS Resolver. -func (m *Manager) Resolver() *resolver.Resolver { return m.resolver } +func (m *Manager) Resolver() *resolver.Resolver { + if !buildfeatures.HasDNS { + return nil + } + return m.resolver +} + +// RecompileDNSConfig recompiles the last attempted DNS configuration, which has +// the side effect of re-querying the OS's interface nameservers. This should be used +// on platforms where the interface nameservers can change. Darwin, for example, +// where the nameservers aren't always available when we process a major interface +// change event, or platforms where the nameservers may change while tunnel is up. +// +// This should be called if it is determined that [OSConfigurator.GetBaseConfig] may +// give a better or different result than when [Manager.Set] was last called. The +// logic for making that determination is up to the caller. +// +// It returns [ErrNoDNSConfig] if [Manager.Set] has never been called. +func (m *Manager) RecompileDNSConfig() error { + if !buildfeatures.HasDNS { + return nil + } + m.mu.Lock() + defer m.mu.Unlock() + if m.config != nil { + return m.setLocked(*m.config) + } + return ErrNoDNSConfig +} func (m *Manager) Set(cfg Config) error { + if !buildfeatures.HasDNS { + return nil + } m.mu.Lock() defer m.mu.Unlock() return m.setLocked(cfg) @@ -124,6 +143,9 @@ func (m *Manager) Set(cfg Config) error { // GetBaseConfig returns the current base OS DNS configuration as provided by the OSConfigurator. func (m *Manager) GetBaseConfig() (OSConfig, error) { + if !buildfeatures.HasDNS { + panic("unreachable") + } return m.os.GetBaseConfig() } @@ -133,15 +155,15 @@ func (m *Manager) GetBaseConfig() (OSConfig, error) { func (m *Manager) setLocked(cfg Config) error { syncs.AssertLocked(&m.mu) - // On errors, the 'set' config is cleared. - m.config = nil - m.logf("Set: %v", logger.ArgWriter(func(w *bufio.Writer) { cfg.WriteToBufioWriter(w) })) rcfg, ocfg, err := m.compileConfig(cfg) if err != nil { + // On a compilation failure, set m.config set for later reuse by + // [Manager.RecompileDNSConfig] and return the error. + m.config = &cfg return err } @@ -153,14 +175,16 @@ func (m *Manager) setLocked(cfg Config) error { })) if err := m.resolver.SetConfig(rcfg); err != nil { + m.config = nil return err } if err := m.os.SetDNS(ocfg); err != nil { - m.health.SetDNSOSHealth(err) + m.config = nil + m.health.SetUnhealthy(osConfigurationSetWarnable, health.Args{health.ArgError: err.Error()}) return err } - m.health.SetDNSOSHealth(nil) + m.health.SetHealthy(osConfigurationSetWarnable) m.config = &cfg return nil @@ -203,7 +227,7 @@ func compileHostEntries(cfg Config) (hosts []*HostEntry) { if len(hostsMap) == 0 { return nil } - hosts = xmaps.Values(hostsMap) + hosts = slicesx.MapValues(hostsMap) slices.SortFunc(hosts, func(a, b *HostEntry) int { if len(a.Hosts) == 0 && len(b.Hosts) == 0 { return 0 @@ -217,6 +241,26 @@ func compileHostEntries(cfg Config) (hosts []*HostEntry) { return hosts } +var osConfigurationReadWarnable = health.Register(&health.Warnable{ + Code: "dns-read-os-config-failed", + Title: "Failed to read system DNS configuration", + Text: func(args health.Args) string { + return fmt.Sprintf("Tailscale failed to fetch the DNS configuration of your device: %v", args[health.ArgError]) + }, + Severity: health.SeverityLow, + DependsOn: []*health.Warnable{health.NetworkStatusWarnable}, +}) + +var osConfigurationSetWarnable = health.Register(&health.Warnable{ + Code: "dns-set-os-config-failed", + Title: "Failed to set system DNS configuration", + Text: func(args health.Args) string { + return fmt.Sprintf("Tailscale failed to set the DNS configuration of your device: %v", args[health.ArgError]) + }, + Severity: health.SeverityMedium, + DependsOn: []*health.Warnable{health.NetworkStatusWarnable}, +}) + // compileConfig converts cfg into a quad-100 resolver configuration // and an OS-level configuration. func (m *Manager) compileConfig(cfg Config) (rcfg resolver.Config, ocfg OSConfig, err error) { @@ -225,8 +269,10 @@ func (m *Manager) compileConfig(cfg Config) (rcfg resolver.Config, ocfg OSConfig // the OS. rcfg.Hosts = cfg.Hosts routes := map[dnsname.FQDN][]*dnstype.Resolver{} // assigned conditionally to rcfg.Routes below. + var propagateHostsToOS bool for suffix, resolvers := range cfg.Routes { if len(resolvers) == 0 { + propagateHostsToOS = true rcfg.LocalDomains = append(rcfg.LocalDomains, suffix) } else { routes[suffix] = resolvers @@ -235,13 +281,13 @@ func (m *Manager) compileConfig(cfg Config) (rcfg resolver.Config, ocfg OSConfig // Similarly, the OS always gets search paths. ocfg.SearchDomains = cfg.SearchDomains - if m.goos == "windows" { + if propagateHostsToOS && m.goos == "windows" { ocfg.Hosts = compileHostEntries(cfg) } // Deal with trivial configs first. switch { - case !cfg.needsOSResolver(): + case !cfg.needsOSResolver() || runtime.GOOS == "plan9": // Set search domains, but nothing else. This also covers the // case where cfg is entirely zero, in which case these // configs clear all Tailscale DNS settings. @@ -264,7 +310,7 @@ func (m *Manager) compileConfig(cfg Config) (rcfg resolver.Config, ocfg OSConfig // through quad-100. rcfg.Routes = routes rcfg.Routes["."] = cfg.DefaultResolvers - ocfg.Nameservers = []netip.Addr{cfg.serviceIP()} + ocfg.Nameservers = cfg.serviceIPs(m.knobs) return rcfg, ocfg, nil } @@ -302,7 +348,7 @@ func (m *Manager) compileConfig(cfg Config) (rcfg resolver.Config, ocfg OSConfig // or routes + MagicDNS, or just MagicDNS, or on an OS that cannot // split-DNS. Install a split config pointing at quad-100. rcfg.Routes = routes - ocfg.Nameservers = []netip.Addr{cfg.serviceIP()} + ocfg.Nameservers = cfg.serviceIPs(m.knobs) var baseCfg *OSConfig // base config; non-nil if/when known @@ -312,7 +358,10 @@ func (m *Manager) compileConfig(cfg Config) (rcfg resolver.Config, ocfg OSConfig // that as the forwarder for all DNS traffic that quad-100 doesn't handle. if isApple || !m.os.SupportsSplitDNS() { // If the OS can't do native split-dns, read out the underlying - // resolver config and blend it into our config. + // resolver config and blend it into our config. On apple platforms, [OSConfigurator.GetBaseConfig] + // has a tendency to temporarily fail if called immediately following + // an interface change. These failures should be retried if/when the OS + // indicates that the DNS configuration has changed via [RecompileDNSConfig]. cfg, err := m.os.GetBaseConfig() if err == nil { baseCfg = &cfg @@ -320,9 +369,10 @@ func (m *Manager) compileConfig(cfg Config) (rcfg resolver.Config, ocfg OSConfig // This is currently (2022-10-13) expected on certain iOS and macOS // builds. } else { - m.health.SetDNSOSHealth(err) + m.health.SetUnhealthy(osConfigurationReadWarnable, health.Args{health.ArgError: err.Error()}) return resolver.Config{}, OSConfig{}, err } + m.health.SetHealthy(osConfigurationReadWarnable) } if baseCfg == nil { @@ -528,6 +578,9 @@ func (m *Manager) HandleTCPConn(conn net.Conn, srcAddr netip.AddrPort) { } func (m *Manager) Down() error { + if !buildfeatures.HasDNS { + return nil + } m.ctxCancel() if err := m.os.Close(); err != nil { return err @@ -537,6 +590,9 @@ func (m *Manager) Down() error { } func (m *Manager) FlushCaches() error { + if !buildfeatures.HasDNS { + return nil + } return flushCaches() } @@ -545,14 +601,18 @@ func (m *Manager) FlushCaches() error { // No other state needs to be instantiated before this runs. // // health must not be nil -func CleanUp(logf logger.Logf, netMon *netmon.Monitor, health *health.Tracker, interfaceName string) { - oscfg, err := NewOSConfigurator(logf, nil, nil, interfaceName) +func CleanUp(logf logger.Logf, netMon *netmon.Monitor, bus *eventbus.Bus, health *health.Tracker, interfaceName string) { + if !buildfeatures.HasDNS { + return + } + oscfg, err := NewOSConfigurator(logf, health, policyclient.Get(), nil, interfaceName) if err != nil { logf("creating dns cleanup: %v", err) return } d := &tsdial.Dialer{Logf: logf} d.SetNetMon(netMon) + d.SetBus(bus) dns := NewManager(logf, oscfg, health, d, nil, nil, runtime.GOOS) if err := dns.Down(); err != nil { logf("dns down: %v", err) diff --git a/net/dns/manager_darwin.go b/net/dns/manager_darwin.go index ccfafaa45..d73ad71a8 100644 --- a/net/dns/manager_darwin.go +++ b/net/dns/manager_darwin.go @@ -14,12 +14,13 @@ import ( "tailscale.com/net/tsaddr" "tailscale.com/types/logger" "tailscale.com/util/mak" + "tailscale.com/util/syspolicy/policyclient" ) // NewOSConfigurator creates a new OS configurator. // // The health tracker and the knobs may be nil and are ignored on this platform. -func NewOSConfigurator(logf logger.Logf, _ *health.Tracker, _ *controlknobs.Knobs, ifName string) (OSConfigurator, error) { +func NewOSConfigurator(logf logger.Logf, _ *health.Tracker, _ policyclient.Client, _ *controlknobs.Knobs, ifName string) (OSConfigurator, error) { return &darwinConfigurator{logf: logf, ifName: ifName}, nil } diff --git a/net/dns/manager_default.go b/net/dns/manager_default.go index 11dea5ca8..1a86690c5 100644 --- a/net/dns/manager_default.go +++ b/net/dns/manager_default.go @@ -1,7 +1,7 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -//go:build !linux && !freebsd && !openbsd && !windows && !darwin +//go:build (!linux || android) && !freebsd && !openbsd && !windows && !darwin && !illumos && !solaris && !plan9 package dns @@ -9,11 +9,12 @@ import ( "tailscale.com/control/controlknobs" "tailscale.com/health" "tailscale.com/types/logger" + "tailscale.com/util/syspolicy/policyclient" ) // NewOSConfigurator creates a new OS configurator. // // The health tracker and the knobs may be nil and are ignored on this platform. -func NewOSConfigurator(logger.Logf, *health.Tracker, *controlknobs.Knobs, string) (OSConfigurator, error) { +func NewOSConfigurator(logger.Logf, *health.Tracker, policyclient.Client, *controlknobs.Knobs, string) (OSConfigurator, error) { return NewNoopManager() } diff --git a/net/dns/manager_freebsd.go b/net/dns/manager_freebsd.go index 1ec9ea841..3237fb382 100644 --- a/net/dns/manager_freebsd.go +++ b/net/dns/manager_freebsd.go @@ -10,12 +10,13 @@ import ( "tailscale.com/control/controlknobs" "tailscale.com/health" "tailscale.com/types/logger" + "tailscale.com/util/syspolicy/policyclient" ) // NewOSConfigurator creates a new OS configurator. // // The health tracker may be nil; the knobs may be nil and are ignored on this platform. -func NewOSConfigurator(logf logger.Logf, health *health.Tracker, _ *controlknobs.Knobs, _ string) (OSConfigurator, error) { +func NewOSConfigurator(logf logger.Logf, health *health.Tracker, _ policyclient.Client, _ *controlknobs.Knobs, _ string) (OSConfigurator, error) { bs, err := os.ReadFile("/etc/resolv.conf") if os.IsNotExist(err) { return newDirectManager(logf, health), nil diff --git a/net/dns/manager_linux.go b/net/dns/manager_linux.go index 3ba3022b6..4304df261 100644 --- a/net/dns/manager_linux.go +++ b/net/dns/manager_linux.go @@ -1,11 +1,12 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause +//go:build linux && !android + package dns import ( "bytes" - "context" "errors" "fmt" "os" @@ -13,13 +14,15 @@ import ( "sync" "time" - "github.com/godbus/dbus/v5" "tailscale.com/control/controlknobs" + "tailscale.com/feature" + "tailscale.com/feature/buildfeatures" "tailscale.com/health" "tailscale.com/net/netaddr" "tailscale.com/types/logger" "tailscale.com/util/clientmetric" - "tailscale.com/util/cmpver" + "tailscale.com/util/syspolicy/policyclient" + "tailscale.com/version/distro" ) type kv struct { @@ -32,18 +35,59 @@ func (kv kv) String() string { var publishOnce sync.Once +// reconfigTimeout is the time interval within which Manager.{Up,Down} should complete. +// +// This is particularly useful because certain conditions can cause indefinite hangs +// (such as improper dbus auth followed by contextless dbus.Object.Call). +// Such operations should be wrapped in a timeout context. +const reconfigTimeout = time.Second + +// Set unless ts_omit_networkmanager +var ( + optNewNMManager feature.Hook[func(ifName string) (OSConfigurator, error)] + optNMIsUsingResolved feature.Hook[func() error] + optNMVersionBetween feature.Hook[func(v1, v2 string) (bool, error)] +) + +// Set unless ts_omit_resolved +var ( + optNewResolvedManager feature.Hook[func(logf logger.Logf, health *health.Tracker, interfaceName string) (OSConfigurator, error)] +) + +// Set unless ts_omit_dbus +var ( + optDBusPing feature.Hook[func(name, objectPath string) error] + optDBusReadString feature.Hook[func(name, objectPath, iface, member string) (string, error)] +) + // NewOSConfigurator created a new OS configurator. // // The health tracker may be nil; the knobs may be nil and are ignored on this platform. -func NewOSConfigurator(logf logger.Logf, health *health.Tracker, _ *controlknobs.Knobs, interfaceName string) (ret OSConfigurator, err error) { +func NewOSConfigurator(logf logger.Logf, health *health.Tracker, _ policyclient.Client, _ *controlknobs.Knobs, interfaceName string) (ret OSConfigurator, err error) { + if !buildfeatures.HasDNS || distro.Get() == distro.JetKVM { + return NewNoopManager() + } + env := newOSConfigEnv{ - fs: directFS{}, - dbusPing: dbusPing, - dbusReadString: dbusReadString, - nmIsUsingResolved: nmIsUsingResolved, - nmVersionBetween: nmVersionBetween, - resolvconfStyle: resolvconfStyle, + fs: directFS{}, + resolvconfStyle: resolvconfStyle, + } + if f, ok := optDBusPing.GetOk(); ok { + env.dbusPing = f + } else { + env.dbusPing = func(_, _ string) error { return errors.ErrUnsupported } } + if f, ok := optDBusReadString.GetOk(); ok { + env.dbusReadString = f + } else { + env.dbusReadString = func(_, _, _, _ string) (string, error) { return "", errors.ErrUnsupported } + } + if f, ok := optNMIsUsingResolved.GetOk(); ok { + env.nmIsUsingResolved = f + } else { + env.nmIsUsingResolved = func() error { return errors.ErrUnsupported } + } + env.nmVersionBetween, _ = optNMVersionBetween.GetOk() // GetOk to not panic if nil; unused if optNMIsUsingResolved returns an error mode, err := dnsMode(logf, health, env) if err != nil { return nil, err @@ -58,17 +102,24 @@ func NewOSConfigurator(logf logger.Logf, health *health.Tracker, _ *controlknobs case "direct": return newDirectManagerOnFS(logf, health, env.fs), nil case "systemd-resolved": - return newResolvedManager(logf, health, interfaceName) + if f, ok := optNewResolvedManager.GetOk(); ok { + return f(logf, health, interfaceName) + } + return nil, fmt.Errorf("tailscaled was built without DNS %q support", mode) case "network-manager": - return newNMManager(interfaceName) + if f, ok := optNewNMManager.GetOk(); ok { + return f(interfaceName) + } + return nil, fmt.Errorf("tailscaled was built without DNS %q support", mode) case "debian-resolvconf": return newDebianResolvconfManager(logf) case "openresolv": return newOpenresolvManager(logf) default: logf("[unexpected] detected unknown DNS mode %q, using direct manager as last resort", mode) - return newDirectManagerOnFS(logf, health, env.fs), nil } + + return newDirectManagerOnFS(logf, health, env.fs), nil } // newOSConfigEnv are the funcs newOSConfigurator needs, pulled out for testing. @@ -284,50 +335,6 @@ func dnsMode(logf logger.Logf, health *health.Tracker, env newOSConfigEnv) (ret } } -func nmVersionBetween(first, last string) (bool, error) { - conn, err := dbus.SystemBus() - if err != nil { - // DBus probably not running. - return false, err - } - - nm := conn.Object("org.freedesktop.NetworkManager", dbus.ObjectPath("/org/freedesktop/NetworkManager")) - v, err := nm.GetProperty("org.freedesktop.NetworkManager.Version") - if err != nil { - return false, err - } - - version, ok := v.Value().(string) - if !ok { - return false, fmt.Errorf("unexpected type %T for NM version", v.Value()) - } - - outside := cmpver.Compare(version, first) < 0 || cmpver.Compare(version, last) > 0 - return !outside, nil -} - -func nmIsUsingResolved() error { - conn, err := dbus.SystemBus() - if err != nil { - // DBus probably not running. - return err - } - - nm := conn.Object("org.freedesktop.NetworkManager", dbus.ObjectPath("/org/freedesktop/NetworkManager/DnsManager")) - v, err := nm.GetProperty("org.freedesktop.NetworkManager.DnsManager.Mode") - if err != nil { - return fmt.Errorf("getting NM mode: %w", err) - } - mode, ok := v.Value().(string) - if !ok { - return fmt.Errorf("unexpected type %T for NM DNS mode", v.Value()) - } - if mode != "systemd-resolved" { - return errors.New("NetworkManager is not using systemd-resolved for DNS") - } - return nil -} - // resolvedIsActuallyResolver reports whether the system is using // systemd-resolved as the resolver. There are two different ways to // use systemd-resolved: @@ -388,44 +395,3 @@ func isLibnssResolveUsed(env newOSConfigEnv) error { } return fmt.Errorf("libnss_resolve not used") } - -func dbusPing(name, objectPath string) error { - conn, err := dbus.SystemBus() - if err != nil { - // DBus probably not running. - return err - } - - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - defer cancel() - - obj := conn.Object(name, dbus.ObjectPath(objectPath)) - call := obj.CallWithContext(ctx, "org.freedesktop.DBus.Peer.Ping", 0) - return call.Err -} - -// dbusReadString reads a string property from the provided name and object -// path. property must be in "interface.member" notation. -func dbusReadString(name, objectPath, iface, member string) (string, error) { - conn, err := dbus.SystemBus() - if err != nil { - // DBus probably not running. - return "", err - } - - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - defer cancel() - - obj := conn.Object(name, dbus.ObjectPath(objectPath)) - - var result dbus.Variant - err = obj.CallWithContext(ctx, "org.freedesktop.DBus.Properties.Get", 0, iface, member).Store(&result) - if err != nil { - return "", err - } - - if s, ok := result.Value().(string); ok { - return s, nil - } - return result.String(), nil -} diff --git a/net/dns/manager_openbsd.go b/net/dns/manager_openbsd.go index 1a1c4390c..6168a9e08 100644 --- a/net/dns/manager_openbsd.go +++ b/net/dns/manager_openbsd.go @@ -11,6 +11,7 @@ import ( "tailscale.com/control/controlknobs" "tailscale.com/health" "tailscale.com/types/logger" + "tailscale.com/util/syspolicy/policyclient" ) type kv struct { @@ -24,7 +25,7 @@ func (kv kv) String() string { // NewOSConfigurator created a new OS configurator. // // The health tracker may be nil; the knobs may be nil and are ignored on this platform. -func NewOSConfigurator(logf logger.Logf, health *health.Tracker, _ *controlknobs.Knobs, interfaceName string) (OSConfigurator, error) { +func NewOSConfigurator(logf logger.Logf, health *health.Tracker, _ policyclient.Client, _ *controlknobs.Knobs, interfaceName string) (OSConfigurator, error) { return newOSConfigurator(logf, health, interfaceName, newOSConfigEnv{ rcIsResolvd: rcIsResolvd, diff --git a/net/dns/manager_plan9.go b/net/dns/manager_plan9.go new file mode 100644 index 000000000..ef1ceea17 --- /dev/null +++ b/net/dns/manager_plan9.go @@ -0,0 +1,182 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// TODO: man 6 ndb | grep -e 'suffix.*same line' +// to detect Russ's https://9fans.topicbox.com/groups/9fans/T9c9d81b5801a0820/ndb-suffix-specific-dns-changes + +package dns + +import ( + "bufio" + "bytes" + "fmt" + "io" + "net/netip" + "os" + "regexp" + "strings" + "unicode" + + "tailscale.com/control/controlknobs" + "tailscale.com/health" + "tailscale.com/types/logger" + "tailscale.com/util/set" + "tailscale.com/util/syspolicy/policyclient" +) + +func NewOSConfigurator(logf logger.Logf, ht *health.Tracker, _ policyclient.Client, knobs *controlknobs.Knobs, interfaceName string) (OSConfigurator, error) { + return &plan9DNSManager{ + logf: logf, + ht: ht, + knobs: knobs, + }, nil +} + +type plan9DNSManager struct { + logf logger.Logf + ht *health.Tracker + knobs *controlknobs.Knobs +} + +// netNDBBytesWithoutTailscale returns raw (the contents of /net/ndb) with any +// Tailscale bits removed. +func netNDBBytesWithoutTailscale(raw []byte) ([]byte, error) { + var ret bytes.Buffer + bs := bufio.NewScanner(bytes.NewReader(raw)) + removeLine := set.Set[string]{} + for bs.Scan() { + t := bs.Text() + if rest, ok := strings.CutPrefix(t, "#tailscaled-added-line:"); ok { + removeLine.Add(strings.TrimSpace(rest)) + continue + } + trimmed := strings.TrimSpace(t) + if removeLine.Contains(trimmed) { + removeLine.Delete(trimmed) + continue + } + + // Also remove any DNS line referencing *.ts.net. This is + // Tailscale-specific (and won't work with, say, Headscale), but + // the Headscale case will be covered by the #tailscaled-added-line + // logic above, assuming the user didn't delete those comments. + if (strings.HasPrefix(trimmed, "dns=") || strings.Contains(trimmed, "dnsdomain=")) && + strings.HasSuffix(trimmed, ".ts.net") { + continue + } + + ret.WriteString(t) + ret.WriteByte('\n') + } + return ret.Bytes(), bs.Err() +} + +// setNDBSuffix adds lines to tsFree (the contents of /net/ndb already cleaned +// of Tailscale-added lines) to add the optional DNS search domain (e.g. +// "foo.ts.net") and DNS server to it. +func setNDBSuffix(tsFree []byte, suffix string) []byte { + suffix = strings.TrimSuffix(suffix, ".") + if suffix == "" { + return tsFree + } + var buf bytes.Buffer + bs := bufio.NewScanner(bytes.NewReader(tsFree)) + var added []string + addLine := func(s string) { + added = append(added, strings.TrimSpace(s)) + buf.WriteString(s) + } + for bs.Scan() { + buf.Write(bs.Bytes()) + buf.WriteByte('\n') + + t := bs.Text() + if suffix != "" && len(added) == 0 && strings.HasPrefix(t, "\tdns=") { + addLine(fmt.Sprintf("\tdns=100.100.100.100 suffix=%s\n", suffix)) + addLine(fmt.Sprintf("\tdnsdomain=%s\n", suffix)) + } + } + bufTrim := bytes.TrimLeftFunc(buf.Bytes(), unicode.IsSpace) + if len(added) == 0 { + return bufTrim + } + var ret bytes.Buffer + for _, s := range added { + ret.WriteString("#tailscaled-added-line: ") + ret.WriteString(s) + ret.WriteString("\n") + } + ret.WriteString("\n") + ret.Write(bufTrim) + return ret.Bytes() +} + +func (m *plan9DNSManager) SetDNS(c OSConfig) error { + ndbOnDisk, err := os.ReadFile("/net/ndb") + if err != nil { + return err + } + + tsFree, err := netNDBBytesWithoutTailscale(ndbOnDisk) + if err != nil { + return err + } + + var suffix string + if len(c.SearchDomains) > 0 { + suffix = string(c.SearchDomains[0]) + } + + newBuf := setNDBSuffix(tsFree, suffix) + if !bytes.Equal(newBuf, ndbOnDisk) { + if err := os.WriteFile("/net/ndb", newBuf, 0644); err != nil { + return fmt.Errorf("writing /net/ndb: %w", err) + } + if f, err := os.OpenFile("/net/dns", os.O_RDWR, 0); err == nil { + if _, err := io.WriteString(f, "refresh\n"); err != nil { + f.Close() + return fmt.Errorf("/net/dns refresh write: %w", err) + } + if err := f.Close(); err != nil { + return fmt.Errorf("/net/dns refresh close: %w", err) + } + } + } + + return nil +} + +func (m *plan9DNSManager) SupportsSplitDNS() bool { return false } + +func (m *plan9DNSManager) Close() error { + // TODO(bradfitz): remove the Tailscale bits from /net/ndb ideally + return nil +} + +var dnsRegex = regexp.MustCompile(`\bdns=(\d+\.\d+\.\d+\.\d+)\b`) + +func (m *plan9DNSManager) GetBaseConfig() (OSConfig, error) { + var oc OSConfig + f, err := os.Open("/net/ndb") + if err != nil { + return oc, err + } + defer f.Close() + bs := bufio.NewScanner(f) + for bs.Scan() { + m := dnsRegex.FindSubmatch(bs.Bytes()) + if m == nil { + continue + } + addr, err := netip.ParseAddr(string(m[1])) + if err != nil { + continue + } + oc.Nameservers = append(oc.Nameservers, addr) + } + if err := bs.Err(); err != nil { + return oc, err + } + + return oc, nil +} diff --git a/net/dns/manager_plan9_test.go b/net/dns/manager_plan9_test.go new file mode 100644 index 000000000..806fdb68e --- /dev/null +++ b/net/dns/manager_plan9_test.go @@ -0,0 +1,86 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build plan9 + +package dns + +import "testing" + +func TestNetNDBBytesWithoutTailscale(t *testing.T) { + tests := []struct { + name string + raw string + want string + }{ + { + name: "empty", + raw: "", + want: "", + }, + { + name: "no-tailscale", + raw: "# This is a comment\nip=10.0.2.15 ipmask=255.255.255.0 ipgw=10.0.2.2\n\tsys=gnot\n", + want: "# This is a comment\nip=10.0.2.15 ipmask=255.255.255.0 ipgw=10.0.2.2\n\tsys=gnot\n", + }, + { + name: "remove-by-comments", + raw: "# This is a comment\n#tailscaled-added-line: dns=100.100.100.100\nip=10.0.2.15 ipmask=255.255.255.0 ipgw=10.0.2.2\n\tdns=100.100.100.100\n\tsys=gnot\n", + want: "# This is a comment\nip=10.0.2.15 ipmask=255.255.255.0 ipgw=10.0.2.2\n\tsys=gnot\n", + }, + { + name: "remove-by-ts.net", + raw: "Some line\n\tdns=100.100.100.100 suffix=foo.ts.net\n\tfoo=bar\n", + want: "Some line\n\tfoo=bar\n", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := netNDBBytesWithoutTailscale([]byte(tt.raw)) + if err != nil { + t.Fatal(err) + } + if string(got) != tt.want { + t.Errorf("GOT:\n%s\n\nWANT:\n%s\n", string(got), tt.want) + } + }) + } +} + +func TestSetNDBSuffix(t *testing.T) { + tests := []struct { + name string + raw string + want string + }{ + { + name: "empty", + raw: "", + want: "", + }, + { + name: "set", + raw: "ip=10.0.2.15 ipmask=255.255.255.0 ipgw=10.0.2.2\n\tsys=gnot\n\tdns=100.100.100.100\n\n# foo\n", + want: `#tailscaled-added-line: dns=100.100.100.100 suffix=foo.ts.net +#tailscaled-added-line: dnsdomain=foo.ts.net + +ip=10.0.2.15 ipmask=255.255.255.0 ipgw=10.0.2.2 + sys=gnot + dns=100.100.100.100 + dns=100.100.100.100 suffix=foo.ts.net + dnsdomain=foo.ts.net + +# foo +`, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := setNDBSuffix([]byte(tt.raw), "foo.ts.net") + if string(got) != tt.want { + t.Errorf("wrong value\n GOT %q:\n%s\n\nWANT %q:\n%s\n", got, got, tt.want, tt.want) + } + }) + } + +} diff --git a/net/dns/manager_solaris.go b/net/dns/manager_solaris.go new file mode 100644 index 000000000..de7e72bb5 --- /dev/null +++ b/net/dns/manager_solaris.go @@ -0,0 +1,15 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package dns + +import ( + "tailscale.com/control/controlknobs" + "tailscale.com/health" + "tailscale.com/types/logger" + "tailscale.com/util/syspolicy/policyclient" +) + +func NewOSConfigurator(logf logger.Logf, health *health.Tracker, _ policyclient.Client, _ *controlknobs.Knobs, iface string) (OSConfigurator, error) { + return newDirectManager(logf, health), nil +} diff --git a/net/dns/manager_tcp_test.go b/net/dns/manager_tcp_test.go index f4c42791e..dcdc88c7a 100644 --- a/net/dns/manager_tcp_test.go +++ b/net/dns/manager_tcp_test.go @@ -20,6 +20,7 @@ import ( "tailscale.com/net/tsdial" "tailscale.com/tstest" "tailscale.com/util/dnsname" + "tailscale.com/util/eventbus/eventbustest" ) func mkDNSRequest(domain dnsname.FQDN, tp dns.Type, modify func(*dns.Builder)) []byte { @@ -89,7 +90,10 @@ func TestDNSOverTCP(t *testing.T) { SearchDomains: fqdns("coffee.shop"), }, } - m := NewManager(t.Logf, &f, new(health.Tracker), tsdial.NewDialer(netmon.NewStatic()), nil, nil, "") + bus := eventbustest.NewBus(t) + dialer := tsdial.NewDialer(netmon.NewStatic()) + dialer.SetBus(bus) + m := NewManager(t.Logf, &f, health.NewTracker(bus), dialer, nil, nil, "") m.resolver.TestOnlySetHook(f.SetResolver) m.Set(Config{ Hosts: hosts( @@ -174,7 +178,10 @@ func TestDNSOverTCP_TooLarge(t *testing.T) { SearchDomains: fqdns("coffee.shop"), }, } - m := NewManager(log, &f, new(health.Tracker), tsdial.NewDialer(netmon.NewStatic()), nil, nil, "") + bus := eventbustest.NewBus(t) + dialer := tsdial.NewDialer(netmon.NewStatic()) + dialer.SetBus(bus) + m := NewManager(log, &f, health.NewTracker(bus), dialer, nil, nil, "") m.resolver.TestOnlySetHook(f.SetResolver) m.Set(Config{ Hosts: hosts("andrew.ts.com.", "1.2.3.4"), diff --git a/net/dns/manager_test.go b/net/dns/manager_test.go index 366e08bbf..92b660007 100644 --- a/net/dns/manager_test.go +++ b/net/dns/manager_test.go @@ -4,6 +4,7 @@ package dns import ( + "errors" "net/netip" "runtime" "strings" @@ -18,14 +19,16 @@ import ( "tailscale.com/net/tsdial" "tailscale.com/types/dnstype" "tailscale.com/util/dnsname" + "tailscale.com/util/eventbus/eventbustest" ) type fakeOSConfigurator struct { SplitDNS bool BaseConfig OSConfig - OSConfig OSConfig - ResolverConfig resolver.Config + OSConfig OSConfig + ResolverConfig resolver.Config + GetBaseConfigErr *error } func (c *fakeOSConfigurator) SetDNS(cfg OSConfig) error { @@ -45,6 +48,9 @@ func (c *fakeOSConfigurator) SupportsSplitDNS() bool { } func (c *fakeOSConfigurator) GetBaseConfig() (OSConfig, error) { + if c.GetBaseConfigErr != nil { + return OSConfig{}, *c.GetBaseConfigErr + } return c.BaseConfig, nil } @@ -836,6 +842,76 @@ func TestManager(t *testing.T) { }, goos: "darwin", }, + { + name: "populate-hosts-magicdns", + in: Config{ + Routes: upstreams( + "corp.com", "2.2.2.2", + "ts.com", ""), + Hosts: hosts( + "dave.ts.com.", "1.2.3.4", + "bradfitz.ts.com.", "2.3.4.5"), + SearchDomains: fqdns("ts.com", "universe.tf"), + }, + split: true, + os: OSConfig{ + Hosts: []*HostEntry{ + { + Addr: netip.MustParseAddr("2.3.4.5"), + Hosts: []string{ + "bradfitz.ts.com.", + "bradfitz", + }, + }, + { + Addr: netip.MustParseAddr("1.2.3.4"), + Hosts: []string{ + "dave.ts.com.", + "dave", + }, + }, + }, + Nameservers: mustIPs("100.100.100.100"), + SearchDomains: fqdns("ts.com", "universe.tf"), + MatchDomains: fqdns("corp.com", "ts.com"), + }, + rs: resolver.Config{ + Routes: upstreams("corp.com.", "2.2.2.2"), + Hosts: hosts( + "dave.ts.com.", "1.2.3.4", + "bradfitz.ts.com.", "2.3.4.5"), + LocalDomains: fqdns("ts.com."), + }, + goos: "windows", + }, + { + // Regression test for https://github.com/tailscale/tailscale/issues/14428 + name: "nopopulate-hosts-nomagicdns", + in: Config{ + Routes: upstreams( + "corp.com", "2.2.2.2", + "ts.com", "1.1.1.1"), + Hosts: hosts( + "dave.ts.com.", "1.2.3.4", + "bradfitz.ts.com.", "2.3.4.5"), + SearchDomains: fqdns("ts.com", "universe.tf"), + }, + split: true, + os: OSConfig{ + Nameservers: mustIPs("100.100.100.100"), + SearchDomains: fqdns("ts.com", "universe.tf"), + MatchDomains: fqdns("corp.com", "ts.com"), + }, + rs: resolver.Config{ + Routes: upstreams( + "corp.com.", "2.2.2.2", + "ts.com", "1.1.1.1"), + Hosts: hosts( + "dave.ts.com.", "1.2.3.4", + "bradfitz.ts.com.", "2.3.4.5"), + }, + goos: "windows", + }, } trIP := cmp.Transformer("ipStr", func(ip netip.Addr) string { return ip.String() }) @@ -857,7 +933,10 @@ func TestManager(t *testing.T) { goos = "linux" } knobs := &controlknobs.Knobs{} - m := NewManager(t.Logf, &f, new(health.Tracker), tsdial.NewDialer(netmon.NewStatic()), nil, knobs, goos) + bus := eventbustest.NewBus(t) + dialer := tsdial.NewDialer(netmon.NewStatic()) + dialer.SetBus(bus) + m := NewManager(t.Logf, &f, health.NewTracker(bus), dialer, nil, knobs, goos) m.resolver.TestOnlySetHook(f.SetResolver) if err := m.Set(test.in); err != nil { @@ -949,3 +1028,53 @@ func upstreams(strs ...string) (ret map[dnsname.FQDN][]*dnstype.Resolver) { } return ret } + +func TestConfigRecompilation(t *testing.T) { + fakeErr := errors.New("fake os configurator error") + f := &fakeOSConfigurator{} + f.GetBaseConfigErr = &fakeErr + f.BaseConfig = OSConfig{ + Nameservers: mustIPs("1.1.1.1"), + } + + config := Config{ + Routes: upstreams("ts.net", "69.4.2.0", "foo.ts.net", ""), + SearchDomains: fqdns("foo.ts.net"), + } + + bus := eventbustest.NewBus(t) + dialer := tsdial.NewDialer(netmon.NewStatic()) + dialer.SetBus(bus) + m := NewManager(t.Logf, f, health.NewTracker(bus), dialer, nil, nil, "darwin") + + var managerConfig *resolver.Config + m.resolver.TestOnlySetHook(func(cfg resolver.Config) { + managerConfig = &cfg + }) + + // Initial set should error out and store the config + if err := m.Set(config); err == nil { + t.Fatalf("Want non-nil error. Got nil") + } + if m.config == nil { + t.Fatalf("Want persisted config. Got nil.") + } + if managerConfig != nil { + t.Fatalf("Want nil managerConfig. Got %v", managerConfig) + } + + // Clear the error. We should take the happy path now and + // set m.manager's Config. + f.GetBaseConfigErr = nil + + // Recompilation without an error should succeed and set m.config and m.manager's [resolver.Config] + if err := m.RecompileDNSConfig(); err != nil { + t.Fatalf("Want nil error. Got err %v", err) + } + if m.config == nil { + t.Fatalf("Want non-nil config. Got nil") + } + if managerConfig == nil { + t.Fatalf("Want non nil managerConfig. Got nil") + } +} diff --git a/net/dns/manager_windows.go b/net/dns/manager_windows.go index 250a25573..5ccadbab2 100644 --- a/net/dns/manager_windows.go +++ b/net/dns/manager_windows.go @@ -8,13 +8,14 @@ import ( "bytes" "errors" "fmt" + "maps" "net/netip" "os" "os/exec" "path/filepath" + "slices" "sort" "strings" - "sync" "syscall" "time" @@ -25,8 +26,12 @@ import ( "tailscale.com/control/controlknobs" "tailscale.com/envknob" "tailscale.com/health" + "tailscale.com/syncs" "tailscale.com/types/logger" "tailscale.com/util/dnsname" + "tailscale.com/util/syspolicy/pkey" + "tailscale.com/util/syspolicy/policyclient" + "tailscale.com/util/syspolicy/ptype" "tailscale.com/util/winutil" ) @@ -42,19 +47,26 @@ type windowsManager struct { knobs *controlknobs.Knobs // or nil nrptDB *nrptRuleDatabase wslManager *wslManager + polc policyclient.Client - mu sync.Mutex + unregisterPolicyChangeCb func() // called when the manager is closing + + mu syncs.Mutex closing bool } // NewOSConfigurator created a new OS configurator. // // The health tracker and the knobs may be nil. -func NewOSConfigurator(logf logger.Logf, health *health.Tracker, knobs *controlknobs.Knobs, interfaceName string) (OSConfigurator, error) { +func NewOSConfigurator(logf logger.Logf, health *health.Tracker, polc policyclient.Client, knobs *controlknobs.Knobs, interfaceName string) (OSConfigurator, error) { + if polc == nil { + panic("nil policyclient.Client") + } ret := &windowsManager{ logf: logf, guid: interfaceName, knobs: knobs, + polc: polc, wslManager: newWSLManager(logf, health), } @@ -62,6 +74,11 @@ func NewOSConfigurator(logf logger.Logf, health *health.Tracker, knobs *controlk ret.nrptDB = newNRPTRuleDatabase(logf) } + var err error + if ret.unregisterPolicyChangeCb, err = polc.RegisterChangeCallback(ret.sysPolicyChanged); err != nil { + logf("error registering policy change callback: %v", err) // non-fatal + } + go func() { // Log WSL status once at startup. if distros, err := wslDistros(); err != nil { @@ -140,9 +157,8 @@ func (m *windowsManager) setSplitDNS(resolvers []netip.Addr, domains []dnsname.F return m.nrptDB.WriteSplitDNSConfig(servers, domains) } -func setTailscaleHosts(prevHostsFile []byte, hosts []*HostEntry) ([]byte, error) { - b := bytes.ReplaceAll(prevHostsFile, []byte("\r\n"), []byte("\n")) - sc := bufio.NewScanner(bytes.NewReader(b)) +func setTailscaleHosts(logf logger.Logf, prevHostsFile []byte, hosts []*HostEntry) ([]byte, error) { + sc := bufio.NewScanner(bytes.NewReader(prevHostsFile)) const ( header = "# TailscaleHostsSectionStart" footer = "# TailscaleHostsSectionEnd" @@ -151,6 +167,32 @@ func setTailscaleHosts(prevHostsFile []byte, hosts []*HostEntry) ([]byte, error) "# This section contains MagicDNS entries for Tailscale.", "# Do not edit this section manually.", } + + prevEntries := make(map[netip.Addr][]string) + addPrevEntry := func(line string) { + if line == "" || line[0] == '#' { + return + } + + parts := strings.Split(line, " ") + if len(parts) < 1 { + return + } + + addr, err := netip.ParseAddr(parts[0]) + if err != nil { + logf("Parsing address from hosts: %v", err) + return + } + + prevEntries[addr] = parts[1:] + } + + nextEntries := make(map[netip.Addr][]string, len(hosts)) + for _, he := range hosts { + nextEntries[he.Addr] = he.Hosts + } + var out bytes.Buffer var inSection bool for sc.Scan() { @@ -164,26 +206,34 @@ func setTailscaleHosts(prevHostsFile []byte, hosts []*HostEntry) ([]byte, error) continue } if inSection { + addPrevEntry(line) continue } - fmt.Fprintln(&out, line) + fmt.Fprintf(&out, "%s\r\n", line) } if err := sc.Err(); err != nil { return nil, err } + + unchanged := maps.EqualFunc(prevEntries, nextEntries, func(a, b []string) bool { + return slices.Equal(a, b) + }) + if unchanged { + return nil, nil + } + if len(hosts) > 0 { - fmt.Fprintln(&out, header) + fmt.Fprintf(&out, "%s\r\n", header) for _, c := range comments { - fmt.Fprintln(&out, c) + fmt.Fprintf(&out, "%s\r\n", c) } - fmt.Fprintln(&out) + fmt.Fprintf(&out, "\r\n") for _, he := range hosts { - fmt.Fprintf(&out, "%s %s\n", he.Addr, strings.Join(he.Hosts, " ")) + fmt.Fprintf(&out, "%s %s\r\n", he.Addr, strings.Join(he.Hosts, " ")) } - fmt.Fprintln(&out) - fmt.Fprintln(&out, footer) + fmt.Fprintf(&out, "\r\n%s\r\n", footer) } - return bytes.ReplaceAll(out.Bytes(), []byte("\n"), []byte("\r\n")), nil + return out.Bytes(), nil } // setHosts sets the hosts file to contain the given host entries. @@ -197,10 +247,15 @@ func (m *windowsManager) setHosts(hosts []*HostEntry) error { if err != nil { return err } - outB, err := setTailscaleHosts(b, hosts) + outB, err := setTailscaleHosts(m.logf, b, hosts) if err != nil { return err } + if outB == nil { + // No change to hosts file, therefore no write necessary. + return nil + } + const fileMode = 0 // ignored on windows. // This can fail spuriously with an access denied error, so retry it a @@ -322,11 +377,9 @@ func (m *windowsManager) SetDNS(cfg OSConfig) error { // configuration only, routing one set of things to the "split" // resolver and the rest to the primary. - // Unconditionally disable dynamic DNS updates and NetBIOS on our - // interfaces. - if err := m.disableDynamicUpdates(); err != nil { - m.logf("disableDynamicUpdates error: %v\n", err) - } + // Reconfigure DNS registration according to the [syspolicy.DNSRegistration] + // policy setting, and unconditionally disable NetBIOS on our interfaces. + m.reconfigureDNSRegistration() if err := m.disableNetBIOS(); err != nil { m.logf("disableNetBIOS error: %v\n", err) } @@ -445,6 +498,10 @@ func (m *windowsManager) Close() error { m.closing = true m.mu.Unlock() + if m.unregisterPolicyChangeCb != nil { + m.unregisterPolicyChangeCb() + } + err := m.SetDNS(OSConfig{}) if m.nrptDB != nil { m.nrptDB.Close() @@ -453,15 +510,62 @@ func (m *windowsManager) Close() error { return err } -// disableDynamicUpdates sets the appropriate registry values to prevent the -// Windows DHCP client from sending dynamic DNS updates for our interface to -// AD domain controllers. -func (m *windowsManager) disableDynamicUpdates() error { +// sysPolicyChanged is a callback triggered by [syspolicy] when it detects +// a change in one or more syspolicy settings. +func (m *windowsManager) sysPolicyChanged(policy policyclient.PolicyChange) { + if policy.HasChanged(pkey.EnableDNSRegistration) { + m.reconfigureDNSRegistration() + } +} + +// reconfigureDNSRegistration configures the DNS registration settings +// using the [syspolicy.DNSRegistration] policy setting, if it is set. +// If the policy is not configured, it disables DNS registration. +func (m *windowsManager) reconfigureDNSRegistration() { + // Disable DNS registration by default (if the policy setting is not configured). + // This is primarily for historical reasons and to avoid breaking existing + // setups that rely on this behavior. + enableDNSRegistration, err := m.polc.GetPreferenceOption(pkey.EnableDNSRegistration, ptype.NeverByPolicy) + if err != nil { + m.logf("error getting DNSRegistration policy setting: %v", err) // non-fatal; we'll use the default + } + + if enableDNSRegistration.Show() { + // "Show" reports whether the policy setting is configured as "user-decides". + // The name is a bit unfortunate in this context, as we don't actually "show" anything. + // Still, if the admin configured the policy as "user-decides", we shouldn't modify + // the adapter's settings and should leave them up to the user (admin rights required) + // or the system defaults. + return + } + + // Otherwise, if the policy setting is configured as "always" or "never", + // we should configure the adapter accordingly. + if err := m.configureDNSRegistration(enableDNSRegistration.IsAlways()); err != nil { + m.logf("error configuring DNS registration: %v", err) + } +} + +// configureDNSRegistration sets the appropriate registry values to allow or prevent +// the Windows DHCP client from registering Tailscale IP addresses with DNS +// and sending dynamic updates for our interface to AD domain controllers. +func (m *windowsManager) configureDNSRegistration(enabled bool) error { prefixen := []winutil.RegistryPathPrefix{ winutil.IPv4TCPIPInterfacePrefix, winutil.IPv6TCPIPInterfacePrefix, } + var ( + registrationEnabled = uint32(0) + disableDynamicUpdate = uint32(1) + maxNumberOfAddressesToRegister = uint32(0) + ) + if enabled { + registrationEnabled = 1 + disableDynamicUpdate = 0 + maxNumberOfAddressesToRegister = 1 + } + for _, prefix := range prefixen { k, err := m.openInterfaceKey(prefix) if err != nil { @@ -469,13 +573,13 @@ func (m *windowsManager) disableDynamicUpdates() error { } defer k.Close() - if err := k.SetDWordValue("RegistrationEnabled", 0); err != nil { + if err := k.SetDWordValue("RegistrationEnabled", registrationEnabled); err != nil { return err } - if err := k.SetDWordValue("DisableDynamicUpdate", 1); err != nil { + if err := k.SetDWordValue("DisableDynamicUpdate", disableDynamicUpdate); err != nil { return err } - if err := k.SetDWordValue("MaxNumberOfAddressesToRegister", 0); err != nil { + if err := k.SetDWordValue("MaxNumberOfAddressesToRegister", maxNumberOfAddressesToRegister); err != nil { return err } } diff --git a/net/dns/manager_windows_test.go b/net/dns/manager_windows_test.go index 62c4dd9fb..aa538a0f6 100644 --- a/net/dns/manager_windows_test.go +++ b/net/dns/manager_windows_test.go @@ -15,7 +15,9 @@ import ( "golang.org/x/sys/windows" "golang.org/x/sys/windows/registry" + "tailscale.com/types/logger" "tailscale.com/util/dnsname" + "tailscale.com/util/syspolicy/policyclient" "tailscale.com/util/winutil" "tailscale.com/util/winutil/gp" ) @@ -24,9 +26,56 @@ const testGPRuleID = "{7B1B6151-84E6-41A3-8967-62F7F7B45687}" func TestHostFileNewLines(t *testing.T) { in := []byte("#foo\r\n#bar\n#baz\n") - want := []byte("#foo\r\n#bar\r\n#baz\r\n") + want := []byte("#foo\r\n#bar\r\n#baz\r\n# TailscaleHostsSectionStart\r\n# This section contains MagicDNS entries for Tailscale.\r\n# Do not edit this section manually.\r\n\r\n192.168.1.1 aaron\r\n\r\n# TailscaleHostsSectionEnd\r\n") - got, err := setTailscaleHosts(in, nil) + he := []*HostEntry{ + &HostEntry{ + Addr: netip.MustParseAddr("192.168.1.1"), + Hosts: []string{"aaron"}, + }, + } + got, err := setTailscaleHosts(logger.Discard, in, he) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(got, want) { + t.Errorf("got %q, want %q\n", got, want) + } +} + +func TestHostFileUnchanged(t *testing.T) { + in := []byte("#foo\r\n#bar\r\n#baz\r\n# TailscaleHostsSectionStart\r\n# This section contains MagicDNS entries for Tailscale.\r\n# Do not edit this section manually.\r\n\r\n192.168.1.1 aaron\r\n\r\n# TailscaleHostsSectionEnd\r\n") + + he := []*HostEntry{ + &HostEntry{ + Addr: netip.MustParseAddr("192.168.1.1"), + Hosts: []string{"aaron"}, + }, + } + got, err := setTailscaleHosts(logger.Discard, in, he) + if err != nil { + t.Fatal(err) + } + if got != nil { + t.Errorf("got %q, want nil\n", got) + } +} + +func TestHostFileChanged(t *testing.T) { + in := []byte("#foo\r\n#bar\r\n#baz\r\n# TailscaleHostsSectionStart\r\n# This section contains MagicDNS entries for Tailscale.\r\n# Do not edit this section manually.\r\n\r\n192.168.1.1 aaron1\r\n\r\n# TailscaleHostsSectionEnd\r\n") + want := []byte("#foo\r\n#bar\r\n#baz\r\n# TailscaleHostsSectionStart\r\n# This section contains MagicDNS entries for Tailscale.\r\n# Do not edit this section manually.\r\n\r\n192.168.1.1 aaron1\r\n192.168.1.2 aaron2\r\n\r\n# TailscaleHostsSectionEnd\r\n") + + he := []*HostEntry{ + &HostEntry{ + Addr: netip.MustParseAddr("192.168.1.1"), + Hosts: []string{"aaron1"}, + }, + &HostEntry{ + Addr: netip.MustParseAddr("192.168.1.2"), + Hosts: []string{"aaron2"}, + }, + } + got, err := setTailscaleHosts(logger.Discard, in, he) if err != nil { t.Fatal(err) } @@ -85,7 +134,7 @@ func TestManagerWindowsGPCopy(t *testing.T) { } defer delIfKey() - cfg, err := NewOSConfigurator(logf, nil, nil, fakeInterface.String()) + cfg, err := NewOSConfigurator(logf, nil, policyclient.NoPolicyClient{}, nil, fakeInterface.String()) if err != nil { t.Fatalf("NewOSConfigurator: %v\n", err) } @@ -214,7 +263,7 @@ func runTest(t *testing.T, isLocal bool) { } defer delIfKey() - cfg, err := NewOSConfigurator(logf, nil, nil, fakeInterface.String()) + cfg, err := NewOSConfigurator(logf, nil, policyclient.NoPolicyClient{}, nil, fakeInterface.String()) if err != nil { t.Fatalf("NewOSConfigurator: %v\n", err) } @@ -501,8 +550,8 @@ func genRandomSubdomains(t *testing.T, n int) []dnsname.FQDN { const charset = "abcdefghijklmnopqrstuvwxyz" for len(domains) < cap(domains) { - l := r.Intn(19) + 1 - b := make([]byte, l) + ln := r.Intn(19) + 1 + b := make([]byte, ln) for i := range b { b[i] = charset[r.Intn(len(charset))] } diff --git a/net/dns/nm.go b/net/dns/nm.go index adb33cdb7..a88d29b37 100644 --- a/net/dns/nm.go +++ b/net/dns/nm.go @@ -1,12 +1,14 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -//go:build linux +//go:build linux && !android && !ts_omit_networkmanager package dns import ( "context" + "encoding/binary" + "errors" "fmt" "net" "net/netip" @@ -14,8 +16,8 @@ import ( "time" "github.com/godbus/dbus/v5" - "github.com/josharian/native" "tailscale.com/net/tsaddr" + "tailscale.com/util/cmpver" "tailscale.com/util/dnsname" ) @@ -25,13 +27,6 @@ const ( lowerPriority = int32(200) // lower than all builtin auto priorities ) -// reconfigTimeout is the time interval within which Manager.{Up,Down} should complete. -// -// This is particularly useful because certain conditions can cause indefinite hangs -// (such as improper dbus auth followed by contextless dbus.Object.Call). -// Such operations should be wrapped in a timeout context. -const reconfigTimeout = time.Second - // nmManager uses the NetworkManager DBus API. type nmManager struct { interfaceName string @@ -39,7 +34,13 @@ type nmManager struct { dnsManager dbus.BusObject } -func newNMManager(interfaceName string) (*nmManager, error) { +func init() { + optNewNMManager.Set(newNMManager) + optNMIsUsingResolved.Set(nmIsUsingResolved) + optNMVersionBetween.Set(nmVersionBetween) +} + +func newNMManager(interfaceName string) (OSConfigurator, error) { conn, err := dbus.SystemBus() if err != nil { return nil, err @@ -137,7 +138,7 @@ func (m *nmManager) trySet(ctx context.Context, config OSConfig) error { for _, ip := range config.Nameservers { b := ip.As16() if ip.Is4() { - dnsv4 = append(dnsv4, native.Endian.Uint32(b[12:])) + dnsv4 = append(dnsv4, binary.NativeEndian.Uint32(b[12:])) } else { dnsv6 = append(dnsv6, b[:]) } @@ -389,3 +390,47 @@ func (m *nmManager) Close() error { // settings when the tailscale interface goes away. return nil } + +func nmVersionBetween(first, last string) (bool, error) { + conn, err := dbus.SystemBus() + if err != nil { + // DBus probably not running. + return false, err + } + + nm := conn.Object("org.freedesktop.NetworkManager", dbus.ObjectPath("/org/freedesktop/NetworkManager")) + v, err := nm.GetProperty("org.freedesktop.NetworkManager.Version") + if err != nil { + return false, err + } + + version, ok := v.Value().(string) + if !ok { + return false, fmt.Errorf("unexpected type %T for NM version", v.Value()) + } + + outside := cmpver.Compare(version, first) < 0 || cmpver.Compare(version, last) > 0 + return !outside, nil +} + +func nmIsUsingResolved() error { + conn, err := dbus.SystemBus() + if err != nil { + // DBus probably not running. + return err + } + + nm := conn.Object("org.freedesktop.NetworkManager", dbus.ObjectPath("/org/freedesktop/NetworkManager/DnsManager")) + v, err := nm.GetProperty("org.freedesktop.NetworkManager.DnsManager.Mode") + if err != nil { + return fmt.Errorf("getting NM mode: %w", err) + } + mode, ok := v.Value().(string) + if !ok { + return fmt.Errorf("unexpected type %T for NM DNS mode", v.Value()) + } + if mode != "systemd-resolved" { + return errors.New("NetworkManager is not using systemd-resolved for DNS") + } + return nil +} diff --git a/net/dns/openresolv.go b/net/dns/openresolv.go index 0b5c87a3b..c9562b6a9 100644 --- a/net/dns/openresolv.go +++ b/net/dns/openresolv.go @@ -1,7 +1,7 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -//go:build linux || freebsd || openbsd +//go:build (linux && !android) || freebsd || openbsd package dns diff --git a/net/dns/osconfig.go b/net/dns/osconfig.go index 842c5ac60..af4c0f01f 100644 --- a/net/dns/osconfig.go +++ b/net/dns/osconfig.go @@ -11,6 +11,7 @@ import ( "slices" "strings" + "tailscale.com/feature/buildfeatures" "tailscale.com/types/logger" "tailscale.com/util/dnsname" ) @@ -158,6 +159,10 @@ func (a OSConfig) Equal(b OSConfig) bool { // Fixes https://github.com/tailscale/tailscale/issues/5669 func (a OSConfig) Format(f fmt.State, verb rune) { logger.ArgWriter(func(w *bufio.Writer) { + if !buildfeatures.HasDNS { + w.WriteString(`{DNS-unlinked}`) + return + } w.WriteString(`{Nameservers:[`) for i, ns := range a.Nameservers { if i != 0 { diff --git a/net/dns/publicdns/publicdns.go b/net/dns/publicdns/publicdns.go index 0dbd3ab82..b8a7f8809 100644 --- a/net/dns/publicdns/publicdns.go +++ b/net/dns/publicdns/publicdns.go @@ -17,6 +17,8 @@ import ( "strconv" "strings" "sync" + + "tailscale.com/feature/buildfeatures" ) // dohOfIP maps from public DNS IPs to their DoH base URL. @@ -163,6 +165,9 @@ const ( // populate is called once to initialize the knownDoH and dohIPsOfBase maps. func populate() { + if !buildfeatures.HasDNS { + return + } // Cloudflare // https://developers.cloudflare.com/1.1.1.1/ip-addresses/ addDoH("1.1.1.1", "https://cloudflare-dns.com/dns-query") diff --git a/net/dns/recursive/recursive.go b/net/dns/recursive/recursive.go deleted file mode 100644 index eb23004d8..000000000 --- a/net/dns/recursive/recursive.go +++ /dev/null @@ -1,621 +0,0 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package recursive implements a simple recursive DNS resolver. -package recursive - -import ( - "context" - "errors" - "fmt" - "net" - "net/netip" - "slices" - "strings" - "time" - - "github.com/miekg/dns" - "tailscale.com/envknob" - "tailscale.com/net/netns" - "tailscale.com/types/logger" - "tailscale.com/util/dnsname" - "tailscale.com/util/mak" - "tailscale.com/util/multierr" - "tailscale.com/util/slicesx" -) - -const ( - // maxDepth is how deep from the root nameservers we'll recurse when - // resolving; passing this limit will instead return an error. - // - // maxDepth must be at least 20 to resolve "console.aws.amazon.com", - // which is a domain with a moderately complicated DNS setup. The - // current value of 30 was chosen semi-arbitrarily to ensure that we - // have about 50% headroom. - maxDepth = 30 - // numStartingServers is the number of root nameservers that we use as - // initial candidates for our recursion. - numStartingServers = 3 - // udpQueryTimeout is the amount of time we wait for a UDP response - // from a nameserver before falling back to a TCP connection. - udpQueryTimeout = 5 * time.Second - - // These constants aren't typed in the DNS package, so we create typed - // versions here to avoid having to do repeated type casts. - qtypeA dns.Type = dns.Type(dns.TypeA) - qtypeAAAA dns.Type = dns.Type(dns.TypeAAAA) -) - -var ( - // ErrMaxDepth is returned when recursive resolving exceeds the maximum - // depth limit for this package. - ErrMaxDepth = fmt.Errorf("exceeded max depth %d when resolving", maxDepth) - - // ErrAuthoritativeNoResponses is the error returned when an - // authoritative nameserver indicates that there are no responses to - // the given query. - ErrAuthoritativeNoResponses = errors.New("authoritative server returned no responses") - - // ErrNoResponses is returned when our resolution process completes - // with no valid responses from any nameserver, but no authoritative - // server explicitly returned NXDOMAIN. - ErrNoResponses = errors.New("no responses to query") -) - -var rootServersV4 = []netip.Addr{ - netip.MustParseAddr("198.41.0.4"), // a.root-servers.net - netip.MustParseAddr("170.247.170.2"), // b.root-servers.net - netip.MustParseAddr("192.33.4.12"), // c.root-servers.net - netip.MustParseAddr("199.7.91.13"), // d.root-servers.net - netip.MustParseAddr("192.203.230.10"), // e.root-servers.net - netip.MustParseAddr("192.5.5.241"), // f.root-servers.net - netip.MustParseAddr("192.112.36.4"), // g.root-servers.net - netip.MustParseAddr("198.97.190.53"), // h.root-servers.net - netip.MustParseAddr("192.36.148.17"), // i.root-servers.net - netip.MustParseAddr("192.58.128.30"), // j.root-servers.net - netip.MustParseAddr("193.0.14.129"), // k.root-servers.net - netip.MustParseAddr("199.7.83.42"), // l.root-servers.net - netip.MustParseAddr("202.12.27.33"), // m.root-servers.net -} - -var rootServersV6 = []netip.Addr{ - netip.MustParseAddr("2001:503:ba3e::2:30"), // a.root-servers.net - netip.MustParseAddr("2801:1b8:10::b"), // b.root-servers.net - netip.MustParseAddr("2001:500:2::c"), // c.root-servers.net - netip.MustParseAddr("2001:500:2d::d"), // d.root-servers.net - netip.MustParseAddr("2001:500:a8::e"), // e.root-servers.net - netip.MustParseAddr("2001:500:2f::f"), // f.root-servers.net - netip.MustParseAddr("2001:500:12::d0d"), // g.root-servers.net - netip.MustParseAddr("2001:500:1::53"), // h.root-servers.net - netip.MustParseAddr("2001:7fe::53"), // i.root-servers.net - netip.MustParseAddr("2001:503:c27::2:30"), // j.root-servers.net - netip.MustParseAddr("2001:7fd::1"), // k.root-servers.net - netip.MustParseAddr("2001:500:9f::42"), // l.root-servers.net - netip.MustParseAddr("2001:dc3::35"), // m.root-servers.net -} - -var debug = envknob.RegisterBool("TS_DEBUG_RECURSIVE_DNS") - -// Resolver is a recursive DNS resolver that is designed for looking up A and AAAA records. -type Resolver struct { - // Dialer is used to create outbound connections. If nil, a zero - // net.Dialer will be used instead. - Dialer netns.Dialer - - // Logf is the logging function to use; if none is specified, then logs - // will be dropped. - Logf logger.Logf - - // NoIPv6, if set, will prevent this package from querying for AAAA - // records and will avoid contacting nameservers over IPv6. - NoIPv6 bool - - // Test mocks - testQueryHook func(name dnsname.FQDN, nameserver netip.Addr, protocol string, qtype dns.Type) (*dns.Msg, error) - testExchangeHook func(nameserver netip.Addr, network string, msg *dns.Msg) (*dns.Msg, error) - rootServers []netip.Addr - timeNow func() time.Time - - // Caching - // NOTE(andrew): if we make resolution parallel, this needs a mutex - queryCache map[dnsQuery]dnsMsgWithExpiry - - // Possible future additions: - // - Additional nameservers? From the system maybe? - // - NoIPv4 for IPv4 - // - DNS-over-HTTPS or DNS-over-TLS support -} - -// queryState stores all state during the course of a single query -type queryState struct { - // rootServers are the root nameservers to start from - rootServers []netip.Addr - - // TODO: metrics? -} - -type dnsQuery struct { - nameserver netip.Addr - name dnsname.FQDN - qtype dns.Type -} - -func (q dnsQuery) String() string { - return fmt.Sprintf("dnsQuery{nameserver:%q,name:%q,qtype:%v}", q.nameserver.String(), q.name, q.qtype) -} - -type dnsMsgWithExpiry struct { - *dns.Msg - expiresAt time.Time -} - -func (r *Resolver) now() time.Time { - if r.timeNow != nil { - return r.timeNow() - } - return time.Now() -} - -func (r *Resolver) logf(format string, args ...any) { - if r.Logf == nil { - return - } - r.Logf(format, args...) -} - -func (r *Resolver) depthlogf(depth int, format string, args ...any) { - if r.Logf == nil || !debug() { - return - } - prefix := fmt.Sprintf("[%d] %s", depth, strings.Repeat(" ", depth)) - r.Logf(prefix+format, args...) -} - -var defaultDialer net.Dialer - -func (r *Resolver) dialer() netns.Dialer { - if r.Dialer != nil { - return r.Dialer - } - - return &defaultDialer -} - -func (r *Resolver) newState() *queryState { - var rootServers []netip.Addr - if len(r.rootServers) > 0 { - rootServers = r.rootServers - } else { - // Select a random subset of root nameservers to start from, since if - // we don't get responses from those, something else has probably gone - // horribly wrong. - roots4 := slices.Clone(rootServersV4) - slicesx.Shuffle(roots4) - roots4 = roots4[:numStartingServers] - - var roots6 []netip.Addr - if !r.NoIPv6 { - roots6 = slices.Clone(rootServersV6) - slicesx.Shuffle(roots6) - roots6 = roots6[:numStartingServers] - } - - // Interleave the root servers so that we try to contact them over - // IPv4, then IPv6, IPv4, IPv6, etc. - rootServers = slicesx.Interleave(roots4, roots6) - } - - return &queryState{ - rootServers: rootServers, - } -} - -// Resolve will perform a recursive DNS resolution for the provided name, -// starting at a randomly-chosen root DNS server, and return the A and AAAA -// responses as a slice of netip.Addrs along with the minimum TTL for the -// returned records. -func (r *Resolver) Resolve(ctx context.Context, name string) (addrs []netip.Addr, minTTL time.Duration, err error) { - dnsName, err := dnsname.ToFQDN(name) - if err != nil { - return nil, 0, err - } - - qstate := r.newState() - - r.logf("querying IPv4 addresses for: %q", name) - addrs4, minTTL4, err4 := r.resolveRecursiveFromRoot(ctx, qstate, 0, dnsName, qtypeA) - - var ( - addrs6 []netip.Addr - minTTL6 time.Duration - err6 error - ) - if !r.NoIPv6 { - r.logf("querying IPv6 addresses for: %q", name) - addrs6, minTTL6, err6 = r.resolveRecursiveFromRoot(ctx, qstate, 0, dnsName, qtypeAAAA) - } - - if err4 != nil && err6 != nil { - if err4 == err6 { - return nil, 0, err4 - } - - return nil, 0, multierr.New(err4, err6) - } - if err4 != nil { - return addrs6, minTTL6, nil - } else if err6 != nil { - return addrs4, minTTL4, nil - } - - minTTL = minTTL4 - if minTTL6 < minTTL { - minTTL = minTTL6 - } - - addrs = append(addrs4, addrs6...) - if len(addrs) == 0 { - return nil, 0, ErrNoResponses - } - - slicesx.Shuffle(addrs) - return addrs, minTTL, nil -} - -func (r *Resolver) resolveRecursiveFromRoot( - ctx context.Context, - qstate *queryState, - depth int, - name dnsname.FQDN, // what we're querying - qtype dns.Type, -) ([]netip.Addr, time.Duration, error) { - r.depthlogf(depth, "resolving %q from root (type: %v)", name, qtype) - - var depthError bool - for _, server := range qstate.rootServers { - addrs, minTTL, err := r.resolveRecursive(ctx, qstate, depth, name, server, qtype) - if err == nil { - return addrs, minTTL, err - } else if errors.Is(err, ErrAuthoritativeNoResponses) { - return nil, 0, ErrAuthoritativeNoResponses - } else if errors.Is(err, ErrMaxDepth) { - depthError = true - } - } - - if depthError { - return nil, 0, ErrMaxDepth - } - return nil, 0, ErrNoResponses -} - -func (r *Resolver) resolveRecursive( - ctx context.Context, - qstate *queryState, - depth int, - name dnsname.FQDN, // what we're querying - nameserver netip.Addr, - qtype dns.Type, -) ([]netip.Addr, time.Duration, error) { - if depth == maxDepth { - r.depthlogf(depth, "not recursing past maximum depth") - return nil, 0, ErrMaxDepth - } - - // Ask this nameserver for an answer. - resp, err := r.queryNameserver(ctx, depth, name, nameserver, qtype) - if err != nil { - return nil, 0, err - } - - // If we get an actual answer from the nameserver, then return it. - var ( - answers []netip.Addr - cnames []dnsname.FQDN - minTTL = 24 * 60 * 60 // 24 hours in seconds - ) - for _, answer := range resp.Answer { - if crec, ok := answer.(*dns.CNAME); ok { - cnameFQDN, err := dnsname.ToFQDN(crec.Target) - if err != nil { - r.logf("bad CNAME %q returned: %v", crec.Target, err) - continue - } - - cnames = append(cnames, cnameFQDN) - continue - } - - addr := addrFromRecord(answer) - if !addr.IsValid() { - r.logf("[unexpected] invalid record in %T answer", answer) - } else if addr.Is4() && qtype != qtypeA { - r.logf("[unexpected] got IPv4 answer but qtype=%v", qtype) - } else if addr.Is6() && qtype != qtypeAAAA { - r.logf("[unexpected] got IPv6 answer but qtype=%v", qtype) - } else { - answers = append(answers, addr) - minTTL = min(minTTL, int(answer.Header().Ttl)) - } - } - - if len(answers) > 0 { - r.depthlogf(depth, "got answers for %q: %v", name, answers) - return answers, time.Duration(minTTL) * time.Second, nil - } - - r.depthlogf(depth, "no answers for %q", name) - - // If we have a non-zero number of CNAMEs, then try resolving those - // (from the root again) and return the first one that succeeds. - // - // TODO: return the union of all responses? - // TODO: parallelism? - if len(cnames) > 0 { - r.depthlogf(depth, "got CNAME responses for %q: %v", name, cnames) - } - var cnameDepthError bool - for _, cname := range cnames { - answers, minTTL, err := r.resolveRecursiveFromRoot(ctx, qstate, depth+1, cname, qtype) - if err == nil { - return answers, minTTL, nil - } else if errors.Is(err, ErrAuthoritativeNoResponses) { - return nil, 0, ErrAuthoritativeNoResponses - } else if errors.Is(err, ErrMaxDepth) { - cnameDepthError = true - } - } - - // If this is an authoritative response, then we know that continuing - // to look further is not going to result in any answers and we should - // bail out. - if resp.MsgHdr.Authoritative { - // If we failed to recurse into a CNAME due to a depth limit, - // propagate that here. - if cnameDepthError { - return nil, 0, ErrMaxDepth - } - - r.depthlogf(depth, "got authoritative response with no answers; stopping") - return nil, 0, ErrAuthoritativeNoResponses - } - - r.depthlogf(depth, "got %d NS responses and %d ADDITIONAL responses for %q", len(resp.Ns), len(resp.Extra), name) - - // No CNAMEs and no answers; see if we got any AUTHORITY responses, - // which indicate which nameservers to query next. - var authorities []dnsname.FQDN - for _, rr := range resp.Ns { - ns, ok := rr.(*dns.NS) - if !ok { - continue - } - - nsName, err := dnsname.ToFQDN(ns.Ns) - if err != nil { - r.logf("unexpected bad NS name %q: %v", ns.Ns, err) - continue - } - - authorities = append(authorities, nsName) - } - - // Also check for "glue" records, which are IP addresses provided by - // the DNS server for authority responses; these are required when the - // authority server is a subdomain of what's being resolved. - glueRecords := make(map[dnsname.FQDN][]netip.Addr) - for _, rr := range resp.Extra { - name, err := dnsname.ToFQDN(rr.Header().Name) - if err != nil { - r.logf("unexpected bad Name %q in Extra addr: %v", rr.Header().Name, err) - continue - } - - if addr := addrFromRecord(rr); addr.IsValid() { - glueRecords[name] = append(glueRecords[name], addr) - } else { - r.logf("unexpected bad Extra %T addr", rr) - } - } - - // Try authorities with glue records first, to minimize the number of - // additional DNS queries that we need to make. - authoritiesGlue, authoritiesNoGlue := slicesx.Partition(authorities, func(aa dnsname.FQDN) bool { - return len(glueRecords[aa]) > 0 - }) - - authorityDepthError := false - - r.depthlogf(depth, "authorities with glue records for recursion: %v", authoritiesGlue) - for _, authority := range authoritiesGlue { - for _, nameserver := range glueRecords[authority] { - answers, minTTL, err := r.resolveRecursive(ctx, qstate, depth+1, name, nameserver, qtype) - if err == nil { - return answers, minTTL, nil - } else if errors.Is(err, ErrAuthoritativeNoResponses) { - return nil, 0, ErrAuthoritativeNoResponses - } else if errors.Is(err, ErrMaxDepth) { - authorityDepthError = true - } - } - } - - r.depthlogf(depth, "authorities with no glue records for recursion: %v", authoritiesNoGlue) - for _, authority := range authoritiesNoGlue { - // First, resolve the IP for the authority server from the - // root, querying for both IPv4 and IPv6 addresses regardless - // of what the current question type is. - // - // TODO: check for infinite recursion; it'll get caught by our - // recursion depth, but we want to bail early. - for _, authorityQtype := range []dns.Type{qtypeAAAA, qtypeA} { - answers, _, err := r.resolveRecursiveFromRoot(ctx, qstate, depth+1, authority, authorityQtype) - if err != nil { - r.depthlogf(depth, "error querying authority %q: %v", authority, err) - continue - } - r.depthlogf(depth, "resolved authority %q (type %v) to: %v", authority, authorityQtype, answers) - - // Now, query this authority for the final address. - for _, nameserver := range answers { - answers, minTTL, err := r.resolveRecursive(ctx, qstate, depth+1, name, nameserver, qtype) - if err == nil { - return answers, minTTL, nil - } else if errors.Is(err, ErrAuthoritativeNoResponses) { - return nil, 0, ErrAuthoritativeNoResponses - } else if errors.Is(err, ErrMaxDepth) { - authorityDepthError = true - } - } - } - } - - if authorityDepthError { - return nil, 0, ErrMaxDepth - } - return nil, 0, ErrNoResponses -} - -// queryNameserver sends a query for "name" to the nameserver "nameserver" for -// records of type "qtype", trying both UDP and TCP connections as -// appropriate. -func (r *Resolver) queryNameserver( - ctx context.Context, - depth int, - name dnsname.FQDN, // what we're querying - nameserver netip.Addr, // destination of query - qtype dns.Type, -) (*dns.Msg, error) { - // TODO(andrew): we should QNAME minimisation here to avoid sending the - // full name to intermediate/root nameservers. See: - // https://www.rfc-editor.org/rfc/rfc7816 - - // Handle the case where UDP is blocked by adding an explicit timeout - // for the UDP portion of this query. - udpCtx, udpCtxCancel := context.WithTimeout(ctx, udpQueryTimeout) - defer udpCtxCancel() - - msg, err := r.queryNameserverProto(udpCtx, depth, name, nameserver, "udp", qtype) - if err == nil { - return msg, nil - } - - msg, err2 := r.queryNameserverProto(ctx, depth, name, nameserver, "tcp", qtype) - if err2 == nil { - return msg, nil - } - - return nil, multierr.New(err, err2) -} - -// queryNameserverProto sends a query for "name" to the nameserver "nameserver" -// for records of type "qtype" over the provided protocol (either "udp" -// or "tcp"), and returns the DNS response or an error. -func (r *Resolver) queryNameserverProto( - ctx context.Context, - depth int, - name dnsname.FQDN, // what we're querying - nameserver netip.Addr, // destination of query - protocol string, - qtype dns.Type, -) (resp *dns.Msg, err error) { - if r.testQueryHook != nil { - return r.testQueryHook(name, nameserver, protocol, qtype) - } - - now := r.now() - nameserverStr := nameserver.String() - - cacheKey := dnsQuery{ - nameserver: nameserver, - name: name, - qtype: qtype, - } - cacheEntry, ok := r.queryCache[cacheKey] - if ok && cacheEntry.expiresAt.Before(now) { - r.depthlogf(depth, "using cached response from %s about %q (type: %v)", nameserverStr, name, qtype) - return cacheEntry.Msg, nil - } - - var network string - if nameserver.Is4() { - network = protocol + "4" - } else { - network = protocol + "6" - } - - // Prepare a message asking for an appropriately-typed record - // for the name we're querying. - m := new(dns.Msg) - m.SetQuestion(name.WithTrailingDot(), uint16(qtype)) - - // Allow mocking out the network components with our exchange hook. - if r.testExchangeHook != nil { - resp, err = r.testExchangeHook(nameserver, network, m) - } else { - // Dial the current nameserver using our dialer. - var nconn net.Conn - nconn, err = r.dialer().DialContext(ctx, network, net.JoinHostPort(nameserverStr, "53")) - if err != nil { - return nil, err - } - - var c dns.Client // TODO: share? - conn := &dns.Conn{ - Conn: nconn, - UDPSize: c.UDPSize, - } - - // Send the DNS request to the current nameserver. - r.depthlogf(depth, "asking %s over %s about %q (type: %v)", nameserverStr, protocol, name, qtype) - resp, _, err = c.ExchangeWithConnContext(ctx, m, conn) - } - if err != nil { - return nil, err - } - - // If the message was truncated and we're using UDP, re-run with TCP. - if resp.MsgHdr.Truncated && protocol == "udp" { - r.depthlogf(depth, "response message truncated; re-running query with TCP") - resp, err = r.queryNameserverProto(ctx, depth, name, nameserver, "tcp", qtype) - if err != nil { - return nil, err - } - } - - // Find minimum expiry for all records in this message. - var minTTL int - for _, rr := range resp.Answer { - minTTL = min(minTTL, int(rr.Header().Ttl)) - } - for _, rr := range resp.Ns { - minTTL = min(minTTL, int(rr.Header().Ttl)) - } - for _, rr := range resp.Extra { - minTTL = min(minTTL, int(rr.Header().Ttl)) - } - - mak.Set(&r.queryCache, cacheKey, dnsMsgWithExpiry{ - Msg: resp, - expiresAt: now.Add(time.Duration(minTTL) * time.Second), - }) - return resp, nil -} - -func addrFromRecord(rr dns.RR) netip.Addr { - switch v := rr.(type) { - case *dns.A: - ip, ok := netip.AddrFromSlice(v.A) - if !ok || !ip.Is4() { - return netip.Addr{} - } - return ip - case *dns.AAAA: - ip, ok := netip.AddrFromSlice(v.AAAA) - if !ok || !ip.Is6() { - return netip.Addr{} - } - return ip - } - return netip.Addr{} -} diff --git a/net/dns/recursive/recursive_test.go b/net/dns/recursive/recursive_test.go deleted file mode 100644 index d47e4cebf..000000000 --- a/net/dns/recursive/recursive_test.go +++ /dev/null @@ -1,742 +0,0 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package recursive - -import ( - "context" - "errors" - "flag" - "fmt" - "net" - "net/netip" - "reflect" - "strings" - "testing" - "time" - - "slices" - - "github.com/miekg/dns" - "tailscale.com/envknob" - "tailscale.com/tstest" -) - -const testDomain = "tailscale.com" - -// Recursively resolving the AWS console requires being able to handle CNAMEs, -// glue records, falling back from UDP to TCP for oversize queries, and more; -// it's a great integration test for DNS resolution and they can handle the -// traffic :) -const complicatedTestDomain = "console.aws.amazon.com" - -var flagNetworkAccess = flag.Bool("enable-network-access", false, "run tests that need external network access") - -func init() { - envknob.Setenv("TS_DEBUG_RECURSIVE_DNS", "true") -} - -func newResolver(tb testing.TB) *Resolver { - clock := tstest.NewClock(tstest.ClockOpts{ - Step: 50 * time.Millisecond, - }) - return &Resolver{ - Logf: tb.Logf, - timeNow: clock.Now, - } -} - -func TestResolve(t *testing.T) { - if !*flagNetworkAccess { - t.SkipNow() - } - - ctx := context.Background() - r := newResolver(t) - addrs, minTTL, err := r.Resolve(ctx, testDomain) - if err != nil { - t.Fatal(err) - } - - t.Logf("addrs: %+v", addrs) - t.Logf("minTTL: %v", minTTL) - if len(addrs) < 1 { - t.Fatalf("expected at least one address") - } - - if minTTL <= 10*time.Second || minTTL >= 24*time.Hour { - t.Errorf("invalid minimum TTL: %v", minTTL) - } - - var has4, has6 bool - for _, addr := range addrs { - has4 = has4 || addr.Is4() - has6 = has6 || addr.Is6() - } - - if !has4 { - t.Errorf("expected at least one IPv4 address") - } - if !has6 { - t.Errorf("expected at least one IPv6 address") - } -} - -func TestResolveComplicated(t *testing.T) { - if !*flagNetworkAccess { - t.SkipNow() - } - - ctx := context.Background() - r := newResolver(t) - addrs, minTTL, err := r.Resolve(ctx, complicatedTestDomain) - if err != nil { - t.Fatal(err) - } - - t.Logf("addrs: %+v", addrs) - t.Logf("minTTL: %v", minTTL) - if len(addrs) < 1 { - t.Fatalf("expected at least one address") - } - - if minTTL <= 10*time.Second || minTTL >= 24*time.Hour { - t.Errorf("invalid minimum TTL: %v", minTTL) - } -} - -func TestResolveNoIPv6(t *testing.T) { - if !*flagNetworkAccess { - t.SkipNow() - } - - r := newResolver(t) - r.NoIPv6 = true - - addrs, _, err := r.Resolve(context.Background(), testDomain) - if err != nil { - t.Fatal(err) - } - - t.Logf("addrs: %+v", addrs) - if len(addrs) < 1 { - t.Fatalf("expected at least one address") - } - - for _, addr := range addrs { - if addr.Is6() { - t.Errorf("got unexpected IPv6 address: %v", addr) - } - } -} - -func TestResolveFallbackToTCP(t *testing.T) { - var udpCalls, tcpCalls int - hook := func(nameserver netip.Addr, network string, req *dns.Msg) (*dns.Msg, error) { - if strings.HasPrefix(network, "udp") { - t.Logf("got %q query; returning truncated result", network) - udpCalls++ - resp := &dns.Msg{} - resp.SetReply(req) - resp.Truncated = true - return resp, nil - } - - t.Logf("got %q query; returning real result", network) - tcpCalls++ - resp := &dns.Msg{} - resp.SetReply(req) - resp.Answer = append(resp.Answer, &dns.A{ - Hdr: dns.RR_Header{ - Name: req.Question[0].Name, - Rrtype: req.Question[0].Qtype, - Class: dns.ClassINET, - Ttl: 300, - }, - A: net.IPv4(1, 2, 3, 4), - }) - return resp, nil - } - - r := newResolver(t) - r.testExchangeHook = hook - - ctx := context.Background() - resp, err := r.queryNameserverProto(ctx, 0, "tailscale.com", netip.MustParseAddr("9.9.9.9"), "udp", dns.Type(dns.TypeA)) - if err != nil { - t.Fatal(err) - } - - if len(resp.Answer) < 1 { - t.Fatalf("no answers in response: %v", resp) - } - rrA, ok := resp.Answer[0].(*dns.A) - if !ok { - t.Fatalf("invalid RR type: %T", resp.Answer[0]) - } - if !rrA.A.Equal(net.IPv4(1, 2, 3, 4)) { - t.Errorf("wanted A response 1.2.3.4, got: %v", rrA.A) - } - if tcpCalls != 1 { - t.Errorf("got %d, want 1 TCP calls", tcpCalls) - } - if udpCalls != 1 { - t.Errorf("got %d, want 1 UDP calls", udpCalls) - } - - // Verify that we're cached and re-run to fetch from the cache. - if len(r.queryCache) < 1 { - t.Errorf("wanted entries in the query cache") - } - - resp2, err := r.queryNameserverProto(ctx, 0, "tailscale.com", netip.MustParseAddr("9.9.9.9"), "udp", dns.Type(dns.TypeA)) - if err != nil { - t.Fatal(err) - } - if !reflect.DeepEqual(resp, resp2) { - t.Errorf("expected equal responses; old=%+v new=%+v", resp, resp2) - } - - // We didn't make any more network requests since we loaded from the cache. - if tcpCalls != 1 { - t.Errorf("got %d, want 1 TCP calls", tcpCalls) - } - if udpCalls != 1 { - t.Errorf("got %d, want 1 UDP calls", udpCalls) - } -} - -func dnsIPRR(name string, addr netip.Addr) dns.RR { - if addr.Is4() { - return &dns.A{ - Hdr: dns.RR_Header{ - Name: name, - Rrtype: dns.TypeA, - Class: dns.ClassINET, - Ttl: 300, - }, - A: net.IP(addr.AsSlice()), - } - } - - return &dns.AAAA{ - Hdr: dns.RR_Header{ - Name: name, - Rrtype: dns.TypeAAAA, - Class: dns.ClassINET, - Ttl: 300, - }, - AAAA: net.IP(addr.AsSlice()), - } -} - -func cnameRR(name, target string) dns.RR { - return &dns.CNAME{ - Hdr: dns.RR_Header{ - Name: name, - Rrtype: dns.TypeCNAME, - Class: dns.ClassINET, - Ttl: 300, - }, - Target: target, - } -} - -func nsRR(name, target string) dns.RR { - return &dns.NS{ - Hdr: dns.RR_Header{ - Name: name, - Rrtype: dns.TypeNS, - Class: dns.ClassINET, - Ttl: 300, - }, - Ns: target, - } -} - -type mockReply struct { - name string - qtype dns.Type - resp *dns.Msg -} - -type replyMock struct { - tb testing.TB - replies map[netip.Addr][]mockReply -} - -func (r *replyMock) exchangeHook(nameserver netip.Addr, network string, req *dns.Msg) (*dns.Msg, error) { - if len(req.Question) != 1 { - r.tb.Fatalf("unsupported multiple or empty question: %v", req.Question) - } - question := req.Question[0] - - replies := r.replies[nameserver] - if len(replies) == 0 { - r.tb.Fatalf("no configured replies for nameserver: %v", nameserver) - } - - for _, reply := range replies { - if reply.name == question.Name && reply.qtype == dns.Type(question.Qtype) { - return reply.resp.Copy(), nil - } - } - - r.tb.Fatalf("no replies found for query %q of type %v to %v", question.Name, question.Qtype, nameserver) - panic("unreachable") -} - -// responses for mocking, shared between the following tests -var ( - rootServerAddr = netip.MustParseAddr("198.41.0.4") // a.root-servers.net. - comNSAddr = netip.MustParseAddr("192.5.6.30") // a.gtld-servers.net. - - // DNS response from the root nameservers for a .com nameserver - comRecord = &dns.Msg{ - Ns: []dns.RR{nsRR("com.", "a.gtld-servers.net.")}, - Extra: []dns.RR{dnsIPRR("a.gtld-servers.net.", comNSAddr)}, - } - - // Random Amazon nameservers that we use in glue records - amazonNS = netip.MustParseAddr("205.251.192.197") - amazonNSv6 = netip.MustParseAddr("2600:9000:5306:1600::1") - - // Nameservers for the tailscale.com domain - tailscaleNameservers = &dns.Msg{ - Ns: []dns.RR{ - nsRR("tailscale.com.", "ns-197.awsdns-24.com."), - nsRR("tailscale.com.", "ns-557.awsdns-05.net."), - nsRR("tailscale.com.", "ns-1558.awsdns-02.co.uk."), - nsRR("tailscale.com.", "ns-1359.awsdns-41.org."), - }, - Extra: []dns.RR{ - dnsIPRR("ns-197.awsdns-24.com.", amazonNS), - }, - } -) - -func TestBasicRecursion(t *testing.T) { - mock := &replyMock{ - tb: t, - replies: map[netip.Addr][]mockReply{ - // Query to the root server returns the .com server + a glue record - rootServerAddr: { - {name: "tailscale.com.", qtype: dns.Type(dns.TypeA), resp: comRecord}, - {name: "tailscale.com.", qtype: dns.Type(dns.TypeAAAA), resp: comRecord}, - }, - - // Query to the ".com" server return the nameservers for tailscale.com - comNSAddr: { - {name: "tailscale.com.", qtype: dns.Type(dns.TypeA), resp: tailscaleNameservers}, - {name: "tailscale.com.", qtype: dns.Type(dns.TypeAAAA), resp: tailscaleNameservers}, - }, - - // Query to the actual nameserver works. - amazonNS: { - {name: "tailscale.com.", qtype: dns.Type(dns.TypeA), resp: &dns.Msg{ - MsgHdr: dns.MsgHdr{Authoritative: true}, - Answer: []dns.RR{ - dnsIPRR("tailscale.com.", netip.MustParseAddr("13.248.141.131")), - dnsIPRR("tailscale.com.", netip.MustParseAddr("76.223.15.28")), - }, - }}, - {name: "tailscale.com.", qtype: dns.Type(dns.TypeAAAA), resp: &dns.Msg{ - MsgHdr: dns.MsgHdr{Authoritative: true}, - Answer: []dns.RR{ - dnsIPRR("tailscale.com.", netip.MustParseAddr("2600:9000:a602:b1e6:86d:8165:5e8c:295b")), - dnsIPRR("tailscale.com.", netip.MustParseAddr("2600:9000:a51d:27c1:1530:b9ef:2a6:b9e5")), - }, - }}, - }, - }, - } - - r := newResolver(t) - r.testExchangeHook = mock.exchangeHook - r.rootServers = []netip.Addr{rootServerAddr} - - // Query for tailscale.com, verify we get the right responses - ctx := context.Background() - addrs, minTTL, err := r.Resolve(ctx, "tailscale.com") - if err != nil { - t.Fatal(err) - } - wantAddrs := []netip.Addr{ - netip.MustParseAddr("13.248.141.131"), - netip.MustParseAddr("76.223.15.28"), - netip.MustParseAddr("2600:9000:a602:b1e6:86d:8165:5e8c:295b"), - netip.MustParseAddr("2600:9000:a51d:27c1:1530:b9ef:2a6:b9e5"), - } - slices.SortFunc(addrs, func(x, y netip.Addr) int { return strings.Compare(x.String(), y.String()) }) - slices.SortFunc(wantAddrs, func(x, y netip.Addr) int { return strings.Compare(x.String(), y.String()) }) - - if !reflect.DeepEqual(addrs, wantAddrs) { - t.Errorf("got addrs=%+v; want %+v", addrs, wantAddrs) - } - - const wantMinTTL = 5 * time.Minute - if minTTL != wantMinTTL { - t.Errorf("got minTTL=%+v; want %+v", minTTL, wantMinTTL) - } -} - -func TestNoAnswers(t *testing.T) { - mock := &replyMock{ - tb: t, - replies: map[netip.Addr][]mockReply{ - // Query to the root server returns the .com server + a glue record - rootServerAddr: { - {name: "tailscale.com.", qtype: dns.Type(dns.TypeA), resp: comRecord}, - {name: "tailscale.com.", qtype: dns.Type(dns.TypeAAAA), resp: comRecord}, - }, - - // Query to the ".com" server return the nameservers for tailscale.com - comNSAddr: { - {name: "tailscale.com.", qtype: dns.Type(dns.TypeA), resp: tailscaleNameservers}, - {name: "tailscale.com.", qtype: dns.Type(dns.TypeAAAA), resp: tailscaleNameservers}, - }, - - // Query to the actual nameserver returns no responses, authoritatively. - amazonNS: { - {name: "tailscale.com.", qtype: dns.Type(dns.TypeA), resp: &dns.Msg{ - MsgHdr: dns.MsgHdr{Authoritative: true}, - Answer: []dns.RR{}, - }}, - {name: "tailscale.com.", qtype: dns.Type(dns.TypeAAAA), resp: &dns.Msg{ - MsgHdr: dns.MsgHdr{Authoritative: true}, - Answer: []dns.RR{}, - }}, - }, - }, - } - - r := &Resolver{ - Logf: t.Logf, - testExchangeHook: mock.exchangeHook, - rootServers: []netip.Addr{rootServerAddr}, - } - - // Query for tailscale.com, verify we get the right responses - _, _, err := r.Resolve(context.Background(), "tailscale.com") - if err == nil { - t.Fatalf("got no error, want error") - } - if !errors.Is(err, ErrAuthoritativeNoResponses) { - t.Fatalf("got err=%v, want %v", err, ErrAuthoritativeNoResponses) - } -} - -func TestRecursionCNAME(t *testing.T) { - mock := &replyMock{ - tb: t, - replies: map[netip.Addr][]mockReply{ - // Query to the root server returns the .com server + a glue record - rootServerAddr: { - {name: "subdomain.otherdomain.com.", qtype: dns.Type(dns.TypeA), resp: comRecord}, - {name: "subdomain.otherdomain.com.", qtype: dns.Type(dns.TypeAAAA), resp: comRecord}, - - {name: "subdomain.tailscale.com.", qtype: dns.Type(dns.TypeA), resp: comRecord}, - {name: "subdomain.tailscale.com.", qtype: dns.Type(dns.TypeAAAA), resp: comRecord}, - }, - - // Query to the ".com" server return the nameservers for tailscale.com - comNSAddr: { - {name: "subdomain.otherdomain.com.", qtype: dns.Type(dns.TypeA), resp: tailscaleNameservers}, - {name: "subdomain.otherdomain.com.", qtype: dns.Type(dns.TypeAAAA), resp: tailscaleNameservers}, - - {name: "subdomain.tailscale.com.", qtype: dns.Type(dns.TypeA), resp: tailscaleNameservers}, - {name: "subdomain.tailscale.com.", qtype: dns.Type(dns.TypeAAAA), resp: tailscaleNameservers}, - }, - - // Query to the actual nameserver works. - amazonNS: { - {name: "subdomain.otherdomain.com.", qtype: dns.Type(dns.TypeA), resp: &dns.Msg{ - MsgHdr: dns.MsgHdr{Authoritative: true}, - Answer: []dns.RR{cnameRR("subdomain.otherdomain.com.", "subdomain.tailscale.com.")}, - }}, - {name: "subdomain.otherdomain.com.", qtype: dns.Type(dns.TypeAAAA), resp: &dns.Msg{ - MsgHdr: dns.MsgHdr{Authoritative: true}, - Answer: []dns.RR{cnameRR("subdomain.otherdomain.com.", "subdomain.tailscale.com.")}, - }}, - - {name: "subdomain.tailscale.com.", qtype: dns.Type(dns.TypeA), resp: &dns.Msg{ - MsgHdr: dns.MsgHdr{Authoritative: true}, - Answer: []dns.RR{dnsIPRR("tailscale.com.", netip.MustParseAddr("13.248.141.131"))}, - }}, - {name: "subdomain.tailscale.com.", qtype: dns.Type(dns.TypeAAAA), resp: &dns.Msg{ - MsgHdr: dns.MsgHdr{Authoritative: true}, - Answer: []dns.RR{dnsIPRR("tailscale.com.", netip.MustParseAddr("2600:9000:a602:b1e6:86d:8165:5e8c:295b"))}, - }}, - }, - }, - } - - r := &Resolver{ - Logf: t.Logf, - testExchangeHook: mock.exchangeHook, - rootServers: []netip.Addr{rootServerAddr}, - } - - // Query for tailscale.com, verify we get the right responses - addrs, minTTL, err := r.Resolve(context.Background(), "subdomain.otherdomain.com") - if err != nil { - t.Fatal(err) - } - wantAddrs := []netip.Addr{ - netip.MustParseAddr("13.248.141.131"), - netip.MustParseAddr("2600:9000:a602:b1e6:86d:8165:5e8c:295b"), - } - slices.SortFunc(addrs, func(x, y netip.Addr) int { return strings.Compare(x.String(), y.String()) }) - slices.SortFunc(wantAddrs, func(x, y netip.Addr) int { return strings.Compare(x.String(), y.String()) }) - - if !reflect.DeepEqual(addrs, wantAddrs) { - t.Errorf("got addrs=%+v; want %+v", addrs, wantAddrs) - } - - const wantMinTTL = 5 * time.Minute - if minTTL != wantMinTTL { - t.Errorf("got minTTL=%+v; want %+v", minTTL, wantMinTTL) - } -} - -func TestRecursionNoGlue(t *testing.T) { - coukNS := netip.MustParseAddr("213.248.216.1") - coukRecord := &dns.Msg{ - Ns: []dns.RR{nsRR("com.", "dns1.nic.uk.")}, - Extra: []dns.RR{dnsIPRR("dns1.nic.uk.", coukNS)}, - } - - intermediateNS := netip.MustParseAddr("205.251.193.66") // g-ns-322.awsdns-02.co.uk. - intermediateRecord := &dns.Msg{ - Ns: []dns.RR{nsRR("awsdns-02.co.uk.", "g-ns-322.awsdns-02.co.uk.")}, - Extra: []dns.RR{dnsIPRR("g-ns-322.awsdns-02.co.uk.", intermediateNS)}, - } - - const amazonNameserver = "ns-1558.awsdns-02.co.uk." - tailscaleNameservers := &dns.Msg{ - Ns: []dns.RR{ - nsRR("tailscale.com.", amazonNameserver), - }, - } - - tailscaleResponses := []mockReply{ - {name: "tailscale.com.", qtype: dns.Type(dns.TypeA), resp: &dns.Msg{ - MsgHdr: dns.MsgHdr{Authoritative: true}, - Answer: []dns.RR{dnsIPRR("tailscale.com.", netip.MustParseAddr("13.248.141.131"))}, - }}, - {name: "tailscale.com.", qtype: dns.Type(dns.TypeAAAA), resp: &dns.Msg{ - MsgHdr: dns.MsgHdr{Authoritative: true}, - Answer: []dns.RR{dnsIPRR("tailscale.com.", netip.MustParseAddr("2600:9000:a602:b1e6:86d:8165:5e8c:295b"))}, - }}, - } - - mock := &replyMock{ - tb: t, - replies: map[netip.Addr][]mockReply{ - rootServerAddr: { - // Query to the root server returns the .com server + a glue record - {name: "tailscale.com.", qtype: dns.Type(dns.TypeA), resp: comRecord}, - {name: "tailscale.com.", qtype: dns.Type(dns.TypeAAAA), resp: comRecord}, - - // Querying the .co.uk nameserver returns the .co.uk nameserver + a glue record. - {name: amazonNameserver, qtype: dns.Type(dns.TypeA), resp: coukRecord}, - {name: amazonNameserver, qtype: dns.Type(dns.TypeAAAA), resp: coukRecord}, - }, - - // Queries to the ".com" server return the nameservers - // for tailscale.com, which don't contain a glue - // record. - comNSAddr: { - {name: "tailscale.com.", qtype: dns.Type(dns.TypeA), resp: tailscaleNameservers}, - {name: "tailscale.com.", qtype: dns.Type(dns.TypeAAAA), resp: tailscaleNameservers}, - }, - - // Queries to the ".co.uk" nameserver returns the - // address of the intermediate Amazon nameserver. - coukNS: { - {name: amazonNameserver, qtype: dns.Type(dns.TypeA), resp: intermediateRecord}, - {name: amazonNameserver, qtype: dns.Type(dns.TypeAAAA), resp: intermediateRecord}, - }, - - // Queries to the intermediate nameserver returns an - // answer for the final Amazon nameserver. - intermediateNS: { - {name: amazonNameserver, qtype: dns.Type(dns.TypeA), resp: &dns.Msg{ - MsgHdr: dns.MsgHdr{Authoritative: true}, - Answer: []dns.RR{dnsIPRR(amazonNameserver, amazonNS)}, - }}, - {name: amazonNameserver, qtype: dns.Type(dns.TypeAAAA), resp: &dns.Msg{ - MsgHdr: dns.MsgHdr{Authoritative: true}, - Answer: []dns.RR{dnsIPRR(amazonNameserver, amazonNSv6)}, - }}, - }, - - // Queries to the actual nameserver work and return - // responses to the query. - amazonNS: tailscaleResponses, - amazonNSv6: tailscaleResponses, - }, - } - - r := newResolver(t) - r.testExchangeHook = mock.exchangeHook - r.rootServers = []netip.Addr{rootServerAddr} - - // Query for tailscale.com, verify we get the right responses - addrs, minTTL, err := r.Resolve(context.Background(), "tailscale.com") - if err != nil { - t.Fatal(err) - } - wantAddrs := []netip.Addr{ - netip.MustParseAddr("13.248.141.131"), - netip.MustParseAddr("2600:9000:a602:b1e6:86d:8165:5e8c:295b"), - } - slices.SortFunc(addrs, func(x, y netip.Addr) int { return strings.Compare(x.String(), y.String()) }) - slices.SortFunc(wantAddrs, func(x, y netip.Addr) int { return strings.Compare(x.String(), y.String()) }) - - if !reflect.DeepEqual(addrs, wantAddrs) { - t.Errorf("got addrs=%+v; want %+v", addrs, wantAddrs) - } - - const wantMinTTL = 5 * time.Minute - if minTTL != wantMinTTL { - t.Errorf("got minTTL=%+v; want %+v", minTTL, wantMinTTL) - } -} - -func TestRecursionLimit(t *testing.T) { - mock := &replyMock{ - tb: t, - replies: map[netip.Addr][]mockReply{}, - } - - // Fill out a CNAME chain equal to our recursion limit; we won't get - // this far since each CNAME is more than 1 level "deep", but this - // ensures that we have more than the limit. - for i := range maxDepth + 1 { - curr := fmt.Sprintf("%d-tailscale.com.", i) - - tailscaleNameservers := &dns.Msg{ - Ns: []dns.RR{nsRR(curr, "ns-197.awsdns-24.com.")}, - Extra: []dns.RR{dnsIPRR("ns-197.awsdns-24.com.", amazonNS)}, - } - - // Query to the root server returns the .com server + a glue record - mock.replies[rootServerAddr] = append(mock.replies[rootServerAddr], - mockReply{name: curr, qtype: dns.Type(dns.TypeA), resp: comRecord}, - mockReply{name: curr, qtype: dns.Type(dns.TypeAAAA), resp: comRecord}, - ) - - // Query to the ".com" server return the nameservers for NN-tailscale.com - mock.replies[comNSAddr] = append(mock.replies[comNSAddr], - mockReply{name: curr, qtype: dns.Type(dns.TypeA), resp: tailscaleNameservers}, - mockReply{name: curr, qtype: dns.Type(dns.TypeAAAA), resp: tailscaleNameservers}, - ) - - // Queries to the nameserver return a CNAME for the n+1th server. - next := fmt.Sprintf("%d-tailscale.com.", i+1) - mock.replies[amazonNS] = append(mock.replies[amazonNS], - mockReply{ - name: curr, - qtype: dns.Type(dns.TypeA), - resp: &dns.Msg{ - MsgHdr: dns.MsgHdr{Authoritative: true}, - Answer: []dns.RR{cnameRR(curr, next)}, - }, - }, - mockReply{ - name: curr, - qtype: dns.Type(dns.TypeAAAA), - resp: &dns.Msg{ - MsgHdr: dns.MsgHdr{Authoritative: true}, - Answer: []dns.RR{cnameRR(curr, next)}, - }, - }, - ) - } - - r := newResolver(t) - r.testExchangeHook = mock.exchangeHook - r.rootServers = []netip.Addr{rootServerAddr} - - // Query for the first node in the chain, 0-tailscale.com, and verify - // we get a max-depth error. - ctx := context.Background() - _, _, err := r.Resolve(ctx, "0-tailscale.com") - if err == nil { - t.Fatal("expected error, got nil") - } else if !errors.Is(err, ErrMaxDepth) { - t.Fatalf("got err=%v, want ErrMaxDepth", err) - } -} - -func TestInvalidResponses(t *testing.T) { - mock := &replyMock{ - tb: t, - replies: map[netip.Addr][]mockReply{ - // Query to the root server returns the .com server + a glue record - rootServerAddr: { - {name: "tailscale.com.", qtype: dns.Type(dns.TypeA), resp: comRecord}, - {name: "tailscale.com.", qtype: dns.Type(dns.TypeAAAA), resp: comRecord}, - }, - - // Query to the ".com" server return the nameservers for tailscale.com - comNSAddr: { - {name: "tailscale.com.", qtype: dns.Type(dns.TypeA), resp: tailscaleNameservers}, - {name: "tailscale.com.", qtype: dns.Type(dns.TypeAAAA), resp: tailscaleNameservers}, - }, - - // Query to the actual nameserver returns an invalid IP address - amazonNS: { - {name: "tailscale.com.", qtype: dns.Type(dns.TypeA), resp: &dns.Msg{ - MsgHdr: dns.MsgHdr{Authoritative: true}, - Answer: []dns.RR{&dns.A{ - Hdr: dns.RR_Header{ - Name: "tailscale.com.", - Rrtype: dns.TypeA, - Class: dns.ClassINET, - Ttl: 300, - }, - // Note: this is an IPv6 addr in an IPv4 response - A: net.IP(netip.MustParseAddr("2600:9000:a51d:27c1:1530:b9ef:2a6:b9e5").AsSlice()), - }}, - }}, - {name: "tailscale.com.", qtype: dns.Type(dns.TypeAAAA), resp: &dns.Msg{ - MsgHdr: dns.MsgHdr{Authoritative: true}, - // This an IPv4 response to an IPv6 query - Answer: []dns.RR{&dns.A{ - Hdr: dns.RR_Header{ - Name: "tailscale.com.", - Rrtype: dns.TypeA, - Class: dns.ClassINET, - Ttl: 300, - }, - A: net.IP(netip.MustParseAddr("13.248.141.131").AsSlice()), - }}, - }}, - }, - }, - } - - r := &Resolver{ - Logf: t.Logf, - testExchangeHook: mock.exchangeHook, - rootServers: []netip.Addr{rootServerAddr}, - } - - // Query for tailscale.com, verify we get no responses since the - // addresses are invalid. - _, _, err := r.Resolve(context.Background(), "tailscale.com") - if err == nil { - t.Fatalf("got no error, want error") - } - if !errors.Is(err, ErrAuthoritativeNoResponses) { - t.Fatalf("got err=%v, want %v", err, ErrAuthoritativeNoResponses) - } -} - -// TODO(andrew): test for more edge cases that aren't currently covered: -// * Nameservers that cross between IPv4 and IPv6 -// * Authoritative no replies after following CNAME -// * Authoritative no replies after following non-glue NS record -// * Error querying non-glue NS record followed by success diff --git a/net/dns/resolvd.go b/net/dns/resolvd.go index 9b067eb07..ad1a99c11 100644 --- a/net/dns/resolvd.go +++ b/net/dns/resolvd.go @@ -57,6 +57,7 @@ func (m *resolvdManager) SetDNS(config OSConfig) error { if len(newSearch) > 1 { newResolvConf = append(newResolvConf, []byte(strings.Join(newSearch, " "))...) + newResolvConf = append(newResolvConf, '\n') } err = m.fs.WriteFile(resolvConf, newResolvConf, 0644) @@ -123,6 +124,6 @@ func (m resolvdManager) readResolvConf() (config OSConfig, err error) { } func removeSearchLines(orig []byte) []byte { - re := regexp.MustCompile(`(?m)^search\s+.+$`) + re := regexp.MustCompile(`(?ms)^search\s+.+$`) return re.ReplaceAll(orig, []byte("")) } diff --git a/net/dns/resolved.go b/net/dns/resolved.go index d82d3fc31..d8f63c9d6 100644 --- a/net/dns/resolved.go +++ b/net/dns/resolved.go @@ -1,7 +1,7 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -//go:build linux +//go:build linux && !android && !ts_omit_resolved package dns @@ -15,8 +15,8 @@ import ( "github.com/godbus/dbus/v5" "golang.org/x/sys/unix" "tailscale.com/health" - "tailscale.com/logtail/backoff" "tailscale.com/types/logger" + "tailscale.com/util/backoff" "tailscale.com/util/dnsname" ) @@ -70,7 +70,11 @@ type resolvedManager struct { configCR chan changeRequest // tracks OSConfigs changes and error responses } -func newResolvedManager(logf logger.Logf, health *health.Tracker, interfaceName string) (*resolvedManager, error) { +func init() { + optNewResolvedManager.Set(newResolvedManager) +} + +func newResolvedManager(logf logger.Logf, health *health.Tracker, interfaceName string) (OSConfigurator, error) { iface, err := net.InterfaceByName(interfaceName) if err != nil { return nil, err @@ -163,9 +167,9 @@ func (m *resolvedManager) run(ctx context.Context) { } conn.Signal(signals) - // Reset backoff and SetNSOSHealth after successful on reconnect. + // Reset backoff and set osConfigurationSetWarnable to healthy after a successful reconnect. bo.BackOff(ctx, nil) - m.health.SetDNSOSHealth(nil) + m.health.SetHealthy(osConfigurationSetWarnable) return nil } @@ -243,9 +247,12 @@ func (m *resolvedManager) run(ctx context.Context) { // Set health while holding the lock, because this will // graciously serialize the resync's health outcome with a // concurrent SetDNS call. - m.health.SetDNSOSHealth(err) + if err != nil { m.logf("failed to configure systemd-resolved: %v", err) + m.health.SetUnhealthy(osConfigurationSetWarnable, health.Args{health.ArgError: err.Error()}) + } else { + m.health.SetHealthy(osConfigurationSetWarnable) } } } diff --git a/net/dns/resolver/debug.go b/net/dns/resolver/debug.go index da195d49d..a41462e18 100644 --- a/net/dns/resolver/debug.go +++ b/net/dns/resolver/debug.go @@ -8,14 +8,18 @@ import ( "html" "net/http" "strconv" - "sync" "sync/atomic" "time" + "tailscale.com/feature/buildfeatures" "tailscale.com/health" + "tailscale.com/syncs" ) func init() { + if !buildfeatures.HasDNS { + return + } health.RegisterDebugHandler("dnsfwd", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { n, _ := strconv.Atoi(r.FormValue("n")) if n <= 0 { @@ -35,7 +39,7 @@ func init() { var fwdLogAtomic atomic.Pointer[fwdLog] type fwdLog struct { - mu sync.Mutex + mu syncs.Mutex pos int // ent[pos] is next entry ent []fwdLogEntry } diff --git a/net/dns/resolver/forwarder.go b/net/dns/resolver/forwarder.go index 846ca3d5e..5adc43efc 100644 --- a/net/dns/resolver/forwarder.go +++ b/net/dns/resolver/forwarder.go @@ -17,6 +17,7 @@ import ( "net/http" "net/netip" "net/url" + "runtime" "sort" "strings" "sync" @@ -26,13 +27,17 @@ import ( dns "golang.org/x/net/dns/dnsmessage" "tailscale.com/control/controlknobs" "tailscale.com/envknob" + "tailscale.com/feature" + "tailscale.com/feature/buildfeatures" "tailscale.com/health" "tailscale.com/net/dns/publicdns" "tailscale.com/net/dnscache" "tailscale.com/net/neterror" "tailscale.com/net/netmon" + "tailscale.com/net/netx" "tailscale.com/net/sockstats" "tailscale.com/net/tsdial" + "tailscale.com/syncs" "tailscale.com/types/dnstype" "tailscale.com/types/logger" "tailscale.com/types/nettype" @@ -215,18 +220,19 @@ type resolverAndDelay struct { // forwarder forwards DNS packets to a number of upstream nameservers. type forwarder struct { - logf logger.Logf - netMon *netmon.Monitor // always non-nil - linkSel ForwardLinkSelector // TODO(bradfitz): remove this when tsdial.Dialer absorbs it - dialer *tsdial.Dialer - health *health.Tracker // always non-nil + logf logger.Logf + netMon *netmon.Monitor // always non-nil + linkSel ForwardLinkSelector // TODO(bradfitz): remove this when tsdial.Dialer absorbs it + dialer *tsdial.Dialer + health *health.Tracker // always non-nil + verboseFwd bool // if true, log all DNS forwarding controlKnobs *controlknobs.Knobs // or nil ctx context.Context // good until Close ctxCancel context.CancelFunc // closes ctx - mu sync.Mutex // guards following + mu syncs.Mutex // guards following dohClient map[string]*http.Client // urlBase -> client @@ -243,26 +249,23 @@ type forwarder struct { // /etc/resolv.conf is missing/corrupt, and the peerapi ExitDNS stub // resolver lookup. cloudHostFallback []resolverAndDelay - - // missingUpstreamRecovery, if non-nil, is set called when a SERVFAIL is - // returned due to missing upstream resolvers. - // - // This should attempt to properly (re)set the upstream resolvers. - missingUpstreamRecovery func() } func newForwarder(logf logger.Logf, netMon *netmon.Monitor, linkSel ForwardLinkSelector, dialer *tsdial.Dialer, health *health.Tracker, knobs *controlknobs.Knobs) *forwarder { + if !buildfeatures.HasDNS { + return nil + } if netMon == nil { panic("nil netMon") } f := &forwarder{ - logf: logger.WithPrefix(logf, "forward: "), - netMon: netMon, - linkSel: linkSel, - dialer: dialer, - health: health, - controlKnobs: knobs, - missingUpstreamRecovery: func() {}, + logf: logger.WithPrefix(logf, "forward: "), + netMon: netMon, + linkSel: linkSel, + dialer: dialer, + health: health, + controlKnobs: knobs, + verboseFwd: verboseDNSForward(), } f.ctx, f.ctxCancel = context.WithCancel(context.Background()) return f @@ -487,6 +490,10 @@ func (f *forwarder) sendDoH(ctx context.Context, urlBase string, c *http.Client, defer hres.Body.Close() if hres.StatusCode != 200 { metricDNSFwdDoHErrorStatus.Add(1) + if hres.StatusCode/100 == 5 { + // Translate 5xx HTTP server errors into SERVFAIL DNS responses. + return nil, fmt.Errorf("%w: %s", errServerFailure, hres.Status) + } return nil, errors.New(hres.Status) } if ct := hres.Header.Get("Content-Type"); ct != dohType { @@ -516,7 +523,7 @@ var ( // // send expects the reply to have the same txid as txidOut. func (f *forwarder) send(ctx context.Context, fq *forwardQuery, rr resolverAndDelay) (ret []byte, err error) { - if verboseDNSForward() { + if f.verboseFwd { id := forwarderCount.Add(1) domain, typ, _ := nameFromQuery(fq.packet) f.logf("forwarder.send(%q, %d, %v, %d) [%d] ...", rr.name.Addr, fq.txid, typ, len(domain), id) @@ -525,6 +532,9 @@ func (f *forwarder) send(ctx context.Context, fq *forwardQuery, rr resolverAndDe }() } if strings.HasPrefix(rr.name.Addr, "http://") { + if !buildfeatures.HasPeerAPIClient { + return nil, feature.ErrUnavailable + } return f.sendDoH(ctx, rr.name.Addr, f.dialer.PeerAPIHTTPClient(), fq.packet) } if strings.HasPrefix(rr.name.Addr, "https://") { @@ -735,18 +745,38 @@ func (f *forwarder) sendUDP(ctx context.Context, fq *forwardQuery, rr resolverAn return out, nil } -func (f *forwarder) getDialerType() dnscache.DialContextFunc { - if f.controlKnobs != nil && f.controlKnobs.UserDialUseRoutes.Load() { - // It is safe to use UserDial as it dials external servers without going through Tailscale - // and closes connections on interface change in the same way as SystemDial does, - // thus preventing DNS resolution issues when switching between WiFi and cellular, - // but can also dial an internal DNS server on the Tailnet or via a subnet router. - // - // TODO(nickkhyl): Update tsdial.Dialer to reuse the bart.Table we create in net/tstun.Wrapper - // to avoid having two bart tables in memory, especially on iOS. Once that's done, - // we can get rid of the nodeAttr/control knob and always use UserDial for DNS. - // - // See https://github.com/tailscale/tailscale/issues/12027. +var optDNSForwardUseRoutes = envknob.RegisterOptBool("TS_DEBUG_DNS_FORWARD_USE_ROUTES") + +// ShouldUseRoutes reports whether the DNS resolver should consider routes when dialing +// upstream nameservers via TCP. +// +// If true, routes should be considered ([tsdial.Dialer.UserDial]), otherwise defer +// to the system routes ([tsdial.Dialer.SystemDial]). +// +// TODO(nickkhyl): Update [tsdial.Dialer] to reuse the bart.Table we create in net/tstun.Wrapper +// to avoid having two bart tables in memory, especially on iOS. Once that's done, +// we can get rid of the nodeAttr/control knob and always use UserDial for DNS. +// +// See tailscale/tailscale#12027. +func ShouldUseRoutes(knobs *controlknobs.Knobs) bool { + if !buildfeatures.HasDNS { + return false + } + switch runtime.GOOS { + case "android", "ios": + // On mobile platforms with lower memory limits (e.g., 50MB on iOS), + // this behavior is still gated by the "user-dial-routes" nodeAttr. + return knobs != nil && knobs.UserDialUseRoutes.Load() + default: + // On all other platforms, it is the default behavior, + // but it can be overridden with the "TS_DEBUG_DNS_FORWARD_USE_ROUTES" env var. + doNotUseRoutes := optDNSForwardUseRoutes().EqualBool(false) + return !doNotUseRoutes + } +} + +func (f *forwarder) getDialerType() netx.DialFunc { + if ShouldUseRoutes(f.controlKnobs) { return f.dialer.UserDial } return f.dialer.SystemDial @@ -916,10 +946,7 @@ func (f *forwarder) forwardWithDestChan(ctx context.Context, query packet, respo metricDNSFwdDropBonjour.Add(1) res, err := nxDomainResponse(query) if err != nil { - f.logf("error parsing bonjour query: %v", err) - // Returning an error will cause an internal retry, there is - // nothing we can do if parsing failed. Just drop the packet. - return nil + return err } select { case <-ctx.Done(): @@ -942,19 +969,9 @@ func (f *forwarder) forwardWithDestChan(ctx context.Context, query packet, respo f.health.SetUnhealthy(dnsForwarderFailing, health.Args{health.ArgDNSServers: ""}) f.logf("no upstream resolvers set, returning SERVFAIL") - // Attempt to recompile the DNS configuration - // If we are being asked to forward queries and we have no - // nameservers, the network is in a bad state. - if f.missingUpstreamRecovery != nil { - f.missingUpstreamRecovery() - } - res, err := servfailResponse(query) if err != nil { - f.logf("building servfail response: %v", err) - // Returning an error will cause an internal retry, there is - // nothing we can do if parsing failed. Just drop the packet. - return nil + return err } select { case <-ctx.Done(): @@ -975,7 +992,7 @@ func (f *forwarder) forwardWithDestChan(ctx context.Context, query packet, respo } defer fq.closeOnCtxDone.Close() - if verboseDNSForward() { + if f.verboseFwd { domainSha256 := sha256.Sum256([]byte(domain)) domainSig := base64.RawStdEncoding.EncodeToString(domainSha256[:3]) f.logf("request(%d, %v, %d, %s) %d...", fq.txid, typ, len(domain), domainSig, len(fq.packet)) @@ -1020,7 +1037,7 @@ func (f *forwarder) forwardWithDestChan(ctx context.Context, query packet, respo metricDNSFwdErrorContext.Add(1) return fmt.Errorf("waiting to send response: %w", ctx.Err()) case responseChan <- packet{v, query.family, query.addr}: - if verboseDNSForward() { + if f.verboseFwd { f.logf("response(%d, %v, %d) = %d, nil", fq.txid, typ, len(domain), len(v)) } metricDNSFwdSuccess.Add(1) @@ -1050,9 +1067,10 @@ func (f *forwarder) forwardWithDestChan(ctx context.Context, query packet, respo } f.health.SetUnhealthy(dnsForwarderFailing, health.Args{health.ArgDNSServers: strings.Join(resolverAddrs, ",")}) case responseChan <- res: - if verboseDNSForward() { + if f.verboseFwd { f.logf("forwarder response(%d, %v, %d) = %d, %v", fq.txid, typ, len(domain), len(res.bs), firstErr) } + return nil } } return firstErr diff --git a/net/dns/resolver/forwarder_test.go b/net/dns/resolver/forwarder_test.go index 09d810901..ec491c581 100644 --- a/net/dns/resolver/forwarder_test.go +++ b/net/dns/resolver/forwarder_test.go @@ -7,13 +7,11 @@ import ( "bytes" "context" "encoding/binary" - "errors" "flag" "fmt" "io" "net" "net/netip" - "os" "reflect" "slices" "strings" @@ -24,11 +22,12 @@ import ( dns "golang.org/x/net/dns/dnsmessage" "tailscale.com/control/controlknobs" - "tailscale.com/envknob" "tailscale.com/health" "tailscale.com/net/netmon" "tailscale.com/net/tsdial" + "tailscale.com/tstest" "tailscale.com/types/dnstype" + "tailscale.com/util/eventbus/eventbustest" ) func (rr resolverAndDelay) String() string { @@ -123,7 +122,6 @@ func TestResolversWithDelays(t *testing.T) { } }) } - } func TestGetRCode(t *testing.T) { @@ -277,6 +275,8 @@ func runDNSServer(tb testing.TB, opts *testDNSServerOptions, response []byte, on tb.Fatal("cannot skip both UDP and TCP servers") } + logf := tstest.WhileTestRunningLogger(tb) + tcpResponse := make([]byte, len(response)+2) binary.BigEndian.PutUint16(tcpResponse, uint16(len(response))) copy(tcpResponse[2:], response) @@ -330,13 +330,13 @@ func runDNSServer(tb testing.TB, opts *testDNSServerOptions, response []byte, on // Read the length header, then the buffer var length uint16 if err := binary.Read(conn, binary.BigEndian, &length); err != nil { - tb.Logf("error reading length header: %v", err) + logf("error reading length header: %v", err) return } req := make([]byte, length) n, err := io.ReadFull(conn, req) if err != nil { - tb.Logf("error reading query: %v", err) + logf("error reading query: %v", err) return } req = req[:n] @@ -344,7 +344,7 @@ func runDNSServer(tb testing.TB, opts *testDNSServerOptions, response []byte, on // Write response if _, err := conn.Write(tcpResponse); err != nil { - tb.Logf("error writing response: %v", err) + logf("error writing response: %v", err) return } } @@ -368,7 +368,7 @@ func runDNSServer(tb testing.TB, opts *testDNSServerOptions, response []byte, on handleUDP := func(addr netip.AddrPort, req []byte) { onRequest(false, req) if _, err := udpLn.WriteToUDPAddrPort(response, addr); err != nil { - tb.Logf("error writing response: %v", err) + logf("error writing response: %v", err) } } @@ -391,19 +391,12 @@ func runDNSServer(tb testing.TB, opts *testDNSServerOptions, response []byte, on tb.Cleanup(func() { tcpLn.Close() udpLn.Close() - tb.Logf("waiting for listeners to finish...") + logf("waiting for listeners to finish...") wg.Wait() }) return } -func enableDebug(tb testing.TB) { - const debugKnob = "TS_DEBUG_DNS_FORWARD_SEND" - oldVal := os.Getenv(debugKnob) - envknob.Setenv(debugKnob, "true") - tb.Cleanup(func() { envknob.Setenv(debugKnob, oldVal) }) -} - func makeLargeResponse(tb testing.TB, domain string) (request, response []byte) { name := dns.MustNewName(domain) @@ -450,22 +443,26 @@ func makeLargeResponse(tb testing.TB, domain string) (request, response []byte) return } -func runTestQuery(tb testing.TB, port uint16, request []byte, modify func(*forwarder)) ([]byte, error) { - netMon, err := netmon.New(tb.Logf) +func runTestQuery(tb testing.TB, request []byte, modify func(*forwarder), ports ...uint16) ([]byte, error) { + logf := tstest.WhileTestRunningLogger(tb) + bus := eventbustest.NewBus(tb) + netMon, err := netmon.New(bus, logf) if err != nil { tb.Fatal(err) } var dialer tsdial.Dialer dialer.SetNetMon(netMon) + dialer.SetBus(bus) - fwd := newForwarder(tb.Logf, netMon, nil, &dialer, new(health.Tracker), nil) + fwd := newForwarder(logf, netMon, nil, &dialer, health.NewTracker(bus), nil) if modify != nil { modify(fwd) } - rr := resolverAndDelay{ - name: &dnstype.Resolver{Addr: fmt.Sprintf("127.0.0.1:%d", port)}, + resolvers := make([]resolverAndDelay, len(ports)) + for i, port := range ports { + resolvers[i].name = &dnstype.Resolver{Addr: fmt.Sprintf("127.0.0.1:%d", port)} } rpkt := packet{ @@ -477,7 +474,7 @@ func runTestQuery(tb testing.TB, port uint16, request []byte, modify func(*forwa rchan := make(chan packet, 1) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) tb.Cleanup(cancel) - err = fwd.forwardWithDestChan(ctx, rpkt, rchan, rr) + err = fwd.forwardWithDestChan(ctx, rpkt, rchan, resolvers...) select { case res := <-rchan: return res.bs, err @@ -486,17 +483,73 @@ func runTestQuery(tb testing.TB, port uint16, request []byte, modify func(*forwa } } -func mustRunTestQuery(tb testing.TB, port uint16, request []byte, modify func(*forwarder)) []byte { - resp, err := runTestQuery(tb, port, request, modify) +// makeTestRequest returns a new TypeA request for the given domain. +func makeTestRequest(tb testing.TB, domain string) []byte { + tb.Helper() + name := dns.MustNewName(domain) + builder := dns.NewBuilder(nil, dns.Header{}) + builder.StartQuestions() + builder.Question(dns.Question{ + Name: name, + Type: dns.TypeA, + Class: dns.ClassINET, + }) + request, err := builder.Finish() + if err != nil { + tb.Fatal(err) + } + return request +} + +// makeTestResponse returns a new Type A response for the given domain, +// with the specified status code and zero or more addresses. +func makeTestResponse(tb testing.TB, domain string, code dns.RCode, addrs ...netip.Addr) []byte { + tb.Helper() + name := dns.MustNewName(domain) + builder := dns.NewBuilder(nil, dns.Header{ + Response: true, + Authoritative: true, + RCode: code, + }) + builder.StartQuestions() + q := dns.Question{ + Name: name, + Type: dns.TypeA, + Class: dns.ClassINET, + } + builder.Question(q) + if len(addrs) > 0 { + builder.StartAnswers() + for _, addr := range addrs { + builder.AResource(dns.ResourceHeader{ + Name: q.Name, + Class: q.Class, + TTL: 120, + }, dns.AResource{ + A: addr.As4(), + }) + } + } + response, err := builder.Finish() + if err != nil { + tb.Fatal(err) + } + return response +} + +func mustRunTestQuery(tb testing.TB, request []byte, modify func(*forwarder), ports ...uint16) []byte { + resp, err := runTestQuery(tb, request, modify, ports...) if err != nil { tb.Fatalf("error making request: %v", err) } return resp } -func TestForwarderTCPFallback(t *testing.T) { - enableDebug(t) +func beVerbose(f *forwarder) { + f.verboseFwd = true +} +func TestForwarderTCPFallback(t *testing.T) { const domain = "large-dns-response.tailscale.com." // Make a response that's very large, containing a bunch of localhost addresses. @@ -516,7 +569,7 @@ func TestForwarderTCPFallback(t *testing.T) { } }) - resp := mustRunTestQuery(t, port, request, nil) + resp := mustRunTestQuery(t, request, beVerbose, port) if !bytes.Equal(resp, largeResponse) { t.Errorf("invalid response\ngot: %+v\nwant: %+v", resp, largeResponse) } @@ -532,8 +585,6 @@ func TestForwarderTCPFallback(t *testing.T) { // Test to ensure that if the UDP listener is unresponsive, we always make a // TCP request even if we never get a response. func TestForwarderTCPFallbackTimeout(t *testing.T) { - enableDebug(t) - const domain = "large-dns-response.tailscale.com." // Make a response that's very large, containing a bunch of localhost addresses. @@ -554,7 +605,7 @@ func TestForwarderTCPFallbackTimeout(t *testing.T) { } }) - resp := mustRunTestQuery(t, port, request, nil) + resp := mustRunTestQuery(t, request, beVerbose, port) if !bytes.Equal(resp, largeResponse) { t.Errorf("invalid response\ngot: %+v\nwant: %+v", resp, largeResponse) } @@ -564,8 +615,6 @@ func TestForwarderTCPFallbackTimeout(t *testing.T) { } func TestForwarderTCPFallbackDisabled(t *testing.T) { - enableDebug(t) - const domain = "large-dns-response.tailscale.com." // Make a response that's very large, containing a bunch of localhost addresses. @@ -585,11 +634,12 @@ func TestForwarderTCPFallbackDisabled(t *testing.T) { } }) - resp := mustRunTestQuery(t, port, request, func(fwd *forwarder) { + resp := mustRunTestQuery(t, request, func(fwd *forwarder) { + fwd.verboseFwd = true // Disable retries for this test. fwd.controlKnobs = &controlknobs.Knobs{} fwd.controlKnobs.DisableDNSForwarderTCPRetries.Store(true) - }) + }, port) wantResp := append([]byte(nil), largeResponse[:maxResponseBytes]...) @@ -608,46 +658,13 @@ func TestForwarderTCPFallbackDisabled(t *testing.T) { // Test to ensure that we propagate DNS errors func TestForwarderTCPFallbackError(t *testing.T) { - enableDebug(t) - const domain = "error-response.tailscale.com." // Our response is a SERVFAIL - response := func() []byte { - name := dns.MustNewName(domain) - - builder := dns.NewBuilder(nil, dns.Header{ - Response: true, - RCode: dns.RCodeServerFailure, - }) - builder.StartQuestions() - builder.Question(dns.Question{ - Name: name, - Type: dns.TypeA, - Class: dns.ClassINET, - }) - response, err := builder.Finish() - if err != nil { - t.Fatal(err) - } - return response - }() + response := makeTestResponse(t, domain, dns.RCodeServerFailure) // Our request is a single A query for the domain in the answer, above. - request := func() []byte { - builder := dns.NewBuilder(nil, dns.Header{}) - builder.StartQuestions() - builder.Question(dns.Question{ - Name: dns.MustNewName(domain), - Type: dns.TypeA, - Class: dns.ClassINET, - }) - request, err := builder.Finish() - if err != nil { - t.Fatal(err) - } - return request - }() + request := makeTestRequest(t, domain) var sawRequest atomic.Bool port := runDNSServer(t, nil, response, func(isTCP bool, gotRequest []byte) { @@ -657,14 +674,139 @@ func TestForwarderTCPFallbackError(t *testing.T) { } }) - _, err := runTestQuery(t, port, request, nil) + resp, err := runTestQuery(t, request, beVerbose, port) if !sawRequest.Load() { t.Error("did not see DNS request") } - if err == nil { - t.Error("wanted error, got nil") - } else if !errors.Is(err, errServerFailure) { - t.Errorf("wanted errServerFailure, got: %v", err) + if err != nil { + t.Fatalf("wanted nil, got %v", err) + } + var parser dns.Parser + respHeader, err := parser.Start(resp) + if err != nil { + t.Fatalf("parser.Start() failed: %v", err) + } + if got, want := respHeader.RCode, dns.RCodeServerFailure; got != want { + t.Errorf("wanted %v, got %v", want, got) + } +} + +// Test to ensure that if we have more than one resolver, and at least one of them +// returns a successful response, we propagate it. +func TestForwarderWithManyResolvers(t *testing.T) { + const domain = "example.com." + request := makeTestRequest(t, domain) + + tests := []struct { + name string + responses [][]byte // upstream responses + wantResponses [][]byte // we should receive one of these from the forwarder + }{ + { + name: "Success", + responses: [][]byte{ // All upstream servers returned successful, but different, response. + makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.1")), + makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.2")), + makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.3")), + }, + wantResponses: [][]byte{ // We may forward whichever response is received first. + makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.1")), + makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.2")), + makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.3")), + }, + }, + { + name: "ServFail", + responses: [][]byte{ // All upstream servers returned a SERVFAIL. + makeTestResponse(t, domain, dns.RCodeServerFailure), + makeTestResponse(t, domain, dns.RCodeServerFailure), + makeTestResponse(t, domain, dns.RCodeServerFailure), + }, + wantResponses: [][]byte{ + makeTestResponse(t, domain, dns.RCodeServerFailure), + }, + }, + { + name: "ServFail+Success", + responses: [][]byte{ // All upstream servers fail except for one. + makeTestResponse(t, domain, dns.RCodeServerFailure), + makeTestResponse(t, domain, dns.RCodeServerFailure), + makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.1")), + makeTestResponse(t, domain, dns.RCodeServerFailure), + }, + wantResponses: [][]byte{ // We should forward the successful response. + makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.1")), + }, + }, + { + name: "NXDomain", + responses: [][]byte{ // All upstream servers returned NXDOMAIN. + makeTestResponse(t, domain, dns.RCodeNameError), + makeTestResponse(t, domain, dns.RCodeNameError), + makeTestResponse(t, domain, dns.RCodeNameError), + }, + wantResponses: [][]byte{ + makeTestResponse(t, domain, dns.RCodeNameError), + }, + }, + { + name: "NXDomain+Success", + responses: [][]byte{ // All upstream servers returned NXDOMAIN except for one. + makeTestResponse(t, domain, dns.RCodeNameError), + makeTestResponse(t, domain, dns.RCodeNameError), + makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.1")), + }, + wantResponses: [][]byte{ // However, only SERVFAIL are considered to be errors. Therefore, we may forward any response. + makeTestResponse(t, domain, dns.RCodeNameError), + makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.1")), + }, + }, + { + name: "Refused", + responses: [][]byte{ // All upstream servers return different failures. + makeTestResponse(t, domain, dns.RCodeRefused), + makeTestResponse(t, domain, dns.RCodeRefused), + makeTestResponse(t, domain, dns.RCodeRefused), + makeTestResponse(t, domain, dns.RCodeRefused), + makeTestResponse(t, domain, dns.RCodeRefused), + makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.1")), + }, + wantResponses: [][]byte{ // Refused is not considered to be an error and can be forwarded. + makeTestResponse(t, domain, dns.RCodeRefused), + makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.1")), + }, + }, + { + name: "MixFail", + responses: [][]byte{ // All upstream servers return different failures. + makeTestResponse(t, domain, dns.RCodeServerFailure), + makeTestResponse(t, domain, dns.RCodeNameError), + makeTestResponse(t, domain, dns.RCodeRefused), + }, + wantResponses: [][]byte{ // Both NXDomain and Refused can be forwarded. + makeTestResponse(t, domain, dns.RCodeNameError), + makeTestResponse(t, domain, dns.RCodeRefused), + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ports := make([]uint16, len(tt.responses)) + for i := range tt.responses { + ports[i] = runDNSServer(t, nil, tt.responses[i], func(isTCP bool, gotRequest []byte) {}) + } + gotResponse, err := runTestQuery(t, request, beVerbose, ports...) + if err != nil { + t.Fatalf("wanted nil, got %v", err) + } + responseOk := slices.ContainsFunc(tt.wantResponses, func(wantResponse []byte) bool { + return slices.Equal(gotResponse, wantResponse) + }) + if !responseOk { + t.Errorf("invalid response\ngot: %+v\nwant: %+v", gotResponse, tt.wantResponses[0]) + } + }) } } @@ -713,7 +855,7 @@ func TestNXDOMAINIncludesQuestion(t *testing.T) { port := runDNSServer(t, nil, response, func(isTCP bool, gotRequest []byte) { }) - res, err := runTestQuery(t, port, request, nil) + res, err := runTestQuery(t, request, beVerbose, port) if err != nil { t.Fatal(err) } diff --git a/net/dns/resolver/tsdns.go b/net/dns/resolver/tsdns.go index d196ad4d6..3185cbe2b 100644 --- a/net/dns/resolver/tsdns.go +++ b/net/dns/resolver/tsdns.go @@ -25,6 +25,8 @@ import ( dns "golang.org/x/net/dns/dnsmessage" "tailscale.com/control/controlknobs" "tailscale.com/envknob" + "tailscale.com/feature" + "tailscale.com/feature/buildfeatures" "tailscale.com/health" "tailscale.com/net/dns/resolvconffile" "tailscale.com/net/netaddr" @@ -212,7 +214,7 @@ type Resolver struct { closed chan struct{} // mu guards the following fields from being updated while used. - mu sync.Mutex + mu syncs.Mutex localDomains []dnsname.FQDN hostToIP map[dnsname.FQDN][]netip.Addr ipToHost map[netip.Addr]dnsname.FQDN @@ -251,18 +253,12 @@ func New(logf logger.Logf, linkSel ForwardLinkSelector, dialer *tsdial.Dialer, h return r } -// SetMissingUpstreamRecovery sets a callback to be called upon encountering -// a SERVFAIL due to missing upstream resolvers. -// -// This call should only happen before the resolver is used. It is not safe -// for concurrent use. -func (r *Resolver) SetMissingUpstreamRecovery(f func()) { - r.forwarder.missingUpstreamRecovery = f -} - func (r *Resolver) TestOnlySetHook(hook func(Config)) { r.saveConfigForTests = hook } func (r *Resolver) SetConfig(cfg Config) error { + if !buildfeatures.HasDNS { + return nil + } if r.saveConfigForTests != nil { r.saveConfigForTests(cfg) } @@ -288,6 +284,9 @@ func (r *Resolver) SetConfig(cfg Config) error { // Close shuts down the resolver and ensures poll goroutines have exited. // The Resolver cannot be used again after Close is called. func (r *Resolver) Close() { + if !buildfeatures.HasDNS { + return + } select { case <-r.closed: return @@ -305,6 +304,9 @@ func (r *Resolver) Close() { const dnsQueryTimeout = 10 * time.Second func (r *Resolver) Query(ctx context.Context, bs []byte, family string, from netip.AddrPort) ([]byte, error) { + if !buildfeatures.HasDNS { + return nil, feature.ErrUnavailable + } metricDNSQueryLocal.Add(1) select { case <-r.closed: @@ -321,15 +323,7 @@ func (r *Resolver) Query(ctx context.Context, bs []byte, family string, from net defer cancel() err = r.forwarder.forwardWithDestChan(ctx, packet{bs, family, from}, responses) if err != nil { - select { - // Best effort: use any error response sent by forwardWithDestChan. - // This is present in some errors paths, such as when all upstream - // DNS servers replied with an error. - case resp := <-responses: - return resp.bs, err - default: - return nil, err - } + return nil, err } return (<-responses).bs, nil } @@ -340,6 +334,9 @@ func (r *Resolver) Query(ctx context.Context, bs []byte, family string, from net // GetUpstreamResolvers returns the resolvers that would be used to resolve // the given FQDN. func (r *Resolver) GetUpstreamResolvers(name dnsname.FQDN) []*dnstype.Resolver { + if !buildfeatures.HasDNS { + return nil + } return r.forwarder.GetUpstreamResolvers(name) } @@ -368,6 +365,9 @@ func parseExitNodeQuery(q []byte) *response { // and a nil error. // TODO: figure out if we even need an error result. func (r *Resolver) HandlePeerDNSQuery(ctx context.Context, q []byte, from netip.AddrPort, allowName func(name string) bool) (res []byte, err error) { + if !buildfeatures.HasDNS { + return nil, feature.ErrUnavailable + } metricDNSExitProxyQuery.Add(1) ch := make(chan packet, 1) @@ -392,7 +392,7 @@ func (r *Resolver) HandlePeerDNSQuery(ctx context.Context, q []byte, from netip. // but for now that's probably good enough. Later we'll // want to blend in everything from scutil --dns. fallthrough - case "linux", "freebsd", "openbsd", "illumos", "ios": + case "linux", "freebsd", "openbsd", "illumos", "solaris", "ios": nameserver, err := stubResolverForOS() if err != nil { r.logf("stubResolverForOS: %v", err) @@ -444,6 +444,9 @@ var debugExitNodeDNSNetPkg = envknob.RegisterBool("TS_DEBUG_EXIT_NODE_DNS_NET_PK // response contains the pre-serialized response, which notably // includes the original question and its header. func handleExitNodeDNSQueryWithNetPkg(ctx context.Context, logf logger.Logf, resolver *net.Resolver, resp *response) (res []byte, err error) { + if !buildfeatures.HasDNS { + return nil, feature.ErrUnavailable + } logf = logger.WithPrefix(logf, "exitNodeDNSQueryWithNetPkg: ") if resp.Question.Class != dns.ClassINET { return nil, errors.New("unsupported class") @@ -1264,6 +1267,9 @@ func (r *Resolver) respondReverse(query []byte, name dnsname.FQDN, resp *respons // respond returns a DNS response to query if it can be resolved locally. // Otherwise, it returns errNotOurName. func (r *Resolver) respond(query []byte) ([]byte, error) { + if !buildfeatures.HasDNS { + return nil, feature.ErrUnavailable + } parser := dnsParserPool.Get().(*dnsParser) defer dnsParserPool.Put(parser) diff --git a/net/dns/resolver/tsdns_test.go b/net/dns/resolver/tsdns_test.go index e2c4750b5..f0dbb48b3 100644 --- a/net/dns/resolver/tsdns_test.go +++ b/net/dns/resolver/tsdns_test.go @@ -31,6 +31,7 @@ import ( "tailscale.com/types/dnstype" "tailscale.com/types/logger" "tailscale.com/util/dnsname" + "tailscale.com/util/eventbus/eventbustest" ) var ( @@ -352,10 +353,13 @@ func TestRDNSNameToIPv6(t *testing.T) { } func newResolver(t testing.TB) *Resolver { + bus := eventbustest.NewBus(t) + dialer := tsdial.NewDialer(netmon.NewStatic()) + dialer.SetBus(bus) return New(t.Logf, nil, // no link selector - tsdial.NewDialer(netmon.NewStatic()), - new(health.Tracker), + dialer, + health.NewTracker(bus), nil, // no control knobs ) } @@ -1059,7 +1063,9 @@ func TestForwardLinkSelection(t *testing.T) { // routes differently. specialIP := netaddr.IPv4(1, 2, 3, 4) - netMon, err := netmon.New(logger.WithPrefix(t.Logf, ".... netmon: ")) + bus := eventbustest.NewBus(t) + + netMon, err := netmon.New(bus, logger.WithPrefix(t.Logf, ".... netmon: ")) if err != nil { t.Fatal(err) } @@ -1070,7 +1076,7 @@ func TestForwardLinkSelection(t *testing.T) { return "special" } return "" - }), new(tsdial.Dialer), new(health.Tracker), nil /* no control knobs */) + }), new(tsdial.Dialer), health.NewTracker(bus), nil /* no control knobs */) // Test non-special IP. if got, err := fwd.packetListener(netip.Addr{}); err != nil { @@ -1102,10 +1108,6 @@ type linkSelFunc func(ip netip.Addr) string func (f linkSelFunc) PickLink(ip netip.Addr) string { return f(ip) } func TestHandleExitNodeDNSQueryWithNetPkg(t *testing.T) { - if runtime.GOOS == "windows" { - t.Skip("skipping test on Windows; waiting for golang.org/issue/33097") - } - records := []any{ "no-records.test.", dnsHandler(), @@ -1401,9 +1403,6 @@ func TestHandleExitNodeDNSQueryWithNetPkg(t *testing.T) { // newWrapResolver returns a resolver that uses r (via handleExitNodeDNSQueryWithNetPkg) // to make DNS requests. func newWrapResolver(r *net.Resolver) *net.Resolver { - if runtime.GOOS == "windows" { - panic("doesn't work on Windows") // golang.org/issue/33097 - } return &net.Resolver{ PreferGo: true, Dial: func(ctx context.Context, network, addr string) (net.Conn, error) { @@ -1503,8 +1502,8 @@ func TestServfail(t *testing.T) { r.SetConfig(cfg) pkt, err := syncRespond(r, dnspacket("test.site.", dns.TypeA, noEdns)) - if !errors.Is(err, errServerFailure) { - t.Errorf("err = %v, want %v", err, errServerFailure) + if err != nil { + t.Fatalf("err = %v, want nil", err) } wantPkt := []byte{ diff --git a/net/dnscache/dnscache.go b/net/dnscache/dnscache.go index 2cbea6c0f..e222b983f 100644 --- a/net/dnscache/dnscache.go +++ b/net/dnscache/dnscache.go @@ -19,10 +19,13 @@ import ( "time" "tailscale.com/envknob" + "tailscale.com/net/netx" + "tailscale.com/syncs" "tailscale.com/types/logger" "tailscale.com/util/cloudenv" "tailscale.com/util/singleflight" "tailscale.com/util/slicesx" + "tailscale.com/util/testenv" ) var zaddr netip.Addr @@ -62,6 +65,10 @@ type Resolver struct { // If nil, net.DefaultResolver is used. Forward *net.Resolver + // LookupIPForTest, if non-nil and in tests, handles requests instead + // of the usual mechanisms. + LookupIPForTest func(ctx context.Context, host string) ([]netip.Addr, error) + // LookupIPFallback optionally provides a backup DNS mechanism // to use if Forward returns an error or no results. LookupIPFallback func(ctx context.Context, host string) ([]netip.Addr, error) @@ -91,7 +98,7 @@ type Resolver struct { sf singleflight.Group[string, ipRes] - mu sync.Mutex + mu syncs.Mutex ipCache map[string]ipCacheEntry } @@ -199,6 +206,9 @@ func (r *Resolver) LookupIP(ctx context.Context, host string) (ip, v6 netip.Addr } allIPs = append(allIPs, naIP) } + if !ip.IsValid() && v6.IsValid() { + ip = v6 + } r.dlogf("returning %d static results", len(allIPs)) return } @@ -283,7 +293,13 @@ func (r *Resolver) lookupIP(ctx context.Context, host string) (ip, ip6 netip.Add lookupCtx, lookupCancel := context.WithTimeout(ctx, r.lookupTimeoutForHost(host)) defer lookupCancel() - ips, err := r.fwd().LookupNetIP(lookupCtx, "ip", host) + + var ips []netip.Addr + if r.LookupIPForTest != nil && testenv.InTest() { + ips, err = r.LookupIPForTest(ctx, host) + } else { + ips, err = r.fwd().LookupNetIP(lookupCtx, "ip", host) + } if err != nil || len(ips) == 0 { if resolver, ok := r.cloudHostResolver(); ok { r.dlogf("resolving %q via cloud resolver", host) @@ -355,10 +371,8 @@ func (r *Resolver) addIPCache(host string, ip, ip6 netip.Addr, allIPs []netip.Ad } } -type DialContextFunc func(ctx context.Context, network, address string) (net.Conn, error) - // Dialer returns a wrapped DialContext func that uses the provided dnsCache. -func Dialer(fwd DialContextFunc, dnsCache *Resolver) DialContextFunc { +func Dialer(fwd netx.DialFunc, dnsCache *Resolver) netx.DialFunc { d := &dialer{ fwd: fwd, dnsCache: dnsCache, @@ -369,7 +383,7 @@ func Dialer(fwd DialContextFunc, dnsCache *Resolver) DialContextFunc { // dialer is the config and accumulated state for a dial func returned by Dialer. type dialer struct { - fwd DialContextFunc + fwd netx.DialFunc dnsCache *Resolver mu sync.Mutex @@ -461,7 +475,7 @@ type dialCall struct { d *dialer network, address, host, port string - mu sync.Mutex // lock ordering: dialer.mu, then dialCall.mu + mu syncs.Mutex // lock ordering: dialer.mu, then dialCall.mu fails map[netip.Addr]error // set of IPs that failed to dial thus far } @@ -653,7 +667,7 @@ func v6addrs(aa []netip.Addr) (ret []netip.Addr) { // TLSDialer is like Dialer but returns a func suitable for using with net/http.Transport.DialTLSContext. // It returns a *tls.Conn type on success. // On TLS cert validation failure, it can invoke a backup DNS resolution strategy. -func TLSDialer(fwd DialContextFunc, dnsCache *Resolver, tlsConfigBase *tls.Config) DialContextFunc { +func TLSDialer(fwd netx.DialFunc, dnsCache *Resolver, tlsConfigBase *tls.Config) netx.DialFunc { tcpDialer := Dialer(fwd, dnsCache) return func(ctx context.Context, network, address string) (net.Conn, error) { host, _, err := net.SplitHostPort(address) diff --git a/net/dnscache/dnscache_test.go b/net/dnscache/dnscache_test.go index ef4249b74..58bb6cd7f 100644 --- a/net/dnscache/dnscache_test.go +++ b/net/dnscache/dnscache_test.go @@ -11,6 +11,7 @@ import ( "net" "net/netip" "reflect" + "slices" "testing" "time" @@ -240,3 +241,60 @@ func TestShouldTryBootstrap(t *testing.T) { }) } } + +func TestSingleHostStaticResult(t *testing.T) { + v4 := netip.MustParseAddr("0.0.0.1") + v6 := netip.MustParseAddr("2001::a") + + tests := []struct { + name string + static []netip.Addr + wantIP netip.Addr + wantIP6 netip.Addr + wantAll []netip.Addr + }{ + { + name: "just-v6", + static: []netip.Addr{v6}, + wantIP: v6, + wantIP6: v6, + wantAll: []netip.Addr{v6}, + }, + { + name: "just-v4", + static: []netip.Addr{v4}, + wantIP: v4, + wantIP6: netip.Addr{}, + wantAll: []netip.Addr{v4}, + }, + { + name: "v6-then-v4", + static: []netip.Addr{v6, v4}, + wantIP: v4, + wantIP6: v6, + wantAll: []netip.Addr{v6, v4}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := &Resolver{ + SingleHost: "example.com", + SingleHostStaticResult: tt.static, + } + ip, ip6, all, err := r.LookupIP(context.Background(), "example.com") + if err != nil { + t.Fatal(err) + } + if ip != tt.wantIP { + t.Errorf("got ip %v; want %v", ip, tt.wantIP) + } + if ip6 != tt.wantIP6 { + t.Errorf("got ip6 %v; want %v", ip6, tt.wantIP6) + } + if !slices.Equal(all, tt.wantAll) { + t.Errorf("got all %v; want %v", all, tt.wantAll) + } + }) + } +} diff --git a/net/dnscache/messagecache_test.go b/net/dnscache/messagecache_test.go index 41fc33448..0bedfa5ad 100644 --- a/net/dnscache/messagecache_test.go +++ b/net/dnscache/messagecache_test.go @@ -9,7 +9,6 @@ import ( "errors" "fmt" "net" - "runtime" "testing" "time" @@ -249,14 +248,6 @@ func TestGetDNSQueryCacheKey(t *testing.T) { } func getGoNetPacketDNSQuery(name string) []byte { - if runtime.GOOS == "windows" { - // On Windows, Go's net.Resolver doesn't use the DNS client. - // See https://github.com/golang/go/issues/33097 which - // was approved but not yet implemented. - // For now just pretend it's implemented to make this test - // pass on Windows with complicated the caller. - return makeQ(123, name) - } res := make(chan []byte, 1) r := &net.Resolver{ PreferGo: true, diff --git a/net/dnsfallback/dnsfallback.go b/net/dnsfallback/dnsfallback.go index 4c5d5fa2f..74b625970 100644 --- a/net/dnsfallback/dnsfallback.go +++ b/net/dnsfallback/dnsfallback.go @@ -22,35 +22,20 @@ import ( "net/url" "os" "reflect" - "slices" "sync/atomic" "time" "tailscale.com/atomicfile" - "tailscale.com/envknob" + "tailscale.com/feature" "tailscale.com/health" - "tailscale.com/net/dns/recursive" "tailscale.com/net/netmon" "tailscale.com/net/netns" "tailscale.com/net/tlsdial" - "tailscale.com/net/tshttpproxy" "tailscale.com/tailcfg" "tailscale.com/types/logger" - "tailscale.com/util/clientmetric" - "tailscale.com/util/singleflight" "tailscale.com/util/slicesx" ) -var ( - optRecursiveResolver = envknob.RegisterOptBool("TS_DNSFALLBACK_RECURSIVE_RESOLVER") - disableRecursiveResolver = envknob.RegisterBool("TS_DNSFALLBACK_DISABLE_RECURSIVE_RESOLVER") // legacy pre-1.52 env knob name -) - -type resolveResult struct { - addrs []netip.Addr - minTTL time.Duration -} - // MakeLookupFunc creates a function that can be used to resolve hostnames // (e.g. as a LookupIPFallback from dnscache.Resolver). // The netMon parameter is optional; if non-nil it's used to do faster interface lookups. @@ -68,145 +53,13 @@ type fallbackResolver struct { logf logger.Logf netMon *netmon.Monitor // or nil healthTracker *health.Tracker // or nil - sf singleflight.Group[string, resolveResult] // for tests waitForCompare bool } func (fr *fallbackResolver) Lookup(ctx context.Context, host string) ([]netip.Addr, error) { - // If they've explicitly disabled the recursive resolver with the legacy - // TS_DNSFALLBACK_DISABLE_RECURSIVE_RESOLVER envknob or not set the - // newer TS_DNSFALLBACK_RECURSIVE_RESOLVER to true, then don't use the - // recursive resolver. (tailscale/corp#15261) In the future, we might - // change the default (the opt.Bool being unset) to mean enabled. - if disableRecursiveResolver() || !optRecursiveResolver().EqualBool(true) { - return lookup(ctx, host, fr.logf, fr.healthTracker, fr.netMon) - } - - addrsCh := make(chan []netip.Addr, 1) - - // Run the recursive resolver in the background so we can - // compare the results. For tests, we also allow waiting for the - // comparison to complete; normally, we do this entirely asynchronously - // so as not to block the caller. - var done chan struct{} - if fr.waitForCompare { - done = make(chan struct{}) - go func() { - defer close(done) - fr.compareWithRecursive(ctx, addrsCh, host) - }() - } else { - go fr.compareWithRecursive(ctx, addrsCh, host) - } - - addrs, err := lookup(ctx, host, fr.logf, fr.healthTracker, fr.netMon) - if err != nil { - addrsCh <- nil - return nil, err - } - - addrsCh <- slices.Clone(addrs) - if fr.waitForCompare { - select { - case <-done: - case <-ctx.Done(): - } - } - return addrs, nil -} - -// compareWithRecursive is responsible for comparing the DNS resolution -// performed via the "normal" path (bootstrap DNS requests to the DERP servers) -// with DNS resolution performed with our in-process recursive DNS resolver. -// -// It will select on addrsCh to read exactly one set of addrs (returned by the -// "normal" path) and compare against the results returned by the recursive -// resolver. If ctx is canceled, then it will abort. -func (fr *fallbackResolver) compareWithRecursive( - ctx context.Context, - addrsCh <-chan []netip.Addr, - host string, -) { - logf := logger.WithPrefix(fr.logf, "recursive: ") - - // Ensure that we catch panics while we're testing this - // code path; this should never panic, but we don't - // want to take down the process by having the panic - // propagate to the top of the goroutine's stack and - // then terminate. - defer func() { - if r := recover(); r != nil { - logf("bootstrap DNS: recovered panic: %v", r) - metricRecursiveErrors.Add(1) - } - }() - - // Don't resolve the same host multiple times - // concurrently; if we end up in a tight loop, this can - // take up a lot of CPU. - var didRun bool - result, err, _ := fr.sf.Do(host, func() (resolveResult, error) { - didRun = true - resolver := &recursive.Resolver{ - Dialer: netns.NewDialer(logf, fr.netMon), - Logf: logf, - } - addrs, minTTL, err := resolver.Resolve(ctx, host) - if err != nil { - logf("error using recursive resolver: %v", err) - metricRecursiveErrors.Add(1) - return resolveResult{}, err - } - return resolveResult{addrs, minTTL}, nil - }) - - // The singleflight function handled errors; return if - // there was one. Additionally, don't bother doing the - // comparison if we waited on another singleflight - // caller; the results are likely to be the same, so - // rather than spam the logs we can just exit and let - // the singleflight call that did execute do the - // comparison. - // - // Returning here is safe because the addrsCh channel - // is buffered, so the main function won't block even - // if we never read from it. - if err != nil || !didRun { - return - } - - addrs, minTTL := result.addrs, result.minTTL - compareAddr := func(a, b netip.Addr) int { return a.Compare(b) } - slices.SortFunc(addrs, compareAddr) - - // Wait for a response from the main function; try this once before we - // check whether the context is canceled since selects are - // nondeterministic. - var oldAddrs []netip.Addr - select { - case oldAddrs = <-addrsCh: - // All good; continue - default: - // Now block. - select { - case oldAddrs = <-addrsCh: - case <-ctx.Done(): - return - } - } - slices.SortFunc(oldAddrs, compareAddr) - - matches := slices.Equal(addrs, oldAddrs) - - logf("bootstrap DNS comparison: matches=%v oldAddrs=%v addrs=%v minTTL=%v", matches, oldAddrs, addrs, minTTL) - - if matches { - metricRecursiveMatches.Add(1) - } else { - metricRecursiveMismatches.Add(1) - } + return lookup(ctx, host, fr.logf, fr.healthTracker, fr.netMon) } func lookup(ctx context.Context, host string, logf logger.Logf, ht *health.Tracker, netMon *netmon.Monitor) ([]netip.Addr, error) { @@ -282,11 +135,11 @@ func bootstrapDNSMap(ctx context.Context, serverName string, serverIP netip.Addr dialer := netns.NewDialer(logf, netMon) tr := http.DefaultTransport.(*http.Transport).Clone() tr.DisableKeepAlives = true // This transport is meant to be used once. - tr.Proxy = tshttpproxy.ProxyFromEnvironment + tr.Proxy = feature.HookProxyFromEnvironment.GetOrNil() tr.DialContext = func(ctx context.Context, netw, addr string) (net.Conn, error) { return dialer.DialContext(ctx, "tcp", net.JoinHostPort(serverIP.String(), "443")) } - tr.TLSClientConfig = tlsdial.Config(serverName, ht, tr.TLSClientConfig) + tr.TLSClientConfig = tlsdial.Config(ht, tr.TLSClientConfig) c := &http.Client{Transport: tr} req, err := http.NewRequestWithContext(ctx, "GET", "https://"+serverName+"/bootstrap-dns?q="+url.QueryEscape(queryName), nil) if err != nil { @@ -428,9 +281,3 @@ func SetCachePath(path string, logf logger.Logf) { cachedDERPMap.Store(dm) logf("[v2] dnsfallback: SetCachePath loaded cached DERP map") } - -var ( - metricRecursiveMatches = clientmetric.NewCounter("dnsfallback_recursive_matches") - metricRecursiveMismatches = clientmetric.NewCounter("dnsfallback_recursive_mismatches") - metricRecursiveErrors = clientmetric.NewCounter("dnsfallback_recursive_errors") -) diff --git a/net/dnsfallback/dnsfallback_test.go b/net/dnsfallback/dnsfallback_test.go index 16f5027d4..7f8810574 100644 --- a/net/dnsfallback/dnsfallback_test.go +++ b/net/dnsfallback/dnsfallback_test.go @@ -15,6 +15,7 @@ import ( "tailscale.com/net/netmon" "tailscale.com/tailcfg" "tailscale.com/types/logger" + "tailscale.com/util/eventbus" ) func TestGetDERPMap(t *testing.T) { @@ -185,7 +186,10 @@ func TestLookup(t *testing.T) { logf, closeLogf := logger.LogfCloser(t.Logf) defer closeLogf() - netMon, err := netmon.New(logf) + bus := eventbus.New() + defer bus.Close() + + netMon, err := netmon.New(bus, logf) if err != nil { t.Fatal(err) } diff --git a/net/ipset/ipset.go b/net/ipset/ipset.go index 622fd61d0..27c1e27ed 100644 --- a/net/ipset/ipset.go +++ b/net/ipset/ipset.go @@ -82,8 +82,8 @@ func NewContainsIPFunc(addrs views.Slice[netip.Prefix]) func(ip netip.Addr) bool pathForTest("bart") // Built a bart table. t := &bart.Table[struct{}]{} - for i := range addrs.Len() { - t.Insert(addrs.At(i), struct{}{}) + for _, p := range addrs.All() { + t.Insert(p, struct{}{}) } return bartLookup(t) } @@ -99,8 +99,8 @@ func NewContainsIPFunc(addrs views.Slice[netip.Prefix]) func(ip netip.Addr) bool // General case: pathForTest("ip-map") m := set.Set[netip.Addr]{} - for i := range addrs.Len() { - m.Add(addrs.At(i).Addr()) + for _, p := range addrs.All() { + m.Add(p.Addr()) } return ipInMap(m) } diff --git a/net/ktimeout/ktimeout_linux_test.go b/net/ktimeout/ktimeout_linux_test.go index a367bfd4a..0330923a9 100644 --- a/net/ktimeout/ktimeout_linux_test.go +++ b/net/ktimeout/ktimeout_linux_test.go @@ -4,21 +4,26 @@ package ktimeout import ( + "context" "net" "testing" "time" - "golang.org/x/net/nettest" "golang.org/x/sys/unix" "tailscale.com/util/must" ) func TestSetUserTimeout(t *testing.T) { - l := must.Get(nettest.NewLocalListener("tcp")) - defer l.Close() + lc := net.ListenConfig{} + // As of 2025-02-19, MPTCP does not support TCP_USER_TIMEOUT socket option + // set in ktimeout.UserTimeout above. + lc.SetMultipathTCP(false) + + ln := must.Get(lc.Listen(context.Background(), "tcp", "localhost:0")) + defer ln.Close() var err error - if e := must.Get(l.(*net.TCPListener).SyscallConn()).Control(func(fd uintptr) { + if e := must.Get(ln.(*net.TCPListener).SyscallConn()).Control(func(fd uintptr) { err = SetUserTimeout(fd, 0) }); e != nil { t.Fatal(e) @@ -26,12 +31,12 @@ func TestSetUserTimeout(t *testing.T) { if err != nil { t.Fatal(err) } - v := must.Get(unix.GetsockoptInt(int(must.Get(l.(*net.TCPListener).File()).Fd()), unix.SOL_TCP, unix.TCP_USER_TIMEOUT)) + v := must.Get(unix.GetsockoptInt(int(must.Get(ln.(*net.TCPListener).File()).Fd()), unix.SOL_TCP, unix.TCP_USER_TIMEOUT)) if v != 0 { t.Errorf("TCP_USER_TIMEOUT: got %v; want 0", v) } - if e := must.Get(l.(*net.TCPListener).SyscallConn()).Control(func(fd uintptr) { + if e := must.Get(ln.(*net.TCPListener).SyscallConn()).Control(func(fd uintptr) { err = SetUserTimeout(fd, 30*time.Second) }); e != nil { t.Fatal(e) @@ -39,7 +44,7 @@ func TestSetUserTimeout(t *testing.T) { if err != nil { t.Fatal(err) } - v = must.Get(unix.GetsockoptInt(int(must.Get(l.(*net.TCPListener).File()).Fd()), unix.SOL_TCP, unix.TCP_USER_TIMEOUT)) + v = must.Get(unix.GetsockoptInt(int(must.Get(ln.(*net.TCPListener).File()).Fd()), unix.SOL_TCP, unix.TCP_USER_TIMEOUT)) if v != 30000 { t.Errorf("TCP_USER_TIMEOUT: got %v; want 30000", v) } diff --git a/net/ktimeout/ktimeout_test.go b/net/ktimeout/ktimeout_test.go index 7befa3b1a..b534f046c 100644 --- a/net/ktimeout/ktimeout_test.go +++ b/net/ktimeout/ktimeout_test.go @@ -14,11 +14,11 @@ func ExampleUserTimeout() { lc := net.ListenConfig{ Control: UserTimeout(30 * time.Second), } - l, err := lc.Listen(context.TODO(), "tcp", "127.0.0.1:0") + ln, err := lc.Listen(context.TODO(), "tcp", "127.0.0.1:0") if err != nil { fmt.Printf("error: %v", err) return } - l.Close() + ln.Close() // Output: } diff --git a/net/memnet/listener.go b/net/memnet/listener.go index d84a2e443..dded97995 100644 --- a/net/memnet/listener.go +++ b/net/memnet/listener.go @@ -22,6 +22,7 @@ type Listener struct { ch chan Conn closeOnce sync.Once closed chan struct{} + onClose func() // or nil // NewConn, if non-nil, is called to create a new pair of connections // when dialing. If nil, NewConn is used. @@ -38,24 +39,29 @@ func Listen(addr string) *Listener { } // Addr implements net.Listener.Addr. -func (l *Listener) Addr() net.Addr { - return l.addr +func (ln *Listener) Addr() net.Addr { + return ln.addr } // Close closes the pipe listener. -func (l *Listener) Close() error { - l.closeOnce.Do(func() { - close(l.closed) +func (ln *Listener) Close() error { + var cleanup func() + ln.closeOnce.Do(func() { + cleanup = ln.onClose + close(ln.closed) }) + if cleanup != nil { + cleanup() + } return nil } // Accept blocks until a new connection is available or the listener is closed. -func (l *Listener) Accept() (net.Conn, error) { +func (ln *Listener) Accept() (net.Conn, error) { select { - case c := <-l.ch: + case c := <-ln.ch: return c, nil - case <-l.closed: + case <-ln.closed: return nil, net.ErrClosed } } @@ -64,18 +70,18 @@ func (l *Listener) Accept() (net.Conn, error) { // The provided Context must be non-nil. If the context expires before the // connection is complete, an error is returned. Once successfully connected // any expiration of the context will not affect the connection. -func (l *Listener) Dial(ctx context.Context, network, addr string) (_ net.Conn, err error) { +func (ln *Listener) Dial(ctx context.Context, network, addr string) (_ net.Conn, err error) { if !strings.HasSuffix(network, "tcp") { return nil, net.UnknownNetworkError(network) } - if connAddr(addr) != l.addr { + if connAddr(addr) != ln.addr { return nil, &net.AddrError{ Err: "invalid address", Addr: addr, } } - newConn := l.NewConn + newConn := ln.NewConn if newConn == nil { newConn = func(network, addr string, maxBuf int) (Conn, Conn) { return NewConn(addr, maxBuf) @@ -92,9 +98,9 @@ func (l *Listener) Dial(ctx context.Context, network, addr string) (_ net.Conn, select { case <-ctx.Done(): return nil, ctx.Err() - case <-l.closed: + case <-ln.closed: return nil, net.ErrClosed - case l.ch <- s: + case ln.ch <- s: return c, nil } } diff --git a/net/memnet/listener_test.go b/net/memnet/listener_test.go index 73b67841a..b6ceb3dfa 100644 --- a/net/memnet/listener_test.go +++ b/net/memnet/listener_test.go @@ -9,10 +9,10 @@ import ( ) func TestListener(t *testing.T) { - l := Listen("srv.local") - defer l.Close() + ln := Listen("srv.local") + defer ln.Close() go func() { - c, err := l.Accept() + c, err := ln.Accept() if err != nil { t.Error(err) return @@ -20,11 +20,11 @@ func TestListener(t *testing.T) { defer c.Close() }() - if c, err := l.Dial(context.Background(), "tcp", "invalid"); err == nil { + if c, err := ln.Dial(context.Background(), "tcp", "invalid"); err == nil { c.Close() t.Fatalf("dial to invalid address succeeded") } - c, err := l.Dial(context.Background(), "tcp", "srv.local") + c, err := ln.Dial(context.Background(), "tcp", "srv.local") if err != nil { t.Fatalf("dial failed: %v", err) return diff --git a/net/memnet/memnet.go b/net/memnet/memnet.go index c8799bc17..db9e3872f 100644 --- a/net/memnet/memnet.go +++ b/net/memnet/memnet.go @@ -6,3 +6,87 @@ // in tests and other situations where you don't want to use the // network. package memnet + +import ( + "context" + "fmt" + "net" + "net/netip" + + "tailscale.com/net/netx" + "tailscale.com/syncs" +) + +var _ netx.Network = (*Network)(nil) + +// Network implements [Network] using an in-memory network, usually +// used for testing. +// +// As of 2025-04-08, it only supports TCP. +// +// Its zero value is a valid [netx.Network] implementation. +type Network struct { + mu syncs.Mutex + lns map[string]*Listener // address -> listener +} + +func (m *Network) Listen(network, address string) (net.Listener, error) { + if network != "tcp" && network != "tcp4" && network != "tcp6" { + return nil, fmt.Errorf("memNetwork: Listen called with unsupported network %q", network) + } + ap, err := netip.ParseAddrPort(address) + if err != nil { + return nil, fmt.Errorf("memNetwork: Listen called with invalid address %q: %w", address, err) + } + + m.mu.Lock() + defer m.mu.Unlock() + + if m.lns == nil { + m.lns = make(map[string]*Listener) + } + port := ap.Port() + for { + if port == 0 { + port = 33000 + } + key := net.JoinHostPort(ap.Addr().String(), fmt.Sprint(port)) + _, ok := m.lns[key] + if ok { + if ap.Port() != 0 { + return nil, fmt.Errorf("memNetwork: Listen called with duplicate address %q", address) + } + port++ + continue + } + ln := Listen(key) + m.lns[key] = ln + ln.onClose = func() { + m.mu.Lock() + delete(m.lns, key) + m.mu.Unlock() + } + return ln, nil + } +} + +func (m *Network) NewLocalTCPListener() net.Listener { + ln, err := m.Listen("tcp", "127.0.0.1:0") + if err != nil { + panic(fmt.Sprintf("memNetwork: failed to create local TCP listener: %v", err)) + } + return ln +} + +func (m *Network) Dial(ctx context.Context, network, address string) (net.Conn, error) { + if network != "tcp" && network != "tcp4" && network != "tcp6" { + return nil, fmt.Errorf("memNetwork: Dial called with unsupported network %q", network) + } + m.mu.Lock() + ln, ok := m.lns[address] + m.mu.Unlock() + if !ok { + return nil, fmt.Errorf("memNetwork: Dial called on unknown address %q", address) + } + return ln.Dial(ctx, network, address) +} diff --git a/net/memnet/memnet_test.go b/net/memnet/memnet_test.go new file mode 100644 index 000000000..38086cec0 --- /dev/null +++ b/net/memnet/memnet_test.go @@ -0,0 +1,23 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package memnet + +import "testing" + +func TestListenAddressReuse(t *testing.T) { + var nw Network + ln1, err := nw.Listen("tcp", "127.0.0.1:80") + if err != nil { + t.Fatalf("listen failed: %v", err) + } + if _, err := nw.Listen("tcp", "127.0.0.1:80"); err == nil { + t.Errorf("listen on in-use address succeeded") + } + if err := ln1.Close(); err != nil { + t.Fatalf("close failed: %v", err) + } + if _, err := nw.Listen("tcp", "127.0.0.1:80"); err != nil { + t.Errorf("listen on same address after close failed: %v", err) + } +} diff --git a/net/netaddr/netaddr.go b/net/netaddr/netaddr.go index 1ab6c053a..a04acd57a 100644 --- a/net/netaddr/netaddr.go +++ b/net/netaddr/netaddr.go @@ -34,7 +34,7 @@ func FromStdIPNet(std *net.IPNet) (prefix netip.Prefix, ok bool) { } ip = ip.Unmap() - if l := len(std.Mask); l != net.IPv4len && l != net.IPv6len { + if ln := len(std.Mask); ln != net.IPv4len && ln != net.IPv6len { // Invalid mask. return netip.Prefix{}, false } diff --git a/net/netcheck/captiveportal.go b/net/netcheck/captiveportal.go new file mode 100644 index 000000000..ad11f19a0 --- /dev/null +++ b/net/netcheck/captiveportal.go @@ -0,0 +1,55 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !ts_omit_captiveportal + +package netcheck + +import ( + "context" + "time" + + "tailscale.com/net/captivedetection" + "tailscale.com/tailcfg" +) + +func init() { + hookStartCaptivePortalDetection.Set(startCaptivePortalDetection) +} + +func startCaptivePortalDetection(ctx context.Context, rs *reportState, dm *tailcfg.DERPMap, preferredDERP int) (done <-chan struct{}, stop func()) { + c := rs.c + + // NOTE(andrew): we can't simply add this goroutine to the + // `NewWaitGroupChan` below, since we don't wait for that + // waitgroup to finish when exiting this function and thus get + // a data race. + ch := make(chan struct{}) + + tmr := time.AfterFunc(c.captivePortalDelay(), func() { + defer close(ch) + d := captivedetection.NewDetector(c.logf) + found := d.Detect(ctx, c.NetMon, dm, preferredDERP) + rs.report.CaptivePortal.Set(found) + }) + + stop = func() { + // Don't cancel our captive portal check if we're + // explicitly doing a verbose netcheck. + if c.Verbose { + return + } + + if tmr.Stop() { + // Stopped successfully; need to close the + // signal channel ourselves. + close(ch) + return + } + + // Did not stop; do nothing and it'll finish by itself + // and close the signal channel. + } + + return ch, stop +} diff --git a/net/netcheck/netcheck.go b/net/netcheck/netcheck.go index 003b5fbf8..c5a3d2392 100644 --- a/net/netcheck/netcheck.go +++ b/net/netcheck/netcheck.go @@ -23,16 +23,18 @@ import ( "syscall" "time" - "github.com/tcnksm/go-httpstat" + "tailscale.com/derp" "tailscale.com/derp/derphttp" "tailscale.com/envknob" - "tailscale.com/net/captivedetection" + "tailscale.com/feature" + "tailscale.com/feature/buildfeatures" + "tailscale.com/hostinfo" "tailscale.com/net/dnscache" "tailscale.com/net/neterror" "tailscale.com/net/netmon" "tailscale.com/net/netns" "tailscale.com/net/ping" - "tailscale.com/net/portmapper" + "tailscale.com/net/portmapper/portmappertype" "tailscale.com/net/sockstats" "tailscale.com/net/stun" "tailscale.com/syncs" @@ -85,13 +87,14 @@ const ( // Report contains the result of a single netcheck. type Report struct { - UDP bool // a UDP STUN round trip completed - IPv6 bool // an IPv6 STUN round trip completed - IPv4 bool // an IPv4 STUN round trip completed - IPv6CanSend bool // an IPv6 packet was able to be sent - IPv4CanSend bool // an IPv4 packet was able to be sent - OSHasIPv6 bool // could bind a socket to ::1 - ICMPv4 bool // an ICMPv4 round trip completed + Now time.Time // the time the report was run + UDP bool // a UDP STUN round trip completed + IPv6 bool // an IPv6 STUN round trip completed + IPv4 bool // an IPv4 STUN round trip completed + IPv6CanSend bool // an IPv6 packet was able to be sent + IPv4CanSend bool // an IPv4 packet was able to be sent + OSHasIPv6 bool // could bind a socket to ::1 + ICMPv4 bool // an ICMPv4 round trip completed // MappingVariesByDestIP is whether STUN results depend which // STUN server you're talking to (on IPv4). @@ -172,25 +175,14 @@ func (r *Report) Clone() *Report { return nil } r2 := *r - r2.RegionLatency = cloneDurationMap(r2.RegionLatency) - r2.RegionV4Latency = cloneDurationMap(r2.RegionV4Latency) - r2.RegionV6Latency = cloneDurationMap(r2.RegionV6Latency) + r2.RegionLatency = maps.Clone(r2.RegionLatency) + r2.RegionV4Latency = maps.Clone(r2.RegionV4Latency) + r2.RegionV6Latency = maps.Clone(r2.RegionV6Latency) r2.GlobalV4Counters = maps.Clone(r2.GlobalV4Counters) r2.GlobalV6Counters = maps.Clone(r2.GlobalV6Counters) return &r2 } -func cloneDurationMap(m map[int]time.Duration) map[int]time.Duration { - if m == nil { - return nil - } - m2 := make(map[int]time.Duration, len(m)) - for k, v := range m { - m2[k] = v - } - return m2 -} - // Client generates Reports describing the result of both passive and active // network configuration probing. It provides two different modes of report, a // full report (see MakeNextReportFull) and a more lightweight incremental @@ -224,7 +216,7 @@ type Client struct { // PortMapper, if non-nil, is used for portmap queries. // If nil, portmap discovery is not done. - PortMapper *portmapper.Client // lazily initialized on first use + PortMapper portmappertype.Client // UseDNSCache controls whether this client should use a // *dnscache.Resolver to resolve DERP hostnames, when no IP address is @@ -235,11 +227,15 @@ type Client struct { // If false, the default net.Resolver will be used, with no caching. UseDNSCache bool + // if non-zero, force this DERP region to be preferred in all reports where + // the DERP is found to be reachable. + ForcePreferredDERP int + // For tests testEnoughRegions int testCaptivePortalDelay time.Duration - mu sync.Mutex // guards following + mu syncs.Mutex // guards following nextFull bool // do a full region scan, even if last != nil prev map[time.Time]*Report // some previous reports last *Report // most recent report @@ -391,10 +387,14 @@ type probePlan map[string][]probe // sortRegions returns the regions of dm first sorted // from fastest to slowest (based on the 'last' report), // end in regions that have no data. -func sortRegions(dm *tailcfg.DERPMap, last *Report) (prev []*tailcfg.DERPRegion) { +func sortRegions(dm *tailcfg.DERPMap, last *Report, preferredDERP int) (prev []*tailcfg.DERPRegion) { prev = make([]*tailcfg.DERPRegion, 0, len(dm.Regions)) for _, reg := range dm.Regions { - if reg.Avoid { + if reg.NoMeasureNoHome { + continue + } + // include an otherwise avoid region if it is the current preferred region + if reg.Avoid && reg.RegionID != preferredDERP { continue } prev = append(prev, reg) @@ -419,9 +419,19 @@ func sortRegions(dm *tailcfg.DERPMap, last *Report) (prev []*tailcfg.DERPRegion) // a full report, all regions are scanned.) const numIncrementalRegions = 3 -// makeProbePlan generates the probe plan for a DERPMap, given the most -// recent report and whether IPv6 is configured on an interface. -func makeProbePlan(dm *tailcfg.DERPMap, ifState *netmon.State, last *Report) (plan probePlan) { +// makeProbePlan generates the probe plan for a DERPMap, given the most recent +// report and the current home DERP. preferredDERP is passed independently of +// last (report) because last is currently nil'd to indicate a desire for a full +// netcheck. +// +// TODO(raggi,jwhited): refactor the callers and this function to be more clear +// about full vs. incremental netchecks, and remove the need for the history +// hiding. This was avoided in an incremental change due to exactly this kind of +// distant coupling. +// TODO(raggi): change from "preferred DERP" from a historical report to "home +// DERP" as in what DERP is the current home connection, this would further +// reduce flap events. +func makeProbePlan(dm *tailcfg.DERPMap, ifState *netmon.State, last *Report, preferredDERP int) (plan probePlan) { if last == nil || len(last.RegionLatency) == 0 { return makeProbePlanInitial(dm, ifState) } @@ -432,9 +442,34 @@ func makeProbePlan(dm *tailcfg.DERPMap, ifState *netmon.State, last *Report) (pl had4 := len(last.RegionV4Latency) > 0 had6 := len(last.RegionV6Latency) > 0 hadBoth := have6if && had4 && had6 - for ri, reg := range sortRegions(dm, last) { - if ri == numIncrementalRegions { - break + // #13969 ensure that the home region is always probed. + // If a netcheck has unstable latency, such as a user with large amounts of + // bufferbloat or a highly congested connection, there are cases where a full + // netcheck may observe a one-off high latency to the current home DERP. Prior + // to the forced inclusion of the home DERP, this would result in an + // incremental netcheck following such an event to cause a home DERP move, with + // restoration back to the home DERP on the next full netcheck ~5 minutes later + // - which is highly disruptive when it causes shifts in geo routed subnet + // routers. By always including the home DERP in the incremental netcheck, we + // ensure that the home DERP is always probed, even if it observed a recent + // poor latency sample. This inclusion enables the latency history checks in + // home DERP selection to still take effect. + // planContainsHome indicates whether the home DERP has been added to the probePlan, + // if there is no prior home, then there's no home to additionally include. + planContainsHome := preferredDERP == 0 + for ri, reg := range sortRegions(dm, last, preferredDERP) { + regIsHome := reg.RegionID == preferredDERP + if ri >= numIncrementalRegions { + // planned at least numIncrementalRegions regions and that includes the + // last home region (or there was none), plan complete. + if planContainsHome { + break + } + // planned at least numIncrementalRegions regions, but not the home region, + // check if this is the home region, if not, skip it. + if !regIsHome { + continue + } } var p4, p6 []probe do4 := have4if @@ -445,7 +480,7 @@ func makeProbePlan(dm *tailcfg.DERPMap, ifState *netmon.State, last *Report) (pl tries := 1 isFastestTwo := ri < 2 - if isFastestTwo { + if isFastestTwo || regIsHome { tries = 2 } else if hadBoth { // For dual stack machines, make the 3rd & slower nodes alternate @@ -456,14 +491,15 @@ func makeProbePlan(dm *tailcfg.DERPMap, ifState *netmon.State, last *Report) (pl do4, do6 = false, true } } - if !isFastestTwo && !had6 { + if !regIsHome && !isFastestTwo && !had6 { do6 = false } - if reg.RegionID == last.PreferredDERP { + if regIsHome { // But if we already had a DERP home, try extra hard to // make sure it's there so we don't flip flop around. tries = 4 + planContainsHome = true } for try := 0; try < tries; try++ { @@ -503,6 +539,10 @@ func makeProbePlanInitial(dm *tailcfg.DERPMap, ifState *netmon.State) (plan prob plan = make(probePlan) for _, reg := range dm.Regions { + if reg.NoMeasureNoHome || len(reg.Nodes) == 0 { + continue + } + var p4 []probe var p6 []probe for try := 0; try < 3; try++ { @@ -557,7 +597,7 @@ type reportState struct { stopProbeCh chan struct{} waitPortMap sync.WaitGroup - mu sync.Mutex + mu syncs.Mutex report *Report // to be returned by GetReport inFlight map[stun.TxID]func(netip.AddrPort) // called without c.mu held gotEP4 netip.AddrPort @@ -691,7 +731,7 @@ func (rs *reportState) probePortMapServices() { res, err := rs.c.PortMapper.Probe(context.Background()) if err != nil { - if !errors.Is(err, portmapper.ErrGatewayRange) { + if !errors.Is(err, portmappertype.ErrGatewayRange) { // "skipping portmap; gateway range likely lacks support" // is not very useful, and too spammy on cloud systems. // If there are other errors, we want to log those. @@ -715,6 +755,7 @@ func newReport() *Report { // GetReportOpts contains options that can be passed to GetReport. Unless // specified, all fields are optional and can be left as their zero value. +// At most one of OnlyTCP443 or OnlySTUN may be set. type GetReportOpts struct { // GetLastDERPActivity is a callback that, if provided, should return // the absolute time that the calling code last communicated with a @@ -727,6 +768,8 @@ type GetReportOpts struct { // OnlyTCP443 constrains netcheck reporting to measurements over TCP port // 443. OnlyTCP443 bool + // OnlySTUN constrains netcheck reporting to STUN measurements over UDP. + OnlySTUN bool } // getLastDERPActivity calls o.GetLastDERPActivity if both o and @@ -738,6 +781,14 @@ func (o *GetReportOpts) getLastDERPActivity(region int) time.Time { return o.GetLastDERPActivity(region) } +func (c *Client) SetForcePreferredDERP(region int) { + c.mu.Lock() + defer c.mu.Unlock() + c.ForcePreferredDERP = region +} + +var hookStartCaptivePortalDetection feature.Hook[func(ctx context.Context, rs *reportState, dm *tailcfg.DERPMap, preferredDERP int) (<-chan struct{}, func())] + // GetReport gets a report. The 'opts' argument is optional and can be nil. // Callers are discouraged from passing a ctx with an arbitrary deadline as this // may cause GetReport to return prematurely before all reporting methods have @@ -746,6 +797,13 @@ func (o *GetReportOpts) getLastDERPActivity(region int) time.Time { // // It may not be called concurrently with itself. func (c *Client) GetReport(ctx context.Context, dm *tailcfg.DERPMap, opts *GetReportOpts) (_ *Report, reterr error) { + onlySTUN := false + if opts != nil && opts.OnlySTUN { + if opts.OnlyTCP443 { + return nil, errors.New("netcheck: only one of OnlySTUN or OnlyTCP443 may be set in opts") + } + onlySTUN = true + } defer func() { if reterr != nil { metricNumGetReportError.Add(1) @@ -784,9 +842,10 @@ func (c *Client) GetReport(ctx context.Context, dm *tailcfg.DERPMap, opts *GetRe c.curState = rs last := c.last - // Even if we're doing a non-incremental update, we may want to try our - // preferred DERP region for captive portal detection. Save that, if we - // have it. + // Extract preferredDERP from the last report, if available. This will be used + // in captive portal detection and DERP flapping suppression. Ideally this would + // be the current active home DERP rather than the last report preferred DERP, + // but only the latter is presently available. var preferredDERP int if last != nil { preferredDERP = last.PreferredDERP @@ -819,7 +878,10 @@ func (c *Client) GetReport(ctx context.Context, dm *tailcfg.DERPMap, opts *GetRe c.curState = nil }() - if runtime.GOOS == "js" || runtime.GOOS == "tamago" { + if runtime.GOOS == "js" || runtime.GOOS == "tamago" || (runtime.GOOS == "plan9" && hostinfo.IsInVM86()) { + if onlySTUN { + return nil, errors.New("platform is restricted to HTTP, but OnlySTUN is set in opts") + } if err := c.runHTTPOnlyChecks(ctx, last, rs, dm); err != nil { return nil, err } @@ -843,7 +905,7 @@ func (c *Client) GetReport(ctx context.Context, dm *tailcfg.DERPMap, opts *GetRe var plan probePlan if opts == nil || !opts.OnlyTCP443 { - plan = makeProbePlan(dm, ifState, last) + plan = makeProbePlan(dm, ifState, last, preferredDERP) } // If we're doing a full probe, also check for a captive portal. We @@ -851,38 +913,9 @@ func (c *Client) GetReport(ctx context.Context, dm *tailcfg.DERPMap, opts *GetRe // it's unnecessary. captivePortalDone := syncs.ClosedChan() captivePortalStop := func() {} - if !rs.incremental { - // NOTE(andrew): we can't simply add this goroutine to the - // `NewWaitGroupChan` below, since we don't wait for that - // waitgroup to finish when exiting this function and thus get - // a data race. - ch := make(chan struct{}) - captivePortalDone = ch - - tmr := time.AfterFunc(c.captivePortalDelay(), func() { - defer close(ch) - d := captivedetection.NewDetector(c.logf) - found := d.Detect(ctx, c.NetMon, dm, preferredDERP) - rs.report.CaptivePortal.Set(found) - }) - - captivePortalStop = func() { - // Don't cancel our captive portal check if we're - // explicitly doing a verbose netcheck. - if c.Verbose { - return - } - - if tmr.Stop() { - // Stopped successfully; need to close the - // signal channel ourselves. - close(ch) - return - } - - // Did not stop; do nothing and it'll finish by itself - // and close the signal channel. - } + if buildfeatures.HasCaptivePortal && !rs.incremental && !onlySTUN { + start := hookStartCaptivePortalDetection.Get() + captivePortalDone, captivePortalStop = start(ctx, rs, dm, preferredDERP) } wg := syncs.NewWaitGroupChan() @@ -925,18 +958,18 @@ func (c *Client) GetReport(ctx context.Context, dm *tailcfg.DERPMap, opts *GetRe rs.stopTimers() // Try HTTPS and ICMP latency check if all STUN probes failed due to - // UDP presumably being blocked. + // UDP presumably being blocked, and we are not constrained to only STUN. // TODO: this should be moved into the probePlan, using probeProto probeHTTPS. - if !rs.anyUDP() && ctx.Err() == nil { + if !rs.anyUDP() && ctx.Err() == nil && !onlySTUN { var wg sync.WaitGroup var need []*tailcfg.DERPRegion for rid, reg := range dm.Regions { - if !rs.haveRegionLatency(rid) && regionHasDERPNode(reg) { + if !rs.haveRegionLatency(rid) && regionHasDERPNode(reg) && !reg.Avoid && !reg.NoMeasureNoHome { need = append(need, reg) } } if len(need) > 0 { - if !opts.OnlyTCP443 { + if opts == nil || !opts.OnlyTCP443 { // Kick off ICMP in parallel to HTTPS checks; we don't // reuse the same WaitGroup for those probes because we // need to close the underlying Pinger after a timeout @@ -960,9 +993,9 @@ func (c *Client) GetReport(ctx context.Context, dm *tailcfg.DERPMap, opts *GetRe c.logf("[v1] netcheck: measuring HTTPS latency of %v (%d): %v", reg.RegionCode, reg.RegionID, err) } else { rs.mu.Lock() - if l, ok := rs.report.RegionLatency[reg.RegionID]; !ok { + if latency, ok := rs.report.RegionLatency[reg.RegionID]; !ok { mak.Set(&rs.report.RegionLatency, reg.RegionID, d) - } else if l >= d { + } else if latency >= d { rs.report.RegionLatency[reg.RegionID] = d } // We set these IPv4 and IPv6 but they're not really used @@ -1001,7 +1034,7 @@ func (c *Client) finishAndStoreReport(rs *reportState, dm *tailcfg.DERPMap) *Rep } // runHTTPOnlyChecks is the netcheck done by environments that can -// only do HTTP requests, such as ws/wasm. +// only do HTTP requests, such as js/wasm. func (c *Client) runHTTPOnlyChecks(ctx context.Context, last *Report, rs *reportState, dm *tailcfg.DERPMap) error { var regions []*tailcfg.DERPRegion if rs.incremental && last != nil { @@ -1013,9 +1046,25 @@ func (c *Client) runHTTPOnlyChecks(ctx context.Context, last *Report, rs *report } if len(regions) == 0 { for _, dr := range dm.Regions { + if dr.NoMeasureNoHome { + continue + } regions = append(regions, dr) } } + + if len(regions) == 1 && hostinfo.IsInVM86() { + // If we only have 1 region that's probably and we're in a + // network-limited v86 environment, don't actually probe it. Just fake + // some results. + rg := regions[0] + if len(rg.Nodes) > 0 { + node := rg.Nodes[0] + rs.addNodeLatency(node, netip.AddrPort{}, 999*time.Millisecond) + return nil + } + } + c.logf("running HTTP-only netcheck against %v regions", len(regions)) var wg sync.WaitGroup @@ -1024,7 +1073,6 @@ func (c *Client) runHTTPOnlyChecks(ctx context.Context, last *Report, rs *report continue } wg.Add(1) - rg := rg go func() { defer wg.Done() node := rg.Nodes[0] @@ -1057,10 +1105,11 @@ func (c *Client) runHTTPOnlyChecks(ctx context.Context, last *Report, rs *report return nil } +// measureHTTPSLatency measures HTTP request latency to the DERP region, but +// only returns success if an HTTPS request to the region succeeds. func (c *Client) measureHTTPSLatency(ctx context.Context, reg *tailcfg.DERPRegion) (time.Duration, netip.Addr, error) { metricHTTPSend.Add(1) - var result httpstat.Result - ctx, cancel := context.WithTimeout(httpstat.WithHTTPStat(ctx, &result), httpsProbeTimeout) + ctx, cancel := context.WithTimeout(ctx, httpsProbeTimeout) defer cancel() var ip netip.Addr @@ -1068,6 +1117,8 @@ func (c *Client) measureHTTPSLatency(ctx context.Context, reg *tailcfg.DERPRegio dc := derphttp.NewNetcheckClient(c.logf, c.NetMon) defer dc.Close() + // DialRegionTLS may dial multiple times if a node is not available, as such + // it does not have stable timing to measure. tlsConn, tcpConn, node, err := dc.DialRegionTLS(ctx, reg) if err != nil { return 0, ip, err @@ -1085,6 +1136,8 @@ func (c *Client) measureHTTPSLatency(ctx context.Context, reg *tailcfg.DERPRegio connc := make(chan *tls.Conn, 1) connc <- tlsConn + // make an HTTP request to measure, as this enables us to account for MITM + // overhead in e.g. corp environments that have HTTP MITM in front of DERP. tr := &http.Transport{ DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { return nil, errors.New("unexpected DialContext dial") @@ -1100,12 +1153,17 @@ func (c *Client) measureHTTPSLatency(ctx context.Context, reg *tailcfg.DERPRegio } hc := &http.Client{Transport: tr} + // This is the request that will be measured, the request and response + // should be small enough to fit into a single packet each way unless the + // connection has already become unstable. req, err := http.NewRequestWithContext(ctx, "GET", "https://"+node.HostName+"/derp/latency-check", nil) if err != nil { return 0, ip, err } + startTime := c.timeNow() resp, err := hc.Do(req) + reqDur := c.timeNow().Sub(startTime) if err != nil { return 0, ip, err } @@ -1122,17 +1180,22 @@ func (c *Client) measureHTTPSLatency(ctx context.Context, reg *tailcfg.DERPRegio if err != nil { return 0, ip, err } - result.End(c.timeNow()) - // TODO: decide best timing heuristic here. - // Maybe the server should return the tcpinfo_rtt? - return result.ServerProcessing, ip, nil + // return the connection duration, not the request duration, as this is the + // best approximation of the RTT latency to the node. Note that the + // connection setup performs happy-eyeballs and TLS so there are additional + // overheads. + return reqDur, ip, nil } func (c *Client) measureAllICMPLatency(ctx context.Context, rs *reportState, need []*tailcfg.DERPRegion) error { if len(need) == 0 { return nil } + if runtime.GOOS == "plan9" { + // ICMP isn't implemented. + return nil + } ctx, done := context.WithTimeout(ctx, icmpProbeTimeout) defer done() @@ -1151,9 +1214,9 @@ func (c *Client) measureAllICMPLatency(ctx context.Context, rs *reportState, nee } else if ok { c.logf("[v1] ICMP latency of %v (%d): %v", reg.RegionCode, reg.RegionID, d) rs.mu.Lock() - if l, ok := rs.report.RegionLatency[reg.RegionID]; !ok { + if latency, ok := rs.report.RegionLatency[reg.RegionID]; !ok { mak.Set(&rs.report.RegionLatency, reg.RegionID, d) - } else if l >= d { + } else if latency >= d { rs.report.RegionLatency[reg.RegionID] = d } @@ -1178,17 +1241,19 @@ func (c *Client) measureICMPLatency(ctx context.Context, reg *tailcfg.DERPRegion // Try pinging the first node in the region node := reg.Nodes[0] - // Get the IPAddr by asking for the UDP address that we would use for - // STUN and then using that IP. - // - // TODO(andrew-d): this is a bit ugly - nodeAddr := c.nodeAddr(ctx, node, probeIPv4) - if !nodeAddr.IsValid() { + if node.STUNPort < 0 { + // If STUN is disabled on a node, interpret that as meaning don't measure latency. + return 0, false, nil + } + const unusedPort = 0 + stunAddrPort, ok := c.nodeAddrPort(ctx, node, unusedPort, probeIPv4) + if !ok { return 0, false, fmt.Errorf("no address for node %v (v4-for-icmp)", node.Name) } + ip := stunAddrPort.Addr() addr := &net.IPAddr{ - IP: net.IP(nodeAddr.Addr().AsSlice()), - Zone: nodeAddr.Addr().Zone(), + IP: net.IP(ip.AsSlice()), + Zone: ip.Zone(), } // Use the unique node.Name field as the packet data to reduce the @@ -1232,6 +1297,9 @@ func (c *Client) logConciseReport(r *Report, dm *tailcfg.DERPMap) { if r.CaptivePortal != "" { fmt.Fprintf(w, " captiveportal=%v", r.CaptivePortal) } + if c.ForcePreferredDERP != 0 { + fmt.Fprintf(w, " force=%v", c.ForcePreferredDERP) + } fmt.Fprintf(w, " derp=%v", r.PreferredDERP) if r.PreferredDERP != 0 { fmt.Fprintf(w, " derpdist=") @@ -1277,6 +1345,15 @@ const ( // even without receiving a STUN response. // Note: must remain higher than the derp package frameReceiveRecordRate PreferredDERPFrameTime = 8 * time.Second + // PreferredDERPKeepAliveTimeout is 2x the DERP Keep Alive timeout. If there + // is no latency data to make judgements from, but we have heard from our + // current DERP region inside of 2x the KeepAlive window, don't switch DERP + // regions yet, keep the current region. This prevents region flapping / + // home DERP removal during short periods of packet loss where the DERP TCP + // connection may itself naturally recover. + // TODO(raggi): expose shared time bounds from the DERP package rather than + // duplicating them here. + PreferredDERPKeepAliveTimeout = 2 * derp.KeepAlive ) // addReportHistoryAndSetPreferredDERP adds r to the set of recent Reports @@ -1293,6 +1370,7 @@ func (c *Client) addReportHistoryAndSetPreferredDERP(rs *reportState, r *Report, c.prev = map[time.Time]*Report{} } now := c.timeNow() + r.Now = now.UTC() c.prev[now] = r c.last = r @@ -1360,13 +1438,10 @@ func (c *Client) addReportHistoryAndSetPreferredDERP(rs *reportState, r *Report, // the STUN probe) since we started the netcheck, or in the past 2s, as // another signal for "this region is still working". heardFromOldRegionRecently := false + prevRegionLastHeard := rs.opts.getLastDERPActivity(prevDERP) if changingPreferred { - if lastHeard := rs.opts.getLastDERPActivity(prevDERP); !lastHeard.IsZero() { - now := c.timeNow() - - heardFromOldRegionRecently = lastHeard.After(rs.start) - heardFromOldRegionRecently = heardFromOldRegionRecently || lastHeard.After(now.Add(-PreferredDERPFrameTime)) - } + heardFromOldRegionRecently = prevRegionLastHeard.After(rs.start) + heardFromOldRegionRecently = heardFromOldRegionRecently || prevRegionLastHeard.After(now.Add(-PreferredDERPFrameTime)) } // The old region is accessible if we've heard from it via a non-STUN @@ -1389,6 +1464,24 @@ func (c *Client) addReportHistoryAndSetPreferredDERP(rs *reportState, r *Report, // which undoes any region change we made above. r.PreferredDERP = prevDERP } + if c.ForcePreferredDERP != 0 { + // If the forced DERP region probed successfully, or has recent traffic, + // use it. + _, haveLatencySample := r.RegionLatency[c.ForcePreferredDERP] + lastHeard := rs.opts.getLastDERPActivity(c.ForcePreferredDERP) + recentActivity := lastHeard.After(rs.start) + recentActivity = recentActivity || lastHeard.After(now.Add(-PreferredDERPFrameTime)) + + if haveLatencySample || recentActivity { + r.PreferredDERP = c.ForcePreferredDERP + } + } + // If there was no latency data to make judgements on, but there is an + // active DERP connection that has at least been doing KeepAlive recently, + // keep it, rather than dropping it. + if r.PreferredDERP == 0 && prevRegionLastHeard.After(now.Add(-PreferredDERPKeepAliveTimeout)) { + r.PreferredDERP = prevDERP + } } func updateLatency(m map[int]time.Duration, regionID int, d time.Duration) { @@ -1434,8 +1527,8 @@ func (rs *reportState) runProbe(ctx context.Context, dm *tailcfg.DERPMap, probe return } - addr := c.nodeAddr(ctx, node, probe.proto) - if !addr.IsValid() { + addr, ok := c.nodeAddrPort(ctx, node, node.STUNPort, probe.proto) + if !ok { c.logf("netcheck.runProbe: named node %q has no %v address", probe.node, probe.proto) return } @@ -1484,12 +1577,20 @@ func (rs *reportState) runProbe(ctx context.Context, dm *tailcfg.DERPMap, probe c.vlogf("sent to %v", addr) } -// proto is 4 or 6 -// If it returns nil, the node is skipped. -func (c *Client) nodeAddr(ctx context.Context, n *tailcfg.DERPNode, proto probeProto) (ap netip.AddrPort) { - port := cmp.Or(n.STUNPort, 3478) +// nodeAddrPort returns the IP:port to send a STUN queries to for a given node. +// +// The provided port should be n.STUNPort, which may be negative to disable STUN. +// If STUN is disabled for this node, it returns ok=false. +// The port parameter is separate for the ICMP caller to provide a fake value. +// +// proto is [probeIPv4] or [probeIPv6]. +func (c *Client) nodeAddrPort(ctx context.Context, n *tailcfg.DERPNode, port int, proto probeProto) (_ netip.AddrPort, ok bool) { + var zero netip.AddrPort if port < 0 || port > 1<<16-1 { - return + return zero, false + } + if port == 0 { + port = 3478 } if n.STUNTestIP != "" { ip, err := netip.ParseAddr(n.STUNTestIP) @@ -1502,7 +1603,7 @@ func (c *Client) nodeAddr(ctx context.Context, n *tailcfg.DERPNode, proto probeP if proto == probeIPv6 && ip.Is4() { return } - return netip.AddrPortFrom(ip, uint16(port)) + return netip.AddrPortFrom(ip, uint16(port)), true } switch proto { @@ -1510,20 +1611,20 @@ func (c *Client) nodeAddr(ctx context.Context, n *tailcfg.DERPNode, proto probeP if n.IPv4 != "" { ip, _ := netip.ParseAddr(n.IPv4) if !ip.Is4() { - return + return zero, false } - return netip.AddrPortFrom(ip, uint16(port)) + return netip.AddrPortFrom(ip, uint16(port)), true } case probeIPv6: if n.IPv6 != "" { ip, _ := netip.ParseAddr(n.IPv6) if !ip.Is6() { - return + return zero, false } - return netip.AddrPortFrom(ip, uint16(port)) + return netip.AddrPortFrom(ip, uint16(port)), true } default: - return + return zero, false } // The default lookup function if we don't set UseDNSCache is to use net.DefaultResolver. @@ -1565,13 +1666,13 @@ func (c *Client) nodeAddr(ctx context.Context, n *tailcfg.DERPNode, proto probeP addrs, err := lookupIPAddr(ctx, n.HostName) for _, a := range addrs { if (a.Is4() && probeIsV4) || (a.Is6() && !probeIsV4) { - return netip.AddrPortFrom(a, uint16(port)) + return netip.AddrPortFrom(a, uint16(port)), true } } if err != nil { c.logf("netcheck: DNS lookup error for %q (node %q region %v): %v", n.HostName, n.Name, n.RegionID, err) } - return + return zero, false } func regionHasDERPNode(r *tailcfg.DERPRegion) bool { diff --git a/net/netcheck/netcheck_test.go b/net/netcheck/netcheck_test.go index 02076f8d4..6830e7f27 100644 --- a/net/netcheck/netcheck_test.go +++ b/net/netcheck/netcheck_test.go @@ -18,6 +18,7 @@ import ( "testing" "time" + "tailscale.com/derp" "tailscale.com/net/netmon" "tailscale.com/net/stun/stuntest" "tailscale.com/tailcfg" @@ -28,6 +29,9 @@ func newTestClient(t testing.TB) *Client { c := &Client{ NetMon: netmon.NewStatic(), Logf: t.Logf, + TimeNow: func() time.Time { + return time.Unix(1729624521, 0) + }, } return c } @@ -38,7 +42,7 @@ func TestBasic(t *testing.T) { c := newTestClient(t) - ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + ctx, cancel := context.WithCancel(context.Background()) defer cancel() if err := c.Standalone(ctx, "127.0.0.1:0"); err != nil { @@ -52,6 +56,9 @@ func TestBasic(t *testing.T) { if !r.UDP { t.Error("want UDP") } + if r.Now.IsZero() { + t.Error("Now is zero") + } if len(r.RegionLatency) != 1 { t.Errorf("expected 1 key in DERPLatency; got %+v", r.RegionLatency) } @@ -117,7 +124,7 @@ func TestWorksWhenUDPBlocked(t *testing.T) { c := newTestClient(t) - ctx, cancel := context.WithTimeout(context.Background(), 250*time.Millisecond) + ctx, cancel := context.WithCancel(context.Background()) defer cancel() r, err := c.GetReport(ctx, dm, nil) @@ -130,6 +137,14 @@ func TestWorksWhenUDPBlocked(t *testing.T) { want := newReport() + // The Now field can't be compared with reflect.DeepEqual; check using + // the Equal method and then overwrite it so that the comparison below + // succeeds. + if !r.Now.Equal(c.TimeNow()) { + t.Errorf("Now = %v; want %v", r.Now, c.TimeNow()) + } + want.Now = r.Now + // The IPv4CanSend flag gets set differently across platforms. // On Windows this test detects false, while on Linux detects true. // That's not relevant to this test, so just accept what we're @@ -187,6 +202,7 @@ func TestAddReportHistoryAndSetPreferredDERP(t *testing.T) { steps []step homeParams *tailcfg.DERPHomeParams opts *GetReportOpts + forcedDERP int // if non-zero, force this DERP to be the preferred one wantDERP int // want PreferredDERP on final step wantPrevLen int // wanted len(c.prev) }{ @@ -343,12 +359,107 @@ func TestAddReportHistoryAndSetPreferredDERP(t *testing.T) { wantPrevLen: 3, wantDERP: 2, // moved to d2 since d1 is gone }, + { + name: "preferred_derp_hysteresis_no_switch_pct", + steps: []step{ + {0 * time.Second, report("d1", 34*time.Millisecond, "d2", 35*time.Millisecond)}, + {1 * time.Second, report("d1", 34*time.Millisecond, "d2", 23*time.Millisecond)}, + }, + wantPrevLen: 2, + wantDERP: 1, // diff is 11ms, but d2 is greater than 2/3s of d1 + }, + { + name: "forced_two", + steps: []step{ + {time.Second, report("d1", 2, "d2", 3)}, + {2 * time.Second, report("d1", 4, "d2", 3)}, + }, + forcedDERP: 2, + wantPrevLen: 2, + wantDERP: 2, + }, + { + name: "forced_two_unavailable", + steps: []step{ + {time.Second, report("d1", 2, "d2", 1)}, + {2 * time.Second, report("d1", 4)}, + }, + forcedDERP: 2, + wantPrevLen: 2, + wantDERP: 1, + }, + { + name: "forced_two_no_probe_recent_activity", + steps: []step{ + {time.Second, report("d1", 2)}, + {2 * time.Second, report("d1", 4)}, + }, + opts: &GetReportOpts{ + GetLastDERPActivity: mkLDAFunc(map[int]time.Time{ + 1: startTime, + 2: startTime.Add(time.Second), + }), + }, + forcedDERP: 2, + wantPrevLen: 2, + wantDERP: 2, + }, + { + name: "forced_two_no_probe_no_recent_activity", + steps: []step{ + {time.Second, report("d1", 2)}, + {PreferredDERPFrameTime + time.Second, report("d1", 4)}, + }, + opts: &GetReportOpts{ + GetLastDERPActivity: mkLDAFunc(map[int]time.Time{ + 1: startTime, + 2: startTime, + }), + }, + forcedDERP: 2, + wantPrevLen: 2, + wantDERP: 1, + }, + { + name: "no_data_keep_home", + steps: []step{ + {0, report("d1", 2, "d2", 3)}, + {30 * time.Second, report()}, + {2 * time.Second, report()}, + {2 * time.Second, report()}, + {2 * time.Second, report()}, + {2 * time.Second, report()}, + }, + opts: &GetReportOpts{ + GetLastDERPActivity: mkLDAFunc(map[int]time.Time{ + 1: startTime, + }), + }, + wantPrevLen: 6, + wantDERP: 1, + }, + { + name: "no_data_home_expires", + steps: []step{ + {0, report("d1", 2, "d2", 3)}, + {30 * time.Second, report()}, + {2 * derp.KeepAlive, report()}, + }, + opts: &GetReportOpts{ + GetLastDERPActivity: mkLDAFunc(map[int]time.Time{ + 1: startTime, + }), + }, + wantPrevLen: 3, + wantDERP: 0, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { fakeTime := startTime c := &Client{ - TimeNow: func() time.Time { return fakeTime }, + TimeNow: func() time.Time { return fakeTime }, + ForcePreferredDERP: tt.forcedDERP, } dm := &tailcfg.DERPMap{HomeParams: tt.homeParams} rs := &reportState{ @@ -378,7 +489,7 @@ func TestMakeProbePlan(t *testing.T) { basicMap := &tailcfg.DERPMap{ Regions: map[int]*tailcfg.DERPRegion{}, } - for rid := 1; rid <= 5; rid++ { + for rid := 1; rid <= 6; rid++ { var nodes []*tailcfg.DERPNode for nid := 0; nid < rid; nid++ { nodes = append(nodes, &tailcfg.DERPNode{ @@ -390,8 +501,9 @@ func TestMakeProbePlan(t *testing.T) { }) } basicMap.Regions[rid] = &tailcfg.DERPRegion{ - RegionID: rid, - Nodes: nodes, + RegionID: rid, + Nodes: nodes, + NoMeasureNoHome: rid == 6, } } @@ -576,6 +688,40 @@ func TestMakeProbePlan(t *testing.T) { "region-3-v4": []probe{p("3a", 4)}, }, }, + { + // #13969: ensure that the prior/current home region is always included in + // probe plans, so that we don't flap between regions due to a single major + // netcheck having excluded the home region due to a spuriously high sample. + name: "ensure_home_region_inclusion", + dm: basicMap, + have6if: true, + last: &Report{ + RegionLatency: map[int]time.Duration{ + 1: 50 * time.Millisecond, + 2: 20 * time.Millisecond, + 3: 30 * time.Millisecond, + 4: 40 * time.Millisecond, + }, + RegionV4Latency: map[int]time.Duration{ + 1: 50 * time.Millisecond, + 2: 20 * time.Millisecond, + }, + RegionV6Latency: map[int]time.Duration{ + 3: 30 * time.Millisecond, + 4: 40 * time.Millisecond, + }, + PreferredDERP: 1, + }, + want: probePlan{ + "region-1-v4": []probe{p("1a", 4), p("1a", 4, 60*ms), p("1a", 4, 220*ms), p("1a", 4, 330*ms)}, + "region-1-v6": []probe{p("1a", 6), p("1a", 6, 60*ms), p("1a", 6, 220*ms), p("1a", 6, 330*ms)}, + "region-2-v4": []probe{p("2a", 4), p("2b", 4, 24*ms)}, + "region-2-v6": []probe{p("2a", 6), p("2b", 6, 24*ms)}, + "region-3-v4": []probe{p("3a", 4), p("3b", 4, 36*ms)}, + "region-3-v6": []probe{p("3a", 6), p("3b", 6, 36*ms)}, + "region-4-v4": []probe{p("4a", 4)}, + }, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -583,7 +729,11 @@ func TestMakeProbePlan(t *testing.T) { HaveV6: tt.have6if, HaveV4: !tt.no4, } - got := makeProbePlan(tt.dm, ifState, tt.last) + preferredDERP := 0 + if tt.last != nil { + preferredDERP = tt.last.PreferredDERP + } + got := makeProbePlan(tt.dm, ifState, tt.last, preferredDERP) if !reflect.DeepEqual(got, tt.want) { t.Errorf("unexpected plan; got:\n%v\nwant:\n%v\n", got, tt.want) } @@ -756,7 +906,7 @@ func TestSortRegions(t *testing.T) { report.RegionLatency[3] = time.Second * time.Duration(6) report.RegionLatency[4] = time.Second * time.Duration(0) report.RegionLatency[5] = time.Second * time.Duration(2) - sortedMap := sortRegions(unsortedMap, report) + sortedMap := sortRegions(unsortedMap, report, 0) // Sorting by latency this should result in rid: 5, 2, 1, 3 // rid 4 with latency 0 should be at the end @@ -826,8 +976,8 @@ func TestNodeAddrResolve(t *testing.T) { c.UseDNSCache = tt t.Run("IPv4", func(t *testing.T) { - ap := c.nodeAddr(ctx, dn, probeIPv4) - if !ap.IsValid() { + ap, ok := c.nodeAddrPort(ctx, dn, dn.STUNPort, probeIPv4) + if !ok { t.Fatal("expected valid AddrPort") } if !ap.Addr().Is4() { @@ -841,8 +991,8 @@ func TestNodeAddrResolve(t *testing.T) { t.Skipf("IPv6 may not work on this machine") } - ap := c.nodeAddr(ctx, dn, probeIPv6) - if !ap.IsValid() { + ap, ok := c.nodeAddrPort(ctx, dn, dn.STUNPort, probeIPv6) + if !ok { t.Fatal("expected valid AddrPort") } if !ap.Addr().Is6() { @@ -851,8 +1001,8 @@ func TestNodeAddrResolve(t *testing.T) { t.Logf("got IPv6 addr: %v", ap) }) t.Run("IPv6 Failure", func(t *testing.T) { - ap := c.nodeAddr(ctx, dnV4Only, probeIPv6) - if ap.IsValid() { + ap, ok := c.nodeAddrPort(ctx, dnV4Only, dn.STUNPort, probeIPv6) + if ok { t.Fatalf("expected no addr but got: %v", ap) } t.Logf("correctly got invalid addr") @@ -872,3 +1022,30 @@ func TestReportTimeouts(t *testing.T) { t.Errorf("ReportTimeout (%v) cannot be less than httpsProbeTimeout (%v)", ReportTimeout, httpsProbeTimeout) } } + +func TestNoUDPNilGetReportOpts(t *testing.T) { + blackhole, err := net.ListenPacket("udp4", "127.0.0.1:0") + if err != nil { + t.Fatalf("failed to open blackhole STUN listener: %v", err) + } + defer blackhole.Close() + + dm := stuntest.DERPMapOf(blackhole.LocalAddr().String()) + for _, region := range dm.Regions { + for _, n := range region.Nodes { + n.STUNOnly = false // exercise ICMP & HTTPS probing + } + } + + c := newTestClient(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + r, err := c.GetReport(ctx, dm, nil) + if err != nil { + t.Fatal(err) + } + if r.UDP { + t.Fatal("unexpected working UDP") + } +} diff --git a/net/netcheck/standalone.go b/net/netcheck/standalone.go index c72d7005f..b4523a832 100644 --- a/net/netcheck/standalone.go +++ b/net/netcheck/standalone.go @@ -13,7 +13,6 @@ import ( "tailscale.com/net/stun" "tailscale.com/types/logger" "tailscale.com/types/nettype" - "tailscale.com/util/multierr" ) // Standalone creates the necessary UDP sockets on the given bindAddr and starts @@ -62,7 +61,7 @@ func (c *Client) Standalone(ctx context.Context, bindAddr string) error { // If both v4 and v6 failed, report an error, otherwise let one succeed. if len(errs) == 2 { - return multierr.New(errs...) + return errors.Join(errs...) } return nil } diff --git a/net/netkernelconf/netkernelconf_default.go b/net/netkernelconf/netkernelconf_default.go index ec1b2e619..3e160e5ed 100644 --- a/net/netkernelconf/netkernelconf_default.go +++ b/net/netkernelconf/netkernelconf_default.go @@ -1,7 +1,7 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -//go:build !linux +//go:build !linux || android package netkernelconf diff --git a/net/netkernelconf/netkernelconf_linux.go b/net/netkernelconf/netkernelconf_linux.go index 51ed8ea99..2a4f0a049 100644 --- a/net/netkernelconf/netkernelconf_linux.go +++ b/net/netkernelconf/netkernelconf_linux.go @@ -1,6 +1,8 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause +//go:build linux && !android + package netkernelconf import ( diff --git a/net/netmon/defaultroute_darwin.go b/net/netmon/defaultroute_darwin.go index 4efe2f1aa..57f7e22b7 100644 --- a/net/netmon/defaultroute_darwin.go +++ b/net/netmon/defaultroute_darwin.go @@ -6,6 +6,8 @@ package netmon import ( + "errors" + "fmt" "log" "net" @@ -16,14 +18,26 @@ var ( lastKnownDefaultRouteIfName syncs.AtomicValue[string] ) -// UpdateLastKnownDefaultRouteInterface is called by ipn-go-bridge in the iOS app when +// UpdateLastKnownDefaultRouteInterface is called by ipn-go-bridge from apple network extensions when // our NWPathMonitor instance detects a network path transition. func UpdateLastKnownDefaultRouteInterface(ifName string) { if ifName == "" { return } if old := lastKnownDefaultRouteIfName.Swap(ifName); old != ifName { - log.Printf("defaultroute_darwin: update from Swift, ifName = %s (was %s)", ifName, old) + interfaces, err := netInterfaces() + if err != nil { + log.Printf("defaultroute_darwin: UpdateLastKnownDefaultRouteInterface could not get interfaces: %v", err) + return + } + + netif, err := getInterfaceByName(ifName, interfaces) + if err != nil { + log.Printf("defaultroute_darwin: UpdateLastKnownDefaultRouteInterface could not find interface index for %s: %v", ifName, err) + return + } + + log.Printf("defaultroute_darwin: updated last known default if from OS, ifName = %s index: %d (was %s)", ifName, netif.Index, old) } } @@ -40,57 +54,69 @@ func defaultRoute() (d DefaultRouteDetails, err error) { // // If for any reason the Swift machinery didn't work and we don't get any updates, we will // fallback to the BSD logic. + osRoute, osRouteErr := OSDefaultRoute() + if osRouteErr == nil { + // If we got a valid interface from the OS, use it. + d.InterfaceName = osRoute.InterfaceName + d.InterfaceIndex = osRoute.InterfaceIndex + return d, nil + } - // Start by getting all available interfaces. - interfaces, err := netInterfaces() + // Fallback to the BSD logic + idx, err := DefaultRouteInterfaceIndex() if err != nil { - log.Printf("defaultroute_darwin: could not get interfaces: %v", err) - return d, ErrNoGatewayIndexFound + return d, err } - - getInterfaceByName := func(name string) *Interface { - for _, ifc := range interfaces { - if ifc.Name != name { - continue - } - - if !ifc.IsUp() { - log.Printf("defaultroute_darwin: %s is down", name) - return nil - } - - addrs, _ := ifc.Addrs() - if len(addrs) == 0 { - log.Printf("defaultroute_darwin: %s has no addresses", name) - return nil - } - return &ifc - } - return nil + iface, err := net.InterfaceByIndex(idx) + if err != nil { + return d, err } + d.InterfaceName = iface.Name + d.InterfaceIndex = idx + return d, nil +} + +// OSDefaultRoute returns the DefaultRouteDetails for the default interface as provided by the OS +// via UpdateLastKnownDefaultRouteInterface. If UpdateLastKnownDefaultRouteInterface has not been called, +// the interface name is not valid, or we cannot find its index, an error is returned. +func OSDefaultRoute() (d DefaultRouteDetails, err error) { // Did Swift set lastKnownDefaultRouteInterface? If so, we should use it and don't bother // with anything else. However, for sanity, do check whether Swift gave us with an interface - // that exists, is up, and has an address. + // that exists, is up, and has an address and is not the tunnel itself. if swiftIfName := lastKnownDefaultRouteIfName.Load(); swiftIfName != "" { - ifc := getInterfaceByName(swiftIfName) - if ifc != nil { + // Start by getting all available interfaces. + interfaces, err := netInterfaces() + if err != nil { + log.Printf("defaultroute_darwin: could not get interfaces: %v", err) + return d, err + } + + if ifc, err := getInterfaceByName(swiftIfName, interfaces); err == nil { d.InterfaceName = ifc.Name d.InterfaceIndex = ifc.Index return d, nil } } + err = errors.New("no os provided default route interface found") + return d, err +} - // Fallback to the BSD logic - idx, err := DefaultRouteInterfaceIndex() - if err != nil { - return d, err - } - iface, err := net.InterfaceByIndex(idx) - if err != nil { - return d, err +func getInterfaceByName(name string, interfaces []Interface) (*Interface, error) { + for _, ifc := range interfaces { + if ifc.Name != name { + continue + } + + if !ifc.IsUp() { + return nil, fmt.Errorf("defaultroute_darwin: %s is down", name) + } + + addrs, _ := ifc.Addrs() + if len(addrs) == 0 { + return nil, fmt.Errorf("defaultroute_darwin: %s has no addresses", name) + } + return &ifc, nil } - d.InterfaceName = iface.Name - d.InterfaceIndex = idx - return d, nil + return nil, errors.New("no interfaces found") } diff --git a/net/netmon/interfaces_android.go b/net/netmon/interfaces_android.go index a96423eb6..26104e879 100644 --- a/net/netmon/interfaces_android.go +++ b/net/netmon/interfaces_android.go @@ -5,7 +5,6 @@ package netmon import ( "bytes" - "errors" "log" "net/netip" "os/exec" @@ -15,7 +14,7 @@ import ( "golang.org/x/sys/unix" "tailscale.com/net/netaddr" "tailscale.com/syncs" - "tailscale.com/util/lineread" + "tailscale.com/util/lineiter" ) var ( @@ -34,11 +33,6 @@ func init() { var procNetRouteErr atomic.Bool -// errStopReading is a sentinel error value used internally by -// lineread.File callers to stop reading. It doesn't escape to -// callers/users. -var errStopReading = errors.New("stop reading") - /* Parse 10.0.0.1 out of: @@ -54,44 +48,42 @@ func likelyHomeRouterIPAndroid() (ret netip.Addr, myIP netip.Addr, ok bool) { } lineNum := 0 var f []mem.RO - err := lineread.File(procNetRoutePath, func(line []byte) error { + for lr := range lineiter.File(procNetRoutePath) { + line, err := lr.Value() + if err != nil { + procNetRouteErr.Store(true) + return likelyHomeRouterIP() + } + lineNum++ if lineNum == 1 { // Skip header line. - return nil + continue } if lineNum > maxProcNetRouteRead { - return errStopReading + break } f = mem.AppendFields(f[:0], mem.B(line)) if len(f) < 4 { - return nil + continue } gwHex, flagsHex := f[2], f[3] flags, err := mem.ParseUint(flagsHex, 16, 16) if err != nil { - return nil // ignore error, skip line and keep going + continue // ignore error, skip line and keep going } if flags&(unix.RTF_UP|unix.RTF_GATEWAY) != unix.RTF_UP|unix.RTF_GATEWAY { - return nil + continue } ipu32, err := mem.ParseUint(gwHex, 16, 32) if err != nil { - return nil // ignore error, skip line and keep going + continue // ignore error, skip line and keep going } ip := netaddr.IPv4(byte(ipu32), byte(ipu32>>8), byte(ipu32>>16), byte(ipu32>>24)) if ip.IsPrivate() { ret = ip - return errStopReading + break } - return nil - }) - if errors.Is(err, errStopReading) { - err = nil - } - if err != nil { - procNetRouteErr.Store(true) - return likelyHomeRouterIP() } if ret.IsValid() { // Try to get the local IP of the interface associated with @@ -144,23 +136,26 @@ func likelyHomeRouterIPHelper() (ret netip.Addr, _ netip.Addr, ok bool) { return } // Search for line like "default via 10.0.2.2 dev radio0 table 1016 proto static mtu 1500 " - lineread.Reader(out, func(line []byte) error { + for lr := range lineiter.Reader(out) { + line, err := lr.Value() + if err != nil { + break + } const pfx = "default via " if !mem.HasPrefix(mem.B(line), mem.S(pfx)) { - return nil + continue } line = line[len(pfx):] sp := bytes.IndexByte(line, ' ') if sp == -1 { - return nil + continue } ipb := line[:sp] if ip, err := netip.ParseAddr(string(ipb)); err == nil && ip.Is4() { ret = ip log.Printf("interfaces: found Android default route %v", ip) } - return nil - }) + } cmd.Process.Kill() cmd.Wait() return ret, netip.Addr{}, ret.IsValid() diff --git a/net/netmon/interfaces_darwin.go b/net/netmon/interfaces_darwin.go index b175f980a..126040350 100644 --- a/net/netmon/interfaces_darwin.go +++ b/net/netmon/interfaces_darwin.go @@ -7,12 +7,12 @@ import ( "fmt" "net" "strings" - "sync" "syscall" "unsafe" "golang.org/x/net/route" "golang.org/x/sys/unix" + "tailscale.com/syncs" "tailscale.com/util/mak" ) @@ -26,7 +26,7 @@ func parseRoutingTable(rib []byte) ([]route.Message, error) { } var ifNames struct { - sync.Mutex + syncs.Mutex m map[int]string // ifindex => name } diff --git a/net/netmon/interfaces_darwin_test.go b/net/netmon/interfaces_darwin_test.go index d34040d60..c3d40a6f0 100644 --- a/net/netmon/interfaces_darwin_test.go +++ b/net/netmon/interfaces_darwin_test.go @@ -4,14 +4,13 @@ package netmon import ( - "errors" "io" "net/netip" "os/exec" "testing" "go4.org/mem" - "tailscale.com/util/lineread" + "tailscale.com/util/lineiter" "tailscale.com/version" ) @@ -73,31 +72,34 @@ func likelyHomeRouterIPDarwinExec() (ret netip.Addr, netif string, ok bool) { defer io.Copy(io.Discard, stdout) // clear the pipe to prevent hangs var f []mem.RO - lineread.Reader(stdout, func(lineb []byte) error { + for lr := range lineiter.Reader(stdout) { + lineb, err := lr.Value() + if err != nil { + break + } line := mem.B(lineb) if !mem.Contains(line, mem.S("default")) { - return nil + continue } f = mem.AppendFields(f[:0], line) if len(f) < 4 || !f[0].EqualString("default") { - return nil + continue } ipm, flagsm, netifm := f[1], f[2], f[3] if !mem.Contains(flagsm, mem.S("G")) { - return nil + continue } if mem.Contains(flagsm, mem.S("I")) { - return nil + continue } ip, err := netip.ParseAddr(string(mem.Append(nil, ipm))) if err == nil && ip.IsPrivate() { ret = ip netif = netifm.StringCopy() // We've found what we're looking for. - return errStopReadingNetstatTable + break } - return nil - }) + } return ret, netif, ret.IsValid() } @@ -111,4 +113,24 @@ func TestFetchRoutingTable(t *testing.T) { } } -var errStopReadingNetstatTable = errors.New("found private gateway") +func TestUpdateLastKnownDefaultRouteInterface(t *testing.T) { + // Pick some interface on the machine + interfaces, err := netInterfaces() + if err != nil || len(interfaces) == 0 { + t.Fatalf("netInterfaces() error: %v", err) + } + + // Set it as our last known default route interface + ifName := interfaces[0].Name + UpdateLastKnownDefaultRouteInterface(ifName) + + // And make sure we can get it back + route, err := OSDefaultRoute() + if err != nil { + t.Fatalf("OSDefaultRoute() error: %v", err) + } + want, got := ifName, route.InterfaceName + if want != got { + t.Errorf("OSDefaultRoute() = %q, want %q", got, want) + } +} diff --git a/net/netmon/interfaces_linux.go b/net/netmon/interfaces_linux.go index 299f3101e..a9b93c0a1 100644 --- a/net/netmon/interfaces_linux.go +++ b/net/netmon/interfaces_linux.go @@ -22,8 +22,9 @@ import ( "github.com/mdlayher/netlink" "go4.org/mem" "golang.org/x/sys/unix" + "tailscale.com/feature/buildfeatures" "tailscale.com/net/netaddr" - "tailscale.com/util/lineread" + "tailscale.com/util/lineiter" ) func init() { @@ -32,11 +33,6 @@ func init() { var procNetRouteErr atomic.Bool -// errStopReading is a sentinel error value used internally by -// lineread.File callers to stop reading. It doesn't escape to -// callers/users. -var errStopReading = errors.New("stop reading") - /* Parse 10.0.0.1 out of: @@ -46,50 +42,51 @@ ens18 00000000 0100000A 0003 0 0 0 00000000 ens18 0000000A 00000000 0001 0 0 0 0000FFFF 0 0 0 */ func likelyHomeRouterIPLinux() (ret netip.Addr, myIP netip.Addr, ok bool) { + if !buildfeatures.HasPortMapper { + return + } if procNetRouteErr.Load() { // If we failed to read /proc/net/route previously, don't keep trying. return ret, myIP, false } lineNum := 0 var f []mem.RO - err := lineread.File(procNetRoutePath, func(line []byte) error { + for lr := range lineiter.File(procNetRoutePath) { + line, err := lr.Value() + if err != nil { + procNetRouteErr.Store(true) + log.Printf("interfaces: failed to read /proc/net/route: %v", err) + return ret, myIP, false + } lineNum++ if lineNum == 1 { // Skip header line. - return nil + continue } if lineNum > maxProcNetRouteRead { - return errStopReading + break } f = mem.AppendFields(f[:0], mem.B(line)) if len(f) < 4 { - return nil + continue } gwHex, flagsHex := f[2], f[3] flags, err := mem.ParseUint(flagsHex, 16, 16) if err != nil { - return nil // ignore error, skip line and keep going + continue // ignore error, skip line and keep going } if flags&(unix.RTF_UP|unix.RTF_GATEWAY) != unix.RTF_UP|unix.RTF_GATEWAY { - return nil + continue } ipu32, err := mem.ParseUint(gwHex, 16, 32) if err != nil { - return nil // ignore error, skip line and keep going + continue // ignore error, skip line and keep going } ip := netaddr.IPv4(byte(ipu32), byte(ipu32>>8), byte(ipu32>>16), byte(ipu32>>24)) if ip.IsPrivate() { ret = ip - return errStopReading + break } - return nil - }) - if errors.Is(err, errStopReading) { - err = nil - } - if err != nil { - procNetRouteErr.Store(true) - log.Printf("interfaces: failed to read /proc/net/route: %v", err) } if ret.IsValid() { // Try to get the local IP of the interface associated with diff --git a/net/netmon/interfaces_test.go b/net/netmon/interfaces_test.go index edd4f6d6e..e4274819f 100644 --- a/net/netmon/interfaces_test.go +++ b/net/netmon/interfaces_test.go @@ -13,7 +13,7 @@ import ( ) func TestGetState(t *testing.T) { - st, err := GetState() + st, err := getState("") if err != nil { t.Fatal(err) } diff --git a/net/netmon/interfaces_windows.go b/net/netmon/interfaces_windows.go index 00b686e59..d6625ead3 100644 --- a/net/netmon/interfaces_windows.go +++ b/net/netmon/interfaces_windows.go @@ -13,6 +13,7 @@ import ( "golang.org/x/sys/windows" "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" + "tailscale.com/feature/buildfeatures" "tailscale.com/tsconst" ) @@ -22,7 +23,9 @@ const ( func init() { likelyHomeRouterIP = likelyHomeRouterIPWindows - getPAC = getPACWindows + if buildfeatures.HasUseProxy { + getPAC = getPACWindows + } } func likelyHomeRouterIPWindows() (ret netip.Addr, _ netip.Addr, ok bool) { @@ -244,6 +247,9 @@ const ( ) func getPACWindows() string { + if !buildfeatures.HasUseProxy { + return "" + } var res *uint16 r, _, e := detectAutoProxyConfigURL.Call( winHTTP_AUTO_DETECT_TYPE_DHCP|winHTTP_AUTO_DETECT_TYPE_DNS_A, diff --git a/net/netmon/loghelper.go b/net/netmon/loghelper.go new file mode 100644 index 000000000..675762cd1 --- /dev/null +++ b/net/netmon/loghelper.go @@ -0,0 +1,44 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package netmon + +import ( + "context" + "sync" + + "tailscale.com/types/logger" + "tailscale.com/util/eventbus" +) + +// LinkChangeLogLimiter returns a new [logger.Logf] that logs each unique +// format string to the underlying logger only once per major LinkChange event. +// +// The logger stops tracking seen format strings when the provided context is +// done. +func LinkChangeLogLimiter(ctx context.Context, logf logger.Logf, nm *Monitor) logger.Logf { + var formatSeen sync.Map // map[string]bool + sub := eventbus.SubscribeFunc(nm.b, func(cd ChangeDelta) { + // If we're in a major change or a time jump, clear the seen map. + if cd.Major || cd.TimeJumped { + formatSeen.Clear() + } + }) + context.AfterFunc(ctx, sub.Close) + return func(format string, args ...any) { + // We only store 'true' in the map, so if it's present then it + // means we've already logged this format string. + _, loaded := formatSeen.LoadOrStore(format, true) + if loaded { + // TODO(andrew-d): we may still want to log this + // message every N minutes (1x/hour?) even if it's been + // seen, so that debugging doesn't require searching + // back in the logs for an unbounded amount of time. + // + // See: https://github.com/tailscale/tailscale/issues/13145 + return + } + + logf(format, args...) + } +} diff --git a/net/netmon/loghelper_test.go b/net/netmon/loghelper_test.go new file mode 100644 index 000000000..ca3b1284c --- /dev/null +++ b/net/netmon/loghelper_test.go @@ -0,0 +1,86 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package netmon + +import ( + "bytes" + "context" + "fmt" + "testing" + "testing/synctest" + + "tailscale.com/util/eventbus" + "tailscale.com/util/eventbus/eventbustest" +) + +func TestLinkChangeLogLimiter(t *testing.T) { synctest.Test(t, syncTestLinkChangeLogLimiter) } + +func syncTestLinkChangeLogLimiter(t *testing.T) { + bus := eventbus.New() + defer bus.Close() + mon, err := New(bus, t.Logf) + if err != nil { + t.Fatal(err) + } + defer mon.Close() + + var logBuffer bytes.Buffer + logf := func(format string, args ...any) { + t.Logf("captured log: "+format, args...) + + if format[len(format)-1] != '\n' { + format += "\n" + } + fmt.Fprintf(&logBuffer, format, args...) + } + + ctx, cancel := context.WithCancel(t.Context()) + defer cancel() + + logf = LinkChangeLogLimiter(ctx, logf, mon) + + // Log once, which should write to our log buffer. + logf("hello %s", "world") + if got := logBuffer.String(); got != "hello world\n" { + t.Errorf("unexpected log buffer contents: %q", got) + } + + // Log again, which should not write to our log buffer. + logf("hello %s", "andrew") + if got := logBuffer.String(); got != "hello world\n" { + t.Errorf("unexpected log buffer contents: %q", got) + } + + // Log a different message, which should write to our log buffer. + logf("other message") + if got := logBuffer.String(); got != "hello world\nother message\n" { + t.Errorf("unexpected log buffer contents: %q", got) + } + + // Synthesize a fake major change event, which should clear the format + // string cache and allow the next log to write to our log buffer. + // + // InjectEvent doesn't work because it's not a major event, so we + // instead inject the event ourselves. + injector := eventbustest.NewInjector(t, bus) + eventbustest.Inject(injector, ChangeDelta{Major: true}) + synctest.Wait() + + logf("hello %s", "world") + want := "hello world\nother message\nhello world\n" + if got := logBuffer.String(); got != want { + t.Errorf("unexpected log buffer contents, got: %q, want, %q", got, want) + } + + // Canceling the context we passed to LinkChangeLogLimiter should + // unregister the callback from the netmon. + cancel() + synctest.Wait() + + mon.mu.Lock() + if len(mon.cbs) != 0 { + t.Errorf("expected no callbacks, got %v", mon.cbs) + } + mon.mu.Unlock() +} diff --git a/net/netmon/netmon.go b/net/netmon/netmon.go index 47b540d6a..657da04d5 100644 --- a/net/netmon/netmon.go +++ b/net/netmon/netmon.go @@ -14,8 +14,11 @@ import ( "sync" "time" + "tailscale.com/feature/buildfeatures" + "tailscale.com/syncs" "tailscale.com/types/logger" "tailscale.com/util/clientmetric" + "tailscale.com/util/eventbus" "tailscale.com/util/set" ) @@ -50,7 +53,10 @@ type osMon interface { // Monitor represents a monitoring instance. type Monitor struct { - logf logger.Logf + logf logger.Logf + b *eventbus.Client + changed *eventbus.Publisher[ChangeDelta] + om osMon // nil means not supported on this platform change chan bool // send false to wake poller, true to also force ChangeDeltas be sent stop chan struct{} // closed on Stop @@ -60,9 +66,8 @@ type Monitor struct { // and not change at runtime. tsIfName string // tailscale interface name, if known/set ("tailscale0", "utun3", ...) - mu sync.Mutex // guards all following fields + mu syncs.Mutex // guards all following fields cbs set.HandleSet[ChangeFunc] - ruleDelCB set.HandleSet[RuleDeleteCallback] ifState *State gwValid bool // whether gw and gwSelfIP are valid gw netip.Addr // our gateway's IP @@ -81,9 +86,6 @@ type ChangeFunc func(*ChangeDelta) // ChangeDelta describes the difference between two network states. type ChangeDelta struct { - // Monitor is the network monitor that sent this delta. - Monitor *Monitor - // Old is the old interface state, if known. // It's nil if the old state is unknown. // Do not mutate it. @@ -114,21 +116,23 @@ type ChangeDelta struct { // New instantiates and starts a monitoring instance. // The returned monitor is inactive until it's started by the Start method. // Use RegisterChangeCallback to get notified of network changes. -func New(logf logger.Logf) (*Monitor, error) { +func New(bus *eventbus.Bus, logf logger.Logf) (*Monitor, error) { logf = logger.WithPrefix(logf, "monitor: ") m := &Monitor{ logf: logf, + b: bus.Client("netmon"), change: make(chan bool, 1), stop: make(chan struct{}), lastWall: wallTime(), } + m.changed = eventbus.Publish[ChangeDelta](m.b) st, err := m.interfaceStateUncached() if err != nil { return nil, err } m.ifState = st - m.om, err = newOSMon(logf, m) + m.om, err = newOSMon(bus, logf, m) if err != nil { return nil, err } @@ -161,7 +165,7 @@ func (m *Monitor) InterfaceState() *State { } func (m *Monitor) interfaceStateUncached() (*State, error) { - return GetState() + return getState(m.tsIfName) } // SetTailscaleInterfaceName sets the name of the Tailscale interface. For @@ -179,6 +183,9 @@ func (m *Monitor) SetTailscaleInterfaceName(ifName string) { // It's the same as interfaces.LikelyHomeRouterIP, but it caches the // result until the monitor detects a network change. func (m *Monitor) GatewayAndSelfIP() (gw, myIP netip.Addr, ok bool) { + if !buildfeatures.HasPortMapper { + return + } if m.static { return } @@ -218,29 +225,6 @@ func (m *Monitor) RegisterChangeCallback(callback ChangeFunc) (unregister func() } } -// RuleDeleteCallback is a callback when a Linux IP policy routing -// rule is deleted. The table is the table number (52, 253, 354) and -// priority is the priority order number (for Tailscale rules -// currently: 5210, 5230, 5250, 5270) -type RuleDeleteCallback func(table uint8, priority uint32) - -// RegisterRuleDeleteCallback adds callback to the set of parties to be -// notified (in their own goroutine) when a Linux ip rule is deleted. -// To remove this callback, call unregister (or close the monitor). -func (m *Monitor) RegisterRuleDeleteCallback(callback RuleDeleteCallback) (unregister func()) { - if m.static { - return func() {} - } - m.mu.Lock() - defer m.mu.Unlock() - handle := m.ruleDelCB.Add(callback) - return func() { - m.mu.Lock() - defer m.mu.Unlock() - delete(m.ruleDelCB, handle) - } -} - // Start starts the monitor. // A monitor can only be started & closed once. func (m *Monitor) Start() { @@ -353,10 +337,6 @@ func (m *Monitor) pump() { time.Sleep(time.Second) continue } - if rdm, ok := msg.(ipRuleDeletedMessage); ok { - m.notifyRuleDeleted(rdm) - continue - } if msg.ignore() { continue } @@ -364,14 +344,6 @@ func (m *Monitor) pump() { } } -func (m *Monitor) notifyRuleDeleted(rdm ipRuleDeletedMessage) { - m.mu.Lock() - defer m.mu.Unlock() - for _, cb := range m.ruleDelCB { - go cb(rdm.table, rdm.priority) - } -} - // isInterestingInterface reports whether the provided interface should be // considered when checking for network state changes. // The ips parameter should be the IPs of the provided interface. @@ -431,8 +403,7 @@ func (m *Monitor) handlePotentialChange(newState *State, forceCallbacks bool) { return } - delta := &ChangeDelta{ - Monitor: m, + delta := ChangeDelta{ Old: oldState, New: newState, TimeJumped: timeJumped, @@ -441,7 +412,6 @@ func (m *Monitor) handlePotentialChange(newState *State, forceCallbacks bool) { delta.Major = m.IsMajorChangeFrom(oldState, newState) if delta.Major { m.gwValid = false - m.ifState = newState if s1, s2 := oldState.String(), delta.New.String(); s1 == s2 { m.logf("[unexpected] network state changed, but stringification didn't: %v", s1) @@ -449,6 +419,7 @@ func (m *Monitor) handlePotentialChange(newState *State, forceCallbacks bool) { m.logf("[unexpected] new: %s", jsonSummary(newState)) } } + m.ifState = newState // See if we have a queued or new time jump signal. if timeJumped { m.resetTimeJumpedLocked() @@ -465,8 +436,9 @@ func (m *Monitor) handlePotentialChange(newState *State, forceCallbacks bool) { if delta.TimeJumped { metricChangeTimeJump.Add(1) } + m.changed.Publish(delta) for _, cb := range m.cbs { - go cb(delta) + go cb(&delta) } } @@ -596,7 +568,7 @@ func (m *Monitor) pollWallTime() { // // We don't do this on mobile platforms for battery reasons, and because these // platforms don't really sleep in the same way. -const shouldMonitorTimeJump = runtime.GOOS != "android" && runtime.GOOS != "ios" +const shouldMonitorTimeJump = runtime.GOOS != "android" && runtime.GOOS != "ios" && runtime.GOOS != "plan9" // checkWallTimeAdvanceLocked reports whether wall time jumped more than 150% of // pollWallTimeInterval, indicating we probably just came out of sleep. Once a @@ -617,10 +589,3 @@ func (m *Monitor) checkWallTimeAdvanceLocked() bool { func (m *Monitor) resetTimeJumpedLocked() { m.timeJumped = false } - -type ipRuleDeletedMessage struct { - table uint8 - priority uint32 -} - -func (ipRuleDeletedMessage) ignore() bool { return true } diff --git a/net/netmon/netmon_darwin.go b/net/netmon/netmon_darwin.go index cc6301125..9c5e76475 100644 --- a/net/netmon/netmon_darwin.go +++ b/net/netmon/netmon_darwin.go @@ -13,6 +13,7 @@ import ( "golang.org/x/sys/unix" "tailscale.com/net/netaddr" "tailscale.com/types/logger" + "tailscale.com/util/eventbus" ) const debugRouteMessages = false @@ -24,7 +25,7 @@ type unspecifiedMessage struct{} func (unspecifiedMessage) ignore() bool { return false } -func newOSMon(logf logger.Logf, _ *Monitor) (osMon, error) { +func newOSMon(_ *eventbus.Bus, logf logger.Logf, _ *Monitor) (osMon, error) { fd, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, 0) if err != nil { return nil, err @@ -56,7 +57,19 @@ func (m *darwinRouteMon) Receive() (message, error) { if err != nil { return nil, err } - msgs, err := route.ParseRIB(route.RIBTypeRoute, m.buf[:n]) + msgs, err := func() (msgs []route.Message, err error) { + defer func() { + // #14201: permanent panic protection, as we have been burned by + // ParseRIB panics too many times. + msg := recover() + if msg != nil { + msgs = nil + m.logf("[unexpected] netmon: panic in route.ParseRIB from % 02x", m.buf[:n]) + err = fmt.Errorf("panic in route.ParseRIB: %s", msg) + } + }() + return route.ParseRIB(route.RIBTypeRoute, m.buf[:n]) + }() if err != nil { if debugRouteMessages { m.logf("read %d bytes (% 02x), failed to parse RIB: %v", n, m.buf[:n], err) diff --git a/net/netmon/netmon_freebsd.go b/net/netmon/netmon_freebsd.go index 30480a1d3..842cbdb0d 100644 --- a/net/netmon/netmon_freebsd.go +++ b/net/netmon/netmon_freebsd.go @@ -10,6 +10,7 @@ import ( "strings" "tailscale.com/types/logger" + "tailscale.com/util/eventbus" ) // unspecifiedMessage is a minimal message implementation that should not @@ -24,7 +25,7 @@ type devdConn struct { conn net.Conn } -func newOSMon(logf logger.Logf, m *Monitor) (osMon, error) { +func newOSMon(_ *eventbus.Bus, logf logger.Logf, m *Monitor) (osMon, error) { conn, err := net.Dial("unixpacket", "/var/run/devd.seqpacket.pipe") if err != nil { logf("devd dial error: %v, falling back to polling method", err) diff --git a/net/netmon/netmon_linux.go b/net/netmon/netmon_linux.go index dd23dd342..a1077c257 100644 --- a/net/netmon/netmon_linux.go +++ b/net/netmon/netmon_linux.go @@ -16,6 +16,7 @@ import ( "tailscale.com/envknob" "tailscale.com/net/tsaddr" "tailscale.com/types/logger" + "tailscale.com/util/eventbus" ) var debugNetlinkMessages = envknob.RegisterBool("TS_DEBUG_NETLINK") @@ -27,15 +28,26 @@ type unspecifiedMessage struct{} func (unspecifiedMessage) ignore() bool { return false } +// RuleDeleted reports that one of Tailscale's policy routing rules +// was deleted. +type RuleDeleted struct { + // Table is the table number that the deleted rule referenced. + Table uint8 + // Priority is the lookup priority of the deleted rule. + Priority uint32 +} + // nlConn wraps a *netlink.Conn and returns a monitor.Message // instead of a netlink.Message. Currently, messages are discarded, // but down the line, when messages trigger different logic depending // on the type of event, this provides the capability of handling // each architecture-specific message in a generic fashion. type nlConn struct { - logf logger.Logf - conn *netlink.Conn - buffered []netlink.Message + busClient *eventbus.Client + rulesDeleted *eventbus.Publisher[RuleDeleted] + logf logger.Logf + conn *netlink.Conn + buffered []netlink.Message // addrCache maps interface indices to a set of addresses, and is // used to suppress duplicate RTM_NEWADDR messages. It is populated @@ -44,7 +56,7 @@ type nlConn struct { addrCache map[uint32]map[netip.Addr]bool } -func newOSMon(logf logger.Logf, m *Monitor) (osMon, error) { +func newOSMon(bus *eventbus.Bus, logf logger.Logf, m *Monitor) (osMon, error) { conn, err := netlink.Dial(unix.NETLINK_ROUTE, &netlink.Config{ // Routes get us most of the events of interest, but we need // address as well to cover things like DHCP deciding to give @@ -59,12 +71,22 @@ func newOSMon(logf logger.Logf, m *Monitor) (osMon, error) { logf("monitor_linux: AF_NETLINK RTMGRP failed, falling back to polling") return newPollingMon(logf, m) } - return &nlConn{logf: logf, conn: conn, addrCache: make(map[uint32]map[netip.Addr]bool)}, nil + client := bus.Client("netmon-iprules") + return &nlConn{ + busClient: client, + rulesDeleted: eventbus.Publish[RuleDeleted](client), + logf: logf, + conn: conn, + addrCache: make(map[uint32]map[netip.Addr]bool), + }, nil } func (c *nlConn) IsInterestingInterface(iface string) bool { return true } -func (c *nlConn) Close() error { return c.conn.Close() } +func (c *nlConn) Close() error { + c.busClient.Close() + return c.conn.Close() +} func (c *nlConn) Receive() (message, error) { if len(c.buffered) == 0 { @@ -219,14 +241,15 @@ func (c *nlConn) Receive() (message, error) { // On `ip -4 rule del pref 5210 table main`, logs: // monitor: ip rule deleted: {Family:2 DstLength:0 SrcLength:0 Tos:0 Table:254 Protocol:0 Scope:0 Type:1 Flags:0 Attributes:{Dst: Src: Gateway: OutIface:0 Priority:5210 Table:254 Mark:4294967295 Expires: Metrics: Multipath:[]}} } - rdm := ipRuleDeletedMessage{ - table: rmsg.Table, - priority: rmsg.Attributes.Priority, + rd := RuleDeleted{ + Table: rmsg.Table, + Priority: rmsg.Attributes.Priority, } + c.rulesDeleted.Publish(rd) if debugNetlinkMessages() { - c.logf("%+v", rdm) + c.logf("%+v", rd) } - return rdm, nil + return ignoreMessage{}, nil case unix.RTM_NEWLINK, unix.RTM_DELLINK: // This is an unhandled message, but don't print an error. // See https://github.com/tailscale/tailscale/issues/6806 diff --git a/net/netmon/netmon_linux_test.go b/net/netmon/netmon_linux_test.go index d09fac26a..75d7c6465 100644 --- a/net/netmon/netmon_linux_test.go +++ b/net/netmon/netmon_linux_test.go @@ -1,6 +1,8 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause +//go:build linux && !android + package netmon import ( diff --git a/net/netmon/netmon_polling.go b/net/netmon/netmon_polling.go index 3d6f94731..3b5ef6fe9 100644 --- a/net/netmon/netmon_polling.go +++ b/net/netmon/netmon_polling.go @@ -7,9 +7,10 @@ package netmon import ( "tailscale.com/types/logger" + "tailscale.com/util/eventbus" ) -func newOSMon(logf logger.Logf, m *Monitor) (osMon, error) { +func newOSMon(_ *eventbus.Bus, logf logger.Logf, m *Monitor) (osMon, error) { return newPollingMon(logf, m) } diff --git a/net/netmon/netmon_test.go b/net/netmon/netmon_test.go index ce55d1946..6a87cedb8 100644 --- a/net/netmon/netmon_test.go +++ b/net/netmon/netmon_test.go @@ -7,15 +7,21 @@ import ( "flag" "net" "net/netip" + "reflect" "sync/atomic" "testing" "time" + "tailscale.com/util/eventbus" + "tailscale.com/util/eventbus/eventbustest" "tailscale.com/util/mak" ) func TestMonitorStartClose(t *testing.T) { - mon, err := New(t.Logf) + bus := eventbus.New() + defer bus.Close() + + mon, err := New(bus, t.Logf) if err != nil { t.Fatal(err) } @@ -26,7 +32,10 @@ func TestMonitorStartClose(t *testing.T) { } func TestMonitorJustClose(t *testing.T) { - mon, err := New(t.Logf) + bus := eventbus.New() + defer bus.Close() + + mon, err := New(bus, t.Logf) if err != nil { t.Fatal(err) } @@ -36,7 +45,10 @@ func TestMonitorJustClose(t *testing.T) { } func TestMonitorInjectEvent(t *testing.T) { - mon, err := New(t.Logf) + bus := eventbus.New() + defer bus.Close() + + mon, err := New(bus, t.Logf) if err != nil { t.Fatal(err) } @@ -58,6 +70,23 @@ func TestMonitorInjectEvent(t *testing.T) { } } +func TestMonitorInjectEventOnBus(t *testing.T) { + bus := eventbustest.NewBus(t) + + mon, err := New(bus, t.Logf) + if err != nil { + t.Fatal(err) + } + defer mon.Close() + tw := eventbustest.NewWatcher(t, bus) + + mon.Start() + mon.InjectEvent() + if err := eventbustest.Expect(tw, eventbustest.Type[ChangeDelta]()); err != nil { + t.Error(err) + } +} + var ( monitor = flag.String("monitor", "", `go into monitor mode like 'route monitor'; test never terminates. Value can be either "raw" or "callback"`) monitorDuration = flag.Duration("monitor-duration", 0, "if non-zero, how long to run TestMonitorMode. Zero means forever.") @@ -67,11 +96,15 @@ func TestMonitorMode(t *testing.T) { switch *monitor { case "": t.Skip("skipping non-test without --monitor") - case "raw", "callback": + case "raw", "callback", "eventbus": default: - t.Skipf(`invalid --monitor value: must be "raw" or "callback"`) + t.Skipf(`invalid --monitor value: must be "raw", "callback" or "eventbus"`) } - mon, err := New(t.Logf) + + bus := eventbustest.NewBus(t) + tw := eventbustest.NewWatcher(t, bus) + + mon, err := New(bus, t.Logf) if err != nil { t.Fatal(err) } @@ -110,6 +143,16 @@ func TestMonitorMode(t *testing.T) { mon.Start() <-done t.Logf("%v callbacks", n) + case "eventbus": + time.AfterFunc(*monitorDuration, bus.Close) + n := 0 + mon.Start() + eventbustest.Expect(tw, func(event *ChangeDelta) (bool, error) { + n++ + t.Logf("cb: changed=%v, ifSt=%v", event.Major, event.New) + return false, nil // Return false, indicating we wanna look for more events + }) + t.Logf("%v events", n) } } @@ -225,6 +268,45 @@ func TestIsMajorChangeFrom(t *testing.T) { }) } } +func TestForeachInterface(t *testing.T) { + tests := []struct { + name string + addrs []net.Addr + want []string + }{ + { + name: "Mixed_IPv4_and_IPv6", + addrs: []net.Addr{ + &net.IPNet{IP: net.IPv4(1, 2, 3, 4), Mask: net.CIDRMask(24, 32)}, + &net.IPAddr{IP: net.IP{5, 6, 7, 8}, Zone: ""}, + &net.IPNet{IP: net.ParseIP("2001:db8::1"), Mask: net.CIDRMask(64, 128)}, + &net.IPAddr{IP: net.ParseIP("2001:db8::2"), Zone: ""}, + }, + want: []string{"1.2.3.4", "5.6.7.8", "2001:db8::1", "2001:db8::2"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var got []string + ifaces := InterfaceList{ + { + Interface: &net.Interface{Name: "eth0"}, + AltAddrs: tt.addrs, + }, + } + ifaces.ForeachInterface(func(iface Interface, prefixes []netip.Prefix) { + for _, prefix := range prefixes { + ip := prefix.Addr() + got = append(got, ip.String()) + } + }) + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("got %q, want %q", got, tt.want) + } + }) + } +} type testOSMon struct { osMon diff --git a/net/netmon/netmon_windows.go b/net/netmon/netmon_windows.go index ddf13a2e4..718724b6d 100644 --- a/net/netmon/netmon_windows.go +++ b/net/netmon/netmon_windows.go @@ -13,6 +13,7 @@ import ( "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" "tailscale.com/net/tsaddr" "tailscale.com/types/logger" + "tailscale.com/util/eventbus" ) var ( @@ -45,7 +46,7 @@ type winMon struct { noDeadlockTicker *time.Ticker } -func newOSMon(logf logger.Logf, pm *Monitor) (osMon, error) { +func newOSMon(_ *eventbus.Bus, logf logger.Logf, pm *Monitor) (osMon, error) { m := &winMon{ logf: logf, isActive: pm.isActive, diff --git a/net/netmon/state.go b/net/netmon/state.go index d9b360f5e..27e3524e8 100644 --- a/net/netmon/state.go +++ b/net/netmon/state.go @@ -15,12 +15,19 @@ import ( "strings" "tailscale.com/envknob" + "tailscale.com/feature" + "tailscale.com/feature/buildfeatures" "tailscale.com/hostinfo" "tailscale.com/net/netaddr" "tailscale.com/net/tsaddr" - "tailscale.com/net/tshttpproxy" + "tailscale.com/util/mak" ) +// forceAllIPv6Endpoints is a debug knob that when set forces the client to +// report all IPv6 endpoints rather than trim endpoints that are siblings on the +// same interface and subnet. +var forceAllIPv6Endpoints = envknob.RegisterBool("TS_DEBUG_FORCE_ALL_IPV6_ENDPOINTS") + // LoginEndpointForProxyDetermination is the URL used for testing // which HTTP proxy the system should use. var LoginEndpointForProxyDetermination = "https://controlplane.tailscale.com/" @@ -65,6 +72,7 @@ func LocalAddresses() (regular, loopback []netip.Addr, err error) { if err != nil { return nil, nil, err } + var subnets map[netip.Addr]int for _, a := range addrs { switch v := a.(type) { case *net.IPNet: @@ -102,7 +110,15 @@ func LocalAddresses() (regular, loopback []netip.Addr, err error) { if ip.Is4() { regular4 = append(regular4, ip) } else { - regular6 = append(regular6, ip) + curMask, _ := netip.AddrFromSlice(v.IP.Mask(v.Mask)) + // Limit the number of addresses reported per subnet for + // IPv6, as we have seen some nodes with extremely large + // numbers of assigned addresses being carved out of + // same-subnet allocations. + if forceAllIPv6Endpoints() || subnets[curMask] < 2 { + regular6 = append(regular6, ip) + } + mak.Set(&subnets, curMask, subnets[curMask]+1) } } } @@ -167,6 +183,10 @@ func (ifaces InterfaceList) ForeachInterfaceAddress(fn func(Interface, netip.Pre if pfx, ok := netaddr.FromStdIPNet(v); ok { fn(iface, pfx) } + case *net.IPAddr: + if ip, ok := netip.AddrFromSlice(v.IP); ok { + fn(iface, netip.PrefixFrom(ip, ip.BitLen())) + } } } } @@ -199,6 +219,10 @@ func (ifaces InterfaceList) ForeachInterface(fn func(Interface, []netip.Prefix)) if pfx, ok := netaddr.FromStdIPNet(v); ok { pfxs = append(pfxs, pfx) } + case *net.IPAddr: + if ip, ok := netip.AddrFromSlice(v.IP); ok { + pfxs = append(pfxs, netip.PrefixFrom(ip, ip.BitLen())) + } } } sort.Slice(pfxs, func(i, j int) bool { @@ -446,21 +470,22 @@ func isTailscaleInterface(name string, ips []netip.Prefix) bool { // getPAC, if non-nil, returns the current PAC file URL. var getPAC func() string -// GetState returns the state of all the current machine's network interfaces. +// getState returns the state of all the current machine's network interfaces. // // It does not set the returned State.IsExpensive. The caller can populate that. // -// Deprecated: use netmon.Monitor.InterfaceState instead. -func GetState() (*State, error) { +// optTSInterfaceName is the name of the Tailscale interface, if known. +func getState(optTSInterfaceName string) (*State, error) { s := &State{ InterfaceIPs: make(map[string][]netip.Prefix), Interface: make(map[string]Interface), } if err := ForeachInterface(func(ni Interface, pfxs []netip.Prefix) { + isTSInterfaceName := optTSInterfaceName != "" && ni.Name == optTSInterfaceName ifUp := ni.IsUp() s.Interface[ni.Name] = ni s.InterfaceIPs[ni.Name] = append(s.InterfaceIPs[ni.Name], pfxs...) - if !ifUp || isTailscaleInterface(ni.Name, pfxs) { + if !ifUp || isTSInterfaceName || isTailscaleInterface(ni.Name, pfxs) { return } for _, pfx := range pfxs { @@ -485,13 +510,15 @@ func GetState() (*State, error) { } } - if s.AnyInterfaceUp() { + if buildfeatures.HasUseProxy && s.AnyInterfaceUp() { req, err := http.NewRequest("GET", LoginEndpointForProxyDetermination, nil) if err != nil { return nil, err } - if u, err := tshttpproxy.ProxyFromEnvironment(req); err == nil && u != nil { - s.HTTPProxy = u.String() + if proxyFromEnv, ok := feature.HookProxyFromEnvironment.GetOk(); ok { + if u, err := proxyFromEnv(req); err == nil && u != nil { + s.HTTPProxy = u.String() + } } if getPAC != nil { s.PAC = getPAC() @@ -554,6 +581,9 @@ var disableLikelyHomeRouterIPSelf = envknob.RegisterBool("TS_DEBUG_DISABLE_LIKEL // the LAN using that gateway. // This is used as the destination for UPnP, NAT-PMP, PCP, etc queries. func LikelyHomeRouterIP() (gateway, myIP netip.Addr, ok bool) { + if !buildfeatures.HasPortMapper { + return + } // If we don't have a way to get the home router IP, then we can't do // anything; just return. if likelyHomeRouterIP == nil { @@ -740,11 +770,12 @@ func DefaultRoute() (DefaultRouteDetails, error) { // HasCGNATInterface reports whether there are any non-Tailscale interfaces that // use a CGNAT IP range. -func HasCGNATInterface() (bool, error) { +func (m *Monitor) HasCGNATInterface() (bool, error) { hasCGNATInterface := false cgnatRange := tsaddr.CGNATRange() err := ForeachInterface(func(i Interface, pfxs []netip.Prefix) { - if hasCGNATInterface || !i.IsUp() || isTailscaleInterface(i.Name, pfxs) { + isTSInterfaceName := m.tsIfName != "" && i.Name == m.tsIfName + if hasCGNATInterface || !i.IsUp() || isTSInterfaceName || isTailscaleInterface(i.Name, pfxs) { return } for _, pfx := range pfxs { diff --git a/net/netns/netns.go b/net/netns/netns.go index a473506fa..81ab5e2a2 100644 --- a/net/netns/netns.go +++ b/net/netns/netns.go @@ -17,6 +17,7 @@ import ( "context" "net" "net/netip" + "runtime" "sync/atomic" "tailscale.com/net/netknob" @@ -39,18 +40,36 @@ var bindToInterfaceByRoute atomic.Bool // setting the TS_BIND_TO_INTERFACE_BY_ROUTE. // // Currently, this only changes the behaviour on macOS and Windows. -func SetBindToInterfaceByRoute(v bool) { - bindToInterfaceByRoute.Store(v) +func SetBindToInterfaceByRoute(logf logger.Logf, v bool) { + if bindToInterfaceByRoute.Swap(v) != v { + logf("netns: bindToInterfaceByRoute changed to %v", v) + } } var disableBindConnToInterface atomic.Bool // SetDisableBindConnToInterface disables the (normal) behavior of binding -// connections to the default network interface. +// connections to the default network interface on Darwin nodes. // -// Currently, this only has an effect on Darwin. -func SetDisableBindConnToInterface(v bool) { - disableBindConnToInterface.Store(v) +// Unless you intended to disable this for tailscaled on macos (which is likely +// to break things), you probably wanted to set +// SetDisableBindConnToInterfaceAppleExt which will disable explicit interface +// binding only when tailscaled is running inside a network extension process. +func SetDisableBindConnToInterface(logf logger.Logf, v bool) { + if disableBindConnToInterface.Swap(v) != v { + logf("netns: disableBindConnToInterface changed to %v", v) + } +} + +var disableBindConnToInterfaceAppleExt atomic.Bool + +// SetDisableBindConnToInterfaceAppleExt disables the (normal) behavior of binding +// connections to the default network interface but only on Apple clients where +// tailscaled is running inside a network extension. +func SetDisableBindConnToInterfaceAppleExt(logf logger.Logf, v bool) { + if runtime.GOOS == "darwin" && disableBindConnToInterfaceAppleExt.Swap(v) != v { + logf("netns: disableBindConnToInterfaceAppleExt changed to %v", v) + } } // Listener returns a new net.Listener with its Control hook func diff --git a/net/netns/netns_darwin.go b/net/netns/netns_darwin.go index ac5e89d76..ff05a3f31 100644 --- a/net/netns/netns_darwin.go +++ b/net/netns/netns_darwin.go @@ -21,6 +21,7 @@ import ( "tailscale.com/net/netmon" "tailscale.com/net/tsaddr" "tailscale.com/types/logger" + "tailscale.com/version" ) func control(logf logger.Logf, netMon *netmon.Monitor) func(network, address string, c syscall.RawConn) error { @@ -33,18 +34,14 @@ var bindToInterfaceByRouteEnv = envknob.RegisterBool("TS_BIND_TO_INTERFACE_BY_RO var errInterfaceStateInvalid = errors.New("interface state invalid") -// controlLogf marks c as necessary to dial in a separate network namespace. -// -// It's intentionally the same signature as net.Dialer.Control -// and net.ListenConfig.Control. +// controlLogf binds c to a particular interface as necessary to dial the +// provided (network, address). func controlLogf(logf logger.Logf, netMon *netmon.Monitor, network, address string, c syscall.RawConn) error { - if isLocalhost(address) { - // Don't bind to an interface for localhost connections. + if disableBindConnToInterface.Load() || (version.IsMacGUIVariant() && disableBindConnToInterfaceAppleExt.Load()) { return nil } - if disableBindConnToInterface.Load() { - logf("netns_darwin: binding connection to interfaces disabled") + if isLocalhost(address) { return nil } @@ -78,10 +75,38 @@ func getInterfaceIndex(logf logger.Logf, netMon *netmon.Monitor, address string) return -1, errInterfaceStateInvalid } - if iface, ok := state.Interface[state.DefaultRouteInterface]; ok { - return iface.Index, nil + // Netmon's cached view of the default inteface + cachedIdx, ok := state.Interface[state.DefaultRouteInterface] + // OSes view (if available) of the default interface + osIf, osIferr := netmon.OSDefaultRoute() + + idx := -1 + errOut := errInterfaceStateInvalid + // Preferentially choose the OS's view of the default if index. Due to the way darwin sets the delegated + // interface on tunnel creation only, it is possible for netmon to have a stale view of the default and + // netmon's view is often temporarily wrong during network transitions, or for us to not have the + // the the oses view of the defaultIf yet. + if osIferr == nil { + idx = osIf.InterfaceIndex + errOut = nil + } else if ok { + idx = cachedIdx.Index + errOut = nil + } + + if osIferr == nil && ok && (osIf.InterfaceIndex != cachedIdx.Index) { + logf("netns: [unexpected] os default if %q (%d) != netmon cached if %q (%d)", osIf.InterfaceName, osIf.InterfaceIndex, cachedIdx.Name, cachedIdx.Index) } - return -1, errInterfaceStateInvalid + + // Sanity check to make sure we didn't pick the tailscale interface + if tsif, err2 := tailscaleInterface(); tsif != nil && err2 == nil && errOut == nil { + if tsif.Index == idx { + idx = -1 + errOut = errInterfaceStateInvalid + } + } + + return idx, errOut } useRoute := bindToInterfaceByRoute.Load() || bindToInterfaceByRouteEnv() @@ -100,7 +125,7 @@ func getInterfaceIndex(logf logger.Logf, netMon *netmon.Monitor, address string) idx, err := interfaceIndexFor(addr, true /* canRecurse */) if err != nil { - logf("netns: error in interfaceIndexFor: %v", err) + logf("netns: error getting interface index for %q: %v", address, err) return defaultIdx() } @@ -108,10 +133,13 @@ func getInterfaceIndex(logf logger.Logf, netMon *netmon.Monitor, address string) // if so, we fall back to binding from the default. tsif, err2 := tailscaleInterface() if err2 == nil && tsif != nil && tsif.Index == idx { - logf("[unexpected] netns: interfaceIndexFor returned Tailscale interface") + // note: with an exit node enabled, this is almost always true. defaultIdx() is the + // right thing to do here. return defaultIdx() } + logf("netns: completed success interfaceIndexFor(%s) = %d", address, idx) + return idx, err } diff --git a/net/netns/netns_default.go b/net/netns/netns_default.go index 94f24d8fa..58c593664 100644 --- a/net/netns/netns_default.go +++ b/net/netns/netns_default.go @@ -20,3 +20,7 @@ func control(logger.Logf, *netmon.Monitor) func(network, address string, c sysca func controlC(network, address string, c syscall.RawConn) error { return nil } + +func UseSocketMark() bool { + return false +} diff --git a/net/netns/netns_dw.go b/net/netns/netns_dw.go index f92ba9462..b9f750e8a 100644 --- a/net/netns/netns_dw.go +++ b/net/netns/netns_dw.go @@ -25,3 +25,7 @@ func parseAddress(address string) (addr netip.Addr, err error) { return netip.ParseAddr(host) } + +func UseSocketMark() bool { + return false +} diff --git a/net/netns/netns_linux.go b/net/netns/netns_linux.go index aaf6dab4a..609f524b5 100644 --- a/net/netns/netns_linux.go +++ b/net/netns/netns_linux.go @@ -15,8 +15,8 @@ import ( "golang.org/x/sys/unix" "tailscale.com/envknob" "tailscale.com/net/netmon" + "tailscale.com/tsconst" "tailscale.com/types/logger" - "tailscale.com/util/linuxfw" ) // socketMarkWorksOnce is the sync.Once & cached value for useSocketMark. @@ -111,7 +111,7 @@ func controlC(network, address string, c syscall.RawConn) error { } func setBypassMark(fd uintptr) error { - if err := unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_MARK, linuxfw.TailscaleBypassMarkNum); err != nil { + if err := unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_MARK, tsconst.LinuxBypassMarkNum); err != nil { return fmt.Errorf("setting SO_MARK bypass: %w", err) } return nil diff --git a/net/netns/socks.go b/net/netns/socks.go index eea69d865..9a137db7f 100644 --- a/net/netns/socks.go +++ b/net/netns/socks.go @@ -1,7 +1,7 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -//go:build !ios && !js +//go:build !ios && !js && !android && !ts_omit_useproxy package netns diff --git a/net/netns/zsyscall_windows.go b/net/netns/zsyscall_windows.go index 07e2181be..3d8f06e09 100644 --- a/net/netns/zsyscall_windows.go +++ b/net/netns/zsyscall_windows.go @@ -45,7 +45,7 @@ var ( ) func getBestInterfaceEx(sockaddr *winipcfg.RawSockaddrInet, bestIfaceIndex *uint32) (ret error) { - r0, _, _ := syscall.Syscall(procGetBestInterfaceEx.Addr(), 2, uintptr(unsafe.Pointer(sockaddr)), uintptr(unsafe.Pointer(bestIfaceIndex)), 0) + r0, _, _ := syscall.SyscallN(procGetBestInterfaceEx.Addr(), uintptr(unsafe.Pointer(sockaddr)), uintptr(unsafe.Pointer(bestIfaceIndex))) if r0 != 0 { ret = syscall.Errno(r0) } diff --git a/net/netutil/ip_forward.go b/net/netutil/ip_forward.go index 48cee68ea..c64a9e426 100644 --- a/net/netutil/ip_forward.go +++ b/net/netutil/ip_forward.go @@ -63,6 +63,11 @@ func CheckIPForwarding(routes []netip.Prefix, state *netmon.State) (warn, err er switch runtime.GOOS { case "dragonfly", "freebsd", "netbsd", "openbsd": return fmt.Errorf("Subnet routing and exit nodes only work with additional manual configuration on %v, and is not currently officially supported.", runtime.GOOS), nil + case "illumos", "solaris": + _, err := ipForwardingEnabledSunOS(ipv4, "") + if err != nil { + return nil, fmt.Errorf("Couldn't check system's IP forwarding configuration, subnet routing/exit nodes may not work: %w%s", err, "") + } } return nil, nil } @@ -325,3 +330,24 @@ func reversePathFilterValueLinux(iface string) (int, error) { } return v, nil } + +func ipForwardingEnabledSunOS(p protocol, iface string) (bool, error) { + var proto string + if p == ipv4 { + proto = "ipv4" + } else if p == ipv6 { + proto = "ipv6" + } else { + return false, fmt.Errorf("unknown protocol") + } + + ipadmCmd := "\"ipadm show-prop " + proto + " -p forwarding -o CURRENT -c\"" + bs, err := exec.Command("ipadm", "show-prop", proto, "-p", "forwarding", "-o", "CURRENT", "-c").Output() + if err != nil { + return false, fmt.Errorf("couldn't check %s (%v).\nSubnet routes won't work without IP forwarding.", ipadmCmd, err) + } + if string(bs) != "on\n" { + return false, fmt.Errorf("IP forwarding is set to off. Subnet routes won't work. Try 'routeadm -u -e %s-forwarding'", proto) + } + return true, nil +} diff --git a/net/netutil/netutil.go b/net/netutil/netutil.go index bc64e8fdc..5c42f51c6 100644 --- a/net/netutil/netutil.go +++ b/net/netutil/netutil.go @@ -8,7 +8,8 @@ import ( "bufio" "io" "net" - "sync" + + "tailscale.com/syncs" ) // NewOneConnListener returns a net.Listener that returns c on its @@ -29,7 +30,7 @@ func NewOneConnListener(c net.Conn, addr net.Addr) net.Listener { type oneConnListener struct { addr net.Addr - mu sync.Mutex + mu syncs.Mutex conn net.Conn } diff --git a/net/netutil/netutil_test.go b/net/netutil/netutil_test.go index fdc26b02f..0523946e6 100644 --- a/net/netutil/netutil_test.go +++ b/net/netutil/netutil_test.go @@ -10,6 +10,7 @@ import ( "testing" "tailscale.com/net/netmon" + "tailscale.com/util/eventbus" ) type conn struct { @@ -72,7 +73,10 @@ func TestCheckReversePathFiltering(t *testing.T) { if runtime.GOOS != "linux" { t.Skipf("skipping on %s", runtime.GOOS) } - netMon, err := netmon.New(t.Logf) + bus := eventbus.New() + defer bus.Close() + + netMon, err := netmon.New(bus, t.Logf) if err != nil { t.Fatal(err) } diff --git a/net/netx/netx.go b/net/netx/netx.go new file mode 100644 index 000000000..014daa9a7 --- /dev/null +++ b/net/netx/netx.go @@ -0,0 +1,53 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package netx contains types to describe and abstract over how dialing and +// listening are performed. +package netx + +import ( + "context" + "fmt" + "net" +) + +// DialFunc is a function that dials a network address. +// +// It's the type implemented by net.Dialer.DialContext or required +// by net/http.Transport.DialContext, etc. +type DialFunc func(ctx context.Context, network, address string) (net.Conn, error) + +// Network describes a network that can listen and dial. The two common +// implementations are [RealNetwork], using the net package to use the real +// network, or [memnet.Network], using an in-memory network (typically for testing) +type Network interface { + NewLocalTCPListener() net.Listener + Listen(network, address string) (net.Listener, error) + Dial(ctx context.Context, network, address string) (net.Conn, error) +} + +// RealNetwork returns a Network implementation that uses the real +// net package. +func RealNetwork() Network { return realNetwork{} } + +// realNetwork implements [Network] using the real net package. +type realNetwork struct{} + +func (realNetwork) Listen(network, address string) (net.Listener, error) { + return net.Listen(network, address) +} + +func (realNetwork) Dial(ctx context.Context, network, address string) (net.Conn, error) { + var d net.Dialer + return d.DialContext(ctx, network, address) +} + +func (realNetwork) NewLocalTCPListener() net.Listener { + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + if ln, err = net.Listen("tcp6", "[::1]:0"); err != nil { + panic(fmt.Sprintf("failed to listen on either IPv4 or IPv6 localhost port: %v", err)) + } + } + return ln +} diff --git a/net/packet/capture.go b/net/packet/capture.go new file mode 100644 index 000000000..dd0ca411f --- /dev/null +++ b/net/packet/capture.go @@ -0,0 +1,75 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package packet + +import ( + "io" + "net/netip" + "time" +) + +// Callback describes a function which is called to +// record packets when debugging packet-capture. +// Such callbacks must not take ownership of the +// provided data slice: it may only copy out of it +// within the lifetime of the function. +type CaptureCallback func(CapturePath, time.Time, []byte, CaptureMeta) + +// CaptureSink is the minimal interface from [tailscale.com/feature/capture]'s +// Sink type that is needed by the core (magicsock/LocalBackend/wgengine/etc). +// This lets the relativel heavy feature/capture package be optionally linked. +type CaptureSink interface { + // Close closes + Close() error + + // NumOutputs returns the number of outputs registered with the sink. + NumOutputs() int + + // CaptureCallback returns a callback which can be used to + // write packets to the sink. + CaptureCallback() CaptureCallback + + // WaitCh returns a channel which blocks until + // the sink is closed. + WaitCh() <-chan struct{} + + // RegisterOutput connects an output to this sink, which + // will be written to with a pcap stream as packets are logged. + // A function is returned which unregisters the output when + // called. + // + // If w implements io.Closer, it will be closed upon error + // or when the sink is closed. If w implements http.Flusher, + // it will be flushed periodically. + RegisterOutput(w io.Writer) (unregister func()) +} + +// CaptureMeta contains metadata that is used when debugging. +type CaptureMeta struct { + DidSNAT bool // SNAT was performed & the address was updated. + OriginalSrc netip.AddrPort // The source address before SNAT was performed. + DidDNAT bool // DNAT was performed & the address was updated. + OriginalDst netip.AddrPort // The destination address before DNAT was performed. +} + +// CapturePath describes where in the data path the packet was captured. +type CapturePath uint8 + +// CapturePath values +const ( + // FromLocal indicates the packet was logged as it traversed the FromLocal path: + // i.e.: A packet from the local system into the TUN. + FromLocal CapturePath = 0 + // FromPeer indicates the packet was logged upon reception from a remote peer. + FromPeer CapturePath = 1 + // SynthesizedToLocal indicates the packet was generated from within tailscaled, + // and is being routed to the local machine's network stack. + SynthesizedToLocal CapturePath = 2 + // SynthesizedToPeer indicates the packet was generated from within tailscaled, + // and is being routed to a remote Wireguard peer. + SynthesizedToPeer CapturePath = 3 + + // PathDisco indicates the packet is information about a disco frame. + PathDisco CapturePath = 254 +) diff --git a/net/packet/checksum/checksum.go b/net/packet/checksum/checksum.go index 547ea3a35..4b5b82174 100644 --- a/net/packet/checksum/checksum.go +++ b/net/packet/checksum/checksum.go @@ -8,8 +8,6 @@ import ( "encoding/binary" "net/netip" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/header" "tailscale.com/net/packet" "tailscale.com/types/ipproto" ) @@ -88,13 +86,13 @@ func updateV4PacketChecksums(p *packet.Parsed, old, new netip.Addr) { tr := p.Transport() switch p.IPProto { case ipproto.UDP, ipproto.DCCP: - if len(tr) < header.UDPMinimumSize { + if len(tr) < minUDPSize { // Not enough space for a UDP header. return } updateV4Checksum(tr[6:8], o4[:], n4[:]) case ipproto.TCP: - if len(tr) < header.TCPMinimumSize { + if len(tr) < minTCPSize { // Not enough space for a TCP header. return } @@ -112,34 +110,60 @@ func updateV4PacketChecksums(p *packet.Parsed, old, new netip.Addr) { } } +const ( + minUDPSize = 8 + minTCPSize = 20 + minICMPv6Size = 8 + minIPv6Header = 40 + + offsetICMPv6Checksum = 2 + offsetUDPChecksum = 6 + offsetTCPChecksum = 16 +) + // updateV6PacketChecksums updates the checksums in the packet buffer. // p is modified in place. // If p.IPProto is unknown, no checksums are updated. func updateV6PacketChecksums(p *packet.Parsed, old, new netip.Addr) { - if len(p.Buffer()) < 40 { + if len(p.Buffer()) < minIPv6Header { // Not enough space for an IPv6 header. return } - o6, n6 := tcpip.AddrFrom16Slice(old.AsSlice()), tcpip.AddrFrom16Slice(new.AsSlice()) + o6, n6 := old.As16(), new.As16() // Now update the transport layer checksums, where applicable. tr := p.Transport() switch p.IPProto { case ipproto.ICMPv6: - if len(tr) < header.ICMPv6MinimumSize { + if len(tr) < minICMPv6Size { return } - header.ICMPv6(tr).UpdateChecksumPseudoHeaderAddress(o6, n6) + + ss := tr[offsetICMPv6Checksum:] + xsum := binary.BigEndian.Uint16(ss) + binary.BigEndian.PutUint16(ss, + ^checksumUpdate2ByteAlignedAddress(^xsum, o6, n6)) + case ipproto.UDP, ipproto.DCCP: - if len(tr) < header.UDPMinimumSize { + if len(tr) < minUDPSize { return } - header.UDP(tr).UpdateChecksumPseudoHeaderAddress(o6, n6, true) + ss := tr[offsetUDPChecksum:] + xsum := binary.BigEndian.Uint16(ss) + xsum = ^xsum + xsum = checksumUpdate2ByteAlignedAddress(xsum, o6, n6) + xsum = ^xsum + binary.BigEndian.PutUint16(ss, xsum) case ipproto.TCP: - if len(tr) < header.TCPMinimumSize { + if len(tr) < minTCPSize { return } - header.TCP(tr).UpdateChecksumPseudoHeaderAddress(o6, n6, true) + ss := tr[offsetTCPChecksum:] + xsum := binary.BigEndian.Uint16(ss) + xsum = ^xsum + xsum = checksumUpdate2ByteAlignedAddress(xsum, o6, n6) + xsum = ^xsum + binary.BigEndian.PutUint16(ss, xsum) case ipproto.SCTP: // No transport layer update required. } @@ -195,3 +219,77 @@ func updateV4Checksum(oldSum, old, new []byte) { hcPrime := ^uint16(cPrime) binary.BigEndian.PutUint16(oldSum, hcPrime) } + +// checksumUpdate2ByteAlignedAddress updates an address in a calculated +// checksum. +// +// The addresses must have the same length and must contain an even number +// of bytes. The address MUST begin at a 2-byte boundary in the original buffer. +// +// This implementation is copied from gVisor, but updated to use [16]byte. +func checksumUpdate2ByteAlignedAddress(xsum uint16, old, new [16]byte) uint16 { + const uint16Bytes = 2 + + oldAddr := old[:] + newAddr := new[:] + + // As per RFC 1071 page 4, + // (4) Incremental Update + // + // ... + // + // To update the checksum, simply add the differences of the + // sixteen bit integers that have been changed. To see why this + // works, observe that every 16-bit integer has an additive inverse + // and that addition is associative. From this it follows that + // given the original value m, the new value m', and the old + // checksum C, the new checksum C' is: + // + // C' = C + (-m) + m' = C + (m' - m) + for len(oldAddr) != 0 { + // Convert the 2 byte sequences to uint16 values then apply the increment + // update. + xsum = checksumUpdate2ByteAlignedUint16(xsum, (uint16(oldAddr[0])<<8)+uint16(oldAddr[1]), (uint16(newAddr[0])<<8)+uint16(newAddr[1])) + oldAddr = oldAddr[uint16Bytes:] + newAddr = newAddr[uint16Bytes:] + } + + return xsum +} + +// checksumUpdate2ByteAlignedUint16 updates a uint16 value in a calculated +// checksum. +// +// The value MUST begin at a 2-byte boundary in the original buffer. +// +// This implementation is copied from gVisor. +func checksumUpdate2ByteAlignedUint16(xsum, old, new uint16) uint16 { + // As per RFC 1071 page 4, + // (4) Incremental Update + // + // ... + // + // To update the checksum, simply add the differences of the + // sixteen bit integers that have been changed. To see why this + // works, observe that every 16-bit integer has an additive inverse + // and that addition is associative. From this it follows that + // given the original value m, the new value m', and the old + // checksum C, the new checksum C' is: + // + // C' = C + (-m) + m' = C + (m' - m) + if old == new { + return xsum + } + return checksumCombine(xsum, checksumCombine(new, ^old)) +} + +// checksumCombine combines the two uint16 to form their checksum. This is done +// by adding them and the carry. +// +// Note that checksum a must have been computed on an even number of bytes. +// +// This implementation is copied from gVisor. +func checksumCombine(a, b uint16) uint16 { + v := uint32(a) + uint32(b) + return uint16(v + v>>16) +} diff --git a/net/packet/geneve.go b/net/packet/geneve.go new file mode 100644 index 000000000..71b365ae8 --- /dev/null +++ b/net/packet/geneve.go @@ -0,0 +1,134 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package packet + +import ( + "encoding/binary" + "errors" + "io" +) + +const ( + // GeneveFixedHeaderLength is the length of the fixed size portion of the + // Geneve header, in bytes. + GeneveFixedHeaderLength = 8 +) + +const ( + // GeneveProtocolDisco is the IEEE 802 Ethertype number used to represent + // the Tailscale Disco protocol in a Geneve header. + GeneveProtocolDisco uint16 = 0x7A11 + // GeneveProtocolWireGuard is the IEEE 802 Ethertype number used to represent the + // WireGuard protocol in a Geneve header. + GeneveProtocolWireGuard uint16 = 0x7A12 +) + +// VirtualNetworkID is a Geneve header (RFC8926) 3-byte virtual network +// identifier. Its methods are NOT thread-safe. +type VirtualNetworkID struct { + _vni uint32 +} + +const ( + vniSetMask uint32 = 0xFF000000 + vniGetMask uint32 = ^vniSetMask +) + +// IsSet returns true if Set() had been called previously, otherwise false. +func (v *VirtualNetworkID) IsSet() bool { + return v._vni&vniSetMask != 0 +} + +// Set sets the provided VNI. If VNI exceeds the 3-byte storage it will be +// clamped. +func (v *VirtualNetworkID) Set(vni uint32) { + v._vni = vni | vniSetMask +} + +// Get returns the VNI value. +func (v *VirtualNetworkID) Get() uint32 { + return v._vni & vniGetMask +} + +// GeneveHeader represents the fixed size Geneve header from RFC8926. +// TLVs/options are not implemented/supported. +// +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// |Ver| Opt Len |O|C| Rsvd. | Protocol Type | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Virtual Network Identifier (VNI) | Reserved | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +type GeneveHeader struct { + // Ver (2 bits): The current version number is 0. Packets received by a + // tunnel endpoint with an unknown version MUST be dropped. Transit devices + // interpreting Geneve packets with an unknown version number MUST treat + // them as UDP packets with an unknown payload. + Version uint8 + + // Protocol Type (16 bits): The type of protocol data unit appearing after + // the Geneve header. This follows the Ethertype [ETYPES] convention, with + // Ethernet itself being represented by the value 0x6558. + Protocol uint16 + + // Virtual Network Identifier (VNI) (24 bits): An identifier for a unique + // element of a virtual network. In many situations, this may represent an + // L2 segment; however, the control plane defines the forwarding semantics + // of decapsulated packets. The VNI MAY be used as part of ECMP forwarding + // decisions or MAY be used as a mechanism to distinguish between + // overlapping address spaces contained in the encapsulated packet when load + // balancing across CPUs. + VNI VirtualNetworkID + + // O (1 bit): Control packet. This packet contains a control message. + // Control messages are sent between tunnel endpoints. Tunnel endpoints MUST + // NOT forward the payload, and transit devices MUST NOT attempt to + // interpret it. Since control messages are less frequent, it is RECOMMENDED + // that tunnel endpoints direct these packets to a high-priority control + // queue (for example, to direct the packet to a general purpose CPU from a + // forwarding Application-Specific Integrated Circuit (ASIC) or to separate + // out control traffic on a NIC). Transit devices MUST NOT alter forwarding + // behavior on the basis of this bit, such as ECMP link selection. + Control bool +} + +var ErrGeneveVNIUnset = errors.New("VNI is unset") + +// Encode encodes GeneveHeader into b. If len(b) < [GeneveFixedHeaderLength] an +// [io.ErrShortBuffer] error is returned. If !h.VNI.IsSet() then an +// [ErrGeneveVNIUnset] error is returned. +func (h *GeneveHeader) Encode(b []byte) error { + if len(b) < GeneveFixedHeaderLength { + return io.ErrShortBuffer + } + if !h.VNI.IsSet() { + return ErrGeneveVNIUnset + } + if h.Version > 3 { + return errors.New("version must be <= 3") + } + b[0] = 0 + b[1] = 0 + b[0] |= h.Version << 6 + if h.Control { + b[1] |= 0x80 + } + binary.BigEndian.PutUint16(b[2:], h.Protocol) + binary.BigEndian.PutUint32(b[4:], h.VNI.Get()<<8) + return nil +} + +// Decode decodes GeneveHeader from b. If len(b) < [GeneveFixedHeaderLength] an +// [io.ErrShortBuffer] error is returned. +func (h *GeneveHeader) Decode(b []byte) error { + if len(b) < GeneveFixedHeaderLength { + return io.ErrShortBuffer + } + h.Version = b[0] >> 6 + if b[1]&0x80 != 0 { + h.Control = true + } + h.Protocol = binary.BigEndian.Uint16(b[2:]) + h.VNI.Set(binary.BigEndian.Uint32(b[4:]) >> 8) + return nil +} diff --git a/net/packet/geneve_test.go b/net/packet/geneve_test.go new file mode 100644 index 000000000..be9784998 --- /dev/null +++ b/net/packet/geneve_test.go @@ -0,0 +1,84 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package packet + +import ( + "math" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "tailscale.com/types/ptr" +) + +func TestGeneveHeader(t *testing.T) { + in := GeneveHeader{ + Version: 3, + Protocol: GeneveProtocolDisco, + Control: true, + } + in.VNI.Set(1<<24 - 1) + b := make([]byte, GeneveFixedHeaderLength) + err := in.Encode(b) + if err != nil { + t.Fatal(err) + } + out := GeneveHeader{} + err = out.Decode(b) + if err != nil { + t.Fatal(err) + } + if diff := cmp.Diff(out, in, cmpopts.EquateComparable(VirtualNetworkID{})); diff != "" { + t.Fatalf("wrong results (-got +want)\n%s", diff) + } +} + +func TestVirtualNetworkID(t *testing.T) { + tests := []struct { + name string + set *uint32 + want uint32 + }{ + { + "don't Set", + nil, + 0, + }, + { + "Set 0", + ptr.To(uint32(0)), + 0, + }, + { + "Set 1", + ptr.To(uint32(1)), + 1, + }, + { + "Set math.MaxUint32", + ptr.To(uint32(math.MaxUint32)), + 1<<24 - 1, + }, + { + "Set max 3-byte value", + ptr.To(uint32(1<<24 - 1)), + 1<<24 - 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + v := VirtualNetworkID{} + if tt.set != nil { + v.Set(*tt.set) + } + if v.IsSet() != (tt.set != nil) { + t.Fatalf("IsSet: %v != wantIsSet: %v", v.IsSet(), tt.set != nil) + } + if v.Get() != tt.want { + t.Fatalf("Get(): %v != want: %v", v.Get(), tt.want) + } + }) + } +} diff --git a/net/packet/header.go b/net/packet/header.go index dbe84429a..fa66a8641 100644 --- a/net/packet/header.go +++ b/net/packet/header.go @@ -8,6 +8,7 @@ import ( "math" ) +const igmpHeaderLength = 8 const tcpHeaderLength = 20 const sctpHeaderLength = 12 diff --git a/net/packet/packet.go b/net/packet/packet.go index c9521ad46..34b63aadd 100644 --- a/net/packet/packet.go +++ b/net/packet/packet.go @@ -34,14 +34,6 @@ const ( TCPECNBits TCPFlag = TCPECNEcho | TCPCWR ) -// CaptureMeta contains metadata that is used when debugging. -type CaptureMeta struct { - DidSNAT bool // SNAT was performed & the address was updated. - OriginalSrc netip.AddrPort // The source address before SNAT was performed. - DidDNAT bool // DNAT was performed & the address was updated. - OriginalDst netip.AddrPort // The destination address before DNAT was performed. -} - // Parsed is a minimal decoding of a packet suitable for use in filters. type Parsed struct { // b is the byte buffer that this decodes. @@ -59,10 +51,11 @@ type Parsed struct { IPVersion uint8 // IPProto is the IP subprotocol (UDP, TCP, etc.). Valid iff IPVersion != 0. IPProto ipproto.Proto - // SrcIP4 is the source address. Family matches IPVersion. Port is - // valid iff IPProto == TCP || IPProto == UDP. + // Src is the source address. Family matches IPVersion. Port is + // valid iff IPProto == TCP || IPProto == UDP || IPProto == SCTP. Src netip.AddrPort - // DstIP4 is the destination address. Family matches IPVersion. + // Dst is the destination address. Family matches IPVersion. Port is + // valid iff IPProto == TCP || IPProto == UDP || IPProto == SCTP. Dst netip.AddrPort // TCPFlags is the packet's TCP flag bits. Valid iff IPProto == TCP. TCPFlags TCPFlag @@ -168,14 +161,8 @@ func (q *Parsed) decode4(b []byte) { if fragOfs == 0 { // This is the first fragment - if moreFrags && len(sub) < minFragBlks { - // Suspiciously short first fragment, dump it. - q.IPProto = unknown - return - } - // otherwise, this is either non-fragmented (the usual case) - // or a big enough initial fragment that we can read the - // whole subprotocol header. + // Every protocol below MUST check that it has at least one entire + // transport header in order to protect against fragment confusion. switch q.IPProto { case ipproto.ICMPv4: if len(sub) < icmp4HeaderLength { @@ -187,6 +174,10 @@ func (q *Parsed) decode4(b []byte) { q.dataofs = q.subofs + icmp4HeaderLength return case ipproto.IGMP: + if len(sub) < igmpHeaderLength { + q.IPProto = unknown + return + } // Keep IPProto, but don't parse anything else // out. return @@ -219,6 +210,15 @@ func (q *Parsed) decode4(b []byte) { q.Dst = withPort(q.Dst, binary.BigEndian.Uint16(sub[2:4])) return case ipproto.TSMP: + // Strictly disallow fragmented TSMP + if moreFrags { + q.IPProto = unknown + return + } + if len(sub) < minTSMPSize { + q.IPProto = unknown + return + } // Inter-tailscale messages. q.dataofs = q.subofs return @@ -231,8 +231,11 @@ func (q *Parsed) decode4(b []byte) { } else { // This is a fragment other than the first one. if fragOfs < minFragBlks { - // First frag was suspiciously short, so we can't - // trust the followup either. + // disallow fragment offsets that are potentially inside of a + // transport header. This is notably asymmetric with the + // first-packet limit, that may allow a first-packet that requires a + // shorter offset than this limit, but without state to tie this + // to the first fragment we can not allow shorter packets. q.IPProto = unknown return } @@ -322,6 +325,10 @@ func (q *Parsed) decode6(b []byte) { q.Dst = withPort(q.Dst, binary.BigEndian.Uint16(sub[2:4])) return case ipproto.TSMP: + if len(sub) < minTSMPSize { + q.IPProto = unknown + return + } // Inter-tailscale messages. q.dataofs = q.subofs return diff --git a/net/packet/packet_test.go b/net/packet/packet_test.go index 4fc804a4f..09c2c101d 100644 --- a/net/packet/packet_test.go +++ b/net/packet/packet_test.go @@ -385,6 +385,124 @@ var sctpDecode = Parsed{ Dst: mustIPPort("100.74.70.3:456"), } +var ipv4ShortFirstFragmentBuffer = []byte{ + // IP header (20 bytes) + 0x45, 0x00, 0x00, 0x4f, // Total length 79 bytes + 0x00, 0x01, 0x20, 0x00, // ID, Flags (MoreFragments set, offset 0) + 0x40, 0x06, 0x00, 0x00, // TTL, Protocol (TCP), Checksum + 0x01, 0x02, 0x03, 0x04, // Source IP + 0x05, 0x06, 0x07, 0x08, // Destination IP + // TCP header (20 bytes), but packet is truncated to 59 bytes of TCP data + // (total 79 bytes, 20 for IP) + 0x00, 0x7b, 0x02, 0x37, 0x00, 0x00, 0x12, 0x34, 0x00, 0x00, 0x00, 0x00, + 0x50, 0x12, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, + // Payload (39 bytes) + 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, + 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, + 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, +} + +var ipv4ShortFirstFragmentDecode = Parsed{ + b: ipv4ShortFirstFragmentBuffer, + subofs: 20, + dataofs: 40, + length: len(ipv4ShortFirstFragmentBuffer), + IPVersion: 4, + IPProto: ipproto.TCP, + Src: mustIPPort("1.2.3.4:123"), + Dst: mustIPPort("5.6.7.8:567"), + TCPFlags: 0x12, // SYN + ACK +} + +var ipv4SmallOffsetFragmentBuffer = []byte{ + // IP header (20 bytes) + 0x45, 0x00, 0x00, 0x28, // Total length 40 bytes + 0x00, 0x01, 0x20, 0x08, // ID, Flags (MoreFragments set, offset 8 bytes (0x08 / 8 = 1)) + 0x40, 0x06, 0x00, 0x00, // TTL, Protocol (TCP), Checksum + 0x01, 0x02, 0x03, 0x04, // Source IP + 0x05, 0x06, 0x07, 0x08, // Destination IP + // Payload (20 bytes) - this would be part of the TCP header in a real scenario + 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, 0x61, + 0x61, 0x61, 0x61, 0x61, +} + +var ipv4SmallOffsetFragmentDecode = Parsed{ + b: ipv4SmallOffsetFragmentBuffer, + subofs: 20, // subofs will still be set based on IHL + dataofs: 0, // It's unknown, so dataofs should be 0 + length: len(ipv4SmallOffsetFragmentBuffer), + IPVersion: 4, + IPProto: ipproto.Unknown, // Expected to be Unknown + Src: mustIPPort("1.2.3.4:0"), + Dst: mustIPPort("5.6.7.8:0"), +} + +// First fragment packet missing exactly one byte of the TCP header +var ipv4OneByteShortTCPHeaderBuffer = []byte{ + // IP header (20 bytes) + 0x45, 0x00, 0x00, 0x27, // Total length 51 bytes (20 IP + 19 TCP) + 0x00, 0x01, 0x20, 0x00, // ID, Flags (MoreFragments set, offset 0) + 0x40, 0x06, 0x00, 0x00, // TTL, Protocol (TCP), Checksum + 0x01, 0x02, 0x03, 0x04, // Source IP + 0x05, 0x06, 0x07, 0x08, // Destination IP + // TCP header - only 19 bytes (one byte short of the required 20) + 0x00, 0x7b, 0x02, 0x37, // Source port, Destination port + 0x00, 0x00, 0x12, 0x34, // Sequence number + 0x00, 0x00, 0x00, 0x00, // Acknowledgment number + 0x50, 0x12, 0x01, 0x00, // Data offset, flags, window size + 0x00, 0x00, 0x00, // Checksum (missing the last byte of urgent pointer) +} + +// IPv4 packet with maximum header length (60 bytes = 15 words) and a TCP header that's +// one byte short of being complete +var ipv4MaxHeaderShortTCPBuffer = []byte{ + // IP header with max options (60 bytes) + 0x4F, 0x00, 0x00, 0x4F, // Version (4) + IHL (15), ToS, Total length 79 bytes (60 IP + 19 TCP) + 0x00, 0x01, 0x20, 0x00, // ID, Flags (MoreFragments set, offset 0) + 0x40, 0x06, 0x00, 0x00, // TTL, Protocol (TCP), Checksum + 0x01, 0x02, 0x03, 0x04, // Source IP + 0x05, 0x06, 0x07, 0x08, // Destination IP + // IPv4 options (40 bytes) + 0x01, 0x01, 0x01, 0x01, // 4 NOP options (padding) + 0x01, 0x01, 0x01, 0x01, // 4 NOP options (padding) + 0x01, 0x01, 0x01, 0x01, // 4 NOP options (padding) + 0x01, 0x01, 0x01, 0x01, // 4 NOP options (padding) + 0x01, 0x01, 0x01, 0x01, // 4 NOP options (padding) + 0x01, 0x01, 0x01, 0x01, // 4 NOP options (padding) + 0x01, 0x01, 0x01, 0x01, // 4 NOP options (padding) + 0x01, 0x01, 0x01, 0x01, // 4 NOP options (padding) + 0x01, 0x01, 0x01, 0x01, // 4 NOP options (padding) + 0x01, 0x01, 0x01, 0x01, // 4 NOP options (padding) + // TCP header - only 19 bytes (one byte short of the required 20) + 0x00, 0x7b, 0x02, 0x37, // Source port, Destination port + 0x00, 0x00, 0x12, 0x34, // Sequence number + 0x00, 0x00, 0x00, 0x00, // Acknowledgment number + 0x50, 0x12, 0x01, 0x00, // Data offset, flags, window size + 0x00, 0x00, 0x00, // Checksum (missing the last byte of urgent pointer) +} + +var ipv4MaxHeaderShortTCPDecode = Parsed{ + b: ipv4MaxHeaderShortTCPBuffer, + subofs: 60, // 60 bytes for full IPv4 header with max options + dataofs: 0, // It's unknown, so dataofs should be 0 + length: len(ipv4MaxHeaderShortTCPBuffer), + IPVersion: 4, + IPProto: ipproto.Unknown, // Expected to be Unknown + Src: mustIPPort("1.2.3.4:0"), + Dst: mustIPPort("5.6.7.8:0"), +} + +var ipv4OneByteShortTCPHeaderDecode = Parsed{ + b: ipv4OneByteShortTCPHeaderBuffer, + subofs: 20, + dataofs: 0, // It's unknown, so dataofs should be 0 + length: len(ipv4OneByteShortTCPHeaderBuffer), + IPVersion: 4, + IPProto: ipproto.Unknown, // Expected to be Unknown + Src: mustIPPort("1.2.3.4:0"), + Dst: mustIPPort("5.6.7.8:0"), +} + func TestParsedString(t *testing.T) { tests := []struct { name string @@ -450,6 +568,10 @@ func TestDecode(t *testing.T) { {"ipv4_sctp", sctpBuffer, sctpDecode}, {"ipv4_frag", tcp4MediumFragmentBuffer, tcp4MediumFragmentDecode}, {"ipv4_fragtooshort", tcp4ShortFragmentBuffer, tcp4ShortFragmentDecode}, + {"ipv4_short_first_fragment", ipv4ShortFirstFragmentBuffer, ipv4ShortFirstFragmentDecode}, + {"ipv4_small_offset_fragment", ipv4SmallOffsetFragmentBuffer, ipv4SmallOffsetFragmentDecode}, + {"ipv4_one_byte_short_tcp_header", ipv4OneByteShortTCPHeaderBuffer, ipv4OneByteShortTCPHeaderDecode}, + {"ipv4_max_header_short_tcp", ipv4MaxHeaderShortTCPBuffer, ipv4MaxHeaderShortTCPDecode}, {"ip97", mustHexDecode("4500 0019 d186 4000 4061 751d 644a 4603 6449 e549 6865 6c6c 6f"), Parsed{ IPVersion: 4, diff --git a/net/packet/tsmp.go b/net/packet/tsmp.go index 4e004cca2..0ea321e84 100644 --- a/net/packet/tsmp.go +++ b/net/packet/tsmp.go @@ -15,10 +15,11 @@ import ( "fmt" "net/netip" - "tailscale.com/net/flowtrack" "tailscale.com/types/ipproto" ) +const minTSMPSize = 7 // the rejected body is 7 bytes + // TailscaleRejectedHeader is a TSMP message that says that one // Tailscale node has rejected the connection from another. Unlike a // TCP RST, this includes a reason. @@ -56,10 +57,6 @@ type TailscaleRejectedHeader struct { const rejectFlagBitMaybeBroken = 0x1 -func (rh TailscaleRejectedHeader) Flow() flowtrack.Tuple { - return flowtrack.MakeTuple(rh.Proto, rh.Src, rh.Dst) -} - func (rh TailscaleRejectedHeader) String() string { return fmt.Sprintf("TSMP-reject-flow{%s %s > %s}: %s", rh.Proto, rh.Src, rh.Dst, rh.Reason) } diff --git a/net/ping/ping.go b/net/ping/ping.go index 01f3dcf2c..8e16a692a 100644 --- a/net/ping/ping.go +++ b/net/ping/ping.go @@ -10,6 +10,7 @@ import ( "context" "crypto/rand" "encoding/binary" + "errors" "fmt" "io" "log" @@ -22,9 +23,9 @@ import ( "golang.org/x/net/icmp" "golang.org/x/net/ipv4" "golang.org/x/net/ipv6" + "tailscale.com/syncs" "tailscale.com/types/logger" "tailscale.com/util/mak" - "tailscale.com/util/multierr" ) const ( @@ -64,7 +65,7 @@ type Pinger struct { wg sync.WaitGroup // Following fields protected by mu - mu sync.Mutex + mu syncs.Mutex // conns is a map of "type" to net.PacketConn, type is either // "ip4:icmp" or "ip6:icmp" conns map[string]net.PacketConn @@ -157,17 +158,17 @@ func (p *Pinger) Close() error { p.conns = nil p.mu.Unlock() - var errors []error + var errs []error for _, c := range conns { if err := c.Close(); err != nil { - errors = append(errors, err) + errs = append(errs, err) } } p.wg.Wait() p.cleanupOutstanding() - return multierr.New(errors...) + return errors.Join(errs...) } func (p *Pinger) run(ctx context.Context, conn net.PacketConn, typ string) { diff --git a/net/portmapper/igd_test.go b/net/portmapper/igd_test.go index 5c24d03aa..77015f5bf 100644 --- a/net/portmapper/igd_test.go +++ b/net/portmapper/igd_test.go @@ -14,11 +14,13 @@ import ( "sync/atomic" "testing" - "tailscale.com/control/controlknobs" "tailscale.com/net/netaddr" "tailscale.com/net/netmon" "tailscale.com/syncs" + "tailscale.com/tstest" "tailscale.com/types/logger" + "tailscale.com/util/eventbus" + "tailscale.com/util/testenv" ) // TestIGD is an IGD (Internet Gateway Device) for testing. It supports fake @@ -63,7 +65,8 @@ type igdCounters struct { invalidPCPMapPkt int32 } -func NewTestIGD(logf logger.Logf, t TestIGDOptions) (*TestIGD, error) { +func NewTestIGD(tb testenv.TB, t TestIGDOptions) (*TestIGD, error) { + logf := tstest.WhileTestRunningLogger(tb) d := &TestIGD{ doPMP: t.PMP, doPCP: t.PCP, @@ -258,15 +261,29 @@ func (d *TestIGD) handlePCPQuery(pkt []byte, src netip.AddrPort) { } } -func newTestClient(t *testing.T, igd *TestIGD) *Client { +// newTestClient configures a new test client connected to igd for mapping updates. +// If bus == nil, a new empty event bus is constructed that is cleaned up when t exits. +// A cleanup for the resulting client is also added to t. +func newTestClient(t *testing.T, igd *TestIGD, bus *eventbus.Bus) *Client { + if bus == nil { + bus = eventbus.New() + t.Log("Created empty event bus for test client") + t.Cleanup(bus.Close) + } var c *Client - c = NewClient(t.Logf, netmon.NewStatic(), nil, new(controlknobs.Knobs), func() { - t.Logf("port map changed") - t.Logf("have mapping: %v", c.HaveMapping()) + c = NewClient(Config{ + Logf: tstest.WhileTestRunningLogger(t), + NetMon: netmon.NewStatic(), + EventBus: bus, + OnChange: func() { // TODO(creachadair): Remove. + t.Logf("port map changed") + t.Logf("have mapping: %v", c.HaveMapping()) + }, }) c.testPxPPort = igd.TestPxPPort() c.testUPnPPort = igd.TestUPnPPort() c.netMon = netmon.NewStatic() c.SetGatewayLookupFunc(testIPAndGateway) + t.Cleanup(func() { c.Close() }) return c } diff --git a/net/portmapper/pmpresultcode_string.go b/net/portmapper/pmpresultcode_string.go index 603636ade..18d911d94 100644 --- a/net/portmapper/pmpresultcode_string.go +++ b/net/portmapper/pmpresultcode_string.go @@ -24,8 +24,9 @@ const _pmpResultCode_name = "OKUnsupportedVersionNotAuthorizedNetworkFailureOutO var _pmpResultCode_index = [...]uint8{0, 2, 20, 33, 47, 61, 78} func (i pmpResultCode) String() string { - if i >= pmpResultCode(len(_pmpResultCode_index)-1) { + idx := int(i) - 0 + if i < 0 || idx >= len(_pmpResultCode_index)-1 { return "pmpResultCode(" + strconv.FormatInt(int64(i), 10) + ")" } - return _pmpResultCode_name[_pmpResultCode_index[i]:_pmpResultCode_index[i+1]] + return _pmpResultCode_name[_pmpResultCode_index[idx]:_pmpResultCode_index[idx+1]] } diff --git a/net/portmapper/portmapper.go b/net/portmapper/portmapper.go index 71b55b8a7..16a981d1d 100644 --- a/net/portmapper/portmapper.go +++ b/net/portmapper/portmapper.go @@ -8,29 +8,36 @@ package portmapper import ( "context" "encoding/binary" - "errors" "fmt" "io" "net" "net/http" "net/netip" "slices" - "sync" "sync/atomic" "time" "go4.org/mem" - "tailscale.com/control/controlknobs" "tailscale.com/envknob" + "tailscale.com/feature/buildfeatures" "tailscale.com/net/netaddr" "tailscale.com/net/neterror" "tailscale.com/net/netmon" "tailscale.com/net/netns" + "tailscale.com/net/portmapper/portmappertype" "tailscale.com/net/sockstats" "tailscale.com/syncs" "tailscale.com/types/logger" "tailscale.com/types/nettype" "tailscale.com/util/clientmetric" + "tailscale.com/util/eventbus" +) + +var ( + ErrNoPortMappingServices = portmappertype.ErrNoPortMappingServices + ErrGatewayRange = portmappertype.ErrGatewayRange + ErrGatewayIPv6 = portmappertype.ErrGatewayIPv6 + ErrPortMappingDisabled = portmappertype.ErrPortMappingDisabled ) var disablePortMapperEnv = envknob.RegisterBool("TS_DISABLE_PORTMAPPER") @@ -48,15 +55,33 @@ type DebugKnobs struct { LogHTTP bool // Disable* disables a specific service from mapping. - DisableUPnP bool - DisablePMP bool - DisablePCP bool + // If the funcs are nil or return false, the service is not disabled. + // Use the corresponding accessor methods without the "Func" suffix + // to check whether a service is disabled. + DisableUPnPFunc func() bool + DisablePMPFunc func() bool + DisablePCPFunc func() bool // DisableAll, if non-nil, is a func that reports whether all port // mapping attempts should be disabled. DisableAll func() bool } +// DisableUPnP reports whether UPnP is disabled. +func (k *DebugKnobs) DisableUPnP() bool { + return k != nil && k.DisableUPnPFunc != nil && k.DisableUPnPFunc() +} + +// DisablePMP reports whether NAT-PMP is disabled. +func (k *DebugKnobs) DisablePMP() bool { + return k != nil && k.DisablePMPFunc != nil && k.DisablePMPFunc() +} + +// DisablePCP reports whether PCP is disabled. +func (k *DebugKnobs) DisablePCP() bool { + return k != nil && k.DisablePCPFunc != nil && k.DisablePCPFunc() +} + func (k *DebugKnobs) disableAll() bool { if disablePortMapperEnv() { return true @@ -84,16 +109,20 @@ const trustServiceStillAvailableDuration = 10 * time.Minute // Client is a port mapping client. type Client struct { + // The following two fields must both be non-nil. + // Both are immutable after construction. + pubClient *eventbus.Client + updates *eventbus.Publisher[portmappertype.Mapping] + logf logger.Logf netMon *netmon.Monitor // optional; nil means interfaces will be looked up on-demand - controlKnobs *controlknobs.Knobs ipAndGateway func() (gw, ip netip.Addr, ok bool) onChange func() // or nil debug DebugKnobs testPxPPort uint16 // if non-zero, pxpPort to use for tests testUPnPPort uint16 // if non-zero, uPnPPort to use for tests - mu sync.Mutex // guards following, and all fields thereof + mu syncs.Mutex // guards following, and all fields thereof // runningCreate is whether we're currently working on creating // a port mapping (whether GetCachedMappingOrStartCreatingOne kicked @@ -124,6 +153,8 @@ type Client struct { mapping mapping // non-nil if we have a mapping } +var _ portmappertype.Client = (*Client)(nil) + func (c *Client) vlogf(format string, args ...any) { if c.debug.VerboseLogs { c.logf(format, args...) @@ -153,7 +184,6 @@ type mapping interface { MappingDebug() string } -// HaveMapping reports whether we have a current valid mapping. func (c *Client) HaveMapping() bool { c.mu.Lock() defer c.mu.Unlock() @@ -201,32 +231,52 @@ func (m *pmpMapping) Release(ctx context.Context) { uc.WriteToUDPAddrPort(pkt, m.gw) } -// NewClient returns a new portmapping client. -// -// The netMon parameter is required. -// -// The debug argument allows configuring the behaviour of the portmapper for -// debugging; if nil, a sensible set of defaults will be used. -// -// The controlKnobs, if non-nil, specifies the control knobs from the control -// plane that might disable portmapping. -// -// The optional onChange argument specifies a func to run in a new goroutine -// whenever the port mapping status has changed. If nil, it doesn't make a -// callback. -func NewClient(logf logger.Logf, netMon *netmon.Monitor, debug *DebugKnobs, controlKnobs *controlknobs.Knobs, onChange func()) *Client { - if netMon == nil { - panic("nil netMon") +// Config carries the settings for a [Client]. +type Config struct { + // EventBus, which must be non-nil, is used for event publication and + // subscription by portmapper clients created from this config. + EventBus *eventbus.Bus + + // Logf is called to generate text logs for the client. If nil, logger.Discard is used. + Logf logger.Logf + + // NetMon is the network monitor used by the client. It must be non-nil. + NetMon *netmon.Monitor + + // DebugKnobs, if non-nil, configure the behaviour of the portmapper for + // debugging. If nil, a sensible set of defaults will be used. + DebugKnobs *DebugKnobs + + // OnChange is called to run in a new goroutine whenever the port mapping + // status has changed. If nil, no callback is issued. + OnChange func() +} + +// NewClient constructs a new portmapping [Client] from c. It will panic if any +// required parameters are omitted. +func NewClient(c Config) *Client { + switch { + case c.NetMon == nil: + panic("nil NetMon") + case c.EventBus == nil: + panic("nil EventBus") } ret := &Client{ - logf: logf, - netMon: netMon, - ipAndGateway: netmon.LikelyHomeRouterIP, // TODO(bradfitz): move this to method on netMon - onChange: onChange, - controlKnobs: controlKnobs, + logf: c.Logf, + netMon: c.NetMon, + onChange: c.OnChange, + } + if buildfeatures.HasPortMapper { + // TODO(bradfitz): move this to method on netMon + ret.ipAndGateway = netmon.LikelyHomeRouterIP } - if debug != nil { - ret.debug = *debug + ret.pubClient = c.EventBus.Client("portmapper") + ret.updates = eventbus.Publish[portmappertype.Mapping](ret.pubClient) + if ret.logf == nil { + ret.logf = logger.Discard + } + if c.DebugKnobs != nil { + ret.debug = *c.DebugKnobs } return ret } @@ -256,6 +306,9 @@ func (c *Client) Close() error { } c.closed = true c.invalidateMappingsLocked(true) + c.updates.Close() + c.pubClient.Close() + // TODO: close some future ever-listening UDP socket(s), // waiting for multicast announcements from router. return nil @@ -417,13 +470,6 @@ func IsNoMappingError(err error) bool { return ok } -var ( - ErrNoPortMappingServices = errors.New("no port mapping services were found") - ErrGatewayRange = errors.New("skipping portmap; gateway range likely lacks support") - ErrGatewayIPv6 = errors.New("skipping portmap; no IPv6 support for portmapping") - ErrPortMappingDisabled = errors.New("port mapping is disabled") -) - // GetCachedMappingOrStartCreatingOne quickly returns with our current cached portmapping, if any. // If there's not one, it starts up a background goroutine to create one. // If the background goroutine ends up creating one, the onChange hook registered with the @@ -467,10 +513,29 @@ func (c *Client) createMapping() { c.runningCreate = false }() - if _, err := c.createOrGetMapping(ctx); err == nil && c.onChange != nil { + mapping, _, err := c.createOrGetMapping(ctx) + if err != nil { + if !IsNoMappingError(err) { + c.logf("createOrGetMapping: %v", err) + } + return + } else if mapping == nil { + return + + // TODO(creachadair): This was already logged in createOrGetMapping. + // It really should not happen at all, but we will need to untangle + // the control flow to eliminate that possibility. Meanwhile, this + // mitigates a panic downstream, cf. #16662. + } + c.updates.Publish(portmappertype.Mapping{ + External: mapping.External(), + Type: mapping.MappingType(), + GoodUntil: mapping.GoodUntil(), + }) + // TODO(creachadair): Remove this entirely once there are no longer any + // places where the callback is set. + if c.onChange != nil { go c.onChange() - } else if err != nil && !IsNoMappingError(err) { - c.logf("createOrGetMapping: %v", err) } } @@ -482,19 +547,19 @@ var wildcardIP = netip.MustParseAddr("0.0.0.0") // // If no mapping is available, the error will be of type // NoMappingError; see IsNoMappingError. -func (c *Client) createOrGetMapping(ctx context.Context) (external netip.AddrPort, err error) { +func (c *Client) createOrGetMapping(ctx context.Context) (mapping mapping, external netip.AddrPort, err error) { if c.debug.disableAll() { - return netip.AddrPort{}, NoMappingError{ErrPortMappingDisabled} + return nil, netip.AddrPort{}, NoMappingError{ErrPortMappingDisabled} } - if c.debug.DisableUPnP && c.debug.DisablePCP && c.debug.DisablePMP { - return netip.AddrPort{}, NoMappingError{ErrNoPortMappingServices} + if c.debug.DisableUPnP() && c.debug.DisablePCP() && c.debug.DisablePMP() { + return nil, netip.AddrPort{}, NoMappingError{ErrNoPortMappingServices} } gw, myIP, ok := c.gatewayAndSelfIP() if !ok { - return netip.AddrPort{}, NoMappingError{ErrGatewayRange} + return nil, netip.AddrPort{}, NoMappingError{ErrGatewayRange} } if gw.Is6() { - return netip.AddrPort{}, NoMappingError{ErrGatewayIPv6} + return nil, netip.AddrPort{}, NoMappingError{ErrGatewayIPv6} } now := time.Now() @@ -523,6 +588,17 @@ func (c *Client) createOrGetMapping(ctx context.Context) (external netip.AddrPor return } + // TODO(creachadair): This is more subtle than it should be. Ideally we + // would just return the mapping directly, but there are many different + // paths through the function with carefully-balanced locks, and not all + // the paths have a mapping to return. As a workaround, while we're here + // doing cleanup under the lock, grab the final mapping value and return + // it, so the caller does not need to grab the lock again and potentially + // race with a later update. The mapping itself is concurrency-safe. + // + // We should restructure this code so the locks are properly scoped. + mapping = c.mapping + // Print the internal details of each mapping if we're being verbose. if c.debug.VerboseLogs { c.logf("successfully obtained mapping: now=%d external=%v type=%s mapping=%s", @@ -548,19 +624,19 @@ func (c *Client) createOrGetMapping(ctx context.Context) (external netip.AddrPor if now.Before(m.RenewAfter()) { defer c.mu.Unlock() reusedExisting = true - return m.External(), nil + return nil, m.External(), nil } // The mapping might still be valid, so just try to renew it. prevPort = m.External().Port() } - if c.debug.DisablePCP && c.debug.DisablePMP { + if c.debug.DisablePCP() && c.debug.DisablePMP() { c.mu.Unlock() if external, ok := c.getUPnPPortMapping(ctx, gw, internalAddr, prevPort); ok { - return external, nil + return nil, external, nil } c.vlogf("fallback to UPnP due to PCP and PMP being disabled failed") - return netip.AddrPort{}, NoMappingError{ErrNoPortMappingServices} + return nil, netip.AddrPort{}, NoMappingError{ErrNoPortMappingServices} } // If we just did a Probe (e.g. via netchecker) but didn't @@ -587,16 +663,16 @@ func (c *Client) createOrGetMapping(ctx context.Context) (external netip.AddrPor c.mu.Unlock() // fallback to UPnP portmapping if external, ok := c.getUPnPPortMapping(ctx, gw, internalAddr, prevPort); ok { - return external, nil + return nil, external, nil } c.vlogf("fallback to UPnP due to no PCP and PMP failed") - return netip.AddrPort{}, NoMappingError{ErrNoPortMappingServices} + return nil, netip.AddrPort{}, NoMappingError{ErrNoPortMappingServices} } c.mu.Unlock() uc, err := c.listenPacket(ctx, "udp4", ":0") if err != nil { - return netip.AddrPort{}, err + return nil, netip.AddrPort{}, err } defer uc.Close() @@ -605,7 +681,7 @@ func (c *Client) createOrGetMapping(ctx context.Context) (external netip.AddrPor pxpAddr := netip.AddrPortFrom(gw, c.pxpPort()) - preferPCP := !c.debug.DisablePCP && (c.debug.DisablePMP || (!haveRecentPMP && haveRecentPCP)) + preferPCP := !c.debug.DisablePCP() && (c.debug.DisablePMP() || (!haveRecentPMP && haveRecentPCP)) // Create a mapping, defaulting to PMP unless only PCP was seen recently. if preferPCP { @@ -616,7 +692,7 @@ func (c *Client) createOrGetMapping(ctx context.Context) (external netip.AddrPor if neterror.TreatAsLostUDP(err) { err = NoMappingError{ErrNoPortMappingServices} } - return netip.AddrPort{}, err + return nil, netip.AddrPort{}, err } } else { // Ask for our external address if needed. @@ -625,7 +701,7 @@ func (c *Client) createOrGetMapping(ctx context.Context) (external netip.AddrPor if neterror.TreatAsLostUDP(err) { err = NoMappingError{ErrNoPortMappingServices} } - return netip.AddrPort{}, err + return nil, netip.AddrPort{}, err } } @@ -634,7 +710,7 @@ func (c *Client) createOrGetMapping(ctx context.Context) (external netip.AddrPor if neterror.TreatAsLostUDP(err) { err = NoMappingError{ErrNoPortMappingServices} } - return netip.AddrPort{}, err + return nil, netip.AddrPort{}, err } } @@ -643,13 +719,13 @@ func (c *Client) createOrGetMapping(ctx context.Context) (external netip.AddrPor n, src, err := uc.ReadFromUDPAddrPort(res) if err != nil { if ctx.Err() == context.Canceled { - return netip.AddrPort{}, err + return nil, netip.AddrPort{}, err } // fallback to UPnP portmapping if mapping, ok := c.getUPnPPortMapping(ctx, gw, internalAddr, prevPort); ok { - return mapping, nil + return nil, mapping, nil } - return netip.AddrPort{}, NoMappingError{ErrNoPortMappingServices} + return nil, netip.AddrPort{}, NoMappingError{ErrNoPortMappingServices} } src = netaddr.Unmap(src) if !src.IsValid() { @@ -665,7 +741,7 @@ func (c *Client) createOrGetMapping(ctx context.Context) (external netip.AddrPor continue } if pres.ResultCode != 0 { - return netip.AddrPort{}, NoMappingError{fmt.Errorf("PMP response Op=0x%x,Res=0x%x", pres.OpCode, pres.ResultCode)} + return nil, netip.AddrPort{}, NoMappingError{fmt.Errorf("PMP response Op=0x%x,Res=0x%x", pres.OpCode, pres.ResultCode)} } if pres.OpCode == pmpOpReply|pmpOpMapPublicAddr { m.external = netip.AddrPortFrom(pres.PublicAddr, m.external.Port()) @@ -683,7 +759,7 @@ func (c *Client) createOrGetMapping(ctx context.Context) (external netip.AddrPor if err != nil { c.logf("failed to get PCP mapping: %v", err) // PCP should only have a single packet response - return netip.AddrPort{}, NoMappingError{ErrNoPortMappingServices} + return nil, netip.AddrPort{}, NoMappingError{ErrNoPortMappingServices} } pcpMapping.c = c pcpMapping.internal = m.internal @@ -691,10 +767,10 @@ func (c *Client) createOrGetMapping(ctx context.Context) (external netip.AddrPor c.mu.Lock() defer c.mu.Unlock() c.mapping = pcpMapping - return pcpMapping.external, nil + return pcpMapping, pcpMapping.external, nil default: c.logf("unknown PMP/PCP version number: %d %v", version, res[:n]) - return netip.AddrPort{}, NoMappingError{ErrNoPortMappingServices} + return nil, netip.AddrPort{}, NoMappingError{ErrNoPortMappingServices} } } @@ -702,7 +778,7 @@ func (c *Client) createOrGetMapping(ctx context.Context) (external netip.AddrPor c.mu.Lock() defer c.mu.Unlock() c.mapping = m - return m.external, nil + return nil, m.external, nil } } } @@ -790,19 +866,13 @@ func parsePMPResponse(pkt []byte) (res pmpResponse, ok bool) { return res, true } -type ProbeResult struct { - PCP bool - PMP bool - UPnP bool -} - // Probe returns a summary of which port mapping services are // available on the network. // // If a probe has run recently and there haven't been any network changes since, // the returned result might be server from the Client's cache, without // sending any network traffic. -func (c *Client) Probe(ctx context.Context) (res ProbeResult, err error) { +func (c *Client) Probe(ctx context.Context) (res portmappertype.ProbeResult, err error) { if c.debug.disableAll() { return res, ErrPortMappingDisabled } @@ -837,19 +907,19 @@ func (c *Client) Probe(ctx context.Context) (res ProbeResult, err error) { // https://github.com/tailscale/tailscale/issues/1001 if c.sawPMPRecently() { res.PMP = true - } else if !c.debug.DisablePMP { + } else if !c.debug.DisablePMP() { metricPMPSent.Add(1) uc.WriteToUDPAddrPort(pmpReqExternalAddrPacket, pxpAddr) } if c.sawPCPRecently() { res.PCP = true - } else if !c.debug.DisablePCP { + } else if !c.debug.DisablePCP() { metricPCPSent.Add(1) uc.WriteToUDPAddrPort(pcpAnnounceRequest(myIP), pxpAddr) } if c.sawUPnPRecently() { res.UPnP = true - } else if !c.debug.DisableUPnP { + } else if !c.debug.DisableUPnP() { // Strictly speaking, you discover UPnP services by sending an // SSDP query (which uPnPPacket is) to udp/1900 on the SSDP // multicast address, and then get a flood of responses back diff --git a/net/portmapper/portmapper_test.go b/net/portmapper/portmapper_test.go index d321b720a..a697a3908 100644 --- a/net/portmapper/portmapper_test.go +++ b/net/portmapper/portmapper_test.go @@ -11,21 +11,22 @@ import ( "testing" "time" - "tailscale.com/control/controlknobs" + "tailscale.com/net/portmapper/portmappertype" + "tailscale.com/util/eventbus/eventbustest" ) func TestCreateOrGetMapping(t *testing.T) { if v, _ := strconv.ParseBool(os.Getenv("HIT_NETWORK")); !v { t.Skip("skipping test without HIT_NETWORK=1") } - c := NewClient(t.Logf, nil, nil, new(controlknobs.Knobs), nil) + c := NewClient(Config{Logf: t.Logf}) defer c.Close() c.SetLocalPort(1234) for i := range 2 { if i > 0 { time.Sleep(100 * time.Millisecond) } - ext, err := c.createOrGetMapping(context.Background()) + _, ext, err := c.createOrGetMapping(context.Background()) t.Logf("Got: %v, %v", ext, err) } } @@ -34,7 +35,7 @@ func TestClientProbe(t *testing.T) { if v, _ := strconv.ParseBool(os.Getenv("HIT_NETWORK")); !v { t.Skip("skipping test without HIT_NETWORK=1") } - c := NewClient(t.Logf, nil, nil, new(controlknobs.Knobs), nil) + c := NewClient(Config{Logf: t.Logf}) defer c.Close() for i := range 3 { if i > 0 { @@ -49,26 +50,25 @@ func TestClientProbeThenMap(t *testing.T) { if v, _ := strconv.ParseBool(os.Getenv("HIT_NETWORK")); !v { t.Skip("skipping test without HIT_NETWORK=1") } - c := NewClient(t.Logf, nil, nil, new(controlknobs.Knobs), nil) + c := NewClient(Config{Logf: t.Logf}) defer c.Close() c.debug.VerboseLogs = true c.SetLocalPort(1234) res, err := c.Probe(context.Background()) t.Logf("Probe: %+v, %v", res, err) - ext, err := c.createOrGetMapping(context.Background()) + _, ext, err := c.createOrGetMapping(context.Background()) t.Logf("createOrGetMapping: %v, %v", ext, err) } func TestProbeIntegration(t *testing.T) { - igd, err := NewTestIGD(t.Logf, TestIGDOptions{PMP: true, PCP: true, UPnP: true}) + igd, err := NewTestIGD(t, TestIGDOptions{PMP: true, PCP: true, UPnP: true}) if err != nil { t.Fatal(err) } defer igd.Close() - c := newTestClient(t, igd) + c := newTestClient(t, igd, nil) t.Logf("Listening on pxp=%v, upnp=%v", c.testPxPPort, c.testUPnPPort) - defer c.Close() res, err := c.Probe(context.Background()) if err != nil { @@ -95,14 +95,13 @@ func TestProbeIntegration(t *testing.T) { } func TestPCPIntegration(t *testing.T) { - igd, err := NewTestIGD(t.Logf, TestIGDOptions{PMP: false, PCP: true, UPnP: false}) + igd, err := NewTestIGD(t, TestIGDOptions{PMP: false, PCP: true, UPnP: false}) if err != nil { t.Fatal(err) } defer igd.Close() - c := newTestClient(t, igd) - defer c.Close() + c := newTestClient(t, igd, nil) res, err := c.Probe(context.Background()) if err != nil { t.Fatalf("probe failed: %v", err) @@ -114,7 +113,7 @@ func TestPCPIntegration(t *testing.T) { t.Fatalf("probe did not see pcp: %+v", res) } - external, err := c.createOrGetMapping(context.Background()) + _, external, err := c.createOrGetMapping(context.Background()) if err != nil { t.Fatalf("failed to get mapping: %v", err) } @@ -136,3 +135,22 @@ func TestGetUPnPErrorsMetric(t *testing.T) { getUPnPErrorsMetric(0) getUPnPErrorsMetric(-100) } + +func TestUpdateEvent(t *testing.T) { + igd, err := NewTestIGD(t, TestIGDOptions{PCP: true}) + if err != nil { + t.Fatalf("Create test gateway: %v", err) + } + + bus := eventbustest.NewBus(t) + tw := eventbustest.NewWatcher(t, bus) + + c := newTestClient(t, igd, bus) + if _, err := c.Probe(t.Context()); err != nil { + t.Fatalf("Probe failed: %v", err) + } + c.GetCachedMappingOrStartCreatingOne() + if err := eventbustest.Expect(tw, eventbustest.Type[portmappertype.Mapping]()); err != nil { + t.Error(err.Error()) + } +} diff --git a/net/portmapper/portmappertype/portmappertype.go b/net/portmapper/portmappertype/portmappertype.go new file mode 100644 index 000000000..cc8358a4a --- /dev/null +++ b/net/portmapper/portmappertype/portmappertype.go @@ -0,0 +1,88 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package portmappertype defines the net/portmapper interface, which may or may not be +// linked into the binary. +package portmappertype + +import ( + "context" + "errors" + "net/netip" + "time" + + "tailscale.com/feature" + "tailscale.com/net/netmon" + "tailscale.com/types/logger" + "tailscale.com/util/eventbus" +) + +// HookNewPortMapper is a hook to install the portmapper creation function. +// It must be set by an init function when buildfeatures.HasPortmapper is true. +var HookNewPortMapper feature.Hook[func(logf logger.Logf, + bus *eventbus.Bus, + netMon *netmon.Monitor, + disableUPnPOrNil, + onlyTCP443OrNil func() bool) Client] + +var ( + ErrNoPortMappingServices = errors.New("no port mapping services were found") + ErrGatewayRange = errors.New("skipping portmap; gateway range likely lacks support") + ErrGatewayIPv6 = errors.New("skipping portmap; no IPv6 support for portmapping") + ErrPortMappingDisabled = errors.New("port mapping is disabled") +) + +// ProbeResult is the result of a portmapper probe, saying +// which port mapping protocols were discovered. +type ProbeResult struct { + PCP bool + PMP bool + UPnP bool +} + +// Client is the interface implemented by a portmapper client. +type Client interface { + // Probe returns a summary of which port mapping services are available on + // the network. + // + // If a probe has run recently and there haven't been any network changes + // since, the returned result might be server from the Client's cache, + // without sending any network traffic. + Probe(context.Context) (ProbeResult, error) + + // HaveMapping reports whether we have a current valid mapping. + HaveMapping() bool + + // SetGatewayLookupFunc set the func that returns the machine's default + // gateway IP, and the primary IP address for that gateway. It must be + // called before the client is used. If not called, + // interfaces.LikelyHomeRouterIP is used. + SetGatewayLookupFunc(f func() (gw, myIP netip.Addr, ok bool)) + + // NoteNetworkDown should be called when the network has transitioned to a down state. + // It's too late to release port mappings at this point (the user might've just turned off + // their wifi), but we can make sure we invalidate mappings for later when the network + // comes back. + NoteNetworkDown() + + // GetCachedMappingOrStartCreatingOne quickly returns with our current cached portmapping, if any. + // If there's not one, it starts up a background goroutine to create one. + // If the background goroutine ends up creating one, the onChange hook registered with the + // NewClient constructor (if any) will fire. + GetCachedMappingOrStartCreatingOne() (external netip.AddrPort, ok bool) + + // SetLocalPort updates the local port number to which we want to port + // map UDP traffic + SetLocalPort(localPort uint16) + + Close() error +} + +// Mapping is an event recording the allocation of a port mapping. +type Mapping struct { + External netip.AddrPort + Type string + GoodUntil time.Time + + // TODO(creachadair): Record whether we reused an existing mapping? +} diff --git a/net/portmapper/select_test.go b/net/portmapper/select_test.go index 9e99c9a9d..af2e35cbf 100644 --- a/net/portmapper/select_test.go +++ b/net/portmapper/select_test.go @@ -28,7 +28,7 @@ func TestSelectBestService(t *testing.T) { } // Run a fake IGD server to respond to UPnP requests. - igd, err := NewTestIGD(t.Logf, TestIGDOptions{UPnP: true}) + igd, err := NewTestIGD(t, TestIGDOptions{UPnP: true}) if err != nil { t.Fatal(err) } @@ -163,9 +163,8 @@ func TestSelectBestService(t *testing.T) { Desc: rootDesc, Control: tt.control, }) - c := newTestClient(t, igd) + c := newTestClient(t, igd, nil) t.Logf("Listening on upnp=%v", c.testUPnPPort) - defer c.Close() // Ensure that we're using the HTTP client that talks to our test IGD server ctx := context.Background() diff --git a/net/portmapper/upnp.go b/net/portmapper/upnp.go index f1199f0a6..d65d6e94d 100644 --- a/net/portmapper/upnp.go +++ b/net/portmapper/upnp.go @@ -209,7 +209,7 @@ func addAnyPortMapping( // The meta is the most recently parsed UDP discovery packet response // from the Internet Gateway Device. func getUPnPRootDevice(ctx context.Context, logf logger.Logf, debug DebugKnobs, gw netip.Addr, meta uPnPDiscoResponse) (rootDev *goupnp.RootDevice, loc *url.URL, err error) { - if debug.DisableUPnP { + if debug.DisableUPnP() { return nil, nil, nil } @@ -434,7 +434,7 @@ func (c *Client) getUPnPPortMapping( internal netip.AddrPort, prevPort uint16, ) (external netip.AddrPort, ok bool) { - if disableUPnpEnv() || c.debug.DisableUPnP || (c.controlKnobs != nil && c.controlKnobs.DisableUPnP.Load()) { + if disableUPnpEnv() || c.debug.DisableUPnP() { return netip.AddrPort{}, false } @@ -610,8 +610,9 @@ func (c *Client) tryUPnPPortmapWithDevice( } // From the UPnP spec: http://upnp.org/specs/gw/UPnP-gw-WANIPConnection-v2-Service.pdf + // 402: Invalid Args (see: https://github.com/tailscale/tailscale/issues/15223) // 725: OnlyPermanentLeasesSupported - if ok && code == 725 { + if ok && (code == 402 || code == 725) { newPort, err = addAnyPortMapping( ctx, client, @@ -620,7 +621,7 @@ func (c *Client) tryUPnPPortmapWithDevice( internal.Addr().String(), 0, // permanent ) - c.vlogf("addAnyPortMapping: 725 retry %v, err=%q", newPort, err) + c.vlogf("addAnyPortMapping: errcode=%d retried: port=%v err=%v", code, newPort, err) } } if err != nil { diff --git a/net/portmapper/upnp_test.go b/net/portmapper/upnp_test.go index c41b535a5..a954b2bea 100644 --- a/net/portmapper/upnp_test.go +++ b/net/portmapper/upnp_test.go @@ -18,6 +18,7 @@ import ( "sync/atomic" "testing" + "tailscale.com/net/portmapper/portmappertype" "tailscale.com/tstest" ) @@ -533,7 +534,7 @@ func TestGetUPnPClient(t *testing.T) { } func TestGetUPnPPortMapping(t *testing.T) { - igd, err := NewTestIGD(t.Logf, TestIGDOptions{UPnP: true}) + igd, err := NewTestIGD(t, TestIGDOptions{UPnP: true}) if err != nil { t.Fatal(err) } @@ -586,9 +587,8 @@ func TestGetUPnPPortMapping(t *testing.T) { }, }) - c := newTestClient(t, igd) + c := newTestClient(t, igd, nil) t.Logf("Listening on upnp=%v", c.testUPnPPort) - defer c.Close() c.debug.VerboseLogs = true @@ -628,13 +628,102 @@ func TestGetUPnPPortMapping(t *testing.T) { } } +func TestGetUPnPPortMapping_LeaseDuration(t *testing.T) { + testCases := []struct { + name string + resp string + }{ + {"only_permanent_leases", testAddPortMappingPermanentLease}, + {"invalid_args", testAddPortMappingPermanentLease_InvalidArgs}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + + // This is a very basic fake UPnP server handler. + var sawRequestWithLease atomic.Bool + handlers := map[string]any{ + "AddPortMapping": func(body []byte) (int, string) { + // Decode a minimal body to determine whether we skip the request or not. + var req struct { + Protocol string `xml:"NewProtocol"` + InternalPort string `xml:"NewInternalPort"` + ExternalPort string `xml:"NewExternalPort"` + InternalClient string `xml:"NewInternalClient"` + LeaseDuration string `xml:"NewLeaseDuration"` + } + if err := xml.Unmarshal(body, &req); err != nil { + t.Errorf("bad request: %v", err) + return http.StatusBadRequest, "bad request" + } + + if req.Protocol != "UDP" { + t.Errorf(`got Protocol=%q, want "UDP"`, req.Protocol) + } + if req.LeaseDuration != "0" { + // Return a fake error to ensure that we fall back to a permanent lease. + sawRequestWithLease.Store(true) + return http.StatusOK, tc.resp + } + + return http.StatusOK, testAddPortMappingResponse + }, + "GetExternalIPAddress": testGetExternalIPAddressResponse, + "GetStatusInfo": testGetStatusInfoResponse, + "DeletePortMapping": "", // Do nothing for test + } + + igd, err := NewTestIGD(t, TestIGDOptions{UPnP: true}) + if err != nil { + t.Fatal(err) + } + defer igd.Close() + + igd.SetUPnPHandler(&upnpServer{ + t: t, + Desc: testRootDesc, + Control: map[string]map[string]any{ + "/ctl/IPConn": handlers, + "/upnp/control/yomkmsnooi/wanipconn-1": handlers, + }, + }) + + ctx := context.Background() + c := newTestClient(t, igd, nil) + c.debug.VerboseLogs = true + t.Logf("Listening on upnp=%v", c.testUPnPPort) + + // Actually test the UPnP port mapping. + mustProbeUPnP(t, ctx, c) + + gw, myIP, ok := c.gatewayAndSelfIP() + if !ok { + t.Fatalf("could not get gateway and self IP") + } + t.Logf("gw=%v myIP=%v", gw, myIP) + + ext, ok := c.getUPnPPortMapping(ctx, gw, netip.AddrPortFrom(myIP, 12345), 0) + if !ok { + t.Fatal("could not get UPnP port mapping") + } + if got, want := ext.Addr(), netip.MustParseAddr("123.123.123.123"); got != want { + t.Errorf("bad external address; got %v want %v", got, want) + } + if !sawRequestWithLease.Load() { + t.Errorf("wanted request with lease, but didn't see one") + } + t.Logf("external IP: %v", ext) + }) + } +} + // TestGetUPnPPortMapping_NoValidServices tests that getUPnPPortMapping doesn't // crash when a valid UPnP response with no supported services is discovered // and parsed. // // See https://github.com/tailscale/tailscale/issues/10911 func TestGetUPnPPortMapping_NoValidServices(t *testing.T) { - igd, err := NewTestIGD(t.Logf, TestIGDOptions{UPnP: true}) + igd, err := NewTestIGD(t, TestIGDOptions{UPnP: true}) if err != nil { t.Fatal(err) } @@ -645,8 +734,7 @@ func TestGetUPnPPortMapping_NoValidServices(t *testing.T) { Desc: noSupportedServicesRootDesc, }) - c := newTestClient(t, igd) - defer c.Close() + c := newTestClient(t, igd, nil) c.debug.VerboseLogs = true ctx := context.Background() @@ -666,7 +754,7 @@ func TestGetUPnPPortMapping_NoValidServices(t *testing.T) { // Tests the legacy behaviour with the pre-UPnP standard portmapping service. func TestGetUPnPPortMapping_Legacy(t *testing.T) { - igd, err := NewTestIGD(t.Logf, TestIGDOptions{UPnP: true}) + igd, err := NewTestIGD(t, TestIGDOptions{UPnP: true}) if err != nil { t.Fatal(err) } @@ -688,8 +776,7 @@ func TestGetUPnPPortMapping_Legacy(t *testing.T) { }, }) - c := newTestClient(t, igd) - defer c.Close() + c := newTestClient(t, igd, nil) c.debug.VerboseLogs = true ctx := context.Background() @@ -710,15 +797,14 @@ func TestGetUPnPPortMapping_Legacy(t *testing.T) { } func TestGetUPnPPortMappingNoResponses(t *testing.T) { - igd, err := NewTestIGD(t.Logf, TestIGDOptions{UPnP: true}) + igd, err := NewTestIGD(t, TestIGDOptions{UPnP: true}) if err != nil { t.Fatal(err) } defer igd.Close() - c := newTestClient(t, igd) + c := newTestClient(t, igd, nil) t.Logf("Listening on upnp=%v", c.testUPnPPort) - defer c.Close() c.debug.VerboseLogs = true @@ -827,7 +913,7 @@ func TestGetUPnPPortMapping_Invalid(t *testing.T) { "127.0.0.1", } { t.Run(responseAddr, func(t *testing.T) { - igd, err := NewTestIGD(t.Logf, TestIGDOptions{UPnP: true}) + igd, err := NewTestIGD(t, TestIGDOptions{UPnP: true}) if err != nil { t.Fatal(err) } @@ -849,8 +935,7 @@ func TestGetUPnPPortMapping_Invalid(t *testing.T) { }, }) - c := newTestClient(t, igd) - defer c.Close() + c := newTestClient(t, igd, nil) c.debug.VerboseLogs = true ctx := context.Background() @@ -955,7 +1040,7 @@ func (u *upnpServer) handleControl(w http.ResponseWriter, r *http.Request, handl } } -func mustProbeUPnP(tb testing.TB, ctx context.Context, c *Client) ProbeResult { +func mustProbeUPnP(tb testing.TB, ctx context.Context, c *Client) portmappertype.ProbeResult { tb.Helper() res, err := c.Probe(ctx) if err != nil { @@ -1045,6 +1130,23 @@ const testAddPortMappingPermanentLease = ` ` +const testAddPortMappingPermanentLease_InvalidArgs = ` + + + + SOAP:Client + UPnPError + + + 402 + Invalid Args + + + + + +` + const testAddPortMappingResponse = ` diff --git a/net/routetable/routetable_linux.go b/net/routetable/routetable_linux.go index 88dc8535a..0b2cb305d 100644 --- a/net/routetable/routetable_linux.go +++ b/net/routetable/routetable_linux.go @@ -1,7 +1,7 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -//go:build linux +//go:build linux && !android package routetable diff --git a/net/routetable/routetable_other.go b/net/routetable/routetable_other.go index 35c83e374..e547ab0ac 100644 --- a/net/routetable/routetable_other.go +++ b/net/routetable/routetable_other.go @@ -1,7 +1,7 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -//go:build !linux && !darwin && !freebsd +//go:build android || (!linux && !darwin && !freebsd) package routetable diff --git a/net/sockopts/sockopts.go b/net/sockopts/sockopts.go new file mode 100644 index 000000000..0c0ee7692 --- /dev/null +++ b/net/sockopts/sockopts.go @@ -0,0 +1,37 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package sockopts contains logic for applying socket options. +package sockopts + +import ( + "net" + "runtime" + + "tailscale.com/types/nettype" +) + +// BufferDirection represents either the read/receive or write/send direction +// of a socket buffer. +type BufferDirection string + +const ( + ReadDirection BufferDirection = "read" + WriteDirection BufferDirection = "write" +) + +func portableSetBufferSize(pconn nettype.PacketConn, direction BufferDirection, size int) error { + if runtime.GOOS == "plan9" { + // Not supported. Don't try. Avoid logspam. + return nil + } + var err error + if c, ok := pconn.(*net.UDPConn); ok { + if direction == WriteDirection { + err = c.SetWriteBuffer(size) + } else { + err = c.SetReadBuffer(size) + } + } + return err +} diff --git a/net/sockopts/sockopts_default.go b/net/sockopts/sockopts_default.go new file mode 100644 index 000000000..3cc8679b5 --- /dev/null +++ b/net/sockopts/sockopts_default.go @@ -0,0 +1,21 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !linux + +package sockopts + +import ( + "tailscale.com/types/nettype" +) + +// SetBufferSize sets pconn's buffer to size for direction. size may be silently +// capped depending on platform. +// +// errForce is only relevant for Linux, and will always be nil otherwise, +// but we maintain a consistent cross-platform API. +// +// If pconn is not a [*net.UDPConn], then SetBufferSize is no-op. +func SetBufferSize(pconn nettype.PacketConn, direction BufferDirection, size int) (errForce error, errPortable error) { + return nil, portableSetBufferSize(pconn, direction, size) +} diff --git a/net/sockopts/sockopts_linux.go b/net/sockopts/sockopts_linux.go new file mode 100644 index 000000000..5d778d380 --- /dev/null +++ b/net/sockopts/sockopts_linux.go @@ -0,0 +1,40 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build linux + +package sockopts + +import ( + "net" + "syscall" + + "tailscale.com/types/nettype" +) + +// SetBufferSize sets pconn's buffer to size for direction. It attempts +// (errForce) to set SO_SNDBUFFORCE or SO_RECVBUFFORCE which can overcome the +// limit of net.core.{r,w}mem_max, but require CAP_NET_ADMIN. It falls back to +// the portable implementation (errPortable) if that fails, which may be +// silently capped to net.core.{r,w}mem_max. +// +// If pconn is not a [*net.UDPConn], then SetBufferSize is no-op. +func SetBufferSize(pconn nettype.PacketConn, direction BufferDirection, size int) (errForce error, errPortable error) { + opt := syscall.SO_RCVBUFFORCE + if direction == WriteDirection { + opt = syscall.SO_SNDBUFFORCE + } + if c, ok := pconn.(*net.UDPConn); ok { + var rc syscall.RawConn + rc, errForce = c.SyscallConn() + if errForce == nil { + rc.Control(func(fd uintptr) { + errForce = syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, opt, size) + }) + } + if errForce != nil { + errPortable = portableSetBufferSize(pconn, direction, size) + } + } + return errForce, errPortable +} diff --git a/wgengine/magicsock/magicsock_notwindows.go b/net/sockopts/sockopts_notwindows.go similarity index 52% rename from wgengine/magicsock/magicsock_notwindows.go rename to net/sockopts/sockopts_notwindows.go index 7c31c8202..f1bc7fd44 100644 --- a/wgengine/magicsock/magicsock_notwindows.go +++ b/net/sockopts/sockopts_notwindows.go @@ -3,11 +3,13 @@ //go:build !windows -package magicsock +package sockopts import ( - "tailscale.com/types/logger" "tailscale.com/types/nettype" ) -func trySetUDPSocketOptions(pconn nettype.PacketConn, logf logger.Logf) {} +// SetICMPErrImmunity is no-op on non-Windows. +func SetICMPErrImmunity(pconn nettype.PacketConn) error { + return nil +} diff --git a/wgengine/magicsock/magicsock_unix_test.go b/net/sockopts/sockopts_unix_test.go similarity index 87% rename from wgengine/magicsock/magicsock_unix_test.go rename to net/sockopts/sockopts_unix_test.go index b0700a8eb..ebb4354ac 100644 --- a/wgengine/magicsock/magicsock_unix_test.go +++ b/net/sockopts/sockopts_unix_test.go @@ -3,7 +3,7 @@ //go:build unix -package magicsock +package sockopts import ( "net" @@ -13,7 +13,7 @@ import ( "tailscale.com/types/nettype" ) -func TestTrySetSocketBuffer(t *testing.T) { +func TestSetBufferSize(t *testing.T) { c, err := net.ListenPacket("udp", ":0") if err != nil { t.Fatal(err) @@ -42,7 +42,8 @@ func TestTrySetSocketBuffer(t *testing.T) { curRcv, curSnd := getBufs() - trySetSocketBuffer(c.(nettype.PacketConn), t.Logf) + SetBufferSize(c.(nettype.PacketConn), ReadDirection, 7<<20) + SetBufferSize(c.(nettype.PacketConn), WriteDirection, 7<<20) newRcv, newSnd := getBufs() diff --git a/wgengine/magicsock/magicsock_windows.go b/net/sockopts/sockopts_windows.go similarity index 67% rename from wgengine/magicsock/magicsock_windows.go rename to net/sockopts/sockopts_windows.go index fe2a80e0b..1e6c3f69d 100644 --- a/wgengine/magicsock/magicsock_windows.go +++ b/net/sockopts/sockopts_windows.go @@ -3,28 +3,31 @@ //go:build windows -package magicsock +package sockopts import ( + "fmt" "net" "unsafe" "golang.org/x/sys/windows" - "tailscale.com/types/logger" "tailscale.com/types/nettype" ) -func trySetUDPSocketOptions(pconn nettype.PacketConn, logf logger.Logf) { +// SetICMPErrImmunity sets socket options on pconn to prevent ICMP reception, +// e.g. ICMP Port Unreachable, from surfacing as a syscall error. +// +// If pconn is not a [*net.UDPConn], then SetICMPErrImmunity is no-op. +func SetICMPErrImmunity(pconn nettype.PacketConn) error { c, ok := pconn.(*net.UDPConn) if !ok { // not a UDP connection; nothing to do - return + return nil } sysConn, err := c.SyscallConn() if err != nil { - logf("trySetUDPSocketOptions: getting SyscallConn failed: %v", err) - return + return fmt.Errorf("SetICMPErrImmunity: getting SyscallConn failed: %v", err) } // Similar to https://github.com/golang/go/issues/5834 (which involved @@ -50,9 +53,10 @@ func trySetUDPSocketOptions(pconn nettype.PacketConn, logf logger.Logf) { ) }) if ioctlErr != nil { - logf("trySetUDPSocketOptions: could not set SIO_UDP_NETRESET: %v", ioctlErr) + return fmt.Errorf("SetICMPErrImmunity: could not set SIO_UDP_NETRESET: %v", ioctlErr) } if err != nil { - logf("trySetUDPSocketOptions: SyscallConn.Control failed: %v", err) + return fmt.Errorf("SetICMPErrImmunity: SyscallConn.Control failed: %v", err) } + return nil } diff --git a/net/socks5/socks5.go b/net/socks5/socks5.go index 0d651537f..2e277147b 100644 --- a/net/socks5/socks5.go +++ b/net/socks5/socks5.go @@ -81,6 +81,12 @@ const ( addrTypeNotSupported replyCode = 8 ) +// UDP conn default buffer size and read timeout. +const ( + bufferSize = 8 * 1024 + readTimeout = 5 * time.Second +) + // Server is a SOCKS5 proxy server. type Server struct { // Logf optionally specifies the logger to use. @@ -114,10 +120,10 @@ func (s *Server) logf(format string, args ...any) { } // Serve accepts and handles incoming connections on the given listener. -func (s *Server) Serve(l net.Listener) error { - defer l.Close() +func (s *Server) Serve(ln net.Listener) error { + defer ln.Close() for { - c, err := l.Accept() + c, err := ln.Accept() if err != nil { return err } @@ -143,7 +149,8 @@ type Conn struct { clientConn net.Conn request *request - udpClientAddr net.Addr + udpClientAddr net.Addr + udpTargetConns map[socksAddr]net.Conn } // Run starts the new connection. @@ -276,15 +283,6 @@ func (c *Conn) handleUDP() error { } defer clientUDPConn.Close() - serverUDPConn, err := net.ListenPacket("udp", "[::]:0") - if err != nil { - res := errorResponse(generalFailure) - buf, _ := res.marshal() - c.clientConn.Write(buf) - return err - } - defer serverUDPConn.Close() - bindAddr, bindPort, err := splitHostPort(clientUDPConn.LocalAddr().String()) if err != nil { return err @@ -305,25 +303,32 @@ func (c *Conn) handleUDP() error { } c.clientConn.Write(buf) - return c.transferUDP(c.clientConn, clientUDPConn, serverUDPConn) + return c.transferUDP(c.clientConn, clientUDPConn) } -func (c *Conn) transferUDP(associatedTCP net.Conn, clientConn net.PacketConn, targetConn net.PacketConn) error { +func (c *Conn) transferUDP(associatedTCP net.Conn, clientConn net.PacketConn) error { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - const bufferSize = 8 * 1024 - const readTimeout = 5 * time.Second // client -> target go func() { defer cancel() + + c.udpTargetConns = make(map[socksAddr]net.Conn) + // close all target udp connections when the client connection is closed + defer func() { + for _, conn := range c.udpTargetConns { + _ = conn.Close() + } + }() + buf := make([]byte, bufferSize) for { select { case <-ctx.Done(): return default: - err := c.handleUDPRequest(clientConn, targetConn, buf, readTimeout) + err := c.handleUDPRequest(ctx, clientConn, buf) if err != nil { if isTimeout(err) { continue @@ -337,21 +342,44 @@ func (c *Conn) transferUDP(associatedTCP net.Conn, clientConn net.PacketConn, ta } }() + // A UDP association terminates when the TCP connection that the UDP + // ASSOCIATE request arrived on terminates. RFC1928 + _, err := io.Copy(io.Discard, associatedTCP) + if err != nil { + err = fmt.Errorf("udp associated tcp conn: %w", err) + } + return err +} + +func (c *Conn) getOrDialTargetConn( + ctx context.Context, + clientConn net.PacketConn, + targetAddr socksAddr, +) (net.Conn, error) { + conn, exist := c.udpTargetConns[targetAddr] + if exist { + return conn, nil + } + conn, err := c.srv.dial(ctx, "udp", targetAddr.hostPort()) + if err != nil { + return nil, err + } + c.udpTargetConns[targetAddr] = conn + // target -> client go func() { - defer cancel() buf := make([]byte, bufferSize) for { select { case <-ctx.Done(): return default: - err := c.handleUDPResponse(targetConn, clientConn, buf, readTimeout) + err := c.handleUDPResponse(clientConn, targetAddr, conn, buf) if err != nil { if isTimeout(err) { continue } - if errors.Is(err, net.ErrClosed) { + if errors.Is(err, net.ErrClosed) || errors.Is(err, io.EOF) { return } c.logf("udp transfer: handle udp response fail: %v", err) @@ -360,20 +388,13 @@ func (c *Conn) transferUDP(associatedTCP net.Conn, clientConn net.PacketConn, ta } }() - // A UDP association terminates when the TCP connection that the UDP - // ASSOCIATE request arrived on terminates. RFC1928 - _, err := io.Copy(io.Discard, associatedTCP) - if err != nil { - err = fmt.Errorf("udp associated tcp conn: %w", err) - } - return err + return conn, nil } func (c *Conn) handleUDPRequest( + ctx context.Context, clientConn net.PacketConn, - targetConn net.PacketConn, buf []byte, - readTimeout time.Duration, ) error { // add a deadline for the read to avoid blocking forever _ = clientConn.SetReadDeadline(time.Now().Add(readTimeout)) @@ -386,38 +407,35 @@ func (c *Conn) handleUDPRequest( if err != nil { return fmt.Errorf("parse udp request: %w", err) } - targetAddr, err := net.ResolveUDPAddr("udp", req.addr.hostPort()) + + targetConn, err := c.getOrDialTargetConn(ctx, clientConn, req.addr) if err != nil { - c.logf("resolve target addr fail: %v", err) + return fmt.Errorf("dial target %s fail: %w", req.addr, err) } - nn, err := targetConn.WriteTo(data, targetAddr) + nn, err := targetConn.Write(data) if err != nil { - return fmt.Errorf("write to target %s fail: %w", targetAddr, err) + return fmt.Errorf("write to target %s fail: %w", req.addr, err) } if nn != len(data) { - return fmt.Errorf("write to target %s fail: %w", targetAddr, io.ErrShortWrite) + return fmt.Errorf("write to target %s fail: %w", req.addr, io.ErrShortWrite) } return nil } func (c *Conn) handleUDPResponse( - targetConn net.PacketConn, clientConn net.PacketConn, + targetAddr socksAddr, + targetConn net.Conn, buf []byte, - readTimeout time.Duration, ) error { // add a deadline for the read to avoid blocking forever _ = targetConn.SetReadDeadline(time.Now().Add(readTimeout)) - n, addr, err := targetConn.ReadFrom(buf) + n, err := targetConn.Read(buf) if err != nil { return fmt.Errorf("read from target: %w", err) } - host, port, err := splitHostPort(addr.String()) - if err != nil { - return fmt.Errorf("split host port: %w", err) - } - hdr := udpRequest{addr: socksAddr{addrType: getAddrType(host), addr: host, port: port}} + hdr := udpRequest{addr: targetAddr} pkt, err := hdr.marshal() if err != nil { return fmt.Errorf("marshal udp request: %w", err) @@ -627,10 +645,15 @@ func (s socksAddr) marshal() ([]byte, error) { pkt = binary.BigEndian.AppendUint16(pkt, s.port) return pkt, nil } + func (s socksAddr) hostPort() string { return net.JoinHostPort(s.addr, strconv.Itoa(int(s.port))) } +func (s socksAddr) String() string { + return s.hostPort() +} + // response contains the contents of // a response packet sent from the proxy // to the client. diff --git a/net/socks5/socks5_test.go b/net/socks5/socks5_test.go index 11ea59d4b..bc6fac79f 100644 --- a/net/socks5/socks5_test.go +++ b/net/socks5/socks5_test.go @@ -169,12 +169,25 @@ func TestReadPassword(t *testing.T) { func TestUDP(t *testing.T) { // backend UDP server which we'll use SOCKS5 to connect to - listener, err := net.ListenPacket("udp", ":0") - if err != nil { - t.Fatal(err) + newUDPEchoServer := func() net.PacketConn { + listener, err := net.ListenPacket("udp", ":0") + if err != nil { + t.Fatal(err) + } + go udpEchoServer(listener) + return listener } - backendServerPort := listener.LocalAddr().(*net.UDPAddr).Port - go udpEchoServer(listener) + + const echoServerNumber = 3 + echoServerListener := make([]net.PacketConn, echoServerNumber) + for i := 0; i < echoServerNumber; i++ { + echoServerListener[i] = newUDPEchoServer() + } + defer func() { + for i := 0; i < echoServerNumber; i++ { + _ = echoServerListener[i].Close() + } + }() // SOCKS5 server socks5, err := net.Listen("tcp", ":0") @@ -184,84 +197,93 @@ func TestUDP(t *testing.T) { socks5Port := socks5.Addr().(*net.TCPAddr).Port go socks5Server(socks5) - // net/proxy don't support UDP, so we need to manually send the SOCKS5 UDP request - conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", socks5Port)) - if err != nil { - t.Fatal(err) - } - _, err = conn.Write([]byte{0x05, 0x01, 0x00}) // client hello with no auth - if err != nil { - t.Fatal(err) - } - buf := make([]byte, 1024) - n, err := conn.Read(buf) // server hello - if err != nil { - t.Fatal(err) - } - if n != 2 || buf[0] != 0x05 || buf[1] != 0x00 { - t.Fatalf("got: %q want: 0x05 0x00", buf[:n]) - } + // make a socks5 udpAssociate conn + newUdpAssociateConn := func() (socks5Conn net.Conn, socks5UDPAddr socksAddr) { + // net/proxy don't support UDP, so we need to manually send the SOCKS5 UDP request + conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", socks5Port)) + if err != nil { + t.Fatal(err) + } + _, err = conn.Write([]byte{socks5Version, 0x01, noAuthRequired}) // client hello with no auth + if err != nil { + t.Fatal(err) + } + buf := make([]byte, 1024) + n, err := conn.Read(buf) // server hello + if err != nil { + t.Fatal(err) + } + if n != 2 || buf[0] != socks5Version || buf[1] != noAuthRequired { + t.Fatalf("got: %q want: 0x05 0x00", buf[:n]) + } - targetAddr := socksAddr{ - addrType: domainName, - addr: "localhost", - port: uint16(backendServerPort), - } - targetAddrPkt, err := targetAddr.marshal() - if err != nil { - t.Fatal(err) - } - _, err = conn.Write(append([]byte{0x05, 0x03, 0x00}, targetAddrPkt...)) // client reqeust - if err != nil { - t.Fatal(err) - } + targetAddr := socksAddr{addrType: ipv4, addr: "0.0.0.0", port: 0} + targetAddrPkt, err := targetAddr.marshal() + if err != nil { + t.Fatal(err) + } + _, err = conn.Write(append([]byte{socks5Version, byte(udpAssociate), 0x00}, targetAddrPkt...)) // client reqeust + if err != nil { + t.Fatal(err) + } - n, err = conn.Read(buf) // server response - if err != nil { - t.Fatal(err) - } - if n < 3 || !bytes.Equal(buf[:3], []byte{0x05, 0x00, 0x00}) { - t.Fatalf("got: %q want: 0x05 0x00 0x00", buf[:n]) + n, err = conn.Read(buf) // server response + if err != nil { + t.Fatal(err) + } + if n < 3 || !bytes.Equal(buf[:3], []byte{socks5Version, 0x00, 0x00}) { + t.Fatalf("got: %q want: 0x05 0x00 0x00", buf[:n]) + } + udpProxySocksAddr, err := parseSocksAddr(bytes.NewReader(buf[3:n])) + if err != nil { + t.Fatal(err) + } + + return conn, udpProxySocksAddr } - udpProxySocksAddr, err := parseSocksAddr(bytes.NewReader(buf[3:n])) - if err != nil { - t.Fatal(err) + + conn, udpProxySocksAddr := newUdpAssociateConn() + defer conn.Close() + + sendUDPAndWaitResponse := func(socks5UDPConn net.Conn, addr socksAddr, body []byte) (responseBody []byte) { + udpPayload, err := (&udpRequest{addr: addr}).marshal() + if err != nil { + t.Fatal(err) + } + udpPayload = append(udpPayload, body...) + _, err = socks5UDPConn.Write(udpPayload) + if err != nil { + t.Fatal(err) + } + buf := make([]byte, 1024) + n, err := socks5UDPConn.Read(buf) + if err != nil { + t.Fatal(err) + } + _, responseBody, err = parseUDPRequest(buf[:n]) + if err != nil { + t.Fatal(err) + } + return responseBody } udpProxyAddr, err := net.ResolveUDPAddr("udp", udpProxySocksAddr.hostPort()) if err != nil { t.Fatal(err) } - udpConn, err := net.DialUDP("udp", nil, udpProxyAddr) - if err != nil { - t.Fatal(err) - } - udpPayload, err := (&udpRequest{addr: targetAddr}).marshal() - if err != nil { - t.Fatal(err) - } - udpPayload = append(udpPayload, []byte("Test")...) - _, err = udpConn.Write(udpPayload) // send udp package - if err != nil { - t.Fatal(err) - } - n, _, err = udpConn.ReadFrom(buf) - if err != nil { - t.Fatal(err) - } - _, responseBody, err := parseUDPRequest(buf[:n]) // read udp response - if err != nil { - t.Fatal(err) - } - if string(responseBody) != "Test" { - t.Fatalf("got: %q want: Test", responseBody) - } - err = udpConn.Close() + socks5UDPConn, err := net.DialUDP("udp", nil, udpProxyAddr) if err != nil { t.Fatal(err) } - err = conn.Close() - if err != nil { - t.Fatal(err) + defer socks5UDPConn.Close() + + for i := 0; i < echoServerNumber; i++ { + port := echoServerListener[i].LocalAddr().(*net.UDPAddr).Port + addr := socksAddr{addrType: ipv4, addr: "127.0.0.1", port: uint16(port)} + requestBody := []byte(fmt.Sprintf("Test %d", i)) + responseBody := sendUDPAndWaitResponse(socks5UDPConn, addr, requestBody) + if !bytes.Equal(requestBody, responseBody) { + t.Fatalf("got: %q want: %q", responseBody, requestBody) + } } } diff --git a/net/sockstats/label_string.go b/net/sockstats/label_string.go index f9a111ad7..cc503d943 100644 --- a/net/sockstats/label_string.go +++ b/net/sockstats/label_string.go @@ -28,8 +28,9 @@ const _Label_name = "ControlClientAutoControlClientDialerDERPHTTPClientLogtailLo var _Label_index = [...]uint8{0, 17, 36, 50, 63, 78, 93, 107, 123, 140, 157, 169, 186, 201} func (i Label) String() string { - if i >= Label(len(_Label_index)-1) { + idx := int(i) - 0 + if i < 0 || idx >= len(_Label_index)-1 { return "Label(" + strconv.FormatInt(int64(i), 10) + ")" } - return _Label_name[_Label_index[i]:_Label_index[i+1]] + return _Label_name[_Label_index[idx]:_Label_index[idx+1]] } diff --git a/net/sockstats/sockstats_tsgo.go b/net/sockstats/sockstats_tsgo.go index 2d1ccd5a3..aa875df9a 100644 --- a/net/sockstats/sockstats_tsgo.go +++ b/net/sockstats/sockstats_tsgo.go @@ -10,14 +10,15 @@ import ( "fmt" "net" "strings" - "sync" "sync/atomic" "syscall" "time" "tailscale.com/net/netmon" + "tailscale.com/syncs" "tailscale.com/types/logger" "tailscale.com/util/clientmetric" + "tailscale.com/version" ) const IsAvailable = true @@ -39,7 +40,7 @@ var sockStats = struct { // mu protects fields in this group (but not the fields within // sockStatCounters). It should not be held in the per-read/write // callbacks. - mu sync.Mutex + mu syncs.Mutex countersByLabel map[Label]*sockStatCounters knownInterfaces map[int]string // interface index -> name usedInterfaces map[int]int // set of interface indexes @@ -156,7 +157,11 @@ func withSockStats(ctx context.Context, label Label, logf logger.Logf) context.C } } willOverwrite := func(trace *net.SockTrace) { - logf("sockstats: trace %q was overwritten by another", label) + if version.IsUnstableBuild() { + // Only spam about this in dev builds. + // See https://github.com/tailscale/tailscale/issues/13731 for known problems. + logf("sockstats: trace %q was overwritten by another", label) + } } return net.WithSockTrace(ctx, &net.SockTrace{ @@ -274,7 +279,13 @@ func setNetMon(netMon *netmon.Monitor) { if ifName == "" { return } - ifIndex := state.Interface[ifName].Index + // DefaultRouteInterface and Interface are gathered at different points in time. + // Check for existence first, to avoid a nil pointer dereference. + iface, ok := state.Interface[ifName] + if !ok { + return + } + ifIndex := iface.Index sockStats.mu.Lock() defer sockStats.mu.Unlock() // Ignore changes to unknown interfaces -- it would require diff --git a/net/speedtest/speedtest.go b/net/speedtest/speedtest.go index 7ab0881cc..a462dbeec 100644 --- a/net/speedtest/speedtest.go +++ b/net/speedtest/speedtest.go @@ -24,7 +24,7 @@ const ( // conduct the test. type config struct { Version int `json:"version"` - TestDuration time.Duration `json:"time"` + TestDuration time.Duration `json:"time,format:nano"` Direction Direction `json:"direction"` } diff --git a/net/speedtest/speedtest_server.go b/net/speedtest/speedtest_server.go index 9dd78b195..72f85fa15 100644 --- a/net/speedtest/speedtest_server.go +++ b/net/speedtest/speedtest_server.go @@ -17,9 +17,9 @@ import ( // connections and handles each one in a goroutine. Because it runs in an infinite loop, // this function only returns if any of the speedtests return with errors, or if the // listener is closed. -func Serve(l net.Listener) error { +func Serve(ln net.Listener) error { for { - conn, err := l.Accept() + conn, err := ln.Accept() if errors.Is(err, net.ErrClosed) { return nil } diff --git a/net/speedtest/speedtest_test.go b/net/speedtest/speedtest_test.go index 55dcbeea1..bb8f2676a 100644 --- a/net/speedtest/speedtest_test.go +++ b/net/speedtest/speedtest_test.go @@ -4,20 +4,30 @@ package speedtest import ( + "flag" "net" "testing" "time" + + "tailscale.com/cmd/testwrapper/flakytest" ) +var manualTest = flag.Bool("do-speedtest", false, "if true, run the speedtest TestDownload test. Otherwise skip it because it's slow and flaky; see https://github.com/tailscale/tailscale/issues/17338") + func TestDownload(t *testing.T) { + if !*manualTest { + t.Skip("skipping slow test without --do-speedtest") + } + flakytest.Mark(t, "https://github.com/tailscale/tailscale/issues/17338") + // start a listener and find the port where the server will be listening. - l, err := net.Listen("tcp", ":0") + ln, err := net.Listen("tcp", ":0") if err != nil { t.Fatal(err) } - t.Cleanup(func() { l.Close() }) + t.Cleanup(func() { ln.Close() }) - serverIP := l.Addr().String() + serverIP := ln.Addr().String() t.Log("server IP found:", serverIP) type state struct { @@ -30,7 +40,7 @@ func TestDownload(t *testing.T) { stateChan := make(chan state, 1) go func() { - err := Serve(l) + err := Serve(ln) stateChan <- state{err: err} }() @@ -74,7 +84,7 @@ func TestDownload(t *testing.T) { }) // causes the server goroutine to finish - l.Close() + ln.Close() testState := <-stateChan if testState.err != nil { diff --git a/net/tlsdial/blockblame/blockblame.go b/net/tlsdial/blockblame/blockblame.go new file mode 100644 index 000000000..5b48dc009 --- /dev/null +++ b/net/tlsdial/blockblame/blockblame.go @@ -0,0 +1,120 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package blockblame blames specific firewall manufacturers for blocking Tailscale, +// by analyzing the SSL certificate presented when attempting to connect to a remote +// server. +package blockblame + +import ( + "crypto/x509" + "strings" + "sync" + + "tailscale.com/feature/buildfeatures" +) + +// VerifyCertificate checks if the given certificate c is issued by a firewall manufacturer +// that is known to block Tailscale connections. It returns true and the Manufacturer of +// the equipment if it is, or false and nil if it is not. +func VerifyCertificate(c *x509.Certificate) (m *Manufacturer, ok bool) { + if !buildfeatures.HasDebug { + return nil, false + } + for _, m := range manufacturers() { + if m.match != nil && m.match(c) { + return m, true + } + } + return nil, false +} + +// Manufacturer represents a firewall manufacturer that may be blocking Tailscale. +type Manufacturer struct { + // Name is the name of the firewall manufacturer to be + // mentioned in health warning messages, e.g. "Fortinet". + Name string + // match is a function that returns true if the given certificate looks like it might + // be issued by this manufacturer. + match matchFunc +} + +func manufacturers() []*Manufacturer { + manufacturersOnce.Do(func() { + manufacturersList = []*Manufacturer{ + { + Name: "Aruba Networks", + match: issuerContains("Aruba"), + }, + { + Name: "Cisco", + match: issuerContains("Cisco"), + }, + { + Name: "Fortinet", + match: matchAny( + issuerContains("Fortinet"), + certEmail("support@fortinet.com"), + ), + }, + { + Name: "Huawei", + match: certEmail("mobile@huawei.com"), + }, + { + Name: "Palo Alto Networks", + match: matchAny( + issuerContains("Palo Alto Networks"), + issuerContains("PAN-FW"), + ), + }, + { + Name: "Sophos", + match: issuerContains("Sophos"), + }, + { + Name: "Ubiquiti", + match: matchAny( + issuerContains("UniFi"), + issuerContains("Ubiquiti"), + ), + }, + } + }) + return manufacturersList +} + +var ( + manufacturersOnce sync.Once + manufacturersList []*Manufacturer +) + +type matchFunc func(*x509.Certificate) bool + +func issuerContains(s string) matchFunc { + return func(c *x509.Certificate) bool { + return strings.Contains(strings.ToLower(c.Issuer.String()), strings.ToLower(s)) + } +} + +func certEmail(v string) matchFunc { + return func(c *x509.Certificate) bool { + for _, email := range c.EmailAddresses { + if strings.Contains(strings.ToLower(email), strings.ToLower(v)) { + return true + } + } + return false + } +} + +func matchAny(fs ...matchFunc) matchFunc { + return func(c *x509.Certificate) bool { + for _, f := range fs { + if f(c) { + return true + } + } + return false + } +} diff --git a/net/tlsdial/blockblame/blockblame_test.go b/net/tlsdial/blockblame/blockblame_test.go new file mode 100644 index 000000000..6d3592c60 --- /dev/null +++ b/net/tlsdial/blockblame/blockblame_test.go @@ -0,0 +1,54 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package blockblame + +import ( + "crypto/x509" + "encoding/pem" + "testing" +) + +const controlplaneDotTailscaleDotComPEM = ` +-----BEGIN CERTIFICATE----- +MIIDkzCCAxqgAwIBAgISA2GOahsftpp59yuHClbDuoduMAoGCCqGSM49BAMDMDIx +CzAJBgNVBAYTAlVTMRYwFAYDVQQKEw1MZXQncyBFbmNyeXB0MQswCQYDVQQDEwJF +NjAeFw0yNDEwMTIxNjE2NDVaFw0yNTAxMTAxNjE2NDRaMCUxIzAhBgNVBAMTGmNv +bnRyb2xwbGFuZS50YWlsc2NhbGUuY29tMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcD +QgAExfraDUc1t185zuGtZlnPDtEJJSDBqvHN4vQcXSzSTPSAdDYHcA8fL5woU2Kg +jK/2C0wm/rYy2Rre/ulhkS4wB6OCAhswggIXMA4GA1UdDwEB/wQEAwIHgDAdBgNV +HSUEFjAUBggrBgEFBQcDAQYIKwYBBQUHAwIwDAYDVR0TAQH/BAIwADAdBgNVHQ4E +FgQUpArnpDj8Yh6NTgMOZjDPx0TuLmcwHwYDVR0jBBgwFoAUkydGmAOpUWiOmNbE +QkjbI79YlNIwVQYIKwYBBQUHAQEESTBHMCEGCCsGAQUFBzABhhVodHRwOi8vZTYu +by5sZW5jci5vcmcwIgYIKwYBBQUHMAKGFmh0dHA6Ly9lNi5pLmxlbmNyLm9yZy8w +JQYDVR0RBB4wHIIaY29udHJvbHBsYW5lLnRhaWxzY2FsZS5jb20wEwYDVR0gBAww +CjAIBgZngQwBAgEwggEDBgorBgEEAdZ5AgQCBIH0BIHxAO8AdgDgkrP8DB3I52g2 +H95huZZNClJ4GYpy1nLEsE2lbW9UBAAAAZKBujCyAAAEAwBHMEUCIQDHMgUaL4H9 +ZJa090ZOpBeEVu3+t+EF4HlHI1NqAai6uQIgeY/lLfjAXfcVgxBHHR4zjd0SzhaP +TREHXzwxzN/8blkAdQDPEVbu1S58r/OHW9lpLpvpGnFnSrAX7KwB0lt3zsw7CAAA +AZKBujh8AAAEAwBGMEQCICQwhMk45t9aiFjfwOC/y6+hDbszqSCpIv63kFElweUy +AiAqTdkqmbqUVpnav5JdWkNERVAIlY4jqrThLsCLZYbNszAKBggqhkjOPQQDAwNn +ADBkAjALyfgAt1XQp1uSfxy4GapR5OsmjEMBRVq6IgsPBlCRBfmf0Q3/a6mF0pjb +Sj4oa+cCMEhZk4DmBTIdZY9zjuh8s7bXNfKxUQS0pEhALtXqyFr+D5dF7JcQo9+s +Z98JY7/PCA== +-----END CERTIFICATE-----` + +func TestVerifyCertificateOurControlPlane(t *testing.T) { + p, _ := pem.Decode([]byte(controlplaneDotTailscaleDotComPEM)) + if p == nil { + t.Fatalf("failed to extract certificate bytes for controlplane.tailscale.com") + return + } + cert, err := x509.ParseCertificate(p.Bytes) + if err != nil { + t.Fatalf("failed to parse certificate: %v", err) + return + } + m, found := VerifyCertificate(cert) + if found { + t.Fatalf("expected to not get a result for the controlplane.tailscale.com certificate") + } + if m != nil { + t.Fatalf("expected nil manufacturer for controlplane.tailscale.com certificate") + } +} diff --git a/net/tlsdial/tlsdial.go b/net/tlsdial/tlsdial.go index a49e7f0f7..ee4771d8d 100644 --- a/net/tlsdial/tlsdial.go +++ b/net/tlsdial/tlsdial.go @@ -12,6 +12,7 @@ package tlsdial import ( "bytes" "context" + "crypto/sha256" "crypto/tls" "crypto/x509" "errors" @@ -20,23 +21,22 @@ import ( "net" "net/http" "os" + "strings" "sync" "sync/atomic" "time" + "tailscale.com/derp/derpconst" "tailscale.com/envknob" + "tailscale.com/feature/buildfeatures" "tailscale.com/health" "tailscale.com/hostinfo" + "tailscale.com/net/bakedroots" + "tailscale.com/net/tlsdial/blockblame" ) var counterFallbackOK int32 // atomic -// If SSLKEYLOGFILE is set, it's a file to which we write our TLS private keys -// in a way that WireShark can read. -// -// See https://developer.mozilla.org/en-US/docs/Mozilla/Projects/NSS/Key_Log_Format -var sslKeyLogFile = os.Getenv("SSLKEYLOGFILE") - var debug = envknob.RegisterBool("TS_DEBUG_TLS_DIAL") // tlsdialWarningPrinted tracks whether we've printed a warning about a given @@ -44,26 +44,50 @@ var debug = envknob.RegisterBool("TS_DEBUG_TLS_DIAL") // Headscale, etc. var tlsdialWarningPrinted sync.Map // map[string]bool -// Config returns a tls.Config for connecting to a server. +var mitmBlockWarnable = health.Register(&health.Warnable{ + Code: "blockblame-mitm-detected", + Title: "Network may be blocking Tailscale", + Text: func(args health.Args) string { + return fmt.Sprintf("Network equipment from %q may be blocking Tailscale traffic on this network. Connect to another network, or contact your network administrator for assistance.", args["manufacturer"]) + }, + Severity: health.SeverityMedium, + ImpactsConnectivity: true, +}) + +// Config returns a tls.Config for connecting to a server that +// uses system roots for validation but, if those fail, also tries +// the baked-in LetsEncrypt roots as a fallback validation method. +// // If base is non-nil, it's cloned as the base config before // being configured and returned. // If ht is non-nil, it's used to report health errors. -func Config(host string, ht *health.Tracker, base *tls.Config) *tls.Config { +func Config(ht *health.Tracker, base *tls.Config) *tls.Config { var conf *tls.Config if base == nil { conf = new(tls.Config) } else { conf = base.Clone() } - conf.ServerName = host - if n := sslKeyLogFile; n != "" { - f, err := os.OpenFile(n, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0600) - if err != nil { - log.Fatal(err) + // Note: we do NOT set conf.ServerName here (as we accidentally did + // previously), as this path is also used when dialing an HTTPS proxy server + // (through which we'll send a CONNECT request to get a TCP connection to do + // the real TCP connection) because host is the ultimate hostname, but this + // tls.Config is used for both the proxy and the ultimate target. + + if buildfeatures.HasDebug { + // If SSLKEYLOGFILE is set, it's a file to which we write our TLS private keys + // in a way that WireShark can read. + // + // See https://developer.mozilla.org/en-US/docs/Mozilla/Projects/NSS/Key_Log_Format + if n := os.Getenv("SSLKEYLOGFILE"); n != "" { + f, err := os.OpenFile(n, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0600) + if err != nil { + log.Fatal(err) + } + log.Printf("WARNING: writing to SSLKEYLOGFILE %v", n) + conf.KeyLogWriter = f } - log.Printf("WARNING: writing to SSLKEYLOGFILE %v", n) - conf.KeyLogWriter = f } if conf.InsecureSkipVerify { @@ -78,20 +102,39 @@ func Config(host string, ht *health.Tracker, base *tls.Config) *tls.Config { // (with the baked-in fallback root) in the VerifyConnection hook. conf.InsecureSkipVerify = true conf.VerifyConnection = func(cs tls.ConnectionState) (retErr error) { - if host == "log.tailscale.io" && hostinfo.IsNATLabGuestVM() { - // Allow log.tailscale.io TLS MITM for integration tests when + dialedHost := cs.ServerName + + if dialedHost == "log.tailscale.com" && hostinfo.IsNATLabGuestVM() { + // Allow log.tailscale.com TLS MITM for integration tests when // the client's running within a NATLab VM. return nil } // Perform some health checks on this certificate before we do // any verification. + var cert *x509.Certificate var selfSignedIssuer string - if certs := cs.PeerCertificates; len(certs) > 0 && certIsSelfSigned(certs[0]) { - selfSignedIssuer = certs[0].Issuer.String() + if certs := cs.PeerCertificates; len(certs) > 0 { + cert = certs[0] + if certIsSelfSigned(cert) { + selfSignedIssuer = cert.Issuer.String() + } } if ht != nil { defer func() { + if retErr != nil && cert != nil { + // Is it a MITM SSL certificate from a well-known network appliance manufacturer? + // Show a dedicated warning. + m, ok := blockblame.VerifyCertificate(cert) + if ok { + log.Printf("tlsdial: server cert seen while dialing %q looks like %q equipment (could be blocking Tailscale)", dialedHost, m.Name) + ht.SetUnhealthy(mitmBlockWarnable, health.Args{"manufacturer": m.Name}) + } else { + ht.SetHealthy(mitmBlockWarnable) + } + } else { + ht.SetHealthy(mitmBlockWarnable) + } if retErr != nil && selfSignedIssuer != "" { // Self-signed certs are never valid. // @@ -103,7 +146,7 @@ func Config(host string, ht *health.Tracker, base *tls.Config) *tls.Config { ht.SetTLSConnectionError(cs.ServerName, nil) if selfSignedIssuer != "" { // Log the self-signed issuer, but don't treat it as an error. - log.Printf("tlsdial: warning: server cert for %q passed x509 validation but is self-signed by %q", host, selfSignedIssuer) + log.Printf("tlsdial: warning: server cert for %q passed x509 validation but is self-signed by %q", dialedHost, selfSignedIssuer) } } }() @@ -112,7 +155,7 @@ func Config(host string, ht *health.Tracker, base *tls.Config) *tls.Config { // First try doing x509 verification with the system's // root CA pool. opts := x509.VerifyOptions{ - DNSName: cs.ServerName, + DNSName: dialedHost, Intermediates: x509.NewCertPool(), } for _, cert := range cs.PeerCertificates[1:] { @@ -120,22 +163,22 @@ func Config(host string, ht *health.Tracker, base *tls.Config) *tls.Config { } _, errSys := cs.PeerCertificates[0].Verify(opts) if debug() { - log.Printf("tlsdial(sys %q): %v", host, errSys) + log.Printf("tlsdial(sys %q): %v", dialedHost, errSys) + } + if !buildfeatures.HasBakedRoots || (errSys == nil && !debug()) { + return errSys } - // Always verify with our baked-in Let's Encrypt certificate, - // so we can log an informational message. This is useful for - // detecting SSL MiTM. - opts.Roots = bakedInRoots() + // If we have baked-in LetsEncrypt roots and we either failed above, or + // debug logging is enabled, also verify with LetsEncrypt. + opts.Roots = bakedroots.Get() _, bakedErr := cs.PeerCertificates[0].Verify(opts) if debug() { - log.Printf("tlsdial(bake %q): %v", host, bakedErr) + log.Printf("tlsdial(bake %q): %v", dialedHost, bakedErr) } else if bakedErr != nil { - if _, loaded := tlsdialWarningPrinted.LoadOrStore(host, true); !loaded { - if errSys == nil { - log.Printf("tlsdial: warning: server cert for %q is not a Let's Encrypt cert", host) - } else { - log.Printf("tlsdial: error: server cert for %q failed to verify and is not a Let's Encrypt cert", host) + if _, loaded := tlsdialWarningPrinted.LoadOrStore(dialedHost, true); !loaded { + if errSys != nil { + log.Printf("tlsdial: error: server cert for %q failed both system roots & Let's Encrypt root validation", dialedHost) } } } @@ -170,9 +213,6 @@ func SetConfigExpectedCert(c *tls.Config, certDNSName string) { c.ServerName = certDNSName return } - if c.VerifyPeerCertificate != nil { - panic("refusing to override tls.Config.VerifyPeerCertificate") - } // Set InsecureSkipVerify to prevent crypto/tls from doing its // own cert verification, but do the same work that it'd do // (but using certDNSName) in the VerifyPeerCertificate hook. @@ -202,10 +242,10 @@ func SetConfigExpectedCert(c *tls.Config, certDNSName string) { if debug() { log.Printf("tlsdial(sys %q/%q): %v", c.ServerName, certDNSName, errSys) } - if errSys == nil { - return nil + if !buildfeatures.HasBakedRoots || errSys == nil { + return errSys } - opts.Roots = bakedInRoots() + opts.Roots = bakedroots.Get() _, err := certs[0].Verify(opts) if debug() { log.Printf("tlsdial(bake %q/%q): %v", c.ServerName, certDNSName, err) @@ -217,99 +257,63 @@ func SetConfigExpectedCert(c *tls.Config, certDNSName string) { } } +// SetConfigExpectedCertHash configures c's VerifyPeerCertificate function to +// require that exactly 1 cert is presented (not counting any present MetaCert), +// and that the hex of its SHA256 hash is equal to wantFullCertSHA256Hex and +// that it's a valid cert for c.ServerName. +func SetConfigExpectedCertHash(c *tls.Config, wantFullCertSHA256Hex string) { + if c.VerifyPeerCertificate != nil { + panic("refusing to override tls.Config.VerifyPeerCertificate") + } + + // Set InsecureSkipVerify to prevent crypto/tls from doing its + // own cert verification, but do the same work that it'd do + // (but using certDNSName) in the VerifyConnection hook. + c.InsecureSkipVerify = true + + c.VerifyConnection = func(cs tls.ConnectionState) error { + dialedHost := cs.ServerName + var sawGoodCert bool + + for _, cert := range cs.PeerCertificates { + if strings.HasPrefix(cert.Subject.CommonName, derpconst.MetaCertCommonNamePrefix) { + continue + } + if sawGoodCert { + return errors.New("unexpected multiple certs presented") + } + if fmt.Sprintf("%02x", sha256.Sum256(cert.Raw)) != wantFullCertSHA256Hex { + return fmt.Errorf("cert hash does not match expected cert hash") + } + if dialedHost != "" { // it's empty when dialing a derper by IP with no hostname + if err := cert.VerifyHostname(dialedHost); err != nil { + return fmt.Errorf("cert does not match server name %q: %w", dialedHost, err) + } + } + now := time.Now() + if now.After(cert.NotAfter) { + return fmt.Errorf("cert expired %v", cert.NotAfter) + } + if now.Before(cert.NotBefore) { + return fmt.Errorf("cert not yet valid until %v; is your clock correct?", cert.NotBefore) + } + sawGoodCert = true + } + if !sawGoodCert { + return errors.New("expected cert not presented") + } + return nil + } +} + // NewTransport returns a new HTTP transport that verifies TLS certs using this // package, including its baked-in LetsEncrypt fallback roots. func NewTransport() *http.Transport { return &http.Transport{ DialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) { - host, _, err := net.SplitHostPort(addr) - if err != nil { - return nil, err - } var d tls.Dialer - d.Config = Config(host, nil, nil) + d.Config = Config(nil, nil) return d.DialContext(ctx, network, addr) }, } } - -/* -letsEncryptX1 is the LetsEncrypt X1 root: - -Certificate: - - Data: - Version: 3 (0x2) - Serial Number: - 82:10:cf:b0:d2:40:e3:59:44:63:e0:bb:63:82:8b:00 - Signature Algorithm: sha256WithRSAEncryption - Issuer: C = US, O = Internet Security Research Group, CN = ISRG Root X1 - Validity - Not Before: Jun 4 11:04:38 2015 GMT - Not After : Jun 4 11:04:38 2035 GMT - Subject: C = US, O = Internet Security Research Group, CN = ISRG Root X1 - Subject Public Key Info: - Public Key Algorithm: rsaEncryption - RSA Public-Key: (4096 bit) - -We bake it into the binary as a fallback verification root, -in case the system we're running on doesn't have it. -(Tailscale runs on some ancient devices.) - -To test that this code is working on Debian/Ubuntu: - -$ sudo mv /usr/share/ca-certificates/mozilla/ISRG_Root_X1.crt{,.old} -$ sudo update-ca-certificates - -Then restart tailscaled. To also test dnsfallback's use of it, nuke -your /etc/resolv.conf and it should still start & run fine. -*/ -const letsEncryptX1 = ` ------BEGIN CERTIFICATE----- -MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw -TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh -cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4 -WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu -ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY -MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc -h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+ -0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U -A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW -T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH -B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC -B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv -KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn -OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn -jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw -qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI -rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV -HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq -hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL -ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ -3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK -NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5 -ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur -TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC -jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc -oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq -4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA -mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d -emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc= ------END CERTIFICATE----- -` - -var bakedInRootsOnce struct { - sync.Once - p *x509.CertPool -} - -func bakedInRoots() *x509.CertPool { - bakedInRootsOnce.Do(func() { - p := x509.NewCertPool() - if !p.AppendCertsFromPEM([]byte(letsEncryptX1)) { - panic("bogus PEM") - } - bakedInRootsOnce.p = p - }) - return bakedInRootsOnce.p -} diff --git a/net/tlsdial/tlsdial_test.go b/net/tlsdial/tlsdial_test.go index 26814ebbd..a288d7653 100644 --- a/net/tlsdial/tlsdial_test.go +++ b/net/tlsdial/tlsdial_test.go @@ -4,37 +4,23 @@ package tlsdial import ( - "crypto/x509" "io" "net" "net/http" "os" "os/exec" "path/filepath" - "reflect" "runtime" "sync/atomic" "testing" "tailscale.com/health" + "tailscale.com/net/bakedroots" + "tailscale.com/util/eventbus/eventbustest" ) -func resetOnce() { - rv := reflect.ValueOf(&bakedInRootsOnce).Elem() - rv.Set(reflect.Zero(rv.Type())) -} - -func TestBakedInRoots(t *testing.T) { - resetOnce() - p := bakedInRoots() - got := p.Subjects() - if len(got) != 1 { - t.Errorf("subjects = %v; want 1", len(got)) - } -} - func TestFallbackRootWorks(t *testing.T) { - defer resetOnce() + defer bakedroots.ResetForTest(t, nil) const debug = false if runtime.GOOS != "linux" { @@ -69,14 +55,7 @@ func TestFallbackRootWorks(t *testing.T) { if err != nil { t.Fatal(err) } - resetOnce() - bakedInRootsOnce.Do(func() { - p := x509.NewCertPool() - if !p.AppendCertsFromPEM(caPEM) { - t.Fatal("failed to add") - } - bakedInRootsOnce.p = p - }) + bakedroots.ResetForTest(t, caPEM) ln, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { @@ -107,8 +86,8 @@ func TestFallbackRootWorks(t *testing.T) { }, DisableKeepAlives: true, // for test cleanup ease } - ht := new(health.Tracker) - tr.TLSClientConfig = Config("tlsdial.test", ht, tr.TLSClientConfig) + ht := health.NewTracker(eventbustest.NewBus(t)) + tr.TLSClientConfig = Config(ht, tr.TLSClientConfig) c := &http.Client{Transport: tr} ctr0 := atomic.LoadInt32(&counterFallbackOK) diff --git a/net/tsaddr/tsaddr.go b/net/tsaddr/tsaddr.go index 880695387..06e6a26dd 100644 --- a/net/tsaddr/tsaddr.go +++ b/net/tsaddr/tsaddr.go @@ -66,15 +66,21 @@ const ( TailscaleServiceIPv6String = "fd7a:115c:a1e0::53" ) -// IsTailscaleIP reports whether ip is an IP address in a range that +// IsTailscaleIP reports whether IP is an IP address in a range that // Tailscale assigns from. func IsTailscaleIP(ip netip.Addr) bool { if ip.Is4() { - return CGNATRange().Contains(ip) && !ChromeOSVMRange().Contains(ip) + return IsTailscaleIPv4(ip) } return TailscaleULARange().Contains(ip) } +// IsTailscaleIPv4 reports whether an IPv4 IP is an IP address that +// Tailscale assigns from. +func IsTailscaleIPv4(ip netip.Addr) bool { + return CGNATRange().Contains(ip) && !ChromeOSVMRange().Contains(ip) +} + // TailscaleULARange returns the IPv6 Unique Local Address range that // is the superset range that Tailscale assigns out of. func TailscaleULARange() netip.Prefix { @@ -180,8 +186,7 @@ func PrefixIs6(p netip.Prefix) bool { return p.Addr().Is6() } // IPv6 /0 route. func ContainsExitRoutes(rr views.Slice[netip.Prefix]) bool { var v4, v6 bool - for i := range rr.Len() { - r := rr.At(i) + for _, r := range rr.All() { if r == allIPv4 { v4 = true } else if r == allIPv6 { @@ -194,8 +199,8 @@ func ContainsExitRoutes(rr views.Slice[netip.Prefix]) bool { // ContainsExitRoute reports whether rr contains at least one of IPv4 or // IPv6 /0 (exit) routes. func ContainsExitRoute(rr views.Slice[netip.Prefix]) bool { - for i := range rr.Len() { - if rr.At(i).Bits() == 0 { + for _, r := range rr.All() { + if r.Bits() == 0 { return true } } @@ -205,8 +210,8 @@ func ContainsExitRoute(rr views.Slice[netip.Prefix]) bool { // ContainsNonExitSubnetRoutes reports whether v contains Subnet // Routes other than ExitNode Routes. func ContainsNonExitSubnetRoutes(rr views.Slice[netip.Prefix]) bool { - for i := range rr.Len() { - if rr.At(i).Bits() != 0 { + for _, r := range rr.All() { + if r.Bits() != 0 { return true } } diff --git a/net/tsaddr/tsaddr_test.go b/net/tsaddr/tsaddr_test.go index 4aa2f8c60..9ac1ce303 100644 --- a/net/tsaddr/tsaddr_test.go +++ b/net/tsaddr/tsaddr_test.go @@ -222,3 +222,71 @@ func TestContainsExitRoute(t *testing.T) { } } } + +func TestIsTailscaleIPv4(t *testing.T) { + tests := []struct { + in netip.Addr + want bool + }{ + { + in: netip.MustParseAddr("100.67.19.57"), + want: true, + }, + { + in: netip.MustParseAddr("10.10.10.10"), + want: false, + }, + { + + in: netip.MustParseAddr("fd7a:115c:a1e0:3f2b:7a1d:4e88:9c2b:7f01"), + want: false, + }, + { + in: netip.MustParseAddr("bc9d:0aa0:1f0a:69ab:eb5c:28e0:5456:a518"), + want: false, + }, + { + in: netip.MustParseAddr("100.115.92.157"), + want: false, + }, + } + for _, tt := range tests { + if got := IsTailscaleIPv4(tt.in); got != tt.want { + t.Errorf("IsTailscaleIPv4(%v) = %v, want %v", tt.in, got, tt.want) + } + } +} + +func TestIsTailscaleIP(t *testing.T) { + tests := []struct { + in netip.Addr + want bool + }{ + { + in: netip.MustParseAddr("100.67.19.57"), + want: true, + }, + { + in: netip.MustParseAddr("10.10.10.10"), + want: false, + }, + { + + in: netip.MustParseAddr("fd7a:115c:a1e0:3f2b:7a1d:4e88:9c2b:7f01"), + want: true, + }, + { + in: netip.MustParseAddr("bc9d:0aa0:1f0a:69ab:eb5c:28e0:5456:a518"), + want: false, + }, + { + in: netip.MustParseAddr("100.115.92.157"), + want: false, + }, + } + for _, tt := range tests { + if got := IsTailscaleIP(tt.in); got != tt.want { + t.Errorf("IsTailscaleIP(%v) = %v, want %v", tt.in, got, tt.want) + } + } +} diff --git a/net/tsdial/dnsmap.go b/net/tsdial/dnsmap.go index f5d13861b..37fedd14c 100644 --- a/net/tsdial/dnsmap.go +++ b/net/tsdial/dnsmap.go @@ -36,14 +36,14 @@ func dnsMapFromNetworkMap(nm *netmap.NetworkMap) dnsMap { suffix := nm.MagicDNSSuffix() have4 := false addrs := nm.GetAddresses() - if nm.Name != "" && addrs.Len() > 0 { + if name := nm.SelfName(); name != "" && addrs.Len() > 0 { ip := addrs.At(0).Addr() - ret[canonMapKey(nm.Name)] = ip - if dnsname.HasSuffix(nm.Name, suffix) { - ret[canonMapKey(dnsname.TrimSuffix(nm.Name, suffix))] = ip + ret[canonMapKey(name)] = ip + if dnsname.HasSuffix(name, suffix) { + ret[canonMapKey(dnsname.TrimSuffix(name, suffix))] = ip } - for i := range addrs.Len() { - if addrs.At(i).Addr().Is4() { + for _, p := range addrs.All() { + if p.Addr().Is4() { have4 = true } } @@ -52,9 +52,8 @@ func dnsMapFromNetworkMap(nm *netmap.NetworkMap) dnsMap { if p.Name() == "" { continue } - for i := range p.Addresses().Len() { - a := p.Addresses().At(i) - ip := a.Addr() + for _, pfx := range p.Addresses().All() { + ip := pfx.Addr() if ip.Is4() && !have4 { continue } diff --git a/net/tsdial/dnsmap_test.go b/net/tsdial/dnsmap_test.go index 43461a135..41a957f18 100644 --- a/net/tsdial/dnsmap_test.go +++ b/net/tsdial/dnsmap_test.go @@ -31,8 +31,8 @@ func TestDNSMapFromNetworkMap(t *testing.T) { { name: "self", nm: &netmap.NetworkMap{ - Name: "foo.tailnet", SelfNode: (&tailcfg.Node{ + Name: "foo.tailnet.", Addresses: []netip.Prefix{ pfx("100.102.103.104/32"), pfx("100::123/128"), @@ -47,8 +47,8 @@ func TestDNSMapFromNetworkMap(t *testing.T) { { name: "self_and_peers", nm: &netmap.NetworkMap{ - Name: "foo.tailnet", SelfNode: (&tailcfg.Node{ + Name: "foo.tailnet.", Addresses: []netip.Prefix{ pfx("100.102.103.104/32"), pfx("100::123/128"), @@ -82,8 +82,8 @@ func TestDNSMapFromNetworkMap(t *testing.T) { { name: "self_has_v6_only", nm: &netmap.NetworkMap{ - Name: "foo.tailnet", SelfNode: (&tailcfg.Node{ + Name: "foo.tailnet.", Addresses: []netip.Prefix{ pfx("100::123/128"), }, diff --git a/net/tsdial/tsdial.go b/net/tsdial/tsdial.go index 3606dd67f..065c01384 100644 --- a/net/tsdial/tsdial.go +++ b/net/tsdial/tsdial.go @@ -19,14 +19,19 @@ import ( "time" "github.com/gaissmai/bart" + "tailscale.com/feature" + "tailscale.com/feature/buildfeatures" "tailscale.com/net/dnscache" "tailscale.com/net/netknob" "tailscale.com/net/netmon" "tailscale.com/net/netns" + "tailscale.com/net/netx" "tailscale.com/net/tsaddr" + "tailscale.com/syncs" "tailscale.com/types/logger" "tailscale.com/types/netmap" "tailscale.com/util/clientmetric" + "tailscale.com/util/eventbus" "tailscale.com/util/mak" "tailscale.com/util/testenv" "tailscale.com/version" @@ -43,6 +48,13 @@ func NewDialer(netMon *netmon.Monitor) *Dialer { return d } +// NewFromFuncForDebug is like NewDialer but takes a netx.DialFunc +// and no netMon. It's meant exclusively for the "tailscale debug ts2021" +// debug command, and perhaps tests. +func NewFromFuncForDebug(logf logger.Logf, dial netx.DialFunc) *Dialer { + return &Dialer{sysDialForTest: dial, Logf: logf} +} + // Dialer dials out of tailscaled, while taking care of details while // handling the dozens of edge cases depending on the server mode // (TUN, netstack), the OS network sandboxing style (macOS/iOS @@ -71,10 +83,11 @@ type Dialer struct { netnsDialerOnce sync.Once netnsDialer netns.Dialer + sysDialForTest netx.DialFunc // or nil routes atomic.Pointer[bart.Table[bool]] // or nil if UserDial should not use routes. `true` indicates routes that point into the Tailscale interface - mu sync.Mutex + mu syncs.Mutex closed bool dns dnsMap tunName string // tun device name @@ -84,6 +97,9 @@ type Dialer struct { dnsCache *dnscache.MessageCache // nil until first non-empty SetExitDNSDoH nextSysConnID int activeSysConns map[int]net.Conn // active connections not yet closed + bus *eventbus.Bus // only used for comparison with already set bus. + eventClient *eventbus.Client + eventBusSubs eventbus.Monitor } // sysConn wraps a net.Conn that was created using d.SystemDial. @@ -123,6 +139,9 @@ func (d *Dialer) TUNName() string { // // For example, "http://100.68.82.120:47830/dns-query". func (d *Dialer) SetExitDNSDoH(doh string) { + if !buildfeatures.HasUseExitNode { + return + } d.mu.Lock() defer d.mu.Unlock() if d.exitDNSDoHBase == doh { @@ -149,12 +168,16 @@ func (d *Dialer) SetRoutes(routes, localRoutes []netip.Prefix) { for _, r := range localRoutes { rt.Insert(r, false) } + d.logf("tsdial: bart table size: %d", rt.Size()) } d.routes.Store(rt) } func (d *Dialer) Close() error { + if d.eventClient != nil { + d.eventBusSubs.Close() + } d.mu.Lock() defer d.mu.Unlock() d.closed = true @@ -183,6 +206,14 @@ func (d *Dialer) SetNetMon(netMon *netmon.Monitor) { d.netMonUnregister = nil } d.netMon = netMon + // Having multiple watchers could lead to problems, + // so remove the eventClient if it exists. + // This should really not happen, but better checking for it than not. + // TODO(cmol): Should this just be a panic? + if d.eventClient != nil { + d.eventBusSubs.Close() + d.eventClient = nil + } d.netMonUnregister = d.netMon.RegisterChangeCallback(d.linkChanged) } @@ -194,6 +225,38 @@ func (d *Dialer) NetMon() *netmon.Monitor { return d.netMon } +func (d *Dialer) SetBus(bus *eventbus.Bus) { + d.mu.Lock() + defer d.mu.Unlock() + if d.bus == bus { + return + } else if d.bus != nil { + panic("different eventbus has already been set") + } + // Having multiple watchers could lead to problems, + // so unregister the callback if it exists. + if d.netMonUnregister != nil { + d.netMonUnregister() + } + d.bus = bus + d.eventClient = bus.Client("tsdial.Dialer") + d.eventBusSubs = d.eventClient.Monitor(d.linkChangeWatcher(d.eventClient)) +} + +func (d *Dialer) linkChangeWatcher(ec *eventbus.Client) func(*eventbus.Client) { + linkChangeSub := eventbus.Subscribe[netmon.ChangeDelta](ec) + return func(ec *eventbus.Client) { + for { + select { + case <-ec.Done(): + return + case cd := <-linkChangeSub.Events(): + d.linkChanged(&cd) + } + } + } +} + var ( metricLinkChangeConnClosed = clientmetric.NewCounter("tsdial_linkchange_closes") metricChangeDeltaNoDefaultRoute = clientmetric.NewCounter("tsdial_changedelta_no_default_route") @@ -242,7 +305,7 @@ func changeAffectsConn(delta *netmon.ChangeDelta, conn net.Conn) bool { // In a few cases, we don't have a new DefaultRouteInterface (e.g. on // Android; see tailscale/corp#19124); if so, pessimistically assume // that all connections are affected. - if delta.New.DefaultRouteInterface == "" { + if delta.New.DefaultRouteInterface == "" && runtime.GOOS != "plan9" { return true } @@ -319,7 +382,7 @@ func (d *Dialer) userDialResolve(ctx context.Context, network, addr string) (net } var r net.Resolver - if exitDNSDoH != "" && runtime.GOOS != "windows" { // Windows: https://github.com/golang/go/issues/33097 + if buildfeatures.HasUseExitNode && buildfeatures.HasPeerAPIClient && exitDNSDoH != "" { r.PreferGo = true r.Dial = func(ctx context.Context, network, address string) (net.Conn, error) { return &dohConn{ @@ -361,13 +424,20 @@ func (d *Dialer) logf(format string, args ...any) { } } +// SetSystemDialerForTest sets an alternate function to use for SystemDial +// instead of netns.Dialer. This is intended for use with nettest.MemoryNetwork. +func (d *Dialer) SetSystemDialerForTest(fn netx.DialFunc) { + testenv.AssertInTest() + d.sysDialForTest = fn +} + // SystemDial connects to the provided network address without going over // Tailscale. It prefers going over the default interface and closes existing // connections if the default interface changes. It is used to connect to // Control and (in the future, as of 2022-04-27) DERPs.. func (d *Dialer) SystemDial(ctx context.Context, network, addr string) (net.Conn, error) { d.mu.Lock() - if d.netMon == nil { + if d.netMon == nil && d.sysDialForTest == nil { d.mu.Unlock() if testenv.InTest() { panic("SystemDial requires a netmon.Monitor; call SetNetMon first") @@ -380,10 +450,16 @@ func (d *Dialer) SystemDial(ctx context.Context, network, addr string) (net.Conn return nil, net.ErrClosed } - d.netnsDialerOnce.Do(func() { - d.netnsDialer = netns.NewDialer(d.logf, d.netMon) - }) - c, err := d.netnsDialer.DialContext(ctx, network, addr) + var c net.Conn + var err error + if d.sysDialForTest != nil { + c, err = d.sysDialForTest(ctx, network, addr) + } else { + d.netnsDialerOnce.Do(func() { + d.netnsDialer = netns.NewDialer(d.logf, d.netMon) + }) + c, err = d.netnsDialer.DialContext(ctx, network, addr) + } if err != nil { return nil, err } @@ -443,6 +519,9 @@ func (d *Dialer) UserDial(ctx context.Context, network, addr string) (net.Conn, // network must a "tcp" type, and addr must be an ip:port. Name resolution // is not supported. func (d *Dialer) dialPeerAPI(ctx context.Context, network, addr string) (net.Conn, error) { + if !buildfeatures.HasPeerAPIClient { + return nil, feature.ErrUnavailable + } switch network { case "tcp", "tcp6", "tcp4": default: @@ -485,6 +564,9 @@ func (d *Dialer) getPeerDialer() *net.Dialer { // The returned Client must not be mutated; it's owned by the Dialer // and shared by callers. func (d *Dialer) PeerAPIHTTPClient() *http.Client { + if !buildfeatures.HasPeerAPIClient { + panic("unreachable") + } d.peerClientOnce.Do(func() { t := http.DefaultTransport.(*http.Transport).Clone() t.Dial = nil diff --git a/net/tshttpproxy/tshttpproxy.go b/net/tshttpproxy/tshttpproxy.go index 2ca440b57..0456009ed 100644 --- a/net/tshttpproxy/tshttpproxy.go +++ b/net/tshttpproxy/tshttpproxy.go @@ -7,6 +7,7 @@ package tshttpproxy import ( "context" + "errors" "fmt" "log" "net" @@ -38,6 +39,23 @@ var ( proxyFunc func(*url.URL) (*url.URL, error) ) +// SetProxyFunc can be used by clients to set a platform-specific function for proxy resolution. +// If config is set when this function is called, an error will be returned. +// The provided function should return a proxy URL for the given request URL, +// nil if no proxy is enabled for the request URL, or an error if proxy settings cannot be resolved. +func SetProxyFunc(fn func(*url.URL) (*url.URL, error)) error { + mu.Lock() + defer mu.Unlock() + + // Allow override only if config is not set + if config != nil { + return errors.New("tshttpproxy: SetProxyFunc can only be called when config is not set") + } + + proxyFunc = fn + return nil +} + func getProxyFunc() func(*url.URL) (*url.URL, error) { // Create config/proxyFunc if it's not created mu.Lock() diff --git a/net/tshttpproxy/tshttpproxy_linux.go b/net/tshttpproxy/tshttpproxy_linux.go index b241c256d..7e086e492 100644 --- a/net/tshttpproxy/tshttpproxy_linux.go +++ b/net/tshttpproxy/tshttpproxy_linux.go @@ -9,6 +9,7 @@ import ( "net/http" "net/url" + "tailscale.com/feature/buildfeatures" "tailscale.com/version/distro" ) @@ -17,7 +18,7 @@ func init() { } func linuxSysProxyFromEnv(req *http.Request) (*url.URL, error) { - if distro.Get() == distro.Synology { + if buildfeatures.HasSynology && distro.Get() == distro.Synology { return synologyProxyFromConfigCached(req) } return nil, nil diff --git a/net/tshttpproxy/tshttpproxy_synology.go b/net/tshttpproxy/tshttpproxy_synology.go index cda957648..e28844f7d 100644 --- a/net/tshttpproxy/tshttpproxy_synology.go +++ b/net/tshttpproxy/tshttpproxy_synology.go @@ -17,7 +17,7 @@ import ( "sync" "time" - "tailscale.com/util/lineread" + "tailscale.com/util/lineiter" ) // These vars are overridden for tests. @@ -47,7 +47,7 @@ func synologyProxyFromConfigCached(req *http.Request) (*url.URL, error) { var err error modtime := mtime(synologyProxyConfigPath) - if modtime != cache.updated { + if !modtime.Equal(cache.updated) { cache.httpProxy, cache.httpsProxy, err = synologyProxiesFromConfig() cache.updated = modtime } @@ -76,21 +76,22 @@ func synologyProxiesFromConfig() (*url.URL, *url.URL, error) { func parseSynologyConfig(r io.Reader) (*url.URL, *url.URL, error) { cfg := map[string]string{} - if err := lineread.Reader(r, func(line []byte) error { + for lr := range lineiter.Reader(r) { + line, err := lr.Value() + if err != nil { + return nil, nil, err + } // accept and skip over empty lines line = bytes.TrimSpace(line) if len(line) == 0 { - return nil + continue } key, value, ok := strings.Cut(string(line), "=") if !ok { - return fmt.Errorf("missing \"=\" in proxy.conf line: %q", line) + return nil, nil, fmt.Errorf("missing \"=\" in proxy.conf line: %q", line) } cfg[string(key)] = string(value) - return nil - }); err != nil { - return nil, nil, err } if cfg["proxy_enabled"] != "yes" { diff --git a/net/tshttpproxy/tshttpproxy_synology_test.go b/net/tshttpproxy/tshttpproxy_synology_test.go index 3061740f3..b6e8b948c 100644 --- a/net/tshttpproxy/tshttpproxy_synology_test.go +++ b/net/tshttpproxy/tshttpproxy_synology_test.go @@ -41,7 +41,7 @@ func TestSynologyProxyFromConfigCached(t *testing.T) { t.Fatalf("got %s, %v; want nil, nil", val, err) } - if got, want := cache.updated, time.Unix(0, 0); got != want { + if got, want := cache.updated.UTC(), time.Unix(0, 0).UTC(); !got.Equal(want) { t.Fatalf("got %s, want %s", got, want) } if cache.httpProxy != nil { diff --git a/net/tshttpproxy/tshttpproxy_windows.go b/net/tshttpproxy/tshttpproxy_windows.go index 06a1f5ae4..7163c7863 100644 --- a/net/tshttpproxy/tshttpproxy_windows.go +++ b/net/tshttpproxy/tshttpproxy_windows.go @@ -18,6 +18,7 @@ import ( "unsafe" "github.com/alexbrainman/sspi/negotiate" + "github.com/dblohm7/wingoes" "golang.org/x/sys/windows" "tailscale.com/hostinfo" "tailscale.com/syncs" @@ -97,9 +98,7 @@ func proxyFromWinHTTPOrCache(req *http.Request) (*url.URL, error) { } if err == windows.ERROR_INVALID_PARAMETER { metricErrInvalidParameters.Add(1) - // Seen on Windows 8.1. (https://github.com/tailscale/tailscale/issues/879) - // TODO(bradfitz): figure this out. - setNoProxyUntil(time.Hour) + setNoProxyUntil(10 * time.Second) proxyErrorf("tshttpproxy: winhttp: GetProxyForURL(%q): ERROR_INVALID_PARAMETER [unexpected]", urlStr) return nil, nil } @@ -238,17 +237,30 @@ func (pi *winHTTPProxyInfo) free() { } } -var proxyForURLOpts = &winHTTPAutoProxyOptions{ - DwFlags: winHTTP_AUTOPROXY_ALLOW_AUTOCONFIG | winHTTP_AUTOPROXY_AUTO_DETECT, - DwAutoDetectFlags: winHTTP_AUTO_DETECT_TYPE_DHCP, // | winHTTP_AUTO_DETECT_TYPE_DNS_A, -} +var getProxyForURLOpts = sync.OnceValue(func() *winHTTPAutoProxyOptions { + opts := &winHTTPAutoProxyOptions{ + DwFlags: winHTTP_AUTOPROXY_AUTO_DETECT, + DwAutoDetectFlags: winHTTP_AUTO_DETECT_TYPE_DHCP | winHTTP_AUTO_DETECT_TYPE_DNS_A, + } + // Support for the WINHTTP_AUTOPROXY_ALLOW_AUTOCONFIG flag was added in Windows 10, version 1703. + // + // Using it on earlier versions causes GetProxyForURL to fail with ERROR_INVALID_PARAMETER, + // which prevents proxy detection and can lead to failures reaching the control server + // on environments where a proxy is required. + // + // https://web.archive.org/web/20250529044903/https://learn.microsoft.com/en-us/windows/win32/api/winhttp/ns-winhttp-winhttp_autoproxy_options + if wingoes.IsWin10BuildOrGreater(wingoes.Win10Build1703) { + opts.DwFlags |= winHTTP_AUTOPROXY_ALLOW_AUTOCONFIG + } + return opts +}) func (hi winHTTPInternet) GetProxyForURL(urlStr string) (string, error) { var out winHTTPProxyInfo err := winHTTPGetProxyForURL( hi, windows.StringToUTF16Ptr(urlStr), - proxyForURLOpts, + getProxyForURLOpts(), &out, ) if err != nil { diff --git a/net/tshttpproxy/zsyscall_windows.go b/net/tshttpproxy/zsyscall_windows.go index c07e9ee03..5dcfae83e 100644 --- a/net/tshttpproxy/zsyscall_windows.go +++ b/net/tshttpproxy/zsyscall_windows.go @@ -48,7 +48,7 @@ var ( ) func globalFree(hglobal winHGlobal) (err error) { - r1, _, e1 := syscall.Syscall(procGlobalFree.Addr(), 1, uintptr(hglobal), 0, 0) + r1, _, e1 := syscall.SyscallN(procGlobalFree.Addr(), uintptr(hglobal)) if r1 == 0 { err = errnoErr(e1) } @@ -56,7 +56,7 @@ func globalFree(hglobal winHGlobal) (err error) { } func winHTTPCloseHandle(whi winHTTPInternet) (err error) { - r1, _, e1 := syscall.Syscall(procWinHttpCloseHandle.Addr(), 1, uintptr(whi), 0, 0) + r1, _, e1 := syscall.SyscallN(procWinHttpCloseHandle.Addr(), uintptr(whi)) if r1 == 0 { err = errnoErr(e1) } @@ -64,7 +64,7 @@ func winHTTPCloseHandle(whi winHTTPInternet) (err error) { } func winHTTPGetProxyForURL(whi winHTTPInternet, url *uint16, options *winHTTPAutoProxyOptions, proxyInfo *winHTTPProxyInfo) (err error) { - r1, _, e1 := syscall.Syscall6(procWinHttpGetProxyForUrl.Addr(), 4, uintptr(whi), uintptr(unsafe.Pointer(url)), uintptr(unsafe.Pointer(options)), uintptr(unsafe.Pointer(proxyInfo)), 0, 0) + r1, _, e1 := syscall.SyscallN(procWinHttpGetProxyForUrl.Addr(), uintptr(whi), uintptr(unsafe.Pointer(url)), uintptr(unsafe.Pointer(options)), uintptr(unsafe.Pointer(proxyInfo))) if r1 == 0 { err = errnoErr(e1) } @@ -72,7 +72,7 @@ func winHTTPGetProxyForURL(whi winHTTPInternet, url *uint16, options *winHTTPAut } func winHTTPOpen(agent *uint16, accessType uint32, proxy *uint16, proxyBypass *uint16, flags uint32) (whi winHTTPInternet, err error) { - r0, _, e1 := syscall.Syscall6(procWinHttpOpen.Addr(), 5, uintptr(unsafe.Pointer(agent)), uintptr(accessType), uintptr(unsafe.Pointer(proxy)), uintptr(unsafe.Pointer(proxyBypass)), uintptr(flags), 0) + r0, _, e1 := syscall.SyscallN(procWinHttpOpen.Addr(), uintptr(unsafe.Pointer(agent)), uintptr(accessType), uintptr(unsafe.Pointer(proxy)), uintptr(unsafe.Pointer(proxyBypass)), uintptr(flags)) whi = winHTTPInternet(r0) if whi == 0 { err = errnoErr(e1) diff --git a/net/tstun/linkattrs_notlinux.go b/net/tstun/linkattrs_notlinux.go deleted file mode 100644 index 7a7b40fc2..000000000 --- a/net/tstun/linkattrs_notlinux.go +++ /dev/null @@ -1,12 +0,0 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !linux - -package tstun - -import "github.com/tailscale/wireguard-go/tun" - -func setLinkAttrs(iface tun.Device) error { - return nil -} diff --git a/net/tstun/mtu_test.go b/net/tstun/mtu_test.go index 8d165bfd3..ec31e45ce 100644 --- a/net/tstun/mtu_test.go +++ b/net/tstun/mtu_test.go @@ -1,5 +1,6 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause + package tstun import ( diff --git a/net/tstun/netstack_disabled.go b/net/tstun/netstack_disabled.go new file mode 100644 index 000000000..c1266b305 --- /dev/null +++ b/net/tstun/netstack_disabled.go @@ -0,0 +1,69 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build ts_omit_netstack + +package tstun + +type netstack_PacketBuffer struct { + GSOOptions netstack_GSO +} + +func (*netstack_PacketBuffer) DecRef() { panic("unreachable") } +func (*netstack_PacketBuffer) Size() int { panic("unreachable") } + +type netstack_GSOType int + +const ( + netstack_GSONone netstack_GSOType = iota + netstack_GSOTCPv4 + netstack_GSOTCPv6 + netstack_GSOGvisor +) + +type netstack_GSO struct { + // Type is one of GSONone, GSOTCPv4, etc. + Type netstack_GSOType + // NeedsCsum is set if the checksum offload is enabled. + NeedsCsum bool + // CsumOffset is offset after that to place checksum. + CsumOffset uint16 + + // Mss is maximum segment size. + MSS uint16 + // L3Len is L3 (IP) header length. + L3HdrLen uint16 + + // MaxSize is maximum GSO packet size. + MaxSize uint32 +} + +func (p *netstack_PacketBuffer) NetworkHeader() slicer { + panic("unreachable") +} + +func (p *netstack_PacketBuffer) TransportHeader() slicer { + panic("unreachable") +} + +func (p *netstack_PacketBuffer) ToBuffer() netstack_Buffer { panic("unreachable") } + +func (p *netstack_PacketBuffer) Data() asRanger { + panic("unreachable") +} + +type asRanger struct{} + +func (asRanger) AsRange() toSlicer { panic("unreachable") } + +type toSlicer struct{} + +func (toSlicer) ToSlice() []byte { panic("unreachable") } + +type slicer struct{} + +func (s slicer) Slice() []byte { panic("unreachable") } + +type netstack_Buffer struct{} + +func (netstack_Buffer) Flatten() []byte { panic("unreachable") } diff --git a/net/tstun/netstack_enabled.go b/net/tstun/netstack_enabled.go new file mode 100644 index 000000000..8fc1a2e20 --- /dev/null +++ b/net/tstun/netstack_enabled.go @@ -0,0 +1,22 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !ts_omit_netstack + +package tstun + +import ( + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +type ( + netstack_PacketBuffer = stack.PacketBuffer + netstack_GSO = stack.GSO +) + +const ( + netstack_GSONone = stack.GSONone + netstack_GSOTCPv4 = stack.GSOTCPv4 + netstack_GSOTCPv6 = stack.GSOTCPv6 + netstack_GSOGvisor = stack.GSOGvisor +) diff --git a/net/tstun/tap_unsupported.go b/net/tstun/tap_unsupported.go deleted file mode 100644 index 6792b229f..000000000 --- a/net/tstun/tap_unsupported.go +++ /dev/null @@ -1,8 +0,0 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !linux || ts_omit_tap - -package tstun - -func (*Wrapper) handleTAPFrame([]byte) bool { panic("unreachable") } diff --git a/net/tstun/tstun_stub.go b/net/tstun/tstun_stub.go index 7a4f71a09..d21eda6b0 100644 --- a/net/tstun/tstun_stub.go +++ b/net/tstun/tstun_stub.go @@ -1,7 +1,7 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -//go:build plan9 || aix +//go:build aix || solaris || illumos package tstun diff --git a/net/tstun/tun.go b/net/tstun/tun.go index 66e209d1a..19b0a53f5 100644 --- a/net/tstun/tun.go +++ b/net/tstun/tun.go @@ -1,7 +1,7 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -//go:build !wasm && !plan9 && !tamago && !aix +//go:build !wasm && !tamago && !aix && !solaris && !illumos // Package tun creates a tuntap device, working around OS-specific // quirks if necessary. @@ -9,16 +9,27 @@ package tstun import ( "errors" + "fmt" + "log" + "os" "runtime" "strings" "time" "github.com/tailscale/wireguard-go/tun" + "tailscale.com/feature" + "tailscale.com/feature/buildfeatures" "tailscale.com/types/logger" ) -// createTAP is non-nil on Linux. -var createTAP func(tapName, bridgeName string) (tun.Device, error) +// CreateTAP is the hook maybe set by feature/tap. +var CreateTAP feature.Hook[func(logf logger.Logf, tapName, bridgeName string) (tun.Device, error)] + +// HookSetLinkAttrs is the hook maybe set by feature/linkspeed. +var HookSetLinkAttrs feature.Hook[func(tun.Device) error] + +// modprobeTunHook is a Linux-specific hook to run "/sbin/modprobe tun". +var modprobeTunHook feature.Hook[func() error] // New returns a tun.Device for the requested device name, along with // the OS-dependent name that was allocated to the device. @@ -29,7 +40,7 @@ func New(logf logger.Logf, tunName string) (tun.Device, string, error) { if runtime.GOOS != "linux" { return nil, "", errors.New("tap only works on Linux") } - if createTAP == nil { // if the ts_omit_tap tag is used + if !CreateTAP.IsSet() { // if the ts_omit_tap tag is used return nil, "", errors.New("tap is not supported in this build") } f := strings.Split(tunName, ":") @@ -42,9 +53,27 @@ func New(logf logger.Logf, tunName string) (tun.Device, string, error) { default: return nil, "", errors.New("bogus tap argument") } - dev, err = createTAP(tapName, bridgeName) + dev, err = CreateTAP.Get()(logf, tapName, bridgeName) } else { - dev, err = tun.CreateTUN(tunName, int(DefaultTUNMTU())) + if runtime.GOOS == "plan9" { + cleanUpPlan9Interfaces() + } + // Try to create the TUN device up to two times. If it fails + // the first time and we're on Linux, try a desperate + // "modprobe tun" to load the tun module and try again. + for try := range 2 { + dev, err = tun.CreateTUN(tunName, int(DefaultTUNMTU())) + if err == nil || !modprobeTunHook.IsSet() { + if try > 0 { + logf("created TUN device %q after doing `modprobe tun`", tunName) + } + break + } + if modprobeTunHook.Get()() != nil { + // modprobe failed; no point trying again. + break + } + } } if err != nil { return nil, "", err @@ -53,8 +82,12 @@ func New(logf logger.Logf, tunName string) (tun.Device, string, error) { dev.Close() return nil, "", err } - if err := setLinkAttrs(dev); err != nil { - logf("setting link attributes: %v", err) + if buildfeatures.HasLinkSpeed { + if f, ok := HookSetLinkAttrs.GetOk(); ok { + if err := f(dev); err != nil { + logf("setting link attributes: %v", err) + } + } } name, err := interfaceName(dev) if err != nil { @@ -64,6 +97,36 @@ func New(logf logger.Logf, tunName string) (tun.Device, string, error) { return dev, name, nil } +func cleanUpPlan9Interfaces() { + maybeUnbind := func(n int) { + b, err := os.ReadFile(fmt.Sprintf("/net/ipifc/%d/status", n)) + if err != nil { + return + } + status := string(b) + if !(strings.HasPrefix(status, "device maxtu ") || + strings.Contains(status, "fd7a:115c:a1e0:")) { + return + } + f, err := os.OpenFile(fmt.Sprintf("/net/ipifc/%d/ctl", n), os.O_RDWR, 0) + if err != nil { + return + } + defer f.Close() + if _, err := fmt.Fprintf(f, "unbind\n"); err != nil { + log.Printf("unbind interface %v: %v", n, err) + return + } + log.Printf("tun: unbound stale interface %v", n) + } + + // A common case: after unclean shutdown we might leave interfaces + // behind. Look for our straggler(s) and clean them up. + for n := 2; n < 5; n++ { + maybeUnbind(n) + } +} + // tunDiagnoseFailure, if non-nil, does OS-specific diagnostics of why // TUN failed to work. var tunDiagnoseFailure func(tunName string, logf logger.Logf, err error) diff --git a/net/tstun/tun_linux.go b/net/tstun/tun_linux.go index 9600ceb77..05cf58c17 100644 --- a/net/tstun/tun_linux.go +++ b/net/tstun/tun_linux.go @@ -17,6 +17,14 @@ import ( func init() { tunDiagnoseFailure = diagnoseLinuxTUNFailure + modprobeTunHook.Set(func() error { + _, err := modprobeTun() + return err + }) +} + +func modprobeTun() ([]byte, error) { + return exec.Command("/sbin/modprobe", "tun").CombinedOutput() } func diagnoseLinuxTUNFailure(tunName string, logf logger.Logf, createErr error) { @@ -36,7 +44,7 @@ func diagnoseLinuxTUNFailure(tunName string, logf logger.Logf, createErr error) kernel := utsReleaseField(&un) logf("Linux kernel version: %s", kernel) - modprobeOut, err := exec.Command("/sbin/modprobe", "tun").CombinedOutput() + modprobeOut, err := modprobeTun() if err == nil { logf("'modprobe tun' successful") // Either tun is currently loaded, or it's statically diff --git a/net/tstun/wrap.go b/net/tstun/wrap.go index dcd43d571..db4f689bf 100644 --- a/net/tstun/wrap.go +++ b/net/tstun/wrap.go @@ -22,10 +22,8 @@ import ( "github.com/tailscale/wireguard-go/device" "github.com/tailscale/wireguard-go/tun" "go4.org/mem" - "gvisor.dev/gvisor/pkg/tcpip/stack" "tailscale.com/disco" - tsmetrics "tailscale.com/metrics" - "tailscale.com/net/connstats" + "tailscale.com/feature/buildfeatures" "tailscale.com/net/packet" "tailscale.com/net/packet/checksum" "tailscale.com/net/tsaddr" @@ -34,9 +32,9 @@ import ( "tailscale.com/types/ipproto" "tailscale.com/types/key" "tailscale.com/types/logger" + "tailscale.com/types/netlogfunc" "tailscale.com/util/clientmetric" "tailscale.com/util/usermetric" - "tailscale.com/wgengine/capture" "tailscale.com/wgengine/filter" "tailscale.com/wgengine/netstack/gro" "tailscale.com/wgengine/wgcfg" @@ -53,7 +51,8 @@ const PacketStartOffset = device.MessageTransportHeaderSize // of a packet that can be injected into a tstun.Wrapper. const MaxPacketSize = device.MaxContentSize -const tapDebug = false // for super verbose TAP debugging +// TAPDebug is whether super verbose TAP debugging is enabled. +const TAPDebug = false var ( // ErrClosed is returned when attempting an operation on a closed Wrapper. @@ -109,9 +108,7 @@ type Wrapper struct { lastActivityAtomic mono.Time // time of last send or receive destIPActivity syncs.AtomicValue[map[netip.Addr]func()] - //lint:ignore U1000 used in tap_linux.go - destMACAtomic syncs.AtomicValue[[6]byte] - discoKey syncs.AtomicValue[key.DiscoPublic] + discoKey syncs.AtomicValue[key.DiscoPublic] // timeNow, if non-nil, will be used to obtain the current time. timeNow func() time.Time @@ -206,33 +203,23 @@ type Wrapper struct { // disableTSMPRejected disables TSMP rejected responses. For tests. disableTSMPRejected bool - // stats maintains per-connection counters. - stats atomic.Pointer[connstats.Statistics] + // connCounter maintains per-connection counters. + connCounter syncs.AtomicValue[netlogfunc.ConnectionCounter] - captureHook syncs.AtomicValue[capture.Callback] + captureHook syncs.AtomicValue[packet.CaptureCallback] metrics *metrics } type metrics struct { - inboundDroppedPacketsTotal *tsmetrics.MultiLabelMap[dropPacketLabel] - outboundDroppedPacketsTotal *tsmetrics.MultiLabelMap[dropPacketLabel] + inboundDroppedPacketsTotal *usermetric.MultiLabelMap[usermetric.DropLabels] + outboundDroppedPacketsTotal *usermetric.MultiLabelMap[usermetric.DropLabels] } func registerMetrics(reg *usermetric.Registry) *metrics { return &metrics{ - inboundDroppedPacketsTotal: usermetric.NewMultiLabelMapWithRegistry[dropPacketLabel]( - reg, - "tailscaled_inbound_dropped_packets_total", - "counter", - "Counts the number of dropped packets received by the node from other peers", - ), - outboundDroppedPacketsTotal: usermetric.NewMultiLabelMapWithRegistry[dropPacketLabel]( - reg, - "tailscaled_outbound_dropped_packets_total", - "counter", - "Counts the number of packets dropped while being sent to other peers", - ), + inboundDroppedPacketsTotal: reg.DroppedPacketsInbound(), + outboundDroppedPacketsTotal: reg.DroppedPacketsOutbound(), } } @@ -240,7 +227,7 @@ func registerMetrics(reg *usermetric.Registry) *metrics { type tunInjectedRead struct { // Only one of packet or data should be set, and are read in that order of // precedence. - packet *stack.PacketBuffer + packet *netstack_PacketBuffer data []byte } @@ -257,12 +244,6 @@ type tunVectorReadResult struct { dataOffset int } -type setWrapperer interface { - // setWrapper enables the underlying TUN/TAP to have access to the Wrapper. - // It MUST be called only once during initialization, other usage is unsafe. - setWrapper(*Wrapper) -} - // Start unblocks any Wrapper.Read calls that have already started // and makes the Wrapper functional. // @@ -313,10 +294,6 @@ func wrap(logf logger.Logf, tdev tun.Device, isTAP bool, m *usermetric.Registry) w.bufferConsumed <- struct{}{} w.noteActivity() - if sw, ok := w.tdev.(setWrapperer); ok { - sw.setWrapper(w) - } - return w } @@ -334,7 +311,9 @@ func (t *Wrapper) now() time.Time { // // The map ownership passes to the Wrapper. It must be non-nil. func (t *Wrapper) SetDestIPActivityFuncs(m map[netip.Addr]func()) { - t.destIPActivity.Store(m) + if buildfeatures.HasLazyWG { + t.destIPActivity.Store(m) + } } // SetDiscoKey sets the current discovery key. @@ -459,12 +438,18 @@ const ethernetFrameSize = 14 // 2 six byte MACs, 2 bytes ethertype func (t *Wrapper) pollVector() { sizes := make([]int, len(t.vectorBuffer)) readOffset := PacketStartOffset + reader := t.tdev.Read if t.isTAP { - readOffset = PacketStartOffset - ethernetFrameSize + type tapReader interface { + ReadEthernet(buffs [][]byte, sizes []int, offset int) (int, error) + } + if r, ok := t.tdev.(tapReader); ok { + readOffset = PacketStartOffset - ethernetFrameSize + reader = r.ReadEthernet + } } for range t.bufferConsumed { - DoRead: for i := range t.vectorBuffer { t.vectorBuffer[i] = t.vectorBuffer[i][:cap(t.vectorBuffer[i])] } @@ -474,8 +459,8 @@ func (t *Wrapper) pollVector() { if t.isClosed() { return } - n, err = t.tdev.Read(t.vectorBuffer[:], sizes, readOffset) - if t.isTAP && tapDebug { + n, err = reader(t.vectorBuffer[:], sizes, readOffset) + if t.isTAP && TAPDebug { s := fmt.Sprintf("% x", t.vectorBuffer[0][:]) for strings.HasSuffix(s, " 00") { s = strings.TrimSuffix(s, " 00") @@ -486,21 +471,6 @@ func (t *Wrapper) pollVector() { for i := range sizes[:n] { t.vectorBuffer[i] = t.vectorBuffer[i][:readOffset+sizes[i]] } - if t.isTAP { - if err == nil { - ethernetFrame := t.vectorBuffer[0][readOffset:] - if t.handleTAPFrame(ethernetFrame) { - goto DoRead - } - } - // Fall through. We got an IP packet. - if sizes[0] >= ethernetFrameSize { - t.vectorBuffer[0] = t.vectorBuffer[0][:readOffset+sizes[0]-ethernetFrameSize] - } - if tapDebug { - t.logf("tap regular frame: %x", t.vectorBuffer[0][PacketStartOffset:PacketStartOffset+sizes[0]]) - } - } t.sendVectorOutbound(tunVectorReadResult{ data: t.vectorBuffer[:n], dataOffset: PacketStartOffset, @@ -823,10 +793,21 @@ func (pc *peerConfigTable) outboundPacketIsJailed(p *packet.Parsed) bool { return c.jailed } +// SetIPer is the interface expected to be implemented by the TAP implementation +// of tun.Device. +type SetIPer interface { + // SetIP sets the IP addresses of the TAP device. + SetIP(ipV4, ipV6 netip.Addr) error +} + // SetWGConfig is called when a new NetworkMap is received. func (t *Wrapper) SetWGConfig(wcfg *wgcfg.Config) { + if t.isTAP { + if sip, ok := t.tdev.(SetIPer); ok { + sip.SetIP(findV4(wcfg.Addresses), findV6(wcfg.Addresses)) + } + } cfg := peerConfigTableFromWGConfig(wcfg) - old := t.peerConfig.Swap(cfg) if !reflect.DeepEqual(old, cfg) { t.logf("peer config: %v", cfg) @@ -896,11 +877,13 @@ func (t *Wrapper) filterPacketOutboundToWireGuard(p *packet.Parsed, pc *peerConf return filter.Drop, gro } - if filt.RunOut(p, t.filterFlags) != filter.Accept { + if resp, reason := filt.RunOut(p, t.filterFlags); resp != filter.Accept { metricPacketOutDropFilter.Add(1) - t.metrics.outboundDroppedPacketsTotal.Add(dropPacketLabel{ - Reason: DropReasonACL, - }, 1) + if reason != "" { + t.metrics.outboundDroppedPacketsTotal.Add(usermetric.DropLabels{ + Reason: reason, + }, 1) + } return filter.Drop, gro } @@ -925,9 +908,23 @@ func (t *Wrapper) IdleDuration() time.Duration { return mono.Since(t.lastActivityAtomic.LoadAtomic()) } +func (t *Wrapper) awaitStart() { + for { + select { + case <-t.startCh: + return + case <-time.After(1 * time.Second): + // Multiple times while remixing tailscaled I (Brad) have forgotten + // to call Start and then wasted far too much time debugging. + // I do not wish that debugging on anyone else. Hopefully this'll help: + t.logf("tstun: awaiting Wrapper.Start call") + } + } +} + func (t *Wrapper) Read(buffs [][]byte, sizes []int, offset int) (int, error) { if !t.started.Load() { - <-t.startCh + t.awaitStart() } // packet from OS read and sent to WG res, ok := <-t.vectorOutbound @@ -952,13 +949,15 @@ func (t *Wrapper) Read(buffs [][]byte, sizes []int, offset int) (int, error) { for _, data := range res.data { p.Decode(data[res.dataOffset:]) - if m := t.destIPActivity.Load(); m != nil { - if fn := m[p.Dst.Addr()]; fn != nil { - fn() + if buildfeatures.HasLazyWG { + if m := t.destIPActivity.Load(); m != nil { + if fn := m[p.Dst.Addr()]; fn != nil { + fn() + } } } - if captHook != nil { - captHook(capture.FromLocal, t.now(), p.Buffer(), p.CaptureMeta) + if buildfeatures.HasCapture && captHook != nil { + captHook(packet.FromLocal, t.now(), p.Buffer(), p.CaptureMeta) } if !t.disableFilter { var response filter.Response @@ -968,6 +967,11 @@ func (t *Wrapper) Read(buffs [][]byte, sizes []int, offset int) (int, error) { continue } } + if buildfeatures.HasNetLog { + if update := t.connCounter.Load(); update != nil { + updateConnCounter(update, p.Buffer(), false) + } + } // Make sure to do SNAT after filtering, so that any flow tracking in // the filter sees the original source address. See #12133. @@ -977,9 +981,6 @@ func (t *Wrapper) Read(buffs [][]byte, sizes []int, offset int) (int, error) { panic(fmt.Sprintf("short copy: %d != %d", n, len(data)-res.dataOffset)) } sizes[buffsPos] = n - if stats := t.stats.Load(); stats != nil { - stats.UpdateTxVirtual(p.Buffer()) - } buffsPos++ } if buffsGRO != nil { @@ -1002,7 +1003,10 @@ const ( minTCPHeaderSize = 20 ) -func stackGSOToTunGSO(pkt []byte, gso stack.GSO) (tun.GSOOptions, error) { +func stackGSOToTunGSO(pkt []byte, gso netstack_GSO) (tun.GSOOptions, error) { + if !buildfeatures.HasNetstack { + panic("unreachable") + } options := tun.GSOOptions{ CsumStart: gso.L3HdrLen, CsumOffset: gso.CsumOffset, @@ -1010,12 +1014,12 @@ func stackGSOToTunGSO(pkt []byte, gso stack.GSO) (tun.GSOOptions, error) { NeedsCsum: gso.NeedsCsum, } switch gso.Type { - case stack.GSONone: + case netstack_GSONone: options.GSOType = tun.GSONone return options, nil - case stack.GSOTCPv4: + case netstack_GSOTCPv4: options.GSOType = tun.GSOTCPv4 - case stack.GSOTCPv6: + case netstack_GSOTCPv6: options.GSOType = tun.GSOTCPv6 default: return tun.GSOOptions{}, fmt.Errorf("unsupported gVisor GSOType: %v", gso.Type) @@ -1038,7 +1042,10 @@ func stackGSOToTunGSO(pkt []byte, gso stack.GSO) (tun.GSOOptions, error) { // both before and after partial checksum updates where later checksum // offloading still expects a partial checksum. // TODO(jwhited): plumb partial checksum awareness into net/packet/checksum. -func invertGSOChecksum(pkt []byte, gso stack.GSO) { +func invertGSOChecksum(pkt []byte, gso netstack_GSO) { + if !buildfeatures.HasNetstack { + panic("unreachable") + } if gso.NeedsCsum != true { return } @@ -1052,10 +1059,13 @@ func invertGSOChecksum(pkt []byte, gso stack.GSO) { // injectedRead handles injected reads, which bypass filters. func (t *Wrapper) injectedRead(res tunInjectedRead, outBuffs [][]byte, sizes []int, offset int) (n int, err error) { - var gso stack.GSO + var gso netstack_GSO pkt := outBuffs[0][offset:] if res.packet != nil { + if !buildfeatures.HasNetstack { + panic("unreachable") + } bufN := copy(pkt, res.packet.NetworkHeader().Slice()) bufN += copy(pkt[bufN:], res.packet.TransportHeader().Slice()) bufN += copy(pkt[bufN:], res.packet.Data().AsRange().ToSlice()) @@ -1078,9 +1088,11 @@ func (t *Wrapper) injectedRead(res tunInjectedRead, outBuffs [][]byte, sizes []i pc.snat(p) invertGSOChecksum(pkt, gso) - if m := t.destIPActivity.Load(); m != nil { - if fn := m[p.Dst.Addr()]; fn != nil { - fn() + if buildfeatures.HasLazyWG { + if m := t.destIPActivity.Load(); m != nil { + if fn := m[p.Dst.Addr()]; fn != nil { + fn() + } } } @@ -1093,9 +1105,11 @@ func (t *Wrapper) injectedRead(res tunInjectedRead, outBuffs [][]byte, sizes []i n, err = tun.GSOSplit(pkt, gsoOptions, outBuffs, sizes, offset) } - if stats := t.stats.Load(); stats != nil { - for i := 0; i < n; i++ { - stats.UpdateTxVirtual(outBuffs[i][offset : offset+sizes[i]]) + if buildfeatures.HasNetLog { + if update := t.connCounter.Load(); update != nil { + for i := 0; i < n; i++ { + updateConnCounter(update, outBuffs[i][offset:offset+sizes[i]], false) + } } } @@ -1104,9 +1118,9 @@ func (t *Wrapper) injectedRead(res tunInjectedRead, outBuffs [][]byte, sizes []i return n, err } -func (t *Wrapper) filterPacketInboundFromWireGuard(p *packet.Parsed, captHook capture.Callback, pc *peerConfigTable, gro *gro.GRO) (filter.Response, *gro.GRO) { +func (t *Wrapper) filterPacketInboundFromWireGuard(p *packet.Parsed, captHook packet.CaptureCallback, pc *peerConfigTable, gro *gro.GRO) (filter.Response, *gro.GRO) { if captHook != nil { - captHook(capture.FromPeer, t.now(), p.Buffer(), p.CaptureMeta) + captHook(packet.FromPeer, t.now(), p.Buffer(), p.CaptureMeta) } if p.IPProto == ipproto.TSMP { @@ -1170,8 +1184,8 @@ func (t *Wrapper) filterPacketInboundFromWireGuard(p *packet.Parsed, captHook ca if outcome != filter.Accept { metricPacketInDropFilter.Add(1) - t.metrics.inboundDroppedPacketsTotal.Add(dropPacketLabel{ - Reason: DropReasonACL, + t.metrics.inboundDroppedPacketsTotal.Add(usermetric.DropLabels{ + Reason: usermetric.ReasonACL, }, 1) // Tell them, via TSMP, we're dropping them due to the ACL. @@ -1251,8 +1265,8 @@ func (t *Wrapper) Write(buffs [][]byte, offset int) (int, error) { t.noteActivity() _, err := t.tdevWrite(buffs, offset) if err != nil { - t.metrics.inboundDroppedPacketsTotal.Add(dropPacketLabel{ - Reason: DropReasonError, + t.metrics.inboundDroppedPacketsTotal.Add(usermetric.DropLabels{ + Reason: usermetric.ReasonError, }, int64(len(buffs))) } return len(buffs), err @@ -1261,9 +1275,11 @@ func (t *Wrapper) Write(buffs [][]byte, offset int) (int, error) { } func (t *Wrapper) tdevWrite(buffs [][]byte, offset int) (int, error) { - if stats := t.stats.Load(); stats != nil { - for i := range buffs { - stats.UpdateRxVirtual((buffs)[i][offset:]) + if buildfeatures.HasNetLog { + if update := t.connCounter.Load(); update != nil { + for i := range buffs { + updateConnCounter(update, buffs[i][offset:], true) + } } } return t.tdev.Write(buffs, offset) @@ -1301,7 +1317,10 @@ func (t *Wrapper) SetJailedFilter(filt *filter.Filter) { // // This path is typically used to deliver synthesized packets to the // host networking stack. -func (t *Wrapper) InjectInboundPacketBuffer(pkt *stack.PacketBuffer, buffs [][]byte, sizes []int) error { +func (t *Wrapper) InjectInboundPacketBuffer(pkt *netstack_PacketBuffer, buffs [][]byte, sizes []int) error { + if !buildfeatures.HasNetstack { + panic("unreachable") + } buf := buffs[0][PacketStartOffset:] bufN := copy(buf, pkt.NetworkHeader().Slice()) @@ -1320,7 +1339,7 @@ func (t *Wrapper) InjectInboundPacketBuffer(pkt *stack.PacketBuffer, buffs [][]b p.Decode(buf) captHook := t.captureHook.Load() if captHook != nil { - captHook(capture.SynthesizedToLocal, t.now(), p.Buffer(), p.CaptureMeta) + captHook(packet.SynthesizedToLocal, t.now(), p.Buffer(), p.CaptureMeta) } invertGSOChecksum(buf, pkt.GSOOptions) @@ -1440,7 +1459,10 @@ func (t *Wrapper) InjectOutbound(pkt []byte) error { // InjectOutboundPacketBuffer logically behaves as InjectOutbound. It takes ownership of one // reference count on the packet, and the packet may be mutated. The packet refcount will be // decremented after the injected buffer has been read. -func (t *Wrapper) InjectOutboundPacketBuffer(pkt *stack.PacketBuffer) error { +func (t *Wrapper) InjectOutboundPacketBuffer(pkt *netstack_PacketBuffer) error { + if !buildfeatures.HasNetstack { + panic("unreachable") + } size := pkt.Size() if size > MaxPacketSize { pkt.DecRef() @@ -1452,7 +1474,7 @@ func (t *Wrapper) InjectOutboundPacketBuffer(pkt *stack.PacketBuffer) error { } if capt := t.captureHook.Load(); capt != nil { b := pkt.ToBuffer() - capt(capture.SynthesizedToPeer, t.now(), b.Flatten(), packet.CaptureMeta{}) + capt(packet.SynthesizedToPeer, t.now(), b.Flatten(), packet.CaptureMeta{}) } t.injectOutbound(tunInjectedRead{packet: pkt}) @@ -1476,10 +1498,12 @@ func (t *Wrapper) Unwrap() tun.Device { return t.tdev } -// SetStatistics specifies a per-connection statistics aggregator. +// SetConnectionCounter specifies a per-connection statistics aggregator. // Nil may be specified to disable statistics gathering. -func (t *Wrapper) SetStatistics(stats *connstats.Statistics) { - t.stats.Store(stats) +func (t *Wrapper) SetConnectionCounter(fn netlogfunc.ConnectionCounter) { + if buildfeatures.HasNetLog { + t.connCounter.Store(fn) + } } var ( @@ -1494,20 +1518,19 @@ var ( metricPacketOutDropSelfDisco = clientmetric.NewCounter("tstun_out_to_wg_drop_self_disco") ) -type DropReason string - -const ( - DropReasonACL DropReason = "acl" - DropReasonError DropReason = "error" -) - -type dropPacketLabel struct { - // Reason indicates what we have done with the packet, and has the following values: - // - acl (rejected packets because of ACL) - // - error (rejected packets because of an error) - Reason DropReason +func (t *Wrapper) InstallCaptureHook(cb packet.CaptureCallback) { + if !buildfeatures.HasCapture { + return + } + t.captureHook.Store(cb) } -func (t *Wrapper) InstallCaptureHook(cb capture.Callback) { - t.captureHook.Store(cb) +func updateConnCounter(update netlogfunc.ConnectionCounter, b []byte, receive bool) { + var p packet.Parsed + p.Decode(b) + if receive { + update(p.IPProto, p.Dst, p.Src, 1, len(b), true) + } else { + update(p.IPProto, p.Src, p.Dst, 1, len(b), false) + } } diff --git a/net/tstun/wrap_linux.go b/net/tstun/wrap_linux.go index 136ddfe1e..7498f107b 100644 --- a/net/tstun/wrap_linux.go +++ b/net/tstun/wrap_linux.go @@ -1,6 +1,8 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause +//go:build linux && !ts_omit_gro + package tstun import ( diff --git a/net/tstun/wrap_noop.go b/net/tstun/wrap_noop.go index c743072ca..8ad04bafe 100644 --- a/net/tstun/wrap_noop.go +++ b/net/tstun/wrap_noop.go @@ -1,7 +1,7 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -//go:build !linux +//go:build !linux || ts_omit_gro package tstun diff --git a/net/tstun/wrap_test.go b/net/tstun/wrap_test.go index 0ed0075b6..75cf5afb2 100644 --- a/net/tstun/wrap_test.go +++ b/net/tstun/wrap_test.go @@ -5,7 +5,6 @@ package tstun import ( "bytes" - "context" "encoding/binary" "encoding/hex" "expvar" @@ -27,7 +26,6 @@ import ( "gvisor.dev/gvisor/pkg/buffer" "gvisor.dev/gvisor/pkg/tcpip/stack" "tailscale.com/disco" - "tailscale.com/net/connstats" "tailscale.com/net/netaddr" "tailscale.com/net/packet" "tailscale.com/tstest" @@ -40,7 +38,6 @@ import ( "tailscale.com/types/views" "tailscale.com/util/must" "tailscale.com/util/usermetric" - "tailscale.com/wgengine/capture" "tailscale.com/wgengine/filter" "tailscale.com/wgengine/wgcfg" ) @@ -371,9 +368,8 @@ func TestFilter(t *testing.T) { }() var buf [MaxPacketSize]byte - stats := connstats.NewStatistics(0, 0, nil) - defer stats.Shutdown(context.Background()) - tun.SetStatistics(stats) + var stats netlogtype.CountsByConnection + tun.SetConnectionCounter(stats.Add) for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { var n int @@ -381,9 +377,10 @@ func TestFilter(t *testing.T) { var filtered bool sizes := make([]int, 1) - tunStats, _ := stats.TestExtract() + tunStats := stats.Clone() + stats.Reset() if len(tunStats) > 0 { - t.Errorf("connstats.Statistics.Extract = %v, want {}", stats) + t.Errorf("netlogtype.CountsByConnection = %v, want {}", tunStats) } if tt.dir == in { @@ -416,7 +413,8 @@ func TestFilter(t *testing.T) { } } - got, _ := stats.TestExtract() + got := stats.Clone() + stats.Reset() want := map[netlogtype.Connection]netlogtype.Counts{} var wasUDP bool if !tt.drop { @@ -441,19 +439,19 @@ func TestFilter(t *testing.T) { } var metricInboundDroppedPacketsACL, metricInboundDroppedPacketsErr, metricOutboundDroppedPacketsACL int64 - if m, ok := tun.metrics.inboundDroppedPacketsTotal.Get(dropPacketLabel{Reason: DropReasonACL}).(*expvar.Int); ok { + if m, ok := tun.metrics.inboundDroppedPacketsTotal.Get(usermetric.DropLabels{Reason: usermetric.ReasonACL}).(*expvar.Int); ok { metricInboundDroppedPacketsACL = m.Value() } - if m, ok := tun.metrics.inboundDroppedPacketsTotal.Get(dropPacketLabel{Reason: DropReasonError}).(*expvar.Int); ok { + if m, ok := tun.metrics.inboundDroppedPacketsTotal.Get(usermetric.DropLabels{Reason: usermetric.ReasonError}).(*expvar.Int); ok { metricInboundDroppedPacketsErr = m.Value() } - if m, ok := tun.metrics.outboundDroppedPacketsTotal.Get(dropPacketLabel{Reason: DropReasonACL}).(*expvar.Int); ok { + if m, ok := tun.metrics.outboundDroppedPacketsTotal.Get(usermetric.DropLabels{Reason: usermetric.ReasonACL}).(*expvar.Int); ok { metricOutboundDroppedPacketsACL = m.Value() } assertMetricPackets(t, "inACL", 3, metricInboundDroppedPacketsACL) assertMetricPackets(t, "inError", 0, metricInboundDroppedPacketsErr) - assertMetricPackets(t, "outACL", 1, metricOutboundDroppedPacketsACL) + assertMetricPackets(t, "outACL", 0, metricOutboundDroppedPacketsACL) } func assertMetricPackets(t *testing.T, metricName string, want, got int64) { @@ -871,14 +869,14 @@ func TestPeerCfg_NAT(t *testing.T) { // with the correct parameters when various packet operations are performed. func TestCaptureHook(t *testing.T) { type captureRecord struct { - path capture.Path + path packet.CapturePath now time.Time pkt []byte meta packet.CaptureMeta } var captured []captureRecord - hook := func(path capture.Path, now time.Time, pkt []byte, meta packet.CaptureMeta) { + hook := func(path packet.CapturePath, now time.Time, pkt []byte, meta packet.CaptureMeta) { captured = append(captured, captureRecord{ path: path, now: now, @@ -935,19 +933,19 @@ func TestCaptureHook(t *testing.T) { // Assert that the right packets are captured. want := []captureRecord{ { - path: capture.FromPeer, + path: packet.FromPeer, pkt: []byte("Write1"), }, { - path: capture.FromPeer, + path: packet.FromPeer, pkt: []byte("Write2"), }, { - path: capture.SynthesizedToLocal, + path: packet.SynthesizedToLocal, pkt: []byte("InjectInboundPacketBuffer"), }, { - path: capture.SynthesizedToPeer, + path: packet.SynthesizedToPeer, pkt: []byte("InjectOutboundPacketBuffer"), }, } diff --git a/net/udprelay/endpoint/endpoint.go b/net/udprelay/endpoint/endpoint.go new file mode 100644 index 000000000..0d2a14e96 --- /dev/null +++ b/net/udprelay/endpoint/endpoint.go @@ -0,0 +1,64 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package endpoint contains types relating to UDP relay server endpoints. It +// does not import tailscale.com/net/udprelay. +package endpoint + +import ( + "net/netip" + "time" + + "tailscale.com/tstime" + "tailscale.com/types/key" +) + +// ServerRetryAfter is the default +// [tailscale.com/net/udprelay.ErrServerNotReady.RetryAfter] value. +const ServerRetryAfter = time.Second * 3 + +// ServerEndpoint contains details for an endpoint served by a +// [tailscale.com/net/udprelay.Server]. +type ServerEndpoint struct { + // ServerDisco is the Server's Disco public key used as part of the 3-way + // bind handshake. Server will use the same ServerDisco for its lifetime. + // ServerDisco value in combination with LamportID value represents a + // unique ServerEndpoint allocation. + ServerDisco key.DiscoPublic + + // ClientDisco are the Disco public keys of the relay participants permitted + // to handshake with this endpoint. + ClientDisco [2]key.DiscoPublic + + // LamportID is unique and monotonically non-decreasing across + // ServerEndpoint allocations for the lifetime of Server. It enables clients + // to dedup and resolve allocation event order. Clients may race to allocate + // on the same Server, and signal ServerEndpoint details via alternative + // channels, e.g. DERP. Additionally, Server.AllocateEndpoint() requests may + // not result in a new allocation depending on existing server-side endpoint + // state. Therefore, where clients have local, existing state that contains + // ServerDisco and LamportID values matching a newly learned endpoint, these + // can be considered one and the same. If ServerDisco is equal, but + // LamportID is unequal, LamportID comparison determines which + // ServerEndpoint was allocated most recently. + LamportID uint64 + + // AddrPorts are the IP:Port candidate pairs the Server may be reachable + // over. + AddrPorts []netip.AddrPort + + // VNI (Virtual Network Identifier) is the Geneve header VNI the Server + // will use for transmitted packets, and expects for received packets + // associated with this endpoint. + VNI uint32 + + // BindLifetime is amount of time post-allocation the Server will consider + // the endpoint active while it has yet to be bound via 3-way bind handshake + // from both client parties. + BindLifetime tstime.GoDuration + + // SteadyStateLifetime is the amount of time post 3-way bind handshake from + // both client parties the Server will consider the endpoint active lacking + // bidirectional data flow. + SteadyStateLifetime tstime.GoDuration +} diff --git a/net/udprelay/endpoint/endpoint_test.go b/net/udprelay/endpoint/endpoint_test.go new file mode 100644 index 000000000..f12a6e2f6 --- /dev/null +++ b/net/udprelay/endpoint/endpoint_test.go @@ -0,0 +1,110 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package endpoint + +import ( + "encoding/json" + "math" + "net/netip" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "tailscale.com/tstime" + "tailscale.com/types/key" +) + +func TestServerEndpointJSONUnmarshal(t *testing.T) { + tests := []struct { + name string + json []byte + wantErr bool + }{ + { + name: "valid", + json: []byte(`{"ServerDisco":"discokey:003cd7453e04a653eb0e7a18f206fc353180efadb2facfd05ebd6982a1392c7f","LamportID":18446744073709551615,"AddrPorts":["127.0.0.1:1","127.0.0.2:2"],"VNI":16777215,"BindLifetime":"30s","SteadyStateLifetime":"5m0s"}`), + wantErr: false, + }, + { + name: "invalid ServerDisco", + json: []byte(`{"ServerDisco":"1","LamportID":18446744073709551615,"AddrPorts":["127.0.0.1:1","127.0.0.2:2"],"VNI":16777215,"BindLifetime":"30s","SteadyStateLifetime":"5m0s"}`), + wantErr: true, + }, + { + name: "invalid LamportID", + json: []byte(`{"ServerDisco":"discokey:003cd7453e04a653eb0e7a18f206fc353180efadb2facfd05ebd6982a1392c7f","LamportID":1.1,"AddrPorts":["127.0.0.1:1","127.0.0.2:2"],"VNI":16777215,"BindLifetime":"30s","SteadyStateLifetime":"5m0s"}`), + wantErr: true, + }, + { + name: "invalid AddrPorts", + json: []byte(`{"ServerDisco":"discokey:003cd7453e04a653eb0e7a18f206fc353180efadb2facfd05ebd6982a1392c7f","LamportID":18446744073709551615,"AddrPorts":["127.0.0.1.1:1","127.0.0.2:2"],"VNI":16777215,"BindLifetime":"30s","SteadyStateLifetime":"5m0s"}`), + wantErr: true, + }, + { + name: "invalid VNI", + json: []byte(`{"ServerDisco":"discokey:003cd7453e04a653eb0e7a18f206fc353180efadb2facfd05ebd6982a1392c7f","LamportID":18446744073709551615,"AddrPorts":["127.0.0.1:1","127.0.0.2:2"],"VNI":18446744073709551615,"BindLifetime":"30s","SteadyStateLifetime":"5m0s"}`), + wantErr: true, + }, + { + name: "invalid BindLifetime", + json: []byte(`{"ServerDisco":"discokey:003cd7453e04a653eb0e7a18f206fc353180efadb2facfd05ebd6982a1392c7f","LamportID":18446744073709551615,"AddrPorts":["127.0.0.1:1","127.0.0.2:2"],"VNI":16777215,"BindLifetime":"5","SteadyStateLifetime":"5m0s"}`), + wantErr: true, + }, + { + name: "invalid SteadyStateLifetime", + json: []byte(`{"ServerDisco":"discokey:003cd7453e04a653eb0e7a18f206fc353180efadb2facfd05ebd6982a1392c7f","LamportID":18446744073709551615,"AddrPorts":["127.0.0.1:1","127.0.0.2:2"],"VNI":16777215,"BindLifetime":"30s","SteadyStateLifetime":"5"}`), + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var out ServerEndpoint + err := json.Unmarshal(tt.json, &out) + if tt.wantErr != (err != nil) { + t.Fatalf("wantErr: %v (err == nil): %v", tt.wantErr, err == nil) + } + if tt.wantErr { + return + } + }) + } +} + +func TestServerEndpointJSONMarshal(t *testing.T) { + tests := []struct { + name string + serverEndpoint ServerEndpoint + }{ + { + name: "valid roundtrip", + serverEndpoint: ServerEndpoint{ + ServerDisco: key.NewDisco().Public(), + LamportID: uint64(math.MaxUint64), + AddrPorts: []netip.AddrPort{netip.MustParseAddrPort("127.0.0.1:1"), netip.MustParseAddrPort("127.0.0.2:2")}, + VNI: 1<<24 - 1, + BindLifetime: tstime.GoDuration{Duration: time.Second * 30}, + SteadyStateLifetime: tstime.GoDuration{Duration: time.Minute * 5}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + b, err := json.Marshal(&tt.serverEndpoint) + if err != nil { + t.Fatal(err) + } + var got ServerEndpoint + err = json.Unmarshal(b, &got) + if err != nil { + t.Fatal(err) + } + if diff := cmp.Diff(got, tt.serverEndpoint, cmpopts.EquateComparable(netip.AddrPort{}, key.DiscoPublic{})); diff != "" { + t.Fatalf("ServerEndpoint unequal (-got +want)\n%s", diff) + } + }) + } +} diff --git a/net/udprelay/server.go b/net/udprelay/server.go new file mode 100644 index 000000000..7138cec7a --- /dev/null +++ b/net/udprelay/server.go @@ -0,0 +1,889 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package udprelay contains constructs for relaying Disco and WireGuard packets +// between Tailscale clients over UDP. This package is currently considered +// experimental. +package udprelay + +import ( + "bytes" + "context" + "crypto/rand" + "errors" + "fmt" + "net" + "net/netip" + "slices" + "strconv" + "sync" + "time" + + "go4.org/mem" + "golang.org/x/net/ipv6" + "tailscale.com/disco" + "tailscale.com/net/batching" + "tailscale.com/net/netaddr" + "tailscale.com/net/netcheck" + "tailscale.com/net/netmon" + "tailscale.com/net/packet" + "tailscale.com/net/sockopts" + "tailscale.com/net/stun" + "tailscale.com/net/udprelay/endpoint" + "tailscale.com/net/udprelay/status" + "tailscale.com/tailcfg" + "tailscale.com/tstime" + "tailscale.com/types/key" + "tailscale.com/types/logger" + "tailscale.com/types/nettype" + "tailscale.com/types/views" + "tailscale.com/util/eventbus" + "tailscale.com/util/set" +) + +const ( + // defaultBindLifetime is somewhat arbitrary. We attempt to account for + // high latency between client and [Server], and high latency between + // clients over side channels, e.g. DERP, used to exchange + // [endpoint.ServerEndpoint] details. So, a total of 3 paths with + // potentially high latency. Using a conservative 10s "high latency" bounds + // for each path we end up at a 30s total. It is worse to set an aggressive + // bind lifetime as this may lead to path discovery failure, vs dealing with + // a slight increase of [Server] resource utilization (VNIs, RAM, etc) while + // tracking endpoints that won't bind. + defaultBindLifetime = time.Second * 30 + defaultSteadyStateLifetime = time.Minute * 5 +) + +// Server implements an experimental UDP relay server. +type Server struct { + // The following fields are initialized once and never mutated. + logf logger.Logf + disco key.DiscoPrivate + discoPublic key.DiscoPublic + bindLifetime time.Duration + steadyStateLifetime time.Duration + bus *eventbus.Bus + uc4 batching.Conn // always non-nil + uc4Port uint16 // always nonzero + uc6 batching.Conn // may be nil if IPv6 bind fails during initialization + uc6Port uint16 // may be zero if IPv6 bind fails during initialization + closeOnce sync.Once + wg sync.WaitGroup + closeCh chan struct{} + netChecker *netcheck.Client + + mu sync.Mutex // guards the following fields + derpMap *tailcfg.DERPMap + onlyStaticAddrPorts bool // no dynamic addr port discovery when set + staticAddrPorts views.Slice[netip.AddrPort] // static ip:port pairs set with [Server.SetStaticAddrPorts] + dynamicAddrPorts []netip.AddrPort // dynamically discovered ip:port pairs + closed bool + lamportID uint64 + nextVNI uint32 + byVNI map[uint32]*serverEndpoint + byDisco map[key.SortedPairOfDiscoPublic]*serverEndpoint +} + +const ( + minVNI = uint32(1) + maxVNI = uint32(1<<24 - 1) + totalPossibleVNI = maxVNI - minVNI + 1 +) + +// serverEndpoint contains Server-internal [endpoint.ServerEndpoint] state. +// serverEndpoint methods are not thread-safe. +type serverEndpoint struct { + // discoPubKeys contains the key.DiscoPublic of the served clients. The + // indexing of this array aligns with the following fields, e.g. + // discoSharedSecrets[0] is the shared secret to use when sealing + // Disco protocol messages for transmission towards discoPubKeys[0]. + discoPubKeys key.SortedPairOfDiscoPublic + discoSharedSecrets [2]key.DiscoShared + handshakeGeneration [2]uint32 // or zero if a handshake has never started for that relay leg + handshakeAddrPorts [2]netip.AddrPort // or zero value if a handshake has never started for that relay leg + boundAddrPorts [2]netip.AddrPort // or zero value if a handshake has never completed for that relay leg + lastSeen [2]time.Time // TODO(jwhited): consider using mono.Time + challenge [2][disco.BindUDPRelayChallengeLen]byte + packetsRx [2]uint64 // num packets received from/sent by each client after they are bound + bytesRx [2]uint64 // num bytes received from/sent by each client after they are bound + + lamportID uint64 + vni uint32 + allocatedAt time.Time +} + +func (e *serverEndpoint) handleDiscoControlMsg(from netip.AddrPort, senderIndex int, discoMsg disco.Message, serverDisco key.DiscoPublic) (write []byte, to netip.AddrPort) { + if senderIndex != 0 && senderIndex != 1 { + return nil, netip.AddrPort{} + } + + otherSender := 0 + if senderIndex == 0 { + otherSender = 1 + } + + validateVNIAndRemoteKey := func(common disco.BindUDPRelayEndpointCommon) error { + if common.VNI != e.vni { + return errors.New("mismatching VNI") + } + if common.RemoteKey.Compare(e.discoPubKeys.Get()[otherSender]) != 0 { + return errors.New("mismatching RemoteKey") + } + return nil + } + + switch discoMsg := discoMsg.(type) { + case *disco.BindUDPRelayEndpoint: + err := validateVNIAndRemoteKey(discoMsg.BindUDPRelayEndpointCommon) + if err != nil { + // silently drop + return nil, netip.AddrPort{} + } + if discoMsg.Generation == 0 { + // Generation must be nonzero, silently drop + return nil, netip.AddrPort{} + } + if e.handshakeGeneration[senderIndex] == discoMsg.Generation { + // we've seen this generation before, silently drop + return nil, netip.AddrPort{} + } + e.handshakeGeneration[senderIndex] = discoMsg.Generation + e.handshakeAddrPorts[senderIndex] = from + m := new(disco.BindUDPRelayEndpointChallenge) + m.VNI = e.vni + m.Generation = discoMsg.Generation + m.RemoteKey = e.discoPubKeys.Get()[otherSender] + rand.Read(e.challenge[senderIndex][:]) + copy(m.Challenge[:], e.challenge[senderIndex][:]) + reply := make([]byte, packet.GeneveFixedHeaderLength, 512) + gh := packet.GeneveHeader{Control: true, Protocol: packet.GeneveProtocolDisco} + gh.VNI.Set(e.vni) + err = gh.Encode(reply) + if err != nil { + return nil, netip.AddrPort{} + } + reply = append(reply, disco.Magic...) + reply = serverDisco.AppendTo(reply) + box := e.discoSharedSecrets[senderIndex].Seal(m.AppendMarshal(nil)) + reply = append(reply, box...) + return reply, from + case *disco.BindUDPRelayEndpointAnswer: + err := validateVNIAndRemoteKey(discoMsg.BindUDPRelayEndpointCommon) + if err != nil { + // silently drop + return nil, netip.AddrPort{} + } + generation := e.handshakeGeneration[senderIndex] + if generation == 0 || // we have no active handshake + generation != discoMsg.Generation || // mismatching generation for the active handshake + e.handshakeAddrPorts[senderIndex] != from || // mismatching source for the active handshake + !bytes.Equal(e.challenge[senderIndex][:], discoMsg.Challenge[:]) { // mismatching answer for the active handshake + // silently drop + return nil, netip.AddrPort{} + } + // Handshake complete. Update the binding for this sender. + e.boundAddrPorts[senderIndex] = from + e.lastSeen[senderIndex] = time.Now() // record last seen as bound time + return nil, netip.AddrPort{} + default: + // unexpected message types, silently drop + return nil, netip.AddrPort{} + } +} + +func (e *serverEndpoint) handleSealedDiscoControlMsg(from netip.AddrPort, b []byte, serverDisco key.DiscoPublic) (write []byte, to netip.AddrPort) { + senderRaw, isDiscoMsg := disco.Source(b) + if !isDiscoMsg { + // Not a Disco message + return nil, netip.AddrPort{} + } + sender := key.DiscoPublicFromRaw32(mem.B(senderRaw)) + senderIndex := -1 + switch { + case sender.Compare(e.discoPubKeys.Get()[0]) == 0: + senderIndex = 0 + case sender.Compare(e.discoPubKeys.Get()[1]) == 0: + senderIndex = 1 + default: + // unknown Disco public key + return nil, netip.AddrPort{} + } + + const headerLen = len(disco.Magic) + key.DiscoPublicRawLen + discoPayload, ok := e.discoSharedSecrets[senderIndex].Open(b[headerLen:]) + if !ok { + // unable to decrypt the Disco payload + return nil, netip.AddrPort{} + } + + discoMsg, err := disco.Parse(discoPayload) + if err != nil { + // unable to parse the Disco payload + return nil, netip.AddrPort{} + } + + return e.handleDiscoControlMsg(from, senderIndex, discoMsg, serverDisco) +} + +func (e *serverEndpoint) handlePacket(from netip.AddrPort, gh packet.GeneveHeader, b []byte, serverDisco key.DiscoPublic) (write []byte, to netip.AddrPort) { + if !gh.Control { + if !e.isBound() { + // not a control packet, but serverEndpoint isn't bound + return nil, netip.AddrPort{} + } + switch { + case from == e.boundAddrPorts[0]: + e.lastSeen[0] = time.Now() + e.packetsRx[0]++ + e.bytesRx[0] += uint64(len(b)) + return b, e.boundAddrPorts[1] + case from == e.boundAddrPorts[1]: + e.lastSeen[1] = time.Now() + e.packetsRx[1]++ + e.bytesRx[1] += uint64(len(b)) + return b, e.boundAddrPorts[0] + default: + // unrecognized source + return nil, netip.AddrPort{} + } + } + + if gh.Protocol != packet.GeneveProtocolDisco { + // control packet, but not Disco + return nil, netip.AddrPort{} + } + + msg := b[packet.GeneveFixedHeaderLength:] + return e.handleSealedDiscoControlMsg(from, msg, serverDisco) +} + +func (e *serverEndpoint) isExpired(now time.Time, bindLifetime, steadyStateLifetime time.Duration) bool { + if !e.isBound() { + if now.Sub(e.allocatedAt) > bindLifetime { + return true + } + return false + } + if now.Sub(e.lastSeen[0]) > steadyStateLifetime || now.Sub(e.lastSeen[1]) > steadyStateLifetime { + return true + } + return false +} + +// isBound returns true if both clients have completed a 3-way handshake, +// otherwise false. +func (e *serverEndpoint) isBound() bool { + return e.boundAddrPorts[0].IsValid() && + e.boundAddrPorts[1].IsValid() +} + +// NewServer constructs a [Server] listening on port. If port is zero, then +// port selection is left up to the host networking stack. If +// onlyStaticAddrPorts is true, then dynamic addr:port discovery will be +// disabled, and only addr:port's set via [Server.SetStaticAddrPorts] will be +// used. +func NewServer(logf logger.Logf, port int, onlyStaticAddrPorts bool) (s *Server, err error) { + s = &Server{ + logf: logf, + disco: key.NewDisco(), + bindLifetime: defaultBindLifetime, + steadyStateLifetime: defaultSteadyStateLifetime, + closeCh: make(chan struct{}), + onlyStaticAddrPorts: onlyStaticAddrPorts, + byDisco: make(map[key.SortedPairOfDiscoPublic]*serverEndpoint), + nextVNI: minVNI, + byVNI: make(map[uint32]*serverEndpoint), + } + s.discoPublic = s.disco.Public() + + // TODO(creachadair): Find a way to plumb this in during initialization. + // As-written, messages published here will not be seen by other components + // in a running client. + bus := eventbus.New() + s.bus = bus + netMon, err := netmon.New(s.bus, logf) + if err != nil { + return nil, err + } + s.netChecker = &netcheck.Client{ + NetMon: netMon, + Logf: logger.WithPrefix(logf, "netcheck: "), + SendPacket: func(b []byte, addrPort netip.AddrPort) (int, error) { + if addrPort.Addr().Is4() { + return s.uc4.WriteToUDPAddrPort(b, addrPort) + } else if s.uc6 != nil { + return s.uc6.WriteToUDPAddrPort(b, addrPort) + } else { + return 0, errors.New("IPv6 socket is not bound") + } + }, + } + + err = s.listenOn(port) + if err != nil { + return nil, err + } + + if !s.onlyStaticAddrPorts { + s.wg.Add(1) + go s.addrDiscoveryLoop() + } + + s.wg.Add(1) + go s.packetReadLoop(s.uc4, s.uc6, true) + if s.uc6 != nil { + s.wg.Add(1) + go s.packetReadLoop(s.uc6, s.uc4, false) + } + s.wg.Add(1) + go s.endpointGCLoop() + + return s, nil +} + +func (s *Server) addrDiscoveryLoop() { + defer s.wg.Done() + + timer := time.NewTimer(0) // fire immediately + defer timer.Stop() + + getAddrPorts := func() ([]netip.AddrPort, error) { + var addrPorts set.Set[netip.AddrPort] + addrPorts.Make() + + // get local addresses + ips, _, err := netmon.LocalAddresses() + if err != nil { + return nil, err + } + for _, ip := range ips { + if ip.IsValid() { + if ip.Is4() { + addrPorts.Add(netip.AddrPortFrom(ip, s.uc4Port)) + } else { + addrPorts.Add(netip.AddrPortFrom(ip, s.uc6Port)) + } + } + } + + dm := s.getDERPMap() + if dm == nil { + // We don't have a DERPMap which is required to dynamically + // discover external addresses, but we can return the endpoints we + // do have. + return addrPorts.Slice(), nil + } + + // get addrPorts as visible from DERP + netCheckerCtx, netCheckerCancel := context.WithTimeout(context.Background(), netcheck.ReportTimeout) + defer netCheckerCancel() + rep, err := s.netChecker.GetReport(netCheckerCtx, dm, &netcheck.GetReportOpts{ + OnlySTUN: true, + }) + if err != nil { + return nil, err + } + // Add STUN-discovered endpoints with their observed ports. + v4Addrs, v6Addrs := rep.GetGlobalAddrs() + for _, addr := range v4Addrs { + if addr.IsValid() { + addrPorts.Add(addr) + } + } + for _, addr := range v6Addrs { + if addr.IsValid() { + addrPorts.Add(addr) + } + } + + if len(v4Addrs) >= 1 && v4Addrs[0].IsValid() { + // If they're behind a hard NAT and are using a fixed + // port locally, assume they might've added a static + // port mapping on their router to the same explicit + // port that the relay is running with. Worst case + // it's an invalid candidate mapping. + if rep.MappingVariesByDestIP.EqualBool(true) && s.uc4Port != 0 { + addrPorts.Add(netip.AddrPortFrom(v4Addrs[0].Addr(), s.uc4Port)) + } + } + return addrPorts.Slice(), nil + } + + for { + select { + case <-timer.C: + // Mirror magicsock behavior for duration between STUN. We consider + // 30s a min bound for NAT timeout. + timer.Reset(tstime.RandomDurationBetween(20*time.Second, 26*time.Second)) + addrPorts, err := getAddrPorts() + if err != nil { + s.logf("error discovering IP:port candidates: %v", err) + } + s.mu.Lock() + s.dynamicAddrPorts = addrPorts + s.mu.Unlock() + case <-s.closeCh: + return + } + } +} + +// This is a compile-time assertion that [singlePacketConn] implements the +// [batching.Conn] interface. +var _ batching.Conn = (*singlePacketConn)(nil) + +// singlePacketConn implements [batching.Conn] with single packet syscall +// operations. +type singlePacketConn struct { + *net.UDPConn +} + +func (c *singlePacketConn) ReadBatch(msgs []ipv6.Message, _ int) (int, error) { + n, ap, err := c.UDPConn.ReadFromUDPAddrPort(msgs[0].Buffers[0]) + if err != nil { + return 0, err + } + msgs[0].N = n + msgs[0].Addr = net.UDPAddrFromAddrPort(netaddr.Unmap(ap)) + return 1, nil +} + +func (c *singlePacketConn) WriteBatchTo(buffs [][]byte, addr netip.AddrPort, geneve packet.GeneveHeader, offset int) error { + for _, buff := range buffs { + if geneve.VNI.IsSet() { + geneve.Encode(buff) + } else { + buff = buff[offset:] + } + _, err := c.UDPConn.WriteToUDPAddrPort(buff, addr) + if err != nil { + return err + } + } + return nil +} + +// UDP socket read/write buffer size (7MB). At the time of writing (2025-08-21) +// this value was heavily influenced by magicsock, with similar motivations for +// its increase relative to typical defaults, e.g. long fat networks and +// reducing packet loss around crypto/syscall-induced delay. +const socketBufferSize = 7 << 20 + +func trySetUDPSocketOptions(pconn nettype.PacketConn, logf logger.Logf) { + directions := []sockopts.BufferDirection{sockopts.ReadDirection, sockopts.WriteDirection} + for _, direction := range directions { + errForce, errPortable := sockopts.SetBufferSize(pconn, direction, socketBufferSize) + if errForce != nil { + logf("[warning] failed to force-set UDP %v buffer size to %d: %v; using kernel default values (impacts throughput only)", direction, socketBufferSize, errForce) + } + if errPortable != nil { + logf("failed to set UDP %v buffer size to %d: %v", direction, socketBufferSize, errPortable) + } + } + + err := sockopts.SetICMPErrImmunity(pconn) + if err != nil { + logf("failed to set ICMP error immunity: %v", err) + } +} + +// listenOn binds an IPv4 and IPv6 socket to port. We consider it successful if +// we manage to bind the IPv4 socket. +// +// The requested port may be zero, in which case port selection is left up to +// the host networking stack. We make no attempt to bind a consistent port +// across IPv4 and IPv6 if the requested port is zero. +// +// TODO: make these "re-bindable" in similar fashion to magicsock as a means to +// deal with EDR software closing them. http://go/corp/30118. We could re-use +// [magicsock.RebindingConn], which would also remove the need for +// [singlePacketConn], as [magicsock.RebindingConn] also handles fallback to +// single packet syscall operations. +func (s *Server) listenOn(port int) error { + for _, network := range []string{"udp4", "udp6"} { + uc, err := net.ListenUDP(network, &net.UDPAddr{Port: port}) + if err != nil { + if network == "udp4" { + return err + } else { + s.logf("ignoring IPv6 bind failure: %v", err) + break + } + } + trySetUDPSocketOptions(uc, s.logf) + // TODO: set IP_PKTINFO sockopt + _, boundPortStr, err := net.SplitHostPort(uc.LocalAddr().String()) + if err != nil { + uc.Close() + if s.uc4 != nil { + s.uc4.Close() + } + return err + } + portUint, err := strconv.ParseUint(boundPortStr, 10, 16) + if err != nil { + uc.Close() + if s.uc4 != nil { + s.uc4.Close() + } + return err + } + pc := batching.TryUpgradeToConn(uc, network, batching.IdealBatchSize) + bc, ok := pc.(batching.Conn) + if !ok { + bc = &singlePacketConn{uc} + } + if network == "udp4" { + s.uc4 = bc + s.uc4Port = uint16(portUint) + } else { + s.uc6 = bc + s.uc6Port = uint16(portUint) + } + s.logf("listening on %s:%d", network, portUint) + } + return nil +} + +// Close closes the server. +func (s *Server) Close() error { + s.closeOnce.Do(func() { + s.uc4.Close() + if s.uc6 != nil { + s.uc6.Close() + } + close(s.closeCh) + s.wg.Wait() + // s.mu must not be held while s.wg.Wait'ing, otherwise we can + // deadlock. The goroutines we are waiting on to return can also + // acquire s.mu. + s.mu.Lock() + defer s.mu.Unlock() + clear(s.byVNI) + clear(s.byDisco) + s.closed = true + s.bus.Close() + }) + return nil +} + +func (s *Server) endpointGCLoop() { + defer s.wg.Done() + ticker := time.NewTicker(s.bindLifetime) + defer ticker.Stop() + + gc := func() { + now := time.Now() + // TODO: consider performance implications of scanning all endpoints and + // holding s.mu for the duration. Keep it simple (and slow) for now. + s.mu.Lock() + defer s.mu.Unlock() + for k, v := range s.byDisco { + if v.isExpired(now, s.bindLifetime, s.steadyStateLifetime) { + delete(s.byDisco, k) + delete(s.byVNI, v.vni) + } + } + } + + for { + select { + case <-ticker.C: + gc() + case <-s.closeCh: + return + } + } +} + +func (s *Server) handlePacket(from netip.AddrPort, b []byte) (write []byte, to netip.AddrPort) { + if stun.Is(b) && b[1] == 0x01 { + // A b[1] value of 0x01 (STUN method binding) is sufficiently + // non-overlapping with the Geneve header where the LSB is always 0 + // (part of 6 "reserved" bits). + s.netChecker.ReceiveSTUNPacket(b, from) + return nil, netip.AddrPort{} + } + gh := packet.GeneveHeader{} + err := gh.Decode(b) + if err != nil { + return nil, netip.AddrPort{} + } + // TODO: consider performance implications of holding s.mu for the remainder + // of this method, which does a bunch of disco/crypto work depending. Keep + // it simple (and slow) for now. + s.mu.Lock() + defer s.mu.Unlock() + e, ok := s.byVNI[gh.VNI.Get()] + if !ok { + // unknown VNI + return nil, netip.AddrPort{} + } + + return e.handlePacket(from, gh, b, s.discoPublic) +} + +func (s *Server) packetReadLoop(readFromSocket, otherSocket batching.Conn, readFromSocketIsIPv4 bool) { + defer func() { + // We intentionally close the [Server] if we encounter a socket read + // error below, at least until socket "re-binding" is implemented as + // part of http://go/corp/30118. + // + // Decrementing this [sync.WaitGroup] _before_ calling [Server.Close] is + // intentional as [Server.Close] waits on it. + s.wg.Done() + s.Close() + }() + + msgs := make([]ipv6.Message, batching.IdealBatchSize) + for i := range msgs { + msgs[i].OOB = make([]byte, batching.MinControlMessageSize()) + msgs[i].Buffers = make([][]byte, 1) + msgs[i].Buffers[0] = make([]byte, 1<<16-1) + } + writeBuffsByDest := make(map[netip.AddrPort][][]byte, batching.IdealBatchSize) + + for { + for i := range msgs { + msgs[i] = ipv6.Message{Buffers: msgs[i].Buffers, OOB: msgs[i].OOB[:cap(msgs[i].OOB)]} + } + + // TODO: extract laddr from IP_PKTINFO for use in reply + // ReadBatch will split coalesced datagrams before returning, which + // WriteBatchTo will re-coalesce further down. We _could_ be more + // efficient and not split datagrams that belong to the same VNI if they + // are non-control/handshake packets. We pay the memmove/memcopy + // performance penalty for now in the interest of simple single packet + // handlers. + n, err := readFromSocket.ReadBatch(msgs, 0) + if err != nil { + s.logf("error reading from socket(%v): %v", readFromSocket.LocalAddr(), err) + return + } + + for _, msg := range msgs[:n] { + if msg.N == 0 { + continue + } + buf := msg.Buffers[0][:msg.N] + from := msg.Addr.(*net.UDPAddr).AddrPort() + write, to := s.handlePacket(from, buf) + if !to.IsValid() { + continue + } + if from.Addr().Is4() == to.Addr().Is4() || otherSocket != nil { + buffs, ok := writeBuffsByDest[to] + if !ok { + buffs = make([][]byte, 0, batching.IdealBatchSize) + } + buffs = append(buffs, write) + writeBuffsByDest[to] = buffs + } else { + // This is unexpected. We should never produce a packet to write + // to the "other" socket if the other socket is nil/unbound. + // [server.handlePacket] has to see a packet from a particular + // address family at least once in order for it to return a + // packet to write towards a dest for the same address family. + s.logf("[unexpected] packet from: %v produced packet to: %v while otherSocket is nil", from, to) + } + } + + for dest, buffs := range writeBuffsByDest { + // Write the packet batches via the socket associated with the + // destination's address family. If source and destination address + // families are matching we tx on the same socket the packet was + // received, otherwise we use the "other" socket. [Server] makes no + // use of dual-stack sockets. + if dest.Addr().Is4() == readFromSocketIsIPv4 { + readFromSocket.WriteBatchTo(buffs, dest, packet.GeneveHeader{}, 0) + } else { + otherSocket.WriteBatchTo(buffs, dest, packet.GeneveHeader{}, 0) + } + delete(writeBuffsByDest, dest) + } + } +} + +var ErrServerClosed = errors.New("server closed") + +// ErrServerNotReady indicates the server is not ready. Allocation should be +// requested after waiting for at least RetryAfter duration. +type ErrServerNotReady struct { + RetryAfter time.Duration +} + +func (e ErrServerNotReady) Error() string { + return fmt.Sprintf("server not ready, retry after %v", e.RetryAfter) +} + +// getNextVNILocked returns the next available VNI. It implements the +// "Traditional BSD Port Selection Algorithm" from RFC6056. This algorithm does +// not attempt to obfuscate the selection, i.e. the selection is predictable. +// For now, we favor simplicity and reducing VNI re-use over more complex +// ephemeral port (VNI) selection algorithms. +func (s *Server) getNextVNILocked() (uint32, error) { + for i := uint32(0); i < totalPossibleVNI; i++ { + vni := s.nextVNI + if vni == maxVNI { + s.nextVNI = minVNI + } else { + s.nextVNI++ + } + _, ok := s.byVNI[vni] + if !ok { + return vni, nil + } + } + return 0, errors.New("VNI pool exhausted") +} + +// getAllAddrPortsCopyLocked returns a copy of the combined +// [Server.staticAddrPorts] and [Server.dynamicAddrPorts] slices. +func (s *Server) getAllAddrPortsCopyLocked() []netip.AddrPort { + addrPorts := make([]netip.AddrPort, 0, len(s.dynamicAddrPorts)+s.staticAddrPorts.Len()) + addrPorts = append(addrPorts, s.staticAddrPorts.AsSlice()...) + addrPorts = append(addrPorts, slices.Clone(s.dynamicAddrPorts)...) + return addrPorts +} + +// AllocateEndpoint allocates an [endpoint.ServerEndpoint] for the provided pair +// of [key.DiscoPublic]'s. If an allocation already exists for discoA and discoB +// it is returned without modification/reallocation. AllocateEndpoint returns +// the following notable errors: +// 1. [ErrServerClosed] if the server has been closed. +// 2. [ErrServerNotReady] if the server is not ready. +func (s *Server) AllocateEndpoint(discoA, discoB key.DiscoPublic) (endpoint.ServerEndpoint, error) { + s.mu.Lock() + defer s.mu.Unlock() + if s.closed { + return endpoint.ServerEndpoint{}, ErrServerClosed + } + + if s.staticAddrPorts.Len() == 0 && len(s.dynamicAddrPorts) == 0 { + return endpoint.ServerEndpoint{}, ErrServerNotReady{RetryAfter: endpoint.ServerRetryAfter} + } + + if discoA.Compare(s.discoPublic) == 0 || discoB.Compare(s.discoPublic) == 0 { + return endpoint.ServerEndpoint{}, fmt.Errorf("client disco equals server disco: %s", s.discoPublic.ShortString()) + } + + pair := key.NewSortedPairOfDiscoPublic(discoA, discoB) + e, ok := s.byDisco[pair] + if ok { + // Return the existing allocation. Clients can resolve duplicate + // [endpoint.ServerEndpoint]'s via [endpoint.ServerEndpoint.LamportID]. + // + // TODO: consider ServerEndpoint.BindLifetime -= time.Now()-e.allocatedAt + // to give the client a more accurate picture of the bind window. + return endpoint.ServerEndpoint{ + ServerDisco: s.discoPublic, + // Returning the "latest" addrPorts for an existing allocation is + // the simple choice. It may not be the best depending on client + // behaviors and endpoint state (bound or not). We might want to + // consider storing them (maybe interning) in the [*serverEndpoint] + // at allocation time. + ClientDisco: pair.Get(), + AddrPorts: s.getAllAddrPortsCopyLocked(), + VNI: e.vni, + LamportID: e.lamportID, + BindLifetime: tstime.GoDuration{Duration: s.bindLifetime}, + SteadyStateLifetime: tstime.GoDuration{Duration: s.steadyStateLifetime}, + }, nil + } + + vni, err := s.getNextVNILocked() + if err != nil { + return endpoint.ServerEndpoint{}, err + } + + s.lamportID++ + e = &serverEndpoint{ + discoPubKeys: pair, + lamportID: s.lamportID, + allocatedAt: time.Now(), + vni: vni, + } + e.discoSharedSecrets[0] = s.disco.Shared(e.discoPubKeys.Get()[0]) + e.discoSharedSecrets[1] = s.disco.Shared(e.discoPubKeys.Get()[1]) + + s.byDisco[pair] = e + s.byVNI[e.vni] = e + + s.logf("allocated endpoint vni=%d lamportID=%d disco[0]=%v disco[1]=%v", e.vni, e.lamportID, pair.Get()[0].ShortString(), pair.Get()[1].ShortString()) + return endpoint.ServerEndpoint{ + ServerDisco: s.discoPublic, + ClientDisco: pair.Get(), + AddrPorts: s.getAllAddrPortsCopyLocked(), + VNI: e.vni, + LamportID: e.lamportID, + BindLifetime: tstime.GoDuration{Duration: s.bindLifetime}, + SteadyStateLifetime: tstime.GoDuration{Duration: s.steadyStateLifetime}, + }, nil +} + +// extractClientInfo constructs a [status.ClientInfo] for one of the two peer +// relay clients involved in this session. +func extractClientInfo(idx int, ep *serverEndpoint) status.ClientInfo { + if idx != 0 && idx != 1 { + panic(fmt.Sprintf("idx passed to extractClientInfo() must be 0 or 1; got %d", idx)) + } + + return status.ClientInfo{ + Endpoint: ep.boundAddrPorts[idx], + ShortDisco: ep.discoPubKeys.Get()[idx].ShortString(), + PacketsTx: ep.packetsRx[idx], + BytesTx: ep.bytesRx[idx], + } +} + +// GetSessions returns a slice of peer relay session statuses, with each +// entry containing detailed info about the server and clients involved in +// each session. This information is intended for debugging/status UX, and +// should not be relied on for any purpose outside of that. +func (s *Server) GetSessions() []status.ServerSession { + s.mu.Lock() + defer s.mu.Unlock() + if s.closed { + return nil + } + var sessions = make([]status.ServerSession, 0, len(s.byDisco)) + for _, se := range s.byDisco { + c1 := extractClientInfo(0, se) + c2 := extractClientInfo(1, se) + sessions = append(sessions, status.ServerSession{ + VNI: se.vni, + Client1: c1, + Client2: c2, + }) + } + return sessions +} + +// SetDERPMapView sets the [tailcfg.DERPMapView] to use for future netcheck +// reports. +func (s *Server) SetDERPMapView(view tailcfg.DERPMapView) { + s.mu.Lock() + defer s.mu.Unlock() + if !view.Valid() { + s.derpMap = nil + return + } + s.derpMap = view.AsStruct() +} + +func (s *Server) getDERPMap() *tailcfg.DERPMap { + s.mu.Lock() + defer s.mu.Unlock() + return s.derpMap +} + +// SetStaticAddrPorts sets addr:port pairs the [Server] will advertise +// as candidates it is potentially reachable over, in combination with +// dynamically discovered pairs. This replaces any previously-provided static +// values. +func (s *Server) SetStaticAddrPorts(addrPorts views.Slice[netip.AddrPort]) { + s.mu.Lock() + defer s.mu.Unlock() + s.staticAddrPorts = addrPorts +} diff --git a/net/udprelay/server_test.go b/net/udprelay/server_test.go new file mode 100644 index 000000000..6c3d61658 --- /dev/null +++ b/net/udprelay/server_test.go @@ -0,0 +1,354 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package udprelay + +import ( + "bytes" + "net" + "net/netip" + "testing" + "time" + + qt "github.com/frankban/quicktest" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "go4.org/mem" + "tailscale.com/disco" + "tailscale.com/net/packet" + "tailscale.com/types/key" + "tailscale.com/types/views" +) + +type testClient struct { + vni uint32 + handshakeGeneration uint32 + local key.DiscoPrivate + remote key.DiscoPublic + server key.DiscoPublic + uc *net.UDPConn +} + +func newTestClient(t *testing.T, vni uint32, serverEndpoint netip.AddrPort, local key.DiscoPrivate, remote, server key.DiscoPublic) *testClient { + rAddr := &net.UDPAddr{IP: serverEndpoint.Addr().AsSlice(), Port: int(serverEndpoint.Port())} + uc, err := net.DialUDP("udp", nil, rAddr) + if err != nil { + t.Fatal(err) + } + return &testClient{ + vni: vni, + handshakeGeneration: 1, + local: local, + remote: remote, + server: server, + uc: uc, + } +} + +func (c *testClient) write(t *testing.T, b []byte) { + _, err := c.uc.Write(b) + if err != nil { + t.Fatal(err) + } +} + +func (c *testClient) read(t *testing.T) []byte { + c.uc.SetReadDeadline(time.Now().Add(time.Second)) + b := make([]byte, 1<<16-1) + n, err := c.uc.Read(b) + if err != nil { + t.Fatal(err) + } + return b[:n] +} + +func (c *testClient) writeDataPkt(t *testing.T, b []byte) { + pkt := make([]byte, packet.GeneveFixedHeaderLength, packet.GeneveFixedHeaderLength+len(b)) + gh := packet.GeneveHeader{Control: false, Protocol: packet.GeneveProtocolWireGuard} + gh.VNI.Set(c.vni) + err := gh.Encode(pkt) + if err != nil { + t.Fatal(err) + } + pkt = append(pkt, b...) + c.write(t, pkt) +} + +func (c *testClient) readDataPkt(t *testing.T) []byte { + b := c.read(t) + gh := packet.GeneveHeader{} + err := gh.Decode(b) + if err != nil { + t.Fatal(err) + } + if gh.Protocol != packet.GeneveProtocolWireGuard { + t.Fatal("unexpected geneve protocol") + } + if gh.Control { + t.Fatal("unexpected control") + } + if gh.VNI.Get() != c.vni { + t.Fatal("unexpected vni") + } + return b[packet.GeneveFixedHeaderLength:] +} + +func (c *testClient) writeControlDiscoMsg(t *testing.T, msg disco.Message) { + pkt := make([]byte, packet.GeneveFixedHeaderLength, 512) + gh := packet.GeneveHeader{Control: true, Protocol: packet.GeneveProtocolDisco} + gh.VNI.Set(c.vni) + err := gh.Encode(pkt) + if err != nil { + t.Fatal(err) + } + pkt = append(pkt, disco.Magic...) + pkt = c.local.Public().AppendTo(pkt) + box := c.local.Shared(c.server).Seal(msg.AppendMarshal(nil)) + pkt = append(pkt, box...) + c.write(t, pkt) +} + +func (c *testClient) readControlDiscoMsg(t *testing.T) disco.Message { + b := c.read(t) + gh := packet.GeneveHeader{} + err := gh.Decode(b) + if err != nil { + t.Fatal(err) + } + if gh.Protocol != packet.GeneveProtocolDisco { + t.Fatal("unexpected geneve protocol") + } + if !gh.Control { + t.Fatal("unexpected non-control") + } + if gh.VNI.Get() != c.vni { + t.Fatal("unexpected vni") + } + b = b[packet.GeneveFixedHeaderLength:] + headerLen := len(disco.Magic) + key.DiscoPublicRawLen + if len(b) < headerLen { + t.Fatal("disco message too short") + } + sender := key.DiscoPublicFromRaw32(mem.B(b[len(disco.Magic):headerLen])) + if sender.Compare(c.server) != 0 { + t.Fatal("unknown disco public key") + } + payload, ok := c.local.Shared(c.server).Open(b[headerLen:]) + if !ok { + t.Fatal("failed to open sealed disco msg") + } + msg, err := disco.Parse(payload) + if err != nil { + t.Fatal("failed to parse disco payload") + } + return msg +} + +func (c *testClient) handshake(t *testing.T) { + generation := c.handshakeGeneration + c.handshakeGeneration++ + common := disco.BindUDPRelayEndpointCommon{ + VNI: c.vni, + Generation: generation, + RemoteKey: c.remote, + } + c.writeControlDiscoMsg(t, &disco.BindUDPRelayEndpoint{ + BindUDPRelayEndpointCommon: common, + }) + msg := c.readControlDiscoMsg(t) + challenge, ok := msg.(*disco.BindUDPRelayEndpointChallenge) + if !ok { + t.Fatal("unexpected disco message type") + } + if challenge.Generation != common.Generation { + t.Fatalf("rx'd challenge.Generation (%d) != %d", challenge.Generation, common.Generation) + } + if challenge.VNI != common.VNI { + t.Fatalf("rx'd challenge.VNI (%d) != %d", challenge.VNI, common.VNI) + } + if challenge.RemoteKey != common.RemoteKey { + t.Fatalf("rx'd challenge.RemoteKey (%v) != %v", challenge.RemoteKey, common.RemoteKey) + } + answer := &disco.BindUDPRelayEndpointAnswer{ + BindUDPRelayEndpointCommon: common, + } + answer.Challenge = challenge.Challenge + c.writeControlDiscoMsg(t, answer) +} + +func (c *testClient) close() { + c.uc.Close() +} + +func TestServer(t *testing.T) { + discoA := key.NewDisco() + discoB := key.NewDisco() + + cases := []struct { + name string + staticAddrs []netip.Addr + forceClientsMixedAF bool + }{ + { + name: "over ipv4", + staticAddrs: []netip.Addr{netip.MustParseAddr("127.0.0.1")}, + }, + { + name: "over ipv6", + staticAddrs: []netip.Addr{netip.MustParseAddr("::1")}, + }, + { + name: "mixed address families", + staticAddrs: []netip.Addr{netip.MustParseAddr("127.0.0.1"), netip.MustParseAddr("::1")}, + forceClientsMixedAF: true, + }, + } + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + server, err := NewServer(t.Logf, 0, true) + if err != nil { + t.Fatal(err) + } + defer server.Close() + addrPorts := make([]netip.AddrPort, 0, len(tt.staticAddrs)) + for _, addr := range tt.staticAddrs { + if addr.Is4() { + addrPorts = append(addrPorts, netip.AddrPortFrom(addr, server.uc4Port)) + } else if server.uc6Port != 0 { + addrPorts = append(addrPorts, netip.AddrPortFrom(addr, server.uc6Port)) + } + } + server.SetStaticAddrPorts(views.SliceOf(addrPorts)) + + endpoint, err := server.AllocateEndpoint(discoA.Public(), discoB.Public()) + if err != nil { + t.Fatal(err) + } + dupEndpoint, err := server.AllocateEndpoint(discoA.Public(), discoB.Public()) + if err != nil { + t.Fatal(err) + } + + // We expect the same endpoint details pre-handshake. + if diff := cmp.Diff(dupEndpoint, endpoint, cmpopts.EquateComparable(netip.AddrPort{}, key.DiscoPublic{})); diff != "" { + t.Fatalf("wrong dupEndpoint (-got +want)\n%s", diff) + } + + if len(endpoint.AddrPorts) < 1 { + t.Fatalf("unexpected endpoint.AddrPorts: %v", endpoint.AddrPorts) + } + tcAServerEndpointAddr := endpoint.AddrPorts[0] + tcA := newTestClient(t, endpoint.VNI, tcAServerEndpointAddr, discoA, discoB.Public(), endpoint.ServerDisco) + defer tcA.close() + tcBServerEndpointAddr := tcAServerEndpointAddr + if tt.forceClientsMixedAF { + foundMixedAF := false + for _, addr := range endpoint.AddrPorts { + if addr.Addr().Is4() != tcBServerEndpointAddr.Addr().Is4() { + tcBServerEndpointAddr = addr + foundMixedAF = true + } + } + if !foundMixedAF { + t.Fatal("force clients to mixed address families is set, but relay server lacks address family diversity") + } + } + tcB := newTestClient(t, endpoint.VNI, tcBServerEndpointAddr, discoB, discoA.Public(), endpoint.ServerDisco) + defer tcB.close() + + for i := 0; i < 2; i++ { + // We handshake both clients twice to guarantee server-side + // packet reading goroutines, which are independent across + // address families, have seen an answer from both clients + // before proceeding. This is needed because the test assumes + // that B's handshake is complete (the first send is A->B below), + // but the server may not have handled B's handshake answer + // before it handles A's data pkt towards B. + // + // Data transmissions following "re-handshakes" orient so that + // the sender is the same as the party that performed the + // handshake, for the same reasons. + // + // [magicsock.relayManager] is not prone to this issue as both + // parties transmit data packets immediately following their + // handshake answer. + tcA.handshake(t) + tcB.handshake(t) + } + + dupEndpoint, err = server.AllocateEndpoint(discoA.Public(), discoB.Public()) + if err != nil { + t.Fatal(err) + } + // We expect the same endpoint details post-handshake. + if diff := cmp.Diff(dupEndpoint, endpoint, cmpopts.EquateComparable(netip.AddrPort{}, key.DiscoPublic{})); diff != "" { + t.Fatalf("wrong dupEndpoint (-got +want)\n%s", diff) + } + + txToB := []byte{1, 2, 3} + tcA.writeDataPkt(t, txToB) + rxFromA := tcB.readDataPkt(t) + if !bytes.Equal(txToB, rxFromA) { + t.Fatal("unexpected msg A->B") + } + + txToA := []byte{4, 5, 6} + tcB.writeDataPkt(t, txToA) + rxFromB := tcA.readDataPkt(t) + if !bytes.Equal(txToA, rxFromB) { + t.Fatal("unexpected msg B->A") + } + + tcAOnNewPort := newTestClient(t, endpoint.VNI, tcAServerEndpointAddr, discoA, discoB.Public(), endpoint.ServerDisco) + tcAOnNewPort.handshakeGeneration = tcA.handshakeGeneration + 1 + defer tcAOnNewPort.close() + + // Handshake client A on a new source IP:port, verify we can send packets on the new binding + tcAOnNewPort.handshake(t) + + fromAOnNewPort := []byte{7, 8, 9} + tcAOnNewPort.writeDataPkt(t, fromAOnNewPort) + rxFromA = tcB.readDataPkt(t) + if !bytes.Equal(fromAOnNewPort, rxFromA) { + t.Fatal("unexpected msg A->B") + } + + tcBOnNewPort := newTestClient(t, endpoint.VNI, tcBServerEndpointAddr, discoB, discoA.Public(), endpoint.ServerDisco) + tcBOnNewPort.handshakeGeneration = tcB.handshakeGeneration + 1 + defer tcBOnNewPort.close() + + // Handshake client B on a new source IP:port, verify we can send packets on the new binding + tcBOnNewPort.handshake(t) + + fromBOnNewPort := []byte{7, 8, 9} + tcBOnNewPort.writeDataPkt(t, fromBOnNewPort) + rxFromB = tcAOnNewPort.readDataPkt(t) + if !bytes.Equal(fromBOnNewPort, rxFromB) { + t.Fatal("unexpected msg B->A") + } + }) + } +} + +func TestServer_getNextVNILocked(t *testing.T) { + t.Parallel() + c := qt.New(t) + s := &Server{ + nextVNI: minVNI, + byVNI: make(map[uint32]*serverEndpoint), + } + for i := uint64(0); i < uint64(totalPossibleVNI); i++ { + vni, err := s.getNextVNILocked() + if err != nil { // using quicktest here triples test time + t.Fatal(err) + } + s.byVNI[vni] = nil + } + c.Assert(s.nextVNI, qt.Equals, minVNI) + _, err := s.getNextVNILocked() + c.Assert(err, qt.IsNotNil) + delete(s.byVNI, minVNI) + _, err = s.getNextVNILocked() + c.Assert(err, qt.IsNil) +} diff --git a/net/udprelay/status/status.go b/net/udprelay/status/status.go new file mode 100644 index 000000000..3866efada --- /dev/null +++ b/net/udprelay/status/status.go @@ -0,0 +1,75 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package status contains types relating to the status of peer relay sessions +// between peer relay client nodes via a peer relay server. +package status + +import ( + "net/netip" +) + +// ServerStatus contains the listening UDP port and active sessions (if any) for +// this node's peer relay server at a point in time. +type ServerStatus struct { + // UDPPort is the UDP port number that the peer relay server forwards over, + // as configured by the user with 'tailscale set --relay-server-port='. + // If the port has not been configured, UDPPort will be nil. + UDPPort *int + // Sessions is a slice of detailed status information about each peer + // relay session that this node's peer relay server is involved with. It + // may be empty. + Sessions []ServerSession +} + +// ClientInfo contains status-related information about a single peer relay +// client involved in a single peer relay session. +type ClientInfo struct { + // Endpoint is the [netip.AddrPort] of this peer relay client's underlay + // endpoint participating in the session, or a zero value if the client + // has not completed a handshake. + Endpoint netip.AddrPort + // ShortDisco is a string representation of this peer relay client's disco + // public key. + // + // TODO: disco keys are pretty meaningless to end users, and they are also + // ephemeral. We really need node keys (or translation to first ts addr), + // but those are not fully plumbed into the [udprelay.Server]. Disco keys + // can also be ambiguous to a node key, but we could add node key into a + // [disco.AllocateUDPRelayEndpointRequest] in similar fashion to + // [disco.Ping]. There's also the problem of netmap trimming, where we + // can't verify a node key maps to a disco key. + ShortDisco string + // PacketsTx is the number of packets this peer relay client has sent to + // the other client via the relay server after completing a handshake. This + // is identical to the number of packets that the peer relay server has + // received from this client. + PacketsTx uint64 + // BytesTx is the total overlay bytes this peer relay client has sent to + // the other client via the relay server after completing a handshake. This + // is identical to the total overlay bytes that the peer relay server has + // received from this client. + BytesTx uint64 +} + +// ServerSession contains status information for a single session between two +// peer relay clients, which are relayed via one peer relay server. This is the +// status as seen by the peer relay server; each client node may have a +// different view of the session's current status based on connectivity and +// where the client is in the peer relay endpoint setup (allocation, binding, +// pinging, active). +type ServerSession struct { + // VNI is the Virtual Network Identifier for this peer relay session, which + // comes from the Geneve header and is unique to this session. + VNI uint32 + // Client1 contains status information about one of the two peer relay + // clients involved in this session. Note that 'Client1' does NOT mean this + // was/wasn't the allocating client, or the first client to bind, etc; this + // is just one client of two. + Client1 ClientInfo + // Client2 contains status information about one of the two peer relay + // clients involved in this session. Note that 'Client2' does NOT mean this + // was/wasn't the allocating client, or the second client to bind, etc; this + // is just one client of two. + Client2 ClientInfo +} diff --git a/net/wsconn/wsconn.go b/net/wsconn/wsconn.go index 22b511ea8..9e44da59c 100644 --- a/net/wsconn/wsconn.go +++ b/net/wsconn/wsconn.go @@ -2,9 +2,7 @@ // SPDX-License-Identifier: BSD-3-Clause // Package wsconn contains an adapter type that turns -// a websocket connection into a net.Conn. It a temporary fork of the -// netconn.go file from the github.com/coder/websocket package while we wait for -// https://github.com/nhooyr/websocket/pull/350 to be merged. +// a websocket connection into a net.Conn. package wsconn import ( @@ -14,11 +12,11 @@ import ( "math" "net" "os" - "sync" "sync/atomic" "time" "github.com/coder/websocket" + "tailscale.com/syncs" ) // NetConn converts a *websocket.Conn into a net.Conn. @@ -104,7 +102,7 @@ type netConn struct { reading atomic.Bool afterReadDeadline atomic.Bool - readMu sync.Mutex + readMu syncs.Mutex // eofed is true if the reader should return io.EOF from the Read call. // // +checklocks:readMu diff --git a/packages/deb/deb.go b/packages/deb/deb.go index 30e3f2b4d..cab0fea07 100644 --- a/packages/deb/deb.go +++ b/packages/deb/deb.go @@ -166,14 +166,14 @@ var ( func findArchAndVersion(control []byte) (arch string, version string, err error) { b := bytes.NewBuffer(control) for { - l, err := b.ReadBytes('\n') + ln, err := b.ReadBytes('\n') if err != nil { return "", "", err } - if bytes.HasPrefix(l, archKey) { - arch = string(bytes.TrimSpace(l[len(archKey):])) - } else if bytes.HasPrefix(l, versionKey) { - version = string(bytes.TrimSpace(l[len(versionKey):])) + if bytes.HasPrefix(ln, archKey) { + arch = string(bytes.TrimSpace(ln[len(archKey):])) + } else if bytes.HasPrefix(ln, versionKey) { + version = string(bytes.TrimSpace(ln[len(versionKey):])) } if arch != "" && version != "" { return arch, version, nil diff --git a/paths/paths.go b/paths/paths.go index 28c3be02a..6c9c3fa6c 100644 --- a/paths/paths.go +++ b/paths/paths.go @@ -6,6 +6,7 @@ package paths import ( + "log" "os" "path/filepath" "runtime" @@ -70,6 +71,37 @@ func DefaultTailscaledStateFile() string { return "" } +// DefaultTailscaledStateDir returns the default state directory +// to use for tailscaled, for use when the user provided neither +// a state directory or state file path to use. +// +// It returns the empty string if there's no reasonable default. +func DefaultTailscaledStateDir() string { + if runtime.GOOS == "plan9" { + home, err := os.UserHomeDir() + if err != nil { + log.Fatalf("failed to get home directory: %v", err) + } + return filepath.Join(home, "tailscale-state") + } + return filepath.Dir(DefaultTailscaledStateFile()) +} + +// MakeAutomaticStateDir reports whether the platform +// automatically creates the state directory for tailscaled +// when it's absent. +func MakeAutomaticStateDir() bool { + switch runtime.GOOS { + case "plan9": + return true + case "linux": + if distro.Get() == distro.JetKVM { + return true + } + } + return false +} + // MkStateDir ensures that dirPath, the daemon's configuration directory // containing machine keys etc, both exists and has the correct permissions. // We want it to only be accessible to the user the daemon is running under. diff --git a/paths/paths_unix.go b/paths/paths_unix.go index 6a2b28733..d317921d5 100644 --- a/paths/paths_unix.go +++ b/paths/paths_unix.go @@ -21,8 +21,11 @@ func init() { } func statePath() string { + if runtime.GOOS == "linux" && distro.Get() == distro.JetKVM { + return "/userdata/tailscale/var/tailscaled.state" + } switch runtime.GOOS { - case "linux": + case "linux", "illumos", "solaris": return "/var/lib/tailscale/tailscaled.state" case "freebsd", "openbsd": return "/var/db/tailscale/tailscaled.state" diff --git a/pkgdoc_test.go b/pkgdoc_test.go index be08a358b..0f4a45528 100644 --- a/pkgdoc_test.go +++ b/pkgdoc_test.go @@ -26,6 +26,9 @@ func TestPackageDocs(t *testing.T) { if err != nil { return err } + if fi.Mode().IsDir() && path == ".git" { + return filepath.SkipDir // No documentation lives in .git + } if fi.Mode().IsRegular() && strings.HasSuffix(path, ".go") { if strings.HasSuffix(path, "_test.go") { return nil diff --git a/portlist/portlist_plan9.go b/portlist/portlist_plan9.go new file mode 100644 index 000000000..77f8619f9 --- /dev/null +++ b/portlist/portlist_plan9.go @@ -0,0 +1,122 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package portlist + +import ( + "bufio" + "bytes" + "os" + "strconv" + "strings" + "time" +) + +func init() { + newOSImpl = newPlan9Impl + + pollInterval = 5 * time.Second +} + +type plan9Impl struct { + known map[protoPort]*portMeta // inode string => metadata + + br *bufio.Reader // reused + portsBuf []Port + includeLocalhost bool +} + +type protoPort struct { + proto string + port uint16 +} + +type portMeta struct { + port Port + keep bool +} + +func newPlan9Impl(includeLocalhost bool) osImpl { + return &plan9Impl{ + known: map[protoPort]*portMeta{}, + br: bufio.NewReader(bytes.NewReader(nil)), + includeLocalhost: includeLocalhost, + } +} + +func (*plan9Impl) Close() error { return nil } + +func (im *plan9Impl) AppendListeningPorts(base []Port) ([]Port, error) { + ret := base + + des, err := os.ReadDir("/proc") + if err != nil { + return nil, err + } + for _, de := range des { + if !de.IsDir() { + continue + } + pidStr := de.Name() + pid, err := strconv.Atoi(pidStr) + if err != nil { + continue + } + st, _ := os.ReadFile("/proc/" + pidStr + "/fd") + if !bytes.Contains(st, []byte("/net/tcp/clone")) { + continue + } + args, _ := os.ReadFile("/proc/" + pidStr + "/args") + procName := string(bytes.TrimSpace(args)) + // term% cat /proc/417/fd + // /usr/glenda + // 0 r M 35 (0000000000000001 0 00) 16384 260 /dev/cons + // 1 w c 0 (000000000000000a 0 00) 0 471 /dev/null + // 2 w M 35 (0000000000000001 0 00) 16384 108 /dev/cons + // 3 rw I 0 (000000000000002c 0 00) 0 14 /net/tcp/clone + for line := range bytes.Lines(st) { + if !bytes.Contains(line, []byte("/net/tcp/clone")) { + continue + } + f := strings.Fields(string(line)) + if len(f) < 10 { + continue + } + if f[9] != "/net/tcp/clone" { + continue + } + qid, err := strconv.ParseUint(strings.TrimPrefix(f[4], "("), 16, 64) + if err != nil { + continue + } + tcpN := (qid >> 5) & (1<<12 - 1) + tcpNStr := strconv.FormatUint(tcpN, 10) + st, _ := os.ReadFile("/net/tcp/" + tcpNStr + "/status") + if !bytes.Contains(st, []byte("Listen ")) { + // Unexpected. Or a race. + continue + } + bl, _ := os.ReadFile("/net/tcp/" + tcpNStr + "/local") + i := bytes.LastIndexByte(bl, '!') + if i == -1 { + continue + } + if bytes.HasPrefix(bl, []byte("127.0.0.1!")) && !im.includeLocalhost { + continue + } + portStr := strings.TrimSpace(string(bl[i+1:])) + port, _ := strconv.Atoi(portStr) + if port == 0 { + continue + } + ret = append(ret, Port{ + Proto: "tcp", + Port: uint16(port), + Process: procName, + Pid: pid, + }) + } + } + + return sortAndDedup(ret), nil +} diff --git a/posture/serialnumber_ios.go b/posture/serialnumber_ios.go deleted file mode 100644 index 55d0e438b..000000000 --- a/posture/serialnumber_ios.go +++ /dev/null @@ -1,25 +0,0 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package posture - -import ( - "fmt" - - "tailscale.com/types/logger" - "tailscale.com/util/syspolicy" -) - -// GetSerialNumbers returns the serial number of the iOS/tvOS device as reported by an -// MDM solution. It requires configuration via the DeviceSerialNumber system policy. -// This is the only way to gather serial numbers on iOS and tvOS. -func GetSerialNumbers(_ logger.Logf) ([]string, error) { - s, err := syspolicy.GetString(syspolicy.DeviceSerialNumber, "") - if err != nil { - return nil, fmt.Errorf("failed to get serial number from MDM: %v", err) - } - if s != "" { - return []string{s}, nil - } - return nil, nil -} diff --git a/posture/serialnumber_macos.go b/posture/serialnumber_macos.go index 48355d313..18c929107 100644 --- a/posture/serialnumber_macos.go +++ b/posture/serialnumber_macos.go @@ -59,10 +59,11 @@ import ( "strings" "tailscale.com/types/logger" + "tailscale.com/util/syspolicy/policyclient" ) // GetSerialNumber returns the platform serial sumber as reported by IOKit. -func GetSerialNumbers(_ logger.Logf) ([]string, error) { +func GetSerialNumbers(policyclient.Client, logger.Logf) ([]string, error) { csn := C.getSerialNumber() serialNumber := C.GoString(csn) diff --git a/posture/serialnumber_macos_test.go b/posture/serialnumber_macos_test.go index 9f0ce1c6a..9d9b9f578 100644 --- a/posture/serialnumber_macos_test.go +++ b/posture/serialnumber_macos_test.go @@ -11,6 +11,7 @@ import ( "tailscale.com/types/logger" "tailscale.com/util/cibuild" + "tailscale.com/util/syspolicy/policyclient" ) func TestGetSerialNumberMac(t *testing.T) { @@ -20,7 +21,7 @@ func TestGetSerialNumberMac(t *testing.T) { t.Skip() } - sns, err := GetSerialNumbers(logger.Discard) + sns, err := GetSerialNumbers(policyclient.NoPolicyClient{}, logger.Discard) if err != nil { t.Fatalf("failed to get serial number: %s", err) } diff --git a/posture/serialnumber_notmacos.go b/posture/serialnumber_notmacos.go index 8b91738b0..132fa08f6 100644 --- a/posture/serialnumber_notmacos.go +++ b/posture/serialnumber_notmacos.go @@ -13,6 +13,7 @@ import ( "github.com/digitalocean/go-smbios/smbios" "tailscale.com/types/logger" + "tailscale.com/util/syspolicy/policyclient" ) // getByteFromSmbiosStructure retrieves a 8-bit unsigned integer at the given specOffset. @@ -71,7 +72,7 @@ func init() { numOfTables = len(validTables) } -func GetSerialNumbers(logf logger.Logf) ([]string, error) { +func GetSerialNumbers(polc policyclient.Client, logf logger.Logf) ([]string, error) { // Find SMBIOS data in operating system-specific location. rc, _, err := smbios.Stream() if err != nil { diff --git a/posture/serialnumber_notmacos_test.go b/posture/serialnumber_notmacos_test.go index f2a15e037..da5aada85 100644 --- a/posture/serialnumber_notmacos_test.go +++ b/posture/serialnumber_notmacos_test.go @@ -12,6 +12,7 @@ import ( "testing" "tailscale.com/types/logger" + "tailscale.com/util/syspolicy/policyclient" ) func TestGetSerialNumberNotMac(t *testing.T) { @@ -21,7 +22,7 @@ func TestGetSerialNumberNotMac(t *testing.T) { // Comment out skip for local testing. t.Skip() - sns, err := GetSerialNumbers(logger.Discard) + sns, err := GetSerialNumbers(policyclient.NoPolicyClient{}, logger.Discard) if err != nil { t.Fatalf("failed to get serial number: %s", err) } diff --git a/posture/serialnumber_stub.go b/posture/serialnumber_stub.go index cdabf03e5..854a0014b 100644 --- a/posture/serialnumber_stub.go +++ b/posture/serialnumber_stub.go @@ -1,13 +1,12 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -// android: not implemented // js: not implemented // plan9: not implemented // solaris: currently unsupported by go-smbios: // https://github.com/digitalocean/go-smbios/pull/21 -//go:build android || solaris || plan9 || js || wasm || tamago || aix || (darwin && !cgo && !ios) +//go:build solaris || plan9 || js || wasm || tamago || aix || (darwin && !cgo && !ios) package posture @@ -15,9 +14,10 @@ import ( "errors" "tailscale.com/types/logger" + "tailscale.com/util/syspolicy/policyclient" ) // GetSerialNumber returns client machine serial number(s). -func GetSerialNumbers(_ logger.Logf) ([]string, error) { +func GetSerialNumbers(polc policyclient.Client, _ logger.Logf) ([]string, error) { return nil, errors.New("not implemented") } diff --git a/posture/serialnumber_syspolicy.go b/posture/serialnumber_syspolicy.go new file mode 100644 index 000000000..64a154a2c --- /dev/null +++ b/posture/serialnumber_syspolicy.go @@ -0,0 +1,28 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build android || ios + +package posture + +import ( + "fmt" + + "tailscale.com/types/logger" + "tailscale.com/util/syspolicy/pkey" + "tailscale.com/util/syspolicy/policyclient" +) + +// GetSerialNumbers returns the serial number of the device as reported by an +// MDM solution. It requires configuration via the DeviceSerialNumber system policy. +// This is the only way to gather serial numbers on iOS, tvOS and Android. +func GetSerialNumbers(polc policyclient.Client, _ logger.Logf) ([]string, error) { + s, err := polc.GetString(pkey.DeviceSerialNumber, "") + if err != nil { + return nil, fmt.Errorf("failed to get serial number from MDM: %v", err) + } + if s != "" { + return []string{s}, nil + } + return nil, nil +} diff --git a/posture/serialnumber_test.go b/posture/serialnumber_test.go index fac4392fa..6db3651e2 100644 --- a/posture/serialnumber_test.go +++ b/posture/serialnumber_test.go @@ -7,10 +7,11 @@ import ( "testing" "tailscale.com/types/logger" + "tailscale.com/util/syspolicy/policyclient" ) func TestGetSerialNumber(t *testing.T) { // ensure GetSerialNumbers is implemented // or covered by a stub on a given platform. - _, _ = GetSerialNumbers(logger.Discard) + _, _ = GetSerialNumbers(policyclient.NoPolicyClient{}, logger.Discard) } diff --git a/prober/derp.go b/prober/derp.go index 0dadbe8c2..22843b53a 100644 --- a/prober/derp.go +++ b/prober/derp.go @@ -8,24 +8,35 @@ import ( "cmp" "context" crand "crypto/rand" + "crypto/tls" + "encoding/binary" "encoding/json" "errors" "expvar" "fmt" + "io" "log" + "maps" "net" "net/http" + "net/netip" + "slices" "strconv" "strings" "sync" "time" "github.com/prometheus/client_golang/prometheus" - "tailscale.com/client/tailscale" + wgconn "github.com/tailscale/wireguard-go/conn" + "github.com/tailscale/wireguard-go/device" + "github.com/tailscale/wireguard-go/tun" + "go4.org/netipx" + "tailscale.com/client/local" "tailscale.com/derp" "tailscale.com/derp/derphttp" "tailscale.com/net/netmon" "tailscale.com/net/stun" + "tailscale.com/net/tstun" "tailscale.com/syncs" "tailscale.com/tailcfg" "tailscale.com/types/key" @@ -37,22 +48,32 @@ import ( type derpProber struct { p *Prober derpMapURL string // or "local" + meshKey key.DERPMesh udpInterval time.Duration meshInterval time.Duration tlsInterval time.Duration // Optional bandwidth probing. - bwInterval time.Duration - bwProbeSize int64 + bwInterval time.Duration + bwProbeSize int64 + bwTUNIPv4Prefix *netip.Prefix // or nil to not use TUN + + // Optional queuing delay probing. + qdPacketsPerSecond int // in packets per second + qdPacketTimeout time.Duration + + // Optionally restrict probes to a single regionCodeOrID. + regionCodeOrID string // Probe class for fetching & updating the DERP map. ProbeMap ProbeClass // Probe classes for probing individual derpers. - tlsProbeFn func(string) ProbeClass + tlsProbeFn func(string, *tls.Config) ProbeClass udpProbeFn func(string, int) ProbeClass meshProbeFn func(string, string) ProbeClass bwProbeFn func(string, string, int64) ProbeClass + qdProbeFn func(string, string, int, time.Duration, key.DERPMesh) ProbeClass sync.Mutex lastDERPMap *tailcfg.DERPMap @@ -65,11 +86,30 @@ type DERPOpt func(*derpProber) // WithBandwidthProbing enables bandwidth probing. When enabled, a payload of // `size` bytes will be regularly transferred through each DERP server, and each -// pair of DERP servers in every region. -func WithBandwidthProbing(interval time.Duration, size int64) DERPOpt { +// pair of DERP servers in every region. If tunAddress is specified, probes will +// use a TCP connection over a TUN device at this address in order to exercise +// TCP-in-TCP in similar fashion to TCP over Tailscale via DERP. +func WithBandwidthProbing(interval time.Duration, size int64, tunAddress string) DERPOpt { return func(d *derpProber) { d.bwInterval = interval d.bwProbeSize = size + if tunAddress != "" { + prefix, err := netip.ParsePrefix(fmt.Sprintf("%s/30", tunAddress)) + if err != nil { + log.Fatalf("failed to parse IP prefix from bw-tun-ipv4-addr: %v", err) + } + d.bwTUNIPv4Prefix = &prefix + } + } +} + +// WithQueuingDelayProbing enables/disables queuing delay probing. qdSendRate +// is the number of packets sent per second. qdTimeout is the amount of time +// after which a sent packet is considered to have timed out. +func WithQueuingDelayProbing(qdPacketsPerSecond int, qdPacketTimeout time.Duration) DERPOpt { + return func(d *derpProber) { + d.qdPacketsPerSecond = qdPacketsPerSecond + d.qdPacketTimeout = qdPacketTimeout } } @@ -97,6 +137,20 @@ func WithTLSProbing(interval time.Duration) DERPOpt { } } +// WithRegionCodeOrID restricts probing to the specified region identified by its code +// (e.g. "lax") or its id (e.g. "17"). This is case sensitive. +func WithRegionCodeOrID(regionCode string) DERPOpt { + return func(d *derpProber) { + d.regionCodeOrID = regionCode + } +} + +func WithMeshKey(meshKey key.DERPMesh) DERPOpt { + return func(d *derpProber) { + d.meshKey = meshKey + } +} + // DERP creates a new derpProber. // // If derpMapURL is "local", the DERPMap is fetched via @@ -119,6 +173,7 @@ func DERP(p *Prober, derpMapURL string, opts ...DERPOpt) (*derpProber, error) { d.udpProbeFn = d.ProbeUDP d.meshProbeFn = d.probeMesh d.bwProbeFn = d.probeBandwidth + d.qdProbeFn = d.probeQueuingDelay return d, nil } @@ -135,6 +190,10 @@ func (d *derpProber) probeMapFn(ctx context.Context) error { defer d.Unlock() for _, region := range d.lastDERPMap.Regions { + if d.skipRegion(region) { + continue + } + for _, server := range region.Nodes { labels := Labels{ "region": region.RegionCode, @@ -148,7 +207,7 @@ func (d *derpProber) probeMapFn(ctx context.Context) error { if d.probes[n] == nil { log.Printf("adding DERP TLS probe for %s (%s) every %v", server.Name, region.RegionName, d.tlsInterval) derpPort := cmp.Or(server.DERPPort, 443) - d.probes[n] = d.p.Run(n, d.tlsInterval, labels, d.tlsProbeFn(fmt.Sprintf("%s:%d", server.HostName, derpPort))) + d.probes[n] = d.p.Run(n, d.tlsInterval, labels, d.tlsProbeFn(fmt.Sprintf("%s:%d", server.HostName, derpPort), nil)) } } @@ -181,14 +240,27 @@ func (d *derpProber) probeMapFn(ctx context.Context) error { } } - if d.bwInterval > 0 && d.bwProbeSize > 0 { + if d.bwInterval != 0 && d.bwProbeSize > 0 { n := fmt.Sprintf("derp/%s/%s/%s/bw", region.RegionCode, server.Name, to.Name) wantProbes[n] = true if d.probes[n] == nil { - log.Printf("adding DERP bandwidth probe for %s->%s (%s) %v bytes every %v", server.Name, to.Name, region.RegionName, d.bwProbeSize, d.bwInterval) + tunString := "" + if d.bwTUNIPv4Prefix != nil { + tunString = " (TUN)" + } + log.Printf("adding%s DERP bandwidth probe for %s->%s (%s) %v bytes every %v", tunString, server.Name, to.Name, region.RegionName, d.bwProbeSize, d.bwInterval) d.probes[n] = d.p.Run(n, d.bwInterval, labels, d.bwProbeFn(server.Name, to.Name, d.bwProbeSize)) } } + + if d.qdPacketsPerSecond > 0 { + n := fmt.Sprintf("derp/%s/%s/%s/qd", region.RegionCode, server.Name, to.Name) + wantProbes[n] = true + if d.probes[n] == nil { + log.Printf("adding DERP queuing delay probe for %s->%s (%s)", server.Name, to.Name, region.RegionName) + d.probes[n] = d.p.Run(n, -10*time.Second, labels, d.qdProbeFn(server.Name, to.Name, d.qdPacketsPerSecond, d.qdPacketTimeout, d.meshKey)) + } + } } } } @@ -204,7 +276,7 @@ func (d *derpProber) probeMapFn(ctx context.Context) error { return nil } -// probeMesh returs a probe class that sends a test packet through a pair of DERP +// probeMesh returns a probe class that sends a test packet through a pair of DERP // servers (or just one server, if 'from' and 'to' are the same). 'from' and 'to' // are expected to be names (DERPNode.Name) of two DERP servers in the same region. func (d *derpProber) probeMesh(from, to string) ProbeClass { @@ -220,14 +292,14 @@ func (d *derpProber) probeMesh(from, to string) ProbeClass { } dm := d.lastDERPMap - return derpProbeNodePair(ctx, dm, fromN, toN) + return derpProbeNodePair(ctx, dm, fromN, toN, d.meshKey) }, Class: "derp_mesh", Labels: Labels{"derp_path": derpPath}, } } -// probeBandwidth returs a probe class that sends a payload of a given size +// probeBandwidth returns a probe class that sends a payload of a given size // through a pair of DERP servers (or just one server, if 'from' and 'to' are // the same). 'from' and 'to' are expected to be names (DERPNode.Name) of two // DERP servers in the same region. @@ -236,26 +308,244 @@ func (d *derpProber) probeBandwidth(from, to string, size int64) ProbeClass { if from == to { derpPath = "single" } - var transferTime expvar.Float + var transferTimeSeconds expvar.Float + var totalBytesTransferred expvar.Float + return ProbeClass{ + Probe: func(ctx context.Context) error { + fromN, toN, err := d.getNodePair(from, to) + if err != nil { + return err + } + return derpProbeBandwidth(ctx, d.lastDERPMap, fromN, toN, size, &transferTimeSeconds, &totalBytesTransferred, d.bwTUNIPv4Prefix, d.meshKey) + }, + Class: "derp_bw", + Labels: Labels{ + "derp_path": derpPath, + "tcp_in_tcp": strconv.FormatBool(d.bwTUNIPv4Prefix != nil), + }, + Metrics: func(lb prometheus.Labels) []prometheus.Metric { + metrics := []prometheus.Metric{ + prometheus.MustNewConstMetric(prometheus.NewDesc("derp_bw_probe_size_bytes", "Payload size of the bandwidth prober", nil, lb), prometheus.GaugeValue, float64(size)), + prometheus.MustNewConstMetric(prometheus.NewDesc("derp_bw_transfer_time_seconds_total", "Time it took to transfer data", nil, lb), prometheus.CounterValue, transferTimeSeconds.Value()), + } + if d.bwTUNIPv4Prefix != nil { + // For TCP-in-TCP probes, also record cumulative bytes transferred. + metrics = append(metrics, prometheus.MustNewConstMetric(prometheus.NewDesc("derp_bw_bytes_total", "Amount of data transferred", nil, lb), prometheus.CounterValue, totalBytesTransferred.Value())) + } + return metrics + }, + } +} + +// probeQueuingDelay returns a probe class that continuously sends packets +// through a pair of DERP servers (or just one server, if 'from' and 'to' are +// the same) at a rate of `packetsPerSecond` packets per second in order to +// measure queuing delays. Packets arriving after `packetTimeout` don't contribute +// to the queuing delay measurement and are recorded as dropped. 'from' and 'to' are +// expected to be names (DERPNode.Name) of two DERP servers in the same region, +// and may refer to the same server. +func (d *derpProber) probeQueuingDelay(from, to string, packetsPerSecond int, packetTimeout time.Duration, meshKey key.DERPMesh) ProbeClass { + derpPath := "mesh" + if from == to { + derpPath = "single" + } + var packetsDropped expvar.Float + qdh := newHistogram([]float64{.005, .01, .025, .05, .1, .25, .5, 1}) return ProbeClass{ Probe: func(ctx context.Context) error { fromN, toN, err := d.getNodePair(from, to) if err != nil { return err } - return derpProbeBandwidth(ctx, d.lastDERPMap, fromN, toN, size, &transferTime) + return derpProbeQueuingDelay(ctx, d.lastDERPMap, fromN, toN, packetsPerSecond, packetTimeout, &packetsDropped, qdh, meshKey) }, - Class: "derp_bw", + Class: "derp_qd", Labels: Labels{"derp_path": derpPath}, - Metrics: func(l prometheus.Labels) []prometheus.Metric { - return []prometheus.Metric{ - prometheus.MustNewConstMetric(prometheus.NewDesc("derp_bw_probe_size_bytes", "Payload size of the bandwidth prober", nil, l), prometheus.GaugeValue, float64(size)), - prometheus.MustNewConstMetric(prometheus.NewDesc("derp_bw_transfer_time_seconds_total", "Time it took to transfer data", nil, l), prometheus.CounterValue, transferTime.Value()), + Metrics: func(lb prometheus.Labels) []prometheus.Metric { + qdh.mx.Lock() + result := []prometheus.Metric{ + prometheus.MustNewConstMetric(prometheus.NewDesc("derp_qd_probe_dropped_packets", "Total packets dropped", nil, lb), prometheus.CounterValue, float64(packetsDropped.Value())), + prometheus.MustNewConstHistogram(prometheus.NewDesc("derp_qd_probe_delays_seconds", "Distribution of queuing delays", nil, lb), qdh.count, qdh.sum, maps.Clone(qdh.bucketedCounts)), } + qdh.mx.Unlock() + return result }, } } +// derpProbeQueuingDelay continuously sends data between two local DERP clients +// connected to two DERP servers in order to measure queuing delays. From and to +// can be the same server. +func derpProbeQueuingDelay(ctx context.Context, dm *tailcfg.DERPMap, from, to *tailcfg.DERPNode, packetsPerSecond int, packetTimeout time.Duration, packetsDropped *expvar.Float, qdh *histogram, meshKey key.DERPMesh) (err error) { + // This probe uses clients with isProber=false to avoid spamming the derper + // logs with every packet sent by the queuing delay probe. + fromc, err := newConn(ctx, dm, from, false, meshKey) + if err != nil { + return err + } + defer fromc.Close() + toc, err := newConn(ctx, dm, to, false, meshKey) + if err != nil { + return err + } + defer toc.Close() + + // Wait a bit for from's node to hear about to existing on the + // other node in the region, in the case where the two nodes + // are different. + if from.Name != to.Name { + time.Sleep(100 * time.Millisecond) // pretty arbitrary + } + + if err := runDerpProbeQueuingDelayContinously(ctx, from, to, fromc, toc, packetsPerSecond, packetTimeout, packetsDropped, qdh); err != nil { + // Record pubkeys on failed probes to aid investigation. + return fmt.Errorf("%s -> %s: %w", + fromc.SelfPublicKey().ShortString(), + toc.SelfPublicKey().ShortString(), err) + } + return nil +} + +func runDerpProbeQueuingDelayContinously(ctx context.Context, from, to *tailcfg.DERPNode, fromc, toc *derphttp.Client, packetsPerSecond int, packetTimeout time.Duration, packetsDropped *expvar.Float, qdh *histogram) error { + // Make sure all goroutines have finished. + var wg sync.WaitGroup + defer wg.Wait() + + // Close the clients to make sure goroutines that are reading/writing from them terminate. + defer fromc.Close() + defer toc.Close() + + type txRecord struct { + at time.Time + seq uint64 + } + // txRecords is sized to hold enough transmission records to keep timings + // for packets up to their timeout. As records age out of the front of this + // list, if the associated packet arrives, we won't have a txRecord for it + // and will consider it to have timed out. + txRecords := make([]txRecord, 0, packetsPerSecond*int(packetTimeout.Seconds())) + var txRecordsMu sync.Mutex + + // applyTimeouts walks over txRecords and expires any records that are older + // than packetTimeout, recording in metrics that they were removed. + applyTimeouts := func() { + txRecordsMu.Lock() + defer txRecordsMu.Unlock() + + now := time.Now() + recs := txRecords[:0] + for _, r := range txRecords { + if now.Sub(r.at) > packetTimeout { + packetsDropped.Add(1) + } else { + recs = append(recs, r) + } + } + txRecords = recs + } + + // Send the packets. + sendErrC := make(chan error, 1) + // TODO: construct a disco CallMeMaybe in the same fashion as magicsock, e.g. magic bytes, src pub, seal payload. + // DERP server handling of disco may vary from non-disco, and we may want to measure queue delay of both. + pkt := make([]byte, 260) // the same size as a CallMeMaybe packet observed on a Tailscale client. + crand.Read(pkt) + + wg.Add(1) + go func() { + defer wg.Done() + t := time.NewTicker(time.Second / time.Duration(packetsPerSecond)) + defer t.Stop() + + toDERPPubKey := toc.SelfPublicKey() + seq := uint64(0) + for { + select { + case <-ctx.Done(): + return + case <-t.C: + applyTimeouts() + txRecordsMu.Lock() + if len(txRecords) == cap(txRecords) { + txRecords = slices.Delete(txRecords, 0, 1) + packetsDropped.Add(1) + log.Printf("unexpected: overflow in txRecords") + } + txRecords = append(txRecords, txRecord{time.Now(), seq}) + txRecordsMu.Unlock() + binary.BigEndian.PutUint64(pkt, seq) + seq++ + if err := fromc.Send(toDERPPubKey, pkt); err != nil { + sendErrC <- fmt.Errorf("sending packet %w", err) + return + } + } + } + }() + + // Receive the packets. + recvFinishedC := make(chan error, 1) + wg.Add(1) + go func() { + defer wg.Done() + defer close(recvFinishedC) // to break out of 'select' below. + fromDERPPubKey := fromc.SelfPublicKey() + for { + m, err := toc.Recv() + if err != nil { + recvFinishedC <- err + return + } + switch v := m.(type) { + case derp.ReceivedPacket: + now := time.Now() + if v.Source != fromDERPPubKey { + recvFinishedC <- fmt.Errorf("got data packet from unexpected source, %v", v.Source) + return + } + seq := binary.BigEndian.Uint64(v.Data) + txRecordsMu.Lock() + findTxRecord: + for i, record := range txRecords { + switch { + case record.seq == seq: + rtt := now.Sub(record.at) + qdh.add(rtt.Seconds()) + txRecords = slices.Delete(txRecords, i, i+1) + break findTxRecord + case record.seq > seq: + // No sent time found, probably a late arrival already + // recorded as drop by sender when deleted. + break findTxRecord + case record.seq < seq: + continue + } + } + txRecordsMu.Unlock() + + case derp.KeepAliveMessage: + // Silently ignore. + + default: + log.Printf("%v: ignoring Recv frame type %T", to.Name, v) + // Loop. + } + } + }() + + select { + case <-ctx.Done(): + return fmt.Errorf("timeout: %w", ctx.Err()) + case err := <-sendErrC: + return fmt.Errorf("error sending via %q: %w", from.Name, err) + case err := <-recvFinishedC: + if err != nil { + return fmt.Errorf("error receiving from %q: %w", to.Name, err) + } + } + return nil +} + // getNodePair returns DERPNode objects for two DERP servers based on their // short names. func (d *derpProber) getNodePair(n1, n2 string) (ret1, ret2 *tailcfg.DERPNode, _ error) { @@ -272,7 +562,7 @@ func (d *derpProber) getNodePair(n1, n2 string) (ret1, ret2 *tailcfg.DERPNode, _ return ret1, ret2, nil } -var tsLocalClient tailscale.LocalClient +var tsLocalClient local.Client // updateMap refreshes the locally-cached DERP map. func (d *derpProber) updateMap(ctx context.Context) error { @@ -316,6 +606,10 @@ func (d *derpProber) updateMap(ctx context.Context) error { d.lastDERPMapAt = time.Now() d.nodes = make(map[string]*tailcfg.DERPNode) for _, reg := range d.lastDERPMap.Regions { + if d.skipRegion(reg) { + continue + } + for _, n := range reg.Nodes { if existing, ok := d.nodes[n.Name]; ok { return fmt.Errorf("derpmap has duplicate nodes: %+v and %+v", existing, n) @@ -330,14 +624,30 @@ func (d *derpProber) updateMap(ctx context.Context) error { } func (d *derpProber) ProbeUDP(ipaddr string, port int) ProbeClass { + initLabels := make(Labels) + ip := net.ParseIP(ipaddr) + + if ip.To4() != nil { + initLabels["address_family"] = "ipv4" + } else if ip.To16() != nil { // Will return an IPv4 as 16 byte, so ensure the check for IPv4 precedes this + initLabels["address_family"] = "ipv6" + } else { + initLabels["address_family"] = "unknown" + } + return ProbeClass{ Probe: func(ctx context.Context) error { return derpProbeUDP(ctx, ipaddr, port) }, - Class: "derp_udp", + Class: "derp_udp", + Labels: initLabels, } } +func (d *derpProber) skipRegion(region *tailcfg.DERPRegion) bool { + return d.regionCodeOrID != "" && region.RegionCode != d.regionCodeOrID && strconv.Itoa(region.RegionID) != d.regionCodeOrID +} + func derpProbeUDP(ctx context.Context, ipStr string, port int) error { pc, err := net.ListenPacket("udp", ":0") if err != nil { @@ -389,16 +699,18 @@ func derpProbeUDP(ctx context.Context, ipStr string, port int) error { } // derpProbeBandwidth sends a payload of a given size between two local -// DERP clients connected to two DERP servers. -func derpProbeBandwidth(ctx context.Context, dm *tailcfg.DERPMap, from, to *tailcfg.DERPNode, size int64, transferTime *expvar.Float) (err error) { +// DERP clients connected to two DERP servers.If tunIPv4Address is specified, +// probes will use a TCP connection over a TUN device at this address in order +// to exercise TCP-in-TCP in similar fashion to TCP over Tailscale via DERP. +func derpProbeBandwidth(ctx context.Context, dm *tailcfg.DERPMap, from, to *tailcfg.DERPNode, size int64, transferTimeSeconds, totalBytesTransferred *expvar.Float, tunIPv4Prefix *netip.Prefix, meshKey key.DERPMesh) (err error) { // This probe uses clients with isProber=false to avoid spamming the derper logs with every packet // sent by the bandwidth probe. - fromc, err := newConn(ctx, dm, from, false) + fromc, err := newConn(ctx, dm, from, false, meshKey) if err != nil { return err } defer fromc.Close() - toc, err := newConn(ctx, dm, to, false) + toc, err := newConn(ctx, dm, to, false, meshKey) if err != nil { return err } @@ -411,10 +723,13 @@ func derpProbeBandwidth(ctx context.Context, dm *tailcfg.DERPMap, from, to *tail time.Sleep(100 * time.Millisecond) // pretty arbitrary } - start := time.Now() - defer func() { transferTime.Add(time.Since(start).Seconds()) }() + if tunIPv4Prefix != nil { + err = derpProbeBandwidthTUN(ctx, transferTimeSeconds, totalBytesTransferred, from, to, fromc, toc, size, tunIPv4Prefix) + } else { + err = derpProbeBandwidthDirect(ctx, transferTimeSeconds, from, to, fromc, toc, size) + } - if err := runDerpProbeNodePair(ctx, from, to, fromc, toc, size); err != nil { + if err != nil { // Record pubkeys on failed probes to aid investigation. return fmt.Errorf("%s -> %s: %w", fromc.SelfPublicKey().ShortString(), @@ -425,13 +740,13 @@ func derpProbeBandwidth(ctx context.Context, dm *tailcfg.DERPMap, from, to *tail // derpProbeNodePair sends a small packet between two local DERP clients // connected to two DERP servers. -func derpProbeNodePair(ctx context.Context, dm *tailcfg.DERPMap, from, to *tailcfg.DERPNode) (err error) { - fromc, err := newConn(ctx, dm, from, true) +func derpProbeNodePair(ctx context.Context, dm *tailcfg.DERPMap, from, to *tailcfg.DERPNode, meshKey key.DERPMesh) (err error) { + fromc, err := newConn(ctx, dm, from, true, meshKey) if err != nil { return err } defer fromc.Close() - toc, err := newConn(ctx, dm, to, true) + toc, err := newConn(ctx, dm, to, true, meshKey) if err != nil { return err } @@ -494,9 +809,10 @@ func runDerpProbeNodePair(ctx context.Context, from, to *tailcfg.DERPNode, fromc // Send the packets. sendc := make(chan error, 1) go func() { + toDERPPubKey := toc.SelfPublicKey() for idx, pkt := range pkts { inFlight.AcquireContext(ctx) - if err := fromc.Send(toc.SelfPublicKey(), pkt); err != nil { + if err := fromc.Send(toDERPPubKey, pkt); err != nil { sendc <- fmt.Errorf("sending packet %d: %w", idx, err) return } @@ -508,6 +824,7 @@ func runDerpProbeNodePair(ctx context.Context, from, to *tailcfg.DERPNode, fromc go func() { defer close(recvc) // to break out of 'select' below. idx := 0 + fromDERPPubKey := fromc.SelfPublicKey() for { m, err := toc.Recv() if err != nil { @@ -517,10 +834,12 @@ func runDerpProbeNodePair(ctx context.Context, from, to *tailcfg.DERPNode, fromc switch v := m.(type) { case derp.ReceivedPacket: inFlight.Release() - if v.Source != fromc.SelfPublicKey() { + if v.Source != fromDERPPubKey { recvc <- fmt.Errorf("got data packet %d from unexpected source, %v", idx, v.Source) return } + // This assumes that the packets are received reliably and in order. + // The DERP protocol does not guarantee this, but this probe assumes it. if got, want := v.Data, pkts[idx]; !bytes.Equal(got, want) { recvc <- fmt.Errorf("unexpected data packet %d (out of %d)", idx, len(pkts)) return @@ -554,13 +873,284 @@ func runDerpProbeNodePair(ctx context.Context, from, to *tailcfg.DERPNode, fromc return nil } -func newConn(ctx context.Context, dm *tailcfg.DERPMap, n *tailcfg.DERPNode, isProber bool) (*derphttp.Client, error) { +// derpProbeBandwidthDirect takes two DERP clients (fromc and toc) connected to two +// DERP servers (from and to) and sends a test payload of a given size from one +// to another using runDerpProbeNodePair. The time taken to finish the transfer is +// recorded in `transferTimeSeconds`. +func derpProbeBandwidthDirect(ctx context.Context, transferTimeSeconds *expvar.Float, from, to *tailcfg.DERPNode, fromc, toc *derphttp.Client, size int64) error { + start := time.Now() + defer func() { transferTimeSeconds.Add(time.Since(start).Seconds()) }() + + return runDerpProbeNodePair(ctx, from, to, fromc, toc, size) +} + +// derpProbeBandwidthTUNMu ensures that TUN bandwidth probes don't run concurrently. +// This is necessary to avoid conflicts trying to create the TUN device, and +// it also has the nice benefit of preventing concurrent bandwidth probes from +// influencing each other's results. +// +// This guards derpProbeBandwidthTUN. +var derpProbeBandwidthTUNMu sync.Mutex + +// derpProbeBandwidthTUN takes two DERP clients (fromc and toc) connected to two +// DERP servers (from and to) and sends a test payload of a given size from one +// to another over a TUN device at an address at the start of the usable host IP +// range that the given tunAddress lives in. The time taken to finish the transfer +// is recorded in `transferTimeSeconds`. +func derpProbeBandwidthTUN(ctx context.Context, transferTimeSeconds, totalBytesTransferred *expvar.Float, from, to *tailcfg.DERPNode, fromc, toc *derphttp.Client, size int64, prefix *netip.Prefix) error { + // Make sure all goroutines have finished. + var wg sync.WaitGroup + defer wg.Wait() + + // Close the clients to make sure goroutines that are reading/writing from them terminate. + defer fromc.Close() + defer toc.Close() + + ipRange := netipx.RangeOfPrefix(*prefix) + // Start of the usable host IP range from the address we have been passed in. + ifAddr := ipRange.From().Next() + // Destination address to dial. This is the next address in the range from + // our ifAddr to ensure that the underlying networking stack is actually being + // utilized instead of being optimized away and treated as a loopback. Packets + // sent to this address will be routed over the TUN. + destinationAddr := ifAddr.Next() + + derpProbeBandwidthTUNMu.Lock() + defer derpProbeBandwidthTUNMu.Unlock() + + // Temporarily set up a TUN device with which to simulate a real client TCP connection + // tunneling over DERP. Use `tstun.DefaultTUNMTU()` (e.g., 1280) as our MTU as this is + // the minimum safe MTU used by Tailscale. + dev, err := tun.CreateTUN(tunName, int(tstun.DefaultTUNMTU())) + if err != nil { + return fmt.Errorf("failed to create TUN device: %w", err) + } + defer func() { + if err := dev.Close(); err != nil { + log.Printf("failed to close TUN device: %s", err) + } + }() + mtu, err := dev.MTU() + if err != nil { + return fmt.Errorf("failed to get TUN MTU: %w", err) + } + + name, err := dev.Name() + if err != nil { + return fmt.Errorf("failed to get device name: %w", err) + } + + // Perform platform specific configuration of the TUN device. + err = configureTUN(*prefix, name) + if err != nil { + return fmt.Errorf("failed to configure tun: %w", err) + } + + // Depending on platform, we need some space for headers at the front + // of TUN I/O op buffers. The below constant is more than enough space + // for any platform that this might run on. + tunStartOffset := device.MessageTransportHeaderSize + + // This goroutine reads packets from the TUN device and evaluates if they + // are IPv4 packets destined for loopback via DERP. If so, it performs L3 NAT + // (swap src/dst) and writes them towards DERP in order to loopback via the + // `toc` DERP client. It only reports errors to `tunReadErrC`. + wg.Add(1) + tunReadErrC := make(chan error, 1) + go func() { + defer wg.Done() + + numBufs := wgconn.IdealBatchSize + bufs := make([][]byte, 0, numBufs) + sizes := make([]int, numBufs) + for range numBufs { + bufs = append(bufs, make([]byte, mtu+tunStartOffset)) + } + + destinationAddrBytes := destinationAddr.AsSlice() + scratch := make([]byte, 4) + toDERPPubKey := toc.SelfPublicKey() + for { + n, err := dev.Read(bufs, sizes, tunStartOffset) + if err != nil { + tunReadErrC <- err + return + } + + for i := range n { + pkt := bufs[i][tunStartOffset : sizes[i]+tunStartOffset] + // Skip everything except valid IPv4 packets + if len(pkt) < 20 { + // Doesn't even have a full IPv4 header + continue + } + if pkt[0]>>4 != 4 { + // Not IPv4 + continue + } + + if !bytes.Equal(pkt[16:20], destinationAddrBytes) { + // Unexpected dst address + continue + } + + copy(scratch, pkt[12:16]) + copy(pkt[12:16], pkt[16:20]) + copy(pkt[16:20], scratch) + + if err := fromc.Send(toDERPPubKey, pkt); err != nil { + tunReadErrC <- err + return + } + } + } + }() + + // This goroutine reads packets from the `toc` DERP client and writes them towards the TUN. + // It only reports errors to `recvErrC` channel. + wg.Add(1) + recvErrC := make(chan error, 1) + go func() { + defer wg.Done() + + buf := make([]byte, mtu+tunStartOffset) + bufs := make([][]byte, 1) + + fromDERPPubKey := fromc.SelfPublicKey() + for { + m, err := toc.Recv() + if err != nil { + recvErrC <- fmt.Errorf("failed to receive: %w", err) + return + } + switch v := m.(type) { + case derp.ReceivedPacket: + if v.Source != fromDERPPubKey { + recvErrC <- fmt.Errorf("got data packet from unexpected source, %v", v.Source) + return + } + pkt := v.Data + copy(buf[tunStartOffset:], pkt) + bufs[0] = buf[:len(pkt)+tunStartOffset] + if _, err := dev.Write(bufs, tunStartOffset); err != nil { + recvErrC <- fmt.Errorf("failed to write to TUN device: %w", err) + return + } + case derp.KeepAliveMessage: + // Silently ignore. + default: + log.Printf("%v: ignoring Recv frame type %T", to.Name, v) + // Loop. + } + } + }() + + // Start a listener to receive the data + ln, err := net.Listen("tcp", net.JoinHostPort(ifAddr.String(), "0")) + if err != nil { + return fmt.Errorf("failed to listen: %s", err) + } + defer ln.Close() + + // 128KB by default + const writeChunkSize = 128 << 10 + + randData := make([]byte, writeChunkSize) + _, err = crand.Read(randData) + if err != nil { + return fmt.Errorf("failed to initialize random data: %w", err) + } + + // Dial ourselves + _, port, err := net.SplitHostPort(ln.Addr().String()) + if err != nil { + return fmt.Errorf("failed to split address %q: %w", ln.Addr().String(), err) + } + + connAddr := net.JoinHostPort(destinationAddr.String(), port) + conn, err := net.Dial("tcp", connAddr) + if err != nil { + return fmt.Errorf("failed to dial address %q: %w", connAddr, err) + } + defer conn.Close() + + // Timing only includes the actual sending and receiving of data. + start := time.Now() + + // This goroutine reads data from the TCP stream being looped back via DERP. + // It reports to `readFinishedC` when `size` bytes have been read, or if an + // error occurs. + wg.Add(1) + readFinishedC := make(chan error, 1) + go func() { + defer wg.Done() + + readConn, err := ln.Accept() + if err != nil { + readFinishedC <- err + return + } + defer readConn.Close() + deadline, ok := ctx.Deadline() + if ok { + // Don't try reading past our context's deadline. + if err := readConn.SetReadDeadline(deadline); err != nil { + readFinishedC <- fmt.Errorf("unable to set read deadline: %w", err) + return + } + } + n, err := io.CopyN(io.Discard, readConn, size) + // Measure transfer time and bytes transferred irrespective of whether it succeeded or failed. + transferTimeSeconds.Add(time.Since(start).Seconds()) + totalBytesTransferred.Add(float64(n)) + readFinishedC <- err + }() + + // This goroutine sends data to the TCP stream being looped back via DERP. + // It only reports errors to `sendErrC`. + wg.Add(1) + sendErrC := make(chan error, 1) + go func() { + defer wg.Done() + + for wrote := 0; wrote < int(size); wrote += len(randData) { + b := randData + if wrote+len(randData) > int(size) { + // This is the last chunk and we don't need the whole thing + b = b[0 : int(size)-wrote] + } + if _, err := conn.Write(b); err != nil { + sendErrC <- fmt.Errorf("failed to write to conn: %w", err) + return + } + } + }() + + select { + case <-ctx.Done(): + return fmt.Errorf("timeout: %w", ctx.Err()) + case err := <-tunReadErrC: + return fmt.Errorf("error reading from TUN via %q: %w", from.Name, err) + case err := <-sendErrC: + return fmt.Errorf("error sending via %q: %w", from.Name, err) + case err := <-recvErrC: + return fmt.Errorf("error receiving from %q: %w", to.Name, err) + case err := <-readFinishedC: + if err != nil { + return fmt.Errorf("error reading from %q to TUN: %w", to.Name, err) + } + } + + return nil +} + +func newConn(ctx context.Context, dm *tailcfg.DERPMap, n *tailcfg.DERPNode, isProber bool, meshKey key.DERPMesh) (*derphttp.Client, error) { // To avoid spamming the log with regular connection messages. - l := logger.Filtered(log.Printf, func(s string) bool { + logf := logger.Filtered(log.Printf, func(s string) bool { return !strings.Contains(s, "derphttp.Client.Connect: connecting to") }) priv := key.NewNode() - dc := derphttp.NewRegionClient(priv, l, netmon.NewStatic(), func() *tailcfg.DERPRegion { + dc := derphttp.NewRegionClient(priv, logf, netmon.NewStatic(), func() *tailcfg.DERPRegion { rid := n.RegionID return &tailcfg.DERPRegion{ RegionID: rid, @@ -570,22 +1160,27 @@ func newConn(ctx context.Context, dm *tailcfg.DERPMap, n *tailcfg.DERPNode, isPr } }) dc.IsProber = isProber + dc.MeshKey = meshKey err := dc.Connect(ctx) if err != nil { return nil, err } - cs, ok := dc.TLSConnectionState() - if !ok { - dc.Close() - return nil, errors.New("no TLS state") - } - if len(cs.PeerCertificates) == 0 { - dc.Close() - return nil, errors.New("no peer certificates") - } - if cs.ServerName != n.HostName { - dc.Close() - return nil, fmt.Errorf("TLS server name %q != derp hostname %q", cs.ServerName, n.HostName) + + // Only verify TLS state if this is a prober. + if isProber { + cs, ok := dc.TLSConnectionState() + if !ok { + dc.Close() + return nil, errors.New("no TLS state") + } + if len(cs.PeerCertificates) == 0 { + dc.Close() + return nil, errors.New("no peer certificates") + } + if cs.ServerName != n.HostName { + dc.Close() + return nil, fmt.Errorf("TLS server name %q != derp hostname %q", cs.ServerName, n.HostName) + } } errc := make(chan error, 1) @@ -599,7 +1194,7 @@ func newConn(ctx context.Context, dm *tailcfg.DERPMap, n *tailcfg.DERPNode, isPr case derp.ServerInfoMessage: errc <- nil default: - errc <- fmt.Errorf("unexpected first message type %T", errc) + errc <- fmt.Errorf("unexpected first message type %T", m) } }() select { diff --git a/prober/derp_test.go b/prober/derp_test.go index a34292a23..08a65d697 100644 --- a/prober/derp_test.go +++ b/prober/derp_test.go @@ -16,6 +16,7 @@ import ( "tailscale.com/derp" "tailscale.com/derp/derphttp" + "tailscale.com/derp/derpserver" "tailscale.com/net/netmon" "tailscale.com/tailcfg" "tailscale.com/types/key" @@ -44,6 +45,19 @@ func TestDerpProber(t *testing.T) { }, }, }, + 1: { + RegionID: 1, + RegionCode: "one", + Nodes: []*tailcfg.DERPNode{ + { + Name: "n3", + RegionID: 0, + HostName: "derpn3.tailscale.test", + IPv4: "1.1.1.1", + IPv6: "::1", + }, + }, + }, }, } srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -58,16 +72,17 @@ func TestDerpProber(t *testing.T) { clk := newFakeTime() p := newForTest(clk.Now, clk.NewTicker) dp := &derpProber{ - p: p, - derpMapURL: srv.URL, - tlsInterval: time.Second, - tlsProbeFn: func(_ string) ProbeClass { return FuncProbe(func(context.Context) error { return nil }) }, - udpInterval: time.Second, - udpProbeFn: func(_ string, _ int) ProbeClass { return FuncProbe(func(context.Context) error { return nil }) }, - meshInterval: time.Second, - meshProbeFn: func(_, _ string) ProbeClass { return FuncProbe(func(context.Context) error { return nil }) }, - nodes: make(map[string]*tailcfg.DERPNode), - probes: make(map[string]*Probe), + p: p, + derpMapURL: srv.URL, + tlsInterval: time.Second, + tlsProbeFn: func(_ string, _ *tls.Config) ProbeClass { return FuncProbe(func(context.Context) error { return nil }) }, + udpInterval: time.Second, + udpProbeFn: func(_ string, _ int) ProbeClass { return FuncProbe(func(context.Context) error { return nil }) }, + meshInterval: time.Second, + meshProbeFn: func(_, _ string) ProbeClass { return FuncProbe(func(context.Context) error { return nil }) }, + nodes: make(map[string]*tailcfg.DERPNode), + probes: make(map[string]*Probe), + regionCodeOrID: "zero", } if err := dp.probeMapFn(context.Background()); err != nil { t.Errorf("unexpected probeMapFn() error: %s", err) @@ -84,9 +99,9 @@ func TestDerpProber(t *testing.T) { // Add one more node and check that probes got created. dm.Regions[0].Nodes = append(dm.Regions[0].Nodes, &tailcfg.DERPNode{ - Name: "n3", + Name: "n4", RegionID: 0, - HostName: "derpn3.tailscale.test", + HostName: "derpn4.tailscale.test", IPv4: "1.1.1.1", IPv6: "::1", }) @@ -113,17 +128,30 @@ func TestDerpProber(t *testing.T) { if len(dp.probes) != 4 { t.Errorf("unexpected probes: %+v", dp.probes) } + + // Stop filtering regions. + dp.regionCodeOrID = "" + if err := dp.probeMapFn(context.Background()); err != nil { + t.Errorf("unexpected probeMapFn() error: %s", err) + } + if len(dp.nodes) != 2 { + t.Errorf("unexpected nodes: %+v", dp.nodes) + } + // 6 regular probes + 2 mesh probe + if len(dp.probes) != 8 { + t.Errorf("unexpected probes: %+v", dp.probes) + } } func TestRunDerpProbeNodePair(t *testing.T) { // os.Setenv("DERP_DEBUG_LOGS", "true") serverPrivateKey := key.NewNode() - s := derp.NewServer(serverPrivateKey, t.Logf) + s := derpserver.New(serverPrivateKey, t.Logf) defer s.Close() httpsrv := &http.Server{ TLSNextProto: make(map[string]func(*http.Server, *tls.Conn, http.Handler)), - Handler: derphttp.Handler(s), + Handler: derpserver.Handler(s), } ln, err := net.Listen("tcp4", "localhost:0") if err != nil { diff --git a/prober/dns_example_test.go b/prober/dns_example_test.go index a8326fd72..089816919 100644 --- a/prober/dns_example_test.go +++ b/prober/dns_example_test.go @@ -5,6 +5,7 @@ package prober_test import ( "context" + "crypto/tls" "flag" "fmt" "log" @@ -40,7 +41,7 @@ func ExampleForEachAddr() { // This function is called every time we discover a new IP address to check. makeTLSProbe := func(addr netip.Addr) []*prober.Probe { - pf := prober.TLSWithIP(*hostname, netip.AddrPortFrom(addr, 443)) + pf := prober.TLSWithIP(netip.AddrPortFrom(addr, 443), &tls.Config{ServerName: *hostname}) if *verbose { logger := logger.WithPrefix(log.Printf, fmt.Sprintf("[tls %s]: ", addr)) pf = probeLogWrapper(logger, pf) diff --git a/prober/histogram.go b/prober/histogram.go new file mode 100644 index 000000000..c544a5f79 --- /dev/null +++ b/prober/histogram.go @@ -0,0 +1,49 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package prober + +import ( + "slices" + "sync" +) + +// histogram serves as an adapter to the Prometheus histogram datatype. +// The prober framework passes labels at custom metric collection time that +// it expects to be coupled with the returned metrics. See ProbeClass.Metrics +// and its call sites. Native prometheus histograms cannot be collected while +// injecting more labels. Instead we use this type and pass observations + +// collection labels to prometheus.MustNewConstHistogram() at prometheus +// metric collection time. +type histogram struct { + count uint64 + sum float64 + buckets []float64 + bucketedCounts map[float64]uint64 + mx sync.Mutex +} + +// newHistogram constructs a histogram that buckets data based on the given +// slice of upper bounds. +func newHistogram(buckets []float64) *histogram { + slices.Sort(buckets) + return &histogram{ + buckets: buckets, + bucketedCounts: make(map[float64]uint64, len(buckets)), + } +} + +func (h *histogram) add(v float64) { + h.mx.Lock() + defer h.mx.Unlock() + + h.count++ + h.sum += v + + for _, b := range h.buckets { + if v > b { + continue + } + h.bucketedCounts[b] += 1 + } +} diff --git a/prober/histogram_test.go b/prober/histogram_test.go new file mode 100644 index 000000000..dbb5eda67 --- /dev/null +++ b/prober/histogram_test.go @@ -0,0 +1,29 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package prober + +import ( + "testing" + + "github.com/google/go-cmp/cmp" +) + +func TestHistogram(t *testing.T) { + h := newHistogram([]float64{1, 2}) + h.add(0.5) + h.add(1) + h.add(1.5) + h.add(2) + h.add(2.5) + + if diff := cmp.Diff(h.count, uint64(5)); diff != "" { + t.Errorf("wrong count; (-got+want):%v", diff) + } + if diff := cmp.Diff(h.sum, 7.5); diff != "" { + t.Errorf("wrong sum; (-got+want):%v", diff) + } + if diff := cmp.Diff(h.bucketedCounts, map[float64]uint64{1: 2, 2: 4}); diff != "" { + t.Errorf("wrong bucketedCounts; (-got+want):%v", diff) + } +} diff --git a/prober/prober.go b/prober/prober.go index 2a43628bd..6b904dd97 100644 --- a/prober/prober.go +++ b/prober/prober.go @@ -7,6 +7,8 @@ package prober import ( + "bytes" + "cmp" "container/ring" "context" "encoding/json" @@ -16,10 +18,13 @@ import ( "maps" "math/rand" "net/http" + "slices" "sync" "time" "github.com/prometheus/client_golang/prometheus" + "golang.org/x/sync/errgroup" + "tailscale.com/syncs" "tailscale.com/tsweb" ) @@ -44,6 +49,14 @@ type ProbeClass struct { // exposed by this probe class. Labels Labels + // Timeout is the maximum time the probe function is allowed to run before + // its context is cancelled. Defaults to 80% of the scheduling interval. + Timeout time.Duration + + // Concurrency is the maximum number of concurrent probe executions + // allowed for this probe class. Defaults to 1. + Concurrency int + // Metrics allows a probe class to export custom Metrics. Can be nil. Metrics func(prometheus.Labels) []prometheus.Metric } @@ -94,6 +107,9 @@ func newForTest(now func() time.Time, newTicker func(time.Duration) ticker) *Pro // Run executes probe class function every interval, and exports probe results under probeName. // +// If interval is negative, the probe will run continuously. If it encounters a failure while +// running continuously, it will pause for -1*interval and then retry. +// // Registering a probe under an already-registered name panics. func (p *Prober) Run(name string, interval time.Duration, labels Labels, pc ProbeClass) *Probe { p.mu.Lock() @@ -102,25 +118,25 @@ func (p *Prober) Run(name string, interval time.Duration, labels Labels, pc Prob panic(fmt.Sprintf("probe named %q already registered", name)) } - l := prometheus.Labels{ + lb := prometheus.Labels{ "name": name, "class": pc.Class, } for k, v := range pc.Labels { - l[k] = v + lb[k] = v } for k, v := range labels { - l[k] = v + lb[k] = v } - probe := newProbe(p, name, interval, l, pc) + probe := newProbe(p, name, interval, lb, pc) p.probes[name] = probe go probe.loop() return probe } // newProbe creates a new Probe with the given parameters, but does not start it. -func newProbe(p *Prober, name string, interval time.Duration, l prometheus.Labels, pc ProbeClass) *Probe { +func newProbe(p *Prober, name string, interval time.Duration, lg prometheus.Labels, pc ProbeClass) *Probe { ctx, cancel := context.WithCancel(context.Background()) probe := &Probe{ prober: p, @@ -128,25 +144,28 @@ func newProbe(p *Prober, name string, interval time.Duration, l prometheus.Label cancel: cancel, stopped: make(chan struct{}), + runSema: syncs.NewSemaphore(cmp.Or(pc.Concurrency, 1)), + name: name, probeClass: pc, interval: interval, + timeout: cmp.Or(pc.Timeout, time.Duration(float64(interval)*0.8)), initialDelay: initialDelay(name, interval), successHist: ring.New(recentHistSize), latencyHist: ring.New(recentHistSize), metrics: prometheus.NewRegistry(), - metricLabels: l, - mInterval: prometheus.NewDesc("interval_secs", "Probe interval in seconds", nil, l), - mStartTime: prometheus.NewDesc("start_secs", "Latest probe start time (seconds since epoch)", nil, l), - mEndTime: prometheus.NewDesc("end_secs", "Latest probe end time (seconds since epoch)", nil, l), - mLatency: prometheus.NewDesc("latency_millis", "Latest probe latency (ms)", nil, l), - mResult: prometheus.NewDesc("result", "Latest probe result (1 = success, 0 = failure)", nil, l), + metricLabels: lg, + mInterval: prometheus.NewDesc("interval_secs", "Probe interval in seconds", nil, lg), + mStartTime: prometheus.NewDesc("start_secs", "Latest probe start time (seconds since epoch)", nil, lg), + mEndTime: prometheus.NewDesc("end_secs", "Latest probe end time (seconds since epoch)", nil, lg), + mLatency: prometheus.NewDesc("latency_millis", "Latest probe latency (ms)", nil, lg), + mResult: prometheus.NewDesc("result", "Latest probe result (1 = success, 0 = failure)", nil, lg), mAttempts: prometheus.NewCounterVec(prometheus.CounterOpts{ - Name: "attempts_total", Help: "Total number of probing attempts", ConstLabels: l, + Name: "attempts_total", Help: "Total number of probing attempts", ConstLabels: lg, }, []string{"status"}), mSeconds: prometheus.NewCounterVec(prometheus.CounterOpts{ - Name: "seconds_total", Help: "Total amount of time spent executing the probe", ConstLabels: l, + Name: "seconds_total", Help: "Total amount of time spent executing the probe", ConstLabels: lg, }, []string{"status"}), } if p.metrics != nil { @@ -223,11 +242,12 @@ type Probe struct { ctx context.Context cancel context.CancelFunc // run to initiate shutdown stopped chan struct{} // closed when shutdown is complete - runMu sync.Mutex // ensures only one probe runs at a time + runSema syncs.Semaphore // restricts concurrency per probe name string probeClass ProbeClass interval time.Duration + timeout time.Duration initialDelay time.Duration tick ticker @@ -256,6 +276,11 @@ type Probe struct { latencyHist *ring.Ring } +// IsContinuous indicates that this is a continuous probe. +func (p *Probe) IsContinuous() bool { + return p.interval < 0 +} + // Close shuts down the Probe and unregisters it from its Prober. // It is safe to Run a new probe of the same name after Close returns. func (p *Probe) Close() error { @@ -274,26 +299,43 @@ func (p *Probe) loop() { t := p.prober.newTicker(p.initialDelay) select { case <-t.Chan(): - p.run() case <-p.ctx.Done(): t.Stop() return } t.Stop() - } else { - p.run() } if p.prober.once { + p.run() return } + if p.IsContinuous() { + // Probe function is going to run continuously. + for { + p.run() + // Wait and then retry if probe fails. We use the inverse of the + // configured negative interval as our sleep period. + // TODO(percy):implement exponential backoff, possibly using util/backoff. + select { + case <-time.After(-1 * p.interval): + p.run() + case <-p.ctx.Done(): + return + } + } + } + p.tick = p.prober.newTicker(p.interval) defer p.tick.Stop() for { + // Run the probe in a new goroutine every tick. Default concurrency & timeout + // settings will ensure that only one probe is running at a time. + go p.run() + select { case <-p.tick.Chan(): - p.run() case <-p.ctx.Done(): return } @@ -307,8 +349,13 @@ func (p *Probe) loop() { // that the probe either succeeds or fails before the next cycle is scheduled to // start. func (p *Probe) run() (pi ProbeInfo, err error) { - p.runMu.Lock() - defer p.runMu.Unlock() + // Probes are scheduled each p.interval, so we don't wait longer than that. + semaCtx, cancel := context.WithTimeout(p.ctx, p.interval) + defer cancel() + if !p.runSema.AcquireContext(semaCtx) { + return pi, fmt.Errorf("probe %s: context cancelled", p.name) + } + defer p.runSema.Release() p.recordStart() defer func() { @@ -320,15 +367,21 @@ func (p *Probe) run() (pi ProbeInfo, err error) { if r := recover(); r != nil { log.Printf("probe %s panicked: %v", p.name, r) err = fmt.Errorf("panic: %v", r) - p.recordEnd(err) + p.recordEndLocked(err) } }() - timeout := time.Duration(float64(p.interval) * 0.8) - ctx, cancel := context.WithTimeout(p.ctx, timeout) - defer cancel() + ctx := p.ctx + if !p.IsContinuous() { + var cancel func() + ctx, cancel = context.WithTimeout(ctx, p.timeout) + defer cancel() + } err = p.probeClass.Probe(ctx) - p.recordEnd(err) + + p.mu.Lock() + defer p.mu.Unlock() + p.recordEndLocked(err) if err != nil { log.Printf("probe %s: %v", p.name, err) } @@ -342,10 +395,8 @@ func (p *Probe) recordStart() { p.mu.Unlock() } -func (p *Probe) recordEnd(err error) { +func (p *Probe) recordEndLocked(err error) { end := p.prober.now() - p.mu.Lock() - defer p.mu.Unlock() p.end = end p.succeeded = err == nil p.lastErr = err @@ -356,15 +407,29 @@ func (p *Probe) recordEnd(err error) { p.mSeconds.WithLabelValues("ok").Add(latency.Seconds()) p.latencyHist.Value = latency p.latencyHist = p.latencyHist.Next() + p.mAttempts.WithLabelValues("fail").Add(0) + p.mSeconds.WithLabelValues("fail").Add(0) } else { p.latency = 0 p.mAttempts.WithLabelValues("fail").Inc() p.mSeconds.WithLabelValues("fail").Add(latency.Seconds()) + p.mAttempts.WithLabelValues("ok").Add(0) + p.mSeconds.WithLabelValues("ok").Add(0) } p.successHist.Value = p.succeeded p.successHist = p.successHist.Next() } +// ProbeStatus indicates the status of a probe. +type ProbeStatus string + +const ( + ProbeStatusUnknown = "unknown" + ProbeStatusRunning = "running" + ProbeStatusFailed = "failed" + ProbeStatusSucceeded = "succeeded" +) + // ProbeInfo is a snapshot of the configuration and state of a Probe. type ProbeInfo struct { Name string @@ -374,7 +439,7 @@ type ProbeInfo struct { Start time.Time End time.Time Latency time.Duration - Result bool + Status ProbeStatus Error string RecentResults []bool RecentLatencies []time.Duration @@ -402,6 +467,10 @@ func (pb ProbeInfo) RecentMedianLatency() time.Duration { return pb.RecentLatencies[len(pb.RecentLatencies)/2] } +func (pb ProbeInfo) Continuous() bool { + return pb.Interval < 0 +} + // ProbeInfo returns the state of all probes. func (p *Prober) ProbeInfo() map[string]ProbeInfo { out := map[string]ProbeInfo{} @@ -429,17 +498,22 @@ func (probe *Probe) probeInfoLocked() ProbeInfo { Labels: probe.metricLabels, Start: probe.start, End: probe.end, - Result: probe.succeeded, } - if probe.lastErr != nil { + inf.Status = ProbeStatusUnknown + if probe.end.Before(probe.start) { + inf.Status = ProbeStatusRunning + } else if probe.succeeded { + inf.Status = ProbeStatusSucceeded + } else if probe.lastErr != nil { + inf.Status = ProbeStatusFailed inf.Error = probe.lastErr.Error() } if probe.latency > 0 { inf.Latency = probe.latency } probe.latencyHist.Do(func(v any) { - if l, ok := v.(time.Duration); ok { - inf.RecentLatencies = append(inf.RecentLatencies, l) + if latency, ok := v.(time.Duration); ok { + inf.RecentLatencies = append(inf.RecentLatencies, latency) } }) probe.successHist.Do(func(v any) { @@ -467,7 +541,7 @@ func (p *Prober) RunHandler(w http.ResponseWriter, r *http.Request) error { p.mu.Lock() probe, ok := p.probes[name] p.mu.Unlock() - if !ok { + if !ok || probe.IsContinuous() { return tsweb.Error(http.StatusNotFound, fmt.Sprintf("unknown probe %q", name), nil) } @@ -488,22 +562,84 @@ func (p *Prober) RunHandler(w http.ResponseWriter, r *http.Request) error { PreviousSuccessRatio: prevInfo.RecentSuccessRatio(), PreviousMedianLatency: prevInfo.RecentMedianLatency(), } - w.WriteHeader(respStatus) w.Header().Set("Content-Type", "application/json") + w.WriteHeader(respStatus) if err := json.NewEncoder(w).Encode(resp); err != nil { return tsweb.Error(http.StatusInternalServerError, "error encoding JSON response", err) } return nil } - stats := fmt.Sprintf("Last %d probes: success rate %d%%, median latency %v\n", - len(prevInfo.RecentResults), - int(prevInfo.RecentSuccessRatio()*100), prevInfo.RecentMedianLatency()) + stats := fmt.Sprintf("Last %d probes (including this one): success rate %d%%, median latency %v\n", + len(info.RecentResults), + int(info.RecentSuccessRatio()*100), info.RecentMedianLatency()) if err != nil { return tsweb.Error(respStatus, fmt.Sprintf("Probe failed: %s\n%s", err.Error(), stats), err) } w.WriteHeader(respStatus) - w.Write([]byte(fmt.Sprintf("Probe succeeded in %v\n%s", info.Latency, stats))) + fmt.Fprintf(w, "Probe succeeded in %v\n%s", info.Latency, stats) + return nil +} + +type RunHandlerAllResponse struct { + Results map[string]RunHandlerResponse +} + +func (p *Prober) RunAllHandler(w http.ResponseWriter, r *http.Request) error { + excluded := r.URL.Query()["exclude"] + + probes := make(map[string]*Probe) + p.mu.Lock() + for _, probe := range p.probes { + if !probe.IsContinuous() && !slices.Contains(excluded, probe.name) { + probes[probe.name] = probe + } + } + p.mu.Unlock() + + // Do not abort running probes just because one of them has failed. + g := new(errgroup.Group) + + var resultsMu sync.Mutex + results := make(map[string]RunHandlerResponse) + + for name, probe := range probes { + g.Go(func() error { + probe.mu.Lock() + prevInfo := probe.probeInfoLocked() + probe.mu.Unlock() + + info, err := probe.run() + + resultsMu.Lock() + results[name] = RunHandlerResponse{ + ProbeInfo: info, + PreviousSuccessRatio: prevInfo.RecentSuccessRatio(), + PreviousMedianLatency: prevInfo.RecentMedianLatency(), + } + resultsMu.Unlock() + return err + }) + } + + respStatus := http.StatusOK + if err := g.Wait(); err != nil { + respStatus = http.StatusFailedDependency + } + + // Return serialized JSON response if the client requested JSON + resp := &RunHandlerAllResponse{ + Results: results, + } + var b bytes.Buffer + if err := json.NewEncoder(&b).Encode(resp); err != nil { + return tsweb.Error(http.StatusInternalServerError, "error encoding JSON response", err) + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(respStatus) + w.Write(b.Bytes()) + return nil } @@ -531,7 +667,8 @@ func (p *Probe) Collect(ch chan<- prometheus.Metric) { if !p.start.IsZero() { ch <- prometheus.MustNewConstMetric(p.mStartTime, prometheus.GaugeValue, float64(p.start.Unix())) } - if p.end.IsZero() { + // For periodic probes that haven't ended, don't collect probe metrics yet. + if p.end.IsZero() && !p.IsContinuous() { return } ch <- prometheus.MustNewConstMetric(p.mEndTime, prometheus.GaugeValue, float64(p.end.Unix())) @@ -582,8 +719,8 @@ func initialDelay(seed string, interval time.Duration) time.Duration { // Labels is a set of metric labels used by a prober. type Labels map[string]string -func (l Labels) With(k, v string) Labels { - new := maps.Clone(l) +func (lb Labels) With(k, v string) Labels { + new := maps.Clone(lb) new[k] = v return new } diff --git a/prober/prober_test.go b/prober/prober_test.go index 742a914b2..1e045fa89 100644 --- a/prober/prober_test.go +++ b/prober/prober_test.go @@ -9,7 +9,10 @@ import ( "errors" "fmt" "io" + "net/http" "net/http/httptest" + "net/url" + "regexp" "strings" "sync" "sync/atomic" @@ -149,6 +152,74 @@ func TestProberTimingSpread(t *testing.T) { notCalled() } +func TestProberTimeout(t *testing.T) { + clk := newFakeTime() + p := newForTest(clk.Now, clk.NewTicker) + + var done sync.WaitGroup + done.Add(1) + pfunc := FuncProbe(func(ctx context.Context) error { + defer done.Done() + select { + case <-ctx.Done(): + return ctx.Err() + } + }) + pfunc.Timeout = time.Microsecond + probe := p.Run("foo", 30*time.Second, nil, pfunc) + waitActiveProbes(t, p, clk, 1) + done.Wait() + probe.mu.Lock() + info := probe.probeInfoLocked() + probe.mu.Unlock() + wantInfo := ProbeInfo{ + Name: "foo", + Interval: 30 * time.Second, + Labels: map[string]string{"class": "", "name": "foo"}, + Status: ProbeStatusFailed, + Error: "context deadline exceeded", + RecentResults: []bool{false}, + RecentLatencies: nil, + } + if diff := cmp.Diff(wantInfo, info, cmpopts.IgnoreFields(ProbeInfo{}, "Start", "End", "Latency")); diff != "" { + t.Fatalf("unexpected ProbeInfo (-want +got):\n%s", diff) + } + if got := info.Latency; got > time.Second { + t.Errorf("info.Latency = %v, want at most 1s", got) + } +} + +func TestProberConcurrency(t *testing.T) { + clk := newFakeTime() + p := newForTest(clk.Now, clk.NewTicker) + + var ran atomic.Int64 + stopProbe := make(chan struct{}) + pfunc := FuncProbe(func(ctx context.Context) error { + ran.Add(1) + <-stopProbe + return nil + }) + pfunc.Timeout = time.Hour + pfunc.Concurrency = 3 + p.Run("foo", time.Second, nil, pfunc) + waitActiveProbes(t, p, clk, 1) + + for range 50 { + clk.Advance(time.Second) + } + + if err := tstest.WaitFor(convergenceTimeout, func() error { + if got, want := ran.Load(), int64(3); got != want { + return fmt.Errorf("expected %d probes to run concurrently, got %d", want, got) + } + return nil + }); err != nil { + t.Fatal(err) + } + close(stopProbe) +} + func TestProberRun(t *testing.T) { clk := newFakeTime() p := newForTest(clk.Now, clk.NewTicker) @@ -316,7 +387,7 @@ func TestProberProbeInfo(t *testing.T) { Interval: probeInterval, Labels: map[string]string{"class": "", "name": "probe1"}, Latency: 500 * time.Millisecond, - Result: true, + Status: ProbeStatusSucceeded, RecentResults: []bool{true}, RecentLatencies: []time.Duration{500 * time.Millisecond}, }, @@ -324,6 +395,7 @@ func TestProberProbeInfo(t *testing.T) { Name: "probe2", Interval: probeInterval, Labels: map[string]string{"class": "", "name": "probe2"}, + Status: ProbeStatusFailed, Error: "error2", RecentResults: []bool{false}, RecentLatencies: nil, // no latency for failed probes @@ -349,7 +421,7 @@ func TestProbeInfoRecent(t *testing.T) { }{ { name: "no_runs", - wantProbeInfo: ProbeInfo{}, + wantProbeInfo: ProbeInfo{Status: ProbeStatusUnknown}, wantRecentSuccessRatio: 0, wantRecentMedianLatency: 0, }, @@ -358,7 +430,7 @@ func TestProbeInfoRecent(t *testing.T) { results: []probeResult{{latency: 100 * time.Millisecond, err: nil}}, wantProbeInfo: ProbeInfo{ Latency: 100 * time.Millisecond, - Result: true, + Status: ProbeStatusSucceeded, RecentResults: []bool{true}, RecentLatencies: []time.Duration{100 * time.Millisecond}, }, @@ -369,7 +441,7 @@ func TestProbeInfoRecent(t *testing.T) { name: "single_failure", results: []probeResult{{latency: 100 * time.Millisecond, err: errors.New("error123")}}, wantProbeInfo: ProbeInfo{ - Result: false, + Status: ProbeStatusFailed, RecentResults: []bool{false}, RecentLatencies: nil, Error: "error123", @@ -390,7 +462,7 @@ func TestProbeInfoRecent(t *testing.T) { {latency: 80 * time.Millisecond, err: nil}, }, wantProbeInfo: ProbeInfo{ - Result: true, + Status: ProbeStatusSucceeded, Latency: 80 * time.Millisecond, RecentResults: []bool{false, true, true, false, true, true, false, true}, RecentLatencies: []time.Duration{ @@ -420,7 +492,7 @@ func TestProbeInfoRecent(t *testing.T) { {latency: 110 * time.Millisecond, err: nil}, }, wantProbeInfo: ProbeInfo{ - Result: true, + Status: ProbeStatusSucceeded, Latency: 110 * time.Millisecond, RecentResults: []bool{true, true, true, true, true, true, true, true, true, true}, RecentLatencies: []time.Duration{ @@ -449,9 +521,11 @@ func TestProbeInfoRecent(t *testing.T) { for _, r := range tt.results { probe.recordStart() clk.Advance(r.latency) - probe.recordEnd(r.err) + probe.recordEndLocked(r.err) } + probe.mu.Lock() info := probe.probeInfoLocked() + probe.mu.Unlock() if diff := cmp.Diff(tt.wantProbeInfo, info, cmpopts.IgnoreFields(ProbeInfo{}, "Start", "End", "Interval")); diff != "" { t.Fatalf("unexpected ProbeInfo (-want +got):\n%s", diff) } @@ -473,7 +547,7 @@ func TestProberRunHandler(t *testing.T) { probeFunc func(context.Context) error wantResponseCode int wantJSONResponse RunHandlerResponse - wantPlaintextResponse string + wantPlaintextResponse *regexp.Regexp }{ { name: "success", @@ -483,12 +557,12 @@ func TestProberRunHandler(t *testing.T) { ProbeInfo: ProbeInfo{ Name: "success", Interval: probeInterval, - Result: true, + Status: ProbeStatusSucceeded, RecentResults: []bool{true, true}, }, PreviousSuccessRatio: 1, }, - wantPlaintextResponse: "Probe succeeded", + wantPlaintextResponse: regexp.MustCompile("(?s)Probe succeeded .*Last 2 probes.*success rate 100%"), }, { name: "failure", @@ -498,12 +572,12 @@ func TestProberRunHandler(t *testing.T) { ProbeInfo: ProbeInfo{ Name: "failure", Interval: probeInterval, - Result: false, + Status: ProbeStatusFailed, Error: "error123", RecentResults: []bool{false, false}, }, }, - wantPlaintextResponse: "Probe failed", + wantPlaintextResponse: regexp.MustCompile("(?s)Probe failed: .*Last 2 probes.*success rate 0%"), }, } @@ -515,29 +589,51 @@ func TestProberRunHandler(t *testing.T) { defer probe.Close() <-probe.stopped // wait for the first run. - w := httptest.NewRecorder() + mux := http.NewServeMux() + server := httptest.NewServer(mux) + defer server.Close() + + mux.Handle("/prober/run/", tsweb.StdHandler(tsweb.ReturnHandlerFunc(p.RunHandler), tsweb.HandlerOptions{})) + + req, err := http.NewRequest("GET", server.URL+"/prober/run/?name="+tt.name, nil) + if err != nil { + t.Fatalf("failed to create request: %v", err) + } - req := httptest.NewRequest("GET", "/prober/run/?name="+tt.name, nil) if reqJSON { req.Header.Set("Accept", "application/json") } - tsweb.StdHandler(tsweb.ReturnHandlerFunc(p.RunHandler), tsweb.HandlerOptions{}).ServeHTTP(w, req) - if w.Result().StatusCode != tt.wantResponseCode { - t.Errorf("unexpected response code: got %d, want %d", w.Code, tt.wantResponseCode) + + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("failed to make request: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != tt.wantResponseCode { + t.Errorf("unexpected response code: got %d, want %d", resp.StatusCode, tt.wantResponseCode) } if reqJSON { + if resp.Header.Get("Content-Type") != "application/json" { + t.Errorf("unexpected content type: got %q, want application/json", resp.Header.Get("Content-Type")) + } var gotJSON RunHandlerResponse - if err := json.Unmarshal(w.Body.Bytes(), &gotJSON); err != nil { - t.Fatalf("failed to unmarshal JSON response: %v; body: %s", err, w.Body.String()) + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("failed to read response body: %v", err) + } + + if err := json.Unmarshal(body, &gotJSON); err != nil { + t.Fatalf("failed to unmarshal JSON response: %v; body: %s", err, body) } if diff := cmp.Diff(tt.wantJSONResponse, gotJSON, cmpopts.IgnoreFields(ProbeInfo{}, "Start", "End", "Labels", "RecentLatencies")); diff != "" { t.Errorf("unexpected JSON response (-want +got):\n%s", diff) } } else { - body, _ := io.ReadAll(w.Result().Body) - if !strings.Contains(string(body), tt.wantPlaintextResponse) { - t.Errorf("unexpected response body: got %q, want to contain %q", body, tt.wantPlaintextResponse) + body, _ := io.ReadAll(resp.Body) + if !tt.wantPlaintextResponse.MatchString(string(body)) { + t.Errorf("unexpected response body: got %q, want to match %q", body, tt.wantPlaintextResponse) } } }) @@ -546,6 +642,190 @@ func TestProberRunHandler(t *testing.T) { } +func TestRunAllHandler(t *testing.T) { + clk := newFakeTime() + + tests := []struct { + name string + probeFunc []func(context.Context) error + wantResponseCode int + wantJSONResponse RunHandlerAllResponse + wantPlaintextResponse string + }{ + { + name: "successProbe", + probeFunc: []func(context.Context) error{func(context.Context) error { return nil }, func(context.Context) error { return nil }}, + wantResponseCode: http.StatusOK, + wantJSONResponse: RunHandlerAllResponse{ + Results: map[string]RunHandlerResponse{ + "successProbe-0": { + ProbeInfo: ProbeInfo{ + Name: "successProbe-0", + Interval: probeInterval, + Status: ProbeStatusSucceeded, + RecentResults: []bool{true, true}, + }, + PreviousSuccessRatio: 1, + }, + "successProbe-1": { + ProbeInfo: ProbeInfo{ + Name: "successProbe-1", + Interval: probeInterval, + Status: ProbeStatusSucceeded, + RecentResults: []bool{true, true}, + }, + PreviousSuccessRatio: 1, + }, + }, + }, + wantPlaintextResponse: "Probe successProbe-0: succeeded\n\tLast run: 0s\n\tPrevious success rate: 100.0%\n\tPrevious median latency: 0s\nProbe successProbe-1: succeeded\n\tLast run: 0s\n\tPrevious success rate: 100.0%\n\tPrevious median latency: 0s\n\n", + }, + { + name: "successAndFailureProbes", + probeFunc: []func(context.Context) error{func(context.Context) error { return nil }, func(context.Context) error { return fmt.Errorf("error2") }}, + wantResponseCode: http.StatusFailedDependency, + wantJSONResponse: RunHandlerAllResponse{ + Results: map[string]RunHandlerResponse{ + "successAndFailureProbes-0": { + ProbeInfo: ProbeInfo{ + Name: "successAndFailureProbes-0", + Interval: probeInterval, + Status: ProbeStatusSucceeded, + RecentResults: []bool{true, true}, + }, + PreviousSuccessRatio: 1, + }, + "successAndFailureProbes-1": { + ProbeInfo: ProbeInfo{ + Name: "successAndFailureProbes-1", + Interval: probeInterval, + Status: ProbeStatusFailed, + Error: "error2", + RecentResults: []bool{false, false}, + }, + }, + }, + }, + wantPlaintextResponse: "Probe successAndFailureProbes-0: succeeded\n\tLast run: 0s\n\tPrevious success rate: 100.0%\n\tPrevious median latency: 0s\nProbe successAndFailureProbes-1: failed\n\tLast run: 0s\n\tPrevious success rate: 0.0%\n\tPrevious median latency: 0s\n\n\tLast error: error2\n\n", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + p := newForTest(clk.Now, clk.NewTicker).WithOnce(true) + for i, pfunc := range tc.probeFunc { + probe := p.Run(fmt.Sprintf("%s-%d", tc.name, i), probeInterval, nil, FuncProbe(pfunc)) + defer probe.Close() + <-probe.stopped // wait for the first run. + } + + mux := http.NewServeMux() + server := httptest.NewServer(mux) + defer server.Close() + + mux.Handle("/prober/runall/", tsweb.StdHandler(tsweb.ReturnHandlerFunc(p.RunAllHandler), tsweb.HandlerOptions{})) + + req, err := http.NewRequest("GET", server.URL+"/prober/runall", nil) + if err != nil { + t.Fatalf("failed to create request: %v", err) + } + + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("failed to make request: %v", err) + } + + if resp.StatusCode != tc.wantResponseCode { + t.Errorf("unexpected response code: got %d, want %d", resp.StatusCode, tc.wantResponseCode) + } + + if resp.Header.Get("Content-Type") != "application/json" { + t.Errorf("unexpected content type: got %q, want application/json", resp.Header.Get("Content-Type")) + } + var gotJSON RunHandlerAllResponse + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("failed to read response body: %v", err) + } + + if err := json.Unmarshal(body, &gotJSON); err != nil { + t.Fatalf("failed to unmarshal JSON response: %v; body: %s", err, body) + } + if diff := cmp.Diff(tc.wantJSONResponse, gotJSON, cmpopts.IgnoreFields(ProbeInfo{}, "Start", "End", "Labels", "RecentLatencies")); diff != "" { + t.Errorf("unexpected JSON response (-want +got):\n%s", diff) + } + + }) + } + +} + +func TestExcludeInRunAll(t *testing.T) { + clk := newFakeTime() + p := newForTest(clk.Now, clk.NewTicker).WithOnce(true) + + wantJSONResponse := RunHandlerAllResponse{ + Results: map[string]RunHandlerResponse{ + "includedProbe": { + ProbeInfo: ProbeInfo{ + Name: "includedProbe", + Interval: probeInterval, + Status: ProbeStatusSucceeded, + RecentResults: []bool{true, true}, + }, + PreviousSuccessRatio: 1, + }, + }, + } + + p.Run("includedProbe", probeInterval, nil, FuncProbe(func(context.Context) error { return nil })) + p.Run("excludedProbe", probeInterval, nil, FuncProbe(func(context.Context) error { return nil })) + p.Run("excludedOtherProbe", probeInterval, nil, FuncProbe(func(context.Context) error { return nil })) + + mux := http.NewServeMux() + server := httptest.NewServer(mux) + defer server.Close() + + mux.Handle("/prober/runall", tsweb.StdHandler(tsweb.ReturnHandlerFunc(p.RunAllHandler), tsweb.HandlerOptions{})) + + req, err := http.NewRequest("GET", server.URL+"/prober/runall", nil) + if err != nil { + t.Fatalf("failed to create request: %v", err) + } + + // Exclude probes with "excluded" in their name + req.URL.RawQuery = url.Values{ + "exclude": []string{"excludedProbe", "excludedOtherProbe"}, + }.Encode() + + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("failed to make request: %v", err) + } + + if resp.StatusCode != http.StatusOK { + t.Errorf("unexpected response code: got %d, want %d", resp.StatusCode, http.StatusOK) + } + + var gotJSON RunHandlerAllResponse + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("failed to read response body: %v", err) + } + + if err := json.Unmarshal(body, &gotJSON); err != nil { + t.Fatalf("failed to unmarshal JSON response: %v; body: %s", err, body) + } + + if resp.Header.Get("Content-Type") != "application/json" { + t.Errorf("unexpected content type: got %q, want application/json", resp.Header.Get("Content-Type")) + } + + if diff := cmp.Diff(wantJSONResponse, gotJSON, cmpopts.IgnoreFields(ProbeInfo{}, "Start", "End", "Labels", "RecentLatencies")); diff != "" { + t.Errorf("unexpected JSON response (-want +got):\n%s", diff) + } +} + type fakeTicker struct { ch chan time.Time interval time.Duration diff --git a/prober/status.go b/prober/status.go index aa9ef99d0..20fbeec58 100644 --- a/prober/status.go +++ b/prober/status.go @@ -62,8 +62,9 @@ func (p *Prober) StatusHandler(opts ...statusHandlerOpt) tsweb.ReturnHandlerFunc return func(w http.ResponseWriter, r *http.Request) error { type probeStatus struct { ProbeInfo - TimeSinceLast time.Duration - Links map[string]template.URL + TimeSinceLastStart time.Duration + TimeSinceLastEnd time.Duration + Links map[string]template.URL } vars := struct { Title string @@ -81,12 +82,15 @@ func (p *Prober) StatusHandler(opts ...statusHandlerOpt) tsweb.ReturnHandlerFunc for name, info := range p.ProbeInfo() { vars.TotalProbes++ - if !info.Result { + if info.Error != "" { vars.UnhealthyProbes++ } s := probeStatus{ProbeInfo: info} + if !info.Start.IsZero() { + s.TimeSinceLastStart = time.Since(info.Start).Truncate(time.Second) + } if !info.End.IsZero() { - s.TimeSinceLast = time.Since(info.End).Truncate(time.Second) + s.TimeSinceLastEnd = time.Since(info.End).Truncate(time.Second) } for textTpl, urlTpl := range params.probeLinks { text, err := renderTemplate(textTpl, info) diff --git a/prober/status.html b/prober/status.html index ff0f06c13..d26588da1 100644 --- a/prober/status.html +++ b/prober/status.html @@ -73,8 +73,9 @@ Name Probe Class & Labels Interval - Last Attempt - Success + Last Finished + Last Started + Status Latency Last Error @@ -85,9 +86,11 @@ {{$name}} {{range $text, $url := $probeInfo.Links}}
- + {{if not $probeInfo.Continuous}} + + {{end}} {{end}} {{$probeInfo.Class}}
@@ -97,28 +100,48 @@ {{end}} - {{$probeInfo.Interval}} - - {{if $probeInfo.TimeSinceLast}} - {{$probeInfo.TimeSinceLast.String}} ago
+ + {{if $probeInfo.Continuous}} + Continuous + {{else}} + {{$probeInfo.Interval}} + {{end}} + + + {{if $probeInfo.TimeSinceLastEnd}} + {{$probeInfo.TimeSinceLastEnd.String}} ago
{{$probeInfo.End.Format "2006-01-02T15:04:05Z07:00"}} {{else}} Never {{end}} + + {{if $probeInfo.TimeSinceLastStart}} + {{$probeInfo.TimeSinceLastStart.String}} ago
+ {{$probeInfo.Start.Format "2006-01-02T15:04:05Z07:00"}} + {{else}} + Never + {{end}} + - {{if $probeInfo.Result}} - {{$probeInfo.Result}} + {{if $probeInfo.Error}} + {{$probeInfo.Status}} {{else}} - {{$probeInfo.Result}} + {{$probeInfo.Status}} {{end}}
-
Recent: {{$probeInfo.RecentResults}}
-
Mean: {{$probeInfo.RecentSuccessRatio}}
+ {{if not $probeInfo.Continuous}} +
Recent: {{$probeInfo.RecentResults}}
+
Mean: {{$probeInfo.RecentSuccessRatio}}
+ {{end}} - {{$probeInfo.Latency.String}} -
Recent: {{$probeInfo.RecentLatencies}}
-
Median: {{$probeInfo.RecentMedianLatency}}
+ {{if $probeInfo.Continuous}} + n/a + {{else}} + {{$probeInfo.Latency.String}} +
Recent: {{$probeInfo.RecentLatencies}}
+
Median: {{$probeInfo.RecentMedianLatency}}
+ {{end}} {{$probeInfo.Error}} diff --git a/prober/tls.go b/prober/tls.go index 787df05c2..3ce535435 100644 --- a/prober/tls.go +++ b/prober/tls.go @@ -4,56 +4,54 @@ package prober import ( - "bytes" "context" "crypto/tls" "crypto/x509" + "errors" "fmt" "io" - "net" "net/http" "net/netip" + "slices" "time" - - "github.com/pkg/errors" - "golang.org/x/crypto/ocsp" - "tailscale.com/util/multierr" ) const expiresSoon = 7 * 24 * time.Hour // 7 days from now +// Let’s Encrypt promises to issue certificates with CRL servers after 2025-05-07: +// https://letsencrypt.org/2024/12/05/ending-ocsp/ +// https://github.com/tailscale/tailscale/issues/15912 +const letsEncryptStartedStaplingCRL int64 = 1746576000 // 2025-05-07 00:00:00 UTC // TLS returns a Probe that healthchecks a TLS endpoint. // // The ProbeFunc connects to a hostPort (host:port string), does a TLS // handshake, verifies that the hostname matches the presented certificate, // checks certificate validity time and OCSP revocation status. -func TLS(hostPort string) ProbeClass { +// +// The TLS config is optional and may be nil. +func TLS(hostPort string, config *tls.Config) ProbeClass { return ProbeClass{ Probe: func(ctx context.Context) error { - certDomain, _, err := net.SplitHostPort(hostPort) - if err != nil { - return err - } - return probeTLS(ctx, certDomain, hostPort) + return probeTLS(ctx, config, hostPort) }, Class: "tls", } } -// TLSWithIP is like TLS, but dials the provided dialAddr instead -// of using DNS resolution. The certDomain is the expected name in -// the cert (and the SNI name to send). -func TLSWithIP(certDomain string, dialAddr netip.AddrPort) ProbeClass { +// TLSWithIP is like TLS, but dials the provided dialAddr instead of using DNS +// resolution. Use config.ServerName to send SNI and validate the name in the +// cert. +func TLSWithIP(dialAddr netip.AddrPort, config *tls.Config) ProbeClass { return ProbeClass{ Probe: func(ctx context.Context) error { - return probeTLS(ctx, certDomain, dialAddr.String()) + return probeTLS(ctx, config, dialAddr.String()) }, Class: "tls", } } -func probeTLS(ctx context.Context, certDomain string, dialHostPort string) error { - dialer := &tls.Dialer{Config: &tls.Config{ServerName: certDomain}} +func probeTLS(ctx context.Context, config *tls.Config, dialHostPort string) error { + dialer := &tls.Dialer{Config: config} conn, err := dialer.DialContext(ctx, "tcp", dialHostPort) if err != nil { return fmt.Errorf("connecting to %q: %w", dialHostPort, err) @@ -70,7 +68,7 @@ func probeTLS(ctx context.Context, certDomain string, dialHostPort string) error func validateConnState(ctx context.Context, cs *tls.ConnectionState) (returnerr error) { var errs []error defer func() { - returnerr = multierr.New(errs...) + returnerr = errors.Join(errs...) }() latestAllowedExpiration := time.Now().Add(expiresSoon) @@ -106,50 +104,59 @@ func validateConnState(ctx context.Context, cs *tls.ConnectionState) (returnerr } } - if len(leafCert.OCSPServer) == 0 { - errs = append(errs, fmt.Errorf("no OCSP server presented in leaf cert for %v", leafCert.Subject)) + if len(leafCert.CRLDistributionPoints) == 0 { + if !slices.Contains(leafCert.Issuer.Organization, "Let's Encrypt") { + // LE certs contain a CRL, but certs from other CAs might not. + return + } + if leafCert.NotBefore.Before(time.Unix(letsEncryptStartedStaplingCRL, 0)) { + // Certificate might not have a CRL. + return + } + errs = append(errs, fmt.Errorf("no CRL server presented in leaf cert for %v", leafCert.Subject)) return } - ocspResp, err := getOCSPResponse(ctx, leafCert.OCSPServer[0], leafCert, issuerCert) + err := checkCertCRL(ctx, leafCert.CRLDistributionPoints[0], leafCert, issuerCert) if err != nil { - errs = append(errs, errors.Wrapf(err, "OCSP verification failed for %v", leafCert.Subject)) - return - } - - if ocspResp.Status == ocsp.Unknown { - errs = append(errs, fmt.Errorf("unknown OCSP verification status for %v", leafCert.Subject)) - } - - if ocspResp.Status == ocsp.Revoked { - errs = append(errs, fmt.Errorf("cert for %v has been revoked on %v, reason: %v", leafCert.Subject, ocspResp.RevokedAt, ocspResp.RevocationReason)) + errs = append(errs, fmt.Errorf("CRL verification failed for %v: %w", leafCert.Subject, err)) } return } -func getOCSPResponse(ctx context.Context, ocspServer string, leafCert, issuerCert *x509.Certificate) (*ocsp.Response, error) { - reqb, err := ocsp.CreateRequest(leafCert, issuerCert, nil) +func checkCertCRL(ctx context.Context, crlURL string, leafCert, issuerCert *x509.Certificate) error { + hreq, err := http.NewRequestWithContext(ctx, "GET", crlURL, nil) if err != nil { - return nil, errors.Wrap(err, "could not create OCSP request") + return fmt.Errorf("could not create CRL GET request: %w", err) } - hreq, err := http.NewRequestWithContext(ctx, "POST", ocspServer, bytes.NewReader(reqb)) - if err != nil { - return nil, errors.Wrap(err, "could not create OCSP POST request") - } - hreq.Header.Add("Content-Type", "application/ocsp-request") - hreq.Header.Add("Accept", "application/ocsp-response") hresp, err := http.DefaultClient.Do(hreq) if err != nil { - return nil, errors.Wrap(err, "OCSP request failed") + return fmt.Errorf("CRL request failed: %w", err) } defer hresp.Body.Close() if hresp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("ocsp: non-200 status code from OCSP server: %s", hresp.Status) + return fmt.Errorf("crl: non-200 status code from CRL server: %s", hresp.Status) } lr := io.LimitReader(hresp.Body, 10<<20) // 10MB - ocspB, err := io.ReadAll(lr) + crlB, err := io.ReadAll(lr) if err != nil { - return nil, err + return err + } + + crl, err := x509.ParseRevocationList(crlB) + if err != nil { + return fmt.Errorf("could not parse CRL: %w", err) + } + + if err := crl.CheckSignatureFrom(issuerCert); err != nil { + return fmt.Errorf("could not verify CRL signature: %w", err) } - return ocsp.ParseResponse(ocspB, issuerCert) + + for _, revoked := range crl.RevokedCertificateEntries { + if revoked.SerialNumber.Cmp(leafCert.SerialNumber) == 0 { + return fmt.Errorf("cert for %v has been revoked on %v, reason: %v", leafCert.Subject, revoked.RevocationTime, revoked.ReasonCode) + } + } + + return nil } diff --git a/prober/tls_test.go b/prober/tls_test.go index 5bfb739db..86fba91b9 100644 --- a/prober/tls_test.go +++ b/prober/tls_test.go @@ -6,7 +6,7 @@ package prober import ( "bytes" "context" - "crypto" + "crypto/ecdsa" "crypto/rand" "crypto/rsa" "crypto/tls" @@ -20,8 +20,6 @@ import ( "strings" "testing" "time" - - "golang.org/x/crypto/ocsp" ) var leafCert = x509.Certificate{ @@ -85,7 +83,7 @@ func TestTLSConnection(t *testing.T) { srv.StartTLS() defer srv.Close() - err = probeTLS(context.Background(), "fail.example.com", srv.Listener.Addr().String()) + err = probeTLS(context.Background(), &tls.Config{ServerName: "fail.example.com"}, srv.Listener.Addr().String()) // The specific error message here is platform-specific ("certificate is not trusted" // on macOS and "certificate signed by unknown authority" on Linux), so only check // that it contains the word 'certificate'. @@ -118,11 +116,6 @@ func TestCertExpiration(t *testing.T) { }, "one of the certs expires in", }, - { - "valid duration but no OCSP", - func() *x509.Certificate { return &leafCert }, - "no OCSP server presented in leaf cert for CN=tlsprobe.test", - }, } { t.Run(tt.name, func(t *testing.T) { cs := &tls.ConnectionState{PeerCertificates: []*x509.Certificate{tt.cert()}} @@ -134,100 +127,211 @@ func TestCertExpiration(t *testing.T) { } } -type ocspServer struct { - issuer *x509.Certificate - responderCert *x509.Certificate - template *ocsp.Response - priv crypto.Signer +type CRLServer struct { + crlBytes []byte } -func (s *ocspServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { - if s.template == nil { +func (s *CRLServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if s.crlBytes == nil { w.WriteHeader(http.StatusInternalServerError) return } - resp, err := ocsp.CreateResponse(s.issuer, s.responderCert, *s.template, s.priv) + w.Header().Set("Content-Type", "application/pkix-crl") + w.WriteHeader(http.StatusOK) + w.Write(s.crlBytes) +} + +// someECDSAKey{1,2,3} are different EC private keys in PEM format +// as generated by: +// +// openssl ecparam -name prime256v1 -genkey -noout -out - +// +// They're used in tests to avoid burning CPU at test time to just +// to make some arbitrary test keys. +const ( + someECDSAKey1 = ` +-----BEGIN EC PRIVATE KEY----- +MHcCAQEEIDKggO47Si0/JgqF0q9m0HfQ92lbERWsBaKS5YihtuheoAoGCCqGSM49 +AwEHoUQDQgAE/JtNZkfFmAGQJHW5Xgz0Eoyi9MKVxl77sXjIFDMX233QDIWPEM/B +vmNMvdFkuYBjwbq6H+SNf1NXRNladEGU/Q== +-----END EC PRIVATE KEY----- +` + someECDSAKey2 = ` +-----BEGIN EC PRIVATE KEY----- +MHcCAQEEIPIJhRf4MpzLil1ZKcRqMx+jPeJXw96KtYYzV2AcgBzgoAoGCCqGSM49 +AwEHoUQDQgAEhA9CSWFmUvdvXMzyt+as+6f+0luydHU1x/gEksVByYIgYxahaGts +xbSKj6F2WgAN/ok1gFLqhH3UWMNVthM1wA== +-----END EC PRIVATE KEY----- +` + someECDSAKey3 = ` +-----BEGIN EC PRIVATE KEY----- +MHcCAQEEIKgZ1OJjK2St9O0i52N1K+IgSiu2/NSMk9Yt2+kDMHd7oAoGCCqGSM49 +AwEHoUQDQgAExFp80etkjy/AEUtSgJjXRA39jTU7eiEmCGRREewFQhwcEscBEfrg +6NN31r9YlEs+hZ8gXE1L3Deu6jn5jW3pig== +-----END EC PRIVATE KEY----- +` +) + +// parseECKey parses an EC private key from a PEM-encoded string. +func parseECKey(t *testing.T, pemPriv string) *ecdsa.PrivateKey { + t.Helper() + block, _ := pem.Decode([]byte(pemPriv)) + if block == nil { + t.Fatal("failed to decode PEM") + } + key, err := x509.ParseECPrivateKey(block.Bytes) if err != nil { - panic(err) + t.Fatalf("failed to parse EC key: %v", err) } - w.Write(resp) + return key } -func TestOCSP(t *testing.T) { - issuerKey, err := rsa.GenerateKey(rand.Reader, 4096) +func TestCRL(t *testing.T) { + // Generate CA key and self-signed CA cert + caKey := parseECKey(t, someECDSAKey1) + + caTpl := issuerCertTpl + caTpl.BasicConstraintsValid = true + caTpl.IsCA = true + caTpl.KeyUsage = x509.KeyUsageCertSign | x509.KeyUsageCRLSign | x509.KeyUsageDigitalSignature + caTpl.SignatureAlgorithm = x509.ECDSAWithSHA256 + caBytes, err := x509.CreateCertificate(rand.Reader, &caTpl, &caTpl, &caKey.PublicKey, caKey) if err != nil { t.Fatal(err) } - issuerBytes, err := x509.CreateCertificate(rand.Reader, &issuerCertTpl, &issuerCertTpl, &issuerKey.PublicKey, issuerKey) + caCert, err := x509.ParseCertificate(caBytes) if err != nil { t.Fatal(err) } - issuerCert, err := x509.ParseCertificate(issuerBytes) + + // Issue a leaf cert signed by the CA + leaf := leafCert + leaf.SerialNumber = big.NewInt(20001) + leaf.SignatureAlgorithm = x509.ECDSAWithSHA256 + leaf.Issuer = caCert.Subject + leafKey := parseECKey(t, someECDSAKey2) + leafBytes, err := x509.CreateCertificate(rand.Reader, &leaf, caCert, &leafKey.PublicKey, caKey) if err != nil { t.Fatal(err) } - - responderKey, err := rsa.GenerateKey(rand.Reader, 4096) + leafCertParsed, err := x509.ParseCertificate(leafBytes) if err != nil { t.Fatal(err) } - // issuer cert template re-used here, but with a different key - responderBytes, err := x509.CreateCertificate(rand.Reader, &issuerCertTpl, &issuerCertTpl, &responderKey.PublicKey, responderKey) + + // Catch no CRL set by Let's Encrypt date. + noCRLCert := leafCert + noCRLCert.SerialNumber = big.NewInt(20002) + noCRLCert.CRLDistributionPoints = []string{} + noCRLCert.NotBefore = time.Unix(letsEncryptStartedStaplingCRL, 0).Add(-48 * time.Hour) + noCRLCert.Issuer = caCert.Subject + noCRLCert.SignatureAlgorithm = x509.ECDSAWithSHA256 + noCRLCertKey := parseECKey(t, someECDSAKey3) + noCRLStapledBytes, err := x509.CreateCertificate(rand.Reader, &noCRLCert, caCert, &noCRLCertKey.PublicKey, caKey) if err != nil { t.Fatal(err) } - responderCert, err := x509.ParseCertificate(responderBytes) + noCRLStapledParsed, err := x509.ParseCertificate(noCRLStapledBytes) if err != nil { t.Fatal(err) } - handler := &ocspServer{ - issuer: issuerCert, - responderCert: responderCert, - priv: issuerKey, - } - srv := httptest.NewUnstartedServer(handler) - srv.Start() + crlServer := &CRLServer{crlBytes: nil} + srv := httptest.NewServer(crlServer) defer srv.Close() - cert := leafCert - cert.OCSPServer = append(cert.OCSPServer, srv.URL) - key, err := rsa.GenerateKey(rand.Reader, 4096) - if err != nil { - t.Fatal(err) + // Create a CRL that revokes the leaf cert using x509.CreateRevocationList + now := time.Now() + revoked := []x509.RevocationListEntry{{ + SerialNumber: leaf.SerialNumber, + RevocationTime: now, + ReasonCode: 1, // Key compromise + }} + rl := x509.RevocationList{ + SignatureAlgorithm: caCert.SignatureAlgorithm, + Issuer: caCert.Subject, + ThisUpdate: now, + NextUpdate: now.Add(24 * time.Hour), + RevokedCertificateEntries: revoked, + Number: big.NewInt(1), } - certBytes, err := x509.CreateCertificate(rand.Reader, &cert, issuerCert, &key.PublicKey, issuerKey) + rlBytes, err := x509.CreateRevocationList(rand.Reader, &rl, caCert, caKey) if err != nil { t.Fatal(err) } - parsed, err := x509.ParseCertificate(certBytes) + + emptyRlBytes, err := x509.CreateRevocationList(rand.Reader, &x509.RevocationList{Number: big.NewInt(2)}, caCert, caKey) if err != nil { t.Fatal(err) } for _, tt := range []struct { - name string - resp *ocsp.Response - wantErr string + name string + cert *x509.Certificate + crlBytes []byte + issuer pkix.Name + wantErr string }{ - {"good response", &ocsp.Response{Status: ocsp.Good}, ""}, - {"unknown response", &ocsp.Response{Status: ocsp.Unknown}, "unknown OCSP verification status for CN=tlsprobe.test"}, - {"revoked response", &ocsp.Response{Status: ocsp.Revoked}, "cert for CN=tlsprobe.test has been revoked"}, - {"error 500 from ocsp", nil, "non-200 status code from OCSP"}, + { + "ValidCert", + leafCertParsed, + emptyRlBytes, + caCert.Issuer, + "", + }, + { + "RevokedCert", + leafCertParsed, + rlBytes, + caCert.Issuer, + "has been revoked on", + }, + { + "EmptyCRL", + leafCertParsed, + emptyRlBytes, + caCert.Issuer, + "", + }, + { + "NoCRLLetsEncrypt", + leafCertParsed, + nil, + pkix.Name{CommonName: "tlsprobe.test", Organization: []string{"Let's Encrypt"}}, + "no CRL server presented in leaf cert for", + }, + { + "NoCRLOtherCA", + leafCertParsed, + nil, + caCert.Issuer, + "", + }, + { + "NotBeforeCRLStaplingDate", + noCRLStapledParsed, + nil, + caCert.Issuer, + "", + }, } { t.Run(tt.name, func(t *testing.T) { - handler.template = tt.resp - if handler.template != nil { - handler.template.SerialNumber = big.NewInt(1337) + tt.cert.Issuer = tt.issuer + cs := &tls.ConnectionState{PeerCertificates: []*x509.Certificate{tt.cert, caCert}} + if tt.crlBytes != nil { + crlServer.crlBytes = tt.crlBytes + tt.cert.CRLDistributionPoints = []string{srv.URL} + } else { + crlServer.crlBytes = nil + tt.cert.CRLDistributionPoints = []string{} } - cs := &tls.ConnectionState{PeerCertificates: []*x509.Certificate{parsed, issuerCert}} err := validateConnState(context.Background(), cs) if err == nil && tt.wantErr == "" { return } - if err == nil || !strings.Contains(err.Error(), tt.wantErr) { + if err == nil || tt.wantErr == "" || !strings.Contains(err.Error(), tt.wantErr) { t.Errorf("unexpected error %q; want %q", err, tt.wantErr) } }) diff --git a/prober/tun_darwin.go b/prober/tun_darwin.go new file mode 100644 index 000000000..0ef22e41e --- /dev/null +++ b/prober/tun_darwin.go @@ -0,0 +1,35 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build darwin + +package prober + +import ( + "fmt" + "net/netip" + "os/exec" + + "go4.org/netipx" +) + +const tunName = "utun" + +func configureTUN(addr netip.Prefix, tunname string) error { + cmd := exec.Command("ifconfig", tunname, "inet", addr.String(), addr.Addr().String()) + res, err := cmd.CombinedOutput() + if err != nil { + return fmt.Errorf("failed to add address: %w (%s)", err, string(res)) + } + + net := netipx.PrefixIPNet(addr) + nip := net.IP.Mask(net.Mask) + nstr := fmt.Sprintf("%v/%d", nip, addr.Bits()) + cmd = exec.Command("route", "-q", "-n", "add", "-inet", nstr, "-iface", addr.Addr().String()) + res, err = cmd.CombinedOutput() + if err != nil { + return fmt.Errorf("failed to add route: %w (%s)", err, string(res)) + } + + return nil +} diff --git a/prober/tun_default.go b/prober/tun_default.go new file mode 100644 index 000000000..93a5b07fd --- /dev/null +++ b/prober/tun_default.go @@ -0,0 +1,18 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !linux && !darwin + +package prober + +import ( + "fmt" + "net/netip" + "runtime" +) + +const tunName = "unused" + +func configureTUN(addr netip.Prefix, tunname string) error { + return fmt.Errorf("not implemented on " + runtime.GOOS) +} diff --git a/prober/tun_linux.go b/prober/tun_linux.go new file mode 100644 index 000000000..52a31efbb --- /dev/null +++ b/prober/tun_linux.go @@ -0,0 +1,36 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build linux + +package prober + +import ( + "fmt" + "net/netip" + + "github.com/tailscale/netlink" + "go4.org/netipx" +) + +const tunName = "derpprobe" + +func configureTUN(addr netip.Prefix, tunname string) error { + link, err := netlink.LinkByName(tunname) + if err != nil { + return fmt.Errorf("failed to look up link %q: %w", tunname, err) + } + + // We need to bring the TUN device up before assigning an address. This + // allows the OS to automatically create a route for it. Otherwise, we'd + // have to manually create the route. + if err := netlink.LinkSetUp(link); err != nil { + return fmt.Errorf("failed to bring tun %q up: %w", tunname, err) + } + + if err := netlink.AddrReplace(link, &netlink.Addr{IPNet: netipx.PrefixIPNet(addr)}); err != nil { + return fmt.Errorf("failed to add address: %w", err) + } + + return nil +} diff --git a/proxymap/proxymap.go b/proxymap/proxymap.go index dfe6f2d58..20dc96c84 100644 --- a/proxymap/proxymap.go +++ b/proxymap/proxymap.go @@ -9,9 +9,9 @@ import ( "fmt" "net/netip" "strings" - "sync" "time" + "tailscale.com/syncs" "tailscale.com/util/mak" ) @@ -22,7 +22,7 @@ import ( // ask tailscaled (via the LocalAPI WhoIs method) the Tailscale identity that a // given localhost:port corresponds to. type Mapper struct { - mu sync.Mutex + mu syncs.Mutex // m holds the mapping from localhost IP:ports to Tailscale IPs. It is // keyed first by the protocol ("tcp" or "udp"), then by the IP:port. diff --git a/pull-toolchain.sh b/pull-toolchain.sh index f5a19e7d7..eb8febf6b 100755 --- a/pull-toolchain.sh +++ b/pull-toolchain.sh @@ -11,6 +11,10 @@ if [ "$upstream" != "$current" ]; then echo "$upstream" >go.toolchain.rev fi -if [ -n "$(git diff-index --name-only HEAD -- go.toolchain.rev)" ]; then +./tool/go version 2>/dev/null | awk '{print $3}' | sed 's/^go//' > go.toolchain.version + +./update-flake.sh + +if [ -n "$(git diff-index --name-only HEAD -- go.toolchain.rev go.toolchain.rev.sri go.toolchain.version)" ]; then echo "pull-toolchain.sh: changes imported. Use git commit to make them permanent." >&2 fi diff --git a/release/dist/cli/cli.go b/release/dist/cli/cli.go index 9b861ddd7..f4480cbdb 100644 --- a/release/dist/cli/cli.go +++ b/release/dist/cli/cli.go @@ -65,6 +65,7 @@ func CLI(getTargets func() ([]dist.Target, error)) *ffcli.Command { fs.StringVar(&buildArgs.manifest, "manifest", "", "manifest file to write") fs.BoolVar(&buildArgs.verbose, "verbose", false, "verbose logging") fs.StringVar(&buildArgs.webClientRoot, "web-client-root", "", "path to root of web client source to build") + fs.StringVar(&buildArgs.outPath, "out", "", "path to write output artifacts (defaults to '$PWD/dist' if not set)") return fs })(), LongHelp: strings.TrimSpace(` @@ -156,6 +157,7 @@ var buildArgs struct { manifest string verbose bool webClientRoot string + outPath string } func runBuild(ctx context.Context, filters []string, targets []dist.Target) error { @@ -172,7 +174,11 @@ func runBuild(ctx context.Context, filters []string, targets []dist.Target) erro if err != nil { return fmt.Errorf("getting working directory: %w", err) } - b, err := dist.NewBuild(wd, filepath.Join(wd, "dist")) + outPath := filepath.Join(wd, "dist") + if buildArgs.outPath != "" { + outPath = buildArgs.outPath + } + b, err := dist.NewBuild(wd, outPath) if err != nil { return fmt.Errorf("creating build context: %w", err) } diff --git a/release/dist/dist.go b/release/dist/dist.go index 802d9041b..6fb010299 100644 --- a/release/dist/dist.go +++ b/release/dist/dist.go @@ -20,7 +20,6 @@ import ( "sync" "time" - "tailscale.com/util/multierr" "tailscale.com/version/mkversion" ) @@ -176,7 +175,7 @@ func (b *Build) Build(targets []Target) (files []string, err error) { } sort.Strings(files) - return files, multierr.New(errs...) + return files, errors.Join(errs...) } // Once runs fn if Once hasn't been called with name before. diff --git a/release/dist/qnap/files/scripts/Dockerfile.qpkg b/release/dist/qnap/files/scripts/Dockerfile.qpkg index 135d5d20f..8e99630d1 100644 --- a/release/dist/qnap/files/scripts/Dockerfile.qpkg +++ b/release/dist/qnap/files/scripts/Dockerfile.qpkg @@ -1,9 +1,21 @@ -FROM ubuntu:20.04 +FROM ubuntu:24.04 RUN apt-get update -y && \ apt-get install -y --no-install-recommends \ git-core \ - ca-certificates -RUN git clone https://github.com/qnap-dev/QDK.git + ca-certificates \ + apt-transport-https \ + gnupg \ + curl \ + patch + +# Install Google Cloud PKCS11 module +RUN curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | gpg --dearmor -o /usr/share/keyrings/cloud.google.gpg +RUN echo "deb [signed-by=/usr/share/keyrings/cloud.google.gpg] https://packages.cloud.google.com/apt cloud-sdk main" | tee -a /etc/apt/sources.list.d/google-cloud-sdk.list +RUN apt-get update -y && apt-get install -y --no-install-recommends google-cloud-cli libengine-pkcs11-openssl +RUN curl -L https://github.com/GoogleCloudPlatform/kms-integrations/releases/download/pkcs11-v1.7/libkmsp11-1.7-linux-amd64.tar.gz | tar xz + +# Install QNAP QDK (force a specific version to pick up updates) +RUN git clone https://github.com/tailscale/QDK.git && cd /QDK && git reset --hard 8478a990decf0b0bb259ae11c636e66bfeff2433 RUN cd /QDK && ./InstallToUbuntu.sh install -ENV PATH="/usr/share/QDK/bin:${PATH}" \ No newline at end of file +ENV PATH="/usr/share/QDK/bin:${PATH}" diff --git a/release/dist/qnap/files/scripts/sign-qpkg.sh b/release/dist/qnap/files/scripts/sign-qpkg.sh new file mode 100755 index 000000000..1dacb876f --- /dev/null +++ b/release/dist/qnap/files/scripts/sign-qpkg.sh @@ -0,0 +1,43 @@ +#! /usr/bin/env bash +set -xeu + +mkdir -p "$HOME/.config/gcloud" +echo "$GCLOUD_CREDENTIALS_BASE64" | base64 --decode > /root/.config/gcloud/application_default_credentials.json +gcloud config set project "$GCLOUD_PROJECT" + +echo "--- +tokens: + - key_ring: \"$GCLOUD_KEYRING\" +log_directory: "/tmp/kmsp11" +" > pkcs11-config.yaml +chmod 0600 pkcs11-config.yaml + +export KMS_PKCS11_CONFIG=`readlink -f pkcs11-config.yaml` +export PKCS11_MODULE_PATH=/libkmsp11-1.7-linux-amd64/libkmsp11.so + +# Verify signature of pkcs11 module +# See https://github.com/GoogleCloudPlatform/kms-integrations/blob/master/kmsp11/docs/user_guide.md#downloading-and-verifying-the-library +echo "-----BEGIN PUBLIC KEY----- +MHYwEAYHKoZIzj0CAQYFK4EEACIDYgAEtfLbXkHUVc9oUPTNyaEK3hIwmuGRoTtd +6zDhwqjJuYaMwNd1aaFQLMawTwZgR0Xn27ymVWtqJHBe0FU9BPIQ+SFmKw+9jSwu +/FuqbJnLmTnWMJ1jRCtyHNZawvv2wbiB +-----END PUBLIC KEY-----" > pkcs11-release-signing-key.pem +openssl dgst -sha384 -verify pkcs11-release-signing-key.pem -signature "$PKCS11_MODULE_PATH.sig" "$PKCS11_MODULE_PATH" + +echo "$QNAP_SIGNING_CERT_BASE64" | base64 --decode > signer.pem + +echo "$QNAP_SIGNING_CERT_INTERMEDIARIES_BASE64" | base64 --decode > certs.pem + +openssl cms \ + -sign \ + -binary \ + -nodetach \ + -engine pkcs11 \ + -keyform engine \ + -inkey "pkcs11:object=$QNAP_SIGNING_KEY_NAME" \ + -keyopt rsa_padding_mode:pss \ + -keyopt rsa_pss_saltlen:digest \ + -signer signer.pem \ + -certfile certs.pem \ + -in "$1" \ + -out - diff --git a/release/dist/qnap/pkgs.go b/release/dist/qnap/pkgs.go index 9df649ddb..5062011f0 100644 --- a/release/dist/qnap/pkgs.go +++ b/release/dist/qnap/pkgs.go @@ -27,8 +27,12 @@ type target struct { } type signer struct { - privateKeyPath string - certificatePath string + gcloudCredentialsBase64 string + gcloudProject string + gcloudKeyring string + keyName string + certificateBase64 string + certificateIntermediariesBase64 string } func (t *target) String() string { @@ -66,7 +70,8 @@ func (t *target) buildQPKG(b *dist.Build, qnapBuilds *qnapBuilds, inner *innerPk filename := fmt.Sprintf("Tailscale_%s-%s_%s.qpkg", b.Version.Short, qnapTag, t.arch) filePath := filepath.Join(b.Out, filename) - cmd := b.Command(b.Repo, "docker", "run", "--rm", + args := []string{"run", "--rm", + "--network=host", "-e", fmt.Sprintf("ARCH=%s", t.arch), "-e", fmt.Sprintf("TSTAG=%s", b.Version.Short), "-e", fmt.Sprintf("QNAPTAG=%s", qnapTag), @@ -76,10 +81,29 @@ func (t *target) buildQPKG(b *dist.Build, qnapBuilds *qnapBuilds, inner *innerPk "-v", fmt.Sprintf("%s:/Tailscale", filepath.Join(qnapBuilds.tmpDir, "files/Tailscale")), "-v", fmt.Sprintf("%s:/build-qpkg.sh", filepath.Join(qnapBuilds.tmpDir, "files/scripts/build-qpkg.sh")), "-v", fmt.Sprintf("%s:/out", b.Out), + } + + if t.signer != nil { + log.Println("Will sign with Google Cloud HSM") + args = append(args, + "-e", fmt.Sprintf("GCLOUD_CREDENTIALS_BASE64=%s", t.signer.gcloudCredentialsBase64), + "-e", fmt.Sprintf("GCLOUD_PROJECT=%s", t.signer.gcloudProject), + "-e", fmt.Sprintf("GCLOUD_KEYRING=%s", t.signer.gcloudKeyring), + "-e", fmt.Sprintf("QNAP_SIGNING_KEY_NAME=%s", t.signer.keyName), + "-e", fmt.Sprintf("QNAP_SIGNING_CERT_BASE64=%s", t.signer.certificateBase64), + "-e", fmt.Sprintf("QNAP_SIGNING_CERT_INTERMEDIARIES_BASE64=%s", t.signer.certificateIntermediariesBase64), + "-e", fmt.Sprintf("QNAP_SIGNING_SCRIPT=%s", "/sign-qpkg.sh"), + "-v", fmt.Sprintf("%s:/sign-qpkg.sh", filepath.Join(qnapBuilds.tmpDir, "files/scripts/sign-qpkg.sh")), + ) + } + + args = append(args, "build.tailscale.io/qdk:latest", "/build-qpkg.sh", ) + cmd := b.Command(b.Repo, "docker", args...) + // dist.Build runs target builds in parallel goroutines by default. // For QNAP, this is an issue because the underlaying qbuild builder will // create tmp directories in the shared docker image that end up conflicting @@ -176,32 +200,6 @@ func newQNAPBuilds(b *dist.Build, signer *signer) (*qnapBuilds, error) { return nil, err } - if signer != nil { - log.Print("Setting up qnap signing files") - - key, err := os.ReadFile(signer.privateKeyPath) - if err != nil { - return nil, err - } - cert, err := os.ReadFile(signer.certificatePath) - if err != nil { - return nil, err - } - - // QNAP's qbuild command expects key and cert files to be in the root - // of the project directory (in our case release/dist/qnap/Tailscale). - // So here, we copy the key and cert over to the project folder for the - // duration of qnap package building and then delete them on close. - - keyPath := filepath.Join(m.tmpDir, "files/Tailscale/private_key") - if err := os.WriteFile(keyPath, key, 0400); err != nil { - return nil, err - } - certPath := filepath.Join(m.tmpDir, "files/Tailscale/certificate") - if err := os.WriteFile(certPath, cert, 0400); err != nil { - return nil, err - } - } return m, nil } diff --git a/release/dist/qnap/targets.go b/release/dist/qnap/targets.go index a069dd623..0a0213954 100644 --- a/release/dist/qnap/targets.go +++ b/release/dist/qnap/targets.go @@ -3,16 +3,32 @@ package qnap -import "tailscale.com/release/dist" +import ( + "slices" + + "tailscale.com/release/dist" +) // Targets defines the dist.Targets for QNAP devices. // -// If privateKeyPath and certificatePath are both provided non-empty, -// these targets will be signed for QNAP app store release with built. -func Targets(privateKeyPath, certificatePath string) []dist.Target { +// If all parameters are provided non-empty, then the build will be signed using +// a Google Cloud hosted key. +// +// gcloudCredentialsBase64 is the JSON credential for connecting to Google Cloud, base64 encoded. +// gcloudKeyring is the full path to the Google Cloud keyring containing the signing key. +// keyName is the name of the key. +// certificateBase64 is the PEM certificate to use in the signature, base64 encoded. +func Targets(gcloudCredentialsBase64, gcloudProject, gcloudKeyring, keyName, certificateBase64, certificateIntermediariesBase64 string) []dist.Target { var signerInfo *signer - if privateKeyPath != "" && certificatePath != "" { - signerInfo = &signer{privateKeyPath, certificatePath} + if !slices.Contains([]string{gcloudCredentialsBase64, gcloudProject, gcloudKeyring, keyName, certificateBase64, certificateIntermediariesBase64}, "") { + signerInfo = &signer{ + gcloudCredentialsBase64: gcloudCredentialsBase64, + gcloudProject: gcloudProject, + gcloudKeyring: gcloudKeyring, + keyName: keyName, + certificateBase64: certificateBase64, + certificateIntermediariesBase64: certificateIntermediariesBase64, + } } return []dist.Target{ &target{ diff --git a/release/dist/synology/pkgs.go b/release/dist/synology/pkgs.go index 7802470e1..ab89dbee3 100644 --- a/release/dist/synology/pkgs.go +++ b/release/dist/synology/pkgs.go @@ -155,8 +155,22 @@ func (t *target) mkInfo(b *dist.Build, uncompressedSz int64) []byte { f("os_min_ver", "6.0.1-7445") f("os_max_ver", "7.0-40000") case 7: - f("os_min_ver", "7.0-40000") - f("os_max_ver", "") + if t.packageCenter { + switch t.dsmMinorVersion { + case 0: + f("os_min_ver", "7.0-40000") + f("os_max_ver", "7.2-60000") + case 2: + f("os_min_ver", "7.2-60000") + default: + panic(fmt.Sprintf("unsupported DSM major.minor version %s", t.dsmVersionString())) + } + } else { + // We do not clamp the os_max_ver currently for non-package center builds as + // the binaries for 7.0 and 7.2 are identical. + f("os_min_ver", "7.0-40000") + f("os_max_ver", "") + } default: panic(fmt.Sprintf("unsupported DSM major version %d", t.dsmMajorVersion)) } diff --git a/safesocket/safesocket.go b/safesocket/safesocket.go index 991fddf5f..287cdca59 100644 --- a/safesocket/safesocket.go +++ b/safesocket/safesocket.go @@ -11,6 +11,9 @@ import ( "net" "runtime" "time" + + "tailscale.com/feature" + "tailscale.com/feature/buildfeatures" ) type closeable interface { @@ -31,7 +34,8 @@ func ConnCloseWrite(c net.Conn) error { } var processStartTime = time.Now() -var tailscaledProcExists = func() bool { return false } // set by safesocket_ps.go + +var tailscaledProcExists feature.Hook[func() bool] // tailscaledStillStarting reports whether tailscaled is probably // still starting up. That is, it reports whether the caller should @@ -50,7 +54,8 @@ func tailscaledStillStarting() bool { if d > 5*time.Second { return false } - return tailscaledProcExists() + f, ok := tailscaledProcExists.GetOk() + return ok && f() } // ConnectContext connects to tailscaled using a unix socket or named pipe. @@ -61,7 +66,11 @@ func ConnectContext(ctx context.Context, path string) (net.Conn, error) { if ctx.Err() != nil { return nil, ctx.Err() } - time.Sleep(250 * time.Millisecond) + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-time.After(250 * time.Millisecond): + } continue } return c, err @@ -100,7 +109,12 @@ func LocalTCPPortAndToken() (port int, token string, err error) { // PlatformUsesPeerCreds reports whether the current platform uses peer credentials // to authenticate connections. -func PlatformUsesPeerCreds() bool { return GOOSUsesPeerCreds(runtime.GOOS) } +func PlatformUsesPeerCreds() bool { + if !buildfeatures.HasUnixSocketIdentity { + return false + } + return GOOSUsesPeerCreds(runtime.GOOS) +} // GOOSUsesPeerCreds is like PlatformUsesPeerCreds but takes a // runtime.GOOS value instead of using the current one. diff --git a/safesocket/safesocket_darwin.go b/safesocket/safesocket_darwin.go index 62e6f7e6d..e2b3ea458 100644 --- a/safesocket/safesocket_darwin.go +++ b/safesocket/safesocket_darwin.go @@ -6,8 +6,11 @@ package safesocket import ( "bufio" "bytes" + crand "crypto/rand" "errors" "fmt" + "io/fs" + "log" "net" "os" "os/exec" @@ -17,6 +20,7 @@ import ( "sync" "time" + "golang.org/x/sys/unix" "tailscale.com/version" ) @@ -24,96 +28,287 @@ func init() { localTCPPortAndToken = localTCPPortAndTokenDarwin } -// localTCPPortAndTokenMacsys returns the localhost TCP port number and auth token -// from /Library/Tailscale. -// -// In that case the files are: +const sameUserProofTokenLength = 10 + +type safesocketDarwin struct { + mu sync.Mutex + token string // safesocket auth token + port int // safesocket port + sameuserproofFD *os.File // File descriptor for macos app store sameuserproof file + sharedDir string // Shared directory for location of sameuserproof file + + checkConn bool // If true, check macsys safesocket port before returning it + isMacSysExt func() bool // Reports true if this binary is the macOS System Extension + isMacGUIApp func() bool // Reports true if running as a macOS GUI app (Tailscale.app) +} + +var ssd = safesocketDarwin{ + isMacSysExt: version.IsMacSysExt, + isMacGUIApp: func() bool { return version.IsMacAppStoreGUI() || version.IsMacSysGUI() }, + checkConn: true, + sharedDir: "/Library/Tailscale", +} + +// There are three ways a Darwin binary can be run: as the Mac App Store (macOS) +// standalone notarized (macsys), or a separate CLI (tailscale) that was +// built or downloaded. // -// /Library/Tailscale/ipnport => $port (symlink with localhost port number target) -// /Library/Tailscale/sameuserproof-$port is a file with auth -func localTCPPortAndTokenMacsys() (port int, token string, err error) { +// The macOS and macsys binaries can communicate directly via XPC with +// the NEPacketTunnelProvider managed tailscaled process and are responsible for +// calling SetCredentials when they need to operate as a CLI. + +// A built/downloaded CLI binary will not be managing the NEPacketTunnelProvider +// hosting tailscaled directly and must source the credentials from a 'sameuserproof' file. +// This file is written to sharedDir when tailscaled/NEPacketTunnelProvider +// calls InitListenerDarwin. + +// localTCPPortAndTokenDarwin returns the localhost TCP port number and auth token +// either from the sameuserproof mechanism, or source and set directly from the +// NEPacketTunnelProvider managed tailscaled process when the CLI is invoked +// from the Tailscale.app GUI. +func localTCPPortAndTokenDarwin() (port int, token string, err error) { + ssd.mu.Lock() + defer ssd.mu.Unlock() + + switch { + case ssd.port != 0 && ssd.token != "": + // If something has explicitly set our credentials (typically non-standalone macos binary), use them. + return ssd.port, ssd.token, nil + case !ssd.isMacGUIApp(): + // We're not a GUI app (probably cmd/tailscale), so try falling back to sameuserproof. + // If portAndTokenFromSameUserProof returns an error here, cmd/tailscale will + // attempt to use the default unix socket mechanism supported by tailscaled. + return portAndTokenFromSameUserProof() + default: + return 0, "", ErrTokenNotFound + } +} + +// SetCredentials sets an token and port used to authenticate safesocket generated +// by the NEPacketTunnelProvider tailscaled process. This is only used when running +// the CLI via Tailscale.app. +func SetCredentials(token string, port int) { + ssd.mu.Lock() + defer ssd.mu.Unlock() + + if ssd.token != "" || ssd.port != 0 { + // Not fatal, but likely programmer error. Credentials do not change. + log.Printf("warning: SetCredentials credentials already set") + } + + ssd.token = token + ssd.port = port +} + +// InitListenerDarwin initializes the listener for the CLI commands +// and localapi HTTP server and sets the port/token. This will override +// any credentials set explicitly via SetCredentials(). Calling this mulitple times +// has no effect. The listener and it's corresponding token/port is initialized only once. +func InitListenerDarwin(sharedDir string) (*net.Listener, error) { + ssd.mu.Lock() + defer ssd.mu.Unlock() - const dir = "/Library/Tailscale" - portStr, err := os.Readlink(filepath.Join(dir, "ipnport")) + ln := onceListener.ln + if ln != nil { + return ln, nil + } + + var err error + ln, err = localhostListener() if err != nil { - return 0, "", err + log.Printf("InitListenerDarwin: listener initialization failed") + return nil, err } - port, err = strconv.Atoi(portStr) + + port, err := localhostTCPPort() if err != nil { - return 0, "", err + log.Printf("localhostTCPPort: listener initialization failed") + return nil, err } - authb, err := os.ReadFile(filepath.Join(dir, "sameuserproof-"+portStr)) + + token, err := getToken() if err != nil { - return 0, "", err + log.Printf("localhostTCPPort: getToken failed") + return nil, err } - auth := strings.TrimSpace(string(authb)) - if auth == "" { - return 0, "", errors.New("empty auth token in sameuserproof file") + + if port == 0 || token == "" { + log.Printf("localhostTCPPort: Invalid token or port") + return nil, fmt.Errorf("invalid localhostTCPPort: returned 0") } - // The above files exist forever after the first run of - // /Applications/Tailscale.app, so check we can connect to avoid returning a - // port nothing is listening on. Connect to "127.0.0.1" rather than - // "localhost" due to #7851. - conn, err := net.DialTimeout("tcp", "127.0.0.1:"+portStr, time.Second) + ssd.sharedDir = sharedDir + ssd.token = token + ssd.port = port + + // Write the port and token to a sameuserproof file + err = initSameUserProofToken(sharedDir, port, token) if err != nil { - return 0, "", err + // Not fatal + log.Printf("initSameUserProofToken: failed: %v", err) } - conn.Close() - return port, auth, nil + return ln, nil } -var warnAboutRootOnce sync.Once +var onceListener struct { + once sync.Once + ln *net.Listener +} -func localTCPPortAndTokenDarwin() (port int, token string, err error) { - // There are two ways this binary can be run: as the Mac App Store sandboxed binary, - // or a normal binary that somebody built or download and are being run from outside - // the sandbox. Detect which way we're running and then figure out how to connect - // to the local daemon. - - if dir := os.Getenv("TS_MACOS_CLI_SHARED_DIR"); dir != "" { - // First see if we're running as the non-AppStore "macsys" variant. - if version.IsMacSys() { - if port, token, err := localTCPPortAndTokenMacsys(); err == nil { - return port, token, nil +func localhostTCPPort() (int, error) { + if onceListener.ln == nil { + return 0, fmt.Errorf("listener not initialized") + } + + ln, err := localhostListener() + if err != nil { + return 0, err + } + + return (*ln).Addr().(*net.TCPAddr).Port, nil +} + +func localhostListener() (*net.Listener, error) { + onceListener.once.Do(func() { + ln, err := net.Listen("tcp4", "127.0.0.1:0") + if err != nil { + return + } + onceListener.ln = &ln + }) + if onceListener.ln == nil { + return nil, fmt.Errorf("failed to get TCP listener") + } + return onceListener.ln, nil +} + +var onceToken struct { + once sync.Once + token string +} + +func getToken() (string, error) { + onceToken.once.Do(func() { + buf := make([]byte, sameUserProofTokenLength) + if _, err := crand.Read(buf); err != nil { + return + } + t := fmt.Sprintf("%x", buf) + onceToken.token = t + }) + if onceToken.token == "" { + return "", fmt.Errorf("failed to generate token") + } + + return onceToken.token, nil +} + +// initSameUserProofToken writes the port and token to a sameuserproof +// file owned by the current user. We leave the file open to allow us +// to discover it via lsof. +// +// "sameuserproof" is intended to convey that the user attempting to read +// the credentials from the file is the same user that wrote them. For +// standalone macsys where tailscaled is running as root, we set group +// permissions to allow users in the admin group to read the file. +func initSameUserProofToken(sharedDir string, port int, token string) error { + var err error + + // Guard against bad sharedDir + old, err := os.ReadDir(sharedDir) + if err == os.ErrNotExist { + log.Printf("failed to read shared dir %s: %v", sharedDir, err) + return err + } + + // Remove all old sameuserproof files + for _, fi := range old { + if name := fi.Name(); strings.HasPrefix(name, "sameuserproof-") { + err := os.Remove(filepath.Join(sharedDir, name)) + if err != nil { + log.Printf("failed to remove %s: %v", name, err) } } + } - // The current binary (this process) is sandboxed. The user is - // running the CLI via /Applications/Tailscale.app/Contents/MacOS/Tailscale - // which sets the TS_MACOS_CLI_SHARED_DIR environment variable. - fis, err := os.ReadDir(dir) + var baseFile string + var perm fs.FileMode + if ssd.isMacSysExt() { + perm = 0640 // allow wheel to read + baseFile = fmt.Sprintf("sameuserproof-%d", port) + portFile := filepath.Join(sharedDir, "ipnport") + err := os.Remove(portFile) if err != nil { - return 0, "", err + log.Printf("failed to remove portfile %s: %v", portFile, err) } - for _, fi := range fis { - name := filepath.Base(fi.Name()) - // Look for name like "sameuserproof-61577-2ae2ec9e0aa2005784f1" - // to extract out the port number and token. - if strings.HasPrefix(name, "sameuserproof-") { - f := strings.SplitN(name, "-", 3) - if len(f) == 3 { - if port, err := strconv.Atoi(f[1]); err == nil { - return port, f[2], nil - } - } - } + symlinkErr := os.Symlink(fmt.Sprint(port), portFile) + if symlinkErr != nil { + log.Printf("failed to symlink portfile: %v", symlinkErr) } - if os.Geteuid() == 0 { - // Log a warning as the clue to the user, in case the error - // message is swallowed. Only do this once since we may retry - // multiple times to connect, and don't want to spam. - warnAboutRootOnce.Do(func() { - fmt.Fprintf(os.Stderr, "Warning: The CLI is running as root from within a sandboxed binary. It cannot reach the local tailscaled, please try again as a regular user.\n") - }) + } else { + perm = 0666 + baseFile = fmt.Sprintf("sameuserproof-%d-%s", port, token) + } + + path := filepath.Join(sharedDir, baseFile) + ssd.sameuserproofFD, err = os.OpenFile(path, os.O_RDWR|os.O_CREATE|os.O_TRUNC, perm) + log.Printf("initSameUserProofToken : done=%v", err == nil) + + if ssd.isMacSysExt() && err == nil { + fmt.Fprintf(ssd.sameuserproofFD, "%s\n", token) + + // Macsys runs as root so ownership of this file will be + // root/wheel. Change ownership to root/admin which will let all members + // of the admin group to read it. + unix.Fchown(int(ssd.sameuserproofFD.Fd()), 0, 80 /* admin */) + } + + return err +} + +// readMacsysSameuserproof returns the localhost TCP port number and auth token +// from a sameuserproof file written to /Library/Tailscale. +// +// In that case the files are: +// +// /Library/Tailscale/ipnport => $port (symlink with localhost port number target) +// /Library/Tailscale/sameuserproof-$port is a file containing only the auth token as a hex string. +func readMacsysSameUserProof() (port int, token string, err error) { + portStr, err := os.Readlink(filepath.Join(ssd.sharedDir, "ipnport")) + if err != nil { + return 0, "", err + } + port, err = strconv.Atoi(portStr) + if err != nil { + return 0, "", err + } + authb, err := os.ReadFile(filepath.Join(ssd.sharedDir, "sameuserproof-"+portStr)) + if err != nil { + return 0, "", err + } + auth := strings.TrimSpace(string(authb)) + if auth == "" { + return 0, "", errors.New("empty auth token in sameuserproof file") + } + + if ssd.checkConn { + // Files may be stale and there is no guarantee that the sameuserproof + // derived port is open and valid. Check it before returning it. + conn, err := net.DialTimeout("tcp", "127.0.0.1:"+portStr, time.Second) + if err != nil { + return 0, "", err } - return 0, "", fmt.Errorf("failed to find sandboxed sameuserproof-* file in TS_MACOS_CLI_SHARED_DIR %q", dir) + conn.Close() } - // The current process is running outside the sandbox, so use - // lsof to find the IPNExtension (the Mac App Store variant). + return port, auth, nil +} +// readMacosSameUserProof searches for open sameuserproof files belonging +// to the current user and the IPNExtension (macOS App Store) process and returns a +// port and token. +func readMacosSameUserProof() (port int, token string, err error) { cmd := exec.Command("lsof", "-n", // numeric sockets; don't do DNS lookups, etc "-a", // logical AND remaining options @@ -122,39 +317,45 @@ func localTCPPortAndTokenDarwin() (port int, token string, err error) { "-F", // machine-readable output ) out, err := cmd.Output() - if err != nil { - // Before returning an error, see if we're running the - // macsys variant at the normal location. - if port, token, err := localTCPPortAndTokenMacsys(); err == nil { + + if err == nil { + bs := bufio.NewScanner(bytes.NewReader(out)) + subStr := []byte(".tailscale.ipn.macos/sameuserproof-") + for bs.Scan() { + line := bs.Bytes() + i := bytes.Index(line, subStr) + if i == -1 { + continue + } + f := strings.SplitN(string(line[i+len(subStr):]), "-", 2) + if len(f) != 2 { + continue + } + portStr, token := f[0], f[1] + port, err := strconv.Atoi(portStr) + if err != nil { + return 0, "", fmt.Errorf("invalid port %q found in lsof", portStr) + } + return port, token, nil } - - return 0, "", fmt.Errorf("failed to run '%s' looking for IPNExtension: %w", cmd, err) } - bs := bufio.NewScanner(bytes.NewReader(out)) - subStr := []byte(".tailscale.ipn.macos/sameuserproof-") - for bs.Scan() { - line := bs.Bytes() - i := bytes.Index(line, subStr) - if i == -1 { - continue - } - f := strings.SplitN(string(line[i+len(subStr):]), "-", 2) - if len(f) != 2 { - continue - } - portStr, token := f[0], f[1] - port, err := strconv.Atoi(portStr) - if err != nil { - return 0, "", fmt.Errorf("invalid port %q found in lsof", portStr) - } + return 0, "", ErrTokenNotFound +} + +func portAndTokenFromSameUserProof() (port int, token string, err error) { + // When we're cmd/tailscale, we have no idea what tailscaled is, so we'll try + // macos, then macsys and finally, fallback to tailscaled via a unix socket + // if both of those return an error. You can run macos or macsys and + // tailscaled at the same time, but we are forced to choose one and the GUI + // clients are first in line here. You cannot run macos and macsys simultaneously. + if port, token, err := readMacosSameUserProof(); err == nil { return port, token, nil } - // Before returning an error, see if we're running the - // macsys variant at the normal location. - if port, token, err := localTCPPortAndTokenMacsys(); err == nil { + if port, token, err := readMacsysSameUserProof(); err == nil { return port, token, nil } + return 0, "", ErrTokenNotFound } diff --git a/safesocket/safesocket_darwin_test.go b/safesocket/safesocket_darwin_test.go new file mode 100644 index 000000000..e52959ad5 --- /dev/null +++ b/safesocket/safesocket_darwin_test.go @@ -0,0 +1,190 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package safesocket + +import ( + "os" + "strings" + "testing" + + "tailscale.com/tstest" +) + +// TestSetCredentials verifies that calling SetCredentials +// sets the port and token correctly and that LocalTCPPortAndToken +// returns the given values. +func TestSetCredentials(t *testing.T) { + const ( + wantToken = "token" + wantPort = 123 + ) + + tstest.Replace(t, &ssd.isMacGUIApp, func() bool { return false }) + SetCredentials(wantToken, wantPort) + + gotPort, gotToken, err := LocalTCPPortAndToken() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if gotPort != wantPort { + t.Errorf("port: got %d, want %d", gotPort, wantPort) + } + + if gotToken != wantToken { + t.Errorf("token: got %s, want %s", gotToken, wantToken) + } +} + +// TestFallbackToSameuserproof verifies that we fallback to the +// sameuserproof file via LocalTCPPortAndToken when we're running +// +// s cmd/tailscale +func TestFallbackToSameuserproof(t *testing.T) { + dir := t.TempDir() + const ( + wantToken = "token" + wantPort = 123 + ) + + // Mimics cmd/tailscale falling back to sameuserproof + tstest.Replace(t, &ssd.isMacGUIApp, func() bool { return false }) + tstest.Replace(t, &ssd.sharedDir, dir) + tstest.Replace(t, &ssd.checkConn, false) + + // Behave as macSysExt when initializing sameuserproof + tstest.Replace(t, &ssd.isMacSysExt, func() bool { return true }) + if err := initSameUserProofToken(dir, wantPort, wantToken); err != nil { + t.Fatalf("initSameUserProofToken: %v", err) + } + + gotPort, gotToken, err := LocalTCPPortAndToken() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if gotPort != wantPort { + t.Errorf("port: got %d, want %d", gotPort, wantPort) + } + + if gotToken != wantToken { + t.Errorf("token: got %s, want %s", gotToken, wantToken) + } +} + +// TestInitListenerDarwin verifies that InitListenerDarwin +// returns a listener and a non-zero port and non-empty token. +func TestInitListenerDarwin(t *testing.T) { + temp := t.TempDir() + tstest.Replace(t, &ssd.isMacGUIApp, func() bool { return false }) + + ln, err := InitListenerDarwin(temp) + if err != nil || ln == nil { + t.Fatalf("InitListenerDarwin failed: %v", err) + } + defer (*ln).Close() + + port, token, err := LocalTCPPortAndToken() + if err != nil { + t.Fatalf("LocalTCPPortAndToken failed: %v", err) + } + + if port == 0 { + t.Errorf("port: got %d, want non-zero", port) + } + + if token == "" { + t.Errorf("token: got %s, want non-empty", token) + } +} + +func TestTokenGeneration(t *testing.T) { + token, err := getToken() + if err != nil { + t.Fatalf("getToken: %v", err) + } + + // Verify token length (hex string is 2x byte length) + wantLen := sameUserProofTokenLength * 2 + if got := len(token); got != wantLen { + t.Errorf("token length: got %d, want %d", got, wantLen) + } + + // Verify token persistence + subsequentToken, err := getToken() + if err != nil { + t.Fatalf("subsequent getToken: %v", err) + } + if subsequentToken != token { + t.Errorf("subsequent token: got %q, want %q", subsequentToken, token) + } +} + +// TestSameUserProofToken verifies that the sameuserproof file +// is created and read correctly for the macsys variant +func TestMacsysSameuserproof(t *testing.T) { + dir := t.TempDir() + + tstest.Replace(t, &ssd.isMacSysExt, func() bool { return true }) + tstest.Replace(t, &ssd.checkConn, false) + tstest.Replace(t, &ssd.sharedDir, dir) + + const ( + wantToken = "token" + wantPort = 123 + ) + + if err := initSameUserProofToken(dir, wantPort, wantToken); err != nil { + t.Fatalf("initSameUserProofToken: %v", err) + } + + gotPort, gotToken, err := readMacsysSameUserProof() + if err != nil { + t.Fatalf("readMacOSSameUserProof: %v", err) + } + + if gotPort != wantPort { + t.Errorf("port: got %d, want %d", gotPort, wantPort) + } + if wantToken != gotToken { + t.Errorf("token: got %s, want %s", wantToken, gotToken) + } + assertFileCount(t, dir, 1, "sameuserproof-") +} + +// TestMacosSameuserproof verifies that the sameuserproof file +// is created correctly for the macos variant +func TestMacosSameuserproof(t *testing.T) { + dir := t.TempDir() + wantToken := "token" + wantPort := 123 + + initSameUserProofToken(dir, wantPort, wantToken) + + // initSameUserProofToken should never leave duplicates + initSameUserProofToken(dir, wantPort, wantToken) + + // we can't just call readMacosSameUserProof because it relies on lsof + // and makes some assumptions about the user. But we can make sure + // the file exists + assertFileCount(t, dir, 1, "sameuserproof-") +} + +func assertFileCount(t *testing.T, dir string, want int, prefix string) { + t.Helper() + + files, err := os.ReadDir(dir) + if err != nil { + t.Fatalf("[unexpected] error: %v", err) + } + count := 0 + for _, file := range files { + if strings.HasPrefix(file.Name(), prefix) { + count += 1 + } + } + if count != want { + t.Errorf("files: got %d, want 1", count) + } +} diff --git a/safesocket/safesocket_plan9.go b/safesocket/safesocket_plan9.go index 196c1df9c..c8a5e3b05 100644 --- a/safesocket/safesocket_plan9.go +++ b/safesocket/safesocket_plan9.go @@ -7,119 +7,13 @@ package safesocket import ( "context" - "fmt" "net" - "os" - "syscall" - "time" - - "golang.org/x/sys/plan9" ) -// Plan 9's devsrv srv(3) is a server registry and -// it is conventionally bound to "/srv" in the default -// namespace. It is "a one level directory for holding -// already open channels to services". Post one end of -// a pipe to "/srv/tailscale.sock" and use the other -// end for communication with a requestor. Plan 9 pipes -// are bidirectional. - -type plan9SrvAddr string - -func (sl plan9SrvAddr) Network() string { - return "/srv" -} - -func (sl plan9SrvAddr) String() string { - return string(sl) -} - -// There is no net.FileListener for Plan 9 at this time -type plan9SrvListener struct { - name string - srvf *os.File - file *os.File -} - -func (sl *plan9SrvListener) Accept() (net.Conn, error) { - // sl.file is the server end of the pipe that's - // connected to /srv/tailscale.sock - return plan9FileConn{name: sl.name, file: sl.file}, nil -} - -func (sl *plan9SrvListener) Close() error { - sl.file.Close() - return sl.srvf.Close() -} - -func (sl *plan9SrvListener) Addr() net.Addr { - return plan9SrvAddr(sl.name) -} - -type plan9FileConn struct { - name string - file *os.File -} - -func (fc plan9FileConn) Read(b []byte) (n int, err error) { - return fc.file.Read(b) -} -func (fc plan9FileConn) Write(b []byte) (n int, err error) { - return fc.file.Write(b) -} -func (fc plan9FileConn) Close() error { - return fc.file.Close() -} -func (fc plan9FileConn) LocalAddr() net.Addr { - return plan9SrvAddr(fc.name) -} -func (fc plan9FileConn) RemoteAddr() net.Addr { - return plan9SrvAddr(fc.name) -} -func (fc plan9FileConn) SetDeadline(t time.Time) error { - return syscall.EPLAN9 -} -func (fc plan9FileConn) SetReadDeadline(t time.Time) error { - return syscall.EPLAN9 -} -func (fc plan9FileConn) SetWriteDeadline(t time.Time) error { - return syscall.EPLAN9 -} - func connect(_ context.Context, path string) (net.Conn, error) { - f, err := os.OpenFile(path, os.O_RDWR, 0666) - if err != nil { - return nil, err - } - - return plan9FileConn{name: path, file: f}, nil + return net.Dial("tcp", "localhost:5252") } -// Create an entry in /srv, open a pipe, write the -// client end to the entry and return the server -// end of the pipe to the caller. When the server -// end of the pipe is closed, /srv name associated -// with it will be removed (controlled by ORCLOSE flag) func listen(path string) (net.Listener, error) { - const O_RCLOSE = 64 // remove on close; should be in plan9 package - var pip [2]int - - err := plan9.Pipe(pip[:]) - if err != nil { - return nil, err - } - defer plan9.Close(pip[1]) - - srvfd, err := plan9.Create(path, plan9.O_WRONLY|plan9.O_CLOEXEC|O_RCLOSE, 0600) - if err != nil { - return nil, err - } - srv := os.NewFile(uintptr(srvfd), path) - - _, err = fmt.Fprintf(srv, "%d", pip[1]) - if err != nil { - return nil, err - } - - return &plan9SrvListener{name: path, srvf: srv, file: os.NewFile(uintptr(pip[0]), path)}, nil + return net.Listen("tcp", "localhost:5252") } diff --git a/safesocket/safesocket_ps.go b/safesocket/safesocket_ps.go index f7d97f7fd..d3f409df5 100644 --- a/safesocket/safesocket_ps.go +++ b/safesocket/safesocket_ps.go @@ -1,7 +1,7 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -//go:build linux || windows || darwin || freebsd +//go:build ((linux && !android) || windows || (darwin && !ios) || freebsd) && !ts_omit_cliconndiag package safesocket @@ -12,7 +12,7 @@ import ( ) func init() { - tailscaledProcExists = func() bool { + tailscaledProcExists.Set(func() bool { procs, err := ps.Processes() if err != nil { return false @@ -30,5 +30,5 @@ func init() { } } return false - } + }) } diff --git a/safeweb/http.go b/safeweb/http.go index 9130b42d3..d085fcb88 100644 --- a/safeweb/http.go +++ b/safeweb/http.go @@ -71,28 +71,78 @@ package safeweb import ( "cmp" + "context" crand "crypto/rand" "fmt" "log" + "maps" "net" "net/http" "net/url" "path" + "slices" "strings" "github.com/gorilla/csrf" ) -// The default Content-Security-Policy header. -var defaultCSP = strings.Join([]string{ - `default-src 'self'`, // origin is the only valid source for all content types - `script-src 'self'`, // disallow inline javascript - `frame-ancestors 'none'`, // disallow framing of the page - `form-action 'self'`, // disallow form submissions to other origins - `base-uri 'self'`, // disallow base URIs from other origins - `block-all-mixed-content`, // disallow mixed content when serving over HTTPS - `object-src 'self'`, // disallow embedding of resources from other origins -}, "; ") +// CSP is the value of a Content-Security-Policy header. Keys are CSP +// directives (like "default-src") and values are source expressions (like +// "'self'" or "https://tailscale.com"). A nil slice value is allowed for some +// directives like "upgrade-insecure-requests" that don't expect a list of +// source definitions. +type CSP map[string][]string + +// DefaultCSP is the recommended CSP to use when not loading resources from +// other domains and not embedding the current website. If you need to tweak +// the CSP, it is recommended to extend DefaultCSP instead of writing your own +// from scratch. +func DefaultCSP() CSP { + return CSP{ + "default-src": {"self"}, // origin is the only valid source for all content types + "frame-ancestors": {"none"}, // disallow framing of the page + "form-action": {"self"}, // disallow form submissions to other origins + "base-uri": {"self"}, // disallow base URIs from other origins + // TODO(awly): consider upgrade-insecure-requests in SecureContext + // instead, as this is deprecated. + "block-all-mixed-content": nil, // disallow mixed content when serving over HTTPS + } +} + +// Set sets the values for a given directive. Empty values are allowed, if the +// directive doesn't expect any (like "upgrade-insecure-requests"). +func (csp CSP) Set(directive string, values ...string) { + csp[directive] = values +} + +// Add adds a source expression to an existing directive. +func (csp CSP) Add(directive, value string) { + csp[directive] = append(csp[directive], value) +} + +// Del deletes a directive and all its values. +func (csp CSP) Del(directive string) { + delete(csp, directive) +} + +func (csp CSP) String() string { + keys := slices.Collect(maps.Keys(csp)) + slices.Sort(keys) + var s strings.Builder + for _, k := range keys { + s.WriteString(k) + for _, v := range csp[k] { + // Special values like 'self', 'none', 'unsafe-inline', etc., must + // be quoted. Do it implicitly as a convenience here. + if !strings.Contains(v, ".") && len(v) > 1 && v[0] != '\'' && v[len(v)-1] != '\'' { + v = "'" + v + "'" + } + s.WriteString(" " + v) + } + s.WriteString("; ") + } + return strings.TrimSpace(s.String()) +} // The default Strict-Transport-Security header. This header tells the browser // to exclusively use HTTPS for all requests to the origin for the next year. @@ -130,6 +180,9 @@ type Config struct { // startup. CSRFSecret []byte + // CSP is the Content-Security-Policy header to return with BrowserMux + // responses. + CSP CSP // CSPAllowInlineStyles specifies whether to include `style-src: // unsafe-inline` in the Content-Security-Policy header to permit the use of // inline CSS. @@ -144,6 +197,12 @@ type Config struct { // BrowserMux when SecureContext is true. // If empty, it defaults to max-age of 1 year. StrictTransportSecurityOptions string + + // HTTPServer, if specified, is the underlying http.Server that safeweb will + // use to serve requests. If nil, a new http.Server will be created. + // Do not use the Handler field of http.Server, as it will be ignored. + // Instead, set your handlers using APIMux and BrowserMux. + HTTPServer *http.Server } func (c *Config) setDefaults() error { @@ -162,6 +221,10 @@ func (c *Config) setDefaults() error { } } + if c.CSP == nil { + c.CSP = DefaultCSP() + } + return nil } @@ -193,17 +256,25 @@ func NewServer(config Config) (*Server, error) { if config.CookiesSameSiteLax { sameSite = csrf.SameSiteLaxMode } + if config.CSPAllowInlineStyles { + if _, ok := config.CSP["style-src"]; ok { + config.CSP.Add("style-src", "unsafe-inline") + } else { + config.CSP.Set("style-src", "self", "unsafe-inline") + } + } s := &Server{ Config: config, - csp: defaultCSP, + csp: config.CSP.String(), // only set Secure flag on CSRF cookies if we are in a secure context // as otherwise the browser will reject the cookie csrfProtect: csrf.Protect(config.CSRFSecret, csrf.Secure(config.SecureContext), csrf.SameSite(sameSite)), } - if config.CSPAllowInlineStyles { - s.csp = defaultCSP + `; style-src 'self' 'unsafe-inline'` + s.h = cmp.Or(config.HTTPServer, &http.Server{}) + if s.h.Handler != nil { + return nil, fmt.Errorf("use safeweb.Config.APIMux and safeweb.Config.BrowserMux instead of http.Server.Handler") } - s.h = &http.Server{Handler: s} + s.h.Handler = s return s, nil } @@ -215,12 +286,27 @@ const ( browserHandler ) +func (h handlerType) String() string { + switch h { + case browserHandler: + return "browser" + case apiHandler: + return "api" + default: + return "unknown" + } +} + // checkHandlerType returns either apiHandler or browserHandler, depending on // whether apiPattern or browserPattern is more specific (i.e. which pattern // contains more pathname components). If they are equally specific, it returns // unknownHandler. func checkHandlerType(apiPattern, browserPattern string) handlerType { - c := cmp.Compare(strings.Count(path.Clean(apiPattern), "/"), strings.Count(path.Clean(browserPattern), "/")) + apiPattern, browserPattern = path.Clean(apiPattern), path.Clean(browserPattern) + c := cmp.Compare(strings.Count(apiPattern, "/"), strings.Count(browserPattern, "/")) + if apiPattern == "/" || browserPattern == "/" { + c = cmp.Compare(len(apiPattern), len(browserPattern)) + } switch { case c > 0: return apiHandler @@ -232,6 +318,12 @@ func checkHandlerType(apiPattern, browserPattern string) handlerType { } func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { + // if we are not in a secure context, signal to the CSRF middleware that + // TLS-only header checks should be skipped + if !s.Config.SecureContext { + r = csrf.PlaintextHTTPRequest(r) + } + _, bp := s.BrowserMux.Handler(r) _, ap := s.APIMux.Handler(r) switch { @@ -284,6 +376,7 @@ func (s *Server) serveBrowser(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Security-Policy", s.csp) w.Header().Set("X-Content-Type-Options", "nosniff") w.Header().Set("Referer-Policy", "same-origin") + w.Header().Set("Cross-Origin-Opener-Policy", "same-origin") if s.SecureContext { w.Header().Set("Strict-Transport-Security", cmp.Or(s.StrictTransportSecurityOptions, DefaultStrictTransportSecurityOptions)) } @@ -331,3 +424,7 @@ func (s *Server) ListenAndServe(addr string) error { func (s *Server) Close() error { return s.h.Close() } + +// Shutdown gracefully shuts down the server without interrupting any active +// connections. It has the same semantics as[http.Server.Shutdown]. +func (s *Server) Shutdown(ctx context.Context) error { return s.h.Shutdown(ctx) } diff --git a/safeweb/http_test.go b/safeweb/http_test.go index 843da08aa..852ce326b 100644 --- a/safeweb/http_test.go +++ b/safeweb/http_test.go @@ -10,6 +10,7 @@ import ( "strconv" "strings" "testing" + "time" "github.com/google/go-cmp/cmp" "github.com/gorilla/csrf" @@ -240,18 +241,26 @@ func TestCSRFProtection(t *testing.T) { func TestContentSecurityPolicyHeader(t *testing.T) { tests := []struct { name string + csp CSP apiRoute bool - wantCSP bool + wantCSP string }{ { - name: "default routes get CSP headers", - apiRoute: false, - wantCSP: true, + name: "default CSP", + wantCSP: `base-uri 'self'; block-all-mixed-content; default-src 'self'; form-action 'self'; frame-ancestors 'none';`, + }, + { + name: "custom CSP", + csp: CSP{ + "default-src": {"'self'", "https://tailscale.com"}, + "upgrade-insecure-requests": nil, + }, + wantCSP: `default-src 'self' https://tailscale.com; upgrade-insecure-requests;`, }, { name: "`/api/*` routes do not get CSP headers", apiRoute: true, - wantCSP: false, + wantCSP: "", }, } @@ -264,9 +273,9 @@ func TestContentSecurityPolicyHeader(t *testing.T) { var s *Server var err error if tt.apiRoute { - s, err = NewServer(Config{APIMux: h}) + s, err = NewServer(Config{APIMux: h, CSP: tt.csp}) } else { - s, err = NewServer(Config{BrowserMux: h}) + s, err = NewServer(Config{BrowserMux: h, CSP: tt.csp}) } if err != nil { t.Fatal(err) @@ -278,8 +287,8 @@ func TestContentSecurityPolicyHeader(t *testing.T) { s.h.Handler.ServeHTTP(w, req) resp := w.Result() - if (resp.Header.Get("Content-Security-Policy") == "") == tt.wantCSP { - t.Fatalf("content security policy want: %v; got: %v", tt.wantCSP, resp.Header.Get("Content-Security-Policy")) + if got := resp.Header.Get("Content-Security-Policy"); got != tt.wantCSP { + t.Fatalf("content security policy want: %q; got: %q", tt.wantCSP, got) } }) } @@ -396,7 +405,7 @@ func TestCSPAllowInlineStyles(t *testing.T) { csp := resp.Header.Get("Content-Security-Policy") allowsStyles := strings.Contains(csp, "style-src 'self' 'unsafe-inline'") if allowsStyles != allow { - t.Fatalf("CSP inline styles want: %v; got: %v", allow, allowsStyles) + t.Fatalf("CSP inline styles want: %v, got: %v in %q", allow, allowsStyles, csp) } }) } @@ -526,13 +535,13 @@ func TestGetMoreSpecificPattern(t *testing.T) { { desc: "same prefix", a: "/foo/bar/quux", - b: "/foo/bar/", + b: "/foo/bar/", // path.Clean will strip the trailing slash. want: apiHandler, }, { desc: "almost same prefix, but not a path component", a: "/goat/sheep/cheese", - b: "/goat/sheepcheese/", + b: "/goat/sheepcheese/", // path.Clean will strip the trailing slash. want: apiHandler, }, { @@ -553,6 +562,12 @@ func TestGetMoreSpecificPattern(t *testing.T) { b: "///////", want: unknownHandler, }, + { + desc: "root-level", + a: "/latest", + b: "/", // path.Clean will NOT strip the trailing slash. + want: apiHandler, + }, } { t.Run(tt.desc, func(t *testing.T) { got := checkHandlerType(tt.a, tt.b) @@ -609,3 +624,26 @@ func TestStrictTransportSecurityOptions(t *testing.T) { }) } } + +func TestOverrideHTTPServer(t *testing.T) { + s, err := NewServer(Config{}) + if err != nil { + t.Fatalf("NewServer: %v", err) + } + if s.h.IdleTimeout != 0 { + t.Fatalf("got %v; want 0", s.h.IdleTimeout) + } + + c := http.Server{ + IdleTimeout: 10 * time.Second, + } + + s, err = NewServer(Config{HTTPServer: &c}) + if err != nil { + t.Fatalf("NewServer: %v", err) + } + + if s.h.IdleTimeout != c.IdleTimeout { + t.Fatalf("got %v; want %v", s.h.IdleTimeout, c.IdleTimeout) + } +} diff --git a/scripts/check_license_headers.sh b/scripts/check_license_headers.sh deleted file mode 100755 index 8345afab7..000000000 --- a/scripts/check_license_headers.sh +++ /dev/null @@ -1,77 +0,0 @@ -#!/bin/sh -# -# Copyright (c) Tailscale Inc & AUTHORS -# SPDX-License-Identifier: BSD-3-Clause -# -# check_license_headers.sh checks that all Go files in the given -# directory tree have a correct-looking Tailscale license header. - -check_file() { - got=$1 - - want=$(cat <&2 - exit 1 -fi - -fail=0 -for file in $(find $1 \( -name '*.go' -or -name '*.tsx' -or -name '*.ts' -not -name '*.config.ts' \) -not -path '*/.git/*' -not -path '*/node_modules/*'); do - case $file in - $1/tempfork/*) - # Skip, tempfork of third-party code - ;; - $1/wgengine/router/ifconfig_windows.go) - # WireGuard copyright. - ;; - $1/cmd/tailscale/cli/authenticode_windows.go) - # WireGuard copyright. - ;; - *_string.go) - # Generated file from go:generate stringer - ;; - $1/control/controlbase/noiseexplorer_test.go) - # Noiseexplorer.com copyright. - ;; - */zsyscall_windows.go) - # Generated syscall wrappers - ;; - $1/util/winutil/subprocess_windows_test.go) - # Subprocess test harness code - ;; - $1/util/winutil/testdata/testrestartableprocesses/main.go) - # Subprocess test harness code - ;; - *$1/k8s-operator/apis/v1alpha1/zz_generated.deepcopy.go) - # Generated kube deepcopy funcs file starts with a Go build tag + an empty line - header="$(head -5 $file | tail -n+3 )" - ;; - $1/derp/xdp/bpf_bpfe*.go) - # Generated eBPF management code - ;; - *) - header="$(head -2 $file)" - ;; - esac - if [ ! -z "$header" ]; then - if ! check_file "$header"; then - fail=1 - echo "${file#$1/} doesn't have the right copyright header:" - echo "$header" | sed -e 's/^/ /g' - fi - fi -done - -if [ $fail -ne 0 ]; then - exit 1 -fi diff --git a/scripts/installer.sh b/scripts/installer.sh index 19911ee23..e5b6cd23b 100755 --- a/scripts/installer.sh +++ b/scripts/installer.sh @@ -42,6 +42,8 @@ main() { # - VERSION_CODENAME: the codename of the OS release, if any (e.g. "buster") # - UBUNTU_CODENAME: if it exists, use instead of VERSION_CODENAME . /etc/os-release + VERSION_MAJOR="${VERSION_ID:-}" + VERSION_MAJOR="${VERSION_MAJOR%%.*}" case "$ID" in ubuntu|pop|neon|zorin|tuxedo) OS="ubuntu" @@ -53,10 +55,10 @@ main() { PACKAGETYPE="apt" # Third-party keyrings became the preferred method of # installation in Ubuntu 20.04. - if expr "$VERSION_ID" : "2.*" >/dev/null; then - APT_KEY_TYPE="keyring" - else + if [ "$VERSION_MAJOR" -lt 20 ]; then APT_KEY_TYPE="legacy" + else + APT_KEY_TYPE="keyring" fi ;; debian) @@ -68,7 +70,15 @@ main() { if [ -z "${VERSION_ID:-}" ]; then # rolling release. If you haven't kept current, that's on you. APT_KEY_TYPE="keyring" - elif [ "$VERSION_ID" -lt 11 ]; then + # Parrot Security is a special case that uses ID=debian + elif [ "$NAME" = "Parrot Security" ]; then + # All versions new enough to have this behaviour prefer keyring + # and their VERSION_ID is not consistent with Debian. + APT_KEY_TYPE="keyring" + # They don't specify the Debian version they're based off in os-release + # but Parrot 6 is based on Debian 12 Bookworm. + VERSION=bookworm + elif [ "$VERSION_MAJOR" -lt 11 ]; then APT_KEY_TYPE="legacy" else APT_KEY_TYPE="keyring" @@ -86,7 +96,7 @@ main() { VERSION="$VERSION_CODENAME" fi PACKAGETYPE="apt" - if [ "$VERSION_ID" -lt 5 ]; then + if [ "$VERSION_MAJOR" -lt 5 ]; then APT_KEY_TYPE="legacy" else APT_KEY_TYPE="keyring" @@ -96,16 +106,27 @@ main() { OS="ubuntu" VERSION="$UBUNTU_CODENAME" PACKAGETYPE="apt" - if [ "$VERSION_ID" -lt 6 ]; then + if [ "$VERSION_MAJOR" -lt 6 ]; then + APT_KEY_TYPE="legacy" + else + APT_KEY_TYPE="keyring" + fi + ;; + industrial-os) + OS="debian" + PACKAGETYPE="apt" + if [ "$VERSION_MAJOR" -lt 5 ]; then + VERSION="buster" APT_KEY_TYPE="legacy" else + VERSION="bullseye" APT_KEY_TYPE="keyring" fi ;; parrot|mendel) OS="debian" PACKAGETYPE="apt" - if [ "$VERSION_ID" -lt 5 ]; then + if [ "$VERSION_MAJOR" -lt 5 ]; then VERSION="buster" APT_KEY_TYPE="legacy" else @@ -131,7 +152,7 @@ main() { PACKAGETYPE="apt" # Third-party keyrings became the preferred method of # installation in Raspbian 11 (Bullseye). - if [ "$VERSION_ID" -lt 11 ]; then + if [ "$VERSION_MAJOR" -lt 11 ]; then APT_KEY_TYPE="legacy" else APT_KEY_TYPE="keyring" @@ -140,12 +161,11 @@ main() { kali) OS="debian" PACKAGETYPE="apt" - YEAR="$(echo "$VERSION_ID" | cut -f1 -d.)" APT_SYSTEMCTL_START=true # Third-party keyrings became the preferred method of # installation in Debian 11 (Bullseye), which Kali switched # to in roughly 2021.x releases - if [ "$YEAR" -lt 2021 ]; then + if [ "$VERSION_MAJOR" -lt 2021 ]; then # Kali VERSION_ID is "kali-rolling", which isn't distinguishing VERSION="buster" APT_KEY_TYPE="legacy" @@ -154,10 +174,10 @@ main() { APT_KEY_TYPE="keyring" fi ;; - Deepin) # https://github.com/tailscale/tailscale/issues/7862 + Deepin|deepin) # https://github.com/tailscale/tailscale/issues/7862 OS="debian" PACKAGETYPE="apt" - if [ "$VERSION_ID" -lt 20 ]; then + if [ "$VERSION_MAJOR" -lt 20 ]; then APT_KEY_TYPE="legacy" VERSION="buster" else @@ -165,9 +185,28 @@ main() { VERSION="bullseye" fi ;; + pika) + PACKAGETYPE="apt" + # All versions of PikaOS are new enough to prefer keyring + APT_KEY_TYPE="keyring" + # Older versions of PikaOS are based on Ubuntu rather than Debian + if [ "$VERSION_MAJOR" -lt 4 ]; then + OS="ubuntu" + VERSION="$UBUNTU_CODENAME" + else + OS="debian" + VERSION="$DEBIAN_CODENAME" + fi + ;; + sparky) + OS="debian" + PACKAGETYPE="apt" + VERSION="$DEBIAN_CODENAME" + APT_KEY_TYPE="keyring" + ;; centos) OS="$ID" - VERSION="$VERSION_ID" + VERSION="$VERSION_MAJOR" PACKAGETYPE="dnf" if [ "$VERSION" = "7" ]; then PACKAGETYPE="yum" @@ -175,15 +214,18 @@ main() { ;; ol) OS="oracle" - VERSION="$(echo "$VERSION_ID" | cut -f1 -d.)" + VERSION="$VERSION_MAJOR" PACKAGETYPE="dnf" if [ "$VERSION" = "7" ]; then PACKAGETYPE="yum" fi ;; - rhel) + rhel|miraclelinux) OS="$ID" - VERSION="$(echo "$VERSION_ID" | cut -f1 -d.)" + if [ "$ID" = "miraclelinux" ]; then + OS="rhel" + fi + VERSION="$VERSION_MAJOR" PACKAGETYPE="dnf" if [ "$VERSION" = "7" ]; then PACKAGETYPE="yum" @@ -206,7 +248,7 @@ main() { ;; xenenterprise) OS="centos" - VERSION="$(echo "$VERSION_ID" | cut -f1 -d.)" + VERSION="$VERSION_MAJOR" PACKAGETYPE="yum" ;; opensuse-leap|sles) @@ -224,12 +266,12 @@ main() { VERSION="leap/15.4" PACKAGETYPE="zypper" ;; - arch|archarm|endeavouros|blendos|garuda) + arch|archarm|endeavouros|blendos|garuda|archcraft|cachyos) OS="arch" VERSION="" # rolling release PACKAGETYPE="pacman" ;; - manjaro|manjaro-arm) + manjaro|manjaro-arm|biglinux) OS="manjaro" VERSION="" # rolling release PACKAGETYPE="pacman" @@ -250,6 +292,14 @@ main() { echo "services.tailscale.enable = true;" exit 1 ;; + bazzite) + echo "Bazzite comes with Tailscale installed by default." + echo "Please enable Tailscale by running the following commands as root:" + echo + echo "ujust enable-tailscale" + echo "tailscale up" + exit 1 + ;; void) OS="$ID" VERSION="" # rolling release @@ -262,7 +312,7 @@ main() { ;; freebsd) OS="$ID" - VERSION="$(echo "$VERSION_ID" | cut -f1 -d.)" + VERSION="$VERSION_MAJOR" PACKAGETYPE="pkg" ;; osmc) @@ -273,7 +323,7 @@ main() { ;; photon) OS="photon" - VERSION="$(echo "$VERSION_ID" | cut -f1 -d.)" + VERSION="$VERSION_MAJOR" PACKAGETYPE="tdnf" ;; @@ -369,7 +419,9 @@ main() { ;; freebsd) if [ "$VERSION" != "12" ] && \ - [ "$VERSION" != "13" ] + [ "$VERSION" != "13" ] && \ + [ "$VERSION" != "14" ] && \ + [ "$VERSION" != "15" ] then OS_UNSUPPORTED=1 fi @@ -465,10 +517,13 @@ main() { legacy) $CURL "https://pkgs.tailscale.com/$TRACK/$OS/$VERSION.asc" | $SUDO apt-key add - $CURL "https://pkgs.tailscale.com/$TRACK/$OS/$VERSION.list" | $SUDO tee /etc/apt/sources.list.d/tailscale.list + $SUDO chmod 0644 /etc/apt/sources.list.d/tailscale.list ;; keyring) $CURL "https://pkgs.tailscale.com/$TRACK/$OS/$VERSION.noarmor.gpg" | $SUDO tee /usr/share/keyrings/tailscale-archive-keyring.gpg >/dev/null + $SUDO chmod 0644 /usr/share/keyrings/tailscale-archive-keyring.gpg $CURL "https://pkgs.tailscale.com/$TRACK/$OS/$VERSION.tailscale-keyring.list" | $SUDO tee /etc/apt/sources.list.d/tailscale.list + $SUDO chmod 0644 /etc/apt/sources.list.d/tailscale.list ;; esac $SUDO apt-get update @@ -488,9 +543,41 @@ main() { set +x ;; dnf) + # DNF 5 has a different argument format; determine which one we have. + DNF_VERSION="3" + if LANG=C.UTF-8 dnf --version | grep -q '^dnf5 version'; then + DNF_VERSION="5" + fi + + # The 'config-manager' plugin wasn't implemented when + # DNF5 was released; detect that and use the old + # version if necessary. + if [ "$DNF_VERSION" = "5" ]; then + set -x + $SUDO dnf install -y 'dnf-command(config-manager)' && DNF_HAVE_CONFIG_MANAGER=1 || DNF_HAVE_CONFIG_MANAGER=0 + set +x + + if [ "$DNF_HAVE_CONFIG_MANAGER" != "1" ]; then + if type dnf-3 >/dev/null; then + DNF_VERSION="3" + else + echo "dnf 5 detected, but 'dnf-command(config-manager)' not available and dnf-3 not found" + exit 1 + fi + fi + fi + set -x - $SUDO dnf install -y 'dnf-command(config-manager)' - $SUDO dnf config-manager --add-repo "https://pkgs.tailscale.com/$TRACK/$OS/$VERSION/tailscale.repo" + if [ "$DNF_VERSION" = "3" ]; then + $SUDO dnf install -y 'dnf-command(config-manager)' + $SUDO dnf config-manager --add-repo "https://pkgs.tailscale.com/$TRACK/$OS/$VERSION/tailscale.repo" + elif [ "$DNF_VERSION" = "5" ]; then + # Already installed config-manager, above. + $SUDO dnf config-manager addrepo --from-repofile="https://pkgs.tailscale.com/$TRACK/$OS/$VERSION/tailscale.repo" + else + echo "unexpected: unknown dnf version $DNF_VERSION" + exit 1 + fi $SUDO dnf install -y tailscale $SUDO systemctl enable --now tailscaled set +x @@ -519,7 +606,7 @@ main() { ;; pkg) set -x - $SUDO pkg install tailscale + $SUDO pkg install --yes tailscale $SUDO service tailscaled enable $SUDO service tailscaled start set +x diff --git a/sessionrecording/connect.go b/sessionrecording/connect.go index db966ba2c..9d20b41f9 100644 --- a/sessionrecording/connect.go +++ b/sessionrecording/connect.go @@ -7,6 +7,7 @@ package sessionrecording import ( "context" + "encoding/json" "errors" "fmt" "io" @@ -14,12 +15,29 @@ import ( "net/http" "net/http/httptrace" "net/netip" + "sync/atomic" "time" + "tailscale.com/net/netx" "tailscale.com/tailcfg" - "tailscale.com/util/multierr" + "tailscale.com/util/httpm" ) +const ( + // Timeout for an individual DialFunc call for a single recorder address. + perDialAttemptTimeout = 5 * time.Second + // Timeout for the V2 API HEAD probe request (supportsV2). + http2ProbeTimeout = 10 * time.Second + // Maximum timeout for trying all available recorders, including V2 API + // probes and dial attempts. + allDialAttemptsTimeout = 30 * time.Second +) + +// uploadAckWindow is the period of time to wait for an ackFrame from recorder +// before terminating the connection. This is a variable to allow overriding it +// in tests. +var uploadAckWindow = 30 * time.Second + // ConnectToRecorder connects to the recorder at any of the provided addresses. // It returns the first successful response, or a multierr if all attempts fail. // @@ -32,19 +50,15 @@ import ( // attempts are in order the recorder(s) was attempted. If successful a // successful connection is made, the last attempt in the slice is the // attempt for connected recorder. -func ConnectToRecorder(ctx context.Context, recs []netip.AddrPort, dial func(context.Context, string, string) (net.Conn, error)) (io.WriteCloser, []*tailcfg.SSHRecordingAttempt, <-chan error, error) { +func ConnectToRecorder(ctx context.Context, recs []netip.AddrPort, dial netx.DialFunc) (io.WriteCloser, []*tailcfg.SSHRecordingAttempt, <-chan error, error) { if len(recs) == 0 { return nil, nil, nil, errors.New("no recorders configured") } // We use a special context for dialing the recorder, so that we can // limit the time we spend dialing to 30 seconds and still have an // unbounded context for the upload. - dialCtx, dialCancel := context.WithTimeout(ctx, 30*time.Second) + dialCtx, dialCancel := context.WithTimeout(ctx, allDialAttemptsTimeout) defer dialCancel() - hc, err := SessionRecordingClientForDialer(dialCtx, dial) - if err != nil { - return nil, nil, nil, err - } var errs []error var attempts []*tailcfg.SSHRecordingAttempt @@ -54,74 +68,321 @@ func ConnectToRecorder(ctx context.Context, recs []netip.AddrPort, dial func(con } attempts = append(attempts, attempt) - // We dial the recorder and wait for it to send a 100-continue - // response before returning from this function. This ensures that - // the recorder is ready to accept the recording. - - // got100 is closed when we receive the 100-continue response. - got100 := make(chan struct{}) - ctx = httptrace.WithClientTrace(ctx, &httptrace.ClientTrace{ - Got100Continue: func() { - close(got100) - }, - }) - - pr, pw := io.Pipe() - req, err := http.NewRequestWithContext(ctx, "POST", fmt.Sprintf("http://%s:%d/record", ap.Addr(), ap.Port()), pr) + var pw io.WriteCloser + var errChan <-chan error + var err error + hc := clientHTTP2(dialCtx, dial) + // We need to probe V2 support using a separate HEAD request. Sending + // an HTTP/2 POST request to a HTTP/1 server will just "hang" until the + // request body is closed (instead of returning a 404 as one would + // expect). Sending a HEAD request without a body does not have that + // problem. + if supportsV2(ctx, hc, ap) { + pw, errChan, err = connectV2(ctx, hc, ap) + } else { + pw, errChan, err = connectV1(ctx, clientHTTP1(dialCtx, dial), ap) + } if err != nil { - err = fmt.Errorf("recording: error starting recording: %w", err) + err = fmt.Errorf("recording: error starting recording on %q: %w", ap, err) attempt.FailureMessage = err.Error() errs = append(errs, err) continue } - // We set the Expect header to 100-continue, so that the recorder - // will send a 100-continue response before it starts reading the - // request body. - req.Header.Set("Expect", "100-continue") + return pw, attempts, errChan, nil + } + return nil, attempts, nil, errors.Join(errs...) +} - // errChan is used to indicate the result of the request. - errChan := make(chan error, 1) - go func() { - resp, err := hc.Do(req) - if err != nil { - errChan <- fmt.Errorf("recording: error starting recording: %w", err) +// supportsV2 checks whether a recorder instance supports the /v2/record +// endpoint. +func supportsV2(ctx context.Context, hc *http.Client, ap netip.AddrPort) bool { + ctx, cancel := context.WithTimeout(ctx, http2ProbeTimeout) + defer cancel() + req, err := http.NewRequestWithContext(ctx, httpm.HEAD, fmt.Sprintf("http://%s/v2/record", ap), nil) + if err != nil { + return false + } + resp, err := hc.Do(req) + if err != nil { + return false + } + defer resp.Body.Close() + return resp.StatusCode == http.StatusOK && resp.ProtoMajor > 1 +} + +// supportsEvent checks whether a recorder instance supports the /v2/event +// endpoint. +func supportsEvent(ctx context.Context, hc *http.Client, ap netip.AddrPort) (bool, error) { + ctx, cancel := context.WithTimeout(ctx, http2ProbeTimeout) + defer cancel() + req, err := http.NewRequestWithContext(ctx, httpm.HEAD, fmt.Sprintf("http://%s/v2/event", ap), nil) + if err != nil { + return false, err + } + resp, err := hc.Do(req) + if err != nil { + return false, err + } + + defer resp.Body.Close() + + if resp.StatusCode == http.StatusOK { + return true, nil + } + + if resp.StatusCode != http.StatusNotFound { + body, err := io.ReadAll(resp.Body) + if err != nil { + // Handle the case where reading the body itself fails + return false, fmt.Errorf("server returned non-OK status: %s, and failed to read body: %w", resp.Status, err) + } + + return false, fmt.Errorf("server returned non-OK status: %d: %s", resp.StatusCode, string(body)) + } + + return false, nil +} + +const addressNotSupportEventv2 = `recorder at address %q does not support "/v2/event" endpoint` + +type EventAPINotSupportedErr struct { + ap netip.AddrPort +} + +func (e EventAPINotSupportedErr) Error() string { + return fmt.Sprintf(addressNotSupportEventv2, e.ap) +} + +// SendEvent sends an event the tsrecorders /v2/event endpoint. +func SendEvent(ap netip.AddrPort, event io.Reader, dial netx.DialFunc) (retErr error) { + ctx, cancel := context.WithCancel(context.Background()) + defer func() { + if retErr != nil { + cancel() + } + }() + + client := clientHTTP1(ctx, dial) + + supported, err := supportsEvent(ctx, client, ap) + if err != nil { + return fmt.Errorf("error checking support for `/v2/event` endpoint: %w", err) + } + + if !supported { + return EventAPINotSupportedErr{ + ap: ap, + } + } + + req, err := http.NewRequestWithContext(ctx, "POST", fmt.Sprintf("http://%s/v2/event", ap.String()), event) + if err != nil { + return fmt.Errorf("error creating request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + + resp, err := client.Do(req) + if err != nil { + return fmt.Errorf("error sending request: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + // Handle the case where reading the body itself fails + return fmt.Errorf("server returned non-OK status: %s, and failed to read body: %w", resp.Status, err) + } + + return fmt.Errorf("server returned non-OK status: %d: %s", resp.StatusCode, string(body)) + } + + return nil +} + +// connectV1 connects to the legacy /record endpoint on the recorder. It is +// used for backwards-compatibility with older tsrecorder instances. +// +// On success, it returns a WriteCloser that can be used to upload the +// recording, and a channel that will be sent an error (or nil) when the upload +// fails or completes. +func connectV1(ctx context.Context, hc *http.Client, ap netip.AddrPort) (io.WriteCloser, <-chan error, error) { + // We dial the recorder and wait for it to send a 100-continue + // response before returning from this function. This ensures that + // the recorder is ready to accept the recording. + + // got100 is closed when we receive the 100-continue response. + got100 := make(chan struct{}) + ctx = httptrace.WithClientTrace(ctx, &httptrace.ClientTrace{ + Got100Continue: func() { + close(got100) + }, + }) + + pr, pw := io.Pipe() + req, err := http.NewRequestWithContext(ctx, "POST", fmt.Sprintf("http://%s/record", ap), pr) + if err != nil { + return nil, nil, err + } + // We set the Expect header to 100-continue, so that the recorder + // will send a 100-continue response before it starts reading the + // request body. + req.Header.Set("Expect", "100-continue") + + // errChan is used to indicate the result of the request. + errChan := make(chan error, 1) + go func() { + defer close(errChan) + resp, err := hc.Do(req) + if err != nil { + errChan <- err + return + } + defer resp.Body.Close() + if resp.StatusCode != 200 { + errChan <- fmt.Errorf("recording: unexpected status: %v", resp.Status) + return + } + }() + select { + case <-got100: + return pw, errChan, nil + case err := <-errChan: + // If we get an error before we get the 100-continue response, + // we need to try another recorder. + if err == nil { + // If the error is nil, we got a 200 response, which + // is unexpected as we haven't sent any data yet. + err = errors.New("recording: unexpected EOF") + } + return nil, nil, err + } +} + +// connectV2 connects to the /v2/record endpoint on the recorder over HTTP/2. +// It explicitly tracks ack frames sent in the response and terminates the +// connection if sent recording data is un-acked for uploadAckWindow. +// +// On success, it returns a WriteCloser that can be used to upload the +// recording, and a channel that will be sent an error (or nil) when the upload +// fails or completes. +func connectV2(ctx context.Context, hc *http.Client, ap netip.AddrPort) (io.WriteCloser, <-chan error, error) { + pr, pw := io.Pipe() + upload := &readCounter{r: pr} + req, err := http.NewRequestWithContext(ctx, "POST", fmt.Sprintf("http://%s/v2/record", ap), upload) + if err != nil { + return nil, nil, err + } + + // With HTTP/2, hc.Do will not block while the request body is being sent. + // It will return immediately and allow us to consume the response body at + // the same time. + resp, err := hc.Do(req) + if err != nil { + return nil, nil, err + } + if resp.StatusCode != http.StatusOK { + resp.Body.Close() + return nil, nil, fmt.Errorf("recording: unexpected status: %v", resp.Status) + } + + errChan := make(chan error, 1) + acks := make(chan int64) + // Read acks from the response and send them to the acks channel. + go func() { + defer close(errChan) + defer close(acks) + defer resp.Body.Close() + defer pw.Close() + dec := json.NewDecoder(resp.Body) + for { + var frame v2ResponseFrame + if err := dec.Decode(&frame); err != nil { + if !errors.Is(err, io.EOF) { + errChan <- fmt.Errorf("recording: unexpected error receiving acks: %w", err) + } return } - if resp.StatusCode != 200 { - errChan <- fmt.Errorf("recording: unexpected status: %v", resp.Status) + if frame.Error != "" { + errChan <- fmt.Errorf("recording: received error from the recorder: %q", frame.Error) return } - errChan <- nil - }() - select { - case <-got100: - case err := <-errChan: - // If we get an error before we get the 100-continue response, - // we need to try another recorder. - if err == nil { - // If the error is nil, we got a 200 response, which - // is unexpected as we haven't sent any data yet. - err = errors.New("recording: unexpected EOF") + select { + case acks <- frame.Ack: + case <-ctx.Done(): + return } - attempt.FailureMessage = err.Error() - errs = append(errs, err) - continue // try the next recorder } - return pw, attempts, errChan, nil - } - return nil, attempts, nil, multierr.New(errs...) + }() + // Track acks from the acks channel. + go func() { + // Hack for tests: some tests modify uploadAckWindow and reset it when + // the test ends. This can race with t.Reset call below. Making a copy + // here is a lazy workaround to not wait for this goroutine to exit in + // the test cases. + uploadAckWindow := uploadAckWindow + // This timer fires if we didn't receive an ack for too long. + t := time.NewTimer(uploadAckWindow) + defer t.Stop() + for { + select { + case <-t.C: + // Close the pipe which terminates the connection and cleans up + // other goroutines. Note that tsrecorder will send us ack + // frames even if there is no new data to ack. This helps + // detect broken recorder connection if the session is idle. + pr.CloseWithError(errNoAcks) + resp.Body.Close() + return + case _, ok := <-acks: + if !ok { + // acks channel closed means that the goroutine reading them + // finished, which means that the request has ended. + return + } + // TODO(awly): limit how far behind the received acks can be. This + // should handle scenarios where a session suddenly dumps a lot of + // output. + t.Reset(uploadAckWindow) + case <-ctx.Done(): + return + } + } + }() + + return pw, errChan, nil } -// SessionRecordingClientForDialer returns an http.Client that uses a clone of -// the provided Dialer's PeerTransport to dial connections. This is used to make -// requests to the session recording server to upload session recordings. It -// uses the provided dialCtx to dial connections, and limits a single dial to 5 -// seconds. -func SessionRecordingClientForDialer(dialCtx context.Context, dial func(context.Context, string, string) (net.Conn, error)) (*http.Client, error) { - tr := http.DefaultTransport.(*http.Transport).Clone() +var errNoAcks = errors.New("did not receive ack frames from the recorder in 30s") + +type v2ResponseFrame struct { + // Ack is the number of bytes received from the client so far. The bytes + // are not guaranteed to be durably stored yet. + Ack int64 `json:"ack,omitempty"` + // Error is an error encountered while storing the recording. Error is only + // ever set as the last frame in the response. + Error string `json:"error,omitempty"` +} + +// readCounter is an io.Reader that counts how many bytes were read. +type readCounter struct { + r io.Reader + sent atomic.Int64 +} +func (u *readCounter) Read(buf []byte) (int, error) { + n, err := u.r.Read(buf) + u.sent.Add(int64(n)) + return n, err +} + +// clientHTTP1 returns a claassic http.Client with a per-dial context. It uses +// dialCtx and adds a 5s timeout to it. +func clientHTTP1(dialCtx context.Context, dial netx.DialFunc) *http.Client { + tr := http.DefaultTransport.(*http.Transport).Clone() tr.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { - perAttemptCtx, cancel := context.WithTimeout(ctx, 5*time.Second) + perAttemptCtx, cancel := context.WithTimeout(ctx, perDialAttemptTimeout) defer cancel() go func() { select { @@ -132,7 +393,30 @@ func SessionRecordingClientForDialer(dialCtx context.Context, dial func(context. }() return dial(perAttemptCtx, network, addr) } + return &http.Client{Transport: tr} +} + +// clientHTTP2 is like clientHTTP1 but returns an http.Client suitable for h2c +// requests (HTTP/2 over plaintext). Unfortunately the same client does not +// work for HTTP/1 so we need to split these up. +func clientHTTP2(dialCtx context.Context, dial netx.DialFunc) *http.Client { + var p http.Protocols + p.SetUnencryptedHTTP2(true) return &http.Client{ - Transport: tr, - }, nil + Transport: &http.Transport{ + Protocols: &p, + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + perAttemptCtx, cancel := context.WithTimeout(ctx, perDialAttemptTimeout) + defer cancel() + go func() { + select { + case <-perAttemptCtx.Done(): + case <-dialCtx.Done(): + cancel() + } + }() + return dial(perAttemptCtx, network, addr) + }, + }, + } } diff --git a/sessionrecording/connect_test.go b/sessionrecording/connect_test.go new file mode 100644 index 000000000..e834828f5 --- /dev/null +++ b/sessionrecording/connect_test.go @@ -0,0 +1,291 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package sessionrecording + +import ( + "bytes" + "context" + "crypto/rand" + "crypto/sha256" + "encoding/json" + "fmt" + "io" + "net" + "net/http" + "net/http/httptest" + "net/netip" + "strings" + "testing" + "time" + + "golang.org/x/net/http2" + "golang.org/x/net/http2/h2c" + "tailscale.com/net/memnet" +) + +func TestConnectToRecorder(t *testing.T) { + tests := []struct { + desc string + http2 bool + // setup returns a recorder server mux, and a channel which sends the + // hash of the recording uploaded to it. The channel is expected to + // fire only once. + setup func(t *testing.T) (*http.ServeMux, <-chan []byte) + wantErr bool + }{ + { + desc: "v1 recorder", + setup: func(t *testing.T) (*http.ServeMux, <-chan []byte) { + uploadHash := make(chan []byte, 1) + mux := http.NewServeMux() + mux.HandleFunc("POST /record", func(w http.ResponseWriter, r *http.Request) { + hash := sha256.New() + if _, err := io.Copy(hash, r.Body); err != nil { + t.Error(err) + } + uploadHash <- hash.Sum(nil) + }) + return mux, uploadHash + }, + }, + { + desc: "v2 recorder", + http2: true, + setup: func(t *testing.T) (*http.ServeMux, <-chan []byte) { + uploadHash := make(chan []byte, 1) + mux := http.NewServeMux() + mux.HandleFunc("POST /record", func(w http.ResponseWriter, r *http.Request) { + t.Error("received request to v1 endpoint") + http.Error(w, "not found", http.StatusNotFound) + }) + mux.HandleFunc("POST /v2/record", func(w http.ResponseWriter, r *http.Request) { + // Force the status to send to unblock the client waiting + // for it. + w.WriteHeader(http.StatusOK) + w.(http.Flusher).Flush() + + body := &readCounter{r: r.Body} + hash := sha256.New() + ctx, cancel := context.WithCancel(r.Context()) + go func() { + defer cancel() + if _, err := io.Copy(hash, body); err != nil { + t.Error(err) + } + }() + + // Send acks for received bytes. + tick := time.NewTicker(time.Millisecond) + defer tick.Stop() + enc := json.NewEncoder(w) + outer: + for { + select { + case <-ctx.Done(): + break outer + case <-tick.C: + if err := enc.Encode(v2ResponseFrame{Ack: body.sent.Load()}); err != nil { + t.Errorf("writing ack frame: %v", err) + break outer + } + } + } + + uploadHash <- hash.Sum(nil) + }) + // Probing HEAD endpoint which always returns 200 OK. + mux.HandleFunc("HEAD /v2/record", func(http.ResponseWriter, *http.Request) {}) + return mux, uploadHash + }, + }, + { + desc: "v2 recorder no acks", + http2: true, + wantErr: true, + setup: func(t *testing.T) (*http.ServeMux, <-chan []byte) { + // Make the client no-ack timeout quick for the test. + oldAckWindow := uploadAckWindow + uploadAckWindow = 100 * time.Millisecond + t.Cleanup(func() { uploadAckWindow = oldAckWindow }) + + uploadHash := make(chan []byte, 1) + mux := http.NewServeMux() + mux.HandleFunc("POST /record", func(w http.ResponseWriter, r *http.Request) { + t.Error("received request to v1 endpoint") + http.Error(w, "not found", http.StatusNotFound) + }) + mux.HandleFunc("POST /v2/record", func(w http.ResponseWriter, r *http.Request) { + // Force the status to send to unblock the client waiting + // for it. + w.WriteHeader(http.StatusOK) + w.(http.Flusher).Flush() + + // Consume the whole request body but don't send any acks + // back. + hash := sha256.New() + if _, err := io.Copy(hash, r.Body); err != nil { + t.Error(err) + } + // Goes in the channel buffer, non-blocking. + uploadHash <- hash.Sum(nil) + + // Block until the parent test case ends to prevent the + // request termination. We want to exercise the ack + // tracking logic specifically. + ctx, cancel := context.WithCancel(r.Context()) + t.Cleanup(cancel) + <-ctx.Done() + }) + mux.HandleFunc("HEAD /v2/record", func(http.ResponseWriter, *http.Request) {}) + return mux, uploadHash + }, + }, + } + for _, tt := range tests { + t.Run(tt.desc, func(t *testing.T) { + mux, uploadHash := tt.setup(t) + + memNet := &memnet.Network{} + ln := memNet.NewLocalTCPListener() + + srv := &httptest.Server{ + Config: &http.Server{Handler: mux}, + Listener: ln, + } + + if tt.http2 { + // Wire up h2c-compatible HTTP/2 server. This is optional + // because the v1 recorder didn't support HTTP/2 and we try to + // mimic that. + s := &http2.Server{} + srv.Config.Handler = h2c.NewHandler(mux, s) + if err := http2.ConfigureServer(srv.Config, s); err != nil { + t.Errorf("configuring HTTP/2 support in server: %v", err) + } + } + srv.Start() + t.Cleanup(srv.Close) + + ctx := context.Background() + w, _, errc, err := ConnectToRecorder(ctx, []netip.AddrPort{netip.MustParseAddrPort(ln.Addr().String())}, memNet.Dial) + if err != nil { + t.Fatalf("ConnectToRecorder: %v", err) + } + + // Send some random data and hash it to compare with the recorded + // data hash. + hash := sha256.New() + const numBytes = 1 << 20 // 1MB + if _, err := io.CopyN(io.MultiWriter(w, hash), rand.Reader, numBytes); err != nil { + t.Fatalf("writing recording data: %v", err) + } + if err := w.Close(); err != nil { + t.Fatalf("closing recording stream: %v", err) + } + if err := <-errc; err != nil && !tt.wantErr { + t.Fatalf("error from the channel: %v", err) + } else if err == nil && tt.wantErr { + t.Fatalf("did not receive expected error from the channel") + } + + if recv, sent := <-uploadHash, hash.Sum(nil); !bytes.Equal(recv, sent) { + t.Errorf("mismatch in recording data hash, sent %x, received %x", sent, recv) + } + }) + } +} + +func TestSendEvent(t *testing.T) { + t.Run("supported", func(t *testing.T) { + eventBody := `{"foo":"bar"}` + eventRecieved := make(chan []byte, 1) + mux := http.NewServeMux() + mux.HandleFunc("HEAD /v2/event", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + mux.HandleFunc("POST /v2/event", func(w http.ResponseWriter, r *http.Request) { + body, err := io.ReadAll(r.Body) + if err != nil { + t.Error(err) + } + eventRecieved <- body + w.WriteHeader(http.StatusOK) + }) + + srv := httptest.NewUnstartedServer(mux) + s := &http2.Server{} + srv.Config.Handler = h2c.NewHandler(mux, s) + if err := http2.ConfigureServer(srv.Config, s); err != nil { + t.Fatalf("configuring HTTP/2 support in server: %v", err) + } + srv.Start() + t.Cleanup(srv.Close) + + d := new(net.Dialer) + addr := netip.MustParseAddrPort(srv.Listener.Addr().String()) + err := SendEvent(addr, bytes.NewBufferString(eventBody), d.DialContext) + if err != nil { + t.Fatalf("SendEvent: %v", err) + } + + if recv := string(<-eventRecieved); recv != eventBody { + t.Errorf("mismatch in event body, sent %q, received %q", eventBody, recv) + } + }) + + t.Run("not_supported", func(t *testing.T) { + mux := http.NewServeMux() + mux.HandleFunc("HEAD /v2/event", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + }) + + srv := httptest.NewUnstartedServer(mux) + s := &http2.Server{} + srv.Config.Handler = h2c.NewHandler(mux, s) + if err := http2.ConfigureServer(srv.Config, s); err != nil { + t.Fatalf("configuring HTTP/2 support in server: %v", err) + } + srv.Start() + t.Cleanup(srv.Close) + + d := new(net.Dialer) + addr := netip.MustParseAddrPort(srv.Listener.Addr().String()) + err := SendEvent(addr, nil, d.DialContext) + if err == nil { + t.Fatal("expected an error, got nil") + } + if !strings.Contains(err.Error(), fmt.Sprintf(addressNotSupportEventv2, srv.Listener.Addr().String())) { + t.Fatalf("unexpected error: %v", err) + } + }) + + t.Run("server_error", func(t *testing.T) { + mux := http.NewServeMux() + mux.HandleFunc("HEAD /v2/event", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + mux.HandleFunc("POST /v2/event", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + }) + + srv := httptest.NewUnstartedServer(mux) + s := &http2.Server{} + srv.Config.Handler = h2c.NewHandler(mux, s) + if err := http2.ConfigureServer(srv.Config, s); err != nil { + t.Fatalf("configuring HTTP/2 support in server: %v", err) + } + srv.Start() + t.Cleanup(srv.Close) + + d := new(net.Dialer) + addr := netip.MustParseAddrPort(srv.Listener.Addr().String()) + err := SendEvent(addr, nil, d.DialContext) + if err == nil { + t.Fatal("expected an error, got nil") + } + if !strings.Contains(err.Error(), "server returned non-OK status") { + t.Fatalf("unexpected error: %v", err) + } + }) +} diff --git a/sessionrecording/event.go b/sessionrecording/event.go new file mode 100644 index 000000000..8f8172cc4 --- /dev/null +++ b/sessionrecording/event.go @@ -0,0 +1,118 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package sessionrecording + +import ( + "net/url" + + "tailscale.com/tailcfg" +) + +const ( + KubernetesAPIEventType = "kubernetes-api-request" +) + +// Event represents the top-level structure of a tsrecorder event. +type Event struct { + // Type specifies the kind of event being recorded (e.g., "kubernetes-api-request"). + Type string `json:"type"` + + // ID is a reference of the path that this event is stored at in tsrecorder + ID string `json:"id"` + + // Timestamp is the time when the event was recorded represented as a unix timestamp. + Timestamp int64 `json:"timestamp"` + + // UserAgent is the UerAgent specified in the request, which helps identify + // the client software that initiated the request. + UserAgent string `json:"userAgent"` + + // Request holds details of the HTTP request. + Request Request `json:"request"` + + // Kubernetes contains Kubernetes-specific information about the request (if + // the type is `kubernetes-api-request`) + Kubernetes KubernetesRequestInfo `json:"kubernetes"` + + // Source provides details about the client that initiated the request. + Source Source `json:"source"` + + // Destination provides details about the node receiving the request. + Destination Destination `json:"destination"` +} + +// copied from https://github.com/kubernetes/kubernetes/blob/11ade2f7dd264c2f52a4a1342458abbbaa3cb2b1/staging/src/k8s.io/apiserver/pkg/endpoints/request/requestinfo.go#L44 +// KubernetesRequestInfo contains Kubernetes specific information in the request (if the type is `kubernetes-api-request`) +type KubernetesRequestInfo struct { + // IsResourceRequest indicates whether or not the request is for an API resource or subresource + IsResourceRequest bool + // Path is the URL path of the request + Path string + // Verb is the kube verb associated with the request for API requests, not the http verb. This includes things like list and watch. + // for non-resource requests, this is the lowercase http verb + Verb string + + APIPrefix string + APIGroup string + APIVersion string + + Namespace string + // Resource is the name of the resource being requested. This is not the kind. For example: pods + Resource string + // Subresource is the name of the subresource being requested. This is a different resource, scoped to the parent resource, but it may have a different kind. + // For instance, /pods has the resource "pods" and the kind "Pod", while /pods/foo/status has the resource "pods", the sub resource "status", and the kind "Pod" + // (because status operates on pods). The binding resource for a pod though may be /pods/foo/binding, which has resource "pods", subresource "binding", and kind "Binding". + Subresource string + // Name is empty for some verbs, but if the request directly indicates a name (not in body content) then this field is filled in. + Name string + // Parts are the path parts for the request, always starting with /{resource}/{name} + Parts []string + + // FieldSelector contains the unparsed field selector from a request. It is only present if the apiserver + // honors field selectors for the verb this request is associated with. + FieldSelector string + // LabelSelector contains the unparsed field selector from a request. It is only present if the apiserver + // honors field selectors for the verb this request is associated with. + LabelSelector string +} + +type Source struct { + // Node is the FQDN of the node originating the connection. + // It is also the MagicDNS name for the node. + // It does not have a trailing dot. + // e.g. "host.tail-scale.ts.net" + Node string `json:"node"` + + // NodeID is the node ID of the node originating the connection. + NodeID tailcfg.StableNodeID `json:"nodeID"` + + // Tailscale-specific fields: + // NodeTags is the list of tags on the node originating the connection (if any). + NodeTags []string `json:"nodeTags,omitempty"` + + // NodeUserID is the user ID of the node originating the connection (if not tagged). + NodeUserID tailcfg.UserID `json:"nodeUserID,omitempty"` // if not tagged + + // NodeUser is the LoginName of the node originating the connection (if not tagged). + NodeUser string `json:"nodeUser,omitempty"` +} + +type Destination struct { + // Node is the FQDN of the node receiving the connection. + // It is also the MagicDNS name for the node. + // It does not have a trailing dot. + // e.g. "host.tail-scale.ts.net" + Node string `json:"node"` + + // NodeID is the node ID of the node receiving the connection. + NodeID tailcfg.StableNodeID `json:"nodeID"` +} + +// Request holds information about a request. +type Request struct { + Method string `json:"method"` + Path string `json:"path"` + Body []byte `json:"body"` + QueryParameters url.Values `json:"queryParameters"` +} diff --git a/sessionrecording/header.go b/sessionrecording/header.go index 4806f6585..220852216 100644 --- a/sessionrecording/header.go +++ b/sessionrecording/header.go @@ -62,17 +62,18 @@ type CastHeader struct { ConnectionID string `json:"connectionID"` // Fields that are only set for Kubernetes API server proxy session recordings: - Kubernetes *Kubernetes `json:"kubernetes,omitempty"` } -// Kubernetes contains 'kubectl exec' session specific information for +// Kubernetes contains 'kubectl exec/attach' session specific information for // tsrecorder. type Kubernetes struct { - // PodName is the name of the Pod being exec-ed. + // PodName is the name of the Pod the session was recorded for. PodName string - // Namespace is the namespace in which is the Pod that is being exec-ed. + // Namespace is the namespace in which the Pod the session was recorded for exists in. Namespace string - // Container is the container being exec-ed. + // Container is the container the session was recorded for. Container string + // SessionType is the type of session that was executed (e.g., exec, attach) + SessionType string } diff --git a/shell.nix b/shell.nix index 4d2e24366..ffb28a183 100644 --- a/shell.nix +++ b/shell.nix @@ -16,4 +16,4 @@ ) { src = ./.; }).shellNix -# nix-direnv cache busting line: sha256-xO1DuLWi6/lpA9ubA2ZYVJM+CkVNA5IaVGZxX9my0j0= +# nix-direnv cache busting line: sha256-sGPgML2YM/XNWfsAdDZvzWHagcydwCmR6nKOHJj5COs= diff --git a/smallzstd/testdata b/smallzstd/testdata deleted file mode 100644 index 76640fdc5..000000000 --- a/smallzstd/testdata +++ /dev/null @@ -1,14 +0,0 @@ -{"logtail":{"client_time":"2020-07-01T14:49:40.196597018-07:00","server_time":"2020-07-01T21:49:40.198371511Z"},"text":"9.8M/25.6M magicsock: starting endpoint update (periodic)\n"} -{"logtail":{"client_time":"2020-07-01T14:49:40.345925455-07:00","server_time":"2020-07-01T21:49:40.347904717Z"},"text":"9.9M/25.6M netcheck: udp=true v6=false mapvarydest=false hair=false v4a=202.188.7.1:41641 derp=2 derpdist=1v4:7ms,2v4:3ms,4v4:18ms\n"} -{"logtail":{"client_time":"2020-07-01T14:49:43.347155742-07:00","server_time":"2020-07-01T21:49:43.34828658Z"},"text":"9.9M/25.6M control: map response long-poll timed out!\n"} -{"logtail":{"client_time":"2020-07-01T14:49:43.347539333-07:00","server_time":"2020-07-01T21:49:43.358809354Z"},"text":"9.9M/25.6M control: PollNetMap: context canceled\n"} -{"logtail":{"client_time":"2020-07-01T14:49:43.347767812-07:00","server_time":"2020-07-01T21:49:43.358809354Z"},"text":"10.0M/25.6M control: sendStatus: mapRoutine1: state:authenticated\n"} -{"logtail":{"client_time":"2020-07-01T14:49:43.347817165-07:00","server_time":"2020-07-01T21:49:43.358809354Z"},"text":"10.0M/25.6M blockEngineUpdates(false)\n"} -{"logtail":{"client_time":"2020-07-01T14:49:43.347989028-07:00","server_time":"2020-07-01T21:49:43.358809354Z"},"text":"10.0M/25.6M wgcfg: [SViTM] skipping subnet route\n"} -{"logtail":{"client_time":"2020-07-01T14:49:43.349997554-07:00","server_time":"2020-07-01T21:49:43.358809354Z"},"text":"9.3M/25.6M Received error: PollNetMap: context canceled\n"} -{"logtail":{"client_time":"2020-07-01T14:49:43.350072606-07:00","server_time":"2020-07-01T21:49:43.358809354Z"},"text":"9.3M/25.6M control: mapRoutine: backoff: 30136 msec\n"} -{"logtail":{"client_time":"2020-07-01T14:49:47.998364646-07:00","server_time":"2020-07-01T21:49:47.999333754Z"},"text":"9.5M/25.6M [W1NbE] - [UcppE] Send handshake init [127.3.3.40:1, 6.1.1.6:37388*, 10.3.2.6:41641]\n"} -{"logtail":{"client_time":"2020-07-01T14:49:47.99881914-07:00","server_time":"2020-07-01T21:49:48.009859543Z"},"text":"9.6M/25.6M magicsock: adding connection to derp-1 for [W1NbE]\n"} -{"logtail":{"client_time":"2020-07-01T14:49:47.998904932-07:00","server_time":"2020-07-01T21:49:48.009859543Z"},"text":"9.6M/25.6M magicsock: 2 active derp conns: derp-1=cr0s,wr0s derp-2=cr16h0m0s,wr14h38m0s\n"} -{"logtail":{"client_time":"2020-07-01T14:49:47.999045606-07:00","server_time":"2020-07-01T21:49:48.009859543Z"},"text":"9.6M/25.6M derphttp.Client.Recv: connecting to derp-1 (nyc)\n"} -{"logtail":{"client_time":"2020-07-01T14:49:48.091104119-07:00","server_time":"2020-07-01T21:49:48.09280535Z"},"text":"9.6M/25.6M magicsock: rx [W1NbE] from 6.1.1.6:37388 (1/3), set as new priority\n"} diff --git a/smallzstd/zstd.go b/smallzstd/zstd.go deleted file mode 100644 index 1d8085422..000000000 --- a/smallzstd/zstd.go +++ /dev/null @@ -1,78 +0,0 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package smallzstd produces zstd encoders and decoders optimized for -// low memory usage, at the expense of compression efficiency. -// -// This package is optimized primarily for the memory cost of -// compressing and decompressing data. We reduce this cost in two -// major ways: disable parallelism within the library (i.e. don't use -// multiple CPU cores to decompress), and drop the compression window -// down from the defaults of 4-16MiB, to 8kiB. -// -// Decompressors cost 2x the window size in RAM to run, so by using an -// 8kiB window, we can run ~1000 more decompressors per unit of memory -// than with the defaults. -// -// Depending on context, the benefit is either being able to run more -// decoders (e.g. in our logs processing system), or having a lower -// memory footprint when using compression in network protocols -// (e.g. in tailscaled, which should have a minimal RAM cost). -package smallzstd - -import ( - "io" - - "github.com/klauspost/compress/zstd" -) - -// WindowSize is the window size used for zstd compression. Decoder -// memory usage scales linearly with WindowSize. -const WindowSize = 8 << 10 // 8kiB - -// NewDecoder returns a zstd.Decoder configured for low memory usage, -// at the expense of decompression performance. -func NewDecoder(r io.Reader, options ...zstd.DOption) (*zstd.Decoder, error) { - defaults := []zstd.DOption{ - // Default is GOMAXPROCS, which costs many KiB in stacks. - zstd.WithDecoderConcurrency(1), - // Default is to allocate more upfront for performance. We - // prefer lower memory use and a bit of GC load. - zstd.WithDecoderLowmem(true), - // You might expect to see zstd.WithDecoderMaxMemory - // here. However, it's not terribly safe to use if you're - // doing stateless decoding, because it sets the maximum - // amount of memory the decompressed data can occupy, rather - // than the window size of the zstd stream. This means a very - // compressible piece of data might violate the max memory - // limit here, even if the window size (and thus total memory - // required to decompress the data) is small. - // - // As a result, we don't set a decoder limit here, and rely on - // the encoder below producing "cheap" streams. Callers are - // welcome to set their own max memory setting, if - // contextually there is a clearly correct value (e.g. it's - // known from the upper layer protocol that the decoded data - // can never be more than 1MiB). - } - - return zstd.NewReader(r, append(defaults, options...)...) -} - -// NewEncoder returns a zstd.Encoder configured for low memory usage, -// both during compression and at decompression time, at the expense -// of performance and compression efficiency. -func NewEncoder(w io.Writer, options ...zstd.EOption) (*zstd.Encoder, error) { - defaults := []zstd.EOption{ - // Default is GOMAXPROCS, which costs many KiB in stacks. - zstd.WithEncoderConcurrency(1), - // Default is several MiB, which bloats both encoders and - // their corresponding decoders. - zstd.WithWindowSize(WindowSize), - // Encode zero-length inputs in a way that the `zstd` utility - // can read, because interoperability is handy. - zstd.WithZeroFrames(true), - } - - return zstd.NewWriter(w, append(defaults, options...)...) -} diff --git a/smallzstd/zstd_test.go b/smallzstd/zstd_test.go deleted file mode 100644 index d1225bfac..000000000 --- a/smallzstd/zstd_test.go +++ /dev/null @@ -1,130 +0,0 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package smallzstd - -import ( - "os" - "testing" - - "github.com/klauspost/compress/zstd" -) - -func BenchmarkSmallEncoder(b *testing.B) { - benchEncoder(b, func() (*zstd.Encoder, error) { return NewEncoder(nil) }) -} - -func BenchmarkSmallEncoderWithBuild(b *testing.B) { - benchEncoderWithConstruction(b, func() (*zstd.Encoder, error) { return NewEncoder(nil) }) -} - -func BenchmarkStockEncoder(b *testing.B) { - benchEncoder(b, func() (*zstd.Encoder, error) { return zstd.NewWriter(nil) }) -} - -func BenchmarkStockEncoderWithBuild(b *testing.B) { - benchEncoderWithConstruction(b, func() (*zstd.Encoder, error) { return zstd.NewWriter(nil) }) -} - -func BenchmarkSmallDecoder(b *testing.B) { - benchDecoder(b, func() (*zstd.Decoder, error) { return NewDecoder(nil) }) -} - -func BenchmarkSmallDecoderWithBuild(b *testing.B) { - benchDecoderWithConstruction(b, func() (*zstd.Decoder, error) { return NewDecoder(nil) }) -} - -func BenchmarkStockDecoder(b *testing.B) { - benchDecoder(b, func() (*zstd.Decoder, error) { return zstd.NewReader(nil) }) -} - -func BenchmarkStockDecoderWithBuild(b *testing.B) { - benchDecoderWithConstruction(b, func() (*zstd.Decoder, error) { return zstd.NewReader(nil) }) -} - -func benchEncoder(b *testing.B, mk func() (*zstd.Encoder, error)) { - b.ReportAllocs() - - in := testdata(b) - out := make([]byte, 0, 10<<10) // 10kiB - - e, err := mk() - if err != nil { - b.Fatalf("making encoder: %v", err) - } - - b.ResetTimer() - for range b.N { - e.EncodeAll(in, out) - } -} - -func benchEncoderWithConstruction(b *testing.B, mk func() (*zstd.Encoder, error)) { - b.ReportAllocs() - - in := testdata(b) - out := make([]byte, 0, 10<<10) // 10kiB - - b.ResetTimer() - for range b.N { - e, err := mk() - if err != nil { - b.Fatalf("making encoder: %v", err) - } - - e.EncodeAll(in, out) - } -} - -func benchDecoder(b *testing.B, mk func() (*zstd.Decoder, error)) { - b.ReportAllocs() - - in := compressedTestdata(b) - out := make([]byte, 0, 10<<10) - - d, err := mk() - if err != nil { - b.Fatalf("creating decoder: %v", err) - } - - b.ResetTimer() - for range b.N { - d.DecodeAll(in, out) - } -} - -func benchDecoderWithConstruction(b *testing.B, mk func() (*zstd.Decoder, error)) { - b.ReportAllocs() - - in := compressedTestdata(b) - out := make([]byte, 0, 10<<10) - - b.ResetTimer() - for range b.N { - d, err := mk() - if err != nil { - b.Fatalf("creating decoder: %v", err) - } - - d.DecodeAll(in, out) - } -} - -func testdata(b *testing.B) []byte { - b.Helper() - in, err := os.ReadFile("testdata") - if err != nil { - b.Fatalf("reading testdata: %v", err) - } - return in -} - -func compressedTestdata(b *testing.B) []byte { - b.Helper() - uncomp := testdata(b) - e, err := NewEncoder(nil) - if err != nil { - b.Fatalf("creating encoder: %v", err) - } - return e.EncodeAll(uncomp, nil) -} diff --git a/ssh/tailssh/incubator.go b/ssh/tailssh/incubator.go index 7748376b2..f75646771 100644 --- a/ssh/tailssh/incubator.go +++ b/ssh/tailssh/incubator.go @@ -7,16 +7,18 @@ // and groups to the specified `--uid`, `--gid` and `--groups`, and // then launches the requested `--cmd`. -//go:build linux || (darwin && !ios) || freebsd || openbsd +//go:build (linux && !android) || (darwin && !ios) || freebsd || openbsd package tailssh import ( + "context" "encoding/json" "errors" "flag" "fmt" "io" + "io/fs" "log" "log/syslog" "os" @@ -29,6 +31,7 @@ import ( "strings" "sync/atomic" "syscall" + "time" "github.com/creack/pty" "github.com/pkg/sftp" @@ -43,6 +46,14 @@ import ( "tailscale.com/version/distro" ) +const ( + linux = "linux" + darwin = "darwin" + freebsd = "freebsd" + openbsd = "openbsd" + windows = "windows" +) + func init() { childproc.Add("ssh", beIncubator) childproc.Add("sftp", beSFTP) @@ -63,16 +74,50 @@ var maybeStartLoginSession = func(dlogf logger.Logf, ia incubatorArgs) (close fu return nil } +// truePaths are the common locations to find the true binary, in likelihood order. +var truePaths = [...]string{"/usr/bin/true", "/bin/true"} + +// tryExecInDir tries to run a command in dir and returns nil if it succeeds. +// Otherwise, it returns a filesystem error or a timeout error if the command +// took too long. +func tryExecInDir(ctx context.Context, dir string) error { + ctx, cancel := context.WithTimeout(ctx, 10*time.Second) + defer cancel() + + run := func(path string) error { + cmd := exec.CommandContext(ctx, path) + cmd.Dir = dir + return cmd.Run() + } + + // Assume that the following executables exist, are executable, and + // immediately return. + if runtime.GOOS == windows { + windir := os.Getenv("windir") + return run(filepath.Join(windir, "system32", "doskey.exe")) + } + // Execute the first "true" we find in the list. + for _, path := range truePaths { + // Note: LookPath does not consult $PATH when passed multi-label paths. + if p, err := exec.LookPath(path); err == nil { + return run(p) + } + } + return exec.ErrNotFound +} + // newIncubatorCommand returns a new exec.Cmd configured with // `tailscaled be-child ssh` as the entrypoint. // -// If ss.srv.tailscaledPath is empty, this method is equivalent to -// exec.CommandContext. +// If ss.srv.tailscaledPath is empty, this method is almost equivalent to +// exec.CommandContext. It will refuse to run in SFTP-mode. It will simulate the +// behavior of SSHD when by falling back to the root directory if it cannot run +// a command in the user’s home directory. // // The returned Cmd.Env is guaranteed to be nil; the caller populates it. func (ss *sshSession) newIncubatorCommand(logf logger.Logf) (cmd *exec.Cmd, err error) { defer func() { - if cmd.Env != nil { + if cmd != nil && cmd.Env != nil { panic("internal error") } }() @@ -97,7 +142,35 @@ func (ss *sshSession) newIncubatorCommand(logf logger.Logf) (cmd *exec.Cmd, err loginShell := ss.conn.localUser.LoginShell() args := shellArgs(isShell, ss.RawCommand()) logf("directly running %s %q", loginShell, args) - return exec.CommandContext(ss.ctx, loginShell, args...), nil + cmd = exec.CommandContext(ss.ctx, loginShell, args...) + + // While running directly instead of using `tailscaled be-child`, + // do what sshd does by running inside the home directory, + // falling back to the root directory it doesn't have permissions. + // This can happen if the system has networked home directories, + // i.e. NFS or SMB, which enable root-squashing by default. + cmd.Dir = ss.conn.localUser.HomeDir + err := tryExecInDir(ss.ctx, cmd.Dir) + switch { + case errors.Is(err, exec.ErrNotFound): + // /bin/true might not be installed on a barebones system, + // so we assume that the home directory does not exist. + cmd.Dir = "/" + case errors.Is(err, fs.ErrPermission) || errors.Is(err, fs.ErrNotExist): + // Ensure that cmd.Dir is the source of the error. + var pathErr *fs.PathError + if errors.As(err, &pathErr) && pathErr.Path == cmd.Dir { + // If we cannot run loginShell in localUser.HomeDir, + // we will try to run this command in the root directory. + cmd.Dir = "/" + } else { + return nil, err + } + case err != nil: + return nil, err + } + + return cmd, nil } lu := ss.conn.localUser @@ -126,7 +199,7 @@ func (ss *sshSession) newIncubatorCommand(logf logger.Logf) (cmd *exec.Cmd, err // We have to check the below outside of the incubator process, because it // relies on the "getenforce" command being on the PATH, which it is not // when in the incubator. - if runtime.GOOS == "linux" && hostinfo.IsSELinuxEnforcing() { + if runtime.GOOS == linux && hostinfo.IsSELinuxEnforcing() { incubatorArgs = append(incubatorArgs, "--is-selinux-enforcing") } @@ -171,7 +244,10 @@ func (ss *sshSession) newIncubatorCommand(logf logger.Logf) (cmd *exec.Cmd, err } } - return exec.CommandContext(ss.ctx, ss.conn.srv.tailscaledPath, incubatorArgs...), nil + cmd = exec.CommandContext(ss.ctx, ss.conn.srv.tailscaledPath, incubatorArgs...) + // The incubator will chdir into the home directory after it drops privileges. + cmd.Dir = "/" + return cmd, nil } var debugIncubator bool @@ -210,8 +286,6 @@ type incubatorArgs struct { debugTest bool isSELinuxEnforcing bool encodedEnv string - allowListEnvKeys string - forwardedEnviron []string } func parseIncubatorArgs(args []string) (incubatorArgs, error) { @@ -246,31 +320,47 @@ func parseIncubatorArgs(args []string) (incubatorArgs, error) { ia.gids = append(ia.gids, gid) } - ia.forwardedEnviron = os.Environ() + return ia, nil +} + +// forwardedEnviron returns the concatenation of the current environment with +// any environment variables specified in ia.encodedEnv. +// +// It also returns allowedExtraKeys, containing the env keys that were passed in +// to ia.encodedEnv. +func (ia incubatorArgs) forwardedEnviron() (env, allowedExtraKeys []string, err error) { + environ := os.Environ() + // pass through SSH_AUTH_SOCK environment variable to support ssh agent forwarding - ia.allowListEnvKeys = "SSH_AUTH_SOCK" + // TODO(bradfitz,percy): why is this listed specially? If the parent wanted to included + // it, couldn't it have just passed it to the incubator in encodedEnv? + // If it didn't, no reason for us to pass it to "su -w ..." if it's not in our env + // anyway? (Surely we don't want to inherit the tailscaled parent SSH_AUTH_SOCK, if any) + allowedExtraKeys = []string{"SSH_AUTH_SOCK"} if ia.encodedEnv != "" { unquoted, err := strconv.Unquote(ia.encodedEnv) if err != nil { - return ia, fmt.Errorf("unable to parse encodedEnv %q: %w", ia.encodedEnv, err) + return nil, nil, fmt.Errorf("unable to parse encodedEnv %q: %w", ia.encodedEnv, err) } var extraEnviron []string err = json.Unmarshal([]byte(unquoted), &extraEnviron) if err != nil { - return ia, fmt.Errorf("unable to parse encodedEnv %q: %w", ia.encodedEnv, err) + return nil, nil, fmt.Errorf("unable to parse encodedEnv %q: %w", ia.encodedEnv, err) } - ia.forwardedEnviron = append(ia.forwardedEnviron, extraEnviron...) + environ = append(environ, extraEnviron...) - for _, v := range extraEnviron { - ia.allowListEnvKeys = fmt.Sprintf("%s,%s", ia.allowListEnvKeys, strings.Split(v, "=")[0]) + for _, kv := range extraEnviron { + if k, _, ok := strings.Cut(kv, "="); ok { + allowedExtraKeys = append(allowedExtraKeys, k) + } } } - return ia, nil + return environ, allowedExtraKeys, nil } // beIncubator is the entrypoint to the `tailscaled be-child ssh` subcommand. @@ -426,13 +516,13 @@ func tryExecLogin(dlogf logger.Logf, ia incubatorArgs) error { // Only the macOS version of the login command supports executing a // command, all other versions only support launching a shell without // taking any arguments. - if !ia.isShell && runtime.GOOS != "darwin" { + if !ia.isShell && runtime.GOOS != darwin { dlogf("won't use login because we're not in a shell or on macOS") return nil } switch runtime.GOOS { - case "linux", "freebsd", "openbsd": + case linux, freebsd, openbsd: if !ia.hasTTY { dlogf("can't use login because of missing TTY") // We can only use the login command if a shell was requested with @@ -450,8 +540,13 @@ func tryExecLogin(dlogf logger.Logf, ia incubatorArgs) error { loginArgs := ia.loginArgs(loginCmdPath) dlogf("logging in with %+v", loginArgs) + environ, _, err := ia.forwardedEnviron() + if err != nil { + return err + } + // If Exec works, the Go code will not proceed past this: - err = unix.Exec(loginCmdPath, loginArgs, ia.forwardedEnviron) + err = unix.Exec(loginCmdPath, loginArgs, environ) // If we made it here, Exec failed. return err @@ -484,9 +579,14 @@ func trySU(dlogf logger.Logf, ia incubatorArgs) (handled bool, err error) { defer sessionCloser() } + environ, allowListEnvKeys, err := ia.forwardedEnviron() + if err != nil { + return false, err + } + loginArgs := []string{ su, - "-w", ia.allowListEnvKeys, + "-w", strings.Join(allowListEnvKeys, ","), "-l", ia.localUser, } @@ -498,7 +598,7 @@ func trySU(dlogf logger.Logf, ia incubatorArgs) (handled bool, err error) { dlogf("logging in with %+v", loginArgs) // If Exec works, the Go code will not proceed past this: - err = unix.Exec(su, loginArgs, ia.forwardedEnviron) + err = unix.Exec(su, loginArgs, environ) // If we made it here, Exec failed. return true, err @@ -511,7 +611,7 @@ func trySU(dlogf logger.Logf, ia incubatorArgs) (handled bool, err error) { func findSU(dlogf logger.Logf, ia incubatorArgs) string { // Currently, we only support falling back to su on Linux. This // potentially could work on BSDs as well, but requires testing. - if runtime.GOOS != "linux" { + if runtime.GOOS != linux { return "" } @@ -527,11 +627,16 @@ func findSU(dlogf logger.Logf, ia incubatorArgs) string { return "" } + _, allowListEnvKeys, err := ia.forwardedEnviron() + if err != nil { + return "" + } + // First try to execute su -w -l -c true // to make sure su supports the necessary arguments. err = exec.Command( su, - "-w", ia.allowListEnvKeys, + "-w", strings.Join(allowListEnvKeys, ","), "-l", ia.localUser, "-c", "true", @@ -558,10 +663,15 @@ func handleSSHInProcess(dlogf logger.Logf, ia incubatorArgs) error { return err } + environ, _, err := ia.forwardedEnviron() + if err != nil { + return err + } + args := shellArgs(ia.isShell, ia.cmd) dlogf("running %s %q", ia.loginShell, args) - cmd := newCommand(ia.hasTTY, ia.loginShell, ia.forwardedEnviron, args) - err := cmd.Run() + cmd := newCommand(ia.hasTTY, ia.loginShell, environ, args) + err = cmd.Run() if ee, ok := err.(*exec.ExitError); ok { ps := ee.ProcessState code := ps.ExitCode() @@ -637,7 +747,7 @@ func doDropPrivileges(dlogf logger.Logf, wantUid, wantGid int, supplementaryGrou euid := os.Geteuid() egid := os.Getegid() - if runtime.GOOS == "darwin" || runtime.GOOS == "freebsd" { + if runtime.GOOS == darwin || runtime.GOOS == freebsd { // On FreeBSD and Darwin, the first entry returned from the // getgroups(2) syscall is the egid, and changing it with // setgroups(2) changes the egid of the process. This is @@ -736,7 +846,6 @@ func (ss *sshSession) launchProcess() error { } cmd := ss.cmd - cmd.Dir = "/" cmd.Env = envForUser(ss.conn.localUser) for _, kv := range ss.Environ() { if acceptEnvPair(kv) { @@ -992,10 +1101,10 @@ func (ss *sshSession) startWithStdPipes() (err error) { func envForUser(u *userMeta) []string { return []string{ - fmt.Sprintf("SHELL=" + u.LoginShell()), - fmt.Sprintf("USER=" + u.Username), - fmt.Sprintf("HOME=" + u.HomeDir), - fmt.Sprintf("PATH=" + defaultPathForUser(&u.User)), + fmt.Sprintf("SHELL=%s", u.LoginShell()), + fmt.Sprintf("USER=%s", u.Username), + fmt.Sprintf("HOME=%s", u.HomeDir), + fmt.Sprintf("PATH=%s", defaultPathForUser(&u.User)), } } @@ -1029,7 +1138,7 @@ func fileExists(path string) bool { // loginArgs returns the arguments to use to exec the login binary. func (ia *incubatorArgs) loginArgs(loginCmdPath string) []string { switch runtime.GOOS { - case "darwin": + case darwin: args := []string{ loginCmdPath, "-f", // already authenticated @@ -1049,7 +1158,7 @@ func (ia *incubatorArgs) loginArgs(loginCmdPath string) []string { } return args - case "linux": + case linux: if distro.Get() == distro.Arch && !fileExists("/etc/pam.d/remote") { // See https://github.com/tailscale/tailscale/issues/4924 // @@ -1059,7 +1168,7 @@ func (ia *incubatorArgs) loginArgs(loginCmdPath string) []string { return []string{loginCmdPath, "-f", ia.localUser, "-p"} } return []string{loginCmdPath, "-f", ia.localUser, "-h", ia.remoteIP, "-p"} - case "freebsd", "openbsd": + case freebsd, openbsd: return []string{loginCmdPath, "-fp", "-h", ia.remoteIP, ia.localUser} } panic("unimplemented") @@ -1067,6 +1176,10 @@ func (ia *incubatorArgs) loginArgs(loginCmdPath string) []string { func shellArgs(isShell bool, cmd string) []string { if isShell { + if runtime.GOOS == freebsd || runtime.GOOS == openbsd { + // bsd shells don't support the "-l" option, so we can't run as a login shell + return []string{} + } return []string{"-l"} } else { return []string{"-c", cmd} @@ -1074,7 +1187,7 @@ func shellArgs(isShell bool, cmd string) []string { } func setGroups(groupIDs []int) error { - if runtime.GOOS == "darwin" && len(groupIDs) > 16 { + if runtime.GOOS == darwin && len(groupIDs) > 16 { // darwin returns "invalid argument" if more than 16 groups are passed to syscall.Setgroups // some info can be found here: // https://opensource.apple.com/source/samba/samba-187.8/patches/support-darwin-initgroups-syscall.auto.html diff --git a/ssh/tailssh/incubator_linux.go b/ssh/tailssh/incubator_linux.go index bcbe0e240..4dfb9f27c 100644 --- a/ssh/tailssh/incubator_linux.go +++ b/ssh/tailssh/incubator_linux.go @@ -1,7 +1,7 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -//go:build linux +//go:build linux && !android package tailssh diff --git a/ssh/tailssh/incubator_plan9.go b/ssh/tailssh/incubator_plan9.go new file mode 100644 index 000000000..61b6a54eb --- /dev/null +++ b/ssh/tailssh/incubator_plan9.go @@ -0,0 +1,421 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// This file contains the plan9-specific version of the incubator. Tailscaled +// launches the incubator as the same user as it was launched as. The +// incubator then registers a new session with the OS, sets its UID +// and groups to the specified `--uid`, `--gid` and `--groups`, and +// then launches the requested `--cmd`. + +package tailssh + +import ( + "encoding/json" + "errors" + "flag" + "fmt" + "io" + "log" + "os" + "os/exec" + "runtime" + "strconv" + "strings" + "sync/atomic" + + "github.com/go4org/plan9netshell" + "github.com/pkg/sftp" + "tailscale.com/cmd/tailscaled/childproc" + "tailscale.com/tailcfg" + "tailscale.com/types/logger" +) + +func init() { + childproc.Add("ssh", beIncubator) + childproc.Add("sftp", beSFTP) + childproc.Add("plan9-netshell", beNetshell) +} + +// newIncubatorCommand returns a new exec.Cmd configured with +// `tailscaled be-child ssh` as the entrypoint. +// +// If ss.srv.tailscaledPath is empty, this method is equivalent to +// exec.CommandContext. +// +// The returned Cmd.Env is guaranteed to be nil; the caller populates it. +func (ss *sshSession) newIncubatorCommand(logf logger.Logf) (cmd *exec.Cmd, err error) { + defer func() { + if cmd.Env != nil { + panic("internal error") + } + }() + + var isSFTP, isShell bool + switch ss.Subsystem() { + case "sftp": + isSFTP = true + case "": + isShell = ss.RawCommand() == "" + default: + panic(fmt.Sprintf("unexpected subsystem: %v", ss.Subsystem())) + } + + if ss.conn.srv.tailscaledPath == "" { + if isSFTP { + // SFTP relies on the embedded Go-based SFTP server in tailscaled, + // so without tailscaled, we can't serve SFTP. + return nil, errors.New("no tailscaled found on path, can't serve SFTP") + } + + loginShell := ss.conn.localUser.LoginShell() + logf("directly running /bin/rc -c %q", ss.RawCommand()) + return exec.CommandContext(ss.ctx, loginShell, "-c", ss.RawCommand()), nil + } + + lu := ss.conn.localUser + ci := ss.conn.info + remoteUser := ci.uprof.LoginName + if ci.node.IsTagged() { + remoteUser = strings.Join(ci.node.Tags().AsSlice(), ",") + } + + incubatorArgs := []string{ + "be-child", + "ssh", + // TODO: "--uid=" + lu.Uid, + // TODO: "--gid=" + lu.Gid, + "--local-user=" + lu.Username, + "--home-dir=" + lu.HomeDir, + "--remote-user=" + remoteUser, + "--remote-ip=" + ci.src.Addr().String(), + "--has-tty=false", // updated in-place by startWithPTY + "--tty-name=", // updated in-place by startWithPTY + } + + nm := ss.conn.srv.lb.NetMap() + forceV1Behavior := nm.HasCap(tailcfg.NodeAttrSSHBehaviorV1) && !nm.HasCap(tailcfg.NodeAttrSSHBehaviorV2) + if forceV1Behavior { + incubatorArgs = append(incubatorArgs, "--force-v1-behavior") + } + + if debugTest.Load() { + incubatorArgs = append(incubatorArgs, "--debug-test") + } + + switch { + case isSFTP: + // Note that we include both the `--sftp` flag and a command to launch + // tailscaled as `be-child sftp`. If login or su is available, and + // we're not running with tailcfg.NodeAttrSSHBehaviorV1, this will + // result in serving SFTP within a login shell, with full PAM + // integration. Otherwise, we'll serve SFTP in the incubator process + // with no PAM integration. + incubatorArgs = append(incubatorArgs, "--sftp", fmt.Sprintf("--cmd=%s be-child sftp", ss.conn.srv.tailscaledPath)) + case isShell: + incubatorArgs = append(incubatorArgs, "--shell") + default: + incubatorArgs = append(incubatorArgs, "--cmd="+ss.RawCommand()) + } + + allowSendEnv := nm.HasCap(tailcfg.NodeAttrSSHEnvironmentVariables) + if allowSendEnv { + env, err := filterEnv(ss.conn.acceptEnv, ss.Session.Environ()) + if err != nil { + return nil, err + } + + if len(env) > 0 { + encoded, err := json.Marshal(env) + if err != nil { + return nil, fmt.Errorf("failed to encode environment: %w", err) + } + incubatorArgs = append(incubatorArgs, fmt.Sprintf("--encoded-env=%q", encoded)) + } + } + + return exec.CommandContext(ss.ctx, ss.conn.srv.tailscaledPath, incubatorArgs...), nil +} + +var debugTest atomic.Bool + +type stdRWC struct{} + +func (stdRWC) Read(p []byte) (n int, err error) { + return os.Stdin.Read(p) +} + +func (stdRWC) Write(b []byte) (n int, err error) { + return os.Stdout.Write(b) +} + +func (stdRWC) Close() error { + os.Exit(0) + return nil +} + +type incubatorArgs struct { + localUser string + homeDir string + remoteUser string + remoteIP string + ttyName string + hasTTY bool + cmd string + isSFTP bool + isShell bool + forceV1Behavior bool + debugTest bool + isSELinuxEnforcing bool + encodedEnv string +} + +func parseIncubatorArgs(args []string) (incubatorArgs, error) { + var ia incubatorArgs + + flags := flag.NewFlagSet("", flag.ExitOnError) + flags.StringVar(&ia.localUser, "local-user", "", "the user to run as") + flags.StringVar(&ia.homeDir, "home-dir", "/", "the user's home directory") + flags.StringVar(&ia.remoteUser, "remote-user", "", "the remote user/tags") + flags.StringVar(&ia.remoteIP, "remote-ip", "", "the remote Tailscale IP") + flags.StringVar(&ia.ttyName, "tty-name", "", "the tty name (pts/3)") + flags.BoolVar(&ia.hasTTY, "has-tty", false, "is the output attached to a tty") + flags.StringVar(&ia.cmd, "cmd", "", "the cmd to launch, including all arguments (ignored in sftp mode)") + flags.BoolVar(&ia.isShell, "shell", false, "is launching a shell (with no cmds)") + flags.BoolVar(&ia.isSFTP, "sftp", false, "run sftp server (cmd is ignored)") + flags.BoolVar(&ia.forceV1Behavior, "force-v1-behavior", false, "allow falling back to the su command if login is unavailable") + flags.BoolVar(&ia.debugTest, "debug-test", false, "should debug in test mode") + flags.BoolVar(&ia.isSELinuxEnforcing, "is-selinux-enforcing", false, "whether SELinux is in enforcing mode") + flags.StringVar(&ia.encodedEnv, "encoded-env", "", "JSON encoded array of environment variables in '['key=value']' format") + flags.Parse(args) + return ia, nil +} + +func (ia incubatorArgs) forwardedEnviron() ([]string, string, error) { + environ := os.Environ() + // pass through SSH_AUTH_SOCK environment variable to support ssh agent forwarding + allowListKeys := "SSH_AUTH_SOCK" + + if ia.encodedEnv != "" { + unquoted, err := strconv.Unquote(ia.encodedEnv) + if err != nil { + return nil, "", fmt.Errorf("unable to parse encodedEnv %q: %w", ia.encodedEnv, err) + } + + var extraEnviron []string + + err = json.Unmarshal([]byte(unquoted), &extraEnviron) + if err != nil { + return nil, "", fmt.Errorf("unable to parse encodedEnv %q: %w", ia.encodedEnv, err) + } + + environ = append(environ, extraEnviron...) + + for _, v := range extraEnviron { + allowListKeys = fmt.Sprintf("%s,%s", allowListKeys, strings.Split(v, "=")[0]) + } + } + + return environ, allowListKeys, nil +} + +func beNetshell(args []string) error { + plan9netshell.Main() + return nil +} + +// beIncubator is the entrypoint to the `tailscaled be-child ssh` subcommand. +// It is responsible for informing the system of a new login session for the +// user. This is sometimes necessary for mounting home directories and +// decrypting file systems. +// +// Tailscaled launches the incubator as the same user as it was launched as. +func beIncubator(args []string) error { + // To defend against issues like https://golang.org/issue/1435, + // defensively lock our current goroutine's thread to the current + // system thread before we start making any UID/GID/group changes. + // + // This shouldn't matter on Linux because syscall.AllThreadsSyscall is + // used to invoke syscalls on all OS threads, but (as of 2023-03-23) + // that function is not implemented on all platforms. + runtime.LockOSThread() + defer runtime.UnlockOSThread() + + ia, err := parseIncubatorArgs(args) + if err != nil { + return err + } + if ia.isSFTP && ia.isShell { + return fmt.Errorf("--sftp and --shell are mutually exclusive") + } + + if ia.isShell { + plan9netshell.Main() + return nil + } + + dlogf := logger.Discard + if ia.debugTest { + // In testing, we don't always have syslog, so log to a temp file. + if logFile, err := os.OpenFile("/tmp/tailscalessh.log", os.O_APPEND|os.O_WRONLY, 0666); err == nil { + lf := log.New(logFile, "", 0) + dlogf = func(msg string, args ...any) { + lf.Printf(msg, args...) + logFile.Sync() + } + defer logFile.Close() + } + } + + return handleInProcess(dlogf, ia) +} + +func handleInProcess(dlogf logger.Logf, ia incubatorArgs) error { + if ia.isSFTP { + return handleSFTPInProcess(dlogf, ia) + } + return handleSSHInProcess(dlogf, ia) +} + +func handleSFTPInProcess(dlogf logger.Logf, ia incubatorArgs) error { + dlogf("handling sftp") + + return serveSFTP() +} + +// beSFTP serves SFTP in-process. +func beSFTP(args []string) error { + return serveSFTP() +} + +func serveSFTP() error { + server, err := sftp.NewServer(stdRWC{}) + if err != nil { + return err + } + // TODO(https://github.com/pkg/sftp/pull/554): Revert the check for io.EOF, + // when sftp is patched to report clean termination. + if err := server.Serve(); err != nil && err != io.EOF { + return err + } + return nil +} + +// handleSSHInProcess is a last resort if we couldn't use login or su. It +// registers a new session with the OS, sets its UID, GID and groups to the +// specified values, and then launches the requested `--cmd` in the user's +// login shell. +func handleSSHInProcess(dlogf logger.Logf, ia incubatorArgs) error { + + environ, _, err := ia.forwardedEnviron() + if err != nil { + return err + } + + dlogf("running /bin/rc -c %q", ia.cmd) + cmd := newCommand("/bin/rc", environ, []string{"-c", ia.cmd}) + err = cmd.Run() + if ee, ok := err.(*exec.ExitError); ok { + ps := ee.ProcessState + code := ps.ExitCode() + if code < 0 { + // TODO(bradfitz): do we need to also check the syscall.WaitStatus + // and make our process look like it also died by signal/same signal + // as our child process? For now we just do the exit code. + fmt.Fprintf(os.Stderr, "[tailscale-ssh: process died: %v]\n", ps.String()) + code = 1 // for now. so we don't exit with negative + } + os.Exit(code) + } + return err +} + +func newCommand(cmdPath string, cmdEnviron []string, cmdArgs []string) *exec.Cmd { + cmd := exec.Command(cmdPath, cmdArgs...) + cmd.Stdin = os.Stdin + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + cmd.Env = cmdEnviron + + return cmd +} + +// launchProcess launches an incubator process for the provided session. +// It is responsible for configuring the process execution environment. +// The caller can wait for the process to exit by calling cmd.Wait(). +// +// It sets ss.cmd, stdin, stdout, and stderr. +func (ss *sshSession) launchProcess() error { + var err error + ss.cmd, err = ss.newIncubatorCommand(ss.logf) + if err != nil { + return err + } + + cmd := ss.cmd + cmd.Dir = "/" + cmd.Env = append(os.Environ(), envForUser(ss.conn.localUser)...) + for _, kv := range ss.Environ() { + if acceptEnvPair(kv) { + cmd.Env = append(cmd.Env, kv) + } + } + + ci := ss.conn.info + cmd.Env = append(cmd.Env, + fmt.Sprintf("SSH_CLIENT=%s %d %d", ci.src.Addr(), ci.src.Port(), ci.dst.Port()), + fmt.Sprintf("SSH_CONNECTION=%s %d %s %d", ci.src.Addr(), ci.src.Port(), ci.dst.Addr(), ci.dst.Port()), + ) + + if ss.agentListener != nil { + cmd.Env = append(cmd.Env, fmt.Sprintf("SSH_AUTH_SOCK=%s", ss.agentListener.Addr())) + } + + return ss.startWithStdPipes() +} + +// startWithStdPipes starts cmd with os.Pipe for Stdin, Stdout and Stderr. +func (ss *sshSession) startWithStdPipes() (err error) { + var rdStdin, wrStdout, wrStderr io.ReadWriteCloser + defer func() { + if err != nil { + closeAll(rdStdin, ss.wrStdin, ss.rdStdout, wrStdout, ss.rdStderr, wrStderr) + } + }() + if ss.cmd == nil { + return errors.New("nil cmd") + } + if rdStdin, ss.wrStdin, err = os.Pipe(); err != nil { + return err + } + if ss.rdStdout, wrStdout, err = os.Pipe(); err != nil { + return err + } + if ss.rdStderr, wrStderr, err = os.Pipe(); err != nil { + return err + } + ss.cmd.Stdin = rdStdin + ss.cmd.Stdout = wrStdout + ss.cmd.Stderr = wrStderr + ss.childPipes = []io.Closer{rdStdin, wrStdout, wrStderr} + return ss.cmd.Start() +} + +func envForUser(u *userMeta) []string { + return []string{ + fmt.Sprintf("user=%s", u.Username), + fmt.Sprintf("home=%s", u.HomeDir), + fmt.Sprintf("path=%s", defaultPathForUser(&u.User)), + } +} + +// acceptEnvPair reports whether the environment variable key=value pair +// should be accepted from the client. It uses the same default as OpenSSH +// AcceptEnv. +func acceptEnvPair(kv string) bool { + k, _, ok := strings.Cut(kv, "=") + if !ok { + return false + } + _ = k + return true // permit anything on plan9 during bringup, for debugging at least +} diff --git a/ssh/tailssh/tailssh.go b/ssh/tailssh/tailssh.go index 9ade1847e..7d12ab45f 100644 --- a/ssh/tailssh/tailssh.go +++ b/ssh/tailssh/tailssh.go @@ -1,7 +1,7 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -//go:build linux || (darwin && !ios) || freebsd || openbsd +//go:build (linux && !android) || (darwin && !ios) || freebsd || openbsd || plan9 // Package tailssh is an SSH server integrated into Tailscale. package tailssh @@ -10,7 +10,6 @@ import ( "bytes" "context" "crypto/rand" - "encoding/base64" "encoding/json" "errors" "fmt" @@ -30,10 +29,9 @@ import ( "syscall" "time" - gossh "github.com/tailscale/golang-x-crypto/ssh" + gossh "golang.org/x/crypto/ssh" "tailscale.com/envknob" "tailscale.com/ipn/ipnlocal" - "tailscale.com/logtail/backoff" "tailscale.com/net/tsaddr" "tailscale.com/net/tsdial" "tailscale.com/sessionrecording" @@ -42,10 +40,10 @@ import ( "tailscale.com/types/key" "tailscale.com/types/logger" "tailscale.com/types/netmap" + "tailscale.com/util/backoff" "tailscale.com/util/clientmetric" "tailscale.com/util/httpm" "tailscale.com/util/mak" - "tailscale.com/util/slicesx" ) var ( @@ -53,6 +51,11 @@ var ( sshDisableSFTP = envknob.RegisterBool("TS_SSH_DISABLE_SFTP") sshDisableForwarding = envknob.RegisterBool("TS_SSH_DISABLE_FORWARDING") sshDisablePTY = envknob.RegisterBool("TS_SSH_DISABLE_PTY") + + // errTerminal is an empty gossh.PartialSuccessError (with no 'Next' + // authentication methods that may proceed), which results in the SSH + // server immediately disconnecting the client. + errTerminal = &gossh.PartialSuccessError{} ) const ( @@ -80,16 +83,14 @@ type server struct { logf logger.Logf tailscaledPath string - pubKeyHTTPClient *http.Client // or nil for http.DefaultClient - timeNow func() time.Time // or nil for time.Now + timeNow func() time.Time // or nil for time.Now sessionWaitGroup sync.WaitGroup // mu protects the following - mu sync.Mutex - activeConns map[*conn]bool // set; value is always true - fetchPublicKeysCache map[string]pubKeyCacheEntry // by https URL - shutdownCalled bool + mu sync.Mutex + activeConns map[*conn]bool // set; value is always true + shutdownCalled bool } func (srv *server) now() time.Time { @@ -202,9 +203,11 @@ func (srv *server) OnPolicyChange() { // Setup and discover server info // - ServerConfigCallback // -// Do the user auth -// - NoClientAuthHandler -// - PublicKeyHandler (only if NoClientAuthHandler returns errPubKeyRequired) +// Get access to a ServerPreAuthConn (useful for sending banners) +// +// Do the user auth with a NoClientAuthCallback. If user specified +// a username ending in "+password", follow this with password auth +// (to work around buggy SSH clients that don't work with noauth). // // Once auth is done, the conn can be multiplexed with multiple sessions and // channels concurrently. At which point any of the following can be called @@ -224,20 +227,16 @@ type conn struct { idH string connID string // ID that's shared with control - // anyPasswordIsOkay is whether the client is authorized but has requested - // password-based auth to work around their buggy SSH client. When set, we - // accept any password in the PasswordHandler. - anyPasswordIsOkay bool // set by NoClientAuthCallback + // spac is a [gossh.ServerPreAuthConn] used for sending auth banners. + // Banners cannot be sent after auth completes. + spac gossh.ServerPreAuthConn - action0 *tailcfg.SSHAction // set by doPolicyAuth; first matching action - currentAction *tailcfg.SSHAction // set by doPolicyAuth, updated by resolveNextAction - finalAction *tailcfg.SSHAction // set by doPolicyAuth or resolveNextAction - finalActionErr error // set by doPolicyAuth or resolveNextAction + action0 *tailcfg.SSHAction // set by clientAuth + finalAction *tailcfg.SSHAction // set by clientAuth - info *sshConnInfo // set by setInfo - localUser *userMeta // set by doPolicyAuth - userGroupIDs []string // set by doPolicyAuth - pubKey gossh.PublicKey // set by doPolicyAuth + info *sshConnInfo // set by setInfo + localUser *userMeta // set by clientAuth + userGroupIDs []string // set by clientAuth acceptEnv []string // mu protects the following fields. @@ -260,172 +259,190 @@ func (c *conn) vlogf(format string, args ...any) { } } -// isAuthorized walks through the action chain and returns nil if the connection -// is authorized. If the connection is not authorized, it returns -// errDenied. If the action chain resolution fails, it returns the -// resolution error. -func (c *conn) isAuthorized(ctx ssh.Context) error { - action := c.currentAction - for { - if action.Accept { - if c.pubKey != nil { - metricPublicKeyAccepts.Add(1) - } - return nil - } - if action.Reject || action.HoldAndDelegate == "" { - return errDenied - } - var err error - action, err = c.resolveNextAction(ctx) - if err != nil { - return err - } - if action.Message != "" { - if err := ctx.SendAuthBanner(action.Message); err != nil { - return err - } - } - } -} - // errDenied is returned by auth callbacks when a connection is denied by the -// policy. -var errDenied = errors.New("ssh: access denied") - -// errPubKeyRequired is returned by NoClientAuthCallback to make the client -// resort to public-key auth; not user visible. -var errPubKeyRequired = errors.New("ssh publickey required") - -// NoClientAuthCallback implements gossh.NoClientAuthCallback and is called by -// the ssh.Server when the client first connects with the "none" -// authentication method. -// -// It is responsible for continuing policy evaluation from BannerCallback (or -// starting it afresh). It returns an error if the policy evaluation fails, or -// if the decision is "reject" -// -// It either returns nil (accept) or errPubKeyRequired or errDenied -// (reject). The errors may be wrapped. -func (c *conn) NoClientAuthCallback(ctx ssh.Context) error { - if c.insecureSkipTailscaleAuth { - return nil - } - if err := c.doPolicyAuth(ctx, nil /* no pub key */); err != nil { - return err +// policy. It writes the message to an auth banner and then returns an empty +// gossh.PartialSuccessError in order to stop processing authentication +// attempts and immediately disconnect the client. +func (c *conn) errDenied(message string) error { + if message == "" { + message = "tailscale: access denied" } - if err := c.isAuthorized(ctx); err != nil { - return err - } - - // Let users specify a username ending in +password to force password auth. - // This exists for buggy SSH clients that get confused by success from - // "none" auth. - if strings.HasSuffix(ctx.User(), forcePasswordSuffix) { - c.anyPasswordIsOkay = true - return errors.New("any password please") // not shown to users + if err := c.spac.SendAuthBanner(message); err != nil { + c.logf("failed to send auth banner: %s", err) } - return nil + return errTerminal } -func (c *conn) nextAuthMethodCallback(cm gossh.ConnMetadata, prevErrors []error) (nextMethod []string) { - switch { - case c.anyPasswordIsOkay: - nextMethod = append(nextMethod, "password") - case slicesx.LastEqual(prevErrors, errPubKeyRequired): - nextMethod = append(nextMethod, "publickey") +// errBanner writes the given message to an auth banner and then returns an +// empty gossh.PartialSuccessError in order to stop processing authentication +// attempts and immediately disconnect the client. The contents of err is not +// leaked in the auth banner, but it is logged to the server's log. +func (c *conn) errBanner(message string, err error) error { + if err != nil { + c.logf("%s: %s", message, err) } - - // The fake "tailscale" method is always appended to next so OpenSSH renders - // that in parens as the final failure. (It also shows up in "ssh -v", etc) - nextMethod = append(nextMethod, "tailscale") - return + if err := c.spac.SendAuthBanner("tailscale: " + message + "\n"); err != nil { + c.logf("failed to send auth banner: %s", err) + } + return errTerminal } -// fakePasswordHandler is our implementation of the PasswordHandler hook that -// checks whether the user's password is correct. But we don't actually use -// passwords. This exists only for when the user's username ends in "+password" -// to signal that their SSH client is buggy and gets confused by auth type -// "none" succeeding and they want our SSH server to require a dummy password -// prompt instead. We then accept any password since we've already authenticated -// & authorized them. -func (c *conn) fakePasswordHandler(ctx ssh.Context, password string) bool { - return c.anyPasswordIsOkay +// errUnexpected is returned by auth callbacks that encounter an unexpected +// error, such as being unable to send an auth banner. It sends an empty +// gossh.PartialSuccessError to tell gossh.Server to stop processing +// authentication attempts and instead disconnect immediately. +func (c *conn) errUnexpected(err error) error { + c.logf("terminal error: %s", err) + return errTerminal } -// PublicKeyHandler implements ssh.PublicKeyHandler is called by the -// ssh.Server when the client presents a public key. -func (c *conn) PublicKeyHandler(ctx ssh.Context, pubKey ssh.PublicKey) error { - if err := c.doPolicyAuth(ctx, pubKey); err != nil { - // TODO(maisem/bradfitz): surface the error here. - c.logf("rejecting SSH public key %s: %v", bytes.TrimSpace(gossh.MarshalAuthorizedKey(pubKey)), err) - return err +// clientAuth is responsible for performing client authentication. +// +// If policy evaluation fails, it returns an error. +// If access is denied, it returns an error. This must always be an empty +// gossh.PartialSuccessError to prevent further authentication methods from +// being tried. +func (c *conn) clientAuth(cm gossh.ConnMetadata) (perms *gossh.Permissions, retErr error) { + defer func() { + if pse, ok := retErr.(*gossh.PartialSuccessError); ok { + if pse.Next.GSSAPIWithMICConfig != nil || + pse.Next.KeyboardInteractiveCallback != nil || + pse.Next.PasswordCallback != nil || + pse.Next.PublicKeyCallback != nil { + panic("clientAuth attempted to return a non-empty PartialSuccessError") + } + } else if retErr != nil { + panic(fmt.Sprintf("clientAuth attempted to return a non-PartialSuccessError error of type: %t", retErr)) + } + }() + + if c.insecureSkipTailscaleAuth { + return &gossh.Permissions{}, nil } - if err := c.isAuthorized(ctx); err != nil { - return err + + if err := c.setInfo(cm); err != nil { + return nil, c.errBanner("failed to get connection info", err) } - c.logf("accepting SSH public key %s", bytes.TrimSpace(gossh.MarshalAuthorizedKey(pubKey))) - return nil -} -// doPolicyAuth verifies that conn can proceed with the specified (optional) -// pubKey. It returns nil if the matching policy action is Accept or -// HoldAndDelegate. If pubKey is nil, there was no policy match but there is a -// policy that might match a public key it returns errPubKeyRequired. Otherwise, -// it returns errDenied. -func (c *conn) doPolicyAuth(ctx ssh.Context, pubKey ssh.PublicKey) error { - if err := c.setInfo(ctx); err != nil { - c.logf("failed to get conninfo: %v", err) - return errDenied - } - a, localUser, acceptEnv, err := c.evaluatePolicy(pubKey) - if err != nil { - if pubKey == nil && c.havePubKeyPolicy() { - return errPubKeyRequired - } - return fmt.Errorf("%w: %v", errDenied, err) - } - c.action0 = a - c.currentAction = a - c.pubKey = pubKey - c.acceptEnv = acceptEnv - if a.Message != "" { - if err := ctx.SendAuthBanner(a.Message); err != nil { - return fmt.Errorf("SendBanner: %w", err) - } + action, localUser, acceptEnv, result := c.evaluatePolicy() + switch result { + case accepted: + // do nothing + case rejectedUser: + return nil, c.errBanner(fmt.Sprintf("tailnet policy does not permit you to SSH as user %q", c.info.sshUser), nil) + case rejected, noPolicy: + return nil, c.errBanner("tailnet policy does not permit you to SSH to this node", fmt.Errorf("failed to evaluate policy, result: %s", result)) + default: + return nil, c.errBanner("failed to evaluate tailnet policy", fmt.Errorf("failed to evaluate policy, result: %s", result)) } - if a.Accept || a.HoldAndDelegate != "" { - if a.Accept { - c.finalAction = a - } + + c.action0 = action + + if action.Accept || action.HoldAndDelegate != "" { + // Immediately look up user information for purposes of generating + // hold and delegate URL (if necessary). lu, err := userLookup(localUser) if err != nil { - c.logf("failed to look up %v: %v", localUser, err) - ctx.SendAuthBanner(fmt.Sprintf("failed to look up %v\r\n", localUser)) - return err + return nil, c.errBanner(fmt.Sprintf("failed to look up local user %q ", localUser), err) } gids, err := lu.GroupIds() if err != nil { - c.logf("failed to look up local user's group IDs: %v", err) - return err + return nil, c.errBanner("failed to look up local user's group IDs", err) } c.userGroupIDs = gids c.localUser = lu - return nil + c.acceptEnv = acceptEnv } - if a.Reject { - c.finalAction = a - return errDenied + + for { + switch { + case action.Accept: + metricTerminalAccept.Add(1) + if action.Message != "" { + if err := c.spac.SendAuthBanner(action.Message); err != nil { + return nil, c.errUnexpected(fmt.Errorf("error sending auth welcome message: %w", err)) + } + } + c.finalAction = action + return &gossh.Permissions{}, nil + case action.Reject: + metricTerminalReject.Add(1) + c.finalAction = action + return nil, c.errDenied(action.Message) + case action.HoldAndDelegate != "": + if action.Message != "" { + if err := c.spac.SendAuthBanner(action.Message); err != nil { + return nil, c.errUnexpected(fmt.Errorf("error sending hold and delegate message: %w", err)) + } + } + + url := action.HoldAndDelegate + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Minute) + defer cancel() + + metricHolds.Add(1) + url = c.expandDelegateURLLocked(url) + + var err error + action, err = c.fetchSSHAction(ctx, url) + if err != nil { + metricTerminalFetchError.Add(1) + return nil, c.errBanner("failed to fetch next SSH action", fmt.Errorf("fetch failed from %s: %w", url, err)) + } + default: + metricTerminalMalformed.Add(1) + return nil, c.errBanner("reached Action that had neither Accept, Reject, nor HoldAndDelegate", nil) + } } - // Shouldn't get here, but: - return errDenied } // ServerConfig implements ssh.ServerConfigCallback. func (c *conn) ServerConfig(ctx ssh.Context) *gossh.ServerConfig { return &gossh.ServerConfig{ - NoClientAuth: true, // required for the NoClientAuthCallback to run - NextAuthMethodCallback: c.nextAuthMethodCallback, + PreAuthConnCallback: func(spac gossh.ServerPreAuthConn) { + c.spac = spac + }, + NoClientAuth: true, // required for the NoClientAuthCallback to run + NoClientAuthCallback: func(cm gossh.ConnMetadata) (*gossh.Permissions, error) { + // First perform client authentication, which can potentially + // involve multiple steps (for example prompting user to log in to + // Tailscale admin panel to confirm identity). + perms, err := c.clientAuth(cm) + if err != nil { + return nil, err + } + + // Authentication succeeded. Buggy SSH clients get confused by + // success from the "none" auth method. As a workaround, let users + // specify a username ending in "+password" to force password auth. + // The actual value of the password doesn't matter. + if strings.HasSuffix(cm.User(), forcePasswordSuffix) { + return nil, &gossh.PartialSuccessError{ + Next: gossh.ServerAuthCallbacks{ + PasswordCallback: func(_ gossh.ConnMetadata, password []byte) (*gossh.Permissions, error) { + return &gossh.Permissions{}, nil + }, + }, + } + } + + return perms, nil + }, + PasswordCallback: func(cm gossh.ConnMetadata, pword []byte) (*gossh.Permissions, error) { + // Some clients don't request 'none' authentication. Instead, they + // immediately supply a password. We humor them by accepting the + // password, but authenticate as usual, ignoring the actual value of + // the password. + return c.clientAuth(cm) + }, + PublicKeyCallback: func(cm gossh.ConnMetadata, key gossh.PublicKey) (*gossh.Permissions, error) { + // Some clients don't request 'none' authentication. Instead, they + // immediately supply a public key. We humor them by accepting the + // key, but authenticate as usual, ignoring the actual content of + // the key. + return c.clientAuth(cm) + }, } } @@ -436,7 +453,7 @@ func (srv *server) newConn() (*conn, error) { // Stop accepting new connections. // Connections in the auth phase are handled in handleConnPostSSHAuth. // Existing sessions are terminated by Shutdown. - return nil, errDenied + return nil, errors.New("server is shutting down") } srv.mu.Unlock() c := &conn{srv: srv} @@ -447,10 +464,6 @@ func (srv *server) newConn() (*conn, error) { Version: "Tailscale", ServerConfigCallback: c.ServerConfig, - NoClientAuthHandler: c.NoClientAuthCallback, - PublicKeyHandler: c.PublicKeyHandler, - PasswordHandler: c.fakePasswordHandler, - Handler: c.handleSessionPostSSHAuth, LocalPortForwardingCallback: c.mayForwardLocalPortTo, ReversePortForwardingCallback: c.mayReversePortForwardTo, @@ -516,34 +529,6 @@ func (c *conn) mayForwardLocalPortTo(ctx ssh.Context, destinationHost string, de return false } -// havePubKeyPolicy reports whether any policy rule may provide access by means -// of a ssh.PublicKey. -func (c *conn) havePubKeyPolicy() bool { - if c.info == nil { - panic("havePubKeyPolicy called before setInfo") - } - // Is there any rule that looks like it'd require a public key for this - // sshUser? - pol, ok := c.sshPolicy() - if !ok { - return false - } - for _, r := range pol.Rules { - if c.ruleExpired(r) { - continue - } - if mapLocalUser(r.SSHUsers, c.info.sshUser) == "" { - continue - } - for _, p := range r.Principals { - if len(p.PubKeys) > 0 && c.principalMatchesTailscaleIdentity(p) { - return true - } - } - } - return false -} - // sshPolicy returns the SSHPolicy for current node. // If there is no SSHPolicy in the netmap, it returns a debugPolicy // if one is defined. @@ -589,16 +574,16 @@ func toIPPort(a net.Addr) (ipp netip.AddrPort) { return netip.AddrPortFrom(tanetaddr.Unmap(), uint16(ta.Port)) } -// connInfo returns a populated sshConnInfo from the provided arguments, +// connInfo populates the sshConnInfo from the provided arguments, // validating only that they represent a known Tailscale identity. -func (c *conn) setInfo(ctx ssh.Context) error { +func (c *conn) setInfo(cm gossh.ConnMetadata) error { if c.info != nil { return nil } ci := &sshConnInfo{ - sshUser: strings.TrimSuffix(ctx.User(), forcePasswordSuffix), - src: toIPPort(ctx.RemoteAddr()), - dst: toIPPort(ctx.LocalAddr()), + sshUser: strings.TrimSuffix(cm.User(), forcePasswordSuffix), + src: toIPPort(cm.RemoteAddr()), + dst: toIPPort(cm.LocalAddr()), } if !tsaddr.IsTailscaleIP(ci.dst.Addr()) { return fmt.Errorf("tailssh: rejecting non-Tailscale local address %v", ci.dst) @@ -613,122 +598,29 @@ func (c *conn) setInfo(ctx ssh.Context) error { ci.node = node ci.uprof = uprof - c.idH = ctx.SessionID() + c.idH = string(cm.SessionID()) c.info = ci c.logf("handling conn: %v", ci.String()) return nil } -// evaluatePolicy returns the SSHAction and localUser after evaluating -// the SSHPolicy for this conn. The pubKey may be nil for "none" auth. -func (c *conn) evaluatePolicy(pubKey gossh.PublicKey) (_ *tailcfg.SSHAction, localUser string, acceptEnv []string, _ error) { - pol, ok := c.sshPolicy() - if !ok { - return nil, "", nil, fmt.Errorf("tailssh: rejecting connection; no SSH policy") - } - a, localUser, acceptEnv, ok := c.evalSSHPolicy(pol, pubKey) - if !ok { - return nil, "", nil, fmt.Errorf("tailssh: rejecting connection; no matching policy") - } - return a, localUser, acceptEnv, nil -} - -// pubKeyCacheEntry is the cache value for an HTTPS URL of public keys (like -// "https://github.com/foo.keys") -type pubKeyCacheEntry struct { - lines []string - etag string // if sent by server - at time.Time -} +type evalResult string const ( - pubKeyCacheDuration = time.Minute // how long to cache non-empty public keys - pubKeyCacheEmptyDuration = 15 * time.Second // how long to cache empty responses + noPolicy evalResult = "no policy" + rejected evalResult = "rejected" + rejectedUser evalResult = "rejected user" + accepted evalResult = "accept" ) -func (srv *server) fetchPublicKeysURLCached(url string) (ce pubKeyCacheEntry, ok bool) { - srv.mu.Lock() - defer srv.mu.Unlock() - // Mostly don't care about the size of this cache. Clean rarely. - if m := srv.fetchPublicKeysCache; len(m) > 50 { - tooOld := srv.now().Add(pubKeyCacheDuration * 10) - for k, ce := range m { - if ce.at.Before(tooOld) { - delete(m, k) - } - } - } - ce, ok = srv.fetchPublicKeysCache[url] +// evaluatePolicy returns the SSHAction and localUser after evaluating +// the SSHPolicy for this conn. +func (c *conn) evaluatePolicy() (_ *tailcfg.SSHAction, localUser string, acceptEnv []string, result evalResult) { + pol, ok := c.sshPolicy() if !ok { - return ce, false - } - maxAge := pubKeyCacheDuration - if len(ce.lines) == 0 { - maxAge = pubKeyCacheEmptyDuration - } - return ce, srv.now().Sub(ce.at) < maxAge -} - -func (srv *server) pubKeyClient() *http.Client { - if srv.pubKeyHTTPClient != nil { - return srv.pubKeyHTTPClient + return nil, "", nil, noPolicy } - return http.DefaultClient -} - -// fetchPublicKeysURL fetches the public keys from a URL. The strings are in the -// the typical public key "type base64-string [comment]" format seen at e.g. -// https://github.com/USER.keys -func (srv *server) fetchPublicKeysURL(url string) ([]string, error) { - if !strings.HasPrefix(url, "https://") { - return nil, errors.New("invalid URL scheme") - } - - ce, ok := srv.fetchPublicKeysURLCached(url) - if ok { - return ce.lines, nil - } - - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - req, err := http.NewRequestWithContext(ctx, "GET", url, nil) - if err != nil { - return nil, err - } - if ce.etag != "" { - req.Header.Add("If-None-Match", ce.etag) - } - res, err := srv.pubKeyClient().Do(req) - if err != nil { - return nil, err - } - defer res.Body.Close() - var lines []string - var etag string - switch res.StatusCode { - default: - err = fmt.Errorf("unexpected status %v", res.Status) - srv.logf("fetching public keys from %s: %v", url, err) - case http.StatusNotModified: - lines = ce.lines - etag = ce.etag - case http.StatusOK: - var all []byte - all, err = io.ReadAll(io.LimitReader(res.Body, 4<<10)) - if s := strings.TrimSpace(string(all)); s != "" { - lines = strings.Split(s, "\n") - } - etag = res.Header.Get("Etag") - } - - srv.mu.Lock() - defer srv.mu.Unlock() - mak.Set(&srv.fetchPublicKeysCache, url, pubKeyCacheEntry{ - at: srv.now(), - lines: lines, - etag: etag, - }) - return lines, err + return c.evalSSHPolicy(pol) } // handleSessionPostSSHAuth runs an SSH session after the SSH-level authentication, @@ -758,62 +650,6 @@ func (c *conn) handleSessionPostSSHAuth(s ssh.Session) { ss.run() } -// resolveNextAction starts at c.currentAction and makes it way through the -// action chain one step at a time. An action without a HoldAndDelegate is -// considered the final action. Once a final action is reached, this function -// will keep returning that action. It updates c.currentAction to the next -// action in the chain. When the final action is reached, it also sets -// c.finalAction to the final action. -func (c *conn) resolveNextAction(sctx ssh.Context) (action *tailcfg.SSHAction, err error) { - if c.finalAction != nil || c.finalActionErr != nil { - return c.finalAction, c.finalActionErr - } - - defer func() { - if action != nil { - c.currentAction = action - if action.Accept || action.Reject { - c.finalAction = action - } - } - if err != nil { - c.finalActionErr = err - } - }() - - ctx, cancel := context.WithCancel(sctx) - defer cancel() - - // Loop processing/fetching Actions until one reaches a - // terminal state (Accept, Reject, or invalid Action), or - // until fetchSSHAction times out due to the context being - // done (client disconnect) or its 30 minute timeout passes. - // (Which is a long time for somebody to see login - // instructions and go to a URL to do something.) - action = c.currentAction - if action.Accept || action.Reject { - if action.Reject { - metricTerminalReject.Add(1) - } else { - metricTerminalAccept.Add(1) - } - return action, nil - } - url := action.HoldAndDelegate - if url == "" { - metricTerminalMalformed.Add(1) - return nil, errors.New("reached Action that lacked Accept, Reject, and HoldAndDelegate") - } - metricHolds.Add(1) - url = c.expandDelegateURLLocked(url) - nextAction, err := c.fetchSSHAction(ctx, url) - if err != nil { - metricTerminalFetchError.Add(1) - return nil, fmt.Errorf("fetching SSHAction from %s: %w", url, err) - } - return nextAction, nil -} - func (c *conn) expandDelegateURLLocked(actionURL string) string { nm := c.srv.lb.NetMap() ci := c.info @@ -832,18 +668,6 @@ func (c *conn) expandDelegateURLLocked(actionURL string) string { ).Replace(actionURL) } -func (c *conn) expandPublicKeyURL(pubKeyURL string) string { - if !strings.Contains(pubKeyURL, "$") { - return pubKeyURL - } - loginName := c.info.uprof.LoginName - localPart, _, _ := strings.Cut(loginName, "@") - return strings.NewReplacer( - "$LOGINNAME_EMAIL", loginName, - "$LOGINNAME_LOCALPART", localPart, - ).Replace(pubKeyURL) -} - // sshSession is an accepted Tailscale SSH session. type sshSession struct { ssh.Session @@ -894,9 +718,9 @@ func (c *conn) newSSHSession(s ssh.Session) *sshSession { // isStillValid reports whether the conn is still valid. func (c *conn) isStillValid() bool { - a, localUser, _, err := c.evaluatePolicy(c.pubKey) - c.vlogf("stillValid: %+v %v %v", a, localUser, err) - if err != nil { + a, localUser, _, result := c.evaluatePolicy() + c.vlogf("stillValid: %+v %v %v", a, localUser, result) + if result != accepted { return false } if !a.Accept && a.HoldAndDelegate == "" { @@ -1091,7 +915,7 @@ func (ss *sshSession) run() { defer t.Stop() } - if euid := os.Geteuid(); euid != 0 { + if euid := os.Geteuid(); euid != 0 && runtime.GOOS != "plan9" { if lu.Uid != fmt.Sprint(euid) { ss.logf("can't switch to user %q from process euid %v", lu.Username, euid) fmt.Fprintf(ss, "can't switch user\r\n") @@ -1170,7 +994,7 @@ func (ss *sshSession) run() { if err != nil && !errors.Is(err, io.EOF) { isErrBecauseProcessExited := processDone.Load() && errors.Is(err, syscall.EIO) if !isErrBecauseProcessExited { - logf("stdout copy: %v, %T", err) + logf("stdout copy: %v", err) ss.cancelCtx(err) } } @@ -1277,13 +1101,20 @@ func (c *conn) ruleExpired(r *tailcfg.SSHRule) bool { return r.RuleExpires.Before(c.srv.now()) } -func (c *conn) evalSSHPolicy(pol *tailcfg.SSHPolicy, pubKey gossh.PublicKey) (a *tailcfg.SSHAction, localUser string, acceptEnv []string, ok bool) { +func (c *conn) evalSSHPolicy(pol *tailcfg.SSHPolicy) (a *tailcfg.SSHAction, localUser string, acceptEnv []string, result evalResult) { + failedOnUser := false for _, r := range pol.Rules { - if a, localUser, acceptEnv, err := c.matchRule(r, pubKey); err == nil { - return a, localUser, acceptEnv, true + if a, localUser, acceptEnv, err := c.matchRule(r); err == nil { + return a, localUser, acceptEnv, accepted + } else if errors.Is(err, errUserMatch) { + failedOnUser = true } } - return nil, "", nil, false + result = rejected + if failedOnUser { + result = rejectedUser + } + return nil, "", nil, result } // internal errors for testing; they don't escape to callers or logs. @@ -1296,7 +1127,7 @@ var ( errInvalidConn = errors.New("invalid connection state") ) -func (c *conn) matchRule(r *tailcfg.SSHRule, pubKey gossh.PublicKey) (a *tailcfg.SSHAction, localUser string, acceptEnv []string, err error) { +func (c *conn) matchRule(r *tailcfg.SSHRule) (a *tailcfg.SSHAction, localUser string, acceptEnv []string, err error) { defer func() { c.vlogf("matchRule(%+v): %v", r, err) }() @@ -1317,6 +1148,9 @@ func (c *conn) matchRule(r *tailcfg.SSHRule, pubKey gossh.PublicKey) (a *tailcfg if c.ruleExpired(r) { return nil, "", nil, errRuleExpired } + if !c.anyPrincipalMatches(r.Principals) { + return nil, "", nil, errPrincipalMatch + } if !r.Action.Reject { // For all but Reject rules, SSHUsers is required. // If SSHUsers is nil or empty, mapLocalUser will return an @@ -1326,11 +1160,6 @@ func (c *conn) matchRule(r *tailcfg.SSHRule, pubKey gossh.PublicKey) (a *tailcfg return nil, "", nil, errUserMatch } } - if ok, err := c.anyPrincipalMatches(r.Principals, pubKey); err != nil { - return nil, "", nil, err - } else if !ok { - return nil, "", nil, errPrincipalMatch - } return r.Action, localUser, r.AcceptEnv, nil } @@ -1345,30 +1174,20 @@ func mapLocalUser(ruleSSHUsers map[string]string, reqSSHUser string) (localUser return v } -func (c *conn) anyPrincipalMatches(ps []*tailcfg.SSHPrincipal, pubKey gossh.PublicKey) (bool, error) { +func (c *conn) anyPrincipalMatches(ps []*tailcfg.SSHPrincipal) bool { for _, p := range ps { if p == nil { continue } - if ok, err := c.principalMatches(p, pubKey); err != nil { - return false, err - } else if ok { - return true, nil + if c.principalMatchesTailscaleIdentity(p) { + return true } } - return false, nil -} - -func (c *conn) principalMatches(p *tailcfg.SSHPrincipal, pubKey gossh.PublicKey) (bool, error) { - if !c.principalMatchesTailscaleIdentity(p) { - return false, nil - } - return c.principalMatchesPubKey(p, pubKey) + return false } // principalMatchesTailscaleIdentity reports whether one of p's four fields // that match the Tailscale identity match (Node, NodeIP, UserLogin, Any). -// This function does not consider PubKeys. func (c *conn) principalMatchesTailscaleIdentity(p *tailcfg.SSHPrincipal) bool { ci := c.info if p.Any { @@ -1388,42 +1207,6 @@ func (c *conn) principalMatchesTailscaleIdentity(p *tailcfg.SSHPrincipal) bool { return false } -func (c *conn) principalMatchesPubKey(p *tailcfg.SSHPrincipal, clientPubKey gossh.PublicKey) (bool, error) { - if len(p.PubKeys) == 0 { - return true, nil - } - if clientPubKey == nil { - return false, nil - } - knownKeys := p.PubKeys - if len(knownKeys) == 1 && strings.HasPrefix(knownKeys[0], "https://") { - var err error - knownKeys, err = c.srv.fetchPublicKeysURL(c.expandPublicKeyURL(knownKeys[0])) - if err != nil { - return false, err - } - } - for _, knownKey := range knownKeys { - if pubKeyMatchesAuthorizedKey(clientPubKey, knownKey) { - return true, nil - } - } - return false, nil -} - -func pubKeyMatchesAuthorizedKey(pubKey ssh.PublicKey, wantKey string) bool { - wantKeyType, rest, ok := strings.Cut(wantKey, " ") - if !ok { - return false - } - if pubKey.Type() != wantKeyType { - return false - } - wantKeyB64, _, _ := strings.Cut(rest, " ") - wantKeyData, _ := base64.StdEncoding.DecodeString(wantKeyB64) - return len(wantKeyData) > 0 && bytes.Equal(pubKey.Marshal(), wantKeyData) -} - func randBytes(n int) []byte { b := make([]byte, n) if _, err := rand.Read(b); err != nil { @@ -1520,9 +1303,14 @@ func (ss *sshSession) startNewRecording() (_ *recording, err error) { go func() { err := <-errChan if err == nil { - // Success. - ss.logf("recording: finished uploading recording") - return + select { + case <-ss.ctx.Done(): + // Success. + ss.logf("recording: finished uploading recording") + return + default: + err = errors.New("recording upload ended before the SSH session") + } } if onFailure != nil && onFailure.NotifyURL != "" && len(attempts) > 0 { lastAttempt := attempts[len(attempts)-1] @@ -1744,7 +1532,6 @@ func envEq(a, b string) bool { var ( metricActiveSessions = clientmetric.NewGauge("ssh_active_sessions") metricIncomingConnections = clientmetric.NewCounter("ssh_incoming_connections") - metricPublicKeyAccepts = clientmetric.NewCounter("ssh_publickey_accepts") // accepted subset of ssh_publickey_connections metricTerminalAccept = clientmetric.NewCounter("ssh_terminalaction_accept") metricTerminalReject = clientmetric.NewCounter("ssh_terminalaction_reject") metricTerminalMalformed = clientmetric.NewCounter("ssh_terminalaction_malformed") diff --git a/ssh/tailssh/tailssh_integration_test.go b/ssh/tailssh/tailssh_integration_test.go index 1799d3400..9ab26e169 100644 --- a/ssh/tailssh/tailssh_integration_test.go +++ b/ssh/tailssh/tailssh_integration_test.go @@ -2,7 +2,6 @@ // SPDX-License-Identifier: BSD-3-Clause //go:build integrationtest -// +build integrationtest package tailssh @@ -32,8 +31,8 @@ import ( "github.com/bramvdbogaerde/go-scp" "github.com/google/go-cmp/cmp" "github.com/pkg/sftp" - gossh "github.com/tailscale/golang-x-crypto/ssh" "golang.org/x/crypto/ssh" + gossh "golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh/agent" "tailscale.com/net/tsdial" "tailscale.com/tailcfg" @@ -410,6 +409,48 @@ func TestSSHAgentForwarding(t *testing.T) { } } +// TestIntegrationParamiko attempts to connect to Tailscale SSH using the +// paramiko Python library. This library does not request 'none' auth. This +// test ensures that Tailscale SSH can correctly handle clients that don't +// request 'none' auth and instead immediately authenticate with a public key +// or password. +func TestIntegrationParamiko(t *testing.T) { + debugTest.Store(true) + t.Cleanup(func() { + debugTest.Store(false) + }) + + addr := testServer(t, "testuser", true, false) + host, port, err := net.SplitHostPort(addr) + if err != nil { + t.Fatalf("Failed to split addr %q: %s", addr, err) + } + + out, err := exec.Command("python3", "-c", fmt.Sprintf(` +import paramiko.client as pm +from paramiko.ecdsakey import ECDSAKey +client = pm.SSHClient() +client.set_missing_host_key_policy(pm.AutoAddPolicy) +client.connect('%s', port=%s, username='testuser', pkey=ECDSAKey.generate(), allow_agent=False, look_for_keys=False) +client.exec_command('pwd') +`, host, port)).CombinedOutput() + if err != nil { + t.Fatalf("failed to connect with Paramiko using public key auth: %s\n%q", err, string(out)) + } + + out, err = exec.Command("python3", "-c", fmt.Sprintf(` +import paramiko.client as pm +from paramiko.ecdsakey import ECDSAKey +client = pm.SSHClient() +client.set_missing_host_key_policy(pm.AutoAddPolicy) +client.connect('%s', port=%s, username='testuser', password='doesntmatter', allow_agent=False, look_for_keys=False) +client.exec_command('pwd') +`, host, port)).CombinedOutput() + if err != nil { + t.Fatalf("failed to connect with Paramiko using password auth: %s\n%q", err, string(out)) + } +} + func fallbackToSUAvailable() bool { if runtime.GOOS != "linux" { return false diff --git a/ssh/tailssh/tailssh_test.go b/ssh/tailssh/tailssh_test.go index 9e4f5ffd3..3b6d3c52c 100644 --- a/ssh/tailssh/tailssh_test.go +++ b/ssh/tailssh/tailssh_test.go @@ -8,9 +8,10 @@ package tailssh import ( "bytes" "context" + "crypto/ecdsa" "crypto/ed25519" + "crypto/elliptic" "crypto/rand" - "crypto/sha256" "encoding/json" "errors" "fmt" @@ -32,7 +33,10 @@ import ( "testing" "time" - gossh "github.com/tailscale/golang-x-crypto/ssh" + gossh "golang.org/x/crypto/ssh" + "golang.org/x/net/http2" + "golang.org/x/net/http2/h2c" + "tailscale.com/cmd/testwrapper/flakytest" "tailscale.com/ipn/ipnlocal" "tailscale.com/ipn/store/mem" "tailscale.com/net/memnet" @@ -40,15 +44,15 @@ import ( "tailscale.com/sessionrecording" "tailscale.com/tailcfg" "tailscale.com/tempfork/gliderlabs/ssh" + testssh "tailscale.com/tempfork/sshtest/ssh" "tailscale.com/tsd" "tailscale.com/tstest" "tailscale.com/types/key" - "tailscale.com/types/logger" "tailscale.com/types/logid" "tailscale.com/types/netmap" "tailscale.com/types/ptr" "tailscale.com/util/cibuild" - "tailscale.com/util/lineread" + "tailscale.com/util/lineiter" "tailscale.com/util/must" "tailscale.com/version/distro" "tailscale.com/wgengine" @@ -225,9 +229,9 @@ func TestMatchRule(t *testing.T) { t.Run(tt.name, func(t *testing.T) { c := &conn{ info: tt.ci, - srv: &server{logf: t.Logf}, + srv: &server{logf: tstest.WhileTestRunningLogger(t)}, } - got, gotUser, gotAcceptEnv, err := c.matchRule(tt.rule, nil) + got, gotUser, gotAcceptEnv, err := c.matchRule(tt.rule) if err != tt.wantErr { t.Errorf("err = %v; want %v", err, tt.wantErr) } @@ -250,7 +254,7 @@ func TestEvalSSHPolicy(t *testing.T) { name string policy *tailcfg.SSHPolicy ci *sshConnInfo - wantMatch bool + wantResult evalResult wantUser string wantAcceptEnv []string }{ @@ -296,10 +300,20 @@ func TestEvalSSHPolicy(t *testing.T) { ci: &sshConnInfo{sshUser: "alice"}, wantUser: "thealice", wantAcceptEnv: []string{"EXAMPLE", "?_?", "TEST_*"}, - wantMatch: true, + wantResult: accepted, }, { - name: "no-matches-returns-failure", + name: "no-matches-returns-rejected", + policy: &tailcfg.SSHPolicy{ + Rules: []*tailcfg.SSHRule{}, + }, + ci: &sshConnInfo{sshUser: "alice"}, + wantUser: "", + wantAcceptEnv: nil, + wantResult: rejected, + }, + { + name: "no-user-matches-returns-rejected-user", policy: &tailcfg.SSHPolicy{ Rules: []*tailcfg.SSHRule{ { @@ -337,23 +351,23 @@ func TestEvalSSHPolicy(t *testing.T) { ci: &sshConnInfo{sshUser: "alice"}, wantUser: "", wantAcceptEnv: nil, - wantMatch: false, + wantResult: rejectedUser, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c := &conn{ info: tt.ci, - srv: &server{logf: t.Logf}, + srv: &server{logf: tstest.WhileTestRunningLogger(t)}, } - got, gotUser, gotAcceptEnv, match := c.evalSSHPolicy(tt.policy, nil) - if match != tt.wantMatch { - t.Errorf("match = %v; want %v", match, tt.wantMatch) + got, gotUser, gotAcceptEnv, result := c.evalSSHPolicy(tt.policy) + if result != tt.wantResult { + t.Errorf("result = %v; want %v", result, tt.wantResult) } if gotUser != tt.wantUser { t.Errorf("user = %q; want %q", gotUser, tt.wantUser) } - if tt.wantMatch == true && got == nil { + if tt.wantResult == accepted && got == nil { t.Errorf("expected non-nil action on success") } if !slices.Equal(gotAcceptEnv, tt.wantAcceptEnv) { @@ -464,7 +478,7 @@ func (ts *localState) NodeKey() key.NodePublic { func newSSHRule(action *tailcfg.SSHAction) *tailcfg.SSHRule { return &tailcfg.SSHRule{ SSHUsers: map[string]string{ - "*": currentUser, + "alice": currentUser, }, Action: action, Principals: []*tailcfg.SSHPrincipal{ @@ -476,18 +490,19 @@ func newSSHRule(action *tailcfg.SSHAction) *tailcfg.SSHRule { } func TestSSHRecordingCancelsSessionsOnUploadFailure(t *testing.T) { + flakytest.Mark(t, "https://github.com/tailscale/tailscale/issues/7707") + if runtime.GOOS != "linux" && runtime.GOOS != "darwin" { t.Skipf("skipping on %q; only runs on linux and darwin", runtime.GOOS) } var handler http.HandlerFunc - recordingServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + recordingServer := mockRecordingServer(t, func(w http.ResponseWriter, r *http.Request) { handler(w, r) - })) - defer recordingServer.Close() + }) s := &server{ - logf: t.Logf, + logf: tstest.WhileTestRunningLogger(t), lb: &localState{ sshEnabled: true, matchingRule: newSSHRule( @@ -507,9 +522,9 @@ func TestSSHRecordingCancelsSessionsOnUploadFailure(t *testing.T) { defer s.Shutdown() const sshUser = "alice" - cfg := &gossh.ClientConfig{ + cfg := &testssh.ClientConfig{ User: sshUser, - HostKeyCallback: gossh.InsecureIgnoreHostKey(), + HostKeyCallback: testssh.InsecureIgnoreHostKey(), } tests := []struct { @@ -533,9 +548,10 @@ func TestSSHRecordingCancelsSessionsOnUploadFailure(t *testing.T) { { name: "upload-fails-after-starting", handler: func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.(http.Flusher).Flush() r.Body.Read(make([]byte, 1)) time.Sleep(100 * time.Millisecond) - w.WriteHeader(http.StatusInternalServerError) }, sshCommand: "echo hello && sleep 1 && echo world", wantClientOutput: "\r\n\r\nsession terminated\r\n\r\n", @@ -548,18 +564,19 @@ func TestSSHRecordingCancelsSessionsOnUploadFailure(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + s.logf = tstest.WhileTestRunningLogger(t) tstest.Replace(t, &handler, tt.handler) sc, dc := memnet.NewTCPConn(src, dst, 1024) var wg sync.WaitGroup wg.Add(1) go func() { defer wg.Done() - c, chans, reqs, err := gossh.NewClientConn(sc, sc.RemoteAddr().String(), cfg) + c, chans, reqs, err := testssh.NewClientConn(sc, sc.RemoteAddr().String(), cfg) if err != nil { t.Errorf("client: %v", err) return } - client := gossh.NewClient(c, chans, reqs) + client := testssh.NewClient(c, chans, reqs) defer client.Close() session, err := client.NewSession() if err != nil { @@ -597,12 +614,12 @@ func TestMultipleRecorders(t *testing.T) { t.Skipf("skipping on %q; only runs on linux and darwin", runtime.GOOS) } done := make(chan struct{}) - recordingServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + recordingServer := mockRecordingServer(t, func(w http.ResponseWriter, r *http.Request) { defer close(done) - io.ReadAll(r.Body) w.WriteHeader(http.StatusOK) - })) - defer recordingServer.Close() + w.(http.Flusher).Flush() + io.ReadAll(r.Body) + }) badRecorder, err := net.Listen("tcp", ":0") if err != nil { t.Fatal(err) @@ -610,18 +627,12 @@ func TestMultipleRecorders(t *testing.T) { badRecorderAddr := badRecorder.Addr().String() badRecorder.Close() - badRecordingServer500 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(500) - })) - defer badRecordingServer500.Close() - - badRecordingServer200 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(200) - })) - defer badRecordingServer200.Close() + badRecordingServer500 := mockRecordingServer(t, func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + }) s := &server{ - logf: t.Logf, + logf: tstest.WhileTestRunningLogger(t), lb: &localState{ sshEnabled: true, matchingRule: newSSHRule( @@ -630,7 +641,6 @@ func TestMultipleRecorders(t *testing.T) { Recorders: []netip.AddrPort{ netip.MustParseAddrPort(badRecorderAddr), netip.MustParseAddrPort(badRecordingServer500.Listener.Addr().String()), - netip.MustParseAddrPort(badRecordingServer200.Listener.Addr().String()), netip.MustParseAddrPort(recordingServer.Listener.Addr().String()), }, OnRecordingFailure: &tailcfg.SSHRecorderFailureAction{ @@ -647,21 +657,21 @@ func TestMultipleRecorders(t *testing.T) { sc, dc := memnet.NewTCPConn(src, dst, 1024) const sshUser = "alice" - cfg := &gossh.ClientConfig{ + cfg := &testssh.ClientConfig{ User: sshUser, - HostKeyCallback: gossh.InsecureIgnoreHostKey(), + HostKeyCallback: testssh.InsecureIgnoreHostKey(), } var wg sync.WaitGroup wg.Add(1) go func() { defer wg.Done() - c, chans, reqs, err := gossh.NewClientConn(sc, sc.RemoteAddr().String(), cfg) + c, chans, reqs, err := testssh.NewClientConn(sc, sc.RemoteAddr().String(), cfg) if err != nil { t.Errorf("client: %v", err) return } - client := gossh.NewClient(c, chans, reqs) + client := testssh.NewClient(c, chans, reqs) defer client.Close() session, err := client.NewSession() if err != nil { @@ -701,19 +711,21 @@ func TestSSHRecordingNonInteractive(t *testing.T) { } var recording []byte ctx, cancel := context.WithTimeout(context.Background(), time.Second) - recordingServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + recordingServer := mockRecordingServer(t, func(w http.ResponseWriter, r *http.Request) { defer cancel() + w.WriteHeader(http.StatusOK) + w.(http.Flusher).Flush() + var err error recording, err = io.ReadAll(r.Body) if err != nil { t.Error(err) return } - })) - defer recordingServer.Close() + }) s := &server{ - logf: logger.Discard, + logf: tstest.WhileTestRunningLogger(t), lb: &localState{ sshEnabled: true, matchingRule: newSSHRule( @@ -736,21 +748,21 @@ func TestSSHRecordingNonInteractive(t *testing.T) { sc, dc := memnet.NewTCPConn(src, dst, 1024) const sshUser = "alice" - cfg := &gossh.ClientConfig{ + cfg := &testssh.ClientConfig{ User: sshUser, - HostKeyCallback: gossh.InsecureIgnoreHostKey(), + HostKeyCallback: testssh.InsecureIgnoreHostKey(), } var wg sync.WaitGroup wg.Add(1) go func() { defer wg.Done() - c, chans, reqs, err := gossh.NewClientConn(sc, sc.RemoteAddr().String(), cfg) + c, chans, reqs, err := testssh.NewClientConn(sc, sc.RemoteAddr().String(), cfg) if err != nil { t.Errorf("client: %v", err) return } - client := gossh.NewClient(c, chans, reqs) + client := testssh.NewClient(c, chans, reqs) defer client.Close() session, err := client.NewSession() if err != nil { @@ -790,6 +802,11 @@ func TestSSHAuthFlow(t *testing.T) { Accept: true, Message: "Welcome to Tailscale SSH!", }) + bobRule := newSSHRule(&tailcfg.SSHAction{ + Accept: true, + Message: "Welcome to Tailscale SSH!", + }) + bobRule.SSHUsers = map[string]string{"bob": "bob"} rejectRule := newSSHRule(&tailcfg.SSHAction{ Reject: true, Message: "Go Away!", @@ -808,7 +825,17 @@ func TestSSHAuthFlow(t *testing.T) { state: &localState{ sshEnabled: true, }, - authErr: true, + authErr: true, + wantBanners: []string{"tailscale: tailnet policy does not permit you to SSH to this node\n"}, + }, + { + name: "user-mismatch", + state: &localState{ + sshEnabled: true, + matchingRule: bobRule, + }, + authErr: true, + wantBanners: []string{`tailscale: tailnet policy does not permit you to SSH as user "alice"` + "\n"}, }, { name: "accept", @@ -885,87 +912,160 @@ func TestSSHAuthFlow(t *testing.T) { }, } s := &server{ - logf: logger.Discard, + logf: tstest.WhileTestRunningLogger(t), } defer s.Shutdown() src, dst := must.Get(netip.ParseAddrPort("100.100.100.101:2231")), must.Get(netip.ParseAddrPort("100.100.100.102:22")) for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - sc, dc := memnet.NewTCPConn(src, dst, 1024) - s.lb = tc.state - sshUser := "alice" - if tc.sshUser != "" { - sshUser = tc.sshUser - } - var passwordUsed atomic.Bool - cfg := &gossh.ClientConfig{ - User: sshUser, - HostKeyCallback: gossh.InsecureIgnoreHostKey(), - Auth: []gossh.AuthMethod{ - gossh.PasswordCallback(func() (secret string, err error) { - if !tc.usesPassword { - t.Error("unexpected use of PasswordCallback") - return "", errors.New("unexpected use of PasswordCallback") - } + for _, authMethods := range [][]string{nil, {"publickey", "password"}, {"password", "publickey"}} { + t.Run(fmt.Sprintf("%s-skip-none-auth-%v", tc.name, strings.Join(authMethods, "-then-")), func(t *testing.T) { + s.logf = tstest.WhileTestRunningLogger(t) + + sc, dc := memnet.NewTCPConn(src, dst, 1024) + s.lb = tc.state + sshUser := "alice" + if tc.sshUser != "" { + sshUser = tc.sshUser + } + + wantBanners := slices.Clone(tc.wantBanners) + noneAuthEnabled := len(authMethods) == 0 + + var publicKeyUsed atomic.Bool + var passwordUsed atomic.Bool + var methods []testssh.AuthMethod + + for _, authMethod := range authMethods { + switch authMethod { + case "publickey": + methods = append(methods, + testssh.PublicKeysCallback(func() (signers []testssh.Signer, err error) { + publicKeyUsed.Store(true) + key, err := ecdsa.GenerateKey(elliptic.P384(), rand.Reader) + if err != nil { + return nil, err + } + sig, err := testssh.NewSignerFromKey(key) + if err != nil { + return nil, err + } + return []testssh.Signer{sig}, nil + })) + case "password": + methods = append(methods, testssh.PasswordCallback(func() (secret string, err error) { + passwordUsed.Store(true) + return "any-pass", nil + })) + } + } + + if noneAuthEnabled && tc.usesPassword { + methods = append(methods, testssh.PasswordCallback(func() (secret string, err error) { passwordUsed.Store(true) return "any-pass", nil - }), - }, - BannerCallback: func(message string) error { - if len(tc.wantBanners) == 0 { - t.Errorf("unexpected banner: %q", message) - } else if message != tc.wantBanners[0] { - t.Errorf("banner = %q; want %q", message, tc.wantBanners[0]) - } else { - t.Logf("banner = %q", message) - tc.wantBanners = tc.wantBanners[1:] + })) + } + + cfg := &testssh.ClientConfig{ + User: sshUser, + HostKeyCallback: testssh.InsecureIgnoreHostKey(), + SkipNoneAuth: !noneAuthEnabled, + Auth: methods, + BannerCallback: func(message string) error { + if len(wantBanners) == 0 { + t.Errorf("unexpected banner: %q", message) + } else if message != wantBanners[0] { + t.Errorf("banner = %q; want %q", message, wantBanners[0]) + } else { + t.Logf("banner = %q", message) + wantBanners = wantBanners[1:] + } + return nil + }, + } + + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + c, chans, reqs, err := testssh.NewClientConn(sc, sc.RemoteAddr().String(), cfg) + if err != nil { + if !tc.authErr { + t.Errorf("client: %v", err) + } + return + } else if tc.authErr { + c.Close() + t.Errorf("client: expected error, got nil") + return } - return nil - }, - } - var wg sync.WaitGroup - wg.Add(1) - go func() { - defer wg.Done() - c, chans, reqs, err := gossh.NewClientConn(sc, sc.RemoteAddr().String(), cfg) - if err != nil { - if !tc.authErr { + client := testssh.NewClient(c, chans, reqs) + defer client.Close() + session, err := client.NewSession() + if err != nil { t.Errorf("client: %v", err) + return } - return - } else if tc.authErr { - c.Close() - t.Errorf("client: expected error, got nil") - return + defer session.Close() + _, err = session.CombinedOutput("echo Ran echo!") + if err != nil { + t.Errorf("client: %v", err) + } + }() + if err := s.HandleSSHConn(dc); err != nil { + t.Errorf("unexpected error: %v", err) } - client := gossh.NewClient(c, chans, reqs) - defer client.Close() - session, err := client.NewSession() - if err != nil { - t.Errorf("client: %v", err) - return + wg.Wait() + if len(wantBanners) > 0 { + t.Errorf("missing banners: %v", wantBanners) } - defer session.Close() - _, err = session.CombinedOutput("echo Ran echo!") - if err != nil { - t.Errorf("client: %v", err) + + // Check to see which callbacks were invoked. + // + // When `none` auth is enabled, the public key callback should + // never fire, and the password callback should only fire if + // authentication succeeded and the client was trying to force + // password authentication by connecting with the '-password' + // username suffix. + // + // When skipping `none` auth, the first callback should always + // fire, and the 2nd callback should fire only if + // authentication failed. + wantPublicKey := false + wantPassword := false + if noneAuthEnabled { + wantPassword = !tc.authErr && tc.usesPassword + } else { + for i, authMethod := range authMethods { + switch authMethod { + case "publickey": + wantPublicKey = i == 0 || tc.authErr + case "password": + wantPassword = i == 0 || tc.authErr + } + } } - }() - if err := s.HandleSSHConn(dc); err != nil { - t.Errorf("unexpected error: %v", err) - } - wg.Wait() - if len(tc.wantBanners) > 0 { - t.Errorf("missing banners: %v", tc.wantBanners) - } - }) + + if wantPublicKey && !publicKeyUsed.Load() { + t.Error("public key should have been attempted") + } else if !wantPublicKey && publicKeyUsed.Load() { + t.Errorf("public key should not have been attempted") + } + + if wantPassword && !passwordUsed.Load() { + t.Error("password should have been attempted") + } else if !wantPassword && passwordUsed.Load() { + t.Error("password should not have been attempted") + } + }) + } } } func TestSSH(t *testing.T) { - var logf logger.Logf = t.Logf - sys := &tsd.System{} - eng, err := wgengine.NewFakeUserspaceEngine(logf, sys.Set, sys.HealthTracker(), sys.UserMetricsRegistry()) + logf := tstest.WhileTestRunningLogger(t) + sys := tsd.NewSystem() + eng, err := wgengine.NewFakeUserspaceEngine(logf, sys.Set, sys.HealthTracker.Get(), sys.UserMetricsRegistry(), sys.Bus.Get()) if err != nil { t.Fatal(err) } @@ -1123,98 +1223,12 @@ func TestSSH(t *testing.T) { func parseEnv(out []byte) map[string]string { e := map[string]string{} - lineread.Reader(bytes.NewReader(out), func(line []byte) error { - i := bytes.IndexByte(line, '=') - if i == -1 { - return nil - } - e[string(line[:i])] = string(line[i+1:]) - return nil - }) - return e -} - -func TestPublicKeyFetching(t *testing.T) { - var reqsTotal, reqsIfNoneMatchHit, reqsIfNoneMatchMiss int32 - ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - atomic.AddInt32((&reqsTotal), 1) - etag := fmt.Sprintf("W/%q", sha256.Sum256([]byte(r.URL.Path))) - w.Header().Set("Etag", etag) - if v := r.Header.Get("If-None-Match"); v != "" { - if v == etag { - atomic.AddInt32(&reqsIfNoneMatchHit, 1) - w.WriteHeader(304) - return - } - atomic.AddInt32(&reqsIfNoneMatchMiss, 1) - } - io.WriteString(w, "foo\nbar\n"+string(r.URL.Path)+"\n") - })) - ts.StartTLS() - defer ts.Close() - keys := ts.URL - - clock := &tstest.Clock{} - srv := &server{ - pubKeyHTTPClient: ts.Client(), - timeNow: clock.Now, - } - for range 2 { - got, err := srv.fetchPublicKeysURL(keys + "/alice.keys") - if err != nil { - t.Fatal(err) - } - if want := []string{"foo", "bar", "/alice.keys"}; !reflect.DeepEqual(got, want) { - t.Errorf("got %q; want %q", got, want) + for line := range lineiter.Bytes(out) { + if i := bytes.IndexByte(line, '='); i != -1 { + e[string(line[:i])] = string(line[i+1:]) } } - if got, want := atomic.LoadInt32(&reqsTotal), int32(1); got != want { - t.Errorf("got %d requests; want %d", got, want) - } - if got, want := atomic.LoadInt32(&reqsIfNoneMatchHit), int32(0); got != want { - t.Errorf("got %d etag hits; want %d", got, want) - } - clock.Advance(5 * time.Minute) - got, err := srv.fetchPublicKeysURL(keys + "/alice.keys") - if err != nil { - t.Fatal(err) - } - if want := []string{"foo", "bar", "/alice.keys"}; !reflect.DeepEqual(got, want) { - t.Errorf("got %q; want %q", got, want) - } - if got, want := atomic.LoadInt32(&reqsTotal), int32(2); got != want { - t.Errorf("got %d requests; want %d", got, want) - } - if got, want := atomic.LoadInt32(&reqsIfNoneMatchHit), int32(1); got != want { - t.Errorf("got %d etag hits; want %d", got, want) - } - if got, want := atomic.LoadInt32(&reqsIfNoneMatchMiss), int32(0); got != want { - t.Errorf("got %d etag misses; want %d", got, want) - } - -} - -func TestExpandPublicKeyURL(t *testing.T) { - c := &conn{ - info: &sshConnInfo{ - uprof: tailcfg.UserProfile{ - LoginName: "bar@baz.tld", - }, - }, - } - if got, want := c.expandPublicKeyURL("foo"), "foo"; got != want { - t.Errorf("basic: got %q; want %q", got, want) - } - if got, want := c.expandPublicKeyURL("https://example.com/$LOGINNAME_LOCALPART.keys"), "https://example.com/bar.keys"; got != want { - t.Errorf("localpart: got %q; want %q", got, want) - } - if got, want := c.expandPublicKeyURL("https://example.com/keys?email=$LOGINNAME_EMAIL"), "https://example.com/keys?email=bar@baz.tld"; got != want { - t.Errorf("email: got %q; want %q", got, want) - } - c.info = new(sshConnInfo) - if got, want := c.expandPublicKeyURL("https://example.com/keys?email=$LOGINNAME_EMAIL"), "https://example.com/keys?email="; got != want { - t.Errorf("on empty: got %q; want %q", got, want) - } + return e } func TestAcceptEnvPair(t *testing.T) { @@ -1302,3 +1316,22 @@ func TestStdOsUserUserAssumptions(t *testing.T) { t.Errorf("os/user.User has %v fields; this package assumes %v", got, want) } } + +func mockRecordingServer(t *testing.T, handleRecord http.HandlerFunc) *httptest.Server { + t.Helper() + mux := http.NewServeMux() + mux.HandleFunc("POST /record", func(http.ResponseWriter, *http.Request) { + t.Errorf("v1 recording endpoint called") + }) + mux.HandleFunc("HEAD /v2/record", func(http.ResponseWriter, *http.Request) {}) + mux.HandleFunc("POST /v2/record", handleRecord) + + h2s := &http2.Server{} + srv := httptest.NewUnstartedServer(h2c.NewHandler(mux, h2s)) + if err := http2.ConfigureServer(srv.Config, h2s); err != nil { + t.Errorf("configuring HTTP/2 support in recording server: %v", err) + } + srv.Start() + t.Cleanup(srv.Close) + return srv +} diff --git a/ssh/tailssh/testcontainers/Dockerfile b/ssh/tailssh/testcontainers/Dockerfile index c94c961d3..4ef1c1eb0 100644 --- a/ssh/tailssh/testcontainers/Dockerfile +++ b/ssh/tailssh/testcontainers/Dockerfile @@ -3,9 +3,12 @@ FROM ${BASE} ARG BASE -RUN echo "Install openssh, needed for scp." -RUN if echo "$BASE" | grep "ubuntu:"; then apt-get update -y && apt-get install -y openssh-client; fi -RUN if echo "$BASE" | grep "alpine:"; then apk add openssh; fi +RUN echo "Install openssh, needed for scp. Also install python3" +RUN if echo "$BASE" | grep "ubuntu:"; then apt-get update -y && apt-get install -y openssh-client python3 python3-pip; fi +RUN if echo "$BASE" | grep "alpine:"; then apk add openssh python3 py3-pip; fi + +RUN echo "Install paramiko" +RUN pip3 install paramiko==3.5.1 || pip3 install --break-system-packages paramiko==3.5.1 # Note - on Ubuntu, we do not create the user's home directory, pam_mkhomedir will do that # for us, and we want to test that PAM gets triggered by Tailscale SSH. @@ -33,6 +36,8 @@ RUN if echo "$BASE" | grep "ubuntu:"; then rm -Rf /home/testuser; fi RUN TAILSCALED_PATH=`pwd`tailscaled ./tailssh.test -test.v -test.run TestIntegrationSCP RUN if echo "$BASE" | grep "ubuntu:"; then rm -Rf /home/testuser; fi RUN TAILSCALED_PATH=`pwd`tailscaled ./tailssh.test -test.v -test.run TestIntegrationSSH +RUN if echo "$BASE" | grep "ubuntu:"; then rm -Rf /home/testuser; fi +RUN TAILSCALED_PATH=`pwd`tailscaled ./tailssh.test -test.v -test.run TestIntegrationParamiko RUN echo "Then run tests as non-root user testuser and make sure tests still pass." RUN touch /tmp/tailscalessh.log diff --git a/ssh/tailssh/user.go b/ssh/tailssh/user.go index 33ebb4db7..ac92c762a 100644 --- a/ssh/tailssh/user.go +++ b/ssh/tailssh/user.go @@ -1,12 +1,11 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -//go:build linux || (darwin && !ios) || freebsd || openbsd +//go:build (linux && !android) || (darwin && !ios) || freebsd || openbsd || plan9 package tailssh import ( - "io" "os" "os/exec" "os/user" @@ -18,7 +17,7 @@ import ( "go4.org/mem" "tailscale.com/envknob" "tailscale.com/hostinfo" - "tailscale.com/util/lineread" + "tailscale.com/util/lineiter" "tailscale.com/util/osuser" "tailscale.com/version/distro" ) @@ -49,6 +48,9 @@ func userLookup(username string) (*userMeta, error) { } func (u *userMeta) LoginShell() string { + if runtime.GOOS == "plan9" { + return "/bin/rc" + } if u.loginShellCached != "" { // This field should be populated on Linux, at least, because // func userLookup on Linux uses "getent" to look up the user @@ -86,6 +88,9 @@ func defaultPathForUser(u *user.User) string { if s := defaultPathTmpl(); s != "" { return expandDefaultPathTmpl(s, u) } + if runtime.GOOS == "plan9" { + return "/bin" + } isRoot := u.Uid == "0" switch distro.Get() { case distro.Debian: @@ -110,15 +115,16 @@ func defaultPathForUser(u *user.User) string { } func defaultPathForUserOnNixOS(u *user.User) string { - var path string - lineread.File("/etc/pam/environment", func(lineb []byte) error { + for lr := range lineiter.File("/etc/pam/environment") { + lineb, err := lr.Value() + if err != nil { + return "" + } if v := pathFromPAMEnvLine(lineb, u); v != "" { - path = v - return io.EOF // stop iteration + return v } - return nil - }) - return path + } + return "" } func pathFromPAMEnvLine(line []byte, u *user.User) (path string) { diff --git a/syncs/locked.go b/syncs/locked.go index d2048665d..d2e9edef7 100644 --- a/syncs/locked.go +++ b/syncs/locked.go @@ -8,7 +8,7 @@ import ( ) // AssertLocked panics if m is not locked. -func AssertLocked(m *sync.Mutex) { +func AssertLocked(m *Mutex) { if m.TryLock() { m.Unlock() panic("mutex is not locked") @@ -16,7 +16,7 @@ func AssertLocked(m *sync.Mutex) { } // AssertRLocked panics if rw is not locked for reading or writing. -func AssertRLocked(rw *sync.RWMutex) { +func AssertRLocked(rw *RWMutex) { if rw.TryLock() { rw.Unlock() panic("mutex is not locked") diff --git a/syncs/mutex.go b/syncs/mutex.go new file mode 100644 index 000000000..e61d1d1ab --- /dev/null +++ b/syncs/mutex.go @@ -0,0 +1,18 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !ts_mutex_debug + +package syncs + +import "sync" + +// Mutex is an alias for sync.Mutex. +// +// It's only not a sync.Mutex when built with the ts_mutex_debug build tag. +type Mutex = sync.Mutex + +// RWMutex is an alias for sync.RWMutex. +// +// It's only not a sync.RWMutex when built with the ts_mutex_debug build tag. +type RWMutex = sync.RWMutex diff --git a/syncs/mutex_debug.go b/syncs/mutex_debug.go new file mode 100644 index 000000000..14b52ffe3 --- /dev/null +++ b/syncs/mutex_debug.go @@ -0,0 +1,18 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build ts_mutex_debug + +package syncs + +import "sync" + +type Mutex struct { + sync.Mutex +} + +type RWMutex struct { + sync.RWMutex +} + +// TODO(bradfitz): actually track stuff when in debug mode. diff --git a/syncs/shardedint.go b/syncs/shardedint.go new file mode 100644 index 000000000..28c4168d5 --- /dev/null +++ b/syncs/shardedint.go @@ -0,0 +1,69 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package syncs + +import ( + "encoding/json" + "strconv" + "sync/atomic" + + "golang.org/x/sys/cpu" +) + +// ShardedInt provides a sharded atomic int64 value that optimizes high +// frequency (Mhz range and above) writes in highly parallel workloads. +// The zero value is not safe for use; use [NewShardedInt]. +// ShardedInt implements the expvar.Var interface. +type ShardedInt struct { + sv *ShardValue[intShard] +} + +// NewShardedInt returns a new [ShardedInt]. +func NewShardedInt() *ShardedInt { + return &ShardedInt{ + sv: NewShardValue[intShard](), + } +} + +// Add adds delta to the value. +func (m *ShardedInt) Add(delta int64) { + m.sv.One(func(v *intShard) { + v.Add(delta) + }) +} + +type intShard struct { + atomic.Int64 + _ cpu.CacheLinePad // avoid false sharing of neighboring shards +} + +// Value returns the current value. +func (m *ShardedInt) Value() int64 { + var v int64 + for s := range m.sv.All { + v += s.Load() + } + return v +} + +// GetDistribution returns the current value in each shard. +// This is intended for observability/debugging only. +func (m *ShardedInt) GetDistribution() []int64 { + v := make([]int64, 0, m.sv.Len()) + for s := range m.sv.All { + v = append(v, s.Load()) + } + return v +} + +// String implements the expvar.Var interface +func (m *ShardedInt) String() string { + v, _ := json.Marshal(m.Value()) + return string(v) +} + +// AppendText implements the encoding.TextAppender interface +func (m *ShardedInt) AppendText(b []byte) ([]byte, error) { + return strconv.AppendInt(b, m.Value(), 10), nil +} diff --git a/syncs/shardedint_test.go b/syncs/shardedint_test.go new file mode 100644 index 000000000..815a739d1 --- /dev/null +++ b/syncs/shardedint_test.go @@ -0,0 +1,120 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package syncs_test + +import ( + "expvar" + "sync" + "testing" + + . "tailscale.com/syncs" + "tailscale.com/tstest" +) + +var ( + _ expvar.Var = (*ShardedInt)(nil) + // TODO(raggi): future go version: + // _ encoding.TextAppender = (*ShardedInt)(nil) +) + +func BenchmarkShardedInt(b *testing.B) { + b.ReportAllocs() + + b.Run("expvar", func(b *testing.B) { + var m expvar.Int + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + m.Add(1) + } + }) + }) + + b.Run("sharded int", func(b *testing.B) { + m := NewShardedInt() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + m.Add(1) + } + }) + }) +} + +func TestShardedInt(t *testing.T) { + t.Run("basics", func(t *testing.T) { + m := NewShardedInt() + if got, want := m.Value(), int64(0); got != want { + t.Errorf("got %v, want %v", got, want) + } + m.Add(1) + if got, want := m.Value(), int64(1); got != want { + t.Errorf("got %v, want %v", got, want) + } + m.Add(2) + if got, want := m.Value(), int64(3); got != want { + t.Errorf("got %v, want %v", got, want) + } + m.Add(-1) + if got, want := m.Value(), int64(2); got != want { + t.Errorf("got %v, want %v", got, want) + } + }) + + t.Run("high concurrency", func(t *testing.T) { + m := NewShardedInt() + wg := sync.WaitGroup{} + numWorkers := 1000 + numIncrements := 1000 + wg.Add(numWorkers) + for i := 0; i < numWorkers; i++ { + go func() { + defer wg.Done() + for i := 0; i < numIncrements; i++ { + m.Add(1) + } + }() + } + wg.Wait() + if got, want := m.Value(), int64(numWorkers*numIncrements); got != want { + t.Errorf("got %v, want %v", got, want) + } + for i, shard := range m.GetDistribution() { + t.Logf("shard %d: %d", i, shard) + } + }) + + t.Run("encoding.TextAppender", func(t *testing.T) { + m := NewShardedInt() + m.Add(1) + b := make([]byte, 0, 10) + b, err := m.AppendText(b) + if err != nil { + t.Fatal(err) + } + if got, want := string(b), "1"; got != want { + t.Errorf("got %v, want %v", got, want) + } + }) + + t.Run("allocs", func(t *testing.T) { + m := NewShardedInt() + tstest.MinAllocsPerRun(t, 0, func() { + m.Add(1) + _ = m.Value() + }) + + // TODO(raggi): fix access to expvar's internal append based + // interface, unfortunately it's not currently closed for external + // use, this will alloc when it escapes. + tstest.MinAllocsPerRun(t, 0, func() { + m.Add(1) + _ = m.String() + }) + + b := make([]byte, 0, 10) + tstest.MinAllocsPerRun(t, 0, func() { + m.Add(1) + m.AppendText(b) + }) + }) +} diff --git a/syncs/shardvalue.go b/syncs/shardvalue.go new file mode 100644 index 000000000..b1474477c --- /dev/null +++ b/syncs/shardvalue.go @@ -0,0 +1,36 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package syncs + +// TODO(raggi): this implementation is still imperfect as it will still result +// in cross CPU sharing periodically, we instead really want a per-CPU shard +// key, but the limitations of calling platform code make reaching for even the +// getcpu vdso very painful. See https://github.com/golang/go/issues/18802, and +// hopefully one day we can replace with a primitive that falls out of that +// work. + +// ShardValue contains a value sharded over a set of shards. +// In order to be useful, T should be aligned to cache lines. +// Users must organize that usage in One and All is concurrency safe. +// The zero value is not safe for use; use [NewShardValue]. +type ShardValue[T any] struct { + shards []T + + //lint:ignore U1000 unused under tailscale_go builds. + pool shardValuePool +} + +// Len returns the number of shards. +func (sp *ShardValue[T]) Len() int { + return len(sp.shards) +} + +// All yields a pointer to the value in each shard. +func (sp *ShardValue[T]) All(yield func(*T) bool) { + for i := range sp.shards { + if !yield(&sp.shards[i]) { + return + } + } +} diff --git a/syncs/shardvalue_go.go b/syncs/shardvalue_go.go new file mode 100644 index 000000000..9b9d252a7 --- /dev/null +++ b/syncs/shardvalue_go.go @@ -0,0 +1,36 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !tailscale_go + +package syncs + +import ( + "runtime" + "sync" + "sync/atomic" +) + +type shardValuePool struct { + atomic.Int64 + sync.Pool +} + +// NewShardValue constructs a new ShardValue[T] with a shard per CPU. +func NewShardValue[T any]() *ShardValue[T] { + sp := &ShardValue[T]{ + shards: make([]T, runtime.NumCPU()), + } + sp.pool.New = func() any { + i := sp.pool.Add(1) - 1 + return &sp.shards[i%int64(len(sp.shards))] + } + return sp +} + +// One yields a pointer to a single shard value with best-effort P-locality. +func (sp *ShardValue[T]) One(yield func(*T)) { + v := sp.pool.Get().(*T) + yield(v) + sp.pool.Put(v) +} diff --git a/syncs/shardvalue_tailscale.go b/syncs/shardvalue_tailscale.go new file mode 100644 index 000000000..8ef778ff3 --- /dev/null +++ b/syncs/shardvalue_tailscale.go @@ -0,0 +1,24 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// TODO(raggi): update build tag after toolchain update +//go:build tailscale_go + +package syncs + +import ( + "runtime" +) + +//lint:ignore U1000 unused under tailscale_go builds. +type shardValuePool struct{} + +// NewShardValue constructs a new ShardValue[T] with a shard per CPU. +func NewShardValue[T any]() *ShardValue[T] { + return &ShardValue[T]{shards: make([]T, runtime.NumCPU())} +} + +// One yields a pointer to a single shard value with best-effort P-locality. +func (sp *ShardValue[T]) One(f func(*T)) { + f(&sp.shards[runtime.TailscaleCurrentP()%len(sp.shards)]) +} diff --git a/syncs/shardvalue_test.go b/syncs/shardvalue_test.go new file mode 100644 index 000000000..8f6ac6414 --- /dev/null +++ b/syncs/shardvalue_test.go @@ -0,0 +1,119 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package syncs + +import ( + "math" + "runtime" + "sync" + "sync/atomic" + "testing" + + "golang.org/x/sys/cpu" +) + +func TestShardValue(t *testing.T) { + type intVal struct { + atomic.Int64 + _ cpu.CacheLinePad + } + + t.Run("One", func(t *testing.T) { + sv := NewShardValue[intVal]() + sv.One(func(v *intVal) { + v.Store(10) + }) + + var v int64 + for i := range sv.shards { + v += sv.shards[i].Load() + } + if v != 10 { + t.Errorf("got %v, want 10", v) + } + }) + + t.Run("All", func(t *testing.T) { + sv := NewShardValue[intVal]() + for i := range sv.shards { + sv.shards[i].Store(int64(i)) + } + + var total int64 + sv.All(func(v *intVal) bool { + total += v.Load() + return true + }) + // triangle coefficient lower one order due to 0 index + want := int64(len(sv.shards) * (len(sv.shards) - 1) / 2) + if total != want { + t.Errorf("got %v, want %v", total, want) + } + }) + + t.Run("Len", func(t *testing.T) { + sv := NewShardValue[intVal]() + if got, want := sv.Len(), runtime.NumCPU(); got != want { + t.Errorf("got %v, want %v", got, want) + } + }) + + t.Run("distribution", func(t *testing.T) { + sv := NewShardValue[intVal]() + + goroutines := 1000 + iterations := 10000 + var wg sync.WaitGroup + wg.Add(goroutines) + for i := 0; i < goroutines; i++ { + go func() { + defer wg.Done() + for i := 0; i < iterations; i++ { + sv.One(func(v *intVal) { + v.Add(1) + }) + } + }() + } + wg.Wait() + + var ( + total int64 + distribution []int64 + ) + t.Logf("distribution:") + sv.All(func(v *intVal) bool { + total += v.Load() + distribution = append(distribution, v.Load()) + t.Logf("%d", v.Load()) + return true + }) + + if got, want := total, int64(goroutines*iterations); got != want { + t.Errorf("got %v, want %v", got, want) + } + if got, want := len(distribution), runtime.NumCPU(); got != want { + t.Errorf("got %v, want %v", got, want) + } + + mean := total / int64(len(distribution)) + for _, v := range distribution { + if v < mean/10 || v > mean*10 { + t.Logf("distribution is very unbalanced: %v", distribution) + } + } + t.Logf("mean: %d", mean) + + var standardDev int64 + for _, v := range distribution { + standardDev += ((v - mean) * (v - mean)) + } + standardDev = int64(math.Sqrt(float64(standardDev / int64(len(distribution))))) + t.Logf("stdev: %d", standardDev) + + if standardDev > mean/3 { + t.Logf("standard deviation is too high: %v", standardDev) + } + }) +} diff --git a/syncs/syncs.go b/syncs/syncs.go index 0d40204d2..3b37bca08 100644 --- a/syncs/syncs.go +++ b/syncs/syncs.go @@ -6,6 +6,7 @@ package syncs import ( "context" + "iter" "sync" "sync/atomic" @@ -24,6 +25,7 @@ func initClosedChan() <-chan struct{} { } // AtomicValue is the generic version of [atomic.Value]. +// See [MutexValue] for guidance on whether to use this type. type AtomicValue[T any] struct { v atomic.Value } @@ -65,12 +67,79 @@ func (v *AtomicValue[T]) Swap(x T) (old T) { if oldV != nil { return oldV.(wrappedValue[T]).v } - return old + return old // zero value of T } // CompareAndSwap executes the compare-and-swap operation for the Value. +// It panics if T is not comparable. func (v *AtomicValue[T]) CompareAndSwap(oldV, newV T) (swapped bool) { - return v.v.CompareAndSwap(wrappedValue[T]{oldV}, wrappedValue[T]{newV}) + var zero T + return v.v.CompareAndSwap(wrappedValue[T]{oldV}, wrappedValue[T]{newV}) || + // In the edge-case where [atomic.Value.Store] is uninitialized + // and trying to compare with the zero value of T, + // then compare-and-swap with the nil any value. + (any(oldV) == any(zero) && v.v.CompareAndSwap(any(nil), wrappedValue[T]{newV})) +} + +// MutexValue is a value protected by a mutex. +// +// AtomicValue, [MutexValue], [atomic.Pointer] are similar and +// overlap in their use cases. +// +// - Use [atomic.Pointer] if the value being stored is a pointer and +// you only ever need load and store operations. +// An atomic pointer only occupies 1 word of memory. +// +// - Use [MutexValue] if the value being stored is not a pointer or +// you need the ability for a mutex to protect a set of operations +// performed on the value. +// A mutex-guarded value occupies 1 word of memory plus +// the memory representation of T. +// +// - AtomicValue is useful for non-pointer types that happen to +// have the memory layout of a single pointer. +// Examples include a map, channel, func, or a single field struct +// that contains any prior types. +// An atomic value occupies 2 words of memory. +// Consequently, Storing of non-pointer types always allocates. +// +// Note that [AtomicValue] has the ability to report whether it was set +// while [MutexValue] lacks the ability to detect if the value was set +// and it happens to be the zero value of T. If such a use case is +// necessary, then you could consider wrapping T in [opt.Value]. +type MutexValue[T any] struct { + mu sync.Mutex + v T +} + +// WithLock calls f with a pointer to the value while holding the lock. +// The provided pointer must not leak beyond the scope of the call. +func (m *MutexValue[T]) WithLock(f func(p *T)) { + m.mu.Lock() + defer m.mu.Unlock() + f(&m.v) +} + +// Load returns a shallow copy of the underlying value. +func (m *MutexValue[T]) Load() T { + m.mu.Lock() + defer m.mu.Unlock() + return m.v +} + +// Store stores a shallow copy of the provided value. +func (m *MutexValue[T]) Store(v T) { + m.mu.Lock() + defer m.mu.Unlock() + m.v = v +} + +// Swap stores new into m and returns the previous value. +func (m *MutexValue[T]) Swap(new T) (old T) { + m.mu.Lock() + defer m.mu.Unlock() + old, m.v = m.v, new + return old } // WaitGroupChan is like a sync.WaitGroup, but has a chan that closes @@ -132,6 +201,13 @@ func NewSemaphore(n int) Semaphore { return Semaphore{c: make(chan struct{}, n)} } +// Len reports the number of in-flight acquisitions. +// It is incremented whenever the semaphore is acquired. +// It is decremented whenever the semaphore is released. +func (s Semaphore) Len() int { + return len(s.c) +} + // Acquire blocks until a resource is acquired. func (s Semaphore) Acquire() { s.c <- struct{}{} @@ -252,16 +328,47 @@ func (m *Map[K, V]) Delete(key K) { delete(m.m, key) } -// Range iterates over the map in an undefined order calling f for each entry. -// Iteration stops if f returns false. Map changes are blocked during iteration. +// Keys iterates over all keys in the map in an undefined order. // A read lock is held for the entire duration of the iteration. // Use the [WithLock] method instead to mutate the map during iteration. -func (m *Map[K, V]) Range(f func(key K, value V) bool) { - m.mu.RLock() - defer m.mu.RUnlock() - for k, v := range m.m { - if !f(k, v) { - return +func (m *Map[K, V]) Keys() iter.Seq[K] { + return func(yield func(K) bool) { + m.mu.RLock() + defer m.mu.RUnlock() + for k := range m.m { + if !yield(k) { + return + } + } + } +} + +// Values iterates over all values in the map in an undefined order. +// A read lock is held for the entire duration of the iteration. +// Use the [WithLock] method instead to mutate the map during iteration. +func (m *Map[K, V]) Values() iter.Seq[V] { + return func(yield func(V) bool) { + m.mu.RLock() + defer m.mu.RUnlock() + for _, v := range m.m { + if !yield(v) { + return + } + } + } +} + +// All iterates over all entries in the map in an undefined order. +// A read lock is held for the entire duration of the iteration. +// Use the [WithLock] method instead to mutate the map during iteration. +func (m *Map[K, V]) All() iter.Seq2[K, V] { + return func(yield func(K, V) bool) { + m.mu.RLock() + defer m.mu.RUnlock() + for k, v := range m.m { + if !yield(k, v) { + return + } } } } @@ -272,6 +379,9 @@ func (m *Map[K, V]) Range(f func(key K, value V) bool) { func (m *Map[K, V]) WithLock(f func(m2 map[K]V)) { m.mu.Lock() defer m.mu.Unlock() + if m.m == nil { + m.m = make(map[K]V) + } f(m.m) } @@ -299,19 +409,3 @@ func (m *Map[K, V]) Swap(key K, value V) (oldValue V) { mak.Set(&m.m, key, value) return oldValue } - -// WaitGroup is identical to [sync.WaitGroup], -// but provides a Go method to start a goroutine. -type WaitGroup struct{ sync.WaitGroup } - -// Go calls the given function in a new goroutine. -// It automatically increments the counter before execution and -// automatically decrements the counter after execution. -// It must not be called concurrently with Wait. -func (wg *WaitGroup) Go(f func()) { - wg.Add(1) - go func() { - defer wg.Done() - f() - }() -} diff --git a/syncs/syncs_test.go b/syncs/syncs_test.go index 0748dcb72..a546b8d0a 100644 --- a/syncs/syncs_test.go +++ b/syncs/syncs_test.go @@ -7,7 +7,9 @@ import ( "context" "io" "os" + "sync" "testing" + "time" "github.com/google/go-cmp/cmp" ) @@ -63,6 +65,56 @@ func TestAtomicValue(t *testing.T) { t.Fatalf("LoadOk = (%v, %v), want (nil, true)", got, gotOk) } } + + { + c1, c2, c3 := make(chan struct{}), make(chan struct{}), make(chan struct{}) + var v AtomicValue[chan struct{}] + if v.CompareAndSwap(c1, c2) != false { + t.Fatalf("CompareAndSwap = true, want false") + } + if v.CompareAndSwap(nil, c1) != true { + t.Fatalf("CompareAndSwap = false, want true") + } + if v.CompareAndSwap(c2, c3) != false { + t.Fatalf("CompareAndSwap = true, want false") + } + if v.CompareAndSwap(c1, c2) != true { + t.Fatalf("CompareAndSwap = false, want true") + } + } +} + +func TestMutexValue(t *testing.T) { + var v MutexValue[time.Time] + if n := int(testing.AllocsPerRun(1000, func() { + v.Store(v.Load()) + v.WithLock(func(*time.Time) {}) + })); n != 0 { + t.Errorf("AllocsPerRun = %d, want 0", n) + } + + now := time.Now() + v.Store(now) + if !v.Load().Equal(now) { + t.Errorf("Load = %v, want %v", v.Load(), now) + } + + var group sync.WaitGroup + var v2 MutexValue[int] + var sum int + for i := range 10 { + group.Go(func() { + old1 := v2.Load() + old2 := v2.Swap(old1 + i) + delta := old2 - old1 + v2.WithLock(func(p *int) { *p += delta }) + }) + sum += i + } + group.Wait() + if v2.Load() != sum { + t.Errorf("Load = %v, want %v", v2.Load(), sum) + } } func TestWaitGroupChan(t *testing.T) { @@ -110,10 +162,20 @@ func TestClosedChan(t *testing.T) { func TestSemaphore(t *testing.T) { s := NewSemaphore(2) + assertLen := func(want int) { + t.Helper() + if got := s.Len(); got != want { + t.Fatalf("Len = %d, want %d", got, want) + } + } + + assertLen(0) s.Acquire() + assertLen(1) if !s.TryAcquire() { t.Fatal("want true") } + assertLen(2) if s.TryAcquire() { t.Fatal("want false") } @@ -123,11 +185,15 @@ func TestSemaphore(t *testing.T) { t.Fatal("want false") } s.Release() + assertLen(1) if !s.AcquireContext(context.Background()) { t.Fatal("want true") } + assertLen(2) s.Release() + assertLen(1) s.Release() + assertLen(0) } func TestMap(t *testing.T) { @@ -160,10 +226,9 @@ func TestMap(t *testing.T) { } got := map[string]int{} want := map[string]int{"one": 1, "two": 2, "three": 3} - m.Range(func(k string, v int) bool { + for k, v := range m.All() { got[k] = v - return true - }) + } if d := cmp.Diff(got, want); d != "" { t.Errorf("Range mismatch (-got +want):\n%s", d) } @@ -178,17 +243,16 @@ func TestMap(t *testing.T) { m.Delete("noexist") got = map[string]int{} want = map[string]int{} - m.Range(func(k string, v int) bool { + for k, v := range m.All() { got[k] = v - return true - }) + } if d := cmp.Diff(got, want); d != "" { t.Errorf("Range mismatch (-got +want):\n%s", d) } t.Run("LoadOrStore", func(t *testing.T) { var m Map[string, string] - var wg WaitGroup + var wg sync.WaitGroup var ok1, ok2 bool wg.Go(func() { _, ok1 = m.LoadOrStore("", "") }) wg.Go(func() { _, ok2 = m.LoadOrStore("", "") }) diff --git a/tailcfg/c2ntypes.go b/tailcfg/c2ntypes.go index 54efb736e..d78baef1c 100644 --- a/tailcfg/c2ntypes.go +++ b/tailcfg/c2ntypes.go @@ -5,7 +5,10 @@ package tailcfg -import "net/netip" +import ( + "encoding/json" + "net/netip" +) // C2NSSHUsernamesRequest is the request for the /ssh/usernames. // A GET request without a request body is equivalent to the zero value of this type. @@ -102,3 +105,44 @@ type C2NTLSCertInfo struct { // TODO(bradfitz): add fields for whether an ACME fetch is currently in // process and when it started, etc. } + +// C2NVIPServicesResponse is the response (from node to control) from the +// /vip-services handler. +// +// It returns the list of VIPServices that the node is currently serving with +// their port info and whether they are active or not. It also returns a hash of +// the response to allow the control server to detect changes. +type C2NVIPServicesResponse struct { + // VIPServices is the list of VIP services that the node is currently serving. + VIPServices []*VIPService `json:",omitempty"` + + // ServicesHash is the hash of VIPServices to allow the control server to detect + // changes. This value matches what is reported in latest [Hostinfo.ServicesHash]. + ServicesHash string +} + +// C2NDebugNetmapRequest is the request (from control to node) for the +// /debug/netmap handler. +type C2NDebugNetmapRequest struct { + // Candidate is an optional full MapResponse to be used for generating a candidate + // network map. If unset, only the current network map is returned. + Candidate *MapResponse `json:"candidate,omitzero"` + + // OmitFields is an optional list of netmap fields to omit from the response. + // If unset, no fields are omitted. + OmitFields []string `json:"omitFields,omitzero"` +} + +// C2NDebugNetmapResponse is the response (from node to control) from the +// /debug/netmap handler. It contains the current network map and, if a +// candidate full MapResponse was provided in the request, a candidate network +// map generated from it. +// To avoid import cycles, and reflect the non-stable nature of +// netmap.NetworkMap values, they are returned as json.RawMessage. +type C2NDebugNetmapResponse struct { + // Current is the current network map (netmap.NetworkMap). + Current json.RawMessage `json:"current"` + + // Candidate is a network map produced based on the candidate MapResponse. + Candidate json.RawMessage `json:"candidate,omitzero"` +} diff --git a/tailcfg/derpmap.go b/tailcfg/derpmap.go index 056152157..e05559f3e 100644 --- a/tailcfg/derpmap.go +++ b/tailcfg/derpmap.go @@ -96,12 +96,32 @@ type DERPRegion struct { Latitude float64 `json:",omitempty"` Longitude float64 `json:",omitempty"` - // Avoid is whether the client should avoid picking this as its home - // region. The region should only be used if a peer is there. - // Clients already using this region as their home should migrate - // away to a new region without Avoid set. + // Avoid is whether the client should avoid picking this as its home region. + // The region should only be used if a peer is there. Clients already using + // this region as their home should migrate away to a new region without + // Avoid set. + // + // Deprecated: because of bugs in past implementations combined with unclear + // docs that caused people to think the bugs were intentional, this field is + // deprecated. It was never supposed to cause STUN/DERP measurement probes, + // but due to bugs, it sometimes did. And then some parts of the code began + // to rely on that property. But then we were unable to use this field for + // its original purpose, nor its later imagined purpose, because various + // parts of the codebase thought it meant one thing and others thought it + // meant another. But it did something in the middle instead. So we're retiring + // it. Use NoMeasureNoHome instead. Avoid bool `json:",omitempty"` + // NoMeasureNoHome says that this regions should not be measured for its + // latency distance (STUN, HTTPS, etc) or availability (e.g. captive portal + // checks) and should never be selected as the node's home region. However, + // if a peer declares this region as its home, then this client is allowed + // to connect to it for the purpose of communicating with that peer. + // + // This is what the now deprecated Avoid bool was supposed to mean + // originally but had implementation bugs and documentation omissions. + NoMeasureNoHome bool `json:",omitempty"` + // Nodes are the DERP nodes running in this region, in // priority order for the current client. Client TLS // connections should ideally only go to the first entry @@ -139,6 +159,12 @@ type DERPNode struct { // name. If empty, HostName is used. If CertName is non-empty, // HostName is only used for the TCP dial (if IPv4/IPv6 are // not present) + TLS ClientHello. + // + // As a special case, if CertName starts with "sha256-raw:", + // then the rest of the string is a hex-encoded SHA256 of the + // cert to expect. This is used for self-signed certs. + // In this case, the HostName field will typically be an IP + // address literal. CertName string `json:",omitempty"` // IPv4 optionally forces an IPv4 address to use, instead of using DNS. diff --git a/tailcfg/proto_port_range.go b/tailcfg/proto_port_range.go index f65c58804..03505dbd1 100644 --- a/tailcfg/proto_port_range.go +++ b/tailcfg/proto_port_range.go @@ -5,7 +5,6 @@ package tailcfg import ( "errors" - "fmt" "strconv" "strings" @@ -70,14 +69,7 @@ func (ppr ProtoPortRange) String() string { buf.Write(text) buf.Write([]byte(":")) } - pr := ppr.Ports - if pr.First == pr.Last { - fmt.Fprintf(&buf, "%d", pr.First) - } else if pr == PortRangeAny { - buf.WriteByte('*') - } else { - fmt.Fprintf(&buf, "%d-%d", pr.First, pr.Last) - } + buf.WriteString(ppr.Ports.String()) return buf.String() } @@ -104,7 +96,7 @@ func parseProtoPortRange(ipProtoPort string) (*ProtoPortRange, error) { if !strings.Contains(ipProtoPort, ":") { ipProtoPort = "*:" + ipProtoPort } - protoStr, portRange, err := parseHostPortRange(ipProtoPort) + protoStr, portRange, err := ParseHostPortRange(ipProtoPort) if err != nil { return nil, err } @@ -126,9 +118,9 @@ func parseProtoPortRange(ipProtoPort string) (*ProtoPortRange, error) { return ppr, nil } -// parseHostPortRange parses hostport as HOST:PORTS where HOST is +// ParseHostPortRange parses hostport as HOST:PORTS where HOST is // returned unchanged and PORTS is is either "*" or PORTLOW-PORTHIGH ranges. -func parseHostPortRange(hostport string) (host string, ports PortRange, err error) { +func ParseHostPortRange(hostport string) (host string, ports PortRange, err error) { hostport = strings.ToLower(hostport) colon := strings.LastIndexByte(hostport, ':') if colon < 0 { diff --git a/tailcfg/tailcfg.go b/tailcfg/tailcfg.go index df50a8603..41e0a0b28 100644 --- a/tailcfg/tailcfg.go +++ b/tailcfg/tailcfg.go @@ -5,7 +5,7 @@ // the node and the coordination server. package tailcfg -//go:generate go run tailscale.com/cmd/viewer --type=User,Node,Hostinfo,NetInfo,Login,DNSConfig,RegisterResponse,RegisterResponseAuth,RegisterRequest,DERPHomeParams,DERPRegion,DERPMap,DERPNode,SSHRule,SSHAction,SSHPrincipal,ControlDialPlan,Location,UserProfile --clonefunc +//go:generate go run tailscale.com/cmd/viewer --type=User,Node,Hostinfo,NetInfo,Login,DNSConfig,RegisterResponse,RegisterResponseAuth,RegisterRequest,DERPHomeParams,DERPRegion,DERPMap,DERPNode,SSHRule,SSHAction,SSHPrincipal,ControlDialPlan,Location,UserProfile,VIPService,SSHPolicy --clonefunc import ( "bytes" @@ -17,16 +17,20 @@ import ( "net/netip" "reflect" "slices" + "strconv" "strings" "time" + "tailscale.com/feature/buildfeatures" "tailscale.com/types/dnstype" "tailscale.com/types/key" "tailscale.com/types/opt" "tailscale.com/types/structs" "tailscale.com/types/tkatype" + "tailscale.com/types/views" "tailscale.com/util/dnsname" "tailscale.com/util/slicesx" + "tailscale.com/util/vizerror" ) // CapabilityVersion represents the client's capability level. That @@ -142,44 +146,103 @@ type CapabilityVersion int // - 97: 2024-06-06: Client understands NodeAttrDisableSplitDNSWhenNoCustomResolvers // - 98: 2024-06-13: iOS/tvOS clients may provide serial number as part of posture information // - 99: 2024-06-14: Client understands NodeAttrDisableLocalDNSOverrideViaNRPT -// - 100: 2024-06-18: Client supports filtertype.Match.SrcCaps (issue #12542) +// - 100: 2024-06-18: Initial support for filtertype.Match.SrcCaps - actually usable in capver 109 (issue #12542) // - 101: 2024-07-01: Client supports SSH agent forwarding when handling connections with /bin/su // - 102: 2024-07-12: NodeAttrDisableMagicSockCryptoRouting support // - 103: 2024-07-24: Client supports NodeAttrDisableCaptivePortalDetection // - 104: 2024-08-03: SelfNodeV6MasqAddrForThisPeer now works // - 105: 2024-08-05: Fixed SSH behavior on systems that use busybox (issue #12849) // - 106: 2024-09-03: fix panic regression from cryptokey routing change (65fe0ba7b5) -const CurrentCapabilityVersion CapabilityVersion = 106 - -type StableID string - +// - 107: 2024-10-30: add App Connector to conffile (PR #13942) +// - 108: 2024-11-08: Client sends ServicesHash in Hostinfo, understands c2n GET /vip-services. +// - 109: 2024-11-18: Client supports filtertype.Match.SrcCaps (issue #12542) +// - 110: 2024-12-12: removed never-before-used Tailscale SSH public key support (#14373) +// - 111: 2025-01-14: Client supports a peer having Node.HomeDERP (issue #14636) +// - 112: 2025-01-14: Client interprets AllowedIPs of nil as meaning same as Addresses +// - 113: 2025-01-20: Client communicates to control whether funnel is enabled by sending Hostinfo.IngressEnabled (#14688) +// - 114: 2025-01-30: NodeAttrMaxKeyDuration CapMap defined, clients might use it (no tailscaled code change) (#14829) +// - 115: 2025-03-07: Client understands DERPRegion.NoMeasureNoHome. +// - 116: 2025-05-05: Client serves MagicDNS "AAAA" if NodeAttrMagicDNSPeerAAAA set on self node +// - 117: 2025-05-28: Client understands DisplayMessages (structured health messages), but not necessarily PrimaryAction. +// - 118: 2025-07-01: Client sends Hostinfo.StateEncrypted to report whether the state file is encrypted at rest (#15830) +// - 119: 2025-07-10: Client uses Hostinfo.Location.Priority to prioritize one route over another. +// - 120: 2025-07-15: Client understands peer relay disco messages, and implements peer client and relay server functions +// - 121: 2025-07-19: Client understands peer relay endpoint alloc with [disco.AllocateUDPRelayEndpointRequest] & [disco.AllocateUDPRelayEndpointResponse] +// - 122: 2025-07-21: Client sends Hostinfo.ExitNodeID to report which exit node it has selected, if any. +// - 123: 2025-07-28: fix deadlock regression from cryptokey routing change (issue #16651) +// - 124: 2025-08-08: removed NodeAttrDisableMagicSockCryptoRouting support, crypto routing is now mandatory +// - 125: 2025-08-11: dnstype.Resolver adds UseWithExitNode field. +// - 126: 2025-09-17: Client uses seamless key renewal unless disabled by control (tailscale/corp#31479) +// - 127: 2025-09-19: can handle C2N /debug/netmap. +// - 128: 2025-10-02: can handle C2N /debug/health. +// - 129: 2025-10-04: Fixed sleep/wake deadlock in magicsock when using peer relay (PR #17449) +// - 130: 2025-10-06: client can send key.HardwareAttestationPublic and key.HardwareAttestationKeySignature in MapRequest +const CurrentCapabilityVersion CapabilityVersion = 130 + +// ID is an integer ID for a user, node, or login allocated by the +// control plane. +// +// To be nice, control plane servers should not use int64s that are too large to +// fit in a JavaScript number (see JavaScript's Number.MAX_SAFE_INTEGER). +// The Tailscale-hosted control plane stopped allocating large integers in +// March 2023 but nodes prior to that may have IDs larger than +// MAX_SAFE_INTEGER (2^53 – 1). +// +// IDs must not be zero or negative. type ID int64 +// UserID is an [ID] for a [User]. type UserID ID func (u UserID) IsZero() bool { return u == 0 } +// LoginID is an [ID] for a [Login]. +// +// It is not used in the Tailscale client, but is used in the control plane. type LoginID ID func (u LoginID) IsZero() bool { return u == 0 } +// NodeID is a unique integer ID for a node. +// +// It's global within a control plane URL ("tailscale up --login-server") and is +// (as of 2025-01-06) never re-used even after a node is deleted. +// +// To be nice, control plane servers should not use int64s that are too large to +// fit in a JavaScript number (see JavaScript's Number.MAX_SAFE_INTEGER). +// The Tailscale-hosted control plane stopped allocating large integers in +// March 2023 but nodes prior to that may have node IDs larger than +// MAX_SAFE_INTEGER (2^53 – 1). +// +// NodeIDs are not stable across control plane URLs. For more stable URLs, +// see [StableNodeID]. type NodeID ID func (u NodeID) IsZero() bool { return u == 0 } -type StableNodeID StableID +// StableNodeID is a string form of [NodeID]. +// +// Different control plane servers should ideally have different StableNodeID +// suffixes for different sites or regions. +// +// Being a string, it's safer to use in JavaScript without worrying about the +// size of the integer, as documented on [NodeID]. +// +// But in general, Tailscale APIs can accept either a [NodeID] integer or a +// [StableNodeID] string when referring to a node. +type StableNodeID string func (u StableNodeID) IsZero() bool { return u == "" } -// User is an IPN user. +// User is a Tailscale user. // // A user can have multiple logins associated with it (e.g. gmail and github oauth). // (Note: none of our UIs support this yet.) @@ -192,34 +255,30 @@ func (u StableNodeID) IsZero() bool { // have a general gmail address login associated with the user. type User struct { ID UserID - LoginName string `json:"-"` // not stored, filled from Login // TODO REMOVE - DisplayName string // if non-empty overrides Login field - ProfilePicURL string // if non-empty overrides Login field - Logins []LoginID - Created time.Time + DisplayName string // if non-empty overrides Login field + ProfilePicURL string `json:",omitzero"` // if non-empty overrides Login field + Created time.Time `json:",omitzero"` } +// Login is a user from a specific identity provider, not associated with any +// particular tailnet. type Login struct { _ structs.Incomparable - ID LoginID - Provider string - LoginName string - DisplayName string - ProfilePicURL string + ID LoginID // unused in the Tailscale client + Provider string // "google", "github", "okta_foo", etc. + LoginName string // an email address or "email-ish" string (like alice@github) + DisplayName string // from the IdP + ProfilePicURL string `json:",omitzero"` // from the IdP } -// A UserProfile is display-friendly data for a user. +// A UserProfile is display-friendly data for a [User]. // It includes the LoginName for display purposes but *not* the Provider. // It also includes derived data from one of the user's logins. type UserProfile struct { ID UserID LoginName string // "alice@smith.com"; for display purposes only (provider is not listed) DisplayName string // "Alice Smith" - ProfilePicURL string - - // Roles exists for legacy reasons, to keep old macOS clients - // happy. It JSON marshals as []. - Roles emptyStructJSONSlice + ProfilePicURL string `json:",omitzero"` } func (p *UserProfile) Equal(p2 *UserProfile) bool { @@ -235,16 +294,6 @@ func (p *UserProfile) Equal(p2 *UserProfile) bool { p.ProfilePicURL == p2.ProfilePicURL } -type emptyStructJSONSlice struct{} - -var emptyJSONSliceBytes = []byte("[]") - -func (emptyStructJSONSlice) MarshalJSON() ([]byte, error) { - return emptyJSONSliceBytes, nil -} - -func (emptyStructJSONSlice) UnmarshalJSON([]byte) error { return nil } - // RawMessage is a raw encoded JSON value. It implements Marshaler and // Unmarshaler and can be used to delay JSON decoding or precompute a JSON // encoding. @@ -279,6 +328,7 @@ func MarshalCapJSON[T any](capRule T) (RawMessage, error) { return RawMessage(string(bs)), nil } +// Node is a Tailscale device in a tailnet. type Node struct { ID NodeID StableID StableNodeID @@ -295,30 +345,48 @@ type Node struct { User UserID // Sharer, if non-zero, is the user who shared this node, if different than User. - Sharer UserID `json:",omitempty"` + Sharer UserID `json:",omitzero"` Key key.NodePublic - KeyExpiry time.Time // the zero value if this node does not expire + KeyExpiry time.Time `json:",omitzero"` // the zero value if this node does not expire KeySignature tkatype.MarshaledSignature `json:",omitempty"` - Machine key.MachinePublic - DiscoKey key.DiscoPublic - Addresses []netip.Prefix // IP addresses of this Node directly - AllowedIPs []netip.Prefix // range of IP addresses to route to this node - Endpoints []netip.AddrPort `json:",omitempty"` // IP+port (public via STUN, and local LANs) + Machine key.MachinePublic `json:",omitzero"` + DiscoKey key.DiscoPublic `json:",omitzero"` - // DERP is this node's home DERP region ID integer, but shoved into an + // Addresses are the IP addresses of this Node directly. + Addresses []netip.Prefix + + // AllowedIPs are the IP ranges to route to this node. + // + // As of CapabilityVersion 112, this may be nil (null or undefined) on the wire + // to mean the same as Addresses. Internally, it is always filled in with + // its possibly-implicit value. + AllowedIPs []netip.Prefix `json:",omitzero"` // _not_ omitempty; only nil is special + + Endpoints []netip.AddrPort `json:",omitempty"` // IP+port (public via STUN, and local LANs) + + // LegacyDERPString is this node's home LegacyDERPString region ID integer, but shoved into an // IP:port string for legacy reasons. The IP address is always "127.3.3.40" // (a loopback address (127) followed by the digits over the letters DERP on - // a QWERTY keyboard (3.3.40)). The "port number" is the home DERP region ID + // a QWERTY keyboard (3.3.40)). The "port number" is the home LegacyDERPString region ID // integer. // - // TODO(bradfitz): simplify this legacy mess; add a new HomeDERPRegionID int - // field behind a new capver bump. - DERP string `json:",omitempty"` // DERP-in-IP:port ("127.3.3.40:N") endpoint + // Deprecated: HomeDERP has replaced this, but old servers might still send + // this field. See tailscale/tailscale#14636. Do not use this field in code + // other than in the upgradeNode func, which canonicalizes it to HomeDERP + // if it arrives as a LegacyDERPString string on the wire. + LegacyDERPString string `json:"DERP,omitzero"` // DERP-in-IP:port ("127.3.3.40:N") endpoint - Hostinfo HostinfoView - Created time.Time - Cap CapabilityVersion `json:",omitempty"` // if non-zero, the node's capability version; old servers might not send + // HomeDERP is the modern version of the DERP string field, with just an + // integer. The client advertises support for this as of capver 111. + // + // HomeDERP may be zero if not (yet) known, but ideally always be non-zero + // for magicsock connectivity to function normally. + HomeDERP int `json:",omitzero"` // DERP region ID of the node's home DERP + + Hostinfo HostinfoView `json:",omitzero"` + Created time.Time `json:",omitzero"` + Cap CapabilityVersion `json:",omitzero"` // if non-zero, the node's capability version; old servers might not send // Tags are the list of ACL tags applied to this node. // Tags take the form of `tag:` where value starts @@ -385,25 +453,25 @@ type Node struct { // it do anything. It is the tailscaled client's job to double-check the // MapResponse's PacketFilter to verify that its AllowedIPs will not be // accepted by the packet filter. - UnsignedPeerAPIOnly bool `json:",omitempty"` + UnsignedPeerAPIOnly bool `json:",omitzero"` // The following three computed fields hold the various names that can // be used for this node in UIs. They are populated from controlclient // (not from control) by calling node.InitDisplayNames. These can be // used directly or accessed via node.DisplayName or node.DisplayNames. - ComputedName string `json:",omitempty"` // MagicDNS base name (for normal non-shared-in nodes), FQDN (without trailing dot, for shared-in nodes), or Hostname (if no MagicDNS) + ComputedName string `json:",omitzero"` // MagicDNS base name (for normal non-shared-in nodes), FQDN (without trailing dot, for shared-in nodes), or Hostname (if no MagicDNS) computedHostIfDifferent string // hostname, if different than ComputedName, otherwise empty - ComputedNameWithHost string `json:",omitempty"` // either "ComputedName" or "ComputedName (computedHostIfDifferent)", if computedHostIfDifferent is set + ComputedNameWithHost string `json:",omitzero"` // either "ComputedName" or "ComputedName (computedHostIfDifferent)", if computedHostIfDifferent is set // DataPlaneAuditLogID is the per-node logtail ID used for data plane audit logging. - DataPlaneAuditLogID string `json:",omitempty"` + DataPlaneAuditLogID string `json:",omitzero"` // Expired is whether this node's key has expired. Control may send // this; clients are only allowed to set this from false to true. On // the client, this is calculated client-side based on a timestamp sent // from control, to avoid clock skew issues. - Expired bool `json:",omitempty"` + Expired bool `json:",omitzero"` // SelfNodeV4MasqAddrForThisPeer is the IPv4 that this peer knows the current node as. // It may be empty if the peer knows the current node by its native @@ -418,7 +486,7 @@ type Node struct { // This only applies to traffic originating from the current node to the // peer or any of its subnets. Traffic originating from subnet routes will // not be masqueraded (e.g. in case of --snat-subnet-routes). - SelfNodeV4MasqAddrForThisPeer *netip.Addr `json:",omitempty"` + SelfNodeV4MasqAddrForThisPeer *netip.Addr `json:",omitzero"` // TODO: de-pointer: tailscale/tailscale#17978 // SelfNodeV6MasqAddrForThisPeer is the IPv6 that this peer knows the current node as. // It may be empty if the peer knows the current node by its native @@ -433,17 +501,17 @@ type Node struct { // This only applies to traffic originating from the current node to the // peer or any of its subnets. Traffic originating from subnet routes will // not be masqueraded (e.g. in case of --snat-subnet-routes). - SelfNodeV6MasqAddrForThisPeer *netip.Addr `json:",omitempty"` + SelfNodeV6MasqAddrForThisPeer *netip.Addr `json:",omitzero"` // TODO: de-pointer: tailscale/tailscale#17978 // IsWireGuardOnly indicates that this is a non-Tailscale WireGuard peer, it // is not expected to speak Disco or DERP, and it must have Endpoints in // order to be reachable. - IsWireGuardOnly bool `json:",omitempty"` + IsWireGuardOnly bool `json:",omitzero"` // IsJailed indicates that this node is jailed and should not be allowed // initiate connections, however outbound connections to it should still be // allowed. - IsJailed bool `json:",omitempty"` + IsJailed bool `json:",omitzero"` // ExitNodeDNSResolvers is the list of DNS servers that should be used when this // node is marked IsWireGuardOnly and being used as an exit node. @@ -559,6 +627,11 @@ func (n *Node) InitDisplayNames(networkMagicDNSSuffix string) { n.ComputedNameWithHost = nameWithHost } +// MachineStatus is the state of a [Node]'s approval into a tailnet. +// +// A "node" and a "machine" are often 1:1, but technically a Tailscale +// daemon has one machine key and can have multiple nodes (e.g. different +// users on Windows) for that one machine key. type MachineStatus int const ( @@ -754,10 +827,10 @@ type Location struct { // Because it contains pointers (slices), this type should not be used // as a value type. type Hostinfo struct { - IPNVersion string `json:",omitempty"` // version of this code (in version.Long format) - FrontendLogID string `json:",omitempty"` // logtail ID of frontend instance - BackendLogID string `json:",omitempty"` // logtail ID of backend instance - OS string `json:",omitempty"` // operating system the client runs on (a version.OS value) + IPNVersion string `json:",omitzero"` // version of this code (in version.Long format) + FrontendLogID string `json:",omitzero"` // logtail ID of frontend instance + BackendLogID string `json:",omitzero"` // logtail ID of backend instance + OS string `json:",omitzero"` // operating system the client runs on (a version.OS value) // OSVersion is the version of the OS, if available. // @@ -769,51 +842,166 @@ type Hostinfo struct { // string on Linux, like "Debian 10.4; kernel=xxx; container; env=kn" and so // on. As of Tailscale 1.32, this is simply the kernel version on Linux, like // "5.10.0-17-amd64". - OSVersion string `json:",omitempty"` + OSVersion string `json:",omitzero"` - Container opt.Bool `json:",omitempty"` // whether the client is running in a container - Env string `json:",omitempty"` // a hostinfo.EnvType in string form - Distro string `json:",omitempty"` // "debian", "ubuntu", "nixos", ... - DistroVersion string `json:",omitempty"` // "20.04", ... - DistroCodeName string `json:",omitempty"` // "jammy", "bullseye", ... + Container opt.Bool `json:",omitzero"` // best-effort whether the client is running in a container + Env string `json:",omitzero"` // a hostinfo.EnvType in string form + Distro string `json:",omitzero"` // "debian", "ubuntu", "nixos", ... + DistroVersion string `json:",omitzero"` // "20.04", ... + DistroCodeName string `json:",omitzero"` // "jammy", "bullseye", ... // App is used to disambiguate Tailscale clients that run using tsnet. - App string `json:",omitempty"` // "k8s-operator", "golinks", ... - - Desktop opt.Bool `json:",omitempty"` // if a desktop was detected on Linux - Package string `json:",omitempty"` // Tailscale package to disambiguate ("choco", "appstore", etc; "" for unknown) - DeviceModel string `json:",omitempty"` // mobile phone model ("Pixel 3a", "iPhone12,3") - PushDeviceToken string `json:",omitempty"` // macOS/iOS APNs device token for notifications (and Android in the future) - Hostname string `json:",omitempty"` // name of the host the client runs on - ShieldsUp bool `json:",omitempty"` // indicates whether the host is blocking incoming connections - ShareeNode bool `json:",omitempty"` // indicates this node exists in netmap because it's owned by a shared-to user - NoLogsNoSupport bool `json:",omitempty"` // indicates that the user has opted out of sending logs and support - WireIngress bool `json:",omitempty"` // indicates that the node wants the option to receive ingress connections - AllowsUpdate bool `json:",omitempty"` // indicates that the node has opted-in to admin-console-drive remote updates - Machine string `json:",omitempty"` // the current host's machine type (uname -m) - GoArch string `json:",omitempty"` // GOARCH value (of the built binary) - GoArchVar string `json:",omitempty"` // GOARM, GOAMD64, etc (of the built binary) - GoVersion string `json:",omitempty"` // Go version binary was built with + App string `json:",omitzero"` // "k8s-operator", "golinks", ... + + Desktop opt.Bool `json:",omitzero"` // if a desktop was detected on Linux + Package string `json:",omitzero"` // Tailscale package to disambiguate ("choco", "appstore", etc; "" for unknown) + DeviceModel string `json:",omitzero"` // mobile phone model ("Pixel 3a", "iPhone12,3") + PushDeviceToken string `json:",omitzero"` // macOS/iOS APNs device token for notifications (and Android in the future) + Hostname string `json:",omitzero"` // name of the host the client runs on + ShieldsUp bool `json:",omitzero"` // indicates whether the host is blocking incoming connections + ShareeNode bool `json:",omitzero"` // indicates this node exists in netmap because it's owned by a shared-to user + NoLogsNoSupport bool `json:",omitzero"` // indicates that the user has opted out of sending logs and support + // WireIngress indicates that the node would like to be wired up server-side + // (DNS, etc) to be able to use Tailscale Funnel, even if it's not currently + // enabled. For example, the user might only use it for intermittent + // foreground CLI serve sessions, for which they'd like it to work right + // away, even if it's disabled most of the time. As an optimization, this is + // only sent if IngressEnabled is false, as IngressEnabled implies that this + // option is true. + WireIngress bool `json:",omitzero"` + IngressEnabled bool `json:",omitzero"` // if the node has any funnel endpoint enabled + AllowsUpdate bool `json:",omitzero"` // indicates that the node has opted-in to admin-console-drive remote updates + Machine string `json:",omitzero"` // the current host's machine type (uname -m) + GoArch string `json:",omitzero"` // GOARCH value (of the built binary) + GoArchVar string `json:",omitzero"` // GOARM, GOAMD64, etc (of the built binary) + GoVersion string `json:",omitzero"` // Go version binary was built with RoutableIPs []netip.Prefix `json:",omitempty"` // set of IP ranges this client can route RequestTags []string `json:",omitempty"` // set of ACL tags this node wants to claim WoLMACs []string `json:",omitempty"` // MAC address(es) to send Wake-on-LAN packets to wake this node (lowercase hex w/ colons) Services []Service `json:",omitempty"` // services advertised by this machine - NetInfo *NetInfo `json:",omitempty"` + NetInfo *NetInfo `json:",omitzero"` SSH_HostKeys []string `json:"sshHostKeys,omitempty"` // if advertised - Cloud string `json:",omitempty"` - Userspace opt.Bool `json:",omitempty"` // if the client is running in userspace (netstack) mode - UserspaceRouter opt.Bool `json:",omitempty"` // if the client's subnet router is running in userspace (netstack) mode - AppConnector opt.Bool `json:",omitempty"` // if the client is running the app-connector service + Cloud string `json:",omitzero"` + Userspace opt.Bool `json:",omitzero"` // if the client is running in userspace (netstack) mode + UserspaceRouter opt.Bool `json:",omitzero"` // if the client's subnet router is running in userspace (netstack) mode + AppConnector opt.Bool `json:",omitzero"` // if the client is running the app-connector service + ServicesHash string `json:",omitzero"` // opaque hash of the most recent list of tailnet services, change in hash indicates config should be fetched via c2n + ExitNodeID StableNodeID `json:",omitzero"` // the client’s selected exit node, empty when unselected. // Location represents geographical location data about a // Tailscale host. Location is optional and only set if // explicitly declared by a node. - Location *Location `json:",omitempty"` + Location *Location `json:",omitzero"` + + TPM *TPMInfo `json:",omitzero"` // TPM device metadata, if available + // StateEncrypted reports whether the node state is stored encrypted on + // disk. The actual mechanism is platform-specific: + // * Apple nodes use the Keychain + // * Linux and Windows nodes use the TPM + // * Android apps use EncryptedSharedPreferences + StateEncrypted opt.Bool `json:",omitzero"` // NOTE: any new fields containing pointers in this type // require changes to Hostinfo.Equal. } +// TPMInfo contains information about a TPM 2.0 device present on a node. +// All fields are read from TPM_CAP_TPM_PROPERTIES, see Part 2, section 6.13 of +// https://trustedcomputinggroup.org/resource/tpm-library-specification/. +type TPMInfo struct { + // Manufacturer is a 4-letter code from section 4.1 of + // https://trustedcomputinggroup.org/resource/vendor-id-registry/, + // for example "MSFT" for Microsoft. + // Read from TPM_PT_MANUFACTURER. + Manufacturer string `json:",omitzero"` + // Vendor is a vendor ID string, up to 16 characters. + // Read from TPM_PT_VENDOR_STRING_*. + Vendor string `json:",omitzero"` + // Model is a vendor-defined TPM model. + // Read from TPM_PT_VENDOR_TPM_TYPE. + Model int `json:",omitzero"` + // FirmwareVersion is the version number of the firmware. + // Read from TPM_PT_FIRMWARE_VERSION_*. + FirmwareVersion uint64 `json:",omitzero"` + // SpecRevision is the TPM 2.0 spec revision encoded as a single number. All + // revisions can be found at + // https://trustedcomputinggroup.org/resource/tpm-library-specification/. + // Before revision 184, TCG used the "01.83" format for revision 183. + SpecRevision int `json:",omitzero"` + + // FamilyIndicator is the TPM spec family, like "2.0". + // Read from TPM_PT_FAMILY_INDICATOR. + FamilyIndicator string `json:",omitzero"` +} + +// Present reports whether a TPM device is present on this machine. +func (t *TPMInfo) Present() bool { return t != nil } + +// ServiceName is the name of a service, of the form `svc:dns-label`. Services +// represent some kind of application provided for users of the tailnet with a +// MagicDNS name and possibly dedicated IP addresses. Currently (2024-01-21), +// the only type of service is [VIPService]. +// This is not related to the older [Service] used in [Hostinfo.Services]. +type ServiceName string + +// AsServiceName reports whether the given string is a valid service name. +// If so returns the name as a [tailcfg.ServiceName], otherwise returns "". +func AsServiceName(s string) ServiceName { + svcName := ServiceName(s) + if err := svcName.Validate(); err != nil { + return "" + } + return svcName +} + +// Validate validates if the service name is formatted correctly. +// We only allow valid DNS labels, since the expectation is that these will be +// used as parts of domain names. All errors are [vizerror.Error]. +func (sn ServiceName) Validate() error { + bareName, ok := strings.CutPrefix(string(sn), "svc:") + if !ok { + return vizerror.Errorf("%q is not a valid service name: must start with 'svc:'", sn) + } + if bareName == "" { + return vizerror.Errorf("%q is not a valid service name: must not be empty after the 'svc:' prefix", sn) + } + return dnsname.ValidLabel(bareName) +} + +// String implements [fmt.Stringer]. +func (sn ServiceName) String() string { + return string(sn) +} + +// WithoutPrefix is the name of the service without the `svc:` prefix, used for +// DNS names. If the name does not include the prefix (which means +// [ServiceName.Validate] would return an error) then it returns "". +func (sn ServiceName) WithoutPrefix() string { + bareName, ok := strings.CutPrefix(string(sn), "svc:") + if !ok { + return "" + } + return bareName +} + +// VIPService represents a service created on a tailnet from the +// perspective of a node providing that service. These services +// have an virtual IP (VIP) address pair distinct from the node's IPs. +type VIPService struct { + // Name is the name of the service. The Name uniquely identifies a service + // on a particular tailnet, and so also corresponds uniquely to the pair of + // IP addresses belonging to the VIP service. + Name ServiceName + + // Ports specify which ProtoPorts are made available by this node + // on the service's IPs. + Ports []ProtoPortRange + + // Active specifies whether new requests for the service should be + // sent to this node by control. + Active bool +} + // TailscaleSSHEnabled reports whether or not this node is acting as a // Tailscale SSH server. func (hi *Hostinfo) TailscaleSSHEnabled() bool { @@ -824,53 +1012,41 @@ func (hi *Hostinfo) TailscaleSSHEnabled() bool { func (v HostinfoView) TailscaleSSHEnabled() bool { return v.Đļ.TailscaleSSHEnabled() } -// TailscaleFunnelEnabled reports whether or not this node has explicitly -// enabled Funnel. -func (hi *Hostinfo) TailscaleFunnelEnabled() bool { - return hi != nil && hi.WireIngress -} - -func (v HostinfoView) TailscaleFunnelEnabled() bool { return v.Đļ.TailscaleFunnelEnabled() } - // NetInfo contains information about the host's network state. type NetInfo struct { // MappingVariesByDestIP says whether the host's NAT mappings // vary based on the destination IP. - MappingVariesByDestIP opt.Bool - - // HairPinning is their router does hairpinning. - // It reports true even if there's no NAT involved. - HairPinning opt.Bool + MappingVariesByDestIP opt.Bool `json:",omitzero"` // WorkingIPv6 is whether the host has IPv6 internet connectivity. - WorkingIPv6 opt.Bool + WorkingIPv6 opt.Bool `json:",omitzero"` // OSHasIPv6 is whether the OS supports IPv6 at all, regardless of // whether IPv6 internet connectivity is available. - OSHasIPv6 opt.Bool + OSHasIPv6 opt.Bool `json:",omitzero"` // WorkingUDP is whether the host has UDP internet connectivity. - WorkingUDP opt.Bool + WorkingUDP opt.Bool `json:",omitzero"` // WorkingICMPv4 is whether ICMPv4 works. // Empty means not checked. - WorkingICMPv4 opt.Bool + WorkingICMPv4 opt.Bool `json:",omitzero"` // HavePortMap is whether we have an existing portmap open // (UPnP, PMP, or PCP). - HavePortMap bool `json:",omitempty"` + HavePortMap bool `json:",omitzero"` // UPnP is whether UPnP appears present on the LAN. // Empty means not checked. - UPnP opt.Bool + UPnP opt.Bool `json:",omitzero"` // PMP is whether NAT-PMP appears present on the LAN. // Empty means not checked. - PMP opt.Bool + PMP opt.Bool `json:",omitzero"` // PCP is whether PCP appears present on the LAN. // Empty means not checked. - PCP opt.Bool + PCP opt.Bool `json:",omitzero"` // PreferredDERP is this node's preferred (home) DERP region ID. // This is where the node expects to be contacted to begin a @@ -879,10 +1055,10 @@ type NetInfo struct { // that are located elsewhere) but PreferredDERP is the region ID // that the node subscribes to traffic at. // Zero means disconnected or unknown. - PreferredDERP int + PreferredDERP int `json:",omitzero"` // LinkType is the current link type, if known. - LinkType string `json:",omitempty"` // "wired", "wifi", "mobile" (LTE, 4G, 3G, etc) + LinkType string `json:",omitzero"` // "wired", "wifi", "mobile" (LTE, 4G, 3G, etc) // DERPLatency is the fastest recent time to reach various // DERP STUN servers, in seconds. The map key is the @@ -900,7 +1076,7 @@ type NetInfo struct { // "{nft,ift}-REASON", like "nft-forced" or "ipt-default". Empty means // either not Linux or a configuration in which the host firewall rules // are not managed by tailscaled. - FirewallMode string `json:",omitempty"` + FirewallMode string `json:",omitzero"` // Update BasicallyEqual when adding fields. } @@ -909,13 +1085,16 @@ func (ni *NetInfo) String() string { if ni == nil { return "NetInfo(nil)" } - return fmt.Sprintf("NetInfo{varies=%v hairpin=%v ipv6=%v ipv6os=%v udp=%v icmpv4=%v derp=#%v portmap=%v link=%q firewallmode=%q}", - ni.MappingVariesByDestIP, ni.HairPinning, ni.WorkingIPv6, + return fmt.Sprintf("NetInfo{varies=%v ipv6=%v ipv6os=%v udp=%v icmpv4=%v derp=#%v portmap=%v link=%q firewallmode=%q}", + ni.MappingVariesByDestIP, ni.WorkingIPv6, ni.OSHasIPv6, ni.WorkingUDP, ni.WorkingICMPv4, ni.PreferredDERP, ni.portMapSummary(), ni.LinkType, ni.FirewallMode) } func (ni *NetInfo) portMapSummary() string { + if !buildfeatures.HasPortMapper { + return "x" + } if !ni.HavePortMap && ni.UPnP == "" && ni.PMP == "" && ni.PCP == "" { return "?" } @@ -950,7 +1129,6 @@ func (ni *NetInfo) BasicallyEqual(ni2 *NetInfo) bool { return true } return ni.MappingVariesByDestIP == ni2.MappingVariesByDestIP && - ni.HairPinning == ni2.HairPinning && ni.WorkingIPv6 == ni2.WorkingIPv6 && ni.OSHasIPv6 == ni2.OSHasIPv6 && ni.WorkingUDP == ni2.WorkingUDP && @@ -975,68 +1153,6 @@ func (h *Hostinfo) Equal(h2 *Hostinfo) bool { return reflect.DeepEqual(h, h2) } -// HowUnequal returns a list of paths through Hostinfo where h and h2 differ. -// If they differ in nil-ness, the path is "nil", otherwise the path is like -// "ShieldsUp" or "NetInfo.nil" or "NetInfo.PCP". -func (h *Hostinfo) HowUnequal(h2 *Hostinfo) (path []string) { - return appendStructPtrDiff(nil, "", reflect.ValueOf(h), reflect.ValueOf(h2)) -} - -func appendStructPtrDiff(base []string, pfx string, p1, p2 reflect.Value) (ret []string) { - ret = base - if p1.IsNil() && p2.IsNil() { - return base - } - mkPath := func(b string) string { - if pfx == "" { - return b - } - return pfx + "." + b - } - if p1.IsNil() || p2.IsNil() { - return append(base, mkPath("nil")) - } - v1, v2 := p1.Elem(), p2.Elem() - t := v1.Type() - for i, n := 0, t.NumField(); i < n; i++ { - sf := t.Field(i) - switch sf.Type.Kind() { - case reflect.String: - if v1.Field(i).String() != v2.Field(i).String() { - ret = append(ret, mkPath(sf.Name)) - } - continue - case reflect.Bool: - if v1.Field(i).Bool() != v2.Field(i).Bool() { - ret = append(ret, mkPath(sf.Name)) - } - continue - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - if v1.Field(i).Int() != v2.Field(i).Int() { - ret = append(ret, mkPath(sf.Name)) - } - continue - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: - if v1.Field(i).Uint() != v2.Field(i).Uint() { - ret = append(ret, mkPath(sf.Name)) - } - continue - case reflect.Slice, reflect.Map: - if !reflect.DeepEqual(v1.Field(i).Interface(), v2.Field(i).Interface()) { - ret = append(ret, mkPath(sf.Name)) - } - continue - case reflect.Ptr: - if sf.Type.Elem().Kind() == reflect.Struct { - ret = appendStructPtrDiff(ret, sf.Name, v1.Field(i), v2.Field(i)) - continue - } - } - panic(fmt.Sprintf("unsupported type at %s: %s", mkPath(sf.Name), sf.Type.String())) - } - return ret -} - // SignatureType specifies a scheme for signing RegisterRequest messages. It // specifies the crypto algorithms to use, the contents of what is signed, and // any other relevant details. Historically, requests were unsigned so the zero @@ -1117,11 +1233,11 @@ type RegisterResponseAuth struct { AuthKey string `json:",omitempty"` } -// RegisterRequest is sent by a client to register the key for a node. -// It is encoded to JSON, encrypted with golang.org/x/crypto/nacl/box, -// using the local machine key, and sent to: +// RegisterRequest is a request to register a key for a node. // -// https://login.tailscale.com/machine/ +// This is JSON-encoded and sent over the control plane connection to: +// +// POST https:///machine/register. type RegisterRequest struct { _ structs.Incomparable @@ -1237,10 +1353,9 @@ type Endpoint struct { // The request includes a copy of the client's current set of WireGuard // endpoints and general host information. // -// The request is encoded to JSON, encrypted with golang.org/x/crypto/nacl/box, -// using the local machine key, and sent to: +// This is JSON-encoded and sent over the control plane connection to: // -// https://login.tailscale.com/machine//map +// POST https:///machine/map type MapRequest struct { // Version is incremented whenever the client code changes enough that // we want to signal to the control server that we're capable of something @@ -1249,11 +1364,22 @@ type MapRequest struct { // For current values and history, see the CapabilityVersion type's docs. Version CapabilityVersion - Compress string // "zstd" or "" (no compression) - KeepAlive bool // whether server should send keep-alives back to us + Compress string `json:",omitzero"` // "zstd" or "" (no compression) + KeepAlive bool `json:",omitzero"` // whether server should send keep-alives back to us NodeKey key.NodePublic DiscoKey key.DiscoPublic + // HardwareAttestationKey is the public key of the node's hardware-backed + // identity attestation key, if any. + HardwareAttestationKey key.HardwareAttestationPublic `json:",omitzero"` + // HardwareAttestationKeySignature is the signature of + // "$UNIX_TIMESTAMP|$NODE_KEY" using its hardware attestation key, if any. + HardwareAttestationKeySignature []byte `json:",omitempty"` + // HardwareAttestationKeySignatureTimestamp is the time at which the + // HardwareAttestationKeySignature was created, if any. This UNIX timestamp + // value is prepended to the node key when signing. + HardwareAttestationKeySignatureTimestamp time.Time `json:",omitzero"` + // Stream is whether the client wants to receive multiple MapResponses over // the same HTTP connection. // @@ -1262,7 +1388,7 @@ type MapRequest struct { // // If true and Version >= 68, the server should treat this as a read-only // request and ignore any Hostinfo or other fields that might be set. - Stream bool + Stream bool `json:",omitzero"` // Hostinfo is the client's current Hostinfo. Although it is always included // in the request, the server may choose to ignore it when Stream is true @@ -1279,14 +1405,14 @@ type MapRequest struct { // // The server may choose to ignore the request for any reason and start a // new map session. This is only applicable when Stream is true. - MapSessionHandle string `json:",omitempty"` + MapSessionHandle string `json:",omitzero"` // MapSessionSeq is the sequence number in the map session identified by // MapSesssionHandle that was most recently processed by the client. // It is only applicable when MapSessionHandle is specified. // If the server chooses to honor the MapSessionHandle request, only sequence // numbers greater than this value will be returned. - MapSessionSeq int64 `json:",omitempty"` + MapSessionSeq int64 `json:",omitzero"` // Endpoints are the client's magicsock UDP ip:port endpoints (IPv4 or IPv6). // These can be ignored if Stream is true and Version >= 68. @@ -1297,7 +1423,7 @@ type MapRequest struct { // TKAHead describes the hash of the latest AUM applied to the local // tailnet key authority, if one is operating. // It is encoded as tka.AUMHash.MarshalText. - TKAHead string `json:",omitempty"` + TKAHead string `json:",omitzero"` // ReadOnly was set when client just wanted to fetch the MapResponse, // without updating their Endpoints. The intended use was for clients to @@ -1305,7 +1431,7 @@ type MapRequest struct { // update. // // Deprecated: always false as of Version 68. - ReadOnly bool `json:",omitempty"` + ReadOnly bool `json:",omitzero"` // OmitPeers is whether the client is okay with the Peers list being omitted // in the response. @@ -1321,7 +1447,7 @@ type MapRequest struct { // If OmitPeers is true, Stream is false, but ReadOnly is true, // then all the response fields are included. (This is what the client does // when initially fetching the DERP map.) - OmitPeers bool `json:",omitempty"` + OmitPeers bool `json:",omitzero"` // DebugFlags is a list of strings specifying debugging and // development features to enable in handling this map @@ -1336,6 +1462,12 @@ type MapRequest struct { // * "warn-router-unhealthy": client's Router implementation is // having problems. DebugFlags []string `json:",omitempty"` + + // ConnectionHandleForTest, if non-empty, is an opaque string sent by the client that + // identifies this specific connection to the server. The server may choose to + // use this handle to identify the connection for debugging or testing + // purposes. It has no semantic meaning. + ConnectionHandleForTest string `json:",omitzero"` } // PortRange represents a range of UDP or TCP port numbers. @@ -1351,11 +1483,20 @@ func (pr PortRange) Contains(port uint16) bool { var PortRangeAny = PortRange{0, 65535} +func (pr PortRange) String() string { + if pr.First == pr.Last { + return strconv.FormatUint(uint64(pr.First), 10) + } else if pr == PortRangeAny { + return "*" + } + return fmt.Sprintf("%d-%d", pr.First, pr.Last) +} + // NetPortRange represents a range of ports that's allowed for one or more IPs. type NetPortRange struct { _ structs.Incomparable IP string // IP, CIDR, Range, or "*" (same formats as FilterRule.SrcIPs) - Bits *int // deprecated; the 2020 way to turn IP into a CIDR. See FilterRule.SrcBits. + Bits *int `json:",omitempty"` // deprecated; the 2020 way to turn IP into a CIDR. See FilterRule.SrcBits. Ports PortRange } @@ -1413,11 +1554,23 @@ const ( // user groups as Kubernetes user groups. This capability is read by // peers that are Tailscale Kubernetes operator instances. PeerCapabilityKubernetes PeerCapability = "tailscale.com/cap/kubernetes" + + // PeerCapabilityRelay grants the ability for a peer to allocate relay + // endpoints. + PeerCapabilityRelay PeerCapability = "tailscale.com/cap/relay" + // PeerCapabilityRelayTarget grants the current node the ability to allocate + // relay endpoints to the peer which has this capability. + PeerCapabilityRelayTarget PeerCapability = "tailscale.com/cap/relay-target" + + // PeerCapabilityTsIDP grants a peer tsidp-specific + // capabilities, such as the ability to add user groups to the OIDC + // claim + PeerCapabilityTsIDP PeerCapability = "tailscale.com/cap/tsidp" ) // NodeCapMap is a map of capabilities to their optional values. It is valid for // a capability to have no values (nil slice); such capabilities can be tested -// for by using the Contains method. +// for by using the [NodeCapMap.Contains] method. // // See [NodeCapability] for more information on keys. type NodeCapMap map[NodeCapability][]RawMessage @@ -1431,12 +1584,19 @@ func (c NodeCapMap) Equal(c2 NodeCapMap) bool { // If cap does not exist in cm, it returns (nil, nil). // It returns an error if the values cannot be unmarshaled into the provided type. func UnmarshalNodeCapJSON[T any](cm NodeCapMap, cap NodeCapability) ([]T, error) { - vals, ok := cm[cap] + return UnmarshalNodeCapViewJSON[T](views.MapSliceOf(cm), cap) +} + +// UnmarshalNodeCapViewJSON unmarshals each JSON value in cm.Get(cap) as T. +// If cap does not exist in cm, it returns (nil, nil). +// It returns an error if the values cannot be unmarshaled into the provided type. +func UnmarshalNodeCapViewJSON[T any](cm views.MapSlice[NodeCapability, RawMessage], cap NodeCapability) ([]T, error) { + vals, ok := cm.GetOk(cap) if !ok { return nil, nil } - out := make([]T, 0, len(vals)) - for _, v := range vals { + out := make([]T, 0, vals.Len()) + for _, v := range vals.All() { var t T if err := json.Unmarshal([]byte(v), &t); err != nil { return nil, err @@ -1466,12 +1626,19 @@ type PeerCapMap map[PeerCapability][]RawMessage // If cap does not exist in cm, it returns (nil, nil). // It returns an error if the values cannot be unmarshaled into the provided type. func UnmarshalCapJSON[T any](cm PeerCapMap, cap PeerCapability) ([]T, error) { - vals, ok := cm[cap] + return UnmarshalCapViewJSON[T](views.MapSliceOf(cm), cap) +} + +// UnmarshalCapViewJSON unmarshals each JSON value in cm.Get(cap) as T. +// If cap does not exist in cm, it returns (nil, nil). +// It returns an error if the values cannot be unmarshaled into the provided type. +func UnmarshalCapViewJSON[T any](cm views.MapSlice[PeerCapability, RawMessage], cap PeerCapability) ([]T, error) { + vals, ok := cm.GetOk(cap) if !ok { return nil, nil } - out := make([]T, 0, len(vals)) - for _, v := range vals { + out := make([]T, 0, vals.Len()) + for _, v := range vals.All() { var t T if err := json.Unmarshal([]byte(v), &t); err != nil { return nil, err @@ -1591,12 +1758,11 @@ type DNSConfig struct { // in the network map, aka MagicDNS. // Despite the (legacy) name, does not necessarily cause request // proxying to be enabled. - Proxied bool `json:",omitempty"` - - // The following fields are only set and used by - // MapRequest.Version >=9 and <14. + Proxied bool `json:",omitzero"` - // Nameservers are the IP addresses of the nameservers to use. + // Nameservers are the IP addresses of the global nameservers to use. + // + // Deprecated: this is only set and used by MapRequest.Version >=9 and <14. Use Resolvers instead. Nameservers []netip.Addr `json:",omitempty"` // CertDomains are the set of DNS names for which the control @@ -1629,7 +1795,7 @@ type DNSConfig struct { // TempCorpIssue13969 is a temporary (2023-08-16) field for an internal hack day prototype. // It contains a user inputed URL that should have a list of domains to be blocked. // See https://github.com/tailscale/corp/issues/13969. - TempCorpIssue13969 string `json:",omitempty"` + TempCorpIssue13969 string `json:",omitzero"` } // DNSRecord is an extra DNS record to add to MagicDNS. @@ -1641,7 +1807,7 @@ type DNSRecord struct { // Type is the DNS record type. // Empty means A or AAAA, depending on value. // Other values are currently ignored. - Type string `json:",omitempty"` + Type string `json:",omitzero"` // Value is the IP address in string form. // TODO(bradfitz): if we ever add support for record types @@ -1666,9 +1832,14 @@ const ( PingPeerAPI PingType = "peerapi" ) -// PingRequest with no IP and Types is a request to send an HTTP request to prove the -// long-polling client is still connected. -// PingRequest with Types and IP, will send a ping to the IP and send a POST +// PingRequest is a request from the control plane to the local node to probe +// something. +// +// A PingRequest with no IP and Types is a request from the control plane to the +// local node to send an HTTP request to a URL to prove the long-polling client +// is still connected. +// +// A PingRequest with Types and IP, will send a ping to the IP and send a POST // request containing a PingResponse to the URL containing results. type PingRequest struct { // URL is the URL to reply to the PingRequest to. @@ -1684,11 +1855,11 @@ type PingRequest struct { // URLIsNoise, if true, means that the client should hit URL over the Noise // transport instead of TLS. - URLIsNoise bool `json:",omitempty"` + URLIsNoise bool `json:",omitzero"` // Log is whether to log about this ping in the success case. // For failure cases, the client will log regardless. - Log bool `json:",omitempty"` + Log bool `json:",omitzero"` // Types is the types of ping that are initiated. Can be any PingType, comma // separated, e.g. "disco,TSMP" @@ -1698,10 +1869,10 @@ type PingRequest struct { // node's c2n handler and the HTTP response sent in a POST to URL. For c2n, // the value of URLIsNoise is ignored and only the Noise transport (back to // the control plane) will be used, as if URLIsNoise were true. - Types string `json:",omitempty"` + Types string `json:",omitzero"` // IP is the ping target, when needed by the PingType(s) given in Types. - IP netip.Addr + IP netip.Addr `json:",omitzero"` // Payload is the ping payload. // @@ -1729,10 +1900,14 @@ type PingResponse struct { // omitted, Err should contain information as to the cause. LatencySeconds float64 `json:",omitempty"` - // Endpoint is the ip:port if direct UDP was used. - // It is not currently set for TSMP pings. + // Endpoint is a string of the form "{ip}:{port}" if direct UDP was used. It + // is not currently set for TSMP. Endpoint string `json:",omitempty"` + // PeerRelay is a string of the form "{ip}:{port}:vni:{vni}" if a peer + // relay was used. It is not currently set for TSMP. + PeerRelay string `json:",omitempty"` + // DERPRegionID is non-zero DERP region ID if DERP was used. // It is not currently set for TSMP pings. DERPRegionID int `json:",omitempty"` @@ -1833,7 +2008,7 @@ type MapResponse struct { // PeersChangedPatch, if non-nil, means that node(s) have changed. // This is a lighter version of the older PeersChanged support that - // only supports certain types of updates + // only supports certain types of updates. // // These are applied after Peers* above, but in practice the // control server should only send these on their own, without @@ -1914,13 +2089,31 @@ type MapResponse struct { // plane's perspective. A nil value means no change from the previous // MapResponse. A non-nil 0-length slice restores the health to good (no // known problems). A non-zero length slice are the list of problems that - // the control place sees. + // the control plane sees. + // + // Either this will be set, or DisplayMessages will be set, but not both. // // Note that this package's type, due its use of a slice and omitempty, is // unable to marshal a zero-length non-nil slice. The control server needs // to marshal this type using a separate type. See MapResponse docs. Health []string `json:",omitempty"` + // DisplayMessages sets the health state of the node from the control + // plane's perspective. + // + // Either this will be set, or Health will be set, but not both. + // + // The map keys are IDs that uniquely identify the type of health issue. The + // map values are the messages. If the server sends down a map with entries, + // the client treats it as a patch: new entries are added, keys with a value + // of nil are deleted, existing entries with new values are updated. A nil + // map and an empty map both mean no change has occurred since the last + // update. + // + // As a special case, the map key "*" with a value of nil means to clear all + // prior display messages before processing the other map entries. + DisplayMessages map[DisplayMessageID]*DisplayMessage `json:",omitempty"` + // SSHPolicy, if non-nil, updates the SSH policy for how incoming // SSH connections should be handled. SSHPolicy *SSHPolicy `json:",omitempty"` @@ -1962,12 +2155,90 @@ type MapResponse struct { // auto-update setting doesn't change if the tailnet admin flips the // default after the node registered. DefaultAutoUpdate opt.Bool `json:",omitempty"` +} - // MaxKeyDuration describes the MaxKeyDuration setting for the tailnet. - // If zero, the value is unchanged. - MaxKeyDuration time.Duration `json:",omitempty"` +// DisplayMessage represents a health state of the node from the control plane's +// perspective. It is deliberately similar to [health.Warnable] as both get +// converted into [health.UnhealthyState] to be sent to the GUI. +type DisplayMessage struct { + // Title is a string that the GUI uses as title for this message. The title + // should be short and fit in a single line. It should not end in a period. + // + // Example: "Network may be blocking Tailscale". + // + // See the various instantiations of [health.Warnable] for more examples. + Title string + + // Text is an extended string that the GUI will display to the user. This + // could be multiple sentences explaining the issue in more detail. + // + // Example: "macOS Screen Time seems to be blocking Tailscale. Try disabling + // Screen Time in System Settings > Screen Time > Content & Privacy > Access + // to Web Content." + // + // See the various instantiations of [health.Warnable] for more examples. + Text string + + // Severity is the severity of the DisplayMessage, which the GUI can use to + // determine how to display it. Maps to [health.Severity]. + Severity DisplayMessageSeverity + + // ImpactsConnectivity is whether the health problem will impact the user's + // ability to connect to the Internet or other nodes on the tailnet, which + // the GUI can use to determine how to display it. + ImpactsConnectivity bool `json:",omitempty"` + + // Primary action, if present, represents the action to allow the user to + // take when interacting with this message. For example, if the + // DisplayMessage is shown via a notification, the action label might be a + // button on that notification and clicking the button would open the URL. + PrimaryAction *DisplayMessageAction `json:",omitempty"` +} + +// DisplayMessageAction represents an action (URL and link) to be presented to +// the user associated with a [DisplayMessage]. +type DisplayMessageAction struct { + // URL is the URL to navigate to when the user interacts with this action + URL string + + // Label is the call to action for the UI to display on the UI element that + // will open the URL (such as a button or link). For example, "Sign in" or + // "Learn more". + Label string +} + +// DisplayMessageID is a string that uniquely identifies the kind of health +// issue (e.g. "session-expired"). +type DisplayMessageID string + +// Equal returns true iff all fields are equal. +func (m DisplayMessage) Equal(o DisplayMessage) bool { + return m.Title == o.Title && + m.Text == o.Text && + m.Severity == o.Severity && + m.ImpactsConnectivity == o.ImpactsConnectivity && + (m.PrimaryAction == nil) == (o.PrimaryAction == nil) && + (m.PrimaryAction == nil || (m.PrimaryAction.URL == o.PrimaryAction.URL && + m.PrimaryAction.Label == o.PrimaryAction.Label)) } +// DisplayMessageSeverity represents how serious a [DisplayMessage] is. Analogous +// to health.Severity. +type DisplayMessageSeverity string + +const ( + // SeverityHigh is the highest severity level, used for critical errors that need immediate attention. + // On platforms where the client GUI can deliver notifications, a SeverityHigh message will trigger + // a modal notification. + SeverityHigh DisplayMessageSeverity = "high" + // SeverityMedium is used for errors that are important but not critical. This won't trigger a modal + // notification, however it will be displayed in a more visible way than a SeverityLow message. + SeverityMedium DisplayMessageSeverity = "medium" + // SeverityLow is used for less important notices that don't need immediate attention. The user will + // have to go to a Settings window, or another "hidden" GUI location to see these messages. + SeverityLow DisplayMessageSeverity = "low" +) + // ClientVersion is information about the latest client version that's available // for the client (and whether they're already running it). // @@ -2013,7 +2284,14 @@ type ControlDialPlan struct { // connecting to the control server. type ControlIPCandidate struct { // IP is the address to attempt connecting to. - IP netip.Addr + IP netip.Addr `json:",omitzero"` + + // ACEHost, if non-empty, means that the client should connect to the + // control plane using an HTTPS CONNECT request to the provided hostname. If + // the IP field is also set, then the IP is the IP address of the ACEHost + // (and not the control plane) and DNS should not be used. The target (the + // argument to CONNECT) is always the control plane's hostname, not an IP. + ACEHost string `json:",omitempty"` // DialStartSec is the number of seconds after the beginning of the // connection process to wait before trying this candidate. @@ -2054,10 +2332,10 @@ type Debug struct { Exit *int `json:",omitempty"` } -func (id ID) String() string { return fmt.Sprintf("id:%x", int64(id)) } -func (id UserID) String() string { return fmt.Sprintf("userid:%x", int64(id)) } -func (id LoginID) String() string { return fmt.Sprintf("loginid:%x", int64(id)) } -func (id NodeID) String() string { return fmt.Sprintf("nodeid:%x", int64(id)) } +func (id ID) String() string { return fmt.Sprintf("id:%d", int64(id)) } +func (id UserID) String() string { return fmt.Sprintf("userid:%d", int64(id)) } +func (id LoginID) String() string { return fmt.Sprintf("loginid:%d", int64(id)) } +func (id NodeID) String() string { return fmt.Sprintf("nodeid:%d", int64(id)) } // Equal reports whether n and n2 are equal. func (n *Node) Equal(n2 *Node) bool { @@ -2081,7 +2359,8 @@ func (n *Node) Equal(n2 *Node) bool { slicesx.EqualSameNil(n.AllowedIPs, n2.AllowedIPs) && slicesx.EqualSameNil(n.PrimaryRoutes, n2.PrimaryRoutes) && slicesx.EqualSameNil(n.Endpoints, n2.Endpoints) && - n.DERP == n2.DERP && + n.LegacyDERPString == n2.LegacyDERPString && + n.HomeDERP == n2.HomeDERP && n.Cap == n2.Cap && n.Hostinfo.Equal(n2.Hostinfo) && n.Created.Equal(n2.Created) && @@ -2155,12 +2434,16 @@ type NodeCapability string const ( CapabilityFileSharing NodeCapability = "https://tailscale.com/cap/file-sharing" CapabilityAdmin NodeCapability = "https://tailscale.com/cap/is-admin" + CapabilityOwner NodeCapability = "https://tailscale.com/cap/is-owner" CapabilitySSH NodeCapability = "https://tailscale.com/cap/ssh" // feature enabled/available CapabilitySSHRuleIn NodeCapability = "https://tailscale.com/cap/ssh-rule-in" // some SSH rule reach this node CapabilityDataPlaneAuditLogs NodeCapability = "https://tailscale.com/cap/data-plane-audit-logs" // feature enabled CapabilityDebug NodeCapability = "https://tailscale.com/cap/debug" // exposes debug endpoints over the PeerAPI CapabilityHTTPS NodeCapability = "https" + // CapabilityMacUIV2 makes the macOS GUI enable its v2 mode. + CapabilityMacUIV2 NodeCapability = "https://tailscale.com/cap/mac-ui-v2" + // CapabilityBindToInterfaceByRoute changes how Darwin nodes create // sockets (in the net/netns package). See that package for more // details on the behaviour of this capability. @@ -2177,6 +2460,10 @@ const ( // of connections to the default network interface on Darwin nodes. CapabilityDebugDisableBindConnToInterface NodeCapability = "https://tailscale.com/cap/debug-disable-bind-conn-to-interface" + // CapabilityDebugDisableBindConnToInterface disables the automatic binding + // of connections to the default network interface on Darwin nodes using network extensions + CapabilityDebugDisableBindConnToInterfaceAppleExt NodeCapability = "https://tailscale.com/cap/debug-disable-bind-conn-to-interface-apple-ext" + // CapabilityTailnetLock indicates the node may initialize tailnet lock. CapabilityTailnetLock NodeCapability = "https://tailscale.com/cap/tailnet-lock" @@ -2276,8 +2563,19 @@ const ( // This cannot be set simultaneously with NodeAttrLinuxMustUseIPTables. NodeAttrLinuxMustUseNfTables NodeCapability = "linux-netfilter?v=nftables" - // NodeAttrSeamlessKeyRenewal makes clients enable beta functionality - // of renewing node keys without breaking connections. + // NodeAttrDisableSeamlessKeyRenewal disables seamless key renewal, which is + // enabled by default in clients as of 2025-09-17 (1.90 and later). + // + // We will use this attribute to manage the rollout, and disable seamless in + // clients with known bugs. + // http://go/seamless-key-renewal + NodeAttrDisableSeamlessKeyRenewal NodeCapability = "disable-seamless-key-renewal" + + // NodeAttrSeamlessKeyRenewal was used to opt-in to seamless key renewal + // during its private alpha. + // + // Deprecated: NodeAttrSeamlessKeyRenewal is deprecated as of CapabilityVersion 126, + // because seamless key renewal is now enabled by default. NodeAttrSeamlessKeyRenewal NodeCapability = "seamless-key-renewal" // NodeAttrProbeUDPLifetime makes the client probe UDP path lifetime at the @@ -2347,26 +2645,93 @@ const ( // NodeAttrDisableMagicSockCryptoRouting disables the use of the // magicsock cryptorouting hook. See tailscale/corp#20732. + // + // Deprecated: NodeAttrDisableMagicSockCryptoRouting is deprecated as of + // CapabilityVersion 124, CryptoRouting is now mandatory. See tailscale/corp#31083. NodeAttrDisableMagicSockCryptoRouting NodeCapability = "disable-magicsock-crypto-routing" // NodeAttrDisableCaptivePortalDetection instructs the client to not perform captive portal detection // automatically when the network state changes. NodeAttrDisableCaptivePortalDetection NodeCapability = "disable-captive-portal-detection" + // NodeAttrDisableSkipStatusQueue is set when the node should disable skipping + // of queued netmap.NetworkMap between the controlclient and LocalBackend. + // See tailscale/tailscale#14768. + NodeAttrDisableSkipStatusQueue NodeCapability = "disable-skip-status-queue" + // NodeAttrSSHEnvironmentVariables enables logic for handling environment variables sent // via SendEnv in the SSH server and applying them to the SSH session. NodeAttrSSHEnvironmentVariables NodeCapability = "ssh-env-vars" + + // NodeAttrServiceHost indicates the VIP Services for which the client is + // approved to act as a service host, and which IP addresses are assigned + // to those VIP Services. Any VIP Services that the client is not + // advertising can be ignored. + // Each value of this key in [NodeCapMap] is of type [ServiceIPMappings]. + // If multiple values of this key exist, they should be merged in sequence + // (replace conflicting keys). + NodeAttrServiceHost NodeCapability = "service-host" + + // NodeAttrMaxKeyDuration represents the MaxKeyDuration setting on the + // tailnet. The value of this key in [NodeCapMap] will be only one entry of + // type float64 representing the duration in seconds. This cap will be + // omitted if the tailnet's MaxKeyDuration is the default. + NodeAttrMaxKeyDuration NodeCapability = "tailnet.maxKeyDuration" + + // NodeAttrNativeIPV4 contains the IPV4 address of the node in its + // native tailnet. This is currently only sent to Hello, in its + // peer node list. + NodeAttrNativeIPV4 NodeCapability = "native-ipv4" + + // NodeAttrDisableRelayServer prevents the node from acting as an underlay + // UDP relay server. There are no expected values for this key; the key + // only needs to be present in [NodeCapMap] to take effect. + NodeAttrDisableRelayServer NodeCapability = "disable-relay-server" + + // NodeAttrDisableRelayClient prevents the node from both allocating UDP + // relay server endpoints itself, and from using endpoints allocated by + // its peers. This attribute can be added to the node dynamically; if added + // while the node is already running, the node will be unable to allocate + // endpoints after it next updates its network map, and will be immediately + // unable to use new paths via a UDP relay server. Setting this attribute + // dynamically does not remove any existing paths, including paths that + // traverse a UDP relay server. There are no expected values for this key + // in [NodeCapMap]; the key only needs to be present in [NodeCapMap] to + // take effect. + NodeAttrDisableRelayClient NodeCapability = "disable-relay-client" + + // NodeAttrMagicDNSPeerAAAA is a capability that tells the node's MagicDNS + // server to answer AAAA queries about its peers. See tailscale/tailscale#1152. + NodeAttrMagicDNSPeerAAAA NodeCapability = "magicdns-aaaa" + + // NodeAttrTrafficSteering configures the node to use the traffic + // steering subsystem for via routes. See tailscale/corp#29966. + NodeAttrTrafficSteering NodeCapability = "traffic-steering" + + // NodeAttrTailnetDisplayName is an optional alternate name for the tailnet + // to be displayed to the user. + // If empty or absent, a default is used. + // If this value is present and set by a user this will only include letters, + // numbers, apostrophe, spaces, and hyphens. This may not be true for the default. + // Values can look like "foo.com" or "Foo's Test Tailnet - Staging". + NodeAttrTailnetDisplayName NodeCapability = "tailnet-display-name" + + // NodeAttrClientSideReachability configures the node to determine + // reachability itself when choosing connectors. When absent, the + // default behavior is to trust the control plane when it claims that a + // node is no longer online, but that is not a reliable signal. + NodeAttrClientSideReachability = "client-side-reachability" ) // SetDNSRequest is a request to add a DNS record. // -// This is used for ACME DNS-01 challenges (so people can use -// LetsEncrypt, etc). +// This is used to let tailscaled clients complete their ACME DNS-01 challenges +// (so people can use LetsEncrypt, etc) to get TLS certificates for +// their foo.bar.ts.net MagicDNS names. // -// The request is encoded to JSON, encrypted with golang.org/x/crypto/nacl/box, -// using the local machine key, and sent to: +// This is JSON-encoded and sent over the control plane connection to: // -// https://login.tailscale.com/machine//set-dns +// POST https:///machine/set-dns type SetDNSRequest struct { // Version is the client's capabilities // (CurrentCapabilityVersion) when using the Noise transport. @@ -2396,7 +2761,12 @@ type SetDNSRequest struct { type SetDNSResponse struct{} // HealthChangeRequest is the JSON request body type used to report -// node health changes to https:///machine//update-health. +// node health changes to: +// +// POST https:///machine/update-health. +// +// As of 2025-10-02, we stopped sending this to the control plane proactively. +// It was never useful enough with its current design and needs more thought. type HealthChangeRequest struct { Subsys string // a health.Subsystem value in string form Error string // or empty if cleared @@ -2406,6 +2776,38 @@ type HealthChangeRequest struct { NodeKey key.NodePublic } +// SetDeviceAttributesRequest is a request to update the +// current node's device posture attributes. +// +// As of 2024-12-30, this is an experimental dev feature +// for internal testing. See tailscale/corp#24690. +// +// This is JSON-encoded and sent over the control plane connection to: +// +// PATCH https:///machine/set-device-attr +type SetDeviceAttributesRequest struct { + // Version is the current binary's [CurrentCapabilityVersion]. + Version CapabilityVersion + + // NodeKey identifies the node to modify. It should be the currently active + // node and is an error if not. + NodeKey key.NodePublic + + // Update is a map of device posture attributes to update. + // Attributes not in the map are left unchanged. + Update AttrUpdate +} + +// AttrUpdate is a map of attributes to update. +// Attributes not in the map are left unchanged. +// The value can be a string, float64, bool, or nil to delete. +// +// See https://tailscale.com/s/api-device-posture-attrs. +// +// TODO(bradfitz): add struct type for specifying optional associated data +// for each attribute value, like an expiry time? +type AttrUpdate map[string]any + // SSHPolicy is the policy for how to handle incoming SSH connections // over Tailscale. type SSHPolicy struct { @@ -2481,16 +2883,13 @@ type SSHPrincipal struct { Any bool `json:"any,omitempty"` // if true, match any connection // TODO(bradfitz): add StableUserID, once that exists - // PubKeys, if non-empty, means that this SSHPrincipal only - // matches if one of these public keys is presented by the user. + // UnusedPubKeys was public key support. It never became an official product + // feature and so as of 2024-12-12 is being removed. + // This stub exists to remind us not to re-use the JSON field name "pubKeys" + // in the future if we bring it back with different semantics. // - // As a special case, if len(PubKeys) == 1 and PubKeys[0] starts - // with "https://", then it's fetched (like https://github.com/username.keys). - // In that case, the following variable expansions are also supported - // in the URL: - // * $LOGINNAME_EMAIL ("foo@bar.com" or "foo@github") - // * $LOGINNAME_LOCALPART (the "foo" from either of the above) - PubKeys []string `json:"pubKeys,omitempty"` + // Deprecated: do not use. It does nothing. + UnusedPubKeys []string `json:"pubKeys,omitempty"` } // SSHAction is how to handle an incoming connection. @@ -2512,7 +2911,7 @@ type SSHAction struct { // SessionDuration, if non-zero, is how long the session can stay open // before being forcefully terminated. - SessionDuration time.Duration `json:"sessionDuration,omitempty"` + SessionDuration time.Duration `json:"sessionDuration,omitempty,format:nano"` // AllowAgentForwarding, if true, allows accepted connections to forward // the ssh agent if requested. @@ -2575,6 +2974,8 @@ type SSHRecorderFailureAction struct { // SSHEventNotifyRequest is the JSON payload sent to the NotifyURL // for an SSH event. +// +// POST https:///[...varies, sent in SSH policy...] type SSHEventNotifyRequest struct { // EventType is the type of notify request being sent. EventType SSHEventType @@ -2635,36 +3036,36 @@ type SSHRecordingAttempt struct { FailureMessage string } -// QueryFeatureRequest is a request sent to "/machine/feature/query" -// to get instructions on how to enable a feature, such as Funnel, -// for the node's tailnet. +// QueryFeatureRequest is a request sent to "POST /machine/feature/query" to get +// instructions on how to enable a feature, such as Funnel, for the node's +// tailnet. // // See QueryFeatureResponse for response structure. type QueryFeatureRequest struct { // Feature is the string identifier for a feature. - Feature string `json:",omitempty"` + Feature string `json:",omitzero"` // NodeKey is the client's current node key. - NodeKey key.NodePublic `json:",omitempty"` + NodeKey key.NodePublic `json:",omitzero"` } // QueryFeatureResponse is the response to an QueryFeatureRequest. // See cli.enableFeatureInteractive for usage. type QueryFeatureResponse struct { // Complete is true when the feature is already enabled. - Complete bool `json:",omitempty"` + Complete bool `json:",omitzero"` // Text holds lines to display in the CLI with information // about the feature and how to enable it. // // Lines are separated by newline characters. The final // newline may be omitted. - Text string `json:",omitempty"` + Text string `json:",omitzero"` // URL is the link for the user to visit to take action on // enabling the feature. // // When empty, there is no action for this user to take. - URL string `json:",omitempty"` + URL string `json:",omitzero"` // ShouldWait specifies whether the CLI should block and // wait for the user to enable the feature. @@ -2677,7 +3078,7 @@ type QueryFeatureResponse struct { // // The CLI can watch the IPN notification bus for changes in // required node capabilities to know when to continue. - ShouldWait bool `json:",omitempty"` + ShouldWait bool `json:",omitzero"` } // WebClientAuthResponse is the response to a web client authentication request @@ -2687,15 +3088,15 @@ type WebClientAuthResponse struct { // ID is a unique identifier for the session auth request. // It can be supplied to "/machine/webclient/wait" to pause until // the session authentication has been completed. - ID string `json:",omitempty"` + ID string `json:",omitzero"` // URL is the link for the user to visit to authenticate the session. // // When empty, there is no action for the user to take. - URL string `json:",omitempty"` + URL string `json:",omitzero"` // Complete is true when the session authentication has been completed. - Complete bool `json:",omitempty"` + Complete bool `json:",omitzero"` } // OverTLSPublicKeyResponse is the JSON response to /key?v= @@ -2726,7 +3127,7 @@ type OverTLSPublicKeyResponse struct { // The token can be presented to any resource provider which offers OIDC // Federation. // -// It is JSON-encoded and sent over Noise to "/machine/id-token". +// It is JSON-encoded and sent over Noise to "POST /machine/id-token". type TokenRequest struct { // CapVersion is the client's current CapabilityVersion. CapVersion CapabilityVersion @@ -2771,10 +3172,10 @@ type PeerChange struct { // DERPRegion, if non-zero, means that NodeID's home DERP // region ID is now this number. - DERPRegion int `json:",omitempty"` + DERPRegion int `json:",omitzero"` // Cap, if non-zero, means that NodeID's capability version has changed. - Cap CapabilityVersion `json:",omitempty"` + Cap CapabilityVersion `json:",omitzero"` // CapMap, if non-nil, means that NodeID's capability map has changed. CapMap NodeCapMap `json:",omitempty"` @@ -2784,23 +3185,23 @@ type PeerChange struct { Endpoints []netip.AddrPort `json:",omitempty"` // Key, if non-nil, means that the NodeID's wireguard public key changed. - Key *key.NodePublic `json:",omitempty"` + Key *key.NodePublic `json:",omitzero"` // TODO: de-pointer: tailscale/tailscale#17978 // KeySignature, if non-nil, means that the signature of the wireguard // public key has changed. KeySignature tkatype.MarshaledSignature `json:",omitempty"` // DiscoKey, if non-nil, means that the NodeID's discokey changed. - DiscoKey *key.DiscoPublic `json:",omitempty"` + DiscoKey *key.DiscoPublic `json:",omitzero"` // TODO: de-pointer: tailscale/tailscale#17978 // Online, if non-nil, means that the NodeID's online status changed. - Online *bool `json:",omitempty"` + Online *bool `json:",omitzero"` // LastSeen, if non-nil, means that the NodeID's online status changed. - LastSeen *time.Time `json:",omitempty"` + LastSeen *time.Time `json:",omitzero"` // TODO: de-pointer: tailscale/tailscale#17978 // KeyExpiry, if non-nil, changes the NodeID's key expiry. - KeyExpiry *time.Time `json:",omitempty"` + KeyExpiry *time.Time `json:",omitzero"` // TODO: de-pointer: tailscale/tailscale#17978 } // DerpMagicIP is a fake WireGuard endpoint IP address that means to @@ -2841,3 +3242,51 @@ type EarlyNoise struct { // For some request types, the header may have multiple values. (e.g. OldNodeKey // vs NodeKey) const LBHeader = "Ts-Lb" + +// ServiceIPMappings maps ServiceName to lists of IP addresses. This is used +// as the value of the [NodeAttrServiceHost] capability, to inform service hosts +// what IP addresses they need to listen on for each service that they are +// advertising. +// +// This is of the form: +// +// { +// "svc:samba": ["100.65.32.1", "fd7a:115c:a1e0::1234"], +// "svc:web": ["100.102.42.3", "fd7a:115c:a1e0::abcd"], +// } +// +// where the IP addresses are the IPs of the VIP services. These IPs are also +// provided in AllowedIPs, but this lets the client know which services +// correspond to those IPs. Any services that don't correspond to a service +// this client is hosting can be ignored. +type ServiceIPMappings map[ServiceName][]netip.Addr + +// ClientAuditAction represents an auditable action that a client can report to the +// control plane. These actions must correspond to the supported actions +// in the control plane. +type ClientAuditAction string + +const ( + // AuditNodeDisconnect action is sent when a node has disconnected + // from the control plane. The details must include a reason in the Details + // field, either generated, or entered by the user. + AuditNodeDisconnect = ClientAuditAction("DISCONNECT_NODE") +) + +// AuditLogRequest represents an audit log request to be sent to the control plane. +// +// This is JSON-encoded and sent over the control plane connection to: +// POST https:///machine/audit-log +type AuditLogRequest struct { + // Version is the client's current CapabilityVersion. + Version CapabilityVersion `json:",omitzero"` + // NodeKey is the client's current node key. + NodeKey key.NodePublic `json:",omitzero"` + // Action is the action to be logged. It must correspond to a known action in the control plane. + Action ClientAuditAction `json:",omitzero"` + // Details is an opaque string, specific to the action being logged. Empty strings may not + // be valid depending on the action being logged. + Details string `json:",omitzero"` + // Timestamp is the time at which the audit log was generated on the node. + Timestamp time.Time `json:",omitzero"` +} diff --git a/tailcfg/tailcfg_clone.go b/tailcfg/tailcfg_clone.go index 61564f3f8..751b7c288 100644 --- a/tailcfg/tailcfg_clone.go +++ b/tailcfg/tailcfg_clone.go @@ -26,17 +26,14 @@ func (src *User) Clone() *User { } dst := new(User) *dst = *src - dst.Logins = append(src.Logins[:0:0], src.Logins...) return dst } // A compilation failure here means this code must be regenerated, with the command at the top of this file. var _UserCloneNeedsRegeneration = User(struct { ID UserID - LoginName string DisplayName string ProfilePicURL string - Logins []LoginID Created time.Time }{}) @@ -102,7 +99,8 @@ var _NodeCloneNeedsRegeneration = Node(struct { Addresses []netip.Prefix AllowedIPs []netip.Prefix Endpoints []netip.AddrPort - DERP string + LegacyDERPString string + HomeDERP int Hostinfo HostinfoView Created time.Time Cap CapabilityVersion @@ -143,6 +141,9 @@ func (src *Hostinfo) Clone() *Hostinfo { if dst.Location != nil { dst.Location = ptr.To(*src.Location) } + if dst.TPM != nil { + dst.TPM = ptr.To(*src.TPM) + } return dst } @@ -168,6 +169,7 @@ var _HostinfoCloneNeedsRegeneration = Hostinfo(struct { ShareeNode bool NoLogsNoSupport bool WireIngress bool + IngressEnabled bool AllowsUpdate bool Machine string GoArch string @@ -183,7 +185,11 @@ var _HostinfoCloneNeedsRegeneration = Hostinfo(struct { Userspace opt.Bool UserspaceRouter opt.Bool AppConnector opt.Bool + ServicesHash string + ExitNodeID StableNodeID Location *Location + TPM *TPMInfo + StateEncrypted opt.Bool }{}) // Clone makes a deep copy of NetInfo. @@ -201,7 +207,6 @@ func (src *NetInfo) Clone() *NetInfo { // A compilation failure here means this code must be regenerated, with the command at the top of this file. var _NetInfoCloneNeedsRegeneration = NetInfo(struct { MappingVariesByDestIP opt.Bool - HairPinning opt.Bool WorkingIPv6 opt.Bool OSHasIPv6 opt.Bool WorkingUDP opt.Bool @@ -301,7 +306,6 @@ func (src *RegisterResponse) Clone() *RegisterResponse { } dst := new(RegisterResponse) *dst = *src - dst.User = *src.User.Clone() dst.NodeKeySignature = append(src.NodeKeySignature[:0:0], src.NodeKeySignature...) return dst } @@ -417,13 +421,14 @@ func (src *DERPRegion) Clone() *DERPRegion { // A compilation failure here means this code must be regenerated, with the command at the top of this file. var _DERPRegionCloneNeedsRegeneration = DERPRegion(struct { - RegionID int - RegionCode string - RegionName string - Latitude float64 - Longitude float64 - Avoid bool - Nodes []*DERPNode + RegionID int + RegionCode string + RegionName string + Latitude float64 + Longitude float64 + Avoid bool + NoMeasureNoHome bool + Nodes []*DERPNode }{}) // Clone makes a deep copy of DERPMap. @@ -555,17 +560,17 @@ func (src *SSHPrincipal) Clone() *SSHPrincipal { } dst := new(SSHPrincipal) *dst = *src - dst.PubKeys = append(src.PubKeys[:0:0], src.PubKeys...) + dst.UnusedPubKeys = append(src.UnusedPubKeys[:0:0], src.UnusedPubKeys...) return dst } // A compilation failure here means this code must be regenerated, with the command at the top of this file. var _SSHPrincipalCloneNeedsRegeneration = SSHPrincipal(struct { - Node StableNodeID - NodeIP string - UserLogin string - Any bool - PubKeys []string + Node StableNodeID + NodeIP string + UserLogin string + Any bool + UnusedPubKeys []string }{}) // Clone makes a deep copy of ControlDialPlan. @@ -624,12 +629,56 @@ var _UserProfileCloneNeedsRegeneration = UserProfile(struct { LoginName string DisplayName string ProfilePicURL string - Roles emptyStructJSONSlice +}{}) + +// Clone makes a deep copy of VIPService. +// The result aliases no memory with the original. +func (src *VIPService) Clone() *VIPService { + if src == nil { + return nil + } + dst := new(VIPService) + *dst = *src + dst.Ports = append(src.Ports[:0:0], src.Ports...) + return dst +} + +// A compilation failure here means this code must be regenerated, with the command at the top of this file. +var _VIPServiceCloneNeedsRegeneration = VIPService(struct { + Name ServiceName + Ports []ProtoPortRange + Active bool +}{}) + +// Clone makes a deep copy of SSHPolicy. +// The result aliases no memory with the original. +func (src *SSHPolicy) Clone() *SSHPolicy { + if src == nil { + return nil + } + dst := new(SSHPolicy) + *dst = *src + if src.Rules != nil { + dst.Rules = make([]*SSHRule, len(src.Rules)) + for i := range dst.Rules { + if src.Rules[i] == nil { + dst.Rules[i] = nil + } else { + dst.Rules[i] = src.Rules[i].Clone() + } + } + } + return dst +} + +// A compilation failure here means this code must be regenerated, with the command at the top of this file. +var _SSHPolicyCloneNeedsRegeneration = SSHPolicy(struct { + Rules []*SSHRule }{}) // Clone duplicates src into dst and reports whether it succeeded. // To succeed, must be of types <*T, *T> or <*T, **T>, -// where T is one of User,Node,Hostinfo,NetInfo,Login,DNSConfig,RegisterResponse,RegisterResponseAuth,RegisterRequest,DERPHomeParams,DERPRegion,DERPMap,DERPNode,SSHRule,SSHAction,SSHPrincipal,ControlDialPlan,Location,UserProfile. +// where T is one of User,Node,Hostinfo,NetInfo,Login,DNSConfig,RegisterResponse,RegisterResponseAuth,RegisterRequest,DERPHomeParams,DERPRegion,DERPMap,DERPNode,SSHRule,SSHAction,SSHPrincipal,ControlDialPlan,Location,UserProfile,VIPService,SSHPolicy. func Clone(dst, src any) bool { switch src := src.(type) { case *User: @@ -803,6 +852,24 @@ func Clone(dst, src any) bool { *dst = src.Clone() return true } + case *VIPService: + switch dst := dst.(type) { + case *VIPService: + *dst = *src.Clone() + return true + case **VIPService: + *dst = src.Clone() + return true + } + case *SSHPolicy: + switch dst := dst.(type) { + case *SSHPolicy: + *dst = *src.Clone() + return true + case **SSHPolicy: + *dst = src.Clone() + return true + } } return false } diff --git a/tailcfg/tailcfg_test.go b/tailcfg/tailcfg_test.go index 0d0636677..6691263eb 100644 --- a/tailcfg/tailcfg_test.go +++ b/tailcfg/tailcfg_test.go @@ -10,7 +10,6 @@ import ( "reflect" "regexp" "strconv" - "strings" "testing" "time" @@ -51,6 +50,7 @@ func TestHostinfoEqual(t *testing.T) { "ShareeNode", "NoLogsNoSupport", "WireIngress", + "IngressEnabled", "AllowsUpdate", "Machine", "GoArch", @@ -66,7 +66,11 @@ func TestHostinfoEqual(t *testing.T) { "Userspace", "UserspaceRouter", "AppConnector", + "ServicesHash", + "ExitNodeID", "Location", + "TPM", + "StateEncrypted", } if have := fieldsOf(reflect.TypeFor[Hostinfo]()); !reflect.DeepEqual(have, hiHandles) { t.Errorf("Hostinfo.Equal check might be out of sync\nfields: %q\nhandled: %q\n", @@ -240,87 +244,56 @@ func TestHostinfoEqual(t *testing.T) { &Hostinfo{AppConnector: opt.Bool("false")}, false, }, - } - for i, tt := range tests { - got := tt.a.Equal(tt.b) - if got != tt.want { - t.Errorf("%d. Equal = %v; want %v", i, got, tt.want) - } - } -} - -func TestHostinfoHowEqual(t *testing.T) { - tests := []struct { - a, b *Hostinfo - want []string - }{ { - a: nil, - b: nil, - want: nil, + &Hostinfo{ServicesHash: "73475cb40a568e8da8a045ced110137e159f890ac4da883b6b17dc651b3a8049"}, + &Hostinfo{ServicesHash: "73475cb40a568e8da8a045ced110137e159f890ac4da883b6b17dc651b3a8049"}, + true, }, { - a: new(Hostinfo), - b: nil, - want: []string{"nil"}, + &Hostinfo{ServicesHash: "084c799cd551dd1d8d5c5f9a5d593b2e931f5e36122ee5c793c1d08a19839cc0"}, + &Hostinfo{}, + false, }, { - a: nil, - b: new(Hostinfo), - want: []string{"nil"}, + &Hostinfo{IngressEnabled: true}, + &Hostinfo{}, + false, }, { - a: new(Hostinfo), - b: new(Hostinfo), - want: nil, + &Hostinfo{IngressEnabled: true}, + &Hostinfo{IngressEnabled: true}, + true, }, { - a: &Hostinfo{ - IPNVersion: "1", - ShieldsUp: false, - RoutableIPs: []netip.Prefix{netip.MustParsePrefix("1.2.3.0/24")}, - }, - b: &Hostinfo{ - IPNVersion: "2", - ShieldsUp: true, - RoutableIPs: []netip.Prefix{netip.MustParsePrefix("1.2.3.0/25")}, - }, - want: []string{"IPNVersion", "ShieldsUp", "RoutableIPs"}, + &Hostinfo{IngressEnabled: false}, + &Hostinfo{}, + true, }, { - a: &Hostinfo{ - IPNVersion: "1", - }, - b: &Hostinfo{ - IPNVersion: "2", - NetInfo: new(NetInfo), - }, - want: []string{"IPNVersion", "NetInfo.nil"}, - }, - { - a: &Hostinfo{ - IPNVersion: "1", - NetInfo: &NetInfo{ - WorkingIPv6: "true", - HavePortMap: true, - LinkType: "foo", - PreferredDERP: 123, - DERPLatency: map[string]float64{ - "foo": 1.0, - }, - }, - }, - b: &Hostinfo{ - IPNVersion: "2", - NetInfo: &NetInfo{}, - }, - want: []string{"IPNVersion", "NetInfo.WorkingIPv6", "NetInfo.HavePortMap", "NetInfo.PreferredDERP", "NetInfo.LinkType", "NetInfo.DERPLatency"}, + &Hostinfo{IngressEnabled: false}, + &Hostinfo{IngressEnabled: true}, + false, + }, + { + &Hostinfo{ExitNodeID: "stable-exit"}, + &Hostinfo{ExitNodeID: "stable-exit"}, + true, + }, + { + &Hostinfo{ExitNodeID: ""}, + &Hostinfo{}, + true, + }, + { + &Hostinfo{ExitNodeID: ""}, + &Hostinfo{ExitNodeID: "stable-exit"}, + false, }, } for i, tt := range tests { - got := tt.a.HowUnequal(tt.b) - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("%d. got %q; want %q", i, got, tt.want) + got := tt.a.Equal(tt.b) + if got != tt.want { + t.Errorf("%d. Equal = %v; want %v", i, got, tt.want) } } } @@ -356,7 +329,7 @@ func TestNodeEqual(t *testing.T) { nodeHandles := []string{ "ID", "StableID", "Name", "User", "Sharer", "Key", "KeyExpiry", "KeySignature", "Machine", "DiscoKey", - "Addresses", "AllowedIPs", "Endpoints", "DERP", "Hostinfo", + "Addresses", "AllowedIPs", "Endpoints", "LegacyDERPString", "HomeDERP", "Hostinfo", "Created", "Cap", "Tags", "PrimaryRoutes", "LastSeen", "Online", "MachineAuthorized", "Capabilities", "CapMap", @@ -519,8 +492,13 @@ func TestNodeEqual(t *testing.T) { true, }, { - &Node{DERP: "foo"}, - &Node{DERP: "bar"}, + &Node{LegacyDERPString: "foo"}, + &Node{LegacyDERPString: "bar"}, + false, + }, + { + &Node{HomeDERP: 1}, + &Node{HomeDERP: 2}, false, }, { @@ -629,7 +607,6 @@ func TestNodeEqual(t *testing.T) { func TestNetInfoFields(t *testing.T) { handled := []string{ "MappingVariesByDestIP", - "HairPinning", "WorkingIPv6", "OSHasIPv6", "WorkingUDP", @@ -655,7 +632,6 @@ func TestCloneUser(t *testing.T) { u *User }{ {"nil_logins", &User{}}, - {"zero_logins", &User{Logins: make([]LoginID, 0)}}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -689,28 +665,6 @@ func TestCloneNode(t *testing.T) { } } -func TestUserProfileJSONMarshalForMac(t *testing.T) { - // Old macOS clients had a bug where they required - // UserProfile.Roles to be non-null. Lock that in - // 1.0.x/1.2.x clients are gone in the wild. - // See mac commit 0242c08a2ca496958027db1208f44251bff8488b (Sep 30). - // It was fixed in at least 1.4.x, and perhaps 1.2.x. - j, err := json.Marshal(UserProfile{}) - if err != nil { - t.Fatal(err) - } - const wantSub = `"Roles":[]` - if !strings.Contains(string(j), wantSub) { - t.Fatalf("didn't contain %#q; got: %s", wantSub, j) - } - - // And back: - var up UserProfile - if err := json.Unmarshal(j, &up); err != nil { - t.Fatalf("Unmarshal: %v", err) - } -} - func TestEndpointTypeMarshal(t *testing.T) { eps := []EndpointType{ EndpointUnknownType, @@ -940,3 +894,132 @@ func TestCheckTag(t *testing.T) { }) } } + +func TestDisplayMessageEqual(t *testing.T) { + type test struct { + name string + value1 DisplayMessage + value2 DisplayMessage + wantEqual bool + } + + for _, test := range []test{ + { + name: "same", + value1: DisplayMessage{ + Title: "title", + Text: "text", + Severity: SeverityHigh, + ImpactsConnectivity: false, + PrimaryAction: &DisplayMessageAction{ + URL: "https://example.com", + Label: "Open", + }, + }, + value2: DisplayMessage{ + Title: "title", + Text: "text", + Severity: SeverityHigh, + ImpactsConnectivity: false, + PrimaryAction: &DisplayMessageAction{ + URL: "https://example.com", + Label: "Open", + }, + }, + wantEqual: true, + }, + { + name: "different-title", + value1: DisplayMessage{ + Title: "title", + }, + value2: DisplayMessage{ + Title: "different title", + }, + wantEqual: false, + }, + { + name: "different-text", + value1: DisplayMessage{ + Text: "some text", + }, + value2: DisplayMessage{ + Text: "different text", + }, + wantEqual: false, + }, + { + name: "different-severity", + value1: DisplayMessage{ + Severity: SeverityHigh, + }, + value2: DisplayMessage{ + Severity: SeverityMedium, + }, + wantEqual: false, + }, + { + name: "different-impactsConnectivity", + value1: DisplayMessage{ + ImpactsConnectivity: true, + }, + value2: DisplayMessage{ + ImpactsConnectivity: false, + }, + wantEqual: false, + }, + { + name: "different-primaryAction-nil-non-nil", + value1: DisplayMessage{}, + value2: DisplayMessage{ + PrimaryAction: &DisplayMessageAction{ + URL: "https://example.com", + Label: "Open", + }, + }, + wantEqual: false, + }, + { + name: "different-primaryAction-url", + value1: DisplayMessage{ + PrimaryAction: &DisplayMessageAction{ + URL: "https://example.com", + Label: "Open", + }, + }, + value2: DisplayMessage{ + PrimaryAction: &DisplayMessageAction{ + URL: "https://zombo.com", + Label: "Open", + }, + }, + wantEqual: false, + }, + { + name: "different-primaryAction-label", + value1: DisplayMessage{ + PrimaryAction: &DisplayMessageAction{ + URL: "https://example.com", + Label: "Open", + }, + }, + value2: DisplayMessage{ + PrimaryAction: &DisplayMessageAction{ + URL: "https://example.com", + Label: "Learn more", + }, + }, + wantEqual: false, + }, + } { + t.Run(test.name, func(t *testing.T) { + got := test.value1.Equal(test.value2) + + if got != test.wantEqual { + value1 := must.Get(json.MarshalIndent(test.value1, "", " ")) + value2 := must.Get(json.MarshalIndent(test.value2, "", " ")) + t.Errorf("value1.Equal(value2): got %t, want %t\nvalue1:\n%s\nvalue2:\n%s", got, test.wantEqual, value1, value2) + } + }) + } +} diff --git a/tailcfg/tailcfg_view.go b/tailcfg/tailcfg_view.go index a3e19b0dc..dbd29a87a 100644 --- a/tailcfg/tailcfg_view.go +++ b/tailcfg/tailcfg_view.go @@ -6,11 +6,13 @@ package tailcfg import ( - "encoding/json" + jsonv1 "encoding/json" "errors" "net/netip" "time" + jsonv2 "github.com/go-json-experiment/json" + "github.com/go-json-experiment/json/jsontext" "tailscale.com/types/dnstype" "tailscale.com/types/key" "tailscale.com/types/opt" @@ -19,9 +21,9 @@ import ( "tailscale.com/types/views" ) -//go:generate go run tailscale.com/cmd/cloner -clonefunc=true -type=User,Node,Hostinfo,NetInfo,Login,DNSConfig,RegisterResponse,RegisterResponseAuth,RegisterRequest,DERPHomeParams,DERPRegion,DERPMap,DERPNode,SSHRule,SSHAction,SSHPrincipal,ControlDialPlan,Location,UserProfile +//go:generate go run tailscale.com/cmd/cloner -clonefunc=true -type=User,Node,Hostinfo,NetInfo,Login,DNSConfig,RegisterResponse,RegisterResponseAuth,RegisterRequest,DERPHomeParams,DERPRegion,DERPMap,DERPNode,SSHRule,SSHAction,SSHPrincipal,ControlDialPlan,Location,UserProfile,VIPService,SSHPolicy -// View returns a readonly view of User. +// View returns a read-only view of User. func (p *User) View() UserView { return UserView{Đļ: p} } @@ -37,7 +39,7 @@ type UserView struct { Đļ *User } -// Valid reports whether underlying value is non-nil. +// Valid reports whether v's underlying value is non-nil. func (v UserView) Valid() bool { return v.Đļ != nil } // AsStruct returns a clone of the underlying value which aliases no memory with @@ -49,8 +51,17 @@ func (v UserView) AsStruct() *User { return v.Đļ.Clone() } -func (v UserView) MarshalJSON() ([]byte, error) { return json.Marshal(v.Đļ) } +// MarshalJSON implements [jsonv1.Marshaler]. +func (v UserView) MarshalJSON() ([]byte, error) { + return jsonv1.Marshal(v.Đļ) +} + +// MarshalJSONTo implements [jsonv2.MarshalerTo]. +func (v UserView) MarshalJSONTo(enc *jsontext.Encoder) error { + return jsonv2.MarshalEncode(enc, v.Đļ) +} +// UnmarshalJSON implements [jsonv1.Unmarshaler]. func (v *UserView) UnmarshalJSON(b []byte) error { if v.Đļ != nil { return errors.New("already initialized") @@ -59,31 +70,44 @@ func (v *UserView) UnmarshalJSON(b []byte) error { return nil } var x User - if err := json.Unmarshal(b, &x); err != nil { + if err := jsonv1.Unmarshal(b, &x); err != nil { + return err + } + v.Đļ = &x + return nil +} + +// UnmarshalJSONFrom implements [jsonv2.UnmarshalerFrom]. +func (v *UserView) UnmarshalJSONFrom(dec *jsontext.Decoder) error { + if v.Đļ != nil { + return errors.New("already initialized") + } + var x User + if err := jsonv2.UnmarshalDecode(dec, &x); err != nil { return err } v.Đļ = &x return nil } -func (v UserView) ID() UserID { return v.Đļ.ID } -func (v UserView) LoginName() string { return v.Đļ.LoginName } -func (v UserView) DisplayName() string { return v.Đļ.DisplayName } -func (v UserView) ProfilePicURL() string { return v.Đļ.ProfilePicURL } -func (v UserView) Logins() views.Slice[LoginID] { return views.SliceOf(v.Đļ.Logins) } -func (v UserView) Created() time.Time { return v.Đļ.Created } +func (v UserView) ID() UserID { return v.Đļ.ID } + +// if non-empty overrides Login field +func (v UserView) DisplayName() string { return v.Đļ.DisplayName } + +// if non-empty overrides Login field +func (v UserView) ProfilePicURL() string { return v.Đļ.ProfilePicURL } +func (v UserView) Created() time.Time { return v.Đļ.Created } // A compilation failure here means this code must be regenerated, with the command at the top of this file. var _UserViewNeedsRegeneration = User(struct { ID UserID - LoginName string DisplayName string ProfilePicURL string - Logins []LoginID Created time.Time }{}) -// View returns a readonly view of Node. +// View returns a read-only view of Node. func (p *Node) View() NodeView { return NodeView{Đļ: p} } @@ -99,7 +123,7 @@ type NodeView struct { Đļ *Node } -// Valid reports whether underlying value is non-nil. +// Valid reports whether v's underlying value is non-nil. func (v NodeView) Valid() bool { return v.Đļ != nil } // AsStruct returns a clone of the underlying value which aliases no memory with @@ -111,8 +135,17 @@ func (v NodeView) AsStruct() *Node { return v.Đļ.Clone() } -func (v NodeView) MarshalJSON() ([]byte, error) { return json.Marshal(v.Đļ) } +// MarshalJSON implements [jsonv1.Marshaler]. +func (v NodeView) MarshalJSON() ([]byte, error) { + return jsonv1.Marshal(v.Đļ) +} + +// MarshalJSONTo implements [jsonv2.MarshalerTo]. +func (v NodeView) MarshalJSONTo(enc *jsontext.Encoder) error { + return jsonv2.MarshalEncode(enc, v.Đļ) +} +// UnmarshalJSON implements [jsonv1.Unmarshaler]. func (v *NodeView) UnmarshalJSON(b []byte) error { if v.Đļ != nil { return errors.New("already initialized") @@ -121,7 +154,20 @@ func (v *NodeView) UnmarshalJSON(b []byte) error { return nil } var x Node - if err := json.Unmarshal(b, &x); err != nil { + if err := jsonv1.Unmarshal(b, &x); err != nil { + return err + } + v.Đļ = &x + return nil +} + +// UnmarshalJSONFrom implements [jsonv2.UnmarshalerFrom]. +func (v *NodeView) UnmarshalJSONFrom(dec *jsontext.Decoder) error { + if v.Đļ != nil { + return errors.New("already initialized") + } + var x Node + if err := jsonv2.UnmarshalDecode(dec, &x); err != nil { return err } v.Đļ = &x @@ -130,70 +176,202 @@ func (v *NodeView) UnmarshalJSON(b []byte) error { func (v NodeView) ID() NodeID { return v.Đļ.ID } func (v NodeView) StableID() StableNodeID { return v.Đļ.StableID } -func (v NodeView) Name() string { return v.Đļ.Name } -func (v NodeView) User() UserID { return v.Đļ.User } -func (v NodeView) Sharer() UserID { return v.Đļ.Sharer } -func (v NodeView) Key() key.NodePublic { return v.Đļ.Key } -func (v NodeView) KeyExpiry() time.Time { return v.Đļ.KeyExpiry } + +// Name is the FQDN of the node. +// It is also the MagicDNS name for the node. +// It has a trailing dot. +// e.g. "host.tail-scale.ts.net." +func (v NodeView) Name() string { return v.Đļ.Name } + +// User is the user who created the node. If ACL tags are in use for the +// node then it doesn't reflect the ACL identity that the node is running +// as. +func (v NodeView) User() UserID { return v.Đļ.User } + +// Sharer, if non-zero, is the user who shared this node, if different than User. +func (v NodeView) Sharer() UserID { return v.Đļ.Sharer } +func (v NodeView) Key() key.NodePublic { return v.Đļ.Key } + +// the zero value if this node does not expire +func (v NodeView) KeyExpiry() time.Time { return v.Đļ.KeyExpiry } func (v NodeView) KeySignature() views.ByteSlice[tkatype.MarshaledSignature] { return views.ByteSliceOf(v.Đļ.KeySignature) } -func (v NodeView) Machine() key.MachinePublic { return v.Đļ.Machine } -func (v NodeView) DiscoKey() key.DiscoPublic { return v.Đļ.DiscoKey } -func (v NodeView) Addresses() views.Slice[netip.Prefix] { return views.SliceOf(v.Đļ.Addresses) } -func (v NodeView) AllowedIPs() views.Slice[netip.Prefix] { return views.SliceOf(v.Đļ.AllowedIPs) } -func (v NodeView) Endpoints() views.Slice[netip.AddrPort] { return views.SliceOf(v.Đļ.Endpoints) } -func (v NodeView) DERP() string { return v.Đļ.DERP } -func (v NodeView) Hostinfo() HostinfoView { return v.Đļ.Hostinfo } -func (v NodeView) Created() time.Time { return v.Đļ.Created } -func (v NodeView) Cap() CapabilityVersion { return v.Đļ.Cap } -func (v NodeView) Tags() views.Slice[string] { return views.SliceOf(v.Đļ.Tags) } +func (v NodeView) Machine() key.MachinePublic { return v.Đļ.Machine } +func (v NodeView) DiscoKey() key.DiscoPublic { return v.Đļ.DiscoKey } + +// Addresses are the IP addresses of this Node directly. +func (v NodeView) Addresses() views.Slice[netip.Prefix] { return views.SliceOf(v.Đļ.Addresses) } + +// AllowedIPs are the IP ranges to route to this node. +// +// As of CapabilityVersion 112, this may be nil (null or undefined) on the wire +// to mean the same as Addresses. Internally, it is always filled in with +// its possibly-implicit value. +func (v NodeView) AllowedIPs() views.Slice[netip.Prefix] { return views.SliceOf(v.Đļ.AllowedIPs) } + +// IP+port (public via STUN, and local LANs) +func (v NodeView) Endpoints() views.Slice[netip.AddrPort] { return views.SliceOf(v.Đļ.Endpoints) } + +// LegacyDERPString is this node's home LegacyDERPString region ID integer, but shoved into an +// IP:port string for legacy reasons. The IP address is always "127.3.3.40" +// (a loopback address (127) followed by the digits over the letters DERP on +// a QWERTY keyboard (3.3.40)). The "port number" is the home LegacyDERPString region ID +// integer. +// +// Deprecated: HomeDERP has replaced this, but old servers might still send +// this field. See tailscale/tailscale#14636. Do not use this field in code +// other than in the upgradeNode func, which canonicalizes it to HomeDERP +// if it arrives as a LegacyDERPString string on the wire. +func (v NodeView) LegacyDERPString() string { return v.Đļ.LegacyDERPString } + +// HomeDERP is the modern version of the DERP string field, with just an +// integer. The client advertises support for this as of capver 111. +// +// HomeDERP may be zero if not (yet) known, but ideally always be non-zero +// for magicsock connectivity to function normally. +func (v NodeView) HomeDERP() int { return v.Đļ.HomeDERP } +func (v NodeView) Hostinfo() HostinfoView { return v.Đļ.Hostinfo } +func (v NodeView) Created() time.Time { return v.Đļ.Created } + +// if non-zero, the node's capability version; old servers might not send +func (v NodeView) Cap() CapabilityVersion { return v.Đļ.Cap } + +// Tags are the list of ACL tags applied to this node. +// Tags take the form of `tag:` where value starts +// with a letter and only contains alphanumerics and dashes `-`. +// Some valid tag examples: +// +// `tag:prod` +// `tag:database` +// `tag:lab-1` +func (v NodeView) Tags() views.Slice[string] { return views.SliceOf(v.Đļ.Tags) } + +// PrimaryRoutes are the routes from AllowedIPs that this node +// is currently the primary subnet router for, as determined +// by the control plane. It does not include the self address +// values from Addresses that are in AllowedIPs. func (v NodeView) PrimaryRoutes() views.Slice[netip.Prefix] { return views.SliceOf(v.Đļ.PrimaryRoutes) } -func (v NodeView) LastSeen() *time.Time { - if v.Đļ.LastSeen == nil { - return nil - } - x := *v.Đļ.LastSeen - return &x -} -func (v NodeView) Online() *bool { - if v.Đļ.Online == nil { - return nil - } - x := *v.Đļ.Online - return &x +// LastSeen is when the node was last online. It is not +// updated when Online is true. It is nil if the current +// node doesn't have permission to know, or the node +// has never been online. +func (v NodeView) LastSeen() views.ValuePointer[time.Time] { + return views.ValuePointerOf(v.Đļ.LastSeen) } -func (v NodeView) MachineAuthorized() bool { return v.Đļ.MachineAuthorized } +// Online is whether the node is currently connected to the +// coordination server. A value of nil means unknown, or the +// current node doesn't have permission to know. +func (v NodeView) Online() views.ValuePointer[bool] { return views.ValuePointerOf(v.Đļ.Online) } + +// TODO(crawshaw): replace with MachineStatus +func (v NodeView) MachineAuthorized() bool { return v.Đļ.MachineAuthorized } + +// Capabilities are capabilities that the node has. +// They're free-form strings, but should be in the form of URLs/URIs +// such as: +// +// "https://tailscale.com/cap/is-admin" +// "https://tailscale.com/cap/file-sharing" +// +// Deprecated: use CapMap instead. See https://github.com/tailscale/tailscale/issues/11508 func (v NodeView) Capabilities() views.Slice[NodeCapability] { return views.SliceOf(v.Đļ.Capabilities) } +// CapMap is a map of capabilities to their optional argument/data values. +// +// It is valid for a capability to not have any argument/data values; such +// capabilities can be tested for using the HasCap method. These type of +// capabilities are used to indicate that a node has a capability, but there +// is no additional data associated with it. These were previously +// represented by the Capabilities field, but can now be represented by +// CapMap with an empty value. +// +// See NodeCapability for more information on keys. +// +// Metadata about nodes can be transmitted in 3 ways: +// 1. MapResponse.Node.CapMap describes attributes that affect behavior for +// this node, such as which features have been enabled through the admin +// panel and any associated configuration details. +// 2. MapResponse.PacketFilter(s) describes access (both IP and application +// based) that should be granted to peers. +// 3. MapResponse.Peers[].CapMap describes attributes regarding a peer node, +// such as which features the peer supports or if that peer is preferred +// for a particular task vs other peers that could also be chosen. func (v NodeView) CapMap() views.MapSlice[NodeCapability, RawMessage] { return views.MapSliceOf(v.Đļ.CapMap) } -func (v NodeView) UnsignedPeerAPIOnly() bool { return v.Đļ.UnsignedPeerAPIOnly } -func (v NodeView) ComputedName() string { return v.Đļ.ComputedName } + +// UnsignedPeerAPIOnly means that this node is not signed nor subject to TKA +// restrictions. However, in exchange for that privilege, it does not get +// network access. It can only access this node's peerapi, which may not let +// it do anything. It is the tailscaled client's job to double-check the +// MapResponse's PacketFilter to verify that its AllowedIPs will not be +// accepted by the packet filter. +func (v NodeView) UnsignedPeerAPIOnly() bool { return v.Đļ.UnsignedPeerAPIOnly } + +// MagicDNS base name (for normal non-shared-in nodes), FQDN (without trailing dot, for shared-in nodes), or Hostname (if no MagicDNS) +func (v NodeView) ComputedName() string { return v.Đļ.ComputedName } + +// either "ComputedName" or "ComputedName (computedHostIfDifferent)", if computedHostIfDifferent is set func (v NodeView) ComputedNameWithHost() string { return v.Đļ.ComputedNameWithHost } -func (v NodeView) DataPlaneAuditLogID() string { return v.Đļ.DataPlaneAuditLogID } -func (v NodeView) Expired() bool { return v.Đļ.Expired } -func (v NodeView) SelfNodeV4MasqAddrForThisPeer() *netip.Addr { - if v.Đļ.SelfNodeV4MasqAddrForThisPeer == nil { - return nil - } - x := *v.Đļ.SelfNodeV4MasqAddrForThisPeer - return &x -} -func (v NodeView) SelfNodeV6MasqAddrForThisPeer() *netip.Addr { - if v.Đļ.SelfNodeV6MasqAddrForThisPeer == nil { - return nil - } - x := *v.Đļ.SelfNodeV6MasqAddrForThisPeer - return &x +// DataPlaneAuditLogID is the per-node logtail ID used for data plane audit logging. +func (v NodeView) DataPlaneAuditLogID() string { return v.Đļ.DataPlaneAuditLogID } + +// Expired is whether this node's key has expired. Control may send +// this; clients are only allowed to set this from false to true. On +// the client, this is calculated client-side based on a timestamp sent +// from control, to avoid clock skew issues. +func (v NodeView) Expired() bool { return v.Đļ.Expired } + +// SelfNodeV4MasqAddrForThisPeer is the IPv4 that this peer knows the current node as. +// It may be empty if the peer knows the current node by its native +// IPv4 address. +// This field is only populated in a MapResponse for peers and not +// for the current node. +// +// If set, it should be used to masquerade traffic originating from the +// current node to this peer. The masquerade address is only relevant +// for this peer and not for other peers. +// +// This only applies to traffic originating from the current node to the +// peer or any of its subnets. Traffic originating from subnet routes will +// not be masqueraded (e.g. in case of --snat-subnet-routes). +func (v NodeView) SelfNodeV4MasqAddrForThisPeer() views.ValuePointer[netip.Addr] { + return views.ValuePointerOf(v.Đļ.SelfNodeV4MasqAddrForThisPeer) +} + +// SelfNodeV6MasqAddrForThisPeer is the IPv6 that this peer knows the current node as. +// It may be empty if the peer knows the current node by its native +// IPv6 address. +// This field is only populated in a MapResponse for peers and not +// for the current node. +// +// If set, it should be used to masquerade traffic originating from the +// current node to this peer. The masquerade address is only relevant +// for this peer and not for other peers. +// +// This only applies to traffic originating from the current node to the +// peer or any of its subnets. Traffic originating from subnet routes will +// not be masqueraded (e.g. in case of --snat-subnet-routes). +func (v NodeView) SelfNodeV6MasqAddrForThisPeer() views.ValuePointer[netip.Addr] { + return views.ValuePointerOf(v.Đļ.SelfNodeV6MasqAddrForThisPeer) } +// IsWireGuardOnly indicates that this is a non-Tailscale WireGuard peer, it +// is not expected to speak Disco or DERP, and it must have Endpoints in +// order to be reachable. func (v NodeView) IsWireGuardOnly() bool { return v.Đļ.IsWireGuardOnly } -func (v NodeView) IsJailed() bool { return v.Đļ.IsJailed } + +// IsJailed indicates that this node is jailed and should not be allowed +// initiate connections, however outbound connections to it should still be +// allowed. +func (v NodeView) IsJailed() bool { return v.Đļ.IsJailed } + +// ExitNodeDNSResolvers is the list of DNS servers that should be used when this +// node is marked IsWireGuardOnly and being used as an exit node. func (v NodeView) ExitNodeDNSResolvers() views.SliceView[*dnstype.Resolver, dnstype.ResolverView] { return views.SliceOfViews[*dnstype.Resolver, dnstype.ResolverView](v.Đļ.ExitNodeDNSResolvers) } @@ -214,7 +392,8 @@ var _NodeViewNeedsRegeneration = Node(struct { Addresses []netip.Prefix AllowedIPs []netip.Prefix Endpoints []netip.AddrPort - DERP string + LegacyDERPString string + HomeDERP int Hostinfo HostinfoView Created time.Time Cap CapabilityVersion @@ -238,7 +417,7 @@ var _NodeViewNeedsRegeneration = Node(struct { ExitNodeDNSResolvers []*dnstype.Resolver }{}) -// View returns a readonly view of Hostinfo. +// View returns a read-only view of Hostinfo. func (p *Hostinfo) View() HostinfoView { return HostinfoView{Đļ: p} } @@ -254,7 +433,7 @@ type HostinfoView struct { Đļ *Hostinfo } -// Valid reports whether underlying value is non-nil. +// Valid reports whether v's underlying value is non-nil. func (v HostinfoView) Valid() bool { return v.Đļ != nil } // AsStruct returns a clone of the underlying value which aliases no memory with @@ -266,8 +445,17 @@ func (v HostinfoView) AsStruct() *Hostinfo { return v.Đļ.Clone() } -func (v HostinfoView) MarshalJSON() ([]byte, error) { return json.Marshal(v.Đļ) } +// MarshalJSON implements [jsonv1.Marshaler]. +func (v HostinfoView) MarshalJSON() ([]byte, error) { + return jsonv1.Marshal(v.Đļ) +} + +// MarshalJSONTo implements [jsonv2.MarshalerTo]. +func (v HostinfoView) MarshalJSONTo(enc *jsontext.Encoder) error { + return jsonv2.MarshalEncode(enc, v.Đļ) +} +// UnmarshalJSON implements [jsonv1.Unmarshaler]. func (v *HostinfoView) UnmarshalJSON(b []byte) error { if v.Đļ != nil { return errors.New("already initialized") @@ -276,56 +464,165 @@ func (v *HostinfoView) UnmarshalJSON(b []byte) error { return nil } var x Hostinfo - if err := json.Unmarshal(b, &x); err != nil { + if err := jsonv1.Unmarshal(b, &x); err != nil { return err } v.Đļ = &x return nil } -func (v HostinfoView) IPNVersion() string { return v.Đļ.IPNVersion } -func (v HostinfoView) FrontendLogID() string { return v.Đļ.FrontendLogID } -func (v HostinfoView) BackendLogID() string { return v.Đļ.BackendLogID } -func (v HostinfoView) OS() string { return v.Đļ.OS } -func (v HostinfoView) OSVersion() string { return v.Đļ.OSVersion } -func (v HostinfoView) Container() opt.Bool { return v.Đļ.Container } -func (v HostinfoView) Env() string { return v.Đļ.Env } -func (v HostinfoView) Distro() string { return v.Đļ.Distro } -func (v HostinfoView) DistroVersion() string { return v.Đļ.DistroVersion } -func (v HostinfoView) DistroCodeName() string { return v.Đļ.DistroCodeName } -func (v HostinfoView) App() string { return v.Đļ.App } -func (v HostinfoView) Desktop() opt.Bool { return v.Đļ.Desktop } -func (v HostinfoView) Package() string { return v.Đļ.Package } -func (v HostinfoView) DeviceModel() string { return v.Đļ.DeviceModel } -func (v HostinfoView) PushDeviceToken() string { return v.Đļ.PushDeviceToken } -func (v HostinfoView) Hostname() string { return v.Đļ.Hostname } -func (v HostinfoView) ShieldsUp() bool { return v.Đļ.ShieldsUp } -func (v HostinfoView) ShareeNode() bool { return v.Đļ.ShareeNode } -func (v HostinfoView) NoLogsNoSupport() bool { return v.Đļ.NoLogsNoSupport } -func (v HostinfoView) WireIngress() bool { return v.Đļ.WireIngress } -func (v HostinfoView) AllowsUpdate() bool { return v.Đļ.AllowsUpdate } -func (v HostinfoView) Machine() string { return v.Đļ.Machine } -func (v HostinfoView) GoArch() string { return v.Đļ.GoArch } -func (v HostinfoView) GoArchVar() string { return v.Đļ.GoArchVar } -func (v HostinfoView) GoVersion() string { return v.Đļ.GoVersion } -func (v HostinfoView) RoutableIPs() views.Slice[netip.Prefix] { return views.SliceOf(v.Đļ.RoutableIPs) } -func (v HostinfoView) RequestTags() views.Slice[string] { return views.SliceOf(v.Đļ.RequestTags) } -func (v HostinfoView) WoLMACs() views.Slice[string] { return views.SliceOf(v.Đļ.WoLMACs) } -func (v HostinfoView) Services() views.Slice[Service] { return views.SliceOf(v.Đļ.Services) } -func (v HostinfoView) NetInfo() NetInfoView { return v.Đļ.NetInfo.View() } -func (v HostinfoView) SSH_HostKeys() views.Slice[string] { return views.SliceOf(v.Đļ.SSH_HostKeys) } -func (v HostinfoView) Cloud() string { return v.Đļ.Cloud } -func (v HostinfoView) Userspace() opt.Bool { return v.Đļ.Userspace } -func (v HostinfoView) UserspaceRouter() opt.Bool { return v.Đļ.UserspaceRouter } -func (v HostinfoView) AppConnector() opt.Bool { return v.Đļ.AppConnector } -func (v HostinfoView) Location() *Location { - if v.Đļ.Location == nil { - return nil +// UnmarshalJSONFrom implements [jsonv2.UnmarshalerFrom]. +func (v *HostinfoView) UnmarshalJSONFrom(dec *jsontext.Decoder) error { + if v.Đļ != nil { + return errors.New("already initialized") + } + var x Hostinfo + if err := jsonv2.UnmarshalDecode(dec, &x); err != nil { + return err } - x := *v.Đļ.Location - return &x + v.Đļ = &x + return nil } +// version of this code (in version.Long format) +func (v HostinfoView) IPNVersion() string { return v.Đļ.IPNVersion } + +// logtail ID of frontend instance +func (v HostinfoView) FrontendLogID() string { return v.Đļ.FrontendLogID } + +// logtail ID of backend instance +func (v HostinfoView) BackendLogID() string { return v.Đļ.BackendLogID } + +// operating system the client runs on (a version.OS value) +func (v HostinfoView) OS() string { return v.Đļ.OS } + +// OSVersion is the version of the OS, if available. +// +// For Android, it's like "10", "11", "12", etc. For iOS and macOS it's like +// "15.6.1" or "12.4.0". For Windows it's like "10.0.19044.1889". For +// FreeBSD it's like "12.3-STABLE". +// +// For Linux, prior to Tailscale 1.32, we jammed a bunch of fields into this +// string on Linux, like "Debian 10.4; kernel=xxx; container; env=kn" and so +// on. As of Tailscale 1.32, this is simply the kernel version on Linux, like +// "5.10.0-17-amd64". +func (v HostinfoView) OSVersion() string { return v.Đļ.OSVersion } + +// best-effort whether the client is running in a container +func (v HostinfoView) Container() opt.Bool { return v.Đļ.Container } + +// a hostinfo.EnvType in string form +func (v HostinfoView) Env() string { return v.Đļ.Env } + +// "debian", "ubuntu", "nixos", ... +func (v HostinfoView) Distro() string { return v.Đļ.Distro } + +// "20.04", ... +func (v HostinfoView) DistroVersion() string { return v.Đļ.DistroVersion } + +// "jammy", "bullseye", ... +func (v HostinfoView) DistroCodeName() string { return v.Đļ.DistroCodeName } + +// App is used to disambiguate Tailscale clients that run using tsnet. +func (v HostinfoView) App() string { return v.Đļ.App } + +// if a desktop was detected on Linux +func (v HostinfoView) Desktop() opt.Bool { return v.Đļ.Desktop } + +// Tailscale package to disambiguate ("choco", "appstore", etc; "" for unknown) +func (v HostinfoView) Package() string { return v.Đļ.Package } + +// mobile phone model ("Pixel 3a", "iPhone12,3") +func (v HostinfoView) DeviceModel() string { return v.Đļ.DeviceModel } + +// macOS/iOS APNs device token for notifications (and Android in the future) +func (v HostinfoView) PushDeviceToken() string { return v.Đļ.PushDeviceToken } + +// name of the host the client runs on +func (v HostinfoView) Hostname() string { return v.Đļ.Hostname } + +// indicates whether the host is blocking incoming connections +func (v HostinfoView) ShieldsUp() bool { return v.Đļ.ShieldsUp } + +// indicates this node exists in netmap because it's owned by a shared-to user +func (v HostinfoView) ShareeNode() bool { return v.Đļ.ShareeNode } + +// indicates that the user has opted out of sending logs and support +func (v HostinfoView) NoLogsNoSupport() bool { return v.Đļ.NoLogsNoSupport } + +// WireIngress indicates that the node would like to be wired up server-side +// (DNS, etc) to be able to use Tailscale Funnel, even if it's not currently +// enabled. For example, the user might only use it for intermittent +// foreground CLI serve sessions, for which they'd like it to work right +// away, even if it's disabled most of the time. As an optimization, this is +// only sent if IngressEnabled is false, as IngressEnabled implies that this +// option is true. +func (v HostinfoView) WireIngress() bool { return v.Đļ.WireIngress } + +// if the node has any funnel endpoint enabled +func (v HostinfoView) IngressEnabled() bool { return v.Đļ.IngressEnabled } + +// indicates that the node has opted-in to admin-console-drive remote updates +func (v HostinfoView) AllowsUpdate() bool { return v.Đļ.AllowsUpdate } + +// the current host's machine type (uname -m) +func (v HostinfoView) Machine() string { return v.Đļ.Machine } + +// GOARCH value (of the built binary) +func (v HostinfoView) GoArch() string { return v.Đļ.GoArch } + +// GOARM, GOAMD64, etc (of the built binary) +func (v HostinfoView) GoArchVar() string { return v.Đļ.GoArchVar } + +// Go version binary was built with +func (v HostinfoView) GoVersion() string { return v.Đļ.GoVersion } + +// set of IP ranges this client can route +func (v HostinfoView) RoutableIPs() views.Slice[netip.Prefix] { return views.SliceOf(v.Đļ.RoutableIPs) } + +// set of ACL tags this node wants to claim +func (v HostinfoView) RequestTags() views.Slice[string] { return views.SliceOf(v.Đļ.RequestTags) } + +// MAC address(es) to send Wake-on-LAN packets to wake this node (lowercase hex w/ colons) +func (v HostinfoView) WoLMACs() views.Slice[string] { return views.SliceOf(v.Đļ.WoLMACs) } + +// services advertised by this machine +func (v HostinfoView) Services() views.Slice[Service] { return views.SliceOf(v.Đļ.Services) } +func (v HostinfoView) NetInfo() NetInfoView { return v.Đļ.NetInfo.View() } + +// if advertised +func (v HostinfoView) SSH_HostKeys() views.Slice[string] { return views.SliceOf(v.Đļ.SSH_HostKeys) } +func (v HostinfoView) Cloud() string { return v.Đļ.Cloud } + +// if the client is running in userspace (netstack) mode +func (v HostinfoView) Userspace() opt.Bool { return v.Đļ.Userspace } + +// if the client's subnet router is running in userspace (netstack) mode +func (v HostinfoView) UserspaceRouter() opt.Bool { return v.Đļ.UserspaceRouter } + +// if the client is running the app-connector service +func (v HostinfoView) AppConnector() opt.Bool { return v.Đļ.AppConnector } + +// opaque hash of the most recent list of tailnet services, change in hash indicates config should be fetched via c2n +func (v HostinfoView) ServicesHash() string { return v.Đļ.ServicesHash } + +// the client’s selected exit node, empty when unselected. +func (v HostinfoView) ExitNodeID() StableNodeID { return v.Đļ.ExitNodeID } + +// Location represents geographical location data about a +// Tailscale host. Location is optional and only set if +// explicitly declared by a node. +func (v HostinfoView) Location() LocationView { return v.Đļ.Location.View() } + +// TPM device metadata, if available +func (v HostinfoView) TPM() views.ValuePointer[TPMInfo] { return views.ValuePointerOf(v.Đļ.TPM) } + +// StateEncrypted reports whether the node state is stored encrypted on +// disk. The actual mechanism is platform-specific: +// - Apple nodes use the Keychain +// - Linux and Windows nodes use the TPM +// - Android apps use EncryptedSharedPreferences +func (v HostinfoView) StateEncrypted() opt.Bool { return v.Đļ.StateEncrypted } func (v HostinfoView) Equal(v2 HostinfoView) bool { return v.Đļ.Equal(v2.Đļ) } // A compilation failure here means this code must be regenerated, with the command at the top of this file. @@ -350,6 +647,7 @@ var _HostinfoViewNeedsRegeneration = Hostinfo(struct { ShareeNode bool NoLogsNoSupport bool WireIngress bool + IngressEnabled bool AllowsUpdate bool Machine string GoArch string @@ -365,10 +663,14 @@ var _HostinfoViewNeedsRegeneration = Hostinfo(struct { Userspace opt.Bool UserspaceRouter opt.Bool AppConnector opt.Bool + ServicesHash string + ExitNodeID StableNodeID Location *Location + TPM *TPMInfo + StateEncrypted opt.Bool }{}) -// View returns a readonly view of NetInfo. +// View returns a read-only view of NetInfo. func (p *NetInfo) View() NetInfoView { return NetInfoView{Đļ: p} } @@ -384,7 +686,7 @@ type NetInfoView struct { Đļ *NetInfo } -// Valid reports whether underlying value is non-nil. +// Valid reports whether v's underlying value is non-nil. func (v NetInfoView) Valid() bool { return v.Đļ != nil } // AsStruct returns a clone of the underlying value which aliases no memory with @@ -396,8 +698,17 @@ func (v NetInfoView) AsStruct() *NetInfo { return v.Đļ.Clone() } -func (v NetInfoView) MarshalJSON() ([]byte, error) { return json.Marshal(v.Đļ) } +// MarshalJSON implements [jsonv1.Marshaler]. +func (v NetInfoView) MarshalJSON() ([]byte, error) { + return jsonv1.Marshal(v.Đļ) +} + +// MarshalJSONTo implements [jsonv2.MarshalerTo]. +func (v NetInfoView) MarshalJSONTo(enc *jsontext.Encoder) error { + return jsonv2.MarshalEncode(enc, v.Đļ) +} +// UnmarshalJSON implements [jsonv1.Unmarshaler]. func (v *NetInfoView) UnmarshalJSON(b []byte) error { if v.Đļ != nil { return errors.New("already initialized") @@ -406,34 +717,94 @@ func (v *NetInfoView) UnmarshalJSON(b []byte) error { return nil } var x NetInfo - if err := json.Unmarshal(b, &x); err != nil { + if err := jsonv1.Unmarshal(b, &x); err != nil { + return err + } + v.Đļ = &x + return nil +} + +// UnmarshalJSONFrom implements [jsonv2.UnmarshalerFrom]. +func (v *NetInfoView) UnmarshalJSONFrom(dec *jsontext.Decoder) error { + if v.Đļ != nil { + return errors.New("already initialized") + } + var x NetInfo + if err := jsonv2.UnmarshalDecode(dec, &x); err != nil { return err } v.Đļ = &x return nil } +// MappingVariesByDestIP says whether the host's NAT mappings +// vary based on the destination IP. func (v NetInfoView) MappingVariesByDestIP() opt.Bool { return v.Đļ.MappingVariesByDestIP } -func (v NetInfoView) HairPinning() opt.Bool { return v.Đļ.HairPinning } -func (v NetInfoView) WorkingIPv6() opt.Bool { return v.Đļ.WorkingIPv6 } -func (v NetInfoView) OSHasIPv6() opt.Bool { return v.Đļ.OSHasIPv6 } -func (v NetInfoView) WorkingUDP() opt.Bool { return v.Đļ.WorkingUDP } -func (v NetInfoView) WorkingICMPv4() opt.Bool { return v.Đļ.WorkingICMPv4 } -func (v NetInfoView) HavePortMap() bool { return v.Đļ.HavePortMap } -func (v NetInfoView) UPnP() opt.Bool { return v.Đļ.UPnP } -func (v NetInfoView) PMP() opt.Bool { return v.Đļ.PMP } -func (v NetInfoView) PCP() opt.Bool { return v.Đļ.PCP } -func (v NetInfoView) PreferredDERP() int { return v.Đļ.PreferredDERP } -func (v NetInfoView) LinkType() string { return v.Đļ.LinkType } +// WorkingIPv6 is whether the host has IPv6 internet connectivity. +func (v NetInfoView) WorkingIPv6() opt.Bool { return v.Đļ.WorkingIPv6 } + +// OSHasIPv6 is whether the OS supports IPv6 at all, regardless of +// whether IPv6 internet connectivity is available. +func (v NetInfoView) OSHasIPv6() opt.Bool { return v.Đļ.OSHasIPv6 } + +// WorkingUDP is whether the host has UDP internet connectivity. +func (v NetInfoView) WorkingUDP() opt.Bool { return v.Đļ.WorkingUDP } + +// WorkingICMPv4 is whether ICMPv4 works. +// Empty means not checked. +func (v NetInfoView) WorkingICMPv4() opt.Bool { return v.Đļ.WorkingICMPv4 } + +// HavePortMap is whether we have an existing portmap open +// (UPnP, PMP, or PCP). +func (v NetInfoView) HavePortMap() bool { return v.Đļ.HavePortMap } + +// UPnP is whether UPnP appears present on the LAN. +// Empty means not checked. +func (v NetInfoView) UPnP() opt.Bool { return v.Đļ.UPnP } + +// PMP is whether NAT-PMP appears present on the LAN. +// Empty means not checked. +func (v NetInfoView) PMP() opt.Bool { return v.Đļ.PMP } + +// PCP is whether PCP appears present on the LAN. +// Empty means not checked. +func (v NetInfoView) PCP() opt.Bool { return v.Đļ.PCP } + +// PreferredDERP is this node's preferred (home) DERP region ID. +// This is where the node expects to be contacted to begin a +// peer-to-peer connection. The node might be be temporarily +// connected to multiple DERP servers (to speak to other nodes +// that are located elsewhere) but PreferredDERP is the region ID +// that the node subscribes to traffic at. +// Zero means disconnected or unknown. +func (v NetInfoView) PreferredDERP() int { return v.Đļ.PreferredDERP } + +// LinkType is the current link type, if known. +func (v NetInfoView) LinkType() string { return v.Đļ.LinkType } + +// DERPLatency is the fastest recent time to reach various +// DERP STUN servers, in seconds. The map key is the +// "regionID-v4" or "-v6"; it was previously the DERP server's +// STUN host:port. +// +// This should only be updated rarely, or when there's a +// material change, as any change here also gets uploaded to +// the control plane. func (v NetInfoView) DERPLatency() views.Map[string, float64] { return views.MapOf(v.Đļ.DERPLatency) } -func (v NetInfoView) FirewallMode() string { return v.Đļ.FirewallMode } -func (v NetInfoView) String() string { return v.Đļ.String() } + +// FirewallMode encodes both which firewall mode was selected and why. +// It is Linux-specific (at least as of 2023-08-19) and is meant to help +// debug iptables-vs-nftables issues. The string is of the form +// "{nft,ift}-REASON", like "nft-forced" or "ipt-default". Empty means +// either not Linux or a configuration in which the host firewall rules +// are not managed by tailscaled. +func (v NetInfoView) FirewallMode() string { return v.Đļ.FirewallMode } +func (v NetInfoView) String() string { return v.Đļ.String() } // A compilation failure here means this code must be regenerated, with the command at the top of this file. var _NetInfoViewNeedsRegeneration = NetInfo(struct { MappingVariesByDestIP opt.Bool - HairPinning opt.Bool WorkingIPv6 opt.Bool OSHasIPv6 opt.Bool WorkingUDP opt.Bool @@ -448,7 +819,7 @@ var _NetInfoViewNeedsRegeneration = NetInfo(struct { FirewallMode string }{}) -// View returns a readonly view of Login. +// View returns a read-only view of Login. func (p *Login) View() LoginView { return LoginView{Đļ: p} } @@ -464,7 +835,7 @@ type LoginView struct { Đļ *Login } -// Valid reports whether underlying value is non-nil. +// Valid reports whether v's underlying value is non-nil. func (v LoginView) Valid() bool { return v.Đļ != nil } // AsStruct returns a clone of the underlying value which aliases no memory with @@ -476,8 +847,17 @@ func (v LoginView) AsStruct() *Login { return v.Đļ.Clone() } -func (v LoginView) MarshalJSON() ([]byte, error) { return json.Marshal(v.Đļ) } +// MarshalJSON implements [jsonv1.Marshaler]. +func (v LoginView) MarshalJSON() ([]byte, error) { + return jsonv1.Marshal(v.Đļ) +} + +// MarshalJSONTo implements [jsonv2.MarshalerTo]. +func (v LoginView) MarshalJSONTo(enc *jsontext.Encoder) error { + return jsonv2.MarshalEncode(enc, v.Đļ) +} +// UnmarshalJSON implements [jsonv1.Unmarshaler]. func (v *LoginView) UnmarshalJSON(b []byte) error { if v.Đļ != nil { return errors.New("already initialized") @@ -486,17 +866,39 @@ func (v *LoginView) UnmarshalJSON(b []byte) error { return nil } var x Login - if err := json.Unmarshal(b, &x); err != nil { + if err := jsonv1.Unmarshal(b, &x); err != nil { + return err + } + v.Đļ = &x + return nil +} + +// UnmarshalJSONFrom implements [jsonv2.UnmarshalerFrom]. +func (v *LoginView) UnmarshalJSONFrom(dec *jsontext.Decoder) error { + if v.Đļ != nil { + return errors.New("already initialized") + } + var x Login + if err := jsonv2.UnmarshalDecode(dec, &x); err != nil { return err } v.Đļ = &x return nil } -func (v LoginView) ID() LoginID { return v.Đļ.ID } -func (v LoginView) Provider() string { return v.Đļ.Provider } -func (v LoginView) LoginName() string { return v.Đļ.LoginName } -func (v LoginView) DisplayName() string { return v.Đļ.DisplayName } +// unused in the Tailscale client +func (v LoginView) ID() LoginID { return v.Đļ.ID } + +// "google", "github", "okta_foo", etc. +func (v LoginView) Provider() string { return v.Đļ.Provider } + +// an email address or "email-ish" string (like alice@github) +func (v LoginView) LoginName() string { return v.Đļ.LoginName } + +// from the IdP +func (v LoginView) DisplayName() string { return v.Đļ.DisplayName } + +// from the IdP func (v LoginView) ProfilePicURL() string { return v.Đļ.ProfilePicURL } // A compilation failure here means this code must be regenerated, with the command at the top of this file. @@ -509,7 +911,7 @@ var _LoginViewNeedsRegeneration = Login(struct { ProfilePicURL string }{}) -// View returns a readonly view of DNSConfig. +// View returns a read-only view of DNSConfig. func (p *DNSConfig) View() DNSConfigView { return DNSConfigView{Đļ: p} } @@ -525,7 +927,7 @@ type DNSConfigView struct { Đļ *DNSConfig } -// Valid reports whether underlying value is non-nil. +// Valid reports whether v's underlying value is non-nil. func (v DNSConfigView) Valid() bool { return v.Đļ != nil } // AsStruct returns a clone of the underlying value which aliases no memory with @@ -537,8 +939,17 @@ func (v DNSConfigView) AsStruct() *DNSConfig { return v.Đļ.Clone() } -func (v DNSConfigView) MarshalJSON() ([]byte, error) { return json.Marshal(v.Đļ) } +// MarshalJSON implements [jsonv1.Marshaler]. +func (v DNSConfigView) MarshalJSON() ([]byte, error) { + return jsonv1.Marshal(v.Đļ) +} + +// MarshalJSONTo implements [jsonv2.MarshalerTo]. +func (v DNSConfigView) MarshalJSONTo(enc *jsontext.Encoder) error { + return jsonv2.MarshalEncode(enc, v.Đļ) +} +// UnmarshalJSON implements [jsonv1.Unmarshaler]. func (v *DNSConfigView) UnmarshalJSON(b []byte) error { if v.Đļ != nil { return errors.New("already initialized") @@ -547,33 +958,102 @@ func (v *DNSConfigView) UnmarshalJSON(b []byte) error { return nil } var x DNSConfig - if err := json.Unmarshal(b, &x); err != nil { + if err := jsonv1.Unmarshal(b, &x); err != nil { + return err + } + v.Đļ = &x + return nil +} + +// UnmarshalJSONFrom implements [jsonv2.UnmarshalerFrom]. +func (v *DNSConfigView) UnmarshalJSONFrom(dec *jsontext.Decoder) error { + if v.Đļ != nil { + return errors.New("already initialized") + } + var x DNSConfig + if err := jsonv2.UnmarshalDecode(dec, &x); err != nil { return err } v.Đļ = &x return nil } +// Resolvers are the DNS resolvers to use, in order of preference. func (v DNSConfigView) Resolvers() views.SliceView[*dnstype.Resolver, dnstype.ResolverView] { return views.SliceOfViews[*dnstype.Resolver, dnstype.ResolverView](v.Đļ.Resolvers) } +// Routes maps DNS name suffixes to a set of DNS resolvers to +// use. It is used to implement "split DNS" and other advanced DNS +// routing overlays. +// +// Map keys are fully-qualified DNS name suffixes; they may +// optionally contain a trailing dot but no leading dot. +// +// If the value is an empty slice, that means the suffix should still +// be handled by Tailscale's built-in resolver (100.100.100.100), such +// as for the purpose of handling ExtraRecords. func (v DNSConfigView) Routes() views.MapFn[string, []*dnstype.Resolver, views.SliceView[*dnstype.Resolver, dnstype.ResolverView]] { return views.MapFnOf(v.Đļ.Routes, func(t []*dnstype.Resolver) views.SliceView[*dnstype.Resolver, dnstype.ResolverView] { return views.SliceOfViews[*dnstype.Resolver, dnstype.ResolverView](t) }) } + +// FallbackResolvers is like Resolvers, but is only used if a +// split DNS configuration is requested in a configuration that +// doesn't work yet without explicit default resolvers. +// https://github.com/tailscale/tailscale/issues/1743 func (v DNSConfigView) FallbackResolvers() views.SliceView[*dnstype.Resolver, dnstype.ResolverView] { return views.SliceOfViews[*dnstype.Resolver, dnstype.ResolverView](v.Đļ.FallbackResolvers) } -func (v DNSConfigView) Domains() views.Slice[string] { return views.SliceOf(v.Đļ.Domains) } -func (v DNSConfigView) Proxied() bool { return v.Đļ.Proxied } + +// Domains are the search domains to use. +// Search domains must be FQDNs, but *without* the trailing dot. +func (v DNSConfigView) Domains() views.Slice[string] { return views.SliceOf(v.Đļ.Domains) } + +// Proxied turns on automatic resolution of hostnames for devices +// in the network map, aka MagicDNS. +// Despite the (legacy) name, does not necessarily cause request +// proxying to be enabled. +func (v DNSConfigView) Proxied() bool { return v.Đļ.Proxied } + +// Nameservers are the IP addresses of the global nameservers to use. +// +// Deprecated: this is only set and used by MapRequest.Version >=9 and <14. Use Resolvers instead. func (v DNSConfigView) Nameservers() views.Slice[netip.Addr] { return views.SliceOf(v.Đļ.Nameservers) } -func (v DNSConfigView) CertDomains() views.Slice[string] { return views.SliceOf(v.Đļ.CertDomains) } + +// CertDomains are the set of DNS names for which the control +// plane server will assist with provisioning TLS +// certificates. See SetDNSRequest, which can be used to +// answer dns-01 ACME challenges for e.g. LetsEncrypt. +// These names are FQDNs without trailing periods, and without +// any "_acme-challenge." prefix. +func (v DNSConfigView) CertDomains() views.Slice[string] { return views.SliceOf(v.Đļ.CertDomains) } + +// ExtraRecords contains extra DNS records to add to the +// MagicDNS config. func (v DNSConfigView) ExtraRecords() views.Slice[DNSRecord] { return views.SliceOf(v.Đļ.ExtraRecords) } + +// ExitNodeFilteredSuffixes are the DNS suffixes that the +// node, when being an exit node DNS proxy, should not answer. +// +// The entries do not contain trailing periods and are always +// all lowercase. +// +// If an entry starts with a period, it's a suffix match (but +// suffix ".a.b" doesn't match "a.b"; a prefix is required). +// +// If an entry does not start with a period, it's an exact +// match. +// +// Matches are case insensitive. func (v DNSConfigView) ExitNodeFilteredSet() views.Slice[string] { return views.SliceOf(v.Đļ.ExitNodeFilteredSet) } + +// TempCorpIssue13969 is a temporary (2023-08-16) field for an internal hack day prototype. +// It contains a user inputed URL that should have a list of domains to be blocked. +// See https://github.com/tailscale/corp/issues/13969. func (v DNSConfigView) TempCorpIssue13969() string { return v.Đļ.TempCorpIssue13969 } // A compilation failure here means this code must be regenerated, with the command at the top of this file. @@ -590,7 +1070,7 @@ var _DNSConfigViewNeedsRegeneration = DNSConfig(struct { TempCorpIssue13969 string }{}) -// View returns a readonly view of RegisterResponse. +// View returns a read-only view of RegisterResponse. func (p *RegisterResponse) View() RegisterResponseView { return RegisterResponseView{Đļ: p} } @@ -606,7 +1086,7 @@ type RegisterResponseView struct { Đļ *RegisterResponse } -// Valid reports whether underlying value is non-nil. +// Valid reports whether v's underlying value is non-nil. func (v RegisterResponseView) Valid() bool { return v.Đļ != nil } // AsStruct returns a clone of the underlying value which aliases no memory with @@ -618,8 +1098,17 @@ func (v RegisterResponseView) AsStruct() *RegisterResponse { return v.Đļ.Clone() } -func (v RegisterResponseView) MarshalJSON() ([]byte, error) { return json.Marshal(v.Đļ) } +// MarshalJSON implements [jsonv1.Marshaler]. +func (v RegisterResponseView) MarshalJSON() ([]byte, error) { + return jsonv1.Marshal(v.Đļ) +} + +// MarshalJSONTo implements [jsonv2.MarshalerTo]. +func (v RegisterResponseView) MarshalJSONTo(enc *jsontext.Encoder) error { + return jsonv2.MarshalEncode(enc, v.Đļ) +} +// UnmarshalJSON implements [jsonv1.Unmarshaler]. func (v *RegisterResponseView) UnmarshalJSON(b []byte) error { if v.Đļ != nil { return errors.New("already initialized") @@ -628,21 +1117,46 @@ func (v *RegisterResponseView) UnmarshalJSON(b []byte) error { return nil } var x RegisterResponse - if err := json.Unmarshal(b, &x); err != nil { + if err := jsonv1.Unmarshal(b, &x); err != nil { + return err + } + v.Đļ = &x + return nil +} + +// UnmarshalJSONFrom implements [jsonv2.UnmarshalerFrom]. +func (v *RegisterResponseView) UnmarshalJSONFrom(dec *jsontext.Decoder) error { + if v.Đļ != nil { + return errors.New("already initialized") + } + var x RegisterResponse + if err := jsonv2.UnmarshalDecode(dec, &x); err != nil { return err } v.Đļ = &x return nil } -func (v RegisterResponseView) User() UserView { return v.Đļ.User.View() } -func (v RegisterResponseView) Login() Login { return v.Đļ.Login } -func (v RegisterResponseView) NodeKeyExpired() bool { return v.Đļ.NodeKeyExpired } +func (v RegisterResponseView) User() User { return v.Đļ.User } +func (v RegisterResponseView) Login() Login { return v.Đļ.Login } + +// if true, the NodeKey needs to be replaced +func (v RegisterResponseView) NodeKeyExpired() bool { return v.Đļ.NodeKeyExpired } + +// TODO(crawshaw): move to using MachineStatus func (v RegisterResponseView) MachineAuthorized() bool { return v.Đļ.MachineAuthorized } -func (v RegisterResponseView) AuthURL() string { return v.Đļ.AuthURL } + +// if set, authorization pending +func (v RegisterResponseView) AuthURL() string { return v.Đļ.AuthURL } + +// If set, this is the current node-key signature that needs to be +// re-signed for the node's new node-key. func (v RegisterResponseView) NodeKeySignature() views.ByteSlice[tkatype.MarshaledSignature] { return views.ByteSliceOf(v.Đļ.NodeKeySignature) } + +// Error indicates that authorization failed. If this is non-empty, +// other status fields should be ignored. func (v RegisterResponseView) Error() string { return v.Đļ.Error } // A compilation failure here means this code must be regenerated, with the command at the top of this file. @@ -656,7 +1170,7 @@ var _RegisterResponseViewNeedsRegeneration = RegisterResponse(struct { Error string }{}) -// View returns a readonly view of RegisterResponseAuth. +// View returns a read-only view of RegisterResponseAuth. func (p *RegisterResponseAuth) View() RegisterResponseAuthView { return RegisterResponseAuthView{Đļ: p} } @@ -672,7 +1186,7 @@ type RegisterResponseAuthView struct { Đļ *RegisterResponseAuth } -// Valid reports whether underlying value is non-nil. +// Valid reports whether v's underlying value is non-nil. func (v RegisterResponseAuthView) Valid() bool { return v.Đļ != nil } // AsStruct returns a clone of the underlying value which aliases no memory with @@ -684,8 +1198,17 @@ func (v RegisterResponseAuthView) AsStruct() *RegisterResponseAuth { return v.Đļ.Clone() } -func (v RegisterResponseAuthView) MarshalJSON() ([]byte, error) { return json.Marshal(v.Đļ) } +// MarshalJSON implements [jsonv1.Marshaler]. +func (v RegisterResponseAuthView) MarshalJSON() ([]byte, error) { + return jsonv1.Marshal(v.Đļ) +} + +// MarshalJSONTo implements [jsonv2.MarshalerTo]. +func (v RegisterResponseAuthView) MarshalJSONTo(enc *jsontext.Encoder) error { + return jsonv2.MarshalEncode(enc, v.Đļ) +} +// UnmarshalJSON implements [jsonv1.Unmarshaler]. func (v *RegisterResponseAuthView) UnmarshalJSON(b []byte) error { if v.Đļ != nil { return errors.New("already initialized") @@ -694,19 +1217,29 @@ func (v *RegisterResponseAuthView) UnmarshalJSON(b []byte) error { return nil } var x RegisterResponseAuth - if err := json.Unmarshal(b, &x); err != nil { + if err := jsonv1.Unmarshal(b, &x); err != nil { return err } v.Đļ = &x return nil } -func (v RegisterResponseAuthView) Oauth2Token() *Oauth2Token { - if v.Đļ.Oauth2Token == nil { - return nil +// UnmarshalJSONFrom implements [jsonv2.UnmarshalerFrom]. +func (v *RegisterResponseAuthView) UnmarshalJSONFrom(dec *jsontext.Decoder) error { + if v.Đļ != nil { + return errors.New("already initialized") + } + var x RegisterResponseAuth + if err := jsonv2.UnmarshalDecode(dec, &x); err != nil { + return err } - x := *v.Đļ.Oauth2Token - return &x + v.Đļ = &x + return nil +} + +// used by pre-1.66 Android only +func (v RegisterResponseAuthView) Oauth2Token() views.ValuePointer[Oauth2Token] { + return views.ValuePointerOf(v.Đļ.Oauth2Token) } func (v RegisterResponseAuthView) AuthKey() string { return v.Đļ.AuthKey } @@ -718,7 +1251,7 @@ var _RegisterResponseAuthViewNeedsRegeneration = RegisterResponseAuth(struct { AuthKey string }{}) -// View returns a readonly view of RegisterRequest. +// View returns a read-only view of RegisterRequest. func (p *RegisterRequest) View() RegisterRequestView { return RegisterRequestView{Đļ: p} } @@ -734,7 +1267,7 @@ type RegisterRequestView struct { Đļ *RegisterRequest } -// Valid reports whether underlying value is non-nil. +// Valid reports whether v's underlying value is non-nil. func (v RegisterRequestView) Valid() bool { return v.Đļ != nil } // AsStruct returns a clone of the underlying value which aliases no memory with @@ -746,8 +1279,17 @@ func (v RegisterRequestView) AsStruct() *RegisterRequest { return v.Đļ.Clone() } -func (v RegisterRequestView) MarshalJSON() ([]byte, error) { return json.Marshal(v.Đļ) } +// MarshalJSON implements [jsonv1.Marshaler]. +func (v RegisterRequestView) MarshalJSON() ([]byte, error) { + return jsonv1.Marshal(v.Đļ) +} + +// MarshalJSONTo implements [jsonv2.MarshalerTo]. +func (v RegisterRequestView) MarshalJSONTo(enc *jsontext.Encoder) error { + return jsonv2.MarshalEncode(enc, v.Đļ) +} +// UnmarshalJSON implements [jsonv1.Unmarshaler]. func (v *RegisterRequestView) UnmarshalJSON(b []byte) error { if v.Đļ != nil { return errors.New("already initialized") @@ -756,40 +1298,89 @@ func (v *RegisterRequestView) UnmarshalJSON(b []byte) error { return nil } var x RegisterRequest - if err := json.Unmarshal(b, &x); err != nil { + if err := jsonv1.Unmarshal(b, &x); err != nil { return err } v.Đļ = &x return nil } +// UnmarshalJSONFrom implements [jsonv2.UnmarshalerFrom]. +func (v *RegisterRequestView) UnmarshalJSONFrom(dec *jsontext.Decoder) error { + if v.Đļ != nil { + return errors.New("already initialized") + } + var x RegisterRequest + if err := jsonv2.UnmarshalDecode(dec, &x); err != nil { + return err + } + v.Đļ = &x + return nil +} + +// Version is the client's capabilities when using the Noise +// transport. +// +// When using the original nacl crypto_box transport, the +// value must be 1. func (v RegisterRequestView) Version() CapabilityVersion { return v.Đļ.Version } func (v RegisterRequestView) NodeKey() key.NodePublic { return v.Đļ.NodeKey } func (v RegisterRequestView) OldNodeKey() key.NodePublic { return v.Đļ.OldNodeKey } func (v RegisterRequestView) NLKey() key.NLPublic { return v.Đļ.NLKey } func (v RegisterRequestView) Auth() RegisterResponseAuthView { return v.Đļ.Auth.View() } -func (v RegisterRequestView) Expiry() time.Time { return v.Đļ.Expiry } -func (v RegisterRequestView) Followup() string { return v.Đļ.Followup } -func (v RegisterRequestView) Hostinfo() HostinfoView { return v.Đļ.Hostinfo.View() } -func (v RegisterRequestView) Ephemeral() bool { return v.Đļ.Ephemeral } + +// Expiry optionally specifies the requested key expiry. +// The server policy may override. +// As a special case, if Expiry is in the past and NodeKey is +// the node's current key, the key is expired. +func (v RegisterRequestView) Expiry() time.Time { return v.Đļ.Expiry } + +// response waits until AuthURL is visited +func (v RegisterRequestView) Followup() string { return v.Đļ.Followup } +func (v RegisterRequestView) Hostinfo() HostinfoView { return v.Đļ.Hostinfo.View() } + +// Ephemeral is whether the client is requesting that this +// node be considered ephemeral and be automatically deleted +// when it stops being active. +func (v RegisterRequestView) Ephemeral() bool { return v.Đļ.Ephemeral } + +// NodeKeySignature is the node's own node-key signature, re-signed +// for its new node key using its network-lock key. +// +// This field is set when the client retries registration after learning +// its NodeKeySignature (which is in need of rotation). func (v RegisterRequestView) NodeKeySignature() views.ByteSlice[tkatype.MarshaledSignature] { return views.ByteSliceOf(v.Đļ.NodeKeySignature) } + +// The following fields are not used for SignatureNone and are required for +// SignatureV1: func (v RegisterRequestView) SignatureType() SignatureType { return v.Đļ.SignatureType } -func (v RegisterRequestView) Timestamp() *time.Time { - if v.Đļ.Timestamp == nil { - return nil - } - x := *v.Đļ.Timestamp - return &x + +// creation time of request to prevent replay +func (v RegisterRequestView) Timestamp() views.ValuePointer[time.Time] { + return views.ValuePointerOf(v.Đļ.Timestamp) } +// X.509 certificate for client device func (v RegisterRequestView) DeviceCert() views.ByteSlice[[]byte] { return views.ByteSliceOf(v.Đļ.DeviceCert) } + +// as described by SignatureType func (v RegisterRequestView) Signature() views.ByteSlice[[]byte] { return views.ByteSliceOf(v.Đļ.Signature) } + +// Tailnet is an optional identifier specifying the name of the recommended or required +// network that the node should join. Its exact form should not be depended on; new +// forms are coming later. The identifier is generally a domain name (for an organization) +// or e-mail address (for a personal account on a shared e-mail provider). It is the same name +// used by the API, as described in /api.md#tailnet. +// If Tailnet begins with the prefix "required:" then the server should prevent logging in to a different +// network than the one specified. Otherwise, the server should recommend the specified network +// but still permit logging in to other networks. +// If empty, no recommendation is offered to the server and the login page should show all options. func (v RegisterRequestView) Tailnet() string { return v.Đļ.Tailnet } // A compilation failure here means this code must be regenerated, with the command at the top of this file. @@ -812,7 +1403,7 @@ var _RegisterRequestViewNeedsRegeneration = RegisterRequest(struct { Tailnet string }{}) -// View returns a readonly view of DERPHomeParams. +// View returns a read-only view of DERPHomeParams. func (p *DERPHomeParams) View() DERPHomeParamsView { return DERPHomeParamsView{Đļ: p} } @@ -828,7 +1419,7 @@ type DERPHomeParamsView struct { Đļ *DERPHomeParams } -// Valid reports whether underlying value is non-nil. +// Valid reports whether v's underlying value is non-nil. func (v DERPHomeParamsView) Valid() bool { return v.Đļ != nil } // AsStruct returns a clone of the underlying value which aliases no memory with @@ -840,8 +1431,17 @@ func (v DERPHomeParamsView) AsStruct() *DERPHomeParams { return v.Đļ.Clone() } -func (v DERPHomeParamsView) MarshalJSON() ([]byte, error) { return json.Marshal(v.Đļ) } +// MarshalJSON implements [jsonv1.Marshaler]. +func (v DERPHomeParamsView) MarshalJSON() ([]byte, error) { + return jsonv1.Marshal(v.Đļ) +} + +// MarshalJSONTo implements [jsonv2.MarshalerTo]. +func (v DERPHomeParamsView) MarshalJSONTo(enc *jsontext.Encoder) error { + return jsonv2.MarshalEncode(enc, v.Đļ) +} +// UnmarshalJSON implements [jsonv1.Unmarshaler]. func (v *DERPHomeParamsView) UnmarshalJSON(b []byte) error { if v.Đļ != nil { return errors.New("already initialized") @@ -850,23 +1450,49 @@ func (v *DERPHomeParamsView) UnmarshalJSON(b []byte) error { return nil } var x DERPHomeParams - if err := json.Unmarshal(b, &x); err != nil { + if err := jsonv1.Unmarshal(b, &x); err != nil { return err } v.Đļ = &x return nil } -func (v DERPHomeParamsView) RegionScore() views.Map[int, float64] { - return views.MapOf(v.Đļ.RegionScore) -} - -// A compilation failure here means this code must be regenerated, with the command at the top of this file. -var _DERPHomeParamsViewNeedsRegeneration = DERPHomeParams(struct { - RegionScore map[int]float64 +// UnmarshalJSONFrom implements [jsonv2.UnmarshalerFrom]. +func (v *DERPHomeParamsView) UnmarshalJSONFrom(dec *jsontext.Decoder) error { + if v.Đļ != nil { + return errors.New("already initialized") + } + var x DERPHomeParams + if err := jsonv2.UnmarshalDecode(dec, &x); err != nil { + return err + } + v.Đļ = &x + return nil +} + +// RegionScore scales latencies of DERP regions by a given scaling +// factor when determining which region to use as the home +// ("preferred") DERP. Scores in the range (0, 1) will cause this +// region to be proportionally more preferred, and scores in the range +// (1, ∞) will penalize a region. +// +// If a region is not present in this map, it is treated as having a +// score of 1.0. +// +// Scores should not be 0 or negative; such scores will be ignored. +// +// A nil map means no change from the previous value (if any); an empty +// non-nil map can be sent to reset all scores back to 1.0. +func (v DERPHomeParamsView) RegionScore() views.Map[int, float64] { + return views.MapOf(v.Đļ.RegionScore) +} + +// A compilation failure here means this code must be regenerated, with the command at the top of this file. +var _DERPHomeParamsViewNeedsRegeneration = DERPHomeParams(struct { + RegionScore map[int]float64 }{}) -// View returns a readonly view of DERPRegion. +// View returns a read-only view of DERPRegion. func (p *DERPRegion) View() DERPRegionView { return DERPRegionView{Đļ: p} } @@ -882,7 +1508,7 @@ type DERPRegionView struct { Đļ *DERPRegion } -// Valid reports whether underlying value is non-nil. +// Valid reports whether v's underlying value is non-nil. func (v DERPRegionView) Valid() bool { return v.Đļ != nil } // AsStruct returns a clone of the underlying value which aliases no memory with @@ -894,8 +1520,17 @@ func (v DERPRegionView) AsStruct() *DERPRegion { return v.Đļ.Clone() } -func (v DERPRegionView) MarshalJSON() ([]byte, error) { return json.Marshal(v.Đļ) } +// MarshalJSON implements [jsonv1.Marshaler]. +func (v DERPRegionView) MarshalJSON() ([]byte, error) { + return jsonv1.Marshal(v.Đļ) +} + +// MarshalJSONTo implements [jsonv2.MarshalerTo]. +func (v DERPRegionView) MarshalJSONTo(enc *jsontext.Encoder) error { + return jsonv2.MarshalEncode(enc, v.Đļ) +} +// UnmarshalJSON implements [jsonv1.Unmarshaler]. func (v *DERPRegionView) UnmarshalJSON(b []byte) error { if v.Đļ != nil { return errors.New("already initialized") @@ -904,35 +1539,108 @@ func (v *DERPRegionView) UnmarshalJSON(b []byte) error { return nil } var x DERPRegion - if err := json.Unmarshal(b, &x); err != nil { + if err := jsonv1.Unmarshal(b, &x); err != nil { + return err + } + v.Đļ = &x + return nil +} + +// UnmarshalJSONFrom implements [jsonv2.UnmarshalerFrom]. +func (v *DERPRegionView) UnmarshalJSONFrom(dec *jsontext.Decoder) error { + if v.Đļ != nil { + return errors.New("already initialized") + } + var x DERPRegion + if err := jsonv2.UnmarshalDecode(dec, &x); err != nil { return err } v.Đļ = &x return nil } -func (v DERPRegionView) RegionID() int { return v.Đļ.RegionID } +// RegionID is a unique integer for a geographic region. +// +// It corresponds to the legacy derpN.tailscale.com hostnames +// used by older clients. (Older clients will continue to resolve +// derpN.tailscale.com when contacting peers, rather than use +// the server-provided DERPMap) +// +// RegionIDs must be non-zero, positive, and guaranteed to fit +// in a JavaScript number. +// +// RegionIDs in range 900-999 are reserved for end users to run their +// own DERP nodes. +func (v DERPRegionView) RegionID() int { return v.Đļ.RegionID } + +// RegionCode is a short name for the region. It's usually a popular +// city or airport code in the region: "nyc", "sf", "sin", +// "fra", etc. func (v DERPRegionView) RegionCode() string { return v.Đļ.RegionCode } + +// RegionName is a long English name for the region: "New York City", +// "San Francisco", "Singapore", "Frankfurt", etc. func (v DERPRegionView) RegionName() string { return v.Đļ.RegionName } + +// Latitude, Longitude are optional geographical coordinates of the DERP region's city, in degrees. func (v DERPRegionView) Latitude() float64 { return v.Đļ.Latitude } func (v DERPRegionView) Longitude() float64 { return v.Đļ.Longitude } -func (v DERPRegionView) Avoid() bool { return v.Đļ.Avoid } + +// Avoid is whether the client should avoid picking this as its home region. +// The region should only be used if a peer is there. Clients already using +// this region as their home should migrate away to a new region without +// Avoid set. +// +// Deprecated: because of bugs in past implementations combined with unclear +// docs that caused people to think the bugs were intentional, this field is +// deprecated. It was never supposed to cause STUN/DERP measurement probes, +// but due to bugs, it sometimes did. And then some parts of the code began +// to rely on that property. But then we were unable to use this field for +// its original purpose, nor its later imagined purpose, because various +// parts of the codebase thought it meant one thing and others thought it +// meant another. But it did something in the middle instead. So we're retiring +// it. Use NoMeasureNoHome instead. +func (v DERPRegionView) Avoid() bool { return v.Đļ.Avoid } + +// NoMeasureNoHome says that this regions should not be measured for its +// latency distance (STUN, HTTPS, etc) or availability (e.g. captive portal +// checks) and should never be selected as the node's home region. However, +// if a peer declares this region as its home, then this client is allowed +// to connect to it for the purpose of communicating with that peer. +// +// This is what the now deprecated Avoid bool was supposed to mean +// originally but had implementation bugs and documentation omissions. +func (v DERPRegionView) NoMeasureNoHome() bool { return v.Đļ.NoMeasureNoHome } + +// Nodes are the DERP nodes running in this region, in +// priority order for the current client. Client TLS +// connections should ideally only go to the first entry +// (falling back to the second if necessary). STUN packets +// should go to the first 1 or 2. +// +// If nodes within a region route packets amongst themselves, +// but not to other regions. That said, each user/domain +// should get a the same preferred node order, so if all nodes +// for a user/network pick the first one (as they should, when +// things are healthy), the inter-cluster routing is minimal +// to zero. func (v DERPRegionView) Nodes() views.SliceView[*DERPNode, DERPNodeView] { return views.SliceOfViews[*DERPNode, DERPNodeView](v.Đļ.Nodes) } // A compilation failure here means this code must be regenerated, with the command at the top of this file. var _DERPRegionViewNeedsRegeneration = DERPRegion(struct { - RegionID int - RegionCode string - RegionName string - Latitude float64 - Longitude float64 - Avoid bool - Nodes []*DERPNode + RegionID int + RegionCode string + RegionName string + Latitude float64 + Longitude float64 + Avoid bool + NoMeasureNoHome bool + Nodes []*DERPNode }{}) -// View returns a readonly view of DERPMap. +// View returns a read-only view of DERPMap. func (p *DERPMap) View() DERPMapView { return DERPMapView{Đļ: p} } @@ -948,7 +1656,7 @@ type DERPMapView struct { Đļ *DERPMap } -// Valid reports whether underlying value is non-nil. +// Valid reports whether v's underlying value is non-nil. func (v DERPMapView) Valid() bool { return v.Đļ != nil } // AsStruct returns a clone of the underlying value which aliases no memory with @@ -960,8 +1668,17 @@ func (v DERPMapView) AsStruct() *DERPMap { return v.Đļ.Clone() } -func (v DERPMapView) MarshalJSON() ([]byte, error) { return json.Marshal(v.Đļ) } +// MarshalJSON implements [jsonv1.Marshaler]. +func (v DERPMapView) MarshalJSON() ([]byte, error) { + return jsonv1.Marshal(v.Đļ) +} + +// MarshalJSONTo implements [jsonv2.MarshalerTo]. +func (v DERPMapView) MarshalJSONTo(enc *jsontext.Encoder) error { + return jsonv2.MarshalEncode(enc, v.Đļ) +} +// UnmarshalJSON implements [jsonv1.Unmarshaler]. func (v *DERPMapView) UnmarshalJSON(b []byte) error { if v.Đļ != nil { return errors.New("already initialized") @@ -970,20 +1687,46 @@ func (v *DERPMapView) UnmarshalJSON(b []byte) error { return nil } var x DERPMap - if err := json.Unmarshal(b, &x); err != nil { + if err := jsonv1.Unmarshal(b, &x); err != nil { return err } v.Đļ = &x return nil } +// UnmarshalJSONFrom implements [jsonv2.UnmarshalerFrom]. +func (v *DERPMapView) UnmarshalJSONFrom(dec *jsontext.Decoder) error { + if v.Đļ != nil { + return errors.New("already initialized") + } + var x DERPMap + if err := jsonv2.UnmarshalDecode(dec, &x); err != nil { + return err + } + v.Đļ = &x + return nil +} + +// HomeParams, if non-nil, is a change in home parameters. +// +// The rest of the DEPRMap fields, if zero, means unchanged. func (v DERPMapView) HomeParams() DERPHomeParamsView { return v.Đļ.HomeParams.View() } +// Regions is the set of geographic regions running DERP node(s). +// +// It's keyed by the DERPRegion.RegionID. +// +// The numbers are not necessarily contiguous. func (v DERPMapView) Regions() views.MapFn[int, *DERPRegion, DERPRegionView] { return views.MapFnOf(v.Đļ.Regions, func(t *DERPRegion) DERPRegionView { return t.View() }) } + +// OmitDefaultRegions specifies to not use Tailscale's DERP servers, and only use those +// specified in this DERPMap. If there are none set outside of the defaults, this is a noop. +// +// This field is only meaningful if the Regions map is non-nil (indicating a change). func (v DERPMapView) OmitDefaultRegions() bool { return v.Đļ.OmitDefaultRegions } // A compilation failure here means this code must be regenerated, with the command at the top of this file. @@ -993,7 +1736,7 @@ var _DERPMapViewNeedsRegeneration = DERPMap(struct { OmitDefaultRegions bool }{}) -// View returns a readonly view of DERPNode. +// View returns a read-only view of DERPNode. func (p *DERPNode) View() DERPNodeView { return DERPNodeView{Đļ: p} } @@ -1009,7 +1752,7 @@ type DERPNodeView struct { Đļ *DERPNode } -// Valid reports whether underlying value is non-nil. +// Valid reports whether v's underlying value is non-nil. func (v DERPNodeView) Valid() bool { return v.Đļ != nil } // AsStruct returns a clone of the underlying value which aliases no memory with @@ -1021,8 +1764,17 @@ func (v DERPNodeView) AsStruct() *DERPNode { return v.Đļ.Clone() } -func (v DERPNodeView) MarshalJSON() ([]byte, error) { return json.Marshal(v.Đļ) } +// MarshalJSON implements [jsonv1.Marshaler]. +func (v DERPNodeView) MarshalJSON() ([]byte, error) { + return jsonv1.Marshal(v.Đļ) +} +// MarshalJSONTo implements [jsonv2.MarshalerTo]. +func (v DERPNodeView) MarshalJSONTo(enc *jsontext.Encoder) error { + return jsonv2.MarshalEncode(enc, v.Đļ) +} + +// UnmarshalJSON implements [jsonv1.Unmarshaler]. func (v *DERPNodeView) UnmarshalJSON(b []byte) error { if v.Đļ != nil { return errors.New("already initialized") @@ -1031,25 +1783,94 @@ func (v *DERPNodeView) UnmarshalJSON(b []byte) error { return nil } var x DERPNode - if err := json.Unmarshal(b, &x); err != nil { + if err := jsonv1.Unmarshal(b, &x); err != nil { return err } v.Đļ = &x return nil } -func (v DERPNodeView) Name() string { return v.Đļ.Name } -func (v DERPNodeView) RegionID() int { return v.Đļ.RegionID } -func (v DERPNodeView) HostName() string { return v.Đļ.HostName } -func (v DERPNodeView) CertName() string { return v.Đļ.CertName } -func (v DERPNodeView) IPv4() string { return v.Đļ.IPv4 } -func (v DERPNodeView) IPv6() string { return v.Đļ.IPv6 } -func (v DERPNodeView) STUNPort() int { return v.Đļ.STUNPort } -func (v DERPNodeView) STUNOnly() bool { return v.Đļ.STUNOnly } -func (v DERPNodeView) DERPPort() int { return v.Đļ.DERPPort } +// UnmarshalJSONFrom implements [jsonv2.UnmarshalerFrom]. +func (v *DERPNodeView) UnmarshalJSONFrom(dec *jsontext.Decoder) error { + if v.Đļ != nil { + return errors.New("already initialized") + } + var x DERPNode + if err := jsonv2.UnmarshalDecode(dec, &x); err != nil { + return err + } + v.Đļ = &x + return nil +} + +// Name is a unique node name (across all regions). +// It is not a host name. +// It's typically of the form "1b", "2a", "3b", etc. (region +// ID + suffix within that region) +func (v DERPNodeView) Name() string { return v.Đļ.Name } + +// RegionID is the RegionID of the DERPRegion that this node +// is running in. +func (v DERPNodeView) RegionID() int { return v.Đļ.RegionID } + +// HostName is the DERP node's hostname. +// +// It is required but need not be unique; multiple nodes may +// have the same HostName but vary in configuration otherwise. +func (v DERPNodeView) HostName() string { return v.Đļ.HostName } + +// CertName optionally specifies the expected TLS cert common +// name. If empty, HostName is used. If CertName is non-empty, +// HostName is only used for the TCP dial (if IPv4/IPv6 are +// not present) + TLS ClientHello. +// +// As a special case, if CertName starts with "sha256-raw:", +// then the rest of the string is a hex-encoded SHA256 of the +// cert to expect. This is used for self-signed certs. +// In this case, the HostName field will typically be an IP +// address literal. +func (v DERPNodeView) CertName() string { return v.Đļ.CertName } + +// IPv4 optionally forces an IPv4 address to use, instead of using DNS. +// If empty, A record(s) from DNS lookups of HostName are used. +// If the string is not an IPv4 address, IPv4 is not used; the +// conventional string to disable IPv4 (and not use DNS) is +// "none". +func (v DERPNodeView) IPv4() string { return v.Đļ.IPv4 } + +// IPv6 optionally forces an IPv6 address to use, instead of using DNS. +// If empty, AAAA record(s) from DNS lookups of HostName are used. +// If the string is not an IPv6 address, IPv6 is not used; the +// conventional string to disable IPv6 (and not use DNS) is +// "none". +func (v DERPNodeView) IPv6() string { return v.Đļ.IPv6 } + +// Port optionally specifies a STUN port to use. +// Zero means 3478. +// To disable STUN on this node, use -1. +func (v DERPNodeView) STUNPort() int { return v.Đļ.STUNPort } + +// STUNOnly marks a node as only a STUN server and not a DERP +// server. +func (v DERPNodeView) STUNOnly() bool { return v.Đļ.STUNOnly } + +// DERPPort optionally provides an alternate TLS port number +// for the DERP HTTPS server. +// +// If zero, 443 is used. +func (v DERPNodeView) DERPPort() int { return v.Đļ.DERPPort } + +// InsecureForTests is used by unit tests to disable TLS verification. +// It should not be set by users. func (v DERPNodeView) InsecureForTests() bool { return v.Đļ.InsecureForTests } -func (v DERPNodeView) STUNTestIP() string { return v.Đļ.STUNTestIP } -func (v DERPNodeView) CanPort80() bool { return v.Đļ.CanPort80 } + +// STUNTestIP is used in tests to override the STUN server's IP. +// If empty, it's assumed to be the same as the DERP server. +func (v DERPNodeView) STUNTestIP() string { return v.Đļ.STUNTestIP } + +// CanPort80 specifies whether this DERP node is accessible over HTTP +// on port 80 specifically. This is used for captive portal checks. +func (v DERPNodeView) CanPort80() bool { return v.Đļ.CanPort80 } // A compilation failure here means this code must be regenerated, with the command at the top of this file. var _DERPNodeViewNeedsRegeneration = DERPNode(struct { @@ -1067,7 +1888,7 @@ var _DERPNodeViewNeedsRegeneration = DERPNode(struct { CanPort80 bool }{}) -// View returns a readonly view of SSHRule. +// View returns a read-only view of SSHRule. func (p *SSHRule) View() SSHRuleView { return SSHRuleView{Đļ: p} } @@ -1083,7 +1904,7 @@ type SSHRuleView struct { Đļ *SSHRule } -// Valid reports whether underlying value is non-nil. +// Valid reports whether v's underlying value is non-nil. func (v SSHRuleView) Valid() bool { return v.Đļ != nil } // AsStruct returns a clone of the underlying value which aliases no memory with @@ -1095,8 +1916,17 @@ func (v SSHRuleView) AsStruct() *SSHRule { return v.Đļ.Clone() } -func (v SSHRuleView) MarshalJSON() ([]byte, error) { return json.Marshal(v.Đļ) } +// MarshalJSON implements [jsonv1.Marshaler]. +func (v SSHRuleView) MarshalJSON() ([]byte, error) { + return jsonv1.Marshal(v.Đļ) +} + +// MarshalJSONTo implements [jsonv2.MarshalerTo]. +func (v SSHRuleView) MarshalJSONTo(enc *jsontext.Encoder) error { + return jsonv2.MarshalEncode(enc, v.Đļ) +} +// UnmarshalJSON implements [jsonv1.Unmarshaler]. func (v *SSHRuleView) UnmarshalJSON(b []byte) error { if v.Đļ != nil { return errors.New("already initialized") @@ -1105,28 +1935,69 @@ func (v *SSHRuleView) UnmarshalJSON(b []byte) error { return nil } var x SSHRule - if err := json.Unmarshal(b, &x); err != nil { + if err := jsonv1.Unmarshal(b, &x); err != nil { return err } v.Đļ = &x return nil } -func (v SSHRuleView) RuleExpires() *time.Time { - if v.Đļ.RuleExpires == nil { - return nil +// UnmarshalJSONFrom implements [jsonv2.UnmarshalerFrom]. +func (v *SSHRuleView) UnmarshalJSONFrom(dec *jsontext.Decoder) error { + if v.Đļ != nil { + return errors.New("already initialized") } - x := *v.Đļ.RuleExpires - return &x + var x SSHRule + if err := jsonv2.UnmarshalDecode(dec, &x); err != nil { + return err + } + v.Đļ = &x + return nil } +// RuleExpires, if non-nil, is when this rule expires. +// +// For example, a (principal,sshuser) tuple might be granted +// prompt-free SSH access for N minutes, so this rule would be +// before a expiration-free rule for the same principal that +// required an auth prompt. This permits the control plane to +// be out of the path for already-authorized SSH pairs. +// +// Once a rule matches, the lifetime of any accepting connection +// is subject to the SSHAction.SessionExpires time, if any. +func (v SSHRuleView) RuleExpires() views.ValuePointer[time.Time] { + return views.ValuePointerOf(v.Đļ.RuleExpires) +} + +// Principals matches an incoming connection. If the connection +// matches anything in this list and also matches SSHUsers, +// then Action is applied. func (v SSHRuleView) Principals() views.SliceView[*SSHPrincipal, SSHPrincipalView] { return views.SliceOfViews[*SSHPrincipal, SSHPrincipalView](v.Đļ.Principals) } +// SSHUsers are the SSH users that this rule matches. It is a +// map from either ssh-user|"*" => local-user. The map must +// contain a key for either ssh-user or, as a fallback, "*" to +// match anything. If it does, the map entry's value is the +// actual user that's logged in. +// If the map value is the empty string (for either the +// requested SSH user or "*"), the rule doesn't match. +// If the map value is "=", it means the ssh-user should map +// directly to the local-user. +// It may be nil if the Action is reject. func (v SSHRuleView) SSHUsers() views.Map[string, string] { return views.MapOf(v.Đļ.SSHUsers) } -func (v SSHRuleView) Action() SSHActionView { return v.Đļ.Action.View() } -func (v SSHRuleView) AcceptEnv() views.Slice[string] { return views.SliceOf(v.Đļ.AcceptEnv) } + +// Action is the outcome to task. +// A nil or invalid action means to deny. +func (v SSHRuleView) Action() SSHActionView { return v.Đļ.Action.View() } + +// AcceptEnv is a slice of environment variable names that are allowlisted +// for the SSH rule in the policy file. +// +// AcceptEnv values may contain * and ? wildcard characters which match against +// an arbitrary number of characters or a single character respectively. +func (v SSHRuleView) AcceptEnv() views.Slice[string] { return views.SliceOf(v.Đļ.AcceptEnv) } // A compilation failure here means this code must be regenerated, with the command at the top of this file. var _SSHRuleViewNeedsRegeneration = SSHRule(struct { @@ -1137,7 +2008,7 @@ var _SSHRuleViewNeedsRegeneration = SSHRule(struct { AcceptEnv []string }{}) -// View returns a readonly view of SSHAction. +// View returns a read-only view of SSHAction. func (p *SSHAction) View() SSHActionView { return SSHActionView{Đļ: p} } @@ -1153,7 +2024,7 @@ type SSHActionView struct { Đļ *SSHAction } -// Valid reports whether underlying value is non-nil. +// Valid reports whether v's underlying value is non-nil. func (v SSHActionView) Valid() bool { return v.Đļ != nil } // AsStruct returns a clone of the underlying value which aliases no memory with @@ -1165,8 +2036,17 @@ func (v SSHActionView) AsStruct() *SSHAction { return v.Đļ.Clone() } -func (v SSHActionView) MarshalJSON() ([]byte, error) { return json.Marshal(v.Đļ) } +// MarshalJSON implements [jsonv1.Marshaler]. +func (v SSHActionView) MarshalJSON() ([]byte, error) { + return jsonv1.Marshal(v.Đļ) +} + +// MarshalJSONTo implements [jsonv2.MarshalerTo]. +func (v SSHActionView) MarshalJSONTo(enc *jsontext.Encoder) error { + return jsonv2.MarshalEncode(enc, v.Đļ) +} +// UnmarshalJSON implements [jsonv1.Unmarshaler]. func (v *SSHActionView) UnmarshalJSON(b []byte) error { if v.Đļ != nil { return errors.New("already initialized") @@ -1175,28 +2055,83 @@ func (v *SSHActionView) UnmarshalJSON(b []byte) error { return nil } var x SSHAction - if err := json.Unmarshal(b, &x); err != nil { + if err := jsonv1.Unmarshal(b, &x); err != nil { return err } v.Đļ = &x return nil } -func (v SSHActionView) Message() string { return v.Đļ.Message } -func (v SSHActionView) Reject() bool { return v.Đļ.Reject } -func (v SSHActionView) Accept() bool { return v.Đļ.Accept } -func (v SSHActionView) SessionDuration() time.Duration { return v.Đļ.SessionDuration } -func (v SSHActionView) AllowAgentForwarding() bool { return v.Đļ.AllowAgentForwarding } -func (v SSHActionView) HoldAndDelegate() string { return v.Đļ.HoldAndDelegate } -func (v SSHActionView) AllowLocalPortForwarding() bool { return v.Đļ.AllowLocalPortForwarding } -func (v SSHActionView) AllowRemotePortForwarding() bool { return v.Đļ.AllowRemotePortForwarding } -func (v SSHActionView) Recorders() views.Slice[netip.AddrPort] { return views.SliceOf(v.Đļ.Recorders) } -func (v SSHActionView) OnRecordingFailure() *SSHRecorderFailureAction { - if v.Đļ.OnRecordingFailure == nil { - return nil +// UnmarshalJSONFrom implements [jsonv2.UnmarshalerFrom]. +func (v *SSHActionView) UnmarshalJSONFrom(dec *jsontext.Decoder) error { + if v.Đļ != nil { + return errors.New("already initialized") } - x := *v.Đļ.OnRecordingFailure - return &x + var x SSHAction + if err := jsonv2.UnmarshalDecode(dec, &x); err != nil { + return err + } + v.Đļ = &x + return nil +} + +// Message, if non-empty, is shown to the user before the +// action occurs. +func (v SSHActionView) Message() string { return v.Đļ.Message } + +// Reject, if true, terminates the connection. This action +// has higher priority that Accept, if given. +// The reason this is exists is primarily so a response +// from HoldAndDelegate has a way to stop the poll. +func (v SSHActionView) Reject() bool { return v.Đļ.Reject } + +// Accept, if true, accepts the connection immediately +// without further prompts. +func (v SSHActionView) Accept() bool { return v.Đļ.Accept } + +// SessionDuration, if non-zero, is how long the session can stay open +// before being forcefully terminated. +func (v SSHActionView) SessionDuration() time.Duration { return v.Đļ.SessionDuration } + +// AllowAgentForwarding, if true, allows accepted connections to forward +// the ssh agent if requested. +func (v SSHActionView) AllowAgentForwarding() bool { return v.Đļ.AllowAgentForwarding } + +// HoldAndDelegate, if non-empty, is a URL that serves an +// outcome verdict. The connection will be accepted and will +// block until the provided long-polling URL serves a new +// SSHAction JSON value. The URL must be fetched using the +// Noise transport (in package control/control{base,http}). +// If the long poll breaks before returning a complete HTTP +// response, it should be re-fetched as long as the SSH +// session is open. +// +// The following variables in the URL are expanded by tailscaled: +// +// - $SRC_NODE_IP (URL escaped) +// - $SRC_NODE_ID (Node.ID as int64 string) +// - $DST_NODE_IP (URL escaped) +// - $DST_NODE_ID (Node.ID as int64 string) +// - $SSH_USER (URL escaped, ssh user requested) +// - $LOCAL_USER (URL escaped, local user mapped) +func (v SSHActionView) HoldAndDelegate() string { return v.Đļ.HoldAndDelegate } + +// AllowLocalPortForwarding, if true, allows accepted connections +// to use local port forwarding if requested. +func (v SSHActionView) AllowLocalPortForwarding() bool { return v.Đļ.AllowLocalPortForwarding } + +// AllowRemotePortForwarding, if true, allows accepted connections +// to use remote port forwarding if requested. +func (v SSHActionView) AllowRemotePortForwarding() bool { return v.Đļ.AllowRemotePortForwarding } + +// Recorders defines the destinations of the SSH session recorders. +// The recording will be uploaded to http://addr:port/record. +func (v SSHActionView) Recorders() views.Slice[netip.AddrPort] { return views.SliceOf(v.Đļ.Recorders) } + +// OnRecorderFailure is the action to take if recording fails. +// If nil, the default action is to fail open. +func (v SSHActionView) OnRecordingFailure() views.ValuePointer[SSHRecorderFailureAction] { + return views.ValuePointerOf(v.Đļ.OnRecordingFailure) } // A compilation failure here means this code must be regenerated, with the command at the top of this file. @@ -1213,7 +2148,7 @@ var _SSHActionViewNeedsRegeneration = SSHAction(struct { OnRecordingFailure *SSHRecorderFailureAction }{}) -// View returns a readonly view of SSHPrincipal. +// View returns a read-only view of SSHPrincipal. func (p *SSHPrincipal) View() SSHPrincipalView { return SSHPrincipalView{Đļ: p} } @@ -1229,7 +2164,7 @@ type SSHPrincipalView struct { Đļ *SSHPrincipal } -// Valid reports whether underlying value is non-nil. +// Valid reports whether v's underlying value is non-nil. func (v SSHPrincipalView) Valid() bool { return v.Đļ != nil } // AsStruct returns a clone of the underlying value which aliases no memory with @@ -1241,8 +2176,17 @@ func (v SSHPrincipalView) AsStruct() *SSHPrincipal { return v.Đļ.Clone() } -func (v SSHPrincipalView) MarshalJSON() ([]byte, error) { return json.Marshal(v.Đļ) } +// MarshalJSON implements [jsonv1.Marshaler]. +func (v SSHPrincipalView) MarshalJSON() ([]byte, error) { + return jsonv1.Marshal(v.Đļ) +} +// MarshalJSONTo implements [jsonv2.MarshalerTo]. +func (v SSHPrincipalView) MarshalJSONTo(enc *jsontext.Encoder) error { + return jsonv2.MarshalEncode(enc, v.Đļ) +} + +// UnmarshalJSON implements [jsonv1.Unmarshaler]. func (v *SSHPrincipalView) UnmarshalJSON(b []byte) error { if v.Đļ != nil { return errors.New("already initialized") @@ -1251,29 +2195,55 @@ func (v *SSHPrincipalView) UnmarshalJSON(b []byte) error { return nil } var x SSHPrincipal - if err := json.Unmarshal(b, &x); err != nil { + if err := jsonv1.Unmarshal(b, &x); err != nil { return err } v.Đļ = &x return nil } -func (v SSHPrincipalView) Node() StableNodeID { return v.Đļ.Node } -func (v SSHPrincipalView) NodeIP() string { return v.Đļ.NodeIP } -func (v SSHPrincipalView) UserLogin() string { return v.Đļ.UserLogin } -func (v SSHPrincipalView) Any() bool { return v.Đļ.Any } -func (v SSHPrincipalView) PubKeys() views.Slice[string] { return views.SliceOf(v.Đļ.PubKeys) } +// UnmarshalJSONFrom implements [jsonv2.UnmarshalerFrom]. +func (v *SSHPrincipalView) UnmarshalJSONFrom(dec *jsontext.Decoder) error { + if v.Đļ != nil { + return errors.New("already initialized") + } + var x SSHPrincipal + if err := jsonv2.UnmarshalDecode(dec, &x); err != nil { + return err + } + v.Đļ = &x + return nil +} + +func (v SSHPrincipalView) Node() StableNodeID { return v.Đļ.Node } +func (v SSHPrincipalView) NodeIP() string { return v.Đļ.NodeIP } + +// email-ish: foo@example.com, bar@github +func (v SSHPrincipalView) UserLogin() string { return v.Đļ.UserLogin } + +// if true, match any connection +func (v SSHPrincipalView) Any() bool { return v.Đļ.Any } + +// UnusedPubKeys was public key support. It never became an official product +// feature and so as of 2024-12-12 is being removed. +// This stub exists to remind us not to re-use the JSON field name "pubKeys" +// in the future if we bring it back with different semantics. +// +// Deprecated: do not use. It does nothing. +func (v SSHPrincipalView) UnusedPubKeys() views.Slice[string] { + return views.SliceOf(v.Đļ.UnusedPubKeys) +} // A compilation failure here means this code must be regenerated, with the command at the top of this file. var _SSHPrincipalViewNeedsRegeneration = SSHPrincipal(struct { - Node StableNodeID - NodeIP string - UserLogin string - Any bool - PubKeys []string + Node StableNodeID + NodeIP string + UserLogin string + Any bool + UnusedPubKeys []string }{}) -// View returns a readonly view of ControlDialPlan. +// View returns a read-only view of ControlDialPlan. func (p *ControlDialPlan) View() ControlDialPlanView { return ControlDialPlanView{Đļ: p} } @@ -1289,7 +2259,7 @@ type ControlDialPlanView struct { Đļ *ControlDialPlan } -// Valid reports whether underlying value is non-nil. +// Valid reports whether v's underlying value is non-nil. func (v ControlDialPlanView) Valid() bool { return v.Đļ != nil } // AsStruct returns a clone of the underlying value which aliases no memory with @@ -1301,8 +2271,17 @@ func (v ControlDialPlanView) AsStruct() *ControlDialPlan { return v.Đļ.Clone() } -func (v ControlDialPlanView) MarshalJSON() ([]byte, error) { return json.Marshal(v.Đļ) } +// MarshalJSON implements [jsonv1.Marshaler]. +func (v ControlDialPlanView) MarshalJSON() ([]byte, error) { + return jsonv1.Marshal(v.Đļ) +} + +// MarshalJSONTo implements [jsonv2.MarshalerTo]. +func (v ControlDialPlanView) MarshalJSONTo(enc *jsontext.Encoder) error { + return jsonv2.MarshalEncode(enc, v.Đļ) +} +// UnmarshalJSON implements [jsonv1.Unmarshaler]. func (v *ControlDialPlanView) UnmarshalJSON(b []byte) error { if v.Đļ != nil { return errors.New("already initialized") @@ -1311,13 +2290,27 @@ func (v *ControlDialPlanView) UnmarshalJSON(b []byte) error { return nil } var x ControlDialPlan - if err := json.Unmarshal(b, &x); err != nil { + if err := jsonv1.Unmarshal(b, &x); err != nil { return err } v.Đļ = &x return nil } +// UnmarshalJSONFrom implements [jsonv2.UnmarshalerFrom]. +func (v *ControlDialPlanView) UnmarshalJSONFrom(dec *jsontext.Decoder) error { + if v.Đļ != nil { + return errors.New("already initialized") + } + var x ControlDialPlan + if err := jsonv2.UnmarshalDecode(dec, &x); err != nil { + return err + } + v.Đļ = &x + return nil +} + +// An empty list means the default: use DNS (unspecified which DNS). func (v ControlDialPlanView) Candidates() views.Slice[ControlIPCandidate] { return views.SliceOf(v.Đļ.Candidates) } @@ -1327,7 +2320,7 @@ var _ControlDialPlanViewNeedsRegeneration = ControlDialPlan(struct { Candidates []ControlIPCandidate }{}) -// View returns a readonly view of Location. +// View returns a read-only view of Location. func (p *Location) View() LocationView { return LocationView{Đļ: p} } @@ -1343,7 +2336,7 @@ type LocationView struct { Đļ *Location } -// Valid reports whether underlying value is non-nil. +// Valid reports whether v's underlying value is non-nil. func (v LocationView) Valid() bool { return v.Đļ != nil } // AsStruct returns a clone of the underlying value which aliases no memory with @@ -1355,8 +2348,17 @@ func (v LocationView) AsStruct() *Location { return v.Đļ.Clone() } -func (v LocationView) MarshalJSON() ([]byte, error) { return json.Marshal(v.Đļ) } +// MarshalJSON implements [jsonv1.Marshaler]. +func (v LocationView) MarshalJSON() ([]byte, error) { + return jsonv1.Marshal(v.Đļ) +} + +// MarshalJSONTo implements [jsonv2.MarshalerTo]. +func (v LocationView) MarshalJSONTo(enc *jsontext.Encoder) error { + return jsonv2.MarshalEncode(enc, v.Đļ) +} +// UnmarshalJSON implements [jsonv1.Unmarshaler]. func (v *LocationView) UnmarshalJSON(b []byte) error { if v.Đļ != nil { return errors.New("already initialized") @@ -1365,20 +2367,55 @@ func (v *LocationView) UnmarshalJSON(b []byte) error { return nil } var x Location - if err := json.Unmarshal(b, &x); err != nil { + if err := jsonv1.Unmarshal(b, &x); err != nil { return err } v.Đļ = &x return nil } -func (v LocationView) Country() string { return v.Đļ.Country } +// UnmarshalJSONFrom implements [jsonv2.UnmarshalerFrom]. +func (v *LocationView) UnmarshalJSONFrom(dec *jsontext.Decoder) error { + if v.Đļ != nil { + return errors.New("already initialized") + } + var x Location + if err := jsonv2.UnmarshalDecode(dec, &x); err != nil { + return err + } + v.Đļ = &x + return nil +} + +// User friendly country name, with proper capitalization ("Canada") +func (v LocationView) Country() string { return v.Đļ.Country } + +// ISO 3166-1 alpha-2 in upper case ("CA") func (v LocationView) CountryCode() string { return v.Đļ.CountryCode } -func (v LocationView) City() string { return v.Đļ.City } -func (v LocationView) CityCode() string { return v.Đļ.CityCode } -func (v LocationView) Latitude() float64 { return v.Đļ.Latitude } -func (v LocationView) Longitude() float64 { return v.Đļ.Longitude } -func (v LocationView) Priority() int { return v.Đļ.Priority } + +// User friendly city name, with proper capitalization ("Squamish") +func (v LocationView) City() string { return v.Đļ.City } + +// CityCode is a short code representing the city in upper case. +// CityCode is used to disambiguate a city from another location +// with the same city name. It uniquely identifies a particular +// geographical location, within the tailnet. +// IATA, ICAO or ISO 3166-2 codes are recommended ("YSE") +func (v LocationView) CityCode() string { return v.Đļ.CityCode } + +// Latitude, Longitude are optional geographical coordinates of the node, in degrees. +// No particular accuracy level is promised; the coordinates may simply be the center of the city or country. +func (v LocationView) Latitude() float64 { return v.Đļ.Latitude } +func (v LocationView) Longitude() float64 { return v.Đļ.Longitude } + +// Priority determines the order of use of an exit node when a +// location based preference matches more than one exit node, +// the node with the highest priority wins. Nodes of equal +// probability may be selected arbitrarily. +// +// A value of 0 means the exit node does not have a priority +// preference. A negative int is not allowed. +func (v LocationView) Priority() int { return v.Đļ.Priority } // A compilation failure here means this code must be regenerated, with the command at the top of this file. var _LocationViewNeedsRegeneration = Location(struct { @@ -1391,7 +2428,7 @@ var _LocationViewNeedsRegeneration = Location(struct { Priority int }{}) -// View returns a readonly view of UserProfile. +// View returns a read-only view of UserProfile. func (p *UserProfile) View() UserProfileView { return UserProfileView{Đļ: p} } @@ -1407,7 +2444,7 @@ type UserProfileView struct { Đļ *UserProfile } -// Valid reports whether underlying value is non-nil. +// Valid reports whether v's underlying value is non-nil. func (v UserProfileView) Valid() bool { return v.Đļ != nil } // AsStruct returns a clone of the underlying value which aliases no memory with @@ -1419,8 +2456,17 @@ func (v UserProfileView) AsStruct() *UserProfile { return v.Đļ.Clone() } -func (v UserProfileView) MarshalJSON() ([]byte, error) { return json.Marshal(v.Đļ) } +// MarshalJSON implements [jsonv1.Marshaler]. +func (v UserProfileView) MarshalJSON() ([]byte, error) { + return jsonv1.Marshal(v.Đļ) +} + +// MarshalJSONTo implements [jsonv2.MarshalerTo]. +func (v UserProfileView) MarshalJSONTo(enc *jsontext.Encoder) error { + return jsonv2.MarshalEncode(enc, v.Đļ) +} +// UnmarshalJSON implements [jsonv1.Unmarshaler]. func (v *UserProfileView) UnmarshalJSON(b []byte) error { if v.Đļ != nil { return errors.New("already initialized") @@ -1429,18 +2475,34 @@ func (v *UserProfileView) UnmarshalJSON(b []byte) error { return nil } var x UserProfile - if err := json.Unmarshal(b, &x); err != nil { + if err := jsonv1.Unmarshal(b, &x); err != nil { + return err + } + v.Đļ = &x + return nil +} + +// UnmarshalJSONFrom implements [jsonv2.UnmarshalerFrom]. +func (v *UserProfileView) UnmarshalJSONFrom(dec *jsontext.Decoder) error { + if v.Đļ != nil { + return errors.New("already initialized") + } + var x UserProfile + if err := jsonv2.UnmarshalDecode(dec, &x); err != nil { return err } v.Đļ = &x return nil } -func (v UserProfileView) ID() UserID { return v.Đļ.ID } -func (v UserProfileView) LoginName() string { return v.Đļ.LoginName } +func (v UserProfileView) ID() UserID { return v.Đļ.ID } + +// "alice@smith.com"; for display purposes only (provider is not listed) +func (v UserProfileView) LoginName() string { return v.Đļ.LoginName } + +// "Alice Smith" func (v UserProfileView) DisplayName() string { return v.Đļ.DisplayName } func (v UserProfileView) ProfilePicURL() string { return v.Đļ.ProfilePicURL } -func (v UserProfileView) Roles() emptyStructJSONSlice { return v.Đļ.Roles } func (v UserProfileView) Equal(v2 UserProfileView) bool { return v.Đļ.Equal(v2.Đļ) } // A compilation failure here means this code must be regenerated, with the command at the top of this file. @@ -1449,5 +2511,182 @@ var _UserProfileViewNeedsRegeneration = UserProfile(struct { LoginName string DisplayName string ProfilePicURL string - Roles emptyStructJSONSlice +}{}) + +// View returns a read-only view of VIPService. +func (p *VIPService) View() VIPServiceView { + return VIPServiceView{Đļ: p} +} + +// VIPServiceView provides a read-only view over VIPService. +// +// Its methods should only be called if `Valid()` returns true. +type VIPServiceView struct { + // Đļ is the underlying mutable value, named with a hard-to-type + // character that looks pointy like a pointer. + // It is named distinctively to make you think of how dangerous it is to escape + // to callers. You must not let callers be able to mutate it. + Đļ *VIPService +} + +// Valid reports whether v's underlying value is non-nil. +func (v VIPServiceView) Valid() bool { return v.Đļ != nil } + +// AsStruct returns a clone of the underlying value which aliases no memory with +// the original. +func (v VIPServiceView) AsStruct() *VIPService { + if v.Đļ == nil { + return nil + } + return v.Đļ.Clone() +} + +// MarshalJSON implements [jsonv1.Marshaler]. +func (v VIPServiceView) MarshalJSON() ([]byte, error) { + return jsonv1.Marshal(v.Đļ) +} + +// MarshalJSONTo implements [jsonv2.MarshalerTo]. +func (v VIPServiceView) MarshalJSONTo(enc *jsontext.Encoder) error { + return jsonv2.MarshalEncode(enc, v.Đļ) +} + +// UnmarshalJSON implements [jsonv1.Unmarshaler]. +func (v *VIPServiceView) UnmarshalJSON(b []byte) error { + if v.Đļ != nil { + return errors.New("already initialized") + } + if len(b) == 0 { + return nil + } + var x VIPService + if err := jsonv1.Unmarshal(b, &x); err != nil { + return err + } + v.Đļ = &x + return nil +} + +// UnmarshalJSONFrom implements [jsonv2.UnmarshalerFrom]. +func (v *VIPServiceView) UnmarshalJSONFrom(dec *jsontext.Decoder) error { + if v.Đļ != nil { + return errors.New("already initialized") + } + var x VIPService + if err := jsonv2.UnmarshalDecode(dec, &x); err != nil { + return err + } + v.Đļ = &x + return nil +} + +// Name is the name of the service. The Name uniquely identifies a service +// on a particular tailnet, and so also corresponds uniquely to the pair of +// IP addresses belonging to the VIP service. +func (v VIPServiceView) Name() ServiceName { return v.Đļ.Name } + +// Ports specify which ProtoPorts are made available by this node +// on the service's IPs. +func (v VIPServiceView) Ports() views.Slice[ProtoPortRange] { return views.SliceOf(v.Đļ.Ports) } + +// Active specifies whether new requests for the service should be +// sent to this node by control. +func (v VIPServiceView) Active() bool { return v.Đļ.Active } + +// A compilation failure here means this code must be regenerated, with the command at the top of this file. +var _VIPServiceViewNeedsRegeneration = VIPService(struct { + Name ServiceName + Ports []ProtoPortRange + Active bool +}{}) + +// View returns a read-only view of SSHPolicy. +func (p *SSHPolicy) View() SSHPolicyView { + return SSHPolicyView{Đļ: p} +} + +// SSHPolicyView provides a read-only view over SSHPolicy. +// +// Its methods should only be called if `Valid()` returns true. +type SSHPolicyView struct { + // Đļ is the underlying mutable value, named with a hard-to-type + // character that looks pointy like a pointer. + // It is named distinctively to make you think of how dangerous it is to escape + // to callers. You must not let callers be able to mutate it. + Đļ *SSHPolicy +} + +// Valid reports whether v's underlying value is non-nil. +func (v SSHPolicyView) Valid() bool { return v.Đļ != nil } + +// AsStruct returns a clone of the underlying value which aliases no memory with +// the original. +func (v SSHPolicyView) AsStruct() *SSHPolicy { + if v.Đļ == nil { + return nil + } + return v.Đļ.Clone() +} + +// MarshalJSON implements [jsonv1.Marshaler]. +func (v SSHPolicyView) MarshalJSON() ([]byte, error) { + return jsonv1.Marshal(v.Đļ) +} + +// MarshalJSONTo implements [jsonv2.MarshalerTo]. +func (v SSHPolicyView) MarshalJSONTo(enc *jsontext.Encoder) error { + return jsonv2.MarshalEncode(enc, v.Đļ) +} + +// UnmarshalJSON implements [jsonv1.Unmarshaler]. +func (v *SSHPolicyView) UnmarshalJSON(b []byte) error { + if v.Đļ != nil { + return errors.New("already initialized") + } + if len(b) == 0 { + return nil + } + var x SSHPolicy + if err := jsonv1.Unmarshal(b, &x); err != nil { + return err + } + v.Đļ = &x + return nil +} + +// UnmarshalJSONFrom implements [jsonv2.UnmarshalerFrom]. +func (v *SSHPolicyView) UnmarshalJSONFrom(dec *jsontext.Decoder) error { + if v.Đļ != nil { + return errors.New("already initialized") + } + var x SSHPolicy + if err := jsonv2.UnmarshalDecode(dec, &x); err != nil { + return err + } + v.Đļ = &x + return nil +} + +// Rules are the rules to process for an incoming SSH connection. The first +// matching rule takes its action and stops processing further rules. +// +// When an incoming connection first starts, all rules are evaluated in +// "none" auth mode, where the client hasn't even been asked to send a +// public key. All SSHRule.Principals requiring a public key won't match. If +// a rule matches on the first pass and its Action is reject, the +// authentication fails with that action's rejection message, if any. +// +// If the first pass rule evaluation matches nothing without matching an +// Action with Reject set, the rules are considered to see whether public +// keys might still result in a match. If not, "none" auth is terminated +// before proceeding to public key mode. If so, the client is asked to try +// public key authentication and the rules are evaluated again for each of +// the client's present keys. +func (v SSHPolicyView) Rules() views.SliceView[*SSHRule, SSHRuleView] { + return views.SliceOfViews[*SSHRule, SSHRuleView](v.Đļ.Rules) +} + +// A compilation failure here means this code must be regenerated, with the command at the top of this file. +var _SSHPolicyViewNeedsRegeneration = SSHPolicy(struct { + Rules []*SSHRule }{}) diff --git a/taildrop/send.go b/taildrop/send.go deleted file mode 100644 index 0dff71b24..000000000 --- a/taildrop/send.go +++ /dev/null @@ -1,252 +0,0 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package taildrop - -import ( - "crypto/sha256" - "errors" - "io" - "os" - "path/filepath" - "sync" - "time" - - "tailscale.com/envknob" - "tailscale.com/ipn" - "tailscale.com/tstime" - "tailscale.com/version/distro" -) - -type incomingFileKey struct { - id ClientID - name string // e.g., "foo.jpeg" -} - -type incomingFile struct { - clock tstime.DefaultClock - - started time.Time - size int64 // or -1 if unknown; never 0 - w io.Writer // underlying writer - sendFileNotify func() // called when done - partialPath string // non-empty in direct mode - finalPath string // not used in direct mode - - mu sync.Mutex - copied int64 - done bool - lastNotify time.Time -} - -func (f *incomingFile) Write(p []byte) (n int, err error) { - n, err = f.w.Write(p) - - var needNotify bool - defer func() { - if needNotify { - f.sendFileNotify() - } - }() - if n > 0 { - f.mu.Lock() - defer f.mu.Unlock() - f.copied += int64(n) - now := f.clock.Now() - if f.lastNotify.IsZero() || now.Sub(f.lastNotify) > time.Second { - f.lastNotify = now - needNotify = true - } - } - return n, err -} - -// PutFile stores a file into [Manager.Dir] from a given client id. -// The baseName must be a base filename without any slashes. -// The length is the expected length of content to read from r, -// it may be negative to indicate that it is unknown. -// It returns the length of the entire file. -// -// If there is a failure reading from r, then the partial file is not deleted -// for some period of time. The [Manager.PartialFiles] and [Manager.HashPartialFile] -// methods may be used to list all partial files and to compute the hash for a -// specific partial file. This allows the client to determine whether to resume -// a partial file. While resuming, PutFile may be called again with a non-zero -// offset to specify where to resume receiving data at. -func (m *Manager) PutFile(id ClientID, baseName string, r io.Reader, offset, length int64) (int64, error) { - switch { - case m == nil || m.opts.Dir == "": - return 0, ErrNoTaildrop - case !envknob.CanTaildrop(): - return 0, ErrNoTaildrop - case distro.Get() == distro.Unraid && !m.opts.DirectFileMode: - return 0, ErrNotAccessible - } - dstPath, err := joinDir(m.opts.Dir, baseName) - if err != nil { - return 0, err - } - - redactAndLogError := func(action string, err error) error { - err = redactError(err) - m.opts.Logf("put %v error: %v", action, err) - return err - } - - // Check whether there is an in-progress transfer for the file. - partialPath := dstPath + id.partialSuffix() - inFileKey := incomingFileKey{id, baseName} - inFile, loaded := m.incomingFiles.LoadOrInit(inFileKey, func() *incomingFile { - inFile := &incomingFile{ - clock: m.opts.Clock, - started: m.opts.Clock.Now(), - size: length, - sendFileNotify: m.opts.SendFileNotify, - } - if m.opts.DirectFileMode { - inFile.partialPath = partialPath - inFile.finalPath = dstPath - } - return inFile - }) - if loaded { - return 0, ErrFileExists - } - defer m.incomingFiles.Delete(inFileKey) - m.deleter.Remove(filepath.Base(partialPath)) // avoid deleting the partial file while receiving - - // Create (if not already) the partial file with read-write permissions. - f, err := os.OpenFile(partialPath, os.O_CREATE|os.O_RDWR, 0666) - if err != nil { - return 0, redactAndLogError("Create", err) - } - defer func() { - f.Close() // best-effort to cleanup dangling file handles - if err != nil { - m.deleter.Insert(filepath.Base(partialPath)) // mark partial file for eventual deletion - } - }() - inFile.w = f - - // Record that we have started to receive at least one file. - // This is used by the deleter upon a cold-start to scan the directory - // for any files that need to be deleted. - if m.opts.State != nil { - if b, _ := m.opts.State.ReadState(ipn.TaildropReceivedKey); len(b) == 0 { - if err := m.opts.State.WriteState(ipn.TaildropReceivedKey, []byte{1}); err != nil { - m.opts.Logf("WriteState error: %v", err) // non-fatal error - } - } - } - - // A positive offset implies that we are resuming an existing file. - // Seek to the appropriate offset and truncate the file. - if offset != 0 { - currLength, err := f.Seek(0, io.SeekEnd) - if err != nil { - return 0, redactAndLogError("Seek", err) - } - if offset < 0 || offset > currLength { - return 0, redactAndLogError("Seek", err) - } - if _, err := f.Seek(offset, io.SeekStart); err != nil { - return 0, redactAndLogError("Seek", err) - } - if err := f.Truncate(offset); err != nil { - return 0, redactAndLogError("Truncate", err) - } - } - - // Copy the contents of the file. - copyLength, err := io.Copy(inFile, r) - if err != nil { - return 0, redactAndLogError("Copy", err) - } - if length >= 0 && copyLength != length { - return 0, redactAndLogError("Copy", errors.New("copied an unexpected number of bytes")) - } - if err := f.Close(); err != nil { - return 0, redactAndLogError("Close", err) - } - fileLength := offset + copyLength - - inFile.mu.Lock() - inFile.done = true - inFile.mu.Unlock() - - // File has been successfully received, rename the partial file - // to the final destination filename. If a file of that name already exists, - // then try multiple times with variations of the filename. - computePartialSum := sync.OnceValues(func() ([sha256.Size]byte, error) { - return sha256File(partialPath) - }) - maxRetries := 10 - for ; maxRetries > 0; maxRetries-- { - // Atomically rename the partial file as the destination file if it doesn't exist. - // Otherwise, it returns the length of the current destination file. - // The operation is atomic. - dstLength, err := func() (int64, error) { - m.renameMu.Lock() - defer m.renameMu.Unlock() - switch fi, err := os.Stat(dstPath); { - case os.IsNotExist(err): - return -1, os.Rename(partialPath, dstPath) - case err != nil: - return -1, err - default: - return fi.Size(), nil - } - }() - if err != nil { - return 0, redactAndLogError("Rename", err) - } - if dstLength < 0 { - break // we successfully renamed; so stop - } - - // Avoid the final rename if a destination file has the same contents. - // - // Note: this is best effort and copying files from iOS from the Media Library - // results in processing on the iOS side which means the size and shas of the - // same file can be different. - if dstLength == fileLength { - partialSum, err := computePartialSum() - if err != nil { - return 0, redactAndLogError("Rename", err) - } - dstSum, err := sha256File(dstPath) - if err != nil { - return 0, redactAndLogError("Rename", err) - } - if dstSum == partialSum { - if err := os.Remove(partialPath); err != nil { - return 0, redactAndLogError("Remove", err) - } - break // we successfully found a content match; so stop - } - } - - // Choose a new destination filename and try again. - dstPath = NextFilename(dstPath) - inFile.finalPath = dstPath - } - if maxRetries <= 0 { - return 0, errors.New("too many retries trying to rename partial file") - } - m.totalReceived.Add(1) - m.opts.SendFileNotify() - return fileLength, nil -} - -func sha256File(file string) (out [sha256.Size]byte, err error) { - h := sha256.New() - f, err := os.Open(file) - if err != nil { - return out, err - } - defer f.Close() - if _, err := io.Copy(h, f); err != nil { - return out, err - } - return [sha256.Size]byte(h.Sum(nil)), nil -} diff --git a/tempfork/acme/README.md b/tempfork/acme/README.md new file mode 100644 index 000000000..def357fc1 --- /dev/null +++ b/tempfork/acme/README.md @@ -0,0 +1,14 @@ +# tempfork/acme + +This is a vendored copy of Tailscale's https://github.com/tailscale/golang-x-crypto, +which is a fork of golang.org/x/crypto/acme. + +See https://github.com/tailscale/tailscale/issues/10238 for unforking +status. + +The https://github.com/tailscale/golang-x-crypto location exists to +let us do rebases from upstream easily, and then we update tempfork/acme +in the same commit we go get github.com/tailscale/golang-x-crypto@main. +See the comment on the TestSyncedToUpstream test for details. That +test should catch that forgotten step. + diff --git a/tempfork/acme/acme.go b/tempfork/acme/acme.go new file mode 100644 index 000000000..bbddb9551 --- /dev/null +++ b/tempfork/acme/acme.go @@ -0,0 +1,866 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package acme provides an implementation of the +// Automatic Certificate Management Environment (ACME) spec, +// most famously used by Let's Encrypt. +// +// The initial implementation of this package was based on an early version +// of the spec. The current implementation supports only the modern +// RFC 8555 but some of the old API surface remains for compatibility. +// While code using the old API will still compile, it will return an error. +// Note the deprecation comments to update your code. +// +// See https://tools.ietf.org/html/rfc8555 for the spec. +// +// Most common scenarios will want to use autocert subdirectory instead, +// which provides automatic access to certificates from Let's Encrypt +// and any other ACME-based CA. +package acme + +import ( + "context" + "crypto" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/sha256" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/asn1" + "encoding/base64" + "encoding/hex" + "encoding/json" + "encoding/pem" + "errors" + "fmt" + "math/big" + "net/http" + "strings" + "sync" + "time" +) + +const ( + // LetsEncryptURL is the Directory endpoint of Let's Encrypt CA. + LetsEncryptURL = "https://acme-v02.api.letsencrypt.org/directory" + + // ALPNProto is the ALPN protocol name used by a CA server when validating + // tls-alpn-01 challenges. + // + // Package users must ensure their servers can negotiate the ACME ALPN in + // order for tls-alpn-01 challenge verifications to succeed. + // See the crypto/tls package's Config.NextProtos field. + ALPNProto = "acme-tls/1" +) + +// idPeACMEIdentifier is the OID for the ACME extension for the TLS-ALPN challenge. +// https://tools.ietf.org/html/draft-ietf-acme-tls-alpn-05#section-5.1 +var idPeACMEIdentifier = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 1, 31} + +const ( + maxChainLen = 5 // max depth and breadth of a certificate chain + maxCertSize = 1 << 20 // max size of a certificate, in DER bytes + // Used for decoding certs from application/pem-certificate-chain response, + // the default when in RFC mode. + maxCertChainSize = maxCertSize * maxChainLen + + // Max number of collected nonces kept in memory. + // Expect usual peak of 1 or 2. + maxNonces = 100 +) + +// Client is an ACME client. +// +// The only required field is Key. An example of creating a client with a new key +// is as follows: +// +// key, err := rsa.GenerateKey(rand.Reader, 2048) +// if err != nil { +// log.Fatal(err) +// } +// client := &Client{Key: key} +type Client struct { + // Key is the account key used to register with a CA and sign requests. + // Key.Public() must return a *rsa.PublicKey or *ecdsa.PublicKey. + // + // The following algorithms are supported: + // RS256, ES256, ES384 and ES512. + // See RFC 7518 for more details about the algorithms. + Key crypto.Signer + + // HTTPClient optionally specifies an HTTP client to use + // instead of http.DefaultClient. + HTTPClient *http.Client + + // DirectoryURL points to the CA directory endpoint. + // If empty, LetsEncryptURL is used. + // Mutating this value after a successful call of Client's Discover method + // will have no effect. + DirectoryURL string + + // RetryBackoff computes the duration after which the nth retry of a failed request + // should occur. The value of n for the first call on failure is 1. + // The values of r and resp are the request and response of the last failed attempt. + // If the returned value is negative or zero, no more retries are done and an error + // is returned to the caller of the original method. + // + // Requests which result in a 4xx client error are not retried, + // except for 400 Bad Request due to "bad nonce" errors and 429 Too Many Requests. + // + // If RetryBackoff is nil, a truncated exponential backoff algorithm + // with the ceiling of 10 seconds is used, where each subsequent retry n + // is done after either ("Retry-After" + jitter) or (2^n seconds + jitter), + // preferring the former if "Retry-After" header is found in the resp. + // The jitter is a random value up to 1 second. + RetryBackoff func(n int, r *http.Request, resp *http.Response) time.Duration + + // UserAgent is prepended to the User-Agent header sent to the ACME server, + // which by default is this package's name and version. + // + // Reusable libraries and tools in particular should set this value to be + // identifiable by the server, in case they are causing issues. + UserAgent string + + cacheMu sync.Mutex + dir *Directory // cached result of Client's Discover method + // KID is the key identifier provided by the CA. If not provided it will be + // retrieved from the CA by making a call to the registration endpoint. + KID KeyID + + noncesMu sync.Mutex + nonces map[string]struct{} // nonces collected from previous responses +} + +// accountKID returns a key ID associated with c.Key, the account identity +// provided by the CA during RFC based registration. +// It assumes c.Discover has already been called. +// +// accountKID requires at most one network roundtrip. +// It caches only successful result. +// +// When in pre-RFC mode or when c.getRegRFC responds with an error, accountKID +// returns noKeyID. +func (c *Client) accountKID(ctx context.Context) KeyID { + c.cacheMu.Lock() + defer c.cacheMu.Unlock() + if c.KID != noKeyID { + return c.KID + } + a, err := c.getRegRFC(ctx) + if err != nil { + return noKeyID + } + c.KID = KeyID(a.URI) + return c.KID +} + +var errPreRFC = errors.New("acme: server does not support the RFC 8555 version of ACME") + +// Discover performs ACME server discovery using c.DirectoryURL. +// +// It caches successful result. So, subsequent calls will not result in +// a network round-trip. This also means mutating c.DirectoryURL after successful call +// of this method will have no effect. +func (c *Client) Discover(ctx context.Context) (Directory, error) { + c.cacheMu.Lock() + defer c.cacheMu.Unlock() + if c.dir != nil { + return *c.dir, nil + } + + res, err := c.get(ctx, c.directoryURL(), wantStatus(http.StatusOK)) + if err != nil { + return Directory{}, err + } + defer res.Body.Close() + c.addNonce(res.Header) + + var v struct { + Reg string `json:"newAccount"` + Authz string `json:"newAuthz"` + Order string `json:"newOrder"` + Revoke string `json:"revokeCert"` + Nonce string `json:"newNonce"` + KeyChange string `json:"keyChange"` + RenewalInfo string `json:"renewalInfo"` + Meta struct { + Terms string `json:"termsOfService"` + Website string `json:"website"` + CAA []string `json:"caaIdentities"` + ExternalAcct bool `json:"externalAccountRequired"` + } + } + if err := json.NewDecoder(res.Body).Decode(&v); err != nil { + return Directory{}, err + } + if v.Order == "" { + return Directory{}, errPreRFC + } + c.dir = &Directory{ + RegURL: v.Reg, + AuthzURL: v.Authz, + OrderURL: v.Order, + RevokeURL: v.Revoke, + NonceURL: v.Nonce, + KeyChangeURL: v.KeyChange, + RenewalInfoURL: v.RenewalInfo, + Terms: v.Meta.Terms, + Website: v.Meta.Website, + CAA: v.Meta.CAA, + ExternalAccountRequired: v.Meta.ExternalAcct, + } + return *c.dir, nil +} + +func (c *Client) directoryURL() string { + if c.DirectoryURL != "" { + return c.DirectoryURL + } + return LetsEncryptURL +} + +// CreateCert was part of the old version of ACME. It is incompatible with RFC 8555. +// +// Deprecated: this was for the pre-RFC 8555 version of ACME. Callers should use CreateOrderCert. +func (c *Client) CreateCert(ctx context.Context, csr []byte, exp time.Duration, bundle bool) (der [][]byte, certURL string, err error) { + return nil, "", errPreRFC +} + +// FetchCert retrieves already issued certificate from the given url, in DER format. +// It retries the request until the certificate is successfully retrieved, +// context is cancelled by the caller or an error response is received. +// +// If the bundle argument is true, the returned value also contains the CA (issuer) +// certificate chain. +// +// FetchCert returns an error if the CA's response or chain was unreasonably large. +// Callers are encouraged to parse the returned value to ensure the certificate is valid +// and has expected features. +func (c *Client) FetchCert(ctx context.Context, url string, bundle bool) ([][]byte, error) { + if _, err := c.Discover(ctx); err != nil { + return nil, err + } + return c.fetchCertRFC(ctx, url, bundle) +} + +// RevokeCert revokes a previously issued certificate cert, provided in DER format. +// +// The key argument, used to sign the request, must be authorized +// to revoke the certificate. It's up to the CA to decide which keys are authorized. +// For instance, the key pair of the certificate may be authorized. +// If the key is nil, c.Key is used instead. +func (c *Client) RevokeCert(ctx context.Context, key crypto.Signer, cert []byte, reason CRLReasonCode) error { + if _, err := c.Discover(ctx); err != nil { + return err + } + return c.revokeCertRFC(ctx, key, cert, reason) +} + +// FetchRenewalInfo retrieves the RenewalInfo from Directory.RenewalInfoURL. +func (c *Client) FetchRenewalInfo(ctx context.Context, leaf []byte) (*RenewalInfo, error) { + if _, err := c.Discover(ctx); err != nil { + return nil, err + } + + parsedLeaf, err := x509.ParseCertificate(leaf) + if err != nil { + return nil, fmt.Errorf("parsing leaf certificate: %w", err) + } + + renewalURL := c.getRenewalURL(parsedLeaf) + + res, err := c.get(ctx, renewalURL, wantStatus(http.StatusOK)) + if err != nil { + return nil, fmt.Errorf("fetching renewal info: %w", err) + } + defer res.Body.Close() + + var info RenewalInfo + if err := json.NewDecoder(res.Body).Decode(&info); err != nil { + return nil, fmt.Errorf("parsing renewal info response: %w", err) + } + return &info, nil +} + +func (c *Client) getRenewalURL(cert *x509.Certificate) string { + // See https://www.ietf.org/archive/id/draft-ietf-acme-ari-04.html#name-the-renewalinfo-resource + // for how the request URL is built. + url := c.dir.RenewalInfoURL + if !strings.HasSuffix(url, "/") { + url += "/" + } + return url + certRenewalIdentifier(cert) +} + +func certRenewalIdentifier(cert *x509.Certificate) string { + aki := base64.RawURLEncoding.EncodeToString(cert.AuthorityKeyId) + serial := base64.RawURLEncoding.EncodeToString(cert.SerialNumber.Bytes()) + return aki + "." + serial +} + +// AcceptTOS always returns true to indicate the acceptance of a CA's Terms of Service +// during account registration. See Register method of Client for more details. +func AcceptTOS(tosURL string) bool { return true } + +// Register creates a new account with the CA using c.Key. +// It returns the registered account. The account acct is not modified. +// +// The registration may require the caller to agree to the CA's Terms of Service (TOS). +// If so, and the account has not indicated the acceptance of the terms (see Account for details), +// Register calls prompt with a TOS URL provided by the CA. Prompt should report +// whether the caller agrees to the terms. To always accept the terms, the caller can use AcceptTOS. +// +// When interfacing with an RFC-compliant CA, non-RFC 8555 fields of acct are ignored +// and prompt is called if Directory's Terms field is non-zero. +// Also see Error's Instance field for when a CA requires already registered accounts to agree +// to an updated Terms of Service. +func (c *Client) Register(ctx context.Context, acct *Account, prompt func(tosURL string) bool) (*Account, error) { + if c.Key == nil { + return nil, errors.New("acme: client.Key must be set to Register") + } + if _, err := c.Discover(ctx); err != nil { + return nil, err + } + return c.registerRFC(ctx, acct, prompt) +} + +// GetReg retrieves an existing account associated with c.Key. +// +// The url argument is a legacy artifact of the pre-RFC 8555 API +// and is ignored. +func (c *Client) GetReg(ctx context.Context, url string) (*Account, error) { + if _, err := c.Discover(ctx); err != nil { + return nil, err + } + return c.getRegRFC(ctx) +} + +// UpdateReg updates an existing registration. +// It returns an updated account copy. The provided account is not modified. +// +// The account's URI is ignored and the account URL associated with +// c.Key is used instead. +func (c *Client) UpdateReg(ctx context.Context, acct *Account) (*Account, error) { + if _, err := c.Discover(ctx); err != nil { + return nil, err + } + return c.updateRegRFC(ctx, acct) +} + +// AccountKeyRollover attempts to transition a client's account key to a new key. +// On success client's Key is updated which is not concurrency safe. +// On failure an error will be returned. +// The new key is already registered with the ACME provider if the following is true: +// - error is of type acme.Error +// - StatusCode should be 409 (Conflict) +// - Location header will have the KID of the associated account +// +// More about account key rollover can be found at +// https://tools.ietf.org/html/rfc8555#section-7.3.5. +func (c *Client) AccountKeyRollover(ctx context.Context, newKey crypto.Signer) error { + return c.accountKeyRollover(ctx, newKey) +} + +// Authorize performs the initial step in the pre-authorization flow, +// as opposed to order-based flow. +// The caller will then need to choose from and perform a set of returned +// challenges using c.Accept in order to successfully complete authorization. +// +// Once complete, the caller can use AuthorizeOrder which the CA +// should provision with the already satisfied authorization. +// For pre-RFC CAs, the caller can proceed directly to requesting a certificate +// using CreateCert method. +// +// If an authorization has been previously granted, the CA may return +// a valid authorization which has its Status field set to StatusValid. +// +// More about pre-authorization can be found at +// https://tools.ietf.org/html/rfc8555#section-7.4.1. +func (c *Client) Authorize(ctx context.Context, domain string) (*Authorization, error) { + return c.authorize(ctx, "dns", domain) +} + +// AuthorizeIP is the same as Authorize but requests IP address authorization. +// Clients which successfully obtain such authorization may request to issue +// a certificate for IP addresses. +// +// See the ACME spec extension for more details about IP address identifiers: +// https://tools.ietf.org/html/draft-ietf-acme-ip. +func (c *Client) AuthorizeIP(ctx context.Context, ipaddr string) (*Authorization, error) { + return c.authorize(ctx, "ip", ipaddr) +} + +func (c *Client) authorize(ctx context.Context, typ, val string) (*Authorization, error) { + if _, err := c.Discover(ctx); err != nil { + return nil, err + } + + type authzID struct { + Type string `json:"type"` + Value string `json:"value"` + } + req := struct { + Resource string `json:"resource"` + Identifier authzID `json:"identifier"` + }{ + Resource: "new-authz", + Identifier: authzID{Type: typ, Value: val}, + } + res, err := c.post(ctx, nil, c.dir.AuthzURL, req, wantStatus(http.StatusCreated)) + if err != nil { + return nil, err + } + defer res.Body.Close() + + var v wireAuthz + if err := json.NewDecoder(res.Body).Decode(&v); err != nil { + return nil, fmt.Errorf("acme: invalid response: %v", err) + } + if v.Status != StatusPending && v.Status != StatusValid { + return nil, fmt.Errorf("acme: unexpected status: %s", v.Status) + } + return v.authorization(res.Header.Get("Location")), nil +} + +// GetAuthorization retrieves an authorization identified by the given URL. +// +// If a caller needs to poll an authorization until its status is final, +// see the WaitAuthorization method. +func (c *Client) GetAuthorization(ctx context.Context, url string) (*Authorization, error) { + if _, err := c.Discover(ctx); err != nil { + return nil, err + } + + res, err := c.postAsGet(ctx, url, wantStatus(http.StatusOK)) + if err != nil { + return nil, err + } + defer res.Body.Close() + var v wireAuthz + if err := json.NewDecoder(res.Body).Decode(&v); err != nil { + return nil, fmt.Errorf("acme: invalid response: %v", err) + } + return v.authorization(url), nil +} + +// RevokeAuthorization relinquishes an existing authorization identified +// by the given URL. +// The url argument is an Authorization.URI value. +// +// If successful, the caller will be required to obtain a new authorization +// using the Authorize or AuthorizeOrder methods before being able to request +// a new certificate for the domain associated with the authorization. +// +// It does not revoke existing certificates. +func (c *Client) RevokeAuthorization(ctx context.Context, url string) error { + if _, err := c.Discover(ctx); err != nil { + return err + } + + req := struct { + Resource string `json:"resource"` + Status string `json:"status"` + Delete bool `json:"delete"` + }{ + Resource: "authz", + Status: "deactivated", + Delete: true, + } + res, err := c.post(ctx, nil, url, req, wantStatus(http.StatusOK)) + if err != nil { + return err + } + defer res.Body.Close() + return nil +} + +// WaitAuthorization polls an authorization at the given URL +// until it is in one of the final states, StatusValid or StatusInvalid, +// the ACME CA responded with a 4xx error code, or the context is done. +// +// It returns a non-nil Authorization only if its Status is StatusValid. +// In all other cases WaitAuthorization returns an error. +// If the Status is StatusInvalid, the returned error is of type *AuthorizationError. +func (c *Client) WaitAuthorization(ctx context.Context, url string) (*Authorization, error) { + if _, err := c.Discover(ctx); err != nil { + return nil, err + } + for { + res, err := c.postAsGet(ctx, url, wantStatus(http.StatusOK, http.StatusAccepted)) + if err != nil { + return nil, err + } + + var raw wireAuthz + err = json.NewDecoder(res.Body).Decode(&raw) + res.Body.Close() + switch { + case err != nil: + // Skip and retry. + case raw.Status == StatusValid: + return raw.authorization(url), nil + case raw.Status == StatusInvalid: + return nil, raw.error(url) + } + + // Exponential backoff is implemented in c.get above. + // This is just to prevent continuously hitting the CA + // while waiting for a final authorization status. + d := retryAfter(res.Header.Get("Retry-After")) + if d == 0 { + // Given that the fastest challenges TLS-SNI and HTTP-01 + // require a CA to make at least 1 network round trip + // and most likely persist a challenge state, + // this default delay seems reasonable. + d = time.Second + } + t := time.NewTimer(d) + select { + case <-ctx.Done(): + t.Stop() + return nil, ctx.Err() + case <-t.C: + // Retry. + } + } +} + +// GetChallenge retrieves the current status of an challenge. +// +// A client typically polls a challenge status using this method. +func (c *Client) GetChallenge(ctx context.Context, url string) (*Challenge, error) { + if _, err := c.Discover(ctx); err != nil { + return nil, err + } + + res, err := c.postAsGet(ctx, url, wantStatus(http.StatusOK, http.StatusAccepted)) + if err != nil { + return nil, err + } + + defer res.Body.Close() + v := wireChallenge{URI: url} + if err := json.NewDecoder(res.Body).Decode(&v); err != nil { + return nil, fmt.Errorf("acme: invalid response: %v", err) + } + return v.challenge(), nil +} + +// Accept informs the server that the client accepts one of its challenges +// previously obtained with c.Authorize. +// +// The server will then perform the validation asynchronously. +func (c *Client) Accept(ctx context.Context, chal *Challenge) (*Challenge, error) { + if _, err := c.Discover(ctx); err != nil { + return nil, err + } + + payload := json.RawMessage("{}") + if len(chal.Payload) != 0 { + payload = chal.Payload + } + res, err := c.post(ctx, nil, chal.URI, payload, wantStatus( + http.StatusOK, // according to the spec + http.StatusAccepted, // Let's Encrypt: see https://goo.gl/WsJ7VT (acme-divergences.md) + )) + if err != nil { + return nil, err + } + defer res.Body.Close() + + var v wireChallenge + if err := json.NewDecoder(res.Body).Decode(&v); err != nil { + return nil, fmt.Errorf("acme: invalid response: %v", err) + } + return v.challenge(), nil +} + +// DNS01ChallengeRecord returns a DNS record value for a dns-01 challenge response. +// A TXT record containing the returned value must be provisioned under +// "_acme-challenge" name of the domain being validated. +// +// The token argument is a Challenge.Token value. +func (c *Client) DNS01ChallengeRecord(token string) (string, error) { + ka, err := keyAuth(c.Key.Public(), token) + if err != nil { + return "", err + } + b := sha256.Sum256([]byte(ka)) + return base64.RawURLEncoding.EncodeToString(b[:]), nil +} + +// HTTP01ChallengeResponse returns the response for an http-01 challenge. +// Servers should respond with the value to HTTP requests at the URL path +// provided by HTTP01ChallengePath to validate the challenge and prove control +// over a domain name. +// +// The token argument is a Challenge.Token value. +func (c *Client) HTTP01ChallengeResponse(token string) (string, error) { + return keyAuth(c.Key.Public(), token) +} + +// HTTP01ChallengePath returns the URL path at which the response for an http-01 challenge +// should be provided by the servers. +// The response value can be obtained with HTTP01ChallengeResponse. +// +// The token argument is a Challenge.Token value. +func (c *Client) HTTP01ChallengePath(token string) string { + return "/.well-known/acme-challenge/" + token +} + +// TLSSNI01ChallengeCert creates a certificate for TLS-SNI-01 challenge response. +// +// Deprecated: This challenge type is unused in both draft-02 and RFC versions of the ACME spec. +func (c *Client) TLSSNI01ChallengeCert(token string, opt ...CertOption) (cert tls.Certificate, name string, err error) { + ka, err := keyAuth(c.Key.Public(), token) + if err != nil { + return tls.Certificate{}, "", err + } + b := sha256.Sum256([]byte(ka)) + h := hex.EncodeToString(b[:]) + name = fmt.Sprintf("%s.%s.acme.invalid", h[:32], h[32:]) + cert, err = tlsChallengeCert([]string{name}, opt) + if err != nil { + return tls.Certificate{}, "", err + } + return cert, name, nil +} + +// TLSSNI02ChallengeCert creates a certificate for TLS-SNI-02 challenge response. +// +// Deprecated: This challenge type is unused in both draft-02 and RFC versions of the ACME spec. +func (c *Client) TLSSNI02ChallengeCert(token string, opt ...CertOption) (cert tls.Certificate, name string, err error) { + b := sha256.Sum256([]byte(token)) + h := hex.EncodeToString(b[:]) + sanA := fmt.Sprintf("%s.%s.token.acme.invalid", h[:32], h[32:]) + + ka, err := keyAuth(c.Key.Public(), token) + if err != nil { + return tls.Certificate{}, "", err + } + b = sha256.Sum256([]byte(ka)) + h = hex.EncodeToString(b[:]) + sanB := fmt.Sprintf("%s.%s.ka.acme.invalid", h[:32], h[32:]) + + cert, err = tlsChallengeCert([]string{sanA, sanB}, opt) + if err != nil { + return tls.Certificate{}, "", err + } + return cert, sanA, nil +} + +// TLSALPN01ChallengeCert creates a certificate for TLS-ALPN-01 challenge response. +// Servers can present the certificate to validate the challenge and prove control +// over a domain name. For more details on TLS-ALPN-01 see +// https://tools.ietf.org/html/draft-shoemaker-acme-tls-alpn-00#section-3 +// +// The token argument is a Challenge.Token value. +// If a WithKey option is provided, its private part signs the returned cert, +// and the public part is used to specify the signee. +// If no WithKey option is provided, a new ECDSA key is generated using P-256 curve. +// +// The returned certificate is valid for the next 24 hours and must be presented only when +// the server name in the TLS ClientHello matches the domain, and the special acme-tls/1 ALPN protocol +// has been specified. +func (c *Client) TLSALPN01ChallengeCert(token, domain string, opt ...CertOption) (cert tls.Certificate, err error) { + ka, err := keyAuth(c.Key.Public(), token) + if err != nil { + return tls.Certificate{}, err + } + shasum := sha256.Sum256([]byte(ka)) + extValue, err := asn1.Marshal(shasum[:]) + if err != nil { + return tls.Certificate{}, err + } + acmeExtension := pkix.Extension{ + Id: idPeACMEIdentifier, + Critical: true, + Value: extValue, + } + + tmpl := defaultTLSChallengeCertTemplate() + + var newOpt []CertOption + for _, o := range opt { + switch o := o.(type) { + case *certOptTemplate: + t := *(*x509.Certificate)(o) // shallow copy is ok + tmpl = &t + default: + newOpt = append(newOpt, o) + } + } + tmpl.ExtraExtensions = append(tmpl.ExtraExtensions, acmeExtension) + newOpt = append(newOpt, WithTemplate(tmpl)) + return tlsChallengeCert([]string{domain}, newOpt) +} + +// popNonce returns a nonce value previously stored with c.addNonce +// or fetches a fresh one from c.dir.NonceURL. +// If NonceURL is empty, it first tries c.directoryURL() and, failing that, +// the provided url. +func (c *Client) popNonce(ctx context.Context, url string) (string, error) { + c.noncesMu.Lock() + defer c.noncesMu.Unlock() + if len(c.nonces) == 0 { + if c.dir != nil && c.dir.NonceURL != "" { + return c.fetchNonce(ctx, c.dir.NonceURL) + } + dirURL := c.directoryURL() + v, err := c.fetchNonce(ctx, dirURL) + if err != nil && url != dirURL { + v, err = c.fetchNonce(ctx, url) + } + return v, err + } + var nonce string + for nonce = range c.nonces { + delete(c.nonces, nonce) + break + } + return nonce, nil +} + +// clearNonces clears any stored nonces +func (c *Client) clearNonces() { + c.noncesMu.Lock() + defer c.noncesMu.Unlock() + c.nonces = make(map[string]struct{}) +} + +// addNonce stores a nonce value found in h (if any) for future use. +func (c *Client) addNonce(h http.Header) { + v := nonceFromHeader(h) + if v == "" { + return + } + c.noncesMu.Lock() + defer c.noncesMu.Unlock() + if len(c.nonces) >= maxNonces { + return + } + if c.nonces == nil { + c.nonces = make(map[string]struct{}) + } + c.nonces[v] = struct{}{} +} + +func (c *Client) fetchNonce(ctx context.Context, url string) (string, error) { + r, err := http.NewRequest("HEAD", url, nil) + if err != nil { + return "", err + } + resp, err := c.doNoRetry(ctx, r) + if err != nil { + return "", err + } + defer resp.Body.Close() + nonce := nonceFromHeader(resp.Header) + if nonce == "" { + if resp.StatusCode > 299 { + return "", responseError(resp) + } + return "", errors.New("acme: nonce not found") + } + return nonce, nil +} + +func nonceFromHeader(h http.Header) string { + return h.Get("Replay-Nonce") +} + +// linkHeader returns URI-Reference values of all Link headers +// with relation-type rel. +// See https://tools.ietf.org/html/rfc5988#section-5 for details. +func linkHeader(h http.Header, rel string) []string { + var links []string + for _, v := range h["Link"] { + parts := strings.Split(v, ";") + for _, p := range parts { + p = strings.TrimSpace(p) + if !strings.HasPrefix(p, "rel=") { + continue + } + if v := strings.Trim(p[4:], `"`); v == rel { + links = append(links, strings.Trim(parts[0], "<>")) + } + } + } + return links +} + +// keyAuth generates a key authorization string for a given token. +func keyAuth(pub crypto.PublicKey, token string) (string, error) { + th, err := JWKThumbprint(pub) + if err != nil { + return "", err + } + return fmt.Sprintf("%s.%s", token, th), nil +} + +// defaultTLSChallengeCertTemplate is a template used to create challenge certs for TLS challenges. +func defaultTLSChallengeCertTemplate() *x509.Certificate { + return &x509.Certificate{ + SerialNumber: big.NewInt(1), + NotBefore: time.Now(), + NotAfter: time.Now().Add(24 * time.Hour), + BasicConstraintsValid: true, + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + } +} + +// tlsChallengeCert creates a temporary certificate for TLS-SNI challenges +// with the given SANs and auto-generated public/private key pair. +// The Subject Common Name is set to the first SAN to aid debugging. +// To create a cert with a custom key pair, specify WithKey option. +func tlsChallengeCert(san []string, opt []CertOption) (tls.Certificate, error) { + var key crypto.Signer + tmpl := defaultTLSChallengeCertTemplate() + for _, o := range opt { + switch o := o.(type) { + case *certOptKey: + if key != nil { + return tls.Certificate{}, errors.New("acme: duplicate key option") + } + key = o.key + case *certOptTemplate: + t := *(*x509.Certificate)(o) // shallow copy is ok + tmpl = &t + default: + // package's fault, if we let this happen: + panic(fmt.Sprintf("unsupported option type %T", o)) + } + } + if key == nil { + var err error + if key, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader); err != nil { + return tls.Certificate{}, err + } + } + tmpl.DNSNames = san + if len(san) > 0 { + tmpl.Subject.CommonName = san[0] + } + + der, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, key.Public(), key) + if err != nil { + return tls.Certificate{}, err + } + return tls.Certificate{ + Certificate: [][]byte{der}, + PrivateKey: key, + }, nil +} + +// encodePEM returns b encoded as PEM with block of type typ. +func encodePEM(typ string, b []byte) []byte { + pb := &pem.Block{Type: typ, Bytes: b} + return pem.EncodeToMemory(pb) +} + +// timeNow is time.Now, except in tests which can mess with it. +var timeNow = time.Now diff --git a/tempfork/acme/acme_test.go b/tempfork/acme/acme_test.go new file mode 100644 index 000000000..f0c45aea9 --- /dev/null +++ b/tempfork/acme/acme_test.go @@ -0,0 +1,970 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package acme + +import ( + "bytes" + "context" + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/base64" + "encoding/hex" + "encoding/json" + "encoding/pem" + "fmt" + "io" + "math/big" + "net/http" + "net/http/httptest" + "net/url" + "reflect" + "sort" + "strings" + "testing" + "time" +) + +// newTestClient creates a client with a non-nil Directory so that it skips +// the discovery which is otherwise done on the first call of almost every +// exported method. +func newTestClient() *Client { + return &Client{ + Key: testKeyEC, + dir: &Directory{}, // skip discovery + } +} + +// newTestClientWithMockDirectory creates a client with a non-nil Directory +// that contains mock field values. +func newTestClientWithMockDirectory() *Client { + return &Client{ + Key: testKeyEC, + dir: &Directory{ + RenewalInfoURL: "https://example.com/acme/renewal-info/", + }, + } +} + +// Decodes a JWS-encoded request and unmarshals the decoded JSON into a provided +// interface. +func decodeJWSRequest(t *testing.T, v interface{}, r io.Reader) { + // Decode request + var req struct{ Payload string } + if err := json.NewDecoder(r).Decode(&req); err != nil { + t.Fatal(err) + } + payload, err := base64.RawURLEncoding.DecodeString(req.Payload) + if err != nil { + t.Fatal(err) + } + err = json.Unmarshal(payload, v) + if err != nil { + t.Fatal(err) + } +} + +type jwsHead struct { + Alg string + Nonce string + URL string `json:"url"` + KID string `json:"kid"` + JWK map[string]string `json:"jwk"` +} + +func decodeJWSHead(r io.Reader) (*jwsHead, error) { + var req struct{ Protected string } + if err := json.NewDecoder(r).Decode(&req); err != nil { + return nil, err + } + b, err := base64.RawURLEncoding.DecodeString(req.Protected) + if err != nil { + return nil, err + } + var head jwsHead + if err := json.Unmarshal(b, &head); err != nil { + return nil, err + } + return &head, nil +} + +func TestRegisterWithoutKey(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == "HEAD" { + w.Header().Set("Replay-Nonce", "test-nonce") + return + } + w.WriteHeader(http.StatusCreated) + fmt.Fprint(w, `{}`) + })) + defer ts.Close() + // First verify that using a complete client results in success. + c := Client{ + Key: testKeyEC, + DirectoryURL: ts.URL, + dir: &Directory{RegURL: ts.URL}, + } + if _, err := c.Register(context.Background(), &Account{}, AcceptTOS); err != nil { + t.Fatalf("c.Register() = %v; want success with a complete test client", err) + } + c.Key = nil + if _, err := c.Register(context.Background(), &Account{}, AcceptTOS); err == nil { + t.Error("c.Register() from client without key succeeded, wanted error") + } +} + +func TestAuthorize(t *testing.T) { + tt := []struct{ typ, value string }{ + {"dns", "example.com"}, + {"ip", "1.2.3.4"}, + } + for _, test := range tt { + t.Run(test.typ, func(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == "HEAD" { + w.Header().Set("Replay-Nonce", "test-nonce") + return + } + if r.Method != "POST" { + t.Errorf("r.Method = %q; want POST", r.Method) + } + + var j struct { + Resource string + Identifier struct { + Type string + Value string + } + } + decodeJWSRequest(t, &j, r.Body) + + // Test request + if j.Resource != "new-authz" { + t.Errorf("j.Resource = %q; want new-authz", j.Resource) + } + if j.Identifier.Type != test.typ { + t.Errorf("j.Identifier.Type = %q; want %q", j.Identifier.Type, test.typ) + } + if j.Identifier.Value != test.value { + t.Errorf("j.Identifier.Value = %q; want %q", j.Identifier.Value, test.value) + } + + w.Header().Set("Location", "https://ca.tld/acme/auth/1") + w.WriteHeader(http.StatusCreated) + fmt.Fprintf(w, `{ + "identifier": {"type":%q,"value":%q}, + "status":"pending", + "challenges":[ + { + "type":"http-01", + "status":"pending", + "uri":"https://ca.tld/acme/challenge/publickey/id1", + "token":"token1" + }, + { + "type":"tls-sni-01", + "status":"pending", + "uri":"https://ca.tld/acme/challenge/publickey/id2", + "token":"token2" + } + ], + "combinations":[[0],[1]] + }`, test.typ, test.value) + })) + defer ts.Close() + + var ( + auth *Authorization + err error + ) + cl := Client{ + Key: testKeyEC, + DirectoryURL: ts.URL, + dir: &Directory{AuthzURL: ts.URL}, + } + switch test.typ { + case "dns": + auth, err = cl.Authorize(context.Background(), test.value) + case "ip": + auth, err = cl.AuthorizeIP(context.Background(), test.value) + default: + t.Fatalf("unknown identifier type: %q", test.typ) + } + if err != nil { + t.Fatal(err) + } + + if auth.URI != "https://ca.tld/acme/auth/1" { + t.Errorf("URI = %q; want https://ca.tld/acme/auth/1", auth.URI) + } + if auth.Status != "pending" { + t.Errorf("Status = %q; want pending", auth.Status) + } + if auth.Identifier.Type != test.typ { + t.Errorf("Identifier.Type = %q; want %q", auth.Identifier.Type, test.typ) + } + if auth.Identifier.Value != test.value { + t.Errorf("Identifier.Value = %q; want %q", auth.Identifier.Value, test.value) + } + + if n := len(auth.Challenges); n != 2 { + t.Fatalf("len(auth.Challenges) = %d; want 2", n) + } + + c := auth.Challenges[0] + if c.Type != "http-01" { + t.Errorf("c.Type = %q; want http-01", c.Type) + } + if c.URI != "https://ca.tld/acme/challenge/publickey/id1" { + t.Errorf("c.URI = %q; want https://ca.tld/acme/challenge/publickey/id1", c.URI) + } + if c.Token != "token1" { + t.Errorf("c.Token = %q; want token1", c.Token) + } + + c = auth.Challenges[1] + if c.Type != "tls-sni-01" { + t.Errorf("c.Type = %q; want tls-sni-01", c.Type) + } + if c.URI != "https://ca.tld/acme/challenge/publickey/id2" { + t.Errorf("c.URI = %q; want https://ca.tld/acme/challenge/publickey/id2", c.URI) + } + if c.Token != "token2" { + t.Errorf("c.Token = %q; want token2", c.Token) + } + + combs := [][]int{{0}, {1}} + if !reflect.DeepEqual(auth.Combinations, combs) { + t.Errorf("auth.Combinations: %+v\nwant: %+v\n", auth.Combinations, combs) + } + + }) + } +} + +func TestAuthorizeValid(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == "HEAD" { + w.Header().Set("Replay-Nonce", "nonce") + return + } + w.WriteHeader(http.StatusCreated) + w.Write([]byte(`{"status":"valid"}`)) + })) + defer ts.Close() + client := Client{ + Key: testKey, + DirectoryURL: ts.URL, + dir: &Directory{AuthzURL: ts.URL}, + } + _, err := client.Authorize(context.Background(), "example.com") + if err != nil { + t.Errorf("err = %v", err) + } +} + +func TestWaitAuthorization(t *testing.T) { + t.Run("wait loop", func(t *testing.T) { + var count int + authz, err := runWaitAuthorization(context.Background(), t, func(w http.ResponseWriter, r *http.Request) { + count++ + w.Header().Set("Retry-After", "0") + if count > 1 { + fmt.Fprintf(w, `{"status":"valid"}`) + return + } + fmt.Fprintf(w, `{"status":"pending"}`) + }) + if err != nil { + t.Fatalf("non-nil error: %v", err) + } + if authz == nil { + t.Fatal("authz is nil") + } + }) + t.Run("invalid status", func(t *testing.T) { + _, err := runWaitAuthorization(context.Background(), t, func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintf(w, `{"status":"invalid"}`) + }) + if _, ok := err.(*AuthorizationError); !ok { + t.Errorf("err is %v (%T); want non-nil *AuthorizationError", err, err) + } + }) + t.Run("invalid status with error returns the authorization error", func(t *testing.T) { + _, err := runWaitAuthorization(context.Background(), t, func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintf(w, `{ + "type": "dns-01", + "status": "invalid", + "error": { + "type": "urn:ietf:params:acme:error:caa", + "detail": "CAA record for prevents issuance", + "status": 403 + }, + "url": "https://acme-v02.api.letsencrypt.org/acme/chall-v3/xxx/xxx", + "token": "xxx", + "validationRecord": [ + { + "hostname": "" + } + ] + }`) + }) + + want := &AuthorizationError{ + Errors: []error{ + (&wireError{ + Status: 403, + Type: "urn:ietf:params:acme:error:caa", + Detail: "CAA record for prevents issuance", + }).error(nil), + }, + } + + _, ok := err.(*AuthorizationError) + if !ok { + t.Errorf("err is %T; want non-nil *AuthorizationError", err) + } + + if err.Error() != want.Error() { + t.Errorf("err is %v; want %v", err, want) + } + }) + t.Run("non-retriable error", func(t *testing.T) { + const code = http.StatusBadRequest + _, err := runWaitAuthorization(context.Background(), t, func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(code) + }) + res, ok := err.(*Error) + if !ok { + t.Fatalf("err is %v (%T); want a non-nil *Error", err, err) + } + if res.StatusCode != code { + t.Errorf("res.StatusCode = %d; want %d", res.StatusCode, code) + } + }) + for _, code := range []int{http.StatusTooManyRequests, http.StatusInternalServerError} { + t.Run(fmt.Sprintf("retriable %d error", code), func(t *testing.T) { + var count int + authz, err := runWaitAuthorization(context.Background(), t, func(w http.ResponseWriter, r *http.Request) { + count++ + w.Header().Set("Retry-After", "0") + if count > 1 { + fmt.Fprintf(w, `{"status":"valid"}`) + return + } + w.WriteHeader(code) + }) + if err != nil { + t.Fatalf("non-nil error: %v", err) + } + if authz == nil { + t.Fatal("authz is nil") + } + }) + } + t.Run("context cancel", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + _, err := runWaitAuthorization(ctx, t, func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Retry-After", "60") + fmt.Fprintf(w, `{"status":"pending"}`) + time.AfterFunc(1*time.Millisecond, cancel) + }) + if err == nil { + t.Error("err is nil") + } + }) +} + +func runWaitAuthorization(ctx context.Context, t *testing.T, h http.HandlerFunc) (*Authorization, error) { + t.Helper() + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Replay-Nonce", fmt.Sprintf("bad-test-nonce-%v", time.Now().UnixNano())) + h(w, r) + })) + defer ts.Close() + + client := &Client{ + Key: testKey, + DirectoryURL: ts.URL, + dir: &Directory{}, + KID: "some-key-id", // set to avoid lookup attempt + } + return client.WaitAuthorization(ctx, ts.URL) +} + +func TestRevokeAuthorization(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == "HEAD" { + w.Header().Set("Replay-Nonce", "nonce") + return + } + switch r.URL.Path { + case "/1": + var req struct { + Resource string + Status string + Delete bool + } + decodeJWSRequest(t, &req, r.Body) + if req.Resource != "authz" { + t.Errorf("req.Resource = %q; want authz", req.Resource) + } + if req.Status != "deactivated" { + t.Errorf("req.Status = %q; want deactivated", req.Status) + } + if !req.Delete { + t.Errorf("req.Delete is false") + } + case "/2": + w.WriteHeader(http.StatusBadRequest) + } + })) + defer ts.Close() + client := &Client{ + Key: testKey, + DirectoryURL: ts.URL, // don't dial outside of localhost + dir: &Directory{}, // don't do discovery + } + ctx := context.Background() + if err := client.RevokeAuthorization(ctx, ts.URL+"/1"); err != nil { + t.Errorf("err = %v", err) + } + if client.RevokeAuthorization(ctx, ts.URL+"/2") == nil { + t.Error("nil error") + } +} + +func TestFetchCertCancel(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + <-r.Context().Done() + w.Header().Set("Retry-After", "0") + w.WriteHeader(http.StatusBadRequest) + })) + defer ts.Close() + ctx, cancel := context.WithCancel(context.Background()) + done := make(chan struct{}) + var err error + go func() { + cl := newTestClient() + _, err = cl.FetchCert(ctx, ts.URL, false) + close(done) + }() + cancel() + <-done + if err != context.Canceled { + t.Errorf("err = %v; want %v", err, context.Canceled) + } +} + +func TestFetchCertDepth(t *testing.T) { + var count byte + var ts *httptest.Server + ts = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + count++ + if count > maxChainLen+1 { + t.Errorf("count = %d; want at most %d", count, maxChainLen+1) + w.WriteHeader(http.StatusInternalServerError) + } + w.Header().Set("Link", fmt.Sprintf("<%s>;rel=up", ts.URL)) + w.Write([]byte{count}) + })) + defer ts.Close() + cl := newTestClient() + _, err := cl.FetchCert(context.Background(), ts.URL, true) + if err == nil { + t.Errorf("err is nil") + } +} + +func TestFetchCertBreadth(t *testing.T) { + var ts *httptest.Server + ts = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + for i := 0; i < maxChainLen+1; i++ { + w.Header().Add("Link", fmt.Sprintf("<%s>;rel=up", ts.URL)) + } + w.Write([]byte{1}) + })) + defer ts.Close() + cl := newTestClient() + _, err := cl.FetchCert(context.Background(), ts.URL, true) + if err == nil { + t.Errorf("err is nil") + } +} + +func TestFetchCertSize(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + b := bytes.Repeat([]byte{1}, maxCertSize+1) + w.Write(b) + })) + defer ts.Close() + cl := newTestClient() + _, err := cl.FetchCert(context.Background(), ts.URL, false) + if err == nil { + t.Errorf("err is nil") + } +} + +const ( + leafPEM = `-----BEGIN CERTIFICATE----- +MIIEizCCAvOgAwIBAgIRAITApw7R8HSs7GU7cj8dEyUwDQYJKoZIhvcNAQELBQAw +gYUxHjAcBgNVBAoTFW1rY2VydCBkZXZlbG9wbWVudCBDQTEtMCsGA1UECwwkY3Bh +bG1lckBwdW1wa2luLmxvY2FsIChDaHJpcyBQYWxtZXIpMTQwMgYDVQQDDCtta2Nl +cnQgY3BhbG1lckBwdW1wa2luLmxvY2FsIChDaHJpcyBQYWxtZXIpMB4XDTIzMDcx +MjE4MjIxNloXDTI1MTAxMjE4MjIxNlowWDEnMCUGA1UEChMebWtjZXJ0IGRldmVs +b3BtZW50IGNlcnRpZmljYXRlMS0wKwYDVQQLDCRjcGFsbWVyQHB1bXBraW4ubG9j +YWwgKENocmlzIFBhbG1lcikwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIB +AQDNDO8P4MI9jaqVcPtF8C4GgHnTP5EK3U9fgyGApKGxTpicMQkA6z4GXwUP/Fvq +7RuCU9Wg7By5VetKIHF7FxkxWkUMrssr7mV8v6mRCh/a5GqDs14aj5ucjLQAJV74 +tLAdrCiijQ1fkPWc82fob+LkfKWGCWw7Cxf6ZtEyC8jz/DnfQXUvOiZS729ndGF7 +FobKRfIoirD+GI2NTYIp3LAUFSPR6HXTe7HAg8J81VoUKli8z504+FebfMmHePm/ +zIfiI0njAj4czOlZD56/oLsV0WRUizFjafHHUFz1HVdfFw8Qf9IOOTydYOe8M5i0 +lVbVO5G+HP+JDn3cr9MT41B9AgMBAAGjgaEwgZ4wDgYDVR0PAQH/BAQDAgWgMBMG +A1UdJQQMMAoGCCsGAQUFBwMBMB8GA1UdIwQYMBaAFPpL4Q0O7Z7voTkjn2rrFCsf +s8TbMFYGA1UdEQRPME2CC2V4YW1wbGUuY29tgg0qLmV4YW1wbGUuY29tggxleGFt +cGxlLnRlc3SCCWxvY2FsaG9zdIcEfwAAAYcQAAAAAAAAAAAAAAAAAAAAATANBgkq +hkiG9w0BAQsFAAOCAYEAMlOb7lrHuSxwcnAu7mL1ysTGqKn1d2TyDJAN5W8YFY+4 +XLpofNkK2UzZ0t9LQRnuFUcjmfqmfplh5lpC7pKmtL4G5Qcdc+BczQWcopbxd728 +sht9BKRkH+Bo1I+1WayKKNXW+5bsMv4CH641zxaMBlzjEnPvwKkNaGLMH3x5lIeX +GGgkKNXwVtINmyV+lTNVtu2IlHprxJGCjRfEuX7mEv6uRnqz3Wif+vgyh3MBgM/1 +dUOsTBNH4a6Jl/9VPSOfRdQOStqIlwTa/J1bhTvivsYt1+eWjLnsQJLgZQqwKvYH +BJ30gAk1oNnuSkx9dHbx4mO+4mB9oIYUALXUYakb8JHTOnuMSj9qelVj5vjVxl9q +KRitptU+kLYRA4HSgUXrhDIm4Q6D/w8/ascPqQ3HxPIDFLe+gTofEjqnnsnQB29L +gWpI8l5/MtXAOMdW69eEovnADc2pgaiif0T+v9nNKBc5xfDZHnrnqIqVzQEwL5Qv +niQI8IsWD5LcQ1Eg7kCq +-----END CERTIFICATE-----` +) + +func TestGetRenewalURL(t *testing.T) { + leaf, _ := pem.Decode([]byte(leafPEM)) + + parsedLeaf, err := x509.ParseCertificate(leaf.Bytes) + if err != nil { + t.Fatal(err) + } + + client := newTestClientWithMockDirectory() + urlString := client.getRenewalURL(parsedLeaf) + + parsedURL, err := url.Parse(urlString) + if err != nil { + t.Fatal(err) + } + if scheme := parsedURL.Scheme; scheme == "" { + t.Fatalf("malformed URL scheme: %q from %q", scheme, urlString) + } + if host := parsedURL.Host; host == "" { + t.Fatalf("malformed URL host: %q from %q", host, urlString) + } + if parsedURL.RawQuery != "" { + t.Fatalf("malformed URL: should not have a query") + } + path := parsedURL.EscapedPath() + slash := strings.LastIndex(path, "/") + if slash == -1 { + t.Fatalf("malformed URL path: %q from %q", path, urlString) + } + certID := path[slash+1:] + if certID == "" { + t.Fatalf("missing certificate identifier in URL path: %q from %q", path, urlString) + } + certIDParts := strings.Split(certID, ".") + if len(certIDParts) != 2 { + t.Fatalf("certificate identifier should consist of 2 base64-encoded values separated by a dot: %q from %q", certID, urlString) + } + if _, err := base64.RawURLEncoding.DecodeString(certIDParts[0]); err != nil { + t.Fatalf("malformed AKI part in certificate identifier: %q from %q: %v", certIDParts[0], urlString, err) + } + if _, err := base64.RawURLEncoding.DecodeString(certIDParts[1]); err != nil { + t.Fatalf("malformed Serial part in certificate identifier: %q from %q: %v", certIDParts[1], urlString, err) + } + +} + +func TestUnmarshalRenewalInfo(t *testing.T) { + renewalInfoJSON := `{ + "suggestedWindow": { + "start": "2021-01-03T00:00:00Z", + "end": "2021-01-07T00:00:00Z" + }, + "explanationURL": "https://example.com/docs/example-mass-reissuance-event" + }` + expectedStart := time.Date(2021, time.January, 3, 0, 0, 0, 0, time.UTC) + expectedEnd := time.Date(2021, time.January, 7, 0, 0, 0, 0, time.UTC) + + var info RenewalInfo + if err := json.Unmarshal([]byte(renewalInfoJSON), &info); err != nil { + t.Fatal(err) + } + if _, err := url.Parse(info.ExplanationURL); err != nil { + t.Fatal(err) + } + if !info.SuggestedWindow.Start.Equal(expectedStart) { + t.Fatalf("%v != %v", expectedStart, info.SuggestedWindow.Start) + } + if !info.SuggestedWindow.End.Equal(expectedEnd) { + t.Fatalf("%v != %v", expectedEnd, info.SuggestedWindow.End) + } +} + +func TestNonce_add(t *testing.T) { + var c Client + c.addNonce(http.Header{"Replay-Nonce": {"nonce"}}) + c.addNonce(http.Header{"Replay-Nonce": {}}) + c.addNonce(http.Header{"Replay-Nonce": {"nonce"}}) + + nonces := map[string]struct{}{"nonce": {}} + if !reflect.DeepEqual(c.nonces, nonces) { + t.Errorf("c.nonces = %q; want %q", c.nonces, nonces) + } +} + +func TestNonce_addMax(t *testing.T) { + c := &Client{nonces: make(map[string]struct{})} + for i := 0; i < maxNonces; i++ { + c.nonces[fmt.Sprintf("%d", i)] = struct{}{} + } + c.addNonce(http.Header{"Replay-Nonce": {"nonce"}}) + if n := len(c.nonces); n != maxNonces { + t.Errorf("len(c.nonces) = %d; want %d", n, maxNonces) + } +} + +func TestNonce_fetch(t *testing.T) { + tests := []struct { + code int + nonce string + }{ + {http.StatusOK, "nonce1"}, + {http.StatusBadRequest, "nonce2"}, + {http.StatusOK, ""}, + } + var i int + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != "HEAD" { + t.Errorf("%d: r.Method = %q; want HEAD", i, r.Method) + } + w.Header().Set("Replay-Nonce", tests[i].nonce) + w.WriteHeader(tests[i].code) + })) + defer ts.Close() + for ; i < len(tests); i++ { + test := tests[i] + c := newTestClient() + n, err := c.fetchNonce(context.Background(), ts.URL) + if n != test.nonce { + t.Errorf("%d: n=%q; want %q", i, n, test.nonce) + } + switch { + case err == nil && test.nonce == "": + t.Errorf("%d: n=%q, err=%v; want non-nil error", i, n, err) + case err != nil && test.nonce != "": + t.Errorf("%d: n=%q, err=%v; want %q", i, n, err, test.nonce) + } + } +} + +func TestNonce_fetchError(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusTooManyRequests) + })) + defer ts.Close() + c := newTestClient() + _, err := c.fetchNonce(context.Background(), ts.URL) + e, ok := err.(*Error) + if !ok { + t.Fatalf("err is %T; want *Error", err) + } + if e.StatusCode != http.StatusTooManyRequests { + t.Errorf("e.StatusCode = %d; want %d", e.StatusCode, http.StatusTooManyRequests) + } +} + +func TestNonce_popWhenEmpty(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != "HEAD" { + t.Errorf("r.Method = %q; want HEAD", r.Method) + } + switch r.URL.Path { + case "/dir-with-nonce": + w.Header().Set("Replay-Nonce", "dirnonce") + case "/new-nonce": + w.Header().Set("Replay-Nonce", "newnonce") + case "/dir-no-nonce", "/empty": + // No nonce in the header. + default: + t.Errorf("Unknown URL: %s", r.URL) + } + })) + defer ts.Close() + ctx := context.Background() + + tt := []struct { + dirURL, popURL, nonce string + wantOK bool + }{ + {ts.URL + "/dir-with-nonce", ts.URL + "/new-nonce", "dirnonce", true}, + {ts.URL + "/dir-no-nonce", ts.URL + "/new-nonce", "newnonce", true}, + {ts.URL + "/dir-no-nonce", ts.URL + "/empty", "", false}, + } + for _, test := range tt { + t.Run(fmt.Sprintf("nonce:%s wantOK:%v", test.nonce, test.wantOK), func(t *testing.T) { + c := Client{DirectoryURL: test.dirURL} + v, err := c.popNonce(ctx, test.popURL) + if !test.wantOK { + if err == nil { + t.Fatalf("c.popNonce(%q) returned nil error", test.popURL) + } + return + } + if err != nil { + t.Fatalf("c.popNonce(%q): %v", test.popURL, err) + } + if v != test.nonce { + t.Errorf("c.popNonce(%q) = %q; want %q", test.popURL, v, test.nonce) + } + }) + } +} + +func TestLinkHeader(t *testing.T) { + h := http.Header{"Link": { + `;rel="next"`, + `; rel=recover`, + `; foo=bar; rel="terms-of-service"`, + `;rel="next"`, + }} + tests := []struct { + rel string + out []string + }{ + {"next", []string{"https://example.com/acme/new-authz", "dup"}}, + {"recover", []string{"https://example.com/acme/recover-reg"}}, + {"terms-of-service", []string{"https://example.com/acme/terms"}}, + {"empty", nil}, + } + for i, test := range tests { + if v := linkHeader(h, test.rel); !reflect.DeepEqual(v, test.out) { + t.Errorf("%d: linkHeader(%q): %v; want %v", i, test.rel, v, test.out) + } + } +} + +func TestTLSSNI01ChallengeCert(t *testing.T) { + const ( + token = "evaGxfADs6pSRb2LAv9IZf17Dt3juxGJ-PCt92wr-oA" + // echo -n | shasum -a 256 + san = "dbbd5eefe7b4d06eb9d1d9f5acb4c7cd.a27d320e4b30332f0b6cb441734ad7b0.acme.invalid" + ) + + tlscert, name, err := newTestClient().TLSSNI01ChallengeCert(token) + if err != nil { + t.Fatal(err) + } + + if n := len(tlscert.Certificate); n != 1 { + t.Fatalf("len(tlscert.Certificate) = %d; want 1", n) + } + cert, err := x509.ParseCertificate(tlscert.Certificate[0]) + if err != nil { + t.Fatal(err) + } + if len(cert.DNSNames) != 1 || cert.DNSNames[0] != san { + t.Fatalf("cert.DNSNames = %v; want %q", cert.DNSNames, san) + } + if cert.DNSNames[0] != name { + t.Errorf("cert.DNSNames[0] != name: %q vs %q", cert.DNSNames[0], name) + } + if cn := cert.Subject.CommonName; cn != san { + t.Errorf("cert.Subject.CommonName = %q; want %q", cn, san) + } +} + +func TestTLSSNI02ChallengeCert(t *testing.T) { + const ( + token = "evaGxfADs6pSRb2LAv9IZf17Dt3juxGJ-PCt92wr-oA" + // echo -n evaGxfADs6pSRb2LAv9IZf17Dt3juxGJ-PCt92wr-oA | shasum -a 256 + sanA = "7ea0aaa69214e71e02cebb18bb867736.09b730209baabf60e43d4999979ff139.token.acme.invalid" + // echo -n | shasum -a 256 + sanB = "dbbd5eefe7b4d06eb9d1d9f5acb4c7cd.a27d320e4b30332f0b6cb441734ad7b0.ka.acme.invalid" + ) + + tlscert, name, err := newTestClient().TLSSNI02ChallengeCert(token) + if err != nil { + t.Fatal(err) + } + + if n := len(tlscert.Certificate); n != 1 { + t.Fatalf("len(tlscert.Certificate) = %d; want 1", n) + } + cert, err := x509.ParseCertificate(tlscert.Certificate[0]) + if err != nil { + t.Fatal(err) + } + names := []string{sanA, sanB} + if !reflect.DeepEqual(cert.DNSNames, names) { + t.Fatalf("cert.DNSNames = %v;\nwant %v", cert.DNSNames, names) + } + sort.Strings(cert.DNSNames) + i := sort.SearchStrings(cert.DNSNames, name) + if i >= len(cert.DNSNames) || cert.DNSNames[i] != name { + t.Errorf("%v doesn't have %q", cert.DNSNames, name) + } + if cn := cert.Subject.CommonName; cn != sanA { + t.Errorf("CommonName = %q; want %q", cn, sanA) + } +} + +func TestTLSALPN01ChallengeCert(t *testing.T) { + const ( + token = "evaGxfADs6pSRb2LAv9IZf17Dt3juxGJ-PCt92wr-oA" + keyAuth = "evaGxfADs6pSRb2LAv9IZf17Dt3juxGJ-PCt92wr-oA." + testKeyECThumbprint + // echo -n | shasum -a 256 + h = "0420dbbd5eefe7b4d06eb9d1d9f5acb4c7cda27d320e4b30332f0b6cb441734ad7b0" + domain = "example.com" + ) + + extValue, err := hex.DecodeString(h) + if err != nil { + t.Fatal(err) + } + + tlscert, err := newTestClient().TLSALPN01ChallengeCert(token, domain) + if err != nil { + t.Fatal(err) + } + + if n := len(tlscert.Certificate); n != 1 { + t.Fatalf("len(tlscert.Certificate) = %d; want 1", n) + } + cert, err := x509.ParseCertificate(tlscert.Certificate[0]) + if err != nil { + t.Fatal(err) + } + names := []string{domain} + if !reflect.DeepEqual(cert.DNSNames, names) { + t.Fatalf("cert.DNSNames = %v;\nwant %v", cert.DNSNames, names) + } + if cn := cert.Subject.CommonName; cn != domain { + t.Errorf("CommonName = %q; want %q", cn, domain) + } + acmeExts := []pkix.Extension{} + for _, ext := range cert.Extensions { + if idPeACMEIdentifier.Equal(ext.Id) { + acmeExts = append(acmeExts, ext) + } + } + if len(acmeExts) != 1 { + t.Errorf("acmeExts = %v; want exactly one", acmeExts) + } + if !acmeExts[0].Critical { + t.Errorf("acmeExt.Critical = %v; want true", acmeExts[0].Critical) + } + if bytes.Compare(acmeExts[0].Value, extValue) != 0 { + t.Errorf("acmeExt.Value = %v; want %v", acmeExts[0].Value, extValue) + } + +} + +func TestTLSChallengeCertOpt(t *testing.T) { + key, err := rsa.GenerateKey(rand.Reader, 1024) + if err != nil { + t.Fatal(err) + } + tmpl := &x509.Certificate{ + SerialNumber: big.NewInt(2), + Subject: pkix.Name{Organization: []string{"Test"}}, + DNSNames: []string{"should-be-overwritten"}, + } + opts := []CertOption{WithKey(key), WithTemplate(tmpl)} + + client := newTestClient() + cert1, _, err := client.TLSSNI01ChallengeCert("token", opts...) + if err != nil { + t.Fatal(err) + } + cert2, _, err := client.TLSSNI02ChallengeCert("token", opts...) + if err != nil { + t.Fatal(err) + } + + for i, tlscert := range []tls.Certificate{cert1, cert2} { + // verify generated cert private key + tlskey, ok := tlscert.PrivateKey.(*rsa.PrivateKey) + if !ok { + t.Errorf("%d: tlscert.PrivateKey is %T; want *rsa.PrivateKey", i, tlscert.PrivateKey) + continue + } + if tlskey.D.Cmp(key.D) != 0 { + t.Errorf("%d: tlskey.D = %v; want %v", i, tlskey.D, key.D) + } + // verify generated cert public key + x509Cert, err := x509.ParseCertificate(tlscert.Certificate[0]) + if err != nil { + t.Errorf("%d: %v", i, err) + continue + } + tlspub, ok := x509Cert.PublicKey.(*rsa.PublicKey) + if !ok { + t.Errorf("%d: x509Cert.PublicKey is %T; want *rsa.PublicKey", i, x509Cert.PublicKey) + continue + } + if tlspub.N.Cmp(key.N) != 0 { + t.Errorf("%d: tlspub.N = %v; want %v", i, tlspub.N, key.N) + } + // verify template option + sn := big.NewInt(2) + if x509Cert.SerialNumber.Cmp(sn) != 0 { + t.Errorf("%d: SerialNumber = %v; want %v", i, x509Cert.SerialNumber, sn) + } + org := []string{"Test"} + if !reflect.DeepEqual(x509Cert.Subject.Organization, org) { + t.Errorf("%d: Subject.Organization = %+v; want %+v", i, x509Cert.Subject.Organization, org) + } + for _, v := range x509Cert.DNSNames { + if !strings.HasSuffix(v, ".acme.invalid") { + t.Errorf("%d: invalid DNSNames element: %q", i, v) + } + } + } +} + +func TestHTTP01Challenge(t *testing.T) { + const ( + token = "xxx" + // thumbprint is precomputed for testKeyEC in jws_test.go + value = token + "." + testKeyECThumbprint + urlpath = "/.well-known/acme-challenge/" + token + ) + client := newTestClient() + val, err := client.HTTP01ChallengeResponse(token) + if err != nil { + t.Fatal(err) + } + if val != value { + t.Errorf("val = %q; want %q", val, value) + } + if path := client.HTTP01ChallengePath(token); path != urlpath { + t.Errorf("path = %q; want %q", path, urlpath) + } +} + +func TestDNS01ChallengeRecord(t *testing.T) { + // echo -n xxx. | \ + // openssl dgst -binary -sha256 | \ + // base64 | tr -d '=' | tr '/+' '_-' + const value = "8DERMexQ5VcdJ_prpPiA0mVdp7imgbCgjsG4SqqNMIo" + + val, err := newTestClient().DNS01ChallengeRecord("xxx") + if err != nil { + t.Fatal(err) + } + if val != value { + t.Errorf("val = %q; want %q", val, value) + } +} diff --git a/tempfork/acme/http.go b/tempfork/acme/http.go new file mode 100644 index 000000000..d92ff232f --- /dev/null +++ b/tempfork/acme/http.go @@ -0,0 +1,344 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package acme + +import ( + "bytes" + "context" + "crypto" + "crypto/rand" + "encoding/json" + "errors" + "fmt" + "io" + "math/big" + "net/http" + "runtime/debug" + "strconv" + "strings" + "time" +) + +// retryTimer encapsulates common logic for retrying unsuccessful requests. +// It is not safe for concurrent use. +type retryTimer struct { + // backoffFn provides backoff delay sequence for retries. + // See Client.RetryBackoff doc comment. + backoffFn func(n int, r *http.Request, res *http.Response) time.Duration + // n is the current retry attempt. + n int +} + +func (t *retryTimer) inc() { + t.n++ +} + +// backoff pauses the current goroutine as described in Client.RetryBackoff. +func (t *retryTimer) backoff(ctx context.Context, r *http.Request, res *http.Response) error { + d := t.backoffFn(t.n, r, res) + if d <= 0 { + return fmt.Errorf("acme: no more retries for %s; tried %d time(s)", r.URL, t.n) + } + wakeup := time.NewTimer(d) + defer wakeup.Stop() + select { + case <-ctx.Done(): + return ctx.Err() + case <-wakeup.C: + return nil + } +} + +func (c *Client) retryTimer() *retryTimer { + f := c.RetryBackoff + if f == nil { + f = defaultBackoff + } + return &retryTimer{backoffFn: f} +} + +// defaultBackoff provides default Client.RetryBackoff implementation +// using a truncated exponential backoff algorithm, +// as described in Client.RetryBackoff. +// +// The n argument is always bounded between 1 and 30. +// The returned value is always greater than 0. +func defaultBackoff(n int, r *http.Request, res *http.Response) time.Duration { + const max = 10 * time.Second + var jitter time.Duration + if x, err := rand.Int(rand.Reader, big.NewInt(1000)); err == nil { + // Set the minimum to 1ms to avoid a case where + // an invalid Retry-After value is parsed into 0 below, + // resulting in the 0 returned value which would unintentionally + // stop the retries. + jitter = (1 + time.Duration(x.Int64())) * time.Millisecond + } + if v, ok := res.Header["Retry-After"]; ok { + return retryAfter(v[0]) + jitter + } + + if n < 1 { + n = 1 + } + if n > 30 { + n = 30 + } + d := time.Duration(1< max { + return max + } + return d +} + +// retryAfter parses a Retry-After HTTP header value, +// trying to convert v into an int (seconds) or use http.ParseTime otherwise. +// It returns zero value if v cannot be parsed. +func retryAfter(v string) time.Duration { + if i, err := strconv.Atoi(v); err == nil { + return time.Duration(i) * time.Second + } + t, err := http.ParseTime(v) + if err != nil { + return 0 + } + return t.Sub(timeNow()) +} + +// resOkay is a function that reports whether the provided response is okay. +// It is expected to keep the response body unread. +type resOkay func(*http.Response) bool + +// wantStatus returns a function which reports whether the code +// matches the status code of a response. +func wantStatus(codes ...int) resOkay { + return func(res *http.Response) bool { + for _, code := range codes { + if code == res.StatusCode { + return true + } + } + return false + } +} + +// get issues an unsigned GET request to the specified URL. +// It returns a non-error value only when ok reports true. +// +// get retries unsuccessful attempts according to c.RetryBackoff +// until the context is done or a non-retriable error is received. +func (c *Client) get(ctx context.Context, url string, ok resOkay) (*http.Response, error) { + retry := c.retryTimer() + for { + req, err := http.NewRequest("GET", url, nil) + if err != nil { + return nil, err + } + res, err := c.doNoRetry(ctx, req) + switch { + case err != nil: + return nil, err + case ok(res): + return res, nil + case isRetriable(res.StatusCode): + retry.inc() + resErr := responseError(res) + res.Body.Close() + // Ignore the error value from retry.backoff + // and return the one from last retry, as received from the CA. + if retry.backoff(ctx, req, res) != nil { + return nil, resErr + } + default: + defer res.Body.Close() + return nil, responseError(res) + } + } +} + +// postAsGet is POST-as-GET, a replacement for GET in RFC 8555 +// as described in https://tools.ietf.org/html/rfc8555#section-6.3. +// It makes a POST request in KID form with zero JWS payload. +// See nopayload doc comments in jws.go. +func (c *Client) postAsGet(ctx context.Context, url string, ok resOkay) (*http.Response, error) { + return c.post(ctx, nil, url, noPayload, ok) +} + +// post issues a signed POST request in JWS format using the provided key +// to the specified URL. If key is nil, c.Key is used instead. +// It returns a non-error value only when ok reports true. +// +// post retries unsuccessful attempts according to c.RetryBackoff +// until the context is done or a non-retriable error is received. +// It uses postNoRetry to make individual requests. +func (c *Client) post(ctx context.Context, key crypto.Signer, url string, body interface{}, ok resOkay) (*http.Response, error) { + retry := c.retryTimer() + for { + res, req, err := c.postNoRetry(ctx, key, url, body) + if err != nil { + return nil, err + } + if ok(res) { + return res, nil + } + resErr := responseError(res) + res.Body.Close() + switch { + // Check for bad nonce before isRetriable because it may have been returned + // with an unretriable response code such as 400 Bad Request. + case isBadNonce(resErr): + // Consider any previously stored nonce values to be invalid. + c.clearNonces() + case !isRetriable(res.StatusCode): + return nil, resErr + } + retry.inc() + // Ignore the error value from retry.backoff + // and return the one from last retry, as received from the CA. + if err := retry.backoff(ctx, req, res); err != nil { + return nil, resErr + } + } +} + +// postNoRetry signs the body with the given key and POSTs it to the provided url. +// It is used by c.post to retry unsuccessful attempts. +// The body argument must be JSON-serializable. +// +// If key argument is nil, c.Key is used to sign the request. +// If key argument is nil and c.accountKID returns a non-zero keyID, +// the request is sent in KID form. Otherwise, JWK form is used. +// +// In practice, when interfacing with RFC-compliant CAs most requests are sent in KID form +// and JWK is used only when KID is unavailable: new account endpoint and certificate +// revocation requests authenticated by a cert key. +// See jwsEncodeJSON for other details. +func (c *Client) postNoRetry(ctx context.Context, key crypto.Signer, url string, body interface{}) (*http.Response, *http.Request, error) { + kid := noKeyID + if key == nil { + if c.Key == nil { + return nil, nil, errors.New("acme: Client.Key must be populated to make POST requests") + } + key = c.Key + kid = c.accountKID(ctx) + } + nonce, err := c.popNonce(ctx, url) + if err != nil { + return nil, nil, err + } + b, err := jwsEncodeJSON(body, key, kid, nonce, url) + if err != nil { + return nil, nil, err + } + req, err := http.NewRequest("POST", url, bytes.NewReader(b)) + if err != nil { + return nil, nil, err + } + req.Header.Set("Content-Type", "application/jose+json") + res, err := c.doNoRetry(ctx, req) + if err != nil { + return nil, nil, err + } + c.addNonce(res.Header) + return res, req, nil +} + +// doNoRetry issues a request req, replacing its context (if any) with ctx. +func (c *Client) doNoRetry(ctx context.Context, req *http.Request) (*http.Response, error) { + req.Header.Set("User-Agent", c.userAgent()) + res, err := c.httpClient().Do(req.WithContext(ctx)) + if err != nil { + select { + case <-ctx.Done(): + // Prefer the unadorned context error. + // (The acme package had tests assuming this, previously from ctxhttp's + // behavior, predating net/http supporting contexts natively) + // TODO(bradfitz): reconsider this in the future. But for now this + // requires no test updates. + return nil, ctx.Err() + default: + return nil, err + } + } + return res, nil +} + +func (c *Client) httpClient() *http.Client { + if c.HTTPClient != nil { + return c.HTTPClient + } + return http.DefaultClient +} + +// packageVersion is the version of the module that contains this package, for +// sending as part of the User-Agent header. +var packageVersion string + +func init() { + // Set packageVersion if the binary was built in modules mode and x/crypto + // was not replaced with a different module. + info, ok := debug.ReadBuildInfo() + if !ok { + return + } + for _, m := range info.Deps { + if m.Path != "golang.org/x/crypto" { + continue + } + if m.Replace == nil { + packageVersion = m.Version + } + break + } +} + +// userAgent returns the User-Agent header value. It includes the package name, +// the module version (if available), and the c.UserAgent value (if set). +func (c *Client) userAgent() string { + ua := "golang.org/x/crypto/acme" + if packageVersion != "" { + ua += "@" + packageVersion + } + if c.UserAgent != "" { + ua = c.UserAgent + " " + ua + } + return ua +} + +// isBadNonce reports whether err is an ACME "badnonce" error. +func isBadNonce(err error) bool { + // According to the spec badNonce is urn:ietf:params:acme:error:badNonce. + // However, ACME servers in the wild return their versions of the error. + // See https://tools.ietf.org/html/draft-ietf-acme-acme-02#section-5.4 + // and https://github.com/letsencrypt/boulder/blob/0e07eacb/docs/acme-divergences.md#section-66. + ae, ok := err.(*Error) + return ok && strings.HasSuffix(strings.ToLower(ae.ProblemType), ":badnonce") +} + +// isRetriable reports whether a request can be retried +// based on the response status code. +// +// Note that a "bad nonce" error is returned with a non-retriable 400 Bad Request code. +// Callers should parse the response and check with isBadNonce. +func isRetriable(code int) bool { + return code <= 399 || code >= 500 || code == http.StatusTooManyRequests +} + +// responseError creates an error of Error type from resp. +func responseError(resp *http.Response) error { + // don't care if ReadAll returns an error: + // json.Unmarshal will fail in that case anyway + b, _ := io.ReadAll(resp.Body) + e := &wireError{Status: resp.StatusCode} + if err := json.Unmarshal(b, e); err != nil { + // this is not a regular error response: + // populate detail with anything we received, + // e.Status will already contain HTTP response code value + e.Detail = string(b) + if e.Detail == "" { + e.Detail = resp.Status + } + } + return e.error(resp.Header) +} diff --git a/tempfork/acme/http_test.go b/tempfork/acme/http_test.go new file mode 100644 index 000000000..d124e4e21 --- /dev/null +++ b/tempfork/acme/http_test.go @@ -0,0 +1,255 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package acme + +import ( + "context" + "fmt" + "io" + "net/http" + "net/http/httptest" + "reflect" + "strings" + "testing" + "time" +) + +func TestDefaultBackoff(t *testing.T) { + tt := []struct { + nretry int + retryAfter string // Retry-After header + out time.Duration // expected min; max = min + jitter + }{ + {-1, "", time.Second}, // verify the lower bound is 1 + {0, "", time.Second}, // verify the lower bound is 1 + {100, "", 10 * time.Second}, // verify the ceiling + {1, "3600", time.Hour}, // verify the header value is used + {1, "", 1 * time.Second}, + {2, "", 2 * time.Second}, + {3, "", 4 * time.Second}, + {4, "", 8 * time.Second}, + } + for i, test := range tt { + r := httptest.NewRequest("GET", "/", nil) + resp := &http.Response{Header: http.Header{}} + if test.retryAfter != "" { + resp.Header.Set("Retry-After", test.retryAfter) + } + d := defaultBackoff(test.nretry, r, resp) + max := test.out + time.Second // + max jitter + if d < test.out || max < d { + t.Errorf("%d: defaultBackoff(%v) = %v; want between %v and %v", i, test.nretry, d, test.out, max) + } + } +} + +func TestErrorResponse(t *testing.T) { + s := `{ + "status": 400, + "type": "urn:acme:error:xxx", + "detail": "text" + }` + res := &http.Response{ + StatusCode: 400, + Status: "400 Bad Request", + Body: io.NopCloser(strings.NewReader(s)), + Header: http.Header{"X-Foo": {"bar"}}, + } + err := responseError(res) + v, ok := err.(*Error) + if !ok { + t.Fatalf("err = %+v (%T); want *Error type", err, err) + } + if v.StatusCode != 400 { + t.Errorf("v.StatusCode = %v; want 400", v.StatusCode) + } + if v.ProblemType != "urn:acme:error:xxx" { + t.Errorf("v.ProblemType = %q; want urn:acme:error:xxx", v.ProblemType) + } + if v.Detail != "text" { + t.Errorf("v.Detail = %q; want text", v.Detail) + } + if !reflect.DeepEqual(v.Header, res.Header) { + t.Errorf("v.Header = %+v; want %+v", v.Header, res.Header) + } +} + +func TestPostWithRetries(t *testing.T) { + var count int + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + count++ + w.Header().Set("Replay-Nonce", fmt.Sprintf("nonce%d", count)) + if r.Method == "HEAD" { + // We expect the client to do 2 head requests to fetch + // nonces, one to start and another after getting badNonce + return + } + + head, err := decodeJWSHead(r.Body) + switch { + case err != nil: + t.Errorf("decodeJWSHead: %v", err) + case head.Nonce == "": + t.Error("head.Nonce is empty") + case head.Nonce == "nonce1": + // Return a badNonce error to force the call to retry. + w.Header().Set("Retry-After", "0") + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte(`{"type":"urn:ietf:params:acme:error:badNonce"}`)) + return + } + // Make client.Authorize happy; we're not testing its result. + w.WriteHeader(http.StatusCreated) + w.Write([]byte(`{"status":"valid"}`)) + })) + defer ts.Close() + + client := &Client{ + Key: testKey, + DirectoryURL: ts.URL, + dir: &Directory{AuthzURL: ts.URL}, + } + // This call will fail with badNonce, causing a retry + if _, err := client.Authorize(context.Background(), "example.com"); err != nil { + t.Errorf("client.Authorize 1: %v", err) + } + if count != 3 { + t.Errorf("total requests count: %d; want 3", count) + } +} + +func TestRetryErrorType(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Replay-Nonce", "nonce") + w.WriteHeader(http.StatusTooManyRequests) + w.Write([]byte(`{"type":"rateLimited"}`)) + })) + defer ts.Close() + + client := &Client{ + Key: testKey, + RetryBackoff: func(n int, r *http.Request, res *http.Response) time.Duration { + // Do no retries. + return 0 + }, + dir: &Directory{AuthzURL: ts.URL}, + } + + t.Run("post", func(t *testing.T) { + testRetryErrorType(t, func() error { + _, err := client.Authorize(context.Background(), "example.com") + return err + }) + }) + t.Run("get", func(t *testing.T) { + testRetryErrorType(t, func() error { + _, err := client.GetAuthorization(context.Background(), ts.URL) + return err + }) + }) +} + +func testRetryErrorType(t *testing.T, callClient func() error) { + t.Helper() + err := callClient() + if err == nil { + t.Fatal("client.Authorize returned nil error") + } + acmeErr, ok := err.(*Error) + if !ok { + t.Fatalf("err is %v (%T); want *Error", err, err) + } + if acmeErr.StatusCode != http.StatusTooManyRequests { + t.Errorf("acmeErr.StatusCode = %d; want %d", acmeErr.StatusCode, http.StatusTooManyRequests) + } + if acmeErr.ProblemType != "rateLimited" { + t.Errorf("acmeErr.ProblemType = %q; want 'rateLimited'", acmeErr.ProblemType) + } +} + +func TestRetryBackoffArgs(t *testing.T) { + const resCode = http.StatusInternalServerError + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Replay-Nonce", "test-nonce") + w.WriteHeader(resCode) + })) + defer ts.Close() + + // Canceled in backoff. + ctx, cancel := context.WithCancel(context.Background()) + + var nretry int + backoff := func(n int, r *http.Request, res *http.Response) time.Duration { + nretry++ + if n != nretry { + t.Errorf("n = %d; want %d", n, nretry) + } + if nretry == 3 { + cancel() + } + + if r == nil { + t.Error("r is nil") + } + if res.StatusCode != resCode { + t.Errorf("res.StatusCode = %d; want %d", res.StatusCode, resCode) + } + return time.Millisecond + } + + client := &Client{ + Key: testKey, + RetryBackoff: backoff, + dir: &Directory{AuthzURL: ts.URL}, + } + if _, err := client.Authorize(ctx, "example.com"); err == nil { + t.Error("err is nil") + } + if nretry != 3 { + t.Errorf("nretry = %d; want 3", nretry) + } +} + +func TestUserAgent(t *testing.T) { + for _, custom := range []string{"", "CUSTOM_UA"} { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + t.Log(r.UserAgent()) + if s := "golang.org/x/crypto/acme"; !strings.Contains(r.UserAgent(), s) { + t.Errorf("expected User-Agent to contain %q, got %q", s, r.UserAgent()) + } + if !strings.Contains(r.UserAgent(), custom) { + t.Errorf("expected User-Agent to contain %q, got %q", custom, r.UserAgent()) + } + + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"newOrder": "sure"}`)) + })) + defer ts.Close() + + client := &Client{ + Key: testKey, + DirectoryURL: ts.URL, + UserAgent: custom, + } + if _, err := client.Discover(context.Background()); err != nil { + t.Errorf("client.Discover: %v", err) + } + } +} + +func TestAccountKidLoop(t *testing.T) { + // if Client.postNoRetry is called with a nil key argument + // then Client.Key must be set, otherwise we fall into an + // infinite loop (which also causes a deadlock). + client := &Client{dir: &Directory{OrderURL: ":)"}} + _, _, err := client.postNoRetry(context.Background(), nil, "", nil) + if err == nil { + t.Fatal("Client.postNoRetry didn't fail with a nil key") + } + expected := "acme: Client.Key must be populated to make POST requests" + if err.Error() != expected { + t.Fatalf("Unexpected error returned: wanted %q, got %q", expected, err.Error()) + } +} diff --git a/tempfork/acme/jws.go b/tempfork/acme/jws.go new file mode 100644 index 000000000..b38828d85 --- /dev/null +++ b/tempfork/acme/jws.go @@ -0,0 +1,257 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package acme + +import ( + "crypto" + "crypto/ecdsa" + "crypto/hmac" + "crypto/rand" + "crypto/rsa" + "crypto/sha256" + _ "crypto/sha512" // need for EC keys + "encoding/asn1" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "math/big" +) + +// KeyID is the account key identity provided by a CA during registration. +type KeyID string + +// noKeyID indicates that jwsEncodeJSON should compute and use JWK instead of a KID. +// See jwsEncodeJSON for details. +const noKeyID = KeyID("") + +// noPayload indicates jwsEncodeJSON will encode zero-length octet string +// in a JWS request. This is called POST-as-GET in RFC 8555 and is used to make +// authenticated GET requests via POSTing with an empty payload. +// See https://tools.ietf.org/html/rfc8555#section-6.3 for more details. +const noPayload = "" + +// noNonce indicates that the nonce should be omitted from the protected header. +// See jwsEncodeJSON for details. +const noNonce = "" + +// jsonWebSignature can be easily serialized into a JWS following +// https://tools.ietf.org/html/rfc7515#section-3.2. +type jsonWebSignature struct { + Protected string `json:"protected"` + Payload string `json:"payload"` + Sig string `json:"signature"` +} + +// jwsEncodeJSON signs claimset using provided key and a nonce. +// The result is serialized in JSON format containing either kid or jwk +// fields based on the provided KeyID value. +// +// The claimset is marshalled using json.Marshal unless it is a string. +// In which case it is inserted directly into the message. +// +// If kid is non-empty, its quoted value is inserted in the protected header +// as "kid" field value. Otherwise, JWK is computed using jwkEncode and inserted +// as "jwk" field value. The "jwk" and "kid" fields are mutually exclusive. +// +// If nonce is non-empty, its quoted value is inserted in the protected header. +// +// See https://tools.ietf.org/html/rfc7515#section-7. +func jwsEncodeJSON(claimset interface{}, key crypto.Signer, kid KeyID, nonce, url string) ([]byte, error) { + if key == nil { + return nil, errors.New("nil key") + } + alg, sha := jwsHasher(key.Public()) + if alg == "" || !sha.Available() { + return nil, ErrUnsupportedKey + } + headers := struct { + Alg string `json:"alg"` + KID string `json:"kid,omitempty"` + JWK json.RawMessage `json:"jwk,omitempty"` + Nonce string `json:"nonce,omitempty"` + URL string `json:"url"` + }{ + Alg: alg, + Nonce: nonce, + URL: url, + } + switch kid { + case noKeyID: + jwk, err := jwkEncode(key.Public()) + if err != nil { + return nil, err + } + headers.JWK = json.RawMessage(jwk) + default: + headers.KID = string(kid) + } + phJSON, err := json.Marshal(headers) + if err != nil { + return nil, err + } + phead := base64.RawURLEncoding.EncodeToString([]byte(phJSON)) + var payload string + if val, ok := claimset.(string); ok { + payload = val + } else { + cs, err := json.Marshal(claimset) + if err != nil { + return nil, err + } + payload = base64.RawURLEncoding.EncodeToString(cs) + } + hash := sha.New() + hash.Write([]byte(phead + "." + payload)) + sig, err := jwsSign(key, sha, hash.Sum(nil)) + if err != nil { + return nil, err + } + enc := jsonWebSignature{ + Protected: phead, + Payload: payload, + Sig: base64.RawURLEncoding.EncodeToString(sig), + } + return json.Marshal(&enc) +} + +// jwsWithMAC creates and signs a JWS using the given key and the HS256 +// algorithm. kid and url are included in the protected header. rawPayload +// should not be base64-URL-encoded. +func jwsWithMAC(key []byte, kid, url string, rawPayload []byte) (*jsonWebSignature, error) { + if len(key) == 0 { + return nil, errors.New("acme: cannot sign JWS with an empty MAC key") + } + header := struct { + Algorithm string `json:"alg"` + KID string `json:"kid"` + URL string `json:"url,omitempty"` + }{ + // Only HMAC-SHA256 is supported. + Algorithm: "HS256", + KID: kid, + URL: url, + } + rawProtected, err := json.Marshal(header) + if err != nil { + return nil, err + } + protected := base64.RawURLEncoding.EncodeToString(rawProtected) + payload := base64.RawURLEncoding.EncodeToString(rawPayload) + + h := hmac.New(sha256.New, key) + if _, err := h.Write([]byte(protected + "." + payload)); err != nil { + return nil, err + } + mac := h.Sum(nil) + + return &jsonWebSignature{ + Protected: protected, + Payload: payload, + Sig: base64.RawURLEncoding.EncodeToString(mac), + }, nil +} + +// jwkEncode encodes public part of an RSA or ECDSA key into a JWK. +// The result is also suitable for creating a JWK thumbprint. +// https://tools.ietf.org/html/rfc7517 +func jwkEncode(pub crypto.PublicKey) (string, error) { + switch pub := pub.(type) { + case *rsa.PublicKey: + // https://tools.ietf.org/html/rfc7518#section-6.3.1 + n := pub.N + e := big.NewInt(int64(pub.E)) + // Field order is important. + // See https://tools.ietf.org/html/rfc7638#section-3.3 for details. + return fmt.Sprintf(`{"e":"%s","kty":"RSA","n":"%s"}`, + base64.RawURLEncoding.EncodeToString(e.Bytes()), + base64.RawURLEncoding.EncodeToString(n.Bytes()), + ), nil + case *ecdsa.PublicKey: + // https://tools.ietf.org/html/rfc7518#section-6.2.1 + p := pub.Curve.Params() + n := p.BitSize / 8 + if p.BitSize%8 != 0 { + n++ + } + x := pub.X.Bytes() + if n > len(x) { + x = append(make([]byte, n-len(x)), x...) + } + y := pub.Y.Bytes() + if n > len(y) { + y = append(make([]byte, n-len(y)), y...) + } + // Field order is important. + // See https://tools.ietf.org/html/rfc7638#section-3.3 for details. + return fmt.Sprintf(`{"crv":"%s","kty":"EC","x":"%s","y":"%s"}`, + p.Name, + base64.RawURLEncoding.EncodeToString(x), + base64.RawURLEncoding.EncodeToString(y), + ), nil + } + return "", ErrUnsupportedKey +} + +// jwsSign signs the digest using the given key. +// The hash is unused for ECDSA keys. +func jwsSign(key crypto.Signer, hash crypto.Hash, digest []byte) ([]byte, error) { + switch pub := key.Public().(type) { + case *rsa.PublicKey: + return key.Sign(rand.Reader, digest, hash) + case *ecdsa.PublicKey: + sigASN1, err := key.Sign(rand.Reader, digest, hash) + if err != nil { + return nil, err + } + + var rs struct{ R, S *big.Int } + if _, err := asn1.Unmarshal(sigASN1, &rs); err != nil { + return nil, err + } + + rb, sb := rs.R.Bytes(), rs.S.Bytes() + size := pub.Params().BitSize / 8 + if size%8 > 0 { + size++ + } + sig := make([]byte, size*2) + copy(sig[size-len(rb):], rb) + copy(sig[size*2-len(sb):], sb) + return sig, nil + } + return nil, ErrUnsupportedKey +} + +// jwsHasher indicates suitable JWS algorithm name and a hash function +// to use for signing a digest with the provided key. +// It returns ("", 0) if the key is not supported. +func jwsHasher(pub crypto.PublicKey) (string, crypto.Hash) { + switch pub := pub.(type) { + case *rsa.PublicKey: + return "RS256", crypto.SHA256 + case *ecdsa.PublicKey: + switch pub.Params().Name { + case "P-256": + return "ES256", crypto.SHA256 + case "P-384": + return "ES384", crypto.SHA384 + case "P-521": + return "ES512", crypto.SHA512 + } + } + return "", 0 +} + +// JWKThumbprint creates a JWK thumbprint out of pub +// as specified in https://tools.ietf.org/html/rfc7638. +func JWKThumbprint(pub crypto.PublicKey) (string, error) { + jwk, err := jwkEncode(pub) + if err != nil { + return "", err + } + b := sha256.Sum256([]byte(jwk)) + return base64.RawURLEncoding.EncodeToString(b[:]), nil +} diff --git a/tempfork/acme/jws_test.go b/tempfork/acme/jws_test.go new file mode 100644 index 000000000..d5f00ba2d --- /dev/null +++ b/tempfork/acme/jws_test.go @@ -0,0 +1,550 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package acme + +import ( + "crypto" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rsa" + "crypto/sha256" + "crypto/x509" + "encoding/base64" + "encoding/json" + "encoding/pem" + "fmt" + "io" + "math/big" + "testing" +) + +// The following shell command alias is used in the comments +// throughout this file: +// alias b64raw="base64 -w0 | tr -d '=' | tr '/+' '_-'" + +const ( + // Modulus in raw base64: + // 4xgZ3eRPkwoRvy7qeRUbmMDe0V-xH9eWLdu0iheeLlrmD2mqWXfP9IeSKApbn34 + // g8TuAS9g5zhq8ELQ3kmjr-KV86GAMgI6VAcGlq3QrzpTCf_30Ab7-zawrfRaFON + // a1HwEzPY1KHnGVkxJc85gNkwYI9SY2RHXtvln3zs5wITNrdosqEXeaIkVYBEhbh + // Nu54pp3kxo6TuWLi9e6pXeWetEwmlBwtWZlPoib2j3TxLBksKZfoyFyek380mHg + // JAumQ_I2fjj98_97mk3ihOY4AgVdCDj1z_GCoZkG5Rq7nbCGyosyKWyDX00Zs-n + // NqVhoLeIvXC4nnWdJMZ6rogxyQQ + testKeyPEM = ` +-----BEGIN RSA PRIVATE KEY----- +MIIEowIBAAKCAQEA4xgZ3eRPkwoRvy7qeRUbmMDe0V+xH9eWLdu0iheeLlrmD2mq +WXfP9IeSKApbn34g8TuAS9g5zhq8ELQ3kmjr+KV86GAMgI6VAcGlq3QrzpTCf/30 +Ab7+zawrfRaFONa1HwEzPY1KHnGVkxJc85gNkwYI9SY2RHXtvln3zs5wITNrdosq +EXeaIkVYBEhbhNu54pp3kxo6TuWLi9e6pXeWetEwmlBwtWZlPoib2j3TxLBksKZf +oyFyek380mHgJAumQ/I2fjj98/97mk3ihOY4AgVdCDj1z/GCoZkG5Rq7nbCGyosy +KWyDX00Zs+nNqVhoLeIvXC4nnWdJMZ6rogxyQQIDAQABAoIBACIEZTOI1Kao9nmV +9IeIsuaR1Y61b9neOF/MLmIVIZu+AAJFCMB4Iw11FV6sFodwpEyeZhx2WkpWVN+H +r19eGiLX3zsL0DOdqBJoSIHDWCCMxgnYJ6nvS0nRxX3qVrBp8R2g12Ub+gNPbmFm +ecf/eeERIVxfifd9VsyRu34eDEvcmKFuLYbElFcPh62xE3x12UZvV/sN7gXbawpP +G+w255vbE5MoaKdnnO83cTFlcHvhn24M/78qP7Te5OAeelr1R89kYxQLpuGe4fbS +zc6E3ym5Td6urDetGGrSY1Eu10/8sMusX+KNWkm+RsBRbkyKq72ks/qKpOxOa+c6 +9gm+Y8ECgYEA/iNUyg1ubRdH11p82l8KHtFC1DPE0V1gSZsX29TpM5jS4qv46K+s +8Ym1zmrORM8x+cynfPx1VQZQ34EYeCMIX212ryJ+zDATl4NE0I4muMvSiH9vx6Xc +7FmhNnaYzPsBL5Tm9nmtQuP09YEn8poiOJFiDs/4olnD5ogA5O4THGkCgYEA5MIL +qWYBUuqbEWLRtMruUtpASclrBqNNsJEsMGbeqBJmoMxdHeSZckbLOrqm7GlMyNRJ +Ne/5uWRGSzaMYuGmwsPpERzqEvYFnSrpjW5YtXZ+JtxFXNVfm9Z1gLLgvGpOUCIU +RbpoDckDe1vgUuk3y5+DjZihs+rqIJ45XzXTzBkCgYBWuf3segruJZy5rEKhTv+o +JqeUvRn0jNYYKFpLBeyTVBrbie6GkbUGNIWbrK05pC+c3K9nosvzuRUOQQL1tJbd +4gA3oiD9U4bMFNr+BRTHyZ7OQBcIXdz3t1qhuHVKtnngIAN1p25uPlbRFUNpshnt +jgeVoHlsBhApcs5DUc+pyQKBgDzeHPg/+g4z+nrPznjKnktRY1W+0El93kgi+J0Q +YiJacxBKEGTJ1MKBb8X6sDurcRDm22wMpGfd9I5Cv2v4GsUsF7HD/cx5xdih+G73 +c4clNj/k0Ff5Nm1izPUno4C+0IOl7br39IPmfpSuR6wH/h6iHQDqIeybjxyKvT1G +N0rRAoGBAKGD+4ZI/E1MoJ5CXB8cDDMHagbE3cq/DtmYzE2v1DFpQYu5I4PCm5c7 +EQeIP6dZtv8IMgtGIb91QX9pXvP0aznzQKwYIA8nZgoENCPfiMTPiEDT9e/0lObO +9XWsXpbSTsRPj0sv1rB+UzBJ0PgjK4q2zOF0sNo7b1+6nlM3BWPx +-----END RSA PRIVATE KEY----- +` + + // This thumbprint is for the testKey defined above. + testKeyThumbprint = "6nicxzh6WETQlrvdchkz-U3e3DOQZ4heJKU63rfqMqQ" + + // openssl ecparam -name secp256k1 -genkey -noout + testKeyECPEM = ` +-----BEGIN EC PRIVATE KEY----- +MHcCAQEEIK07hGLr0RwyUdYJ8wbIiBS55CjnkMD23DWr+ccnypWLoAoGCCqGSM49 +AwEHoUQDQgAE5lhEug5xK4xBDZ2nAbaxLtaLiv85bxJ7ePd1dkO23HThqIrvawF5 +QAaS/RNouybCiRhRjI3EaxLkQwgrCw0gqQ== +-----END EC PRIVATE KEY----- +` + // openssl ecparam -name secp384r1 -genkey -noout + testKeyEC384PEM = ` +-----BEGIN EC PRIVATE KEY----- +MIGkAgEBBDAQ4lNtXRORWr1bgKR1CGysr9AJ9SyEk4jiVnlUWWUChmSNL+i9SLSD +Oe/naPqXJ6CgBwYFK4EEACKhZANiAAQzKtj+Ms0vHoTX5dzv3/L5YMXOWuI5UKRj +JigpahYCqXD2BA1j0E/2xt5vlPf+gm0PL+UHSQsCokGnIGuaHCsJAp3ry0gHQEke +WYXapUUFdvaK1R2/2hn5O+eiQM8YzCg= +-----END EC PRIVATE KEY----- +` + // openssl ecparam -name secp521r1 -genkey -noout + testKeyEC512PEM = ` +-----BEGIN EC PRIVATE KEY----- +MIHcAgEBBEIBSNZKFcWzXzB/aJClAb305ibalKgtDA7+70eEkdPt28/3LZMM935Z +KqYHh/COcxuu3Kt8azRAUz3gyr4zZKhlKUSgBwYFK4EEACOhgYkDgYYABAHUNKbx +7JwC7H6pa2sV0tERWhHhB3JmW+OP6SUgMWryvIKajlx73eS24dy4QPGrWO9/ABsD +FqcRSkNVTXnIv6+0mAF25knqIBIg5Q8M9BnOu9GGAchcwt3O7RDHmqewnJJDrbjd +GGnm6rb+NnWR9DIopM0nKNkToWoF/hzopxu4Ae/GsQ== +-----END EC PRIVATE KEY----- +` + // 1. openssl ec -in key.pem -noout -text + // 2. remove first byte, 04 (the header); the rest is X and Y + // 3. convert each with: echo | xxd -r -p | b64raw + testKeyECPubX = "5lhEug5xK4xBDZ2nAbaxLtaLiv85bxJ7ePd1dkO23HQ" + testKeyECPubY = "4aiK72sBeUAGkv0TaLsmwokYUYyNxGsS5EMIKwsNIKk" + testKeyEC384PubX = "MyrY_jLNLx6E1-Xc79_y-WDFzlriOVCkYyYoKWoWAqlw9gQNY9BP9sbeb5T3_oJt" + testKeyEC384PubY = "Dy_lB0kLAqJBpyBrmhwrCQKd68tIB0BJHlmF2qVFBXb2itUdv9oZ-TvnokDPGMwo" + testKeyEC512PubX = "AdQ0pvHsnALsfqlraxXS0RFaEeEHcmZb44_pJSAxavK8gpqOXHvd5Lbh3LhA8atY738AGwMWpxFKQ1VNeci_r7SY" + testKeyEC512PubY = "AXbmSeogEiDlDwz0Gc670YYByFzC3c7tEMeap7CckkOtuN0Yaebqtv42dZH0MiikzSco2ROhagX-HOinG7gB78ax" + + // echo -n '{"crv":"P-256","kty":"EC","x":"","y":""}' | \ + // openssl dgst -binary -sha256 | b64raw + testKeyECThumbprint = "zedj-Bd1Zshp8KLePv2MB-lJ_Hagp7wAwdkA0NUTniU" +) + +var ( + testKey *rsa.PrivateKey + testKeyEC *ecdsa.PrivateKey + testKeyEC384 *ecdsa.PrivateKey + testKeyEC512 *ecdsa.PrivateKey +) + +func init() { + testKey = parseRSA(testKeyPEM, "testKeyPEM") + testKeyEC = parseEC(testKeyECPEM, "testKeyECPEM") + testKeyEC384 = parseEC(testKeyEC384PEM, "testKeyEC384PEM") + testKeyEC512 = parseEC(testKeyEC512PEM, "testKeyEC512PEM") +} + +func decodePEM(s, name string) []byte { + d, _ := pem.Decode([]byte(s)) + if d == nil { + panic("no block found in " + name) + } + return d.Bytes +} + +func parseRSA(s, name string) *rsa.PrivateKey { + b := decodePEM(s, name) + k, err := x509.ParsePKCS1PrivateKey(b) + if err != nil { + panic(fmt.Sprintf("%s: %v", name, err)) + } + return k +} + +func parseEC(s, name string) *ecdsa.PrivateKey { + b := decodePEM(s, name) + k, err := x509.ParseECPrivateKey(b) + if err != nil { + panic(fmt.Sprintf("%s: %v", name, err)) + } + return k +} + +func TestJWSEncodeJSON(t *testing.T) { + claims := struct{ Msg string }{"Hello JWS"} + // JWS signed with testKey and "nonce" as the nonce value + // JSON-serialized JWS fields are split for easier testing + const ( + // {"alg":"RS256","jwk":{"e":"AQAB","kty":"RSA","n":"..."},"nonce":"nonce","url":"url"} + protected = "eyJhbGciOiJSUzI1NiIsImp3ayI6eyJlIjoiQVFBQiIsImt0eSI6" + + "IlJTQSIsIm4iOiI0eGdaM2VSUGt3b1J2eTdxZVJVYm1NRGUwVi14" + + "SDllV0xkdTBpaGVlTGxybUQybXFXWGZQOUllU0tBcGJuMzRnOFR1" + + "QVM5ZzV6aHE4RUxRM2ttanItS1Y4NkdBTWdJNlZBY0dscTNRcnpw" + + "VENmXzMwQWI3LXphd3JmUmFGT05hMUh3RXpQWTFLSG5HVmt4SmM4" + + "NWdOa3dZSTlTWTJSSFh0dmxuM3pzNXdJVE5yZG9zcUVYZWFJa1ZZ" + + "QkVoYmhOdTU0cHAza3hvNlR1V0xpOWU2cFhlV2V0RXdtbEJ3dFda" + + "bFBvaWIyajNUeExCa3NLWmZveUZ5ZWszODBtSGdKQXVtUV9JMmZq" + + "ajk4Xzk3bWszaWhPWTRBZ1ZkQ0RqMXpfR0NvWmtHNVJxN25iQ0d5" + + "b3N5S1d5RFgwMFpzLW5OcVZob0xlSXZYQzRubldkSk1aNnJvZ3h5" + + "UVEifSwibm9uY2UiOiJub25jZSIsInVybCI6InVybCJ9" + // {"Msg":"Hello JWS"} + payload = "eyJNc2ciOiJIZWxsbyBKV1MifQ" + // printf '.' | openssl dgst -binary -sha256 -sign testKey | b64raw + signature = "YFyl_xz1E7TR-3E1bIuASTr424EgCvBHjt25WUFC2VaDjXYV0Rj_" + + "Hd3dJ_2IRqBrXDZZ2n4ZeA_4mm3QFwmwyeDwe2sWElhb82lCZ8iX" + + "uFnjeOmSOjx-nWwPa5ibCXzLq13zZ-OBV1Z4oN_TuailQeRoSfA3" + + "nO8gG52mv1x2OMQ5MAFtt8jcngBLzts4AyhI6mBJ2w7Yaj3ZCriq" + + "DWA3GLFvvHdW1Ba9Z01wtGT2CuZI7DUk_6Qj1b3BkBGcoKur5C9i" + + "bUJtCkABwBMvBQNyD3MmXsrRFRTgvVlyU_yMaucYm7nmzEr_2PaQ" + + "50rFt_9qOfJ4sfbLtG1Wwae57BQx1g" + ) + + b, err := jwsEncodeJSON(claims, testKey, noKeyID, "nonce", "url") + if err != nil { + t.Fatal(err) + } + var jws struct{ Protected, Payload, Signature string } + if err := json.Unmarshal(b, &jws); err != nil { + t.Fatal(err) + } + if jws.Protected != protected { + t.Errorf("protected:\n%s\nwant:\n%s", jws.Protected, protected) + } + if jws.Payload != payload { + t.Errorf("payload:\n%s\nwant:\n%s", jws.Payload, payload) + } + if jws.Signature != signature { + t.Errorf("signature:\n%s\nwant:\n%s", jws.Signature, signature) + } +} + +func TestJWSEncodeNoNonce(t *testing.T) { + kid := KeyID("https://example.org/account/1") + claims := "RawString" + const ( + // {"alg":"ES256","kid":"https://example.org/account/1","nonce":"nonce","url":"url"} + protected = "eyJhbGciOiJFUzI1NiIsImtpZCI6Imh0dHBzOi8vZXhhbXBsZS5vcmcvYWNjb3VudC8xIiwidXJsIjoidXJsIn0" + // "Raw String" + payload = "RawString" + ) + + b, err := jwsEncodeJSON(claims, testKeyEC, kid, "", "url") + if err != nil { + t.Fatal(err) + } + var jws struct{ Protected, Payload, Signature string } + if err := json.Unmarshal(b, &jws); err != nil { + t.Fatal(err) + } + if jws.Protected != protected { + t.Errorf("protected:\n%s\nwant:\n%s", jws.Protected, protected) + } + if jws.Payload != payload { + t.Errorf("payload:\n%s\nwant:\n%s", jws.Payload, payload) + } + + sig, err := base64.RawURLEncoding.DecodeString(jws.Signature) + if err != nil { + t.Fatalf("jws.Signature: %v", err) + } + r, s := big.NewInt(0), big.NewInt(0) + r.SetBytes(sig[:len(sig)/2]) + s.SetBytes(sig[len(sig)/2:]) + h := sha256.Sum256([]byte(protected + "." + payload)) + if !ecdsa.Verify(testKeyEC.Public().(*ecdsa.PublicKey), h[:], r, s) { + t.Error("invalid signature") + } +} + +func TestJWSEncodeKID(t *testing.T) { + kid := KeyID("https://example.org/account/1") + claims := struct{ Msg string }{"Hello JWS"} + // JWS signed with testKeyEC + const ( + // {"alg":"ES256","kid":"https://example.org/account/1","nonce":"nonce","url":"url"} + protected = "eyJhbGciOiJFUzI1NiIsImtpZCI6Imh0dHBzOi8vZXhhbXBsZS5" + + "vcmcvYWNjb3VudC8xIiwibm9uY2UiOiJub25jZSIsInVybCI6InVybCJ9" + // {"Msg":"Hello JWS"} + payload = "eyJNc2ciOiJIZWxsbyBKV1MifQ" + ) + + b, err := jwsEncodeJSON(claims, testKeyEC, kid, "nonce", "url") + if err != nil { + t.Fatal(err) + } + var jws struct{ Protected, Payload, Signature string } + if err := json.Unmarshal(b, &jws); err != nil { + t.Fatal(err) + } + if jws.Protected != protected { + t.Errorf("protected:\n%s\nwant:\n%s", jws.Protected, protected) + } + if jws.Payload != payload { + t.Errorf("payload:\n%s\nwant:\n%s", jws.Payload, payload) + } + + sig, err := base64.RawURLEncoding.DecodeString(jws.Signature) + if err != nil { + t.Fatalf("jws.Signature: %v", err) + } + r, s := big.NewInt(0), big.NewInt(0) + r.SetBytes(sig[:len(sig)/2]) + s.SetBytes(sig[len(sig)/2:]) + h := sha256.Sum256([]byte(protected + "." + payload)) + if !ecdsa.Verify(testKeyEC.Public().(*ecdsa.PublicKey), h[:], r, s) { + t.Error("invalid signature") + } +} + +func TestJWSEncodeJSONEC(t *testing.T) { + tt := []struct { + key *ecdsa.PrivateKey + x, y string + alg, crv string + }{ + {testKeyEC, testKeyECPubX, testKeyECPubY, "ES256", "P-256"}, + {testKeyEC384, testKeyEC384PubX, testKeyEC384PubY, "ES384", "P-384"}, + {testKeyEC512, testKeyEC512PubX, testKeyEC512PubY, "ES512", "P-521"}, + } + for i, test := range tt { + claims := struct{ Msg string }{"Hello JWS"} + b, err := jwsEncodeJSON(claims, test.key, noKeyID, "nonce", "url") + if err != nil { + t.Errorf("%d: %v", i, err) + continue + } + var jws struct{ Protected, Payload, Signature string } + if err := json.Unmarshal(b, &jws); err != nil { + t.Errorf("%d: %v", i, err) + continue + } + + b, err = base64.RawURLEncoding.DecodeString(jws.Protected) + if err != nil { + t.Errorf("%d: jws.Protected: %v", i, err) + } + var head struct { + Alg string + Nonce string + URL string `json:"url"` + KID string `json:"kid"` + JWK struct { + Crv string + Kty string + X string + Y string + } `json:"jwk"` + } + if err := json.Unmarshal(b, &head); err != nil { + t.Errorf("%d: jws.Protected: %v", i, err) + } + if head.Alg != test.alg { + t.Errorf("%d: head.Alg = %q; want %q", i, head.Alg, test.alg) + } + if head.Nonce != "nonce" { + t.Errorf("%d: head.Nonce = %q; want nonce", i, head.Nonce) + } + if head.URL != "url" { + t.Errorf("%d: head.URL = %q; want 'url'", i, head.URL) + } + if head.KID != "" { + // We used noKeyID in jwsEncodeJSON: expect no kid value. + t.Errorf("%d: head.KID = %q; want empty", i, head.KID) + } + if head.JWK.Crv != test.crv { + t.Errorf("%d: head.JWK.Crv = %q; want %q", i, head.JWK.Crv, test.crv) + } + if head.JWK.Kty != "EC" { + t.Errorf("%d: head.JWK.Kty = %q; want EC", i, head.JWK.Kty) + } + if head.JWK.X != test.x { + t.Errorf("%d: head.JWK.X = %q; want %q", i, head.JWK.X, test.x) + } + if head.JWK.Y != test.y { + t.Errorf("%d: head.JWK.Y = %q; want %q", i, head.JWK.Y, test.y) + } + } +} + +type customTestSigner struct { + sig []byte + pub crypto.PublicKey +} + +func (s *customTestSigner) Public() crypto.PublicKey { return s.pub } +func (s *customTestSigner) Sign(io.Reader, []byte, crypto.SignerOpts) ([]byte, error) { + return s.sig, nil +} + +func TestJWSEncodeJSONCustom(t *testing.T) { + claims := struct{ Msg string }{"hello"} + const ( + // printf '{"Msg":"hello"}' | b64raw + payload = "eyJNc2ciOiJoZWxsbyJ9" + // printf 'testsig' | b64raw + testsig = "dGVzdHNpZw" + + // the example P256 curve point from https://tools.ietf.org/html/rfc7515#appendix-A.3.1 + // encoded as ASN.1â€Ļ + es256stdsig = "MEUCIA7RIVN5Y2xIPC9/FVgH1AKjsigDOvl8fheBmsMWnqZlAiEA" + + "xQoH04w8cOXY8S2vCEpUgKZlkMXyk1Cajz9/ioOjVNU" + // â€Ļand RFC7518 (https://tools.ietf.org/html/rfc7518#section-3.4) + es256jwsig = "DtEhU3ljbEg8L38VWAfUAqOyKAM6-Xx-F4GawxaepmXFCgfTjDxw" + + "5djxLa8ISlSApmWQxfKTUJqPP3-Kg6NU1Q" + + // printf '{"alg":"ES256","jwk":{"crv":"P-256","kty":"EC","x":,"y":},"nonce":"nonce","url":"url"}' | b64raw + es256phead = "eyJhbGciOiJFUzI1NiIsImp3ayI6eyJjcnYiOiJQLTI1NiIsImt0" + + "eSI6IkVDIiwieCI6IjVsaEV1ZzV4SzR4QkRaMm5BYmF4THRhTGl2" + + "ODVieEo3ZVBkMWRrTzIzSFEiLCJ5IjoiNGFpSzcyc0JlVUFHa3Yw" + + "VGFMc213b2tZVVl5TnhHc1M1RU1JS3dzTklLayJ9LCJub25jZSI6" + + "Im5vbmNlIiwidXJsIjoidXJsIn0" + + // {"alg":"RS256","jwk":{"e":"AQAB","kty":"RSA","n":"..."},"nonce":"nonce","url":"url"} + rs256phead = "eyJhbGciOiJSUzI1NiIsImp3ayI6eyJlIjoiQVFBQiIsImt0eSI6" + + "IlJTQSIsIm4iOiI0eGdaM2VSUGt3b1J2eTdxZVJVYm1NRGUwVi14" + + "SDllV0xkdTBpaGVlTGxybUQybXFXWGZQOUllU0tBcGJuMzRnOFR1" + + "QVM5ZzV6aHE4RUxRM2ttanItS1Y4NkdBTWdJNlZBY0dscTNRcnpw" + + "VENmXzMwQWI3LXphd3JmUmFGT05hMUh3RXpQWTFLSG5HVmt4SmM4" + + "NWdOa3dZSTlTWTJSSFh0dmxuM3pzNXdJVE5yZG9zcUVYZWFJa1ZZ" + + "QkVoYmhOdTU0cHAza3hvNlR1V0xpOWU2cFhlV2V0RXdtbEJ3dFda" + + "bFBvaWIyajNUeExCa3NLWmZveUZ5ZWszODBtSGdKQXVtUV9JMmZq" + + "ajk4Xzk3bWszaWhPWTRBZ1ZkQ0RqMXpfR0NvWmtHNVJxN25iQ0d5" + + "b3N5S1d5RFgwMFpzLW5OcVZob0xlSXZYQzRubldkSk1aNnJvZ3h5" + + "UVEifSwibm9uY2UiOiJub25jZSIsInVybCI6InVybCJ9" + ) + + tt := []struct { + alg, phead string + pub crypto.PublicKey + stdsig, jwsig string + }{ + {"ES256", es256phead, testKeyEC.Public(), es256stdsig, es256jwsig}, + {"RS256", rs256phead, testKey.Public(), testsig, testsig}, + } + for _, tc := range tt { + tc := tc + t.Run(tc.alg, func(t *testing.T) { + stdsig, err := base64.RawStdEncoding.DecodeString(tc.stdsig) + if err != nil { + t.Errorf("couldn't decode test vector: %v", err) + } + signer := &customTestSigner{ + sig: stdsig, + pub: tc.pub, + } + + b, err := jwsEncodeJSON(claims, signer, noKeyID, "nonce", "url") + if err != nil { + t.Fatal(err) + } + var j jsonWebSignature + if err := json.Unmarshal(b, &j); err != nil { + t.Fatal(err) + } + if j.Protected != tc.phead { + t.Errorf("j.Protected = %q\nwant %q", j.Protected, tc.phead) + } + if j.Payload != payload { + t.Errorf("j.Payload = %q\nwant %q", j.Payload, payload) + } + if j.Sig != tc.jwsig { + t.Errorf("j.Sig = %q\nwant %q", j.Sig, tc.jwsig) + } + }) + } +} + +func TestJWSWithMAC(t *testing.T) { + // Example from RFC 7520 Section 4.4.3. + // https://tools.ietf.org/html/rfc7520#section-4.4.3 + b64Key := "hJtXIZ2uSN5kbQfbtTNWbpdmhkV8FJG-Onbc6mxCcYg" + rawPayload := []byte("It\xe2\x80\x99s a dangerous business, Frodo, going out your " + + "door. You step onto the road, and if you don't keep your feet, " + + "there\xe2\x80\x99s no knowing where you might be swept off " + + "to.") + protected := "eyJhbGciOiJIUzI1NiIsImtpZCI6IjAxOGMwYWU1LTRkOWItNDcxYi1iZmQ2LW" + + "VlZjMxNGJjNzAzNyJ9" + payload := "SXTigJlzIGEgZGFuZ2Vyb3VzIGJ1c2luZXNzLCBGcm9kbywg" + + "Z29pbmcgb3V0IHlvdXIgZG9vci4gWW91IHN0ZXAgb250byB0aGUgcm9h" + + "ZCwgYW5kIGlmIHlvdSBkb24ndCBrZWVwIHlvdXIgZmVldCwgdGhlcmXi" + + "gJlzIG5vIGtub3dpbmcgd2hlcmUgeW91IG1pZ2h0IGJlIHN3ZXB0IG9m" + + "ZiB0by4" + sig := "s0h6KThzkfBBBkLspW1h84VsJZFTsPPqMDA7g1Md7p0" + + key, err := base64.RawURLEncoding.DecodeString(b64Key) + if err != nil { + t.Fatalf("unable to decode key: %q", b64Key) + } + got, err := jwsWithMAC(key, "018c0ae5-4d9b-471b-bfd6-eef314bc7037", "", rawPayload) + if err != nil { + t.Fatalf("jwsWithMAC() = %q", err) + } + if got.Protected != protected { + t.Errorf("got.Protected = %q\nwant %q", got.Protected, protected) + } + if got.Payload != payload { + t.Errorf("got.Payload = %q\nwant %q", got.Payload, payload) + } + if got.Sig != sig { + t.Errorf("got.Signature = %q\nwant %q", got.Sig, sig) + } +} + +func TestJWSWithMACError(t *testing.T) { + p := "{}" + if _, err := jwsWithMAC(nil, "", "", []byte(p)); err == nil { + t.Errorf("jwsWithMAC(nil, ...) = success; want err") + } +} + +func TestJWKThumbprintRSA(t *testing.T) { + // Key example from RFC 7638 + const base64N = "0vx7agoebGcQSuuPiLJXZptN9nndrQmbXEps2aiAFbWhM78LhWx4cbbfAAt" + + "VT86zwu1RK7aPFFxuhDR1L6tSoc_BJECPebWKRXjBZCiFV4n3oknjhMstn6" + + "4tZ_2W-5JsGY4Hc5n9yBXArwl93lqt7_RN5w6Cf0h4QyQ5v-65YGjQR0_FD" + + "W2QvzqY368QQMicAtaSqzs8KJZgnYb9c7d0zgdAZHzu6qMQvRL5hajrn1n9" + + "1CbOpbISD08qNLyrdkt-bFTWhAI4vMQFh6WeZu0fM4lFd2NcRwr3XPksINH" + + "aQ-G_xBniIqbw0Ls1jF44-csFCur-kEgU8awapJzKnqDKgw" + const base64E = "AQAB" + const expected = "NzbLsXh8uDCcd-6MNwXF4W_7noWXFZAfHkxZsRGC9Xs" + + b, err := base64.RawURLEncoding.DecodeString(base64N) + if err != nil { + t.Fatalf("Error parsing example key N: %v", err) + } + n := new(big.Int).SetBytes(b) + + b, err = base64.RawURLEncoding.DecodeString(base64E) + if err != nil { + t.Fatalf("Error parsing example key E: %v", err) + } + e := new(big.Int).SetBytes(b) + + pub := &rsa.PublicKey{N: n, E: int(e.Uint64())} + th, err := JWKThumbprint(pub) + if err != nil { + t.Error(err) + } + if th != expected { + t.Errorf("thumbprint = %q; want %q", th, expected) + } +} + +func TestJWKThumbprintEC(t *testing.T) { + // Key example from RFC 7520 + // expected was computed with + // printf '{"crv":"P-521","kty":"EC","x":"","y":""}' | \ + // openssl dgst -binary -sha256 | b64raw + const ( + base64X = "AHKZLLOsCOzz5cY97ewNUajB957y-C-U88c3v13nmGZx6sYl_oJXu9A5RkT" + + "KqjqvjyekWF-7ytDyRXYgCF5cj0Kt" + base64Y = "AdymlHvOiLxXkEhayXQnNCvDX4h9htZaCJN34kfmC6pV5OhQHiraVySsUda" + + "QkAgDPrwQrJmbnX9cwlGfP-HqHZR1" + expected = "dHri3SADZkrush5HU_50AoRhcKFryN-PI6jPBtPL55M" + ) + + b, err := base64.RawURLEncoding.DecodeString(base64X) + if err != nil { + t.Fatalf("Error parsing example key X: %v", err) + } + x := new(big.Int).SetBytes(b) + + b, err = base64.RawURLEncoding.DecodeString(base64Y) + if err != nil { + t.Fatalf("Error parsing example key Y: %v", err) + } + y := new(big.Int).SetBytes(b) + + pub := &ecdsa.PublicKey{Curve: elliptic.P521(), X: x, Y: y} + th, err := JWKThumbprint(pub) + if err != nil { + t.Error(err) + } + if th != expected { + t.Errorf("thumbprint = %q; want %q", th, expected) + } +} + +func TestJWKThumbprintErrUnsupportedKey(t *testing.T) { + _, err := JWKThumbprint(struct{}{}) + if err != ErrUnsupportedKey { + t.Errorf("err = %q; want %q", err, ErrUnsupportedKey) + } +} diff --git a/tempfork/acme/rfc8555.go b/tempfork/acme/rfc8555.go new file mode 100644 index 000000000..3eaf935fd --- /dev/null +++ b/tempfork/acme/rfc8555.go @@ -0,0 +1,486 @@ +// Copyright 2019 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package acme + +import ( + "context" + "crypto" + "crypto/x509" + "encoding/base64" + "encoding/json" + "encoding/pem" + "errors" + "fmt" + "io" + "net/http" + "time" +) + +// DeactivateReg permanently disables an existing account associated with c.Key. +// A deactivated account can no longer request certificate issuance or access +// resources related to the account, such as orders or authorizations. +// +// It only works with CAs implementing RFC 8555. +func (c *Client) DeactivateReg(ctx context.Context) error { + if _, err := c.Discover(ctx); err != nil { // required by c.accountKID + return err + } + url := string(c.accountKID(ctx)) + if url == "" { + return ErrNoAccount + } + req := json.RawMessage(`{"status": "deactivated"}`) + res, err := c.post(ctx, nil, url, req, wantStatus(http.StatusOK)) + if err != nil { + return err + } + res.Body.Close() + return nil +} + +// registerRFC is equivalent to c.Register but for CAs implementing RFC 8555. +// It expects c.Discover to have already been called. +func (c *Client) registerRFC(ctx context.Context, acct *Account, prompt func(tosURL string) bool) (*Account, error) { + c.cacheMu.Lock() // guard c.kid access + defer c.cacheMu.Unlock() + + req := struct { + TermsAgreed bool `json:"termsOfServiceAgreed,omitempty"` + Contact []string `json:"contact,omitempty"` + ExternalAccountBinding *jsonWebSignature `json:"externalAccountBinding,omitempty"` + }{ + Contact: acct.Contact, + } + if c.dir.Terms != "" { + req.TermsAgreed = prompt(c.dir.Terms) + } + + // set 'externalAccountBinding' field if requested + if acct.ExternalAccountBinding != nil { + eabJWS, err := c.encodeExternalAccountBinding(acct.ExternalAccountBinding) + if err != nil { + return nil, fmt.Errorf("acme: failed to encode external account binding: %v", err) + } + req.ExternalAccountBinding = eabJWS + } + + res, err := c.post(ctx, c.Key, c.dir.RegURL, req, wantStatus( + http.StatusOK, // account with this key already registered + http.StatusCreated, // new account created + )) + if err != nil { + return nil, err + } + + defer res.Body.Close() + a, err := responseAccount(res) + if err != nil { + return nil, err + } + // Cache Account URL even if we return an error to the caller. + // It is by all means a valid and usable "kid" value for future requests. + c.KID = KeyID(a.URI) + if res.StatusCode == http.StatusOK { + return nil, ErrAccountAlreadyExists + } + return a, nil +} + +// encodeExternalAccountBinding will encode an external account binding stanza +// as described in https://tools.ietf.org/html/rfc8555#section-7.3.4. +func (c *Client) encodeExternalAccountBinding(eab *ExternalAccountBinding) (*jsonWebSignature, error) { + jwk, err := jwkEncode(c.Key.Public()) + if err != nil { + return nil, err + } + return jwsWithMAC(eab.Key, eab.KID, c.dir.RegURL, []byte(jwk)) +} + +// updateRegRFC is equivalent to c.UpdateReg but for CAs implementing RFC 8555. +// It expects c.Discover to have already been called. +func (c *Client) updateRegRFC(ctx context.Context, a *Account) (*Account, error) { + url := string(c.accountKID(ctx)) + if url == "" { + return nil, ErrNoAccount + } + req := struct { + Contact []string `json:"contact,omitempty"` + }{ + Contact: a.Contact, + } + res, err := c.post(ctx, nil, url, req, wantStatus(http.StatusOK)) + if err != nil { + return nil, err + } + defer res.Body.Close() + return responseAccount(res) +} + +// getRegRFC is equivalent to c.GetReg but for CAs implementing RFC 8555. +// It expects c.Discover to have already been called. +func (c *Client) getRegRFC(ctx context.Context) (*Account, error) { + req := json.RawMessage(`{"onlyReturnExisting": true}`) + res, err := c.post(ctx, c.Key, c.dir.RegURL, req, wantStatus(http.StatusOK)) + if e, ok := err.(*Error); ok && e.ProblemType == "urn:ietf:params:acme:error:accountDoesNotExist" { + return nil, ErrNoAccount + } + if err != nil { + return nil, err + } + + defer res.Body.Close() + return responseAccount(res) +} + +func responseAccount(res *http.Response) (*Account, error) { + var v struct { + Status string + Contact []string + Orders string + } + if err := json.NewDecoder(res.Body).Decode(&v); err != nil { + return nil, fmt.Errorf("acme: invalid account response: %v", err) + } + return &Account{ + URI: res.Header.Get("Location"), + Status: v.Status, + Contact: v.Contact, + OrdersURL: v.Orders, + }, nil +} + +// accountKeyRollover attempts to perform account key rollover. +// On success it will change client.Key to the new key. +func (c *Client) accountKeyRollover(ctx context.Context, newKey crypto.Signer) error { + dir, err := c.Discover(ctx) // Also required by c.accountKID + if err != nil { + return err + } + kid := c.accountKID(ctx) + if kid == noKeyID { + return ErrNoAccount + } + oldKey, err := jwkEncode(c.Key.Public()) + if err != nil { + return err + } + payload := struct { + Account string `json:"account"` + OldKey json.RawMessage `json:"oldKey"` + }{ + Account: string(kid), + OldKey: json.RawMessage(oldKey), + } + inner, err := jwsEncodeJSON(payload, newKey, noKeyID, noNonce, dir.KeyChangeURL) + if err != nil { + return err + } + + res, err := c.post(ctx, nil, dir.KeyChangeURL, base64.RawURLEncoding.EncodeToString(inner), wantStatus(http.StatusOK)) + if err != nil { + return err + } + defer res.Body.Close() + c.Key = newKey + return nil +} + +// AuthorizeOrder initiates the order-based application for certificate issuance, +// as opposed to pre-authorization in Authorize. +// It is only supported by CAs implementing RFC 8555. +// +// The caller then needs to fetch each authorization with GetAuthorization, +// identify those with StatusPending status and fulfill a challenge using Accept. +// Once all authorizations are satisfied, the caller will typically want to poll +// order status using WaitOrder until it's in StatusReady state. +// To finalize the order and obtain a certificate, the caller submits a CSR with CreateOrderCert. +func (c *Client) AuthorizeOrder(ctx context.Context, id []AuthzID, opt ...OrderOption) (*Order, error) { + dir, err := c.Discover(ctx) + if err != nil { + return nil, err + } + + req := struct { + Identifiers []wireAuthzID `json:"identifiers"` + NotBefore string `json:"notBefore,omitempty"` + NotAfter string `json:"notAfter,omitempty"` + Replaces string `json:"replaces,omitempty"` + }{} + for _, v := range id { + req.Identifiers = append(req.Identifiers, wireAuthzID{ + Type: v.Type, + Value: v.Value, + }) + } + for _, o := range opt { + switch o := o.(type) { + case orderNotBeforeOpt: + req.NotBefore = time.Time(o).Format(time.RFC3339) + case orderNotAfterOpt: + req.NotAfter = time.Time(o).Format(time.RFC3339) + case orderReplacesCert: + req.Replaces = certRenewalIdentifier(o.cert) + case orderReplacesCertDER: + cert, err := x509.ParseCertificate(o) + if err != nil { + return nil, fmt.Errorf("failed to parse certificate being replaced: %w", err) + } + req.Replaces = certRenewalIdentifier(cert) + default: + // Package's fault if we let this happen. + panic(fmt.Sprintf("unsupported order option type %T", o)) + } + } + + res, err := c.post(ctx, nil, dir.OrderURL, req, wantStatus(http.StatusCreated)) + if err != nil { + return nil, err + } + defer res.Body.Close() + return responseOrder(res) +} + +// GetOrder retrives an order identified by the given URL. +// For orders created with AuthorizeOrder, the url value is Order.URI. +// +// If a caller needs to poll an order until its status is final, +// see the WaitOrder method. +func (c *Client) GetOrder(ctx context.Context, url string) (*Order, error) { + if _, err := c.Discover(ctx); err != nil { + return nil, err + } + + res, err := c.postAsGet(ctx, url, wantStatus(http.StatusOK)) + if err != nil { + return nil, err + } + defer res.Body.Close() + return responseOrder(res) +} + +// WaitOrder polls an order from the given URL until it is in one of the final states, +// StatusReady, StatusValid or StatusInvalid, the CA responded with a non-retryable error +// or the context is done. +// +// It returns a non-nil Order only if its Status is StatusReady or StatusValid. +// In all other cases WaitOrder returns an error. +// If the Status is StatusInvalid, the returned error is of type *OrderError. +func (c *Client) WaitOrder(ctx context.Context, url string) (*Order, error) { + if _, err := c.Discover(ctx); err != nil { + return nil, err + } + for { + res, err := c.postAsGet(ctx, url, wantStatus(http.StatusOK)) + if err != nil { + return nil, err + } + o, err := responseOrder(res) + res.Body.Close() + switch { + case err != nil: + // Skip and retry. + case o.Status == StatusInvalid: + return nil, &OrderError{OrderURL: o.URI, Status: o.Status} + case o.Status == StatusReady || o.Status == StatusValid: + return o, nil + } + + d := retryAfter(res.Header.Get("Retry-After")) + if d == 0 { + // Default retry-after. + // Same reasoning as in WaitAuthorization. + d = time.Second + } + t := time.NewTimer(d) + select { + case <-ctx.Done(): + t.Stop() + return nil, ctx.Err() + case <-t.C: + // Retry. + } + } +} + +func responseOrder(res *http.Response) (*Order, error) { + var v struct { + Status string + Expires time.Time + Identifiers []wireAuthzID + NotBefore time.Time + NotAfter time.Time + Error *wireError + Authorizations []string + Finalize string + Certificate string + } + if err := json.NewDecoder(res.Body).Decode(&v); err != nil { + return nil, fmt.Errorf("acme: error reading order: %v", err) + } + o := &Order{ + URI: res.Header.Get("Location"), + Status: v.Status, + Expires: v.Expires, + NotBefore: v.NotBefore, + NotAfter: v.NotAfter, + AuthzURLs: v.Authorizations, + FinalizeURL: v.Finalize, + CertURL: v.Certificate, + } + for _, id := range v.Identifiers { + o.Identifiers = append(o.Identifiers, AuthzID{Type: id.Type, Value: id.Value}) + } + if v.Error != nil { + o.Error = v.Error.error(nil /* headers */) + } + return o, nil +} + +// CreateOrderCert submits the CSR (Certificate Signing Request) to a CA at the specified URL. +// The URL is the FinalizeURL field of an Order created with AuthorizeOrder. +// +// If the bundle argument is true, the returned value also contain the CA (issuer) +// certificate chain. Otherwise, only a leaf certificate is returned. +// The returned URL can be used to re-fetch the certificate using FetchCert. +// +// This method is only supported by CAs implementing RFC 8555. See CreateCert for pre-RFC CAs. +// +// CreateOrderCert returns an error if the CA's response is unreasonably large. +// Callers are encouraged to parse the returned value to ensure the certificate is valid and has the expected features. +func (c *Client) CreateOrderCert(ctx context.Context, url string, csr []byte, bundle bool) (der [][]byte, certURL string, err error) { + if _, err := c.Discover(ctx); err != nil { // required by c.accountKID + return nil, "", err + } + + // RFC describes this as "finalize order" request. + req := struct { + CSR string `json:"csr"` + }{ + CSR: base64.RawURLEncoding.EncodeToString(csr), + } + res, err := c.post(ctx, nil, url, req, wantStatus(http.StatusOK)) + if err != nil { + return nil, "", err + } + defer res.Body.Close() + o, err := responseOrder(res) + if err != nil { + return nil, "", err + } + + // Wait for CA to issue the cert if they haven't. + if o.Status != StatusValid { + o, err = c.WaitOrder(ctx, o.URI) + } + if err != nil { + return nil, "", err + } + // The only acceptable status post finalize and WaitOrder is "valid". + if o.Status != StatusValid { + return nil, "", &OrderError{OrderURL: o.URI, Status: o.Status} + } + crt, err := c.fetchCertRFC(ctx, o.CertURL, bundle) + return crt, o.CertURL, err +} + +// fetchCertRFC downloads issued certificate from the given URL. +// It expects the CA to respond with PEM-encoded certificate chain. +// +// The URL argument is the CertURL field of Order. +func (c *Client) fetchCertRFC(ctx context.Context, url string, bundle bool) ([][]byte, error) { + res, err := c.postAsGet(ctx, url, wantStatus(http.StatusOK)) + if err != nil { + return nil, err + } + defer res.Body.Close() + + // Get all the bytes up to a sane maximum. + // Account very roughly for base64 overhead. + const max = maxCertChainSize + maxCertChainSize/33 + b, err := io.ReadAll(io.LimitReader(res.Body, max+1)) + if err != nil { + return nil, fmt.Errorf("acme: fetch cert response stream: %v", err) + } + if len(b) > max { + return nil, errors.New("acme: certificate chain is too big") + } + + // Decode PEM chain. + var chain [][]byte + for { + var p *pem.Block + p, b = pem.Decode(b) + if p == nil { + break + } + if p.Type != "CERTIFICATE" { + return nil, fmt.Errorf("acme: invalid PEM cert type %q", p.Type) + } + + chain = append(chain, p.Bytes) + if !bundle { + return chain, nil + } + if len(chain) > maxChainLen { + return nil, errors.New("acme: certificate chain is too long") + } + } + if len(chain) == 0 { + return nil, errors.New("acme: certificate chain is empty") + } + return chain, nil +} + +// sends a cert revocation request in either JWK form when key is non-nil or KID form otherwise. +func (c *Client) revokeCertRFC(ctx context.Context, key crypto.Signer, cert []byte, reason CRLReasonCode) error { + req := &struct { + Cert string `json:"certificate"` + Reason int `json:"reason"` + }{ + Cert: base64.RawURLEncoding.EncodeToString(cert), + Reason: int(reason), + } + res, err := c.post(ctx, key, c.dir.RevokeURL, req, wantStatus(http.StatusOK)) + if err != nil { + if isAlreadyRevoked(err) { + // Assume it is not an error to revoke an already revoked cert. + return nil + } + return err + } + defer res.Body.Close() + return nil +} + +func isAlreadyRevoked(err error) bool { + e, ok := err.(*Error) + return ok && e.ProblemType == "urn:ietf:params:acme:error:alreadyRevoked" +} + +// ListCertAlternates retrieves any alternate certificate chain URLs for the +// given certificate chain URL. These alternate URLs can be passed to FetchCert +// in order to retrieve the alternate certificate chains. +// +// If there are no alternate issuer certificate chains, a nil slice will be +// returned. +func (c *Client) ListCertAlternates(ctx context.Context, url string) ([]string, error) { + if _, err := c.Discover(ctx); err != nil { // required by c.accountKID + return nil, err + } + + res, err := c.postAsGet(ctx, url, wantStatus(http.StatusOK)) + if err != nil { + return nil, err + } + defer res.Body.Close() + + // We don't need the body but we need to discard it so we don't end up + // preventing keep-alive + if _, err := io.Copy(io.Discard, res.Body); err != nil { + return nil, fmt.Errorf("acme: cert alternates response stream: %v", err) + } + alts := linkHeader(res.Header, "alternate") + return alts, nil +} diff --git a/tempfork/acme/rfc8555_test.go b/tempfork/acme/rfc8555_test.go new file mode 100644 index 000000000..ec51a7a5e --- /dev/null +++ b/tempfork/acme/rfc8555_test.go @@ -0,0 +1,1024 @@ +// Copyright 2019 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package acme + +import ( + "bytes" + "context" + "crypto/hmac" + "crypto/rand" + "crypto/sha256" + "crypto/x509" + "crypto/x509/pkix" + "encoding/base64" + "encoding/json" + "encoding/pem" + "errors" + "fmt" + "io" + "math/big" + "net/http" + "net/http/httptest" + "reflect" + "strings" + "sync" + "testing" + "time" +) + +// While contents of this file is pertinent only to RFC8555, +// it is complementary to the tests in the other _test.go files +// many of which are valid for both pre- and RFC8555. +// This will make it easier to clean up the tests once non-RFC compliant +// code is removed. + +func TestRFC_Discover(t *testing.T) { + const ( + nonce = "https://example.com/acme/new-nonce" + reg = "https://example.com/acme/new-acct" + order = "https://example.com/acme/new-order" + authz = "https://example.com/acme/new-authz" + revoke = "https://example.com/acme/revoke-cert" + keychange = "https://example.com/acme/key-change" + metaTerms = "https://example.com/acme/terms/2017-5-30" + metaWebsite = "https://www.example.com/" + metaCAA = "example.com" + ) + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + fmt.Fprintf(w, `{ + "newNonce": %q, + "newAccount": %q, + "newOrder": %q, + "newAuthz": %q, + "revokeCert": %q, + "keyChange": %q, + "meta": { + "termsOfService": %q, + "website": %q, + "caaIdentities": [%q], + "externalAccountRequired": true + } + }`, nonce, reg, order, authz, revoke, keychange, metaTerms, metaWebsite, metaCAA) + })) + defer ts.Close() + c := &Client{DirectoryURL: ts.URL} + dir, err := c.Discover(context.Background()) + if err != nil { + t.Fatal(err) + } + if dir.NonceURL != nonce { + t.Errorf("dir.NonceURL = %q; want %q", dir.NonceURL, nonce) + } + if dir.RegURL != reg { + t.Errorf("dir.RegURL = %q; want %q", dir.RegURL, reg) + } + if dir.OrderURL != order { + t.Errorf("dir.OrderURL = %q; want %q", dir.OrderURL, order) + } + if dir.AuthzURL != authz { + t.Errorf("dir.AuthzURL = %q; want %q", dir.AuthzURL, authz) + } + if dir.RevokeURL != revoke { + t.Errorf("dir.RevokeURL = %q; want %q", dir.RevokeURL, revoke) + } + if dir.KeyChangeURL != keychange { + t.Errorf("dir.KeyChangeURL = %q; want %q", dir.KeyChangeURL, keychange) + } + if dir.Terms != metaTerms { + t.Errorf("dir.Terms = %q; want %q", dir.Terms, metaTerms) + } + if dir.Website != metaWebsite { + t.Errorf("dir.Website = %q; want %q", dir.Website, metaWebsite) + } + if len(dir.CAA) == 0 || dir.CAA[0] != metaCAA { + t.Errorf("dir.CAA = %q; want [%q]", dir.CAA, metaCAA) + } + if !dir.ExternalAccountRequired { + t.Error("dir.Meta.ExternalAccountRequired is false") + } +} + +func TestRFC_popNonce(t *testing.T) { + var count int + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // The Client uses only Directory.NonceURL when specified. + // Expect no other URL paths. + if r.URL.Path != "/new-nonce" { + t.Errorf("r.URL.Path = %q; want /new-nonce", r.URL.Path) + } + if count > 0 { + w.WriteHeader(http.StatusTooManyRequests) + return + } + count++ + w.Header().Set("Replay-Nonce", "second") + })) + cl := &Client{ + DirectoryURL: ts.URL, + dir: &Directory{NonceURL: ts.URL + "/new-nonce"}, + } + cl.addNonce(http.Header{"Replay-Nonce": {"first"}}) + + for i, nonce := range []string{"first", "second"} { + v, err := cl.popNonce(context.Background(), "") + if err != nil { + t.Errorf("%d: cl.popNonce: %v", i, err) + } + if v != nonce { + t.Errorf("%d: cl.popNonce = %q; want %q", i, v, nonce) + } + } + // No more nonces and server replies with an error past first nonce fetch. + // Expected to fail. + if _, err := cl.popNonce(context.Background(), ""); err == nil { + t.Error("last cl.popNonce returned nil error") + } +} + +func TestRFC_postKID(t *testing.T) { + var ts *httptest.Server + ts = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/new-nonce": + w.Header().Set("Replay-Nonce", "nonce") + case "/new-account": + w.Header().Set("Location", "/account-1") + w.Write([]byte(`{"status":"valid"}`)) + case "/post": + b, _ := io.ReadAll(r.Body) // check err later in decodeJWSxxx + head, err := decodeJWSHead(bytes.NewReader(b)) + if err != nil { + t.Errorf("decodeJWSHead: %v", err) + return + } + if head.KID != "/account-1" { + t.Errorf("head.KID = %q; want /account-1", head.KID) + } + if len(head.JWK) != 0 { + t.Errorf("head.JWK = %q; want zero map", head.JWK) + } + if v := ts.URL + "/post"; head.URL != v { + t.Errorf("head.URL = %q; want %q", head.URL, v) + } + + var payload struct{ Msg string } + decodeJWSRequest(t, &payload, bytes.NewReader(b)) + if payload.Msg != "ping" { + t.Errorf("payload.Msg = %q; want ping", payload.Msg) + } + w.Write([]byte("pong")) + default: + t.Errorf("unhandled %s %s", r.Method, r.URL) + w.WriteHeader(http.StatusBadRequest) + } + })) + defer ts.Close() + + ctx := context.Background() + cl := &Client{ + Key: testKey, + DirectoryURL: ts.URL, + dir: &Directory{ + NonceURL: ts.URL + "/new-nonce", + RegURL: ts.URL + "/new-account", + OrderURL: "/force-rfc-mode", + }, + } + req := json.RawMessage(`{"msg":"ping"}`) + res, err := cl.post(ctx, nil /* use kid */, ts.URL+"/post", req, wantStatus(http.StatusOK)) + if err != nil { + t.Fatal(err) + } + defer res.Body.Close() + b, _ := io.ReadAll(res.Body) // don't care about err - just checking b + if string(b) != "pong" { + t.Errorf("res.Body = %q; want pong", b) + } +} + +// acmeServer simulates a subset of RFC 8555 compliant CA. +// +// TODO: We also have x/crypto/acme/autocert/acmetest and startACMEServerStub in autocert_test.go. +// It feels like this acmeServer is a sweet spot between usefulness and added complexity. +// Also, acmetest and startACMEServerStub were both written for draft-02, no RFC support. +// The goal is to consolidate all into one ACME test server. +type acmeServer struct { + ts *httptest.Server + handler map[string]http.HandlerFunc // keyed by r.URL.Path + + mu sync.Mutex + nnonce int +} + +func newACMEServer() *acmeServer { + return &acmeServer{handler: make(map[string]http.HandlerFunc)} +} + +func (s *acmeServer) handle(path string, f func(http.ResponseWriter, *http.Request)) { + s.handler[path] = http.HandlerFunc(f) +} + +func (s *acmeServer) start() { + s.ts = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + + // Directory request. + if r.URL.Path == "/" { + fmt.Fprintf(w, `{ + "newNonce": %q, + "newAccount": %q, + "newOrder": %q, + "newAuthz": %q, + "revokeCert": %q, + "keyChange": %q, + "meta": {"termsOfService": %q} + }`, + s.url("/acme/new-nonce"), + s.url("/acme/new-account"), + s.url("/acme/new-order"), + s.url("/acme/new-authz"), + s.url("/acme/revoke-cert"), + s.url("/acme/key-change"), + s.url("/terms"), + ) + return + } + + // All other responses contain a nonce value unconditionally. + w.Header().Set("Replay-Nonce", s.nonce()) + if r.URL.Path == "/acme/new-nonce" { + return + } + + h := s.handler[r.URL.Path] + if h == nil { + w.WriteHeader(http.StatusBadRequest) + fmt.Fprintf(w, "Unhandled %s", r.URL.Path) + return + } + h.ServeHTTP(w, r) + })) +} + +func (s *acmeServer) close() { + s.ts.Close() +} + +func (s *acmeServer) url(path string) string { + return s.ts.URL + path +} + +func (s *acmeServer) nonce() string { + s.mu.Lock() + defer s.mu.Unlock() + s.nnonce++ + return fmt.Sprintf("nonce%d", s.nnonce) +} + +func (s *acmeServer) error(w http.ResponseWriter, e *wireError) { + w.WriteHeader(e.Status) + json.NewEncoder(w).Encode(e) +} + +func TestRFC_Register(t *testing.T) { + const email = "mailto:user@example.org" + + s := newACMEServer() + s.handle("/acme/new-account", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Location", s.url("/accounts/1")) + w.WriteHeader(http.StatusCreated) // 201 means new account created + fmt.Fprintf(w, `{ + "status": "valid", + "contact": [%q], + "orders": %q + }`, email, s.url("/accounts/1/orders")) + + b, _ := io.ReadAll(r.Body) // check err later in decodeJWSxxx + head, err := decodeJWSHead(bytes.NewReader(b)) + if err != nil { + t.Errorf("decodeJWSHead: %v", err) + return + } + if len(head.JWK) == 0 { + t.Error("head.JWK is empty") + } + + var req struct{ Contact []string } + decodeJWSRequest(t, &req, bytes.NewReader(b)) + if len(req.Contact) != 1 || req.Contact[0] != email { + t.Errorf("req.Contact = %q; want [%q]", req.Contact, email) + } + }) + s.start() + defer s.close() + + ctx := context.Background() + cl := &Client{ + Key: testKeyEC, + DirectoryURL: s.url("/"), + } + + var didPrompt bool + a := &Account{Contact: []string{email}} + acct, err := cl.Register(ctx, a, func(tos string) bool { + didPrompt = true + terms := s.url("/terms") + if tos != terms { + t.Errorf("tos = %q; want %q", tos, terms) + } + return true + }) + if err != nil { + t.Fatal(err) + } + okAccount := &Account{ + URI: s.url("/accounts/1"), + Status: StatusValid, + Contact: []string{email}, + OrdersURL: s.url("/accounts/1/orders"), + } + if !reflect.DeepEqual(acct, okAccount) { + t.Errorf("acct = %+v; want %+v", acct, okAccount) + } + if !didPrompt { + t.Error("tos prompt wasn't called") + } + if v := cl.accountKID(ctx); v != KeyID(okAccount.URI) { + t.Errorf("account kid = %q; want %q", v, okAccount.URI) + } +} + +func TestRFC_RegisterExternalAccountBinding(t *testing.T) { + eab := &ExternalAccountBinding{ + KID: "kid-1", + Key: []byte("secret"), + } + + type protected struct { + Algorithm string `json:"alg"` + KID string `json:"kid"` + URL string `json:"url"` + } + const email = "mailto:user@example.org" + + s := newACMEServer() + s.handle("/acme/new-account", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Location", s.url("/accounts/1")) + if r.Method != "POST" { + t.Errorf("r.Method = %q; want POST", r.Method) + } + + var j struct { + Protected string + Contact []string + TermsOfServiceAgreed bool + ExternalaccountBinding struct { + Protected string + Payload string + Signature string + } + } + decodeJWSRequest(t, &j, r.Body) + protData, err := base64.RawURLEncoding.DecodeString(j.ExternalaccountBinding.Protected) + if err != nil { + t.Fatal(err) + } + + var prot protected + err = json.Unmarshal(protData, &prot) + if err != nil { + t.Fatal(err) + } + + if !reflect.DeepEqual(j.Contact, []string{email}) { + t.Errorf("j.Contact = %v; want %v", j.Contact, []string{email}) + } + if !j.TermsOfServiceAgreed { + t.Error("j.TermsOfServiceAgreed = false; want true") + } + + // Ensure same KID. + if prot.KID != eab.KID { + t.Errorf("j.ExternalAccountBinding.KID = %s; want %s", prot.KID, eab.KID) + } + // Ensure expected Algorithm. + if prot.Algorithm != "HS256" { + t.Errorf("j.ExternalAccountBinding.Alg = %s; want %s", + prot.Algorithm, "HS256") + } + + // Ensure same URL as outer JWS. + url := fmt.Sprintf("http://%s/acme/new-account", r.Host) + if prot.URL != url { + t.Errorf("j.ExternalAccountBinding.URL = %s; want %s", + prot.URL, url) + } + + // Ensure payload is base64URL encoded string of JWK in outer JWS + jwk, err := jwkEncode(testKeyEC.Public()) + if err != nil { + t.Fatal(err) + } + decodedPayload, err := base64.RawURLEncoding.DecodeString(j.ExternalaccountBinding.Payload) + if err != nil { + t.Fatal(err) + } + if jwk != string(decodedPayload) { + t.Errorf("j.ExternalAccountBinding.Payload = %s; want %s", decodedPayload, jwk) + } + + // Check signature on inner external account binding JWS + hmac := hmac.New(sha256.New, []byte("secret")) + _, err = hmac.Write([]byte(j.ExternalaccountBinding.Protected + "." + j.ExternalaccountBinding.Payload)) + if err != nil { + t.Fatal(err) + } + mac := hmac.Sum(nil) + encodedMAC := base64.RawURLEncoding.EncodeToString(mac) + + if !bytes.Equal([]byte(encodedMAC), []byte(j.ExternalaccountBinding.Signature)) { + t.Errorf("j.ExternalAccountBinding.Signature = %v; want %v", + []byte(j.ExternalaccountBinding.Signature), encodedMAC) + } + + w.Header().Set("Location", s.url("/accounts/1")) + w.WriteHeader(http.StatusCreated) + b, _ := json.Marshal([]string{email}) + fmt.Fprintf(w, `{"status":"valid","orders":"%s","contact":%s}`, s.url("/accounts/1/orders"), b) + }) + s.start() + defer s.close() + + ctx := context.Background() + cl := &Client{ + Key: testKeyEC, + DirectoryURL: s.url("/"), + } + + var didPrompt bool + a := &Account{Contact: []string{email}, ExternalAccountBinding: eab} + acct, err := cl.Register(ctx, a, func(tos string) bool { + didPrompt = true + terms := s.url("/terms") + if tos != terms { + t.Errorf("tos = %q; want %q", tos, terms) + } + return true + }) + if err != nil { + t.Fatal(err) + } + okAccount := &Account{ + URI: s.url("/accounts/1"), + Status: StatusValid, + Contact: []string{email}, + OrdersURL: s.url("/accounts/1/orders"), + } + if !reflect.DeepEqual(acct, okAccount) { + t.Errorf("acct = %+v; want %+v", acct, okAccount) + } + if !didPrompt { + t.Error("tos prompt wasn't called") + } + if v := cl.accountKID(ctx); v != KeyID(okAccount.URI) { + t.Errorf("account kid = %q; want %q", v, okAccount.URI) + } +} + +func TestRFC_RegisterExisting(t *testing.T) { + s := newACMEServer() + s.handle("/acme/new-account", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Location", s.url("/accounts/1")) + w.WriteHeader(http.StatusOK) // 200 means account already exists + w.Write([]byte(`{"status": "valid"}`)) + }) + s.start() + defer s.close() + + cl := &Client{Key: testKeyEC, DirectoryURL: s.url("/")} + _, err := cl.Register(context.Background(), &Account{}, AcceptTOS) + if err != ErrAccountAlreadyExists { + t.Errorf("err = %v; want %v", err, ErrAccountAlreadyExists) + } + kid := KeyID(s.url("/accounts/1")) + if v := cl.accountKID(context.Background()); v != kid { + t.Errorf("account kid = %q; want %q", v, kid) + } +} + +func TestRFC_UpdateReg(t *testing.T) { + const email = "mailto:user@example.org" + + s := newACMEServer() + s.handle("/acme/new-account", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Location", s.url("/accounts/1")) + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"status": "valid"}`)) + }) + var didUpdate bool + s.handle("/accounts/1", func(w http.ResponseWriter, r *http.Request) { + didUpdate = true + w.Header().Set("Location", s.url("/accounts/1")) + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"status": "valid"}`)) + + b, _ := io.ReadAll(r.Body) // check err later in decodeJWSxxx + head, err := decodeJWSHead(bytes.NewReader(b)) + if err != nil { + t.Errorf("decodeJWSHead: %v", err) + return + } + if len(head.JWK) != 0 { + t.Error("head.JWK is non-zero") + } + kid := s.url("/accounts/1") + if head.KID != kid { + t.Errorf("head.KID = %q; want %q", head.KID, kid) + } + + var req struct{ Contact []string } + decodeJWSRequest(t, &req, bytes.NewReader(b)) + if len(req.Contact) != 1 || req.Contact[0] != email { + t.Errorf("req.Contact = %q; want [%q]", req.Contact, email) + } + }) + s.start() + defer s.close() + + cl := &Client{Key: testKeyEC, DirectoryURL: s.url("/")} + _, err := cl.UpdateReg(context.Background(), &Account{Contact: []string{email}}) + if err != nil { + t.Error(err) + } + if !didUpdate { + t.Error("UpdateReg didn't update the account") + } +} + +func TestRFC_GetReg(t *testing.T) { + s := newACMEServer() + s.handle("/acme/new-account", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Location", s.url("/accounts/1")) + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"status": "valid"}`)) + + head, err := decodeJWSHead(r.Body) + if err != nil { + t.Errorf("decodeJWSHead: %v", err) + return + } + if len(head.JWK) == 0 { + t.Error("head.JWK is empty") + } + }) + s.start() + defer s.close() + + cl := &Client{Key: testKeyEC, DirectoryURL: s.url("/")} + acct, err := cl.GetReg(context.Background(), "") + if err != nil { + t.Fatal(err) + } + okAccount := &Account{ + URI: s.url("/accounts/1"), + Status: StatusValid, + } + if !reflect.DeepEqual(acct, okAccount) { + t.Errorf("acct = %+v; want %+v", acct, okAccount) + } +} + +func TestRFC_GetRegNoAccount(t *testing.T) { + s := newACMEServer() + s.handle("/acme/new-account", func(w http.ResponseWriter, r *http.Request) { + s.error(w, &wireError{ + Status: http.StatusBadRequest, + Type: "urn:ietf:params:acme:error:accountDoesNotExist", + }) + }) + s.start() + defer s.close() + + cl := &Client{Key: testKeyEC, DirectoryURL: s.url("/")} + if _, err := cl.GetReg(context.Background(), ""); err != ErrNoAccount { + t.Errorf("err = %v; want %v", err, ErrNoAccount) + } +} + +func TestRFC_GetRegOtherError(t *testing.T) { + s := newACMEServer() + s.handle("/acme/new-account", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadRequest) + }) + s.start() + defer s.close() + + cl := &Client{Key: testKeyEC, DirectoryURL: s.url("/")} + if _, err := cl.GetReg(context.Background(), ""); err == nil || err == ErrNoAccount { + t.Errorf("GetReg: %v; want any other non-nil err", err) + } +} + +func TestRFC_AccountKeyRollover(t *testing.T) { + s := newACMEServer() + s.handle("/acme/new-account", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Location", s.url("/accounts/1")) + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"status": "valid"}`)) + }) + s.handle("/acme/key-change", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + s.start() + defer s.close() + + cl := &Client{Key: testKeyEC, DirectoryURL: s.url("/")} + if err := cl.AccountKeyRollover(context.Background(), testKeyEC384); err != nil { + t.Errorf("AccountKeyRollover: %v, wanted no error", err) + } else if cl.Key != testKeyEC384 { + t.Error("AccountKeyRollover did not rotate the client key") + } +} + +func TestRFC_DeactivateReg(t *testing.T) { + const email = "mailto:user@example.org" + curStatus := StatusValid + + type account struct { + Status string `json:"status"` + Contact []string `json:"contact"` + AcceptTOS bool `json:"termsOfServiceAgreed"` + Orders string `json:"orders"` + } + + s := newACMEServer() + s.handle("/acme/new-account", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Location", s.url("/accounts/1")) + w.WriteHeader(http.StatusOK) // 200 means existing account + json.NewEncoder(w).Encode(account{ + Status: curStatus, + Contact: []string{email}, + AcceptTOS: true, + Orders: s.url("/accounts/1/orders"), + }) + + b, _ := io.ReadAll(r.Body) // check err later in decodeJWSxxx + head, err := decodeJWSHead(bytes.NewReader(b)) + if err != nil { + t.Errorf("decodeJWSHead: %v", err) + return + } + if len(head.JWK) == 0 { + t.Error("head.JWK is empty") + } + + var req struct { + Status string `json:"status"` + Contact []string `json:"contact"` + AcceptTOS bool `json:"termsOfServiceAgreed"` + OnlyExisting bool `json:"onlyReturnExisting"` + } + decodeJWSRequest(t, &req, bytes.NewReader(b)) + if !req.OnlyExisting { + t.Errorf("req.OnlyReturnExisting = %t; want = %t", req.OnlyExisting, true) + } + }) + s.handle("/accounts/1", func(w http.ResponseWriter, r *http.Request) { + if curStatus == StatusValid { + curStatus = StatusDeactivated + w.WriteHeader(http.StatusOK) + } else { + s.error(w, &wireError{ + Status: http.StatusUnauthorized, + Type: "urn:ietf:params:acme:error:unauthorized", + }) + } + var req account + b, _ := io.ReadAll(r.Body) // check err later in decodeJWSxxx + head, err := decodeJWSHead(bytes.NewReader(b)) + if err != nil { + t.Errorf("decodeJWSHead: %v", err) + return + } + if len(head.JWK) != 0 { + t.Error("head.JWK is not empty") + } + if !strings.HasSuffix(head.KID, "/accounts/1") { + t.Errorf("head.KID = %q; want suffix /accounts/1", head.KID) + } + + decodeJWSRequest(t, &req, bytes.NewReader(b)) + if req.Status != StatusDeactivated { + t.Errorf("req.Status = %q; want = %q", req.Status, StatusDeactivated) + } + }) + s.start() + defer s.close() + + cl := &Client{Key: testKeyEC, DirectoryURL: s.url("/")} + if err := cl.DeactivateReg(context.Background()); err != nil { + t.Errorf("DeactivateReg: %v, wanted no error", err) + } + if err := cl.DeactivateReg(context.Background()); err == nil { + t.Errorf("DeactivateReg: %v, wanted error for unauthorized", err) + } +} + +func TestRF_DeactivateRegNoAccount(t *testing.T) { + s := newACMEServer() + s.handle("/acme/new-account", func(w http.ResponseWriter, r *http.Request) { + s.error(w, &wireError{ + Status: http.StatusBadRequest, + Type: "urn:ietf:params:acme:error:accountDoesNotExist", + }) + }) + s.start() + defer s.close() + + cl := &Client{Key: testKeyEC, DirectoryURL: s.url("/")} + if err := cl.DeactivateReg(context.Background()); !errors.Is(err, ErrNoAccount) { + t.Errorf("DeactivateReg: %v, wanted ErrNoAccount", err) + } +} + +func TestRFC_AuthorizeOrder(t *testing.T) { + s := newACMEServer() + s.handle("/acme/new-account", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Location", s.url("/accounts/1")) + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"status": "valid"}`)) + }) + s.handle("/acme/new-order", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Location", s.url("/orders/1")) + w.WriteHeader(http.StatusCreated) + fmt.Fprintf(w, `{ + "status": "pending", + "expires": "2019-09-01T00:00:00Z", + "notBefore": "2019-08-31T00:00:00Z", + "notAfter": "2019-09-02T00:00:00Z", + "identifiers": [{"type":"dns", "value":"example.org"}], + "authorizations": [%q] + }`, s.url("/authz/1")) + }) + s.start() + defer s.close() + + prevCertDER, _ := pem.Decode([]byte(leafPEM)) + prevCert, err := x509.ParseCertificate(prevCertDER.Bytes) + if err != nil { + t.Fatal(err) + } + + cl := &Client{Key: testKeyEC, DirectoryURL: s.url("/")} + o, err := cl.AuthorizeOrder(context.Background(), DomainIDs("example.org"), + WithOrderNotBefore(time.Date(2019, 8, 31, 0, 0, 0, 0, time.UTC)), + WithOrderNotAfter(time.Date(2019, 9, 2, 0, 0, 0, 0, time.UTC)), + WithOrderReplacesCert(prevCert), + ) + if err != nil { + t.Fatal(err) + } + okOrder := &Order{ + URI: s.url("/orders/1"), + Status: StatusPending, + Expires: time.Date(2019, 9, 1, 0, 0, 0, 0, time.UTC), + NotBefore: time.Date(2019, 8, 31, 0, 0, 0, 0, time.UTC), + NotAfter: time.Date(2019, 9, 2, 0, 0, 0, 0, time.UTC), + Identifiers: []AuthzID{AuthzID{Type: "dns", Value: "example.org"}}, + AuthzURLs: []string{s.url("/authz/1")}, + } + if !reflect.DeepEqual(o, okOrder) { + t.Errorf("AuthorizeOrder = %+v; want %+v", o, okOrder) + } +} + +func TestRFC_GetOrder(t *testing.T) { + s := newACMEServer() + s.handle("/acme/new-account", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Location", s.url("/accounts/1")) + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"status": "valid"}`)) + }) + s.handle("/orders/1", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Location", s.url("/orders/1")) + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{ + "status": "invalid", + "expires": "2019-09-01T00:00:00Z", + "notBefore": "2019-08-31T00:00:00Z", + "notAfter": "2019-09-02T00:00:00Z", + "identifiers": [{"type":"dns", "value":"example.org"}], + "authorizations": ["/authz/1"], + "finalize": "/orders/1/fin", + "certificate": "/orders/1/cert", + "error": {"type": "badRequest"} + }`)) + }) + s.start() + defer s.close() + + cl := &Client{Key: testKeyEC, DirectoryURL: s.url("/")} + o, err := cl.GetOrder(context.Background(), s.url("/orders/1")) + if err != nil { + t.Fatal(err) + } + okOrder := &Order{ + URI: s.url("/orders/1"), + Status: StatusInvalid, + Expires: time.Date(2019, 9, 1, 0, 0, 0, 0, time.UTC), + NotBefore: time.Date(2019, 8, 31, 0, 0, 0, 0, time.UTC), + NotAfter: time.Date(2019, 9, 2, 0, 0, 0, 0, time.UTC), + Identifiers: []AuthzID{AuthzID{Type: "dns", Value: "example.org"}}, + AuthzURLs: []string{"/authz/1"}, + FinalizeURL: "/orders/1/fin", + CertURL: "/orders/1/cert", + Error: &Error{ProblemType: "badRequest"}, + } + if !reflect.DeepEqual(o, okOrder) { + t.Errorf("GetOrder = %+v\nwant %+v", o, okOrder) + } +} + +func TestRFC_WaitOrder(t *testing.T) { + for _, st := range []string{StatusReady, StatusValid} { + t.Run(st, func(t *testing.T) { + testWaitOrderStatus(t, st) + }) + } +} + +func testWaitOrderStatus(t *testing.T, okStatus string) { + s := newACMEServer() + s.handle("/acme/new-account", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Location", s.url("/accounts/1")) + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"status": "valid"}`)) + }) + var count int + s.handle("/orders/1", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Location", s.url("/orders/1")) + w.WriteHeader(http.StatusOK) + s := StatusPending + if count > 0 { + s = okStatus + } + fmt.Fprintf(w, `{"status": %q}`, s) + count++ + }) + s.start() + defer s.close() + + cl := &Client{Key: testKeyEC, DirectoryURL: s.url("/")} + order, err := cl.WaitOrder(context.Background(), s.url("/orders/1")) + if err != nil { + t.Fatalf("WaitOrder: %v", err) + } + if order.Status != okStatus { + t.Errorf("order.Status = %q; want %q", order.Status, okStatus) + } +} + +func TestRFC_WaitOrderError(t *testing.T) { + s := newACMEServer() + s.handle("/acme/new-account", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Location", s.url("/accounts/1")) + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"status": "valid"}`)) + }) + var count int + s.handle("/orders/1", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Location", s.url("/orders/1")) + w.WriteHeader(http.StatusOK) + s := StatusPending + if count > 0 { + s = StatusInvalid + } + fmt.Fprintf(w, `{"status": %q}`, s) + count++ + }) + s.start() + defer s.close() + + cl := &Client{Key: testKeyEC, DirectoryURL: s.url("/")} + _, err := cl.WaitOrder(context.Background(), s.url("/orders/1")) + if err == nil { + t.Fatal("WaitOrder returned nil error") + } + e, ok := err.(*OrderError) + if !ok { + t.Fatalf("err = %v (%T); want OrderError", err, err) + } + if e.OrderURL != s.url("/orders/1") { + t.Errorf("e.OrderURL = %q; want %q", e.OrderURL, s.url("/orders/1")) + } + if e.Status != StatusInvalid { + t.Errorf("e.Status = %q; want %q", e.Status, StatusInvalid) + } +} + +func TestRFC_CreateOrderCert(t *testing.T) { + q := &x509.CertificateRequest{ + Subject: pkix.Name{CommonName: "example.org"}, + } + csr, err := x509.CreateCertificateRequest(rand.Reader, q, testKeyEC) + if err != nil { + t.Fatal(err) + } + + tmpl := &x509.Certificate{SerialNumber: big.NewInt(1)} + leaf, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &testKeyEC.PublicKey, testKeyEC) + if err != nil { + t.Fatal(err) + } + + s := newACMEServer() + s.handle("/acme/new-account", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Location", s.url("/accounts/1")) + w.Write([]byte(`{"status": "valid"}`)) + }) + var count int + s.handle("/pleaseissue", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Location", s.url("/pleaseissue")) + st := StatusProcessing + if count > 0 { + st = StatusValid + } + fmt.Fprintf(w, `{"status":%q, "certificate":%q}`, st, s.url("/crt")) + count++ + }) + s.handle("/crt", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/pem-certificate-chain") + pem.Encode(w, &pem.Block{Type: "CERTIFICATE", Bytes: leaf}) + }) + s.start() + defer s.close() + ctx := context.Background() + + cl := &Client{Key: testKeyEC, DirectoryURL: s.url("/")} + cert, curl, err := cl.CreateOrderCert(ctx, s.url("/pleaseissue"), csr, true) + if err != nil { + t.Fatalf("CreateOrderCert: %v", err) + } + if _, err := x509.ParseCertificate(cert[0]); err != nil { + t.Errorf("ParseCertificate: %v", err) + } + if !reflect.DeepEqual(cert[0], leaf) { + t.Errorf("cert and leaf bytes don't match") + } + if u := s.url("/crt"); curl != u { + t.Errorf("curl = %q; want %q", curl, u) + } +} + +func TestRFC_AlreadyRevokedCert(t *testing.T) { + s := newACMEServer() + s.handle("/acme/revoke-cert", func(w http.ResponseWriter, r *http.Request) { + s.error(w, &wireError{ + Status: http.StatusBadRequest, + Type: "urn:ietf:params:acme:error:alreadyRevoked", + }) + }) + s.start() + defer s.close() + + cl := &Client{Key: testKeyEC, DirectoryURL: s.url("/")} + err := cl.RevokeCert(context.Background(), testKeyEC, []byte{0}, CRLReasonUnspecified) + if err != nil { + t.Fatalf("RevokeCert: %v", err) + } +} + +func TestRFC_ListCertAlternates(t *testing.T) { + s := newACMEServer() + s.handle("/crt", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/pem-certificate-chain") + w.Header().Add("Link", `;rel="alternate"`) + w.Header().Add("Link", `; rel="alternate"`) + w.Header().Add("Link", `; rel="index"`) + }) + s.handle("/crt2", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/pem-certificate-chain") + }) + s.start() + defer s.close() + + cl := &Client{Key: testKeyEC, DirectoryURL: s.url("/")} + crts, err := cl.ListCertAlternates(context.Background(), s.url("/crt")) + if err != nil { + t.Fatalf("ListCertAlternates: %v", err) + } + want := []string{"https://example.com/crt/2", "https://example.com/crt/3"} + if !reflect.DeepEqual(crts, want) { + t.Errorf("ListCertAlternates(/crt): %v; want %v", crts, want) + } + crts, err = cl.ListCertAlternates(context.Background(), s.url("/crt2")) + if err != nil { + t.Fatalf("ListCertAlternates: %v", err) + } + if crts != nil { + t.Errorf("ListCertAlternates(/crt2): %v; want nil", crts) + } +} diff --git a/tempfork/acme/sync_to_upstream_test.go b/tempfork/acme/sync_to_upstream_test.go new file mode 100644 index 000000000..e22c8c1a8 --- /dev/null +++ b/tempfork/acme/sync_to_upstream_test.go @@ -0,0 +1,70 @@ +package acme + +import ( + "os" + "os/exec" + "path/filepath" + "strings" + "testing" + + "github.com/google/go-cmp/cmp" + _ "github.com/tailscale/golang-x-crypto/acme" // so it's on disk for the test +) + +// Verify that the files tempfork/acme/*.go (other than this test file) match the +// files in "github.com/tailscale/golang-x-crypto/acme" which is where we develop +// our fork of golang.org/x/crypto/acme and merge with upstream, but then we vendor +// just its acme package into tailscale.com/tempfork/acme. +// +// Development workflow: +// +// - make a change in github.com/tailscale/golang-x-crypto/acme +// - merge it (ideally with golang.org/x/crypto/acme too) +// - rebase github.com/tailscale/golang-x-crypto/acme with upstream x/crypto/acme +// as needed +// - in the tailscale.com repo, run "go get github.com/tailscale/golang-x-crypto/acme@main" +// - run go test ./tempfork/acme to watch it fail; the failure includes +// a shell command you should run to copy the *.go files from tailscale/golang-x-crypto +// to tailscale.com. +// - watch tests pass. git add it all. +// - send PR to tailscale.com +func TestSyncedToUpstream(t *testing.T) { + const pkg = "github.com/tailscale/golang-x-crypto/acme" + out, err := exec.Command("go", "list", "-f", "{{.Dir}}", pkg).Output() + if err != nil { + t.Fatalf("failed to find %s's location o disk: %v", pkg, err) + } + xDir := strings.TrimSpace(string(out)) + + t.Logf("at %s", xDir) + scanDir := func(dir string) map[string]string { + m := map[string]string{} // filename => Go contents + ents, err := os.ReadDir(dir) + if err != nil { + t.Fatal(err) + } + for _, de := range ents { + name := de.Name() + if name == "sync_to_upstream_test.go" { + continue + } + if !strings.HasSuffix(name, ".go") { + continue + } + b, err := os.ReadFile(filepath.Join(dir, name)) + if err != nil { + t.Fatal(err) + } + m[name] = strings.ReplaceAll(string(b), "\r", "") + } + + return m + } + + want := scanDir(xDir) + got := scanDir(".") + if diff := cmp.Diff(want, got); diff != "" { + t.Errorf("files differ (-want +got):\n%s", diff) + t.Errorf("to fix, run from module root:\n\ncp %s/*.go ./tempfork/acme && ./tool/go mod tidy\n", xDir) + } +} diff --git a/tempfork/acme/types.go b/tempfork/acme/types.go new file mode 100644 index 000000000..0142469d8 --- /dev/null +++ b/tempfork/acme/types.go @@ -0,0 +1,667 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package acme + +import ( + "crypto" + "crypto/x509" + "encoding/json" + "errors" + "fmt" + "net/http" + "strings" + "time" +) + +// ACME status values of Account, Order, Authorization and Challenge objects. +// See https://tools.ietf.org/html/rfc8555#section-7.1.6 for details. +const ( + StatusDeactivated = "deactivated" + StatusExpired = "expired" + StatusInvalid = "invalid" + StatusPending = "pending" + StatusProcessing = "processing" + StatusReady = "ready" + StatusRevoked = "revoked" + StatusUnknown = "unknown" + StatusValid = "valid" +) + +// CRLReasonCode identifies the reason for a certificate revocation. +type CRLReasonCode int + +// CRL reason codes as defined in RFC 5280. +const ( + CRLReasonUnspecified CRLReasonCode = 0 + CRLReasonKeyCompromise CRLReasonCode = 1 + CRLReasonCACompromise CRLReasonCode = 2 + CRLReasonAffiliationChanged CRLReasonCode = 3 + CRLReasonSuperseded CRLReasonCode = 4 + CRLReasonCessationOfOperation CRLReasonCode = 5 + CRLReasonCertificateHold CRLReasonCode = 6 + CRLReasonRemoveFromCRL CRLReasonCode = 8 + CRLReasonPrivilegeWithdrawn CRLReasonCode = 9 + CRLReasonAACompromise CRLReasonCode = 10 +) + +var ( + // ErrUnsupportedKey is returned when an unsupported key type is encountered. + ErrUnsupportedKey = errors.New("acme: unknown key type; only RSA and ECDSA are supported") + + // ErrAccountAlreadyExists indicates that the Client's key has already been registered + // with the CA. It is returned by Register method. + ErrAccountAlreadyExists = errors.New("acme: account already exists") + + // ErrNoAccount indicates that the Client's key has not been registered with the CA. + ErrNoAccount = errors.New("acme: account does not exist") +) + +// A Subproblem describes an ACME subproblem as reported in an Error. +type Subproblem struct { + // Type is a URI reference that identifies the problem type, + // typically in a "urn:acme:error:xxx" form. + Type string + // Detail is a human-readable explanation specific to this occurrence of the problem. + Detail string + // Instance indicates a URL that the client should direct a human user to visit + // in order for instructions on how to agree to the updated Terms of Service. + // In such an event CA sets StatusCode to 403, Type to + // "urn:ietf:params:acme:error:userActionRequired", and adds a Link header with relation + // "terms-of-service" containing the latest TOS URL. + Instance string + // Identifier may contain the ACME identifier that the error is for. + Identifier *AuthzID +} + +func (sp Subproblem) String() string { + str := fmt.Sprintf("%s: ", sp.Type) + if sp.Identifier != nil { + str += fmt.Sprintf("[%s: %s] ", sp.Identifier.Type, sp.Identifier.Value) + } + str += sp.Detail + return str +} + +// Error is an ACME error, defined in Problem Details for HTTP APIs doc +// http://tools.ietf.org/html/draft-ietf-appsawg-http-problem. +type Error struct { + // StatusCode is The HTTP status code generated by the origin server. + StatusCode int + // ProblemType is a URI reference that identifies the problem type, + // typically in a "urn:acme:error:xxx" form. + ProblemType string + // Detail is a human-readable explanation specific to this occurrence of the problem. + Detail string + // Instance indicates a URL that the client should direct a human user to visit + // in order for instructions on how to agree to the updated Terms of Service. + // In such an event CA sets StatusCode to 403, ProblemType to + // "urn:ietf:params:acme:error:userActionRequired" and a Link header with relation + // "terms-of-service" containing the latest TOS URL. + Instance string + // Header is the original server error response headers. + // It may be nil. + Header http.Header + // Subproblems may contain more detailed information about the individual problems + // that caused the error. This field is only sent by RFC 8555 compatible ACME + // servers. Defined in RFC 8555 Section 6.7.1. + Subproblems []Subproblem +} + +func (e *Error) Error() string { + str := fmt.Sprintf("%d %s: %s", e.StatusCode, e.ProblemType, e.Detail) + if len(e.Subproblems) > 0 { + str += fmt.Sprintf("; subproblems:") + for _, sp := range e.Subproblems { + str += fmt.Sprintf("\n\t%s", sp) + } + } + return str +} + +// AuthorizationError indicates that an authorization for an identifier +// did not succeed. +// It contains all errors from Challenge items of the failed Authorization. +type AuthorizationError struct { + // URI uniquely identifies the failed Authorization. + URI string + + // Identifier is an AuthzID.Value of the failed Authorization. + Identifier string + + // Errors is a collection of non-nil error values of Challenge items + // of the failed Authorization. + Errors []error +} + +func (a *AuthorizationError) Error() string { + e := make([]string, len(a.Errors)) + for i, err := range a.Errors { + e[i] = err.Error() + } + + if a.Identifier != "" { + return fmt.Sprintf("acme: authorization error for %s: %s", a.Identifier, strings.Join(e, "; ")) + } + + return fmt.Sprintf("acme: authorization error: %s", strings.Join(e, "; ")) +} + +// OrderError is returned from Client's order related methods. +// It indicates the order is unusable and the clients should start over with +// AuthorizeOrder. +// +// The clients can still fetch the order object from CA using GetOrder +// to inspect its state. +type OrderError struct { + OrderURL string + Status string +} + +func (oe *OrderError) Error() string { + return fmt.Sprintf("acme: order %s status: %s", oe.OrderURL, oe.Status) +} + +// RateLimit reports whether err represents a rate limit error and +// any Retry-After duration returned by the server. +// +// See the following for more details on rate limiting: +// https://tools.ietf.org/html/draft-ietf-acme-acme-05#section-5.6 +func RateLimit(err error) (time.Duration, bool) { + e, ok := err.(*Error) + if !ok { + return 0, false + } + // Some CA implementations may return incorrect values. + // Use case-insensitive comparison. + if !strings.HasSuffix(strings.ToLower(e.ProblemType), ":ratelimited") { + return 0, false + } + if e.Header == nil { + return 0, true + } + return retryAfter(e.Header.Get("Retry-After")), true +} + +// Account is a user account. It is associated with a private key. +// Non-RFC 8555 fields are empty when interfacing with a compliant CA. +type Account struct { + // URI is the account unique ID, which is also a URL used to retrieve + // account data from the CA. + // When interfacing with RFC 8555-compliant CAs, URI is the "kid" field + // value in JWS signed requests. + URI string + + // Contact is a slice of contact info used during registration. + // See https://tools.ietf.org/html/rfc8555#section-7.3 for supported + // formats. + Contact []string + + // Status indicates current account status as returned by the CA. + // Possible values are StatusValid, StatusDeactivated, and StatusRevoked. + Status string + + // OrdersURL is a URL from which a list of orders submitted by this account + // can be fetched. + OrdersURL string + + // The terms user has agreed to. + // A value not matching CurrentTerms indicates that the user hasn't agreed + // to the actual Terms of Service of the CA. + // + // It is non-RFC 8555 compliant. Package users can store the ToS they agree to + // during Client's Register call in the prompt callback function. + AgreedTerms string + + // Actual terms of a CA. + // + // It is non-RFC 8555 compliant. Use Directory's Terms field. + // When a CA updates their terms and requires an account agreement, + // a URL at which instructions to do so is available in Error's Instance field. + CurrentTerms string + + // Authz is the authorization URL used to initiate a new authz flow. + // + // It is non-RFC 8555 compliant. Use Directory's AuthzURL or OrderURL. + Authz string + + // Authorizations is a URI from which a list of authorizations + // granted to this account can be fetched via a GET request. + // + // It is non-RFC 8555 compliant and is obsoleted by OrdersURL. + Authorizations string + + // Certificates is a URI from which a list of certificates + // issued for this account can be fetched via a GET request. + // + // It is non-RFC 8555 compliant and is obsoleted by OrdersURL. + Certificates string + + // ExternalAccountBinding represents an arbitrary binding to an account of + // the CA which the ACME server is tied to. + // See https://tools.ietf.org/html/rfc8555#section-7.3.4 for more details. + ExternalAccountBinding *ExternalAccountBinding +} + +// ExternalAccountBinding contains the data needed to form a request with +// an external account binding. +// See https://tools.ietf.org/html/rfc8555#section-7.3.4 for more details. +type ExternalAccountBinding struct { + // KID is the Key ID of the symmetric MAC key that the CA provides to + // identify an external account from ACME. + KID string + + // Key is the bytes of the symmetric key that the CA provides to identify + // the account. Key must correspond to the KID. + Key []byte +} + +func (e *ExternalAccountBinding) String() string { + return fmt.Sprintf("&{KID: %q, Key: redacted}", e.KID) +} + +// Directory is ACME server discovery data. +// See https://tools.ietf.org/html/rfc8555#section-7.1.1 for more details. +type Directory struct { + // NonceURL indicates an endpoint where to fetch fresh nonce values from. + NonceURL string + + // RegURL is an account endpoint URL, allowing for creating new accounts. + // Pre-RFC 8555 CAs also allow modifying existing accounts at this URL. + RegURL string + + // OrderURL is used to initiate the certificate issuance flow + // as described in RFC 8555. + OrderURL string + + // AuthzURL is used to initiate identifier pre-authorization flow. + // Empty string indicates the flow is unsupported by the CA. + AuthzURL string + + // CertURL is a new certificate issuance endpoint URL. + // It is non-RFC 8555 compliant and is obsoleted by OrderURL. + CertURL string + + // RevokeURL is used to initiate a certificate revocation flow. + RevokeURL string + + // KeyChangeURL allows to perform account key rollover flow. + KeyChangeURL string + + // RenewalInfoURL allows to perform certificate renewal using the ACME + // Renewal Information (ARI) Extension. + RenewalInfoURL string + + // Terms is a URI identifying the current terms of service. + Terms string + + // Website is an HTTP or HTTPS URL locating a website + // providing more information about the ACME server. + Website string + + // CAA consists of lowercase hostname elements, which the ACME server + // recognises as referring to itself for the purposes of CAA record validation + // as defined in RFC 6844. + CAA []string + + // ExternalAccountRequired indicates that the CA requires for all account-related + // requests to include external account binding information. + ExternalAccountRequired bool +} + +// Order represents a client's request for a certificate. +// It tracks the request flow progress through to issuance. +type Order struct { + // URI uniquely identifies an order. + URI string + + // Status represents the current status of the order. + // It indicates which action the client should take. + // + // Possible values are StatusPending, StatusReady, StatusProcessing, StatusValid and StatusInvalid. + // Pending means the CA does not believe that the client has fulfilled the requirements. + // Ready indicates that the client has fulfilled all the requirements and can submit a CSR + // to obtain a certificate. This is done with Client's CreateOrderCert. + // Processing means the certificate is being issued. + // Valid indicates the CA has issued the certificate. It can be downloaded + // from the Order's CertURL. This is done with Client's FetchCert. + // Invalid means the certificate will not be issued. Users should consider this order + // abandoned. + Status string + + // Expires is the timestamp after which CA considers this order invalid. + Expires time.Time + + // Identifiers contains all identifier objects which the order pertains to. + Identifiers []AuthzID + + // NotBefore is the requested value of the notBefore field in the certificate. + NotBefore time.Time + + // NotAfter is the requested value of the notAfter field in the certificate. + NotAfter time.Time + + // AuthzURLs represents authorizations to complete before a certificate + // for identifiers specified in the order can be issued. + // It also contains unexpired authorizations that the client has completed + // in the past. + // + // Authorization objects can be fetched using Client's GetAuthorization method. + // + // The required authorizations are dictated by CA policies. + // There may not be a 1:1 relationship between the identifiers and required authorizations. + // Required authorizations can be identified by their StatusPending status. + // + // For orders in the StatusValid or StatusInvalid state these are the authorizations + // which were completed. + AuthzURLs []string + + // FinalizeURL is the endpoint at which a CSR is submitted to obtain a certificate + // once all the authorizations are satisfied. + FinalizeURL string + + // CertURL points to the certificate that has been issued in response to this order. + CertURL string + + // The error that occurred while processing the order as received from a CA, if any. + Error *Error +} + +// OrderOption allows customizing Client.AuthorizeOrder call. +type OrderOption interface { + privateOrderOpt() +} + +// WithOrderNotBefore sets order's NotBefore field. +func WithOrderNotBefore(t time.Time) OrderOption { + return orderNotBeforeOpt(t) +} + +// WithOrderNotAfter sets order's NotAfter field. +func WithOrderNotAfter(t time.Time) OrderOption { + return orderNotAfterOpt(t) +} + +type orderNotBeforeOpt time.Time + +func (orderNotBeforeOpt) privateOrderOpt() {} + +type orderNotAfterOpt time.Time + +func (orderNotAfterOpt) privateOrderOpt() {} + +// WithOrderReplacesCert indicates that this Order is for a replacement of an +// existing certificate. +// See https://datatracker.ietf.org/doc/html/draft-ietf-acme-ari-03#section-5 +func WithOrderReplacesCert(cert *x509.Certificate) OrderOption { + return orderReplacesCert{cert} +} + +type orderReplacesCert struct { + cert *x509.Certificate +} + +func (orderReplacesCert) privateOrderOpt() {} + +// WithOrderReplacesCertDER indicates that this Order is for a replacement of +// an existing DER-encoded certificate. +// See https://datatracker.ietf.org/doc/html/draft-ietf-acme-ari-03#section-5 +func WithOrderReplacesCertDER(der []byte) OrderOption { + return orderReplacesCertDER(der) +} + +type orderReplacesCertDER []byte + +func (orderReplacesCertDER) privateOrderOpt() {} + +// Authorization encodes an authorization response. +type Authorization struct { + // URI uniquely identifies a authorization. + URI string + + // Status is the current status of an authorization. + // Possible values are StatusPending, StatusValid, StatusInvalid, StatusDeactivated, + // StatusExpired and StatusRevoked. + Status string + + // Identifier is what the account is authorized to represent. + Identifier AuthzID + + // The timestamp after which the CA considers the authorization invalid. + Expires time.Time + + // Wildcard is true for authorizations of a wildcard domain name. + Wildcard bool + + // Challenges that the client needs to fulfill in order to prove possession + // of the identifier (for pending authorizations). + // For valid authorizations, the challenge that was validated. + // For invalid authorizations, the challenge that was attempted and failed. + // + // RFC 8555 compatible CAs require users to fuflfill only one of the challenges. + Challenges []*Challenge + + // A collection of sets of challenges, each of which would be sufficient + // to prove possession of the identifier. + // Clients must complete a set of challenges that covers at least one set. + // Challenges are identified by their indices in the challenges array. + // If this field is empty, the client needs to complete all challenges. + // + // This field is unused in RFC 8555. + Combinations [][]int +} + +// AuthzID is an identifier that an account is authorized to represent. +type AuthzID struct { + Type string // The type of identifier, "dns" or "ip". + Value string // The identifier itself, e.g. "example.org". +} + +// DomainIDs creates a slice of AuthzID with "dns" identifier type. +func DomainIDs(names ...string) []AuthzID { + a := make([]AuthzID, len(names)) + for i, v := range names { + a[i] = AuthzID{Type: "dns", Value: v} + } + return a +} + +// IPIDs creates a slice of AuthzID with "ip" identifier type. +// Each element of addr is textual form of an address as defined +// in RFC 1123 Section 2.1 for IPv4 and in RFC 5952 Section 4 for IPv6. +func IPIDs(addr ...string) []AuthzID { + a := make([]AuthzID, len(addr)) + for i, v := range addr { + a[i] = AuthzID{Type: "ip", Value: v} + } + return a +} + +// wireAuthzID is ACME JSON representation of authorization identifier objects. +type wireAuthzID struct { + Type string `json:"type"` + Value string `json:"value"` +} + +// wireAuthz is ACME JSON representation of Authorization objects. +type wireAuthz struct { + Identifier wireAuthzID + Status string + Expires time.Time + Wildcard bool + Challenges []wireChallenge + Combinations [][]int + Error *wireError +} + +func (z *wireAuthz) authorization(uri string) *Authorization { + a := &Authorization{ + URI: uri, + Status: z.Status, + Identifier: AuthzID{Type: z.Identifier.Type, Value: z.Identifier.Value}, + Expires: z.Expires, + Wildcard: z.Wildcard, + Challenges: make([]*Challenge, len(z.Challenges)), + Combinations: z.Combinations, // shallow copy + } + for i, v := range z.Challenges { + a.Challenges[i] = v.challenge() + } + return a +} + +func (z *wireAuthz) error(uri string) *AuthorizationError { + err := &AuthorizationError{ + URI: uri, + Identifier: z.Identifier.Value, + } + + if z.Error != nil { + err.Errors = append(err.Errors, z.Error.error(nil)) + } + + for _, raw := range z.Challenges { + if raw.Error != nil { + err.Errors = append(err.Errors, raw.Error.error(nil)) + } + } + + return err +} + +// Challenge encodes a returned CA challenge. +// Its Error field may be non-nil if the challenge is part of an Authorization +// with StatusInvalid. +type Challenge struct { + // Type is the challenge type, e.g. "http-01", "tls-alpn-01", "dns-01". + Type string + + // URI is where a challenge response can be posted to. + URI string + + // Token is a random value that uniquely identifies the challenge. + Token string + + // Status identifies the status of this challenge. + // In RFC 8555, possible values are StatusPending, StatusProcessing, StatusValid, + // and StatusInvalid. + Status string + + // Validated is the time at which the CA validated this challenge. + // Always zero value in pre-RFC 8555. + Validated time.Time + + // Error indicates the reason for an authorization failure + // when this challenge was used. + // The type of a non-nil value is *Error. + Error error + + // Payload is the JSON-formatted payload that the client sends + // to the server to indicate it is ready to respond to the challenge. + // When unset, it defaults to an empty JSON object: {}. + // For most challenges, the client must not set Payload, + // see https://tools.ietf.org/html/rfc8555#section-7.5.1. + // Payload is used only for newer challenges (such as "device-attest-01") + // where the client must send additional data for the server to validate + // the challenge. + Payload json.RawMessage +} + +// wireChallenge is ACME JSON challenge representation. +type wireChallenge struct { + URL string `json:"url"` // RFC + URI string `json:"uri"` // pre-RFC + Type string + Token string + Status string + Validated time.Time + Error *wireError +} + +func (c *wireChallenge) challenge() *Challenge { + v := &Challenge{ + URI: c.URL, + Type: c.Type, + Token: c.Token, + Status: c.Status, + } + if v.URI == "" { + v.URI = c.URI // c.URL was empty; use legacy + } + if v.Status == "" { + v.Status = StatusPending + } + if c.Error != nil { + v.Error = c.Error.error(nil) + } + return v +} + +// wireError is a subset of fields of the Problem Details object +// as described in https://tools.ietf.org/html/rfc7807#section-3.1. +type wireError struct { + Status int + Type string + Detail string + Instance string + Subproblems []Subproblem +} + +func (e *wireError) error(h http.Header) *Error { + err := &Error{ + StatusCode: e.Status, + ProblemType: e.Type, + Detail: e.Detail, + Instance: e.Instance, + Header: h, + Subproblems: e.Subproblems, + } + return err +} + +// CertOption is an optional argument type for the TLS ChallengeCert methods for +// customizing a temporary certificate for TLS-based challenges. +type CertOption interface { + privateCertOpt() +} + +// WithKey creates an option holding a private/public key pair. +// The private part signs a certificate, and the public part represents the signee. +func WithKey(key crypto.Signer) CertOption { + return &certOptKey{key} +} + +type certOptKey struct { + key crypto.Signer +} + +func (*certOptKey) privateCertOpt() {} + +// WithTemplate creates an option for specifying a certificate template. +// See x509.CreateCertificate for template usage details. +// +// In TLS ChallengeCert methods, the template is also used as parent, +// resulting in a self-signed certificate. +// The DNSNames field of t is always overwritten for tls-sni challenge certs. +func WithTemplate(t *x509.Certificate) CertOption { + return (*certOptTemplate)(t) +} + +type certOptTemplate x509.Certificate + +func (*certOptTemplate) privateCertOpt() {} + +// RenewalInfoWindow describes the time frame during which the ACME client +// should attempt to renew, using the ACME Renewal Info Extension. +type RenewalInfoWindow struct { + Start time.Time `json:"start"` + End time.Time `json:"end"` +} + +// RenewalInfo describes the suggested renewal window for a given certificate, +// returned from an ACME server, using the ACME Renewal Info Extension. +type RenewalInfo struct { + SuggestedWindow RenewalInfoWindow `json:"suggestedWindow"` + ExplanationURL string `json:"explanationURL"` +} diff --git a/tempfork/acme/types_test.go b/tempfork/acme/types_test.go new file mode 100644 index 000000000..59ce7e760 --- /dev/null +++ b/tempfork/acme/types_test.go @@ -0,0 +1,219 @@ +// Copyright 2017 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package acme + +import ( + "errors" + "net/http" + "reflect" + "testing" + "time" +) + +func TestExternalAccountBindingString(t *testing.T) { + eab := ExternalAccountBinding{ + KID: "kid", + Key: []byte("key"), + } + got := eab.String() + want := `&{KID: "kid", Key: redacted}` + if got != want { + t.Errorf("eab.String() = %q, want: %q", got, want) + } +} + +func TestRateLimit(t *testing.T) { + now := time.Date(2017, 04, 27, 10, 0, 0, 0, time.UTC) + f := timeNow + defer func() { timeNow = f }() + timeNow = func() time.Time { return now } + + h120, hTime := http.Header{}, http.Header{} + h120.Set("Retry-After", "120") + hTime.Set("Retry-After", "Tue Apr 27 11:00:00 2017") + + err1 := &Error{ + ProblemType: "urn:ietf:params:acme:error:nolimit", + Header: h120, + } + err2 := &Error{ + ProblemType: "urn:ietf:params:acme:error:rateLimited", + Header: h120, + } + err3 := &Error{ + ProblemType: "urn:ietf:params:acme:error:rateLimited", + Header: nil, + } + err4 := &Error{ + ProblemType: "urn:ietf:params:acme:error:rateLimited", + Header: hTime, + } + + tt := []struct { + err error + res time.Duration + ok bool + }{ + {nil, 0, false}, + {errors.New("dummy"), 0, false}, + {err1, 0, false}, + {err2, 2 * time.Minute, true}, + {err3, 0, true}, + {err4, time.Hour, true}, + } + for i, test := range tt { + res, ok := RateLimit(test.err) + if ok != test.ok { + t.Errorf("%d: RateLimit(%+v): ok = %v; want %v", i, test.err, ok, test.ok) + continue + } + if res != test.res { + t.Errorf("%d: RateLimit(%+v) = %v; want %v", i, test.err, res, test.res) + } + } +} + +func TestAuthorizationError(t *testing.T) { + tests := []struct { + desc string + err *AuthorizationError + msg string + }{ + { + desc: "when auth error identifier is set", + err: &AuthorizationError{ + Identifier: "domain.com", + Errors: []error{ + (&wireError{ + Status: 403, + Type: "urn:ietf:params:acme:error:caa", + Detail: "CAA record for domain.com prevents issuance", + }).error(nil), + }, + }, + msg: "acme: authorization error for domain.com: 403 urn:ietf:params:acme:error:caa: CAA record for domain.com prevents issuance", + }, + + { + desc: "when auth error identifier is unset", + err: &AuthorizationError{ + Errors: []error{ + (&wireError{ + Status: 403, + Type: "urn:ietf:params:acme:error:caa", + Detail: "CAA record for domain.com prevents issuance", + }).error(nil), + }, + }, + msg: "acme: authorization error: 403 urn:ietf:params:acme:error:caa: CAA record for domain.com prevents issuance", + }, + } + + for _, tt := range tests { + if tt.err.Error() != tt.msg { + t.Errorf("got: %s\nwant: %s", tt.err, tt.msg) + } + } +} + +func TestSubproblems(t *testing.T) { + tests := []struct { + wire wireError + expectedOut Error + }{ + { + wire: wireError{ + Status: 1, + Type: "urn:error", + Detail: "it's an error", + }, + expectedOut: Error{ + StatusCode: 1, + ProblemType: "urn:error", + Detail: "it's an error", + }, + }, + { + wire: wireError{ + Status: 1, + Type: "urn:error", + Detail: "it's an error", + Subproblems: []Subproblem{ + { + Type: "urn:error:sub", + Detail: "it's a subproblem", + }, + }, + }, + expectedOut: Error{ + StatusCode: 1, + ProblemType: "urn:error", + Detail: "it's an error", + Subproblems: []Subproblem{ + { + Type: "urn:error:sub", + Detail: "it's a subproblem", + }, + }, + }, + }, + { + wire: wireError{ + Status: 1, + Type: "urn:error", + Detail: "it's an error", + Subproblems: []Subproblem{ + { + Type: "urn:error:sub", + Detail: "it's a subproblem", + Identifier: &AuthzID{Type: "dns", Value: "example"}, + }, + }, + }, + expectedOut: Error{ + StatusCode: 1, + ProblemType: "urn:error", + Detail: "it's an error", + Subproblems: []Subproblem{ + { + Type: "urn:error:sub", + Detail: "it's a subproblem", + Identifier: &AuthzID{Type: "dns", Value: "example"}, + }, + }, + }, + }, + } + + for _, tc := range tests { + out := tc.wire.error(nil) + if !reflect.DeepEqual(*out, tc.expectedOut) { + t.Errorf("Unexpected error: wanted %v, got %v", tc.expectedOut, *out) + } + } +} + +func TestErrorStringerWithSubproblems(t *testing.T) { + err := Error{ + StatusCode: 1, + ProblemType: "urn:error", + Detail: "it's an error", + Subproblems: []Subproblem{ + { + Type: "urn:error:sub", + Detail: "it's a subproblem", + }, + { + Type: "urn:error:sub", + Detail: "it's a subproblem", + Identifier: &AuthzID{Type: "dns", Value: "example"}, + }, + }, + } + expectedStr := "1 urn:error: it's an error; subproblems:\n\turn:error:sub: it's a subproblem\n\turn:error:sub: [dns: example] it's a subproblem" + if err.Error() != expectedStr { + t.Errorf("Unexpected error string: wanted %q, got %q", expectedStr, err.Error()) + } +} diff --git a/tempfork/acme/version_go112.go b/tempfork/acme/version_go112.go new file mode 100644 index 000000000..cc5fab604 --- /dev/null +++ b/tempfork/acme/version_go112.go @@ -0,0 +1,27 @@ +// Copyright 2019 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build go1.12 + +package acme + +import "runtime/debug" + +func init() { + // Set packageVersion if the binary was built in modules mode and x/crypto + // was not replaced with a different module. + info, ok := debug.ReadBuildInfo() + if !ok { + return + } + for _, m := range info.Deps { + if m.Path != "golang.org/x/crypto" { + continue + } + if m.Replace == nil { + packageVersion = m.Version + } + break + } +} diff --git a/tempfork/gliderlabs/ssh/agent.go b/tempfork/gliderlabs/ssh/agent.go index 86a5bce7f..99e84c1e5 100644 --- a/tempfork/gliderlabs/ssh/agent.go +++ b/tempfork/gliderlabs/ssh/agent.go @@ -7,7 +7,7 @@ import ( "path" "sync" - gossh "github.com/tailscale/golang-x-crypto/ssh" + gossh "golang.org/x/crypto/ssh" ) const ( diff --git a/tempfork/gliderlabs/ssh/context.go b/tempfork/gliderlabs/ssh/context.go index d43de6f09..505a43dbf 100644 --- a/tempfork/gliderlabs/ssh/context.go +++ b/tempfork/gliderlabs/ssh/context.go @@ -6,7 +6,7 @@ import ( "net" "sync" - gossh "github.com/tailscale/golang-x-crypto/ssh" + gossh "golang.org/x/crypto/ssh" ) // contextKey is a value for use with context.WithValue. It's used as @@ -55,8 +55,6 @@ var ( // ContextKeyPublicKey is a context key for use with Contexts in this package. // The associated value will be of type PublicKey. ContextKeyPublicKey = &contextKey{"public-key"} - - ContextKeySendAuthBanner = &contextKey{"send-auth-banner"} ) // Context is a package specific context interface. It exposes connection @@ -91,8 +89,6 @@ type Context interface { // SetValue allows you to easily write new values into the underlying context. SetValue(key, value interface{}) - - SendAuthBanner(banner string) error } type sshContext struct { @@ -121,7 +117,6 @@ func applyConnMetadata(ctx Context, conn gossh.ConnMetadata) { ctx.SetValue(ContextKeyUser, conn.User()) ctx.SetValue(ContextKeyLocalAddr, conn.LocalAddr()) ctx.SetValue(ContextKeyRemoteAddr, conn.RemoteAddr()) - ctx.SetValue(ContextKeySendAuthBanner, conn.SendAuthBanner) } func (ctx *sshContext) SetValue(key, value interface{}) { @@ -158,7 +153,3 @@ func (ctx *sshContext) LocalAddr() net.Addr { func (ctx *sshContext) Permissions() *Permissions { return ctx.Value(ContextKeyPermissions).(*Permissions) } - -func (ctx *sshContext) SendAuthBanner(msg string) error { - return ctx.Value(ContextKeySendAuthBanner).(func(string) error)(msg) -} diff --git a/tempfork/gliderlabs/ssh/options.go b/tempfork/gliderlabs/ssh/options.go index aa87a4f39..29c8ef141 100644 --- a/tempfork/gliderlabs/ssh/options.go +++ b/tempfork/gliderlabs/ssh/options.go @@ -3,7 +3,7 @@ package ssh import ( "os" - gossh "github.com/tailscale/golang-x-crypto/ssh" + gossh "golang.org/x/crypto/ssh" ) // PasswordAuth returns a functional option that sets PasswordHandler on the server. diff --git a/tempfork/gliderlabs/ssh/options_test.go b/tempfork/gliderlabs/ssh/options_test.go index 7cf6f376c..47342b0f6 100644 --- a/tempfork/gliderlabs/ssh/options_test.go +++ b/tempfork/gliderlabs/ssh/options_test.go @@ -8,7 +8,7 @@ import ( "sync/atomic" "testing" - gossh "github.com/tailscale/golang-x-crypto/ssh" + gossh "golang.org/x/crypto/ssh" ) func newTestSessionWithOptions(t *testing.T, srv *Server, cfg *gossh.ClientConfig, options ...Option) (*gossh.Session, *gossh.Client, func()) { diff --git a/tempfork/gliderlabs/ssh/server.go b/tempfork/gliderlabs/ssh/server.go index 1086a72ca..473e5fbd6 100644 --- a/tempfork/gliderlabs/ssh/server.go +++ b/tempfork/gliderlabs/ssh/server.go @@ -8,7 +8,7 @@ import ( "sync" "time" - gossh "github.com/tailscale/golang-x-crypto/ssh" + gossh "golang.org/x/crypto/ssh" ) // ErrServerClosed is returned by the Server's Serve, ListenAndServe, diff --git a/tempfork/gliderlabs/ssh/session.go b/tempfork/gliderlabs/ssh/session.go index 0a4a21e53..a7a9a3eeb 100644 --- a/tempfork/gliderlabs/ssh/session.go +++ b/tempfork/gliderlabs/ssh/session.go @@ -9,7 +9,7 @@ import ( "sync" "github.com/anmitsu/go-shlex" - gossh "github.com/tailscale/golang-x-crypto/ssh" + gossh "golang.org/x/crypto/ssh" ) // Session provides access to information about an SSH session and methods diff --git a/tempfork/gliderlabs/ssh/session_test.go b/tempfork/gliderlabs/ssh/session_test.go index a60be5ec1..fe61a9d96 100644 --- a/tempfork/gliderlabs/ssh/session_test.go +++ b/tempfork/gliderlabs/ssh/session_test.go @@ -9,7 +9,7 @@ import ( "net" "testing" - gossh "github.com/tailscale/golang-x-crypto/ssh" + gossh "golang.org/x/crypto/ssh" ) func (srv *Server) serveOnce(l net.Listener) error { diff --git a/tempfork/gliderlabs/ssh/ssh.go b/tempfork/gliderlabs/ssh/ssh.go index 644cb257d..54bd31ec2 100644 --- a/tempfork/gliderlabs/ssh/ssh.go +++ b/tempfork/gliderlabs/ssh/ssh.go @@ -4,7 +4,7 @@ import ( "crypto/subtle" "net" - gossh "github.com/tailscale/golang-x-crypto/ssh" + gossh "golang.org/x/crypto/ssh" ) type Signal string @@ -105,7 +105,7 @@ type Pty struct { // requested by the client as part of the pty-req. These are outlined as // part of https://datatracker.ietf.org/doc/html/rfc4254#section-8. // - // The opcodes are defined as constants in github.com/tailscale/golang-x-crypto/ssh (VINTR,VQUIT,etc.). + // The opcodes are defined as constants in golang.org/x/crypto/ssh (VINTR,VQUIT,etc.). // Boolean opcodes have values 0 or 1. Modes gossh.TerminalModes } diff --git a/tempfork/gliderlabs/ssh/tcpip.go b/tempfork/gliderlabs/ssh/tcpip.go index 056a0c734..335fda657 100644 --- a/tempfork/gliderlabs/ssh/tcpip.go +++ b/tempfork/gliderlabs/ssh/tcpip.go @@ -7,7 +7,7 @@ import ( "strconv" "sync" - gossh "github.com/tailscale/golang-x-crypto/ssh" + gossh "golang.org/x/crypto/ssh" ) const ( diff --git a/tempfork/gliderlabs/ssh/tcpip_test.go b/tempfork/gliderlabs/ssh/tcpip_test.go index 118b5d53a..b3ba60a9b 100644 --- a/tempfork/gliderlabs/ssh/tcpip_test.go +++ b/tempfork/gliderlabs/ssh/tcpip_test.go @@ -10,7 +10,7 @@ import ( "strings" "testing" - gossh "github.com/tailscale/golang-x-crypto/ssh" + gossh "golang.org/x/crypto/ssh" ) var sampleServerResponse = []byte("Hello world") diff --git a/tempfork/gliderlabs/ssh/util.go b/tempfork/gliderlabs/ssh/util.go index e3b5716a3..3bee06dcd 100644 --- a/tempfork/gliderlabs/ssh/util.go +++ b/tempfork/gliderlabs/ssh/util.go @@ -5,7 +5,7 @@ import ( "crypto/rsa" "encoding/binary" - "github.com/tailscale/golang-x-crypto/ssh" + "golang.org/x/crypto/ssh" ) func generateSigner() (ssh.Signer, error) { diff --git a/tempfork/gliderlabs/ssh/wrap.go b/tempfork/gliderlabs/ssh/wrap.go index 17867d751..d1f2b161e 100644 --- a/tempfork/gliderlabs/ssh/wrap.go +++ b/tempfork/gliderlabs/ssh/wrap.go @@ -1,6 +1,6 @@ package ssh -import gossh "github.com/tailscale/golang-x-crypto/ssh" +import gossh "golang.org/x/crypto/ssh" // PublicKey is an abstraction of different types of public keys. type PublicKey interface { diff --git a/tempfork/httprec/httprec.go b/tempfork/httprec/httprec.go new file mode 100644 index 000000000..07ca673fe --- /dev/null +++ b/tempfork/httprec/httprec.go @@ -0,0 +1,220 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package httprec is a copy of the Go standard library's httptest.ResponseRecorder +// type, which we want to use in non-test code without pulling in the rest of +// the httptest package and its test certs, etc. +package httprec + +import ( + "bytes" + "fmt" + "io" + "net/http" + "net/textproto" + "strconv" +) + +// ResponseRecorder is an implementation of [http.ResponseWriter] that +// records its mutations for later inspection in tests. +type ResponseRecorder struct { + // Code is the HTTP response code set by WriteHeader. + // + // Note that if a Handler never calls WriteHeader or Write, + // this might end up being 0, rather than the implicit + // http.StatusOK. To get the implicit value, use the Result + // method. + Code int + + // HeaderMap contains the headers explicitly set by the Handler. + // It is an internal detail. + // + // Deprecated: HeaderMap exists for historical compatibility + // and should not be used. To access the headers returned by a handler, + // use the Response.Header map as returned by the Result method. + HeaderMap http.Header + + // Body is the buffer to which the Handler's Write calls are sent. + // If nil, the Writes are silently discarded. + Body *bytes.Buffer + + // Flushed is whether the Handler called Flush. + Flushed bool + + result *http.Response // cache of Result's return value + snapHeader http.Header // snapshot of HeaderMap at first Write + wroteHeader bool +} + +// NewRecorder returns an initialized [ResponseRecorder]. +func NewRecorder() *ResponseRecorder { + return &ResponseRecorder{ + HeaderMap: make(http.Header), + Body: new(bytes.Buffer), + Code: 200, + } +} + +// Header implements [http.ResponseWriter]. It returns the response +// headers to mutate within a handler. To test the headers that were +// written after a handler completes, use the [ResponseRecorder.Result] method and see +// the returned Response value's Header. +func (rw *ResponseRecorder) Header() http.Header { + m := rw.HeaderMap + if m == nil { + m = make(http.Header) + rw.HeaderMap = m + } + return m +} + +// writeHeader writes a header if it was not written yet and +// detects Content-Type if needed. +// +// bytes or str are the beginning of the response body. +// We pass both to avoid unnecessarily generate garbage +// in rw.WriteString which was created for performance reasons. +// Non-nil bytes win. +func (rw *ResponseRecorder) writeHeader(b []byte, str string) { + if rw.wroteHeader { + return + } + if len(str) > 512 { + str = str[:512] + } + + m := rw.Header() + + _, hasType := m["Content-Type"] + hasTE := m.Get("Transfer-Encoding") != "" + if !hasType && !hasTE { + if b == nil { + b = []byte(str) + } + m.Set("Content-Type", http.DetectContentType(b)) + } + + rw.WriteHeader(200) +} + +// Write implements http.ResponseWriter. The data in buf is written to +// rw.Body, if not nil. +func (rw *ResponseRecorder) Write(buf []byte) (int, error) { + rw.writeHeader(buf, "") + if rw.Body != nil { + rw.Body.Write(buf) + } + return len(buf), nil +} + +// WriteString implements [io.StringWriter]. The data in str is written +// to rw.Body, if not nil. +func (rw *ResponseRecorder) WriteString(str string) (int, error) { + rw.writeHeader(nil, str) + if rw.Body != nil { + rw.Body.WriteString(str) + } + return len(str), nil +} + +func checkWriteHeaderCode(code int) { + // Issue 22880: require valid WriteHeader status codes. + // For now we only enforce that it's three digits. + // In the future we might block things over 599 (600 and above aren't defined + // at https://httpwg.org/specs/rfc7231.html#status.codes) + // and we might block under 200 (once we have more mature 1xx support). + // But for now any three digits. + // + // We used to send "HTTP/1.1 000 0" on the wire in responses but there's + // no equivalent bogus thing we can realistically send in HTTP/2, + // so we'll consistently panic instead and help people find their bugs + // early. (We can't return an error from WriteHeader even if we wanted to.) + if code < 100 || code > 999 { + panic(fmt.Sprintf("invalid WriteHeader code %v", code)) + } +} + +// WriteHeader implements [http.ResponseWriter]. +func (rw *ResponseRecorder) WriteHeader(code int) { + if rw.wroteHeader { + return + } + + checkWriteHeaderCode(code) + rw.Code = code + rw.wroteHeader = true + if rw.HeaderMap == nil { + rw.HeaderMap = make(http.Header) + } + rw.snapHeader = rw.HeaderMap.Clone() +} + +// Flush implements [http.Flusher]. To test whether Flush was +// called, see rw.Flushed. +func (rw *ResponseRecorder) Flush() { + if !rw.wroteHeader { + rw.WriteHeader(200) + } + rw.Flushed = true +} + +// Result returns the response generated by the handler. +// +// The returned Response will have at least its StatusCode, +// Header, Body, and optionally Trailer populated. +// More fields may be populated in the future, so callers should +// not DeepEqual the result in tests. +// +// The Response.Header is a snapshot of the headers at the time of the +// first write call, or at the time of this call, if the handler never +// did a write. +// +// The Response.Body is guaranteed to be non-nil and Body.Read call is +// guaranteed to not return any error other than [io.EOF]. +// +// Result must only be called after the handler has finished running. +func (rw *ResponseRecorder) Result() *http.Response { + if rw.result != nil { + return rw.result + } + if rw.snapHeader == nil { + rw.snapHeader = rw.HeaderMap.Clone() + } + res := &http.Response{ + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + StatusCode: rw.Code, + Header: rw.snapHeader, + } + rw.result = res + if res.StatusCode == 0 { + res.StatusCode = 200 + } + res.Status = fmt.Sprintf("%03d %s", res.StatusCode, http.StatusText(res.StatusCode)) + if rw.Body != nil { + res.Body = io.NopCloser(bytes.NewReader(rw.Body.Bytes())) + } else { + res.Body = http.NoBody + } + res.ContentLength = parseContentLength(res.Header.Get("Content-Length")) + return res +} + +// parseContentLength trims whitespace from s and returns -1 if no value +// is set, or the value if it's >= 0. +// +// This a modified version of same function found in net/http/transfer.go. This +// one just ignores an invalid header. +func parseContentLength(cl string) int64 { + cl = textproto.TrimString(cl) + if cl == "" { + return -1 + } + n, err := strconv.ParseUint(cl, 10, 63) + if err != nil { + return -1 + } + return int64(n) +} diff --git a/tempfork/sshtest/README.md b/tempfork/sshtest/README.md new file mode 100644 index 000000000..30c74f525 --- /dev/null +++ b/tempfork/sshtest/README.md @@ -0,0 +1,9 @@ +# sshtest + +This contains packages that are forked & locally hacked up for use +in tests. + +Notably, `golang.org/x/crypto/ssh` was copied to +`tailscale.com/tempfork/sshtest/ssh` to permit adding behaviors specific +to testing (for testing Tailscale SSH) that aren't necessarily desirable +to have upstream. diff --git a/tempfork/sshtest/ssh/benchmark_test.go b/tempfork/sshtest/ssh/benchmark_test.go new file mode 100644 index 000000000..b356330b4 --- /dev/null +++ b/tempfork/sshtest/ssh/benchmark_test.go @@ -0,0 +1,127 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "errors" + "fmt" + "io" + "net" + "testing" +) + +type server struct { + *ServerConn + chans <-chan NewChannel +} + +func newServer(c net.Conn, conf *ServerConfig) (*server, error) { + sconn, chans, reqs, err := NewServerConn(c, conf) + if err != nil { + return nil, err + } + go DiscardRequests(reqs) + return &server{sconn, chans}, nil +} + +func (s *server) Accept() (NewChannel, error) { + n, ok := <-s.chans + if !ok { + return nil, io.EOF + } + return n, nil +} + +func sshPipe() (Conn, *server, error) { + c1, c2, err := netPipe() + if err != nil { + return nil, nil, err + } + + clientConf := ClientConfig{ + User: "user", + HostKeyCallback: InsecureIgnoreHostKey(), + } + serverConf := ServerConfig{ + NoClientAuth: true, + } + serverConf.AddHostKey(testSigners["ecdsa"]) + done := make(chan *server, 1) + go func() { + server, err := newServer(c2, &serverConf) + if err != nil { + done <- nil + } + done <- server + }() + + client, _, reqs, err := NewClientConn(c1, "", &clientConf) + if err != nil { + return nil, nil, err + } + + server := <-done + if server == nil { + return nil, nil, errors.New("server handshake failed.") + } + go DiscardRequests(reqs) + + return client, server, nil +} + +func BenchmarkEndToEnd(b *testing.B) { + b.StopTimer() + + client, server, err := sshPipe() + if err != nil { + b.Fatalf("sshPipe: %v", err) + } + + defer client.Close() + defer server.Close() + + size := (1 << 20) + input := make([]byte, size) + output := make([]byte, size) + b.SetBytes(int64(size)) + done := make(chan int, 1) + + go func() { + newCh, err := server.Accept() + if err != nil { + panic(fmt.Sprintf("Client: %v", err)) + } + ch, incoming, err := newCh.Accept() + if err != nil { + panic(fmt.Sprintf("Accept: %v", err)) + } + go DiscardRequests(incoming) + for i := 0; i < b.N; i++ { + if _, err := io.ReadFull(ch, output); err != nil { + panic(fmt.Sprintf("ReadFull: %v", err)) + } + } + ch.Close() + done <- 1 + }() + + ch, in, err := client.OpenChannel("speed", nil) + if err != nil { + b.Fatalf("OpenChannel: %v", err) + } + go DiscardRequests(in) + + b.ResetTimer() + b.StartTimer() + for i := 0; i < b.N; i++ { + if _, err := ch.Write(input); err != nil { + b.Fatalf("WriteFull: %v", err) + } + } + ch.Close() + b.StopTimer() + + <-done +} diff --git a/tempfork/sshtest/ssh/buffer.go b/tempfork/sshtest/ssh/buffer.go new file mode 100644 index 000000000..1ab07d078 --- /dev/null +++ b/tempfork/sshtest/ssh/buffer.go @@ -0,0 +1,97 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "io" + "sync" +) + +// buffer provides a linked list buffer for data exchange +// between producer and consumer. Theoretically the buffer is +// of unlimited capacity as it does no allocation of its own. +type buffer struct { + // protects concurrent access to head, tail and closed + *sync.Cond + + head *element // the buffer that will be read first + tail *element // the buffer that will be read last + + closed bool +} + +// An element represents a single link in a linked list. +type element struct { + buf []byte + next *element +} + +// newBuffer returns an empty buffer that is not closed. +func newBuffer() *buffer { + e := new(element) + b := &buffer{ + Cond: newCond(), + head: e, + tail: e, + } + return b +} + +// write makes buf available for Read to receive. +// buf must not be modified after the call to write. +func (b *buffer) write(buf []byte) { + b.Cond.L.Lock() + e := &element{buf: buf} + b.tail.next = e + b.tail = e + b.Cond.Signal() + b.Cond.L.Unlock() +} + +// eof closes the buffer. Reads from the buffer once all +// the data has been consumed will receive io.EOF. +func (b *buffer) eof() { + b.Cond.L.Lock() + b.closed = true + b.Cond.Signal() + b.Cond.L.Unlock() +} + +// Read reads data from the internal buffer in buf. Reads will block +// if no data is available, or until the buffer is closed. +func (b *buffer) Read(buf []byte) (n int, err error) { + b.Cond.L.Lock() + defer b.Cond.L.Unlock() + + for len(buf) > 0 { + // if there is data in b.head, copy it + if len(b.head.buf) > 0 { + r := copy(buf, b.head.buf) + buf, b.head.buf = buf[r:], b.head.buf[r:] + n += r + continue + } + // if there is a next buffer, make it the head + if len(b.head.buf) == 0 && b.head != b.tail { + b.head = b.head.next + continue + } + + // if at least one byte has been copied, return + if n > 0 { + break + } + + // if nothing was read, and there is nothing outstanding + // check to see if the buffer is closed. + if b.closed { + err = io.EOF + break + } + // out of buffers, wait for producer + b.Cond.Wait() + } + return +} diff --git a/tempfork/sshtest/ssh/buffer_test.go b/tempfork/sshtest/ssh/buffer_test.go new file mode 100644 index 000000000..d5781cb3d --- /dev/null +++ b/tempfork/sshtest/ssh/buffer_test.go @@ -0,0 +1,87 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "io" + "testing" +) + +var alphabet = []byte("abcdefghijklmnopqrstuvwxyz") + +func TestBufferReadwrite(t *testing.T) { + b := newBuffer() + b.write(alphabet[:10]) + r, _ := b.Read(make([]byte, 10)) + if r != 10 { + t.Fatalf("Expected written == read == 10, written: 10, read %d", r) + } + + b = newBuffer() + b.write(alphabet[:5]) + r, _ = b.Read(make([]byte, 10)) + if r != 5 { + t.Fatalf("Expected written == read == 5, written: 5, read %d", r) + } + + b = newBuffer() + b.write(alphabet[:10]) + r, _ = b.Read(make([]byte, 5)) + if r != 5 { + t.Fatalf("Expected written == 10, read == 5, written: 10, read %d", r) + } + + b = newBuffer() + b.write(alphabet[:5]) + b.write(alphabet[5:15]) + r, _ = b.Read(make([]byte, 10)) + r2, _ := b.Read(make([]byte, 10)) + if r != 10 || r2 != 5 || 15 != r+r2 { + t.Fatal("Expected written == read == 15") + } +} + +func TestBufferClose(t *testing.T) { + b := newBuffer() + b.write(alphabet[:10]) + b.eof() + _, err := b.Read(make([]byte, 5)) + if err != nil { + t.Fatal("expected read of 5 to not return EOF") + } + b = newBuffer() + b.write(alphabet[:10]) + b.eof() + r, err := b.Read(make([]byte, 5)) + r2, err2 := b.Read(make([]byte, 10)) + if r != 5 || r2 != 5 || err != nil || err2 != nil { + t.Fatal("expected reads of 5 and 5") + } + + b = newBuffer() + b.write(alphabet[:10]) + b.eof() + r, err = b.Read(make([]byte, 5)) + r2, err2 = b.Read(make([]byte, 10)) + r3, err3 := b.Read(make([]byte, 10)) + if r != 5 || r2 != 5 || r3 != 0 || err != nil || err2 != nil || err3 != io.EOF { + t.Fatal("expected reads of 5 and 5 and 0, with EOF") + } + + b = newBuffer() + b.write(make([]byte, 5)) + b.write(make([]byte, 10)) + b.eof() + r, err = b.Read(make([]byte, 9)) + r2, err2 = b.Read(make([]byte, 3)) + r3, err3 = b.Read(make([]byte, 3)) + r4, err4 := b.Read(make([]byte, 10)) + if err != nil || err2 != nil || err3 != nil || err4 != io.EOF { + t.Fatalf("Expected EOF on forth read only, err=%v, err2=%v, err3=%v, err4=%v", err, err2, err3, err4) + } + if r != 9 || r2 != 3 || r3 != 3 || r4 != 0 { + t.Fatal("Expected written == read == 15", r, r2, r3, r4) + } +} diff --git a/tempfork/sshtest/ssh/certs.go b/tempfork/sshtest/ssh/certs.go new file mode 100644 index 000000000..27d0e14aa --- /dev/null +++ b/tempfork/sshtest/ssh/certs.go @@ -0,0 +1,611 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "bytes" + "errors" + "fmt" + "io" + "net" + "sort" + "time" +) + +// Certificate algorithm names from [PROTOCOL.certkeys]. These values can appear +// in Certificate.Type, PublicKey.Type, and ClientConfig.HostKeyAlgorithms. +// Unlike key algorithm names, these are not passed to AlgorithmSigner nor +// returned by MultiAlgorithmSigner and don't appear in the Signature.Format +// field. +const ( + CertAlgoRSAv01 = "ssh-rsa-cert-v01@openssh.com" + CertAlgoDSAv01 = "ssh-dss-cert-v01@openssh.com" + CertAlgoECDSA256v01 = "ecdsa-sha2-nistp256-cert-v01@openssh.com" + CertAlgoECDSA384v01 = "ecdsa-sha2-nistp384-cert-v01@openssh.com" + CertAlgoECDSA521v01 = "ecdsa-sha2-nistp521-cert-v01@openssh.com" + CertAlgoSKECDSA256v01 = "sk-ecdsa-sha2-nistp256-cert-v01@openssh.com" + CertAlgoED25519v01 = "ssh-ed25519-cert-v01@openssh.com" + CertAlgoSKED25519v01 = "sk-ssh-ed25519-cert-v01@openssh.com" + + // CertAlgoRSASHA256v01 and CertAlgoRSASHA512v01 can't appear as a + // Certificate.Type (or PublicKey.Type), but only in + // ClientConfig.HostKeyAlgorithms. + CertAlgoRSASHA256v01 = "rsa-sha2-256-cert-v01@openssh.com" + CertAlgoRSASHA512v01 = "rsa-sha2-512-cert-v01@openssh.com" +) + +const ( + // Deprecated: use CertAlgoRSAv01. + CertSigAlgoRSAv01 = CertAlgoRSAv01 + // Deprecated: use CertAlgoRSASHA256v01. + CertSigAlgoRSASHA2256v01 = CertAlgoRSASHA256v01 + // Deprecated: use CertAlgoRSASHA512v01. + CertSigAlgoRSASHA2512v01 = CertAlgoRSASHA512v01 +) + +// Certificate types distinguish between host and user +// certificates. The values can be set in the CertType field of +// Certificate. +const ( + UserCert = 1 + HostCert = 2 +) + +// Signature represents a cryptographic signature. +type Signature struct { + Format string + Blob []byte + Rest []byte `ssh:"rest"` +} + +// CertTimeInfinity can be used for OpenSSHCertV01.ValidBefore to indicate that +// a certificate does not expire. +const CertTimeInfinity = 1<<64 - 1 + +// An Certificate represents an OpenSSH certificate as defined in +// [PROTOCOL.certkeys]?rev=1.8. The Certificate type implements the +// PublicKey interface, so it can be unmarshaled using +// ParsePublicKey. +type Certificate struct { + Nonce []byte + Key PublicKey + Serial uint64 + CertType uint32 + KeyId string + ValidPrincipals []string + ValidAfter uint64 + ValidBefore uint64 + Permissions + Reserved []byte + SignatureKey PublicKey + Signature *Signature +} + +// genericCertData holds the key-independent part of the certificate data. +// Overall, certificates contain an nonce, public key fields and +// key-independent fields. +type genericCertData struct { + Serial uint64 + CertType uint32 + KeyId string + ValidPrincipals []byte + ValidAfter uint64 + ValidBefore uint64 + CriticalOptions []byte + Extensions []byte + Reserved []byte + SignatureKey []byte + Signature []byte +} + +func marshalStringList(namelist []string) []byte { + var to []byte + for _, name := range namelist { + s := struct{ N string }{name} + to = append(to, Marshal(&s)...) + } + return to +} + +type optionsTuple struct { + Key string + Value []byte +} + +type optionsTupleValue struct { + Value string +} + +// serialize a map of critical options or extensions +// issue #10569 - per [PROTOCOL.certkeys] and SSH implementation, +// we need two length prefixes for a non-empty string value +func marshalTuples(tups map[string]string) []byte { + keys := make([]string, 0, len(tups)) + for key := range tups { + keys = append(keys, key) + } + sort.Strings(keys) + + var ret []byte + for _, key := range keys { + s := optionsTuple{Key: key} + if value := tups[key]; len(value) > 0 { + s.Value = Marshal(&optionsTupleValue{value}) + } + ret = append(ret, Marshal(&s)...) + } + return ret +} + +// issue #10569 - per [PROTOCOL.certkeys] and SSH implementation, +// we need two length prefixes for a non-empty option value +func parseTuples(in []byte) (map[string]string, error) { + tups := map[string]string{} + var lastKey string + var haveLastKey bool + + for len(in) > 0 { + var key, val, extra []byte + var ok bool + + if key, in, ok = parseString(in); !ok { + return nil, errShortRead + } + keyStr := string(key) + // according to [PROTOCOL.certkeys], the names must be in + // lexical order. + if haveLastKey && keyStr <= lastKey { + return nil, fmt.Errorf("ssh: certificate options are not in lexical order") + } + lastKey, haveLastKey = keyStr, true + // the next field is a data field, which if non-empty has a string embedded + if val, in, ok = parseString(in); !ok { + return nil, errShortRead + } + if len(val) > 0 { + val, extra, ok = parseString(val) + if !ok { + return nil, errShortRead + } + if len(extra) > 0 { + return nil, fmt.Errorf("ssh: unexpected trailing data after certificate option value") + } + tups[keyStr] = string(val) + } else { + tups[keyStr] = "" + } + } + return tups, nil +} + +func parseCert(in []byte, privAlgo string) (*Certificate, error) { + nonce, rest, ok := parseString(in) + if !ok { + return nil, errShortRead + } + + key, rest, err := parsePubKey(rest, privAlgo) + if err != nil { + return nil, err + } + + var g genericCertData + if err := Unmarshal(rest, &g); err != nil { + return nil, err + } + + c := &Certificate{ + Nonce: nonce, + Key: key, + Serial: g.Serial, + CertType: g.CertType, + KeyId: g.KeyId, + ValidAfter: g.ValidAfter, + ValidBefore: g.ValidBefore, + } + + for principals := g.ValidPrincipals; len(principals) > 0; { + principal, rest, ok := parseString(principals) + if !ok { + return nil, errShortRead + } + c.ValidPrincipals = append(c.ValidPrincipals, string(principal)) + principals = rest + } + + c.CriticalOptions, err = parseTuples(g.CriticalOptions) + if err != nil { + return nil, err + } + c.Extensions, err = parseTuples(g.Extensions) + if err != nil { + return nil, err + } + c.Reserved = g.Reserved + k, err := ParsePublicKey(g.SignatureKey) + if err != nil { + return nil, err + } + + c.SignatureKey = k + c.Signature, rest, ok = parseSignatureBody(g.Signature) + if !ok || len(rest) > 0 { + return nil, errors.New("ssh: signature parse error") + } + + return c, nil +} + +type openSSHCertSigner struct { + pub *Certificate + signer Signer +} + +type algorithmOpenSSHCertSigner struct { + *openSSHCertSigner + algorithmSigner AlgorithmSigner +} + +// NewCertSigner returns a Signer that signs with the given Certificate, whose +// private key is held by signer. It returns an error if the public key in cert +// doesn't match the key used by signer. +func NewCertSigner(cert *Certificate, signer Signer) (Signer, error) { + if !bytes.Equal(cert.Key.Marshal(), signer.PublicKey().Marshal()) { + return nil, errors.New("ssh: signer and cert have different public key") + } + + switch s := signer.(type) { + case MultiAlgorithmSigner: + return &multiAlgorithmSigner{ + AlgorithmSigner: &algorithmOpenSSHCertSigner{ + &openSSHCertSigner{cert, signer}, s}, + supportedAlgorithms: s.Algorithms(), + }, nil + case AlgorithmSigner: + return &algorithmOpenSSHCertSigner{ + &openSSHCertSigner{cert, signer}, s}, nil + default: + return &openSSHCertSigner{cert, signer}, nil + } +} + +func (s *openSSHCertSigner) Sign(rand io.Reader, data []byte) (*Signature, error) { + return s.signer.Sign(rand, data) +} + +func (s *openSSHCertSigner) PublicKey() PublicKey { + return s.pub +} + +func (s *algorithmOpenSSHCertSigner) SignWithAlgorithm(rand io.Reader, data []byte, algorithm string) (*Signature, error) { + return s.algorithmSigner.SignWithAlgorithm(rand, data, algorithm) +} + +const sourceAddressCriticalOption = "source-address" + +// CertChecker does the work of verifying a certificate. Its methods +// can be plugged into ClientConfig.HostKeyCallback and +// ServerConfig.PublicKeyCallback. For the CertChecker to work, +// minimally, the IsAuthority callback should be set. +type CertChecker struct { + // SupportedCriticalOptions lists the CriticalOptions that the + // server application layer understands. These are only used + // for user certificates. + SupportedCriticalOptions []string + + // IsUserAuthority should return true if the key is recognized as an + // authority for the given user certificate. This allows for + // certificates to be signed by other certificates. This must be set + // if this CertChecker will be checking user certificates. + IsUserAuthority func(auth PublicKey) bool + + // IsHostAuthority should report whether the key is recognized as + // an authority for this host. This allows for certificates to be + // signed by other keys, and for those other keys to only be valid + // signers for particular hostnames. This must be set if this + // CertChecker will be checking host certificates. + IsHostAuthority func(auth PublicKey, address string) bool + + // Clock is used for verifying time stamps. If nil, time.Now + // is used. + Clock func() time.Time + + // UserKeyFallback is called when CertChecker.Authenticate encounters a + // public key that is not a certificate. It must implement validation + // of user keys or else, if nil, all such keys are rejected. + UserKeyFallback func(conn ConnMetadata, key PublicKey) (*Permissions, error) + + // HostKeyFallback is called when CertChecker.CheckHostKey encounters a + // public key that is not a certificate. It must implement host key + // validation or else, if nil, all such keys are rejected. + HostKeyFallback HostKeyCallback + + // IsRevoked is called for each certificate so that revocation checking + // can be implemented. It should return true if the given certificate + // is revoked and false otherwise. If nil, no certificates are + // considered to have been revoked. + IsRevoked func(cert *Certificate) bool +} + +// CheckHostKey checks a host key certificate. This method can be +// plugged into ClientConfig.HostKeyCallback. +func (c *CertChecker) CheckHostKey(addr string, remote net.Addr, key PublicKey) error { + cert, ok := key.(*Certificate) + if !ok { + if c.HostKeyFallback != nil { + return c.HostKeyFallback(addr, remote, key) + } + return errors.New("ssh: non-certificate host key") + } + if cert.CertType != HostCert { + return fmt.Errorf("ssh: certificate presented as a host key has type %d", cert.CertType) + } + if !c.IsHostAuthority(cert.SignatureKey, addr) { + return fmt.Errorf("ssh: no authorities for hostname: %v", addr) + } + + hostname, _, err := net.SplitHostPort(addr) + if err != nil { + return err + } + + // Pass hostname only as principal for host certificates (consistent with OpenSSH) + return c.CheckCert(hostname, cert) +} + +// Authenticate checks a user certificate. Authenticate can be used as +// a value for ServerConfig.PublicKeyCallback. +func (c *CertChecker) Authenticate(conn ConnMetadata, pubKey PublicKey) (*Permissions, error) { + cert, ok := pubKey.(*Certificate) + if !ok { + if c.UserKeyFallback != nil { + return c.UserKeyFallback(conn, pubKey) + } + return nil, errors.New("ssh: normal key pairs not accepted") + } + + if cert.CertType != UserCert { + return nil, fmt.Errorf("ssh: cert has type %d", cert.CertType) + } + if !c.IsUserAuthority(cert.SignatureKey) { + return nil, fmt.Errorf("ssh: certificate signed by unrecognized authority") + } + + if err := c.CheckCert(conn.User(), cert); err != nil { + return nil, err + } + + return &cert.Permissions, nil +} + +// CheckCert checks CriticalOptions, ValidPrincipals, revocation, timestamp and +// the signature of the certificate. +func (c *CertChecker) CheckCert(principal string, cert *Certificate) error { + if c.IsRevoked != nil && c.IsRevoked(cert) { + return fmt.Errorf("ssh: certificate serial %d revoked", cert.Serial) + } + + for opt := range cert.CriticalOptions { + // sourceAddressCriticalOption will be enforced by + // serverAuthenticate + if opt == sourceAddressCriticalOption { + continue + } + + found := false + for _, supp := range c.SupportedCriticalOptions { + if supp == opt { + found = true + break + } + } + if !found { + return fmt.Errorf("ssh: unsupported critical option %q in certificate", opt) + } + } + + if len(cert.ValidPrincipals) > 0 { + // By default, certs are valid for all users/hosts. + found := false + for _, p := range cert.ValidPrincipals { + if p == principal { + found = true + break + } + } + if !found { + return fmt.Errorf("ssh: principal %q not in the set of valid principals for given certificate: %q", principal, cert.ValidPrincipals) + } + } + + clock := c.Clock + if clock == nil { + clock = time.Now + } + + unixNow := clock().Unix() + if after := int64(cert.ValidAfter); after < 0 || unixNow < int64(cert.ValidAfter) { + return fmt.Errorf("ssh: cert is not yet valid") + } + if before := int64(cert.ValidBefore); cert.ValidBefore != uint64(CertTimeInfinity) && (unixNow >= before || before < 0) { + return fmt.Errorf("ssh: cert has expired") + } + if err := cert.SignatureKey.Verify(cert.bytesForSigning(), cert.Signature); err != nil { + return fmt.Errorf("ssh: certificate signature does not verify") + } + + return nil +} + +// SignCert signs the certificate with an authority, setting the Nonce, +// SignatureKey, and Signature fields. If the authority implements the +// MultiAlgorithmSigner interface the first algorithm in the list is used. This +// is useful if you want to sign with a specific algorithm. +func (c *Certificate) SignCert(rand io.Reader, authority Signer) error { + c.Nonce = make([]byte, 32) + if _, err := io.ReadFull(rand, c.Nonce); err != nil { + return err + } + c.SignatureKey = authority.PublicKey() + + if v, ok := authority.(MultiAlgorithmSigner); ok { + if len(v.Algorithms()) == 0 { + return errors.New("the provided authority has no signature algorithm") + } + // Use the first algorithm in the list. + sig, err := v.SignWithAlgorithm(rand, c.bytesForSigning(), v.Algorithms()[0]) + if err != nil { + return err + } + c.Signature = sig + return nil + } else if v, ok := authority.(AlgorithmSigner); ok && v.PublicKey().Type() == KeyAlgoRSA { + // Default to KeyAlgoRSASHA512 for ssh-rsa signers. + // TODO: consider using KeyAlgoRSASHA256 as default. + sig, err := v.SignWithAlgorithm(rand, c.bytesForSigning(), KeyAlgoRSASHA512) + if err != nil { + return err + } + c.Signature = sig + return nil + } + + sig, err := authority.Sign(rand, c.bytesForSigning()) + if err != nil { + return err + } + c.Signature = sig + return nil +} + +// certKeyAlgoNames is a mapping from known certificate algorithm names to the +// corresponding public key signature algorithm. +// +// This map must be kept in sync with the one in agent/client.go. +var certKeyAlgoNames = map[string]string{ + CertAlgoRSAv01: KeyAlgoRSA, + CertAlgoRSASHA256v01: KeyAlgoRSASHA256, + CertAlgoRSASHA512v01: KeyAlgoRSASHA512, + CertAlgoDSAv01: KeyAlgoDSA, + CertAlgoECDSA256v01: KeyAlgoECDSA256, + CertAlgoECDSA384v01: KeyAlgoECDSA384, + CertAlgoECDSA521v01: KeyAlgoECDSA521, + CertAlgoSKECDSA256v01: KeyAlgoSKECDSA256, + CertAlgoED25519v01: KeyAlgoED25519, + CertAlgoSKED25519v01: KeyAlgoSKED25519, +} + +// underlyingAlgo returns the signature algorithm associated with algo (which is +// an advertised or negotiated public key or host key algorithm). These are +// usually the same, except for certificate algorithms. +func underlyingAlgo(algo string) string { + if a, ok := certKeyAlgoNames[algo]; ok { + return a + } + return algo +} + +// certificateAlgo returns the certificate algorithms that uses the provided +// underlying signature algorithm. +func certificateAlgo(algo string) (certAlgo string, ok bool) { + for certName, algoName := range certKeyAlgoNames { + if algoName == algo { + return certName, true + } + } + return "", false +} + +func (cert *Certificate) bytesForSigning() []byte { + c2 := *cert + c2.Signature = nil + out := c2.Marshal() + // Drop trailing signature length. + return out[:len(out)-4] +} + +// Marshal serializes c into OpenSSH's wire format. It is part of the +// PublicKey interface. +func (c *Certificate) Marshal() []byte { + generic := genericCertData{ + Serial: c.Serial, + CertType: c.CertType, + KeyId: c.KeyId, + ValidPrincipals: marshalStringList(c.ValidPrincipals), + ValidAfter: uint64(c.ValidAfter), + ValidBefore: uint64(c.ValidBefore), + CriticalOptions: marshalTuples(c.CriticalOptions), + Extensions: marshalTuples(c.Extensions), + Reserved: c.Reserved, + SignatureKey: c.SignatureKey.Marshal(), + } + if c.Signature != nil { + generic.Signature = Marshal(c.Signature) + } + genericBytes := Marshal(&generic) + keyBytes := c.Key.Marshal() + _, keyBytes, _ = parseString(keyBytes) + prefix := Marshal(&struct { + Name string + Nonce []byte + Key []byte `ssh:"rest"` + }{c.Type(), c.Nonce, keyBytes}) + + result := make([]byte, 0, len(prefix)+len(genericBytes)) + result = append(result, prefix...) + result = append(result, genericBytes...) + return result +} + +// Type returns the certificate algorithm name. It is part of the PublicKey interface. +func (c *Certificate) Type() string { + certName, ok := certificateAlgo(c.Key.Type()) + if !ok { + panic("unknown certificate type for key type " + c.Key.Type()) + } + return certName +} + +// Verify verifies a signature against the certificate's public +// key. It is part of the PublicKey interface. +func (c *Certificate) Verify(data []byte, sig *Signature) error { + return c.Key.Verify(data, sig) +} + +func parseSignatureBody(in []byte) (out *Signature, rest []byte, ok bool) { + format, in, ok := parseString(in) + if !ok { + return + } + + out = &Signature{ + Format: string(format), + } + + if out.Blob, in, ok = parseString(in); !ok { + return + } + + switch out.Format { + case KeyAlgoSKECDSA256, CertAlgoSKECDSA256v01, KeyAlgoSKED25519, CertAlgoSKED25519v01: + out.Rest = in + return out, nil, ok + } + + return out, in, ok +} + +func parseSignature(in []byte) (out *Signature, rest []byte, ok bool) { + sigBytes, rest, ok := parseString(in) + if !ok { + return + } + + out, trailing, ok := parseSignatureBody(sigBytes) + if !ok || len(trailing) > 0 { + return nil, nil, false + } + return +} diff --git a/tempfork/sshtest/ssh/certs_test.go b/tempfork/sshtest/ssh/certs_test.go new file mode 100644 index 000000000..6208bb37a --- /dev/null +++ b/tempfork/sshtest/ssh/certs_test.go @@ -0,0 +1,406 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "bytes" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "fmt" + "io" + "net" + "reflect" + "testing" + "time" + + "golang.org/x/crypto/ssh/testdata" +) + +func TestParseCert(t *testing.T) { + authKeyBytes := bytes.TrimSuffix(testdata.SSHCertificates["rsa"], []byte(" host.example.com\n")) + + key, _, _, rest, err := ParseAuthorizedKey(authKeyBytes) + if err != nil { + t.Fatalf("ParseAuthorizedKey: %v", err) + } + if len(rest) > 0 { + t.Errorf("rest: got %q, want empty", rest) + } + + if _, ok := key.(*Certificate); !ok { + t.Fatalf("got %v (%T), want *Certificate", key, key) + } + + marshaled := MarshalAuthorizedKey(key) + // Before comparison, remove the trailing newline that + // MarshalAuthorizedKey adds. + marshaled = marshaled[:len(marshaled)-1] + if !bytes.Equal(authKeyBytes, marshaled) { + t.Errorf("marshaled certificate does not match original: got %q, want %q", marshaled, authKeyBytes) + } +} + +// Cert generated by ssh-keygen OpenSSH_6.8p1 OS X 10.10.3 +// % ssh-keygen -s ca -I testcert -O source-address=192.168.1.0/24 -O force-command=/bin/sleep user.pub +// user.pub key: ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQDACh1rt2DXfV3hk6fszSQcQ/rueMId0kVD9U7nl8cfEnFxqOCrNT92g4laQIGl2mn8lsGZfTLg8ksHq3gkvgO3oo/0wHy4v32JeBOHTsN5AL4gfHNEhWeWb50ev47hnTsRIt9P4dxogeUo/hTu7j9+s9lLpEQXCvq6xocXQt0j8MV9qZBBXFLXVT3cWIkSqOdwt/5ZBg+1GSrc7WfCXVWgTk4a20uPMuJPxU4RQwZW6X3+O8Pqo8C3cW0OzZRFP6gUYUKUsTI5WntlS+LAxgw1mZNsozFGdbiOPRnEryE3SRldh9vjDR3tin1fGpA5P7+CEB/bqaXtG3V+F2OkqaMN +// Critical Options: +// +// force-command /bin/sleep +// source-address 192.168.1.0/24 +// +// Extensions: +// +// permit-X11-forwarding +// permit-agent-forwarding +// permit-port-forwarding +// permit-pty +// permit-user-rc +const exampleSSHCertWithOptions = `ssh-rsa-cert-v01@openssh.com AAAAHHNzaC1yc2EtY2VydC12MDFAb3BlbnNzaC5jb20AAAAgDyysCJY0XrO1n03EeRRoITnTPdjENFmWDs9X58PP3VUAAAADAQABAAABAQDACh1rt2DXfV3hk6fszSQcQ/rueMId0kVD9U7nl8cfEnFxqOCrNT92g4laQIGl2mn8lsGZfTLg8ksHq3gkvgO3oo/0wHy4v32JeBOHTsN5AL4gfHNEhWeWb50ev47hnTsRIt9P4dxogeUo/hTu7j9+s9lLpEQXCvq6xocXQt0j8MV9qZBBXFLXVT3cWIkSqOdwt/5ZBg+1GSrc7WfCXVWgTk4a20uPMuJPxU4RQwZW6X3+O8Pqo8C3cW0OzZRFP6gUYUKUsTI5WntlS+LAxgw1mZNsozFGdbiOPRnEryE3SRldh9vjDR3tin1fGpA5P7+CEB/bqaXtG3V+F2OkqaMNAAAAAAAAAAAAAAABAAAACHRlc3RjZXJ0AAAAAAAAAAAAAAAA//////////8AAABLAAAADWZvcmNlLWNvbW1hbmQAAAAOAAAACi9iaW4vc2xlZXAAAAAOc291cmNlLWFkZHJlc3MAAAASAAAADjE5Mi4xNjguMS4wLzI0AAAAggAAABVwZXJtaXQtWDExLWZvcndhcmRpbmcAAAAAAAAAF3Blcm1pdC1hZ2VudC1mb3J3YXJkaW5nAAAAAAAAABZwZXJtaXQtcG9ydC1mb3J3YXJkaW5nAAAAAAAAAApwZXJtaXQtcHR5AAAAAAAAAA5wZXJtaXQtdXNlci1yYwAAAAAAAAAAAAABFwAAAAdzc2gtcnNhAAAAAwEAAQAAAQEAwU+c5ui5A8+J/CFpjW8wCa52bEODA808WWQDCSuTG/eMXNf59v9Y8Pk0F1E9dGCosSNyVcB/hacUrc6He+i97+HJCyKavBsE6GDxrjRyxYqAlfcOXi/IVmaUGiO8OQ39d4GHrjToInKvExSUeleQyH4Y4/e27T/pILAqPFL3fyrvMLT5qU9QyIt6zIpa7GBP5+urouNavMprV3zsfIqNBbWypinOQAw823a5wN+zwXnhZrgQiHZ/USG09Y6k98y1dTVz8YHlQVR4D3lpTAsKDKJ5hCH9WU4fdf+lU8OyNGaJ/vz0XNqxcToe1l4numLTnaoSuH89pHryjqurB7lJKwAAAQ8AAAAHc3NoLXJzYQAAAQCaHvUIoPL1zWUHIXLvu96/HU1s/i4CAW2IIEuGgxCUCiFj6vyTyYtgxQxcmbfZf6eaITlS6XJZa7Qq4iaFZh75C1DXTX8labXhRSD4E2t//AIP9MC1rtQC5xo6FmbQ+BoKcDskr+mNACcbRSxs3IL3bwCfWDnIw2WbVox9ZdcthJKk4UoCW4ix4QwdHw7zlddlz++fGEEVhmTbll1SUkycGApPFBsAYRTMupUJcYPIeReBI/m8XfkoMk99bV8ZJQTAd7OekHY2/48Ff53jLmyDjP7kNw1F8OaPtkFs6dGJXta4krmaekPy87j+35In5hFj7yoOqvSbmYUkeX70/GGQ` + +func TestParseCertWithOptions(t *testing.T) { + opts := map[string]string{ + "source-address": "192.168.1.0/24", + "force-command": "/bin/sleep", + } + exts := map[string]string{ + "permit-X11-forwarding": "", + "permit-agent-forwarding": "", + "permit-port-forwarding": "", + "permit-pty": "", + "permit-user-rc": "", + } + authKeyBytes := []byte(exampleSSHCertWithOptions) + + key, _, _, rest, err := ParseAuthorizedKey(authKeyBytes) + if err != nil { + t.Fatalf("ParseAuthorizedKey: %v", err) + } + if len(rest) > 0 { + t.Errorf("rest: got %q, want empty", rest) + } + cert, ok := key.(*Certificate) + if !ok { + t.Fatalf("got %v (%T), want *Certificate", key, key) + } + if !reflect.DeepEqual(cert.CriticalOptions, opts) { + t.Errorf("unexpected critical options - got %v, want %v", cert.CriticalOptions, opts) + } + if !reflect.DeepEqual(cert.Extensions, exts) { + t.Errorf("unexpected Extensions - got %v, want %v", cert.Extensions, exts) + } + marshaled := MarshalAuthorizedKey(key) + // Before comparison, remove the trailing newline that + // MarshalAuthorizedKey adds. + marshaled = marshaled[:len(marshaled)-1] + if !bytes.Equal(authKeyBytes, marshaled) { + t.Errorf("marshaled certificate does not match original: got %q, want %q", marshaled, authKeyBytes) + } +} + +func TestValidateCert(t *testing.T) { + key, _, _, _, err := ParseAuthorizedKey(testdata.SSHCertificates["rsa-user-testcertificate"]) + if err != nil { + t.Fatalf("ParseAuthorizedKey: %v", err) + } + validCert, ok := key.(*Certificate) + if !ok { + t.Fatalf("got %v (%T), want *Certificate", key, key) + } + checker := CertChecker{} + checker.IsUserAuthority = func(k PublicKey) bool { + return bytes.Equal(k.Marshal(), validCert.SignatureKey.Marshal()) + } + + if err := checker.CheckCert("testcertificate", validCert); err != nil { + t.Errorf("Unable to validate certificate: %v", err) + } + invalidCert := &Certificate{ + Key: testPublicKeys["rsa"], + SignatureKey: testPublicKeys["ecdsa"], + ValidBefore: CertTimeInfinity, + Signature: &Signature{}, + } + if err := checker.CheckCert("testcertificate", invalidCert); err == nil { + t.Error("Invalid cert signature passed validation") + } +} + +func TestValidateCertTime(t *testing.T) { + cert := Certificate{ + ValidPrincipals: []string{"user"}, + Key: testPublicKeys["rsa"], + ValidAfter: 50, + ValidBefore: 100, + } + + cert.SignCert(rand.Reader, testSigners["ecdsa"]) + + for ts, ok := range map[int64]bool{ + 25: false, + 50: true, + 99: true, + 100: false, + 125: false, + } { + checker := CertChecker{ + Clock: func() time.Time { return time.Unix(ts, 0) }, + } + checker.IsUserAuthority = func(k PublicKey) bool { + return bytes.Equal(k.Marshal(), + testPublicKeys["ecdsa"].Marshal()) + } + + if v := checker.CheckCert("user", &cert); (v == nil) != ok { + t.Errorf("Authenticate(%d): %v", ts, v) + } + } +} + +// TODO(hanwen): tests for +// +// host keys: +// * fallbacks + +func TestHostKeyCert(t *testing.T) { + cert := &Certificate{ + ValidPrincipals: []string{"hostname", "hostname.domain", "otherhost"}, + Key: testPublicKeys["rsa"], + ValidBefore: CertTimeInfinity, + CertType: HostCert, + } + cert.SignCert(rand.Reader, testSigners["ecdsa"]) + + checker := &CertChecker{ + IsHostAuthority: func(p PublicKey, addr string) bool { + return addr == "hostname:22" && bytes.Equal(testPublicKeys["ecdsa"].Marshal(), p.Marshal()) + }, + } + + certSigner, err := NewCertSigner(cert, testSigners["rsa"]) + if err != nil { + t.Errorf("NewCertSigner: %v", err) + } + + for _, test := range []struct { + addr string + succeed bool + certSignerAlgorithms []string // Empty means no algorithm restrictions. + clientHostKeyAlgorithms []string + }{ + {addr: "hostname:22", succeed: true}, + { + addr: "hostname:22", + succeed: true, + certSignerAlgorithms: []string{KeyAlgoRSASHA256, KeyAlgoRSASHA512}, + clientHostKeyAlgorithms: []string{CertAlgoRSASHA512v01}, + }, + { + addr: "hostname:22", + succeed: false, + certSignerAlgorithms: []string{KeyAlgoRSASHA256, KeyAlgoRSASHA512}, + clientHostKeyAlgorithms: []string{CertAlgoRSAv01}, + }, + { + addr: "hostname:22", + succeed: false, + certSignerAlgorithms: []string{KeyAlgoRSASHA256, KeyAlgoRSASHA512}, + clientHostKeyAlgorithms: []string{KeyAlgoRSASHA512}, // Not a certificate algorithm. + }, + {addr: "otherhost:22", succeed: false}, // The certificate is valid for 'otherhost' as hostname, but we only recognize the authority of the signer for the address 'hostname:22' + {addr: "lasthost:22", succeed: false}, + } { + c1, c2, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + defer c1.Close() + defer c2.Close() + + errc := make(chan error) + + go func() { + conf := ServerConfig{ + NoClientAuth: true, + } + if len(test.certSignerAlgorithms) > 0 { + mas, err := NewSignerWithAlgorithms(certSigner.(AlgorithmSigner), test.certSignerAlgorithms) + if err != nil { + errc <- err + return + } + conf.AddHostKey(mas) + } else { + conf.AddHostKey(certSigner) + } + _, _, _, err := NewServerConn(c1, &conf) + errc <- err + }() + + config := &ClientConfig{ + User: "user", + HostKeyCallback: checker.CheckHostKey, + HostKeyAlgorithms: test.clientHostKeyAlgorithms, + } + _, _, _, err = NewClientConn(c2, test.addr, config) + + if (err == nil) != test.succeed { + t.Errorf("NewClientConn(%q): %v", test.addr, err) + } + + err = <-errc + if (err == nil) != test.succeed { + t.Errorf("NewServerConn(%q): %v", test.addr, err) + } + } +} + +type legacyRSASigner struct { + Signer +} + +func (s *legacyRSASigner) Sign(rand io.Reader, data []byte) (*Signature, error) { + v, ok := s.Signer.(AlgorithmSigner) + if !ok { + return nil, fmt.Errorf("invalid signer") + } + return v.SignWithAlgorithm(rand, data, KeyAlgoRSA) +} + +func TestCertTypes(t *testing.T) { + algorithmSigner, ok := testSigners["rsa"].(AlgorithmSigner) + if !ok { + t.Fatal("rsa test signer does not implement the AlgorithmSigner interface") + } + multiAlgoSignerSHA256, err := NewSignerWithAlgorithms(algorithmSigner, []string{KeyAlgoRSASHA256}) + if err != nil { + t.Fatalf("unable to create multi algorithm signer SHA256: %v", err) + } + // Algorithms are in order of preference, we expect rsa-sha2-512 to be used. + multiAlgoSignerSHA512, err := NewSignerWithAlgorithms(algorithmSigner, []string{KeyAlgoRSASHA512, KeyAlgoRSASHA256}) + if err != nil { + t.Fatalf("unable to create multi algorithm signer SHA512: %v", err) + } + + var testVars = []struct { + name string + signer Signer + algo string + }{ + {CertAlgoECDSA256v01, testSigners["ecdsap256"], ""}, + {CertAlgoECDSA384v01, testSigners["ecdsap384"], ""}, + {CertAlgoECDSA521v01, testSigners["ecdsap521"], ""}, + {CertAlgoED25519v01, testSigners["ed25519"], ""}, + {CertAlgoRSAv01, testSigners["rsa"], KeyAlgoRSASHA256}, + {"legacyRSASigner", &legacyRSASigner{testSigners["rsa"]}, KeyAlgoRSA}, + {"multiAlgoRSASignerSHA256", multiAlgoSignerSHA256, KeyAlgoRSASHA256}, + {"multiAlgoRSASignerSHA512", multiAlgoSignerSHA512, KeyAlgoRSASHA512}, + {CertAlgoDSAv01, testSigners["dsa"], ""}, + } + + k, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + t.Fatalf("error generating host key: %v", err) + } + + signer, err := NewSignerFromKey(k) + if err != nil { + t.Fatalf("error generating signer for ssh listener: %v", err) + } + + conf := &ServerConfig{ + PublicKeyCallback: func(c ConnMetadata, k PublicKey) (*Permissions, error) { + return new(Permissions), nil + }, + } + conf.AddHostKey(signer) + + for _, m := range testVars { + t.Run(m.name, func(t *testing.T) { + + c1, c2, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + defer c1.Close() + defer c2.Close() + + go NewServerConn(c1, conf) + + priv := m.signer + if err != nil { + t.Fatalf("error generating ssh pubkey: %v", err) + } + + cert := &Certificate{ + CertType: UserCert, + Key: priv.PublicKey(), + } + cert.SignCert(rand.Reader, priv) + + certSigner, err := NewCertSigner(cert, priv) + if err != nil { + t.Fatalf("error generating cert signer: %v", err) + } + + if m.algo != "" && cert.Signature.Format != m.algo { + t.Errorf("expected %q signature format, got %q", m.algo, cert.Signature.Format) + } + + config := &ClientConfig{ + User: "user", + HostKeyCallback: func(h string, r net.Addr, k PublicKey) error { return nil }, + Auth: []AuthMethod{PublicKeys(certSigner)}, + } + + _, _, _, err = NewClientConn(c2, "", config) + if err != nil { + t.Fatalf("error connecting: %v", err) + } + }) + } +} + +func TestCertSignWithMultiAlgorithmSigner(t *testing.T) { + type testcase struct { + sigAlgo string + algorithms []string + } + cases := []testcase{ + { + sigAlgo: KeyAlgoRSA, + algorithms: []string{KeyAlgoRSA, KeyAlgoRSASHA512}, + }, + { + sigAlgo: KeyAlgoRSASHA256, + algorithms: []string{KeyAlgoRSASHA256, KeyAlgoRSA, KeyAlgoRSASHA512}, + }, + { + sigAlgo: KeyAlgoRSASHA512, + algorithms: []string{KeyAlgoRSASHA512, KeyAlgoRSASHA256}, + }, + } + + cert := &Certificate{ + Key: testPublicKeys["rsa"], + ValidBefore: CertTimeInfinity, + CertType: UserCert, + } + + for _, c := range cases { + t.Run(c.sigAlgo, func(t *testing.T) { + signer, err := NewSignerWithAlgorithms(testSigners["rsa"].(AlgorithmSigner), c.algorithms) + if err != nil { + t.Fatalf("NewSignerWithAlgorithms error: %v", err) + } + if err := cert.SignCert(rand.Reader, signer); err != nil { + t.Fatalf("SignCert error: %v", err) + } + if cert.Signature.Format != c.sigAlgo { + t.Fatalf("got signature format %q, want %q", cert.Signature.Format, c.sigAlgo) + } + }) + } +} diff --git a/tempfork/sshtest/ssh/channel.go b/tempfork/sshtest/ssh/channel.go new file mode 100644 index 000000000..cc0bb7ab6 --- /dev/null +++ b/tempfork/sshtest/ssh/channel.go @@ -0,0 +1,645 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "encoding/binary" + "errors" + "fmt" + "io" + "log" + "sync" +) + +const ( + minPacketLength = 9 + // channelMaxPacket contains the maximum number of bytes that will be + // sent in a single packet. As per RFC 4253, section 6.1, 32k is also + // the minimum. + channelMaxPacket = 1 << 15 + // We follow OpenSSH here. + channelWindowSize = 64 * channelMaxPacket +) + +// NewChannel represents an incoming request to a channel. It must either be +// accepted for use by calling Accept, or rejected by calling Reject. +type NewChannel interface { + // Accept accepts the channel creation request. It returns the Channel + // and a Go channel containing SSH requests. The Go channel must be + // serviced otherwise the Channel will hang. + Accept() (Channel, <-chan *Request, error) + + // Reject rejects the channel creation request. After calling + // this, no other methods on the Channel may be called. + Reject(reason RejectionReason, message string) error + + // ChannelType returns the type of the channel, as supplied by the + // client. + ChannelType() string + + // ExtraData returns the arbitrary payload for this channel, as supplied + // by the client. This data is specific to the channel type. + ExtraData() []byte +} + +// A Channel is an ordered, reliable, flow-controlled, duplex stream +// that is multiplexed over an SSH connection. +type Channel interface { + // Read reads up to len(data) bytes from the channel. + Read(data []byte) (int, error) + + // Write writes len(data) bytes to the channel. + Write(data []byte) (int, error) + + // Close signals end of channel use. No data may be sent after this + // call. + Close() error + + // CloseWrite signals the end of sending in-band + // data. Requests may still be sent, and the other side may + // still send data + CloseWrite() error + + // SendRequest sends a channel request. If wantReply is true, + // it will wait for a reply and return the result as a + // boolean, otherwise the return value will be false. Channel + // requests are out-of-band messages so they may be sent even + // if the data stream is closed or blocked by flow control. + // If the channel is closed before a reply is returned, io.EOF + // is returned. + SendRequest(name string, wantReply bool, payload []byte) (bool, error) + + // Stderr returns an io.ReadWriter that writes to this channel + // with the extended data type set to stderr. Stderr may + // safely be read and written from a different goroutine than + // Read and Write respectively. + Stderr() io.ReadWriter +} + +// Request is a request sent outside of the normal stream of +// data. Requests can either be specific to an SSH channel, or they +// can be global. +type Request struct { + Type string + WantReply bool + Payload []byte + + ch *channel + mux *mux +} + +// Reply sends a response to a request. It must be called for all requests +// where WantReply is true and is a no-op otherwise. The payload argument is +// ignored for replies to channel-specific requests. +func (r *Request) Reply(ok bool, payload []byte) error { + if !r.WantReply { + return nil + } + + if r.ch == nil { + return r.mux.ackRequest(ok, payload) + } + + return r.ch.ackRequest(ok) +} + +// RejectionReason is an enumeration used when rejecting channel creation +// requests. See RFC 4254, section 5.1. +type RejectionReason uint32 + +const ( + Prohibited RejectionReason = iota + 1 + ConnectionFailed + UnknownChannelType + ResourceShortage +) + +// String converts the rejection reason to human readable form. +func (r RejectionReason) String() string { + switch r { + case Prohibited: + return "administratively prohibited" + case ConnectionFailed: + return "connect failed" + case UnknownChannelType: + return "unknown channel type" + case ResourceShortage: + return "resource shortage" + } + return fmt.Sprintf("unknown reason %d", int(r)) +} + +func min(a uint32, b int) uint32 { + if a < uint32(b) { + return a + } + return uint32(b) +} + +type channelDirection uint8 + +const ( + channelInbound channelDirection = iota + channelOutbound +) + +// channel is an implementation of the Channel interface that works +// with the mux class. +type channel struct { + // R/O after creation + chanType string + extraData []byte + localId, remoteId uint32 + + // maxIncomingPayload and maxRemotePayload are the maximum + // payload sizes of normal and extended data packets for + // receiving and sending, respectively. The wire packet will + // be 9 or 13 bytes larger (excluding encryption overhead). + maxIncomingPayload uint32 + maxRemotePayload uint32 + + mux *mux + + // decided is set to true if an accept or reject message has been sent + // (for outbound channels) or received (for inbound channels). + decided bool + + // direction contains either channelOutbound, for channels created + // locally, or channelInbound, for channels created by the peer. + direction channelDirection + + // Pending internal channel messages. + msg chan interface{} + + // Since requests have no ID, there can be only one request + // with WantReply=true outstanding. This lock is held by a + // goroutine that has such an outgoing request pending. + sentRequestMu sync.Mutex + + incomingRequests chan *Request + + sentEOF bool + + // thread-safe data + remoteWin window + pending *buffer + extPending *buffer + + // windowMu protects myWindow, the flow-control window, and myConsumed, + // the number of bytes consumed since we last increased myWindow + windowMu sync.Mutex + myWindow uint32 + myConsumed uint32 + + // writeMu serializes calls to mux.conn.writePacket() and + // protects sentClose and packetPool. This mutex must be + // different from windowMu, as writePacket can block if there + // is a key exchange pending. + writeMu sync.Mutex + sentClose bool + + // packetPool has a buffer for each extended channel ID to + // save allocations during writes. + packetPool map[uint32][]byte +} + +// writePacket sends a packet. If the packet is a channel close, it updates +// sentClose. This method takes the lock c.writeMu. +func (ch *channel) writePacket(packet []byte) error { + ch.writeMu.Lock() + if ch.sentClose { + ch.writeMu.Unlock() + return io.EOF + } + ch.sentClose = (packet[0] == msgChannelClose) + err := ch.mux.conn.writePacket(packet) + ch.writeMu.Unlock() + return err +} + +func (ch *channel) sendMessage(msg interface{}) error { + if debugMux { + log.Printf("send(%d): %#v", ch.mux.chanList.offset, msg) + } + + p := Marshal(msg) + binary.BigEndian.PutUint32(p[1:], ch.remoteId) + return ch.writePacket(p) +} + +// WriteExtended writes data to a specific extended stream. These streams are +// used, for example, for stderr. +func (ch *channel) WriteExtended(data []byte, extendedCode uint32) (n int, err error) { + if ch.sentEOF { + return 0, io.EOF + } + // 1 byte message type, 4 bytes remoteId, 4 bytes data length + opCode := byte(msgChannelData) + headerLength := uint32(9) + if extendedCode > 0 { + headerLength += 4 + opCode = msgChannelExtendedData + } + + ch.writeMu.Lock() + packet := ch.packetPool[extendedCode] + // We don't remove the buffer from packetPool, so + // WriteExtended calls from different goroutines will be + // flagged as errors by the race detector. + ch.writeMu.Unlock() + + for len(data) > 0 { + space := min(ch.maxRemotePayload, len(data)) + if space, err = ch.remoteWin.reserve(space); err != nil { + return n, err + } + if want := headerLength + space; uint32(cap(packet)) < want { + packet = make([]byte, want) + } else { + packet = packet[:want] + } + + todo := data[:space] + + packet[0] = opCode + binary.BigEndian.PutUint32(packet[1:], ch.remoteId) + if extendedCode > 0 { + binary.BigEndian.PutUint32(packet[5:], uint32(extendedCode)) + } + binary.BigEndian.PutUint32(packet[headerLength-4:], uint32(len(todo))) + copy(packet[headerLength:], todo) + if err = ch.writePacket(packet); err != nil { + return n, err + } + + n += len(todo) + data = data[len(todo):] + } + + ch.writeMu.Lock() + ch.packetPool[extendedCode] = packet + ch.writeMu.Unlock() + + return n, err +} + +func (ch *channel) handleData(packet []byte) error { + headerLen := 9 + isExtendedData := packet[0] == msgChannelExtendedData + if isExtendedData { + headerLen = 13 + } + if len(packet) < headerLen { + // malformed data packet + return parseError(packet[0]) + } + + var extended uint32 + if isExtendedData { + extended = binary.BigEndian.Uint32(packet[5:]) + } + + length := binary.BigEndian.Uint32(packet[headerLen-4 : headerLen]) + if length == 0 { + return nil + } + if length > ch.maxIncomingPayload { + // TODO(hanwen): should send Disconnect? + return errors.New("ssh: incoming packet exceeds maximum payload size") + } + + data := packet[headerLen:] + if length != uint32(len(data)) { + return errors.New("ssh: wrong packet length") + } + + ch.windowMu.Lock() + if ch.myWindow < length { + ch.windowMu.Unlock() + // TODO(hanwen): should send Disconnect with reason? + return errors.New("ssh: remote side wrote too much") + } + ch.myWindow -= length + ch.windowMu.Unlock() + + if extended == 1 { + ch.extPending.write(data) + } else if extended > 0 { + // discard other extended data. + } else { + ch.pending.write(data) + } + return nil +} + +func (c *channel) adjustWindow(adj uint32) error { + c.windowMu.Lock() + // Since myConsumed and myWindow are managed on our side, and can never + // exceed the initial window setting, we don't worry about overflow. + c.myConsumed += adj + var sendAdj uint32 + if (channelWindowSize-c.myWindow > 3*c.maxIncomingPayload) || + (c.myWindow < channelWindowSize/2) { + sendAdj = c.myConsumed + c.myConsumed = 0 + c.myWindow += sendAdj + } + c.windowMu.Unlock() + if sendAdj == 0 { + return nil + } + return c.sendMessage(windowAdjustMsg{ + AdditionalBytes: sendAdj, + }) +} + +func (c *channel) ReadExtended(data []byte, extended uint32) (n int, err error) { + switch extended { + case 1: + n, err = c.extPending.Read(data) + case 0: + n, err = c.pending.Read(data) + default: + return 0, fmt.Errorf("ssh: extended code %d unimplemented", extended) + } + + if n > 0 { + err = c.adjustWindow(uint32(n)) + // sendWindowAdjust can return io.EOF if the remote + // peer has closed the connection, however we want to + // defer forwarding io.EOF to the caller of Read until + // the buffer has been drained. + if n > 0 && err == io.EOF { + err = nil + } + } + + return n, err +} + +func (c *channel) close() { + c.pending.eof() + c.extPending.eof() + close(c.msg) + close(c.incomingRequests) + c.writeMu.Lock() + // This is not necessary for a normal channel teardown, but if + // there was another error, it is. + c.sentClose = true + c.writeMu.Unlock() + // Unblock writers. + c.remoteWin.close() +} + +// responseMessageReceived is called when a success or failure message is +// received on a channel to check that such a message is reasonable for the +// given channel. +func (ch *channel) responseMessageReceived() error { + if ch.direction == channelInbound { + return errors.New("ssh: channel response message received on inbound channel") + } + if ch.decided { + return errors.New("ssh: duplicate response received for channel") + } + ch.decided = true + return nil +} + +func (ch *channel) handlePacket(packet []byte) error { + switch packet[0] { + case msgChannelData, msgChannelExtendedData: + return ch.handleData(packet) + case msgChannelClose: + ch.sendMessage(channelCloseMsg{PeersID: ch.remoteId}) + ch.mux.chanList.remove(ch.localId) + ch.close() + return nil + case msgChannelEOF: + // RFC 4254 is mute on how EOF affects dataExt messages but + // it is logical to signal EOF at the same time. + ch.extPending.eof() + ch.pending.eof() + return nil + } + + decoded, err := decode(packet) + if err != nil { + return err + } + + switch msg := decoded.(type) { + case *channelOpenFailureMsg: + if err := ch.responseMessageReceived(); err != nil { + return err + } + ch.mux.chanList.remove(msg.PeersID) + ch.msg <- msg + case *channelOpenConfirmMsg: + if err := ch.responseMessageReceived(); err != nil { + return err + } + if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > 1<<31 { + return fmt.Errorf("ssh: invalid MaxPacketSize %d from peer", msg.MaxPacketSize) + } + ch.remoteId = msg.MyID + ch.maxRemotePayload = msg.MaxPacketSize + ch.remoteWin.add(msg.MyWindow) + ch.msg <- msg + case *windowAdjustMsg: + if !ch.remoteWin.add(msg.AdditionalBytes) { + return fmt.Errorf("ssh: invalid window update for %d bytes", msg.AdditionalBytes) + } + case *channelRequestMsg: + req := Request{ + Type: msg.Request, + WantReply: msg.WantReply, + Payload: msg.RequestSpecificData, + ch: ch, + } + + ch.incomingRequests <- &req + default: + ch.msg <- msg + } + return nil +} + +func (m *mux) newChannel(chanType string, direction channelDirection, extraData []byte) *channel { + ch := &channel{ + remoteWin: window{Cond: newCond()}, + myWindow: channelWindowSize, + pending: newBuffer(), + extPending: newBuffer(), + direction: direction, + incomingRequests: make(chan *Request, chanSize), + msg: make(chan interface{}, chanSize), + chanType: chanType, + extraData: extraData, + mux: m, + packetPool: make(map[uint32][]byte), + } + ch.localId = m.chanList.add(ch) + return ch +} + +var errUndecided = errors.New("ssh: must Accept or Reject channel") +var errDecidedAlready = errors.New("ssh: can call Accept or Reject only once") + +type extChannel struct { + code uint32 + ch *channel +} + +func (e *extChannel) Write(data []byte) (n int, err error) { + return e.ch.WriteExtended(data, e.code) +} + +func (e *extChannel) Read(data []byte) (n int, err error) { + return e.ch.ReadExtended(data, e.code) +} + +func (ch *channel) Accept() (Channel, <-chan *Request, error) { + if ch.decided { + return nil, nil, errDecidedAlready + } + ch.maxIncomingPayload = channelMaxPacket + confirm := channelOpenConfirmMsg{ + PeersID: ch.remoteId, + MyID: ch.localId, + MyWindow: ch.myWindow, + MaxPacketSize: ch.maxIncomingPayload, + } + ch.decided = true + if err := ch.sendMessage(confirm); err != nil { + return nil, nil, err + } + + return ch, ch.incomingRequests, nil +} + +func (ch *channel) Reject(reason RejectionReason, message string) error { + if ch.decided { + return errDecidedAlready + } + reject := channelOpenFailureMsg{ + PeersID: ch.remoteId, + Reason: reason, + Message: message, + Language: "en", + } + ch.decided = true + return ch.sendMessage(reject) +} + +func (ch *channel) Read(data []byte) (int, error) { + if !ch.decided { + return 0, errUndecided + } + return ch.ReadExtended(data, 0) +} + +func (ch *channel) Write(data []byte) (int, error) { + if !ch.decided { + return 0, errUndecided + } + return ch.WriteExtended(data, 0) +} + +func (ch *channel) CloseWrite() error { + if !ch.decided { + return errUndecided + } + ch.sentEOF = true + return ch.sendMessage(channelEOFMsg{ + PeersID: ch.remoteId}) +} + +func (ch *channel) Close() error { + if !ch.decided { + return errUndecided + } + + return ch.sendMessage(channelCloseMsg{ + PeersID: ch.remoteId}) +} + +// Extended returns an io.ReadWriter that sends and receives data on the given, +// SSH extended stream. Such streams are used, for example, for stderr. +func (ch *channel) Extended(code uint32) io.ReadWriter { + if !ch.decided { + return nil + } + return &extChannel{code, ch} +} + +func (ch *channel) Stderr() io.ReadWriter { + return ch.Extended(1) +} + +func (ch *channel) SendRequest(name string, wantReply bool, payload []byte) (bool, error) { + if !ch.decided { + return false, errUndecided + } + + if wantReply { + ch.sentRequestMu.Lock() + defer ch.sentRequestMu.Unlock() + } + + msg := channelRequestMsg{ + PeersID: ch.remoteId, + Request: name, + WantReply: wantReply, + RequestSpecificData: payload, + } + + if err := ch.sendMessage(msg); err != nil { + return false, err + } + + if wantReply { + m, ok := (<-ch.msg) + if !ok { + return false, io.EOF + } + switch m.(type) { + case *channelRequestFailureMsg: + return false, nil + case *channelRequestSuccessMsg: + return true, nil + default: + return false, fmt.Errorf("ssh: unexpected response to channel request: %#v", m) + } + } + + return false, nil +} + +// ackRequest either sends an ack or nack to the channel request. +func (ch *channel) ackRequest(ok bool) error { + if !ch.decided { + return errUndecided + } + + var msg interface{} + if !ok { + msg = channelRequestFailureMsg{ + PeersID: ch.remoteId, + } + } else { + msg = channelRequestSuccessMsg{ + PeersID: ch.remoteId, + } + } + return ch.sendMessage(msg) +} + +func (ch *channel) ChannelType() string { + return ch.chanType +} + +func (ch *channel) ExtraData() []byte { + return ch.extraData +} diff --git a/tempfork/sshtest/ssh/cipher.go b/tempfork/sshtest/ssh/cipher.go new file mode 100644 index 000000000..0533786f4 --- /dev/null +++ b/tempfork/sshtest/ssh/cipher.go @@ -0,0 +1,789 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/des" + "crypto/rc4" + "crypto/subtle" + "encoding/binary" + "errors" + "fmt" + "hash" + "io" + + "golang.org/x/crypto/chacha20" + "golang.org/x/crypto/poly1305" +) + +const ( + packetSizeMultiple = 16 // TODO(huin) this should be determined by the cipher. + + // RFC 4253 section 6.1 defines a minimum packet size of 32768 that implementations + // MUST be able to process (plus a few more kilobytes for padding and mac). The RFC + // indicates implementations SHOULD be able to handle larger packet sizes, but then + // waffles on about reasonable limits. + // + // OpenSSH caps their maxPacket at 256kB so we choose to do + // the same. maxPacket is also used to ensure that uint32 + // length fields do not overflow, so it should remain well + // below 4G. + maxPacket = 256 * 1024 +) + +// noneCipher implements cipher.Stream and provides no encryption. It is used +// by the transport before the first key-exchange. +type noneCipher struct{} + +func (c noneCipher) XORKeyStream(dst, src []byte) { + copy(dst, src) +} + +func newAESCTR(key, iv []byte) (cipher.Stream, error) { + c, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + return cipher.NewCTR(c, iv), nil +} + +func newRC4(key, iv []byte) (cipher.Stream, error) { + return rc4.NewCipher(key) +} + +type cipherMode struct { + keySize int + ivSize int + create func(key, iv []byte, macKey []byte, algs directionAlgorithms) (packetCipher, error) +} + +func streamCipherMode(skip int, createFunc func(key, iv []byte) (cipher.Stream, error)) func(key, iv []byte, macKey []byte, algs directionAlgorithms) (packetCipher, error) { + return func(key, iv, macKey []byte, algs directionAlgorithms) (packetCipher, error) { + stream, err := createFunc(key, iv) + if err != nil { + return nil, err + } + + var streamDump []byte + if skip > 0 { + streamDump = make([]byte, 512) + } + + for remainingToDump := skip; remainingToDump > 0; { + dumpThisTime := remainingToDump + if dumpThisTime > len(streamDump) { + dumpThisTime = len(streamDump) + } + stream.XORKeyStream(streamDump[:dumpThisTime], streamDump[:dumpThisTime]) + remainingToDump -= dumpThisTime + } + + mac := macModes[algs.MAC].new(macKey) + return &streamPacketCipher{ + mac: mac, + etm: macModes[algs.MAC].etm, + macResult: make([]byte, mac.Size()), + cipher: stream, + }, nil + } +} + +// cipherModes documents properties of supported ciphers. Ciphers not included +// are not supported and will not be negotiated, even if explicitly requested in +// ClientConfig.Crypto.Ciphers. +var cipherModes = map[string]*cipherMode{ + // Ciphers from RFC 4344, which introduced many CTR-based ciphers. Algorithms + // are defined in the order specified in the RFC. + "aes128-ctr": {16, aes.BlockSize, streamCipherMode(0, newAESCTR)}, + "aes192-ctr": {24, aes.BlockSize, streamCipherMode(0, newAESCTR)}, + "aes256-ctr": {32, aes.BlockSize, streamCipherMode(0, newAESCTR)}, + + // Ciphers from RFC 4345, which introduces security-improved arcfour ciphers. + // They are defined in the order specified in the RFC. + "arcfour128": {16, 0, streamCipherMode(1536, newRC4)}, + "arcfour256": {32, 0, streamCipherMode(1536, newRC4)}, + + // Cipher defined in RFC 4253, which describes SSH Transport Layer Protocol. + // Note that this cipher is not safe, as stated in RFC 4253: "Arcfour (and + // RC4) has problems with weak keys, and should be used with caution." + // RFC 4345 introduces improved versions of Arcfour. + "arcfour": {16, 0, streamCipherMode(0, newRC4)}, + + // AEAD ciphers + gcm128CipherID: {16, 12, newGCMCipher}, + gcm256CipherID: {32, 12, newGCMCipher}, + chacha20Poly1305ID: {64, 0, newChaCha20Cipher}, + + // CBC mode is insecure and so is not included in the default config. + // (See https://www.ieee-security.org/TC/SP2013/papers/4977a526.pdf). If absolutely + // needed, it's possible to specify a custom Config to enable it. + // You should expect that an active attacker can recover plaintext if + // you do. + aes128cbcID: {16, aes.BlockSize, newAESCBCCipher}, + + // 3des-cbc is insecure and is not included in the default + // config. + tripledescbcID: {24, des.BlockSize, newTripleDESCBCCipher}, +} + +// prefixLen is the length of the packet prefix that contains the packet length +// and number of padding bytes. +const prefixLen = 5 + +// streamPacketCipher is a packetCipher using a stream cipher. +type streamPacketCipher struct { + mac hash.Hash + cipher cipher.Stream + etm bool + + // The following members are to avoid per-packet allocations. + prefix [prefixLen]byte + seqNumBytes [4]byte + padding [2 * packetSizeMultiple]byte + packetData []byte + macResult []byte +} + +// readCipherPacket reads and decrypt a single packet from the reader argument. +func (s *streamPacketCipher) readCipherPacket(seqNum uint32, r io.Reader) ([]byte, error) { + if _, err := io.ReadFull(r, s.prefix[:]); err != nil { + return nil, err + } + + var encryptedPaddingLength [1]byte + if s.mac != nil && s.etm { + copy(encryptedPaddingLength[:], s.prefix[4:5]) + s.cipher.XORKeyStream(s.prefix[4:5], s.prefix[4:5]) + } else { + s.cipher.XORKeyStream(s.prefix[:], s.prefix[:]) + } + + length := binary.BigEndian.Uint32(s.prefix[0:4]) + paddingLength := uint32(s.prefix[4]) + + var macSize uint32 + if s.mac != nil { + s.mac.Reset() + binary.BigEndian.PutUint32(s.seqNumBytes[:], seqNum) + s.mac.Write(s.seqNumBytes[:]) + if s.etm { + s.mac.Write(s.prefix[:4]) + s.mac.Write(encryptedPaddingLength[:]) + } else { + s.mac.Write(s.prefix[:]) + } + macSize = uint32(s.mac.Size()) + } + + if length <= paddingLength+1 { + return nil, errors.New("ssh: invalid packet length, packet too small") + } + + if length > maxPacket { + return nil, errors.New("ssh: invalid packet length, packet too large") + } + + // the maxPacket check above ensures that length-1+macSize + // does not overflow. + if uint32(cap(s.packetData)) < length-1+macSize { + s.packetData = make([]byte, length-1+macSize) + } else { + s.packetData = s.packetData[:length-1+macSize] + } + + if _, err := io.ReadFull(r, s.packetData); err != nil { + return nil, err + } + mac := s.packetData[length-1:] + data := s.packetData[:length-1] + + if s.mac != nil && s.etm { + s.mac.Write(data) + } + + s.cipher.XORKeyStream(data, data) + + if s.mac != nil { + if !s.etm { + s.mac.Write(data) + } + s.macResult = s.mac.Sum(s.macResult[:0]) + if subtle.ConstantTimeCompare(s.macResult, mac) != 1 { + return nil, errors.New("ssh: MAC failure") + } + } + + return s.packetData[:length-paddingLength-1], nil +} + +// writeCipherPacket encrypts and sends a packet of data to the writer argument +func (s *streamPacketCipher) writeCipherPacket(seqNum uint32, w io.Writer, rand io.Reader, packet []byte) error { + if len(packet) > maxPacket { + return errors.New("ssh: packet too large") + } + + aadlen := 0 + if s.mac != nil && s.etm { + // packet length is not encrypted for EtM modes + aadlen = 4 + } + + paddingLength := packetSizeMultiple - (prefixLen+len(packet)-aadlen)%packetSizeMultiple + if paddingLength < 4 { + paddingLength += packetSizeMultiple + } + + length := len(packet) + 1 + paddingLength + binary.BigEndian.PutUint32(s.prefix[:], uint32(length)) + s.prefix[4] = byte(paddingLength) + padding := s.padding[:paddingLength] + if _, err := io.ReadFull(rand, padding); err != nil { + return err + } + + if s.mac != nil { + s.mac.Reset() + binary.BigEndian.PutUint32(s.seqNumBytes[:], seqNum) + s.mac.Write(s.seqNumBytes[:]) + + if s.etm { + // For EtM algorithms, the packet length must stay unencrypted, + // but the following data (padding length) must be encrypted + s.cipher.XORKeyStream(s.prefix[4:5], s.prefix[4:5]) + } + + s.mac.Write(s.prefix[:]) + + if !s.etm { + // For non-EtM algorithms, the algorithm is applied on unencrypted data + s.mac.Write(packet) + s.mac.Write(padding) + } + } + + if !(s.mac != nil && s.etm) { + // For EtM algorithms, the padding length has already been encrypted + // and the packet length must remain unencrypted + s.cipher.XORKeyStream(s.prefix[:], s.prefix[:]) + } + + s.cipher.XORKeyStream(packet, packet) + s.cipher.XORKeyStream(padding, padding) + + if s.mac != nil && s.etm { + // For EtM algorithms, packet and padding must be encrypted + s.mac.Write(packet) + s.mac.Write(padding) + } + + if _, err := w.Write(s.prefix[:]); err != nil { + return err + } + if _, err := w.Write(packet); err != nil { + return err + } + if _, err := w.Write(padding); err != nil { + return err + } + + if s.mac != nil { + s.macResult = s.mac.Sum(s.macResult[:0]) + if _, err := w.Write(s.macResult); err != nil { + return err + } + } + + return nil +} + +type gcmCipher struct { + aead cipher.AEAD + prefix [4]byte + iv []byte + buf []byte +} + +func newGCMCipher(key, iv, unusedMacKey []byte, unusedAlgs directionAlgorithms) (packetCipher, error) { + c, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + + aead, err := cipher.NewGCM(c) + if err != nil { + return nil, err + } + + return &gcmCipher{ + aead: aead, + iv: iv, + }, nil +} + +const gcmTagSize = 16 + +func (c *gcmCipher) writeCipherPacket(seqNum uint32, w io.Writer, rand io.Reader, packet []byte) error { + // Pad out to multiple of 16 bytes. This is different from the + // stream cipher because that encrypts the length too. + padding := byte(packetSizeMultiple - (1+len(packet))%packetSizeMultiple) + if padding < 4 { + padding += packetSizeMultiple + } + + length := uint32(len(packet) + int(padding) + 1) + binary.BigEndian.PutUint32(c.prefix[:], length) + if _, err := w.Write(c.prefix[:]); err != nil { + return err + } + + if cap(c.buf) < int(length) { + c.buf = make([]byte, length) + } else { + c.buf = c.buf[:length] + } + + c.buf[0] = padding + copy(c.buf[1:], packet) + if _, err := io.ReadFull(rand, c.buf[1+len(packet):]); err != nil { + return err + } + c.buf = c.aead.Seal(c.buf[:0], c.iv, c.buf, c.prefix[:]) + if _, err := w.Write(c.buf); err != nil { + return err + } + c.incIV() + + return nil +} + +func (c *gcmCipher) incIV() { + for i := 4 + 7; i >= 4; i-- { + c.iv[i]++ + if c.iv[i] != 0 { + break + } + } +} + +func (c *gcmCipher) readCipherPacket(seqNum uint32, r io.Reader) ([]byte, error) { + if _, err := io.ReadFull(r, c.prefix[:]); err != nil { + return nil, err + } + length := binary.BigEndian.Uint32(c.prefix[:]) + if length > maxPacket { + return nil, errors.New("ssh: max packet length exceeded") + } + + if cap(c.buf) < int(length+gcmTagSize) { + c.buf = make([]byte, length+gcmTagSize) + } else { + c.buf = c.buf[:length+gcmTagSize] + } + + if _, err := io.ReadFull(r, c.buf); err != nil { + return nil, err + } + + plain, err := c.aead.Open(c.buf[:0], c.iv, c.buf, c.prefix[:]) + if err != nil { + return nil, err + } + c.incIV() + + if len(plain) == 0 { + return nil, errors.New("ssh: empty packet") + } + + padding := plain[0] + if padding < 4 { + // padding is a byte, so it automatically satisfies + // the maximum size, which is 255. + return nil, fmt.Errorf("ssh: illegal padding %d", padding) + } + + if int(padding+1) >= len(plain) { + return nil, fmt.Errorf("ssh: padding %d too large", padding) + } + plain = plain[1 : length-uint32(padding)] + return plain, nil +} + +// cbcCipher implements aes128-cbc cipher defined in RFC 4253 section 6.1 +type cbcCipher struct { + mac hash.Hash + macSize uint32 + decrypter cipher.BlockMode + encrypter cipher.BlockMode + + // The following members are to avoid per-packet allocations. + seqNumBytes [4]byte + packetData []byte + macResult []byte + + // Amount of data we should still read to hide which + // verification error triggered. + oracleCamouflage uint32 +} + +func newCBCCipher(c cipher.Block, key, iv, macKey []byte, algs directionAlgorithms) (packetCipher, error) { + cbc := &cbcCipher{ + mac: macModes[algs.MAC].new(macKey), + decrypter: cipher.NewCBCDecrypter(c, iv), + encrypter: cipher.NewCBCEncrypter(c, iv), + packetData: make([]byte, 1024), + } + if cbc.mac != nil { + cbc.macSize = uint32(cbc.mac.Size()) + } + + return cbc, nil +} + +func newAESCBCCipher(key, iv, macKey []byte, algs directionAlgorithms) (packetCipher, error) { + c, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + + cbc, err := newCBCCipher(c, key, iv, macKey, algs) + if err != nil { + return nil, err + } + + return cbc, nil +} + +func newTripleDESCBCCipher(key, iv, macKey []byte, algs directionAlgorithms) (packetCipher, error) { + c, err := des.NewTripleDESCipher(key) + if err != nil { + return nil, err + } + + cbc, err := newCBCCipher(c, key, iv, macKey, algs) + if err != nil { + return nil, err + } + + return cbc, nil +} + +func maxUInt32(a, b int) uint32 { + if a > b { + return uint32(a) + } + return uint32(b) +} + +const ( + cbcMinPacketSizeMultiple = 8 + cbcMinPacketSize = 16 + cbcMinPaddingSize = 4 +) + +// cbcError represents a verification error that may leak information. +type cbcError string + +func (e cbcError) Error() string { return string(e) } + +func (c *cbcCipher) readCipherPacket(seqNum uint32, r io.Reader) ([]byte, error) { + p, err := c.readCipherPacketLeaky(seqNum, r) + if err != nil { + if _, ok := err.(cbcError); ok { + // Verification error: read a fixed amount of + // data, to make distinguishing between + // failing MAC and failing length check more + // difficult. + io.CopyN(io.Discard, r, int64(c.oracleCamouflage)) + } + } + return p, err +} + +func (c *cbcCipher) readCipherPacketLeaky(seqNum uint32, r io.Reader) ([]byte, error) { + blockSize := c.decrypter.BlockSize() + + // Read the header, which will include some of the subsequent data in the + // case of block ciphers - this is copied back to the payload later. + // How many bytes of payload/padding will be read with this first read. + firstBlockLength := uint32((prefixLen + blockSize - 1) / blockSize * blockSize) + firstBlock := c.packetData[:firstBlockLength] + if _, err := io.ReadFull(r, firstBlock); err != nil { + return nil, err + } + + c.oracleCamouflage = maxPacket + 4 + c.macSize - firstBlockLength + + c.decrypter.CryptBlocks(firstBlock, firstBlock) + length := binary.BigEndian.Uint32(firstBlock[:4]) + if length > maxPacket { + return nil, cbcError("ssh: packet too large") + } + if length+4 < maxUInt32(cbcMinPacketSize, blockSize) { + // The minimum size of a packet is 16 (or the cipher block size, whichever + // is larger) bytes. + return nil, cbcError("ssh: packet too small") + } + // The length of the packet (including the length field but not the MAC) must + // be a multiple of the block size or 8, whichever is larger. + if (length+4)%maxUInt32(cbcMinPacketSizeMultiple, blockSize) != 0 { + return nil, cbcError("ssh: invalid packet length multiple") + } + + paddingLength := uint32(firstBlock[4]) + if paddingLength < cbcMinPaddingSize || length <= paddingLength+1 { + return nil, cbcError("ssh: invalid packet length") + } + + // Positions within the c.packetData buffer: + macStart := 4 + length + paddingStart := macStart - paddingLength + + // Entire packet size, starting before length, ending at end of mac. + entirePacketSize := macStart + c.macSize + + // Ensure c.packetData is large enough for the entire packet data. + if uint32(cap(c.packetData)) < entirePacketSize { + // Still need to upsize and copy, but this should be rare at runtime, only + // on upsizing the packetData buffer. + c.packetData = make([]byte, entirePacketSize) + copy(c.packetData, firstBlock) + } else { + c.packetData = c.packetData[:entirePacketSize] + } + + n, err := io.ReadFull(r, c.packetData[firstBlockLength:]) + if err != nil { + return nil, err + } + c.oracleCamouflage -= uint32(n) + + remainingCrypted := c.packetData[firstBlockLength:macStart] + c.decrypter.CryptBlocks(remainingCrypted, remainingCrypted) + + mac := c.packetData[macStart:] + if c.mac != nil { + c.mac.Reset() + binary.BigEndian.PutUint32(c.seqNumBytes[:], seqNum) + c.mac.Write(c.seqNumBytes[:]) + c.mac.Write(c.packetData[:macStart]) + c.macResult = c.mac.Sum(c.macResult[:0]) + if subtle.ConstantTimeCompare(c.macResult, mac) != 1 { + return nil, cbcError("ssh: MAC failure") + } + } + + return c.packetData[prefixLen:paddingStart], nil +} + +func (c *cbcCipher) writeCipherPacket(seqNum uint32, w io.Writer, rand io.Reader, packet []byte) error { + effectiveBlockSize := maxUInt32(cbcMinPacketSizeMultiple, c.encrypter.BlockSize()) + + // Length of encrypted portion of the packet (header, payload, padding). + // Enforce minimum padding and packet size. + encLength := maxUInt32(prefixLen+len(packet)+cbcMinPaddingSize, cbcMinPaddingSize) + // Enforce block size. + encLength = (encLength + effectiveBlockSize - 1) / effectiveBlockSize * effectiveBlockSize + + length := encLength - 4 + paddingLength := int(length) - (1 + len(packet)) + + // Overall buffer contains: header, payload, padding, mac. + // Space for the MAC is reserved in the capacity but not the slice length. + bufferSize := encLength + c.macSize + if uint32(cap(c.packetData)) < bufferSize { + c.packetData = make([]byte, encLength, bufferSize) + } else { + c.packetData = c.packetData[:encLength] + } + + p := c.packetData + + // Packet header. + binary.BigEndian.PutUint32(p, length) + p = p[4:] + p[0] = byte(paddingLength) + + // Payload. + p = p[1:] + copy(p, packet) + + // Padding. + p = p[len(packet):] + if _, err := io.ReadFull(rand, p); err != nil { + return err + } + + if c.mac != nil { + c.mac.Reset() + binary.BigEndian.PutUint32(c.seqNumBytes[:], seqNum) + c.mac.Write(c.seqNumBytes[:]) + c.mac.Write(c.packetData) + // The MAC is now appended into the capacity reserved for it earlier. + c.packetData = c.mac.Sum(c.packetData) + } + + c.encrypter.CryptBlocks(c.packetData[:encLength], c.packetData[:encLength]) + + if _, err := w.Write(c.packetData); err != nil { + return err + } + + return nil +} + +const chacha20Poly1305ID = "chacha20-poly1305@openssh.com" + +// chacha20Poly1305Cipher implements the chacha20-poly1305@openssh.com +// AEAD, which is described here: +// +// https://tools.ietf.org/html/draft-josefsson-ssh-chacha20-poly1305-openssh-00 +// +// the methods here also implement padding, which RFC 4253 Section 6 +// also requires of stream ciphers. +type chacha20Poly1305Cipher struct { + lengthKey [32]byte + contentKey [32]byte + buf []byte +} + +func newChaCha20Cipher(key, unusedIV, unusedMACKey []byte, unusedAlgs directionAlgorithms) (packetCipher, error) { + if len(key) != 64 { + panic(len(key)) + } + + c := &chacha20Poly1305Cipher{ + buf: make([]byte, 256), + } + + copy(c.contentKey[:], key[:32]) + copy(c.lengthKey[:], key[32:]) + return c, nil +} + +func (c *chacha20Poly1305Cipher) readCipherPacket(seqNum uint32, r io.Reader) ([]byte, error) { + nonce := make([]byte, 12) + binary.BigEndian.PutUint32(nonce[8:], seqNum) + s, err := chacha20.NewUnauthenticatedCipher(c.contentKey[:], nonce) + if err != nil { + return nil, err + } + var polyKey, discardBuf [32]byte + s.XORKeyStream(polyKey[:], polyKey[:]) + s.XORKeyStream(discardBuf[:], discardBuf[:]) // skip the next 32 bytes + + encryptedLength := c.buf[:4] + if _, err := io.ReadFull(r, encryptedLength); err != nil { + return nil, err + } + + var lenBytes [4]byte + ls, err := chacha20.NewUnauthenticatedCipher(c.lengthKey[:], nonce) + if err != nil { + return nil, err + } + ls.XORKeyStream(lenBytes[:], encryptedLength) + + length := binary.BigEndian.Uint32(lenBytes[:]) + if length > maxPacket { + return nil, errors.New("ssh: invalid packet length, packet too large") + } + + contentEnd := 4 + length + packetEnd := contentEnd + poly1305.TagSize + if uint32(cap(c.buf)) < packetEnd { + c.buf = make([]byte, packetEnd) + copy(c.buf[:], encryptedLength) + } else { + c.buf = c.buf[:packetEnd] + } + + if _, err := io.ReadFull(r, c.buf[4:packetEnd]); err != nil { + return nil, err + } + + var mac [poly1305.TagSize]byte + copy(mac[:], c.buf[contentEnd:packetEnd]) + if !poly1305.Verify(&mac, c.buf[:contentEnd], &polyKey) { + return nil, errors.New("ssh: MAC failure") + } + + plain := c.buf[4:contentEnd] + s.XORKeyStream(plain, plain) + + if len(plain) == 0 { + return nil, errors.New("ssh: empty packet") + } + + padding := plain[0] + if padding < 4 { + // padding is a byte, so it automatically satisfies + // the maximum size, which is 255. + return nil, fmt.Errorf("ssh: illegal padding %d", padding) + } + + if int(padding)+1 >= len(plain) { + return nil, fmt.Errorf("ssh: padding %d too large", padding) + } + + plain = plain[1 : len(plain)-int(padding)] + + return plain, nil +} + +func (c *chacha20Poly1305Cipher) writeCipherPacket(seqNum uint32, w io.Writer, rand io.Reader, payload []byte) error { + nonce := make([]byte, 12) + binary.BigEndian.PutUint32(nonce[8:], seqNum) + s, err := chacha20.NewUnauthenticatedCipher(c.contentKey[:], nonce) + if err != nil { + return err + } + var polyKey, discardBuf [32]byte + s.XORKeyStream(polyKey[:], polyKey[:]) + s.XORKeyStream(discardBuf[:], discardBuf[:]) // skip the next 32 bytes + + // There is no blocksize, so fall back to multiple of 8 byte + // padding, as described in RFC 4253, Sec 6. + const packetSizeMultiple = 8 + + padding := packetSizeMultiple - (1+len(payload))%packetSizeMultiple + if padding < 4 { + padding += packetSizeMultiple + } + + // size (4 bytes), padding (1), payload, padding, tag. + totalLength := 4 + 1 + len(payload) + padding + poly1305.TagSize + if cap(c.buf) < totalLength { + c.buf = make([]byte, totalLength) + } else { + c.buf = c.buf[:totalLength] + } + + binary.BigEndian.PutUint32(c.buf, uint32(1+len(payload)+padding)) + ls, err := chacha20.NewUnauthenticatedCipher(c.lengthKey[:], nonce) + if err != nil { + return err + } + ls.XORKeyStream(c.buf, c.buf[:4]) + c.buf[4] = byte(padding) + copy(c.buf[5:], payload) + packetEnd := 5 + len(payload) + padding + if _, err := io.ReadFull(rand, c.buf[5+len(payload):packetEnd]); err != nil { + return err + } + + s.XORKeyStream(c.buf[4:], c.buf[4:packetEnd]) + + var mac [poly1305.TagSize]byte + poly1305.Sum(&mac, c.buf[:packetEnd], &polyKey) + + copy(c.buf[packetEnd:], mac[:]) + + if _, err := w.Write(c.buf); err != nil { + return err + } + return nil +} diff --git a/tempfork/sshtest/ssh/cipher_test.go b/tempfork/sshtest/ssh/cipher_test.go new file mode 100644 index 000000000..fe339862c --- /dev/null +++ b/tempfork/sshtest/ssh/cipher_test.go @@ -0,0 +1,231 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "bytes" + "crypto" + "crypto/rand" + "encoding/binary" + "io" + "testing" + + "golang.org/x/crypto/chacha20" + "golang.org/x/crypto/poly1305" +) + +func TestDefaultCiphersExist(t *testing.T) { + for _, cipherAlgo := range supportedCiphers { + if _, ok := cipherModes[cipherAlgo]; !ok { + t.Errorf("supported cipher %q is unknown", cipherAlgo) + } + } + for _, cipherAlgo := range preferredCiphers { + if _, ok := cipherModes[cipherAlgo]; !ok { + t.Errorf("preferred cipher %q is unknown", cipherAlgo) + } + } +} + +func TestPacketCiphers(t *testing.T) { + defaultMac := "hmac-sha2-256" + defaultCipher := "aes128-ctr" + for cipher := range cipherModes { + t.Run("cipher="+cipher, + func(t *testing.T) { testPacketCipher(t, cipher, defaultMac) }) + } + for mac := range macModes { + t.Run("mac="+mac, + func(t *testing.T) { testPacketCipher(t, defaultCipher, mac) }) + } +} + +func testPacketCipher(t *testing.T, cipher, mac string) { + kr := &kexResult{Hash: crypto.SHA1} + algs := directionAlgorithms{ + Cipher: cipher, + MAC: mac, + Compression: "none", + } + client, err := newPacketCipher(clientKeys, algs, kr) + if err != nil { + t.Fatalf("newPacketCipher(client, %q, %q): %v", cipher, mac, err) + } + server, err := newPacketCipher(clientKeys, algs, kr) + if err != nil { + t.Fatalf("newPacketCipher(client, %q, %q): %v", cipher, mac, err) + } + + want := "bla bla" + input := []byte(want) + buf := &bytes.Buffer{} + if err := client.writeCipherPacket(0, buf, rand.Reader, input); err != nil { + t.Fatalf("writeCipherPacket(%q, %q): %v", cipher, mac, err) + } + + packet, err := server.readCipherPacket(0, buf) + if err != nil { + t.Fatalf("readCipherPacket(%q, %q): %v", cipher, mac, err) + } + + if string(packet) != want { + t.Errorf("roundtrip(%q, %q): got %q, want %q", cipher, mac, packet, want) + } +} + +func TestCBCOracleCounterMeasure(t *testing.T) { + kr := &kexResult{Hash: crypto.SHA1} + algs := directionAlgorithms{ + Cipher: aes128cbcID, + MAC: "hmac-sha1", + Compression: "none", + } + client, err := newPacketCipher(clientKeys, algs, kr) + if err != nil { + t.Fatalf("newPacketCipher(client): %v", err) + } + + want := "bla bla" + input := []byte(want) + buf := &bytes.Buffer{} + if err := client.writeCipherPacket(0, buf, rand.Reader, input); err != nil { + t.Errorf("writeCipherPacket: %v", err) + } + + packetSize := buf.Len() + buf.Write(make([]byte, 2*maxPacket)) + + // We corrupt each byte, but this usually will only test the + // 'packet too large' or 'MAC failure' cases. + lastRead := -1 + for i := 0; i < packetSize; i++ { + server, err := newPacketCipher(clientKeys, algs, kr) + if err != nil { + t.Fatalf("newPacketCipher(client): %v", err) + } + + fresh := &bytes.Buffer{} + fresh.Write(buf.Bytes()) + fresh.Bytes()[i] ^= 0x01 + + before := fresh.Len() + _, err = server.readCipherPacket(0, fresh) + if err == nil { + t.Errorf("corrupt byte %d: readCipherPacket succeeded ", i) + continue + } + if _, ok := err.(cbcError); !ok { + t.Errorf("corrupt byte %d: got %v (%T), want cbcError", i, err, err) + continue + } + + after := fresh.Len() + bytesRead := before - after + if bytesRead < maxPacket { + t.Errorf("corrupt byte %d: read %d bytes, want more than %d", i, bytesRead, maxPacket) + continue + } + + if i > 0 && bytesRead != lastRead { + t.Errorf("corrupt byte %d: read %d bytes, want %d bytes read", i, bytesRead, lastRead) + } + lastRead = bytesRead + } +} + +func TestCVE202143565(t *testing.T) { + tests := []struct { + cipher string + constructPacket func(packetCipher) io.Reader + }{ + { + cipher: gcm128CipherID, + constructPacket: func(client packetCipher) io.Reader { + internalCipher := client.(*gcmCipher) + b := &bytes.Buffer{} + prefix := [4]byte{} + if _, err := b.Write(prefix[:]); err != nil { + t.Fatal(err) + } + internalCipher.buf = internalCipher.aead.Seal(internalCipher.buf[:0], internalCipher.iv, []byte{}, prefix[:]) + if _, err := b.Write(internalCipher.buf); err != nil { + t.Fatal(err) + } + internalCipher.incIV() + + return b + }, + }, + { + cipher: chacha20Poly1305ID, + constructPacket: func(client packetCipher) io.Reader { + internalCipher := client.(*chacha20Poly1305Cipher) + b := &bytes.Buffer{} + + nonce := make([]byte, 12) + s, err := chacha20.NewUnauthenticatedCipher(internalCipher.contentKey[:], nonce) + if err != nil { + t.Fatal(err) + } + var polyKey, discardBuf [32]byte + s.XORKeyStream(polyKey[:], polyKey[:]) + s.XORKeyStream(discardBuf[:], discardBuf[:]) // skip the next 32 bytes + + internalCipher.buf = make([]byte, 4+poly1305.TagSize) + binary.BigEndian.PutUint32(internalCipher.buf, 0) + ls, err := chacha20.NewUnauthenticatedCipher(internalCipher.lengthKey[:], nonce) + if err != nil { + t.Fatal(err) + } + ls.XORKeyStream(internalCipher.buf, internalCipher.buf[:4]) + if _, err := io.ReadFull(rand.Reader, internalCipher.buf[4:4]); err != nil { + t.Fatal(err) + } + + s.XORKeyStream(internalCipher.buf[4:], internalCipher.buf[4:4]) + + var tag [poly1305.TagSize]byte + poly1305.Sum(&tag, internalCipher.buf[:4], &polyKey) + + copy(internalCipher.buf[4:], tag[:]) + + if _, err := b.Write(internalCipher.buf); err != nil { + t.Fatal(err) + } + + return b + }, + }, + } + + for _, tc := range tests { + mac := "hmac-sha2-256" + + kr := &kexResult{Hash: crypto.SHA1} + algs := directionAlgorithms{ + Cipher: tc.cipher, + MAC: mac, + Compression: "none", + } + client, err := newPacketCipher(clientKeys, algs, kr) + if err != nil { + t.Fatalf("newPacketCipher(client, %q, %q): %v", tc.cipher, mac, err) + } + server, err := newPacketCipher(clientKeys, algs, kr) + if err != nil { + t.Fatalf("newPacketCipher(client, %q, %q): %v", tc.cipher, mac, err) + } + + b := tc.constructPacket(client) + + wantErr := "ssh: empty packet" + _, err = server.readCipherPacket(0, b) + if err == nil { + t.Fatalf("readCipherPacket(%q, %q): didn't fail with empty packet", tc.cipher, mac) + } else if err.Error() != wantErr { + t.Fatalf("readCipherPacket(%q, %q): unexpected error, got %q, want %q", tc.cipher, mac, err, wantErr) + } + } +} diff --git a/tempfork/sshtest/ssh/client.go b/tempfork/sshtest/ssh/client.go new file mode 100644 index 000000000..5876e6421 --- /dev/null +++ b/tempfork/sshtest/ssh/client.go @@ -0,0 +1,290 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "bytes" + "errors" + "fmt" + "net" + "os" + "sync" + "time" +) + +// Client implements a traditional SSH client that supports shells, +// subprocesses, TCP port/streamlocal forwarding and tunneled dialing. +type Client struct { + Conn + + handleForwardsOnce sync.Once // guards calling (*Client).handleForwards + + forwards forwardList // forwarded tcpip connections from the remote side + mu sync.Mutex + channelHandlers map[string]chan NewChannel +} + +// HandleChannelOpen returns a channel on which NewChannel requests +// for the given type are sent. If the type already is being handled, +// nil is returned. The channel is closed when the connection is closed. +func (c *Client) HandleChannelOpen(channelType string) <-chan NewChannel { + c.mu.Lock() + defer c.mu.Unlock() + if c.channelHandlers == nil { + // The SSH channel has been closed. + c := make(chan NewChannel) + close(c) + return c + } + + ch := c.channelHandlers[channelType] + if ch != nil { + return nil + } + + ch = make(chan NewChannel, chanSize) + c.channelHandlers[channelType] = ch + return ch +} + +// NewClient creates a Client on top of the given connection. +func NewClient(c Conn, chans <-chan NewChannel, reqs <-chan *Request) *Client { + conn := &Client{ + Conn: c, + channelHandlers: make(map[string]chan NewChannel, 1), + } + + go conn.handleGlobalRequests(reqs) + go conn.handleChannelOpens(chans) + go func() { + conn.Wait() + conn.forwards.closeAll() + }() + return conn +} + +// NewClientConn establishes an authenticated SSH connection using c +// as the underlying transport. The Request and NewChannel channels +// must be serviced or the connection will hang. +func NewClientConn(c net.Conn, addr string, config *ClientConfig) (Conn, <-chan NewChannel, <-chan *Request, error) { + fullConf := *config + fullConf.SetDefaults() + if fullConf.HostKeyCallback == nil { + c.Close() + return nil, nil, nil, errors.New("ssh: must specify HostKeyCallback") + } + + conn := &connection{ + sshConn: sshConn{conn: c, user: fullConf.User}, + } + + if err := conn.clientHandshake(addr, &fullConf); err != nil { + c.Close() + return nil, nil, nil, fmt.Errorf("ssh: handshake failed: %w", err) + } + conn.mux = newMux(conn.transport) + return conn, conn.mux.incomingChannels, conn.mux.incomingRequests, nil +} + +// clientHandshake performs the client side key exchange. See RFC 4253 Section +// 7. +func (c *connection) clientHandshake(dialAddress string, config *ClientConfig) error { + if config.ClientVersion != "" { + c.clientVersion = []byte(config.ClientVersion) + } else { + c.clientVersion = []byte(packageVersion) + } + var err error + c.serverVersion, err = exchangeVersions(c.sshConn.conn, c.clientVersion) + if err != nil { + return err + } + + c.transport = newClientTransport( + newTransport(c.sshConn.conn, config.Rand, true /* is client */), + c.clientVersion, c.serverVersion, config, dialAddress, c.sshConn.RemoteAddr()) + if err := c.transport.waitSession(); err != nil { + return err + } + + c.sessionID = c.transport.getSessionID() + return c.clientAuthenticate(config) +} + +// verifyHostKeySignature verifies the host key obtained in the key exchange. +// algo is the negotiated algorithm, and may be a certificate type. +func verifyHostKeySignature(hostKey PublicKey, algo string, result *kexResult) error { + sig, rest, ok := parseSignatureBody(result.Signature) + if len(rest) > 0 || !ok { + return errors.New("ssh: signature parse error") + } + + if a := underlyingAlgo(algo); sig.Format != a { + return fmt.Errorf("ssh: invalid signature algorithm %q, expected %q", sig.Format, a) + } + + return hostKey.Verify(result.H, sig) +} + +// NewSession opens a new Session for this client. (A session is a remote +// execution of a program.) +func (c *Client) NewSession() (*Session, error) { + ch, in, err := c.OpenChannel("session", nil) + if err != nil { + return nil, err + } + return newSession(ch, in) +} + +func (c *Client) handleGlobalRequests(incoming <-chan *Request) { + for r := range incoming { + // This handles keepalive messages and matches + // the behaviour of OpenSSH. + r.Reply(false, nil) + } +} + +// handleChannelOpens channel open messages from the remote side. +func (c *Client) handleChannelOpens(in <-chan NewChannel) { + for ch := range in { + c.mu.Lock() + handler := c.channelHandlers[ch.ChannelType()] + c.mu.Unlock() + + if handler != nil { + handler <- ch + } else { + ch.Reject(UnknownChannelType, fmt.Sprintf("unknown channel type: %v", ch.ChannelType())) + } + } + + c.mu.Lock() + for _, ch := range c.channelHandlers { + close(ch) + } + c.channelHandlers = nil + c.mu.Unlock() +} + +// Dial starts a client connection to the given SSH server. It is a +// convenience function that connects to the given network address, +// initiates the SSH handshake, and then sets up a Client. For access +// to incoming channels and requests, use net.Dial with NewClientConn +// instead. +func Dial(network, addr string, config *ClientConfig) (*Client, error) { + conn, err := net.DialTimeout(network, addr, config.Timeout) + if err != nil { + return nil, err + } + c, chans, reqs, err := NewClientConn(conn, addr, config) + if err != nil { + return nil, err + } + return NewClient(c, chans, reqs), nil +} + +// HostKeyCallback is the function type used for verifying server +// keys. A HostKeyCallback must return nil if the host key is OK, or +// an error to reject it. It receives the hostname as passed to Dial +// or NewClientConn. The remote address is the RemoteAddr of the +// net.Conn underlying the SSH connection. +type HostKeyCallback func(hostname string, remote net.Addr, key PublicKey) error + +// BannerCallback is the function type used for treat the banner sent by +// the server. A BannerCallback receives the message sent by the remote server. +type BannerCallback func(message string) error + +// A ClientConfig structure is used to configure a Client. It must not be +// modified after having been passed to an SSH function. +type ClientConfig struct { + // Config contains configuration that is shared between clients and + // servers. + Config + + // User contains the username to authenticate as. + User string + + // Auth contains possible authentication methods to use with the + // server. Only the first instance of a particular RFC 4252 method will + // be used during authentication. + Auth []AuthMethod + + // HostKeyCallback is called during the cryptographic + // handshake to validate the server's host key. The client + // configuration must supply this callback for the connection + // to succeed. The functions InsecureIgnoreHostKey or + // FixedHostKey can be used for simplistic host key checks. + HostKeyCallback HostKeyCallback + + // BannerCallback is called during the SSH dance to display a custom + // server's message. The client configuration can supply this callback to + // handle it as wished. The function BannerDisplayStderr can be used for + // simplistic display on Stderr. + BannerCallback BannerCallback + + // ClientVersion contains the version identification string that will + // be used for the connection. If empty, a reasonable default is used. + ClientVersion string + + // HostKeyAlgorithms lists the public key algorithms that the client will + // accept from the server for host key authentication, in order of + // preference. If empty, a reasonable default is used. Any + // string returned from a PublicKey.Type method may be used, or + // any of the CertAlgo and KeyAlgo constants. + HostKeyAlgorithms []string + + // Timeout is the maximum amount of time for the TCP connection to establish. + // + // A Timeout of zero means no timeout. + Timeout time.Duration + + // SkipNoneAuth allows skipping the initial "none" auth request. This is unusual + // behavior, but it is allowed by [RFC4252 5.2](https://datatracker.ietf.org/doc/html/rfc4252#section-5.2), + // and some clients in the wild behave like this. One such client is the paramiko Python + // library, which is used in pgadmin4 via the sshtunnel library. + // When SkipNoneAuth is true, the client will attempt all configured + // [AuthMethod]s until one works, or it runs out. + SkipNoneAuth bool +} + +// InsecureIgnoreHostKey returns a function that can be used for +// ClientConfig.HostKeyCallback to accept any host key. It should +// not be used for production code. +func InsecureIgnoreHostKey() HostKeyCallback { + return func(hostname string, remote net.Addr, key PublicKey) error { + return nil + } +} + +type fixedHostKey struct { + key PublicKey +} + +func (f *fixedHostKey) check(hostname string, remote net.Addr, key PublicKey) error { + if f.key == nil { + return fmt.Errorf("ssh: required host key was nil") + } + if !bytes.Equal(key.Marshal(), f.key.Marshal()) { + return fmt.Errorf("ssh: host key mismatch") + } + return nil +} + +// FixedHostKey returns a function for use in +// ClientConfig.HostKeyCallback to accept only a specific host key. +func FixedHostKey(key PublicKey) HostKeyCallback { + hk := &fixedHostKey{key} + return hk.check +} + +// BannerDisplayStderr returns a function that can be used for +// ClientConfig.BannerCallback to display banners on os.Stderr. +func BannerDisplayStderr() BannerCallback { + return func(banner string) error { + _, err := os.Stderr.WriteString(banner) + + return err + } +} diff --git a/tempfork/sshtest/ssh/client_auth.go b/tempfork/sshtest/ssh/client_auth.go new file mode 100644 index 000000000..af25a4f01 --- /dev/null +++ b/tempfork/sshtest/ssh/client_auth.go @@ -0,0 +1,805 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "bytes" + "errors" + "fmt" + "io" + "strings" +) + +type authResult int + +const ( + authFailure authResult = iota + authPartialSuccess + authSuccess +) + +// clientAuthenticate authenticates with the remote server. See RFC 4252. +func (c *connection) clientAuthenticate(config *ClientConfig) error { + // initiate user auth session + if err := c.transport.writePacket(Marshal(&serviceRequestMsg{serviceUserAuth})); err != nil { + return err + } + packet, err := c.transport.readPacket() + if err != nil { + return err + } + // The server may choose to send a SSH_MSG_EXT_INFO at this point (if we + // advertised willingness to receive one, which we always do) or not. See + // RFC 8308, Section 2.4. + extensions := make(map[string][]byte) + if len(packet) > 0 && packet[0] == msgExtInfo { + var extInfo extInfoMsg + if err := Unmarshal(packet, &extInfo); err != nil { + return err + } + payload := extInfo.Payload + for i := uint32(0); i < extInfo.NumExtensions; i++ { + name, rest, ok := parseString(payload) + if !ok { + return parseError(msgExtInfo) + } + value, rest, ok := parseString(rest) + if !ok { + return parseError(msgExtInfo) + } + extensions[string(name)] = value + payload = rest + } + packet, err = c.transport.readPacket() + if err != nil { + return err + } + } + var serviceAccept serviceAcceptMsg + if err := Unmarshal(packet, &serviceAccept); err != nil { + return err + } + + // during the authentication phase the client first attempts the "none" method + // then any untried methods suggested by the server. + var tried []string + var lastMethods []string + + sessionID := c.transport.getSessionID() + var auth AuthMethod + if !config.SkipNoneAuth { + auth = AuthMethod(new(noneAuth)) + } else if len(config.Auth) > 0 { + auth = config.Auth[0] + for _, a := range config.Auth { + lastMethods = append(lastMethods, a.method()) + } + } + for auth != nil { + ok, methods, err := auth.auth(sessionID, config.User, c.transport, config.Rand, extensions) + if err != nil { + // On disconnect, return error immediately + if _, ok := err.(*disconnectMsg); ok { + return err + } + // We return the error later if there is no other method left to + // try. + ok = authFailure + } + if ok == authSuccess { + // success + return nil + } else if ok == authFailure { + if m := auth.method(); !contains(tried, m) { + tried = append(tried, m) + } + } + if methods == nil { + methods = lastMethods + } + lastMethods = methods + + auth = nil + + findNext: + for _, a := range config.Auth { + candidateMethod := a.method() + if contains(tried, candidateMethod) { + continue + } + for _, meth := range methods { + if meth == candidateMethod { + auth = a + break findNext + } + } + } + + if auth == nil && err != nil { + // We have an error and there are no other authentication methods to + // try, so we return it. + return err + } + } + return fmt.Errorf("ssh: unable to authenticate, attempted methods %v, no supported methods remain", tried) +} + +func contains(list []string, e string) bool { + for _, s := range list { + if s == e { + return true + } + } + return false +} + +// An AuthMethod represents an instance of an RFC 4252 authentication method. +type AuthMethod interface { + // auth authenticates user over transport t. + // Returns true if authentication is successful. + // If authentication is not successful, a []string of alternative + // method names is returned. If the slice is nil, it will be ignored + // and the previous set of possible methods will be reused. + auth(session []byte, user string, p packetConn, rand io.Reader, extensions map[string][]byte) (authResult, []string, error) + + // method returns the RFC 4252 method name. + method() string +} + +// "none" authentication, RFC 4252 section 5.2. +type noneAuth int + +func (n *noneAuth) auth(session []byte, user string, c packetConn, rand io.Reader, _ map[string][]byte) (authResult, []string, error) { + if err := c.writePacket(Marshal(&userAuthRequestMsg{ + User: user, + Service: serviceSSH, + Method: "none", + })); err != nil { + return authFailure, nil, err + } + + return handleAuthResponse(c) +} + +func (n *noneAuth) method() string { + return "none" +} + +// passwordCallback is an AuthMethod that fetches the password through +// a function call, e.g. by prompting the user. +type passwordCallback func() (password string, err error) + +func (cb passwordCallback) auth(session []byte, user string, c packetConn, rand io.Reader, _ map[string][]byte) (authResult, []string, error) { + type passwordAuthMsg struct { + User string `sshtype:"50"` + Service string + Method string + Reply bool + Password string + } + + pw, err := cb() + // REVIEW NOTE: is there a need to support skipping a password attempt? + // The program may only find out that the user doesn't have a password + // when prompting. + if err != nil { + return authFailure, nil, err + } + + if err := c.writePacket(Marshal(&passwordAuthMsg{ + User: user, + Service: serviceSSH, + Method: cb.method(), + Reply: false, + Password: pw, + })); err != nil { + return authFailure, nil, err + } + + return handleAuthResponse(c) +} + +func (cb passwordCallback) method() string { + return "password" +} + +// Password returns an AuthMethod using the given password. +func Password(secret string) AuthMethod { + return passwordCallback(func() (string, error) { return secret, nil }) +} + +// PasswordCallback returns an AuthMethod that uses a callback for +// fetching a password. +func PasswordCallback(prompt func() (secret string, err error)) AuthMethod { + return passwordCallback(prompt) +} + +type publickeyAuthMsg struct { + User string `sshtype:"50"` + Service string + Method string + // HasSig indicates to the receiver packet that the auth request is signed and + // should be used for authentication of the request. + HasSig bool + Algoname string + PubKey []byte + // Sig is tagged with "rest" so Marshal will exclude it during + // validateKey + Sig []byte `ssh:"rest"` +} + +// publicKeyCallback is an AuthMethod that uses a set of key +// pairs for authentication. +type publicKeyCallback func() ([]Signer, error) + +func (cb publicKeyCallback) method() string { + return "publickey" +} + +func pickSignatureAlgorithm(signer Signer, extensions map[string][]byte) (MultiAlgorithmSigner, string, error) { + var as MultiAlgorithmSigner + keyFormat := signer.PublicKey().Type() + + // If the signer implements MultiAlgorithmSigner we use the algorithms it + // support, if it implements AlgorithmSigner we assume it supports all + // algorithms, otherwise only the key format one. + switch s := signer.(type) { + case MultiAlgorithmSigner: + as = s + case AlgorithmSigner: + as = &multiAlgorithmSigner{ + AlgorithmSigner: s, + supportedAlgorithms: algorithmsForKeyFormat(underlyingAlgo(keyFormat)), + } + default: + as = &multiAlgorithmSigner{ + AlgorithmSigner: algorithmSignerWrapper{signer}, + supportedAlgorithms: []string{underlyingAlgo(keyFormat)}, + } + } + + getFallbackAlgo := func() (string, error) { + // Fallback to use if there is no "server-sig-algs" extension or a + // common algorithm cannot be found. We use the public key format if the + // MultiAlgorithmSigner supports it, otherwise we return an error. + if !contains(as.Algorithms(), underlyingAlgo(keyFormat)) { + return "", fmt.Errorf("ssh: no common public key signature algorithm, server only supports %q for key type %q, signer only supports %v", + underlyingAlgo(keyFormat), keyFormat, as.Algorithms()) + } + return keyFormat, nil + } + + extPayload, ok := extensions["server-sig-algs"] + if !ok { + // If there is no "server-sig-algs" extension use the fallback + // algorithm. + algo, err := getFallbackAlgo() + return as, algo, err + } + + // The server-sig-algs extension only carries underlying signature + // algorithm, but we are trying to select a protocol-level public key + // algorithm, which might be a certificate type. Extend the list of server + // supported algorithms to include the corresponding certificate algorithms. + serverAlgos := strings.Split(string(extPayload), ",") + for _, algo := range serverAlgos { + if certAlgo, ok := certificateAlgo(algo); ok { + serverAlgos = append(serverAlgos, certAlgo) + } + } + + // Filter algorithms based on those supported by MultiAlgorithmSigner. + var keyAlgos []string + for _, algo := range algorithmsForKeyFormat(keyFormat) { + if contains(as.Algorithms(), underlyingAlgo(algo)) { + keyAlgos = append(keyAlgos, algo) + } + } + + algo, err := findCommon("public key signature algorithm", keyAlgos, serverAlgos) + if err != nil { + // If there is no overlap, return the fallback algorithm to support + // servers that fail to list all supported algorithms. + algo, err := getFallbackAlgo() + return as, algo, err + } + return as, algo, nil +} + +func (cb publicKeyCallback) auth(session []byte, user string, c packetConn, rand io.Reader, extensions map[string][]byte) (authResult, []string, error) { + // Authentication is performed by sending an enquiry to test if a key is + // acceptable to the remote. If the key is acceptable, the client will + // attempt to authenticate with the valid key. If not the client will repeat + // the process with the remaining keys. + + signers, err := cb() + if err != nil { + return authFailure, nil, err + } + var methods []string + var errSigAlgo error + + origSignersLen := len(signers) + for idx := 0; idx < len(signers); idx++ { + signer := signers[idx] + pub := signer.PublicKey() + as, algo, err := pickSignatureAlgorithm(signer, extensions) + if err != nil && errSigAlgo == nil { + // If we cannot negotiate a signature algorithm store the first + // error so we can return it to provide a more meaningful message if + // no other signers work. + errSigAlgo = err + continue + } + ok, err := validateKey(pub, algo, user, c) + if err != nil { + return authFailure, nil, err + } + // OpenSSH 7.2-7.7 advertises support for rsa-sha2-256 and rsa-sha2-512 + // in the "server-sig-algs" extension but doesn't support these + // algorithms for certificate authentication, so if the server rejects + // the key try to use the obtained algorithm as if "server-sig-algs" had + // not been implemented if supported from the algorithm signer. + if !ok && idx < origSignersLen && isRSACert(algo) && algo != CertAlgoRSAv01 { + if contains(as.Algorithms(), KeyAlgoRSA) { + // We retry using the compat algorithm after all signers have + // been tried normally. + signers = append(signers, &multiAlgorithmSigner{ + AlgorithmSigner: as, + supportedAlgorithms: []string{KeyAlgoRSA}, + }) + } + } + if !ok { + continue + } + + pubKey := pub.Marshal() + data := buildDataSignedForAuth(session, userAuthRequestMsg{ + User: user, + Service: serviceSSH, + Method: cb.method(), + }, algo, pubKey) + sign, err := as.SignWithAlgorithm(rand, data, underlyingAlgo(algo)) + if err != nil { + return authFailure, nil, err + } + + // manually wrap the serialized signature in a string + s := Marshal(sign) + sig := make([]byte, stringLength(len(s))) + marshalString(sig, s) + msg := publickeyAuthMsg{ + User: user, + Service: serviceSSH, + Method: cb.method(), + HasSig: true, + Algoname: algo, + PubKey: pubKey, + Sig: sig, + } + p := Marshal(&msg) + if err := c.writePacket(p); err != nil { + return authFailure, nil, err + } + var success authResult + success, methods, err = handleAuthResponse(c) + if err != nil { + return authFailure, nil, err + } + + // If authentication succeeds or the list of available methods does not + // contain the "publickey" method, do not attempt to authenticate with any + // other keys. According to RFC 4252 Section 7, the latter can occur when + // additional authentication methods are required. + if success == authSuccess || !contains(methods, cb.method()) { + return success, methods, err + } + } + + return authFailure, methods, errSigAlgo +} + +// validateKey validates the key provided is acceptable to the server. +func validateKey(key PublicKey, algo string, user string, c packetConn) (bool, error) { + pubKey := key.Marshal() + msg := publickeyAuthMsg{ + User: user, + Service: serviceSSH, + Method: "publickey", + HasSig: false, + Algoname: algo, + PubKey: pubKey, + } + if err := c.writePacket(Marshal(&msg)); err != nil { + return false, err + } + + return confirmKeyAck(key, c) +} + +func confirmKeyAck(key PublicKey, c packetConn) (bool, error) { + pubKey := key.Marshal() + + for { + packet, err := c.readPacket() + if err != nil { + return false, err + } + switch packet[0] { + case msgUserAuthBanner: + if err := handleBannerResponse(c, packet); err != nil { + return false, err + } + case msgUserAuthPubKeyOk: + var msg userAuthPubKeyOkMsg + if err := Unmarshal(packet, &msg); err != nil { + return false, err + } + // According to RFC 4252 Section 7 the algorithm in + // SSH_MSG_USERAUTH_PK_OK should match that of the request but some + // servers send the key type instead. OpenSSH allows any algorithm + // that matches the public key, so we do the same. + // https://github.com/openssh/openssh-portable/blob/86bdd385/sshconnect2.c#L709 + if !contains(algorithmsForKeyFormat(key.Type()), msg.Algo) { + return false, nil + } + if !bytes.Equal(msg.PubKey, pubKey) { + return false, nil + } + return true, nil + case msgUserAuthFailure: + return false, nil + default: + return false, unexpectedMessageError(msgUserAuthPubKeyOk, packet[0]) + } + } +} + +// PublicKeys returns an AuthMethod that uses the given key +// pairs. +func PublicKeys(signers ...Signer) AuthMethod { + return publicKeyCallback(func() ([]Signer, error) { return signers, nil }) +} + +// PublicKeysCallback returns an AuthMethod that runs the given +// function to obtain a list of key pairs. +func PublicKeysCallback(getSigners func() (signers []Signer, err error)) AuthMethod { + return publicKeyCallback(getSigners) +} + +// handleAuthResponse returns whether the preceding authentication request succeeded +// along with a list of remaining authentication methods to try next and +// an error if an unexpected response was received. +func handleAuthResponse(c packetConn) (authResult, []string, error) { + gotMsgExtInfo := false + for { + packet, err := c.readPacket() + if err != nil { + return authFailure, nil, err + } + + switch packet[0] { + case msgUserAuthBanner: + if err := handleBannerResponse(c, packet); err != nil { + return authFailure, nil, err + } + case msgExtInfo: + // Ignore post-authentication RFC 8308 extensions, once. + if gotMsgExtInfo { + return authFailure, nil, unexpectedMessageError(msgUserAuthSuccess, packet[0]) + } + gotMsgExtInfo = true + case msgUserAuthFailure: + var msg userAuthFailureMsg + if err := Unmarshal(packet, &msg); err != nil { + return authFailure, nil, err + } + if msg.PartialSuccess { + return authPartialSuccess, msg.Methods, nil + } + return authFailure, msg.Methods, nil + case msgUserAuthSuccess: + return authSuccess, nil, nil + default: + return authFailure, nil, unexpectedMessageError(msgUserAuthSuccess, packet[0]) + } + } +} + +func handleBannerResponse(c packetConn, packet []byte) error { + var msg userAuthBannerMsg + if err := Unmarshal(packet, &msg); err != nil { + return err + } + + transport, ok := c.(*handshakeTransport) + if !ok { + return nil + } + + if transport.bannerCallback != nil { + return transport.bannerCallback(msg.Message) + } + + return nil +} + +// KeyboardInteractiveChallenge should print questions, optionally +// disabling echoing (e.g. for passwords), and return all the answers. +// Challenge may be called multiple times in a single session. After +// successful authentication, the server may send a challenge with no +// questions, for which the name and instruction messages should be +// printed. RFC 4256 section 3.3 details how the UI should behave for +// both CLI and GUI environments. +type KeyboardInteractiveChallenge func(name, instruction string, questions []string, echos []bool) (answers []string, err error) + +// KeyboardInteractive returns an AuthMethod using a prompt/response +// sequence controlled by the server. +func KeyboardInteractive(challenge KeyboardInteractiveChallenge) AuthMethod { + return challenge +} + +func (cb KeyboardInteractiveChallenge) method() string { + return "keyboard-interactive" +} + +func (cb KeyboardInteractiveChallenge) auth(session []byte, user string, c packetConn, rand io.Reader, _ map[string][]byte) (authResult, []string, error) { + type initiateMsg struct { + User string `sshtype:"50"` + Service string + Method string + Language string + Submethods string + } + + if err := c.writePacket(Marshal(&initiateMsg{ + User: user, + Service: serviceSSH, + Method: "keyboard-interactive", + })); err != nil { + return authFailure, nil, err + } + + gotMsgExtInfo := false + gotUserAuthInfoRequest := false + for { + packet, err := c.readPacket() + if err != nil { + return authFailure, nil, err + } + + // like handleAuthResponse, but with less options. + switch packet[0] { + case msgUserAuthBanner: + if err := handleBannerResponse(c, packet); err != nil { + return authFailure, nil, err + } + continue + case msgExtInfo: + // Ignore post-authentication RFC 8308 extensions, once. + if gotMsgExtInfo { + return authFailure, nil, unexpectedMessageError(msgUserAuthInfoRequest, packet[0]) + } + gotMsgExtInfo = true + continue + case msgUserAuthInfoRequest: + // OK + case msgUserAuthFailure: + var msg userAuthFailureMsg + if err := Unmarshal(packet, &msg); err != nil { + return authFailure, nil, err + } + if msg.PartialSuccess { + return authPartialSuccess, msg.Methods, nil + } + if !gotUserAuthInfoRequest { + return authFailure, msg.Methods, unexpectedMessageError(msgUserAuthInfoRequest, packet[0]) + } + return authFailure, msg.Methods, nil + case msgUserAuthSuccess: + return authSuccess, nil, nil + default: + return authFailure, nil, unexpectedMessageError(msgUserAuthInfoRequest, packet[0]) + } + + var msg userAuthInfoRequestMsg + if err := Unmarshal(packet, &msg); err != nil { + return authFailure, nil, err + } + gotUserAuthInfoRequest = true + + // Manually unpack the prompt/echo pairs. + rest := msg.Prompts + var prompts []string + var echos []bool + for i := 0; i < int(msg.NumPrompts); i++ { + prompt, r, ok := parseString(rest) + if !ok || len(r) == 0 { + return authFailure, nil, errors.New("ssh: prompt format error") + } + prompts = append(prompts, string(prompt)) + echos = append(echos, r[0] != 0) + rest = r[1:] + } + + if len(rest) != 0 { + return authFailure, nil, errors.New("ssh: extra data following keyboard-interactive pairs") + } + + answers, err := cb(msg.Name, msg.Instruction, prompts, echos) + if err != nil { + return authFailure, nil, err + } + + if len(answers) != len(prompts) { + return authFailure, nil, fmt.Errorf("ssh: incorrect number of answers from keyboard-interactive callback %d (expected %d)", len(answers), len(prompts)) + } + responseLength := 1 + 4 + for _, a := range answers { + responseLength += stringLength(len(a)) + } + serialized := make([]byte, responseLength) + p := serialized + p[0] = msgUserAuthInfoResponse + p = p[1:] + p = marshalUint32(p, uint32(len(answers))) + for _, a := range answers { + p = marshalString(p, []byte(a)) + } + + if err := c.writePacket(serialized); err != nil { + return authFailure, nil, err + } + } +} + +type retryableAuthMethod struct { + authMethod AuthMethod + maxTries int +} + +func (r *retryableAuthMethod) auth(session []byte, user string, c packetConn, rand io.Reader, extensions map[string][]byte) (ok authResult, methods []string, err error) { + for i := 0; r.maxTries <= 0 || i < r.maxTries; i++ { + ok, methods, err = r.authMethod.auth(session, user, c, rand, extensions) + if ok != authFailure || err != nil { // either success, partial success or error terminate + return ok, methods, err + } + } + return ok, methods, err +} + +func (r *retryableAuthMethod) method() string { + return r.authMethod.method() +} + +// RetryableAuthMethod is a decorator for other auth methods enabling them to +// be retried up to maxTries before considering that AuthMethod itself failed. +// If maxTries is <= 0, will retry indefinitely +// +// This is useful for interactive clients using challenge/response type +// authentication (e.g. Keyboard-Interactive, Password, etc) where the user +// could mistype their response resulting in the server issuing a +// SSH_MSG_USERAUTH_FAILURE (rfc4252 #8 [password] and rfc4256 #3.4 +// [keyboard-interactive]); Without this decorator, the non-retryable +// AuthMethod would be removed from future consideration, and never tried again +// (and so the user would never be able to retry their entry). +func RetryableAuthMethod(auth AuthMethod, maxTries int) AuthMethod { + return &retryableAuthMethod{authMethod: auth, maxTries: maxTries} +} + +// GSSAPIWithMICAuthMethod is an AuthMethod with "gssapi-with-mic" authentication. +// See RFC 4462 section 3 +// gssAPIClient is implementation of the GSSAPIClient interface, see the definition of the interface for details. +// target is the server host you want to log in to. +func GSSAPIWithMICAuthMethod(gssAPIClient GSSAPIClient, target string) AuthMethod { + if gssAPIClient == nil { + panic("gss-api client must be not nil with enable gssapi-with-mic") + } + return &gssAPIWithMICCallback{gssAPIClient: gssAPIClient, target: target} +} + +type gssAPIWithMICCallback struct { + gssAPIClient GSSAPIClient + target string +} + +func (g *gssAPIWithMICCallback) auth(session []byte, user string, c packetConn, rand io.Reader, _ map[string][]byte) (authResult, []string, error) { + m := &userAuthRequestMsg{ + User: user, + Service: serviceSSH, + Method: g.method(), + } + // The GSS-API authentication method is initiated when the client sends an SSH_MSG_USERAUTH_REQUEST. + // See RFC 4462 section 3.2. + m.Payload = appendU32(m.Payload, 1) + m.Payload = appendString(m.Payload, string(krb5OID)) + if err := c.writePacket(Marshal(m)); err != nil { + return authFailure, nil, err + } + // The server responds to the SSH_MSG_USERAUTH_REQUEST with either an + // SSH_MSG_USERAUTH_FAILURE if none of the mechanisms are supported or + // with an SSH_MSG_USERAUTH_GSSAPI_RESPONSE. + // See RFC 4462 section 3.3. + // OpenSSH supports Kerberos V5 mechanism only for GSS-API authentication,so I don't want to check + // selected mech if it is valid. + packet, err := c.readPacket() + if err != nil { + return authFailure, nil, err + } + userAuthGSSAPIResp := &userAuthGSSAPIResponse{} + if err := Unmarshal(packet, userAuthGSSAPIResp); err != nil { + return authFailure, nil, err + } + // Start the loop into the exchange token. + // See RFC 4462 section 3.4. + var token []byte + defer g.gssAPIClient.DeleteSecContext() + for { + // Initiates the establishment of a security context between the application and a remote peer. + nextToken, needContinue, err := g.gssAPIClient.InitSecContext("host@"+g.target, token, false) + if err != nil { + return authFailure, nil, err + } + if len(nextToken) > 0 { + if err := c.writePacket(Marshal(&userAuthGSSAPIToken{ + Token: nextToken, + })); err != nil { + return authFailure, nil, err + } + } + if !needContinue { + break + } + packet, err = c.readPacket() + if err != nil { + return authFailure, nil, err + } + switch packet[0] { + case msgUserAuthFailure: + var msg userAuthFailureMsg + if err := Unmarshal(packet, &msg); err != nil { + return authFailure, nil, err + } + if msg.PartialSuccess { + return authPartialSuccess, msg.Methods, nil + } + return authFailure, msg.Methods, nil + case msgUserAuthGSSAPIError: + userAuthGSSAPIErrorResp := &userAuthGSSAPIError{} + if err := Unmarshal(packet, userAuthGSSAPIErrorResp); err != nil { + return authFailure, nil, err + } + return authFailure, nil, fmt.Errorf("GSS-API Error:\n"+ + "Major Status: %d\n"+ + "Minor Status: %d\n"+ + "Error Message: %s\n", userAuthGSSAPIErrorResp.MajorStatus, userAuthGSSAPIErrorResp.MinorStatus, + userAuthGSSAPIErrorResp.Message) + case msgUserAuthGSSAPIToken: + userAuthGSSAPITokenReq := &userAuthGSSAPIToken{} + if err := Unmarshal(packet, userAuthGSSAPITokenReq); err != nil { + return authFailure, nil, err + } + token = userAuthGSSAPITokenReq.Token + } + } + // Binding Encryption Keys. + // See RFC 4462 section 3.5. + micField := buildMIC(string(session), user, "ssh-connection", "gssapi-with-mic") + micToken, err := g.gssAPIClient.GetMIC(micField) + if err != nil { + return authFailure, nil, err + } + if err := c.writePacket(Marshal(&userAuthGSSAPIMIC{ + MIC: micToken, + })); err != nil { + return authFailure, nil, err + } + return handleAuthResponse(c) +} + +func (g *gssAPIWithMICCallback) method() string { + return "gssapi-with-mic" +} diff --git a/tempfork/sshtest/ssh/client_auth_test.go b/tempfork/sshtest/ssh/client_auth_test.go new file mode 100644 index 000000000..ec27133a3 --- /dev/null +++ b/tempfork/sshtest/ssh/client_auth_test.go @@ -0,0 +1,1384 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "bytes" + "crypto/rand" + "errors" + "fmt" + "io" + "log" + "net" + "os" + "runtime" + "strings" + "testing" +) + +type keyboardInteractive map[string]string + +func (cr keyboardInteractive) Challenge(user string, instruction string, questions []string, echos []bool) ([]string, error) { + var answers []string + for _, q := range questions { + answers = append(answers, cr[q]) + } + return answers, nil +} + +// reused internally by tests +var clientPassword = "tiger" + +// tryAuth runs a handshake with a given config against an SSH server +// with config serverConfig. Returns both client and server side errors. +func tryAuth(t *testing.T, config *ClientConfig) error { + err, _ := tryAuthBothSides(t, config, nil) + return err +} + +// tryAuthWithGSSAPIWithMICConfig runs a handshake with a given config against an SSH server +// with a given GSSAPIWithMICConfig and config serverConfig. Returns both client and server side errors. +func tryAuthWithGSSAPIWithMICConfig(t *testing.T, clientConfig *ClientConfig, gssAPIWithMICConfig *GSSAPIWithMICConfig) error { + err, _ := tryAuthBothSides(t, clientConfig, gssAPIWithMICConfig) + return err +} + +// tryAuthBothSides runs the handshake and returns the resulting errors from both sides of the connection. +func tryAuthBothSides(t *testing.T, config *ClientConfig, gssAPIWithMICConfig *GSSAPIWithMICConfig) (clientError error, serverAuthErrors []error) { + c1, c2, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + defer c1.Close() + defer c2.Close() + + certChecker := CertChecker{ + IsUserAuthority: func(k PublicKey) bool { + return bytes.Equal(k.Marshal(), testPublicKeys["ecdsa"].Marshal()) + }, + UserKeyFallback: func(conn ConnMetadata, key PublicKey) (*Permissions, error) { + if conn.User() == "testuser" && bytes.Equal(key.Marshal(), testPublicKeys["rsa"].Marshal()) { + return nil, nil + } + + return nil, fmt.Errorf("pubkey for %q not acceptable", conn.User()) + }, + IsRevoked: func(c *Certificate) bool { + return c.Serial == 666 + }, + } + serverConfig := &ServerConfig{ + PasswordCallback: func(conn ConnMetadata, pass []byte) (*Permissions, error) { + if conn.User() == "testuser" && string(pass) == clientPassword { + return nil, nil + } + return nil, errors.New("password auth failed") + }, + PublicKeyCallback: certChecker.Authenticate, + KeyboardInteractiveCallback: func(conn ConnMetadata, challenge KeyboardInteractiveChallenge) (*Permissions, error) { + ans, err := challenge("user", + "instruction", + []string{"question1", "question2"}, + []bool{true, true}) + if err != nil { + return nil, err + } + ok := conn.User() == "testuser" && ans[0] == "answer1" && ans[1] == "answer2" + if ok { + challenge("user", "motd", nil, nil) + return nil, nil + } + return nil, errors.New("keyboard-interactive failed") + }, + GSSAPIWithMICConfig: gssAPIWithMICConfig, + } + serverConfig.AddHostKey(testSigners["rsa"]) + + serverConfig.AuthLogCallback = func(conn ConnMetadata, method string, err error) { + serverAuthErrors = append(serverAuthErrors, err) + } + + go newServer(c1, serverConfig) + _, _, _, err = NewClientConn(c2, "", config) + return err, serverAuthErrors +} + +type loggingAlgorithmSigner struct { + used []string + AlgorithmSigner +} + +func (l *loggingAlgorithmSigner) Sign(rand io.Reader, data []byte) (*Signature, error) { + l.used = append(l.used, "[Sign]") + return l.AlgorithmSigner.Sign(rand, data) +} + +func (l *loggingAlgorithmSigner) SignWithAlgorithm(rand io.Reader, data []byte, algorithm string) (*Signature, error) { + l.used = append(l.used, algorithm) + return l.AlgorithmSigner.SignWithAlgorithm(rand, data, algorithm) +} + +func TestClientAuthPublicKey(t *testing.T) { + signer := &loggingAlgorithmSigner{AlgorithmSigner: testSigners["rsa"].(AlgorithmSigner)} + config := &ClientConfig{ + User: "testuser", + Auth: []AuthMethod{ + PublicKeys(signer), + }, + HostKeyCallback: InsecureIgnoreHostKey(), + } + if err := tryAuth(t, config); err != nil { + t.Fatalf("unable to dial remote side: %s", err) + } + if len(signer.used) != 1 || signer.used[0] != KeyAlgoRSASHA256 { + t.Errorf("unexpected Sign/SignWithAlgorithm calls: %q", signer.used) + } +} + +// TestClientAuthNoSHA2 tests a ssh-rsa Signer that doesn't implement AlgorithmSigner. +func TestClientAuthNoSHA2(t *testing.T) { + config := &ClientConfig{ + User: "testuser", + Auth: []AuthMethod{ + PublicKeys(&legacyRSASigner{testSigners["rsa"]}), + }, + HostKeyCallback: InsecureIgnoreHostKey(), + } + if err := tryAuth(t, config); err != nil { + t.Fatalf("unable to dial remote side: %s", err) + } +} + +// TestClientAuthThirdKey checks that the third configured can succeed. If we +// were to do three attempts for each key (rsa-sha2-256, rsa-sha2-512, ssh-rsa), +// we'd hit the six maximum attempts before reaching it. +func TestClientAuthThirdKey(t *testing.T) { + config := &ClientConfig{ + User: "testuser", + Auth: []AuthMethod{ + PublicKeys(testSigners["rsa-openssh-format"], + testSigners["rsa-openssh-format"], testSigners["rsa"]), + }, + HostKeyCallback: InsecureIgnoreHostKey(), + } + if err := tryAuth(t, config); err != nil { + t.Fatalf("unable to dial remote side: %s", err) + } +} + +func TestAuthMethodPassword(t *testing.T) { + config := &ClientConfig{ + User: "testuser", + Auth: []AuthMethod{ + Password(clientPassword), + }, + HostKeyCallback: InsecureIgnoreHostKey(), + } + + if err := tryAuth(t, config); err != nil { + t.Fatalf("unable to dial remote side: %s", err) + } +} + +func TestAuthMethodFallback(t *testing.T) { + var passwordCalled bool + config := &ClientConfig{ + User: "testuser", + Auth: []AuthMethod{ + PublicKeys(testSigners["rsa"]), + PasswordCallback( + func() (string, error) { + passwordCalled = true + return "WRONG", nil + }), + }, + HostKeyCallback: InsecureIgnoreHostKey(), + } + + if err := tryAuth(t, config); err != nil { + t.Fatalf("unable to dial remote side: %s", err) + } + + if passwordCalled { + t.Errorf("password auth tried before public-key auth.") + } +} + +func TestAuthMethodWrongPassword(t *testing.T) { + config := &ClientConfig{ + User: "testuser", + Auth: []AuthMethod{ + Password("wrong"), + PublicKeys(testSigners["rsa"]), + }, + HostKeyCallback: InsecureIgnoreHostKey(), + } + + if err := tryAuth(t, config); err != nil { + t.Fatalf("unable to dial remote side: %s", err) + } +} + +func TestAuthMethodKeyboardInteractive(t *testing.T) { + answers := keyboardInteractive(map[string]string{ + "question1": "answer1", + "question2": "answer2", + }) + config := &ClientConfig{ + User: "testuser", + Auth: []AuthMethod{ + KeyboardInteractive(answers.Challenge), + }, + HostKeyCallback: InsecureIgnoreHostKey(), + } + + if err := tryAuth(t, config); err != nil { + t.Fatalf("unable to dial remote side: %s", err) + } +} + +func TestAuthMethodWrongKeyboardInteractive(t *testing.T) { + answers := keyboardInteractive(map[string]string{ + "question1": "answer1", + "question2": "WRONG", + }) + config := &ClientConfig{ + User: "testuser", + Auth: []AuthMethod{ + KeyboardInteractive(answers.Challenge), + }, + } + + if err := tryAuth(t, config); err == nil { + t.Fatalf("wrong answers should not have authenticated with KeyboardInteractive") + } +} + +// the mock server will only authenticate ssh-rsa keys +func TestAuthMethodInvalidPublicKey(t *testing.T) { + config := &ClientConfig{ + User: "testuser", + Auth: []AuthMethod{ + PublicKeys(testSigners["dsa"]), + }, + } + + if err := tryAuth(t, config); err == nil { + t.Fatalf("dsa private key should not have authenticated with rsa public key") + } +} + +// the client should authenticate with the second key +func TestAuthMethodRSAandDSA(t *testing.T) { + config := &ClientConfig{ + User: "testuser", + Auth: []AuthMethod{ + PublicKeys(testSigners["dsa"], testSigners["rsa"]), + }, + HostKeyCallback: InsecureIgnoreHostKey(), + } + if err := tryAuth(t, config); err != nil { + t.Fatalf("client could not authenticate with rsa key: %v", err) + } +} + +type invalidAlgSigner struct { + Signer +} + +func (s *invalidAlgSigner) Sign(rand io.Reader, data []byte) (*Signature, error) { + sig, err := s.Signer.Sign(rand, data) + if sig != nil { + sig.Format = "invalid" + } + return sig, err +} + +func TestMethodInvalidAlgorithm(t *testing.T) { + config := &ClientConfig{ + User: "testuser", + Auth: []AuthMethod{ + PublicKeys(&invalidAlgSigner{testSigners["rsa"]}), + }, + HostKeyCallback: InsecureIgnoreHostKey(), + } + + err, serverErrors := tryAuthBothSides(t, config, nil) + if err == nil { + t.Fatalf("login succeeded") + } + + found := false + want := "algorithm \"invalid\"" + + var errStrings []string + for _, err := range serverErrors { + found = found || (err != nil && strings.Contains(err.Error(), want)) + errStrings = append(errStrings, err.Error()) + } + if !found { + t.Errorf("server got error %q, want substring %q", errStrings, want) + } +} + +func TestClientHMAC(t *testing.T) { + for _, mac := range supportedMACs { + config := &ClientConfig{ + User: "testuser", + Auth: []AuthMethod{ + PublicKeys(testSigners["rsa"]), + }, + Config: Config{ + MACs: []string{mac}, + }, + HostKeyCallback: InsecureIgnoreHostKey(), + } + if err := tryAuth(t, config); err != nil { + t.Fatalf("client could not authenticate with mac algo %s: %v", mac, err) + } + } +} + +// issue 4285. +func TestClientUnsupportedCipher(t *testing.T) { + config := &ClientConfig{ + User: "testuser", + Auth: []AuthMethod{ + PublicKeys(), + }, + Config: Config{ + Ciphers: []string{"aes128-cbc"}, // not currently supported + }, + } + if err := tryAuth(t, config); err == nil { + t.Errorf("expected no ciphers in common") + } +} + +func TestClientUnsupportedKex(t *testing.T) { + if os.Getenv("GO_BUILDER_NAME") != "" { + t.Skip("skipping known-flaky test on the Go build dashboard; see golang.org/issue/15198") + } + config := &ClientConfig{ + User: "testuser", + Auth: []AuthMethod{ + PublicKeys(), + }, + Config: Config{ + KeyExchanges: []string{"non-existent-kex"}, + }, + HostKeyCallback: InsecureIgnoreHostKey(), + } + if err := tryAuth(t, config); err == nil || !strings.Contains(err.Error(), "common algorithm") { + t.Errorf("got %v, expected 'common algorithm'", err) + } +} + +func TestClientLoginCert(t *testing.T) { + cert := &Certificate{ + Key: testPublicKeys["rsa"], + ValidBefore: CertTimeInfinity, + CertType: UserCert, + } + cert.SignCert(rand.Reader, testSigners["ecdsa"]) + certSigner, err := NewCertSigner(cert, testSigners["rsa"]) + if err != nil { + t.Fatalf("NewCertSigner: %v", err) + } + + clientConfig := &ClientConfig{ + User: "user", + HostKeyCallback: InsecureIgnoreHostKey(), + } + clientConfig.Auth = append(clientConfig.Auth, PublicKeys(certSigner)) + + // should succeed + if err := tryAuth(t, clientConfig); err != nil { + t.Errorf("cert login failed: %v", err) + } + + // corrupted signature + cert.Signature.Blob[0]++ + if err := tryAuth(t, clientConfig); err == nil { + t.Errorf("cert login passed with corrupted sig") + } + + // revoked + cert.Serial = 666 + cert.SignCert(rand.Reader, testSigners["ecdsa"]) + if err := tryAuth(t, clientConfig); err == nil { + t.Errorf("revoked cert login succeeded") + } + cert.Serial = 1 + + // sign with wrong key + cert.SignCert(rand.Reader, testSigners["dsa"]) + if err := tryAuth(t, clientConfig); err == nil { + t.Errorf("cert login passed with non-authoritative key") + } + + // host cert + cert.CertType = HostCert + cert.SignCert(rand.Reader, testSigners["ecdsa"]) + if err := tryAuth(t, clientConfig); err == nil { + t.Errorf("cert login passed with wrong type") + } + cert.CertType = UserCert + + // principal specified + cert.ValidPrincipals = []string{"user"} + cert.SignCert(rand.Reader, testSigners["ecdsa"]) + if err := tryAuth(t, clientConfig); err != nil { + t.Errorf("cert login failed: %v", err) + } + + // wrong principal specified + cert.ValidPrincipals = []string{"fred"} + cert.SignCert(rand.Reader, testSigners["ecdsa"]) + if err := tryAuth(t, clientConfig); err == nil { + t.Errorf("cert login passed with wrong principal") + } + cert.ValidPrincipals = nil + + // added critical option + cert.CriticalOptions = map[string]string{"root-access": "yes"} + cert.SignCert(rand.Reader, testSigners["ecdsa"]) + if err := tryAuth(t, clientConfig); err == nil { + t.Errorf("cert login passed with unrecognized critical option") + } + + // allowed source address + cert.CriticalOptions = map[string]string{"source-address": "127.0.0.42/24,::42/120"} + cert.SignCert(rand.Reader, testSigners["ecdsa"]) + if err := tryAuth(t, clientConfig); err != nil { + t.Errorf("cert login with source-address failed: %v", err) + } + + // disallowed source address + cert.CriticalOptions = map[string]string{"source-address": "127.0.0.42,::42"} + cert.SignCert(rand.Reader, testSigners["ecdsa"]) + if err := tryAuth(t, clientConfig); err == nil { + t.Errorf("cert login with source-address succeeded") + } +} + +func testPermissionsPassing(withPermissions bool, t *testing.T) { + serverConfig := &ServerConfig{ + PublicKeyCallback: func(conn ConnMetadata, key PublicKey) (*Permissions, error) { + if conn.User() == "nopermissions" { + return nil, nil + } + return &Permissions{}, nil + }, + } + serverConfig.AddHostKey(testSigners["rsa"]) + + clientConfig := &ClientConfig{ + Auth: []AuthMethod{ + PublicKeys(testSigners["rsa"]), + }, + HostKeyCallback: InsecureIgnoreHostKey(), + } + if withPermissions { + clientConfig.User = "permissions" + } else { + clientConfig.User = "nopermissions" + } + + c1, c2, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + defer c1.Close() + defer c2.Close() + + go NewClientConn(c2, "", clientConfig) + serverConn, err := newServer(c1, serverConfig) + if err != nil { + t.Fatal(err) + } + if p := serverConn.Permissions; (p != nil) != withPermissions { + t.Fatalf("withPermissions is %t, but Permissions object is %#v", withPermissions, p) + } +} + +func TestPermissionsPassing(t *testing.T) { + testPermissionsPassing(true, t) +} + +func TestNoPermissionsPassing(t *testing.T) { + testPermissionsPassing(false, t) +} + +func TestRetryableAuth(t *testing.T) { + n := 0 + passwords := []string{"WRONG1", "WRONG2"} + + config := &ClientConfig{ + User: "testuser", + Auth: []AuthMethod{ + RetryableAuthMethod(PasswordCallback(func() (string, error) { + p := passwords[n] + n++ + return p, nil + }), 2), + PublicKeys(testSigners["rsa"]), + }, + HostKeyCallback: InsecureIgnoreHostKey(), + } + + if err := tryAuth(t, config); err != nil { + t.Fatalf("unable to dial remote side: %s", err) + } + if n != 2 { + t.Fatalf("Did not try all passwords") + } +} + +func ExampleRetryableAuthMethod() { + user := "testuser" + NumberOfPrompts := 3 + + // Normally this would be a callback that prompts the user to answer the + // provided questions + Cb := func(user, instruction string, questions []string, echos []bool) (answers []string, err error) { + return []string{"answer1", "answer2"}, nil + } + + config := &ClientConfig{ + HostKeyCallback: InsecureIgnoreHostKey(), + User: user, + Auth: []AuthMethod{ + RetryableAuthMethod(KeyboardInteractiveChallenge(Cb), NumberOfPrompts), + }, + } + + host := "mysshserver" + netConn, err := net.Dial("tcp", host) + if err != nil { + log.Fatal(err) + } + + sshConn, _, _, err := NewClientConn(netConn, host, config) + if err != nil { + log.Fatal(err) + } + _ = sshConn +} + +// Test if username is received on server side when NoClientAuth is used +func TestClientAuthNone(t *testing.T) { + user := "testuser" + serverConfig := &ServerConfig{ + NoClientAuth: true, + } + serverConfig.AddHostKey(testSigners["rsa"]) + + clientConfig := &ClientConfig{ + User: user, + HostKeyCallback: InsecureIgnoreHostKey(), + } + + c1, c2, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + defer c1.Close() + defer c2.Close() + + go NewClientConn(c2, "", clientConfig) + serverConn, err := newServer(c1, serverConfig) + if err != nil { + t.Fatalf("newServer: %v", err) + } + if serverConn.User() != user { + t.Fatalf("server: got %q, want %q", serverConn.User(), user) + } +} + +// Test if authentication attempts are limited on server when MaxAuthTries is set +func TestClientAuthMaxAuthTries(t *testing.T) { + user := "testuser" + + serverConfig := &ServerConfig{ + MaxAuthTries: 2, + PasswordCallback: func(conn ConnMetadata, pass []byte) (*Permissions, error) { + if conn.User() == "testuser" && string(pass) == "right" { + return nil, nil + } + return nil, errors.New("password auth failed") + }, + } + serverConfig.AddHostKey(testSigners["rsa"]) + + expectedErr := fmt.Errorf("ssh: handshake failed: %v", &disconnectMsg{ + Reason: 2, + Message: "too many authentication failures", + }) + + for tries := 2; tries < 4; tries++ { + n := tries + clientConfig := &ClientConfig{ + User: user, + Auth: []AuthMethod{ + RetryableAuthMethod(PasswordCallback(func() (string, error) { + n-- + if n == 0 { + return "right", nil + } + return "wrong", nil + }), tries), + }, + HostKeyCallback: InsecureIgnoreHostKey(), + } + + c1, c2, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + defer c1.Close() + defer c2.Close() + + errCh := make(chan error, 1) + + go func() { + _, err := newServer(c1, serverConfig) + errCh <- err + }() + _, _, _, cliErr := NewClientConn(c2, "", clientConfig) + srvErr := <-errCh + + if tries > serverConfig.MaxAuthTries { + if cliErr == nil { + t.Fatalf("client: got no error, want %s", expectedErr) + } else if cliErr.Error() != expectedErr.Error() { + t.Fatalf("client: got %s, want %s", err, expectedErr) + } + var authErr *ServerAuthError + if !errors.As(srvErr, &authErr) { + t.Errorf("expected ServerAuthError, got: %v", srvErr) + } + } else { + if cliErr != nil { + t.Fatalf("client: got %s, want no error", cliErr) + } + } + } +} + +// Test if authentication attempts are correctly limited on server +// when more public keys are provided then MaxAuthTries +func TestClientAuthMaxAuthTriesPublicKey(t *testing.T) { + signers := []Signer{} + for i := 0; i < 6; i++ { + signers = append(signers, testSigners["dsa"]) + } + + validConfig := &ClientConfig{ + User: "testuser", + Auth: []AuthMethod{ + PublicKeys(append([]Signer{testSigners["rsa"]}, signers...)...), + }, + HostKeyCallback: InsecureIgnoreHostKey(), + } + if err := tryAuth(t, validConfig); err != nil { + t.Fatalf("unable to dial remote side: %s", err) + } + + expectedErr := fmt.Errorf("ssh: handshake failed: %v", &disconnectMsg{ + Reason: 2, + Message: "too many authentication failures", + }) + invalidConfig := &ClientConfig{ + User: "testuser", + Auth: []AuthMethod{ + PublicKeys(append(signers, testSigners["rsa"])...), + }, + HostKeyCallback: InsecureIgnoreHostKey(), + } + if err := tryAuth(t, invalidConfig); err == nil { + t.Fatalf("client: got no error, want %s", expectedErr) + } else if err.Error() != expectedErr.Error() { + // On Windows we can see a WSAECONNABORTED error + // if the client writes another authentication request + // before the client goroutine reads the disconnection + // message. See issue 50805. + if runtime.GOOS == "windows" && strings.Contains(err.Error(), "wsarecv: An established connection was aborted") { + // OK. + } else { + t.Fatalf("client: got %s, want %s", err, expectedErr) + } + } +} + +// Test whether authentication errors are being properly logged if all +// authentication methods have been exhausted +func TestClientAuthErrorList(t *testing.T) { + publicKeyErr := errors.New("This is an error from PublicKeyCallback") + + clientConfig := &ClientConfig{ + Auth: []AuthMethod{ + PublicKeys(testSigners["rsa"]), + }, + HostKeyCallback: InsecureIgnoreHostKey(), + } + serverConfig := &ServerConfig{ + PublicKeyCallback: func(_ ConnMetadata, _ PublicKey) (*Permissions, error) { + return nil, publicKeyErr + }, + } + serverConfig.AddHostKey(testSigners["rsa"]) + + c1, c2, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + defer c1.Close() + defer c2.Close() + + go NewClientConn(c2, "", clientConfig) + _, err = newServer(c1, serverConfig) + if err == nil { + t.Fatal("newServer: got nil, expected errors") + } + + authErrs, ok := err.(*ServerAuthError) + if !ok { + t.Fatalf("errors: got %T, want *ssh.ServerAuthError", err) + } + for i, e := range authErrs.Errors { + switch i { + case 0: + if e != ErrNoAuth { + t.Fatalf("errors: got error %v, want ErrNoAuth", e) + } + case 1: + if e != publicKeyErr { + t.Fatalf("errors: got %v, want %v", e, publicKeyErr) + } + default: + t.Fatalf("errors: got %v, expected 2 errors", authErrs.Errors) + } + } +} + +func TestAuthMethodGSSAPIWithMIC(t *testing.T) { + type testcase struct { + config *ClientConfig + gssConfig *GSSAPIWithMICConfig + clientWantErr string + serverWantErr string + } + testcases := []*testcase{ + { + config: &ClientConfig{ + User: "testuser", + Auth: []AuthMethod{ + GSSAPIWithMICAuthMethod( + &FakeClient{ + exchanges: []*exchange{ + { + outToken: "client-valid-token-1", + }, + { + expectedToken: "server-valid-token-1", + }, + }, + mic: []byte("valid-mic"), + maxRound: 2, + }, "testtarget", + ), + }, + HostKeyCallback: InsecureIgnoreHostKey(), + }, + gssConfig: &GSSAPIWithMICConfig{ + AllowLogin: func(conn ConnMetadata, srcName string) (*Permissions, error) { + if srcName != conn.User()+"@DOMAIN" { + return nil, fmt.Errorf("srcName is %s, conn user is %s", srcName, conn.User()) + } + return nil, nil + }, + Server: &FakeServer{ + exchanges: []*exchange{ + { + outToken: "server-valid-token-1", + expectedToken: "client-valid-token-1", + }, + }, + maxRound: 1, + expectedMIC: []byte("valid-mic"), + srcName: "testuser@DOMAIN", + }, + }, + }, + { + config: &ClientConfig{ + User: "testuser", + Auth: []AuthMethod{ + GSSAPIWithMICAuthMethod( + &FakeClient{ + exchanges: []*exchange{ + { + outToken: "client-valid-token-1", + }, + { + expectedToken: "server-valid-token-1", + }, + }, + mic: []byte("valid-mic"), + maxRound: 2, + }, "testtarget", + ), + }, + HostKeyCallback: InsecureIgnoreHostKey(), + }, + gssConfig: &GSSAPIWithMICConfig{ + AllowLogin: func(conn ConnMetadata, srcName string) (*Permissions, error) { + return nil, fmt.Errorf("user is not allowed to login") + }, + Server: &FakeServer{ + exchanges: []*exchange{ + { + outToken: "server-valid-token-1", + expectedToken: "client-valid-token-1", + }, + }, + maxRound: 1, + expectedMIC: []byte("valid-mic"), + srcName: "testuser@DOMAIN", + }, + }, + serverWantErr: "user is not allowed to login", + clientWantErr: "ssh: handshake failed: ssh: unable to authenticate", + }, + { + config: &ClientConfig{ + User: "testuser", + Auth: []AuthMethod{ + GSSAPIWithMICAuthMethod( + &FakeClient{ + exchanges: []*exchange{ + { + outToken: "client-valid-token-1", + }, + { + expectedToken: "server-valid-token-1", + }, + }, + mic: []byte("valid-mic"), + maxRound: 2, + }, "testtarget", + ), + }, + HostKeyCallback: InsecureIgnoreHostKey(), + }, + gssConfig: &GSSAPIWithMICConfig{ + AllowLogin: func(conn ConnMetadata, srcName string) (*Permissions, error) { + if srcName != conn.User() { + return nil, fmt.Errorf("srcName is %s, conn user is %s", srcName, conn.User()) + } + return nil, nil + }, + Server: &FakeServer{ + exchanges: []*exchange{ + { + outToken: "server-invalid-token-1", + expectedToken: "client-valid-token-1", + }, + }, + maxRound: 1, + expectedMIC: []byte("valid-mic"), + srcName: "testuser@DOMAIN", + }, + }, + clientWantErr: "ssh: handshake failed: got \"server-invalid-token-1\", want token \"server-valid-token-1\"", + }, + { + config: &ClientConfig{ + User: "testuser", + Auth: []AuthMethod{ + GSSAPIWithMICAuthMethod( + &FakeClient{ + exchanges: []*exchange{ + { + outToken: "client-valid-token-1", + }, + { + expectedToken: "server-valid-token-1", + }, + }, + mic: []byte("invalid-mic"), + maxRound: 2, + }, "testtarget", + ), + }, + HostKeyCallback: InsecureIgnoreHostKey(), + }, + gssConfig: &GSSAPIWithMICConfig{ + AllowLogin: func(conn ConnMetadata, srcName string) (*Permissions, error) { + if srcName != conn.User() { + return nil, fmt.Errorf("srcName is %s, conn user is %s", srcName, conn.User()) + } + return nil, nil + }, + Server: &FakeServer{ + exchanges: []*exchange{ + { + outToken: "server-valid-token-1", + expectedToken: "client-valid-token-1", + }, + }, + maxRound: 1, + expectedMIC: []byte("valid-mic"), + srcName: "testuser@DOMAIN", + }, + }, + serverWantErr: "got MICToken \"invalid-mic\", want \"valid-mic\"", + clientWantErr: "ssh: handshake failed: ssh: unable to authenticate", + }, + } + + for i, c := range testcases { + clientErr, serverErrs := tryAuthBothSides(t, c.config, c.gssConfig) + if (c.clientWantErr == "") != (clientErr == nil) { + t.Fatalf("client got %v, want %s, case %d", clientErr, c.clientWantErr, i) + } + if (c.serverWantErr == "") != (len(serverErrs) == 2 && serverErrs[1] == nil || len(serverErrs) == 1) { + t.Fatalf("server got err %v, want %s", serverErrs, c.serverWantErr) + } + if c.clientWantErr != "" { + if clientErr != nil && !strings.Contains(clientErr.Error(), c.clientWantErr) { + t.Fatalf("client got %v, want %s, case %d", clientErr, c.clientWantErr, i) + } + } + found := false + var errStrings []string + if c.serverWantErr != "" { + for _, err := range serverErrs { + found = found || (err != nil && strings.Contains(err.Error(), c.serverWantErr)) + errStrings = append(errStrings, err.Error()) + } + if !found { + t.Errorf("server got error %q, want substring %q, case %d", errStrings, c.serverWantErr, i) + } + } + } +} + +func TestCompatibleAlgoAndSignatures(t *testing.T) { + type testcase struct { + algo string + sigFormat string + compatible bool + } + testcases := []*testcase{ + { + KeyAlgoRSA, + KeyAlgoRSA, + true, + }, + { + KeyAlgoRSA, + KeyAlgoRSASHA256, + true, + }, + { + KeyAlgoRSA, + KeyAlgoRSASHA512, + true, + }, + { + KeyAlgoRSASHA256, + KeyAlgoRSA, + true, + }, + { + KeyAlgoRSASHA512, + KeyAlgoRSA, + true, + }, + { + KeyAlgoRSASHA512, + KeyAlgoRSASHA256, + true, + }, + { + KeyAlgoRSASHA256, + KeyAlgoRSASHA512, + true, + }, + { + KeyAlgoRSASHA512, + KeyAlgoRSASHA512, + true, + }, + { + CertAlgoRSAv01, + KeyAlgoRSA, + true, + }, + { + CertAlgoRSAv01, + KeyAlgoRSASHA256, + true, + }, + { + CertAlgoRSAv01, + KeyAlgoRSASHA512, + true, + }, + { + CertAlgoRSASHA256v01, + KeyAlgoRSASHA512, + true, + }, + { + CertAlgoRSASHA512v01, + KeyAlgoRSASHA512, + true, + }, + { + CertAlgoRSASHA512v01, + KeyAlgoRSASHA256, + true, + }, + { + CertAlgoRSASHA256v01, + CertAlgoRSAv01, + true, + }, + { + CertAlgoRSAv01, + CertAlgoRSASHA512v01, + true, + }, + { + KeyAlgoECDSA256, + KeyAlgoRSA, + false, + }, + { + KeyAlgoECDSA256, + KeyAlgoECDSA521, + false, + }, + { + KeyAlgoECDSA256, + KeyAlgoECDSA256, + true, + }, + { + KeyAlgoECDSA256, + KeyAlgoED25519, + false, + }, + { + KeyAlgoED25519, + KeyAlgoED25519, + true, + }, + } + + for _, c := range testcases { + if isAlgoCompatible(c.algo, c.sigFormat) != c.compatible { + t.Errorf("algorithm %q, signature format %q, expected compatible to be %t", c.algo, c.sigFormat, c.compatible) + } + } +} + +func TestPickSignatureAlgorithm(t *testing.T) { + type testcase struct { + name string + extensions map[string][]byte + } + cases := []testcase{ + { + name: "server with empty server-sig-algs", + extensions: map[string][]byte{ + "server-sig-algs": []byte(``), + }, + }, + { + name: "server with no server-sig-algs", + extensions: nil, + }, + } + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + signer, ok := testSigners["rsa"].(MultiAlgorithmSigner) + if !ok { + t.Fatalf("rsa test signer does not implement the MultiAlgorithmSigner interface") + } + // The signer supports the public key algorithm which is then returned. + _, algo, err := pickSignatureAlgorithm(signer, c.extensions) + if err != nil { + t.Fatalf("got %v, want no error", err) + } + if algo != signer.PublicKey().Type() { + t.Fatalf("got algo %q, want %q", algo, signer.PublicKey().Type()) + } + // Test a signer that uses a certificate algorithm as the public key + // type. + cert := &Certificate{ + CertType: UserCert, + Key: signer.PublicKey(), + } + cert.SignCert(rand.Reader, signer) + + certSigner, err := NewCertSigner(cert, signer) + if err != nil { + t.Fatalf("error generating cert signer: %v", err) + } + // The signer supports the public key algorithm and the + // public key format is a certificate type so the cerificate + // algorithm matching the key format must be returned + _, algo, err = pickSignatureAlgorithm(certSigner, c.extensions) + if err != nil { + t.Fatalf("got %v, want no error", err) + } + if algo != certSigner.PublicKey().Type() { + t.Fatalf("got algo %q, want %q", algo, certSigner.PublicKey().Type()) + } + signer, err = NewSignerWithAlgorithms(signer.(AlgorithmSigner), []string{KeyAlgoRSASHA512, KeyAlgoRSASHA256}) + if err != nil { + t.Fatalf("unable to create signer with algorithms: %v", err) + } + // The signer does not support the public key algorithm so an error + // is returned. + _, _, err = pickSignatureAlgorithm(signer, c.extensions) + if err == nil { + t.Fatal("got no error, no common public key signature algorithm error expected") + } + }) + } +} + +// configurablePublicKeyCallback is a public key callback that allows to +// configure the signature algorithm and format. This way we can emulate the +// behavior of buggy clients. +type configurablePublicKeyCallback struct { + signer AlgorithmSigner + signatureAlgo string + signatureFormat string +} + +func (cb configurablePublicKeyCallback) method() string { + return "publickey" +} + +func (cb configurablePublicKeyCallback) auth(session []byte, user string, c packetConn, rand io.Reader, extensions map[string][]byte) (authResult, []string, error) { + pub := cb.signer.PublicKey() + + ok, err := validateKey(pub, cb.signatureAlgo, user, c) + if err != nil { + return authFailure, nil, err + } + if !ok { + return authFailure, nil, fmt.Errorf("invalid public key") + } + + pubKey := pub.Marshal() + data := buildDataSignedForAuth(session, userAuthRequestMsg{ + User: user, + Service: serviceSSH, + Method: cb.method(), + }, cb.signatureAlgo, pubKey) + sign, err := cb.signer.SignWithAlgorithm(rand, data, underlyingAlgo(cb.signatureFormat)) + if err != nil { + return authFailure, nil, err + } + + s := Marshal(sign) + sig := make([]byte, stringLength(len(s))) + marshalString(sig, s) + msg := publickeyAuthMsg{ + User: user, + Service: serviceSSH, + Method: cb.method(), + HasSig: true, + Algoname: cb.signatureAlgo, + PubKey: pubKey, + Sig: sig, + } + p := Marshal(&msg) + if err := c.writePacket(p); err != nil { + return authFailure, nil, err + } + var success authResult + success, methods, err := handleAuthResponse(c) + if err != nil { + return authFailure, nil, err + } + if success == authSuccess || !contains(methods, cb.method()) { + return success, methods, err + } + + return authFailure, methods, nil +} + +func TestPublicKeyAndAlgoCompatibility(t *testing.T) { + cert := &Certificate{ + Key: testPublicKeys["rsa"], + ValidBefore: CertTimeInfinity, + CertType: UserCert, + } + cert.SignCert(rand.Reader, testSigners["ecdsa"]) + certSigner, err := NewCertSigner(cert, testSigners["rsa"]) + if err != nil { + t.Fatalf("NewCertSigner: %v", err) + } + + clientConfig := &ClientConfig{ + User: "user", + HostKeyCallback: InsecureIgnoreHostKey(), + Auth: []AuthMethod{ + configurablePublicKeyCallback{ + signer: certSigner.(AlgorithmSigner), + signatureAlgo: KeyAlgoRSASHA256, + signatureFormat: KeyAlgoRSASHA256, + }, + }, + } + if err := tryAuth(t, clientConfig); err == nil { + t.Error("cert login passed with incompatible public key type and algorithm") + } +} + +func TestClientAuthGPGAgentCompat(t *testing.T) { + clientConfig := &ClientConfig{ + User: "testuser", + HostKeyCallback: InsecureIgnoreHostKey(), + Auth: []AuthMethod{ + // algorithm rsa-sha2-512 and signature format ssh-rsa. + configurablePublicKeyCallback{ + signer: testSigners["rsa"].(AlgorithmSigner), + signatureAlgo: KeyAlgoRSASHA512, + signatureFormat: KeyAlgoRSA, + }, + }, + } + if err := tryAuth(t, clientConfig); err != nil { + t.Fatalf("unable to dial remote side: %s", err) + } +} + +func TestCertAuthOpenSSHCompat(t *testing.T) { + cert := &Certificate{ + Key: testPublicKeys["rsa"], + ValidBefore: CertTimeInfinity, + CertType: UserCert, + } + cert.SignCert(rand.Reader, testSigners["ecdsa"]) + certSigner, err := NewCertSigner(cert, testSigners["rsa"]) + if err != nil { + t.Fatalf("NewCertSigner: %v", err) + } + + clientConfig := &ClientConfig{ + User: "user", + HostKeyCallback: InsecureIgnoreHostKey(), + Auth: []AuthMethod{ + // algorithm ssh-rsa-cert-v01@openssh.com and signature format + // rsa-sha2-256. + configurablePublicKeyCallback{ + signer: certSigner.(AlgorithmSigner), + signatureAlgo: CertAlgoRSAv01, + signatureFormat: KeyAlgoRSASHA256, + }, + }, + } + if err := tryAuth(t, clientConfig); err != nil { + t.Fatalf("unable to dial remote side: %s", err) + } +} + +func TestKeyboardInteractiveAuthEarlyFail(t *testing.T) { + const maxAuthTries = 2 + + c1, c2, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + defer c1.Close() + defer c2.Close() + + // Start testserver + serverConfig := &ServerConfig{ + MaxAuthTries: maxAuthTries, + KeyboardInteractiveCallback: func(c ConnMetadata, + client KeyboardInteractiveChallenge) (*Permissions, error) { + // Fail keyboard-interactive authentication early before + // any prompt is sent to client. + return nil, errors.New("keyboard-interactive auth failed") + }, + PasswordCallback: func(c ConnMetadata, + pass []byte) (*Permissions, error) { + if string(pass) == clientPassword { + return nil, nil + } + return nil, errors.New("password auth failed") + }, + } + serverConfig.AddHostKey(testSigners["rsa"]) + + serverDone := make(chan struct{}) + go func() { + defer func() { serverDone <- struct{}{} }() + conn, chans, reqs, err := NewServerConn(c2, serverConfig) + if err != nil { + return + } + _ = conn.Close() + + discarderDone := make(chan struct{}) + go func() { + defer func() { discarderDone <- struct{}{} }() + DiscardRequests(reqs) + }() + for newChannel := range chans { + newChannel.Reject(Prohibited, + "testserver not accepting requests") + } + + <-discarderDone + }() + + // Connect to testserver, expect KeyboardInteractive() to be not called, + // PasswordCallback() to be called and connection to succeed. + passwordCallbackCalled := false + clientConfig := &ClientConfig{ + User: "testuser", + Auth: []AuthMethod{ + RetryableAuthMethod(KeyboardInteractive(func(name, + instruction string, questions []string, + echos []bool) ([]string, error) { + t.Errorf("unexpected call to KeyboardInteractive()") + return []string{clientPassword}, nil + }), maxAuthTries), + RetryableAuthMethod(PasswordCallback(func() (secret string, + err error) { + t.Logf("PasswordCallback()") + passwordCallbackCalled = true + return clientPassword, nil + }), maxAuthTries), + }, + HostKeyCallback: InsecureIgnoreHostKey(), + } + + conn, _, _, err := NewClientConn(c1, "", clientConfig) + if err != nil { + t.Errorf("unexpected NewClientConn() error: %v", err) + } + if conn != nil { + conn.Close() + } + + // Wait for server to finish. + <-serverDone + + if !passwordCallbackCalled { + t.Errorf("expected PasswordCallback() to be called") + } +} diff --git a/tempfork/sshtest/ssh/client_test.go b/tempfork/sshtest/ssh/client_test.go new file mode 100644 index 000000000..2621f0ea5 --- /dev/null +++ b/tempfork/sshtest/ssh/client_test.go @@ -0,0 +1,367 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "bytes" + "crypto/rand" + "errors" + "fmt" + "net" + "strings" + "testing" +) + +func TestClientVersion(t *testing.T) { + for _, tt := range []struct { + name string + version string + multiLine string + wantErr bool + }{ + { + name: "default version", + version: packageVersion, + }, + { + name: "custom version", + version: "SSH-2.0-CustomClientVersionString", + }, + { + name: "good multi line version", + version: packageVersion, + multiLine: strings.Repeat("ignored\r\n", 20), + }, + { + name: "bad multi line version", + version: packageVersion, + multiLine: "bad multi line version", + wantErr: true, + }, + { + name: "long multi line version", + version: packageVersion, + multiLine: strings.Repeat("long multi line version\r\n", 50)[:256], + wantErr: true, + }, + } { + t.Run(tt.name, func(t *testing.T) { + c1, c2, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + defer c1.Close() + defer c2.Close() + go func() { + if tt.multiLine != "" { + c1.Write([]byte(tt.multiLine)) + } + NewClientConn(c1, "", &ClientConfig{ + ClientVersion: tt.version, + HostKeyCallback: InsecureIgnoreHostKey(), + }) + c1.Close() + }() + conf := &ServerConfig{NoClientAuth: true} + conf.AddHostKey(testSigners["rsa"]) + conn, _, _, err := NewServerConn(c2, conf) + if err == nil == tt.wantErr { + t.Fatalf("got err %v; wantErr %t", err, tt.wantErr) + } + if tt.wantErr { + // Don't verify the version on an expected error. + return + } + if got := string(conn.ClientVersion()); got != tt.version { + t.Fatalf("got %q; want %q", got, tt.version) + } + }) + } +} + +func TestHostKeyCheck(t *testing.T) { + for _, tt := range []struct { + name string + wantError string + key PublicKey + }{ + {"no callback", "must specify HostKeyCallback", nil}, + {"correct key", "", testSigners["rsa"].PublicKey()}, + {"mismatch", "mismatch", testSigners["ecdsa"].PublicKey()}, + } { + c1, c2, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + defer c1.Close() + defer c2.Close() + serverConf := &ServerConfig{ + NoClientAuth: true, + } + serverConf.AddHostKey(testSigners["rsa"]) + + go NewServerConn(c1, serverConf) + clientConf := ClientConfig{ + User: "user", + } + if tt.key != nil { + clientConf.HostKeyCallback = FixedHostKey(tt.key) + } + + _, _, _, err = NewClientConn(c2, "", &clientConf) + if err != nil { + if tt.wantError == "" || !strings.Contains(err.Error(), tt.wantError) { + t.Errorf("%s: got error %q, missing %q", tt.name, err.Error(), tt.wantError) + } + } else if tt.wantError != "" { + t.Errorf("%s: succeeded, but want error string %q", tt.name, tt.wantError) + } + } +} + +func TestVerifyHostKeySignature(t *testing.T) { + for _, tt := range []struct { + key string + signAlgo string + verifyAlgo string + wantError string + }{ + {"rsa", KeyAlgoRSA, KeyAlgoRSA, ""}, + {"rsa", KeyAlgoRSASHA256, KeyAlgoRSASHA256, ""}, + {"rsa", KeyAlgoRSA, KeyAlgoRSASHA512, `ssh: invalid signature algorithm "ssh-rsa", expected "rsa-sha2-512"`}, + {"ed25519", KeyAlgoED25519, KeyAlgoED25519, ""}, + } { + key := testSigners[tt.key].PublicKey() + s, ok := testSigners[tt.key].(AlgorithmSigner) + if !ok { + t.Fatalf("needed an AlgorithmSigner") + } + sig, err := s.SignWithAlgorithm(rand.Reader, []byte("test"), tt.signAlgo) + if err != nil { + t.Fatalf("couldn't sign: %q", err) + } + + b := bytes.Buffer{} + writeString(&b, []byte(sig.Format)) + writeString(&b, sig.Blob) + + result := kexResult{Signature: b.Bytes(), H: []byte("test")} + + err = verifyHostKeySignature(key, tt.verifyAlgo, &result) + if err != nil { + if tt.wantError == "" || !strings.Contains(err.Error(), tt.wantError) { + t.Errorf("got error %q, expecting %q", err.Error(), tt.wantError) + } + } else if tt.wantError != "" { + t.Errorf("succeeded, but want error string %q", tt.wantError) + } + } +} + +func TestBannerCallback(t *testing.T) { + c1, c2, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + defer c1.Close() + defer c2.Close() + + serverConf := &ServerConfig{ + PasswordCallback: func(conn ConnMetadata, password []byte) (*Permissions, error) { + return &Permissions{}, nil + }, + BannerCallback: func(conn ConnMetadata) string { + return "Hello World" + }, + } + serverConf.AddHostKey(testSigners["rsa"]) + go NewServerConn(c1, serverConf) + + var receivedBanner string + var bannerCount int + clientConf := ClientConfig{ + Auth: []AuthMethod{ + Password("123"), + }, + User: "user", + HostKeyCallback: InsecureIgnoreHostKey(), + BannerCallback: func(message string) error { + bannerCount++ + receivedBanner = message + return nil + }, + } + + _, _, _, err = NewClientConn(c2, "", &clientConf) + if err != nil { + t.Fatal(err) + } + + if bannerCount != 1 { + t.Errorf("got %d banners; want 1", bannerCount) + } + + expected := "Hello World" + if receivedBanner != expected { + t.Fatalf("got %s; want %s", receivedBanner, expected) + } +} + +func TestNewClientConn(t *testing.T) { + errHostKeyMismatch := errors.New("host key mismatch") + + for _, tt := range []struct { + name string + user string + simulateHostKeyMismatch HostKeyCallback + }{ + { + name: "good user field for ConnMetadata", + user: "testuser", + }, + { + name: "empty user field for ConnMetadata", + user: "", + }, + { + name: "host key mismatch", + user: "testuser", + simulateHostKeyMismatch: func(hostname string, remote net.Addr, key PublicKey) error { + return fmt.Errorf("%w: %s", errHostKeyMismatch, bytes.TrimSpace(MarshalAuthorizedKey(key))) + }, + }, + } { + t.Run(tt.name, func(t *testing.T) { + c1, c2, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + defer c1.Close() + defer c2.Close() + + serverConf := &ServerConfig{ + PasswordCallback: func(conn ConnMetadata, password []byte) (*Permissions, error) { + return &Permissions{}, nil + }, + } + serverConf.AddHostKey(testSigners["rsa"]) + go NewServerConn(c1, serverConf) + + clientConf := &ClientConfig{ + User: tt.user, + Auth: []AuthMethod{ + Password("testpw"), + }, + HostKeyCallback: InsecureIgnoreHostKey(), + } + + if tt.simulateHostKeyMismatch != nil { + clientConf.HostKeyCallback = tt.simulateHostKeyMismatch + } + + clientConn, _, _, err := NewClientConn(c2, "", clientConf) + if err != nil { + if tt.simulateHostKeyMismatch != nil && errors.Is(err, errHostKeyMismatch) { + return + } + t.Fatal(err) + } + + if userGot := clientConn.User(); userGot != tt.user { + t.Errorf("got user %q; want user %q", userGot, tt.user) + } + }) + } +} + +func TestUnsupportedAlgorithm(t *testing.T) { + for _, tt := range []struct { + name string + config Config + wantError string + }{ + { + "unsupported KEX", + Config{ + KeyExchanges: []string{"unsupported"}, + }, + "no common algorithm", + }, + { + "unsupported and supported KEXs", + Config{ + KeyExchanges: []string{"unsupported", kexAlgoCurve25519SHA256}, + }, + "", + }, + { + "unsupported cipher", + Config{ + Ciphers: []string{"unsupported"}, + }, + "no common algorithm", + }, + { + "unsupported and supported ciphers", + Config{ + Ciphers: []string{"unsupported", chacha20Poly1305ID}, + }, + "", + }, + { + "unsupported MAC", + Config{ + MACs: []string{"unsupported"}, + // MAC is used for non AAED ciphers. + Ciphers: []string{"aes256-ctr"}, + }, + "no common algorithm", + }, + { + "unsupported and supported MACs", + Config{ + MACs: []string{"unsupported", "hmac-sha2-256-etm@openssh.com"}, + // MAC is used for non AAED ciphers. + Ciphers: []string{"aes256-ctr"}, + }, + "", + }, + } { + t.Run(tt.name, func(t *testing.T) { + c1, c2, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + defer c1.Close() + defer c2.Close() + + serverConf := &ServerConfig{ + Config: tt.config, + PasswordCallback: func(conn ConnMetadata, password []byte) (*Permissions, error) { + return &Permissions{}, nil + }, + } + serverConf.AddHostKey(testSigners["rsa"]) + go NewServerConn(c1, serverConf) + + clientConf := &ClientConfig{ + User: "testuser", + Config: tt.config, + Auth: []AuthMethod{ + Password("testpw"), + }, + HostKeyCallback: InsecureIgnoreHostKey(), + } + _, _, _, err = NewClientConn(c2, "", clientConf) + if err != nil { + if tt.wantError == "" || !strings.Contains(err.Error(), tt.wantError) { + t.Errorf("%s: got error %q, missing %q", tt.name, err.Error(), tt.wantError) + } + } else if tt.wantError != "" { + t.Errorf("%s: succeeded, but want error string %q", tt.name, tt.wantError) + } + }) + } +} diff --git a/tempfork/sshtest/ssh/common.go b/tempfork/sshtest/ssh/common.go new file mode 100644 index 000000000..7e9c2cbc6 --- /dev/null +++ b/tempfork/sshtest/ssh/common.go @@ -0,0 +1,476 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "crypto" + "crypto/rand" + "fmt" + "io" + "math" + "sync" + + _ "crypto/sha1" + _ "crypto/sha256" + _ "crypto/sha512" +) + +// These are string constants in the SSH protocol. +const ( + compressionNone = "none" + serviceUserAuth = "ssh-userauth" + serviceSSH = "ssh-connection" +) + +// supportedCiphers lists ciphers we support but might not recommend. +var supportedCiphers = []string{ + "aes128-ctr", "aes192-ctr", "aes256-ctr", + "aes128-gcm@openssh.com", gcm256CipherID, + chacha20Poly1305ID, + "arcfour256", "arcfour128", "arcfour", + aes128cbcID, + tripledescbcID, +} + +// preferredCiphers specifies the default preference for ciphers. +var preferredCiphers = []string{ + "aes128-gcm@openssh.com", gcm256CipherID, + chacha20Poly1305ID, + "aes128-ctr", "aes192-ctr", "aes256-ctr", +} + +// supportedKexAlgos specifies the supported key-exchange algorithms in +// preference order. +var supportedKexAlgos = []string{ + kexAlgoCurve25519SHA256, kexAlgoCurve25519SHA256LibSSH, + // P384 and P521 are not constant-time yet, but since we don't + // reuse ephemeral keys, using them for ECDH should be OK. + kexAlgoECDH256, kexAlgoECDH384, kexAlgoECDH521, + kexAlgoDH14SHA256, kexAlgoDH16SHA512, kexAlgoDH14SHA1, + kexAlgoDH1SHA1, +} + +// serverForbiddenKexAlgos contains key exchange algorithms, that are forbidden +// for the server half. +var serverForbiddenKexAlgos = map[string]struct{}{ + kexAlgoDHGEXSHA1: {}, // server half implementation is only minimal to satisfy the automated tests + kexAlgoDHGEXSHA256: {}, // server half implementation is only minimal to satisfy the automated tests +} + +// preferredKexAlgos specifies the default preference for key-exchange +// algorithms in preference order. The diffie-hellman-group16-sha512 algorithm +// is disabled by default because it is a bit slower than the others. +var preferredKexAlgos = []string{ + kexAlgoCurve25519SHA256, kexAlgoCurve25519SHA256LibSSH, + kexAlgoECDH256, kexAlgoECDH384, kexAlgoECDH521, + kexAlgoDH14SHA256, kexAlgoDH14SHA1, +} + +// supportedHostKeyAlgos specifies the supported host-key algorithms (i.e. methods +// of authenticating servers) in preference order. +var supportedHostKeyAlgos = []string{ + CertAlgoRSASHA256v01, CertAlgoRSASHA512v01, + CertAlgoRSAv01, CertAlgoDSAv01, CertAlgoECDSA256v01, + CertAlgoECDSA384v01, CertAlgoECDSA521v01, CertAlgoED25519v01, + + KeyAlgoECDSA256, KeyAlgoECDSA384, KeyAlgoECDSA521, + KeyAlgoRSASHA256, KeyAlgoRSASHA512, + KeyAlgoRSA, KeyAlgoDSA, + + KeyAlgoED25519, +} + +// supportedMACs specifies a default set of MAC algorithms in preference order. +// This is based on RFC 4253, section 6.4, but with hmac-md5 variants removed +// because they have reached the end of their useful life. +var supportedMACs = []string{ + "hmac-sha2-256-etm@openssh.com", "hmac-sha2-512-etm@openssh.com", "hmac-sha2-256", "hmac-sha2-512", "hmac-sha1", "hmac-sha1-96", +} + +var supportedCompressions = []string{compressionNone} + +// hashFuncs keeps the mapping of supported signature algorithms to their +// respective hashes needed for signing and verification. +var hashFuncs = map[string]crypto.Hash{ + KeyAlgoRSA: crypto.SHA1, + KeyAlgoRSASHA256: crypto.SHA256, + KeyAlgoRSASHA512: crypto.SHA512, + KeyAlgoDSA: crypto.SHA1, + KeyAlgoECDSA256: crypto.SHA256, + KeyAlgoECDSA384: crypto.SHA384, + KeyAlgoECDSA521: crypto.SHA512, + // KeyAlgoED25519 doesn't pre-hash. + KeyAlgoSKECDSA256: crypto.SHA256, + KeyAlgoSKED25519: crypto.SHA256, +} + +// algorithmsForKeyFormat returns the supported signature algorithms for a given +// public key format (PublicKey.Type), in order of preference. See RFC 8332, +// Section 2. See also the note in sendKexInit on backwards compatibility. +func algorithmsForKeyFormat(keyFormat string) []string { + switch keyFormat { + case KeyAlgoRSA: + return []string{KeyAlgoRSASHA256, KeyAlgoRSASHA512, KeyAlgoRSA} + case CertAlgoRSAv01: + return []string{CertAlgoRSASHA256v01, CertAlgoRSASHA512v01, CertAlgoRSAv01} + default: + return []string{keyFormat} + } +} + +// isRSA returns whether algo is a supported RSA algorithm, including certificate +// algorithms. +func isRSA(algo string) bool { + algos := algorithmsForKeyFormat(KeyAlgoRSA) + return contains(algos, underlyingAlgo(algo)) +} + +func isRSACert(algo string) bool { + _, ok := certKeyAlgoNames[algo] + if !ok { + return false + } + return isRSA(algo) +} + +// supportedPubKeyAuthAlgos specifies the supported client public key +// authentication algorithms. Note that this doesn't include certificate types +// since those use the underlying algorithm. This list is sent to the client if +// it supports the server-sig-algs extension. Order is irrelevant. +var supportedPubKeyAuthAlgos = []string{ + KeyAlgoED25519, + KeyAlgoSKED25519, KeyAlgoSKECDSA256, + KeyAlgoECDSA256, KeyAlgoECDSA384, KeyAlgoECDSA521, + KeyAlgoRSASHA256, KeyAlgoRSASHA512, KeyAlgoRSA, + KeyAlgoDSA, +} + +// unexpectedMessageError results when the SSH message that we received didn't +// match what we wanted. +func unexpectedMessageError(expected, got uint8) error { + return fmt.Errorf("ssh: unexpected message type %d (expected %d)", got, expected) +} + +// parseError results from a malformed SSH message. +func parseError(tag uint8) error { + return fmt.Errorf("ssh: parse error in message type %d", tag) +} + +func findCommon(what string, client []string, server []string) (common string, err error) { + for _, c := range client { + for _, s := range server { + if c == s { + return c, nil + } + } + } + return "", fmt.Errorf("ssh: no common algorithm for %s; client offered: %v, server offered: %v", what, client, server) +} + +// directionAlgorithms records algorithm choices in one direction (either read or write) +type directionAlgorithms struct { + Cipher string + MAC string + Compression string +} + +// rekeyBytes returns a rekeying intervals in bytes. +func (a *directionAlgorithms) rekeyBytes() int64 { + // According to RFC 4344 block ciphers should rekey after + // 2^(BLOCKSIZE/4) blocks. For all AES flavors BLOCKSIZE is + // 128. + switch a.Cipher { + case "aes128-ctr", "aes192-ctr", "aes256-ctr", gcm128CipherID, gcm256CipherID, aes128cbcID: + return 16 * (1 << 32) + + } + + // For others, stick with RFC 4253 recommendation to rekey after 1 Gb of data. + return 1 << 30 +} + +var aeadCiphers = map[string]bool{ + gcm128CipherID: true, + gcm256CipherID: true, + chacha20Poly1305ID: true, +} + +type algorithms struct { + kex string + hostKey string + w directionAlgorithms + r directionAlgorithms +} + +func findAgreedAlgorithms(isClient bool, clientKexInit, serverKexInit *kexInitMsg) (algs *algorithms, err error) { + result := &algorithms{} + + result.kex, err = findCommon("key exchange", clientKexInit.KexAlgos, serverKexInit.KexAlgos) + if err != nil { + return + } + + result.hostKey, err = findCommon("host key", clientKexInit.ServerHostKeyAlgos, serverKexInit.ServerHostKeyAlgos) + if err != nil { + return + } + + stoc, ctos := &result.w, &result.r + if isClient { + ctos, stoc = stoc, ctos + } + + ctos.Cipher, err = findCommon("client to server cipher", clientKexInit.CiphersClientServer, serverKexInit.CiphersClientServer) + if err != nil { + return + } + + stoc.Cipher, err = findCommon("server to client cipher", clientKexInit.CiphersServerClient, serverKexInit.CiphersServerClient) + if err != nil { + return + } + + if !aeadCiphers[ctos.Cipher] { + ctos.MAC, err = findCommon("client to server MAC", clientKexInit.MACsClientServer, serverKexInit.MACsClientServer) + if err != nil { + return + } + } + + if !aeadCiphers[stoc.Cipher] { + stoc.MAC, err = findCommon("server to client MAC", clientKexInit.MACsServerClient, serverKexInit.MACsServerClient) + if err != nil { + return + } + } + + ctos.Compression, err = findCommon("client to server compression", clientKexInit.CompressionClientServer, serverKexInit.CompressionClientServer) + if err != nil { + return + } + + stoc.Compression, err = findCommon("server to client compression", clientKexInit.CompressionServerClient, serverKexInit.CompressionServerClient) + if err != nil { + return + } + + return result, nil +} + +// If rekeythreshold is too small, we can't make any progress sending +// stuff. +const minRekeyThreshold uint64 = 256 + +// Config contains configuration data common to both ServerConfig and +// ClientConfig. +type Config struct { + // Rand provides the source of entropy for cryptographic + // primitives. If Rand is nil, the cryptographic random reader + // in package crypto/rand will be used. + Rand io.Reader + + // The maximum number of bytes sent or received after which a + // new key is negotiated. It must be at least 256. If + // unspecified, a size suitable for the chosen cipher is used. + RekeyThreshold uint64 + + // The allowed key exchanges algorithms. If unspecified then a default set + // of algorithms is used. Unsupported values are silently ignored. + KeyExchanges []string + + // The allowed cipher algorithms. If unspecified then a sensible default is + // used. Unsupported values are silently ignored. + Ciphers []string + + // The allowed MAC algorithms. If unspecified then a sensible default is + // used. Unsupported values are silently ignored. + MACs []string +} + +// SetDefaults sets sensible values for unset fields in config. This is +// exported for testing: Configs passed to SSH functions are copied and have +// default values set automatically. +func (c *Config) SetDefaults() { + if c.Rand == nil { + c.Rand = rand.Reader + } + if c.Ciphers == nil { + c.Ciphers = preferredCiphers + } + var ciphers []string + for _, c := range c.Ciphers { + if cipherModes[c] != nil { + // Ignore the cipher if we have no cipherModes definition. + ciphers = append(ciphers, c) + } + } + c.Ciphers = ciphers + + if c.KeyExchanges == nil { + c.KeyExchanges = preferredKexAlgos + } + var kexs []string + for _, k := range c.KeyExchanges { + if kexAlgoMap[k] != nil { + // Ignore the KEX if we have no kexAlgoMap definition. + kexs = append(kexs, k) + } + } + c.KeyExchanges = kexs + + if c.MACs == nil { + c.MACs = supportedMACs + } + var macs []string + for _, m := range c.MACs { + if macModes[m] != nil { + // Ignore the MAC if we have no macModes definition. + macs = append(macs, m) + } + } + c.MACs = macs + + if c.RekeyThreshold == 0 { + // cipher specific default + } else if c.RekeyThreshold < minRekeyThreshold { + c.RekeyThreshold = minRekeyThreshold + } else if c.RekeyThreshold >= math.MaxInt64 { + // Avoid weirdness if somebody uses -1 as a threshold. + c.RekeyThreshold = math.MaxInt64 + } +} + +// buildDataSignedForAuth returns the data that is signed in order to prove +// possession of a private key. See RFC 4252, section 7. algo is the advertised +// algorithm, and may be a certificate type. +func buildDataSignedForAuth(sessionID []byte, req userAuthRequestMsg, algo string, pubKey []byte) []byte { + data := struct { + Session []byte + Type byte + User string + Service string + Method string + Sign bool + Algo string + PubKey []byte + }{ + sessionID, + msgUserAuthRequest, + req.User, + req.Service, + req.Method, + true, + algo, + pubKey, + } + return Marshal(data) +} + +func appendU16(buf []byte, n uint16) []byte { + return append(buf, byte(n>>8), byte(n)) +} + +func appendU32(buf []byte, n uint32) []byte { + return append(buf, byte(n>>24), byte(n>>16), byte(n>>8), byte(n)) +} + +func appendU64(buf []byte, n uint64) []byte { + return append(buf, + byte(n>>56), byte(n>>48), byte(n>>40), byte(n>>32), + byte(n>>24), byte(n>>16), byte(n>>8), byte(n)) +} + +func appendInt(buf []byte, n int) []byte { + return appendU32(buf, uint32(n)) +} + +func appendString(buf []byte, s string) []byte { + buf = appendU32(buf, uint32(len(s))) + buf = append(buf, s...) + return buf +} + +func appendBool(buf []byte, b bool) []byte { + if b { + return append(buf, 1) + } + return append(buf, 0) +} + +// newCond is a helper to hide the fact that there is no usable zero +// value for sync.Cond. +func newCond() *sync.Cond { return sync.NewCond(new(sync.Mutex)) } + +// window represents the buffer available to clients +// wishing to write to a channel. +type window struct { + *sync.Cond + win uint32 // RFC 4254 5.2 says the window size can grow to 2^32-1 + writeWaiters int + closed bool +} + +// add adds win to the amount of window available +// for consumers. +func (w *window) add(win uint32) bool { + // a zero sized window adjust is a noop. + if win == 0 { + return true + } + w.L.Lock() + if w.win+win < win { + w.L.Unlock() + return false + } + w.win += win + // It is unusual that multiple goroutines would be attempting to reserve + // window space, but not guaranteed. Use broadcast to notify all waiters + // that additional window is available. + w.Broadcast() + w.L.Unlock() + return true +} + +// close sets the window to closed, so all reservations fail +// immediately. +func (w *window) close() { + w.L.Lock() + w.closed = true + w.Broadcast() + w.L.Unlock() +} + +// reserve reserves win from the available window capacity. +// If no capacity remains, reserve will block. reserve may +// return less than requested. +func (w *window) reserve(win uint32) (uint32, error) { + var err error + w.L.Lock() + w.writeWaiters++ + w.Broadcast() + for w.win == 0 && !w.closed { + w.Wait() + } + w.writeWaiters-- + if w.win < win { + win = w.win + } + w.win -= win + if w.closed { + err = io.EOF + } + w.L.Unlock() + return win, err +} + +// waitWriterBlocked waits until some goroutine is blocked for further +// writes. It is used in tests only. +func (w *window) waitWriterBlocked() { + w.Cond.L.Lock() + for w.writeWaiters == 0 { + w.Cond.Wait() + } + w.Cond.L.Unlock() +} diff --git a/tempfork/sshtest/ssh/common_test.go b/tempfork/sshtest/ssh/common_test.go new file mode 100644 index 000000000..a7beee8e8 --- /dev/null +++ b/tempfork/sshtest/ssh/common_test.go @@ -0,0 +1,176 @@ +// Copyright 2019 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "reflect" + "testing" +) + +func TestFindAgreedAlgorithms(t *testing.T) { + initKex := func(k *kexInitMsg) { + if k.KexAlgos == nil { + k.KexAlgos = []string{"kex1"} + } + if k.ServerHostKeyAlgos == nil { + k.ServerHostKeyAlgos = []string{"hostkey1"} + } + if k.CiphersClientServer == nil { + k.CiphersClientServer = []string{"cipher1"} + + } + if k.CiphersServerClient == nil { + k.CiphersServerClient = []string{"cipher1"} + + } + if k.MACsClientServer == nil { + k.MACsClientServer = []string{"mac1"} + + } + if k.MACsServerClient == nil { + k.MACsServerClient = []string{"mac1"} + + } + if k.CompressionClientServer == nil { + k.CompressionClientServer = []string{"compression1"} + + } + if k.CompressionServerClient == nil { + k.CompressionServerClient = []string{"compression1"} + + } + if k.LanguagesClientServer == nil { + k.LanguagesClientServer = []string{"language1"} + + } + if k.LanguagesServerClient == nil { + k.LanguagesServerClient = []string{"language1"} + + } + } + + initDirAlgs := func(a *directionAlgorithms) { + if a.Cipher == "" { + a.Cipher = "cipher1" + } + if a.MAC == "" { + a.MAC = "mac1" + } + if a.Compression == "" { + a.Compression = "compression1" + } + } + + initAlgs := func(a *algorithms) { + if a.kex == "" { + a.kex = "kex1" + } + if a.hostKey == "" { + a.hostKey = "hostkey1" + } + initDirAlgs(&a.r) + initDirAlgs(&a.w) + } + + type testcase struct { + name string + clientIn, serverIn kexInitMsg + wantClient, wantServer algorithms + wantErr bool + } + + cases := []testcase{ + { + name: "standard", + }, + + { + name: "no common hostkey", + serverIn: kexInitMsg{ + ServerHostKeyAlgos: []string{"hostkey2"}, + }, + wantErr: true, + }, + + { + name: "no common kex", + serverIn: kexInitMsg{ + KexAlgos: []string{"kex2"}, + }, + wantErr: true, + }, + + { + name: "no common cipher", + serverIn: kexInitMsg{ + CiphersClientServer: []string{"cipher2"}, + }, + wantErr: true, + }, + + { + name: "client decides cipher", + serverIn: kexInitMsg{ + CiphersClientServer: []string{"cipher1", "cipher2"}, + CiphersServerClient: []string{"cipher2", "cipher3"}, + }, + clientIn: kexInitMsg{ + CiphersClientServer: []string{"cipher2", "cipher1"}, + CiphersServerClient: []string{"cipher3", "cipher2"}, + }, + wantClient: algorithms{ + r: directionAlgorithms{ + Cipher: "cipher3", + }, + w: directionAlgorithms{ + Cipher: "cipher2", + }, + }, + wantServer: algorithms{ + w: directionAlgorithms{ + Cipher: "cipher3", + }, + r: directionAlgorithms{ + Cipher: "cipher2", + }, + }, + }, + + // TODO(hanwen): fix and add tests for AEAD ignoring + // the MACs field + } + + for i := range cases { + initKex(&cases[i].clientIn) + initKex(&cases[i].serverIn) + initAlgs(&cases[i].wantClient) + initAlgs(&cases[i].wantServer) + } + + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + serverAlgs, serverErr := findAgreedAlgorithms(false, &c.clientIn, &c.serverIn) + clientAlgs, clientErr := findAgreedAlgorithms(true, &c.clientIn, &c.serverIn) + + serverHasErr := serverErr != nil + clientHasErr := clientErr != nil + if c.wantErr != serverHasErr || c.wantErr != clientHasErr { + t.Fatalf("got client/server error (%v, %v), want hasError %v", + clientErr, serverErr, c.wantErr) + + } + if c.wantErr { + return + } + + if !reflect.DeepEqual(serverAlgs, &c.wantServer) { + t.Errorf("server: got algs %#v, want %#v", serverAlgs, &c.wantServer) + } + if !reflect.DeepEqual(clientAlgs, &c.wantClient) { + t.Errorf("server: got algs %#v, want %#v", clientAlgs, &c.wantClient) + } + }) + } +} diff --git a/tempfork/sshtest/ssh/connection.go b/tempfork/sshtest/ssh/connection.go new file mode 100644 index 000000000..8f345ee92 --- /dev/null +++ b/tempfork/sshtest/ssh/connection.go @@ -0,0 +1,143 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "fmt" + "net" +) + +// OpenChannelError is returned if the other side rejects an +// OpenChannel request. +type OpenChannelError struct { + Reason RejectionReason + Message string +} + +func (e *OpenChannelError) Error() string { + return fmt.Sprintf("ssh: rejected: %s (%s)", e.Reason, e.Message) +} + +// ConnMetadata holds metadata for the connection. +type ConnMetadata interface { + // User returns the user ID for this connection. + User() string + + // SessionID returns the session hash, also denoted by H. + SessionID() []byte + + // ClientVersion returns the client's version string as hashed + // into the session ID. + ClientVersion() []byte + + // ServerVersion returns the server's version string as hashed + // into the session ID. + ServerVersion() []byte + + // RemoteAddr returns the remote address for this connection. + RemoteAddr() net.Addr + + // LocalAddr returns the local address for this connection. + LocalAddr() net.Addr +} + +// Conn represents an SSH connection for both server and client roles. +// Conn is the basis for implementing an application layer, such +// as ClientConn, which implements the traditional shell access for +// clients. +type Conn interface { + ConnMetadata + + // SendRequest sends a global request, and returns the + // reply. If wantReply is true, it returns the response status + // and payload. See also RFC 4254, section 4. + SendRequest(name string, wantReply bool, payload []byte) (bool, []byte, error) + + // OpenChannel tries to open an channel. If the request is + // rejected, it returns *OpenChannelError. On success it returns + // the SSH Channel and a Go channel for incoming, out-of-band + // requests. The Go channel must be serviced, or the + // connection will hang. + OpenChannel(name string, data []byte) (Channel, <-chan *Request, error) + + // Close closes the underlying network connection + Close() error + + // Wait blocks until the connection has shut down, and returns the + // error causing the shutdown. + Wait() error + + // TODO(hanwen): consider exposing: + // RequestKeyChange + // Disconnect +} + +// DiscardRequests consumes and rejects all requests from the +// passed-in channel. +func DiscardRequests(in <-chan *Request) { + for req := range in { + if req.WantReply { + req.Reply(false, nil) + } + } +} + +// A connection represents an incoming connection. +type connection struct { + transport *handshakeTransport + sshConn + + // The connection protocol. + *mux +} + +func (c *connection) Close() error { + return c.sshConn.conn.Close() +} + +// sshConn provides net.Conn metadata, but disallows direct reads and +// writes. +type sshConn struct { + conn net.Conn + + user string + sessionID []byte + clientVersion []byte + serverVersion []byte +} + +func dup(src []byte) []byte { + dst := make([]byte, len(src)) + copy(dst, src) + return dst +} + +func (c *sshConn) User() string { + return c.user +} + +func (c *sshConn) RemoteAddr() net.Addr { + return c.conn.RemoteAddr() +} + +func (c *sshConn) Close() error { + return c.conn.Close() +} + +func (c *sshConn) LocalAddr() net.Addr { + return c.conn.LocalAddr() +} + +func (c *sshConn) SessionID() []byte { + return dup(c.sessionID) +} + +func (c *sshConn) ClientVersion() []byte { + return dup(c.clientVersion) +} + +func (c *sshConn) ServerVersion() []byte { + return dup(c.serverVersion) +} diff --git a/tempfork/sshtest/ssh/doc.go b/tempfork/sshtest/ssh/doc.go new file mode 100644 index 000000000..f5d352fe3 --- /dev/null +++ b/tempfork/sshtest/ssh/doc.go @@ -0,0 +1,23 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +/* +Package ssh implements an SSH client and server. + +SSH is a transport security protocol, an authentication protocol and a +family of application protocols. The most typical application level +protocol is a remote shell and this is specifically implemented. However, +the multiplexed nature of SSH is exposed to users that wish to support +others. + +References: + + [PROTOCOL]: https://cvsweb.openbsd.org/cgi-bin/cvsweb/src/usr.bin/ssh/PROTOCOL?rev=HEAD + [PROTOCOL.certkeys]: http://cvsweb.openbsd.org/cgi-bin/cvsweb/src/usr.bin/ssh/PROTOCOL.certkeys?rev=HEAD + [SSH-PARAMETERS]: http://www.iana.org/assignments/ssh-parameters/ssh-parameters.xml#ssh-parameters-1 + +This package does not fall under the stability promise of the Go language itself, +so its API may be changed when pressing needs arise. +*/ +package ssh diff --git a/tempfork/sshtest/ssh/example_test.go b/tempfork/sshtest/ssh/example_test.go new file mode 100644 index 000000000..97b3b6aba --- /dev/null +++ b/tempfork/sshtest/ssh/example_test.go @@ -0,0 +1,400 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh_test + +import ( + "bufio" + "bytes" + "crypto/rand" + "crypto/rsa" + "fmt" + "log" + "net" + "net/http" + "os" + "path/filepath" + "strings" + "sync" + + "golang.org/x/crypto/ssh" + "golang.org/x/crypto/ssh/terminal" +) + +func ExampleNewServerConn() { + // Public key authentication is done by comparing + // the public key of a received connection + // with the entries in the authorized_keys file. + authorizedKeysBytes, err := os.ReadFile("authorized_keys") + if err != nil { + log.Fatalf("Failed to load authorized_keys, err: %v", err) + } + + authorizedKeysMap := map[string]bool{} + for len(authorizedKeysBytes) > 0 { + pubKey, _, _, rest, err := ssh.ParseAuthorizedKey(authorizedKeysBytes) + if err != nil { + log.Fatal(err) + } + + authorizedKeysMap[string(pubKey.Marshal())] = true + authorizedKeysBytes = rest + } + + // An SSH server is represented by a ServerConfig, which holds + // certificate details and handles authentication of ServerConns. + config := &ssh.ServerConfig{ + // Remove to disable password auth. + PasswordCallback: func(c ssh.ConnMetadata, pass []byte) (*ssh.Permissions, error) { + // Should use constant-time compare (or better, salt+hash) in + // a production setting. + if c.User() == "testuser" && string(pass) == "tiger" { + return nil, nil + } + return nil, fmt.Errorf("password rejected for %q", c.User()) + }, + + // Remove to disable public key auth. + PublicKeyCallback: func(c ssh.ConnMetadata, pubKey ssh.PublicKey) (*ssh.Permissions, error) { + if authorizedKeysMap[string(pubKey.Marshal())] { + return &ssh.Permissions{ + // Record the public key used for authentication. + Extensions: map[string]string{ + "pubkey-fp": ssh.FingerprintSHA256(pubKey), + }, + }, nil + } + return nil, fmt.Errorf("unknown public key for %q", c.User()) + }, + } + + privateBytes, err := os.ReadFile("id_rsa") + if err != nil { + log.Fatal("Failed to load private key: ", err) + } + + private, err := ssh.ParsePrivateKey(privateBytes) + if err != nil { + log.Fatal("Failed to parse private key: ", err) + } + config.AddHostKey(private) + + // Once a ServerConfig has been configured, connections can be + // accepted. + listener, err := net.Listen("tcp", "0.0.0.0:2022") + if err != nil { + log.Fatal("failed to listen for connection: ", err) + } + nConn, err := listener.Accept() + if err != nil { + log.Fatal("failed to accept incoming connection: ", err) + } + + // Before use, a handshake must be performed on the incoming + // net.Conn. + conn, chans, reqs, err := ssh.NewServerConn(nConn, config) + if err != nil { + log.Fatal("failed to handshake: ", err) + } + log.Printf("logged in with key %s", conn.Permissions.Extensions["pubkey-fp"]) + + var wg sync.WaitGroup + defer wg.Wait() + + // The incoming Request channel must be serviced. + wg.Add(1) + go func() { + ssh.DiscardRequests(reqs) + wg.Done() + }() + + // Service the incoming Channel channel. + for newChannel := range chans { + // Channels have a type, depending on the application level + // protocol intended. In the case of a shell, the type is + // "session" and ServerShell may be used to present a simple + // terminal interface. + if newChannel.ChannelType() != "session" { + newChannel.Reject(ssh.UnknownChannelType, "unknown channel type") + continue + } + channel, requests, err := newChannel.Accept() + if err != nil { + log.Fatalf("Could not accept channel: %v", err) + } + + // Sessions have out-of-band requests such as "shell", + // "pty-req" and "env". Here we handle only the + // "shell" request. + wg.Add(1) + go func(in <-chan *ssh.Request) { + for req := range in { + req.Reply(req.Type == "shell", nil) + } + wg.Done() + }(requests) + + term := terminal.NewTerminal(channel, "> ") + + wg.Add(1) + go func() { + defer func() { + channel.Close() + wg.Done() + }() + for { + line, err := term.ReadLine() + if err != nil { + break + } + fmt.Println(line) + } + }() + } +} + +func ExampleServerConfig_AddHostKey() { + // Minimal ServerConfig supporting only password authentication. + config := &ssh.ServerConfig{ + PasswordCallback: func(c ssh.ConnMetadata, pass []byte) (*ssh.Permissions, error) { + // Should use constant-time compare (or better, salt+hash) in + // a production setting. + if c.User() == "testuser" && string(pass) == "tiger" { + return nil, nil + } + return nil, fmt.Errorf("password rejected for %q", c.User()) + }, + } + + privateBytes, err := os.ReadFile("id_rsa") + if err != nil { + log.Fatal("Failed to load private key: ", err) + } + + private, err := ssh.ParsePrivateKey(privateBytes) + if err != nil { + log.Fatal("Failed to parse private key: ", err) + } + // Restrict host key algorithms to disable ssh-rsa. + signer, err := ssh.NewSignerWithAlgorithms(private.(ssh.AlgorithmSigner), []string{ssh.KeyAlgoRSASHA256, ssh.KeyAlgoRSASHA512}) + if err != nil { + log.Fatal("Failed to create private key with restricted algorithms: ", err) + } + config.AddHostKey(signer) +} + +func ExampleClientConfig_HostKeyCallback() { + // Every client must provide a host key check. Here is a + // simple-minded parse of OpenSSH's known_hosts file + host := "hostname" + file, err := os.Open(filepath.Join(os.Getenv("HOME"), ".ssh", "known_hosts")) + if err != nil { + log.Fatal(err) + } + defer file.Close() + + scanner := bufio.NewScanner(file) + var hostKey ssh.PublicKey + for scanner.Scan() { + fields := strings.Split(scanner.Text(), " ") + if len(fields) != 3 { + continue + } + if strings.Contains(fields[0], host) { + var err error + hostKey, _, _, _, err = ssh.ParseAuthorizedKey(scanner.Bytes()) + if err != nil { + log.Fatalf("error parsing %q: %v", fields[2], err) + } + break + } + } + + if hostKey == nil { + log.Fatalf("no hostkey for %s", host) + } + + config := ssh.ClientConfig{ + User: os.Getenv("USER"), + HostKeyCallback: ssh.FixedHostKey(hostKey), + } + + _, err = ssh.Dial("tcp", host+":22", &config) + log.Println(err) +} + +func ExampleDial() { + var hostKey ssh.PublicKey + // An SSH client is represented with a ClientConn. + // + // To authenticate with the remote server you must pass at least one + // implementation of AuthMethod via the Auth field in ClientConfig, + // and provide a HostKeyCallback. + config := &ssh.ClientConfig{ + User: "username", + Auth: []ssh.AuthMethod{ + ssh.Password("yourpassword"), + }, + HostKeyCallback: ssh.FixedHostKey(hostKey), + } + client, err := ssh.Dial("tcp", "yourserver.com:22", config) + if err != nil { + log.Fatal("Failed to dial: ", err) + } + defer client.Close() + + // Each ClientConn can support multiple interactive sessions, + // represented by a Session. + session, err := client.NewSession() + if err != nil { + log.Fatal("Failed to create session: ", err) + } + defer session.Close() + + // Once a Session is created, you can execute a single command on + // the remote side using the Run method. + var b bytes.Buffer + session.Stdout = &b + if err := session.Run("/usr/bin/whoami"); err != nil { + log.Fatal("Failed to run: " + err.Error()) + } + fmt.Println(b.String()) +} + +func ExamplePublicKeys() { + var hostKey ssh.PublicKey + // A public key may be used to authenticate against the remote + // server by using an unencrypted PEM-encoded private key file. + // + // If you have an encrypted private key, the crypto/x509 package + // can be used to decrypt it. + key, err := os.ReadFile("/home/user/.ssh/id_rsa") + if err != nil { + log.Fatalf("unable to read private key: %v", err) + } + + // Create the Signer for this private key. + signer, err := ssh.ParsePrivateKey(key) + if err != nil { + log.Fatalf("unable to parse private key: %v", err) + } + + config := &ssh.ClientConfig{ + User: "user", + Auth: []ssh.AuthMethod{ + // Use the PublicKeys method for remote authentication. + ssh.PublicKeys(signer), + }, + HostKeyCallback: ssh.FixedHostKey(hostKey), + } + + // Connect to the remote server and perform the SSH handshake. + client, err := ssh.Dial("tcp", "host.com:22", config) + if err != nil { + log.Fatalf("unable to connect: %v", err) + } + defer client.Close() +} + +func ExampleClient_Listen() { + var hostKey ssh.PublicKey + config := &ssh.ClientConfig{ + User: "username", + Auth: []ssh.AuthMethod{ + ssh.Password("password"), + }, + HostKeyCallback: ssh.FixedHostKey(hostKey), + } + // Dial your ssh server. + conn, err := ssh.Dial("tcp", "localhost:22", config) + if err != nil { + log.Fatal("unable to connect: ", err) + } + defer conn.Close() + + // Request the remote side to open port 8080 on all interfaces. + l, err := conn.Listen("tcp", "0.0.0.0:8080") + if err != nil { + log.Fatal("unable to register tcp forward: ", err) + } + defer l.Close() + + // Serve HTTP with your SSH server acting as a reverse proxy. + http.Serve(l, http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) { + fmt.Fprintf(resp, "Hello world!\n") + })) +} + +func ExampleSession_RequestPty() { + var hostKey ssh.PublicKey + // Create client config + config := &ssh.ClientConfig{ + User: "username", + Auth: []ssh.AuthMethod{ + ssh.Password("password"), + }, + HostKeyCallback: ssh.FixedHostKey(hostKey), + } + // Connect to ssh server + conn, err := ssh.Dial("tcp", "localhost:22", config) + if err != nil { + log.Fatal("unable to connect: ", err) + } + defer conn.Close() + // Create a session + session, err := conn.NewSession() + if err != nil { + log.Fatal("unable to create session: ", err) + } + defer session.Close() + // Set up terminal modes + modes := ssh.TerminalModes{ + ssh.ECHO: 0, // disable echoing + ssh.TTY_OP_ISPEED: 14400, // input speed = 14.4kbaud + ssh.TTY_OP_OSPEED: 14400, // output speed = 14.4kbaud + } + // Request pseudo terminal + if err := session.RequestPty("xterm", 40, 80, modes); err != nil { + log.Fatal("request for pseudo terminal failed: ", err) + } + // Start remote shell + if err := session.Shell(); err != nil { + log.Fatal("failed to start shell: ", err) + } +} + +func ExampleCertificate_SignCert() { + // Sign a certificate with a specific algorithm. + privateKey, err := rsa.GenerateKey(rand.Reader, 3072) + if err != nil { + log.Fatal("unable to generate RSA key: ", err) + } + publicKey, err := ssh.NewPublicKey(&privateKey.PublicKey) + if err != nil { + log.Fatal("unable to get RSA public key: ", err) + } + caKey, err := rsa.GenerateKey(rand.Reader, 3072) + if err != nil { + log.Fatal("unable to generate CA key: ", err) + } + signer, err := ssh.NewSignerFromKey(caKey) + if err != nil { + log.Fatal("unable to generate signer from key: ", err) + } + mas, err := ssh.NewSignerWithAlgorithms(signer.(ssh.AlgorithmSigner), []string{ssh.KeyAlgoRSASHA256}) + if err != nil { + log.Fatal("unable to create signer with algorithms: ", err) + } + certificate := ssh.Certificate{ + Key: publicKey, + CertType: ssh.UserCert, + } + if err := certificate.SignCert(rand.Reader, mas); err != nil { + log.Fatal("unable to sign certificate: ", err) + } + // Save the public key to a file and check that rsa-sha-256 is used for + // signing: + // ssh-keygen -L -f + fmt.Println(string(ssh.MarshalAuthorizedKey(&certificate))) +} diff --git a/tempfork/sshtest/ssh/handshake.go b/tempfork/sshtest/ssh/handshake.go new file mode 100644 index 000000000..fef687db0 --- /dev/null +++ b/tempfork/sshtest/ssh/handshake.go @@ -0,0 +1,816 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "crypto/rand" + "errors" + "fmt" + "io" + "log" + "net" + "strings" + "sync" +) + +// debugHandshake, if set, prints messages sent and received. Key +// exchange messages are printed as if DH were used, so the debug +// messages are wrong when using ECDH. +const debugHandshake = false + +// chanSize sets the amount of buffering SSH connections. This is +// primarily for testing: setting chanSize=0 uncovers deadlocks more +// quickly. +const chanSize = 16 + +// keyingTransport is a packet based transport that supports key +// changes. It need not be thread-safe. It should pass through +// msgNewKeys in both directions. +type keyingTransport interface { + packetConn + + // prepareKeyChange sets up a key change. The key change for a + // direction will be effected if a msgNewKeys message is sent + // or received. + prepareKeyChange(*algorithms, *kexResult) error + + // setStrictMode sets the strict KEX mode, notably triggering + // sequence number resets on sending or receiving msgNewKeys. + // If the sequence number is already > 1 when setStrictMode + // is called, an error is returned. + setStrictMode() error + + // setInitialKEXDone indicates to the transport that the initial key exchange + // was completed + setInitialKEXDone() +} + +// handshakeTransport implements rekeying on top of a keyingTransport +// and offers a thread-safe writePacket() interface. +type handshakeTransport struct { + conn keyingTransport + config *Config + + serverVersion []byte + clientVersion []byte + + // hostKeys is non-empty if we are the server. In that case, + // it contains all host keys that can be used to sign the + // connection. + hostKeys []Signer + + // publicKeyAuthAlgorithms is non-empty if we are the server. In that case, + // it contains the supported client public key authentication algorithms. + publicKeyAuthAlgorithms []string + + // hostKeyAlgorithms is non-empty if we are the client. In that case, + // we accept these key types from the server as host key. + hostKeyAlgorithms []string + + // On read error, incoming is closed, and readError is set. + incoming chan []byte + readError error + + mu sync.Mutex + writeError error + sentInitPacket []byte + sentInitMsg *kexInitMsg + pendingPackets [][]byte // Used when a key exchange is in progress. + writePacketsLeft uint32 + writeBytesLeft int64 + userAuthComplete bool // whether the user authentication phase is complete + + // If the read loop wants to schedule a kex, it pings this + // channel, and the write loop will send out a kex + // message. + requestKex chan struct{} + + // If the other side requests or confirms a kex, its kexInit + // packet is sent here for the write loop to find it. + startKex chan *pendingKex + kexLoopDone chan struct{} // closed (with writeError non-nil) when kexLoop exits + + // data for host key checking + hostKeyCallback HostKeyCallback + dialAddress string + remoteAddr net.Addr + + // bannerCallback is non-empty if we are the client and it has been set in + // ClientConfig. In that case it is called during the user authentication + // dance to handle a custom server's message. + bannerCallback BannerCallback + + // Algorithms agreed in the last key exchange. + algorithms *algorithms + + // Counters exclusively owned by readLoop. + readPacketsLeft uint32 + readBytesLeft int64 + + // The session ID or nil if first kex did not complete yet. + sessionID []byte + + // strictMode indicates if the other side of the handshake indicated + // that we should be following the strict KEX protocol restrictions. + strictMode bool +} + +type pendingKex struct { + otherInit []byte + done chan error +} + +func newHandshakeTransport(conn keyingTransport, config *Config, clientVersion, serverVersion []byte) *handshakeTransport { + t := &handshakeTransport{ + conn: conn, + serverVersion: serverVersion, + clientVersion: clientVersion, + incoming: make(chan []byte, chanSize), + requestKex: make(chan struct{}, 1), + startKex: make(chan *pendingKex), + kexLoopDone: make(chan struct{}), + + config: config, + } + t.resetReadThresholds() + t.resetWriteThresholds() + + // We always start with a mandatory key exchange. + t.requestKex <- struct{}{} + return t +} + +func newClientTransport(conn keyingTransport, clientVersion, serverVersion []byte, config *ClientConfig, dialAddr string, addr net.Addr) *handshakeTransport { + t := newHandshakeTransport(conn, &config.Config, clientVersion, serverVersion) + t.dialAddress = dialAddr + t.remoteAddr = addr + t.hostKeyCallback = config.HostKeyCallback + t.bannerCallback = config.BannerCallback + if config.HostKeyAlgorithms != nil { + t.hostKeyAlgorithms = config.HostKeyAlgorithms + } else { + t.hostKeyAlgorithms = supportedHostKeyAlgos + } + go t.readLoop() + go t.kexLoop() + return t +} + +func newServerTransport(conn keyingTransport, clientVersion, serverVersion []byte, config *ServerConfig) *handshakeTransport { + t := newHandshakeTransport(conn, &config.Config, clientVersion, serverVersion) + t.hostKeys = config.hostKeys + t.publicKeyAuthAlgorithms = config.PublicKeyAuthAlgorithms + go t.readLoop() + go t.kexLoop() + return t +} + +func (t *handshakeTransport) getSessionID() []byte { + return t.sessionID +} + +// waitSession waits for the session to be established. This should be +// the first thing to call after instantiating handshakeTransport. +func (t *handshakeTransport) waitSession() error { + p, err := t.readPacket() + if err != nil { + return err + } + if p[0] != msgNewKeys { + return fmt.Errorf("ssh: first packet should be msgNewKeys") + } + + return nil +} + +func (t *handshakeTransport) id() string { + if len(t.hostKeys) > 0 { + return "server" + } + return "client" +} + +func (t *handshakeTransport) printPacket(p []byte, write bool) { + action := "got" + if write { + action = "sent" + } + + if p[0] == msgChannelData || p[0] == msgChannelExtendedData { + log.Printf("%s %s data (packet %d bytes)", t.id(), action, len(p)) + } else { + msg, err := decode(p) + log.Printf("%s %s %T %v (%v)", t.id(), action, msg, msg, err) + } +} + +func (t *handshakeTransport) readPacket() ([]byte, error) { + p, ok := <-t.incoming + if !ok { + return nil, t.readError + } + return p, nil +} + +func (t *handshakeTransport) readLoop() { + first := true + for { + p, err := t.readOnePacket(first) + first = false + if err != nil { + t.readError = err + close(t.incoming) + break + } + // If this is the first kex, and strict KEX mode is enabled, + // we don't ignore any messages, as they may be used to manipulate + // the packet sequence numbers. + if !(t.sessionID == nil && t.strictMode) && (p[0] == msgIgnore || p[0] == msgDebug) { + continue + } + t.incoming <- p + } + + // Stop writers too. + t.recordWriteError(t.readError) + + // Unblock the writer should it wait for this. + close(t.startKex) + + // Don't close t.requestKex; it's also written to from writePacket. +} + +func (t *handshakeTransport) pushPacket(p []byte) error { + if debugHandshake { + t.printPacket(p, true) + } + return t.conn.writePacket(p) +} + +func (t *handshakeTransport) getWriteError() error { + t.mu.Lock() + defer t.mu.Unlock() + return t.writeError +} + +func (t *handshakeTransport) recordWriteError(err error) { + t.mu.Lock() + defer t.mu.Unlock() + if t.writeError == nil && err != nil { + t.writeError = err + } +} + +func (t *handshakeTransport) requestKeyExchange() { + select { + case t.requestKex <- struct{}{}: + default: + // something already requested a kex, so do nothing. + } +} + +func (t *handshakeTransport) resetWriteThresholds() { + t.writePacketsLeft = packetRekeyThreshold + if t.config.RekeyThreshold > 0 { + t.writeBytesLeft = int64(t.config.RekeyThreshold) + } else if t.algorithms != nil { + t.writeBytesLeft = t.algorithms.w.rekeyBytes() + } else { + t.writeBytesLeft = 1 << 30 + } +} + +func (t *handshakeTransport) kexLoop() { + +write: + for t.getWriteError() == nil { + var request *pendingKex + var sent bool + + for request == nil || !sent { + var ok bool + select { + case request, ok = <-t.startKex: + if !ok { + break write + } + case <-t.requestKex: + break + } + + if !sent { + if err := t.sendKexInit(); err != nil { + t.recordWriteError(err) + break + } + sent = true + } + } + + if err := t.getWriteError(); err != nil { + if request != nil { + request.done <- err + } + break + } + + // We're not servicing t.requestKex, but that is OK: + // we never block on sending to t.requestKex. + + // We're not servicing t.startKex, but the remote end + // has just sent us a kexInitMsg, so it can't send + // another key change request, until we close the done + // channel on the pendingKex request. + + err := t.enterKeyExchange(request.otherInit) + + t.mu.Lock() + t.writeError = err + t.sentInitPacket = nil + t.sentInitMsg = nil + + t.resetWriteThresholds() + + // we have completed the key exchange. Since the + // reader is still blocked, it is safe to clear out + // the requestKex channel. This avoids the situation + // where: 1) we consumed our own request for the + // initial kex, and 2) the kex from the remote side + // caused another send on the requestKex channel, + clear: + for { + select { + case <-t.requestKex: + // + default: + break clear + } + } + + request.done <- t.writeError + + // kex finished. Push packets that we received while + // the kex was in progress. Don't look at t.startKex + // and don't increment writtenSinceKex: if we trigger + // another kex while we are still busy with the last + // one, things will become very confusing. + for _, p := range t.pendingPackets { + t.writeError = t.pushPacket(p) + if t.writeError != nil { + break + } + } + t.pendingPackets = t.pendingPackets[:0] + t.mu.Unlock() + } + + // Unblock reader. + t.conn.Close() + + // drain startKex channel. We don't service t.requestKex + // because nobody does blocking sends there. + for request := range t.startKex { + request.done <- t.getWriteError() + } + + // Mark that the loop is done so that Close can return. + close(t.kexLoopDone) +} + +// The protocol uses uint32 for packet counters, so we can't let them +// reach 1<<32. We will actually read and write more packets than +// this, though: the other side may send more packets, and after we +// hit this limit on writing we will send a few more packets for the +// key exchange itself. +const packetRekeyThreshold = (1 << 31) + +func (t *handshakeTransport) resetReadThresholds() { + t.readPacketsLeft = packetRekeyThreshold + if t.config.RekeyThreshold > 0 { + t.readBytesLeft = int64(t.config.RekeyThreshold) + } else if t.algorithms != nil { + t.readBytesLeft = t.algorithms.r.rekeyBytes() + } else { + t.readBytesLeft = 1 << 30 + } +} + +func (t *handshakeTransport) readOnePacket(first bool) ([]byte, error) { + p, err := t.conn.readPacket() + if err != nil { + return nil, err + } + + if t.readPacketsLeft > 0 { + t.readPacketsLeft-- + } else { + t.requestKeyExchange() + } + + if t.readBytesLeft > 0 { + t.readBytesLeft -= int64(len(p)) + } else { + t.requestKeyExchange() + } + + if debugHandshake { + t.printPacket(p, false) + } + + if first && p[0] != msgKexInit { + return nil, fmt.Errorf("ssh: first packet should be msgKexInit") + } + + if p[0] != msgKexInit { + return p, nil + } + + firstKex := t.sessionID == nil + + kex := pendingKex{ + done: make(chan error, 1), + otherInit: p, + } + t.startKex <- &kex + err = <-kex.done + + if debugHandshake { + log.Printf("%s exited key exchange (first %v), err %v", t.id(), firstKex, err) + } + + if err != nil { + return nil, err + } + + t.resetReadThresholds() + + // By default, a key exchange is hidden from higher layers by + // translating it into msgIgnore. + successPacket := []byte{msgIgnore} + if firstKex { + // sendKexInit() for the first kex waits for + // msgNewKeys so the authentication process is + // guaranteed to happen over an encrypted transport. + successPacket = []byte{msgNewKeys} + } + + return successPacket, nil +} + +const ( + kexStrictClient = "kex-strict-c-v00@openssh.com" + kexStrictServer = "kex-strict-s-v00@openssh.com" +) + +// sendKexInit sends a key change message. +func (t *handshakeTransport) sendKexInit() error { + t.mu.Lock() + defer t.mu.Unlock() + if t.sentInitMsg != nil { + // kexInits may be sent either in response to the other side, + // or because our side wants to initiate a key change, so we + // may have already sent a kexInit. In that case, don't send a + // second kexInit. + return nil + } + + msg := &kexInitMsg{ + CiphersClientServer: t.config.Ciphers, + CiphersServerClient: t.config.Ciphers, + MACsClientServer: t.config.MACs, + MACsServerClient: t.config.MACs, + CompressionClientServer: supportedCompressions, + CompressionServerClient: supportedCompressions, + } + io.ReadFull(rand.Reader, msg.Cookie[:]) + + // We mutate the KexAlgos slice, in order to add the kex-strict extension algorithm, + // and possibly to add the ext-info extension algorithm. Since the slice may be the + // user owned KeyExchanges, we create our own slice in order to avoid using user + // owned memory by mistake. + msg.KexAlgos = make([]string, 0, len(t.config.KeyExchanges)+2) // room for kex-strict and ext-info + msg.KexAlgos = append(msg.KexAlgos, t.config.KeyExchanges...) + + isServer := len(t.hostKeys) > 0 + if isServer { + for _, k := range t.hostKeys { + // If k is a MultiAlgorithmSigner, we restrict the signature + // algorithms. If k is a AlgorithmSigner, presume it supports all + // signature algorithms associated with the key format. If k is not + // an AlgorithmSigner, we can only assume it only supports the + // algorithms that matches the key format. (This means that Sign + // can't pick a different default). + keyFormat := k.PublicKey().Type() + + switch s := k.(type) { + case MultiAlgorithmSigner: + for _, algo := range algorithmsForKeyFormat(keyFormat) { + if contains(s.Algorithms(), underlyingAlgo(algo)) { + msg.ServerHostKeyAlgos = append(msg.ServerHostKeyAlgos, algo) + } + } + case AlgorithmSigner: + msg.ServerHostKeyAlgos = append(msg.ServerHostKeyAlgos, algorithmsForKeyFormat(keyFormat)...) + default: + msg.ServerHostKeyAlgos = append(msg.ServerHostKeyAlgos, keyFormat) + } + } + + if t.sessionID == nil { + msg.KexAlgos = append(msg.KexAlgos, kexStrictServer) + } + } else { + msg.ServerHostKeyAlgos = t.hostKeyAlgorithms + + // As a client we opt in to receiving SSH_MSG_EXT_INFO so we know what + // algorithms the server supports for public key authentication. See RFC + // 8308, Section 2.1. + // + // We also send the strict KEX mode extension algorithm, in order to opt + // into the strict KEX mode. + if firstKeyExchange := t.sessionID == nil; firstKeyExchange { + msg.KexAlgos = append(msg.KexAlgos, "ext-info-c") + msg.KexAlgos = append(msg.KexAlgos, kexStrictClient) + } + + } + + packet := Marshal(msg) + + // writePacket destroys the contents, so save a copy. + packetCopy := make([]byte, len(packet)) + copy(packetCopy, packet) + + if err := t.pushPacket(packetCopy); err != nil { + return err + } + + t.sentInitMsg = msg + t.sentInitPacket = packet + + return nil +} + +var errSendBannerPhase = errors.New("ssh: SendAuthBanner outside of authentication phase") + +func (t *handshakeTransport) writePacket(p []byte) error { + t.mu.Lock() + defer t.mu.Unlock() + + switch p[0] { + case msgKexInit: + return errors.New("ssh: only handshakeTransport can send kexInit") + case msgNewKeys: + return errors.New("ssh: only handshakeTransport can send newKeys") + case msgUserAuthBanner: + if t.userAuthComplete { + return errSendBannerPhase + } + case msgUserAuthSuccess: + t.userAuthComplete = true + } + + if t.writeError != nil { + return t.writeError + } + + if t.sentInitMsg != nil { + // Copy the packet so the writer can reuse the buffer. + cp := make([]byte, len(p)) + copy(cp, p) + t.pendingPackets = append(t.pendingPackets, cp) + return nil + } + + if t.writeBytesLeft > 0 { + t.writeBytesLeft -= int64(len(p)) + } else { + t.requestKeyExchange() + } + + if t.writePacketsLeft > 0 { + t.writePacketsLeft-- + } else { + t.requestKeyExchange() + } + + if err := t.pushPacket(p); err != nil { + t.writeError = err + } + + return nil +} + +func (t *handshakeTransport) Close() error { + // Close the connection. This should cause the readLoop goroutine to wake up + // and close t.startKex, which will shut down kexLoop if running. + err := t.conn.Close() + + // Wait for the kexLoop goroutine to complete. + // At that point we know that the readLoop goroutine is complete too, + // because kexLoop itself waits for readLoop to close the startKex channel. + <-t.kexLoopDone + + return err +} + +func (t *handshakeTransport) enterKeyExchange(otherInitPacket []byte) error { + if debugHandshake { + log.Printf("%s entered key exchange", t.id()) + } + + otherInit := &kexInitMsg{} + if err := Unmarshal(otherInitPacket, otherInit); err != nil { + return err + } + + magics := handshakeMagics{ + clientVersion: t.clientVersion, + serverVersion: t.serverVersion, + clientKexInit: otherInitPacket, + serverKexInit: t.sentInitPacket, + } + + clientInit := otherInit + serverInit := t.sentInitMsg + isClient := len(t.hostKeys) == 0 + if isClient { + clientInit, serverInit = serverInit, clientInit + + magics.clientKexInit = t.sentInitPacket + magics.serverKexInit = otherInitPacket + } + + var err error + t.algorithms, err = findAgreedAlgorithms(isClient, clientInit, serverInit) + if err != nil { + return err + } + + if t.sessionID == nil && ((isClient && contains(serverInit.KexAlgos, kexStrictServer)) || (!isClient && contains(clientInit.KexAlgos, kexStrictClient))) { + t.strictMode = true + if err := t.conn.setStrictMode(); err != nil { + return err + } + } + + // We don't send FirstKexFollows, but we handle receiving it. + // + // RFC 4253 section 7 defines the kex and the agreement method for + // first_kex_packet_follows. It states that the guessed packet + // should be ignored if the "kex algorithm and/or the host + // key algorithm is guessed wrong (server and client have + // different preferred algorithm), or if any of the other + // algorithms cannot be agreed upon". The other algorithms have + // already been checked above so the kex algorithm and host key + // algorithm are checked here. + if otherInit.FirstKexFollows && (clientInit.KexAlgos[0] != serverInit.KexAlgos[0] || clientInit.ServerHostKeyAlgos[0] != serverInit.ServerHostKeyAlgos[0]) { + // other side sent a kex message for the wrong algorithm, + // which we have to ignore. + if _, err := t.conn.readPacket(); err != nil { + return err + } + } + + kex, ok := kexAlgoMap[t.algorithms.kex] + if !ok { + return fmt.Errorf("ssh: unexpected key exchange algorithm %v", t.algorithms.kex) + } + + var result *kexResult + if len(t.hostKeys) > 0 { + result, err = t.server(kex, &magics) + } else { + result, err = t.client(kex, &magics) + } + + if err != nil { + return err + } + + firstKeyExchange := t.sessionID == nil + if firstKeyExchange { + t.sessionID = result.H + } + result.SessionID = t.sessionID + + if err := t.conn.prepareKeyChange(t.algorithms, result); err != nil { + return err + } + if err = t.conn.writePacket([]byte{msgNewKeys}); err != nil { + return err + } + + // On the server side, after the first SSH_MSG_NEWKEYS, send a SSH_MSG_EXT_INFO + // message with the server-sig-algs extension if the client supports it. See + // RFC 8308, Sections 2.4 and 3.1, and [PROTOCOL], Section 1.9. + if !isClient && firstKeyExchange && contains(clientInit.KexAlgos, "ext-info-c") { + supportedPubKeyAuthAlgosList := strings.Join(t.publicKeyAuthAlgorithms, ",") + extInfo := &extInfoMsg{ + NumExtensions: 2, + Payload: make([]byte, 0, 4+15+4+len(supportedPubKeyAuthAlgosList)+4+16+4+1), + } + extInfo.Payload = appendInt(extInfo.Payload, len("server-sig-algs")) + extInfo.Payload = append(extInfo.Payload, "server-sig-algs"...) + extInfo.Payload = appendInt(extInfo.Payload, len(supportedPubKeyAuthAlgosList)) + extInfo.Payload = append(extInfo.Payload, supportedPubKeyAuthAlgosList...) + extInfo.Payload = appendInt(extInfo.Payload, len("ping@openssh.com")) + extInfo.Payload = append(extInfo.Payload, "ping@openssh.com"...) + extInfo.Payload = appendInt(extInfo.Payload, 1) + extInfo.Payload = append(extInfo.Payload, "0"...) + if err := t.conn.writePacket(Marshal(extInfo)); err != nil { + return err + } + } + + if packet, err := t.conn.readPacket(); err != nil { + return err + } else if packet[0] != msgNewKeys { + return unexpectedMessageError(msgNewKeys, packet[0]) + } + + if firstKeyExchange { + // Indicates to the transport that the first key exchange is completed + // after receiving SSH_MSG_NEWKEYS. + t.conn.setInitialKEXDone() + } + + return nil +} + +// algorithmSignerWrapper is an AlgorithmSigner that only supports the default +// key format algorithm. +// +// This is technically a violation of the AlgorithmSigner interface, but it +// should be unreachable given where we use this. Anyway, at least it returns an +// error instead of panicing or producing an incorrect signature. +type algorithmSignerWrapper struct { + Signer +} + +func (a algorithmSignerWrapper) SignWithAlgorithm(rand io.Reader, data []byte, algorithm string) (*Signature, error) { + if algorithm != underlyingAlgo(a.PublicKey().Type()) { + return nil, errors.New("ssh: internal error: algorithmSignerWrapper invoked with non-default algorithm") + } + return a.Sign(rand, data) +} + +func pickHostKey(hostKeys []Signer, algo string) AlgorithmSigner { + for _, k := range hostKeys { + if s, ok := k.(MultiAlgorithmSigner); ok { + if !contains(s.Algorithms(), underlyingAlgo(algo)) { + continue + } + } + + if algo == k.PublicKey().Type() { + return algorithmSignerWrapper{k} + } + + k, ok := k.(AlgorithmSigner) + if !ok { + continue + } + for _, a := range algorithmsForKeyFormat(k.PublicKey().Type()) { + if algo == a { + return k + } + } + } + return nil +} + +func (t *handshakeTransport) server(kex kexAlgorithm, magics *handshakeMagics) (*kexResult, error) { + hostKey := pickHostKey(t.hostKeys, t.algorithms.hostKey) + if hostKey == nil { + return nil, errors.New("ssh: internal error: negotiated unsupported signature type") + } + + r, err := kex.Server(t.conn, t.config.Rand, magics, hostKey, t.algorithms.hostKey) + return r, err +} + +func (t *handshakeTransport) client(kex kexAlgorithm, magics *handshakeMagics) (*kexResult, error) { + result, err := kex.Client(t.conn, t.config.Rand, magics) + if err != nil { + return nil, err + } + + hostKey, err := ParsePublicKey(result.HostKey) + if err != nil { + return nil, err + } + + if err := verifyHostKeySignature(hostKey, t.algorithms.hostKey, result); err != nil { + return nil, err + } + + err = t.hostKeyCallback(t.dialAddress, t.remoteAddr, hostKey) + if err != nil { + return nil, err + } + + return result, nil +} diff --git a/tempfork/sshtest/ssh/handshake_test.go b/tempfork/sshtest/ssh/handshake_test.go new file mode 100644 index 000000000..2bc607b64 --- /dev/null +++ b/tempfork/sshtest/ssh/handshake_test.go @@ -0,0 +1,1021 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "bytes" + "crypto/rand" + "errors" + "fmt" + "io" + "net" + "reflect" + "runtime" + "strings" + "sync" + "testing" +) + +type testChecker struct { + calls []string +} + +func (t *testChecker) Check(dialAddr string, addr net.Addr, key PublicKey) error { + if dialAddr == "bad" { + return fmt.Errorf("dialAddr is bad") + } + + if tcpAddr, ok := addr.(*net.TCPAddr); !ok || tcpAddr == nil { + return fmt.Errorf("testChecker: got %T want *net.TCPAddr", addr) + } + + t.calls = append(t.calls, fmt.Sprintf("%s %v %s %x", dialAddr, addr, key.Type(), key.Marshal())) + + return nil +} + +// netPipe is analogous to net.Pipe, but it uses a real net.Conn, and +// therefore is buffered (net.Pipe deadlocks if both sides start with +// a write.) +func netPipe() (net.Conn, net.Conn, error) { + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + listener, err = net.Listen("tcp", "[::1]:0") + if err != nil { + return nil, nil, err + } + } + defer listener.Close() + c1, err := net.Dial("tcp", listener.Addr().String()) + if err != nil { + return nil, nil, err + } + + c2, err := listener.Accept() + if err != nil { + c1.Close() + return nil, nil, err + } + + return c1, c2, nil +} + +// noiseTransport inserts ignore messages to check that the read loop +// and the key exchange filters out these messages. +type noiseTransport struct { + keyingTransport +} + +func (t *noiseTransport) writePacket(p []byte) error { + ignore := []byte{msgIgnore} + if err := t.keyingTransport.writePacket(ignore); err != nil { + return err + } + debug := []byte{msgDebug, 1, 2, 3} + if err := t.keyingTransport.writePacket(debug); err != nil { + return err + } + + return t.keyingTransport.writePacket(p) +} + +func addNoiseTransport(t keyingTransport) keyingTransport { + return &noiseTransport{t} +} + +// handshakePair creates two handshakeTransports connected with each +// other. If the noise argument is true, both transports will try to +// confuse the other side by sending ignore and debug messages. +func handshakePair(clientConf *ClientConfig, addr string, noise bool) (client *handshakeTransport, server *handshakeTransport, err error) { + a, b, err := netPipe() + if err != nil { + return nil, nil, err + } + + var trC, trS keyingTransport + + trC = newTransport(a, rand.Reader, true) + trS = newTransport(b, rand.Reader, false) + if noise { + trC = addNoiseTransport(trC) + trS = addNoiseTransport(trS) + } + clientConf.SetDefaults() + + v := []byte("version") + client = newClientTransport(trC, v, v, clientConf, addr, a.RemoteAddr()) + + serverConf := &ServerConfig{} + serverConf.AddHostKey(testSigners["ecdsa"]) + serverConf.AddHostKey(testSigners["rsa"]) + serverConf.SetDefaults() + server = newServerTransport(trS, v, v, serverConf) + + if err := server.waitSession(); err != nil { + return nil, nil, fmt.Errorf("server.waitSession: %v", err) + } + if err := client.waitSession(); err != nil { + return nil, nil, fmt.Errorf("client.waitSession: %v", err) + } + + return client, server, nil +} + +func TestHandshakeBasic(t *testing.T) { + if runtime.GOOS == "plan9" { + t.Skip("see golang.org/issue/7237") + } + + checker := &syncChecker{ + waitCall: make(chan int, 10), + called: make(chan int, 10), + } + + checker.waitCall <- 1 + trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr", false) + if err != nil { + t.Fatalf("handshakePair: %v", err) + } + + defer trC.Close() + defer trS.Close() + + // Let first kex complete normally. + <-checker.called + + clientDone := make(chan int, 0) + gotHalf := make(chan int, 0) + const N = 20 + errorCh := make(chan error, 1) + + go func() { + defer close(clientDone) + // Client writes a bunch of stuff, and does a key + // change in the middle. This should not confuse the + // handshake in progress. We do this twice, so we test + // that the packet buffer is reset correctly. + for i := 0; i < N; i++ { + p := []byte{msgRequestSuccess, byte(i)} + if err := trC.writePacket(p); err != nil { + errorCh <- err + trC.Close() + return + } + if (i % 10) == 5 { + <-gotHalf + // halfway through, we request a key change. + trC.requestKeyExchange() + + // Wait until we can be sure the key + // change has really started before we + // write more. + <-checker.called + } + if (i % 10) == 7 { + // write some packets until the kex + // completes, to test buffering of + // packets. + checker.waitCall <- 1 + } + } + errorCh <- nil + }() + + // Server checks that client messages come in cleanly + i := 0 + for ; i < N; i++ { + p, err := trS.readPacket() + if err != nil && err != io.EOF { + t.Fatalf("server error: %v", err) + } + if (i % 10) == 5 { + gotHalf <- 1 + } + + want := []byte{msgRequestSuccess, byte(i)} + if bytes.Compare(p, want) != 0 { + t.Errorf("message %d: got %v, want %v", i, p, want) + } + } + <-clientDone + if err := <-errorCh; err != nil { + t.Fatalf("sendPacket: %v", err) + } + if i != N { + t.Errorf("received %d messages, want 10.", i) + } + + close(checker.called) + if _, ok := <-checker.called; ok { + // If all went well, we registered exactly 2 key changes: one + // that establishes the session, and one that we requested + // additionally. + t.Fatalf("got another host key checks after 2 handshakes") + } +} + +func TestForceFirstKex(t *testing.T) { + // like handshakePair, but must access the keyingTransport. + checker := &testChecker{} + clientConf := &ClientConfig{HostKeyCallback: checker.Check} + a, b, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + + var trC, trS keyingTransport + + trC = newTransport(a, rand.Reader, true) + + // This is the disallowed packet: + trC.writePacket(Marshal(&serviceRequestMsg{serviceUserAuth})) + + // Rest of the setup. + trS = newTransport(b, rand.Reader, false) + clientConf.SetDefaults() + + v := []byte("version") + client := newClientTransport(trC, v, v, clientConf, "addr", a.RemoteAddr()) + + serverConf := &ServerConfig{} + serverConf.AddHostKey(testSigners["ecdsa"]) + serverConf.AddHostKey(testSigners["rsa"]) + serverConf.SetDefaults() + server := newServerTransport(trS, v, v, serverConf) + + defer client.Close() + defer server.Close() + + // We setup the initial key exchange, but the remote side + // tries to send serviceRequestMsg in cleartext, which is + // disallowed. + + if err := server.waitSession(); err == nil { + t.Errorf("server first kex init should reject unexpected packet") + } +} + +func TestHandshakeAutoRekeyWrite(t *testing.T) { + checker := &syncChecker{ + called: make(chan int, 10), + waitCall: nil, + } + clientConf := &ClientConfig{HostKeyCallback: checker.Check} + clientConf.RekeyThreshold = 500 + trC, trS, err := handshakePair(clientConf, "addr", false) + if err != nil { + t.Fatalf("handshakePair: %v", err) + } + defer trC.Close() + defer trS.Close() + + input := make([]byte, 251) + input[0] = msgRequestSuccess + + done := make(chan int, 1) + const numPacket = 5 + go func() { + defer close(done) + j := 0 + for ; j < numPacket; j++ { + if p, err := trS.readPacket(); err != nil { + break + } else if !bytes.Equal(input, p) { + t.Errorf("got packet type %d, want %d", p[0], input[0]) + } + } + + if j != numPacket { + t.Errorf("got %d, want 5 messages", j) + } + }() + + <-checker.called + + for i := 0; i < numPacket; i++ { + p := make([]byte, len(input)) + copy(p, input) + if err := trC.writePacket(p); err != nil { + t.Errorf("writePacket: %v", err) + } + if i == 2 { + // Make sure the kex is in progress. + <-checker.called + } + + } + <-done +} + +type syncChecker struct { + waitCall chan int + called chan int +} + +func (c *syncChecker) Check(dialAddr string, addr net.Addr, key PublicKey) error { + c.called <- 1 + if c.waitCall != nil { + <-c.waitCall + } + return nil +} + +func TestHandshakeAutoRekeyRead(t *testing.T) { + sync := &syncChecker{ + called: make(chan int, 2), + waitCall: nil, + } + clientConf := &ClientConfig{ + HostKeyCallback: sync.Check, + } + clientConf.RekeyThreshold = 500 + + trC, trS, err := handshakePair(clientConf, "addr", false) + if err != nil { + t.Fatalf("handshakePair: %v", err) + } + defer trC.Close() + defer trS.Close() + + packet := make([]byte, 501) + packet[0] = msgRequestSuccess + if err := trS.writePacket(packet); err != nil { + t.Fatalf("writePacket: %v", err) + } + + // While we read out the packet, a key change will be + // initiated. + errorCh := make(chan error, 1) + go func() { + _, err := trC.readPacket() + errorCh <- err + }() + + if err := <-errorCh; err != nil { + t.Fatalf("readPacket(client): %v", err) + } + + <-sync.called +} + +// errorKeyingTransport generates errors after a given number of +// read/write operations. +type errorKeyingTransport struct { + packetConn + readLeft, writeLeft int +} + +func (n *errorKeyingTransport) prepareKeyChange(*algorithms, *kexResult) error { + return nil +} + +func (n *errorKeyingTransport) getSessionID() []byte { + return nil +} + +func (n *errorKeyingTransport) writePacket(packet []byte) error { + if n.writeLeft == 0 { + n.Close() + return errors.New("barf") + } + + n.writeLeft-- + return n.packetConn.writePacket(packet) +} + +func (n *errorKeyingTransport) readPacket() ([]byte, error) { + if n.readLeft == 0 { + n.Close() + return nil, errors.New("barf") + } + + n.readLeft-- + return n.packetConn.readPacket() +} + +func (n *errorKeyingTransport) setStrictMode() error { return nil } + +func (n *errorKeyingTransport) setInitialKEXDone() {} + +func TestHandshakeErrorHandlingRead(t *testing.T) { + for i := 0; i < 20; i++ { + testHandshakeErrorHandlingN(t, i, -1, false) + } +} + +func TestHandshakeErrorHandlingWrite(t *testing.T) { + for i := 0; i < 20; i++ { + testHandshakeErrorHandlingN(t, -1, i, false) + } +} + +func TestHandshakeErrorHandlingReadCoupled(t *testing.T) { + for i := 0; i < 20; i++ { + testHandshakeErrorHandlingN(t, i, -1, true) + } +} + +func TestHandshakeErrorHandlingWriteCoupled(t *testing.T) { + for i := 0; i < 20; i++ { + testHandshakeErrorHandlingN(t, -1, i, true) + } +} + +// testHandshakeErrorHandlingN runs handshakes, injecting errors. If +// handshakeTransport deadlocks, the go runtime will detect it and +// panic. +func testHandshakeErrorHandlingN(t *testing.T, readLimit, writeLimit int, coupled bool) { + if (runtime.GOOS == "js" || runtime.GOOS == "wasip1") && runtime.GOARCH == "wasm" { + t.Skipf("skipping on %s/wasm; see golang.org/issue/32840", runtime.GOOS) + } + msg := Marshal(&serviceRequestMsg{strings.Repeat("x", int(minRekeyThreshold)/4)}) + + a, b := memPipe() + defer a.Close() + defer b.Close() + + key := testSigners["ecdsa"] + serverConf := Config{RekeyThreshold: minRekeyThreshold} + serverConf.SetDefaults() + serverConn := newHandshakeTransport(&errorKeyingTransport{a, readLimit, writeLimit}, &serverConf, []byte{'a'}, []byte{'b'}) + serverConn.hostKeys = []Signer{key} + go serverConn.readLoop() + go serverConn.kexLoop() + + clientConf := Config{RekeyThreshold: 10 * minRekeyThreshold} + clientConf.SetDefaults() + clientConn := newHandshakeTransport(&errorKeyingTransport{b, -1, -1}, &clientConf, []byte{'a'}, []byte{'b'}) + clientConn.hostKeyAlgorithms = []string{key.PublicKey().Type()} + clientConn.hostKeyCallback = InsecureIgnoreHostKey() + go clientConn.readLoop() + go clientConn.kexLoop() + + var wg sync.WaitGroup + + for _, hs := range []packetConn{serverConn, clientConn} { + if !coupled { + wg.Add(2) + go func(c packetConn) { + for i := 0; ; i++ { + str := fmt.Sprintf("%08x", i) + strings.Repeat("x", int(minRekeyThreshold)/4-8) + err := c.writePacket(Marshal(&serviceRequestMsg{str})) + if err != nil { + break + } + } + wg.Done() + c.Close() + }(hs) + go func(c packetConn) { + for { + _, err := c.readPacket() + if err != nil { + break + } + } + wg.Done() + }(hs) + } else { + wg.Add(1) + go func(c packetConn) { + for { + _, err := c.readPacket() + if err != nil { + break + } + if err := c.writePacket(msg); err != nil { + break + } + + } + wg.Done() + }(hs) + } + } + wg.Wait() +} + +func TestDisconnect(t *testing.T) { + if runtime.GOOS == "plan9" { + t.Skip("see golang.org/issue/7237") + } + checker := &testChecker{} + trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr", false) + if err != nil { + t.Fatalf("handshakePair: %v", err) + } + + defer trC.Close() + defer trS.Close() + + trC.writePacket([]byte{msgRequestSuccess, 0, 0}) + errMsg := &disconnectMsg{ + Reason: 42, + Message: "such is life", + } + trC.writePacket(Marshal(errMsg)) + trC.writePacket([]byte{msgRequestSuccess, 0, 0}) + + packet, err := trS.readPacket() + if err != nil { + t.Fatalf("readPacket 1: %v", err) + } + if packet[0] != msgRequestSuccess { + t.Errorf("got packet %v, want packet type %d", packet, msgRequestSuccess) + } + + _, err = trS.readPacket() + if err == nil { + t.Errorf("readPacket 2 succeeded") + } else if !reflect.DeepEqual(err, errMsg) { + t.Errorf("got error %#v, want %#v", err, errMsg) + } + + _, err = trS.readPacket() + if err == nil { + t.Errorf("readPacket 3 succeeded") + } +} + +func TestHandshakeRekeyDefault(t *testing.T) { + clientConf := &ClientConfig{ + Config: Config{ + Ciphers: []string{"aes128-ctr"}, + }, + HostKeyCallback: InsecureIgnoreHostKey(), + } + trC, trS, err := handshakePair(clientConf, "addr", false) + if err != nil { + t.Fatalf("handshakePair: %v", err) + } + defer trC.Close() + defer trS.Close() + + trC.writePacket([]byte{msgRequestSuccess, 0, 0}) + trC.Close() + + rgb := (1024 + trC.readBytesLeft) >> 30 + wgb := (1024 + trC.writeBytesLeft) >> 30 + + if rgb != 64 { + t.Errorf("got rekey after %dG read, want 64G", rgb) + } + if wgb != 64 { + t.Errorf("got rekey after %dG write, want 64G", wgb) + } +} + +func TestHandshakeAEADCipherNoMAC(t *testing.T) { + for _, cipher := range []string{chacha20Poly1305ID, gcm128CipherID} { + checker := &syncChecker{ + called: make(chan int, 1), + } + clientConf := &ClientConfig{ + Config: Config{ + Ciphers: []string{cipher}, + MACs: []string{}, + }, + HostKeyCallback: checker.Check, + } + trC, trS, err := handshakePair(clientConf, "addr", false) + if err != nil { + t.Fatalf("handshakePair: %v", err) + } + defer trC.Close() + defer trS.Close() + + <-checker.called + } +} + +// TestNoSHA2Support tests a host key Signer that is not an AlgorithmSigner and +// therefore can't do SHA-2 signatures. Ensures the server does not advertise +// support for them in this case. +func TestNoSHA2Support(t *testing.T) { + c1, c2, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + defer c1.Close() + defer c2.Close() + + serverConf := &ServerConfig{ + PasswordCallback: func(conn ConnMetadata, password []byte) (*Permissions, error) { + return &Permissions{}, nil + }, + } + serverConf.AddHostKey(&legacyRSASigner{testSigners["rsa"]}) + go func() { + _, _, _, err := NewServerConn(c1, serverConf) + if err != nil { + t.Error(err) + } + }() + + clientConf := &ClientConfig{ + User: "test", + Auth: []AuthMethod{Password("testpw")}, + HostKeyCallback: FixedHostKey(testSigners["rsa"].PublicKey()), + } + + if _, _, _, err := NewClientConn(c2, "", clientConf); err != nil { + t.Fatal(err) + } +} + +func TestMultiAlgoSignerHandshake(t *testing.T) { + algorithmSigner, ok := testSigners["rsa"].(AlgorithmSigner) + if !ok { + t.Fatal("rsa test signer does not implement the AlgorithmSigner interface") + } + multiAlgoSigner, err := NewSignerWithAlgorithms(algorithmSigner, []string{KeyAlgoRSASHA256, KeyAlgoRSASHA512}) + if err != nil { + t.Fatalf("unable to create multi algorithm signer: %v", err) + } + c1, c2, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + defer c1.Close() + defer c2.Close() + + serverConf := &ServerConfig{ + PasswordCallback: func(conn ConnMetadata, password []byte) (*Permissions, error) { + return &Permissions{}, nil + }, + } + serverConf.AddHostKey(multiAlgoSigner) + go NewServerConn(c1, serverConf) + + clientConf := &ClientConfig{ + User: "test", + Auth: []AuthMethod{Password("testpw")}, + HostKeyCallback: FixedHostKey(testSigners["rsa"].PublicKey()), + HostKeyAlgorithms: []string{KeyAlgoRSASHA512}, + } + + if _, _, _, err := NewClientConn(c2, "", clientConf); err != nil { + t.Fatal(err) + } +} + +func TestMultiAlgoSignerNoCommonHostKeyAlgo(t *testing.T) { + algorithmSigner, ok := testSigners["rsa"].(AlgorithmSigner) + if !ok { + t.Fatal("rsa test signer does not implement the AlgorithmSigner interface") + } + multiAlgoSigner, err := NewSignerWithAlgorithms(algorithmSigner, []string{KeyAlgoRSASHA256, KeyAlgoRSASHA512}) + if err != nil { + t.Fatalf("unable to create multi algorithm signer: %v", err) + } + c1, c2, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + defer c1.Close() + defer c2.Close() + + // ssh-rsa is disabled server side + serverConf := &ServerConfig{ + PasswordCallback: func(conn ConnMetadata, password []byte) (*Permissions, error) { + return &Permissions{}, nil + }, + } + serverConf.AddHostKey(multiAlgoSigner) + go NewServerConn(c1, serverConf) + + // the client only supports ssh-rsa + clientConf := &ClientConfig{ + User: "test", + Auth: []AuthMethod{Password("testpw")}, + HostKeyCallback: FixedHostKey(testSigners["rsa"].PublicKey()), + HostKeyAlgorithms: []string{KeyAlgoRSA}, + } + + _, _, _, err = NewClientConn(c2, "", clientConf) + if err == nil { + t.Fatal("succeeded connecting with no common hostkey algorithm") + } +} + +func TestPickIncompatibleHostKeyAlgo(t *testing.T) { + algorithmSigner, ok := testSigners["rsa"].(AlgorithmSigner) + if !ok { + t.Fatal("rsa test signer does not implement the AlgorithmSigner interface") + } + multiAlgoSigner, err := NewSignerWithAlgorithms(algorithmSigner, []string{KeyAlgoRSASHA256, KeyAlgoRSASHA512}) + if err != nil { + t.Fatalf("unable to create multi algorithm signer: %v", err) + } + signer := pickHostKey([]Signer{multiAlgoSigner}, KeyAlgoRSA) + if signer != nil { + t.Fatal("incompatible signer returned") + } +} + +func TestStrictKEXResetSeqFirstKEX(t *testing.T) { + if runtime.GOOS == "plan9" { + t.Skip("see golang.org/issue/7237") + } + + checker := &syncChecker{ + waitCall: make(chan int, 10), + called: make(chan int, 10), + } + + checker.waitCall <- 1 + trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr", false) + if err != nil { + t.Fatalf("handshakePair: %v", err) + } + <-checker.called + + t.Cleanup(func() { + trC.Close() + trS.Close() + }) + + // Throw away the msgExtInfo packet sent during the handshake by the server + _, err = trC.readPacket() + if err != nil { + t.Fatalf("readPacket failed: %s", err) + } + + // close the handshake transports before checking the sequence number to + // avoid races. + trC.Close() + trS.Close() + + // check that the sequence number counters. We reset after msgNewKeys, but + // then the server immediately writes msgExtInfo, and we close the + // transports so we expect read 2, write 0 on the client and read 1, write 1 + // on the server. + if trC.conn.(*transport).reader.seqNum != 2 || trC.conn.(*transport).writer.seqNum != 0 || + trS.conn.(*transport).reader.seqNum != 1 || trS.conn.(*transport).writer.seqNum != 1 { + t.Errorf( + "unexpected sequence counters:\nclient: reader %d (expected 2), writer %d (expected 0)\nserver: reader %d (expected 1), writer %d (expected 1)", + trC.conn.(*transport).reader.seqNum, + trC.conn.(*transport).writer.seqNum, + trS.conn.(*transport).reader.seqNum, + trS.conn.(*transport).writer.seqNum, + ) + } +} + +func TestStrictKEXResetSeqSuccessiveKEX(t *testing.T) { + if runtime.GOOS == "plan9" { + t.Skip("see golang.org/issue/7237") + } + + checker := &syncChecker{ + waitCall: make(chan int, 10), + called: make(chan int, 10), + } + + checker.waitCall <- 1 + trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr", false) + if err != nil { + t.Fatalf("handshakePair: %v", err) + } + <-checker.called + + t.Cleanup(func() { + trC.Close() + trS.Close() + }) + + // Throw away the msgExtInfo packet sent during the handshake by the server + _, err = trC.readPacket() + if err != nil { + t.Fatalf("readPacket failed: %s", err) + } + + // write and read five packets on either side to bump the sequence numbers + for i := 0; i < 5; i++ { + if err := trC.writePacket([]byte{msgRequestSuccess}); err != nil { + t.Fatalf("writePacket failed: %s", err) + } + if _, err := trS.readPacket(); err != nil { + t.Fatalf("readPacket failed: %s", err) + } + if err := trS.writePacket([]byte{msgRequestSuccess}); err != nil { + t.Fatalf("writePacket failed: %s", err) + } + if _, err := trC.readPacket(); err != nil { + t.Fatalf("readPacket failed: %s", err) + } + } + + // Request a key exchange, which should cause the sequence numbers to reset + checker.waitCall <- 1 + trC.requestKeyExchange() + <-checker.called + + // write a packet on the client, and then read it, to verify the key change has actually happened, since + // the HostKeyCallback is called _during_ the handshake, so isn't actually indicative of the handshake + // finishing. + dummyPacket := []byte{99} + if err := trS.writePacket(dummyPacket); err != nil { + t.Fatalf("writePacket failed: %s", err) + } + if p, err := trC.readPacket(); err != nil { + t.Fatalf("readPacket failed: %s", err) + } else if !bytes.Equal(p, dummyPacket) { + t.Fatalf("unexpected packet: got %x, want %x", p, dummyPacket) + } + + // close the handshake transports before checking the sequence number to + // avoid races. + trC.Close() + trS.Close() + + if trC.conn.(*transport).reader.seqNum != 2 || trC.conn.(*transport).writer.seqNum != 0 || + trS.conn.(*transport).reader.seqNum != 1 || trS.conn.(*transport).writer.seqNum != 1 { + t.Errorf( + "unexpected sequence counters:\nclient: reader %d (expected 2), writer %d (expected 0)\nserver: reader %d (expected 1), writer %d (expected 1)", + trC.conn.(*transport).reader.seqNum, + trC.conn.(*transport).writer.seqNum, + trS.conn.(*transport).reader.seqNum, + trS.conn.(*transport).writer.seqNum, + ) + } +} + +func TestSeqNumIncrease(t *testing.T) { + if runtime.GOOS == "plan9" { + t.Skip("see golang.org/issue/7237") + } + + checker := &syncChecker{ + waitCall: make(chan int, 10), + called: make(chan int, 10), + } + + checker.waitCall <- 1 + trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr", false) + if err != nil { + t.Fatalf("handshakePair: %v", err) + } + <-checker.called + + t.Cleanup(func() { + trC.Close() + trS.Close() + }) + + // Throw away the msgExtInfo packet sent during the handshake by the server + _, err = trC.readPacket() + if err != nil { + t.Fatalf("readPacket failed: %s", err) + } + + // write and read five packets on either side to bump the sequence numbers + for i := 0; i < 5; i++ { + if err := trC.writePacket([]byte{msgRequestSuccess}); err != nil { + t.Fatalf("writePacket failed: %s", err) + } + if _, err := trS.readPacket(); err != nil { + t.Fatalf("readPacket failed: %s", err) + } + if err := trS.writePacket([]byte{msgRequestSuccess}); err != nil { + t.Fatalf("writePacket failed: %s", err) + } + if _, err := trC.readPacket(); err != nil { + t.Fatalf("readPacket failed: %s", err) + } + } + + // close the handshake transports before checking the sequence number to + // avoid races. + trC.Close() + trS.Close() + + if trC.conn.(*transport).reader.seqNum != 7 || trC.conn.(*transport).writer.seqNum != 5 || + trS.conn.(*transport).reader.seqNum != 6 || trS.conn.(*transport).writer.seqNum != 6 { + t.Errorf( + "unexpected sequence counters:\nclient: reader %d (expected 7), writer %d (expected 5)\nserver: reader %d (expected 6), writer %d (expected 6)", + trC.conn.(*transport).reader.seqNum, + trC.conn.(*transport).writer.seqNum, + trS.conn.(*transport).reader.seqNum, + trS.conn.(*transport).writer.seqNum, + ) + } +} + +func TestStrictKEXUnexpectedMsg(t *testing.T) { + if runtime.GOOS == "plan9" { + t.Skip("see golang.org/issue/7237") + } + + // Check that unexpected messages during the handshake cause failure + _, _, err := handshakePair(&ClientConfig{HostKeyCallback: func(hostname string, remote net.Addr, key PublicKey) error { return nil }}, "addr", true) + if err == nil { + t.Fatal("handshake should fail when there are unexpected messages during the handshake") + } + + trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: func(hostname string, remote net.Addr, key PublicKey) error { return nil }}, "addr", false) + if err != nil { + t.Fatalf("handshake failed: %s", err) + } + + // Check that ignore/debug pacekts are still ignored outside of the handshake + if err := trC.writePacket([]byte{msgIgnore}); err != nil { + t.Fatalf("writePacket failed: %s", err) + } + if err := trC.writePacket([]byte{msgDebug}); err != nil { + t.Fatalf("writePacket failed: %s", err) + } + dummyPacket := []byte{99} + if err := trC.writePacket(dummyPacket); err != nil { + t.Fatalf("writePacket failed: %s", err) + } + + if p, err := trS.readPacket(); err != nil { + t.Fatalf("readPacket failed: %s", err) + } else if !bytes.Equal(p, dummyPacket) { + t.Fatalf("unexpected packet: got %x, want %x", p, dummyPacket) + } +} + +func TestStrictKEXMixed(t *testing.T) { + // Test that we still support a mixed connection, where one side sends kex-strict but the other + // side doesn't. + + a, b, err := netPipe() + if err != nil { + t.Fatalf("netPipe failed: %s", err) + } + + var trC, trS keyingTransport + + trC = newTransport(a, rand.Reader, true) + trS = newTransport(b, rand.Reader, false) + trS = addNoiseTransport(trS) + + clientConf := &ClientConfig{HostKeyCallback: func(hostname string, remote net.Addr, key PublicKey) error { return nil }} + clientConf.SetDefaults() + + v := []byte("version") + client := newClientTransport(trC, v, v, clientConf, "addr", a.RemoteAddr()) + + serverConf := &ServerConfig{} + serverConf.AddHostKey(testSigners["ecdsa"]) + serverConf.AddHostKey(testSigners["rsa"]) + serverConf.SetDefaults() + + transport := newHandshakeTransport(trS, &serverConf.Config, []byte("version"), []byte("version")) + transport.hostKeys = serverConf.hostKeys + transport.publicKeyAuthAlgorithms = serverConf.PublicKeyAuthAlgorithms + + readOneFailure := make(chan error, 1) + go func() { + if _, err := transport.readOnePacket(true); err != nil { + readOneFailure <- err + } + }() + + // Basically sendKexInit, but without the kex-strict extension algorithm + msg := &kexInitMsg{ + KexAlgos: transport.config.KeyExchanges, + CiphersClientServer: transport.config.Ciphers, + CiphersServerClient: transport.config.Ciphers, + MACsClientServer: transport.config.MACs, + MACsServerClient: transport.config.MACs, + CompressionClientServer: supportedCompressions, + CompressionServerClient: supportedCompressions, + ServerHostKeyAlgos: []string{KeyAlgoRSASHA256, KeyAlgoRSASHA512, KeyAlgoRSA}, + } + packet := Marshal(msg) + // writePacket destroys the contents, so save a copy. + packetCopy := make([]byte, len(packet)) + copy(packetCopy, packet) + if err := transport.pushPacket(packetCopy); err != nil { + t.Fatalf("pushPacket: %s", err) + } + transport.sentInitMsg = msg + transport.sentInitPacket = packet + + if err := transport.getWriteError(); err != nil { + t.Fatalf("getWriteError failed: %s", err) + } + var request *pendingKex + select { + case err = <-readOneFailure: + t.Fatalf("server readOnePacket failed: %s", err) + case request = <-transport.startKex: + break + } + + // We expect the following calls to fail if the side which does not support + // kex-strict sends unexpected/ignored packets during the handshake, even if + // the other side does support kex-strict. + + if err := transport.enterKeyExchange(request.otherInit); err != nil { + t.Fatalf("enterKeyExchange failed: %s", err) + } + if err := client.waitSession(); err != nil { + t.Fatalf("client.waitSession: %v", err) + } +} diff --git a/tempfork/sshtest/ssh/kex.go b/tempfork/sshtest/ssh/kex.go new file mode 100644 index 000000000..8a05f7990 --- /dev/null +++ b/tempfork/sshtest/ssh/kex.go @@ -0,0 +1,786 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "crypto" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/subtle" + "encoding/binary" + "errors" + "fmt" + "io" + "math/big" + + "golang.org/x/crypto/curve25519" +) + +const ( + kexAlgoDH1SHA1 = "diffie-hellman-group1-sha1" + kexAlgoDH14SHA1 = "diffie-hellman-group14-sha1" + kexAlgoDH14SHA256 = "diffie-hellman-group14-sha256" + kexAlgoDH16SHA512 = "diffie-hellman-group16-sha512" + kexAlgoECDH256 = "ecdh-sha2-nistp256" + kexAlgoECDH384 = "ecdh-sha2-nistp384" + kexAlgoECDH521 = "ecdh-sha2-nistp521" + kexAlgoCurve25519SHA256LibSSH = "curve25519-sha256@libssh.org" + kexAlgoCurve25519SHA256 = "curve25519-sha256" + + // For the following kex only the client half contains a production + // ready implementation. The server half only consists of a minimal + // implementation to satisfy the automated tests. + kexAlgoDHGEXSHA1 = "diffie-hellman-group-exchange-sha1" + kexAlgoDHGEXSHA256 = "diffie-hellman-group-exchange-sha256" +) + +// kexResult captures the outcome of a key exchange. +type kexResult struct { + // Session hash. See also RFC 4253, section 8. + H []byte + + // Shared secret. See also RFC 4253, section 8. + K []byte + + // Host key as hashed into H. + HostKey []byte + + // Signature of H. + Signature []byte + + // A cryptographic hash function that matches the security + // level of the key exchange algorithm. It is used for + // calculating H, and for deriving keys from H and K. + Hash crypto.Hash + + // The session ID, which is the first H computed. This is used + // to derive key material inside the transport. + SessionID []byte +} + +// handshakeMagics contains data that is always included in the +// session hash. +type handshakeMagics struct { + clientVersion, serverVersion []byte + clientKexInit, serverKexInit []byte +} + +func (m *handshakeMagics) write(w io.Writer) { + writeString(w, m.clientVersion) + writeString(w, m.serverVersion) + writeString(w, m.clientKexInit) + writeString(w, m.serverKexInit) +} + +// kexAlgorithm abstracts different key exchange algorithms. +type kexAlgorithm interface { + // Server runs server-side key agreement, signing the result + // with a hostkey. algo is the negotiated algorithm, and may + // be a certificate type. + Server(p packetConn, rand io.Reader, magics *handshakeMagics, s AlgorithmSigner, algo string) (*kexResult, error) + + // Client runs the client-side key agreement. Caller is + // responsible for verifying the host key signature. + Client(p packetConn, rand io.Reader, magics *handshakeMagics) (*kexResult, error) +} + +// dhGroup is a multiplicative group suitable for implementing Diffie-Hellman key agreement. +type dhGroup struct { + g, p, pMinus1 *big.Int + hashFunc crypto.Hash +} + +func (group *dhGroup) diffieHellman(theirPublic, myPrivate *big.Int) (*big.Int, error) { + if theirPublic.Cmp(bigOne) <= 0 || theirPublic.Cmp(group.pMinus1) >= 0 { + return nil, errors.New("ssh: DH parameter out of bounds") + } + return new(big.Int).Exp(theirPublic, myPrivate, group.p), nil +} + +func (group *dhGroup) Client(c packetConn, randSource io.Reader, magics *handshakeMagics) (*kexResult, error) { + var x *big.Int + for { + var err error + if x, err = rand.Int(randSource, group.pMinus1); err != nil { + return nil, err + } + if x.Sign() > 0 { + break + } + } + + X := new(big.Int).Exp(group.g, x, group.p) + kexDHInit := kexDHInitMsg{ + X: X, + } + if err := c.writePacket(Marshal(&kexDHInit)); err != nil { + return nil, err + } + + packet, err := c.readPacket() + if err != nil { + return nil, err + } + + var kexDHReply kexDHReplyMsg + if err = Unmarshal(packet, &kexDHReply); err != nil { + return nil, err + } + + ki, err := group.diffieHellman(kexDHReply.Y, x) + if err != nil { + return nil, err + } + + h := group.hashFunc.New() + magics.write(h) + writeString(h, kexDHReply.HostKey) + writeInt(h, X) + writeInt(h, kexDHReply.Y) + K := make([]byte, intLength(ki)) + marshalInt(K, ki) + h.Write(K) + + return &kexResult{ + H: h.Sum(nil), + K: K, + HostKey: kexDHReply.HostKey, + Signature: kexDHReply.Signature, + Hash: group.hashFunc, + }, nil +} + +func (group *dhGroup) Server(c packetConn, randSource io.Reader, magics *handshakeMagics, priv AlgorithmSigner, algo string) (result *kexResult, err error) { + packet, err := c.readPacket() + if err != nil { + return + } + var kexDHInit kexDHInitMsg + if err = Unmarshal(packet, &kexDHInit); err != nil { + return + } + + var y *big.Int + for { + if y, err = rand.Int(randSource, group.pMinus1); err != nil { + return + } + if y.Sign() > 0 { + break + } + } + + Y := new(big.Int).Exp(group.g, y, group.p) + ki, err := group.diffieHellman(kexDHInit.X, y) + if err != nil { + return nil, err + } + + hostKeyBytes := priv.PublicKey().Marshal() + + h := group.hashFunc.New() + magics.write(h) + writeString(h, hostKeyBytes) + writeInt(h, kexDHInit.X) + writeInt(h, Y) + + K := make([]byte, intLength(ki)) + marshalInt(K, ki) + h.Write(K) + + H := h.Sum(nil) + + // H is already a hash, but the hostkey signing will apply its + // own key-specific hash algorithm. + sig, err := signAndMarshal(priv, randSource, H, algo) + if err != nil { + return nil, err + } + + kexDHReply := kexDHReplyMsg{ + HostKey: hostKeyBytes, + Y: Y, + Signature: sig, + } + packet = Marshal(&kexDHReply) + + err = c.writePacket(packet) + return &kexResult{ + H: H, + K: K, + HostKey: hostKeyBytes, + Signature: sig, + Hash: group.hashFunc, + }, err +} + +// ecdh performs Elliptic Curve Diffie-Hellman key exchange as +// described in RFC 5656, section 4. +type ecdh struct { + curve elliptic.Curve +} + +func (kex *ecdh) Client(c packetConn, rand io.Reader, magics *handshakeMagics) (*kexResult, error) { + ephKey, err := ecdsa.GenerateKey(kex.curve, rand) + if err != nil { + return nil, err + } + + kexInit := kexECDHInitMsg{ + ClientPubKey: elliptic.Marshal(kex.curve, ephKey.PublicKey.X, ephKey.PublicKey.Y), + } + + serialized := Marshal(&kexInit) + if err := c.writePacket(serialized); err != nil { + return nil, err + } + + packet, err := c.readPacket() + if err != nil { + return nil, err + } + + var reply kexECDHReplyMsg + if err = Unmarshal(packet, &reply); err != nil { + return nil, err + } + + x, y, err := unmarshalECKey(kex.curve, reply.EphemeralPubKey) + if err != nil { + return nil, err + } + + // generate shared secret + secret, _ := kex.curve.ScalarMult(x, y, ephKey.D.Bytes()) + + h := ecHash(kex.curve).New() + magics.write(h) + writeString(h, reply.HostKey) + writeString(h, kexInit.ClientPubKey) + writeString(h, reply.EphemeralPubKey) + K := make([]byte, intLength(secret)) + marshalInt(K, secret) + h.Write(K) + + return &kexResult{ + H: h.Sum(nil), + K: K, + HostKey: reply.HostKey, + Signature: reply.Signature, + Hash: ecHash(kex.curve), + }, nil +} + +// unmarshalECKey parses and checks an EC key. +func unmarshalECKey(curve elliptic.Curve, pubkey []byte) (x, y *big.Int, err error) { + x, y = elliptic.Unmarshal(curve, pubkey) + if x == nil { + return nil, nil, errors.New("ssh: elliptic.Unmarshal failure") + } + if !validateECPublicKey(curve, x, y) { + return nil, nil, errors.New("ssh: public key not on curve") + } + return x, y, nil +} + +// validateECPublicKey checks that the point is a valid public key for +// the given curve. See [SEC1], 3.2.2 +func validateECPublicKey(curve elliptic.Curve, x, y *big.Int) bool { + if x.Sign() == 0 && y.Sign() == 0 { + return false + } + + if x.Cmp(curve.Params().P) >= 0 { + return false + } + + if y.Cmp(curve.Params().P) >= 0 { + return false + } + + if !curve.IsOnCurve(x, y) { + return false + } + + // We don't check if N * PubKey == 0, since + // + // - the NIST curves have cofactor = 1, so this is implicit. + // (We don't foresee an implementation that supports non NIST + // curves) + // + // - for ephemeral keys, we don't need to worry about small + // subgroup attacks. + return true +} + +func (kex *ecdh) Server(c packetConn, rand io.Reader, magics *handshakeMagics, priv AlgorithmSigner, algo string) (result *kexResult, err error) { + packet, err := c.readPacket() + if err != nil { + return nil, err + } + + var kexECDHInit kexECDHInitMsg + if err = Unmarshal(packet, &kexECDHInit); err != nil { + return nil, err + } + + clientX, clientY, err := unmarshalECKey(kex.curve, kexECDHInit.ClientPubKey) + if err != nil { + return nil, err + } + + // We could cache this key across multiple users/multiple + // connection attempts, but the benefit is small. OpenSSH + // generates a new key for each incoming connection. + ephKey, err := ecdsa.GenerateKey(kex.curve, rand) + if err != nil { + return nil, err + } + + hostKeyBytes := priv.PublicKey().Marshal() + + serializedEphKey := elliptic.Marshal(kex.curve, ephKey.PublicKey.X, ephKey.PublicKey.Y) + + // generate shared secret + secret, _ := kex.curve.ScalarMult(clientX, clientY, ephKey.D.Bytes()) + + h := ecHash(kex.curve).New() + magics.write(h) + writeString(h, hostKeyBytes) + writeString(h, kexECDHInit.ClientPubKey) + writeString(h, serializedEphKey) + + K := make([]byte, intLength(secret)) + marshalInt(K, secret) + h.Write(K) + + H := h.Sum(nil) + + // H is already a hash, but the hostkey signing will apply its + // own key-specific hash algorithm. + sig, err := signAndMarshal(priv, rand, H, algo) + if err != nil { + return nil, err + } + + reply := kexECDHReplyMsg{ + EphemeralPubKey: serializedEphKey, + HostKey: hostKeyBytes, + Signature: sig, + } + + serialized := Marshal(&reply) + if err := c.writePacket(serialized); err != nil { + return nil, err + } + + return &kexResult{ + H: H, + K: K, + HostKey: reply.HostKey, + Signature: sig, + Hash: ecHash(kex.curve), + }, nil +} + +// ecHash returns the hash to match the given elliptic curve, see RFC +// 5656, section 6.2.1 +func ecHash(curve elliptic.Curve) crypto.Hash { + bitSize := curve.Params().BitSize + switch { + case bitSize <= 256: + return crypto.SHA256 + case bitSize <= 384: + return crypto.SHA384 + } + return crypto.SHA512 +} + +var kexAlgoMap = map[string]kexAlgorithm{} + +func init() { + // This is the group called diffie-hellman-group1-sha1 in + // RFC 4253 and Oakley Group 2 in RFC 2409. + p, _ := new(big.Int).SetString("FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF", 16) + kexAlgoMap[kexAlgoDH1SHA1] = &dhGroup{ + g: new(big.Int).SetInt64(2), + p: p, + pMinus1: new(big.Int).Sub(p, bigOne), + hashFunc: crypto.SHA1, + } + + // This are the groups called diffie-hellman-group14-sha1 and + // diffie-hellman-group14-sha256 in RFC 4253 and RFC 8268, + // and Oakley Group 14 in RFC 3526. + p, _ = new(big.Int).SetString("FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AACAA68FFFFFFFFFFFFFFFF", 16) + group14 := &dhGroup{ + g: new(big.Int).SetInt64(2), + p: p, + pMinus1: new(big.Int).Sub(p, bigOne), + } + + kexAlgoMap[kexAlgoDH14SHA1] = &dhGroup{ + g: group14.g, p: group14.p, pMinus1: group14.pMinus1, + hashFunc: crypto.SHA1, + } + kexAlgoMap[kexAlgoDH14SHA256] = &dhGroup{ + g: group14.g, p: group14.p, pMinus1: group14.pMinus1, + hashFunc: crypto.SHA256, + } + + // This is the group called diffie-hellman-group16-sha512 in RFC + // 8268 and Oakley Group 16 in RFC 3526. + p, _ = new(big.Int).SetString("FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AAAC42DAD33170D04507A33A85521ABDF1CBA64ECFB850458DBEF0A8AEA71575D060C7DB3970F85A6E1E4C7ABF5AE8CDB0933D71E8C94E04A25619DCEE3D2261AD2EE6BF12FFA06D98A0864D87602733EC86A64521F2B18177B200CBBE117577A615D6C770988C0BAD946E208E24FA074E5AB3143DB5BFCE0FD108E4B82D120A92108011A723C12A787E6D788719A10BDBA5B2699C327186AF4E23C1A946834B6150BDA2583E9CA2AD44CE8DBBBC2DB04DE8EF92E8EFC141FBECAA6287C59474E6BC05D99B2964FA090C3A2233BA186515BE7ED1F612970CEE2D7AFB81BDD762170481CD0069127D5B05AA993B4EA988D8FDDC186FFB7DC90A6C08F4DF435C934063199FFFFFFFFFFFFFFFF", 16) + + kexAlgoMap[kexAlgoDH16SHA512] = &dhGroup{ + g: new(big.Int).SetInt64(2), + p: p, + pMinus1: new(big.Int).Sub(p, bigOne), + hashFunc: crypto.SHA512, + } + + kexAlgoMap[kexAlgoECDH521] = &ecdh{elliptic.P521()} + kexAlgoMap[kexAlgoECDH384] = &ecdh{elliptic.P384()} + kexAlgoMap[kexAlgoECDH256] = &ecdh{elliptic.P256()} + kexAlgoMap[kexAlgoCurve25519SHA256] = &curve25519sha256{} + kexAlgoMap[kexAlgoCurve25519SHA256LibSSH] = &curve25519sha256{} + kexAlgoMap[kexAlgoDHGEXSHA1] = &dhGEXSHA{hashFunc: crypto.SHA1} + kexAlgoMap[kexAlgoDHGEXSHA256] = &dhGEXSHA{hashFunc: crypto.SHA256} +} + +// curve25519sha256 implements the curve25519-sha256 (formerly known as +// curve25519-sha256@libssh.org) key exchange method, as described in RFC 8731. +type curve25519sha256 struct{} + +type curve25519KeyPair struct { + priv [32]byte + pub [32]byte +} + +func (kp *curve25519KeyPair) generate(rand io.Reader) error { + if _, err := io.ReadFull(rand, kp.priv[:]); err != nil { + return err + } + curve25519.ScalarBaseMult(&kp.pub, &kp.priv) + return nil +} + +// curve25519Zeros is just an array of 32 zero bytes so that we have something +// convenient to compare against in order to reject curve25519 points with the +// wrong order. +var curve25519Zeros [32]byte + +func (kex *curve25519sha256) Client(c packetConn, rand io.Reader, magics *handshakeMagics) (*kexResult, error) { + var kp curve25519KeyPair + if err := kp.generate(rand); err != nil { + return nil, err + } + if err := c.writePacket(Marshal(&kexECDHInitMsg{kp.pub[:]})); err != nil { + return nil, err + } + + packet, err := c.readPacket() + if err != nil { + return nil, err + } + + var reply kexECDHReplyMsg + if err = Unmarshal(packet, &reply); err != nil { + return nil, err + } + if len(reply.EphemeralPubKey) != 32 { + return nil, errors.New("ssh: peer's curve25519 public value has wrong length") + } + + var servPub, secret [32]byte + copy(servPub[:], reply.EphemeralPubKey) + curve25519.ScalarMult(&secret, &kp.priv, &servPub) + if subtle.ConstantTimeCompare(secret[:], curve25519Zeros[:]) == 1 { + return nil, errors.New("ssh: peer's curve25519 public value has wrong order") + } + + h := crypto.SHA256.New() + magics.write(h) + writeString(h, reply.HostKey) + writeString(h, kp.pub[:]) + writeString(h, reply.EphemeralPubKey) + + ki := new(big.Int).SetBytes(secret[:]) + K := make([]byte, intLength(ki)) + marshalInt(K, ki) + h.Write(K) + + return &kexResult{ + H: h.Sum(nil), + K: K, + HostKey: reply.HostKey, + Signature: reply.Signature, + Hash: crypto.SHA256, + }, nil +} + +func (kex *curve25519sha256) Server(c packetConn, rand io.Reader, magics *handshakeMagics, priv AlgorithmSigner, algo string) (result *kexResult, err error) { + packet, err := c.readPacket() + if err != nil { + return + } + var kexInit kexECDHInitMsg + if err = Unmarshal(packet, &kexInit); err != nil { + return + } + + if len(kexInit.ClientPubKey) != 32 { + return nil, errors.New("ssh: peer's curve25519 public value has wrong length") + } + + var kp curve25519KeyPair + if err := kp.generate(rand); err != nil { + return nil, err + } + + var clientPub, secret [32]byte + copy(clientPub[:], kexInit.ClientPubKey) + curve25519.ScalarMult(&secret, &kp.priv, &clientPub) + if subtle.ConstantTimeCompare(secret[:], curve25519Zeros[:]) == 1 { + return nil, errors.New("ssh: peer's curve25519 public value has wrong order") + } + + hostKeyBytes := priv.PublicKey().Marshal() + + h := crypto.SHA256.New() + magics.write(h) + writeString(h, hostKeyBytes) + writeString(h, kexInit.ClientPubKey) + writeString(h, kp.pub[:]) + + ki := new(big.Int).SetBytes(secret[:]) + K := make([]byte, intLength(ki)) + marshalInt(K, ki) + h.Write(K) + + H := h.Sum(nil) + + sig, err := signAndMarshal(priv, rand, H, algo) + if err != nil { + return nil, err + } + + reply := kexECDHReplyMsg{ + EphemeralPubKey: kp.pub[:], + HostKey: hostKeyBytes, + Signature: sig, + } + if err := c.writePacket(Marshal(&reply)); err != nil { + return nil, err + } + return &kexResult{ + H: H, + K: K, + HostKey: hostKeyBytes, + Signature: sig, + Hash: crypto.SHA256, + }, nil +} + +// dhGEXSHA implements the diffie-hellman-group-exchange-sha1 and +// diffie-hellman-group-exchange-sha256 key agreement protocols, +// as described in RFC 4419 +type dhGEXSHA struct { + hashFunc crypto.Hash +} + +const ( + dhGroupExchangeMinimumBits = 2048 + dhGroupExchangePreferredBits = 2048 + dhGroupExchangeMaximumBits = 8192 +) + +func (gex *dhGEXSHA) Client(c packetConn, randSource io.Reader, magics *handshakeMagics) (*kexResult, error) { + // Send GexRequest + kexDHGexRequest := kexDHGexRequestMsg{ + MinBits: dhGroupExchangeMinimumBits, + PreferedBits: dhGroupExchangePreferredBits, + MaxBits: dhGroupExchangeMaximumBits, + } + if err := c.writePacket(Marshal(&kexDHGexRequest)); err != nil { + return nil, err + } + + // Receive GexGroup + packet, err := c.readPacket() + if err != nil { + return nil, err + } + + var msg kexDHGexGroupMsg + if err = Unmarshal(packet, &msg); err != nil { + return nil, err + } + + // reject if p's bit length < dhGroupExchangeMinimumBits or > dhGroupExchangeMaximumBits + if msg.P.BitLen() < dhGroupExchangeMinimumBits || msg.P.BitLen() > dhGroupExchangeMaximumBits { + return nil, fmt.Errorf("ssh: server-generated gex p is out of range (%d bits)", msg.P.BitLen()) + } + + // Check if g is safe by verifying that 1 < g < p-1 + pMinusOne := new(big.Int).Sub(msg.P, bigOne) + if msg.G.Cmp(bigOne) <= 0 || msg.G.Cmp(pMinusOne) >= 0 { + return nil, fmt.Errorf("ssh: server provided gex g is not safe") + } + + // Send GexInit + pHalf := new(big.Int).Rsh(msg.P, 1) + x, err := rand.Int(randSource, pHalf) + if err != nil { + return nil, err + } + X := new(big.Int).Exp(msg.G, x, msg.P) + kexDHGexInit := kexDHGexInitMsg{ + X: X, + } + if err := c.writePacket(Marshal(&kexDHGexInit)); err != nil { + return nil, err + } + + // Receive GexReply + packet, err = c.readPacket() + if err != nil { + return nil, err + } + + var kexDHGexReply kexDHGexReplyMsg + if err = Unmarshal(packet, &kexDHGexReply); err != nil { + return nil, err + } + + if kexDHGexReply.Y.Cmp(bigOne) <= 0 || kexDHGexReply.Y.Cmp(pMinusOne) >= 0 { + return nil, errors.New("ssh: DH parameter out of bounds") + } + kInt := new(big.Int).Exp(kexDHGexReply.Y, x, msg.P) + + // Check if k is safe by verifying that k > 1 and k < p - 1 + if kInt.Cmp(bigOne) <= 0 || kInt.Cmp(pMinusOne) >= 0 { + return nil, fmt.Errorf("ssh: derived k is not safe") + } + + h := gex.hashFunc.New() + magics.write(h) + writeString(h, kexDHGexReply.HostKey) + binary.Write(h, binary.BigEndian, uint32(dhGroupExchangeMinimumBits)) + binary.Write(h, binary.BigEndian, uint32(dhGroupExchangePreferredBits)) + binary.Write(h, binary.BigEndian, uint32(dhGroupExchangeMaximumBits)) + writeInt(h, msg.P) + writeInt(h, msg.G) + writeInt(h, X) + writeInt(h, kexDHGexReply.Y) + K := make([]byte, intLength(kInt)) + marshalInt(K, kInt) + h.Write(K) + + return &kexResult{ + H: h.Sum(nil), + K: K, + HostKey: kexDHGexReply.HostKey, + Signature: kexDHGexReply.Signature, + Hash: gex.hashFunc, + }, nil +} + +// Server half implementation of the Diffie Hellman Key Exchange with SHA1 and SHA256. +// +// This is a minimal implementation to satisfy the automated tests. +func (gex dhGEXSHA) Server(c packetConn, randSource io.Reader, magics *handshakeMagics, priv AlgorithmSigner, algo string) (result *kexResult, err error) { + // Receive GexRequest + packet, err := c.readPacket() + if err != nil { + return + } + var kexDHGexRequest kexDHGexRequestMsg + if err = Unmarshal(packet, &kexDHGexRequest); err != nil { + return + } + + // Send GexGroup + // This is the group called diffie-hellman-group14-sha1 in RFC + // 4253 and Oakley Group 14 in RFC 3526. + p, _ := new(big.Int).SetString("FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AACAA68FFFFFFFFFFFFFFFF", 16) + g := big.NewInt(2) + + msg := &kexDHGexGroupMsg{ + P: p, + G: g, + } + if err := c.writePacket(Marshal(msg)); err != nil { + return nil, err + } + + // Receive GexInit + packet, err = c.readPacket() + if err != nil { + return + } + var kexDHGexInit kexDHGexInitMsg + if err = Unmarshal(packet, &kexDHGexInit); err != nil { + return + } + + pHalf := new(big.Int).Rsh(p, 1) + + y, err := rand.Int(randSource, pHalf) + if err != nil { + return + } + Y := new(big.Int).Exp(g, y, p) + + pMinusOne := new(big.Int).Sub(p, bigOne) + if kexDHGexInit.X.Cmp(bigOne) <= 0 || kexDHGexInit.X.Cmp(pMinusOne) >= 0 { + return nil, errors.New("ssh: DH parameter out of bounds") + } + kInt := new(big.Int).Exp(kexDHGexInit.X, y, p) + + hostKeyBytes := priv.PublicKey().Marshal() + + h := gex.hashFunc.New() + magics.write(h) + writeString(h, hostKeyBytes) + binary.Write(h, binary.BigEndian, uint32(dhGroupExchangeMinimumBits)) + binary.Write(h, binary.BigEndian, uint32(dhGroupExchangePreferredBits)) + binary.Write(h, binary.BigEndian, uint32(dhGroupExchangeMaximumBits)) + writeInt(h, p) + writeInt(h, g) + writeInt(h, kexDHGexInit.X) + writeInt(h, Y) + + K := make([]byte, intLength(kInt)) + marshalInt(K, kInt) + h.Write(K) + + H := h.Sum(nil) + + // H is already a hash, but the hostkey signing will apply its + // own key-specific hash algorithm. + sig, err := signAndMarshal(priv, randSource, H, algo) + if err != nil { + return nil, err + } + + kexDHGexReply := kexDHGexReplyMsg{ + HostKey: hostKeyBytes, + Y: Y, + Signature: sig, + } + packet = Marshal(&kexDHGexReply) + + err = c.writePacket(packet) + + return &kexResult{ + H: H, + K: K, + HostKey: hostKeyBytes, + Signature: sig, + Hash: gex.hashFunc, + }, err +} diff --git a/tempfork/sshtest/ssh/kex_test.go b/tempfork/sshtest/ssh/kex_test.go new file mode 100644 index 000000000..cb7f66a50 --- /dev/null +++ b/tempfork/sshtest/ssh/kex_test.go @@ -0,0 +1,106 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +// Key exchange tests. + +import ( + "crypto/rand" + "fmt" + "reflect" + "sync" + "testing" +) + +// Runs multiple key exchanges concurrent to detect potential data races with +// kex obtained from the global kexAlgoMap. +// This test needs to be executed using the race detector in order to detect +// race conditions. +func TestKexes(t *testing.T) { + type kexResultErr struct { + result *kexResult + err error + } + + for name, kex := range kexAlgoMap { + t.Run(name, func(t *testing.T) { + wg := sync.WaitGroup{} + for i := 0; i < 3; i++ { + wg.Add(1) + go func() { + defer wg.Done() + a, b := memPipe() + + s := make(chan kexResultErr, 1) + c := make(chan kexResultErr, 1) + var magics handshakeMagics + go func() { + r, e := kex.Client(a, rand.Reader, &magics) + a.Close() + c <- kexResultErr{r, e} + }() + go func() { + r, e := kex.Server(b, rand.Reader, &magics, testSigners["ecdsa"].(AlgorithmSigner), testSigners["ecdsa"].PublicKey().Type()) + b.Close() + s <- kexResultErr{r, e} + }() + + clientRes := <-c + serverRes := <-s + if clientRes.err != nil { + t.Errorf("client: %v", clientRes.err) + } + if serverRes.err != nil { + t.Errorf("server: %v", serverRes.err) + } + if !reflect.DeepEqual(clientRes.result, serverRes.result) { + t.Errorf("kex %q: mismatch %#v, %#v", name, clientRes.result, serverRes.result) + } + }() + } + wg.Wait() + }) + } +} + +func BenchmarkKexes(b *testing.B) { + type kexResultErr struct { + result *kexResult + err error + } + + for name, kex := range kexAlgoMap { + b.Run(name, func(b *testing.B) { + for i := 0; i < b.N; i++ { + t1, t2 := memPipe() + + s := make(chan kexResultErr, 1) + c := make(chan kexResultErr, 1) + var magics handshakeMagics + + go func() { + r, e := kex.Client(t1, rand.Reader, &magics) + t1.Close() + c <- kexResultErr{r, e} + }() + go func() { + r, e := kex.Server(t2, rand.Reader, &magics, testSigners["ecdsa"].(AlgorithmSigner), testSigners["ecdsa"].PublicKey().Type()) + t2.Close() + s <- kexResultErr{r, e} + }() + + clientRes := <-c + serverRes := <-s + + if clientRes.err != nil { + panic(fmt.Sprintf("client: %v", clientRes.err)) + } + if serverRes.err != nil { + panic(fmt.Sprintf("server: %v", serverRes.err)) + } + } + }) + } +} diff --git a/tempfork/sshtest/ssh/keys.go b/tempfork/sshtest/ssh/keys.go new file mode 100644 index 000000000..4a3d769d9 --- /dev/null +++ b/tempfork/sshtest/ssh/keys.go @@ -0,0 +1,1626 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "bytes" + "crypto" + "crypto/dsa" + "crypto/ecdsa" + "crypto/ed25519" + "crypto/elliptic" + "crypto/md5" + "crypto/rand" + "crypto/rsa" + "crypto/sha256" + "crypto/x509" + "encoding/asn1" + "encoding/base64" + "encoding/binary" + "encoding/hex" + "encoding/pem" + "errors" + "fmt" + "io" + "math/big" + "strings" +) + +// Public key algorithms names. These values can appear in PublicKey.Type, +// ClientConfig.HostKeyAlgorithms, Signature.Format, or as AlgorithmSigner +// arguments. +const ( + KeyAlgoRSA = "ssh-rsa" + KeyAlgoDSA = "ssh-dss" + KeyAlgoECDSA256 = "ecdsa-sha2-nistp256" + KeyAlgoSKECDSA256 = "sk-ecdsa-sha2-nistp256@openssh.com" + KeyAlgoECDSA384 = "ecdsa-sha2-nistp384" + KeyAlgoECDSA521 = "ecdsa-sha2-nistp521" + KeyAlgoED25519 = "ssh-ed25519" + KeyAlgoSKED25519 = "sk-ssh-ed25519@openssh.com" + + // KeyAlgoRSASHA256 and KeyAlgoRSASHA512 are only public key algorithms, not + // public key formats, so they can't appear as a PublicKey.Type. The + // corresponding PublicKey.Type is KeyAlgoRSA. See RFC 8332, Section 2. + KeyAlgoRSASHA256 = "rsa-sha2-256" + KeyAlgoRSASHA512 = "rsa-sha2-512" +) + +const ( + // Deprecated: use KeyAlgoRSA. + SigAlgoRSA = KeyAlgoRSA + // Deprecated: use KeyAlgoRSASHA256. + SigAlgoRSASHA2256 = KeyAlgoRSASHA256 + // Deprecated: use KeyAlgoRSASHA512. + SigAlgoRSASHA2512 = KeyAlgoRSASHA512 +) + +// parsePubKey parses a public key of the given algorithm. +// Use ParsePublicKey for keys with prepended algorithm. +func parsePubKey(in []byte, algo string) (pubKey PublicKey, rest []byte, err error) { + switch algo { + case KeyAlgoRSA: + return parseRSA(in) + case KeyAlgoDSA: + return parseDSA(in) + case KeyAlgoECDSA256, KeyAlgoECDSA384, KeyAlgoECDSA521: + return parseECDSA(in) + case KeyAlgoSKECDSA256: + return parseSKECDSA(in) + case KeyAlgoED25519: + return parseED25519(in) + case KeyAlgoSKED25519: + return parseSKEd25519(in) + case CertAlgoRSAv01, CertAlgoDSAv01, CertAlgoECDSA256v01, CertAlgoECDSA384v01, CertAlgoECDSA521v01, CertAlgoSKECDSA256v01, CertAlgoED25519v01, CertAlgoSKED25519v01: + cert, err := parseCert(in, certKeyAlgoNames[algo]) + if err != nil { + return nil, nil, err + } + return cert, nil, nil + } + return nil, nil, fmt.Errorf("ssh: unknown key algorithm: %v", algo) +} + +// parseAuthorizedKey parses a public key in OpenSSH authorized_keys format +// (see sshd(8) manual page) once the options and key type fields have been +// removed. +func parseAuthorizedKey(in []byte) (out PublicKey, comment string, err error) { + in = bytes.TrimSpace(in) + + i := bytes.IndexAny(in, " \t") + if i == -1 { + i = len(in) + } + base64Key := in[:i] + + key := make([]byte, base64.StdEncoding.DecodedLen(len(base64Key))) + n, err := base64.StdEncoding.Decode(key, base64Key) + if err != nil { + return nil, "", err + } + key = key[:n] + out, err = ParsePublicKey(key) + if err != nil { + return nil, "", err + } + comment = string(bytes.TrimSpace(in[i:])) + return out, comment, nil +} + +// ParseKnownHosts parses an entry in the format of the known_hosts file. +// +// The known_hosts format is documented in the sshd(8) manual page. This +// function will parse a single entry from in. On successful return, marker +// will contain the optional marker value (i.e. "cert-authority" or "revoked") +// or else be empty, hosts will contain the hosts that this entry matches, +// pubKey will contain the public key and comment will contain any trailing +// comment at the end of the line. See the sshd(8) manual page for the various +// forms that a host string can take. +// +// The unparsed remainder of the input will be returned in rest. This function +// can be called repeatedly to parse multiple entries. +// +// If no entries were found in the input then err will be io.EOF. Otherwise a +// non-nil err value indicates a parse error. +func ParseKnownHosts(in []byte) (marker string, hosts []string, pubKey PublicKey, comment string, rest []byte, err error) { + for len(in) > 0 { + end := bytes.IndexByte(in, '\n') + if end != -1 { + rest = in[end+1:] + in = in[:end] + } else { + rest = nil + } + + end = bytes.IndexByte(in, '\r') + if end != -1 { + in = in[:end] + } + + in = bytes.TrimSpace(in) + if len(in) == 0 || in[0] == '#' { + in = rest + continue + } + + i := bytes.IndexAny(in, " \t") + if i == -1 { + in = rest + continue + } + + // Strip out the beginning of the known_host key. + // This is either an optional marker or a (set of) hostname(s). + keyFields := bytes.Fields(in) + if len(keyFields) < 3 || len(keyFields) > 5 { + return "", nil, nil, "", nil, errors.New("ssh: invalid entry in known_hosts data") + } + + // keyFields[0] is either "@cert-authority", "@revoked" or a comma separated + // list of hosts + marker := "" + if keyFields[0][0] == '@' { + marker = string(keyFields[0][1:]) + keyFields = keyFields[1:] + } + + hosts := string(keyFields[0]) + // keyFields[1] contains the key type (e.g. “ssh-rsa”). + // However, that information is duplicated inside the + // base64-encoded key and so is ignored here. + + key := bytes.Join(keyFields[2:], []byte(" ")) + if pubKey, comment, err = parseAuthorizedKey(key); err != nil { + return "", nil, nil, "", nil, err + } + + return marker, strings.Split(hosts, ","), pubKey, comment, rest, nil + } + + return "", nil, nil, "", nil, io.EOF +} + +// ParseAuthorizedKey parses a public key from an authorized_keys +// file used in OpenSSH according to the sshd(8) manual page. +func ParseAuthorizedKey(in []byte) (out PublicKey, comment string, options []string, rest []byte, err error) { + for len(in) > 0 { + end := bytes.IndexByte(in, '\n') + if end != -1 { + rest = in[end+1:] + in = in[:end] + } else { + rest = nil + } + + end = bytes.IndexByte(in, '\r') + if end != -1 { + in = in[:end] + } + + in = bytes.TrimSpace(in) + if len(in) == 0 || in[0] == '#' { + in = rest + continue + } + + i := bytes.IndexAny(in, " \t") + if i == -1 { + in = rest + continue + } + + if out, comment, err = parseAuthorizedKey(in[i:]); err == nil { + return out, comment, options, rest, nil + } + + // No key type recognised. Maybe there's an options field at + // the beginning. + var b byte + inQuote := false + var candidateOptions []string + optionStart := 0 + for i, b = range in { + isEnd := !inQuote && (b == ' ' || b == '\t') + if (b == ',' && !inQuote) || isEnd { + if i-optionStart > 0 { + candidateOptions = append(candidateOptions, string(in[optionStart:i])) + } + optionStart = i + 1 + } + if isEnd { + break + } + if b == '"' && (i == 0 || (i > 0 && in[i-1] != '\\')) { + inQuote = !inQuote + } + } + for i < len(in) && (in[i] == ' ' || in[i] == '\t') { + i++ + } + if i == len(in) { + // Invalid line: unmatched quote + in = rest + continue + } + + in = in[i:] + i = bytes.IndexAny(in, " \t") + if i == -1 { + in = rest + continue + } + + if out, comment, err = parseAuthorizedKey(in[i:]); err == nil { + options = candidateOptions + return out, comment, options, rest, nil + } + + in = rest + continue + } + + return nil, "", nil, nil, errors.New("ssh: no key found") +} + +// ParsePublicKey parses an SSH public key formatted for use in +// the SSH wire protocol according to RFC 4253, section 6.6. +func ParsePublicKey(in []byte) (out PublicKey, err error) { + algo, in, ok := parseString(in) + if !ok { + return nil, errShortRead + } + var rest []byte + out, rest, err = parsePubKey(in, string(algo)) + if len(rest) > 0 { + return nil, errors.New("ssh: trailing junk in public key") + } + + return out, err +} + +// MarshalAuthorizedKey serializes key for inclusion in an OpenSSH +// authorized_keys file. The return value ends with newline. +func MarshalAuthorizedKey(key PublicKey) []byte { + b := &bytes.Buffer{} + b.WriteString(key.Type()) + b.WriteByte(' ') + e := base64.NewEncoder(base64.StdEncoding, b) + e.Write(key.Marshal()) + e.Close() + b.WriteByte('\n') + return b.Bytes() +} + +// MarshalPrivateKey returns a PEM block with the private key serialized in the +// OpenSSH format. +func MarshalPrivateKey(key crypto.PrivateKey, comment string) (*pem.Block, error) { + return marshalOpenSSHPrivateKey(key, comment, unencryptedOpenSSHMarshaler) +} + +// PublicKey represents a public key using an unspecified algorithm. +// +// Some PublicKeys provided by this package also implement CryptoPublicKey. +type PublicKey interface { + // Type returns the key format name, e.g. "ssh-rsa". + Type() string + + // Marshal returns the serialized key data in SSH wire format, with the name + // prefix. To unmarshal the returned data, use the ParsePublicKey function. + Marshal() []byte + + // Verify that sig is a signature on the given data using this key. This + // method will hash the data appropriately first. sig.Format is allowed to + // be any signature algorithm compatible with the key type, the caller + // should check if it has more stringent requirements. + Verify(data []byte, sig *Signature) error +} + +// CryptoPublicKey, if implemented by a PublicKey, +// returns the underlying crypto.PublicKey form of the key. +type CryptoPublicKey interface { + CryptoPublicKey() crypto.PublicKey +} + +// A Signer can create signatures that verify against a public key. +// +// Some Signers provided by this package also implement MultiAlgorithmSigner. +type Signer interface { + // PublicKey returns the associated PublicKey. + PublicKey() PublicKey + + // Sign returns a signature for the given data. This method will hash the + // data appropriately first. The signature algorithm is expected to match + // the key format returned by the PublicKey.Type method (and not to be any + // alternative algorithm supported by the key format). + Sign(rand io.Reader, data []byte) (*Signature, error) +} + +// An AlgorithmSigner is a Signer that also supports specifying an algorithm to +// use for signing. +// +// An AlgorithmSigner can't advertise the algorithms it supports, unless it also +// implements MultiAlgorithmSigner, so it should be prepared to be invoked with +// every algorithm supported by the public key format. +type AlgorithmSigner interface { + Signer + + // SignWithAlgorithm is like Signer.Sign, but allows specifying a desired + // signing algorithm. Callers may pass an empty string for the algorithm in + // which case the AlgorithmSigner will use a default algorithm. This default + // doesn't currently control any behavior in this package. + SignWithAlgorithm(rand io.Reader, data []byte, algorithm string) (*Signature, error) +} + +// MultiAlgorithmSigner is an AlgorithmSigner that also reports the algorithms +// supported by that signer. +type MultiAlgorithmSigner interface { + AlgorithmSigner + + // Algorithms returns the available algorithms in preference order. The list + // must not be empty, and it must not include certificate types. + Algorithms() []string +} + +// NewSignerWithAlgorithms returns a signer restricted to the specified +// algorithms. The algorithms must be set in preference order. The list must not +// be empty, and it must not include certificate types. An error is returned if +// the specified algorithms are incompatible with the public key type. +func NewSignerWithAlgorithms(signer AlgorithmSigner, algorithms []string) (MultiAlgorithmSigner, error) { + if len(algorithms) == 0 { + return nil, errors.New("ssh: please specify at least one valid signing algorithm") + } + var signerAlgos []string + supportedAlgos := algorithmsForKeyFormat(underlyingAlgo(signer.PublicKey().Type())) + if s, ok := signer.(*multiAlgorithmSigner); ok { + signerAlgos = s.Algorithms() + } else { + signerAlgos = supportedAlgos + } + + for _, algo := range algorithms { + if !contains(supportedAlgos, algo) { + return nil, fmt.Errorf("ssh: algorithm %q is not supported for key type %q", + algo, signer.PublicKey().Type()) + } + if !contains(signerAlgos, algo) { + return nil, fmt.Errorf("ssh: algorithm %q is restricted for the provided signer", algo) + } + } + return &multiAlgorithmSigner{ + AlgorithmSigner: signer, + supportedAlgorithms: algorithms, + }, nil +} + +type multiAlgorithmSigner struct { + AlgorithmSigner + supportedAlgorithms []string +} + +func (s *multiAlgorithmSigner) Algorithms() []string { + return s.supportedAlgorithms +} + +func (s *multiAlgorithmSigner) isAlgorithmSupported(algorithm string) bool { + if algorithm == "" { + algorithm = underlyingAlgo(s.PublicKey().Type()) + } + for _, algo := range s.supportedAlgorithms { + if algorithm == algo { + return true + } + } + return false +} + +func (s *multiAlgorithmSigner) SignWithAlgorithm(rand io.Reader, data []byte, algorithm string) (*Signature, error) { + if !s.isAlgorithmSupported(algorithm) { + return nil, fmt.Errorf("ssh: algorithm %q is not supported: %v", algorithm, s.supportedAlgorithms) + } + return s.AlgorithmSigner.SignWithAlgorithm(rand, data, algorithm) +} + +type rsaPublicKey rsa.PublicKey + +func (r *rsaPublicKey) Type() string { + return "ssh-rsa" +} + +// parseRSA parses an RSA key according to RFC 4253, section 6.6. +func parseRSA(in []byte) (out PublicKey, rest []byte, err error) { + var w struct { + E *big.Int + N *big.Int + Rest []byte `ssh:"rest"` + } + if err := Unmarshal(in, &w); err != nil { + return nil, nil, err + } + + if w.E.BitLen() > 24 { + return nil, nil, errors.New("ssh: exponent too large") + } + e := w.E.Int64() + if e < 3 || e&1 == 0 { + return nil, nil, errors.New("ssh: incorrect exponent") + } + + var key rsa.PublicKey + key.E = int(e) + key.N = w.N + return (*rsaPublicKey)(&key), w.Rest, nil +} + +func (r *rsaPublicKey) Marshal() []byte { + e := new(big.Int).SetInt64(int64(r.E)) + // RSA publickey struct layout should match the struct used by + // parseRSACert in the x/crypto/ssh/agent package. + wirekey := struct { + Name string + E *big.Int + N *big.Int + }{ + KeyAlgoRSA, + e, + r.N, + } + return Marshal(&wirekey) +} + +func (r *rsaPublicKey) Verify(data []byte, sig *Signature) error { + supportedAlgos := algorithmsForKeyFormat(r.Type()) + if !contains(supportedAlgos, sig.Format) { + return fmt.Errorf("ssh: signature type %s for key type %s", sig.Format, r.Type()) + } + hash := hashFuncs[sig.Format] + h := hash.New() + h.Write(data) + digest := h.Sum(nil) + + // Signatures in PKCS1v15 must match the key's modulus in + // length. However with SSH, some signers provide RSA + // signatures which are missing the MSB 0's of the bignum + // represented. With ssh-rsa signatures, this is encouraged by + // the spec (even though e.g. OpenSSH will give the full + // length unconditionally). With rsa-sha2-* signatures, the + // verifier is allowed to support these, even though they are + // out of spec. See RFC 4253 Section 6.6 for ssh-rsa and RFC + // 8332 Section 3 for rsa-sha2-* details. + // + // In practice: + // * OpenSSH always allows "short" signatures: + // https://github.com/openssh/openssh-portable/blob/V_9_8_P1/ssh-rsa.c#L526 + // but always generates padded signatures: + // https://github.com/openssh/openssh-portable/blob/V_9_8_P1/ssh-rsa.c#L439 + // + // * PuTTY versions 0.81 and earlier will generate short + // signatures for all RSA signature variants. Note that + // PuTTY is embedded in other software, such as WinSCP and + // FileZilla. At the time of writing, a patch has been + // applied to PuTTY to generate padded signatures for + // rsa-sha2-*, but not yet released: + // https://git.tartarus.org/?p=simon/putty.git;a=commitdiff;h=a5bcf3d384e1bf15a51a6923c3724cbbee022d8e + // + // * SSH.NET versions 2024.0.0 and earlier will generate short + // signatures for all RSA signature variants, fixed in 2024.1.0: + // https://github.com/sshnet/SSH.NET/releases/tag/2024.1.0 + // + // As a result, we pad these up to the key size by inserting + // leading 0's. + // + // Note that support for short signatures with rsa-sha2-* may + // be removed in the future due to such signatures not being + // allowed by the spec. + blob := sig.Blob + keySize := (*rsa.PublicKey)(r).Size() + if len(blob) < keySize { + padded := make([]byte, keySize) + copy(padded[keySize-len(blob):], blob) + blob = padded + } + return rsa.VerifyPKCS1v15((*rsa.PublicKey)(r), hash, digest, blob) +} + +func (r *rsaPublicKey) CryptoPublicKey() crypto.PublicKey { + return (*rsa.PublicKey)(r) +} + +type dsaPublicKey dsa.PublicKey + +func (k *dsaPublicKey) Type() string { + return "ssh-dss" +} + +func checkDSAParams(param *dsa.Parameters) error { + // SSH specifies FIPS 186-2, which only provided a single size + // (1024 bits) DSA key. FIPS 186-3 allows for larger key + // sizes, which would confuse SSH. + if l := param.P.BitLen(); l != 1024 { + return fmt.Errorf("ssh: unsupported DSA key size %d", l) + } + + return nil +} + +// parseDSA parses an DSA key according to RFC 4253, section 6.6. +func parseDSA(in []byte) (out PublicKey, rest []byte, err error) { + var w struct { + P, Q, G, Y *big.Int + Rest []byte `ssh:"rest"` + } + if err := Unmarshal(in, &w); err != nil { + return nil, nil, err + } + + param := dsa.Parameters{ + P: w.P, + Q: w.Q, + G: w.G, + } + if err := checkDSAParams(¶m); err != nil { + return nil, nil, err + } + + key := &dsaPublicKey{ + Parameters: param, + Y: w.Y, + } + return key, w.Rest, nil +} + +func (k *dsaPublicKey) Marshal() []byte { + // DSA publickey struct layout should match the struct used by + // parseDSACert in the x/crypto/ssh/agent package. + w := struct { + Name string + P, Q, G, Y *big.Int + }{ + k.Type(), + k.P, + k.Q, + k.G, + k.Y, + } + + return Marshal(&w) +} + +func (k *dsaPublicKey) Verify(data []byte, sig *Signature) error { + if sig.Format != k.Type() { + return fmt.Errorf("ssh: signature type %s for key type %s", sig.Format, k.Type()) + } + h := hashFuncs[sig.Format].New() + h.Write(data) + digest := h.Sum(nil) + + // Per RFC 4253, section 6.6, + // The value for 'dss_signature_blob' is encoded as a string containing + // r, followed by s (which are 160-bit integers, without lengths or + // padding, unsigned, and in network byte order). + // For DSS purposes, sig.Blob should be exactly 40 bytes in length. + if len(sig.Blob) != 40 { + return errors.New("ssh: DSA signature parse error") + } + r := new(big.Int).SetBytes(sig.Blob[:20]) + s := new(big.Int).SetBytes(sig.Blob[20:]) + if dsa.Verify((*dsa.PublicKey)(k), digest, r, s) { + return nil + } + return errors.New("ssh: signature did not verify") +} + +func (k *dsaPublicKey) CryptoPublicKey() crypto.PublicKey { + return (*dsa.PublicKey)(k) +} + +type dsaPrivateKey struct { + *dsa.PrivateKey +} + +func (k *dsaPrivateKey) PublicKey() PublicKey { + return (*dsaPublicKey)(&k.PrivateKey.PublicKey) +} + +func (k *dsaPrivateKey) Sign(rand io.Reader, data []byte) (*Signature, error) { + return k.SignWithAlgorithm(rand, data, k.PublicKey().Type()) +} + +func (k *dsaPrivateKey) Algorithms() []string { + return []string{k.PublicKey().Type()} +} + +func (k *dsaPrivateKey) SignWithAlgorithm(rand io.Reader, data []byte, algorithm string) (*Signature, error) { + if algorithm != "" && algorithm != k.PublicKey().Type() { + return nil, fmt.Errorf("ssh: unsupported signature algorithm %s", algorithm) + } + + h := hashFuncs[k.PublicKey().Type()].New() + h.Write(data) + digest := h.Sum(nil) + r, s, err := dsa.Sign(rand, k.PrivateKey, digest) + if err != nil { + return nil, err + } + + sig := make([]byte, 40) + rb := r.Bytes() + sb := s.Bytes() + + copy(sig[20-len(rb):20], rb) + copy(sig[40-len(sb):], sb) + + return &Signature{ + Format: k.PublicKey().Type(), + Blob: sig, + }, nil +} + +type ecdsaPublicKey ecdsa.PublicKey + +func (k *ecdsaPublicKey) Type() string { + return "ecdsa-sha2-" + k.nistID() +} + +func (k *ecdsaPublicKey) nistID() string { + switch k.Params().BitSize { + case 256: + return "nistp256" + case 384: + return "nistp384" + case 521: + return "nistp521" + } + panic("ssh: unsupported ecdsa key size") +} + +type ed25519PublicKey ed25519.PublicKey + +func (k ed25519PublicKey) Type() string { + return KeyAlgoED25519 +} + +func parseED25519(in []byte) (out PublicKey, rest []byte, err error) { + var w struct { + KeyBytes []byte + Rest []byte `ssh:"rest"` + } + + if err := Unmarshal(in, &w); err != nil { + return nil, nil, err + } + + if l := len(w.KeyBytes); l != ed25519.PublicKeySize { + return nil, nil, fmt.Errorf("invalid size %d for Ed25519 public key", l) + } + + return ed25519PublicKey(w.KeyBytes), w.Rest, nil +} + +func (k ed25519PublicKey) Marshal() []byte { + w := struct { + Name string + KeyBytes []byte + }{ + KeyAlgoED25519, + []byte(k), + } + return Marshal(&w) +} + +func (k ed25519PublicKey) Verify(b []byte, sig *Signature) error { + if sig.Format != k.Type() { + return fmt.Errorf("ssh: signature type %s for key type %s", sig.Format, k.Type()) + } + if l := len(k); l != ed25519.PublicKeySize { + return fmt.Errorf("ssh: invalid size %d for Ed25519 public key", l) + } + + if ok := ed25519.Verify(ed25519.PublicKey(k), b, sig.Blob); !ok { + return errors.New("ssh: signature did not verify") + } + + return nil +} + +func (k ed25519PublicKey) CryptoPublicKey() crypto.PublicKey { + return ed25519.PublicKey(k) +} + +func supportedEllipticCurve(curve elliptic.Curve) bool { + return curve == elliptic.P256() || curve == elliptic.P384() || curve == elliptic.P521() +} + +// parseECDSA parses an ECDSA key according to RFC 5656, section 3.1. +func parseECDSA(in []byte) (out PublicKey, rest []byte, err error) { + var w struct { + Curve string + KeyBytes []byte + Rest []byte `ssh:"rest"` + } + + if err := Unmarshal(in, &w); err != nil { + return nil, nil, err + } + + key := new(ecdsa.PublicKey) + + switch w.Curve { + case "nistp256": + key.Curve = elliptic.P256() + case "nistp384": + key.Curve = elliptic.P384() + case "nistp521": + key.Curve = elliptic.P521() + default: + return nil, nil, errors.New("ssh: unsupported curve") + } + + key.X, key.Y = elliptic.Unmarshal(key.Curve, w.KeyBytes) + if key.X == nil || key.Y == nil { + return nil, nil, errors.New("ssh: invalid curve point") + } + return (*ecdsaPublicKey)(key), w.Rest, nil +} + +func (k *ecdsaPublicKey) Marshal() []byte { + // See RFC 5656, section 3.1. + keyBytes := elliptic.Marshal(k.Curve, k.X, k.Y) + // ECDSA publickey struct layout should match the struct used by + // parseECDSACert in the x/crypto/ssh/agent package. + w := struct { + Name string + ID string + Key []byte + }{ + k.Type(), + k.nistID(), + keyBytes, + } + + return Marshal(&w) +} + +func (k *ecdsaPublicKey) Verify(data []byte, sig *Signature) error { + if sig.Format != k.Type() { + return fmt.Errorf("ssh: signature type %s for key type %s", sig.Format, k.Type()) + } + + h := hashFuncs[sig.Format].New() + h.Write(data) + digest := h.Sum(nil) + + // Per RFC 5656, section 3.1.2, + // The ecdsa_signature_blob value has the following specific encoding: + // mpint r + // mpint s + var ecSig struct { + R *big.Int + S *big.Int + } + + if err := Unmarshal(sig.Blob, &ecSig); err != nil { + return err + } + + if ecdsa.Verify((*ecdsa.PublicKey)(k), digest, ecSig.R, ecSig.S) { + return nil + } + return errors.New("ssh: signature did not verify") +} + +func (k *ecdsaPublicKey) CryptoPublicKey() crypto.PublicKey { + return (*ecdsa.PublicKey)(k) +} + +// skFields holds the additional fields present in U2F/FIDO2 signatures. +// See openssh/PROTOCOL.u2f 'SSH U2F Signatures' for details. +type skFields struct { + // Flags contains U2F/FIDO2 flags such as 'user present' + Flags byte + // Counter is a monotonic signature counter which can be + // used to detect concurrent use of a private key, should + // it be extracted from hardware. + Counter uint32 +} + +type skECDSAPublicKey struct { + // application is a URL-like string, typically "ssh:" for SSH. + // see openssh/PROTOCOL.u2f for details. + application string + ecdsa.PublicKey +} + +func (k *skECDSAPublicKey) Type() string { + return KeyAlgoSKECDSA256 +} + +func (k *skECDSAPublicKey) nistID() string { + return "nistp256" +} + +func parseSKECDSA(in []byte) (out PublicKey, rest []byte, err error) { + var w struct { + Curve string + KeyBytes []byte + Application string + Rest []byte `ssh:"rest"` + } + + if err := Unmarshal(in, &w); err != nil { + return nil, nil, err + } + + key := new(skECDSAPublicKey) + key.application = w.Application + + if w.Curve != "nistp256" { + return nil, nil, errors.New("ssh: unsupported curve") + } + key.Curve = elliptic.P256() + + key.X, key.Y = elliptic.Unmarshal(key.Curve, w.KeyBytes) + if key.X == nil || key.Y == nil { + return nil, nil, errors.New("ssh: invalid curve point") + } + + return key, w.Rest, nil +} + +func (k *skECDSAPublicKey) Marshal() []byte { + // See RFC 5656, section 3.1. + keyBytes := elliptic.Marshal(k.Curve, k.X, k.Y) + w := struct { + Name string + ID string + Key []byte + Application string + }{ + k.Type(), + k.nistID(), + keyBytes, + k.application, + } + + return Marshal(&w) +} + +func (k *skECDSAPublicKey) Verify(data []byte, sig *Signature) error { + if sig.Format != k.Type() { + return fmt.Errorf("ssh: signature type %s for key type %s", sig.Format, k.Type()) + } + + h := hashFuncs[sig.Format].New() + h.Write([]byte(k.application)) + appDigest := h.Sum(nil) + + h.Reset() + h.Write(data) + dataDigest := h.Sum(nil) + + var ecSig struct { + R *big.Int + S *big.Int + } + if err := Unmarshal(sig.Blob, &ecSig); err != nil { + return err + } + + var skf skFields + if err := Unmarshal(sig.Rest, &skf); err != nil { + return err + } + + blob := struct { + ApplicationDigest []byte `ssh:"rest"` + Flags byte + Counter uint32 + MessageDigest []byte `ssh:"rest"` + }{ + appDigest, + skf.Flags, + skf.Counter, + dataDigest, + } + + original := Marshal(blob) + + h.Reset() + h.Write(original) + digest := h.Sum(nil) + + if ecdsa.Verify((*ecdsa.PublicKey)(&k.PublicKey), digest, ecSig.R, ecSig.S) { + return nil + } + return errors.New("ssh: signature did not verify") +} + +func (k *skECDSAPublicKey) CryptoPublicKey() crypto.PublicKey { + return &k.PublicKey +} + +type skEd25519PublicKey struct { + // application is a URL-like string, typically "ssh:" for SSH. + // see openssh/PROTOCOL.u2f for details. + application string + ed25519.PublicKey +} + +func (k *skEd25519PublicKey) Type() string { + return KeyAlgoSKED25519 +} + +func parseSKEd25519(in []byte) (out PublicKey, rest []byte, err error) { + var w struct { + KeyBytes []byte + Application string + Rest []byte `ssh:"rest"` + } + + if err := Unmarshal(in, &w); err != nil { + return nil, nil, err + } + + if l := len(w.KeyBytes); l != ed25519.PublicKeySize { + return nil, nil, fmt.Errorf("invalid size %d for Ed25519 public key", l) + } + + key := new(skEd25519PublicKey) + key.application = w.Application + key.PublicKey = ed25519.PublicKey(w.KeyBytes) + + return key, w.Rest, nil +} + +func (k *skEd25519PublicKey) Marshal() []byte { + w := struct { + Name string + KeyBytes []byte + Application string + }{ + KeyAlgoSKED25519, + []byte(k.PublicKey), + k.application, + } + return Marshal(&w) +} + +func (k *skEd25519PublicKey) Verify(data []byte, sig *Signature) error { + if sig.Format != k.Type() { + return fmt.Errorf("ssh: signature type %s for key type %s", sig.Format, k.Type()) + } + if l := len(k.PublicKey); l != ed25519.PublicKeySize { + return fmt.Errorf("invalid size %d for Ed25519 public key", l) + } + + h := hashFuncs[sig.Format].New() + h.Write([]byte(k.application)) + appDigest := h.Sum(nil) + + h.Reset() + h.Write(data) + dataDigest := h.Sum(nil) + + var edSig struct { + Signature []byte `ssh:"rest"` + } + + if err := Unmarshal(sig.Blob, &edSig); err != nil { + return err + } + + var skf skFields + if err := Unmarshal(sig.Rest, &skf); err != nil { + return err + } + + blob := struct { + ApplicationDigest []byte `ssh:"rest"` + Flags byte + Counter uint32 + MessageDigest []byte `ssh:"rest"` + }{ + appDigest, + skf.Flags, + skf.Counter, + dataDigest, + } + + original := Marshal(blob) + + if ok := ed25519.Verify(k.PublicKey, original, edSig.Signature); !ok { + return errors.New("ssh: signature did not verify") + } + + return nil +} + +func (k *skEd25519PublicKey) CryptoPublicKey() crypto.PublicKey { + return k.PublicKey +} + +// NewSignerFromKey takes an *rsa.PrivateKey, *dsa.PrivateKey, +// *ecdsa.PrivateKey or any other crypto.Signer and returns a +// corresponding Signer instance. ECDSA keys must use P-256, P-384 or +// P-521. DSA keys must use parameter size L1024N160. +func NewSignerFromKey(key interface{}) (Signer, error) { + switch key := key.(type) { + case crypto.Signer: + return NewSignerFromSigner(key) + case *dsa.PrivateKey: + return newDSAPrivateKey(key) + default: + return nil, fmt.Errorf("ssh: unsupported key type %T", key) + } +} + +func newDSAPrivateKey(key *dsa.PrivateKey) (Signer, error) { + if err := checkDSAParams(&key.PublicKey.Parameters); err != nil { + return nil, err + } + + return &dsaPrivateKey{key}, nil +} + +type wrappedSigner struct { + signer crypto.Signer + pubKey PublicKey +} + +// NewSignerFromSigner takes any crypto.Signer implementation and +// returns a corresponding Signer interface. This can be used, for +// example, with keys kept in hardware modules. +func NewSignerFromSigner(signer crypto.Signer) (Signer, error) { + pubKey, err := NewPublicKey(signer.Public()) + if err != nil { + return nil, err + } + + return &wrappedSigner{signer, pubKey}, nil +} + +func (s *wrappedSigner) PublicKey() PublicKey { + return s.pubKey +} + +func (s *wrappedSigner) Sign(rand io.Reader, data []byte) (*Signature, error) { + return s.SignWithAlgorithm(rand, data, s.pubKey.Type()) +} + +func (s *wrappedSigner) Algorithms() []string { + return algorithmsForKeyFormat(s.pubKey.Type()) +} + +func (s *wrappedSigner) SignWithAlgorithm(rand io.Reader, data []byte, algorithm string) (*Signature, error) { + if algorithm == "" { + algorithm = s.pubKey.Type() + } + + if !contains(s.Algorithms(), algorithm) { + return nil, fmt.Errorf("ssh: unsupported signature algorithm %q for key format %q", algorithm, s.pubKey.Type()) + } + + hashFunc := hashFuncs[algorithm] + var digest []byte + if hashFunc != 0 { + h := hashFunc.New() + h.Write(data) + digest = h.Sum(nil) + } else { + digest = data + } + + signature, err := s.signer.Sign(rand, digest, hashFunc) + if err != nil { + return nil, err + } + + // crypto.Signer.Sign is expected to return an ASN.1-encoded signature + // for ECDSA and DSA, but that's not the encoding expected by SSH, so + // re-encode. + switch s.pubKey.(type) { + case *ecdsaPublicKey, *dsaPublicKey: + type asn1Signature struct { + R, S *big.Int + } + asn1Sig := new(asn1Signature) + _, err := asn1.Unmarshal(signature, asn1Sig) + if err != nil { + return nil, err + } + + switch s.pubKey.(type) { + case *ecdsaPublicKey: + signature = Marshal(asn1Sig) + + case *dsaPublicKey: + signature = make([]byte, 40) + r := asn1Sig.R.Bytes() + s := asn1Sig.S.Bytes() + copy(signature[20-len(r):20], r) + copy(signature[40-len(s):40], s) + } + } + + return &Signature{ + Format: algorithm, + Blob: signature, + }, nil +} + +// NewPublicKey takes an *rsa.PublicKey, *dsa.PublicKey, *ecdsa.PublicKey, +// or ed25519.PublicKey returns a corresponding PublicKey instance. +// ECDSA keys must use P-256, P-384 or P-521. +func NewPublicKey(key interface{}) (PublicKey, error) { + switch key := key.(type) { + case *rsa.PublicKey: + return (*rsaPublicKey)(key), nil + case *ecdsa.PublicKey: + if !supportedEllipticCurve(key.Curve) { + return nil, errors.New("ssh: only P-256, P-384 and P-521 EC keys are supported") + } + return (*ecdsaPublicKey)(key), nil + case *dsa.PublicKey: + return (*dsaPublicKey)(key), nil + case ed25519.PublicKey: + if l := len(key); l != ed25519.PublicKeySize { + return nil, fmt.Errorf("ssh: invalid size %d for Ed25519 public key", l) + } + return ed25519PublicKey(key), nil + default: + return nil, fmt.Errorf("ssh: unsupported key type %T", key) + } +} + +// ParsePrivateKey returns a Signer from a PEM encoded private key. It supports +// the same keys as ParseRawPrivateKey. If the private key is encrypted, it +// will return a PassphraseMissingError. +func ParsePrivateKey(pemBytes []byte) (Signer, error) { + key, err := ParseRawPrivateKey(pemBytes) + if err != nil { + return nil, err + } + + return NewSignerFromKey(key) +} + +// encryptedBlock tells whether a private key is +// encrypted by examining its Proc-Type header +// for a mention of ENCRYPTED +// according to RFC 1421 Section 4.6.1.1. +func encryptedBlock(block *pem.Block) bool { + return strings.Contains(block.Headers["Proc-Type"], "ENCRYPTED") +} + +// A PassphraseMissingError indicates that parsing this private key requires a +// passphrase. Use ParsePrivateKeyWithPassphrase. +type PassphraseMissingError struct { + // PublicKey will be set if the private key format includes an unencrypted + // public key along with the encrypted private key. + PublicKey PublicKey +} + +func (*PassphraseMissingError) Error() string { + return "ssh: this private key is passphrase protected" +} + +// ParseRawPrivateKey returns a private key from a PEM encoded private key. It supports +// RSA, DSA, ECDSA, and Ed25519 private keys in PKCS#1, PKCS#8, OpenSSL, and OpenSSH +// formats. If the private key is encrypted, it will return a PassphraseMissingError. +func ParseRawPrivateKey(pemBytes []byte) (interface{}, error) { + block, _ := pem.Decode(pemBytes) + if block == nil { + return nil, errors.New("ssh: no key found") + } + + if encryptedBlock(block) { + return nil, &PassphraseMissingError{} + } + + switch block.Type { + case "RSA PRIVATE KEY": + return x509.ParsePKCS1PrivateKey(block.Bytes) + // RFC5208 - https://tools.ietf.org/html/rfc5208 + case "PRIVATE KEY": + return x509.ParsePKCS8PrivateKey(block.Bytes) + case "EC PRIVATE KEY": + return x509.ParseECPrivateKey(block.Bytes) + case "DSA PRIVATE KEY": + return ParseDSAPrivateKey(block.Bytes) + case "OPENSSH PRIVATE KEY": + return parseOpenSSHPrivateKey(block.Bytes, unencryptedOpenSSHKey) + default: + return nil, fmt.Errorf("ssh: unsupported key type %q", block.Type) + } +} + +// ParseDSAPrivateKey returns a DSA private key from its ASN.1 DER encoding, as +// specified by the OpenSSL DSA man page. +func ParseDSAPrivateKey(der []byte) (*dsa.PrivateKey, error) { + var k struct { + Version int + P *big.Int + Q *big.Int + G *big.Int + Pub *big.Int + Priv *big.Int + } + rest, err := asn1.Unmarshal(der, &k) + if err != nil { + return nil, errors.New("ssh: failed to parse DSA key: " + err.Error()) + } + if len(rest) > 0 { + return nil, errors.New("ssh: garbage after DSA key") + } + + return &dsa.PrivateKey{ + PublicKey: dsa.PublicKey{ + Parameters: dsa.Parameters{ + P: k.P, + Q: k.Q, + G: k.G, + }, + Y: k.Pub, + }, + X: k.Priv, + }, nil +} + +func unencryptedOpenSSHKey(cipherName, kdfName, kdfOpts string, privKeyBlock []byte) ([]byte, error) { + if kdfName != "none" || cipherName != "none" { + return nil, &PassphraseMissingError{} + } + if kdfOpts != "" { + return nil, errors.New("ssh: invalid openssh private key") + } + return privKeyBlock, nil +} + +func unencryptedOpenSSHMarshaler(privKeyBlock []byte) ([]byte, string, string, string, error) { + key := generateOpenSSHPadding(privKeyBlock, 8) + return key, "none", "none", "", nil +} + +const privateKeyAuthMagic = "openssh-key-v1\x00" + +type openSSHDecryptFunc func(CipherName, KdfName, KdfOpts string, PrivKeyBlock []byte) ([]byte, error) +type openSSHEncryptFunc func(PrivKeyBlock []byte) (ProtectedKeyBlock []byte, cipherName, kdfName, kdfOptions string, err error) + +type openSSHEncryptedPrivateKey struct { + CipherName string + KdfName string + KdfOpts string + NumKeys uint32 + PubKey []byte + PrivKeyBlock []byte +} + +type openSSHPrivateKey struct { + Check1 uint32 + Check2 uint32 + Keytype string + Rest []byte `ssh:"rest"` +} + +type openSSHRSAPrivateKey struct { + N *big.Int + E *big.Int + D *big.Int + Iqmp *big.Int + P *big.Int + Q *big.Int + Comment string + Pad []byte `ssh:"rest"` +} + +type openSSHEd25519PrivateKey struct { + Pub []byte + Priv []byte + Comment string + Pad []byte `ssh:"rest"` +} + +type openSSHECDSAPrivateKey struct { + Curve string + Pub []byte + D *big.Int + Comment string + Pad []byte `ssh:"rest"` +} + +// parseOpenSSHPrivateKey parses an OpenSSH private key, using the decrypt +// function to unwrap the encrypted portion. unencryptedOpenSSHKey can be used +// as the decrypt function to parse an unencrypted private key. See +// https://github.com/openssh/openssh-portable/blob/master/PROTOCOL.key. +func parseOpenSSHPrivateKey(key []byte, decrypt openSSHDecryptFunc) (crypto.PrivateKey, error) { + if len(key) < len(privateKeyAuthMagic) || string(key[:len(privateKeyAuthMagic)]) != privateKeyAuthMagic { + return nil, errors.New("ssh: invalid openssh private key format") + } + remaining := key[len(privateKeyAuthMagic):] + + var w openSSHEncryptedPrivateKey + if err := Unmarshal(remaining, &w); err != nil { + return nil, err + } + if w.NumKeys != 1 { + // We only support single key files, and so does OpenSSH. + // https://github.com/openssh/openssh-portable/blob/4103a3ec7/sshkey.c#L4171 + return nil, errors.New("ssh: multi-key files are not supported") + } + + privKeyBlock, err := decrypt(w.CipherName, w.KdfName, w.KdfOpts, w.PrivKeyBlock) + if err != nil { + if err, ok := err.(*PassphraseMissingError); ok { + pub, errPub := ParsePublicKey(w.PubKey) + if errPub != nil { + return nil, fmt.Errorf("ssh: failed to parse embedded public key: %v", errPub) + } + err.PublicKey = pub + } + return nil, err + } + + var pk1 openSSHPrivateKey + if err := Unmarshal(privKeyBlock, &pk1); err != nil || pk1.Check1 != pk1.Check2 { + if w.CipherName != "none" { + return nil, x509.IncorrectPasswordError + } + return nil, errors.New("ssh: malformed OpenSSH key") + } + + switch pk1.Keytype { + case KeyAlgoRSA: + var key openSSHRSAPrivateKey + if err := Unmarshal(pk1.Rest, &key); err != nil { + return nil, err + } + + if err := checkOpenSSHKeyPadding(key.Pad); err != nil { + return nil, err + } + + pk := &rsa.PrivateKey{ + PublicKey: rsa.PublicKey{ + N: key.N, + E: int(key.E.Int64()), + }, + D: key.D, + Primes: []*big.Int{key.P, key.Q}, + } + + if err := pk.Validate(); err != nil { + return nil, err + } + + pk.Precompute() + + return pk, nil + case KeyAlgoED25519: + var key openSSHEd25519PrivateKey + if err := Unmarshal(pk1.Rest, &key); err != nil { + return nil, err + } + + if len(key.Priv) != ed25519.PrivateKeySize { + return nil, errors.New("ssh: private key unexpected length") + } + + if err := checkOpenSSHKeyPadding(key.Pad); err != nil { + return nil, err + } + + pk := ed25519.PrivateKey(make([]byte, ed25519.PrivateKeySize)) + copy(pk, key.Priv) + return &pk, nil + case KeyAlgoECDSA256, KeyAlgoECDSA384, KeyAlgoECDSA521: + var key openSSHECDSAPrivateKey + if err := Unmarshal(pk1.Rest, &key); err != nil { + return nil, err + } + + if err := checkOpenSSHKeyPadding(key.Pad); err != nil { + return nil, err + } + + var curve elliptic.Curve + switch key.Curve { + case "nistp256": + curve = elliptic.P256() + case "nistp384": + curve = elliptic.P384() + case "nistp521": + curve = elliptic.P521() + default: + return nil, errors.New("ssh: unhandled elliptic curve: " + key.Curve) + } + + X, Y := elliptic.Unmarshal(curve, key.Pub) + if X == nil || Y == nil { + return nil, errors.New("ssh: failed to unmarshal public key") + } + + if key.D.Cmp(curve.Params().N) >= 0 { + return nil, errors.New("ssh: scalar is out of range") + } + + x, y := curve.ScalarBaseMult(key.D.Bytes()) + if x.Cmp(X) != 0 || y.Cmp(Y) != 0 { + return nil, errors.New("ssh: public key does not match private key") + } + + return &ecdsa.PrivateKey{ + PublicKey: ecdsa.PublicKey{ + Curve: curve, + X: X, + Y: Y, + }, + D: key.D, + }, nil + default: + return nil, errors.New("ssh: unhandled key type") + } +} + +func marshalOpenSSHPrivateKey(key crypto.PrivateKey, comment string, encrypt openSSHEncryptFunc) (*pem.Block, error) { + var w openSSHEncryptedPrivateKey + var pk1 openSSHPrivateKey + + // Random check bytes. + var check uint32 + if err := binary.Read(rand.Reader, binary.BigEndian, &check); err != nil { + return nil, err + } + + pk1.Check1 = check + pk1.Check2 = check + w.NumKeys = 1 + + // Use a []byte directly on ed25519 keys. + if k, ok := key.(*ed25519.PrivateKey); ok { + key = *k + } + + switch k := key.(type) { + case *rsa.PrivateKey: + E := new(big.Int).SetInt64(int64(k.PublicKey.E)) + // Marshal public key: + // E and N are in reversed order in the public and private key. + pubKey := struct { + KeyType string + E *big.Int + N *big.Int + }{ + KeyAlgoRSA, + E, k.PublicKey.N, + } + w.PubKey = Marshal(pubKey) + + // Marshal private key. + key := openSSHRSAPrivateKey{ + N: k.PublicKey.N, + E: E, + D: k.D, + Iqmp: k.Precomputed.Qinv, + P: k.Primes[0], + Q: k.Primes[1], + Comment: comment, + } + pk1.Keytype = KeyAlgoRSA + pk1.Rest = Marshal(key) + case ed25519.PrivateKey: + pub := make([]byte, ed25519.PublicKeySize) + priv := make([]byte, ed25519.PrivateKeySize) + copy(pub, k[32:]) + copy(priv, k) + + // Marshal public key. + pubKey := struct { + KeyType string + Pub []byte + }{ + KeyAlgoED25519, pub, + } + w.PubKey = Marshal(pubKey) + + // Marshal private key. + key := openSSHEd25519PrivateKey{ + Pub: pub, + Priv: priv, + Comment: comment, + } + pk1.Keytype = KeyAlgoED25519 + pk1.Rest = Marshal(key) + case *ecdsa.PrivateKey: + var curve, keyType string + switch name := k.Curve.Params().Name; name { + case "P-256": + curve = "nistp256" + keyType = KeyAlgoECDSA256 + case "P-384": + curve = "nistp384" + keyType = KeyAlgoECDSA384 + case "P-521": + curve = "nistp521" + keyType = KeyAlgoECDSA521 + default: + return nil, errors.New("ssh: unhandled elliptic curve " + name) + } + + pub := elliptic.Marshal(k.Curve, k.PublicKey.X, k.PublicKey.Y) + + // Marshal public key. + pubKey := struct { + KeyType string + Curve string + Pub []byte + }{ + keyType, curve, pub, + } + w.PubKey = Marshal(pubKey) + + // Marshal private key. + key := openSSHECDSAPrivateKey{ + Curve: curve, + Pub: pub, + D: k.D, + Comment: comment, + } + pk1.Keytype = keyType + pk1.Rest = Marshal(key) + default: + return nil, fmt.Errorf("ssh: unsupported key type %T", k) + } + + var err error + // Add padding and encrypt the key if necessary. + w.PrivKeyBlock, w.CipherName, w.KdfName, w.KdfOpts, err = encrypt(Marshal(pk1)) + if err != nil { + return nil, err + } + + b := Marshal(w) + block := &pem.Block{ + Type: "OPENSSH PRIVATE KEY", + Bytes: append([]byte(privateKeyAuthMagic), b...), + } + return block, nil +} + +func checkOpenSSHKeyPadding(pad []byte) error { + for i, b := range pad { + if int(b) != i+1 { + return errors.New("ssh: padding not as expected") + } + } + return nil +} + +func generateOpenSSHPadding(block []byte, blockSize int) []byte { + for i, l := 0, len(block); (l+i)%blockSize != 0; i++ { + block = append(block, byte(i+1)) + } + return block +} + +// FingerprintLegacyMD5 returns the user presentation of the key's +// fingerprint as described by RFC 4716 section 4. +func FingerprintLegacyMD5(pubKey PublicKey) string { + md5sum := md5.Sum(pubKey.Marshal()) + hexarray := make([]string, len(md5sum)) + for i, c := range md5sum { + hexarray[i] = hex.EncodeToString([]byte{c}) + } + return strings.Join(hexarray, ":") +} + +// FingerprintSHA256 returns the user presentation of the key's +// fingerprint as unpadded base64 encoded sha256 hash. +// This format was introduced from OpenSSH 6.8. +// https://www.openssh.com/txt/release-6.8 +// https://tools.ietf.org/html/rfc4648#section-3.2 (unpadded base64 encoding) +func FingerprintSHA256(pubKey PublicKey) string { + sha256sum := sha256.Sum256(pubKey.Marshal()) + hash := base64.RawStdEncoding.EncodeToString(sha256sum[:]) + return "SHA256:" + hash +} diff --git a/tempfork/sshtest/ssh/keys_test.go b/tempfork/sshtest/ssh/keys_test.go new file mode 100644 index 000000000..bf1f0be1b --- /dev/null +++ b/tempfork/sshtest/ssh/keys_test.go @@ -0,0 +1,724 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "bytes" + "crypto/dsa" + "crypto/ecdsa" + "crypto/ed25519" + "crypto/elliptic" + "crypto/rand" + "crypto/rsa" + "encoding/base64" + "encoding/hex" + "encoding/pem" + "fmt" + "io" + "reflect" + "strings" + "testing" + + "golang.org/x/crypto/ssh/testdata" +) + +func rawKey(pub PublicKey) interface{} { + switch k := pub.(type) { + case *rsaPublicKey: + return (*rsa.PublicKey)(k) + case *dsaPublicKey: + return (*dsa.PublicKey)(k) + case *ecdsaPublicKey: + return (*ecdsa.PublicKey)(k) + case ed25519PublicKey: + return (ed25519.PublicKey)(k) + case *Certificate: + return k + } + panic("unknown key type") +} + +func TestKeyMarshalParse(t *testing.T) { + for _, priv := range testSigners { + pub := priv.PublicKey() + roundtrip, err := ParsePublicKey(pub.Marshal()) + if err != nil { + t.Errorf("ParsePublicKey(%T): %v", pub, err) + } + + k1 := rawKey(pub) + k2 := rawKey(roundtrip) + + if !reflect.DeepEqual(k1, k2) { + t.Errorf("got %#v in roundtrip, want %#v", k2, k1) + } + } +} + +func TestUnsupportedCurves(t *testing.T) { + raw, err := ecdsa.GenerateKey(elliptic.P224(), rand.Reader) + if err != nil { + t.Fatalf("GenerateKey: %v", err) + } + + if _, err = NewSignerFromKey(raw); err == nil || !strings.Contains(err.Error(), "only P-256") { + t.Fatalf("NewPrivateKey should not succeed with P-224, got: %v", err) + } + + if _, err = NewPublicKey(&raw.PublicKey); err == nil || !strings.Contains(err.Error(), "only P-256") { + t.Fatalf("NewPublicKey should not succeed with P-224, got: %v", err) + } +} + +func TestNewPublicKey(t *testing.T) { + for _, k := range testSigners { + raw := rawKey(k.PublicKey()) + // Skip certificates, as NewPublicKey does not support them. + if _, ok := raw.(*Certificate); ok { + continue + } + pub, err := NewPublicKey(raw) + if err != nil { + t.Errorf("NewPublicKey(%#v): %v", raw, err) + } + if !reflect.DeepEqual(k.PublicKey(), pub) { + t.Errorf("NewPublicKey(%#v) = %#v, want %#v", raw, pub, k.PublicKey()) + } + } +} + +func TestKeySignVerify(t *testing.T) { + for _, priv := range testSigners { + pub := priv.PublicKey() + + data := []byte("sign me") + sig, err := priv.Sign(rand.Reader, data) + if err != nil { + t.Fatalf("Sign(%T): %v", priv, err) + } + + if err := pub.Verify(data, sig); err != nil { + t.Errorf("publicKey.Verify(%T): %v", priv, err) + } + sig.Blob[5]++ + if err := pub.Verify(data, sig); err == nil { + t.Errorf("publicKey.Verify on broken sig did not fail") + } + } +} + +func TestKeySignWithAlgorithmVerify(t *testing.T) { + for k, priv := range testSigners { + if algorithmSigner, ok := priv.(MultiAlgorithmSigner); !ok { + t.Errorf("Signers %q constructed by ssh package should always implement the MultiAlgorithmSigner interface: %T", k, priv) + } else { + pub := priv.PublicKey() + data := []byte("sign me") + + signWithAlgTestCase := func(algorithm string, expectedAlg string) { + sig, err := algorithmSigner.SignWithAlgorithm(rand.Reader, data, algorithm) + if err != nil { + t.Fatalf("Sign(%T): %v", priv, err) + } + if sig.Format != expectedAlg { + t.Errorf("signature format did not match requested signature algorithm: %s != %s", sig.Format, expectedAlg) + } + + if err := pub.Verify(data, sig); err != nil { + t.Errorf("publicKey.Verify(%T): %v", priv, err) + } + sig.Blob[5]++ + if err := pub.Verify(data, sig); err == nil { + t.Errorf("publicKey.Verify on broken sig did not fail") + } + } + + // Using the empty string as the algorithm name should result in the same signature format as the algorithm-free Sign method. + defaultSig, err := priv.Sign(rand.Reader, data) + if err != nil { + t.Fatalf("Sign(%T): %v", priv, err) + } + signWithAlgTestCase("", defaultSig.Format) + + // RSA keys are the only ones which currently support more than one signing algorithm + if pub.Type() == KeyAlgoRSA { + for _, algorithm := range []string{KeyAlgoRSA, KeyAlgoRSASHA256, KeyAlgoRSASHA512} { + signWithAlgTestCase(algorithm, algorithm) + } + } + } + } +} + +func TestKeySignWithShortSignature(t *testing.T) { + signer := testSigners["rsa"].(AlgorithmSigner) + pub := signer.PublicKey() + // Note: data obtained by empirically trying until a result + // starting with 0 appeared + tests := []struct { + algorithm string + data []byte + }{ + { + algorithm: KeyAlgoRSA, + data: []byte("sign me92"), + }, + { + algorithm: KeyAlgoRSASHA256, + data: []byte("sign me294"), + }, + { + algorithm: KeyAlgoRSASHA512, + data: []byte("sign me60"), + }, + } + + for _, tt := range tests { + sig, err := signer.SignWithAlgorithm(rand.Reader, tt.data, tt.algorithm) + if err != nil { + t.Fatalf("Sign(%T): %v", signer, err) + } + if sig.Blob[0] != 0 { + t.Errorf("%s: Expected signature with a leading 0", tt.algorithm) + } + sig.Blob = sig.Blob[1:] + if err := pub.Verify(tt.data, sig); err != nil { + t.Errorf("publicKey.Verify(%s): %v", tt.algorithm, err) + } + } +} + +func TestParseRSAPrivateKey(t *testing.T) { + key := testPrivateKeys["rsa"] + + rsa, ok := key.(*rsa.PrivateKey) + if !ok { + t.Fatalf("got %T, want *rsa.PrivateKey", rsa) + } + + if err := rsa.Validate(); err != nil { + t.Errorf("Validate: %v", err) + } +} + +func TestParseECPrivateKey(t *testing.T) { + key := testPrivateKeys["ecdsa"] + + ecKey, ok := key.(*ecdsa.PrivateKey) + if !ok { + t.Fatalf("got %T, want *ecdsa.PrivateKey", ecKey) + } + + if !validateECPublicKey(ecKey.Curve, ecKey.X, ecKey.Y) { + t.Fatalf("public key does not validate.") + } +} + +func TestParseDSA(t *testing.T) { + // We actually exercise the ParsePrivateKey codepath here, as opposed to + // using the ParseRawPrivateKey+NewSignerFromKey path that testdata_test.go + // uses. + s, err := ParsePrivateKey(testdata.PEMBytes["dsa"]) + if err != nil { + t.Fatalf("ParsePrivateKey returned error: %s", err) + } + + data := []byte("sign me") + sig, err := s.Sign(rand.Reader, data) + if err != nil { + t.Fatalf("dsa.Sign: %v", err) + } + + if err := s.PublicKey().Verify(data, sig); err != nil { + t.Errorf("Verify failed: %v", err) + } +} + +// Tests for authorized_keys parsing. + +// getTestKey returns a public key, and its base64 encoding. +func getTestKey() (PublicKey, string) { + k := testPublicKeys["rsa"] + + b := &bytes.Buffer{} + e := base64.NewEncoder(base64.StdEncoding, b) + e.Write(k.Marshal()) + e.Close() + + return k, b.String() +} + +func TestMarshalParsePublicKey(t *testing.T) { + pub, pubSerialized := getTestKey() + line := fmt.Sprintf("%s %s user@host", pub.Type(), pubSerialized) + + authKeys := MarshalAuthorizedKey(pub) + actualFields := strings.Fields(string(authKeys)) + if len(actualFields) == 0 { + t.Fatalf("failed authKeys: %v", authKeys) + } + + // drop the comment + expectedFields := strings.Fields(line)[0:2] + + if !reflect.DeepEqual(actualFields, expectedFields) { + t.Errorf("got %v, expected %v", actualFields, expectedFields) + } + + actPub, _, _, _, err := ParseAuthorizedKey([]byte(line)) + if err != nil { + t.Fatalf("cannot parse %v: %v", line, err) + } + if !reflect.DeepEqual(actPub, pub) { + t.Errorf("got %v, expected %v", actPub, pub) + } +} + +func TestMarshalPrivateKey(t *testing.T) { + tests := []struct { + name string + }{ + {"rsa-openssh-format"}, + {"ed25519"}, + {"p256-openssh-format"}, + {"p384-openssh-format"}, + {"p521-openssh-format"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + expected, ok := testPrivateKeys[tt.name] + if !ok { + t.Fatalf("cannot find key %s", tt.name) + } + + block, err := MarshalPrivateKey(expected, "test@golang.org") + if err != nil { + t.Fatalf("cannot marshal %s: %v", tt.name, err) + } + + key, err := ParseRawPrivateKey(pem.EncodeToMemory(block)) + if err != nil { + t.Fatalf("cannot parse %s: %v", tt.name, err) + } + + if !reflect.DeepEqual(expected, key) { + t.Errorf("unexpected marshaled key %s", tt.name) + } + }) + } +} + +type testAuthResult struct { + pubKey PublicKey + options []string + comments string + rest string + ok bool +} + +func testAuthorizedKeys(t *testing.T, authKeys []byte, expected []testAuthResult) { + rest := authKeys + var values []testAuthResult + for len(rest) > 0 { + var r testAuthResult + var err error + r.pubKey, r.comments, r.options, rest, err = ParseAuthorizedKey(rest) + r.ok = (err == nil) + t.Log(err) + r.rest = string(rest) + values = append(values, r) + } + + if !reflect.DeepEqual(values, expected) { + t.Errorf("got %#v, expected %#v", values, expected) + } +} + +func TestAuthorizedKeyBasic(t *testing.T) { + pub, pubSerialized := getTestKey() + line := "ssh-rsa " + pubSerialized + " user@host" + testAuthorizedKeys(t, []byte(line), + []testAuthResult{ + {pub, nil, "user@host", "", true}, + }) +} + +func TestAuth(t *testing.T) { + pub, pubSerialized := getTestKey() + authWithOptions := []string{ + `# comments to ignore before any keys...`, + ``, + `env="HOME=/home/root",no-port-forwarding ssh-rsa ` + pubSerialized + ` user@host`, + `# comments to ignore, along with a blank line`, + ``, + `env="HOME=/home/root2" ssh-rsa ` + pubSerialized + ` user2@host2`, + ``, + `# more comments, plus a invalid entry`, + `ssh-rsa data-that-will-not-parse user@host3`, + } + for _, eol := range []string{"\n", "\r\n"} { + authOptions := strings.Join(authWithOptions, eol) + rest2 := strings.Join(authWithOptions[3:], eol) + rest3 := strings.Join(authWithOptions[6:], eol) + testAuthorizedKeys(t, []byte(authOptions), []testAuthResult{ + {pub, []string{`env="HOME=/home/root"`, "no-port-forwarding"}, "user@host", rest2, true}, + {pub, []string{`env="HOME=/home/root2"`}, "user2@host2", rest3, true}, + {nil, nil, "", "", false}, + }) + } +} + +func TestAuthWithQuotedSpaceInEnv(t *testing.T) { + pub, pubSerialized := getTestKey() + authWithQuotedSpaceInEnv := []byte(`env="HOME=/home/root dir",no-port-forwarding ssh-rsa ` + pubSerialized + ` user@host`) + testAuthorizedKeys(t, []byte(authWithQuotedSpaceInEnv), []testAuthResult{ + {pub, []string{`env="HOME=/home/root dir"`, "no-port-forwarding"}, "user@host", "", true}, + }) +} + +func TestAuthWithQuotedCommaInEnv(t *testing.T) { + pub, pubSerialized := getTestKey() + authWithQuotedCommaInEnv := []byte(`env="HOME=/home/root,dir",no-port-forwarding ssh-rsa ` + pubSerialized + ` user@host`) + testAuthorizedKeys(t, []byte(authWithQuotedCommaInEnv), []testAuthResult{ + {pub, []string{`env="HOME=/home/root,dir"`, "no-port-forwarding"}, "user@host", "", true}, + }) +} + +func TestAuthWithQuotedQuoteInEnv(t *testing.T) { + pub, pubSerialized := getTestKey() + authWithQuotedQuoteInEnv := []byte(`env="HOME=/home/\"root dir",no-port-forwarding` + "\t" + `ssh-rsa` + "\t" + pubSerialized + ` user@host`) + authWithDoubleQuotedQuote := []byte(`no-port-forwarding,env="HOME=/home/ \"root dir\"" ssh-rsa ` + pubSerialized + "\t" + `user@host`) + testAuthorizedKeys(t, []byte(authWithQuotedQuoteInEnv), []testAuthResult{ + {pub, []string{`env="HOME=/home/\"root dir"`, "no-port-forwarding"}, "user@host", "", true}, + }) + + testAuthorizedKeys(t, []byte(authWithDoubleQuotedQuote), []testAuthResult{ + {pub, []string{"no-port-forwarding", `env="HOME=/home/ \"root dir\""`}, "user@host", "", true}, + }) +} + +func TestAuthWithInvalidSpace(t *testing.T) { + _, pubSerialized := getTestKey() + authWithInvalidSpace := []byte(`env="HOME=/home/root dir", no-port-forwarding ssh-rsa ` + pubSerialized + ` user@host +#more to follow but still no valid keys`) + testAuthorizedKeys(t, []byte(authWithInvalidSpace), []testAuthResult{ + {nil, nil, "", "", false}, + }) +} + +func TestAuthWithMissingQuote(t *testing.T) { + pub, pubSerialized := getTestKey() + authWithMissingQuote := []byte(`env="HOME=/home/root,no-port-forwarding ssh-rsa ` + pubSerialized + ` user@host +env="HOME=/home/root",shared-control ssh-rsa ` + pubSerialized + ` user@host`) + + testAuthorizedKeys(t, []byte(authWithMissingQuote), []testAuthResult{ + {pub, []string{`env="HOME=/home/root"`, `shared-control`}, "user@host", "", true}, + }) +} + +func TestInvalidEntry(t *testing.T) { + authInvalid := []byte(`ssh-rsa`) + _, _, _, _, err := ParseAuthorizedKey(authInvalid) + if err == nil { + t.Errorf("got valid entry for %q", authInvalid) + } +} + +var knownHostsParseTests = []struct { + input string + err string + + marker string + comment string + hosts []string + rest string +}{ + { + "", + "EOF", + + "", "", nil, "", + }, + { + "# Just a comment", + "EOF", + + "", "", nil, "", + }, + { + " \t ", + "EOF", + + "", "", nil, "", + }, + { + "localhost ssh-rsa {RSAPUB}", + "", + + "", "", []string{"localhost"}, "", + }, + { + "localhost\tssh-rsa {RSAPUB}", + "", + + "", "", []string{"localhost"}, "", + }, + { + "localhost\tssh-rsa {RSAPUB}\tcomment comment", + "", + + "", "comment comment", []string{"localhost"}, "", + }, + { + "localhost\tssh-rsa {RSAPUB}\tcomment comment\n", + "", + + "", "comment comment", []string{"localhost"}, "", + }, + { + "localhost\tssh-rsa {RSAPUB}\tcomment comment\r\n", + "", + + "", "comment comment", []string{"localhost"}, "", + }, + { + "localhost\tssh-rsa {RSAPUB}\tcomment comment\r\nnext line", + "", + + "", "comment comment", []string{"localhost"}, "next line", + }, + { + "localhost,[host2:123]\tssh-rsa {RSAPUB}\tcomment comment", + "", + + "", "comment comment", []string{"localhost", "[host2:123]"}, "", + }, + { + "@marker \tlocalhost,[host2:123]\tssh-rsa {RSAPUB}", + "", + + "marker", "", []string{"localhost", "[host2:123]"}, "", + }, + { + "@marker \tlocalhost,[host2:123]\tssh-rsa aabbccdd", + "short read", + + "", "", nil, "", + }, +} + +func TestKnownHostsParsing(t *testing.T) { + rsaPub, rsaPubSerialized := getTestKey() + + for i, test := range knownHostsParseTests { + var expectedKey PublicKey + const rsaKeyToken = "{RSAPUB}" + + input := test.input + if strings.Contains(input, rsaKeyToken) { + expectedKey = rsaPub + input = strings.Replace(test.input, rsaKeyToken, rsaPubSerialized, -1) + } + + marker, hosts, pubKey, comment, rest, err := ParseKnownHosts([]byte(input)) + if err != nil { + if len(test.err) == 0 { + t.Errorf("#%d: unexpectedly failed with %q", i, err) + } else if !strings.Contains(err.Error(), test.err) { + t.Errorf("#%d: expected error containing %q, but got %q", i, test.err, err) + } + continue + } else if len(test.err) != 0 { + t.Errorf("#%d: succeeded but expected error including %q", i, test.err) + continue + } + + if !reflect.DeepEqual(expectedKey, pubKey) { + t.Errorf("#%d: expected key %#v, but got %#v", i, expectedKey, pubKey) + } + + if marker != test.marker { + t.Errorf("#%d: expected marker %q, but got %q", i, test.marker, marker) + } + + if comment != test.comment { + t.Errorf("#%d: expected comment %q, but got %q", i, test.comment, comment) + } + + if !reflect.DeepEqual(test.hosts, hosts) { + t.Errorf("#%d: expected hosts %#v, but got %#v", i, test.hosts, hosts) + } + + if rest := string(rest); rest != test.rest { + t.Errorf("#%d: expected remaining input to be %q, but got %q", i, test.rest, rest) + } + } +} + +func TestFingerprintLegacyMD5(t *testing.T) { + pub, _ := getTestKey() + fingerprint := FingerprintLegacyMD5(pub) + want := "b7:ef:d3:d5:89:29:52:96:9f:df:47:41:4d:15:37:f4" // ssh-keygen -lf -E md5 rsa + if fingerprint != want { + t.Errorf("got fingerprint %q want %q", fingerprint, want) + } +} + +func TestFingerprintSHA256(t *testing.T) { + pub, _ := getTestKey() + fingerprint := FingerprintSHA256(pub) + want := "SHA256:fi5+D7UmDZDE9Q2sAVvvlpcQSIakN4DERdINgXd2AnE" // ssh-keygen -lf rsa + if fingerprint != want { + t.Errorf("got fingerprint %q want %q", fingerprint, want) + } +} + +func TestInvalidKeys(t *testing.T) { + keyTypes := []string{ + "RSA PRIVATE KEY", + "PRIVATE KEY", + "EC PRIVATE KEY", + "DSA PRIVATE KEY", + "OPENSSH PRIVATE KEY", + } + + for _, keyType := range keyTypes { + for _, dataLen := range []int{0, 1, 2, 5, 10, 20} { + data := make([]byte, dataLen) + if _, err := io.ReadFull(rand.Reader, data); err != nil { + t.Fatal(err) + } + + var buf bytes.Buffer + pem.Encode(&buf, &pem.Block{ + Type: keyType, + Bytes: data, + }) + + // This test is just to ensure that the function + // doesn't panic so the return value is ignored. + ParseRawPrivateKey(buf.Bytes()) + } + } +} + +func TestSKKeys(t *testing.T) { + for _, d := range testdata.SKData { + pk, _, _, _, err := ParseAuthorizedKey(d.PubKey) + if err != nil { + t.Fatalf("parseAuthorizedKey returned error: %v", err) + } + + sigBuf := make([]byte, hex.DecodedLen(len(d.HexSignature))) + if _, err := hex.Decode(sigBuf, d.HexSignature); err != nil { + t.Fatalf("hex.Decode() failed: %v", err) + } + + dataBuf := make([]byte, hex.DecodedLen(len(d.HexData))) + if _, err := hex.Decode(dataBuf, d.HexData); err != nil { + t.Fatalf("hex.Decode() failed: %v", err) + } + + sig, _, ok := parseSignature(sigBuf) + if !ok { + t.Fatalf("parseSignature(%v) failed", sigBuf) + } + + // Test that good data and signature pass verification + if err := pk.Verify(dataBuf, sig); err != nil { + t.Errorf("%s: PublicKey.Verify(%v, %v) failed: %v", d.Name, dataBuf, sig, err) + } + + // Invalid data being passed in + invalidData := []byte("INVALID DATA") + if err := pk.Verify(invalidData, sig); err == nil { + t.Errorf("%s with invalid data: PublicKey.Verify(%v, %v) passed unexpectedly", d.Name, invalidData, sig) + } + + // Change byte in blob to corrup signature + sig.Blob[5] = byte('A') + // Corrupted data being passed in + if err := pk.Verify(dataBuf, sig); err == nil { + t.Errorf("%s with corrupted signature: PublicKey.Verify(%v, %v) passed unexpectedly", d.Name, dataBuf, sig) + } + } +} + +func TestNewSignerWithAlgos(t *testing.T) { + algorithSigner, ok := testSigners["rsa"].(AlgorithmSigner) + if !ok { + t.Fatal("rsa test signer does not implement the AlgorithmSigner interface") + } + _, err := NewSignerWithAlgorithms(algorithSigner, nil) + if err == nil { + t.Error("signer with algos created with no algorithms") + } + + _, err = NewSignerWithAlgorithms(algorithSigner, []string{KeyAlgoED25519}) + if err == nil { + t.Error("signer with algos created with invalid algorithms") + } + + _, err = NewSignerWithAlgorithms(algorithSigner, []string{CertAlgoRSASHA256v01}) + if err == nil { + t.Error("signer with algos created with certificate algorithms") + } + + mas, err := NewSignerWithAlgorithms(algorithSigner, []string{KeyAlgoRSASHA256, KeyAlgoRSASHA512}) + if err != nil { + t.Errorf("unable to create signer with valid algorithms: %v", err) + } + + _, err = NewSignerWithAlgorithms(mas, []string{KeyAlgoRSA}) + if err == nil { + t.Error("signer with algos created with restricted algorithms") + } +} + +func TestCryptoPublicKey(t *testing.T) { + for _, priv := range testSigners { + p1 := priv.PublicKey() + key, ok := p1.(CryptoPublicKey) + if !ok { + continue + } + p2, err := NewPublicKey(key.CryptoPublicKey()) + if err != nil { + t.Fatalf("NewPublicKey(CryptoPublicKey) failed for %s, got: %v", p1.Type(), err) + } + if !reflect.DeepEqual(p1, p2) { + t.Errorf("got %#v in NewPublicKey, want %#v", p2, p1) + } + } + for _, d := range testdata.SKData { + p1, _, _, _, err := ParseAuthorizedKey(d.PubKey) + if err != nil { + t.Fatalf("parseAuthorizedKey returned error: %v", err) + } + k1, ok := p1.(CryptoPublicKey) + if !ok { + t.Fatalf("%T does not implement CryptoPublicKey", p1) + } + + var p2 PublicKey + switch pub := k1.CryptoPublicKey().(type) { + case *ecdsa.PublicKey: + p2 = &skECDSAPublicKey{ + application: "ssh:", + PublicKey: *pub, + } + case ed25519.PublicKey: + p2 = &skEd25519PublicKey{ + application: "ssh:", + PublicKey: pub, + } + default: + t.Fatalf("unexpected type %T from CryptoPublicKey()", pub) + } + if !reflect.DeepEqual(p1, p2) { + t.Errorf("got %#v, want %#v", p2, p1) + } + } +} diff --git a/tempfork/sshtest/ssh/mac.go b/tempfork/sshtest/ssh/mac.go new file mode 100644 index 000000000..06a1b2750 --- /dev/null +++ b/tempfork/sshtest/ssh/mac.go @@ -0,0 +1,68 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +// Message authentication support + +import ( + "crypto/hmac" + "crypto/sha1" + "crypto/sha256" + "crypto/sha512" + "hash" +) + +type macMode struct { + keySize int + etm bool + new func(key []byte) hash.Hash +} + +// truncatingMAC wraps around a hash.Hash and truncates the output digest to +// a given size. +type truncatingMAC struct { + length int + hmac hash.Hash +} + +func (t truncatingMAC) Write(data []byte) (int, error) { + return t.hmac.Write(data) +} + +func (t truncatingMAC) Sum(in []byte) []byte { + out := t.hmac.Sum(in) + return out[:len(in)+t.length] +} + +func (t truncatingMAC) Reset() { + t.hmac.Reset() +} + +func (t truncatingMAC) Size() int { + return t.length +} + +func (t truncatingMAC) BlockSize() int { return t.hmac.BlockSize() } + +var macModes = map[string]*macMode{ + "hmac-sha2-512-etm@openssh.com": {64, true, func(key []byte) hash.Hash { + return hmac.New(sha512.New, key) + }}, + "hmac-sha2-256-etm@openssh.com": {32, true, func(key []byte) hash.Hash { + return hmac.New(sha256.New, key) + }}, + "hmac-sha2-512": {64, false, func(key []byte) hash.Hash { + return hmac.New(sha512.New, key) + }}, + "hmac-sha2-256": {32, false, func(key []byte) hash.Hash { + return hmac.New(sha256.New, key) + }}, + "hmac-sha1": {20, false, func(key []byte) hash.Hash { + return hmac.New(sha1.New, key) + }}, + "hmac-sha1-96": {20, false, func(key []byte) hash.Hash { + return truncatingMAC{12, hmac.New(sha1.New, key)} + }}, +} diff --git a/tempfork/sshtest/ssh/mempipe_test.go b/tempfork/sshtest/ssh/mempipe_test.go new file mode 100644 index 000000000..f27339c51 --- /dev/null +++ b/tempfork/sshtest/ssh/mempipe_test.go @@ -0,0 +1,124 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "io" + "sync" + "testing" +) + +// An in-memory packetConn. It is safe to call Close and writePacket +// from different goroutines. +type memTransport struct { + eof bool + pending [][]byte + write *memTransport + writeCount uint64 + sync.Mutex + *sync.Cond +} + +func (t *memTransport) readPacket() ([]byte, error) { + t.Lock() + defer t.Unlock() + for { + if len(t.pending) > 0 { + r := t.pending[0] + t.pending = t.pending[1:] + return r, nil + } + if t.eof { + return nil, io.EOF + } + t.Cond.Wait() + } +} + +func (t *memTransport) closeSelf() error { + t.Lock() + defer t.Unlock() + if t.eof { + return io.EOF + } + t.eof = true + t.Cond.Broadcast() + return nil +} + +func (t *memTransport) Close() error { + err := t.write.closeSelf() + t.closeSelf() + return err +} + +func (t *memTransport) writePacket(p []byte) error { + t.write.Lock() + defer t.write.Unlock() + if t.write.eof { + return io.EOF + } + c := make([]byte, len(p)) + copy(c, p) + t.write.pending = append(t.write.pending, c) + t.write.Cond.Signal() + t.writeCount++ + return nil +} + +func (t *memTransport) getWriteCount() uint64 { + t.write.Lock() + defer t.write.Unlock() + return t.writeCount +} + +func memPipe() (a, b packetConn) { + t1 := memTransport{} + t2 := memTransport{} + t1.write = &t2 + t2.write = &t1 + t1.Cond = sync.NewCond(&t1.Mutex) + t2.Cond = sync.NewCond(&t2.Mutex) + return &t1, &t2 +} + +func TestMemPipe(t *testing.T) { + a, b := memPipe() + if err := a.writePacket([]byte{42}); err != nil { + t.Fatalf("writePacket: %v", err) + } + if wc := a.(*memTransport).getWriteCount(); wc != 1 { + t.Fatalf("got %v, want 1", wc) + } + if err := a.Close(); err != nil { + t.Fatal("Close: ", err) + } + p, err := b.readPacket() + if err != nil { + t.Fatal("readPacket: ", err) + } + if len(p) != 1 || p[0] != 42 { + t.Fatalf("got %v, want {42}", p) + } + p, err = b.readPacket() + if err != io.EOF { + t.Fatalf("got %v, %v, want EOF", p, err) + } + if wc := b.(*memTransport).getWriteCount(); wc != 0 { + t.Fatalf("got %v, want 0", wc) + } +} + +func TestDoubleClose(t *testing.T) { + a, _ := memPipe() + err := a.Close() + if err != nil { + t.Errorf("Close: %v", err) + } + err = a.Close() + if err != io.EOF { + t.Errorf("expect EOF on double close.") + } +} diff --git a/tempfork/sshtest/ssh/messages.go b/tempfork/sshtest/ssh/messages.go new file mode 100644 index 000000000..b55f86056 --- /dev/null +++ b/tempfork/sshtest/ssh/messages.go @@ -0,0 +1,891 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + "io" + "math/big" + "reflect" + "strconv" + "strings" +) + +// These are SSH message type numbers. They are scattered around several +// documents but many were taken from [SSH-PARAMETERS]. +const ( + msgIgnore = 2 + msgUnimplemented = 3 + msgDebug = 4 + msgNewKeys = 21 +) + +// SSH messages: +// +// These structures mirror the wire format of the corresponding SSH messages. +// They are marshaled using reflection with the marshal and unmarshal functions +// in this file. The only wrinkle is that a final member of type []byte with a +// ssh tag of "rest" receives the remainder of a packet when unmarshaling. + +// See RFC 4253, section 11.1. +const msgDisconnect = 1 + +// disconnectMsg is the message that signals a disconnect. It is also +// the error type returned from mux.Wait() +type disconnectMsg struct { + Reason uint32 `sshtype:"1"` + Message string + Language string +} + +func (d *disconnectMsg) Error() string { + return fmt.Sprintf("ssh: disconnect, reason %d: %s", d.Reason, d.Message) +} + +// See RFC 4253, section 7.1. +const msgKexInit = 20 + +type kexInitMsg struct { + Cookie [16]byte `sshtype:"20"` + KexAlgos []string + ServerHostKeyAlgos []string + CiphersClientServer []string + CiphersServerClient []string + MACsClientServer []string + MACsServerClient []string + CompressionClientServer []string + CompressionServerClient []string + LanguagesClientServer []string + LanguagesServerClient []string + FirstKexFollows bool + Reserved uint32 +} + +// See RFC 4253, section 8. + +// Diffie-Hellman +const msgKexDHInit = 30 + +type kexDHInitMsg struct { + X *big.Int `sshtype:"30"` +} + +const msgKexECDHInit = 30 + +type kexECDHInitMsg struct { + ClientPubKey []byte `sshtype:"30"` +} + +const msgKexECDHReply = 31 + +type kexECDHReplyMsg struct { + HostKey []byte `sshtype:"31"` + EphemeralPubKey []byte + Signature []byte +} + +const msgKexDHReply = 31 + +type kexDHReplyMsg struct { + HostKey []byte `sshtype:"31"` + Y *big.Int + Signature []byte +} + +// See RFC 4419, section 5. +const msgKexDHGexGroup = 31 + +type kexDHGexGroupMsg struct { + P *big.Int `sshtype:"31"` + G *big.Int +} + +const msgKexDHGexInit = 32 + +type kexDHGexInitMsg struct { + X *big.Int `sshtype:"32"` +} + +const msgKexDHGexReply = 33 + +type kexDHGexReplyMsg struct { + HostKey []byte `sshtype:"33"` + Y *big.Int + Signature []byte +} + +const msgKexDHGexRequest = 34 + +type kexDHGexRequestMsg struct { + MinBits uint32 `sshtype:"34"` + PreferedBits uint32 + MaxBits uint32 +} + +// See RFC 4253, section 10. +const msgServiceRequest = 5 + +type serviceRequestMsg struct { + Service string `sshtype:"5"` +} + +// See RFC 4253, section 10. +const msgServiceAccept = 6 + +type serviceAcceptMsg struct { + Service string `sshtype:"6"` +} + +// See RFC 8308, section 2.3 +const msgExtInfo = 7 + +type extInfoMsg struct { + NumExtensions uint32 `sshtype:"7"` + Payload []byte `ssh:"rest"` +} + +// See RFC 4252, section 5. +const msgUserAuthRequest = 50 + +type userAuthRequestMsg struct { + User string `sshtype:"50"` + Service string + Method string + Payload []byte `ssh:"rest"` +} + +// Used for debug printouts of packets. +type userAuthSuccessMsg struct { +} + +// See RFC 4252, section 5.1 +const msgUserAuthFailure = 51 + +type userAuthFailureMsg struct { + Methods []string `sshtype:"51"` + PartialSuccess bool +} + +// See RFC 4252, section 5.1 +const msgUserAuthSuccess = 52 + +// See RFC 4252, section 5.4 +const msgUserAuthBanner = 53 + +type userAuthBannerMsg struct { + Message string `sshtype:"53"` + // unused, but required to allow message parsing + Language string +} + +// See RFC 4256, section 3.2 +const msgUserAuthInfoRequest = 60 +const msgUserAuthInfoResponse = 61 + +type userAuthInfoRequestMsg struct { + Name string `sshtype:"60"` + Instruction string + Language string + NumPrompts uint32 + Prompts []byte `ssh:"rest"` +} + +// See RFC 4254, section 5.1. +const msgChannelOpen = 90 + +type channelOpenMsg struct { + ChanType string `sshtype:"90"` + PeersID uint32 + PeersWindow uint32 + MaxPacketSize uint32 + TypeSpecificData []byte `ssh:"rest"` +} + +const msgChannelExtendedData = 95 +const msgChannelData = 94 + +// Used for debug print outs of packets. +type channelDataMsg struct { + PeersID uint32 `sshtype:"94"` + Length uint32 + Rest []byte `ssh:"rest"` +} + +// See RFC 4254, section 5.1. +const msgChannelOpenConfirm = 91 + +type channelOpenConfirmMsg struct { + PeersID uint32 `sshtype:"91"` + MyID uint32 + MyWindow uint32 + MaxPacketSize uint32 + TypeSpecificData []byte `ssh:"rest"` +} + +// See RFC 4254, section 5.1. +const msgChannelOpenFailure = 92 + +type channelOpenFailureMsg struct { + PeersID uint32 `sshtype:"92"` + Reason RejectionReason + Message string + Language string +} + +const msgChannelRequest = 98 + +type channelRequestMsg struct { + PeersID uint32 `sshtype:"98"` + Request string + WantReply bool + RequestSpecificData []byte `ssh:"rest"` +} + +// See RFC 4254, section 5.4. +const msgChannelSuccess = 99 + +type channelRequestSuccessMsg struct { + PeersID uint32 `sshtype:"99"` +} + +// See RFC 4254, section 5.4. +const msgChannelFailure = 100 + +type channelRequestFailureMsg struct { + PeersID uint32 `sshtype:"100"` +} + +// See RFC 4254, section 5.3 +const msgChannelClose = 97 + +type channelCloseMsg struct { + PeersID uint32 `sshtype:"97"` +} + +// See RFC 4254, section 5.3 +const msgChannelEOF = 96 + +type channelEOFMsg struct { + PeersID uint32 `sshtype:"96"` +} + +// See RFC 4254, section 4 +const msgGlobalRequest = 80 + +type globalRequestMsg struct { + Type string `sshtype:"80"` + WantReply bool + Data []byte `ssh:"rest"` +} + +// See RFC 4254, section 4 +const msgRequestSuccess = 81 + +type globalRequestSuccessMsg struct { + Data []byte `ssh:"rest" sshtype:"81"` +} + +// See RFC 4254, section 4 +const msgRequestFailure = 82 + +type globalRequestFailureMsg struct { + Data []byte `ssh:"rest" sshtype:"82"` +} + +// See RFC 4254, section 5.2 +const msgChannelWindowAdjust = 93 + +type windowAdjustMsg struct { + PeersID uint32 `sshtype:"93"` + AdditionalBytes uint32 +} + +// See RFC 4252, section 7 +const msgUserAuthPubKeyOk = 60 + +type userAuthPubKeyOkMsg struct { + Algo string `sshtype:"60"` + PubKey []byte +} + +// See RFC 4462, section 3 +const msgUserAuthGSSAPIResponse = 60 + +type userAuthGSSAPIResponse struct { + SupportMech []byte `sshtype:"60"` +} + +const msgUserAuthGSSAPIToken = 61 + +type userAuthGSSAPIToken struct { + Token []byte `sshtype:"61"` +} + +const msgUserAuthGSSAPIMIC = 66 + +type userAuthGSSAPIMIC struct { + MIC []byte `sshtype:"66"` +} + +// See RFC 4462, section 3.9 +const msgUserAuthGSSAPIErrTok = 64 + +type userAuthGSSAPIErrTok struct { + ErrorToken []byte `sshtype:"64"` +} + +// See RFC 4462, section 3.8 +const msgUserAuthGSSAPIError = 65 + +type userAuthGSSAPIError struct { + MajorStatus uint32 `sshtype:"65"` + MinorStatus uint32 + Message string + LanguageTag string +} + +// Transport layer OpenSSH extension. See [PROTOCOL], section 1.9 +const msgPing = 192 + +type pingMsg struct { + Data string `sshtype:"192"` +} + +// Transport layer OpenSSH extension. See [PROTOCOL], section 1.9 +const msgPong = 193 + +type pongMsg struct { + Data string `sshtype:"193"` +} + +// typeTags returns the possible type bytes for the given reflect.Type, which +// should be a struct. The possible values are separated by a '|' character. +func typeTags(structType reflect.Type) (tags []byte) { + tagStr := structType.Field(0).Tag.Get("sshtype") + + for _, tag := range strings.Split(tagStr, "|") { + i, err := strconv.Atoi(tag) + if err == nil { + tags = append(tags, byte(i)) + } + } + + return tags +} + +func fieldError(t reflect.Type, field int, problem string) error { + if problem != "" { + problem = ": " + problem + } + return fmt.Errorf("ssh: unmarshal error for field %s of type %s%s", t.Field(field).Name, t.Name(), problem) +} + +var errShortRead = errors.New("ssh: short read") + +// Unmarshal parses data in SSH wire format into a structure. The out +// argument should be a pointer to struct. If the first member of the +// struct has the "sshtype" tag set to a '|'-separated set of numbers +// in decimal, the packet must start with one of those numbers. In +// case of error, Unmarshal returns a ParseError or +// UnexpectedMessageError. +func Unmarshal(data []byte, out interface{}) error { + v := reflect.ValueOf(out).Elem() + structType := v.Type() + expectedTypes := typeTags(structType) + + var expectedType byte + if len(expectedTypes) > 0 { + expectedType = expectedTypes[0] + } + + if len(data) == 0 { + return parseError(expectedType) + } + + if len(expectedTypes) > 0 { + goodType := false + for _, e := range expectedTypes { + if e > 0 && data[0] == e { + goodType = true + break + } + } + if !goodType { + return fmt.Errorf("ssh: unexpected message type %d (expected one of %v)", data[0], expectedTypes) + } + data = data[1:] + } + + var ok bool + for i := 0; i < v.NumField(); i++ { + field := v.Field(i) + t := field.Type() + switch t.Kind() { + case reflect.Bool: + if len(data) < 1 { + return errShortRead + } + field.SetBool(data[0] != 0) + data = data[1:] + case reflect.Array: + if t.Elem().Kind() != reflect.Uint8 { + return fieldError(structType, i, "array of unsupported type") + } + if len(data) < t.Len() { + return errShortRead + } + for j, n := 0, t.Len(); j < n; j++ { + field.Index(j).Set(reflect.ValueOf(data[j])) + } + data = data[t.Len():] + case reflect.Uint64: + var u64 uint64 + if u64, data, ok = parseUint64(data); !ok { + return errShortRead + } + field.SetUint(u64) + case reflect.Uint32: + var u32 uint32 + if u32, data, ok = parseUint32(data); !ok { + return errShortRead + } + field.SetUint(uint64(u32)) + case reflect.Uint8: + if len(data) < 1 { + return errShortRead + } + field.SetUint(uint64(data[0])) + data = data[1:] + case reflect.String: + var s []byte + if s, data, ok = parseString(data); !ok { + return fieldError(structType, i, "") + } + field.SetString(string(s)) + case reflect.Slice: + switch t.Elem().Kind() { + case reflect.Uint8: + if structType.Field(i).Tag.Get("ssh") == "rest" { + field.Set(reflect.ValueOf(data)) + data = nil + } else { + var s []byte + if s, data, ok = parseString(data); !ok { + return errShortRead + } + field.Set(reflect.ValueOf(s)) + } + case reflect.String: + var nl []string + if nl, data, ok = parseNameList(data); !ok { + return errShortRead + } + field.Set(reflect.ValueOf(nl)) + default: + return fieldError(structType, i, "slice of unsupported type") + } + case reflect.Ptr: + if t == bigIntType { + var n *big.Int + if n, data, ok = parseInt(data); !ok { + return errShortRead + } + field.Set(reflect.ValueOf(n)) + } else { + return fieldError(structType, i, "pointer to unsupported type") + } + default: + return fieldError(structType, i, fmt.Sprintf("unsupported type: %v", t)) + } + } + + if len(data) != 0 { + return parseError(expectedType) + } + + return nil +} + +// Marshal serializes the message in msg to SSH wire format. The msg +// argument should be a struct or pointer to struct. If the first +// member has the "sshtype" tag set to a number in decimal, that +// number is prepended to the result. If the last of member has the +// "ssh" tag set to "rest", its contents are appended to the output. +func Marshal(msg interface{}) []byte { + out := make([]byte, 0, 64) + return marshalStruct(out, msg) +} + +func marshalStruct(out []byte, msg interface{}) []byte { + v := reflect.Indirect(reflect.ValueOf(msg)) + msgTypes := typeTags(v.Type()) + if len(msgTypes) > 0 { + out = append(out, msgTypes[0]) + } + + for i, n := 0, v.NumField(); i < n; i++ { + field := v.Field(i) + switch t := field.Type(); t.Kind() { + case reflect.Bool: + var v uint8 + if field.Bool() { + v = 1 + } + out = append(out, v) + case reflect.Array: + if t.Elem().Kind() != reflect.Uint8 { + panic(fmt.Sprintf("array of non-uint8 in field %d: %T", i, field.Interface())) + } + for j, l := 0, t.Len(); j < l; j++ { + out = append(out, uint8(field.Index(j).Uint())) + } + case reflect.Uint32: + out = appendU32(out, uint32(field.Uint())) + case reflect.Uint64: + out = appendU64(out, uint64(field.Uint())) + case reflect.Uint8: + out = append(out, uint8(field.Uint())) + case reflect.String: + s := field.String() + out = appendInt(out, len(s)) + out = append(out, s...) + case reflect.Slice: + switch t.Elem().Kind() { + case reflect.Uint8: + if v.Type().Field(i).Tag.Get("ssh") != "rest" { + out = appendInt(out, field.Len()) + } + out = append(out, field.Bytes()...) + case reflect.String: + offset := len(out) + out = appendU32(out, 0) + if n := field.Len(); n > 0 { + for j := 0; j < n; j++ { + f := field.Index(j) + if j != 0 { + out = append(out, ',') + } + out = append(out, f.String()...) + } + // overwrite length value + binary.BigEndian.PutUint32(out[offset:], uint32(len(out)-offset-4)) + } + default: + panic(fmt.Sprintf("slice of unknown type in field %d: %T", i, field.Interface())) + } + case reflect.Ptr: + if t == bigIntType { + var n *big.Int + nValue := reflect.ValueOf(&n) + nValue.Elem().Set(field) + needed := intLength(n) + oldLength := len(out) + + if cap(out)-len(out) < needed { + newOut := make([]byte, len(out), 2*(len(out)+needed)) + copy(newOut, out) + out = newOut + } + out = out[:oldLength+needed] + marshalInt(out[oldLength:], n) + } else { + panic(fmt.Sprintf("pointer to unknown type in field %d: %T", i, field.Interface())) + } + } + } + + return out +} + +var bigOne = big.NewInt(1) + +func parseString(in []byte) (out, rest []byte, ok bool) { + if len(in) < 4 { + return + } + length := binary.BigEndian.Uint32(in) + in = in[4:] + if uint32(len(in)) < length { + return + } + out = in[:length] + rest = in[length:] + ok = true + return +} + +var ( + comma = []byte{','} + emptyNameList = []string{} +) + +func parseNameList(in []byte) (out []string, rest []byte, ok bool) { + contents, rest, ok := parseString(in) + if !ok { + return + } + if len(contents) == 0 { + out = emptyNameList + return + } + parts := bytes.Split(contents, comma) + out = make([]string, len(parts)) + for i, part := range parts { + out[i] = string(part) + } + return +} + +func parseInt(in []byte) (out *big.Int, rest []byte, ok bool) { + contents, rest, ok := parseString(in) + if !ok { + return + } + out = new(big.Int) + + if len(contents) > 0 && contents[0]&0x80 == 0x80 { + // This is a negative number + notBytes := make([]byte, len(contents)) + for i := range notBytes { + notBytes[i] = ^contents[i] + } + out.SetBytes(notBytes) + out.Add(out, bigOne) + out.Neg(out) + } else { + // Positive number + out.SetBytes(contents) + } + ok = true + return +} + +func parseUint32(in []byte) (uint32, []byte, bool) { + if len(in) < 4 { + return 0, nil, false + } + return binary.BigEndian.Uint32(in), in[4:], true +} + +func parseUint64(in []byte) (uint64, []byte, bool) { + if len(in) < 8 { + return 0, nil, false + } + return binary.BigEndian.Uint64(in), in[8:], true +} + +func intLength(n *big.Int) int { + length := 4 /* length bytes */ + if n.Sign() < 0 { + nMinus1 := new(big.Int).Neg(n) + nMinus1.Sub(nMinus1, bigOne) + bitLen := nMinus1.BitLen() + if bitLen%8 == 0 { + // The number will need 0xff padding + length++ + } + length += (bitLen + 7) / 8 + } else if n.Sign() == 0 { + // A zero is the zero length string + } else { + bitLen := n.BitLen() + if bitLen%8 == 0 { + // The number will need 0x00 padding + length++ + } + length += (bitLen + 7) / 8 + } + + return length +} + +func marshalUint32(to []byte, n uint32) []byte { + binary.BigEndian.PutUint32(to, n) + return to[4:] +} + +func marshalUint64(to []byte, n uint64) []byte { + binary.BigEndian.PutUint64(to, n) + return to[8:] +} + +func marshalInt(to []byte, n *big.Int) []byte { + lengthBytes := to + to = to[4:] + length := 0 + + if n.Sign() < 0 { + // A negative number has to be converted to two's-complement + // form. So we'll subtract 1 and invert. If the + // most-significant-bit isn't set then we'll need to pad the + // beginning with 0xff in order to keep the number negative. + nMinus1 := new(big.Int).Neg(n) + nMinus1.Sub(nMinus1, bigOne) + bytes := nMinus1.Bytes() + for i := range bytes { + bytes[i] ^= 0xff + } + if len(bytes) == 0 || bytes[0]&0x80 == 0 { + to[0] = 0xff + to = to[1:] + length++ + } + nBytes := copy(to, bytes) + to = to[nBytes:] + length += nBytes + } else if n.Sign() == 0 { + // A zero is the zero length string + } else { + bytes := n.Bytes() + if len(bytes) > 0 && bytes[0]&0x80 != 0 { + // We'll have to pad this with a 0x00 in order to + // stop it looking like a negative number. + to[0] = 0 + to = to[1:] + length++ + } + nBytes := copy(to, bytes) + to = to[nBytes:] + length += nBytes + } + + lengthBytes[0] = byte(length >> 24) + lengthBytes[1] = byte(length >> 16) + lengthBytes[2] = byte(length >> 8) + lengthBytes[3] = byte(length) + return to +} + +func writeInt(w io.Writer, n *big.Int) { + length := intLength(n) + buf := make([]byte, length) + marshalInt(buf, n) + w.Write(buf) +} + +func writeString(w io.Writer, s []byte) { + var lengthBytes [4]byte + lengthBytes[0] = byte(len(s) >> 24) + lengthBytes[1] = byte(len(s) >> 16) + lengthBytes[2] = byte(len(s) >> 8) + lengthBytes[3] = byte(len(s)) + w.Write(lengthBytes[:]) + w.Write(s) +} + +func stringLength(n int) int { + return 4 + n +} + +func marshalString(to []byte, s []byte) []byte { + to[0] = byte(len(s) >> 24) + to[1] = byte(len(s) >> 16) + to[2] = byte(len(s) >> 8) + to[3] = byte(len(s)) + to = to[4:] + copy(to, s) + return to[len(s):] +} + +var bigIntType = reflect.TypeOf((*big.Int)(nil)) + +// Decode a packet into its corresponding message. +func decode(packet []byte) (interface{}, error) { + var msg interface{} + switch packet[0] { + case msgDisconnect: + msg = new(disconnectMsg) + case msgServiceRequest: + msg = new(serviceRequestMsg) + case msgServiceAccept: + msg = new(serviceAcceptMsg) + case msgExtInfo: + msg = new(extInfoMsg) + case msgKexInit: + msg = new(kexInitMsg) + case msgKexDHInit: + msg = new(kexDHInitMsg) + case msgKexDHReply: + msg = new(kexDHReplyMsg) + case msgUserAuthRequest: + msg = new(userAuthRequestMsg) + case msgUserAuthSuccess: + return new(userAuthSuccessMsg), nil + case msgUserAuthFailure: + msg = new(userAuthFailureMsg) + case msgUserAuthPubKeyOk: + msg = new(userAuthPubKeyOkMsg) + case msgGlobalRequest: + msg = new(globalRequestMsg) + case msgRequestSuccess: + msg = new(globalRequestSuccessMsg) + case msgRequestFailure: + msg = new(globalRequestFailureMsg) + case msgChannelOpen: + msg = new(channelOpenMsg) + case msgChannelData: + msg = new(channelDataMsg) + case msgChannelOpenConfirm: + msg = new(channelOpenConfirmMsg) + case msgChannelOpenFailure: + msg = new(channelOpenFailureMsg) + case msgChannelWindowAdjust: + msg = new(windowAdjustMsg) + case msgChannelEOF: + msg = new(channelEOFMsg) + case msgChannelClose: + msg = new(channelCloseMsg) + case msgChannelRequest: + msg = new(channelRequestMsg) + case msgChannelSuccess: + msg = new(channelRequestSuccessMsg) + case msgChannelFailure: + msg = new(channelRequestFailureMsg) + case msgUserAuthGSSAPIToken: + msg = new(userAuthGSSAPIToken) + case msgUserAuthGSSAPIMIC: + msg = new(userAuthGSSAPIMIC) + case msgUserAuthGSSAPIErrTok: + msg = new(userAuthGSSAPIErrTok) + case msgUserAuthGSSAPIError: + msg = new(userAuthGSSAPIError) + default: + return nil, unexpectedMessageError(0, packet[0]) + } + if err := Unmarshal(packet, msg); err != nil { + return nil, err + } + return msg, nil +} + +var packetTypeNames = map[byte]string{ + msgDisconnect: "disconnectMsg", + msgServiceRequest: "serviceRequestMsg", + msgServiceAccept: "serviceAcceptMsg", + msgExtInfo: "extInfoMsg", + msgKexInit: "kexInitMsg", + msgKexDHInit: "kexDHInitMsg", + msgKexDHReply: "kexDHReplyMsg", + msgUserAuthRequest: "userAuthRequestMsg", + msgUserAuthSuccess: "userAuthSuccessMsg", + msgUserAuthFailure: "userAuthFailureMsg", + msgUserAuthPubKeyOk: "userAuthPubKeyOkMsg", + msgGlobalRequest: "globalRequestMsg", + msgRequestSuccess: "globalRequestSuccessMsg", + msgRequestFailure: "globalRequestFailureMsg", + msgChannelOpen: "channelOpenMsg", + msgChannelData: "channelDataMsg", + msgChannelOpenConfirm: "channelOpenConfirmMsg", + msgChannelOpenFailure: "channelOpenFailureMsg", + msgChannelWindowAdjust: "windowAdjustMsg", + msgChannelEOF: "channelEOFMsg", + msgChannelClose: "channelCloseMsg", + msgChannelRequest: "channelRequestMsg", + msgChannelSuccess: "channelRequestSuccessMsg", + msgChannelFailure: "channelRequestFailureMsg", +} diff --git a/tempfork/sshtest/ssh/messages_test.go b/tempfork/sshtest/ssh/messages_test.go new file mode 100644 index 000000000..e79076412 --- /dev/null +++ b/tempfork/sshtest/ssh/messages_test.go @@ -0,0 +1,288 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "bytes" + "math/big" + "math/rand" + "reflect" + "testing" + "testing/quick" +) + +var intLengthTests = []struct { + val, length int +}{ + {0, 4 + 0}, + {1, 4 + 1}, + {127, 4 + 1}, + {128, 4 + 2}, + {-1, 4 + 1}, +} + +func TestIntLength(t *testing.T) { + for _, test := range intLengthTests { + v := new(big.Int).SetInt64(int64(test.val)) + length := intLength(v) + if length != test.length { + t.Errorf("For %d, got length %d but expected %d", test.val, length, test.length) + } + } +} + +type msgAllTypes struct { + Bool bool `sshtype:"21"` + Array [16]byte + Uint64 uint64 + Uint32 uint32 + Uint8 uint8 + String string + Strings []string + Bytes []byte + Int *big.Int + Rest []byte `ssh:"rest"` +} + +func (t *msgAllTypes) Generate(rand *rand.Rand, size int) reflect.Value { + m := &msgAllTypes{} + m.Bool = rand.Intn(2) == 1 + randomBytes(m.Array[:], rand) + m.Uint64 = uint64(rand.Int63n(1<<63 - 1)) + m.Uint32 = uint32(rand.Intn((1 << 31) - 1)) + m.Uint8 = uint8(rand.Intn(1 << 8)) + m.String = string(m.Array[:]) + m.Strings = randomNameList(rand) + m.Bytes = m.Array[:] + m.Int = randomInt(rand) + m.Rest = m.Array[:] + return reflect.ValueOf(m) +} + +func TestMarshalUnmarshal(t *testing.T) { + rand := rand.New(rand.NewSource(0)) + iface := &msgAllTypes{} + ty := reflect.ValueOf(iface).Type() + + n := 100 + if testing.Short() { + n = 5 + } + for j := 0; j < n; j++ { + v, ok := quick.Value(ty, rand) + if !ok { + t.Errorf("failed to create value") + break + } + + m1 := v.Elem().Interface() + m2 := iface + + marshaled := Marshal(m1) + if err := Unmarshal(marshaled, m2); err != nil { + t.Errorf("Unmarshal %#v: %s", m1, err) + break + } + + if !reflect.DeepEqual(v.Interface(), m2) { + t.Errorf("got: %#v\nwant:%#v\n%x", m2, m1, marshaled) + break + } + } +} + +func TestUnmarshalEmptyPacket(t *testing.T) { + var b []byte + var m channelRequestSuccessMsg + if err := Unmarshal(b, &m); err == nil { + t.Fatalf("unmarshal of empty slice succeeded") + } +} + +func TestUnmarshalUnexpectedPacket(t *testing.T) { + type S struct { + I uint32 `sshtype:"43"` + S string + B bool + } + + s := S{11, "hello", true} + packet := Marshal(s) + packet[0] = 42 + roundtrip := S{} + err := Unmarshal(packet, &roundtrip) + if err == nil { + t.Fatal("expected error, not nil") + } +} + +func TestMarshalPtr(t *testing.T) { + s := struct { + S string + }{"hello"} + + m1 := Marshal(s) + m2 := Marshal(&s) + if !bytes.Equal(m1, m2) { + t.Errorf("got %q, want %q for marshaled pointer", m2, m1) + } +} + +func TestBareMarshalUnmarshal(t *testing.T) { + type S struct { + I uint32 + S string + B bool + } + + s := S{42, "hello", true} + packet := Marshal(s) + roundtrip := S{} + Unmarshal(packet, &roundtrip) + + if !reflect.DeepEqual(s, roundtrip) { + t.Errorf("got %#v, want %#v", roundtrip, s) + } +} + +func TestBareMarshal(t *testing.T) { + type S2 struct { + I uint32 + } + s := S2{42} + packet := Marshal(s) + i, rest, ok := parseUint32(packet) + if len(rest) > 0 || !ok { + t.Errorf("parseInt(%q): parse error", packet) + } + if i != s.I { + t.Errorf("got %d, want %d", i, s.I) + } +} + +func TestUnmarshalShortKexInitPacket(t *testing.T) { + // This used to panic. + // Issue 11348 + packet := []byte{0x14, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0xff, 0xff, 0xff, 0xff} + kim := &kexInitMsg{} + if err := Unmarshal(packet, kim); err == nil { + t.Error("truncated packet unmarshaled without error") + } +} + +func TestMarshalMultiTag(t *testing.T) { + var res struct { + A uint32 `sshtype:"1|2"` + } + + good1 := struct { + A uint32 `sshtype:"1"` + }{ + 1, + } + good2 := struct { + A uint32 `sshtype:"2"` + }{ + 1, + } + + if e := Unmarshal(Marshal(good1), &res); e != nil { + t.Errorf("error unmarshaling multipart tag: %v", e) + } + + if e := Unmarshal(Marshal(good2), &res); e != nil { + t.Errorf("error unmarshaling multipart tag: %v", e) + } + + bad1 := struct { + A uint32 `sshtype:"3"` + }{ + 1, + } + if e := Unmarshal(Marshal(bad1), &res); e == nil { + t.Errorf("bad struct unmarshaled without error") + } +} + +func randomBytes(out []byte, rand *rand.Rand) { + for i := 0; i < len(out); i++ { + out[i] = byte(rand.Int31()) + } +} + +func randomNameList(rand *rand.Rand) []string { + ret := make([]string, rand.Int31()&15) + for i := range ret { + s := make([]byte, 1+(rand.Int31()&15)) + for j := range s { + s[j] = 'a' + uint8(rand.Int31()&15) + } + ret[i] = string(s) + } + return ret +} + +func randomInt(rand *rand.Rand) *big.Int { + return new(big.Int).SetInt64(int64(int32(rand.Uint32()))) +} + +func (*kexInitMsg) Generate(rand *rand.Rand, size int) reflect.Value { + ki := &kexInitMsg{} + randomBytes(ki.Cookie[:], rand) + ki.KexAlgos = randomNameList(rand) + ki.ServerHostKeyAlgos = randomNameList(rand) + ki.CiphersClientServer = randomNameList(rand) + ki.CiphersServerClient = randomNameList(rand) + ki.MACsClientServer = randomNameList(rand) + ki.MACsServerClient = randomNameList(rand) + ki.CompressionClientServer = randomNameList(rand) + ki.CompressionServerClient = randomNameList(rand) + ki.LanguagesClientServer = randomNameList(rand) + ki.LanguagesServerClient = randomNameList(rand) + if rand.Int31()&1 == 1 { + ki.FirstKexFollows = true + } + return reflect.ValueOf(ki) +} + +func (*kexDHInitMsg) Generate(rand *rand.Rand, size int) reflect.Value { + dhi := &kexDHInitMsg{} + dhi.X = randomInt(rand) + return reflect.ValueOf(dhi) +} + +var ( + _kexInitMsg = new(kexInitMsg).Generate(rand.New(rand.NewSource(0)), 10).Elem().Interface() + _kexDHInitMsg = new(kexDHInitMsg).Generate(rand.New(rand.NewSource(0)), 10).Elem().Interface() + + _kexInit = Marshal(_kexInitMsg) + _kexDHInit = Marshal(_kexDHInitMsg) +) + +func BenchmarkMarshalKexInitMsg(b *testing.B) { + for i := 0; i < b.N; i++ { + Marshal(_kexInitMsg) + } +} + +func BenchmarkUnmarshalKexInitMsg(b *testing.B) { + m := new(kexInitMsg) + for i := 0; i < b.N; i++ { + Unmarshal(_kexInit, m) + } +} + +func BenchmarkMarshalKexDHInitMsg(b *testing.B) { + for i := 0; i < b.N; i++ { + Marshal(_kexDHInitMsg) + } +} + +func BenchmarkUnmarshalKexDHInitMsg(b *testing.B) { + m := new(kexDHInitMsg) + for i := 0; i < b.N; i++ { + Unmarshal(_kexDHInit, m) + } +} diff --git a/tempfork/sshtest/ssh/mux.go b/tempfork/sshtest/ssh/mux.go new file mode 100644 index 000000000..d2d24c635 --- /dev/null +++ b/tempfork/sshtest/ssh/mux.go @@ -0,0 +1,357 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "encoding/binary" + "fmt" + "io" + "log" + "sync" + "sync/atomic" +) + +// debugMux, if set, causes messages in the connection protocol to be +// logged. +const debugMux = false + +// chanList is a thread safe channel list. +type chanList struct { + // protects concurrent access to chans + sync.Mutex + + // chans are indexed by the local id of the channel, which the + // other side should send in the PeersId field. + chans []*channel + + // This is a debugging aid: it offsets all IDs by this + // amount. This helps distinguish otherwise identical + // server/client muxes + offset uint32 +} + +// Assigns a channel ID to the given channel. +func (c *chanList) add(ch *channel) uint32 { + c.Lock() + defer c.Unlock() + for i := range c.chans { + if c.chans[i] == nil { + c.chans[i] = ch + return uint32(i) + c.offset + } + } + c.chans = append(c.chans, ch) + return uint32(len(c.chans)-1) + c.offset +} + +// getChan returns the channel for the given ID. +func (c *chanList) getChan(id uint32) *channel { + id -= c.offset + + c.Lock() + defer c.Unlock() + if id < uint32(len(c.chans)) { + return c.chans[id] + } + return nil +} + +func (c *chanList) remove(id uint32) { + id -= c.offset + c.Lock() + if id < uint32(len(c.chans)) { + c.chans[id] = nil + } + c.Unlock() +} + +// dropAll forgets all channels it knows, returning them in a slice. +func (c *chanList) dropAll() []*channel { + c.Lock() + defer c.Unlock() + var r []*channel + + for _, ch := range c.chans { + if ch == nil { + continue + } + r = append(r, ch) + } + c.chans = nil + return r +} + +// mux represents the state for the SSH connection protocol, which +// multiplexes many channels onto a single packet transport. +type mux struct { + conn packetConn + chanList chanList + + incomingChannels chan NewChannel + + globalSentMu sync.Mutex + globalResponses chan interface{} + incomingRequests chan *Request + + errCond *sync.Cond + err error +} + +// When debugging, each new chanList instantiation has a different +// offset. +var globalOff uint32 + +func (m *mux) Wait() error { + m.errCond.L.Lock() + defer m.errCond.L.Unlock() + for m.err == nil { + m.errCond.Wait() + } + return m.err +} + +// newMux returns a mux that runs over the given connection. +func newMux(p packetConn) *mux { + m := &mux{ + conn: p, + incomingChannels: make(chan NewChannel, chanSize), + globalResponses: make(chan interface{}, 1), + incomingRequests: make(chan *Request, chanSize), + errCond: newCond(), + } + if debugMux { + m.chanList.offset = atomic.AddUint32(&globalOff, 1) + } + + go m.loop() + return m +} + +func (m *mux) sendMessage(msg interface{}) error { + p := Marshal(msg) + if debugMux { + log.Printf("send global(%d): %#v", m.chanList.offset, msg) + } + return m.conn.writePacket(p) +} + +func (m *mux) SendRequest(name string, wantReply bool, payload []byte) (bool, []byte, error) { + if wantReply { + m.globalSentMu.Lock() + defer m.globalSentMu.Unlock() + } + + if err := m.sendMessage(globalRequestMsg{ + Type: name, + WantReply: wantReply, + Data: payload, + }); err != nil { + return false, nil, err + } + + if !wantReply { + return false, nil, nil + } + + msg, ok := <-m.globalResponses + if !ok { + return false, nil, io.EOF + } + switch msg := msg.(type) { + case *globalRequestFailureMsg: + return false, msg.Data, nil + case *globalRequestSuccessMsg: + return true, msg.Data, nil + default: + return false, nil, fmt.Errorf("ssh: unexpected response to request: %#v", msg) + } +} + +// ackRequest must be called after processing a global request that +// has WantReply set. +func (m *mux) ackRequest(ok bool, data []byte) error { + if ok { + return m.sendMessage(globalRequestSuccessMsg{Data: data}) + } + return m.sendMessage(globalRequestFailureMsg{Data: data}) +} + +func (m *mux) Close() error { + return m.conn.Close() +} + +// loop runs the connection machine. It will process packets until an +// error is encountered. To synchronize on loop exit, use mux.Wait. +func (m *mux) loop() { + var err error + for err == nil { + err = m.onePacket() + } + + for _, ch := range m.chanList.dropAll() { + ch.close() + } + + close(m.incomingChannels) + close(m.incomingRequests) + close(m.globalResponses) + + m.conn.Close() + + m.errCond.L.Lock() + m.err = err + m.errCond.Broadcast() + m.errCond.L.Unlock() + + if debugMux { + log.Println("loop exit", err) + } +} + +// onePacket reads and processes one packet. +func (m *mux) onePacket() error { + packet, err := m.conn.readPacket() + if err != nil { + return err + } + + if debugMux { + if packet[0] == msgChannelData || packet[0] == msgChannelExtendedData { + log.Printf("decoding(%d): data packet - %d bytes", m.chanList.offset, len(packet)) + } else { + p, _ := decode(packet) + log.Printf("decoding(%d): %d %#v - %d bytes", m.chanList.offset, packet[0], p, len(packet)) + } + } + + switch packet[0] { + case msgChannelOpen: + return m.handleChannelOpen(packet) + case msgGlobalRequest, msgRequestSuccess, msgRequestFailure: + return m.handleGlobalPacket(packet) + case msgPing: + var msg pingMsg + if err := Unmarshal(packet, &msg); err != nil { + return fmt.Errorf("failed to unmarshal ping@openssh.com message: %w", err) + } + return m.sendMessage(pongMsg(msg)) + } + + // assume a channel packet. + if len(packet) < 5 { + return parseError(packet[0]) + } + id := binary.BigEndian.Uint32(packet[1:]) + ch := m.chanList.getChan(id) + if ch == nil { + return m.handleUnknownChannelPacket(id, packet) + } + + return ch.handlePacket(packet) +} + +func (m *mux) handleGlobalPacket(packet []byte) error { + msg, err := decode(packet) + if err != nil { + return err + } + + switch msg := msg.(type) { + case *globalRequestMsg: + m.incomingRequests <- &Request{ + Type: msg.Type, + WantReply: msg.WantReply, + Payload: msg.Data, + mux: m, + } + case *globalRequestSuccessMsg, *globalRequestFailureMsg: + m.globalResponses <- msg + default: + panic(fmt.Sprintf("not a global message %#v", msg)) + } + + return nil +} + +// handleChannelOpen schedules a channel to be Accept()ed. +func (m *mux) handleChannelOpen(packet []byte) error { + var msg channelOpenMsg + if err := Unmarshal(packet, &msg); err != nil { + return err + } + + if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > 1<<31 { + failMsg := channelOpenFailureMsg{ + PeersID: msg.PeersID, + Reason: ConnectionFailed, + Message: "invalid request", + Language: "en_US.UTF-8", + } + return m.sendMessage(failMsg) + } + + c := m.newChannel(msg.ChanType, channelInbound, msg.TypeSpecificData) + c.remoteId = msg.PeersID + c.maxRemotePayload = msg.MaxPacketSize + c.remoteWin.add(msg.PeersWindow) + m.incomingChannels <- c + return nil +} + +func (m *mux) OpenChannel(chanType string, extra []byte) (Channel, <-chan *Request, error) { + ch, err := m.openChannel(chanType, extra) + if err != nil { + return nil, nil, err + } + + return ch, ch.incomingRequests, nil +} + +func (m *mux) openChannel(chanType string, extra []byte) (*channel, error) { + ch := m.newChannel(chanType, channelOutbound, extra) + + ch.maxIncomingPayload = channelMaxPacket + + open := channelOpenMsg{ + ChanType: chanType, + PeersWindow: ch.myWindow, + MaxPacketSize: ch.maxIncomingPayload, + TypeSpecificData: extra, + PeersID: ch.localId, + } + if err := m.sendMessage(open); err != nil { + return nil, err + } + + switch msg := (<-ch.msg).(type) { + case *channelOpenConfirmMsg: + return ch, nil + case *channelOpenFailureMsg: + return nil, &OpenChannelError{msg.Reason, msg.Message} + default: + return nil, fmt.Errorf("ssh: unexpected packet in response to channel open: %T", msg) + } +} + +func (m *mux) handleUnknownChannelPacket(id uint32, packet []byte) error { + msg, err := decode(packet) + if err != nil { + return err + } + + switch msg := msg.(type) { + // RFC 4254 section 5.4 says unrecognized channel requests should + // receive a failure response. + case *channelRequestMsg: + if msg.WantReply { + return m.sendMessage(channelRequestFailureMsg{ + PeersID: msg.PeersID, + }) + } + return nil + default: + return fmt.Errorf("ssh: invalid channel %d", id) + } +} diff --git a/tempfork/sshtest/ssh/mux_test.go b/tempfork/sshtest/ssh/mux_test.go new file mode 100644 index 000000000..21f0ac3e3 --- /dev/null +++ b/tempfork/sshtest/ssh/mux_test.go @@ -0,0 +1,839 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "errors" + "fmt" + "io" + "sync" + "testing" +) + +func muxPair() (*mux, *mux) { + a, b := memPipe() + + s := newMux(a) + c := newMux(b) + + return s, c +} + +// Returns both ends of a channel, and the mux for the 2nd +// channel. +func channelPair(t *testing.T) (*channel, *channel, *mux) { + c, s := muxPair() + + res := make(chan *channel, 1) + go func() { + newCh, ok := <-s.incomingChannels + if !ok { + t.Error("no incoming channel") + close(res) + return + } + if newCh.ChannelType() != "chan" { + t.Errorf("got type %q want chan", newCh.ChannelType()) + newCh.Reject(Prohibited, fmt.Sprintf("got type %q want chan", newCh.ChannelType())) + close(res) + return + } + ch, _, err := newCh.Accept() + if err != nil { + t.Errorf("accept: %v", err) + close(res) + return + } + res <- ch.(*channel) + }() + + ch, err := c.openChannel("chan", nil) + if err != nil { + t.Fatalf("OpenChannel: %v", err) + } + w := <-res + if w == nil { + t.Fatal("unable to get write channel") + } + + return w, ch, c +} + +// Test that stderr and stdout can be addressed from different +// goroutines. This is intended for use with the race detector. +func TestMuxChannelExtendedThreadSafety(t *testing.T) { + writer, reader, mux := channelPair(t) + defer writer.Close() + defer reader.Close() + defer mux.Close() + + var wr, rd sync.WaitGroup + magic := "hello world" + + wr.Add(2) + go func() { + io.WriteString(writer, magic) + wr.Done() + }() + go func() { + io.WriteString(writer.Stderr(), magic) + wr.Done() + }() + + rd.Add(2) + go func() { + c, err := io.ReadAll(reader) + if string(c) != magic { + t.Errorf("stdout read got %q, want %q (error %s)", c, magic, err) + } + rd.Done() + }() + go func() { + c, err := io.ReadAll(reader.Stderr()) + if string(c) != magic { + t.Errorf("stderr read got %q, want %q (error %s)", c, magic, err) + } + rd.Done() + }() + + wr.Wait() + writer.CloseWrite() + rd.Wait() +} + +func TestMuxReadWrite(t *testing.T) { + s, c, mux := channelPair(t) + defer s.Close() + defer c.Close() + defer mux.Close() + + magic := "hello world" + magicExt := "hello stderr" + var wg sync.WaitGroup + t.Cleanup(wg.Wait) + wg.Add(1) + go func() { + defer wg.Done() + _, err := s.Write([]byte(magic)) + if err != nil { + t.Errorf("Write: %v", err) + return + } + _, err = s.Extended(1).Write([]byte(magicExt)) + if err != nil { + t.Errorf("Write: %v", err) + return + } + }() + + var buf [1024]byte + n, err := c.Read(buf[:]) + if err != nil { + t.Fatalf("server Read: %v", err) + } + got := string(buf[:n]) + if got != magic { + t.Fatalf("server: got %q want %q", got, magic) + } + + n, err = c.Extended(1).Read(buf[:]) + if err != nil { + t.Fatalf("server Read: %v", err) + } + + got = string(buf[:n]) + if got != magicExt { + t.Fatalf("server: got %q want %q", got, magic) + } +} + +func TestMuxChannelOverflow(t *testing.T) { + reader, writer, mux := channelPair(t) + defer reader.Close() + defer writer.Close() + defer mux.Close() + + var wg sync.WaitGroup + t.Cleanup(wg.Wait) + wg.Add(1) + go func() { + defer wg.Done() + if _, err := writer.Write(make([]byte, channelWindowSize)); err != nil { + t.Errorf("could not fill window: %v", err) + } + writer.Write(make([]byte, 1)) + }() + writer.remoteWin.waitWriterBlocked() + + // Send 1 byte. + packet := make([]byte, 1+4+4+1) + packet[0] = msgChannelData + marshalUint32(packet[1:], writer.remoteId) + marshalUint32(packet[5:], uint32(1)) + packet[9] = 42 + + if err := writer.mux.conn.writePacket(packet); err != nil { + t.Errorf("could not send packet") + } + if _, err := reader.SendRequest("hello", true, nil); err == nil { + t.Errorf("SendRequest succeeded.") + } +} + +func TestMuxChannelReadUnblock(t *testing.T) { + reader, writer, mux := channelPair(t) + defer reader.Close() + defer writer.Close() + defer mux.Close() + + var wg sync.WaitGroup + t.Cleanup(wg.Wait) + wg.Add(1) + go func() { + defer wg.Done() + if _, err := writer.Write(make([]byte, channelWindowSize)); err != nil { + t.Errorf("could not fill window: %v", err) + } + if _, err := writer.Write(make([]byte, 1)); err != nil { + t.Errorf("Write: %v", err) + } + writer.Close() + }() + + writer.remoteWin.waitWriterBlocked() + + buf := make([]byte, 32768) + for { + _, err := reader.Read(buf) + if err == io.EOF { + break + } + if err != nil { + t.Fatalf("Read: %v", err) + } + } +} + +func TestMuxChannelCloseWriteUnblock(t *testing.T) { + reader, writer, mux := channelPair(t) + defer reader.Close() + defer writer.Close() + defer mux.Close() + + var wg sync.WaitGroup + t.Cleanup(wg.Wait) + wg.Add(1) + go func() { + defer wg.Done() + if _, err := writer.Write(make([]byte, channelWindowSize)); err != nil { + t.Errorf("could not fill window: %v", err) + } + if _, err := writer.Write(make([]byte, 1)); err != io.EOF { + t.Errorf("got %v, want EOF for unblock write", err) + } + }() + + writer.remoteWin.waitWriterBlocked() + reader.Close() +} + +func TestMuxConnectionCloseWriteUnblock(t *testing.T) { + reader, writer, mux := channelPair(t) + defer reader.Close() + defer writer.Close() + defer mux.Close() + + var wg sync.WaitGroup + t.Cleanup(wg.Wait) + wg.Add(1) + go func() { + defer wg.Done() + if _, err := writer.Write(make([]byte, channelWindowSize)); err != nil { + t.Errorf("could not fill window: %v", err) + } + if _, err := writer.Write(make([]byte, 1)); err != io.EOF { + t.Errorf("got %v, want EOF for unblock write", err) + } + }() + + writer.remoteWin.waitWriterBlocked() + mux.Close() +} + +func TestMuxReject(t *testing.T) { + client, server := muxPair() + defer server.Close() + defer client.Close() + + var wg sync.WaitGroup + t.Cleanup(wg.Wait) + wg.Add(1) + go func() { + defer wg.Done() + + ch, ok := <-server.incomingChannels + if !ok { + t.Error("cannot accept channel") + return + } + if ch.ChannelType() != "ch" || string(ch.ExtraData()) != "extra" { + t.Errorf("unexpected channel: %q, %q", ch.ChannelType(), ch.ExtraData()) + ch.Reject(RejectionReason(UnknownChannelType), UnknownChannelType.String()) + return + } + ch.Reject(RejectionReason(42), "message") + }() + + ch, err := client.openChannel("ch", []byte("extra")) + if ch != nil { + t.Fatal("openChannel not rejected") + } + + ocf, ok := err.(*OpenChannelError) + if !ok { + t.Errorf("got %#v want *OpenChannelError", err) + } else if ocf.Reason != 42 || ocf.Message != "message" { + t.Errorf("got %#v, want {Reason: 42, Message: %q}", ocf, "message") + } + + want := "ssh: rejected: unknown reason 42 (message)" + if err.Error() != want { + t.Errorf("got %q, want %q", err.Error(), want) + } +} + +func TestMuxChannelRequest(t *testing.T) { + client, server, mux := channelPair(t) + defer server.Close() + defer client.Close() + defer mux.Close() + + var received int + var wg sync.WaitGroup + t.Cleanup(wg.Wait) + wg.Add(1) + go func() { + for r := range server.incomingRequests { + received++ + r.Reply(r.Type == "yes", nil) + } + wg.Done() + }() + _, err := client.SendRequest("yes", false, nil) + if err != nil { + t.Fatalf("SendRequest: %v", err) + } + ok, err := client.SendRequest("yes", true, nil) + if err != nil { + t.Fatalf("SendRequest: %v", err) + } + + if !ok { + t.Errorf("SendRequest(yes): %v", ok) + + } + + ok, err = client.SendRequest("no", true, nil) + if err != nil { + t.Fatalf("SendRequest: %v", err) + } + if ok { + t.Errorf("SendRequest(no): %v", ok) + } + + client.Close() + wg.Wait() + + if received != 3 { + t.Errorf("got %d requests, want %d", received, 3) + } +} + +func TestMuxUnknownChannelRequests(t *testing.T) { + clientPipe, serverPipe := memPipe() + client := newMux(clientPipe) + defer serverPipe.Close() + defer client.Close() + + kDone := make(chan error, 1) + go func() { + // Ignore unknown channel messages that don't want a reply. + err := serverPipe.writePacket(Marshal(channelRequestMsg{ + PeersID: 1, + Request: "keepalive@openssh.com", + WantReply: false, + RequestSpecificData: []byte{}, + })) + if err != nil { + kDone <- fmt.Errorf("send: %w", err) + return + } + + // Send a keepalive, which should get a channel failure message + // in response. + err = serverPipe.writePacket(Marshal(channelRequestMsg{ + PeersID: 2, + Request: "keepalive@openssh.com", + WantReply: true, + RequestSpecificData: []byte{}, + })) + if err != nil { + kDone <- fmt.Errorf("send: %w", err) + return + } + + packet, err := serverPipe.readPacket() + if err != nil { + kDone <- fmt.Errorf("read packet: %w", err) + return + } + decoded, err := decode(packet) + if err != nil { + kDone <- fmt.Errorf("decode failed: %w", err) + return + } + + switch msg := decoded.(type) { + case *channelRequestFailureMsg: + if msg.PeersID != 2 { + kDone <- fmt.Errorf("received response to wrong message: %v", msg) + return + + } + default: + kDone <- fmt.Errorf("unexpected channel message: %v", msg) + return + } + + kDone <- nil + + // Receive and respond to the keepalive to confirm the mux is + // still processing requests. + packet, err = serverPipe.readPacket() + if err != nil { + kDone <- fmt.Errorf("read packet: %w", err) + return + } + if packet[0] != msgGlobalRequest { + kDone <- errors.New("expected global request") + return + } + + err = serverPipe.writePacket(Marshal(globalRequestFailureMsg{ + Data: []byte{}, + })) + if err != nil { + kDone <- fmt.Errorf("failed to send failure msg: %w", err) + return + } + + close(kDone) + }() + + // Wait for the server to send the keepalive message and receive back a + // response. + if err := <-kDone; err != nil { + t.Fatal(err) + } + + // Confirm client hasn't closed. + if _, _, err := client.SendRequest("keepalive@golang.org", true, nil); err != nil { + t.Fatalf("failed to send keepalive: %v", err) + } + + // Wait for the server to shut down. + if err := <-kDone; err != nil { + t.Fatal(err) + } +} + +func TestMuxClosedChannel(t *testing.T) { + clientPipe, serverPipe := memPipe() + client := newMux(clientPipe) + defer serverPipe.Close() + defer client.Close() + + kDone := make(chan error, 1) + go func() { + // Open the channel. + packet, err := serverPipe.readPacket() + if err != nil { + kDone <- fmt.Errorf("read packet: %w", err) + return + } + if packet[0] != msgChannelOpen { + kDone <- errors.New("expected chan open") + return + } + + var openMsg channelOpenMsg + if err := Unmarshal(packet, &openMsg); err != nil { + kDone <- fmt.Errorf("unmarshal: %w", err) + return + } + + // Send back the opened channel confirmation. + err = serverPipe.writePacket(Marshal(channelOpenConfirmMsg{ + PeersID: openMsg.PeersID, + MyID: 0, + MyWindow: 0, + MaxPacketSize: channelMaxPacket, + })) + if err != nil { + kDone <- fmt.Errorf("send: %w", err) + return + } + + // Close the channel. + err = serverPipe.writePacket(Marshal(channelCloseMsg{ + PeersID: openMsg.PeersID, + })) + if err != nil { + kDone <- fmt.Errorf("send: %w", err) + return + } + + // Send a keepalive message on the channel we just closed. + err = serverPipe.writePacket(Marshal(channelRequestMsg{ + PeersID: openMsg.PeersID, + Request: "keepalive@openssh.com", + WantReply: true, + RequestSpecificData: []byte{}, + })) + if err != nil { + kDone <- fmt.Errorf("send: %w", err) + return + } + + // Receive the channel closed response. + packet, err = serverPipe.readPacket() + if err != nil { + kDone <- fmt.Errorf("read packet: %w", err) + return + } + if packet[0] != msgChannelClose { + kDone <- errors.New("expected channel close") + return + } + + // Receive the keepalive response failure. + packet, err = serverPipe.readPacket() + if err != nil { + kDone <- fmt.Errorf("read packet: %w", err) + return + } + if packet[0] != msgChannelFailure { + kDone <- errors.New("expected channel failure") + return + } + kDone <- nil + + // Receive and respond to the keepalive to confirm the mux is + // still processing requests. + packet, err = serverPipe.readPacket() + if err != nil { + kDone <- fmt.Errorf("read packet: %w", err) + return + } + if packet[0] != msgGlobalRequest { + kDone <- errors.New("expected global request") + return + } + + err = serverPipe.writePacket(Marshal(globalRequestFailureMsg{ + Data: []byte{}, + })) + if err != nil { + kDone <- fmt.Errorf("failed to send failure msg: %w", err) + return + } + + close(kDone) + }() + + // Open a channel. + ch, err := client.openChannel("chan", nil) + if err != nil { + t.Fatalf("OpenChannel: %v", err) + } + defer ch.Close() + + // Wait for the server to close the channel and send the keepalive. + <-kDone + + // Make sure the channel closed. + if _, ok := <-ch.incomingRequests; ok { + t.Fatalf("channel not closed") + } + + // Confirm client hasn't closed + if _, _, err := client.SendRequest("keepalive@golang.org", true, nil); err != nil { + t.Fatalf("failed to send keepalive: %v", err) + } + + // Wait for the server to shut down. + <-kDone +} + +func TestMuxGlobalRequest(t *testing.T) { + var sawPeek bool + var wg sync.WaitGroup + defer func() { + wg.Wait() + if !sawPeek { + t.Errorf("never saw 'peek' request") + } + }() + + clientMux, serverMux := muxPair() + defer serverMux.Close() + defer clientMux.Close() + + wg.Add(1) + go func() { + defer wg.Done() + for r := range serverMux.incomingRequests { + sawPeek = sawPeek || r.Type == "peek" + if r.WantReply { + err := r.Reply(r.Type == "yes", + append([]byte(r.Type), r.Payload...)) + if err != nil { + t.Errorf("AckRequest: %v", err) + } + } + } + }() + + _, _, err := clientMux.SendRequest("peek", false, nil) + if err != nil { + t.Errorf("SendRequest: %v", err) + } + + ok, data, err := clientMux.SendRequest("yes", true, []byte("a")) + if !ok || string(data) != "yesa" || err != nil { + t.Errorf("SendRequest(\"yes\", true, \"a\"): %v %v %v", + ok, data, err) + } + if ok, data, err := clientMux.SendRequest("yes", true, []byte("a")); !ok || string(data) != "yesa" || err != nil { + t.Errorf("SendRequest(\"yes\", true, \"a\"): %v %v %v", + ok, data, err) + } + + if ok, data, err := clientMux.SendRequest("no", true, []byte("a")); ok || string(data) != "noa" || err != nil { + t.Errorf("SendRequest(\"no\", true, \"a\"): %v %v %v", + ok, data, err) + } +} + +func TestMuxGlobalRequestUnblock(t *testing.T) { + clientMux, serverMux := muxPair() + defer serverMux.Close() + defer clientMux.Close() + + result := make(chan error, 1) + go func() { + _, _, err := clientMux.SendRequest("hello", true, nil) + result <- err + }() + + <-serverMux.incomingRequests + serverMux.conn.Close() + err := <-result + + if err != io.EOF { + t.Errorf("want EOF, got %v", io.EOF) + } +} + +func TestMuxChannelRequestUnblock(t *testing.T) { + a, b, connB := channelPair(t) + defer a.Close() + defer b.Close() + defer connB.Close() + + result := make(chan error, 1) + go func() { + _, err := a.SendRequest("hello", true, nil) + result <- err + }() + + <-b.incomingRequests + connB.conn.Close() + err := <-result + + if err != io.EOF { + t.Errorf("want EOF, got %v", err) + } +} + +func TestMuxCloseChannel(t *testing.T) { + r, w, mux := channelPair(t) + defer mux.Close() + defer r.Close() + defer w.Close() + + result := make(chan error, 1) + go func() { + var b [1024]byte + _, err := r.Read(b[:]) + result <- err + }() + if err := w.Close(); err != nil { + t.Errorf("w.Close: %v", err) + } + + if _, err := w.Write([]byte("hello")); err != io.EOF { + t.Errorf("got err %v, want io.EOF after Close", err) + } + + if err := <-result; err != io.EOF { + t.Errorf("got %v (%T), want io.EOF", err, err) + } +} + +func TestMuxCloseWriteChannel(t *testing.T) { + r, w, mux := channelPair(t) + defer mux.Close() + + result := make(chan error, 1) + go func() { + var b [1024]byte + _, err := r.Read(b[:]) + result <- err + }() + if err := w.CloseWrite(); err != nil { + t.Errorf("w.CloseWrite: %v", err) + } + + if _, err := w.Write([]byte("hello")); err != io.EOF { + t.Errorf("got err %v, want io.EOF after CloseWrite", err) + } + + if err := <-result; err != io.EOF { + t.Errorf("got %v (%T), want io.EOF", err, err) + } +} + +func TestMuxInvalidRecord(t *testing.T) { + a, b := muxPair() + defer a.Close() + defer b.Close() + + packet := make([]byte, 1+4+4+1) + packet[0] = msgChannelData + marshalUint32(packet[1:], 29348723 /* invalid channel id */) + marshalUint32(packet[5:], 1) + packet[9] = 42 + + a.conn.writePacket(packet) + go a.SendRequest("hello", false, nil) + // 'a' wrote an invalid packet, so 'b' has exited. + req, ok := <-b.incomingRequests + if ok { + t.Errorf("got request %#v after receiving invalid packet", req) + } +} + +func TestZeroWindowAdjust(t *testing.T) { + a, b, mux := channelPair(t) + defer a.Close() + defer b.Close() + defer mux.Close() + + go func() { + io.WriteString(a, "hello") + // bogus adjust. + a.sendMessage(windowAdjustMsg{}) + io.WriteString(a, "world") + a.Close() + }() + + want := "helloworld" + c, _ := io.ReadAll(b) + if string(c) != want { + t.Errorf("got %q want %q", c, want) + } +} + +func TestMuxMaxPacketSize(t *testing.T) { + a, b, mux := channelPair(t) + defer a.Close() + defer b.Close() + defer mux.Close() + + large := make([]byte, a.maxRemotePayload+1) + packet := make([]byte, 1+4+4+1+len(large)) + packet[0] = msgChannelData + marshalUint32(packet[1:], a.remoteId) + marshalUint32(packet[5:], uint32(len(large))) + packet[9] = 42 + + if err := a.mux.conn.writePacket(packet); err != nil { + t.Errorf("could not send packet") + } + + var wg sync.WaitGroup + t.Cleanup(wg.Wait) + wg.Add(1) + go func() { + a.SendRequest("hello", false, nil) + wg.Done() + }() + + _, ok := <-b.incomingRequests + if ok { + t.Errorf("connection still alive after receiving large packet.") + } +} + +func TestMuxChannelWindowDeferredUpdates(t *testing.T) { + s, c, mux := channelPair(t) + cTransport := mux.conn.(*memTransport) + defer s.Close() + defer c.Close() + defer mux.Close() + + var wg sync.WaitGroup + t.Cleanup(wg.Wait) + + data := make([]byte, 1024) + + wg.Add(1) + go func() { + defer wg.Done() + _, err := s.Write(data) + if err != nil { + t.Errorf("Write: %v", err) + return + } + }() + cWritesInit := cTransport.getWriteCount() + buf := make([]byte, 1) + for i := 0; i < len(data); i++ { + n, err := c.Read(buf) + if n != len(buf) || err != nil { + t.Fatalf("Read: %v, %v", n, err) + } + } + cWrites := cTransport.getWriteCount() - cWritesInit + // reading 1 KiB should not cause any window updates to be sent, but allow + // for some unexpected writes + if cWrites > 30 { + t.Fatalf("reading 1 KiB from channel caused %v writes", cWrites) + } +} + +// Don't ship code with debug=true. +func TestDebug(t *testing.T) { + if debugMux { + t.Error("mux debug switched on") + } + if debugHandshake { + t.Error("handshake debug switched on") + } + if debugTransport { + t.Error("transport debug switched on") + } +} diff --git a/tempfork/sshtest/ssh/server.go b/tempfork/sshtest/ssh/server.go new file mode 100644 index 000000000..1839ddc6a --- /dev/null +++ b/tempfork/sshtest/ssh/server.go @@ -0,0 +1,933 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "bytes" + "errors" + "fmt" + "io" + "net" + "strings" +) + +// The Permissions type holds fine-grained permissions that are +// specific to a user or a specific authentication method for a user. +// The Permissions value for a successful authentication attempt is +// available in ServerConn, so it can be used to pass information from +// the user-authentication phase to the application layer. +type Permissions struct { + // CriticalOptions indicate restrictions to the default + // permissions, and are typically used in conjunction with + // user certificates. The standard for SSH certificates + // defines "force-command" (only allow the given command to + // execute) and "source-address" (only allow connections from + // the given address). The SSH package currently only enforces + // the "source-address" critical option. It is up to server + // implementations to enforce other critical options, such as + // "force-command", by checking them after the SSH handshake + // is successful. In general, SSH servers should reject + // connections that specify critical options that are unknown + // or not supported. + CriticalOptions map[string]string + + // Extensions are extra functionality that the server may + // offer on authenticated connections. Lack of support for an + // extension does not preclude authenticating a user. Common + // extensions are "permit-agent-forwarding", + // "permit-X11-forwarding". The Go SSH library currently does + // not act on any extension, and it is up to server + // implementations to honor them. Extensions can be used to + // pass data from the authentication callbacks to the server + // application layer. + Extensions map[string]string +} + +type GSSAPIWithMICConfig struct { + // AllowLogin, must be set, is called when gssapi-with-mic + // authentication is selected (RFC 4462 section 3). The srcName is from the + // results of the GSS-API authentication. The format is username@DOMAIN. + // GSSAPI just guarantees to the server who the user is, but not if they can log in, and with what permissions. + // This callback is called after the user identity is established with GSSAPI to decide if the user can login with + // which permissions. If the user is allowed to login, it should return a nil error. + AllowLogin func(conn ConnMetadata, srcName string) (*Permissions, error) + + // Server must be set. It's the implementation + // of the GSSAPIServer interface. See GSSAPIServer interface for details. + Server GSSAPIServer +} + +// SendAuthBanner implements [ServerPreAuthConn]. +func (s *connection) SendAuthBanner(msg string) error { + return s.transport.writePacket(Marshal(&userAuthBannerMsg{ + Message: msg, + })) +} + +func (*connection) unexportedMethodForFutureProofing() {} + +// ServerPreAuthConn is the interface available on an incoming server +// connection before authentication has completed. +type ServerPreAuthConn interface { + unexportedMethodForFutureProofing() // permits growing ServerPreAuthConn safely later, ala testing.TB + + ConnMetadata + + // SendAuthBanner sends a banner message to the client. + // It returns an error once the authentication phase has ended. + SendAuthBanner(string) error +} + +// ServerConfig holds server specific configuration data. +type ServerConfig struct { + // Config contains configuration shared between client and server. + Config + + // PublicKeyAuthAlgorithms specifies the supported client public key + // authentication algorithms. Note that this should not include certificate + // types since those use the underlying algorithm. This list is sent to the + // client if it supports the server-sig-algs extension. Order is irrelevant. + // If unspecified then a default set of algorithms is used. + PublicKeyAuthAlgorithms []string + + hostKeys []Signer + + // NoClientAuth is true if clients are allowed to connect without + // authenticating. + // To determine NoClientAuth at runtime, set NoClientAuth to true + // and the optional NoClientAuthCallback to a non-nil value. + NoClientAuth bool + + // NoClientAuthCallback, if non-nil, is called when a user + // attempts to authenticate with auth method "none". + // NoClientAuth must also be set to true for this be used, or + // this func is unused. + NoClientAuthCallback func(ConnMetadata) (*Permissions, error) + + // MaxAuthTries specifies the maximum number of authentication attempts + // permitted per connection. If set to a negative number, the number of + // attempts are unlimited. If set to zero, the number of attempts are limited + // to 6. + MaxAuthTries int + + // PasswordCallback, if non-nil, is called when a user + // attempts to authenticate using a password. + PasswordCallback func(conn ConnMetadata, password []byte) (*Permissions, error) + + // PublicKeyCallback, if non-nil, is called when a client + // offers a public key for authentication. It must return a nil error + // if the given public key can be used to authenticate the + // given user. For example, see CertChecker.Authenticate. A + // call to this function does not guarantee that the key + // offered is in fact used to authenticate. To record any data + // depending on the public key, store it inside a + // Permissions.Extensions entry. + PublicKeyCallback func(conn ConnMetadata, key PublicKey) (*Permissions, error) + + // KeyboardInteractiveCallback, if non-nil, is called when + // keyboard-interactive authentication is selected (RFC + // 4256). The client object's Challenge function should be + // used to query the user. The callback may offer multiple + // Challenge rounds. To avoid information leaks, the client + // should be presented a challenge even if the user is + // unknown. + KeyboardInteractiveCallback func(conn ConnMetadata, client KeyboardInteractiveChallenge) (*Permissions, error) + + // AuthLogCallback, if non-nil, is called to log all authentication + // attempts. + AuthLogCallback func(conn ConnMetadata, method string, err error) + + // PreAuthConnCallback, if non-nil, is called upon receiving a new connection + // before any authentication has started. The provided ServerPreAuthConn + // can be used at any time before authentication is complete, including + // after this callback has returned. + PreAuthConnCallback func(ServerPreAuthConn) + + // ServerVersion is the version identification string to announce in + // the public handshake. + // If empty, a reasonable default is used. + // Note that RFC 4253 section 4.2 requires that this string start with + // "SSH-2.0-". + ServerVersion string + + // BannerCallback, if present, is called and the return string is sent to + // the client after key exchange completed but before authentication. + BannerCallback func(conn ConnMetadata) string + + // GSSAPIWithMICConfig includes gssapi server and callback, which if both non-nil, is used + // when gssapi-with-mic authentication is selected (RFC 4462 section 3). + GSSAPIWithMICConfig *GSSAPIWithMICConfig +} + +// AddHostKey adds a private key as a host key. If an existing host +// key exists with the same public key format, it is replaced. Each server +// config must have at least one host key. +func (s *ServerConfig) AddHostKey(key Signer) { + for i, k := range s.hostKeys { + if k.PublicKey().Type() == key.PublicKey().Type() { + s.hostKeys[i] = key + return + } + } + + s.hostKeys = append(s.hostKeys, key) +} + +// cachedPubKey contains the results of querying whether a public key is +// acceptable for a user. This is a FIFO cache. +type cachedPubKey struct { + user string + pubKeyData []byte + result error + perms *Permissions +} + +// maxCachedPubKeys is the number of cache entries we store. +// +// Due to consistent misuse of the PublicKeyCallback API, we have reduced this +// to 1, such that the only key in the cache is the most recently seen one. This +// forces the behavior that the last call to PublicKeyCallback will always be +// with the key that is used for authentication. +const maxCachedPubKeys = 1 + +// pubKeyCache caches tests for public keys. Since SSH clients +// will query whether a public key is acceptable before attempting to +// authenticate with it, we end up with duplicate queries for public +// key validity. The cache only applies to a single ServerConn. +type pubKeyCache struct { + keys []cachedPubKey +} + +// get returns the result for a given user/algo/key tuple. +func (c *pubKeyCache) get(user string, pubKeyData []byte) (cachedPubKey, bool) { + for _, k := range c.keys { + if k.user == user && bytes.Equal(k.pubKeyData, pubKeyData) { + return k, true + } + } + return cachedPubKey{}, false +} + +// add adds the given tuple to the cache. +func (c *pubKeyCache) add(candidate cachedPubKey) { + if len(c.keys) >= maxCachedPubKeys { + c.keys = c.keys[1:] + } + c.keys = append(c.keys, candidate) +} + +// ServerConn is an authenticated SSH connection, as seen from the +// server +type ServerConn struct { + Conn + + // If the succeeding authentication callback returned a + // non-nil Permissions pointer, it is stored here. + Permissions *Permissions +} + +// NewServerConn starts a new SSH server with c as the underlying +// transport. It starts with a handshake and, if the handshake is +// unsuccessful, it closes the connection and returns an error. The +// Request and NewChannel channels must be serviced, or the connection +// will hang. +// +// The returned error may be of type *ServerAuthError for +// authentication errors. +func NewServerConn(c net.Conn, config *ServerConfig) (*ServerConn, <-chan NewChannel, <-chan *Request, error) { + fullConf := *config + fullConf.SetDefaults() + if fullConf.MaxAuthTries == 0 { + fullConf.MaxAuthTries = 6 + } + if len(fullConf.PublicKeyAuthAlgorithms) == 0 { + fullConf.PublicKeyAuthAlgorithms = supportedPubKeyAuthAlgos + } else { + for _, algo := range fullConf.PublicKeyAuthAlgorithms { + if !contains(supportedPubKeyAuthAlgos, algo) { + c.Close() + return nil, nil, nil, fmt.Errorf("ssh: unsupported public key authentication algorithm %s", algo) + } + } + } + // Check if the config contains any unsupported key exchanges + for _, kex := range fullConf.KeyExchanges { + if _, ok := serverForbiddenKexAlgos[kex]; ok { + c.Close() + return nil, nil, nil, fmt.Errorf("ssh: unsupported key exchange %s for server", kex) + } + } + + s := &connection{ + sshConn: sshConn{conn: c}, + } + perms, err := s.serverHandshake(&fullConf) + if err != nil { + c.Close() + return nil, nil, nil, err + } + return &ServerConn{s, perms}, s.mux.incomingChannels, s.mux.incomingRequests, nil +} + +// signAndMarshal signs the data with the appropriate algorithm, +// and serializes the result in SSH wire format. algo is the negotiate +// algorithm and may be a certificate type. +func signAndMarshal(k AlgorithmSigner, rand io.Reader, data []byte, algo string) ([]byte, error) { + sig, err := k.SignWithAlgorithm(rand, data, underlyingAlgo(algo)) + if err != nil { + return nil, err + } + + return Marshal(sig), nil +} + +// handshake performs key exchange and user authentication. +func (s *connection) serverHandshake(config *ServerConfig) (*Permissions, error) { + if len(config.hostKeys) == 0 { + return nil, errors.New("ssh: server has no host keys") + } + + if !config.NoClientAuth && config.PasswordCallback == nil && config.PublicKeyCallback == nil && + config.KeyboardInteractiveCallback == nil && (config.GSSAPIWithMICConfig == nil || + config.GSSAPIWithMICConfig.AllowLogin == nil || config.GSSAPIWithMICConfig.Server == nil) { + return nil, errors.New("ssh: no authentication methods configured but NoClientAuth is also false") + } + + if config.ServerVersion != "" { + s.serverVersion = []byte(config.ServerVersion) + } else { + s.serverVersion = []byte(packageVersion) + } + var err error + s.clientVersion, err = exchangeVersions(s.sshConn.conn, s.serverVersion) + if err != nil { + return nil, err + } + + tr := newTransport(s.sshConn.conn, config.Rand, false /* not client */) + s.transport = newServerTransport(tr, s.clientVersion, s.serverVersion, config) + + if err := s.transport.waitSession(); err != nil { + return nil, err + } + + // We just did the key change, so the session ID is established. + s.sessionID = s.transport.getSessionID() + + var packet []byte + if packet, err = s.transport.readPacket(); err != nil { + return nil, err + } + + var serviceRequest serviceRequestMsg + if err = Unmarshal(packet, &serviceRequest); err != nil { + return nil, err + } + if serviceRequest.Service != serviceUserAuth { + return nil, errors.New("ssh: requested service '" + serviceRequest.Service + "' before authenticating") + } + serviceAccept := serviceAcceptMsg{ + Service: serviceUserAuth, + } + if err := s.transport.writePacket(Marshal(&serviceAccept)); err != nil { + return nil, err + } + + perms, err := s.serverAuthenticate(config) + if err != nil { + return nil, err + } + s.mux = newMux(s.transport) + return perms, err +} + +func checkSourceAddress(addr net.Addr, sourceAddrs string) error { + if addr == nil { + return errors.New("ssh: no address known for client, but source-address match required") + } + + tcpAddr, ok := addr.(*net.TCPAddr) + if !ok { + return fmt.Errorf("ssh: remote address %v is not an TCP address when checking source-address match", addr) + } + + for _, sourceAddr := range strings.Split(sourceAddrs, ",") { + if allowedIP := net.ParseIP(sourceAddr); allowedIP != nil { + if allowedIP.Equal(tcpAddr.IP) { + return nil + } + } else { + _, ipNet, err := net.ParseCIDR(sourceAddr) + if err != nil { + return fmt.Errorf("ssh: error parsing source-address restriction %q: %v", sourceAddr, err) + } + + if ipNet.Contains(tcpAddr.IP) { + return nil + } + } + } + + return fmt.Errorf("ssh: remote address %v is not allowed because of source-address restriction", addr) +} + +func gssExchangeToken(gssapiConfig *GSSAPIWithMICConfig, token []byte, s *connection, + sessionID []byte, userAuthReq userAuthRequestMsg) (authErr error, perms *Permissions, err error) { + gssAPIServer := gssapiConfig.Server + defer gssAPIServer.DeleteSecContext() + var srcName string + for { + var ( + outToken []byte + needContinue bool + ) + outToken, srcName, needContinue, err = gssAPIServer.AcceptSecContext(token) + if err != nil { + return err, nil, nil + } + if len(outToken) != 0 { + if err := s.transport.writePacket(Marshal(&userAuthGSSAPIToken{ + Token: outToken, + })); err != nil { + return nil, nil, err + } + } + if !needContinue { + break + } + packet, err := s.transport.readPacket() + if err != nil { + return nil, nil, err + } + userAuthGSSAPITokenReq := &userAuthGSSAPIToken{} + if err := Unmarshal(packet, userAuthGSSAPITokenReq); err != nil { + return nil, nil, err + } + token = userAuthGSSAPITokenReq.Token + } + packet, err := s.transport.readPacket() + if err != nil { + return nil, nil, err + } + userAuthGSSAPIMICReq := &userAuthGSSAPIMIC{} + if err := Unmarshal(packet, userAuthGSSAPIMICReq); err != nil { + return nil, nil, err + } + mic := buildMIC(string(sessionID), userAuthReq.User, userAuthReq.Service, userAuthReq.Method) + if err := gssAPIServer.VerifyMIC(mic, userAuthGSSAPIMICReq.MIC); err != nil { + return err, nil, nil + } + perms, authErr = gssapiConfig.AllowLogin(s, srcName) + return authErr, perms, nil +} + +// isAlgoCompatible checks if the signature format is compatible with the +// selected algorithm taking into account edge cases that occur with old +// clients. +func isAlgoCompatible(algo, sigFormat string) bool { + // Compatibility for old clients. + // + // For certificate authentication with OpenSSH 7.2-7.7 signature format can + // be rsa-sha2-256 or rsa-sha2-512 for the algorithm + // ssh-rsa-cert-v01@openssh.com. + // + // With gpg-agent < 2.2.6 the algorithm can be rsa-sha2-256 or rsa-sha2-512 + // for signature format ssh-rsa. + if isRSA(algo) && isRSA(sigFormat) { + return true + } + // Standard case: the underlying algorithm must match the signature format. + return underlyingAlgo(algo) == sigFormat +} + +// ServerAuthError represents server authentication errors and is +// sometimes returned by NewServerConn. It appends any authentication +// errors that may occur, and is returned if all of the authentication +// methods provided by the user failed to authenticate. +type ServerAuthError struct { + // Errors contains authentication errors returned by the authentication + // callback methods. The first entry is typically ErrNoAuth. + Errors []error +} + +func (l ServerAuthError) Error() string { + var errs []string + for _, err := range l.Errors { + errs = append(errs, err.Error()) + } + return "[" + strings.Join(errs, ", ") + "]" +} + +// ServerAuthCallbacks defines server-side authentication callbacks. +type ServerAuthCallbacks struct { + // PasswordCallback behaves like [ServerConfig.PasswordCallback]. + PasswordCallback func(conn ConnMetadata, password []byte) (*Permissions, error) + + // PublicKeyCallback behaves like [ServerConfig.PublicKeyCallback]. + PublicKeyCallback func(conn ConnMetadata, key PublicKey) (*Permissions, error) + + // KeyboardInteractiveCallback behaves like [ServerConfig.KeyboardInteractiveCallback]. + KeyboardInteractiveCallback func(conn ConnMetadata, client KeyboardInteractiveChallenge) (*Permissions, error) + + // GSSAPIWithMICConfig behaves like [ServerConfig.GSSAPIWithMICConfig]. + GSSAPIWithMICConfig *GSSAPIWithMICConfig +} + +// PartialSuccessError can be returned by any of the [ServerConfig] +// authentication callbacks to indicate to the client that authentication has +// partially succeeded, but further steps are required. +type PartialSuccessError struct { + // Next defines the authentication callbacks to apply to further steps. The + // available methods communicated to the client are based on the non-nil + // ServerAuthCallbacks fields. + Next ServerAuthCallbacks +} + +func (p *PartialSuccessError) Error() string { + return "ssh: authenticated with partial success" +} + +// ErrNoAuth is the error value returned if no +// authentication method has been passed yet. This happens as a normal +// part of the authentication loop, since the client first tries +// 'none' authentication to discover available methods. +// It is returned in ServerAuthError.Errors from NewServerConn. +var ErrNoAuth = errors.New("ssh: no auth passed yet") + +// BannerError is an error that can be returned by authentication handlers in +// ServerConfig to send a banner message to the client. +type BannerError struct { + Err error + Message string +} + +func (b *BannerError) Unwrap() error { + return b.Err +} + +func (b *BannerError) Error() string { + if b.Err == nil { + return b.Message + } + return b.Err.Error() +} + +func (s *connection) serverAuthenticate(config *ServerConfig) (*Permissions, error) { + if config.PreAuthConnCallback != nil { + config.PreAuthConnCallback(s) + } + + sessionID := s.transport.getSessionID() + var cache pubKeyCache + var perms *Permissions + + authFailures := 0 + noneAuthCount := 0 + var authErrs []error + var calledBannerCallback bool + partialSuccessReturned := false + // Set the initial authentication callbacks from the config. They can be + // changed if a PartialSuccessError is returned. + authConfig := ServerAuthCallbacks{ + PasswordCallback: config.PasswordCallback, + PublicKeyCallback: config.PublicKeyCallback, + KeyboardInteractiveCallback: config.KeyboardInteractiveCallback, + GSSAPIWithMICConfig: config.GSSAPIWithMICConfig, + } + +userAuthLoop: + for { + if authFailures >= config.MaxAuthTries && config.MaxAuthTries > 0 { + discMsg := &disconnectMsg{ + Reason: 2, + Message: "too many authentication failures", + } + + if err := s.transport.writePacket(Marshal(discMsg)); err != nil { + return nil, err + } + authErrs = append(authErrs, discMsg) + return nil, &ServerAuthError{Errors: authErrs} + } + + var userAuthReq userAuthRequestMsg + if packet, err := s.transport.readPacket(); err != nil { + if err == io.EOF { + return nil, &ServerAuthError{Errors: authErrs} + } + return nil, err + } else if err = Unmarshal(packet, &userAuthReq); err != nil { + return nil, err + } + + if userAuthReq.Service != serviceSSH { + return nil, errors.New("ssh: client attempted to negotiate for unknown service: " + userAuthReq.Service) + } + + if s.user != userAuthReq.User && partialSuccessReturned { + return nil, fmt.Errorf("ssh: client changed the user after a partial success authentication, previous user %q, current user %q", + s.user, userAuthReq.User) + } + + s.user = userAuthReq.User + + if !calledBannerCallback && config.BannerCallback != nil { + calledBannerCallback = true + if msg := config.BannerCallback(s); msg != "" { + if err := s.SendAuthBanner(msg); err != nil { + return nil, err + } + } + } + + perms = nil + authErr := ErrNoAuth + + switch userAuthReq.Method { + case "none": + noneAuthCount++ + // We don't allow none authentication after a partial success + // response. + if config.NoClientAuth && !partialSuccessReturned { + if config.NoClientAuthCallback != nil { + perms, authErr = config.NoClientAuthCallback(s) + } else { + authErr = nil + } + } + case "password": + if authConfig.PasswordCallback == nil { + authErr = errors.New("ssh: password auth not configured") + break + } + payload := userAuthReq.Payload + if len(payload) < 1 || payload[0] != 0 { + return nil, parseError(msgUserAuthRequest) + } + payload = payload[1:] + password, payload, ok := parseString(payload) + if !ok || len(payload) > 0 { + return nil, parseError(msgUserAuthRequest) + } + + perms, authErr = authConfig.PasswordCallback(s, password) + case "keyboard-interactive": + if authConfig.KeyboardInteractiveCallback == nil { + authErr = errors.New("ssh: keyboard-interactive auth not configured") + break + } + + prompter := &sshClientKeyboardInteractive{s} + perms, authErr = authConfig.KeyboardInteractiveCallback(s, prompter.Challenge) + case "publickey": + if authConfig.PublicKeyCallback == nil { + authErr = errors.New("ssh: publickey auth not configured") + break + } + payload := userAuthReq.Payload + if len(payload) < 1 { + return nil, parseError(msgUserAuthRequest) + } + isQuery := payload[0] == 0 + payload = payload[1:] + algoBytes, payload, ok := parseString(payload) + if !ok { + return nil, parseError(msgUserAuthRequest) + } + algo := string(algoBytes) + if !contains(config.PublicKeyAuthAlgorithms, underlyingAlgo(algo)) { + authErr = fmt.Errorf("ssh: algorithm %q not accepted", algo) + break + } + + pubKeyData, payload, ok := parseString(payload) + if !ok { + return nil, parseError(msgUserAuthRequest) + } + + pubKey, err := ParsePublicKey(pubKeyData) + if err != nil { + return nil, err + } + + candidate, ok := cache.get(s.user, pubKeyData) + if !ok { + candidate.user = s.user + candidate.pubKeyData = pubKeyData + candidate.perms, candidate.result = authConfig.PublicKeyCallback(s, pubKey) + _, isPartialSuccessError := candidate.result.(*PartialSuccessError) + + if (candidate.result == nil || isPartialSuccessError) && + candidate.perms != nil && + candidate.perms.CriticalOptions != nil && + candidate.perms.CriticalOptions[sourceAddressCriticalOption] != "" { + if err := checkSourceAddress( + s.RemoteAddr(), + candidate.perms.CriticalOptions[sourceAddressCriticalOption]); err != nil { + candidate.result = err + } + } + cache.add(candidate) + } + + if isQuery { + // The client can query if the given public key + // would be okay. + + if len(payload) > 0 { + return nil, parseError(msgUserAuthRequest) + } + _, isPartialSuccessError := candidate.result.(*PartialSuccessError) + if candidate.result == nil || isPartialSuccessError { + okMsg := userAuthPubKeyOkMsg{ + Algo: algo, + PubKey: pubKeyData, + } + if err = s.transport.writePacket(Marshal(&okMsg)); err != nil { + return nil, err + } + continue userAuthLoop + } + authErr = candidate.result + } else { + sig, payload, ok := parseSignature(payload) + if !ok || len(payload) > 0 { + return nil, parseError(msgUserAuthRequest) + } + // Ensure the declared public key algo is compatible with the + // decoded one. This check will ensure we don't accept e.g. + // ssh-rsa-cert-v01@openssh.com algorithm with ssh-rsa public + // key type. The algorithm and public key type must be + // consistent: both must be certificate algorithms, or neither. + if !contains(algorithmsForKeyFormat(pubKey.Type()), algo) { + authErr = fmt.Errorf("ssh: public key type %q not compatible with selected algorithm %q", + pubKey.Type(), algo) + break + } + // Ensure the public key algo and signature algo + // are supported. Compare the private key + // algorithm name that corresponds to algo with + // sig.Format. This is usually the same, but + // for certs, the names differ. + if !contains(config.PublicKeyAuthAlgorithms, sig.Format) { + authErr = fmt.Errorf("ssh: algorithm %q not accepted", sig.Format) + break + } + if !isAlgoCompatible(algo, sig.Format) { + authErr = fmt.Errorf("ssh: signature %q not compatible with selected algorithm %q", sig.Format, algo) + break + } + + signedData := buildDataSignedForAuth(sessionID, userAuthReq, algo, pubKeyData) + + if err := pubKey.Verify(signedData, sig); err != nil { + return nil, err + } + + authErr = candidate.result + perms = candidate.perms + } + case "gssapi-with-mic": + if authConfig.GSSAPIWithMICConfig == nil { + authErr = errors.New("ssh: gssapi-with-mic auth not configured") + break + } + gssapiConfig := authConfig.GSSAPIWithMICConfig + userAuthRequestGSSAPI, err := parseGSSAPIPayload(userAuthReq.Payload) + if err != nil { + return nil, parseError(msgUserAuthRequest) + } + // OpenSSH supports Kerberos V5 mechanism only for GSS-API authentication. + if userAuthRequestGSSAPI.N == 0 { + authErr = fmt.Errorf("ssh: Mechanism negotiation is not supported") + break + } + var i uint32 + present := false + for i = 0; i < userAuthRequestGSSAPI.N; i++ { + if userAuthRequestGSSAPI.OIDS[i].Equal(krb5Mesh) { + present = true + break + } + } + if !present { + authErr = fmt.Errorf("ssh: GSSAPI authentication must use the Kerberos V5 mechanism") + break + } + // Initial server response, see RFC 4462 section 3.3. + if err := s.transport.writePacket(Marshal(&userAuthGSSAPIResponse{ + SupportMech: krb5OID, + })); err != nil { + return nil, err + } + // Exchange token, see RFC 4462 section 3.4. + packet, err := s.transport.readPacket() + if err != nil { + return nil, err + } + userAuthGSSAPITokenReq := &userAuthGSSAPIToken{} + if err := Unmarshal(packet, userAuthGSSAPITokenReq); err != nil { + return nil, err + } + authErr, perms, err = gssExchangeToken(gssapiConfig, userAuthGSSAPITokenReq.Token, s, sessionID, + userAuthReq) + if err != nil { + return nil, err + } + default: + authErr = fmt.Errorf("ssh: unknown method %q", userAuthReq.Method) + } + + authErrs = append(authErrs, authErr) + + if config.AuthLogCallback != nil { + config.AuthLogCallback(s, userAuthReq.Method, authErr) + } + + var bannerErr *BannerError + if errors.As(authErr, &bannerErr) { + if bannerErr.Message != "" { + if err := s.SendAuthBanner(bannerErr.Message); err != nil { + return nil, err + } + } + } + + if authErr == nil { + break userAuthLoop + } + + var failureMsg userAuthFailureMsg + + if partialSuccess, ok := authErr.(*PartialSuccessError); ok { + // After a partial success error we don't allow changing the user + // name and execute the NoClientAuthCallback. + partialSuccessReturned = true + + // In case a partial success is returned, the server may send + // a new set of authentication methods. + authConfig = partialSuccess.Next + + // Reset pubkey cache, as the new PublicKeyCallback might + // accept a different set of public keys. + cache = pubKeyCache{} + + // Send back a partial success message to the user. + failureMsg.PartialSuccess = true + } else { + // Allow initial attempt of 'none' without penalty. + if authFailures > 0 || userAuthReq.Method != "none" || noneAuthCount != 1 { + authFailures++ + } + if config.MaxAuthTries > 0 && authFailures >= config.MaxAuthTries { + // If we have hit the max attempts, don't bother sending the + // final SSH_MSG_USERAUTH_FAILURE message, since there are + // no more authentication methods which can be attempted, + // and this message may cause the client to re-attempt + // authentication while we send the disconnect message. + // Continue, and trigger the disconnect at the start of + // the loop. + // + // The SSH specification is somewhat confusing about this, + // RFC 4252 Section 5.1 requires each authentication failure + // be responded to with a respective SSH_MSG_USERAUTH_FAILURE + // message, but Section 4 says the server should disconnect + // after some number of attempts, but it isn't explicit which + // message should take precedence (i.e. should there be a failure + // message than a disconnect message, or if we are going to + // disconnect, should we only send that message.) + // + // Either way, OpenSSH disconnects immediately after the last + // failed authentication attempt, and given they are typically + // considered the golden implementation it seems reasonable + // to match that behavior. + continue + } + } + + if authConfig.PasswordCallback != nil { + failureMsg.Methods = append(failureMsg.Methods, "password") + } + if authConfig.PublicKeyCallback != nil { + failureMsg.Methods = append(failureMsg.Methods, "publickey") + } + if authConfig.KeyboardInteractiveCallback != nil { + failureMsg.Methods = append(failureMsg.Methods, "keyboard-interactive") + } + if authConfig.GSSAPIWithMICConfig != nil && authConfig.GSSAPIWithMICConfig.Server != nil && + authConfig.GSSAPIWithMICConfig.AllowLogin != nil { + failureMsg.Methods = append(failureMsg.Methods, "gssapi-with-mic") + } + + if len(failureMsg.Methods) == 0 { + return nil, errors.New("ssh: no authentication methods available") + } + + if err := s.transport.writePacket(Marshal(&failureMsg)); err != nil { + return nil, err + } + } + + if err := s.transport.writePacket([]byte{msgUserAuthSuccess}); err != nil { + return nil, err + } + return perms, nil +} + +// sshClientKeyboardInteractive implements a ClientKeyboardInteractive by +// asking the client on the other side of a ServerConn. +type sshClientKeyboardInteractive struct { + *connection +} + +func (c *sshClientKeyboardInteractive) Challenge(name, instruction string, questions []string, echos []bool) (answers []string, err error) { + if len(questions) != len(echos) { + return nil, errors.New("ssh: echos and questions must have equal length") + } + + var prompts []byte + for i := range questions { + prompts = appendString(prompts, questions[i]) + prompts = appendBool(prompts, echos[i]) + } + + if err := c.transport.writePacket(Marshal(&userAuthInfoRequestMsg{ + Name: name, + Instruction: instruction, + NumPrompts: uint32(len(questions)), + Prompts: prompts, + })); err != nil { + return nil, err + } + + packet, err := c.transport.readPacket() + if err != nil { + return nil, err + } + if packet[0] != msgUserAuthInfoResponse { + return nil, unexpectedMessageError(msgUserAuthInfoResponse, packet[0]) + } + packet = packet[1:] + + n, packet, ok := parseUint32(packet) + if !ok || int(n) != len(questions) { + return nil, parseError(msgUserAuthInfoResponse) + } + + for i := uint32(0); i < n; i++ { + ans, rest, ok := parseString(packet) + if !ok { + return nil, parseError(msgUserAuthInfoResponse) + } + + answers = append(answers, string(ans)) + packet = rest + } + if len(packet) != 0 { + return nil, errors.New("ssh: junk at end of message") + } + + return answers, nil +} diff --git a/tempfork/sshtest/ssh/server_multi_auth_test.go b/tempfork/sshtest/ssh/server_multi_auth_test.go new file mode 100644 index 000000000..3b3980243 --- /dev/null +++ b/tempfork/sshtest/ssh/server_multi_auth_test.go @@ -0,0 +1,412 @@ +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "bytes" + "errors" + "fmt" + "strings" + "testing" +) + +func doClientServerAuth(t *testing.T, serverConfig *ServerConfig, clientConfig *ClientConfig) ([]error, error) { + c1, c2, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + defer c1.Close() + defer c2.Close() + + var serverAuthErrors []error + + serverConfig.AddHostKey(testSigners["rsa"]) + serverConfig.AuthLogCallback = func(conn ConnMetadata, method string, err error) { + serverAuthErrors = append(serverAuthErrors, err) + } + go newServer(c1, serverConfig) + c, _, _, err := NewClientConn(c2, "", clientConfig) + if err == nil { + c.Close() + } + return serverAuthErrors, err +} + +func TestMultiStepAuth(t *testing.T) { + // This user can login with password, public key or public key + password. + username := "testuser" + // This user can login with public key + password only. + usernameSecondFactor := "testuser_second_factor" + errPwdAuthFailed := errors.New("password auth failed") + errWrongSequence := errors.New("wrong sequence") + + serverConfig := &ServerConfig{ + PasswordCallback: func(conn ConnMetadata, password []byte) (*Permissions, error) { + if conn.User() == usernameSecondFactor { + return nil, errWrongSequence + } + if conn.User() == username && string(password) == clientPassword { + return nil, nil + } + return nil, errPwdAuthFailed + }, + PublicKeyCallback: func(conn ConnMetadata, key PublicKey) (*Permissions, error) { + if bytes.Equal(key.Marshal(), testPublicKeys["rsa"].Marshal()) { + if conn.User() == usernameSecondFactor { + return nil, &PartialSuccessError{ + Next: ServerAuthCallbacks{ + PasswordCallback: func(conn ConnMetadata, password []byte) (*Permissions, error) { + if string(password) == clientPassword { + return nil, nil + } + return nil, errPwdAuthFailed + }, + }, + } + } + return nil, nil + } + return nil, fmt.Errorf("pubkey for %q not acceptable", conn.User()) + }, + } + + clientConfig := &ClientConfig{ + User: usernameSecondFactor, + Auth: []AuthMethod{ + PublicKeys(testSigners["rsa"]), + Password(clientPassword), + }, + HostKeyCallback: InsecureIgnoreHostKey(), + } + + serverAuthErrors, err := doClientServerAuth(t, serverConfig, clientConfig) + if err != nil { + t.Fatalf("client login error: %s", err) + } + + // The error sequence is: + // - no auth passed yet + // - partial success + // - nil + if len(serverAuthErrors) != 3 { + t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors) + } + if _, ok := serverAuthErrors[1].(*PartialSuccessError); !ok { + t.Fatalf("expected partial success error, got: %v", serverAuthErrors[1]) + } + // Now test a wrong sequence. + clientConfig.Auth = []AuthMethod{ + Password(clientPassword), + PublicKeys(testSigners["rsa"]), + } + + serverAuthErrors, err = doClientServerAuth(t, serverConfig, clientConfig) + if err == nil { + t.Fatal("client login with wrong sequence must fail") + } + // The error sequence is: + // - no auth passed yet + // - wrong sequence + // - partial success + if len(serverAuthErrors) != 3 { + t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors) + } + if serverAuthErrors[1] != errWrongSequence { + t.Fatal("server not returned wrong sequence") + } + if _, ok := serverAuthErrors[2].(*PartialSuccessError); !ok { + t.Fatalf("expected partial success error, got: %v", serverAuthErrors[2]) + } + // Now test using a correct sequence but a wrong password before the right + // one. + n := 0 + passwords := []string{"WRONG", "WRONG", clientPassword} + clientConfig.Auth = []AuthMethod{ + PublicKeys(testSigners["rsa"]), + RetryableAuthMethod(PasswordCallback(func() (string, error) { + p := passwords[n] + n++ + return p, nil + }), 3), + } + + serverAuthErrors, err = doClientServerAuth(t, serverConfig, clientConfig) + if err != nil { + t.Fatalf("client login error: %s", err) + } + // The error sequence is: + // - no auth passed yet + // - partial success + // - wrong password + // - wrong password + // - nil + if len(serverAuthErrors) != 5 { + t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors) + } + if _, ok := serverAuthErrors[1].(*PartialSuccessError); !ok { + t.Fatal("server not returned partial success") + } + if serverAuthErrors[2] != errPwdAuthFailed { + t.Fatal("server not returned password authentication failed") + } + if serverAuthErrors[3] != errPwdAuthFailed { + t.Fatal("server not returned password authentication failed") + } + // Only password authentication should fail. + clientConfig.Auth = []AuthMethod{ + Password(clientPassword), + } + + serverAuthErrors, err = doClientServerAuth(t, serverConfig, clientConfig) + if err == nil { + t.Fatal("client login with password only must fail") + } + // The error sequence is: + // - no auth passed yet + // - wrong sequence + if len(serverAuthErrors) != 2 { + t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors) + } + if serverAuthErrors[1] != errWrongSequence { + t.Fatal("server not returned wrong sequence") + } + + // Only public key authentication should fail. + clientConfig.Auth = []AuthMethod{ + PublicKeys(testSigners["rsa"]), + } + + serverAuthErrors, err = doClientServerAuth(t, serverConfig, clientConfig) + if err == nil { + t.Fatal("client login with public key only must fail") + } + // The error sequence is: + // - no auth passed yet + // - partial success + if len(serverAuthErrors) != 2 { + t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors) + } + if _, ok := serverAuthErrors[1].(*PartialSuccessError); !ok { + t.Fatal("server not returned partial success") + } + + // Public key and wrong password. + clientConfig.Auth = []AuthMethod{ + PublicKeys(testSigners["rsa"]), + Password("WRONG"), + } + + serverAuthErrors, err = doClientServerAuth(t, serverConfig, clientConfig) + if err == nil { + t.Fatal("client login with wrong password after public key must fail") + } + // The error sequence is: + // - no auth passed yet + // - partial success + // - password auth failed + if len(serverAuthErrors) != 3 { + t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors) + } + if _, ok := serverAuthErrors[1].(*PartialSuccessError); !ok { + t.Fatal("server not returned partial success") + } + if serverAuthErrors[2] != errPwdAuthFailed { + t.Fatal("server not returned password authentication failed") + } + + // Public key, public key again and then correct password. Public key + // authentication is attempted only once because the partial success error + // returns only "password" as the allowed authentication method. + clientConfig.Auth = []AuthMethod{ + PublicKeys(testSigners["rsa"]), + PublicKeys(testSigners["rsa"]), + Password(clientPassword), + } + + serverAuthErrors, err = doClientServerAuth(t, serverConfig, clientConfig) + if err != nil { + t.Fatalf("client login error: %s", err) + } + // The error sequence is: + // - no auth passed yet + // - partial success + // - nil + if len(serverAuthErrors) != 3 { + t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors) + } + if _, ok := serverAuthErrors[1].(*PartialSuccessError); !ok { + t.Fatal("server not returned partial success") + } + + // The unrestricted username can do anything + clientConfig = &ClientConfig{ + User: username, + Auth: []AuthMethod{ + PublicKeys(testSigners["rsa"]), + Password(clientPassword), + }, + HostKeyCallback: InsecureIgnoreHostKey(), + } + + _, err = doClientServerAuth(t, serverConfig, clientConfig) + if err != nil { + t.Fatalf("unrestricted client login error: %s", err) + } + + clientConfig = &ClientConfig{ + User: username, + Auth: []AuthMethod{ + PublicKeys(testSigners["rsa"]), + }, + HostKeyCallback: InsecureIgnoreHostKey(), + } + + _, err = doClientServerAuth(t, serverConfig, clientConfig) + if err != nil { + t.Fatalf("unrestricted client login error: %s", err) + } + + clientConfig = &ClientConfig{ + User: username, + Auth: []AuthMethod{ + Password(clientPassword), + }, + HostKeyCallback: InsecureIgnoreHostKey(), + } + + _, err = doClientServerAuth(t, serverConfig, clientConfig) + if err != nil { + t.Fatalf("unrestricted client login error: %s", err) + } +} + +func TestDynamicAuthCallbacks(t *testing.T) { + user1 := "user1" + user2 := "user2" + errInvalidCredentials := errors.New("invalid credentials") + + serverConfig := &ServerConfig{ + NoClientAuth: true, + NoClientAuthCallback: func(conn ConnMetadata) (*Permissions, error) { + switch conn.User() { + case user1: + return nil, &PartialSuccessError{ + Next: ServerAuthCallbacks{ + PasswordCallback: func(conn ConnMetadata, password []byte) (*Permissions, error) { + if conn.User() == user1 && string(password) == clientPassword { + return nil, nil + } + return nil, errInvalidCredentials + }, + }, + } + case user2: + return nil, &PartialSuccessError{ + Next: ServerAuthCallbacks{ + PublicKeyCallback: func(conn ConnMetadata, key PublicKey) (*Permissions, error) { + if bytes.Equal(key.Marshal(), testPublicKeys["rsa"].Marshal()) { + if conn.User() == user2 { + return nil, nil + } + } + return nil, errInvalidCredentials + }, + }, + } + default: + return nil, errInvalidCredentials + } + }, + } + + clientConfig := &ClientConfig{ + User: user1, + Auth: []AuthMethod{ + Password(clientPassword), + }, + HostKeyCallback: InsecureIgnoreHostKey(), + } + + serverAuthErrors, err := doClientServerAuth(t, serverConfig, clientConfig) + if err != nil { + t.Fatalf("client login error: %s", err) + } + // The error sequence is: + // - partial success + // - nil + if len(serverAuthErrors) != 2 { + t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors) + } + if _, ok := serverAuthErrors[0].(*PartialSuccessError); !ok { + t.Fatal("server not returned partial success") + } + + clientConfig = &ClientConfig{ + User: user2, + Auth: []AuthMethod{ + PublicKeys(testSigners["rsa"]), + }, + HostKeyCallback: InsecureIgnoreHostKey(), + } + + serverAuthErrors, err = doClientServerAuth(t, serverConfig, clientConfig) + if err != nil { + t.Fatalf("client login error: %s", err) + } + // The error sequence is: + // - partial success + // - nil + if len(serverAuthErrors) != 2 { + t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors) + } + if _, ok := serverAuthErrors[0].(*PartialSuccessError); !ok { + t.Fatal("server not returned partial success") + } + + // user1 cannot login with public key + clientConfig = &ClientConfig{ + User: user1, + Auth: []AuthMethod{ + PublicKeys(testSigners["rsa"]), + }, + HostKeyCallback: InsecureIgnoreHostKey(), + } + + serverAuthErrors, err = doClientServerAuth(t, serverConfig, clientConfig) + if err == nil { + t.Fatal("user1 login with public key must fail") + } + if !strings.Contains(err.Error(), "no supported methods remain") { + t.Errorf("got %v, expected 'no supported methods remain'", err) + } + if len(serverAuthErrors) != 1 { + t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors) + } + if _, ok := serverAuthErrors[0].(*PartialSuccessError); !ok { + t.Fatal("server not returned partial success") + } + // user2 cannot login with password + clientConfig = &ClientConfig{ + User: user2, + Auth: []AuthMethod{ + Password(clientPassword), + }, + HostKeyCallback: InsecureIgnoreHostKey(), + } + + serverAuthErrors, err = doClientServerAuth(t, serverConfig, clientConfig) + if err == nil { + t.Fatal("user2 login with password must fail") + } + if !strings.Contains(err.Error(), "no supported methods remain") { + t.Errorf("got %v, expected 'no supported methods remain'", err) + } + if len(serverAuthErrors) != 1 { + t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors) + } + if _, ok := serverAuthErrors[0].(*PartialSuccessError); !ok { + t.Fatal("server not returned partial success") + } +} diff --git a/tempfork/sshtest/ssh/server_test.go b/tempfork/sshtest/ssh/server_test.go new file mode 100644 index 000000000..c2b24f47c --- /dev/null +++ b/tempfork/sshtest/ssh/server_test.go @@ -0,0 +1,478 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "bytes" + "errors" + "fmt" + "io" + "net" + "reflect" + "strings" + "sync/atomic" + "testing" + "time" +) + +func TestClientAuthRestrictedPublicKeyAlgos(t *testing.T) { + for _, tt := range []struct { + name string + key Signer + wantError bool + }{ + {"rsa", testSigners["rsa"], false}, + {"dsa", testSigners["dsa"], true}, + {"ed25519", testSigners["ed25519"], true}, + } { + c1, c2, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + defer c1.Close() + defer c2.Close() + serverConf := &ServerConfig{ + PublicKeyAuthAlgorithms: []string{KeyAlgoRSASHA256, KeyAlgoRSASHA512}, + PublicKeyCallback: func(conn ConnMetadata, key PublicKey) (*Permissions, error) { + return nil, nil + }, + } + serverConf.AddHostKey(testSigners["ecdsap256"]) + + done := make(chan struct{}) + go func() { + defer close(done) + NewServerConn(c1, serverConf) + }() + + clientConf := ClientConfig{ + User: "user", + Auth: []AuthMethod{ + PublicKeys(tt.key), + }, + HostKeyCallback: InsecureIgnoreHostKey(), + } + + _, _, _, err = NewClientConn(c2, "", &clientConf) + if err != nil { + if !tt.wantError { + t.Errorf("%s: got unexpected error %q", tt.name, err.Error()) + } + } else if tt.wantError { + t.Errorf("%s: succeeded, but want error", tt.name) + } + <-done + } +} + +func TestMaxAuthTriesNoneMethod(t *testing.T) { + username := "testuser" + serverConfig := &ServerConfig{ + MaxAuthTries: 2, + PasswordCallback: func(conn ConnMetadata, password []byte) (*Permissions, error) { + if conn.User() == username && string(password) == clientPassword { + return nil, nil + } + return nil, errors.New("invalid credentials") + }, + } + c1, c2, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + defer c1.Close() + defer c2.Close() + + var serverAuthErrors []error + + serverConfig.AddHostKey(testSigners["rsa"]) + serverConfig.AuthLogCallback = func(conn ConnMetadata, method string, err error) { + serverAuthErrors = append(serverAuthErrors, err) + } + go newServer(c1, serverConfig) + + clientConfig := ClientConfig{ + User: username, + HostKeyCallback: InsecureIgnoreHostKey(), + } + clientConfig.SetDefaults() + // Our client will send 'none' auth only once, so we need to send the + // requests manually. + c := &connection{ + sshConn: sshConn{ + conn: c2, + user: username, + clientVersion: []byte(packageVersion), + }, + } + c.serverVersion, err = exchangeVersions(c.sshConn.conn, c.clientVersion) + if err != nil { + t.Fatalf("unable to exchange version: %v", err) + } + c.transport = newClientTransport( + newTransport(c.sshConn.conn, clientConfig.Rand, true /* is client */), + c.clientVersion, c.serverVersion, &clientConfig, "", c.sshConn.RemoteAddr()) + if err := c.transport.waitSession(); err != nil { + t.Fatalf("unable to wait session: %v", err) + } + c.sessionID = c.transport.getSessionID() + if err := c.transport.writePacket(Marshal(&serviceRequestMsg{serviceUserAuth})); err != nil { + t.Fatalf("unable to send ssh-userauth message: %v", err) + } + packet, err := c.transport.readPacket() + if err != nil { + t.Fatal(err) + } + if len(packet) > 0 && packet[0] == msgExtInfo { + packet, err = c.transport.readPacket() + if err != nil { + t.Fatal(err) + } + } + var serviceAccept serviceAcceptMsg + if err := Unmarshal(packet, &serviceAccept); err != nil { + t.Fatal(err) + } + for i := 0; i <= serverConfig.MaxAuthTries; i++ { + auth := new(noneAuth) + _, _, err := auth.auth(c.sessionID, clientConfig.User, c.transport, clientConfig.Rand, nil) + if i < serverConfig.MaxAuthTries { + if err != nil { + t.Fatal(err) + } + continue + } + if err == nil { + t.Fatal("client: got no error") + } else if !strings.Contains(err.Error(), "too many authentication failures") { + t.Fatalf("client: got unexpected error: %v", err) + } + } + if len(serverAuthErrors) != 3 { + t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors) + } + for _, err := range serverAuthErrors { + if !errors.Is(err, ErrNoAuth) { + t.Errorf("go error: %v; want: %v", err, ErrNoAuth) + } + } +} + +func TestMaxAuthTriesFirstNoneAuthErrorIgnored(t *testing.T) { + username := "testuser" + serverConfig := &ServerConfig{ + MaxAuthTries: 1, + PasswordCallback: func(conn ConnMetadata, password []byte) (*Permissions, error) { + if conn.User() == username && string(password) == clientPassword { + return nil, nil + } + return nil, errors.New("invalid credentials") + }, + } + clientConfig := &ClientConfig{ + User: username, + Auth: []AuthMethod{ + Password(clientPassword), + }, + HostKeyCallback: InsecureIgnoreHostKey(), + } + + serverAuthErrors, err := doClientServerAuth(t, serverConfig, clientConfig) + if err != nil { + t.Fatalf("client login error: %s", err) + } + if len(serverAuthErrors) != 2 { + t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors) + } + if !errors.Is(serverAuthErrors[0], ErrNoAuth) { + t.Errorf("go error: %v; want: %v", serverAuthErrors[0], ErrNoAuth) + } + if serverAuthErrors[1] != nil { + t.Errorf("unexpected error: %v", serverAuthErrors[1]) + } +} + +func TestNewServerConnValidationErrors(t *testing.T) { + serverConf := &ServerConfig{ + PublicKeyAuthAlgorithms: []string{CertAlgoRSAv01}, + } + c := &markerConn{} + _, _, _, err := NewServerConn(c, serverConf) + if err == nil { + t.Fatal("NewServerConn with invalid public key auth algorithms succeeded") + } + if !c.isClosed() { + t.Fatal("NewServerConn with invalid public key auth algorithms left connection open") + } + if c.isUsed() { + t.Fatal("NewServerConn with invalid public key auth algorithms used connection") + } + + serverConf = &ServerConfig{ + Config: Config{ + KeyExchanges: []string{kexAlgoDHGEXSHA256}, + }, + } + c = &markerConn{} + _, _, _, err = NewServerConn(c, serverConf) + if err == nil { + t.Fatal("NewServerConn with unsupported key exchange succeeded") + } + if !c.isClosed() { + t.Fatal("NewServerConn with unsupported key exchange left connection open") + } + if c.isUsed() { + t.Fatal("NewServerConn with unsupported key exchange used connection") + } +} + +func TestBannerError(t *testing.T) { + serverConfig := &ServerConfig{ + BannerCallback: func(ConnMetadata) string { + return "banner from BannerCallback" + }, + NoClientAuth: true, + NoClientAuthCallback: func(ConnMetadata) (*Permissions, error) { + err := &BannerError{ + Err: errors.New("error from NoClientAuthCallback"), + Message: "banner from NoClientAuthCallback", + } + return nil, fmt.Errorf("wrapped: %w", err) + }, + PasswordCallback: func(conn ConnMetadata, password []byte) (*Permissions, error) { + return &Permissions{}, nil + }, + PublicKeyCallback: func(conn ConnMetadata, key PublicKey) (*Permissions, error) { + return nil, &BannerError{ + Err: errors.New("error from PublicKeyCallback"), + Message: "banner from PublicKeyCallback", + } + }, + KeyboardInteractiveCallback: func(conn ConnMetadata, client KeyboardInteractiveChallenge) (*Permissions, error) { + return nil, &BannerError{ + Err: nil, // make sure that a nil inner error is allowed + Message: "banner from KeyboardInteractiveCallback", + } + }, + } + serverConfig.AddHostKey(testSigners["rsa"]) + + var banners []string + clientConfig := &ClientConfig{ + User: "test", + Auth: []AuthMethod{ + PublicKeys(testSigners["rsa"]), + KeyboardInteractive(func(name, instruction string, questions []string, echos []bool) ([]string, error) { + return []string{"letmein"}, nil + }), + Password(clientPassword), + }, + HostKeyCallback: InsecureIgnoreHostKey(), + BannerCallback: func(msg string) error { + banners = append(banners, msg) + return nil + }, + } + + c1, c2, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + defer c1.Close() + defer c2.Close() + go newServer(c1, serverConfig) + c, _, _, err := NewClientConn(c2, "", clientConfig) + if err != nil { + t.Fatalf("client connection failed: %v", err) + } + defer c.Close() + + wantBanners := []string{ + "banner from BannerCallback", + "banner from NoClientAuthCallback", + "banner from PublicKeyCallback", + "banner from KeyboardInteractiveCallback", + } + if !reflect.DeepEqual(banners, wantBanners) { + t.Errorf("got banners:\n%q\nwant banners:\n%q", banners, wantBanners) + } +} + +func TestPublicKeyCallbackLastSeen(t *testing.T) { + var lastSeenKey PublicKey + + c1, c2, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + defer c1.Close() + defer c2.Close() + serverConf := &ServerConfig{ + PublicKeyCallback: func(conn ConnMetadata, key PublicKey) (*Permissions, error) { + lastSeenKey = key + fmt.Printf("seen %#v\n", key) + if _, ok := key.(*dsaPublicKey); !ok { + return nil, errors.New("nope") + } + return nil, nil + }, + } + serverConf.AddHostKey(testSigners["ecdsap256"]) + + done := make(chan struct{}) + go func() { + defer close(done) + NewServerConn(c1, serverConf) + }() + + clientConf := ClientConfig{ + User: "user", + Auth: []AuthMethod{ + PublicKeys(testSigners["rsa"], testSigners["dsa"], testSigners["ed25519"]), + }, + HostKeyCallback: InsecureIgnoreHostKey(), + } + + _, _, _, err = NewClientConn(c2, "", &clientConf) + if err != nil { + t.Fatal(err) + } + <-done + + expectedPublicKey := testSigners["dsa"].PublicKey().Marshal() + lastSeenMarshalled := lastSeenKey.Marshal() + if !bytes.Equal(lastSeenMarshalled, expectedPublicKey) { + t.Errorf("unexpected key: got %#v, want %#v", lastSeenKey, testSigners["dsa"].PublicKey()) + } +} + +func TestPreAuthConnAndBanners(t *testing.T) { + testDone := make(chan struct{}) + defer close(testDone) + + authConnc := make(chan ServerPreAuthConn, 1) + serverConfig := &ServerConfig{ + PreAuthConnCallback: func(c ServerPreAuthConn) { + t.Logf("got ServerPreAuthConn: %v", c) + authConnc <- c // for use later in the test + for _, s := range []string{"hello1", "hello2"} { + if err := c.SendAuthBanner(s); err != nil { + t.Errorf("failed to send banner %q: %v", s, err) + } + } + // Now start a goroutine to spam SendAuthBanner in hopes + // of hitting a race. + go func() { + for { + select { + case <-testDone: + return + default: + if err := c.SendAuthBanner("attempted-race"); err != nil && err != errSendBannerPhase { + t.Errorf("unexpected error from SendAuthBanner: %v", err) + } + time.Sleep(5 * time.Millisecond) + } + } + }() + }, + NoClientAuth: true, + NoClientAuthCallback: func(ConnMetadata) (*Permissions, error) { + t.Logf("got NoClientAuthCallback") + return &Permissions{}, nil + }, + } + serverConfig.AddHostKey(testSigners["rsa"]) + + var banners []string + clientConfig := &ClientConfig{ + User: "test", + HostKeyCallback: InsecureIgnoreHostKey(), + BannerCallback: func(msg string) error { + if msg != "attempted-race" { + banners = append(banners, msg) + } + return nil + }, + } + + c1, c2, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + defer c1.Close() + defer c2.Close() + go newServer(c1, serverConfig) + c, _, _, err := NewClientConn(c2, "", clientConfig) + if err != nil { + t.Fatalf("client connection failed: %v", err) + } + defer c.Close() + + wantBanners := []string{ + "hello1", + "hello2", + } + if !reflect.DeepEqual(banners, wantBanners) { + t.Errorf("got banners:\n%q\nwant banners:\n%q", banners, wantBanners) + } + + // Now that we're authenticated, verify that use of SendBanner + // is an error. + var bc ServerPreAuthConn + select { + case bc = <-authConnc: + default: + t.Fatal("expected ServerPreAuthConn") + } + if err := bc.SendAuthBanner("wrong-phase"); err == nil { + t.Error("unexpected success of SendAuthBanner after authentication") + } else if err != errSendBannerPhase { + t.Errorf("unexpected error: %v; want %v", err, errSendBannerPhase) + } +} + +type markerConn struct { + closed uint32 + used uint32 +} + +func (c *markerConn) isClosed() bool { + return atomic.LoadUint32(&c.closed) != 0 +} + +func (c *markerConn) isUsed() bool { + return atomic.LoadUint32(&c.used) != 0 +} + +func (c *markerConn) Close() error { + atomic.StoreUint32(&c.closed, 1) + return nil +} + +func (c *markerConn) Read(b []byte) (n int, err error) { + atomic.StoreUint32(&c.used, 1) + if atomic.LoadUint32(&c.closed) != 0 { + return 0, net.ErrClosed + } else { + return 0, io.EOF + } +} + +func (c *markerConn) Write(b []byte) (n int, err error) { + atomic.StoreUint32(&c.used, 1) + if atomic.LoadUint32(&c.closed) != 0 { + return 0, net.ErrClosed + } else { + return 0, io.ErrClosedPipe + } +} + +func (*markerConn) LocalAddr() net.Addr { return nil } +func (*markerConn) RemoteAddr() net.Addr { return nil } + +func (*markerConn) SetDeadline(t time.Time) error { return nil } +func (*markerConn) SetReadDeadline(t time.Time) error { return nil } +func (*markerConn) SetWriteDeadline(t time.Time) error { return nil } diff --git a/tempfork/sshtest/ssh/session.go b/tempfork/sshtest/ssh/session.go new file mode 100644 index 000000000..acef62259 --- /dev/null +++ b/tempfork/sshtest/ssh/session.go @@ -0,0 +1,647 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +// Session implements an interactive session described in +// "RFC 4254, section 6". + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + "io" + "sync" +) + +type Signal string + +// POSIX signals as listed in RFC 4254 Section 6.10. +const ( + SIGABRT Signal = "ABRT" + SIGALRM Signal = "ALRM" + SIGFPE Signal = "FPE" + SIGHUP Signal = "HUP" + SIGILL Signal = "ILL" + SIGINT Signal = "INT" + SIGKILL Signal = "KILL" + SIGPIPE Signal = "PIPE" + SIGQUIT Signal = "QUIT" + SIGSEGV Signal = "SEGV" + SIGTERM Signal = "TERM" + SIGUSR1 Signal = "USR1" + SIGUSR2 Signal = "USR2" +) + +var signals = map[Signal]int{ + SIGABRT: 6, + SIGALRM: 14, + SIGFPE: 8, + SIGHUP: 1, + SIGILL: 4, + SIGINT: 2, + SIGKILL: 9, + SIGPIPE: 13, + SIGQUIT: 3, + SIGSEGV: 11, + SIGTERM: 15, +} + +type TerminalModes map[uint8]uint32 + +// POSIX terminal mode flags as listed in RFC 4254 Section 8. +const ( + tty_OP_END = 0 + VINTR = 1 + VQUIT = 2 + VERASE = 3 + VKILL = 4 + VEOF = 5 + VEOL = 6 + VEOL2 = 7 + VSTART = 8 + VSTOP = 9 + VSUSP = 10 + VDSUSP = 11 + VREPRINT = 12 + VWERASE = 13 + VLNEXT = 14 + VFLUSH = 15 + VSWTCH = 16 + VSTATUS = 17 + VDISCARD = 18 + IGNPAR = 30 + PARMRK = 31 + INPCK = 32 + ISTRIP = 33 + INLCR = 34 + IGNCR = 35 + ICRNL = 36 + IUCLC = 37 + IXON = 38 + IXANY = 39 + IXOFF = 40 + IMAXBEL = 41 + IUTF8 = 42 // RFC 8160 + ISIG = 50 + ICANON = 51 + XCASE = 52 + ECHO = 53 + ECHOE = 54 + ECHOK = 55 + ECHONL = 56 + NOFLSH = 57 + TOSTOP = 58 + IEXTEN = 59 + ECHOCTL = 60 + ECHOKE = 61 + PENDIN = 62 + OPOST = 70 + OLCUC = 71 + ONLCR = 72 + OCRNL = 73 + ONOCR = 74 + ONLRET = 75 + CS7 = 90 + CS8 = 91 + PARENB = 92 + PARODD = 93 + TTY_OP_ISPEED = 128 + TTY_OP_OSPEED = 129 +) + +// A Session represents a connection to a remote command or shell. +type Session struct { + // Stdin specifies the remote process's standard input. + // If Stdin is nil, the remote process reads from an empty + // bytes.Buffer. + Stdin io.Reader + + // Stdout and Stderr specify the remote process's standard + // output and error. + // + // If either is nil, Run connects the corresponding file + // descriptor to an instance of io.Discard. There is a + // fixed amount of buffering that is shared for the two streams. + // If either blocks it may eventually cause the remote + // command to block. + Stdout io.Writer + Stderr io.Writer + + ch Channel // the channel backing this session + started bool // true once Start, Run or Shell is invoked. + copyFuncs []func() error + errors chan error // one send per copyFunc + + // true if pipe method is active + stdinpipe, stdoutpipe, stderrpipe bool + + // stdinPipeWriter is non-nil if StdinPipe has not been called + // and Stdin was specified by the user; it is the write end of + // a pipe connecting Session.Stdin to the stdin channel. + stdinPipeWriter io.WriteCloser + + exitStatus chan error +} + +// SendRequest sends an out-of-band channel request on the SSH channel +// underlying the session. +func (s *Session) SendRequest(name string, wantReply bool, payload []byte) (bool, error) { + return s.ch.SendRequest(name, wantReply, payload) +} + +func (s *Session) Close() error { + return s.ch.Close() +} + +// RFC 4254 Section 6.4. +type setenvRequest struct { + Name string + Value string +} + +// Setenv sets an environment variable that will be applied to any +// command executed by Shell or Run. +func (s *Session) Setenv(name, value string) error { + msg := setenvRequest{ + Name: name, + Value: value, + } + ok, err := s.ch.SendRequest("env", true, Marshal(&msg)) + if err == nil && !ok { + err = errors.New("ssh: setenv failed") + } + return err +} + +// RFC 4254 Section 6.2. +type ptyRequestMsg struct { + Term string + Columns uint32 + Rows uint32 + Width uint32 + Height uint32 + Modelist string +} + +// RequestPty requests the association of a pty with the session on the remote host. +func (s *Session) RequestPty(term string, h, w int, termmodes TerminalModes) error { + var tm []byte + for k, v := range termmodes { + kv := struct { + Key byte + Val uint32 + }{k, v} + + tm = append(tm, Marshal(&kv)...) + } + tm = append(tm, tty_OP_END) + req := ptyRequestMsg{ + Term: term, + Columns: uint32(w), + Rows: uint32(h), + Width: uint32(w * 8), + Height: uint32(h * 8), + Modelist: string(tm), + } + ok, err := s.ch.SendRequest("pty-req", true, Marshal(&req)) + if err == nil && !ok { + err = errors.New("ssh: pty-req failed") + } + return err +} + +// RFC 4254 Section 6.5. +type subsystemRequestMsg struct { + Subsystem string +} + +// RequestSubsystem requests the association of a subsystem with the session on the remote host. +// A subsystem is a predefined command that runs in the background when the ssh session is initiated +func (s *Session) RequestSubsystem(subsystem string) error { + msg := subsystemRequestMsg{ + Subsystem: subsystem, + } + ok, err := s.ch.SendRequest("subsystem", true, Marshal(&msg)) + if err == nil && !ok { + err = errors.New("ssh: subsystem request failed") + } + return err +} + +// RFC 4254 Section 6.7. +type ptyWindowChangeMsg struct { + Columns uint32 + Rows uint32 + Width uint32 + Height uint32 +} + +// WindowChange informs the remote host about a terminal window dimension change to h rows and w columns. +func (s *Session) WindowChange(h, w int) error { + req := ptyWindowChangeMsg{ + Columns: uint32(w), + Rows: uint32(h), + Width: uint32(w * 8), + Height: uint32(h * 8), + } + _, err := s.ch.SendRequest("window-change", false, Marshal(&req)) + return err +} + +// RFC 4254 Section 6.9. +type signalMsg struct { + Signal string +} + +// Signal sends the given signal to the remote process. +// sig is one of the SIG* constants. +func (s *Session) Signal(sig Signal) error { + msg := signalMsg{ + Signal: string(sig), + } + + _, err := s.ch.SendRequest("signal", false, Marshal(&msg)) + return err +} + +// RFC 4254 Section 6.5. +type execMsg struct { + Command string +} + +// Start runs cmd on the remote host. Typically, the remote +// server passes cmd to the shell for interpretation. +// A Session only accepts one call to Run, Start or Shell. +func (s *Session) Start(cmd string) error { + if s.started { + return errors.New("ssh: session already started") + } + req := execMsg{ + Command: cmd, + } + + ok, err := s.ch.SendRequest("exec", true, Marshal(&req)) + if err == nil && !ok { + err = fmt.Errorf("ssh: command %v failed", cmd) + } + if err != nil { + return err + } + return s.start() +} + +// Run runs cmd on the remote host. Typically, the remote +// server passes cmd to the shell for interpretation. +// A Session only accepts one call to Run, Start, Shell, Output, +// or CombinedOutput. +// +// The returned error is nil if the command runs, has no problems +// copying stdin, stdout, and stderr, and exits with a zero exit +// status. +// +// If the remote server does not send an exit status, an error of type +// *ExitMissingError is returned. If the command completes +// unsuccessfully or is interrupted by a signal, the error is of type +// *ExitError. Other error types may be returned for I/O problems. +func (s *Session) Run(cmd string) error { + err := s.Start(cmd) + if err != nil { + return err + } + return s.Wait() +} + +// Output runs cmd on the remote host and returns its standard output. +func (s *Session) Output(cmd string) ([]byte, error) { + if s.Stdout != nil { + return nil, errors.New("ssh: Stdout already set") + } + var b bytes.Buffer + s.Stdout = &b + err := s.Run(cmd) + return b.Bytes(), err +} + +type singleWriter struct { + b bytes.Buffer + mu sync.Mutex +} + +func (w *singleWriter) Write(p []byte) (int, error) { + w.mu.Lock() + defer w.mu.Unlock() + return w.b.Write(p) +} + +// CombinedOutput runs cmd on the remote host and returns its combined +// standard output and standard error. +func (s *Session) CombinedOutput(cmd string) ([]byte, error) { + if s.Stdout != nil { + return nil, errors.New("ssh: Stdout already set") + } + if s.Stderr != nil { + return nil, errors.New("ssh: Stderr already set") + } + var b singleWriter + s.Stdout = &b + s.Stderr = &b + err := s.Run(cmd) + return b.b.Bytes(), err +} + +// Shell starts a login shell on the remote host. A Session only +// accepts one call to Run, Start, Shell, Output, or CombinedOutput. +func (s *Session) Shell() error { + if s.started { + return errors.New("ssh: session already started") + } + + ok, err := s.ch.SendRequest("shell", true, nil) + if err == nil && !ok { + return errors.New("ssh: could not start shell") + } + if err != nil { + return err + } + return s.start() +} + +func (s *Session) start() error { + s.started = true + + type F func(*Session) + for _, setupFd := range []F{(*Session).stdin, (*Session).stdout, (*Session).stderr} { + setupFd(s) + } + + s.errors = make(chan error, len(s.copyFuncs)) + for _, fn := range s.copyFuncs { + go func(fn func() error) { + s.errors <- fn() + }(fn) + } + return nil +} + +// Wait waits for the remote command to exit. +// +// The returned error is nil if the command runs, has no problems +// copying stdin, stdout, and stderr, and exits with a zero exit +// status. +// +// If the remote server does not send an exit status, an error of type +// *ExitMissingError is returned. If the command completes +// unsuccessfully or is interrupted by a signal, the error is of type +// *ExitError. Other error types may be returned for I/O problems. +func (s *Session) Wait() error { + if !s.started { + return errors.New("ssh: session not started") + } + waitErr := <-s.exitStatus + + if s.stdinPipeWriter != nil { + s.stdinPipeWriter.Close() + } + var copyError error + for range s.copyFuncs { + if err := <-s.errors; err != nil && copyError == nil { + copyError = err + } + } + if waitErr != nil { + return waitErr + } + return copyError +} + +func (s *Session) wait(reqs <-chan *Request) error { + wm := Waitmsg{status: -1} + // Wait for msg channel to be closed before returning. + for msg := range reqs { + switch msg.Type { + case "exit-status": + wm.status = int(binary.BigEndian.Uint32(msg.Payload)) + case "exit-signal": + var sigval struct { + Signal string + CoreDumped bool + Error string + Lang string + } + if err := Unmarshal(msg.Payload, &sigval); err != nil { + return err + } + + // Must sanitize strings? + wm.signal = sigval.Signal + wm.msg = sigval.Error + wm.lang = sigval.Lang + default: + // This handles keepalives and matches + // OpenSSH's behaviour. + if msg.WantReply { + msg.Reply(false, nil) + } + } + } + if wm.status == 0 { + return nil + } + if wm.status == -1 { + // exit-status was never sent from server + if wm.signal == "" { + // signal was not sent either. RFC 4254 + // section 6.10 recommends against this + // behavior, but it is allowed, so we let + // clients handle it. + return &ExitMissingError{} + } + wm.status = 128 + if _, ok := signals[Signal(wm.signal)]; ok { + wm.status += signals[Signal(wm.signal)] + } + } + + return &ExitError{wm} +} + +// ExitMissingError is returned if a session is torn down cleanly, but +// the server sends no confirmation of the exit status. +type ExitMissingError struct{} + +func (e *ExitMissingError) Error() string { + return "wait: remote command exited without exit status or exit signal" +} + +func (s *Session) stdin() { + if s.stdinpipe { + return + } + var stdin io.Reader + if s.Stdin == nil { + stdin = new(bytes.Buffer) + } else { + r, w := io.Pipe() + go func() { + _, err := io.Copy(w, s.Stdin) + w.CloseWithError(err) + }() + stdin, s.stdinPipeWriter = r, w + } + s.copyFuncs = append(s.copyFuncs, func() error { + _, err := io.Copy(s.ch, stdin) + if err1 := s.ch.CloseWrite(); err == nil && err1 != io.EOF { + err = err1 + } + return err + }) +} + +func (s *Session) stdout() { + if s.stdoutpipe { + return + } + if s.Stdout == nil { + s.Stdout = io.Discard + } + s.copyFuncs = append(s.copyFuncs, func() error { + _, err := io.Copy(s.Stdout, s.ch) + return err + }) +} + +func (s *Session) stderr() { + if s.stderrpipe { + return + } + if s.Stderr == nil { + s.Stderr = io.Discard + } + s.copyFuncs = append(s.copyFuncs, func() error { + _, err := io.Copy(s.Stderr, s.ch.Stderr()) + return err + }) +} + +// sessionStdin reroutes Close to CloseWrite. +type sessionStdin struct { + io.Writer + ch Channel +} + +func (s *sessionStdin) Close() error { + return s.ch.CloseWrite() +} + +// StdinPipe returns a pipe that will be connected to the +// remote command's standard input when the command starts. +func (s *Session) StdinPipe() (io.WriteCloser, error) { + if s.Stdin != nil { + return nil, errors.New("ssh: Stdin already set") + } + if s.started { + return nil, errors.New("ssh: StdinPipe after process started") + } + s.stdinpipe = true + return &sessionStdin{s.ch, s.ch}, nil +} + +// StdoutPipe returns a pipe that will be connected to the +// remote command's standard output when the command starts. +// There is a fixed amount of buffering that is shared between +// stdout and stderr streams. If the StdoutPipe reader is +// not serviced fast enough it may eventually cause the +// remote command to block. +func (s *Session) StdoutPipe() (io.Reader, error) { + if s.Stdout != nil { + return nil, errors.New("ssh: Stdout already set") + } + if s.started { + return nil, errors.New("ssh: StdoutPipe after process started") + } + s.stdoutpipe = true + return s.ch, nil +} + +// StderrPipe returns a pipe that will be connected to the +// remote command's standard error when the command starts. +// There is a fixed amount of buffering that is shared between +// stdout and stderr streams. If the StderrPipe reader is +// not serviced fast enough it may eventually cause the +// remote command to block. +func (s *Session) StderrPipe() (io.Reader, error) { + if s.Stderr != nil { + return nil, errors.New("ssh: Stderr already set") + } + if s.started { + return nil, errors.New("ssh: StderrPipe after process started") + } + s.stderrpipe = true + return s.ch.Stderr(), nil +} + +// newSession returns a new interactive session on the remote host. +func newSession(ch Channel, reqs <-chan *Request) (*Session, error) { + s := &Session{ + ch: ch, + } + s.exitStatus = make(chan error, 1) + go func() { + s.exitStatus <- s.wait(reqs) + }() + + return s, nil +} + +// An ExitError reports unsuccessful completion of a remote command. +type ExitError struct { + Waitmsg +} + +func (e *ExitError) Error() string { + return e.Waitmsg.String() +} + +// Waitmsg stores the information about an exited remote command +// as reported by Wait. +type Waitmsg struct { + status int + signal string + msg string + lang string +} + +// ExitStatus returns the exit status of the remote command. +func (w Waitmsg) ExitStatus() int { + return w.status +} + +// Signal returns the exit signal of the remote command if +// it was terminated violently. +func (w Waitmsg) Signal() string { + return w.signal +} + +// Msg returns the exit message given by the remote command +func (w Waitmsg) Msg() string { + return w.msg +} + +// Lang returns the language tag. See RFC 3066 +func (w Waitmsg) Lang() string { + return w.lang +} + +func (w Waitmsg) String() string { + str := fmt.Sprintf("Process exited with status %v", w.status) + if w.signal != "" { + str += fmt.Sprintf(" from signal %v", w.signal) + } + if w.msg != "" { + str += fmt.Sprintf(". Reason was: %v", w.msg) + } + return str +} diff --git a/tempfork/sshtest/ssh/session_test.go b/tempfork/sshtest/ssh/session_test.go new file mode 100644 index 000000000..807a913e5 --- /dev/null +++ b/tempfork/sshtest/ssh/session_test.go @@ -0,0 +1,892 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +// Session tests. + +import ( + "bytes" + crypto_rand "crypto/rand" + "errors" + "io" + "math/rand" + "net" + "sync" + "testing" + + "golang.org/x/crypto/ssh/terminal" +) + +type serverType func(Channel, <-chan *Request, *testing.T) + +// dial constructs a new test server and returns a *ClientConn. +func dial(handler serverType, t *testing.T) *Client { + c1, c2, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + + var wg sync.WaitGroup + t.Cleanup(wg.Wait) + wg.Add(1) + go func() { + defer func() { + c1.Close() + wg.Done() + }() + conf := ServerConfig{ + NoClientAuth: true, + } + conf.AddHostKey(testSigners["rsa"]) + + conn, chans, reqs, err := NewServerConn(c1, &conf) + if err != nil { + t.Errorf("Unable to handshake: %v", err) + return + } + wg.Add(1) + go func() { + DiscardRequests(reqs) + wg.Done() + }() + + for newCh := range chans { + if newCh.ChannelType() != "session" { + newCh.Reject(UnknownChannelType, "unknown channel type") + continue + } + + ch, inReqs, err := newCh.Accept() + if err != nil { + t.Errorf("Accept: %v", err) + continue + } + wg.Add(1) + go func() { + handler(ch, inReqs, t) + wg.Done() + }() + } + if err := conn.Wait(); err != io.EOF { + t.Logf("server exit reason: %v", err) + } + }() + + config := &ClientConfig{ + User: "testuser", + HostKeyCallback: InsecureIgnoreHostKey(), + } + + conn, chans, reqs, err := NewClientConn(c2, "", config) + if err != nil { + t.Fatalf("unable to dial remote side: %v", err) + } + + return NewClient(conn, chans, reqs) +} + +// Test a simple string is returned to session.Stdout. +func TestSessionShell(t *testing.T) { + conn := dial(shellHandler, t) + defer conn.Close() + session, err := conn.NewSession() + if err != nil { + t.Fatalf("Unable to request new session: %v", err) + } + defer session.Close() + stdout := new(bytes.Buffer) + session.Stdout = stdout + if err := session.Shell(); err != nil { + t.Fatalf("Unable to execute command: %s", err) + } + if err := session.Wait(); err != nil { + t.Fatalf("Remote command did not exit cleanly: %v", err) + } + actual := stdout.String() + if actual != "golang" { + t.Fatalf("Remote shell did not return expected string: expected=golang, actual=%s", actual) + } +} + +// TODO(dfc) add support for Std{in,err}Pipe when the Server supports it. + +// Test a simple string is returned via StdoutPipe. +func TestSessionStdoutPipe(t *testing.T) { + conn := dial(shellHandler, t) + defer conn.Close() + session, err := conn.NewSession() + if err != nil { + t.Fatalf("Unable to request new session: %v", err) + } + defer session.Close() + stdout, err := session.StdoutPipe() + if err != nil { + t.Fatalf("Unable to request StdoutPipe(): %v", err) + } + var buf bytes.Buffer + if err := session.Shell(); err != nil { + t.Fatalf("Unable to execute command: %v", err) + } + done := make(chan bool, 1) + go func() { + if _, err := io.Copy(&buf, stdout); err != nil { + t.Errorf("Copy of stdout failed: %v", err) + } + done <- true + }() + if err := session.Wait(); err != nil { + t.Fatalf("Remote command did not exit cleanly: %v", err) + } + <-done + actual := buf.String() + if actual != "golang" { + t.Fatalf("Remote shell did not return expected string: expected=golang, actual=%s", actual) + } +} + +// Test that a simple string is returned via the Output helper, +// and that stderr is discarded. +func TestSessionOutput(t *testing.T) { + conn := dial(fixedOutputHandler, t) + defer conn.Close() + session, err := conn.NewSession() + if err != nil { + t.Fatalf("Unable to request new session: %v", err) + } + defer session.Close() + + buf, err := session.Output("") // cmd is ignored by fixedOutputHandler + if err != nil { + t.Error("Remote command did not exit cleanly:", err) + } + w := "this-is-stdout." + g := string(buf) + if g != w { + t.Error("Remote command did not return expected string:") + t.Logf("want %q", w) + t.Logf("got %q", g) + } +} + +// Test that both stdout and stderr are returned +// via the CombinedOutput helper. +func TestSessionCombinedOutput(t *testing.T) { + conn := dial(fixedOutputHandler, t) + defer conn.Close() + session, err := conn.NewSession() + if err != nil { + t.Fatalf("Unable to request new session: %v", err) + } + defer session.Close() + + buf, err := session.CombinedOutput("") // cmd is ignored by fixedOutputHandler + if err != nil { + t.Error("Remote command did not exit cleanly:", err) + } + const stdout = "this-is-stdout." + const stderr = "this-is-stderr." + g := string(buf) + if g != stdout+stderr && g != stderr+stdout { + t.Error("Remote command did not return expected string:") + t.Logf("want %q, or %q", stdout+stderr, stderr+stdout) + t.Logf("got %q", g) + } +} + +// Test non-0 exit status is returned correctly. +func TestExitStatusNonZero(t *testing.T) { + conn := dial(exitStatusNonZeroHandler, t) + defer conn.Close() + session, err := conn.NewSession() + if err != nil { + t.Fatalf("Unable to request new session: %v", err) + } + defer session.Close() + if err := session.Shell(); err != nil { + t.Fatalf("Unable to execute command: %v", err) + } + err = session.Wait() + if err == nil { + t.Fatalf("expected command to fail but it didn't") + } + e, ok := err.(*ExitError) + if !ok { + t.Fatalf("expected *ExitError but got %T", err) + } + if e.ExitStatus() != 15 { + t.Fatalf("expected command to exit with 15 but got %v", e.ExitStatus()) + } +} + +// Test 0 exit status is returned correctly. +func TestExitStatusZero(t *testing.T) { + conn := dial(exitStatusZeroHandler, t) + defer conn.Close() + session, err := conn.NewSession() + if err != nil { + t.Fatalf("Unable to request new session: %v", err) + } + defer session.Close() + + if err := session.Shell(); err != nil { + t.Fatalf("Unable to execute command: %v", err) + } + err = session.Wait() + if err != nil { + t.Fatalf("expected nil but got %v", err) + } +} + +// Test exit signal and status are both returned correctly. +func TestExitSignalAndStatus(t *testing.T) { + conn := dial(exitSignalAndStatusHandler, t) + defer conn.Close() + session, err := conn.NewSession() + if err != nil { + t.Fatalf("Unable to request new session: %v", err) + } + defer session.Close() + if err := session.Shell(); err != nil { + t.Fatalf("Unable to execute command: %v", err) + } + err = session.Wait() + if err == nil { + t.Fatalf("expected command to fail but it didn't") + } + e, ok := err.(*ExitError) + if !ok { + t.Fatalf("expected *ExitError but got %T", err) + } + if e.Signal() != "TERM" || e.ExitStatus() != 15 { + t.Fatalf("expected command to exit with signal TERM and status 15 but got signal %s and status %v", e.Signal(), e.ExitStatus()) + } +} + +// Test exit signal and status are both returned correctly. +func TestKnownExitSignalOnly(t *testing.T) { + conn := dial(exitSignalHandler, t) + defer conn.Close() + session, err := conn.NewSession() + if err != nil { + t.Fatalf("Unable to request new session: %v", err) + } + defer session.Close() + if err := session.Shell(); err != nil { + t.Fatalf("Unable to execute command: %v", err) + } + err = session.Wait() + if err == nil { + t.Fatalf("expected command to fail but it didn't") + } + e, ok := err.(*ExitError) + if !ok { + t.Fatalf("expected *ExitError but got %T", err) + } + if e.Signal() != "TERM" || e.ExitStatus() != 143 { + t.Fatalf("expected command to exit with signal TERM and status 143 but got signal %s and status %v", e.Signal(), e.ExitStatus()) + } +} + +// Test exit signal and status are both returned correctly. +func TestUnknownExitSignal(t *testing.T) { + conn := dial(exitSignalUnknownHandler, t) + defer conn.Close() + session, err := conn.NewSession() + if err != nil { + t.Fatalf("Unable to request new session: %v", err) + } + defer session.Close() + if err := session.Shell(); err != nil { + t.Fatalf("Unable to execute command: %v", err) + } + err = session.Wait() + if err == nil { + t.Fatalf("expected command to fail but it didn't") + } + e, ok := err.(*ExitError) + if !ok { + t.Fatalf("expected *ExitError but got %T", err) + } + if e.Signal() != "SYS" || e.ExitStatus() != 128 { + t.Fatalf("expected command to exit with signal SYS and status 128 but got signal %s and status %v", e.Signal(), e.ExitStatus()) + } +} + +func TestExitWithoutStatusOrSignal(t *testing.T) { + conn := dial(exitWithoutSignalOrStatus, t) + defer conn.Close() + session, err := conn.NewSession() + if err != nil { + t.Fatalf("Unable to request new session: %v", err) + } + defer session.Close() + if err := session.Shell(); err != nil { + t.Fatalf("Unable to execute command: %v", err) + } + err = session.Wait() + if err == nil { + t.Fatalf("expected command to fail but it didn't") + } + if _, ok := err.(*ExitMissingError); !ok { + t.Fatalf("got %T want *ExitMissingError", err) + } +} + +// windowTestBytes is the number of bytes that we'll send to the SSH server. +const windowTestBytes = 16000 * 200 + +// TestServerWindow writes random data to the server. The server is expected to echo +// the same data back, which is compared against the original. +func TestServerWindow(t *testing.T) { + origBuf := bytes.NewBuffer(make([]byte, 0, windowTestBytes)) + io.CopyN(origBuf, crypto_rand.Reader, windowTestBytes) + origBytes := origBuf.Bytes() + + conn := dial(echoHandler, t) + defer conn.Close() + session, err := conn.NewSession() + if err != nil { + t.Fatal(err) + } + defer session.Close() + + serverStdin, err := session.StdinPipe() + if err != nil { + t.Fatalf("StdinPipe failed: %v", err) + } + + result := make(chan []byte) + go func() { + defer close(result) + echoedBuf := bytes.NewBuffer(make([]byte, 0, windowTestBytes)) + serverStdout, err := session.StdoutPipe() + if err != nil { + t.Errorf("StdoutPipe failed: %v", err) + return + } + n, err := copyNRandomly("stdout", echoedBuf, serverStdout, windowTestBytes) + if err != nil && err != io.EOF { + t.Errorf("Read only %d bytes from server, expected %d: %v", n, windowTestBytes, err) + } + result <- echoedBuf.Bytes() + }() + + written, err := copyNRandomly("stdin", serverStdin, origBuf, windowTestBytes) + if err != nil { + t.Errorf("failed to copy origBuf to serverStdin: %v", err) + } else if written != windowTestBytes { + t.Errorf("Wrote only %d of %d bytes to server", written, windowTestBytes) + } + + echoedBytes := <-result + + if !bytes.Equal(origBytes, echoedBytes) { + t.Fatalf("Echoed buffer differed from original, orig %d, echoed %d", len(origBytes), len(echoedBytes)) + } +} + +// Verify the client can handle a keepalive packet from the server. +func TestClientHandlesKeepalives(t *testing.T) { + conn := dial(channelKeepaliveSender, t) + defer conn.Close() + session, err := conn.NewSession() + if err != nil { + t.Fatal(err) + } + defer session.Close() + if err := session.Shell(); err != nil { + t.Fatalf("Unable to execute command: %v", err) + } + err = session.Wait() + if err != nil { + t.Fatalf("expected nil but got: %v", err) + } +} + +type exitStatusMsg struct { + Status uint32 +} + +type exitSignalMsg struct { + Signal string + CoreDumped bool + Errmsg string + Lang string +} + +func handleTerminalRequests(in <-chan *Request) { + for req := range in { + ok := false + switch req.Type { + case "shell": + ok = true + if len(req.Payload) > 0 { + // We don't accept any commands, only the default shell. + ok = false + } + case "env": + ok = true + } + req.Reply(ok, nil) + } +} + +func newServerShell(ch Channel, in <-chan *Request, prompt string) *terminal.Terminal { + term := terminal.NewTerminal(ch, prompt) + go handleTerminalRequests(in) + return term +} + +func exitStatusZeroHandler(ch Channel, in <-chan *Request, t *testing.T) { + defer ch.Close() + // this string is returned to stdout + shell := newServerShell(ch, in, "> ") + readLine(shell, t) + sendStatus(0, ch, t) +} + +func exitStatusNonZeroHandler(ch Channel, in <-chan *Request, t *testing.T) { + defer ch.Close() + shell := newServerShell(ch, in, "> ") + readLine(shell, t) + sendStatus(15, ch, t) +} + +func exitSignalAndStatusHandler(ch Channel, in <-chan *Request, t *testing.T) { + defer ch.Close() + shell := newServerShell(ch, in, "> ") + readLine(shell, t) + sendStatus(15, ch, t) + sendSignal("TERM", ch, t) +} + +func exitSignalHandler(ch Channel, in <-chan *Request, t *testing.T) { + defer ch.Close() + shell := newServerShell(ch, in, "> ") + readLine(shell, t) + sendSignal("TERM", ch, t) +} + +func exitSignalUnknownHandler(ch Channel, in <-chan *Request, t *testing.T) { + defer ch.Close() + shell := newServerShell(ch, in, "> ") + readLine(shell, t) + sendSignal("SYS", ch, t) +} + +func exitWithoutSignalOrStatus(ch Channel, in <-chan *Request, t *testing.T) { + defer ch.Close() + shell := newServerShell(ch, in, "> ") + readLine(shell, t) +} + +func shellHandler(ch Channel, in <-chan *Request, t *testing.T) { + defer ch.Close() + // this string is returned to stdout + shell := newServerShell(ch, in, "golang") + readLine(shell, t) + sendStatus(0, ch, t) +} + +// Ignores the command, writes fixed strings to stderr and stdout. +// Strings are "this-is-stdout." and "this-is-stderr.". +func fixedOutputHandler(ch Channel, in <-chan *Request, t *testing.T) { + defer ch.Close() + _, err := ch.Read(nil) + + req, ok := <-in + if !ok { + t.Fatalf("error: expected channel request, got: %#v", err) + return + } + + // ignore request, always send some text + req.Reply(true, nil) + + _, err = io.WriteString(ch, "this-is-stdout.") + if err != nil { + t.Fatalf("error writing on server: %v", err) + } + _, err = io.WriteString(ch.Stderr(), "this-is-stderr.") + if err != nil { + t.Fatalf("error writing on server: %v", err) + } + sendStatus(0, ch, t) +} + +func readLine(shell *terminal.Terminal, t *testing.T) { + if _, err := shell.ReadLine(); err != nil && err != io.EOF { + t.Errorf("unable to read line: %v", err) + } +} + +func sendStatus(status uint32, ch Channel, t *testing.T) { + msg := exitStatusMsg{ + Status: status, + } + if _, err := ch.SendRequest("exit-status", false, Marshal(&msg)); err != nil { + t.Errorf("unable to send status: %v", err) + } +} + +func sendSignal(signal string, ch Channel, t *testing.T) { + sig := exitSignalMsg{ + Signal: signal, + CoreDumped: false, + Errmsg: "Process terminated", + Lang: "en-GB-oed", + } + if _, err := ch.SendRequest("exit-signal", false, Marshal(&sig)); err != nil { + t.Errorf("unable to send signal: %v", err) + } +} + +func discardHandler(ch Channel, t *testing.T) { + defer ch.Close() + io.Copy(io.Discard, ch) +} + +func echoHandler(ch Channel, in <-chan *Request, t *testing.T) { + defer ch.Close() + if n, err := copyNRandomly("echohandler", ch, ch, windowTestBytes); err != nil { + t.Errorf("short write, wrote %d, expected %d: %v ", n, windowTestBytes, err) + } +} + +// copyNRandomly copies n bytes from src to dst. It uses a variable, and random, +// buffer size to exercise more code paths. +func copyNRandomly(title string, dst io.Writer, src io.Reader, n int) (int, error) { + var ( + buf = make([]byte, 32*1024) + written int + remaining = n + ) + for remaining > 0 { + l := rand.Intn(1 << 15) + if remaining < l { + l = remaining + } + nr, er := src.Read(buf[:l]) + nw, ew := dst.Write(buf[:nr]) + remaining -= nw + written += nw + if ew != nil { + return written, ew + } + if nr != nw { + return written, io.ErrShortWrite + } + if er != nil && er != io.EOF { + return written, er + } + } + return written, nil +} + +func channelKeepaliveSender(ch Channel, in <-chan *Request, t *testing.T) { + defer ch.Close() + shell := newServerShell(ch, in, "> ") + readLine(shell, t) + if _, err := ch.SendRequest("keepalive@openssh.com", true, nil); err != nil { + t.Errorf("unable to send channel keepalive request: %v", err) + } + sendStatus(0, ch, t) +} + +func TestClientWriteEOF(t *testing.T) { + conn := dial(simpleEchoHandler, t) + defer conn.Close() + + session, err := conn.NewSession() + if err != nil { + t.Fatal(err) + } + defer session.Close() + stdin, err := session.StdinPipe() + if err != nil { + t.Fatalf("StdinPipe failed: %v", err) + } + stdout, err := session.StdoutPipe() + if err != nil { + t.Fatalf("StdoutPipe failed: %v", err) + } + + data := []byte(`0000`) + _, err = stdin.Write(data) + if err != nil { + t.Fatalf("Write failed: %v", err) + } + stdin.Close() + + res, err := io.ReadAll(stdout) + if err != nil { + t.Fatalf("Read failed: %v", err) + } + + if !bytes.Equal(data, res) { + t.Fatalf("Read differed from write, wrote: %v, read: %v", data, res) + } +} + +func simpleEchoHandler(ch Channel, in <-chan *Request, t *testing.T) { + defer ch.Close() + data, err := io.ReadAll(ch) + if err != nil { + t.Errorf("handler read error: %v", err) + } + _, err = ch.Write(data) + if err != nil { + t.Errorf("handler write error: %v", err) + } +} + +func TestSessionID(t *testing.T) { + c1, c2, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + defer c1.Close() + defer c2.Close() + + serverID := make(chan []byte, 1) + clientID := make(chan []byte, 1) + + serverConf := &ServerConfig{ + NoClientAuth: true, + } + serverConf.AddHostKey(testSigners["ecdsa"]) + clientConf := &ClientConfig{ + HostKeyCallback: InsecureIgnoreHostKey(), + User: "user", + } + + var wg sync.WaitGroup + t.Cleanup(wg.Wait) + + srvErrCh := make(chan error, 1) + wg.Add(1) + go func() { + defer wg.Done() + conn, chans, reqs, err := NewServerConn(c1, serverConf) + srvErrCh <- err + if err != nil { + return + } + serverID <- conn.SessionID() + wg.Add(1) + go func() { + DiscardRequests(reqs) + wg.Done() + }() + for ch := range chans { + ch.Reject(Prohibited, "") + } + }() + + cliErrCh := make(chan error, 1) + wg.Add(1) + go func() { + defer wg.Done() + conn, chans, reqs, err := NewClientConn(c2, "", clientConf) + cliErrCh <- err + if err != nil { + return + } + clientID <- conn.SessionID() + wg.Add(1) + go func() { + DiscardRequests(reqs) + wg.Done() + }() + for ch := range chans { + ch.Reject(Prohibited, "") + } + }() + + if err := <-srvErrCh; err != nil { + t.Fatalf("server handshake: %v", err) + } + + if err := <-cliErrCh; err != nil { + t.Fatalf("client handshake: %v", err) + } + + s := <-serverID + c := <-clientID + if bytes.Compare(s, c) != 0 { + t.Errorf("server session ID (%x) != client session ID (%x)", s, c) + } else if len(s) == 0 { + t.Errorf("client and server SessionID were empty.") + } +} + +type noReadConn struct { + readSeen bool + net.Conn +} + +func (c *noReadConn) Close() error { + return nil +} + +func (c *noReadConn) Read(b []byte) (int, error) { + c.readSeen = true + return 0, errors.New("noReadConn error") +} + +func TestInvalidServerConfiguration(t *testing.T) { + c1, c2, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + defer c1.Close() + defer c2.Close() + + serveConn := noReadConn{Conn: c1} + serverConf := &ServerConfig{} + + NewServerConn(&serveConn, serverConf) + if serveConn.readSeen { + t.Fatalf("NewServerConn attempted to Read() from Conn while configuration is missing host key") + } + + serverConf.AddHostKey(testSigners["ecdsa"]) + + NewServerConn(&serveConn, serverConf) + if serveConn.readSeen { + t.Fatalf("NewServerConn attempted to Read() from Conn while configuration is missing authentication method") + } +} + +func TestHostKeyAlgorithms(t *testing.T) { + serverConf := &ServerConfig{ + NoClientAuth: true, + } + serverConf.AddHostKey(testSigners["rsa"]) + serverConf.AddHostKey(testSigners["ecdsa"]) + + var wg sync.WaitGroup + t.Cleanup(wg.Wait) + connect := func(clientConf *ClientConfig, want string) { + var alg string + clientConf.HostKeyCallback = func(h string, a net.Addr, key PublicKey) error { + alg = key.Type() + return nil + } + c1, c2, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + defer c1.Close() + defer c2.Close() + + wg.Add(1) + go func() { + NewServerConn(c1, serverConf) + wg.Done() + }() + _, _, _, err = NewClientConn(c2, "", clientConf) + if err != nil { + t.Fatalf("NewClientConn: %v", err) + } + if alg != want { + t.Errorf("selected key algorithm %s, want %s", alg, want) + } + } + + // By default, we get the preferred algorithm, which is ECDSA 256. + + clientConf := &ClientConfig{ + HostKeyCallback: InsecureIgnoreHostKey(), + } + connect(clientConf, KeyAlgoECDSA256) + + // Client asks for RSA explicitly. + clientConf.HostKeyAlgorithms = []string{KeyAlgoRSA} + connect(clientConf, KeyAlgoRSA) + + // Client asks for RSA-SHA2-512 explicitly. + clientConf.HostKeyAlgorithms = []string{KeyAlgoRSASHA512} + // We get back an "ssh-rsa" key but the verification happened + // with an RSA-SHA2-512 signature. + connect(clientConf, KeyAlgoRSA) + + c1, c2, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + defer c1.Close() + defer c2.Close() + + wg.Add(1) + go func() { + NewServerConn(c1, serverConf) + wg.Done() + }() + clientConf.HostKeyAlgorithms = []string{"nonexistent-hostkey-algo"} + _, _, _, err = NewClientConn(c2, "", clientConf) + if err == nil { + t.Fatal("succeeded connecting with unknown hostkey algorithm") + } +} + +func TestServerClientAuthCallback(t *testing.T) { + c1, c2, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + defer c1.Close() + defer c2.Close() + + userCh := make(chan string, 1) + + serverConf := &ServerConfig{ + NoClientAuth: true, + NoClientAuthCallback: func(conn ConnMetadata) (*Permissions, error) { + userCh <- conn.User() + return nil, nil + }, + } + const someUsername = "some-username" + + serverConf.AddHostKey(testSigners["ecdsa"]) + clientConf := &ClientConfig{ + HostKeyCallback: InsecureIgnoreHostKey(), + User: someUsername, + } + + var wg sync.WaitGroup + t.Cleanup(wg.Wait) + wg.Add(1) + go func() { + defer wg.Done() + _, chans, reqs, err := NewServerConn(c1, serverConf) + if err != nil { + t.Errorf("server handshake: %v", err) + userCh <- "error" + return + } + wg.Add(1) + go func() { + DiscardRequests(reqs) + wg.Done() + }() + for ch := range chans { + ch.Reject(Prohibited, "") + } + }() + + conn, _, _, err := NewClientConn(c2, "", clientConf) + if err != nil { + t.Fatalf("client handshake: %v", err) + return + } + conn.Close() + + got := <-userCh + if got != someUsername { + t.Errorf("username = %q; want %q", got, someUsername) + } +} diff --git a/tempfork/sshtest/ssh/ssh_gss.go b/tempfork/sshtest/ssh/ssh_gss.go new file mode 100644 index 000000000..24bd7c8e8 --- /dev/null +++ b/tempfork/sshtest/ssh/ssh_gss.go @@ -0,0 +1,139 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "encoding/asn1" + "errors" +) + +var krb5OID []byte + +func init() { + krb5OID, _ = asn1.Marshal(krb5Mesh) +} + +// GSSAPIClient provides the API to plug-in GSSAPI authentication for client logins. +type GSSAPIClient interface { + // InitSecContext initiates the establishment of a security context for GSS-API between the + // ssh client and ssh server. Initially the token parameter should be specified as nil. + // The routine may return a outputToken which should be transferred to + // the ssh server, where the ssh server will present it to + // AcceptSecContext. If no token need be sent, InitSecContext will indicate this by setting + // needContinue to false. To complete the context + // establishment, one or more reply tokens may be required from the ssh + // server;if so, InitSecContext will return a needContinue which is true. + // In this case, InitSecContext should be called again when the + // reply token is received from the ssh server, passing the reply + // token to InitSecContext via the token parameters. + // See RFC 2743 section 2.2.1 and RFC 4462 section 3.4. + InitSecContext(target string, token []byte, isGSSDelegCreds bool) (outputToken []byte, needContinue bool, err error) + // GetMIC generates a cryptographic MIC for the SSH2 message, and places + // the MIC in a token for transfer to the ssh server. + // The contents of the MIC field are obtained by calling GSS_GetMIC() + // over the following, using the GSS-API context that was just + // established: + // string session identifier + // byte SSH_MSG_USERAUTH_REQUEST + // string user name + // string service + // string "gssapi-with-mic" + // See RFC 2743 section 2.3.1 and RFC 4462 3.5. + GetMIC(micFiled []byte) ([]byte, error) + // Whenever possible, it should be possible for + // DeleteSecContext() calls to be successfully processed even + // if other calls cannot succeed, thereby enabling context-related + // resources to be released. + // In addition to deleting established security contexts, + // gss_delete_sec_context must also be able to delete "half-built" + // security contexts resulting from an incomplete sequence of + // InitSecContext()/AcceptSecContext() calls. + // See RFC 2743 section 2.2.3. + DeleteSecContext() error +} + +// GSSAPIServer provides the API to plug in GSSAPI authentication for server logins. +type GSSAPIServer interface { + // AcceptSecContext allows a remotely initiated security context between the application + // and a remote peer to be established by the ssh client. The routine may return a + // outputToken which should be transferred to the ssh client, + // where the ssh client will present it to InitSecContext. + // If no token need be sent, AcceptSecContext will indicate this + // by setting the needContinue to false. To + // complete the context establishment, one or more reply tokens may be + // required from the ssh client. if so, AcceptSecContext + // will return a needContinue which is true, in which case it + // should be called again when the reply token is received from the ssh + // client, passing the token to AcceptSecContext via the + // token parameters. + // The srcName return value is the authenticated username. + // See RFC 2743 section 2.2.2 and RFC 4462 section 3.4. + AcceptSecContext(token []byte) (outputToken []byte, srcName string, needContinue bool, err error) + // VerifyMIC verifies that a cryptographic MIC, contained in the token parameter, + // fits the supplied message is received from the ssh client. + // See RFC 2743 section 2.3.2. + VerifyMIC(micField []byte, micToken []byte) error + // Whenever possible, it should be possible for + // DeleteSecContext() calls to be successfully processed even + // if other calls cannot succeed, thereby enabling context-related + // resources to be released. + // In addition to deleting established security contexts, + // gss_delete_sec_context must also be able to delete "half-built" + // security contexts resulting from an incomplete sequence of + // InitSecContext()/AcceptSecContext() calls. + // See RFC 2743 section 2.2.3. + DeleteSecContext() error +} + +var ( + // OpenSSH supports Kerberos V5 mechanism only for GSS-API authentication, + // so we also support the krb5 mechanism only. + // See RFC 1964 section 1. + krb5Mesh = asn1.ObjectIdentifier{1, 2, 840, 113554, 1, 2, 2} +) + +// The GSS-API authentication method is initiated when the client sends an SSH_MSG_USERAUTH_REQUEST +// See RFC 4462 section 3.2. +type userAuthRequestGSSAPI struct { + N uint32 + OIDS []asn1.ObjectIdentifier +} + +func parseGSSAPIPayload(payload []byte) (*userAuthRequestGSSAPI, error) { + n, rest, ok := parseUint32(payload) + if !ok { + return nil, errors.New("parse uint32 failed") + } + s := &userAuthRequestGSSAPI{ + N: n, + OIDS: make([]asn1.ObjectIdentifier, n), + } + for i := 0; i < int(n); i++ { + var ( + desiredMech []byte + err error + ) + desiredMech, rest, ok = parseString(rest) + if !ok { + return nil, errors.New("parse string failed") + } + if rest, err = asn1.Unmarshal(desiredMech, &s.OIDS[i]); err != nil { + return nil, err + } + + } + return s, nil +} + +// See RFC 4462 section 3.6. +func buildMIC(sessionID string, username string, service string, authMethod string) []byte { + out := make([]byte, 0, 0) + out = appendString(out, sessionID) + out = append(out, msgUserAuthRequest) + out = appendString(out, username) + out = appendString(out, service) + out = appendString(out, authMethod) + return out +} diff --git a/tempfork/sshtest/ssh/ssh_gss_test.go b/tempfork/sshtest/ssh/ssh_gss_test.go new file mode 100644 index 000000000..39a111288 --- /dev/null +++ b/tempfork/sshtest/ssh/ssh_gss_test.go @@ -0,0 +1,109 @@ +package ssh + +import ( + "fmt" + "testing" +) + +func TestParseGSSAPIPayload(t *testing.T) { + payload := []byte{0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x0b, 0x06, 0x09, + 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x12, 0x01, 0x02, 0x02} + res, err := parseGSSAPIPayload(payload) + if err != nil { + t.Fatal(err) + } + if ok := res.OIDS[0].Equal(krb5Mesh); !ok { + t.Fatalf("got %v, want %v", res, krb5Mesh) + } +} + +func TestBuildMIC(t *testing.T) { + sessionID := []byte{134, 180, 134, 194, 62, 145, 171, 82, 119, 149, 254, 196, 125, 173, 177, 145, 187, 85, 53, + 183, 44, 150, 219, 129, 166, 195, 19, 33, 209, 246, 175, 121} + username := "testuser" + service := "ssh-connection" + authMethod := "gssapi-with-mic" + expected := []byte{0, 0, 0, 32, 134, 180, 134, 194, 62, 145, 171, 82, 119, 149, 254, 196, 125, 173, 177, 145, 187, 85, 53, 183, 44, 150, 219, 129, 166, 195, 19, 33, 209, 246, 175, 121, 50, 0, 0, 0, 8, 116, 101, 115, 116, 117, 115, 101, 114, 0, 0, 0, 14, 115, 115, 104, 45, 99, 111, 110, 110, 101, 99, 116, 105, 111, 110, 0, 0, 0, 15, 103, 115, 115, 97, 112, 105, 45, 119, 105, 116, 104, 45, 109, 105, 99} + result := buildMIC(string(sessionID), username, service, authMethod) + if string(result) != string(expected) { + t.Fatalf("buildMic: got %v, want %v", result, expected) + } +} + +type exchange struct { + outToken string + expectedToken string +} + +type FakeClient struct { + exchanges []*exchange + round int + mic []byte + maxRound int +} + +func (f *FakeClient) InitSecContext(target string, token []byte, isGSSDelegCreds bool) (outputToken []byte, needContinue bool, err error) { + if token == nil { + if f.exchanges[f.round].expectedToken != "" { + err = fmt.Errorf("got empty token, want %q", f.exchanges[f.round].expectedToken) + } else { + outputToken = []byte(f.exchanges[f.round].outToken) + } + } else { + if string(token) != string(f.exchanges[f.round].expectedToken) { + err = fmt.Errorf("got %q, want token %q", token, f.exchanges[f.round].expectedToken) + } else { + outputToken = []byte(f.exchanges[f.round].outToken) + } + } + f.round++ + needContinue = f.round < f.maxRound + return +} + +func (f *FakeClient) GetMIC(micField []byte) ([]byte, error) { + return f.mic, nil +} + +func (f *FakeClient) DeleteSecContext() error { + return nil +} + +type FakeServer struct { + exchanges []*exchange + round int + expectedMIC []byte + srcName string + maxRound int +} + +func (f *FakeServer) AcceptSecContext(token []byte) (outputToken []byte, srcName string, needContinue bool, err error) { + if token == nil { + if f.exchanges[f.round].expectedToken != "" { + err = fmt.Errorf("got empty token, want %q", f.exchanges[f.round].expectedToken) + } else { + outputToken = []byte(f.exchanges[f.round].outToken) + } + } else { + if string(token) != string(f.exchanges[f.round].expectedToken) { + err = fmt.Errorf("got %q, want token %q", token, f.exchanges[f.round].expectedToken) + } else { + outputToken = []byte(f.exchanges[f.round].outToken) + } + } + f.round++ + needContinue = f.round < f.maxRound + srcName = f.srcName + return +} + +func (f *FakeServer) VerifyMIC(micField []byte, micToken []byte) error { + if string(micToken) != string(f.expectedMIC) { + return fmt.Errorf("got MICToken %q, want %q", micToken, f.expectedMIC) + } + return nil +} + +func (f *FakeServer) DeleteSecContext() error { + return nil +} diff --git a/tempfork/sshtest/ssh/streamlocal.go b/tempfork/sshtest/ssh/streamlocal.go new file mode 100644 index 000000000..b171b330b --- /dev/null +++ b/tempfork/sshtest/ssh/streamlocal.go @@ -0,0 +1,116 @@ +package ssh + +import ( + "errors" + "io" + "net" +) + +// streamLocalChannelOpenDirectMsg is a struct used for SSH_MSG_CHANNEL_OPEN message +// with "direct-streamlocal@openssh.com" string. +// +// See openssh-portable/PROTOCOL, section 2.4. connection: Unix domain socket forwarding +// https://github.com/openssh/openssh-portable/blob/master/PROTOCOL#L235 +type streamLocalChannelOpenDirectMsg struct { + socketPath string + reserved0 string + reserved1 uint32 +} + +// forwardedStreamLocalPayload is a struct used for SSH_MSG_CHANNEL_OPEN message +// with "forwarded-streamlocal@openssh.com" string. +type forwardedStreamLocalPayload struct { + SocketPath string + Reserved0 string +} + +// streamLocalChannelForwardMsg is a struct used for SSH2_MSG_GLOBAL_REQUEST message +// with "streamlocal-forward@openssh.com"/"cancel-streamlocal-forward@openssh.com" string. +type streamLocalChannelForwardMsg struct { + socketPath string +} + +// ListenUnix is similar to ListenTCP but uses a Unix domain socket. +func (c *Client) ListenUnix(socketPath string) (net.Listener, error) { + c.handleForwardsOnce.Do(c.handleForwards) + m := streamLocalChannelForwardMsg{ + socketPath, + } + // send message + ok, _, err := c.SendRequest("streamlocal-forward@openssh.com", true, Marshal(&m)) + if err != nil { + return nil, err + } + if !ok { + return nil, errors.New("ssh: streamlocal-forward@openssh.com request denied by peer") + } + ch := c.forwards.add(&net.UnixAddr{Name: socketPath, Net: "unix"}) + + return &unixListener{socketPath, c, ch}, nil +} + +func (c *Client) dialStreamLocal(socketPath string) (Channel, error) { + msg := streamLocalChannelOpenDirectMsg{ + socketPath: socketPath, + } + ch, in, err := c.OpenChannel("direct-streamlocal@openssh.com", Marshal(&msg)) + if err != nil { + return nil, err + } + go DiscardRequests(in) + return ch, err +} + +type unixListener struct { + socketPath string + + conn *Client + in <-chan forward +} + +// Accept waits for and returns the next connection to the listener. +func (l *unixListener) Accept() (net.Conn, error) { + s, ok := <-l.in + if !ok { + return nil, io.EOF + } + ch, incoming, err := s.newCh.Accept() + if err != nil { + return nil, err + } + go DiscardRequests(incoming) + + return &chanConn{ + Channel: ch, + laddr: &net.UnixAddr{ + Name: l.socketPath, + Net: "unix", + }, + raddr: &net.UnixAddr{ + Name: "@", + Net: "unix", + }, + }, nil +} + +// Close closes the listener. +func (l *unixListener) Close() error { + // this also closes the listener. + l.conn.forwards.remove(&net.UnixAddr{Name: l.socketPath, Net: "unix"}) + m := streamLocalChannelForwardMsg{ + l.socketPath, + } + ok, _, err := l.conn.SendRequest("cancel-streamlocal-forward@openssh.com", true, Marshal(&m)) + if err == nil && !ok { + err = errors.New("ssh: cancel-streamlocal-forward@openssh.com failed") + } + return err +} + +// Addr returns the listener's network address. +func (l *unixListener) Addr() net.Addr { + return &net.UnixAddr{ + Name: l.socketPath, + Net: "unix", + } +} diff --git a/tempfork/sshtest/ssh/tcpip.go b/tempfork/sshtest/ssh/tcpip.go new file mode 100644 index 000000000..ef5059a11 --- /dev/null +++ b/tempfork/sshtest/ssh/tcpip.go @@ -0,0 +1,509 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "context" + "errors" + "fmt" + "io" + "math/rand" + "net" + "strconv" + "strings" + "sync" + "time" +) + +// Listen requests the remote peer open a listening socket on +// addr. Incoming connections will be available by calling Accept on +// the returned net.Listener. The listener must be serviced, or the +// SSH connection may hang. +// N must be "tcp", "tcp4", "tcp6", or "unix". +func (c *Client) Listen(n, addr string) (net.Listener, error) { + switch n { + case "tcp", "tcp4", "tcp6": + laddr, err := net.ResolveTCPAddr(n, addr) + if err != nil { + return nil, err + } + return c.ListenTCP(laddr) + case "unix": + return c.ListenUnix(addr) + default: + return nil, fmt.Errorf("ssh: unsupported protocol: %s", n) + } +} + +// Automatic port allocation is broken with OpenSSH before 6.0. See +// also https://bugzilla.mindrot.org/show_bug.cgi?id=2017. In +// particular, OpenSSH 5.9 sends a channelOpenMsg with port number 0, +// rather than the actual port number. This means you can never open +// two different listeners with auto allocated ports. We work around +// this by trying explicit ports until we succeed. + +const openSSHPrefix = "OpenSSH_" + +var portRandomizer = rand.New(rand.NewSource(time.Now().UnixNano())) + +// isBrokenOpenSSHVersion returns true if the given version string +// specifies a version of OpenSSH that is known to have a bug in port +// forwarding. +func isBrokenOpenSSHVersion(versionStr string) bool { + i := strings.Index(versionStr, openSSHPrefix) + if i < 0 { + return false + } + i += len(openSSHPrefix) + j := i + for ; j < len(versionStr); j++ { + if versionStr[j] < '0' || versionStr[j] > '9' { + break + } + } + version, _ := strconv.Atoi(versionStr[i:j]) + return version < 6 +} + +// autoPortListenWorkaround simulates automatic port allocation by +// trying random ports repeatedly. +func (c *Client) autoPortListenWorkaround(laddr *net.TCPAddr) (net.Listener, error) { + var sshListener net.Listener + var err error + const tries = 10 + for i := 0; i < tries; i++ { + addr := *laddr + addr.Port = 1024 + portRandomizer.Intn(60000) + sshListener, err = c.ListenTCP(&addr) + if err == nil { + laddr.Port = addr.Port + return sshListener, err + } + } + return nil, fmt.Errorf("ssh: listen on random port failed after %d tries: %v", tries, err) +} + +// RFC 4254 7.1 +type channelForwardMsg struct { + addr string + rport uint32 +} + +// handleForwards starts goroutines handling forwarded connections. +// It's called on first use by (*Client).ListenTCP to not launch +// goroutines until needed. +func (c *Client) handleForwards() { + go c.forwards.handleChannels(c.HandleChannelOpen("forwarded-tcpip")) + go c.forwards.handleChannels(c.HandleChannelOpen("forwarded-streamlocal@openssh.com")) +} + +// ListenTCP requests the remote peer open a listening socket +// on laddr. Incoming connections will be available by calling +// Accept on the returned net.Listener. +func (c *Client) ListenTCP(laddr *net.TCPAddr) (net.Listener, error) { + c.handleForwardsOnce.Do(c.handleForwards) + if laddr.Port == 0 && isBrokenOpenSSHVersion(string(c.ServerVersion())) { + return c.autoPortListenWorkaround(laddr) + } + + m := channelForwardMsg{ + laddr.IP.String(), + uint32(laddr.Port), + } + // send message + ok, resp, err := c.SendRequest("tcpip-forward", true, Marshal(&m)) + if err != nil { + return nil, err + } + if !ok { + return nil, errors.New("ssh: tcpip-forward request denied by peer") + } + + // If the original port was 0, then the remote side will + // supply a real port number in the response. + if laddr.Port == 0 { + var p struct { + Port uint32 + } + if err := Unmarshal(resp, &p); err != nil { + return nil, err + } + laddr.Port = int(p.Port) + } + + // Register this forward, using the port number we obtained. + ch := c.forwards.add(laddr) + + return &tcpListener{laddr, c, ch}, nil +} + +// forwardList stores a mapping between remote +// forward requests and the tcpListeners. +type forwardList struct { + sync.Mutex + entries []forwardEntry +} + +// forwardEntry represents an established mapping of a laddr on a +// remote ssh server to a channel connected to a tcpListener. +type forwardEntry struct { + laddr net.Addr + c chan forward +} + +// forward represents an incoming forwarded tcpip connection. The +// arguments to add/remove/lookup should be address as specified in +// the original forward-request. +type forward struct { + newCh NewChannel // the ssh client channel underlying this forward + raddr net.Addr // the raddr of the incoming connection +} + +func (l *forwardList) add(addr net.Addr) chan forward { + l.Lock() + defer l.Unlock() + f := forwardEntry{ + laddr: addr, + c: make(chan forward, 1), + } + l.entries = append(l.entries, f) + return f.c +} + +// See RFC 4254, section 7.2 +type forwardedTCPPayload struct { + Addr string + Port uint32 + OriginAddr string + OriginPort uint32 +} + +// parseTCPAddr parses the originating address from the remote into a *net.TCPAddr. +func parseTCPAddr(addr string, port uint32) (*net.TCPAddr, error) { + if port == 0 || port > 65535 { + return nil, fmt.Errorf("ssh: port number out of range: %d", port) + } + ip := net.ParseIP(string(addr)) + if ip == nil { + return nil, fmt.Errorf("ssh: cannot parse IP address %q", addr) + } + return &net.TCPAddr{IP: ip, Port: int(port)}, nil +} + +func (l *forwardList) handleChannels(in <-chan NewChannel) { + for ch := range in { + var ( + laddr net.Addr + raddr net.Addr + err error + ) + switch channelType := ch.ChannelType(); channelType { + case "forwarded-tcpip": + var payload forwardedTCPPayload + if err = Unmarshal(ch.ExtraData(), &payload); err != nil { + ch.Reject(ConnectionFailed, "could not parse forwarded-tcpip payload: "+err.Error()) + continue + } + + // RFC 4254 section 7.2 specifies that incoming + // addresses should list the address, in string + // format. It is implied that this should be an IP + // address, as it would be impossible to connect to it + // otherwise. + laddr, err = parseTCPAddr(payload.Addr, payload.Port) + if err != nil { + ch.Reject(ConnectionFailed, err.Error()) + continue + } + raddr, err = parseTCPAddr(payload.OriginAddr, payload.OriginPort) + if err != nil { + ch.Reject(ConnectionFailed, err.Error()) + continue + } + + case "forwarded-streamlocal@openssh.com": + var payload forwardedStreamLocalPayload + if err = Unmarshal(ch.ExtraData(), &payload); err != nil { + ch.Reject(ConnectionFailed, "could not parse forwarded-streamlocal@openssh.com payload: "+err.Error()) + continue + } + laddr = &net.UnixAddr{ + Name: payload.SocketPath, + Net: "unix", + } + raddr = &net.UnixAddr{ + Name: "@", + Net: "unix", + } + default: + panic(fmt.Errorf("ssh: unknown channel type %s", channelType)) + } + if ok := l.forward(laddr, raddr, ch); !ok { + // Section 7.2, implementations MUST reject spurious incoming + // connections. + ch.Reject(Prohibited, "no forward for address") + continue + } + + } +} + +// remove removes the forward entry, and the channel feeding its +// listener. +func (l *forwardList) remove(addr net.Addr) { + l.Lock() + defer l.Unlock() + for i, f := range l.entries { + if addr.Network() == f.laddr.Network() && addr.String() == f.laddr.String() { + l.entries = append(l.entries[:i], l.entries[i+1:]...) + close(f.c) + return + } + } +} + +// closeAll closes and clears all forwards. +func (l *forwardList) closeAll() { + l.Lock() + defer l.Unlock() + for _, f := range l.entries { + close(f.c) + } + l.entries = nil +} + +func (l *forwardList) forward(laddr, raddr net.Addr, ch NewChannel) bool { + l.Lock() + defer l.Unlock() + for _, f := range l.entries { + if laddr.Network() == f.laddr.Network() && laddr.String() == f.laddr.String() { + f.c <- forward{newCh: ch, raddr: raddr} + return true + } + } + return false +} + +type tcpListener struct { + laddr *net.TCPAddr + + conn *Client + in <-chan forward +} + +// Accept waits for and returns the next connection to the listener. +func (l *tcpListener) Accept() (net.Conn, error) { + s, ok := <-l.in + if !ok { + return nil, io.EOF + } + ch, incoming, err := s.newCh.Accept() + if err != nil { + return nil, err + } + go DiscardRequests(incoming) + + return &chanConn{ + Channel: ch, + laddr: l.laddr, + raddr: s.raddr, + }, nil +} + +// Close closes the listener. +func (l *tcpListener) Close() error { + m := channelForwardMsg{ + l.laddr.IP.String(), + uint32(l.laddr.Port), + } + + // this also closes the listener. + l.conn.forwards.remove(l.laddr) + ok, _, err := l.conn.SendRequest("cancel-tcpip-forward", true, Marshal(&m)) + if err == nil && !ok { + err = errors.New("ssh: cancel-tcpip-forward failed") + } + return err +} + +// Addr returns the listener's network address. +func (l *tcpListener) Addr() net.Addr { + return l.laddr +} + +// DialContext initiates a connection to the addr from the remote host. +// +// The provided Context must be non-nil. If the context expires before the +// connection is complete, an error is returned. Once successfully connected, +// any expiration of the context will not affect the connection. +// +// See func Dial for additional information. +func (c *Client) DialContext(ctx context.Context, n, addr string) (net.Conn, error) { + if err := ctx.Err(); err != nil { + return nil, err + } + type connErr struct { + conn net.Conn + err error + } + ch := make(chan connErr) + go func() { + conn, err := c.Dial(n, addr) + select { + case ch <- connErr{conn, err}: + case <-ctx.Done(): + if conn != nil { + conn.Close() + } + } + }() + select { + case res := <-ch: + return res.conn, res.err + case <-ctx.Done(): + return nil, ctx.Err() + } +} + +// Dial initiates a connection to the addr from the remote host. +// The resulting connection has a zero LocalAddr() and RemoteAddr(). +func (c *Client) Dial(n, addr string) (net.Conn, error) { + var ch Channel + switch n { + case "tcp", "tcp4", "tcp6": + // Parse the address into host and numeric port. + host, portString, err := net.SplitHostPort(addr) + if err != nil { + return nil, err + } + port, err := strconv.ParseUint(portString, 10, 16) + if err != nil { + return nil, err + } + ch, err = c.dial(net.IPv4zero.String(), 0, host, int(port)) + if err != nil { + return nil, err + } + // Use a zero address for local and remote address. + zeroAddr := &net.TCPAddr{ + IP: net.IPv4zero, + Port: 0, + } + return &chanConn{ + Channel: ch, + laddr: zeroAddr, + raddr: zeroAddr, + }, nil + case "unix": + var err error + ch, err = c.dialStreamLocal(addr) + if err != nil { + return nil, err + } + return &chanConn{ + Channel: ch, + laddr: &net.UnixAddr{ + Name: "@", + Net: "unix", + }, + raddr: &net.UnixAddr{ + Name: addr, + Net: "unix", + }, + }, nil + default: + return nil, fmt.Errorf("ssh: unsupported protocol: %s", n) + } +} + +// DialTCP connects to the remote address raddr on the network net, +// which must be "tcp", "tcp4", or "tcp6". If laddr is not nil, it is used +// as the local address for the connection. +func (c *Client) DialTCP(n string, laddr, raddr *net.TCPAddr) (net.Conn, error) { + if laddr == nil { + laddr = &net.TCPAddr{ + IP: net.IPv4zero, + Port: 0, + } + } + ch, err := c.dial(laddr.IP.String(), laddr.Port, raddr.IP.String(), raddr.Port) + if err != nil { + return nil, err + } + return &chanConn{ + Channel: ch, + laddr: laddr, + raddr: raddr, + }, nil +} + +// RFC 4254 7.2 +type channelOpenDirectMsg struct { + raddr string + rport uint32 + laddr string + lport uint32 +} + +func (c *Client) dial(laddr string, lport int, raddr string, rport int) (Channel, error) { + msg := channelOpenDirectMsg{ + raddr: raddr, + rport: uint32(rport), + laddr: laddr, + lport: uint32(lport), + } + ch, in, err := c.OpenChannel("direct-tcpip", Marshal(&msg)) + if err != nil { + return nil, err + } + go DiscardRequests(in) + return ch, err +} + +type tcpChan struct { + Channel // the backing channel +} + +// chanConn fulfills the net.Conn interface without +// the tcpChan having to hold laddr or raddr directly. +type chanConn struct { + Channel + laddr, raddr net.Addr +} + +// LocalAddr returns the local network address. +func (t *chanConn) LocalAddr() net.Addr { + return t.laddr +} + +// RemoteAddr returns the remote network address. +func (t *chanConn) RemoteAddr() net.Addr { + return t.raddr +} + +// SetDeadline sets the read and write deadlines associated +// with the connection. +func (t *chanConn) SetDeadline(deadline time.Time) error { + if err := t.SetReadDeadline(deadline); err != nil { + return err + } + return t.SetWriteDeadline(deadline) +} + +// SetReadDeadline sets the read deadline. +// A zero value for t means Read will not time out. +// After the deadline, the error from Read will implement net.Error +// with Timeout() == true. +func (t *chanConn) SetReadDeadline(deadline time.Time) error { + // for compatibility with previous version, + // the error message contains "tcpChan" + return errors.New("ssh: tcpChan: deadline not supported") +} + +// SetWriteDeadline exists to satisfy the net.Conn interface +// but is not implemented by this type. It always returns an error. +func (t *chanConn) SetWriteDeadline(deadline time.Time) error { + return errors.New("ssh: tcpChan: deadline not supported") +} diff --git a/tempfork/sshtest/ssh/tcpip_test.go b/tempfork/sshtest/ssh/tcpip_test.go new file mode 100644 index 000000000..4d8511472 --- /dev/null +++ b/tempfork/sshtest/ssh/tcpip_test.go @@ -0,0 +1,53 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "context" + "net" + "testing" + "time" +) + +func TestAutoPortListenBroken(t *testing.T) { + broken := "SSH-2.0-OpenSSH_5.9hh11" + works := "SSH-2.0-OpenSSH_6.1" + if !isBrokenOpenSSHVersion(broken) { + t.Errorf("version %q not marked as broken", broken) + } + if isBrokenOpenSSHVersion(works) { + t.Errorf("version %q marked as broken", works) + } +} + +func TestClientImplementsDialContext(t *testing.T) { + type ContextDialer interface { + DialContext(context.Context, string, string) (net.Conn, error) + } + // Belt and suspenders assertion, since package net does not + // declare a ContextDialer type. + var _ ContextDialer = &net.Dialer{} + var _ ContextDialer = &Client{} +} + +func TestClientDialContextWithCancel(t *testing.T) { + c := &Client{} + ctx, cancel := context.WithCancel(context.Background()) + cancel() + _, err := c.DialContext(ctx, "tcp", "localhost:1000") + if err != context.Canceled { + t.Errorf("DialContext: got nil error, expected %v", context.Canceled) + } +} + +func TestClientDialContextWithDeadline(t *testing.T) { + c := &Client{} + ctx, cancel := context.WithDeadline(context.Background(), time.Now()) + defer cancel() + _, err := c.DialContext(ctx, "tcp", "localhost:1000") + if err != context.DeadlineExceeded { + t.Errorf("DialContext: got nil error, expected %v", context.DeadlineExceeded) + } +} diff --git a/tempfork/sshtest/ssh/testdata_test.go b/tempfork/sshtest/ssh/testdata_test.go new file mode 100644 index 000000000..2da8c79dc --- /dev/null +++ b/tempfork/sshtest/ssh/testdata_test.go @@ -0,0 +1,63 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// IMPLEMENTATION NOTE: To avoid a package loop, this file is in three places: +// ssh/, ssh/agent, and ssh/test/. It should be kept in sync across all three +// instances. + +package ssh + +import ( + "crypto/rand" + "fmt" + + "golang.org/x/crypto/ssh/testdata" +) + +var ( + testPrivateKeys map[string]interface{} + testSigners map[string]Signer + testPublicKeys map[string]PublicKey +) + +func init() { + var err error + + n := len(testdata.PEMBytes) + testPrivateKeys = make(map[string]interface{}, n) + testSigners = make(map[string]Signer, n) + testPublicKeys = make(map[string]PublicKey, n) + for t, k := range testdata.PEMBytes { + testPrivateKeys[t], err = ParseRawPrivateKey(k) + if err != nil { + panic(fmt.Sprintf("Unable to parse test key %s: %v", t, err)) + } + testSigners[t], err = NewSignerFromKey(testPrivateKeys[t]) + if err != nil { + panic(fmt.Sprintf("Unable to create signer for test key %s: %v", t, err)) + } + testPublicKeys[t] = testSigners[t].PublicKey() + } + + // Create a cert and sign it for use in tests. + testCert := &Certificate{ + Nonce: []byte{}, // To pass reflect.DeepEqual after marshal & parse, this must be non-nil + ValidPrincipals: []string{"gopher1", "gopher2"}, // increases test coverage + ValidAfter: 0, // unix epoch + ValidBefore: CertTimeInfinity, // The end of currently representable time. + Reserved: []byte{}, // To pass reflect.DeepEqual after marshal & parse, this must be non-nil + Key: testPublicKeys["ecdsa"], + SignatureKey: testPublicKeys["rsa"], + Permissions: Permissions{ + CriticalOptions: map[string]string{}, + Extensions: map[string]string{}, + }, + } + testCert.SignCert(rand.Reader, testSigners["rsa"]) + testPrivateKeys["cert"] = testPrivateKeys["ecdsa"] + testSigners["cert"], err = NewCertSigner(testCert, testSigners["ecdsa"]) + if err != nil { + panic(fmt.Sprintf("Unable to create certificate signer: %v", err)) + } +} diff --git a/tempfork/sshtest/ssh/transport.go b/tempfork/sshtest/ssh/transport.go new file mode 100644 index 000000000..0424d2d37 --- /dev/null +++ b/tempfork/sshtest/ssh/transport.go @@ -0,0 +1,380 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "bufio" + "bytes" + "errors" + "io" + "log" +) + +// debugTransport if set, will print packet types as they go over the +// wire. No message decoding is done, to minimize the impact on timing. +const debugTransport = false + +const ( + gcm128CipherID = "aes128-gcm@openssh.com" + gcm256CipherID = "aes256-gcm@openssh.com" + aes128cbcID = "aes128-cbc" + tripledescbcID = "3des-cbc" +) + +// packetConn represents a transport that implements packet based +// operations. +type packetConn interface { + // Encrypt and send a packet of data to the remote peer. + writePacket(packet []byte) error + + // Read a packet from the connection. The read is blocking, + // i.e. if error is nil, then the returned byte slice is + // always non-empty. + readPacket() ([]byte, error) + + // Close closes the write-side of the connection. + Close() error +} + +// transport is the keyingTransport that implements the SSH packet +// protocol. +type transport struct { + reader connectionState + writer connectionState + + bufReader *bufio.Reader + bufWriter *bufio.Writer + rand io.Reader + isClient bool + io.Closer + + strictMode bool + initialKEXDone bool +} + +// packetCipher represents a combination of SSH encryption/MAC +// protocol. A single instance should be used for one direction only. +type packetCipher interface { + // writeCipherPacket encrypts the packet and writes it to w. The + // contents of the packet are generally scrambled. + writeCipherPacket(seqnum uint32, w io.Writer, rand io.Reader, packet []byte) error + + // readCipherPacket reads and decrypts a packet of data. The + // returned packet may be overwritten by future calls of + // readPacket. + readCipherPacket(seqnum uint32, r io.Reader) ([]byte, error) +} + +// connectionState represents one side (read or write) of the +// connection. This is necessary because each direction has its own +// keys, and can even have its own algorithms +type connectionState struct { + packetCipher + seqNum uint32 + dir direction + pendingKeyChange chan packetCipher +} + +func (t *transport) setStrictMode() error { + if t.reader.seqNum != 1 { + return errors.New("ssh: sequence number != 1 when strict KEX mode requested") + } + t.strictMode = true + return nil +} + +func (t *transport) setInitialKEXDone() { + t.initialKEXDone = true +} + +// prepareKeyChange sets up key material for a keychange. The key changes in +// both directions are triggered by reading and writing a msgNewKey packet +// respectively. +func (t *transport) prepareKeyChange(algs *algorithms, kexResult *kexResult) error { + ciph, err := newPacketCipher(t.reader.dir, algs.r, kexResult) + if err != nil { + return err + } + t.reader.pendingKeyChange <- ciph + + ciph, err = newPacketCipher(t.writer.dir, algs.w, kexResult) + if err != nil { + return err + } + t.writer.pendingKeyChange <- ciph + + return nil +} + +func (t *transport) printPacket(p []byte, write bool) { + if len(p) == 0 { + return + } + who := "server" + if t.isClient { + who = "client" + } + what := "read" + if write { + what = "write" + } + + log.Println(what, who, p[0]) +} + +// Read and decrypt next packet. +func (t *transport) readPacket() (p []byte, err error) { + for { + p, err = t.reader.readPacket(t.bufReader, t.strictMode) + if err != nil { + break + } + // in strict mode we pass through DEBUG and IGNORE packets only during the initial KEX + if len(p) == 0 || (t.strictMode && !t.initialKEXDone) || (p[0] != msgIgnore && p[0] != msgDebug) { + break + } + } + if debugTransport { + t.printPacket(p, false) + } + + return p, err +} + +func (s *connectionState) readPacket(r *bufio.Reader, strictMode bool) ([]byte, error) { + packet, err := s.packetCipher.readCipherPacket(s.seqNum, r) + s.seqNum++ + if err == nil && len(packet) == 0 { + err = errors.New("ssh: zero length packet") + } + + if len(packet) > 0 { + switch packet[0] { + case msgNewKeys: + select { + case cipher := <-s.pendingKeyChange: + s.packetCipher = cipher + if strictMode { + s.seqNum = 0 + } + default: + return nil, errors.New("ssh: got bogus newkeys message") + } + + case msgDisconnect: + // Transform a disconnect message into an + // error. Since this is lowest level at which + // we interpret message types, doing it here + // ensures that we don't have to handle it + // elsewhere. + var msg disconnectMsg + if err := Unmarshal(packet, &msg); err != nil { + return nil, err + } + return nil, &msg + } + } + + // The packet may point to an internal buffer, so copy the + // packet out here. + fresh := make([]byte, len(packet)) + copy(fresh, packet) + + return fresh, err +} + +func (t *transport) writePacket(packet []byte) error { + if debugTransport { + t.printPacket(packet, true) + } + return t.writer.writePacket(t.bufWriter, t.rand, packet, t.strictMode) +} + +func (s *connectionState) writePacket(w *bufio.Writer, rand io.Reader, packet []byte, strictMode bool) error { + changeKeys := len(packet) > 0 && packet[0] == msgNewKeys + + err := s.packetCipher.writeCipherPacket(s.seqNum, w, rand, packet) + if err != nil { + return err + } + if err = w.Flush(); err != nil { + return err + } + s.seqNum++ + if changeKeys { + select { + case cipher := <-s.pendingKeyChange: + s.packetCipher = cipher + if strictMode { + s.seqNum = 0 + } + default: + panic("ssh: no key material for msgNewKeys") + } + } + return err +} + +func newTransport(rwc io.ReadWriteCloser, rand io.Reader, isClient bool) *transport { + t := &transport{ + bufReader: bufio.NewReader(rwc), + bufWriter: bufio.NewWriter(rwc), + rand: rand, + reader: connectionState{ + packetCipher: &streamPacketCipher{cipher: noneCipher{}}, + pendingKeyChange: make(chan packetCipher, 1), + }, + writer: connectionState{ + packetCipher: &streamPacketCipher{cipher: noneCipher{}}, + pendingKeyChange: make(chan packetCipher, 1), + }, + Closer: rwc, + } + t.isClient = isClient + + if isClient { + t.reader.dir = serverKeys + t.writer.dir = clientKeys + } else { + t.reader.dir = clientKeys + t.writer.dir = serverKeys + } + + return t +} + +type direction struct { + ivTag []byte + keyTag []byte + macKeyTag []byte +} + +var ( + serverKeys = direction{[]byte{'B'}, []byte{'D'}, []byte{'F'}} + clientKeys = direction{[]byte{'A'}, []byte{'C'}, []byte{'E'}} +) + +// setupKeys sets the cipher and MAC keys from kex.K, kex.H and sessionId, as +// described in RFC 4253, section 6.4. direction should either be serverKeys +// (to setup server->client keys) or clientKeys (for client->server keys). +func newPacketCipher(d direction, algs directionAlgorithms, kex *kexResult) (packetCipher, error) { + cipherMode := cipherModes[algs.Cipher] + + iv := make([]byte, cipherMode.ivSize) + key := make([]byte, cipherMode.keySize) + + generateKeyMaterial(iv, d.ivTag, kex) + generateKeyMaterial(key, d.keyTag, kex) + + var macKey []byte + if !aeadCiphers[algs.Cipher] { + macMode := macModes[algs.MAC] + macKey = make([]byte, macMode.keySize) + generateKeyMaterial(macKey, d.macKeyTag, kex) + } + + return cipherModes[algs.Cipher].create(key, iv, macKey, algs) +} + +// generateKeyMaterial fills out with key material generated from tag, K, H +// and sessionId, as specified in RFC 4253, section 7.2. +func generateKeyMaterial(out, tag []byte, r *kexResult) { + var digestsSoFar []byte + + h := r.Hash.New() + for len(out) > 0 { + h.Reset() + h.Write(r.K) + h.Write(r.H) + + if len(digestsSoFar) == 0 { + h.Write(tag) + h.Write(r.SessionID) + } else { + h.Write(digestsSoFar) + } + + digest := h.Sum(nil) + n := copy(out, digest) + out = out[n:] + if len(out) > 0 { + digestsSoFar = append(digestsSoFar, digest...) + } + } +} + +const packageVersion = "SSH-2.0-Go" + +// Sends and receives a version line. The versionLine string should +// be US ASCII, start with "SSH-2.0-", and should not include a +// newline. exchangeVersions returns the other side's version line. +func exchangeVersions(rw io.ReadWriter, versionLine []byte) (them []byte, err error) { + // Contrary to the RFC, we do not ignore lines that don't + // start with "SSH-2.0-" to make the library usable with + // nonconforming servers. + for _, c := range versionLine { + // The spec disallows non US-ASCII chars, and + // specifically forbids null chars. + if c < 32 { + return nil, errors.New("ssh: junk character in version line") + } + } + if _, err = rw.Write(append(versionLine, '\r', '\n')); err != nil { + return + } + + them, err = readVersion(rw) + return them, err +} + +// maxVersionStringBytes is the maximum number of bytes that we'll +// accept as a version string. RFC 4253 section 4.2 limits this at 255 +// chars +const maxVersionStringBytes = 255 + +// Read version string as specified by RFC 4253, section 4.2. +func readVersion(r io.Reader) ([]byte, error) { + versionString := make([]byte, 0, 64) + var ok bool + var buf [1]byte + + for length := 0; length < maxVersionStringBytes; length++ { + _, err := io.ReadFull(r, buf[:]) + if err != nil { + return nil, err + } + // The RFC says that the version should be terminated with \r\n + // but several SSH servers actually only send a \n. + if buf[0] == '\n' { + if !bytes.HasPrefix(versionString, []byte("SSH-")) { + // RFC 4253 says we need to ignore all version string lines + // except the one containing the SSH version (provided that + // all the lines do not exceed 255 bytes in total). + versionString = versionString[:0] + continue + } + ok = true + break + } + + // non ASCII chars are disallowed, but we are lenient, + // since Go doesn't use null-terminated strings. + + // The RFC allows a comment after a space, however, + // all of it (version and comments) goes into the + // session hash. + versionString = append(versionString, buf[0]) + } + + if !ok { + return nil, errors.New("ssh: overflow reading version string") + } + + // There might be a '\r' on the end which we should remove. + if len(versionString) > 0 && versionString[len(versionString)-1] == '\r' { + versionString = versionString[:len(versionString)-1] + } + return versionString, nil +} diff --git a/tempfork/sshtest/ssh/transport_test.go b/tempfork/sshtest/ssh/transport_test.go new file mode 100644 index 000000000..8445e1e56 --- /dev/null +++ b/tempfork/sshtest/ssh/transport_test.go @@ -0,0 +1,113 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "bytes" + "crypto/rand" + "encoding/binary" + "strings" + "testing" +) + +func TestReadVersion(t *testing.T) { + longVersion := strings.Repeat("SSH-2.0-bla", 50)[:253] + multiLineVersion := strings.Repeat("ignored\r\n", 20) + "SSH-2.0-bla\r\n" + cases := map[string]string{ + "SSH-2.0-bla\r\n": "SSH-2.0-bla", + "SSH-2.0-bla\n": "SSH-2.0-bla", + multiLineVersion: "SSH-2.0-bla", + longVersion + "\r\n": longVersion, + } + + for in, want := range cases { + result, err := readVersion(bytes.NewBufferString(in)) + if err != nil { + t.Errorf("readVersion(%q): %s", in, err) + } + got := string(result) + if got != want { + t.Errorf("got %q, want %q", got, want) + } + } +} + +func TestReadVersionError(t *testing.T) { + longVersion := strings.Repeat("SSH-2.0-bla", 50)[:253] + multiLineVersion := strings.Repeat("ignored\r\n", 50) + "SSH-2.0-bla\r\n" + cases := []string{ + longVersion + "too-long\r\n", + multiLineVersion, + } + for _, in := range cases { + if _, err := readVersion(bytes.NewBufferString(in)); err == nil { + t.Errorf("readVersion(%q) should have failed", in) + } + } +} + +func TestExchangeVersionsBasic(t *testing.T) { + v := "SSH-2.0-bla" + buf := bytes.NewBufferString(v + "\r\n") + them, err := exchangeVersions(buf, []byte("xyz")) + if err != nil { + t.Errorf("exchangeVersions: %v", err) + } + + if want := "SSH-2.0-bla"; string(them) != want { + t.Errorf("got %q want %q for our version", them, want) + } +} + +func TestExchangeVersions(t *testing.T) { + cases := []string{ + "not\x000allowed", + "not allowed\x01\r\n", + } + for _, c := range cases { + buf := bytes.NewBufferString("SSH-2.0-bla\r\n") + if _, err := exchangeVersions(buf, []byte(c)); err == nil { + t.Errorf("exchangeVersions(%q): should have failed", c) + } + } +} + +type closerBuffer struct { + bytes.Buffer +} + +func (b *closerBuffer) Close() error { + return nil +} + +func TestTransportMaxPacketWrite(t *testing.T) { + buf := &closerBuffer{} + tr := newTransport(buf, rand.Reader, true) + huge := make([]byte, maxPacket+1) + err := tr.writePacket(huge) + if err == nil { + t.Errorf("transport accepted write for a huge packet.") + } +} + +func TestTransportMaxPacketReader(t *testing.T) { + var header [5]byte + huge := make([]byte, maxPacket+128) + binary.BigEndian.PutUint32(header[0:], uint32(len(huge))) + // padding. + header[4] = 0 + + buf := &closerBuffer{} + buf.Write(header[:]) + buf.Write(huge) + + tr := newTransport(buf, rand.Reader, true) + _, err := tr.readPacket() + if err == nil { + t.Errorf("transport succeeded reading huge packet.") + } else if !strings.Contains(err.Error(), "large") { + t.Errorf("got %q, should mention %q", err.Error(), "large") + } +} diff --git a/tka/aum.go b/tka/aum.go index 07a34b4f6..b8c4b6c9e 100644 --- a/tka/aum.go +++ b/tka/aum.go @@ -1,6 +1,8 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause +//go:build !ts_omit_tailnetlock + package tka import ( @@ -29,8 +31,8 @@ func (h AUMHash) String() string { // UnmarshalText implements encoding.TextUnmarshaler. func (h *AUMHash) UnmarshalText(text []byte) error { - if l := base32StdNoPad.DecodedLen(len(text)); l != len(h) { - return fmt.Errorf("tka.AUMHash.UnmarshalText: text wrong length: %d, want %d", l, len(text)) + if ln := base32StdNoPad.DecodedLen(len(text)); ln != len(h) { + return fmt.Errorf("tka.AUMHash.UnmarshalText: text wrong length: %d, want %d", ln, len(text)) } if _, err := base32StdNoPad.Decode(h[:], text); err != nil { return fmt.Errorf("tka.AUMHash.UnmarshalText: %w", err) @@ -53,6 +55,17 @@ func (h AUMHash) IsZero() bool { return h == (AUMHash{}) } +// PrevAUMHash represents the BLAKE2s digest of an Authority Update Message (AUM). +// Unlike an AUMHash, this can be empty if there is no previous AUM hash +// (which occurs in the genesis AUM). +type PrevAUMHash []byte + +// String returns the PrevAUMHash encoded as base32. +// This is suitable for use as a filename, and for storing in text-preferred media. +func (h PrevAUMHash) String() string { + return base32StdNoPad.EncodeToString(h[:]) +} + // AUMKind describes valid AUM types. type AUMKind uint8 @@ -117,8 +130,8 @@ func (k AUMKind) String() string { // behavior of old clients (which will ignore the field). // - No floats! type AUM struct { - MessageKind AUMKind `cbor:"1,keyasint"` - PrevAUMHash []byte `cbor:"2,keyasint"` + MessageKind AUMKind `cbor:"1,keyasint"` + PrevAUMHash PrevAUMHash `cbor:"2,keyasint"` // Key encodes a public key to be added to the key authority. // This field is used for AddKey AUMs. @@ -224,7 +237,7 @@ func (a *AUM) Serialize() tkatype.MarshaledAUM { // Further, experience with other attempts (JWS/JWT,SAML,X509 etc) has // taught us that even subtle behaviors such as how you handle invalid // or unrecognized fields + any invariants in subsequent re-serialization - // can easily lead to security-relevant logic bugs. Its certainly possible + // can easily lead to security-relevant logic bugs. It's certainly possible // to invent a workable scheme by massaging a JSON parsing library, though // profoundly unwise. // diff --git a/tka/aum_test.go b/tka/aum_test.go index 4297efabf..833a02654 100644 --- a/tka/aum_test.go +++ b/tka/aum_test.go @@ -5,6 +5,8 @@ package tka import ( "bytes" + "encoding/base64" + "fmt" "testing" "github.com/google/go-cmp/cmp" @@ -156,6 +158,80 @@ func TestSerialization(t *testing.T) { } } +func fromBase64(s string) []byte { + data, err := base64.StdEncoding.DecodeString(s) + if err != nil { + panic(fmt.Sprintf("base64 decode failed: %v", err)) + } + return data +} + +// This test verifies that we can read AUMs which were serialized with +// older versions of our code. +func TestDeserializeExistingAUMs(t *testing.T) { + for _, tt := range []struct { + Name string + Data []byte + Want AUM + }{ + { + // This is an AUM which was created in a test tailnet, and encoded + // on 2025-11-07 with commit d4c5b27. + Name: "genesis-aum-2025-11-07", + Data: fromBase64("pAEFAvYFpQH2AopYII0sLaLSEZU3W5DT1dG2WYnzjCBr4tXtVbCT2LvA9LS6WCAQhwVGDiUGRiu3P63gucZ/8otjt2DXyk+OBjbh5iWx1Fgg5VU4oRQiMoq5qK00McfpwtmjcheVammLCRwzdp2Zje9YIHDoOXe4ogPSy7lfA/veyPCKM6iZe3PTgzhQZ4W5Sh7wWCBYQtiQ6NcRlyVARJxgAj1BbbvdJQ0t4m+vHqU1J02oDlgg2sksJA+COfsBkrohwHBWlbKrpS8Mvigpl+enuHw9rIJYIB/+CUBBBLUz0KeHu7NKrg5ZEhjjPUWhNcf9QTNHjuNWWCCJuxqPZ6/IASPTmAERaoKnBNH/D+zY4p4TUGHR4fACjFggMtDAipPutgcxKnU9Tg2663gP3KlTQfztV3hBwiePZdRYIGYeD2erBkRouSL20lOnWHHlRq5kmNfN6xFb2CTaPjnXA4KjAQECAQNYIADftG3yaitV/YMoKSBP45zgyeodClumN9ZaeQg/DmCEowEBAgEDWCBRKbmWSzOyHXbHJuYn8s7dmMPDzxmIjgBoA80cBYgItAQbEWOrxfqJzIkFG/5uNUp0s/ScF4GiAVggAN+0bfJqK1X9gygpIE/jnODJ6h0KW6Y31lp5CD8OYIQCWEAENvzblKV2qx6PED5YdGy8kWa7nxEnaeuMmS5Wkx0n7CXs0XxD5f2NIE+pSv9cOsNkfYNndQkYD7ne33hQOsQM"), + Want: AUM{ + MessageKind: AUMCheckpoint, + State: &State{ + DisablementSecrets: [][]byte{ + fromBase64("jSwtotIRlTdbkNPV0bZZifOMIGvi1e1VsJPYu8D0tLo="), + fromBase64("EIcFRg4lBkYrtz+t4LnGf/KLY7dg18pPjgY24eYlsdQ="), + fromBase64("5VU4oRQiMoq5qK00McfpwtmjcheVammLCRwzdp2Zje8="), + fromBase64("cOg5d7iiA9LLuV8D+97I8IozqJl7c9ODOFBnhblKHvA="), + fromBase64("WELYkOjXEZclQEScYAI9QW273SUNLeJvrx6lNSdNqA4="), + fromBase64("2sksJA+COfsBkrohwHBWlbKrpS8Mvigpl+enuHw9rII="), + fromBase64("H/4JQEEEtTPQp4e7s0quDlkSGOM9RaE1x/1BM0eO41Y="), + fromBase64("ibsaj2evyAEj05gBEWqCpwTR/w/s2OKeE1Bh0eHwAow="), + fromBase64("MtDAipPutgcxKnU9Tg2663gP3KlTQfztV3hBwiePZdQ="), + fromBase64("Zh4PZ6sGRGi5IvbSU6dYceVGrmSY183rEVvYJNo+Odc="), + }, + Keys: []Key{ + { + Kind: Key25519, + Votes: 1, + Public: fromBase64("AN+0bfJqK1X9gygpIE/jnODJ6h0KW6Y31lp5CD8OYIQ="), + }, + { + Kind: Key25519, + Votes: 1, + Public: fromBase64("USm5lkszsh12xybmJ/LO3ZjDw88ZiI4AaAPNHAWICLQ="), + }, + }, + StateID1: 1253033988139371657, + StateID2: 18333649726973670556, + }, + Signatures: []tkatype.Signature{ + { + KeyID: fromBase64("AN+0bfJqK1X9gygpIE/jnODJ6h0KW6Y31lp5CD8OYIQ="), + Signature: fromBase64("BDb825SldqsejxA+WHRsvJFmu58RJ2nrjJkuVpMdJ+wl7NF8Q+X9jSBPqUr/XDrDZH2DZ3UJGA+53t94UDrEDA=="), + }, + }, + }, + }, + } { + t.Run(tt.Name, func(t *testing.T) { + var got AUM + + if err := got.Unserialize(tt.Data); err != nil { + t.Fatalf("Unserialize: %v", err) + } + + if diff := cmp.Diff(got, tt.Want); diff != "" { + t.Fatalf("wrong AUM (-got, +want):\n%s", diff) + } + }) + } +} + func TestAUMWeight(t *testing.T) { var fakeKeyID [blake2s.Size]byte testingRand(t, 1).Read(fakeKeyID[:]) diff --git a/tka/builder.go b/tka/builder.go index c14ba2330..ab2364d85 100644 --- a/tka/builder.go +++ b/tka/builder.go @@ -1,6 +1,8 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause +//go:build !ts_omit_tailnetlock + package tka import ( @@ -67,6 +69,11 @@ func (b *UpdateBuilder) AddKey(key Key) error { if _, err := b.state.GetKey(keyID); err == nil { return fmt.Errorf("cannot add key %v: already exists", key) } + + if len(b.state.Keys) >= maxKeys { + return fmt.Errorf("cannot add key %v: maximum number of keys reached", key) + } + return b.mkUpdate(AUM{MessageKind: AUMAddKey, Key: &key}) } @@ -107,7 +114,7 @@ func (b *UpdateBuilder) generateCheckpoint() error { } } - // Checkpoints cant specify a parent AUM. + // Checkpoints can't specify a parent AUM. state.LastAUMHash = nil return b.mkUpdate(AUM{MessageKind: AUMCheckpoint, State: &state}) } @@ -129,7 +136,7 @@ func (b *UpdateBuilder) Finalize(storage Chonk) ([]AUM, error) { needCheckpoint = false break } - return nil, fmt.Errorf("reading AUM: %v", err) + return nil, fmt.Errorf("reading AUM (%v): %v", cursor, err) } if aum.MessageKind == AUMCheckpoint { diff --git a/tka/builder_test.go b/tka/builder_test.go index 666af9ad0..3fd32f64e 100644 --- a/tka/builder_test.go +++ b/tka/builder_test.go @@ -5,6 +5,7 @@ package tka import ( "crypto/ed25519" + "strings" "testing" "github.com/google/go-cmp/cmp" @@ -27,7 +28,7 @@ func TestAuthorityBuilderAddKey(t *testing.T) { pub, priv := testingKey25519(t, 1) key := Key{Kind: Key25519, Public: pub, Votes: 2} - storage := &Mem{} + storage := ChonkMem() a, _, err := Create(storage, State{ Keys: []Key{key}, DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, @@ -57,6 +58,50 @@ func TestAuthorityBuilderAddKey(t *testing.T) { t.Errorf("could not read new key: %v", err) } } +func TestAuthorityBuilderMaxKey(t *testing.T) { + pub, priv := testingKey25519(t, 1) + key := Key{Kind: Key25519, Public: pub, Votes: 2} + + storage := ChonkMem() + a, _, err := Create(storage, State{ + Keys: []Key{key}, + DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, + }, signer25519(priv)) + if err != nil { + t.Fatalf("Create() failed: %v", err) + } + + for i := 0; i <= maxKeys; i++ { + pub2, _ := testingKey25519(t, int64(2+i)) + key2 := Key{Kind: Key25519, Public: pub2, Votes: 1} + + b := a.NewUpdater(signer25519(priv)) + err := b.AddKey(key2) + if i < maxKeys-1 { + if err != nil { + t.Fatalf("AddKey(%v) failed: %v", key2, err) + } + } else { + // Too many keys. + if err == nil { + t.Fatalf("AddKey(%v) succeeded unexpectedly", key2) + } + continue + } + + updates, err := b.Finalize(storage) + if err != nil { + t.Fatalf("Finalize() failed: %v", err) + } + + if err := a.Inform(storage, updates); err != nil { + t.Fatalf("could not apply generated updates: %v", err) + } + if _, err := a.state.GetKey(key2.MustID()); err != nil { + t.Errorf("could not read new key: %v", err) + } + } +} func TestAuthorityBuilderRemoveKey(t *testing.T) { pub, priv := testingKey25519(t, 1) @@ -64,7 +109,7 @@ func TestAuthorityBuilderRemoveKey(t *testing.T) { pub2, _ := testingKey25519(t, 2) key2 := Key{Kind: Key25519, Public: pub2, Votes: 1} - storage := &Mem{} + storage := ChonkMem() a, _, err := Create(storage, State{ Keys: []Key{key, key2}, DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, @@ -90,13 +135,27 @@ func TestAuthorityBuilderRemoveKey(t *testing.T) { if _, err := a.state.GetKey(key2.MustID()); err != ErrNoSuchKey { t.Errorf("GetKey(key2).err = %v, want %v", err, ErrNoSuchKey) } + + // Check that removing the remaining key errors out. + b = a.NewUpdater(signer25519(priv)) + if err := b.RemoveKey(key.MustID()); err != nil { + t.Fatalf("RemoveKey(%v) failed: %v", key, err) + } + updates, err = b.Finalize(storage) + if err != nil { + t.Fatalf("Finalize() failed: %v", err) + } + wantErr := "cannot remove the last key" + if err := a.Inform(storage, updates); err == nil || !strings.Contains(err.Error(), wantErr) { + t.Fatalf("expected Inform() to return error %q, got: %v", wantErr, err) + } } func TestAuthorityBuilderSetKeyVote(t *testing.T) { pub, priv := testingKey25519(t, 1) key := Key{Kind: Key25519, Public: pub, Votes: 2} - storage := &Mem{} + storage := ChonkMem() a, _, err := Create(storage, State{ Keys: []Key{key}, DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, @@ -132,7 +191,7 @@ func TestAuthorityBuilderSetKeyMeta(t *testing.T) { pub, priv := testingKey25519(t, 1) key := Key{Kind: Key25519, Public: pub, Votes: 2, Meta: map[string]string{"a": "b"}} - storage := &Mem{} + storage := ChonkMem() a, _, err := Create(storage, State{ Keys: []Key{key}, DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, @@ -168,7 +227,7 @@ func TestAuthorityBuilderMultiple(t *testing.T) { pub, priv := testingKey25519(t, 1) key := Key{Kind: Key25519, Public: pub, Votes: 2} - storage := &Mem{} + storage := ChonkMem() a, _, err := Create(storage, State{ Keys: []Key{key}, DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, @@ -216,7 +275,7 @@ func TestAuthorityBuilderCheckpointsAfterXUpdates(t *testing.T) { pub, priv := testingKey25519(t, 1) key := Key{Kind: Key25519, Public: pub, Votes: 2} - storage := &Mem{} + storage := ChonkMem() a, _, err := Create(storage, State{ Keys: []Key{key}, DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, diff --git a/tka/chaintest_test.go b/tka/chaintest_test.go index 5811f9c83..a3122b5d1 100644 --- a/tka/chaintest_test.go +++ b/tka/chaintest_test.go @@ -285,25 +285,25 @@ func (c *testChain) makeAUM(v *testchainNode) AUM { // Chonk returns a tailchonk containing all AUMs. func (c *testChain) Chonk() Chonk { - var out Mem + out := ChonkMem() for _, update := range c.AUMs { if err := out.CommitVerifiedAUMs([]AUM{update}); err != nil { panic(err) } } - return &out + return out } // ChonkWith returns a tailchonk containing the named AUMs. func (c *testChain) ChonkWith(names ...string) Chonk { - var out Mem + out := ChonkMem() for _, name := range names { update := c.AUMs[name] if err := out.CommitVerifiedAUMs([]AUM{update}); err != nil { panic(err) } } - return &out + return out } type testchainOpt struct { diff --git a/tka/deeplink.go b/tka/deeplink.go index 5cf24fc5c..5570a19d7 100644 --- a/tka/deeplink.go +++ b/tka/deeplink.go @@ -1,6 +1,8 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause +//go:build !ts_omit_tailnetlock + package tka import ( diff --git a/tka/disabled_stub.go b/tka/disabled_stub.go new file mode 100644 index 000000000..4c4afa370 --- /dev/null +++ b/tka/disabled_stub.go @@ -0,0 +1,160 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build ts_omit_tailnetlock + +package tka + +import ( + "crypto/ed25519" + "errors" + + "tailscale.com/types/key" + "tailscale.com/types/logger" + "tailscale.com/types/tkatype" +) + +type Authority struct { + head AUM + oldestAncestor AUM + state State +} + +func (*Authority) Head() AUMHash { return AUMHash{} } + +// MarshalText returns a dummy value explaining that Tailnet Lock +// is not compiled in to this binary. +// +// We need to be able to marshal AUMHash to text because it's included +// in [netmap.NetworkMap], which gets serialised as JSON in the +// c2n /debug/netmap endpoint. +// +// We provide a basic marshaller so that endpoint works correctly +// with nodes that omit Tailnet Lock support, but we don't want the +// base32 dependency used for the regular marshaller, and we don't +// need unmarshalling support at time of writing (2025-11-18). +func (h AUMHash) MarshalText() ([]byte, error) { + return []byte(""), nil +} + +func (h *AUMHash) UnmarshalText(text []byte) error { + return errors.New("tailnet lock is not supported by this binary") +} + +type State struct{} + +// AUMKind describes valid AUM types. +type AUMKind uint8 + +type AUMHash [32]byte + +type AUM struct { + MessageKind AUMKind `cbor:"1,keyasint"` + PrevAUMHash []byte `cbor:"2,keyasint"` + + // Key encodes a public key to be added to the key authority. + // This field is used for AddKey AUMs. + Key *Key `cbor:"3,keyasint,omitempty"` + + // KeyID references a public key which is part of the key authority. + // This field is used for RemoveKey and UpdateKey AUMs. + KeyID tkatype.KeyID `cbor:"4,keyasint,omitempty"` + + // State describes the full state of the key authority. + // This field is used for Checkpoint AUMs. + State *State `cbor:"5,keyasint,omitempty"` + + // Votes and Meta describe properties of a key in the key authority. + // These fields are used for UpdateKey AUMs. + Votes *uint `cbor:"6,keyasint,omitempty"` + Meta map[string]string `cbor:"7,keyasint,omitempty"` + + // Signatures lists the signatures over this AUM. + // CBOR key 23 is the last key which can be encoded as a single byte. + Signatures []tkatype.Signature `cbor:"23,keyasint,omitempty"` +} + +type Chonk interface { + // AUM returns the AUM with the specified digest. + // + // If the AUM does not exist, then os.ErrNotExist is returned. + AUM(hash AUMHash) (AUM, error) + + // ChildAUMs returns all AUMs with a specified previous + // AUM hash. + ChildAUMs(prevAUMHash AUMHash) ([]AUM, error) + + // CommitVerifiedAUMs durably stores the provided AUMs. + // Callers MUST ONLY provide AUMs which are verified (specifically, + // a call to aumVerify() must return a nil error). + // as the implementation assumes that only verified AUMs are stored. + CommitVerifiedAUMs(updates []AUM) error + + // Heads returns AUMs for which there are no children. In other + // words, the latest AUM in all possible chains (the 'leaves'). + Heads() ([]AUM, error) + + // SetLastActiveAncestor is called to record the oldest-known AUM + // that contributed to the current state. This value is used as + // a hint on next startup to determine which chain to pick when computing + // the current state, if there are multiple distinct chains. + SetLastActiveAncestor(hash AUMHash) error + + // LastActiveAncestor returns the oldest-known AUM that was (in a + // previous run) an ancestor of the current state. This is used + // as a hint to pick the correct chain in the event that the Chonk stores + // multiple distinct chains. + LastActiveAncestor() (*AUMHash, error) +} + +// SigKind describes valid NodeKeySignature types. +type SigKind uint8 + +type NodeKeySignature struct { + // SigKind identifies the variety of signature. + SigKind SigKind `cbor:"1,keyasint"` + // Pubkey identifies the key.NodePublic which is being authorized. + // SigCredential signatures do not use this field. + Pubkey []byte `cbor:"2,keyasint,omitempty"` + + // KeyID identifies which key in the tailnet key authority should + // be used to verify this signature. Only set for SigDirect and + // SigCredential signature kinds. + KeyID []byte `cbor:"3,keyasint,omitempty"` + + // Signature is the packed (R, S) ed25519 signature over all other + // fields of the structure. + Signature []byte `cbor:"4,keyasint,omitempty"` + + // Nested describes a NodeKeySignature which authorizes the node-key + // used as Pubkey. Only used for SigRotation signatures. + Nested *NodeKeySignature `cbor:"5,keyasint,omitempty"` + + // WrappingPubkey specifies the ed25519 public key which must be used + // to sign a Signature which embeds this one. + // + // For SigRotation signatures multiple levels deep, intermediate + // signatures may omit this value, in which case the parent WrappingPubkey + // is used. + // + // SigCredential signatures use this field to specify the public key + // they are certifying, following the usual semanticsfor WrappingPubkey. + WrappingPubkey []byte `cbor:"6,keyasint,omitempty"` +} + +type DeeplinkValidationResult struct { +} + +func DecodeWrappedAuthkey(wrappedAuthKey string, logf logger.Logf) (authKey string, isWrapped bool, sig *NodeKeySignature, priv ed25519.PrivateKey) { + return wrappedAuthKey, false, nil, nil +} + +func ResignNKS(priv key.NLPrivate, nodeKey key.NodePublic, oldNKS tkatype.MarshaledSignature) (tkatype.MarshaledSignature, error) { + return nil, nil +} + +func SignByCredential(privKey []byte, wrapped *NodeKeySignature, nodeKey key.NodePublic) (tkatype.MarshaledSignature, error) { + return nil, nil +} + +func (s NodeKeySignature) String() string { return "" } diff --git a/tka/key.go b/tka/key.go index 07736795d..dca1b4416 100644 --- a/tka/key.go +++ b/tka/key.go @@ -8,7 +8,6 @@ import ( "errors" "fmt" - "github.com/hdevalence/ed25519consensus" "tailscale.com/types/tkatype" ) @@ -136,24 +135,3 @@ func (k Key) StaticValidate() error { } return nil } - -// Verify returns a nil error if the signature is valid over the -// provided AUM BLAKE2s digest, using the given key. -func signatureVerify(s *tkatype.Signature, aumDigest tkatype.AUMSigHash, key Key) error { - // NOTE(tom): Even if we can compute the public from the KeyID, - // its possible for the KeyID to be attacker-controlled - // so we should use the public contained in the state machine. - switch key.Kind { - case Key25519: - if len(key.Public) != ed25519.PublicKeySize { - return fmt.Errorf("ed25519 key has wrong length: %d", len(key.Public)) - } - if ed25519consensus.Verify(ed25519.PublicKey(key.Public), aumDigest[:], s.Signature) { - return nil - } - return errors.New("invalid signature") - - default: - return fmt.Errorf("unhandled key type: %v", key.Kind) - } -} diff --git a/tka/key_test.go b/tka/key_test.go index e912f89c4..327de1a0e 100644 --- a/tka/key_test.go +++ b/tka/key_test.go @@ -42,7 +42,7 @@ func TestVerify25519(t *testing.T) { aum := AUM{ MessageKind: AUMRemoveKey, KeyID: []byte{1, 2, 3, 4}, - // Signatures is set to crap so we are sure its ignored in the sigHash computation. + // Signatures is set to crap so we are sure it's ignored in the sigHash computation. Signatures: []tkatype.Signature{{KeyID: []byte{45, 42}}}, } sigHash := aum.SigHash() @@ -72,7 +72,7 @@ func TestNLPrivate(t *testing.T) { // Test that key.NLPrivate implements Signer by making a new // authority. k := Key{Kind: Key25519, Public: pub.Verifier(), Votes: 1} - _, aum, err := Create(&Mem{}, State{ + _, aum, err := Create(ChonkMem(), State{ Keys: []Key{k}, DisablementSecrets: [][]byte{bytes.Repeat([]byte{1}, 32)}, }, p) @@ -89,7 +89,7 @@ func TestNLPrivate(t *testing.T) { t.Error("signature did not verify") } - // We manually compute the keyID, so make sure its consistent with + // We manually compute the keyID, so make sure it's consistent with // tka.Key.ID(). if !bytes.Equal(k.MustID(), p.KeyID()) { t.Errorf("private.KeyID() & tka KeyID differ: %x != %x", k.MustID(), p.KeyID()) diff --git a/tka/scenario_test.go b/tka/scenario_test.go index 89a8111e1..a0361a130 100644 --- a/tka/scenario_test.go +++ b/tka/scenario_test.go @@ -204,7 +204,7 @@ func TestNormalPropagation(t *testing.T) { `) control := s.mkNode("control") - // Lets say theres a node with some updates! + // Let's say there's a node with some updates! n1 := s.mkNodeWithForks("n1", true, map[string]*testChain{ "L2": newTestchain(t, `L3 -> L4`), }) diff --git a/tka/sig.go b/tka/sig.go index c82f9715c..46d598ad9 100644 --- a/tka/sig.go +++ b/tka/sig.go @@ -1,6 +1,8 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause +//go:build !ts_omit_tailnetlock + package tka import ( @@ -275,7 +277,7 @@ func (s *NodeKeySignature) verifySignature(nodeKey key.NodePublic, verificationK // Recurse to verify the signature on the nested structure. var nestedPub key.NodePublic // SigCredential signatures certify an indirection key rather than a node - // key, so theres no need to check the node key. + // key, so there's no need to check the node key. if s.Nested.SigKind != SigCredential { if err := nestedPub.UnmarshalBinary(s.Nested.Pubkey); err != nil { return fmt.Errorf("nested pubkey: %v", err) diff --git a/tka/sig_test.go b/tka/sig_test.go index d64575e7c..c5c03ef2e 100644 --- a/tka/sig_test.go +++ b/tka/sig_test.go @@ -76,8 +76,8 @@ func TestSigNested(t *testing.T) { if err := nestedSig.verifySignature(oldNode.Public(), k); err != nil { t.Fatalf("verifySignature(oldNode) failed: %v", err) } - if l := sigChainLength(nestedSig); l != 1 { - t.Errorf("nestedSig chain length = %v, want 1", l) + if ln := sigChainLength(nestedSig); ln != 1 { + t.Errorf("nestedSig chain length = %v, want 1", ln) } // The signature authorizing the rotation, signed by the @@ -93,8 +93,8 @@ func TestSigNested(t *testing.T) { if err := sig.verifySignature(node.Public(), k); err != nil { t.Fatalf("verifySignature(node) failed: %v", err) } - if l := sigChainLength(sig); l != 2 { - t.Errorf("sig chain length = %v, want 2", l) + if ln := sigChainLength(sig); ln != 2 { + t.Errorf("sig chain length = %v, want 2", ln) } // Test verification fails if the wrong verification key is provided @@ -119,7 +119,7 @@ func TestSigNested(t *testing.T) { } // Test verification fails if the outer signature is signed with a - // different public key to whats specified in WrappingPubkey + // different public key to what's specified in WrappingPubkey sig.Signature = ed25519.Sign(priv, sigHash[:]) if err := sig.verifySignature(node.Public(), k); err == nil { t.Error("verifySignature(node) succeeded with different signature") @@ -275,7 +275,7 @@ func TestSigCredential(t *testing.T) { } // Test verification fails if the outer signature is signed with a - // different public key to whats specified in WrappingPubkey + // different public key to what's specified in WrappingPubkey sig.Signature = ed25519.Sign(priv, sigHash[:]) if err := sig.verifySignature(node.Public(), k); err == nil { t.Error("verifySignature(node) succeeded with different signature") @@ -507,7 +507,7 @@ func TestDecodeWrappedAuthkey(t *testing.T) { } func TestResignNKS(t *testing.T) { - // Tailnet lock keypair of a signing node. + // Tailnet Lock keypair of a signing node. authPub, authPriv := testingKey25519(t, 1) authKey := Key{Kind: Key25519, Public: authPub, Votes: 2} diff --git a/tka/state.go b/tka/state.go index 0a459bd9a..95a319bd9 100644 --- a/tka/state.go +++ b/tka/state.go @@ -1,6 +1,8 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause +//go:build !ts_omit_tailnetlock + package tka import ( @@ -138,7 +140,7 @@ func (s State) checkDisablement(secret []byte) bool { // Specifically, the rules are: // - The last AUM hash must match (transitively, this implies that this // update follows the last update message applied to the state machine) -// - Or, the state machine knows no parent (its brand new). +// - Or, the state machine knows no parent (it's brand new). func (s State) parentMatches(update AUM) bool { if s.LastAUMHash == nil { return true diff --git a/tka/state_test.go b/tka/state_test.go index 060bd9350..32b656314 100644 --- a/tka/state_test.go +++ b/tka/state_test.go @@ -1,6 +1,8 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause +//go:build !ts_omit_tailnetlock + package tka import ( diff --git a/tka/sync.go b/tka/sync.go index 6131f54d0..e3a858c15 100644 --- a/tka/sync.go +++ b/tka/sync.go @@ -1,6 +1,8 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause +//go:build !ts_omit_tailnetlock + package tka import ( @@ -52,7 +54,7 @@ const ( // can then be applied locally with Inform(). // // This SyncOffer + AUM exchange should be performed by both ends, -// because its possible that either end has AUMs that the other needs +// because it's possible that either end has AUMs that the other needs // to find out about. func (a *Authority) SyncOffer(storage Chonk) (SyncOffer, error) { oldest := a.oldestAncestor.Hash() @@ -121,7 +123,7 @@ func computeSyncIntersection(storage Chonk, localOffer, remoteOffer SyncOffer) ( } // Case: 'head intersection' - // If we have the remote's head, its more likely than not that + // If we have the remote's head, it's more likely than not that // we have updates that build on that head. To confirm this, // we iterate backwards through our chain to see if the given // head is an ancestor of our current chain. @@ -163,7 +165,7 @@ func computeSyncIntersection(storage Chonk, localOffer, remoteOffer SyncOffer) ( // Case: 'tail intersection' // So we don't have a clue what the remote's head is, but // if one of the ancestors they gave us is part of our chain, - // then theres an intersection, which is a starting point for + // then there's an intersection, which is a starting point for // the remote to send us AUMs from. // // We iterate the list of ancestors in order because the remote diff --git a/tka/sync_test.go b/tka/sync_test.go index 7250eacf7..ea14a37e5 100644 --- a/tka/sync_test.go +++ b/tka/sync_test.go @@ -346,7 +346,7 @@ func TestSyncSimpleE2E(t *testing.T) { optKey("key", key, priv), optSignAllUsing("key")) - nodeStorage := &Mem{} + nodeStorage := ChonkMem() node, err := Bootstrap(nodeStorage, c.AUMs["G1"]) if err != nil { t.Fatalf("node Bootstrap() failed: %v", err) @@ -357,7 +357,7 @@ func TestSyncSimpleE2E(t *testing.T) { t.Fatalf("control Open() failed: %v", err) } - // Control knows the full chain, node only knows the genesis. Lets see + // Control knows the full chain, node only knows the genesis. Let's see // if they can sync. nodeOffer, err := node.SyncOffer(nodeStorage) if err != nil { diff --git a/tka/tailchonk.go b/tka/tailchonk.go index 32d2215de..a55033bcd 100644 --- a/tka/tailchonk.go +++ b/tka/tailchonk.go @@ -1,19 +1,26 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause +//go:build !ts_omit_tailnetlock + package tka import ( "bytes" "errors" "fmt" + "log" + "maps" "os" "path/filepath" + "slices" "sync" "time" "github.com/fxamacker/cbor/v2" "tailscale.com/atomicfile" + "tailscale.com/tstime" + "tailscale.com/util/testenv" ) // Chonk implementations provide durable storage for AUMs and other @@ -71,38 +78,69 @@ type CompactableChonk interface { // PurgeAUMs permanently and irrevocably deletes the specified // AUMs from storage. PurgeAUMs(hashes []AUMHash) error + + // RemoveAll permanently and completely clears the TKA state. This should + // be called when the user disables Tailnet Lock. + RemoveAll() error } // Mem implements in-memory storage of TKA state, suitable for -// tests. +// tests or cases where filesystem storage is unavailable. // // Mem implements the Chonk interface. +// +// Mem is thread-safe. type Mem struct { - l sync.RWMutex + mu sync.RWMutex aums map[AUMHash]AUM + commitTimes map[AUMHash]time.Time + clock tstime.Clock + + // parentIndex is a map of AUMs to the AUMs for which they are + // the parent. + // + // For example, if parent index is {1 -> {2, 3, 4}}, that means + // that AUMs 2, 3, 4 all have aum.PrevAUMHash = 1. parentIndex map[AUMHash][]AUMHash lastActiveAncestor *AUMHash } +// ChonkMem returns an implementation of Chonk which stores TKA state +// in-memory. +func ChonkMem() *Mem { + return &Mem{ + clock: tstime.DefaultClock{}, + } +} + +// SetClock sets the clock used by [Mem]. This is only for use in tests, +// and will panic if called from non-test code. +func (c *Mem) SetClock(clock tstime.Clock) { + if !testenv.InTest() { + panic("used SetClock in non-test code") + } + c.clock = clock +} + func (c *Mem) SetLastActiveAncestor(hash AUMHash) error { - c.l.Lock() - defer c.l.Unlock() + c.mu.Lock() + defer c.mu.Unlock() c.lastActiveAncestor = &hash return nil } func (c *Mem) LastActiveAncestor() (*AUMHash, error) { - c.l.RLock() - defer c.l.RUnlock() + c.mu.RLock() + defer c.mu.RUnlock() return c.lastActiveAncestor, nil } // Heads returns AUMs for which there are no children. In other // words, the latest AUM in all chains (the 'leaf'). func (c *Mem) Heads() ([]AUM, error) { - c.l.RLock() - defer c.l.RUnlock() + c.mu.RLock() + defer c.mu.RUnlock() out := make([]AUM, 0, 6) // An AUM is a 'head' if there are no nodes for which it is the parent. @@ -116,8 +154,8 @@ func (c *Mem) Heads() ([]AUM, error) { // AUM returns the AUM with the specified digest. func (c *Mem) AUM(hash AUMHash) (AUM, error) { - c.l.RLock() - defer c.l.RUnlock() + c.mu.RLock() + defer c.mu.RUnlock() aum, ok := c.aums[hash] if !ok { return AUM{}, os.ErrNotExist @@ -125,24 +163,11 @@ func (c *Mem) AUM(hash AUMHash) (AUM, error) { return aum, nil } -// Orphans returns all AUMs which do not have a parent. -func (c *Mem) Orphans() ([]AUM, error) { - c.l.RLock() - defer c.l.RUnlock() - out := make([]AUM, 0, 6) - for _, a := range c.aums { - if _, ok := a.Parent(); !ok { - out = append(out, a) - } - } - return out, nil -} - // ChildAUMs returns all AUMs with a specified previous // AUM hash. func (c *Mem) ChildAUMs(prevAUMHash AUMHash) ([]AUM, error) { - c.l.RLock() - defer c.l.RUnlock() + c.mu.RLock() + defer c.mu.RUnlock() out := make([]AUM, 0, 6) for _, entry := range c.parentIndex[prevAUMHash] { out = append(out, c.aums[entry]) @@ -156,17 +181,19 @@ func (c *Mem) ChildAUMs(prevAUMHash AUMHash) ([]AUM, error) { // as the rest of the TKA implementation assumes that only // verified AUMs are stored. func (c *Mem) CommitVerifiedAUMs(updates []AUM) error { - c.l.Lock() - defer c.l.Unlock() + c.mu.Lock() + defer c.mu.Unlock() if c.aums == nil { c.parentIndex = make(map[AUMHash][]AUMHash, 64) c.aums = make(map[AUMHash]AUM, 64) + c.commitTimes = make(map[AUMHash]time.Time, 64) } updateLoop: for _, aum := range updates { aumHash := aum.Hash() c.aums[aumHash] = aum + c.commitTimes[aumHash] = c.clock.Now() parent, ok := aum.Parent() if ok { @@ -182,6 +209,71 @@ updateLoop: return nil } +// RemoveAll permanently and completely clears the TKA state. +func (c *Mem) RemoveAll() error { + c.mu.Lock() + defer c.mu.Unlock() + c.aums = nil + c.commitTimes = nil + c.parentIndex = nil + c.lastActiveAncestor = nil + return nil +} + +// AllAUMs returns all AUMs stored in the chonk. +func (c *Mem) AllAUMs() ([]AUMHash, error) { + c.mu.RLock() + defer c.mu.RUnlock() + + return slices.Collect(maps.Keys(c.aums)), nil +} + +// CommitTime returns the time at which the AUM was committed. +// +// If the AUM does not exist, then os.ErrNotExist is returned. +func (c *Mem) CommitTime(h AUMHash) (time.Time, error) { + c.mu.RLock() + defer c.mu.RUnlock() + + t, ok := c.commitTimes[h] + if ok { + return t, nil + } else { + return time.Time{}, os.ErrNotExist + } +} + +// PurgeAUMs marks the specified AUMs for deletion from storage. +func (c *Mem) PurgeAUMs(hashes []AUMHash) error { + c.mu.Lock() + defer c.mu.Unlock() + + for _, h := range hashes { + // Remove the deleted AUM from the list of its parents' children. + // + // However, we leave the list of this AUM's children in parentIndex, + // so we can find them later in ChildAUMs(). + if aum, ok := c.aums[h]; ok { + parent, hasParent := aum.Parent() + if hasParent { + c.parentIndex[parent] = slices.DeleteFunc( + c.parentIndex[parent], + func(other AUMHash) bool { return bytes.Equal(h[:], other[:]) }, + ) + if len(c.parentIndex[parent]) == 0 { + delete(c.parentIndex, parent) + } + } + } + + // Delete this AUM from the list of AUMs and commit times. + delete(c.aums, h) + delete(c.commitTimes, h) + } + + return nil +} + // FS implements filesystem storage of TKA state. // // FS implements the Chonk interface. @@ -193,6 +285,10 @@ type FS struct { // ChonkDir returns an implementation of Chonk which uses the // given directory to store TKA state. func ChonkDir(dir string) (*FS, error) { + if err := os.MkdirAll(dir, 0755); err != nil && !os.IsExist(err) { + return nil, fmt.Errorf("creating chonk root dir: %v", err) + } + stat, err := os.Stat(dir) if err != nil { return nil, err @@ -217,10 +313,14 @@ func ChonkDir(dir string) (*FS, error) { // CBOR was chosen because we are already using it and it serializes // much smaller than JSON for AUMs. The 'keyasint' thing isn't essential // but again it saves a bunch of bytes. +// +// We have removed the following fields from fsHashInfo, but they may be +// present in data stored in existing deployments. Do not reuse these values, +// to avoid getting unexpected values from legacy data: +// - cbor:1, Children type fsHashInfo struct { - Children []AUMHash `cbor:"1,keyasint"` - AUM *AUM `cbor:"2,keyasint"` - CreatedUnix int64 `cbor:"3,keyasint,omitempty"` + AUM *AUM `cbor:"2,keyasint"` + CreatedUnix int64 `cbor:"3,keyasint,omitempty"` // PurgedUnix is set when the AUM is deleted. The value is // the unix epoch at the time it was deleted. @@ -296,32 +396,15 @@ func (c *FS) ChildAUMs(prevAUMHash AUMHash) ([]AUM, error) { c.mu.RLock() defer c.mu.RUnlock() - info, err := c.get(prevAUMHash) - if err != nil { - if os.IsNotExist(err) { - // not knowing about this hash is not an error - return nil, nil - } - return nil, err - } - // NOTE(tom): We don't check PurgedUnix here because 'purged' - // only applies to that specific AUM (i.e. info.AUM) and not to - // any information about children stored against that hash. + var out []AUM - out := make([]AUM, len(info.Children)) - for i, h := range info.Children { - c, err := c.get(h) - if err != nil { - // We expect any AUM recorded as a child on its parent to exist. - return nil, fmt.Errorf("reading child %d of %x: %v", i, h, err) - } - if c.AUM == nil || c.PurgedUnix > 0 { - return nil, fmt.Errorf("child %d of %x: AUM not stored", i, h) + err := c.scanHashes(func(info *fsHashInfo) { + if info.AUM != nil && bytes.Equal(info.AUM.PrevAUMHash, prevAUMHash[:]) { + out = append(out, *info.AUM) } - out[i] = *c.AUM - } + }) - return out, nil + return out, err } func (c *FS) get(h AUMHash) (*fsHashInfo, error) { @@ -357,13 +440,50 @@ func (c *FS) Heads() ([]AUM, error) { c.mu.RLock() defer c.mu.RUnlock() + // Scan the complete list of AUMs, and build a list of all parent hashes. + // This tells us which AUMs have children. + var parentHashes []AUMHash + + allAUMs, err := c.AllAUMs() + if err != nil { + return nil, err + } + + for _, h := range allAUMs { + aum, err := c.AUM(h) + if err != nil { + return nil, err + } + parent, hasParent := aum.Parent() + if !hasParent { + continue + } + if !slices.Contains(parentHashes, parent) { + parentHashes = append(parentHashes, parent) + } + } + + // Now scan a second time, and only include AUMs which weren't marked as + // the parent of any other AUM. out := make([]AUM, 0, 6) // 6 is arbitrary. - err := c.scanHashes(func(info *fsHashInfo) { - if len(info.Children) == 0 && info.AUM != nil && info.PurgedUnix == 0 { - out = append(out, *info.AUM) + + for _, h := range allAUMs { + if slices.Contains(parentHashes, h) { + continue } - }) - return out, err + aum, err := c.AUM(h) + if err != nil { + return nil, err + } + out = append(out, aum) + } + + return out, nil +} + +// RemoveAll permanently and completely clears the TKA state. +func (c *FS) RemoveAll() error { + return os.RemoveAll(c.base) } // AllAUMs returns all AUMs stored in the chonk. @@ -373,7 +493,7 @@ func (c *FS) AllAUMs() ([]AUMHash, error) { out := make([]AUMHash, 0, 6) // 6 is arbitrary. err := c.scanHashes(func(info *fsHashInfo) { - if info.AUM != nil && info.PurgedUnix == 0 { + if info.AUM != nil { out = append(out, info.AUM.Hash()) } }) @@ -394,14 +514,24 @@ func (c *FS) scanHashes(eachHashInfo func(*fsHashInfo)) error { return fmt.Errorf("reading prefix dir: %v", err) } for _, file := range files { + // Ignore files whose names aren't valid AUM hashes, which may be + // temporary files which are partway through being written, or other + // files added by the OS (like .DS_Store) which we can ignore. + // TODO(alexc): it might be useful to append a suffix like `.aum` to + // filenames, so we can more easily distinguish between AUMs and + // arbitrary other files. var h AUMHash if err := h.UnmarshalText([]byte(file.Name())); err != nil { - return fmt.Errorf("invalid aum file: %s: %w", file.Name(), err) + log.Printf("ignoring unexpected non-AUM: %s: %v", file.Name(), err) + continue } info, err := c.get(h) if err != nil { return fmt.Errorf("reading %x: %v", h, err) } + if info.PurgedUnix > 0 { + continue + } eachHashInfo(info) } @@ -456,24 +586,6 @@ func (c *FS) CommitVerifiedAUMs(updates []AUM) error { for i, aum := range updates { h := aum.Hash() - // We keep track of children against their parent so that - // ChildAUMs() do not need to scan all AUMs. - parent, hasParent := aum.Parent() - if hasParent { - err := c.commit(parent, func(info *fsHashInfo) { - // Only add it if its not already there. - for i := range info.Children { - if info.Children[i] == h { - return - } - } - info.Children = append(info.Children, h) - }) - if err != nil { - return fmt.Errorf("committing update[%d] to parent %x: %v", i, parent, err) - } - } - err := c.commit(h, func(info *fsHashInfo) { info.PurgedUnix = 0 // just in-case it was set for some reason info.AUM = &aum @@ -576,7 +688,7 @@ const ( ) // markActiveChain marks AUMs in the active chain. -// All AUMs that are within minChain ancestors of head are +// All AUMs that are within minChain ancestors of head, or are marked as young, are // marked retainStateActive, and all remaining ancestors are // marked retainStateCandidate. // @@ -602,27 +714,30 @@ func markActiveChain(storage Chonk, verdict map[AUMHash]retainState, minChain in // We've reached the end of the chain we have stored. return h, nil } - return AUMHash{}, fmt.Errorf("reading active chain (retainStateActive) (%d): %w", i, err) + return AUMHash{}, fmt.Errorf("reading active chain (retainStateActive) (%d, %v): %w", i, parent, err) } } // If we got this far, we have at least minChain AUMs stored, and minChain number // of ancestors have been marked for retention. We now continue to iterate backwards - // till we find an AUM which we can compact to (a Checkpoint AUM). + // till we find an AUM which we can compact to: either a Checkpoint AUM which is old + // enough, or the genesis AUM. for { h := next.Hash() verdict[h] |= retainStateActive + + parent, hasParent := next.Parent() + isYoung := verdict[h]&retainStateYoung != 0 + if next.MessageKind == AUMCheckpoint { lastActiveAncestor = h - break + if !isYoung || !hasParent { + break + } } - parent, hasParent := next.Parent() - if !hasParent { - return AUMHash{}, errors.New("reached genesis AUM without finding an appropriate lastActiveAncestor") - } if next, err = storage.AUM(parent); err != nil { - return AUMHash{}, fmt.Errorf("searching for compaction target: %w", err) + return AUMHash{}, fmt.Errorf("searching for compaction target (%v): %w", parent, err) } } @@ -638,7 +753,7 @@ func markActiveChain(storage Chonk, verdict map[AUMHash]retainState, minChain in // We've reached the end of the chain we have stored. break } - return AUMHash{}, fmt.Errorf("reading active chain (retainStateCandidate): %w", err) + return AUMHash{}, fmt.Errorf("reading active chain (retainStateCandidate, %v): %w", parent, err) } } @@ -676,7 +791,7 @@ func markAncestorIntersectionAUMs(storage Chonk, verdict map[AUMHash]retainState toScan := make([]AUMHash, 0, len(verdict)) for h, v := range verdict { if (v & retainAUMMask) == 0 { - continue // not marked for retention, so dont need to consider it + continue // not marked for retention, so don't need to consider it } if h == candidateAncestor { continue @@ -750,7 +865,7 @@ func markAncestorIntersectionAUMs(storage Chonk, verdict map[AUMHash]retainState if didAdjustCandidateAncestor { var next AUM if next, err = storage.AUM(candidateAncestor); err != nil { - return AUMHash{}, fmt.Errorf("searching for compaction target: %w", err) + return AUMHash{}, fmt.Errorf("searching for compaction target (%v): %w", candidateAncestor, err) } for { @@ -766,7 +881,7 @@ func markAncestorIntersectionAUMs(storage Chonk, verdict map[AUMHash]retainState return AUMHash{}, errors.New("reached genesis AUM without finding an appropriate candidateAncestor") } if next, err = storage.AUM(parent); err != nil { - return AUMHash{}, fmt.Errorf("searching for compaction target: %w", err) + return AUMHash{}, fmt.Errorf("searching for compaction target (%v): %w", parent, err) } } } @@ -779,7 +894,7 @@ func markDescendantAUMs(storage Chonk, verdict map[AUMHash]retainState) error { toScan := make([]AUMHash, 0, len(verdict)) for h, v := range verdict { if v&retainAUMMask == 0 { - continue // not marked, so dont need to mark descendants + continue // not marked, so don't need to mark descendants } toScan = append(toScan, h) } @@ -825,12 +940,12 @@ func Compact(storage CompactableChonk, head AUMHash, opts CompactionOptions) (la verdict[h] = 0 } - if lastActiveAncestor, err = markActiveChain(storage, verdict, opts.MinChain, head); err != nil { - return AUMHash{}, fmt.Errorf("marking active chain: %w", err) - } if err := markYoungAUMs(storage, verdict, opts.MinAge); err != nil { return AUMHash{}, fmt.Errorf("marking young AUMs: %w", err) } + if lastActiveAncestor, err = markActiveChain(storage, verdict, opts.MinChain, head); err != nil { + return AUMHash{}, fmt.Errorf("marking active chain: %w", err) + } if err := markDescendantAUMs(storage, verdict); err != nil { return AUMHash{}, fmt.Errorf("marking descendant AUMs: %w", err) } diff --git a/tka/tailchonk_test.go b/tka/tailchonk_test.go index 86d5642a3..eeb6edfff 100644 --- a/tka/tailchonk_test.go +++ b/tka/tailchonk_test.go @@ -5,9 +5,9 @@ package tka import ( "bytes" - "fmt" "os" "path/filepath" + "slices" "sync" "testing" "time" @@ -15,8 +15,17 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" "golang.org/x/crypto/blake2s" + "tailscale.com/types/key" + "tailscale.com/util/must" ) +// This package has implementation-specific tests for Mem and FS. +// +// We also have tests for the Chonk interface in `chonktest`, which exercises +// both Mem and FS. Those tests are in a separate package so they can be shared +// with other repos; we don't call the shared test helpers from this package +// to avoid creating a circular dependency. + // randHash derives a fake blake2s hash from the test name // and the given seed. func randHash(t *testing.T, seed int64) [blake2s.Size]byte { @@ -26,134 +35,12 @@ func randHash(t *testing.T, seed int64) [blake2s.Size]byte { } func TestImplementsChonk(t *testing.T) { - impls := []Chonk{&Mem{}, &FS{}} + impls := []Chonk{ChonkMem(), &FS{}} t.Logf("chonks: %v", impls) } -func TestTailchonk_ChildAUMs(t *testing.T) { - for _, chonk := range []Chonk{&Mem{}, &FS{base: t.TempDir()}} { - t.Run(fmt.Sprintf("%T", chonk), func(t *testing.T) { - parentHash := randHash(t, 1) - data := []AUM{ - { - MessageKind: AUMRemoveKey, - KeyID: []byte{1, 2}, - PrevAUMHash: parentHash[:], - }, - { - MessageKind: AUMRemoveKey, - KeyID: []byte{3, 4}, - PrevAUMHash: parentHash[:], - }, - } - - if err := chonk.CommitVerifiedAUMs(data); err != nil { - t.Fatalf("CommitVerifiedAUMs failed: %v", err) - } - stored, err := chonk.ChildAUMs(parentHash) - if err != nil { - t.Fatalf("ChildAUMs failed: %v", err) - } - if diff := cmp.Diff(data, stored); diff != "" { - t.Errorf("stored AUM differs (-want, +got):\n%s", diff) - } - }) - } -} - -func TestTailchonk_AUMMissing(t *testing.T) { - for _, chonk := range []Chonk{&Mem{}, &FS{base: t.TempDir()}} { - t.Run(fmt.Sprintf("%T", chonk), func(t *testing.T) { - var notExists AUMHash - notExists[:][0] = 42 - if _, err := chonk.AUM(notExists); err != os.ErrNotExist { - t.Errorf("chonk.AUM(notExists).err = %v, want %v", err, os.ErrNotExist) - } - }) - } -} - -func TestTailchonkMem_Orphans(t *testing.T) { - chonk := Mem{} - - parentHash := randHash(t, 1) - orphan := AUM{MessageKind: AUMNoOp} - aums := []AUM{ - orphan, - // A parent is specified, so we shouldnt see it in GetOrphans() - { - MessageKind: AUMRemoveKey, - KeyID: []byte{3, 4}, - PrevAUMHash: parentHash[:], - }, - } - if err := chonk.CommitVerifiedAUMs(aums); err != nil { - t.Fatalf("CommitVerifiedAUMs failed: %v", err) - } - - stored, err := chonk.Orphans() - if err != nil { - t.Fatalf("Orphans failed: %v", err) - } - if diff := cmp.Diff([]AUM{orphan}, stored); diff != "" { - t.Errorf("stored AUM differs (-want, +got):\n%s", diff) - } -} - -func TestTailchonk_ReadChainFromHead(t *testing.T) { - for _, chonk := range []Chonk{&Mem{}, &FS{base: t.TempDir()}} { - - t.Run(fmt.Sprintf("%T", chonk), func(t *testing.T) { - genesis := AUM{MessageKind: AUMRemoveKey, KeyID: []byte{1, 2}} - gHash := genesis.Hash() - intermediate := AUM{PrevAUMHash: gHash[:]} - iHash := intermediate.Hash() - leaf := AUM{PrevAUMHash: iHash[:]} - - commitSet := []AUM{ - genesis, - intermediate, - leaf, - } - if err := chonk.CommitVerifiedAUMs(commitSet); err != nil { - t.Fatalf("CommitVerifiedAUMs failed: %v", err) - } - // t.Logf("genesis hash = %X", genesis.Hash()) - // t.Logf("intermediate hash = %X", intermediate.Hash()) - // t.Logf("leaf hash = %X", leaf.Hash()) - - // Read the chain from the leaf backwards. - gotLeafs, err := chonk.Heads() - if err != nil { - t.Fatalf("Heads failed: %v", err) - } - if diff := cmp.Diff([]AUM{leaf}, gotLeafs); diff != "" { - t.Fatalf("leaf AUM differs (-want, +got):\n%s", diff) - } - - parent, _ := gotLeafs[0].Parent() - gotIntermediate, err := chonk.AUM(parent) - if err != nil { - t.Fatalf("AUM() failed: %v", err) - } - if diff := cmp.Diff(intermediate, gotIntermediate); diff != "" { - t.Errorf("intermediate AUM differs (-want, +got):\n%s", diff) - } - - parent, _ = gotIntermediate.Parent() - gotGenesis, err := chonk.AUM(parent) - if err != nil { - t.Fatalf("AUM() failed: %v", err) - } - if diff := cmp.Diff(genesis, gotGenesis); diff != "" { - t.Errorf("genesis AUM differs (-want, +got):\n%s", diff) - } - }) - } -} - func TestTailchonkFS_Commit(t *testing.T) { - chonk := &FS{base: t.TempDir()} + chonk := must.Get(ChonkDir(t.TempDir())) parentHash := randHash(t, 1) aum := AUM{MessageKind: AUMNoOp, PrevAUMHash: parentHash[:]} @@ -171,9 +58,6 @@ func TestTailchonkFS_Commit(t *testing.T) { if _, err := os.Stat(filepath.Join(dir, base)); err != nil { t.Errorf("stat of AUM file failed: %v", err) } - if _, err := os.Stat(filepath.Join(chonk.base, "M7", "M7LL2NDB4NKCZIUPVS6RDM2GUOIMW6EEAFVBWMVCPUANQJPHT3SQ")); err != nil { - t.Errorf("stat of AUM parent failed: %v", err) - } info, err := chonk.get(aum.Hash()) if err != nil { @@ -185,7 +69,7 @@ func TestTailchonkFS_Commit(t *testing.T) { } func TestTailchonkFS_CommitTime(t *testing.T) { - chonk := &FS{base: t.TempDir()} + chonk := must.Get(ChonkDir(t.TempDir())) parentHash := randHash(t, 1) aum := AUM{MessageKind: AUMNoOp, PrevAUMHash: parentHash[:]} @@ -201,57 +85,83 @@ func TestTailchonkFS_CommitTime(t *testing.T) { } } -func TestTailchonkFS_PurgeAUMs(t *testing.T) { - chonk := &FS{base: t.TempDir()} +// If we were interrupted while writing a temporary file, AllAUMs() +// should ignore it when scanning the AUM directory. +func TestTailchonkFS_IgnoreTempFile(t *testing.T) { + base := t.TempDir() + chonk := must.Get(ChonkDir(base)) parentHash := randHash(t, 1) aum := AUM{MessageKind: AUMNoOp, PrevAUMHash: parentHash[:]} + must.Do(chonk.CommitVerifiedAUMs([]AUM{aum})) - if err := chonk.CommitVerifiedAUMs([]AUM{aum}); err != nil { - t.Fatal(err) - } - if err := chonk.PurgeAUMs([]AUMHash{aum.Hash()}); err != nil { - t.Fatal(err) + writeAUMFile := func(filename, contents string) { + t.Helper() + if err := os.MkdirAll(filepath.Join(base, filename[0:2]), os.ModePerm); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(base, filename[0:2], filename), []byte(contents), 0600); err != nil { + t.Fatal(err) + } } - if _, err := chonk.AUM(aum.Hash()); err != os.ErrNotExist { - t.Errorf("AUM() on purged AUM returned err = %v, want ErrNotExist", err) + // Check that calling AllAUMs() returns the single committed AUM + got, err := chonk.AllAUMs() + if err != nil { + t.Fatalf("AllAUMs() failed: %v", err) + } + want := []AUMHash{aum.Hash()} + if !slices.Equal(got, want) { + t.Fatalf("AllAUMs() is wrong: got %v, want %v", got, want) } - info, err := chonk.get(aum.Hash()) + // Write some temporary files which are named like partially-committed AUMs, + // then check that AllAUMs() only returns the single committed AUM. + writeAUMFile("AUM1234.tmp", "incomplete AUM\n") + writeAUMFile("AUM1234.tmp_123", "second incomplete AUM\n") + + got, err = chonk.AllAUMs() if err != nil { - t.Fatal(err) + t.Fatalf("AllAUMs() failed: %v", err) } - if info.PurgedUnix == 0 { - t.Errorf("recently-created AUM PurgedUnix = %d, want non-zero", info.PurgedUnix) + if !slices.Equal(got, want) { + t.Fatalf("AllAUMs() is wrong: got %v, want %v", got, want) } } -func TestTailchonkFS_AllAUMs(t *testing.T) { - chonk := &FS{base: t.TempDir()} - genesis := AUM{MessageKind: AUMRemoveKey, KeyID: []byte{1, 2}} - gHash := genesis.Hash() - intermediate := AUM{PrevAUMHash: gHash[:]} - iHash := intermediate.Hash() - leaf := AUM{PrevAUMHash: iHash[:]} +// If we use a non-existent directory with filesystem Chonk storage, +// it's automatically created. +func TestTailchonkFS_CreateChonkDir(t *testing.T) { + base := filepath.Join(t.TempDir(), "a", "b", "c") - commitSet := []AUM{ - genesis, - intermediate, - leaf, - } - if err := chonk.CommitVerifiedAUMs(commitSet); err != nil { - t.Fatalf("CommitVerifiedAUMs failed: %v", err) + chonk, err := ChonkDir(base) + if err != nil { + t.Fatalf("ChonkDir: %v", err) } - hashes, err := chonk.AllAUMs() + aum := AUM{MessageKind: AUMNoOp} + must.Do(chonk.CommitVerifiedAUMs([]AUM{aum})) + + got, err := chonk.AUM(aum.Hash()) if err != nil { - t.Fatal(err) + t.Errorf("Chonk.AUM: %v", err) } - hashesLess := func(a, b AUMHash) bool { - return bytes.Compare(a[:], b[:]) < 0 + if diff := cmp.Diff(got, aum); diff != "" { + t.Errorf("wrong AUM; (-got+want):%v", diff) } - if diff := cmp.Diff([]AUMHash{genesis.Hash(), intermediate.Hash(), leaf.Hash()}, hashes, cmpopts.SortSlices(hashesLess)); diff != "" { - t.Fatalf("AllAUMs() output differs (-want, +got):\n%s", diff) + + if _, err := os.Stat(base); err != nil { + t.Errorf("os.Stat: %v", err) + } +} + +// You can't use a file as the root of your filesystem Chonk storage. +func TestTailchonkFS_CannotUseFile(t *testing.T) { + base := filepath.Join(t.TempDir(), "tka_storage.txt") + must.Do(os.WriteFile(base, []byte("this won't work"), 0644)) + + _, err := ChonkDir(base) + if err == nil { + t.Fatal("ChonkDir succeeded; expected an error") } } @@ -319,7 +229,7 @@ func TestMarkActiveChain(t *testing.T) { verdict := make(map[AUMHash]retainState, len(tc.chain)) // Build the state of the tailchonk for tests. - storage := &Mem{} + storage := ChonkMem() var prev AUMHash for i := range tc.chain { if !prev.IsZero() { @@ -611,11 +521,12 @@ func (c *compactingChonkFake) CommitTime(hash AUMHash) (time.Time, error) { return c.aumAge[hash], nil } +func hashesLess(x, y AUMHash) bool { + return bytes.Compare(x[:], y[:]) < 0 +} + func (c *compactingChonkFake) PurgeAUMs(hashes []AUMHash) error { - lessHashes := func(a, b AUMHash) bool { - return bytes.Compare(a[:], b[:]) < 0 - } - if diff := cmp.Diff(c.wantDelete, hashes, cmpopts.SortSlices(lessHashes)); diff != "" { + if diff := cmp.Diff(c.wantDelete, hashes, cmpopts.SortSlices(hashesLess)); diff != "" { c.t.Errorf("deletion set differs (-want, +got):\n%s", diff) } return nil @@ -623,7 +534,7 @@ func (c *compactingChonkFake) PurgeAUMs(hashes []AUMHash) error { // Avoid go vet complaining about copying a lock value func cloneMem(src, dst *Mem) { - dst.l = sync.RWMutex{} + dst.mu = sync.RWMutex{} dst.aums = src.aums dst.parentIndex = src.parentIndex dst.lastActiveAncestor = src.lastActiveAncestor @@ -691,3 +602,33 @@ func TestCompact(t *testing.T) { } } } + +func TestCompactLongButYoung(t *testing.T) { + ourPriv := key.NewNLPrivate() + ourKey := Key{Kind: Key25519, Public: ourPriv.Public().Verifier(), Votes: 1} + someOtherKey := Key{Kind: Key25519, Public: key.NewNLPrivate().Public().Verifier(), Votes: 1} + + storage := ChonkMem() + auth, _, err := Create(storage, State{ + Keys: []Key{ourKey, someOtherKey}, + DisablementSecrets: [][]byte{DisablementKDF(bytes.Repeat([]byte{0xa5}, 32))}, + }, ourPriv) + if err != nil { + t.Fatalf("tka.Create() failed: %v", err) + } + + genesis := auth.Head() + + for range 100 { + upd := auth.NewUpdater(ourPriv) + must.Do(upd.RemoveKey(someOtherKey.MustID())) + must.Do(upd.AddKey(someOtherKey)) + aums := must.Get(upd.Finalize(storage)) + must.Do(auth.Inform(storage, aums)) + } + + lastActiveAncestor := must.Get(Compact(storage, auth.Head(), CompactionOptions{MinChain: 5, MinAge: time.Hour})) + if lastActiveAncestor != genesis { + t.Errorf("last active ancestor = %v, want %v", lastActiveAncestor, genesis) + } +} diff --git a/tka/tka.go b/tka/tka.go index 04b712660..ed029c82e 100644 --- a/tka/tka.go +++ b/tka/tka.go @@ -1,7 +1,9 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -// Package tka (WIP) implements the Tailnet Key Authority. +//go:build !ts_omit_tailnetlock + +// Package tka implements the Tailnet Key Authority (TKA) for Tailnet Lock. package tka import ( @@ -92,7 +94,7 @@ func computeChainCandidates(storage Chonk, lastKnownOldest *AUMHash, maxIter int // candidates.Oldest needs to be computed by working backwards from // head as far as we can. - iterAgain := true // if theres still work to be done. + iterAgain := true // if there's still work to be done. for i := 0; iterAgain; i++ { if i >= maxIter { return nil, fmt.Errorf("iteration limit exceeded (%d)", maxIter) @@ -100,14 +102,14 @@ func computeChainCandidates(storage Chonk, lastKnownOldest *AUMHash, maxIter int iterAgain = false for j := range candidates { - parent, hasParent := candidates[j].Oldest.Parent() + parentHash, hasParent := candidates[j].Oldest.Parent() if hasParent { - parent, err := storage.AUM(parent) + parent, err := storage.AUM(parentHash) if err != nil { if err == os.ErrNotExist { continue } - return nil, fmt.Errorf("reading parent: %v", err) + return nil, fmt.Errorf("reading parent %s: %v", parentHash, err) } candidates[j].Oldest = parent if lastKnownOldest != nil && *lastKnownOldest == parent.Hash() { @@ -208,7 +210,7 @@ func fastForwardWithAdvancer( } nextAUM, err := storage.AUM(*startState.LastAUMHash) if err != nil { - return AUM{}, State{}, fmt.Errorf("reading next: %v", err) + return AUM{}, State{}, fmt.Errorf("reading next (%v): %v", *startState.LastAUMHash, err) } curs := nextAUM @@ -293,9 +295,9 @@ func computeStateAt(storage Chonk, maxIter int, wantHash AUMHash) (State, error) } // If we got here, the current state is dependent on the previous. - // Keep iterating backwards till thats not the case. + // Keep iterating backwards till that's not the case. if curs, err = storage.AUM(parent); err != nil { - return State{}, fmt.Errorf("reading parent: %v", err) + return State{}, fmt.Errorf("reading parent (%v): %v", parent, err) } } @@ -322,7 +324,7 @@ func computeStateAt(storage Chonk, maxIter int, wantHash AUMHash) (State, error) return curs.Hash() == wantHash }) // fastForward only terminates before the done condition if it - // doesnt have any later AUMs to process. This cant be the case + // doesn't have any later AUMs to process. This can't be the case // as we've already iterated through them above so they must exist, // but we check anyway to be super duper sure. if err == nil && *state.LastAUMHash != wantHash { @@ -334,13 +336,13 @@ func computeStateAt(storage Chonk, maxIter int, wantHash AUMHash) (State, error) // computeActiveAncestor determines which ancestor AUM to use as the // ancestor of the valid chain. // -// If all the chains end up having the same ancestor, then thats the +// If all the chains end up having the same ancestor, then that's the // only possible ancestor, ezpz. However if there are multiple distinct // ancestors, that means there are distinct chains, and we need some // hint to choose what to use. For that, we rely on the chainsThroughActive // bit, which signals to us that that ancestor was part of the // chain in a previous run. -func computeActiveAncestor(storage Chonk, chains []chain) (AUMHash, error) { +func computeActiveAncestor(chains []chain) (AUMHash, error) { // Dedupe possible ancestors, tracking if they were part of // the active chain on a previous run. ancestors := make(map[AUMHash]bool, len(chains)) @@ -355,7 +357,7 @@ func computeActiveAncestor(storage Chonk, chains []chain) (AUMHash, error) { } } - // Theres more than one, so we need to use the ancestor that was + // There's more than one, so we need to use the ancestor that was // part of the active chain in a previous iteration. // Note that there can only be one distinct ancestor that was // formerly part of the active chain, because AUMs can only have @@ -389,8 +391,12 @@ func computeActiveChain(storage Chonk, lastKnownOldest *AUMHash, maxIter int) (c return chain{}, fmt.Errorf("computing candidates: %v", err) } + if len(chains) == 0 { + return chain{}, errors.New("no chain candidates in AUM storage") + } + // Find the right ancestor. - oldestHash, err := computeActiveAncestor(storage, chains) + oldestHash, err := computeActiveAncestor(chains) if err != nil { return chain{}, fmt.Errorf("computing ancestor: %v", err) } @@ -440,6 +446,13 @@ func aumVerify(aum AUM, state State, isGenesisAUM bool) error { return fmt.Errorf("signature %d: %v", i, err) } } + + if aum.MessageKind == AUMRemoveKey && len(state.Keys) == 1 { + if kid, err := state.Keys[0].ID(); err == nil && bytes.Equal(aum.KeyID, kid) { + return errors.New("cannot remove the last key in the state") + } + } + return nil } @@ -466,7 +479,7 @@ func (a *Authority) Head() AUMHash { // Open initializes an existing TKA from the given tailchonk. // // Only use this if the current node has initialized an Authority before. -// If a TKA exists on other nodes but theres nothing locally, use Bootstrap(). +// If a TKA exists on other nodes but there's nothing locally, use Bootstrap(). // If no TKA exists anywhere and you are creating it for the first // time, use New(). func Open(storage Chonk) (*Authority, error) { @@ -579,14 +592,14 @@ func (a *Authority) InformIdempotent(storage Chonk, updates []AUM) (Authority, e toCommit := make([]AUM, 0, len(updates)) prevHash := a.Head() - // The state at HEAD is the current state of the authority. Its likely + // The state at HEAD is the current state of the authority. It's likely // to be needed, so we prefill it rather than computing it. stateAt[prevHash] = a.state // Optimization: If the set of updates is a chain building from // the current head, EG: // ==> updates[0] ==> updates[1] ... - // Then theres no need to recompute the resulting state from the + // Then there's no need to recompute the resulting state from the // stored ancestor, because the last state computed during iteration // is the new state. This should be the common case. // isHeadChain keeps track of this. @@ -766,8 +779,8 @@ func (a *Authority) findParentForRewrite(storage Chonk, removeKeys []tkatype.Key } } if !keyTrusted { - // Success: the revoked keys are not trusted! - // Lets check that our key was trusted to ensure + // Success: the revoked keys are not trusted. + // Check that our key was trusted to ensure // we can sign a fork from here. if _, err := state.GetKey(ourKey); err == nil { break diff --git a/tka/tka_test.go b/tka/tka_test.go index 9e3c4e79d..78af7400d 100644 --- a/tka/tka_test.go +++ b/tka/tka_test.go @@ -253,7 +253,7 @@ func TestOpenAuthority(t *testing.T) { } // Construct the state of durable storage. - chonk := &Mem{} + chonk := ChonkMem() err := chonk.CommitVerifiedAUMs([]AUM{g1, i1, l1, i2, i3, l2, l3, g2, l4}) if err != nil { t.Fatal(err) @@ -275,7 +275,7 @@ func TestOpenAuthority(t *testing.T) { } func TestOpenAuthority_EmptyErrors(t *testing.T) { - _, err := Open(&Mem{}) + _, err := Open(ChonkMem()) if err == nil { t.Error("Expected an error initializing an empty authority, got nil") } @@ -319,7 +319,7 @@ func TestCreateBootstrapAuthority(t *testing.T) { pub, priv := testingKey25519(t, 1) key := Key{Kind: Key25519, Public: pub, Votes: 2} - a1, genesisAUM, err := Create(&Mem{}, State{ + a1, genesisAUM, err := Create(ChonkMem(), State{ Keys: []Key{key}, DisablementSecrets: [][]byte{DisablementKDF([]byte{1, 2, 3})}, }, signer25519(priv)) @@ -327,7 +327,7 @@ func TestCreateBootstrapAuthority(t *testing.T) { t.Fatalf("Create() failed: %v", err) } - a2, err := Bootstrap(&Mem{}, genesisAUM) + a2, err := Bootstrap(ChonkMem(), genesisAUM) if err != nil { t.Fatalf("Bootstrap() failed: %v", err) } @@ -366,7 +366,7 @@ func TestAuthorityInformNonLinear(t *testing.T) { optKey("key", key, priv), optSignAllUsing("key")) - storage := &Mem{} + storage := ChonkMem() a, err := Bootstrap(storage, c.AUMs["G1"]) if err != nil { t.Fatalf("Bootstrap() failed: %v", err) @@ -411,7 +411,7 @@ func TestAuthorityInformLinear(t *testing.T) { optKey("key", key, priv), optSignAllUsing("key")) - storage := &Mem{} + storage := ChonkMem() a, err := Bootstrap(storage, c.AUMs["G1"]) if err != nil { t.Fatalf("Bootstrap() failed: %v", err) @@ -444,7 +444,7 @@ func TestInteropWithNLKey(t *testing.T) { pub2 := key.NewNLPrivate().Public() pub3 := key.NewNLPrivate().Public() - a, _, err := Create(&Mem{}, State{ + a, _, err := Create(ChonkMem(), State{ Keys: []Key{ { Kind: Key25519, diff --git a/tka/verify.go b/tka/verify.go new file mode 100644 index 000000000..ed0ecea66 --- /dev/null +++ b/tka/verify.go @@ -0,0 +1,36 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !ts_omit_tailnetlock + +package tka + +import ( + "crypto/ed25519" + "errors" + "fmt" + + "github.com/hdevalence/ed25519consensus" + "tailscale.com/types/tkatype" +) + +// signatureVerify returns a nil error if the signature is valid over the +// provided AUM BLAKE2s digest, using the given key. +func signatureVerify(s *tkatype.Signature, aumDigest tkatype.AUMSigHash, key Key) error { + // NOTE(tom): Even if we can compute the public from the KeyID, + // it's possible for the KeyID to be attacker-controlled + // so we should use the public contained in the state machine. + switch key.Kind { + case Key25519: + if len(key.Public) != ed25519.PublicKeySize { + return fmt.Errorf("ed25519 key has wrong length: %d", len(key.Public)) + } + if ed25519consensus.Verify(ed25519.PublicKey(key.Public), aumDigest[:], s.Signature) { + return nil + } + return errors.New("invalid signature") + + default: + return fmt.Errorf("unhandled key type: %v", key.Kind) + } +} diff --git a/tka/verify_disabled.go b/tka/verify_disabled.go new file mode 100644 index 000000000..ba72f93e2 --- /dev/null +++ b/tka/verify_disabled.go @@ -0,0 +1,18 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build ts_omit_tailnetlock + +package tka + +import ( + "errors" + + "tailscale.com/types/tkatype" +) + +// signatureVerify returns a nil error if the signature is valid over the +// provided AUM BLAKE2s digest, using the given key. +func signatureVerify(s *tkatype.Signature, aumDigest tkatype.AUMSigHash, key Key) error { + return errors.New("tailnetlock disabled in build") +} diff --git a/tool/go-win.ps1 b/tool/go-win.ps1 new file mode 100644 index 000000000..49313ffba --- /dev/null +++ b/tool/go-win.ps1 @@ -0,0 +1,64 @@ +<# + go.ps1 – Tailscale Go toolchain fetching wrapper for Windows/PowerShell + â€ĸ Reads go.toolchain.rev one dir above this script + â€ĸ If the requested commit hash isn't cached, downloads and unpacks + https://github.com/tailscale/go/releases/download/build-${REV}/${OS}-${ARCH}.tar.gz + â€ĸ Finally execs the toolchain's "go" binary, forwarding all args & exit-code +#> + +param( + [Parameter(ValueFromRemainingArguments = $true)] + [string[]] $Args +) + +Set-StrictMode -Version Latest +$ErrorActionPreference = 'Stop' + +if ($env:CI -eq 'true' -and $env:NODEBUG -ne 'true') { + $VerbosePreference = 'Continue' +} + +$repoRoot = Resolve-Path (Join-Path $PSScriptRoot '..') +$REV = (Get-Content (Join-Path $repoRoot 'go.toolchain.rev') -Raw).Trim() + +if ([IO.Path]::IsPathRooted($REV)) { + $toolchain = $REV +} else { + if (-not [string]::IsNullOrWhiteSpace($env:TSGO_CACHE_ROOT)) { + $cacheRoot = $env:TSGO_CACHE_ROOT + } else { + $cacheRoot = Join-Path $env:USERPROFILE '.cache\tsgo' + } + + $toolchain = Join-Path $cacheRoot $REV + $marker = "$toolchain.extracted" + + if (-not (Test-Path $marker)) { + Write-Host "# Downloading Go toolchain $REV" -ForegroundColor Cyan + if (Test-Path $toolchain) { Remove-Item -Recurse -Force $toolchain } + + # Removing the marker file again (even though it shouldn't still exist) + # because the equivalent Bash script also does so (to guard against + # concurrent cache fills?). + # TODO(bradfitz): remove this and add some proper locking instead? + if (Test-Path $marker ) { Remove-Item -Force $marker } + + New-Item -ItemType Directory -Path $cacheRoot -Force | Out-Null + + $url = "https://github.com/tailscale/go/releases/download/build-$REV/windows-amd64.tar.gz" + $tgz = "$toolchain.tar.gz" + Invoke-WebRequest -Uri $url -OutFile $tgz -UseBasicParsing -ErrorAction Stop + + New-Item -ItemType Directory -Path $toolchain -Force | Out-Null + tar --strip-components=1 -xzf $tgz -C $toolchain + Remove-Item $tgz + Set-Content -Path $marker -Value $REV + } +} + +$goExe = Join-Path $toolchain 'bin\go.exe' +if (-not (Test-Path $goExe)) { throw "go executable not found at $goExe" } + +& $goExe @Args +exit $LASTEXITCODE + diff --git a/tool/go.cmd b/tool/go.cmd new file mode 100644 index 000000000..b7b5d0483 --- /dev/null +++ b/tool/go.cmd @@ -0,0 +1,36 @@ +@echo off +rem Checking for PowerShell Core using PowerShell for Windows... +powershell -NoProfile -NonInteractive -Command "& {Get-Command -Name pwsh -ErrorAction Stop}" > NUL +if ERRORLEVEL 1 ( + rem Ask the user whether they should install the dependencies. Note that this + rem code path never runs in CI because pwsh is always explicitly installed. + + rem Time out after 5 minutes, defaulting to 'N' + choice /c yn /t 300 /d n /m "PowerShell Core is required. Install now" + if ERRORLEVEL 2 ( + echo Aborting due to unmet dependencies. + exit /b 1 + ) + + rem Check for a .NET Core runtime using PowerShell for Windows... + powershell -NoProfile -NonInteractive -Command "& {if (-not (dotnet --list-runtimes | Select-String 'Microsoft\.NETCore\.App' -Quiet)) {exit 1}}" > NUL + rem Install .NET Core if missing to provide PowerShell Core's runtime library. + if ERRORLEVEL 1 ( + rem Time out after 5 minutes, defaulting to 'N' + choice /c yn /t 300 /d n /m "PowerShell Core requires .NET Core for its runtime library. Install now" + if ERRORLEVEL 2 ( + echo Aborting due to unmet dependencies. + exit /b 1 + ) + + winget install --accept-package-agreements --id Microsoft.DotNet.Runtime.8 -e --source winget + ) + + rem Now install PowerShell Core. + winget install --accept-package-agreements --id Microsoft.PowerShell -e --source winget + if ERRORLEVEL 0 echo Please re-run this script within a new console session to pick up PATH changes. + rem Either way we didn't build, so return 1. + exit /b 1 +) + +pwsh -NoProfile -ExecutionPolicy Bypass "%~dp0..\tool\gocross\gocross-wrapper.ps1" %* diff --git a/tool/gocross/autoflags.go b/tool/gocross/autoflags.go index c66cab55a..b28d3bc5d 100644 --- a/tool/gocross/autoflags.go +++ b/tool/gocross/autoflags.go @@ -35,7 +35,7 @@ func autoflagsForTest(argv []string, env *Environment, goroot, nativeGOOS, nativ cc = "cc" targetOS = cmp.Or(env.Get("GOOS", ""), nativeGOOS) targetArch = cmp.Or(env.Get("GOARCH", ""), nativeGOARCH) - buildFlags = []string{"-trimpath"} + buildFlags = []string{} cgoCflags = []string{"-O3", "-std=gnu11", "-g"} cgoLdflags []string ldflags []string @@ -47,6 +47,10 @@ func autoflagsForTest(argv []string, env *Environment, goroot, nativeGOOS, nativ subcommand = argv[1] } + if subcommand != "test" { + buildFlags = append(buildFlags, "-trimpath") + } + switch subcommand { case "build", "env", "install", "run", "test", "list": default: @@ -146,7 +150,11 @@ func autoflagsForTest(argv []string, env *Environment, goroot, nativeGOOS, nativ case env.IsSet("MACOSX_DEPLOYMENT_TARGET"): xcodeFlags = append(xcodeFlags, "-mmacosx-version-min="+env.Get("MACOSX_DEPLOYMENT_TARGET", "")) case env.IsSet("TVOS_DEPLOYMENT_TARGET"): - xcodeFlags = append(xcodeFlags, "-mtvos-version-min="+env.Get("TVOS_DEPLOYMENT_TARGET", "")) + if env.Get("TARGET_DEVICE_PLATFORM_NAME", "") == "appletvsimulator" { + xcodeFlags = append(xcodeFlags, "-mtvos-simulator-version-min="+env.Get("TVOS_DEPLOYMENT_TARGET", "")) + } else { + xcodeFlags = append(xcodeFlags, "-mtvos-version-min="+env.Get("TVOS_DEPLOYMENT_TARGET", "")) + } default: return nil, nil, fmt.Errorf("invoked by Xcode but couldn't figure out deployment target. Did Xcode change its envvars again?") } diff --git a/tool/gocross/autoflags_test.go b/tool/gocross/autoflags_test.go index 8f24dd8a3..a0f3edfd2 100644 --- a/tool/gocross/autoflags_test.go +++ b/tool/gocross/autoflags_test.go @@ -163,7 +163,6 @@ GOTOOLCHAIN=local (was ) TS_LINK_FAIL_REFLECT=0 (was )`, wantArgv: []string{ "gocross", "test", - "-trimpath", "-tags=tailscale_go,osusergo,netgo", "-ldflags", "-X tailscale.com/version.longStamp=1.2.3-long -X tailscale.com/version.shortStamp=1.2.3 -X tailscale.com/version.gitCommitStamp=abcd -X tailscale.com/version.extraGitCommitStamp=defg '-extldflags=-static'", "-race", diff --git a/tool/gocross/exec_other.go b/tool/gocross/exec_other.go index 8d4df0db3..4dd74f84d 100644 --- a/tool/gocross/exec_other.go +++ b/tool/gocross/exec_other.go @@ -6,15 +6,25 @@ package main import ( + "errors" "os" "os/exec" ) func doExec(cmd string, args []string, env []string) error { - c := exec.Command(cmd, args...) + c := exec.Command(cmd, args[1:]...) c.Env = env c.Stdin = os.Stdin c.Stdout = os.Stdout c.Stderr = os.Stderr - return c.Run() + err := c.Run() + + // Propagate ExitErrors within this func to give us similar semantics to + // the Unix variant. + var ee *exec.ExitError + if errors.As(err, &ee) { + os.Exit(ee.ExitCode()) + } + + return err } diff --git a/tool/gocross/gocross-wrapper.ps1 b/tool/gocross/gocross-wrapper.ps1 new file mode 100644 index 000000000..fe0b46996 --- /dev/null +++ b/tool/gocross/gocross-wrapper.ps1 @@ -0,0 +1,226 @@ +# Copyright (c) Tailscale Inc & AUTHORS +# SPDX-License-Identifier: BSD-3-Clause + +#Requires -Version 7.4 + +$ErrorActionPreference = 'Stop' +Set-StrictMode -Version 3.0 + +if (($Env:CI -eq 'true') -and ($Env:NOPWSHDEBUG -ne 'true')) { + Set-PSDebug -Trace 1 +} + +<# + .DESCRIPTION + Copies the script's $args variable into an array, which is easier to work with + when preparing to start child processes. +#> +function Copy-ScriptArgs { + $list = [System.Collections.Generic.List[string]]::new($Script:args.Count) + foreach ($arg in $Script:args) { + $list.Add($arg) + } + return $list.ToArray() +} + +<# + .DESCRIPTION + Copies the current environment into a hashtable, which is easier to work with + when preparing to start child processes. +#> +function Copy-Environment { + $result = @{} + foreach ($pair in (Get-Item -Path Env:)) { + $result[$pair.Key] = $pair.Value + } + return $result +} + +<# + .DESCRIPTION + Outputs the fully-qualified path to the repository's root directory. This + function expects to be run from somewhere within a git repository. + The directory containing the git executable must be somewhere in the PATH. +#> +function Get-RepoRoot { + Get-Command -Name 'git' | Out-Null + $repoRoot = & git rev-parse --show-toplevel + if ($LASTEXITCODE -ne 0) { + throw "failed obtaining repo root: git failed with code $LASTEXITCODE" + } + + # Git outputs a path containing forward slashes. Canonicalize. + return [System.IO.Path]::GetFullPath($repoRoot) +} + +<# + .DESCRIPTION + Runs the provided ScriptBlock in a child scope, restoring any changes to the + current working directory once the script block completes. +#> +function Start-ChildScope { + param ( + [Parameter(Mandatory = $true)] + [ScriptBlock]$ScriptBlock + ) + + $initialLocation = Get-Location + try { + Invoke-Command -ScriptBlock $ScriptBlock + } + finally { + Set-Location -Path $initialLocation + } +} + +<# + .SYNOPSIS + Write-Output with timestamps prepended to each line. +#> +function Write-Log { + param ($message) + $timestamp = (Get-Date).ToString('yyyy-MM-dd HH:mm:ss') + Write-Output "$timestamp - $message" +} + +$bootstrapScriptBlock = { + + $repoRoot = Get-RepoRoot + + Set-Location -LiteralPath $repoRoot + + switch -Wildcard -File .\go.toolchain.rev { + "/*" { $toolchain = $_ } + default { + $rev = $_ + $tsgo = Join-Path $Env:USERPROFILE '.cache' 'tsgo' + $toolchain = Join-Path $tsgo $rev + if (-not (Test-Path -LiteralPath "$toolchain.extracted" -PathType Leaf -ErrorAction SilentlyContinue)) { + New-Item -Force -Path $tsgo -ItemType Directory | Out-Null + Remove-Item -Force -Recurse -LiteralPath $toolchain -ErrorAction SilentlyContinue + Write-Log "Downloading Go toolchain $rev" + + # Values from https://web.archive.org/web/20250227081443/https://learn.microsoft.com/en-us/dotnet/api/system.runtime.interopservices.architecture?view=net-9.0 + $cpuArch = ([System.Runtime.InteropServices.RuntimeInformation]::OSArchitecture | Out-String -NoNewline) + # Comparison in switch is case-insensitive by default. + switch ($cpuArch) { + 'x86' { $goArch = '386' } + 'x64' { $goArch = 'amd64' } + default { $goArch = $cpuArch } + } + + Invoke-WebRequest -Uri "https://github.com/tailscale/go/releases/download/build-$rev/windows-$goArch.tar.gz" -OutFile "$toolchain.tar.gz" + try { + New-Item -Force -Path $toolchain -ItemType Directory | Out-Null + Start-ChildScope -ScriptBlock { + Set-Location -LiteralPath $toolchain + tar --strip-components=1 -xf "$toolchain.tar.gz" + if ($LASTEXITCODE -ne 0) { + throw "tar failed with exit code $LASTEXITCODE" + } + } + $rev | Out-File -FilePath "$toolchain.extracted" + } + finally { + Remove-Item -Force "$toolchain.tar.gz" -ErrorAction Continue + } + + # Cleanup old toolchains. + $maxDays = 90 + $oldFiles = Get-ChildItem -Path $tsgo -Filter '*.extracted' -File -Recurse -Depth 1 | Where-Object { $_.LastWriteTime -lt (Get-Date).AddDays(-$maxDays) } + foreach ($file in $oldFiles) { + Write-Log "Cleaning up old Go toolchain $($file.Basename)" + Remove-Item -LiteralPath $file.FullName -Force -ErrorAction Continue + $dirName = Join-Path $file.DirectoryName $file.Basename -Resolve -ErrorAction Continue + if ($dirName -and (Test-Path -LiteralPath $dirName -PathType Container -ErrorAction Continue)) { + Remove-Item -LiteralPath $dirName -Recurse -Force -ErrorAction Continue + } + } + } + } + } + + if ($Env:TS_USE_GOCROSS -ne '1') { + return + } + + if (Test-Path -LiteralPath $toolchain -PathType Container -ErrorAction SilentlyContinue) { + $goMod = Join-Path $repoRoot 'go.mod' -Resolve + $goLine = Get-Content -LiteralPath $goMod | Select-String -Pattern '^go (.*)$' -List + $wantGoMinor = $goLine.Matches.Groups[1].Value.split('.')[1] + $versionFile = Join-Path $toolchain 'VERSION' + if (Test-Path -LiteralPath $versionFile -PathType Leaf -ErrorAction SilentlyContinue) { + try { + $haveGoMinor = ((Get-Content -LiteralPath $versionFile -TotalCount 1).split('.')[1]) -replace 'rc.*', '' + } + catch { + } + } + + if ([string]::IsNullOrEmpty($haveGoMinor) -or ($haveGoMinor -lt $wantGoMinor)) { + Remove-Item -Force -Recurse -LiteralPath $toolchain -ErrorAction Continue + Remove-Item -Force -LiteralPath "$toolchain.extracted" -ErrorAction Continue + } + } + + $wantVer = & git rev-parse HEAD + $gocrossOk = $false + $gocrossPath = '.\gocross.exe' + if (Get-Command -Name $gocrossPath -CommandType Application -ErrorAction SilentlyContinue) { + $gotVer = & $gocrossPath gocross-version 2> $null + if ($gotVer -eq $wantVer) { + $gocrossOk = $true + } + } + + if (-not $gocrossOk) { + $goBuildEnv = Copy-Environment + $goBuildEnv['CGO_ENABLED'] = '0' + # Start-Process's -Environment arg applies diffs, so instead of removing + # these variables from $goBuildEnv, we must set them to $null to indicate + # that they should be cleared. + $goBuildEnv['GOOS'] = $null + $goBuildEnv['GOARCH'] = $null + $goBuildEnv['GO111MODULE'] = $null + $goBuildEnv['GOROOT'] = $null + + $procExe = Join-Path $toolchain 'bin' 'go.exe' -Resolve + $proc = Start-Process -FilePath $procExe -WorkingDirectory $repoRoot -Environment $goBuildEnv -ArgumentList 'build', '-o', $gocrossPath, "-ldflags=-X=tailscale.com/version.gitCommitStamp=$wantVer", 'tailscale.com/tool/gocross' -NoNewWindow -Wait -PassThru + if ($proc.ExitCode -ne 0) { + throw 'error building gocross' + } + } + +} # bootstrapScriptBlock + +Start-ChildScope -ScriptBlock $bootstrapScriptBlock + +$repoRoot = Get-RepoRoot + +$execEnv = Copy-Environment +# Start-Process's -Environment arg applies diffs, so instead of removing +# these variables from $execEnv, we must set them to $null to indicate +# that they should be cleared. +$execEnv['GOROOT'] = $null + +$argList = Copy-ScriptArgs + +if ($Env:TS_USE_GOCROSS -ne '1') { + $revFile = Join-Path $repoRoot 'go.toolchain.rev' -Resolve + switch -Wildcard -File $revFile { + "/*" { $toolchain = $_ } + default { + $rev = $_ + $tsgo = Join-Path $Env:USERPROFILE '.cache' 'tsgo' + $toolchain = Join-Path $tsgo $rev -Resolve + } + } + + $procExe = Join-Path $toolchain 'bin' 'go.exe' -Resolve + $proc = Start-Process -FilePath $procExe -WorkingDirectory $repoRoot -Environment $execEnv -ArgumentList $argList -NoNewWindow -Wait -PassThru + exit $proc.ExitCode +} + +$procExe = Join-Path $repoRoot 'gocross.exe' -Resolve +$proc = Start-Process -FilePath $procExe -WorkingDirectory $repoRoot -Environment $execEnv -ArgumentList $argList -NoNewWindow -Wait -PassThru +exit $proc.ExitCode diff --git a/tool/gocross/gocross-wrapper.sh b/tool/gocross/gocross-wrapper.sh index 6817b6e4e..d93b137aa 100755 --- a/tool/gocross/gocross-wrapper.sh +++ b/tool/gocross/gocross-wrapper.sh @@ -3,8 +3,11 @@ # SPDX-License-Identifier: BSD-3-Clause # # gocross-wrapper.sh is a wrapper that can be aliased to 'go', which -# transparently builds gocross using a "bootstrap" Go toolchain, and -# then invokes gocross. +# transparently runs the version of github.com/tailscale/go as specified repo's +# go.toolchain.rev file. +# +# It also conditionally (if TS_USE_GOCROSS=1) builds gocross and uses it as a go +# wrapper to inject certain go flags. set -euo pipefail @@ -12,6 +15,12 @@ if [[ "${CI:-}" == "true" && "${NOBASHDEBUG:-}" != "true" ]]; then set -x fi +if [[ "${OSTYPE:-}" == "cygwin" || "${OSTYPE:-}" == "msys" ]]; then + hash pwsh 2>/dev/null || { echo >&2 "This operation requires PowerShell Core."; exit 1; } + pwsh -NoProfile -ExecutionPolicy Bypass "${BASH_SOURCE%/*}/gocross-wrapper.ps1" "$@" + exit +fi + # Locate a bootstrap toolchain and (re)build gocross if necessary. We run all of # this in a subshell because posix shell semantics make it very easy to # accidentally mutate the input environment that will get passed to gocross at @@ -67,15 +76,24 @@ case "$REV" in rm -f "$toolchain.tar.gz" # Do some cleanup of old toolchains while we're here. - for hash in $(find "$HOME/.cache/tsgo" -type f -maxdepth 1 -name '*.extracted' -mtime 90 -exec basename {} \; | sed 's/.extracted$//'); do + for hash in $(find "$HOME/.cache/tsgo" -maxdepth 1 -type f -name '*.extracted' -mtime 90 -exec basename {} \; | sed 's/.extracted$//'); do echo "# Cleaning up old Go toolchain $hash" >&2 rm -rf "$HOME/.cache/tsgo/$hash" rm -rf "$HOME/.cache/tsgo/$hash.extracted" + rm -rf "$HOME/.cache/tsgoroot/$hash" done fi ;; esac +# gocross is opt-in as of 2025-06-16. See tailscale/corp#26717. +# It's primarily used for xcode builds, and a bit still for Windows. +# In the past we needed it for git version stamping on Linux etc, but +# Go does that itself nowadays. +if [ "${TS_USE_GOCROSS:-}" != "1" ]; then + exit 0 # out of subshell +fi + if [[ -d "$toolchain" ]]; then # A toolchain exists, but is it recent enough to compile gocross? If not, # wipe it out so that the next if block fetches a usable one. @@ -119,4 +137,29 @@ if [[ "$gocross_ok" == "0" ]]; then fi ) # End of the subshell execution. -exec "${BASH_SOURCE%/*}/../../gocross" "$@" +repo_root="${BASH_SOURCE%/*}/../.." + +# Some scripts/package systems set GOROOT even though they should only be +# setting $PATH. Stop them from breaking builds - go(1) respects GOROOT and +# so if it is left on here, compilation units depending on our Go fork will +# fail (such as those which depend on our net/ patches). +unset GOROOT + +# gocross is opt-in as of 2025-06-16. See tailscale/corp#26717 +# and comment above in this file. +if [ "${TS_USE_GOCROSS:-}" != "1" ]; then + read -r REV <"${repo_root}/go.toolchain.rev" + case "$REV" in + /*) + toolchain="$REV" + ;; + *) + # If the prior subshell completed successfully, this toolchain location + # should be valid at this point. + toolchain="$HOME/.cache/tsgo/$REV" + ;; + esac + exec "$toolchain/bin/go" "$@" +fi + +exec "${repo_root}/gocross" "$@" diff --git a/tool/gocross/gocross.go b/tool/gocross/gocross.go index 8011c1095..41fab3d58 100644 --- a/tool/gocross/gocross.go +++ b/tool/gocross/gocross.go @@ -15,9 +15,10 @@ import ( "fmt" "os" "path/filepath" + "runtime/debug" + "strings" "tailscale.com/atomicfile" - "tailscale.com/version" ) func main() { @@ -28,8 +29,19 @@ func main() { // any time. switch os.Args[1] { case "gocross-version": - fmt.Println(version.GetMeta().GitCommit) - os.Exit(0) + bi, ok := debug.ReadBuildInfo() + if !ok { + fmt.Fprintln(os.Stderr, "failed getting build info") + os.Exit(1) + } + for _, s := range bi.Settings { + if s.Key == "vcs.revision" { + fmt.Println(s.Value) + os.Exit(0) + } + } + fmt.Fprintln(os.Stderr, "did not find vcs.revision in build info") + os.Exit(1) case "is-gocross": // This subcommand exits with an error code when called on a // regular go binary, so it can be used to detect when `go` is @@ -57,8 +69,13 @@ func main() { fmt.Fprintf(os.Stderr, "usage: gocross write-wrapper-script \n") os.Exit(1) } - if err := atomicfile.WriteFile(os.Args[2], wrapperScript, 0755); err != nil { - fmt.Fprintf(os.Stderr, "writing wrapper script: %v\n", err) + if err := atomicfile.WriteFile(os.Args[2], wrapperScriptBash, 0755); err != nil { + fmt.Fprintf(os.Stderr, "writing bash wrapper script: %v\n", err) + os.Exit(1) + } + psFileName := strings.TrimSuffix(os.Args[2], filepath.Ext(os.Args[2])) + ".ps1" + if err := atomicfile.WriteFile(psFileName, wrapperScriptPowerShell, 0644); err != nil { + fmt.Fprintf(os.Stderr, "writing PowerShell wrapper script: %v\n", err) os.Exit(1) } os.Exit(0) @@ -85,9 +102,9 @@ func main() { path := filepath.Join(toolchain, "bin") + string(os.PathListSeparator) + os.Getenv("PATH") env.Set("PATH", path) - debug("Input: %s\n", formatArgv(os.Args)) - debug("Command: %s\n", formatArgv(newArgv)) - debug("Set the following flags/envvars:\n%s\n", env.Diff()) + debugf("Input: %s\n", formatArgv(os.Args)) + debugf("Command: %s\n", formatArgv(newArgv)) + debugf("Set the following flags/envvars:\n%s\n", env.Diff()) args = newArgv if err := env.Apply(); err != nil { @@ -97,13 +114,20 @@ func main() { } - doExec(filepath.Join(toolchain, "bin/go"), args, os.Environ()) + // Note that doExec only returns if the exec call failed. + if err := doExec(filepath.Join(toolchain, "bin", "go"), args, os.Environ()); err != nil { + fmt.Fprintf(os.Stderr, "executing process: %v\n", err) + os.Exit(1) + } } //go:embed gocross-wrapper.sh -var wrapperScript []byte +var wrapperScriptBash []byte + +//go:embed gocross-wrapper.ps1 +var wrapperScriptPowerShell []byte -func debug(format string, args ...any) { +func debugf(format string, args ...any) { debug := os.Getenv("GOCROSS_DEBUG") var ( out *os.File diff --git a/tool/gocross/gocross_test.go b/tool/gocross/gocross_test.go new file mode 100644 index 000000000..82afd268c --- /dev/null +++ b/tool/gocross/gocross_test.go @@ -0,0 +1,19 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package main + +import ( + "testing" + + "tailscale.com/tstest/deptest" +) + +func TestDeps(t *testing.T) { + deptest.DepChecker{ + BadDeps: map[string]string{ + "tailscale.com/tailcfg": "circular dependency via go generate", + "tailscale.com/version": "circular dependency via go generate", + }, + }.Check(t) +} diff --git a/tool/gocross/gocross_wrapper_test.go b/tool/gocross/gocross_wrapper_test.go index 2b0f016a2..6937ccec7 100644 --- a/tool/gocross/gocross_wrapper_test.go +++ b/tool/gocross/gocross_wrapper_test.go @@ -15,13 +15,13 @@ import ( func TestGocrossWrapper(t *testing.T) { for i := range 2 { // once to build gocross; second to test it's cached cmd := exec.Command("./gocross-wrapper.sh", "version") - cmd.Env = append(os.Environ(), "CI=true", "NOBASHDEBUG=false") // for "set -x" verbosity + cmd.Env = append(os.Environ(), "CI=true", "NOBASHDEBUG=false", "TS_USE_GOCROSS=1") // for "set -x" verbosity out, err := cmd.CombinedOutput() if err != nil { t.Fatalf("gocross-wrapper.sh failed: %v\n%s", err, out) } if i > 0 && !strings.Contains(string(out), "gocross_ok=1\n") { - t.Errorf("expected to find 'gocross-ok=1'; got output:\n%s", out) + t.Errorf("expected to find 'gocross_ok=1'; got output:\n%s", out) } } } diff --git a/tool/gocross/gocross_wrapper_windows_test.go b/tool/gocross/gocross_wrapper_windows_test.go new file mode 100644 index 000000000..aa4277425 --- /dev/null +++ b/tool/gocross/gocross_wrapper_windows_test.go @@ -0,0 +1,25 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package main + +import ( + "os" + "os/exec" + "strings" + "testing" +) + +func TestGocrossWrapper(t *testing.T) { + for i := range 2 { // once to build gocross; second to test it's cached + cmd := exec.Command("pwsh", "-NoProfile", "-ExecutionPolicy", "Bypass", ".\\gocross-wrapper.ps1", "version") + cmd.Env = append(os.Environ(), "CI=true", "NOPWSHDEBUG=false", "TS_USE_GOCROSS=1") // for Set-PSDebug verbosity + out, err := cmd.CombinedOutput() + if err != nil { + t.Fatalf("gocross-wrapper.ps1 failed: %v\n%s", err, out) + } + if i > 0 && !strings.Contains(string(out), "$gocrossOk = $true\r\n") { + t.Errorf("expected to find '$gocrossOk = $true'; got output:\n%s", out) + } + } +} diff --git a/tool/gocross/toolchain.go b/tool/gocross/toolchain.go index e701662f5..9cf7f892b 100644 --- a/tool/gocross/toolchain.go +++ b/tool/gocross/toolchain.go @@ -60,9 +60,17 @@ func getToolchain() (toolchainDir, gorootDir string, err error) { return "", "", err } - cache := filepath.Join(os.Getenv("HOME"), ".cache") + homeDir, err := os.UserHomeDir() + if err != nil { + return "", "", err + } + + // We use ".cache" instead of os.UserCacheDir for legacy reasons and we + // don't want to break that on platforms where the latter returns a different + // result. + cache := filepath.Join(homeDir, ".cache") toolchainDir = filepath.Join(cache, "tsgo", rev) - gorootDir = filepath.Join(toolchainDir, "gocross-goroot") + gorootDir = filepath.Join(cache, "tsgoroot", rev) // You might wonder why getting the toolchain also provisions and returns a // path suitable for use as GOROOT. Wonder no longer! diff --git a/tool/listpkgs/listpkgs.go b/tool/listpkgs/listpkgs.go new file mode 100644 index 000000000..400bf90c1 --- /dev/null +++ b/tool/listpkgs/listpkgs.go @@ -0,0 +1,206 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// listpkgs prints the import paths that match the Go package patterns +// given on the command line and conditionally filters them in various ways. +package main + +import ( + "bufio" + "flag" + "fmt" + "go/build/constraint" + "log" + "os" + "slices" + "strings" + "sync" + + "golang.org/x/tools/go/packages" +) + +var ( + ignore3p = flag.Bool("ignore-3p", false, "ignore third-party packages forked/vendored into Tailscale") + goos = flag.String("goos", "", "GOOS to use for loading packages (default: current OS)") + goarch = flag.String("goarch", "", "GOARCH to use for loading packages (default: current architecture)") + withTagsAllStr = flag.String("with-tags-all", "", "if non-empty, a comma-separated list of builds tags to require (a package will only be listed if it contains all of these build tags)") + withoutTagsAnyStr = flag.String("without-tags-any", "", "if non-empty, a comma-separated list of build constraints to exclude (a package will be omitted if it contains any of these build tags)") + shard = flag.String("shard", "", "if non-empty, a string of the form 'N/M' to only print packages in shard N of M (e.g. '1/3', '2/3', '3/3/' for different thirds of the list)") +) + +func main() { + flag.Parse() + + patterns := flag.Args() + if len(patterns) == 0 { + flag.Usage() + os.Exit(1) + } + + cfg := &packages.Config{ + Mode: packages.LoadFiles, + Env: os.Environ(), + } + if *goos != "" { + cfg.Env = append(cfg.Env, "GOOS="+*goos) + } + if *goarch != "" { + cfg.Env = append(cfg.Env, "GOARCH="+*goarch) + } + + pkgs, err := packages.Load(cfg, patterns...) + if err != nil { + log.Fatalf("loading packages: %v", err) + } + + var withoutAny []string + if *withoutTagsAnyStr != "" { + withoutAny = strings.Split(*withoutTagsAnyStr, ",") + } + var withAll []string + if *withTagsAllStr != "" { + withAll = strings.Split(*withTagsAllStr, ",") + } + + seen := map[string]bool{} + matches := 0 +Pkg: + for _, pkg := range pkgs { + if pkg.PkgPath == "" { // malformed (shouldn’t happen) + continue + } + if seen[pkg.PkgPath] { + continue // suppress duplicates when patterns overlap + } + seen[pkg.PkgPath] = true + + pkgPath := pkg.PkgPath + + if *ignore3p && isThirdParty(pkgPath) { + continue + } + if withAll != nil { + for _, t := range withAll { + if !hasBuildTag(pkg, t) { + continue Pkg + } + } + } + for _, t := range withoutAny { + if hasBuildTag(pkg, t) { + continue Pkg + } + } + matches++ + + if *shard != "" { + var n, m int + if _, err := fmt.Sscanf(*shard, "%d/%d", &n, &m); err != nil || n < 1 || m < 1 { + log.Fatalf("invalid shard format %q; expected 'N/M'", *shard) + } + if m > 0 && (matches-1)%m != n-1 { + continue // not in this shard + } + } + fmt.Println(pkgPath) + } + + // If any package had errors (e.g. missing deps) report them via packages.PrintErrors. + // This mirrors `go list` behaviour when -e is *not* supplied. + if packages.PrintErrors(pkgs) > 0 { + os.Exit(1) + } +} + +func isThirdParty(pkg string) bool { + return strings.HasPrefix(pkg, "tailscale.com/tempfork/") +} + +// hasBuildTag reports whether any source file in pkg mentions `tag` +// in a //go:build constraint. +func hasBuildTag(pkg *packages.Package, tag string) bool { + all := slices.Concat(pkg.CompiledGoFiles, pkg.OtherFiles, pkg.IgnoredFiles) + suffix := "_" + tag + ".go" + for _, name := range all { + if strings.HasSuffix(name, suffix) { + return true + } + ok, err := fileMentionsTag(name, tag) + if err != nil { + log.Printf("reading %s: %v", name, err) + continue + } + if ok { + return true + } + } + return false +} + +// tagSet is a set of build tags. +// The values are always true. We avoid non-std set types +// to make this faster to "go run" on empty caches. +type tagSet map[string]bool + +var ( + mu sync.Mutex + fileTags = map[string]tagSet{} // abs path -> set of build tags mentioned in file +) + +func getFileTags(filename string) (tagSet, error) { + mu.Lock() + tags, ok := fileTags[filename] + mu.Unlock() + if ok { + return tags, nil + } + + f, err := os.Open(filename) + if err != nil { + return nil, err + } + defer f.Close() + + ts := make(tagSet) + s := bufio.NewScanner(f) + for s.Scan() { + line := s.Text() + if strings.TrimSpace(line) == "" { + continue // still in leading blank lines + } + if !strings.HasPrefix(line, "//") { + // hit real code – done with header comments + // TODO(bradfitz): care about /* */ comments? + break + } + if !strings.HasPrefix(line, "//go:build") { + continue // some other comment + } + expr, err := constraint.Parse(line) + if err != nil { + return nil, fmt.Errorf("parsing %q: %w", line, err) + } + // Call Eval to populate ts with the tags mentioned in the expression. + // We don't care about the result, just the side effect of populating ts. + expr.Eval(func(tag string) bool { + ts[tag] = true + return true // arbitrary + }) + } + if err := s.Err(); err != nil { + return nil, fmt.Errorf("reading %s: %w", filename, err) + } + + mu.Lock() + defer mu.Unlock() + fileTags[filename] = ts + return tags, nil +} + +func fileMentionsTag(filename, tag string) (bool, error) { + tags, err := getFileTags(filename) + if err != nil { + return false, err + } + return tags[tag], nil +} diff --git a/tool/node.rev b/tool/node.rev index 17719ce25..7d41c735d 100644 --- a/tool/node.rev +++ b/tool/node.rev @@ -1 +1 @@ -18.20.4 +22.14.0 diff --git a/tsconsensus/authorization.go b/tsconsensus/authorization.go new file mode 100644 index 000000000..bd8e2f39a --- /dev/null +++ b/tsconsensus/authorization.go @@ -0,0 +1,134 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tsconsensus + +import ( + "context" + "errors" + "net/netip" + "sync" + "time" + + "tailscale.com/ipn" + "tailscale.com/ipn/ipnstate" + "tailscale.com/tsnet" + "tailscale.com/types/views" + "tailscale.com/util/set" +) + +type statusGetter interface { + getStatus(context.Context) (*ipnstate.Status, error) +} + +type tailscaleStatusGetter struct { + ts *tsnet.Server + + mu sync.Mutex // protects the following + lastStatus *ipnstate.Status + lastStatusTime time.Time +} + +func (sg *tailscaleStatusGetter) fetchStatus(ctx context.Context) (*ipnstate.Status, error) { + lc, err := sg.ts.LocalClient() + if err != nil { + return nil, err + } + return lc.Status(ctx) +} + +func (sg *tailscaleStatusGetter) getStatus(ctx context.Context) (*ipnstate.Status, error) { + sg.mu.Lock() + defer sg.mu.Unlock() + if sg.lastStatus != nil && time.Since(sg.lastStatusTime) < 1*time.Second { + return sg.lastStatus, nil + } + status, err := sg.fetchStatus(ctx) + if err != nil { + return nil, err + } + sg.lastStatus = status + sg.lastStatusTime = time.Now() + return status, nil +} + +type authorization struct { + sg statusGetter + tag string + + mu sync.Mutex + peers *peers // protected by mu +} + +func newAuthorization(ts *tsnet.Server, tag string) *authorization { + return &authorization{ + sg: &tailscaleStatusGetter{ + ts: ts, + }, + tag: tag, + } +} + +func (a *authorization) Refresh(ctx context.Context) error { + tStatus, err := a.sg.getStatus(ctx) + if err != nil { + return err + } + if tStatus == nil { + return errors.New("no status") + } + if tStatus.BackendState != ipn.Running.String() { + return errors.New("ts Server is not running") + } + a.mu.Lock() + defer a.mu.Unlock() + a.peers = newPeers(tStatus, a.tag) + return nil +} + +func (a *authorization) AllowsHost(addr netip.Addr) bool { + a.mu.Lock() + defer a.mu.Unlock() + if a.peers == nil { + return false + } + return a.peers.addrs.Contains(addr) +} + +func (a *authorization) SelfAllowed() bool { + a.mu.Lock() + defer a.mu.Unlock() + if a.peers == nil { + return false + } + return a.peers.status.Self.Tags != nil && views.SliceContains(*a.peers.status.Self.Tags, a.tag) +} + +func (a *authorization) AllowedPeers() views.Slice[*ipnstate.PeerStatus] { + a.mu.Lock() + defer a.mu.Unlock() + if a.peers == nil { + return views.Slice[*ipnstate.PeerStatus]{} + } + return views.SliceOf(a.peers.statuses) +} + +type peers struct { + status *ipnstate.Status + addrs set.Set[netip.Addr] + statuses []*ipnstate.PeerStatus +} + +func newPeers(status *ipnstate.Status, tag string) *peers { + ps := &peers{ + status: status, + addrs: set.Set[netip.Addr]{}, + } + for _, p := range status.Peer { + if p.Tags != nil && views.SliceContains(*p.Tags, tag) { + ps.statuses = append(ps.statuses, p) + ps.addrs.AddSlice(p.TailscaleIPs) + } + } + return ps +} diff --git a/tsconsensus/authorization_test.go b/tsconsensus/authorization_test.go new file mode 100644 index 000000000..e0023f4ff --- /dev/null +++ b/tsconsensus/authorization_test.go @@ -0,0 +1,230 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tsconsensus + +import ( + "context" + "fmt" + "net/netip" + "testing" + + "tailscale.com/ipn" + "tailscale.com/ipn/ipnstate" + "tailscale.com/tailcfg" + "tailscale.com/types/key" + "tailscale.com/types/views" +) + +type testStatusGetter struct { + status *ipnstate.Status +} + +func (sg testStatusGetter) getStatus(ctx context.Context) (*ipnstate.Status, error) { + return sg.status, nil +} + +const testTag string = "tag:clusterTag" + +func makeAuthTestPeer(i int, tags views.Slice[string]) *ipnstate.PeerStatus { + return &ipnstate.PeerStatus{ + ID: tailcfg.StableNodeID(fmt.Sprintf("%d", i)), + Tags: &tags, + TailscaleIPs: []netip.Addr{ + netip.AddrFrom4([4]byte{100, 0, 0, byte(i)}), + netip.MustParseAddr(fmt.Sprintf("fd7a:115c:a1e0:0::%d", i)), + }, + } +} + +func makeAuthTestPeers(tags [][]string) []*ipnstate.PeerStatus { + peers := make([]*ipnstate.PeerStatus, len(tags)) + for i, ts := range tags { + peers[i] = makeAuthTestPeer(i, views.SliceOf(ts)) + } + return peers +} + +func authForStatus(s *ipnstate.Status) *authorization { + return &authorization{ + sg: testStatusGetter{ + status: s, + }, + tag: testTag, + } +} + +func authForPeers(self *ipnstate.PeerStatus, peers []*ipnstate.PeerStatus) *authorization { + s := &ipnstate.Status{ + BackendState: ipn.Running.String(), + Self: self, + Peer: map[key.NodePublic]*ipnstate.PeerStatus{}, + } + for _, p := range peers { + s.Peer[key.NewNode().Public()] = p + } + return authForStatus(s) +} + +func TestAuthRefreshErrorsNotRunning(t *testing.T) { + tests := []struct { + in *ipnstate.Status + expected string + }{ + { + in: nil, + expected: "no status", + }, + { + in: &ipnstate.Status{ + BackendState: "NeedsMachineAuth", + }, + expected: "ts Server is not running", + }, + } + + for _, tt := range tests { + t.Run(tt.expected, func(t *testing.T) { + ctx := t.Context() + a := authForStatus(tt.in) + err := a.Refresh(ctx) + if err == nil { + t.Fatalf("expected err to be non-nil") + } + if err.Error() != tt.expected { + t.Fatalf("expected: %s, got: %s", tt.expected, err.Error()) + } + }) + } +} + +func TestAuthUnrefreshed(t *testing.T) { + a := authForStatus(nil) + if a.AllowsHost(netip.MustParseAddr("100.0.0.1")) { + t.Fatalf("never refreshed authorization, allowsHost: expected false, got true") + } + gotAllowedPeers := a.AllowedPeers() + if gotAllowedPeers.Len() != 0 { + t.Fatalf("never refreshed authorization, allowedPeers: expected [], got %v", gotAllowedPeers) + } + if a.SelfAllowed() != false { + t.Fatalf("never refreshed authorization, selfAllowed: expected false got true") + } +} + +func TestAuthAllowsHost(t *testing.T) { + peerTags := [][]string{ + {"woo"}, + nil, + {"woo", testTag}, + {testTag}, + } + peers := makeAuthTestPeers(peerTags) + + tests := []struct { + name string + peerStatus *ipnstate.PeerStatus + expected bool + }{ + { + name: "tagged with different tag", + peerStatus: peers[0], + expected: false, + }, + { + name: "not tagged", + peerStatus: peers[1], + expected: false, + }, + { + name: "tags includes testTag", + peerStatus: peers[2], + expected: true, + }, + { + name: "only tag is testTag", + peerStatus: peers[3], + expected: true, + }, + } + + a := authForPeers(nil, peers) + err := a.Refresh(t.Context()) + if err != nil { + t.Fatal(err) + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // test we get the expected result for any of the peers TailscaleIPs + for _, addr := range tt.peerStatus.TailscaleIPs { + got := a.AllowsHost(addr) + if got != tt.expected { + t.Fatalf("allowed for peer with tags: %v, expected: %t, got %t", tt.peerStatus.Tags, tt.expected, got) + } + } + }) + } +} + +func TestAuthAllowedPeers(t *testing.T) { + ctx := t.Context() + peerTags := [][]string{ + {"woo"}, + nil, + {"woo", testTag}, + {testTag}, + } + peers := makeAuthTestPeers(peerTags) + a := authForPeers(nil, peers) + err := a.Refresh(ctx) + if err != nil { + t.Fatal(err) + } + ps := a.AllowedPeers() + if ps.Len() != 2 { + t.Fatalf("expected: 2, got: %d", ps.Len()) + } + for _, i := range []int{2, 3} { + if !ps.ContainsFunc(func(p *ipnstate.PeerStatus) bool { + return p.ID == peers[i].ID + }) { + t.Fatalf("expected peers[%d] to be in AllowedPeers because it is tagged with testTag", i) + } + } +} + +func TestAuthSelfAllowed(t *testing.T) { + tests := []struct { + name string + in []string + expected bool + }{ + { + name: "self has different tag", + in: []string{"woo"}, + expected: false, + }, + { + name: "selfs tags include testTag", + in: []string{"woo", testTag}, + expected: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := t.Context() + self := makeAuthTestPeer(0, views.SliceOf(tt.in)) + a := authForPeers(self, nil) + err := a.Refresh(ctx) + if err != nil { + t.Fatal(err) + } + got := a.SelfAllowed() + if got != tt.expected { + t.Fatalf("expected: %t, got: %t", tt.expected, got) + } + }) + } +} diff --git a/tsconsensus/bolt_store.go b/tsconsensus/bolt_store.go new file mode 100644 index 000000000..ca347cfc0 --- /dev/null +++ b/tsconsensus/bolt_store.go @@ -0,0 +1,19 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !loong64 + +package tsconsensus + +import ( + "github.com/hashicorp/raft" + raftboltdb "github.com/hashicorp/raft-boltdb/v2" +) + +func boltStore(path string) (raft.StableStore, raft.LogStore, error) { + store, err := raftboltdb.NewBoltStore(path) + if err != nil { + return nil, nil, err + } + return store, store, nil +} diff --git a/tsconsensus/bolt_store_no_bolt.go b/tsconsensus/bolt_store_no_bolt.go new file mode 100644 index 000000000..33b3bd6c7 --- /dev/null +++ b/tsconsensus/bolt_store_no_bolt.go @@ -0,0 +1,18 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build loong64 + +package tsconsensus + +import ( + "errors" + + "github.com/hashicorp/raft" +) + +func boltStore(path string) (raft.StableStore, raft.LogStore, error) { + // "github.com/hashicorp/raft-boltdb/v2" doesn't build on loong64 + // see https://github.com/hashicorp/raft-boltdb/issues/27 + return nil, nil, errors.New("not implemented") +} diff --git a/tsconsensus/http.go b/tsconsensus/http.go new file mode 100644 index 000000000..d2a44015f --- /dev/null +++ b/tsconsensus/http.go @@ -0,0 +1,182 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tsconsensus + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "log" + "net/http" + "time" + + "tailscale.com/util/httpm" +) + +type joinRequest struct { + RemoteHost string + RemoteID string +} + +type commandClient struct { + port uint16 + httpClient *http.Client +} + +func (rac *commandClient) url(host string, path string) string { + return fmt.Sprintf("http://%s:%d%s", host, rac.port, path) +} + +const maxBodyBytes = 1024 * 1024 + +func readAllMaxBytes(r io.Reader) ([]byte, error) { + return io.ReadAll(io.LimitReader(r, maxBodyBytes+1)) +} + +func (rac *commandClient) join(host string, jr joinRequest) error { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + rBs, err := json.Marshal(jr) + if err != nil { + return err + } + url := rac.url(host, "/join") + req, err := http.NewRequestWithContext(ctx, httpm.POST, url, bytes.NewReader(rBs)) + if err != nil { + return err + } + resp, err := rac.httpClient.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + if resp.StatusCode != 200 { + respBs, err := readAllMaxBytes(resp.Body) + if err != nil { + return err + } + return fmt.Errorf("remote responded %d: %s", resp.StatusCode, string(respBs)) + } + return nil +} + +func (rac *commandClient) executeCommand(host string, bs []byte) (CommandResult, error) { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + url := rac.url(host, "/executeCommand") + req, err := http.NewRequestWithContext(ctx, httpm.POST, url, bytes.NewReader(bs)) + if err != nil { + return CommandResult{}, err + } + resp, err := rac.httpClient.Do(req) + if err != nil { + return CommandResult{}, err + } + defer resp.Body.Close() + respBs, err := readAllMaxBytes(resp.Body) + if err != nil { + return CommandResult{}, err + } + if resp.StatusCode != 200 { + return CommandResult{}, fmt.Errorf("remote responded %d: %s", resp.StatusCode, string(respBs)) + } + var cr CommandResult + if err = json.Unmarshal(respBs, &cr); err != nil { + return CommandResult{}, err + } + return cr, nil +} + +type authedHandler struct { + auth *authorization + handler http.Handler +} + +func (h authedHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + err := h.auth.Refresh(r.Context()) + if err != nil { + log.Printf("error authedHandler ServeHTTP refresh auth: %v", err) + http.Error(w, "", http.StatusInternalServerError) + return + } + a, err := addrFromServerAddress(r.RemoteAddr) + if err != nil { + log.Printf("error authedHandler ServeHTTP refresh auth: %v", err) + http.Error(w, "", http.StatusInternalServerError) + return + } + allowed := h.auth.AllowsHost(a) + if !allowed { + http.Error(w, "peer not allowed", http.StatusForbidden) + return + } + h.handler.ServeHTTP(w, r) +} + +func (c *Consensus) handleJoinHTTP(w http.ResponseWriter, r *http.Request) { + defer r.Body.Close() + decoder := json.NewDecoder(http.MaxBytesReader(w, r.Body, maxBodyBytes+1)) + var jr joinRequest + err := decoder.Decode(&jr) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + _, err = decoder.Token() + if !errors.Is(err, io.EOF) { + http.Error(w, "Request body must only contain a single JSON object", http.StatusBadRequest) + return + } + if jr.RemoteHost == "" { + http.Error(w, "Required: remoteAddr", http.StatusBadRequest) + return + } + if jr.RemoteID == "" { + http.Error(w, "Required: remoteID", http.StatusBadRequest) + return + } + err = c.handleJoin(jr) + if err != nil { + log.Printf("join handler error: %v", err) + http.Error(w, "", http.StatusInternalServerError) + return + } +} + +func (c *Consensus) handleExecuteCommandHTTP(w http.ResponseWriter, r *http.Request) { + defer r.Body.Close() + decoder := json.NewDecoder(r.Body) + var cmd Command + err := decoder.Decode(&cmd) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + result, err := c.executeCommandLocally(cmd) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + if err := json.NewEncoder(w).Encode(result); err != nil { + log.Printf("error encoding execute command result: %v", err) + return + } +} + +func (c *Consensus) makeCommandMux() *http.ServeMux { + mux := http.NewServeMux() + mux.HandleFunc("POST /join", c.handleJoinHTTP) + mux.HandleFunc("POST /executeCommand", c.handleExecuteCommandHTTP) + return mux +} + +func (c *Consensus) makeCommandHandler(auth *authorization) http.Handler { + return authedHandler{ + handler: c.makeCommandMux(), + auth: auth, + } +} diff --git a/tsconsensus/monitor.go b/tsconsensus/monitor.go new file mode 100644 index 000000000..c84e83454 --- /dev/null +++ b/tsconsensus/monitor.go @@ -0,0 +1,158 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tsconsensus + +import ( + "context" + "encoding/json" + "fmt" + "io" + "log" + "net/http" + "slices" + + "tailscale.com/ipn" + "tailscale.com/ipn/ipnstate" + "tailscale.com/tsnet" + "tailscale.com/util/dnsname" +) + +type status struct { + Status *ipnstate.Status + RaftState string +} + +type monitor struct { + ts *tsnet.Server + con *Consensus + sg statusGetter +} + +func (m *monitor) getStatus(ctx context.Context) (status, error) { + tStatus, err := m.sg.getStatus(ctx) + if err != nil { + return status{}, err + } + return status{Status: tStatus, RaftState: m.con.raft.State().String()}, nil +} + +func serveMonitor(c *Consensus, ts *tsnet.Server, listenAddr string) (*http.Server, error) { + ln, err := ts.Listen("tcp", listenAddr) + if err != nil { + return nil, err + } + m := &monitor{con: c, ts: ts, sg: &tailscaleStatusGetter{ + ts: ts, + }} + mux := http.NewServeMux() + mux.HandleFunc("GET /full", m.handleFullStatus) + mux.HandleFunc("GET /{$}", m.handleSummaryStatus) + mux.HandleFunc("GET /netmap", m.handleNetmap) + mux.HandleFunc("POST /dial", m.handleDial) + srv := &http.Server{Handler: mux} + go func() { + err := srv.Serve(ln) + log.Printf("MonitorHTTP stopped serving with error: %v", err) + }() + return srv, nil +} + +func (m *monitor) handleFullStatus(w http.ResponseWriter, r *http.Request) { + s, err := m.getStatus(r.Context()) + if err != nil { + log.Printf("monitor: error getStatus: %v", err) + http.Error(w, "", http.StatusInternalServerError) + return + } + if err := json.NewEncoder(w).Encode(s); err != nil { + log.Printf("monitor: error encoding full status: %v", err) + return + } +} + +func (m *monitor) handleSummaryStatus(w http.ResponseWriter, r *http.Request) { + s, err := m.getStatus(r.Context()) + if err != nil { + log.Printf("monitor: error getStatus: %v", err) + http.Error(w, "", http.StatusInternalServerError) + return + } + lines := []string{} + for _, p := range s.Status.Peer { + if p.Online { + name := dnsname.FirstLabel(p.DNSName) + lines = append(lines, fmt.Sprintf("%s\t\t%d\t%d\t%t", name, p.RxBytes, p.TxBytes, p.Active)) + } + } + _, err = w.Write([]byte(fmt.Sprintf("RaftState: %s\n", s.RaftState))) + if err != nil { + log.Printf("monitor: error writing status: %v", err) + return + } + + slices.Sort(lines) + for _, ln := range lines { + _, err = w.Write([]byte(fmt.Sprintf("%s\n", ln))) + if err != nil { + log.Printf("monitor: error writing status: %v", err) + return + } + } +} + +func (m *monitor) handleNetmap(w http.ResponseWriter, r *http.Request) { + lc, err := m.ts.LocalClient() + if err != nil { + log.Printf("monitor: error LocalClient: %v", err) + http.Error(w, "", http.StatusInternalServerError) + return + } + watcher, err := lc.WatchIPNBus(r.Context(), ipn.NotifyInitialNetMap) + if err != nil { + log.Printf("monitor: error WatchIPNBus: %v", err) + http.Error(w, "", http.StatusInternalServerError) + return + } + defer watcher.Close() + + n, err := watcher.Next() + if err != nil { + log.Printf("monitor: error watcher.Next: %v", err) + http.Error(w, "", http.StatusInternalServerError) + return + } + encoder := json.NewEncoder(w) + encoder.SetIndent("", "\t") + if err := encoder.Encode(n); err != nil { + log.Printf("monitor: error encoding netmap: %v", err) + return + } +} + +func (m *monitor) handleDial(w http.ResponseWriter, r *http.Request) { + var dialParams struct { + Addr string + } + defer r.Body.Close() + bs, err := io.ReadAll(http.MaxBytesReader(w, r.Body, maxBodyBytes)) + if err != nil { + log.Printf("monitor: error reading body: %v", err) + http.Error(w, "", http.StatusInternalServerError) + return + } + err = json.Unmarshal(bs, &dialParams) + if err != nil { + log.Printf("monitor: error unmarshalling json: %v", err) + http.Error(w, "", http.StatusBadRequest) + return + } + c, err := m.ts.Dial(r.Context(), "tcp", dialParams.Addr) + if err != nil { + log.Printf("monitor: error dialing: %v", err) + http.Error(w, "", http.StatusInternalServerError) + return + } + c.Close() + w.Write([]byte("ok\n")) +} diff --git a/tsconsensus/tsconsensus.go b/tsconsensus/tsconsensus.go new file mode 100644 index 000000000..1f7dc1b7b --- /dev/null +++ b/tsconsensus/tsconsensus.go @@ -0,0 +1,547 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package tsconsensus implements a consensus algorithm for a group of tsnet.Servers +// +// The Raft consensus algorithm relies on you implementing a state machine that will give the same +// result to a given command as long as the same logs have been applied in the same order. +// +// tsconsensus uses the hashicorp/raft library to implement leader elections and log application. +// +// tsconsensus provides: +// - cluster peer discovery based on tailscale tags +// - executing a command on the leader +// - communication between cluster peers over tailscale using tsnet +// +// Users implement a state machine that satisfies the raft.FSM interface, with the business logic they desire. +// When changes to state are needed any node may +// - create a Command instance with serialized Args. +// - call ExecuteCommand with the Command instance +// this will propagate the command to the leader, +// and then from the reader to every node via raft. +// - the state machine then can implement raft.Apply, and dispatch commands via the Command.Name +// returning a CommandResult with an Err or a serialized Result. +package tsconsensus + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "log" + "net" + "net/http" + "net/netip" + "path/filepath" + "time" + + "github.com/hashicorp/go-hclog" + "github.com/hashicorp/raft" + "tailscale.com/ipn/ipnstate" + "tailscale.com/tsnet" + "tailscale.com/types/views" +) + +func raftAddr(host netip.Addr, cfg Config) string { + return netip.AddrPortFrom(host, cfg.RaftPort).String() +} + +func addrFromServerAddress(sa string) (netip.Addr, error) { + addrPort, err := netip.ParseAddrPort(sa) + if err != nil { + return netip.Addr{}, err + } + return addrPort.Addr(), nil +} + +// A selfRaftNode is the info we need to talk to hashicorp/raft about our node. +// We specify the ID and Addr on Consensus Start, and then use it later for raft +// operations such as BootstrapCluster and AddVoter. +type selfRaftNode struct { + id string + hostAddr netip.Addr +} + +// A Config holds configurable values such as ports and timeouts. +// Use DefaultConfig to get a useful Config. +type Config struct { + CommandPort uint16 + RaftPort uint16 + MonitorPort uint16 + Raft *raft.Config + MaxConnPool int + ConnTimeout time.Duration + ServeDebugMonitor bool + StateDirPath string +} + +// DefaultConfig returns a Config populated with default values ready for use. +func DefaultConfig() Config { + raftConfig := raft.DefaultConfig() + // these values are 2x the raft DefaultConfig + raftConfig.HeartbeatTimeout = 2000 * time.Millisecond + raftConfig.ElectionTimeout = 2000 * time.Millisecond + raftConfig.LeaderLeaseTimeout = 1000 * time.Millisecond + + return Config{ + CommandPort: 6271, + RaftPort: 6270, + MonitorPort: 8081, + Raft: raftConfig, + MaxConnPool: 5, + ConnTimeout: 5 * time.Second, + } +} + +// StreamLayer implements an interface asked for by raft.NetworkTransport. +// It does the raft interprocess communication via tailscale. +type StreamLayer struct { + net.Listener + s *tsnet.Server + auth *authorization + shutdownCtx context.Context +} + +// Dial implements the raft.StreamLayer interface with the tsnet.Server's Dial. +func (sl StreamLayer) Dial(address raft.ServerAddress, timeout time.Duration) (net.Conn, error) { + ctx, cancel := context.WithTimeout(sl.shutdownCtx, timeout) + defer cancel() + authorized, err := sl.addrAuthorized(ctx, string(address)) + if err != nil { + return nil, err + } + if !authorized { + return nil, errors.New("dial: peer is not allowed") + } + return sl.s.Dial(ctx, "tcp", string(address)) +} + +func (sl StreamLayer) addrAuthorized(ctx context.Context, address string) (bool, error) { + addr, err := addrFromServerAddress(address) + if err != nil { + // bad RemoteAddr is not authorized + return false, nil + } + err = sl.auth.Refresh(ctx) + if err != nil { + // might be authorized, we couldn't tell + return false, err + } + return sl.auth.AllowsHost(addr), nil +} + +func (sl StreamLayer) Accept() (net.Conn, error) { + ctx, cancel := context.WithCancel(sl.shutdownCtx) + defer cancel() + for { + conn, err := sl.Listener.Accept() + if err != nil || conn == nil { + return conn, err + } + addr := conn.RemoteAddr() + if addr == nil { + conn.Close() + return nil, errors.New("conn has no remote addr") + } + authorized, err := sl.addrAuthorized(ctx, addr.String()) + if err != nil { + conn.Close() + return nil, err + } + if !authorized { + log.Printf("StreamLayer accept: unauthorized: %s", addr) + conn.Close() + continue + } + return conn, err + } +} + +type BootstrapOpts struct { + Tag string + FollowOnly bool +} + +// Start returns a pointer to a running Consensus instance. +// Calling it with a *tsnet.Server will cause that server to join or start a consensus cluster +// with other nodes on the tailnet tagged with the clusterTag. The *tsnet.Server will run the state +// machine defined by the raft.FSM also provided, and keep it in sync with the other cluster members' +// state machines using Raft. +func Start(ctx context.Context, ts *tsnet.Server, fsm raft.FSM, bootstrapOpts BootstrapOpts, cfg Config) (*Consensus, error) { + if bootstrapOpts.Tag == "" { + return nil, errors.New("cluster tag must be provided") + } + + cc := commandClient{ + port: cfg.CommandPort, + httpClient: ts.HTTPClient(), + } + v4, _ := ts.TailscaleIPs() + // TODO(fran) support tailnets that have ipv4 disabled + self := selfRaftNode{ + id: v4.String(), + hostAddr: v4, + } + shutdownCtx, shutdownCtxCancel := context.WithCancel(ctx) + c := Consensus{ + commandClient: &cc, + self: self, + config: cfg, + shutdownCtxCancel: shutdownCtxCancel, + } + + auth := newAuthorization(ts, bootstrapOpts.Tag) + err := auth.Refresh(shutdownCtx) + if err != nil { + return nil, fmt.Errorf("auth refresh: %w", err) + } + if !auth.SelfAllowed() { + return nil, errors.New("this node is not tagged with the cluster tag") + } + + srv, err := c.serveCommandHTTP(ts, auth) + if err != nil { + return nil, err + } + c.cmdHttpServer = srv + + // after startRaft it's possible some other raft node that has us in their configuration will get + // in contact, so by the time we do anything else we may already be a functioning member + // of a consensus + r, err := startRaft(shutdownCtx, ts, &fsm, c.self, auth, cfg) + if err != nil { + return nil, err + } + c.raft = r + + // we may already be in a consensus (see comment above before startRaft) but we're going to + // try to bootstrap anyway in case this is a fresh start. + err = c.bootstrap(shutdownCtx, auth, bootstrapOpts) + if err != nil { + if errors.Is(err, raft.ErrCantBootstrap) { + // Raft cluster state can be persisted, if we try to call raft.BootstrapCluster when + // we already have cluster state it will return raft.ErrCantBootstrap. It's safe to + // ignore (according to the comment in the raft code), and we can expect that the other + // nodes of the cluster will become available at some point and we can get back into the + // consensus. + log.Print("Bootstrap: raft has cluster state, waiting for peers") + } else { + return nil, err + } + } + + if cfg.ServeDebugMonitor { + srv, err = serveMonitor(&c, ts, netip.AddrPortFrom(c.self.hostAddr, cfg.MonitorPort).String()) + if err != nil { + return nil, err + } + c.monitorHttpServer = srv + } + + return &c, nil +} + +func startRaft(shutdownCtx context.Context, ts *tsnet.Server, fsm *raft.FSM, self selfRaftNode, auth *authorization, cfg Config) (*raft.Raft, error) { + cfg.Raft.LocalID = raft.ServerID(self.id) + + var logStore raft.LogStore + var stableStore raft.StableStore + var snapStore raft.SnapshotStore + + if cfg.StateDirPath == "" { + // comments in raft code say to only use for tests + logStore = raft.NewInmemStore() + stableStore = raft.NewInmemStore() + snapStore = raft.NewInmemSnapshotStore() + } else { + var err error + stableStore, logStore, err = boltStore(filepath.Join(cfg.StateDirPath, "store")) + if err != nil { + return nil, err + } + snaplogger := hclog.New(&hclog.LoggerOptions{ + Name: "raft-snap", + Output: cfg.Raft.LogOutput, + Level: hclog.LevelFromString(cfg.Raft.LogLevel), + }) + snapStore, err = raft.NewFileSnapshotStoreWithLogger(filepath.Join(cfg.StateDirPath, "snapstore"), 2, snaplogger) + if err != nil { + return nil, err + } + } + + // opens the listener on the raft port, raft will close it when it thinks it's appropriate + ln, err := ts.Listen("tcp", raftAddr(self.hostAddr, cfg)) + if err != nil { + return nil, err + } + + transportLogger := hclog.New(&hclog.LoggerOptions{ + Name: "raft-net", + Output: cfg.Raft.LogOutput, + Level: hclog.LevelFromString(cfg.Raft.LogLevel), + }) + + transport := raft.NewNetworkTransportWithLogger(StreamLayer{ + s: ts, + Listener: ln, + auth: auth, + shutdownCtx: shutdownCtx, + }, + cfg.MaxConnPool, + cfg.ConnTimeout, + transportLogger) + + return raft.NewRaft(cfg.Raft, *fsm, logStore, stableStore, snapStore, transport) +} + +// A Consensus is the consensus algorithm for a tsnet.Server +// It wraps a raft.Raft instance and performs the peer discovery +// and command execution on the leader. +type Consensus struct { + raft *raft.Raft + commandClient *commandClient + self selfRaftNode + config Config + cmdHttpServer *http.Server + monitorHttpServer *http.Server + shutdownCtxCancel context.CancelFunc +} + +func (c *Consensus) bootstrapTryToJoinAnyTarget(targets views.Slice[*ipnstate.PeerStatus]) bool { + log.Printf("Bootstrap: Trying to find cluster: num targets to try: %d", targets.Len()) + for _, p := range targets.All() { + if !p.Online { + log.Printf("Bootstrap: Trying to find cluster: tailscale reports not online: %s", p.TailscaleIPs[0]) + continue + } + log.Printf("Bootstrap: Trying to find cluster: trying %s", p.TailscaleIPs[0]) + err := c.commandClient.join(p.TailscaleIPs[0].String(), joinRequest{ + RemoteHost: c.self.hostAddr.String(), + RemoteID: c.self.id, + }) + if err != nil { + log.Printf("Bootstrap: Trying to find cluster: could not join %s: %v", p.TailscaleIPs[0], err) + continue + } + log.Printf("Bootstrap: Trying to find cluster: joined %s", p.TailscaleIPs[0]) + return true + } + return false +} + +func (c *Consensus) retryFollow(ctx context.Context, auth *authorization) bool { + waitFor := 500 * time.Millisecond + nRetries := 10 + attemptCount := 1 + for true { + log.Printf("Bootstrap: trying to follow any cluster member: attempt %v", attemptCount) + joined := c.bootstrapTryToJoinAnyTarget(auth.AllowedPeers()) + if joined || attemptCount == nRetries { + return joined + } + log.Printf("Bootstrap: Failed to follow. Retrying in %v", waitFor) + time.Sleep(waitFor) + waitFor *= 2 + attemptCount++ + auth.Refresh(ctx) + } + return false +} + +// bootstrap tries to join a raft cluster, or start one. +// +// We need to do the very first raft cluster configuration, but after that raft manages it. +// bootstrap is called at start up, and we may not currently be aware of what the cluster config might be, +// our node may already be in it. Try to join the raft cluster of all the other nodes we know about, and +// if unsuccessful, assume we are the first and try to start our own. If the FollowOnly option is set, only try +// to join, never start our own. +// +// It's possible for bootstrap to start an errant breakaway cluster if for example all nodes are having a fresh +// start, they're racing bootstrap and multiple nodes were unable to join a peer and so start their own new cluster. +// To avoid this operators should either ensure bootstrap is called for a single node first and allow it to become +// leader before starting the other nodes. Or start all but one node with the FollowOnly option. +// +// We have a list of expected cluster members already from control (the members of the tailnet with the tag) +// so we could do the initial configuration with all servers specified. +// Choose to start with just this machine in the raft configuration instead, as: +// - We want to handle machines joining after start anyway. +// - Not all tagged nodes tailscale believes are active are necessarily actually responsive right now, +// so let each node opt in when able. +func (c *Consensus) bootstrap(ctx context.Context, auth *authorization, opts BootstrapOpts) error { + if opts.FollowOnly { + joined := c.retryFollow(ctx, auth) + if !joined { + return errors.New("unable to join cluster") + } + return nil + } + + joined := c.bootstrapTryToJoinAnyTarget(auth.AllowedPeers()) + if joined { + return nil + } + log.Printf("Bootstrap: Trying to find cluster: unsuccessful, starting as leader: %s", c.self.hostAddr.String()) + f := c.raft.BootstrapCluster( + raft.Configuration{ + Servers: []raft.Server{ + { + ID: raft.ServerID(c.self.id), + Address: raft.ServerAddress(c.raftAddr(c.self.hostAddr)), + }, + }, + }) + return f.Error() +} + +// ExecuteCommand propagates a Command to be executed on the leader. Which +// uses raft to Apply it to the followers. +func (c *Consensus) ExecuteCommand(cmd Command) (CommandResult, error) { + b, err := json.Marshal(cmd) + if err != nil { + return CommandResult{}, err + } + result, err := c.executeCommandLocally(cmd) + var leErr lookElsewhereError + for errors.As(err, &leErr) { + result, err = c.commandClient.executeCommand(leErr.where, b) + } + return result, err +} + +// Stop attempts to gracefully shutdown various components. +func (c *Consensus) Stop(ctx context.Context) error { + fut := c.raft.Shutdown() + err := fut.Error() + if err != nil { + log.Printf("Stop: Error in Raft Shutdown: %v", err) + } + c.shutdownCtxCancel() + err = c.cmdHttpServer.Shutdown(ctx) + if err != nil { + log.Printf("Stop: Error in command HTTP Shutdown: %v", err) + } + if c.monitorHttpServer != nil { + err = c.monitorHttpServer.Shutdown(ctx) + if err != nil { + log.Printf("Stop: Error in monitor HTTP Shutdown: %v", err) + } + } + return nil +} + +// A Command is a representation of a state machine action. +type Command struct { + // The Name can be used to dispatch the command when received. + Name string + // The Args are serialized for transport. + Args json.RawMessage +} + +// A CommandResult is a representation of the result of a state +// machine action. +type CommandResult struct { + // Err is any error that occurred on the node that tried to execute the command, + // including any error from the underlying operation and deserialization problems etc. + Err error + // Result is serialized for transport. + Result json.RawMessage +} + +type lookElsewhereError struct { + where string +} + +func (e lookElsewhereError) Error() string { + return fmt.Sprintf("not the leader, try: %s", e.where) +} + +var errLeaderUnknown = errors.New("leader unknown") + +func (c *Consensus) serveCommandHTTP(ts *tsnet.Server, auth *authorization) (*http.Server, error) { + ln, err := ts.Listen("tcp", c.commandAddr(c.self.hostAddr)) + if err != nil { + return nil, err + } + srv := &http.Server{Handler: c.makeCommandHandler(auth)} + go func() { + err := srv.Serve(ln) + log.Printf("CmdHttp stopped serving with err: %v", err) + }() + return srv, nil +} + +func (c *Consensus) getLeader() (string, error) { + raftLeaderAddr, _ := c.raft.LeaderWithID() + leaderAddr := (string)(raftLeaderAddr) + if leaderAddr == "" { + // Raft doesn't know who the leader is. + return "", errLeaderUnknown + } + // Raft gives us the address with the raft port, we don't always want that. + host, _, err := net.SplitHostPort(leaderAddr) + return host, err +} + +func (c *Consensus) executeCommandLocally(cmd Command) (CommandResult, error) { + b, err := json.Marshal(cmd) + if err != nil { + return CommandResult{}, err + } + f := c.raft.Apply(b, 0) + err = f.Error() + result := f.Response() + if errors.Is(err, raft.ErrNotLeader) { + leader, err := c.getLeader() + if err != nil { + // we know we're not leader but we were unable to give the address of the leader + return CommandResult{}, err + } + return CommandResult{}, lookElsewhereError{where: leader} + } + if result == nil { + result = CommandResult{} + } + return result.(CommandResult), err +} + +func (c *Consensus) handleJoin(jr joinRequest) error { + addr, err := netip.ParseAddr(jr.RemoteHost) + if err != nil { + return err + } + remoteAddr := c.raftAddr(addr) + f := c.raft.AddVoter(raft.ServerID(jr.RemoteID), raft.ServerAddress(remoteAddr), 0, 0) + if f.Error() != nil { + return f.Error() + } + return nil +} + +func (c *Consensus) raftAddr(host netip.Addr) string { + return raftAddr(host, c.config) +} + +func (c *Consensus) commandAddr(host netip.Addr) string { + return netip.AddrPortFrom(host, c.config.CommandPort).String() +} + +// GetClusterConfiguration returns the result of the underlying raft instance's GetConfiguration +func (c *Consensus) GetClusterConfiguration() (raft.Configuration, error) { + fut := c.raft.GetConfiguration() + err := fut.Error() + if err != nil { + return raft.Configuration{}, err + } + return fut.Configuration(), nil +} + +// DeleteClusterServer returns the result of the underlying raft instance's RemoveServer +func (c *Consensus) DeleteClusterServer(id raft.ServerID) (uint64, error) { + fut := c.raft.RemoveServer(id, 0, 1*time.Second) + err := fut.Error() + if err != nil { + return 0, err + } + return fut.Index(), nil +} diff --git a/tsconsensus/tsconsensus_test.go b/tsconsensus/tsconsensus_test.go new file mode 100644 index 000000000..7f89eb48a --- /dev/null +++ b/tsconsensus/tsconsensus_test.go @@ -0,0 +1,771 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tsconsensus + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net" + "net/http" + "net/http/httptest" + "net/netip" + "os" + "path/filepath" + "runtime" + "strings" + "sync" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "github.com/hashicorp/go-hclog" + "github.com/hashicorp/raft" + "tailscale.com/client/tailscale" + "tailscale.com/cmd/testwrapper/flakytest" + "tailscale.com/ipn/store/mem" + "tailscale.com/net/netns" + "tailscale.com/tailcfg" + "tailscale.com/tsnet" + "tailscale.com/tstest/integration" + "tailscale.com/tstest/integration/testcontrol" + "tailscale.com/tstest/nettest" + "tailscale.com/types/key" + "tailscale.com/types/logger" + "tailscale.com/types/views" + "tailscale.com/util/cibuild" + "tailscale.com/util/racebuild" +) + +type fsm struct { + mu sync.Mutex + applyEvents []string +} + +func commandWith(t *testing.T, s string) []byte { + jsonArgs, err := json.Marshal(s) + if err != nil { + t.Fatal(err) + } + bs, err := json.Marshal(Command{ + Args: jsonArgs, + }) + if err != nil { + t.Fatal(err) + } + return bs +} + +func fromCommand(bs []byte) (string, error) { + var cmd Command + err := json.Unmarshal(bs, &cmd) + if err != nil { + return "", err + } + var args string + err = json.Unmarshal(cmd.Args, &args) + if err != nil { + return "", err + } + return args, nil +} + +func (f *fsm) Apply(lg *raft.Log) any { + f.mu.Lock() + defer f.mu.Unlock() + s, err := fromCommand(lg.Data) + if err != nil { + return CommandResult{ + Err: err, + } + } + f.applyEvents = append(f.applyEvents, s) + result, err := json.Marshal(len(f.applyEvents)) + if err != nil { + panic("should be able to Marshal that?") + } + return CommandResult{ + Result: result, + } +} + +func (f *fsm) numEvents() int { + f.mu.Lock() + defer f.mu.Unlock() + return len(f.applyEvents) +} + +func (f *fsm) eventsMatch(es []string) bool { + f.mu.Lock() + defer f.mu.Unlock() + return cmp.Equal(es, f.applyEvents) +} + +func (f *fsm) Snapshot() (raft.FSMSnapshot, error) { + return nil, nil +} + +func (f *fsm) Restore(rc io.ReadCloser) error { + return nil +} + +func testConfig(t *testing.T) { + if runtime.GOOS == "windows" && cibuild.On() { + t.Skip("cmd/natc isn't supported on Windows, so skipping tsconsensus tests on CI for now; see https://github.com/tailscale/tailscale/issues/16340") + } + // -race AND Parallel makes things start to take too long. + if !racebuild.On { + t.Parallel() + } + nettest.SkipIfNoNetwork(t) +} + +func startControl(t testing.TB) (control *testcontrol.Server, controlURL string) { + t.Helper() + // tailscale/corp#4520: don't use netns for tests. + netns.SetEnabled(false) + t.Cleanup(func() { + netns.SetEnabled(true) + }) + + derpLogf := logger.Discard + derpMap := integration.RunDERPAndSTUN(t, derpLogf, "127.0.0.1") + control = &testcontrol.Server{ + DERPMap: derpMap, + DNSConfig: &tailcfg.DNSConfig{ + Proxied: true, + }, + MagicDNSDomain: "tail-scale.ts.net", + } + control.HTTPTestServer = httptest.NewUnstartedServer(control) + control.HTTPTestServer.Start() + t.Cleanup(control.HTTPTestServer.Close) + controlURL = control.HTTPTestServer.URL + t.Logf("testcontrol listening on %s", controlURL) + return control, controlURL +} + +func startNode(t testing.TB, ctx context.Context, controlURL, hostname string) (*tsnet.Server, key.NodePublic, netip.Addr) { + t.Helper() + + tmp := filepath.Join(t.TempDir(), hostname) + os.MkdirAll(tmp, 0755) + s := &tsnet.Server{ + Dir: tmp, + ControlURL: controlURL, + Hostname: hostname, + Store: new(mem.Store), + Ephemeral: true, + } + t.Cleanup(func() { s.Close() }) + + status, err := s.Up(ctx) + if err != nil { + t.Fatal(err) + } + return s, status.Self.PublicKey, status.TailscaleIPs[0] +} + +func waitForNodesToBeTaggedInStatus(t testing.TB, ctx context.Context, ts *tsnet.Server, nodeKeys []key.NodePublic, tag string) { + t.Helper() + waitFor(t, "nodes tagged in status", func() bool { + lc, err := ts.LocalClient() + if err != nil { + t.Fatal(err) + } + status, err := lc.Status(ctx) + if err != nil { + t.Fatalf("error getting status: %v", err) + } + for _, k := range nodeKeys { + var tags *views.Slice[string] + if k == status.Self.PublicKey { + tags = status.Self.Tags + } else { + tags = status.Peer[k].Tags + } + if tag == "" { + if tags != nil && tags.Len() != 0 { + return false + } + } else { + if tags == nil { + return false + } + if tags.Len() != 1 || tags.At(0) != tag { + return false + } + } + } + return true + }, 2*time.Second) +} + +func tagNodes(t testing.TB, control *testcontrol.Server, nodeKeys []key.NodePublic, tag string) { + t.Helper() + for _, key := range nodeKeys { + n := control.Node(key) + if tag == "" { + if len(n.Tags) != 1 { + t.Fatalf("expected tags to have one tag") + } + n.Tags = nil + } else { + if len(n.Tags) != 0 { + // if we want this to work with multiple tags we'll have to change the logic + // for checking if a tag got removed yet. + t.Fatalf("expected tags to be empty") + } + n.Tags = append(n.Tags, tag) + } + b := true + n.Online = &b + control.UpdateNode(n) + } +} + +func addIDedLogger(id string, c Config) Config { + // logs that identify themselves + c.Raft.Logger = hclog.New(&hclog.LoggerOptions{ + Name: fmt.Sprintf("raft: %s", id), + Output: c.Raft.LogOutput, + Level: hclog.LevelFromString(c.Raft.LogLevel), + }) + return c +} + +func warnLogConfig() Config { + c := DefaultConfig() + // fewer logs from raft + c.Raft.LogLevel = "WARN" + // timeouts long enough that we can form a cluster under -race + c.Raft.LeaderLeaseTimeout = 2 * time.Second + c.Raft.HeartbeatTimeout = 4 * time.Second + c.Raft.ElectionTimeout = 4 * time.Second + return c +} + +func TestStart(t *testing.T) { + flakytest.Mark(t, "https://github.com/tailscale/tailscale/issues/15627") + testConfig(t) + control, controlURL := startControl(t) + ctx := context.Background() + one, k, _ := startNode(t, ctx, controlURL, "one") + + clusterTag := "tag:whatever" + // nodes must be tagged with the cluster tag, to find each other + tagNodes(t, control, []key.NodePublic{k}, clusterTag) + waitForNodesToBeTaggedInStatus(t, ctx, one, []key.NodePublic{k}, clusterTag) + + sm := &fsm{} + r, err := Start(ctx, one, sm, BootstrapOpts{Tag: clusterTag}, warnLogConfig()) + if err != nil { + t.Fatal(err) + } + defer r.Stop(ctx) +} + +func waitFor(t testing.TB, msg string, condition func() bool, waitBetweenTries time.Duration) { + t.Helper() + try := 0 + for true { + try++ + done := condition() + if done { + t.Logf("waitFor success: %s: after %d tries", msg, try) + return + } + time.Sleep(waitBetweenTries) + } +} + +type participant struct { + c *Consensus + sm *fsm + ts *tsnet.Server + key key.NodePublic +} + +// starts and tags the *tsnet.Server nodes with the control, waits for the nodes to make successful +// LocalClient Status calls that show the first node as Online. +func startNodesAndWaitForPeerStatus(t testing.TB, ctx context.Context, clusterTag string, nNodes int) ([]*participant, *testcontrol.Server, string) { + t.Helper() + ps := make([]*participant, nNodes) + keysToTag := make([]key.NodePublic, nNodes) + localClients := make([]*tailscale.LocalClient, nNodes) + control, controlURL := startControl(t) + for i := 0; i < nNodes; i++ { + ts, key, _ := startNode(t, ctx, controlURL, fmt.Sprintf("node %d", i)) + ps[i] = &participant{ts: ts, key: key} + keysToTag[i] = key + lc, err := ts.LocalClient() + if err != nil { + t.Fatalf("%d: error getting local client: %v", i, err) + } + localClients[i] = lc + } + tagNodes(t, control, keysToTag, clusterTag) + waitForNodesToBeTaggedInStatus(t, ctx, ps[0].ts, keysToTag, clusterTag) + fxCameOnline := func() bool { + // all the _other_ nodes see the first as online + for i := 1; i < nNodes; i++ { + status, err := localClients[i].Status(ctx) + if err != nil { + t.Fatalf("%d: error getting status: %v", i, err) + } + if !status.Peer[ps[0].key].Online { + return false + } + } + return true + } + waitFor(t, "other nodes see node 1 online in ts status", fxCameOnline, 2*time.Second) + return ps, control, controlURL +} + +// populates participants with their consensus fields, waits for all nodes to show all nodes +// as part of the same consensus cluster. Starts the first participant first and waits for it to +// become leader before adding other nodes. +func createConsensusCluster(t testing.TB, ctx context.Context, clusterTag string, participants []*participant, cfg Config) { + t.Helper() + participants[0].sm = &fsm{} + myCfg := addIDedLogger("0", cfg) + first, err := Start(ctx, participants[0].ts, participants[0].sm, BootstrapOpts{Tag: clusterTag}, myCfg) + if err != nil { + t.Fatal(err) + } + fxFirstIsLeader := func() bool { + return first.raft.State() == raft.Leader + } + waitFor(t, "node 0 is leader", fxFirstIsLeader, 2*time.Second) + participants[0].c = first + + for i := 1; i < len(participants); i++ { + participants[i].sm = &fsm{} + myCfg := addIDedLogger(fmt.Sprintf("%d", i), cfg) + c, err := Start(ctx, participants[i].ts, participants[i].sm, BootstrapOpts{Tag: clusterTag}, myCfg) + if err != nil { + t.Fatal(err) + } + participants[i].c = c + } + + fxRaftConfigContainsAll := func() bool { + for i := 0; i < len(participants); i++ { + fut := participants[i].c.raft.GetConfiguration() + err = fut.Error() + if err != nil { + t.Fatalf("%d: Getting Configuration errored: %v", i, err) + } + if len(fut.Configuration().Servers) != len(participants) { + return false + } + } + return true + } + waitFor(t, "all raft machines have all servers in their config", fxRaftConfigContainsAll, time.Second*2) +} + +func TestApply(t *testing.T) { + flakytest.Mark(t, "https://github.com/tailscale/tailscale/issues/15627") + testConfig(t) + ctx := context.Background() + clusterTag := "tag:whatever" + ps, _, _ := startNodesAndWaitForPeerStatus(t, ctx, clusterTag, 2) + cfg := warnLogConfig() + createConsensusCluster(t, ctx, clusterTag, ps, cfg) + for _, p := range ps { + defer p.c.Stop(ctx) + } + + fut := ps[0].c.raft.Apply(commandWith(t, "woo"), 2*time.Second) + err := fut.Error() + if err != nil { + t.Fatalf("Raft Apply Error: %v", err) + } + + want := []string{"woo"} + fxBothMachinesHaveTheApply := func() bool { + return ps[0].sm.eventsMatch(want) && ps[1].sm.eventsMatch(want) + } + waitFor(t, "the apply event made it into both state machines", fxBothMachinesHaveTheApply, time.Second*1) +} + +// calls ExecuteCommand on each participant and checks that all participants get all commands +func assertCommandsWorkOnAnyNode(t testing.TB, participants []*participant) { + t.Helper() + want := []string{} + for i, p := range participants { + si := fmt.Sprintf("%d", i) + want = append(want, si) + bs, err := json.Marshal(si) + if err != nil { + t.Fatal(err) + } + res, err := p.c.ExecuteCommand(Command{Args: bs}) + if err != nil { + t.Fatalf("%d: Error ExecuteCommand: %v", i, err) + } + if res.Err != nil { + t.Fatalf("%d: Result Error ExecuteCommand: %v", i, res.Err) + } + var retVal int + err = json.Unmarshal(res.Result, &retVal) + if err != nil { + t.Fatal(err) + } + // the test implementation of the fsm returns the count of events that have been received + if retVal != i+1 { + t.Fatalf("Result, want %d, got %d", i+1, retVal) + } + + fxEventsInAll := func() bool { + for _, pOther := range participants { + if !pOther.sm.eventsMatch(want) { + return false + } + } + return true + } + waitFor(t, "event makes it to all", fxEventsInAll, time.Second*1) + } +} + +func TestConfig(t *testing.T) { + flakytest.Mark(t, "https://github.com/tailscale/tailscale/issues/15627") + testConfig(t) + ctx := context.Background() + clusterTag := "tag:whatever" + ps, _, _ := startNodesAndWaitForPeerStatus(t, ctx, clusterTag, 3) + cfg := warnLogConfig() + // test all is well with non default ports + cfg.CommandPort = 12347 + cfg.RaftPort = 11882 + mp := uint16(8798) + cfg.MonitorPort = mp + cfg.ServeDebugMonitor = true + createConsensusCluster(t, ctx, clusterTag, ps, cfg) + for _, p := range ps { + defer p.c.Stop(ctx) + } + assertCommandsWorkOnAnyNode(t, ps) + + url := fmt.Sprintf("http://%s:%d/", ps[0].c.self.hostAddr.String(), mp) + httpClientOnTailnet := ps[1].ts.HTTPClient() + rsp, err := httpClientOnTailnet.Get(url) + if err != nil { + t.Fatal(err) + } + if rsp.StatusCode != 200 { + t.Fatalf("monitor status want %d, got %d", 200, rsp.StatusCode) + } + defer rsp.Body.Close() + reader := bufio.NewReader(rsp.Body) + line1, err := reader.ReadString('\n') + if err != nil { + t.Fatal(err) + } + // Not a great assertion because it relies on the format of the response. + if !strings.HasPrefix(line1, "RaftState:") { + t.Fatalf("getting monitor status, first line, want something that starts with 'RaftState:', got '%s'", line1) + } +} + +func TestFollowerFailover(t *testing.T) { + flakytest.Mark(t, "https://github.com/tailscale/tailscale/issues/15627") + testConfig(t) + ctx := context.Background() + clusterTag := "tag:whatever" + ps, _, _ := startNodesAndWaitForPeerStatus(t, ctx, clusterTag, 3) + cfg := warnLogConfig() + createConsensusCluster(t, ctx, clusterTag, ps, cfg) + for _, p := range ps { + defer p.c.Stop(ctx) + } + + smThree := ps[2].sm + + fut := ps[0].c.raft.Apply(commandWith(t, "a"), 2*time.Second) + futTwo := ps[0].c.raft.Apply(commandWith(t, "b"), 2*time.Second) + err := fut.Error() + if err != nil { + t.Fatalf("Apply Raft error %v", err) + } + err = futTwo.Error() + if err != nil { + t.Fatalf("Apply Raft error %v", err) + } + + wantFirstTwoEvents := []string{"a", "b"} + fxAllMachinesHaveTheApplies := func() bool { + return ps[0].sm.eventsMatch(wantFirstTwoEvents) && + ps[1].sm.eventsMatch(wantFirstTwoEvents) && + smThree.eventsMatch(wantFirstTwoEvents) + } + waitFor(t, "the apply events made it into all state machines", fxAllMachinesHaveTheApplies, time.Second*1) + + //a follower goes loses contact with the cluster + ps[2].c.Stop(ctx) + + // applies still make it to one and two + futThree := ps[0].c.raft.Apply(commandWith(t, "c"), 2*time.Second) + futFour := ps[0].c.raft.Apply(commandWith(t, "d"), 2*time.Second) + err = futThree.Error() + if err != nil { + t.Fatalf("Apply Raft error %v", err) + } + err = futFour.Error() + if err != nil { + t.Fatalf("Apply Raft error %v", err) + } + wantFourEvents := []string{"a", "b", "c", "d"} + fxAliveMachinesHaveTheApplies := func() bool { + return ps[0].sm.eventsMatch(wantFourEvents) && + ps[1].sm.eventsMatch(wantFourEvents) && + smThree.eventsMatch(wantFirstTwoEvents) + } + waitFor(t, "the apply events made it into eligible state machines", fxAliveMachinesHaveTheApplies, time.Second*1) + + // follower comes back + smThreeAgain := &fsm{} + cfg = addIDedLogger("2 after restarting", warnLogConfig()) + rThreeAgain, err := Start(ctx, ps[2].ts, smThreeAgain, BootstrapOpts{Tag: clusterTag}, cfg) + if err != nil { + t.Fatal(err) + } + defer rThreeAgain.Stop(ctx) + fxThreeGetsCaughtUp := func() bool { + return smThreeAgain.eventsMatch(wantFourEvents) + } + waitFor(t, "the apply events made it into the third node when it appeared with an empty state machine", fxThreeGetsCaughtUp, time.Second*2) + if !smThree.eventsMatch(wantFirstTwoEvents) { + t.Fatalf("Expected smThree to remain on 2 events: got %d", smThree.numEvents()) + } +} + +func TestRejoin(t *testing.T) { + flakytest.Mark(t, "https://github.com/tailscale/tailscale/issues/15627") + testConfig(t) + ctx := context.Background() + clusterTag := "tag:whatever" + ps, control, controlURL := startNodesAndWaitForPeerStatus(t, ctx, clusterTag, 3) + cfg := warnLogConfig() + createConsensusCluster(t, ctx, clusterTag, ps, cfg) + for _, p := range ps { + defer p.c.Stop(ctx) + } + + // 1st node gets a redundant second join request from the second node + ps[0].c.handleJoin(joinRequest{ + RemoteHost: ps[1].c.self.hostAddr.String(), + RemoteID: ps[1].c.self.id, + }) + + tsJoiner, keyJoiner, _ := startNode(t, ctx, controlURL, "node joiner") + tagNodes(t, control, []key.NodePublic{keyJoiner}, clusterTag) + waitForNodesToBeTaggedInStatus(t, ctx, ps[0].ts, []key.NodePublic{keyJoiner}, clusterTag) + smJoiner := &fsm{} + cJoiner, err := Start(ctx, tsJoiner, smJoiner, BootstrapOpts{Tag: clusterTag}, cfg) + if err != nil { + t.Fatal(err) + } + ps = append(ps, &participant{ + sm: smJoiner, + c: cJoiner, + ts: tsJoiner, + key: keyJoiner, + }) + + assertCommandsWorkOnAnyNode(t, ps) +} + +func TestOnlyTaggedPeersCanDialRaftPort(t *testing.T) { + flakytest.Mark(t, "https://github.com/tailscale/tailscale/issues/15627") + testConfig(t) + ctx := context.Background() + clusterTag := "tag:whatever" + ps, control, controlURL := startNodesAndWaitForPeerStatus(t, ctx, clusterTag, 3) + cfg := warnLogConfig() + createConsensusCluster(t, ctx, clusterTag, ps, cfg) + for _, p := range ps { + defer p.c.Stop(ctx) + } + assertCommandsWorkOnAnyNode(t, ps) + + untaggedNode, _, _ := startNode(t, ctx, controlURL, "untagged node") + + taggedNode, taggedKey, _ := startNode(t, ctx, controlURL, "untagged node") + tagNodes(t, control, []key.NodePublic{taggedKey}, clusterTag) + waitForNodesToBeTaggedInStatus(t, ctx, ps[0].ts, []key.NodePublic{taggedKey}, clusterTag) + + // surface area: command http, peer tcp + //untagged + ipv4, _ := ps[0].ts.TailscaleIPs() + sAddr := fmt.Sprintf("%s:%d", ipv4, cfg.RaftPort) + + getErrorFromTryingToSend := func(s *tsnet.Server) error { + ctx := context.Background() + conn, err := s.Dial(ctx, "tcp", sAddr) + if err != nil { + t.Fatalf("unexpected Dial err: %v", err) + } + fmt.Fprintf(conn, "hellllllloooooo") + status, err := bufio.NewReader(conn).ReadString('\n') + if status != "" { + t.Fatalf("node sending non-raft message should get empty response, got: '%s' for: %s", status, s.Hostname) + } + if err == nil { + t.Fatalf("node sending non-raft message should get an error but got nil err for: %s", s.Hostname) + } + return err + } + + isNetErr := func(err error) bool { + var netErr net.Error + return errors.As(err, &netErr) + } + + err := getErrorFromTryingToSend(untaggedNode) + if !isNetErr(err) { + t.Fatalf("untagged node trying to send should get a net.Error, got: %v", err) + } + // we still get an error trying to send but it's EOF the target node was happy to talk + // to us but couldn't understand what we said. + err = getErrorFromTryingToSend(taggedNode) + if isNetErr(err) { + t.Fatalf("tagged node trying to send should not get a net.Error, got: %v", err) + } +} + +func TestOnlyTaggedPeersCanBeDialed(t *testing.T) { + flakytest.Mark(t, "https://github.com/tailscale/tailscale/issues/15627") + testConfig(t) + ctx := context.Background() + clusterTag := "tag:whatever" + ps, control, _ := startNodesAndWaitForPeerStatus(t, ctx, clusterTag, 3) + + // make a StreamLayer for ps[0] + ts := ps[0].ts + auth := newAuthorization(ts, clusterTag) + + port := 19841 + lns := make([]net.Listener, 3) + for i, p := range ps { + ln, err := p.ts.Listen("tcp", fmt.Sprintf(":%d", port)) + if err != nil { + t.Fatal(err) + } + lns[i] = ln + } + + sl := StreamLayer{ + s: ts, + Listener: lns[0], + auth: auth, + shutdownCtx: ctx, + } + + ip1, _ := ps[1].ts.TailscaleIPs() + a1 := raft.ServerAddress(fmt.Sprintf("%s:%d", ip1, port)) + + ip2, _ := ps[2].ts.TailscaleIPs() + a2 := raft.ServerAddress(fmt.Sprintf("%s:%d", ip2, port)) + + // both can be dialed... + conn, err := sl.Dial(a1, 2*time.Second) + if err != nil { + t.Fatal(err) + } + conn.Close() + + conn, err = sl.Dial(a2, 2*time.Second) + if err != nil { + t.Fatal(err) + } + conn.Close() + + // untag ps[2] + tagNodes(t, control, []key.NodePublic{ps[2].key}, "") + waitForNodesToBeTaggedInStatus(t, ctx, ps[0].ts, []key.NodePublic{ps[2].key}, "") + + // now only ps[1] can be dialed + conn, err = sl.Dial(a1, 2*time.Second) + if err != nil { + t.Fatal(err) + } + conn.Close() + + _, err = sl.Dial(a2, 2*time.Second) + if err.Error() != "dial: peer is not allowed" { + t.Fatalf("expected dial: peer is not allowed, got: %v", err) + } + +} + +func TestOnlyTaggedPeersCanJoin(t *testing.T) { + testConfig(t) + ctx := context.Background() + clusterTag := "tag:whatever" + ps, _, controlURL := startNodesAndWaitForPeerStatus(t, ctx, clusterTag, 3) + cfg := warnLogConfig() + createConsensusCluster(t, ctx, clusterTag, ps, cfg) + for _, p := range ps { + defer p.c.Stop(ctx) + } + + tsJoiner, _, _ := startNode(t, ctx, controlURL, "joiner node") + + ipv4, _ := tsJoiner.TailscaleIPs() + url := fmt.Sprintf("http://%s/join", ps[0].c.commandAddr(ps[0].c.self.hostAddr)) + payload, err := json.Marshal(joinRequest{ + RemoteHost: ipv4.String(), + RemoteID: "node joiner", + }) + if err != nil { + t.Fatal(err) + } + body := bytes.NewBuffer(payload) + req, err := http.NewRequest("POST", url, body) + if err != nil { + t.Fatal(err) + } + resp, err := tsJoiner.HTTPClient().Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusForbidden { + t.Fatalf("join req when not tagged, expected status: %d, got: %d", http.StatusForbidden, resp.StatusCode) + } + rBody, _ := io.ReadAll(resp.Body) + sBody := strings.TrimSpace(string(rBody)) + expected := "peer not allowed" + if sBody != expected { + t.Fatalf("join req when not tagged, expected body: %s, got: %s", expected, sBody) + } +} + +func TestFollowOnly(t *testing.T) { + testConfig(t) + ctx := context.Background() + clusterTag := "tag:whatever" + ps, _, _ := startNodesAndWaitForPeerStatus(t, ctx, clusterTag, 3) + cfg := warnLogConfig() + + // start the leader + _, err := Start(ctx, ps[0].ts, ps[0].sm, BootstrapOpts{Tag: clusterTag}, cfg) + if err != nil { + t.Fatal(err) + } + + // start the follower with FollowOnly + _, err = Start(ctx, ps[1].ts, ps[1].sm, BootstrapOpts{Tag: clusterTag, FollowOnly: true}, cfg) + if err != nil { + t.Fatal(err) + } +} diff --git a/tsconst/health.go b/tsconst/health.go new file mode 100644 index 000000000..5db9b1fc2 --- /dev/null +++ b/tsconst/health.go @@ -0,0 +1,26 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tsconst + +const ( + HealthWarnableUpdateAvailable = "update-available" + HealthWarnableSecurityUpdateAvailable = "security-update-available" + HealthWarnableIsUsingUnstableVersion = "is-using-unstable-version" + HealthWarnableNetworkStatus = "network-status" + HealthWarnableWantRunningFalse = "wantrunning-false" + HealthWarnableLocalLogConfigError = "local-log-config-error" + HealthWarnableLoginState = "login-state" + HealthWarnableNotInMapPoll = "not-in-map-poll" + HealthWarnableNoDERPHome = "no-derp-home" + HealthWarnableNoDERPConnection = "no-derp-connection" + HealthWarnableDERPTimedOut = "derp-timed-out" + HealthWarnableDERPRegionError = "derp-region-error" + HealthWarnableNoUDP4Bind = "no-udp4-bind" + HealthWarnableMapResponseTimeout = "mapresponse-timeout" + HealthWarnableTLSConnectionFailed = "tls-connection-failed" + HealthWarnableMagicsockReceiveFuncError = "magicsock-receive-func-error" + HealthWarnableTestWarnable = "test-warnable" + HealthWarnableApplyDiskConfig = "apply-disk-config" + HealthWarnableWarmingUp = "warming-up" +) diff --git a/tsconst/linuxfw.go b/tsconst/linuxfw.go new file mode 100644 index 000000000..ce571e402 --- /dev/null +++ b/tsconst/linuxfw.go @@ -0,0 +1,43 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tsconst + +// Linux firewall constants used by Tailscale. + +// The following bits are added to packet marks for Tailscale use. +// +// We tried to pick bits sufficiently out of the way that it's +// unlikely to collide with existing uses. We have 4 bytes of mark +// bits to play with. We leave the lower byte alone on the assumption +// that sysadmins would use those. Kubernetes uses a few bits in the +// second byte, so we steer clear of that too. +// +// Empirically, most of the documentation on packet marks on the +// internet gives the impression that the marks are 16 bits +// wide. Based on this, we theorize that the upper two bytes are +// relatively unused in the wild, and so we consume bits 16:23 (the +// third byte). +// +// The constants are in the iptables/iproute2 string format for +// matching and setting the bits, so they can be directly embedded in +// commands. +const ( + // The mask for reading/writing the 'firewall mask' bits on a packet. + // See the comment on the const block on why we only use the third byte. + // + // We claim bits 16:23 entirely. For now we only use the lower four + // bits, leaving the higher 4 bits for future use. + LinuxFwmarkMask = "0xff0000" + LinuxFwmarkMaskNum = 0xff0000 + + // Packet is from Tailscale and to a subnet route destination, so + // is allowed to be routed through this machine. + LinuxSubnetRouteMark = "0x40000" + LinuxSubnetRouteMarkNum = 0x40000 + + // Packet was originated by tailscaled itself, and must not be + // routed over the Tailscale network. + LinuxBypassMark = "0x80000" + LinuxBypassMarkNum = 0x80000 +) diff --git a/tsconst/interface.go b/tsconst/tsconst.go similarity index 100% rename from tsconst/interface.go rename to tsconst/tsconst.go diff --git a/tsconst/webclient.go b/tsconst/webclient.go new file mode 100644 index 000000000..d4b3c8db5 --- /dev/null +++ b/tsconst/webclient.go @@ -0,0 +1,9 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tsconst + +// WebListenPort is the static port used for the web client when run inside +// tailscaled. (5252 are the numbers above the letters "TSTS" on a qwerty +// keyboard.) +const WebListenPort = 5252 diff --git a/tsd/tsd.go b/tsd/tsd.go index acd09560c..8223254da 100644 --- a/tsd/tsd.go +++ b/tsd/tsd.go @@ -32,6 +32,8 @@ import ( "tailscale.com/net/tstun" "tailscale.com/proxymap" "tailscale.com/types/netmap" + "tailscale.com/util/eventbus" + "tailscale.com/util/syspolicy/policyclient" "tailscale.com/util/usermetric" "tailscale.com/wgengine" "tailscale.com/wgengine/magicsock" @@ -39,7 +41,12 @@ import ( ) // System contains all the subsystems of a Tailscale node (tailscaled, etc.) +// +// A valid System value must always have a non-nil Bus populated. Callers must +// ensure this before using the value further. Call [NewSystem] to obtain a +// value ready to use. type System struct { + Bus SubSystem[*eventbus.Bus] Dialer SubSystem[*tsdial.Dialer] DNSManager SubSystem[*dns.Manager] // can get its *resolver.Resolver from DNSManager.Resolver Engine SubSystem[wgengine.Engine] @@ -52,6 +59,8 @@ type System struct { Netstack SubSystem[NetstackImpl] // actually a *netstack.Impl DriveForLocal SubSystem[drive.FileSystemForLocal] DriveForRemote SubSystem[drive.FileSystemForRemote] + PolicyClient SubSystem[policyclient.Client] + HealthTracker SubSystem[*health.Tracker] // InitialConfig is initial server config, if any. // It is nil if the node is not in declarative mode. @@ -66,14 +75,37 @@ type System struct { controlKnobs controlknobs.Knobs proxyMap proxymap.Mapper - healthTracker health.Tracker userMetricsRegistry usermetric.Registry } +// NewSystem constructs a new otherwise-empty [System] with a +// freshly-constructed event bus populated. +func NewSystem() *System { return NewSystemWithBus(eventbus.New()) } + +// NewSystemWithBus constructs a new otherwise-empty [System] with an +// eventbus provided by the caller. The provided bus must not be nil. +// This is mainly intended for testing; for production use call [NewBus]. +func NewSystemWithBus(bus *eventbus.Bus) *System { + if bus == nil { + panic("nil eventbus") + } + sys := new(System) + sys.Set(bus) + + tracker := health.NewTracker(bus) + sys.Set(tracker) + + return sys +} + +// LocalBackend is a fake name for *ipnlocal.LocalBackend to avoid an import cycle. +type LocalBackend = any + // NetstackImpl is the interface that *netstack.Impl implements. // It's an interface for circular dependency reasons: netstack.Impl // references LocalBackend, and LocalBackend has a tsd.System. type NetstackImpl interface { + Start(LocalBackend) error UpdateNetstackIPs(*netmap.NetworkMap) } @@ -82,6 +114,8 @@ type NetstackImpl interface { // has already been set. func (s *System) Set(v any) { switch v := v.(type) { + case *eventbus.Bus: + s.Bus.Set(v) case *netmon.Monitor: s.NetMon.Set(v) case *dns.Manager: @@ -110,6 +144,10 @@ func (s *System) Set(v any) { s.DriveForLocal.Set(v) case drive.FileSystemForRemote: s.DriveForRemote.Set(v) + case policyclient.Client: + s.PolicyClient.Set(v) + case *health.Tracker: + s.HealthTracker.Set(v) default: panic(fmt.Sprintf("unknown type %T", v)) } @@ -139,16 +177,20 @@ func (s *System) ProxyMapper() *proxymap.Mapper { return &s.proxyMap } -// HealthTracker returns the system health tracker. -func (s *System) HealthTracker() *health.Tracker { - return &s.healthTracker -} - // UserMetricsRegistry returns the system usermetrics. func (s *System) UserMetricsRegistry() *usermetric.Registry { return &s.userMetricsRegistry } +// PolicyClientOrDefault returns the policy client if set or a no-op default +// otherwise. It always returns a non-nil value. +func (s *System) PolicyClientOrDefault() policyclient.Client { + if client, ok := s.PolicyClient.GetOK(); ok { + return client + } + return policyclient.Get() +} + // SubSystem represents some subsystem of the Tailscale node daemon. // // A subsystem can be set to a value, and then later retrieved. A subsystem diff --git a/tsnet/depaware.txt b/tsnet/depaware.txt new file mode 100644 index 000000000..7d5ec0a60 --- /dev/null +++ b/tsnet/depaware.txt @@ -0,0 +1,576 @@ +tailscale.com/tsnet dependencies: (generated by github.com/tailscale/depaware) + + filippo.io/edwards25519 from github.com/hdevalence/ed25519consensus + filippo.io/edwards25519/field from filippo.io/edwards25519 + W đŸ’Ŗ github.com/alexbrainman/sspi from github.com/alexbrainman/sspi/internal/common+ + W github.com/alexbrainman/sspi/internal/common from github.com/alexbrainman/sspi/negotiate + W đŸ’Ŗ github.com/alexbrainman/sspi/negotiate from tailscale.com/net/tshttpproxy + LDW github.com/coder/websocket from tailscale.com/util/eventbus + LDW github.com/coder/websocket/internal/errd from github.com/coder/websocket + LDW github.com/coder/websocket/internal/util from github.com/coder/websocket + LDW github.com/coder/websocket/internal/xsync from github.com/coder/websocket + github.com/creachadair/msync/trigger from tailscale.com/logtail + W đŸ’Ŗ github.com/dblohm7/wingoes from tailscale.com/net/tshttpproxy+ + W đŸ’Ŗ github.com/dblohm7/wingoes/com from tailscale.com/util/osdiag+ + W đŸ’Ŗ github.com/dblohm7/wingoes/com/automation from tailscale.com/util/osdiag/internal/wsc + W github.com/dblohm7/wingoes/internal from github.com/dblohm7/wingoes/com + W đŸ’Ŗ github.com/dblohm7/wingoes/pe from tailscale.com/util/osdiag+ + github.com/fxamacker/cbor/v2 from tailscale.com/tka + github.com/gaissmai/bart from tailscale.com/net/ipset+ + github.com/gaissmai/bart/internal/bitset from github.com/gaissmai/bart+ + github.com/gaissmai/bart/internal/sparse from github.com/gaissmai/bart + github.com/go-json-experiment/json from tailscale.com/types/opt+ + github.com/go-json-experiment/json/internal from github.com/go-json-experiment/json+ + github.com/go-json-experiment/json/internal/jsonflags from github.com/go-json-experiment/json+ + github.com/go-json-experiment/json/internal/jsonopts from github.com/go-json-experiment/json+ + github.com/go-json-experiment/json/internal/jsonwire from github.com/go-json-experiment/json+ + github.com/go-json-experiment/json/jsontext from github.com/go-json-experiment/json+ + L đŸ’Ŗ github.com/godbus/dbus/v5 from tailscale.com/net/dns + github.com/golang/groupcache/lru from tailscale.com/net/dnscache + github.com/google/btree from gvisor.dev/gvisor/pkg/tcpip/header+ + DI github.com/google/uuid from github.com/prometheus-community/pro-bing + github.com/hdevalence/ed25519consensus from tailscale.com/tka + L đŸ’Ŗ github.com/jsimonetti/rtnetlink from tailscale.com/net/netmon + L github.com/jsimonetti/rtnetlink/internal/unix from github.com/jsimonetti/rtnetlink + github.com/klauspost/compress from github.com/klauspost/compress/zstd + github.com/klauspost/compress/fse from github.com/klauspost/compress/huff0 + github.com/klauspost/compress/huff0 from github.com/klauspost/compress/zstd + github.com/klauspost/compress/internal/cpuinfo from github.com/klauspost/compress/huff0+ + github.com/klauspost/compress/internal/snapref from github.com/klauspost/compress/zstd + github.com/klauspost/compress/zstd from tailscale.com/util/zstdframe + github.com/klauspost/compress/zstd/internal/xxhash from github.com/klauspost/compress/zstd + L đŸ’Ŗ github.com/mdlayher/netlink from github.com/jsimonetti/rtnetlink+ + L đŸ’Ŗ github.com/mdlayher/netlink/nlenc from github.com/jsimonetti/rtnetlink+ + LA đŸ’Ŗ github.com/mdlayher/socket from github.com/mdlayher/netlink+ + LDW đŸ’Ŗ github.com/mitchellh/go-ps from tailscale.com/safesocket + github.com/pires/go-proxyproto from tailscale.com/ipn/ipnlocal + DI github.com/prometheus-community/pro-bing from tailscale.com/wgengine/netstack + L đŸ’Ŗ github.com/safchain/ethtool from tailscale.com/net/netkernelconf + W đŸ’Ŗ github.com/tailscale/certstore from tailscale.com/control/controlclient + W đŸ’Ŗ github.com/tailscale/go-winio from tailscale.com/safesocket + W đŸ’Ŗ github.com/tailscale/go-winio/internal/fs from github.com/tailscale/go-winio + W đŸ’Ŗ github.com/tailscale/go-winio/internal/socket from github.com/tailscale/go-winio + W github.com/tailscale/go-winio/internal/stringbuffer from github.com/tailscale/go-winio/internal/fs + W github.com/tailscale/go-winio/pkg/guid from github.com/tailscale/go-winio+ + github.com/tailscale/goupnp from github.com/tailscale/goupnp/dcps/internetgateway2+ + github.com/tailscale/goupnp/dcps/internetgateway2 from tailscale.com/net/portmapper + github.com/tailscale/goupnp/httpu from github.com/tailscale/goupnp+ + github.com/tailscale/goupnp/scpd from github.com/tailscale/goupnp + github.com/tailscale/goupnp/soap from github.com/tailscale/goupnp+ + github.com/tailscale/goupnp/ssdp from github.com/tailscale/goupnp + LDW github.com/tailscale/hujson from tailscale.com/ipn/conffile + LDAI github.com/tailscale/peercred from tailscale.com/ipn/ipnauth + LDW github.com/tailscale/web-client-prebuilt from tailscale.com/client/web + đŸ’Ŗ github.com/tailscale/wireguard-go/conn from github.com/tailscale/wireguard-go/device+ + W đŸ’Ŗ github.com/tailscale/wireguard-go/conn/winrio from github.com/tailscale/wireguard-go/conn + đŸ’Ŗ github.com/tailscale/wireguard-go/device from tailscale.com/net/tstun+ + đŸ’Ŗ github.com/tailscale/wireguard-go/ipc from github.com/tailscale/wireguard-go/device + W đŸ’Ŗ github.com/tailscale/wireguard-go/ipc/namedpipe from github.com/tailscale/wireguard-go/ipc + github.com/tailscale/wireguard-go/ratelimiter from github.com/tailscale/wireguard-go/device + github.com/tailscale/wireguard-go/replay from github.com/tailscale/wireguard-go/device + github.com/tailscale/wireguard-go/rwcancel from github.com/tailscale/wireguard-go/device+ + github.com/tailscale/wireguard-go/tai64n from github.com/tailscale/wireguard-go/device + đŸ’Ŗ github.com/tailscale/wireguard-go/tun from github.com/tailscale/wireguard-go/device+ + github.com/x448/float16 from github.com/fxamacker/cbor/v2 + đŸ’Ŗ go4.org/mem from tailscale.com/client/local+ + go4.org/netipx from tailscale.com/ipn/ipnlocal+ + W đŸ’Ŗ golang.zx2c4.com/wintun from github.com/tailscale/wireguard-go/tun + W đŸ’Ŗ golang.zx2c4.com/wireguard/windows/tunnel/winipcfg from tailscale.com/net/dns+ + gvisor.dev/gvisor/pkg/atomicbitops from gvisor.dev/gvisor/pkg/buffer+ + gvisor.dev/gvisor/pkg/bits from gvisor.dev/gvisor/pkg/buffer + đŸ’Ŗ gvisor.dev/gvisor/pkg/buffer from gvisor.dev/gvisor/pkg/tcpip+ + gvisor.dev/gvisor/pkg/context from gvisor.dev/gvisor/pkg/refs + đŸ’Ŗ gvisor.dev/gvisor/pkg/gohacks from gvisor.dev/gvisor/pkg/state/wire+ + gvisor.dev/gvisor/pkg/linewriter from gvisor.dev/gvisor/pkg/log + gvisor.dev/gvisor/pkg/log from gvisor.dev/gvisor/pkg/context+ + gvisor.dev/gvisor/pkg/rand from gvisor.dev/gvisor/pkg/tcpip+ + gvisor.dev/gvisor/pkg/refs from gvisor.dev/gvisor/pkg/buffer+ + đŸ’Ŗ gvisor.dev/gvisor/pkg/sleep from gvisor.dev/gvisor/pkg/tcpip/transport/tcp + đŸ’Ŗ gvisor.dev/gvisor/pkg/state from gvisor.dev/gvisor/pkg/atomicbitops+ + gvisor.dev/gvisor/pkg/state/wire from gvisor.dev/gvisor/pkg/state + đŸ’Ŗ gvisor.dev/gvisor/pkg/sync from gvisor.dev/gvisor/pkg/atomicbitops+ + đŸ’Ŗ gvisor.dev/gvisor/pkg/sync/locking from gvisor.dev/gvisor/pkg/tcpip/stack + gvisor.dev/gvisor/pkg/tcpip from gvisor.dev/gvisor/pkg/tcpip/adapters/gonet+ + gvisor.dev/gvisor/pkg/tcpip/adapters/gonet from tailscale.com/wgengine/netstack + đŸ’Ŗ gvisor.dev/gvisor/pkg/tcpip/checksum from gvisor.dev/gvisor/pkg/buffer+ + gvisor.dev/gvisor/pkg/tcpip/hash/jenkins from gvisor.dev/gvisor/pkg/tcpip/stack+ + gvisor.dev/gvisor/pkg/tcpip/header from gvisor.dev/gvisor/pkg/tcpip/header/parse+ + gvisor.dev/gvisor/pkg/tcpip/header/parse from gvisor.dev/gvisor/pkg/tcpip/network/ipv4+ + gvisor.dev/gvisor/pkg/tcpip/internal/tcp from gvisor.dev/gvisor/pkg/tcpip/transport/tcp + gvisor.dev/gvisor/pkg/tcpip/network/hash from gvisor.dev/gvisor/pkg/tcpip/network/ipv4 + gvisor.dev/gvisor/pkg/tcpip/network/internal/fragmentation from gvisor.dev/gvisor/pkg/tcpip/network/ipv4+ + gvisor.dev/gvisor/pkg/tcpip/network/internal/ip from gvisor.dev/gvisor/pkg/tcpip/network/ipv4+ + gvisor.dev/gvisor/pkg/tcpip/network/internal/multicast from gvisor.dev/gvisor/pkg/tcpip/network/ipv4+ + gvisor.dev/gvisor/pkg/tcpip/network/ipv4 from tailscale.com/wgengine/netstack + gvisor.dev/gvisor/pkg/tcpip/network/ipv6 from tailscale.com/wgengine/netstack + gvisor.dev/gvisor/pkg/tcpip/ports from gvisor.dev/gvisor/pkg/tcpip/stack+ + gvisor.dev/gvisor/pkg/tcpip/seqnum from gvisor.dev/gvisor/pkg/tcpip/header+ + đŸ’Ŗ gvisor.dev/gvisor/pkg/tcpip/stack from gvisor.dev/gvisor/pkg/tcpip/adapters/gonet+ + LDWA gvisor.dev/gvisor/pkg/tcpip/stack/gro from tailscale.com/wgengine/netstack/gro + gvisor.dev/gvisor/pkg/tcpip/transport from gvisor.dev/gvisor/pkg/tcpip/transport/icmp+ + gvisor.dev/gvisor/pkg/tcpip/transport/icmp from tailscale.com/wgengine/netstack + gvisor.dev/gvisor/pkg/tcpip/transport/internal/network from gvisor.dev/gvisor/pkg/tcpip/transport/icmp+ + gvisor.dev/gvisor/pkg/tcpip/transport/internal/noop from gvisor.dev/gvisor/pkg/tcpip/transport/raw + gvisor.dev/gvisor/pkg/tcpip/transport/packet from gvisor.dev/gvisor/pkg/tcpip/transport/raw + gvisor.dev/gvisor/pkg/tcpip/transport/raw from gvisor.dev/gvisor/pkg/tcpip/transport/icmp+ + đŸ’Ŗ gvisor.dev/gvisor/pkg/tcpip/transport/tcp from gvisor.dev/gvisor/pkg/tcpip/adapters/gonet+ + gvisor.dev/gvisor/pkg/tcpip/transport/tcpconntrack from gvisor.dev/gvisor/pkg/tcpip/stack + gvisor.dev/gvisor/pkg/tcpip/transport/udp from gvisor.dev/gvisor/pkg/tcpip/adapters/gonet+ + gvisor.dev/gvisor/pkg/waiter from gvisor.dev/gvisor/pkg/context+ + tailscale.com from tailscale.com/version + tailscale.com/appc from tailscale.com/ipn/ipnlocal + đŸ’Ŗ tailscale.com/atomicfile from tailscale.com/ipn+ + tailscale.com/client/local from tailscale.com/client/web+ + tailscale.com/client/tailscale from tailscale.com/internal/client/tailscale + tailscale.com/client/tailscale/apitype from tailscale.com/client/local+ + LDW tailscale.com/client/web from tailscale.com/ipn/ipnlocal + tailscale.com/control/controlbase from tailscale.com/control/controlhttp+ + tailscale.com/control/controlclient from tailscale.com/ipn/ipnext+ + tailscale.com/control/controlhttp from tailscale.com/control/ts2021 + tailscale.com/control/controlhttp/controlhttpcommon from tailscale.com/control/controlhttp + tailscale.com/control/controlknobs from tailscale.com/control/controlclient+ + tailscale.com/control/ts2021 from tailscale.com/control/controlclient + tailscale.com/derp from tailscale.com/derp/derphttp+ + tailscale.com/derp/derpconst from tailscale.com/derp/derphttp+ + tailscale.com/derp/derphttp from tailscale.com/ipn/localapi+ + tailscale.com/disco from tailscale.com/net/tstun+ + tailscale.com/drive from tailscale.com/client/local+ + tailscale.com/envknob from tailscale.com/client/local+ + tailscale.com/envknob/featureknob from tailscale.com/client/web+ + tailscale.com/feature from tailscale.com/ipn/ipnext+ + tailscale.com/feature/buildfeatures from tailscale.com/wgengine/magicsock+ + tailscale.com/feature/c2n from tailscale.com/tsnet + tailscale.com/feature/condlite/expvar from tailscale.com/wgengine/magicsock + tailscale.com/feature/condregister/oauthkey from tailscale.com/tsnet + tailscale.com/feature/condregister/portmapper from tailscale.com/tsnet + tailscale.com/feature/condregister/useproxy from tailscale.com/tsnet + tailscale.com/feature/oauthkey from tailscale.com/feature/condregister/oauthkey + tailscale.com/feature/portmapper from tailscale.com/feature/condregister/portmapper + tailscale.com/feature/syspolicy from tailscale.com/logpolicy + tailscale.com/feature/useproxy from tailscale.com/feature/condregister/useproxy + tailscale.com/health from tailscale.com/control/controlclient+ + tailscale.com/health/healthmsg from tailscale.com/ipn/ipnlocal+ + tailscale.com/hostinfo from tailscale.com/client/web+ + tailscale.com/internal/client/tailscale from tailscale.com/tsnet+ + tailscale.com/ipn from tailscale.com/client/local+ + tailscale.com/ipn/conffile from tailscale.com/ipn/ipnlocal+ + đŸ’Ŗ tailscale.com/ipn/ipnauth from tailscale.com/ipn/ipnext+ + tailscale.com/ipn/ipnext from tailscale.com/ipn/ipnlocal + tailscale.com/ipn/ipnlocal from tailscale.com/ipn/localapi+ + tailscale.com/ipn/ipnstate from tailscale.com/client/local+ + tailscale.com/ipn/localapi from tailscale.com/tsnet + tailscale.com/ipn/store from tailscale.com/ipn/ipnlocal+ + tailscale.com/ipn/store/mem from tailscale.com/ipn/ipnlocal+ + tailscale.com/kube/kubetypes from tailscale.com/envknob + LDW tailscale.com/licenses from tailscale.com/client/web + tailscale.com/log/filelogger from tailscale.com/logpolicy + tailscale.com/log/sockstatlog from tailscale.com/ipn/ipnlocal + tailscale.com/logpolicy from tailscale.com/ipn/ipnlocal+ + tailscale.com/logtail from tailscale.com/control/controlclient+ + tailscale.com/logtail/filch from tailscale.com/log/sockstatlog+ + tailscale.com/metrics from tailscale.com/tsweb+ + tailscale.com/net/bakedroots from tailscale.com/ipn/ipnlocal+ + đŸ’Ŗ tailscale.com/net/batching from tailscale.com/wgengine/magicsock + tailscale.com/net/captivedetection from tailscale.com/ipn/ipnlocal+ + tailscale.com/net/dns from tailscale.com/ipn/ipnlocal+ + tailscale.com/net/dns/publicdns from tailscale.com/net/dns+ + tailscale.com/net/dns/resolvconffile from tailscale.com/net/dns+ + tailscale.com/net/dns/resolver from tailscale.com/net/dns+ + tailscale.com/net/dnscache from tailscale.com/control/controlclient+ + tailscale.com/net/dnsfallback from tailscale.com/control/controlclient+ + tailscale.com/net/flowtrack from tailscale.com/wgengine+ + tailscale.com/net/ipset from tailscale.com/ipn/ipnlocal+ + tailscale.com/net/memnet from tailscale.com/tsnet + tailscale.com/net/netaddr from tailscale.com/ipn+ + tailscale.com/net/netcheck from tailscale.com/ipn/ipnlocal+ + tailscale.com/net/neterror from tailscale.com/net/dns/resolver+ + tailscale.com/net/netkernelconf from tailscale.com/ipn/ipnlocal + tailscale.com/net/netknob from tailscale.com/logpolicy+ + đŸ’Ŗ tailscale.com/net/netmon from tailscale.com/control/controlclient+ + đŸ’Ŗ tailscale.com/net/netns from tailscale.com/derp/derphttp+ + tailscale.com/net/netutil from tailscale.com/client/local+ + tailscale.com/net/netx from tailscale.com/control/controlclient+ + tailscale.com/net/packet from tailscale.com/ipn/ipnlocal+ + tailscale.com/net/packet/checksum from tailscale.com/net/tstun + tailscale.com/net/ping from tailscale.com/net/netcheck+ + tailscale.com/net/portmapper from tailscale.com/feature/portmapper + tailscale.com/net/portmapper/portmappertype from tailscale.com/net/netcheck+ + tailscale.com/net/proxymux from tailscale.com/tsnet + đŸ’Ŗ tailscale.com/net/sockopts from tailscale.com/wgengine/magicsock + tailscale.com/net/socks5 from tailscale.com/tsnet + tailscale.com/net/sockstats from tailscale.com/control/controlclient+ + tailscale.com/net/stun from tailscale.com/ipn/localapi+ + tailscale.com/net/tlsdial from tailscale.com/control/controlclient+ + tailscale.com/net/tlsdial/blockblame from tailscale.com/net/tlsdial + tailscale.com/net/tsaddr from tailscale.com/client/web+ + tailscale.com/net/tsdial from tailscale.com/control/controlclient+ + đŸ’Ŗ tailscale.com/net/tshttpproxy from tailscale.com/feature/useproxy + tailscale.com/net/tstun from tailscale.com/tsd+ + tailscale.com/net/udprelay/endpoint from tailscale.com/wgengine/magicsock + tailscale.com/net/udprelay/status from tailscale.com/client/local + tailscale.com/omit from tailscale.com/ipn/conffile + tailscale.com/paths from tailscale.com/client/local+ + tailscale.com/proxymap from tailscale.com/tsd+ + đŸ’Ŗ tailscale.com/safesocket from tailscale.com/client/local+ + tailscale.com/syncs from tailscale.com/control/controlhttp+ + tailscale.com/tailcfg from tailscale.com/client/local+ + tailscale.com/tempfork/acme from tailscale.com/ipn/ipnlocal + tailscale.com/tempfork/heap from tailscale.com/wgengine/magicsock + tailscale.com/tempfork/httprec from tailscale.com/feature/c2n + tailscale.com/tka from tailscale.com/client/local+ + tailscale.com/tsconst from tailscale.com/ipn/ipnlocal+ + tailscale.com/tsd from tailscale.com/ipn/ipnext+ + tailscale.com/tstime from tailscale.com/control/controlclient+ + tailscale.com/tstime/mono from tailscale.com/net/tstun+ + tailscale.com/tstime/rate from tailscale.com/wgengine/filter + LDW tailscale.com/tsweb from tailscale.com/util/eventbus + tailscale.com/tsweb/varz from tailscale.com/tsweb+ + tailscale.com/types/appctype from tailscale.com/ipn/ipnlocal+ + tailscale.com/types/bools from tailscale.com/tsnet+ + tailscale.com/types/dnstype from tailscale.com/client/local+ + tailscale.com/types/empty from tailscale.com/ipn+ + tailscale.com/types/ipproto from tailscale.com/ipn+ + tailscale.com/types/key from tailscale.com/client/local+ + tailscale.com/types/lazy from tailscale.com/hostinfo+ + tailscale.com/types/logger from tailscale.com/appc+ + tailscale.com/types/logid from tailscale.com/ipn/ipnlocal+ + tailscale.com/types/mapx from tailscale.com/ipn/ipnext + tailscale.com/types/netlogfunc from tailscale.com/net/tstun+ + tailscale.com/types/netlogtype from tailscale.com/wgengine/netlog + tailscale.com/types/netmap from tailscale.com/control/controlclient+ + tailscale.com/types/nettype from tailscale.com/ipn/localapi+ + tailscale.com/types/opt from tailscale.com/control/controlknobs+ + tailscale.com/types/persist from tailscale.com/control/controlclient+ + tailscale.com/types/preftype from tailscale.com/ipn+ + tailscale.com/types/ptr from tailscale.com/control/controlclient+ + tailscale.com/types/result from tailscale.com/util/lineiter + tailscale.com/types/structs from tailscale.com/control/controlclient+ + tailscale.com/types/tkatype from tailscale.com/client/local+ + tailscale.com/types/views from tailscale.com/appc+ + tailscale.com/util/backoff from tailscale.com/control/controlclient+ + tailscale.com/util/checkchange from tailscale.com/ipn/ipnlocal+ + tailscale.com/util/cibuild from tailscale.com/health+ + tailscale.com/util/clientmetric from tailscale.com/appc+ + tailscale.com/util/cloudenv from tailscale.com/hostinfo+ + LW tailscale.com/util/cmpver from tailscale.com/net/dns+ + tailscale.com/util/ctxkey from tailscale.com/client/tailscale/apitype+ + đŸ’Ŗ tailscale.com/util/deephash from tailscale.com/util/syspolicy/setting + LA đŸ’Ŗ tailscale.com/util/dirwalk from tailscale.com/metrics + tailscale.com/util/dnsname from tailscale.com/appc+ + tailscale.com/util/eventbus from tailscale.com/client/local+ + tailscale.com/util/execqueue from tailscale.com/appc+ + tailscale.com/util/goroutines from tailscale.com/ipn/ipnlocal + tailscale.com/util/groupmember from tailscale.com/client/web+ + đŸ’Ŗ tailscale.com/util/hashx from tailscale.com/util/deephash + tailscale.com/util/httpm from tailscale.com/client/web+ + tailscale.com/util/lineiter from tailscale.com/hostinfo+ + tailscale.com/util/mak from tailscale.com/appc+ + tailscale.com/util/must from tailscale.com/logpolicy+ + tailscale.com/util/nocasemaps from tailscale.com/types/ipproto + đŸ’Ŗ tailscale.com/util/osdiag from tailscale.com/ipn/localapi + W đŸ’Ŗ tailscale.com/util/osdiag/internal/wsc from tailscale.com/util/osdiag + tailscale.com/util/osuser from tailscale.com/ipn/ipnlocal + tailscale.com/util/race from tailscale.com/net/dns/resolver + tailscale.com/util/racebuild from tailscale.com/logpolicy + tailscale.com/util/rands from tailscale.com/ipn/ipnlocal+ + tailscale.com/util/ringlog from tailscale.com/wgengine/magicsock + tailscale.com/util/set from tailscale.com/control/controlclient+ + tailscale.com/util/singleflight from tailscale.com/control/controlclient+ + tailscale.com/util/slicesx from tailscale.com/appc+ + tailscale.com/util/syspolicy from tailscale.com/feature/syspolicy + tailscale.com/util/syspolicy/internal from tailscale.com/util/syspolicy+ + tailscale.com/util/syspolicy/internal/loggerx from tailscale.com/util/syspolicy+ + tailscale.com/util/syspolicy/internal/metrics from tailscale.com/util/syspolicy/source + tailscale.com/util/syspolicy/pkey from tailscale.com/control/controlclient+ + tailscale.com/util/syspolicy/policyclient from tailscale.com/control/controlclient+ + tailscale.com/util/syspolicy/ptype from tailscale.com/util/syspolicy+ + tailscale.com/util/syspolicy/rsop from tailscale.com/ipn/localapi+ + tailscale.com/util/syspolicy/setting from tailscale.com/client/local+ + tailscale.com/util/syspolicy/source from tailscale.com/util/syspolicy+ + tailscale.com/util/testenv from tailscale.com/control/controlclient+ + tailscale.com/util/truncate from tailscale.com/logtail + tailscale.com/util/usermetric from tailscale.com/health+ + tailscale.com/util/vizerror from tailscale.com/tailcfg+ + đŸ’Ŗ tailscale.com/util/winutil from tailscale.com/hostinfo+ + W đŸ’Ŗ tailscale.com/util/winutil/authenticode from tailscale.com/util/osdiag + W đŸ’Ŗ tailscale.com/util/winutil/gp from tailscale.com/net/dns+ + W tailscale.com/util/winutil/policy from tailscale.com/ipn/ipnlocal + W đŸ’Ŗ tailscale.com/util/winutil/winenv from tailscale.com/hostinfo+ + tailscale.com/util/zstdframe from tailscale.com/control/controlclient+ + tailscale.com/version from tailscale.com/client/web+ + tailscale.com/version/distro from tailscale.com/client/web+ + tailscale.com/wgengine from tailscale.com/ipn/ipnlocal+ + tailscale.com/wgengine/filter from tailscale.com/control/controlclient+ + tailscale.com/wgengine/filter/filtertype from tailscale.com/types/netmap+ + đŸ’Ŗ tailscale.com/wgengine/magicsock from tailscale.com/ipn/ipnlocal+ + tailscale.com/wgengine/netlog from tailscale.com/wgengine + tailscale.com/wgengine/netstack from tailscale.com/tsnet + tailscale.com/wgengine/netstack/gro from tailscale.com/net/tstun+ + tailscale.com/wgengine/router from tailscale.com/ipn/ipnlocal+ + tailscale.com/wgengine/wgcfg from tailscale.com/ipn/ipnlocal+ + tailscale.com/wgengine/wgcfg/nmcfg from tailscale.com/ipn/ipnlocal + đŸ’Ŗ tailscale.com/wgengine/wgint from tailscale.com/wgengine+ + tailscale.com/wgengine/wglog from tailscale.com/wgengine + golang.org/x/crypto/argon2 from tailscale.com/tka + golang.org/x/crypto/blake2b from golang.org/x/crypto/argon2+ + golang.org/x/crypto/blake2s from github.com/tailscale/wireguard-go/device+ + LD golang.org/x/crypto/blowfish from golang.org/x/crypto/ssh/internal/bcrypt_pbkdf + golang.org/x/crypto/chacha20 from golang.org/x/crypto/chacha20poly1305+ + golang.org/x/crypto/chacha20poly1305 from github.com/tailscale/wireguard-go/device+ + golang.org/x/crypto/curve25519 from github.com/tailscale/wireguard-go/device+ + golang.org/x/crypto/hkdf from tailscale.com/control/controlbase + golang.org/x/crypto/internal/alias from golang.org/x/crypto/chacha20+ + golang.org/x/crypto/internal/poly1305 from golang.org/x/crypto/chacha20poly1305+ + golang.org/x/crypto/nacl/box from tailscale.com/types/key + golang.org/x/crypto/nacl/secretbox from golang.org/x/crypto/nacl/box + golang.org/x/crypto/poly1305 from github.com/tailscale/wireguard-go/device + golang.org/x/crypto/salsa20/salsa from golang.org/x/crypto/nacl/box+ + LD golang.org/x/crypto/ssh from tailscale.com/ipn/ipnlocal + LD golang.org/x/crypto/ssh/internal/bcrypt_pbkdf from golang.org/x/crypto/ssh + golang.org/x/exp/constraints from tailscale.com/tsweb/varz+ + golang.org/x/exp/maps from tailscale.com/ipn/store/mem+ + golang.org/x/net/bpf from github.com/mdlayher/netlink+ + golang.org/x/net/dns/dnsmessage from tailscale.com/appc+ + golang.org/x/net/http/httpguts from tailscale.com/ipn/ipnlocal + golang.org/x/net/http/httpproxy from tailscale.com/net/tshttpproxy + golang.org/x/net/icmp from github.com/prometheus-community/pro-bing+ + golang.org/x/net/idna from golang.org/x/net/http/httpguts+ + golang.org/x/net/internal/iana from golang.org/x/net/icmp+ + golang.org/x/net/internal/socket from golang.org/x/net/icmp+ + LDW golang.org/x/net/internal/socks from golang.org/x/net/proxy + golang.org/x/net/ipv4 from github.com/prometheus-community/pro-bing+ + golang.org/x/net/ipv6 from github.com/prometheus-community/pro-bing+ + LDW golang.org/x/net/proxy from tailscale.com/net/netns + DI golang.org/x/net/route from tailscale.com/net/netmon+ + golang.org/x/oauth2 from golang.org/x/oauth2/clientcredentials + golang.org/x/oauth2/clientcredentials from tailscale.com/feature/oauthkey + golang.org/x/oauth2/internal from golang.org/x/oauth2+ + golang.org/x/sync/errgroup from github.com/mdlayher/socket+ + golang.org/x/sys/cpu from github.com/tailscale/certstore+ + LDAI golang.org/x/sys/unix from github.com/jsimonetti/rtnetlink/internal/unix+ + W golang.org/x/sys/windows from github.com/dblohm7/wingoes+ + W golang.org/x/sys/windows/registry from github.com/dblohm7/wingoes+ + W golang.org/x/sys/windows/svc from golang.org/x/sys/windows/svc/mgr+ + W golang.org/x/sys/windows/svc/mgr from tailscale.com/util/winutil + golang.org/x/term from tailscale.com/logpolicy + golang.org/x/text/secure/bidirule from golang.org/x/net/idna + golang.org/x/text/transform from golang.org/x/text/secure/bidirule+ + golang.org/x/text/unicode/bidi from golang.org/x/net/idna+ + golang.org/x/text/unicode/norm from golang.org/x/net/idna + golang.org/x/time/rate from gvisor.dev/gvisor/pkg/log+ + vendor/golang.org/x/crypto/chacha20 from vendor/golang.org/x/crypto/chacha20poly1305 + vendor/golang.org/x/crypto/chacha20poly1305 from crypto/internal/hpke+ + vendor/golang.org/x/crypto/cryptobyte from crypto/ecdsa+ + vendor/golang.org/x/crypto/cryptobyte/asn1 from crypto/ecdsa+ + vendor/golang.org/x/crypto/internal/alias from vendor/golang.org/x/crypto/chacha20+ + vendor/golang.org/x/crypto/internal/poly1305 from vendor/golang.org/x/crypto/chacha20poly1305 + vendor/golang.org/x/net/dns/dnsmessage from net + vendor/golang.org/x/net/http/httpguts from net/http+ + vendor/golang.org/x/net/http/httpproxy from net/http + vendor/golang.org/x/net/http2/hpack from net/http+ + vendor/golang.org/x/net/idna from net/http+ + vendor/golang.org/x/sys/cpu from vendor/golang.org/x/crypto/chacha20poly1305 + vendor/golang.org/x/text/secure/bidirule from vendor/golang.org/x/net/idna + vendor/golang.org/x/text/transform from vendor/golang.org/x/text/secure/bidirule+ + vendor/golang.org/x/text/unicode/bidi from vendor/golang.org/x/net/idna+ + vendor/golang.org/x/text/unicode/norm from vendor/golang.org/x/net/idna + bufio from compress/flate+ + bytes from bufio+ + cmp from encoding/json+ + compress/flate from compress/gzip+ + compress/gzip from internal/profile+ + W compress/zlib from debug/pe + container/heap from gvisor.dev/gvisor/pkg/tcpip/transport/tcp + container/list from crypto/tls+ + context from crypto/tls+ + crypto from crypto/ecdh+ + crypto/aes from crypto/internal/hpke+ + crypto/cipher from crypto/aes+ + crypto/des from crypto/tls+ + crypto/dsa from crypto/x509+ + crypto/ecdh from crypto/ecdsa+ + crypto/ecdsa from crypto/tls+ + crypto/ed25519 from crypto/tls+ + crypto/elliptic from crypto/ecdsa+ + crypto/fips140 from crypto/tls/internal/fips140tls+ + crypto/hkdf from crypto/internal/hpke+ + crypto/hmac from crypto/tls+ + crypto/internal/boring from crypto/aes+ + crypto/internal/boring/bbig from crypto/ecdsa+ + crypto/internal/boring/sig from crypto/internal/boring + crypto/internal/entropy from crypto/internal/fips140/drbg + crypto/internal/fips140 from crypto/internal/fips140/aes+ + crypto/internal/fips140/aes from crypto/aes+ + crypto/internal/fips140/aes/gcm from crypto/cipher+ + crypto/internal/fips140/alias from crypto/cipher+ + crypto/internal/fips140/bigmod from crypto/internal/fips140/ecdsa+ + crypto/internal/fips140/check from crypto/internal/fips140/aes+ + crypto/internal/fips140/drbg from crypto/internal/fips140/aes/gcm+ + crypto/internal/fips140/ecdh from crypto/ecdh + crypto/internal/fips140/ecdsa from crypto/ecdsa + crypto/internal/fips140/ed25519 from crypto/ed25519 + crypto/internal/fips140/edwards25519 from crypto/internal/fips140/ed25519 + crypto/internal/fips140/edwards25519/field from crypto/ecdh+ + crypto/internal/fips140/hkdf from crypto/internal/fips140/tls13+ + crypto/internal/fips140/hmac from crypto/hmac+ + crypto/internal/fips140/mlkem from crypto/tls+ + crypto/internal/fips140/nistec from crypto/elliptic+ + crypto/internal/fips140/nistec/fiat from crypto/internal/fips140/nistec + crypto/internal/fips140/rsa from crypto/rsa + crypto/internal/fips140/sha256 from crypto/internal/fips140/check+ + crypto/internal/fips140/sha3 from crypto/internal/fips140/hmac+ + crypto/internal/fips140/sha512 from crypto/internal/fips140/ecdsa+ + crypto/internal/fips140/subtle from crypto/internal/fips140/aes+ + crypto/internal/fips140/tls12 from crypto/tls + crypto/internal/fips140/tls13 from crypto/tls + crypto/internal/fips140cache from crypto/ecdsa+ + crypto/internal/fips140deps/byteorder from crypto/internal/fips140/aes+ + crypto/internal/fips140deps/cpu from crypto/internal/fips140/aes+ + crypto/internal/fips140deps/godebug from crypto/internal/fips140+ + crypto/internal/fips140hash from crypto/ecdsa+ + crypto/internal/fips140only from crypto/cipher+ + crypto/internal/hpke from crypto/tls + crypto/internal/impl from crypto/internal/fips140/aes+ + crypto/internal/randutil from crypto/dsa+ + crypto/internal/sysrand from crypto/internal/entropy+ + crypto/md5 from crypto/tls+ + LD crypto/mlkem from golang.org/x/crypto/ssh + crypto/rand from crypto/ed25519+ + crypto/rc4 from crypto/tls+ + crypto/rsa from crypto/tls+ + crypto/sha1 from crypto/tls+ + crypto/sha256 from crypto/tls+ + crypto/sha3 from crypto/internal/fips140hash + crypto/sha512 from crypto/ecdsa+ + crypto/subtle from crypto/cipher+ + crypto/tls from github.com/prometheus-community/pro-bing+ + crypto/tls/internal/fips140tls from crypto/tls + crypto/x509 from crypto/tls+ + DI crypto/x509/internal/macos from crypto/x509 + crypto/x509/pkix from crypto/x509+ + DI database/sql/driver from github.com/google/uuid + W debug/dwarf from debug/pe + W debug/pe from github.com/dblohm7/wingoes/pe + embed from github.com/tailscale/web-client-prebuilt+ + encoding from encoding/json+ + encoding/asn1 from crypto/x509+ + encoding/base32 from github.com/fxamacker/cbor/v2+ + encoding/base64 from encoding/json+ + encoding/binary from compress/gzip+ + encoding/hex from crypto/x509+ + encoding/json from expvar+ + encoding/pem from crypto/tls+ + encoding/xml from github.com/tailscale/goupnp+ + errors from bufio+ + expvar from tailscale.com/health+ + flag from tailscale.com/util/testenv + fmt from compress/flate+ + hash from crypto+ + W hash/adler32 from compress/zlib + hash/crc32 from compress/gzip+ + hash/maphash from go4.org/mem + html from html/template+ + LDW html/template from tailscale.com/util/eventbus + internal/abi from crypto/x509/internal/macos+ + internal/asan from internal/runtime/maps+ + internal/bisect from internal/godebug + internal/bytealg from bytes+ + internal/byteorder from crypto/cipher+ + internal/chacha8rand from math/rand/v2+ + internal/coverage/rtcov from runtime + internal/cpu from crypto/internal/fips140deps/cpu+ + internal/filepathlite from os+ + internal/fmtsort from fmt+ + internal/goarch from crypto/internal/fips140deps/cpu+ + internal/godebug from crypto/internal/fips140deps/godebug+ + internal/godebugs from internal/godebug+ + internal/goexperiment from hash/maphash+ + internal/goos from crypto/x509+ + internal/itoa from internal/poll+ + internal/msan from internal/runtime/maps+ + internal/nettrace from net+ + internal/oserror from io/fs+ + internal/poll from net+ + LDW internal/profile from net/http/pprof + internal/profilerecord from runtime+ + internal/race from internal/poll+ + internal/reflectlite from context+ + DI internal/routebsd from net + internal/runtime/atomic from internal/runtime/exithook+ + LA internal/runtime/cgroup from runtime + internal/runtime/exithook from runtime + internal/runtime/gc from runtime + internal/runtime/maps from reflect+ + internal/runtime/math from internal/runtime/maps+ + internal/runtime/strconv from internal/runtime/cgroup+ + internal/runtime/sys from crypto/subtle+ + LA internal/runtime/syscall from runtime+ + internal/saferio from debug/pe+ + internal/singleflight from net + internal/stringslite from embed+ + internal/sync from sync+ + internal/synctest from sync + internal/syscall/execenv from os+ + LDAI internal/syscall/unix from crypto/internal/sysrand+ + W internal/syscall/windows from crypto/internal/sysrand+ + W internal/syscall/windows/registry from mime+ + W internal/syscall/windows/sysdll from internal/syscall/windows+ + internal/testlog from os + internal/trace/tracev2 from runtime+ + internal/unsafeheader from internal/reflectlite+ + io from bufio+ + io/fs from crypto/x509+ + io/ioutil from github.com/godbus/dbus/v5+ + iter from bytes+ + log from expvar+ + log/internal from log + maps from crypto/x509+ + math from compress/flate+ + math/big from crypto/dsa+ + math/bits from bytes+ + math/rand from github.com/fxamacker/cbor/v2+ + math/rand/v2 from crypto/ecdsa+ + mime from mime/multipart+ + mime/multipart from net/http + mime/quotedprintable from mime/multipart + net from crypto/tls+ + net/http from expvar+ + net/http/httptrace from github.com/prometheus-community/pro-bing+ + net/http/httputil from tailscale.com/client/web+ + net/http/internal from net/http+ + net/http/internal/ascii from net/http+ + net/http/internal/httpcommon from net/http + LDW net/http/pprof from tailscale.com/ipn/localapi+ + net/netip from crypto/x509+ + net/textproto from github.com/coder/websocket+ + net/url from crypto/x509+ + os from crypto/internal/sysrand+ + os/exec from github.com/godbus/dbus/v5+ + os/user from github.com/godbus/dbus/v5+ + path from debug/dwarf+ + path/filepath from crypto/x509+ + reflect from crypto/x509+ + regexp from github.com/tailscale/goupnp/httpu+ + regexp/syntax from regexp + runtime from crypto/internal/fips140+ + runtime/debug from github.com/coder/websocket/internal/xsync+ + runtime/pprof from net/http/pprof+ + LDW runtime/trace from net/http/pprof + slices from crypto/tls+ + sort from compress/flate+ + strconv from compress/flate+ + strings from bufio+ + W structs from internal/syscall/windows + sync from compress/flate+ + sync/atomic from context+ + syscall from crypto/internal/sysrand+ + text/tabwriter from runtime/pprof + LDW text/template from html/template + LDW text/template/parse from html/template+ + time from compress/gzip+ + unicode from bytes+ + unicode/utf16 from crypto/x509+ + unicode/utf8 from bufio+ + unique from net/netip + unsafe from bytes+ + weak from unique+ diff --git a/tsnet/packet_filter_test.go b/tsnet/packet_filter_test.go new file mode 100644 index 000000000..455400eaa --- /dev/null +++ b/tsnet/packet_filter_test.go @@ -0,0 +1,250 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tsnet + +import ( + "context" + "fmt" + "net/netip" + "testing" + "time" + + "tailscale.com/ipn" + "tailscale.com/tailcfg" + "tailscale.com/tstest" + "tailscale.com/types/ipproto" + "tailscale.com/types/key" + "tailscale.com/types/netmap" + "tailscale.com/util/must" + "tailscale.com/wgengine/filter" +) + +// waitFor blocks until a NetMap is seen on the IPN bus that satisfies the given +// function f. Note: has no timeout, should be called with a ctx that has an +// appropriate timeout set. +func waitFor(t testing.TB, ctx context.Context, s *Server, f func(*netmap.NetworkMap) bool) error { + t.Helper() + watcher, err := s.localClient.WatchIPNBus(ctx, ipn.NotifyInitialNetMap) + if err != nil { + t.Fatalf("error watching IPN bus: %s", err) + } + defer watcher.Close() + + for { + n, err := watcher.Next() + if err != nil { + return fmt.Errorf("getting next ipn.Notify from IPN bus: %w", err) + } + if n.NetMap != nil { + if f(n.NetMap) { + return nil + } + } + } +} + +// TestPacketFilterFromNetmap tests all of the client code for processing +// netmaps and turning them into packet filters together. Only the control-plane +// side is mocked out. +func TestPacketFilterFromNetmap(t *testing.T) { + tstest.Shard(t) + t.Parallel() + + var key key.NodePublic + must.Do(key.UnmarshalText([]byte("nodekey:5c8f86d5fc70d924e55f02446165a5dae8f822994ad26bcf4b08fd841f9bf261"))) + + type check struct { + src string + dst string + port uint16 + want filter.Response + } + + tests := []struct { + name string + mapResponse *tailcfg.MapResponse + waitTest func(*netmap.NetworkMap) bool + + incrementalMapResponse *tailcfg.MapResponse // optional + incrementalWaitTest func(*netmap.NetworkMap) bool // optional + + checks []check + }{ + { + name: "IP_based_peers", + mapResponse: &tailcfg.MapResponse{ + Node: &tailcfg.Node{ + Addresses: []netip.Prefix{netip.MustParsePrefix("1.1.1.1/32")}, + }, + Peers: []*tailcfg.Node{{ + ID: 2, + Name: "foo", + Key: key, + Addresses: []netip.Prefix{netip.MustParsePrefix("2.2.2.2/32")}, + CapMap: nil, + }}, + PacketFilter: []tailcfg.FilterRule{{ + SrcIPs: []string{"2.2.2.2/32"}, + DstPorts: []tailcfg.NetPortRange{{ + IP: "1.1.1.1/32", + Ports: tailcfg.PortRange{ + First: 22, + Last: 22, + }, + }}, + IPProto: []int{int(ipproto.TCP)}, + }}, + }, + waitTest: func(nm *netmap.NetworkMap) bool { + return len(nm.Peers) > 0 + }, + checks: []check{ + {src: "2.2.2.2", dst: "1.1.1.1", port: 22, want: filter.Accept}, + {src: "2.2.2.2", dst: "1.1.1.1", port: 23, want: filter.Drop}, // different port + {src: "3.3.3.3", dst: "1.1.1.1", port: 22, want: filter.Drop}, // different src + {src: "2.2.2.2", dst: "1.1.1.2", port: 22, want: filter.Drop}, // different dst + }, + }, + { + name: "capmap_based_peers", + mapResponse: &tailcfg.MapResponse{ + Node: &tailcfg.Node{ + Addresses: []netip.Prefix{netip.MustParsePrefix("1.1.1.1/32")}, + }, + Peers: []*tailcfg.Node{{ + ID: 2, + Name: "foo", + Key: key, + Addresses: []netip.Prefix{netip.MustParsePrefix("2.2.2.2/32")}, + CapMap: tailcfg.NodeCapMap{"X": nil}, + }}, + PacketFilter: []tailcfg.FilterRule{{ + SrcIPs: []string{"cap:X"}, + DstPorts: []tailcfg.NetPortRange{{ + IP: "1.1.1.1/32", + Ports: tailcfg.PortRange{ + First: 22, + Last: 22, + }, + }}, + IPProto: []int{int(ipproto.TCP)}, + }}, + }, + waitTest: func(nm *netmap.NetworkMap) bool { + return len(nm.Peers) > 0 + }, + checks: []check{ + {src: "2.2.2.2", dst: "1.1.1.1", port: 22, want: filter.Accept}, + {src: "2.2.2.2", dst: "1.1.1.1", port: 23, want: filter.Drop}, // different port + {src: "3.3.3.3", dst: "1.1.1.1", port: 22, want: filter.Drop}, // different src + {src: "2.2.2.2", dst: "1.1.1.2", port: 22, want: filter.Drop}, // different dst + }, + }, + { + name: "capmap_based_peers_changed", + mapResponse: &tailcfg.MapResponse{ + Node: &tailcfg.Node{ + Addresses: []netip.Prefix{netip.MustParsePrefix("1.1.1.1/32")}, + CapMap: tailcfg.NodeCapMap{"X-sigil": nil}, + }, + PacketFilter: []tailcfg.FilterRule{{ + SrcIPs: []string{"cap:label-1"}, + DstPorts: []tailcfg.NetPortRange{{ + IP: "1.1.1.1/32", + Ports: tailcfg.PortRange{ + First: 22, + Last: 22, + }, + }}, + IPProto: []int{int(ipproto.TCP)}, + }}, + }, + waitTest: func(nm *netmap.NetworkMap) bool { + return nm.SelfNode.HasCap("X-sigil") + }, + incrementalMapResponse: &tailcfg.MapResponse{ + PeersChanged: []*tailcfg.Node{{ + ID: 2, + Name: "foo", + Key: key, + Addresses: []netip.Prefix{netip.MustParsePrefix("2.2.2.2/32")}, + CapMap: tailcfg.NodeCapMap{"label-1": nil}, + }}, + }, + incrementalWaitTest: func(nm *netmap.NetworkMap) bool { + return len(nm.Peers) > 0 + }, + checks: []check{ + {src: "2.2.2.2", dst: "1.1.1.1", port: 22, want: filter.Accept}, + {src: "2.2.2.2", dst: "1.1.1.1", port: 23, want: filter.Drop}, // different port + {src: "3.3.3.3", dst: "1.1.1.1", port: 22, want: filter.Drop}, // different src + {src: "2.2.2.2", dst: "1.1.1.2", port: 22, want: filter.Drop}, // different dst + }, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithTimeout(t.Context(), 30*time.Second) + defer cancel() + + controlURL, c := startControl(t) + s, _, pubKey := startServer(t, ctx, controlURL, "node") + + if test.waitTest(s.lb.NetMap()) { + t.Fatal("waitTest already passes before sending initial netmap: this will be flaky") + } + + if !c.AddRawMapResponse(pubKey, test.mapResponse) { + t.Fatalf("could not send map response to %s", pubKey) + } + + if err := waitFor(t, ctx, s, test.waitTest); err != nil { + t.Fatalf("waitFor: %s", err) + } + + pf := s.lb.GetFilterForTest() + + for _, check := range test.checks { + got := pf.Check(netip.MustParseAddr(check.src), netip.MustParseAddr(check.dst), check.port, ipproto.TCP) + + want := check.want + if test.incrementalMapResponse != nil { + want = filter.Drop + } + if got != want { + t.Errorf("check %s -> %s:%d, got: %s, want: %s", check.src, check.dst, check.port, got, want) + } + } + + if test.incrementalMapResponse != nil { + if test.incrementalWaitTest == nil { + t.Fatal("incrementalWaitTest must be set if incrementalMapResponse is set") + } + + if test.incrementalWaitTest(s.lb.NetMap()) { + t.Fatal("incrementalWaitTest already passes before sending incremental netmap: this will be flaky") + } + + if !c.AddRawMapResponse(pubKey, test.incrementalMapResponse) { + t.Fatalf("could not send map response to %s", pubKey) + } + + if err := waitFor(t, ctx, s, test.incrementalWaitTest); err != nil { + t.Fatalf("waitFor: %s", err) + } + + pf := s.lb.GetFilterForTest() + + for _, check := range test.checks { + got := pf.Check(netip.MustParseAddr(check.src), netip.MustParseAddr(check.dst), check.port, ipproto.TCP) + if got != check.want { + t.Errorf("check %s -> %s:%d, got: %s, want: %s", check.src, check.dst, check.port, got, check.want) + } + } + } + + }) + } +} diff --git a/tsnet/tsnet.go b/tsnet/tsnet.go index 0be33ba8a..14747650f 100644 --- a/tsnet/tsnet.go +++ b/tsnet/tsnet.go @@ -26,12 +26,18 @@ import ( "sync" "time" - "tailscale.com/client/tailscale" + "tailscale.com/client/local" "tailscale.com/control/controlclient" "tailscale.com/envknob" + _ "tailscale.com/feature/c2n" + _ "tailscale.com/feature/condregister/oauthkey" + _ "tailscale.com/feature/condregister/portmapper" + _ "tailscale.com/feature/condregister/useproxy" "tailscale.com/health" "tailscale.com/hostinfo" + "tailscale.com/internal/client/tailscale" "tailscale.com/ipn" + "tailscale.com/ipn/ipnauth" "tailscale.com/ipn/ipnlocal" "tailscale.com/ipn/ipnstate" "tailscale.com/ipn/localapi" @@ -46,6 +52,7 @@ import ( "tailscale.com/net/socks5" "tailscale.com/net/tsdial" "tailscale.com/tsd" + "tailscale.com/types/bools" "tailscale.com/types/logger" "tailscale.com/types/logid" "tailscale.com/types/nettype" @@ -78,7 +85,7 @@ type Server struct { // If nil, a new FileStore is initialized at `Dir/tailscaled.state`. // See tailscale.com/ipn/store for supported stores. // - // Logs will automatically be uploaded to log.tailscale.io, + // Logs will automatically be uploaded to log.tailscale.com, // where the configuration file for logging will be saved at // `Dir/tailscaled.log.conf`. Store ipn.StateStore @@ -121,22 +128,29 @@ type Server struct { // field at zero unless you know what you are doing. Port uint16 + // AdvertiseTags specifies tags that should be applied to this node, for + // purposes of ACL enforcement. These can be referenced from the ACL policy + // document. Note that advertising a tag on the client doesn't guarantee + // that the control server will allow the node to adopt that tag. + AdvertiseTags []string + getCertForTesting func(*tls.ClientHelloInfo) (*tls.Certificate, error) initOnce sync.Once initErr error lb *ipnlocal.LocalBackend + sys *tsd.System netstack *netstack.Impl netMon *netmon.Monitor rootPath string // the state directory hostname string shutdownCtx context.Context shutdownCancel context.CancelFunc - proxyCred string // SOCKS5 proxy auth for loopbackListener - localAPICred string // basic auth password for loopbackListener - loopbackListener net.Listener // optional loopback for localapi and proxies - localAPIListener net.Listener // in-memory, used by localClient - localClient *tailscale.LocalClient // in-memory + proxyCred string // SOCKS5 proxy auth for loopbackListener + localAPICred string // basic auth password for loopbackListener + loopbackListener net.Listener // optional loopback for localapi and proxies + localAPIListener net.Listener // in-memory, used by localClient + localClient *local.Client // in-memory localAPIServer *http.Server logbuffer *filch.Filch logtail *logtail.Logger @@ -168,9 +182,41 @@ func (s *Server) Dial(ctx context.Context, network, address string) (net.Conn, e if err := s.Start(); err != nil { return nil, err } + if err := s.awaitRunning(ctx); err != nil { + return nil, err + } return s.dialer.UserDial(ctx, network, address) } +// awaitRunning waits until the backend is in state Running. +// If the backend is in state Starting, it blocks until it reaches +// a terminal state (such as Stopped, NeedsMachineAuth) +// or the context expires. +func (s *Server) awaitRunning(ctx context.Context) error { + st := s.lb.State() + for { + if err := ctx.Err(); err != nil { + return err + } + switch st { + case ipn.Running: + return nil + case ipn.NeedsLogin, ipn.Starting: + // Even after LocalBackend.Start, the state machine is still briefly + // in the "NeedsLogin" state. So treat that as also "Starting" and + // wait for us to get out of that state. + s.lb.WatchNotifications(ctx, ipn.NotifyInitialState, nil, func(n *ipn.Notify) (keepGoing bool) { + if n.State != nil { + st = *n.State + } + return st == ipn.NeedsLogin || st == ipn.Starting + }) + default: + return fmt.Errorf("tsnet: backend in state %v", st) + } + } +} + // HTTPClient returns an HTTP client that is configured to connect over Tailscale. // // This is useful if you need to have your tsnet services connect to other devices on @@ -187,7 +233,7 @@ func (s *Server) HTTPClient() *http.Client { // // It will start the server if it has not been started yet. If the server's // already been started successfully, it doesn't return an error. -func (s *Server) LocalClient() (*tailscale.LocalClient, error) { +func (s *Server) LocalClient() (*local.Client, error) { if err := s.Start(); err != nil { return nil, err } @@ -238,7 +284,13 @@ func (s *Server) Loopback() (addr string, proxyCred, localAPICred string, err er // out the CONNECT code from tailscaled/proxy.go that uses // httputil.ReverseProxy and adding auth support. go func() { - lah := localapi.NewHandler(s.lb, s.logf, s.logid) + lah := localapi.NewHandler(localapi.HandlerConfig{ + Actor: ipnauth.Self, + Backend: s.lb, + Logf: s.logf, + LogID: s.logid, + EventBus: s.sys.Bus.Get(), + }) lah.PermitWrite = true lah.PermitRead = true lah.RequiredPassword = s.localAPICred @@ -298,7 +350,7 @@ func (s *Server) Up(ctx context.Context) (*ipnstate.Status, error) { return nil, fmt.Errorf("tsnet.Up: %w", err) } - watcher, err := lc.WatchIPNBus(ctx, ipn.NotifyInitialState|ipn.NotifyNoPrivateKeys) + watcher, err := lc.WatchIPNBus(ctx, ipn.NotifyInitialState) if err != nil { return nil, fmt.Errorf("tsnet.Up: %w", err) } @@ -398,8 +450,8 @@ func (s *Server) Close() error { for _, ln := range s.listeners { ln.closeLocked() } - wg.Wait() + s.sys.Bus.Get().Close() s.closed = true return nil } @@ -432,8 +484,7 @@ func (s *Server) TailscaleIPs() (ip4, ip6 netip.Addr) { return } addrs := nm.GetAddresses() - for i := range addrs.Len() { - addr := addrs.At(i) + for _, addr := range addrs.All() { ip := addr.Addr() if ip.Is6() { ip6 = ip @@ -446,6 +497,16 @@ func (s *Server) TailscaleIPs() (ip4, ip6 netip.Addr) { return ip4, ip6 } +// LogtailWriter returns an [io.Writer] that writes to Tailscale's logging service and will be only visible to Tailscale's +// support team. Logs written there cannot be retrieved by the user. This method always returns a non-nil value. +func (s *Server) LogtailWriter() io.Writer { + if s.logtail == nil { + return io.Discard + } + + return s.logtail +} + func (s *Server) getAuthKey() string { if v := s.AuthKey; v != "" { return v @@ -469,6 +530,11 @@ func (s *Server) start() (reterr error) { // directory and hostname when they're not supplied. But we can fall // back to "tsnet" as well. exe = "tsnet" + case "ios": + // When compiled as a framework (via TailscaleKit in libtailscale), + // os.Executable() returns an error, so fall back to "tsnet" there + // too. + exe = "tsnet" default: return err } @@ -493,10 +559,7 @@ func (s *Server) start() (reterr error) { if err != nil { return err } - s.rootPath, err = getTSNetDir(s.logf, confDir, prog) - if err != nil { - return err - } + s.rootPath = filepath.Join(confDir, "tsnet-"+prog) } if err := os.MkdirAll(s.rootPath, 0700); err != nil { return err @@ -517,25 +580,28 @@ func (s *Server) start() (reterr error) { s.Logf(format, a...) } - sys := new(tsd.System) - if err := s.startLogger(&closePool, sys.HealthTracker(), tsLogf); err != nil { + sys := tsd.NewSystem() + s.sys = sys + if err := s.startLogger(&closePool, sys.HealthTracker.Get(), tsLogf); err != nil { return err } - s.netMon, err = netmon.New(tsLogf) + s.netMon, err = netmon.New(sys.Bus.Get(), tsLogf) if err != nil { return err } closePool.add(s.netMon) s.dialer = &tsdial.Dialer{Logf: tsLogf} // mutated below (before used) + s.dialer.SetBus(sys.Bus.Get()) eng, err := wgengine.NewUserspaceEngine(tsLogf, wgengine.Config{ + EventBus: sys.Bus.Get(), ListenPort: s.Port, NetMon: s.netMon, Dialer: s.dialer, SetSubsystem: sys.Set, ControlKnobs: sys.ControlKnobs(), - HealthTracker: sys.HealthTracker(), + HealthTracker: sys.HealthTracker.Get(), Metrics: sys.UserMetricsRegistry(), }) if err != nil { @@ -543,10 +609,10 @@ func (s *Server) start() (reterr error) { } closePool.add(s.dialer) sys.Set(eng) - sys.HealthTracker().SetMetricsRegistry(sys.UserMetricsRegistry()) + sys.HealthTracker.Get().SetMetricsRegistry(sys.UserMetricsRegistry()) // TODO(oxtoacart): do we need to support Taildrive on tsnet, and if so, how? - ns, err := netstack.Create(tsLogf, sys.Tun.Get(), eng, sys.MagicSock.Get(), s.dialer, sys.DNSManager.Get(), sys.ProxyMapper(), nil) + ns, err := netstack.Create(tsLogf, sys.Tun.Get(), eng, sys.MagicSock.Get(), s.dialer, sys.DNSManager.Get(), sys.ProxyMapper()) if err != nil { return fmt.Errorf("netstack.Create: %w", err) } @@ -565,7 +631,9 @@ func (s *Server) start() (reterr error) { // Note: don't just return ns.DialContextTCP or we'll return // *gonet.TCPConn(nil) instead of a nil interface which trips up // callers. - tcpConn, err := ns.DialContextTCP(ctx, dst) + v4, v6 := s.TailscaleIPs() + src := bools.IfElse(dst.Addr().Is6(), v6, v4) + tcpConn, err := ns.DialContextTCPWithBind(ctx, src, dst) if err != nil { return nil, err } @@ -575,7 +643,9 @@ func (s *Server) start() (reterr error) { // Note: don't just return ns.DialContextUDP or we'll return // *gonet.UDPConn(nil) instead of a nil interface which trips up // callers. - udpConn, err := ns.DialContextUDP(ctx, dst) + v4, v6 := s.TailscaleIPs() + src := bools.IfElse(dst.Addr().Is6(), v6, v4) + udpConn, err := ns.DialContextUDPWithBind(ctx, src, dst) if err != nil { return nil, err } @@ -613,7 +683,16 @@ func (s *Server) start() (reterr error) { prefs.WantRunning = true prefs.ControlURL = s.ControlURL prefs.RunWebClient = s.RunWebClient + prefs.AdvertiseTags = s.AdvertiseTags authKey := s.getAuthKey() + // Try to use an OAuth secret to generate an auth key if that functionality + // is available. + if f, ok := tailscale.HookResolveAuthKey.GetOk(); ok { + authKey, err = f(s.shutdownCtx, s.getAuthKey(), prefs.AdvertiseTags) + if err != nil { + return fmt.Errorf("resolving auth key: %w", err) + } + } err = lb.Start(ipn.Options{ UpdatePrefs: prefs, AuthKey: authKey, @@ -633,7 +712,13 @@ func (s *Server) start() (reterr error) { go s.printAuthURLLoop() // Run the localapi handler, to allow fetching LetsEncrypt certs. - lah := localapi.NewHandler(lb, tsLogf, s.logid) + lah := localapi.NewHandler(localapi.HandlerConfig{ + Actor: ipnauth.Self, + Backend: lb, + Logf: tsLogf, + LogID: s.logid, + EventBus: sys.Bus.Get(), + }) lah.PermitWrite = true lah.PermitRead = true @@ -641,7 +726,7 @@ func (s *Server) start() (reterr error) { // nettest.Listen provides a in-memory pipe based implementation for net.Conn. lal := memnet.Listen("local-tailscaled.sock:80") s.localAPIListener = lal - s.localClient = &tailscale.LocalClient{Dial: lal.Dial} + s.localClient = &local.Client{Dial: lal.Dial} s.localAPIServer = &http.Server{Handler: lah} s.lb.ConfigureWebClient(s.localClient) go func() { @@ -684,6 +769,7 @@ func (s *Server) startLogger(closePool *closeOnErrorPool, health *health.Tracker Stderr: io.Discard, // log everything to Buffer Buffer: s.logbuffer, CompressLogs: true, + Bus: s.sys.Bus.Get(), HTTPC: &http.Client{Transport: logpolicy.NewLogtailTransport(logtail.DefaultHost, s.netMon, health, tsLogf)}, MetricsDelta: clientmetric.EncodeLogTailMetricsDelta, } @@ -848,65 +934,6 @@ func (s *Server) getUDPHandlerForFlow(src, dst netip.AddrPort) (handler func(net return func(c nettype.ConnPacketConn) { ln.handle(c) }, true } -// getTSNetDir usually just returns filepath.Join(confDir, "tsnet-"+prog) -// with no error. -// -// One special case is that it renames old "tslib-" directories to -// "tsnet-", and that rename might return an error. -// -// TODO(bradfitz): remove this maybe 6 months after 2022-03-17, -// once people (notably Tailscale corp services) have updated. -func getTSNetDir(logf logger.Logf, confDir, prog string) (string, error) { - oldPath := filepath.Join(confDir, "tslib-"+prog) - newPath := filepath.Join(confDir, "tsnet-"+prog) - - fi, err := os.Lstat(oldPath) - if os.IsNotExist(err) { - // Common path. - return newPath, nil - } - if err != nil { - return "", err - } - if !fi.IsDir() { - return "", fmt.Errorf("expected old tslib path %q to be a directory; got %v", oldPath, fi.Mode()) - } - - // At this point, oldPath exists and is a directory. But does - // the new path exist? - - fi, err = os.Lstat(newPath) - if err == nil && fi.IsDir() { - // New path already exists somehow. Ignore the old one and - // don't try to migrate it. - return newPath, nil - } - if err != nil && !os.IsNotExist(err) { - return "", err - } - if err := os.Rename(oldPath, newPath); err != nil { - return "", err - } - logf("renamed old tsnet state storage directory %q to %q", oldPath, newPath) - return newPath, nil -} - -// APIClient returns a tailscale.Client that can be used to make authenticated -// requests to the Tailscale control server. -// It requires the user to set tailscale.I_Acknowledge_This_API_Is_Unstable. -func (s *Server) APIClient() (*tailscale.Client, error) { - if !tailscale.I_Acknowledge_This_API_Is_Unstable { - return nil, errors.New("use of Client without setting I_Acknowledge_This_API_Is_Unstable") - } - if err := s.Start(); err != nil { - return nil, err - } - - c := tailscale.NewClient("-", nil) - c.HTTPClient = &http.Client{Transport: s.lb.KeyProvingNoiseRoundTripper()} - return c, nil -} - // Listen announces only on the Tailscale network. // It will start the server if it has not been started yet. // @@ -1013,13 +1040,33 @@ type FunnelOption interface { funnelOption() } -type funnelOnly int +type funnelOnly struct{} func (funnelOnly) funnelOption() {} // FunnelOnly configures the listener to only respond to connections from Tailscale Funnel. // The local tailnet will not be able to connect to the listener. -func FunnelOnly() FunnelOption { return funnelOnly(1) } +func FunnelOnly() FunnelOption { return funnelOnly{} } + +type funnelTLSConfig struct{ conf *tls.Config } + +func (f funnelTLSConfig) funnelOption() {} + +// FunnelTLSConfig configures the TLS configuration for [Server.ListenFunnel] +// +// This is rarely needed but can permit requiring client certificates, specific +// ciphers suites, etc. +// +// The provided conf should at least be able to get a certificate, setting +// GetCertificate, Certificates or GetConfigForClient appropriately. +// The most common configuration is to set GetCertificate to +// Server.LocalClient's GetCertificate method. +// +// Unless [FunnelOnly] is also used, the configuration is also used for +// in-tailnet connections that don't arrive over Funnel. +func FunnelTLSConfig(conf *tls.Config) FunnelOption { + return funnelTLSConfig{conf: conf} +} // ListenFunnel announces on the public internet using Tailscale Funnel. // @@ -1052,6 +1099,26 @@ func (s *Server) ListenFunnel(network, addr string, opts ...FunnelOption) (net.L return nil, err } + // Process, validate opts. + lnOn := listenOnBoth + var tlsConfig *tls.Config + for _, opt := range opts { + switch v := opt.(type) { + case funnelTLSConfig: + if v.conf == nil { + return nil, errors.New("invalid nil FunnelTLSConfig") + } + tlsConfig = v.conf + case funnelOnly: + lnOn = listenOnFunnel + default: + return nil, fmt.Errorf("unknown opts FunnelOption type %T", v) + } + } + if tlsConfig == nil { + tlsConfig = &tls.Config{GetCertificate: s.getCert} + } + ctx := context.Background() st, err := s.Up(ctx) if err != nil { @@ -1089,19 +1156,11 @@ func (s *Server) ListenFunnel(network, addr string, opts ...FunnelOption) (net.L } // Start a funnel listener. - lnOn := listenOnBoth - for _, opt := range opts { - if _, ok := opt.(funnelOnly); ok { - lnOn = listenOnFunnel - } - } ln, err := s.listen(network, addr, lnOn) if err != nil { return nil, err } - return tls.NewListener(ln, &tls.Config{ - GetCertificate: s.getCert, - }), nil + return tls.NewListener(ln, tlsConfig), nil } type listenOn string @@ -1178,7 +1237,8 @@ func (s *Server) listen(network, addr string, lnOn listenOn) (net.Listener, erro keys: keys, addr: addr, - conn: make(chan net.Conn), + closedc: make(chan struct{}), + conn: make(chan net.Conn), } s.mu.Lock() for _, key := range keys { @@ -1197,6 +1257,12 @@ func (s *Server) listen(network, addr string, lnOn listenOn) (net.Listener, erro return ln, nil } +// GetRootPath returns the root path of the tsnet server. +// This is where the state file and other data is stored. +func (s *Server) GetRootPath() string { + return s.rootPath +} + // CapturePcap can be called by the application code compiled with tsnet to save a pcap // of packets which the netstack within tsnet sees. This is expected to be useful during // debugging, probably not useful for production. @@ -1226,6 +1292,13 @@ func (s *Server) CapturePcap(ctx context.Context, pcapFile string) error { return nil } +// Sys returns a handle to the Tailscale subsystems of this node. +// +// This is not a stable API, nor are the APIs of the returned subsystems. +func (s *Server) Sys() *tsd.System { + return s.sys +} + type listenKey struct { network string host netip.Addr // or zero value for unspecified @@ -1234,19 +1307,21 @@ type listenKey struct { } type listener struct { - s *Server - keys []listenKey - addr string - conn chan net.Conn - closed bool // guarded by s.mu + s *Server + keys []listenKey + addr string + conn chan net.Conn // unbuffered, never closed + closedc chan struct{} // closed on [listener.Close] + closed bool // guarded by s.mu } func (ln *listener) Accept() (net.Conn, error) { - c, ok := <-ln.conn - if !ok { + select { + case c := <-ln.conn: + return c, nil + case <-ln.closedc: return nil, fmt.Errorf("tsnet: %w", net.ErrClosed) } - return c, nil } func (ln *listener) Addr() net.Addr { return addr{ln} } @@ -1268,21 +1343,22 @@ func (ln *listener) closeLocked() error { delete(ln.s.listeners, key) } } - close(ln.conn) + close(ln.closedc) ln.closed = true return nil } func (ln *listener) handle(c net.Conn) { - t := time.NewTimer(time.Second) - defer t.Stop() select { case ln.conn <- c: - case <-t.C: + return + case <-ln.closedc: + case <-ln.s.shutdownCtx.Done(): + case <-time.After(time.Second): // TODO(bradfitz): this isn't ideal. Think about how // we how we want to do pushback. - c.Close() } + c.Close() } // Server returns the tsnet Server associated with the listener. diff --git a/tsnet/tsnet_test.go b/tsnet/tsnet_test.go index b95061d38..f1531d013 100644 --- a/tsnet/tsnet_test.go +++ b/tsnet/tsnet_test.go @@ -36,13 +36,14 @@ import ( dto "github.com/prometheus/client_model/go" "github.com/prometheus/common/expfmt" "golang.org/x/net/proxy" + "tailscale.com/client/local" "tailscale.com/cmd/testwrapper/flakytest" - "tailscale.com/health" "tailscale.com/ipn" "tailscale.com/ipn/store/mem" "tailscale.com/net/netns" "tailscale.com/tailcfg" "tailscale.com/tstest" + "tailscale.com/tstest/deptest" "tailscale.com/tstest/integration" "tailscale.com/tstest/integration/testcontrol" "tailscale.com/types/key" @@ -120,6 +121,7 @@ func startControl(t *testing.T) (controlURL string, control *testcontrol.Server) Proxied: true, }, MagicDNSDomain: "tail-scale.ts.net", + Logf: t.Logf, } control.HTTPTestServer = httptest.NewUnstartedServer(control) control.HTTPTestServer.Start() @@ -221,7 +223,7 @@ func startServer(t *testing.T, ctx context.Context, controlURL, hostname string) getCertForTesting: testCertRoot.getCert, } if *verboseNodes { - s.Logf = log.Printf + s.Logf = t.Logf } t.Cleanup(func() { s.Close() }) @@ -232,30 +234,98 @@ func startServer(t *testing.T, ctx context.Context, controlURL, hostname string) return s, status.TailscaleIPs[0], status.Self.PublicKey } +func TestDialBlocks(t *testing.T) { + tstest.Shard(t) + tstest.ResourceCheck(t) + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + controlURL, _ := startControl(t) + + // Make one tsnet that blocks until it's up. + s1, _, _ := startServer(t, ctx, controlURL, "s1") + + ln, err := s1.Listen("tcp", ":8080") + if err != nil { + t.Fatal(err) + } + defer ln.Close() + + // Then make another tsnet node that will only be woken up + // upon the first dial. + tmp := filepath.Join(t.TempDir(), "s2") + os.MkdirAll(tmp, 0755) + s2 := &Server{ + Dir: tmp, + ControlURL: controlURL, + Hostname: "s2", + Store: new(mem.Store), + Ephemeral: true, + getCertForTesting: testCertRoot.getCert, + } + if *verboseNodes { + s2.Logf = log.Printf + } + t.Cleanup(func() { s2.Close() }) + + c, err := s2.Dial(ctx, "tcp", "s1:8080") + if err != nil { + t.Fatal(err) + } + defer c.Close() +} + +// TestConn tests basic TCP connections between two tsnet Servers, s1 and s2: +// +// - s1, a subnet router, first listens on its TCP :8081. +// - s2 can connect to s1:8081 +// - s2 cannot connect to s1:8082 (no listener) +// - s2 can dial through the subnet router functionality (getting a synthetic RST +// that we verify we generated & saw) func TestConn(t *testing.T) { + tstest.Shard(t) tstest.ResourceCheck(t) ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() controlURL, c := startControl(t) s1, s1ip, s1PubKey := startServer(t, ctx, controlURL, "s1") - s2, _, _ := startServer(t, ctx, controlURL, "s2") - s1.lb.EditPrefs(&ipn.MaskedPrefs{ + // Track whether we saw an attempted dial to 192.0.2.1:8081. + var saw192DocNetDial atomic.Bool + s1.RegisterFallbackTCPHandler(func(src, dst netip.AddrPort) (handler func(net.Conn), intercept bool) { + t.Logf("s1: fallback TCP handler called for %v -> %v", src, dst) + if dst.String() == "192.0.2.1:8081" { + saw192DocNetDial.Store(true) + } + return nil, true // nil handler but intercept=true means to send RST + }) + + lc1 := must.Get(s1.LocalClient()) + + must.Get(lc1.EditPrefs(ctx, &ipn.MaskedPrefs{ Prefs: ipn.Prefs{ AdvertiseRoutes: []netip.Prefix{netip.MustParsePrefix("192.0.2.0/24")}, }, AdvertiseRoutesSet: true, - }) + })) c.SetSubnetRoutes(s1PubKey, []netip.Prefix{netip.MustParsePrefix("192.0.2.0/24")}) - lc2, err := s2.LocalClient() - if err != nil { - t.Fatal(err) - } + // Start s2 after s1 is fully set up, including advertising its routes, + // otherwise the test is flaky if the test starts dialing through s2 before + // our test control server has told s2 about s1's routes. + s2, _, _ := startServer(t, ctx, controlURL, "s2") + lc2 := must.Get(s2.LocalClient()) + + must.Get(lc2.EditPrefs(ctx, &ipn.MaskedPrefs{ + Prefs: ipn.Prefs{ + RouteAll: true, + }, + RouteAllSet: true, + })) // ping to make sure the connection is up. - res, err := lc2.Ping(ctx, s1ip, tailcfg.PingICMP) + res, err := lc2.Ping(ctx, s1ip, tailcfg.PingTSMP) if err != nil { t.Fatal(err) } @@ -268,12 +338,26 @@ func TestConn(t *testing.T) { } defer ln.Close() - w, err := s2.Dial(ctx, "tcp", fmt.Sprintf("%s:8081", s1ip)) - if err != nil { - t.Fatal(err) - } + s1Conns := make(chan net.Conn) + go func() { + for { + c, err := ln.Accept() + if err != nil { + if ctx.Err() != nil { + return + } + t.Errorf("s1.Accept: %v", err) + return + } + select { + case s1Conns <- c: + case <-ctx.Done(): + c.Close() + } + } + }() - r, err := ln.Accept() + w, err := s2.Dial(ctx, "tcp", fmt.Sprintf("%s:8081", s1ip)) if err != nil { t.Fatal(err) } @@ -283,32 +367,56 @@ func TestConn(t *testing.T) { t.Fatal(err) } - got := make([]byte, len(want)) - if _, err := io.ReadAtLeast(r, got, len(got)); err != nil { - t.Fatal(err) - } - t.Logf("got: %q", got) - if string(got) != want { - t.Errorf("got %q, want %q", got, want) + select { + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for connection") + case r := <-s1Conns: + got := make([]byte, len(want)) + _, err := io.ReadAtLeast(r, got, len(got)) + r.Close() + if err != nil { + t.Fatal(err) + } + t.Logf("got: %q", got) + if string(got) != want { + t.Errorf("got %q, want %q", got, want) + } } + // Dial a non-existent port on s1 and expect it to fail. _, err = s2.Dial(ctx, "tcp", fmt.Sprintf("%s:8082", s1ip)) // some random port if err == nil { t.Fatalf("unexpected success; should have seen a connection refused error") } - - // s1 is a subnet router for TEST-NET-1 (192.0.2.0/24). Lets dial to that - // subnet from s2 to ensure a listener without an IP address (i.e. ":8081") - // only matches destination IPs corresponding to the node's IP, and not - // to any random IP a subnet is routing. - _, err = s2.Dial(ctx, "tcp", fmt.Sprintf("%s:8081", "192.0.2.1")) + t.Logf("got expected failure: %v", err) + + // s1 is a subnet router for TEST-NET-1 (192.0.2.0/24). Let's dial to that + // subnet from s2 to ensure a listener without an IP address (i.e. our + // ":8081" listen above) only matches destination IPs corresponding to the + // s1 node's IP addresses, and not to any random IP of a subnet it's routing. + // + // The RegisterFallbackTCPHandler on s1 above handles sending a RST when the + // TCP SYN arrives from s2. But we bound it to 5 seconds lest a regression + // like tailscale/tailscale#17805 recur. + s2dialer := s2.Sys().Dialer.Get() + s2dialer.SetSystemDialerForTest(func(ctx context.Context, netw, addr string) (net.Conn, error) { + t.Logf("s2: unexpected system dial called for %s %s", netw, addr) + return nil, fmt.Errorf("system dialer called unexpectedly for %s %s", netw, addr) + }) + docCtx, docCancel := context.WithTimeout(ctx, 5*time.Second) + defer docCancel() + _, err = s2.Dial(docCtx, "tcp", "192.0.2.1:8081") if err == nil { t.Fatalf("unexpected success; should have seen a connection refused error") } + if !saw192DocNetDial.Load() { + t.Errorf("expected s1's fallback TCP handler to have been called for 192.0.2.1:8081") + } } func TestLoopbackLocalAPI(t *testing.T) { flakytest.Mark(t, "https://github.com/tailscale/tailscale/issues/8557") + tstest.Shard(t) tstest.ResourceCheck(t) ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() @@ -384,6 +492,7 @@ func TestLoopbackLocalAPI(t *testing.T) { func TestLoopbackSOCKS5(t *testing.T) { flakytest.Mark(t, "https://github.com/tailscale/tailscale/issues/8198") + tstest.Shard(t) ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() @@ -434,6 +543,7 @@ func TestLoopbackSOCKS5(t *testing.T) { } func TestTailscaleIPs(t *testing.T) { + tstest.Shard(t) controlURL, _ := startControl(t) tmp := t.TempDir() @@ -476,6 +586,7 @@ func TestTailscaleIPs(t *testing.T) { // TestListenerCleanup is a regression test to verify that s.Close doesn't // deadlock if a listener is still open. func TestListenerCleanup(t *testing.T) { + tstest.Shard(t) ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() @@ -494,11 +605,31 @@ func TestListenerCleanup(t *testing.T) { if err := ln.Close(); !errors.Is(err, net.ErrClosed) { t.Fatalf("second ln.Close error: %v, want net.ErrClosed", err) } + + // Verify that handling a connection from gVisor (from a packet arriving) + // after a listener closed doesn't panic (previously: sending on a closed + // channel) or hang. + c := &closeTrackConn{} + ln.(*listener).handle(c) + if !c.closed { + t.Errorf("c.closed = false, want true") + } +} + +type closeTrackConn struct { + net.Conn + closed bool +} + +func (wc *closeTrackConn) Close() error { + wc.closed = true + return nil } // tests https://github.com/tailscale/tailscale/issues/6973 -- that we can start a tsnet server, // stop it, and restart it, even on Windows. func TestStartStopStartGetsSameIP(t *testing.T) { + tstest.Shard(t) controlURL, _ := startControl(t) tmp := t.TempDir() @@ -510,7 +641,7 @@ func TestStartStopStartGetsSameIP(t *testing.T) { Dir: tmps1, ControlURL: controlURL, Hostname: "s1", - Logf: logger.TestLogger(t), + Logf: tstest.WhileTestRunningLogger(t), } } s1 := newServer() @@ -548,6 +679,7 @@ func TestStartStopStartGetsSameIP(t *testing.T) { } func TestFunnel(t *testing.T) { + tstest.Shard(t) ctx, dialCancel := context.WithTimeout(context.Background(), 30*time.Second) defer dialCancel() @@ -608,6 +740,38 @@ func TestFunnel(t *testing.T) { } } +func TestListenerClose(t *testing.T) { + tstest.Shard(t) + ctx := context.Background() + controlURL, _ := startControl(t) + + s1, _, _ := startServer(t, ctx, controlURL, "s1") + + ln, err := s1.Listen("tcp", ":8080") + if err != nil { + t.Fatal(err) + } + + errc := make(chan error, 1) + go func() { + c, err := ln.Accept() + if c != nil { + c.Close() + } + errc <- err + }() + + ln.Close() + select { + case err := <-errc: + if !errors.Is(err, net.ErrClosed) { + t.Errorf("unexpected error: %v", err) + } + case <-time.After(10 * time.Second): + t.Fatal("timeout waiting for Accept to return") + } +} + func dialIngressConn(from, to *Server, target string) (net.Conn, error) { toLC := must.Get(to.LocalClient()) toStatus := must.Get(toLC.StatusWithoutPeers(context.Background())) @@ -657,6 +821,7 @@ func (c *bufferedConn) Read(b []byte) (int, error) { } func TestFallbackTCPHandler(t *testing.T) { + tstest.Shard(t) tstest.ResourceCheck(t) ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() @@ -699,6 +864,7 @@ func TestFallbackTCPHandler(t *testing.T) { } func TestCapturePcap(t *testing.T) { + tstest.Shard(t) const timeLimit = 120 ctx, cancel := context.WithTimeout(context.Background(), timeLimit*time.Second) defer cancel() @@ -752,6 +918,7 @@ func TestCapturePcap(t *testing.T) { } func TestUDPConn(t *testing.T) { + tstest.Shard(t) tstest.ResourceCheck(t) ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() @@ -821,16 +988,6 @@ func TestUDPConn(t *testing.T) { } } -// testWarnable is a Warnable that is used within this package for testing purposes only. -var testWarnable = health.Register(&health.Warnable{ - Code: "test-warnable-tsnet", - Title: "Test warnable", - Severity: health.SeverityLow, - Text: func(args health.Args) string { - return args[health.ArgError] - }, -}) - func parseMetrics(m []byte) (map[string]float64, error) { metrics := make(map[string]float64) @@ -864,25 +1021,225 @@ func promMetricLabelsStr(labels []*dto.LabelPair) string { } var b strings.Builder b.WriteString("{") - for i, l := range labels { + for i, lb := range labels { if i > 0 { b.WriteString(",") } - b.WriteString(fmt.Sprintf("%s=%q", l.GetName(), l.GetValue())) + b.WriteString(fmt.Sprintf("%s=%q", lb.GetName(), lb.GetValue())) } b.WriteString("}") return b.String() } -func TestUserMetrics(t *testing.T) { - flakytest.Mark(t, "https://github.com/tailscale/tailscale/issues/13420") - tstest.ResourceCheck(t) +// sendData sends a given amount of bytes from s1 to s2. +func sendData(logf func(format string, args ...any), ctx context.Context, bytesCount int, s1, s2 *Server, s1ip, s2ip netip.Addr) error { + lb := must.Get(s1.Listen("tcp", fmt.Sprintf("%s:8081", s1ip))) + defer lb.Close() + + // Dial to s1 from s2 + w, err := s2.Dial(ctx, "tcp", fmt.Sprintf("%s:8081", s1ip)) + if err != nil { + return err + } + defer w.Close() + + stopReceive := make(chan struct{}) + defer close(stopReceive) + allReceived := make(chan error) + defer close(allReceived) + + go func() { + conn, err := lb.Accept() + if err != nil { + allReceived <- err + return + } + conn.SetWriteDeadline(time.Now().Add(30 * time.Second)) + + total := 0 + recvStart := time.Now() + for { + got := make([]byte, bytesCount) + n, err := conn.Read(got) + if err != nil { + allReceived <- fmt.Errorf("failed reading packet, %s", err) + return + } + got = got[:n] + + select { + case <-stopReceive: + return + default: + } + + total += n + logf("received %d/%d bytes, %.2f %%", total, bytesCount, (float64(total) / (float64(bytesCount)) * 100)) + + // Validate the received bytes to be the same as the sent bytes. + for _, b := range string(got) { + if b != 'A' { + allReceived <- fmt.Errorf("received unexpected byte: %c", b) + return + } + } + + if total == bytesCount { + break + } + } + + logf("all received, took: %s", time.Since(recvStart).String()) + allReceived <- nil + }() + + sendStart := time.Now() + w.SetWriteDeadline(time.Now().Add(30 * time.Second)) + if _, err := w.Write(bytes.Repeat([]byte("A"), bytesCount)); err != nil { + stopReceive <- struct{}{} + return err + } + + logf("all sent (%s), waiting for all packets (%d) to be received", time.Since(sendStart).String(), bytesCount) + err, _ = <-allReceived + if err != nil { + return err + } + + return nil +} + +func TestUserMetricsByteCounters(t *testing.T) { + tstest.Shard(t) + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + controlURL, _ := startControl(t) + s1, s1ip, _ := startServer(t, ctx, controlURL, "s1") + defer s1.Close() + s2, s2ip, _ := startServer(t, ctx, controlURL, "s2") + defer s2.Close() + + lc1, err := s1.LocalClient() + if err != nil { + t.Fatal(err) + } + + lc2, err := s2.LocalClient() + if err != nil { + t.Fatal(err) + } + + // Force an update to the netmap to ensure that the metrics are up-to-date. + s1.lb.DebugForceNetmapUpdate() + s2.lb.DebugForceNetmapUpdate() + + // Wait for both nodes to have a peer in their netmap. + waitForCondition(t, "waiting for netmaps to contain peer", 90*time.Second, func() bool { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + status1, err := lc1.Status(ctx) + if err != nil { + t.Logf("getting status: %s", err) + return false + } + status2, err := lc2.Status(ctx) + if err != nil { + t.Logf("getting status: %s", err) + return false + } + return len(status1.Peers()) > 0 && len(status2.Peers()) > 0 + }) + + // ping to make sure the connection is up. + res, err := lc2.Ping(ctx, s1ip, tailcfg.PingICMP) + if err != nil { + t.Fatalf("pinging: %s", err) + } + t.Logf("ping success: %#+v", res) + + mustDirect(t, t.Logf, lc1, lc2) + + // 1 megabytes + bytesToSend := 1 * 1024 * 1024 + + // This asserts generates some traffic, it is factored out + // of TestUDPConn. + start := time.Now() + err = sendData(t.Logf, ctx, bytesToSend, s1, s2, s1ip, s2ip) + if err != nil { + t.Fatalf("Failed to send packets: %v", err) + } + t.Logf("Sent %d bytes from s1 to s2 in %s", bytesToSend, time.Since(start).String()) + + ctxLc, cancelLc := context.WithTimeout(context.Background(), 5*time.Second) + defer cancelLc() + metrics1, err := lc1.UserMetrics(ctxLc) + if err != nil { + t.Fatal(err) + } + + parsedMetrics1, err := parseMetrics(metrics1) + if err != nil { + t.Fatal(err) + } + + // Allow the metrics for the bytes sent to be off by 15%. + bytesSentTolerance := 1.15 + + t.Logf("Metrics1:\n%s\n", metrics1) + + // Verify that the amount of data recorded in bytes is higher or equal to the data sent + inboundBytes1 := parsedMetrics1[`tailscaled_inbound_bytes_total{path="direct_ipv4"}`] + if inboundBytes1 < float64(bytesToSend) { + t.Errorf(`metrics1, tailscaled_inbound_bytes_total{path="direct_ipv4"}: expected higher (or equal) than %d, got: %f`, bytesToSend, inboundBytes1) + } + + // But ensure that it is not too much higher than the data sent. + if inboundBytes1 > float64(bytesToSend)*bytesSentTolerance { + t.Errorf(`metrics1, tailscaled_inbound_bytes_total{path="direct_ipv4"}: expected lower than %f, got: %f`, float64(bytesToSend)*bytesSentTolerance, inboundBytes1) + } + + metrics2, err := lc2.UserMetrics(ctx) + if err != nil { + t.Fatal(err) + } + + parsedMetrics2, err := parseMetrics(metrics2) + if err != nil { + t.Fatal(err) + } + + t.Logf("Metrics2:\n%s\n", metrics2) + + // Verify that the amount of data recorded in bytes is higher or equal than the data sent. + outboundBytes2 := parsedMetrics2[`tailscaled_outbound_bytes_total{path="direct_ipv4"}`] + if outboundBytes2 < float64(bytesToSend) { + t.Errorf(`metrics2, tailscaled_outbound_bytes_total{path="direct_ipv4"}: expected higher (or equal) than %d, got: %f`, bytesToSend, outboundBytes2) + } + + // But ensure that it is not too much higher than the data sent. + if outboundBytes2 > float64(bytesToSend)*bytesSentTolerance { + t.Errorf(`metrics2, tailscaled_outbound_bytes_total{path="direct_ipv4"}: expected lower than %f, got: %f`, float64(bytesToSend)*bytesSentTolerance, outboundBytes2) + } +} + +func TestUserMetricsRouteGauges(t *testing.T) { + tstest.Shard(t) + // Windows does not seem to support or report back routes when running in + // userspace via tsnet. So, we skip this check on Windows. + // TODO(kradalby): Figure out if this is correct. + if runtime.GOOS == "windows" { + t.Skipf("skipping on windows") + } ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) defer cancel() controlURL, c := startControl(t) - s1, s1ip, s1PubKey := startServer(t, ctx, controlURL, "s1") + s1, _, s1PubKey := startServer(t, ctx, controlURL, "s1") + defer s1.Close() s2, _, _ := startServer(t, ctx, controlURL, "s2") + defer s2.Close() s1.lb.EditPrefs(&ipn.MaskedPrefs{ Prefs: ipn.Prefs{ @@ -911,24 +1268,11 @@ func TestUserMetrics(t *testing.T) { t.Fatal(err) } - // ping to make sure the connection is up. - res, err := lc2.Ping(ctx, s1ip, tailcfg.PingICMP) - if err != nil { - t.Fatalf("pinging: %s", err) - } - t.Logf("ping success: %#+v", res) - - ht := s1.lb.HealthTracker() - ht.SetUnhealthy(testWarnable, health.Args{"Text": "Hello world 1"}) - // Force an update to the netmap to ensure that the metrics are up-to-date. s1.lb.DebugForceNetmapUpdate() s2.lb.DebugForceNetmapUpdate() wantRoutes := float64(2) - if runtime.GOOS == "windows" { - wantRoutes = 0 - } // Wait for the routes to be propagated to node 1 to ensure // that the metrics are up-to-date. @@ -940,12 +1284,6 @@ func TestUserMetrics(t *testing.T) { t.Logf("getting status: %s", err) return false } - if runtime.GOOS == "windows" { - // Windows does not seem to support or report back routes when running in - // userspace via tsnet. So, we skip this check on Windows. - // TODO(kradalby): Figure out if this is correct. - return true - } // Wait for the primary routes to reach our desired routes, which is wantRoutes + 1, because // the PrimaryRoutes list will contain a exit node route, which the metric does not count. return status1.Self.PrimaryRoutes != nil && status1.Self.PrimaryRoutes.Len() == int(wantRoutes)+1 @@ -958,11 +1296,6 @@ func TestUserMetrics(t *testing.T) { t.Fatal(err) } - status1, err := lc1.Status(ctxLc) - if err != nil { - t.Fatal(err) - } - parsedMetrics1, err := parseMetrics(metrics1) if err != nil { t.Fatal(err) @@ -985,28 +1318,11 @@ func TestUserMetrics(t *testing.T) { t.Errorf("metrics1, tailscaled_approved_routes: got %v, want %v", got, want) } - // Validate the health counter metric against the status of the node - if got, want := parsedMetrics1[`tailscaled_health_messages{type="warning"}`], float64(len(status1.Health)); got != want { - t.Errorf("metrics1, tailscaled_health_messages: got %v, want %v", got, want) - } - - // The node is the primary subnet router for 2 routes: - // - 192.0.2.0/24 - // - 192.0.5.1/32 - if got, want := parsedMetrics1["tailscaled_primary_routes"], wantRoutes; got != want { - t.Errorf("metrics1, tailscaled_primary_routes: got %v, want %v", got, want) - } - metrics2, err := lc2.UserMetrics(ctx) if err != nil { t.Fatal(err) } - status2, err := lc2.Status(ctx) - if err != nil { - t.Fatal(err) - } - parsedMetrics2, err := parseMetrics(metrics2) if err != nil { t.Fatal(err) @@ -1023,16 +1339,6 @@ func TestUserMetrics(t *testing.T) { if got, want := parsedMetrics2["tailscaled_approved_routes"], 0.0; got != want { t.Errorf("metrics2, tailscaled_approved_routes: got %v, want %v", got, want) } - - // Validate the health counter metric against the status of the node - if got, want := parsedMetrics2[`tailscaled_health_messages{type="warning"}`], float64(len(status2.Health)); got != want { - t.Errorf("metrics2, tailscaled_health_messages: got %v, want %v", got, want) - } - - // The node is the primary subnet router for 0 routes - if got, want := parsedMetrics2["tailscaled_primary_routes"], 0.0; got != want { - t.Errorf("metrics2, tailscaled_primary_routes: got %v, want %v", got, want) - } } func waitForCondition(t *testing.T, msg string, waitTime time.Duration, f func() bool) { @@ -1044,3 +1350,46 @@ func waitForCondition(t *testing.T, msg string, waitTime time.Duration, f func() } t.Fatalf("waiting for condition: %s", msg) } + +// mustDirect ensures there is a direct connection between LocalClient 1 and 2 +func mustDirect(t *testing.T, logf logger.Logf, lc1, lc2 *local.Client) { + t.Helper() + lastLog := time.Now().Add(-time.Minute) + // See https://github.com/tailscale/tailscale/issues/654 + // and https://github.com/tailscale/tailscale/issues/3247 for discussions of this deadline. + for deadline := time.Now().Add(30 * time.Second); time.Now().Before(deadline); time.Sleep(10 * time.Millisecond) { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + status1, err := lc1.Status(ctx) + if err != nil { + continue + } + status2, err := lc2.Status(ctx) + if err != nil { + continue + } + pst := status1.Peer[status2.Self.PublicKey] + if pst.CurAddr != "" { + logf("direct link %s->%s found with addr %s", status1.Self.HostName, status2.Self.HostName, pst.CurAddr) + return + } + if now := time.Now(); now.Sub(lastLog) > time.Second { + logf("no direct path %s->%s yet, addrs %v", status1.Self.HostName, status2.Self.HostName, pst.Addrs) + lastLog = now + } + } + t.Error("magicsock did not find a direct path from lc1 to lc2") +} + +func TestDeps(t *testing.T) { + tstest.Shard(t) + deptest.DepChecker{ + GOOS: "linux", + GOARCH: "amd64", + OnDep: func(dep string) { + if strings.Contains(dep, "portlist") { + t.Errorf("unexpected dep: %q", dep) + } + }, + }.Check(t) +} diff --git a/tstest/archtest/qemu_test.go b/tstest/archtest/qemu_test.go index 8b59ae5d9..68ec38851 100644 --- a/tstest/archtest/qemu_test.go +++ b/tstest/archtest/qemu_test.go @@ -33,7 +33,6 @@ func TestInQemu(t *testing.T) { } inCI := cibuild.On() for _, arch := range arches { - arch := arch t.Run(arch.Goarch, func(t *testing.T) { t.Parallel() qemuUser := "qemu-" + arch.Qarch diff --git a/tstest/chonktest/chonktest.go b/tstest/chonktest/chonktest.go new file mode 100644 index 000000000..404f1ec47 --- /dev/null +++ b/tstest/chonktest/chonktest.go @@ -0,0 +1,303 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package chonktest contains a shared set of tests for the Chonk +// interface used to store AUM messages in Tailnet Lock, which we can +// share between different implementations. +package chonktest + +import ( + "bytes" + "encoding/binary" + "errors" + "math/rand" + "os" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "golang.org/x/crypto/blake2s" + "tailscale.com/tka" + "tailscale.com/util/must" +) + +// returns a random source based on the test name + extraSeed. +func testingRand(t *testing.T, extraSeed int64) *rand.Rand { + var seed int64 + if err := binary.Read(bytes.NewBuffer([]byte(t.Name())), binary.LittleEndian, &seed); err != nil { + panic(err) + } + return rand.New(rand.NewSource(seed + extraSeed)) +} + +// randHash derives a fake blake2s hash from the test name +// and the given seed. +func randHash(t *testing.T, seed int64) [blake2s.Size]byte { + var out [blake2s.Size]byte + testingRand(t, seed).Read(out[:]) + return out +} + +func hashesLess(x, y tka.AUMHash) bool { + return bytes.Compare(x[:], y[:]) < 0 +} + +func aumHashesLess(x, y tka.AUM) bool { + return hashesLess(x.Hash(), y.Hash()) +} + +// RunChonkTests is a set of tests for the behaviour of a Chonk. +// +// Any implementation of Chonk should pass these tests, so we know all +// Chonks behave in the same way. If you want to test behaviour that's +// specific to one implementation, write a separate test. +func RunChonkTests(t *testing.T, newChonk func(*testing.T) tka.Chonk) { + t.Run("ChildAUMs", func(t *testing.T) { + t.Parallel() + chonk := newChonk(t) + parentHash := randHash(t, 1) + data := []tka.AUM{ + { + MessageKind: tka.AUMRemoveKey, + KeyID: []byte{1, 2}, + PrevAUMHash: parentHash[:], + }, + { + MessageKind: tka.AUMRemoveKey, + KeyID: []byte{3, 4}, + PrevAUMHash: parentHash[:], + }, + } + + if err := chonk.CommitVerifiedAUMs(data); err != nil { + t.Fatalf("CommitVerifiedAUMs failed: %v", err) + } + stored, err := chonk.ChildAUMs(parentHash) + if err != nil { + t.Fatalf("ChildAUMs failed: %v", err) + } + if diff := cmp.Diff(data, stored, cmpopts.SortSlices(aumHashesLess)); diff != "" { + t.Errorf("stored AUM differs (-want, +got):\n%s", diff) + } + }) + + t.Run("AUMMissing", func(t *testing.T) { + t.Parallel() + chonk := newChonk(t) + var notExists tka.AUMHash + notExists[:][0] = 42 + if _, err := chonk.AUM(notExists); err != os.ErrNotExist { + t.Errorf("chonk.AUM(notExists).err = %v, want %v", err, os.ErrNotExist) + } + }) + + t.Run("ReadChainFromHead", func(t *testing.T) { + t.Parallel() + chonk := newChonk(t) + genesis := tka.AUM{MessageKind: tka.AUMRemoveKey, KeyID: []byte{1, 2}} + gHash := genesis.Hash() + intermediate := tka.AUM{PrevAUMHash: gHash[:]} + iHash := intermediate.Hash() + leaf := tka.AUM{PrevAUMHash: iHash[:]} + + commitSet := []tka.AUM{ + genesis, + intermediate, + leaf, + } + if err := chonk.CommitVerifiedAUMs(commitSet); err != nil { + t.Fatalf("CommitVerifiedAUMs failed: %v", err) + } + t.Logf("genesis hash = %X", genesis.Hash()) + t.Logf("intermediate hash = %X", intermediate.Hash()) + t.Logf("leaf hash = %X", leaf.Hash()) + + // Read the chain from the leaf backwards. + gotLeafs, err := chonk.Heads() + if err != nil { + t.Fatalf("Heads failed: %v", err) + } + if diff := cmp.Diff([]tka.AUM{leaf}, gotLeafs); diff != "" { + t.Fatalf("leaf AUM differs (-want, +got):\n%s", diff) + } + + parent, _ := gotLeafs[0].Parent() + gotIntermediate, err := chonk.AUM(parent) + if err != nil { + t.Fatalf("AUM() failed: %v", err) + } + if diff := cmp.Diff(intermediate, gotIntermediate); diff != "" { + t.Errorf("intermediate AUM differs (-want, +got):\n%s", diff) + } + + parent, _ = gotIntermediate.Parent() + gotGenesis, err := chonk.AUM(parent) + if err != nil { + t.Fatalf("AUM() failed: %v", err) + } + if diff := cmp.Diff(genesis, gotGenesis); diff != "" { + t.Errorf("genesis AUM differs (-want, +got):\n%s", diff) + } + }) + + t.Run("LastActiveAncestor", func(t *testing.T) { + t.Parallel() + chonk := newChonk(t) + + aum := tka.AUM{MessageKind: tka.AUMRemoveKey, KeyID: []byte{1, 2}} + hash := aum.Hash() + + if err := chonk.SetLastActiveAncestor(hash); err != nil { + t.Fatal(err) + } + got, err := chonk.LastActiveAncestor() + if err != nil { + t.Fatal(err) + } + if got == nil || hash.String() != got.String() { + t.Errorf("LastActiveAncestor=%s, want %s", got, hash) + } + }) +} + +// RunCompactableChonkTests is a set of tests for the behaviour of a +// CompactableChonk. +// +// Any implementation of CompactableChonk should pass these tests, so we +// know all CompactableChonk behave in the same way. If you want to test +// behaviour that's specific to one implementation, write a separate test. +func RunCompactableChonkTests(t *testing.T, newChonk func(t *testing.T) tka.CompactableChonk) { + t.Run("PurgeAUMs", func(t *testing.T) { + t.Parallel() + chonk := newChonk(t) + parentHash := randHash(t, 1) + aum := tka.AUM{MessageKind: tka.AUMNoOp, PrevAUMHash: parentHash[:]} + + if err := chonk.CommitVerifiedAUMs([]tka.AUM{aum}); err != nil { + t.Fatal(err) + } + if err := chonk.PurgeAUMs([]tka.AUMHash{aum.Hash()}); err != nil { + t.Fatal(err) + } + + if _, err := chonk.AUM(aum.Hash()); err != os.ErrNotExist { + t.Errorf("AUM() on purged AUM returned err = %v, want ErrNotExist", err) + } + }) + + t.Run("AllAUMs", func(t *testing.T) { + chonk := newChonk(t) + genesis := tka.AUM{MessageKind: tka.AUMRemoveKey, KeyID: []byte{1, 2}} + gHash := genesis.Hash() + intermediate := tka.AUM{PrevAUMHash: gHash[:]} + iHash := intermediate.Hash() + leaf := tka.AUM{PrevAUMHash: iHash[:]} + + commitSet := []tka.AUM{ + genesis, + intermediate, + leaf, + } + if err := chonk.CommitVerifiedAUMs(commitSet); err != nil { + t.Fatalf("CommitVerifiedAUMs failed: %v", err) + } + + hashes, err := chonk.AllAUMs() + if err != nil { + t.Fatal(err) + } + if diff := cmp.Diff([]tka.AUMHash{genesis.Hash(), intermediate.Hash(), leaf.Hash()}, hashes, cmpopts.SortSlices(hashesLess)); diff != "" { + t.Fatalf("AllAUMs() output differs (-want, +got):\n%s", diff) + } + }) + + t.Run("ChildAUMsOfPurgedAUM", func(t *testing.T) { + t.Parallel() + chonk := newChonk(t) + parent := tka.AUM{MessageKind: tka.AUMRemoveKey, KeyID: []byte{0, 0}} + + parentHash := parent.Hash() + + child1 := tka.AUM{MessageKind: tka.AUMAddKey, KeyID: []byte{1, 1}, PrevAUMHash: parentHash[:]} + child2 := tka.AUM{MessageKind: tka.AUMAddKey, KeyID: []byte{2, 2}, PrevAUMHash: parentHash[:]} + child3 := tka.AUM{MessageKind: tka.AUMAddKey, KeyID: []byte{3, 3}, PrevAUMHash: parentHash[:]} + + child2Hash := child2.Hash() + grandchild2A := tka.AUM{MessageKind: tka.AUMAddKey, KeyID: []byte{2, 2, 2, 2}, PrevAUMHash: child2Hash[:]} + grandchild2B := tka.AUM{MessageKind: tka.AUMAddKey, KeyID: []byte{2, 2, 2, 2, 2}, PrevAUMHash: child2Hash[:]} + + commitSet := []tka.AUM{parent, child1, child2, child3, grandchild2A, grandchild2B} + + if err := chonk.CommitVerifiedAUMs(commitSet); err != nil { + t.Fatalf("CommitVerifiedAUMs failed: %v", err) + } + + // Check the set of hashes is correct + childHashes := must.Get(chonk.ChildAUMs(parentHash)) + if diff := cmp.Diff([]tka.AUM{child1, child2, child3}, childHashes, cmpopts.SortSlices(aumHashesLess)); diff != "" { + t.Fatalf("ChildAUMs() output differs (-want, +got):\n%s", diff) + } + + // Purge the parent AUM, and check the set of child AUMs is unchanged + chonk.PurgeAUMs([]tka.AUMHash{parent.Hash()}) + + childHashes = must.Get(chonk.ChildAUMs(parentHash)) + if diff := cmp.Diff([]tka.AUM{child1, child2, child3}, childHashes, cmpopts.SortSlices(aumHashesLess)); diff != "" { + t.Fatalf("ChildAUMs() output differs (-want, +got):\n%s", diff) + } + + // Now purge one of the child AUMs, and check it no longer appears as a child of the parent + chonk.PurgeAUMs([]tka.AUMHash{child3.Hash()}) + + childHashes = must.Get(chonk.ChildAUMs(parentHash)) + if diff := cmp.Diff([]tka.AUM{child1, child2}, childHashes, cmpopts.SortSlices(aumHashesLess)); diff != "" { + t.Fatalf("ChildAUMs() output differs (-want, +got):\n%s", diff) + } + }) + + t.Run("RemoveAll", func(t *testing.T) { + t.Parallel() + chonk := newChonk(t) + parentHash := randHash(t, 1) + data := []tka.AUM{ + { + MessageKind: tka.AUMRemoveKey, + KeyID: []byte{1, 2}, + PrevAUMHash: parentHash[:], + }, + { + MessageKind: tka.AUMRemoveKey, + KeyID: []byte{3, 4}, + PrevAUMHash: parentHash[:], + }, + } + + if err := chonk.CommitVerifiedAUMs(data); err != nil { + t.Fatalf("CommitVerifiedAUMs failed: %v", err) + } + + // Check we can retrieve the AUMs we just stored + for _, want := range data { + got, err := chonk.AUM(want.Hash()) + if err != nil { + t.Fatalf("could not get %s: %v", want.Hash(), err) + } + if diff := cmp.Diff(want, got); diff != "" { + t.Errorf("stored AUM %s differs (-want, +got):\n%s", want.Hash(), diff) + } + } + + // Call RemoveAll() to drop all the AUM state + if err := chonk.RemoveAll(); err != nil { + t.Fatalf("RemoveAll failed: %v", err) + } + + // Check we can no longer retrieve the previously-stored AUMs + for _, want := range data { + aum, err := chonk.AUM(want.Hash()) + if !errors.Is(err, os.ErrNotExist) { + t.Fatalf("expected os.ErrNotExist for %s, instead got aum=%v, err=%v", want.Hash(), aum, err) + } + } + }) +} diff --git a/tstest/chonktest/tailchonk_test.go b/tstest/chonktest/tailchonk_test.go new file mode 100644 index 000000000..d9343e916 --- /dev/null +++ b/tstest/chonktest/tailchonk_test.go @@ -0,0 +1,59 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package chonktest + +import ( + "testing" + + "tailscale.com/tka" + "tailscale.com/util/must" +) + +func TestImplementsChonk(t *testing.T) { + for _, tt := range []struct { + name string + newChonk func(t *testing.T) tka.Chonk + }{ + { + name: "Mem", + newChonk: func(t *testing.T) tka.Chonk { + return tka.ChonkMem() + }, + }, + { + name: "FS", + newChonk: func(t *testing.T) tka.Chonk { + return must.Get(tka.ChonkDir(t.TempDir())) + }, + }, + } { + t.Run(tt.name, func(t *testing.T) { + RunChonkTests(t, tt.newChonk) + }) + } +} + +func TestImplementsCompactableChonk(t *testing.T) { + for _, tt := range []struct { + name string + newChonk func(t *testing.T) tka.CompactableChonk + }{ + { + name: "Mem", + newChonk: func(t *testing.T) tka.CompactableChonk { + return tka.ChonkMem() + }, + }, + { + name: "FS", + newChonk: func(t *testing.T) tka.CompactableChonk { + return must.Get(tka.ChonkDir(t.TempDir())) + }, + }, + } { + t.Run(tt.name, func(t *testing.T) { + RunCompactableChonkTests(t, tt.newChonk) + }) + } +} diff --git a/tstest/clock_test.go b/tstest/clock_test.go index d5816564a..2ebaf752a 100644 --- a/tstest/clock_test.go +++ b/tstest/clock_test.go @@ -56,7 +56,6 @@ func TestClockWithDefinedStartTime(t *testing.T) { } for _, tt := range tests { - tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() clock := NewClock(ClockOpts{ @@ -118,7 +117,6 @@ func TestClockWithDefaultStartTime(t *testing.T) { } for _, tt := range tests { - tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() clock := NewClock(ClockOpts{ @@ -277,7 +275,6 @@ func TestClockSetStep(t *testing.T) { } for _, tt := range tests { - tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() clock := NewClock(ClockOpts{ @@ -426,7 +423,6 @@ func TestClockAdvance(t *testing.T) { } for _, tt := range tests { - tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() clock := NewClock(ClockOpts{ @@ -876,7 +872,6 @@ func TestSingleTicker(t *testing.T) { } for _, tt := range tests { - tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() var realTimeClockForTestClock tstime.Clock @@ -1377,7 +1372,6 @@ func TestSingleTimer(t *testing.T) { } for _, tt := range tests { - tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() var realTimeClockForTestClock tstime.Clock @@ -1911,7 +1905,6 @@ func TestClockFollowRealTime(t *testing.T) { } for _, tt := range tests { - tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() realTimeClock := NewClock(tt.realTimeClockOpts) @@ -2364,7 +2357,6 @@ func TestAfterFunc(t *testing.T) { } for _, tt := range tests { - tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() var realTimeClockForTestClock tstime.Clock @@ -2468,7 +2460,6 @@ func TestSince(t *testing.T) { } for _, tt := range tests { - tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() clock := NewClock(ClockOpts{ diff --git a/tstest/deptest/deptest.go b/tstest/deptest/deptest.go index 57db2b79a..c0b6d8b8c 100644 --- a/tstest/deptest/deptest.go +++ b/tstest/deptest/deptest.go @@ -13,14 +13,23 @@ import ( "path/filepath" "regexp" "runtime" + "slices" "strings" + "sync" "testing" + + "tailscale.com/util/set" ) type DepChecker struct { - GOOS string // optional - GOARCH string // optional - BadDeps map[string]string // package => why + GOOS string // optional + GOARCH string // optional + OnDep func(string) // if non-nil, called per dependency + OnImport func(string) // if non-nil, called per import + BadDeps map[string]string // package => why + WantDeps set.Set[string] // packages expected + Tags string // comma-separated + ExtraEnv []string // extra environment for "go list" (e.g. CGO_ENABLED=1) } func (c DepChecker) Check(t *testing.T) { @@ -29,7 +38,7 @@ func (c DepChecker) Check(t *testing.T) { t.Skip("skipping dep tests on windows hosts") } t.Helper() - cmd := exec.Command("go", "list", "-json", ".") + cmd := exec.Command("go", "list", "-json", "-tags="+c.Tags, ".") var extraEnv []string if c.GOOS != "" { extraEnv = append(extraEnv, "GOOS="+c.GOOS) @@ -37,23 +46,63 @@ func (c DepChecker) Check(t *testing.T) { if c.GOARCH != "" { extraEnv = append(extraEnv, "GOARCH="+c.GOARCH) } + extraEnv = append(extraEnv, c.ExtraEnv...) cmd.Env = append(os.Environ(), extraEnv...) out, err := cmd.Output() if err != nil { t.Fatal(err) } var res struct { - Deps []string + Imports []string + Deps []string } if err := json.Unmarshal(out, &res); err != nil { t.Fatal(err) } + tsRoot := sync.OnceValue(func() string { + out, err := exec.Command("go", "list", "-f", "{{.Dir}}", "tailscale.com").Output() + if err != nil { + t.Fatalf("failed to find tailscale.com root: %v", err) + } + return strings.TrimSpace(string(out)) + }) + + if c.OnImport != nil { + for _, imp := range res.Imports { + c.OnImport(imp) + } + } + for _, dep := range res.Deps { + if c.OnDep != nil { + c.OnDep(dep) + } if why, ok := c.BadDeps[dep]; ok { t.Errorf("package %q is not allowed as a dependency (env: %q); reason: %s", dep, extraEnv, why) } } + // Make sure the BadDeps packages actually exists. If they got renamed or + // moved around, we should update the test referencing the old name. + // Doing this in the general case requires network access at runtime + // (resolving a package path to its module, possibly doing the ?go-get=1 + // meta tag dance), so we just check the common case of + // "tailscale.com/*" packages for now, with the assumption that all + // "tailscale.com/*" packages are in the same module, which isn't + // necessarily true in the general case. + for dep := range c.BadDeps { + if suf, ok := strings.CutPrefix(dep, "tailscale.com/"); ok { + pkgDir := filepath.Join(tsRoot(), suf) + if _, err := os.Stat(pkgDir); err != nil { + t.Errorf("listed BadDep %q doesn't seem to exist anymore: %v", dep, err) + } + } + } + for dep := range c.WantDeps { + if !slices.Contains(res.Deps, dep) { + t.Errorf("expected package %q to be a dependency (env: %q)", dep, extraEnv) + } + } t.Logf("got %d dependencies", len(res.Deps)) } diff --git a/tstest/integration/capmap_test.go b/tstest/integration/capmap_test.go new file mode 100644 index 000000000..0ee05be2f --- /dev/null +++ b/tstest/integration/capmap_test.go @@ -0,0 +1,147 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package integration + +import ( + "errors" + "testing" + "time" + + "tailscale.com/tailcfg" + "tailscale.com/tstest" +) + +// TestPeerCapMap tests that the node capability map (CapMap) is included in peer information. +func TestPeerCapMap(t *testing.T) { + tstest.Shard(t) + tstest.Parallel(t) + env := NewTestEnv(t) + + // Spin up two nodes. + n1 := NewTestNode(t, env) + d1 := n1.StartDaemon() + n1.AwaitListening() + n1.MustUp() + n1.AwaitRunning() + + n2 := NewTestNode(t, env) + d2 := n2.StartDaemon() + n2.AwaitListening() + n2.MustUp() + n2.AwaitRunning() + + n1.AwaitIP4() + n2.AwaitIP4() + + // Get the nodes from the control server. + nodes := env.Control.AllNodes() + if len(nodes) != 2 { + t.Fatalf("expected 2 nodes, got %d nodes", len(nodes)) + } + + // Figure out which node is which by comparing keys. + st1 := n1.MustStatus() + var tn1, tn2 *tailcfg.Node + for _, n := range nodes { + if n.Key == st1.Self.PublicKey { + tn1 = n + } else { + tn2 = n + } + } + + // Set CapMap on both nodes. + caps := make(tailcfg.NodeCapMap) + caps["example:custom"] = []tailcfg.RawMessage{`"value"`} + caps["example:enabled"] = []tailcfg.RawMessage{`true`} + + env.Control.SetNodeCapMap(tn1.Key, caps) + env.Control.SetNodeCapMap(tn2.Key, caps) + + // Check that nodes see each other's CapMap. + if err := tstest.WaitFor(10*time.Second, func() error { + st1 := n1.MustStatus() + st2 := n2.MustStatus() + + if len(st1.Peer) == 0 || len(st2.Peer) == 0 { + return errors.New("no peers") + } + + // Check n1 sees n2's CapMap. + p1 := st1.Peer[st1.Peers()[0]] + if p1.CapMap == nil { + return errors.New("peer CapMap is nil") + } + if p1.CapMap["example:custom"] == nil || p1.CapMap["example:enabled"] == nil { + return errors.New("peer CapMap missing entries") + } + + // Check n2 sees n1's CapMap. + p2 := st2.Peer[st2.Peers()[0]] + if p2.CapMap == nil { + return errors.New("peer CapMap is nil") + } + if p2.CapMap["example:custom"] == nil || p2.CapMap["example:enabled"] == nil { + return errors.New("peer CapMap missing entries") + } + + return nil + }); err != nil { + t.Fatal(err) + } + + d1.MustCleanShutdown(t) + d2.MustCleanShutdown(t) +} + +// TestSetNodeCapMap tests that SetNodeCapMap updates are propagated to peers. +func TestSetNodeCapMap(t *testing.T) { + tstest.Shard(t) + tstest.Parallel(t) + env := NewTestEnv(t) + + n1 := NewTestNode(t, env) + d1 := n1.StartDaemon() + n1.AwaitListening() + n1.MustUp() + n1.AwaitRunning() + + nodes := env.Control.AllNodes() + if len(nodes) != 1 { + t.Fatalf("expected 1 node, got %d nodes", len(nodes)) + } + node1 := nodes[0] + + // Set initial CapMap. + caps := make(tailcfg.NodeCapMap) + caps["test:state"] = []tailcfg.RawMessage{`"initial"`} + env.Control.SetNodeCapMap(node1.Key, caps) + + // Start second node and verify it sees the first node's CapMap. + n2 := NewTestNode(t, env) + d2 := n2.StartDaemon() + n2.AwaitListening() + n2.MustUp() + n2.AwaitRunning() + + if err := tstest.WaitFor(5*time.Second, func() error { + st := n2.MustStatus() + if len(st.Peer) == 0 { + return errors.New("no peers") + } + p := st.Peer[st.Peers()[0]] + if p.CapMap == nil || p.CapMap["test:state"] == nil { + return errors.New("peer CapMap not set") + } + if string(p.CapMap["test:state"][0]) != `"initial"` { + return errors.New("wrong CapMap value") + } + return nil + }); err != nil { + t.Fatal(err) + } + + d1.MustCleanShutdown(t) + d2.MustCleanShutdown(t) +} diff --git a/tstest/integration/integration.go b/tstest/integration/integration.go index 36a92759f..6700205cf 100644 --- a/tstest/integration/integration.go +++ b/tstest/integration/integration.go @@ -9,92 +9,206 @@ package integration import ( "bytes" + "context" "crypto/tls" "encoding/json" + "flag" "fmt" "io" "log" "net" "net/http" "net/http/httptest" + "net/netip" "os" "os/exec" "path" "path/filepath" + "regexp" "runtime" + "strconv" "strings" "sync" "testing" "time" "go4.org/mem" - "tailscale.com/derp" - "tailscale.com/derp/derphttp" + "tailscale.com/client/local" + "tailscale.com/derp/derpserver" + "tailscale.com/ipn" + "tailscale.com/ipn/ipnlocal" + "tailscale.com/ipn/ipnstate" + "tailscale.com/ipn/store" "tailscale.com/net/stun/stuntest" + "tailscale.com/safesocket" + "tailscale.com/syncs" "tailscale.com/tailcfg" + "tailscale.com/tstest" + "tailscale.com/tstest/integration/testcontrol" "tailscale.com/types/key" "tailscale.com/types/logger" "tailscale.com/types/logid" "tailscale.com/types/nettype" + "tailscale.com/util/rands" "tailscale.com/util/zstdframe" "tailscale.com/version" ) -// CleanupBinaries cleans up any resources created by calls to BinaryDir, TailscaleBinary, or TailscaledBinary. -// It should be called from TestMain after all tests have completed. -func CleanupBinaries() { - buildOnce.Do(func() {}) - if binDir != "" { - os.RemoveAll(binDir) +var ( + verboseTailscaled = flag.Bool("verbose-tailscaled", false, "verbose tailscaled logging") + verboseTailscale = flag.Bool("verbose-tailscale", false, "verbose tailscale CLI logging") +) + +// MainError is an error that's set if an error conditions happens outside of a +// context where a testing.TB is available. The caller can check it in its TestMain +// as a last ditch place to report errors. +var MainError syncs.AtomicValue[error] + +// Binaries contains the paths to the tailscale and tailscaled binaries. +type Binaries struct { + Dir string + Tailscale BinaryInfo + Tailscaled BinaryInfo +} + +// BinaryInfo describes a tailscale or tailscaled binary. +type BinaryInfo struct { + Path string // abs path to tailscale or tailscaled binary + Size int64 + + // FD and FDmu are set on Unix to efficiently copy the binary to a new + // test's automatically-cleaned-up temp directory. + FD *os.File // for Unix (macOS, Linux, ...) + FDMu sync.Locker + + // Contents is used on Windows instead of FD to copy the binary between + // test directories. (On Windows you can't keep an FD open while an earlier + // test's temp directories are deleted.) + // This burns some memory and costs more in I/O, but oh well. + Contents []byte +} + +func (b BinaryInfo) CopyTo(dir string) (BinaryInfo, error) { + ret := b + ret.Path = filepath.Join(dir, path.Base(b.Path)) + + switch runtime.GOOS { + case "linux": + // TODO(bradfitz): be fancy and use linkat with AT_EMPTY_PATH to avoid + // copying? I couldn't get it to work, though. + // For now, just do the same thing as every other Unix and copy + // the binary. + fallthrough + case "darwin", "freebsd", "openbsd", "netbsd": + f, err := os.OpenFile(ret.Path, os.O_RDWR|os.O_CREATE|os.O_EXCL, 0o755) + if err != nil { + return BinaryInfo{}, err + } + b.FDMu.Lock() + b.FD.Seek(0, 0) + size, err := io.Copy(f, b.FD) + b.FDMu.Unlock() + if err != nil { + f.Close() + return BinaryInfo{}, fmt.Errorf("copying %q: %w", b.Path, err) + } + if size != b.Size { + f.Close() + return BinaryInfo{}, fmt.Errorf("copy %q: size mismatch: %d != %d", b.Path, size, b.Size) + } + if err := f.Close(); err != nil { + return BinaryInfo{}, err + } + return ret, nil + case "windows": + return ret, os.WriteFile(ret.Path, b.Contents, 0o755) + default: + return BinaryInfo{}, fmt.Errorf("unsupported OS %q", runtime.GOOS) } } -// BinaryDir returns a directory containing test tailscale and tailscaled binaries. -// If any test calls BinaryDir, there must be a TestMain function that calls -// CleanupBinaries after all tests are complete. -func BinaryDir(tb testing.TB) string { +// GetBinaries create a temp directory using tb and builds (or copies previously +// built) cmd/tailscale and cmd/tailscaled binaries into that directory. +// +// It fails tb if the build or binary copies fail. +func GetBinaries(tb testing.TB) *Binaries { + dir := tb.TempDir() buildOnce.Do(func() { - binDir, buildErr = buildTestBinaries() + buildErr = buildTestBinaries(dir) }) if buildErr != nil { tb.Fatal(buildErr) } - return binDir -} - -// TailscaleBinary returns the path to the test tailscale binary. -// If any test calls TailscaleBinary, there must be a TestMain function that calls -// CleanupBinaries after all tests are complete. -func TailscaleBinary(tb testing.TB) string { - return filepath.Join(BinaryDir(tb), "tailscale"+exe()) -} - -// TailscaledBinary returns the path to the test tailscaled binary. -// If any test calls TailscaleBinary, there must be a TestMain function that calls -// CleanupBinaries after all tests are complete. -func TailscaledBinary(tb testing.TB) string { - return filepath.Join(BinaryDir(tb), "tailscaled"+exe()) + if binariesCache.Dir == dir { + return binariesCache + } + ts, err := binariesCache.Tailscale.CopyTo(dir) + if err != nil { + tb.Fatalf("copying tailscale binary: %v", err) + } + tsd, err := binariesCache.Tailscaled.CopyTo(dir) + if err != nil { + tb.Fatalf("copying tailscaled binary: %v", err) + } + return &Binaries{ + Dir: dir, + Tailscale: ts, + Tailscaled: tsd, + } } var ( - buildOnce sync.Once - buildErr error - binDir string + buildOnce sync.Once + buildErr error + binariesCache *Binaries ) // buildTestBinaries builds tailscale and tailscaled. -// It returns the dir containing the binaries. -func buildTestBinaries() (string, error) { - bindir, err := os.MkdirTemp("", "") +// On success, it initializes [binariesCache]. +func buildTestBinaries(dir string) error { + getBinaryInfo := func(name string) (BinaryInfo, error) { + bi := BinaryInfo{Path: filepath.Join(dir, name+exe())} + fi, err := os.Stat(bi.Path) + if err != nil { + return BinaryInfo{}, fmt.Errorf("stat %q: %v", bi.Path, err) + } + bi.Size = fi.Size() + + switch runtime.GOOS { + case "windows": + bi.Contents, err = os.ReadFile(bi.Path) + if err != nil { + return BinaryInfo{}, fmt.Errorf("read %q: %v", bi.Path, err) + } + default: + bi.FD, err = os.OpenFile(bi.Path, os.O_RDONLY, 0) + if err != nil { + return BinaryInfo{}, fmt.Errorf("open %q: %v", bi.Path, err) + } + bi.FDMu = new(sync.Mutex) + // Note: bi.FD is copied around between tests but never closed, by + // design. It will be closed when the process exits, and that will + // close the inode that we're copying the bytes from for each test. + } + return bi, nil + } + err := build(dir, "tailscale.com/cmd/tailscaled", "tailscale.com/cmd/tailscale") + if err != nil { + return err + } + b := &Binaries{ + Dir: dir, + } + b.Tailscale, err = getBinaryInfo("tailscale") if err != nil { - return "", err + return err } - err = build(bindir, "tailscale.com/cmd/tailscaled", "tailscale.com/cmd/tailscale") + b.Tailscaled, err = getBinaryInfo("tailscaled") if err != nil { - os.RemoveAll(bindir) - return "", err + return err } - return bindir, nil + binariesCache = b + return nil } func build(outDir string, targets ...string) error { @@ -182,14 +296,14 @@ func exe() string { func RunDERPAndSTUN(t testing.TB, logf logger.Logf, ipAddress string) (derpMap *tailcfg.DERPMap) { t.Helper() - d := derp.NewServer(key.NewNode(), logf) + d := derpserver.New(key.NewNode(), logf) ln, err := net.Listen("tcp", net.JoinHostPort(ipAddress, "0")) if err != nil { t.Fatal(err) } - httpsrv := httptest.NewUnstartedServer(derphttp.Handler(d)) + httpsrv := httptest.NewUnstartedServer(derpserver.Handler(d)) httpsrv.Listener.Close() httpsrv.Listener = ln httpsrv.Config.ErrorLog = logger.StdLogger(logf) @@ -361,3 +475,667 @@ func (lc *LogCatcher) ServeHTTP(w http.ResponseWriter, r *http.Request) { } w.WriteHeader(200) // must have no content, but not a 204 } + +// TestEnv contains the test environment (set of servers) used by one +// or more nodes. +type TestEnv struct { + t testing.TB + tunMode bool + cli string + daemon string + loopbackPort *int + neverDirectUDP bool + relayServerUseLoopback bool + + LogCatcher *LogCatcher + LogCatcherServer *httptest.Server + + Control *testcontrol.Server + ControlServer *httptest.Server + + TrafficTrap *trafficTrap + TrafficTrapServer *httptest.Server +} + +// ControlURL returns e.ControlServer.URL, panicking if it's the empty string, +// which it should never be in tests. +func (e *TestEnv) ControlURL() string { + s := e.ControlServer.URL + if s == "" { + panic("control server not set") + } + return s +} + +// TestEnvOpt represents an option that can be passed to NewTestEnv. +type TestEnvOpt interface { + ModifyTestEnv(*TestEnv) +} + +// ConfigureControl is a test option that configures the test control server. +type ConfigureControl func(*testcontrol.Server) + +func (f ConfigureControl) ModifyTestEnv(te *TestEnv) { + f(te.Control) +} + +// NewTestEnv starts a bunch of services and returns a new test environment. +// NewTestEnv arranges for the environment's resources to be cleaned up on exit. +func NewTestEnv(t testing.TB, opts ...TestEnvOpt) *TestEnv { + if runtime.GOOS == "windows" { + t.Skip("not tested/working on Windows yet") + } + derpMap := RunDERPAndSTUN(t, logger.Discard, "127.0.0.1") + logc := new(LogCatcher) + control := &testcontrol.Server{ + Logf: logger.WithPrefix(t.Logf, "testcontrol: "), + DERPMap: derpMap, + } + control.HTTPTestServer = httptest.NewUnstartedServer(control) + trafficTrap := new(trafficTrap) + binaries := GetBinaries(t) + e := &TestEnv{ + t: t, + cli: binaries.Tailscale.Path, + daemon: binaries.Tailscaled.Path, + LogCatcher: logc, + LogCatcherServer: httptest.NewServer(logc), + Control: control, + ControlServer: control.HTTPTestServer, + TrafficTrap: trafficTrap, + TrafficTrapServer: httptest.NewServer(trafficTrap), + } + for _, o := range opts { + o.ModifyTestEnv(e) + } + control.HTTPTestServer.Start() + t.Cleanup(func() { + // Shut down e. + if err := e.TrafficTrap.Err(); err != nil { + e.t.Errorf("traffic trap: %v", err) + e.t.Logf("logs: %s", e.LogCatcher.logsString()) + } + e.LogCatcherServer.Close() + e.TrafficTrapServer.Close() + e.ControlServer.Close() + }) + t.Logf("control URL: %v", e.ControlURL()) + return e +} + +// TestNode is a machine with a tailscale & tailscaled. +// Currently, the test is simplistic and user==node==machine. +// That may grow complexity later to test more. +type TestNode struct { + env *TestEnv + tailscaledParser *nodeOutputParser + + dir string // temp dir for sock & state + configFile string // or empty for none + sockFile string + stateFile string + upFlagGOOS string // if non-empty, sets TS_DEBUG_UP_FLAG_GOOS for cmd/tailscale CLI + encryptState bool + + mu sync.Mutex + onLogLine []func([]byte) + lc *local.Client +} + +// NewTestNode allocates a temp directory for a new test node. +// The node is not started automatically. +func NewTestNode(t *testing.T, env *TestEnv) *TestNode { + dir := t.TempDir() + sockFile := filepath.Join(dir, "tailscale.sock") + if len(sockFile) >= 104 { + // Maximum length for a unix socket on darwin. Try something else. + sockFile = filepath.Join(os.TempDir(), rands.HexString(8)+".sock") + t.Cleanup(func() { os.Remove(sockFile) }) + } + n := &TestNode{ + env: env, + dir: dir, + sockFile: sockFile, + stateFile: filepath.Join(dir, "tailscaled.state"), // matches what cmd/tailscaled uses + } + + // Look for a data race or panic. + // Once we see the start marker, start logging the rest. + var sawRace bool + var sawPanic bool + n.addLogLineHook(func(line []byte) { + lineB := mem.B(line) + if mem.Contains(lineB, mem.S("DEBUG-ADDR=")) { + t.Log(strings.TrimSpace(string(line))) + } + if mem.Contains(lineB, mem.S("WARNING: DATA RACE")) { + sawRace = true + } + if mem.HasPrefix(lineB, mem.S("panic: ")) { + sawPanic = true + } + if sawRace || sawPanic { + t.Logf("%s", line) + } + }) + + return n +} + +func (n *TestNode) LocalClient() *local.Client { + n.mu.Lock() + defer n.mu.Unlock() + if n.lc == nil { + tr := &http.Transport{} + n.lc = &local.Client{ + Socket: n.sockFile, + UseSocketOnly: true, + } + n.env.t.Cleanup(tr.CloseIdleConnections) + } + return n.lc +} + +func (n *TestNode) diskPrefs() *ipn.Prefs { + t := n.env.t + t.Helper() + if _, err := os.ReadFile(n.stateFile); err != nil { + t.Fatalf("reading prefs: %v", err) + } + fs, err := store.New(nil, n.stateFile) + if err != nil { + t.Fatalf("reading prefs, NewFileStore: %v", err) + } + p, err := ipnlocal.ReadStartupPrefsForTest(t.Logf, fs) + if err != nil { + t.Fatalf("reading prefs, ReadDiskPrefsForTest: %v", err) + } + return p.AsStruct() +} + +// AwaitResponding waits for n's tailscaled to be up enough to be +// responding, but doesn't wait for any particular state. +func (n *TestNode) AwaitResponding() { + t := n.env.t + t.Helper() + n.AwaitListening() + + st := n.MustStatus() + t.Logf("Status: %s", st.BackendState) + + if err := tstest.WaitFor(20*time.Second, func() error { + const sub = `Program starting: ` + if !n.env.LogCatcher.logsContains(mem.S(sub)) { + return fmt.Errorf("log catcher didn't see %#q; got %s", sub, n.env.LogCatcher.logsString()) + } + return nil + }); err != nil { + t.Fatal(err) + } +} + +// addLogLineHook registers a hook f to be called on each tailscaled +// log line output. +func (n *TestNode) addLogLineHook(f func([]byte)) { + n.mu.Lock() + defer n.mu.Unlock() + n.onLogLine = append(n.onLogLine, f) +} + +// socks5AddrChan returns a channel that receives the address (e.g. "localhost:23874") +// of the node's SOCKS5 listener, once started. +func (n *TestNode) socks5AddrChan() <-chan string { + ch := make(chan string, 1) + n.addLogLineHook(func(line []byte) { + const sub = "SOCKS5 listening on " + i := mem.Index(mem.B(line), mem.S(sub)) + if i == -1 { + return + } + addr := strings.TrimSpace(string(line)[i+len(sub):]) + select { + case ch <- addr: + default: + } + }) + return ch +} + +func (n *TestNode) AwaitSocksAddr(ch <-chan string) string { + t := n.env.t + t.Helper() + timer := time.NewTimer(10 * time.Second) + defer timer.Stop() + select { + case v := <-ch: + return v + case <-timer.C: + t.Fatal("timeout waiting for node to log its SOCK5 listening address") + panic("unreachable") + } +} + +// nodeOutputParser parses stderr of tailscaled processes, calling the +// per-line callbacks previously registered via +// testNode.addLogLineHook. +type nodeOutputParser struct { + allBuf bytes.Buffer + pendLineBuf bytes.Buffer + n *TestNode +} + +func (op *nodeOutputParser) Write(p []byte) (n int, err error) { + tn := op.n + tn.mu.Lock() + defer tn.mu.Unlock() + + op.allBuf.Write(p) + n, err = op.pendLineBuf.Write(p) + op.parseLinesLocked() + return +} + +func (op *nodeOutputParser) parseLinesLocked() { + n := op.n + buf := op.pendLineBuf.Bytes() + for len(buf) > 0 { + nl := bytes.IndexByte(buf, '\n') + if nl == -1 { + break + } + line := buf[:nl+1] + buf = buf[nl+1:] + + for _, f := range n.onLogLine { + f(line) + } + } + if len(buf) == 0 { + op.pendLineBuf.Reset() + } else { + io.CopyN(io.Discard, &op.pendLineBuf, int64(op.pendLineBuf.Len()-len(buf))) + } +} + +type Daemon struct { + Process *os.Process +} + +func (d *Daemon) MustCleanShutdown(t testing.TB) { + d.Process.Signal(os.Interrupt) + ps, err := d.Process.Wait() + if err != nil { + t.Fatalf("tailscaled Wait: %v", err) + } + if ps.ExitCode() != 0 { + t.Errorf("tailscaled ExitCode = %d; want 0", ps.ExitCode()) + } +} + +// awaitTailscaledRunnable tries to run `tailscaled --version` until it +// works. This is an unsatisfying workaround for ETXTBSY we were seeing +// on GitHub Actions that aren't understood. It's not clear what's holding +// a writable fd to tailscaled after `go install` completes. +// See https://github.com/tailscale/tailscale/issues/15868. +func (n *TestNode) awaitTailscaledRunnable() error { + t := n.env.t + t.Helper() + if err := tstest.WaitFor(10*time.Second, func() error { + out, err := exec.Command(n.env.daemon, "--version").CombinedOutput() + if err == nil { + return nil + } + t.Logf("error running tailscaled --version: %v, %s", err, out) + return err + }); err != nil { + return fmt.Errorf("gave up trying to run tailscaled: %v", err) + } + return nil +} + +// StartDaemon starts the node's tailscaled, failing if it fails to start. +// StartDaemon ensures that the process will exit when the test completes. +func (n *TestNode) StartDaemon() *Daemon { + return n.StartDaemonAsIPNGOOS(runtime.GOOS) +} + +func (n *TestNode) StartDaemonAsIPNGOOS(ipnGOOS string) *Daemon { + t := n.env.t + + if err := n.awaitTailscaledRunnable(); err != nil { + t.Fatalf("awaitTailscaledRunnable: %v", err) + } + + cmd := exec.Command(n.env.daemon) + cmd.Args = append(cmd.Args, + "--statedir="+n.dir, + "--socket="+n.sockFile, + "--socks5-server=localhost:0", + "--debug=localhost:0", + ) + if *verboseTailscaled { + cmd.Args = append(cmd.Args, "-verbose=2") + } + if !n.env.tunMode { + cmd.Args = append(cmd.Args, + "--tun=userspace-networking", + ) + } + if n.configFile != "" { + cmd.Args = append(cmd.Args, "--config="+n.configFile) + } + if n.encryptState { + cmd.Args = append(cmd.Args, "--encrypt-state") + } + cmd.Env = append(os.Environ(), + "TS_DEBUG_PERMIT_HTTP_C2N=1", + "TS_LOG_TARGET="+n.env.LogCatcherServer.URL, + "HTTP_PROXY="+n.env.TrafficTrapServer.URL, + "HTTPS_PROXY="+n.env.TrafficTrapServer.URL, + "TS_DEBUG_FAKE_GOOS="+ipnGOOS, + "TS_LOGS_DIR="+t.TempDir(), + "TS_NETCHECK_GENERATE_204_URL="+n.env.ControlServer.URL+"/generate_204", + "TS_ASSUME_NETWORK_UP_FOR_TEST=1", // don't pause control client in airplane mode (no wifi, etc) + "TS_PANIC_IF_HIT_MAIN_CONTROL=1", + "TS_DISABLE_PORTMAPPER=1", // shouldn't be needed; test is all localhost + "TS_DEBUG_LOG_RATE=all", + ) + if n.env.loopbackPort != nil { + cmd.Env = append(cmd.Env, "TS_DEBUG_NETSTACK_LOOPBACK_PORT="+strconv.Itoa(*n.env.loopbackPort)) + } + if n.env.neverDirectUDP { + cmd.Env = append(cmd.Env, "TS_DEBUG_NEVER_DIRECT_UDP=1") + } + if n.env.relayServerUseLoopback { + cmd.Env = append(cmd.Env, "TS_DEBUG_RELAY_SERVER_ADDRS=::1,127.0.0.1") + } + if version.IsRace() { + cmd.Env = append(cmd.Env, "GORACE=halt_on_error=1") + } + n.tailscaledParser = &nodeOutputParser{n: n} + cmd.Stderr = n.tailscaledParser + if *verboseTailscaled { + cmd.Stdout = os.Stdout + cmd.Stderr = io.MultiWriter(cmd.Stderr, os.Stderr) + } + if runtime.GOOS != "windows" { + pr, pw, err := os.Pipe() + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { pw.Close() }) + cmd.ExtraFiles = append(cmd.ExtraFiles, pr) + cmd.Env = append(cmd.Env, "TS_PARENT_DEATH_FD=3") + } + if err := cmd.Start(); err != nil { + t.Fatalf("starting tailscaled: %v", err) + } + t.Cleanup(func() { cmd.Process.Kill() }) + return &Daemon{ + Process: cmd.Process, + } +} + +func (n *TestNode) MustUp(extraArgs ...string) { + t := n.env.t + t.Helper() + args := []string{ + "up", + "--login-server=" + n.env.ControlURL(), + "--reset", + } + args = append(args, extraArgs...) + cmd := n.Tailscale(args...) + t.Logf("Running %v ...", cmd) + cmd.Stdout = nil // in case --verbose-tailscale was set + cmd.Stderr = nil // in case --verbose-tailscale was set + if b, err := cmd.CombinedOutput(); err != nil { + t.Fatalf("up: %v, %v", string(b), err) + } +} + +func (n *TestNode) MustDown() { + t := n.env.t + t.Logf("Running down ...") + if err := n.Tailscale("down", "--accept-risk=all").Run(); err != nil { + t.Fatalf("down: %v", err) + } +} + +func (n *TestNode) MustLogOut() { + t := n.env.t + t.Logf("Running logout ...") + if err := n.Tailscale("logout").Run(); err != nil { + t.Fatalf("logout: %v", err) + } +} + +func (n *TestNode) Ping(otherNode *TestNode) error { + t := n.env.t + ip := otherNode.AwaitIP4().String() + t.Logf("Running ping %v (from %v)...", ip, n.AwaitIP4()) + return n.Tailscale("ping", ip).Run() +} + +// AwaitListening waits for the tailscaled to be serving local clients +// over its localhost IPC mechanism. (Unix socket, etc) +func (n *TestNode) AwaitListening() { + t := n.env.t + if err := tstest.WaitFor(20*time.Second, func() (err error) { + c, err := safesocket.ConnectContext(context.Background(), n.sockFile) + if err == nil { + c.Close() + } + return err + }); err != nil { + t.Fatal(err) + } +} + +func (n *TestNode) AwaitIPs() []netip.Addr { + t := n.env.t + t.Helper() + var addrs []netip.Addr + if err := tstest.WaitFor(20*time.Second, func() error { + cmd := n.Tailscale("ip") + cmd.Stdout = nil // in case --verbose-tailscale was set + cmd.Stderr = nil // in case --verbose-tailscale was set + out, err := cmd.Output() + if err != nil { + return err + } + ips := string(out) + ipslice := strings.Fields(ips) + addrs = make([]netip.Addr, len(ipslice)) + + for i, ip := range ipslice { + netIP, err := netip.ParseAddr(ip) + if err != nil { + t.Fatal(err) + } + addrs[i] = netIP + } + return nil + }); err != nil { + t.Fatalf("awaiting an IP address: %v", err) + } + if len(addrs) == 0 { + t.Fatalf("returned IP address was blank") + } + return addrs +} + +// AwaitIP4 returns the IPv4 address of n. +func (n *TestNode) AwaitIP4() netip.Addr { + t := n.env.t + t.Helper() + ips := n.AwaitIPs() + return ips[0] +} + +// AwaitIP6 returns the IPv6 address of n. +func (n *TestNode) AwaitIP6() netip.Addr { + t := n.env.t + t.Helper() + ips := n.AwaitIPs() + return ips[1] +} + +// AwaitRunning waits for n to reach the IPN state "Running". +func (n *TestNode) AwaitRunning() { + t := n.env.t + t.Helper() + n.AwaitBackendState("Running") +} + +func (n *TestNode) AwaitBackendState(state string) { + t := n.env.t + t.Helper() + if err := tstest.WaitFor(20*time.Second, func() error { + st, err := n.Status() + if err != nil { + return err + } + if st.BackendState != state { + return fmt.Errorf("in state %q; want %q", st.BackendState, state) + } + return nil + }); err != nil { + t.Fatalf("failure/timeout waiting for transition to Running status: %v", err) + } +} + +// AwaitNeedsLogin waits for n to reach the IPN state "NeedsLogin". +func (n *TestNode) AwaitNeedsLogin() { + t := n.env.t + t.Helper() + if err := tstest.WaitFor(20*time.Second, func() error { + st, err := n.Status() + if err != nil { + return err + } + if st.BackendState != "NeedsLogin" { + return fmt.Errorf("in state %q", st.BackendState) + } + return nil + }); err != nil { + t.Fatalf("failure/timeout waiting for transition to NeedsLogin status: %v", err) + } +} + +func (n *TestNode) TailscaleForOutput(arg ...string) *exec.Cmd { + cmd := n.Tailscale(arg...) + cmd.Stdout = nil + cmd.Stderr = nil + return cmd +} + +// Tailscale returns a command that runs the tailscale CLI with the provided arguments. +// It does not start the process. +func (n *TestNode) Tailscale(arg ...string) *exec.Cmd { + cmd := exec.Command(n.env.cli) + cmd.Args = append(cmd.Args, "--socket="+n.sockFile) + cmd.Args = append(cmd.Args, arg...) + cmd.Dir = n.dir + cmd.Env = append(os.Environ(), + "TS_DEBUG_UP_FLAG_GOOS="+n.upFlagGOOS, + "TS_LOGS_DIR="+n.env.t.TempDir(), + ) + if *verboseTailscale { + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + } + return cmd +} + +func (n *TestNode) Status() (*ipnstate.Status, error) { + cmd := n.Tailscale("status", "--json") + cmd.Stdout = nil // in case --verbose-tailscale was set + cmd.Stderr = nil // in case --verbose-tailscale was set + out, err := cmd.CombinedOutput() + if err != nil { + return nil, fmt.Errorf("running tailscale status: %v, %s", err, out) + } + st := new(ipnstate.Status) + if err := json.Unmarshal(out, st); err != nil { + return nil, fmt.Errorf("decoding tailscale status JSON: %w\njson:\n%s", err, out) + } + return st, nil +} + +func (n *TestNode) MustStatus() *ipnstate.Status { + tb := n.env.t + tb.Helper() + st, err := n.Status() + if err != nil { + tb.Fatal(err) + } + return st +} + +// trafficTrap is an HTTP proxy handler to note whether any +// HTTP traffic tries to leave localhost from tailscaled. We don't +// expect any, so any request triggers a failure. +type trafficTrap struct { + atomicErr syncs.AtomicValue[error] +} + +func (tt *trafficTrap) Err() error { + return tt.atomicErr.Load() +} + +func (tt *trafficTrap) ServeHTTP(w http.ResponseWriter, r *http.Request) { + var got bytes.Buffer + r.Write(&got) + err := fmt.Errorf("unexpected HTTP request via proxy: %s", got.Bytes()) + MainError.Store(err) + if tt.Err() == nil { + // Best effort at remembering the first request. + tt.atomicErr.Store(err) + } + log.Printf("Error: %v", err) + w.WriteHeader(403) +} + +type authURLParserWriter struct { + t *testing.T + buf bytes.Buffer + // Handle login URLs, and count how many times they were seen + authURLFn func(urlStr string) error + // Handle machine approval URLs, and count how many times they were seen. + deviceApprovalURLFn func(urlStr string) error +} + +// Note: auth URLs from testcontrol look slightly different to real auth URLs, +// e.g. http://127.0.0.1:60456/auth/96af2ff7e04ae1499a9a +var authURLRx = regexp.MustCompile(`(https?://\S+/auth/\S+)`) + +// Looks for any device approval URL, which is any URL ending with `/admin` +// e.g. http://127.0.0.1:60456/admin +var deviceApprovalURLRx = regexp.MustCompile(`(https?://\S+/admin)[^\S]`) + +func (w *authURLParserWriter) Write(p []byte) (n int, err error) { + w.t.Helper() + w.t.Logf("received bytes: %s", string(p)) + n, err = w.buf.Write(p) + + defer w.buf.Reset() // so it's not matched again + + m := authURLRx.FindSubmatch(w.buf.Bytes()) + if m != nil { + urlStr := string(m[1]) + if err := w.authURLFn(urlStr); err != nil { + return 0, err + } + } + + m = deviceApprovalURLRx.FindSubmatch(w.buf.Bytes()) + if m != nil && w.deviceApprovalURLFn != nil { + urlStr := string(m[1]) + if err := w.deviceApprovalURLFn(urlStr); err != nil { + return 0, err + } + } + + return n, err +} diff --git a/tstest/integration/integration_test.go b/tstest/integration/integration_test.go index 70c5d68c3..9d75cfc29 100644 --- a/tstest/integration/integration_test.go +++ b/tstest/integration/integration_test.go @@ -13,7 +13,6 @@ import ( "flag" "fmt" "io" - "log" "net" "net/http" "net/http/httptest" @@ -25,54 +24,43 @@ import ( "runtime" "strconv" "strings" - "sync" "sync/atomic" "testing" "time" + "github.com/google/go-cmp/cmp" "github.com/miekg/dns" "go4.org/mem" + "tailscale.com/client/local" "tailscale.com/client/tailscale" - "tailscale.com/clientupdate" "tailscale.com/cmd/testwrapper/flakytest" + "tailscale.com/feature" + _ "tailscale.com/feature/clientupdate" + "tailscale.com/hostinfo" "tailscale.com/ipn" - "tailscale.com/ipn/ipnlocal" - "tailscale.com/ipn/ipnstate" - "tailscale.com/ipn/store" "tailscale.com/net/tsaddr" "tailscale.com/net/tstun" - "tailscale.com/safesocket" - "tailscale.com/syncs" + "tailscale.com/net/udprelay/status" "tailscale.com/tailcfg" "tailscale.com/tstest" "tailscale.com/tstest/integration/testcontrol" "tailscale.com/types/key" - "tailscale.com/types/logger" + "tailscale.com/types/netmap" "tailscale.com/types/opt" "tailscale.com/types/ptr" - "tailscale.com/util/dnsname" "tailscale.com/util/must" - "tailscale.com/util/rands" - "tailscale.com/version" + "tailscale.com/util/set" ) -var ( - verboseTailscaled = flag.Bool("verbose-tailscaled", false, "verbose tailscaled logging") - verboseTailscale = flag.Bool("verbose-tailscale", false, "verbose tailscale CLI logging") -) - -var mainError syncs.AtomicValue[error] - func TestMain(m *testing.M) { // Have to disable UPnP which hits the network, otherwise it fails due to HTTP proxy. os.Setenv("TS_DISABLE_UPNP", "true") flag.Parse() v := m.Run() - CleanupBinaries() if v != 0 { os.Exit(v) } - if err := mainError.Load(); err != nil { + if err := MainError.Load(); err != nil { fmt.Fprintf(os.Stderr, "FAIL: %v\n", err) os.Exit(1) } @@ -87,9 +75,9 @@ func TestTUNMode(t *testing.T) { t.Skip("skipping when not root") } tstest.Parallel(t) - env := newTestEnv(t) + env := NewTestEnv(t) env.tunMode = true - n1 := newTestNode(t, env) + n1 := NewTestNode(t, env) d1 := n1.StartDaemon() n1.AwaitResponding() @@ -104,8 +92,8 @@ func TestTUNMode(t *testing.T) { func TestOneNodeUpNoAuth(t *testing.T) { tstest.Shard(t) tstest.Parallel(t) - env := newTestEnv(t) - n1 := newTestNode(t, env) + env := NewTestEnv(t) + n1 := NewTestNode(t, env) d1 := n1.StartDaemon() n1.AwaitResponding() @@ -122,8 +110,8 @@ func TestOneNodeUpNoAuth(t *testing.T) { func TestOneNodeExpiredKey(t *testing.T) { tstest.Shard(t) tstest.Parallel(t) - env := newTestEnv(t) - n1 := newTestNode(t, env) + env := NewTestEnv(t) + n1 := NewTestNode(t, env) d1 := n1.StartDaemon() n1.AwaitResponding() @@ -159,8 +147,8 @@ func TestOneNodeExpiredKey(t *testing.T) { func TestControlKnobs(t *testing.T) { tstest.Shard(t) tstest.Parallel(t) - env := newTestEnv(t) - n1 := newTestNode(t, env) + env := NewTestEnv(t) + n1 := NewTestNode(t, env) d1 := n1.StartDaemon() defer d1.MustCleanShutdown(t) @@ -187,11 +175,34 @@ func TestControlKnobs(t *testing.T) { } } +func TestExpectedFeaturesLinked(t *testing.T) { + tstest.Shard(t) + tstest.Parallel(t) + env := NewTestEnv(t) + n1 := NewTestNode(t, env) + + d1 := n1.StartDaemon() + n1.AwaitResponding() + lc := n1.LocalClient() + got, err := lc.QueryOptionalFeatures(t.Context()) + if err != nil { + t.Fatal(err) + } + if !got.Features["portmapper"] { + t.Errorf("optional feature portmapper unexpectedly not found: got %v", got.Features) + } + + d1.MustCleanShutdown(t) + + t.Logf("number of HTTP logcatcher requests: %v", env.LogCatcher.numRequests()) +} + func TestCollectPanic(t *testing.T) { + flakytest.Mark(t, "https://github.com/tailscale/tailscale/issues/15865") tstest.Shard(t) tstest.Parallel(t) - env := newTestEnv(t) - n := newTestNode(t, env) + env := NewTestEnv(t) + n := NewTestNode(t, env) cmd := exec.Command(env.daemon, "--cleanup") cmd.Env = append(os.Environ(), @@ -221,9 +232,9 @@ func TestCollectPanic(t *testing.T) { func TestControlTimeLogLine(t *testing.T) { tstest.Shard(t) tstest.Parallel(t) - env := newTestEnv(t) + env := NewTestEnv(t) env.LogCatcher.StoreRawJSON() - n := newTestNode(t, env) + n := NewTestNode(t, env) n.StartDaemon() n.AwaitResponding() @@ -245,8 +256,8 @@ func TestControlTimeLogLine(t *testing.T) { func TestStateSavedOnStart(t *testing.T) { tstest.Shard(t) tstest.Parallel(t) - env := newTestEnv(t) - n1 := newTestNode(t, env) + env := NewTestEnv(t) + n1 := NewTestNode(t, env) d1 := n1.StartDaemon() n1.AwaitResponding() @@ -264,7 +275,7 @@ func TestStateSavedOnStart(t *testing.T) { n1.MustDown() // And change the hostname to something: - if err := n1.Tailscale("up", "--login-server="+n1.env.controlURL(), "--hostname=foo").Run(); err != nil { + if err := n1.Tailscale("up", "--login-server="+n1.env.ControlURL(), "--hostname=foo").Run(); err != nil { t.Fatalf("up: %v", err) } @@ -279,48 +290,426 @@ func TestStateSavedOnStart(t *testing.T) { d1.MustCleanShutdown(t) } +// This handler receives auth URLs, and logs into control. +// +// It counts how many URLs it sees, and will fail the test if it +// sees multiple login URLs. +func completeLogin(t *testing.T, control *testcontrol.Server, counter *atomic.Int32) func(string) error { + return func(urlStr string) error { + t.Logf("saw auth URL %q", urlStr) + if control.CompleteAuth(urlStr) { + if counter.Add(1) > 1 { + err := errors.New("completed multiple auth URLs") + t.Error(err) + return err + } + t.Logf("completed login to %s", urlStr) + return nil + } else { + err := fmt.Errorf("failed to complete initial login to %q", urlStr) + t.Fatal(err) + return err + } + } +} + +// This handler receives device approval URLs, and approves the device. +// +// It counts how many URLs it sees, and will fail the test if it +// sees multiple device approval URLs, or if you try to approve a device +// with the wrong control server. +func completeDeviceApproval(t *testing.T, node *TestNode, counter *atomic.Int32) func(string) error { + return func(urlStr string) error { + control := node.env.Control + nodeKey := node.MustStatus().Self.PublicKey + t.Logf("saw device approval URL %q", urlStr) + if control.CompleteDeviceApproval(node.env.ControlURL(), urlStr, &nodeKey) { + if counter.Add(1) > 1 { + err := errors.New("completed multiple device approval URLs") + t.Error(err) + return err + } + t.Log("completed device approval") + return nil + } else { + err := errors.New("failed to complete device approval") + t.Fatal(err) + return err + } + } +} + func TestOneNodeUpAuth(t *testing.T) { - tstest.Shard(t) - tstest.Parallel(t) - env := newTestEnv(t, configureControl(func(control *testcontrol.Server) { - control.RequireAuth = true - })) + type step struct { + args []string + // + // Do we expect to log in again with a new /auth/ URL? + wantAuthURL bool + // + // Do we expect to need a device approval URL? + wantDeviceApprovalURL bool + } + + for _, tt := range []struct { + name string + args []string + // + // What auth key should we use for control? + authKey string + // + // Do we require device approval in the tailnet? + requireDeviceApproval bool + // + // What CLI commands should we run in this test? + steps []step + }{ + { + name: "up", + steps: []step{ + {args: []string{"up"}, wantAuthURL: true}, + }, + }, + { + name: "up-with-machine-auth", + steps: []step{ + {args: []string{"up"}, wantAuthURL: true, wantDeviceApprovalURL: true}, + }, + requireDeviceApproval: true, + }, + { + name: "up-with-force-reauth", + steps: []step{ + {args: []string{"up", "--force-reauth"}, wantAuthURL: true}, + }, + }, + { + name: "up-with-auth-key", + authKey: "opensesame", + steps: []step{ + {args: []string{"up", "--auth-key=opensesame"}}, + }, + }, + { + name: "up-with-auth-key-with-machine-auth", + authKey: "opensesame", + steps: []step{ + { + args: []string{"up", "--auth-key=opensesame"}, + wantAuthURL: false, + wantDeviceApprovalURL: true, + }, + }, + requireDeviceApproval: true, + }, + { + name: "up-with-force-reauth-and-auth-key", + authKey: "opensesame", + steps: []step{ + {args: []string{"up", "--force-reauth", "--auth-key=opensesame"}}, + }, + }, + { + name: "up-after-login", + steps: []step{ + {args: []string{"up"}, wantAuthURL: true}, + {args: []string{"up"}, wantAuthURL: false}, + }, + }, + { + name: "up-after-login-with-machine-approval", + steps: []step{ + {args: []string{"up"}, wantAuthURL: true, wantDeviceApprovalURL: true}, + {args: []string{"up"}, wantAuthURL: false, wantDeviceApprovalURL: false}, + }, + requireDeviceApproval: true, + }, + { + name: "up-with-force-reauth-after-login", + steps: []step{ + {args: []string{"up"}, wantAuthURL: true}, + {args: []string{"up", "--force-reauth"}, wantAuthURL: true}, + }, + }, + { + name: "up-with-force-reauth-after-login-with-machine-approval", + steps: []step{ + {args: []string{"up"}, wantAuthURL: true, wantDeviceApprovalURL: true}, + {args: []string{"up", "--force-reauth"}, wantAuthURL: true, wantDeviceApprovalURL: false}, + }, + requireDeviceApproval: true, + }, + { + name: "up-with-auth-key-after-login", + authKey: "opensesame", + steps: []step{ + {args: []string{"up", "--auth-key=opensesame"}}, + {args: []string{"up", "--auth-key=opensesame"}}, + }, + }, + { + name: "up-with-force-reauth-and-auth-key-after-login", + authKey: "opensesame", + steps: []step{ + {args: []string{"up", "--auth-key=opensesame"}}, + {args: []string{"up", "--force-reauth", "--auth-key=opensesame"}}, + }, + }, + } { + tstest.Shard(t) - n1 := newTestNode(t, env) - d1 := n1.StartDaemon() + for _, useSeamlessKeyRenewal := range []bool{true, false} { + name := tt.name + if useSeamlessKeyRenewal { + name += "-with-seamless" + } + t.Run(name, func(t *testing.T) { + tstest.Parallel(t) + + env := NewTestEnv(t, ConfigureControl( + func(control *testcontrol.Server) { + if tt.authKey != "" { + control.RequireAuthKey = tt.authKey + } else { + control.RequireAuth = true + } + + if tt.requireDeviceApproval { + control.RequireMachineAuth = true + } + + control.AllNodesSameUser = true + + if useSeamlessKeyRenewal { + control.DefaultNodeCapabilities = &tailcfg.NodeCapMap{ + tailcfg.NodeAttrSeamlessKeyRenewal: []tailcfg.RawMessage{}, + } + } + }, + )) - n1.AwaitListening() + n1 := NewTestNode(t, env) + d1 := n1.StartDaemon() + defer d1.MustCleanShutdown(t) - st := n1.MustStatus() - t.Logf("Status: %s", st.BackendState) + for i, step := range tt.steps { + t.Logf("Running step %d", i) + cmdArgs := append(step.args, "--login-server="+env.ControlURL()) - t.Logf("Running up --login-server=%s ...", env.controlURL()) + t.Logf("Running command: %s", strings.Join(cmdArgs, " ")) - cmd := n1.Tailscale("up", "--login-server="+env.controlURL()) - var authCountAtomic int32 - cmd.Stdout = &authURLParserWriter{fn: func(urlStr string) error { - if env.Control.CompleteAuth(urlStr) { - atomic.AddInt32(&authCountAtomic, 1) - t.Logf("completed auth path %s", urlStr) - return nil + var authURLCount atomic.Int32 + var deviceApprovalURLCount atomic.Int32 + + handler := &authURLParserWriter{t: t, + authURLFn: completeLogin(t, env.Control, &authURLCount), + deviceApprovalURLFn: completeDeviceApproval(t, n1, &deviceApprovalURLCount), + } + + cmd := n1.Tailscale(cmdArgs...) + cmd.Stdout = handler + cmd.Stdout = handler + cmd.Stderr = cmd.Stdout + if err := cmd.Run(); err != nil { + t.Fatalf("up: %v", err) + } + + n1.AwaitRunning() + + var wantAuthURLCount int32 + if step.wantAuthURL { + wantAuthURLCount = 1 + } + if n := authURLCount.Load(); n != wantAuthURLCount { + t.Errorf("Auth URLs completed = %d; want %d", n, wantAuthURLCount) + } + + var wantDeviceApprovalURLCount int32 + if step.wantDeviceApprovalURL { + wantDeviceApprovalURLCount = 1 + } + if n := deviceApprovalURLCount.Load(); n != wantDeviceApprovalURLCount { + t.Errorf("Device approval URLs completed = %d; want %d", n, wantDeviceApprovalURLCount) + } + } + }) } - err := fmt.Errorf("Failed to complete auth path to %q", urlStr) - t.Log(err) - return err + } +} + +// Returns true if the error returned by [exec.Run] fails with a non-zero +// exit code, false otherwise. +func isNonZeroExitCode(err error) bool { + if err == nil { + return false + } + + exitError, ok := err.(*exec.ExitError) + if !ok { + return false + } + + return exitError.ExitCode() != 0 +} + +// If we interrupt `tailscale up` and then run it again, we should only +// print a single auth URL. +func TestOneNodeUpInterruptedAuth(t *testing.T) { + tstest.Shard(t) + tstest.Parallel(t) + + env := NewTestEnv(t, ConfigureControl( + func(control *testcontrol.Server) { + control.RequireAuth = true + control.AllNodesSameUser = true + }, + )) + + n := NewTestNode(t, env) + d := n.StartDaemon() + defer d.MustCleanShutdown(t) + + cmdArgs := []string{"up", "--login-server=" + env.ControlURL()} + + // The first time we run the command, we wait for an auth URL to be + // printed, and then we cancel the command -- equivalent to ^C. + // + // At this point, we've connected to control to get an auth URL, + // and printed it in the CLI, but not clicked it. + t.Logf("Running command for the first time: %s", strings.Join(cmdArgs, " ")) + cmd1 := n.Tailscale(cmdArgs...) + + // This handler watches for auth URLs in stdout, then cancels the + // running `tailscale up` CLI command. + cmd1.Stdout = &authURLParserWriter{t: t, authURLFn: func(urlStr string) error { + t.Logf("saw auth URL %q", urlStr) + cmd1.Process.Kill() + return nil }} - cmd.Stderr = cmd.Stdout - if err := cmd.Run(); err != nil { + cmd1.Stderr = cmd1.Stdout + + if err := cmd1.Run(); !isNonZeroExitCode(err) { + t.Fatalf("Command did not fail with non-zero exit code: %q", err) + } + + // Because we didn't click the auth URL, we should still be in NeedsLogin. + n.AwaitBackendState("NeedsLogin") + + // The second time we run the command, we click the first auth URL we see + // and check that we log in correctly. + // + // In #17361, there was a bug where we'd print two auth URLs, and you could + // click either auth URL and log in to control, but logging in through the + // first URL would leave `tailscale up` hanging. + // + // Using `authURLHandler` ensures we only print the new, correct auth URL. + // + // If we print both URLs, it will throw an error because it only expects + // to log in with one auth URL. + // + // If we only print the stale auth URL, the test will timeout because + // `tailscale up` will never return. + t.Logf("Running command for the second time: %s", strings.Join(cmdArgs, " ")) + + var authURLCount atomic.Int32 + + cmd2 := n.Tailscale(cmdArgs...) + cmd2.Stdout = &authURLParserWriter{ + t: t, authURLFn: completeLogin(t, env.Control, &authURLCount), + } + cmd2.Stderr = cmd2.Stdout + + if err := cmd2.Run(); err != nil { t.Fatalf("up: %v", err) } - t.Logf("Got IP: %v", n1.AwaitIP4()) - n1.AwaitRunning() + if urls := authURLCount.Load(); urls != 1 { + t.Errorf("Auth URLs completed = %d; want %d", urls, 1) + } + + n.AwaitRunning() +} + +// If we interrupt `tailscale up` and login successfully, but don't +// complete the device approval, we should see the device approval URL +// when we run `tailscale up` a second time. +func TestOneNodeUpInterruptedDeviceApproval(t *testing.T) { + tstest.Shard(t) + tstest.Parallel(t) + + env := NewTestEnv(t, ConfigureControl( + func(control *testcontrol.Server) { + control.RequireAuth = true + control.RequireMachineAuth = true + control.AllNodesSameUser = true + }, + )) - if n := atomic.LoadInt32(&authCountAtomic); n != 1 { - t.Errorf("Auth URLs completed = %d; want 1", n) + n := NewTestNode(t, env) + d := n.StartDaemon() + defer d.MustCleanShutdown(t) + + // The first time we run the command, we: + // + // * set a custom login URL + // * wait for an auth URL to be printed + // * click it to complete the login process + // * wait for a device approval URL to be printed + // * cancel the command, equivalent to ^C + // + // At this point, we've logged in to control, but our node isn't + // approved to connect to the tailnet. + cmd1Args := []string{"up", "--login-server=" + env.ControlURL()} + t.Logf("Running command: %s", strings.Join(cmd1Args, " ")) + cmd1 := n.Tailscale(cmd1Args...) + + handler1 := &authURLParserWriter{t: t, + authURLFn: completeLogin(t, env.Control, &atomic.Int32{}), + deviceApprovalURLFn: func(urlStr string) error { + t.Logf("saw device approval URL %q", urlStr) + cmd1.Process.Kill() + return nil + }, } + cmd1.Stdout = handler1 + cmd1.Stderr = cmd1.Stdout - d1.MustCleanShutdown(t) + if err := cmd1.Run(); !isNonZeroExitCode(err) { + t.Fatalf("Command did not fail with non-zero exit code: %q", err) + } + + // Because we logged in but we didn't complete the device approval, we + // should be in state NeedsMachineAuth. + n.AwaitBackendState("NeedsMachineAuth") + + // The second time we run the command, we expect not to get an auth URL + // and go straight to the device approval URL. We don't need to pass the + // login server, because `tailscale up` should remember our control URL. + cmd2Args := []string{"up"} + t.Logf("Running command: %s", strings.Join(cmd2Args, " ")) + + var deviceApprovalURLCount atomic.Int32 + + cmd2 := n.Tailscale(cmd2Args...) + cmd2.Stdout = &authURLParserWriter{t: t, + authURLFn: func(urlStr string) error { + t.Fatalf("got unexpected auth URL: %q", urlStr) + cmd2.Process.Kill() + return nil + }, + deviceApprovalURLFn: completeDeviceApproval(t, n, &deviceApprovalURLCount), + } + cmd2.Stderr = cmd2.Stdout + + if err := cmd2.Run(); err != nil { + t.Fatalf("up: %v", err) + } + + wantDeviceApprovalURLCount := int32(1) + if n := deviceApprovalURLCount.Load(); n != wantDeviceApprovalURLCount { + t.Errorf("Device approval URLs completed = %d; want %d", n, wantDeviceApprovalURLCount) + } + + n.AwaitRunning() } func TestConfigFileAuthKey(t *testing.T) { @@ -328,11 +717,11 @@ func TestConfigFileAuthKey(t *testing.T) { tstest.Shard(t) t.Parallel() const authKey = "opensesame" - env := newTestEnv(t, configureControl(func(control *testcontrol.Server) { + env := NewTestEnv(t, ConfigureControl(func(control *testcontrol.Server) { control.RequireAuthKey = authKey })) - n1 := newTestNode(t, env) + n1 := NewTestNode(t, env) n1.configFile = filepath.Join(n1.dir, "config.json") authKeyFile := filepath.Join(n1.dir, "my-auth-key") must.Do(os.WriteFile(authKeyFile, fmt.Appendf(nil, "%s\n", authKey), 0666)) @@ -353,14 +742,14 @@ func TestConfigFileAuthKey(t *testing.T) { func TestTwoNodes(t *testing.T) { tstest.Shard(t) tstest.Parallel(t) - env := newTestEnv(t) + env := NewTestEnv(t) // Create two nodes: - n1 := newTestNode(t, env) + n1 := NewTestNode(t, env) n1SocksAddrCh := n1.socks5AddrChan() d1 := n1.StartDaemon() - n2 := newTestNode(t, env) + n2 := NewTestNode(t, env) n2SocksAddrCh := n2.socks5AddrChan() d2 := n2.StartDaemon() @@ -379,7 +768,7 @@ func TestTwoNodes(t *testing.T) { defer n2.mu.Unlock() rxNoDates := regexp.MustCompile(`(?m)^\d{4}.\d{2}.\d{2}.\d{2}:\d{2}:\d{2}`) - cleanLog := func(n *testNode) []byte { + cleanLog := func(n *TestNode) []byte { b := n.tailscaledParser.allBuf.Bytes() b = rxNoDates.ReplaceAll(b, nil) return b @@ -439,10 +828,10 @@ func TestTwoNodes(t *testing.T) { func TestIncrementalMapUpdatePeersRemoved(t *testing.T) { tstest.Shard(t) tstest.Parallel(t) - env := newTestEnv(t) + env := NewTestEnv(t) // Create one node: - n1 := newTestNode(t, env) + n1 := NewTestNode(t, env) d1 := n1.StartDaemon() n1.AwaitListening() n1.MustUp() @@ -454,7 +843,7 @@ func TestIncrementalMapUpdatePeersRemoved(t *testing.T) { } tnode1 := all[0] - n2 := newTestNode(t, env) + n2 := NewTestNode(t, env) d2 := n2.StartDaemon() n2.AwaitListening() n2.MustUp() @@ -524,8 +913,8 @@ func TestNodeAddressIPFields(t *testing.T) { tstest.Shard(t) flakytest.Mark(t, "https://github.com/tailscale/tailscale/issues/7008") tstest.Parallel(t) - env := newTestEnv(t) - n1 := newTestNode(t, env) + env := NewTestEnv(t) + n1 := NewTestNode(t, env) d1 := n1.StartDaemon() n1.AwaitListening() @@ -551,8 +940,8 @@ func TestNodeAddressIPFields(t *testing.T) { func TestAddPingRequest(t *testing.T) { tstest.Shard(t) tstest.Parallel(t) - env := newTestEnv(t) - n1 := newTestNode(t, env) + env := NewTestEnv(t) + n1 := NewTestNode(t, env) n1.StartDaemon() n1.AwaitListening() @@ -605,25 +994,9 @@ func TestC2NPingRequest(t *testing.T) { tstest.Shard(t) tstest.Parallel(t) - env := newTestEnv(t) + env := NewTestEnv(t) - gotPing := make(chan bool, 1) - env.Control.HandleC2N = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Method != "POST" { - t.Errorf("unexpected ping method %q", r.Method) - } - got, err := io.ReadAll(r.Body) - if err != nil { - t.Errorf("ping body read error: %v", err) - } - const want = "HTTP/1.1 200 OK\r\nConnection: close\r\nContent-Type: text/plain; charset=utf-8\r\n\r\nabc" - if string(got) != want { - t.Errorf("body error\n got: %q\nwant: %q", got, want) - } - gotPing <- true - }) - - n1 := newTestNode(t, env) + n1 := NewTestNode(t, env) n1.StartDaemon() n1.AwaitListening() @@ -646,27 +1019,33 @@ func TestC2NPingRequest(t *testing.T) { } cancel() - pr := &tailcfg.PingRequest{ - URL: fmt.Sprintf("https://unused/some-c2n-path/ping-%d", try), - Log: true, - Types: "c2n", - Payload: []byte("POST /echo HTTP/1.0\r\nContent-Length: 3\r\n\r\nabc"), + ctx, cancel = context.WithTimeout(t.Context(), 2*time.Second) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, "POST", "/echo", bytes.NewReader([]byte("abc"))) + if err != nil { + t.Errorf("failed to create request: %v", err) + continue } - if !env.Control.AddPingRequest(nodeKey, pr) { - t.Logf("failed to AddPingRequest") + r, err := env.Control.NodeRoundTripper(nodeKey).RoundTrip(req) + if err != nil { + t.Errorf("RoundTrip failed: %v", err) continue } - - // Wait for PingRequest to come back - pingTimeout := time.NewTimer(2 * time.Second) - defer pingTimeout.Stop() - select { - case <-gotPing: - t.Logf("got ping; success") - return - case <-pingTimeout.C: - // Try again. + if r.StatusCode != 200 { + t.Errorf("unexpected status code: %d", r.StatusCode) + continue + } + b, err := io.ReadAll(r.Body) + if err != nil { + t.Errorf("error reading body: %v", err) + continue + } + if string(b) != "abc" { + t.Errorf("body = %q; want %q", b, "abc") + continue } + return } t.Error("all ping attempts failed") } @@ -676,8 +1055,8 @@ func TestC2NPingRequest(t *testing.T) { func TestNoControlConnWhenDown(t *testing.T) { tstest.Shard(t) tstest.Parallel(t) - env := newTestEnv(t) - n1 := newTestNode(t, env) + env := NewTestEnv(t) + n1 := NewTestNode(t, env) d1 := n1.StartDaemon() n1.AwaitResponding() @@ -715,8 +1094,8 @@ func TestNoControlConnWhenDown(t *testing.T) { func TestOneNodeUpWindowsStyle(t *testing.T) { tstest.Shard(t) tstest.Parallel(t) - env := newTestEnv(t) - n1 := newTestNode(t, env) + env := NewTestEnv(t) + n1 := NewTestNode(t, env) n1.upFlagGOOS = "windows" d1 := n1.StartDaemonAsIPNGOOS("windows") @@ -733,11 +1112,12 @@ func TestOneNodeUpWindowsStyle(t *testing.T) { // jailed node cannot initiate connections to the other node however the other // node can initiate connections to the jailed node. func TestClientSideJailing(t *testing.T) { + flakytest.Mark(t, "https://github.com/tailscale/tailscale/issues/17419") tstest.Shard(t) tstest.Parallel(t) - env := newTestEnv(t) - registerNode := func() (*testNode, key.NodePublic) { - n := newTestNode(t, env) + env := NewTestEnv(t) + registerNode := func() (*TestNode, key.NodePublic) { + n := NewTestNode(t, env) n.StartDaemon() n.AwaitListening() n.MustUp() @@ -755,11 +1135,11 @@ func TestClientSideJailing(t *testing.T) { defer ln.Close() port := uint16(ln.Addr().(*net.TCPAddr).Port) - lc1 := &tailscale.LocalClient{ + lc1 := &local.Client{ Socket: n1.sockFile, UseSocketOnly: true, } - lc2 := &tailscale.LocalClient{ + lc2 := &local.Client{ Socket: n2.sockFile, UseSocketOnly: true, } @@ -789,7 +1169,7 @@ func TestClientSideJailing(t *testing.T) { }, } - testDial := func(t *testing.T, lc *tailscale.LocalClient, ip netip.Addr, port uint16, shouldFail bool) { + testDial := func(t *testing.T, lc *local.Client, ip netip.Addr, port uint16, shouldFail bool) { t.Helper() ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() @@ -851,9 +1231,9 @@ func TestNATPing(t *testing.T) { tstest.Shard(t) tstest.Parallel(t) for _, v6 := range []bool{false, true} { - env := newTestEnv(t) - registerNode := func() (*testNode, key.NodePublic) { - n := newTestNode(t, env) + env := NewTestEnv(t) + registerNode := func() (*TestNode, key.NodePublic) { + n := NewTestNode(t, env) n.StartDaemon() n.AwaitListening() n.MustUp() @@ -978,11 +1358,11 @@ func TestNATPing(t *testing.T) { func TestLogoutRemovesAllPeers(t *testing.T) { tstest.Shard(t) tstest.Parallel(t) - env := newTestEnv(t) + env := NewTestEnv(t) // Spin up some nodes. - nodes := make([]*testNode, 2) + nodes := make([]*TestNode, 2) for i := range nodes { - nodes[i] = newTestNode(t, env) + nodes[i] = NewTestNode(t, env) nodes[i].StartDaemon() nodes[i].AwaitResponding() nodes[i].MustUp() @@ -1031,14 +1411,14 @@ func TestLogoutRemovesAllPeers(t *testing.T) { } func TestAutoUpdateDefaults(t *testing.T) { - if !clientupdate.CanAutoUpdate() { + if !feature.CanAutoUpdate() { t.Skip("auto-updates not supported on this platform") } tstest.Shard(t) tstest.Parallel(t) - env := newTestEnv(t) + env := NewTestEnv(t) - checkDefault := func(n *testNode, want bool) error { + checkDefault := func(n *TestNode, want bool) error { enabled, ok := n.diskPrefs().AutoUpdate.Apply.Get() if !ok { return fmt.Errorf("auto-update for node is unset, should be set as %v", want) @@ -1049,7 +1429,7 @@ func TestAutoUpdateDefaults(t *testing.T) { return nil } - sendAndCheckDefault := func(t *testing.T, n *testNode, send, want bool) { + sendAndCheckDefault := func(t *testing.T, n *TestNode, send, want bool) { t.Helper() if !env.Control.AddRawMapResponse(n.MustStatus().Self.PublicKey, &tailcfg.MapResponse{ DefaultAutoUpdate: opt.NewBool(send), @@ -1065,11 +1445,11 @@ func TestAutoUpdateDefaults(t *testing.T) { tests := []struct { desc string - run func(t *testing.T, n *testNode) + run func(t *testing.T, n *TestNode) }{ { desc: "tailnet-default-false", - run: func(t *testing.T, n *testNode) { + run: func(t *testing.T, n *TestNode) { // First received default "false". sendAndCheckDefault(t, n, false, false) // Should not be changed even if sent "true" later. @@ -1083,7 +1463,7 @@ func TestAutoUpdateDefaults(t *testing.T) { }, { desc: "tailnet-default-true", - run: func(t *testing.T, n *testNode) { + run: func(t *testing.T, n *TestNode) { // First received default "true". sendAndCheckDefault(t, n, true, true) // Should not be changed even if sent "false" later. @@ -1097,7 +1477,7 @@ func TestAutoUpdateDefaults(t *testing.T) { }, { desc: "user-sets-first", - run: func(t *testing.T, n *testNode) { + run: func(t *testing.T, n *TestNode) { // User sets auto-update first, before receiving defaults. if out, err := n.TailscaleForOutput("set", "--auto-update=false").CombinedOutput(); err != nil { t.Fatalf("failed to disable auto-update on node: %v\noutput: %s", err, out) @@ -1110,7 +1490,7 @@ func TestAutoUpdateDefaults(t *testing.T) { } for _, tt := range tests { t.Run(tt.desc, func(t *testing.T) { - n := newTestNode(t, env) + n := NewTestNode(t, env) d := n.StartDaemon() defer d.MustCleanShutdown(t) @@ -1132,25 +1512,16 @@ func TestDNSOverTCPIntervalResolver(t *testing.T) { if os.Getuid() != 0 { t.Skip("skipping when not root") } - env := newTestEnv(t) + env := NewTestEnv(t) env.tunMode = true - n1 := newTestNode(t, env) + n1 := NewTestNode(t, env) d1 := n1.StartDaemon() n1.AwaitResponding() n1.MustUp() - - wantIP4 := n1.AwaitIP4() n1.AwaitRunning() - status, err := n1.Status() - if err != nil { - t.Fatalf("failed to get node status: %v", err) - } - selfDNSName, err := dnsname.ToFQDN(status.Self.DNSName) - if err != nil { - t.Fatalf("error converting self dns name to fqdn: %v", err) - } + const dnsSymbolicFQDN = "magicdns.localhost-tailscale-daemon." cases := []struct { network string @@ -1166,9 +1537,9 @@ func TestDNSOverTCPIntervalResolver(t *testing.T) { }, } for _, c := range cases { - err = tstest.WaitFor(time.Second*5, func() error { + err := tstest.WaitFor(time.Second*5, func() error { m := new(dns.Msg) - m.SetQuestion(selfDNSName.WithTrailingDot(), dns.TypeA) + m.SetQuestion(dnsSymbolicFQDN, dns.TypeA) conn, err := net.DialTimeout(c.network, net.JoinHostPort(c.serviceAddr.String(), "53"), time.Second*1) if err != nil { return err @@ -1193,8 +1564,8 @@ func TestDNSOverTCPIntervalResolver(t *testing.T) { return fmt.Errorf("unexpected answer type: %s", resp.Answer[0]) } gotAddr = answer.A - if !bytes.Equal(gotAddr, wantIP4.AsSlice()) { - return fmt.Errorf("got (%s) != want (%s)", gotAddr, wantIP4) + if !bytes.Equal(gotAddr, tsaddr.TailscaleServiceIP().AsSlice()) { + return fmt.Errorf("got (%s) != want (%s)", gotAddr, tsaddr.TailscaleServiceIP()) } return nil }) @@ -1214,12 +1585,12 @@ func TestNetstackTCPLoopback(t *testing.T) { t.Skip("skipping when not root") } - env := newTestEnv(t) + env := NewTestEnv(t) env.tunMode = true loopbackPort := 5201 env.loopbackPort = &loopbackPort loopbackPortStr := strconv.Itoa(loopbackPort) - n1 := newTestNode(t, env) + n1 := NewTestNode(t, env) d1 := n1.StartDaemon() n1.AwaitResponding() @@ -1356,11 +1727,11 @@ func TestNetstackUDPLoopback(t *testing.T) { t.Skip("skipping when not root") } - env := newTestEnv(t) + env := NewTestEnv(t) env.tunMode = true loopbackPort := 5201 env.loopbackPort = &loopbackPort - n1 := newTestNode(t, env) + n1 := NewTestNode(t, env) d1 := n1.StartDaemon() n1.AwaitResponding() @@ -1495,580 +1866,383 @@ func TestNetstackUDPLoopback(t *testing.T) { d1.MustCleanShutdown(t) } -// testEnv contains the test environment (set of servers) used by one -// or more nodes. -type testEnv struct { - t testing.TB - tunMode bool - cli string - daemon string - loopbackPort *int +func TestEncryptStateMigration(t *testing.T) { + if !hostinfo.New().TPM.Present() { + t.Skip("TPM not available") + } + if runtime.GOOS != "linux" && runtime.GOOS != "windows" { + t.Skip("--encrypt-state for tailscaled state not supported on this platform") + } + tstest.Shard(t) + tstest.Parallel(t) + env := NewTestEnv(t) + n := NewTestNode(t, env) - LogCatcher *LogCatcher - LogCatcherServer *httptest.Server + runNode := func(t *testing.T, wantStateKeys []string) { + t.Helper() - Control *testcontrol.Server - ControlServer *httptest.Server + // Run the node. + d := n.StartDaemon() + n.AwaitResponding() + n.MustUp() + n.AwaitRunning() - TrafficTrap *trafficTrap - TrafficTrapServer *httptest.Server -} + // Check the contents of the state file. + buf, err := os.ReadFile(n.stateFile) + if err != nil { + t.Fatalf("reading %q: %v", n.stateFile, err) + } + t.Logf("state file content:\n%s", buf) + var content map[string]any + if err := json.Unmarshal(buf, &content); err != nil { + t.Fatalf("parsing %q: %v", n.stateFile, err) + } + for _, k := range wantStateKeys { + if _, ok := content[k]; !ok { + t.Errorf("state file is missing key %q", k) + } + } -// controlURL returns e.ControlServer.URL, panicking if it's the empty string, -// which it should never be in tests. -func (e *testEnv) controlURL() string { - s := e.ControlServer.URL - if s == "" { - panic("control server not set") + // Stop the node. + d.MustCleanShutdown(t) } - return s -} - -type testEnvOpt interface { - modifyTestEnv(*testEnv) -} -type configureControl func(*testcontrol.Server) - -func (f configureControl) modifyTestEnv(te *testEnv) { - f(te.Control) -} - -// newTestEnv starts a bunch of services and returns a new test environment. -// newTestEnv arranges for the environment's resources to be cleaned up on exit. -func newTestEnv(t testing.TB, opts ...testEnvOpt) *testEnv { - if runtime.GOOS == "windows" { - t.Skip("not tested/working on Windows yet") - } - derpMap := RunDERPAndSTUN(t, logger.Discard, "127.0.0.1") - logc := new(LogCatcher) - control := &testcontrol.Server{ - DERPMap: derpMap, - } - control.HTTPTestServer = httptest.NewUnstartedServer(control) - trafficTrap := new(trafficTrap) - e := &testEnv{ - t: t, - cli: TailscaleBinary(t), - daemon: TailscaledBinary(t), - LogCatcher: logc, - LogCatcherServer: httptest.NewServer(logc), - Control: control, - ControlServer: control.HTTPTestServer, - TrafficTrap: trafficTrap, - TrafficTrapServer: httptest.NewServer(trafficTrap), - } - for _, o := range opts { - o.modifyTestEnv(e) - } - control.HTTPTestServer.Start() - t.Cleanup(func() { - // Shut down e. - if err := e.TrafficTrap.Err(); err != nil { - e.t.Errorf("traffic trap: %v", err) - e.t.Logf("logs: %s", e.LogCatcher.logsString()) - } - e.LogCatcherServer.Close() - e.TrafficTrapServer.Close() - e.ControlServer.Close() + wantPlaintextStateKeys := []string{"_machinekey", "_current-profile", "_profiles"} + wantEncryptedStateKeys := []string{"key", "nonce", "data"} + t.Run("regular-state", func(t *testing.T) { + n.encryptState = false + runNode(t, wantPlaintextStateKeys) + }) + t.Run("migrate-to-encrypted", func(t *testing.T) { + n.encryptState = true + runNode(t, wantEncryptedStateKeys) + }) + t.Run("migrate-to-plaintext", func(t *testing.T) { + n.encryptState = false + runNode(t, wantPlaintextStateKeys) }) - t.Logf("control URL: %v", e.controlURL()) - return e } -// testNode is a machine with a tailscale & tailscaled. -// Currently, the test is simplistic and user==node==machine. -// That may grow complexity later to test more. -type testNode struct { - env *testEnv - tailscaledParser *nodeOutputParser - - dir string // temp dir for sock & state - configFile string // or empty for none - sockFile string - stateFile string - upFlagGOOS string // if non-empty, sets TS_DEBUG_UP_FLAG_GOOS for cmd/tailscale CLI - - mu sync.Mutex - onLogLine []func([]byte) -} +// TestPeerRelayPing creates three nodes with one acting as a peer relay. +// The test succeeds when "tailscale ping" flows through the peer +// relay between all 3 nodes, and "tailscale debug peer-relay-sessions" returns +// expected values. +func TestPeerRelayPing(t *testing.T) { + flakytest.Mark(t, "https://github.com/tailscale/tailscale/issues/17251") + tstest.Shard(t) + tstest.Parallel(t) -// newTestNode allocates a temp directory for a new test node. -// The node is not started automatically. -func newTestNode(t *testing.T, env *testEnv) *testNode { - dir := t.TempDir() - sockFile := filepath.Join(dir, "tailscale.sock") - if len(sockFile) >= 104 { - // Maximum length for a unix socket on darwin. Try something else. - sockFile = filepath.Join(os.TempDir(), rands.HexString(8)+".sock") - t.Cleanup(func() { os.Remove(sockFile) }) - } - n := &testNode{ - env: env, - dir: dir, - sockFile: sockFile, - stateFile: filepath.Join(dir, "tailscale.state"), - } - - // Look for a data race. Once we see the start marker, start logging the rest. - var sawRace bool - var sawPanic bool - n.addLogLineHook(func(line []byte) { - lineB := mem.B(line) - if mem.Contains(lineB, mem.S("WARNING: DATA RACE")) { - sawRace = true - } - if mem.HasPrefix(lineB, mem.S("panic: ")) { - sawPanic = true - } - if sawRace || sawPanic { - t.Logf("%s", line) - } - }) + env := NewTestEnv(t, ConfigureControl(func(server *testcontrol.Server) { + server.PeerRelayGrants = true + })) + env.neverDirectUDP = true + env.relayServerUseLoopback = true - return n -} + n1 := NewTestNode(t, env) + n2 := NewTestNode(t, env) + peerRelay := NewTestNode(t, env) -func (n *testNode) diskPrefs() *ipn.Prefs { - t := n.env.t - t.Helper() - if _, err := os.ReadFile(n.stateFile); err != nil { - t.Fatalf("reading prefs: %v", err) - } - fs, err := store.NewFileStore(nil, n.stateFile) - if err != nil { - t.Fatalf("reading prefs, NewFileStore: %v", err) - } - p, err := ipnlocal.ReadStartupPrefsForTest(t.Logf, fs) - if err != nil { - t.Fatalf("reading prefs, ReadDiskPrefsForTest: %v", err) + allNodes := []*TestNode{n1, n2, peerRelay} + wantPeerRelayServers := make(set.Set[string]) + for _, n := range allNodes { + n.StartDaemon() + n.AwaitResponding() + n.MustUp() + wantPeerRelayServers.Add(n.AwaitIP4().String()) + n.AwaitRunning() } - return p.AsStruct() -} - -// AwaitResponding waits for n's tailscaled to be up enough to be -// responding, but doesn't wait for any particular state. -func (n *testNode) AwaitResponding() { - t := n.env.t - t.Helper() - n.AwaitListening() - st := n.MustStatus() - t.Logf("Status: %s", st.BackendState) - - if err := tstest.WaitFor(20*time.Second, func() error { - const sub = `Program starting: ` - if !n.env.LogCatcher.logsContains(mem.S(sub)) { - return fmt.Errorf("log catcher didn't see %#q; got %s", sub, n.env.LogCatcher.logsString()) - } - return nil - }); err != nil { + if err := peerRelay.Tailscale("set", "--relay-server-port=0").Run(); err != nil { t.Fatal(err) } -} -// addLogLineHook registers a hook f to be called on each tailscaled -// log line output. -func (n *testNode) addLogLineHook(f func([]byte)) { - n.mu.Lock() - defer n.mu.Unlock() - n.onLogLine = append(n.onLogLine, f) -} + errCh := make(chan error) + for _, a := range allNodes { + go func() { + err := tstest.WaitFor(time.Second*5, func() error { + out, err := a.Tailscale("debug", "peer-relay-servers").CombinedOutput() + if err != nil { + return fmt.Errorf("debug peer-relay-servers failed: %v", err) + } + servers := make([]string, 0) + err = json.Unmarshal(out, &servers) + if err != nil { + return fmt.Errorf("failed to unmarshal debug peer-relay-servers: %v", err) + } + gotPeerRelayServers := make(set.Set[string]) + for _, server := range servers { + gotPeerRelayServers.Add(server) + } + if !gotPeerRelayServers.Equal(wantPeerRelayServers) { + return fmt.Errorf("got peer relay servers: %v want: %v", gotPeerRelayServers, wantPeerRelayServers) + } + return nil + }) + errCh <- err + }() + } + for range allNodes { + err := <-errCh + if err != nil { + t.Fatal(err) + } + } -// socks5AddrChan returns a channel that receives the address (e.g. "localhost:23874") -// of the node's SOCKS5 listener, once started. -func (n *testNode) socks5AddrChan() <-chan string { - ch := make(chan string, 1) - n.addLogLineHook(func(line []byte) { - const sub = "SOCKS5 listening on " - i := mem.Index(mem.B(line), mem.S(sub)) - if i == -1 { - return + pingPairs := make([][2]*TestNode, 0) + for _, a := range allNodes { + for _, z := range allNodes { + if a == z { + continue + } + pingPairs = append(pingPairs, [2]*TestNode{a, z}) } - addr := strings.TrimSpace(string(line)[i+len(sub):]) - select { - case ch <- addr: - default: + } + for _, pair := range pingPairs { + go func() { + a := pair[0] + z := pair[1] + err := tstest.WaitFor(time.Second*10, func() error { + remoteKey := z.MustStatus().Self.PublicKey + if err := a.Tailscale("ping", "--until-direct=false", "--c=1", "--timeout=1s", z.AwaitIP4().String()).Run(); err != nil { + return err + } + remotePeer, ok := a.MustStatus().Peer[remoteKey] + if !ok { + return fmt.Errorf("%v->%v remote peer not found", a.MustStatus().Self.ID, z.MustStatus().Self.ID) + } + if len(remotePeer.PeerRelay) == 0 { + return fmt.Errorf("%v->%v not using peer relay, curAddr=%v relay=%v", a.MustStatus().Self.ID, z.MustStatus().Self.ID, remotePeer.CurAddr, remotePeer.Relay) + } + t.Logf("%v->%v using peer relay addr: %v", a.MustStatus().Self.ID, z.MustStatus().Self.ID, remotePeer.PeerRelay) + return nil + }) + errCh <- err + }() + } + for range pingPairs { + err := <-errCh + if err != nil { + t.Fatal(err) } - }) - return ch -} - -func (n *testNode) AwaitSocksAddr(ch <-chan string) string { - t := n.env.t - t.Helper() - timer := time.NewTimer(10 * time.Second) - defer timer.Stop() - select { - case v := <-ch: - return v - case <-timer.C: - t.Fatal("timeout waiting for node to log its SOCK5 listening address") - panic("unreachable") } -} - -// nodeOutputParser parses stderr of tailscaled processes, calling the -// per-line callbacks previously registered via -// testNode.addLogLineHook. -type nodeOutputParser struct { - allBuf bytes.Buffer - pendLineBuf bytes.Buffer - n *testNode -} - -func (op *nodeOutputParser) Write(p []byte) (n int, err error) { - tn := op.n - tn.mu.Lock() - defer tn.mu.Unlock() - - op.allBuf.Write(p) - n, err = op.pendLineBuf.Write(p) - op.parseLinesLocked() - return -} -func (op *nodeOutputParser) parseLinesLocked() { - n := op.n - buf := op.pendLineBuf.Bytes() - for len(buf) > 0 { - nl := bytes.IndexByte(buf, '\n') - if nl == -1 { + allControlNodes := env.Control.AllNodes() + wantSessionsForDiscoShorts := make(set.Set[[2]string]) + for i, a := range allControlNodes { + if i == len(allControlNodes)-1 { break } - line := buf[:nl+1] - buf = buf[nl+1:] - - for _, f := range n.onLogLine { - f(line) + for _, z := range allControlNodes[i+1:] { + wantSessionsForDiscoShorts.Add([2]string{a.DiscoKey.ShortString(), z.DiscoKey.ShortString()}) } } - if len(buf) == 0 { - op.pendLineBuf.Reset() - } else { - io.CopyN(io.Discard, &op.pendLineBuf, int64(op.pendLineBuf.Len()-len(buf))) - } -} -type Daemon struct { - Process *os.Process -} - -func (d *Daemon) MustCleanShutdown(t testing.TB) { - d.Process.Signal(os.Interrupt) - ps, err := d.Process.Wait() + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + debugSessions, err := peerRelay.LocalClient().DebugPeerRelaySessions(ctx) + cancel() if err != nil { - t.Fatalf("tailscaled Wait: %v", err) + t.Fatalf("debug peer-relay-sessions failed: %v", err) } - if ps.ExitCode() != 0 { - t.Errorf("tailscaled ExitCode = %d; want 0", ps.ExitCode()) + if len(debugSessions.Sessions) != len(wantSessionsForDiscoShorts) { + t.Errorf("got %d peer relay sessions, want %d", len(debugSessions.Sessions), len(wantSessionsForDiscoShorts)) } -} - -// StartDaemon starts the node's tailscaled, failing if it fails to start. -// StartDaemon ensures that the process will exit when the test completes. -func (n *testNode) StartDaemon() *Daemon { - return n.StartDaemonAsIPNGOOS(runtime.GOOS) -} - -func (n *testNode) StartDaemonAsIPNGOOS(ipnGOOS string) *Daemon { - t := n.env.t - cmd := exec.Command(n.env.daemon) - cmd.Args = append(cmd.Args, - "--state="+n.stateFile, - "--socket="+n.sockFile, - "--socks5-server=localhost:0", - ) - if *verboseTailscaled { - cmd.Args = append(cmd.Args, "-verbose=2") - } - if !n.env.tunMode { - cmd.Args = append(cmd.Args, - "--tun=userspace-networking", - ) - } - if n.configFile != "" { - cmd.Args = append(cmd.Args, "--config="+n.configFile) - } - cmd.Env = append(os.Environ(), - "TS_CONTROL_IS_PLAINTEXT_HTTP=1", - "TS_DEBUG_PERMIT_HTTP_C2N=1", - "TS_LOG_TARGET="+n.env.LogCatcherServer.URL, - "HTTP_PROXY="+n.env.TrafficTrapServer.URL, - "HTTPS_PROXY="+n.env.TrafficTrapServer.URL, - "TS_DEBUG_FAKE_GOOS="+ipnGOOS, - "TS_LOGS_DIR="+t.TempDir(), - "TS_NETCHECK_GENERATE_204_URL="+n.env.ControlServer.URL+"/generate_204", - "TS_ASSUME_NETWORK_UP_FOR_TEST=1", // don't pause control client in airplane mode (no wifi, etc) - "TS_PANIC_IF_HIT_MAIN_CONTROL=1", - "TS_DISABLE_PORTMAPPER=1", // shouldn't be needed; test is all localhost - "TS_DEBUG_LOG_RATE=all", - ) - if n.env.loopbackPort != nil { - cmd.Env = append(cmd.Env, "TS_DEBUG_NETSTACK_LOOPBACK_PORT="+strconv.Itoa(*n.env.loopbackPort)) - } - if version.IsRace() { - cmd.Env = append(cmd.Env, "GORACE=halt_on_error=1") - } - n.tailscaledParser = &nodeOutputParser{n: n} - cmd.Stderr = n.tailscaledParser - if *verboseTailscaled { - cmd.Stdout = os.Stdout - cmd.Stderr = io.MultiWriter(cmd.Stderr, os.Stderr) - } - if runtime.GOOS != "windows" { - pr, pw, err := os.Pipe() - if err != nil { - t.Fatal(err) + for _, session := range debugSessions.Sessions { + if !wantSessionsForDiscoShorts.Contains([2]string{session.Client1.ShortDisco, session.Client2.ShortDisco}) && + !wantSessionsForDiscoShorts.Contains([2]string{session.Client2.ShortDisco, session.Client1.ShortDisco}) { + t.Errorf("peer relay session for disco keys %s<->%s not found in debug peer-relay-sessions: %+v", session.Client1.ShortDisco, session.Client2.ShortDisco, debugSessions.Sessions) + } + for _, client := range []status.ClientInfo{session.Client1, session.Client2} { + if client.BytesTx == 0 { + t.Errorf("unexpected 0 bytes TX counter in peer relay session: %+v", session) + } + if client.PacketsTx == 0 { + t.Errorf("unexpected 0 packets TX counter in peer relay session: %+v", session) + } + if !client.Endpoint.IsValid() { + t.Errorf("unexpected endpoint zero value in peer relay session: %+v", session) + } + if len(client.ShortDisco) == 0 { + t.Errorf("unexpected zero len short disco in peer relay session: %+v", session) + } } - t.Cleanup(func() { pw.Close() }) - cmd.ExtraFiles = append(cmd.ExtraFiles, pr) - cmd.Env = append(cmd.Env, "TS_PARENT_DEATH_FD=3") - } - if err := cmd.Start(); err != nil { - t.Fatalf("starting tailscaled: %v", err) - } - t.Cleanup(func() { cmd.Process.Kill() }) - return &Daemon{ - Process: cmd.Process, } } -func (n *testNode) MustUp(extraArgs ...string) { - t := n.env.t - t.Helper() - args := []string{ - "up", - "--login-server=" + n.env.controlURL(), - "--reset", - } - args = append(args, extraArgs...) - cmd := n.Tailscale(args...) - t.Logf("Running %v ...", cmd) - cmd.Stdout = nil // in case --verbose-tailscale was set - cmd.Stderr = nil // in case --verbose-tailscale was set - if b, err := cmd.CombinedOutput(); err != nil { - t.Fatalf("up: %v, %v", string(b), err) - } -} +func TestC2NDebugNetmap(t *testing.T) { + tstest.Shard(t) + tstest.Parallel(t) + env := NewTestEnv(t, ConfigureControl(func(s *testcontrol.Server) { + s.CollectServices = opt.False + })) -func (n *testNode) MustDown() { - t := n.env.t - t.Logf("Running down ...") - if err := n.Tailscale("down", "--accept-risk=all").Run(); err != nil { - t.Fatalf("down: %v", err) - } -} + var testNodes []*TestNode + var nodes []*tailcfg.Node + for i := range 2 { + n := NewTestNode(t, env) + d := n.StartDaemon() + defer d.MustCleanShutdown(t) + + n.AwaitResponding() + n.MustUp() + n.AwaitRunning() + testNodes = append(testNodes, n) -func (n *testNode) MustLogOut() { - t := n.env.t - t.Logf("Running logout ...") - if err := n.Tailscale("logout").Run(); err != nil { - t.Fatalf("logout: %v", err) + controlNodes := env.Control.AllNodes() + if len(controlNodes) != i+1 { + t.Fatalf("expected %d nodes, got %d nodes", i+1, len(controlNodes)) + } + for _, cn := range controlNodes { + if n.MustStatus().Self.PublicKey == cn.Key { + nodes = append(nodes, cn) + break + } + } } -} -func (n *testNode) Ping(otherNode *testNode) error { - t := n.env.t - ip := otherNode.AwaitIP4().String() - t.Logf("Running ping %v (from %v)...", ip, n.AwaitIP4()) - return n.Tailscale("ping", ip).Run() -} + // getC2NNetmap fetches the current netmap. If a candidate map response is provided, + // a candidate netmap is also fetched and compared to the current netmap. + getC2NNetmap := func(node key.NodePublic, cand *tailcfg.MapResponse) *netmap.NetworkMap { + t.Helper() + ctx, cancel := context.WithTimeout(t.Context(), 5*time.Second) + defer cancel() -// AwaitListening waits for the tailscaled to be serving local clients -// over its localhost IPC mechanism. (Unix socket, etc) -func (n *testNode) AwaitListening() { - t := n.env.t - if err := tstest.WaitFor(20*time.Second, func() (err error) { - c, err := safesocket.ConnectContext(context.Background(), n.sockFile) - if err == nil { - c.Close() + var req *http.Request + if cand != nil { + body := must.Get(json.Marshal(&tailcfg.C2NDebugNetmapRequest{Candidate: cand})) + req = must.Get(http.NewRequestWithContext(ctx, "POST", "/debug/netmap", bytes.NewReader(body))) + } else { + req = must.Get(http.NewRequestWithContext(ctx, "GET", "/debug/netmap", nil)) } - return err - }); err != nil { - t.Fatal(err) - } -} + httpResp := must.Get(env.Control.NodeRoundTripper(node).RoundTrip(req)) + defer httpResp.Body.Close() -func (n *testNode) AwaitIPs() []netip.Addr { - t := n.env.t - t.Helper() - var addrs []netip.Addr - if err := tstest.WaitFor(20*time.Second, func() error { - cmd := n.Tailscale("ip") - cmd.Stdout = nil // in case --verbose-tailscale was set - cmd.Stderr = nil // in case --verbose-tailscale was set - out, err := cmd.Output() - if err != nil { - return err + if httpResp.StatusCode != 200 { + t.Errorf("unexpected status code: %d", httpResp.StatusCode) + return nil } - ips := string(out) - ipslice := strings.Fields(ips) - addrs = make([]netip.Addr, len(ipslice)) - for i, ip := range ipslice { - netIP, err := netip.ParseAddr(ip) - if err != nil { - t.Fatal(err) + respBody := must.Get(io.ReadAll(httpResp.Body)) + var resp tailcfg.C2NDebugNetmapResponse + must.Do(json.Unmarshal(respBody, &resp)) + + var current netmap.NetworkMap + must.Do(json.Unmarshal(resp.Current, ¤t)) + + // Check candidate netmap if we sent a map response. + if cand != nil { + var candidate netmap.NetworkMap + must.Do(json.Unmarshal(resp.Candidate, &candidate)) + if diff := cmp.Diff(current.SelfNode, candidate.SelfNode); diff != "" { + t.Errorf("SelfNode differs (-current +candidate):\n%s", diff) + } + if diff := cmp.Diff(current.Peers, candidate.Peers); diff != "" { + t.Errorf("Peers differ (-current +candidate):\n%s", diff) } - addrs[i] = netIP } - return nil - }); err != nil { - t.Fatalf("awaiting an IP address: %v", err) - } - if len(addrs) == 0 { - t.Fatalf("returned IP address was blank") + return ¤t } - return addrs -} -// AwaitIP4 returns the IPv4 address of n. -func (n *testNode) AwaitIP4() netip.Addr { - t := n.env.t - t.Helper() - ips := n.AwaitIPs() - return ips[0] -} + for _, n := range nodes { + mr := must.Get(env.Control.MapResponse(&tailcfg.MapRequest{NodeKey: n.Key})) + nm := getC2NNetmap(n.Key, mr) -// AwaitIP6 returns the IPv6 address of n. -func (n *testNode) AwaitIP6() netip.Addr { - t := n.env.t - t.Helper() - ips := n.AwaitIPs() - return ips[1] -} + // Make sure peers do not have "testcap" initially (we'll change this later). + if len(nm.Peers) != 1 || nm.Peers[0].CapMap().Contains("testcap") { + t.Fatalf("expected 1 peer without testcap, got: %v", nm.Peers) + } -// AwaitRunning waits for n to reach the IPN state "Running". -func (n *testNode) AwaitRunning() { - n.AwaitBackendState("Running") -} + // Make sure nodes think each other are offline initially. + if nm.Peers[0].Online().Get() { + t.Fatalf("expected 1 peer to be offline, got: %v", nm.Peers) + } + } -func (n *testNode) AwaitBackendState(state string) { - t := n.env.t - t.Helper() - if err := tstest.WaitFor(20*time.Second, func() error { - st, err := n.Status() - if err != nil { - return err + // Send a delta update to n0, setting "testcap" on node 1. + env.Control.AddRawMapResponse(nodes[0].Key, &tailcfg.MapResponse{ + PeersChangedPatch: []*tailcfg.PeerChange{{ + NodeID: nodes[1].ID, CapMap: tailcfg.NodeCapMap{"testcap": []tailcfg.RawMessage{}}, + }}, + }) + + // node 0 should see node 1 with "testcap". + must.Do(tstest.WaitFor(5*time.Second, func() error { + st := testNodes[0].MustStatus() + p, ok := st.Peer[nodes[1].Key] + if !ok { + return fmt.Errorf("node 0 (%s) doesn't see node 1 (%s) as peer\n%v", nodes[0].Key, nodes[1].Key, st) } - if st.BackendState != state { - return fmt.Errorf("in state %q; want %q", st.BackendState, state) + if _, ok := p.CapMap["testcap"]; !ok { + return fmt.Errorf("node 0 (%s) sees node 1 (%s) as peer but without testcap\n%v", nodes[0].Key, nodes[1].Key, p) } return nil - }); err != nil { - t.Fatalf("failure/timeout waiting for transition to Running status: %v", err) + })) + + // Check that node 0's current netmap has "testcap" for node 1. + nm := getC2NNetmap(nodes[0].Key, nil) + if len(nm.Peers) != 1 || !nm.Peers[0].CapMap().Contains("testcap") { + t.Errorf("current netmap missing testcap: %v", nm.Peers[0].CapMap()) } -} -// AwaitNeedsLogin waits for n to reach the IPN state "NeedsLogin". -func (n *testNode) AwaitNeedsLogin() { - t := n.env.t - t.Helper() - if err := tstest.WaitFor(20*time.Second, func() error { - st, err := n.Status() - if err != nil { - return err - } - if st.BackendState != "NeedsLogin" { - return fmt.Errorf("in state %q", st.BackendState) + // Send a delta update to n1, marking node 0 as online. + env.Control.AddRawMapResponse(nodes[1].Key, &tailcfg.MapResponse{ + PeersChangedPatch: []*tailcfg.PeerChange{{ + NodeID: nodes[0].ID, Online: ptr.To(true), + }}, + }) + + // node 1 should see node 0 as online. + must.Do(tstest.WaitFor(5*time.Second, func() error { + st := testNodes[1].MustStatus() + p, ok := st.Peer[nodes[0].Key] + if !ok || !p.Online { + return fmt.Errorf("node 0 (%s) doesn't see node 1 (%s) as an online peer\n%v", nodes[0].Key, nodes[1].Key, st) } return nil - }); err != nil { - t.Fatalf("failure/timeout waiting for transition to NeedsLogin status: %v", err) + })) + + // The netmap from node 1 should show node 0 as online. + nm = getC2NNetmap(nodes[1].Key, nil) + if len(nm.Peers) != 1 || !nm.Peers[0].Online().Get() { + t.Errorf("expected peer to be online; got %+v", nm.Peers[0].AsStruct()) } } -func (n *testNode) TailscaleForOutput(arg ...string) *exec.Cmd { - cmd := n.Tailscale(arg...) - cmd.Stdout = nil - cmd.Stderr = nil - return cmd -} +func TestNetworkLock(t *testing.T) { -// Tailscale returns a command that runs the tailscale CLI with the provided arguments. -// It does not start the process. -func (n *testNode) Tailscale(arg ...string) *exec.Cmd { - cmd := exec.Command(n.env.cli) - cmd.Args = append(cmd.Args, "--socket="+n.sockFile) - cmd.Args = append(cmd.Args, arg...) - cmd.Dir = n.dir - cmd.Env = append(os.Environ(), - "TS_DEBUG_UP_FLAG_GOOS="+n.upFlagGOOS, - "TS_LOGS_DIR="+n.env.t.TempDir(), - ) - if *verboseTailscale { - cmd.Stdout = os.Stdout - cmd.Stderr = os.Stderr - } - return cmd -} + // If you run `tailscale lock log` on a node where Tailnet Lock isn't + // enabled, you get an error explaining that. + t.Run("log-when-not-enabled", func(t *testing.T) { + tstest.Shard(t) + t.Parallel() -func (n *testNode) Status() (*ipnstate.Status, error) { - cmd := n.Tailscale("status", "--json") - cmd.Stdout = nil // in case --verbose-tailscale was set - cmd.Stderr = nil // in case --verbose-tailscale was set - out, err := cmd.CombinedOutput() - if err != nil { - return nil, fmt.Errorf("running tailscale status: %v, %s", err, out) - } - st := new(ipnstate.Status) - if err := json.Unmarshal(out, st); err != nil { - return nil, fmt.Errorf("decoding tailscale status JSON: %w", err) - } - return st, nil -} + env := NewTestEnv(t) + n1 := NewTestNode(t, env) + d1 := n1.StartDaemon() + defer d1.MustCleanShutdown(t) -func (n *testNode) MustStatus() *ipnstate.Status { - tb := n.env.t - tb.Helper() - st, err := n.Status() - if err != nil { - tb.Fatal(err) - } - return st -} + n1.MustUp() + n1.AwaitRunning() -// trafficTrap is an HTTP proxy handler to note whether any -// HTTP traffic tries to leave localhost from tailscaled. We don't -// expect any, so any request triggers a failure. -type trafficTrap struct { - atomicErr syncs.AtomicValue[error] -} + cmdArgs := []string{"lock", "log"} + t.Logf("Running command: %s", strings.Join(cmdArgs, " ")) -func (tt *trafficTrap) Err() error { - return tt.atomicErr.Load() -} + var outBuf, errBuf bytes.Buffer -func (tt *trafficTrap) ServeHTTP(w http.ResponseWriter, r *http.Request) { - var got bytes.Buffer - r.Write(&got) - err := fmt.Errorf("unexpected HTTP request via proxy: %s", got.Bytes()) - mainError.Store(err) - if tt.Err() == nil { - // Best effort at remembering the first request. - tt.atomicErr.Store(err) - } - log.Printf("Error: %v", err) - w.WriteHeader(403) -} + cmd := n1.Tailscale(cmdArgs...) + cmd.Stdout = &outBuf + cmd.Stderr = &errBuf -type authURLParserWriter struct { - buf bytes.Buffer - fn func(urlStr string) error -} + if err := cmd.Run(); !isNonZeroExitCode(err) { + t.Fatalf("command did not fail with non-zero exit code: %q", err) + } -var authURLRx = regexp.MustCompile(`(https?://\S+/auth/\S+)`) + if outBuf.String() != "" { + t.Fatalf("stdout: want '', got %q", outBuf.String()) + } -func (w *authURLParserWriter) Write(p []byte) (n int, err error) { - n, err = w.buf.Write(p) - m := authURLRx.FindSubmatch(w.buf.Bytes()) - if m != nil { - urlStr := string(m[1]) - w.buf.Reset() // so it's not matched again - if err := w.fn(urlStr); err != nil { - return 0, err + wantErr := "Tailnet Lock is not enabled\n" + if errBuf.String() != wantErr { + t.Fatalf("stderr: want %q, got %q", wantErr, errBuf.String()) } - } - return n, err + }) } diff --git a/tstest/integration/nat/nat_test.go b/tstest/integration/nat/nat_test.go index 535515588..15f126985 100644 --- a/tstest/integration/nat/nat_test.go +++ b/tstest/integration/nat/nat_test.go @@ -32,6 +32,7 @@ import ( ) var ( + runVMTests = flag.Bool("run-vm-tests", false, "run tests that require a VM") logTailscaled = flag.Bool("log-tailscaled", false, "log tailscaled output") pcapFile = flag.String("pcap", "", "write pcap to file") ) @@ -59,8 +60,25 @@ func newNatTest(tb testing.TB) *natTest { base: filepath.Join(modRoot, "gokrazy/natlabapp.qcow2"), } + if !*runVMTests { + tb.Skip("skipping heavy test; set --run-vm-tests to run") + } + if _, err := os.Stat(nt.base); err != nil { - tb.Skipf("skipping test; base image %q not found", nt.base) + if !os.IsNotExist(err) { + tb.Fatal(err) + } + tb.Logf("building VM image...") + cmd := exec.Command("make", "natlab") + cmd.Dir = filepath.Join(modRoot, "gokrazy") + cmd.Stderr = os.Stderr + cmd.Stdout = os.Stdout + if err := cmd.Run(); err != nil { + tb.Fatalf("Error running 'make natlab' in gokrazy directory") + } + if _, err := os.Stat(nt.base); err != nil { + tb.Skipf("still can't find VM image: %v", err) + } } nt.kernel, err = findKernelPath(filepath.Join(modRoot, "gokrazy/natlabapp/builddir/github.com/tailscale/gokrazy-kernel/go.mod")) @@ -218,6 +236,22 @@ func hard(c *vnet.Config) *vnet.Node { fmt.Sprintf("10.0.%d.1/24", n), vnet.HardNAT)) } +func hardNoDERPOrEndoints(c *vnet.Config) *vnet.Node { + n := c.NumNodes() + 1 + return c.AddNode(c.AddNetwork( + fmt.Sprintf("2.%d.%d.%d", n, n, n), // public IP + fmt.Sprintf("10.0.%d.1/24", n), vnet.HardNAT), + vnet.TailscaledEnv{ + Key: "TS_DEBUG_STRIP_ENDPOINTS", + Value: "1", + }, + vnet.TailscaledEnv{ + Key: "TS_DEBUG_STRIP_HOME_DERP", + Value: "1", + }, + ) +} + func hardPMP(c *vnet.Config) *vnet.Node { n := c.NumNodes() + 1 return c.AddNode(c.AddNetwork( @@ -492,6 +526,26 @@ func TestEasyEasy(t *testing.T) { nt.want(routeDirect) } +// Issue tailscale/corp#26438: use learned DERP route as send path of last +// resort +// +// See (*magicsock.Conn).fallbackDERPRegionForPeer and its comment for +// background. +// +// This sets up a test with two nodes that must use DERP to communicate but the +// target of the ping (the second node) additionally is not getting DERP or +// Endpoint updates from the control plane. (Or rather, it's getting them but is +// configured to scrub them right when they come off the network before being +// processed) This then tests whether node2, upon receiving a packet, will be +// able to reply to node1 since it knows neither node1's endpoints nor its home +// DERP. The only reply route it can use is that fact that it just received a +// packet over a particular DERP from that peer. +func TestFallbackDERPRegionForPeer(t *testing.T) { + nt := newNatTest(t) + nt.runTest(hard, hardNoDERPOrEndoints) + nt.want(routeDERP) +} + func TestSingleJustIPv6(t *testing.T) { nt := newNatTest(t) nt.runTest(just6) diff --git a/tstest/integration/tailscaled_deps_test_darwin.go b/tstest/integration/tailscaled_deps_test_darwin.go index 6676ee22c..217188f75 100644 --- a/tstest/integration/tailscaled_deps_test_darwin.go +++ b/tstest/integration/tailscaled_deps_test_darwin.go @@ -11,12 +11,15 @@ import ( // transitive deps when we run "go install tailscaled" in a child // process and can cache a prior success when a dependency changes. _ "tailscale.com/chirp" - _ "tailscale.com/client/tailscale" + _ "tailscale.com/client/local" _ "tailscale.com/cmd/tailscaled/childproc" _ "tailscale.com/control/controlclient" _ "tailscale.com/derp/derphttp" _ "tailscale.com/drive/driveimpl" _ "tailscale.com/envknob" + _ "tailscale.com/feature" + _ "tailscale.com/feature/buildfeatures" + _ "tailscale.com/feature/condregister" _ "tailscale.com/health" _ "tailscale.com/hostinfo" _ "tailscale.com/ipn" @@ -33,7 +36,6 @@ import ( _ "tailscale.com/net/proxymux" _ "tailscale.com/net/socks5" _ "tailscale.com/net/tsdial" - _ "tailscale.com/net/tshttpproxy" _ "tailscale.com/net/tstun" _ "tailscale.com/paths" _ "tailscale.com/safesocket" @@ -47,8 +49,10 @@ import ( _ "tailscale.com/types/logger" _ "tailscale.com/types/logid" _ "tailscale.com/util/clientmetric" - _ "tailscale.com/util/multierr" + _ "tailscale.com/util/eventbus" _ "tailscale.com/util/osshare" + _ "tailscale.com/util/syspolicy/pkey" + _ "tailscale.com/util/syspolicy/policyclient" _ "tailscale.com/version" _ "tailscale.com/version/distro" _ "tailscale.com/wgengine" diff --git a/tstest/integration/tailscaled_deps_test_freebsd.go b/tstest/integration/tailscaled_deps_test_freebsd.go index 6676ee22c..217188f75 100644 --- a/tstest/integration/tailscaled_deps_test_freebsd.go +++ b/tstest/integration/tailscaled_deps_test_freebsd.go @@ -11,12 +11,15 @@ import ( // transitive deps when we run "go install tailscaled" in a child // process and can cache a prior success when a dependency changes. _ "tailscale.com/chirp" - _ "tailscale.com/client/tailscale" + _ "tailscale.com/client/local" _ "tailscale.com/cmd/tailscaled/childproc" _ "tailscale.com/control/controlclient" _ "tailscale.com/derp/derphttp" _ "tailscale.com/drive/driveimpl" _ "tailscale.com/envknob" + _ "tailscale.com/feature" + _ "tailscale.com/feature/buildfeatures" + _ "tailscale.com/feature/condregister" _ "tailscale.com/health" _ "tailscale.com/hostinfo" _ "tailscale.com/ipn" @@ -33,7 +36,6 @@ import ( _ "tailscale.com/net/proxymux" _ "tailscale.com/net/socks5" _ "tailscale.com/net/tsdial" - _ "tailscale.com/net/tshttpproxy" _ "tailscale.com/net/tstun" _ "tailscale.com/paths" _ "tailscale.com/safesocket" @@ -47,8 +49,10 @@ import ( _ "tailscale.com/types/logger" _ "tailscale.com/types/logid" _ "tailscale.com/util/clientmetric" - _ "tailscale.com/util/multierr" + _ "tailscale.com/util/eventbus" _ "tailscale.com/util/osshare" + _ "tailscale.com/util/syspolicy/pkey" + _ "tailscale.com/util/syspolicy/policyclient" _ "tailscale.com/version" _ "tailscale.com/version/distro" _ "tailscale.com/wgengine" diff --git a/tstest/integration/tailscaled_deps_test_linux.go b/tstest/integration/tailscaled_deps_test_linux.go index 6676ee22c..217188f75 100644 --- a/tstest/integration/tailscaled_deps_test_linux.go +++ b/tstest/integration/tailscaled_deps_test_linux.go @@ -11,12 +11,15 @@ import ( // transitive deps when we run "go install tailscaled" in a child // process and can cache a prior success when a dependency changes. _ "tailscale.com/chirp" - _ "tailscale.com/client/tailscale" + _ "tailscale.com/client/local" _ "tailscale.com/cmd/tailscaled/childproc" _ "tailscale.com/control/controlclient" _ "tailscale.com/derp/derphttp" _ "tailscale.com/drive/driveimpl" _ "tailscale.com/envknob" + _ "tailscale.com/feature" + _ "tailscale.com/feature/buildfeatures" + _ "tailscale.com/feature/condregister" _ "tailscale.com/health" _ "tailscale.com/hostinfo" _ "tailscale.com/ipn" @@ -33,7 +36,6 @@ import ( _ "tailscale.com/net/proxymux" _ "tailscale.com/net/socks5" _ "tailscale.com/net/tsdial" - _ "tailscale.com/net/tshttpproxy" _ "tailscale.com/net/tstun" _ "tailscale.com/paths" _ "tailscale.com/safesocket" @@ -47,8 +49,10 @@ import ( _ "tailscale.com/types/logger" _ "tailscale.com/types/logid" _ "tailscale.com/util/clientmetric" - _ "tailscale.com/util/multierr" + _ "tailscale.com/util/eventbus" _ "tailscale.com/util/osshare" + _ "tailscale.com/util/syspolicy/pkey" + _ "tailscale.com/util/syspolicy/policyclient" _ "tailscale.com/version" _ "tailscale.com/version/distro" _ "tailscale.com/wgengine" diff --git a/tstest/integration/tailscaled_deps_test_openbsd.go b/tstest/integration/tailscaled_deps_test_openbsd.go index 6676ee22c..217188f75 100644 --- a/tstest/integration/tailscaled_deps_test_openbsd.go +++ b/tstest/integration/tailscaled_deps_test_openbsd.go @@ -11,12 +11,15 @@ import ( // transitive deps when we run "go install tailscaled" in a child // process and can cache a prior success when a dependency changes. _ "tailscale.com/chirp" - _ "tailscale.com/client/tailscale" + _ "tailscale.com/client/local" _ "tailscale.com/cmd/tailscaled/childproc" _ "tailscale.com/control/controlclient" _ "tailscale.com/derp/derphttp" _ "tailscale.com/drive/driveimpl" _ "tailscale.com/envknob" + _ "tailscale.com/feature" + _ "tailscale.com/feature/buildfeatures" + _ "tailscale.com/feature/condregister" _ "tailscale.com/health" _ "tailscale.com/hostinfo" _ "tailscale.com/ipn" @@ -33,7 +36,6 @@ import ( _ "tailscale.com/net/proxymux" _ "tailscale.com/net/socks5" _ "tailscale.com/net/tsdial" - _ "tailscale.com/net/tshttpproxy" _ "tailscale.com/net/tstun" _ "tailscale.com/paths" _ "tailscale.com/safesocket" @@ -47,8 +49,10 @@ import ( _ "tailscale.com/types/logger" _ "tailscale.com/types/logid" _ "tailscale.com/util/clientmetric" - _ "tailscale.com/util/multierr" + _ "tailscale.com/util/eventbus" _ "tailscale.com/util/osshare" + _ "tailscale.com/util/syspolicy/pkey" + _ "tailscale.com/util/syspolicy/policyclient" _ "tailscale.com/version" _ "tailscale.com/version/distro" _ "tailscale.com/wgengine" diff --git a/tstest/integration/tailscaled_deps_test_windows.go b/tstest/integration/tailscaled_deps_test_windows.go index bbf46d8c2..f3cd5e75b 100644 --- a/tstest/integration/tailscaled_deps_test_windows.go +++ b/tstest/integration/tailscaled_deps_test_windows.go @@ -18,22 +18,27 @@ import ( _ "golang.org/x/sys/windows/svc/mgr" _ "golang.zx2c4.com/wintun" _ "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" - _ "tailscale.com/client/tailscale" + _ "tailscale.com/client/local" _ "tailscale.com/cmd/tailscaled/childproc" + _ "tailscale.com/cmd/tailscaled/tailscaledhooks" _ "tailscale.com/control/controlclient" _ "tailscale.com/derp/derphttp" _ "tailscale.com/drive/driveimpl" _ "tailscale.com/envknob" + _ "tailscale.com/feature" + _ "tailscale.com/feature/buildfeatures" + _ "tailscale.com/feature/condregister" _ "tailscale.com/health" _ "tailscale.com/hostinfo" _ "tailscale.com/ipn" + _ "tailscale.com/ipn/auditlog" _ "tailscale.com/ipn/conffile" + _ "tailscale.com/ipn/desktop" _ "tailscale.com/ipn/ipnlocal" _ "tailscale.com/ipn/ipnserver" _ "tailscale.com/ipn/store" _ "tailscale.com/logpolicy" _ "tailscale.com/logtail" - _ "tailscale.com/logtail/backoff" _ "tailscale.com/net/dns" _ "tailscale.com/net/dnsfallback" _ "tailscale.com/net/netmon" @@ -41,7 +46,6 @@ import ( _ "tailscale.com/net/proxymux" _ "tailscale.com/net/socks5" _ "tailscale.com/net/tsdial" - _ "tailscale.com/net/tshttpproxy" _ "tailscale.com/net/tstun" _ "tailscale.com/paths" _ "tailscale.com/safesocket" @@ -53,12 +57,15 @@ import ( _ "tailscale.com/types/key" _ "tailscale.com/types/logger" _ "tailscale.com/types/logid" + _ "tailscale.com/util/backoff" _ "tailscale.com/util/clientmetric" - _ "tailscale.com/util/multierr" + _ "tailscale.com/util/eventbus" _ "tailscale.com/util/osdiag" _ "tailscale.com/util/osshare" - _ "tailscale.com/util/syspolicy" + _ "tailscale.com/util/syspolicy/pkey" + _ "tailscale.com/util/syspolicy/policyclient" _ "tailscale.com/util/winutil" + _ "tailscale.com/util/winutil/gp" _ "tailscale.com/version" _ "tailscale.com/version/distro" _ "tailscale.com/wf" diff --git a/tstest/integration/testcontrol/testcontrol.go b/tstest/integration/testcontrol/testcontrol.go index bbcf277d1..f9a33705b 100644 --- a/tstest/integration/testcontrol/testcontrol.go +++ b/tstest/integration/testcontrol/testcontrol.go @@ -5,7 +5,9 @@ package testcontrol import ( + "bufio" "bytes" + "cmp" "context" "encoding/binary" "encoding/json" @@ -26,13 +28,16 @@ import ( "time" "golang.org/x/net/http2" - "tailscale.com/control/controlhttp" + "tailscale.com/control/controlhttp/controlhttpserver" "tailscale.com/net/netaddr" "tailscale.com/net/tsaddr" + "tailscale.com/syncs" "tailscale.com/tailcfg" "tailscale.com/types/key" "tailscale.com/types/logger" + "tailscale.com/types/opt" "tailscale.com/types/ptr" + "tailscale.com/util/httpm" "tailscale.com/util/mak" "tailscale.com/util/must" "tailscale.com/util/rands" @@ -45,14 +50,30 @@ const msgLimit = 1 << 20 // encrypted message length limit // Server is a control plane server. Its zero value is ready for use. // Everything is stored in-memory in one tailnet. type Server struct { - Logf logger.Logf // nil means to use the log package - DERPMap *tailcfg.DERPMap // nil means to use prod DERP map - RequireAuth bool - RequireAuthKey string // required authkey for all nodes - Verbose bool - DNSConfig *tailcfg.DNSConfig // nil means no DNS config - MagicDNSDomain string - HandleC2N http.Handler // if non-nil, used for /some-c2n-path/ in tests + Logf logger.Logf // nil means to use the log package + DERPMap *tailcfg.DERPMap // nil means to use prod DERP map + RequireAuth bool + RequireAuthKey string // required authkey for all nodes + RequireMachineAuth bool + Verbose bool + DNSConfig *tailcfg.DNSConfig // nil means no DNS config + MagicDNSDomain string + C2NResponses syncs.Map[string, func(*http.Response)] // token => onResponse func + + // PeerRelayGrants, if true, inserts relay capabilities into the wildcard + // grants rules. + PeerRelayGrants bool + + // AllNodesSameUser, if true, makes all created nodes + // belong to the same user. + AllNodesSameUser bool + + // DefaultNodeCapabilities overrides the capability map sent to each client. + DefaultNodeCapabilities *tailcfg.NodeCapMap + + // CollectServices, if non-empty, sets whether the control server asks + // for service updates. If empty, the default is "true". + CollectServices opt.Bool // ExplicitBaseURL or HTTPTestServer must be set. ExplicitBaseURL string // e.g. "http://127.0.0.1:1234" with no trailing URL @@ -95,9 +116,9 @@ type Server struct { logins map[key.NodePublic]*tailcfg.Login updates map[tailcfg.NodeID]chan updateType authPath map[string]*AuthPath - nodeKeyAuthed map[key.NodePublic]bool // key => true once authenticated - msgToSend map[key.NodePublic]any // value is *tailcfg.PingRequest or entire *tailcfg.MapResponse - allExpired bool // All nodes will be told their node key is expired. + nodeKeyAuthed set.Set[key.NodePublic] + msgToSend map[key.NodePublic]any // value is *tailcfg.PingRequest or entire *tailcfg.MapResponse + allExpired bool // All nodes will be told their node key is expired. } // BaseURL returns the server's base URL, without trailing slash. @@ -174,6 +195,52 @@ func (s *Server) AddPingRequest(nodeKeyDst key.NodePublic, pr *tailcfg.PingReque return s.addDebugMessage(nodeKeyDst, pr) } +// c2nRoundTripper is an http.RoundTripper that sends requests to a node via C2N. +type c2nRoundTripper struct { + s *Server + n key.NodePublic +} + +func (s *Server) NodeRoundTripper(n key.NodePublic) http.RoundTripper { + return c2nRoundTripper{s, n} +} + +func (rt c2nRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + ctx := req.Context() + resc := make(chan *http.Response, 1) + if err := rt.s.SendC2N(rt.n, req, func(r *http.Response) { resc <- r }); err != nil { + return nil, err + } + select { + case <-ctx.Done(): + return nil, ctx.Err() + case r := <-resc: + return r, nil + } +} + +// SendC2N sends req to node. When the response is received, onRes is called. +func (s *Server) SendC2N(node key.NodePublic, req *http.Request, onRes func(*http.Response)) error { + var buf bytes.Buffer + if err := req.Write(&buf); err != nil { + return err + } + + token := rands.HexString(10) + pr := &tailcfg.PingRequest{ + URL: "https://unused/c2n/" + token, + Log: true, + Types: "c2n", + Payload: buf.Bytes(), + } + s.C2NResponses.Store(token, onRes) + if !s.AddPingRequest(node, pr) { + s.C2NResponses.Delete(token) + return fmt.Errorf("node %v not connected", node) + } + return nil +} + // AddRawMapResponse delivers the raw MapResponse mr to nodeKeyDst. It's meant // for testing incremental map updates. // @@ -260,9 +327,7 @@ func (s *Server) initMux() { s.mux.HandleFunc("/key", s.serveKey) s.mux.HandleFunc("/machine/", s.serveMachine) s.mux.HandleFunc("/ts2021", s.serveNoiseUpgrade) - if s.HandleC2N != nil { - s.mux.Handle("/some-c2n-path/", s.HandleC2N) - } + s.mux.HandleFunc("/c2n/", s.serveC2N) } func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { @@ -276,6 +341,37 @@ func (s *Server) serveUnhandled(w http.ResponseWriter, r *http.Request) { go panic(fmt.Sprintf("testcontrol.Server received unhandled request: %s", got.Bytes())) } +// serveC2N handles a POST from a node containing a c2n response. +func (s *Server) serveC2N(w http.ResponseWriter, r *http.Request) { + if err := func() error { + if r.Method != httpm.POST { + return errors.New("POST required") + } + token, ok := strings.CutPrefix(r.URL.Path, "/c2n/") + if !ok { + return fmt.Errorf("invalid path %q", r.URL.Path) + } + + onRes, ok := s.C2NResponses.Load(token) + if !ok { + return fmt.Errorf("unknown c2n token %q", token) + } + s.C2NResponses.Delete(token) + + res, err := http.ReadResponse(bufio.NewReader(r.Body), nil) + if err != nil { + return fmt.Errorf("error reading c2n response: %w", err) + } + onRes(res) + return nil + }(); err != nil { + s.logf("testcontrol: %s", err) + http.Error(w, err.Error(), 500) + return + } + w.WriteHeader(http.StatusNoContent) +} + type peerMachinePublicContextKey struct{} func (s *Server) serveNoiseUpgrade(w http.ResponseWriter, r *http.Request) { @@ -288,7 +384,7 @@ func (s *Server) serveNoiseUpgrade(w http.ResponseWriter, r *http.Request) { s.mu.Lock() noisePrivate := s.noisePrivKey s.mu.Unlock() - cc, err := controlhttp.AcceptHTTP(ctx, w, r, noisePrivate, nil) + cc, err := controlhttpserver.AcceptHTTP(ctx, w, r, noisePrivate, nil) if err != nil { log.Printf("AcceptHTTP: %v", err) return @@ -476,13 +572,22 @@ func (s *Server) AddFakeNode() { // TODO: send updates to other (non-fake?) nodes } -func (s *Server) AllUsers() (users []*tailcfg.User) { +func (s *Server) allUserProfiles() (res []tailcfg.UserProfile) { s.mu.Lock() defer s.mu.Unlock() - for _, u := range s.users { - users = append(users, u.Clone()) + for k, u := range s.users { + up := tailcfg.UserProfile{ + ID: u.ID, + DisplayName: u.DisplayName, + } + if login, ok := s.logins[k]; ok { + up.LoginName = login.LoginName + up.ProfilePicURL = cmp.Or(up.ProfilePicURL, login.ProfilePicURL) + up.DisplayName = cmp.Or(up.DisplayName, login.DisplayName) + } + res = append(res, up) } - return users + return res } func (s *Server) AllNodes() (nodes []*tailcfg.Node) { @@ -512,6 +617,10 @@ func (s *Server) getUser(nodeKey key.NodePublic) (*tailcfg.User, *tailcfg.Login) return u, s.logins[nodeKey] } id := tailcfg.UserID(len(s.users) + 1) + if s.AllNodesSameUser { + id = 123 + } + s.logf("Created user %v for node %s", id, nodeKey) loginName := fmt.Sprintf("user-%d@%s", id, domain) displayName := fmt.Sprintf("User %d", id) login := &tailcfg.Login{ @@ -523,9 +632,7 @@ func (s *Server) getUser(nodeKey key.NodePublic) (*tailcfg.User, *tailcfg.Login) } user := &tailcfg.User{ ID: id, - LoginName: loginName, DisplayName: displayName, - Logins: []tailcfg.LoginID{login.ID}, } s.users[nodeKey] = user s.logins[nodeKey] = login @@ -574,14 +681,35 @@ func (s *Server) CompleteAuth(authPathOrURL string) bool { if ap.nodeKey.IsZero() { panic("zero AuthPath.NodeKey") } - if s.nodeKeyAuthed == nil { - s.nodeKeyAuthed = map[key.NodePublic]bool{} - } - s.nodeKeyAuthed[ap.nodeKey] = true + s.nodeKeyAuthed.Make() + s.nodeKeyAuthed.Add(ap.nodeKey) ap.CompleteSuccessfully() return true } +// Complete the device approval for this node. +// +// This function returns false if the node does not exist, or you try to +// approve a device against a different control server. +func (s *Server) CompleteDeviceApproval(controlUrl string, urlStr string, nodeKey *key.NodePublic) bool { + s.mu.Lock() + defer s.mu.Unlock() + + node, ok := s.nodes[*nodeKey] + if !ok { + return false + } + + if urlStr != controlUrl+"/admin" { + return false + } + + sendUpdate(s.updates[node.ID], updateSelfChanged) + + node.MachineAuthorized = true + return true +} + func (s *Server) serveRegister(w http.ResponseWriter, r *http.Request, mkey key.MachinePublic) { msg, err := io.ReadAll(io.LimitReader(r.Body, msgLimit)) r.Body.Close() @@ -630,6 +758,25 @@ func (s *Server) serveRegister(w http.ResponseWriter, r *http.Request, mkey key. // some follow-ups? For now all are successes. } + // The in-memory list of nodes, users, and logins is keyed by + // the node key. If the node key changes, update all the data stores + // to use the new node key. + s.mu.Lock() + if _, oldNodeKeyOk := s.nodes[req.OldNodeKey]; oldNodeKeyOk { + if _, newNodeKeyOk := s.nodes[req.NodeKey]; !newNodeKeyOk { + s.nodes[req.OldNodeKey].Key = req.NodeKey + s.nodes[req.NodeKey] = s.nodes[req.OldNodeKey] + + s.users[req.NodeKey] = s.users[req.OldNodeKey] + s.logins[req.NodeKey] = s.logins[req.OldNodeKey] + + delete(s.nodes, req.OldNodeKey) + delete(s.users, req.OldNodeKey) + delete(s.logins, req.OldNodeKey) + } + } + s.mu.Unlock() + nk := req.NodeKey user, login := s.getUser(nk) @@ -637,36 +784,50 @@ func (s *Server) serveRegister(w http.ResponseWriter, r *http.Request, mkey key. if s.nodes == nil { s.nodes = map[key.NodePublic]*tailcfg.Node{} } + _, ok := s.nodes[nk] + machineAuthorized := !s.RequireMachineAuth + if !ok { - machineAuthorized := true // TODO: add Server.RequireMachineAuth + nodeID := len(s.nodes) + 1 + v4Prefix := netip.PrefixFrom(netaddr.IPv4(100, 64, uint8(nodeID>>8), uint8(nodeID)), 32) + v6Prefix := netip.PrefixFrom(tsaddr.Tailscale4To6(v4Prefix.Addr()), 128) - v4Prefix := netip.PrefixFrom(netaddr.IPv4(100, 64, uint8(tailcfg.NodeID(user.ID)>>8), uint8(tailcfg.NodeID(user.ID))), 32) - v6Prefix := netip.PrefixFrom(tsaddr.Tailscale4To6(v4Prefix.Addr()), 128) + allowedIPs := []netip.Prefix{ + v4Prefix, + v6Prefix, + } - allowedIPs := []netip.Prefix{ - v4Prefix, - v6Prefix, - } + var capMap tailcfg.NodeCapMap + if s.DefaultNodeCapabilities != nil { + capMap = *s.DefaultNodeCapabilities + } else { + capMap = tailcfg.NodeCapMap{ + tailcfg.CapabilityHTTPS: []tailcfg.RawMessage{}, + tailcfg.NodeAttrFunnel: []tailcfg.RawMessage{}, + tailcfg.CapabilityFileSharing: []tailcfg.RawMessage{}, + tailcfg.CapabilityFunnelPorts + "?ports=8080,443": []tailcfg.RawMessage{}, + } + } - s.nodes[nk] = &tailcfg.Node{ - ID: tailcfg.NodeID(user.ID), - StableID: tailcfg.StableNodeID(fmt.Sprintf("TESTCTRL%08x", int(user.ID))), - User: user.ID, - Machine: mkey, - Key: req.NodeKey, - MachineAuthorized: machineAuthorized, - Addresses: allowedIPs, - AllowedIPs: allowedIPs, - Hostinfo: req.Hostinfo.View(), - Name: req.Hostinfo.Hostname, - Capabilities: []tailcfg.NodeCapability{ - tailcfg.CapabilityHTTPS, - tailcfg.NodeAttrFunnel, - tailcfg.CapabilityFunnelPorts + "?ports=8080,443", - }, + node := &tailcfg.Node{ + ID: tailcfg.NodeID(nodeID), + StableID: tailcfg.StableNodeID(fmt.Sprintf("TESTCTRL%08x", int(nodeID))), + User: user.ID, + Machine: mkey, + Key: req.NodeKey, + MachineAuthorized: machineAuthorized, + Addresses: allowedIPs, + AllowedIPs: allowedIPs, + Hostinfo: req.Hostinfo.View(), + Name: req.Hostinfo.Hostname, + Cap: req.Version, + CapMap: capMap, + Capabilities: slices.Collect(maps.Keys(capMap)), + } + s.nodes[nk] = node } requireAuth := s.RequireAuth - if requireAuth && s.nodeKeyAuthed[nk] { + if requireAuth && s.nodeKeyAuthed.Contains(nk) { requireAuth = false } allExpired := s.allExpired @@ -793,11 +954,12 @@ func (s *Server) serveMap(w http.ResponseWriter, r *http.Request, mkey key.Machi endpoints := filterInvalidIPv6Endpoints(req.Endpoints) node.Endpoints = endpoints node.DiscoKey = req.DiscoKey + node.Cap = req.Version if req.Hostinfo != nil { node.Hostinfo = req.Hostinfo.View() if ni := node.Hostinfo.NetInfo(); ni.Valid() { if ni.PreferredDERP() != 0 { - node.DERP = fmt.Sprintf("127.3.3.40:%d", ni.PreferredDERP()) + node.HomeDERP = ni.PreferredDERP() } } } @@ -831,15 +993,17 @@ func (s *Server) serveMap(w http.ResponseWriter, r *http.Request, mkey key.Machi w.WriteHeader(200) for { - if resBytes, ok := s.takeRawMapMessage(req.NodeKey); ok { - if err := s.sendMapMsg(w, mkey, compress, resBytes); err != nil { - s.logf("sendMapMsg of raw message: %v", err) - return - } - if streaming { + // Only send raw map responses to the streaming poll, to avoid a + // non-streaming map request beating the streaming poll in a race and + // potentially dropping the map response. + if streaming { + if resBytes, ok := s.takeRawMapMessage(req.NodeKey); ok { + if err := s.sendMapMsg(w, compress, resBytes); err != nil { + s.logf("sendMapMsg of raw message: %v", err) + return + } continue } - return } if s.canGenerateAutomaticMapResponseFor(req.NodeKey) { @@ -864,7 +1028,7 @@ func (s *Server) serveMap(w http.ResponseWriter, r *http.Request, mkey key.Machi s.logf("json.Marshal: %v", err) return } - if err := s.sendMapMsg(w, mkey, compress, resBytes); err != nil { + if err := s.sendMapMsg(w, compress, resBytes); err != nil { return } } @@ -895,7 +1059,7 @@ func (s *Server) serveMap(w http.ResponseWriter, r *http.Request, mkey key.Machi } break keepAliveLoop case <-keepAliveTimerCh: - if err := s.sendMapMsg(w, mkey, compress, keepAliveMsg); err != nil { + if err := s.sendMapMsg(w, compress, keepAliveMsg); err != nil { return } } @@ -909,14 +1073,21 @@ var keepAliveMsg = &struct { KeepAlive: true, } -func packetFilterWithIngressCaps() []tailcfg.FilterRule { +func packetFilterWithIngress(addRelayCaps bool) []tailcfg.FilterRule { out := slices.Clone(tailcfg.FilterAllowAll) + caps := []tailcfg.PeerCapability{ + tailcfg.PeerCapabilityIngress, + } + if addRelayCaps { + caps = append(caps, tailcfg.PeerCapabilityRelay) + caps = append(caps, tailcfg.PeerCapabilityRelayTarget) + } out = append(out, tailcfg.FilterRule{ SrcIPs: []string{"*"}, CapGrant: []tailcfg.CapGrant{ { Dsts: []netip.Prefix{tsaddr.AllIPv4(), tsaddr.AllIPv6()}, - Caps: []tailcfg.PeerCapability{tailcfg.PeerCapabilityIngress}, + Caps: caps, }, }, }) @@ -941,13 +1112,12 @@ func (s *Server) MapResponse(req *tailcfg.MapRequest) (res *tailcfg.MapResponse, node.CapMap = nodeCapMap node.Capabilities = append(node.Capabilities, tailcfg.NodeAttrDisableUPnP) - user, _ := s.getUser(nk) t := time.Date(2020, 8, 3, 0, 0, 0, 1, time.UTC) dns := s.DNSConfig if dns != nil && s.MagicDNSDomain != "" { dns = dns.Clone() dns.CertDomains = []string{ - fmt.Sprintf(node.Hostinfo.Hostname() + "." + s.MagicDNSDomain), + node.Hostinfo.Hostname() + "." + s.MagicDNSDomain, } } @@ -955,8 +1125,8 @@ func (s *Server) MapResponse(req *tailcfg.MapRequest) (res *tailcfg.MapResponse, Node: node, DERPMap: s.DERPMap, Domain: domain, - CollectServices: "true", - PacketFilter: packetFilterWithIngressCaps(), + CollectServices: cmp.Or(s.CollectServices, opt.True), + PacketFilter: packetFilterWithIngress(s.PeerRelayGrants), DNSConfig: dns, ControlTime: &t, } @@ -981,7 +1151,11 @@ func (s *Server) MapResponse(req *tailcfg.MapRequest) (res *tailcfg.MapResponse, s.mu.Lock() peerAddress := s.masquerades[p.Key][node.Key] routes := s.nodeSubnetRoutes[p.Key] + peerCapMap := maps.Clone(s.nodeCapMaps[p.Key]) s.mu.Unlock() + if peerCapMap != nil { + p.CapMap = peerCapMap + } if peerAddress.IsValid() { if peerAddress.Is6() { p.Addresses[1] = netip.PrefixFrom(peerAddress, peerAddress.BitLen()) @@ -1001,15 +1175,9 @@ func (s *Server) MapResponse(req *tailcfg.MapRequest) (res *tailcfg.MapResponse, sort.Slice(res.Peers, func(i, j int) bool { return res.Peers[i].ID < res.Peers[j].ID }) - for _, u := range s.AllUsers() { - res.UserProfiles = append(res.UserProfiles, tailcfg.UserProfile{ - ID: u.ID, - LoginName: u.LoginName, - DisplayName: u.DisplayName, - }) - } + res.UserProfiles = s.allUserProfiles() - v4Prefix := netip.PrefixFrom(netaddr.IPv4(100, 64, uint8(tailcfg.NodeID(user.ID)>>8), uint8(tailcfg.NodeID(user.ID))), 32) + v4Prefix := netip.PrefixFrom(netaddr.IPv4(100, 64, uint8(node.ID>>8), uint8(node.ID)), 32) v6Prefix := netip.PrefixFrom(tsaddr.Tailscale4To6(v4Prefix.Addr()), 128) res.Node.Addresses = []netip.Prefix{ @@ -1040,18 +1208,25 @@ func (s *Server) canGenerateAutomaticMapResponseFor(nk key.NodePublic) bool { func (s *Server) hasPendingRawMapMessage(nk key.NodePublic) bool { s.mu.Lock() defer s.mu.Unlock() - _, ok := s.msgToSend[nk].(*tailcfg.MapResponse) + _, ok := s.msgToSend[nk] return ok } func (s *Server) takeRawMapMessage(nk key.NodePublic) (mapResJSON []byte, ok bool) { s.mu.Lock() defer s.mu.Unlock() - mr, ok := s.msgToSend[nk].(*tailcfg.MapResponse) + mr, ok := s.msgToSend[nk] if !ok { return nil, false } delete(s.msgToSend, nk) + + // If it's a bare PingRequest, wrap it in a MapResponse. + switch pr := mr.(type) { + case *tailcfg.PingRequest: + mr = &tailcfg.MapResponse{PingRequest: pr} + } + var err error mapResJSON, err = json.Marshal(mr) if err != nil { @@ -1060,7 +1235,7 @@ func (s *Server) takeRawMapMessage(nk key.NodePublic) (mapResJSON []byte, ok boo return mapResJSON, true } -func (s *Server) sendMapMsg(w http.ResponseWriter, mkey key.MachinePublic, compress bool, msg any) error { +func (s *Server) sendMapMsg(w http.ResponseWriter, compress bool, msg any) error { resBytes, err := s.encode(compress, msg) if err != nil { return err diff --git a/tstest/integration/vms/README.md b/tstest/integration/vms/README.md index 519c3d000..a68ed0514 100644 --- a/tstest/integration/vms/README.md +++ b/tstest/integration/vms/README.md @@ -1,7 +1,6 @@ # End-to-End VM-based Integration Testing -This test spins up a bunch of common linux distributions and then tries to get -them to connect to a +These tests spin up a Tailscale client in a Linux VM and try to connect it to [`testcontrol`](https://pkg.go.dev/tailscale.com/tstest/integration/testcontrol) server. @@ -55,26 +54,6 @@ If you pass the `-no-s3` flag to `go test`, the S3 step will be skipped in favor of downloading the images directly from upstream sources, which may cause the test to fail in odd places. -### Distribution Picking - -This test runs on a large number of distributions. By default it tries to run -everything, which may or may not be ideal for you. If you only want to test a -subset of distributions, you can use the `--distro-regex` flag to match a subset -of distributions using a [regular expression](https://golang.org/pkg/regexp/) -such as like this: - -```console -$ go test -run-vm-tests -distro-regex centos -``` - -This would run all tests on all versions of CentOS. - -```console -$ go test -run-vm-tests -distro-regex '(debian|ubuntu)' -``` - -This would run all tests on all versions of Debian and Ubuntu. - ### Ram Limiting This test uses a lot of memory. In order to avoid making machines run out of diff --git a/tstest/integration/vms/distros.hujson b/tstest/integration/vms/distros.hujson index 049091ed5..2c90f9a2f 100644 --- a/tstest/integration/vms/distros.hujson +++ b/tstest/integration/vms/distros.hujson @@ -12,24 +12,16 @@ // /var/log/cloud-init-output.log for what you messed up. [ { - "Name": "ubuntu-18-04", - "URL": "https://cloud-images.ubuntu.com/releases/bionic/release-20210817/ubuntu-18.04-server-cloudimg-amd64.img", - "SHA256Sum": "1ee1039f0b91c8367351413b5b5f56026aaf302fd5f66f17f8215132d6e946d2", + "Name": "ubuntu-24-04", + "URL": "https://cloud-images.ubuntu.com/noble/20250523/noble-server-cloudimg-amd64.img", + "SHA256Sum": "0e865619967706765cdc8179fb9929202417ab3a0719d77d8c8942d38aa9611b", "MemoryMegs": 512, "PackageManager": "apt", "InitSystem": "systemd" }, { - "Name": "ubuntu-20-04", - "URL": "https://cloud-images.ubuntu.com/releases/focal/release-20210819/ubuntu-20.04-server-cloudimg-amd64.img", - "SHA256Sum": "99e25e6e344e3a50a081235e825937238a3d51b099969e107ef66f0d3a1f955e", - "MemoryMegs": 512, - "PackageManager": "apt", - "InitSystem": "systemd" - }, - { - "Name": "nixos-21-11", - "URL": "channel:nixos-21.11", + "Name": "nixos-25-05", + "URL": "channel:nixos-25.05", "SHA256Sum": "lolfakesha", "MemoryMegs": 512, "PackageManager": "nix", diff --git a/tstest/integration/vms/harness_test.go b/tstest/integration/vms/harness_test.go index 1e080414d..256227d6c 100644 --- a/tstest/integration/vms/harness_test.go +++ b/tstest/integration/vms/harness_test.go @@ -134,11 +134,12 @@ func newHarness(t *testing.T) *Harness { loginServer := fmt.Sprintf("http://%s", ln.Addr()) t.Logf("loginServer: %s", loginServer) + binaries := integration.GetBinaries(t) h := &Harness{ pubKey: string(pubkey), - binaryDir: integration.BinaryDir(t), - cli: integration.TailscaleBinary(t), - daemon: integration.TailscaledBinary(t), + binaryDir: binaries.Dir, + cli: binaries.Tailscale.Path, + daemon: binaries.Tailscaled.Path, signer: signer, loginServerURL: loginServer, cs: cs, diff --git a/tstest/integration/vms/nixos_test.go b/tstest/integration/vms/nixos_test.go index c2998ff3c..02b040fed 100644 --- a/tstest/integration/vms/nixos_test.go +++ b/tstest/integration/vms/nixos_test.go @@ -97,7 +97,7 @@ let # Wrap tailscaled with the ip and iptables commands. wrapProgram $out/bin/tailscaled --prefix PATH : ${ - lib.makeBinPath [ iproute iptables ] + lib.makeBinPath [ iproute2 iptables ] } # Install systemd unit. @@ -127,6 +127,9 @@ in { # yolo, this vm can sudo freely. security.sudo.wheelNeedsPassword = false; + # nix considers squid insecure, but this is fine for a test. + nixpkgs.config.permittedInsecurePackages = [ "squid-7.0.1" ]; + # Enable cloud-init so we can set VM hostnames and the like the same as other # distros. This will also take care of SSH keys. It's pretty handy. services.cloud-init = { diff --git a/tstest/integration/vms/opensuse_leap_15_1_test.go b/tstest/integration/vms/opensuse_leap_15_1_test.go deleted file mode 100644 index 7d3ac579e..000000000 --- a/tstest/integration/vms/opensuse_leap_15_1_test.go +++ /dev/null @@ -1,85 +0,0 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !windows && !plan9 - -package vms - -import ( - "encoding/json" - "os" - "path/filepath" - "testing" - - "github.com/google/uuid" -) - -/* - The images that we use for OpenSUSE Leap 15.1 have an issue that makes the - nocloud backend[1] for cloud-init just not work. As a distro-specific - workaround, we're gonna pretend to be OpenStack. - - TODO(Xe): delete once we no longer need to support OpenSUSE Leap 15.1. - - [1]: https://cloudinit.readthedocs.io/en/latest/topics/datasources/nocloud.html -*/ - -type openSUSELeap151MetaData struct { - Zone string `json:"availability_zone"` // nova - Hostname string `json:"hostname"` // opensuse-leap-15-1 - LaunchIndex string `json:"launch_index"` // 0 - Meta openSUSELeap151MetaDataMeta `json:"meta"` // some openstack metadata we don't need to care about - Name string `json:"name"` // opensuse-leap-15-1 - UUID string `json:"uuid"` // e9c664cd-b116-433b-aa61-7ff420163dcd -} - -type openSUSELeap151MetaDataMeta struct { - Role string `json:"role"` // server - DSMode string `json:"dsmode"` // local - Essential string `json:"essential"` // essential -} - -func hackOpenSUSE151UserData(t *testing.T, d Distro, dir string) bool { - if d.Name != "opensuse-leap-15-1" { - return false - } - - t.Log("doing OpenSUSE Leap 15.1 hack") - osDir := filepath.Join(dir, "openstack", "latest") - err := os.MkdirAll(osDir, 0755) - if err != nil { - t.Fatalf("can't make metadata home: %v", err) - } - - metadata, err := json.Marshal(openSUSELeap151MetaData{ - Zone: "nova", - Hostname: d.Name, - LaunchIndex: "0", - Meta: openSUSELeap151MetaDataMeta{ - Role: "server", - DSMode: "local", - Essential: "false", - }, - Name: d.Name, - UUID: uuid.New().String(), - }) - if err != nil { - t.Fatalf("can't encode metadata: %v", err) - } - err = os.WriteFile(filepath.Join(osDir, "meta_data.json"), metadata, 0666) - if err != nil { - t.Fatalf("can't write to meta_data.json: %v", err) - } - - data, err := os.ReadFile(filepath.Join(dir, "user-data")) - if err != nil { - t.Fatalf("can't read user_data: %v", err) - } - - err = os.WriteFile(filepath.Join(osDir, "user_data"), data, 0666) - if err != nil { - t.Fatalf("can't create output user_data: %v", err) - } - - return true -} diff --git a/tstest/integration/vms/regex_flag.go b/tstest/integration/vms/regex_flag.go deleted file mode 100644 index 02e399ecd..000000000 --- a/tstest/integration/vms/regex_flag.go +++ /dev/null @@ -1,29 +0,0 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package vms - -import "regexp" - -type regexValue struct { - r *regexp.Regexp -} - -func (r *regexValue) String() string { - if r.r == nil { - return "" - } - - return r.r.String() -} - -func (r *regexValue) Set(val string) error { - if rex, err := regexp.Compile(val); err != nil { - return err - } else { - r.r = rex - return nil - } -} - -func (r regexValue) Unwrap() *regexp.Regexp { return r.r } diff --git a/tstest/integration/vms/regex_flag_test.go b/tstest/integration/vms/regex_flag_test.go deleted file mode 100644 index 0f4e5f8f7..000000000 --- a/tstest/integration/vms/regex_flag_test.go +++ /dev/null @@ -1,21 +0,0 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package vms - -import ( - "flag" - "testing" -) - -func TestRegexFlag(t *testing.T) { - var v regexValue - fs := flag.NewFlagSet(t.Name(), flag.PanicOnError) - fs.Var(&v, "regex", "regex to parse") - - const want = `.*` - fs.Parse([]string{"-regex", want}) - if v.Unwrap().String() != want { - t.Fatalf("got wrong regex: %q, wanted: %q", v.Unwrap().String(), want) - } -} diff --git a/tstest/integration/vms/top_level_test.go b/tstest/integration/vms/top_level_test.go index c107fd89c..5db237b6e 100644 --- a/tstest/integration/vms/top_level_test.go +++ b/tstest/integration/vms/top_level_test.go @@ -14,17 +14,13 @@ import ( expect "github.com/tailscale/goexpect" ) -func TestRunUbuntu1804(t *testing.T) { +func TestRunUbuntu2404(t *testing.T) { testOneDistribution(t, 0, Distros[0]) } -func TestRunUbuntu2004(t *testing.T) { - testOneDistribution(t, 1, Distros[1]) -} - -func TestRunNixos2111(t *testing.T) { +func TestRunNixos2505(t *testing.T) { t.Parallel() - testOneDistribution(t, 2, Distros[2]) + testOneDistribution(t, 1, Distros[1]) } // TestMITMProxy is a smoke test for derphttp through a MITM proxy. @@ -39,13 +35,7 @@ func TestRunNixos2111(t *testing.T) { func TestMITMProxy(t *testing.T) { t.Parallel() setupTests(t) - distro := Distros[2] // nixos-21.11 - - if distroRex.Unwrap().MatchString(distro.Name) { - t.Logf("%s matches %s", distro.Name, distroRex.Unwrap()) - } else { - t.Skip("regex not matched") - } + distro := Distros[1] // nixos-25.05 ctx, done := context.WithCancel(context.Background()) t.Cleanup(done) diff --git a/tstest/integration/vms/vms_test.go b/tstest/integration/vms/vms_test.go index 6d73a3f78..c3a3775de 100644 --- a/tstest/integration/vms/vms_test.go +++ b/tstest/integration/vms/vms_test.go @@ -15,7 +15,6 @@ import ( "os" "os/exec" "path/filepath" - "regexp" "strconv" "strings" "sync" @@ -28,7 +27,6 @@ import ( "golang.org/x/crypto/ssh" "golang.org/x/sync/semaphore" "tailscale.com/tstest" - "tailscale.com/tstest/integration" "tailscale.com/types/logger" ) @@ -44,20 +42,8 @@ var ( useVNC = flag.Bool("use-vnc", false, "if set, display guest vms over VNC") verboseLogcatcher = flag.Bool("verbose-logcatcher", true, "if set, print logcatcher to t.Logf") verboseQemu = flag.Bool("verbose-qemu", true, "if set, print qemu console to t.Logf") - distroRex = func() *regexValue { - result := ®exValue{r: regexp.MustCompile(`.*`)} - flag.Var(result, "distro-regex", "The regex that matches what distros should be run") - return result - }() ) -func TestMain(m *testing.M) { - flag.Parse() - v := m.Run() - integration.CleanupBinaries() - os.Exit(v) -} - func TestDownloadImages(t *testing.T) { if !*runVMTests { t.Skip("not running integration tests (need --run-vm-tests)") @@ -67,9 +53,6 @@ func TestDownloadImages(t *testing.T) { distro := d t.Run(distro.Name, func(t *testing.T) { t.Parallel() - if !distroRex.Unwrap().MatchString(distro.Name) { - t.Skipf("distro name %q doesn't match regex: %s", distro.Name, distroRex) - } if strings.HasPrefix(distro.Name, "nixos") { t.Skip("NixOS is built on the fly, no need to download it") } @@ -183,10 +166,6 @@ func mkSeed(t *testing.T, d Distro, sshKey, hostURL, tdir string, port int) { filepath.Join(dir, "user-data"), } - if hackOpenSUSE151UserData(t, d, dir) { - args = append(args, filepath.Join(dir, "openstack")) - } - run(t, tdir, "genisoimage", args...) } @@ -205,14 +184,14 @@ type ipMapping struct { // it is difficult to be 100% sure. This function should be used with care. It // will probably do what you want, but it is very easy to hold this wrong. func getProbablyFreePortNumber() (int, error) { - l, err := net.Listen("tcp", ":0") + ln, err := net.Listen("tcp", ":0") if err != nil { return 0, err } - defer l.Close() + defer ln.Close() - _, port, err := net.SplitHostPort(l.Addr().String()) + _, port, err := net.SplitHostPort(ln.Addr().String()) if err != nil { return 0, err } @@ -255,12 +234,6 @@ var ramsem struct { func testOneDistribution(t *testing.T, n int, distro Distro) { setupTests(t) - if distroRex.Unwrap().MatchString(distro.Name) { - t.Logf("%s matches %s", distro.Name, distroRex.Unwrap()) - } else { - t.Skip("regex not matched") - } - ctx, done := context.WithCancel(context.Background()) t.Cleanup(done) diff --git a/tstest/iosdeps/iosdeps_test.go b/tstest/iosdeps/iosdeps_test.go index 08df9a930..b533724eb 100644 --- a/tstest/iosdeps/iosdeps_test.go +++ b/tstest/iosdeps/iosdeps_test.go @@ -14,9 +14,17 @@ func TestDeps(t *testing.T) { GOOS: "ios", GOARCH: "arm64", BadDeps: map[string]string{ - "testing": "do not use testing package in production code", - "text/template": "linker bloat (MethodByName)", - "html/template": "linker bloat (MethodByName)", + "testing": "do not use testing package in production code", + "text/template": "linker bloat (MethodByName)", + "html/template": "linker bloat (MethodByName)", + "tailscale.com/net/wsconn": "https://github.com/tailscale/tailscale/issues/13762", + "github.com/coder/websocket": "https://github.com/tailscale/tailscale/issues/13762", + "github.com/mitchellh/go-ps": "https://github.com/tailscale/tailscale/pull/13759", + "database/sql/driver": "iOS doesn't use an SQL database", + "github.com/google/uuid": "see tailscale/tailscale#13760", + "tailscale.com/clientupdate/distsign": "downloads via AppStore, not distsign", + "github.com/tailscale/hujson": "no config file support on iOS", + "tailscale.com/feature/capture": "no debug packet capture on iOS", }, }.Check(t) } diff --git a/tstest/log.go b/tstest/log.go index cb67c609a..d081c819d 100644 --- a/tstest/log.go +++ b/tstest/log.go @@ -13,6 +13,7 @@ import ( "go4.org/mem" "tailscale.com/types/logger" + "tailscale.com/util/testenv" ) type testLogWriter struct { @@ -149,7 +150,7 @@ func (ml *MemLogger) String() string { // WhileTestRunningLogger returns a logger.Logf that logs to t.Logf until the // test finishes, at which point it no longer logs anything. -func WhileTestRunningLogger(t testing.TB) logger.Logf { +func WhileTestRunningLogger(t testenv.TB) logger.Logf { var ( mu sync.RWMutex done bool diff --git a/tstest/mts/mts.go b/tstest/mts/mts.go new file mode 100644 index 000000000..c10d69d8d --- /dev/null +++ b/tstest/mts/mts.go @@ -0,0 +1,599 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build linux || darwin + +// The mts ("Multiple Tailscale") command runs multiple tailscaled instances for +// development, managing their directories and sockets, and lets you easily direct +// tailscale CLI commands to them. +package main + +import ( + "bufio" + "context" + "encoding/json" + "flag" + "fmt" + "io" + "log" + "maps" + "net" + "net/http" + "net/url" + "os" + "os/exec" + "path/filepath" + "regexp" + "slices" + "strings" + "sync" + "syscall" + "time" + + "tailscale.com/client/local" + "tailscale.com/types/bools" + "tailscale.com/types/lazy" + "tailscale.com/util/mak" +) + +func usage(args ...any) { + var format string + if len(args) > 0 { + format, args = args[0].(string), args[1:] + } + if format != "" { + format = strings.TrimSpace(format) + "\n\n" + fmt.Fprintf(os.Stderr, format, args...) + } + io.WriteString(os.Stderr, strings.TrimSpace(` +usage: + + mts server # manage tailscaled instances + mts server run # run the mts server (parent process of all tailscaled) + mts server list # list all tailscaled and their state + mts server list # show details of named instance + mts server add # add+start new named tailscaled + mts server start # start a previously added tailscaled + mts server stop # stop & remove a named tailscaled + mts server rm # stop & remove a named tailscaled + mts server logs [-f] # get/follow tailscaled logs + + mts [tailscale CLI args] # run Tailscale CLI against a named instance + e.g. + mts gmail1 up + mts github2 status --json + `)+"\n") + os.Exit(1) +} + +func main() { + // Don't use flag.Parse here; we mostly just delegate through + // to the Tailscale CLI. + + if len(os.Args) < 2 { + usage() + } + firstArg, args := os.Args[1], os.Args[2:] + if firstArg == "server" || firstArg == "s" { + if err := runMTSServer(args); err != nil { + log.Fatal(err) + } + } else { + var c Client + inst := firstArg + c.RunCommand(inst, args) + } +} + +func runMTSServer(args []string) error { + if len(args) == 0 { + usage() + } + cmd, args := args[0], args[1:] + if cmd == "run" { + var s Server + return s.Run() + } + + // Commands other than "run" all use the HTTP client to + // hit the mts server over its unix socket. + var c Client + + switch cmd { + default: + usage("unknown mts server subcommand %q", cmd) + case "list", "ls": + list, err := c.List() + if err != nil { + return err + } + if len(args) == 0 { + names := slices.Sorted(maps.Keys(list.Instances)) + for _, name := range names { + running := list.Instances[name].Running + fmt.Printf("%10s %s\n", bools.IfElse(running, "RUNNING", "stopped"), name) + } + } else { + for _, name := range args { + inst, ok := list.Instances[name] + if !ok { + return fmt.Errorf("no instance named %q", name) + } + je := json.NewEncoder(os.Stdout) + je.SetIndent("", " ") + if err := je.Encode(inst); err != nil { + return err + } + } + } + + case "rm": + if len(args) == 0 { + return fmt.Errorf("missing instance name(s) to remove") + } + log.SetFlags(0) + for _, name := range args { + ok, err := c.Remove(name) + if err != nil { + return err + } + if ok { + log.Printf("%s deleted.", name) + } else { + log.Printf("%s didn't exist.", name) + } + } + case "stop": + if len(args) == 0 { + return fmt.Errorf("missing instance name(s) to stop") + } + log.SetFlags(0) + for _, name := range args { + ok, err := c.Stop(name) + if err != nil { + return err + } + if ok { + log.Printf("%s stopped.", name) + } else { + log.Printf("%s didn't exist.", name) + } + } + case "start", "restart": + list, err := c.List() + if err != nil { + return err + } + shouldStop := cmd == "restart" + for _, arg := range args { + is, ok := list.Instances[arg] + if !ok { + return fmt.Errorf("no instance named %q", arg) + } + if is.Running { + if shouldStop { + if _, err := c.Stop(arg); err != nil { + return fmt.Errorf("stopping %q: %w", arg, err) + } + } else { + log.SetFlags(0) + log.Printf("%s already running.", arg) + continue + } + } + // Creating an existing one starts it up. + if err := c.Create(arg); err != nil { + return fmt.Errorf("starting %q: %w", arg, err) + } + } + case "add": + if len(args) == 0 { + return fmt.Errorf("missing instance name(s) to add") + } + for _, name := range args { + if err := c.Create(name); err != nil { + return fmt.Errorf("creating %q: %w", name, err) + } + } + case "logs": + fs := flag.NewFlagSet("logs", flag.ExitOnError) + fs.Usage = func() { usage() } + follow := fs.Bool("f", false, "follow logs") + fs.Parse(args) + log.Printf("Parsed; following=%v, args=%q", *follow, fs.Args()) + if fs.NArg() != 1 { + usage() + } + cmd := bools.IfElse(*follow, "tail", "cat") + args := []string{cmd} + if *follow { + args = append(args, "-f") + } + path, err := exec.LookPath(cmd) + if err != nil { + return fmt.Errorf("looking up %q: %w", cmd, err) + } + args = append(args, instLogsFile(fs.Arg(0))) + log.Fatal(syscall.Exec(path, args, os.Environ())) + } + return nil +} + +type Client struct { +} + +func (c *Client) client() *http.Client { + return &http.Client{ + Transport: &http.Transport{ + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + return net.Dial("unix", mtsSock()) + }, + }, + } +} + +func getJSON[T any](res *http.Response, err error) (T, error) { + var ret T + if err != nil { + return ret, err + } + defer res.Body.Close() + if res.StatusCode != http.StatusOK { + body, _ := io.ReadAll(res.Body) + return ret, fmt.Errorf("unexpected status: %v: %s", res.Status, body) + } + if err := json.NewDecoder(res.Body).Decode(&ret); err != nil { + return ret, err + } + return ret, nil +} + +func (c *Client) List() (listResponse, error) { + return getJSON[listResponse](c.client().Get("http://mts/list")) +} + +func (c *Client) Remove(name string) (found bool, err error) { + return getJSON[bool](c.client().PostForm("http://mts/rm", url.Values{ + "name": []string{name}, + })) +} + +func (c *Client) Stop(name string) (found bool, err error) { + return getJSON[bool](c.client().PostForm("http://mts/stop", url.Values{ + "name": []string{name}, + })) +} + +func (c *Client) Create(name string) error { + req, err := http.NewRequest("POST", "http://mts/create/"+name, nil) + if err != nil { + return err + } + resp, err := c.client().Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return fmt.Errorf("unexpected status: %v: %s", resp.Status, body) + } + return nil +} + +func (c *Client) RunCommand(name string, args []string) { + sock := instSock(name) + lc := &local.Client{ + Socket: sock, + UseSocketOnly: true, + } + probeCtx, cancel := context.WithTimeout(context.Background(), 250*time.Millisecond) + defer cancel() + if _, err := lc.StatusWithoutPeers(probeCtx); err != nil { + log.Fatalf("instance %q not running? start with 'mts server start %q'; got error: %v", name, name, err) + } + args = append([]string{"run", "tailscale.com/cmd/tailscale", "--socket=" + sock}, args...) + cmd := exec.Command("go", args...) + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + cmd.Stdin = os.Stdin + err := cmd.Run() + if err == nil { + os.Exit(0) + } + if exitErr, ok := err.(*exec.ExitError); ok { + os.Exit(exitErr.ExitCode()) + } + panic(err) +} + +type Server struct { + lazyTailscaled lazy.GValue[string] + + mu sync.Mutex + cmds map[string]*exec.Cmd // running tailscaled instances +} + +func (s *Server) tailscaled() string { + v, err := s.lazyTailscaled.GetErr(func() (string, error) { + out, err := exec.Command("go", "list", "-f", "{{.Target}}", "tailscale.com/cmd/tailscaled").CombinedOutput() + if err != nil { + return "", err + } + return strings.TrimSpace(string(out)), nil + }) + if err != nil { + panic(err) + } + return v +} + +func (s *Server) Run() error { + if err := os.MkdirAll(mtsRoot(), 0700); err != nil { + return err + } + sock := mtsSock() + os.Remove(sock) + log.Printf("Multi-Tailscaled Server running; listening on %q ...", sock) + ln, err := net.Listen("unix", sock) + if err != nil { + return err + } + return http.Serve(ln, s) +} + +var validNameRx = regexp.MustCompile(`^[a-zA-Z0-9_-]+$`) + +func validInstanceName(name string) bool { + return validNameRx.MatchString(name) +} + +func (s *Server) InstanceRunning(name string) bool { + s.mu.Lock() + defer s.mu.Unlock() + _, ok := s.cmds[name] + return ok +} + +func (s *Server) Stop(name string) { + s.mu.Lock() + defer s.mu.Unlock() + if cmd, ok := s.cmds[name]; ok { + if err := cmd.Process.Kill(); err != nil { + log.Printf("error killing %q: %v", name, err) + } + delete(s.cmds, name) + } +} + +func (s *Server) RunInstance(name string) error { + s.mu.Lock() + defer s.mu.Unlock() + + if _, ok := s.cmds[name]; ok { + return fmt.Errorf("instance %q already running", name) + } + + if !validInstanceName(name) { + return fmt.Errorf("invalid instance name %q", name) + } + dir := filepath.Join(mtsRoot(), name) + if err := os.MkdirAll(dir, 0700); err != nil { + return err + } + + env := os.Environ() + env = append(env, "TS_DEBUG_LOG_RATE=all") + if ef, err := os.Open(instEnvFile(name)); err == nil { + defer ef.Close() + sc := bufio.NewScanner(ef) + for sc.Scan() { + t := strings.TrimSpace(sc.Text()) + if strings.HasPrefix(t, "#") || !strings.Contains(t, "=") { + continue + } + env = append(env, t) + } + } else if os.IsNotExist(err) { + // Write an example one. + os.WriteFile(instEnvFile(name), fmt.Appendf(nil, "# Example mts env.txt file; uncomment/add stuff you want for %q\n\n#TS_DEBUG_MAP=1\n#TS_DEBUG_REGISTER=1\n#TS_NO_LOGS_NO_SUPPORT=1\n", name), 0600) + } + + extraArgs := []string{"--verbose=1"} + if af, err := os.Open(instArgsFile(name)); err == nil { + extraArgs = nil // clear default args + defer af.Close() + sc := bufio.NewScanner(af) + for sc.Scan() { + t := strings.TrimSpace(sc.Text()) + if strings.HasPrefix(t, "#") || t == "" { + continue + } + extraArgs = append(extraArgs, t) + } + } else if os.IsNotExist(err) { + // Write an example one. + os.WriteFile(instArgsFile(name), fmt.Appendf(nil, "# Example mts args.txt file for instance %q.\n# One line per extra arg to tailscaled; no magic string quoting\n\n--verbose=1\n#--socks5-server=127.0.0.1:5000\n", name), 0600) + } + + log.Printf("Running Tailscale daemon %q in %q", name, dir) + + args := []string{ + "--tun=userspace-networking", + "--statedir=" + filepath.Join(dir), + "--socket=" + filepath.Join(dir, "tailscaled.sock"), + } + args = append(args, extraArgs...) + + cmd := exec.Command(s.tailscaled(), args...) + cmd.Dir = dir + cmd.Env = env + + out, err := cmd.StdoutPipe() + if err != nil { + return err + } + cmd.Stderr = cmd.Stdout + + logs := instLogsFile(name) + logFile, err := os.OpenFile(logs, os.O_CREATE|os.O_WRONLY|os.O_APPEND|os.O_TRUNC, 0644) + if err != nil { + return fmt.Errorf("opening logs file: %w", err) + } + + go func() { + bs := bufio.NewScanner(out) + for bs.Scan() { + // TODO(bradfitz): record in memory too, serve via HTTP + line := strings.TrimSpace(bs.Text()) + fmt.Fprintf(logFile, "%s\n", line) + fmt.Printf("tailscaled[%s]: %s\n", name, line) + } + }() + + if err := cmd.Start(); err != nil { + return err + } + go func() { + err := cmd.Wait() + logFile.Close() + log.Printf("Tailscale daemon %q exited: %v", name, err) + s.mu.Lock() + defer s.mu.Unlock() + delete(s.cmds, name) + }() + + mak.Set(&s.cmds, name, cmd) + return nil +} + +type listResponse struct { + // Instances maps instance name to its details. + Instances map[string]listResponseInstance `json:"instances"` +} + +type listResponseInstance struct { + Name string `json:"name"` + Dir string `json:"dir"` + Sock string `json:"sock"` + Running bool `json:"running"` + Env string `json:"env"` + Args string `json:"args"` + Logs string `json:"logs"` +} + +func writeJSON(w http.ResponseWriter, v any) { + w.Header().Set("Content-Type", "application/json") + e := json.NewEncoder(w) + e.SetIndent("", " ") + e.Encode(v) +} + +func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/list" { + var res listResponse + for _, name := range s.InstanceNames() { + mak.Set(&res.Instances, name, listResponseInstance{ + Name: name, + Dir: instDir(name), + Sock: instSock(name), + Running: s.InstanceRunning(name), + Env: instEnvFile(name), + Args: instArgsFile(name), + Logs: instLogsFile(name), + }) + } + writeJSON(w, res) + return + } + if r.URL.Path == "/rm" || r.URL.Path == "/stop" { + shouldRemove := r.URL.Path == "/rm" + if r.Method != "POST" { + http.Error(w, "POST required", http.StatusMethodNotAllowed) + return + } + target := r.FormValue("name") + var ok bool + for _, name := range s.InstanceNames() { + if name != target { + continue + } + ok = true + s.Stop(name) + if shouldRemove { + if err := os.RemoveAll(instDir(name)); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + } + break + } + writeJSON(w, ok) + return + } + if inst, ok := strings.CutPrefix(r.URL.Path, "/create/"); ok { + if !s.InstanceRunning(inst) { + if err := s.RunInstance(inst); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + } + fmt.Fprintf(w, "OK\n") + return + } + if r.URL.Path == "/" { + fmt.Fprintf(w, "This is mts, the multi-tailscaled server.\n") + return + } + http.NotFound(w, r) +} + +func (s *Server) InstanceNames() []string { + var ret []string + des, err := os.ReadDir(mtsRoot()) + if err != nil { + if os.IsNotExist(err) { + return nil + } + panic(err) + } + for _, de := range des { + if !de.IsDir() { + continue + } + ret = append(ret, de.Name()) + } + return ret +} + +func mtsRoot() string { + dir, err := os.UserConfigDir() + if err != nil { + panic(err) + } + return filepath.Join(dir, "multi-tailscale-dev") +} + +func instDir(name string) string { + return filepath.Join(mtsRoot(), name) +} + +func instSock(name string) string { + return filepath.Join(instDir(name), "tailscaled.sock") +} + +func instEnvFile(name string) string { + return filepath.Join(mtsRoot(), name, "env.txt") +} + +func instArgsFile(name string) string { + return filepath.Join(mtsRoot(), name, "args.txt") +} + +func instLogsFile(name string) string { + return filepath.Join(mtsRoot(), name, "logs.txt") +} + +func mtsSock() string { + return filepath.Join(mtsRoot(), "mts.sock") +} diff --git a/tstest/natlab/natlab.go b/tstest/natlab/natlab.go index 92a4ccb68..ffa02eee4 100644 --- a/tstest/natlab/natlab.go +++ b/tstest/natlab/natlab.go @@ -684,10 +684,11 @@ func (m *Machine) ListenPacket(ctx context.Context, network, address string) (ne ipp := netip.AddrPortFrom(ip, port) c := &conn{ - m: m, - fam: fam, - ipp: ipp, - in: make(chan *Packet, 100), // arbitrary + m: m, + fam: fam, + ipp: ipp, + closedCh: make(chan struct{}), + in: make(chan *Packet, 100), // arbitrary } switch c.fam { case 0: @@ -716,70 +717,28 @@ type conn struct { fam uint8 // 0, 4, or 6 ipp netip.AddrPort - mu sync.Mutex - closed bool - readDeadline time.Time - activeReads map[*activeRead]bool - in chan *Packet -} + closeOnce sync.Once + closedCh chan struct{} // closed by Close -type activeRead struct { - cancel context.CancelFunc -} - -// canRead reports whether we can do a read. -func (c *conn) canRead() error { - c.mu.Lock() - defer c.mu.Unlock() - if c.closed { - return net.ErrClosed - } - if !c.readDeadline.IsZero() && c.readDeadline.Before(time.Now()) { - return errors.New("read deadline exceeded") - } - return nil -} - -func (c *conn) registerActiveRead(ar *activeRead, active bool) { - c.mu.Lock() - defer c.mu.Unlock() - if c.activeReads == nil { - c.activeReads = make(map[*activeRead]bool) - } - if active { - c.activeReads[ar] = true - } else { - delete(c.activeReads, ar) - } + in chan *Packet } func (c *conn) Close() error { - c.mu.Lock() - defer c.mu.Unlock() - if c.closed { - return nil - } - c.closed = true - switch c.fam { - case 0: - c.m.unregisterConn4(c) - c.m.unregisterConn6(c) - case 4: - c.m.unregisterConn4(c) - case 6: - c.m.unregisterConn6(c) - } - c.breakActiveReadsLocked() + c.closeOnce.Do(func() { + switch c.fam { + case 0: + c.m.unregisterConn4(c) + c.m.unregisterConn6(c) + case 4: + c.m.unregisterConn4(c) + case 6: + c.m.unregisterConn6(c) + } + close(c.closedCh) + }) return nil } -func (c *conn) breakActiveReadsLocked() { - for ar := range c.activeReads { - ar.cancel() - } - c.activeReads = nil -} - func (c *conn) LocalAddr() net.Addr { return &net.UDPAddr{ IP: c.ipp.Addr().AsSlice(), @@ -809,25 +768,13 @@ func (c *conn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { } func (c *conn) ReadFromUDPAddrPort(p []byte) (n int, addr netip.AddrPort, err error) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - ar := &activeRead{cancel: cancel} - - if err := c.canRead(); err != nil { - return 0, netip.AddrPort{}, err - } - - c.registerActiveRead(ar, true) - defer c.registerActiveRead(ar, false) - select { + case <-c.closedCh: + return 0, netip.AddrPort{}, net.ErrClosed case pkt := <-c.in: n = copy(p, pkt.Payload) pkt.Trace("PacketConn.ReadFrom") return n, pkt.Src, nil - case <-ctx.Done(): - return 0, netip.AddrPort{}, context.DeadlineExceeded } } @@ -857,18 +804,5 @@ func (c *conn) SetWriteDeadline(t time.Time) error { panic("SetWriteDeadline unsupported; TODO when needed") } func (c *conn) SetReadDeadline(t time.Time) error { - c.mu.Lock() - defer c.mu.Unlock() - - now := time.Now() - if t.After(now) { - panic("SetReadDeadline in the future not yet supported; TODO?") - } - - if !t.IsZero() && t.Before(now) { - c.breakActiveReadsLocked() - } - c.readDeadline = t - - return nil + panic("SetReadDeadline unsupported; TODO when needed") } diff --git a/tstest/natlab/vnet/conf.go b/tstest/natlab/vnet/conf.go index cf71a6674..07b181540 100644 --- a/tstest/natlab/vnet/conf.go +++ b/tstest/natlab/vnet/conf.go @@ -10,6 +10,7 @@ import ( "net/netip" "os" "slices" + "time" "github.com/google/gopacket/layers" "github.com/google/gopacket/pcapgo" @@ -120,6 +121,8 @@ func (c *Config) AddNode(opts ...any) *Node { n.err = fmt.Errorf("unknown NodeOption %q", o) } } + case MAC: + n.mac = o default: if n.err == nil { n.err = fmt.Errorf("unknown AddNode option type %T", o) @@ -279,10 +282,28 @@ type Network struct { svcs set.Set[NetworkService] + latency time.Duration // latency applied to interface writes + lossRate float64 // chance of packet loss (0.0 to 1.0) + // ... err error // carried error } +// SetLatency sets the simulated network latency for this network. +func (n *Network) SetLatency(d time.Duration) { + n.latency = d +} + +// SetPacketLoss sets the packet loss rate for this network 0.0 (no loss) to 1.0 (total loss). +func (n *Network) SetPacketLoss(rate float64) { + if rate < 0 { + rate = 0 + } else if rate > 1 { + rate = 1 + } + n.lossRate = rate +} + // SetBlackholedIPv4 sets whether the network should blackhole all IPv4 traffic // out to the Internet. (DHCP etc continues to work on the LAN.) func (n *Network) SetBlackholedIPv4(v bool) { @@ -361,6 +382,8 @@ func (s *Server) initFromConfig(c *Config) error { wanIP4: conf.wanIP4, lanIP4: conf.lanIP4, breakWAN4: conf.breakWAN4, + latency: conf.latency, + lossRate: conf.lossRate, nodesByIP4: map[netip.Addr]*node{}, nodesByMAC: map[MAC]*node{}, logf: logger.WithPrefix(s.logf, fmt.Sprintf("[net-%v] ", conf.mac)), diff --git a/tstest/natlab/vnet/conf_test.go b/tstest/natlab/vnet/conf_test.go index 15d3c69ef..6566ac8cf 100644 --- a/tstest/natlab/vnet/conf_test.go +++ b/tstest/natlab/vnet/conf_test.go @@ -3,7 +3,10 @@ package vnet -import "testing" +import ( + "testing" + "time" +) func TestConfig(t *testing.T) { tests := []struct { @@ -18,6 +21,16 @@ func TestConfig(t *testing.T) { c.AddNode(c.AddNetwork("2.2.2.2", "10.2.0.1/16", HardNAT)) }, }, + { + name: "latency-and-loss", + setup: func(c *Config) { + n1 := c.AddNetwork("2.1.1.1", "192.168.1.1/24", EasyNAT, NATPMP) + n1.SetLatency(time.Second) + n1.SetPacketLoss(0.1) + c.AddNode(n1) + c.AddNode(c.AddNetwork("2.2.2.2", "10.2.0.1/16", HardNAT)) + }, + }, { name: "indirect", setup: func(c *Config) { diff --git a/tstest/natlab/vnet/vip.go b/tstest/natlab/vnet/vip.go index c75f17cee..190c9e75f 100644 --- a/tstest/natlab/vnet/vip.go +++ b/tstest/natlab/vnet/vip.go @@ -17,7 +17,7 @@ var ( fakeControl = newVIP("control.tailscale", 3) fakeDERP1 = newVIP("derp1.tailscale", "33.4.0.1") // 3340=DERP; 1=derp 1 fakeDERP2 = newVIP("derp2.tailscale", "33.4.0.2") // 3340=DERP; 2=derp 2 - fakeLogCatcher = newVIP("log.tailscale.io", 4) + fakeLogCatcher = newVIP("log.tailscale.com", 4) fakeSyslog = newVIP("syslog.tailscale", 9) ) diff --git a/tstest/natlab/vnet/vnet.go b/tstest/natlab/vnet/vnet.go index 919ae1fa1..49d47f029 100644 --- a/tstest/natlab/vnet/vnet.go +++ b/tstest/natlab/vnet/vnet.go @@ -50,10 +50,10 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/waiter" - "tailscale.com/client/tailscale" - "tailscale.com/derp" - "tailscale.com/derp/derphttp" + "tailscale.com/client/local" + "tailscale.com/derp/derpserver" "tailscale.com/net/netutil" + "tailscale.com/net/netx" "tailscale.com/net/stun" "tailscale.com/syncs" "tailscale.com/tailcfg" @@ -88,6 +88,9 @@ func (s *Server) PopulateDERPMapIPs() error { if n.IPv4 != "" { s.derpIPs.Add(netip.MustParseAddr(n.IPv4)) } + if n.IPv6 != "" { + s.derpIPs.Add(netip.MustParseAddr(n.IPv6)) + } } } return nil @@ -394,7 +397,7 @@ func (n *network) acceptTCP(r *tcp.ForwarderRequest) { } } -// serveLogCatchConn serves a TCP connection to "log.tailscale.io", speaking the +// serveLogCatchConn serves a TCP connection to "log.tailscale.com", speaking the // logtail/logcatcher protocol. // // We terminate TLS with an arbitrary cert; the client is configured to not @@ -515,6 +518,8 @@ type network struct { wanIP4 netip.Addr // router's LAN IPv4, if any lanIP4 netip.Prefix // router's LAN IP + CIDR (e.g. 192.168.2.1/24) breakWAN4 bool // break WAN IPv4 connectivity + latency time.Duration // latency applied to interface writes + lossRate float64 // probability of dropping a packet (0.0 to 1.0) nodesByIP4 map[netip.Addr]*node // by LAN IPv4 nodesByMAC map[MAC]*node logf func(format string, args ...any) @@ -595,7 +600,7 @@ func (n *node) String() string { } type derpServer struct { - srv *derp.Server + srv *derpserver.Server handler http.Handler tlsConfig *tls.Config } @@ -606,12 +611,12 @@ func newDERPServer() *derpServer { ts.Close() ds := &derpServer{ - srv: derp.NewServer(key.NewNode(), logger.Discard), + srv: derpserver.New(key.NewNode(), logger.Discard), tlsConfig: ts.TLS, // self-signed; test client configure to not check } var mux http.ServeMux - mux.Handle("/derp", derphttp.Handler(ds.srv)) - mux.HandleFunc("/generate_204", derphttp.ServeNoContent) + mux.Handle("/derp", derpserver.Handler(ds.srv)) + mux.HandleFunc("/generate_204", derpserver.ServeNoContent) ds.handler = &mux return ds @@ -644,7 +649,7 @@ type Server struct { mu sync.Mutex agentConnWaiter map[*node]chan<- struct{} // signaled after added to set agentConns set.Set[*agentConn] // not keyed by node; should be small/cheap enough to scan all - agentDialer map[*node]DialFunc + agentDialer map[*node]netx.DialFunc } func (s *Server) logf(format string, args ...any) { @@ -659,8 +664,6 @@ func (s *Server) SetLoggerForTest(logf func(format string, args ...any)) { s.optLogf = logf } -type DialFunc func(ctx context.Context, network, address string) (net.Conn, error) - var derpMap = &tailcfg.DERPMap{ Regions: map[int]*tailcfg.DERPRegion{ 1: { @@ -974,13 +977,12 @@ func (n *network) writeEth(res []byte) bool { if dstMAC.IsBroadcast() || (n.v6 && etherType == layers.EthernetTypeIPv6 && dstMAC == macAllNodes) { num := 0 - n.writers.Range(func(mac MAC, nw networkWriter) bool { + for mac, nw := range n.writers.All() { if mac != srcMAC { num++ - nw.write(res) + n.conditionedWrite(nw, res) } - return true - }) + } return num > 0 } if srcMAC == dstMAC { @@ -988,7 +990,7 @@ func (n *network) writeEth(res []byte) bool { return false } if nw, ok := n.writers.Load(dstMAC); ok { - nw.write(res) + n.conditionedWrite(nw, res) return true } @@ -1001,6 +1003,23 @@ func (n *network) writeEth(res []byte) bool { return false } +func (n *network) conditionedWrite(nw networkWriter, packet []byte) { + if n.lossRate > 0 && rand.Float64() < n.lossRate { + // packet lost + return + } + if n.latency > 0 { + // copy the packet as there's no guarantee packet is owned long enough. + // TODO(raggi): this could be optimized substantially if necessary, + // a pool of buffers and a cheaper delay mechanism are both obvious improvements. + var pkt = make([]byte, len(packet)) + copy(pkt, packet) + time.AfterFunc(n.latency, func() { nw.write(pkt) }) + } else { + nw.write(packet) + } +} + var ( macAllNodes = MAC{0: 0x33, 1: 0x33, 5: 0x01} macAllRouters = MAC{0: 0x33, 1: 0x33, 5: 0x02} @@ -2105,11 +2124,11 @@ func (s *Server) takeAgentConnOne(n *node) (_ *agentConn, ok bool) { } type NodeAgentClient struct { - *tailscale.LocalClient + *local.Client HTTPClient *http.Client } -func (s *Server) NodeAgentDialer(n *Node) DialFunc { +func (s *Server) NodeAgentDialer(n *Node) netx.DialFunc { s.mu.Lock() defer s.mu.Unlock() @@ -2130,7 +2149,7 @@ func (s *Server) NodeAgentDialer(n *Node) DialFunc { func (s *Server) NodeAgentClient(n *Node) *NodeAgentClient { d := s.NodeAgentDialer(n) return &NodeAgentClient{ - LocalClient: &tailscale.LocalClient{ + Client: &local.Client{ UseSocketOnly: true, OmitAuth: true, Dial: d, diff --git a/tstest/nettest/nettest.go b/tstest/nettest/nettest.go index 47c8857a5..c78677dd4 100644 --- a/tstest/nettest/nettest.go +++ b/tstest/nettest/nettest.go @@ -6,11 +6,22 @@ package nettest import ( + "context" + "flag" + "net" + "net/http" + "net/http/httptest" + "sync" "testing" + "tailscale.com/net/memnet" "tailscale.com/net/netmon" + "tailscale.com/net/netx" + "tailscale.com/util/testenv" ) +var useMemNet = flag.Bool("use-test-memnet", false, "prefer using in-memory network for tests") + // SkipIfNoNetwork skips the test if it looks like there's no network // access. func SkipIfNoNetwork(t testing.TB) { @@ -19,3 +30,89 @@ func SkipIfNoNetwork(t testing.TB) { t.Skip("skipping; test requires network but no interface is up") } } + +// PreferMemNetwork reports whether the --use-test-memnet flag is set. +func PreferMemNetwork() bool { + return *useMemNet +} + +// GetNetwork returns the appropriate Network implementation based on +// whether the --use-test-memnet flag is set. +// +// Each call generates a new network. +func GetNetwork(tb testing.TB) netx.Network { + var n netx.Network + if PreferMemNetwork() { + n = &memnet.Network{} + } else { + n = netx.RealNetwork() + } + + detectLeaks := PreferMemNetwork() || !testenv.InParallelTest(tb) + if detectLeaks { + tb.Cleanup(func() { + // TODO: leak detection, making sure no connections + // remain at the end of the test. For real network, + // snapshot conns in pid table before & after. + }) + } + return n +} + +// NewHTTPServer starts and returns a new [httptest.Server]. +// The caller should call Close when finished, to shut it down. +func NewHTTPServer(net netx.Network, handler http.Handler) *httptest.Server { + ts := NewUnstartedHTTPServer(net, handler) + ts.Start() + return ts +} + +// NewUnstartedHTTPServer returns a new [httptest.Server] but doesn't start it. +// +// After changing its configuration, the caller should call Start or +// StartTLS. +// +// The caller should call Close when finished, to shut it down. +func NewUnstartedHTTPServer(nw netx.Network, handler http.Handler) *httptest.Server { + s := &httptest.Server{ + Config: &http.Server{Handler: handler}, + } + ln := nw.NewLocalTCPListener() + s.Listener = &listenerOnAddrOnce{ + Listener: ln, + fn: func() { + c := s.Client() + if c == nil { + // This httptest.Server.Start initialization order has been true + // for over 10 years. Let's keep counting on it. + panic("httptest.Server: Client not initialized before Addr called") + } + if c.Transport == nil { + c.Transport = &http.Transport{} + } + tr := c.Transport.(*http.Transport) + if tr.Dial != nil || tr.DialContext != nil { + panic("unexpected non-nil Dial or DialContext in httptest.Server.Client.Transport") + } + tr.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { + return nw.Dial(ctx, network, addr) + } + }, + } + return s +} + +// listenerOnAddrOnce is a net.Listener that wraps another net.Listener +// and calls a function the first time its Addr is called. +type listenerOnAddrOnce struct { + net.Listener + once sync.Once + fn func() +} + +func (ln *listenerOnAddrOnce) Addr() net.Addr { + ln.once.Do(func() { + ln.fn() + }) + return ln.Listener.Addr() +} diff --git a/tstest/resource.go b/tstest/resource.go index a3c292094..f50bb3330 100644 --- a/tstest/resource.go +++ b/tstest/resource.go @@ -7,10 +7,10 @@ import ( "bytes" "runtime" "runtime/pprof" + "slices" + "strings" "testing" "time" - - "github.com/google/go-cmp/cmp" ) // ResourceCheck takes a snapshot of the current goroutines and registers a @@ -29,7 +29,8 @@ func ResourceCheck(tb testing.TB) { startN, startStacks := goroutines() tb.Cleanup(func() { if tb.Failed() { - // Something else went wrong. + // Test has failed - but this doesn't catch panics due to + // https://github.com/golang/go/issues/49929. return } // Goroutines might be still exiting. @@ -43,8 +44,24 @@ func ResourceCheck(tb testing.TB) { if endN <= startN { return } - tb.Logf("goroutine diff:\n%v\n", cmp.Diff(startStacks, endStacks)) - tb.Fatalf("goroutine count: expected %d, got %d\n", startN, endN) + + // Parse and print goroutines. + start := parseGoroutines(startStacks) + end := parseGoroutines(endStacks) + if testing.Verbose() { + tb.Logf("goroutines start:\n%s", printGoroutines(start)) + tb.Logf("goroutines end:\n%s", printGoroutines(end)) + } + + // Print goroutine diff, omitting tstest.ResourceCheck goroutines. + self := func(g goroutine) bool { return bytes.Contains(g.stack, []byte("\ttailscale.com/tstest.goroutines+")) } + start.goroutines = slices.DeleteFunc(start.goroutines, self) + end.goroutines = slices.DeleteFunc(end.goroutines, self) + tb.Logf("goroutine diff (-start +end):\n%s", diffGoroutines(start, end)) + + // tb.Failed() above won't report on panics, so we shouldn't call Fatal + // here or we risk suppressing reporting of the panic. + tb.Errorf("goroutine count: expected %d, got %d\n", startN, endN) }) } @@ -54,3 +71,208 @@ func goroutines() (int, []byte) { p.WriteTo(b, 1) return p.Count(), b.Bytes() } + +// parseGoroutines takes pprof/goroutines?debug=1 -formatted output sorted by +// count, and splits it into a separate list of goroutines with count and stack +// separated. +// +// Example input: +// +// goroutine profile: total 408 +// 48 @ 0x47bc0e 0x136c6b9 0x136c69e 0x136c7ab 0x1379809 0x13797fa 0x483da1 +// # 0x136c6b8 gvisor.dev/gvisor/pkg/sync.Gopark+0x78 gvisor.dev/gvisor@v0.0.0-20250205023644-9414b50a5633/pkg/sync/runtime_unsafe.go:33 +// # 0x136c69d gvisor.dev/gvisor/pkg/sleep.(*Sleeper).nextWaker+0x5d gvisor.dev/gvisor@v0.0.0-20250205023644-9414b50a5633/pkg/sleep/sleep_unsafe.go:210 +// # 0x136c7aa gvisor.dev/gvisor/pkg/sleep.(*Sleeper).fetch+0x2a gvisor.dev/gvisor@v0.0.0-20250205023644-9414b50a5633/pkg/sleep/sleep_unsafe.go:257 +// # 0x1379808 gvisor.dev/gvisor/pkg/sleep.(*Sleeper).Fetch+0xa8 gvisor.dev/gvisor@v0.0.0-20250205023644-9414b50a5633/pkg/sleep/sleep_unsafe.go:280 +// # 0x13797f9 gvisor.dev/gvisor/pkg/tcpip/transport/tcp.(*processor).start+0x99 gvisor.dev/gvisor@v0.0.0-20250205023644-9414b50a5633/pkg/tcpip/transport/tcp/dispatcher.go:291 +// +// 48 @ 0x47bc0e 0x413705 0x4132b2 0x10fc905 0x483da1 +// # 0x10fc904 github.com/tailscale/wireguard-go/device.(*Device).RoutineDecryption+0x184 github.com/tailscale/wireguard-go@v0.0.0-20250107165329-0b8b35511f19/device/receive.go:245 +// +// 48 @ 0x47bc0e 0x413705 0x4132b2 0x10fcd2a 0x483da1 +// # 0x10fcd29 github.com/tailscale/wireguard-go/device.(*Device).RoutineHandshake+0x169 github.com/tailscale/wireguard-go@v0.0.0-20250107165329-0b8b35511f19/device/receive.go:279 +// +// 48 @ 0x47bc0e 0x413705 0x4132b2 0x1100ba7 0x483da1 +// # 0x1100ba6 github.com/tailscale/wireguard-go/device.(*Device).RoutineEncryption+0x186 github.com/tailscale/wireguard-go@v0.0.0-20250107165329-0b8b35511f19/device/send.go:451 +// +// 26 @ 0x47bc0e 0x458e57 0x847587 0x483da1 +// # 0x847586 database/sql.(*DB).connectionOpener+0x86 database/sql/sql.go:1261 +// +// 13 @ 0x47bc0e 0x458e57 0x754927 0x483da1 +// # 0x754926 net/http.(*persistConn).writeLoop+0xe6 net/http/transport.go:2596 +// +// 7 @ 0x47bc0e 0x413705 0x4132b2 0x10fda4d 0x483da1 +// # 0x10fda4c github.com/tailscale/wireguard-go/device.(*Peer).RoutineSequentialReceiver+0x16c github.com/tailscale/wireguard-go@v0.0.0-20250107165329-0b8b35511f19/device/receive.go:443 +func parseGoroutines(g []byte) goroutineDump { + head, tail, ok := bytes.Cut(g, []byte("\n")) + if !ok { + return goroutineDump{head: head} + } + + raw := bytes.Split(tail, []byte("\n\n")) + parsed := make([]goroutine, 0, len(raw)) + for _, s := range raw { + count, rem, ok := bytes.Cut(s, []byte(" @ ")) + if !ok { + continue + } + header, stack, _ := bytes.Cut(rem, []byte("\n")) + sort := slices.Clone(header) + reverseWords(sort) + parsed = append(parsed, goroutine{count, header, stack, sort}) + } + + return goroutineDump{head, parsed} +} + +type goroutineDump struct { + head []byte + goroutines []goroutine +} + +// goroutine is a parsed stack trace in pprof goroutine output, e.g. +// "10 @ 0x100 0x001\n# 0x100 test() test.go\n# 0x001 main() test.go". +type goroutine struct { + count []byte // e.g. "10" + header []byte // e.g. "0x100 0x001" + stack []byte // e.g. "# 0x100 test() test.go\n# 0x001 main() test.go" + + // sort is the same pointers as in header, but in reverse order so that we + // can place related goroutines near each other by sorting on this field. + // E.g. "0x001 0x100". + sort []byte +} + +func (g goroutine) Compare(h goroutine) int { + return bytes.Compare(g.sort, h.sort) +} + +// reverseWords repositions the words in b such that they are reversed. +// Words are separated by spaces. New lines are not considered. +// https://sketch.dev/sk/a4ef +func reverseWords(b []byte) { + if len(b) == 0 { + return + } + + // First, reverse the entire slice. + reverse(b) + + // Then reverse each word individually. + start := 0 + for i := 0; i <= len(b); i++ { + if i == len(b) || b[i] == ' ' { + reverse(b[start:i]) + start = i + 1 + } + } +} + +// reverse reverses bytes in place +func reverse(b []byte) { + for i, j := 0, len(b)-1; i < j; i, j = i+1, j-1 { + b[i], b[j] = b[j], b[i] + } +} + +// printGoroutines returns a text representation of h, gs equivalent to the +// pprof ?debug=1 input parsed by parseGoroutines, except the goroutines are +// sorted in an order easier for diffing. +func printGoroutines(g goroutineDump) []byte { + var b bytes.Buffer + b.Write(g.head) + + slices.SortFunc(g.goroutines, goroutine.Compare) + for _, g := range g.goroutines { + b.WriteString("\n\n") + b.Write(g.count) + b.WriteString(" @ ") + b.Write(g.header) + b.WriteString("\n") + if len(g.stack) > 0 { + b.Write(g.stack) + } + } + + return b.Bytes() +} + +// diffGoroutines returns a diff between goroutines of gx and gy. +// Goroutines present in gx and absent from gy are prefixed with "-". +// Goroutines absent from gx and present in gy are prefixed with "+". +// Goroutines present in both but with different counts only show a prefix on the count line. +func diffGoroutines(x, y goroutineDump) string { + hx, hy := x.head, y.head + gx, gy := x.goroutines, y.goroutines + var b strings.Builder + if !bytes.Equal(hx, hy) { + b.WriteString("- ") + b.Write(hx) + b.WriteString("\n+ ") + b.Write(hy) + b.WriteString("\n") + } + + slices.SortFunc(gx, goroutine.Compare) + slices.SortFunc(gy, goroutine.Compare) + + writeHeader := func(prefix string, g goroutine) { + b.WriteString(prefix) + b.Write(g.count) + b.WriteString(" @ ") + b.Write(g.header) + b.WriteString("\n") + } + writeStack := func(prefix string, g goroutine) { + s := g.stack + for { + var h []byte + h, s, _ = bytes.Cut(s, []byte("\n")) + if len(h) == 0 && len(s) == 0 { + break + } + b.WriteString(prefix) + b.Write(h) + b.WriteString("\n") + } + } + + i, j := 0, 0 + for { + var d int + switch { + case i < len(gx) && j < len(gy): + d = gx[i].Compare(gy[j]) + case i < len(gx): + d = -1 + case j < len(gy): + d = 1 + default: + return b.String() + } + + switch d { + case -1: + b.WriteString("\n") + writeHeader("- ", gx[i]) + writeStack("- ", gx[i]) + i++ + + case +1: + b.WriteString("\n") + writeHeader("+ ", gy[j]) + writeStack("+ ", gy[j]) + j++ + + case 0: + if !bytes.Equal(gx[i].count, gy[j].count) { + b.WriteString("\n") + writeHeader("- ", gx[i]) + writeHeader("+ ", gy[j]) + writeStack(" ", gy[j]) + } + i++ + j++ + } + } +} diff --git a/tstest/resource_test.go b/tstest/resource_test.go new file mode 100644 index 000000000..7199ac5d1 --- /dev/null +++ b/tstest/resource_test.go @@ -0,0 +1,256 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tstest + +import ( + "strings" + "testing" + + "github.com/google/go-cmp/cmp" +) + +func TestPrintGoroutines(t *testing.T) { + tests := []struct { + name string + in string + want string + }{ + { + name: "empty", + in: "goroutine profile: total 0\n", + want: "goroutine profile: total 0", + }, + { + name: "single goroutine", + in: `goroutine profile: total 1 +1 @ 0x47bc0e 0x458e57 0x847587 0x483da1 +# 0x847586 database/sql.(*DB).connectionOpener+0x86 database/sql/sql.go:1261 +`, + want: `goroutine profile: total 1 + +1 @ 0x47bc0e 0x458e57 0x847587 0x483da1 +# 0x847586 database/sql.(*DB).connectionOpener+0x86 database/sql/sql.go:1261 +`, + }, + { + name: "multiple goroutines sorted", + in: `goroutine profile: total 14 +7 @ 0x47bc0e 0x413705 0x4132b2 0x10fda4d 0x483da1 +# 0x10fda4c github.com/user/pkg.RoutineA+0x16c pkg/a.go:443 + +7 @ 0x47bc0e 0x458e57 0x754927 0x483da1 +# 0x754926 net/http.(*persistConn).writeLoop+0xe6 net/http/transport.go:2596 +`, + want: `goroutine profile: total 14 + +7 @ 0x47bc0e 0x413705 0x4132b2 0x10fda4d 0x483da1 +# 0x10fda4c github.com/user/pkg.RoutineA+0x16c pkg/a.go:443 + +7 @ 0x47bc0e 0x458e57 0x754927 0x483da1 +# 0x754926 net/http.(*persistConn).writeLoop+0xe6 net/http/transport.go:2596 +`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := string(printGoroutines(parseGoroutines([]byte(tt.in)))) + if got != tt.want { + t.Errorf("printGoroutines() = %q, want %q, diff:\n%s", got, tt.want, cmp.Diff(tt.want, got)) + } + }) + } +} + +func TestDiffPprofGoroutines(t *testing.T) { + tests := []struct { + name string + x, y string + want string + }{ + { + name: "no difference", + x: `goroutine profile: total 1 +1 @ 0x47bc0e 0x458e57 0x847587 0x483da1 +# 0x847586 database/sql.(*DB).connectionOpener+0x86 database/sql/sql.go:1261`, + y: `goroutine profile: total 1 +1 @ 0x47bc0e 0x458e57 0x847587 0x483da1 +# 0x847586 database/sql.(*DB).connectionOpener+0x86 database/sql/sql.go:1261 +`, + want: "", + }, + { + name: "different counts", + x: `goroutine profile: total 1 +1 @ 0x47bc0e 0x458e57 0x847587 0x483da1 +# 0x847586 database/sql.(*DB).connectionOpener+0x86 database/sql/sql.go:1261 +`, + y: `goroutine profile: total 2 +2 @ 0x47bc0e 0x458e57 0x847587 0x483da1 +# 0x847586 database/sql.(*DB).connectionOpener+0x86 database/sql/sql.go:1261 +`, + want: `- goroutine profile: total 1 ++ goroutine profile: total 2 + +- 1 @ 0x47bc0e 0x458e57 0x847587 0x483da1 ++ 2 @ 0x47bc0e 0x458e57 0x847587 0x483da1 + # 0x847586 database/sql.(*DB).connectionOpener+0x86 database/sql/sql.go:1261 +`, + }, + { + name: "new goroutine", + x: `goroutine profile: total 1 +1 @ 0x47bc0e 0x458e57 0x847587 0x483da1 +# 0x847586 database/sql.(*DB).connectionOpener+0x86 database/sql/sql.go:1261 +`, + y: `goroutine profile: total 2 +1 @ 0x47bc0e 0x458e57 0x847587 0x483da1 +# 0x847586 database/sql.(*DB).connectionOpener+0x86 database/sql/sql.go:1261 + +1 @ 0x47bc0e 0x458e57 0x754927 0x483da1 +# 0x754926 net/http.(*persistConn).writeLoop+0xe6 net/http/transport.go:2596 +`, + want: `- goroutine profile: total 1 ++ goroutine profile: total 2 + ++ 1 @ 0x47bc0e 0x458e57 0x754927 0x483da1 ++ # 0x754926 net/http.(*persistConn).writeLoop+0xe6 net/http/transport.go:2596 +`, + }, + { + name: "removed goroutine", + x: `goroutine profile: total 2 +1 @ 0x47bc0e 0x458e57 0x847587 0x483da1 +# 0x847586 database/sql.(*DB).connectionOpener+0x86 database/sql/sql.go:1261 + +1 @ 0x47bc0e 0x458e57 0x754927 0x483da1 +# 0x754926 net/http.(*persistConn).writeLoop+0xe6 net/http/transport.go:2596 +`, + y: `goroutine profile: total 1 +1 @ 0x47bc0e 0x458e57 0x847587 0x483da1 +# 0x847586 database/sql.(*DB).connectionOpener+0x86 database/sql/sql.go:1261 +`, + want: `- goroutine profile: total 2 ++ goroutine profile: total 1 + +- 1 @ 0x47bc0e 0x458e57 0x754927 0x483da1 +- # 0x754926 net/http.(*persistConn).writeLoop+0xe6 net/http/transport.go:2596 +`, + }, + { + name: "removed many goroutine", + x: `goroutine profile: total 2 +1 @ 0x47bc0e 0x458e57 0x847587 0x483da1 +# 0x847586 database/sql.(*DB).connectionOpener+0x86 database/sql/sql.go:1261 + +1 @ 0x47bc0e 0x458e57 0x754927 0x483da1 +# 0x754926 net/http.(*persistConn).writeLoop+0xe6 net/http/transport.go:2596 +`, + y: `goroutine profile: total 0`, + want: `- goroutine profile: total 2 ++ goroutine profile: total 0 + +- 1 @ 0x47bc0e 0x458e57 0x754927 0x483da1 +- # 0x754926 net/http.(*persistConn).writeLoop+0xe6 net/http/transport.go:2596 + +- 1 @ 0x47bc0e 0x458e57 0x847587 0x483da1 +- # 0x847586 database/sql.(*DB).connectionOpener+0x86 database/sql/sql.go:1261 +`, + }, + { + name: "invalid input x", + x: "invalid", + y: "goroutine profile: total 0\n", + want: "- invalid\n+ goroutine profile: total 0\n", + }, + { + name: "invalid input y", + x: "goroutine profile: total 0\n", + y: "invalid", + want: "- goroutine profile: total 0\n+ invalid\n", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := diffGoroutines( + parseGoroutines([]byte(tt.x)), + parseGoroutines([]byte(tt.y)), + ) + if got != tt.want { + t.Errorf("diffPprofGoroutines() diff:\ngot:\n%s\nwant:\n%s\ndiff (-want +got):\n%s", got, tt.want, cmp.Diff(tt.want, got)) + } + }) + } +} + +func TestParseGoroutines(t *testing.T) { + tests := []struct { + name string + in string + wantHeader string + wantCount int + }{ + { + name: "empty profile", + in: "goroutine profile: total 0\n", + wantHeader: "goroutine profile: total 0", + wantCount: 0, + }, + { + name: "single goroutine", + in: `goroutine profile: total 1 +1 @ 0x47bc0e 0x458e57 0x847587 0x483da1 +# 0x847586 database/sql.(*DB).connectionOpener+0x86 database/sql/sql.go:1261 +`, + wantHeader: "goroutine profile: total 1", + wantCount: 1, + }, + { + name: "multiple goroutines", + in: `goroutine profile: total 14 +7 @ 0x47bc0e 0x413705 0x4132b2 0x10fda4d 0x483da1 +# 0x10fda4c github.com/user/pkg.RoutineA+0x16c pkg/a.go:443 + +7 @ 0x47bc0e 0x458e57 0x754927 0x483da1 +# 0x754926 net/http.(*persistConn).writeLoop+0xe6 net/http/transport.go:2596 +`, + wantHeader: "goroutine profile: total 14", + wantCount: 2, + }, + { + name: "invalid format", + in: "invalid", + wantHeader: "invalid", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + g := parseGoroutines([]byte(tt.in)) + + if got := string(g.head); got != tt.wantHeader { + t.Errorf("parseGoroutines() header = %q, want %q", got, tt.wantHeader) + } + if got := len(g.goroutines); got != tt.wantCount { + t.Errorf("parseGoroutines() goroutine count = %d, want %d", got, tt.wantCount) + } + + // Verify that the sort field is correctly reversed + for _, g := range g.goroutines { + original := strings.Fields(string(g.header)) + sorted := strings.Fields(string(g.sort)) + if len(original) != len(sorted) { + t.Errorf("sort field has different number of words: got %d, want %d", len(sorted), len(original)) + continue + } + for i := 0; i < len(original); i++ { + if original[i] != sorted[len(sorted)-1-i] { + t.Errorf("sort field word mismatch at position %d: got %q, want %q", i, sorted[len(sorted)-1-i], original[i]) + } + } + } + }) + } +} diff --git a/tstest/tailmac/Swift/Common/Config.swift b/tstest/tailmac/Swift/Common/Config.swift index 01d5069b0..18b68ae9b 100644 --- a/tstest/tailmac/Swift/Common/Config.swift +++ b/tstest/tailmac/Swift/Common/Config.swift @@ -14,6 +14,7 @@ class Config: Codable { var mac = "52:cc:cc:cc:cc:01" var ethermac = "52:cc:cc:cc:ce:01" var port: UInt32 = 51009 + var sharedDir: String? // The virtual machines ID. Also double as the directory name under which // we will store configuration, block device, etc. diff --git a/tstest/tailmac/Swift/Common/TailMacConfigHelper.swift b/tstest/tailmac/Swift/Common/TailMacConfigHelper.swift index 00f999a15..c0961c883 100644 --- a/tstest/tailmac/Swift/Common/TailMacConfigHelper.swift +++ b/tstest/tailmac/Swift/Common/TailMacConfigHelper.swift @@ -141,5 +141,18 @@ struct TailMacConfigHelper { func createKeyboardConfiguration() -> VZKeyboardConfiguration { return VZMacKeyboardConfiguration() } + + func createDirectoryShareConfiguration(tag: String) -> VZDirectorySharingDeviceConfiguration? { + guard let dir = config.sharedDir else { return nil } + + let sharedDir = VZSharedDirectory(url: URL(fileURLWithPath: dir), readOnly: false) + let share = VZSingleDirectoryShare(directory: sharedDir) + + // Create the VZVirtioFileSystemDeviceConfiguration and assign it a unique tag. + let sharingConfiguration = VZVirtioFileSystemDeviceConfiguration(tag: tag) + sharingConfiguration.share = share + + return sharingConfiguration + } } diff --git a/tstest/tailmac/Swift/Host/HostCli.swift b/tstest/tailmac/Swift/Host/HostCli.swift index 1318a09fa..c31478cc3 100644 --- a/tstest/tailmac/Swift/Host/HostCli.swift +++ b/tstest/tailmac/Swift/Host/HostCli.swift @@ -19,10 +19,12 @@ var config: Config = Config() extension HostCli { struct Run: ParsableCommand { @Option var id: String + @Option var share: String? mutating func run() { - print("Running vm with identifier \(id)") config = Config(id) + config.sharedDir = share + print("Running vm with identifier \(id) and sharedDir \(share ?? "")") _ = NSApplicationMain(CommandLine.argc, CommandLine.unsafeArgv) } } diff --git a/tstest/tailmac/Swift/Host/VMController.swift b/tstest/tailmac/Swift/Host/VMController.swift index 8774894c1..fe4a3828b 100644 --- a/tstest/tailmac/Swift/Host/VMController.swift +++ b/tstest/tailmac/Swift/Host/VMController.swift @@ -95,6 +95,13 @@ class VMController: NSObject, VZVirtualMachineDelegate { virtualMachineConfiguration.keyboards = [helper.createKeyboardConfiguration()] virtualMachineConfiguration.socketDevices = [helper.createSocketDeviceConfiguration()] + if let dir = config.sharedDir, let shareConfig = helper.createDirectoryShareConfiguration(tag: "vmshare") { + print("Sharing \(dir) as vmshare. Use: mount_virtiofs vmshare in the guest to mount.") + virtualMachineConfiguration.directorySharingDevices = [shareConfig] + } else { + print("No shared directory created. \(config.sharedDir ?? "none") was requested.") + } + try! virtualMachineConfiguration.validate() try! virtualMachineConfiguration.validateSaveRestoreSupport() diff --git a/tstest/tailmac/Swift/TailMac/TailMac.swift b/tstest/tailmac/Swift/TailMac/TailMac.swift index 56f651696..84aa5e498 100644 --- a/tstest/tailmac/Swift/TailMac/TailMac.swift +++ b/tstest/tailmac/Swift/TailMac/TailMac.swift @@ -95,12 +95,16 @@ extension Tailmac { extension Tailmac { struct Run: ParsableCommand { @Option(help: "The vm identifier") var id: String + @Option(help: "Optional share directory") var share: String? @Flag(help: "Tail the TailMac log output instead of returning immediatly") var tail mutating func run() { let process = Process() let stdOutPipe = Pipe() - let appPath = "./Host.app/Contents/MacOS/Host" + + let executablePath = CommandLine.arguments[0] + let executableDirectory = (executablePath as NSString).deletingLastPathComponent + let appPath = executableDirectory + "/Host.app/Contents/MacOS/Host" process.executableURL = URL( fileURLWithPath: appPath, @@ -109,10 +113,15 @@ extension Tailmac { ) if !FileManager.default.fileExists(atPath: appPath) { - fatalError("Could not find Host.app. This must be co-located with the tailmac utility") + fatalError("Could not find Host.app at \(appPath). This must be co-located with the tailmac utility") } - process.arguments = ["run", "--id", id] + var args = ["run", "--id", id] + if let share { + args.append("--share") + args.append(share) + } + process.arguments = args do { process.standardOutput = stdOutPipe @@ -121,26 +130,18 @@ extension Tailmac { fatalError("Unable to launch the vm process") } - // This doesn't print until we exit which is not ideal, but at least we - // get the output if tail != 0 { + // (jonathan)TODO: How do we get the process output in real time? + // The child process only seems to flush to stdout on completion let outHandle = stdOutPipe.fileHandleForReading - - let queue = OperationQueue() - NotificationCenter.default.addObserver( - forName: NSNotification.Name.NSFileHandleDataAvailable, - object: outHandle, queue: queue) - { - notification -> Void in - let data = outHandle.availableData + outHandle.readabilityHandler = { handle in + let data = handle.availableData if data.count > 0 { if let str = String(data: data, encoding: String.Encoding.utf8) { print(str) } } - outHandle.waitForDataInBackgroundAndNotify() } - outHandle.waitForDataInBackgroundAndNotify() process.waitUntilExit() } } diff --git a/tstest/tlstest/tlstest.go b/tstest/tlstest/tlstest.go new file mode 100644 index 000000000..76ec0e7e2 --- /dev/null +++ b/tstest/tlstest/tlstest.go @@ -0,0 +1,187 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package tlstest contains code to help test Tailscale's TLS support without +// depending on real WebPKI roots or certificates during tests. +package tlstest + +import ( + "bytes" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + _ "embed" + "encoding/pem" + "fmt" + "math/big" + "sync" + "time" +) + +// TestRootCA returns a self-signed ECDSA root CA certificate (as PEM) for +// testing purposes. +// +// Typical use in a test is like: +// +// bakedroots.ResetForTest(t, tlstest.TestRootCA()) +func TestRootCA() []byte { + return bytes.Clone(testRootCAOncer()) +} + +// cache for [privateKey], so it always returns the same key for a given domain. +var ( + mu sync.Mutex + privateKeys = make(map[string][]byte) // domain -> private key PEM +) + +// caDomain is a fake domain name to repreesnt the private key for the root CA. +const caDomain = "_root" + +// privateKey returns a PEM-encoded test ECDSA private key for the given domain. +func privateKey(domain string) (pemBytes []byte) { + mu.Lock() + defer mu.Unlock() + if pemBytes, ok := privateKeys[domain]; ok { + return bytes.Clone(pemBytes) + } + defer func() { privateKeys[domain] = bytes.Clone(pemBytes) }() + + k, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + panic(fmt.Sprintf("failed to generate ECDSA key for %q: %v", domain, err)) + } + der, err := x509.MarshalECPrivateKey(k) + if err != nil { + panic(fmt.Sprintf("failed to marshal ECDSA key for %q: %v", domain, err)) + } + var buf bytes.Buffer + if err := pem.Encode(&buf, &pem.Block{Type: "EC PRIVATE KEY", Bytes: der}); err != nil { + panic(fmt.Sprintf("failed to encode PEM: %v", err)) + } + return buf.Bytes() +} + +var testRootCAOncer = sync.OnceValue(func() []byte { + key := rootCAKey() + now := time.Now().Add(-time.Hour) + tpl := &x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{ + CommonName: "Tailscale Unit Test ECDSA Root", + Organization: []string{"Tailscale Test Org"}, + }, + NotBefore: now, + NotAfter: now.AddDate(5, 0, 0), + + IsCA: true, + BasicConstraintsValid: true, + KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign, + SubjectKeyId: mustSKID(&key.PublicKey), + } + + der, err := x509.CreateCertificate(rand.Reader, tpl, tpl, &key.PublicKey, key) + if err != nil { + panic(err) + } + return pemCert(der) +}) + +func pemCert(der []byte) []byte { + var buf bytes.Buffer + if err := pem.Encode(&buf, &pem.Block{Type: "CERTIFICATE", Bytes: der}); err != nil { + panic(fmt.Sprintf("failed to encode PEM: %v", err)) + } + return buf.Bytes() +} + +var rootCAKey = sync.OnceValue(func() *ecdsa.PrivateKey { + return mustParsePEM(privateKey(caDomain), x509.ParseECPrivateKey) +}) + +func mustParsePEM[T any](pemBytes []byte, parse func([]byte) (T, error)) T { + block, rest := pem.Decode(pemBytes) + if block == nil || len(rest) > 0 { + panic("invalid PEM") + } + v, err := parse(block.Bytes) + if err != nil { + panic(fmt.Sprintf("invalid PEM: %v", err)) + } + return v +} + +// Domain is a fake domain name used in TLS tests. +// +// They don't have real DNS records. Tests are expected to fake DNS +// lookups and dials for these domains. +type Domain string + +// ProxyServer is a domain name for a hypothetical proxy server. +const ( + ProxyServer = Domain("proxy.tstest") + + // ControlPlane is a domain name for a test control plane server. + ControlPlane = Domain("controlplane.tstest") + + // Derper is a domain name for a test DERP server. + Derper = Domain("derp.tstest") +) + +// ServerTLSConfig returns a TLS configuration suitable for a server +// using the KeyPair's certificate and private key. +func (d Domain) ServerTLSConfig() *tls.Config { + cert, err := tls.X509KeyPair(d.CertPEM(), privateKey(string(d))) + if err != nil { + panic("invalid TLS key pair: " + err.Error()) + } + return &tls.Config{ + Certificates: []tls.Certificate{cert}, + } +} + +// KeyPEM returns a PEM-encoded private key for the domain. +func (d Domain) KeyPEM() []byte { + return privateKey(string(d)) +} + +// CertPEM returns a PEM-encoded certificate for the domain. +func (d Domain) CertPEM() []byte { + caCert := mustParsePEM(TestRootCA(), x509.ParseCertificate) + caPriv := mustParsePEM(privateKey(caDomain), x509.ParseECPrivateKey) + leafKey := mustParsePEM(d.KeyPEM(), x509.ParseECPrivateKey) + + serial, err := rand.Int(rand.Reader, big.NewInt(0).Lsh(big.NewInt(1), 128)) + if err != nil { + panic(err) + } + + now := time.Now().Add(-time.Hour) + tpl := &x509.Certificate{ + SerialNumber: serial, + Subject: pkix.Name{CommonName: string(d)}, + NotBefore: now, + NotAfter: now.AddDate(2, 0, 0), + + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + DNSNames: []string{string(d)}, + } + + der, err := x509.CreateCertificate(rand.Reader, tpl, caCert, &leafKey.PublicKey, caPriv) + if err != nil { + panic(err) + } + return pemCert(der) +} + +func mustSKID(pub *ecdsa.PublicKey) []byte { + skid, err := x509.MarshalPKIXPublicKey(pub) + if err != nil { + panic(err) + } + return skid[:20] // same as x509 library +} diff --git a/tstest/tlstest/tlstest_test.go b/tstest/tlstest/tlstest_test.go new file mode 100644 index 000000000..8497b872e --- /dev/null +++ b/tstest/tlstest/tlstest_test.go @@ -0,0 +1,21 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tlstest + +import ( + "testing" +) + +func TestPrivateKey(t *testing.T) { + a := privateKey("a.tstest") + a2 := privateKey("a.tstest") + b := privateKey("b.tstest") + + if string(a) != string(a2) { + t.Errorf("a and a2 should be equal") + } + if string(a) == string(b) { + t.Errorf("a and b should not be equal") + } +} diff --git a/tstest/tstest.go b/tstest/tstest.go index 2d0d1351e..169450686 100644 --- a/tstest/tstest.go +++ b/tstest/tstest.go @@ -14,8 +14,8 @@ import ( "time" "tailscale.com/envknob" - "tailscale.com/logtail/backoff" "tailscale.com/types/logger" + "tailscale.com/util/backoff" "tailscale.com/util/cibuild" ) diff --git a/tstest/typewalk/typewalk.go b/tstest/typewalk/typewalk.go new file mode 100644 index 000000000..b22505351 --- /dev/null +++ b/tstest/typewalk/typewalk.go @@ -0,0 +1,106 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package typewalk provides utilities to walk Go types using reflection. +package typewalk + +import ( + "iter" + "reflect" + "strings" +) + +// Path describes a path via a type where a private key may be found, +// along with a function to test whether a reflect.Value at that path is +// non-zero. +type Path struct { + // Name is the path from the root type, suitable for using as a t.Run name. + Name string + + // Walk returns the reflect.Value at the end of the path, given a root + // reflect.Value. + Walk func(root reflect.Value) (leaf reflect.Value) +} + +// MatchingPaths returns a sequence of [Path] for all paths +// within the given type that end in a type matching match. +func MatchingPaths(rt reflect.Type, match func(reflect.Type) bool) iter.Seq[Path] { + // valFromRoot is a function that, given a reflect.Value of the root struct, + // returns the reflect.Value at some path within it. + type valFromRoot func(reflect.Value) reflect.Value + + return func(yield func(Path) bool) { + var walk func(reflect.Type, valFromRoot) + var path []string + var done bool + seen := map[reflect.Type]bool{} + + walk = func(t reflect.Type, getV valFromRoot) { + if seen[t] { + return + } + seen[t] = true + defer func() { seen[t] = false }() + if done { + return + } + if match(t) { + if !yield(Path{ + Name: strings.Join(path, "."), + Walk: getV, + }) { + done = true + } + return + } + switch t.Kind() { + case reflect.Ptr, reflect.Slice, reflect.Array: + walk(t.Elem(), func(root reflect.Value) reflect.Value { + v := getV(root) + return v.Elem() + }) + case reflect.Struct: + for i := range t.NumField() { + sf := t.Field(i) + fieldName := sf.Name + if fieldName == "_" { + continue + } + path = append(path, fieldName) + walk(sf.Type, func(root reflect.Value) reflect.Value { + return getV(root).FieldByName(fieldName) + }) + path = path[:len(path)-1] + if done { + return + } + } + case reflect.Map: + walk(t.Elem(), func(root reflect.Value) reflect.Value { + v := getV(root) + if v.Len() == 0 { + return reflect.Zero(t.Elem()) + } + iter := v.MapRange() + iter.Next() + return iter.Value() + }) + if done { + return + } + walk(t.Key(), func(root reflect.Value) reflect.Value { + v := getV(root) + if v.Len() == 0 { + return reflect.Zero(t.Key()) + } + iter := v.MapRange() + iter.Next() + return iter.Key() + }) + } + } + + path = append(path, rt.Name()) + walk(rt, func(v reflect.Value) reflect.Value { return v }) + } +} diff --git a/tstime/tstime.go b/tstime/tstime.go index 1c006355f..6e5b7f9f4 100644 --- a/tstime/tstime.go +++ b/tstime/tstime.go @@ -6,6 +6,7 @@ package tstime import ( "context" + "encoding" "strconv" "strings" "time" @@ -183,3 +184,40 @@ func (StdClock) AfterFunc(d time.Duration, f func()) TimerController { func (StdClock) Since(t time.Time) time.Duration { return time.Since(t) } + +// GoDuration is a [time.Duration] but JSON serializes with [time.Duration.String]. +// +// Note that this format is specific to Go and non-standard, +// but excels in being most humanly readable compared to alternatives. +// The wider industry still lacks consensus for the representation +// of a time duration in humanly-readable text. +// See https://go.dev/issue/71631 for more discussion. +// +// Regardless of how the industry evolves into the future, +// this type explicitly uses the Go format. +type GoDuration struct{ time.Duration } + +var ( + _ encoding.TextAppender = (*GoDuration)(nil) + _ encoding.TextMarshaler = (*GoDuration)(nil) + _ encoding.TextUnmarshaler = (*GoDuration)(nil) +) + +func (d GoDuration) AppendText(b []byte) ([]byte, error) { + // The String method is inlineable (see https://go.dev/cl/520602), + // so this may not allocate since the string does not escape. + return append(b, d.String()...), nil +} + +func (d GoDuration) MarshalText() ([]byte, error) { + return []byte(d.String()), nil +} + +func (d *GoDuration) UnmarshalText(b []byte) error { + d2, err := time.ParseDuration(string(b)) + if err != nil { + return err + } + d.Duration = d2 + return nil +} diff --git a/tstime/tstime_test.go b/tstime/tstime_test.go index 3ffeaf0ff..556ad4e8b 100644 --- a/tstime/tstime_test.go +++ b/tstime/tstime_test.go @@ -4,8 +4,11 @@ package tstime import ( + "encoding/json" "testing" "time" + + "tailscale.com/util/must" ) func TestParseDuration(t *testing.T) { @@ -34,3 +37,17 @@ func TestParseDuration(t *testing.T) { } } } + +func TestGoDuration(t *testing.T) { + wantDur := GoDuration{time.Hour + time.Minute + time.Second + time.Millisecond + time.Microsecond + time.Nanosecond} + gotJSON := string(must.Get(json.Marshal(wantDur))) + wantJSON := `"1h1m1.001001001s"` + if gotJSON != wantJSON { + t.Errorf("json.Marshal(%v) = %s, want %s", wantDur, gotJSON, wantJSON) + } + var gotDur GoDuration + must.Do(json.Unmarshal([]byte(wantJSON), &gotDur)) + if gotDur != wantDur { + t.Errorf("json.Unmarshal(%s) = %v, want %v", wantJSON, gotDur, wantDur) + } +} diff --git a/tsweb/debug.go b/tsweb/debug.go index 6db3f25cf..4c0fabaff 100644 --- a/tsweb/debug.go +++ b/tsweb/debug.go @@ -9,12 +9,11 @@ import ( "html" "io" "net/http" - "net/http/pprof" "net/url" "os" "runtime" - "tailscale.com/tsweb/promvarz" + "tailscale.com/feature" "tailscale.com/tsweb/varz" "tailscale.com/version" ) @@ -34,8 +33,14 @@ type DebugHandler struct { kvs []func(io.Writer) // output one
  • ...
  • each, see KV() urls []string // one
  • ...
  • block with link each sections []func(io.Writer, *http.Request) // invoked in registration order prior to outputting + title string // title displayed on index page } +// PrometheusHandler is an optional hook to enable native Prometheus +// support in the debug handler. It is disabled by default. Import the +// tailscale.com/tsweb/promvarz package to enable this feature. +var PrometheusHandler feature.Hook[func(*DebugHandler)] + // Debugger returns the DebugHandler registered on mux at /debug/, // creating it if necessary. func Debugger(mux *http.ServeMux) *DebugHandler { @@ -44,25 +49,21 @@ func Debugger(mux *http.ServeMux) *DebugHandler { return d } ret := &DebugHandler{ - mux: mux, + mux: mux, + title: fmt.Sprintf("%s debug", version.CmdName()), } mux.Handle("/debug/", ret) ret.KVFunc("Uptime", func() any { return varz.Uptime() }) ret.KV("Version", version.Long()) ret.Handle("vars", "Metrics (Go)", expvar.Handler()) - ret.Handle("varz", "Metrics (Prometheus)", http.HandlerFunc(promvarz.Handler)) - ret.Handle("pprof/", "pprof (index)", http.HandlerFunc(pprof.Index)) - // the CPU profile handler is special because it responds - // streamily, unlike every other pprof handler. This means it's - // not made available through pprof.Index the way all the other - // pprof types are, you have to register the CPU profile handler - // separately. Use HandleSilent for that to not pollute the human - // debug list with a link that produces streaming line noise if - // you click it. - ret.HandleSilent("pprof/profile", http.HandlerFunc(pprof.Profile)) - ret.URL("/debug/pprof/goroutine?debug=1", "Goroutines (collapsed)") - ret.URL("/debug/pprof/goroutine?debug=2", "Goroutines (full)") + if PrometheusHandler.IsSet() { + PrometheusHandler.Get()(ret) + } else { + ret.Handle("varz", "Metrics (Prometheus)", http.HandlerFunc(varz.Handler)) + } + + addProfilingHandlers(ret) ret.Handle("gc", "force GC", http.HandlerFunc(gcHandler)) hostname, err := os.Hostname() if err == nil { @@ -85,7 +86,7 @@ func (d *DebugHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { AddBrowserHeaders(w) f := func(format string, args ...any) { fmt.Fprintf(w, format, args...) } - f("

    %s debug

      ", version.CmdName()) + f("

      %s

        ", html.EscapeString(d.title)) for _, kv := range d.kvs { kv(w) } @@ -103,14 +104,20 @@ func (d *DebugHandler) handle(slug string, handler http.Handler) string { return href } -// Handle registers handler at /debug/ and creates a descriptive -// entry in /debug/ for it. +// Handle registers handler at /debug/ and adds a link to it +// on /debug/ with the provided description. func (d *DebugHandler) Handle(slug, desc string, handler http.Handler) { href := d.handle(slug, handler) d.URL(href, desc) } -// HandleSilent registers handler at /debug/. It does not create +// Handle registers handler at /debug/ and adds a link to it +// on /debug/ with the provided description. +func (d *DebugHandler) HandleFunc(slug, desc string, handler http.HandlerFunc) { + d.Handle(slug, desc, handler) +} + +// HandleSilent registers handler at /debug/. It does not add // a descriptive entry in /debug/ for it. This should be used // sparingly, for things that need to be registered but would pollute // the list of debug links. @@ -118,6 +125,14 @@ func (d *DebugHandler) HandleSilent(slug string, handler http.Handler) { d.handle(slug, handler) } +// HandleSilent registers handler at /debug/. It does not add +// a descriptive entry in /debug/ for it. This should be used +// sparingly, for things that need to be registered but would pollute +// the list of debug links. +func (d *DebugHandler) HandleSilentFunc(slug string, handler http.HandlerFunc) { + d.HandleSilent(slug, handler) +} + // KV adds a key/value list item to /debug/. func (d *DebugHandler) KV(k string, v any) { val := html.EscapeString(fmt.Sprintf("%v", v)) @@ -149,6 +164,11 @@ func (d *DebugHandler) Section(f func(w io.Writer, r *http.Request)) { d.sections = append(d.sections, f) } +// Title sets the title at the top of the debug page. +func (d *DebugHandler) Title(title string) { + d.title = title +} + func gcHandler(w http.ResponseWriter, r *http.Request) { w.Write([]byte("running GC...\n")) if f, ok := w.(http.Flusher); ok { diff --git a/tsweb/pprof_default.go b/tsweb/pprof_default.go new file mode 100644 index 000000000..7d22a6161 --- /dev/null +++ b/tsweb/pprof_default.go @@ -0,0 +1,24 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !js && !wasm + +package tsweb + +import ( + "net/http" + "net/http/pprof" +) + +func addProfilingHandlers(d *DebugHandler) { + // pprof.Index serves everything that runtime/pprof.Lookup finds: + // goroutine, threadcreate, heap, allocs, block, mutex + d.Handle("pprof/", "pprof (index)", http.HandlerFunc(pprof.Index)) + // But register the other ones from net/http/pprof directly: + d.HandleSilent("pprof/cmdline", http.HandlerFunc(pprof.Cmdline)) + d.HandleSilent("pprof/profile", http.HandlerFunc(pprof.Profile)) + d.HandleSilent("pprof/symbol", http.HandlerFunc(pprof.Symbol)) + d.HandleSilent("pprof/trace", http.HandlerFunc(pprof.Trace)) + d.URL("/debug/pprof/goroutine?debug=1", "Goroutines (collapsed)") + d.URL("/debug/pprof/goroutine?debug=2", "Goroutines (full)") +} diff --git a/tsweb/pprof_js.go b/tsweb/pprof_js.go new file mode 100644 index 000000000..1212b37e8 --- /dev/null +++ b/tsweb/pprof_js.go @@ -0,0 +1,10 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build js && wasm + +package tsweb + +func addProfilingHandlers(d *DebugHandler) { + // No pprof in js builds, pprof doesn't work and bloats the build. +} diff --git a/tsweb/promvarz/promvarz.go b/tsweb/promvarz/promvarz.go index d0e1e52ba..1d978c767 100644 --- a/tsweb/promvarz/promvarz.go +++ b/tsweb/promvarz/promvarz.go @@ -11,12 +11,21 @@ import ( "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/common/expfmt" + "tailscale.com/tsweb" "tailscale.com/tsweb/varz" ) -// Handler returns Prometheus metrics exported by our expvar converter +func init() { + tsweb.PrometheusHandler.Set(registerVarz) +} + +func registerVarz(debug *tsweb.DebugHandler) { + debug.Handle("varz", "Metrics (Prometheus)", http.HandlerFunc(handler)) +} + +// handler returns Prometheus metrics exported by our expvar converter // and the official Prometheus client. -func Handler(w http.ResponseWriter, r *http.Request) { +func handler(w http.ResponseWriter, r *http.Request) { if err := gatherNativePrometheusMetrics(w); err != nil { w.WriteHeader(http.StatusInternalServerError) w.Write([]byte(err.Error())) diff --git a/tsweb/promvarz/promvarz_test.go b/tsweb/promvarz/promvarz_test.go index a3f4e66f1..cffbbec22 100644 --- a/tsweb/promvarz/promvarz_test.go +++ b/tsweb/promvarz/promvarz_test.go @@ -1,5 +1,6 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause + package promvarz import ( @@ -23,7 +24,7 @@ func TestHandler(t *testing.T) { testVar1.Set(42) testVar2.Set(4242) - svr := httptest.NewServer(http.HandlerFunc(Handler)) + svr := httptest.NewServer(http.HandlerFunc(handler)) defer svr.Close() want := ` diff --git a/tsweb/request_id.go b/tsweb/request_id.go index 8516b8f72..46e523852 100644 --- a/tsweb/request_id.go +++ b/tsweb/request_id.go @@ -6,9 +6,10 @@ package tsweb import ( "context" "net/http" + "time" "tailscale.com/util/ctxkey" - "tailscale.com/util/fastuuid" + "tailscale.com/util/rands" ) // RequestID is an opaque identifier for a HTTP request, used to correlate @@ -41,10 +42,12 @@ const RequestIDHeader = "X-Tailscale-Request-Id" // GenerateRequestID generates a new request ID with the current format. func GenerateRequestID() RequestID { - // REQ-1 indicates the version of the RequestID pattern. It is - // currently arbitrary but allows for forward compatible - // transitions if needed. - return RequestID("REQ-1" + fastuuid.NewUUID().String()) + // Return a string of the form "REQ-<...>" + // Previously we returned "REQ-1". + // Now we return "REQ-2" version, where the "2" doubles as the year 2YYY + // in a leading date. + now := time.Now().UTC() + return RequestID("REQ-" + now.Format("20060102150405") + rands.HexString(16)) } // SetRequestID is an HTTP middleware that injects a RequestID in the diff --git a/tsweb/tsweb.go b/tsweb/tsweb.go index 9ddb3fad5..869b4cc8e 100644 --- a/tsweb/tsweb.go +++ b/tsweb/tsweb.go @@ -15,7 +15,6 @@ import ( "io" "net" "net/http" - _ "net/http/pprof" "net/netip" "net/url" "os" @@ -629,8 +628,8 @@ type loggingResponseWriter struct { // from r, or falls back to logf. If a nil logger is given, the logs are // discarded. func newLogResponseWriter(logf logger.Logf, w http.ResponseWriter, r *http.Request) *loggingResponseWriter { - if l, ok := logger.LogfKey.ValueOk(r.Context()); ok && l != nil { - logf = l + if lg, ok := logger.LogfKey.ValueOk(r.Context()); ok && lg != nil { + logf = lg } if logf == nil { logf = logger.Discard @@ -643,46 +642,46 @@ func newLogResponseWriter(logf logger.Logf, w http.ResponseWriter, r *http.Reque } // WriteHeader implements [http.ResponseWriter]. -func (l *loggingResponseWriter) WriteHeader(statusCode int) { - if l.code != 0 { - l.logf("[unexpected] HTTP handler set statusCode twice (%d and %d)", l.code, statusCode) +func (lg *loggingResponseWriter) WriteHeader(statusCode int) { + if lg.code != 0 { + lg.logf("[unexpected] HTTP handler set statusCode twice (%d and %d)", lg.code, statusCode) return } - if l.ctx.Err() == nil { - l.code = statusCode + if lg.ctx.Err() == nil { + lg.code = statusCode } - l.ResponseWriter.WriteHeader(statusCode) + lg.ResponseWriter.WriteHeader(statusCode) } // Write implements [http.ResponseWriter]. -func (l *loggingResponseWriter) Write(bs []byte) (int, error) { - if l.code == 0 { - l.code = 200 +func (lg *loggingResponseWriter) Write(bs []byte) (int, error) { + if lg.code == 0 { + lg.code = 200 } - n, err := l.ResponseWriter.Write(bs) - l.bytes += n + n, err := lg.ResponseWriter.Write(bs) + lg.bytes += n return n, err } // Hijack implements http.Hijacker. Note that hijacking can still fail // because the wrapped ResponseWriter is not required to implement // Hijacker, as this breaks HTTP/2. -func (l *loggingResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { - h, ok := l.ResponseWriter.(http.Hijacker) +func (lg *loggingResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + h, ok := lg.ResponseWriter.(http.Hijacker) if !ok { return nil, nil, errors.New("ResponseWriter is not a Hijacker") } conn, buf, err := h.Hijack() if err == nil { - l.hijacked = true + lg.hijacked = true } return conn, buf, err } -func (l loggingResponseWriter) Flush() { - f, _ := l.ResponseWriter.(http.Flusher) +func (lg loggingResponseWriter) Flush() { + f, _ := lg.ResponseWriter.(http.Flusher) if f == nil { - l.logf("[unexpected] tried to Flush a ResponseWriter that can't flush") + lg.logf("[unexpected] tried to Flush a ResponseWriter that can't flush") return } f.Flush() diff --git a/tsweb/tsweb_test.go b/tsweb/tsweb_test.go index 13840c012..d4c9721e9 100644 --- a/tsweb/tsweb_test.go +++ b/tsweb/tsweb_test.go @@ -1307,6 +1307,28 @@ func TestBucket(t *testing.T) { } } +func TestGenerateRequestID(t *testing.T) { + t0 := time.Now() + got := GenerateRequestID() + t.Logf("Got: %q", got) + if !strings.HasPrefix(string(got), "REQ-2") { + t.Errorf("expect REQ-2 prefix; got %q", got) + } + const wantLen = len("REQ-2024112022140896f8ead3d3f3be27") + if len(got) != wantLen { + t.Fatalf("len = %d; want %d", len(got), wantLen) + } + d := got[len("REQ-"):][:14] + timeBack, err := time.Parse("20060102150405", string(d)) + if err != nil { + t.Fatalf("parsing time back: %v", err) + } + elapsed := timeBack.Sub(t0) + if elapsed > 3*time.Second { // allow for slow github actions runners :) + t.Fatalf("time back was %v; want within 3s", elapsed) + } +} + func ExampleMiddlewareStack() { // setHeader returns a middleware that sets header k = vs. setHeader := func(k string, vs ...string) Middleware { diff --git a/tsweb/varz/varz.go b/tsweb/varz/varz.go index 561b24877..b1c66b859 100644 --- a/tsweb/varz/varz.go +++ b/tsweb/varz/varz.go @@ -5,28 +5,41 @@ package varz import ( + "bufio" "cmp" "expvar" "fmt" "io" "net/http" + "os" + "path/filepath" "reflect" "runtime" "sort" + "strconv" "strings" "sync" "time" "unicode" "unicode/utf8" + "golang.org/x/exp/constraints" "tailscale.com/metrics" + "tailscale.com/syncs" + "tailscale.com/types/logger" "tailscale.com/version" ) +// StaticStringVar returns a new expvar.Var that always returns s. +func StaticStringVar(s string) expvar.Var { + var v any = s // box s into an interface just once + return expvar.Func(func() any { return v }) +} + func init() { expvar.Publish("process_start_unix_time", expvar.Func(func() any { return timeStart.Unix() })) - expvar.Publish("version", expvar.Func(func() any { return version.Long() })) - expvar.Publish("go_version", expvar.Func(func() any { return runtime.Version() })) + expvar.Publish("version", StaticStringVar(version.Long())) + expvar.Publish("go_version", StaticStringVar(runtime.Version())) expvar.Publish("counter_uptime_sec", expvar.Func(func() any { return int64(Uptime().Seconds()) })) expvar.Publish("gauge_goroutines", expvar.Func(func() any { return runtime.NumGoroutine() })) } @@ -124,6 +137,9 @@ func writePromExpVar(w io.Writer, prefix string, kv expvar.KeyValue) { case *expvar.Int: fmt.Fprintf(w, "# TYPE %s %s\n%s %v\n", name, cmp.Or(typ, "counter"), name, v.Value()) return + case *syncs.ShardedInt: + fmt.Fprintf(w, "# TYPE %s %s\n%s %v\n", name, cmp.Or(typ, "counter"), name, v.Value()) + return case *expvar.Float: fmt.Fprintf(w, "# TYPE %s %s\n%s %v\n", name, cmp.Or(typ, "gauge"), name, v.Value()) return @@ -179,7 +195,11 @@ func writePromExpVar(w io.Writer, prefix string, kv expvar.KeyValue) { return } if vs, ok := v.(string); ok && strings.HasSuffix(name, "version") { - fmt.Fprintf(w, "%s{version=%q} 1\n", name, vs) + if name == "version" { + fmt.Fprintf(w, "%s{version=%q,binary=%q} 1\n", name, vs, binaryName()) + } else { + fmt.Fprintf(w, "%s{version=%q} 1\n", name, vs) + } return } switch v := v.(type) { @@ -298,6 +318,18 @@ func ExpvarDoHandler(expvarDoFunc func(f func(expvar.KeyValue))) func(http.Respo } } +var binaryName = sync.OnceValue(func() string { + exe, err := os.Executable() + if err != nil { + return "" + } + exe2, err := filepath.EvalSymlinks(exe) + if err != nil { + return filepath.Base(exe) + } + return filepath.Base(exe2) +}) + // PrometheusMetricsReflectRooter is an optional interface that expvar.Var implementations // can implement to indicate that they should be walked recursively with reflect to find // sets of fields to export. @@ -310,21 +342,52 @@ type PrometheusMetricsReflectRooter interface { var expvarDo = expvar.Do // pulled out for tests -func writeMemstats(w io.Writer, ms *runtime.MemStats) { - out := func(name, typ string, v uint64, help string) { - if help != "" { - fmt.Fprintf(w, "# HELP memstats_%s %s\n", name, help) - } - fmt.Fprintf(w, "# TYPE memstats_%s %s\nmemstats_%s %v\n", name, typ, name, v) +func writeMemstat[V constraints.Integer | constraints.Float](bw *bufio.Writer, typ, name string, v V, help string) { + if help != "" { + bw.WriteString("# HELP memstats_") + bw.WriteString(name) + bw.WriteString(" ") + bw.WriteString(help) + bw.WriteByte('\n') + } + bw.WriteString("# TYPE memstats_") + bw.WriteString(name) + bw.WriteString(" ") + bw.WriteString(typ) + bw.WriteByte('\n') + bw.WriteString("memstats_") + bw.WriteString(name) + bw.WriteByte(' ') + rt := reflect.TypeOf(v) + switch { + case rt == reflect.TypeFor[int]() || + rt == reflect.TypeFor[uint]() || + rt == reflect.TypeFor[int8]() || + rt == reflect.TypeFor[uint8]() || + rt == reflect.TypeFor[int16]() || + rt == reflect.TypeFor[uint16]() || + rt == reflect.TypeFor[int32]() || + rt == reflect.TypeFor[uint32]() || + rt == reflect.TypeFor[int64]() || + rt == reflect.TypeFor[uint64]() || + rt == reflect.TypeFor[uintptr](): + bw.Write(strconv.AppendInt(bw.AvailableBuffer(), int64(v), 10)) + case rt == reflect.TypeFor[float32]() || rt == reflect.TypeFor[float64](): + bw.Write(strconv.AppendFloat(bw.AvailableBuffer(), float64(v), 'f', -1, 64)) } - g := func(name string, v uint64, help string) { out(name, "gauge", v, help) } - c := func(name string, v uint64, help string) { out(name, "counter", v, help) } - g("heap_alloc", ms.HeapAlloc, "current bytes of allocated heap objects (up/down smoothly)") - c("total_alloc", ms.TotalAlloc, "cumulative bytes allocated for heap objects") - g("sys", ms.Sys, "total bytes of memory obtained from the OS") - c("mallocs", ms.Mallocs, "cumulative count of heap objects allocated") - c("frees", ms.Frees, "cumulative count of heap objects freed") - c("num_gc", uint64(ms.NumGC), "number of completed GC cycles") + bw.WriteByte('\n') +} + +func writeMemstats(w io.Writer, ms *runtime.MemStats) { + fmt.Fprintf(w, "%v", logger.ArgWriter(func(bw *bufio.Writer) { + writeMemstat(bw, "gauge", "heap_alloc", ms.HeapAlloc, "current bytes of allocated heap objects (up/down smoothly)") + writeMemstat(bw, "counter", "total_alloc", ms.TotalAlloc, "cumulative bytes allocated for heap objects") + writeMemstat(bw, "gauge", "sys", ms.Sys, "total bytes of memory obtained from the OS") + writeMemstat(bw, "counter", "mallocs", ms.Mallocs, "cumulative count of heap objects allocated") + writeMemstat(bw, "counter", "frees", ms.Frees, "cumulative count of heap objects freed") + writeMemstat(bw, "counter", "num_gc", ms.NumGC, "number of completed GC cycles") + writeMemstat(bw, "gauge", "gc_cpu_fraction", ms.GCCPUFraction, "fraction of CPU time used by GC") + })) } // sortedStructField is metadata about a struct field used both for sorting once diff --git a/tsweb/varz/varz_test.go b/tsweb/varz/varz_test.go index 7e094b0e7..5bbacbe35 100644 --- a/tsweb/varz/varz_test.go +++ b/tsweb/varz/varz_test.go @@ -4,14 +4,18 @@ package varz import ( + "bytes" "expvar" "net/http/httptest" "reflect" + "runtime" "strings" "testing" "tailscale.com/metrics" + "tailscale.com/syncs" "tailscale.com/tstest" + "tailscale.com/util/racebuild" "tailscale.com/version" ) @@ -280,6 +284,20 @@ foo_foo_a 1 foo_foo_b 1 `) + "\n", }, + { + "metrics_sharded_int", + "counter_api_status_code", + func() *syncs.ShardedInt { + m := syncs.NewShardedInt() + m.Add(40) + m.Add(2) + return m + }(), + strings.TrimSpace(` +# TYPE api_status_code counter +api_status_code 42 + `) + "\n", + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -418,3 +436,75 @@ func TestVarzHandlerSorting(t *testing.T) { } } } + +func TestWriteMemestats(t *testing.T) { + memstats := &runtime.MemStats{ + Alloc: 1, + TotalAlloc: 2, + Sys: 3, + Lookups: 4, + Mallocs: 5, + Frees: 6, + HeapAlloc: 7, + HeapSys: 8, + HeapIdle: 9, + HeapInuse: 10, + HeapReleased: 11, + HeapObjects: 12, + StackInuse: 13, + StackSys: 14, + MSpanInuse: 15, + MSpanSys: 16, + MCacheInuse: 17, + MCacheSys: 18, + BuckHashSys: 19, + GCSys: 20, + OtherSys: 21, + NextGC: 22, + LastGC: 23, + PauseTotalNs: 24, + // PauseNs: [256]int64{}, + NumGC: 26, + NumForcedGC: 27, + GCCPUFraction: 0.28, + } + + var buf bytes.Buffer + writeMemstats(&buf, memstats) + lines := strings.Split(buf.String(), "\n") + + checkFor := func(name, typ, value string) { + var foundType, foundValue bool + for _, line := range lines { + if line == "memstats_"+name+" "+value { + foundValue = true + } + if line == "# TYPE memstats_"+name+" "+typ { + foundType = true + } + if foundValue && foundType { + return + } + } + t.Errorf("memstats_%s foundType=%v foundValue=%v", name, foundType, foundValue) + } + + t.Logf("memstats:\n %s", buf.String()) + + checkFor("heap_alloc", "gauge", "7") + checkFor("total_alloc", "counter", "2") + checkFor("sys", "gauge", "3") + checkFor("mallocs", "counter", "5") + checkFor("frees", "counter", "6") + checkFor("num_gc", "counter", "26") + checkFor("gc_cpu_fraction", "gauge", "0.28") + + if !racebuild.On { + if allocs := testing.AllocsPerRun(1000, func() { + buf.Reset() + writeMemstats(&buf, memstats) + }); allocs != 1 { + t.Errorf("allocs = %v; want max %v", allocs, 1) + } + } +} diff --git a/types/appctype/appconnector.go b/types/appctype/appconnector.go index f4ced65a4..567ab755f 100644 --- a/types/appctype/appconnector.go +++ b/types/appctype/appconnector.go @@ -73,3 +73,23 @@ type AppConnectorAttr struct { // tag of the form tag:. Connectors []string `json:"connectors,omitempty"` } + +// RouteInfo is a data structure used to persist the in memory state of an AppConnector +// so that we can know, even after a restart, which routes came from ACLs and which were +// learned from domains. +type RouteInfo struct { + // Control is the routes from the 'routes' section of an app connector acl. + Control []netip.Prefix `json:",omitempty"` + // Domains are the routes discovered by observing DNS lookups for configured domains. + Domains map[string][]netip.Addr `json:",omitempty"` + // Wildcards are the configured DNS lookup domains to observe. When a DNS query matches Wildcards, + // its result is added to Domains. + Wildcards []string `json:",omitempty"` +} + +// RouteUpdate records a set of routes that should be advertised and a set of +// routes that should be unadvertised in event bus updates. +type RouteUpdate struct { + Advertise []netip.Prefix + Unadvertise []netip.Prefix +} diff --git a/types/bools/bools.go b/types/bools/bools.go new file mode 100644 index 000000000..e64068746 --- /dev/null +++ b/types/bools/bools.go @@ -0,0 +1,37 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package bools contains the [Int], [Compare], and [IfElse] functions. +package bools + +// Int returns 1 for true and 0 for false. +func Int(v bool) int { + if v { + return 1 + } else { + return 0 + } +} + +// Compare compares two boolean values as if false is ordered before true. +func Compare[T ~bool](x, y T) int { + switch { + case x == false && y == true: + return -1 + case x == true && y == false: + return +1 + default: + return 0 + } +} + +// IfElse is a ternary operator that returns trueVal if condExpr is true +// otherwise it returns falseVal. +// IfElse(c, a, b) is roughly equivalent to (c ? a : b) in languages like C. +func IfElse[T any](condExpr bool, trueVal T, falseVal T) T { + if condExpr { + return trueVal + } else { + return falseVal + } +} diff --git a/types/bools/bools_test.go b/types/bools/bools_test.go new file mode 100644 index 000000000..67faf3bcc --- /dev/null +++ b/types/bools/bools_test.go @@ -0,0 +1,39 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package bools + +import "testing" + +func TestInt(t *testing.T) { + if got := Int(true); got != 1 { + t.Errorf("Int(true) = %v, want 1", got) + } + if got := Int(false); got != 0 { + t.Errorf("Int(false) = %v, want 0", got) + } +} + +func TestCompare(t *testing.T) { + if got := Compare(false, false); got != 0 { + t.Errorf("Compare(false, false) = %v, want 0", got) + } + if got := Compare(false, true); got != -1 { + t.Errorf("Compare(false, true) = %v, want -1", got) + } + if got := Compare(true, false); got != +1 { + t.Errorf("Compare(true, false) = %v, want +1", got) + } + if got := Compare(true, true); got != 0 { + t.Errorf("Compare(true, true) = %v, want 0", got) + } +} + +func TestIfElse(t *testing.T) { + if got := IfElse(true, 0, 1); got != 0 { + t.Errorf("IfElse(true, 0, 1) = %v, want 0", got) + } + if got := IfElse(false, 0, 1); got != 1 { + t.Errorf("IfElse(false, 0, 1) = %v, want 1", got) + } +} diff --git a/types/dnstype/dnstype.go b/types/dnstype/dnstype.go index b7f5b9d02..a3ba1b0a9 100644 --- a/types/dnstype/dnstype.go +++ b/types/dnstype/dnstype.go @@ -35,6 +35,12 @@ type Resolver struct { // // As of 2022-09-08, BootstrapResolution is not yet used. BootstrapResolution []netip.Addr `json:",omitempty"` + + // UseWithExitNode designates that this resolver should continue to be used when an + // exit node is in use. Normally, DNS resolution is delegated to the exit node but + // there are situations where it is preferable to still use a Split DNS server and/or + // global DNS server instead of the exit node. + UseWithExitNode bool `json:",omitempty"` } // IPPort returns r.Addr as an IP address and port if either @@ -64,5 +70,7 @@ func (r *Resolver) Equal(other *Resolver) bool { return true } - return r.Addr == other.Addr && slices.Equal(r.BootstrapResolution, other.BootstrapResolution) + return r.Addr == other.Addr && + slices.Equal(r.BootstrapResolution, other.BootstrapResolution) && + r.UseWithExitNode == other.UseWithExitNode } diff --git a/types/dnstype/dnstype_clone.go b/types/dnstype/dnstype_clone.go index 86ca0535f..3985704aa 100644 --- a/types/dnstype/dnstype_clone.go +++ b/types/dnstype/dnstype_clone.go @@ -25,6 +25,7 @@ func (src *Resolver) Clone() *Resolver { var _ResolverCloneNeedsRegeneration = Resolver(struct { Addr string BootstrapResolution []netip.Addr + UseWithExitNode bool }{}) // Clone duplicates src into dst and reports whether it succeeded. diff --git a/types/dnstype/dnstype_test.go b/types/dnstype/dnstype_test.go index e3a941a20..ada5f687d 100644 --- a/types/dnstype/dnstype_test.go +++ b/types/dnstype/dnstype_test.go @@ -17,7 +17,7 @@ func TestResolverEqual(t *testing.T) { fieldNames = append(fieldNames, field.Name) } sort.Strings(fieldNames) - if !slices.Equal(fieldNames, []string{"Addr", "BootstrapResolution"}) { + if !slices.Equal(fieldNames, []string{"Addr", "BootstrapResolution", "UseWithExitNode"}) { t.Errorf("Resolver fields changed; update test") } @@ -68,6 +68,18 @@ func TestResolverEqual(t *testing.T) { }, want: false, }, + { + name: "equal UseWithExitNode", + a: &Resolver{Addr: "dns.example.com", UseWithExitNode: true}, + b: &Resolver{Addr: "dns.example.com", UseWithExitNode: true}, + want: true, + }, + { + name: "not equal UseWithExitNode", + a: &Resolver{Addr: "dns.example.com", UseWithExitNode: true}, + b: &Resolver{Addr: "dns.example.com", UseWithExitNode: false}, + want: false, + }, } for _, tt := range tests { diff --git a/types/dnstype/dnstype_view.go b/types/dnstype/dnstype_view.go index c0e2b28ff..a983864d0 100644 --- a/types/dnstype/dnstype_view.go +++ b/types/dnstype/dnstype_view.go @@ -6,16 +6,18 @@ package dnstype import ( - "encoding/json" + jsonv1 "encoding/json" "errors" "net/netip" + jsonv2 "github.com/go-json-experiment/json" + "github.com/go-json-experiment/json/jsontext" "tailscale.com/types/views" ) //go:generate go run tailscale.com/cmd/cloner -clonefunc=true -type=Resolver -// View returns a readonly view of Resolver. +// View returns a read-only view of Resolver. func (p *Resolver) View() ResolverView { return ResolverView{Đļ: p} } @@ -31,7 +33,7 @@ type ResolverView struct { Đļ *Resolver } -// Valid reports whether underlying value is non-nil. +// Valid reports whether v's underlying value is non-nil. func (v ResolverView) Valid() bool { return v.Đļ != nil } // AsStruct returns a clone of the underlying value which aliases no memory with @@ -43,8 +45,17 @@ func (v ResolverView) AsStruct() *Resolver { return v.Đļ.Clone() } -func (v ResolverView) MarshalJSON() ([]byte, error) { return json.Marshal(v.Đļ) } +// MarshalJSON implements [jsonv1.Marshaler]. +func (v ResolverView) MarshalJSON() ([]byte, error) { + return jsonv1.Marshal(v.Đļ) +} + +// MarshalJSONTo implements [jsonv2.MarshalerTo]. +func (v ResolverView) MarshalJSONTo(enc *jsontext.Encoder) error { + return jsonv2.MarshalEncode(enc, v.Đļ) +} +// UnmarshalJSON implements [jsonv1.Unmarshaler]. func (v *ResolverView) UnmarshalJSON(b []byte) error { if v.Đļ != nil { return errors.New("already initialized") @@ -53,21 +64,61 @@ func (v *ResolverView) UnmarshalJSON(b []byte) error { return nil } var x Resolver - if err := json.Unmarshal(b, &x); err != nil { + if err := jsonv1.Unmarshal(b, &x); err != nil { + return err + } + v.Đļ = &x + return nil +} + +// UnmarshalJSONFrom implements [jsonv2.UnmarshalerFrom]. +func (v *ResolverView) UnmarshalJSONFrom(dec *jsontext.Decoder) error { + if v.Đļ != nil { + return errors.New("already initialized") + } + var x Resolver + if err := jsonv2.UnmarshalDecode(dec, &x); err != nil { return err } v.Đļ = &x return nil } +// Addr is the address of the DNS resolver, one of: +// - A plain IP address for a "classic" UDP+TCP DNS resolver. +// This is the common format as sent by the control plane. +// - An IP:port, for tests. +// - "https://resolver.com/path" for DNS over HTTPS; currently +// as of 2022-09-08 only used for certain well-known resolvers +// (see the publicdns package) for which the IP addresses to dial DoH are +// known ahead of time, so bootstrap DNS resolution is not required. +// - "http://node-address:port/path" for DNS over HTTP over WireGuard. This +// is implemented in the PeerAPI for exit nodes and app connectors. +// - [TODO] "tls://resolver.com" for DNS over TCP+TLS func (v ResolverView) Addr() string { return v.Đļ.Addr } + +// BootstrapResolution is an optional suggested resolution for the +// DoT/DoH resolver, if the resolver URL does not reference an IP +// address directly. +// BootstrapResolution may be empty, in which case clients should +// look up the DoT/DoH server using their local "classic" DNS +// resolver. +// +// As of 2022-09-08, BootstrapResolution is not yet used. func (v ResolverView) BootstrapResolution() views.Slice[netip.Addr] { return views.SliceOf(v.Đļ.BootstrapResolution) } + +// UseWithExitNode designates that this resolver should continue to be used when an +// exit node is in use. Normally, DNS resolution is delegated to the exit node but +// there are situations where it is preferable to still use a Split DNS server and/or +// global DNS server instead of the exit node. +func (v ResolverView) UseWithExitNode() bool { return v.Đļ.UseWithExitNode } func (v ResolverView) Equal(v2 ResolverView) bool { return v.Đļ.Equal(v2.Đļ) } // A compilation failure here means this code must be regenerated, with the command at the top of this file. var _ResolverViewNeedsRegeneration = Resolver(struct { Addr string BootstrapResolution []netip.Addr + UseWithExitNode bool }{}) diff --git a/types/dnstype/messagetypes-string.go b/types/dnstype/messagetypes-string.go deleted file mode 100644 index 34abea1ba..000000000 --- a/types/dnstype/messagetypes-string.go +++ /dev/null @@ -1,84 +0,0 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package dnstype - -import ( - "errors" - "strings" - - "golang.org/x/net/dns/dnsmessage" -) - -// StringForType returns the string representation of a dnsmessage.Type. -// For example, StringForType(dnsmessage.TypeA) returns "A". -func StringForDNSMessageType(t dnsmessage.Type) string { - switch t { - case dnsmessage.TypeAAAA: - return "AAAA" - case dnsmessage.TypeALL: - return "ALL" - case dnsmessage.TypeA: - return "A" - case dnsmessage.TypeCNAME: - return "CNAME" - case dnsmessage.TypeHINFO: - return "HINFO" - case dnsmessage.TypeMINFO: - return "MINFO" - case dnsmessage.TypeMX: - return "MX" - case dnsmessage.TypeNS: - return "NS" - case dnsmessage.TypeOPT: - return "OPT" - case dnsmessage.TypePTR: - return "PTR" - case dnsmessage.TypeSOA: - return "SOA" - case dnsmessage.TypeSRV: - return "SRV" - case dnsmessage.TypeTXT: - return "TXT" - case dnsmessage.TypeWKS: - return "WKS" - } - return "UNKNOWN" -} - -// DNSMessageTypeForString returns the dnsmessage.Type for the given string. -// For example, DNSMessageTypeForString("A") returns dnsmessage.TypeA. -func DNSMessageTypeForString(s string) (t dnsmessage.Type, err error) { - s = strings.TrimSpace(strings.ToUpper(s)) - switch s { - case "AAAA": - return dnsmessage.TypeAAAA, nil - case "ALL": - return dnsmessage.TypeALL, nil - case "A": - return dnsmessage.TypeA, nil - case "CNAME": - return dnsmessage.TypeCNAME, nil - case "HINFO": - return dnsmessage.TypeHINFO, nil - case "MINFO": - return dnsmessage.TypeMINFO, nil - case "MX": - return dnsmessage.TypeMX, nil - case "NS": - return dnsmessage.TypeNS, nil - case "OPT": - return dnsmessage.TypeOPT, nil - case "PTR": - return dnsmessage.TypePTR, nil - case "SOA": - return dnsmessage.TypeSOA, nil - case "SRV": - return dnsmessage.TypeSRV, nil - case "TXT": - return dnsmessage.TypeTXT, nil - case "WKS": - return dnsmessage.TypeWKS, nil - } - return 0, errors.New("unknown DNS message type: " + s) -} diff --git a/types/geo/doc.go b/types/geo/doc.go new file mode 100644 index 000000000..749c63080 --- /dev/null +++ b/types/geo/doc.go @@ -0,0 +1,6 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package geo provides functionality to represent and process geographical +// locations on a spherical Earth. +package geo diff --git a/types/geo/point.go b/types/geo/point.go new file mode 100644 index 000000000..d7160ac59 --- /dev/null +++ b/types/geo/point.go @@ -0,0 +1,279 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package geo + +import ( + "encoding/binary" + "errors" + "fmt" + "math" + "strconv" +) + +// ErrBadPoint indicates that the point is malformed. +var ErrBadPoint = errors.New("not a valid point") + +// Point represents a pair of latitude and longitude coordinates. +type Point struct { + lat Degrees + // lng180 is the longitude offset by +180° so the zero value is invalid + // and +0+0/ is Point{lat: +0.0, lng180: +180.0}. + lng180 Degrees +} + +// MakePoint returns a Point representing a given latitude and longitude on +// a WGS 84 ellipsoid. The Coordinate Reference System is EPSG:4326. +// Latitude is wrapped to [-90°, +90°] and longitude to (-180°, +180°]. +func MakePoint(latitude, longitude Degrees) Point { + lat, lng := float64(latitude), float64(longitude) + + switch { + case math.IsNaN(lat) || math.IsInf(lat, 0): + // don’t wrap + case lat < -90 || lat > 90: + // Latitude wraps by flipping the longitude + lat = math.Mod(lat, 360.0) + switch { + case lat == 0.0: + lat = 0.0 // -0.0 == 0.0, but -0° is not valid + case lat < -270.0: + lat = +360.0 + lat + case lat < -90.0: + lat = -180.0 - lat + lng += 180.0 + case lat > +270.0: + lat = -360.0 + lat + case lat > +90.0: + lat = +180.0 - lat + lng += 180.0 + } + } + + switch { + case lat == -90.0 || lat == +90.0: + // By convention, the north and south poles have longitude 0°. + lng = 0 + case math.IsNaN(lng) || math.IsInf(lng, 0): + // don’t wrap + case lng <= -180.0 || lng > 180.0: + // Longitude wraps around normally + lng = math.Mod(lng, 360.0) + switch { + case lng == 0.0: + lng = 0.0 // -0.0 == 0.0, but -0° is not valid + case lng <= -180.0: + lng = +360.0 + lng + case lng > +180.0: + lng = -360.0 + lng + } + } + + return Point{ + lat: Degrees(lat), + lng180: Degrees(lng + 180.0), + } +} + +// Valid reports if p is a valid point. +func (p Point) Valid() bool { + return !p.IsZero() +} + +// LatLng reports the latitude and longitude. +func (p Point) LatLng() (lat, lng Degrees, err error) { + if p.IsZero() { + return 0 * Degree, 0 * Degree, ErrBadPoint + } + return p.lat, p.lng180 - 180.0*Degree, nil +} + +// LatLng reports the latitude and longitude in float64. If err is nil, then lat +// and lng will never both be 0.0 to disambiguate between an empty struct and +// Null Island (0° 0°). +func (p Point) LatLngFloat64() (lat, lng float64, err error) { + dlat, dlng, err := p.LatLng() + if err != nil { + return 0.0, 0.0, err + } + if dlat == 0.0 && dlng == 0.0 { + // dlng must survive conversion to float32. + dlng = math.SmallestNonzeroFloat32 + } + return float64(dlat), float64(dlng), err +} + +// SphericalAngleTo returns the angular distance from p to q, calculated on a +// spherical Earth. +func (p Point) SphericalAngleTo(q Point) (Radians, error) { + pLat, pLng, pErr := p.LatLng() + qLat, qLng, qErr := q.LatLng() + switch { + case pErr != nil && qErr != nil: + return 0.0, fmt.Errorf("spherical distance from %v to %v: %w", p, q, errors.Join(pErr, qErr)) + case pErr != nil: + return 0.0, fmt.Errorf("spherical distance from %v: %w", p, pErr) + case qErr != nil: + return 0.0, fmt.Errorf("spherical distance to %v: %w", q, qErr) + } + // The spherical law of cosines is accurate enough for close points when + // using float64. + // + // The haversine formula is an alternative, but it is poorly behaved + // when points are on opposite sides of the sphere. + rLat, rLng := float64(pLat.Radians()), float64(pLng.Radians()) + sLat, sLng := float64(qLat.Radians()), float64(qLng.Radians()) + cosA := math.Sin(rLat)*math.Sin(sLat) + + math.Cos(rLat)*math.Cos(sLat)*math.Cos(rLng-sLng) + return Radians(math.Acos(cosA)), nil +} + +// DistanceTo reports the great-circle distance between p and q, in meters. +func (p Point) DistanceTo(q Point) (Distance, error) { + r, err := p.SphericalAngleTo(q) + if err != nil { + return 0, err + } + return DistanceOnEarth(r.Turns()), nil +} + +// String returns a space-separated pair of latitude and longitude, in decimal +// degrees. Positive latitudes are in the northern hemisphere, and positive +// longitudes are east of the prime meridian. If p was not initialized, this +// will return "nowhere". +func (p Point) String() string { + lat, lng, err := p.LatLng() + if err != nil { + if err == ErrBadPoint { + return "nowhere" + } + panic(err) + } + + return lat.String() + " " + lng.String() +} + +// AppendBinary implements [encoding.BinaryAppender]. The output consists of two +// float32s in big-endian byte order: latitude and longitude offset by 180°. +// If p is not a valid, the output will be an 8-byte zero value. +func (p Point) AppendBinary(b []byte) ([]byte, error) { + end := binary.BigEndian + b = end.AppendUint32(b, math.Float32bits(float32(p.lat))) + b = end.AppendUint32(b, math.Float32bits(float32(p.lng180))) + return b, nil +} + +// MarshalBinary implements [encoding.BinaryMarshaller]. The output matches that +// of calling [Point.AppendBinary]. +func (p Point) MarshalBinary() ([]byte, error) { + var b [8]byte + return p.AppendBinary(b[:0]) +} + +// UnmarshalBinary implements [encoding.BinaryUnmarshaler]. It expects input +// that was formatted by [Point.AppendBinary]: in big-endian byte order, a +// float32 representing latitude followed by a float32 representing longitude +// offset by 180°. If latitude and longitude fall outside valid ranges, then +// an error is returned. +func (p *Point) UnmarshalBinary(data []byte) error { + if len(data) < 8 { // Two uint32s are 8 bytes long + return fmt.Errorf("%w: not enough data: %q", ErrBadPoint, data) + } + + end := binary.BigEndian + lat := Degrees(math.Float32frombits(end.Uint32(data[0:]))) + if lat < -90*Degree || lat > 90*Degree { + return fmt.Errorf("%w: latitude outside [-90°, +90°]: %s", ErrBadPoint, lat) + } + lng180 := Degrees(math.Float32frombits(end.Uint32(data[4:]))) + if lng180 != 0 && (lng180 < 0*Degree || lng180 > 360*Degree) { + // lng180 == 0 is OK: the zero value represents invalid points. + lng := lng180 - 180*Degree + return fmt.Errorf("%w: longitude outside (-180°, +180°]: %s", ErrBadPoint, lng) + } + + p.lat = lat + p.lng180 = lng180 + return nil +} + +// AppendText implements [encoding.TextAppender]. The output is a point +// formatted as OGC Well-Known Text, as "POINT (longitude latitude)" where +// longitude and latitude are in decimal degrees. If p is not valid, the output +// will be "POINT EMPTY". +func (p Point) AppendText(b []byte) ([]byte, error) { + if p.IsZero() { + b = append(b, []byte("POINT EMPTY")...) + return b, nil + } + + lat, lng, err := p.LatLng() + if err != nil { + return b, err + } + + b = append(b, []byte("POINT (")...) + b = strconv.AppendFloat(b, float64(lng), 'f', -1, 64) + b = append(b, ' ') + b = strconv.AppendFloat(b, float64(lat), 'f', -1, 64) + b = append(b, ')') + return b, nil +} + +// MarshalText implements [encoding.TextMarshaller]. The output matches that +// of calling [Point.AppendText]. +func (p Point) MarshalText() ([]byte, error) { + var b [8]byte + return p.AppendText(b[:0]) +} + +// MarshalUint64 produces the same output as MashalBinary, encoded in a uint64. +func (p Point) MarshalUint64() (uint64, error) { + b, err := p.MarshalBinary() + return binary.NativeEndian.Uint64(b), err +} + +// UnmarshalUint64 expects input formatted by MarshalUint64. +func (p *Point) UnmarshalUint64(v uint64) error { + b := binary.NativeEndian.AppendUint64(nil, v) + return p.UnmarshalBinary(b) +} + +// IsZero reports if p is the zero value. +func (p Point) IsZero() bool { + return p == Point{} +} + +// EqualApprox reports if p and q are approximately equal: that is the absolute +// difference of both latitude and longitude are less than tol. If tol is +// negative, then tol defaults to a reasonably small number (10âģâĩ). If tol is +// zero, then p and q must be exactly equal. +func (p Point) EqualApprox(q Point, tol float64) bool { + if tol == 0 { + return p == q + } + + if p.IsZero() && q.IsZero() { + return true + } else if p.IsZero() || q.IsZero() { + return false + } + + plat, plng, err := p.LatLng() + if err != nil { + panic(err) + } + qlat, qlng, err := q.LatLng() + if err != nil { + panic(err) + } + + if tol < 0 { + tol = 1e-5 + } + + dlat := float64(plat) - float64(qlat) + dlng := float64(plng) - float64(qlng) + return ((dlat < 0 && -dlat < tol) || (dlat >= 0 && dlat < tol)) && + ((dlng < 0 && -dlng < tol) || (dlng >= 0 && dlng < tol)) +} diff --git a/types/geo/point_test.go b/types/geo/point_test.go new file mode 100644 index 000000000..308c1a183 --- /dev/null +++ b/types/geo/point_test.go @@ -0,0 +1,541 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package geo_test + +import ( + "fmt" + "math" + "testing" + "testing/quick" + + "tailscale.com/types/geo" +) + +func TestPointZero(t *testing.T) { + var zero geo.Point + + if got := zero.IsZero(); !got { + t.Errorf("IsZero() got %t", got) + } + + if got := zero.Valid(); got { + t.Errorf("Valid() got %t", got) + } + + wantErr := geo.ErrBadPoint.Error() + if _, _, err := zero.LatLng(); err.Error() != wantErr { + t.Errorf("LatLng() err %q, want %q", err, wantErr) + } + + wantStr := "nowhere" + if got := zero.String(); got != wantStr { + t.Errorf("String() got %q, want %q", got, wantStr) + } + + wantB := []byte{0, 0, 0, 0, 0, 0, 0, 0} + if b, err := zero.MarshalBinary(); err != nil { + t.Errorf("MarshalBinary() err %q, want nil", err) + } else if string(b) != string(wantB) { + t.Errorf("MarshalBinary got %q, want %q", b, wantB) + } + + wantI := uint64(0x00000000) + if i, err := zero.MarshalUint64(); err != nil { + t.Errorf("MarshalUint64() err %q, want nil", err) + } else if i != wantI { + t.Errorf("MarshalUint64 got %v, want %v", i, wantI) + } +} + +func TestPoint(t *testing.T) { + for _, tt := range []struct { + name string + lat geo.Degrees + lng geo.Degrees + wantLat geo.Degrees + wantLng geo.Degrees + wantString string + wantText string + }{ + { + name: "null-island", + lat: +0.0, + lng: +0.0, + wantLat: +0.0, + wantLng: +0.0, + wantString: "+0° +0°", + wantText: "POINT (0 0)", + }, + { + name: "north-pole", + lat: +90.0, + lng: +0.0, + wantLat: +90.0, + wantLng: +0.0, + wantString: "+90° +0°", + wantText: "POINT (0 90)", + }, + { + name: "south-pole", + lat: -90.0, + lng: +0.0, + wantLat: -90.0, + wantLng: +0.0, + wantString: "-90° +0°", + wantText: "POINT (0 -90)", + }, + { + name: "north-pole-weird-longitude", + lat: +90.0, + lng: +1.0, + wantLat: +90.0, + wantLng: +0.0, + wantString: "+90° +0°", + wantText: "POINT (0 90)", + }, + { + name: "south-pole-weird-longitude", + lat: -90.0, + lng: +1.0, + wantLat: -90.0, + wantLng: +0.0, + wantString: "-90° +0°", + wantText: "POINT (0 -90)", + }, + { + name: "almost-north", + lat: +89.0, + lng: +0.0, + wantLat: +89.0, + wantLng: +0.0, + wantString: "+89° +0°", + wantText: "POINT (0 89)", + }, + { + name: "past-north", + lat: +91.0, + lng: +0.0, + wantLat: +89.0, + wantLng: +180.0, + wantString: "+89° +180°", + wantText: "POINT (180 89)", + }, + { + name: "almost-south", + lat: -89.0, + lng: +0.0, + wantLat: -89.0, + wantLng: +0.0, + wantString: "-89° +0°", + wantText: "POINT (0 -89)", + }, + { + name: "past-south", + lat: -91.0, + lng: +0.0, + wantLat: -89.0, + wantLng: +180.0, + wantString: "-89° +180°", + wantText: "POINT (180 -89)", + }, + { + name: "antimeridian-north", + lat: +180.0, + lng: +0.0, + wantLat: +0.0, + wantLng: +180.0, + wantString: "+0° +180°", + wantText: "POINT (180 0)", + }, + { + name: "antimeridian-south", + lat: -180.0, + lng: +0.0, + wantLat: +0.0, + wantLng: +180.0, + wantString: "+0° +180°", + wantText: "POINT (180 0)", + }, + { + name: "almost-antimeridian-north", + lat: +179.0, + lng: +0.0, + wantLat: +1.0, + wantLng: +180.0, + wantString: "+1° +180°", + wantText: "POINT (180 1)", + }, + { + name: "past-antimeridian-north", + lat: +181.0, + lng: +0.0, + wantLat: -1.0, + wantLng: +180.0, + wantString: "-1° +180°", + wantText: "POINT (180 -1)", + }, + { + name: "almost-antimeridian-south", + lat: -179.0, + lng: +0.0, + wantLat: -1.0, + wantLng: +180.0, + wantString: "-1° +180°", + wantText: "POINT (180 -1)", + }, + { + name: "past-antimeridian-south", + lat: -181.0, + lng: +0.0, + wantLat: +1.0, + wantLng: +180.0, + wantString: "+1° +180°", + wantText: "POINT (180 1)", + }, + { + name: "circumnavigate-north", + lat: +360.0, + lng: +1.0, + wantLat: +0.0, + wantLng: +1.0, + wantString: "+0° +1°", + wantText: "POINT (1 0)", + }, + { + name: "circumnavigate-south", + lat: -360.0, + lng: +1.0, + wantLat: +0.0, + wantLng: +1.0, + wantString: "+0° +1°", + wantText: "POINT (1 0)", + }, + { + name: "almost-circumnavigate-north", + lat: +359.0, + lng: +1.0, + wantLat: -1.0, + wantLng: +1.0, + wantString: "-1° +1°", + wantText: "POINT (1 -1)", + }, + { + name: "past-circumnavigate-north", + lat: +361.0, + lng: +1.0, + wantLat: +1.0, + wantLng: +1.0, + wantString: "+1° +1°", + wantText: "POINT (1 1)", + }, + { + name: "almost-circumnavigate-south", + lat: -359.0, + lng: +1.0, + wantLat: +1.0, + wantLng: +1.0, + wantString: "+1° +1°", + wantText: "POINT (1 1)", + }, + { + name: "past-circumnavigate-south", + lat: -361.0, + lng: +1.0, + wantLat: -1.0, + wantLng: +1.0, + wantString: "-1° +1°", + wantText: "POINT (1 -1)", + }, + { + name: "antimeridian-east", + lat: +0.0, + lng: +180.0, + wantLat: +0.0, + wantLng: +180.0, + wantString: "+0° +180°", + wantText: "POINT (180 0)", + }, + { + name: "antimeridian-west", + lat: +0.0, + lng: -180.0, + wantLat: +0.0, + wantLng: +180.0, + wantString: "+0° +180°", + wantText: "POINT (180 0)", + }, + { + name: "almost-antimeridian-east", + lat: +0.0, + lng: +179.0, + wantLat: +0.0, + wantLng: +179.0, + wantString: "+0° +179°", + wantText: "POINT (179 0)", + }, + { + name: "past-antimeridian-east", + lat: +0.0, + lng: +181.0, + wantLat: +0.0, + wantLng: -179.0, + wantString: "+0° -179°", + wantText: "POINT (-179 0)", + }, + { + name: "almost-antimeridian-west", + lat: +0.0, + lng: -179.0, + wantLat: +0.0, + wantLng: -179.0, + wantString: "+0° -179°", + wantText: "POINT (-179 0)", + }, + { + name: "past-antimeridian-west", + lat: +0.0, + lng: -181.0, + wantLat: +0.0, + wantLng: +179.0, + wantString: "+0° +179°", + wantText: "POINT (179 0)", + }, + { + name: "montreal", + lat: +45.508888, + lng: -73.561668, + wantLat: +45.508888, + wantLng: -73.561668, + wantString: "+45.508888° -73.561668°", + wantText: "POINT (-73.561668 45.508888)", + }, + { + name: "canada", + lat: 57.550480044655636, + lng: -98.41680517868062, + wantLat: 57.550480044655636, + wantLng: -98.41680517868062, + wantString: "+57.550480044655636° -98.41680517868062°", + wantText: "POINT (-98.41680517868062 57.550480044655636)", + }, + } { + t.Run(tt.name, func(t *testing.T) { + p := geo.MakePoint(tt.lat, tt.lng) + + lat, lng, err := p.LatLng() + if !approx(lat, tt.wantLat) { + t.Errorf("MakePoint: lat %v, want %v", lat, tt.wantLat) + } + if !approx(lng, tt.wantLng) { + t.Errorf("MakePoint: lng %v, want %v", lng, tt.wantLng) + } + if err != nil { + t.Fatalf("LatLng: err %q, expected nil", err) + } + + if got := p.String(); got != tt.wantString { + t.Errorf("String: got %q, wantString %q", got, tt.wantString) + } + + txt, err := p.MarshalText() + if err != nil { + t.Errorf("Text: err %q, expected nil", err) + } else if string(txt) != tt.wantText { + t.Errorf("Text: got %q, wantText %q", txt, tt.wantText) + } + + b, err := p.MarshalBinary() + if err != nil { + t.Fatalf("MarshalBinary: err %q, expected nil", err) + } + + var q geo.Point + if err := q.UnmarshalBinary(b); err != nil { + t.Fatalf("UnmarshalBinary: err %q, expected nil", err) + } + if !q.EqualApprox(p, -1) { + t.Errorf("UnmarshalBinary: roundtrip failed: %#v != %#v", q, p) + } + + i, err := p.MarshalUint64() + if err != nil { + t.Fatalf("MarshalUint64: err %q, expected nil", err) + } + + var r geo.Point + if err := r.UnmarshalUint64(i); err != nil { + t.Fatalf("UnmarshalUint64: err %r, expected nil", err) + } + if !q.EqualApprox(r, -1) { + t.Errorf("UnmarshalUint64: roundtrip failed: %#v != %#v", r, p) + } + }) + } +} + +func TestPointMarshalBinary(t *testing.T) { + roundtrip := func(p geo.Point) error { + b, err := p.MarshalBinary() + if err != nil { + return fmt.Errorf("marshal: %v", err) + } + var q geo.Point + if err := q.UnmarshalBinary(b); err != nil { + return fmt.Errorf("unmarshal: %v", err) + } + if q != p { + return fmt.Errorf("%#v != %#v", q, p) + } + return nil + } + + t.Run("nowhere", func(t *testing.T) { + var nowhere geo.Point + if err := roundtrip(nowhere); err != nil { + t.Errorf("roundtrip: %v", err) + } + }) + + t.Run("quick-check", func(t *testing.T) { + f := func(lat geo.Degrees, lng geo.Degrees) (ok bool) { + pt := geo.MakePoint(lat, lng) + if err := roundtrip(pt); err != nil { + t.Errorf("roundtrip: %v", err) + } + return !t.Failed() + } + if err := quick.Check(f, nil); err != nil { + t.Error(err) + } + }) +} + +func TestPointMarshalUint64(t *testing.T) { + t.Skip("skip") + roundtrip := func(p geo.Point) error { + i, err := p.MarshalUint64() + if err != nil { + return fmt.Errorf("marshal: %v", err) + } + var q geo.Point + if err := q.UnmarshalUint64(i); err != nil { + return fmt.Errorf("unmarshal: %v", err) + } + if q != p { + return fmt.Errorf("%#v != %#v", q, p) + } + return nil + } + + t.Run("nowhere", func(t *testing.T) { + var nowhere geo.Point + if err := roundtrip(nowhere); err != nil { + t.Errorf("roundtrip: %v", err) + } + }) + + t.Run("quick-check", func(t *testing.T) { + f := func(lat geo.Degrees, lng geo.Degrees) (ok bool) { + if err := roundtrip(geo.MakePoint(lat, lng)); err != nil { + t.Errorf("roundtrip: %v", err) + } + return !t.Failed() + } + if err := quick.Check(f, nil); err != nil { + t.Error(err) + } + }) +} + +func TestPointSphericalAngleTo(t *testing.T) { + const earthRadius = 6371.000 // volumetric mean radius (km) + const kmToRad = 1 / earthRadius + for _, tt := range []struct { + name string + x geo.Point + y geo.Point + want geo.Radians + wantErr string + }{ + { + name: "same-point-null-island", + x: geo.MakePoint(0, 0), + y: geo.MakePoint(0, 0), + want: 0.0 * geo.Radian, + }, + { + name: "same-point-north-pole", + x: geo.MakePoint(+90, 0), + y: geo.MakePoint(+90, +90), + want: 0.0 * geo.Radian, + }, + { + name: "same-point-south-pole", + x: geo.MakePoint(-90, 0), + y: geo.MakePoint(-90, -90), + want: 0.0 * geo.Radian, + }, + { + name: "north-pole-to-south-pole", + x: geo.MakePoint(+90, 0), + y: geo.MakePoint(-90, -90), + want: math.Pi * geo.Radian, + }, + { + name: "toronto-to-montreal", + x: geo.MakePoint(+43.6532, -79.3832), + y: geo.MakePoint(+45.5019, -73.5674), + want: 504.26 * kmToRad * geo.Radian, + }, + { + name: "sydney-to-san-francisco", + x: geo.MakePoint(-33.8727, +151.2057), + y: geo.MakePoint(+37.7749, -122.4194), + want: 11948.18 * kmToRad * geo.Radian, + }, + { + name: "new-york-to-paris", + x: geo.MakePoint(+40.7128, -74.0060), + y: geo.MakePoint(+48.8575, +2.3514), + want: 5837.15 * kmToRad * geo.Radian, + }, + { + name: "seattle-to-tokyo", + x: geo.MakePoint(+47.6061, -122.3328), + y: geo.MakePoint(+35.6764, +139.6500), + want: 7700.00 * kmToRad * geo.Radian, + }, + } { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.x.SphericalAngleTo(tt.y) + if tt.wantErr == "" && err != nil { + t.Fatalf("err %q, expected nil", err) + } + if tt.wantErr != "" && (err == nil || err.Error() != tt.wantErr) { + t.Fatalf("err %q, expected %q", err, tt.wantErr) + } + if tt.wantErr != "" { + return + } + + if !approx(got, tt.want) { + t.Errorf("x to y: got %v, want %v", got, tt.want) + } + + // Distance should be commutative + got, err = tt.y.SphericalAngleTo(tt.x) + if err != nil { + t.Fatalf("err %q, expected nil", err) + } + if !approx(got, tt.want) { + t.Errorf("y to x: got %v, want %v", got, tt.want) + } + t.Logf("x to y: %v km", got/kmToRad) + }) + } +} + +func approx[T ~float64](x, y T) bool { + return math.Abs(float64(x)-float64(y)) <= 1e-5 +} diff --git a/types/geo/quantize.go b/types/geo/quantize.go new file mode 100644 index 000000000..18ec11f9f --- /dev/null +++ b/types/geo/quantize.go @@ -0,0 +1,106 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package geo + +import ( + "math" + "sync" +) + +// MinSeparation is the minimum separation between two points after quantizing. +// [Point.Quantize] guarantees that two points will either be snapped to exactly +// the same point, which conflates multiple positions together, or that the two +// points will be far enough apart that successfully performing most reverse +// lookups would be highly improbable. +const MinSeparation = 50_000 * Meter + +// Latitude +var ( + // numSepsEquatorToPole is the number of separations between a point on + // the equator to a point on a pole, that satisfies [minPointSep]. In + // other words, the number of separations between 0° and +90° degrees + // latitude. + numSepsEquatorToPole = int(math.Floor(float64( + earthPolarCircumference / MinSeparation / 4))) + + // latSep is the number of degrees between two adjacent latitudinal + // points. In other words, the next point going straight north of + // 0° would be latSep°. + latSep = Degrees(90.0 / float64(numSepsEquatorToPole)) +) + +// snapToLat returns the number of the nearest latitudinal separation to +// lat. A positive result is north of the equator, a negative result is south, +// and zero is the equator itself. For example, a result of -1 would mean a +// point that is [latSep] south of the equator. +func snapToLat(lat Degrees) int { + return int(math.Round(float64(lat / latSep))) +} + +// lngSep is a lookup table for the number of degrees between two adjacent +// longitudinal separations. where the index corresponds to the absolute value +// of the latitude separation. The first value corresponds to the equator and +// the last value corresponds to the separation before the pole. There is no +// value for the pole itself, because longitude has no meaning there. +// +// [lngSep] is calculated on init, which is so quick and will be used so often +// that the startup cost is negligible. +var lngSep = sync.OnceValue(func() []Degrees { + lut := make([]Degrees, numSepsEquatorToPole) + + // i ranges from the equator to a pole + for i := range len(lut) { + // lat ranges from [0°, 90°], because the southern hemisphere is + // a reflection of the northern one. + lat := Degrees(i) * latSep + ratio := math.Cos(float64(lat.Radians())) + circ := Distance(ratio) * earthEquatorialCircumference + num := int(math.Floor(float64(circ / MinSeparation))) + // We define lut[0] as 0°, lut[len(lut)] to be the north pole, + // which means -lut[len(lut)] is the south pole. + lut[i] = Degrees(360.0 / float64(num)) + } + return lut +}) + +// snapToLatLng returns the number of the nearest latitudinal separation to lat, +// and the nearest longitudinal separation to lng. +func snapToLatLng(lat, lng Degrees) (Degrees, Degrees) { + latN := snapToLat(lat) + + // absolute index into lngSep + n := latN + if n < 0 { + n = -latN + } + + lngSep := lngSep() + if n < len(lngSep) { + sep := lngSep[n] + lngN := int(math.Round(float64(lng / sep))) + return Degrees(latN) * latSep, Degrees(lngN) * sep + } + if latN < 0 { // south pole + return -90 * Degree, 0 * Degree + } else { // north pole + return +90 * Degree, 0 * Degree + } +} + +// Quantize returns a new [Point] after throwing away enough location data in p +// so that it would be difficult to distinguish a node among all the other nodes +// in its general vicinity. One caveat is that if there’s only one point in an +// obscure location, someone could triangulate the node using additional data. +// +// This method is stable: given the same p, it will always return the same +// result. It is equivalent to snapping to points on Earth that are at least +// [MinSeparation] apart. +func (p Point) Quantize() Point { + if p.IsZero() { + return p + } + + lat, lng := snapToLatLng(p.lat, p.lng180-180) + return MakePoint(lat, lng) +} diff --git a/types/geo/quantize_test.go b/types/geo/quantize_test.go new file mode 100644 index 000000000..bc1f62c9b --- /dev/null +++ b/types/geo/quantize_test.go @@ -0,0 +1,130 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package geo_test + +import ( + "testing" + "testing/quick" + + "tailscale.com/types/geo" +) + +func TestPointAnonymize(t *testing.T) { + t.Run("nowhere", func(t *testing.T) { + var zero geo.Point + p := zero.Quantize() + want := zero.Valid() + if got := p.Valid(); got != want { + t.Fatalf("zero.Valid %t, want %t", got, want) + } + }) + + t.Run("separation", func(t *testing.T) { + // Walk from the south pole to the north pole and check that each + // point on the latitude is approximately MinSeparation apart. + const southPole = -90 * geo.Degree + const northPole = 90 * geo.Degree + const dateLine = 180 * geo.Degree + + llat := southPole + for lat := llat; lat <= northPole; lat += 0x1p-4 { + last := geo.MakePoint(llat, 0) + cur := geo.MakePoint(lat, 0) + anon := cur.Quantize() + switch latlng, g, err := anon.LatLng(); { + case err != nil: + t.Fatal(err) + case lat == southPole: + // initialize llng, to the first snapped longitude + llat = latlng + goto Lng + case g != 0: + t.Fatalf("%v is west or east of %v", anon, last) + case latlng < llat: + t.Fatalf("%v is south of %v", anon, last) + case latlng == llat: + continue + case latlng > llat: + switch dist, err := last.DistanceTo(anon); { + case err != nil: + t.Fatal(err) + case dist == 0.0: + continue + case dist < geo.MinSeparation: + t.Logf("lat=%v last=%v cur=%v anon=%v", lat, last, cur, anon) + t.Fatalf("%v is too close to %v", anon, last) + default: + llat = latlng + } + } + + Lng: + llng := dateLine + for lng := llng; lng <= dateLine && lng >= -dateLine; lng -= 0x1p-3 { + last := geo.MakePoint(llat, llng) + cur := geo.MakePoint(lat, lng) + anon := cur.Quantize() + switch latlng, g, err := anon.LatLng(); { + case err != nil: + t.Fatal(err) + case lng == dateLine: + // initialize llng, to the first snapped longitude + llng = g + continue + case latlng != llat: + t.Fatalf("%v is north or south of %v", anon, last) + case g != llng: + const tolerance = geo.MinSeparation * 0x1p-9 + switch dist, err := last.DistanceTo(anon); { + case err != nil: + t.Fatal(err) + case dist < tolerance: + continue + case dist < (geo.MinSeparation - tolerance): + t.Logf("lat=%v lng=%v last=%v cur=%v anon=%v", lat, lng, last, cur, anon) + t.Fatalf("%v is too close to %v: %v", anon, last, dist) + default: + llng = g + } + + } + } + } + if llat == southPole { + t.Fatal("llat never incremented") + } + }) + + t.Run("quick-check", func(t *testing.T) { + f := func(lat, lng geo.Degrees) bool { + p := geo.MakePoint(lat, lng) + q := p.Quantize() + t.Logf("quantize %v = %v", p, q) + + lat, lng, err := q.LatLng() + if err != nil { + t.Errorf("err %v, want nil", err) + return !t.Failed() + } + + if lat < -90*geo.Degree || lat > 90*geo.Degree { + t.Errorf("lat outside [-90°, +90°]: %v", lat) + } + if lng < -180*geo.Degree || lng > 180*geo.Degree { + t.Errorf("lng outside [-180°, +180°], %v", lng) + } + + if dist, err := p.DistanceTo(q); err != nil { + t.Error(err) + } else if dist > (geo.MinSeparation * 2) { + t.Errorf("moved too far: %v", dist) + } + + return !t.Failed() + } + if err := quick.Check(f, nil); err != nil { + t.Fatal(err) + } + }) +} diff --git a/types/geo/units.go b/types/geo/units.go new file mode 100644 index 000000000..76a4c02f7 --- /dev/null +++ b/types/geo/units.go @@ -0,0 +1,191 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package geo + +import ( + "math" + "strconv" + "strings" + "unicode" +) + +const ( + Degree Degrees = 1 + Radian Radians = 1 + Turn Turns = 1 + Meter Distance = 1 +) + +// Degrees represents a latitude or longitude, in decimal degrees. +type Degrees float64 + +// ParseDegrees parses s as decimal degrees. +func ParseDegrees(s string) (Degrees, error) { + s = strings.TrimSuffix(s, "°") + f, err := strconv.ParseFloat(s, 64) + return Degrees(f), err +} + +// MustParseDegrees parses s as decimal degrees, but panics on error. +func MustParseDegrees(s string) Degrees { + d, err := ParseDegrees(s) + if err != nil { + panic(err) + } + return d +} + +// String implements the [Stringer] interface. The output is formatted in +// decimal degrees, prefixed by either the appropriate + or - sign, and suffixed +// by a ° degree symbol. +func (d Degrees) String() string { + b, _ := d.AppendText(nil) + b = append(b, []byte("°")...) + return string(b) +} + +// AppendText implements [encoding.TextAppender]. The output is formatted in +// decimal degrees, prefixed by either the appropriate + or - sign. +func (d Degrees) AppendText(b []byte) ([]byte, error) { + b = d.AppendZeroPaddedText(b, 0) + return b, nil +} + +// AppendZeroPaddedText appends d formatted as decimal degrees to b. The number of +// integer digits will be zero-padded to nint. +func (d Degrees) AppendZeroPaddedText(b []byte, nint int) []byte { + n := float64(d) + + if math.IsInf(n, 0) || math.IsNaN(n) { + return strconv.AppendFloat(b, n, 'f', -1, 64) + } + + sign := byte('+') + if math.Signbit(n) { + sign = '-' + n = -n + } + b = append(b, sign) + + pad := nint - 1 + for nn := n / 10; nn >= 1 && pad > 0; nn /= 10 { + pad-- + } + for range pad { + b = append(b, '0') + } + return strconv.AppendFloat(b, n, 'f', -1, 64) +} + +// Radians converts d into radians. +func (d Degrees) Radians() Radians { + return Radians(d * math.Pi / 180.0) +} + +// Turns converts d into a number of turns. +func (d Degrees) Turns() Turns { + return Turns(d / 360.0) +} + +// Radians represents a latitude or longitude, in radians. +type Radians float64 + +// ParseRadians parses s as radians. +func ParseRadians(s string) (Radians, error) { + s = strings.TrimSuffix(s, "rad") + s = strings.TrimRightFunc(s, unicode.IsSpace) + f, err := strconv.ParseFloat(s, 64) + return Radians(f), err +} + +// MustParseRadians parses s as radians, but panics on error. +func MustParseRadians(s string) Radians { + r, err := ParseRadians(s) + if err != nil { + panic(err) + } + return r +} + +// String implements the [Stringer] interface. +func (r Radians) String() string { + return strconv.FormatFloat(float64(r), 'f', -1, 64) + " rad" +} + +// Degrees converts r into decimal degrees. +func (r Radians) Degrees() Degrees { + return Degrees(r * 180.0 / math.Pi) +} + +// Turns converts r into a number of turns. +func (r Radians) Turns() Turns { + return Turns(r / 2 / math.Pi) +} + +// Turns represents a number of complete revolutions around a sphere. +type Turns float64 + +// String implements the [Stringer] interface. +func (o Turns) String() string { + return strconv.FormatFloat(float64(o), 'f', -1, 64) +} + +// Degrees converts t into decimal degrees. +func (o Turns) Degrees() Degrees { + return Degrees(o * 360.0) +} + +// Radians converts t into radians. +func (o Turns) Radians() Radians { + return Radians(o * 2 * math.Pi) +} + +// Distance represents a great-circle distance in meters. +type Distance float64 + +// ParseDistance parses s as distance in meters. +func ParseDistance(s string) (Distance, error) { + s = strings.TrimSuffix(s, "m") + s = strings.TrimRightFunc(s, unicode.IsSpace) + f, err := strconv.ParseFloat(s, 64) + return Distance(f), err +} + +// MustParseDistance parses s as distance in meters, but panics on error. +func MustParseDistance(s string) Distance { + d, err := ParseDistance(s) + if err != nil { + panic(err) + } + return d +} + +// String implements the [Stringer] interface. +func (d Distance) String() string { + return strconv.FormatFloat(float64(d), 'f', -1, 64) + "m" +} + +// DistanceOnEarth converts t turns into the great-circle distance, in meters. +func DistanceOnEarth(t Turns) Distance { + return Distance(t) * EarthMeanCircumference +} + +// Earth Fact Sheet +// https://nssdc.gsfc.nasa.gov/planetary/factsheet/earthfact.html +const ( + // EarthMeanRadius is the volumetric mean radius of the Earth. + EarthMeanRadius = 6_371_000 * Meter + // EarthMeanCircumference is the volumetric mean circumference of the Earth. + EarthMeanCircumference = 2 * math.Pi * EarthMeanRadius + + // earthEquatorialRadius is the equatorial radius of the Earth. + earthEquatorialRadius = 6_378_137 * Meter + // earthEquatorialCircumference is the equatorial circumference of the Earth. + earthEquatorialCircumference = 2 * math.Pi * earthEquatorialRadius + + // earthPolarRadius is the polar radius of the Earth. + earthPolarRadius = 6_356_752 * Meter + // earthPolarCircumference is the polar circumference of the Earth. + earthPolarCircumference = 2 * math.Pi * earthPolarRadius +) diff --git a/types/geo/units_test.go b/types/geo/units_test.go new file mode 100644 index 000000000..b6f724ce0 --- /dev/null +++ b/types/geo/units_test.go @@ -0,0 +1,395 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package geo_test + +import ( + "math" + "strings" + "testing" + + "tailscale.com/types/geo" +) + +func TestDegrees(t *testing.T) { + for _, tt := range []struct { + name string + degs geo.Degrees + wantStr string + wantText string + wantPad string + wantRads geo.Radians + wantTurns geo.Turns + }{ + { + name: "zero", + degs: 0.0 * geo.Degree, + wantStr: "+0°", + wantText: "+0", + wantPad: "+000", + wantRads: 0.0 * geo.Radian, + wantTurns: 0 * geo.Turn, + }, + { + name: "quarter-turn", + degs: 90.0 * geo.Degree, + wantStr: "+90°", + wantText: "+90", + wantPad: "+090", + wantRads: 0.5 * math.Pi * geo.Radian, + wantTurns: 0.25 * geo.Turn, + }, + { + name: "half-turn", + degs: 180.0 * geo.Degree, + wantStr: "+180°", + wantText: "+180", + wantPad: "+180", + wantRads: 1.0 * math.Pi * geo.Radian, + wantTurns: 0.5 * geo.Turn, + }, + { + name: "full-turn", + degs: 360.0 * geo.Degree, + wantStr: "+360°", + wantText: "+360", + wantPad: "+360", + wantRads: 2.0 * math.Pi * geo.Radian, + wantTurns: 1.0 * geo.Turn, + }, + { + name: "negative-zero", + degs: geo.MustParseDegrees("-0.0"), + wantStr: "-0°", + wantText: "-0", + wantPad: "-000", + wantRads: 0 * geo.Radian * -1, + wantTurns: 0 * geo.Turn * -1, + }, + { + name: "small-degree", + degs: -1.2003 * geo.Degree, + wantStr: "-1.2003°", + wantText: "-1.2003", + wantPad: "-001.2003", + wantRads: -0.020949187011687936 * geo.Radian, + wantTurns: -0.0033341666666666663 * geo.Turn, + }, + } { + t.Run(tt.name, func(t *testing.T) { + if got := tt.degs.String(); got != tt.wantStr { + t.Errorf("String got %q, want %q", got, tt.wantStr) + } + + d, err := geo.ParseDegrees(tt.wantStr) + if err != nil { + t.Fatalf("ParseDegrees err %q, want nil", err.Error()) + } + if d != tt.degs { + t.Errorf("ParseDegrees got %q, want %q", d, tt.degs) + } + + b, err := tt.degs.AppendText(nil) + if err != nil { + t.Fatalf("AppendText err %q, want nil", err.Error()) + } + if string(b) != tt.wantText { + t.Errorf("AppendText got %q, want %q", b, tt.wantText) + } + + b = tt.degs.AppendZeroPaddedText(nil, 3) + if string(b) != tt.wantPad { + t.Errorf("AppendZeroPaddedText got %q, want %q", b, tt.wantPad) + } + + r := tt.degs.Radians() + if r != tt.wantRads { + t.Errorf("Radian got %v, want %v", r, tt.wantRads) + } + if d := r.Degrees(); d != tt.degs { // Roundtrip + t.Errorf("Degrees got %v, want %v", d, tt.degs) + } + + o := tt.degs.Turns() + if o != tt.wantTurns { + t.Errorf("Turns got %v, want %v", o, tt.wantTurns) + } + }) + } +} + +func TestRadians(t *testing.T) { + for _, tt := range []struct { + name string + rads geo.Radians + wantStr string + wantText string + wantDegs geo.Degrees + wantTurns geo.Turns + }{ + { + name: "zero", + rads: 0.0 * geo.Radian, + wantStr: "0 rad", + wantDegs: 0.0 * geo.Degree, + wantTurns: 0 * geo.Turn, + }, + { + name: "quarter-turn", + rads: 0.5 * math.Pi * geo.Radian, + wantStr: "1.5707963267948966 rad", + wantDegs: 90.0 * geo.Degree, + wantTurns: 0.25 * geo.Turn, + }, + { + name: "half-turn", + rads: 1.0 * math.Pi * geo.Radian, + wantStr: "3.141592653589793 rad", + wantDegs: 180.0 * geo.Degree, + wantTurns: 0.5 * geo.Turn, + }, + { + name: "full-turn", + rads: 2.0 * math.Pi * geo.Radian, + wantStr: "6.283185307179586 rad", + wantDegs: 360.0 * geo.Degree, + wantTurns: 1.0 * geo.Turn, + }, + { + name: "negative-zero", + rads: geo.MustParseRadians("-0"), + wantStr: "-0 rad", + wantDegs: 0 * geo.Degree * -1, + wantTurns: 0 * geo.Turn * -1, + }, + } { + t.Run(tt.name, func(t *testing.T) { + if got := tt.rads.String(); got != tt.wantStr { + t.Errorf("String got %q, want %q", got, tt.wantStr) + } + + r, err := geo.ParseRadians(tt.wantStr) + if err != nil { + t.Fatalf("ParseDegrees err %q, want nil", err.Error()) + } + if r != tt.rads { + t.Errorf("ParseDegrees got %q, want %q", r, tt.rads) + } + + d := tt.rads.Degrees() + if d != tt.wantDegs { + t.Errorf("Degrees got %v, want %v", d, tt.wantDegs) + } + if r := d.Radians(); r != tt.rads { // Roundtrip + t.Errorf("Radians got %v, want %v", r, tt.rads) + } + + o := tt.rads.Turns() + if o != tt.wantTurns { + t.Errorf("Turns got %v, want %v", o, tt.wantTurns) + } + }) + } +} + +func TestTurns(t *testing.T) { + for _, tt := range []struct { + name string + turns geo.Turns + wantStr string + wantText string + wantDegs geo.Degrees + wantRads geo.Radians + }{ + { + name: "zero", + turns: 0.0, + wantStr: "0", + wantDegs: 0.0 * geo.Degree, + wantRads: 0 * geo.Radian, + }, + { + name: "quarter-turn", + turns: 0.25, + wantStr: "0.25", + wantDegs: 90.0 * geo.Degree, + wantRads: 0.5 * math.Pi * geo.Radian, + }, + { + name: "half-turn", + turns: 0.5, + wantStr: "0.5", + wantDegs: 180.0 * geo.Degree, + wantRads: 1.0 * math.Pi * geo.Radian, + }, + { + name: "full-turn", + turns: 1.0, + wantStr: "1", + wantDegs: 360.0 * geo.Degree, + wantRads: 2.0 * math.Pi * geo.Radian, + }, + { + name: "negative-zero", + turns: geo.Turns(math.Copysign(0, -1)), + wantStr: "-0", + wantDegs: 0 * geo.Degree * -1, + wantRads: 0 * geo.Radian * -1, + }, + } { + t.Run(tt.name, func(t *testing.T) { + if got := tt.turns.String(); got != tt.wantStr { + t.Errorf("String got %q, want %q", got, tt.wantStr) + } + + d := tt.turns.Degrees() + if d != tt.wantDegs { + t.Errorf("Degrees got %v, want %v", d, tt.wantDegs) + } + if o := d.Turns(); o != tt.turns { // Roundtrip + t.Errorf("Turns got %v, want %v", o, tt.turns) + } + + r := tt.turns.Radians() + if r != tt.wantRads { + t.Errorf("Turns got %v, want %v", r, tt.wantRads) + } + }) + } +} + +func TestDistance(t *testing.T) { + for _, tt := range []struct { + name string + dist geo.Distance + wantStr string + }{ + { + name: "zero", + dist: 0.0 * geo.Meter, + wantStr: "0m", + }, + { + name: "random", + dist: 4 * geo.Meter, + wantStr: "4m", + }, + { + name: "light-second", + dist: 299_792_458 * geo.Meter, + wantStr: "299792458m", + }, + { + name: "planck-length", + dist: 1.61625518e-35 * geo.Meter, + wantStr: "0.0000000000000000000000000000000000161625518m", + }, + { + name: "negative-zero", + dist: geo.Distance(math.Copysign(0, -1)), + wantStr: "-0m", + }, + } { + t.Run(tt.name, func(t *testing.T) { + if got := tt.dist.String(); got != tt.wantStr { + t.Errorf("String got %q, want %q", got, tt.wantStr) + } + + r, err := geo.ParseDistance(tt.wantStr) + if err != nil { + t.Fatalf("ParseDegrees err %q, want nil", err.Error()) + } + if r != tt.dist { + t.Errorf("ParseDegrees got %q, want %q", r, tt.dist) + } + }) + } +} + +func TestDistanceOnEarth(t *testing.T) { + for _, tt := range []struct { + name string + here geo.Point + there geo.Point + want geo.Distance + wantErr string + }{ + { + name: "no-points", + here: geo.Point{}, + there: geo.Point{}, + wantErr: "not a valid point", + }, + { + name: "not-here", + here: geo.Point{}, + there: geo.MakePoint(0, 0), + wantErr: "not a valid point", + }, + { + name: "not-there", + here: geo.MakePoint(0, 0), + there: geo.Point{}, + wantErr: "not a valid point", + }, + { + name: "null-island", + here: geo.MakePoint(0, 0), + there: geo.MakePoint(0, 0), + want: 0 * geo.Meter, + }, + { + name: "equator-to-south-pole", + here: geo.MakePoint(0, 0), + there: geo.MakePoint(-90, 0), + want: geo.EarthMeanCircumference / 4, + }, + { + name: "north-pole-to-south-pole", + here: geo.MakePoint(+90, 0), + there: geo.MakePoint(-90, 0), + want: geo.EarthMeanCircumference / 2, + }, + { + name: "meridian-to-antimeridian", + here: geo.MakePoint(0, 0), + there: geo.MakePoint(0, -180), + want: geo.EarthMeanCircumference / 2, + }, + { + name: "positive-to-negative-antimeridian", + here: geo.MakePoint(0, 180), + there: geo.MakePoint(0, -180), + want: 0 * geo.Meter, + }, + { + name: "toronto-to-montreal", + here: geo.MakePoint(+43.70011, -79.41630), + there: geo.MakePoint(+45.50884, -73.58781), + want: 503_200 * geo.Meter, + }, + { + name: "montreal-to-san-francisco", + here: geo.MakePoint(+45.50884, -73.58781), + there: geo.MakePoint(+37.77493, -122.41942), + want: 4_082_600 * geo.Meter, + }, + } { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.here.DistanceTo(tt.there) + if tt.wantErr == "" && err != nil { + t.Fatalf("err %q, want nil", err) + } + if tt.wantErr != "" && !strings.Contains(err.Error(), tt.wantErr) { + t.Fatalf("err %q, want %q", err, tt.wantErr) + } + + approx := func(x, y geo.Distance) bool { + return math.Abs(float64(x)-float64(y)) <= 10 + } + if !approx(got, tt.want) { + t.Fatalf("got %v, want %v", got, tt.want) + } + }) + } +} diff --git a/types/iox/io.go b/types/iox/io.go new file mode 100644 index 000000000..a5ca1be43 --- /dev/null +++ b/types/iox/io.go @@ -0,0 +1,23 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package iox provides types to implement [io] functionality. +package iox + +// TODO(https://go.dev/issue/21670): Deprecate or remove this functionality +// once the Go language supports implementing an 1-method interface directly +// using a function value of a matching signature. + +// ReaderFunc implements [io.Reader] using the underlying function value. +type ReaderFunc func([]byte) (int, error) + +func (f ReaderFunc) Read(b []byte) (int, error) { + return f(b) +} + +// WriterFunc implements [io.Writer] using the underlying function value. +type WriterFunc func([]byte) (int, error) + +func (f WriterFunc) Write(b []byte) (int, error) { + return f(b) +} diff --git a/types/iox/io_test.go b/types/iox/io_test.go new file mode 100644 index 000000000..9fba39605 --- /dev/null +++ b/types/iox/io_test.go @@ -0,0 +1,39 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package iox + +import ( + "bytes" + "io" + "testing" + "testing/iotest" + + "tailscale.com/util/must" +) + +func TestCopy(t *testing.T) { + const testdata = "the quick brown fox jumped over the lazy dog" + src := testdata + bb := new(bytes.Buffer) + if got := must.Get(io.Copy(bb, ReaderFunc(func(b []byte) (n int, err error) { + n = copy(b[:min(len(b), 7)], src) + src = src[n:] + if len(src) == 0 { + err = io.EOF + } + return n, err + }))); int(got) != len(testdata) { + t.Errorf("copy = %d, want %d", got, len(testdata)) + } + var dst []byte + if got := must.Get(io.Copy(WriterFunc(func(b []byte) (n int, err error) { + dst = append(dst, b...) + return len(b), nil + }), iotest.OneByteReader(bb))); int(got) != len(testdata) { + t.Errorf("copy = %d, want %d", got, len(testdata)) + } + if string(dst) != testdata { + t.Errorf("copy = %q, want %q", dst, testdata) + } +} diff --git a/types/jsonx/json.go b/types/jsonx/json.go new file mode 100644 index 000000000..3f01ea358 --- /dev/null +++ b/types/jsonx/json.go @@ -0,0 +1,171 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package jsonx contains helper types and functionality to use with +// [github.com/go-json-experiment/json], which is positioned to be +// merged into the Go standard library as [encoding/json/v2]. +// +// See https://go.dev/issues/71497 +package jsonx + +import ( + "errors" + "fmt" + "reflect" + + "github.com/go-json-experiment/json" + "github.com/go-json-experiment/json/jsontext" +) + +var ( + errUnknownTypeName = errors.New("unknown type name") + errNonSingularValue = errors.New("dynamic value must only have exactly one member") +) + +// MakeInterfaceCoders constructs a pair of marshal and unmarshal functions +// to serialize a Go interface type T. A bijective mapping for the set +// of concrete types that implement T is provided, +// where the key is a stable type name to use in the JSON representation, +// while the value is any value of a concrete type that implements T. +// By convention, only the zero value of concrete types is passed. +// +// The JSON representation for a dynamic value is a JSON object +// with a single member, where the member name is the type name, +// and the value is the JSON representation for the Go value. +// For example, the JSON serialization for a concrete type named Foo +// would be {"Foo": ...}, where ... is the JSON representation +// of the concrete value of the Foo type. +// +// Example instantiation: +// +// // Interface is a union type implemented by [FooType] and [BarType]. +// type Interface interface { ... } +// +// var interfaceCoders = MakeInterfaceCoders(map[string]Interface{ +// "FooType": FooType{}, +// "BarType": (*BarType)(nil), +// }) +// +// The pair of Marshal and Unmarshal functions can be used with the [json] +// package with either type-specified or caller-specified serialization. +// The result of this constructor is usually stored into a global variable. +// +// Example usage with type-specified serialization: +// +// // InterfaceWrapper is a concrete type that wraps [Interface]. +// // It extends [Interface] to implement +// // [json.MarshalerTo] and [json.UnmarshalerFrom]. +// type InterfaceWrapper struct{ Interface } +// +// func (w InterfaceWrapper) MarshalJSONTo(enc *jsontext.Encoder) error { +// return interfaceCoders.Marshal(enc, &w.Interface) +// } +// +// func (w *InterfaceWrapper) UnmarshalJSONFrom(dec *jsontext.Decoder) error { +// return interfaceCoders.Unmarshal(dec, &w.Interface) +// } +// +// Example usage with caller-specified serialization: +// +// var opts json.Options = json.JoinOptions( +// json.WithMarshalers(json.MarshalToFunc(interfaceCoders.Marshal)), +// json.WithUnmarshalers(json.UnmarshalFromFunc(interfaceCoders.Unmarshal)), +// ) +// +// var v Interface +// ... := json.Marshal(v, opts) +// ... := json.Unmarshal(&v, opts) +// +// The function panics if T is not a named interface kind, +// or if valuesByName contains distinct entries with the same concrete type. +func MakeInterfaceCoders[T any](valuesByName map[string]T) (c struct { + Marshal func(*jsontext.Encoder, *T) error + Unmarshal func(*jsontext.Decoder, *T) error +}) { + // Verify that T is a named interface. + switch t := reflect.TypeFor[T](); { + case t.Kind() != reflect.Interface: + panic(fmt.Sprintf("%v must be an interface kind", t)) + case t.Name() == "": + panic(fmt.Sprintf("%v must be a named type", t)) + } + + // Construct a bijective mapping of names to types. + typesByName := make(map[string]reflect.Type) + namesByType := make(map[reflect.Type]string) + for name, value := range valuesByName { + t := reflect.TypeOf(value) + if t == nil { + panic(fmt.Sprintf("nil value for %s", name)) + } + if name2, ok := namesByType[t]; ok { + panic(fmt.Sprintf("type %v cannot have multiple names %s and %v", t, name, name2)) + } + typesByName[name] = t + namesByType[t] = name + } + + // Construct the marshal and unmarshal functions. + c.Marshal = func(enc *jsontext.Encoder, val *T) error { + t := reflect.TypeOf(*val) + if t == nil { + return enc.WriteToken(jsontext.Null) + } + name := namesByType[t] + if name == "" { + return fmt.Errorf("Go type %v: %w", t, errUnknownTypeName) + } + + if err := enc.WriteToken(jsontext.BeginObject); err != nil { + return err + } + if err := enc.WriteToken(jsontext.String(name)); err != nil { + return err + } + if err := json.MarshalEncode(enc, *val); err != nil { + return err + } + if err := enc.WriteToken(jsontext.EndObject); err != nil { + return err + } + return nil + } + c.Unmarshal = func(dec *jsontext.Decoder, val *T) error { + switch tok, err := dec.ReadToken(); { + case err != nil: + return err + case tok.Kind() == 'n': + var zero T + *val = zero // store nil interface value for JSON null + return nil + case tok.Kind() != '{': + return &json.SemanticError{JSONKind: tok.Kind(), GoType: reflect.TypeFor[T]()} + } + var v reflect.Value + switch tok, err := dec.ReadToken(); { + case err != nil: + return err + case tok.Kind() != '"': + return errNonSingularValue + default: + t := typesByName[tok.String()] + if t == nil { + return errUnknownTypeName + } + v = reflect.New(t) + } + if err := json.UnmarshalDecode(dec, v.Interface()); err != nil { + return err + } + *val = v.Elem().Interface().(T) + switch tok, err := dec.ReadToken(); { + case err != nil: + return err + case tok.Kind() != '}': + return errNonSingularValue + } + return nil + } + + return c +} diff --git a/types/jsonx/json_test.go b/types/jsonx/json_test.go new file mode 100644 index 000000000..0f2a646c4 --- /dev/null +++ b/types/jsonx/json_test.go @@ -0,0 +1,140 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package jsonx + +import ( + "errors" + "testing" + + "github.com/go-json-experiment/json" + "github.com/go-json-experiment/json/jsontext" + "github.com/google/go-cmp/cmp" + "tailscale.com/types/ptr" +) + +type Interface interface { + implementsInterface() +} + +type Foo string + +func (Foo) implementsInterface() {} + +type Bar int + +func (Bar) implementsInterface() {} + +type Baz struct{ Fizz, Buzz string } + +func (*Baz) implementsInterface() {} + +var interfaceCoders = MakeInterfaceCoders(map[string]Interface{ + "Foo": Foo(""), + "Bar": (*Bar)(nil), + "Baz": (*Baz)(nil), +}) + +type InterfaceWrapper struct{ Interface } + +func (w InterfaceWrapper) MarshalJSONTo(enc *jsontext.Encoder) error { + return interfaceCoders.Marshal(enc, &w.Interface) +} + +func (w *InterfaceWrapper) UnmarshalJSONFrom(dec *jsontext.Decoder) error { + return interfaceCoders.Unmarshal(dec, &w.Interface) +} + +func TestInterfaceCoders(t *testing.T) { + var opts json.Options = json.JoinOptions( + json.WithMarshalers(json.MarshalToFunc(interfaceCoders.Marshal)), + json.WithUnmarshalers(json.UnmarshalFromFunc(interfaceCoders.Unmarshal)), + ) + + errSkipMarshal := errors.New("skip marshal") + makeFiller := func() InterfaceWrapper { + return InterfaceWrapper{&Baz{"fizz", "buzz"}} + } + + for _, tt := range []struct { + label string + wantVal InterfaceWrapper + wantJSON string + wantMarshalError error + wantUnmarshalError error + }{{ + label: "Null", + wantVal: InterfaceWrapper{}, + wantJSON: `null`, + }, { + label: "Foo", + wantVal: InterfaceWrapper{Foo("hello")}, + wantJSON: `{"Foo":"hello"}`, + }, { + label: "BarPointer", + wantVal: InterfaceWrapper{ptr.To(Bar(5))}, + wantJSON: `{"Bar":5}`, + }, { + label: "BarValue", + wantVal: InterfaceWrapper{Bar(5)}, + // NOTE: We could handle BarValue just like BarPointer, + // but round-trip marshal/unmarshal would not be identical. + wantMarshalError: errUnknownTypeName, + }, { + label: "Baz", + wantVal: InterfaceWrapper{&Baz{"alpha", "omega"}}, + wantJSON: `{"Baz":{"Fizz":"alpha","Buzz":"omega"}}`, + }, { + label: "Unknown", + wantVal: makeFiller(), + wantJSON: `{"Unknown":[1,2,3]}`, + wantMarshalError: errSkipMarshal, + wantUnmarshalError: errUnknownTypeName, + }, { + label: "Empty", + wantVal: makeFiller(), + wantJSON: `{}`, + wantMarshalError: errSkipMarshal, + wantUnmarshalError: errNonSingularValue, + }, { + label: "Duplicate", + wantVal: InterfaceWrapper{Foo("hello")}, // first entry wins + wantJSON: `{"Foo":"hello","Bar":5}`, + wantMarshalError: errSkipMarshal, + wantUnmarshalError: errNonSingularValue, + }} { + t.Run(tt.label, func(t *testing.T) { + if tt.wantMarshalError != errSkipMarshal { + switch gotJSON, err := json.Marshal(&tt.wantVal); { + case !errors.Is(err, tt.wantMarshalError): + t.Fatalf("json.Marshal(%v) error = %v, want %v", tt.wantVal, err, tt.wantMarshalError) + case string(gotJSON) != tt.wantJSON: + t.Fatalf("json.Marshal(%v) = %s, want %s", tt.wantVal, gotJSON, tt.wantJSON) + } + switch gotJSON, err := json.Marshal(&tt.wantVal.Interface, opts); { + case !errors.Is(err, tt.wantMarshalError): + t.Fatalf("json.Marshal(%v) error = %v, want %v", tt.wantVal, err, tt.wantMarshalError) + case string(gotJSON) != tt.wantJSON: + t.Fatalf("json.Marshal(%v) = %s, want %s", tt.wantVal, gotJSON, tt.wantJSON) + } + } + + if tt.wantJSON != "" { + gotVal := makeFiller() + if err := json.Unmarshal([]byte(tt.wantJSON), &gotVal); !errors.Is(err, tt.wantUnmarshalError) { + t.Fatalf("json.Unmarshal(%v) error = %v, want %v", tt.wantJSON, err, tt.wantUnmarshalError) + } + if d := cmp.Diff(gotVal, tt.wantVal); d != "" { + t.Fatalf("json.Unmarshal(%v):\n%s", tt.wantJSON, d) + } + gotVal = makeFiller() + if err := json.Unmarshal([]byte(tt.wantJSON), &gotVal.Interface, opts); !errors.Is(err, tt.wantUnmarshalError) { + t.Fatalf("json.Unmarshal(%v) error = %v, want %v", tt.wantJSON, err, tt.wantUnmarshalError) + } + if d := cmp.Diff(gotVal, tt.wantVal); d != "" { + t.Fatalf("json.Unmarshal(%v):\n%s", tt.wantJSON, d) + } + } + }) + } +} diff --git a/types/key/derp.go b/types/key/derp.go new file mode 100644 index 000000000..1466b85bc --- /dev/null +++ b/types/key/derp.go @@ -0,0 +1,90 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package key + +import ( + "crypto/subtle" + "encoding/hex" + "encoding/json" + "errors" + "fmt" + "strings" + + "go4.org/mem" + "tailscale.com/types/structs" +) + +var ErrInvalidMeshKey = errors.New("invalid mesh key") + +// DERPMesh is a mesh key, used for inter-DERP-node communication and for +// privileged DERP clients. +type DERPMesh struct { + _ structs.Incomparable // == isn't constant-time + k [32]byte // 64-digit hexadecimal numbers fit in 32 bytes +} + +// MarshalJSON implements the [encoding/json.Marshaler] interface. +func (k DERPMesh) MarshalJSON() ([]byte, error) { + return json.Marshal(k.String()) +} + +// UnmarshalJSON implements the [encoding/json.Unmarshaler] interface. +func (k *DERPMesh) UnmarshalJSON(data []byte) error { + var s string + json.Unmarshal(data, &s) + + if hex.DecodedLen(len(s)) != len(k.k) { + return fmt.Errorf("types/key/derp: cannot unmarshal, incorrect size mesh key len: %d, must be %d, %w", hex.DecodedLen(len(s)), len(k.k), ErrInvalidMeshKey) + } + _, err := hex.Decode(k.k[:], []byte(s)) + if err != nil { + return fmt.Errorf("types/key/derp: cannot unmarshal, invalid mesh key: %w", err) + } + + return nil +} + +// DERPMeshFromRaw32 parses a 32-byte raw value as a DERP mesh key. +func DERPMeshFromRaw32(raw mem.RO) DERPMesh { + if raw.Len() != 32 { + panic("input has wrong size") + } + var ret DERPMesh + raw.Copy(ret.k[:]) + return ret +} + +// ParseDERPMesh parses a DERP mesh key from a string. +// This function trims whitespace around the string. +// If the key is not a 64-digit hexadecimal number, ErrInvalidMeshKey is returned. +func ParseDERPMesh(key string) (DERPMesh, error) { + key = strings.TrimSpace(key) + if len(key) != 64 { + return DERPMesh{}, fmt.Errorf("%w: must be 64-digit hexadecimal number", ErrInvalidMeshKey) + } + decoded, err := hex.DecodeString(key) + if err != nil { + return DERPMesh{}, fmt.Errorf("%w: %v", ErrInvalidMeshKey, err) + } + return DERPMeshFromRaw32(mem.B(decoded)), nil +} + +// IsZero reports whether k is the zero value. +func (k DERPMesh) IsZero() bool { + return k.Equal(DERPMesh{}) +} + +// Equal reports whether k and other are the same key. +func (k DERPMesh) Equal(other DERPMesh) bool { + // Compare mesh keys in constant time to prevent timing attacks. + // Since mesh keys are a fixed length, we don’t need to be concerned + // about timing attacks on client mesh keys that are the wrong length. + // See https://github.com/tailscale/corp/issues/28720 + return subtle.ConstantTimeCompare(k.k[:], other.k[:]) == 1 +} + +// String returns k as a hex-encoded 64-digit number. +func (k DERPMesh) String() string { + return hex.EncodeToString(k.k[:]) +} diff --git a/types/key/derp_test.go b/types/key/derp_test.go new file mode 100644 index 000000000..b91cbbf8c --- /dev/null +++ b/types/key/derp_test.go @@ -0,0 +1,133 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package key + +import ( + "errors" + "testing" + + "go4.org/mem" +) + +func TestDERPMeshIsValid(t *testing.T) { + for name, tt := range map[string]struct { + input string + want string + wantErr error + }{ + "good": { + input: "0123456789012345678901234567890123456789012345678901234567890123", + want: "0123456789012345678901234567890123456789012345678901234567890123", + wantErr: nil, + }, + "hex": { + input: "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", + want: "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", + wantErr: nil, + }, + "uppercase": { + input: "0123456789ABCDEF0123456789ABCDEF0123456789ABCDEF0123456789ABCDEF", + want: "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", + wantErr: nil, + }, + "whitespace": { + input: " 0123456789012345678901234567890123456789012345678901234567890123 ", + want: "0123456789012345678901234567890123456789012345678901234567890123", + wantErr: nil, + }, + "short": { + input: "0123456789abcdef", + wantErr: ErrInvalidMeshKey, + }, + "long": { + input: "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef0", + wantErr: ErrInvalidMeshKey, + }, + } { + t.Run(name, func(t *testing.T) { + k, err := ParseDERPMesh(tt.input) + if !errors.Is(err, tt.wantErr) { + t.Errorf("err %v, want %v", err, tt.wantErr) + } + + got := k.String() + if got != tt.want && tt.wantErr == nil { + t.Errorf("got %q, want %q", got, tt.want) + } + + }) + } + +} + +func TestDERPMesh(t *testing.T) { + t.Parallel() + + for name, tt := range map[string]struct { + str string + hex []byte + equal bool // are str and hex equal? + }{ + "zero": { + str: "0000000000000000000000000000000000000000000000000000000000000000", + hex: []byte{ + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + }, + equal: true, + }, + "equal": { + str: "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", + hex: []byte{ + 0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef, + 0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef, + 0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef, + 0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef, + }, + equal: true, + }, + "unequal": { + str: "0badc0de00000000000000000000000000000000000000000000000000000000", + hex: []byte{ + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + }, + equal: false, + }, + } { + t.Run(name, func(t *testing.T) { + t.Parallel() + + k, err := ParseDERPMesh(tt.str) + if err != nil { + t.Fatal(err) + } + + // string representation should round-trip + s := k.String() + if s != tt.str { + t.Fatalf("string %s, want %s", s, tt.str) + } + + // if tt.equal, then tt.hex is intended to be equal + if k.k != [32]byte(tt.hex) && tt.equal { + t.Fatalf("decoded %x, want %x", k.k, tt.hex) + } + + h := DERPMeshFromRaw32(mem.B(tt.hex)) + if k.Equal(h) != tt.equal { + if tt.equal { + t.Fatalf("%v != %v", k, h) + } else { + t.Fatalf("%v == %v", k, h) + } + } + + }) + } +} diff --git a/types/key/disco.go b/types/key/disco.go index 1013ce5bf..52b40c766 100644 --- a/types/key/disco.go +++ b/types/key/disco.go @@ -73,6 +73,44 @@ func (k DiscoPrivate) Shared(p DiscoPublic) DiscoShared { return ret } +// SortedPairOfDiscoPublic is a lexicographically sorted container of two +// [DiscoPublic] keys. +type SortedPairOfDiscoPublic struct { + k [2]DiscoPublic +} + +// Get returns the underlying keys. +func (s SortedPairOfDiscoPublic) Get() [2]DiscoPublic { + return s.k +} + +// NewSortedPairOfDiscoPublic returns a SortedPairOfDiscoPublic from a and b. +func NewSortedPairOfDiscoPublic(a, b DiscoPublic) SortedPairOfDiscoPublic { + s := SortedPairOfDiscoPublic{} + if a.Compare(b) < 0 { + s.k[0] = a + s.k[1] = b + } else { + s.k[0] = b + s.k[1] = a + } + return s +} + +func (s SortedPairOfDiscoPublic) String() string { + return fmt.Sprintf("%s <=> %s", s.k[0].ShortString(), s.k[1].ShortString()) +} + +// Equal returns true if s and b are equal, otherwise it returns false. +func (s SortedPairOfDiscoPublic) Equal(b SortedPairOfDiscoPublic) bool { + for i := range s.k { + if s.k[i].Compare(b.k[i]) != 0 { + return false + } + } + return true +} + // DiscoPublic is the public portion of a DiscoPrivate. type DiscoPublic struct { k [32]byte @@ -129,11 +167,11 @@ func (k DiscoPublic) String() string { } // Compare returns an integer comparing DiscoPublic k and l lexicographically. -// The result will be 0 if k == l, -1 if k < l, and +1 if k > l. This is useful -// for situations requiring only one node in a pair to perform some operation, -// e.g. probing UDP path lifetime. -func (k DiscoPublic) Compare(l DiscoPublic) int { - return bytes.Compare(k.k[:], l.k[:]) +// The result will be 0 if k == other, -1 if k < other, and +1 if k > other. +// This is useful for situations requiring only one node in a pair to perform +// some operation, e.g. probing UDP path lifetime. +func (k DiscoPublic) Compare(other DiscoPublic) int { + return bytes.Compare(k.k[:], other.k[:]) } // AppendText implements encoding.TextAppender. diff --git a/types/key/disco_test.go b/types/key/disco_test.go index c62c13cbf..131fe350f 100644 --- a/types/key/disco_test.go +++ b/types/key/disco_test.go @@ -81,3 +81,21 @@ func TestDiscoShared(t *testing.T) { t.Error("k1.Shared(k2) != k2.Shared(k1)") } } + +func TestSortedPairOfDiscoPublic(t *testing.T) { + pubA := DiscoPublic{} + pubA.k[0] = 0x01 + pubB := DiscoPublic{} + pubB.k[0] = 0x02 + sortedInput := NewSortedPairOfDiscoPublic(pubA, pubB) + unsortedInput := NewSortedPairOfDiscoPublic(pubB, pubA) + if sortedInput.Get() != unsortedInput.Get() { + t.Fatal("sortedInput.Get() != unsortedInput.Get()") + } + if unsortedInput.Get()[0] != pubA { + t.Fatal("unsortedInput.Get()[0] != pubA") + } + if unsortedInput.Get()[1] != pubB { + t.Fatal("unsortedInput.Get()[1] != pubB") + } +} diff --git a/types/key/hardware_attestation.go b/types/key/hardware_attestation.go new file mode 100644 index 000000000..9d4a21ee4 --- /dev/null +++ b/types/key/hardware_attestation.go @@ -0,0 +1,181 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package key + +import ( + "crypto" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/subtle" + "encoding/json" + "fmt" + "io" + + "go4.org/mem" +) + +var ErrUnsupported = fmt.Errorf("key type not supported on this platform") + +const hardwareAttestPublicHexPrefix = "hwattestpub:" + +const pubkeyLength = 65 // uncompressed P-256 + +// HardwareAttestationKey describes a hardware-backed key that is used to +// identify a node. Implementation details will +// vary based on the platform in use (SecureEnclave for Apple, TPM for +// Windows/Linux, Android Hardware-backed Keystore). +// This key can only be marshalled and unmarshaled on the same machine. +type HardwareAttestationKey interface { + crypto.Signer + json.Marshaler + json.Unmarshaler + io.Closer + Clone() HardwareAttestationKey + IsZero() bool +} + +// HardwareAttestationPublicFromPlatformKey creates a HardwareAttestationPublic +// for communicating the public component of the hardware attestation key +// with control and other nodes. +func HardwareAttestationPublicFromPlatformKey(k HardwareAttestationKey) HardwareAttestationPublic { + if k == nil { + return HardwareAttestationPublic{} + } + pub := k.Public() + ecdsaPub, ok := pub.(*ecdsa.PublicKey) + if !ok { + panic("hardware attestation key is not ECDSA") + } + bytes, err := ecdsaPub.Bytes() + if err != nil { + panic(err) + } + if len(bytes) != pubkeyLength { + panic("hardware attestation key is not uncompressed ECDSA P-256") + } + var ecdsaPubArr [pubkeyLength]byte + copy(ecdsaPubArr[:], bytes) + return HardwareAttestationPublic{k: ecdsaPubArr} +} + +// HardwareAttestationPublic is the public key counterpart to +// HardwareAttestationKey. +type HardwareAttestationPublic struct { + k [pubkeyLength]byte +} + +func (k *HardwareAttestationPublic) Clone() *HardwareAttestationPublic { + if k == nil { + return nil + } + var out HardwareAttestationPublic + copy(out.k[:], k.k[:]) + return &out +} + +func (k HardwareAttestationPublic) Equal(o HardwareAttestationPublic) bool { + return subtle.ConstantTimeCompare(k.k[:], o.k[:]) == 1 +} + +// IsZero reports whether k is the zero value. +func (k HardwareAttestationPublic) IsZero() bool { + var zero [pubkeyLength]byte + return k.k == zero +} + +// String returns the hex-encoded public key with a type prefix. +func (k HardwareAttestationPublic) String() string { + bs, err := k.MarshalText() + if err != nil { + panic(err) + } + return string(bs) +} + +// MarshalText implements encoding.TextMarshaler. +func (k HardwareAttestationPublic) MarshalText() ([]byte, error) { + if k.IsZero() { + return nil, nil + } + return k.AppendText(nil) +} + +// UnmarshalText implements encoding.TextUnmarshaler. It expects a typed prefix +// followed by a hex encoded representation of k. +func (k *HardwareAttestationPublic) UnmarshalText(b []byte) error { + if len(b) == 0 { + *k = HardwareAttestationPublic{} + return nil + } + + kb := make([]byte, pubkeyLength) + if err := parseHex(kb, mem.B(b), mem.S(hardwareAttestPublicHexPrefix)); err != nil { + return err + } + + _, err := ecdsa.ParseUncompressedPublicKey(elliptic.P256(), kb) + if err != nil { + return err + } + copy(k.k[:], kb) + return nil +} + +func (k HardwareAttestationPublic) AppendText(dst []byte) ([]byte, error) { + return appendHexKey(dst, hardwareAttestPublicHexPrefix, k.k[:]), nil +} + +// Verifier returns the ECDSA public key for verifying signatures made by k. +func (k HardwareAttestationPublic) Verifier() *ecdsa.PublicKey { + pk, err := ecdsa.ParseUncompressedPublicKey(elliptic.P256(), k.k[:]) + if err != nil { + panic(err) + } + return pk +} + +// emptyHardwareAttestationKey is a function that returns an empty +// HardwareAttestationKey suitable for use with JSON unmarshaling. +var emptyHardwareAttestationKey func() HardwareAttestationKey + +// createHardwareAttestationKey is a function that creates a new +// HardwareAttestationKey for the current platform. +var createHardwareAttestationKey func() (HardwareAttestationKey, error) + +// HardwareAttestationKeyFn is a callback function type that returns a HardwareAttestationKey +// and an error. It is used to register platform-specific implementations of +// HardwareAttestationKey. +type HardwareAttestationKeyFn func() (HardwareAttestationKey, error) + +// RegisterHardwareAttestationKeyFns registers a hardware attestation +// key implementation for the current platform. +func RegisterHardwareAttestationKeyFns(emptyFn func() HardwareAttestationKey, createFn HardwareAttestationKeyFn) { + if emptyHardwareAttestationKey != nil { + panic("emptyPlatformHardwareAttestationKey already registered") + } + emptyHardwareAttestationKey = emptyFn + + if createHardwareAttestationKey != nil { + panic("createPlatformHardwareAttestationKey already registered") + } + createHardwareAttestationKey = createFn +} + +// NewEmptyHardwareAttestationKey returns an empty HardwareAttestationKey +// suitable for JSON unmarshaling. +func NewEmptyHardwareAttestationKey() (HardwareAttestationKey, error) { + if emptyHardwareAttestationKey == nil { + return nil, ErrUnsupported + } + return emptyHardwareAttestationKey(), nil +} + +// NewHardwareAttestationKey returns a newly created HardwareAttestationKey for +// the current platform. +func NewHardwareAttestationKey() (HardwareAttestationKey, error) { + if createHardwareAttestationKey == nil { + return nil, ErrUnsupported + } + return createHardwareAttestationKey() +} diff --git a/types/key/nl.go b/types/key/nl.go index e0b4e5ca6..50caed98c 100644 --- a/types/key/nl.go +++ b/types/key/nl.go @@ -131,10 +131,10 @@ func NLPublicFromEd25519Unsafe(public ed25519.PublicKey) NLPublic { // is able to decode both the CLI form (tlpub:) & the // regular form (nlpub:). func (k *NLPublic) UnmarshalText(b []byte) error { - if mem.HasPrefix(mem.B(b), mem.S(nlPublicHexPrefixCLI)) { - return parseHex(k.k[:], mem.B(b), mem.S(nlPublicHexPrefixCLI)) + if mem.HasPrefix(mem.B(b), mem.S(nlPublicHexPrefix)) { + return parseHex(k.k[:], mem.B(b), mem.S(nlPublicHexPrefix)) } - return parseHex(k.k[:], mem.B(b), mem.S(nlPublicHexPrefix)) + return parseHex(k.k[:], mem.B(b), mem.S(nlPublicHexPrefixCLI)) } // AppendText implements encoding.TextAppender. diff --git a/types/key/util.go b/types/key/util.go index bdb2a06f6..50fac8275 100644 --- a/types/key/util.go +++ b/types/key/util.go @@ -10,9 +10,12 @@ import ( "errors" "fmt" "io" + "reflect" "slices" "go4.org/mem" + "tailscale.com/util/set" + "tailscale.com/util/testenv" ) // rand fills b with cryptographically strong random bytes. Panics if @@ -115,3 +118,18 @@ func debug32(k [32]byte) string { dst[6] = ']' return string(dst[:7]) } + +// PrivateTypesForTest returns the set of private key types +// in this package, for testing purposes. +func PrivateTypesForTest() set.Set[reflect.Type] { + testenv.AssertInTest() + return set.Of( + reflect.TypeFor[ChallengePrivate](), + reflect.TypeFor[ControlPrivate](), + reflect.TypeFor[DiscoPrivate](), + reflect.TypeFor[MachinePrivate](), + reflect.TypeFor[NodePrivate](), + reflect.TypeFor[NLPrivate](), + reflect.TypeFor[HardwareAttestationKey](), + ) +} diff --git a/types/lazy/deferred.go b/types/lazy/deferred.go new file mode 100644 index 000000000..973082914 --- /dev/null +++ b/types/lazy/deferred.go @@ -0,0 +1,105 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package lazy + +import ( + "sync" + "sync/atomic" + + "tailscale.com/types/ptr" +) + +// DeferredInit allows one or more funcs to be deferred +// until [DeferredInit.Do] is called for the first time. +// +// DeferredInit is safe for concurrent use. +type DeferredInit struct { + DeferredFuncs +} + +// DeferredFuncs allows one or more funcs to be deferred +// until the owner's [DeferredInit.Do] method is called +// for the first time. +// +// DeferredFuncs is safe for concurrent use. The execution +// order of functions deferred by different goroutines is +// unspecified and must not be relied upon. +// However, functions deferred by the same goroutine are +// executed in the same relative order they were deferred. +// Warning: this is the opposite of the behavior of Go's +// defer statement, which executes deferred functions in +// reverse order. +type DeferredFuncs struct { + m sync.Mutex + funcs []func() error + + // err is either: + // * nil, if deferred init has not yet been completed + // * nilErrPtr, if initialization completed successfully + // * non-nil and not nilErrPtr, if there was an error + // + // It is an atomic.Pointer so it can be read without m held. + err atomic.Pointer[error] +} + +// Defer adds a function to be called when [DeferredInit.Do] +// is called for the first time. It returns true on success, +// or false if [DeferredInit.Do] has already been called. +func (d *DeferredFuncs) Defer(f func() error) bool { + d.m.Lock() + defer d.m.Unlock() + if d.err.Load() != nil { + return false + } + d.funcs = append(d.funcs, f) + return true +} + +// MustDefer is like [DeferredFuncs.Defer], but panics +// if [DeferredInit.Do] has already been called. +func (d *DeferredFuncs) MustDefer(f func() error) { + if !d.Defer(f) { + panic("deferred init already completed") + } +} + +// Do calls previously deferred init functions if it is being called +// for the first time on this instance of [DeferredInit]. +// It stops and returns an error if any init function returns an error. +// +// It is safe for concurrent use, and the deferred init is guaranteed +// to have been completed, either successfully or with an error, +// when Do() returns. +func (d *DeferredInit) Do() error { + err := d.err.Load() + if err == nil { + err = d.doSlow() + } + return *err +} + +func (d *DeferredInit) doSlow() (err *error) { + d.m.Lock() + defer d.m.Unlock() + if err := d.err.Load(); err != nil { + return err + } + defer func() { + d.err.Store(err) + d.funcs = nil // do not keep funcs alive after invoking + }() + for _, f := range d.funcs { + if err := f(); err != nil { + return ptr.To(err) + } + } + return nilErrPtr +} + +// Funcs is a shorthand for &d.DeferredFuncs. +// The returned value can safely be passed to external code, +// allowing to defer init funcs without also exposing [DeferredInit.Do]. +func (d *DeferredInit) Funcs() *DeferredFuncs { + return &d.DeferredFuncs +} diff --git a/types/lazy/deferred_test.go b/types/lazy/deferred_test.go new file mode 100644 index 000000000..98cacbfce --- /dev/null +++ b/types/lazy/deferred_test.go @@ -0,0 +1,297 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package lazy + +import ( + "errors" + "fmt" + "sync" + "sync/atomic" + "testing" +) + +func ExampleDeferredInit() { + // DeferredInit allows both registration and invocation of the + // deferred funcs. It should remain internal to the code that "owns" it. + var di DeferredInit + // Deferred funcs will not be executed until [DeferredInit.Do] is called. + deferred := di.Defer(func() error { + fmt.Println("Internal init") + return nil + }) + // [DeferredInit.Defer] reports whether the function was successfully deferred. + // A func can only fail to defer if [DeferredInit.Do] has already been called. + if deferred { + fmt.Printf("Internal init has been deferred\n\n") + } + + // If necessary, the value returned by [DeferredInit.Funcs] + // can be shared with external code to facilitate deferring + // funcs without allowing it to call [DeferredInit.Do]. + df := di.Funcs() + // If a certain init step must be completed for the program + // to function correctly, and failure to defer it indicates + // a coding error, use [DeferredFuncs.MustDefer] instead of + // [DeferredFuncs.Defer]. It panics if Do() has already been called. + df.MustDefer(func() error { + fmt.Println("External init - 1") + return nil + }) + // A deferred func may return an error to indicate a failed init. + // If a deferred func returns an error, execution stops + // and the error is propagated to the caller. + df.Defer(func() error { + fmt.Println("External init - 2") + return errors.New("bang!") + }) + // The deferred function below won't be executed. + df.Defer(func() error { + fmt.Println("Unreachable") + return nil + }) + + // When [DeferredInit]'s owner needs initialization to be completed, + // it can call [DeferredInit.Do]. When called for the first time, + // it invokes the deferred funcs. + err := di.Do() + if err != nil { + fmt.Printf("Deferred init failed: %v\n", err) + } + // [DeferredInit.Do] is safe for concurrent use and can be called + // multiple times by the same or different goroutines. + // However, the deferred functions are never invoked more than once. + // If the deferred init fails on the first attempt, all subsequent + // [DeferredInit.Do] calls will return the same error. + if err = di.Do(); err != nil { + fmt.Printf("Deferred init failed: %v\n\n", err) + } + + // Additionally, all subsequent attempts to defer a function will fail + // after [DeferredInit.Do] has been called. + deferred = di.Defer(func() error { + fmt.Println("Unreachable") + return nil + }) + if !deferred { + fmt.Println("Cannot defer a func once init has been completed") + } + + // Output: + // Internal init has been deferred + // + // Internal init + // External init - 1 + // External init - 2 + // Deferred init failed: bang! + // Deferred init failed: bang! + // + // Cannot defer a func once init has been completed +} + +func TestDeferredInit(t *testing.T) { + tests := []struct { + name string + numFuncs int + }{ + { + name: "no-funcs", + numFuncs: 0, + }, + { + name: "one-func", + numFuncs: 1, + }, + { + name: "many-funcs", + numFuncs: 1000, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var di DeferredInit + + calls := make([]atomic.Bool, tt.numFuncs) // whether N-th func has been called + checkCalls := func() { + t.Helper() + for i := range calls { + if !calls[i].Load() { + t.Errorf("Func #%d has never been called", i) + } + } + } + + // Defer funcs concurrently across multiple goroutines. + var wg sync.WaitGroup + wg.Add(tt.numFuncs) + for i := range tt.numFuncs { + go func() { + f := func() error { + if calls[i].Swap(true) { + t.Errorf("Func #%d has already been called", i) + } + return nil + } + if !di.Defer(f) { + t.Errorf("Func #%d cannot be deferred", i) + return + } + wg.Done() + }() + } + // Wait for all funcs to be deferred. + wg.Wait() + + // Call [DeferredInit.Do] concurrently. + const N = 10000 + for range N { + wg.Add(1) + go func() { + gotErr := di.Do() + checkError(t, gotErr, nil, false) + checkCalls() + wg.Done() + }() + } + wg.Wait() + }) + } +} + +func TestDeferredErr(t *testing.T) { + tests := []struct { + name string + funcs []func() error + wantErr error + }{ + { + name: "no-funcs", + wantErr: nil, + }, + { + name: "no-error", + funcs: []func() error{func() error { return nil }}, + wantErr: nil, + }, + { + name: "error", + funcs: []func() error{ + func() error { return nil }, + func() error { return errors.New("bang!") }, + func() error { return errors.New("unreachable") }, + }, + wantErr: errors.New("bang!"), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var di DeferredInit + for _, f := range tt.funcs { + di.MustDefer(f) + } + + var wg sync.WaitGroup + N := 10000 + for range N { + wg.Add(1) + go func() { + gotErr := di.Do() + checkError(t, gotErr, tt.wantErr, false) + wg.Done() + }() + } + wg.Wait() + }) + } +} + +// TestDeferAfterDo checks all of the following: +// - Deferring a function before [DeferredInit.Do] is called should always succeed. +// - All successfully deferred functions are executed by the time [DeferredInit.Do] completes. +// - No functions can be deferred after [DeferredInit.Do] is called, meaning: +// - [DeferredInit.Defer] should return false. +// - The deferred function should not be executed. +// +// This test is intentionally racy as it attempts to defer functions from multiple goroutines +// and then calls [DeferredInit.Do] without waiting for them to finish. Waiting would alter +// the observable behavior and render the test pointless. +func TestDeferAfterDo(t *testing.T) { + var di DeferredInit + var deferred, called atomic.Int32 + + // deferOnce defers a test function once and fails the test + // if [DeferredInit.Defer] returns true after [DeferredInit.Do] + // has already been called and any deferred functions have been executed. + // It's called concurrently by multiple goroutines. + deferOnce := func() bool { + // canDefer is whether it's acceptable for Defer to return true. + // (but not it necessarily must return true) + // If its func has run before, it's definitely not okay for it to + // accept more Defer funcs. + canDefer := called.Load() == 0 + ok := di.Defer(func() error { + called.Add(1) + return nil + }) + if ok { + if !canDefer { + t.Error("An init function was deferred after DeferredInit.Do() was already called") + } + deferred.Add(1) + } + return ok + } + + // Deferring a func before calling [DeferredInit.Do] should always succeed. + if !deferOnce() { + t.Fatal("Failed to defer a func") + } + + // Defer up to N funcs concurrently while [DeferredInit.Do] is being called by the main goroutine. + // Since we'll likely attempt to defer some funcs after [DeferredInit.Do] has been called, + // we expect these late defers to fail, and the funcs will not be deferred or executed. + // However, the number of the deferred and called funcs should always be equal when [DeferredInit.Do] exits. + const N = 10000 + var wg sync.WaitGroup + for range N { + wg.Add(1) + go func() { + deferOnce() + wg.Done() + }() + } + + if err := di.Do(); err != nil { + t.Fatalf("DeferredInit.Do() failed: %v", err) + } + // The number of called funcs should remain unchanged after [DeferredInit.Do] returns. + wantCalled := called.Load() + + if deferOnce() { + t.Error("An init func was deferred after DeferredInit.Do() returned") + } + + // Wait for the goroutines deferring init funcs to exit. + // No funcs should be called after DeferredInit.Do() has returned, + // and the number of called funcs should be equal to the number of deferred funcs. + wg.Wait() + if gotCalled := called.Load(); gotCalled != wantCalled { + t.Errorf("An init func was called after DeferredInit.Do() returned. Got %d, want %d", gotCalled, wantCalled) + } + if deferred, called := deferred.Load(), called.Load(); deferred != called { + t.Errorf("Deferred: %d; Called: %d", deferred, called) + } +} + +func checkError(tb testing.TB, got, want error, fatal bool) { + tb.Helper() + f := tb.Errorf + if fatal { + f = tb.Fatalf + } + if (want == nil && got != nil) || + (want != nil && got == nil) || + (want != nil && got != nil && want.Error() != got.Error()) { + f("gotErr: %v; wantErr: %v", got, want) + } +} diff --git a/types/lazy/lazy.go b/types/lazy/lazy.go index 43325512d..f537758fa 100644 --- a/types/lazy/lazy.go +++ b/types/lazy/lazy.go @@ -23,6 +23,9 @@ var nilErrPtr = ptr.To[error](nil) // Recursive use of a SyncValue from its own fill function will deadlock. // // SyncValue is safe for concurrent use. +// +// Unlike [sync.OnceValue], the linker can do better dead code elimination +// with SyncValue. See https://github.com/golang/go/issues/62202. type SyncValue[T any] struct { once sync.Once v T @@ -120,44 +123,9 @@ func (z *SyncValue[T]) PeekErr() (v T, err error, ok bool) { return zero, nil, false } -// SyncFunc wraps a function to make it lazy. -// -// The returned function calls fill the first time it's called, and returns -// fill's result on every subsequent call. -// -// The returned function is safe for concurrent use. -func SyncFunc[T any](fill func() T) func() T { - var ( - once sync.Once - v T - ) - return func() T { - once.Do(func() { v = fill() }) - return v - } -} - -// SyncFuncErr wraps a function to make it lazy. -// -// The returned function calls fill the first time it's called, and returns -// fill's results on every subsequent call. -// -// The returned function is safe for concurrent use. -func SyncFuncErr[T any](fill func() (T, error)) func() (T, error) { - var ( - once sync.Once - v T - err error - ) - return func() (T, error) { - once.Do(func() { v, err = fill() }) - return v, err - } -} - -// TB is a subset of testing.TB that we use to set up test helpers. +// testing_TB is a subset of testing.TB that we use to set up test helpers. // It's defined here to avoid pulling in the testing package. -type TB interface { +type testing_TB interface { Helper() Cleanup(func()) } @@ -167,7 +135,9 @@ type TB interface { // subtests complete. // It is not safe for concurrent use and must not be called concurrently with // any SyncValue methods, including another call to itself. -func (z *SyncValue[T]) SetForTest(tb TB, val T, err error) { +// +// The provided tb should be a [*testing.T] or [*testing.B]. +func (z *SyncValue[T]) SetForTest(tb testing_TB, val T, err error) { tb.Helper() oldErr, oldVal := z.err.Load(), z.v diff --git a/types/lazy/map.go b/types/lazy/map.go new file mode 100644 index 000000000..75a1dd739 --- /dev/null +++ b/types/lazy/map.go @@ -0,0 +1,62 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package lazy + +import "tailscale.com/util/mak" + +// GMap is a map of lazily computed [GValue] pointers, keyed by a comparable +// type. +// +// Use either Get or GetErr, depending on whether your fill function returns an +// error. +// +// GMap is not safe for concurrent use. +type GMap[K comparable, V any] struct { + store map[K]*GValue[V] +} + +// Len returns the number of entries in the map. +func (s *GMap[K, V]) Len() int { + return len(s.store) +} + +// Set attempts to set the value of k to v, and reports whether it succeeded. +// Set only succeeds if k has never been called with Get/GetErr/Set before. +func (s *GMap[K, V]) Set(k K, v V) bool { + z, ok := s.store[k] + if !ok { + z = new(GValue[V]) + mak.Set(&s.store, k, z) + } + return z.Set(v) +} + +// MustSet sets the value of k to v, or panics if k already has a value. +func (s *GMap[K, V]) MustSet(k K, v V) { + if !s.Set(k, v) { + panic("Set after already filled") + } +} + +// Get returns the value for k, computing it with fill if it's not already +// present. +func (s *GMap[K, V]) Get(k K, fill func() V) V { + z, ok := s.store[k] + if !ok { + z = new(GValue[V]) + mak.Set(&s.store, k, z) + } + return z.Get(fill) +} + +// GetErr returns the value for k, computing it with fill if it's not already +// present. +func (s *GMap[K, V]) GetErr(k K, fill func() (V, error)) (V, error) { + z, ok := s.store[k] + if !ok { + z = new(GValue[V]) + mak.Set(&s.store, k, z) + } + return z.GetErr(fill) +} diff --git a/types/lazy/map_test.go b/types/lazy/map_test.go new file mode 100644 index 000000000..ec1152b0b --- /dev/null +++ b/types/lazy/map_test.go @@ -0,0 +1,95 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package lazy + +import ( + "errors" + "testing" +) + +func TestGMap(t *testing.T) { + var gm GMap[string, int] + n := int(testing.AllocsPerRun(1000, func() { + got := gm.Get("42", fortyTwo) + if got != 42 { + t.Fatalf("got %v; want 42", got) + } + })) + if n != 0 { + t.Errorf("allocs = %v; want 0", n) + } +} + +func TestGMapErr(t *testing.T) { + var gm GMap[string, int] + n := int(testing.AllocsPerRun(1000, func() { + got, err := gm.GetErr("42", func() (int, error) { + return 42, nil + }) + if got != 42 || err != nil { + t.Fatalf("got %v, %v; want 42, nil", got, err) + } + })) + if n != 0 { + t.Errorf("allocs = %v; want 0", n) + } + + var gmErr GMap[string, int] + wantErr := errors.New("test error") + n = int(testing.AllocsPerRun(1000, func() { + got, err := gmErr.GetErr("42", func() (int, error) { + return 0, wantErr + }) + if got != 0 || err != wantErr { + t.Fatalf("got %v, %v; want 0, %v", got, err, wantErr) + } + })) + if n != 0 { + t.Errorf("allocs = %v; want 0", n) + } +} + +func TestGMapSet(t *testing.T) { + var gm GMap[string, int] + if !gm.Set("42", 42) { + t.Fatalf("Set failed") + } + if gm.Set("42", 43) { + t.Fatalf("Set succeeded after first Set") + } + n := int(testing.AllocsPerRun(1000, func() { + got := gm.Get("42", fortyTwo) + if got != 42 { + t.Fatalf("got %v; want 42", got) + } + })) + if n != 0 { + t.Errorf("allocs = %v; want 0", n) + } +} + +func TestGMapMustSet(t *testing.T) { + var gm GMap[string, int] + gm.MustSet("42", 42) + defer func() { + if e := recover(); e == nil { + t.Errorf("unexpected success; want panic") + } + }() + gm.MustSet("42", 43) +} + +func TestGMapRecursivePanic(t *testing.T) { + defer func() { + if e := recover(); e != nil { + t.Logf("got panic, as expected") + } else { + t.Errorf("unexpected success; want panic") + } + }() + gm := GMap[string, int]{} + gm.Get("42", func() int { + return gm.Get("42", func() int { return 42 }) + }) +} diff --git a/types/lazy/sync_test.go b/types/lazy/sync_test.go index 5578eee0c..4d1278253 100644 --- a/types/lazy/sync_test.go +++ b/types/lazy/sync_test.go @@ -354,46 +354,3 @@ func TestSyncValueSetForTest(t *testing.T) { }) } } - -func TestSyncFunc(t *testing.T) { - f := SyncFunc(fortyTwo) - - n := int(testing.AllocsPerRun(1000, func() { - got := f() - if got != 42 { - t.Fatalf("got %v; want 42", got) - } - })) - if n != 0 { - t.Errorf("allocs = %v; want 0", n) - } -} - -func TestSyncFuncErr(t *testing.T) { - f := SyncFuncErr(func() (int, error) { - return 42, nil - }) - n := int(testing.AllocsPerRun(1000, func() { - got, err := f() - if got != 42 || err != nil { - t.Fatalf("got %v, %v; want 42, nil", got, err) - } - })) - if n != 0 { - t.Errorf("allocs = %v; want 0", n) - } - - wantErr := errors.New("test error") - f = SyncFuncErr(func() (int, error) { - return 0, wantErr - }) - n = int(testing.AllocsPerRun(1000, func() { - got, err := f() - if got != 0 || err != wantErr { - t.Fatalf("got %v, %v; want 0, %v", got, err, wantErr) - } - })) - if n != 0 { - t.Errorf("allocs = %v; want 0", n) - } -} diff --git a/types/logger/logger.go b/types/logger/logger.go index 11596b357..6c4edf633 100644 --- a/types/logger/logger.go +++ b/types/logger/logger.go @@ -14,6 +14,7 @@ import ( "fmt" "io" "log" + "runtime" "strings" "sync" "time" @@ -23,6 +24,7 @@ import ( "go4.org/mem" "tailscale.com/envknob" "tailscale.com/util/ctxkey" + "tailscale.com/util/testenv" ) // Logf is the basic Tailscale logger type: a printf-like func. @@ -162,6 +164,10 @@ func RateLimitedFnWithClock(logf Logf, f time.Duration, burst int, maxCache int, if envknob.String("TS_DEBUG_LOG_RATE") == "all" { return logf } + if runtime.GOOS == "plan9" { + // To ease bring-up. + return logf + } var ( mu sync.Mutex msgLim = make(map[string]*limitData) // keyed by logf format @@ -317,6 +323,7 @@ func (fn ArgWriter) Format(f fmt.State, _ rune) { bw.Reset(f) fn(bw) bw.Flush() + bw.Reset(io.Discard) argBufioPool.Put(bw) } @@ -379,16 +386,10 @@ func (a asJSONResult) Format(s fmt.State, verb rune) { s.Write(v) } -// TBLogger is the testing.TB subset needed by TestLogger. -type TBLogger interface { - Helper() - Logf(format string, args ...any) -} - // TestLogger returns a logger that logs to tb.Logf // with a prefix to make it easier to distinguish spam // from explicit test failures. -func TestLogger(tb TBLogger) Logf { +func TestLogger(tb testenv.TB) Logf { return func(format string, args ...any) { tb.Helper() tb.Logf(" ... "+format, args...) diff --git a/types/mapx/ordered.go b/types/mapx/ordered.go new file mode 100644 index 000000000..1991f039d --- /dev/null +++ b/types/mapx/ordered.go @@ -0,0 +1,111 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package mapx contains extra map types and functions. +package mapx + +import ( + "iter" + "slices" +) + +// OrderedMap is a map that maintains the order of its keys. +// +// It is meant for maps that only grow or that are small; +// is it not optimized for deleting keys. +// +// The zero value is ready to use. +// +// Locking-wise, it has the same rules as a regular Go map: +// concurrent reads are safe, but not writes. +type OrderedMap[K comparable, V any] struct { + // m is the underlying map. + m map[K]V + + // keys is the order of keys in the map. + keys []K +} + +func (m *OrderedMap[K, V]) init() { + if m.m == nil { + m.m = make(map[K]V) + } +} + +// Set sets the value for the given key in the map. +// +// If the key already exists, it updates the value and keeps the order. +func (m *OrderedMap[K, V]) Set(key K, value V) { + m.init() + len0 := len(m.keys) + m.m[key] = value + if len(m.m) > len0 { + // New key (not an update) + m.keys = append(m.keys, key) + } +} + +// Get returns the value for the given key in the map. +// If the key does not exist, it returns the zero value for V. +func (m *OrderedMap[K, V]) Get(key K) V { + return m.m[key] +} + +// GetOk returns the value for the given key in the map +// and whether it was present in the map. +func (m *OrderedMap[K, V]) GetOk(key K) (_ V, ok bool) { + v, ok := m.m[key] + return v, ok +} + +// Contains reports whether the map contains the given key. +func (m *OrderedMap[K, V]) Contains(key K) bool { + _, ok := m.m[key] + return ok +} + +// Delete removes the key from the map. +// +// The cost is O(n) in the number of keys in the map. +func (m *OrderedMap[K, V]) Delete(key K) { + len0 := len(m.m) + delete(m.m, key) + if len(m.m) == len0 { + // Wasn't present; no need to adjust keys. + return + } + was := m.keys + m.keys = m.keys[:0] + for _, k := range was { + if k != key { + m.keys = append(m.keys, k) + } + } +} + +// All yields all the keys and values, in the order they were inserted. +func (m *OrderedMap[K, V]) All() iter.Seq2[K, V] { + return func(yield func(K, V) bool) { + for _, k := range m.keys { + if !yield(k, m.m[k]) { + return + } + } + } +} + +// Keys yields the map keys, in the order they were inserted. +func (m *OrderedMap[K, V]) Keys() iter.Seq[K] { + return slices.Values(m.keys) +} + +// Values yields the map values, in the order they were inserted. +func (m *OrderedMap[K, V]) Values() iter.Seq[V] { + return func(yield func(V) bool) { + for _, k := range m.keys { + if !yield(m.m[k]) { + return + } + } + } +} diff --git a/types/mapx/ordered_test.go b/types/mapx/ordered_test.go new file mode 100644 index 000000000..7dcb7e405 --- /dev/null +++ b/types/mapx/ordered_test.go @@ -0,0 +1,56 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package mapx + +import ( + "fmt" + "slices" + "testing" +) + +func TestOrderedMap(t *testing.T) { + // Test the OrderedMap type and its methods. + var m OrderedMap[string, int] + m.Set("d", 4) + m.Set("a", 1) + m.Set("b", 1) + m.Set("b", 2) + m.Set("c", 3) + m.Delete("d") + m.Delete("e") + + want := map[string]int{ + "a": 1, + "b": 2, + "c": 3, + "d": 0, + } + for k, v := range want { + if m.Get(k) != v { + t.Errorf("Get(%q) = %d, want %d", k, m.Get(k), v) + continue + } + got, ok := m.GetOk(k) + if got != v { + t.Errorf("GetOk(%q) = %d, want %d", k, got, v) + } + if ok != m.Contains(k) { + t.Errorf("GetOk and Contains don't agree for %q", k) + } + } + + if got, want := slices.Collect(m.Keys()), []string{"a", "b", "c"}; !slices.Equal(got, want) { + t.Errorf("Keys() = %q, want %q", got, want) + } + if got, want := slices.Collect(m.Values()), []int{1, 2, 3}; !slices.Equal(got, want) { + t.Errorf("Values() = %v, want %v", got, want) + } + var allGot []string + for k, v := range m.All() { + allGot = append(allGot, fmt.Sprintf("%s:%d", k, v)) + } + if got, want := allGot, []string{"a:1", "b:2", "c:3"}; !slices.Equal(got, want) { + t.Errorf("All() = %q, want %q", got, want) + } +} diff --git a/types/netlogfunc/netlogfunc.go b/types/netlogfunc/netlogfunc.go new file mode 100644 index 000000000..6185fcb71 --- /dev/null +++ b/types/netlogfunc/netlogfunc.go @@ -0,0 +1,15 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package netlogfunc defines types for network logging. +package netlogfunc + +import ( + "net/netip" + + "tailscale.com/types/ipproto" +) + +// ConnectionCounter is a function for counting packets and bytes +// for a particular connection. +type ConnectionCounter func(proto ipproto.Proto, src, dst netip.AddrPort, packets, bytes int, recv bool) diff --git a/types/netlogtype/netlogtype.go b/types/netlogtype/netlogtype.go index f2fa2bda9..cc38684a3 100644 --- a/types/netlogtype/netlogtype.go +++ b/types/netlogtype/netlogtype.go @@ -5,38 +5,44 @@ package netlogtype import ( + "maps" "net/netip" + "sync" "time" "tailscale.com/tailcfg" "tailscale.com/types/ipproto" ) -// TODO(joetsai): Remove "omitempty" if "omitzero" is ever supported in both -// the v1 and v2 "json" packages. - // Message is the log message that captures network traffic. type Message struct { - NodeID tailcfg.StableNodeID `json:"nodeId" cbor:"0,keyasint"` // e.g., "n123456CNTRL" + NodeID tailcfg.StableNodeID `json:"nodeId"` // e.g., "n123456CNTRL" + + Start time.Time `json:"start"` // inclusive + End time.Time `json:"end"` // inclusive - Start time.Time `json:"start" cbor:"12,keyasint"` // inclusive - End time.Time `json:"end" cbor:"13,keyasint"` // inclusive + SrcNode Node `json:"srcNode,omitzero"` + DstNodes []Node `json:"dstNodes,omitempty"` - VirtualTraffic []ConnectionCounts `json:"virtualTraffic,omitempty" cbor:"14,keyasint,omitempty"` - SubnetTraffic []ConnectionCounts `json:"subnetTraffic,omitempty" cbor:"15,keyasint,omitempty"` - ExitTraffic []ConnectionCounts `json:"exitTraffic,omitempty" cbor:"16,keyasint,omitempty"` - PhysicalTraffic []ConnectionCounts `json:"physicalTraffic,omitempty" cbor:"17,keyasint,omitempty"` + VirtualTraffic []ConnectionCounts `json:"virtualTraffic,omitempty"` + SubnetTraffic []ConnectionCounts `json:"subnetTraffic,omitempty"` + ExitTraffic []ConnectionCounts `json:"exitTraffic,omitempty"` + PhysicalTraffic []ConnectionCounts `json:"physicalTraffic,omitempty"` } const ( - messageJSON = `{"nodeId":"n0123456789abcdefCNTRL",` + maxJSONTimeRange + `,` + minJSONTraffic + `}` + messageJSON = `{"nodeId":` + maxJSONStableID + `,` + minJSONNodes + `,` + maxJSONTimeRange + `,` + minJSONTraffic + `}` + maxJSONStableID = `"n0123456789abcdefCNTRL"` + minJSONNodes = `"srcNode":{},"dstNodes":[]` maxJSONTimeRange = `"start":` + maxJSONRFC3339 + `,"end":` + maxJSONRFC3339 maxJSONRFC3339 = `"0001-01-01T00:00:00.000000000Z"` minJSONTraffic = `"virtualTraffic":{},"subnetTraffic":{},"exitTraffic":{},"physicalTraffic":{}` - // MaxMessageJSONSize is the overhead size of Message when it is - // serialized as JSON assuming that each traffic map is populated. - MaxMessageJSONSize = len(messageJSON) + // MinMessageJSONSize is the overhead size of Message when it is + // serialized as JSON assuming that each field is minimally populated. + // Each [Node] occupies at least [MinNodeJSONSize]. + // Each [ConnectionCounts] occupies at most [MaxConnectionCountsJSONSize]. + MinMessageJSONSize = len(messageJSON) maxJSONConnCounts = `{` + maxJSONConn + `,` + maxJSONCounts + `}` maxJSONConn = `"proto":` + maxJSONProto + `,"src":` + maxJSONAddrPort + `,"dst":` + maxJSONAddrPort @@ -51,19 +57,30 @@ const ( // this object is nested within an array. // It assumes that netip.Addr never has IPv6 zones. MaxConnectionCountsJSONSize = len(maxJSONConnCounts) +) - maxCBORConnCounts = "\xbf" + maxCBORConn + maxCBORCounts + "\xff" - maxCBORConn = "\x00" + maxCBORProto + "\x01" + maxCBORAddrPort + "\x02" + maxCBORAddrPort - maxCBORProto = "\x18\xff" - maxCBORAddrPort = "\x52\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff" - maxCBORCounts = "\x0c" + maxCBORCount + "\x0d" + maxCBORCount + "\x0e" + maxCBORCount + "\x0f" + maxCBORCount - maxCBORCount = "\x1b\xff\xff\xff\xff\xff\xff\xff\xff" +// Node is information about a node. +type Node struct { + // NodeID is the stable ID of the node. + NodeID tailcfg.StableNodeID `json:"nodeId"` - // MaxConnectionCountsCBORSize is the maximum size of a ConnectionCounts - // when it is serialized as CBOR. - // It assumes that netip.Addr never has IPv6 zones. - MaxConnectionCountsCBORSize = len(maxCBORConnCounts) -) + // Name is the fully-qualified name of the node. + Name string `json:"name,omitzero"` // e.g., "carbonite.example.ts.net" + + // Addresses are the Tailscale IP addresses of the node. + Addresses []netip.Addr `json:"addresses,omitempty"` + + // OS is the operating system of the node. + OS string `json:"os,omitzero"` // e.g., "linux" + + // User is the user that owns the node. + // It is not populated if the node is tagged. + User string `json:"user,omitzero"` // e.g., "johndoe@example.com" + + // Tags are the tags of the node. + // It is not populated if the node is owned by a user. + Tags []string `json:"tags,omitempty"` // e.g., ["tag:prod","tag:logs"] +} // ConnectionCounts is a flattened struct of both a connection and counts. type ConnectionCounts struct { @@ -73,19 +90,19 @@ type ConnectionCounts struct { // Connection is a 5-tuple of proto, source and destination IP and port. type Connection struct { - Proto ipproto.Proto `json:"proto,omitzero,omitempty" cbor:"0,keyasint,omitempty"` - Src netip.AddrPort `json:"src,omitzero,omitempty" cbor:"1,keyasint,omitempty"` - Dst netip.AddrPort `json:"dst,omitzero,omitempty" cbor:"2,keyasint,omitempty"` + Proto ipproto.Proto `json:"proto,omitzero"` + Src netip.AddrPort `json:"src,omitzero"` + Dst netip.AddrPort `json:"dst,omitzero"` } func (c Connection) IsZero() bool { return c == Connection{} } // Counts are statistics about a particular connection. type Counts struct { - TxPackets uint64 `json:"txPkts,omitzero,omitempty" cbor:"12,keyasint,omitempty"` - TxBytes uint64 `json:"txBytes,omitzero,omitempty" cbor:"13,keyasint,omitempty"` - RxPackets uint64 `json:"rxPkts,omitzero,omitempty" cbor:"14,keyasint,omitempty"` - RxBytes uint64 `json:"rxBytes,omitzero,omitempty" cbor:"15,keyasint,omitempty"` + TxPackets uint64 `json:"txPkts,omitzero"` + TxBytes uint64 `json:"txBytes,omitzero"` + RxPackets uint64 `json:"rxPkts,omitzero"` + RxBytes uint64 `json:"rxBytes,omitzero"` } func (c Counts) IsZero() bool { return c == Counts{} } @@ -98,3 +115,43 @@ func (c1 Counts) Add(c2 Counts) Counts { c1.RxBytes += c2.RxBytes return c1 } + +// CountsByConnection is a count of packets and bytes for each connection. +// All methods are safe for concurrent calls. +type CountsByConnection struct { + mu sync.Mutex + m map[Connection]Counts +} + +// Add adds packets and bytes for the specified connection. +func (c *CountsByConnection) Add(proto ipproto.Proto, src, dst netip.AddrPort, packets, bytes int, recv bool) { + conn := Connection{Proto: proto, Src: src, Dst: dst} + c.mu.Lock() + defer c.mu.Unlock() + if c.m == nil { + c.m = make(map[Connection]Counts) + } + cnts := c.m[conn] + if recv { + cnts.RxPackets += uint64(packets) + cnts.RxBytes += uint64(bytes) + } else { + cnts.TxPackets += uint64(packets) + cnts.TxBytes += uint64(bytes) + } + c.m[conn] = cnts +} + +// Clone deep copies the map. +func (c *CountsByConnection) Clone() map[Connection]Counts { + c.mu.Lock() + defer c.mu.Unlock() + return maps.Clone(c.m) +} + +// Reset clear the map. +func (c *CountsByConnection) Reset() { + c.mu.Lock() + defer c.mu.Unlock() + clear(c.m) +} diff --git a/types/netlogtype/netlogtype_test.go b/types/netlogtype/netlogtype_test.go index 7f29090c5..00f89b228 100644 --- a/types/netlogtype/netlogtype_test.go +++ b/types/netlogtype/netlogtype_test.go @@ -1,6 +1,8 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause +//go:build !ts_omit_tailnetlock + package netlogtype import ( @@ -9,7 +11,6 @@ import ( "net/netip" "testing" - "github.com/fxamacker/cbor/v2" "github.com/google/go-cmp/cmp" "tailscale.com/util/must" ) @@ -30,10 +31,4 @@ func TestMaxSize(t *testing.T) { if string(outJSON) != maxJSONConnCounts { t.Errorf("JSON mismatch (-got +want):\n%s", cmp.Diff(string(outJSON), maxJSONConnCounts)) } - - outCBOR := must.Get(cbor.Marshal(cc)) - maxCBORConnCountsAlt := "\xa7" + maxCBORConnCounts[1:len(maxCBORConnCounts)-1] // may use a definite encoding of map - if string(outCBOR) != maxCBORConnCounts && string(outCBOR) != maxCBORConnCountsAlt { - t.Errorf("CBOR mismatch (-got +want):\n%s", cmp.Diff(string(outCBOR), maxCBORConnCounts)) - } } diff --git a/types/netmap/netmap.go b/types/netmap/netmap.go index 5e0622922..c54562f4d 100644 --- a/types/netmap/netmap.go +++ b/types/netmap/netmap.go @@ -26,14 +26,9 @@ import ( // The fields should all be considered read-only. They might // alias parts of previous NetworkMap values. type NetworkMap struct { - SelfNode tailcfg.NodeView - AllCaps set.Set[tailcfg.NodeCapability] // set version of SelfNode.Capabilities + SelfNode.CapMap - NodeKey key.NodePublic - PrivateKey key.NodePrivate - Expiry time.Time - // Name is the DNS name assigned to this node. - // It is the MapResponse.Node.Name value and ends with a period. - Name string + SelfNode tailcfg.NodeView + AllCaps set.Set[tailcfg.NodeCapability] // set version of SelfNode.Capabilities + SelfNode.CapMap + NodeKey key.NodePublic MachineKey key.MachinePublic @@ -54,12 +49,12 @@ type NetworkMap struct { // between updates and should not be modified. DERPMap *tailcfg.DERPMap - // ControlHealth are the list of health check problems for this + // DisplayMessages are the list of health check problems for this // node from the perspective of the control plane. // If empty, there are no known problems from the control plane's // point of view, but the node might know about its own health // check problems. - ControlHealth []string + DisplayMessages map[tailcfg.DisplayMessageID]tailcfg.DisplayMessage // TKAEnabled indicates whether the tailnet key authority should be // enabled, from the perspective of the control plane. @@ -76,10 +71,9 @@ type NetworkMap struct { // If this is empty, then data-plane audit logging is disabled. DomainAuditLogID string - UserProfiles map[tailcfg.UserID]tailcfg.UserProfile - - // MaxKeyDuration describes the MaxKeyDuration setting for the tailnet. - MaxKeyDuration time.Duration + // UserProfiles contains the profile information of UserIDs referenced + // in SelfNode and Peers. + UserProfiles map[tailcfg.UserID]tailcfg.UserProfileView } // User returns nm.SelfNode.User if nm.SelfNode is non-nil, otherwise it returns @@ -101,6 +95,62 @@ func (nm *NetworkMap) GetAddresses() views.Slice[netip.Prefix] { return nm.SelfNode.Addresses() } +// GetVIPServiceIPMap returns a map of service names to the slice of +// VIP addresses that correspond to the service. The service names are +// with the prefix "svc:". +// +// TODO(tailscale/corp##25997): cache the result of decoding the capmap so that +// we don't have to decode it multiple times after each netmap update. +func (nm *NetworkMap) GetVIPServiceIPMap() tailcfg.ServiceIPMappings { + if nm == nil { + return nil + } + if !nm.SelfNode.Valid() { + return nil + } + + ipMaps, err := tailcfg.UnmarshalNodeCapViewJSON[tailcfg.ServiceIPMappings](nm.SelfNode.CapMap(), tailcfg.NodeAttrServiceHost) + if len(ipMaps) != 1 || err != nil { + return nil + } + + return ipMaps[0] +} + +// GetIPVIPServiceMap returns a map of VIP addresses to the service +// names that has the VIP address. The service names are with the +// prefix "svc:". +func (nm *NetworkMap) GetIPVIPServiceMap() IPServiceMappings { + var res IPServiceMappings + if nm == nil { + return res + } + + if !nm.SelfNode.Valid() { + return res + } + + serviceIPMap := nm.GetVIPServiceIPMap() + if serviceIPMap == nil { + return res + } + res = make(IPServiceMappings) + for svc, addrs := range serviceIPMap { + for _, addr := range addrs { + res[addr] = svc + } + } + return res +} + +// SelfNodeOrZero returns the self node, or a zero value if nm is nil. +func (nm *NetworkMap) SelfNodeOrZero() tailcfg.NodeView { + if nm == nil { + return tailcfg.NodeView{} + } + return nm.SelfNode +} + // AnyPeersAdvertiseRoutes reports whether any peer is advertising non-exit node routes. func (nm *NetworkMap) AnyPeersAdvertiseRoutes() bool { for _, p := range nm.Peers { @@ -181,10 +231,25 @@ func MagicDNSSuffixOfNodeName(nodeName string) string { // // It will neither start nor end with a period. func (nm *NetworkMap) MagicDNSSuffix() string { - if nm == nil { + return MagicDNSSuffixOfNodeName(nm.SelfName()) +} + +// SelfName returns nm.SelfNode.Name, or the empty string +// if nm is nil or nm.SelfNode is invalid. +func (nm *NetworkMap) SelfName() string { + if nm == nil || !nm.SelfNode.Valid() { return "" } - return MagicDNSSuffixOfNodeName(nm.Name) + return nm.SelfNode.Name() +} + +// SelfKeyExpiry returns nm.SelfNode.KeyExpiry, or the zero +// value if nil or nm.SelfNode is invalid. +func (nm *NetworkMap) SelfKeyExpiry() time.Time { + if nm == nil || !nm.SelfNode.Valid() { + return time.Time{} + } + return nm.SelfNode.KeyExpiry() } // DomainName returns the name of the NetworkMap's @@ -197,21 +262,27 @@ func (nm *NetworkMap) DomainName() string { return nm.Domain } -// SelfCapabilities returns SelfNode.Capabilities if nm and nm.SelfNode are -// non-nil. This is a method so we can use it in envknob/logknob without a -// circular dependency. -func (nm *NetworkMap) SelfCapabilities() views.Slice[tailcfg.NodeCapability] { - var zero views.Slice[tailcfg.NodeCapability] +// TailnetDisplayName returns the admin-editable name contained in +// NodeAttrTailnetDisplayName. If the capability is not present it +// returns an empty string. +func (nm *NetworkMap) TailnetDisplayName() string { if nm == nil || !nm.SelfNode.Valid() { - return zero + return "" } - out := nm.SelfNode.Capabilities().AsSlice() - nm.SelfNode.CapMap().Range(func(k tailcfg.NodeCapability, _ views.Slice[tailcfg.RawMessage]) (cont bool) { - out = append(out, k) - return true - }) - return views.SliceOf(out) + tailnetDisplayNames, err := tailcfg.UnmarshalNodeCapViewJSON[string](nm.SelfNode.CapMap(), tailcfg.NodeAttrTailnetDisplayName) + if err != nil || len(tailnetDisplayNames) == 0 { + return "" + } + + return tailnetDisplayNames[0] +} + +// HasSelfCapability reports whether nm.SelfNode contains capability c. +// +// It exists to satisify an unused (as of 2025-01-04) interface in the logknob package. +func (nm *NetworkMap) HasSelfCapability(c tailcfg.NodeCapability) bool { + return nm.AllCaps.Contains(c) } func (nm *NetworkMap) String() string { @@ -251,7 +322,12 @@ func (nm *NetworkMap) PeerWithStableID(pid tailcfg.StableNodeID) (_ tailcfg.Node func (nm *NetworkMap) printConciseHeader(buf *strings.Builder) { fmt.Fprintf(buf, "netmap: self: %v auth=%v", nm.NodeKey.ShortString(), nm.GetMachineStatus()) - login := nm.UserProfiles[nm.User()].LoginName + + var login string + up, ok := nm.UserProfiles[nm.User()] + if ok { + login = up.LoginName() + } if login == "" { if nm.User().IsZero() { login = "?" @@ -279,15 +355,14 @@ func (a *NetworkMap) equalConciseHeader(b *NetworkMap) bool { // in nodeConciseEqual in sync. func printPeerConcise(buf *strings.Builder, p tailcfg.NodeView) { aip := make([]string, p.AllowedIPs().Len()) - for i := range aip { - a := p.AllowedIPs().At(i) - s := strings.TrimSuffix(fmt.Sprint(a), "/32") + for i, a := range p.AllowedIPs().All() { + s := strings.TrimSuffix(a.String(), "/32") aip[i] = s } - ep := make([]string, p.Endpoints().Len()) - for i := range ep { - e := p.Endpoints().At(i).String() + epStrs := make([]string, p.Endpoints().Len()) + for i, ep := range p.Endpoints().All() { + e := ep.String() // Align vertically on the ':' between IP and port colon := strings.IndexByte(e, ':') spaces := 0 @@ -295,14 +370,11 @@ func printPeerConcise(buf *strings.Builder, p tailcfg.NodeView) { spaces++ colon-- } - ep[i] = fmt.Sprintf("%21v", e+strings.Repeat(" ", spaces)) + epStrs[i] = fmt.Sprintf("%21v", e+strings.Repeat(" ", spaces)) } - derp := p.DERP() - const derpPrefix = "127.3.3.40:" - if strings.HasPrefix(derp, derpPrefix) { - derp = "D" + derp[len(derpPrefix):] - } + derp := fmt.Sprintf("D%d", p.HomeDERP()) + var discoShort string if !p.DiscoKey().IsZero() { discoShort = p.DiscoKey().ShortString() + " " @@ -316,13 +388,13 @@ func printPeerConcise(buf *strings.Builder, p tailcfg.NodeView) { discoShort, derp, strings.Join(aip, " "), - strings.Join(ep, " ")) + strings.Join(epStrs, " ")) } // nodeConciseEqual reports whether a and b are equal for the fields accessed by printPeerConcise. func nodeConciseEqual(a, b tailcfg.NodeView) bool { return a.Key() == b.Key() && - a.DERP() == b.DERP() && + a.HomeDERP() == b.HomeDERP() && a.DiscoKey() == b.DiscoKey() && views.SliceEqual(a.AllowedIPs(), b.AllowedIPs()) && views.SliceEqual(a.Endpoints(), b.Endpoints()) @@ -391,3 +463,19 @@ const ( _ WGConfigFlags = 1 << iota AllowSubnetRoutes ) + +// IPServiceMappings maps IP addresses to service names. This is the inverse of +// [tailcfg.ServiceIPMappings], and is used to inform track which service a VIP +// is associated with. This is set to b.ipVIPServiceMap every time the netmap is +// updated. This is used to reduce the cost for looking up the service name for +// the dst IP address in the netStack packet processing workflow. +// +// This is of the form: +// +// { +// "100.65.32.1": "svc:samba", +// "fd7a:115c:a1e0::1234": "svc:samba", +// "100.102.42.3": "svc:web", +// "fd7a:115c:a1e0::abcd": "svc:web", +// } +type IPServiceMappings map[netip.Addr]tailcfg.ServiceName diff --git a/types/netmap/netmap_test.go b/types/netmap/netmap_test.go index e7e2d1957..ee4fecdb4 100644 --- a/types/netmap/netmap_test.go +++ b/types/netmap/netmap_test.go @@ -6,11 +6,13 @@ package netmap import ( "encoding/hex" "net/netip" + "reflect" "testing" "go4.org/mem" "tailscale.com/net/netaddr" "tailscale.com/tailcfg" + "tailscale.com/tstest/typewalk" "tailscale.com/types/key" ) @@ -63,12 +65,12 @@ func TestNetworkMapConcise(t *testing.T) { Peers: nodeViews([]*tailcfg.Node{ { Key: testNodeKey(2), - DERP: "127.3.3.40:2", + HomeDERP: 2, Endpoints: eps("192.168.0.100:12", "192.168.0.100:12354"), }, { Key: testNodeKey(3), - DERP: "127.3.3.40:4", + HomeDERP: 4, Endpoints: eps("10.2.0.100:12", "10.1.0.100:12345"), }, }), @@ -102,7 +104,7 @@ func TestConciseDiffFrom(t *testing.T) { Peers: nodeViews([]*tailcfg.Node{ { Key: testNodeKey(2), - DERP: "127.3.3.40:2", + HomeDERP: 2, Endpoints: eps("192.168.0.100:12", "192.168.0.100:12354"), }, }), @@ -112,7 +114,7 @@ func TestConciseDiffFrom(t *testing.T) { Peers: nodeViews([]*tailcfg.Node{ { Key: testNodeKey(2), - DERP: "127.3.3.40:2", + HomeDERP: 2, Endpoints: eps("192.168.0.100:12", "192.168.0.100:12354"), }, }), @@ -126,7 +128,7 @@ func TestConciseDiffFrom(t *testing.T) { Peers: nodeViews([]*tailcfg.Node{ { Key: testNodeKey(2), - DERP: "127.3.3.40:2", + HomeDERP: 2, Endpoints: eps("192.168.0.100:12", "192.168.0.100:12354"), }, }), @@ -136,7 +138,7 @@ func TestConciseDiffFrom(t *testing.T) { Peers: nodeViews([]*tailcfg.Node{ { Key: testNodeKey(2), - DERP: "127.3.3.40:2", + HomeDERP: 2, Endpoints: eps("192.168.0.100:12", "192.168.0.100:12354"), }, }), @@ -151,7 +153,7 @@ func TestConciseDiffFrom(t *testing.T) { { ID: 2, Key: testNodeKey(2), - DERP: "127.3.3.40:2", + HomeDERP: 2, Endpoints: eps("192.168.0.100:12", "192.168.0.100:12354"), }, }), @@ -162,19 +164,19 @@ func TestConciseDiffFrom(t *testing.T) { { ID: 1, Key: testNodeKey(1), - DERP: "127.3.3.40:1", + HomeDERP: 1, Endpoints: eps("192.168.0.100:12", "192.168.0.100:12354"), }, { ID: 2, Key: testNodeKey(2), - DERP: "127.3.3.40:2", + HomeDERP: 2, Endpoints: eps("192.168.0.100:12", "192.168.0.100:12354"), }, { ID: 3, Key: testNodeKey(3), - DERP: "127.3.3.40:3", + HomeDERP: 3, Endpoints: eps("192.168.0.100:12", "192.168.0.100:12354"), }, }), @@ -189,19 +191,19 @@ func TestConciseDiffFrom(t *testing.T) { { ID: 1, Key: testNodeKey(1), - DERP: "127.3.3.40:1", + HomeDERP: 1, Endpoints: eps("192.168.0.100:12", "192.168.0.100:12354"), }, { ID: 2, Key: testNodeKey(2), - DERP: "127.3.3.40:2", + HomeDERP: 2, Endpoints: eps("192.168.0.100:12", "192.168.0.100:12354"), }, { ID: 3, Key: testNodeKey(3), - DERP: "127.3.3.40:3", + HomeDERP: 3, Endpoints: eps("192.168.0.100:12", "192.168.0.100:12354"), }, }), @@ -212,7 +214,7 @@ func TestConciseDiffFrom(t *testing.T) { { ID: 2, Key: testNodeKey(2), - DERP: "127.3.3.40:2", + HomeDERP: 2, Endpoints: eps("192.168.0.100:12", "192.168.0.100:12354"), }, }), @@ -227,7 +229,7 @@ func TestConciseDiffFrom(t *testing.T) { { ID: 2, Key: testNodeKey(2), - DERP: "127.3.3.40:2", + HomeDERP: 2, Endpoints: eps("192.168.0.100:12", "1.1.1.1:1"), }, }), @@ -238,7 +240,7 @@ func TestConciseDiffFrom(t *testing.T) { { ID: 2, Key: testNodeKey(2), - DERP: "127.3.3.40:2", + HomeDERP: 2, Endpoints: eps("192.168.0.100:12", "1.1.1.1:2"), }, }), @@ -253,7 +255,7 @@ func TestConciseDiffFrom(t *testing.T) { { ID: 2, Key: testNodeKey(2), - DERP: "127.3.3.40:2", + HomeDERP: 2, Endpoints: eps("192.168.0.100:41641", "1.1.1.1:41641"), DiscoKey: testDiscoKey("f00f00f00f"), AllowedIPs: []netip.Prefix{netip.PrefixFrom(netaddr.IPv4(100, 102, 103, 104), 32)}, @@ -266,7 +268,7 @@ func TestConciseDiffFrom(t *testing.T) { { ID: 2, Key: testNodeKey(2), - DERP: "127.3.3.40:2", + HomeDERP: 2, Endpoints: eps("192.168.0.100:41641", "1.1.1.1:41641"), DiscoKey: testDiscoKey("ba4ba4ba4b"), AllowedIPs: []netip.Prefix{netip.PrefixFrom(netaddr.IPv4(100, 102, 103, 104), 32)}, @@ -316,3 +318,10 @@ func TestPeerIndexByNodeID(t *testing.T) { } } } + +func TestNoPrivateKeyMaterial(t *testing.T) { + private := key.PrivateTypesForTest() + for path := range typewalk.MatchingPaths(reflect.TypeFor[NetworkMap](), private.Contains) { + t.Errorf("NetworkMap contains private key material at path: %q", path.Name) + } +} diff --git a/types/netmap/nodemut.go b/types/netmap/nodemut.go index 46fbaefc6..f4de1bf0b 100644 --- a/types/netmap/nodemut.go +++ b/types/netmap/nodemut.go @@ -5,7 +5,6 @@ package netmap import ( "cmp" - "fmt" "net/netip" "reflect" "slices" @@ -35,10 +34,10 @@ type NodeMutationDERPHome struct { } func (m NodeMutationDERPHome) Apply(n *tailcfg.Node) { - n.DERP = fmt.Sprintf("127.3.3.40:%v", m.DERPRegion) + n.HomeDERP = m.DERPRegion } -// NodeMutation is a NodeMutation that says a node's endpoints have changed. +// NodeMutationEndpoints is a NodeMutation that says a node's endpoints have changed. type NodeMutationEndpoints struct { mutatingNodeID Endpoints []netip.AddrPort @@ -164,6 +163,7 @@ func mapResponseContainsNonPatchFields(res *tailcfg.MapResponse) bool { res.PacketFilters != nil || res.UserProfiles != nil || res.Health != nil || + res.DisplayMessages != nil || res.SSHPolicy != nil || res.TKAInfo != nil || res.DomainDataPlaneAuditLogID != "" || @@ -177,6 +177,5 @@ func mapResponseContainsNonPatchFields(res *tailcfg.MapResponse) bool { // function is called, so it should never be set anyway. But for // completedness, and for tests, check it too: res.PeersChanged != nil || - res.DefaultAutoUpdate != "" || - res.MaxKeyDuration > 0 + res.DefaultAutoUpdate != "" } diff --git a/types/opt/bool.go b/types/opt/bool.go index 0a3ee67ad..fbc39e1dc 100644 --- a/types/opt/bool.go +++ b/types/opt/bool.go @@ -18,6 +18,22 @@ import ( // field without it being dropped. type Bool string +const ( + // True is the encoding of an explicit true. + True = Bool("true") + + // False is the encoding of an explicit false. + False = Bool("false") + + // ExplicitlyUnset is the encoding used by a null + // JSON value. It is a synonym for the empty string. + ExplicitlyUnset = Bool("unset") + + // Empty means the Bool is unset and it's neither + // true nor false. + Empty = Bool("") +) + // NewBool constructs a new Bool value equal to b. The returned Bool is set, // unless Set("") or Clear() methods are called. func NewBool(b bool) Bool { @@ -50,16 +66,16 @@ func (b *Bool) Scan(src any) error { switch src := src.(type) { case bool: if src { - *b = "true" + *b = True } else { - *b = "false" + *b = False } return nil case int64: if src == 0 { - *b = "false" + *b = False } else { - *b = "true" + *b = True } return nil default: @@ -67,6 +83,17 @@ func (b *Bool) Scan(src any) error { } } +// Normalized returns the normalized form of b, mapping "unset" to "" +// and leaving other values unchanged. +func (b Bool) Normalized() Bool { + switch b { + case ExplicitlyUnset: + return Empty + default: + return b + } +} + // EqualBool reports whether b is equal to v. // If b is empty or not a valid bool, it reports false. func (b Bool) EqualBool(v bool) bool { @@ -75,18 +102,18 @@ func (b Bool) EqualBool(v bool) bool { } var ( - trueBytes = []byte("true") - falseBytes = []byte("false") + trueBytes = []byte(True) + falseBytes = []byte(False) nullBytes = []byte("null") ) func (b Bool) MarshalJSON() ([]byte, error) { switch b { - case "true": + case True: return trueBytes, nil - case "false": + case False: return falseBytes, nil - case "", "unset": + case Empty, ExplicitlyUnset: return nullBytes, nil } return nil, fmt.Errorf("invalid opt.Bool value %q", string(b)) @@ -95,11 +122,11 @@ func (b Bool) MarshalJSON() ([]byte, error) { func (b *Bool) UnmarshalJSON(j []byte) error { switch string(j) { case "true": - *b = "true" + *b = True case "false": - *b = "false" + *b = False case "null": - *b = "unset" + *b = ExplicitlyUnset default: return fmt.Errorf("invalid opt.Bool value %q", j) } diff --git a/types/opt/bool_test.go b/types/opt/bool_test.go index dddbcfc19..e61d66dbe 100644 --- a/types/opt/bool_test.go +++ b/types/opt/bool_test.go @@ -106,6 +106,8 @@ func TestBoolEqualBool(t *testing.T) { }{ {"", true, false}, {"", false, false}, + {"unset", true, false}, + {"unset", false, false}, {"sdflk;", true, false}, {"sldkf;", false, false}, {"true", true, true}, @@ -122,6 +124,24 @@ func TestBoolEqualBool(t *testing.T) { } } +func TestBoolNormalized(t *testing.T) { + tests := []struct { + in Bool + want Bool + }{ + {"", ""}, + {"true", "true"}, + {"false", "false"}, + {"unset", ""}, + {"foo", "foo"}, + } + for _, tt := range tests { + if got := tt.in.Normalized(); got != tt.want { + t.Errorf("(%q).Normalized() = %q; want %q", string(tt.in), string(got), string(tt.want)) + } + } +} + func TestUnmarshalAlloc(t *testing.T) { b := json.Unmarshaler(new(Bool)) n := testing.AllocsPerRun(10, func() { b.UnmarshalJSON(trueBytes) }) diff --git a/types/opt/value.go b/types/opt/value.go index 54fab7a53..c71c53e51 100644 --- a/types/opt/value.go +++ b/types/opt/value.go @@ -36,7 +36,7 @@ func ValueOf[T any](v T) Value[T] { } // String implements [fmt.Stringer]. -func (o *Value[T]) String() string { +func (o Value[T]) String() string { if !o.set { return fmt.Sprintf("(empty[%T])", o.value) } @@ -100,31 +100,31 @@ func (o Value[T]) Equal(v Value[T]) bool { return false } -// MarshalJSONV2 implements [jsonv2.MarshalerV2]. -func (o Value[T]) MarshalJSONV2(enc *jsontext.Encoder, opts jsonv2.Options) error { +// MarshalJSONTo implements [jsonv2.MarshalerTo]. +func (o Value[T]) MarshalJSONTo(enc *jsontext.Encoder) error { if !o.set { return enc.WriteToken(jsontext.Null) } - return jsonv2.MarshalEncode(enc, &o.value, opts) + return jsonv2.MarshalEncode(enc, &o.value) } -// UnmarshalJSONV2 implements [jsonv2.UnmarshalerV2]. -func (o *Value[T]) UnmarshalJSONV2(dec *jsontext.Decoder, opts jsonv2.Options) error { +// UnmarshalJSONFrom implements [jsonv2.UnmarshalerFrom]. +func (o *Value[T]) UnmarshalJSONFrom(dec *jsontext.Decoder) error { if dec.PeekKind() == 'n' { *o = Value[T]{} _, err := dec.ReadToken() // read null return err } o.set = true - return jsonv2.UnmarshalDecode(dec, &o.value, opts) + return jsonv2.UnmarshalDecode(dec, &o.value) } // MarshalJSON implements [json.Marshaler]. func (o Value[T]) MarshalJSON() ([]byte, error) { - return jsonv2.Marshal(o) // uses MarshalJSONV2 + return jsonv2.Marshal(o) // uses MarshalJSONTo } // UnmarshalJSON implements [json.Unmarshaler]. func (o *Value[T]) UnmarshalJSON(b []byte) error { - return jsonv2.Unmarshal(b, o) // uses UnmarshalJSONV2 + return jsonv2.Unmarshal(b, o) // uses UnmarshalJSONFrom } diff --git a/types/opt/value_test.go b/types/opt/value_test.go index 93d935e27..890f9a579 100644 --- a/types/opt/value_test.go +++ b/types/opt/value_test.go @@ -9,6 +9,13 @@ import ( "testing" jsonv2 "github.com/go-json-experiment/json" + "tailscale.com/types/bools" + "tailscale.com/util/must" +) + +var ( + _ jsonv2.MarshalerTo = (*Value[bool])(nil) + _ jsonv2.UnmarshalerFrom = (*Value[bool])(nil) ) type testStruct struct { @@ -87,7 +94,14 @@ func TestValue(t *testing.T) { False: ValueOf(false), ExplicitUnset: Value[bool]{}, }, - want: `{"True":true,"False":false,"Unset":null,"ExplicitUnset":null}`, + want: bools.IfElse( + // Detect whether v1 "encoding/json" supports `omitzero` or not. + // TODO(Go1.24): Remove this after `omitzero` is supported. + string(must.Get(json.Marshal(struct { + X int `json:",omitzero"` + }{}))) == `{}`, + `{"True":true,"False":false}`, // omitzero supported + `{"True":true,"False":false,"Unset":null,"ExplicitUnset":null}`), // omitzero not supported wantBack: struct { True Value[bool] `json:",omitzero"` False Value[bool] `json:",omitzero"` diff --git a/types/persist/persist.go b/types/persist/persist.go index 8b555abd4..4b62c79dd 100644 --- a/types/persist/persist.go +++ b/types/persist/persist.go @@ -21,22 +21,12 @@ import ( type Persist struct { _ structs.Incomparable - // LegacyFrontendPrivateMachineKey is here temporarily - // (starting 2020-09-28) during migration of Windows users' - // machine keys from frontend storage to the backend. On the - // first LocalBackend.Start call, the backend will initialize - // the real (backend-owned) machine key from the frontend's - // provided value (if non-zero), picking a new random one if - // needed. This field should be considered read-only from GUI - // frontends. The real value should not be written back in - // this field, lest the frontend persist it to disk. - LegacyFrontendPrivateMachineKey key.MachinePrivate `json:"PrivateMachineKey"` - PrivateNodeKey key.NodePrivate OldPrivateNodeKey key.NodePrivate // needed to request key rotation UserProfile tailcfg.UserProfile NetworkLockKey key.NLPrivate NodeID tailcfg.StableNodeID + AttestationKey key.HardwareAttestationKey `json:",omitempty"` // DisallowedTKAStateIDs stores the tka.State.StateID values which // this node will not operate network lock on. This is used to @@ -95,29 +85,37 @@ func (p *Persist) Equals(p2 *Persist) bool { return false } - return p.LegacyFrontendPrivateMachineKey.Equal(p2.LegacyFrontendPrivateMachineKey) && - p.PrivateNodeKey.Equal(p2.PrivateNodeKey) && + var pub, p2Pub key.HardwareAttestationPublic + if p.AttestationKey != nil && !p.AttestationKey.IsZero() { + pub = key.HardwareAttestationPublicFromPlatformKey(p.AttestationKey) + } + if p2.AttestationKey != nil && !p2.AttestationKey.IsZero() { + p2Pub = key.HardwareAttestationPublicFromPlatformKey(p2.AttestationKey) + } + + return p.PrivateNodeKey.Equal(p2.PrivateNodeKey) && p.OldPrivateNodeKey.Equal(p2.OldPrivateNodeKey) && p.UserProfile.Equal(&p2.UserProfile) && p.NetworkLockKey.Equal(p2.NetworkLockKey) && p.NodeID == p2.NodeID && + pub.Equal(p2Pub) && reflect.DeepEqual(nilIfEmpty(p.DisallowedTKAStateIDs), nilIfEmpty(p2.DisallowedTKAStateIDs)) } func (p *Persist) Pretty() string { var ( - mk key.MachinePublic ok, nk key.NodePublic ) - if !p.LegacyFrontendPrivateMachineKey.IsZero() { - mk = p.LegacyFrontendPrivateMachineKey.Public() - } + akString := "-" if !p.OldPrivateNodeKey.IsZero() { ok = p.OldPrivateNodeKey.Public() } if !p.PrivateNodeKey.IsZero() { nk = p.PublicNodeKey() } - return fmt.Sprintf("Persist{lm=%v, o=%v, n=%v u=%#v}", - mk.ShortString(), ok.ShortString(), nk.ShortString(), p.UserProfile.LoginName) + if p.AttestationKey != nil && !p.AttestationKey.IsZero() { + akString = fmt.Sprintf("%v", p.AttestationKey.Public()) + } + return fmt.Sprintf("Persist{o=%v, n=%v u=%#v ak=%s}", + ok.ShortString(), nk.ShortString(), p.UserProfile.LoginName, akString) } diff --git a/types/persist/persist_clone.go b/types/persist/persist_clone.go index 95dd65ac1..9dbe7e0f6 100644 --- a/types/persist/persist_clone.go +++ b/types/persist/persist_clone.go @@ -19,18 +19,21 @@ func (src *Persist) Clone() *Persist { } dst := new(Persist) *dst = *src + if src.AttestationKey != nil { + dst.AttestationKey = src.AttestationKey.Clone() + } dst.DisallowedTKAStateIDs = append(src.DisallowedTKAStateIDs[:0:0], src.DisallowedTKAStateIDs...) return dst } // A compilation failure here means this code must be regenerated, with the command at the top of this file. var _PersistCloneNeedsRegeneration = Persist(struct { - _ structs.Incomparable - LegacyFrontendPrivateMachineKey key.MachinePrivate - PrivateNodeKey key.NodePrivate - OldPrivateNodeKey key.NodePrivate - UserProfile tailcfg.UserProfile - NetworkLockKey key.NLPrivate - NodeID tailcfg.StableNodeID - DisallowedTKAStateIDs []string + _ structs.Incomparable + PrivateNodeKey key.NodePrivate + OldPrivateNodeKey key.NodePrivate + UserProfile tailcfg.UserProfile + NetworkLockKey key.NLPrivate + NodeID tailcfg.StableNodeID + AttestationKey key.HardwareAttestationKey + DisallowedTKAStateIDs []string }{}) diff --git a/types/persist/persist_test.go b/types/persist/persist_test.go index 6b159573d..713114b74 100644 --- a/types/persist/persist_test.go +++ b/types/persist/persist_test.go @@ -21,13 +21,12 @@ func fieldsOf(t reflect.Type) (fields []string) { } func TestPersistEqual(t *testing.T) { - persistHandles := []string{"LegacyFrontendPrivateMachineKey", "PrivateNodeKey", "OldPrivateNodeKey", "UserProfile", "NetworkLockKey", "NodeID", "DisallowedTKAStateIDs"} + persistHandles := []string{"PrivateNodeKey", "OldPrivateNodeKey", "UserProfile", "NetworkLockKey", "NodeID", "AttestationKey", "DisallowedTKAStateIDs"} if have := fieldsOf(reflect.TypeFor[Persist]()); !reflect.DeepEqual(have, persistHandles) { t.Errorf("Persist.Equal check might be out of sync\nfields: %q\nhandled: %q\n", have, persistHandles) } - m1 := key.NewMachine() k1 := key.NewNode() nl1 := key.NewNLPrivate() tests := []struct { @@ -39,17 +38,6 @@ func TestPersistEqual(t *testing.T) { {&Persist{}, nil, false}, {&Persist{}, &Persist{}, true}, - { - &Persist{LegacyFrontendPrivateMachineKey: m1}, - &Persist{LegacyFrontendPrivateMachineKey: key.NewMachine()}, - false, - }, - { - &Persist{LegacyFrontendPrivateMachineKey: m1}, - &Persist{LegacyFrontendPrivateMachineKey: m1}, - true, - }, - { &Persist{PrivateNodeKey: k1}, &Persist{PrivateNodeKey: key.NewNode()}, diff --git a/types/persist/persist_view.go b/types/persist/persist_view.go index 1d479b3bf..dbf8294ef 100644 --- a/types/persist/persist_view.go +++ b/types/persist/persist_view.go @@ -6,9 +6,11 @@ package persist import ( - "encoding/json" + jsonv1 "encoding/json" "errors" + jsonv2 "github.com/go-json-experiment/json" + "github.com/go-json-experiment/json/jsontext" "tailscale.com/tailcfg" "tailscale.com/types/key" "tailscale.com/types/structs" @@ -17,7 +19,7 @@ import ( //go:generate go run tailscale.com/cmd/cloner -clonefunc=false -type=Persist -// View returns a readonly view of Persist. +// View returns a read-only view of Persist. func (p *Persist) View() PersistView { return PersistView{Đļ: p} } @@ -33,7 +35,7 @@ type PersistView struct { Đļ *Persist } -// Valid reports whether underlying value is non-nil. +// Valid reports whether v's underlying value is non-nil. func (v PersistView) Valid() bool { return v.Đļ != nil } // AsStruct returns a clone of the underlying value which aliases no memory with @@ -45,8 +47,17 @@ func (v PersistView) AsStruct() *Persist { return v.Đļ.Clone() } -func (v PersistView) MarshalJSON() ([]byte, error) { return json.Marshal(v.Đļ) } +// MarshalJSON implements [jsonv1.Marshaler]. +func (v PersistView) MarshalJSON() ([]byte, error) { + return jsonv1.Marshal(v.Đļ) +} + +// MarshalJSONTo implements [jsonv2.MarshalerTo]. +func (v PersistView) MarshalJSONTo(enc *jsontext.Encoder) error { + return jsonv2.MarshalEncode(enc, v.Đļ) +} +// UnmarshalJSON implements [jsonv1.Unmarshaler]. func (v *PersistView) UnmarshalJSON(b []byte) error { if v.Đļ != nil { return errors.New("already initialized") @@ -55,33 +66,51 @@ func (v *PersistView) UnmarshalJSON(b []byte) error { return nil } var x Persist - if err := json.Unmarshal(b, &x); err != nil { + if err := jsonv1.Unmarshal(b, &x); err != nil { return err } v.Đļ = &x return nil } -func (v PersistView) LegacyFrontendPrivateMachineKey() key.MachinePrivate { - return v.Đļ.LegacyFrontendPrivateMachineKey +// UnmarshalJSONFrom implements [jsonv2.UnmarshalerFrom]. +func (v *PersistView) UnmarshalJSONFrom(dec *jsontext.Decoder) error { + if v.Đļ != nil { + return errors.New("already initialized") + } + var x Persist + if err := jsonv2.UnmarshalDecode(dec, &x); err != nil { + return err + } + v.Đļ = &x + return nil } -func (v PersistView) PrivateNodeKey() key.NodePrivate { return v.Đļ.PrivateNodeKey } -func (v PersistView) OldPrivateNodeKey() key.NodePrivate { return v.Đļ.OldPrivateNodeKey } -func (v PersistView) UserProfile() tailcfg.UserProfile { return v.Đļ.UserProfile } -func (v PersistView) NetworkLockKey() key.NLPrivate { return v.Đļ.NetworkLockKey } -func (v PersistView) NodeID() tailcfg.StableNodeID { return v.Đļ.NodeID } + +func (v PersistView) PrivateNodeKey() key.NodePrivate { return v.Đļ.PrivateNodeKey } + +// needed to request key rotation +func (v PersistView) OldPrivateNodeKey() key.NodePrivate { return v.Đļ.OldPrivateNodeKey } +func (v PersistView) UserProfile() tailcfg.UserProfile { return v.Đļ.UserProfile } +func (v PersistView) NetworkLockKey() key.NLPrivate { return v.Đļ.NetworkLockKey } +func (v PersistView) NodeID() tailcfg.StableNodeID { return v.Đļ.NodeID } +func (v PersistView) AttestationKey() tailcfg.StableNodeID { panic("unsupported") } + +// DisallowedTKAStateIDs stores the tka.State.StateID values which +// this node will not operate network lock on. This is used to +// prevent bootstrapping TKA onto a key authority which was forcibly +// disabled. func (v PersistView) DisallowedTKAStateIDs() views.Slice[string] { return views.SliceOf(v.Đļ.DisallowedTKAStateIDs) } // A compilation failure here means this code must be regenerated, with the command at the top of this file. var _PersistViewNeedsRegeneration = Persist(struct { - _ structs.Incomparable - LegacyFrontendPrivateMachineKey key.MachinePrivate - PrivateNodeKey key.NodePrivate - OldPrivateNodeKey key.NodePrivate - UserProfile tailcfg.UserProfile - NetworkLockKey key.NLPrivate - NodeID tailcfg.StableNodeID - DisallowedTKAStateIDs []string + _ structs.Incomparable + PrivateNodeKey key.NodePrivate + OldPrivateNodeKey key.NodePrivate + UserProfile tailcfg.UserProfile + NetworkLockKey key.NLPrivate + NodeID tailcfg.StableNodeID + AttestationKey key.HardwareAttestationKey + DisallowedTKAStateIDs []string }{}) diff --git a/types/prefs/item.go b/types/prefs/item.go index 103204147..717a0c76c 100644 --- a/types/prefs/item.go +++ b/types/prefs/item.go @@ -152,15 +152,15 @@ func (iv ItemView[T, V]) Equal(iv2 ItemView[T, V]) bool { return iv.Đļ.Equal(*iv2.Đļ) } -// MarshalJSONV2 implements [jsonv2.MarshalerV2]. -func (iv ItemView[T, V]) MarshalJSONV2(out *jsontext.Encoder, opts jsonv2.Options) error { - return iv.Đļ.MarshalJSONV2(out, opts) +// MarshalJSONTo implements [jsonv2.MarshalerTo]. +func (iv ItemView[T, V]) MarshalJSONTo(out *jsontext.Encoder) error { + return iv.Đļ.MarshalJSONTo(out) } -// UnmarshalJSONV2 implements [jsonv2.UnmarshalerV2]. -func (iv *ItemView[T, V]) UnmarshalJSONV2(in *jsontext.Decoder, opts jsonv2.Options) error { +// UnmarshalJSONFrom implements [jsonv2.UnmarshalerFrom]. +func (iv *ItemView[T, V]) UnmarshalJSONFrom(in *jsontext.Decoder) error { var x Item[T] - if err := x.UnmarshalJSONV2(in, opts); err != nil { + if err := x.UnmarshalJSONFrom(in); err != nil { return err } iv.Đļ = &x @@ -169,10 +169,10 @@ func (iv *ItemView[T, V]) UnmarshalJSONV2(in *jsontext.Decoder, opts jsonv2.Opti // MarshalJSON implements [json.Marshaler]. func (iv ItemView[T, V]) MarshalJSON() ([]byte, error) { - return jsonv2.Marshal(iv) // uses MarshalJSONV2 + return jsonv2.Marshal(iv) // uses MarshalJSONTo } // UnmarshalJSON implements [json.Unmarshaler]. func (iv *ItemView[T, V]) UnmarshalJSON(b []byte) error { - return jsonv2.Unmarshal(b, iv) // uses UnmarshalJSONV2 + return jsonv2.Unmarshal(b, iv) // uses UnmarshalJSONFrom } diff --git a/types/prefs/list.go b/types/prefs/list.go index 9830e79de..ae6b2fae3 100644 --- a/types/prefs/list.go +++ b/types/prefs/list.go @@ -45,36 +45,36 @@ func ListWithOpts[T ImmutableType](opts ...Options) List[T] { // SetValue configures the preference with the specified value. // It fails and returns [ErrManaged] if p is a managed preference, // and [ErrReadOnly] if p is a read-only preference. -func (l *List[T]) SetValue(val []T) error { - return l.preference.SetValue(cloneSlice(val)) +func (ls *List[T]) SetValue(val []T) error { + return ls.preference.SetValue(cloneSlice(val)) } // SetManagedValue configures the preference with the specified value // and marks the preference as managed. -func (l *List[T]) SetManagedValue(val []T) { - l.preference.SetManagedValue(cloneSlice(val)) +func (ls *List[T]) SetManagedValue(val []T) { + ls.preference.SetManagedValue(cloneSlice(val)) } // View returns a read-only view of l. -func (l *List[T]) View() ListView[T] { - return ListView[T]{l} +func (ls *List[T]) View() ListView[T] { + return ListView[T]{ls} } // Clone returns a copy of l that aliases no memory with l. -func (l List[T]) Clone() *List[T] { - res := ptr.To(l) - if v, ok := l.s.Value.GetOk(); ok { +func (ls List[T]) Clone() *List[T] { + res := ptr.To(ls) + if v, ok := ls.s.Value.GetOk(); ok { res.s.Value.Set(append(v[:0:0], v...)) } return res } // Equal reports whether l and l2 are equal. -func (l List[T]) Equal(l2 List[T]) bool { - if l.s.Metadata != l2.s.Metadata { +func (ls List[T]) Equal(l2 List[T]) bool { + if ls.s.Metadata != l2.s.Metadata { return false } - v1, ok1 := l.s.Value.GetOk() + v1, ok1 := ls.s.Value.GetOk() v2, ok2 := l2.s.Value.GetOk() if ok1 != ok2 { return false @@ -157,15 +157,20 @@ func (lv ListView[T]) Equal(lv2 ListView[T]) bool { return lv.Đļ.Equal(*lv2.Đļ) } -// MarshalJSONV2 implements [jsonv2.MarshalerV2]. -func (lv ListView[T]) MarshalJSONV2(out *jsontext.Encoder, opts jsonv2.Options) error { - return lv.Đļ.MarshalJSONV2(out, opts) +var ( + _ jsonv2.MarshalerTo = (*ListView[bool])(nil) + _ jsonv2.UnmarshalerFrom = (*ListView[bool])(nil) +) + +// MarshalJSONTo implements [jsonv2.MarshalerTo]. +func (lv ListView[T]) MarshalJSONTo(out *jsontext.Encoder) error { + return lv.Đļ.MarshalJSONTo(out) } -// UnmarshalJSONV2 implements [jsonv2.UnmarshalerV2]. -func (lv *ListView[T]) UnmarshalJSONV2(in *jsontext.Decoder, opts jsonv2.Options) error { +// UnmarshalJSONFrom implements [jsonv2.UnmarshalerFrom]. +func (lv *ListView[T]) UnmarshalJSONFrom(in *jsontext.Decoder) error { var x List[T] - if err := x.UnmarshalJSONV2(in, opts); err != nil { + if err := x.UnmarshalJSONFrom(in); err != nil { return err } lv.Đļ = &x @@ -174,10 +179,10 @@ func (lv *ListView[T]) UnmarshalJSONV2(in *jsontext.Decoder, opts jsonv2.Options // MarshalJSON implements [json.Marshaler]. func (lv ListView[T]) MarshalJSON() ([]byte, error) { - return jsonv2.Marshal(lv) // uses MarshalJSONV2 + return jsonv2.Marshal(lv) // uses MarshalJSONTo } // UnmarshalJSON implements [json.Unmarshaler]. func (lv *ListView[T]) UnmarshalJSON(b []byte) error { - return jsonv2.Unmarshal(b, lv) // uses UnmarshalJSONV2 + return jsonv2.Unmarshal(b, lv) // uses UnmarshalJSONFrom } diff --git a/types/prefs/map.go b/types/prefs/map.go index 2bd32bfbd..4b64690ed 100644 --- a/types/prefs/map.go +++ b/types/prefs/map.go @@ -133,15 +133,15 @@ func (mv MapView[K, V]) Equal(mv2 MapView[K, V]) bool { return mv.Đļ.Equal(*mv2.Đļ) } -// MarshalJSONV2 implements [jsonv2.MarshalerV2]. -func (mv MapView[K, V]) MarshalJSONV2(out *jsontext.Encoder, opts jsonv2.Options) error { - return mv.Đļ.MarshalJSONV2(out, opts) +// MarshalJSONTo implements [jsonv2.MarshalerTo]. +func (mv MapView[K, V]) MarshalJSONTo(out *jsontext.Encoder) error { + return mv.Đļ.MarshalJSONTo(out) } -// UnmarshalJSONV2 implements [jsonv2.UnmarshalerV2]. -func (mv *MapView[K, V]) UnmarshalJSONV2(in *jsontext.Decoder, opts jsonv2.Options) error { +// UnmarshalJSONFrom implements [jsonv2.UnmarshalerFrom]. +func (mv *MapView[K, V]) UnmarshalJSONFrom(in *jsontext.Decoder) error { var x Map[K, V] - if err := x.UnmarshalJSONV2(in, opts); err != nil { + if err := x.UnmarshalJSONFrom(in); err != nil { return err } mv.Đļ = &x @@ -150,10 +150,10 @@ func (mv *MapView[K, V]) UnmarshalJSONV2(in *jsontext.Decoder, opts jsonv2.Optio // MarshalJSON implements [json.Marshaler]. func (mv MapView[K, V]) MarshalJSON() ([]byte, error) { - return jsonv2.Marshal(mv) // uses MarshalJSONV2 + return jsonv2.Marshal(mv) // uses MarshalJSONTo } // UnmarshalJSON implements [json.Unmarshaler]. func (mv *MapView[K, V]) UnmarshalJSON(b []byte) error { - return jsonv2.Unmarshal(b, mv) // uses UnmarshalJSONV2 + return jsonv2.Unmarshal(b, mv) // uses UnmarshalJSONFrom } diff --git a/types/prefs/prefs.go b/types/prefs/prefs.go index 3bbd237fe..a6caf1283 100644 --- a/types/prefs/prefs.go +++ b/types/prefs/prefs.go @@ -29,8 +29,8 @@ import ( var ( // ErrManaged is the error returned when attempting to modify a managed preference. ErrManaged = errors.New("cannot modify a managed preference") - // ErrReadOnly is the error returned when attempting to modify a readonly preference. - ErrReadOnly = errors.New("cannot modify a readonly preference") + // ErrReadOnly is the error returned when attempting to modify a read-only preference. + ErrReadOnly = errors.New("cannot modify a read-only preference") ) // metadata holds type-agnostic preference metadata. @@ -158,22 +158,27 @@ func (p *preference[T]) SetReadOnly(readonly bool) { p.s.Metadata.ReadOnly = readonly } -// MarshalJSONV2 implements [jsonv2.MarshalerV2]. -func (p preference[T]) MarshalJSONV2(out *jsontext.Encoder, opts jsonv2.Options) error { - return jsonv2.MarshalEncode(out, &p.s, opts) +var ( + _ jsonv2.MarshalerTo = (*preference[struct{}])(nil) + _ jsonv2.UnmarshalerFrom = (*preference[struct{}])(nil) +) + +// MarshalJSONTo implements [jsonv2.MarshalerTo]. +func (p preference[T]) MarshalJSONTo(out *jsontext.Encoder) error { + return jsonv2.MarshalEncode(out, &p.s) } -// UnmarshalJSONV2 implements [jsonv2.UnmarshalerV2]. -func (p *preference[T]) UnmarshalJSONV2(in *jsontext.Decoder, opts jsonv2.Options) error { - return jsonv2.UnmarshalDecode(in, &p.s, opts) +// UnmarshalJSONFrom implements [jsonv2.UnmarshalerFrom]. +func (p *preference[T]) UnmarshalJSONFrom(in *jsontext.Decoder) error { + return jsonv2.UnmarshalDecode(in, &p.s) } // MarshalJSON implements [json.Marshaler]. func (p preference[T]) MarshalJSON() ([]byte, error) { - return jsonv2.Marshal(p) // uses MarshalJSONV2 + return jsonv2.Marshal(p) // uses MarshalJSONTo } // UnmarshalJSON implements [json.Unmarshaler]. func (p *preference[T]) UnmarshalJSON(b []byte) error { - return jsonv2.Unmarshal(b, p) // uses UnmarshalJSONV2 + return jsonv2.Unmarshal(b, p) // uses UnmarshalJSONFrom } diff --git a/types/prefs/prefs_example/prefs_example_view.go b/types/prefs/prefs_example/prefs_example_view.go index 0256bd7e6..6a1a36865 100644 --- a/types/prefs/prefs_example/prefs_example_view.go +++ b/types/prefs/prefs_example/prefs_example_view.go @@ -6,10 +6,12 @@ package prefs_example import ( - "encoding/json" + jsonv1 "encoding/json" "errors" "net/netip" + jsonv2 "github.com/go-json-experiment/json" + "github.com/go-json-experiment/json/jsontext" "tailscale.com/drive" "tailscale.com/tailcfg" "tailscale.com/types/opt" @@ -20,7 +22,7 @@ import ( //go:generate go run tailscale.com/cmd/cloner -clonefunc=false -type=Prefs,AutoUpdatePrefs,AppConnectorPrefs -// View returns a readonly view of Prefs. +// View returns a read-only view of Prefs. func (p *Prefs) View() PrefsView { return PrefsView{Đļ: p} } @@ -36,7 +38,7 @@ type PrefsView struct { Đļ *Prefs } -// Valid reports whether underlying value is non-nil. +// Valid reports whether v's underlying value is non-nil. func (v PrefsView) Valid() bool { return v.Đļ != nil } // AsStruct returns a clone of the underlying value which aliases no memory with @@ -48,8 +50,17 @@ func (v PrefsView) AsStruct() *Prefs { return v.Đļ.Clone() } -func (v PrefsView) MarshalJSON() ([]byte, error) { return json.Marshal(v.Đļ) } +// MarshalJSON implements [jsonv1.Marshaler]. +func (v PrefsView) MarshalJSON() ([]byte, error) { + return jsonv1.Marshal(v.Đļ) +} + +// MarshalJSONTo implements [jsonv2.MarshalerTo]. +func (v PrefsView) MarshalJSONTo(enc *jsontext.Encoder) error { + return jsonv2.MarshalEncode(enc, v.Đļ) +} +// UnmarshalJSON implements [jsonv1.Unmarshaler]. func (v *PrefsView) UnmarshalJSON(b []byte) error { if v.Đļ != nil { return errors.New("already initialized") @@ -58,45 +69,88 @@ func (v *PrefsView) UnmarshalJSON(b []byte) error { return nil } var x Prefs - if err := json.Unmarshal(b, &x); err != nil { + if err := jsonv1.Unmarshal(b, &x); err != nil { return err } v.Đļ = &x return nil } -func (v PrefsView) ControlURL() prefs.Item[string] { return v.Đļ.ControlURL } -func (v PrefsView) RouteAll() prefs.Item[bool] { return v.Đļ.RouteAll } -func (v PrefsView) ExitNodeID() prefs.Item[tailcfg.StableNodeID] { return v.Đļ.ExitNodeID } -func (v PrefsView) ExitNodeIP() prefs.Item[netip.Addr] { return v.Đļ.ExitNodeIP } -func (v PrefsView) ExitNodePrior() tailcfg.StableNodeID { return v.Đļ.ExitNodePrior } -func (v PrefsView) ExitNodeAllowLANAccess() prefs.Item[bool] { return v.Đļ.ExitNodeAllowLANAccess } -func (v PrefsView) CorpDNS() prefs.Item[bool] { return v.Đļ.CorpDNS } -func (v PrefsView) RunSSH() prefs.Item[bool] { return v.Đļ.RunSSH } -func (v PrefsView) RunWebClient() prefs.Item[bool] { return v.Đļ.RunWebClient } -func (v PrefsView) WantRunning() prefs.Item[bool] { return v.Đļ.WantRunning } -func (v PrefsView) LoggedOut() prefs.Item[bool] { return v.Đļ.LoggedOut } -func (v PrefsView) ShieldsUp() prefs.Item[bool] { return v.Đļ.ShieldsUp } -func (v PrefsView) AdvertiseTags() prefs.ListView[string] { return v.Đļ.AdvertiseTags.View() } -func (v PrefsView) Hostname() prefs.Item[string] { return v.Đļ.Hostname } -func (v PrefsView) NotepadURLs() prefs.Item[bool] { return v.Đļ.NotepadURLs } -func (v PrefsView) ForceDaemon() prefs.Item[bool] { return v.Đļ.ForceDaemon } -func (v PrefsView) Egg() prefs.Item[bool] { return v.Đļ.Egg } +// UnmarshalJSONFrom implements [jsonv2.UnmarshalerFrom]. +func (v *PrefsView) UnmarshalJSONFrom(dec *jsontext.Decoder) error { + if v.Đļ != nil { + return errors.New("already initialized") + } + var x Prefs + if err := jsonv2.UnmarshalDecode(dec, &x); err != nil { + return err + } + v.Đļ = &x + return nil +} + +func (v PrefsView) ControlURL() prefs.Item[string] { return v.Đļ.ControlURL } +func (v PrefsView) RouteAll() prefs.Item[bool] { return v.Đļ.RouteAll } +func (v PrefsView) ExitNodeID() prefs.Item[tailcfg.StableNodeID] { return v.Đļ.ExitNodeID } +func (v PrefsView) ExitNodeIP() prefs.Item[netip.Addr] { return v.Đļ.ExitNodeIP } + +// ExitNodePrior is an internal state rather than a preference. +// It can be kept in the Prefs structure but should not be wrapped +// and is ignored by the [prefs] package. +func (v PrefsView) ExitNodePrior() tailcfg.StableNodeID { return v.Đļ.ExitNodePrior } +func (v PrefsView) ExitNodeAllowLANAccess() prefs.Item[bool] { return v.Đļ.ExitNodeAllowLANAccess } +func (v PrefsView) CorpDNS() prefs.Item[bool] { return v.Đļ.CorpDNS } +func (v PrefsView) RunSSH() prefs.Item[bool] { return v.Đļ.RunSSH } +func (v PrefsView) RunWebClient() prefs.Item[bool] { return v.Đļ.RunWebClient } +func (v PrefsView) WantRunning() prefs.Item[bool] { return v.Đļ.WantRunning } +func (v PrefsView) LoggedOut() prefs.Item[bool] { return v.Đļ.LoggedOut } +func (v PrefsView) ShieldsUp() prefs.Item[bool] { return v.Đļ.ShieldsUp } + +// AdvertiseTags is a preference whose value is a slice of strings. +// The value is atomic, and individual items in the slice should +// not be modified after the preference is set. +// Since the item type (string) is immutable, we can use [prefs.List]. +func (v PrefsView) AdvertiseTags() prefs.ListView[string] { return v.Đļ.AdvertiseTags.View() } +func (v PrefsView) Hostname() prefs.Item[string] { return v.Đļ.Hostname } +func (v PrefsView) NotepadURLs() prefs.Item[bool] { return v.Đļ.NotepadURLs } +func (v PrefsView) ForceDaemon() prefs.Item[bool] { return v.Đļ.ForceDaemon } +func (v PrefsView) Egg() prefs.Item[bool] { return v.Đļ.Egg } + +// AdvertiseRoutes is a preference whose value is a slice of netip.Prefix. +// The value is atomic, and individual items in the slice should +// not be modified after the preference is set. +// Since the item type (netip.Prefix) is immutable, we can use [prefs.List]. func (v PrefsView) AdvertiseRoutes() prefs.ListView[netip.Prefix] { return v.Đļ.AdvertiseRoutes.View() } func (v PrefsView) NoSNAT() prefs.Item[bool] { return v.Đļ.NoSNAT } func (v PrefsView) NoStatefulFiltering() prefs.Item[opt.Bool] { return v.Đļ.NoStatefulFiltering } func (v PrefsView) NetfilterMode() prefs.Item[preftype.NetfilterMode] { return v.Đļ.NetfilterMode } func (v PrefsView) OperatorUser() prefs.Item[string] { return v.Đļ.OperatorUser } func (v PrefsView) ProfileName() prefs.Item[string] { return v.Đļ.ProfileName } -func (v PrefsView) AutoUpdate() AutoUpdatePrefs { return v.Đļ.AutoUpdate } -func (v PrefsView) AppConnector() AppConnectorPrefs { return v.Đļ.AppConnector } -func (v PrefsView) PostureChecking() prefs.Item[bool] { return v.Đļ.PostureChecking } -func (v PrefsView) NetfilterKind() prefs.Item[string] { return v.Đļ.NetfilterKind } + +// AutoUpdate contains auto-update preferences. +// Each preference in the group can be configured and managed individually. +func (v PrefsView) AutoUpdate() AutoUpdatePrefs { return v.Đļ.AutoUpdate } + +// AppConnector contains app connector-related preferences. +// Each preference in the group can be configured and managed individually. +func (v PrefsView) AppConnector() AppConnectorPrefs { return v.Đļ.AppConnector } +func (v PrefsView) PostureChecking() prefs.Item[bool] { return v.Đļ.PostureChecking } +func (v PrefsView) NetfilterKind() prefs.Item[string] { return v.Đļ.NetfilterKind } + +// DriveShares is a preference whose value is a slice of *[drive.Share]. +// The value is atomic, and individual items in the slice should +// not be modified after the preference is set. +// Since the item type (*drive.Share) is mutable and implements [views.ViewCloner], +// we need to use [prefs.StructList] instead of [prefs.List]. func (v PrefsView) DriveShares() prefs.StructListView[*drive.Share, drive.ShareView] { return prefs.StructListViewOf(&v.Đļ.DriveShares) } func (v PrefsView) AllowSingleHosts() prefs.Item[marshalAsTrueInJSON] { return v.Đļ.AllowSingleHosts } -func (v PrefsView) Persist() persist.PersistView { return v.Đļ.Persist.View() } + +// Persist is an internal state rather than a preference. +// It can be kept in the Prefs structure but should not be wrapped +// and is ignored by the [prefs] package. +func (v PrefsView) Persist() persist.PersistView { return v.Đļ.Persist.View() } // A compilation failure here means this code must be regenerated, with the command at the top of this file. var _PrefsViewNeedsRegeneration = Prefs(struct { @@ -132,7 +186,7 @@ var _PrefsViewNeedsRegeneration = Prefs(struct { Persist *persist.Persist }{}) -// View returns a readonly view of AutoUpdatePrefs. +// View returns a read-only view of AutoUpdatePrefs. func (p *AutoUpdatePrefs) View() AutoUpdatePrefsView { return AutoUpdatePrefsView{Đļ: p} } @@ -148,7 +202,7 @@ type AutoUpdatePrefsView struct { Đļ *AutoUpdatePrefs } -// Valid reports whether underlying value is non-nil. +// Valid reports whether v's underlying value is non-nil. func (v AutoUpdatePrefsView) Valid() bool { return v.Đļ != nil } // AsStruct returns a clone of the underlying value which aliases no memory with @@ -160,8 +214,17 @@ func (v AutoUpdatePrefsView) AsStruct() *AutoUpdatePrefs { return v.Đļ.Clone() } -func (v AutoUpdatePrefsView) MarshalJSON() ([]byte, error) { return json.Marshal(v.Đļ) } +// MarshalJSON implements [jsonv1.Marshaler]. +func (v AutoUpdatePrefsView) MarshalJSON() ([]byte, error) { + return jsonv1.Marshal(v.Đļ) +} + +// MarshalJSONTo implements [jsonv2.MarshalerTo]. +func (v AutoUpdatePrefsView) MarshalJSONTo(enc *jsontext.Encoder) error { + return jsonv2.MarshalEncode(enc, v.Đļ) +} +// UnmarshalJSON implements [jsonv1.Unmarshaler]. func (v *AutoUpdatePrefsView) UnmarshalJSON(b []byte) error { if v.Đļ != nil { return errors.New("already initialized") @@ -170,7 +233,20 @@ func (v *AutoUpdatePrefsView) UnmarshalJSON(b []byte) error { return nil } var x AutoUpdatePrefs - if err := json.Unmarshal(b, &x); err != nil { + if err := jsonv1.Unmarshal(b, &x); err != nil { + return err + } + v.Đļ = &x + return nil +} + +// UnmarshalJSONFrom implements [jsonv2.UnmarshalerFrom]. +func (v *AutoUpdatePrefsView) UnmarshalJSONFrom(dec *jsontext.Decoder) error { + if v.Đļ != nil { + return errors.New("already initialized") + } + var x AutoUpdatePrefs + if err := jsonv2.UnmarshalDecode(dec, &x); err != nil { return err } v.Đļ = &x @@ -186,7 +262,7 @@ var _AutoUpdatePrefsViewNeedsRegeneration = AutoUpdatePrefs(struct { Apply prefs.Item[opt.Bool] }{}) -// View returns a readonly view of AppConnectorPrefs. +// View returns a read-only view of AppConnectorPrefs. func (p *AppConnectorPrefs) View() AppConnectorPrefsView { return AppConnectorPrefsView{Đļ: p} } @@ -202,7 +278,7 @@ type AppConnectorPrefsView struct { Đļ *AppConnectorPrefs } -// Valid reports whether underlying value is non-nil. +// Valid reports whether v's underlying value is non-nil. func (v AppConnectorPrefsView) Valid() bool { return v.Đļ != nil } // AsStruct returns a clone of the underlying value which aliases no memory with @@ -214,8 +290,17 @@ func (v AppConnectorPrefsView) AsStruct() *AppConnectorPrefs { return v.Đļ.Clone() } -func (v AppConnectorPrefsView) MarshalJSON() ([]byte, error) { return json.Marshal(v.Đļ) } +// MarshalJSON implements [jsonv1.Marshaler]. +func (v AppConnectorPrefsView) MarshalJSON() ([]byte, error) { + return jsonv1.Marshal(v.Đļ) +} +// MarshalJSONTo implements [jsonv2.MarshalerTo]. +func (v AppConnectorPrefsView) MarshalJSONTo(enc *jsontext.Encoder) error { + return jsonv2.MarshalEncode(enc, v.Đļ) +} + +// UnmarshalJSON implements [jsonv1.Unmarshaler]. func (v *AppConnectorPrefsView) UnmarshalJSON(b []byte) error { if v.Đļ != nil { return errors.New("already initialized") @@ -224,7 +309,20 @@ func (v *AppConnectorPrefsView) UnmarshalJSON(b []byte) error { return nil } var x AppConnectorPrefs - if err := json.Unmarshal(b, &x); err != nil { + if err := jsonv1.Unmarshal(b, &x); err != nil { + return err + } + v.Đļ = &x + return nil +} + +// UnmarshalJSONFrom implements [jsonv2.UnmarshalerFrom]. +func (v *AppConnectorPrefsView) UnmarshalJSONFrom(dec *jsontext.Decoder) error { + if v.Đļ != nil { + return errors.New("already initialized") + } + var x AppConnectorPrefs + if err := jsonv2.UnmarshalDecode(dec, &x); err != nil { return err } v.Đļ = &x diff --git a/types/prefs/prefs_example/prefs_types.go b/types/prefs/prefs_example/prefs_types.go index 49f0d8c3c..c35f1f62f 100644 --- a/types/prefs/prefs_example/prefs_types.go +++ b/types/prefs/prefs_example/prefs_types.go @@ -48,10 +48,10 @@ import ( // the `omitzero` JSON tag option. This option is not supported by the // [encoding/json] package as of 2024-08-21; see golang/go#45669. // It is recommended that a prefs type implements both -// [jsonv2.MarshalerV2]/[jsonv2.UnmarshalerV2] and [json.Marshaler]/[json.Unmarshaler] +// [jsonv2.MarshalerTo]/[jsonv2.UnmarshalerFrom] and [json.Marshaler]/[json.Unmarshaler] // to ensure consistent and more performant marshaling, regardless of the JSON package // used at the call sites; the standard marshalers can be implemented via [jsonv2]. -// See [Prefs.MarshalJSONV2], [Prefs.UnmarshalJSONV2], [Prefs.MarshalJSON], +// See [Prefs.MarshalJSONTo], [Prefs.UnmarshalJSONFrom], [Prefs.MarshalJSON], // and [Prefs.UnmarshalJSON] for an example implementation. type Prefs struct { ControlURL prefs.Item[string] `json:",omitzero"` @@ -128,34 +128,39 @@ type AppConnectorPrefs struct { Advertise prefs.Item[bool] `json:",omitzero"` } -// MarshalJSONV2 implements [jsonv2.MarshalerV2]. +var ( + _ jsonv2.MarshalerTo = (*Prefs)(nil) + _ jsonv2.UnmarshalerFrom = (*Prefs)(nil) +) + +// MarshalJSONTo implements [jsonv2.MarshalerTo]. // It is implemented as a performance improvement and to enable omission of // unconfigured preferences from the JSON output. See the [Prefs] doc for details. -func (p Prefs) MarshalJSONV2(out *jsontext.Encoder, opts jsonv2.Options) error { +func (p Prefs) MarshalJSONTo(out *jsontext.Encoder) error { // The prefs type shadows the Prefs's method set, // causing [jsonv2] to use the default marshaler and avoiding // infinite recursion. type prefs Prefs - return jsonv2.MarshalEncode(out, (*prefs)(&p), opts) + return jsonv2.MarshalEncode(out, (*prefs)(&p)) } -// UnmarshalJSONV2 implements [jsonv2.UnmarshalerV2]. -func (p *Prefs) UnmarshalJSONV2(in *jsontext.Decoder, opts jsonv2.Options) error { +// UnmarshalJSONFrom implements [jsonv2.UnmarshalerFrom]. +func (p *Prefs) UnmarshalJSONFrom(in *jsontext.Decoder) error { // The prefs type shadows the Prefs's method set, // causing [jsonv2] to use the default unmarshaler and avoiding // infinite recursion. type prefs Prefs - return jsonv2.UnmarshalDecode(in, (*prefs)(p), opts) + return jsonv2.UnmarshalDecode(in, (*prefs)(p)) } // MarshalJSON implements [json.Marshaler]. func (p Prefs) MarshalJSON() ([]byte, error) { - return jsonv2.Marshal(p) // uses MarshalJSONV2 + return jsonv2.Marshal(p) // uses MarshalJSONTo } // UnmarshalJSON implements [json.Unmarshaler]. func (p *Prefs) UnmarshalJSON(b []byte) error { - return jsonv2.Unmarshal(b, p) // uses UnmarshalJSONV2 + return jsonv2.Unmarshal(b, p) // uses UnmarshalJSONFrom } type marshalAsTrueInJSON struct{} diff --git a/types/prefs/prefs_test.go b/types/prefs/prefs_test.go index ea4729366..dc1213adb 100644 --- a/types/prefs/prefs_test.go +++ b/types/prefs/prefs_test.go @@ -19,6 +19,20 @@ import ( //go:generate go run tailscale.com/cmd/viewer --tags=test --type=TestPrefs,TestBundle,TestValueStruct,TestGenericStruct,TestPrefsGroup +var ( + _ jsonv2.MarshalerTo = (*ItemView[*TestBundle, TestBundleView])(nil) + _ jsonv2.UnmarshalerFrom = (*ItemView[*TestBundle, TestBundleView])(nil) + + _ jsonv2.MarshalerTo = (*MapView[string, string])(nil) + _ jsonv2.UnmarshalerFrom = (*MapView[string, string])(nil) + + _ jsonv2.MarshalerTo = (*StructListView[*TestBundle, TestBundleView])(nil) + _ jsonv2.UnmarshalerFrom = (*StructListView[*TestBundle, TestBundleView])(nil) + + _ jsonv2.MarshalerTo = (*StructMapView[string, *TestBundle, TestBundleView])(nil) + _ jsonv2.UnmarshalerFrom = (*StructMapView[string, *TestBundle, TestBundleView])(nil) +) + type TestPrefs struct { Int32Item Item[int32] `json:",omitzero"` UInt64Item Item[uint64] `json:",omitzero"` @@ -53,32 +67,37 @@ type TestPrefs struct { Group TestPrefsGroup `json:",omitzero"` } -// MarshalJSONV2 implements [jsonv2.MarshalerV2]. -func (p TestPrefs) MarshalJSONV2(out *jsontext.Encoder, opts jsonv2.Options) error { +var ( + _ jsonv2.MarshalerTo = (*TestPrefs)(nil) + _ jsonv2.UnmarshalerFrom = (*TestPrefs)(nil) +) + +// MarshalJSONTo implements [jsonv2.MarshalerTo]. +func (p TestPrefs) MarshalJSONTo(out *jsontext.Encoder) error { // The testPrefs type shadows the TestPrefs's method set, // causing jsonv2 to use the default marshaler and avoiding // infinite recursion. type testPrefs TestPrefs - return jsonv2.MarshalEncode(out, (*testPrefs)(&p), opts) + return jsonv2.MarshalEncode(out, (*testPrefs)(&p)) } -// UnmarshalJSONV2 implements [jsonv2.UnmarshalerV2]. -func (p *TestPrefs) UnmarshalJSONV2(in *jsontext.Decoder, opts jsonv2.Options) error { +// UnmarshalJSONFrom implements [jsonv2.UnmarshalerFrom]. +func (p *TestPrefs) UnmarshalJSONFrom(in *jsontext.Decoder) error { // The testPrefs type shadows the TestPrefs's method set, // causing jsonv2 to use the default unmarshaler and avoiding // infinite recursion. type testPrefs TestPrefs - return jsonv2.UnmarshalDecode(in, (*testPrefs)(p), opts) + return jsonv2.UnmarshalDecode(in, (*testPrefs)(p)) } // MarshalJSON implements [json.Marshaler]. func (p TestPrefs) MarshalJSON() ([]byte, error) { - return jsonv2.Marshal(p) // uses MarshalJSONV2 + return jsonv2.Marshal(p) // uses MarshalJSONTo } // UnmarshalJSON implements [json.Unmarshaler]. func (p *TestPrefs) UnmarshalJSON(b []byte) error { - return jsonv2.Unmarshal(b, p) // uses UnmarshalJSONV2 + return jsonv2.Unmarshal(b, p) // uses UnmarshalJSONFrom } // TestBundle is an example structure type that, @@ -468,31 +487,31 @@ func TestItemView(t *testing.T) { } func TestListView(t *testing.T) { - l := ListOf([]int{4, 8, 15, 16, 23, 42}, ReadOnly) + ls := ListOf([]int{4, 8, 15, 16, 23, 42}, ReadOnly) - lv := l.View() + lv := ls.View() checkIsSet(t, lv, true) checkIsManaged(t, lv, false) checkIsReadOnly(t, lv, true) - checkValue(t, lv, views.SliceOf(l.Value())) - checkValueOk(t, lv, views.SliceOf(l.Value()), true) + checkValue(t, lv, views.SliceOf(ls.Value())) + checkValueOk(t, lv, views.SliceOf(ls.Value()), true) l2 := *lv.AsStruct() - checkEqual(t, l, l2, true) + checkEqual(t, ls, l2, true) } func TestStructListView(t *testing.T) { - l := StructListOf([]*TestBundle{{Name: "E1"}, {Name: "E2"}}, ReadOnly) + ls := StructListOf([]*TestBundle{{Name: "E1"}, {Name: "E2"}}, ReadOnly) - lv := StructListViewOf(&l) + lv := StructListViewOf(&ls) checkIsSet(t, lv, true) checkIsManaged(t, lv, false) checkIsReadOnly(t, lv, true) - checkValue(t, lv, views.SliceOfViews(l.Value())) - checkValueOk(t, lv, views.SliceOfViews(l.Value()), true) + checkValue(t, lv, views.SliceOfViews(ls.Value())) + checkValueOk(t, lv, views.SliceOfViews(ls.Value()), true) l2 := *lv.AsStruct() - checkEqual(t, l, l2, true) + checkEqual(t, ls, l2, true) } func TestStructMapView(t *testing.T) { diff --git a/types/prefs/prefs_view_test.go b/types/prefs/prefs_view_test.go index d76eebb43..8993cb535 100644 --- a/types/prefs/prefs_view_test.go +++ b/types/prefs/prefs_view_test.go @@ -6,14 +6,17 @@ package prefs import ( - "encoding/json" + jsonv1 "encoding/json" "errors" "net/netip" + + jsonv2 "github.com/go-json-experiment/json" + "github.com/go-json-experiment/json/jsontext" ) //go:generate go run tailscale.com/cmd/cloner -clonefunc=false -type=TestPrefs,TestBundle,TestValueStruct,TestGenericStruct,TestPrefsGroup -tags=test -// View returns a readonly view of TestPrefs. +// View returns a read-only view of TestPrefs. func (p *TestPrefs) View() TestPrefsView { return TestPrefsView{Đļ: p} } @@ -29,7 +32,7 @@ type TestPrefsView struct { Đļ *TestPrefs } -// Valid reports whether underlying value is non-nil. +// Valid reports whether v's underlying value is non-nil. func (v TestPrefsView) Valid() bool { return v.Đļ != nil } // AsStruct returns a clone of the underlying value which aliases no memory with @@ -41,8 +44,17 @@ func (v TestPrefsView) AsStruct() *TestPrefs { return v.Đļ.Clone() } -func (v TestPrefsView) MarshalJSON() ([]byte, error) { return json.Marshal(v.Đļ) } +// MarshalJSON implements [jsonv1.Marshaler]. +func (v TestPrefsView) MarshalJSON() ([]byte, error) { + return jsonv1.Marshal(v.Đļ) +} +// MarshalJSONTo implements [jsonv2.MarshalerTo]. +func (v TestPrefsView) MarshalJSONTo(enc *jsontext.Encoder) error { + return jsonv2.MarshalEncode(enc, v.Đļ) +} + +// UnmarshalJSON implements [jsonv1.Unmarshaler]. func (v *TestPrefsView) UnmarshalJSON(b []byte) error { if v.Đļ != nil { return errors.New("already initialized") @@ -51,7 +63,20 @@ func (v *TestPrefsView) UnmarshalJSON(b []byte) error { return nil } var x TestPrefs - if err := json.Unmarshal(b, &x); err != nil { + if err := jsonv1.Unmarshal(b, &x); err != nil { + return err + } + v.Đļ = &x + return nil +} + +// UnmarshalJSONFrom implements [jsonv2.UnmarshalerFrom]. +func (v *TestPrefsView) UnmarshalJSONFrom(dec *jsontext.Decoder) error { + if v.Đļ != nil { + return errors.New("already initialized") + } + var x TestPrefs + if err := jsonv2.UnmarshalDecode(dec, &x); err != nil { return err } v.Đļ = &x @@ -70,6 +95,9 @@ func (v TestPrefsView) AddrItem() Item[netip.Addr] { return v.Đļ.A func (v TestPrefsView) StringStringMap() MapView[string, string] { return v.Đļ.StringStringMap.View() } func (v TestPrefsView) IntStringMap() MapView[int, string] { return v.Đļ.IntStringMap.View() } func (v TestPrefsView) AddrIntMap() MapView[netip.Addr, int] { return v.Đļ.AddrIntMap.View() } + +// Bundles are complex preferences that usually consist of +// multiple parameters that must be configured atomically. func (v TestPrefsView) Bundle1() ItemView[*TestBundle, TestBundleView] { return ItemViewOf(&v.Đļ.Bundle1) } @@ -91,6 +119,10 @@ func (v TestPrefsView) IntBundleMap() StructMapView[int, *TestBundle, TestBundle func (v TestPrefsView) AddrBundleMap() StructMapView[netip.Addr, *TestBundle, TestBundleView] { return StructMapViewOf(&v.Đļ.AddrBundleMap) } + +// Group is a nested struct that contains one or more preferences. +// Each preference in a group can be configured individually. +// Preference groups should be included directly rather than by pointers. func (v TestPrefsView) Group() TestPrefsGroup { return v.Đļ.Group } // A compilation failure here means this code must be regenerated, with the command at the top of this file. @@ -117,7 +149,7 @@ var _TestPrefsViewNeedsRegeneration = TestPrefs(struct { Group TestPrefsGroup }{}) -// View returns a readonly view of TestBundle. +// View returns a read-only view of TestBundle. func (p *TestBundle) View() TestBundleView { return TestBundleView{Đļ: p} } @@ -133,7 +165,7 @@ type TestBundleView struct { Đļ *TestBundle } -// Valid reports whether underlying value is non-nil. +// Valid reports whether v's underlying value is non-nil. func (v TestBundleView) Valid() bool { return v.Đļ != nil } // AsStruct returns a clone of the underlying value which aliases no memory with @@ -145,8 +177,17 @@ func (v TestBundleView) AsStruct() *TestBundle { return v.Đļ.Clone() } -func (v TestBundleView) MarshalJSON() ([]byte, error) { return json.Marshal(v.Đļ) } +// MarshalJSON implements [jsonv1.Marshaler]. +func (v TestBundleView) MarshalJSON() ([]byte, error) { + return jsonv1.Marshal(v.Đļ) +} +// MarshalJSONTo implements [jsonv2.MarshalerTo]. +func (v TestBundleView) MarshalJSONTo(enc *jsontext.Encoder) error { + return jsonv2.MarshalEncode(enc, v.Đļ) +} + +// UnmarshalJSON implements [jsonv1.Unmarshaler]. func (v *TestBundleView) UnmarshalJSON(b []byte) error { if v.Đļ != nil { return errors.New("already initialized") @@ -155,22 +196,28 @@ func (v *TestBundleView) UnmarshalJSON(b []byte) error { return nil } var x TestBundle - if err := json.Unmarshal(b, &x); err != nil { + if err := jsonv1.Unmarshal(b, &x); err != nil { return err } v.Đļ = &x return nil } -func (v TestBundleView) Name() string { return v.Đļ.Name } -func (v TestBundleView) Nested() *TestValueStruct { - if v.Đļ.Nested == nil { - return nil +// UnmarshalJSONFrom implements [jsonv2.UnmarshalerFrom]. +func (v *TestBundleView) UnmarshalJSONFrom(dec *jsontext.Decoder) error { + if v.Đļ != nil { + return errors.New("already initialized") + } + var x TestBundle + if err := jsonv2.UnmarshalDecode(dec, &x); err != nil { + return err } - x := *v.Đļ.Nested - return &x + v.Đļ = &x + return nil } +func (v TestBundleView) Name() string { return v.Đļ.Name } +func (v TestBundleView) Nested() TestValueStructView { return v.Đļ.Nested.View() } func (v TestBundleView) Equal(v2 TestBundleView) bool { return v.Đļ.Equal(v2.Đļ) } // A compilation failure here means this code must be regenerated, with the command at the top of this file. @@ -179,7 +226,7 @@ var _TestBundleViewNeedsRegeneration = TestBundle(struct { Nested *TestValueStruct }{}) -// View returns a readonly view of TestValueStruct. +// View returns a read-only view of TestValueStruct. func (p *TestValueStruct) View() TestValueStructView { return TestValueStructView{Đļ: p} } @@ -195,7 +242,7 @@ type TestValueStructView struct { Đļ *TestValueStruct } -// Valid reports whether underlying value is non-nil. +// Valid reports whether v's underlying value is non-nil. func (v TestValueStructView) Valid() bool { return v.Đļ != nil } // AsStruct returns a clone of the underlying value which aliases no memory with @@ -207,8 +254,17 @@ func (v TestValueStructView) AsStruct() *TestValueStruct { return v.Đļ.Clone() } -func (v TestValueStructView) MarshalJSON() ([]byte, error) { return json.Marshal(v.Đļ) } +// MarshalJSON implements [jsonv1.Marshaler]. +func (v TestValueStructView) MarshalJSON() ([]byte, error) { + return jsonv1.Marshal(v.Đļ) +} + +// MarshalJSONTo implements [jsonv2.MarshalerTo]. +func (v TestValueStructView) MarshalJSONTo(enc *jsontext.Encoder) error { + return jsonv2.MarshalEncode(enc, v.Đļ) +} +// UnmarshalJSON implements [jsonv1.Unmarshaler]. func (v *TestValueStructView) UnmarshalJSON(b []byte) error { if v.Đļ != nil { return errors.New("already initialized") @@ -217,7 +273,20 @@ func (v *TestValueStructView) UnmarshalJSON(b []byte) error { return nil } var x TestValueStruct - if err := json.Unmarshal(b, &x); err != nil { + if err := jsonv1.Unmarshal(b, &x); err != nil { + return err + } + v.Đļ = &x + return nil +} + +// UnmarshalJSONFrom implements [jsonv2.UnmarshalerFrom]. +func (v *TestValueStructView) UnmarshalJSONFrom(dec *jsontext.Decoder) error { + if v.Đļ != nil { + return errors.New("already initialized") + } + var x TestValueStruct + if err := jsonv2.UnmarshalDecode(dec, &x); err != nil { return err } v.Đļ = &x @@ -232,7 +301,7 @@ var _TestValueStructViewNeedsRegeneration = TestValueStruct(struct { Value int }{}) -// View returns a readonly view of TestGenericStruct. +// View returns a read-only view of TestGenericStruct. func (p *TestGenericStruct[T]) View() TestGenericStructView[T] { return TestGenericStructView[T]{Đļ: p} } @@ -248,7 +317,7 @@ type TestGenericStructView[T ImmutableType] struct { Đļ *TestGenericStruct[T] } -// Valid reports whether underlying value is non-nil. +// Valid reports whether v's underlying value is non-nil. func (v TestGenericStructView[T]) Valid() bool { return v.Đļ != nil } // AsStruct returns a clone of the underlying value which aliases no memory with @@ -260,8 +329,17 @@ func (v TestGenericStructView[T]) AsStruct() *TestGenericStruct[T] { return v.Đļ.Clone() } -func (v TestGenericStructView[T]) MarshalJSON() ([]byte, error) { return json.Marshal(v.Đļ) } +// MarshalJSON implements [jsonv1.Marshaler]. +func (v TestGenericStructView[T]) MarshalJSON() ([]byte, error) { + return jsonv1.Marshal(v.Đļ) +} +// MarshalJSONTo implements [jsonv2.MarshalerTo]. +func (v TestGenericStructView[T]) MarshalJSONTo(enc *jsontext.Encoder) error { + return jsonv2.MarshalEncode(enc, v.Đļ) +} + +// UnmarshalJSON implements [jsonv1.Unmarshaler]. func (v *TestGenericStructView[T]) UnmarshalJSON(b []byte) error { if v.Đļ != nil { return errors.New("already initialized") @@ -270,7 +348,20 @@ func (v *TestGenericStructView[T]) UnmarshalJSON(b []byte) error { return nil } var x TestGenericStruct[T] - if err := json.Unmarshal(b, &x); err != nil { + if err := jsonv1.Unmarshal(b, &x); err != nil { + return err + } + v.Đļ = &x + return nil +} + +// UnmarshalJSONFrom implements [jsonv2.UnmarshalerFrom]. +func (v *TestGenericStructView[T]) UnmarshalJSONFrom(dec *jsontext.Decoder) error { + if v.Đļ != nil { + return errors.New("already initialized") + } + var x TestGenericStruct[T] + if err := jsonv2.UnmarshalDecode(dec, &x); err != nil { return err } v.Đļ = &x @@ -287,7 +378,7 @@ func _TestGenericStructViewNeedsRegeneration[T ImmutableType](TestGenericStruct[ }{}) } -// View returns a readonly view of TestPrefsGroup. +// View returns a read-only view of TestPrefsGroup. func (p *TestPrefsGroup) View() TestPrefsGroupView { return TestPrefsGroupView{Đļ: p} } @@ -303,7 +394,7 @@ type TestPrefsGroupView struct { Đļ *TestPrefsGroup } -// Valid reports whether underlying value is non-nil. +// Valid reports whether v's underlying value is non-nil. func (v TestPrefsGroupView) Valid() bool { return v.Đļ != nil } // AsStruct returns a clone of the underlying value which aliases no memory with @@ -315,8 +406,17 @@ func (v TestPrefsGroupView) AsStruct() *TestPrefsGroup { return v.Đļ.Clone() } -func (v TestPrefsGroupView) MarshalJSON() ([]byte, error) { return json.Marshal(v.Đļ) } +// MarshalJSON implements [jsonv1.Marshaler]. +func (v TestPrefsGroupView) MarshalJSON() ([]byte, error) { + return jsonv1.Marshal(v.Đļ) +} + +// MarshalJSONTo implements [jsonv2.MarshalerTo]. +func (v TestPrefsGroupView) MarshalJSONTo(enc *jsontext.Encoder) error { + return jsonv2.MarshalEncode(enc, v.Đļ) +} +// UnmarshalJSON implements [jsonv1.Unmarshaler]. func (v *TestPrefsGroupView) UnmarshalJSON(b []byte) error { if v.Đļ != nil { return errors.New("already initialized") @@ -325,7 +425,20 @@ func (v *TestPrefsGroupView) UnmarshalJSON(b []byte) error { return nil } var x TestPrefsGroup - if err := json.Unmarshal(b, &x); err != nil { + if err := jsonv1.Unmarshal(b, &x); err != nil { + return err + } + v.Đļ = &x + return nil +} + +// UnmarshalJSONFrom implements [jsonv2.UnmarshalerFrom]. +func (v *TestPrefsGroupView) UnmarshalJSONFrom(dec *jsontext.Decoder) error { + if v.Đļ != nil { + return errors.New("already initialized") + } + var x TestPrefsGroup + if err := jsonv2.UnmarshalDecode(dec, &x); err != nil { return err } v.Đļ = &x diff --git a/types/prefs/struct_list.go b/types/prefs/struct_list.go index 872cb2326..ba145e2cf 100644 --- a/types/prefs/struct_list.go +++ b/types/prefs/struct_list.go @@ -33,20 +33,20 @@ func StructListWithOpts[T views.Cloner[T]](opts ...Options) StructList[T] { // SetValue configures the preference with the specified value. // It fails and returns [ErrManaged] if p is a managed preference, // and [ErrReadOnly] if p is a read-only preference. -func (l *StructList[T]) SetValue(val []T) error { - return l.preference.SetValue(deepCloneSlice(val)) +func (ls *StructList[T]) SetValue(val []T) error { + return ls.preference.SetValue(deepCloneSlice(val)) } // SetManagedValue configures the preference with the specified value // and marks the preference as managed. -func (l *StructList[T]) SetManagedValue(val []T) { - l.preference.SetManagedValue(deepCloneSlice(val)) +func (ls *StructList[T]) SetManagedValue(val []T) { + ls.preference.SetManagedValue(deepCloneSlice(val)) } // Clone returns a copy of l that aliases no memory with l. -func (l StructList[T]) Clone() *StructList[T] { - res := ptr.To(l) - if v, ok := l.s.Value.GetOk(); ok { +func (ls StructList[T]) Clone() *StructList[T] { + res := ptr.To(ls) + if v, ok := ls.s.Value.GetOk(); ok { res.s.Value.Set(deepCloneSlice(v)) } return res @@ -56,11 +56,11 @@ func (l StructList[T]) Clone() *StructList[T] { // If the template type T implements an Equal(T) bool method, it will be used // instead of the == operator for value comparison. // It panics if T is not comparable. -func (l StructList[T]) Equal(l2 StructList[T]) bool { - if l.s.Metadata != l2.s.Metadata { +func (ls StructList[T]) Equal(l2 StructList[T]) bool { + if ls.s.Metadata != l2.s.Metadata { return false } - v1, ok1 := l.s.Value.GetOk() + v1, ok1 := ls.s.Value.GetOk() v2, ok2 := l2.s.Value.GetOk() if ok1 != ok2 { return false @@ -105,8 +105,8 @@ type StructListView[T views.ViewCloner[T, V], V views.StructView[T]] struct { // StructListViewOf returns a read-only view of l. // It is used by [tailscale.com/cmd/viewer]. -func StructListViewOf[T views.ViewCloner[T, V], V views.StructView[T]](l *StructList[T]) StructListView[T, V] { - return StructListView[T, V]{l} +func StructListViewOf[T views.ViewCloner[T, V], V views.StructView[T]](ls *StructList[T]) StructListView[T, V] { + return StructListView[T, V]{ls} } // Valid reports whether the underlying [StructList] is non-nil. @@ -169,15 +169,15 @@ func (lv StructListView[T, V]) Equal(lv2 StructListView[T, V]) bool { return lv.Đļ.Equal(*lv2.Đļ) } -// MarshalJSONV2 implements [jsonv2.MarshalerV2]. -func (lv StructListView[T, V]) MarshalJSONV2(out *jsontext.Encoder, opts jsonv2.Options) error { - return lv.Đļ.MarshalJSONV2(out, opts) +// MarshalJSONTo implements [jsonv2.MarshalerTo]. +func (lv StructListView[T, V]) MarshalJSONTo(out *jsontext.Encoder) error { + return lv.Đļ.MarshalJSONTo(out) } -// UnmarshalJSONV2 implements [jsonv2.UnmarshalerV2]. -func (lv *StructListView[T, V]) UnmarshalJSONV2(in *jsontext.Decoder, opts jsonv2.Options) error { +// UnmarshalJSONFrom implements [jsonv2.UnmarshalerFrom]. +func (lv *StructListView[T, V]) UnmarshalJSONFrom(in *jsontext.Decoder) error { var x StructList[T] - if err := x.UnmarshalJSONV2(in, opts); err != nil { + if err := x.UnmarshalJSONFrom(in); err != nil { return err } lv.Đļ = &x @@ -186,10 +186,10 @@ func (lv *StructListView[T, V]) UnmarshalJSONV2(in *jsontext.Decoder, opts jsonv // MarshalJSON implements [json.Marshaler]. func (lv StructListView[T, V]) MarshalJSON() ([]byte, error) { - return jsonv2.Marshal(lv) // uses MarshalJSONV2 + return jsonv2.Marshal(lv) // uses MarshalJSONTo } // UnmarshalJSON implements [json.Unmarshaler]. func (lv *StructListView[T, V]) UnmarshalJSON(b []byte) error { - return jsonv2.Unmarshal(b, lv) // uses UnmarshalJSONV2 + return jsonv2.Unmarshal(b, lv) // uses UnmarshalJSONFrom } diff --git a/types/prefs/struct_map.go b/types/prefs/struct_map.go index 2003eebe3..83cc7447b 100644 --- a/types/prefs/struct_map.go +++ b/types/prefs/struct_map.go @@ -31,14 +31,14 @@ func StructMapWithOpts[K MapKeyType, V views.Cloner[V]](opts ...Options) StructM // SetValue configures the preference with the specified value. // It fails and returns [ErrManaged] if p is a managed preference, // and [ErrReadOnly] if p is a read-only preference. -func (l *StructMap[K, V]) SetValue(val map[K]V) error { - return l.preference.SetValue(deepCloneMap(val)) +func (m *StructMap[K, V]) SetValue(val map[K]V) error { + return m.preference.SetValue(deepCloneMap(val)) } // SetManagedValue configures the preference with the specified value // and marks the preference as managed. -func (l *StructMap[K, V]) SetManagedValue(val map[K]V) { - l.preference.SetManagedValue(deepCloneMap(val)) +func (m *StructMap[K, V]) SetManagedValue(val map[K]V) { + m.preference.SetManagedValue(deepCloneMap(val)) } // Clone returns a copy of m that aliases no memory with m. @@ -83,7 +83,7 @@ type StructMapView[K MapKeyType, T views.ViewCloner[T, V], V views.StructView[T] Đļ *StructMap[K, T] } -// StructMapViewOf returns a readonly view of m. +// StructMapViewOf returns a read-only view of m. // It is used by [tailscale.com/cmd/viewer]. func StructMapViewOf[K MapKeyType, T views.ViewCloner[T, V], V views.StructView[T]](m *StructMap[K, T]) StructMapView[K, T, V] { return StructMapView[K, T, V]{m} @@ -149,15 +149,15 @@ func (mv StructMapView[K, T, V]) Equal(mv2 StructMapView[K, T, V]) bool { return mv.Đļ.Equal(*mv2.Đļ) } -// MarshalJSONV2 implements [jsonv2.MarshalerV2]. -func (mv StructMapView[K, T, V]) MarshalJSONV2(out *jsontext.Encoder, opts jsonv2.Options) error { - return mv.Đļ.MarshalJSONV2(out, opts) +// MarshalJSONTo implements [jsonv2.MarshalerTo]. +func (mv StructMapView[K, T, V]) MarshalJSONTo(out *jsontext.Encoder) error { + return mv.Đļ.MarshalJSONTo(out) } -// UnmarshalJSONV2 implements [jsonv2.UnmarshalerV2]. -func (mv *StructMapView[K, T, V]) UnmarshalJSONV2(in *jsontext.Decoder, opts jsonv2.Options) error { +// UnmarshalJSONFrom implements [jsonv2.UnmarshalerFrom]. +func (mv *StructMapView[K, T, V]) UnmarshalJSONFrom(in *jsontext.Decoder) error { var x StructMap[K, T] - if err := x.UnmarshalJSONV2(in, opts); err != nil { + if err := x.UnmarshalJSONFrom(in); err != nil { return err } mv.Đļ = &x @@ -166,10 +166,10 @@ func (mv *StructMapView[K, T, V]) UnmarshalJSONV2(in *jsontext.Decoder, opts jso // MarshalJSON implements [json.Marshaler]. func (mv StructMapView[K, T, V]) MarshalJSON() ([]byte, error) { - return jsonv2.Marshal(mv) // uses MarshalJSONV2 + return jsonv2.Marshal(mv) // uses MarshalJSONTo } // UnmarshalJSON implements [json.Unmarshaler]. func (mv *StructMapView[K, T, V]) UnmarshalJSON(b []byte) error { - return jsonv2.Unmarshal(b, mv) // uses UnmarshalJSONV2 + return jsonv2.Unmarshal(b, mv) // uses UnmarshalJSONFrom } diff --git a/types/result/result.go b/types/result/result.go new file mode 100644 index 000000000..6bd1c2ea6 --- /dev/null +++ b/types/result/result.go @@ -0,0 +1,49 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package result contains the Of result type, which is +// either a value or an error. +package result + +// Of is either a T value or an error. +// +// Think of it like Rust or Swift's result types. +// It's named "Of" because the fully qualified name +// for callers reads result.Of[T]. +type Of[T any] struct { + v T // valid if Err is nil; invalid if Err is non-nil + err error +} + +// Value returns a new result with value v, +// without an error. +func Value[T any](v T) Of[T] { + return Of[T]{v: v} +} + +// Error returns a new result with error err. +// If err is nil, the returned result is equivalent +// to calling Value with T's zero value. +func Error[T any](err error) Of[T] { + return Of[T]{err: err} +} + +// MustValue returns r's result value. +// It panics if r.Err returns non-nil. +func (r Of[T]) MustValue() T { + if r.err != nil { + panic(r.err) + } + return r.v +} + +// Value returns r's result value and error. +func (r Of[T]) Value() (T, error) { + return r.v, r.err +} + +// Err returns r's error, if any. +// When r.Err returns nil, it's safe to call r.MustValue without it panicking. +func (r Of[T]) Err() error { + return r.err +} diff --git a/types/views/views.go b/types/views/views.go index 5fe88fa6c..252f126a7 100644 --- a/types/views/views.go +++ b/types/views/views.go @@ -7,7 +7,8 @@ package views import ( "bytes" - "encoding/json" + "cmp" + jsonv1 "encoding/json" "errors" "fmt" "iter" @@ -15,19 +16,12 @@ import ( "reflect" "slices" + jsonv2 "github.com/go-json-experiment/json" + "github.com/go-json-experiment/json/jsontext" "go4.org/mem" + "tailscale.com/types/ptr" ) -func unmarshalSliceFromJSON[T any](b []byte, x *[]T) error { - if *x != nil { - return errors.New("already initialized") - } - if len(b) == 0 { - return nil - } - return json.Unmarshal(b, x) -} - // ByteSlice is a read-only accessor for types that are backed by a []byte. type ByteSlice[T ~[]byte] struct { // Đļ is the underlying mutable value, named with a hard-to-type @@ -92,15 +86,32 @@ func (v ByteSlice[T]) SliceTo(i int) ByteSlice[T] { return ByteSlice[T]{v.Đļ[:i] // Slice returns v[i:j] func (v ByteSlice[T]) Slice(i, j int) ByteSlice[T] { return ByteSlice[T]{v.Đļ[i:j]} } -// MarshalJSON implements json.Marshaler. -func (v ByteSlice[T]) MarshalJSON() ([]byte, error) { return json.Marshal(v.Đļ) } +// MarshalJSON implements [jsonv1.Marshaler]. +func (v ByteSlice[T]) MarshalJSON() ([]byte, error) { + return jsonv1.Marshal(v.Đļ) +} + +// MarshalJSONTo implements [jsonv2.MarshalerTo]. +func (v ByteSlice[T]) MarshalJSONTo(enc *jsontext.Encoder) error { + return jsonv2.MarshalEncode(enc, v.Đļ) +} -// UnmarshalJSON implements json.Unmarshaler. +// UnmarshalJSON implements [jsonv1.Unmarshaler]. +// It must only be called on an uninitialized ByteSlice. func (v *ByteSlice[T]) UnmarshalJSON(b []byte) error { if v.Đļ != nil { return errors.New("already initialized") } - return json.Unmarshal(b, &v.Đļ) + return jsonv1.Unmarshal(b, &v.Đļ) +} + +// UnmarshalJSONFrom implements [jsonv2.UnmarshalerFrom]. +// It must only be called on an uninitialized ByteSlice. +func (v *ByteSlice[T]) UnmarshalJSONFrom(dec *jsontext.Decoder) error { + if v.Đļ != nil { + return errors.New("already initialized") + } + return jsonv2.UnmarshalDecode(dec, &v.Đļ) } // StructView represents the corresponding StructView of a Viewable. The concrete types are @@ -158,11 +169,35 @@ func (v SliceView[T, V]) All() iter.Seq2[int, V] { } } -// MarshalJSON implements json.Marshaler. -func (v SliceView[T, V]) MarshalJSON() ([]byte, error) { return json.Marshal(v.Đļ) } +// MarshalJSON implements [jsonv1.Marshaler]. +func (v SliceView[T, V]) MarshalJSON() ([]byte, error) { + return jsonv1.Marshal(v.Đļ) +} + +// MarshalJSONTo implements [jsonv2.MarshalerTo]. +func (v SliceView[T, V]) MarshalJSONTo(enc *jsontext.Encoder) error { + return jsonv2.MarshalEncode(enc, v.Đļ) +} -// UnmarshalJSON implements json.Unmarshaler. -func (v *SliceView[T, V]) UnmarshalJSON(b []byte) error { return unmarshalSliceFromJSON(b, &v.Đļ) } +// UnmarshalJSON implements [jsonv1.Unmarshaler]. +// It must only be called on an uninitialized SliceView. +func (v *SliceView[T, V]) UnmarshalJSON(b []byte) error { + if v.Đļ != nil { + return errors.New("already initialized") + } else if len(b) == 0 { + return nil + } + return jsonv1.Unmarshal(b, &v.Đļ) +} + +// UnmarshalJSONFrom implements [jsonv2.UnmarshalerFrom]. +// It must only be called on an uninitialized SliceView. +func (v *SliceView[T, V]) UnmarshalJSONFrom(dec *jsontext.Decoder) error { + if v.Đļ != nil { + return errors.New("already initialized") + } + return jsonv2.UnmarshalDecode(dec, &v.Đļ) +} // IsNil reports whether the underlying slice is nil. func (v SliceView[T, V]) IsNil() bool { return v.Đļ == nil } @@ -251,14 +286,34 @@ func SliceOf[T any](x []T) Slice[T] { return Slice[T]{x} } -// MarshalJSON implements json.Marshaler. +// MarshalJSON implements [jsonv1.Marshaler]. func (v Slice[T]) MarshalJSON() ([]byte, error) { - return json.Marshal(v.Đļ) + return jsonv1.Marshal(v.Đļ) } -// UnmarshalJSON implements json.Unmarshaler. +// MarshalJSONTo implements [jsonv2.MarshalerTo]. +func (v Slice[T]) MarshalJSONTo(enc *jsontext.Encoder) error { + return jsonv2.MarshalEncode(enc, v.Đļ) +} + +// UnmarshalJSON implements [jsonv1.Unmarshaler]. +// It must only be called on an uninitialized Slice. func (v *Slice[T]) UnmarshalJSON(b []byte) error { - return unmarshalSliceFromJSON(b, &v.Đļ) + if v.Đļ != nil { + return errors.New("already initialized") + } else if len(b) == 0 { + return nil + } + return jsonv1.Unmarshal(b, &v.Đļ) +} + +// UnmarshalJSONFrom implements [jsonv2.UnmarshalerFrom]. +// It must only be called on an uninitialized Slice. +func (v *Slice[T]) UnmarshalJSONFrom(dec *jsontext.Decoder) error { + if v.Đļ != nil { + return errors.New("already initialized") + } + return jsonv2.UnmarshalDecode(dec, &v.Đļ) } // IsNil reports whether the underlying slice is nil. @@ -309,6 +364,20 @@ func (v Slice[T]) ContainsFunc(f func(T) bool) bool { return slices.ContainsFunc(v.Đļ, f) } +// MaxFunc returns the maximal value in v, using cmp to compare elements. It +// panics if v is empty. If there is more than one maximal element according to +// the cmp function, MaxFunc returns the first one. See also [slices.MaxFunc]. +func (v Slice[T]) MaxFunc(cmp func(a, b T) int) T { + return slices.MaxFunc(v.Đļ, cmp) +} + +// MinFunc returns the minimal value in v, using cmp to compare elements. It +// panics if v is empty. If there is more than one minimal element according to +// the cmp function, MinFunc returns the first one. See also [slices.MinFunc]. +func (v Slice[T]) MinFunc(cmp func(a, b T) int) T { + return slices.MinFunc(v.Đļ, cmp) +} + // AppendStrings appends the string representation of each element in v to dst. func AppendStrings[T fmt.Stringer](dst []string, v Slice[T]) []string { for _, x := range v.Đļ { @@ -329,6 +398,26 @@ func SliceEqual[T comparable](a, b Slice[T]) bool { return slices.Equal(a.Đļ, b.Đļ) } +// SliceMax returns the maximal value in v. It panics if v is empty. For +// floating point T, SliceMax propagates NaNs (any NaN value in v forces the +// output to be NaN). See also [slices.Max]. +func SliceMax[T cmp.Ordered](v Slice[T]) T { + return slices.Max(v.Đļ) +} + +// SliceMin returns the minimal value in v. It panics if v is empty. For +// floating point T, SliceMin propagates NaNs (any NaN value in v forces the +// output to be NaN). See also [slices.Min]. +func SliceMin[T cmp.Ordered](v Slice[T]) T { + return slices.Min(v.Đļ) +} + +// shortOOOLen (short Out-of-Order length) is the slice length at or +// under which we attempt to compare two slices quadratically rather +// than allocating memory for a map in SliceEqualAnyOrder and +// SliceEqualAnyOrderFunc. +const shortOOOLen = 5 + // SliceEqualAnyOrder reports whether a and b contain the same elements, regardless of order. // The underlying slices for a and b can be nil. func SliceEqualAnyOrder[T comparable](a, b Slice[T]) bool { @@ -346,13 +435,63 @@ func SliceEqualAnyOrder[T comparable](a, b Slice[T]) bool { return true } - // count the occurrences of remaining values and compare - valueCount := make(map[T]int) - for i, n := diffStart, a.Len(); i < n; i++ { - valueCount[a.At(i)]++ - valueCount[b.At(i)]-- + a, b = a.SliceFrom(diffStart), b.SliceFrom(diffStart) + cmp := func(v T) T { return v } + + // For a small number of items, avoid the allocation of a map and just + // do the quadratic thing. + if a.Len() <= shortOOOLen { + return unorderedSliceEqualAnyOrderSmall(a, b, cmp) + } + return unorderedSliceEqualAnyOrder(a, b, cmp) +} + +// SliceEqualAnyOrderFunc reports whether a and b contain the same elements, +// regardless of order. The underlying slices for a and b can be nil. +// +// The provided function should return a comparable value for each element. +func SliceEqualAnyOrderFunc[T any, V comparable](a, b Slice[T], cmp func(T) V) bool { + if a.Len() != b.Len() { + return false + } + + var diffStart int // beginning index where a and b differ + for n := a.Len(); diffStart < n; diffStart++ { + av := cmp(a.At(diffStart)) + bv := cmp(b.At(diffStart)) + if av != bv { + break + } + } + if diffStart == a.Len() { + return true + } + + a, b = a.SliceFrom(diffStart), b.SliceFrom(diffStart) + // For a small number of items, avoid the allocation of a map and just + // do the quadratic thing. + if a.Len() <= shortOOOLen { + return unorderedSliceEqualAnyOrderSmall(a, b, cmp) + } + return unorderedSliceEqualAnyOrder(a, b, cmp) +} + +// unorderedSliceEqualAnyOrder reports whether a and b contain the same elements +// using a map. The cmp function maps from a T slice element to a comparable +// value. +func unorderedSliceEqualAnyOrder[T any, V comparable](a, b Slice[T], cmp func(T) V) bool { + if a.Len() != b.Len() { + panic("internal error") } - for _, count := range valueCount { + if a.Len() == 0 { + return true + } + m := make(map[V]int) + for i := range a.Len() { + m[cmp(a.At(i))]++ + m[cmp(b.At(i))]-- + } + for _, count := range m { if count != 0 { return false } @@ -360,6 +499,60 @@ func SliceEqualAnyOrder[T comparable](a, b Slice[T]) bool { return true } +// unorderedSliceEqualAnyOrderSmall reports whether a and b (which must be the +// same length, and shortOOOLen or shorter) contain the same elements (using cmp +// to map from T to a comparable value) in some order. +// +// This is the quadratic-time implementation for small slices that doesn't +// allocate. +func unorderedSliceEqualAnyOrderSmall[T any, V comparable](a, b Slice[T], cmp func(T) V) bool { + if a.Len() != b.Len() || a.Len() > shortOOOLen { + panic("internal error") + } + + // These track which elements in a and b have been matched, so + // that we don't treat arrays with differing number of + // duplicate elements as equal (e.g. [1, 1, 2] and [1, 2, 2]). + var aMatched, bMatched [shortOOOLen]bool + + // Compare each element in a to each element in b + for i := range a.Len() { + av := cmp(a.At(i)) + found := false + for j := range a.Len() { + // Skip elements in b that have already been + // used to match an item in a. + if bMatched[j] { + continue + } + + bv := cmp(b.At(j)) + if av == bv { + // Mark these elements as already + // matched, so that a future loop + // iteration (of a duplicate element) + // doesn't match it again. + aMatched[i] = true + bMatched[j] = true + found = true + break + } + } + if !found { + return false + } + } + + // Verify all elements were matched exactly once. + for i := range a.Len() { + if !aMatched[i] || !bMatched[i] { + return false + } + } + + return true +} + // MapSlice is a view over a map whose values are slices. type MapSlice[K comparable, V any] struct { // Đļ is the underlying mutable value, named with a hard-to-type @@ -401,28 +594,32 @@ func (m MapSlice[K, V]) GetOk(k K) (Slice[V], bool) { return SliceOf(v), ok } -// MarshalJSON implements json.Marshaler. +// MarshalJSON implements [jsonv1.Marshaler]. func (m MapSlice[K, V]) MarshalJSON() ([]byte, error) { - return json.Marshal(m.Đļ) + return jsonv1.Marshal(m.Đļ) } -// UnmarshalJSON implements json.Unmarshaler. +// MarshalJSONTo implements [jsonv2.MarshalerTo]. +func (m MapSlice[K, V]) MarshalJSONTo(enc *jsontext.Encoder) error { + return jsonv2.MarshalEncode(enc, m.Đļ) +} + +// UnmarshalJSON implements [jsonv1.Unmarshaler]. // It should only be called on an uninitialized Map. func (m *MapSlice[K, V]) UnmarshalJSON(b []byte) error { if m.Đļ != nil { return errors.New("already initialized") } - return json.Unmarshal(b, &m.Đļ) + return jsonv1.Unmarshal(b, &m.Đļ) } -// Range calls f for every k,v pair in the underlying map. -// It stops iteration immediately if f returns false. -func (m MapSlice[K, V]) Range(f MapRangeFn[K, Slice[V]]) { - for k, v := range m.Đļ { - if !f(k, SliceOf(v)) { - return - } +// UnmarshalJSONFrom implements [jsonv2.UnmarshalerFrom]. +// It should only be called on an uninitialized MapSlice. +func (m *MapSlice[K, V]) UnmarshalJSONFrom(dec *jsontext.Decoder) error { + if m.Đļ != nil { + return errors.New("already initialized") } + return jsonv2.UnmarshalDecode(dec, &m.Đļ) } // AsMap returns a shallow-clone of the underlying map. @@ -440,6 +637,17 @@ func (m MapSlice[K, V]) AsMap() map[K][]V { return out } +// All returns an iterator iterating over the keys and values of m. +func (m MapSlice[K, V]) All() iter.Seq2[K, Slice[V]] { + return func(yield func(K, Slice[V]) bool) { + for k, v := range m.Đļ { + if !yield(k, SliceOf(v)) { + return + } + } + } +} + // Map provides a read-only view of a map. It is the caller's responsibility to // make sure V is immutable. type Map[K comparable, V any] struct { @@ -488,18 +696,32 @@ func (m Map[K, V]) GetOk(k K) (V, bool) { return v, ok } -// MarshalJSON implements json.Marshaler. +// MarshalJSON implements [jsonv1.Marshaler]. func (m Map[K, V]) MarshalJSON() ([]byte, error) { - return json.Marshal(m.Đļ) + return jsonv1.Marshal(m.Đļ) } -// UnmarshalJSON implements json.Unmarshaler. +// MarshalJSONTo implements [jsonv2.MarshalerTo]. +func (m Map[K, V]) MarshalJSONTo(enc *jsontext.Encoder) error { + return jsonv2.MarshalEncode(enc, m.Đļ) +} + +// UnmarshalJSON implements [jsonv1.Unmarshaler]. // It should only be called on an uninitialized Map. func (m *Map[K, V]) UnmarshalJSON(b []byte) error { if m.Đļ != nil { return errors.New("already initialized") } - return json.Unmarshal(b, &m.Đļ) + return jsonv1.Unmarshal(b, &m.Đļ) +} + +// UnmarshalJSONFrom implements [jsonv2.UnmarshalerFrom]. +// It must only be called on an uninitialized Map. +func (m *Map[K, V]) UnmarshalJSONFrom(dec *jsontext.Decoder) error { + if m.Đļ != nil { + return errors.New("already initialized") + } + return jsonv2.UnmarshalDecode(dec, &m.Đļ) } // AsMap returns a shallow-clone of the underlying map. @@ -512,16 +734,59 @@ func (m Map[K, V]) AsMap() map[K]V { return maps.Clone(m.Đļ) } +// NOTE: the type constraints for MapViewsEqual and MapViewsEqualFunc are based +// on those for maps.Equal and maps.EqualFunc. + +// MapViewsEqual returns whether the two given [Map]s are equal. Both K and V +// must be comparable; if V is non-comparable, use [MapViewsEqualFunc] instead. +func MapViewsEqual[K, V comparable](a, b Map[K, V]) bool { + if a.Len() != b.Len() || a.IsNil() != b.IsNil() { + return false + } + if a.IsNil() { + return true // both nil; can exit early + } + + for k, v := range a.All() { + bv, ok := b.GetOk(k) + if !ok || v != bv { + return false + } + } + return true +} + +// MapViewsEqualFunc returns whether the two given [Map]s are equal, using the +// given function to compare two values. +func MapViewsEqualFunc[K comparable, V1, V2 any](a Map[K, V1], b Map[K, V2], eq func(V1, V2) bool) bool { + if a.Len() != b.Len() || a.IsNil() != b.IsNil() { + return false + } + if a.IsNil() { + return true // both nil; can exit early + } + + for k, v := range a.All() { + bv, ok := b.GetOk(k) + if !ok || !eq(v, bv) { + return false + } + } + return true +} + // MapRangeFn is the func called from a Map.Range call. // Implementations should return false to stop range. type MapRangeFn[K comparable, V any] func(k K, v V) (cont bool) -// Range calls f for every k,v pair in the underlying map. -// It stops iteration immediately if f returns false. -func (m Map[K, V]) Range(f MapRangeFn[K, V]) { - for k, v := range m.Đļ { - if !f(k, v) { - return +// All returns an iterator iterating over the keys +// and values of m. +func (m Map[K, V]) All() iter.Seq2[K, V] { + return func(yield func(K, V) bool) { + for k, v := range m.Đļ { + if !yield(k, v) { + return + } } } } @@ -577,16 +842,111 @@ func (m MapFn[K, T, V]) GetOk(k K) (V, bool) { return m.wrapv(v), ok } -// Range calls f for every k,v pair in the underlying map. -// It stops iteration immediately if f returns false. -func (m MapFn[K, T, V]) Range(f MapRangeFn[K, V]) { - for k, v := range m.Đļ { - if !f(k, m.wrapv(v)) { - return +// All returns an iterator iterating over the keys and value views of m. +func (m MapFn[K, T, V]) All() iter.Seq2[K, V] { + return func(yield func(K, V) bool) { + for k, v := range m.Đļ { + if !yield(k, m.wrapv(v)) { + return + } } } } +// ValuePointer provides a read-only view of a pointer to a value type, +// such as a primitive type or an immutable struct. Its Value and ValueOk +// methods return a stack-allocated shallow copy of the underlying value. +// It is the caller's responsibility to ensure that T +// is free from memory aliasing/mutation concerns. +type ValuePointer[T any] struct { + // Đļ is the underlying value, named with a hard-to-type + // character that looks pointy like a pointer. + // It is named distinctively to make you think of how dangerous it is to escape + // to callers. You must not let callers be able to mutate it. + Đļ *T +} + +// Valid reports whether the underlying pointer is non-nil. +func (p ValuePointer[T]) Valid() bool { + return p.Đļ != nil +} + +// Get returns a shallow copy of the value if the underlying pointer is non-nil. +// Otherwise, it returns a zero value. +func (p ValuePointer[T]) Get() T { + v, _ := p.GetOk() + return v +} + +// GetOk returns a shallow copy of the underlying value and true if the underlying +// pointer is non-nil. Otherwise, it returns a zero value and false. +func (p ValuePointer[T]) GetOk() (value T, ok bool) { + if p.Đļ == nil { + return value, false // value holds a zero value + } + return *p.Đļ, true +} + +// GetOr returns a shallow copy of the underlying value if it is non-nil. +// Otherwise, it returns the provided default value. +func (p ValuePointer[T]) GetOr(def T) T { + if p.Đļ == nil { + return def + } + return *p.Đļ +} + +// Clone returns a shallow copy of the underlying value. +func (p ValuePointer[T]) Clone() *T { + if p.Đļ == nil { + return nil + } + return ptr.To(*p.Đļ) +} + +// String implements [fmt.Stringer]. +func (p ValuePointer[T]) String() string { + if p.Đļ == nil { + return "nil" + } + return fmt.Sprint(p.Đļ) +} + +// ValuePointerOf returns an immutable view of a pointer to an immutable value. +// It is the caller's responsibility to ensure that T +// is free from memory aliasing/mutation concerns. +func ValuePointerOf[T any](v *T) ValuePointer[T] { + return ValuePointer[T]{v} +} + +// MarshalJSON implements [jsonv1.Marshaler]. +func (p ValuePointer[T]) MarshalJSON() ([]byte, error) { + return jsonv1.Marshal(p.Đļ) +} + +// MarshalJSONTo implements [jsonv2.MarshalerTo]. +func (p ValuePointer[T]) MarshalJSONTo(enc *jsontext.Encoder) error { + return jsonv2.MarshalEncode(enc, p.Đļ) +} + +// UnmarshalJSON implements [jsonv1.Unmarshaler]. +// It must only be called on an uninitialized ValuePointer. +func (p *ValuePointer[T]) UnmarshalJSON(b []byte) error { + if p.Đļ != nil { + return errors.New("already initialized") + } + return jsonv1.Unmarshal(b, &p.Đļ) +} + +// UnmarshalJSONFrom implements [jsonv2.UnmarshalerFrom]. +// It must only be called on an uninitialized ValuePointer. +func (p *ValuePointer[T]) UnmarshalJSONFrom(dec *jsontext.Decoder) error { + if p.Đļ != nil { + return errors.New("already initialized") + } + return jsonv2.UnmarshalDecode(dec, &p.Đļ) +} + // ContainsPointers reports whether T contains any pointers, // either explicitly or implicitly. // It has special handling for some types that contain pointers diff --git a/types/views/views_test.go b/types/views/views_test.go index ec7dcec4c..5a30c11a1 100644 --- a/types/views/views_test.go +++ b/types/views/views_test.go @@ -4,8 +4,7 @@ package views import ( - "bytes" - "encoding/json" + jsonv1 "encoding/json" "fmt" "net/netip" "reflect" @@ -15,8 +14,27 @@ import ( "unsafe" qt "github.com/frankban/quicktest" + jsonv2 "github.com/go-json-experiment/json" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "tailscale.com/types/structs" ) +// Statically verify that each type implements the following interfaces. +var _ = []interface { + jsonv1.Marshaler + jsonv1.Unmarshaler + jsonv2.MarshalerTo + jsonv2.UnmarshalerFrom +}{ + (*ByteSlice[[]byte])(nil), + (*SliceView[*testStruct, testStructView])(nil), + (*Slice[testStruct])(nil), + (*MapSlice[*testStruct, testStructView])(nil), + (*Map[*testStruct, testStructView])(nil), + (*ValuePointer[testStruct])(nil), +} + type viewStruct struct { Int int Addrs Slice[netip.Prefix] @@ -82,14 +100,16 @@ func TestViewsJSON(t *testing.T) { ipp := SliceOf(mustCIDR("192.168.0.0/24")) ss := SliceOf([]string{"bar"}) tests := []struct { - name string - in viewStruct - wantJSON string + name string + in viewStruct + wantJSONv1 string + wantJSONv2 string }{ { - name: "empty", - in: viewStruct{}, - wantJSON: `{"Int":0,"Addrs":null,"Strings":null}`, + name: "empty", + in: viewStruct{}, + wantJSONv1: `{"Int":0,"Addrs":null,"Strings":null}`, + wantJSONv2: `{"Int":0,"Addrs":[],"Strings":[]}`, }, { name: "everything", @@ -100,30 +120,49 @@ func TestViewsJSON(t *testing.T) { StringsPtr: &ss, Strings: ss, }, - wantJSON: `{"Int":1234,"Addrs":["192.168.0.0/24"],"Strings":["bar"],"AddrsPtr":["192.168.0.0/24"],"StringsPtr":["bar"]}`, + wantJSONv1: `{"Int":1234,"Addrs":["192.168.0.0/24"],"Strings":["bar"],"AddrsPtr":["192.168.0.0/24"],"StringsPtr":["bar"]}`, + wantJSONv2: `{"Int":1234,"Addrs":["192.168.0.0/24"],"Strings":["bar"],"AddrsPtr":["192.168.0.0/24"],"StringsPtr":["bar"]}`, }, } - var buf bytes.Buffer - encoder := json.NewEncoder(&buf) - encoder.SetIndent("", "") for _, tc := range tests { - buf.Reset() - if err := encoder.Encode(&tc.in); err != nil { - t.Fatal(err) - } - b := buf.Bytes() - gotJSON := strings.TrimSpace(string(b)) - if tc.wantJSON != gotJSON { - t.Fatalf("JSON: %v; want: %v", gotJSON, tc.wantJSON) - } - var got viewStruct - if err := json.Unmarshal(b, &got); err != nil { - t.Fatal(err) - } - if !reflect.DeepEqual(got, tc.in) { - t.Fatalf("unmarshal resulted in different output: %+v; want %+v", got, tc.in) + cmpOpts := cmp.Options{ + cmp.AllowUnexported(Slice[string]{}), + cmp.AllowUnexported(Slice[netip.Prefix]{}), + cmpopts.EquateComparable(netip.Prefix{}), } + t.Run("JSONv1", func(t *testing.T) { + gotJSON, err := jsonv1.Marshal(tc.in) + if err != nil { + t.Fatal(err) + } + if string(gotJSON) != tc.wantJSONv1 { + t.Fatalf("JSON: %s; want: %s", gotJSON, tc.wantJSONv1) + } + var got viewStruct + if err := jsonv1.Unmarshal(gotJSON, &got); err != nil { + t.Fatal(err) + } + if d := cmp.Diff(got, tc.in, cmpOpts); d != "" { + t.Fatalf("unmarshal mismatch (-got +want):\n%s", d) + } + }) + t.Run("JSONv2", func(t *testing.T) { + gotJSON, err := jsonv2.Marshal(tc.in) + if err != nil { + t.Fatal(err) + } + if string(gotJSON) != tc.wantJSONv2 { + t.Fatalf("JSON: %s; want: %s", gotJSON, tc.wantJSONv2) + } + var got viewStruct + if err := jsonv2.Unmarshal(gotJSON, &got); err != nil { + t.Fatal(err) + } + if d := cmp.Diff(got, tc.in, cmpOpts, cmpopts.EquateEmpty()); d != "" { + t.Fatalf("unmarshal mismatch (-got +want):\n%s", d) + } + }) } } @@ -152,6 +191,161 @@ func TestViewUtils(t *testing.T) { qt.Equals, true) } +func TestSliceEqualAnyOrderFunc(t *testing.T) { + type nc struct { + _ structs.Incomparable + v string + } + + // ncFrom returns a Slice[nc] from a slice of []string + ncFrom := func(s ...string) Slice[nc] { + var out []nc + for _, v := range s { + out = append(out, nc{v: v}) + } + return SliceOf(out) + } + + // cmp returns a comparable value for a nc + cmp := func(a nc) string { return a.v } + + v := ncFrom("foo", "bar") + c := qt.New(t) + + // Simple case of slice equal to itself. + c.Check(SliceEqualAnyOrderFunc(v, v, cmp), qt.Equals, true) + + // Different order. + c.Check(SliceEqualAnyOrderFunc(v, ncFrom("bar", "foo"), cmp), qt.Equals, true) + + // Different values, same length + c.Check(SliceEqualAnyOrderFunc(v, ncFrom("foo", "baz"), cmp), qt.Equals, false) + + // Different values, different length + c.Check(SliceEqualAnyOrderFunc(v, ncFrom("foo"), cmp), qt.Equals, false) + + // Nothing shared + c.Check(SliceEqualAnyOrderFunc(v, ncFrom("baz", "qux"), cmp), qt.Equals, false) + + // Long slice that matches + longSlice := ncFrom("a", "b", "c", "d", "e", "f", "g", "h", "i", "j") + longSame := ncFrom("b", "a", "c", "d", "e", "f", "g", "h", "i", "j") // first 2 elems swapped + c.Check(SliceEqualAnyOrderFunc(longSlice, longSame, cmp), qt.Equals, true) + + // Long difference; past the quadratic limit + longDiff := ncFrom("b", "a", "c", "d", "e", "f", "g", "h", "i", "k") // differs at end + c.Check(SliceEqualAnyOrderFunc(longSlice, longDiff, cmp), qt.Equals, false) + + // The short slice optimization had a bug where it wouldn't handle + // duplicate elements; test various cases here driven by code coverage. + shortTestCases := []struct { + name string + s1, s2 Slice[nc] + want bool + }{ + { + name: "duplicates_same_length", + s1: ncFrom("a", "a", "b"), + s2: ncFrom("a", "b", "b"), + want: false, + }, + { + name: "duplicates_different_matched", + s1: ncFrom("x", "y", "a", "a", "b"), + s2: ncFrom("x", "y", "b", "a", "a"), + want: true, + }, + { + name: "item_in_a_not_b", + s1: ncFrom("x", "y", "a", "b", "c"), + s2: ncFrom("x", "y", "b", "c", "q"), + want: false, + }, + } + for _, tc := range shortTestCases { + t.Run("short_"+tc.name, func(t *testing.T) { + c.Check(SliceEqualAnyOrderFunc(tc.s1, tc.s2, cmp), qt.Equals, tc.want) + }) + } +} + +func TestSliceEqualAnyOrderAllocs(t *testing.T) { + ss := func(s ...string) Slice[string] { return SliceOf(s) } + cmp := func(s string) string { return s } + + t.Run("no-allocs-short-unordered", func(t *testing.T) { + // No allocations for short comparisons + short1 := ss("a", "b", "c") + short2 := ss("c", "b", "a") + if n := testing.AllocsPerRun(1000, func() { + if !SliceEqualAnyOrder(short1, short2) { + t.Fatal("not equal") + } + if !SliceEqualAnyOrderFunc(short1, short2, cmp) { + t.Fatal("not equal") + } + }); n > 0 { + t.Fatalf("allocs = %v; want 0", n) + } + }) + + t.Run("no-allocs-long-match", func(t *testing.T) { + long1 := ss("a", "b", "c", "d", "e", "f", "g", "h", "i", "j") + long2 := ss("a", "b", "c", "d", "e", "f", "g", "h", "i", "j") + + if n := testing.AllocsPerRun(1000, func() { + if !SliceEqualAnyOrder(long1, long2) { + t.Fatal("not equal") + } + if !SliceEqualAnyOrderFunc(long1, long2, cmp) { + t.Fatal("not equal") + } + }); n > 0 { + t.Fatalf("allocs = %v; want 0", n) + } + }) + + t.Run("allocs-long-unordered", func(t *testing.T) { + // We do unfortunately allocate for long comparisons. + long1 := ss("a", "b", "c", "d", "e", "f", "g", "h", "i", "j") + long2 := ss("c", "b", "a", "e", "d", "f", "g", "h", "i", "j") + + if n := testing.AllocsPerRun(1000, func() { + if !SliceEqualAnyOrder(long1, long2) { + t.Fatal("not equal") + } + if !SliceEqualAnyOrderFunc(long1, long2, cmp) { + t.Fatal("not equal") + } + }); n == 0 { + t.Fatalf("unexpectedly didn't allocate") + } + }) +} + +func BenchmarkSliceEqualAnyOrder(b *testing.B) { + b.Run("short", func(b *testing.B) { + b.ReportAllocs() + s1 := SliceOf([]string{"foo", "bar"}) + s2 := SliceOf([]string{"bar", "foo"}) + for range b.N { + if !SliceEqualAnyOrder(s1, s2) { + b.Fatal() + } + } + }) + b.Run("long", func(b *testing.B) { + b.ReportAllocs() + s1 := SliceOf([]string{"a", "b", "c", "d", "e", "f", "g", "h", "i", "j"}) + s2 := SliceOf([]string{"c", "b", "a", "e", "d", "f", "g", "h", "i", "j"}) + for range b.N { + if !SliceEqualAnyOrder(s1, s2) { + b.Fatal() + } + } + }) +} + func TestSliceEqual(t *testing.T) { a := SliceOf([]string{"foo", "bar"}) b := SliceOf([]string{"foo", "bar"}) @@ -446,6 +640,7 @@ func (v testStructView) AsStruct() *testStruct { } return v.p.Clone() } +func (v testStructView) ValueForTest() string { return v.p.value } func TestSliceViewRange(t *testing.T) { vs := SliceOfViews([]*testStruct{{value: "foo"}, {value: "bar"}}) @@ -458,3 +653,129 @@ func TestSliceViewRange(t *testing.T) { t.Errorf("got %q; want %q", got, want) } } + +func TestMapIter(t *testing.T) { + m := MapOf(map[string]int{"foo": 1, "bar": 2}) + var got []string + for k, v := range m.All() { + got = append(got, fmt.Sprintf("%s-%d", k, v)) + } + slices.Sort(got) + want := []string{"bar-2", "foo-1"} + if !slices.Equal(got, want) { + t.Errorf("got %q; want %q", got, want) + } +} + +func TestMapSliceIter(t *testing.T) { + m := MapSliceOf(map[string][]int{"foo": {3, 4}, "bar": {1, 2}}) + var got []string + for k, v := range m.All() { + got = append(got, fmt.Sprintf("%s-%d", k, v)) + } + slices.Sort(got) + want := []string{"bar-{[1 2]}", "foo-{[3 4]}"} + if !slices.Equal(got, want) { + t.Errorf("got %q; want %q", got, want) + } +} + +func TestMapFnIter(t *testing.T) { + m := MapFnOf[string, *testStruct, testStructView](map[string]*testStruct{ + "foo": {value: "fooVal"}, + "bar": {value: "barVal"}, + }, func(p *testStruct) testStructView { return testStructView{p} }) + var got []string + for k, v := range m.All() { + got = append(got, fmt.Sprintf("%v-%v", k, v.ValueForTest())) + } + slices.Sort(got) + want := []string{"bar-barVal", "foo-fooVal"} + if !slices.Equal(got, want) { + t.Errorf("got %q; want %q", got, want) + } +} + +func TestMapViewsEqual(t *testing.T) { + testCases := []struct { + name string + a, b map[string]string + want bool + }{ + { + name: "both_nil", + a: nil, + b: nil, + want: true, + }, + { + name: "both_empty", + a: map[string]string{}, + b: map[string]string{}, + want: true, + }, + { + name: "one_nil", + a: nil, + b: map[string]string{"a": "1"}, + want: false, + }, + { + name: "different_length", + a: map[string]string{"a": "1"}, + b: map[string]string{"a": "1", "b": "2"}, + want: false, + }, + { + name: "different_values", + a: map[string]string{"a": "1"}, + b: map[string]string{"a": "2"}, + want: false, + }, + { + name: "different_keys", + a: map[string]string{"a": "1"}, + b: map[string]string{"b": "1"}, + want: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + got := MapViewsEqual(MapOf(tc.a), MapOf(tc.b)) + if got != tc.want { + t.Errorf("MapViewsEqual: got=%v, want %v", got, tc.want) + } + + got = MapViewsEqualFunc(MapOf(tc.a), MapOf(tc.b), func(a, b string) bool { + return a == b + }) + if got != tc.want { + t.Errorf("MapViewsEqualFunc: got=%v, want %v", got, tc.want) + } + }) + } +} + +func TestMapViewsEqualFunc(t *testing.T) { + // Test that we can compare maps with two different non-comparable + // values using a custom comparison function. + type customStruct1 struct { + _ structs.Incomparable + Field1 string + } + type customStruct2 struct { + _ structs.Incomparable + Field2 string + } + + a := map[string]customStruct1{"a": {Field1: "1"}} + b := map[string]customStruct2{"a": {Field2: "1"}} + + got := MapViewsEqualFunc(MapOf(a), MapOf(b), func(a customStruct1, b customStruct2) bool { + return a.Field1 == b.Field2 + }) + if !got { + t.Errorf("MapViewsEqualFunc: got=%v, want true", got) + } +} diff --git a/update-flake.sh b/update-flake.sh index 4561183b8..c22572b86 100755 --- a/update-flake.sh +++ b/update-flake.sh @@ -10,6 +10,14 @@ rm -rf "$OUT" ./tool/go run tailscale.com/cmd/nardump --sri "$OUT" >go.mod.sri rm -rf "$OUT" +GOOUT=$(mktemp -d -t gocross-XXXXXX) +GOREV=$(xargs < ./go.toolchain.rev) +TARBALL="$GOOUT/go-$GOREV.tar.gz" +curl -Ls -o "$TARBALL" "https://github.com/tailscale/go/archive/$GOREV.tar.gz" +tar -xzf "$TARBALL" -C "$GOOUT" +./tool/go run tailscale.com/cmd/nardump --sri "$GOOUT/go-$GOREV" > go.toolchain.rev.sri +rm -rf "$GOOUT" + # nix-direnv only watches the top-level nix file for changes. As a # result, when we change a referenced SRI file, we have to cause some # change to shell.nix and flake.nix as well, so that nix-direnv diff --git a/logtail/backoff/backoff.go b/util/backoff/backoff.go similarity index 94% rename from logtail/backoff/backoff.go rename to util/backoff/backoff.go index c6aeae998..95089fc24 100644 --- a/logtail/backoff/backoff.go +++ b/util/backoff/backoff.go @@ -78,3 +78,9 @@ func (b *Backoff) BackOff(ctx context.Context, err error) { case <-tChannel: } } + +// Reset resets the backoff schedule, equivalent to calling BackOff with a nil +// error. +func (b *Backoff) Reset() { + b.n = 0 +} diff --git a/util/cache/cache_test.go b/util/cache/cache_test.go deleted file mode 100644 index a6683e12d..000000000 --- a/util/cache/cache_test.go +++ /dev/null @@ -1,199 +0,0 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package cache - -import ( - "errors" - "testing" - "time" -) - -var startTime = time.Date(2023, time.March, 1, 0, 0, 0, 0, time.UTC) - -func TestSingleCache(t *testing.T) { - testTime := startTime - timeNow := func() time.Time { return testTime } - c := &Single[string, int]{ - timeNow: timeNow, - } - - t.Run("NoServeExpired", func(t *testing.T) { - testCacheImpl(t, c, &testTime, false) - }) - - t.Run("ServeExpired", func(t *testing.T) { - c.Empty() - c.ServeExpired = true - testTime = startTime - testCacheImpl(t, c, &testTime, true) - }) -} - -func TestLocking(t *testing.T) { - testTime := startTime - timeNow := func() time.Time { return testTime } - c := NewLocking(&Single[string, int]{ - timeNow: timeNow, - }) - - // Just verify that the inner cache's behaviour hasn't changed. - testCacheImpl(t, c, &testTime, false) -} - -func testCacheImpl(t *testing.T, c Cache[string, int], testTime *time.Time, serveExpired bool) { - var fillTime time.Time - t.Run("InitialFill", func(t *testing.T) { - fillTime = testTime.Add(time.Hour) - val, err := c.Get("key", func() (int, time.Time, error) { - return 123, fillTime, nil - }) - if err != nil { - t.Fatal(err) - } - if val != 123 { - t.Fatalf("got val=%d; want 123", val) - } - }) - - // Fetching again won't call our fill function - t.Run("SecondFetch", func(t *testing.T) { - *testTime = fillTime.Add(-1 * time.Second) - called := false - val, err := c.Get("key", func() (int, time.Time, error) { - called = true - return -1, fillTime, nil - }) - if called { - t.Fatal("wanted no call to fill function") - } - if err != nil { - t.Fatal(err) - } - if val != 123 { - t.Fatalf("got val=%d; want 123", val) - } - }) - - // Fetching after the expiry time will re-fill - t.Run("ReFill", func(t *testing.T) { - *testTime = fillTime.Add(1) - fillTime = fillTime.Add(time.Hour) - val, err := c.Get("key", func() (int, time.Time, error) { - return 999, fillTime, nil - }) - if err != nil { - t.Fatal(err) - } - if val != 999 { - t.Fatalf("got val=%d; want 999", val) - } - }) - - // An error on fetch will serve the expired value. - t.Run("FetchError", func(t *testing.T) { - if !serveExpired { - t.Skipf("not testing ServeExpired") - } - - *testTime = fillTime.Add(time.Hour + 1) - val, err := c.Get("key", func() (int, time.Time, error) { - return 0, time.Time{}, errors.New("some error") - }) - if err != nil { - t.Fatal(err) - } - if val != 999 { - t.Fatalf("got val=%d; want 999", val) - } - }) - - // Fetching a different key re-fills - t.Run("DifferentKey", func(t *testing.T) { - *testTime = fillTime.Add(time.Hour + 1) - - var calls int - val, err := c.Get("key1", func() (int, time.Time, error) { - calls++ - return 123, fillTime, nil - }) - if err != nil { - t.Fatal(err) - } - if val != 123 { - t.Fatalf("got val=%d; want 123", val) - } - if calls != 1 { - t.Errorf("got %d, want 1 call", calls) - } - - val, err = c.Get("key2", func() (int, time.Time, error) { - calls++ - return 456, fillTime, nil - }) - if err != nil { - t.Fatal(err) - } - if val != 456 { - t.Fatalf("got val=%d; want 456", val) - } - if calls != 2 { - t.Errorf("got %d, want 2 call", calls) - } - }) - - // Calling Forget with the wrong key does nothing, and with the correct - // key will drop the cache. - t.Run("Forget", func(t *testing.T) { - // Add some time so that previously-cached values don't matter. - fillTime = testTime.Add(2 * time.Hour) - *testTime = fillTime.Add(-1 * time.Second) - - const key = "key" - - var calls int - val, err := c.Get(key, func() (int, time.Time, error) { - calls++ - return 123, fillTime, nil - }) - if err != nil { - t.Fatal(err) - } - if val != 123 { - t.Fatalf("got val=%d; want 123", val) - } - if calls != 1 { - t.Errorf("got %d, want 1 call", calls) - } - - // Forgetting the wrong key does nothing - c.Forget("other") - val, err = c.Get(key, func() (int, time.Time, error) { - t.Fatal("should not be called") - panic("unreachable") - }) - if err != nil { - t.Fatal(err) - } - if val != 123 { - t.Fatalf("got val=%d; want 123", val) - } - - // Forgetting the correct key re-fills - c.Forget(key) - - val, err = c.Get("key2", func() (int, time.Time, error) { - calls++ - return 456, fillTime, nil - }) - if err != nil { - t.Fatal(err) - } - if val != 456 { - t.Fatalf("got val=%d; want 456", val) - } - if calls != 2 { - t.Errorf("got %d, want 2 call", calls) - } - }) -} diff --git a/util/cache/interface.go b/util/cache/interface.go deleted file mode 100644 index 0db87ba0e..000000000 --- a/util/cache/interface.go +++ /dev/null @@ -1,40 +0,0 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package cache contains an interface for a cache around a typed value, and -// various cache implementations that implement that interface. -package cache - -import "time" - -// Cache is the interface for the cache types in this package. -// -// Functions in this interface take a key parameter, but it is valid for a -// cache type to hold a single value associated with a key, and simply drop the -// cached value if provided with a different key. -// -// It is valid for Cache implementations to be concurrency-safe or not, and -// each implementation should document this. If you need a concurrency-safe -// cache, an existing cache can be wrapped with a lock using NewLocking(inner). -// -// K and V should be types that can be successfully passed to json.Marshal. -type Cache[K comparable, V any] interface { - // Get should return a previously-cached value or call the provided - // FillFunc to obtain a new one. The provided key can be used either to - // allow multiple cached values, or to drop the cache if the key - // changes; either is valid. - Get(K, FillFunc[V]) (V, error) - - // Forget should remove the given key from the cache, if it is present. - // If it is not present, nothing should be done. - Forget(K) - - // Empty should empty the cache such that the next call to Get should - // call the provided FillFunc for all possible keys. - Empty() -} - -// FillFunc is the signature of a function for filling a cache. It should -// return the value to be cached, the time that the cached value is valid -// until, or an error. -type FillFunc[T any] func() (T, time.Time, error) diff --git a/util/cache/locking.go b/util/cache/locking.go deleted file mode 100644 index 85e44b360..000000000 --- a/util/cache/locking.go +++ /dev/null @@ -1,43 +0,0 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package cache - -import "sync" - -// Locking wraps an inner Cache implementation with a mutex, making it -// safe for concurrent use. All methods are serialized on the same mutex. -type Locking[K comparable, V any, C Cache[K, V]] struct { - sync.Mutex - inner C -} - -// NewLocking creates a new Locking cache wrapping inner. -func NewLocking[K comparable, V any, C Cache[K, V]](inner C) *Locking[K, V, C] { - return &Locking[K, V, C]{inner: inner} -} - -// Get implements Cache. -// -// The cache's mutex is held for the entire duration of this function, -// including while the FillFunc is being called. This function is not -// reentrant; attempting to call Get from a FillFunc will deadlock. -func (c *Locking[K, V, C]) Get(key K, f FillFunc[V]) (V, error) { - c.Lock() - defer c.Unlock() - return c.inner.Get(key, f) -} - -// Forget implements Cache. -func (c *Locking[K, V, C]) Forget(key K) { - c.Lock() - defer c.Unlock() - c.inner.Forget(key) -} - -// Empty implements Cache. -func (c *Locking[K, V, C]) Empty() { - c.Lock() - defer c.Unlock() - c.inner.Empty() -} diff --git a/util/cache/none.go b/util/cache/none.go deleted file mode 100644 index c4073e0d9..000000000 --- a/util/cache/none.go +++ /dev/null @@ -1,23 +0,0 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package cache - -// None provides no caching and always calls the provided FillFunc. -// -// It is safe for concurrent use if the underlying FillFunc is. -type None[K comparable, V any] struct{} - -var _ Cache[int, int] = None[int, int]{} - -// Get always calls the provided FillFunc and returns what it does. -func (c None[K, V]) Get(_ K, f FillFunc[V]) (V, error) { - v, _, e := f() - return v, e -} - -// Forget implements Cache. -func (None[K, V]) Forget(K) {} - -// Empty implements Cache. -func (None[K, V]) Empty() {} diff --git a/util/cache/single.go b/util/cache/single.go deleted file mode 100644 index 6b9ac2c11..000000000 --- a/util/cache/single.go +++ /dev/null @@ -1,81 +0,0 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package cache - -import ( - "time" -) - -// Single is a simple in-memory cache that stores a single value until a -// defined time before it is re-fetched. It also supports returning a -// previously-expired value if refreshing the value in the cache fails. -// -// Single is not safe for concurrent use. -type Single[K comparable, V any] struct { - key K - val V - goodUntil time.Time - timeNow func() time.Time // for tests - - // ServeExpired indicates that if an error occurs when filling the - // cache, an expired value can be returned instead of an error. - // - // This value should only be set when this struct is created. - ServeExpired bool -} - -var _ Cache[int, int] = (*Single[int, int])(nil) - -// Get will return the cached value, if any, or fill the cache by calling f and -// return the corresponding value. If f returns an error and c.ServeExpired is -// true, then a previous expired value can be returned with no error. -func (c *Single[K, V]) Get(key K, f FillFunc[V]) (V, error) { - var now time.Time - if c.timeNow != nil { - now = c.timeNow() - } else { - now = time.Now() - } - - if c.key == key && now.Before(c.goodUntil) { - return c.val, nil - } - - // Re-fill cached entry - val, until, err := f() - if err == nil { - c.key = key - c.val = val - c.goodUntil = until - return val, nil - } - - // Never serve an expired entry for the wrong key. - if c.key == key && c.ServeExpired && !c.goodUntil.IsZero() { - return c.val, nil - } - - var zero V - return zero, err -} - -// Forget implements Cache. -func (c *Single[K, V]) Forget(key K) { - if c.key != key { - return - } - - c.Empty() -} - -// Empty implements Cache. -func (c *Single[K, V]) Empty() { - c.goodUntil = time.Time{} - - var zeroKey K - c.key = zeroKey - - var zeroVal V - c.val = zeroVal -} diff --git a/util/checkchange/checkchange.go b/util/checkchange/checkchange.go new file mode 100644 index 000000000..8ba64720d --- /dev/null +++ b/util/checkchange/checkchange.go @@ -0,0 +1,25 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package checkchange defines a utility for determining whether a value +// has changed since the last time it was checked. +package checkchange + +// EqualCloner is an interface for types that can be compared for equality +// and can be cloned. +type EqualCloner[T any] interface { + Equal(T) bool + Clone() T +} + +// Update sets *old to a clone of new if they are not equal, returning whether +// they were different. +// +// It only modifies *old if they are different. old must be non-nil. +func Update[T EqualCloner[T]](old *T, new T) (changed bool) { + if (*old).Equal(new) { + return false + } + *old = new.Clone() + return true +} diff --git a/util/clientmetric/clientmetric.go b/util/clientmetric/clientmetric.go index b2d356b60..9e6b03a15 100644 --- a/util/clientmetric/clientmetric.go +++ b/util/clientmetric/clientmetric.go @@ -1,6 +1,8 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause +//go:build !ts_omit_clientmetrics + // Package clientmetric provides client-side metrics whose values // get occasionally logged. package clientmetric @@ -9,6 +11,7 @@ import ( "bytes" "encoding/binary" "encoding/hex" + "expvar" "fmt" "io" "sort" @@ -16,6 +19,9 @@ import ( "sync" "sync/atomic" "time" + + "tailscale.com/feature/buildfeatures" + "tailscale.com/util/set" ) var ( @@ -127,15 +133,20 @@ func (m *Metric) Publish() { metrics[m.name] = m sortedDirty = true - if m.f != nil { - lastLogVal = append(lastLogVal, scanEntry{f: m.f}) - } else { + if m.f == nil { if len(valFreeList) == 0 { valFreeList = make([]int64, 256) } m.v = &valFreeList[0] valFreeList = valFreeList[1:] - lastLogVal = append(lastLogVal, scanEntry{v: m.v}) + } + + if buildfeatures.HasLogTail { + if m.f != nil { + lastLogVal = append(lastLogVal, scanEntry{f: m.f}) + } else { + lastLogVal = append(lastLogVal, scanEntry{v: m.v}) + } } m.regIdx = len(unsorted) @@ -223,6 +234,54 @@ func NewGaugeFunc(name string, f func() int64) *Metric { return m } +// AggregateCounter returns a sum of expvar counters registered with it. +type AggregateCounter struct { + mu sync.RWMutex + counters set.Set[*expvar.Int] +} + +func (c *AggregateCounter) Value() int64 { + c.mu.RLock() + defer c.mu.RUnlock() + var sum int64 + for cnt := range c.counters { + sum += cnt.Value() + } + return sum +} + +// Register registers provided expvar counter. +// When a counter is added to the counter, it will be reset +// to start counting from 0. This is to avoid incrementing the +// counter with an unexpectedly large value. +func (c *AggregateCounter) Register(counter *expvar.Int) { + c.mu.Lock() + defer c.mu.Unlock() + // No need to do anything if it's already registered. + if c.counters.Contains(counter) { + return + } + counter.Set(0) + c.counters.Add(counter) +} + +// UnregisterAll unregisters all counters resulting in it +// starting back down at zero. This is to ensure monotonicity +// and respect the semantics of the counter. +func (c *AggregateCounter) UnregisterAll() { + c.mu.Lock() + defer c.mu.Unlock() + c.counters = set.Set[*expvar.Int]{} +} + +// NewAggregateCounter returns a new aggregate counter that returns +// a sum of expvar variables registered with it. +func NewAggregateCounter(name string) *AggregateCounter { + c := &AggregateCounter{counters: set.Set[*expvar.Int]{}} + NewCounterFunc(name, c.Value) + return c +} + // WritePrometheusExpositionFormat writes all client metrics to w in // the Prometheus text-based exposition format. // @@ -268,6 +327,9 @@ const ( // - increment a metric: (decrements if negative) // 'I' + hex(varint(wireid)) + hex(varint(value)) func EncodeLogTailMetricsDelta() string { + if !buildfeatures.HasLogTail { + return "" + } mu.Lock() defer mu.Unlock() diff --git a/util/clientmetric/clientmetric_test.go b/util/clientmetric/clientmetric_test.go index ab6c4335a..555d7a711 100644 --- a/util/clientmetric/clientmetric_test.go +++ b/util/clientmetric/clientmetric_test.go @@ -4,8 +4,11 @@ package clientmetric import ( + "expvar" "testing" "time" + + qt "github.com/frankban/quicktest" ) func TestDeltaEncBuf(t *testing.T) { @@ -107,3 +110,49 @@ func TestWithFunc(t *testing.T) { t.Errorf("second = %q; want %q", got, want) } } + +func TestAggregateCounter(t *testing.T) { + clearMetrics() + + c := qt.New(t) + + expv1 := &expvar.Int{} + expv2 := &expvar.Int{} + expv3 := &expvar.Int{} + + aggCounter := NewAggregateCounter("agg_counter") + + aggCounter.Register(expv1) + c.Assert(aggCounter.Value(), qt.Equals, int64(0)) + + expv1.Add(1) + c.Assert(aggCounter.Value(), qt.Equals, int64(1)) + + aggCounter.Register(expv2) + c.Assert(aggCounter.Value(), qt.Equals, int64(1)) + + expv1.Add(1) + expv2.Add(1) + c.Assert(aggCounter.Value(), qt.Equals, int64(3)) + + // Adding a new expvar should not change the value + // and any value the counter already had is reset + expv3.Set(5) + aggCounter.Register(expv3) + c.Assert(aggCounter.Value(), qt.Equals, int64(3)) + + // Registering the same expvar multiple times should not change the value + aggCounter.Register(expv3) + c.Assert(aggCounter.Value(), qt.Equals, int64(3)) + + aggCounter.UnregisterAll() + c.Assert(aggCounter.Value(), qt.Equals, int64(0)) + + // Start over + expv3.Set(5) + aggCounter.Register(expv3) + c.Assert(aggCounter.Value(), qt.Equals, int64(0)) + + expv3.Set(5) + c.Assert(aggCounter.Value(), qt.Equals, int64(5)) +} diff --git a/util/clientmetric/omit.go b/util/clientmetric/omit.go new file mode 100644 index 000000000..5349fc724 --- /dev/null +++ b/util/clientmetric/omit.go @@ -0,0 +1,24 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build ts_omit_clientmetrics + +package clientmetric + +type Metric struct{} + +func (*Metric) Add(int64) {} +func (*Metric) Set(int64) {} +func (*Metric) Value() int64 { return 0 } +func (*Metric) Register(expvarInt any) {} +func (*Metric) UnregisterAll() {} + +func HasPublished(string) bool { panic("unreachable") } +func EncodeLogTailMetricsDelta() string { return "" } +func WritePrometheusExpositionFormat(any) {} + +var zeroMetric Metric + +func NewCounter(string) *Metric { return &zeroMetric } +func NewGauge(string) *Metric { return &zeroMetric } +func NewAggregateCounter(string) *Metric { return &zeroMetric } diff --git a/util/cloudenv/cloudenv.go b/util/cloudenv/cloudenv.go index be60ca007..f55f7dfb0 100644 --- a/util/cloudenv/cloudenv.go +++ b/util/cloudenv/cloudenv.go @@ -16,6 +16,7 @@ import ( "strings" "time" + "tailscale.com/feature/buildfeatures" "tailscale.com/syncs" "tailscale.com/types/lazy" ) @@ -51,6 +52,9 @@ const ( // ResolverIP returns the cloud host's recursive DNS server or the // empty string if not available. func (c Cloud) ResolverIP() string { + if !buildfeatures.HasCloud { + return "" + } switch c { case GCP: return GoogleMetadataAndDNSIP @@ -92,6 +96,9 @@ var cloudAtomic syncs.AtomicValue[Cloud] // Get returns the current cloud, or the empty string if unknown. func Get() Cloud { + if !buildfeatures.HasCloud { + return "" + } if c, ok := cloudAtomic.LoadOk(); ok { return c } diff --git a/util/codegen/codegen.go b/util/codegen/codegen.go index d998d925d..ec02d652b 100644 --- a/util/codegen/codegen.go +++ b/util/codegen/codegen.go @@ -85,23 +85,35 @@ func NewImportTracker(thisPkg *types.Package) *ImportTracker { } } +type namePkgPath struct { + name string // optional import name + pkgPath string +} + // ImportTracker provides a mechanism to track and build import paths. type ImportTracker struct { thisPkg *types.Package - packages map[string]bool + packages map[namePkgPath]bool } -func (it *ImportTracker) Import(pkg string) { - if pkg != "" && !it.packages[pkg] { - mak.Set(&it.packages, pkg, true) +// Import imports pkgPath under an optional import name. +func (it *ImportTracker) Import(name, pkgPath string) { + if pkgPath != "" && !it.packages[namePkgPath{name, pkgPath}] { + mak.Set(&it.packages, namePkgPath{name, pkgPath}, true) } } +// Has reports whether the specified package path has been imported +// under the particular import name. +func (it *ImportTracker) Has(name, pkgPath string) bool { + return it.packages[namePkgPath{name, pkgPath}] +} + func (it *ImportTracker) qualifier(pkg *types.Package) string { if it.thisPkg == pkg { return "" } - it.Import(pkg.Path()) + it.Import("", pkg.Path()) // TODO(maisem): handle conflicts? return pkg.Name() } @@ -123,7 +135,11 @@ func (it *ImportTracker) PackagePrefix(pkg *types.Package) string { func (it *ImportTracker) Write(w io.Writer) { fmt.Fprintf(w, "import (\n") for s := range it.packages { - fmt.Fprintf(w, "\t%q\n", s) + if s.name == "" { + fmt.Fprintf(w, "\t%q\n", s.pkgPath) + } else { + fmt.Fprintf(w, "\t%s %q\n", s.name, s.pkgPath) + } } fmt.Fprintf(w, ")\n\n") } @@ -272,11 +288,16 @@ func IsInvalid(t types.Type) bool { // It has special handling for some types that contain pointers // that we know are free from memory aliasing/mutation concerns. func ContainsPointers(typ types.Type) bool { - switch typ.String() { + s := typ.String() + switch s { case "time.Time": - // time.Time contains a pointer that does not need copying + // time.Time contains a pointer that does not need cloning. return false - case "inet.af/netip.Addr", "net/netip.Addr", "net/netip.Prefix", "net/netip.AddrPort": + case "inet.af/netip.Addr": + return false + } + if strings.HasPrefix(s, "unique.Handle[") { + // unique.Handle contains a pointer that does not need cloning. return false } switch ft := typ.Underlying().(type) { diff --git a/util/codegen/codegen_test.go b/util/codegen/codegen_test.go index 28ddaed2b..74715eeca 100644 --- a/util/codegen/codegen_test.go +++ b/util/codegen/codegen_test.go @@ -10,6 +10,8 @@ import ( "strings" "sync" "testing" + "time" + "unique" "unsafe" "golang.org/x/exp/constraints" @@ -84,6 +86,16 @@ type PointerUnionParam[T netip.Prefix | BasicType | IntPtr] struct { V T } +type StructWithUniqueHandle struct{ _ unique.Handle[[32]byte] } + +type StructWithTime struct{ _ time.Time } + +type StructWithNetipTypes struct { + _ netip.Addr + _ netip.AddrPort + _ netip.Prefix +} + type Interface interface { Method() } @@ -161,6 +173,18 @@ func TestGenericContainsPointers(t *testing.T) { typ: "PointerUnionParam", wantPointer: true, }, + { + typ: "StructWithUniqueHandle", + wantPointer: false, + }, + { + typ: "StructWithTime", + wantPointer: false, + }, + { + typ: "StructWithNetipTypes", + wantPointer: false, + }, } for _, tt := range tests { diff --git a/util/cstruct/cstruct.go b/util/cstruct/cstruct.go index 464dc5dc3..4d1d0a98b 100644 --- a/util/cstruct/cstruct.go +++ b/util/cstruct/cstruct.go @@ -6,10 +6,9 @@ package cstruct import ( + "encoding/binary" "errors" "io" - - "github.com/josharian/native" ) // Size of a pointer-typed value, in bits @@ -120,7 +119,7 @@ func (d *Decoder) Uint16() uint16 { d.err = err return 0 } - return native.Endian.Uint16(d.dbuf[0:2]) + return binary.NativeEndian.Uint16(d.dbuf[0:2]) } // Uint32 returns a uint32 decoded from the buffer. @@ -133,7 +132,7 @@ func (d *Decoder) Uint32() uint32 { d.err = err return 0 } - return native.Endian.Uint32(d.dbuf[0:4]) + return binary.NativeEndian.Uint32(d.dbuf[0:4]) } // Uint64 returns a uint64 decoded from the buffer. @@ -146,7 +145,7 @@ func (d *Decoder) Uint64() uint64 { d.err = err return 0 } - return native.Endian.Uint64(d.dbuf[0:8]) + return binary.NativeEndian.Uint64(d.dbuf[0:8]) } // Uintptr returns a uintptr decoded from the buffer. diff --git a/util/deephash/deephash_test.go b/util/deephash/deephash_test.go index d5584def3..413893ff9 100644 --- a/util/deephash/deephash_test.go +++ b/util/deephash/deephash_test.go @@ -23,18 +23,11 @@ import ( "go4.org/mem" "go4.org/netipx" "tailscale.com/tailcfg" - "tailscale.com/types/dnstype" - "tailscale.com/types/ipproto" "tailscale.com/types/key" "tailscale.com/types/ptr" - "tailscale.com/types/views" "tailscale.com/util/deephash/testtype" - "tailscale.com/util/dnsname" "tailscale.com/util/hashx" "tailscale.com/version" - "tailscale.com/wgengine/filter" - "tailscale.com/wgengine/router" - "tailscale.com/wgengine/wgcfg" ) type appendBytes []byte @@ -197,21 +190,6 @@ func TestHash(t *testing.T) { } } -func TestDeepHash(t *testing.T) { - // v contains the types of values we care about for our current callers. - // Mostly we're just testing that we don't panic on handled types. - v := getVal() - hash1 := Hash(v) - t.Logf("hash: %v", hash1) - for range 20 { - v := getVal() - hash2 := Hash(v) - if hash1 != hash2 { - t.Error("second hash didn't match") - } - } -} - // Tests that we actually hash map elements. Whoops. func TestIssue4868(t *testing.T) { m1 := map[int]string{1: "foo"} @@ -255,110 +233,6 @@ func TestQuick(t *testing.T) { } } -type tailscaleTypes struct { - WGConfig *wgcfg.Config - RouterConfig *router.Config - MapFQDNAddrs map[dnsname.FQDN][]netip.Addr - MapFQDNAddrPorts map[dnsname.FQDN][]netip.AddrPort - MapDiscoPublics map[key.DiscoPublic]bool - MapResponse *tailcfg.MapResponse - FilterMatch filter.Match -} - -func getVal() *tailscaleTypes { - return &tailscaleTypes{ - &wgcfg.Config{ - Name: "foo", - Addresses: []netip.Prefix{netip.PrefixFrom(netip.AddrFrom16([16]byte{3: 3}).Unmap(), 5)}, - Peers: []wgcfg.Peer{ - { - PublicKey: key.NodePublic{}, - }, - }, - }, - &router.Config{ - Routes: []netip.Prefix{ - netip.MustParsePrefix("1.2.3.0/24"), - netip.MustParsePrefix("1234::/64"), - }, - }, - map[dnsname.FQDN][]netip.Addr{ - dnsname.FQDN("a."): {netip.MustParseAddr("1.2.3.4"), netip.MustParseAddr("4.3.2.1")}, - dnsname.FQDN("b."): {netip.MustParseAddr("8.8.8.8"), netip.MustParseAddr("9.9.9.9")}, - dnsname.FQDN("c."): {netip.MustParseAddr("6.6.6.6"), netip.MustParseAddr("7.7.7.7")}, - dnsname.FQDN("d."): {netip.MustParseAddr("6.7.6.6"), netip.MustParseAddr("7.7.7.8")}, - dnsname.FQDN("e."): {netip.MustParseAddr("6.8.6.6"), netip.MustParseAddr("7.7.7.9")}, - dnsname.FQDN("f."): {netip.MustParseAddr("6.9.6.6"), netip.MustParseAddr("7.7.7.0")}, - }, - map[dnsname.FQDN][]netip.AddrPort{ - dnsname.FQDN("a."): {netip.MustParseAddrPort("1.2.3.4:11"), netip.MustParseAddrPort("4.3.2.1:22")}, - dnsname.FQDN("b."): {netip.MustParseAddrPort("8.8.8.8:11"), netip.MustParseAddrPort("9.9.9.9:22")}, - dnsname.FQDN("c."): {netip.MustParseAddrPort("8.8.8.8:12"), netip.MustParseAddrPort("9.9.9.9:23")}, - dnsname.FQDN("d."): {netip.MustParseAddrPort("8.8.8.8:13"), netip.MustParseAddrPort("9.9.9.9:24")}, - dnsname.FQDN("e."): {netip.MustParseAddrPort("8.8.8.8:14"), netip.MustParseAddrPort("9.9.9.9:25")}, - }, - map[key.DiscoPublic]bool{ - key.DiscoPublicFromRaw32(mem.B([]byte{1: 1, 31: 0})): true, - key.DiscoPublicFromRaw32(mem.B([]byte{1: 2, 31: 0})): false, - key.DiscoPublicFromRaw32(mem.B([]byte{1: 3, 31: 0})): true, - key.DiscoPublicFromRaw32(mem.B([]byte{1: 4, 31: 0})): false, - }, - &tailcfg.MapResponse{ - DERPMap: &tailcfg.DERPMap{ - Regions: map[int]*tailcfg.DERPRegion{ - 1: { - RegionID: 1, - RegionCode: "foo", - Nodes: []*tailcfg.DERPNode{ - { - Name: "n1", - RegionID: 1, - HostName: "foo.com", - }, - { - Name: "n2", - RegionID: 1, - HostName: "bar.com", - }, - }, - }, - }, - }, - DNSConfig: &tailcfg.DNSConfig{ - Resolvers: []*dnstype.Resolver{ - {Addr: "10.0.0.1"}, - }, - }, - PacketFilter: []tailcfg.FilterRule{ - { - SrcIPs: []string{"1.2.3.4"}, - DstPorts: []tailcfg.NetPortRange{ - { - IP: "1.2.3.4/32", - Ports: tailcfg.PortRange{First: 1, Last: 2}, - }, - }, - }, - }, - Peers: []*tailcfg.Node{ - { - ID: 1, - }, - { - ID: 2, - }, - }, - UserProfiles: []tailcfg.UserProfile{ - {ID: 1, LoginName: "foo@bar.com"}, - {ID: 2, LoginName: "bar@foo.com"}, - }, - }, - filter.Match{ - IPProto: views.SliceOf([]ipproto.Proto{1, 2, 3}), - }, - } -} - type IntThenByte struct { _ int _ byte @@ -758,14 +632,6 @@ func TestInterfaceCycle(t *testing.T) { var sink Sum -func BenchmarkHash(b *testing.B) { - b.ReportAllocs() - v := getVal() - for range b.N { - sink = Hash(v) - } -} - // filterRules is a packet filter that has both everything populated (in its // first element) and also a few entries that are the typical shape for regular // packet filters as sent to clients. @@ -1072,16 +938,6 @@ func FuzzAddr(f *testing.F) { }) } -func TestAppendTo(t *testing.T) { - v := getVal() - h := Hash(v) - sum := h.AppendTo(nil) - - if s := h.String(); s != string(sum) { - t.Errorf("hash sum mismatch; h.String()=%q h.AppendTo()=%q", s, string(sum)) - } -} - func TestFilterFields(t *testing.T) { type T struct { A int @@ -1126,15 +982,3 @@ func TestFilterFields(t *testing.T) { } } } - -func BenchmarkAppendTo(b *testing.B) { - b.ReportAllocs() - v := getVal() - h := Hash(v) - - hashBuf := make([]byte, 0, 100) - b.ResetTimer() - for range b.N { - hashBuf = h.AppendTo(hashBuf[:0]) - } -} diff --git a/util/deephash/tailscale_types_test.go b/util/deephash/tailscale_types_test.go new file mode 100644 index 000000000..eeb7fdf84 --- /dev/null +++ b/util/deephash/tailscale_types_test.go @@ -0,0 +1,176 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// This file contains tests and benchmarks that use types from other packages +// in the Tailscale codebase. Unlike other deephash tests, these are in the _test +// package to avoid circular dependencies. + +package deephash_test + +import ( + "net/netip" + "testing" + + "go4.org/mem" + "tailscale.com/tailcfg" + "tailscale.com/types/dnstype" + "tailscale.com/types/ipproto" + "tailscale.com/types/key" + "tailscale.com/types/views" + "tailscale.com/util/dnsname" + "tailscale.com/wgengine/filter" + "tailscale.com/wgengine/router" + "tailscale.com/wgengine/wgcfg" + + . "tailscale.com/util/deephash" +) + +var sink Sum + +func BenchmarkHash(b *testing.B) { + b.ReportAllocs() + v := getVal() + for range b.N { + sink = Hash(v) + } +} + +func BenchmarkAppendTo(b *testing.B) { + b.ReportAllocs() + v := getVal() + h := Hash(v) + + hashBuf := make([]byte, 0, 100) + b.ResetTimer() + for range b.N { + hashBuf = h.AppendTo(hashBuf[:0]) + } +} + +func TestDeepHash(t *testing.T) { + // v contains the types of values we care about for our current callers. + // Mostly we're just testing that we don't panic on handled types. + v := getVal() + hash1 := Hash(v) + t.Logf("hash: %v", hash1) + for range 20 { + v := getVal() + hash2 := Hash(v) + if hash1 != hash2 { + t.Error("second hash didn't match") + } + } +} + +func TestAppendTo(t *testing.T) { + v := getVal() + h := Hash(v) + sum := h.AppendTo(nil) + + if s := h.String(); s != string(sum) { + t.Errorf("hash sum mismatch; h.String()=%q h.AppendTo()=%q", s, string(sum)) + } +} + +type tailscaleTypes struct { + WGConfig *wgcfg.Config + RouterConfig *router.Config + MapFQDNAddrs map[dnsname.FQDN][]netip.Addr + MapFQDNAddrPorts map[dnsname.FQDN][]netip.AddrPort + MapDiscoPublics map[key.DiscoPublic]bool + MapResponse *tailcfg.MapResponse + FilterMatch filter.Match +} + +func getVal() *tailscaleTypes { + return &tailscaleTypes{ + &wgcfg.Config{ + Addresses: []netip.Prefix{netip.PrefixFrom(netip.AddrFrom16([16]byte{3: 3}).Unmap(), 5)}, + Peers: []wgcfg.Peer{ + { + PublicKey: key.NodePublic{}, + }, + }, + }, + &router.Config{ + Routes: []netip.Prefix{ + netip.MustParsePrefix("1.2.3.0/24"), + netip.MustParsePrefix("1234::/64"), + }, + }, + map[dnsname.FQDN][]netip.Addr{ + dnsname.FQDN("a."): {netip.MustParseAddr("1.2.3.4"), netip.MustParseAddr("4.3.2.1")}, + dnsname.FQDN("b."): {netip.MustParseAddr("8.8.8.8"), netip.MustParseAddr("9.9.9.9")}, + dnsname.FQDN("c."): {netip.MustParseAddr("6.6.6.6"), netip.MustParseAddr("7.7.7.7")}, + dnsname.FQDN("d."): {netip.MustParseAddr("6.7.6.6"), netip.MustParseAddr("7.7.7.8")}, + dnsname.FQDN("e."): {netip.MustParseAddr("6.8.6.6"), netip.MustParseAddr("7.7.7.9")}, + dnsname.FQDN("f."): {netip.MustParseAddr("6.9.6.6"), netip.MustParseAddr("7.7.7.0")}, + }, + map[dnsname.FQDN][]netip.AddrPort{ + dnsname.FQDN("a."): {netip.MustParseAddrPort("1.2.3.4:11"), netip.MustParseAddrPort("4.3.2.1:22")}, + dnsname.FQDN("b."): {netip.MustParseAddrPort("8.8.8.8:11"), netip.MustParseAddrPort("9.9.9.9:22")}, + dnsname.FQDN("c."): {netip.MustParseAddrPort("8.8.8.8:12"), netip.MustParseAddrPort("9.9.9.9:23")}, + dnsname.FQDN("d."): {netip.MustParseAddrPort("8.8.8.8:13"), netip.MustParseAddrPort("9.9.9.9:24")}, + dnsname.FQDN("e."): {netip.MustParseAddrPort("8.8.8.8:14"), netip.MustParseAddrPort("9.9.9.9:25")}, + }, + map[key.DiscoPublic]bool{ + key.DiscoPublicFromRaw32(mem.B([]byte{1: 1, 31: 0})): true, + key.DiscoPublicFromRaw32(mem.B([]byte{1: 2, 31: 0})): false, + key.DiscoPublicFromRaw32(mem.B([]byte{1: 3, 31: 0})): true, + key.DiscoPublicFromRaw32(mem.B([]byte{1: 4, 31: 0})): false, + }, + &tailcfg.MapResponse{ + DERPMap: &tailcfg.DERPMap{ + Regions: map[int]*tailcfg.DERPRegion{ + 1: { + RegionID: 1, + RegionCode: "foo", + Nodes: []*tailcfg.DERPNode{ + { + Name: "n1", + RegionID: 1, + HostName: "foo.com", + }, + { + Name: "n2", + RegionID: 1, + HostName: "bar.com", + }, + }, + }, + }, + }, + DNSConfig: &tailcfg.DNSConfig{ + Resolvers: []*dnstype.Resolver{ + {Addr: "10.0.0.1"}, + }, + }, + PacketFilter: []tailcfg.FilterRule{ + { + SrcIPs: []string{"1.2.3.4"}, + DstPorts: []tailcfg.NetPortRange{ + { + IP: "1.2.3.4/32", + Ports: tailcfg.PortRange{First: 1, Last: 2}, + }, + }, + }, + }, + Peers: []*tailcfg.Node{ + { + ID: 1, + }, + { + ID: 2, + }, + }, + UserProfiles: []tailcfg.UserProfile{ + {ID: 1, LoginName: "foo@bar.com"}, + {ID: 2, LoginName: "bar@foo.com"}, + }, + }, + filter.Match{ + IPProto: views.SliceOf([]ipproto.Proto{1, 2, 3}), + }, + } +} diff --git a/util/dnsname/dnsname.go b/util/dnsname/dnsname.go index dde0baaed..ef898ebbd 100644 --- a/util/dnsname/dnsname.go +++ b/util/dnsname/dnsname.go @@ -5,16 +5,16 @@ package dnsname import ( - "errors" - "fmt" "strings" + + "tailscale.com/util/vizerror" ) const ( // maxLabelLength is the maximum length of a label permitted by RFC 1035. maxLabelLength = 63 // maxNameLength is the maximum length of a DNS name. - maxNameLength = 253 + maxNameLength = 254 ) // A FQDN is a fully-qualified DNS name or name suffix. @@ -36,7 +36,7 @@ func ToFQDN(s string) (FQDN, error) { totalLen += 1 // account for missing dot } if totalLen > maxNameLength { - return "", fmt.Errorf("%q is too long to be a DNS name", s) + return "", vizerror.Errorf("%q is too long to be a DNS name", s) } st := 0 @@ -54,7 +54,7 @@ func ToFQDN(s string) (FQDN, error) { // // See https://github.com/tailscale/tailscale/issues/2024 for more. if len(label) == 0 || len(label) > maxLabelLength { - return "", fmt.Errorf("%q is not a valid DNS label", label) + return "", vizerror.Errorf("%q is not a valid DNS label", label) } st = i + 1 } @@ -94,26 +94,27 @@ func (f FQDN) Contains(other FQDN) bool { return strings.HasSuffix(other.WithTrailingDot(), cmp) } -// ValidLabel reports whether label is a valid DNS label. +// ValidLabel reports whether label is a valid DNS label. All errors are +// [vizerror.Error]. func ValidLabel(label string) error { if len(label) == 0 { - return errors.New("empty DNS label") + return vizerror.New("empty DNS label") } if len(label) > maxLabelLength { - return fmt.Errorf("%q is too long, max length is %d bytes", label, maxLabelLength) + return vizerror.Errorf("%q is too long, max length is %d bytes", label, maxLabelLength) } if !isalphanum(label[0]) { - return fmt.Errorf("%q is not a valid DNS label: must start with a letter or number", label) + return vizerror.Errorf("%q is not a valid DNS label: must start with a letter or number", label) } if !isalphanum(label[len(label)-1]) { - return fmt.Errorf("%q is not a valid DNS label: must end with a letter or number", label) + return vizerror.Errorf("%q is not a valid DNS label: must end with a letter or number", label) } if len(label) < 2 { return nil } for i := 1; i < len(label)-1; i++ { if !isdnschar(label[i]) { - return fmt.Errorf("%q is not a valid DNS label: contains invalid character %q", label, label[i]) + return vizerror.Errorf("%q is not a valid DNS label: contains invalid character %q", label, label[i]) } } return nil diff --git a/util/dnsname/dnsname_test.go b/util/dnsname/dnsname_test.go index 719e28be3..b038bb1bd 100644 --- a/util/dnsname/dnsname_test.go +++ b/util/dnsname/dnsname_test.go @@ -59,6 +59,38 @@ func TestFQDN(t *testing.T) { } } +func TestFQDNTooLong(t *testing.T) { + // RFC 1035 says a dns name has a max size of 255 octets, and is represented as labels of len+ASCII chars so + // example.com + // is represented as + // 7example3com0 + // which is to say that if we have a trailing dot then the dots cancel out all the len bytes except the first and + // we can accept 254 chars. + + // This name is max length + name := "aaaaaaaaaaaaaaaaaaaaa.aaaaaaaaaaaaaaaaaaaaa.aaaaaaaaaaaaaaaaaaaaa.aaaaaaaaaaaaaaaaaaaaa.aaaaaaaaaaaaaaaaaaaaa.aaaaaaaaaaaaaaaaaaaaa.aaaaaaaaaaaaaaaaaaaaa.aaaaaaaaaaaaaaaaaaaaa.aaaaaaaaaaaaaaaaaaaaa.aaaaaaaaaaaaaaaaaaaaa.aaaaaaaaaaaaaaaaaaaaa.example.com." + if len(name) != 254 { + t.Fatalf("name should be 254 chars including trailing . (len is %d)", len(name)) + } + got, err := ToFQDN(name) + if err != nil { + t.Fatalf("want: no error, got: %v", err) + } + if string(got) != name { + t.Fatalf("want: %s, got: %s", name, got) + } + + // This name is too long + name = "x" + name + got, err = ToFQDN(name) + if got != "" { + t.Fatalf("want: \"\", got: %s", got) + } + if err == nil || !strings.HasSuffix(err.Error(), "is too long to be a DNS name") { + t.Fatalf("want: error to end with \"is too long to be a DNS name\", got: %v", err) + } +} + func TestFQDNContains(t *testing.T) { tests := []struct { a, b string diff --git a/util/eventbus/assets/event.html b/util/eventbus/assets/event.html new file mode 100644 index 000000000..8e016f583 --- /dev/null +++ b/util/eventbus/assets/event.html @@ -0,0 +1,6 @@ +
      • +
        + {{.Count}}: {{.Type}} from {{.Event.From.Name}}, {{len .Event.To}} recipients + {{.Event.Event}} +
        +
      • diff --git a/util/eventbus/assets/htmx-websocket.min.js.gz b/util/eventbus/assets/htmx-websocket.min.js.gz new file mode 100644 index 000000000..4ed53be49 Binary files /dev/null and b/util/eventbus/assets/htmx-websocket.min.js.gz differ diff --git a/util/eventbus/assets/htmx.min.js.gz b/util/eventbus/assets/htmx.min.js.gz new file mode 100644 index 000000000..b75fea8d1 Binary files /dev/null and b/util/eventbus/assets/htmx.min.js.gz differ diff --git a/util/eventbus/assets/main.html b/util/eventbus/assets/main.html new file mode 100644 index 000000000..51d6b22ad --- /dev/null +++ b/util/eventbus/assets/main.html @@ -0,0 +1,97 @@ + + + + + + + + +

        Event bus

        + +
        +

        General

        + {{with $.PublishQueue}} + {{len .}} pending + {{end}} + + +
        + +
        +

        Clients

        + + + + + + + + + + + {{range .Clients}} + + + + + + + {{end}} +
        NamePublishingSubscribingPending
        {{.Name}} +
          + {{range .Publish}} +
        • {{.}}
        • + {{end}} +
        +
        +
          + {{range .Subscribe}} +
        • {{.}}
        • + {{end}} +
        +
        + {{len ($.SubscribeQueue .Client)}} +
        +
        + +
        +

        Types

        + + {{range .Types}} + +
        +

        {{.Name}}

        +

        Definition

        + {{prettyPrintStruct .}} + +

        Published by:

        + {{if len (.Publish)}} +
          + {{range .Publish}} +
        • {{.Name}}
        • + {{end}} +
        + {{else}} +
          +
        • No publishers.
        • +
        + {{end}} + +

        Received by:

        + {{if len (.Subscribe)}} +
          + {{range .Subscribe}} +
        • {{.Name}}
        • + {{end}} +
        + {{else}} +
          +
        • No subscribers.
        • +
        + {{end}} +
        + {{end}} + +
        + + diff --git a/util/eventbus/assets/monitor.html b/util/eventbus/assets/monitor.html new file mode 100644 index 000000000..1af5bdce6 --- /dev/null +++ b/util/eventbus/assets/monitor.html @@ -0,0 +1,5 @@ +
        +
          +
        + +
        diff --git a/util/eventbus/assets/style.css b/util/eventbus/assets/style.css new file mode 100644 index 000000000..690bd4f17 --- /dev/null +++ b/util/eventbus/assets/style.css @@ -0,0 +1,90 @@ +/* CSS reset, thanks Josh Comeau: https://www.joshwcomeau.com/css/custom-css-reset/ */ +*, *::before, *::after { box-sizing: border-box; } +* { margin: 0; } +input, button, textarea, select { font: inherit; } +p, h1, h2, h3, h4, h5, h6 { overflow-wrap: break-word; } +p { text-wrap: pretty; } +h1, h2, h3, h4, h5, h6 { text-wrap: balance; } +#root, #__next { isolation: isolate; } +body { + line-height: 1.5; + -webkit-font-smoothing: antialiased; +} +img, picture, video, canvas, svg { + display: block; + max-width: 100%; +} + +/* Local styling begins */ + +body { + padding: 12px; +} + +div { + width: 100%; +} + +section { + display: flex; + flex-direction: column; + flex-gap: 6px; + align-items: flex-start; + padding: 12px 0; +} + +section > * { + margin-left: 24px; +} + +section > h2, section > h3 { + margin-left: 0; + padding-bottom: 6px; + padding-top: 12px; +} + +details { + padding-bottom: 12px; +} + +table { + table-layout: fixed; + width: calc(100% - 48px); + border-collapse: collapse; + border: 1px solid black; +} + +th, td { + padding: 12px; + border: 1px solid black; +} + +td.list { + vertical-align: top; +} + +ul { + list-style: none; +} + +td ul { + margin: 0; + padding: 0; +} + +code { + padding: 12px; + white-space: pre; +} + +#monitor { + width: calc(100% - 48px); + resize: vertical; + padding: 12px; + overflow: scroll; + height: 15lh; + border: 1px inset; + min-height: 1em; + display: flex; + flex-direction: column-reverse; +} diff --git a/util/eventbus/bench_test.go b/util/eventbus/bench_test.go new file mode 100644 index 000000000..25f5b8002 --- /dev/null +++ b/util/eventbus/bench_test.go @@ -0,0 +1,125 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package eventbus_test + +import ( + "math/rand/v2" + "testing" + + "tailscale.com/util/eventbus" +) + +func BenchmarkBasicThroughput(b *testing.B) { + bus := eventbus.New() + pcli := bus.Client(b.Name() + "-pub") + scli := bus.Client(b.Name() + "-sub") + + type emptyEvent [0]byte + + // One publisher and a corresponding subscriber shoveling events as fast as + // they can through the plumbing. + pub := eventbus.Publish[emptyEvent](pcli) + sub := eventbus.Subscribe[emptyEvent](scli) + + go func() { + for { + select { + case <-sub.Events(): + continue + case <-sub.Done(): + return + } + } + }() + + for b.Loop() { + pub.Publish(emptyEvent{}) + } + bus.Close() +} + +func BenchmarkSubsThroughput(b *testing.B) { + bus := eventbus.New() + pcli := bus.Client(b.Name() + "-pub") + scli1 := bus.Client(b.Name() + "-sub1") + scli2 := bus.Client(b.Name() + "-sub2") + + type emptyEvent [0]byte + + // One publisher and two subscribers shoveling events as fast as they can + // through the plumbing. + pub := eventbus.Publish[emptyEvent](pcli) + sub1 := eventbus.Subscribe[emptyEvent](scli1) + sub2 := eventbus.Subscribe[emptyEvent](scli2) + + for _, sub := range []*eventbus.Subscriber[emptyEvent]{sub1, sub2} { + go func() { + for { + select { + case <-sub.Events(): + continue + case <-sub.Done(): + return + } + } + }() + } + + for b.Loop() { + pub.Publish(emptyEvent{}) + } + bus.Close() +} + +func BenchmarkMultiThroughput(b *testing.B) { + bus := eventbus.New() + cli := bus.Client(b.Name()) + + type eventA struct{} + type eventB struct{} + + // Two disjoint event streams routed through the global order. + apub := eventbus.Publish[eventA](cli) + asub := eventbus.Subscribe[eventA](cli) + bpub := eventbus.Publish[eventB](cli) + bsub := eventbus.Subscribe[eventB](cli) + + go func() { + for { + select { + case <-asub.Events(): + continue + case <-asub.Done(): + return + } + } + }() + go func() { + for { + select { + case <-bsub.Events(): + continue + case <-bsub.Done(): + return + } + } + }() + + var rng uint64 + var bits int + for b.Loop() { + if bits == 0 { + rng = rand.Uint64() + bits = 64 + } + if rng&1 == 0 { + apub.Publish(eventA{}) + } else { + bpub.Publish(eventB{}) + } + rng >>= 1 + bits-- + } + bus.Close() +} diff --git a/util/eventbus/bus.go b/util/eventbus/bus.go new file mode 100644 index 000000000..aa6880d01 --- /dev/null +++ b/util/eventbus/bus.go @@ -0,0 +1,338 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package eventbus + +import ( + "context" + "log" + "reflect" + "slices" + + "tailscale.com/syncs" + "tailscale.com/types/logger" + "tailscale.com/util/set" +) + +type PublishedEvent struct { + Event any + From *Client +} + +type RoutedEvent struct { + Event any + From *Client + To []*Client +} + +// Bus is an event bus that distributes published events to interested +// subscribers. +type Bus struct { + router *worker + write chan PublishedEvent + snapshot chan chan []PublishedEvent + routeDebug hook[RoutedEvent] + logf logger.Logf + + topicsMu syncs.Mutex + topics map[reflect.Type][]*subscribeState + + // Used for introspection/debugging only, not in the normal event + // publishing path. + clientsMu syncs.Mutex + clients set.Set[*Client] +} + +// New returns a new bus with default options. It is equivalent to +// calling [NewWithOptions] with zero [BusOptions]. +func New() *Bus { return NewWithOptions(BusOptions{}) } + +// NewWithOptions returns a new [Bus] with the specified [BusOptions]. +// Use [Bus.Client] to construct clients on the bus. +// Use [Publish] to make event publishers. +// Use [Subscribe] and [SubscribeFunc] to make event subscribers. +func NewWithOptions(opts BusOptions) *Bus { + ret := &Bus{ + write: make(chan PublishedEvent), + snapshot: make(chan chan []PublishedEvent), + topics: map[reflect.Type][]*subscribeState{}, + clients: set.Set[*Client]{}, + logf: opts.logger(), + } + ret.router = runWorker(ret.pump) + return ret +} + +// BusOptions are optional parameters for a [Bus]. A zero value is ready for +// use and provides defaults as described. +type BusOptions struct { + // Logf, if non-nil, is used for debug logs emitted by the bus and clients, + // publishers, and subscribers under its care. If it is nil, logs are sent + // to [log.Printf]. + Logf logger.Logf +} + +func (o BusOptions) logger() logger.Logf { + if o.Logf == nil { + return log.Printf + } + return o.Logf +} + +// Client returns a new client with no subscriptions. Use [Subscribe] +// to receive events, and [Publish] to emit events. +// +// The client's name is used only for debugging, to tell humans what +// piece of code a publisher/subscriber belongs to. Aim for something +// short but unique, for example "kernel-route-monitor" or "taildrop", +// not "watcher". +func (b *Bus) Client(name string) *Client { + ret := &Client{ + name: name, + bus: b, + pub: set.Set[publisher]{}, + } + b.clientsMu.Lock() + defer b.clientsMu.Unlock() + b.clients.Add(ret) + return ret +} + +// Debugger returns the debugging facility for the bus. +func (b *Bus) Debugger() *Debugger { + return &Debugger{b} +} + +// Close closes the bus. It implicitly closes all clients, publishers and +// subscribers attached to the bus. +// +// Close blocks until the bus is fully shut down. The bus is +// permanently unusable after closing. +func (b *Bus) Close() { + b.router.StopAndWait() + + b.clientsMu.Lock() + defer b.clientsMu.Unlock() + for c := range b.clients { + c.Close() + } + b.clients = nil +} + +func (b *Bus) pump(ctx context.Context) { + var vals queue[PublishedEvent] + acceptCh := func() chan PublishedEvent { + if vals.Full() { + return nil + } + return b.write + } + for { + // Drain all pending events. Note that while we're draining + // events into subscriber queues, we continue to + // opportunistically accept more incoming events, if we have + // queue space for it. + for !vals.Empty() { + val := vals.Peek() + dests := b.dest(reflect.TypeOf(val.Event)) + + if b.routeDebug.active() { + clients := make([]*Client, len(dests)) + for i := range len(dests) { + clients[i] = dests[i].client + } + b.routeDebug.run(RoutedEvent{ + Event: val.Event, + From: val.From, + To: clients, + }) + } + + for _, d := range dests { + evt := DeliveredEvent{ + Event: val.Event, + From: val.From, + To: d.client, + } + deliverOne: + for { + select { + case d.write <- evt: + break deliverOne + case <-d.closed(): + // Queue closed, don't block but continue + // delivering to others. + break deliverOne + case in := <-acceptCh(): + vals.Add(in) + in.From.publishDebug.run(in) + case <-ctx.Done(): + return + case ch := <-b.snapshot: + ch <- vals.Snapshot() + } + } + } + vals.Drop() + } + + // Inbound queue empty, wait for at least 1 work item before + // resuming. + for vals.Empty() { + select { + case <-ctx.Done(): + return + case in := <-b.write: + vals.Add(in) + in.From.publishDebug.run(in) + case ch := <-b.snapshot: + ch <- nil + } + } + } +} + +// logger returns a [logger.Logf] to which logs related to bus activity should be written. +func (b *Bus) logger() logger.Logf { return b.logf } + +func (b *Bus) dest(t reflect.Type) []*subscribeState { + b.topicsMu.Lock() + defer b.topicsMu.Unlock() + return b.topics[t] +} + +func (b *Bus) shouldPublish(t reflect.Type) bool { + if b.routeDebug.active() { + return true + } + + b.topicsMu.Lock() + defer b.topicsMu.Unlock() + return len(b.topics[t]) > 0 +} + +func (b *Bus) listClients() []*Client { + b.clientsMu.Lock() + defer b.clientsMu.Unlock() + return b.clients.Slice() +} + +func (b *Bus) snapshotPublishQueue() []PublishedEvent { + resp := make(chan []PublishedEvent) + select { + case b.snapshot <- resp: + return <-resp + case <-b.router.Done(): + return nil + } +} + +func (b *Bus) subscribe(t reflect.Type, q *subscribeState) (cancel func()) { + b.topicsMu.Lock() + defer b.topicsMu.Unlock() + b.topics[t] = append(b.topics[t], q) + return func() { + b.unsubscribe(t, q) + } +} + +func (b *Bus) unsubscribe(t reflect.Type, q *subscribeState) { + b.topicsMu.Lock() + defer b.topicsMu.Unlock() + // Topic slices are accessed by pump without holding a lock, so we + // have to replace the entire slice when unsubscribing. + // Unsubscribing should be infrequent enough that this won't + // matter. + i := slices.Index(b.topics[t], q) + if i < 0 { + return + } + b.topics[t] = slices.Delete(slices.Clone(b.topics[t]), i, i+1) +} + +// A worker runs a worker goroutine and helps coordinate its shutdown. +type worker struct { + ctx context.Context + stop context.CancelFunc + stopped chan struct{} +} + +// runWorker creates a worker goroutine running fn. The context passed +// to fn is canceled by [worker.Stop]. +func runWorker(fn func(context.Context)) *worker { + ctx, stop := context.WithCancel(context.Background()) + ret := &worker{ + ctx: ctx, + stop: stop, + stopped: make(chan struct{}), + } + go ret.run(fn) + return ret +} + +func (w *worker) run(fn func(context.Context)) { + defer close(w.stopped) + fn(w.ctx) +} + +// Stop signals the worker goroutine to shut down. +func (w *worker) Stop() { w.stop() } + +// Done returns a channel that is closed when the worker goroutine +// exits. +func (w *worker) Done() <-chan struct{} { return w.stopped } + +// Wait waits until the worker goroutine has exited. +func (w *worker) Wait() { <-w.stopped } + +// StopAndWait signals the worker goroutine to shut down, then waits +// for it to exit. +func (w *worker) StopAndWait() { + w.stop() + <-w.stopped +} + +// stopFlag is a value that can be watched for a notification. The +// zero value is ready for use. +// +// The flag is notified by running [stopFlag.Stop]. Stop can be called +// multiple times. Upon the first call to Stop, [stopFlag.Done] is +// closed, all pending [stopFlag.Wait] calls return, and future Wait +// calls return immediately. +// +// A stopFlag can only notify once, and is intended for use as a +// one-way shutdown signal that's lighter than a cancellable +// context.Context. +type stopFlag struct { + // guards the lazy construction of stopped, and the value of + // alreadyStopped. + mu syncs.Mutex + stopped chan struct{} + alreadyStopped bool +} + +func (s *stopFlag) Stop() { + s.mu.Lock() + defer s.mu.Unlock() + if s.alreadyStopped { + return + } + s.alreadyStopped = true + if s.stopped == nil { + s.stopped = make(chan struct{}) + } + close(s.stopped) +} + +func (s *stopFlag) Done() <-chan struct{} { + s.mu.Lock() + defer s.mu.Unlock() + if s.stopped == nil { + s.stopped = make(chan struct{}) + } + return s.stopped +} + +func (s *stopFlag) Wait() { + <-s.Done() +} diff --git a/util/eventbus/bus_test.go b/util/eventbus/bus_test.go new file mode 100644 index 000000000..61728fbfd --- /dev/null +++ b/util/eventbus/bus_test.go @@ -0,0 +1,628 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package eventbus_test + +import ( + "bytes" + "errors" + "fmt" + "log" + "regexp" + "testing" + "testing/synctest" + "time" + + "github.com/creachadair/taskgroup" + "github.com/google/go-cmp/cmp" + "tailscale.com/util/eventbus" +) + +type EventA struct { + Counter int +} + +type EventB struct { + Counter int +} + +func TestBus(t *testing.T) { + b := eventbus.New() + defer b.Close() + + c := b.Client("TestSub") + cdone := c.Done() + defer func() { + c.Close() + select { + case <-cdone: + t.Log("Client close signal received (OK)") + case <-time.After(time.Second): + t.Error("timed out waiting for client close signal") + } + }() + s := eventbus.Subscribe[EventA](c) + + go func() { + p := b.Client("TestPub") + defer p.Close() + pa := eventbus.Publish[EventA](p) + defer pa.Close() + pb := eventbus.Publish[EventB](p) + defer pb.Close() + pa.Publish(EventA{1}) + pb.Publish(EventB{2}) + pa.Publish(EventA{3}) + }() + + want := expectEvents(t, EventA{1}, EventA{3}) + for !want.Empty() { + select { + case got := <-s.Events(): + want.Got(got) + case <-s.Done(): + t.Fatalf("queue closed unexpectedly") + case <-time.After(time.Second): + t.Fatalf("timed out waiting for event") + } + } +} + +func TestSubscriberFunc(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + b := eventbus.New() + defer b.Close() + + c := b.Client("TestClient") + + exp := expectEvents(t, EventA{12345}) + eventbus.SubscribeFunc[EventA](c, func(e EventA) { exp.Got(e) }) + + p := eventbus.Publish[EventA](c) + p.Publish(EventA{12345}) + + synctest.Wait() + c.Close() + + if !exp.Empty() { + t.Errorf("unexpected extra events: %+v", exp.want) + } + }) + + t.Run("CloseWait", func(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + b := eventbus.New() + defer b.Close() + + c := b.Client(t.Name()) + + eventbus.SubscribeFunc[EventA](c, func(e EventA) { + time.Sleep(2 * time.Second) + }) + + p := eventbus.Publish[EventA](c) + p.Publish(EventA{12345}) + + synctest.Wait() // subscriber has the event + c.Close() + + // If close does not wait for the subscriber, the test will fail + // because an active goroutine remains in the bubble. + }) + }) + + t.Run("CloseWait/Belated", func(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + buf := swapLogBuf(t) + + b := eventbus.New() + defer b.Close() + + c := b.Client(t.Name()) + + // This subscriber stalls for a long time, so that when we try to + // close the client it gives up and returns in the timeout condition. + eventbus.SubscribeFunc[EventA](c, func(e EventA) { + time.Sleep(time.Minute) // notably, longer than the wait period + }) + + p := eventbus.Publish[EventA](c) + p.Publish(EventA{12345}) + + synctest.Wait() // subscriber has the event + c.Close() + + // Verify that the logger recorded that Close gave up on the slowpoke. + want := regexp.MustCompile(`^.* tailscale.com/util/eventbus_test bus_test.go:\d+: ` + + `giving up on subscriber for eventbus_test.EventA after \d+s at close.*`) + if got := buf.String(); !want.MatchString(got) { + t.Errorf("Wrong log output\ngot: %q\nwant %s", got, want) + } + + // Wait for the subscriber to actually finish to clean up the goroutine. + time.Sleep(2 * time.Minute) + }) + }) + + t.Run("SubscriberPublishes", func(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + b := eventbus.New() + defer b.Close() + + c := b.Client("TestClient") + pa := eventbus.Publish[EventA](c) + pb := eventbus.Publish[EventB](c) + exp := expectEvents(t, EventA{127}, EventB{128}) + eventbus.SubscribeFunc[EventA](c, func(e EventA) { + exp.Got(e) + pb.Publish(EventB{Counter: e.Counter + 1}) + }) + eventbus.SubscribeFunc[EventB](c, func(e EventB) { + exp.Got(e) + }) + + pa.Publish(EventA{127}) + + synctest.Wait() + c.Close() + if !exp.Empty() { + t.Errorf("unepxected extra events: %+v", exp.want) + } + }) + }) +} + +func TestBusMultipleConsumers(t *testing.T) { + b := eventbus.New() + defer b.Close() + + c1 := b.Client("TestSubA") + defer c1.Close() + s1 := eventbus.Subscribe[EventA](c1) + + c2 := b.Client("TestSubB") + defer c2.Close() + s2A := eventbus.Subscribe[EventA](c2) + s2B := eventbus.Subscribe[EventB](c2) + + go func() { + p := b.Client("TestPub") + defer p.Close() + pa := eventbus.Publish[EventA](p) + defer pa.Close() + pb := eventbus.Publish[EventB](p) + defer pb.Close() + pa.Publish(EventA{1}) + pb.Publish(EventB{2}) + pa.Publish(EventA{3}) + }() + + wantA := expectEvents(t, EventA{1}, EventA{3}) + wantB := expectEvents(t, EventA{1}, EventB{2}, EventA{3}) + for !wantA.Empty() || !wantB.Empty() { + select { + case got := <-s1.Events(): + wantA.Got(got) + case got := <-s2A.Events(): + wantB.Got(got) + case got := <-s2B.Events(): + wantB.Got(got) + case <-s1.Done(): + t.Fatalf("queue closed unexpectedly") + case <-s2A.Done(): + t.Fatalf("queue closed unexpectedly") + case <-s2B.Done(): + t.Fatalf("queue closed unexpectedly") + case <-time.After(time.Second): + t.Fatalf("timed out waiting for event") + } + } +} + +func TestClientMixedSubscribers(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + b := eventbus.New() + defer b.Close() + + c := b.Client("TestClient") + + var gotA EventA + s1 := eventbus.Subscribe[EventA](c) + + var gotB EventB + eventbus.SubscribeFunc[EventB](c, func(e EventB) { + t.Logf("func sub received %[1]T %+[1]v", e) + gotB = e + }) + + go func() { + for { + select { + case <-s1.Done(): + return + case e := <-s1.Events(): + t.Logf("chan sub received %[1]T %+[1]v", e) + gotA = e + } + } + }() + + p1 := eventbus.Publish[EventA](c) + p2 := eventbus.Publish[EventB](c) + + go p1.Publish(EventA{12345}) + go p2.Publish(EventB{67890}) + + synctest.Wait() + c.Close() + synctest.Wait() + + if diff := cmp.Diff(gotB, EventB{67890}); diff != "" { + t.Errorf("Chan sub (-got, +want):\n%s", diff) + } + if diff := cmp.Diff(gotA, EventA{12345}); diff != "" { + t.Errorf("Func sub (-got, +want):\n%s", diff) + } + }) +} + +func TestSpam(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + b := eventbus.New() + defer b.Close() + + const ( + publishers = 100 + eventsPerPublisher = 20 + wantEvents = publishers * eventsPerPublisher + subscribers = 100 + ) + + var g taskgroup.Group + + // A bunch of subscribers receiving on channels. + chanReceived := make([][]EventA, subscribers) + for i := range subscribers { + c := b.Client(fmt.Sprintf("Subscriber%d", i)) + defer c.Close() + + s := eventbus.Subscribe[EventA](c) + g.Go(func() error { + for range wantEvents { + select { + case evt := <-s.Events(): + chanReceived[i] = append(chanReceived[i], evt) + case <-s.Done(): + t.Errorf("queue done before expected number of events received") + return errors.New("queue prematurely closed") + case <-time.After(5 * time.Second): + t.Logf("timed out waiting for expected bus event after %d events", len(chanReceived[i])) + return errors.New("timeout") + } + } + return nil + }) + } + + // A bunch of subscribers receiving via a func. + funcReceived := make([][]EventA, subscribers) + for i := range subscribers { + c := b.Client(fmt.Sprintf("SubscriberFunc%d", i)) + defer c.Close() + eventbus.SubscribeFunc(c, func(e EventA) { + funcReceived[i] = append(funcReceived[i], e) + }) + } + + published := make([][]EventA, publishers) + for i := range publishers { + c := b.Client(fmt.Sprintf("Publisher%d", i)) + p := eventbus.Publish[EventA](c) + g.Run(func() { + defer c.Close() + for j := range eventsPerPublisher { + evt := EventA{i*eventsPerPublisher + j} + p.Publish(evt) + published[i] = append(published[i], evt) + } + }) + } + + if err := g.Wait(); err != nil { + t.Fatal(err) + } + synctest.Wait() + + tests := []struct { + name string + recv [][]EventA + }{ + {"Subscriber", chanReceived}, + {"SubscriberFunc", funcReceived}, + } + for _, tc := range tests { + for i, got := range tc.recv { + if len(got) != wantEvents { + t.Errorf("%s %d: got %d events, want %d", tc.name, i, len(got), wantEvents) + } + if i == 0 { + continue + } + if diff := cmp.Diff(got, tc.recv[i-1]); diff != "" { + t.Errorf("%s %d did not see the same events as %d (-got+want):\n%s", tc.name, i, i-1, diff) + } + } + } + for i, sent := range published { + if got := len(sent); got != eventsPerPublisher { + t.Fatalf("Publisher %d sent %d events, want %d", i, got, eventsPerPublisher) + } + } + + // TODO: check that the published sequences are proper + // subsequences of the received slices. + }) +} + +func TestClient_Done(t *testing.T) { + b := eventbus.New() + defer b.Close() + + c := b.Client(t.Name()) + s := eventbus.Subscribe[string](c) + + // The client is not Done until closed. + select { + case <-c.Done(): + t.Fatal("Client done before being closed") + default: + // OK + } + + go c.Close() + + // Once closed, the client becomes Done. + select { + case <-c.Done(): + // OK + case <-time.After(time.Second): + t.Fatal("timeout waiting for Client to be done") + } + + // Thereafter, the subscriber should also be closed. + select { + case <-s.Done(): + // OK + case <-time.After(time.Second): + t.Fatal("timoeout waiting for Subscriber to be done") + } +} + +func TestMonitor(t *testing.T) { + t.Run("ZeroWait", func(t *testing.T) { + var zero eventbus.Monitor + + ready := make(chan struct{}) + go func() { zero.Wait(); close(ready) }() + + select { + case <-ready: + // OK + case <-time.After(time.Second): + t.Fatal("timeout waiting for Wait to return") + } + }) + + t.Run("ZeroDone", func(t *testing.T) { + var zero eventbus.Monitor + + select { + case <-zero.Done(): + // OK + case <-time.After(time.Second): + t.Fatal("timeout waiting for zero monitor to be done") + } + }) + + t.Run("ZeroClose", func(t *testing.T) { + var zero eventbus.Monitor + + ready := make(chan struct{}) + go func() { zero.Close(); close(ready) }() + + select { + case <-ready: + // OK + case <-time.After(time.Second): + t.Fatal("timeout waiting for Close to return") + } + }) + + testMon := func(t *testing.T, release func(*eventbus.Client, eventbus.Monitor)) func(t *testing.T) { + t.Helper() + return func(t *testing.T) { + bus := eventbus.New() + cli := bus.Client("test client") + + // The monitored goroutine runs until the client or test subscription ends. + sub := eventbus.Subscribe[string](cli) + m := cli.Monitor(func(c *eventbus.Client) { + select { + case <-c.Done(): + t.Log("client closed") + case <-sub.Done(): + t.Log("subscription closed") + } + }) + + done := make(chan struct{}) + go func() { + defer close(done) + m.Wait() + }() + + // While the goroutine is running, Wait does not complete. + select { + case <-done: + t.Error("monitor is ready before its goroutine is finished (Wait)") + default: + // OK + } + select { + case <-m.Done(): + t.Error("monitor is ready before its goroutine is finished (Done)") + default: + // OK + } + + release(cli, m) + select { + case <-done: + // OK + case <-time.After(time.Second): + t.Fatal("timeout waiting for monitor to complete (Wait)") + } + select { + case <-m.Done(): + // OK + case <-time.After(time.Second): + t.Fatal("timeout waiting for monitor to complete (Done)") + } + } + } + t.Run("Close", testMon(t, func(_ *eventbus.Client, m eventbus.Monitor) { m.Close() })) + t.Run("Wait", testMon(t, func(c *eventbus.Client, m eventbus.Monitor) { c.Close(); m.Wait() })) +} + +func TestSlowSubs(t *testing.T) { + t.Run("Subscriber", func(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + buf := swapLogBuf(t) + + b := eventbus.New() + defer b.Close() + + pc := b.Client("pub") + p := eventbus.Publish[EventA](pc) + + sc := b.Client("sub") + s := eventbus.Subscribe[EventA](sc) + + go func() { + time.Sleep(6 * time.Second) // trigger the slow check at 5s. + t.Logf("Subscriber accepted %v", <-s.Events()) + }() + + p.Publish(EventA{12345}) + + time.Sleep(7 * time.Second) // advance time... + synctest.Wait() // subscriber is done + + want := regexp.MustCompile(`^.* tailscale.com/util/eventbus_test bus_test.go:\d+: ` + + `subscriber for eventbus_test.EventA is slow.*`) + if got := buf.String(); !want.MatchString(got) { + t.Errorf("Wrong log output\ngot: %q\nwant: %s", got, want) + } + }) + }) + + t.Run("SubscriberFunc", func(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + buf := swapLogBuf(t) + + b := eventbus.New() + defer b.Close() + + pc := b.Client("pub") + p := eventbus.Publish[EventB](pc) + + sc := b.Client("sub") + eventbus.SubscribeFunc[EventB](sc, func(e EventB) { + time.Sleep(6 * time.Second) // trigger the slow check at 5s. + t.Logf("SubscriberFunc processed %v", e) + }) + + p.Publish(EventB{67890}) + + time.Sleep(7 * time.Second) // advance time... + synctest.Wait() // subscriber is done + + want := regexp.MustCompile(`^.* tailscale.com/util/eventbus_test bus_test.go:\d+: ` + + `subscriber for eventbus_test.EventB is slow.*`) + if got := buf.String(); !want.MatchString(got) { + t.Errorf("Wrong log output\ngot: %q\nwant: %s", got, want) + } + }) + }) +} + +func TestRegression(t *testing.T) { + bus := eventbus.New() + t.Cleanup(bus.Close) + + t.Run("SubscribeClosed", func(t *testing.T) { + c := bus.Client("test sub client") + c.Close() + + var v any + func() { + defer func() { v = recover() }() + eventbus.Subscribe[string](c) + }() + if v == nil { + t.Fatal("Expected a panic from Subscribe on a closed client") + } else { + t.Logf("Got expected panic: %v", v) + } + }) + + t.Run("PublishClosed", func(t *testing.T) { + c := bus.Client("test pub client") + c.Close() + + var v any + func() { + defer func() { v = recover() }() + eventbus.Publish[string](c) + }() + if v == nil { + t.Fatal("expected a panic from Publish on a closed client") + } else { + t.Logf("Got expected panic: %v", v) + } + }) +} + +type queueChecker struct { + t *testing.T + want []any +} + +func expectEvents(t *testing.T, want ...any) *queueChecker { + return &queueChecker{t, want} +} + +func (q *queueChecker) Got(v any) { + q.t.Helper() + if q.Empty() { + q.t.Errorf("queue got unexpected %v", v) + return + } + if v != q.want[0] { + q.t.Errorf("queue got %#v, want %#v", v, q.want[0]) + return + } + q.want = q.want[1:] +} + +func (q *queueChecker) Empty() bool { + return len(q.want) == 0 +} + +func swapLogBuf(t *testing.T) *bytes.Buffer { + logBuf := new(bytes.Buffer) + save := log.Writer() + log.SetOutput(logBuf) + t.Cleanup(func() { log.SetOutput(save) }) + return logBuf +} diff --git a/util/eventbus/client.go b/util/eventbus/client.go new file mode 100644 index 000000000..a7a5ab673 --- /dev/null +++ b/util/eventbus/client.go @@ -0,0 +1,182 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package eventbus + +import ( + "reflect" + + "tailscale.com/syncs" + "tailscale.com/types/logger" + "tailscale.com/util/set" +) + +// A Client can publish and subscribe to events on its attached +// bus. See [Publish] to publish events, and [Subscribe] to receive +// events. +// +// Subscribers that share the same client receive events one at a +// time, in the order they were published. +type Client struct { + name string + bus *Bus + publishDebug hook[PublishedEvent] + + mu syncs.Mutex + pub set.Set[publisher] + sub *subscribeState // Lazily created on first subscribe + stop stopFlag // signaled on Close +} + +func (c *Client) Name() string { return c.name } + +func (c *Client) logger() logger.Logf { return c.bus.logger() } + +// Close closes the client. It implicitly closes all publishers and +// subscribers obtained from this client. +func (c *Client) Close() { + var ( + pub set.Set[publisher] + sub *subscribeState + ) + + c.mu.Lock() + pub, c.pub = c.pub, nil + sub, c.sub = c.sub, nil + c.mu.Unlock() + + if sub != nil { + sub.close() + } + for p := range pub { + p.Close() + } + c.stop.Stop() +} + +func (c *Client) isClosed() bool { return c.pub == nil && c.sub == nil } + +// Done returns a channel that is closed when [Client.Close] is called. +// The channel is closed after all the publishers and subscribers governed by +// the client have been closed. +func (c *Client) Done() <-chan struct{} { return c.stop.Done() } + +func (c *Client) snapshotSubscribeQueue() []DeliveredEvent { + return c.peekSubscribeState().snapshotQueue() +} + +func (c *Client) peekSubscribeState() *subscribeState { + c.mu.Lock() + defer c.mu.Unlock() + return c.sub +} + +func (c *Client) publishTypes() []reflect.Type { + c.mu.Lock() + defer c.mu.Unlock() + ret := make([]reflect.Type, 0, len(c.pub)) + for pub := range c.pub { + ret = append(ret, pub.publishType()) + } + return ret +} + +func (c *Client) subscribeTypes() []reflect.Type { + return c.peekSubscribeState().subscribeTypes() +} + +func (c *Client) subscribeState() *subscribeState { + c.mu.Lock() + defer c.mu.Unlock() + return c.subscribeStateLocked() +} + +func (c *Client) subscribeStateLocked() *subscribeState { + if c.sub == nil { + c.sub = newSubscribeState(c) + } + return c.sub +} + +func (c *Client) addPublisher(pub publisher) { + c.mu.Lock() + defer c.mu.Unlock() + if c.isClosed() { + panic("cannot Publish on a closed client") + } + c.pub.Add(pub) +} + +func (c *Client) deletePublisher(pub publisher) { + c.mu.Lock() + defer c.mu.Unlock() + c.pub.Delete(pub) +} + +func (c *Client) addSubscriber(t reflect.Type, s *subscribeState) { + c.bus.subscribe(t, s) +} + +func (c *Client) deleteSubscriber(t reflect.Type, s *subscribeState) { + c.bus.unsubscribe(t, s) +} + +func (c *Client) publish() chan<- PublishedEvent { + return c.bus.write +} + +func (c *Client) shouldPublish(t reflect.Type) bool { + return c.publishDebug.active() || c.bus.shouldPublish(t) +} + +// Subscribe requests delivery of events of type T through the given client. +// It panics if c already has a subscriber for type T, or if c is closed. +func Subscribe[T any](c *Client) *Subscriber[T] { + // Hold the client lock throughout the subscription process so that a caller + // attempting to subscribe on a closed client will get a useful diagnostic + // instead of a random panic from inside the subscriber plumbing. + c.mu.Lock() + defer c.mu.Unlock() + + // The caller should not race subscriptions with close, give them a useful + // diagnostic at the call site. + if c.isClosed() { + panic("cannot Subscribe on a closed client") + } + + r := c.subscribeStateLocked() + s := newSubscriber[T](r, logfForCaller(c.logger())) + r.addSubscriber(s) + return s +} + +// SubscribeFunc is like [Subscribe], but calls the provided func for each +// event of type T. +// +// A SubscriberFunc calls f synchronously from the client's goroutine. +// This means the callback must not block for an extended period of time, +// as this will block the subscriber and slow event processing for all +// subscriptions on c. +func SubscribeFunc[T any](c *Client, f func(T)) *SubscriberFunc[T] { + c.mu.Lock() + defer c.mu.Unlock() + + // The caller should not race subscriptions with close, give them a useful + // diagnostic at the call site. + if c.isClosed() { + panic("cannot SubscribeFunc on a closed client") + } + + r := c.subscribeStateLocked() + s := newSubscriberFunc[T](r, f, logfForCaller(c.logger())) + r.addSubscriber(s) + return s +} + +// Publish returns a publisher for event type T using the given client. +// It panics if c is closed. +func Publish[T any](c *Client) *Publisher[T] { + p := newPublisher[T](c) + c.addPublisher(p) + return p +} diff --git a/util/eventbus/debug-demo/main.go b/util/eventbus/debug-demo/main.go new file mode 100644 index 000000000..71894d2ea --- /dev/null +++ b/util/eventbus/debug-demo/main.go @@ -0,0 +1,107 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// debug-demo is a program that serves a bus's debug interface over +// HTTP, then generates some fake traffic from a handful of +// clients. It is an aid to development, to have something to present +// on the debug interfaces while writing them. +package main + +import ( + "log" + "math/rand/v2" + "net/http" + "net/netip" + "time" + + "tailscale.com/feature/buildfeatures" + "tailscale.com/tsweb" + "tailscale.com/types/key" + "tailscale.com/util/eventbus" +) + +func main() { + if !buildfeatures.HasDebugEventBus { + log.Fatalf("debug-demo requires the \"debugeventbus\" feature enabled") + } + b := eventbus.New() + c := b.Client("RouteMonitor") + go testPub[RouteAdded](c, 5*time.Second) + go testPub[RouteRemoved](c, 5*time.Second) + c = b.Client("ControlClient") + go testPub[PeerAdded](c, 3*time.Second) + go testPub[PeerRemoved](c, 6*time.Second) + c = b.Client("Portmapper") + go testPub[PortmapAcquired](c, 10*time.Second) + go testPub[PortmapLost](c, 15*time.Second) + go testSub[RouteAdded](c) + c = b.Client("WireguardConfig") + go testSub[PeerAdded](c) + go testSub[PeerRemoved](c) + c = b.Client("Magicsock") + go testPub[PeerPathChanged](c, 5*time.Second) + go testSub[RouteAdded](c) + go testSub[RouteRemoved](c) + go testSub[PortmapAcquired](c) + go testSub[PortmapLost](c) + + m := http.NewServeMux() + d := tsweb.Debugger(m) + b.Debugger().RegisterHTTP(d) + + m.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + http.Redirect(w, r, "/debug/bus", http.StatusFound) + }) + log.Printf("Serving debug interface at http://localhost:8185/debug/bus") + http.ListenAndServe(":8185", m) +} + +func testPub[T any](c *eventbus.Client, every time.Duration) { + p := eventbus.Publish[T](c) + for { + jitter := time.Duration(rand.N(2000)) * time.Millisecond + time.Sleep(jitter) + var zero T + log.Printf("%s publish: %T", c.Name(), zero) + p.Publish(zero) + time.Sleep(every) + } +} + +func testSub[T any](c *eventbus.Client) { + s := eventbus.Subscribe[T](c) + for v := range s.Events() { + log.Printf("%s received: %T", c.Name(), v) + } +} + +type RouteAdded struct { + Prefix netip.Prefix + Via netip.Addr + Priority int +} +type RouteRemoved struct { + Prefix netip.Addr +} + +type PeerAdded struct { + ID int + Key key.NodePublic +} +type PeerRemoved struct { + ID int + Key key.NodePublic +} + +type PortmapAcquired struct { + Endpoint netip.Addr +} +type PortmapLost struct { + Endpoint netip.Addr +} + +type PeerPathChanged struct { + ID int + EndpointID int + Quality int +} diff --git a/util/eventbus/debug.go b/util/eventbus/debug.go new file mode 100644 index 000000000..0453defb1 --- /dev/null +++ b/util/eventbus/debug.go @@ -0,0 +1,242 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package eventbus + +import ( + "cmp" + "fmt" + "path/filepath" + "reflect" + "runtime" + "slices" + "strings" + "sync/atomic" + "time" + + "tailscale.com/syncs" + "tailscale.com/types/logger" +) + +// slowSubscriberTimeout is a timeout after which a subscriber that does not +// accept a pending event will be flagged as being slow. +const slowSubscriberTimeout = 5 * time.Second + +// A Debugger offers access to a bus's privileged introspection and +// debugging facilities. +// +// The debugger's functionality is intended for humans and their tools +// to examine and troubleshoot bus clients, and should not be used in +// normal codepaths. +// +// In particular, the debugger provides access to information that is +// deliberately withheld from bus clients to encourage more robust and +// maintainable code - for example, the sender of an event, or the +// event streams of other clients. Please don't use the debugger to +// circumvent these restrictions for purposes other than debugging. +type Debugger struct { + bus *Bus +} + +// Clients returns a list of all clients attached to the bus. +func (d *Debugger) Clients() []*Client { + ret := d.bus.listClients() + slices.SortFunc(ret, func(a, b *Client) int { + return cmp.Compare(a.Name(), b.Name()) + }) + return ret +} + +// PublishQueue returns the contents of the publish queue. +// +// The publish queue contains events that have been accepted by the +// bus from Publish() calls, but have not yet been routed to relevant +// subscribers. +// +// This queue is expected to be almost empty in normal operation. A +// full publish queue indicates that a slow subscriber downstream is +// causing backpressure and stalling the bus. +func (d *Debugger) PublishQueue() []PublishedEvent { + return d.bus.snapshotPublishQueue() +} + +// checkClient verifies that client is attached to the same bus as the +// Debugger, and panics if not. +func (d *Debugger) checkClient(client *Client) { + if client.bus != d.bus { + panic(fmt.Errorf("SubscribeQueue given client belonging to wrong bus")) + } +} + +// SubscribeQueue returns the contents of the given client's subscribe +// queue. +// +// The subscribe queue contains events that are to be delivered to the +// client, but haven't yet been handed off to the relevant +// [Subscriber]. +// +// This queue is expected to be almost empty in normal operation. A +// full subscribe queue indicates that the client is accepting events +// too slowly, and may be causing the rest of the bus to stall. +func (d *Debugger) SubscribeQueue(client *Client) []DeliveredEvent { + d.checkClient(client) + return client.snapshotSubscribeQueue() +} + +// WatchBus streams information about all events passing through the +// bus. +// +// Monitored events are delivered in the bus's global publication +// order (see "Concurrency properties" in the package docs). +// +// The caller must consume monitoring events promptly to avoid +// stalling the bus (see "Expected subscriber behavior" in the package +// docs). +func (d *Debugger) WatchBus() *Subscriber[RoutedEvent] { + return newMonitor(d.bus.routeDebug.add) +} + +// WatchPublish streams information about all events published by the +// given client. +// +// Monitored events are delivered in the bus's global publication +// order (see "Concurrency properties" in the package docs). +// +// The caller must consume monitoring events promptly to avoid +// stalling the bus (see "Expected subscriber behavior" in the package +// docs). +func (d *Debugger) WatchPublish(client *Client) *Subscriber[PublishedEvent] { + d.checkClient(client) + return newMonitor(client.publishDebug.add) +} + +// WatchSubscribe streams information about all events received by the +// given client. +// +// Monitored events are delivered in the bus's global publication +// order (see "Concurrency properties" in the package docs). +// +// The caller must consume monitoring events promptly to avoid +// stalling the bus (see "Expected subscriber behavior" in the package +// docs). +func (d *Debugger) WatchSubscribe(client *Client) *Subscriber[DeliveredEvent] { + d.checkClient(client) + return newMonitor(client.subscribeState().debug.add) +} + +// PublishTypes returns the list of types being published by client. +// +// The returned types are those for which the client has obtained a +// [Publisher]. The client may not have ever sent the type in +// question. +func (d *Debugger) PublishTypes(client *Client) []reflect.Type { + d.checkClient(client) + return client.publishTypes() +} + +// SubscribeTypes returns the list of types being subscribed to by +// client. +// +// The returned types are those for which the client has obtained a +// [Subscriber]. The client may not have ever received the type in +// question, and here may not be any publishers of the type. +func (d *Debugger) SubscribeTypes(client *Client) []reflect.Type { + d.checkClient(client) + return client.subscribeTypes() +} + +// A hook collects hook functions that can be run as a group. +type hook[T any] struct { + syncs.Mutex + fns []hookFn[T] +} + +var hookID atomic.Uint64 + +// add registers fn to be called when the hook is run. Returns an +// unregistration function that removes fn from the hook when called. +func (h *hook[T]) add(fn func(T)) (remove func()) { + id := hookID.Add(1) + h.Lock() + defer h.Unlock() + h.fns = append(h.fns, hookFn[T]{id, fn}) + return func() { h.remove(id) } +} + +// remove removes the function with the given ID from the hook. +func (h *hook[T]) remove(id uint64) { + h.Lock() + defer h.Unlock() + h.fns = slices.DeleteFunc(h.fns, func(f hookFn[T]) bool { return f.ID == id }) +} + +// active reports whether any functions are registered with the +// hook. This can be used to skip expensive work when the hook is +// inactive. +func (h *hook[T]) active() bool { + h.Lock() + defer h.Unlock() + return len(h.fns) > 0 +} + +// run calls all registered functions with the value v. +func (h *hook[T]) run(v T) { + h.Lock() + defer h.Unlock() + for _, fn := range h.fns { + fn.Fn(v) + } +} + +type hookFn[T any] struct { + ID uint64 + Fn func(T) +} + +// DebugEvent is a representation of an event used for debug clients. +type DebugEvent struct { + Count int + Type string + From string + To []string + Event any +} + +// DebugTopics provides the JSON encoding as a wrapper for a collection of [DebugTopic]. +type DebugTopics struct { + Topics []DebugTopic +} + +// DebugTopic provides the JSON encoding of publishers and subscribers for a +// given topic. +type DebugTopic struct { + Name string + Publisher string + Subscribers []string +} + +// logfForCaller returns a [logger.Logf] that prefixes its output with the +// package, filename, and line number of the caller's caller. +// If logf == nil, it returns [logger.Discard]. +// If the caller location could not be determined, it returns logf unmodified. +func logfForCaller(logf logger.Logf) logger.Logf { + if logf == nil { + return logger.Discard + } + pc, fpath, line, _ := runtime.Caller(2) // +1 for my caller, +1 for theirs + if f := runtime.FuncForPC(pc); f != nil { + return logger.WithPrefix(logf, fmt.Sprintf("%s %s:%d: ", funcPackageName(f.Name()), filepath.Base(fpath), line)) + } + return logf +} + +func funcPackageName(funcName string) string { + ls := max(strings.LastIndex(funcName, "/"), 0) + for { + i := strings.LastIndex(funcName, ".") + if i <= ls { + return funcName + } + funcName = funcName[:i] + } +} diff --git a/util/eventbus/debughttp.go b/util/eventbus/debughttp.go new file mode 100644 index 000000000..9e03676d0 --- /dev/null +++ b/util/eventbus/debughttp.go @@ -0,0 +1,240 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !ios && !android && !ts_omit_debugeventbus + +package eventbus + +import ( + "bytes" + "cmp" + "embed" + "fmt" + "html/template" + "io" + "io/fs" + "log" + "net/http" + "path/filepath" + "reflect" + "slices" + "strings" + "sync" + + "github.com/coder/websocket" + "tailscale.com/tsweb" +) + +type httpDebugger struct { + *Debugger +} + +func (d *Debugger) RegisterHTTP(td *tsweb.DebugHandler) { + dh := httpDebugger{d} + td.Handle("bus", "Event bus", dh) + td.HandleSilent("bus/monitor", http.HandlerFunc(dh.serveMonitor)) + td.HandleSilent("bus/style.css", serveStatic("style.css")) + td.HandleSilent("bus/htmx.min.js", serveStatic("htmx.min.js.gz")) + td.HandleSilent("bus/htmx-websocket.min.js", serveStatic("htmx-websocket.min.js.gz")) +} + +//go:embed assets/*.html +var templatesSrc embed.FS + +var templates = sync.OnceValue(func() *template.Template { + d, err := fs.Sub(templatesSrc, "assets") + if err != nil { + panic(fmt.Errorf("getting eventbus debughttp templates subdir: %w", err)) + } + ret := template.New("").Funcs(map[string]any{ + "prettyPrintStruct": prettyPrintStruct, + }) + return template.Must(ret.ParseFS(d, "*")) +}) + +//go:generate go run fetch-htmx.go + +//go:embed assets/*.css assets/*.min.js.gz +var static embed.FS + +func serveStatic(name string) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case strings.HasSuffix(name, ".css"): + w.Header().Set("Content-Type", "text/css") + case strings.HasSuffix(name, ".min.js.gz"): + w.Header().Set("Content-Type", "text/javascript") + w.Header().Set("Content-Encoding", "gzip") + case strings.HasSuffix(name, ".js"): + w.Header().Set("Content-Type", "text/javascript") + default: + http.Error(w, "not found", http.StatusNotFound) + return + } + + f, err := static.Open(filepath.Join("assets", name)) + if err != nil { + http.Error(w, fmt.Sprintf("opening asset: %v", err), http.StatusInternalServerError) + return + } + defer f.Close() + if _, err := io.Copy(w, f); err != nil { + http.Error(w, fmt.Sprintf("serving asset: %v", err), http.StatusInternalServerError) + return + } + }) +} + +func render(w http.ResponseWriter, name string, data any) { + err := templates().ExecuteTemplate(w, name+".html", data) + if err != nil { + err := fmt.Errorf("rendering template: %v", err) + log.Print(err) + http.Error(w, err.Error(), http.StatusInternalServerError) + } +} + +func (h httpDebugger) ServeHTTP(w http.ResponseWriter, r *http.Request) { + type clientInfo struct { + *Client + Publish []reflect.Type + Subscribe []reflect.Type + } + type typeInfo struct { + reflect.Type + Publish []*Client + Subscribe []*Client + } + type info struct { + *Debugger + Clients map[string]*clientInfo + Types map[string]*typeInfo + } + + data := info{ + Debugger: h.Debugger, + Clients: map[string]*clientInfo{}, + Types: map[string]*typeInfo{}, + } + + getTypeInfo := func(t reflect.Type) *typeInfo { + if data.Types[t.Name()] == nil { + data.Types[t.Name()] = &typeInfo{ + Type: t, + } + } + return data.Types[t.Name()] + } + + for _, c := range h.Clients() { + ci := &clientInfo{ + Client: c, + Publish: h.PublishTypes(c), + Subscribe: h.SubscribeTypes(c), + } + slices.SortFunc(ci.Publish, func(a, b reflect.Type) int { return cmp.Compare(a.Name(), b.Name()) }) + slices.SortFunc(ci.Subscribe, func(a, b reflect.Type) int { return cmp.Compare(a.Name(), b.Name()) }) + data.Clients[c.Name()] = ci + + for _, t := range ci.Publish { + ti := getTypeInfo(t) + ti.Publish = append(ti.Publish, c) + } + for _, t := range ci.Subscribe { + ti := getTypeInfo(t) + ti.Subscribe = append(ti.Subscribe, c) + } + } + + render(w, "main", data) +} + +func (h httpDebugger) serveMonitor(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Upgrade") == "websocket" { + h.serveMonitorStream(w, r) + return + } + + render(w, "monitor", nil) +} + +func (h httpDebugger) serveMonitorStream(w http.ResponseWriter, r *http.Request) { + conn, err := websocket.Accept(w, r, nil) + if err != nil { + return + } + defer conn.CloseNow() + wsCtx := conn.CloseRead(r.Context()) + + mon := h.WatchBus() + defer mon.Close() + + i := 0 + for { + select { + case <-r.Context().Done(): + return + case <-wsCtx.Done(): + return + case <-mon.Done(): + return + case event := <-mon.Events(): + msg, err := conn.Writer(r.Context(), websocket.MessageText) + if err != nil { + return + } + data := map[string]any{ + "Count": i, + "Type": reflect.TypeOf(event.Event), + "Event": event, + } + i++ + if err := templates().ExecuteTemplate(msg, "event.html", data); err != nil { + log.Println(err) + return + } + if err := msg.Close(); err != nil { + return + } + } + } +} + +func prettyPrintStruct(t reflect.Type) string { + if t.Kind() != reflect.Struct { + return t.String() + } + var rec func(io.Writer, int, reflect.Type) + rec = func(out io.Writer, indent int, t reflect.Type) { + ind := strings.Repeat(" ", indent) + fmt.Fprintf(out, "%s", t.String()) + fs := collectFields(t) + if len(fs) > 0 { + io.WriteString(out, " {\n") + for _, f := range fs { + fmt.Fprintf(out, "%s %s ", ind, f.Name) + if f.Type.Kind() == reflect.Struct { + rec(out, indent+1, f.Type) + } else { + fmt.Fprint(out, f.Type) + } + io.WriteString(out, "\n") + } + fmt.Fprintf(out, "%s}", ind) + } + } + + var ret bytes.Buffer + rec(&ret, 0, t) + return ret.String() +} + +func collectFields(t reflect.Type) (ret []reflect.StructField) { + for _, f := range reflect.VisibleFields(t) { + if !f.IsExported() { + continue + } + ret = append(ret, f) + } + return ret +} diff --git a/util/eventbus/debughttp_off.go b/util/eventbus/debughttp_off.go new file mode 100644 index 000000000..332525262 --- /dev/null +++ b/util/eventbus/debughttp_off.go @@ -0,0 +1,10 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build ios || android || ts_omit_debugeventbus + +package eventbus + +type tswebDebugHandler = any // actually *tsweb.DebugHandler; any to avoid import tsweb with ts_omit_debugeventbus + +func (*Debugger) RegisterHTTP(td tswebDebugHandler) {} diff --git a/util/eventbus/doc.go b/util/eventbus/doc.go new file mode 100644 index 000000000..f95f9398c --- /dev/null +++ b/util/eventbus/doc.go @@ -0,0 +1,102 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package eventbus provides an in-process event bus. +// +// An event bus connects publishers of typed events with subscribers +// interested in those events. Typically, there is one global event +// bus per process. +// +// # Usage +// +// To send or receive events, first use [Bus.Client] to register with +// the bus. Clients should register with a human-readable name that +// identifies the code using the client, to aid in debugging. +// +// To publish events, use [Publish] on a Client to get a typed +// publisher for your event type, then call [Publisher.Publish] as +// needed. If your event is expensive to construct, you can optionally +// use [Publisher.ShouldPublish] to skip the work if nobody is +// listening for the event. +// +// To receive events, use [Subscribe] to get a typed subscriber for +// each event type you're interested in. Receive the events themselves +// by selecting over all your [Subscriber.Events] channels, as well as +// [Subscriber.Done] for shutdown notifications. +// +// # Concurrency properties +// +// The bus serializes all published events across all publishers, and +// preserves that ordering when delivering to subscribers that are +// attached to the same Client. In more detail: +// +// - An event is published to the bus at some instant between the +// start and end of the call to [Publisher.Publish]. +// - Two events cannot be published at the same instant, and so are +// totally ordered by their publication time. Given two events E1 +// and E2, either E1 happens before E2, or E2 happens before E1. +// - Clients dispatch events to their Subscribers in publication +// order: if E1 happens before E2, the client always delivers E1 +// before E2. +// - Clients do not synchronize subscriptions with each other: given +// clients C1 and C2, both subscribed to events E1 and E2, C1 may +// deliver both E1 and E2 before C2 delivers E1. +// +// Less formally: there is one true timeline of all published events. +// If you make a Client and subscribe to events, you will receive +// events one at a time, in the same order as the one true +// timeline. You will "skip over" events you didn't subscribe to, but +// your view of the world always moves forward in time, never +// backwards, and you will observe events in the same order as +// everyone else. +// +// However, you cannot assume that what your client see as "now" is +// the same as what other clients. They may be further behind you in +// working through the timeline, or running ahead of you. This means +// you should be careful about reaching out to another component +// directly after receiving an event, as its view of the world may not +// yet (or ever) be exactly consistent with yours. +// +// To make your code more testable and understandable, you should try +// to structure it following the actor model: you have some local +// state over which you have authority, but your only way to interact +// with state elsewhere in the program is to receive and process +// events coming from elsewhere, or to emit events of your own. +// +// # Expected subscriber behavior +// +// Subscribers are expected to promptly receive their events on +// [Subscriber.Events]. The bus has a small, fixed amount of internal +// buffering, meaning that a slow subscriber will eventually cause +// backpressure and block publication of all further events. +// +// In general, you should receive from your subscriber(s) in a loop, +// and only do fast state updates within that loop. Any heavier work +// should be offloaded to another goroutine. +// +// Causing publishers to block from backpressure is considered a bug +// in the slow subscriber causing the backpressure, and should be +// addressed there. Publishers should assume that Publish will not +// block for extended periods of time, and should not make exceptional +// effort to behave gracefully if they do get blocked. +// +// These blocking semantics are provisional and subject to +// change. Please speak up if this causes development pain, so that we +// can adapt the semantics to better suit our needs. +// +// # Debugging facilities +// +// The [Debugger], obtained through [Bus.Debugger], provides +// introspection facilities to monitor events flowing through the bus, +// and inspect publisher and subscriber state. +// +// Additionally, a debug command exists for monitoring the eventbus: +// +// tailscale debug daemon-bus-events +// +// # Testing facilities +// +// Helpers for testing code with the eventbus can be found in: +// +// eventbus/eventbustest +package eventbus diff --git a/util/eventbus/eventbustest/doc.go b/util/eventbus/eventbustest/doc.go new file mode 100644 index 000000000..1e9928b9d --- /dev/null +++ b/util/eventbus/eventbustest/doc.go @@ -0,0 +1,59 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package eventbustest provides helper methods for testing an [eventbus.Bus]. +// +// # Usage +// +// A [Watcher] presents a set of generic helpers for testing events. +// +// To test code that generates events, create a [Watcher] from the [eventbus.Bus] +// used by the code under test, run the code to generate events, then use the watcher +// to verify that the expected events were produced. In outline: +// +// bus := eventbustest.NewBus(t) +// tw := eventbustest.NewWatcher(t, bus) +// somethingThatEmitsSomeEvent() +// if err := eventbustest.Expect(tw, eventbustest.Type[EventFoo]()); err != nil { +// t.Error(err.Error()) +// } +// +// As shown, [Expect] checks that at least one event of the given type occurs +// in the stream generated by the code under test. +// +// The following functions all take an any parameter representing a function. +// This function will take an argument of the expected type and is used to test +// for the events on the eventbus being of the given type. The function can +// take the shape described in [Expect]. +// +// [Type] is a helper for only testing event type. +// +// To check for specific properties of an event, use [Expect], and pass a function +// as the second argument that tests for those properties. +// +// To test for multiple events, use [Expect], which checks that the stream +// contains the given events in the given order, possibly with other events +// interspersed. +// +// To test the complete contents of the stream, use [ExpectExactly], which +// checks that the stream contains exactly the given events in the given order, +// and no others. +// +// To test for the absence of events, use [ExpectExactly] without any +// expected events, along side [testing/synctest] to avoid waiting for timers +// to ensure that no events are produced. This will look like: +// +// synctest.Test(t, func(t *testing.T) { +// bus := eventbustest.NewBus(t) +// tw := eventbustest.NewWatcher(t, bus) +// somethingThatShouldNotEmitsSomeEvent() +// synctest.Wait() +// if err := eventbustest.ExpectExactly(tw); err != nil { +// t.Errorf("Expected no events or errors, got %v", err) +// } +// }) +// +// See the [usage examples]. +// +// [usage examples]: https://github.com/tailscale/tailscale/blob/main/util/eventbus/eventbustest/examples_test.go +package eventbustest diff --git a/util/eventbus/eventbustest/eventbustest.go b/util/eventbus/eventbustest/eventbustest.go new file mode 100644 index 000000000..fd8a15081 --- /dev/null +++ b/util/eventbus/eventbustest/eventbustest.go @@ -0,0 +1,295 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package eventbustest + +import ( + "errors" + "fmt" + "reflect" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "tailscale.com/util/eventbus" +) + +// NewBus constructs an [eventbus.Bus] that will be shut automatically when +// its controlling test ends. +func NewBus(t testing.TB) *eventbus.Bus { + bus := eventbus.New() + t.Cleanup(bus.Close) + return bus +} + +// NewWatcher constructs a [Watcher] that can be used to check the stream of +// events generated by code under test. After construction the caller may use +// [Expect] and [ExpectExactly], to verify that the desired events were captured. +func NewWatcher(t *testing.T, bus *eventbus.Bus) *Watcher { + tw := &Watcher{ + mon: bus.Debugger().WatchBus(), + chDone: make(chan bool, 1), + events: make(chan any, 100), + } + t.Cleanup(tw.done) + go tw.watch() + return tw +} + +// Watcher monitors and holds events for test expectations. +// The Watcher works with [synctest], and some scenarios does require the use of +// [synctest]. This is amongst others true if you are testing for the absence of +// events. +// +// For usage examples, see the documentation in the top of the package. +type Watcher struct { + mon *eventbus.Subscriber[eventbus.RoutedEvent] + events chan any + chDone chan bool +} + +// Type is a helper representing the expectation to see an event of type T, without +// caring about the content of the event. +// It makes it possible to use helpers like: +// +// eventbustest.ExpectFilter(tw, eventbustest.Type[EventFoo]()) +func Type[T any]() func(T) { return func(T) {} } + +// Expect verifies that the given events are a subsequence of the events +// observed by tw. That is, tw must contain at least one event matching the type +// of each argument in the given order, other event types are allowed to occur in +// between without error. The given events are represented by a function +// that must have one of the following forms: +// +// // Tests for the event type only +// func(e ExpectedType) +// +// // Tests for event type and whatever is defined in the body. +// // If return is false, the test will look for other events of that type +// // If return is true, the test will look for the next given event +// // if a list is given +// func(e ExpectedType) bool +// +// // Tests for event type and whatever is defined in the body. +// // The boolean return works as above. +// // The if error != nil, the test helper will return that error immediately. +// func(e ExpectedType) (bool, error) +// +// // Tests for event type and whatever is defined in the body. +// // If a non-nil error is reported, the test helper will return that error +// // immediately; otherwise the expectation is considered to be met. +// func(e ExpectedType) error +// +// If the list of events must match exactly with no extra events, +// use [ExpectExactly]. +func Expect(tw *Watcher, filters ...any) error { + if len(filters) == 0 { + return errors.New("no event filters were provided") + } + eventCount := 0 + head := 0 + for head < len(filters) { + eventFunc := eventFilter(filters[head]) + select { + case event := <-tw.events: + eventCount++ + if ok, err := eventFunc(event); err != nil { + return err + } else if ok { + head++ + } + // Use synctest when you want an error here. + case <-time.After(100 * time.Second): // "indefinitely", to advance a synctest clock + return fmt.Errorf( + "timed out waiting for event, saw %d events, %d was expected", + eventCount, len(filters)) + case <-tw.chDone: + return errors.New("watcher closed while waiting for events") + } + } + return nil +} + +// ExpectExactly checks for some number of events showing up on the event bus +// in a given order, returning an error if the events does not match the given list +// exactly. The given events are represented by a function as described in +// [Expect]. Use [Expect] if other events are allowed. +// +// If you are expecting ExpectExactly to fail because of a missing event, or if +// you are testing for the absence of events, call [synctest.Wait] after +// actions that would publish an event, but before calling ExpectExactly. +func ExpectExactly(tw *Watcher, filters ...any) error { + if len(filters) == 0 { + select { + case event := <-tw.events: + return fmt.Errorf("saw event type %s, expected none", reflect.TypeOf(event)) + case <-time.After(100 * time.Second): // "indefinitely", to advance a synctest clock + return nil + } + } + eventCount := 0 + for pos, next := range filters { + eventFunc := eventFilter(next) + fnType := reflect.TypeOf(next) + argType := fnType.In(0) + select { + case event := <-tw.events: + eventCount++ + typeEvent := reflect.TypeOf(event) + if typeEvent != argType { + return fmt.Errorf( + "expected event type %s, saw %s, at index %d", + argType, typeEvent, pos) + } else if ok, err := eventFunc(event); err != nil { + return err + } else if !ok { + return fmt.Errorf( + "expected test ok for type %s, at index %d", argType, pos) + } + case <-time.After(100 * time.Second): // "indefinitely", to advance a synctest clock + return fmt.Errorf( + "timed out waiting for event, saw %d events, %d was expected", + eventCount, len(filters)) + case <-tw.chDone: + return errors.New("watcher closed while waiting for events") + } + } + return nil +} + +func (tw *Watcher) watch() { + for { + select { + case event := <-tw.mon.Events(): + tw.events <- event.Event + case <-tw.mon.Done(): + tw.done() + return + case <-tw.chDone: + tw.mon.Close() + return + } + } +} + +// done tells the watcher to stop monitoring for new events. +func (tw *Watcher) done() { + close(tw.chDone) +} + +type filter = func(any) (bool, error) + +func eventFilter(f any) filter { + ft := reflect.TypeOf(f) + if ft.Kind() != reflect.Func { + panic("filter is not a function") + } else if ft.NumIn() != 1 { + panic(fmt.Sprintf("function takes %d arguments, want 1", ft.NumIn())) + } + var fixup func([]reflect.Value) []reflect.Value + switch ft.NumOut() { + case 0: + fixup = func([]reflect.Value) []reflect.Value { + return []reflect.Value{reflect.ValueOf(true), reflect.Zero(reflect.TypeFor[error]())} + } + case 1: + switch ft.Out(0) { + case reflect.TypeFor[bool](): + fixup = func(vals []reflect.Value) []reflect.Value { + return append(vals, reflect.Zero(reflect.TypeFor[error]())) + } + case reflect.TypeFor[error](): + fixup = func(vals []reflect.Value) []reflect.Value { + pass := vals[0].IsZero() + return append([]reflect.Value{reflect.ValueOf(pass)}, vals...) + } + default: + panic(fmt.Sprintf("result is %v, want bool or error", ft.Out(0))) + } + case 2: + if ft.Out(0) != reflect.TypeFor[bool]() || ft.Out(1) != reflect.TypeFor[error]() { + panic(fmt.Sprintf("results are %v, %v; want bool, error", ft.Out(0), ft.Out(1))) + } + fixup = func(vals []reflect.Value) []reflect.Value { return vals } + default: + panic(fmt.Sprintf("function returns %d values", ft.NumOut())) + } + fv := reflect.ValueOf(f) + return reflect.MakeFunc(reflect.TypeFor[filter](), func(args []reflect.Value) []reflect.Value { + if !args[0].IsValid() || args[0].Elem().Type() != ft.In(0) { + return []reflect.Value{reflect.ValueOf(false), reflect.Zero(reflect.TypeFor[error]())} + } + return fixup(fv.Call([]reflect.Value{args[0].Elem()})) + }).Interface().(filter) +} + +// Injector holds a map with [eventbus.Publisher], tied to an [eventbus.Client] +// for testing purposes. +type Injector struct { + client *eventbus.Client + publishers map[reflect.Type]any + // The value for a key is an *eventbus.Publisher[T] for the corresponding type. +} + +// NewInjector constructs an [Injector] that can be used to inject events into +// the the stream of events used by code under test. After construction the +// caller may use [Inject] to insert events into the bus. +func NewInjector(t *testing.T, b *eventbus.Bus) *Injector { + inj := &Injector{ + client: b.Client(t.Name()), + publishers: make(map[reflect.Type]any), + } + t.Cleanup(inj.client.Close) + + return inj +} + +// Inject inserts events of T onto an [eventbus.Bus]. If an [eventbus.Publisher] +// for the type does not exist, it will be initialized lazily. Calling inject is +// synchronous, and the event will as such have been published to the eventbus +// by the time the function returns. +func Inject[T any](inj *Injector, event T) { + eventType := reflect.TypeFor[T]() + + pub, ok := inj.publishers[eventType] + if !ok { + pub = eventbus.Publish[T](inj.client) + inj.publishers[eventType] = pub + } + pub.(*eventbus.Publisher[T]).Publish(event) +} + +// EqualTo returns an event-matching function for use with [Expect] and +// [ExpectExactly] that matches on an event of the given type that is equal to +// want by comparison with [cmp.Diff]. The expectation fails with an error +// message including the diff, if present. +func EqualTo[T any](want T) func(T) error { + return func(got T) error { + if diff := cmp.Diff(got, want); diff != "" { + return fmt.Errorf("wrong result (-got, +want):\n%s", diff) + } + return nil + } +} + +// LogAllEvents logs summaries of all the events routed via the specified bus +// during the execution of the test governed by t. This is intended to support +// development and debugging of tests. +func LogAllEvents(t testing.TB, bus *eventbus.Bus) { + dw := bus.Debugger().WatchBus() + done := make(chan struct{}) + go func() { + defer close(done) + var i int + for { + select { + case <-dw.Done(): + return + case re := <-dw.Events(): + i++ + t.Logf("[eventbus] #%[1]d: %[2]T | %+[2]v", i, re.Event) + } + } + }() + t.Cleanup(func() { dw.Close(); <-done }) +} diff --git a/util/eventbus/eventbustest/eventbustest_test.go b/util/eventbus/eventbustest/eventbustest_test.go new file mode 100644 index 000000000..ac454023c --- /dev/null +++ b/util/eventbus/eventbustest/eventbustest_test.go @@ -0,0 +1,408 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package eventbustest_test + +import ( + "flag" + "fmt" + "strings" + "testing" + "testing/synctest" + + "tailscale.com/util/eventbus" + "tailscale.com/util/eventbus/eventbustest" +) + +var doDebug = flag.Bool("debug", false, "Enable debug logging") + +type EventFoo struct { + Value int +} + +type EventBar struct { + Value string +} + +type EventBaz struct { + Value []float64 +} + +func TestExpectFilter(t *testing.T) { + tests := []struct { + name string + events []int + expectFunc any + wantErr string // if non-empty, an error is expected containing this text + }{ + { + name: "single event", + events: []int{42}, + expectFunc: eventbustest.Type[EventFoo](), + }, + { + name: "multiple events, single expectation", + events: []int{42, 1, 2, 3, 4, 5}, + expectFunc: eventbustest.Type[EventFoo](), + }, + { + name: "filter on event with function", + events: []int{24, 42}, + expectFunc: func(event EventFoo) (bool, error) { + if event.Value == 42 { + return true, nil + } + return false, nil + }, + }, + { + name: "filter-with-nil-error", + events: []int{1, 2, 3}, + expectFunc: func(event EventFoo) error { + if event.Value > 10 { + return fmt.Errorf("value > 10: %d", event.Value) + } + return nil + }, + }, + { + name: "filter-with-non-nil-error", + events: []int{100, 200, 300}, + expectFunc: func(event EventFoo) error { + if event.Value > 10 { + return fmt.Errorf("value > 10: %d", event.Value) + } + return nil + }, + wantErr: "value > 10", + }, + { + name: "first event has to be func", + events: []int{24, 42}, + expectFunc: func(event EventFoo) (bool, error) { + if event.Value != 42 { + return false, fmt.Errorf("expected 42, got %d", event.Value) + } + return false, nil + }, + wantErr: "expected 42, got 24", + }, + { + name: "equal-values", + events: []int{23}, + expectFunc: eventbustest.EqualTo(EventFoo{Value: 23}), + }, + { + name: "unequal-values", + events: []int{37}, + expectFunc: eventbustest.EqualTo(EventFoo{Value: 23}), + wantErr: "wrong result (-got, +want)", + }, + { + name: "no events", + events: []int{}, + expectFunc: func(event EventFoo) (bool, error) { + return true, nil + }, + wantErr: "timed out waiting", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + bus := eventbustest.NewBus(t) + + if *doDebug { + eventbustest.LogAllEvents(t, bus) + } + tw := eventbustest.NewWatcher(t, bus) + + client := bus.Client("testClient") + updater := eventbus.Publish[EventFoo](client) + + for _, i := range tt.events { + updater.Publish(EventFoo{i}) + } + + synctest.Wait() + + if err := eventbustest.Expect(tw, tt.expectFunc); err != nil { + if tt.wantErr == "" { + t.Errorf("Expect[EventFoo]: unexpected error: %v", err) + } else if !strings.Contains(err.Error(), tt.wantErr) { + t.Errorf("Expect[EventFoo]: err = %v, want %q", err, tt.wantErr) + } else { + t.Logf("Got expected error: %v (OK)", err) + } + } else if tt.wantErr != "" { + t.Errorf("Expect[EventFoo]: unexpectedly succeeded, want error %q", tt.wantErr) + } + }) + }) + } +} + +func TestExpectEvents(t *testing.T) { + tests := []struct { + name string + events []any + expectEvents []any + wantErr bool + }{ + { + name: "No expectations", + events: []any{EventFoo{}}, + expectEvents: []any{}, + wantErr: true, + }, + { + name: "One event", + events: []any{EventFoo{}}, + expectEvents: []any{eventbustest.Type[EventFoo]()}, + wantErr: false, + }, + { + name: "Two events", + events: []any{EventFoo{}, EventBar{}}, + expectEvents: []any{eventbustest.Type[EventFoo](), eventbustest.Type[EventBar]()}, + wantErr: false, + }, + { + name: "Two expected events with another in the middle", + events: []any{EventFoo{}, EventBaz{}, EventBar{}}, + expectEvents: []any{eventbustest.Type[EventFoo](), eventbustest.Type[EventBar]()}, + wantErr: false, + }, + { + name: "Missing event", + events: []any{EventFoo{}, EventBaz{}}, + expectEvents: []any{eventbustest.Type[EventFoo](), eventbustest.Type[EventBar]()}, + wantErr: true, + }, + { + name: "One event with specific value", + events: []any{EventFoo{42}}, + expectEvents: []any{ + func(ev EventFoo) (bool, error) { + if ev.Value == 42 { + return true, nil + } + return false, nil + }, + }, + wantErr: false, + }, + { + name: "Two event with one specific value", + events: []any{EventFoo{43}, EventFoo{42}}, + expectEvents: []any{ + func(ev EventFoo) (bool, error) { + if ev.Value == 42 { + return true, nil + } + return false, nil + }, + }, + wantErr: false, + }, + { + name: "One event with wrong value", + events: []any{EventFoo{43}}, + expectEvents: []any{ + func(ev EventFoo) (bool, error) { + if ev.Value == 42 { + return true, nil + } + return false, nil + }, + }, + wantErr: true, + }, + { + name: "Two events with specific values", + events: []any{EventFoo{42}, EventFoo{42}, EventBar{"42"}}, + expectEvents: []any{ + func(ev EventFoo) (bool, error) { + if ev.Value == 42 { + return true, nil + } + return false, nil + }, + func(ev EventBar) (bool, error) { + if ev.Value == "42" { + return true, nil + } + return false, nil + }, + }, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + bus := eventbustest.NewBus(t) + + tw := eventbustest.NewWatcher(t, bus) + + client := bus.Client("testClient") + updaterFoo := eventbus.Publish[EventFoo](client) + updaterBar := eventbus.Publish[EventBar](client) + updaterBaz := eventbus.Publish[EventBaz](client) + + for _, ev := range tt.events { + switch ev := ev.(type) { + case EventFoo: + evCast := ev + updaterFoo.Publish(evCast) + case EventBar: + evCast := ev + updaterBar.Publish(evCast) + case EventBaz: + evCast := ev + updaterBaz.Publish(evCast) + } + } + + synctest.Wait() + if err := eventbustest.Expect(tw, tt.expectEvents...); (err != nil) != tt.wantErr { + t.Errorf("ExpectEvents: error = %v, wantErr %v", err, tt.wantErr) + } + }) + }) + } +} + +func TestExpectExactlyEventsFilter(t *testing.T) { + tests := []struct { + name string + events []any + expectEvents []any + wantErr bool + }{ + { + name: "No expectations", + events: []any{EventFoo{}}, + expectEvents: []any{}, + wantErr: true, + }, + { + name: "One event", + events: []any{EventFoo{}}, + expectEvents: []any{eventbustest.Type[EventFoo]()}, + wantErr: false, + }, + { + name: "Two events", + events: []any{EventFoo{}, EventBar{}}, + expectEvents: []any{eventbustest.Type[EventFoo](), eventbustest.Type[EventBar]()}, + wantErr: false, + }, + { + name: "Two expected events with another in the middle", + events: []any{EventFoo{}, EventBaz{}, EventBar{}}, + expectEvents: []any{eventbustest.Type[EventFoo](), eventbustest.Type[EventBar]()}, + wantErr: true, + }, + { + name: "Missing event", + events: []any{EventFoo{}, EventBaz{}}, + expectEvents: []any{eventbustest.Type[EventFoo](), eventbustest.Type[EventBar]()}, + wantErr: true, + }, + { + name: "One event with value", + events: []any{EventFoo{42}}, + expectEvents: []any{ + func(ev EventFoo) (bool, error) { + if ev.Value == 42 { + return true, nil + } + return false, nil + }, + }, + wantErr: false, + }, + { + name: "Two event with one specific value", + events: []any{EventFoo{43}, EventFoo{42}}, + expectEvents: []any{ + func(ev EventFoo) (bool, error) { + if ev.Value == 42 { + return true, nil + } + return false, nil + }, + }, + wantErr: true, + }, + { + name: "One event with wrong value", + events: []any{EventFoo{43}}, + expectEvents: []any{ + func(ev EventFoo) (bool, error) { + if ev.Value == 42 { + return true, nil + } + return false, nil + }, + }, + wantErr: true, + }, + { + name: "Two events with specific values", + events: []any{EventFoo{42}, EventFoo{42}, EventBar{"42"}}, + expectEvents: []any{ + func(ev EventFoo) (bool, error) { + if ev.Value == 42 { + return true, nil + } + return false, nil + }, + func(ev EventBar) (bool, error) { + if ev.Value == "42" { + return true, nil + } + return false, nil + }, + }, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + bus := eventbustest.NewBus(t) + + tw := eventbustest.NewWatcher(t, bus) + + client := bus.Client("testClient") + updaterFoo := eventbus.Publish[EventFoo](client) + updaterBar := eventbus.Publish[EventBar](client) + updaterBaz := eventbus.Publish[EventBaz](client) + + for _, ev := range tt.events { + switch ev := ev.(type) { + case EventFoo: + evCast := ev + updaterFoo.Publish(evCast) + case EventBar: + evCast := ev + updaterBar.Publish(evCast) + case EventBaz: + evCast := ev + updaterBaz.Publish(evCast) + } + } + + synctest.Wait() + if err := eventbustest.ExpectExactly(tw, tt.expectEvents...); (err != nil) != tt.wantErr { + t.Errorf("ExpectEvents: error = %v, wantErr %v", err, tt.wantErr) + } + }) + }) + } +} diff --git a/util/eventbus/eventbustest/examples_test.go b/util/eventbus/eventbustest/examples_test.go new file mode 100644 index 000000000..c84811317 --- /dev/null +++ b/util/eventbus/eventbustest/examples_test.go @@ -0,0 +1,260 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package eventbustest_test + +import ( + "testing" + "testing/synctest" + "time" + + "tailscale.com/util/eventbus" + "tailscale.com/util/eventbus/eventbustest" +) + +func TestExample_Expect(t *testing.T) { + type eventOfInterest struct{} + + bus := eventbustest.NewBus(t) + tw := eventbustest.NewWatcher(t, bus) + + client := bus.Client("testClient") + updater := eventbus.Publish[eventOfInterest](client) + updater.Publish(eventOfInterest{}) + + if err := eventbustest.Expect(tw, eventbustest.Type[eventOfInterest]()); err != nil { + t.Log(err.Error()) + } else { + t.Log("OK") + } + // Output: + // OK +} + +func TestExample_Expect_WithFunction(t *testing.T) { + type eventOfInterest struct { + value int + } + + bus := eventbustest.NewBus(t) + tw := eventbustest.NewWatcher(t, bus) + + client := bus.Client("testClient") + updater := eventbus.Publish[eventOfInterest](client) + updater.Publish(eventOfInterest{43}) + updater.Publish(eventOfInterest{42}) + + // Look for an event of eventOfInterest with a specific value + if err := eventbustest.Expect(tw, func(event eventOfInterest) (bool, error) { + if event.value != 42 { + return false, nil // Look for another event with the expected value. + // You could alternatively return an error here to ensure that the + // first seen eventOfInterest matches the value: + // return false, fmt.Errorf("expected 42, got %d", event.value) + } + return true, nil + }); err != nil { + t.Log(err.Error()) + } else { + t.Log("OK") + } + // Output: + // OK +} + +func TestExample_Expect_MultipleEvents(t *testing.T) { + type eventOfInterest struct{} + type eventOfNoConcern struct{} + type eventOfCuriosity struct{} + + bus := eventbustest.NewBus(t) + tw := eventbustest.NewWatcher(t, bus) + + client := bus.Client("testClient") + updaterInterest := eventbus.Publish[eventOfInterest](client) + updaterConcern := eventbus.Publish[eventOfNoConcern](client) + updaterCuriosity := eventbus.Publish[eventOfCuriosity](client) + updaterInterest.Publish(eventOfInterest{}) + updaterConcern.Publish(eventOfNoConcern{}) + updaterCuriosity.Publish(eventOfCuriosity{}) + + // Even though three events was published, we just care about the two + if err := eventbustest.Expect(tw, + eventbustest.Type[eventOfInterest](), + eventbustest.Type[eventOfCuriosity]()); err != nil { + t.Log(err.Error()) + } else { + t.Log("OK") + } + // Output: + // OK +} + +func TestExample_ExpectExactly_MultipleEvents(t *testing.T) { + type eventOfInterest struct{} + type eventOfNoConcern struct{} + type eventOfCuriosity struct{} + + bus := eventbustest.NewBus(t) + tw := eventbustest.NewWatcher(t, bus) + + client := bus.Client("testClient") + updaterInterest := eventbus.Publish[eventOfInterest](client) + updaterConcern := eventbus.Publish[eventOfNoConcern](client) + updaterCuriosity := eventbus.Publish[eventOfCuriosity](client) + updaterInterest.Publish(eventOfInterest{}) + updaterConcern.Publish(eventOfNoConcern{}) + updaterCuriosity.Publish(eventOfCuriosity{}) + + // Will fail as more events than the two expected comes in + if err := eventbustest.ExpectExactly(tw, + eventbustest.Type[eventOfInterest](), + eventbustest.Type[eventOfCuriosity]()); err != nil { + t.Log(err.Error()) + } else { + t.Log("OK") + } +} + +func TestExample_Expect_WithMultipleFunctions(t *testing.T) { + type eventOfInterest struct { + value int + } + type eventOfNoConcern struct{} + type eventOfCuriosity struct { + value string + } + + bus := eventbustest.NewBus(t) + tw := eventbustest.NewWatcher(t, bus) + + client := bus.Client("testClient") + updaterInterest := eventbus.Publish[eventOfInterest](client) + updaterConcern := eventbus.Publish[eventOfNoConcern](client) + updaterCuriosity := eventbus.Publish[eventOfCuriosity](client) + updaterInterest.Publish(eventOfInterest{42}) + updaterConcern.Publish(eventOfNoConcern{}) + updaterCuriosity.Publish(eventOfCuriosity{"42"}) + + interest := func(event eventOfInterest) (bool, error) { + if event.value == 42 { + return true, nil + } + return false, nil + } + curiosity := func(event eventOfCuriosity) (bool, error) { + if event.value == "42" { + return true, nil + } + return false, nil + } + + // Will fail as more events than the two expected comes in + if err := eventbustest.Expect(tw, interest, curiosity); err != nil { + t.Log(err.Error()) + } else { + t.Log("OK") + } + // Output: + // OK +} + +func TestExample_ExpectExactly_WithMultipleFunctions(t *testing.T) { + type eventOfInterest struct { + value int + } + type eventOfNoConcern struct{} + type eventOfCuriosity struct { + value string + } + + bus := eventbustest.NewBus(t) + tw := eventbustest.NewWatcher(t, bus) + + client := bus.Client("testClient") + updaterInterest := eventbus.Publish[eventOfInterest](client) + updaterConcern := eventbus.Publish[eventOfNoConcern](client) + updaterCuriosity := eventbus.Publish[eventOfCuriosity](client) + updaterInterest.Publish(eventOfInterest{42}) + updaterConcern.Publish(eventOfNoConcern{}) + updaterCuriosity.Publish(eventOfCuriosity{"42"}) + + interest := func(event eventOfInterest) (bool, error) { + if event.value == 42 { + return true, nil + } + return false, nil + } + curiosity := func(event eventOfCuriosity) (bool, error) { + if event.value == "42" { + return true, nil + } + return false, nil + } + + // Will fail as more events than the two expected comes in + if err := eventbustest.ExpectExactly(tw, interest, curiosity); err != nil { + t.Log(err.Error()) + } else { + t.Log("OK") + } + // Output: + // expected event type eventbustest.eventOfCuriosity, saw eventbustest.eventOfNoConcern, at index 1 +} + +func TestExample_ExpectExactly_NoEvents(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + bus := eventbustest.NewBus(t) + tw := eventbustest.NewWatcher(t, bus) + + go func() { + // Do some work that does not produce an event + time.Sleep(10 * time.Second) + t.Log("Not producing events") + }() + + // Wait for all other routines to be stale before continuing to ensure that + // there is nothing running that would produce an event at a later time. + synctest.Wait() + + if err := eventbustest.ExpectExactly(tw); err != nil { + t.Error(err.Error()) + } else { + t.Log("OK") + } + // Output: + // OK + }) +} + +func TestExample_ExpectExactly_OneEventExpectingTwo(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + type eventOfInterest struct{} + + bus := eventbustest.NewBus(t) + tw := eventbustest.NewWatcher(t, bus) + client := bus.Client("testClient") + updater := eventbus.Publish[eventOfInterest](client) + + go func() { + // Do some work that does not produce an event + time.Sleep(10 * time.Second) + updater.Publish(eventOfInterest{}) + }() + + // Wait for all other routines to be stale before continuing to ensure that + // there is nothing running that would produce an event at a later time. + synctest.Wait() + + if err := eventbustest.ExpectExactly(tw, + eventbustest.Type[eventOfInterest](), + eventbustest.Type[eventOfInterest](), + ); err != nil { + t.Log(err.Error()) + } else { + t.Log("OK") + } + // Output: + // timed out waiting for event, saw 1 events, 2 was expected + }) +} diff --git a/util/eventbus/fetch-htmx.go b/util/eventbus/fetch-htmx.go new file mode 100644 index 000000000..f80d50257 --- /dev/null +++ b/util/eventbus/fetch-htmx.go @@ -0,0 +1,93 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build ignore + +// Program fetch-htmx fetches and installs local copies of the HTMX +// library and its dependencies, used by the debug UI. It is meant to +// be run via go generate. +package main + +import ( + "compress/gzip" + "crypto/sha512" + "encoding/base64" + "fmt" + "io" + "log" + "net/http" + "os" +) + +func main() { + // Hash from https://htmx.org/docs/#installing + htmx, err := fetchHashed("https://unpkg.com/htmx.org@2.0.4", "HGfztofotfshcF7+8n44JQL2oJmowVChPTg48S+jvZoztPfvwD79OC/LTtG6dMp+") + if err != nil { + log.Fatalf("fetching htmx: %v", err) + } + + // Hash SHOULD be from https://htmx.org/extensions/ws/ , but the + // hash is currently incorrect, see + // https://github.com/bigskysoftware/htmx-extensions/issues/153 + // + // Until that bug is resolved, hash was obtained by rebuilding the + // extension from git source, and verifying that the hash matches + // what unpkg is serving. + ws, err := fetchHashed("https://unpkg.com/htmx-ext-ws@2.0.2", "932iIqjARv+Gy0+r6RTGrfCkCKS5MsF539Iqf6Vt8L4YmbnnWI2DSFoMD90bvXd0") + if err != nil { + log.Fatalf("fetching htmx-websockets: %v", err) + } + + if err := writeGz("assets/htmx.min.js.gz", htmx); err != nil { + log.Fatalf("writing htmx.min.js.gz: %v", err) + } + if err := writeGz("assets/htmx-websocket.min.js.gz", ws); err != nil { + log.Fatalf("writing htmx-websocket.min.js.gz: %v", err) + } +} + +func writeGz(path string, bs []byte) error { + f, err := os.Create(path) + if err != nil { + return err + } + defer f.Close() + + g, err := gzip.NewWriterLevel(f, gzip.BestCompression) + if err != nil { + return err + } + + if _, err := g.Write(bs); err != nil { + return err + } + + if err := g.Flush(); err != nil { + return err + } + if err := f.Close(); err != nil { + return err + } + return nil +} + +func fetchHashed(url, wantHash string) ([]byte, error) { + resp, err := http.Get(url) + if err != nil { + return nil, err + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("fetching %q returned error status: %s", url, resp.Status) + } + ret, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("reading file from %q: %v", url, err) + } + h := sha512.Sum384(ret) + got := base64.StdEncoding.EncodeToString(h[:]) + if got != wantHash { + return nil, fmt.Errorf("wrong hash for %q: got %q, want %q", url, got, wantHash) + } + return ret, nil +} diff --git a/util/eventbus/monitor.go b/util/eventbus/monitor.go new file mode 100644 index 000000000..db6fe1be4 --- /dev/null +++ b/util/eventbus/monitor.go @@ -0,0 +1,54 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package eventbus + +import "tailscale.com/syncs" + +// A Monitor monitors the execution of a goroutine processing events from a +// [Client], allowing the caller to block until it is complete. The zero value +// of m is valid; its Close and Wait methods return immediately, and its Done +// method returns an already-closed channel. +type Monitor struct { + // These fields are immutable after initialization + cli *Client + done <-chan struct{} +} + +// Close closes the client associated with m and blocks until the processing +// goroutine is complete. +func (m Monitor) Close() { + if m.cli == nil { + return + } + m.cli.Close() + <-m.done +} + +// Wait blocks until the goroutine monitored by m has finished executing, but +// does not close the associated client. It is safe to call Wait repeatedly, +// and from multiple concurrent goroutines. +func (m Monitor) Wait() { + if m.done == nil { + return + } + <-m.done +} + +// Done returns a channel that is closed when the monitored goroutine has +// finished executing. +func (m Monitor) Done() <-chan struct{} { + if m.done == nil { + return syncs.ClosedChan() + } + return m.done +} + +// Monitor executes f in a new goroutine attended by a [Monitor]. The caller +// is responsible for waiting for the goroutine to complete, by calling either +// [Monitor.Close] or [Monitor.Wait]. +func (c *Client) Monitor(f func(*Client)) Monitor { + done := make(chan struct{}) + go func() { defer close(done); f(c) }() + return Monitor{cli: c, done: done} +} diff --git a/util/eventbus/publish.go b/util/eventbus/publish.go new file mode 100644 index 000000000..348bb9dff --- /dev/null +++ b/util/eventbus/publish.go @@ -0,0 +1,74 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package eventbus + +import ( + "reflect" +) + +// publisher is a uniformly typed wrapper around Publisher[T], so that +// debugging facilities can look at active publishers. +type publisher interface { + publishType() reflect.Type + Close() +} + +// A Publisher publishes typed events on a bus. +type Publisher[T any] struct { + client *Client + stop stopFlag +} + +func newPublisher[T any](c *Client) *Publisher[T] { + return &Publisher[T]{client: c} +} + +// Close closes the publisher. +// +// Calls to Publish after Close silently do nothing. +// +// If the Bus or Client from which the Publisher was created is closed, +// the Publisher is implicitly closed and does not need to be closed +// separately. +func (p *Publisher[T]) Close() { + // Just unblocks any active calls to Publish, no other + // synchronization needed. + p.stop.Stop() + p.client.deletePublisher(p) +} + +func (p *Publisher[T]) publishType() reflect.Type { + return reflect.TypeFor[T]() +} + +// Publish publishes event v on the bus. +func (p *Publisher[T]) Publish(v T) { + // Check for just a stopped publisher or bus before trying to + // write, so that once closed Publish consistently does nothing. + select { + case <-p.stop.Done(): + return + default: + } + + evt := PublishedEvent{ + Event: v, + From: p.client, + } + + select { + case p.client.publish() <- evt: + case <-p.stop.Done(): + } +} + +// ShouldPublish reports whether anyone is subscribed to the events +// that this publisher emits. +// +// ShouldPublish can be used to skip expensive event construction if +// nobody seems to care. Publishers must not assume that someone will +// definitely receive an event if ShouldPublish returns true. +func (p *Publisher[T]) ShouldPublish() bool { + return p.client.shouldPublish(reflect.TypeFor[T]()) +} diff --git a/util/eventbus/queue.go b/util/eventbus/queue.go new file mode 100644 index 000000000..a62bf3c62 --- /dev/null +++ b/util/eventbus/queue.go @@ -0,0 +1,85 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package eventbus + +import ( + "slices" +) + +const maxQueuedItems = 16 + +// queue is an ordered queue of length up to maxQueuedItems. +type queue[T any] struct { + vals []T + start int +} + +// canAppend reports whether a value can be appended to q.vals without +// shifting values around. +func (q *queue[T]) canAppend() bool { + return cap(q.vals) < maxQueuedItems || len(q.vals) < cap(q.vals) +} + +func (q *queue[T]) Full() bool { + return q.start == 0 && !q.canAppend() +} + +func (q *queue[T]) Empty() bool { + return q.start == len(q.vals) +} + +func (q *queue[T]) Len() int { + return len(q.vals) - q.start +} + +// Add adds v to the end of the queue. Blocks until append can be +// done. +func (q *queue[T]) Add(v T) { + if !q.canAppend() { + if q.start == 0 { + panic("Add on a full queue") + } + + // Slide remaining values back to the start of the array. + n := copy(q.vals, q.vals[q.start:]) + toClear := len(q.vals) - n + clear(q.vals[len(q.vals)-toClear:]) + q.vals = q.vals[:n] + q.start = 0 + } + + q.vals = append(q.vals, v) +} + +// Peek returns the first value in the queue, without removing it from +// the queue, or nil if the queue is empty. +func (q *queue[T]) Peek() T { + if q.Empty() { + var zero T + return zero + } + + return q.vals[q.start] +} + +// Drop discards the first value in the queue, if any. +func (q *queue[T]) Drop() { + if q.Empty() { + return + } + + var zero T + q.vals[q.start] = zero + q.start++ + if q.Empty() { + // Reset cursor to start of array, it's free to do. + q.start = 0 + q.vals = q.vals[:0] + } +} + +// Snapshot returns a copy of the queue's contents. +func (q *queue[T]) Snapshot() []T { + return slices.Clone(q.vals[q.start:]) +} diff --git a/util/eventbus/subscribe.go b/util/eventbus/subscribe.go new file mode 100644 index 000000000..b0348e125 --- /dev/null +++ b/util/eventbus/subscribe.go @@ -0,0 +1,356 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package eventbus + +import ( + "context" + "fmt" + "reflect" + "runtime" + "time" + + "tailscale.com/syncs" + "tailscale.com/types/logger" + "tailscale.com/util/cibuild" +) + +type DeliveredEvent struct { + Event any + From *Client + To *Client +} + +// subscriber is a uniformly typed wrapper around Subscriber[T], so +// that debugging facilities can look at active subscribers. +type subscriber interface { + subscribeType() reflect.Type + // dispatch is a function that dispatches the head value in vals to + // a subscriber, while also handling stop and incoming queue write + // events. + // + // dispatch exists because of the strongly typed Subscriber[T] + // wrapper around subscriptions: within the bus events are boxed in an + // 'any', and need to be unpacked to their full type before delivery + // to the subscriber. This involves writing to a strongly-typed + // channel, so subscribeState cannot handle that dispatch by itself - + // but if that strongly typed send blocks, we also need to keep + // processing other potential sources of wakeups, which is how we end + // up at this awkward type signature and sharing of internal state + // through dispatch. + dispatch(ctx context.Context, vals *queue[DeliveredEvent], acceptCh func() chan DeliveredEvent, snapshot chan chan []DeliveredEvent) bool + Close() +} + +// subscribeState handles dispatching of events received from a Bus. +type subscribeState struct { + client *Client + + dispatcher *worker + write chan DeliveredEvent + snapshot chan chan []DeliveredEvent + debug hook[DeliveredEvent] + + outputsMu syncs.Mutex + outputs map[reflect.Type]subscriber +} + +func newSubscribeState(c *Client) *subscribeState { + ret := &subscribeState{ + client: c, + write: make(chan DeliveredEvent), + snapshot: make(chan chan []DeliveredEvent), + outputs: map[reflect.Type]subscriber{}, + } + ret.dispatcher = runWorker(ret.pump) + return ret +} + +func (s *subscribeState) pump(ctx context.Context) { + var vals queue[DeliveredEvent] + acceptCh := func() chan DeliveredEvent { + if vals.Full() { + return nil + } + return s.write + } + for { + if !vals.Empty() { + val := vals.Peek() + sub := s.subscriberFor(val.Event) + if sub == nil { + // Raced with unsubscribe. + vals.Drop() + continue + } + if !sub.dispatch(ctx, &vals, acceptCh, s.snapshot) { + return + } + + if s.debug.active() { + s.debug.run(DeliveredEvent{ + Event: val.Event, + From: val.From, + To: s.client, + }) + } + } else { + // Keep the cases in this select in sync with + // Subscriber.dispatch and SubscriberFunc.dispatch below. + // The only difference should be that this select doesn't deliver + // queued values to anyone, and unconditionally accepts new values. + select { + case val := <-s.write: + vals.Add(val) + case <-ctx.Done(): + return + case ch := <-s.snapshot: + ch <- vals.Snapshot() + } + } + } +} + +func (s *subscribeState) snapshotQueue() []DeliveredEvent { + if s == nil { + return nil + } + + resp := make(chan []DeliveredEvent) + select { + case s.snapshot <- resp: + return <-resp + case <-s.dispatcher.Done(): + return nil + } +} + +func (s *subscribeState) subscribeTypes() []reflect.Type { + if s == nil { + return nil + } + + s.outputsMu.Lock() + defer s.outputsMu.Unlock() + ret := make([]reflect.Type, 0, len(s.outputs)) + for t := range s.outputs { + ret = append(ret, t) + } + return ret +} + +func (s *subscribeState) addSubscriber(sub subscriber) { + s.outputsMu.Lock() + defer s.outputsMu.Unlock() + t := sub.subscribeType() + if s.outputs[t] != nil { + panic(fmt.Errorf("double subscription for event %s", t)) + } + s.outputs[t] = sub + s.client.addSubscriber(t, s) +} + +func (s *subscribeState) deleteSubscriber(t reflect.Type) { + s.outputsMu.Lock() + defer s.outputsMu.Unlock() + delete(s.outputs, t) + s.client.deleteSubscriber(t, s) +} + +func (s *subscribeState) subscriberFor(val any) subscriber { + s.outputsMu.Lock() + defer s.outputsMu.Unlock() + return s.outputs[reflect.TypeOf(val)] +} + +// Close closes the subscribeState. It implicitly closes all Subscribers +// linked to this state, and any pending events are discarded. +func (s *subscribeState) close() { + s.dispatcher.StopAndWait() + + var subs map[reflect.Type]subscriber + s.outputsMu.Lock() + subs, s.outputs = s.outputs, nil + s.outputsMu.Unlock() + for _, sub := range subs { + sub.Close() + } +} + +func (s *subscribeState) closed() <-chan struct{} { + return s.dispatcher.Done() +} + +// A Subscriber delivers one type of event from a [Client]. +// Events are sent to the [Subscriber.Events] channel. +type Subscriber[T any] struct { + stop stopFlag + read chan T + unregister func() + logf logger.Logf + slow *time.Timer // used to detect slow subscriber service +} + +func newSubscriber[T any](r *subscribeState, logf logger.Logf) *Subscriber[T] { + slow := time.NewTimer(0) + slow.Stop() // reset in dispatch + return &Subscriber[T]{ + read: make(chan T), + unregister: func() { r.deleteSubscriber(reflect.TypeFor[T]()) }, + logf: logf, + slow: slow, + } +} + +func newMonitor[T any](attach func(fn func(T)) (cancel func())) *Subscriber[T] { + ret := &Subscriber[T]{ + read: make(chan T, 100), // arbitrary, large + } + ret.unregister = attach(ret.monitor) + return ret +} + +func (s *Subscriber[T]) subscribeType() reflect.Type { + return reflect.TypeFor[T]() +} + +func (s *Subscriber[T]) monitor(debugEvent T) { + select { + case s.read <- debugEvent: + case <-s.stop.Done(): + } +} + +func (s *Subscriber[T]) dispatch(ctx context.Context, vals *queue[DeliveredEvent], acceptCh func() chan DeliveredEvent, snapshot chan chan []DeliveredEvent) bool { + t := vals.Peek().Event.(T) + + start := time.Now() + s.slow.Reset(slowSubscriberTimeout) + defer s.slow.Stop() + + for { + // Keep the cases in this select in sync with subscribeState.pump + // above. The only difference should be that this select + // delivers a value on s.read. + select { + case s.read <- t: + vals.Drop() + return true + case val := <-acceptCh(): + vals.Add(val) + case <-ctx.Done(): + return false + case ch := <-snapshot: + ch <- vals.Snapshot() + case <-s.slow.C: + s.logf("subscriber for %T is slow (%v elapsed)", t, time.Since(start)) + s.slow.Reset(slowSubscriberTimeout) + } + } +} + +// Events returns a channel on which the subscriber's events are +// delivered. +func (s *Subscriber[T]) Events() <-chan T { + return s.read +} + +// Done returns a channel that is closed when the subscriber is +// closed. +func (s *Subscriber[T]) Done() <-chan struct{} { + return s.stop.Done() +} + +// Close closes the Subscriber, indicating the caller no longer wishes +// to receive this event type. After Close, receives on +// [Subscriber.Events] block for ever. +// +// If the Bus from which the Subscriber was created is closed, +// the Subscriber is implicitly closed and does not need to be closed +// separately. +func (s *Subscriber[T]) Close() { + s.stop.Stop() // unblock receivers + s.unregister() +} + +// A SubscriberFunc delivers one type of event from a [Client]. +// Events are forwarded synchronously to a function provided at construction. +type SubscriberFunc[T any] struct { + stop stopFlag + read func(T) + unregister func() + logf logger.Logf + slow *time.Timer // used to detect slow subscriber service +} + +func newSubscriberFunc[T any](r *subscribeState, f func(T), logf logger.Logf) *SubscriberFunc[T] { + slow := time.NewTimer(0) + slow.Stop() // reset in dispatch + return &SubscriberFunc[T]{ + read: f, + unregister: func() { r.deleteSubscriber(reflect.TypeFor[T]()) }, + logf: logf, + slow: slow, + } +} + +// Close closes the SubscriberFunc, indicating the caller no longer wishes to +// receive this event type. After Close, no further events will be passed to +// the callback. +// +// If the [Bus] from which s was created is closed, s is implicitly closed and +// does not need to be closed separately. +func (s *SubscriberFunc[T]) Close() { s.stop.Stop(); s.unregister() } + +// subscribeType implements part of the subscriber interface. +func (s *SubscriberFunc[T]) subscribeType() reflect.Type { return reflect.TypeFor[T]() } + +// dispatch implements part of the subscriber interface. +func (s *SubscriberFunc[T]) dispatch(ctx context.Context, vals *queue[DeliveredEvent], acceptCh func() chan DeliveredEvent, snapshot chan chan []DeliveredEvent) bool { + t := vals.Peek().Event.(T) + callDone := make(chan struct{}) + go s.runCallback(t, callDone) + + start := time.Now() + s.slow.Reset(slowSubscriberTimeout) + defer s.slow.Stop() + + // Keep the cases in this select in sync with subscribeState.pump + // above. The only difference should be that this select + // delivers a value by calling s.read. + for { + select { + case <-callDone: + vals.Drop() + return true + case val := <-acceptCh(): + vals.Add(val) + case <-ctx.Done(): + // Wait for the callback to be complete, but not forever. + s.slow.Reset(5 * slowSubscriberTimeout) + select { + case <-s.slow.C: + s.logf("giving up on subscriber for %T after %v at close", t, time.Since(start)) + if cibuild.On() { + all := make([]byte, 2<<20) + n := runtime.Stack(all, true) + s.logf("goroutine stacks:\n%s", all[:n]) + } + case <-callDone: + } + return false + case ch := <-snapshot: + ch <- vals.Snapshot() + case <-s.slow.C: + s.logf("subscriber for %T is slow (%v elapsed)", t, time.Since(start)) + s.slow.Reset(slowSubscriberTimeout) + } + } +} + +// runCallback invokes the callback on v and closes ch when it returns. +// This should be run in a goroutine. +func (s *SubscriberFunc[T]) runCallback(v T, ch chan struct{}) { + defer close(ch) + s.read(v) +} diff --git a/util/execqueue/execqueue.go b/util/execqueue/execqueue.go index 889cea255..2ea0c1f2f 100644 --- a/util/execqueue/execqueue.go +++ b/util/execqueue/execqueue.go @@ -7,11 +7,14 @@ package execqueue import ( "context" "errors" - "sync" + + "tailscale.com/syncs" ) type ExecQueue struct { - mu sync.Mutex + mu syncs.Mutex + ctx context.Context // context.Background + closed on Shutdown + cancel context.CancelFunc // closes ctx closed bool inFlight bool // whether a goroutine is running q.run doneWaiter chan struct{} // non-nil if waiter is waiting, then closed @@ -24,6 +27,7 @@ func (q *ExecQueue) Add(f func()) { if q.closed { return } + q.initCtxLocked() if q.inFlight { q.queue = append(q.queue, f) } else { @@ -79,18 +83,32 @@ func (q *ExecQueue) Shutdown() { q.mu.Lock() defer q.mu.Unlock() q.closed = true + if q.cancel != nil { + q.cancel() + } +} + +func (q *ExecQueue) initCtxLocked() { + if q.ctx == nil { + q.ctx, q.cancel = context.WithCancel(context.Background()) + } } -// Wait waits for the queue to be empty. +// Wait waits for the queue to be empty or shut down. func (q *ExecQueue) Wait(ctx context.Context) error { q.mu.Lock() + q.initCtxLocked() waitCh := q.doneWaiter if q.inFlight && waitCh == nil { waitCh = make(chan struct{}) q.doneWaiter = waitCh } + closed := q.closed q.mu.Unlock() + if closed { + return errors.New("execqueue shut down") + } if waitCh == nil { return nil } @@ -98,6 +116,8 @@ func (q *ExecQueue) Wait(ctx context.Context) error { select { case <-waitCh: return nil + case <-q.ctx.Done(): + return errors.New("execqueue shut down") case <-ctx.Done(): return ctx.Err() } diff --git a/util/expvarx/expvarx.go b/util/expvarx/expvarx.go index 762f65d06..bcdc4a91a 100644 --- a/util/expvarx/expvarx.go +++ b/util/expvarx/expvarx.go @@ -7,9 +7,9 @@ package expvarx import ( "encoding/json" "expvar" - "sync" "time" + "tailscale.com/syncs" "tailscale.com/types/lazy" ) @@ -20,7 +20,7 @@ type SafeFunc struct { limit time.Duration onSlow func(time.Duration, any) - mu sync.Mutex + mu syncs.Mutex inflight *lazy.SyncValue[any] } diff --git a/util/expvarx/expvarx_test.go b/util/expvarx/expvarx_test.go index 74ec152f4..9ed2e8f20 100644 --- a/util/expvarx/expvarx_test.go +++ b/util/expvarx/expvarx_test.go @@ -9,6 +9,7 @@ import ( "sync" "sync/atomic" "testing" + "testing/synctest" "time" ) @@ -52,18 +53,21 @@ func ExampleNewSafeFunc() { } func TestSafeFuncHappyPath(t *testing.T) { - var count int - f := NewSafeFunc(expvar.Func(func() any { - count++ - return count - }), time.Millisecond, nil) - - if got, want := f.Value(), 1; got != want { - t.Errorf("got %v, want %v", got, want) - } - if got, want := f.Value(), 2; got != want { - t.Errorf("got %v, want %v", got, want) - } + synctest.Test(t, func(t *testing.T) { + var count int + f := NewSafeFunc(expvar.Func(func() any { + count++ + return count + }), time.Second, nil) + + if got, want := f.Value(), 1; got != want { + t.Errorf("got %v, want %v", got, want) + } + time.Sleep(5 * time.Second) // (fake time in synctest) + if got, want := f.Value(), 2; got != want { + t.Errorf("got %v, want %v", got, want) + } + }) } func TestSafeFuncSlow(t *testing.T) { diff --git a/util/fastuuid/fastuuid.go b/util/fastuuid/fastuuid.go deleted file mode 100644 index 4b115ea4e..000000000 --- a/util/fastuuid/fastuuid.go +++ /dev/null @@ -1,56 +0,0 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package fastuuid implements a UUID construction using an in process CSPRNG. -package fastuuid - -import ( - crand "crypto/rand" - "encoding/binary" - "io" - "math/rand/v2" - "sync" - - "github.com/google/uuid" -) - -// NewUUID returns a new UUID using a pool of generators, good for highly -// concurrent use. -func NewUUID() uuid.UUID { - g := pool.Get().(*generator) - defer pool.Put(g) - return g.newUUID() -} - -var pool = sync.Pool{ - New: func() any { - return newGenerator() - }, -} - -type generator struct { - rng rand.ChaCha8 -} - -func seed() [32]byte { - var r [32]byte - if _, err := io.ReadFull(crand.Reader, r[:]); err != nil { - panic(err) - } - return r -} - -func newGenerator() *generator { - return &generator{ - rng: *rand.NewChaCha8(seed()), - } -} - -func (g *generator) newUUID() uuid.UUID { - var u uuid.UUID - binary.NativeEndian.PutUint64(u[:8], g.rng.Uint64()) - binary.NativeEndian.PutUint64(u[8:], g.rng.Uint64()) - u[6] = (u[6] & 0x0f) | 0x40 // Version 4 - u[8] = (u[8] & 0x3f) | 0x80 // Variant 10 - return u -} diff --git a/util/fastuuid/fastuuid_test.go b/util/fastuuid/fastuuid_test.go deleted file mode 100644 index f0d993904..000000000 --- a/util/fastuuid/fastuuid_test.go +++ /dev/null @@ -1,72 +0,0 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package fastuuid - -import ( - "testing" - - "github.com/google/uuid" -) - -func TestNewUUID(t *testing.T) { - g := pool.Get().(*generator) - defer pool.Put(g) - u := g.newUUID() - if u[6] != (u[6]&0x0f)|0x40 { - t.Errorf("version bits are incorrect") - } - if u[8] != (u[8]&0x3f)|0x80 { - t.Errorf("variant bits are incorrect") - } -} - -func BenchmarkBasic(b *testing.B) { - b.Run("NewUUID", func(b *testing.B) { - for range b.N { - NewUUID() - } - }) - - b.Run("uuid.New-unpooled", func(b *testing.B) { - uuid.DisableRandPool() - for range b.N { - uuid.New() - } - }) - - b.Run("uuid.New-pooled", func(b *testing.B) { - uuid.EnableRandPool() - for range b.N { - uuid.New() - } - }) -} - -func BenchmarkParallel(b *testing.B) { - b.Run("NewUUID", func(b *testing.B) { - b.RunParallel(func(pb *testing.PB) { - for pb.Next() { - NewUUID() - } - }) - }) - - b.Run("uuid.New-unpooled", func(b *testing.B) { - uuid.DisableRandPool() - b.RunParallel(func(pb *testing.PB) { - for pb.Next() { - uuid.New() - } - }) - }) - - b.Run("uuid.New-pooled", func(b *testing.B) { - uuid.EnableRandPool() - b.RunParallel(func(pb *testing.PB) { - for pb.Next() { - uuid.New() - } - }) - }) -} diff --git a/util/goroutines/goroutines.go b/util/goroutines/goroutines.go index 9758b0758..d40cbecb1 100644 --- a/util/goroutines/goroutines.go +++ b/util/goroutines/goroutines.go @@ -1,7 +1,7 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -// The goroutines package contains utilities for getting active goroutines. +// The goroutines package contains utilities for tracking and getting active goroutines. package goroutines import ( diff --git a/util/goroutines/tracker.go b/util/goroutines/tracker.go new file mode 100644 index 000000000..c2a0cb8c3 --- /dev/null +++ b/util/goroutines/tracker.go @@ -0,0 +1,66 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package goroutines + +import ( + "sync/atomic" + + "tailscale.com/syncs" + "tailscale.com/util/set" +) + +// Tracker tracks a set of goroutines. +type Tracker struct { + started atomic.Int64 // counter + running atomic.Int64 // gauge + + mu syncs.Mutex + onDone set.HandleSet[func()] +} + +func (t *Tracker) Go(f func()) { + t.started.Add(1) + t.running.Add(1) + go t.goAndDecr(f) +} + +func (t *Tracker) goAndDecr(f func()) { + defer t.decr() + f() +} + +func (t *Tracker) decr() { + t.running.Add(-1) + + t.mu.Lock() + defer t.mu.Unlock() + for _, f := range t.onDone { + go f() + } +} + +// AddDoneCallback adds a callback to be called in a new goroutine +// whenever a goroutine managed by t (excluding ones from this method) +// finishes. It returns a function to remove the callback. +func (t *Tracker) AddDoneCallback(f func()) (remove func()) { + t.mu.Lock() + defer t.mu.Unlock() + if t.onDone == nil { + t.onDone = set.HandleSet[func()]{} + } + h := t.onDone.Add(f) + return func() { + t.mu.Lock() + defer t.mu.Unlock() + delete(t.onDone, h) + } +} + +func (t *Tracker) RunningGoroutines() int64 { + return t.running.Load() +} + +func (t *Tracker) StartedGoroutines() int64 { + return t.started.Load() +} diff --git a/util/jsonutil/types.go b/util/jsonutil/types.go deleted file mode 100644 index 057473249..000000000 --- a/util/jsonutil/types.go +++ /dev/null @@ -1,16 +0,0 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package jsonutil - -// Bytes is a byte slice in a json-encoded struct. -// encoding/json assumes that []byte fields are hex-encoded. -// Bytes are not hex-encoded; they are treated the same as strings. -// This can avoid unnecessary allocations due to a round trip through strings. -type Bytes []byte - -func (b *Bytes) UnmarshalText(text []byte) error { - // Copy the contexts of text. - *b = append(*b, text...) - return nil -} diff --git a/util/jsonutil/unmarshal.go b/util/jsonutil/unmarshal.go deleted file mode 100644 index b1eb4ea87..000000000 --- a/util/jsonutil/unmarshal.go +++ /dev/null @@ -1,89 +0,0 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package jsonutil provides utilities to improve JSON performance. -// It includes an Unmarshal wrapper that amortizes allocated garbage over subsequent runs -// and a Bytes type to reduce allocations when unmarshalling a non-hex-encoded string into a []byte. -package jsonutil - -import ( - "bytes" - "encoding/json" - "sync" -) - -// decoder is a re-usable json decoder. -type decoder struct { - dec *json.Decoder - r *bytes.Reader -} - -var readerPool = sync.Pool{ - New: func() any { - return bytes.NewReader(nil) - }, -} - -var decoderPool = sync.Pool{ - New: func() any { - var d decoder - d.r = readerPool.Get().(*bytes.Reader) - d.dec = json.NewDecoder(d.r) - return &d - }, -} - -// Unmarshal is similar to encoding/json.Unmarshal. -// There are three major differences: -// -// On error, encoding/json.Unmarshal zeros v. -// This Unmarshal may leave partial data in v. -// Always check the error before using v! -// (Future improvements may remove this bug.) -// -// The errors they return don't always match perfectly. -// If you do error matching more precise than err != nil, -// don't use this Unmarshal. -// -// This Unmarshal allocates considerably less memory. -func Unmarshal(b []byte, v any) error { - d := decoderPool.Get().(*decoder) - d.r.Reset(b) - off := d.dec.InputOffset() - err := d.dec.Decode(v) - d.r.Reset(nil) // don't keep a reference to b - // In case of error, report the offset in this byte slice, - // instead of in the totality of all bytes this decoder has processed. - // It is not possible to make all errors match json.Unmarshal exactly, - // but we can at least try. - switch jsonerr := err.(type) { - case *json.SyntaxError: - jsonerr.Offset -= off - case *json.UnmarshalTypeError: - jsonerr.Offset -= off - case nil: - // json.Unmarshal fails if there's any extra junk in the input. - // json.Decoder does not; see https://github.com/golang/go/issues/36225. - // We need to check for anything left over in the buffer. - if d.dec.More() { - // TODO: Provide a better error message. - // Unfortunately, we can't set the msg field. - // The offset doesn't perfectly match json: - // Ours is at the end of the valid data, - // and theirs is at the beginning of the extra data after whitespace. - // Close enough, though. - err = &json.SyntaxError{Offset: d.dec.InputOffset() - off} - - // TODO: zero v. This is hard; see encoding/json.indirect. - } - } - if err == nil { - decoderPool.Put(d) - } else { - // There might be junk left in the decoder's buffer. - // There's no way to flush it, no Reset method. - // Abandoned the decoder but reuse the reader. - readerPool.Put(d.r) - } - return err -} diff --git a/util/jsonutil/unmarshal_test.go b/util/jsonutil/unmarshal_test.go deleted file mode 100644 index 32f8402f0..000000000 --- a/util/jsonutil/unmarshal_test.go +++ /dev/null @@ -1,64 +0,0 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package jsonutil - -import ( - "encoding/json" - "reflect" - "testing" -) - -func TestCompareToStd(t *testing.T) { - tests := []string{ - `{}`, - `{"a": 1}`, - `{]`, - `"abc"`, - `5`, - `{"a": 1} `, - `{"a": 1} {}`, - `{} bad data`, - `{"a": 1} "hello"`, - `[]`, - ` {"x": {"t": [3,4,5]}}`, - } - - for _, test := range tests { - b := []byte(test) - var ourV, stdV any - ourErr := Unmarshal(b, &ourV) - stdErr := json.Unmarshal(b, &stdV) - if (ourErr == nil) != (stdErr == nil) { - t.Errorf("Unmarshal(%q): our err = %#[2]v (%[2]T), std err = %#[3]v (%[3]T)", test, ourErr, stdErr) - } - // if !reflect.DeepEqual(ourErr, stdErr) { - // t.Logf("Unmarshal(%q): our err = %#[2]v (%[2]T), std err = %#[3]v (%[3]T)", test, ourErr, stdErr) - // } - if ourErr != nil { - // TODO: if we zero ourV on error, remove this continue. - continue - } - if !reflect.DeepEqual(ourV, stdV) { - t.Errorf("Unmarshal(%q): our val = %v, std val = %v", test, ourV, stdV) - } - } -} - -func BenchmarkUnmarshal(b *testing.B) { - var m any - j := []byte("5") - b.ReportAllocs() - for range b.N { - Unmarshal(j, &m) - } -} - -func BenchmarkStdUnmarshal(b *testing.B) { - var m any - j := []byte("5") - b.ReportAllocs() - for range b.N { - json.Unmarshal(j, &m) - } -} diff --git a/util/limiter/limiter.go b/util/limiter/limiter.go index 5af5f7bd1..b86efdf29 100644 --- a/util/limiter/limiter.go +++ b/util/limiter/limiter.go @@ -8,9 +8,9 @@ import ( "fmt" "html" "io" - "sync" "time" + "tailscale.com/syncs" "tailscale.com/util/lru" ) @@ -75,7 +75,7 @@ type Limiter[K comparable] struct { // perpetually in debt and cannot proceed at all. Overdraft int64 - mu sync.Mutex + mu syncs.Mutex cache *lru.Cache[K, *bucket] } @@ -94,59 +94,59 @@ type bucket struct { // Allow charges the key one token (up to the overdraft limit), and // reports whether the key can perform an action. -func (l *Limiter[K]) Allow(key K) bool { - return l.allow(key, time.Now()) +func (lm *Limiter[K]) Allow(key K) bool { + return lm.allow(key, time.Now()) } -func (l *Limiter[K]) allow(key K, now time.Time) bool { - l.mu.Lock() - defer l.mu.Unlock() - return l.allowBucketLocked(l.getBucketLocked(key, now), now) +func (lm *Limiter[K]) allow(key K, now time.Time) bool { + lm.mu.Lock() + defer lm.mu.Unlock() + return lm.allowBucketLocked(lm.getBucketLocked(key, now), now) } -func (l *Limiter[K]) getBucketLocked(key K, now time.Time) *bucket { - if l.cache == nil { - l.cache = &lru.Cache[K, *bucket]{MaxEntries: l.Size} - } else if b := l.cache.Get(key); b != nil { +func (lm *Limiter[K]) getBucketLocked(key K, now time.Time) *bucket { + if lm.cache == nil { + lm.cache = &lru.Cache[K, *bucket]{MaxEntries: lm.Size} + } else if b := lm.cache.Get(key); b != nil { return b } b := &bucket{ - cur: l.Max, - lastUpdate: now.Truncate(l.RefillInterval), + cur: lm.Max, + lastUpdate: now.Truncate(lm.RefillInterval), } - l.cache.Set(key, b) + lm.cache.Set(key, b) return b } -func (l *Limiter[K]) allowBucketLocked(b *bucket, now time.Time) bool { +func (lm *Limiter[K]) allowBucketLocked(b *bucket, now time.Time) bool { // Only update the bucket quota if needed to process request. if b.cur <= 0 { - l.updateBucketLocked(b, now) + lm.updateBucketLocked(b, now) } ret := b.cur > 0 - if b.cur > -l.Overdraft { + if b.cur > -lm.Overdraft { b.cur-- } return ret } -func (l *Limiter[K]) updateBucketLocked(b *bucket, now time.Time) { - now = now.Truncate(l.RefillInterval) +func (lm *Limiter[K]) updateBucketLocked(b *bucket, now time.Time) { + now = now.Truncate(lm.RefillInterval) if now.Before(b.lastUpdate) { return } timeDelta := max(now.Sub(b.lastUpdate), 0) - tokenDelta := int64(timeDelta / l.RefillInterval) - b.cur = min(b.cur+tokenDelta, l.Max) + tokenDelta := int64(timeDelta / lm.RefillInterval) + b.cur = min(b.cur+tokenDelta, lm.Max) b.lastUpdate = now } // peekForTest returns the number of tokens for key, also reporting // whether key was present. -func (l *Limiter[K]) tokensForTest(key K) (int64, bool) { - l.mu.Lock() - defer l.mu.Unlock() - if b, ok := l.cache.PeekOk(key); ok { +func (lm *Limiter[K]) tokensForTest(key K) (int64, bool) { + lm.mu.Lock() + defer lm.mu.Unlock() + if b, ok := lm.cache.PeekOk(key); ok { return b.cur, true } return 0, false @@ -159,12 +159,12 @@ func (l *Limiter[K]) tokensForTest(key K) (int64, bool) { // DumpHTML blocks other callers of the limiter while it collects the // state for dumping. It should not be called on large limiters // involved in hot codepaths. -func (l *Limiter[K]) DumpHTML(w io.Writer, onlyLimited bool) { - l.dumpHTML(w, onlyLimited, time.Now()) +func (lm *Limiter[K]) DumpHTML(w io.Writer, onlyLimited bool) { + lm.dumpHTML(w, onlyLimited, time.Now()) } -func (l *Limiter[K]) dumpHTML(w io.Writer, onlyLimited bool, now time.Time) { - dump := l.collectDump(now) +func (lm *Limiter[K]) dumpHTML(w io.Writer, onlyLimited bool, now time.Time) { + dump := lm.collectDump(now) io.WriteString(w, "") for _, line := range dump { if onlyLimited && line.Tokens > 0 { @@ -183,13 +183,13 @@ func (l *Limiter[K]) dumpHTML(w io.Writer, onlyLimited bool, now time.Time) { } // collectDump grabs a copy of the limiter state needed by DumpHTML. -func (l *Limiter[K]) collectDump(now time.Time) []dumpEntry[K] { - l.mu.Lock() - defer l.mu.Unlock() +func (lm *Limiter[K]) collectDump(now time.Time) []dumpEntry[K] { + lm.mu.Lock() + defer lm.mu.Unlock() - ret := make([]dumpEntry[K], 0, l.cache.Len()) - l.cache.ForEach(func(k K, v *bucket) { - l.updateBucketLocked(v, now) // so stats are accurate + ret := make([]dumpEntry[K], 0, lm.cache.Len()) + lm.cache.ForEach(func(k K, v *bucket) { + lm.updateBucketLocked(v, now) // so stats are accurate ret = append(ret, dumpEntry[K]{k, v.cur}) }) return ret diff --git a/util/limiter/limiter_test.go b/util/limiter/limiter_test.go index 1f466d882..77b1d562b 100644 --- a/util/limiter/limiter_test.go +++ b/util/limiter/limiter_test.go @@ -16,7 +16,7 @@ const testRefillInterval = time.Second func TestLimiter(t *testing.T) { // 1qps, burst of 10, 2 keys tracked - l := &Limiter[string]{ + limiter := &Limiter[string]{ Size: 2, Max: 10, RefillInterval: testRefillInterval, @@ -24,48 +24,48 @@ func TestLimiter(t *testing.T) { // Consume entire burst now := time.Now().Truncate(testRefillInterval) - allowed(t, l, "foo", 10, now) - denied(t, l, "foo", 1, now) - hasTokens(t, l, "foo", 0) + allowed(t, limiter, "foo", 10, now) + denied(t, limiter, "foo", 1, now) + hasTokens(t, limiter, "foo", 0) - allowed(t, l, "bar", 10, now) - denied(t, l, "bar", 1, now) - hasTokens(t, l, "bar", 0) + allowed(t, limiter, "bar", 10, now) + denied(t, limiter, "bar", 1, now) + hasTokens(t, limiter, "bar", 0) // Refill 1 token for both foo and bar now = now.Add(time.Second + time.Millisecond) - allowed(t, l, "foo", 1, now) - denied(t, l, "foo", 1, now) - hasTokens(t, l, "foo", 0) + allowed(t, limiter, "foo", 1, now) + denied(t, limiter, "foo", 1, now) + hasTokens(t, limiter, "foo", 0) - allowed(t, l, "bar", 1, now) - denied(t, l, "bar", 1, now) - hasTokens(t, l, "bar", 0) + allowed(t, limiter, "bar", 1, now) + denied(t, limiter, "bar", 1, now) + hasTokens(t, limiter, "bar", 0) // Refill 2 tokens for foo and bar now = now.Add(2*time.Second + time.Millisecond) - allowed(t, l, "foo", 2, now) - denied(t, l, "foo", 1, now) - hasTokens(t, l, "foo", 0) + allowed(t, limiter, "foo", 2, now) + denied(t, limiter, "foo", 1, now) + hasTokens(t, limiter, "foo", 0) - allowed(t, l, "bar", 2, now) - denied(t, l, "bar", 1, now) - hasTokens(t, l, "bar", 0) + allowed(t, limiter, "bar", 2, now) + denied(t, limiter, "bar", 1, now) + hasTokens(t, limiter, "bar", 0) // qux can burst 10, evicts foo so it can immediately burst 10 again too - allowed(t, l, "qux", 10, now) - denied(t, l, "qux", 1, now) - notInLimiter(t, l, "foo") - denied(t, l, "bar", 1, now) // refresh bar so foo lookup doesn't evict it - still throttled - - allowed(t, l, "foo", 10, now) - denied(t, l, "foo", 1, now) - hasTokens(t, l, "foo", 0) + allowed(t, limiter, "qux", 10, now) + denied(t, limiter, "qux", 1, now) + notInLimiter(t, limiter, "foo") + denied(t, limiter, "bar", 1, now) // refresh bar so foo lookup doesn't evict it - still throttled + + allowed(t, limiter, "foo", 10, now) + denied(t, limiter, "foo", 1, now) + hasTokens(t, limiter, "foo", 0) } func TestLimiterOverdraft(t *testing.T) { // 1qps, burst of 10, overdraft of 2, 2 keys tracked - l := &Limiter[string]{ + limiter := &Limiter[string]{ Size: 2, Max: 10, Overdraft: 2, @@ -74,51 +74,51 @@ func TestLimiterOverdraft(t *testing.T) { // Consume entire burst, go 1 into debt now := time.Now().Truncate(testRefillInterval).Add(time.Millisecond) - allowed(t, l, "foo", 10, now) - denied(t, l, "foo", 1, now) - hasTokens(t, l, "foo", -1) + allowed(t, limiter, "foo", 10, now) + denied(t, limiter, "foo", 1, now) + hasTokens(t, limiter, "foo", -1) - allowed(t, l, "bar", 10, now) - denied(t, l, "bar", 1, now) - hasTokens(t, l, "bar", -1) + allowed(t, limiter, "bar", 10, now) + denied(t, limiter, "bar", 1, now) + hasTokens(t, limiter, "bar", -1) // Refill 1 token for both foo and bar. // Still denied, still in debt. now = now.Add(time.Second) - denied(t, l, "foo", 1, now) - hasTokens(t, l, "foo", -1) - denied(t, l, "bar", 1, now) - hasTokens(t, l, "bar", -1) + denied(t, limiter, "foo", 1, now) + hasTokens(t, limiter, "foo", -1) + denied(t, limiter, "bar", 1, now) + hasTokens(t, limiter, "bar", -1) // Refill 2 tokens for foo and bar (1 available after debt), try // to consume 4. Overdraft is capped to 2. now = now.Add(2 * time.Second) - allowed(t, l, "foo", 1, now) - denied(t, l, "foo", 3, now) - hasTokens(t, l, "foo", -2) + allowed(t, limiter, "foo", 1, now) + denied(t, limiter, "foo", 3, now) + hasTokens(t, limiter, "foo", -2) - allowed(t, l, "bar", 1, now) - denied(t, l, "bar", 3, now) - hasTokens(t, l, "bar", -2) + allowed(t, limiter, "bar", 1, now) + denied(t, limiter, "bar", 3, now) + hasTokens(t, limiter, "bar", -2) // Refill 1, not enough to allow. now = now.Add(time.Second) - denied(t, l, "foo", 1, now) - hasTokens(t, l, "foo", -2) - denied(t, l, "bar", 1, now) - hasTokens(t, l, "bar", -2) + denied(t, limiter, "foo", 1, now) + hasTokens(t, limiter, "foo", -2) + denied(t, limiter, "bar", 1, now) + hasTokens(t, limiter, "bar", -2) // qux evicts foo, foo can immediately burst 10 again. - allowed(t, l, "qux", 1, now) - hasTokens(t, l, "qux", 9) - notInLimiter(t, l, "foo") - allowed(t, l, "foo", 10, now) - denied(t, l, "foo", 1, now) - hasTokens(t, l, "foo", -1) + allowed(t, limiter, "qux", 1, now) + hasTokens(t, limiter, "qux", 9) + notInLimiter(t, limiter, "foo") + allowed(t, limiter, "foo", 10, now) + denied(t, limiter, "foo", 1, now) + hasTokens(t, limiter, "foo", -1) } func TestDumpHTML(t *testing.T) { - l := &Limiter[string]{ + limiter := &Limiter[string]{ Size: 3, Max: 10, Overdraft: 10, @@ -126,13 +126,13 @@ func TestDumpHTML(t *testing.T) { } now := time.Now().Truncate(testRefillInterval).Add(time.Millisecond) - allowed(t, l, "foo", 10, now) - denied(t, l, "foo", 2, now) - allowed(t, l, "bar", 4, now) - allowed(t, l, "qux", 1, now) + allowed(t, limiter, "foo", 10, now) + denied(t, limiter, "foo", 2, now) + allowed(t, limiter, "bar", 4, now) + allowed(t, limiter, "qux", 1, now) var out bytes.Buffer - l.DumpHTML(&out, false) + limiter.DumpHTML(&out, false) want := strings.Join([]string{ "
        KeyTokens
        ", "", @@ -146,7 +146,7 @@ func TestDumpHTML(t *testing.T) { } out.Reset() - l.DumpHTML(&out, true) + limiter.DumpHTML(&out, true) want = strings.Join([]string{ "
        KeyTokens
        ", "", @@ -161,7 +161,7 @@ func TestDumpHTML(t *testing.T) { // organically. now = now.Add(3 * time.Second) out.Reset() - l.dumpHTML(&out, false, now) + limiter.dumpHTML(&out, false, now) want = strings.Join([]string{ "
        KeyTokens
        ", "", @@ -175,29 +175,29 @@ func TestDumpHTML(t *testing.T) { } } -func allowed(t *testing.T, l *Limiter[string], key string, count int, now time.Time) { +func allowed(t *testing.T, limiter *Limiter[string], key string, count int, now time.Time) { t.Helper() for i := range count { - if !l.allow(key, now) { - toks, ok := l.tokensForTest(key) + if !limiter.allow(key, now) { + toks, ok := limiter.tokensForTest(key) t.Errorf("after %d times: allow(%q, %q) = false, want true (%d tokens available, in cache = %v)", i, key, now, toks, ok) } } } -func denied(t *testing.T, l *Limiter[string], key string, count int, now time.Time) { +func denied(t *testing.T, limiter *Limiter[string], key string, count int, now time.Time) { t.Helper() for i := range count { - if l.allow(key, now) { - toks, ok := l.tokensForTest(key) + if limiter.allow(key, now) { + toks, ok := limiter.tokensForTest(key) t.Errorf("after %d times: allow(%q, %q) = true, want false (%d tokens available, in cache = %v)", i, key, now, toks, ok) } } } -func hasTokens(t *testing.T, l *Limiter[string], key string, want int64) { +func hasTokens(t *testing.T, limiter *Limiter[string], key string, want int64) { t.Helper() - got, ok := l.tokensForTest(key) + got, ok := limiter.tokensForTest(key) if !ok { t.Errorf("key %q missing from limiter", key) } else if got != want { @@ -205,9 +205,9 @@ func hasTokens(t *testing.T, l *Limiter[string], key string, want int64) { } } -func notInLimiter(t *testing.T, l *Limiter[string], key string) { +func notInLimiter(t *testing.T, limiter *Limiter[string], key string) { t.Helper() - if tokens, ok := l.tokensForTest(key); ok { + if tokens, ok := limiter.tokensForTest(key); ok { t.Errorf("key %q unexpectedly tracked by limiter, with %d tokens", key, tokens) } } diff --git a/util/lineiter/lineiter.go b/util/lineiter/lineiter.go new file mode 100644 index 000000000..5cb1eeef3 --- /dev/null +++ b/util/lineiter/lineiter.go @@ -0,0 +1,72 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package lineiter iterates over lines in things. +package lineiter + +import ( + "bufio" + "bytes" + "io" + "iter" + "os" + + "tailscale.com/types/result" +) + +// File returns an iterator that reads lines from the named file. +// +// The returned substrings don't include the trailing newline. +// Lines may be empty. +func File(name string) iter.Seq[result.Of[[]byte]] { + f, err := os.Open(name) + return reader(f, f, err) +} + +// Bytes returns an iterator over the lines in bs. +// The returned substrings don't include the trailing newline. +// Lines may be empty. +func Bytes(bs []byte) iter.Seq[[]byte] { + return func(yield func([]byte) bool) { + for len(bs) > 0 { + i := bytes.IndexByte(bs, '\n') + if i < 0 { + yield(bs) + return + } + if !yield(bs[:i]) { + return + } + bs = bs[i+1:] + } + } +} + +// Reader returns an iterator over the lines in r. +// +// The returned substrings don't include the trailing newline. +// Lines may be empty. +func Reader(r io.Reader) iter.Seq[result.Of[[]byte]] { + return reader(r, nil, nil) +} + +func reader(r io.Reader, c io.Closer, err error) iter.Seq[result.Of[[]byte]] { + return func(yield func(result.Of[[]byte]) bool) { + if err != nil { + yield(result.Error[[]byte](err)) + return + } + if c != nil { + defer c.Close() + } + bs := bufio.NewScanner(r) + for bs.Scan() { + if !yield(result.Value(bs.Bytes())) { + return + } + } + if err := bs.Err(); err != nil { + yield(result.Error[[]byte](err)) + } + } +} diff --git a/util/lineiter/lineiter_test.go b/util/lineiter/lineiter_test.go new file mode 100644 index 000000000..3373d5fe7 --- /dev/null +++ b/util/lineiter/lineiter_test.go @@ -0,0 +1,32 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package lineiter + +import ( + "slices" + "strings" + "testing" +) + +func TestBytesLines(t *testing.T) { + var got []string + for line := range Bytes([]byte("foo\n\nbar\nbaz")) { + got = append(got, string(line)) + } + want := []string{"foo", "", "bar", "baz"} + if !slices.Equal(got, want) { + t.Errorf("got %q; want %q", got, want) + } +} + +func TestReader(t *testing.T) { + var got []string + for line := range Reader(strings.NewReader("foo\n\nbar\nbaz")) { + got = append(got, string(line.MustValue())) + } + want := []string{"foo", "", "bar", "baz"} + if !slices.Equal(got, want) { + t.Errorf("got %q; want %q", got, want) + } +} diff --git a/util/linuxfw/detector.go b/util/linuxfw/detector.go index f3ee4aa0b..149e0c960 100644 --- a/util/linuxfw/detector.go +++ b/util/linuxfw/detector.go @@ -10,6 +10,8 @@ import ( "os/exec" "tailscale.com/envknob" + "tailscale.com/feature" + "tailscale.com/feature/buildfeatures" "tailscale.com/hostinfo" "tailscale.com/types/logger" "tailscale.com/version/distro" @@ -23,6 +25,11 @@ func detectFirewallMode(logf logger.Logf, prefHint string) FirewallMode { hostinfo.SetFirewallMode("nft-gokrazy") return FirewallModeNfTables } + if distro.Get() == distro.JetKVM { + // JetKVM doesn't have iptables. + hostinfo.SetFirewallMode("nft-jetkvm") + return FirewallModeNfTables + } mode := envknob.String("TS_DEBUG_FIREWALL_MODE") // If the envknob isn't set, fall back to the pref suggested by c2n or @@ -37,10 +44,12 @@ func detectFirewallMode(logf logger.Logf, prefHint string) FirewallMode { var det linuxFWDetector if mode == "" { // We have no preference, so check if `iptables` is even available. - _, err := det.iptDetect() - if err != nil && errors.Is(err, exec.ErrNotFound) { - logf("iptables not found: %v; falling back to nftables", err) - mode = "nftables" + if buildfeatures.HasIPTables { + _, err := det.iptDetect() + if err != nil && errors.Is(err, exec.ErrNotFound) { + logf("iptables not found: %v; falling back to nftables", err) + mode = "nftables" + } } } @@ -54,11 +63,16 @@ func detectFirewallMode(logf logger.Logf, prefHint string) FirewallMode { return FirewallModeNfTables case "iptables": hostinfo.SetFirewallMode("ipt-forced") - default: + return FirewallModeIPTables + } + if buildfeatures.HasIPTables { logf("default choosing iptables") hostinfo.SetFirewallMode("ipt-default") + return FirewallModeIPTables } - return FirewallModeIPTables + logf("default choosing nftables") + hostinfo.SetFirewallMode("nft-default") + return FirewallModeNfTables } // tableDetector abstracts helpers to detect the firewall mode. @@ -71,23 +85,37 @@ type tableDetector interface { type linuxFWDetector struct{} // iptDetect returns the number of iptables rules in the current namespace. -func (l linuxFWDetector) iptDetect() (int, error) { +func (ld linuxFWDetector) iptDetect() (int, error) { return detectIptables() } +var hookDetectNetfilter feature.Hook[func() (int, error)] + +// ErrUnsupported is the error returned from all functions on non-Linux +// platforms. +var ErrUnsupported = errors.New("linuxfw:unsupported") + // nftDetect returns the number of nftables rules in the current namespace. -func (l linuxFWDetector) nftDetect() (int, error) { - return detectNetfilter() +func (ld linuxFWDetector) nftDetect() (int, error) { + if f, ok := hookDetectNetfilter.GetOk(); ok { + return f() + } + return 0, ErrUnsupported } // pickFirewallModeFromInstalledRules returns the firewall mode to use based on // the environment and the system's capabilities. func pickFirewallModeFromInstalledRules(logf logger.Logf, det tableDetector) FirewallMode { + if !buildfeatures.HasIPTables { + hostinfo.SetFirewallMode("nft-noipt") + return FirewallModeNfTables + } if distro.Get() == distro.Gokrazy { // Reduce startup logging on gokrazy. There's no way to do iptables on // gokrazy anyway. return FirewallModeNfTables } + iptAva, nftAva := true, true iptRuleCount, err := det.iptDetect() if err != nil { diff --git a/util/linuxfw/fake.go b/util/linuxfw/fake.go index 63a728d55..d01849a2e 100644 --- a/util/linuxfw/fake.go +++ b/util/linuxfw/fake.go @@ -128,7 +128,7 @@ func (n *fakeIPTables) DeleteChain(table, chain string) error { } } -func NewFakeIPTablesRunner() *iptablesRunner { +func NewFakeIPTablesRunner() NetfilterRunner { ipt4 := newFakeIPTables() v6Available := false var ipt6 iptablesInterface diff --git a/util/linuxfw/fake_netfilter.go b/util/linuxfw/fake_netfilter.go new file mode 100644 index 000000000..a998ed765 --- /dev/null +++ b/util/linuxfw/fake_netfilter.go @@ -0,0 +1,95 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build linux + +package linuxfw + +import ( + "net/netip" + + "tailscale.com/types/logger" +) + +// FakeNetfilterRunner is a fake netfilter runner for tests. +type FakeNetfilterRunner struct { + // services is a map that tracks the firewall rules added/deleted via + // EnsureDNATRuleForSvc/DeleteDNATRuleForSvc. + services map[string]struct { + TailscaleServiceIP netip.Addr + ClusterIP netip.Addr + } +} + +// NewFakeNetfilterRunner creates a new FakeNetfilterRunner. +func NewFakeNetfilterRunner() *FakeNetfilterRunner { + return &FakeNetfilterRunner{ + services: make(map[string]struct { + TailscaleServiceIP netip.Addr + ClusterIP netip.Addr + }), + } +} + +func (f *FakeNetfilterRunner) EnsureDNATRuleForSvc(svcName string, origDst, dst netip.Addr) error { + f.services[svcName] = struct { + TailscaleServiceIP netip.Addr + ClusterIP netip.Addr + }{origDst, dst} + return nil +} + +func (f *FakeNetfilterRunner) DeleteDNATRuleForSvc(svcName string, origDst, dst netip.Addr) error { + delete(f.services, svcName) + return nil +} + +func (f *FakeNetfilterRunner) GetServiceState() map[string]struct { + TailscaleServiceIP netip.Addr + ClusterIP netip.Addr +} { + return f.services +} + +func (f *FakeNetfilterRunner) HasIPV6() bool { + return true +} + +func (f *FakeNetfilterRunner) HasIPV6Filter() bool { + return true +} + +func (f *FakeNetfilterRunner) HasIPV6NAT() bool { + return true +} + +func (f *FakeNetfilterRunner) AddBase(tunname string) error { return nil } +func (f *FakeNetfilterRunner) DelBase() error { return nil } +func (f *FakeNetfilterRunner) AddChains() error { return nil } +func (f *FakeNetfilterRunner) DelChains() error { return nil } +func (f *FakeNetfilterRunner) AddHooks() error { return nil } +func (f *FakeNetfilterRunner) DelHooks(logf logger.Logf) error { return nil } +func (f *FakeNetfilterRunner) AddSNATRule() error { return nil } +func (f *FakeNetfilterRunner) DelSNATRule() error { return nil } +func (f *FakeNetfilterRunner) AddStatefulRule(tunname string) error { return nil } +func (f *FakeNetfilterRunner) DelStatefulRule(tunname string) error { return nil } +func (f *FakeNetfilterRunner) AddLoopbackRule(addr netip.Addr) error { return nil } +func (f *FakeNetfilterRunner) DelLoopbackRule(addr netip.Addr) error { return nil } +func (f *FakeNetfilterRunner) AddDNATRule(origDst, dst netip.Addr) error { return nil } +func (f *FakeNetfilterRunner) DNATWithLoadBalancer(origDst netip.Addr, dsts []netip.Addr) error { + return nil +} +func (f *FakeNetfilterRunner) EnsureSNATForDst(src, dst netip.Addr) error { return nil } +func (f *FakeNetfilterRunner) DNATNonTailscaleTraffic(tun string, dst netip.Addr) error { return nil } +func (f *FakeNetfilterRunner) ClampMSSToPMTU(tun string, addr netip.Addr) error { return nil } +func (f *FakeNetfilterRunner) AddMagicsockPortRule(port uint16, network string) error { return nil } +func (f *FakeNetfilterRunner) DelMagicsockPortRule(port uint16, network string) error { return nil } +func (f *FakeNetfilterRunner) DeletePortMapRuleForSvc(svc, tun string, targetIP netip.Addr, pm PortMap) error { + return nil +} +func (f *FakeNetfilterRunner) DeleteSvc(svc, tun string, targetIPs []netip.Addr, pms []PortMap) error { + return nil +} +func (f *FakeNetfilterRunner) EnsurePortMapRuleForSvc(svc, tun string, targetIP netip.Addr, pm PortMap) error { + return nil +} diff --git a/util/linuxfw/iptables.go b/util/linuxfw/iptables.go index 234fa526c..76c5400be 100644 --- a/util/linuxfw/iptables.go +++ b/util/linuxfw/iptables.go @@ -1,21 +1,31 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -// TODO(#8502): add support for more architectures -//go:build linux && (arm64 || amd64) +//go:build linux && !ts_omit_iptables package linuxfw import ( + "bytes" + "errors" "fmt" + "os" "os/exec" "strings" "unicode" + "github.com/coreos/go-iptables/iptables" "tailscale.com/types/logger" - "tailscale.com/util/multierr" + "tailscale.com/version/distro" ) +func init() { + isNotExistError = func(err error) bool { + var e *iptables.Error + return errors.As(err, &e) && e.IsNotExist() + } +} + // DebugNetfilter prints debug information about iptables rules to the // provided log function. func DebugIptables(logf logger.Logf) error { @@ -54,7 +64,7 @@ func detectIptables() (int, error) { default: return 0, FWModeNotSupportedError{ Mode: FirewallModeIPTables, - Err: fmt.Errorf("iptables command run fail: %w", multierr.New(err, ip6err)), + Err: fmt.Errorf("iptables command run fail: %w", errors.Join(err, ip6err)), } } @@ -71,3 +81,153 @@ func detectIptables() (int, error) { // return the count of non-default rules return count, nil } + +// newIPTablesRunner constructs a NetfilterRunner that programs iptables rules. +// If the underlying iptables library fails to initialize, that error is +// returned. The runner probes for IPv6 support once at initialization time and +// if not found, no IPv6 rules will be modified for the lifetime of the runner. +func newIPTablesRunner(logf logger.Logf) (*iptablesRunner, error) { + ipt4, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) + if err != nil { + return nil, err + } + + supportsV6, supportsV6NAT, supportsV6Filter := false, false, false + v6err := CheckIPv6(logf) + ip6terr := checkIP6TablesExists() + var ipt6 *iptables.IPTables + switch { + case v6err != nil: + logf("disabling tunneled IPv6 due to system IPv6 config: %v", v6err) + case ip6terr != nil: + logf("disabling tunneled IPv6 due to missing ip6tables: %v", ip6terr) + default: + supportsV6 = true + ipt6, err = iptables.NewWithProtocol(iptables.ProtocolIPv6) + if err != nil { + return nil, err + } + supportsV6Filter = checkSupportsV6Filter(ipt6, logf) + supportsV6NAT = checkSupportsV6NAT(ipt6, logf) + logf("netfilter running in iptables mode v6 = %v, v6filter = %v, v6nat = %v", supportsV6, supportsV6Filter, supportsV6NAT) + } + return &iptablesRunner{ + ipt4: ipt4, + ipt6: ipt6, + v6Available: supportsV6, + v6NATAvailable: supportsV6NAT, + v6FilterAvailable: supportsV6Filter}, nil +} + +// checkSupportsV6Filter returns whether the system has a "filter" table in the +// IPv6 tables. Some container environments such as GitHub codespaces have +// limited local IPv6 support, and containers containing ip6tables, but do not +// have kernel support for IPv6 filtering. +// We will not set ip6tables rules in these instances. +func checkSupportsV6Filter(ipt *iptables.IPTables, logf logger.Logf) bool { + if ipt == nil { + return false + } + _, filterListErr := ipt.ListChains("filter") + if filterListErr == nil { + return true + } + logf("ip6tables filtering is not supported on this host: %v", filterListErr) + return false +} + +// checkSupportsV6NAT returns whether the system has a "nat" table in the +// IPv6 netfilter stack. +// +// The nat table was added after the initial release of ipv6 +// netfilter, so some older distros ship a kernel that can't NAT IPv6 +// traffic. +// ipt must be initialized for IPv6. +func checkSupportsV6NAT(ipt *iptables.IPTables, logf logger.Logf) bool { + if ipt == nil || ipt.Proto() != iptables.ProtocolIPv6 { + return false + } + _, natListErr := ipt.ListChains("nat") + if natListErr == nil { + return true + } + + // TODO (irbekrm): the following two checks were added before the check + // above that verifies that nat chains can be listed. It is a + // container-friendly check (see + // https://github.com/tailscale/tailscale/issues/11344), but also should + // be good enough on its own in other environments. If we never observe + // it falsely succeed, let's remove the other two checks. + + bs, err := os.ReadFile("/proc/net/ip6_tables_names") + if err != nil { + return false + } + if bytes.Contains(bs, []byte("nat\n")) { + logf("[unexpected] listing nat chains failed, but /proc/net/ip6_tables_name reports a nat table existing") + return true + } + if exec.Command("modprobe", "ip6table_nat").Run() == nil { + logf("[unexpected] listing nat chains failed, but modprobe ip6table_nat succeeded") + return true + } + return false +} + +func init() { + hookIPTablesCleanup.Set(ipTablesCleanUp) +} + +// ipTablesCleanUp removes all Tailscale added iptables rules. +// Any errors that occur are logged to the provided logf. +func ipTablesCleanUp(logf logger.Logf) { + switch distro.Get() { + case distro.Gokrazy, distro.JetKVM: + // These use nftables and don't have the "iptables" command. + // Avoid log spam on cleanup. (#12277) + return + } + err := clearRules(iptables.ProtocolIPv4, logf) + if err != nil { + logf("linuxfw: clear iptables: %v", err) + } + + err = clearRules(iptables.ProtocolIPv6, logf) + if err != nil { + logf("linuxfw: clear ip6tables: %v", err) + } +} + +// clearRules clears all the iptables rules created by Tailscale +// for the given protocol. If error occurs, it's logged but not returned. +func clearRules(proto iptables.Protocol, logf logger.Logf) error { + ipt, err := iptables.NewWithProtocol(proto) + if err != nil { + return err + } + + var errs []error + + if err := delTSHook(ipt, "filter", "INPUT", logf); err != nil { + errs = append(errs, err) + } + if err := delTSHook(ipt, "filter", "FORWARD", logf); err != nil { + errs = append(errs, err) + } + if err := delTSHook(ipt, "nat", "POSTROUTING", logf); err != nil { + errs = append(errs, err) + } + + if err := delChain(ipt, "filter", "ts-input"); err != nil { + errs = append(errs, err) + } + if err := delChain(ipt, "filter", "ts-forward"); err != nil { + errs = append(errs, err) + } + + if err := delChain(ipt, "nat", "ts-postrouting"); err != nil { + errs = append(errs, err) + } + + return errors.Join(errs...) +} diff --git a/util/linuxfw/iptables_disabled.go b/util/linuxfw/iptables_disabled.go new file mode 100644 index 000000000..538e33647 --- /dev/null +++ b/util/linuxfw/iptables_disabled.go @@ -0,0 +1,20 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build linux && ts_omit_iptables + +package linuxfw + +import ( + "errors" + + "tailscale.com/types/logger" +) + +func detectIptables() (int, error) { + return 0, nil +} + +func newIPTablesRunner(logf logger.Logf) (*iptablesRunner, error) { + return nil, errors.New("iptables disabled in build") +} diff --git a/util/linuxfw/iptables_for_svcs.go b/util/linuxfw/iptables_for_svcs.go index 8e0f5d48d..2cd8716e4 100644 --- a/util/linuxfw/iptables_for_svcs.go +++ b/util/linuxfw/iptables_for_svcs.go @@ -13,6 +13,7 @@ import ( // This file contains functionality to insert portmapping rules for a 'service'. // These are currently only used by the Kubernetes operator proxies. // An iptables rule for such a service contains a comment with the service name. +// A 'service' corresponds to a VIPService as used by the Kubernetes operator. // EnsurePortMapRuleForSvc adds a prerouting rule that forwards traffic received // on match port and NOT on the provided interface to target IP and target port. @@ -24,10 +25,10 @@ func (i *iptablesRunner) EnsurePortMapRuleForSvc(svc, tun string, targetIP netip if err != nil { return fmt.Errorf("error checking if rule exists: %w", err) } - if !exists { - return table.Append("nat", "PREROUTING", args...) + if exists { + return nil } - return nil + return table.Append("nat", "PREROUTING", args...) } // DeleteMapRuleForSvc constructs a prerouting rule as would be created by @@ -40,10 +41,41 @@ func (i *iptablesRunner) DeletePortMapRuleForSvc(svc, excludeI string, targetIP if err != nil { return fmt.Errorf("error checking if rule exists: %w", err) } + if !exists { + return nil + } + return table.Delete("nat", "PREROUTING", args...) +} + +// EnsureDNATRuleForSvc adds a DNAT rule that forwards traffic from the +// VIPService IP address to a local address. This is used by the Kubernetes +// operator's network layer proxies to forward tailnet traffic for VIPServices +// to Kubernetes Services. +func (i *iptablesRunner) EnsureDNATRuleForSvc(svcName string, origDst, dst netip.Addr) error { + table := i.getIPTByAddr(dst) + args := argsForIngressRule(svcName, origDst, dst) + exists, err := table.Exists("nat", "PREROUTING", args...) + if err != nil { + return fmt.Errorf("error checking if rule exists: %w", err) + } if exists { - return table.Delete("nat", "PREROUTING", args...) + return nil } - return nil + return table.Append("nat", "PREROUTING", args...) +} + +// DeleteDNATRuleForSvc deletes a DNAT rule created by EnsureDNATRuleForSvc. +func (i *iptablesRunner) DeleteDNATRuleForSvc(svcName string, origDst, dst netip.Addr) error { + table := i.getIPTByAddr(dst) + args := argsForIngressRule(svcName, origDst, dst) + exists, err := table.Exists("nat", "PREROUTING", args...) + if err != nil { + return fmt.Errorf("error checking if rule exists: %w", err) + } + if !exists { + return nil + } + return table.Delete("nat", "PREROUTING", args...) } // DeleteSvc constructs all possible rules that would have been created by @@ -72,8 +104,24 @@ func argsForPortMapRule(svc, excludeI string, targetIP netip.Addr, pm PortMap) [ } } +func argsForIngressRule(svcName string, origDst, targetIP netip.Addr) []string { + c := commentForIngressSvc(svcName, origDst, targetIP) + return []string{ + "--destination", origDst.String(), + "-m", "comment", "--comment", c, + "-j", "DNAT", + "--to-destination", targetIP.String(), + } +} + // commentForSvc generates a comment to be added to an iptables DNAT rule for a // service. This is for iptables debugging/readability purposes only. func commentForSvc(svc string, pm PortMap) string { return fmt.Sprintf("%s:%s:%d -> %s:%d", svc, pm.Protocol, pm.MatchPort, pm.Protocol, pm.TargetPort) } + +// commentForIngressSvc generates a comment to be added to an iptables DNAT rule for a +// service. This is for iptables debugging/readability purposes only. +func commentForIngressSvc(svc string, vip, clusterIP netip.Addr) string { + return fmt.Sprintf("svc: %s, %s -> %s", svc, vip.String(), clusterIP.String()) +} diff --git a/util/linuxfw/iptables_for_svcs_test.go b/util/linuxfw/iptables_for_svcs_test.go index 99b2f517f..0e56d70ba 100644 --- a/util/linuxfw/iptables_for_svcs_test.go +++ b/util/linuxfw/iptables_for_svcs_test.go @@ -10,6 +10,10 @@ import ( "testing" ) +func newFakeIPTablesRunner() *iptablesRunner { + return NewFakeIPTablesRunner().(*iptablesRunner) +} + func Test_iptablesRunner_EnsurePortMapRuleForSvc(t *testing.T) { v4Addr := netip.MustParseAddr("10.0.0.4") v6Addr := netip.MustParseAddr("fd7a:115c:a1e0::701:b62a") @@ -45,7 +49,7 @@ func Test_iptablesRunner_EnsurePortMapRuleForSvc(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - iptr := NewFakeIPTablesRunner() + iptr := newFakeIPTablesRunner() table := iptr.getIPTByAddr(tt.targetIP) for _, ruleset := range tt.precreateSvcRules { mustPrecreatePortMapRule(t, ruleset, table) @@ -103,7 +107,7 @@ func Test_iptablesRunner_DeletePortMapRuleForSvc(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - iptr := NewFakeIPTablesRunner() + iptr := newFakeIPTablesRunner() table := iptr.getIPTByAddr(tt.targetIP) for _, ruleset := range tt.precreateSvcRules { mustPrecreatePortMapRule(t, ruleset, table) @@ -127,7 +131,7 @@ func Test_iptablesRunner_DeleteSvc(t *testing.T) { v4Addr := netip.MustParseAddr("10.0.0.4") v6Addr := netip.MustParseAddr("fd7a:115c:a1e0::701:b62a") testPM := PortMap{Protocol: "tcp", MatchPort: 4003, TargetPort: 80} - iptr := NewFakeIPTablesRunner() + iptr := newFakeIPTablesRunner() // create two rules that will consitute svc1 s1R1 := argsForPortMapRule("svc1", "tailscale0", v4Addr, testPM) @@ -153,6 +157,135 @@ func Test_iptablesRunner_DeleteSvc(t *testing.T) { svcMustExist(t, "svc2", map[string][]string{v4Addr.String(): s2R1, v6Addr.String(): s2R2}, iptr) } +func Test_iptablesRunner_EnsureDNATRuleForSvc(t *testing.T) { + v4OrigDst := netip.MustParseAddr("10.0.0.1") + v4Target := netip.MustParseAddr("10.0.0.2") + v6OrigDst := netip.MustParseAddr("fd7a:115c:a1e0::1") + v6Target := netip.MustParseAddr("fd7a:115c:a1e0::2") + v4Rule := argsForIngressRule("svc:test", v4OrigDst, v4Target) + + tests := []struct { + name string + svcName string + origDst netip.Addr + targetIP netip.Addr + precreateSvcRules [][]string + }{ + { + name: "dnat_for_ipv4", + svcName: "svc:test", + origDst: v4OrigDst, + targetIP: v4Target, + }, + { + name: "dnat_for_ipv6", + svcName: "svc:test-2", + origDst: v6OrigDst, + targetIP: v6Target, + }, + { + name: "add_existing_rule", + svcName: "svc:test", + origDst: v4OrigDst, + targetIP: v4Target, + precreateSvcRules: [][]string{v4Rule}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + iptr := newFakeIPTablesRunner() + table := iptr.getIPTByAddr(tt.targetIP) + for _, ruleset := range tt.precreateSvcRules { + mustPrecreateDNATRule(t, ruleset, table) + } + if err := iptr.EnsureDNATRuleForSvc(tt.svcName, tt.origDst, tt.targetIP); err != nil { + t.Errorf("[unexpected error] iptablesRunner.EnsureDNATRuleForSvc() = %v", err) + } + args := argsForIngressRule(tt.svcName, tt.origDst, tt.targetIP) + exists, err := table.Exists("nat", "PREROUTING", args...) + if err != nil { + t.Fatalf("error checking if rule exists: %v", err) + } + if !exists { + t.Errorf("expected rule was not created") + } + }) + } +} + +func Test_iptablesRunner_DeleteDNATRuleForSvc(t *testing.T) { + v4OrigDst := netip.MustParseAddr("10.0.0.1") + v4Target := netip.MustParseAddr("10.0.0.2") + v6OrigDst := netip.MustParseAddr("fd7a:115c:a1e0::1") + v6Target := netip.MustParseAddr("fd7a:115c:a1e0::2") + v4Rule := argsForIngressRule("svc:test", v4OrigDst, v4Target) + v6Rule := argsForIngressRule("svc:test", v6OrigDst, v6Target) + + tests := []struct { + name string + svcName string + origDst netip.Addr + targetIP netip.Addr + precreateSvcRules [][]string + }{ + { + name: "multiple_rules_ipv4_deleted", + svcName: "svc:test", + origDst: v4OrigDst, + targetIP: v4Target, + precreateSvcRules: [][]string{v4Rule, v6Rule}, + }, + { + name: "multiple_rules_ipv6_deleted", + svcName: "svc:test", + origDst: v6OrigDst, + targetIP: v6Target, + precreateSvcRules: [][]string{v4Rule, v6Rule}, + }, + { + name: "non-existent_rule_deleted", + svcName: "svc:test", + origDst: v4OrigDst, + targetIP: v4Target, + precreateSvcRules: [][]string{v6Rule}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + iptr := newFakeIPTablesRunner() + table := iptr.getIPTByAddr(tt.targetIP) + for _, ruleset := range tt.precreateSvcRules { + mustPrecreateDNATRule(t, ruleset, table) + } + if err := iptr.DeleteDNATRuleForSvc(tt.svcName, tt.origDst, tt.targetIP); err != nil { + t.Errorf("iptablesRunner.DeleteDNATRuleForSvc() errored: %v ", err) + } + deletedRule := argsForIngressRule(tt.svcName, tt.origDst, tt.targetIP) + exists, err := table.Exists("nat", "PREROUTING", deletedRule...) + if err != nil { + t.Fatalf("error verifying that rule does not exist after deletion: %v", err) + } + if exists { + t.Errorf("DNAT rule exists after deletion") + } + }) + } +} + +func mustPrecreateDNATRule(t *testing.T, rules []string, table iptablesInterface) { + t.Helper() + exists, err := table.Exists("nat", "PREROUTING", rules...) + if err != nil { + t.Fatalf("error ensuring that nat PREROUTING table exists: %v", err) + } + if exists { + return + } + if err := table.Append("nat", "PREROUTING", rules...); err != nil { + t.Fatalf("error precreating DNAT rule: %v", err) + } +} + func svcMustExist(t *testing.T, svcName string, rules map[string][]string, iptr *iptablesRunner) { t.Helper() for dst, ruleset := range rules { diff --git a/util/linuxfw/iptables_runner.go b/util/linuxfw/iptables_runner.go index 9a6fc0224..4443a9071 100644 --- a/util/linuxfw/iptables_runner.go +++ b/util/linuxfw/iptables_runner.go @@ -6,31 +6,22 @@ package linuxfw import ( - "bytes" - "errors" "fmt" "log" "net/netip" - "os" "os/exec" "slices" "strconv" "strings" - "github.com/coreos/go-iptables/iptables" "tailscale.com/net/tsaddr" "tailscale.com/types/logger" - "tailscale.com/util/multierr" - "tailscale.com/version/distro" ) // isNotExistError needs to be overridden in tests that rely on distinguishing // this error, because we don't have a good way how to create a new // iptables.Error of that type. -var isNotExistError = func(err error) bool { - var e *iptables.Error - return errors.As(err, &e) && e.IsNotExist() -} +var isNotExistError = func(err error) bool { return false } type iptablesInterface interface { // Adding this interface for testing purposes so we can mock out @@ -62,98 +53,6 @@ func checkIP6TablesExists() error { return nil } -// newIPTablesRunner constructs a NetfilterRunner that programs iptables rules. -// If the underlying iptables library fails to initialize, that error is -// returned. The runner probes for IPv6 support once at initialization time and -// if not found, no IPv6 rules will be modified for the lifetime of the runner. -func newIPTablesRunner(logf logger.Logf) (*iptablesRunner, error) { - ipt4, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) - if err != nil { - return nil, err - } - - supportsV6, supportsV6NAT, supportsV6Filter := false, false, false - v6err := CheckIPv6(logf) - ip6terr := checkIP6TablesExists() - var ipt6 *iptables.IPTables - switch { - case v6err != nil: - logf("disabling tunneled IPv6 due to system IPv6 config: %v", v6err) - case ip6terr != nil: - logf("disabling tunneled IPv6 due to missing ip6tables: %v", ip6terr) - default: - supportsV6 = true - ipt6, err = iptables.NewWithProtocol(iptables.ProtocolIPv6) - if err != nil { - return nil, err - } - supportsV6Filter = checkSupportsV6Filter(ipt6, logf) - supportsV6NAT = checkSupportsV6NAT(ipt6, logf) - logf("netfilter running in iptables mode v6 = %v, v6filter = %v, v6nat = %v", supportsV6, supportsV6Filter, supportsV6NAT) - } - return &iptablesRunner{ - ipt4: ipt4, - ipt6: ipt6, - v6Available: supportsV6, - v6NATAvailable: supportsV6NAT, - v6FilterAvailable: supportsV6Filter}, nil -} - -// checkSupportsV6Filter returns whether the system has a "filter" table in the -// IPv6 tables. Some container environments such as GitHub codespaces have -// limited local IPv6 support, and containers containing ip6tables, but do not -// have kernel support for IPv6 filtering. -// We will not set ip6tables rules in these instances. -func checkSupportsV6Filter(ipt *iptables.IPTables, logf logger.Logf) bool { - if ipt == nil { - return false - } - _, filterListErr := ipt.ListChains("filter") - if filterListErr == nil { - return true - } - logf("ip6tables filtering is not supported on this host: %v", filterListErr) - return false -} - -// checkSupportsV6NAT returns whether the system has a "nat" table in the -// IPv6 netfilter stack. -// -// The nat table was added after the initial release of ipv6 -// netfilter, so some older distros ship a kernel that can't NAT IPv6 -// traffic. -// ipt must be initialized for IPv6. -func checkSupportsV6NAT(ipt *iptables.IPTables, logf logger.Logf) bool { - if ipt == nil || ipt.Proto() != iptables.ProtocolIPv6 { - return false - } - _, natListErr := ipt.ListChains("nat") - if natListErr == nil { - return true - } - - // TODO (irbekrm): the following two checks were added before the check - // above that verifies that nat chains can be listed. It is a - // container-friendly check (see - // https://github.com/tailscale/tailscale/issues/11344), but also should - // be good enough on its own in other environments. If we never observe - // it falsely succeed, let's remove the other two checks. - - bs, err := os.ReadFile("/proc/net/ip6_tables_names") - if err != nil { - return false - } - if bytes.Contains(bs, []byte("nat\n")) { - logf("[unexpected] listing nat chains failed, but /proc/net/ip6_tables_name reports a nat table existing") - return true - } - if exec.Command("modprobe", "ip6table_nat").Run() == nil { - logf("[unexpected] listing nat chains failed, but modprobe ip6table_nat succeeded") - return true - } - return false -} - // HasIPV6 reports true if the system supports IPv6. func (i *iptablesRunner) HasIPV6() bool { return i.v6Available @@ -347,11 +246,11 @@ func (i *iptablesRunner) addBase4(tunname string) error { // POSTROUTING. So instead, we match on the inbound interface in // filter/FORWARD, and set a packet mark that nat/POSTROUTING can // use to effectively run that same test again. - args = []string{"-i", tunname, "-j", "MARK", "--set-mark", TailscaleSubnetRouteMark + "/" + TailscaleFwmarkMask} + args = []string{"-i", tunname, "-j", "MARK", "--set-mark", subnetRouteMark + "/" + fwmarkMask} if err := i.ipt4.Append("filter", "ts-forward", args...); err != nil { return fmt.Errorf("adding %v in v4/filter/ts-forward: %w", args, err) } - args = []string{"-m", "mark", "--mark", TailscaleSubnetRouteMark + "/" + TailscaleFwmarkMask, "-j", "ACCEPT"} + args = []string{"-m", "mark", "--mark", subnetRouteMark + "/" + fwmarkMask, "-j", "ACCEPT"} if err := i.ipt4.Append("filter", "ts-forward", args...); err != nil { return fmt.Errorf("adding %v in v4/filter/ts-forward: %w", args, err) } @@ -453,11 +352,11 @@ func (i *iptablesRunner) addBase6(tunname string) error { return fmt.Errorf("adding %v in v6/filter/ts-input: %w", args, err) } - args = []string{"-i", tunname, "-j", "MARK", "--set-mark", TailscaleSubnetRouteMark + "/" + TailscaleFwmarkMask} + args = []string{"-i", tunname, "-j", "MARK", "--set-mark", subnetRouteMark + "/" + fwmarkMask} if err := i.ipt6.Append("filter", "ts-forward", args...); err != nil { return fmt.Errorf("adding %v in v6/filter/ts-forward: %w", args, err) } - args = []string{"-m", "mark", "--mark", TailscaleSubnetRouteMark + "/" + TailscaleFwmarkMask, "-j", "ACCEPT"} + args = []string{"-m", "mark", "--mark", subnetRouteMark + "/" + fwmarkMask, "-j", "ACCEPT"} if err := i.ipt6.Append("filter", "ts-forward", args...); err != nil { return fmt.Errorf("adding %v in v6/filter/ts-forward: %w", args, err) } @@ -546,7 +445,7 @@ func (i *iptablesRunner) DelHooks(logf logger.Logf) error { // AddSNATRule adds a netfilter rule to SNAT traffic destined for // local subnets. func (i *iptablesRunner) AddSNATRule() error { - args := []string{"-m", "mark", "--mark", TailscaleSubnetRouteMark + "/" + TailscaleFwmarkMask, "-j", "MASQUERADE"} + args := []string{"-m", "mark", "--mark", subnetRouteMark + "/" + fwmarkMask, "-j", "MASQUERADE"} for _, ipt := range i.getNATTables() { if err := ipt.Append("nat", "ts-postrouting", args...); err != nil { return fmt.Errorf("adding %v in nat/ts-postrouting: %w", args, err) @@ -558,7 +457,7 @@ func (i *iptablesRunner) AddSNATRule() error { // DelSNATRule removes the netfilter rule to SNAT traffic destined for // local subnets. An error is returned if the rule does not exist. func (i *iptablesRunner) DelSNATRule() error { - args := []string{"-m", "mark", "--mark", TailscaleSubnetRouteMark + "/" + TailscaleFwmarkMask, "-j", "MASQUERADE"} + args := []string{"-m", "mark", "--mark", subnetRouteMark + "/" + fwmarkMask, "-j", "MASQUERADE"} for _, ipt := range i.getNATTables() { if err := ipt.Delete("nat", "ts-postrouting", args...); err != nil { return fmt.Errorf("deleting %v in nat/ts-postrouting: %w", args, err) @@ -685,25 +584,6 @@ func (i *iptablesRunner) DelMagicsockPortRule(port uint16, network string) error return nil } -// IPTablesCleanUp removes all Tailscale added iptables rules. -// Any errors that occur are logged to the provided logf. -func IPTablesCleanUp(logf logger.Logf) { - if distro.Get() == distro.Gokrazy { - // Gokrazy uses nftables and doesn't have the "iptables" command. - // Avoid log spam on cleanup. (#12277) - return - } - err := clearRules(iptables.ProtocolIPv4, logf) - if err != nil { - logf("linuxfw: clear iptables: %v", err) - } - - err = clearRules(iptables.ProtocolIPv6, logf) - if err != nil { - logf("linuxfw: clear ip6tables: %v", err) - } -} - // delTSHook deletes hook in a chain that jumps to a ts-chain. If the hook does not // exist, it's a no-op since the desired state is already achieved but we log the // error because error code from the iptables module resists unwrapping. @@ -732,40 +612,6 @@ func delChain(ipt iptablesInterface, table, chain string) error { return nil } -// clearRules clears all the iptables rules created by Tailscale -// for the given protocol. If error occurs, it's logged but not returned. -func clearRules(proto iptables.Protocol, logf logger.Logf) error { - ipt, err := iptables.NewWithProtocol(proto) - if err != nil { - return err - } - - var errs []error - - if err := delTSHook(ipt, "filter", "INPUT", logf); err != nil { - errs = append(errs, err) - } - if err := delTSHook(ipt, "filter", "FORWARD", logf); err != nil { - errs = append(errs, err) - } - if err := delTSHook(ipt, "nat", "POSTROUTING", logf); err != nil { - errs = append(errs, err) - } - - if err := delChain(ipt, "filter", "ts-input"); err != nil { - errs = append(errs, err) - } - if err := delChain(ipt, "filter", "ts-forward"); err != nil { - errs = append(errs, err) - } - - if err := delChain(ipt, "nat", "ts-postrouting"); err != nil { - errs = append(errs, err) - } - - return multierr.New(errs...) -} - // argsFromPostRoutingRule accepts a rule as returned by iptables.List and, if it is a rule from POSTROUTING chain, // returns the args part, else returns the original rule. func argsFromPostRoutingRule(r string) string { diff --git a/util/linuxfw/iptables_runner_test.go b/util/linuxfw/iptables_runner_test.go index 56f13c78a..ce905aef3 100644 --- a/util/linuxfw/iptables_runner_test.go +++ b/util/linuxfw/iptables_runner_test.go @@ -11,6 +11,7 @@ import ( "testing" "tailscale.com/net/tsaddr" + "tailscale.com/tsconst" ) var testIsNotExistErr = "exitcode:1" @@ -20,7 +21,7 @@ func init() { } func TestAddAndDeleteChains(t *testing.T) { - iptr := NewFakeIPTablesRunner() + iptr := newFakeIPTablesRunner() err := iptr.AddChains() if err != nil { t.Fatal(err) @@ -59,7 +60,7 @@ func TestAddAndDeleteChains(t *testing.T) { } func TestAddAndDeleteHooks(t *testing.T) { - iptr := NewFakeIPTablesRunner() + iptr := newFakeIPTablesRunner() // don't need to test what happens if the chains don't exist, because // this is handled by fake iptables, in realife iptables would return error. if err := iptr.AddChains(); err != nil { @@ -113,7 +114,7 @@ func TestAddAndDeleteHooks(t *testing.T) { } func TestAddAndDeleteBase(t *testing.T) { - iptr := NewFakeIPTablesRunner() + iptr := newFakeIPTablesRunner() tunname := "tun0" if err := iptr.AddChains(); err != nil { t.Fatal(err) @@ -132,8 +133,8 @@ func TestAddAndDeleteBase(t *testing.T) { tsRulesCommon := []fakeRule{ // table/chain/rule {"filter", "ts-input", []string{"-i", tunname, "-j", "ACCEPT"}}, - {"filter", "ts-forward", []string{"-i", tunname, "-j", "MARK", "--set-mark", TailscaleSubnetRouteMark + "/" + TailscaleFwmarkMask}}, - {"filter", "ts-forward", []string{"-m", "mark", "--mark", TailscaleSubnetRouteMark + "/" + TailscaleFwmarkMask, "-j", "ACCEPT"}}, + {"filter", "ts-forward", []string{"-i", tunname, "-j", "MARK", "--set-mark", tsconst.LinuxSubnetRouteMark + "/" + tsconst.LinuxFwmarkMask}}, + {"filter", "ts-forward", []string{"-m", "mark", "--mark", tsconst.LinuxSubnetRouteMark + "/" + tsconst.LinuxFwmarkMask, "-j", "ACCEPT"}}, {"filter", "ts-forward", []string{"-o", tunname, "-j", "ACCEPT"}}, } @@ -176,7 +177,7 @@ func TestAddAndDeleteBase(t *testing.T) { } func TestAddAndDelLoopbackRule(t *testing.T) { - iptr := NewFakeIPTablesRunner() + iptr := newFakeIPTablesRunner() // We don't need to test for malformed addresses, AddLoopbackRule // takes in a netip.Addr, which is already valid. fakeAddrV4 := netip.MustParseAddr("192.168.0.2") @@ -247,14 +248,14 @@ func TestAddAndDelLoopbackRule(t *testing.T) { } func TestAddAndDelSNATRule(t *testing.T) { - iptr := NewFakeIPTablesRunner() + iptr := newFakeIPTablesRunner() if err := iptr.AddChains(); err != nil { t.Fatal(err) } rule := fakeRule{ // table/chain/rule - "nat", "ts-postrouting", []string{"-m", "mark", "--mark", TailscaleSubnetRouteMark + "/" + TailscaleFwmarkMask, "-j", "MASQUERADE"}, + "nat", "ts-postrouting", []string{"-m", "mark", "--mark", tsconst.LinuxSubnetRouteMark + "/" + tsconst.LinuxFwmarkMask, "-j", "MASQUERADE"}, } // Add SNAT rule @@ -292,7 +293,7 @@ func TestAddAndDelSNATRule(t *testing.T) { func TestEnsureSNATForDst_ipt(t *testing.T) { ip1, ip2, ip3 := netip.MustParseAddr("100.99.99.99"), netip.MustParseAddr("100.88.88.88"), netip.MustParseAddr("100.77.77.77") - iptr := NewFakeIPTablesRunner() + iptr := newFakeIPTablesRunner() // 1. A new rule gets added mustCreateSNATRule_ipt(t, iptr, ip1, ip2) diff --git a/util/linuxfw/linuxfw.go b/util/linuxfw/linuxfw.go index be520e7a4..ec73aacee 100644 --- a/util/linuxfw/linuxfw.go +++ b/util/linuxfw/linuxfw.go @@ -14,6 +14,8 @@ import ( "strings" "github.com/tailscale/netlink" + "tailscale.com/feature" + "tailscale.com/tsconst" "tailscale.com/types/logger" ) @@ -69,23 +71,12 @@ const ( // matching and setting the bits, so they can be directly embedded in // commands. const ( - // The mask for reading/writing the 'firewall mask' bits on a packet. - // See the comment on the const block on why we only use the third byte. - // - // We claim bits 16:23 entirely. For now we only use the lower four - // bits, leaving the higher 4 bits for future use. - TailscaleFwmarkMask = "0xff0000" - TailscaleFwmarkMaskNum = 0xff0000 - - // Packet is from Tailscale and to a subnet route destination, so - // is allowed to be routed through this machine. - TailscaleSubnetRouteMark = "0x40000" - TailscaleSubnetRouteMarkNum = 0x40000 - - // Packet was originated by tailscaled itself, and must not be - // routed over the Tailscale network. - TailscaleBypassMark = "0x80000" - TailscaleBypassMarkNum = 0x80000 + fwmarkMask = tsconst.LinuxFwmarkMask + fwmarkMaskNum = tsconst.LinuxFwmarkMaskNum + subnetRouteMark = tsconst.LinuxSubnetRouteMark + subnetRouteMarkNum = tsconst.LinuxSubnetRouteMarkNum + bypassMark = tsconst.LinuxBypassMark + bypassMarkNum = tsconst.LinuxBypassMarkNum ) // getTailscaleFwmarkMaskNeg returns the negation of TailscaleFwmarkMask in bytes. @@ -169,7 +160,7 @@ func CheckIPRuleSupportsV6(logf logger.Logf) error { // Try to actually create & delete one as a test. rule := netlink.NewRule() rule.Priority = 1234 - rule.Mark = TailscaleBypassMarkNum + rule.Mark = bypassMarkNum rule.Table = 52 rule.Family = netlink.FAMILY_V6 // First delete the rule unconditionally, and don't check for @@ -180,3 +171,13 @@ func CheckIPRuleSupportsV6(logf logger.Logf) error { defer netlink.RuleDel(rule) return netlink.RuleAdd(rule) } + +var hookIPTablesCleanup feature.Hook[func(logger.Logf)] + +// IPTablesCleanUp removes all Tailscale added iptables rules. +// Any errors that occur are logged to the provided logf. +func IPTablesCleanUp(logf logger.Logf) { + if f, ok := hookIPTablesCleanup.GetOk(); ok { + f(logf) + } +} diff --git a/util/linuxfw/linuxfw_unsupported.go b/util/linuxfw/linuxfw_unsupported.go deleted file mode 100644 index 7bfb4fd01..000000000 --- a/util/linuxfw/linuxfw_unsupported.go +++ /dev/null @@ -1,40 +0,0 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// NOTE: linux_{arm64, amd64} are the only two currently supported archs due to missing -// support in upstream dependencies. - -// TODO(#8502): add support for more architectures -//go:build linux && !(arm64 || amd64) - -package linuxfw - -import ( - "errors" - - "tailscale.com/types/logger" -) - -// ErrUnsupported is the error returned from all functions on non-Linux -// platforms. -var ErrUnsupported = errors.New("linuxfw:unsupported") - -// DebugNetfilter is not supported on non-Linux platforms. -func DebugNetfilter(logf logger.Logf) error { - return ErrUnsupported -} - -// DetectNetfilter is not supported on non-Linux platforms. -func detectNetfilter() (int, error) { - return 0, ErrUnsupported -} - -// DebugIptables is not supported on non-Linux platforms. -func debugIptables(logf logger.Logf) error { - return ErrUnsupported -} - -// DetectIptables is not supported on non-Linux platforms. -func detectIptables() (int, error) { - return 0, ErrUnsupported -} diff --git a/util/linuxfw/nftables.go b/util/linuxfw/nftables.go index 056563071..94ce51a14 100644 --- a/util/linuxfw/nftables.go +++ b/util/linuxfw/nftables.go @@ -8,6 +8,7 @@ package linuxfw import ( "cmp" + "encoding/binary" "fmt" "sort" "strings" @@ -15,7 +16,6 @@ import ( "github.com/google/nftables" "github.com/google/nftables/expr" "github.com/google/nftables/xt" - "github.com/josharian/native" "golang.org/x/sys/unix" "tailscale.com/types/logger" ) @@ -103,6 +103,10 @@ func DebugNetfilter(logf logger.Logf) error { return nil } +func init() { + hookDetectNetfilter.Set(detectNetfilter) +} + // detectNetfilter returns the number of nftables rules present in the system. func detectNetfilter() (int, error) { // Frist try creating a dummy postrouting chain. Emperically, we have @@ -235,8 +239,8 @@ func printMatchInfo(name string, info xt.InfoAny) string { break } - pkttype := int(native.Endian.Uint32(data[0:4])) - invert := int(native.Endian.Uint32(data[4:8])) + pkttype := int(binary.NativeEndian.Uint32(data[0:4])) + invert := int(binary.NativeEndian.Uint32(data[4:8])) var invertPrefix string if invert != 0 { invertPrefix = "!" diff --git a/util/linuxfw/nftables_for_svcs.go b/util/linuxfw/nftables_for_svcs.go index 130585b22..474b98086 100644 --- a/util/linuxfw/nftables_for_svcs.go +++ b/util/linuxfw/nftables_for_svcs.go @@ -119,6 +119,63 @@ func (n *nftablesRunner) DeleteSvc(svc, tun string, targetIPs []netip.Addr, pm [ return n.conn.Flush() } +// EnsureDNATRuleForSvc adds a DNAT rule that forwards traffic from the +// VIPService IP address to a local address. This is used by the Kubernetes +// operator's network layer proxies to forward tailnet traffic for VIPServices +// to Kubernetes Services. +func (n *nftablesRunner) EnsureDNATRuleForSvc(svc string, origDst, dst netip.Addr) error { + t, ch, err := n.ensurePreroutingChain(origDst) + if err != nil { + return fmt.Errorf("error ensuring chain for %s: %w", svc, err) + } + meta := svcRuleMeta(svc, origDst, dst) + rule, err := n.findRuleByMetadata(t, ch, meta) + if err != nil { + return fmt.Errorf("error looking up rule: %w", err) + } + if rule != nil { + return nil + } + rule = dnatRuleForChain(t, ch, origDst, dst, meta) + n.conn.InsertRule(rule) + return n.conn.Flush() +} + +// DeleteDNATRuleForSvc deletes a DNAT rule created by EnsureDNATRuleForSvc. +// We use the metadata attached to the rule to look it up. +func (n *nftablesRunner) DeleteDNATRuleForSvc(svcName string, origDst, dst netip.Addr) error { + table, err := n.getNFTByAddr(origDst) + if err != nil { + return fmt.Errorf("error setting up nftables for IP family of %s: %w", origDst, err) + } + t, err := getTableIfExists(n.conn, table.Proto, "nat") + if err != nil { + return fmt.Errorf("error checking if nat table exists: %w", err) + } + if t == nil { + return nil + } + ch, err := getChainFromTable(n.conn, t, "PREROUTING") + if errors.Is(err, errorChainNotFound{tableName: "nat", chainName: "PREROUTING"}) { + return nil + } + if err != nil { + return fmt.Errorf("error checking if chain PREROUTING exists: %w", err) + } + meta := svcRuleMeta(svcName, origDst, dst) + rule, err := n.findRuleByMetadata(t, ch, meta) + if err != nil { + return fmt.Errorf("error checking if rule exists: %w", err) + } + if rule == nil { + return nil + } + if err := n.conn.DelRule(rule); err != nil { + return fmt.Errorf("error deleting rule: %w", err) + } + return n.conn.Flush() +} + func portMapRule(t *nftables.Table, ch *nftables.Chain, tun string, targetIP netip.Addr, matchPort, targetPort uint16, proto uint8, meta []byte) *nftables.Rule { var fam uint32 if targetIP.Is4() { @@ -243,3 +300,10 @@ func protoFromString(s string) (uint8, error) { return 0, fmt.Errorf("unrecognized protocol: %q", s) } } + +// svcRuleMeta generates metadata for a rule. +// This metadata can then be used to find the rule. +// https://github.com/google/nftables/issues/48 +func svcRuleMeta(svcName string, origDst, dst netip.Addr) []byte { + return []byte(fmt.Sprintf("svc:%s,VIP:%s,ClusterIP:%s", svcName, origDst.String(), dst.String())) +} diff --git a/util/linuxfw/nftables_for_svcs_test.go b/util/linuxfw/nftables_for_svcs_test.go index d2df6e4bd..73472ce20 100644 --- a/util/linuxfw/nftables_for_svcs_test.go +++ b/util/linuxfw/nftables_for_svcs_test.go @@ -14,8 +14,9 @@ import ( // This test creates a temporary network namespace for the nftables rules being // set up, so it needs to run in a privileged mode. Locally it needs to be run -// by root, else it will be silently skipped. In CI it runs in a privileged -// container. +// by root, else it will be silently skipped. +// sudo go test -v -run Test_nftablesRunner_EnsurePortMapRuleForSvc ./util/linuxfw/... +// In CI it runs in a privileged container. func Test_nftablesRunner_EnsurePortMapRuleForSvc(t *testing.T) { conn := newSysConn(t) runner := newFakeNftablesRunnerWithConn(t, conn, true) @@ -23,51 +24,215 @@ func Test_nftablesRunner_EnsurePortMapRuleForSvc(t *testing.T) { pmTCP := PortMap{MatchPort: 4003, TargetPort: 80, Protocol: "TCP"} pmTCP1 := PortMap{MatchPort: 4004, TargetPort: 443, Protocol: "TCP"} - // Create a rule for service 'foo' to forward TCP traffic to IPv4 endpoint - runner.EnsurePortMapRuleForSvc("foo", "tailscale0", ipv4, pmTCP) + // Create a rule for service 'svc:foo' to forward TCP traffic to IPv4 endpoint + runner.EnsurePortMapRuleForSvc("svc:foo", "tailscale0", ipv4, pmTCP) svcChains(t, 1, conn) - chainRuleCount(t, "foo", 1, conn, nftables.TableFamilyIPv4) - checkPortMapRule(t, "foo", ipv4, pmTCP, runner, nftables.TableFamilyIPv4) + chainRuleCount(t, "svc:foo", 1, conn, nftables.TableFamilyIPv4) + checkPortMapRule(t, "svc:foo", ipv4, pmTCP, runner, nftables.TableFamilyIPv4) - // Create another rule for service 'foo' to forward TCP traffic to the + // Create another rule for service 'svc:foo' to forward TCP traffic to the // same IPv4 endpoint, but to a different port. - runner.EnsurePortMapRuleForSvc("foo", "tailscale0", ipv4, pmTCP1) + runner.EnsurePortMapRuleForSvc("svc:foo", "tailscale0", ipv4, pmTCP1) svcChains(t, 1, conn) - chainRuleCount(t, "foo", 2, conn, nftables.TableFamilyIPv4) - checkPortMapRule(t, "foo", ipv4, pmTCP1, runner, nftables.TableFamilyIPv4) + chainRuleCount(t, "svc:foo", 2, conn, nftables.TableFamilyIPv4) + checkPortMapRule(t, "svc:foo", ipv4, pmTCP1, runner, nftables.TableFamilyIPv4) - // Create a rule for service 'foo' to forward TCP traffic to an IPv6 endpoint - runner.EnsurePortMapRuleForSvc("foo", "tailscale0", ipv6, pmTCP) + // Create a rule for service 'svc:foo' to forward TCP traffic to an IPv6 endpoint + runner.EnsurePortMapRuleForSvc("svc:foo", "tailscale0", ipv6, pmTCP) svcChains(t, 2, conn) - chainRuleCount(t, "foo", 1, conn, nftables.TableFamilyIPv6) - checkPortMapRule(t, "foo", ipv6, pmTCP, runner, nftables.TableFamilyIPv6) + chainRuleCount(t, "svc:foo", 1, conn, nftables.TableFamilyIPv6) + checkPortMapRule(t, "svc:foo", ipv6, pmTCP, runner, nftables.TableFamilyIPv6) - // Create a rule for service 'bar' to forward TCP traffic to IPv4 endpoint - runner.EnsurePortMapRuleForSvc("bar", "tailscale0", ipv4, pmTCP) + // Create a rule for service 'svc:bar' to forward TCP traffic to IPv4 endpoint + runner.EnsurePortMapRuleForSvc("svc:bar", "tailscale0", ipv4, pmTCP) svcChains(t, 3, conn) - chainRuleCount(t, "bar", 1, conn, nftables.TableFamilyIPv4) - checkPortMapRule(t, "bar", ipv4, pmTCP, runner, nftables.TableFamilyIPv4) + chainRuleCount(t, "svc:bar", 1, conn, nftables.TableFamilyIPv4) + checkPortMapRule(t, "svc:bar", ipv4, pmTCP, runner, nftables.TableFamilyIPv4) - // Create a rule for service 'bar' to forward TCP traffic to an IPv6 endpoint - runner.EnsurePortMapRuleForSvc("bar", "tailscale0", ipv6, pmTCP) + // Create a rule for service 'svc:bar' to forward TCP traffic to an IPv6 endpoint + runner.EnsurePortMapRuleForSvc("svc:bar", "tailscale0", ipv6, pmTCP) svcChains(t, 4, conn) - chainRuleCount(t, "bar", 1, conn, nftables.TableFamilyIPv6) - checkPortMapRule(t, "bar", ipv6, pmTCP, runner, nftables.TableFamilyIPv6) + chainRuleCount(t, "svc:bar", 1, conn, nftables.TableFamilyIPv6) + checkPortMapRule(t, "svc:bar", ipv6, pmTCP, runner, nftables.TableFamilyIPv6) - // Delete service bar - runner.DeleteSvc("bar", "tailscale0", []netip.Addr{ipv4, ipv6}, []PortMap{pmTCP}) + // Delete service svc:bar + runner.DeleteSvc("svc:bar", "tailscale0", []netip.Addr{ipv4, ipv6}, []PortMap{pmTCP}) svcChains(t, 2, conn) - // Delete a rule from service foo - runner.DeletePortMapRuleForSvc("foo", "tailscale0", ipv4, pmTCP) + // Delete a rule from service svc:foo + runner.DeletePortMapRuleForSvc("svc:foo", "tailscale0", ipv4, pmTCP) svcChains(t, 2, conn) - chainRuleCount(t, "foo", 1, conn, nftables.TableFamilyIPv4) + chainRuleCount(t, "svc:foo", 1, conn, nftables.TableFamilyIPv4) - // Delete service foo - runner.DeleteSvc("foo", "tailscale0", []netip.Addr{ipv4, ipv6}, []PortMap{pmTCP, pmTCP1}) + // Delete service svc:foo + runner.DeleteSvc("svc:foo", "tailscale0", []netip.Addr{ipv4, ipv6}, []PortMap{pmTCP, pmTCP1}) svcChains(t, 0, conn) } +func Test_nftablesRunner_EnsureDNATRuleForSvc(t *testing.T) { + conn := newSysConn(t) + runner := newFakeNftablesRunnerWithConn(t, conn, true) + + // Test IPv4 DNAT rule + ipv4OrigDst := netip.MustParseAddr("10.0.0.1") + ipv4Target := netip.MustParseAddr("10.0.0.2") + + // Create DNAT rule for service 'svc:foo' to forward IPv4 traffic + err := runner.EnsureDNATRuleForSvc("svc:foo", ipv4OrigDst, ipv4Target) + if err != nil { + t.Fatalf("error creating IPv4 DNAT rule: %v", err) + } + checkDNATRule(t, "svc:foo", ipv4OrigDst, ipv4Target, runner, nftables.TableFamilyIPv4) + + // Test IPv6 DNAT rule + ipv6OrigDst := netip.MustParseAddr("fd7a:115c:a1e0::1") + ipv6Target := netip.MustParseAddr("fd7a:115c:a1e0::2") + + // Create DNAT rule for service 'svc:foo' to forward IPv6 traffic + err = runner.EnsureDNATRuleForSvc("svc:foo", ipv6OrigDst, ipv6Target) + if err != nil { + t.Fatalf("error creating IPv6 DNAT rule: %v", err) + } + checkDNATRule(t, "svc:foo", ipv6OrigDst, ipv6Target, runner, nftables.TableFamilyIPv6) + + // Test creating rule for another service + err = runner.EnsureDNATRuleForSvc("svc:bar", ipv4OrigDst, ipv4Target) + if err != nil { + t.Fatalf("error creating DNAT rule for service 'svc:bar': %v", err) + } + checkDNATRule(t, "svc:bar", ipv4OrigDst, ipv4Target, runner, nftables.TableFamilyIPv4) +} + +func Test_nftablesRunner_DeleteDNATRuleForSvc(t *testing.T) { + conn := newSysConn(t) + runner := newFakeNftablesRunnerWithConn(t, conn, true) + + // Test IPv4 DNAT rule deletion + ipv4OrigDst := netip.MustParseAddr("10.0.0.1") + ipv4Target := netip.MustParseAddr("10.0.0.2") + + // Create and then delete IPv4 DNAT rule + err := runner.EnsureDNATRuleForSvc("svc:foo", ipv4OrigDst, ipv4Target) + if err != nil { + t.Fatalf("error creating IPv4 DNAT rule: %v", err) + } + + // Verify rule exists before deletion + table, err := runner.getNFTByAddr(ipv4OrigDst) + if err != nil { + t.Fatalf("error getting table: %v", err) + } + nftTable, err := getTableIfExists(runner.conn, table.Proto, "nat") + if err != nil { + t.Fatalf("error getting nat table: %v", err) + } + ch, err := getChainFromTable(runner.conn, nftTable, "PREROUTING") + if err != nil { + t.Fatalf("error getting PREROUTING chain: %v", err) + } + meta := svcRuleMeta("svc:foo", ipv4OrigDst, ipv4Target) + rule, err := runner.findRuleByMetadata(nftTable, ch, meta) + if err != nil { + t.Fatalf("error checking if rule exists: %v", err) + } + if rule == nil { + t.Fatal("rule does not exist before deletion") + } + + err = runner.DeleteDNATRuleForSvc("svc:foo", ipv4OrigDst, ipv4Target) + if err != nil { + t.Fatalf("error deleting IPv4 DNAT rule: %v", err) + } + + // Verify rule is deleted + rule, err = runner.findRuleByMetadata(nftTable, ch, meta) + if err != nil { + t.Fatalf("error checking if rule exists: %v", err) + } + if rule != nil { + t.Fatal("rule still exists after deletion") + } + + // Test IPv6 DNAT rule deletion + ipv6OrigDst := netip.MustParseAddr("fd7a:115c:a1e0::1") + ipv6Target := netip.MustParseAddr("fd7a:115c:a1e0::2") + + // Create and then delete IPv6 DNAT rule + err = runner.EnsureDNATRuleForSvc("svc:foo", ipv6OrigDst, ipv6Target) + if err != nil { + t.Fatalf("error creating IPv6 DNAT rule: %v", err) + } + + // Verify rule exists before deletion + table, err = runner.getNFTByAddr(ipv6OrigDst) + if err != nil { + t.Fatalf("error getting table: %v", err) + } + nftTable, err = getTableIfExists(runner.conn, table.Proto, "nat") + if err != nil { + t.Fatalf("error getting nat table: %v", err) + } + ch, err = getChainFromTable(runner.conn, nftTable, "PREROUTING") + if err != nil { + t.Fatalf("error getting PREROUTING chain: %v", err) + } + meta = svcRuleMeta("svc:foo", ipv6OrigDst, ipv6Target) + rule, err = runner.findRuleByMetadata(nftTable, ch, meta) + if err != nil { + t.Fatalf("error checking if rule exists: %v", err) + } + if rule == nil { + t.Fatal("rule does not exist before deletion") + } + + err = runner.DeleteDNATRuleForSvc("svc:foo", ipv6OrigDst, ipv6Target) + if err != nil { + t.Fatalf("error deleting IPv6 DNAT rule: %v", err) + } + + // Verify rule is deleted + rule, err = runner.findRuleByMetadata(nftTable, ch, meta) + if err != nil { + t.Fatalf("error checking if rule exists: %v", err) + } + if rule != nil { + t.Fatal("rule still exists after deletion") + } +} + +// checkDNATRule verifies that a DNAT rule exists for the given service, original destination, and target IP. +func checkDNATRule(t *testing.T, svc string, origDst, targetIP netip.Addr, runner *nftablesRunner, fam nftables.TableFamily) { + t.Helper() + table, err := runner.getNFTByAddr(origDst) + if err != nil { + t.Fatalf("error getting table: %v", err) + } + nftTable, err := getTableIfExists(runner.conn, table.Proto, "nat") + if err != nil { + t.Fatalf("error getting nat table: %v", err) + } + if nftTable == nil { + t.Fatal("nat table not found") + } + + ch, err := getChainFromTable(runner.conn, nftTable, "PREROUTING") + if err != nil { + t.Fatalf("error getting PREROUTING chain: %v", err) + } + if ch == nil { + t.Fatal("PREROUTING chain not found") + } + + meta := svcRuleMeta(svc, origDst, targetIP) + rule, err := runner.findRuleByMetadata(nftTable, ch, meta) + if err != nil { + t.Fatalf("error checking if rule exists: %v", err) + } + if rule == nil { + t.Fatal("DNAT rule not found") + } +} + // svcChains verifies that the expected number of chains exist (for either IP // family) and that each of them is configured as NAT prerouting chain. func svcChains(t *testing.T, wantCount int, conn *nftables.Conn) { diff --git a/util/linuxfw/nftables_runner.go b/util/linuxfw/nftables_runner.go index 0f411521b..faa02f7c7 100644 --- a/util/linuxfw/nftables_runner.go +++ b/util/linuxfw/nftables_runner.go @@ -107,6 +107,12 @@ func (n *nftablesRunner) AddDNATRule(origDst netip.Addr, dst netip.Addr) error { if err != nil { return err } + rule := dnatRuleForChain(nat, preroutingCh, origDst, dst, nil) + n.conn.InsertRule(rule) + return n.conn.Flush() +} + +func dnatRuleForChain(t *nftables.Table, ch *nftables.Chain, origDst, dst netip.Addr, meta []byte) *nftables.Rule { var daddrOffset, fam, dadderLen uint32 if origDst.Is4() { daddrOffset = 16 @@ -117,9 +123,9 @@ func (n *nftablesRunner) AddDNATRule(origDst netip.Addr, dst netip.Addr) error { dadderLen = 16 fam = unix.NFPROTO_IPV6 } - dnatRule := &nftables.Rule{ - Table: nat, - Chain: preroutingCh, + rule := &nftables.Rule{ + Table: t, + Chain: ch, Exprs: []expr.Any{ &expr.Payload{ DestRegister: 1, @@ -143,8 +149,10 @@ func (n *nftablesRunner) AddDNATRule(origDst netip.Addr, dst netip.Addr) error { }, }, } - n.conn.InsertRule(dnatRule) - return n.conn.Flush() + if len(meta) > 0 { + rule.UserData = meta + } + return rule } // DNATWithLoadBalancer currently just forwards all traffic destined for origDst @@ -555,6 +563,8 @@ type NetfilterRunner interface { EnsurePortMapRuleForSvc(svc, tun string, targetIP netip.Addr, pm PortMap) error DeletePortMapRuleForSvc(svc, tun string, targetIP netip.Addr, pm PortMap) error + EnsureDNATRuleForSvc(svcName string, origDst, dst netip.Addr) error + DeleteDNATRuleForSvc(svcName string, origDst, dst netip.Addr) error DeleteSvc(svc, tun string, targetIPs []netip.Addr, pm []PortMap) error @@ -1710,57 +1720,45 @@ func (n *nftablesRunner) AddSNATRule() error { return nil } +func delMatchSubnetRouteMarkMasqRule(conn *nftables.Conn, table *nftables.Table, chain *nftables.Chain) error { + + rule, err := createMatchSubnetRouteMarkRule(table, chain, Masq) + if err != nil { + return fmt.Errorf("create match subnet route mark rule: %w", err) + } + + SNATRule, err := findRule(conn, rule) + if err != nil { + return fmt.Errorf("find SNAT rule v4: %w", err) + } + + if SNATRule != nil { + _ = conn.DelRule(SNATRule) + } + + if err := conn.Flush(); err != nil { + return fmt.Errorf("flush del SNAT rule: %w", err) + } + + return nil +} + // DelSNATRule removes the netfilter rule to SNAT traffic destined for // local subnets. An error is returned if the rule does not exist. func (n *nftablesRunner) DelSNATRule() error { conn := n.conn - hexTSFwmarkMask := getTailscaleFwmarkMask() - hexTSSubnetRouteMark := getTailscaleSubnetRouteMark() - - exprs := []expr.Any{ - &expr.Meta{Key: expr.MetaKeyMARK, Register: 1}, - &expr.Bitwise{ - SourceRegister: 1, - DestRegister: 1, - Len: 4, - Mask: hexTSFwmarkMask, - }, - &expr.Cmp{ - Op: expr.CmpOpEq, - Register: 1, - Data: hexTSSubnetRouteMark, - }, - &expr.Counter{}, - &expr.Masq{}, - } - for _, table := range n.getTables() { chain, err := getChainFromTable(conn, table.Nat, chainNamePostrouting) if err != nil { - return fmt.Errorf("get postrouting chain v4: %w", err) - } - - rule := &nftables.Rule{ - Table: table.Nat, - Chain: chain, - Exprs: exprs, + return fmt.Errorf("get postrouting chain: %w", err) } - - SNATRule, err := findRule(conn, rule) + err = delMatchSubnetRouteMarkMasqRule(conn, table.Nat, chain) if err != nil { - return fmt.Errorf("find SNAT rule v4: %w", err) - } - - if SNATRule != nil { - _ = conn.DelRule(SNATRule) + return err } } - if err := conn.Flush(); err != nil { - return fmt.Errorf("flush del SNAT rule: %w", err) - } - return nil } diff --git a/util/linuxfw/nftables_runner_test.go b/util/linuxfw/nftables_runner_test.go index 712a7b939..6fb180ed6 100644 --- a/util/linuxfw/nftables_runner_test.go +++ b/util/linuxfw/nftables_runner_test.go @@ -12,6 +12,7 @@ import ( "net/netip" "os" "runtime" + "slices" "strings" "testing" @@ -24,21 +25,21 @@ import ( "tailscale.com/types/logger" ) +func toAnySlice[T any](s []T) []any { + out := make([]any, len(s)) + for i, v := range s { + out[i] = v + } + return out +} + // nfdump returns a hexdump of 4 bytes per line (like nft --debug=all), allowing // users to make sense of large byte literals more easily. func nfdump(b []byte) string { var buf bytes.Buffer - i := 0 - for ; i < len(b); i += 4 { - // TODO: show printable characters as ASCII - fmt.Fprintf(&buf, "%02x %02x %02x %02x\n", - b[i], - b[i+1], - b[i+2], - b[i+3]) - } - for ; i < len(b); i++ { - fmt.Fprintf(&buf, "%02x ", b[i]) + for c := range slices.Chunk(b, 4) { + format := strings.Repeat("%02x ", len(c)) + fmt.Fprintf(&buf, format+"\n", toAnySlice(c)...) } return buf.String() } @@ -75,7 +76,7 @@ func linediff(a, b string) string { return buf.String() } -func newTestConn(t *testing.T, want [][]byte) *nftables.Conn { +func newTestConn(t *testing.T, want [][]byte, reply [][]netlink.Message) *nftables.Conn { conn, err := nftables.New(nftables.WithTestDial( func(req []netlink.Message) ([]netlink.Message, error) { for idx, msg := range req { @@ -96,7 +97,13 @@ func newTestConn(t *testing.T, want [][]byte) *nftables.Conn { } want = want[1:] } - return req, nil + // no reply for batch end message + if len(want) == 0 { + return nil, nil + } + rep := reply[0] + reply = reply[1:] + return rep, nil })) if err != nil { t.Fatal(err) @@ -120,7 +127,7 @@ func TestInsertHookRule(t *testing.T) { // batch end []byte("\x00\x00\x00\x0a"), } - testConn := newTestConn(t, want) + testConn := newTestConn(t, want, nil) table := testConn.AddTable(&nftables.Table{ Family: proto, Name: "ts-filter-test", @@ -160,7 +167,7 @@ func TestInsertLoopbackRule(t *testing.T) { // batch end []byte("\x00\x00\x00\x0a"), } - testConn := newTestConn(t, want) + testConn := newTestConn(t, want, nil) table := testConn.AddTable(&nftables.Table{ Family: proto, Name: "ts-filter-test", @@ -196,7 +203,7 @@ func TestInsertLoopbackRuleV6(t *testing.T) { // batch end []byte("\x00\x00\x00\x0a"), } - testConn := newTestConn(t, want) + testConn := newTestConn(t, want, nil) tableV6 := testConn.AddTable(&nftables.Table{ Family: protoV6, Name: "ts-filter-test", @@ -232,7 +239,7 @@ func TestAddReturnChromeOSVMRangeRule(t *testing.T) { // batch end []byte("\x00\x00\x00\x0a"), } - testConn := newTestConn(t, want) + testConn := newTestConn(t, want, nil) table := testConn.AddTable(&nftables.Table{ Family: proto, Name: "ts-filter-test", @@ -264,7 +271,7 @@ func TestAddDropCGNATRangeRule(t *testing.T) { // batch end []byte("\x00\x00\x00\x0a"), } - testConn := newTestConn(t, want) + testConn := newTestConn(t, want, nil) table := testConn.AddTable(&nftables.Table{ Family: proto, Name: "ts-filter-test", @@ -296,7 +303,7 @@ func TestAddSetSubnetRouteMarkRule(t *testing.T) { // batch end []byte("\x00\x00\x00\x0a"), } - testConn := newTestConn(t, want) + testConn := newTestConn(t, want, nil) table := testConn.AddTable(&nftables.Table{ Family: proto, Name: "ts-filter-test", @@ -328,7 +335,7 @@ func TestAddDropOutgoingPacketFromCGNATRangeRuleWithTunname(t *testing.T) { // batch end []byte("\x00\x00\x00\x0a"), } - testConn := newTestConn(t, want) + testConn := newTestConn(t, want, nil) table := testConn.AddTable(&nftables.Table{ Family: proto, Name: "ts-filter-test", @@ -360,7 +367,7 @@ func TestAddAcceptOutgoingPacketRule(t *testing.T) { // batch end []byte("\x00\x00\x00\x0a"), } - testConn := newTestConn(t, want) + testConn := newTestConn(t, want, nil) table := testConn.AddTable(&nftables.Table{ Family: proto, Name: "ts-filter-test", @@ -392,7 +399,7 @@ func TestAddAcceptIncomingPacketRule(t *testing.T) { // batch end []byte("\x00\x00\x00\x0a"), } - testConn := newTestConn(t, want) + testConn := newTestConn(t, want, nil) table := testConn.AddTable(&nftables.Table{ Family: proto, Name: "ts-filter-test", @@ -420,11 +427,11 @@ func TestAddMatchSubnetRouteMarkRuleMasq(t *testing.T) { // nft add chain ip ts-nat-test ts-postrouting-test { type nat hook postrouting priority 100; } []byte("\x02\x00\x00\x00\x10\x00\x01\x00\x74\x73\x2d\x6e\x61\x74\x2d\x74\x65\x73\x74\x00\x18\x00\x03\x00\x74\x73\x2d\x70\x6f\x73\x74\x72\x6f\x75\x74\x69\x6e\x67\x2d\x74\x65\x73\x74\x00\x14\x00\x04\x80\x08\x00\x01\x00\x00\x00\x00\x04\x08\x00\x02\x00\x00\x00\x00\x64\x08\x00\x07\x00\x6e\x61\x74\x00"), // nft add rule ip ts-nat-test ts-postrouting-test meta mark & 0x00ff0000 == 0x00040000 counter masquerade - []byte("\x02\x00\x00\x00\x10\x00\x01\x00\x74\x73\x2d\x6e\x61\x74\x2d\x74\x65\x73\x74\x00\x18\x00\x02\x00\x74\x73\x2d\x70\x6f\x73\x74\x72\x6f\x75\x74\x69\x6e\x67\x2d\x74\x65\x73\x74\x00\xf4\x00\x04\x80\x24\x00\x01\x80\x09\x00\x01\x00\x6d\x65\x74\x61\x00\x00\x00\x00\x14\x00\x02\x80\x08\x00\x02\x00\x00\x00\x00\x03\x08\x00\x01\x00\x00\x00\x00\x01\x44\x00\x01\x80\x0c\x00\x01\x00\x62\x69\x74\x77\x69\x73\x65\x00\x34\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x01\x08\x00\x02\x00\x00\x00\x00\x01\x08\x00\x03\x00\x00\x00\x00\x04\x0c\x00\x04\x80\x08\x00\x01\x00\x00\xff\x00\x00\x0c\x00\x05\x80\x08\x00\x01\x00\x00\x00\x00\x00\x2c\x00\x01\x80\x08\x00\x01\x00\x63\x6d\x70\x00\x20\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x01\x08\x00\x02\x00\x00\x00\x00\x00\x0c\x00\x03\x80\x08\x00\x01\x00\x00\x04\x00\x00\x2c\x00\x01\x80\x0c\x00\x01\x00\x63\x6f\x75\x6e\x74\x65\x72\x00\x1c\x00\x02\x80\x0c\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x0c\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x30\x00\x01\x80\x0e\x00\x01\x00\x69\x6d\x6d\x65\x64\x69\x61\x74\x65\x00\x00\x00\x1c\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x00\x10\x00\x02\x80\x0c\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x01"), + []byte("\x02\x00\x00\x00\x10\x00\x01\x00\x74\x73\x2d\x6e\x61\x74\x2d\x74\x65\x73\x74\x00\x18\x00\x02\x00\x74\x73\x2d\x70\x6f\x73\x74\x72\x6f\x75\x74\x69\x6e\x67\x2d\x74\x65\x73\x74\x00\xd8\x00\x04\x80\x24\x00\x01\x80\x09\x00\x01\x00\x6d\x65\x74\x61\x00\x00\x00\x00\x14\x00\x02\x80\x08\x00\x02\x00\x00\x00\x00\x03\x08\x00\x01\x00\x00\x00\x00\x01\x44\x00\x01\x80\x0c\x00\x01\x00\x62\x69\x74\x77\x69\x73\x65\x00\x34\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x01\x08\x00\x02\x00\x00\x00\x00\x01\x08\x00\x03\x00\x00\x00\x00\x04\x0c\x00\x04\x80\x08\x00\x01\x00\x00\xff\x00\x00\x0c\x00\x05\x80\x08\x00\x01\x00\x00\x00\x00\x00\x2c\x00\x01\x80\x08\x00\x01\x00\x63\x6d\x70\x00\x20\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x01\x08\x00\x02\x00\x00\x00\x00\x00\x0c\x00\x03\x80\x08\x00\x01\x00\x00\x04\x00\x00\x2c\x00\x01\x80\x0c\x00\x01\x00\x63\x6f\x75\x6e\x74\x65\x72\x00\x1c\x00\x02\x80\x0c\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x0c\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x14\x00\x01\x80\x09\x00\x01\x00\x6d\x61\x73\x71\x00\x00\x00\x00\x04\x00\x02\x80"), // batch end []byte("\x00\x00\x00\x0a"), } - testConn := newTestConn(t, want) + testConn := newTestConn(t, want, nil) table := testConn.AddTable(&nftables.Table{ Family: proto, Name: "ts-nat-test", @@ -436,7 +443,46 @@ func TestAddMatchSubnetRouteMarkRuleMasq(t *testing.T) { Hooknum: nftables.ChainHookPostrouting, Priority: nftables.ChainPriorityNATSource, }) - err := addMatchSubnetRouteMarkRule(testConn, table, chain, Accept) + err := addMatchSubnetRouteMarkRule(testConn, table, chain, Masq) + if err != nil { + t.Fatal(err) + } +} + +func TestDelMatchSubnetRouteMarkMasqRule(t *testing.T) { + proto := nftables.TableFamilyIPv4 + reply := [][]netlink.Message{ + nil, + {{Header: netlink.Header{Length: 0x128, Type: 0xa06, Flags: 0x802, Sequence: 0xa213d55d, PID: 0x11e79}, Data: []uint8{0x2, 0x0, 0x0, 0x8c, 0xd, 0x0, 0x1, 0x0, 0x6e, 0x61, 0x74, 0x2d, 0x74, 0x65, 0x73, 0x74, 0x0, 0x0, 0x0, 0x0, 0x18, 0x0, 0x2, 0x0, 0x74, 0x73, 0x2d, 0x70, 0x6f, 0x73, 0x74, 0x72, 0x6f, 0x75, 0x74, 0x69, 0x6e, 0x67, 0x2d, 0x74, 0x65, 0x73, 0x74, 0x0, 0xc, 0x0, 0x3, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x4, 0xe0, 0x0, 0x4, 0x0, 0x24, 0x0, 0x1, 0x0, 0x9, 0x0, 0x1, 0x0, 0x6d, 0x65, 0x74, 0x61, 0x0, 0x0, 0x0, 0x0, 0x14, 0x0, 0x2, 0x0, 0x8, 0x0, 0x2, 0x0, 0x0, 0x0, 0x0, 0x3, 0x8, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x1, 0x4c, 0x0, 0x1, 0x0, 0xc, 0x0, 0x1, 0x0, 0x62, 0x69, 0x74, 0x77, 0x69, 0x73, 0x65, 0x0, 0x3c, 0x0, 0x2, 0x0, 0x8, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x1, 0x8, 0x0, 0x2, 0x0, 0x0, 0x0, 0x0, 0x1, 0x8, 0x0, 0x3, 0x0, 0x0, 0x0, 0x0, 0x4, 0x8, 0x0, 0x6, 0x0, 0x0, 0x0, 0x0, 0x0, 0xc, 0x0, 0x4, 0x0, 0x8, 0x0, 0x1, 0x0, 0x0, 0xff, 0x0, 0x0, 0xc, 0x0, 0x5, 0x0, 0x8, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2c, 0x0, 0x1, 0x0, 0x8, 0x0, 0x1, 0x0, 0x63, 0x6d, 0x70, 0x0, 0x20, 0x0, 0x2, 0x0, 0x8, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x1, 0x8, 0x0, 0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0xc, 0x0, 0x3, 0x0, 0x8, 0x0, 0x1, 0x0, 0x0, 0x4, 0x0, 0x0, 0x2c, 0x0, 0x1, 0x0, 0xc, 0x0, 0x1, 0x0, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x65, 0x72, 0x0, 0x1c, 0x0, 0x2, 0x0, 0xc, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0xc, 0x0, 0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x14, 0x0, 0x1, 0x0, 0x9, 0x0, 0x1, 0x0, 0x6d, 0x61, 0x73, 0x71, 0x0, 0x0, 0x0, 0x0, 0x4, 0x0, 0x2, 0x0}}}, + {{Header: netlink.Header{Length: 0x14, Type: 0x3, Flags: 0x2, Sequence: 0x311fdccb, PID: 0x11e79}, Data: []uint8{0x0, 0x0, 0x0, 0x0}}}, + {{Header: netlink.Header{Length: 0x24, Type: 0x2, Flags: 0x100, Sequence: 0x311fdccb, PID: 0x11e79}, Data: []uint8{0x0, 0x0, 0x0, 0x0, 0x48, 0x0, 0x0, 0x0, 0x8, 0xa, 0x5, 0x0, 0xcb, 0xdc, 0x1f, 0x31, 0x79, 0x1e, 0x1, 0x0}}}, + } + want := [][]byte{ + // get rules in nat-test table ts-postrouting-test chain + []byte("\x02\x00\x00\x00\x0d\x00\x01\x00\x6e\x61\x74\x2d\x74\x65\x73\x74\x00\x00\x00\x00\x18\x00\x02\x00\x74\x73\x2d\x70\x6f\x73\x74\x72\x6f\x75\x74\x69\x6e\x67\x2d\x74\x65\x73\x74\x00"), + // batch begin + []byte("\x00\x00\x00\x0a"), + // nft delete rule ip nat-test ts-postrouting-test handle 4 + []byte("\x02\x00\x00\x00\x0d\x00\x01\x00\x6e\x61\x74\x2d\x74\x65\x73\x74\x00\x00\x00\x00\x18\x00\x02\x00\x74\x73\x2d\x70\x6f\x73\x74\x72\x6f\x75\x74\x69\x6e\x67\x2d\x74\x65\x73\x74\x00\x0c\x00\x03\x00\x00\x00\x00\x00\x00\x00\x00\x04"), + // batch end + []byte("\x00\x00\x00\x0a"), + } + + conn := newTestConn(t, want, reply) + + table := &nftables.Table{ + Family: proto, + Name: "nat-test", + } + chain := &nftables.Chain{ + Name: "ts-postrouting-test", + Table: table, + Type: nftables.ChainTypeNAT, + Hooknum: nftables.ChainHookPostrouting, + Priority: nftables.ChainPriorityNATSource, + } + + err := delMatchSubnetRouteMarkMasqRule(conn, table, chain) if err != nil { t.Fatal(err) } @@ -456,7 +502,7 @@ func TestAddMatchSubnetRouteMarkRuleAccept(t *testing.T) { // batch end []byte("\x00\x00\x00\x0a"), } - testConn := newTestConn(t, want) + testConn := newTestConn(t, want, nil) table := testConn.AddTable(&nftables.Table{ Family: proto, Name: "ts-filter-test", diff --git a/util/lru/lru_test.go b/util/lru/lru_test.go index fb538efbe..04de2e507 100644 --- a/util/lru/lru_test.go +++ b/util/lru/lru_test.go @@ -10,7 +10,7 @@ import ( "testing" "github.com/google/go-cmp/cmp" - xmaps "golang.org/x/exp/maps" + "tailscale.com/util/slicesx" ) func TestLRU(t *testing.T) { @@ -75,7 +75,7 @@ func TestStressEvictions(t *testing.T) { for len(vm) < numKeys { vm[rand.Uint64()] = true } - vals := xmaps.Keys(vm) + vals := slicesx.MapKeys(vm) c := Cache[uint64, bool]{ MaxEntries: cacheSize, @@ -84,8 +84,8 @@ func TestStressEvictions(t *testing.T) { for range numProbes { v := vals[rand.Intn(len(vals))] c.Set(v, true) - if l := c.Len(); l > cacheSize { - t.Fatalf("Cache size now %d, want max %d", l, cacheSize) + if ln := c.Len(); ln > cacheSize { + t.Fatalf("Cache size now %d, want max %d", ln, cacheSize) } } } @@ -106,7 +106,7 @@ func TestStressBatchedEvictions(t *testing.T) { for len(vm) < numKeys { vm[rand.Uint64()] = true } - vals := xmaps.Keys(vm) + vals := slicesx.MapKeys(vm) c := Cache[uint64, bool]{} @@ -119,8 +119,8 @@ func TestStressBatchedEvictions(t *testing.T) { c.DeleteOldest() } } - if l := c.Len(); l > cacheSizeMax { - t.Fatalf("Cache size now %d, want max %d", l, cacheSizeMax) + if ln := c.Len(); ln > cacheSizeMax { + t.Fatalf("Cache size now %d, want max %d", ln, cacheSizeMax) } } } diff --git a/util/mak/mak.go b/util/mak/mak.go index b421fb0ed..fbdb40b0a 100644 --- a/util/mak/mak.go +++ b/util/mak/mak.go @@ -5,11 +5,6 @@ // things, notably to maps, but also slices. package mak -import ( - "fmt" - "reflect" -) - // Set populates an entry in a map, making the map if necessary. // // That is, it assigns (*m)[k] = v, making *m if it was nil. @@ -20,35 +15,6 @@ func Set[K comparable, V any, T ~map[K]V](m *T, k K, v V) { (*m)[k] = v } -// NonNil takes a pointer to a Go data structure -// (currently only a slice or a map) and makes sure it's non-nil for -// JSON serialization. (In particular, JavaScript clients usually want -// the field to be defined after they decode the JSON.) -// -// Deprecated: use NonNilSliceForJSON or NonNilMapForJSON instead. -func NonNil(ptr any) { - if ptr == nil { - panic("nil interface") - } - rv := reflect.ValueOf(ptr) - if rv.Kind() != reflect.Ptr { - panic(fmt.Sprintf("kind %v, not Ptr", rv.Kind())) - } - if rv.Pointer() == 0 { - panic("nil pointer") - } - rv = rv.Elem() - if rv.Pointer() != 0 { - return - } - switch rv.Type().Kind() { - case reflect.Slice: - rv.Set(reflect.MakeSlice(rv.Type(), 0, 0)) - case reflect.Map: - rv.Set(reflect.MakeMap(rv.Type())) - } -} - // NonNilSliceForJSON makes sure that *slicePtr is non-nil so it will // won't be omitted from JSON serialization and possibly confuse JavaScript // clients expecting it to be present. diff --git a/util/mak/mak_test.go b/util/mak/mak_test.go index 4de499a9d..e47839a3c 100644 --- a/util/mak/mak_test.go +++ b/util/mak/mak_test.go @@ -40,35 +40,6 @@ func TestSet(t *testing.T) { }) } -func TestNonNil(t *testing.T) { - var s []string - NonNil(&s) - if len(s) != 0 { - t.Errorf("slice len = %d; want 0", len(s)) - } - if s == nil { - t.Error("slice still nil") - } - - s = append(s, "foo") - NonNil(&s) - if len(s) != 1 { - t.Errorf("len = %d; want 1", len(s)) - } - if s[0] != "foo" { - t.Errorf("value = %q; want foo", s) - } - - var m map[string]string - NonNil(&m) - if len(m) != 0 { - t.Errorf("map len = %d; want 0", len(s)) - } - if m == nil { - t.Error("map still nil") - } -} - func TestNonNilMapForJSON(t *testing.T) { type M map[string]int var m M diff --git a/util/must/must.go b/util/must/must.go index 21965daa9..a292da226 100644 --- a/util/must/must.go +++ b/util/must/must.go @@ -23,3 +23,11 @@ func Get[T any](v T, err error) T { } return v } + +// Get2 returns v1 and v2 as is. It panics if err is non-nil. +func Get2[T any, U any](v1 T, v2 U, err error) (T, U) { + if err != nil { + panic(err) + } + return v1, v2 +} diff --git a/util/osdiag/zsyscall_windows.go b/util/osdiag/zsyscall_windows.go index ab0d18d3f..2a11b4644 100644 --- a/util/osdiag/zsyscall_windows.go +++ b/util/osdiag/zsyscall_windows.go @@ -51,7 +51,7 @@ var ( ) func regEnumValue(key registry.Key, index uint32, valueName *uint16, valueNameLen *uint32, reserved *uint32, valueType *uint32, pData *byte, cbData *uint32) (ret error) { - r0, _, _ := syscall.Syscall9(procRegEnumValueW.Addr(), 8, uintptr(key), uintptr(index), uintptr(unsafe.Pointer(valueName)), uintptr(unsafe.Pointer(valueNameLen)), uintptr(unsafe.Pointer(reserved)), uintptr(unsafe.Pointer(valueType)), uintptr(unsafe.Pointer(pData)), uintptr(unsafe.Pointer(cbData)), 0) + r0, _, _ := syscall.SyscallN(procRegEnumValueW.Addr(), uintptr(key), uintptr(index), uintptr(unsafe.Pointer(valueName)), uintptr(unsafe.Pointer(valueNameLen)), uintptr(unsafe.Pointer(reserved)), uintptr(unsafe.Pointer(valueType)), uintptr(unsafe.Pointer(pData)), uintptr(unsafe.Pointer(cbData))) if r0 != 0 { ret = syscall.Errno(r0) } @@ -59,7 +59,7 @@ func regEnumValue(key registry.Key, index uint32, valueName *uint16, valueNameLe } func globalMemoryStatusEx(memStatus *_MEMORYSTATUSEX) (err error) { - r1, _, e1 := syscall.Syscall(procGlobalMemoryStatusEx.Addr(), 1, uintptr(unsafe.Pointer(memStatus)), 0, 0) + r1, _, e1 := syscall.SyscallN(procGlobalMemoryStatusEx.Addr(), uintptr(unsafe.Pointer(memStatus))) if int32(r1) == 0 { err = errnoErr(e1) } @@ -67,19 +67,19 @@ func globalMemoryStatusEx(memStatus *_MEMORYSTATUSEX) (err error) { } func wscEnumProtocols(iProtocols *int32, protocolBuffer *wsaProtocolInfo, bufLen *uint32, errno *int32) (ret int32) { - r0, _, _ := syscall.Syscall6(procWSCEnumProtocols.Addr(), 4, uintptr(unsafe.Pointer(iProtocols)), uintptr(unsafe.Pointer(protocolBuffer)), uintptr(unsafe.Pointer(bufLen)), uintptr(unsafe.Pointer(errno)), 0, 0) + r0, _, _ := syscall.SyscallN(procWSCEnumProtocols.Addr(), uintptr(unsafe.Pointer(iProtocols)), uintptr(unsafe.Pointer(protocolBuffer)), uintptr(unsafe.Pointer(bufLen)), uintptr(unsafe.Pointer(errno))) ret = int32(r0) return } func wscGetProviderInfo(providerId *windows.GUID, infoType _WSC_PROVIDER_INFO_TYPE, info unsafe.Pointer, infoSize *uintptr, flags uint32, errno *int32) (ret int32) { - r0, _, _ := syscall.Syscall6(procWSCGetProviderInfo.Addr(), 6, uintptr(unsafe.Pointer(providerId)), uintptr(infoType), uintptr(info), uintptr(unsafe.Pointer(infoSize)), uintptr(flags), uintptr(unsafe.Pointer(errno))) + r0, _, _ := syscall.SyscallN(procWSCGetProviderInfo.Addr(), uintptr(unsafe.Pointer(providerId)), uintptr(infoType), uintptr(info), uintptr(unsafe.Pointer(infoSize)), uintptr(flags), uintptr(unsafe.Pointer(errno))) ret = int32(r0) return } func wscGetProviderPath(providerId *windows.GUID, providerDllPath *uint16, providerDllPathLen *int32, errno *int32) (ret int32) { - r0, _, _ := syscall.Syscall6(procWSCGetProviderPath.Addr(), 4, uintptr(unsafe.Pointer(providerId)), uintptr(unsafe.Pointer(providerDllPath)), uintptr(unsafe.Pointer(providerDllPathLen)), uintptr(unsafe.Pointer(errno)), 0, 0) + r0, _, _ := syscall.SyscallN(procWSCGetProviderPath.Addr(), uintptr(unsafe.Pointer(providerId)), uintptr(unsafe.Pointer(providerDllPath)), uintptr(unsafe.Pointer(providerDllPathLen)), uintptr(unsafe.Pointer(errno))) ret = int32(r0) return } diff --git a/util/osshare/filesharingstatus_windows.go b/util/osshare/filesharingstatus_windows.go index 999fc1cf7..c125de159 100644 --- a/util/osshare/filesharingstatus_windows.go +++ b/util/osshare/filesharingstatus_windows.go @@ -9,30 +9,31 @@ import ( "fmt" "os" "path/filepath" - "sync" + "runtime" "golang.org/x/sys/windows/registry" + "tailscale.com/types/lazy" "tailscale.com/types/logger" + "tailscale.com/util/winutil" ) const ( sendFileShellKey = `*\shell\tailscale` ) -var ipnExePath struct { - sync.Mutex - cache string // absolute path of tailscale-ipn.exe, populated lazily on first use -} +var ipnExePath lazy.SyncValue[string] // absolute path of the GUI executable func getIpnExePath(logf logger.Logf) string { - ipnExePath.Lock() - defer ipnExePath.Unlock() - - if ipnExePath.cache != "" { - return ipnExePath.cache + exe, err := winutil.GUIPathFromReg() + if err == nil { + return exe } - // Find the absolute path of tailscale-ipn.exe assuming that it's in the same + return findGUIInSameDirAsThisExe(logf) +} + +func findGUIInSameDirAsThisExe(logf logger.Logf) string { + // Find the absolute path of the GUI, assuming that it's in the same // directory as this executable (tailscaled.exe). p, err := os.Executable() if err != nil { @@ -43,14 +44,23 @@ func getIpnExePath(logf logger.Logf) string { logf("filepath.EvalSymlinks error: %v", err) return "" } - p = filepath.Join(filepath.Dir(p), "tailscale-ipn.exe") if p, err = filepath.Abs(p); err != nil { logf("filepath.Abs error: %v", err) return "" } - ipnExePath.cache = p - - return p + d := filepath.Dir(p) + candidates := []string{"tailscale-ipn.exe"} + if runtime.GOARCH == "arm64" { + // This name may be used on Windows 10 ARM64. + candidates = append(candidates, "tailscale-gui-386.exe") + } + for _, c := range candidates { + testPath := filepath.Join(d, c) + if _, err := os.Stat(testPath); err == nil { + return testPath + } + } + return "" } // SetFileSharingEnabled adds/removes "Send with Tailscale" from the Windows shell menu. @@ -64,7 +74,9 @@ func SetFileSharingEnabled(enabled bool, logf logger.Logf) { } func enableFileSharing(logf logger.Logf) { - path := getIpnExePath(logf) + path := ipnExePath.Get(func() string { + return getIpnExePath(logf) + }) if path == "" { return } @@ -79,7 +91,7 @@ func enableFileSharing(logf logger.Logf) { logf("k.SetStringValue error: %v", err) return } - if err := k.SetStringValue("Icon", path+",0"); err != nil { + if err := k.SetStringValue("Icon", path+",1"); err != nil { logf("k.SetStringValue error: %v", err) return } diff --git a/util/osuser/group_ids.go b/util/osuser/group_ids.go index f25861dbb..7c2b5b090 100644 --- a/util/osuser/group_ids.go +++ b/util/osuser/group_ids.go @@ -19,6 +19,10 @@ import ( // an error. It will first try to use the 'id' command to get the group IDs, // and if that fails, it will fall back to the user.GroupIds method. func GetGroupIds(user *user.User) ([]string, error) { + if runtime.GOOS == "plan9" { + return nil, nil + } + if runtime.GOOS != "linux" { return user.GroupIds() } diff --git a/util/osuser/user.go b/util/osuser/user.go index 2c7f2e24b..8b96194d7 100644 --- a/util/osuser/user.go +++ b/util/osuser/user.go @@ -54,9 +54,18 @@ func lookup(usernameOrUID string, std lookupStd, wantShell bool) (*user.User, st // Skip getent entirely on Non-Unix platforms that won't ever have it. // (Using HasPrefix for "wasip1", anticipating that WASI support will // move beyond "preview 1" some day.) - if runtime.GOOS == "windows" || runtime.GOOS == "js" || runtime.GOARCH == "wasm" { + if runtime.GOOS == "windows" || runtime.GOOS == "js" || runtime.GOARCH == "wasm" || runtime.GOOS == "plan9" { + var shell string + if wantShell && runtime.GOOS == "plan9" { + shell = "/bin/rc" + } + if runtime.GOOS == "plan9" { + if u, err := user.Current(); err == nil { + return u, shell, nil + } + } u, err := std(usernameOrUID) - return u, "", err + return u, shell, err } // No getent on Gokrazy. So hard-code the login shell. @@ -78,6 +87,16 @@ func lookup(usernameOrUID string, std lookupStd, wantShell bool) (*user.User, st return u, shell, nil } + if runtime.GOOS == "plan9" { + return &user.User{ + Uid: "0", + Gid: "0", + Username: "glenda", + Name: "Glenda", + HomeDir: "/", + }, "/bin/rc", nil + } + // Start with getent if caller wants to get the user shell. if wantShell { return userLookupGetent(usernameOrUID, std) diff --git a/util/pidowner/pidowner_linux.go b/util/pidowner/pidowner_linux.go index 2a5181f14..a07f51242 100644 --- a/util/pidowner/pidowner_linux.go +++ b/util/pidowner/pidowner_linux.go @@ -8,26 +8,26 @@ import ( "os" "strings" - "tailscale.com/util/lineread" + "tailscale.com/util/lineiter" ) func ownerOfPID(pid int) (userID string, err error) { file := fmt.Sprintf("/proc/%d/status", pid) - err = lineread.File(file, func(line []byte) error { + for lr := range lineiter.File(file) { + line, err := lr.Value() + if err != nil { + if os.IsNotExist(err) { + return "", ErrProcessNotFound + } + return "", err + } if len(line) < 4 || string(line[:4]) != "Uid:" { - return nil + continue } f := strings.Fields(string(line)) if len(f) >= 2 { userID = f[1] // real userid } - return nil - }) - if os.IsNotExist(err) { - return "", ErrProcessNotFound - } - if err != nil { - return } if userID == "" { return "", fmt.Errorf("missing Uid line in %s", file) diff --git a/util/prompt/prompt.go b/util/prompt/prompt.go new file mode 100644 index 000000000..a6d86fb48 --- /dev/null +++ b/util/prompt/prompt.go @@ -0,0 +1,39 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package prompt provides a simple way to prompt the user for input. +package prompt + +import ( + "fmt" + "os" + "strings" + + "github.com/mattn/go-isatty" +) + +// YesNo takes a question and prompts the user to answer the +// question with a yes or no. It appends a [y/n] to the message. +// +// If there is no TTY on both Stdin and Stdout, assume that we're in a script +// and return the dflt result. +func YesNo(msg string, dflt bool) bool { + if !(isatty.IsTerminal(os.Stdin.Fd()) && isatty.IsTerminal(os.Stdout.Fd())) { + return dflt + } + if dflt { + fmt.Print(msg + " [Y/n] ") + } else { + fmt.Print(msg + " [y/N] ") + } + var resp string + fmt.Scanln(&resp) + resp = strings.ToLower(resp) + switch resp { + case "y", "yes", "sure": + return true + case "": + return dflt + } + return false +} diff --git a/util/ringbuffer/ringbuffer.go b/util/ringbuffer/ringbuffer.go deleted file mode 100644 index baca2afe8..000000000 --- a/util/ringbuffer/ringbuffer.go +++ /dev/null @@ -1,79 +0,0 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package ringbuffer contains a fixed-size concurrency-safe generic ring -// buffer. -package ringbuffer - -import "sync" - -// New creates a new RingBuffer containing at most max items. -func New[T any](max int) *RingBuffer[T] { - return &RingBuffer[T]{ - max: max, - } -} - -// RingBuffer is a concurrency-safe ring buffer. -type RingBuffer[T any] struct { - mu sync.Mutex - pos int - buf []T - max int -} - -// Add appends a new item to the RingBuffer, possibly overwriting the oldest -// item in the buffer if it is already full. -// -// It does nothing if rb is nil. -func (rb *RingBuffer[T]) Add(t T) { - if rb == nil { - return - } - rb.mu.Lock() - defer rb.mu.Unlock() - if len(rb.buf) < rb.max { - rb.buf = append(rb.buf, t) - } else { - rb.buf[rb.pos] = t - rb.pos = (rb.pos + 1) % rb.max - } -} - -// GetAll returns a copy of all the entries in the ring buffer in the order they -// were added. -// -// It returns nil if rb is nil. -func (rb *RingBuffer[T]) GetAll() []T { - if rb == nil { - return nil - } - rb.mu.Lock() - defer rb.mu.Unlock() - out := make([]T, len(rb.buf)) - for i := range len(rb.buf) { - x := (rb.pos + i) % rb.max - out[i] = rb.buf[x] - } - return out -} - -// Len returns the number of elements in the ring buffer. Note that this value -// could change immediately after being returned if a concurrent caller -// modifies the buffer. -func (rb *RingBuffer[T]) Len() int { - if rb == nil { - return 0 - } - rb.mu.Lock() - defer rb.mu.Unlock() - return len(rb.buf) -} - -// Clear will empty the ring buffer. -func (rb *RingBuffer[T]) Clear() { - rb.mu.Lock() - defer rb.mu.Unlock() - rb.pos = 0 - rb.buf = nil -} diff --git a/util/ringlog/ringlog.go b/util/ringlog/ringlog.go new file mode 100644 index 000000000..62dfbae5b --- /dev/null +++ b/util/ringlog/ringlog.go @@ -0,0 +1,78 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package ringlog contains a limited-size concurrency-safe generic ring log. +package ringlog + +import "tailscale.com/syncs" + +// New creates a new [RingLog] containing at most max items. +func New[T any](max int) *RingLog[T] { + return &RingLog[T]{ + max: max, + } +} + +// RingLog is a concurrency-safe fixed size log window containing entries of [T]. +type RingLog[T any] struct { + mu syncs.Mutex + pos int + buf []T + max int +} + +// Add appends a new item to the [RingLog], possibly overwriting the oldest +// item in the log if it is already full. +// +// It does nothing if rb is nil. +func (rb *RingLog[T]) Add(t T) { + if rb == nil { + return + } + rb.mu.Lock() + defer rb.mu.Unlock() + if len(rb.buf) < rb.max { + rb.buf = append(rb.buf, t) + } else { + rb.buf[rb.pos] = t + rb.pos = (rb.pos + 1) % rb.max + } +} + +// GetAll returns a copy of all the entries in the ring log in the order they +// were added. +// +// It returns nil if rb is nil. +func (rb *RingLog[T]) GetAll() []T { + if rb == nil { + return nil + } + rb.mu.Lock() + defer rb.mu.Unlock() + out := make([]T, len(rb.buf)) + for i := range len(rb.buf) { + x := (rb.pos + i) % rb.max + out[i] = rb.buf[x] + } + return out +} + +// Len returns the number of elements in the ring log. Note that this value +// could change immediately after being returned if a concurrent caller +// modifies the log. +func (rb *RingLog[T]) Len() int { + if rb == nil { + return 0 + } + rb.mu.Lock() + defer rb.mu.Unlock() + return len(rb.buf) +} + +// Clear will empty the ring log. +func (rb *RingLog[T]) Clear() { + rb.mu.Lock() + defer rb.mu.Unlock() + rb.pos = 0 + rb.buf = nil +} diff --git a/util/ringbuffer/ringbuffer_test.go b/util/ringlog/ringlog_test.go similarity index 95% rename from util/ringbuffer/ringbuffer_test.go rename to util/ringlog/ringlog_test.go index e10096bfb..d6776e181 100644 --- a/util/ringbuffer/ringbuffer_test.go +++ b/util/ringlog/ringlog_test.go @@ -1,14 +1,14 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -package ringbuffer +package ringlog import ( "reflect" "testing" ) -func TestRingBuffer(t *testing.T) { +func TestRingLog(t *testing.T) { const numItems = 10 rb := New[int](numItems) diff --git a/util/safediff/diff.go b/util/safediff/diff.go new file mode 100644 index 000000000..cf8add94b --- /dev/null +++ b/util/safediff/diff.go @@ -0,0 +1,280 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package safediff computes the difference between two lists. +// +// It is guaranteed to run in O(n), but may not produce an optimal diff. +// Most diffing algorithms produce optimal diffs but run in O(n²). +// It is safe to pass in untrusted input. +package safediff + +import ( + "bytes" + "fmt" + "math" + "strings" + "unicode" + + "github.com/google/go-cmp/cmp" +) + +var diffTest = false + +// Lines constructs a humanly readable line-by-line diff from x to y. +// The output (if multiple lines) is guaranteed to be no larger than maxSize, +// by truncating the output if necessary. A negative maxSize enforces no limit. +// +// Example diff: +// +// â€Ļ 440 identical lines +// "ssh": [ +// â€Ļ 35 identical lines +// { +// - "src": ["maisem@tailscale.com"], +// - "dst": ["tag:maisem-test"], +// - "users": ["maisem", "root"], +// - "action": "check", +// - // "recorder": ["100.12.34.56:80"], +// + "src": ["maisem@tailscale.com"], +// + "dst": ["tag:maisem-test"], +// + "users": ["maisem", "root"], +// + "action": "check", +// + "recorder": ["node:recorder-2"], +// }, +// â€Ļ 77 identical lines +// ], +// â€Ļ 345 identical lines +// +// Meaning of each line prefix: +// +// - 'â€Ļ' precedes a summary statement +// - ' ' precedes an identical line printed for context +// - '-' precedes a line removed from x +// - '+' precedes a line inserted from y +// +// The diffing algorithm runs in O(n) and is safe to use with untrusted inputs. +func Lines(x, y string, maxSize int) (out string, truncated bool) { + // Convert x and y into a slice of lines and compute the edit-script. + xs := strings.Split(x, "\n") + ys := strings.Split(y, "\n") + es := diffStrings(xs, ys) + + // Modify the edit-script to support printing identical lines of context. + const identicalContext edit = '*' // special edit code to indicate printed line + var xi, yi int // index into xs or ys + isIdentical := func(e edit) bool { return e == identical || e == identicalContext } + indentOf := func(s string) string { return s[:len(s)-len(strings.TrimLeftFunc(s, unicode.IsSpace))] } + for i, e := range es { + if isIdentical(e) { + // Print current line if adjacent symbols are non-identical. + switch { + case i-1 >= 0 && !isIdentical(es[i-1]): + es[i] = identicalContext + case i+1 < len(es) && !isIdentical(es[i+1]): + es[i] = identicalContext + } + } else { + // Print any preceding or succeeding lines, + // where the leading indent is a prefix of the current indent. + // Indentation often indicates a parent-child relationship + // in structured source code. + addParents := func(ss []string, si, direction int) { + childIndent := indentOf(ss[si]) + for j := direction; i+j >= 0 && i+j < len(es) && isIdentical(es[i+j]); j += direction { + parentIndent := indentOf(ss[si+j]) + if strings.HasPrefix(childIndent, parentIndent) && len(parentIndent) < len(childIndent) && parentIndent != "" { + es[i+j] = identicalContext + childIndent = parentIndent + } + } + } + switch e { + case removed, modified: // arbitrarily use the x value for modified values + addParents(xs, xi, -1) + addParents(xs, xi, +1) + case inserted: + addParents(ys, yi, -1) + addParents(ys, yi, +1) + } + } + if e != inserted { + xi++ + } + if e != removed { + yi++ + } + } + + // Show the line for a single hidden identical line, + // since it occupies the same vertical height. + for i, e := range es { + if e == identical { + prevNotIdentical := i-1 < 0 || es[i-1] != identical + nextNotIdentical := i+1 >= len(es) || es[i+1] != identical + if prevNotIdentical && nextNotIdentical { + es[i] = identicalContext + } + } + } + + // Adjust the maxSize, reserving space for the final summary. + if maxSize < 0 { + maxSize = math.MaxInt + } + maxSize -= len(stats{len(xs) + len(ys), len(xs), len(ys)}.appendText(nil)) + + // mayAppendLine appends a line if it does not exceed maxSize. + // Otherwise, it just updates prevStats. + var buf []byte + var prevStats stats + mayAppendLine := func(edit edit, line string) { + // Append the stats (if non-zero) and the line text. + // The stats reports the number of preceding identical lines. + if !truncated { + bufLen := len(buf) // original length (in case we exceed maxSize) + if !prevStats.isZero() { + buf = prevStats.appendText(buf) + prevStats = stats{} // just printed, so clear the stats + } + buf = fmt.Appendf(buf, "%c %s\n", edit, line) + truncated = len(buf) > maxSize + if !truncated { + return + } + buf = buf[:bufLen] // restore original buffer contents + } + + // Output is truncated, so just update the statistics. + switch edit { + case identical: + prevStats.numIdentical++ + case removed: + prevStats.numRemoved++ + case inserted: + prevStats.numInserted++ + } + } + + // Process the entire edit script. + for len(es) > 0 { + num := len(es) - len(bytes.TrimLeft(es, string(es[:1]))) + switch es[0] { + case identical: + prevStats.numIdentical += num + xs, ys = xs[num:], ys[num:] + case identicalContext: + for n := len(xs) - num; len(xs) > n; xs, ys = xs[1:], ys[1:] { + mayAppendLine(identical, xs[0]) // implies xs[0] == ys[0] + } + case modified: + for n := len(xs) - num; len(xs) > n; xs = xs[1:] { + mayAppendLine(removed, xs[0]) + } + for n := len(ys) - num; len(ys) > n; ys = ys[1:] { + mayAppendLine(inserted, ys[0]) + } + case removed: + for n := len(xs) - num; len(xs) > n; xs = xs[1:] { + mayAppendLine(removed, xs[0]) + } + case inserted: + for n := len(ys) - num; len(ys) > n; ys = ys[1:] { + mayAppendLine(inserted, ys[0]) + } + } + es = es[num:] + } + if len(xs)+len(ys)+len(es) > 0 { + panic("BUG: slices not fully consumed") + } + + if !prevStats.isZero() { + buf = prevStats.appendText(buf) // may exceed maxSize + } + return string(buf), truncated +} + +type stats struct{ numIdentical, numRemoved, numInserted int } + +func (s stats) isZero() bool { return s.numIdentical+s.numRemoved+s.numInserted == 0 } + +func (s stats) appendText(b []byte) []byte { + switch { + case s.numIdentical > 0 && s.numRemoved > 0 && s.numInserted > 0: + return fmt.Appendf(b, "â€Ļ %d identical, %d removed, and %d inserted lines\n", s.numIdentical, s.numRemoved, s.numInserted) + case s.numIdentical > 0 && s.numRemoved > 0: + return fmt.Appendf(b, "â€Ļ %d identical and %d removed lines\n", s.numIdentical, s.numRemoved) + case s.numIdentical > 0 && s.numInserted > 0: + return fmt.Appendf(b, "â€Ļ %d identical and %d inserted lines\n", s.numIdentical, s.numInserted) + case s.numRemoved > 0 && s.numInserted > 0: + return fmt.Appendf(b, "â€Ļ %d removed and %d inserted lines\n", s.numRemoved, s.numInserted) + case s.numIdentical > 0: + return fmt.Appendf(b, "â€Ļ %d identical lines\n", s.numIdentical) + case s.numRemoved > 0: + return fmt.Appendf(b, "â€Ļ %d removed lines\n", s.numRemoved) + case s.numInserted > 0: + return fmt.Appendf(b, "â€Ļ %d inserted lines\n", s.numInserted) + default: + return fmt.Appendf(b, "â€Ļ\n") + } +} + +// diffStrings computes an edit-script of two slices of strings. +// +// This calls cmp.Equal to access the "github.com/go-cmp/cmp/internal/diff" +// implementation, which has an O(N) diffing algorithm. It is not guaranteed +// to produce an optimal edit-script, but protects our runtime against +// adversarial inputs that would wreck the optimal O(N²) algorithm used by +// most diffing packages available in open-source. +// +// TODO(https://go.dev/issue/58893): Use "golang.org/x/tools/diff" instead? +func diffStrings(xs, ys []string) []edit { + d := new(diffRecorder) + cmp.Equal(xs, ys, cmp.Reporter(d)) + if diffTest { + numRemoved := bytes.Count(d.script, []byte{removed}) + numInserted := bytes.Count(d.script, []byte{inserted}) + if len(xs) != len(d.script)-numInserted || len(ys) != len(d.script)-numRemoved { + panic("BUG: edit-script is inconsistent") + } + } + return d.script +} + +type edit = byte + +const ( + identical edit = ' ' // equal symbol in both x and y + modified edit = '~' // modified symbol in both x and y + removed edit = '-' // removed symbol from x + inserted edit = '+' // inserted symbol from y +) + +// diffRecorder reproduces an edit-script, essentially recording +// the edit-script from "github.com/google/go-cmp/cmp/internal/diff". +// This implements the cmp.Reporter interface. +type diffRecorder struct { + last cmp.PathStep + script []edit +} + +func (d *diffRecorder) PushStep(ps cmp.PathStep) { d.last = ps } + +func (d *diffRecorder) Report(rs cmp.Result) { + if si, ok := d.last.(cmp.SliceIndex); ok { + if rs.Equal() { + d.script = append(d.script, identical) + } else { + switch xi, yi := si.SplitKeys(); { + case xi >= 0 && yi >= 0: + d.script = append(d.script, modified) + case xi >= 0: + d.script = append(d.script, removed) + case yi >= 0: + d.script = append(d.script, inserted) + } + } + } +} + +func (d *diffRecorder) PopStep() { d.last = nil } diff --git a/util/safediff/diff_test.go b/util/safediff/diff_test.go new file mode 100644 index 000000000..e580bd922 --- /dev/null +++ b/util/safediff/diff_test.go @@ -0,0 +1,196 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package safediff + +import ( + "strings" + "testing" + + "github.com/google/go-cmp/cmp" +) + +func init() { diffTest = true } + +func TestLines(t *testing.T) { + // The diffs shown below technically depend on the stability of cmp, + // but that should be fine for sufficiently simple diffs like these. + // If the output does change, that would suggest a significant regression + // in the optimality of cmp's diffing algorithm. + + x := `{ + "firstName": "John", + "lastName": "Smith", + "isAlive": true, + "age": 27, + "address": { + "streetAddress": "21 2nd Street", + "city": "New York", + "state": "NY", + "postalCode": "10021-3100" + }, + "phoneNumbers": [{ + "type": "home", + "number": "212 555-1234" + }, { + "type": "office", + "number": "646 555-4567" + }], + "children": [ + "Catherine", + "Thomas", + "Trevor" + ], + "spouse": null +}` + y := x + y = strings.ReplaceAll(y, `"New York"`, `"Los Angeles"`) + y = strings.ReplaceAll(y, `"NY"`, `"CA"`) + y = strings.ReplaceAll(y, `"646 555-4567"`, `"315 252-8888"`) + + wantDiff := ` +â€Ļ 5 identical lines + "address": { + "streetAddress": "21 2nd Street", +- "city": "New York", +- "state": "NY", ++ "city": "Los Angeles", ++ "state": "CA", + "postalCode": "10021-3100" + }, +â€Ļ 3 identical lines + }, { + "type": "office", +- "number": "646 555-4567" ++ "number": "315 252-8888" + }], +â€Ļ 7 identical lines +`[1:] + gotDiff, gotTrunc := Lines(x, y, -1) + if d := cmp.Diff(gotDiff, wantDiff); d != "" { + t.Errorf("Lines mismatch (-got +want):\n%s\ngot:\n%s\nwant:\n%s", d, gotDiff, wantDiff) + } else if gotTrunc == true { + t.Errorf("Lines: output unexpectedly truncated") + } + + wantDiff = ` +â€Ļ 5 identical lines + "address": { + "streetAddress": "21 2nd Street", +- "city": "New York", +- "state": "NY", ++ "city": "Los Angeles", +â€Ļ 15 identical, 1 removed, and 2 inserted lines +`[1:] + gotDiff, gotTrunc = Lines(x, y, 200) + if d := cmp.Diff(gotDiff, wantDiff); d != "" { + t.Errorf("Lines mismatch (-got +want):\n%s\ngot:\n%s\nwant:\n%s", d, gotDiff, wantDiff) + } else if gotTrunc == false { + t.Errorf("Lines: output unexpectedly not truncated") + } + + wantDiff = "â€Ļ 17 identical, 3 removed, and 3 inserted lines\n" + gotDiff, gotTrunc = Lines(x, y, 0) + if d := cmp.Diff(gotDiff, wantDiff); d != "" { + t.Errorf("Lines mismatch (-got +want):\n%s\ngot:\n%s\nwant:\n%s", d, gotDiff, wantDiff) + } else if gotTrunc == false { + t.Errorf("Lines: output unexpectedly not truncated") + } + + x = `{ + "unrelated": [ + "unrelated", + ], + "related": { + "unrelated": [ + "unrelated", + ], + "related": { + "unrelated": [ + "unrelated", + ], + "related": { + "related": "changed", + }, + "unrelated": [ + "unrelated", + ], + }, + "unrelated": [ + "unrelated", + ], + }, + "unrelated": [ + "unrelated", + ], +}` + y = strings.ReplaceAll(x, "changed", "CHANGED") + + wantDiff = ` +â€Ļ 4 identical lines + "related": { +â€Ļ 3 identical lines + "related": { +â€Ļ 3 identical lines + "related": { +- "related": "changed", ++ "related": "CHANGED", + }, +â€Ļ 3 identical lines + }, +â€Ļ 3 identical lines + }, +â€Ļ 4 identical lines +`[1:] + gotDiff, gotTrunc = Lines(x, y, -1) + if d := cmp.Diff(gotDiff, wantDiff); d != "" { + t.Errorf("Lines mismatch (-got +want):\n%s\ngot:\n%s\nwant:\n%s", d, gotDiff, wantDiff) + } else if gotTrunc == true { + t.Errorf("Lines: output unexpectedly truncated") + } + + x = `{ + "ACLs": [ + { + "Action": "accept", + "Users": ["group:all"], + "Ports": ["tag:tmemes:80"], + }, + ], +}` + y = strings.ReplaceAll(x, "tag:tmemes:80", "tag:tmemes:80,8383") + wantDiff = ` + { + "ACLs": [ + { + "Action": "accept", + "Users": ["group:all"], +- "Ports": ["tag:tmemes:80"], ++ "Ports": ["tag:tmemes:80,8383"], + }, + ], + } +`[1:] + gotDiff, gotTrunc = Lines(x, y, -1) + if d := cmp.Diff(gotDiff, wantDiff); d != "" { + t.Errorf("Lines mismatch (-got +want):\n%s\ngot:\n%s\nwant:\n%s", d, gotDiff, wantDiff) + } else if gotTrunc == true { + t.Errorf("Lines: output unexpectedly truncated") + } +} + +func FuzzDiff(f *testing.F) { + f.Fuzz(func(t *testing.T, x, y string, maxSize int) { + const maxInput = 1e3 + if len(x) > maxInput { + x = x[:maxInput] + } + if len(y) > maxInput { + y = y[:maxInput] + } + diff, _ := Lines(x, y, maxSize) // make sure this does not panic + if strings.Count(diff, "\n") > 1 && maxSize >= 0 && len(diff) > maxSize { + t.Fatal("maxSize exceeded") + } + }) +} diff --git a/util/set/handle.go b/util/set/handle.go index 471ceeba2..9c6b6dab0 100644 --- a/util/set/handle.go +++ b/util/set/handle.go @@ -9,20 +9,28 @@ package set type HandleSet[T any] map[Handle]T // Handle is an opaque comparable value that's used as the map key in a -// HandleSet. The only way to get one is to call HandleSet.Add. +// HandleSet. type Handle struct { v *byte } +// NewHandle returns a new handle value. +func NewHandle() Handle { + return Handle{new(byte)} +} + // Add adds the element (map value) e to the set. // -// It returns the handle (map key) with which e can be removed, using a map -// delete. +// It returns a new handle (map key) with which e can be removed, using a map +// delete or the [HandleSet.Delete] method. func (s *HandleSet[T]) Add(e T) Handle { - h := Handle{new(byte)} + h := NewHandle() if *s == nil { *s = make(HandleSet[T]) } (*s)[h] = e return h } + +// Delete removes the element with handle h from the set. +func (s HandleSet[T]) Delete(h Handle) { delete(s, h) } diff --git a/util/set/intset.go b/util/set/intset.go new file mode 100644 index 000000000..d32524691 --- /dev/null +++ b/util/set/intset.go @@ -0,0 +1,198 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package set + +import ( + "iter" + "maps" + "math/bits" + "math/rand/v2" + + "golang.org/x/exp/constraints" + "tailscale.com/util/mak" +) + +// IntSet is a set optimized for integer values close to zero +// or set of integers that are close in value. +type IntSet[T constraints.Integer] struct { + // bits is a [bitSet] for numbers less than [bits.UintSize]. + bits bitSet + + // extra is a mapping of [bitSet] for numbers not in bits, + // where the key is a number modulo [bits.UintSize]. + extra map[uint64]bitSet + + // extraLen is the count of numbers in extra since len(extra) + // does not reflect that each bitSet may have multiple numbers. + extraLen int +} + +// IntsOf constructs an [IntSet] with the provided elements. +func IntsOf[T constraints.Integer](slice ...T) IntSet[T] { + var s IntSet[T] + for _, e := range slice { + s.Add(e) + } + return s +} + +// Values returns an iterator over the elements of the set. +// The iterator will yield the elements in no particular order. +func (s IntSet[T]) Values() iter.Seq[T] { + return func(yield func(T) bool) { + if s.bits != 0 { + for i := range s.bits.values() { + if !yield(decodeZigZag[T](i)) { + return + } + } + } + if s.extra != nil { + for hi, bs := range s.extra { + for lo := range bs.values() { + if !yield(decodeZigZag[T](hi*bits.UintSize + lo)) { + return + } + } + } + } + } +} + +// Contains reports whether e is in the set. +func (s IntSet[T]) Contains(e T) bool { + if v := encodeZigZag(e); v < bits.UintSize { + return s.bits.contains(v) + } else { + hi, lo := v/uint64(bits.UintSize), v%uint64(bits.UintSize) + return s.extra[hi].contains(lo) + } +} + +// Add adds e to the set. +// +// When storing a IntSet in a map as a value type, +// it is important to re-assign the map entry after calling Add or Delete, +// as the IntSet's representation may change. +func (s *IntSet[T]) Add(e T) { + if v := encodeZigZag(e); v < bits.UintSize { + s.bits.add(v) + } else { + hi, lo := v/uint64(bits.UintSize), v%uint64(bits.UintSize) + if bs := s.extra[hi]; !bs.contains(lo) { + bs.add(lo) + mak.Set(&s.extra, hi, bs) + s.extra[hi] = bs + s.extraLen++ + } + } +} + +// AddSeq adds the values from seq to the set. +func (s *IntSet[T]) AddSeq(seq iter.Seq[T]) { + for e := range seq { + s.Add(e) + } +} + +// Len reports the number of elements in the set. +func (s IntSet[T]) Len() int { + return s.bits.len() + s.extraLen +} + +// Delete removes e from the set. +// +// When storing a IntSet in a map as a value type, +// it is important to re-assign the map entry after calling Add or Delete, +// as the IntSet's representation may change. +func (s *IntSet[T]) Delete(e T) { + if v := encodeZigZag(e); v < bits.UintSize { + s.bits.delete(v) + } else { + hi, lo := v/uint64(bits.UintSize), v%uint64(bits.UintSize) + if bs := s.extra[hi]; bs.contains(lo) { + bs.delete(lo) + mak.Set(&s.extra, hi, bs) + s.extra[hi] = bs + s.extraLen-- + } + } +} + +// DeleteSeq deletes the values in seq from the set. +func (s *IntSet[T]) DeleteSeq(seq iter.Seq[T]) { + for e := range seq { + s.Delete(e) + } +} + +// Equal reports whether s is equal to other. +func (s IntSet[T]) Equal(other IntSet[T]) bool { + for hi, bits := range s.extra { + if other.extra[hi] != bits { + return false + } + } + return s.extraLen == other.extraLen && s.bits == other.bits +} + +// Clone returns a copy of s that doesn't alias the original. +func (s IntSet[T]) Clone() IntSet[T] { + return IntSet[T]{ + bits: s.bits, + extra: maps.Clone(s.extra), + extraLen: s.extraLen, + } +} + +type bitSet uint + +func (s bitSet) values() iter.Seq[uint64] { + return func(yield func(uint64) bool) { + // Hyrum-proofing: randomly iterate in forwards or reverse. + if rand.Uint64()%2 == 0 { + for i := 0; i < bits.UintSize; i++ { + if s.contains(uint64(i)) && !yield(uint64(i)) { + return + } + } + } else { + for i := bits.UintSize; i >= 0; i-- { + if s.contains(uint64(i)) && !yield(uint64(i)) { + return + } + } + } + } +} +func (s bitSet) len() int { return bits.OnesCount(uint(s)) } +func (s bitSet) contains(i uint64) bool { return s&(1< 0 } +func (s *bitSet) add(i uint64) { *s |= 1 << i } +func (s *bitSet) delete(i uint64) { *s &= ^(1 << i) } + +// encodeZigZag encodes an integer as an unsigned integer ensuring that +// negative integers near zero still have a near zero positive value. +// For unsigned integers, it returns the value verbatim. +func encodeZigZag[T constraints.Integer](v T) uint64 { + var zero T + if ^zero >= 0 { // must be constraints.Unsigned + return uint64(v) + } else { // must be constraints.Signed + // See [google.golang.org/protobuf/encoding/protowire.EncodeZigZag] + return uint64(int64(v)<<1) ^ uint64(int64(v)>>63) + } +} + +// decodeZigZag decodes an unsigned integer as an integer ensuring that +// negative integers near zero still have a near zero positive value. +// For unsigned integers, it returns the value verbatim. +func decodeZigZag[T constraints.Integer](v uint64) T { + var zero T + if ^zero >= 0 { // must be constraints.Unsigned + return T(v) + } else { // must be constraints.Signed + // See [google.golang.org/protobuf/encoding/protowire.DecodeZigZag] + return T(int64(v>>1) ^ int64(v)<<63>>63) + } +} diff --git a/util/set/intset_test.go b/util/set/intset_test.go new file mode 100644 index 000000000..d838215c9 --- /dev/null +++ b/util/set/intset_test.go @@ -0,0 +1,180 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package set + +import ( + "maps" + "math" + "slices" + "testing" + + "golang.org/x/exp/constraints" +) + +func TestIntSet(t *testing.T) { + t.Run("Int64", func(t *testing.T) { + ss := make(Set[int64]) + var si IntSet[int64] + intValues(t, ss, si) + deleteInt(t, ss, &si, -5) + deleteInt(t, ss, &si, 2) + deleteInt(t, ss, &si, 75) + intValues(t, ss, si) + addInt(t, ss, &si, 2) + addInt(t, ss, &si, 75) + addInt(t, ss, &si, 75) + addInt(t, ss, &si, -3) + addInt(t, ss, &si, -3) + addInt(t, ss, &si, -3) + addInt(t, ss, &si, math.MinInt64) + addInt(t, ss, &si, 8) + intValues(t, ss, si) + addInt(t, ss, &si, 77) + addInt(t, ss, &si, 76) + addInt(t, ss, &si, 76) + addInt(t, ss, &si, 76) + intValues(t, ss, si) + addInt(t, ss, &si, -5) + addInt(t, ss, &si, 7) + addInt(t, ss, &si, -83) + addInt(t, ss, &si, math.MaxInt64) + intValues(t, ss, si) + deleteInt(t, ss, &si, -5) + deleteInt(t, ss, &si, 2) + deleteInt(t, ss, &si, 75) + intValues(t, ss, si) + deleteInt(t, ss, &si, math.MinInt64) + deleteInt(t, ss, &si, math.MaxInt64) + intValues(t, ss, si) + if !si.Equal(IntsOf(ss.Slice()...)) { + t.Errorf("{%v}.Equal({%v}) = false, want true", si, ss) + } + }) + + t.Run("Uint64", func(t *testing.T) { + ss := make(Set[uint64]) + var si IntSet[uint64] + intValues(t, ss, si) + deleteInt(t, ss, &si, 5) + deleteInt(t, ss, &si, 2) + deleteInt(t, ss, &si, 75) + intValues(t, ss, si) + addInt(t, ss, &si, 2) + addInt(t, ss, &si, 75) + addInt(t, ss, &si, 75) + addInt(t, ss, &si, 3) + addInt(t, ss, &si, 3) + addInt(t, ss, &si, 8) + intValues(t, ss, si) + addInt(t, ss, &si, 77) + addInt(t, ss, &si, 76) + addInt(t, ss, &si, 76) + addInt(t, ss, &si, 76) + intValues(t, ss, si) + addInt(t, ss, &si, 5) + addInt(t, ss, &si, 7) + addInt(t, ss, &si, 83) + addInt(t, ss, &si, math.MaxInt64) + intValues(t, ss, si) + deleteInt(t, ss, &si, 5) + deleteInt(t, ss, &si, 2) + deleteInt(t, ss, &si, 75) + intValues(t, ss, si) + deleteInt(t, ss, &si, math.MaxInt64) + intValues(t, ss, si) + if !si.Equal(IntsOf(ss.Slice()...)) { + t.Errorf("{%v}.Equal({%v}) = false, want true", si, ss) + } + }) +} + +func intValues[T constraints.Integer](t testing.TB, ss Set[T], si IntSet[T]) { + got := slices.Collect(maps.Keys(ss)) + slices.Sort(got) + want := slices.Collect(si.Values()) + slices.Sort(want) + if !slices.Equal(got, want) { + t.Fatalf("Values mismatch:\n\tgot %v\n\twant %v", got, want) + } + if got, want := si.Len(), ss.Len(); got != want { + t.Fatalf("Len() = %v, want %v", got, want) + } +} + +func addInt[T constraints.Integer](t testing.TB, ss Set[T], si *IntSet[T], v T) { + t.Helper() + if got, want := si.Contains(v), ss.Contains(v); got != want { + t.Fatalf("Contains(%v) = %v, want %v", v, got, want) + } + ss.Add(v) + si.Add(v) + if !si.Contains(v) { + t.Fatalf("Contains(%v) = false, want true", v) + } + if got, want := si.Len(), ss.Len(); got != want { + t.Fatalf("Len() = %v, want %v", got, want) + } +} + +func deleteInt[T constraints.Integer](t testing.TB, ss Set[T], si *IntSet[T], v T) { + t.Helper() + if got, want := si.Contains(v), ss.Contains(v); got != want { + t.Fatalf("Contains(%v) = %v, want %v", v, got, want) + } + ss.Delete(v) + si.Delete(v) + if si.Contains(v) { + t.Fatalf("Contains(%v) = true, want false", v) + } + if got, want := si.Len(), ss.Len(); got != want { + t.Fatalf("Len() = %v, want %v", got, want) + } +} + +func TestZigZag(t *testing.T) { + t.Run("Int64", func(t *testing.T) { + for _, tt := range []struct { + decoded int64 + encoded uint64 + }{ + {math.MinInt64, math.MaxUint64}, + {-2, 3}, + {-1, 1}, + {0, 0}, + {1, 2}, + {2, 4}, + {math.MaxInt64, math.MaxUint64 - 1}, + } { + encoded := encodeZigZag(tt.decoded) + if encoded != tt.encoded { + t.Errorf("encodeZigZag(%v) = %v, want %v", tt.decoded, encoded, tt.encoded) + } + decoded := decodeZigZag[int64](tt.encoded) + if decoded != tt.decoded { + t.Errorf("decodeZigZag(%v) = %v, want %v", tt.encoded, decoded, tt.decoded) + } + } + }) + t.Run("Uint64", func(t *testing.T) { + for _, tt := range []struct { + decoded uint64 + encoded uint64 + }{ + {0, 0}, + {1, 1}, + {2, 2}, + {math.MaxInt64, math.MaxInt64}, + {math.MaxUint64, math.MaxUint64}, + } { + encoded := encodeZigZag(tt.decoded) + if encoded != tt.encoded { + t.Errorf("encodeZigZag(%v) = %v, want %v", tt.decoded, encoded, tt.encoded) + } + decoded := decodeZigZag[uint64](tt.encoded) + if decoded != tt.decoded { + t.Errorf("decodeZigZag(%v) = %v, want %v", tt.encoded, decoded, tt.decoded) + } + } + }) +} diff --git a/util/set/slice.go b/util/set/slice.go index 38551aee1..2fc65b82d 100644 --- a/util/set/slice.go +++ b/util/set/slice.go @@ -67,7 +67,7 @@ func (ss *Slice[T]) Add(vs ...T) { // AddSlice adds all elements in vs to the set. func (ss *Slice[T]) AddSlice(vs views.Slice[T]) { - for i := range vs.Len() { - ss.Add(vs.At(i)) + for _, v := range vs.All() { + ss.Add(v) } } diff --git a/util/set/smallset.go b/util/set/smallset.go new file mode 100644 index 000000000..1b77419d2 --- /dev/null +++ b/util/set/smallset.go @@ -0,0 +1,148 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package set + +import ( + "iter" + "maps" + + "tailscale.com/types/structs" +) + +// SmallSet is a set that is optimized for reducing memory overhead when the +// expected size of the set is 0 or 1 elements. +// +// The zero value of SmallSet is a usable empty set. +// +// When storing a SmallSet in a map as a value type, it is important to re-assign +// the map entry after calling Add or Delete, as the SmallSet's representation +// may change. +// +// Copying a SmallSet by value may alias the previous value. Use the Clone method +// to create a new SmallSet with the same contents. +type SmallSet[T comparable] struct { + _ structs.Incomparable // to prevent == mistakes + one T // if non-zero, then single item in set + m Set[T] // if non-nil, the set of items, which might be size 1 if it's the zero value of T +} + +// Values returns an iterator over the elements of the set. +// The iterator will yield the elements in no particular order. +func (s SmallSet[T]) Values() iter.Seq[T] { + if s.m != nil { + return maps.Keys(s.m) + } + var zero T + return func(yield func(T) bool) { + if s.one != zero { + yield(s.one) + } + } +} + +// Contains reports whether e is in the set. +func (s SmallSet[T]) Contains(e T) bool { + if s.m != nil { + return s.m.Contains(e) + } + var zero T + return e != zero && s.one == e +} + +// SoleElement returns the single value in the set, if the set has exactly one +// element. +// +// If the set is empty or has more than one element, ok will be false and e will +// be the zero value of T. +func (s SmallSet[T]) SoleElement() (e T, ok bool) { + return s.one, s.Len() == 1 +} + +// Add adds e to the set. +// +// When storing a SmallSet in a map as a value type, it is important to +// re-assign the map entry after calling Add or Delete, as the SmallSet's +// representation may change. +func (s *SmallSet[T]) Add(e T) { + var zero T + if s.m != nil { + s.m.Add(e) + return + } + // Non-zero elements can go into s.one. + if e != zero { + if s.one == zero { + s.one = e // Len 0 to Len 1 + return + } + if s.one == e { + return // dup + } + } + // Need to make a multi map, either + // because we now have two items, or + // because e is the zero value. + s.m = Set[T]{} + if s.one != zero { + s.m.Add(s.one) // move single item to multi + } + s.m.Add(e) // add new item, possibly zero + s.one = zero +} + +// Len reports the number of elements in the set. +func (s SmallSet[T]) Len() int { + var zero T + if s.m != nil { + return s.m.Len() + } + if s.one != zero { + return 1 + } + return 0 +} + +// Delete removes e from the set. +// +// When storing a SmallSet in a map as a value type, it is important to +// re-assign the map entry after calling Add or Delete, as the SmallSet's +// representation may change. +func (s *SmallSet[T]) Delete(e T) { + var zero T + if s.m == nil { + if s.one == e { + s.one = zero + } + return + } + s.m.Delete(e) + + // If the map size drops to zero, that means + // it only contained the zero value of T. + if s.m.Len() == 0 { + s.m = nil + return + } + + // If the map size drops to one element and doesn't + // contain the zero value, we can switch back to the + // single-item representation. + if s.m.Len() == 1 { + for v := range s.m { + if v != zero { + s.one = v + s.m = nil + } + } + } + return +} + +// Clone returns a copy of s that doesn't alias the original. +func (s SmallSet[T]) Clone() SmallSet[T] { + return SmallSet[T]{ + one: s.one, + m: maps.Clone(s.m), // preserves nilness + } +} diff --git a/util/set/smallset_test.go b/util/set/smallset_test.go new file mode 100644 index 000000000..d6f446df0 --- /dev/null +++ b/util/set/smallset_test.go @@ -0,0 +1,126 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package set + +import ( + "fmt" + "iter" + "maps" + "reflect" + "slices" + "testing" +) + +func TestSmallSet(t *testing.T) { + t.Parallel() + + wantSize := reflect.TypeFor[int64]().Size() + reflect.TypeFor[map[int]struct{}]().Size() + if wantSize > 16 { + t.Errorf("wantSize should be no more than 16") // it might be smaller on 32-bit systems + } + if size := reflect.TypeFor[SmallSet[int64]]().Size(); size != wantSize { + t.Errorf("SmallSet[int64] size is %d, want %v", size, wantSize) + } + + type op struct { + add bool + v int + } + ops := iter.Seq[op](func(yield func(op) bool) { + for _, add := range []bool{false, true} { + for v := range 4 { + if !yield(op{add: add, v: v}) { + return + } + } + } + }) + type setLike interface { + Add(int) + Delete(int) + } + apply := func(s setLike, o op) { + if o.add { + s.Add(o.v) + } else { + s.Delete(o.v) + } + } + + // For all combinations of 4 operations, + // apply them to both a regular map and SmallSet + // and make sure all the invariants hold. + + for op1 := range ops { + for op2 := range ops { + for op3 := range ops { + for op4 := range ops { + + normal := Set[int]{} + small := &SmallSet[int]{} + for _, op := range []op{op1, op2, op3, op4} { + apply(normal, op) + apply(small, op) + } + + name := func() string { + return fmt.Sprintf("op1=%v, op2=%v, op3=%v, op4=%v", op1, op2, op3, op4) + } + if normal.Len() != small.Len() { + t.Errorf("len mismatch after ops %s: normal=%d, small=%d", name(), normal.Len(), small.Len()) + } + if got := small.Clone().Len(); normal.Len() != got { + t.Errorf("len mismatch after ops %s: normal=%d, clone=%d", name(), normal.Len(), got) + } + + normalEle := slices.Sorted(maps.Keys(normal)) + smallEle := slices.Sorted(small.Values()) + if !slices.Equal(normalEle, smallEle) { + t.Errorf("elements mismatch after ops %s: normal=%v, small=%v", name(), normalEle, smallEle) + } + for e := range 5 { + if normal.Contains(e) != small.Contains(e) { + t.Errorf("contains(%v) mismatch after ops %s: normal=%v, small=%v", e, name(), normal.Contains(e), small.Contains(e)) + } + } + + if err := small.checkInvariants(); err != nil { + t.Errorf("checkInvariants failed after ops %s: %v", name(), err) + } + + if !t.Failed() { + sole, ok := small.SoleElement() + if ok != (small.Len() == 1) { + t.Errorf("SoleElement ok mismatch after ops %s: SoleElement ok=%v, want=%v", name(), ok, !ok) + } + if ok && sole != smallEle[0] { + t.Errorf("SoleElement value mismatch after ops %s: SoleElement=%v, want=%v", name(), sole, smallEle[0]) + t.Errorf("Internals: %+v", small) + } + } + } + } + } + } +} + +func (s *SmallSet[T]) checkInvariants() error { + var zero T + if s.m != nil && s.one != zero { + return fmt.Errorf("both m and one are non-zero") + } + if s.m != nil { + switch len(s.m) { + case 0: + return fmt.Errorf("m is non-nil but empty") + case 1: + for k := range s.m { + if k != zero { + return fmt.Errorf("m contains exactly 1 non-zero element, %v", k) + } + } + } + } + return nil +} diff --git a/util/slicesx/slicesx.go b/util/slicesx/slicesx.go index e0b820eb7..ff9d47375 100644 --- a/util/slicesx/slicesx.go +++ b/util/slicesx/slicesx.go @@ -95,6 +95,17 @@ func Filter[S ~[]T, T any](dst, src S, fn func(T) bool) S { return dst } +// AppendNonzero appends all non-zero elements of src to dst. +func AppendNonzero[S ~[]T, T comparable](dst, src S) S { + var zero T + for _, v := range src { + if v != zero { + dst = append(dst, v) + } + } + return dst +} + // AppendMatching appends elements in ps to dst if f(x) is true. func AppendMatching[T any](dst, ps []T, f func(T) bool) []T { for _, p := range ps { @@ -148,3 +159,43 @@ func FirstEqual[T comparable](s []T, v T) bool { func LastEqual[T comparable](s []T, v T) bool { return len(s) > 0 && s[len(s)-1] == v } + +// MapKeys returns the values of the map m. +// +// The keys will be in an indeterminate order. +// +// It's equivalent to golang.org/x/exp/maps.Keys, which +// unfortunately has the package name "maps", shadowing +// the std "maps" package. This version exists for clarity +// when reading call sites. +// +// As opposed to slices.Collect(maps.Keys(m)), this allocates +// the returned slice once to exactly the right size, rather than +// appending larger backing arrays as it goes. +func MapKeys[M ~map[K]V, K comparable, V any](m M) []K { + r := make([]K, 0, len(m)) + for k := range m { + r = append(r, k) + } + return r +} + +// MapValues returns the values of the map m. +// +// The values will be in an indeterminate order. +// +// It's equivalent to golang.org/x/exp/maps.Values, which +// unfortunately has the package name "maps", shadowing +// the std "maps" package. This version exists for clarity +// when reading call sites. +// +// As opposed to slices.Collect(maps.Values(m)), this allocates +// the returned slice once to exactly the right size, rather than +// appending larger backing arrays as it goes. +func MapValues[M ~map[K]V, K comparable, V any](m M) []V { + r := make([]V, 0, len(m)) + for _, v := range m { + r = append(r, v) + } + return r +} diff --git a/util/slicesx/slicesx_test.go b/util/slicesx/slicesx_test.go index 597b22b83..346449284 100644 --- a/util/slicesx/slicesx_test.go +++ b/util/slicesx/slicesx_test.go @@ -137,6 +137,19 @@ func TestFilterNoAllocations(t *testing.T) { } } +func TestAppendNonzero(t *testing.T) { + v := []string{"one", "two", "", "four"} + got := AppendNonzero(nil, v) + want := []string{"one", "two", "four"} + if !reflect.DeepEqual(got, want) { + t.Errorf("got %v; want %v", got, want) + } + got = AppendNonzero(v[:0], v) + if !reflect.DeepEqual(got, want) { + t.Errorf("got %v; want %v", got, want) + } +} + func TestAppendMatching(t *testing.T) { v := []string{"one", "two", "three", "four"} got := AppendMatching(v[:0], v, func(s string) bool { return len(s) > 3 }) diff --git a/util/stringsx/stringsx.go b/util/stringsx/stringsx.go new file mode 100644 index 000000000..6c7a8d20d --- /dev/null +++ b/util/stringsx/stringsx.go @@ -0,0 +1,52 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package stringsx provides additional string manipulation functions +// that aren't in the standard library's strings package or go4.org/mem. +package stringsx + +import ( + "unicode" + "unicode/utf8" +) + +// CompareFold returns -1, 0, or 1 depending on whether a < b, a == b, or a > b, +// like cmp.Compare, but case insensitively. +func CompareFold(a, b string) int { + // Track our position in both strings + ia, ib := 0, 0 + for ia < len(a) && ib < len(b) { + ra, wa := nextRuneLower(a[ia:]) + rb, wb := nextRuneLower(b[ib:]) + if ra < rb { + return -1 + } + if ra > rb { + return 1 + } + ia += wa + ib += wb + if wa == 0 || wb == 0 { + break + } + } + + // If we've reached here, one or both strings are exhausted + // The shorter string is "less than" if they match up to this point + switch { + case ia == len(a) && ib == len(b): + return 0 + case ia == len(a): + return -1 + default: + return 1 + } +} + +// nextRuneLower returns the next rune in the string, lowercased, along with its +// original (consumed) width in bytes. If the string is empty, it returns +// (utf8.RuneError, 0) +func nextRuneLower(s string) (r rune, width int) { + r, width = utf8.DecodeRuneInString(s) + return unicode.ToLower(r), width +} diff --git a/util/stringsx/stringsx_test.go b/util/stringsx/stringsx_test.go new file mode 100644 index 000000000..8575c0b27 --- /dev/null +++ b/util/stringsx/stringsx_test.go @@ -0,0 +1,78 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package stringsx + +import ( + "cmp" + "strings" + "testing" +) + +func TestCompareFold(t *testing.T) { + tests := []struct { + a, b string + }{ + // Basic ASCII cases + {"", ""}, + {"a", "a"}, + {"a", "A"}, + {"A", "a"}, + {"a", "b"}, + {"b", "a"}, + {"abc", "ABC"}, + {"ABC", "abc"}, + {"abc", "abd"}, + {"abd", "abc"}, + + // Length differences + {"abc", "ab"}, + {"ab", "abc"}, + + // Unicode cases + {"ä¸–į•Œ", "ä¸–į•Œ"}, + {"Helloä¸–į•Œ", "helloä¸–į•Œ"}, + {"ä¸–į•ŒHello", "ä¸–į•Œhello"}, + {"ä¸–į•Œ", "ä¸–į•Œx"}, + {"ä¸–į•Œx", "ä¸–į•Œ"}, + + // Special case folding examples + {"ß", "ss"}, // German sharp s + {"īŦ", "fi"}, // fi ligature + {"ÎŖ", "΃"}, // Greek sigma + {"İ", "i\u0307"}, // Turkish dotted I + + // Mixed cases + {"HelloWorld", "helloworld"}, + {"HELLOWORLD", "helloworld"}, + {"helloworld", "HELLOWORLD"}, + {"HelloWorld", "helloworld"}, + {"helloworld", "HelloWorld"}, + + // Edge cases + {" ", " "}, + {"1", "1"}, + {"123", "123"}, + {"!@#", "!@#"}, + } + + wants := []int{} + for _, tt := range tests { + got := CompareFold(tt.a, tt.b) + want := cmp.Compare(strings.ToLower(tt.a), strings.ToLower(tt.b)) + if got != want { + t.Errorf("CompareFold(%q, %q) = %v, want %v", tt.a, tt.b, got, want) + } + wants = append(wants, want) + } + + if n := testing.AllocsPerRun(1000, func() { + for i, tt := range tests { + if CompareFold(tt.a, tt.b) != wants[i] { + panic("unexpected") + } + } + }); n > 0 { + t.Errorf("allocs = %v; want 0", int(n)) + } +} diff --git a/util/syspolicy/caching_handler.go b/util/syspolicy/caching_handler.go deleted file mode 100644 index 5192958bc..000000000 --- a/util/syspolicy/caching_handler.go +++ /dev/null @@ -1,122 +0,0 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package syspolicy - -import ( - "errors" - "sync" -) - -// CachingHandler is a handler that reads policies from an underlying handler the first time each key is requested -// and permanently caches the result unless there is an error. If there is an ErrNoSuchKey error, that result is cached, -// otherwise the actual error is returned and the next read for that key will retry using the handler. -type CachingHandler struct { - mu sync.Mutex - strings map[string]string - uint64s map[string]uint64 - bools map[string]bool - strArrs map[string][]string - notFound map[string]bool - handler Handler -} - -// NewCachingHandler creates a CachingHandler given a handler. -func NewCachingHandler(handler Handler) *CachingHandler { - return &CachingHandler{ - handler: handler, - strings: make(map[string]string), - uint64s: make(map[string]uint64), - bools: make(map[string]bool), - strArrs: make(map[string][]string), - notFound: make(map[string]bool), - } -} - -// ReadString reads the policy settings value string given the key. -// ReadString first reads from the handler's cache before resorting to using the handler. -func (ch *CachingHandler) ReadString(key string) (string, error) { - ch.mu.Lock() - defer ch.mu.Unlock() - if val, ok := ch.strings[key]; ok { - return val, nil - } - if notFound := ch.notFound[key]; notFound { - return "", ErrNoSuchKey - } - val, err := ch.handler.ReadString(key) - if errors.Is(err, ErrNoSuchKey) { - ch.notFound[key] = true - return "", err - } else if err != nil { - return "", err - } - ch.strings[key] = val - return val, nil -} - -// ReadUInt64 reads the policy settings uint64 value given the key. -// ReadUInt64 first reads from the handler's cache before resorting to using the handler. -func (ch *CachingHandler) ReadUInt64(key string) (uint64, error) { - ch.mu.Lock() - defer ch.mu.Unlock() - if val, ok := ch.uint64s[key]; ok { - return val, nil - } - if notFound := ch.notFound[key]; notFound { - return 0, ErrNoSuchKey - } - val, err := ch.handler.ReadUInt64(key) - if errors.Is(err, ErrNoSuchKey) { - ch.notFound[key] = true - return 0, err - } else if err != nil { - return 0, err - } - ch.uint64s[key] = val - return val, nil -} - -// ReadBoolean reads the policy settings boolean value given the key. -// ReadBoolean first reads from the handler's cache before resorting to using the handler. -func (ch *CachingHandler) ReadBoolean(key string) (bool, error) { - ch.mu.Lock() - defer ch.mu.Unlock() - if val, ok := ch.bools[key]; ok { - return val, nil - } - if notFound := ch.notFound[key]; notFound { - return false, ErrNoSuchKey - } - val, err := ch.handler.ReadBoolean(key) - if errors.Is(err, ErrNoSuchKey) { - ch.notFound[key] = true - return false, err - } else if err != nil { - return false, err - } - ch.bools[key] = val - return val, nil -} - -// ReadBoolean reads the policy settings boolean value given the key. -// ReadBoolean first reads from the handler's cache before resorting to using the handler. -func (ch *CachingHandler) ReadStringArray(key string) ([]string, error) { - ch.mu.Lock() - defer ch.mu.Unlock() - if val, ok := ch.strArrs[key]; ok { - return val, nil - } - if notFound := ch.notFound[key]; notFound { - return nil, ErrNoSuchKey - } - val, err := ch.handler.ReadStringArray(key) - if errors.Is(err, ErrNoSuchKey) { - ch.notFound[key] = true - return nil, err - } else if err != nil { - return nil, err - } - ch.strArrs[key] = val - return val, nil -} diff --git a/util/syspolicy/caching_handler_test.go b/util/syspolicy/caching_handler_test.go deleted file mode 100644 index 881f6ff83..000000000 --- a/util/syspolicy/caching_handler_test.go +++ /dev/null @@ -1,262 +0,0 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package syspolicy - -import ( - "testing" -) - -func TestHandlerReadString(t *testing.T) { - tests := []struct { - name string - key string - handlerKey Key - handlerValue string - handlerError error - preserveHandler bool - wantValue string - wantErr error - strings map[string]string - expectedCalls int - }{ - { - name: "read existing cached values", - key: "test", - handlerKey: "do not read", - strings: map[string]string{"test": "foo"}, - wantValue: "foo", - expectedCalls: 0, - }, - { - name: "read existing values not cached", - key: "test", - handlerKey: "test", - handlerValue: "foo", - wantValue: "foo", - expectedCalls: 1, - }, - { - name: "error no such key", - key: "test", - handlerKey: "test", - handlerError: ErrNoSuchKey, - wantErr: ErrNoSuchKey, - expectedCalls: 1, - }, - { - name: "other error", - key: "test", - handlerKey: "test", - handlerError: someOtherError, - wantErr: someOtherError, - preserveHandler: true, - expectedCalls: 2, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - testHandler := &testHandler{ - t: t, - key: tt.handlerKey, - s: tt.handlerValue, - err: tt.handlerError, - } - cache := NewCachingHandler(testHandler) - if tt.strings != nil { - cache.strings = tt.strings - } - got, err := cache.ReadString(tt.key) - if err != tt.wantErr { - t.Errorf("err=%v want %v", err, tt.wantErr) - } - if got != tt.wantValue { - t.Errorf("got %v want %v", got, cache.strings[tt.key]) - } - if !tt.preserveHandler { - testHandler.key, testHandler.s, testHandler.err = "do not read", "", nil - } - got, err = cache.ReadString(tt.key) - if err != tt.wantErr { - t.Errorf("repeat err=%v want %v", err, tt.wantErr) - } - if got != tt.wantValue { - t.Errorf("repeat got %v want %v", got, cache.strings[tt.key]) - } - if testHandler.calls != tt.expectedCalls { - t.Errorf("calls=%v want %v", testHandler.calls, tt.expectedCalls) - } - }) - } -} - -func TestHandlerReadUint64(t *testing.T) { - tests := []struct { - name string - key string - handlerKey Key - handlerValue uint64 - handlerError error - preserveHandler bool - wantValue uint64 - wantErr error - uint64s map[string]uint64 - expectedCalls int - }{ - { - name: "read existing cached values", - key: "test", - handlerKey: "do not read", - uint64s: map[string]uint64{"test": 1}, - wantValue: 1, - expectedCalls: 0, - }, - { - name: "read existing values not cached", - key: "test", - handlerKey: "test", - handlerValue: 1, - wantValue: 1, - expectedCalls: 1, - }, - { - name: "error no such key", - key: "test", - handlerKey: "test", - handlerError: ErrNoSuchKey, - wantErr: ErrNoSuchKey, - expectedCalls: 1, - }, - { - name: "other error", - key: "test", - handlerKey: "test", - handlerError: someOtherError, - wantErr: someOtherError, - preserveHandler: true, - expectedCalls: 2, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - testHandler := &testHandler{ - t: t, - key: tt.handlerKey, - u64: tt.handlerValue, - err: tt.handlerError, - } - cache := NewCachingHandler(testHandler) - if tt.uint64s != nil { - cache.uint64s = tt.uint64s - } - got, err := cache.ReadUInt64(tt.key) - if err != tt.wantErr { - t.Errorf("err=%v want %v", err, tt.wantErr) - } - if got != tt.wantValue { - t.Errorf("got %v want %v", got, cache.strings[tt.key]) - } - if !tt.preserveHandler { - testHandler.key, testHandler.s, testHandler.err = "do not read", "", nil - } - got, err = cache.ReadUInt64(tt.key) - if err != tt.wantErr { - t.Errorf("repeat err=%v want %v", err, tt.wantErr) - } - if got != tt.wantValue { - t.Errorf("repeat got %v want %v", got, cache.strings[tt.key]) - } - if testHandler.calls != tt.expectedCalls { - t.Errorf("calls=%v want %v", testHandler.calls, tt.expectedCalls) - } - }) - } - -} - -func TestHandlerReadBool(t *testing.T) { - tests := []struct { - name string - key string - handlerKey Key - handlerValue bool - handlerError error - preserveHandler bool - wantValue bool - wantErr error - bools map[string]bool - expectedCalls int - }{ - { - name: "read existing cached values", - key: "test", - handlerKey: "do not read", - bools: map[string]bool{"test": true}, - wantValue: true, - expectedCalls: 0, - }, - { - name: "read existing values not cached", - key: "test", - handlerKey: "test", - handlerValue: true, - wantValue: true, - expectedCalls: 1, - }, - { - name: "error no such key", - key: "test", - handlerKey: "test", - handlerError: ErrNoSuchKey, - wantErr: ErrNoSuchKey, - expectedCalls: 1, - }, - { - name: "other error", - key: "test", - handlerKey: "test", - handlerError: someOtherError, - wantErr: someOtherError, - preserveHandler: true, - expectedCalls: 2, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - testHandler := &testHandler{ - t: t, - key: tt.handlerKey, - b: tt.handlerValue, - err: tt.handlerError, - } - cache := NewCachingHandler(testHandler) - if tt.bools != nil { - cache.bools = tt.bools - } - got, err := cache.ReadBoolean(tt.key) - if err != tt.wantErr { - t.Errorf("err=%v want %v", err, tt.wantErr) - } - if got != tt.wantValue { - t.Errorf("got %v want %v", got, cache.strings[tt.key]) - } - if !tt.preserveHandler { - testHandler.key, testHandler.s, testHandler.err = "do not read", "", nil - } - got, err = cache.ReadBoolean(tt.key) - if err != tt.wantErr { - t.Errorf("repeat err=%v want %v", err, tt.wantErr) - } - if got != tt.wantValue { - t.Errorf("repeat got %v want %v", got, cache.strings[tt.key]) - } - if testHandler.calls != tt.expectedCalls { - t.Errorf("calls=%v want %v", testHandler.calls, tt.expectedCalls) - } - }) - } - -} diff --git a/util/syspolicy/handler.go b/util/syspolicy/handler.go deleted file mode 100644 index f1fad9770..000000000 --- a/util/syspolicy/handler.go +++ /dev/null @@ -1,83 +0,0 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package syspolicy - -import ( - "errors" - "sync/atomic" -) - -var ( - handlerUsed atomic.Bool - handler Handler = defaultHandler{} -) - -// Handler reads system policies from OS-specific storage. -type Handler interface { - // ReadString reads the policy setting's string value for the given key. - // It should return ErrNoSuchKey if the key does not have a value set. - ReadString(key string) (string, error) - // ReadUInt64 reads the policy setting's uint64 value for the given key. - // It should return ErrNoSuchKey if the key does not have a value set. - ReadUInt64(key string) (uint64, error) - // ReadBool reads the policy setting's boolean value for the given key. - // It should return ErrNoSuchKey if the key does not have a value set. - ReadBoolean(key string) (bool, error) - // ReadStringArray reads the policy setting's string array value for the given key. - // It should return ErrNoSuchKey if the key does not have a value set. - ReadStringArray(key string) ([]string, error) -} - -// ErrNoSuchKey is returned by a Handler when the specified key does not have a -// value set. -var ErrNoSuchKey = errors.New("no such key") - -// defaultHandler is the catch all syspolicy type for anything that isn't windows or apple. -type defaultHandler struct{} - -func (defaultHandler) ReadString(_ string) (string, error) { - return "", ErrNoSuchKey -} - -func (defaultHandler) ReadUInt64(_ string) (uint64, error) { - return 0, ErrNoSuchKey -} - -func (defaultHandler) ReadBoolean(_ string) (bool, error) { - return false, ErrNoSuchKey -} - -func (defaultHandler) ReadStringArray(_ string) ([]string, error) { - return nil, ErrNoSuchKey -} - -// markHandlerInUse is called before handler methods are called. -func markHandlerInUse() { - handlerUsed.Store(true) -} - -// RegisterHandler initializes the policy handler and ensures registration will happen once. -func RegisterHandler(h Handler) { - // Technically this assignment is not concurrency safe, but in the - // event that there was any risk of a data race, we will panic due to - // the CompareAndSwap failing. - handler = h - if !handlerUsed.CompareAndSwap(false, true) { - panic("handler was already used before registration") - } -} - -// TB is a subset of testing.TB that we use to set up test helpers. -// It's defined here to avoid pulling in the testing package. -type TB interface { - Helper() - Cleanup(func()) -} - -func SetHandlerForTest(tb TB, h Handler) { - tb.Helper() - oldHandler := handler - handler = h - tb.Cleanup(func() { handler = oldHandler }) -} diff --git a/util/syspolicy/handler_test.go b/util/syspolicy/handler_test.go deleted file mode 100644 index 39b18936f..000000000 --- a/util/syspolicy/handler_test.go +++ /dev/null @@ -1,19 +0,0 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package syspolicy - -import "testing" - -func TestDefaultHandlerReadValues(t *testing.T) { - var h defaultHandler - - got, err := h.ReadString(string(AdminConsoleVisibility)) - if got != "" || err != ErrNoSuchKey { - t.Fatalf("got %v err %v", got, err) - } - result, err := h.ReadUInt64(string(LogSCMInteractions)) - if result != 0 || err != ErrNoSuchKey { - t.Fatalf("got %v err %v", result, err) - } -} diff --git a/util/syspolicy/handler_windows.go b/util/syspolicy/handler_windows.go deleted file mode 100644 index 661853ead..000000000 --- a/util/syspolicy/handler_windows.go +++ /dev/null @@ -1,105 +0,0 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package syspolicy - -import ( - "errors" - "fmt" - - "tailscale.com/util/clientmetric" - "tailscale.com/util/winutil" -) - -var ( - windowsErrors = clientmetric.NewCounter("windows_syspolicy_errors") - windowsAny = clientmetric.NewGauge("windows_syspolicy_any") -) - -type windowsHandler struct{} - -func init() { - RegisterHandler(NewCachingHandler(windowsHandler{})) - - keyList := []struct { - isSet func(Key) bool - keys []Key - }{ - { - isSet: func(k Key) bool { - _, err := handler.ReadString(string(k)) - return err == nil - }, - keys: stringKeys, - }, - { - isSet: func(k Key) bool { - _, err := handler.ReadBoolean(string(k)) - return err == nil - }, - keys: boolKeys, - }, - { - isSet: func(k Key) bool { - _, err := handler.ReadUInt64(string(k)) - return err == nil - }, - keys: uint64Keys, - }, - } - - var anySet bool - for _, l := range keyList { - for _, k := range l.keys { - if !l.isSet(k) { - continue - } - clientmetric.NewGauge(fmt.Sprintf("windows_syspolicy_%s", k)).Set(1) - anySet = true - } - } - if anySet { - windowsAny.Set(1) - } -} - -func (windowsHandler) ReadString(key string) (string, error) { - s, err := winutil.GetPolicyString(key) - if errors.Is(err, winutil.ErrNoValue) { - err = ErrNoSuchKey - } else if err != nil { - windowsErrors.Add(1) - } - - return s, err -} - -func (windowsHandler) ReadUInt64(key string) (uint64, error) { - value, err := winutil.GetPolicyInteger(key) - if errors.Is(err, winutil.ErrNoValue) { - err = ErrNoSuchKey - } else if err != nil { - windowsErrors.Add(1) - } - return value, err -} - -func (windowsHandler) ReadBoolean(key string) (bool, error) { - value, err := winutil.GetPolicyInteger(key) - if errors.Is(err, winutil.ErrNoValue) { - err = ErrNoSuchKey - } else if err != nil { - windowsErrors.Add(1) - } - return value != 0, err -} - -func (windowsHandler) ReadStringArray(key string) ([]string, error) { - value, err := winutil.GetPolicyStringArray(key) - if errors.Is(err, winutil.ErrNoValue) { - err = ErrNoSuchKey - } else if err != nil { - windowsErrors.Add(1) - } - return value, err -} diff --git a/util/syspolicy/internal/internal.go b/util/syspolicy/internal/internal.go index 4c3e28d39..6ab147de6 100644 --- a/util/syspolicy/internal/internal.go +++ b/util/syspolicy/internal/internal.go @@ -10,9 +10,13 @@ import ( "github.com/go-json-experiment/json/jsontext" "tailscale.com/types/lazy" + "tailscale.com/util/testenv" "tailscale.com/version" ) +// Init facilitates deferred invocation of initializers. +var Init lazy.DeferredInit + // OSForTesting is the operating system override used for testing. // It follows the same naming convention as [version.OS]. var OSForTesting lazy.SyncValue[string] @@ -22,22 +26,10 @@ func OS() string { return OSForTesting.Get(version.OS) } -// TB is a subset of testing.TB that we use to set up test helpers. -// It's defined here to avoid pulling in the testing package. -type TB interface { - Helper() - Cleanup(func()) - Logf(format string, args ...any) - Error(args ...any) - Errorf(format string, args ...any) - Fatal(args ...any) - Fatalf(format string, args ...any) -} - // EqualJSONForTest compares the JSON in j1 and j2 for semantic equality. // It returns "", "", true if j1 and j2 are equal. Otherwise, it returns // indented versions of j1 and j2 and false. -func EqualJSONForTest(tb TB, j1, j2 jsontext.Value) (s1, s2 string, equal bool) { +func EqualJSONForTest(tb testenv.TB, j1, j2 jsontext.Value) (s1, s2 string, equal bool) { tb.Helper() j1 = j1.Clone() j2 = j2.Clone() @@ -53,10 +45,10 @@ func EqualJSONForTest(tb TB, j1, j2 jsontext.Value) (s1, s2 string, equal bool) return "", "", true } // Otherwise, format the values for display and return false. - if err := j1.Indent("", "\t"); err != nil { + if err := j1.Indent(); err != nil { tb.Fatal(err) } - if err := j2.Indent("", "\t"); err != nil { + if err := j2.Indent(); err != nil { tb.Fatal(err) } return j1.String(), j2.String(), false diff --git a/util/syspolicy/internal/loggerx/logger.go b/util/syspolicy/internal/loggerx/logger.go index b28610826..d1f48cbb4 100644 --- a/util/syspolicy/internal/loggerx/logger.go +++ b/util/syspolicy/internal/loggerx/logger.go @@ -6,41 +6,59 @@ package loggerx import ( "log" + "sync/atomic" "tailscale.com/types/lazy" "tailscale.com/types/logger" - "tailscale.com/util/syspolicy/internal" + "tailscale.com/util/testenv" ) const ( - errorPrefix = "syspolicy: " + normalPrefix = "syspolicy: " verbosePrefix = "syspolicy: [v2] " ) var ( - lazyErrorf lazy.SyncValue[logger.Logf] + debugLogging atomic.Bool // whether debugging logging is enabled + + lazyPrintf lazy.SyncValue[logger.Logf] lazyVerbosef lazy.SyncValue[logger.Logf] ) +// SetDebugLoggingEnabled controls whether spammy debug logging is enabled. +func SetDebugLoggingEnabled(v bool) { + debugLogging.Store(v) +} + // Errorf formats and writes an error message to the log. func Errorf(format string, args ...any) { - errorf := lazyErrorf.Get(func() logger.Logf { - return logger.WithPrefix(log.Printf, errorPrefix) - }) - errorf(format, args...) + printf(format, args...) } // Verbosef formats and writes an optional, verbose message to the log. func Verbosef(format string, args ...any) { - verbosef := lazyVerbosef.Get(func() logger.Logf { + if debugLogging.Load() { + printf(format, args...) + } else { + verbosef(format, args...) + } +} + +func printf(format string, args ...any) { + lazyPrintf.Get(func() logger.Logf { + return logger.WithPrefix(log.Printf, normalPrefix) + })(format, args...) +} + +func verbosef(format string, args ...any) { + lazyVerbosef.Get(func() logger.Logf { return logger.WithPrefix(log.Printf, verbosePrefix) - }) - verbosef(format, args...) + })(format, args...) } -// SetForTest sets the specified errorf and verbosef functions for the duration +// SetForTest sets the specified printf and verbosef functions for the duration // of tb and its subtests. -func SetForTest(tb internal.TB, errorf, verbosef logger.Logf) { - lazyErrorf.SetForTest(tb, errorf, nil) +func SetForTest(tb testenv.TB, printf, verbosef logger.Logf) { + lazyPrintf.SetForTest(tb, printf, nil) lazyVerbosef.SetForTest(tb, verbosef, nil) } diff --git a/util/syspolicy/internal/loggerx/logger_test.go b/util/syspolicy/internal/loggerx/logger_test.go new file mode 100644 index 000000000..9735b5d30 --- /dev/null +++ b/util/syspolicy/internal/loggerx/logger_test.go @@ -0,0 +1,53 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package loggerx + +import ( + "fmt" + "io" + "strings" + "testing" + + "tailscale.com/types/logger" +) + +func TestDebugLogging(t *testing.T) { + var normal, verbose strings.Builder + SetForTest(t, logfTo(&normal), logfTo(&verbose)) + + checkOutput := func(wantNormal, wantVerbose string) { + t.Helper() + if gotNormal := normal.String(); gotNormal != wantNormal { + t.Errorf("Unexpected normal output: got %q; want %q", gotNormal, wantNormal) + } + if gotVerbose := verbose.String(); gotVerbose != wantVerbose { + t.Errorf("Unexpected verbose output: got %q; want %q", gotVerbose, wantVerbose) + } + normal.Reset() + verbose.Reset() + } + + Errorf("This is an error message: %v", 42) + checkOutput("This is an error message: 42", "") + Verbosef("This is a verbose message: %v", 17) + checkOutput("", "This is a verbose message: 17") + + SetDebugLoggingEnabled(true) + Errorf("This is an error message: %v", 42) + checkOutput("This is an error message: 42", "") + Verbosef("This is a verbose message: %v", 17) + checkOutput("This is a verbose message: 17", "") + + SetDebugLoggingEnabled(false) + Errorf("This is an error message: %v", 42) + checkOutput("This is an error message: 42", "") + Verbosef("This is a verbose message: %v", 17) + checkOutput("", "This is a verbose message: 17") +} + +func logfTo(w io.Writer) logger.Logf { + return func(format string, args ...any) { + fmt.Fprintf(w, format, args...) + } +} diff --git a/util/syspolicy/internal/metrics/metrics.go b/util/syspolicy/internal/metrics/metrics.go index 2ea02278a..8f2745673 100644 --- a/util/syspolicy/internal/metrics/metrics.go +++ b/util/syspolicy/internal/metrics/metrics.go @@ -17,6 +17,7 @@ import ( "tailscale.com/util/slicesx" "tailscale.com/util/syspolicy/internal" "tailscale.com/util/syspolicy/internal/loggerx" + "tailscale.com/util/syspolicy/pkey" "tailscale.com/util/syspolicy/setting" "tailscale.com/util/testenv" ) @@ -209,7 +210,7 @@ func scopeMetrics(origin *setting.Origin) *policyScopeMetrics { var ( settingMetricsMu sync.RWMutex - settingMetricsMap map[setting.Key]*settingMetrics + settingMetricsMap map[pkey.Key]*settingMetrics ) func settingMetricsFor(setting *setting.Definition) *settingMetrics { @@ -259,7 +260,7 @@ var addMetricTestHook, setMetricTestHook syncs.AtomicValue[metricFn] // SetHooksForTest sets the specified addMetric and setMetric functions // as the metric functions for the duration of tb and all its subtests. -func SetHooksForTest(tb internal.TB, addMetric, setMetric metricFn) { +func SetHooksForTest(tb testenv.TB, addMetric, setMetric metricFn) { oldAddMetric := addMetricTestHook.Swap(addMetric) oldSetMetric := setMetricTestHook.Swap(setMetric) tb.Cleanup(func() { @@ -283,13 +284,14 @@ func SetHooksForTest(tb internal.TB, addMetric, setMetric metricFn) { lazyUserMetrics.SetForTest(tb, newScopeMetrics(setting.UserSetting), nil) } -func newSettingMetric(key setting.Key, scope setting.Scope, suffix string, typ clientmetric.Type) metric { - name := strings.ReplaceAll(string(key), setting.KeyPathSeparator, "_") +func newSettingMetric(key pkey.Key, scope setting.Scope, suffix string, typ clientmetric.Type) metric { + name := strings.ReplaceAll(string(key), string(pkey.KeyPathSeparator), "_") + name = strings.ReplaceAll(name, ".", "_") // dots are not allowed in metric names return newMetric([]string{name, metricScopeName(scope), suffix}, typ) } func newMetric(nameParts []string, typ clientmetric.Type) metric { - name := strings.Join(slicesx.Filter([]string{internal.OS(), "syspolicy"}, nameParts, isNonEmpty), "_") + name := strings.Join(slicesx.AppendNonzero([]string{internal.OS(), "syspolicy"}, nameParts), "_") switch { case !ShouldReport(): return &funcMetric{name: name, typ: typ} @@ -304,8 +306,6 @@ func newMetric(nameParts []string, typ clientmetric.Type) metric { } } -func isNonEmpty(s string) bool { return s != "" } - func metricScopeName(scope setting.Scope) string { switch scope { case setting.DeviceSetting: diff --git a/util/syspolicy/internal/metrics/metrics_test.go b/util/syspolicy/internal/metrics/metrics_test.go index 07be4773c..a99938769 100644 --- a/util/syspolicy/internal/metrics/metrics_test.go +++ b/util/syspolicy/internal/metrics/metrics_test.go @@ -10,13 +10,14 @@ import ( "tailscale.com/types/lazy" "tailscale.com/util/clientmetric" "tailscale.com/util/syspolicy/internal" + "tailscale.com/util/syspolicy/pkey" "tailscale.com/util/syspolicy/setting" ) func TestSettingMetricNames(t *testing.T) { tests := []struct { name string - key setting.Key + key pkey.Key scope setting.Scope suffix string typ clientmetric.Type diff --git a/util/syspolicy/internal/metrics/test_handler.go b/util/syspolicy/internal/metrics/test_handler.go index f9e484609..36c3f2cad 100644 --- a/util/syspolicy/internal/metrics/test_handler.go +++ b/util/syspolicy/internal/metrics/test_handler.go @@ -9,6 +9,7 @@ import ( "tailscale.com/util/clientmetric" "tailscale.com/util/set" "tailscale.com/util/syspolicy/internal" + "tailscale.com/util/testenv" ) // TestState represents a metric name and its expected value. @@ -19,13 +20,13 @@ type TestState struct { // TestHandler facilitates testing of the code that uses metrics. type TestHandler struct { - t internal.TB + t testenv.TB m map[string]int64 } // NewTestHandler returns a new TestHandler. -func NewTestHandler(t internal.TB) *TestHandler { +func NewTestHandler(t testenv.TB) *TestHandler { return &TestHandler{t, make(map[string]int64)} } diff --git a/util/syspolicy/pkey/pkey.go b/util/syspolicy/pkey/pkey.go new file mode 100644 index 000000000..e450625cd --- /dev/null +++ b/util/syspolicy/pkey/pkey.go @@ -0,0 +1,190 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package pkey defines the keys used to store system policies in the registry. +// +// This is a leaf package meant to only contain string constants, not code. +package pkey + +// Key is a string that uniquely identifies a policy and must remain unchanged +// once established and documented for a given policy setting. It may contain +// alphanumeric characters and zero or more [KeyPathSeparator]s to group +// individual policy settings into categories. +type Key string + +// KeyPathSeparator allows logical grouping of policy settings into categories. +const KeyPathSeparator = '/' + +// The const block below lists known policy keys. +// When adding a key to this list, remember to add a corresponding +// [setting.Definition] to [implicitDefinitions] in util/syspolicy/policy_keys.go. +// Otherwise, the [TestKnownKeysRegistered] test will fail as a reminder. + +const ( + // Keys with a string value + ControlURL Key = "LoginURL" // default ""; if blank, ipn uses ipn.DefaultControlURL. + LogTarget Key = "LogTarget" // default ""; if blank logging uses logtail.DefaultHost. + Tailnet Key = "Tailnet" // default ""; if blank, no tailnet name is sent to the server. + + // AlwaysOn is a boolean key that controls whether Tailscale + // should always remain in a connected state, and the user should + // not be able to disconnect at their discretion. + // + // Warning: This policy setting is experimental and may change or be removed in the future. + // It may also not be fully supported by all Tailscale clients until it is out of experimental status. + // See tailscale/corp#26247, tailscale/corp#26248 and tailscale/corp#26249 for more information. + AlwaysOn Key = "AlwaysOn.Enabled" + + // AlwaysOnOverrideWithReason is a boolean key that alters the behavior + // of [AlwaysOn]. When true, the user is allowed to disconnect Tailscale + // by providing a reason. The reason is logged and sent to the control + // for auditing purposes. It has no effect when [AlwaysOn] is false. + AlwaysOnOverrideWithReason Key = "AlwaysOn.OverrideWithReason" + + // ReconnectAfter is a string value formatted for use with time.ParseDuration() + // that defines the duration after which the client should automatically reconnect + // to the Tailscale network following a user-initiated disconnect. + // An empty string or a zero duration disables automatic reconnection. + ReconnectAfter Key = "ReconnectAfter" + + // AllowTailscaledRestart is a boolean key that controls whether users with write access + // to the LocalAPI are allowed to shutdown tailscaled with the intention of restarting it. + // On Windows, tailscaled will be restarted automatically by the service process + // (see babysitProc in cmd/tailscaled/tailscaled_windows.go). + // On other platforms, it is the client's responsibility to restart tailscaled. + AllowTailscaledRestart Key = "AllowTailscaledRestart" + + // ExitNodeID is the exit node's node id. default ""; if blank, no exit node is forced. + // Exit node ID takes precedence over exit node IP. + // To find the node ID, go to /api.md#device. + ExitNodeID Key = "ExitNodeID" + ExitNodeIP Key = "ExitNodeIP" // default ""; if blank, no exit node is forced. Value is exit node IP. + + // AllowExitNodeOverride is a boolean key that allows the user to override exit node policy settings + // and manually select an exit node. It does not allow disabling exit node usage entirely. + // It is typically used in conjunction with [ExitNodeID] set to "auto:any". + // + // Warning: This policy setting is experimental and may change, be renamed or removed in the future. + // It may also not be fully supported by all Tailscale clients until it is out of experimental status. + // See tailscale/corp#29969. + AllowExitNodeOverride Key = "ExitNode.AllowOverride" + + // Keys with a string value that specifies an option: "always", "never", "user-decides". + // The default is "user-decides" unless otherwise stated. Enforcement of + // these policies is typically performed in ipnlocal.applySysPolicy(). GUIs + // typically hide menu items related to policies that are enforced. + EnableIncomingConnections Key = "AllowIncomingConnections" + EnableServerMode Key = "UnattendedMode" + ExitNodeAllowLANAccess Key = "ExitNodeAllowLANAccess" + EnableTailscaleDNS Key = "UseTailscaleDNSSettings" + EnableTailscaleSubnets Key = "UseTailscaleSubnets" + + // EnableDNSRegistration is a string value that can be set to "always", "never" + // or "user-decides". It controls whether DNS registration and dynamic DNS + // updates are enabled for the Tailscale interface. For historical reasons + // and to maintain compatibility with existing setups, the default is "never". + // It is only used on Windows. + EnableDNSRegistration Key = "EnableDNSRegistration" + + // CheckUpdates is the key to signal if the updater should periodically + // check for updates. + CheckUpdates Key = "CheckUpdates" + // ApplyUpdates is the key to signal if updates should be automatically + // installed. Its value is "InstallUpdates" because of an awkwardly-named + // visibility option "ApplyUpdates" on MacOS. + ApplyUpdates Key = "InstallUpdates" + // EnableRunExitNode controls if the device acts as an exit node. Even when + // running as an exit node, the device must be approved by a tailnet + // administrator. Its name is slightly awkward because RunExitNodeVisibility + // predates this option but is preserved for backwards compatibility. + EnableRunExitNode Key = "AdvertiseExitNode" + + // Keys with a string value that controls visibility: "show", "hide". + // The default is "show" unless otherwise stated. Enforcement of these + // policies is typically performed by the UI code for the relevant operating + // system. + AdminConsoleVisibility Key = "AdminConsole" + NetworkDevicesVisibility Key = "NetworkDevices" + TestMenuVisibility Key = "TestMenu" + UpdateMenuVisibility Key = "UpdateMenu" + ResetToDefaultsVisibility Key = "ResetToDefaults" + // RunExitNodeVisibility controls if the "run as exit node" menu item is + // visible, without controlling the setting itself. This is preserved for + // backwards compatibility but prefer EnableRunExitNode in new deployments. + RunExitNodeVisibility Key = "RunExitNode" + PreferencesMenuVisibility Key = "PreferencesMenu" + ExitNodeMenuVisibility Key = "ExitNodesPicker" + // AutoUpdateVisibility is the key to signal if the menu item for automatic + // installation of updates should be visible. It is only used by macsys + // installations and uses the Sparkle naming convention, even though it does + // not actually control updates, merely the UI for that setting. + AutoUpdateVisibility Key = "ApplyUpdates" + // SuggestedExitNodeVisibility controls the visibility of suggested exit nodes in the client GUI. + // When this system policy is set to 'hide', an exit node suggestion won't be presented to the user as part of the exit nodes picker. + SuggestedExitNodeVisibility Key = "SuggestedExitNode" + // OnboardingFlowVisibility controls the visibility of the onboarding flow in the client GUI. + // When this system policy is set to 'hide', the onboarding flow is never shown to the user. + OnboardingFlowVisibility Key = "OnboardingFlow" + + // Keys with a string value formatted for use with time.ParseDuration(). + KeyExpirationNoticeTime Key = "KeyExpirationNotice" // default 24 hours + + // Boolean Keys that are only applicable on Windows. Booleans are stored in the registry as + // DWORD or QWORD (either is acceptable). 0 means false, and anything else means true. + // The default is 0 unless otherwise stated. + LogSCMInteractions Key = "LogSCMInteractions" + FlushDNSOnSessionUnlock Key = "FlushDNSOnSessionUnlock" + + // EncryptState is a boolean setting that specifies whether to encrypt the + // tailscaled state file. + // Windows and Linux use a TPM device, Apple uses the Keychain. + // It's a noop on other platforms. + EncryptState Key = "EncryptState" + + // HardwareAttestation is a boolean key that controls whether to use a + // hardware-backed key to bind the node identity to this device. + HardwareAttestation Key = "HardwareAttestation" + + // PostureChecking indicates if posture checking is enabled and the client shall gather + // posture data. + // Key is a string value that specifies an option: "always", "never", "user-decides". + // The default is "user-decides" unless otherwise stated. + PostureChecking Key = "PostureChecking" + // DeviceSerialNumber is the serial number of the device that is running Tailscale. + // This is used on Android, iOS and tvOS to allow IT administrators to manually give us a serial number via MDM. + // We are unable to programmatically get the serial number on mobile due to sandboxing restrictions. + DeviceSerialNumber Key = "DeviceSerialNumber" + + // ManagedByOrganizationName indicates the name of the organization managing the Tailscale + // install. It is displayed inside the client UI in a prominent location. + ManagedByOrganizationName Key = "ManagedByOrganizationName" + // ManagedByCaption is an info message displayed inside the client UI as a caption when + // ManagedByOrganizationName is set. It can be used to provide a pointer to support resources + // for Tailscale within the organization. + ManagedByCaption Key = "ManagedByCaption" + // ManagedByURL is a valid URL pointing to a support help desk for Tailscale within the + // organization. A button in the client UI provides easy access to this URL. + ManagedByURL Key = "ManagedByURL" + + // AuthKey is an auth key that will be used to login whenever the backend starts. This can be used to + // automatically authenticate managed devices, without requiring user interaction. + AuthKey Key = "AuthKey" + + // MachineCertificateSubject is the exact name of a Subject that needs + // to be present in an identity's certificate chain to sign a RegisterRequest, + // formatted as per pkix.Name.String(). The Subject may be that of the identity + // itself, an intermediate CA or the root CA. + // + // Example: "CN=Tailscale Inc Test Root CA,OU=Tailscale Inc Test Certificate Authority,O=Tailscale Inc,ST=ON,C=CA" + MachineCertificateSubject Key = "MachineCertificateSubject" + + // Hostname is the hostname of the device that is running Tailscale. + // When this policy is set, it overrides the hostname that the client + // would otherwise obtain from the OS, e.g. by calling os.Hostname(). + Hostname Key = "Hostname" + + // Keys with a string array value. + + // AllowedSuggestedExitNodes's string array value is a list of exit node IDs that restricts which exit nodes are considered when generating suggestions for exit nodes. + AllowedSuggestedExitNodes Key = "AllowedSuggestedExitNodes" +) diff --git a/util/syspolicy/policy_keys.go b/util/syspolicy/policy_keys.go index ec0556a94..3a54f9dde 100644 --- a/util/syspolicy/policy_keys.go +++ b/util/syspolicy/policy_keys.go @@ -3,110 +3,78 @@ package syspolicy -import "tailscale.com/util/syspolicy/setting" - -type Key = setting.Key - -const ( - // Keys with a string value - ControlURL Key = "LoginURL" // default ""; if blank, ipn uses ipn.DefaultControlURL. - LogTarget Key = "LogTarget" // default ""; if blank logging uses logtail.DefaultHost. - Tailnet Key = "Tailnet" // default ""; if blank, no tailnet name is sent to the server. - // ExitNodeID is the exit node's node id. default ""; if blank, no exit node is forced. - // Exit node ID takes precedence over exit node IP. - // To find the node ID, go to /api.md#device. - ExitNodeID Key = "ExitNodeID" - ExitNodeIP Key = "ExitNodeIP" // default ""; if blank, no exit node is forced. Value is exit node IP. - - // Keys with a string value that specifies an option: "always", "never", "user-decides". - // The default is "user-decides" unless otherwise stated. Enforcement of - // these policies is typically performed in ipnlocal.applySysPolicy(). GUIs - // typically hide menu items related to policies that are enforced. - EnableIncomingConnections Key = "AllowIncomingConnections" - EnableServerMode Key = "UnattendedMode" - ExitNodeAllowLANAccess Key = "ExitNodeAllowLANAccess" - EnableTailscaleDNS Key = "UseTailscaleDNSSettings" - EnableTailscaleSubnets Key = "UseTailscaleSubnets" - // CheckUpdates is the key to signal if the updater should periodically - // check for updates. - CheckUpdates Key = "CheckUpdates" - // ApplyUpdates is the key to signal if updates should be automatically - // installed. Its value is "InstallUpdates" because of an awkwardly-named - // visibility option "ApplyUpdates" on MacOS. - ApplyUpdates Key = "InstallUpdates" - // EnableRunExitNode controls if the device acts as an exit node. Even when - // running as an exit node, the device must be approved by a tailnet - // administrator. Its name is slightly awkward because RunExitNodeVisibility - // predates this option but is preserved for backwards compatibility. - EnableRunExitNode Key = "AdvertiseExitNode" - - // Keys with a string value that controls visibility: "show", "hide". - // The default is "show" unless otherwise stated. Enforcement of these - // policies is typically performed by the UI code for the relevant operating - // system. - AdminConsoleVisibility Key = "AdminConsole" - NetworkDevicesVisibility Key = "NetworkDevices" - TestMenuVisibility Key = "TestMenu" - UpdateMenuVisibility Key = "UpdateMenu" - ResetToDefaultsVisibility Key = "ResetToDefaults" - // RunExitNodeVisibility controls if the "run as exit node" menu item is - // visible, without controlling the setting itself. This is preserved for - // backwards compatibility but prefer EnableRunExitNode in new deployments. - RunExitNodeVisibility Key = "RunExitNode" - PreferencesMenuVisibility Key = "PreferencesMenu" - ExitNodeMenuVisibility Key = "ExitNodesPicker" - // AutoUpdateVisibility is the key to signal if the menu item for automatic - // installation of updates should be visible. It is only used by macsys - // installations and uses the Sparkle naming convention, even though it does - // not actually control updates, merely the UI for that setting. - AutoUpdateVisibility Key = "ApplyUpdates" - // SuggestedExitNodeVisibility controls the visibility of suggested exit nodes in the client GUI. - // When this system policy is set to 'hide', an exit node suggestion won't be presented to the user as part of the exit nodes picker. - SuggestedExitNodeVisibility Key = "SuggestedExitNode" - - // Keys with a string value formatted for use with time.ParseDuration(). - KeyExpirationNoticeTime Key = "KeyExpirationNotice" // default 24 hours - - // Boolean Keys that are only applicable on Windows. Booleans are stored in the registry as - // DWORD or QWORD (either is acceptable). 0 means false, and anything else means true. - // The default is 0 unless otherwise stated. - LogSCMInteractions Key = "LogSCMInteractions" - FlushDNSOnSessionUnlock Key = "FlushDNSOnSessionUnlock" - - // PostureChecking indicates if posture checking is enabled and the client shall gather - // posture data. - // Key is a string value that specifies an option: "always", "never", "user-decides". - // The default is "user-decides" unless otherwise stated. - PostureChecking Key = "PostureChecking" - // DeviceSerialNumber is the serial number of the device that is running Tailscale. - // This is used on iOS/tvOS to allow IT administrators to manually give us a serial number via MDM. - // We are unable to programmatically get the serial number from IOKit due to sandboxing restrictions. - DeviceSerialNumber Key = "DeviceSerialNumber" - - // ManagedByOrganizationName indicates the name of the organization managing the Tailscale - // install. It is displayed inside the client UI in a prominent location. - ManagedByOrganizationName Key = "ManagedByOrganizationName" - // ManagedByCaption is an info message displayed inside the client UI as a caption when - // ManagedByOrganizationName is set. It can be used to provide a pointer to support resources - // for Tailscale within the organization. - ManagedByCaption Key = "ManagedByCaption" - // ManagedByURL is a valid URL pointing to a support help desk for Tailscale within the - // organization. A button in the client UI provides easy access to this URL. - ManagedByURL Key = "ManagedByURL" +import ( + "tailscale.com/util/syspolicy/internal" + "tailscale.com/util/syspolicy/pkey" + "tailscale.com/util/syspolicy/setting" + "tailscale.com/util/testenv" +) - // AuthKey is an auth key that will be used to login whenever the backend starts. This can be used to - // automatically authenticate managed devices, without requiring user interaction. - AuthKey Key = "AuthKey" +// implicitDefinitions is a list of [setting.Definition] that will be registered +// automatically when the policy setting definitions are first used by the syspolicy package hierarchy. +// This includes the first time a policy needs to be read from any source. +var implicitDefinitions = []*setting.Definition{ + // Device policy settings (can only be configured on a per-device basis): + setting.NewDefinition(pkey.AllowedSuggestedExitNodes, setting.DeviceSetting, setting.StringListValue), + setting.NewDefinition(pkey.AllowExitNodeOverride, setting.DeviceSetting, setting.BooleanValue), + setting.NewDefinition(pkey.AllowTailscaledRestart, setting.DeviceSetting, setting.BooleanValue), + setting.NewDefinition(pkey.AlwaysOn, setting.DeviceSetting, setting.BooleanValue), + setting.NewDefinition(pkey.AlwaysOnOverrideWithReason, setting.DeviceSetting, setting.BooleanValue), + setting.NewDefinition(pkey.ApplyUpdates, setting.DeviceSetting, setting.PreferenceOptionValue), + setting.NewDefinition(pkey.AuthKey, setting.DeviceSetting, setting.StringValue), + setting.NewDefinition(pkey.CheckUpdates, setting.DeviceSetting, setting.PreferenceOptionValue), + setting.NewDefinition(pkey.ControlURL, setting.DeviceSetting, setting.StringValue), + setting.NewDefinition(pkey.DeviceSerialNumber, setting.DeviceSetting, setting.StringValue), + setting.NewDefinition(pkey.EnableDNSRegistration, setting.DeviceSetting, setting.PreferenceOptionValue), + setting.NewDefinition(pkey.EnableIncomingConnections, setting.DeviceSetting, setting.PreferenceOptionValue), + setting.NewDefinition(pkey.EnableRunExitNode, setting.DeviceSetting, setting.PreferenceOptionValue), + setting.NewDefinition(pkey.EnableServerMode, setting.DeviceSetting, setting.PreferenceOptionValue), + setting.NewDefinition(pkey.EnableTailscaleDNS, setting.DeviceSetting, setting.PreferenceOptionValue), + setting.NewDefinition(pkey.EnableTailscaleSubnets, setting.DeviceSetting, setting.PreferenceOptionValue), + setting.NewDefinition(pkey.ExitNodeAllowLANAccess, setting.DeviceSetting, setting.PreferenceOptionValue), + setting.NewDefinition(pkey.ExitNodeID, setting.DeviceSetting, setting.StringValue), + setting.NewDefinition(pkey.ExitNodeIP, setting.DeviceSetting, setting.StringValue), + setting.NewDefinition(pkey.FlushDNSOnSessionUnlock, setting.DeviceSetting, setting.BooleanValue), + setting.NewDefinition(pkey.EncryptState, setting.DeviceSetting, setting.BooleanValue), + setting.NewDefinition(pkey.Hostname, setting.DeviceSetting, setting.StringValue), + setting.NewDefinition(pkey.LogSCMInteractions, setting.DeviceSetting, setting.BooleanValue), + setting.NewDefinition(pkey.LogTarget, setting.DeviceSetting, setting.StringValue), + setting.NewDefinition(pkey.MachineCertificateSubject, setting.DeviceSetting, setting.StringValue), + setting.NewDefinition(pkey.PostureChecking, setting.DeviceSetting, setting.PreferenceOptionValue), + setting.NewDefinition(pkey.ReconnectAfter, setting.DeviceSetting, setting.DurationValue), + setting.NewDefinition(pkey.Tailnet, setting.DeviceSetting, setting.StringValue), + setting.NewDefinition(pkey.HardwareAttestation, setting.DeviceSetting, setting.BooleanValue), - // MachineCertificateSubject is the exact name of a Subject that needs - // to be present in an identity's certificate chain to sign a RegisterRequest, - // formatted as per pkix.Name.String(). The Subject may be that of the identity - // itself, an intermediate CA or the root CA. - // - // Example: "CN=Tailscale Inc Test Root CA,OU=Tailscale Inc Test Certificate Authority,O=Tailscale Inc,ST=ON,C=CA" - MachineCertificateSubject Key = "MachineCertificateSubject" + // User policy settings (can be configured on a user- or device-basis): + setting.NewDefinition(pkey.AdminConsoleVisibility, setting.UserSetting, setting.VisibilityValue), + setting.NewDefinition(pkey.AutoUpdateVisibility, setting.UserSetting, setting.VisibilityValue), + setting.NewDefinition(pkey.ExitNodeMenuVisibility, setting.UserSetting, setting.VisibilityValue), + setting.NewDefinition(pkey.KeyExpirationNoticeTime, setting.UserSetting, setting.DurationValue), + setting.NewDefinition(pkey.ManagedByCaption, setting.UserSetting, setting.StringValue), + setting.NewDefinition(pkey.ManagedByOrganizationName, setting.UserSetting, setting.StringValue), + setting.NewDefinition(pkey.ManagedByURL, setting.UserSetting, setting.StringValue), + setting.NewDefinition(pkey.NetworkDevicesVisibility, setting.UserSetting, setting.VisibilityValue), + setting.NewDefinition(pkey.PreferencesMenuVisibility, setting.UserSetting, setting.VisibilityValue), + setting.NewDefinition(pkey.ResetToDefaultsVisibility, setting.UserSetting, setting.VisibilityValue), + setting.NewDefinition(pkey.RunExitNodeVisibility, setting.UserSetting, setting.VisibilityValue), + setting.NewDefinition(pkey.SuggestedExitNodeVisibility, setting.UserSetting, setting.VisibilityValue), + setting.NewDefinition(pkey.TestMenuVisibility, setting.UserSetting, setting.VisibilityValue), + setting.NewDefinition(pkey.UpdateMenuVisibility, setting.UserSetting, setting.VisibilityValue), + setting.NewDefinition(pkey.OnboardingFlowVisibility, setting.UserSetting, setting.VisibilityValue), +} - // Keys with a string array value. - // AllowedSuggestedExitNodes's string array value is a list of exit node IDs that restricts which exit nodes are considered when generating suggestions for exit nodes. - AllowedSuggestedExitNodes Key = "AllowedSuggestedExitNodes" -) +func init() { + internal.Init.MustDefer(func() error { + // Avoid implicit [setting.Definition] registration during tests. + // Each test should control which policy settings to register. + // Use [setting.SetDefinitionsForTest] to specify necessary definitions, + // or [setWellKnownSettingsForTest] to set implicit definitions for the test duration. + if testenv.InTest() { + return nil + } + for _, d := range implicitDefinitions { + setting.RegisterDefinition(d) + } + return nil + }) +} diff --git a/util/syspolicy/policy_keys_test.go b/util/syspolicy/policy_keys_test.go new file mode 100644 index 000000000..c2b8d5741 --- /dev/null +++ b/util/syspolicy/policy_keys_test.go @@ -0,0 +1,93 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package syspolicy + +import ( + "fmt" + "go/ast" + "go/parser" + "go/token" + "go/types" + "os" + "reflect" + "strconv" + "testing" + + "tailscale.com/util/syspolicy/pkey" + "tailscale.com/util/syspolicy/setting" +) + +func TestKnownKeysRegistered(t *testing.T) { + const file = "pkey/pkey.go" + keyConsts, err := listStringConsts[pkey.Key](file) + if err != nil { + t.Fatalf("listStringConsts failed: %v", err) + } + if len(keyConsts) == 0 { + t.Fatalf("no key constants found in %s", file) + } + + m, err := setting.DefinitionMapOf(implicitDefinitions) + if err != nil { + t.Fatalf("definitionMapOf failed: %v", err) + } + + for _, key := range keyConsts { + t.Run(string(key), func(t *testing.T) { + d := m[key] + if d == nil { + t.Fatalf("%q was not registered", key) + } + if d.Key() != key { + t.Fatalf("d.Key got: %s, want %s", d.Key(), key) + } + }) + } +} + +func listStringConsts[T ~string](filename string) (map[string]T, error) { + fset := token.NewFileSet() + src, err := os.ReadFile(filename) + if err != nil { + return nil, err + } + + f, err := parser.ParseFile(fset, filename, src, 0) + if err != nil { + return nil, err + } + + consts := make(map[string]T) + typeName := reflect.TypeFor[T]().Name() + for _, d := range f.Decls { + g, ok := d.(*ast.GenDecl) + if !ok || g.Tok != token.CONST { + continue + } + + for _, s := range g.Specs { + vs, ok := s.(*ast.ValueSpec) + if !ok || len(vs.Names) != len(vs.Values) { + continue + } + if typ, ok := vs.Type.(*ast.Ident); !ok || typ.Name != typeName { + continue + } + + for i, n := range vs.Names { + lit, ok := vs.Values[i].(*ast.BasicLit) + if !ok { + return nil, fmt.Errorf("unexpected string literal: %v = %v", n.Name, types.ExprString(vs.Values[i])) + } + val, err := strconv.Unquote(lit.Value) + if err != nil { + return nil, fmt.Errorf("unexpected string literal: %v = %v", n.Name, lit.Value) + } + consts[n.Name] = T(val) + } + } + } + + return consts, nil +} diff --git a/util/syspolicy/policy_keys_windows.go b/util/syspolicy/policy_keys_windows.go deleted file mode 100644 index 5e9a71695..000000000 --- a/util/syspolicy/policy_keys_windows.go +++ /dev/null @@ -1,38 +0,0 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package syspolicy - -var stringKeys = []Key{ - ControlURL, - LogTarget, - Tailnet, - ExitNodeID, - ExitNodeIP, - EnableIncomingConnections, - EnableServerMode, - ExitNodeAllowLANAccess, - EnableTailscaleDNS, - EnableTailscaleSubnets, - AdminConsoleVisibility, - NetworkDevicesVisibility, - TestMenuVisibility, - UpdateMenuVisibility, - RunExitNodeVisibility, - PreferencesMenuVisibility, - ExitNodeMenuVisibility, - AutoUpdateVisibility, - ResetToDefaultsVisibility, - KeyExpirationNoticeTime, - PostureChecking, - ManagedByOrganizationName, - ManagedByCaption, - ManagedByURL, -} - -var boolKeys = []Key{ - LogSCMInteractions, - FlushDNSOnSessionUnlock, -} - -var uint64Keys = []Key{} diff --git a/util/syspolicy/policyclient/policyclient.go b/util/syspolicy/policyclient/policyclient.go new file mode 100644 index 000000000..728a16718 --- /dev/null +++ b/util/syspolicy/policyclient/policyclient.go @@ -0,0 +1,145 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package policyclient contains the minimal syspolicy interface as needed by +// client code using syspolicy. It's the part that's always linked in, even if the rest +// of syspolicy is omitted from the build. +package policyclient + +import ( + "time" + + "tailscale.com/util/syspolicy/pkey" + "tailscale.com/util/syspolicy/ptype" + "tailscale.com/util/testenv" +) + +// Client is the interface between code making questions about the system policy +// and the actual implementation. +type Client interface { + // GetString returns a string policy setting with the specified key, + // or defaultValue (and a nil error) if it does not exist. + GetString(key pkey.Key, defaultValue string) (string, error) + + // GetStringArray returns a string array policy setting with the specified key, + // or defaultValue (and a nil error) if it does not exist. + GetStringArray(key pkey.Key, defaultValue []string) ([]string, error) + + // GetBoolean returns a boolean policy setting with the specified key, + // or defaultValue (and a nil error) if it does not exist. + GetBoolean(key pkey.Key, defaultValue bool) (bool, error) + + // GetUint64 returns a numeric policy setting with the specified key, + // or defaultValue (and a nil error) if it does not exist. + GetUint64(key pkey.Key, defaultValue uint64) (uint64, error) + + // GetDuration loads a policy from the registry that can be managed by an + // enterprise policy management system and describes a duration for some + // action. The registry value should be a string that time.ParseDuration + // understands. If the registry value is "" or can not be processed, + // defaultValue (and a nil error) is returned instead. + GetDuration(key pkey.Key, defaultValue time.Duration) (time.Duration, error) + + // GetPreferenceOption loads a policy from the registry that can be + // managed by an enterprise policy management system and allows administrative + // overrides of users' choices in a way that we do not want tailcontrol to have + // the authority to set. It describes user-decides/always/never options, where + // "always" and "never" remove the user's ability to make a selection. If not + // present or set to a different value, defaultValue (and a nil error) is returned. + GetPreferenceOption(key pkey.Key, defaultValue ptype.PreferenceOption) (ptype.PreferenceOption, error) + + // GetVisibility returns whether a UI element should be visible based on + // the system's configuration. + // If unconfigured, implementations should return [ptype.VisibleByPolicy] + // and a nil error. + GetVisibility(key pkey.Key) (ptype.Visibility, error) + + // SetDebugLoggingEnabled enables or disables debug logging for the policy client. + SetDebugLoggingEnabled(enabled bool) + + // HasAnyOf returns whether at least one of the specified policy settings is + // configured, or an error if no keys are provided or the check fails. + HasAnyOf(keys ...pkey.Key) (bool, error) + + // RegisterChangeCallback registers a callback function that will be called + // whenever a policy change is detected. It returns a function to unregister + // the callback and an error if the registration fails. + RegisterChangeCallback(cb func(PolicyChange)) (unregister func(), err error) +} + +// Get returns a non-nil [Client] implementation as a function of the +// build tags. It returns a no-op implementation if the full syspolicy +// package is omitted from the build, or in tests. +func Get() Client { + if testenv.InTest() { + // This is a little redundant (the Windows implementation at least + // already does this) but it's here for redundancy and clarity, that we + // don't want to accidentally use the real system policy when running + // tests. + return NoPolicyClient{} + } + return client +} + +// RegisterClientImpl registers a [Client] implementation to be returned by +// [Get]. +func RegisterClientImpl(c Client) { + client = c +} + +var client Client = NoPolicyClient{} + +// PolicyChange is the interface representing a change in policy settings. +type PolicyChange interface { + // HasChanged reports whether the policy setting identified by the given key + // has changed. + HasChanged(pkey.Key) bool + + // HasChangedAnyOf reports whether any of the provided policy settings + // changed in this change. + HasChangedAnyOf(keys ...pkey.Key) bool +} + +// NoPolicyClient is a no-op implementation of [Client] that only +// returns default values. +type NoPolicyClient struct{} + +var _ Client = NoPolicyClient{} + +func (NoPolicyClient) GetBoolean(key pkey.Key, defaultValue bool) (bool, error) { + return defaultValue, nil +} + +func (NoPolicyClient) GetString(key pkey.Key, defaultValue string) (string, error) { + return defaultValue, nil +} + +func (NoPolicyClient) GetStringArray(key pkey.Key, defaultValue []string) ([]string, error) { + return defaultValue, nil +} + +func (NoPolicyClient) GetUint64(key pkey.Key, defaultValue uint64) (uint64, error) { + return defaultValue, nil +} + +func (NoPolicyClient) GetDuration(name pkey.Key, defaultValue time.Duration) (time.Duration, error) { + return defaultValue, nil +} + +func (NoPolicyClient) GetPreferenceOption(name pkey.Key, defaultValue ptype.PreferenceOption) (ptype.PreferenceOption, error) { + return defaultValue, nil +} + +func (NoPolicyClient) GetVisibility(name pkey.Key) (ptype.Visibility, error) { + return ptype.VisibleByPolicy, nil +} + +func (NoPolicyClient) HasAnyOf(keys ...pkey.Key) (bool, error) { + return false, nil +} + +func (NoPolicyClient) SetDebugLoggingEnabled(enabled bool) {} + +func (NoPolicyClient) RegisterChangeCallback(cb func(PolicyChange)) (unregister func(), err error) { + return func() {}, nil +} diff --git a/util/syspolicy/policytest/policytest.go b/util/syspolicy/policytest/policytest.go new file mode 100644 index 000000000..e5c1c7856 --- /dev/null +++ b/util/syspolicy/policytest/policytest.go @@ -0,0 +1,249 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package policytest contains test helpers for the syspolicy packages. +package policytest + +import ( + "fmt" + "maps" + "slices" + "sync" + "time" + + "tailscale.com/util/set" + "tailscale.com/util/syspolicy/pkey" + "tailscale.com/util/syspolicy/policyclient" + "tailscale.com/util/syspolicy/ptype" +) + +// Config is a [policyclient.Client] implementation with a static mapping of +// values. +// +// It is used for testing purposes to simulate policy client behavior. +// +// It panics if a value is Set with one type and then accessed with a different +// expected type and/or value. Some accessors such as GetPreferenceOption and +// GetVisibility support either a ptype.PreferenceOption/ptype.Visibility in the +// map, or the string representation as supported by their UnmarshalText +// methods. +// +// The map value may be an error to return that error value from the accessor. +type Config map[pkey.Key]any + +var _ policyclient.Client = Config{} + +// Set sets key to value. The value should be of the correct type that it will +// be read as later. For PreferenceOption and Visibility, you may also set them +// to 'string' values and they'll be UnmarshalText'ed into their correct value +// at Get time. +// +// As a special case, the value can also be of type error to make the accessors +// return that error value. +func (c *Config) Set(key pkey.Key, value any) { + if *c == nil { + *c = make(map[pkey.Key]any) + } + (*c)[key] = value + + if w, ok := (*c)[watchersKey].(*watchers); ok && key != watchersKey { + w.mu.Lock() + vals := slices.Collect(maps.Values(w.s)) + w.mu.Unlock() + for _, f := range vals { + f(policyChange(key)) + } + } +} + +// SetMultiple is a batch version of [Config.Set]. It copies the contents of o +// into c and does at most one notification wake-up for the whole batch. +func (c *Config) SetMultiple(o Config) { + if *c == nil { + *c = make(map[pkey.Key]any) + } + + maps.Copy(*c, o) + + if w, ok := (*c)[watchersKey].(*watchers); ok { + w.mu.Lock() + vals := slices.Collect(maps.Values(w.s)) + w.mu.Unlock() + for _, f := range vals { + f(policyChanges(o)) + } + } +} + +type policyChange pkey.Key + +func (pc policyChange) HasChanged(v pkey.Key) bool { return pkey.Key(pc) == v } +func (pc policyChange) HasChangedAnyOf(keys ...pkey.Key) bool { + return slices.Contains(keys, pkey.Key(pc)) +} + +type policyChanges map[pkey.Key]any + +func (pc policyChanges) HasChanged(v pkey.Key) bool { + _, ok := pc[v] + return ok +} +func (pc policyChanges) HasChangedAnyOf(keys ...pkey.Key) bool { + for _, k := range keys { + if pc.HasChanged(k) { + return true + } + } + return false +} + +const watchersKey = "_policytest_watchers" + +type watchers struct { + mu sync.Mutex + s set.HandleSet[func(policyclient.PolicyChange)] +} + +// EnableRegisterChangeCallback makes c support the RegisterChangeCallback +// for testing. Without calling this, the RegisterChangeCallback does nothing. +// For watchers to be notified, use the [Config.Set] method. Changing the map +// directly obviously wouldn't work. +func (c *Config) EnableRegisterChangeCallback() { + if _, ok := (*c)[watchersKey]; !ok { + c.Set(watchersKey, new(watchers)) + } +} + +func (c Config) GetStringArray(key pkey.Key, defaultVal []string) ([]string, error) { + if val, ok := c[key]; ok { + switch val := val.(type) { + case []string: + return val, nil + case error: + return nil, val + default: + panic(fmt.Sprintf("key %s is not a []string; got %T", key, val)) + } + } + return defaultVal, nil +} + +func (c Config) GetString(key pkey.Key, defaultVal string) (string, error) { + if val, ok := c[key]; ok { + switch val := val.(type) { + case string: + return val, nil + case error: + return "", val + default: + panic(fmt.Sprintf("key %s is not a string; got %T", key, val)) + } + } + return defaultVal, nil +} + +func (c Config) GetBoolean(key pkey.Key, defaultVal bool) (bool, error) { + if val, ok := c[key]; ok { + switch val := val.(type) { + case bool: + return val, nil + case error: + return false, val + default: + panic(fmt.Sprintf("key %s is not a bool; got %T", key, val)) + } + } + return defaultVal, nil +} + +func (c Config) GetUint64(key pkey.Key, defaultVal uint64) (uint64, error) { + if val, ok := c[key]; ok { + switch val := val.(type) { + case uint64: + return val, nil + case error: + return 0, val + default: + panic(fmt.Sprintf("key %s is not a uint64; got %T", key, val)) + } + } + return defaultVal, nil +} + +func (c Config) GetDuration(key pkey.Key, defaultVal time.Duration) (time.Duration, error) { + if val, ok := c[key]; ok { + switch val := val.(type) { + case time.Duration: + return val, nil + case error: + return 0, val + default: + panic(fmt.Sprintf("key %s is not a time.Duration; got %T", key, val)) + } + } + return defaultVal, nil +} + +func (c Config) GetPreferenceOption(key pkey.Key, defaultVal ptype.PreferenceOption) (ptype.PreferenceOption, error) { + if val, ok := c[key]; ok { + switch val := val.(type) { + case ptype.PreferenceOption: + return val, nil + case error: + var zero ptype.PreferenceOption + return zero, val + case string: + var p ptype.PreferenceOption + err := p.UnmarshalText(([]byte)(val)) + return p, err + default: + panic(fmt.Sprintf("key %s is not a ptype.PreferenceOption", key)) + } + } + return defaultVal, nil +} + +func (c Config) GetVisibility(key pkey.Key) (ptype.Visibility, error) { + if val, ok := c[key]; ok { + switch val := val.(type) { + case ptype.Visibility: + return val, nil + case error: + var zero ptype.Visibility + return zero, val + case string: + var p ptype.Visibility + err := p.UnmarshalText(([]byte)(val)) + return p, err + default: + panic(fmt.Sprintf("key %s is not a ptype.Visibility", key)) + } + } + return ptype.Visibility(ptype.ShowChoiceByPolicy), nil +} + +func (c Config) HasAnyOf(keys ...pkey.Key) (bool, error) { + for _, key := range keys { + if _, ok := c[key]; ok { + return true, nil + } + } + return false, nil +} + +func (c Config) RegisterChangeCallback(callback func(policyclient.PolicyChange)) (func(), error) { + w, ok := c[watchersKey].(*watchers) + if !ok { + return func() {}, nil + } + w.mu.Lock() + defer w.mu.Unlock() + h := w.s.Add(callback) + return func() { + w.mu.Lock() + defer w.mu.Unlock() + delete(w.s, h) + }, nil +} + +func (sp Config) SetDebugLoggingEnabled(enabled bool) {} diff --git a/util/syspolicy/setting/types.go b/util/syspolicy/ptype/ptype.go similarity index 88% rename from util/syspolicy/setting/types.go rename to util/syspolicy/ptype/ptype.go index 9f110ab03..65ca9e631 100644 --- a/util/syspolicy/setting/types.go +++ b/util/syspolicy/ptype/ptype.go @@ -1,11 +1,11 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -package setting - -import ( - "encoding" -) +// Package ptype contains types used by syspolicy. +// +// It's a leaf package for dependency reasons and should not contain much if any +// code, and should not import much (or anything). +package ptype // PreferenceOption is a policy that governs whether a boolean variable // is forcibly assigned an administrator-defined value, or allowed to receive @@ -18,9 +18,10 @@ const ( AlwaysByPolicy ) -// Show returns if the UI option that controls the choice administered by this -// policy should be shown. Currently this is true if and only if the policy is -// [ShowChoiceByPolicy]. +// Show reports whether the UI option that controls the choice administered by +// this policy should be shown (that is, available for users to change). +// +// Currently this is true if and only if the policy is [ShowChoiceByPolicy]. func (p PreferenceOption) Show() bool { return p == ShowChoiceByPolicy } @@ -91,11 +92,6 @@ func (p *PreferenceOption) UnmarshalText(text []byte) error { // component of a user interface is to be shown. type Visibility byte -var ( - _ encoding.TextMarshaler = (*Visibility)(nil) - _ encoding.TextUnmarshaler = (*Visibility)(nil) -) - const ( VisibleByPolicy Visibility = 'v' HiddenByPolicy Visibility = 'h' diff --git a/util/syspolicy/ptype/ptype_test.go b/util/syspolicy/ptype/ptype_test.go new file mode 100644 index 000000000..7c963398b --- /dev/null +++ b/util/syspolicy/ptype/ptype_test.go @@ -0,0 +1,24 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package ptype + +import ( + "encoding" + "testing" + + "tailscale.com/tstest/deptest" +) + +var ( + _ encoding.TextMarshaler = (*Visibility)(nil) + _ encoding.TextUnmarshaler = (*Visibility)(nil) +) + +func TestImports(t *testing.T) { + deptest.DepChecker{ + OnDep: func(dep string) { + t.Errorf("unexpected dep %q in leaf package; this package should not contain much code", dep) + }, + }.Check(t) +} diff --git a/util/syspolicy/rsop/change_callbacks.go b/util/syspolicy/rsop/change_callbacks.go new file mode 100644 index 000000000..71135bb2a --- /dev/null +++ b/util/syspolicy/rsop/change_callbacks.go @@ -0,0 +1,116 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package rsop + +import ( + "reflect" + "slices" + "sync" + "time" + + "tailscale.com/syncs" + "tailscale.com/util/set" + "tailscale.com/util/syspolicy/internal/loggerx" + "tailscale.com/util/syspolicy/pkey" + "tailscale.com/util/syspolicy/policyclient" + "tailscale.com/util/syspolicy/ptype" + "tailscale.com/util/syspolicy/setting" +) + +// Change represents a change from the Old to the New value of type T. +type Change[T any] struct { + New, Old T +} + +// PolicyChangeCallback is a function called whenever a policy changes. +type PolicyChangeCallback func(policyclient.PolicyChange) + +// PolicyChange describes a policy change. +type PolicyChange struct { + snapshots Change[*setting.Snapshot] +} + +// New returns the [setting.Snapshot] after the change. +func (c PolicyChange) New() *setting.Snapshot { + return c.snapshots.New +} + +// Old returns the [setting.Snapshot] before the change. +func (c PolicyChange) Old() *setting.Snapshot { + return c.snapshots.Old +} + +// HasChanged reports whether a policy setting with the specified [pkey.Key], has changed. +func (c PolicyChange) HasChanged(key pkey.Key) bool { + new, newErr := c.snapshots.New.GetErr(key) + old, oldErr := c.snapshots.Old.GetErr(key) + if newErr != nil && oldErr != nil { + return false + } + if newErr != nil || oldErr != nil { + return true + } + switch newVal := new.(type) { + case bool, uint64, string, ptype.Visibility, ptype.PreferenceOption, time.Duration: + return newVal != old + case []string: + oldVal, ok := old.([]string) + return !ok || !slices.Equal(newVal, oldVal) + default: + loggerx.Errorf("[unexpected] %q has an unsupported value type: %T", key, newVal) + return !reflect.DeepEqual(new, old) + } +} + +// HasChangedAnyOf reports whether any of the specified policy settings has changed. +func (c PolicyChange) HasChangedAnyOf(keys ...pkey.Key) bool { + return slices.ContainsFunc(keys, c.HasChanged) +} + +// policyChangeCallbacks are the callbacks to invoke when the effective policy changes. +// It is safe for concurrent use. +type policyChangeCallbacks struct { + mu syncs.Mutex + cbs set.HandleSet[PolicyChangeCallback] +} + +// Register adds the specified callback to be invoked whenever the policy changes. +func (c *policyChangeCallbacks) Register(callback PolicyChangeCallback) (unregister func()) { + c.mu.Lock() + handle := c.cbs.Add(callback) + c.mu.Unlock() + return func() { + c.mu.Lock() + delete(c.cbs, handle) + c.mu.Unlock() + } +} + +// Invoke calls the registered callback functions with the specified policy change info. +func (c *policyChangeCallbacks) Invoke(snapshots Change[*setting.Snapshot]) { + var wg sync.WaitGroup + defer wg.Wait() + + c.mu.Lock() + defer c.mu.Unlock() + + wg.Add(len(c.cbs)) + change := &PolicyChange{snapshots: snapshots} + for _, cb := range c.cbs { + go func() { + defer wg.Done() + cb(change) + }() + } +} + +// Close awaits the completion of active callbacks and prevents any further invocations. +func (c *policyChangeCallbacks) Close() { + c.mu.Lock() + defer c.mu.Unlock() + if c.cbs != nil { + clear(c.cbs) + c.cbs = nil + } +} diff --git a/util/syspolicy/rsop/resultant_policy.go b/util/syspolicy/rsop/resultant_policy.go new file mode 100644 index 000000000..bdda90976 --- /dev/null +++ b/util/syspolicy/rsop/resultant_policy.go @@ -0,0 +1,456 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package rsop + +import ( + "errors" + "fmt" + "slices" + "sync/atomic" + "time" + + "tailscale.com/syncs" + "tailscale.com/util/syspolicy/internal/loggerx" + "tailscale.com/util/syspolicy/setting" + "tailscale.com/util/testenv" + + "tailscale.com/util/syspolicy/source" +) + +// ErrPolicyClosed is returned by [Policy.Reload], [Policy.addSource], +// [Policy.removeSource] and [Policy.replaceSource] if the policy has been closed. +var ErrPolicyClosed = errors.New("effective policy closed") + +// The minimum and maximum wait times after detecting a policy change +// before reloading the policy. This only affects policy reloads triggered +// by a change in the underlying [source.Store] and does not impact +// synchronous, caller-initiated reloads, such as when [Policy.Reload] is called. +// +// Policy changes occurring within [policyReloadMinDelay] of each other +// will be batched together, resulting in a single policy reload +// no later than [policyReloadMaxDelay] after the first detected change. +// In other words, the effective policy will be reloaded no more often than once +// every 5 seconds, but at most 15 seconds after an underlying [source.Store] +// has issued a policy change callback. +// +// See [Policy.watchReload]. +var ( + policyReloadMinDelay = 5 * time.Second + policyReloadMaxDelay = 15 * time.Second +) + +// Policy provides access to the current effective [setting.Snapshot] for a given +// scope and allows to reload it from the underlying [source.Store] list. It also allows to +// subscribe and receive a callback whenever the effective [setting.Snapshot] is changed. +// +// It is safe for concurrent use. +type Policy struct { + scope setting.PolicyScope + + reloadCh chan reloadRequest // 1-buffered; written to when a policy reload is required + closeCh chan struct{} // closed to signal that the Policy is being closed + doneCh chan struct{} // closed by [Policy.closeInternal] + + // effective is the most recent version of the [setting.Snapshot] + // containing policy settings merged from all applicable sources. + effective atomic.Pointer[setting.Snapshot] + + changeCallbacks policyChangeCallbacks + + mu syncs.Mutex + watcherStarted bool // whether [Policy.watchReload] was started + sources source.ReadableSources + closing bool // whether [Policy.Close] was called (even if we're still closing) +} + +// newPolicy returns a new [Policy] for the specified [setting.PolicyScope] +// that tracks changes and merges policy settings read from the specified sources. +func newPolicy(scope setting.PolicyScope, sources ...*source.Source) (_ *Policy, err error) { + readableSources := make(source.ReadableSources, 0, len(sources)) + defer func() { + if err != nil { + readableSources.Close() + } + }() + for _, s := range sources { + reader, err := s.Reader() + if err != nil { + return nil, fmt.Errorf("failed to get a store reader: %w", err) + } + session, err := reader.OpenSession() + if err != nil { + return nil, fmt.Errorf("failed to open a reading session: %w", err) + } + readableSources = append(readableSources, source.ReadableSource{Source: s, ReadingSession: session}) + } + + // Sort policy sources by their precedence from lower to higher. + // For example, {UserPolicy},{ProfilePolicy},{DevicePolicy}. + readableSources.StableSort() + + p := &Policy{ + scope: scope, + sources: readableSources, + reloadCh: make(chan reloadRequest, 1), + closeCh: make(chan struct{}), + doneCh: make(chan struct{}), + } + if _, err := p.reloadNow(false); err != nil { + p.Close() + return nil, err + } + p.startWatchReloadIfNeeded() + return p, nil +} + +// IsValid reports whether p is in a valid state and has not been closed. +// +// Since p's state can be changed by other goroutines at any time, this should +// only be used as an optimization. +func (p *Policy) IsValid() bool { + select { + case <-p.closeCh: + return false + default: + return true + } +} + +// Scope returns the [setting.PolicyScope] that this policy applies to. +func (p *Policy) Scope() setting.PolicyScope { + return p.scope +} + +// Get returns the effective [setting.Snapshot]. +func (p *Policy) Get() *setting.Snapshot { + return p.effective.Load() +} + +// RegisterChangeCallback adds a function to be called whenever the effective +// policy changes. The returned function can be used to unregister the callback. +func (p *Policy) RegisterChangeCallback(callback PolicyChangeCallback) (unregister func()) { + return p.changeCallbacks.Register(callback) +} + +// Reload synchronously re-reads policy settings from the underlying list of policy sources, +// constructing a new merged [setting.Snapshot] even if the policy remains unchanged. +// In most scenarios, there's no need to re-read the policy manually. +// Instead, it is recommended to register a policy change callback, or to use +// the most recent [setting.Snapshot] returned by the [Policy.Get] method. +// +// It must not be called with p.mu held. +func (p *Policy) Reload() (*setting.Snapshot, error) { + return p.reload(true) +} + +// reload is like Reload, but allows to specify whether to re-read policy settings +// from unchanged policy sources. +// +// It must not be called with p.mu held. +func (p *Policy) reload(force bool) (*setting.Snapshot, error) { + if !p.startWatchReloadIfNeeded() { + return p.Get(), nil + } + + respCh := make(chan reloadResponse, 1) + select { + case p.reloadCh <- reloadRequest{force: force, respCh: respCh}: + // continue + case <-p.closeCh: + return nil, ErrPolicyClosed + } + select { + case resp := <-respCh: + return resp.policy, resp.err + case <-p.closeCh: + return nil, ErrPolicyClosed + } +} + +// reloadAsync requests an asynchronous background policy reload. +// The policy will be reloaded no later than in [policyReloadMaxDelay]. +// +// It must not be called with p.mu held. +func (p *Policy) reloadAsync() { + if !p.startWatchReloadIfNeeded() { + return + } + select { + case p.reloadCh <- reloadRequest{}: + // Sent. + default: + // A reload request is already en route. + } +} + +// reloadNow loads and merges policies from all sources, updating the effective policy. +// If the force parameter is true, it forcibly reloads policies +// from the underlying policy store, even if no policy changes were detected. +// +// Except for the initial policy reload during the [Policy] creation, +// this method should only be called from the [Policy.watchReload] goroutine. +func (p *Policy) reloadNow(force bool) (*setting.Snapshot, error) { + new, err := p.readAndMerge(force) + if err != nil { + return nil, err + } + old := p.effective.Swap(new) + // A nil old value indicates the initial policy load rather than a policy change. + // Additionally, we should not invoke the policy change callbacks unless the + // policy items have actually changed. + if old != nil && !old.EqualItems(new) { + snapshots := Change[*setting.Snapshot]{New: new, Old: old} + p.changeCallbacks.Invoke(snapshots) + } + return new, nil +} + +// Done returns a channel that is closed when the [Policy] is closed. +func (p *Policy) Done() <-chan struct{} { + return p.doneCh +} + +// readAndMerge reads and merges policy settings from all applicable sources, +// returning a [setting.Snapshot] with the merged result. +// If the force parameter is true, it re-reads policy settings from each source +// even if no policy change was observed, and returns an error if the read +// operation fails. +func (p *Policy) readAndMerge(force bool) (*setting.Snapshot, error) { + p.mu.Lock() + defer p.mu.Unlock() + // Start with an empty policy in the target scope. + effective := setting.NewSnapshot(nil, setting.SummaryWith(p.scope)) + // Then merge policy settings from all sources. + // Policy sources with the highest precedence (e.g., the device policy) are merged last, + // overriding any conflicting policy settings with lower precedence. + for _, s := range p.sources { + var policy *setting.Snapshot + if force { + var err error + if policy, err = s.ReadSettings(); err != nil { + return nil, err + } + } else { + policy = s.GetSettings() + } + effective = setting.MergeSnapshots(effective, policy) + } + return effective, nil +} + +// addSource adds the specified source to the list of sources used by p, +// and triggers a synchronous policy refresh. It returns an error +// if the source is not a valid source for this effective policy, +// or if the effective policy is being closed, +// or if policy refresh fails with an error. +func (p *Policy) addSource(source *source.Source) error { + return p.applySourcesChange(source, nil) +} + +// removeSource removes the specified source from the list of sources used by p, +// and triggers a synchronous policy refresh. It returns an error if the +// effective policy is being closed, or if policy refresh fails with an error. +func (p *Policy) removeSource(source *source.Source) error { + return p.applySourcesChange(nil, source) +} + +// replaceSource replaces the old source with the new source atomically, +// and triggers a synchronous policy refresh. It returns an error +// if the source is not a valid source for this effective policy, +// or if the effective policy is being closed, +// or if policy refresh fails with an error. +func (p *Policy) replaceSource(old, new *source.Source) error { + return p.applySourcesChange(new, old) +} + +func (p *Policy) applySourcesChange(toAdd, toRemove *source.Source) error { + if toAdd == toRemove { + return nil + } + if toAdd != nil && !toAdd.Scope().Contains(p.scope) { + return errors.New("scope mismatch") + } + + changed, err := func() (changed bool, err error) { + p.mu.Lock() + defer p.mu.Unlock() + if toAdd != nil && !p.sources.Contains(toAdd) { + reader, err := toAdd.Reader() + if err != nil { + return false, fmt.Errorf("failed to get a store reader: %w", err) + } + session, err := reader.OpenSession() + if err != nil { + return false, fmt.Errorf("failed to open a reading session: %w", err) + } + + addAt := p.sources.InsertionIndexOf(toAdd) + toAdd := source.ReadableSource{ + Source: toAdd, + ReadingSession: session, + } + p.sources = slices.Insert(p.sources, addAt, toAdd) + go p.watchPolicyChanges(toAdd) + changed = true + } + if toRemove != nil { + if deleteAt := p.sources.IndexOf(toRemove); deleteAt != -1 { + p.sources.DeleteAt(deleteAt) + changed = true + } + } + return changed, nil + }() + if changed { + _, err = p.reload(false) + } + return err // may be nil or non-nil +} + +func (p *Policy) watchPolicyChanges(s source.ReadableSource) { + for { + select { + case _, ok := <-s.ReadingSession.PolicyChanged(): + if !ok { + p.mu.Lock() + abruptlyClosed := slices.Contains(p.sources, s) + p.mu.Unlock() + if abruptlyClosed { + // The underlying [source.Source] was closed abruptly without + // being properly removed or replaced by another policy source. + // We can't keep this [Policy] up to date, so we should close it. + p.Close() + } + return + } + // The PolicyChanged channel was signaled. + // Request an asynchronous policy reload. + p.reloadAsync() + case <-p.closeCh: + // The [Policy] is being closed. + return + } + } +} + +// startWatchReloadIfNeeded starts [Policy.watchReload] in a new goroutine +// if the list of policy sources is not empty, it hasn't been started yet, +// and the [Policy] is not being closed. +// It reports whether [Policy.watchReload] has ever been started. +// +// It must not be called with p.mu held. +func (p *Policy) startWatchReloadIfNeeded() bool { + p.mu.Lock() + defer p.mu.Unlock() + if len(p.sources) != 0 && !p.watcherStarted && !p.closing { + go p.watchReload() + for i := range p.sources { + go p.watchPolicyChanges(p.sources[i]) + } + p.watcherStarted = true + } + return p.watcherStarted +} + +// reloadRequest describes a policy reload request. +type reloadRequest struct { + // force policy reload regardless of whether a policy change was detected. + force bool + // respCh is an optional channel. If non-nil, it makes the reload request + // synchronous and receives the result. + respCh chan<- reloadResponse +} + +// reloadResponse is a result of a synchronous policy reload. +type reloadResponse struct { + policy *setting.Snapshot + err error +} + +// watchReload processes incoming synchronous and asynchronous policy reload requests. +// +// Synchronous requests (with a non-nil respCh) are served immediately. +// +// Asynchronous requests are debounced and throttled: they are executed at least +// [policyReloadMinDelay] after the last request, but no later than [policyReloadMaxDelay] +// after the first request in a batch. +func (p *Policy) watchReload() { + defer p.closeInternal() + + force := false // whether a forced refresh was requested + var delayCh, timeoutCh <-chan time.Time + reload := func(respCh chan<- reloadResponse) { + delayCh, timeoutCh = nil, nil + policy, err := p.reloadNow(force) + if err != nil { + loggerx.Errorf("%v policy reload failed: %v\n", p.scope, err) + } + if respCh != nil { + respCh <- reloadResponse{policy: policy, err: err} + } + force = false + } + +loop: + for { + select { + case req := <-p.reloadCh: + if req.force { + force = true + } + if req.respCh != nil { + reload(req.respCh) + continue + } + if delayCh == nil { + timeoutCh = time.After(policyReloadMinDelay) + } + delayCh = time.After(policyReloadMaxDelay) + case <-delayCh: + reload(nil) + case <-timeoutCh: + reload(nil) + case <-p.closeCh: + break loop + } + } +} + +func (p *Policy) closeInternal() { + p.mu.Lock() + defer p.mu.Unlock() + p.sources.Close() + p.changeCallbacks.Close() + close(p.doneCh) + deletePolicy(p) +} + +// Close initiates the closing of the policy. +// The [Policy.Done] channel is closed to signal that the operation has been completed. +func (p *Policy) Close() { + p.mu.Lock() + alreadyClosing := p.closing + watcherStarted := p.watcherStarted + p.closing = true + p.mu.Unlock() + + if alreadyClosing { + return + } + + close(p.closeCh) + if !watcherStarted { + // Normally, closing p.closeCh signals [Policy.watchReload] to exit, + // and [Policy.closeInternal] performs the actual closing when + // [Policy.watchReload] returns. However, if the watcher was never + // started, we need to call [Policy.closeInternal] manually. + go p.closeInternal() + } +} + +func setForTest[T any](tb testenv.TB, target *T, newValue T) { + oldValue := *target + tb.Cleanup(func() { *target = oldValue }) + *target = newValue +} diff --git a/util/syspolicy/rsop/resultant_policy_test.go b/util/syspolicy/rsop/resultant_policy_test.go new file mode 100644 index 000000000..3ff142119 --- /dev/null +++ b/util/syspolicy/rsop/resultant_policy_test.go @@ -0,0 +1,983 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package rsop + +import ( + "errors" + "slices" + "sort" + "strconv" + "sync" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "tailscale.com/tstest" + "tailscale.com/util/syspolicy/pkey" + "tailscale.com/util/syspolicy/policyclient" + "tailscale.com/util/syspolicy/setting" + + "tailscale.com/util/syspolicy/source" +) + +func TestGetEffectivePolicyNoSource(t *testing.T) { + tests := []struct { + name string + scope setting.PolicyScope + }{ + { + name: "DevicePolicy", + scope: setting.DeviceScope, + }, + { + name: "CurrentProfilePolicy", + scope: setting.CurrentProfileScope, + }, + { + name: "CurrentUserPolicy", + scope: setting.CurrentUserScope, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var policy *Policy + t.Cleanup(func() { + if policy != nil { + policy.Close() + <-policy.Done() + } + }) + + // Make sure we don't create any goroutines. + // We intentionally call ResourceCheck after t.Cleanup, so that when the test exits, + // the resource check runs before the test cleanup closes the policy. + // This helps to report any unexpectedly created goroutines. + // The goal is to ensure that using the syspolicy package, and particularly + // the rsop sub-package, is not wasteful and does not create unnecessary goroutines + // on platforms without registered policy sources. + tstest.ResourceCheck(t) + + policy, err := PolicyFor(tt.scope) + if err != nil { + t.Fatalf("Failed to get effective policy for %v: %v", tt.scope, err) + } + + if got := policy.Get(); got.Len() != 0 { + t.Errorf("Snapshot: got %v; want empty", got) + } + + if got, err := policy.Reload(); err != nil { + t.Errorf("Reload failed: %v", err) + } else if got.Len() != 0 { + t.Errorf("Snapshot: got %v; want empty", got) + } + }) + } +} + +func TestRegisterSourceAndGetEffectivePolicy(t *testing.T) { + type sourceConfig struct { + name string + scope setting.PolicyScope + settingKey pkey.Key + settingValue string + wantEffective bool + } + tests := []struct { + name string + scope setting.PolicyScope + initialSources []sourceConfig + additionalSources []sourceConfig + wantSnapshot *setting.Snapshot + }{ + { + name: "DevicePolicy/NoSources", + scope: setting.DeviceScope, + wantSnapshot: setting.NewSnapshot(nil, setting.DeviceScope), + }, + { + name: "UserScope/NoSources", + scope: setting.CurrentUserScope, + wantSnapshot: setting.NewSnapshot(nil, setting.CurrentUserScope), + }, + { + name: "DevicePolicy/OneInitialSource", + scope: setting.DeviceScope, + initialSources: []sourceConfig{ + { + name: "TestSourceA", + scope: setting.DeviceScope, + settingKey: "TestKeyA", + settingValue: "TestValueA", + wantEffective: true, + }, + }, + wantSnapshot: setting.NewSnapshot(map[pkey.Key]setting.RawItem{ + "TestKeyA": setting.RawItemWith("TestValueA", nil, setting.NewNamedOrigin("TestSourceA", setting.DeviceScope)), + }, setting.NewNamedOrigin("TestSourceA", setting.DeviceScope)), + }, + { + name: "DevicePolicy/OneAdditionalSource", + scope: setting.DeviceScope, + additionalSources: []sourceConfig{ + { + name: "TestSourceA", + scope: setting.DeviceScope, + settingKey: "TestKeyA", + settingValue: "TestValueA", + wantEffective: true, + }, + }, + wantSnapshot: setting.NewSnapshot(map[pkey.Key]setting.RawItem{ + "TestKeyA": setting.RawItemWith("TestValueA", nil, setting.NewNamedOrigin("TestSourceA", setting.DeviceScope)), + }, setting.NewNamedOrigin("TestSourceA", setting.DeviceScope)), + }, + { + name: "DevicePolicy/ManyInitialSources/NoConflicts", + scope: setting.DeviceScope, + initialSources: []sourceConfig{ + { + name: "TestSourceA", + scope: setting.DeviceScope, + settingKey: "TestKeyA", + settingValue: "TestValueA", + wantEffective: true, + }, + { + name: "TestSourceB", + scope: setting.DeviceScope, + settingKey: "TestKeyB", + settingValue: "TestValueB", + wantEffective: true, + }, + { + name: "TestSourceC", + scope: setting.DeviceScope, + settingKey: "TestKeyC", + settingValue: "TestValueC", + wantEffective: true, + }, + }, + wantSnapshot: setting.NewSnapshot(map[pkey.Key]setting.RawItem{ + "TestKeyA": setting.RawItemWith("TestValueA", nil, setting.NewNamedOrigin("TestSourceA", setting.DeviceScope)), + "TestKeyB": setting.RawItemWith("TestValueB", nil, setting.NewNamedOrigin("TestSourceB", setting.DeviceScope)), + "TestKeyC": setting.RawItemWith("TestValueC", nil, setting.NewNamedOrigin("TestSourceC", setting.DeviceScope)), + }, setting.DeviceScope), + }, + { + name: "DevicePolicy/ManyInitialSources/Conflicts", + scope: setting.DeviceScope, + initialSources: []sourceConfig{ + { + name: "TestSourceA", + scope: setting.DeviceScope, + settingKey: "TestKeyA", + settingValue: "TestValueA", + wantEffective: true, + }, + { + name: "TestSourceB", + scope: setting.DeviceScope, + settingKey: "TestKeyB", + settingValue: "TestValueB", + wantEffective: true, + }, + { + name: "TestSourceC", + scope: setting.DeviceScope, + settingKey: "TestKeyA", + settingValue: "TestValueC", + wantEffective: true, + }, + }, + wantSnapshot: setting.NewSnapshot(map[pkey.Key]setting.RawItem{ + "TestKeyA": setting.RawItemWith("TestValueC", nil, setting.NewNamedOrigin("TestSourceC", setting.DeviceScope)), + "TestKeyB": setting.RawItemWith("TestValueB", nil, setting.NewNamedOrigin("TestSourceB", setting.DeviceScope)), + }, setting.DeviceScope), + }, + { + name: "DevicePolicy/MixedSources/Conflicts", + scope: setting.DeviceScope, + initialSources: []sourceConfig{ + { + name: "TestSourceA", + scope: setting.DeviceScope, + settingKey: "TestKeyA", + settingValue: "TestValueA", + wantEffective: true, + }, + { + name: "TestSourceB", + scope: setting.DeviceScope, + settingKey: "TestKeyB", + settingValue: "TestValueB", + wantEffective: true, + }, + { + name: "TestSourceC", + scope: setting.DeviceScope, + settingKey: "TestKeyA", + settingValue: "TestValueC", + wantEffective: true, + }, + }, + additionalSources: []sourceConfig{ + { + name: "TestSourceD", + scope: setting.DeviceScope, + settingKey: "TestKeyA", + settingValue: "TestValueD", + wantEffective: true, + }, + { + name: "TestSourceE", + scope: setting.DeviceScope, + settingKey: "TestKeyC", + settingValue: "TestValueE", + wantEffective: true, + }, + { + name: "TestSourceF", + scope: setting.DeviceScope, + settingKey: "TestKeyA", + settingValue: "TestValueF", + wantEffective: true, + }, + }, + wantSnapshot: setting.NewSnapshot(map[pkey.Key]setting.RawItem{ + "TestKeyA": setting.RawItemWith("TestValueF", nil, setting.NewNamedOrigin("TestSourceF", setting.DeviceScope)), + "TestKeyB": setting.RawItemWith("TestValueB", nil, setting.NewNamedOrigin("TestSourceB", setting.DeviceScope)), + "TestKeyC": setting.RawItemWith("TestValueE", nil, setting.NewNamedOrigin("TestSourceE", setting.DeviceScope)), + }, setting.DeviceScope), + }, + { + name: "UserScope/Init-DeviceSource", + scope: setting.CurrentUserScope, + initialSources: []sourceConfig{ + { + name: "TestSourceDevice", + scope: setting.DeviceScope, + settingKey: "TestKeyA", + settingValue: "DeviceValue", + wantEffective: true, + }, + }, + wantSnapshot: setting.NewSnapshot(map[pkey.Key]setting.RawItem{ + "TestKeyA": setting.RawItemWith("DeviceValue", nil, setting.NewNamedOrigin("TestSourceDevice", setting.DeviceScope)), + }, setting.CurrentUserScope, setting.NewNamedOrigin("TestSourceDevice", setting.DeviceScope)), + }, + { + name: "UserScope/Init-DeviceSource/Add-UserSource", + scope: setting.CurrentUserScope, + initialSources: []sourceConfig{ + { + name: "TestSourceDevice", + scope: setting.DeviceScope, + settingKey: "TestKeyA", + settingValue: "DeviceValue", + wantEffective: true, + }, + }, + additionalSources: []sourceConfig{ + { + name: "TestSourceUser", + scope: setting.CurrentUserScope, + settingKey: "TestKeyB", + settingValue: "UserValue", + wantEffective: true, + }, + }, + wantSnapshot: setting.NewSnapshot(map[pkey.Key]setting.RawItem{ + "TestKeyA": setting.RawItemWith("DeviceValue", nil, setting.NewNamedOrigin("TestSourceDevice", setting.DeviceScope)), + "TestKeyB": setting.RawItemWith("UserValue", nil, setting.NewNamedOrigin("TestSourceUser", setting.CurrentUserScope)), + }, setting.CurrentUserScope), + }, + { + name: "UserScope/Init-DeviceSource/Add-UserSource-and-ProfileSource", + scope: setting.CurrentUserScope, + initialSources: []sourceConfig{ + { + name: "TestSourceDevice", + scope: setting.DeviceScope, + settingKey: "TestKeyA", + settingValue: "DeviceValue", + wantEffective: true, + }, + }, + additionalSources: []sourceConfig{ + { + name: "TestSourceProfile", + scope: setting.CurrentProfileScope, + settingKey: "TestKeyB", + settingValue: "ProfileValue", + wantEffective: true, + }, + { + name: "TestSourceUser", + scope: setting.CurrentUserScope, + settingKey: "TestKeyB", + settingValue: "UserValue", + wantEffective: true, + }, + }, + wantSnapshot: setting.NewSnapshot(map[pkey.Key]setting.RawItem{ + "TestKeyA": setting.RawItemWith("DeviceValue", nil, setting.NewNamedOrigin("TestSourceDevice", setting.DeviceScope)), + "TestKeyB": setting.RawItemWith("ProfileValue", nil, setting.NewNamedOrigin("TestSourceProfile", setting.CurrentProfileScope)), + }, setting.CurrentUserScope), + }, + { + name: "DevicePolicy/User-Source-does-not-apply", + scope: setting.DeviceScope, + initialSources: []sourceConfig{ + { + name: "TestSourceDevice", + scope: setting.DeviceScope, + settingKey: "TestKeyA", + settingValue: "DeviceValue", + wantEffective: true, + }, + }, + additionalSources: []sourceConfig{ + { + name: "TestSourceUser", + scope: setting.CurrentUserScope, + settingKey: "TestKeyA", + settingValue: "UserValue", + wantEffective: false, // Registering a user source should have no impact on the device policy. + }, + }, + wantSnapshot: setting.NewSnapshot(map[pkey.Key]setting.RawItem{ + "TestKeyA": setting.RawItemWith("DeviceValue", nil, setting.NewNamedOrigin("TestSourceDevice", setting.DeviceScope)), + }, setting.NewNamedOrigin("TestSourceDevice", setting.DeviceScope)), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Register all settings that we use in this test. + var definitions []*setting.Definition + for _, source := range slices.Concat(tt.initialSources, tt.additionalSources) { + definitions = append(definitions, setting.NewDefinition(source.settingKey, tt.scope.Kind(), setting.StringValue)) + } + if err := setting.SetDefinitionsForTest(t, definitions...); err != nil { + t.Fatalf("SetDefinitionsForTest failed: %v", err) + } + + // Add the initial policy sources. + var wantSources []*source.Source + for _, s := range tt.initialSources { + store := source.NewTestStoreOf(t, source.TestSettingOf(s.settingKey, s.settingValue)) + source := source.NewSource(s.name, s.scope, store) + if err := registerSource(source); err != nil { + t.Fatalf("Failed to register policy source: %v", source) + } + if s.wantEffective { + wantSources = append(wantSources, source) + } + t.Cleanup(func() { unregisterSource(source) }) + } + + // Retrieve the effective policy. + policy, err := policyForTest(t, tt.scope) + if err != nil { + t.Fatalf("Failed to get effective policy for %v: %v", tt.scope, err) + } + + checkPolicySources(t, policy, wantSources) + + // Add additional setting sources. + for _, s := range tt.additionalSources { + store := source.NewTestStoreOf(t, source.TestSettingOf(s.settingKey, s.settingValue)) + source := source.NewSource(s.name, s.scope, store) + if err := registerSource(source); err != nil { + t.Fatalf("Failed to register additional policy source: %v", source) + } + if s.wantEffective { + wantSources = append(wantSources, source) + } + t.Cleanup(func() { unregisterSource(source) }) + } + + checkPolicySources(t, policy, wantSources) + + // Verify the final effective settings snapshots. + if got := policy.Get(); !got.Equal(tt.wantSnapshot) { + t.Errorf("Snapshot: got %v; want %v", got, tt.wantSnapshot) + } + }) + } +} + +func TestPolicyFor(t *testing.T) { + tests := []struct { + name string + scopeA, scopeB setting.PolicyScope + closePolicy bool // indicates whether to close policyA before retrieving policyB + wantSame bool // specifies whether policyA and policyB should reference the same [Policy] instance + }{ + { + name: "Device/Device", + scopeA: setting.DeviceScope, + scopeB: setting.DeviceScope, + wantSame: true, + }, + { + name: "Device/CurrentProfile", + scopeA: setting.DeviceScope, + scopeB: setting.CurrentProfileScope, + wantSame: false, + }, + { + name: "Device/CurrentUser", + scopeA: setting.DeviceScope, + scopeB: setting.CurrentUserScope, + wantSame: false, + }, + { + name: "CurrentProfile/CurrentProfile", + scopeA: setting.CurrentProfileScope, + scopeB: setting.CurrentProfileScope, + wantSame: true, + }, + { + name: "CurrentProfile/CurrentUser", + scopeA: setting.CurrentProfileScope, + scopeB: setting.CurrentUserScope, + wantSame: false, + }, + { + name: "CurrentUser/CurrentUser", + scopeA: setting.CurrentUserScope, + scopeB: setting.CurrentUserScope, + wantSame: true, + }, + { + name: "UserA/UserA", + scopeA: setting.UserScopeOf("UserA"), + scopeB: setting.UserScopeOf("UserA"), + wantSame: true, + }, + { + name: "UserA/UserB", + scopeA: setting.UserScopeOf("UserA"), + scopeB: setting.UserScopeOf("UserB"), + wantSame: false, + }, + { + name: "New-after-close", + scopeA: setting.DeviceScope, + scopeB: setting.DeviceScope, + closePolicy: true, + wantSame: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + policyA, err := policyForTest(t, tt.scopeA) + if err != nil { + t.Fatalf("Failed to get effective policy for %v: %v", tt.scopeA, err) + } + + if tt.closePolicy { + policyA.Close() + } + + policyB, err := policyForTest(t, tt.scopeB) + if err != nil { + t.Fatalf("Failed to get effective policy for %v: %v", tt.scopeB, err) + } + + if gotSame := policyA == policyB; gotSame != tt.wantSame { + t.Fatalf("Got same: %v; want same %v", gotSame, tt.wantSame) + } + }) + } +} + +func TestPolicyChangeHasChanged(t *testing.T) { + tests := []struct { + name string + old, new map[pkey.Key]setting.RawItem + wantChanged []pkey.Key + wantUnchanged []pkey.Key + }{ + { + name: "String-Settings", + old: map[pkey.Key]setting.RawItem{ + "ChangedSetting": setting.RawItemOf("Old"), + "UnchangedSetting": setting.RawItemOf("Value"), + }, + new: map[pkey.Key]setting.RawItem{ + "ChangedSetting": setting.RawItemOf("New"), + "UnchangedSetting": setting.RawItemOf("Value"), + }, + wantChanged: []pkey.Key{"ChangedSetting"}, + wantUnchanged: []pkey.Key{"UnchangedSetting"}, + }, + { + name: "UInt64-Settings", + old: map[pkey.Key]setting.RawItem{ + "ChangedSetting": setting.RawItemOf(uint64(0)), + "UnchangedSetting": setting.RawItemOf(uint64(42)), + }, + new: map[pkey.Key]setting.RawItem{ + "ChangedSetting": setting.RawItemOf(uint64(1)), + "UnchangedSetting": setting.RawItemOf(uint64(42)), + }, + wantChanged: []pkey.Key{"ChangedSetting"}, + wantUnchanged: []pkey.Key{"UnchangedSetting"}, + }, + { + name: "StringSlice-Settings", + old: map[pkey.Key]setting.RawItem{ + "ChangedSetting": setting.RawItemOf([]string{"Chicago"}), + "UnchangedSetting": setting.RawItemOf([]string{"String1", "String2"}), + }, + new: map[pkey.Key]setting.RawItem{ + "ChangedSetting": setting.RawItemOf([]string{"New York"}), + "UnchangedSetting": setting.RawItemOf([]string{"String1", "String2"}), + }, + wantChanged: []pkey.Key{"ChangedSetting"}, + wantUnchanged: []pkey.Key{"UnchangedSetting"}, + }, + { + name: "Int8-Settings", // We don't have actual int8 settings, but this should still work. + old: map[pkey.Key]setting.RawItem{ + "ChangedSetting": setting.RawItemOf(int8(0)), + "UnchangedSetting": setting.RawItemOf(int8(42)), + }, + new: map[pkey.Key]setting.RawItem{ + "ChangedSetting": setting.RawItemOf(int8(1)), + "UnchangedSetting": setting.RawItemOf(int8(42)), + }, + wantChanged: []pkey.Key{"ChangedSetting"}, + wantUnchanged: []pkey.Key{"UnchangedSetting"}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + old := setting.NewSnapshot(tt.old) + new := setting.NewSnapshot(tt.new) + change := PolicyChange{Change[*setting.Snapshot]{old, new}} + for _, wantChanged := range tt.wantChanged { + if !change.HasChanged(wantChanged) { + t.Errorf("%q changed: got false; want true", wantChanged) + } + } + for _, wantUnchanged := range tt.wantUnchanged { + if change.HasChanged(wantUnchanged) { + t.Errorf("%q unchanged: got true; want false", wantUnchanged) + } + } + }) + } +} + +func TestChangePolicySetting(t *testing.T) { + // Register policy settings used in this test. + settingA := setting.NewDefinition("TestSettingA", setting.DeviceSetting, setting.StringValue) + settingB := setting.NewDefinition("TestSettingB", setting.DeviceSetting, setting.StringValue) + if err := setting.SetDefinitionsForTest(t, settingA, settingB); err != nil { + t.Fatalf("SetDefinitionsForTest failed: %v", err) + } + + // Register a test policy store and create a effective policy that reads the policy settings from it. + store := source.NewTestStoreOf[string](t) + if _, err := RegisterStoreForTest(t, "TestSource", setting.DeviceScope, store); err != nil { + t.Fatalf("Failed to register policy store: %v", err) + } + + setForTest(t, &policyReloadMinDelay, 100*time.Millisecond) + setForTest(t, &policyReloadMaxDelay, 500*time.Millisecond) + + policy, err := policyForTest(t, setting.DeviceScope) + if err != nil { + t.Fatalf("Failed to get effective policy: %v", err) + } + + // The policy setting is not configured yet. + if _, ok := policy.Get().GetSetting(settingA.Key()); ok { + t.Fatalf("Policy setting %q unexpectedly exists", settingA.Key()) + } + + // Subscribe to the policy change callback... + policyChanged := make(chan policyclient.PolicyChange) + unregister := policy.RegisterChangeCallback(func(pc policyclient.PolicyChange) { policyChanged <- pc }) + t.Cleanup(unregister) + + // ...make the change, and measure the time between initiating the change + // and receiving the callback. + start := time.Now() + const wantValueA = "TestValueA" + store.SetStrings(source.TestSettingOf(settingA.Key(), wantValueA)) + change := (<-policyChanged).(*PolicyChange) + gotDelay := time.Since(start) + + // Ensure there is at least a [policyReloadMinDelay] delay between + // a change and the policy reload along with the callback invocation. + // This prevents reloading policy settings too frequently + // when multiple settings change within a short period of time. + if gotDelay < policyReloadMinDelay { + t.Errorf("Delay: got %v; want >= %v", gotDelay, policyReloadMinDelay) + } + + // Verify that the [PolicyChange] passed to the policy change callback + // contains the correct information regarding the policy setting changes. + if !change.HasChanged(settingA.Key()) { + t.Errorf("Policy setting %q has not changed", settingA.Key()) + } + if change.HasChanged(settingB.Key()) { + t.Errorf("Policy setting %q was unexpectedly changed", settingB.Key()) + } + if _, ok := change.Old().GetSetting(settingA.Key()); ok { + t.Fatalf("Policy setting %q unexpectedly exists", settingA.Key()) + } + if gotValue := change.New().Get(settingA.Key()); gotValue != wantValueA { + t.Errorf("Policy setting %q: got %q; want %q", settingA.Key(), gotValue, wantValueA) + } + + // And also verify that the current (most recent) [setting.Snapshot] + // includes the change we just made. + if gotValue := policy.Get().Get(settingA.Key()); gotValue != wantValueA { + t.Errorf("Policy setting %q: got %q; want %q", settingA.Key(), gotValue, wantValueA) + } + + // Now, let's change another policy setting value N times. + const N = 10 + wantValueB := strconv.Itoa(N) + start = time.Now() + for i := range N { + store.SetStrings(source.TestSettingOf(settingB.Key(), strconv.Itoa(i+1))) + } + + // The callback should be invoked only once, even though the policy setting + // has changed N times. + change = (<-policyChanged).(*PolicyChange) + gotDelay = time.Since(start) + gotCallbacks := 1 +drain: + for { + select { + case <-policyChanged: + gotCallbacks++ + case <-time.After(policyReloadMaxDelay): + break drain + } + } + if wantCallbacks := 1; gotCallbacks > wantCallbacks { + t.Errorf("Callbacks: got %d; want %d", gotCallbacks, wantCallbacks) + } + + // Additionally, the policy change callback should be received no sooner + // than [policyReloadMinDelay] and no later than [policyReloadMaxDelay]. + if gotDelay < policyReloadMinDelay || gotDelay > policyReloadMaxDelay { + t.Errorf("Delay: got %v; want >= %v && <= %v", gotDelay, policyReloadMinDelay, policyReloadMaxDelay) + } + + // Verify that the [PolicyChange] received via the callback + // contains the final policy setting value. + if !change.HasChanged(settingB.Key()) { + t.Errorf("Policy setting %q has not changed", settingB.Key()) + } + if change.HasChanged(settingA.Key()) { + t.Errorf("Policy setting %q was unexpectedly changed", settingA.Key()) + } + if _, ok := change.Old().GetSetting(settingB.Key()); ok { + t.Fatalf("Policy setting %q unexpectedly exists", settingB.Key()) + } + if gotValue := change.New().Get(settingB.Key()); gotValue != wantValueB { + t.Errorf("Policy setting %q: got %q; want %q", settingB.Key(), gotValue, wantValueB) + } + + // Lastly, if a policy store issues a change notification, but the effective policy + // remains unchanged, the [Policy] should ignore it without invoking the change callbacks. + store.NotifyPolicyChanged() + select { + case <-policyChanged: + t.Fatal("Unexpected policy changed notification") + case <-time.After(policyReloadMaxDelay): + } +} + +func TestClosePolicySource(t *testing.T) { + testSetting := setting.NewDefinition("TestSetting", setting.DeviceSetting, setting.StringValue) + if err := setting.SetDefinitionsForTest(t, testSetting); err != nil { + t.Fatalf("SetDefinitionsForTest failed: %v", err) + } + + wantSettingValue := "TestValue" + store := source.NewTestStoreOf(t, source.TestSettingOf(testSetting.Key(), wantSettingValue)) + if _, err := RegisterStoreForTest(t, "TestSource", setting.DeviceScope, store); err != nil { + t.Fatalf("Failed to register policy store: %v", err) + } + policy, err := policyForTest(t, setting.DeviceScope) + if err != nil { + t.Fatalf("Failed to get effective policy: %v", err) + } + + initialSnapshot, err := policy.Reload() + if err != nil { + t.Fatalf("Failed to reload policy: %v", err) + } + if gotSettingValue, err := initialSnapshot.GetErr(testSetting.Key()); err != nil { + t.Fatalf("Failed to get %q setting value: %v", testSetting.Key(), err) + } else if gotSettingValue != wantSettingValue { + t.Fatalf("Setting %q: got %q; want %q", testSetting.Key(), gotSettingValue, wantSettingValue) + } + + store.Close() + + // Closing a policy source abruptly without removing it first should invalidate and close the policy. + <-policy.Done() + if policy.IsValid() { + t.Fatal("The policy was not properly closed") + } + + // The resulting policy snapshot should remain valid and unchanged. + finalSnapshot := policy.Get() + if !finalSnapshot.Equal(initialSnapshot) { + t.Fatal("Policy snapshot has changed") + } + if gotSettingValue, err := finalSnapshot.GetErr(testSetting.Key()); err != nil { + t.Fatalf("Failed to get final %q setting value: %v", testSetting.Key(), err) + } else if gotSettingValue != wantSettingValue { + t.Fatalf("Setting %q: got %q; want %q", testSetting.Key(), gotSettingValue, wantSettingValue) + } + + // However, any further requests to reload the policy should fail. + if _, err := policy.Reload(); err == nil || !errors.Is(err, ErrPolicyClosed) { + t.Fatalf("Reload: gotErr: %v; wantErr: %v", err, ErrPolicyClosed) + } +} + +func TestRemovePolicySource(t *testing.T) { + // Register policy settings used in this test. + settingA := setting.NewDefinition("TestSettingA", setting.DeviceSetting, setting.StringValue) + settingB := setting.NewDefinition("TestSettingB", setting.DeviceSetting, setting.StringValue) + if err := setting.SetDefinitionsForTest(t, settingA, settingB); err != nil { + t.Fatalf("SetDefinitionsForTest failed: %v", err) + } + + // Register two policy stores. + storeA := source.NewTestStoreOf(t, source.TestSettingOf(settingA.Key(), "A")) + storeRegA, err := RegisterStoreForTest(t, "TestSourceA", setting.DeviceScope, storeA) + if err != nil { + t.Fatalf("Failed to register policy store A: %v", err) + } + storeB := source.NewTestStoreOf(t, source.TestSettingOf(settingB.Key(), "B")) + storeRegB, err := RegisterStoreForTest(t, "TestSourceB", setting.DeviceScope, storeB) + if err != nil { + t.Fatalf("Failed to register policy store A: %v", err) + } + + // Create a effective [Policy] that reads policy settings from the two stores. + policy, err := policyForTest(t, setting.DeviceScope) + if err != nil { + t.Fatalf("Failed to get effective policy: %v", err) + } + + // Verify that the [Policy] uses both stores and includes policy settings from each. + if gotSources, wantSources := len(policy.sources), 2; gotSources != wantSources { + t.Fatalf("Policy Sources: got %v; want %v", gotSources, wantSources) + } + if got, want := policy.Get().Get(settingA.Key()), "A"; got != want { + t.Fatalf("Setting %q: got %q; want %q", settingA.Key(), got, want) + } + if got, want := policy.Get().Get(settingB.Key()), "B"; got != want { + t.Fatalf("Setting %q: got %q; want %q", settingB.Key(), got, want) + } + + // Unregister Store A and verify that the effective policy remains valid. + // It should no longer use the removed store or include any policy settings from it. + if err := storeRegA.Unregister(); err != nil { + t.Fatalf("Failed to unregister Store A: %v", err) + } + if !policy.IsValid() { + t.Fatalf("Policy was unexpectedly closed") + } + if gotSources, wantSources := len(policy.sources), 1; gotSources != wantSources { + t.Fatalf("Policy Sources: got %v; want %v", gotSources, wantSources) + } + if got, want := policy.Get().Get(settingA.Key()), any(nil); got != want { + t.Fatalf("Setting %q: got %q; want %q", settingA.Key(), got, want) + } + if got, want := policy.Get().Get(settingB.Key()), "B"; got != want { + t.Fatalf("Setting %q: got %q; want %q", settingB.Key(), got, want) + } + + // Unregister Store B and verify that the effective policy is still valid. + // However, it should be empty since there are no associated policy sources. + if err := storeRegB.Unregister(); err != nil { + t.Fatalf("Failed to unregister Store B: %v", err) + } + if !policy.IsValid() { + t.Fatalf("Policy was unexpectedly closed") + } + if gotSources, wantSources := len(policy.sources), 0; gotSources != wantSources { + t.Fatalf("Policy Sources: got %v; want %v", gotSources, wantSources) + } + if got := policy.Get(); got.Len() != 0 { + t.Fatalf("Settings: got %v; want {Empty}", got) + } +} + +func TestReplacePolicySource(t *testing.T) { + setForTest(t, &policyReloadMinDelay, 100*time.Millisecond) + setForTest(t, &policyReloadMaxDelay, 500*time.Millisecond) + + // Register policy settings used in this test. + testSetting := setting.NewDefinition("TestSettingA", setting.DeviceSetting, setting.StringValue) + if err := setting.SetDefinitionsForTest(t, testSetting); err != nil { + t.Fatalf("SetDefinitionsForTest failed: %v", err) + } + + // Create two policy stores. + initialStore := source.NewTestStoreOf(t, source.TestSettingOf(testSetting.Key(), "InitialValue")) + newStore := source.NewTestStoreOf(t, source.TestSettingOf(testSetting.Key(), "NewValue")) + unchangedStore := source.NewTestStoreOf(t, source.TestSettingOf(testSetting.Key(), "NewValue")) + + // Register the initial store and create a effective [Policy] that reads policy settings from it. + reg, err := RegisterStoreForTest(t, "TestStore", setting.DeviceScope, initialStore) + if err != nil { + t.Fatalf("Failed to register the initial store: %v", err) + } + policy, err := policyForTest(t, setting.DeviceScope) + if err != nil { + t.Fatalf("Failed to get effective policy: %v", err) + } + + // Verify that the test setting has its initial value. + if got, want := policy.Get().Get(testSetting.Key()), "InitialValue"; got != want { + t.Fatalf("Setting %q: got %q; want %q", testSetting.Key(), got, want) + } + + // Subscribe to the policy change callback. + policyChanged := make(chan policyclient.PolicyChange, 1) + unregister := policy.RegisterChangeCallback(func(pc policyclient.PolicyChange) { policyChanged <- pc }) + t.Cleanup(unregister) + + // Now, let's replace the initial store with the new store. + reg, err = reg.ReplaceStore(newStore) + if err != nil { + t.Fatalf("Failed to replace the policy store: %v", err) + } + t.Cleanup(func() { reg.Unregister() }) + + // We should receive a policy change notification as the setting value has changed. + <-policyChanged + + // Verify that the test setting has the new value. + if got, want := policy.Get().Get(testSetting.Key()), "NewValue"; got != want { + t.Fatalf("Setting %q: got %q; want %q", testSetting.Key(), got, want) + } + + // Replacing a policy store with an identical one containing the same + // values for the same settings should not be considered a policy change. + reg, err = reg.ReplaceStore(unchangedStore) + if err != nil { + t.Fatalf("Failed to replace the policy store: %v", err) + } + t.Cleanup(func() { reg.Unregister() }) + + select { + case <-policyChanged: + t.Fatal("Unexpected policy changed notification") + default: + <-time.After(policyReloadMaxDelay) + } +} + +func TestAddClosedPolicySource(t *testing.T) { + store := source.NewTestStoreOf[string](t) + if _, err := RegisterStoreForTest(t, "TestSource", setting.DeviceScope, store); err != nil { + t.Fatalf("Failed to register policy store: %v", err) + } + store.Close() + + _, err := policyForTest(t, setting.DeviceScope) + if err == nil || !errors.Is(err, source.ErrStoreClosed) { + t.Fatalf("got: %v; want: %v", err, source.ErrStoreClosed) + } +} + +func TestClosePolicyMoreThanOnce(t *testing.T) { + tests := []struct { + name string + numSources int + }{ + { + name: "NoSources", + numSources: 0, + }, + { + name: "OneSource", + numSources: 1, + }, + { + name: "ManySources", + numSources: 10, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + for i := range tt.numSources { + store := source.NewTestStoreOf[string](t) + if _, err := RegisterStoreForTest(t, "TestSource #"+strconv.Itoa(i), setting.DeviceScope, store); err != nil { + t.Fatalf("Failed to register policy store: %v", err) + } + } + + policy, err := policyForTest(t, setting.DeviceScope) + if err != nil { + t.Fatalf("failed to get effective policy: %v", err) + } + + const N = 10000 + var wg sync.WaitGroup + for range N { + wg.Add(1) + go func() { + wg.Done() + policy.Close() + <-policy.Done() + }() + } + wg.Wait() + }) + } +} + +func checkPolicySources(tb testing.TB, gotPolicy *Policy, wantSources []*source.Source) { + tb.Helper() + sort.SliceStable(wantSources, func(i, j int) bool { + return wantSources[i].Compare(wantSources[j]) < 0 + }) + gotSources := make([]*source.Source, len(gotPolicy.sources)) + for i := range gotPolicy.sources { + gotSources[i] = gotPolicy.sources[i].Source + } + type sourceSummary struct{ Name, Scope string } + toSourceSummary := cmp.Transformer("source", func(s *source.Source) sourceSummary { return sourceSummary{s.Name(), s.Scope().String()} }) + if diff := cmp.Diff(wantSources, gotSources, toSourceSummary, cmpopts.EquateEmpty()); diff != "" { + tb.Errorf("Policy Sources mismatch: %v", diff) + } +} + +// policyForTest is like [PolicyFor], but it deletes the policy +// when tb and all its subtests complete. +func policyForTest(tb testing.TB, target setting.PolicyScope) (*Policy, error) { + tb.Helper() + + policy, err := PolicyFor(target) + if err != nil { + return nil, err + } + tb.Cleanup(func() { + policy.Close() + <-policy.Done() + deletePolicy(policy) + }) + return policy, nil +} diff --git a/util/syspolicy/rsop/rsop.go b/util/syspolicy/rsop/rsop.go new file mode 100644 index 000000000..333dca643 --- /dev/null +++ b/util/syspolicy/rsop/rsop.go @@ -0,0 +1,173 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package rsop facilitates [source.Store] registration via [RegisterStore] +// and provides access to the effective policy merged from all registered sources +// via [PolicyFor]. +package rsop + +import ( + "errors" + "fmt" + "slices" + + "tailscale.com/syncs" + "tailscale.com/util/slicesx" + "tailscale.com/util/syspolicy/internal" + "tailscale.com/util/syspolicy/setting" + "tailscale.com/util/syspolicy/source" +) + +var ( + policyMu syncs.Mutex // protects [policySources] and [effectivePolicies] + policySources []*source.Source // all registered policy sources + effectivePolicies []*Policy // all active (non-closed) effective policies returned by [PolicyFor] + + // effectivePolicyLRU is an LRU cache of [Policy] by [setting.Scope]. + // Although there could be multiple [setting.PolicyScope] instances with the same [setting.Scope], + // such as two user scopes for different users, there is only one [setting.DeviceScope], only one + // [setting.CurrentProfileScope], and in most cases, only one active user scope. + // Therefore, cache misses that require falling back to [effectivePolicies] are extremely rare. + // It's a fixed-size array of atomic values and can be accessed without [policyMu] held. + effectivePolicyLRU [setting.NumScopes]syncs.AtomicValue[*Policy] +) + +// PolicyFor returns the [Policy] for the specified scope, +// creating it from the registered [source.Store]s if it doesn't already exist. +func PolicyFor(scope setting.PolicyScope) (*Policy, error) { + if err := internal.Init.Do(); err != nil { + return nil, err + } + policy := effectivePolicyLRU[scope.Kind()].Load() + if policy != nil && policy.Scope() == scope && policy.IsValid() { + return policy, nil + } + return policyForSlow(scope) +} + +func policyForSlow(scope setting.PolicyScope) (policy *Policy, err error) { + defer func() { + // Always update the LRU cache on exit if we found (or created) + // a policy for the specified scope. + if policy != nil { + effectivePolicyLRU[scope.Kind()].Store(policy) + } + }() + + policyMu.Lock() + defer policyMu.Unlock() + if policy, ok := findPolicyByScopeLocked(scope); ok { + return policy, nil + } + + // If there is no existing effective policy for the specified scope, + // we need to create one using the policy sources registered for that scope. + sources := slicesx.Filter(nil, policySources, func(source *source.Source) bool { + return source.Scope().Contains(scope) + }) + policy, err = newPolicy(scope, sources...) + if err != nil { + return nil, err + } + effectivePolicies = append(effectivePolicies, policy) + return policy, nil +} + +// findPolicyByScopeLocked returns a policy with the specified scope and true if +// one exists in the [effectivePolicies] list, otherwise it returns nil, false. +// [policyMu] must be held. +func findPolicyByScopeLocked(target setting.PolicyScope) (policy *Policy, ok bool) { + for _, policy := range effectivePolicies { + if policy.Scope() == target && policy.IsValid() { + return policy, true + } + } + return nil, false +} + +// deletePolicy deletes the specified effective policy from [effectivePolicies] +// and [effectivePolicyLRU]. +func deletePolicy(policy *Policy) { + policyMu.Lock() + defer policyMu.Unlock() + if i := slices.Index(effectivePolicies, policy); i != -1 { + effectivePolicies = slices.Delete(effectivePolicies, i, i+1) + } + effectivePolicyLRU[policy.Scope().Kind()].CompareAndSwap(policy, nil) +} + +// registerSource registers the specified [source.Source] to be used by the package. +// It updates existing [Policy]s returned by [PolicyFor] to use this source if +// they are within the source's [setting.PolicyScope]. +func registerSource(source *source.Source) error { + policyMu.Lock() + defer policyMu.Unlock() + if slices.Contains(policySources, source) { + // already registered + return nil + } + policySources = append(policySources, source) + return forEachEffectivePolicyLocked(func(policy *Policy) error { + if !source.Scope().Contains(policy.Scope()) { + // Policy settings in the specified source do not apply + // to the scope of this effective policy. + // For example, a user policy source is being registered + // while the effective policy is for the device (or another user). + return nil + } + return policy.addSource(source) + }) +} + +// replaceSource is like [unregisterSource](old) followed by [registerSource](new), +// but performed atomically: the effective policy will contain settings +// either from the old source or the new source, never both and never neither. +func replaceSource(old, new *source.Source) error { + policyMu.Lock() + defer policyMu.Unlock() + oldIndex := slices.Index(policySources, old) + if oldIndex == -1 { + return fmt.Errorf("the source is not registered: %v", old) + } + policySources[oldIndex] = new + return forEachEffectivePolicyLocked(func(policy *Policy) error { + if !old.Scope().Contains(policy.Scope()) || !new.Scope().Contains(policy.Scope()) { + return nil + } + return policy.replaceSource(old, new) + }) +} + +// unregisterSource unregisters the specified [source.Source], +// so that it won't be used by any new or existing [Policy]. +func unregisterSource(source *source.Source) error { + policyMu.Lock() + defer policyMu.Unlock() + index := slices.Index(policySources, source) + if index == -1 { + return nil + } + policySources = slices.Delete(policySources, index, index+1) + return forEachEffectivePolicyLocked(func(policy *Policy) error { + if !source.Scope().Contains(policy.Scope()) { + return nil + } + return policy.removeSource(source) + }) +} + +// forEachEffectivePolicyLocked calls fn for every non-closed [Policy] in [effectivePolicies]. +// It accumulates the returned errors and returns an error that wraps all errors returned by fn. +// The [policyMu] mutex must be held while this function is executed. +func forEachEffectivePolicyLocked(fn func(p *Policy) error) error { + var errs []error + for _, policy := range effectivePolicies { + if policy.IsValid() { + err := fn(policy) + if err != nil && !errors.Is(err, ErrPolicyClosed) { + errs = append(errs, err) + } + } + } + return errors.Join(errs...) +} diff --git a/util/syspolicy/rsop/store_registration.go b/util/syspolicy/rsop/store_registration.go new file mode 100644 index 000000000..a7c354b6d --- /dev/null +++ b/util/syspolicy/rsop/store_registration.go @@ -0,0 +1,98 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package rsop + +import ( + "errors" + "sync" + "sync/atomic" + "time" + + "tailscale.com/util/syspolicy/setting" + "tailscale.com/util/syspolicy/source" + "tailscale.com/util/testenv" +) + +// ErrAlreadyConsumed is the error returned when [StoreRegistration.ReplaceStore] +// or [StoreRegistration.Unregister] is called more than once. +var ErrAlreadyConsumed = errors.New("the store registration is no longer valid") + +// StoreRegistration is a [source.Store] registered for use in the specified scope. +// It can be used to unregister the store, or replace it with another one. +type StoreRegistration struct { + source *source.Source + m sync.Mutex // protects the [StoreRegistration.consumeSlow] path + consumed atomic.Bool // can be read without holding m, but must be written with m held +} + +// RegisterStore registers a new policy [source.Store] with the specified name and [setting.PolicyScope]. +func RegisterStore(name string, scope setting.PolicyScope, store source.Store) (*StoreRegistration, error) { + return newStoreRegistration(name, scope, store) +} + +// RegisterStoreForTest is like [RegisterStore], but unregisters the store when +// tb and all its subtests complete. +func RegisterStoreForTest(tb testenv.TB, name string, scope setting.PolicyScope, store source.Store) (*StoreRegistration, error) { + setForTest(tb, &policyReloadMinDelay, 10*time.Millisecond) + setForTest(tb, &policyReloadMaxDelay, 500*time.Millisecond) + + reg, err := RegisterStore(name, scope, store) + if err == nil { + tb.Cleanup(func() { + if err := reg.Unregister(); err != nil && !errors.Is(err, ErrAlreadyConsumed) { + tb.Fatalf("Unregister failed: %v", err) + } + }) + } + return reg, err // may be nil or non-nil +} + +func newStoreRegistration(name string, scope setting.PolicyScope, store source.Store) (*StoreRegistration, error) { + source := source.NewSource(name, scope, store) + if err := registerSource(source); err != nil { + return nil, err + } + return &StoreRegistration{source: source}, nil +} + +// ReplaceStore replaces the registered store with the new one, +// returning a new [StoreRegistration] or an error. +func (r *StoreRegistration) ReplaceStore(new source.Store) (*StoreRegistration, error) { + var res *StoreRegistration + err := r.consume(func() error { + newSource := source.NewSource(r.source.Name(), r.source.Scope(), new) + if err := replaceSource(r.source, newSource); err != nil { + return err + } + res = &StoreRegistration{source: newSource} + return nil + }) + return res, err +} + +// Unregister reverts the registration. +func (r *StoreRegistration) Unregister() error { + return r.consume(func() error { return unregisterSource(r.source) }) +} + +// consume invokes fn, consuming r if no error is returned. +// It returns [ErrAlreadyConsumed] on subsequent calls after the first successful call. +func (r *StoreRegistration) consume(fn func() error) (err error) { + if r.consumed.Load() { + return ErrAlreadyConsumed + } + return r.consumeSlow(fn) +} + +func (r *StoreRegistration) consumeSlow(fn func() error) (err error) { + r.m.Lock() + defer r.m.Unlock() + if r.consumed.Load() { + return ErrAlreadyConsumed + } + if err = fn(); err == nil { + r.consumed.Store(true) + } + return err // may be nil or non-nil +} diff --git a/util/syspolicy/setting/key.go b/util/syspolicy/setting/key.go deleted file mode 100644 index 406fde132..000000000 --- a/util/syspolicy/setting/key.go +++ /dev/null @@ -1,13 +0,0 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package setting - -// Key is a string that uniquely identifies a policy and must remain unchanged -// once established and documented for a given policy setting. It may contain -// alphanumeric characters and zero or more [KeyPathSeparator]s to group -// individual policy settings into categories. -type Key string - -// KeyPathSeparator allows logical grouping of policy settings into categories. -const KeyPathSeparator = "/" diff --git a/util/syspolicy/setting/origin.go b/util/syspolicy/setting/origin.go index 078ef758e..4c7cc7025 100644 --- a/util/syspolicy/setting/origin.go +++ b/util/syspolicy/setting/origin.go @@ -50,22 +50,27 @@ func (s Origin) String() string { return s.Scope().String() } -// MarshalJSONV2 implements [jsonv2.MarshalerV2]. -func (s Origin) MarshalJSONV2(out *jsontext.Encoder, opts jsonv2.Options) error { - return jsonv2.MarshalEncode(out, &s.data, opts) +var ( + _ jsonv2.MarshalerTo = (*Origin)(nil) + _ jsonv2.UnmarshalerFrom = (*Origin)(nil) +) + +// MarshalJSONTo implements [jsonv2.MarshalerTo]. +func (s Origin) MarshalJSONTo(out *jsontext.Encoder) error { + return jsonv2.MarshalEncode(out, &s.data) } -// UnmarshalJSONV2 implements [jsonv2.UnmarshalerV2]. -func (s *Origin) UnmarshalJSONV2(in *jsontext.Decoder, opts jsonv2.Options) error { - return jsonv2.UnmarshalDecode(in, &s.data, opts) +// UnmarshalJSONFrom implements [jsonv2.UnmarshalerFrom]. +func (s *Origin) UnmarshalJSONFrom(in *jsontext.Decoder) error { + return jsonv2.UnmarshalDecode(in, &s.data) } // MarshalJSON implements [json.Marshaler]. func (s Origin) MarshalJSON() ([]byte, error) { - return jsonv2.Marshal(s) // uses MarshalJSONV2 + return jsonv2.Marshal(s) // uses MarshalJSONTo } // UnmarshalJSON implements [json.Unmarshaler]. func (s *Origin) UnmarshalJSON(b []byte) error { - return jsonv2.Unmarshal(b, s) // uses UnmarshalJSONV2 + return jsonv2.Unmarshal(b, s) // uses UnmarshalJSONFrom } diff --git a/util/syspolicy/setting/policy_scope.go b/util/syspolicy/setting/policy_scope.go index 55fa339e7..c2039fdda 100644 --- a/util/syspolicy/setting/policy_scope.go +++ b/util/syspolicy/setting/policy_scope.go @@ -8,6 +8,7 @@ import ( "strings" "tailscale.com/types/lazy" + "tailscale.com/util/syspolicy/internal" ) var ( @@ -35,6 +36,8 @@ type PolicyScope struct { // when querying policy settings. // It returns [DeviceScope], unless explicitly changed with [SetDefaultScope]. func DefaultScope() PolicyScope { + // Allow deferred package init functions to override the default scope. + internal.Init.Do() return lazyDefaultScope.Get(func() PolicyScope { return DeviceScope }) } diff --git a/util/syspolicy/setting/raw_item.go b/util/syspolicy/setting/raw_item.go index 30480d892..ea97865f5 100644 --- a/util/syspolicy/setting/raw_item.go +++ b/util/syspolicy/setting/raw_item.go @@ -5,8 +5,13 @@ package setting import ( "fmt" + "reflect" + jsonv2 "github.com/go-json-experiment/json" + "github.com/go-json-experiment/json/jsontext" + "tailscale.com/types/opt" "tailscale.com/types/structs" + "tailscale.com/util/syspolicy/pkey" ) // RawItem contains a raw policy setting value as read from a policy store, or an @@ -17,10 +22,15 @@ import ( // or converted from strings, these setting types predate the typed policy // hierarchies, and must be supported at this layer. type RawItem struct { - _ structs.Incomparable - value any - err *ErrorText - origin *Origin // or nil + _ structs.Incomparable + data rawItemJSON +} + +// rawItemJSON holds JSON-marshallable data for [RawItem]. +type rawItemJSON struct { + Value RawValue `json:",omitzero"` + Error *ErrorText `json:",omitzero"` // or nil + Origin *Origin `json:",omitzero"` // or nil } // RawItemOf returns a [RawItem] with the specified value. @@ -30,20 +40,20 @@ func RawItemOf(value any) RawItem { // RawItemWith returns a [RawItem] with the specified value, error and origin. func RawItemWith(value any, err *ErrorText, origin *Origin) RawItem { - return RawItem{value: value, err: err, origin: origin} + return RawItem{data: rawItemJSON{Value: RawValue{opt.ValueOf(value)}, Error: err, Origin: origin}} } // Value returns the value of the policy setting, or nil if the policy setting // is not configured, or an error occurred while reading it. func (i RawItem) Value() any { - return i.value + return i.data.Value.Get() } // Error returns the error that occurred when reading the policy setting, // or nil if no error occurred. func (i RawItem) Error() error { - if i.err != nil { - return i.err + if i.data.Error != nil { + return i.data.Error } return nil } @@ -51,17 +61,113 @@ func (i RawItem) Error() error { // Origin returns an optional [Origin] indicating where the policy setting is // configured. func (i RawItem) Origin() *Origin { - return i.origin + return i.data.Origin } // String implements [fmt.Stringer]. func (i RawItem) String() string { var suffix string - if i.origin != nil { - suffix = fmt.Sprintf(" - {%v}", i.origin) + if i.data.Origin != nil { + suffix = fmt.Sprintf(" - {%v}", i.data.Origin) + } + if i.data.Error != nil { + return fmt.Sprintf("Error{%q}%s", i.data.Error.Error(), suffix) + } + return fmt.Sprintf("%v%s", i.data.Value.Value, suffix) +} + +var ( + _ jsonv2.MarshalerTo = (*RawItem)(nil) + _ jsonv2.UnmarshalerFrom = (*RawItem)(nil) +) + +// MarshalJSONTo implements [jsonv2.MarshalerTo]. +func (i RawItem) MarshalJSONTo(out *jsontext.Encoder) error { + return jsonv2.MarshalEncode(out, &i.data) +} + +// UnmarshalJSONFrom implements [jsonv2.UnmarshalerFrom]. +func (i *RawItem) UnmarshalJSONFrom(in *jsontext.Decoder) error { + return jsonv2.UnmarshalDecode(in, &i.data) +} + +// MarshalJSON implements [json.Marshaler]. +func (i RawItem) MarshalJSON() ([]byte, error) { + return jsonv2.Marshal(i) // uses MarshalJSONTo +} + +// UnmarshalJSON implements [json.Unmarshaler]. +func (i *RawItem) UnmarshalJSON(b []byte) error { + return jsonv2.Unmarshal(b, i) // uses UnmarshalJSONFrom +} + +// RawValue represents a raw policy setting value read from a policy store. +// It is JSON-marshallable and facilitates unmarshalling of JSON values +// into corresponding policy setting types, with special handling for JSON numbers +// (unmarshalled as float64) and JSON string arrays (unmarshalled as []string). +// See also [RawValue.UnmarshalJSONFrom]. +type RawValue struct { + opt.Value[any] +} + +// RawValueType is a constraint that permits raw setting value types. +type RawValueType interface { + bool | uint64 | string | []string +} + +// RawValueOf returns a new [RawValue] holding the specified value. +func RawValueOf[T RawValueType](v T) RawValue { + return RawValue{opt.ValueOf[any](v)} +} + +var ( + _ jsonv2.MarshalerTo = (*RawValue)(nil) + _ jsonv2.UnmarshalerFrom = (*RawValue)(nil) +) + +// MarshalJSONTo implements [jsonv2.MarshalerTo]. +func (v RawValue) MarshalJSONTo(out *jsontext.Encoder) error { + return jsonv2.MarshalEncode(out, v.Value) +} + +// UnmarshalJSONFrom implements [jsonv2.UnmarshalerFrom] by attempting to unmarshal +// a JSON value as one of the supported policy setting value types (bool, string, uint64, or []string), +// based on the JSON value type. It fails if the JSON value is an object, if it's a JSON number that +// cannot be represented as a uint64, or if a JSON array contains anything other than strings. +func (v *RawValue) UnmarshalJSONFrom(in *jsontext.Decoder) error { + var valPtr any + switch k := in.PeekKind(); k { + case 't', 'f': + valPtr = new(bool) + case '"': + valPtr = new(string) + case '0': + valPtr = new(uint64) // unmarshal JSON numbers as uint64 + case '[', 'n': + valPtr = new([]string) // unmarshal arrays as string slices + case '{': + return fmt.Errorf("unexpected token: %v", k) + default: + panic("unreachable") } - if i.err != nil { - return fmt.Sprintf("Error{%q}%s", i.err.Error(), suffix) + if err := jsonv2.UnmarshalDecode(in, valPtr); err != nil { + v.Value.Clear() + return err } - return fmt.Sprintf("%v%s", i.value, suffix) + value := reflect.ValueOf(valPtr).Elem().Interface() + v.Value = opt.ValueOf(value) + return nil +} + +// MarshalJSON implements [json.Marshaler]. +func (v RawValue) MarshalJSON() ([]byte, error) { + return jsonv2.Marshal(v) // uses MarshalJSONTo +} + +// UnmarshalJSON implements [json.Unmarshaler]. +func (v *RawValue) UnmarshalJSON(b []byte) error { + return jsonv2.Unmarshal(b, v) // uses UnmarshalJSONFrom } + +// RawValues is a map of keyed setting values that can be read from a JSON. +type RawValues map[pkey.Key]RawValue diff --git a/util/syspolicy/setting/raw_item_test.go b/util/syspolicy/setting/raw_item_test.go new file mode 100644 index 000000000..05562d78c --- /dev/null +++ b/util/syspolicy/setting/raw_item_test.go @@ -0,0 +1,101 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package setting + +import ( + "math" + "reflect" + "strconv" + "testing" + + jsonv2 "github.com/go-json-experiment/json" +) + +func TestMarshalUnmarshalRawValue(t *testing.T) { + tests := []struct { + name string + json string + want RawValue + wantErr bool + }{ + { + name: "Bool/True", + json: `true`, + want: RawValueOf(true), + }, + { + name: "Bool/False", + json: `false`, + want: RawValueOf(false), + }, + { + name: "String/Empty", + json: `""`, + want: RawValueOf(""), + }, + { + name: "String/NonEmpty", + json: `"Test"`, + want: RawValueOf("Test"), + }, + { + name: "StringSlice/Null", + json: `null`, + want: RawValueOf([]string(nil)), + }, + { + name: "StringSlice/Empty", + json: `[]`, + want: RawValueOf([]string{}), + }, + { + name: "StringSlice/NonEmpty", + json: `["A", "B", "C"]`, + want: RawValueOf([]string{"A", "B", "C"}), + }, + { + name: "StringSlice/NonStrings", + json: `[1, 2, 3]`, + wantErr: true, + }, + { + name: "Number/Integer/0", + json: `0`, + want: RawValueOf(uint64(0)), + }, + { + name: "Number/Integer/1", + json: `1`, + want: RawValueOf(uint64(1)), + }, + { + name: "Number/Integer/MaxUInt64", + json: strconv.FormatUint(math.MaxUint64, 10), + want: RawValueOf(uint64(math.MaxUint64)), + }, + { + name: "Number/Integer/Negative", + json: `-1`, + wantErr: true, + }, + { + name: "Object", + json: `{}`, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var got RawValue + gotErr := jsonv2.Unmarshal([]byte(tt.json), &got) + if (gotErr != nil) != tt.wantErr { + t.Fatalf("Error: got %v; want %v", gotErr, tt.wantErr) + } + + if !tt.wantErr && !reflect.DeepEqual(got, tt.want) { + t.Fatalf("Value: got %v; want %v", got, tt.want) + } + }) + } +} diff --git a/util/syspolicy/setting/setting.go b/util/syspolicy/setting/setting.go index 93be287b1..97362b1dc 100644 --- a/util/syspolicy/setting/setting.go +++ b/util/syspolicy/setting/setting.go @@ -11,11 +11,14 @@ import ( "fmt" "slices" "strings" - "sync" "time" + "tailscale.com/syncs" "tailscale.com/types/lazy" "tailscale.com/util/syspolicy/internal" + "tailscale.com/util/syspolicy/pkey" + "tailscale.com/util/syspolicy/ptype" + "tailscale.com/util/testenv" ) // Scope indicates the broadest scope at which a policy setting may apply, @@ -128,12 +131,12 @@ func (t Type) String() string { // ValueType is a constraint that allows Go types corresponding to [Type]. type ValueType interface { - bool | uint64 | string | []string | Visibility | PreferenceOption | time.Duration + bool | uint64 | string | []string | ptype.Visibility | ptype.PreferenceOption | time.Duration } // Definition defines policy key, scope and value type. type Definition struct { - key Key + key pkey.Key scope Scope typ Type platforms PlatformList @@ -141,12 +144,12 @@ type Definition struct { // NewDefinition returns a new [Definition] with the specified // key, scope, type and supported platforms (see [PlatformList]). -func NewDefinition(k Key, s Scope, t Type, platforms ...string) *Definition { +func NewDefinition(k pkey.Key, s Scope, t Type, platforms ...string) *Definition { return &Definition{key: k, scope: s, typ: t, platforms: platforms} } // Key returns a policy setting's identifier. -func (d *Definition) Key() Key { +func (d *Definition) Key() pkey.Key { if d == nil { return "" } @@ -207,12 +210,12 @@ func (d *Definition) Equal(d2 *Definition) bool { } // DefinitionMap is a map of setting [Definition] by [Key]. -type DefinitionMap map[Key]*Definition +type DefinitionMap map[pkey.Key]*Definition var ( definitions lazy.SyncValue[DefinitionMap] - definitionsMu sync.Mutex + definitionsMu syncs.Mutex definitionsList []*Definition definitionsUsed bool ) @@ -223,7 +226,7 @@ var ( // invoking any functions that use the registered policy definitions. This // includes calling [Definitions] or [DefinitionOf] directly, or reading any // policy settings via syspolicy. -func Register(k Key, s Scope, t Type, platforms ...string) { +func Register(k pkey.Key, s Scope, t Type, platforms ...string) { RegisterDefinition(NewDefinition(k, s, t, platforms...)) } @@ -243,6 +246,9 @@ func registerLocked(d *Definition) { func settingDefinitions() (DefinitionMap, error) { return definitions.GetErr(func() (DefinitionMap, error) { + if err := internal.Init.Do(); err != nil { + return nil, err + } definitionsMu.Lock() defer definitionsMu.Unlock() definitionsUsed = true @@ -274,7 +280,7 @@ func DefinitionMapOf(settings []*Definition) (DefinitionMap, error) { // for the test duration. It is not concurrency-safe, but unlike [Register], // it does not panic and can be called anytime. // It returns an error if ds contains two different settings with the same [Key]. -func SetDefinitionsForTest(tb lazy.TB, ds ...*Definition) error { +func SetDefinitionsForTest(tb testenv.TB, ds ...*Definition) error { m, err := DefinitionMapOf(ds) if err != nil { return err @@ -286,7 +292,7 @@ func SetDefinitionsForTest(tb lazy.TB, ds ...*Definition) error { // DefinitionOf returns a setting definition by key, // or [ErrNoSuchKey] if the specified key does not exist, // or an error if there are conflicting policy definitions. -func DefinitionOf(k Key) (*Definition, error) { +func DefinitionOf(k pkey.Key) (*Definition, error) { ds, err := settingDefinitions() if err != nil { return nil, err @@ -316,33 +322,33 @@ func Definitions() ([]*Definition, error) { type PlatformList []string // Has reports whether l contains the target platform. -func (l PlatformList) Has(target string) bool { - if len(l) == 0 { +func (ls PlatformList) Has(target string) bool { + if len(ls) == 0 { return true } - return slices.ContainsFunc(l, func(os string) bool { + return slices.ContainsFunc(ls, func(os string) bool { return strings.EqualFold(os, target) }) } // HasCurrent is like Has, but for the current platform. -func (l PlatformList) HasCurrent() bool { - return l.Has(internal.OS()) +func (ls PlatformList) HasCurrent() bool { + return ls.Has(internal.OS()) } // mergeFrom merges l2 into l. Since an empty list indicates no platform restrictions, // if either l or l2 is empty, the merged result in l will also be empty. -func (l *PlatformList) mergeFrom(l2 PlatformList) { +func (ls *PlatformList) mergeFrom(l2 PlatformList) { switch { - case len(*l) == 0: + case len(*ls) == 0: // No-op. An empty list indicates no platform restrictions. case len(l2) == 0: // Merging with an empty list results in an empty list. - *l = l2 + *ls = l2 default: // Append, sort and dedup. - *l = append(*l, l2...) - slices.Sort(*l) - *l = slices.Compact(*l) + *ls = append(*ls, l2...) + slices.Sort(*ls) + *ls = slices.Compact(*ls) } } diff --git a/util/syspolicy/setting/setting_test.go b/util/syspolicy/setting/setting_test.go index 3cc08e7da..9d99884f6 100644 --- a/util/syspolicy/setting/setting_test.go +++ b/util/syspolicy/setting/setting_test.go @@ -11,6 +11,7 @@ import ( "tailscale.com/types/lazy" "tailscale.com/types/ptr" "tailscale.com/util/syspolicy/internal" + "tailscale.com/util/syspolicy/pkey" ) func TestSettingDefinition(t *testing.T) { @@ -18,7 +19,7 @@ func TestSettingDefinition(t *testing.T) { name string setting *Definition osOverride string - wantKey Key + wantKey pkey.Key wantScope Scope wantType Type wantIsSupported bool @@ -163,10 +164,10 @@ func TestSettingDefinition(t *testing.T) { } func TestRegisterSettingDefinition(t *testing.T) { - const testPolicySettingKey Key = "TestPolicySetting" + const testPolicySettingKey pkey.Key = "TestPolicySetting" tests := []struct { name string - key Key + key pkey.Key wantEq *Definition wantErr error }{ @@ -310,8 +311,8 @@ func TestListSettingDefinitions(t *testing.T) { t.Fatalf("SetDefinitionsForTest failed: %v", err) } - cmp := func(l, r *Definition) int { - return strings.Compare(string(l.Key()), string(r.Key())) + cmp := func(a, b *Definition) int { + return strings.Compare(string(a.Key()), string(b.Key())) } want := append([]*Definition{}, definitions...) slices.SortFunc(want, cmp) diff --git a/util/syspolicy/setting/snapshot.go b/util/syspolicy/setting/snapshot.go index 306bf759e..94c7ecadb 100644 --- a/util/syspolicy/setting/snapshot.go +++ b/util/syspolicy/setting/snapshot.go @@ -4,41 +4,46 @@ package setting import ( + "errors" + "iter" + "maps" "slices" "strings" + "time" + jsonv2 "github.com/go-json-experiment/json" + "github.com/go-json-experiment/json/jsontext" xmaps "golang.org/x/exp/maps" "tailscale.com/util/deephash" + "tailscale.com/util/syspolicy/pkey" ) // Snapshot is an immutable collection of ([Key], [RawItem]) pairs, representing // a set of policy settings applied at a specific moment in time. // A nil pointer to [Snapshot] is valid. type Snapshot struct { - m map[Key]RawItem + m map[pkey.Key]RawItem sig deephash.Sum // of m summary Summary } // NewSnapshot returns a new [Snapshot] with the specified items and options. -func NewSnapshot(items map[Key]RawItem, opts ...SummaryOption) *Snapshot { +func NewSnapshot(items map[pkey.Key]RawItem, opts ...SummaryOption) *Snapshot { return &Snapshot{m: xmaps.Clone(items), sig: deephash.Hash(&items), summary: SummaryWith(opts...)} } -// All returns a map of all policy settings in s. -// The returned map must not be modified. -func (s *Snapshot) All() map[Key]RawItem { +// All returns an iterator over policy settings in s. The iteration order is not +// specified and is not guaranteed to be the same from one call to the next. +func (s *Snapshot) All() iter.Seq2[pkey.Key, RawItem] { if s == nil { - return nil + return func(yield func(pkey.Key, RawItem) bool) {} } - // TODO(nickkhyl): return iter.Seq2[[Key], [RawItem]] in Go 1.23, - // and remove [keyItemPair]. - return s.m + return maps.All(s.m) } // Get returns the value of the policy setting with the specified key // or nil if it is not configured or has an error. -func (s *Snapshot) Get(k Key) any { +func (s *Snapshot) Get(k pkey.Key) any { v, _ := s.GetErr(k) return v } @@ -46,7 +51,7 @@ func (s *Snapshot) Get(k Key) any { // GetErr returns the value of the policy setting with the specified key, // [ErrNotConfigured] if it is not configured, or an error returned by // the policy Store if the policy setting could not be read. -func (s *Snapshot) GetErr(k Key) (any, error) { +func (s *Snapshot) GetErr(k pkey.Key) (any, error) { if s != nil { if s, ok := s.m[k]; ok { return s.Value(), s.Error() @@ -58,13 +63,16 @@ func (s *Snapshot) GetErr(k Key) (any, error) { // GetSetting returns the untyped policy setting with the specified key and true // if a policy setting with such key has been configured; // otherwise, it returns zero, false. -func (s *Snapshot) GetSetting(k Key) (setting RawItem, ok bool) { +func (s *Snapshot) GetSetting(k pkey.Key) (setting RawItem, ok bool) { setting, ok = s.m[k] return setting, ok } // Equal reports whether s and s2 are equal. func (s *Snapshot) Equal(s2 *Snapshot) bool { + if s == s2 { + return true + } if !s.EqualItems(s2) { return false } @@ -87,12 +95,11 @@ func (s *Snapshot) EqualItems(s2 *Snapshot) bool { // Keys return an iterator over keys in s. The iteration order is not specified // and is not guaranteed to be the same from one call to the next. -func (s *Snapshot) Keys() []Key { +func (s *Snapshot) Keys() iter.Seq[pkey.Key] { if s.m == nil { - return nil + return func(yield func(pkey.Key) bool) {} } - // TODO(nickkhyl): return iter.Seq[Key] in Go 1.23. - return xmaps.Keys(s.m) + return maps.Keys(s.m) } // Len reports the number of [RawItem]s in s. @@ -116,8 +123,6 @@ func (s *Snapshot) String() string { if s.Len() == 0 && s.Summary().IsEmpty() { return "{Empty}" } - keys := s.Keys() - slices.Sort(keys) var sb strings.Builder if !s.summary.IsEmpty() { sb.WriteRune('{') @@ -127,7 +132,7 @@ func (s *Snapshot) String() string { sb.WriteString(s.summary.String()) sb.WriteRune('}') } - for _, k := range keys { + for _, k := range slices.Sorted(s.Keys()) { if sb.Len() != 0 { sb.WriteRune('\n') } @@ -138,6 +143,68 @@ func (s *Snapshot) String() string { return sb.String() } +// snapshotJSON holds JSON-marshallable data for [Snapshot]. +type snapshotJSON struct { + Summary Summary `json:",omitzero"` + Settings map[pkey.Key]RawItem `json:",omitempty"` +} + +var ( + _ jsonv2.MarshalerTo = (*Snapshot)(nil) + _ jsonv2.UnmarshalerFrom = (*Snapshot)(nil) +) + +// As of 2025-07-28, jsonv2 no longer has a default representation for [time.Duration], +// so we need to provide a custom marshaler. +// +// This is temporary until the decision on the default representation is made +// (see https://github.com/golang/go/issues/71631#issuecomment-2981670799). +// +// In the future, we might either use the default representation (if compatible with +// [time.Duration.String]) or specify something like json.WithFormat[time.Duration]("units") +// when golang/go#71664 is implemented. +// +// TODO(nickkhyl): revisit this when the decision on the default [time.Duration] +// representation is made in golang/go#71631 and/or golang/go#71664 is implemented. +var formatDurationAsUnits = jsonv2.JoinOptions( + jsonv2.WithMarshalers(jsonv2.MarshalToFunc(func(e *jsontext.Encoder, t time.Duration) error { + return e.WriteToken(jsontext.String(t.String())) + })), +) + +// MarshalJSONTo implements [jsonv2.MarshalerTo]. +func (s *Snapshot) MarshalJSONTo(out *jsontext.Encoder) error { + data := &snapshotJSON{} + if s != nil { + data.Summary = s.summary + data.Settings = s.m + } + return jsonv2.MarshalEncode(out, data, formatDurationAsUnits) +} + +// UnmarshalJSONFrom implements [jsonv2.UnmarshalerFrom]. +func (s *Snapshot) UnmarshalJSONFrom(in *jsontext.Decoder) error { + if s == nil { + return errors.New("s must not be nil") + } + data := &snapshotJSON{} + if err := jsonv2.UnmarshalDecode(in, data); err != nil { + return err + } + *s = Snapshot{m: data.Settings, sig: deephash.Hash(&data.Settings), summary: data.Summary} + return nil +} + +// MarshalJSON implements [json.Marshaler]. +func (s *Snapshot) MarshalJSON() ([]byte, error) { + return jsonv2.Marshal(s) // uses MarshalJSONTo +} + +// UnmarshalJSON implements [json.Unmarshaler]. +func (s *Snapshot) UnmarshalJSON(b []byte) error { + return jsonv2.Unmarshal(b, s) // uses UnmarshalJSONFrom +} + // MergeSnapshots returns a [Snapshot] that contains all [RawItem]s // from snapshot1 and snapshot2 and the [Summary] with the narrower [PolicyScope]. // If there's a conflict between policy settings in the two snapshots, @@ -166,7 +233,7 @@ func MergeSnapshots(snapshot1, snapshot2 *Snapshot) *Snapshot { } return &Snapshot{snapshot2.m, snapshot2.sig, SummaryWith(summaryOpts...)} } - m := make(map[Key]RawItem, snapshot1.Len()+snapshot2.Len()) + m := make(map[pkey.Key]RawItem, snapshot1.Len()+snapshot2.Len()) xmaps.Copy(m, snapshot1.m) xmaps.Copy(m, snapshot2.m) // snapshot2 has higher precedence return &Snapshot{m, deephash.Hash(&m), SummaryWith(summaryOpts...)} diff --git a/util/syspolicy/setting/snapshot_test.go b/util/syspolicy/setting/snapshot_test.go index e198d4a58..762a9681c 100644 --- a/util/syspolicy/setting/snapshot_test.go +++ b/util/syspolicy/setting/snapshot_test.go @@ -4,8 +4,20 @@ package setting import ( + "cmp" + "encoding/json" "testing" "time" + + jsonv2 "github.com/go-json-experiment/json" + "tailscale.com/util/syspolicy/internal" + "tailscale.com/util/syspolicy/pkey" + "tailscale.com/util/syspolicy/ptype" +) + +const ( + VisibleByPolicy = ptype.VisibleByPolicy + ShowChoiceByPolicy = ptype.ShowChoiceByPolicy ) func TestMergeSnapshots(t *testing.T) { @@ -18,180 +30,179 @@ func TestMergeSnapshots(t *testing.T) { name: "both-nil", s1: nil, s2: nil, - want: NewSnapshot(map[Key]RawItem{}), + want: NewSnapshot(map[pkey.Key]RawItem{}), }, { name: "both-empty", - s1: NewSnapshot(map[Key]RawItem{}), - s2: NewSnapshot(map[Key]RawItem{}), - want: NewSnapshot(map[Key]RawItem{}), + s1: NewSnapshot(map[pkey.Key]RawItem{}), + s2: NewSnapshot(map[pkey.Key]RawItem{}), + want: NewSnapshot(map[pkey.Key]RawItem{}), }, { name: "first-nil", s1: nil, - s2: NewSnapshot(map[Key]RawItem{ - "Setting1": {value: 123}, - "Setting2": {value: "String"}, - "Setting3": {value: true}, + s2: NewSnapshot(map[pkey.Key]RawItem{ + "Setting1": RawItemOf(123), + "Setting2": RawItemOf("String"), + "Setting3": RawItemOf(true), }), - want: NewSnapshot(map[Key]RawItem{ - "Setting1": {value: 123}, - "Setting2": {value: "String"}, - "Setting3": {value: true}, + want: NewSnapshot(map[pkey.Key]RawItem{ + "Setting1": RawItemOf(123), + "Setting2": RawItemOf("String"), + "Setting3": RawItemOf(true), }), }, { name: "first-empty", - s1: NewSnapshot(map[Key]RawItem{}), - s2: NewSnapshot(map[Key]RawItem{ - "Setting1": {value: 123}, - "Setting2": {value: "String"}, - "Setting3": {value: false}, + s1: NewSnapshot(map[pkey.Key]RawItem{}), + s2: NewSnapshot(map[pkey.Key]RawItem{ + "Setting1": RawItemOf(123), + "Setting2": RawItemOf("String"), + "Setting3": RawItemOf(false), }), - want: NewSnapshot(map[Key]RawItem{ - "Setting1": {value: 123}, - "Setting2": {value: "String"}, - "Setting3": {value: false}, + want: NewSnapshot(map[pkey.Key]RawItem{ + "Setting1": RawItemOf(123), + "Setting2": RawItemOf("String"), + "Setting3": RawItemOf(false), }), }, { name: "second-nil", - s1: NewSnapshot(map[Key]RawItem{ - "Setting1": {value: 123}, - "Setting2": {value: "String"}, - "Setting3": {value: true}, + s1: NewSnapshot(map[pkey.Key]RawItem{ + "Setting1": RawItemOf(123), + "Setting2": RawItemOf("String"), + "Setting3": RawItemOf(true), }), s2: nil, - want: NewSnapshot(map[Key]RawItem{ - "Setting1": {value: 123}, - "Setting2": {value: "String"}, - "Setting3": {value: true}, + want: NewSnapshot(map[pkey.Key]RawItem{ + "Setting1": RawItemOf(123), + "Setting2": RawItemOf("String"), + "Setting3": RawItemOf(true), }), }, { name: "second-empty", - s1: NewSnapshot(map[Key]RawItem{ - "Setting1": {value: 123}, - "Setting2": {value: "String"}, - "Setting3": {value: false}, + s1: NewSnapshot(map[pkey.Key]RawItem{ + "Setting1": RawItemOf(123), + "Setting2": RawItemOf("String"), + "Setting3": RawItemOf(false), }), - s2: NewSnapshot(map[Key]RawItem{}), - want: NewSnapshot(map[Key]RawItem{ - "Setting1": {value: 123}, - "Setting2": {value: "String"}, - "Setting3": {value: false}, + s2: NewSnapshot(map[pkey.Key]RawItem{}), + want: NewSnapshot(map[pkey.Key]RawItem{ + "Setting1": RawItemOf(123), + "Setting2": RawItemOf("String"), + "Setting3": RawItemOf(false), }), }, { name: "no-conflicts", - s1: NewSnapshot(map[Key]RawItem{ - "Setting1": {value: 123}, - "Setting2": {value: "String"}, - "Setting3": {value: false}, + s1: NewSnapshot(map[pkey.Key]RawItem{ + "Setting1": RawItemOf(123), + "Setting2": RawItemOf("String"), + "Setting3": RawItemOf(false), }), - s2: NewSnapshot(map[Key]RawItem{ - "Setting4": {value: 2 * time.Hour}, - "Setting5": {value: VisibleByPolicy}, - "Setting6": {value: ShowChoiceByPolicy}, + s2: NewSnapshot(map[pkey.Key]RawItem{ + "Setting4": RawItemOf(2 * time.Hour), + "Setting5": RawItemOf(VisibleByPolicy), + "Setting6": RawItemOf(ShowChoiceByPolicy), }), - want: NewSnapshot(map[Key]RawItem{ - "Setting1": {value: 123}, - "Setting2": {value: "String"}, - "Setting3": {value: false}, - "Setting4": {value: 2 * time.Hour}, - "Setting5": {value: VisibleByPolicy}, - "Setting6": {value: ShowChoiceByPolicy}, + want: NewSnapshot(map[pkey.Key]RawItem{ + "Setting1": RawItemOf(123), + "Setting2": RawItemOf("String"), + "Setting3": RawItemOf(false), + "Setting4": RawItemOf(2 * time.Hour), + "Setting5": RawItemOf(VisibleByPolicy), + "Setting6": RawItemOf(ShowChoiceByPolicy), }), }, { name: "with-conflicts", - s1: NewSnapshot(map[Key]RawItem{ - "Setting1": {value: 123}, - "Setting2": {value: "String"}, - "Setting3": {value: true}, + s1: NewSnapshot(map[pkey.Key]RawItem{ + "Setting1": RawItemOf(123), + "Setting2": RawItemOf("String"), + "Setting3": RawItemOf(true), }), - s2: NewSnapshot(map[Key]RawItem{ - "Setting1": {value: 456}, - "Setting3": {value: false}, - "Setting4": {value: 2 * time.Hour}, + s2: NewSnapshot(map[pkey.Key]RawItem{ + "Setting1": RawItemOf(456), + "Setting3": RawItemOf(false), + "Setting4": RawItemOf(2 * time.Hour), }), - want: NewSnapshot(map[Key]RawItem{ - "Setting1": {value: 456}, - "Setting2": {value: "String"}, - "Setting3": {value: false}, - "Setting4": {value: 2 * time.Hour}, + want: NewSnapshot(map[pkey.Key]RawItem{ + "Setting1": RawItemOf(456), + "Setting2": RawItemOf("String"), + "Setting3": RawItemOf(false), + "Setting4": RawItemOf(2 * time.Hour), }), }, { name: "with-scope-first-wins", - s1: NewSnapshot(map[Key]RawItem{ - "Setting1": {value: 123}, - "Setting2": {value: "String"}, - "Setting3": {value: true}, + s1: NewSnapshot(map[pkey.Key]RawItem{ + "Setting1": RawItemOf(123), + "Setting2": RawItemOf("String"), + "Setting3": RawItemOf(true), }, DeviceScope), - s2: NewSnapshot(map[Key]RawItem{ - "Setting1": {value: 456}, - "Setting3": {value: false}, - "Setting4": {value: 2 * time.Hour}, + s2: NewSnapshot(map[pkey.Key]RawItem{ + "Setting1": RawItemOf(456), + "Setting3": RawItemOf(false), + "Setting4": RawItemOf(2 * time.Hour), }, CurrentUserScope), - want: NewSnapshot(map[Key]RawItem{ - "Setting1": {value: 123}, - "Setting2": {value: "String"}, - "Setting3": {value: true}, - "Setting4": {value: 2 * time.Hour}, + want: NewSnapshot(map[pkey.Key]RawItem{ + "Setting1": RawItemOf(123), + "Setting2": RawItemOf("String"), + "Setting3": RawItemOf(true), + "Setting4": RawItemOf(2 * time.Hour), }, CurrentUserScope), }, { name: "with-scope-second-wins", - s1: NewSnapshot(map[Key]RawItem{ - "Setting1": {value: 123}, - "Setting2": {value: "String"}, - "Setting3": {value: true}, + s1: NewSnapshot(map[pkey.Key]RawItem{ + "Setting1": RawItemOf(123), + "Setting2": RawItemOf("String"), + "Setting3": RawItemOf(true), }, CurrentUserScope), - s2: NewSnapshot(map[Key]RawItem{ - "Setting1": {value: 456}, - "Setting3": {value: false}, - "Setting4": {value: 2 * time.Hour}, + s2: NewSnapshot(map[pkey.Key]RawItem{ + "Setting1": RawItemOf(456), + "Setting3": RawItemOf(false), + "Setting4": RawItemOf(2 * time.Hour), }, DeviceScope), - want: NewSnapshot(map[Key]RawItem{ - "Setting1": {value: 456}, - "Setting2": {value: "String"}, - "Setting3": {value: false}, - "Setting4": {value: 2 * time.Hour}, + want: NewSnapshot(map[pkey.Key]RawItem{ + "Setting1": RawItemOf(456), + "Setting2": RawItemOf("String"), + "Setting3": RawItemOf(false), + "Setting4": RawItemOf(2 * time.Hour), }, CurrentUserScope), }, { name: "with-scope-both-empty", - s1: NewSnapshot(map[Key]RawItem{}, CurrentUserScope), - s2: NewSnapshot(map[Key]RawItem{}, DeviceScope), - want: NewSnapshot(map[Key]RawItem{}, CurrentUserScope), + s1: NewSnapshot(map[pkey.Key]RawItem{}, CurrentUserScope), + s2: NewSnapshot(map[pkey.Key]RawItem{}, DeviceScope), + want: NewSnapshot(map[pkey.Key]RawItem{}, CurrentUserScope), }, { name: "with-scope-first-empty", - s1: NewSnapshot(map[Key]RawItem{}, CurrentUserScope), - s2: NewSnapshot(map[Key]RawItem{ - "Setting1": {value: 123}, - "Setting2": {value: "String"}, - "Setting3": {value: true}}, - DeviceScope, NewNamedOrigin("TestPolicy", DeviceScope)), - want: NewSnapshot(map[Key]RawItem{ - "Setting1": {value: 123}, - "Setting2": {value: "String"}, - "Setting3": {value: true}, + s1: NewSnapshot(map[pkey.Key]RawItem{}, CurrentUserScope), + s2: NewSnapshot(map[pkey.Key]RawItem{ + "Setting1": RawItemOf(123), + "Setting2": RawItemOf("String"), + "Setting3": RawItemOf(true)}, DeviceScope, NewNamedOrigin("TestPolicy", DeviceScope)), + want: NewSnapshot(map[pkey.Key]RawItem{ + "Setting1": RawItemOf(123), + "Setting2": RawItemOf("String"), + "Setting3": RawItemOf(true), }, CurrentUserScope, NewNamedOrigin("TestPolicy", DeviceScope)), }, { name: "with-scope-second-empty", - s1: NewSnapshot(map[Key]RawItem{ - "Setting1": {value: 123}, - "Setting2": {value: "String"}, - "Setting3": {value: true}, + s1: NewSnapshot(map[pkey.Key]RawItem{ + "Setting1": RawItemOf(123), + "Setting2": RawItemOf("String"), + "Setting3": RawItemOf(true), }, CurrentUserScope), - s2: NewSnapshot(map[Key]RawItem{}), - want: NewSnapshot(map[Key]RawItem{ - "Setting1": {value: 123}, - "Setting2": {value: "String"}, - "Setting3": {value: true}, + s2: NewSnapshot(map[pkey.Key]RawItem{}), + want: NewSnapshot(map[pkey.Key]RawItem{ + "Setting1": RawItemOf(123), + "Setting2": RawItemOf("String"), + "Setting3": RawItemOf(true), }, CurrentUserScope), }, } @@ -222,52 +233,52 @@ func TestSnapshotEqual(t *testing.T) { { name: "nil-empty", s1: nil, - s2: NewSnapshot(map[Key]RawItem{}), + s2: NewSnapshot(map[pkey.Key]RawItem{}), wantEqual: true, wantEqualItems: true, }, { name: "empty-nil", - s1: NewSnapshot(map[Key]RawItem{}), + s1: NewSnapshot(map[pkey.Key]RawItem{}), s2: nil, wantEqual: true, wantEqualItems: true, }, { name: "empty-empty", - s1: NewSnapshot(map[Key]RawItem{}), - s2: NewSnapshot(map[Key]RawItem{}), + s1: NewSnapshot(map[pkey.Key]RawItem{}), + s2: NewSnapshot(map[pkey.Key]RawItem{}), wantEqual: true, wantEqualItems: true, }, { name: "first-nil", s1: nil, - s2: NewSnapshot(map[Key]RawItem{ - "Setting1": {value: 123}, - "Setting2": {value: "String"}, - "Setting3": {value: false}, + s2: NewSnapshot(map[pkey.Key]RawItem{ + "Setting1": RawItemOf(123), + "Setting2": RawItemOf("String"), + "Setting3": RawItemOf(false), }), wantEqual: false, wantEqualItems: false, }, { name: "first-empty", - s1: NewSnapshot(map[Key]RawItem{}), - s2: NewSnapshot(map[Key]RawItem{ - "Setting1": {value: 123}, - "Setting2": {value: "String"}, - "Setting3": {value: false}, + s1: NewSnapshot(map[pkey.Key]RawItem{}), + s2: NewSnapshot(map[pkey.Key]RawItem{ + "Setting1": RawItemOf(123), + "Setting2": RawItemOf("String"), + "Setting3": RawItemOf(false), }), wantEqual: false, wantEqualItems: false, }, { name: "second-nil", - s1: NewSnapshot(map[Key]RawItem{ - "Setting1": {value: 123}, - "Setting2": {value: "String"}, - "Setting3": {value: true}, + s1: NewSnapshot(map[pkey.Key]RawItem{ + "Setting1": RawItemOf(123), + "Setting2": RawItemOf("String"), + "Setting3": RawItemOf(true), }), s2: nil, wantEqual: false, @@ -275,86 +286,86 @@ func TestSnapshotEqual(t *testing.T) { }, { name: "second-empty", - s1: NewSnapshot(map[Key]RawItem{ - "Setting1": {value: 123}, - "Setting2": {value: "String"}, - "Setting3": {value: false}, + s1: NewSnapshot(map[pkey.Key]RawItem{ + "Setting1": RawItemOf(123), + "Setting2": RawItemOf("String"), + "Setting3": RawItemOf(false), }), - s2: NewSnapshot(map[Key]RawItem{}), + s2: NewSnapshot(map[pkey.Key]RawItem{}), wantEqual: false, wantEqualItems: false, }, { name: "same-items-same-order-no-scope", - s1: NewSnapshot(map[Key]RawItem{ - "Setting1": {value: 123}, - "Setting2": {value: "String"}, - "Setting3": {value: false}, + s1: NewSnapshot(map[pkey.Key]RawItem{ + "Setting1": RawItemOf(123), + "Setting2": RawItemOf("String"), + "Setting3": RawItemOf(false), }), - s2: NewSnapshot(map[Key]RawItem{ - "Setting1": {value: 123}, - "Setting2": {value: "String"}, - "Setting3": {value: false}, + s2: NewSnapshot(map[pkey.Key]RawItem{ + "Setting1": RawItemOf(123), + "Setting2": RawItemOf("String"), + "Setting3": RawItemOf(false), }), wantEqual: true, wantEqualItems: true, }, { name: "same-items-same-order-same-scope", - s1: NewSnapshot(map[Key]RawItem{ - "Setting1": {value: 123}, - "Setting2": {value: "String"}, - "Setting3": {value: false}, + s1: NewSnapshot(map[pkey.Key]RawItem{ + "Setting1": RawItemOf(123), + "Setting2": RawItemOf("String"), + "Setting3": RawItemOf(false), }, DeviceScope), - s2: NewSnapshot(map[Key]RawItem{ - "Setting1": {value: 123}, - "Setting2": {value: "String"}, - "Setting3": {value: false}, + s2: NewSnapshot(map[pkey.Key]RawItem{ + "Setting1": RawItemOf(123), + "Setting2": RawItemOf("String"), + "Setting3": RawItemOf(false), }, DeviceScope), wantEqual: true, wantEqualItems: true, }, { name: "same-items-different-order-same-scope", - s1: NewSnapshot(map[Key]RawItem{ - "Setting1": {value: 123}, - "Setting2": {value: "String"}, - "Setting3": {value: false}, + s1: NewSnapshot(map[pkey.Key]RawItem{ + "Setting1": RawItemOf(123), + "Setting2": RawItemOf("String"), + "Setting3": RawItemOf(false), }, DeviceScope), - s2: NewSnapshot(map[Key]RawItem{ - "Setting3": {value: false}, - "Setting1": {value: 123}, - "Setting2": {value: "String"}, + s2: NewSnapshot(map[pkey.Key]RawItem{ + "Setting3": RawItemOf(false), + "Setting1": RawItemOf(123), + "Setting2": RawItemOf("String"), }, DeviceScope), wantEqual: true, wantEqualItems: true, }, { name: "same-items-same-order-different-scope", - s1: NewSnapshot(map[Key]RawItem{ - "Setting1": {value: 123}, - "Setting2": {value: "String"}, - "Setting3": {value: false}, + s1: NewSnapshot(map[pkey.Key]RawItem{ + "Setting1": RawItemOf(123), + "Setting2": RawItemOf("String"), + "Setting3": RawItemOf(false), }, DeviceScope), - s2: NewSnapshot(map[Key]RawItem{ - "Setting1": {value: 123}, - "Setting2": {value: "String"}, - "Setting3": {value: false}, + s2: NewSnapshot(map[pkey.Key]RawItem{ + "Setting1": RawItemOf(123), + "Setting2": RawItemOf("String"), + "Setting3": RawItemOf(false), }, CurrentUserScope), wantEqual: false, wantEqualItems: true, }, { name: "different-items-same-scope", - s1: NewSnapshot(map[Key]RawItem{ - "Setting1": {value: 123}, - "Setting2": {value: "String"}, - "Setting3": {value: false}, + s1: NewSnapshot(map[pkey.Key]RawItem{ + "Setting1": RawItemOf(123), + "Setting2": RawItemOf("String"), + "Setting3": RawItemOf(false), }, DeviceScope), - s2: NewSnapshot(map[Key]RawItem{ - "Setting4": {value: 2 * time.Hour}, - "Setting5": {value: VisibleByPolicy}, - "Setting6": {value: ShowChoiceByPolicy}, + s2: NewSnapshot(map[pkey.Key]RawItem{ + "Setting4": RawItemOf(2 * time.Hour), + "Setting5": RawItemOf(VisibleByPolicy), + "Setting6": RawItemOf(ShowChoiceByPolicy), }, DeviceScope), wantEqual: false, wantEqualItems: false, @@ -400,10 +411,10 @@ func TestSnapshotString(t *testing.T) { }, { name: "non-empty", - snapshot: NewSnapshot(map[Key]RawItem{ - "Setting1": {value: 2 * time.Hour}, - "Setting2": {value: VisibleByPolicy}, - "Setting3": {value: ShowChoiceByPolicy}, + snapshot: NewSnapshot(map[pkey.Key]RawItem{ + "Setting1": RawItemOf(2 * time.Hour), + "Setting2": RawItemOf(VisibleByPolicy), + "Setting3": RawItemOf(ShowChoiceByPolicy), }, NewNamedOrigin("Test Policy", DeviceScope)), wantString: `{Test Policy (Device)} Setting1 = 2h0m0s @@ -412,15 +423,15 @@ Setting3 = user-decides`, }, { name: "non-empty-with-item-origin", - snapshot: NewSnapshot(map[Key]RawItem{ - "Setting1": {value: 42, origin: NewNamedOrigin("Test Policy", DeviceScope)}, + snapshot: NewSnapshot(map[pkey.Key]RawItem{ + "Setting1": RawItemWith(42, nil, NewNamedOrigin("Test Policy", DeviceScope)), }), wantString: `Setting1 = 42 - {Test Policy (Device)}`, }, { name: "non-empty-with-item-error", - snapshot: NewSnapshot(map[Key]RawItem{ - "Setting1": {err: NewErrorText("bang!")}, + snapshot: NewSnapshot(map[pkey.Key]RawItem{ + "Setting1": RawItemWith(nil, NewErrorText("bang!"), nil), }), wantString: `Setting1 = Error{"bang!"}`, }, @@ -433,3 +444,145 @@ Setting3 = user-decides`, }) } } + +func TestMarshalUnmarshalSnapshot(t *testing.T) { + tests := []struct { + name string + snapshot *Snapshot + wantJSON string + wantBack *Snapshot + }{ + { + name: "Nil", + snapshot: (*Snapshot)(nil), + wantJSON: "null", + wantBack: NewSnapshot(nil), + }, + { + name: "Zero", + snapshot: &Snapshot{}, + wantJSON: "{}", + }, + { + name: "Bool/True", + snapshot: NewSnapshot(map[pkey.Key]RawItem{"BoolPolicy": RawItemOf(true)}), + wantJSON: `{"Settings": {"BoolPolicy": {"Value": true}}}`, + }, + { + name: "Bool/False", + snapshot: NewSnapshot(map[pkey.Key]RawItem{"BoolPolicy": RawItemOf(false)}), + wantJSON: `{"Settings": {"BoolPolicy": {"Value": false}}}`, + }, + { + name: "String/Non-Empty", + snapshot: NewSnapshot(map[pkey.Key]RawItem{"StringPolicy": RawItemOf("StringValue")}), + wantJSON: `{"Settings": {"StringPolicy": {"Value": "StringValue"}}}`, + }, + { + name: "String/Empty", + snapshot: NewSnapshot(map[pkey.Key]RawItem{"StringPolicy": RawItemOf("")}), + wantJSON: `{"Settings": {"StringPolicy": {"Value": ""}}}`, + }, + { + name: "Integer/NonZero", + snapshot: NewSnapshot(map[pkey.Key]RawItem{"IntPolicy": RawItemOf(uint64(42))}), + wantJSON: `{"Settings": {"IntPolicy": {"Value": 42}}}`, + }, + { + name: "Integer/Zero", + snapshot: NewSnapshot(map[pkey.Key]RawItem{"IntPolicy": RawItemOf(uint64(0))}), + wantJSON: `{"Settings": {"IntPolicy": {"Value": 0}}}`, + }, + { + name: "String-List", + snapshot: NewSnapshot(map[pkey.Key]RawItem{"ListPolicy": RawItemOf([]string{"Value1", "Value2"})}), + wantJSON: `{"Settings": {"ListPolicy": {"Value": ["Value1", "Value2"]}}}`, + }, + { + name: "Duration/Zero", + snapshot: NewSnapshot(map[pkey.Key]RawItem{"DurationPolicy": RawItemOf(time.Duration(0))}), + wantJSON: `{"Settings": {"DurationPolicy": {"Value": "0s"}}}`, + wantBack: NewSnapshot(map[pkey.Key]RawItem{"DurationPolicy": RawItemOf("0s")}), + }, + { + name: "Duration/NonZero", + snapshot: NewSnapshot(map[pkey.Key]RawItem{"DurationPolicy": RawItemOf(2 * time.Hour)}), + wantJSON: `{"Settings": {"DurationPolicy": {"Value": "2h0m0s"}}}`, + wantBack: NewSnapshot(map[pkey.Key]RawItem{"DurationPolicy": RawItemOf("2h0m0s")}), + }, + { + name: "Empty/With-Summary", + snapshot: NewSnapshot( + map[pkey.Key]RawItem{}, + SummaryWith(CurrentUserScope, NewNamedOrigin("TestSource", DeviceScope)), + ), + wantJSON: `{"Summary": {"Origin": {"Name": "TestSource", "Scope": "Device"}, "Scope": "User"}}`, + }, + { + name: "Setting/With-Summary", + snapshot: NewSnapshot( + map[pkey.Key]RawItem{"PolicySetting": RawItemOf(uint64(42))}, + SummaryWith(CurrentUserScope, NewNamedOrigin("TestSource", DeviceScope)), + ), + wantJSON: `{ + "Summary": {"Origin": {"Name": "TestSource", "Scope": "Device"}, "Scope": "User"}, + "Settings": {"PolicySetting": {"Value": 42}} + }`, + }, + { + name: "Settings/With-Origins", + snapshot: NewSnapshot( + map[pkey.Key]RawItem{ + "SettingA": RawItemWith(uint64(42), nil, NewNamedOrigin("SourceA", DeviceScope)), + "SettingB": RawItemWith("B", nil, NewNamedOrigin("SourceB", CurrentProfileScope)), + "SettingC": RawItemWith(true, nil, NewNamedOrigin("SourceC", CurrentUserScope)), + }, + ), + wantJSON: `{ + "Settings": { + "SettingA": {"Value": 42, "Origin": {"Name": "SourceA", "Scope": "Device"}}, + "SettingB": {"Value": "B", "Origin": {"Name": "SourceB", "Scope": "Profile"}}, + "SettingC": {"Value": true, "Origin": {"Name": "SourceC", "Scope": "User"}} + } + }`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + doTest := func(t *testing.T, useJSONv2 bool) { + var gotJSON []byte + var err error + if useJSONv2 { + gotJSON, err = jsonv2.Marshal(tt.snapshot) + } else { + gotJSON, err = json.Marshal(tt.snapshot) + } + if err != nil { + t.Fatal(err) + } + + if got, want, equal := internal.EqualJSONForTest(t, gotJSON, []byte(tt.wantJSON)); !equal { + t.Errorf("JSON: got %s; want %s", got, want) + } + + gotBack := &Snapshot{} + if useJSONv2 { + err = jsonv2.Unmarshal(gotJSON, &gotBack) + } else { + err = json.Unmarshal(gotJSON, &gotBack) + } + if err != nil { + t.Fatal(err) + } + + if wantBack := cmp.Or(tt.wantBack, tt.snapshot); !gotBack.Equal(wantBack) { + t.Errorf("Snapshot: got %+v; want %+v", gotBack, wantBack) + } + } + + t.Run("json", func(t *testing.T) { doTest(t, false) }) + t.Run("jsonv2", func(t *testing.T) { doTest(t, true) }) + }) + } +} diff --git a/util/syspolicy/setting/summary.go b/util/syspolicy/setting/summary.go index 5ff20e0aa..9864822f7 100644 --- a/util/syspolicy/setting/summary.go +++ b/util/syspolicy/setting/summary.go @@ -54,24 +54,29 @@ func (s Summary) String() string { return s.data.Scope.String() } -// MarshalJSONV2 implements [jsonv2.MarshalerV2]. -func (s Summary) MarshalJSONV2(out *jsontext.Encoder, opts jsonv2.Options) error { - return jsonv2.MarshalEncode(out, &s.data, opts) +var ( + _ jsonv2.MarshalerTo = (*Summary)(nil) + _ jsonv2.UnmarshalerFrom = (*Summary)(nil) +) + +// MarshalJSONTo implements [jsonv2.MarshalerTo]. +func (s Summary) MarshalJSONTo(out *jsontext.Encoder) error { + return jsonv2.MarshalEncode(out, &s.data) } -// UnmarshalJSONV2 implements [jsonv2.UnmarshalerV2]. -func (s *Summary) UnmarshalJSONV2(in *jsontext.Decoder, opts jsonv2.Options) error { - return jsonv2.UnmarshalDecode(in, &s.data, opts) +// UnmarshalJSONFrom implements [jsonv2.UnmarshalerFrom]. +func (s *Summary) UnmarshalJSONFrom(in *jsontext.Decoder) error { + return jsonv2.UnmarshalDecode(in, &s.data) } // MarshalJSON implements [json.Marshaler]. func (s Summary) MarshalJSON() ([]byte, error) { - return jsonv2.Marshal(s) // uses MarshalJSONV2 + return jsonv2.Marshal(s) // uses MarshalJSONTo } // UnmarshalJSON implements [json.Unmarshaler]. func (s *Summary) UnmarshalJSON(b []byte) error { - return jsonv2.Unmarshal(b, s) // uses UnmarshalJSONV2 + return jsonv2.Unmarshal(b, s) // uses UnmarshalJSONFrom } // SummaryOption is an option that configures [Summary] diff --git a/util/syspolicy/source/env_policy_store.go b/util/syspolicy/source/env_policy_store.go new file mode 100644 index 000000000..be363b79a --- /dev/null +++ b/util/syspolicy/source/env_policy_store.go @@ -0,0 +1,160 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package source + +import ( + "errors" + "fmt" + "os" + "strconv" + "strings" + "unicode/utf8" + + "tailscale.com/util/syspolicy/pkey" + "tailscale.com/util/syspolicy/setting" +) + +var lookupEnv = os.LookupEnv // test hook + +var _ Store = (*EnvPolicyStore)(nil) + +// EnvPolicyStore is a [Store] that reads policy settings from environment variables. +type EnvPolicyStore struct{} + +// ReadString implements [Store]. +func (s *EnvPolicyStore) ReadString(key pkey.Key) (string, error) { + _, str, err := s.lookupSettingVariable(key) + if err != nil { + return "", err + } + return str, nil +} + +// ReadUInt64 implements [Store]. +func (s *EnvPolicyStore) ReadUInt64(key pkey.Key) (uint64, error) { + name, str, err := s.lookupSettingVariable(key) + if err != nil { + return 0, err + } + if str == "" { + return 0, setting.ErrNotConfigured + } + value, err := strconv.ParseUint(str, 0, 64) + if err != nil { + return 0, fmt.Errorf("%s: %w: %q is not a valid uint64", name, setting.ErrTypeMismatch, str) + } + return value, nil +} + +// ReadBoolean implements [Store]. +func (s *EnvPolicyStore) ReadBoolean(key pkey.Key) (bool, error) { + name, str, err := s.lookupSettingVariable(key) + if err != nil { + return false, err + } + if str == "" { + return false, setting.ErrNotConfigured + } + value, err := strconv.ParseBool(str) + if err != nil { + return false, fmt.Errorf("%s: %w: %q is not a valid bool", name, setting.ErrTypeMismatch, str) + } + return value, nil +} + +// ReadStringArray implements [Store]. +func (s *EnvPolicyStore) ReadStringArray(key pkey.Key) ([]string, error) { + _, str, err := s.lookupSettingVariable(key) + if err != nil || str == "" { + return nil, err + } + var dst int + res := strings.Split(str, ",") + for src := range res { + res[dst] = strings.TrimSpace(res[src]) + if res[dst] != "" { + dst++ + } + } + return res[0:dst], nil +} + +func (s *EnvPolicyStore) lookupSettingVariable(key pkey.Key) (name, value string, err error) { + name, err = keyToEnvVarName(key) + if err != nil { + return "", "", err + } + value, ok := lookupEnv(name) + if !ok { + return name, "", setting.ErrNotConfigured + } + return name, value, nil +} + +var ( + errEmptyKey = errors.New("key must not be empty") + errInvalidKey = errors.New("key must consist of alphanumeric characters and slashes") +) + +// keyToEnvVarName returns the environment variable name for a given policy +// setting key, or an error if the key is invalid. It converts CamelCase keys into +// underscore-separated words and prepends the variable name with the TS prefix. +// For example: AuthKey => TS_AUTH_KEY, ExitNodeAllowLANAccess => TS_EXIT_NODE_ALLOW_LAN_ACCESS, etc. +// +// It's fine to use this in [EnvPolicyStore] without caching variable names since it's not a hot path. +// [EnvPolicyStore] is not a [Changeable] policy store, so the conversion will only happen once. +func keyToEnvVarName(key pkey.Key) (string, error) { + if len(key) == 0 { + return "", errEmptyKey + } + + isLower := func(c byte) bool { return 'a' <= c && c <= 'z' } + isUpper := func(c byte) bool { return 'A' <= c && c <= 'Z' } + isLetter := func(c byte) bool { return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') } + isDigit := func(c byte) bool { return '0' <= c && c <= '9' } + + words := make([]string, 0, 8) + words = append(words, "TS_DEBUGSYSPOLICY") + var currentWord strings.Builder + for i := 0; i < len(key); i++ { + c := key[i] + if c >= utf8.RuneSelf { + return "", errInvalidKey + } + + var split bool + switch { + case isLower(c): + c -= 'a' - 'A' // make upper + split = currentWord.Len() > 0 && !isLetter(key[i-1]) + case isUpper(c): + if currentWord.Len() > 0 { + prevUpper := isUpper(key[i-1]) + nextLower := i < len(key)-1 && isLower(key[i+1]) + split = !prevUpper || nextLower // split on case transition + } + case isDigit(c): + split = currentWord.Len() > 0 && !isDigit(key[i-1]) + case c == pkey.KeyPathSeparator: + words = append(words, currentWord.String()) + currentWord.Reset() + continue + default: + return "", errInvalidKey + } + + if split { + words = append(words, currentWord.String()) + currentWord.Reset() + } + + currentWord.WriteByte(c) + } + + if currentWord.Len() > 0 { + words = append(words, currentWord.String()) + } + + return strings.Join(words, "_"), nil +} diff --git a/util/syspolicy/source/env_policy_store_test.go b/util/syspolicy/source/env_policy_store_test.go new file mode 100644 index 000000000..3255095b2 --- /dev/null +++ b/util/syspolicy/source/env_policy_store_test.go @@ -0,0 +1,360 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package source + +import ( + "cmp" + "errors" + "math" + "reflect" + "strconv" + "testing" + + "tailscale.com/util/syspolicy/pkey" + "tailscale.com/util/syspolicy/setting" +) + +func TestKeyToEnvVarName(t *testing.T) { + tests := []struct { + name string + key pkey.Key + want string // suffix after "TS_DEBUGSYSPOLICY_" + wantErr error + }{ + { + name: "empty", + key: "", + wantErr: errEmptyKey, + }, + { + name: "lowercase", + key: "tailnet", + want: "TAILNET", + }, + { + name: "CamelCase", + key: "AuthKey", + want: "AUTH_KEY", + }, + { + name: "LongerCamelCase", + key: "ManagedByOrganizationName", + want: "MANAGED_BY_ORGANIZATION_NAME", + }, + { + name: "UPPERCASE", + key: "UPPERCASE", + want: "UPPERCASE", + }, + { + name: "WithAbbrev/Front", + key: "DNSServer", + want: "DNS_SERVER", + }, + { + name: "WithAbbrev/Middle", + key: "ExitNodeAllowLANAccess", + want: "EXIT_NODE_ALLOW_LAN_ACCESS", + }, + { + name: "WithAbbrev/Back", + key: "ExitNodeID", + want: "EXIT_NODE_ID", + }, + { + name: "WithDigits/Single/Front", + key: "0TestKey", + want: "0_TEST_KEY", + }, + { + name: "WithDigits/Multi/Front", + key: "64TestKey", + want: "64_TEST_KEY", + }, + { + name: "WithDigits/Single/Middle", + key: "Test0Key", + want: "TEST_0_KEY", + }, + { + name: "WithDigits/Multi/Middle", + key: "Test64Key", + want: "TEST_64_KEY", + }, + { + name: "WithDigits/Single/Back", + key: "TestKey0", + want: "TEST_KEY_0", + }, + { + name: "WithDigits/Multi/Back", + key: "TestKey64", + want: "TEST_KEY_64", + }, + { + name: "WithDigits/Multi/Back", + key: "TestKey64", + want: "TEST_KEY_64", + }, + { + name: "WithPathSeparators/Single", + key: "Key/Subkey", + want: "KEY_SUBKEY", + }, + { + name: "WithPathSeparators/Multi", + key: "Root/Level1/Level2", + want: "ROOT_LEVEL_1_LEVEL_2", + }, + { + name: "Mixed", + key: "Network/DNSServer/IPAddress", + want: "NETWORK_DNS_SERVER_IP_ADDRESS", + }, + { + name: "Non-Alphanumeric/NonASCII/1", + key: "Đļ", + wantErr: errInvalidKey, + }, + { + name: "Non-Alphanumeric/NonASCII/2", + key: "KeyĐļName", + wantErr: errInvalidKey, + }, + { + name: "Non-Alphanumeric/Space", + key: "Key Name", + wantErr: errInvalidKey, + }, + { + name: "Non-Alphanumeric/Punct", + key: "Key!Name", + wantErr: errInvalidKey, + }, + { + name: "Non-Alphanumeric/Backslash", + key: `Key\Name`, + wantErr: errInvalidKey, + }, + } + for _, tt := range tests { + t.Run(cmp.Or(tt.name, string(tt.key)), func(t *testing.T) { + got, err := keyToEnvVarName(tt.key) + checkError(t, err, tt.wantErr, true) + + want := tt.want + if want != "" { + want = "TS_DEBUGSYSPOLICY_" + want + } + if got != want { + t.Fatalf("got %q; want %q", got, want) + } + }) + } +} + +func TestEnvPolicyStore(t *testing.T) { + blankEnv := func(string) (string, bool) { return "", false } + makeEnv := func(wantName, value string) func(string) (string, bool) { + wantName = "TS_DEBUGSYSPOLICY_" + wantName + return func(gotName string) (string, bool) { + if gotName != wantName { + return "", false + } + return value, true + } + } + tests := []struct { + name string + key pkey.Key + lookup func(string) (string, bool) + want any + wantErr error + }{ + { + name: "NotConfigured/String", + key: "AuthKey", + lookup: blankEnv, + wantErr: setting.ErrNotConfigured, + want: "", + }, + { + name: "Configured/String/Empty", + key: "AuthKey", + lookup: makeEnv("AUTH_KEY", ""), + want: "", + }, + { + name: "Configured/String/NonEmpty", + key: "AuthKey", + lookup: makeEnv("AUTH_KEY", "ABC123"), + want: "ABC123", + }, + { + name: "NotConfigured/UInt64", + key: "IntegerSetting", + lookup: blankEnv, + wantErr: setting.ErrNotConfigured, + want: uint64(0), + }, + { + name: "Configured/UInt64/Empty", + key: "IntegerSetting", + lookup: makeEnv("INTEGER_SETTING", ""), + wantErr: setting.ErrNotConfigured, + want: uint64(0), + }, + { + name: "Configured/UInt64/Zero", + key: "IntegerSetting", + lookup: makeEnv("INTEGER_SETTING", "0"), + want: uint64(0), + }, + { + name: "Configured/UInt64/NonZero", + key: "IntegerSetting", + lookup: makeEnv("INTEGER_SETTING", "12345"), + want: uint64(12345), + }, + { + name: "Configured/UInt64/MaxUInt64", + key: "IntegerSetting", + lookup: makeEnv("INTEGER_SETTING", strconv.FormatUint(math.MaxUint64, 10)), + want: uint64(math.MaxUint64), + }, + { + name: "Configured/UInt64/Negative", + key: "IntegerSetting", + lookup: makeEnv("INTEGER_SETTING", "-1"), + wantErr: setting.ErrTypeMismatch, + want: uint64(0), + }, + { + name: "Configured/UInt64/Hex", + key: "IntegerSetting", + lookup: makeEnv("INTEGER_SETTING", "0xDEADBEEF"), + want: uint64(0xDEADBEEF), + }, + { + name: "NotConfigured/Bool", + key: "LogSCMInteractions", + lookup: blankEnv, + wantErr: setting.ErrNotConfigured, + want: false, + }, + { + name: "Configured/Bool/Empty", + key: "LogSCMInteractions", + lookup: makeEnv("LOG_SCM_INTERACTIONS", ""), + wantErr: setting.ErrNotConfigured, + want: false, + }, + { + name: "Configured/Bool/True", + key: "LogSCMInteractions", + lookup: makeEnv("LOG_SCM_INTERACTIONS", "true"), + want: true, + }, + { + name: "Configured/Bool/False", + key: "LogSCMInteractions", + lookup: makeEnv("LOG_SCM_INTERACTIONS", "False"), + want: false, + }, + { + name: "Configured/Bool/1", + key: "LogSCMInteractions", + lookup: makeEnv("LOG_SCM_INTERACTIONS", "1"), + want: true, + }, + { + name: "Configured/Bool/0", + key: "LogSCMInteractions", + lookup: makeEnv("LOG_SCM_INTERACTIONS", "0"), + want: false, + }, + { + name: "Configured/Bool/Invalid", + key: "IntegerSetting", + lookup: makeEnv("INTEGER_SETTING", "NotABool"), + wantErr: setting.ErrTypeMismatch, + want: false, + }, + { + name: "NotConfigured/StringArray", + key: "AllowedSuggestedExitNodes", + lookup: blankEnv, + wantErr: setting.ErrNotConfigured, + want: []string(nil), + }, + { + name: "Configured/StringArray/Empty", + key: "AllowedSuggestedExitNodes", + lookup: makeEnv("ALLOWED_SUGGESTED_EXIT_NODES", ""), + want: []string(nil), + }, + { + name: "Configured/StringArray/Spaces", + key: "AllowedSuggestedExitNodes", + lookup: makeEnv("ALLOWED_SUGGESTED_EXIT_NODES", " \t "), + want: []string{}, + }, + { + name: "Configured/StringArray/Single", + key: "AllowedSuggestedExitNodes", + lookup: makeEnv("ALLOWED_SUGGESTED_EXIT_NODES", "NodeA"), + want: []string{"NodeA"}, + }, + { + name: "Configured/StringArray/Multi", + key: "AllowedSuggestedExitNodes", + lookup: makeEnv("ALLOWED_SUGGESTED_EXIT_NODES", "NodeA,NodeB,NodeC"), + want: []string{"NodeA", "NodeB", "NodeC"}, + }, + { + name: "Configured/StringArray/WithBlank", + key: "AllowedSuggestedExitNodes", + lookup: makeEnv("ALLOWED_SUGGESTED_EXIT_NODES", "NodeA,\t,, ,NodeB"), + want: []string{"NodeA", "NodeB"}, + }, + } + for _, tt := range tests { + t.Run(cmp.Or(tt.name, string(tt.key)), func(t *testing.T) { + oldLookupEnv := lookupEnv + t.Cleanup(func() { lookupEnv = oldLookupEnv }) + lookupEnv = tt.lookup + + var got any + var err error + var store EnvPolicyStore + switch tt.want.(type) { + case string: + got, err = store.ReadString(tt.key) + case uint64: + got, err = store.ReadUInt64(tt.key) + case bool: + got, err = store.ReadBoolean(tt.key) + case []string: + got, err = store.ReadStringArray(tt.key) + } + checkError(t, err, tt.wantErr, false) + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("got %v; want %v", got, tt.want) + } + }) + } +} + +func checkError(tb testing.TB, got, want error, fatal bool) { + tb.Helper() + f := tb.Errorf + if fatal { + f = tb.Fatalf + } + if (want == nil && got != nil) || + (want != nil && got == nil) || + (want != nil && got != nil && !errors.Is(got, want) && want.Error() != got.Error()) { + f("gotErr: %v; wantErr: %v", got, want) + } +} diff --git a/util/syspolicy/source/policy_reader.go b/util/syspolicy/source/policy_reader.go index a1bd3147e..33ef22912 100644 --- a/util/syspolicy/source/policy_reader.go +++ b/util/syspolicy/source/policy_reader.go @@ -16,6 +16,8 @@ import ( "tailscale.com/util/set" "tailscale.com/util/syspolicy/internal/loggerx" "tailscale.com/util/syspolicy/internal/metrics" + "tailscale.com/util/syspolicy/pkey" + "tailscale.com/util/syspolicy/ptype" "tailscale.com/util/syspolicy/setting" ) @@ -138,9 +140,9 @@ func (r *Reader) reload(force bool) (*setting.Snapshot, error) { metrics.Reset(r.origin) - var m map[setting.Key]setting.RawItem + var m map[pkey.Key]setting.RawItem if lastPolicyCount := r.lastPolicy.Len(); lastPolicyCount > 0 { - m = make(map[setting.Key]setting.RawItem, lastPolicyCount) + m = make(map[pkey.Key]setting.RawItem, lastPolicyCount) } for _, s := range r.settings { if !r.origin.Scope().IsConfigurableSetting(s) { @@ -364,21 +366,21 @@ func readPolicySettingValue(store Store, s *setting.Definition) (value any, err case setting.PreferenceOptionValue: s, err := store.ReadString(key) if err == nil { - var value setting.PreferenceOption + var value ptype.PreferenceOption if err = value.UnmarshalText([]byte(s)); err == nil { return value, nil } } - return setting.ShowChoiceByPolicy, err + return ptype.ShowChoiceByPolicy, err case setting.VisibilityValue: s, err := store.ReadString(key) if err == nil { - var value setting.Visibility + var value ptype.Visibility if err = value.UnmarshalText([]byte(s)); err == nil { return value, nil } } - return setting.VisibleByPolicy, err + return ptype.VisibleByPolicy, err case setting.DurationValue: s, err := store.ReadString(key) if err == nil { diff --git a/util/syspolicy/source/policy_reader_test.go b/util/syspolicy/source/policy_reader_test.go index 57676e67d..32e8c51a6 100644 --- a/util/syspolicy/source/policy_reader_test.go +++ b/util/syspolicy/source/policy_reader_test.go @@ -9,6 +9,8 @@ import ( "time" "tailscale.com/util/must" + "tailscale.com/util/syspolicy/pkey" + "tailscale.com/util/syspolicy/ptype" "tailscale.com/util/syspolicy/setting" ) @@ -72,7 +74,7 @@ func TestReaderLifecycle(t *testing.T) { initWant: setting.NewSnapshot(nil, setting.NewNamedOrigin("Test", setting.DeviceScope)), addStrings: []TestSetting[string]{TestSettingOf("StringValue", "S1")}, addStringLists: []TestSetting[[]string]{TestSettingOf("StringListValue", []string{"S1", "S2", "S3"})}, - newWant: setting.NewSnapshot(map[setting.Key]setting.RawItem{ + newWant: setting.NewSnapshot(map[pkey.Key]setting.RawItem{ "StringValue": setting.RawItemWith("S1", nil, setting.NewNamedOrigin("Test", setting.DeviceScope)), "StringListValue": setting.RawItemWith([]string{"S1", "S2", "S3"}, nil, setting.NewNamedOrigin("Test", setting.DeviceScope)), }, setting.NewNamedOrigin("Test", setting.DeviceScope)), @@ -136,10 +138,10 @@ func TestReaderLifecycle(t *testing.T) { TestSettingOf("PreferenceOptionValue", "always"), TestSettingOf("VisibilityValue", "show"), }, - initWant: setting.NewSnapshot(map[setting.Key]setting.RawItem{ + initWant: setting.NewSnapshot(map[pkey.Key]setting.RawItem{ "DurationValue": setting.RawItemWith(must.Get(time.ParseDuration("2h30m")), nil, setting.NewNamedOrigin("Test", setting.DeviceScope)), - "PreferenceOptionValue": setting.RawItemWith(setting.AlwaysByPolicy, nil, setting.NewNamedOrigin("Test", setting.DeviceScope)), - "VisibilityValue": setting.RawItemWith(setting.VisibleByPolicy, nil, setting.NewNamedOrigin("Test", setting.DeviceScope)), + "PreferenceOptionValue": setting.RawItemWith(ptype.AlwaysByPolicy, nil, setting.NewNamedOrigin("Test", setting.DeviceScope)), + "VisibilityValue": setting.RawItemWith(ptype.VisibleByPolicy, nil, setting.NewNamedOrigin("Test", setting.DeviceScope)), }, setting.NewNamedOrigin("Test", setting.DeviceScope)), }, { @@ -165,11 +167,11 @@ func TestReaderLifecycle(t *testing.T) { initUInt64s: []TestSetting[uint64]{ TestSettingOf[uint64]("VisibilityValue", 42), // type mismatch }, - initWant: setting.NewSnapshot(map[setting.Key]setting.RawItem{ + initWant: setting.NewSnapshot(map[pkey.Key]setting.RawItem{ "DurationValue1": setting.RawItemWith(nil, setting.NewErrorText("time: invalid duration \"soon\""), setting.NewNamedOrigin("Test", setting.CurrentUserScope)), "DurationValue2": setting.RawItemWith(nil, setting.NewErrorText("bang!"), setting.NewNamedOrigin("Test", setting.CurrentUserScope)), - "PreferenceOptionValue": setting.RawItemWith(setting.ShowChoiceByPolicy, nil, setting.NewNamedOrigin("Test", setting.CurrentUserScope)), - "VisibilityValue": setting.RawItemWith(setting.VisibleByPolicy, setting.NewErrorText("type mismatch in ReadString: got uint64"), setting.NewNamedOrigin("Test", setting.CurrentUserScope)), + "PreferenceOptionValue": setting.RawItemWith(ptype.ShowChoiceByPolicy, nil, setting.NewNamedOrigin("Test", setting.CurrentUserScope)), + "VisibilityValue": setting.RawItemWith(ptype.VisibleByPolicy, setting.NewErrorText("type mismatch in ReadString: got uint64"), setting.NewNamedOrigin("Test", setting.CurrentUserScope)), }, setting.NewNamedOrigin("Test", setting.CurrentUserScope)), }, } @@ -277,7 +279,7 @@ func TestReadingSession(t *testing.T) { t.Fatalf("the session was closed prematurely") } - want := setting.NewSnapshot(map[setting.Key]setting.RawItem{ + want := setting.NewSnapshot(map[pkey.Key]setting.RawItem{ "StringValue": setting.RawItemWith("S1", nil, origin), }, origin) if got := session.GetSettings(); !got.Equal(want) { diff --git a/util/syspolicy/source/policy_source.go b/util/syspolicy/source/policy_source.go index 7f2821b59..c4774217c 100644 --- a/util/syspolicy/source/policy_source.go +++ b/util/syspolicy/source/policy_source.go @@ -13,6 +13,7 @@ import ( "io" "tailscale.com/types/lazy" + "tailscale.com/util/syspolicy/pkey" "tailscale.com/util/syspolicy/setting" ) @@ -31,19 +32,19 @@ type Store interface { // ReadString returns the value of a [setting.StringValue] with the specified key, // an [setting.ErrNotConfigured] if the policy setting is not configured, or // an error on failure. - ReadString(key setting.Key) (string, error) + ReadString(key pkey.Key) (string, error) // ReadUInt64 returns the value of a [setting.IntegerValue] with the specified key, // an [setting.ErrNotConfigured] if the policy setting is not configured, or // an error on failure. - ReadUInt64(key setting.Key) (uint64, error) + ReadUInt64(key pkey.Key) (uint64, error) // ReadBoolean returns the value of a [setting.BooleanValue] with the specified key, // an [setting.ErrNotConfigured] if the policy setting is not configured, or // an error on failure. - ReadBoolean(key setting.Key) (bool, error) + ReadBoolean(key pkey.Key) (bool, error) // ReadStringArray returns the value of a [setting.StringListValue] with the specified key, // an [setting.ErrNotConfigured] if the policy setting is not configured, or // an error on failure. - ReadStringArray(key setting.Key) ([]string, error) + ReadStringArray(key pkey.Key) ([]string, error) } // Lockable is an optional interface that [Store] implementations may support. diff --git a/util/syspolicy/source/policy_store_windows.go b/util/syspolicy/source/policy_store_windows.go index f526b4ce1..f97b17f3a 100644 --- a/util/syspolicy/source/policy_store_windows.go +++ b/util/syspolicy/source/policy_store_windows.go @@ -12,6 +12,8 @@ import ( "golang.org/x/sys/windows" "golang.org/x/sys/windows/registry" "tailscale.com/util/set" + "tailscale.com/util/syspolicy/internal/loggerx" + "tailscale.com/util/syspolicy/pkey" "tailscale.com/util/syspolicy/setting" "tailscale.com/util/winutil/gp" ) @@ -29,6 +31,18 @@ var ( _ Expirable = (*PlatformPolicyStore)(nil) ) +// lockableCloser is a [Lockable] that can also be closed. +// It is implemented by [gp.PolicyLock] and [optionalPolicyLock]. +type lockableCloser interface { + Lockable + Close() error +} + +var ( + _ lockableCloser = (*gp.PolicyLock)(nil) + _ lockableCloser = (*optionalPolicyLock)(nil) +) + // PlatformPolicyStore implements [Store] by providing read access to // Registry-based Tailscale policies, such as those configured via Group Policy or MDM. // For better performance and consistency, it is recommended to lock it when @@ -55,7 +69,7 @@ type PlatformPolicyStore struct { // they are being read. // // When both policyLock and mu need to be taken, mu must be taken before policyLock. - policyLock *gp.PolicyLock + policyLock lockableCloser mu sync.Mutex tsKeys []registry.Key // or nil if the [PlatformPolicyStore] hasn't been locked. @@ -108,7 +122,7 @@ func newPlatformPolicyStore(scope gp.Scope, softwareKey registry.Key, policyLock scope: scope, softwareKey: softwareKey, done: make(chan struct{}), - policyLock: policyLock, + policyLock: &optionalPolicyLock{PolicyLock: policyLock}, } } @@ -238,7 +252,7 @@ func (ps *PlatformPolicyStore) onChange() { // ReadString retrieves a string policy with the specified key. // It returns [setting.ErrNotConfigured] if the policy setting does not exist. -func (ps *PlatformPolicyStore) ReadString(key setting.Key) (val string, err error) { +func (ps *PlatformPolicyStore) ReadString(key pkey.Key) (val string, err error) { return getPolicyValue(ps, key, func(key registry.Key, valueName string) (string, error) { val, _, err := key.GetStringValue(valueName) @@ -248,7 +262,7 @@ func (ps *PlatformPolicyStore) ReadString(key setting.Key) (val string, err erro // ReadUInt64 retrieves an integer policy with the specified key. // It returns [setting.ErrNotConfigured] if the policy setting does not exist. -func (ps *PlatformPolicyStore) ReadUInt64(key setting.Key) (uint64, error) { +func (ps *PlatformPolicyStore) ReadUInt64(key pkey.Key) (uint64, error) { return getPolicyValue(ps, key, func(key registry.Key, valueName string) (uint64, error) { val, _, err := key.GetIntegerValue(valueName) @@ -258,7 +272,7 @@ func (ps *PlatformPolicyStore) ReadUInt64(key setting.Key) (uint64, error) { // ReadBoolean retrieves a boolean policy with the specified key. // It returns [setting.ErrNotConfigured] if the policy setting does not exist. -func (ps *PlatformPolicyStore) ReadBoolean(key setting.Key) (bool, error) { +func (ps *PlatformPolicyStore) ReadBoolean(key pkey.Key) (bool, error) { return getPolicyValue(ps, key, func(key registry.Key, valueName string) (bool, error) { val, _, err := key.GetIntegerValue(valueName) @@ -270,8 +284,8 @@ func (ps *PlatformPolicyStore) ReadBoolean(key setting.Key) (bool, error) { } // ReadString retrieves a multi-string policy with the specified key. -// It returns [setting.ErrNotConfigured] if the policy setting does not exist. -func (ps *PlatformPolicyStore) ReadStringArray(key setting.Key) ([]string, error) { +// It returns [pkey.ErrNotConfigured] if the policy setting does not exist. +func (ps *PlatformPolicyStore) ReadStringArray(key pkey.Key) ([]string, error) { return getPolicyValue(ps, key, func(key registry.Key, valueName string) ([]string, error) { val, _, err := key.GetStringsValue(valueName) @@ -309,25 +323,25 @@ func (ps *PlatformPolicyStore) ReadStringArray(key setting.Key) ([]string, error }) } -// splitSettingKey extracts the registry key name and value name from a [setting.Key]. -// The [setting.Key] format allows grouping settings into nested categories using one -// or more [setting.KeyPathSeparator]s in the path. How individual policy settings are +// splitSettingKey extracts the registry key name and value name from a [pkey.Key]. +// The [pkey.Key] format allows grouping settings into nested categories using one +// or more [pkey.KeyPathSeparator]s in the path. How individual policy settings are // stored is an implementation detail of each [Store]. In the [PlatformPolicyStore] // for Windows, we map nested policy categories onto the Registry key hierarchy. -// The last component after a [setting.KeyPathSeparator] is treated as the value name, +// The last component after a [pkey.KeyPathSeparator] is treated as the value name, // while everything preceding it is considered a subpath (relative to the {HKLM,HKCU}\Software\Policies\Tailscale key). -// If there are no [setting.KeyPathSeparator]s in the key, the policy setting value +// If there are no [pkey.KeyPathSeparator]s in the key, the policy setting value // is meant to be stored directly under {HKLM,HKCU}\Software\Policies\Tailscale. -func splitSettingKey(key setting.Key) (path, valueName string) { - if idx := strings.LastIndex(string(key), setting.KeyPathSeparator); idx != -1 { - path = strings.ReplaceAll(string(key[:idx]), setting.KeyPathSeparator, `\`) - valueName = string(key[idx+len(setting.KeyPathSeparator):]) +func splitSettingKey(key pkey.Key) (path, valueName string) { + if idx := strings.LastIndexByte(string(key), pkey.KeyPathSeparator); idx != -1 { + path = strings.ReplaceAll(string(key[:idx]), string(pkey.KeyPathSeparator), `\`) + valueName = string(key[idx+1:]) return path, valueName } return "", string(key) } -func getPolicyValue[T any](ps *PlatformPolicyStore, key setting.Key, getter registryValueGetter[T]) (T, error) { +func getPolicyValue[T any](ps *PlatformPolicyStore, key pkey.Key, getter registryValueGetter[T]) (T, error) { var zero T ps.mu.Lock() @@ -448,3 +462,68 @@ func tailscaleKeyNamesFor(scope gp.Scope) []string { panic("unreachable") } } + +type gpLockState int + +const ( + gpUnlocked = gpLockState(iota) + gpLocked + gpLockRestricted // the lock could not be acquired due to a restriction in place +) + +// optionalPolicyLock is a wrapper around [gp.PolicyLock] that locks +// and unlocks the underlying [gp.PolicyLock]. +// +// If the [gp.PolicyLock.Lock] returns [gp.ErrLockRestricted], the error is ignored, +// and calling [optionalPolicyLock.Unlock] is a no-op. +// +// The underlying GP lock is kinda optional: it is safe to read policy settings +// from the Registry without acquiring it, but it is recommended to lock it anyway +// when reading multiple policy settings to avoid potentially inconsistent results. +// +// It is not safe for concurrent use. +type optionalPolicyLock struct { + *gp.PolicyLock + state gpLockState +} + +// Lock acquires the underlying [gp.PolicyLock], returning an error on failure. +// If the lock cannot be acquired due to a restriction in place +// (e.g., attempting to acquire a lock while the service is starting), +// the lock is considered to be held, the method returns nil, and a subsequent +// call to [Unlock] is a no-op. +// It is a runtime error to call Lock when the lock is already held. +func (o *optionalPolicyLock) Lock() error { + if o.state != gpUnlocked { + panic("already locked") + } + switch err := o.PolicyLock.Lock(); err { + case nil: + o.state = gpLocked + return nil + case gp.ErrLockRestricted: + loggerx.Errorf("GP lock not acquired: %v", err) + o.state = gpLockRestricted + return nil + default: + return err + } +} + +// Unlock releases the underlying [gp.PolicyLock], if it was previously acquired. +// It is a runtime error to call Unlock when the lock is not held. +func (o *optionalPolicyLock) Unlock() { + switch o.state { + case gpLocked: + o.PolicyLock.Unlock() + case gpLockRestricted: + // The GP lock wasn't acquired due to a restriction in place + // when [optionalPolicyLock.Lock] was called. Unlock is a no-op. + case gpUnlocked: + panic("not locked") + default: + panic("unreachable") + } + + o.state = gpUnlocked +} diff --git a/util/syspolicy/source/policy_store_windows_test.go b/util/syspolicy/source/policy_store_windows_test.go index 33f85dc0b..4ab1da805 100644 --- a/util/syspolicy/source/policy_store_windows_test.go +++ b/util/syspolicy/source/policy_store_windows_test.go @@ -19,6 +19,7 @@ import ( "tailscale.com/tstest" "tailscale.com/util/cibuild" "tailscale.com/util/mak" + "tailscale.com/util/syspolicy/pkey" "tailscale.com/util/syspolicy/setting" "tailscale.com/util/winutil" "tailscale.com/util/winutil/gp" @@ -31,7 +32,7 @@ import ( type subkeyStrings []string type testPolicyValue struct { - name setting.Key + name pkey.Key value any } @@ -100,7 +101,7 @@ func TestReadPolicyStore(t *testing.T) { t.Skipf("test requires running as elevated user") } tests := []struct { - name setting.Key + name pkey.Key newValue any legacyValue any want any @@ -269,7 +270,7 @@ func TestPolicyStoreChangeNotifications(t *testing.T) { func TestSplitSettingKey(t *testing.T) { tests := []struct { name string - key setting.Key + key pkey.Key wantPath string wantValue string }{ diff --git a/util/syspolicy/source/test_store.go b/util/syspolicy/source/test_store.go index bb8e164fb..ddec9efbb 100644 --- a/util/syspolicy/source/test_store.go +++ b/util/syspolicy/source/test_store.go @@ -11,8 +11,10 @@ import ( xmaps "golang.org/x/exp/maps" "tailscale.com/util/mak" "tailscale.com/util/set" - "tailscale.com/util/syspolicy/internal" + "tailscale.com/util/slicesx" + "tailscale.com/util/syspolicy/pkey" "tailscale.com/util/syspolicy/setting" + "tailscale.com/util/testenv" ) var ( @@ -30,7 +32,7 @@ type TestValueType interface { // TestSetting is a policy setting in a [TestStore]. type TestSetting[T TestValueType] struct { // Key is the setting's unique identifier. - Key setting.Key + Key pkey.Key // Error is the error to be returned by the [TestStore] when reading // a policy setting with the specified key. Error error @@ -42,20 +44,20 @@ type TestSetting[T TestValueType] struct { // TestSettingOf returns a [TestSetting] representing a policy setting // configured with the specified key and value. -func TestSettingOf[T TestValueType](key setting.Key, value T) TestSetting[T] { +func TestSettingOf[T TestValueType](key pkey.Key, value T) TestSetting[T] { return TestSetting[T]{Key: key, Value: value} } // TestSettingWithError returns a [TestSetting] representing a policy setting // with the specified key and error. -func TestSettingWithError[T TestValueType](key setting.Key, err error) TestSetting[T] { +func TestSettingWithError[T TestValueType](key pkey.Key, err error) TestSetting[T] { return TestSetting[T]{Key: key, Error: err} } // testReadOperation describes a single policy setting read operation. type testReadOperation struct { // Key is the setting's unique identifier. - Key setting.Key + Key pkey.Key // Type is a value type of a read operation. // [setting.BooleanValue], [setting.IntegerValue], [setting.StringValue] or [setting.StringListValue] Type setting.Type @@ -64,7 +66,7 @@ type testReadOperation struct { // TestExpectedReads is the number of read operations with the specified details. type TestExpectedReads struct { // Key is the setting's unique identifier. - Key setting.Key + Key pkey.Key // Type is a value type of a read operation. // [setting.BooleanValue], [setting.IntegerValue], [setting.StringValue] or [setting.StringListValue] Type setting.Type @@ -78,7 +80,7 @@ func (r TestExpectedReads) operation() testReadOperation { // TestStore is a [Store] that can be used in tests. type TestStore struct { - tb internal.TB + tb testenv.TB done chan struct{} @@ -86,9 +88,10 @@ type TestStore struct { storeLockCount atomic.Int32 mu sync.RWMutex - suspendCount int // change callback are suspended if > 0 - mr, mw map[setting.Key]any // maps for reading and writing; they're the same unless the store is suspended. + suspendCount int // change callback are suspended if > 0 + mr, mw map[pkey.Key]any // maps for reading and writing; they're the same unless the store is suspended. cbs set.HandleSet[func()] + closed bool readsMu sync.Mutex reads map[testReadOperation]int // how many times a policy setting was read @@ -96,26 +99,22 @@ type TestStore struct { // NewTestStore returns a new [TestStore]. // The tb will be used to report coding errors detected by the [TestStore]. -func NewTestStore(tb internal.TB) *TestStore { - m := make(map[setting.Key]any) - return &TestStore{ +func NewTestStore(tb testenv.TB) *TestStore { + m := make(map[pkey.Key]any) + store := &TestStore{ tb: tb, done: make(chan struct{}), mr: m, mw: m, } + tb.Cleanup(store.Close) + return store } // NewTestStoreOf is a shorthand for [NewTestStore] followed by [TestStore.SetBooleans], // [TestStore.SetUInt64s], [TestStore.SetStrings] or [TestStore.SetStringLists]. -func NewTestStoreOf[T TestValueType](tb internal.TB, settings ...TestSetting[T]) *TestStore { - m := make(map[setting.Key]any) - store := &TestStore{ - tb: tb, - done: make(chan struct{}), - mr: m, - mw: m, - } +func NewTestStoreOf[T TestValueType](tb testenv.TB, settings ...TestSetting[T]) *TestStore { + store := NewTestStore(tb) switch settings := any(settings).(type) { case []TestSetting[bool]: store.SetBooleans(settings...) @@ -156,8 +155,15 @@ func (s *TestStore) RegisterChangeCallback(callback func()) (unregister func(), }, nil } +// IsEmpty reports whether the store does not contain any settings. +func (s *TestStore) IsEmpty() bool { + s.mu.RLock() + defer s.mu.RUnlock() + return len(s.mr) == 0 +} + // ReadString implements [Store]. -func (s *TestStore) ReadString(key setting.Key) (string, error) { +func (s *TestStore) ReadString(key pkey.Key) (string, error) { defer s.recordRead(key, setting.StringValue) s.mu.RLock() defer s.mu.RUnlock() @@ -176,7 +182,7 @@ func (s *TestStore) ReadString(key setting.Key) (string, error) { } // ReadUInt64 implements [Store]. -func (s *TestStore) ReadUInt64(key setting.Key) (uint64, error) { +func (s *TestStore) ReadUInt64(key pkey.Key) (uint64, error) { defer s.recordRead(key, setting.IntegerValue) s.mu.RLock() defer s.mu.RUnlock() @@ -195,7 +201,7 @@ func (s *TestStore) ReadUInt64(key setting.Key) (uint64, error) { } // ReadBoolean implements [Store]. -func (s *TestStore) ReadBoolean(key setting.Key) (bool, error) { +func (s *TestStore) ReadBoolean(key pkey.Key) (bool, error) { defer s.recordRead(key, setting.BooleanValue) s.mu.RLock() defer s.mu.RUnlock() @@ -214,7 +220,7 @@ func (s *TestStore) ReadBoolean(key setting.Key) (bool, error) { } // ReadStringArray implements [Store]. -func (s *TestStore) ReadStringArray(key setting.Key) ([]string, error) { +func (s *TestStore) ReadStringArray(key pkey.Key) ([]string, error) { defer s.recordRead(key, setting.StringListValue) s.mu.RLock() defer s.mu.RUnlock() @@ -232,7 +238,7 @@ func (s *TestStore) ReadStringArray(key setting.Key) ([]string, error) { return slice, nil } -func (s *TestStore) recordRead(key setting.Key, typ setting.Type) { +func (s *TestStore) recordRead(key pkey.Key, typ setting.Type) { s.readsMu.Lock() op := testReadOperation{key, typ} num := s.reads[op] @@ -308,7 +314,7 @@ func (s *TestStore) Resume() { s.mr = s.mw s.mu.Unlock() s.storeLock.Unlock() - s.notifyPolicyChanged() + s.NotifyPolicyChanged() case s.suspendCount < 0: s.tb.Fatal("negative suspendCount") default: @@ -333,7 +339,7 @@ func (s *TestStore) SetBooleans(settings ...TestSetting[bool]) { s.mu.Unlock() } s.storeLock.Unlock() - s.notifyPolicyChanged() + s.NotifyPolicyChanged() } // SetUInt64s sets the specified integer settings in s. @@ -352,7 +358,7 @@ func (s *TestStore) SetUInt64s(settings ...TestSetting[uint64]) { s.mu.Unlock() } s.storeLock.Unlock() - s.notifyPolicyChanged() + s.NotifyPolicyChanged() } // SetStrings sets the specified string settings in s. @@ -371,7 +377,7 @@ func (s *TestStore) SetStrings(settings ...TestSetting[string]) { s.mu.Unlock() } s.storeLock.Unlock() - s.notifyPolicyChanged() + s.NotifyPolicyChanged() } // SetStrings sets the specified string list settings in s. @@ -390,11 +396,11 @@ func (s *TestStore) SetStringLists(settings ...TestSetting[[]string]) { s.mu.Unlock() } s.storeLock.Unlock() - s.notifyPolicyChanged() + s.NotifyPolicyChanged() } // Delete deletes the specified settings from s. -func (s *TestStore) Delete(keys ...setting.Key) { +func (s *TestStore) Delete(keys ...pkey.Key) { s.storeLock.Lock() for _, key := range keys { s.mu.Lock() @@ -402,7 +408,7 @@ func (s *TestStore) Delete(keys ...setting.Key) { s.mu.Unlock() } s.storeLock.Unlock() - s.notifyPolicyChanged() + s.NotifyPolicyChanged() } // Clear deletes all settings from s. @@ -412,16 +418,16 @@ func (s *TestStore) Clear() { clear(s.mw) s.mu.Unlock() s.storeLock.Unlock() - s.notifyPolicyChanged() + s.NotifyPolicyChanged() } -func (s *TestStore) notifyPolicyChanged() { +func (s *TestStore) NotifyPolicyChanged() { s.mu.RLock() if s.suspendCount != 0 { s.mu.RUnlock() return } - cbs := xmaps.Values(s.cbs) + cbs := slicesx.MapValues(s.cbs) s.mu.RUnlock() var wg sync.WaitGroup @@ -439,9 +445,9 @@ func (s *TestStore) notifyPolicyChanged() { func (s *TestStore) Close() { s.mu.Lock() defer s.mu.Unlock() - if s.done != nil { + if !s.closed { close(s.done) - s.done = nil + s.closed = true } } diff --git a/util/syspolicy/syspolicy.go b/util/syspolicy/syspolicy.go index ccfd83347..48e430b67 100644 --- a/util/syspolicy/syspolicy.go +++ b/util/syspolicy/syspolicy.go @@ -1,98 +1,174 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -// Package syspolicy provides functions to retrieve system settings of a device. +// Package syspolicy contains the implementation of system policy management. +// Calling code should use the client interface in +// tailscale.com/util/syspolicy/policyclient. package syspolicy import ( "errors" + "fmt" + "reflect" "time" + "tailscale.com/util/syspolicy/internal/loggerx" + "tailscale.com/util/syspolicy/pkey" + "tailscale.com/util/syspolicy/policyclient" + "tailscale.com/util/syspolicy/ptype" + "tailscale.com/util/syspolicy/rsop" "tailscale.com/util/syspolicy/setting" + "tailscale.com/util/syspolicy/source" ) -func GetString(key Key, defaultValue string) (string, error) { - markHandlerInUse() - v, err := handler.ReadString(string(key)) - if errors.Is(err, ErrNoSuchKey) { - return defaultValue, nil - } - return v, err +var ( + // ErrNotConfigured is returned when the requested policy setting is not configured. + ErrNotConfigured = setting.ErrNotConfigured + // ErrTypeMismatch is returned when there's a type mismatch between the actual type + // of the setting value and the expected type. + ErrTypeMismatch = setting.ErrTypeMismatch + // ErrNoSuchKey is returned by [setting.DefinitionOf] when no policy setting + // has been registered with the specified key. + // + // This error is also returned by a (now deprecated) [Handler] when the specified + // key does not have a value set. While the package maintains compatibility with this + // usage of ErrNoSuchKey, it is recommended to return [ErrNotConfigured] from newer + // [source.Store] implementations. + ErrNoSuchKey = setting.ErrNoSuchKey +) + +// RegisterStore registers a new policy [source.Store] with the specified name and [setting.PolicyScope]. +// +// It is a shorthand for [rsop.RegisterStore]. +func RegisterStore(name string, scope setting.PolicyScope, store source.Store) (*rsop.StoreRegistration, error) { + return rsop.RegisterStore(name, scope, store) } -func GetUint64(key Key, defaultValue uint64) (uint64, error) { - markHandlerInUse() - v, err := handler.ReadUInt64(string(key)) - if errors.Is(err, ErrNoSuchKey) { - return defaultValue, nil +// hasAnyOf returns whether at least one of the specified policy settings is configured, +// or an error if no keys are provided or the check fails. +func hasAnyOf(keys ...pkey.Key) (bool, error) { + if len(keys) == 0 { + return false, errors.New("at least one key must be specified") + } + policy, err := rsop.PolicyFor(setting.DefaultScope()) + if err != nil { + return false, err } - return v, err + effective := policy.Get() + for _, k := range keys { + _, err := effective.GetErr(k) + if errors.Is(err, setting.ErrNotConfigured) || errors.Is(err, setting.ErrNoSuchKey) { + continue + } + if err != nil { + return false, err + } + return true, nil + } + return false, nil } -func GetBoolean(key Key, defaultValue bool) (bool, error) { - markHandlerInUse() - v, err := handler.ReadBoolean(string(key)) - if errors.Is(err, ErrNoSuchKey) { - return defaultValue, nil - } - return v, err +// getString returns a string policy setting with the specified key, +// or defaultValue if it does not exist. +func getString(key pkey.Key, defaultValue string) (string, error) { + return getCurrentPolicySettingValue(key, defaultValue) } -func GetStringArray(key Key, defaultValue []string) ([]string, error) { - markHandlerInUse() - v, err := handler.ReadStringArray(string(key)) - if errors.Is(err, ErrNoSuchKey) { - return defaultValue, nil - } - return v, err +// getUint64 returns a numeric policy setting with the specified key, +// or defaultValue if it does not exist. +func getUint64(key pkey.Key, defaultValue uint64) (uint64, error) { + return getCurrentPolicySettingValue(key, defaultValue) +} + +// getBoolean returns a boolean policy setting with the specified key, +// or defaultValue if it does not exist. +func getBoolean(key pkey.Key, defaultValue bool) (bool, error) { + return getCurrentPolicySettingValue(key, defaultValue) +} + +// getStringArray returns a multi-string policy setting with the specified key, +// or defaultValue if it does not exist. +func getStringArray(key pkey.Key, defaultValue []string) ([]string, error) { + return getCurrentPolicySettingValue(key, defaultValue) } -// GetPreferenceOption loads a policy from the registry that can be +// getPreferenceOption loads a policy from the registry that can be // managed by an enterprise policy management system and allows administrative // overrides of users' choices in a way that we do not want tailcontrol to have // the authority to set. It describes user-decides/always/never options, where // "always" and "never" remove the user's ability to make a selection. If not -// present or set to a different value, "user-decides" is the default. -func GetPreferenceOption(name Key) (setting.PreferenceOption, error) { - s, err := GetString(name, "user-decides") - if err != nil { - return setting.ShowChoiceByPolicy, err - } - var opt setting.PreferenceOption - err = opt.UnmarshalText([]byte(s)) - return opt, err +// present or set to a different value, defaultValue (and a nil error) is returned. +func getPreferenceOption(name pkey.Key, defaultValue ptype.PreferenceOption) (ptype.PreferenceOption, error) { + return getCurrentPolicySettingValue(name, defaultValue) } -// GetVisibility loads a policy from the registry that can be managed +// getVisibility loads a policy from the registry that can be managed // by an enterprise policy management system and describes show/hide decisions // for UI elements. The registry value should be a string set to "show" (return // true) or "hide" (return true). If not present or set to a different value, // "show" (return false) is the default. -func GetVisibility(name Key) (setting.Visibility, error) { - s, err := GetString(name, "show") - if err != nil { - return setting.VisibleByPolicy, err - } - var visibility setting.Visibility - visibility.UnmarshalText([]byte(s)) - return visibility, nil +func getVisibility(name pkey.Key) (ptype.Visibility, error) { + return getCurrentPolicySettingValue(name, ptype.VisibleByPolicy) } -// GetDuration loads a policy from the registry that can be managed +// getDuration loads a policy from the registry that can be managed // by an enterprise policy management system and describes a duration for some // action. The registry value should be a string that time.ParseDuration // understands. If the registry value is "" or can not be processed, // defaultValue is returned instead. -func GetDuration(name Key, defaultValue time.Duration) (time.Duration, error) { - opt, err := GetString(name, "") - if opt == "" || err != nil { - return defaultValue, err +func getDuration(name pkey.Key, defaultValue time.Duration) (time.Duration, error) { + d, err := getCurrentPolicySettingValue(name, defaultValue) + if err != nil { + return d, err } - v, err := time.ParseDuration(opt) - if err != nil || v < 0 { + if d < 0 { return defaultValue, nil } - return v, nil + return d, nil +} + +// registerChangeCallback adds a function that will be called whenever the effective policy +// for the default scope changes. The returned function can be used to unregister the callback. +func registerChangeCallback(cb rsop.PolicyChangeCallback) (unregister func(), err error) { + effective, err := rsop.PolicyFor(setting.DefaultScope()) + if err != nil { + return nil, err + } + return effective.RegisterChangeCallback(cb), nil +} + +// getCurrentPolicySettingValue returns the value of the policy setting +// specified by its key from the [rsop.Policy] of the [setting.DefaultScope]. It +// returns def if the policy setting is not configured, or an error if it has +// an error or could not be converted to the specified type T. +func getCurrentPolicySettingValue[T setting.ValueType](key pkey.Key, def T) (T, error) { + effective, err := rsop.PolicyFor(setting.DefaultScope()) + if err != nil { + return def, err + } + value, err := effective.Get().GetErr(key) + if err != nil { + if errors.Is(err, setting.ErrNotConfigured) || errors.Is(err, setting.ErrNoSuchKey) { + return def, nil + } + return def, err + } + if res, ok := value.(T); ok { + return res, nil + } + return convertPolicySettingValueTo(value, def) +} + +func convertPolicySettingValueTo[T setting.ValueType](value any, def T) (T, error) { + // Convert [PreferenceOption], [Visibility], or [time.Duration] back to a string + // if someone requests a string instead of the actual setting's value. + // TODO(nickkhyl): check if this behavior is relied upon anywhere besides the old tests. + if reflect.TypeFor[T]().Kind() == reflect.String { + if str, ok := value.(fmt.Stringer); ok { + return any(str.String()).(T), nil + } + } + return def, fmt.Errorf("%w: got %T, want %T", setting.ErrTypeMismatch, value, def) } // SelectControlURL returns the ControlURL to use based on a value in @@ -135,3 +211,54 @@ func SelectControlURL(reg, disk string) string { } return def } + +func init() { + policyclient.RegisterClientImpl(globalSyspolicy{}) +} + +// globalSyspolicy implements [policyclient.Client] using the syspolicy global +// functions and global registrations. +// +// TODO: de-global-ify. This implementation using the old global functions +// is an intermediate stage while changing policyclient to be modular. +type globalSyspolicy struct{} + +func (globalSyspolicy) GetBoolean(key pkey.Key, defaultValue bool) (bool, error) { + return getBoolean(key, defaultValue) +} + +func (globalSyspolicy) GetString(key pkey.Key, defaultValue string) (string, error) { + return getString(key, defaultValue) +} + +func (globalSyspolicy) GetStringArray(key pkey.Key, defaultValue []string) ([]string, error) { + return getStringArray(key, defaultValue) +} + +func (globalSyspolicy) SetDebugLoggingEnabled(enabled bool) { + loggerx.SetDebugLoggingEnabled(enabled) +} + +func (globalSyspolicy) GetUint64(key pkey.Key, defaultValue uint64) (uint64, error) { + return getUint64(key, defaultValue) +} + +func (globalSyspolicy) GetDuration(name pkey.Key, defaultValue time.Duration) (time.Duration, error) { + return getDuration(name, defaultValue) +} + +func (globalSyspolicy) GetPreferenceOption(name pkey.Key, defaultValue ptype.PreferenceOption) (ptype.PreferenceOption, error) { + return getPreferenceOption(name, defaultValue) +} + +func (globalSyspolicy) GetVisibility(name pkey.Key) (ptype.Visibility, error) { + return getVisibility(name) +} + +func (globalSyspolicy) HasAnyOf(keys ...pkey.Key) (bool, error) { + return hasAnyOf(keys...) +} + +func (globalSyspolicy) RegisterChangeCallback(cb func(policyclient.PolicyChange)) (unregister func(), err error) { + return registerChangeCallback(cb) +} diff --git a/util/syspolicy/syspolicy_test.go b/util/syspolicy/syspolicy_test.go index 8280aa1df..10f8da486 100644 --- a/util/syspolicy/syspolicy_test.go +++ b/util/syspolicy/syspolicy_test.go @@ -9,110 +9,106 @@ import ( "testing" "time" + "tailscale.com/types/logger" + "tailscale.com/util/syspolicy/internal/loggerx" + "tailscale.com/util/syspolicy/internal/metrics" + "tailscale.com/util/syspolicy/pkey" + "tailscale.com/util/syspolicy/ptype" + "tailscale.com/util/syspolicy/rsop" "tailscale.com/util/syspolicy/setting" + "tailscale.com/util/syspolicy/source" + "tailscale.com/util/testenv" ) -// testHandler encompasses all data types returned when testing any of the syspolicy -// methods that involve getting a policy value. -// For keys and the corresponding values, check policy_keys.go. -type testHandler struct { - t *testing.T - key Key - s string - u64 uint64 - b bool - sArr []string - err error - calls int // used for testing reads from cache vs. handler -} - var someOtherError = errors.New("error other than not found") -func (th *testHandler) ReadString(key string) (string, error) { - if key != string(th.key) { - th.t.Errorf("ReadString(%q) want %q", key, th.key) - } - th.calls++ - return th.s, th.err -} - -func (th *testHandler) ReadUInt64(key string) (uint64, error) { - if key != string(th.key) { - th.t.Errorf("ReadUint64(%q) want %q", key, th.key) - } - th.calls++ - return th.u64, th.err -} - -func (th *testHandler) ReadBoolean(key string) (bool, error) { - if key != string(th.key) { - th.t.Errorf("ReadBool(%q) want %q", key, th.key) +// registerWellKnownSettingsForTest registers all implicit setting definitions +// for the duration of the test. +func registerWellKnownSettingsForTest(tb testenv.TB) { + tb.Helper() + err := setting.SetDefinitionsForTest(tb, implicitDefinitions...) + if err != nil { + tb.Fatalf("Failed to register well-known settings: %v", err) } - th.calls++ - return th.b, th.err -} - -func (th *testHandler) ReadStringArray(key string) ([]string, error) { - if key != string(th.key) { - th.t.Errorf("ReadStringArray(%q) want %q", key, th.key) - } - th.calls++ - return th.sArr, th.err } func TestGetString(t *testing.T) { tests := []struct { name string - key Key + key pkey.Key handlerValue string handlerError error defaultValue string wantValue string wantError error + wantMetrics []metrics.TestState }{ { name: "read existing value", - key: AdminConsoleVisibility, + key: pkey.AdminConsoleVisibility, handlerValue: "hide", wantValue: "hide", + wantMetrics: []metrics.TestState{ + {Name: "$os_syspolicy_any", Value: 1}, + {Name: "$os_syspolicy_AdminConsole", Value: 1}, + }, }, { name: "read non-existing value", - key: EnableServerMode, - handlerError: ErrNoSuchKey, + key: pkey.EnableServerMode, + handlerError: ErrNotConfigured, wantError: nil, }, { name: "read non-existing value, non-blank default", - key: EnableServerMode, - handlerError: ErrNoSuchKey, + key: pkey.EnableServerMode, + handlerError: ErrNotConfigured, defaultValue: "test", wantValue: "test", wantError: nil, }, { name: "reading value returns other error", - key: NetworkDevicesVisibility, + key: pkey.NetworkDevicesVisibility, handlerError: someOtherError, wantError: someOtherError, + wantMetrics: []metrics.TestState{ + {Name: "$os_syspolicy_errors", Value: 1}, + {Name: "$os_syspolicy_NetworkDevices_error", Value: 1}, + }, }, } + registerWellKnownSettingsForTest(t) + for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - SetHandlerForTest(t, &testHandler{ - t: t, - key: tt.key, - s: tt.handlerValue, - err: tt.handlerError, - }) - value, err := GetString(tt.key, tt.defaultValue) - if err != tt.wantError { + h := metrics.NewTestHandler(t) + metrics.SetHooksForTest(t, h.AddMetric, h.SetMetric) + + s := source.TestSetting[string]{ + Key: tt.key, + Value: tt.handlerValue, + Error: tt.handlerError, + } + registerSingleSettingStoreForTest(t, s) + + value, err := getString(tt.key, tt.defaultValue) + if !errorsMatchForTest(err, tt.wantError) { t.Errorf("err=%q, want %q", err, tt.wantError) } if value != tt.wantValue { t.Errorf("value=%v, want %v", value, tt.wantValue) } + wantMetrics := tt.wantMetrics + if !metrics.ShouldReport() { + // Check that metrics are not reported on platforms + // where they shouldn't be reported. + // As of 2024-09-04, syspolicy only reports metrics + // on Windows and Android. + wantMetrics = nil + } + h.MustEqual(wantMetrics...) }) } } @@ -120,7 +116,7 @@ func TestGetString(t *testing.T) { func TestGetUint64(t *testing.T) { tests := []struct { name string - key Key + key pkey.Key handlerValue uint64 handlerError error defaultValue uint64 @@ -129,27 +125,27 @@ func TestGetUint64(t *testing.T) { }{ { name: "read existing value", - key: KeyExpirationNoticeTime, + key: pkey.LogSCMInteractions, handlerValue: 1, wantValue: 1, }, { name: "read non-existing value", - key: LogSCMInteractions, + key: pkey.LogSCMInteractions, handlerValue: 0, - handlerError: ErrNoSuchKey, + handlerError: ErrNotConfigured, wantValue: 0, }, { name: "read non-existing value, non-zero default", - key: LogSCMInteractions, + key: pkey.LogSCMInteractions, defaultValue: 2, - handlerError: ErrNoSuchKey, + handlerError: ErrNotConfigured, wantValue: 2, }, { name: "reading value returns other error", - key: FlushDNSOnSessionUnlock, + key: pkey.FlushDNSOnSessionUnlock, handlerError: someOtherError, wantError: someOtherError, }, @@ -157,14 +153,23 @@ func TestGetUint64(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - SetHandlerForTest(t, &testHandler{ - t: t, - key: tt.key, - u64: tt.handlerValue, - err: tt.handlerError, - }) - value, err := GetUint64(tt.key, tt.defaultValue) - if err != tt.wantError { + // None of the policy settings tested here are integers. + // In fact, we don't have any integer policies as of 2024-10-08. + // However, we can register each of them as an integer policy setting + // for the duration of the test, providing us with something to test against. + if err := setting.SetDefinitionsForTest(t, setting.NewDefinition(tt.key, setting.DeviceSetting, setting.IntegerValue)); err != nil { + t.Fatalf("SetDefinitionsForTest failed: %v", err) + } + + s := source.TestSetting[uint64]{ + Key: tt.key, + Value: tt.handlerValue, + Error: tt.handlerError, + } + registerSingleSettingStoreForTest(t, s) + + value, err := getUint64(tt.key, tt.defaultValue) + if !errorsMatchForTest(err, tt.wantError) { t.Errorf("err=%q, want %q", err, tt.wantError) } if value != tt.wantValue { @@ -177,51 +182,75 @@ func TestGetUint64(t *testing.T) { func TestGetBoolean(t *testing.T) { tests := []struct { name string - key Key + key pkey.Key handlerValue bool handlerError error defaultValue bool wantValue bool wantError error + wantMetrics []metrics.TestState }{ { name: "read existing value", - key: FlushDNSOnSessionUnlock, + key: pkey.FlushDNSOnSessionUnlock, handlerValue: true, wantValue: true, + wantMetrics: []metrics.TestState{ + {Name: "$os_syspolicy_any", Value: 1}, + {Name: "$os_syspolicy_FlushDNSOnSessionUnlock", Value: 1}, + }, }, { name: "read non-existing value", - key: LogSCMInteractions, + key: pkey.LogSCMInteractions, handlerValue: false, - handlerError: ErrNoSuchKey, + handlerError: ErrNotConfigured, wantValue: false, }, { name: "reading value returns other error", - key: FlushDNSOnSessionUnlock, + key: pkey.FlushDNSOnSessionUnlock, handlerError: someOtherError, - wantError: someOtherError, + wantError: someOtherError, // expect error... defaultValue: true, - wantValue: false, + wantValue: true, // ...AND default value if the handler fails. + wantMetrics: []metrics.TestState{ + {Name: "$os_syspolicy_errors", Value: 1}, + {Name: "$os_syspolicy_FlushDNSOnSessionUnlock_error", Value: 1}, + }, }, } + registerWellKnownSettingsForTest(t) + for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - SetHandlerForTest(t, &testHandler{ - t: t, - key: tt.key, - b: tt.handlerValue, - err: tt.handlerError, - }) - value, err := GetBoolean(tt.key, tt.defaultValue) - if err != tt.wantError { + h := metrics.NewTestHandler(t) + metrics.SetHooksForTest(t, h.AddMetric, h.SetMetric) + + s := source.TestSetting[bool]{ + Key: tt.key, + Value: tt.handlerValue, + Error: tt.handlerError, + } + registerSingleSettingStoreForTest(t, s) + + value, err := getBoolean(tt.key, tt.defaultValue) + if !errorsMatchForTest(err, tt.wantError) { t.Errorf("err=%q, want %q", err, tt.wantError) } if value != tt.wantValue { t.Errorf("value=%v, want %v", value, tt.wantValue) } + wantMetrics := tt.wantMetrics + if !metrics.ShouldReport() { + // Check that metrics are not reported on platforms + // where they shouldn't be reported. + // As of 2024-09-04, syspolicy only reports metrics + // on Windows and Android. + wantMetrics = nil + } + h.MustEqual(wantMetrics...) }) } } @@ -229,60 +258,92 @@ func TestGetBoolean(t *testing.T) { func TestGetPreferenceOption(t *testing.T) { tests := []struct { name string - key Key + key pkey.Key handlerValue string handlerError error - wantValue setting.PreferenceOption + wantValue ptype.PreferenceOption wantError error + wantMetrics []metrics.TestState }{ { name: "always by policy", - key: EnableIncomingConnections, + key: pkey.EnableIncomingConnections, handlerValue: "always", - wantValue: setting.AlwaysByPolicy, + wantValue: ptype.AlwaysByPolicy, + wantMetrics: []metrics.TestState{ + {Name: "$os_syspolicy_any", Value: 1}, + {Name: "$os_syspolicy_AllowIncomingConnections", Value: 1}, + }, }, { name: "never by policy", - key: EnableIncomingConnections, + key: pkey.EnableIncomingConnections, handlerValue: "never", - wantValue: setting.NeverByPolicy, + wantValue: ptype.NeverByPolicy, + wantMetrics: []metrics.TestState{ + {Name: "$os_syspolicy_any", Value: 1}, + {Name: "$os_syspolicy_AllowIncomingConnections", Value: 1}, + }, }, { name: "use default", - key: EnableIncomingConnections, + key: pkey.EnableIncomingConnections, handlerValue: "", - wantValue: setting.ShowChoiceByPolicy, + wantValue: ptype.ShowChoiceByPolicy, + wantMetrics: []metrics.TestState{ + {Name: "$os_syspolicy_any", Value: 1}, + {Name: "$os_syspolicy_AllowIncomingConnections", Value: 1}, + }, }, { name: "read non-existing value", - key: EnableIncomingConnections, - handlerError: ErrNoSuchKey, - wantValue: setting.ShowChoiceByPolicy, + key: pkey.EnableIncomingConnections, + handlerError: ErrNotConfigured, + wantValue: ptype.ShowChoiceByPolicy, }, { name: "other error is returned", - key: EnableIncomingConnections, + key: pkey.EnableIncomingConnections, handlerError: someOtherError, - wantValue: setting.ShowChoiceByPolicy, + wantValue: ptype.ShowChoiceByPolicy, wantError: someOtherError, + wantMetrics: []metrics.TestState{ + {Name: "$os_syspolicy_errors", Value: 1}, + {Name: "$os_syspolicy_AllowIncomingConnections_error", Value: 1}, + }, }, } + registerWellKnownSettingsForTest(t) + for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - SetHandlerForTest(t, &testHandler{ - t: t, - key: tt.key, - s: tt.handlerValue, - err: tt.handlerError, - }) - option, err := GetPreferenceOption(tt.key) - if err != tt.wantError { + h := metrics.NewTestHandler(t) + metrics.SetHooksForTest(t, h.AddMetric, h.SetMetric) + + s := source.TestSetting[string]{ + Key: tt.key, + Value: tt.handlerValue, + Error: tt.handlerError, + } + registerSingleSettingStoreForTest(t, s) + + option, err := getPreferenceOption(tt.key, ptype.ShowChoiceByPolicy) + if !errorsMatchForTest(err, tt.wantError) { t.Errorf("err=%q, want %q", err, tt.wantError) } if option != tt.wantValue { t.Errorf("option=%v, want %v", option, tt.wantValue) } + wantMetrics := tt.wantMetrics + if !metrics.ShouldReport() { + // Check that metrics are not reported on platforms + // where they shouldn't be reported. + // As of 2024-09-04, syspolicy only reports metrics + // on Windows and Android. + wantMetrics = nil + } + h.MustEqual(wantMetrics...) }) } } @@ -290,56 +351,84 @@ func TestGetPreferenceOption(t *testing.T) { func TestGetVisibility(t *testing.T) { tests := []struct { name string - key Key + key pkey.Key handlerValue string handlerError error - wantValue setting.Visibility + wantValue ptype.Visibility wantError error + wantMetrics []metrics.TestState }{ { name: "hidden by policy", - key: AdminConsoleVisibility, + key: pkey.AdminConsoleVisibility, handlerValue: "hide", - wantValue: setting.HiddenByPolicy, + wantValue: ptype.HiddenByPolicy, + wantMetrics: []metrics.TestState{ + {Name: "$os_syspolicy_any", Value: 1}, + {Name: "$os_syspolicy_AdminConsole", Value: 1}, + }, }, { name: "visibility default", - key: AdminConsoleVisibility, + key: pkey.AdminConsoleVisibility, handlerValue: "show", - wantValue: setting.VisibleByPolicy, + wantValue: ptype.VisibleByPolicy, + wantMetrics: []metrics.TestState{ + {Name: "$os_syspolicy_any", Value: 1}, + {Name: "$os_syspolicy_AdminConsole", Value: 1}, + }, }, { name: "read non-existing value", - key: AdminConsoleVisibility, + key: pkey.AdminConsoleVisibility, handlerValue: "show", - handlerError: ErrNoSuchKey, - wantValue: setting.VisibleByPolicy, + handlerError: ErrNotConfigured, + wantValue: ptype.VisibleByPolicy, }, { name: "other error is returned", - key: AdminConsoleVisibility, + key: pkey.AdminConsoleVisibility, handlerValue: "show", handlerError: someOtherError, - wantValue: setting.VisibleByPolicy, + wantValue: ptype.VisibleByPolicy, wantError: someOtherError, + wantMetrics: []metrics.TestState{ + {Name: "$os_syspolicy_errors", Value: 1}, + {Name: "$os_syspolicy_AdminConsole_error", Value: 1}, + }, }, } + registerWellKnownSettingsForTest(t) + for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - SetHandlerForTest(t, &testHandler{ - t: t, - key: tt.key, - s: tt.handlerValue, - err: tt.handlerError, - }) - visibility, err := GetVisibility(tt.key) - if err != tt.wantError { + h := metrics.NewTestHandler(t) + metrics.SetHooksForTest(t, h.AddMetric, h.SetMetric) + + s := source.TestSetting[string]{ + Key: tt.key, + Value: tt.handlerValue, + Error: tt.handlerError, + } + registerSingleSettingStoreForTest(t, s) + + visibility, err := getVisibility(tt.key) + if !errorsMatchForTest(err, tt.wantError) { t.Errorf("err=%q, want %q", err, tt.wantError) } if visibility != tt.wantValue { t.Errorf("visibility=%v, want %v", visibility, tt.wantValue) } + wantMetrics := tt.wantMetrics + if !metrics.ShouldReport() { + // Check that metrics are not reported on platforms + // where they shouldn't be reported. + // As of 2024-09-04, syspolicy only reports metrics + // on Windows and Android. + wantMetrics = nil + } + h.MustEqual(wantMetrics...) }) } } @@ -347,66 +436,95 @@ func TestGetVisibility(t *testing.T) { func TestGetDuration(t *testing.T) { tests := []struct { name string - key Key + key pkey.Key handlerValue string handlerError error defaultValue time.Duration wantValue time.Duration wantError error + wantMetrics []metrics.TestState }{ { name: "read existing value", - key: KeyExpirationNoticeTime, + key: pkey.KeyExpirationNoticeTime, handlerValue: "2h", wantValue: 2 * time.Hour, defaultValue: 24 * time.Hour, + wantMetrics: []metrics.TestState{ + {Name: "$os_syspolicy_any", Value: 1}, + {Name: "$os_syspolicy_KeyExpirationNotice", Value: 1}, + }, }, { name: "invalid duration value", - key: KeyExpirationNoticeTime, + key: pkey.KeyExpirationNoticeTime, handlerValue: "-20", wantValue: 24 * time.Hour, + wantError: errors.New(`time: missing unit in duration "-20"`), defaultValue: 24 * time.Hour, + wantMetrics: []metrics.TestState{ + {Name: "$os_syspolicy_errors", Value: 1}, + {Name: "$os_syspolicy_KeyExpirationNotice_error", Value: 1}, + }, }, { name: "read non-existing value", - key: KeyExpirationNoticeTime, - handlerError: ErrNoSuchKey, + key: pkey.KeyExpirationNoticeTime, + handlerError: ErrNotConfigured, wantValue: 24 * time.Hour, defaultValue: 24 * time.Hour, }, { name: "read non-existing value different default", - key: KeyExpirationNoticeTime, - handlerError: ErrNoSuchKey, + key: pkey.KeyExpirationNoticeTime, + handlerError: ErrNotConfigured, wantValue: 0 * time.Second, defaultValue: 0 * time.Second, }, { name: "other error is returned", - key: KeyExpirationNoticeTime, + key: pkey.KeyExpirationNoticeTime, handlerError: someOtherError, wantValue: 24 * time.Hour, wantError: someOtherError, defaultValue: 24 * time.Hour, + wantMetrics: []metrics.TestState{ + {Name: "$os_syspolicy_errors", Value: 1}, + {Name: "$os_syspolicy_KeyExpirationNotice_error", Value: 1}, + }, }, } + registerWellKnownSettingsForTest(t) + for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - SetHandlerForTest(t, &testHandler{ - t: t, - key: tt.key, - s: tt.handlerValue, - err: tt.handlerError, - }) - duration, err := GetDuration(tt.key, tt.defaultValue) - if err != tt.wantError { + h := metrics.NewTestHandler(t) + metrics.SetHooksForTest(t, h.AddMetric, h.SetMetric) + + s := source.TestSetting[string]{ + Key: tt.key, + Value: tt.handlerValue, + Error: tt.handlerError, + } + registerSingleSettingStoreForTest(t, s) + + duration, err := getDuration(tt.key, tt.defaultValue) + if !errorsMatchForTest(err, tt.wantError) { t.Errorf("err=%q, want %q", err, tt.wantError) } if duration != tt.wantValue { t.Errorf("duration=%v, want %v", duration, tt.wantValue) } + wantMetrics := tt.wantMetrics + if !metrics.ShouldReport() { + // Check that metrics are not reported on platforms + // where they shouldn't be reported. + // As of 2024-09-04, syspolicy only reports metrics + // on Windows and Android. + wantMetrics = nil + } + h.MustEqual(wantMetrics...) }) } } @@ -414,60 +532,115 @@ func TestGetDuration(t *testing.T) { func TestGetStringArray(t *testing.T) { tests := []struct { name string - key Key + key pkey.Key handlerValue []string handlerError error defaultValue []string wantValue []string wantError error + wantMetrics []metrics.TestState }{ { name: "read existing value", - key: AllowedSuggestedExitNodes, + key: pkey.AllowedSuggestedExitNodes, handlerValue: []string{"foo", "bar"}, wantValue: []string{"foo", "bar"}, + wantMetrics: []metrics.TestState{ + {Name: "$os_syspolicy_any", Value: 1}, + {Name: "$os_syspolicy_AllowedSuggestedExitNodes", Value: 1}, + }, }, { name: "read non-existing value", - key: AllowedSuggestedExitNodes, - handlerError: ErrNoSuchKey, + key: pkey.AllowedSuggestedExitNodes, + handlerError: ErrNotConfigured, wantError: nil, }, { name: "read non-existing value, non nil default", - key: AllowedSuggestedExitNodes, - handlerError: ErrNoSuchKey, + key: pkey.AllowedSuggestedExitNodes, + handlerError: ErrNotConfigured, defaultValue: []string{"foo", "bar"}, wantValue: []string{"foo", "bar"}, wantError: nil, }, { name: "reading value returns other error", - key: AllowedSuggestedExitNodes, + key: pkey.AllowedSuggestedExitNodes, handlerError: someOtherError, wantError: someOtherError, + wantMetrics: []metrics.TestState{ + {Name: "$os_syspolicy_errors", Value: 1}, + {Name: "$os_syspolicy_AllowedSuggestedExitNodes_error", Value: 1}, + }, }, } + registerWellKnownSettingsForTest(t) + for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - SetHandlerForTest(t, &testHandler{ - t: t, - key: tt.key, - sArr: tt.handlerValue, - err: tt.handlerError, - }) - value, err := GetStringArray(tt.key, tt.defaultValue) - if err != tt.wantError { + h := metrics.NewTestHandler(t) + metrics.SetHooksForTest(t, h.AddMetric, h.SetMetric) + + s := source.TestSetting[[]string]{ + Key: tt.key, + Value: tt.handlerValue, + Error: tt.handlerError, + } + registerSingleSettingStoreForTest(t, s) + + value, err := getStringArray(tt.key, tt.defaultValue) + if !errorsMatchForTest(err, tt.wantError) { t.Errorf("err=%q, want %q", err, tt.wantError) } if !slices.Equal(tt.wantValue, value) { t.Errorf("value=%v, want %v", value, tt.wantValue) } + wantMetrics := tt.wantMetrics + if !metrics.ShouldReport() { + // Check that metrics are not reported on platforms + // where they shouldn't be reported. + // As of 2024-09-04, syspolicy only reports metrics + // on Windows and Android. + wantMetrics = nil + } + h.MustEqual(wantMetrics...) }) } } +// mustRegisterStoreForTest is like [rsop.RegisterStoreForTest], but it fails the test if the store could not be registered. +func mustRegisterStoreForTest(tb testenv.TB, name string, scope setting.PolicyScope, store source.Store) *rsop.StoreRegistration { + tb.Helper() + reg, err := rsop.RegisterStoreForTest(tb, name, scope, store) + if err != nil { + tb.Fatalf("Failed to register policy store %q as a %v policy source: %v", name, scope, err) + } + return reg +} + +func registerSingleSettingStoreForTest[T source.TestValueType](tb testenv.TB, s source.TestSetting[T]) { + policyStore := source.NewTestStoreOf(tb, s) + mustRegisterStoreForTest(tb, "TestStore", setting.DeviceScope, policyStore) +} + +func BenchmarkGetString(b *testing.B) { + loggerx.SetForTest(b, logger.Discard, logger.Discard) + registerWellKnownSettingsForTest(b) + + wantControlURL := "https://login.tailscale.com" + registerSingleSettingStoreForTest(b, source.TestSettingOf(pkey.ControlURL, wantControlURL)) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + gotControlURL, _ := getString(pkey.ControlURL, "https://controlplane.tailscale.com") + if gotControlURL != wantControlURL { + b.Fatalf("got %v; want %v", gotControlURL, wantControlURL) + } + } +} + func TestSelectControlURL(t *testing.T) { tests := []struct { reg, disk, want string @@ -499,3 +672,13 @@ func TestSelectControlURL(t *testing.T) { } } } + +func errorsMatchForTest(got, want error) bool { + if got == nil && want == nil { + return true + } + if got == nil || want == nil { + return false + } + return errors.Is(got, want) || got.Error() == want.Error() +} diff --git a/util/syspolicy/syspolicy_windows.go b/util/syspolicy/syspolicy_windows.go new file mode 100644 index 000000000..ca0fd329a --- /dev/null +++ b/util/syspolicy/syspolicy_windows.go @@ -0,0 +1,92 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package syspolicy + +import ( + "errors" + "fmt" + "os/user" + + "tailscale.com/util/syspolicy/internal" + "tailscale.com/util/syspolicy/rsop" + "tailscale.com/util/syspolicy/setting" + "tailscale.com/util/syspolicy/source" + "tailscale.com/util/testenv" +) + +func init() { + // On Windows, we should automatically register the Registry-based policy + // store for the device. If we are running in a user's security context + // (e.g., we're the GUI), we should also register the Registry policy store for + // the user. In the future, we should register (and unregister) user policy + // stores whenever a user connects to (or disconnects from) the local backend. + // This ensures the backend is aware of the user's policy settings and can send + // them to the GUI/CLI/Web clients on demand or whenever they change. + // + // Other platforms, such as macOS, iOS and Android, should register their + // platform-specific policy stores via [RegisterStore] + // (or [RegisterHandler] until they implement the [source.Store] interface). + // + // External code, such as the ipnlocal package, may choose to register + // additional policy stores, such as config files and policies received from + // the control plane. + internal.Init.MustDefer(func() error { + // Do not register or use default policy stores during tests. + // Each test should set up its own necessary configurations. + if testenv.InTest() { + return nil + } + return configureSyspolicy(nil) + }) +} + +// configureSyspolicy configures syspolicy for use on Windows, +// either in test or regular builds depending on whether tb has a non-nil value. +func configureSyspolicy(tb testenv.TB) error { + const localSystemSID = "S-1-5-18" + // Always create and register a machine policy store that reads + // policy settings from the HKEY_LOCAL_MACHINE registry hive. + machineStore, err := source.NewMachinePlatformPolicyStore() + if err != nil { + return fmt.Errorf("failed to create the machine policy store: %v", err) + } + if tb == nil { + _, err = rsop.RegisterStore("Platform", setting.DeviceScope, machineStore) + } else { + _, err = rsop.RegisterStoreForTest(tb, "Platform", setting.DeviceScope, machineStore) + } + if err != nil { + return err + } + // Check whether the current process is running as Local System or not. + u, err := user.Current() + if err != nil { + return err + } + if u.Uid == localSystemSID { + return nil + } + // If it's not a Local System's process (e.g., it's the GUI rather than the tailscaled service), + // we should create and use a policy store for the current user that reads + // policy settings from that user's registry hive (HKEY_CURRENT_USER). + userStore, err := source.NewUserPlatformPolicyStore(0) + if err != nil { + return fmt.Errorf("failed to create the current user's policy store: %v", err) + } + if tb == nil { + _, err = rsop.RegisterStore("Platform", setting.CurrentUserScope, userStore) + } else { + _, err = rsop.RegisterStoreForTest(tb, "Platform", setting.CurrentUserScope, userStore) + } + if err != nil { + return err + } + // And also set [setting.CurrentUserScope] as the [setting.DefaultScope], so [GetString], + // [GetVisibility] and similar functions would be returning a merged result + // of the machine's and user's policies. + if !setting.SetDefaultScope(setting.CurrentUserScope) { + return errors.New("current scope already set") + } + return nil +} diff --git a/util/systemd/systemd_nonlinux.go b/util/systemd/systemd_nonlinux.go deleted file mode 100644 index 36214020c..000000000 --- a/util/systemd/systemd_nonlinux.go +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !linux - -package systemd - -func Ready() {} -func Status(string, ...any) {} diff --git a/util/testenv/testenv.go b/util/testenv/testenv.go index 12ada9003..aa6660411 100644 --- a/util/testenv/testenv.go +++ b/util/testenv/testenv.go @@ -6,6 +6,7 @@ package testenv import ( + "context" "flag" "tailscale.com/types/lazy" @@ -19,3 +20,48 @@ func InTest() bool { return flag.Lookup("test.v") != nil }) } + +// TB is testing.TB, to avoid importing "testing" in non-test code. +type TB interface { + Cleanup(func()) + Error(args ...any) + Errorf(format string, args ...any) + Fail() + FailNow() + Failed() bool + Fatal(args ...any) + Fatalf(format string, args ...any) + Helper() + Log(args ...any) + Logf(format string, args ...any) + Name() string + Setenv(key, value string) + Chdir(dir string) + Skip(args ...any) + SkipNow() + Skipf(format string, args ...any) + Skipped() bool + TempDir() string + Context() context.Context +} + +// InParallelTest reports whether t is running as a parallel test. +// +// Use of this function taints t such that its Parallel method (assuming t is an +// actual *testing.T) will panic if called after this function. +func InParallelTest(t TB) (isParallel bool) { + defer func() { + if r := recover(); r != nil { + isParallel = true + } + }() + t.Chdir(".") // panics in a t.Parallel test + return false +} + +// AssertInTest panics if called outside of a test binary. +func AssertInTest() { + if !InTest() { + panic("func called outside of test binary") + } +} diff --git a/util/testenv/testenv_test.go b/util/testenv/testenv_test.go index 43c332b26..c647d9aec 100644 --- a/util/testenv/testenv_test.go +++ b/util/testenv/testenv_test.go @@ -16,3 +16,16 @@ func TestDeps(t *testing.T) { }, }.Check(t) } + +func TestInParallelTestTrue(t *testing.T) { + t.Parallel() + if !InParallelTest(t) { + t.Fatal("InParallelTest should return true once t.Parallel has been called") + } +} + +func TestInParallelTestFalse(t *testing.T) { + if InParallelTest(t) { + t.Fatal("InParallelTest should return false before t.Parallel has been called") + } +} diff --git a/util/uniq/slice.go b/util/uniq/slice.go deleted file mode 100644 index 4ab933a9d..000000000 --- a/util/uniq/slice.go +++ /dev/null @@ -1,62 +0,0 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package uniq provides removal of adjacent duplicate elements in slices. -// It is similar to the unix command uniq. -package uniq - -// ModifySlice removes adjacent duplicate elements from the given slice. It -// adjusts the length of the slice appropriately and zeros the tail. -// -// ModifySlice does O(len(*slice)) operations. -func ModifySlice[E comparable](slice *[]E) { - // Remove duplicates - dst := 0 - for i := 1; i < len(*slice); i++ { - if (*slice)[i] == (*slice)[dst] { - continue - } - dst++ - (*slice)[dst] = (*slice)[i] - } - - // Zero out the elements we removed at the end of the slice - end := dst + 1 - var zero E - for i := end; i < len(*slice); i++ { - (*slice)[i] = zero - } - - // Truncate the slice - if end < len(*slice) { - *slice = (*slice)[:end] - } -} - -// ModifySliceFunc is the same as ModifySlice except that it allows using a -// custom comparison function. -// -// eq should report whether the two provided elements are equal. -func ModifySliceFunc[E any](slice *[]E, eq func(i, j E) bool) { - // Remove duplicates - dst := 0 - for i := 1; i < len(*slice); i++ { - if eq((*slice)[dst], (*slice)[i]) { - continue - } - dst++ - (*slice)[dst] = (*slice)[i] - } - - // Zero out the elements we removed at the end of the slice - end := dst + 1 - var zero E - for i := end; i < len(*slice); i++ { - (*slice)[i] = zero - } - - // Truncate the slice - if end < len(*slice) { - *slice = (*slice)[:end] - } -} diff --git a/util/uniq/slice_test.go b/util/uniq/slice_test.go deleted file mode 100644 index 564fc0866..000000000 --- a/util/uniq/slice_test.go +++ /dev/null @@ -1,102 +0,0 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package uniq_test - -import ( - "reflect" - "strconv" - "testing" - - "tailscale.com/util/uniq" -) - -func runTests(t *testing.T, cb func(*[]uint32)) { - tests := []struct { - // Use uint32 to be different from an int-typed slice index - in []uint32 - want []uint32 - }{ - {in: []uint32{0, 1, 2}, want: []uint32{0, 1, 2}}, - {in: []uint32{0, 1, 2, 2}, want: []uint32{0, 1, 2}}, - {in: []uint32{0, 0, 1, 2}, want: []uint32{0, 1, 2}}, - {in: []uint32{0, 1, 0, 2}, want: []uint32{0, 1, 0, 2}}, - {in: []uint32{0}, want: []uint32{0}}, - {in: []uint32{0, 0}, want: []uint32{0}}, - {in: []uint32{}, want: []uint32{}}, - } - - for _, test := range tests { - in := make([]uint32, len(test.in)) - copy(in, test.in) - cb(&test.in) - if !reflect.DeepEqual(test.in, test.want) { - t.Errorf("uniq.Slice(%v) = %v, want %v", in, test.in, test.want) - } - start := len(test.in) - test.in = test.in[:cap(test.in)] - for i := start; i < len(in); i++ { - if test.in[i] != 0 { - t.Errorf("uniq.Slice(%v): non-0 in tail of %v at index %v", in, test.in, i) - } - } - } -} - -func TestModifySlice(t *testing.T) { - runTests(t, func(slice *[]uint32) { - uniq.ModifySlice(slice) - }) -} - -func TestModifySliceFunc(t *testing.T) { - runTests(t, func(slice *[]uint32) { - uniq.ModifySliceFunc(slice, func(i, j uint32) bool { - return i == j - }) - }) -} - -func Benchmark(b *testing.B) { - benches := []struct { - name string - reset func(s []byte) - }{ - {name: "AllDups", - reset: func(s []byte) { - for i := range s { - s[i] = '*' - } - }, - }, - {name: "NoDups", - reset: func(s []byte) { - for i := range s { - s[i] = byte(i) - } - }, - }, - } - - for _, bb := range benches { - b.Run(bb.name, func(b *testing.B) { - for size := 1; size <= 4096; size *= 16 { - b.Run(strconv.Itoa(size), func(b *testing.B) { - benchmark(b, 64, bb.reset) - }) - } - }) - } -} - -func benchmark(b *testing.B, size int64, reset func(s []byte)) { - b.ReportAllocs() - b.SetBytes(size) - s := make([]byte, size) - b.ResetTimer() - for range b.N { - s = s[:size] - reset(s) - uniq.ModifySlice(&s) - } -} diff --git a/util/usermetric/metrics.go b/util/usermetric/metrics.go new file mode 100644 index 000000000..be425fb87 --- /dev/null +++ b/util/usermetric/metrics.go @@ -0,0 +1,88 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// This file contains user-facing metrics that are used by multiple packages. +// Use it to define more common metrics. Any changes to the registry and +// metric types should be in usermetric.go. + +package usermetric + +import ( + "sync" + + "tailscale.com/feature/buildfeatures" +) + +// Metrics contains user-facing metrics that are used by multiple packages. +type Metrics struct { + initOnce sync.Once + + droppedPacketsInbound *MultiLabelMap[DropLabels] + droppedPacketsOutbound *MultiLabelMap[DropLabels] +} + +// DropReason is the reason why a packet was dropped. +type DropReason string + +const ( + // ReasonACL means that the packet was not permitted by ACL. + ReasonACL DropReason = "acl" + + // ReasonMulticast means that the packet was dropped because it was a multicast packet. + ReasonMulticast DropReason = "multicast" + + // ReasonLinkLocalUnicast means that the packet was dropped because it was a link-local unicast packet. + ReasonLinkLocalUnicast DropReason = "link_local_unicast" + + // ReasonTooShort means that the packet was dropped because it was a bad packet, + // this could be due to a short packet. + ReasonTooShort DropReason = "too_short" + + // ReasonFragment means that the packet was dropped because it was an IP fragment. + ReasonFragment DropReason = "fragment" + + // ReasonUnknownProtocol means that the packet was dropped because it was an unknown protocol. + ReasonUnknownProtocol DropReason = "unknown_protocol" + + // ReasonError means that the packet was dropped because of an error. + ReasonError DropReason = "error" +) + +// DropLabels contains common label(s) for dropped packet counters. +type DropLabels struct { + Reason DropReason +} + +// initOnce initializes the common metrics. +func (r *Registry) initOnce() { + if !buildfeatures.HasUserMetrics { + return + } + r.m.initOnce.Do(func() { + r.m.droppedPacketsInbound = NewMultiLabelMapWithRegistry[DropLabels]( + r, + "tailscaled_inbound_dropped_packets_total", + "counter", + "Counts the number of dropped packets received by the node from other peers", + ) + r.m.droppedPacketsOutbound = NewMultiLabelMapWithRegistry[DropLabels]( + r, + "tailscaled_outbound_dropped_packets_total", + "counter", + "Counts the number of packets dropped while being sent to other peers", + ) + }) +} + +// DroppedPacketsOutbound returns the outbound dropped packet metric, creating it +// if necessary. +func (r *Registry) DroppedPacketsOutbound() *MultiLabelMap[DropLabels] { + r.initOnce() + return r.m.droppedPacketsOutbound +} + +// DroppedPacketsInbound returns the inbound dropped packet metric. +func (r *Registry) DroppedPacketsInbound() *MultiLabelMap[DropLabels] { + r.initOnce() + return r.m.droppedPacketsInbound +} diff --git a/util/usermetric/omit.go b/util/usermetric/omit.go new file mode 100644 index 000000000..0611990ab --- /dev/null +++ b/util/usermetric/omit.go @@ -0,0 +1,29 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build ts_omit_usermetrics + +package usermetric + +type Registry struct { + m Metrics +} + +func (*Registry) NewGauge(name, help string) *Gauge { return nil } + +type MultiLabelMap[T comparable] = noopMap[T] + +type noopMap[T comparable] struct{} + +type Gauge struct{} + +func (*Gauge) Set(float64) {} + +func NewMultiLabelMapWithRegistry[T comparable](m *Registry, name string, promType, helpText string) *MultiLabelMap[T] { + return nil +} + +func (*noopMap[T]) Add(T, int64) {} +func (*noopMap[T]) Set(T, any) {} + +func (r *Registry) Handler(any, any) {} // no-op HTTP handler diff --git a/util/usermetric/usermetric.go b/util/usermetric/usermetric.go index c964e08a7..1805a5dbe 100644 --- a/util/usermetric/usermetric.go +++ b/util/usermetric/usermetric.go @@ -1,6 +1,8 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause +//go:build !ts_omit_usermetrics + // Package usermetric provides a container and handler // for user-facing metrics. package usermetric @@ -14,13 +16,21 @@ import ( "tailscale.com/metrics" "tailscale.com/tsweb/varz" + "tailscale.com/util/set" ) // Registry tracks user-facing metrics of various Tailscale subsystems. type Registry struct { vars expvar.Map + + // m contains common metrics owned by the registry. + m Metrics } +// MultiLabelMap is an alias for metrics.MultiLabelMap in the common case, +// or an alias to a lighter type when usermetrics are omitted from the build. +type MultiLabelMap[T comparable] = metrics.MultiLabelMap[T] + // NewMultiLabelMapWithRegistry creates and register a new // MultiLabelMap[T] variable with the given name and returns it. // The variable is registered with the userfacing metrics package. @@ -103,3 +113,13 @@ func (r *Registry) String() string { return sb.String() } + +// Metrics returns the name of all the metrics in the registry. +func (r *Registry) MetricNames() []string { + ret := make(set.Set[string]) + r.vars.Do(func(kv expvar.KeyValue) { + ret.Add(kv.Key) + }) + + return ret.Slice() +} diff --git a/util/vizerror/vizerror.go b/util/vizerror/vizerror.go index 158786494..919d765d0 100644 --- a/util/vizerror/vizerror.go +++ b/util/vizerror/vizerror.go @@ -12,35 +12,67 @@ import ( // Error is an error that is safe to display to end users. type Error struct { - err error + publicErr error // visible to end users + wrapped error // internal } -// Error implements the error interface. +// Error implements the error interface. The returned string is safe to display +// to end users. func (e Error) Error() string { - return e.err.Error() + return e.publicErr.Error() } // New returns an error that formats as the given text. It always returns a vizerror.Error. -func New(text string) error { - return Error{errors.New(text)} +func New(publicMsg string) error { + err := errors.New(publicMsg) + return Error{ + publicErr: err, + wrapped: err, + } } -// Errorf returns an Error with the specified format and values. It always returns a vizerror.Error. -func Errorf(format string, a ...any) error { - return Error{fmt.Errorf(format, a...)} +// Errorf returns an Error with the specified publicMsgFormat and values. It always returns a vizerror.Error. +// +// Warning: avoid using an error as one of the format arguments, as this will cause the text +// of that error to be displayed to the end user (which is probably not what you want). +func Errorf(publicMsgFormat string, a ...any) error { + err := fmt.Errorf(publicMsgFormat, a...) + return Error{ + publicErr: err, + wrapped: err, + } } // Unwrap returns the underlying error. +// +// If the Error was constructed using [WrapWithMessage], this is the wrapped (internal) error +// and not the user-visible error message. func (e Error) Unwrap() error { - return e.err + return e.wrapped } -// Wrap wraps err with a vizerror.Error. -func Wrap(err error) error { - if err == nil { +// Wrap wraps publicErr with a vizerror.Error. +// +// Deprecated: this is almost always the wrong thing to do. Are you really sure +// you know exactly what err.Error() will stringify to and be safe to show to +// users? [WrapWithMessage] is probably what you want. +func Wrap(publicErr error) error { + if publicErr == nil { return nil } - return Error{err} + return Error{publicErr: publicErr, wrapped: publicErr} +} + +// WrapWithMessage wraps the given error with a message that's safe to display +// to end users. The text of the wrapped error will not be displayed to end +// users. +// +// WrapWithMessage should almost always be preferred to [Wrap]. +func WrapWithMessage(wrapped error, publicMsg string) error { + return Error{ + publicErr: errors.New(publicMsg), + wrapped: wrapped, + } } // As returns the first vizerror.Error in err's chain. diff --git a/util/vizerror/vizerror_test.go b/util/vizerror/vizerror_test.go index bbd2c07e5..242ca6462 100644 --- a/util/vizerror/vizerror_test.go +++ b/util/vizerror/vizerror_test.go @@ -42,3 +42,25 @@ func TestAs(t *testing.T) { t.Errorf("As() returned error %v, want %v", got, verr) } } + +func TestWrap(t *testing.T) { + wrapped := errors.New("wrapped") + err := Wrap(wrapped) + if err.Error() != "wrapped" { + t.Errorf(`Wrap(wrapped).Error() = %q, want %q`, err.Error(), "wrapped") + } + if errors.Unwrap(err) != wrapped { + t.Errorf("Unwrap = %q, want %q", errors.Unwrap(err), wrapped) + } +} + +func TestWrapWithMessage(t *testing.T) { + wrapped := errors.New("wrapped") + err := WrapWithMessage(wrapped, "safe") + if err.Error() != "safe" { + t.Errorf(`WrapWithMessage(wrapped, "safe").Error() = %q, want %q`, err.Error(), "safe") + } + if errors.Unwrap(err) != wrapped { + t.Errorf("Unwrap = %q, want %q", errors.Unwrap(err), wrapped) + } +} diff --git a/util/winutil/authenticode/zsyscall_windows.go b/util/winutil/authenticode/zsyscall_windows.go index 643721e06..f1fba2828 100644 --- a/util/winutil/authenticode/zsyscall_windows.go +++ b/util/winutil/authenticode/zsyscall_windows.go @@ -56,7 +56,7 @@ var ( ) func cryptMsgClose(cryptMsg windows.Handle) (err error) { - r1, _, e1 := syscall.Syscall(procCryptMsgClose.Addr(), 1, uintptr(cryptMsg), 0, 0) + r1, _, e1 := syscall.SyscallN(procCryptMsgClose.Addr(), uintptr(cryptMsg)) if int32(r1) == 0 { err = errnoErr(e1) } @@ -64,7 +64,7 @@ func cryptMsgClose(cryptMsg windows.Handle) (err error) { } func cryptMsgGetParam(cryptMsg windows.Handle, paramType uint32, index uint32, data unsafe.Pointer, dataLen *uint32) (err error) { - r1, _, e1 := syscall.Syscall6(procCryptMsgGetParam.Addr(), 5, uintptr(cryptMsg), uintptr(paramType), uintptr(index), uintptr(data), uintptr(unsafe.Pointer(dataLen)), 0) + r1, _, e1 := syscall.SyscallN(procCryptMsgGetParam.Addr(), uintptr(cryptMsg), uintptr(paramType), uintptr(index), uintptr(data), uintptr(unsafe.Pointer(dataLen))) if int32(r1) == 0 { err = errnoErr(e1) } @@ -72,7 +72,7 @@ func cryptMsgGetParam(cryptMsg windows.Handle, paramType uint32, index uint32, d } func cryptVerifyMessageSignature(pVerifyPara *_CRYPT_VERIFY_MESSAGE_PARA, signerIndex uint32, pbSignedBlob *byte, cbSignedBlob uint32, pbDecoded *byte, pdbDecoded *uint32, ppSignerCert **windows.CertContext) (err error) { - r1, _, e1 := syscall.Syscall9(procCryptVerifyMessageSignature.Addr(), 7, uintptr(unsafe.Pointer(pVerifyPara)), uintptr(signerIndex), uintptr(unsafe.Pointer(pbSignedBlob)), uintptr(cbSignedBlob), uintptr(unsafe.Pointer(pbDecoded)), uintptr(unsafe.Pointer(pdbDecoded)), uintptr(unsafe.Pointer(ppSignerCert)), 0, 0) + r1, _, e1 := syscall.SyscallN(procCryptVerifyMessageSignature.Addr(), uintptr(unsafe.Pointer(pVerifyPara)), uintptr(signerIndex), uintptr(unsafe.Pointer(pbSignedBlob)), uintptr(cbSignedBlob), uintptr(unsafe.Pointer(pbDecoded)), uintptr(unsafe.Pointer(pdbDecoded)), uintptr(unsafe.Pointer(ppSignerCert))) if int32(r1) == 0 { err = errnoErr(e1) } @@ -80,13 +80,13 @@ func cryptVerifyMessageSignature(pVerifyPara *_CRYPT_VERIFY_MESSAGE_PARA, signer } func msiGetFileSignatureInformation(signedObjectPath *uint16, flags uint32, certCtx **windows.CertContext, pbHashData *byte, cbHashData *uint32) (ret wingoes.HRESULT) { - r0, _, _ := syscall.Syscall6(procMsiGetFileSignatureInformationW.Addr(), 5, uintptr(unsafe.Pointer(signedObjectPath)), uintptr(flags), uintptr(unsafe.Pointer(certCtx)), uintptr(unsafe.Pointer(pbHashData)), uintptr(unsafe.Pointer(cbHashData)), 0) + r0, _, _ := syscall.SyscallN(procMsiGetFileSignatureInformationW.Addr(), uintptr(unsafe.Pointer(signedObjectPath)), uintptr(flags), uintptr(unsafe.Pointer(certCtx)), uintptr(unsafe.Pointer(pbHashData)), uintptr(unsafe.Pointer(cbHashData))) ret = wingoes.HRESULT(r0) return } func cryptCATAdminAcquireContext2(hCatAdmin *_HCATADMIN, pgSubsystem *windows.GUID, hashAlgorithm *uint16, strongHashPolicy *windows.CertStrongSignPara, flags uint32) (err error) { - r1, _, e1 := syscall.Syscall6(procCryptCATAdminAcquireContext2.Addr(), 5, uintptr(unsafe.Pointer(hCatAdmin)), uintptr(unsafe.Pointer(pgSubsystem)), uintptr(unsafe.Pointer(hashAlgorithm)), uintptr(unsafe.Pointer(strongHashPolicy)), uintptr(flags), 0) + r1, _, e1 := syscall.SyscallN(procCryptCATAdminAcquireContext2.Addr(), uintptr(unsafe.Pointer(hCatAdmin)), uintptr(unsafe.Pointer(pgSubsystem)), uintptr(unsafe.Pointer(hashAlgorithm)), uintptr(unsafe.Pointer(strongHashPolicy)), uintptr(flags)) if int32(r1) == 0 { err = errnoErr(e1) } @@ -94,7 +94,7 @@ func cryptCATAdminAcquireContext2(hCatAdmin *_HCATADMIN, pgSubsystem *windows.GU } func cryptCATAdminCalcHashFromFileHandle2(hCatAdmin _HCATADMIN, file windows.Handle, pcbHash *uint32, pbHash *byte, flags uint32) (err error) { - r1, _, e1 := syscall.Syscall6(procCryptCATAdminCalcHashFromFileHandle2.Addr(), 5, uintptr(hCatAdmin), uintptr(file), uintptr(unsafe.Pointer(pcbHash)), uintptr(unsafe.Pointer(pbHash)), uintptr(flags), 0) + r1, _, e1 := syscall.SyscallN(procCryptCATAdminCalcHashFromFileHandle2.Addr(), uintptr(hCatAdmin), uintptr(file), uintptr(unsafe.Pointer(pcbHash)), uintptr(unsafe.Pointer(pbHash)), uintptr(flags)) if int32(r1) == 0 { err = errnoErr(e1) } @@ -102,7 +102,7 @@ func cryptCATAdminCalcHashFromFileHandle2(hCatAdmin _HCATADMIN, file windows.Han } func cryptCATAdminEnumCatalogFromHash(hCatAdmin _HCATADMIN, pbHash *byte, cbHash uint32, flags uint32, prevCatInfo *_HCATINFO) (ret _HCATINFO, err error) { - r0, _, e1 := syscall.Syscall6(procCryptCATAdminEnumCatalogFromHash.Addr(), 5, uintptr(hCatAdmin), uintptr(unsafe.Pointer(pbHash)), uintptr(cbHash), uintptr(flags), uintptr(unsafe.Pointer(prevCatInfo)), 0) + r0, _, e1 := syscall.SyscallN(procCryptCATAdminEnumCatalogFromHash.Addr(), uintptr(hCatAdmin), uintptr(unsafe.Pointer(pbHash)), uintptr(cbHash), uintptr(flags), uintptr(unsafe.Pointer(prevCatInfo))) ret = _HCATINFO(r0) if ret == 0 { err = errnoErr(e1) @@ -111,7 +111,7 @@ func cryptCATAdminEnumCatalogFromHash(hCatAdmin _HCATADMIN, pbHash *byte, cbHash } func cryptCATAdminReleaseCatalogContext(hCatAdmin _HCATADMIN, hCatInfo _HCATINFO, flags uint32) (err error) { - r1, _, e1 := syscall.Syscall(procCryptCATAdminReleaseCatalogContext.Addr(), 3, uintptr(hCatAdmin), uintptr(hCatInfo), uintptr(flags)) + r1, _, e1 := syscall.SyscallN(procCryptCATAdminReleaseCatalogContext.Addr(), uintptr(hCatAdmin), uintptr(hCatInfo), uintptr(flags)) if int32(r1) == 0 { err = errnoErr(e1) } @@ -119,7 +119,7 @@ func cryptCATAdminReleaseCatalogContext(hCatAdmin _HCATADMIN, hCatInfo _HCATINFO } func cryptCATAdminReleaseContext(hCatAdmin _HCATADMIN, flags uint32) (err error) { - r1, _, e1 := syscall.Syscall(procCryptCATAdminReleaseContext.Addr(), 2, uintptr(hCatAdmin), uintptr(flags), 0) + r1, _, e1 := syscall.SyscallN(procCryptCATAdminReleaseContext.Addr(), uintptr(hCatAdmin), uintptr(flags)) if int32(r1) == 0 { err = errnoErr(e1) } @@ -127,7 +127,7 @@ func cryptCATAdminReleaseContext(hCatAdmin _HCATADMIN, flags uint32) (err error) } func cryptCATAdminCatalogInfoFromContext(hCatInfo _HCATINFO, catInfo *_CATALOG_INFO, flags uint32) (err error) { - r1, _, e1 := syscall.Syscall(procCryptCATCatalogInfoFromContext.Addr(), 3, uintptr(hCatInfo), uintptr(unsafe.Pointer(catInfo)), uintptr(flags)) + r1, _, e1 := syscall.SyscallN(procCryptCATCatalogInfoFromContext.Addr(), uintptr(hCatInfo), uintptr(unsafe.Pointer(catInfo)), uintptr(flags)) if int32(r1) == 0 { err = errnoErr(e1) } diff --git a/util/winutil/gp/gp_windows_test.go b/util/winutil/gp/gp_windows_test.go index e2520b46d..f89206883 100644 --- a/util/winutil/gp/gp_windows_test.go +++ b/util/winutil/gp/gp_windows_test.go @@ -182,16 +182,16 @@ func doWithMachinePolicyLocked(t *testing.T, f func()) { f() } -func doWithCustomEnterLeaveFuncs(t *testing.T, f func(l *PolicyLock), enter func(bool) (policyLockHandle, error), leave func(policyLockHandle) error) { +func doWithCustomEnterLeaveFuncs(t *testing.T, f func(*PolicyLock), enter func(bool) (policyLockHandle, error), leave func(policyLockHandle) error) { t.Helper() - l := NewMachinePolicyLock() - l.enterFn, l.leaveFn = enter, leave + lock := NewMachinePolicyLock() + lock.enterFn, lock.leaveFn = enter, leave t.Cleanup(func() { - if err := l.Close(); err != nil { + if err := lock.Close(); err != nil { t.Fatalf("(*PolicyLock).Close failed: %v", err) } }) - f(l) + f(lock) } diff --git a/util/winutil/gp/policylock_windows.go b/util/winutil/gp/policylock_windows.go index 95453aa16..6c3ca0baf 100644 --- a/util/winutil/gp/policylock_windows.go +++ b/util/winutil/gp/policylock_windows.go @@ -48,10 +48,35 @@ type policyLockResult struct { } var ( - // ErrInvalidLockState is returned by (*PolicyLock).Lock if the lock has a zero value or has already been closed. + // ErrInvalidLockState is returned by [PolicyLock.Lock] if the lock has a zero value or has already been closed. ErrInvalidLockState = errors.New("the lock has not been created or has already been closed") + // ErrLockRestricted is returned by [PolicyLock.Lock] if the lock cannot be acquired due to a restriction in place, + // such as when [RestrictPolicyLocks] has been called. + ErrLockRestricted = errors.New("the lock cannot be acquired due to a restriction in place") ) +var policyLockRestricted atomic.Int32 + +// RestrictPolicyLocks forces all [PolicyLock.Lock] calls to return [ErrLockRestricted] +// until the returned function is called to remove the restriction. +// +// It is safe to call the returned function multiple times, but the restriction will only +// be removed once. If [RestrictPolicyLocks] is called multiple times, each call must be +// matched by a corresponding call to the returned function to fully remove the restrictions. +// +// It is primarily used to prevent certain deadlocks, such as when tailscaled attempts to acquire +// a policy lock during startup. If the service starts due to Tailscale being installed by GPSI, +// the write lock will be held by the Group Policy service throughout the installation, +// preventing tailscaled from acquiring the read lock. Since Group Policy waits for the installation +// to complete, and therefore for tailscaled to start, before releasing the write lock, this scenario +// would result in a deadlock. See tailscale/tailscale#14416 for more information. +func RestrictPolicyLocks() (removeRestriction func()) { + policyLockRestricted.Add(1) + return sync.OnceFunc(func() { + policyLockRestricted.Add(-1) + }) +} + // NewMachinePolicyLock creates a PolicyLock that facilitates pausing the // application of computer policy. To avoid deadlocks when acquiring both // machine and user locks, acquire the user lock before the machine lock. @@ -102,27 +127,32 @@ func NewUserPolicyLock(token windows.Token) (*PolicyLock, error) { return lock, nil } -// Lock locks l. -// It returns ErrNotInitialized if l has a zero value or has already been closed, -// or an Errno if the underlying Group Policy lock cannot be acquired. +// Lock locks lk. +// It returns [ErrInvalidLockState] if lk has a zero value or has already been closed, +// [ErrLockRestricted] if the lock cannot be acquired due to a restriction in place, +// or a [syscall.Errno] if the underlying Group Policy lock cannot be acquired. // -// As a special case, it fails with windows.ERROR_ACCESS_DENIED -// if l is a user policy lock, and the corresponding user is not logged in +// As a special case, it fails with [windows.ERROR_ACCESS_DENIED] +// if lk is a user policy lock, and the corresponding user is not logged in // interactively at the time of the call. -func (l *PolicyLock) Lock() error { - l.mu.Lock() - defer l.mu.Unlock() - if l.lockCnt.Add(2)&1 == 0 { +func (lk *PolicyLock) Lock() error { + if policyLockRestricted.Load() > 0 { + return ErrLockRestricted + } + + lk.mu.Lock() + defer lk.mu.Unlock() + if lk.lockCnt.Add(2)&1 == 0 { // The lock cannot be acquired because it has either never been properly // created or its Close method has already been called. However, we need // to call Unlock to both decrement lockCnt and leave the underlying // CriticalPolicySection if we won the race with another goroutine and // now own the lock. - l.Unlock() + lk.Unlock() return ErrInvalidLockState } - if l.handle != 0 { + if lk.handle != 0 { // The underlying CriticalPolicySection is already acquired. // It is an R-Lock (with the W-counterpart owned by the Group Policy service), // meaning that it can be acquired by multiple readers simultaneously. @@ -130,20 +160,20 @@ func (l *PolicyLock) Lock() error { return nil } - return l.lockSlow() + return lk.lockSlow() } // lockSlow calls enterCriticalPolicySection to acquire the underlying GP read lock. // It waits for either the lock to be acquired, or for the Close method to be called. // // l.mu must be held. -func (l *PolicyLock) lockSlow() (err error) { +func (lk *PolicyLock) lockSlow() (err error) { defer func() { if err != nil { // Decrement the counter if the lock cannot be acquired, // and complete the pending close request if we're the last owner. - if l.lockCnt.Add(-2) == 0 { - l.closeInternal() + if lk.lockCnt.Add(-2) == 0 { + lk.closeInternal() } } }() @@ -160,12 +190,12 @@ func (l *PolicyLock) lockSlow() (err error) { resultCh := make(chan policyLockResult) go func() { - closing := l.closing - if l.scope == UserPolicy && l.token != 0 { + closing := lk.closing + if lk.scope == UserPolicy && lk.token != 0 { // Impersonate the user whose critical policy section we want to acquire. runtime.LockOSThread() defer runtime.UnlockOSThread() - if err := impersonateLoggedOnUser(l.token); err != nil { + if err := impersonateLoggedOnUser(lk.token); err != nil { initCh <- err return } @@ -179,10 +209,10 @@ func (l *PolicyLock) lockSlow() (err error) { close(initCh) var machine bool - if l.scope == MachinePolicy { + if lk.scope == MachinePolicy { machine = true } - handle, err := l.enterFn(machine) + handle, err := lk.enterFn(machine) send_result: for { @@ -196,7 +226,7 @@ func (l *PolicyLock) lockSlow() (err error) { // The lock is being closed, and we lost the race to l.closing // it the calling goroutine. if err == nil { - l.leaveFn(handle) + lk.leaveFn(handle) } break send_result default: @@ -217,21 +247,21 @@ func (l *PolicyLock) lockSlow() (err error) { select { case result := <-resultCh: if result.err == nil { - l.handle = result.handle + lk.handle = result.handle } return result.err - case <-l.closing: + case <-lk.closing: return ErrInvalidLockState } } // Unlock unlocks l. // It panics if l is not locked on entry to Unlock. -func (l *PolicyLock) Unlock() { - l.mu.Lock() - defer l.mu.Unlock() +func (lk *PolicyLock) Unlock() { + lk.mu.Lock() + defer lk.mu.Unlock() - lockCnt := l.lockCnt.Add(-2) + lockCnt := lk.lockCnt.Add(-2) if lockCnt < 0 { panic("negative lockCnt") } @@ -243,33 +273,33 @@ func (l *PolicyLock) Unlock() { return } - if l.handle != 0 { + if lk.handle != 0 { // Impersonation is not required to unlock a critical policy section. // The handle we pass determines which mutex will be unlocked. - leaveCriticalPolicySection(l.handle) - l.handle = 0 + leaveCriticalPolicySection(lk.handle) + lk.handle = 0 } if lockCnt == 0 { // Complete the pending close request if there's no more readers. - l.closeInternal() + lk.closeInternal() } } // Close releases resources associated with l. // It is a no-op for the machine policy lock. -func (l *PolicyLock) Close() error { - lockCnt := l.lockCnt.Load() +func (lk *PolicyLock) Close() error { + lockCnt := lk.lockCnt.Load() if lockCnt&1 == 0 { // The lock has never been initialized, or close has already been called. return nil } - close(l.closing) + close(lk.closing) // Unset the LSB to indicate a pending close request. - for !l.lockCnt.CompareAndSwap(lockCnt, lockCnt&^int32(1)) { - lockCnt = l.lockCnt.Load() + for !lk.lockCnt.CompareAndSwap(lockCnt, lockCnt&^int32(1)) { + lockCnt = lk.lockCnt.Load() } if lockCnt != 0 { @@ -277,16 +307,16 @@ func (l *PolicyLock) Close() error { return nil } - return l.closeInternal() + return lk.closeInternal() } -func (l *PolicyLock) closeInternal() error { - if l.token != 0 { - if err := l.token.Close(); err != nil { +func (lk *PolicyLock) closeInternal() error { + if lk.token != 0 { + if err := lk.token.Close(); err != nil { return err } - l.token = 0 + lk.token = 0 } - l.closing = nil + lk.closing = nil return nil } diff --git a/util/winutil/gp/zsyscall_windows.go b/util/winutil/gp/zsyscall_windows.go index 5e40ec3d1..41c240c26 100644 --- a/util/winutil/gp/zsyscall_windows.go +++ b/util/winutil/gp/zsyscall_windows.go @@ -50,7 +50,7 @@ var ( ) func impersonateLoggedOnUser(token windows.Token) (err error) { - r1, _, e1 := syscall.Syscall(procImpersonateLoggedOnUser.Addr(), 1, uintptr(token), 0, 0) + r1, _, e1 := syscall.SyscallN(procImpersonateLoggedOnUser.Addr(), uintptr(token)) if int32(r1) == 0 { err = errnoErr(e1) } @@ -62,7 +62,7 @@ func enterCriticalPolicySection(machine bool) (handle policyLockHandle, err erro if machine { _p0 = 1 } - r0, _, e1 := syscall.Syscall(procEnterCriticalPolicySection.Addr(), 1, uintptr(_p0), 0, 0) + r0, _, e1 := syscall.SyscallN(procEnterCriticalPolicySection.Addr(), uintptr(_p0)) handle = policyLockHandle(r0) if int32(handle) == 0 { err = errnoErr(e1) @@ -71,7 +71,7 @@ func enterCriticalPolicySection(machine bool) (handle policyLockHandle, err erro } func leaveCriticalPolicySection(handle policyLockHandle) (err error) { - r1, _, e1 := syscall.Syscall(procLeaveCriticalPolicySection.Addr(), 1, uintptr(handle), 0, 0) + r1, _, e1 := syscall.SyscallN(procLeaveCriticalPolicySection.Addr(), uintptr(handle)) if int32(r1) == 0 { err = errnoErr(e1) } @@ -83,7 +83,7 @@ func refreshPolicyEx(machine bool, flags uint32) (err error) { if machine { _p0 = 1 } - r1, _, e1 := syscall.Syscall(procRefreshPolicyEx.Addr(), 2, uintptr(_p0), uintptr(flags), 0) + r1, _, e1 := syscall.SyscallN(procRefreshPolicyEx.Addr(), uintptr(_p0), uintptr(flags)) if int32(r1) == 0 { err = errnoErr(e1) } @@ -95,7 +95,7 @@ func registerGPNotification(event windows.Handle, machine bool) (err error) { if machine { _p0 = 1 } - r1, _, e1 := syscall.Syscall(procRegisterGPNotification.Addr(), 2, uintptr(event), uintptr(_p0), 0) + r1, _, e1 := syscall.SyscallN(procRegisterGPNotification.Addr(), uintptr(event), uintptr(_p0)) if int32(r1) == 0 { err = errnoErr(e1) } @@ -103,7 +103,7 @@ func registerGPNotification(event windows.Handle, machine bool) (err error) { } func unregisterGPNotification(event windows.Handle) (err error) { - r1, _, e1 := syscall.Syscall(procUnregisterGPNotification.Addr(), 1, uintptr(event), 0, 0) + r1, _, e1 := syscall.SyscallN(procUnregisterGPNotification.Addr(), uintptr(event)) if int32(r1) == 0 { err = errnoErr(e1) } diff --git a/util/winutil/restartmgr_windows.go b/util/winutil/restartmgr_windows.go index a52e2fee9..6f549de55 100644 --- a/util/winutil/restartmgr_windows.go +++ b/util/winutil/restartmgr_windows.go @@ -19,7 +19,6 @@ import ( "github.com/dblohm7/wingoes" "golang.org/x/sys/windows" "tailscale.com/types/logger" - "tailscale.com/util/multierr" ) var ( @@ -538,7 +537,7 @@ func (rps RestartableProcesses) Terminate(logf logger.Logf, exitCode uint32, tim } if len(errs) != 0 { - return multierr.New(errs...) + return errors.Join(errs...) } return nil } diff --git a/util/winutil/s4u/lsa_windows.go b/util/winutil/s4u/lsa_windows.go index 3ff2171f9..3276b2676 100644 --- a/util/winutil/s4u/lsa_windows.go +++ b/util/winutil/s4u/lsa_windows.go @@ -256,8 +256,8 @@ func checkDomainAccount(username string) (sanitizedUserName string, isDomainAcco // errors.Is to check for it. When capLevel == CapCreateProcess, the logon // enforces the user's logon hours policy (when present). func (ls *lsaSession) logonAs(srcName string, u *user.User, capLevel CapabilityLevel) (token windows.Token, err error) { - if l := len(srcName); l == 0 || l > _TOKEN_SOURCE_LENGTH { - return 0, fmt.Errorf("%w, actual length is %d", ErrBadSrcName, l) + if ln := len(srcName); ln == 0 || ln > _TOKEN_SOURCE_LENGTH { + return 0, fmt.Errorf("%w, actual length is %d", ErrBadSrcName, ln) } if err := checkASCII(srcName); err != nil { return 0, fmt.Errorf("%w: %v", ErrBadSrcName, err) diff --git a/util/winutil/s4u/s4u_windows.go b/util/winutil/s4u/s4u_windows.go index a12b4786a..8c8e02dbe 100644 --- a/util/winutil/s4u/s4u_windows.go +++ b/util/winutil/s4u/s4u_windows.go @@ -17,6 +17,7 @@ import ( "slices" "strconv" "strings" + "sync" "sync/atomic" "unsafe" @@ -128,9 +129,10 @@ func Login(logf logger.Logf, srcName string, u *user.User, capLevel CapabilityLe if err != nil { return nil, err } + tokenCloseOnce := sync.OnceFunc(func() { token.Close() }) defer func() { if err != nil { - token.Close() + tokenCloseOnce() } }() @@ -162,6 +164,7 @@ func Login(logf logger.Logf, srcName string, u *user.User, capLevel CapabilityLe sessToken.Close() } }() + tokenCloseOnce() } userProfile, err := winutil.LoadUserProfile(sessToken, u) @@ -935,10 +938,10 @@ func mergeEnv(existingEnv []string, extraEnv map[string]string) []string { result = append(result, strings.Join([]string{k, v}, "=")) } - slices.SortFunc(result, func(l, r string) int { - kl, _, _ := strings.Cut(l, "=") - kr, _, _ := strings.Cut(r, "=") - return strings.Compare(kl, kr) + slices.SortFunc(result, func(a, b string) int { + ka, _, _ := strings.Cut(a, "=") + kb, _, _ := strings.Cut(b, "=") + return strings.Compare(ka, kb) }) return result } diff --git a/util/winutil/s4u/zsyscall_windows.go b/util/winutil/s4u/zsyscall_windows.go index 6a8c78427..db647dee4 100644 --- a/util/winutil/s4u/zsyscall_windows.go +++ b/util/winutil/s4u/zsyscall_windows.go @@ -52,7 +52,7 @@ var ( ) func allocateLocallyUniqueId(luid *windows.LUID) (err error) { - r1, _, e1 := syscall.Syscall(procAllocateLocallyUniqueId.Addr(), 1, uintptr(unsafe.Pointer(luid)), 0, 0) + r1, _, e1 := syscall.SyscallN(procAllocateLocallyUniqueId.Addr(), uintptr(unsafe.Pointer(luid))) if int32(r1) == 0 { err = errnoErr(e1) } @@ -60,7 +60,7 @@ func allocateLocallyUniqueId(luid *windows.LUID) (err error) { } func impersonateLoggedOnUser(token windows.Token) (err error) { - r1, _, e1 := syscall.Syscall(procImpersonateLoggedOnUser.Addr(), 1, uintptr(token), 0, 0) + r1, _, e1 := syscall.SyscallN(procImpersonateLoggedOnUser.Addr(), uintptr(token)) if int32(r1) == 0 { err = errnoErr(e1) } @@ -68,37 +68,37 @@ func impersonateLoggedOnUser(token windows.Token) (err error) { } func lsaConnectUntrusted(lsaHandle *_LSAHANDLE) (ret windows.NTStatus) { - r0, _, _ := syscall.Syscall(procLsaConnectUntrusted.Addr(), 1, uintptr(unsafe.Pointer(lsaHandle)), 0, 0) + r0, _, _ := syscall.SyscallN(procLsaConnectUntrusted.Addr(), uintptr(unsafe.Pointer(lsaHandle))) ret = windows.NTStatus(r0) return } func lsaDeregisterLogonProcess(lsaHandle _LSAHANDLE) (ret windows.NTStatus) { - r0, _, _ := syscall.Syscall(procLsaDeregisterLogonProcess.Addr(), 1, uintptr(lsaHandle), 0, 0) + r0, _, _ := syscall.SyscallN(procLsaDeregisterLogonProcess.Addr(), uintptr(lsaHandle)) ret = windows.NTStatus(r0) return } func lsaFreeReturnBuffer(buffer uintptr) (ret windows.NTStatus) { - r0, _, _ := syscall.Syscall(procLsaFreeReturnBuffer.Addr(), 1, uintptr(buffer), 0, 0) + r0, _, _ := syscall.SyscallN(procLsaFreeReturnBuffer.Addr(), uintptr(buffer)) ret = windows.NTStatus(r0) return } func lsaLogonUser(lsaHandle _LSAHANDLE, originName *windows.NTString, logonType _SECURITY_LOGON_TYPE, authenticationPackage uint32, authenticationInformation unsafe.Pointer, authenticationInformationLength uint32, localGroups *windows.Tokengroups, sourceContext *_TOKEN_SOURCE, profileBuffer *uintptr, profileBufferLength *uint32, logonID *windows.LUID, token *windows.Token, quotas *_QUOTA_LIMITS, subStatus *windows.NTStatus) (ret windows.NTStatus) { - r0, _, _ := syscall.Syscall15(procLsaLogonUser.Addr(), 14, uintptr(lsaHandle), uintptr(unsafe.Pointer(originName)), uintptr(logonType), uintptr(authenticationPackage), uintptr(authenticationInformation), uintptr(authenticationInformationLength), uintptr(unsafe.Pointer(localGroups)), uintptr(unsafe.Pointer(sourceContext)), uintptr(unsafe.Pointer(profileBuffer)), uintptr(unsafe.Pointer(profileBufferLength)), uintptr(unsafe.Pointer(logonID)), uintptr(unsafe.Pointer(token)), uintptr(unsafe.Pointer(quotas)), uintptr(unsafe.Pointer(subStatus)), 0) + r0, _, _ := syscall.SyscallN(procLsaLogonUser.Addr(), uintptr(lsaHandle), uintptr(unsafe.Pointer(originName)), uintptr(logonType), uintptr(authenticationPackage), uintptr(authenticationInformation), uintptr(authenticationInformationLength), uintptr(unsafe.Pointer(localGroups)), uintptr(unsafe.Pointer(sourceContext)), uintptr(unsafe.Pointer(profileBuffer)), uintptr(unsafe.Pointer(profileBufferLength)), uintptr(unsafe.Pointer(logonID)), uintptr(unsafe.Pointer(token)), uintptr(unsafe.Pointer(quotas)), uintptr(unsafe.Pointer(subStatus))) ret = windows.NTStatus(r0) return } func lsaLookupAuthenticationPackage(lsaHandle _LSAHANDLE, packageName *windows.NTString, authenticationPackage *uint32) (ret windows.NTStatus) { - r0, _, _ := syscall.Syscall(procLsaLookupAuthenticationPackage.Addr(), 3, uintptr(lsaHandle), uintptr(unsafe.Pointer(packageName)), uintptr(unsafe.Pointer(authenticationPackage))) + r0, _, _ := syscall.SyscallN(procLsaLookupAuthenticationPackage.Addr(), uintptr(lsaHandle), uintptr(unsafe.Pointer(packageName)), uintptr(unsafe.Pointer(authenticationPackage))) ret = windows.NTStatus(r0) return } func lsaRegisterLogonProcess(logonProcessName *windows.NTString, lsaHandle *_LSAHANDLE, securityMode *_LSA_OPERATIONAL_MODE) (ret windows.NTStatus) { - r0, _, _ := syscall.Syscall(procLsaRegisterLogonProcess.Addr(), 3, uintptr(unsafe.Pointer(logonProcessName)), uintptr(unsafe.Pointer(lsaHandle)), uintptr(unsafe.Pointer(securityMode))) + r0, _, _ := syscall.SyscallN(procLsaRegisterLogonProcess.Addr(), uintptr(unsafe.Pointer(logonProcessName)), uintptr(unsafe.Pointer(lsaHandle)), uintptr(unsafe.Pointer(securityMode))) ret = windows.NTStatus(r0) return } diff --git a/util/winutil/startupinfo_windows.go b/util/winutil/startupinfo_windows.go index e04e9ea9b..edf48fa65 100644 --- a/util/winutil/startupinfo_windows.go +++ b/util/winutil/startupinfo_windows.go @@ -83,8 +83,8 @@ func (sib *StartupInfoBuilder) Resolve() (startupInfo *windows.StartupInfo, inhe // Always create a Unicode environment. createProcessFlags = windows.CREATE_UNICODE_ENVIRONMENT - if l := uint32(len(sib.attrs)); l > 0 { - attrCont, err := windows.NewProcThreadAttributeList(l) + if ln := uint32(len(sib.attrs)); ln > 0 { + attrCont, err := windows.NewProcThreadAttributeList(ln) if err != nil { return nil, false, 0, err } diff --git a/util/winutil/winenv/zsyscall_windows.go b/util/winutil/winenv/zsyscall_windows.go index 2bdfdd9b1..7e93c7952 100644 --- a/util/winutil/winenv/zsyscall_windows.go +++ b/util/winutil/winenv/zsyscall_windows.go @@ -55,7 +55,7 @@ func isDeviceRegisteredWithManagement(isMDMRegistered *bool, upnBufLen uint32, u if *isMDMRegistered { _p0 = 1 } - r0, _, e1 := syscall.Syscall(procIsDeviceRegisteredWithManagement.Addr(), 3, uintptr(unsafe.Pointer(&_p0)), uintptr(upnBufLen), uintptr(unsafe.Pointer(upnBuf))) + r0, _, e1 := syscall.SyscallN(procIsDeviceRegisteredWithManagement.Addr(), uintptr(unsafe.Pointer(&_p0)), uintptr(upnBufLen), uintptr(unsafe.Pointer(upnBuf))) *isMDMRegistered = _p0 != 0 hr = int32(r0) if hr == 0 { @@ -65,13 +65,13 @@ func isDeviceRegisteredWithManagement(isMDMRegistered *bool, upnBufLen uint32, u } func verSetConditionMask(condMask verCondMask, typ verTypeMask, cond verCond) (res verCondMask) { - r0, _, _ := syscall.Syscall(procVerSetConditionMask.Addr(), 3, uintptr(condMask), uintptr(typ), uintptr(cond)) + r0, _, _ := syscall.SyscallN(procVerSetConditionMask.Addr(), uintptr(condMask), uintptr(typ), uintptr(cond)) res = verCondMask(r0) return } func verifyVersionInfo(verInfo *osVersionInfoEx, typ verTypeMask, cond verCondMask) (res bool) { - r0, _, _ := syscall.Syscall(procVerifyVersionInfoW.Addr(), 3, uintptr(unsafe.Pointer(verInfo)), uintptr(typ), uintptr(cond)) + r0, _, _ := syscall.SyscallN(procVerifyVersionInfoW.Addr(), uintptr(unsafe.Pointer(verInfo)), uintptr(typ), uintptr(cond)) res = r0 != 0 return } diff --git a/util/winutil/winutil_windows.go b/util/winutil/winutil_windows.go index 5dde9a347..c935b210e 100644 --- a/util/winutil/winutil_windows.go +++ b/util/winutil/winutil_windows.go @@ -8,8 +8,10 @@ import ( "fmt" "log" "math" + "os" "os/exec" "os/user" + "path/filepath" "reflect" "runtime" "strings" @@ -33,6 +35,10 @@ var ErrNoShell = errors.New("no Shell process is present") // ErrNoValue is returned when the value doesn't exist in the registry. var ErrNoValue = registry.ErrNotExist +// ErrBadRegValueFormat is returned when a string value does not match the +// expected format. +var ErrBadRegValueFormat = errors.New("registry value formatted incorrectly") + // GetDesktopPID searches the PID of the process that's running the // currently active desktop. Returns ErrNoShell if the shell is not present. // Usually the PID will be for explorer.exe. @@ -947,3 +953,22 @@ func IsDomainName(name string) (bool, error) { return isDomainName(name16) } + +// GUIPathFromReg obtains the path to the client GUI executable from the +// registry value that was written during installation. +func GUIPathFromReg() (string, error) { + regPath, err := GetRegString("GUIPath") + if err != nil { + return "", err + } + + if !filepath.IsAbs(regPath) { + return "", ErrBadRegValueFormat + } + + if _, err := os.Stat(regPath); err != nil { + return "", err + } + + return regPath, nil +} diff --git a/util/winutil/winutil_windows_test.go b/util/winutil/winutil_windows_test.go index d437ffa38..ead10a45d 100644 --- a/util/winutil/winutil_windows_test.go +++ b/util/winutil/winutil_windows_test.go @@ -68,8 +68,8 @@ func checkContiguousBuffer[T any, BU BufUnit](t *testing.T, extra []BU, pt *T, p if gotLen := int(ptLen); gotLen != expectedLen { t.Errorf("allocation length got %d, want %d", gotLen, expectedLen) } - if l := len(slcs); l != 1 { - t.Errorf("len(slcs) got %d, want 1", l) + if ln := len(slcs); ln != 1 { + t.Errorf("len(slcs) got %d, want 1", ln) } if len(extra) == 0 && slcs[0] != nil { t.Error("slcs[0] got non-nil, want nil") diff --git a/util/winutil/zsyscall_windows.go b/util/winutil/zsyscall_windows.go index b4674dff3..56aedb4c7 100644 --- a/util/winutil/zsyscall_windows.go +++ b/util/winutil/zsyscall_windows.go @@ -62,7 +62,7 @@ var ( ) func queryServiceConfig2(hService windows.Handle, infoLevel uint32, buf *byte, bufLen uint32, bytesNeeded *uint32) (err error) { - r1, _, e1 := syscall.Syscall6(procQueryServiceConfig2W.Addr(), 5, uintptr(hService), uintptr(infoLevel), uintptr(unsafe.Pointer(buf)), uintptr(bufLen), uintptr(unsafe.Pointer(bytesNeeded)), 0) + r1, _, e1 := syscall.SyscallN(procQueryServiceConfig2W.Addr(), uintptr(hService), uintptr(infoLevel), uintptr(unsafe.Pointer(buf)), uintptr(bufLen), uintptr(unsafe.Pointer(bytesNeeded))) if r1 == 0 { err = errnoErr(e1) } @@ -70,19 +70,19 @@ func queryServiceConfig2(hService windows.Handle, infoLevel uint32, buf *byte, b } func getApplicationRestartSettings(process windows.Handle, commandLine *uint16, commandLineLen *uint32, flags *uint32) (ret wingoes.HRESULT) { - r0, _, _ := syscall.Syscall6(procGetApplicationRestartSettings.Addr(), 4, uintptr(process), uintptr(unsafe.Pointer(commandLine)), uintptr(unsafe.Pointer(commandLineLen)), uintptr(unsafe.Pointer(flags)), 0, 0) + r0, _, _ := syscall.SyscallN(procGetApplicationRestartSettings.Addr(), uintptr(process), uintptr(unsafe.Pointer(commandLine)), uintptr(unsafe.Pointer(commandLineLen)), uintptr(unsafe.Pointer(flags))) ret = wingoes.HRESULT(r0) return } func registerApplicationRestart(cmdLineExclExeName *uint16, flags uint32) (ret wingoes.HRESULT) { - r0, _, _ := syscall.Syscall(procRegisterApplicationRestart.Addr(), 2, uintptr(unsafe.Pointer(cmdLineExclExeName)), uintptr(flags), 0) + r0, _, _ := syscall.SyscallN(procRegisterApplicationRestart.Addr(), uintptr(unsafe.Pointer(cmdLineExclExeName)), uintptr(flags)) ret = wingoes.HRESULT(r0) return } func dsGetDcName(computerName *uint16, domainName *uint16, domainGuid *windows.GUID, siteName *uint16, flags dsGetDcNameFlag, dcInfo **_DOMAIN_CONTROLLER_INFO) (ret error) { - r0, _, _ := syscall.Syscall6(procDsGetDcNameW.Addr(), 6, uintptr(unsafe.Pointer(computerName)), uintptr(unsafe.Pointer(domainName)), uintptr(unsafe.Pointer(domainGuid)), uintptr(unsafe.Pointer(siteName)), uintptr(flags), uintptr(unsafe.Pointer(dcInfo))) + r0, _, _ := syscall.SyscallN(procDsGetDcNameW.Addr(), uintptr(unsafe.Pointer(computerName)), uintptr(unsafe.Pointer(domainName)), uintptr(unsafe.Pointer(domainGuid)), uintptr(unsafe.Pointer(siteName)), uintptr(flags), uintptr(unsafe.Pointer(dcInfo))) if r0 != 0 { ret = syscall.Errno(r0) } @@ -90,7 +90,7 @@ func dsGetDcName(computerName *uint16, domainName *uint16, domainGuid *windows.G } func netValidateName(server *uint16, name *uint16, account *uint16, password *uint16, nameType _NETSETUP_NAME_TYPE) (ret error) { - r0, _, _ := syscall.Syscall6(procNetValidateName.Addr(), 5, uintptr(unsafe.Pointer(server)), uintptr(unsafe.Pointer(name)), uintptr(unsafe.Pointer(account)), uintptr(unsafe.Pointer(password)), uintptr(nameType), 0) + r0, _, _ := syscall.SyscallN(procNetValidateName.Addr(), uintptr(unsafe.Pointer(server)), uintptr(unsafe.Pointer(name)), uintptr(unsafe.Pointer(account)), uintptr(unsafe.Pointer(password)), uintptr(nameType)) if r0 != 0 { ret = syscall.Errno(r0) } @@ -98,7 +98,7 @@ func netValidateName(server *uint16, name *uint16, account *uint16, password *ui } func rmEndSession(session _RMHANDLE) (ret error) { - r0, _, _ := syscall.Syscall(procRmEndSession.Addr(), 1, uintptr(session), 0, 0) + r0, _, _ := syscall.SyscallN(procRmEndSession.Addr(), uintptr(session)) if r0 != 0 { ret = syscall.Errno(r0) } @@ -106,7 +106,7 @@ func rmEndSession(session _RMHANDLE) (ret error) { } func rmGetList(session _RMHANDLE, nProcInfoNeeded *uint32, nProcInfo *uint32, rgAffectedApps *_RM_PROCESS_INFO, pRebootReasons *uint32) (ret error) { - r0, _, _ := syscall.Syscall6(procRmGetList.Addr(), 5, uintptr(session), uintptr(unsafe.Pointer(nProcInfoNeeded)), uintptr(unsafe.Pointer(nProcInfo)), uintptr(unsafe.Pointer(rgAffectedApps)), uintptr(unsafe.Pointer(pRebootReasons)), 0) + r0, _, _ := syscall.SyscallN(procRmGetList.Addr(), uintptr(session), uintptr(unsafe.Pointer(nProcInfoNeeded)), uintptr(unsafe.Pointer(nProcInfo)), uintptr(unsafe.Pointer(rgAffectedApps)), uintptr(unsafe.Pointer(pRebootReasons))) if r0 != 0 { ret = syscall.Errno(r0) } @@ -114,7 +114,7 @@ func rmGetList(session _RMHANDLE, nProcInfoNeeded *uint32, nProcInfo *uint32, rg } func rmJoinSession(pSession *_RMHANDLE, sessionKey *uint16) (ret error) { - r0, _, _ := syscall.Syscall(procRmJoinSession.Addr(), 2, uintptr(unsafe.Pointer(pSession)), uintptr(unsafe.Pointer(sessionKey)), 0) + r0, _, _ := syscall.SyscallN(procRmJoinSession.Addr(), uintptr(unsafe.Pointer(pSession)), uintptr(unsafe.Pointer(sessionKey))) if r0 != 0 { ret = syscall.Errno(r0) } @@ -122,7 +122,7 @@ func rmJoinSession(pSession *_RMHANDLE, sessionKey *uint16) (ret error) { } func rmRegisterResources(session _RMHANDLE, nFiles uint32, rgsFileNames **uint16, nApplications uint32, rgApplications *_RM_UNIQUE_PROCESS, nServices uint32, rgsServiceNames **uint16) (ret error) { - r0, _, _ := syscall.Syscall9(procRmRegisterResources.Addr(), 7, uintptr(session), uintptr(nFiles), uintptr(unsafe.Pointer(rgsFileNames)), uintptr(nApplications), uintptr(unsafe.Pointer(rgApplications)), uintptr(nServices), uintptr(unsafe.Pointer(rgsServiceNames)), 0, 0) + r0, _, _ := syscall.SyscallN(procRmRegisterResources.Addr(), uintptr(session), uintptr(nFiles), uintptr(unsafe.Pointer(rgsFileNames)), uintptr(nApplications), uintptr(unsafe.Pointer(rgApplications)), uintptr(nServices), uintptr(unsafe.Pointer(rgsServiceNames))) if r0 != 0 { ret = syscall.Errno(r0) } @@ -130,7 +130,7 @@ func rmRegisterResources(session _RMHANDLE, nFiles uint32, rgsFileNames **uint16 } func rmStartSession(pSession *_RMHANDLE, flags uint32, sessionKey *uint16) (ret error) { - r0, _, _ := syscall.Syscall(procRmStartSession.Addr(), 3, uintptr(unsafe.Pointer(pSession)), uintptr(flags), uintptr(unsafe.Pointer(sessionKey))) + r0, _, _ := syscall.SyscallN(procRmStartSession.Addr(), uintptr(unsafe.Pointer(pSession)), uintptr(flags), uintptr(unsafe.Pointer(sessionKey))) if r0 != 0 { ret = syscall.Errno(r0) } @@ -138,7 +138,7 @@ func rmStartSession(pSession *_RMHANDLE, flags uint32, sessionKey *uint16) (ret } func expandEnvironmentStringsForUser(token windows.Token, src *uint16, dst *uint16, dstLen uint32) (err error) { - r1, _, e1 := syscall.Syscall6(procExpandEnvironmentStringsForUserW.Addr(), 4, uintptr(token), uintptr(unsafe.Pointer(src)), uintptr(unsafe.Pointer(dst)), uintptr(dstLen), 0, 0) + r1, _, e1 := syscall.SyscallN(procExpandEnvironmentStringsForUserW.Addr(), uintptr(token), uintptr(unsafe.Pointer(src)), uintptr(unsafe.Pointer(dst)), uintptr(dstLen)) if int32(r1) == 0 { err = errnoErr(e1) } @@ -146,7 +146,7 @@ func expandEnvironmentStringsForUser(token windows.Token, src *uint16, dst *uint } func loadUserProfile(token windows.Token, profileInfo *_PROFILEINFO) (err error) { - r1, _, e1 := syscall.Syscall(procLoadUserProfileW.Addr(), 2, uintptr(token), uintptr(unsafe.Pointer(profileInfo)), 0) + r1, _, e1 := syscall.SyscallN(procLoadUserProfileW.Addr(), uintptr(token), uintptr(unsafe.Pointer(profileInfo))) if int32(r1) == 0 { err = errnoErr(e1) } @@ -154,7 +154,7 @@ func loadUserProfile(token windows.Token, profileInfo *_PROFILEINFO) (err error) } func unloadUserProfile(token windows.Token, profile registry.Key) (err error) { - r1, _, e1 := syscall.Syscall(procUnloadUserProfile.Addr(), 2, uintptr(token), uintptr(profile), 0) + r1, _, e1 := syscall.SyscallN(procUnloadUserProfile.Addr(), uintptr(token), uintptr(profile)) if int32(r1) == 0 { err = errnoErr(e1) } diff --git a/version-embed.go b/version-embed.go index 40c2e7cef..17bf578dd 100644 --- a/version-embed.go +++ b/version-embed.go @@ -4,7 +4,10 @@ // Package tailscaleroot embeds VERSION.txt into the binary. package tailscaleroot -import _ "embed" +import ( + _ "embed" + "runtime/debug" +) // VersionDotTxt is the contents of VERSION.txt. Despite the tempting filename, // this does not necessarily contain the accurate version number of the build, which @@ -22,3 +25,17 @@ var AlpineDockerTag string // //go:embed go.toolchain.rev var GoToolchainRev string + +//lint:ignore U1000 used by tests + assert_ts_toolchain_match.go w/ right build tags +func tailscaleToolchainRev() (gitHash string, ok bool) { + bi, ok := debug.ReadBuildInfo() + if !ok { + return "", false + } + for _, s := range bi.Settings { + if s.Key == "tailscale.toolchain.rev" { + return s.Value, true + } + } + return "", false +} diff --git a/version/cmdname.go b/version/cmdname.go index 51e065438..c38544ce1 100644 --- a/version/cmdname.go +++ b/version/cmdname.go @@ -12,7 +12,7 @@ import ( "io" "os" "path" - "path/filepath" + "runtime" "strings" ) @@ -30,7 +30,7 @@ func CmdName() string { func cmdName(exe string) string { // fallbackName, the lowercase basename of the executable, is what we return if // we can't find the Go module metadata embedded in the file. - fallbackName := filepath.Base(strings.TrimSuffix(strings.ToLower(exe), ".exe")) + fallbackName := prepExeNameForCmp(exe, runtime.GOARCH) var ret string info, err := findModuleInfo(exe) @@ -45,10 +45,10 @@ func cmdName(exe string) string { break } } - if strings.HasPrefix(ret, "wg") && fallbackName == "tailscale-ipn" { - // The tailscale-ipn.exe binary for internal build system packaging reasons - // has a path of "tailscale.io/win/wg64", "tailscale.io/win/wg32", etc. - // Ignore that name and use "tailscale-ipn" instead. + if runtime.GOOS == "windows" && strings.HasPrefix(ret, "gui") && checkPreppedExeNameForGUI(fallbackName) { + // The GUI binary for internal build system packaging reasons + // has a path of "tailscale.io/win/gui". + // Ignore that name and use fallbackName instead. return fallbackName } if ret == "" { diff --git a/version/distro/distro.go b/version/distro/distro.go index 8865a834b..0e88bdd2f 100644 --- a/version/distro/distro.go +++ b/version/distro/distro.go @@ -6,13 +6,13 @@ package distro import ( "bytes" - "io" "os" "runtime" "strconv" + "strings" "tailscale.com/types/lazy" - "tailscale.com/util/lineread" + "tailscale.com/util/lineiter" ) type Distro string @@ -31,6 +31,8 @@ const ( WDMyCloud = Distro("wdmycloud") Unraid = Distro("unraid") Alpine = Distro("alpine") + UBNT = Distro("ubnt") // Ubiquiti Networks + JetKVM = Distro("jetkvm") ) var distro lazy.SyncValue[Distro] @@ -76,6 +78,12 @@ func linuxDistro() Distro { case have("/usr/local/bin/freenas-debug"): // TrueNAS Scale runs on debian return TrueNAS + case have("/usr/bin/ubnt-device-info"): + // UBNT runs on Debian-based systems. This MUST be checked before Debian. + // + // Currently supported product families: + // - UDM (UniFi Dream Machine, UDM-Pro) + return UBNT case have("/etc/debian_version"): return Debian case have("/etc/arch-release"): @@ -96,10 +104,20 @@ func linuxDistro() Distro { return Unraid case have("/etc/alpine-release"): return Alpine + case runtime.GOARCH == "arm" && isDeviceModel("JetKVM"): + return JetKVM } return "" } +func isDeviceModel(want string) bool { + if runtime.GOOS != "linux" { + return false + } + v, _ := os.ReadFile("/sys/firmware/devicetree/base/model") + return want == strings.Trim(string(v), "\x00\r\n\t ") +} + func freebsdDistro() Distro { switch { case have("/etc/pfSense-rc"): @@ -132,18 +150,19 @@ func DSMVersion() int { return v } // But when run from the command line, we have to read it from the file: - lineread.File("/etc/VERSION", func(line []byte) error { + for lr := range lineiter.File("/etc/VERSION") { + line, err := lr.Value() + if err != nil { + break // but otherwise ignore + } line = bytes.TrimSpace(line) if string(line) == `majorversion="7"` { - v = 7 - return io.EOF + return 7 } if string(line) == `majorversion="6"` { - v = 6 - return io.EOF + return 6 } - return nil - }) - return v + } + return 0 }) } diff --git a/version/exename.go b/version/exename.go new file mode 100644 index 000000000..d5047c203 --- /dev/null +++ b/version/exename.go @@ -0,0 +1,25 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package version + +import ( + "path/filepath" + "strings" +) + +// prepExeNameForCmp strips any extension and arch suffix from exe, and +// lowercases it. +func prepExeNameForCmp(exe, arch string) string { + baseNoExt := strings.ToLower(strings.TrimSuffix(filepath.Base(exe), filepath.Ext(exe))) + archSuffix := "-" + arch + return strings.TrimSuffix(baseNoExt, archSuffix) +} + +func checkPreppedExeNameForGUI(preppedExeName string) bool { + return preppedExeName == "tailscale-ipn" || preppedExeName == "tailscale-gui" +} + +func isGUIExeName(exe, arch string) bool { + return checkPreppedExeNameForGUI(prepExeNameForCmp(exe, arch)) +} diff --git a/version/print.go b/version/print.go index 7d8554279..43ee2b559 100644 --- a/version/print.go +++ b/version/print.go @@ -7,11 +7,10 @@ import ( "fmt" "runtime" "strings" - - "tailscale.com/types/lazy" + "sync" ) -var stringLazy = lazy.SyncFunc(func() string { +var stringLazy = sync.OnceValue(func() string { var ret strings.Builder ret.WriteString(Short()) ret.WriteByte('\n') @@ -21,6 +20,7 @@ var stringLazy = lazy.SyncFunc(func() string { if gitCommit() != "" { fmt.Fprintf(&ret, " tailscale commit: %s%s\n", gitCommit(), dirtyString()) } + fmt.Fprintf(&ret, " long version: %s\n", Long()) if extraGitCommitStamp != "" { fmt.Fprintf(&ret, " other commit: %s\n", extraGitCommitStamp) } diff --git a/version/prop.go b/version/prop.go index fee76c65f..0d6a5c00d 100644 --- a/version/prop.go +++ b/version/prop.go @@ -9,6 +9,7 @@ import ( "runtime" "strconv" "strings" + "sync" "tailscale.com/tailcfg" "tailscale.com/types/lazy" @@ -61,26 +62,21 @@ func IsSandboxedMacOS() bool { // Tailscale for macOS, either the main GUI process (non-sandboxed) or the // system extension (sandboxed). func IsMacSys() bool { - return IsMacSysExt() || IsMacSysApp() + return IsMacSysExt() || IsMacSysGUI() } var isMacSysApp lazy.SyncValue[bool] -// IsMacSysApp reports whether this process is the main, non-sandboxed GUI process +// IsMacSysGUI reports whether this process is the main, non-sandboxed GUI process // that ships with the Standalone variant of Tailscale for macOS. -func IsMacSysApp() bool { +func IsMacSysGUI() bool { if runtime.GOOS != "darwin" { return false } return isMacSysApp.Get(func() bool { - exe, err := os.Executable() - if err != nil { - return false - } - // Check that this is the GUI binary, and it is not sandboxed. The GUI binary - // shipped in the App Store will always have the App Sandbox enabled. - return strings.HasSuffix(exe, "/Contents/MacOS/Tailscale") && !IsMacAppStore() + return strings.Contains(os.Getenv("HOME"), "/Containers/io.tailscale.ipn.macsys/") || + strings.Contains(os.Getenv("XPC_SERVICE_NAME"), "io.tailscale.ipn.macsys") }) } @@ -94,10 +90,6 @@ func IsMacSysExt() bool { return false } return isMacSysExt.Get(func() bool { - if strings.Contains(os.Getenv("HOME"), "/Containers/io.tailscale.ipn.macsys/") || - strings.Contains(os.Getenv("XPC_SERVICE_NAME"), "io.tailscale.ipn.macsys") { - return true - } exe, err := os.Executable() if err != nil { return false @@ -108,8 +100,8 @@ func IsMacSysExt() bool { var isMacAppStore lazy.SyncValue[bool] -// IsMacAppStore whether this binary is from the App Store version of Tailscale -// for macOS. +// IsMacAppStore returns whether this binary is from the App Store version of Tailscale +// for macOS. Returns true for both the network extension and the GUI app. func IsMacAppStore() bool { if runtime.GOOS != "darwin" { return false @@ -123,6 +115,25 @@ func IsMacAppStore() bool { }) } +var isMacAppStoreGUI lazy.SyncValue[bool] + +// IsMacAppStoreGUI reports whether this binary is the GUI app from the App Store +// version of Tailscale for macOS. +func IsMacAppStoreGUI() bool { + if runtime.GOOS != "darwin" { + return false + } + return isMacAppStoreGUI.Get(func() bool { + exe, err := os.Executable() + if err != nil { + return false + } + // Check that this is the GUI binary, and it is not sandboxed. The GUI binary + // shipped in the App Store will always have the App Sandbox enabled. + return strings.Contains(exe, "/Tailscale") && !IsMacSysGUI() + }) +} + var isAppleTV lazy.SyncValue[bool] // IsAppleTV reports whether this binary is part of the Tailscale network extension for tvOS. @@ -148,7 +159,9 @@ func IsWindowsGUI() bool { if err != nil { return false } - return strings.EqualFold(exe, "tailscale-ipn.exe") || strings.EqualFold(exe, "tailscale-ipn") + // It is okay to use GOARCH here because we're checking whether our + // _own_ process is the GUI. + return isGUIExeName(exe, runtime.GOARCH) }) } @@ -174,7 +187,7 @@ func IsUnstableBuild() bool { }) } -var isDev = lazy.SyncFunc(func() bool { +var isDev = sync.OnceValue(func() bool { return strings.Contains(Short(), "-dev") }) diff --git a/version/version.go b/version/version.go index 4b96d15ea..2add25689 100644 --- a/version/version.go +++ b/version/version.go @@ -7,7 +7,9 @@ package version import ( "fmt" "runtime/debug" + "strconv" "strings" + "sync" tailscaleroot "tailscale.com" "tailscale.com/types/lazy" @@ -116,7 +118,7 @@ func (i embeddedInfo) commitAbbrev() string { return i.commit } -var getEmbeddedInfo = lazy.SyncFunc(func() embeddedInfo { +var getEmbeddedInfo = sync.OnceValue(func() embeddedInfo { bi, ok := debug.ReadBuildInfo() if !ok { return embeddedInfo{} @@ -169,3 +171,42 @@ func majorMinorPatch() string { ret, _, _ := strings.Cut(Short(), "-") return ret } + +func isValidLongWithTwoRepos(v string) bool { + s := strings.Split(v, "-") + if len(s) != 3 { + return false + } + hexChunk := func(s string) bool { + if len(s) < 6 { + return false + } + for i := range len(s) { + b := s[i] + if (b < '0' || b > '9') && (b < 'a' || b > 'f') { + return false + } + } + return true + } + + v, t, g := s[0], s[1], s[2] + if !strings.HasPrefix(t, "t") || !strings.HasPrefix(g, "g") || + !hexChunk(t[1:]) || !hexChunk(g[1:]) { + return false + } + nums := strings.Split(v, ".") + if len(nums) != 3 { + return false + } + for i, n := range nums { + bits := 8 + if i == 2 { + bits = 16 + } + if _, err := strconv.ParseUint(n, 10, bits); err != nil { + return false + } + } + return true +} diff --git a/version/version_checkformat.go b/version/version_checkformat.go new file mode 100644 index 000000000..05a97d191 --- /dev/null +++ b/version/version_checkformat.go @@ -0,0 +1,17 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build tailscale_go && android + +package version + +import "fmt" + +func init() { + // For official Android builds using the tailscale_go toolchain, + // panic if the builder is screwed up and we fail to stamp a valid + // version string. + if !isValidLongWithTwoRepos(Long()) { + panic(fmt.Sprintf("malformed version.Long value %q", Long())) + } +} diff --git a/version/version_internal_test.go b/version/version_internal_test.go new file mode 100644 index 000000000..b3b848276 --- /dev/null +++ b/version/version_internal_test.go @@ -0,0 +1,62 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package version + +import "testing" + +func TestIsValidLongWithTwoRepos(t *testing.T) { + tests := []struct { + long string + want bool + }{ + {"1.2.3-t01234abcde-g01234abcde", true}, + {"1.2.259-t01234abcde-g01234abcde", true}, // big patch version + {"1.2.3-t01234abcde", false}, // missing repo + {"1.2.3-g01234abcde", false}, // missing repo + {"-t01234abcde-g01234abcde", false}, + {"1.2.3", false}, + {"1.2.3-t01234abcde-g", false}, + {"1.2.3-t01234abcde-gERRBUILDINFO", false}, + } + for _, tt := range tests { + if got := isValidLongWithTwoRepos(tt.long); got != tt.want { + t.Errorf("IsValidLongWithTwoRepos(%q) = %v; want %v", tt.long, got, tt.want) + } + } +} + +func TestPrepExeNameForCmp(t *testing.T) { + cases := []struct { + exe string + want string + }{ + { + "tailscale-ipn.exe", + "tailscale-ipn", + }, + { + "tailscale-gui-amd64.exe", + "tailscale-gui", + }, + { + "tailscale-gui-amd64", + "tailscale-gui", + }, + { + "tailscale-ipn", + "tailscale-ipn", + }, + { + "TaIlScAlE-iPn.ExE", + "tailscale-ipn", + }, + } + + for _, c := range cases { + got := prepExeNameForCmp(c.exe, "amd64") + if got != c.want { + t.Errorf("prepExeNameForCmp(%q) = %q; want %q", c.exe, got, c.want) + } + } +} diff --git a/version_tailscale_test.go b/version_tailscale_test.go index c15e0cbee..0a690e312 100644 --- a/version_tailscale_test.go +++ b/version_tailscale_test.go @@ -7,23 +7,15 @@ package tailscaleroot import ( "os" - "runtime/debug" "strings" "testing" ) func TestToolchainMatches(t *testing.T) { - bi, ok := debug.ReadBuildInfo() + tsRev, ok := tailscaleToolchainRev() if !ok { t.Fatal("failed to read build info") } - var tsRev string - for _, s := range bi.Settings { - if s.Key == "tailscale.toolchain.rev" { - tsRev = s.Value - break - } - } want := strings.TrimSpace(GoToolchainRev) if tsRev != want { if os.Getenv("TS_PERMIT_TOOLCHAIN_MISMATCH") == "1" { diff --git a/version_test.go b/version_test.go index 1f434e682..3d983a19d 100644 --- a/version_test.go +++ b/version_test.go @@ -6,21 +6,16 @@ package tailscaleroot import ( "fmt" "os" - "regexp" + "os/exec" + "runtime" "strings" "testing" + + "golang.org/x/mod/modfile" ) func TestDockerfileVersion(t *testing.T) { - goMod, err := os.ReadFile("go.mod") - if err != nil { - t.Fatal(err) - } - m := regexp.MustCompile(`(?m)^go (\d\.\d+)\r?($|\.)`).FindStringSubmatch(string(goMod)) - if m == nil { - t.Fatalf("didn't find go version in go.mod") - } - goVersion := m[1] + goVersion := mustGetGoModVersion(t, false) dockerFile, err := os.ReadFile("Dockerfile") if err != nil { @@ -31,3 +26,60 @@ func TestDockerfileVersion(t *testing.T) { t.Errorf("didn't find %q in Dockerfile", wantSub) } } + +// TestGoVersion tests that the Go version specified in go.mod matches ./tool/go version. +func TestGoVersion(t *testing.T) { + // We could special-case ./tool/go path for Windows, but really there is no + // need to run it there. + if runtime.GOOS == "windows" { + t.Skip("Skipping test on Windows") + } + goModVersion := mustGetGoModVersion(t, true) + + goToolCmd := exec.Command("./tool/go", "version") + goToolOutput, err := goToolCmd.Output() + if err != nil { + t.Fatalf("Failed to get ./tool/go version: %v", err) + } + + // Version info will approximately look like 'go version go1.24.4 linux/amd64'. + parts := strings.Fields(string(goToolOutput)) + if len(parts) < 4 { + t.Fatalf("Unexpected ./tool/go version output format: %s", goToolOutput) + } + + goToolVersion := strings.TrimPrefix(parts[2], "go") + + if goModVersion != goToolVersion { + t.Errorf("Go version in go.mod (%q) does not match the version of ./tool/go (%q).\nEnsure that the go.mod refers to the same Go version as ./go.toolchain.rev.", + goModVersion, goToolVersion) + } +} + +func mustGetGoModVersion(t *testing.T, includePatchVersion bool) string { + t.Helper() + + goModBytes, err := os.ReadFile("go.mod") + if err != nil { + t.Fatal(err) + } + + modFile, err := modfile.Parse("go.mod", goModBytes, nil) + if err != nil { + t.Fatal(err) + } + + if modFile.Go == nil { + t.Fatal("no Go version found in go.mod") + } + + version := modFile.Go.Version + + parts := strings.Split(version, ".") + if !includePatchVersion { + if len(parts) >= 2 { + version = parts[0] + "." + parts[1] + } + } + return version +} diff --git a/wf/firewall.go b/wf/firewall.go index 076944c8d..07e160eb3 100644 --- a/wf/firewall.go +++ b/wf/firewall.go @@ -18,7 +18,7 @@ import ( // Known addresses. var ( - linkLocalRange = netip.MustParsePrefix("ff80::/10") + linkLocalRange = netip.MustParsePrefix("fe80::/10") linkLocalDHCPMulticast = netip.MustParseAddr("ff02::1:2") siteLocalDHCPMulticast = netip.MustParseAddr("ff05::1:3") linkLocalRouterMulticast = netip.MustParseAddr("ff02::2") @@ -66,8 +66,8 @@ func (p protocol) getLayers(d direction) []wf.LayerID { return layers } -func ruleName(action wf.Action, l wf.LayerID, name string) string { - switch l { +func ruleName(action wf.Action, layerID wf.LayerID, name string) string { + switch layerID { case wf.LayerALEAuthConnectV4: return fmt.Sprintf("%s outbound %s (IPv4)", action, name) case wf.LayerALEAuthConnectV6: @@ -307,8 +307,8 @@ func (f *Firewall) newRule(name string, w weight, layer wf.LayerID, conditions [ func (f *Firewall) addRules(name string, w weight, conditions []*wf.Match, action wf.Action, p protocol, d direction) ([]*wf.Rule, error) { var rules []*wf.Rule - for _, l := range p.getLayers(d) { - r, err := f.newRule(name, w, l, conditions, action) + for _, layer := range p.getLayers(d) { + r, err := f.newRule(name, w, layer, conditions, action) if err != nil { return nil, err } diff --git a/wgengine/bench/wg.go b/wgengine/bench/wg.go index 45823dd56..ce6add866 100644 --- a/wgengine/bench/wg.go +++ b/wgengine/bench/wg.go @@ -38,7 +38,6 @@ func setupWGTest(b *testing.B, logf logger.Logf, traf *TrafficGen, a1, a2 netip. k1 := key.NewNode() c1 := wgcfg.Config{ - Name: "e1", PrivateKey: k1, Addresses: []netip.Prefix{a1}, } @@ -46,14 +45,14 @@ func setupWGTest(b *testing.B, logf logger.Logf, traf *TrafficGen, a1, a2 netip. logf: logger.WithPrefix(logf, "tun1: "), traf: traf, } - s1 := new(tsd.System) + s1 := tsd.NewSystem() e1, err := wgengine.NewUserspaceEngine(l1, wgengine.Config{ Router: router.NewFake(l1), NetMon: nil, ListenPort: 0, Tun: t1, SetSubsystem: s1.Set, - HealthTracker: s1.HealthTracker(), + HealthTracker: s1.HealthTracker.Get(), }) if err != nil { log.Fatalf("e1 init: %v", err) @@ -65,7 +64,6 @@ func setupWGTest(b *testing.B, logf logger.Logf, traf *TrafficGen, a1, a2 netip. l2 := logger.WithPrefix(logf, "e2: ") k2 := key.NewNode() c2 := wgcfg.Config{ - Name: "e2", PrivateKey: k2, Addresses: []netip.Prefix{a2}, } @@ -73,14 +71,14 @@ func setupWGTest(b *testing.B, logf logger.Logf, traf *TrafficGen, a1, a2 netip. logf: logger.WithPrefix(logf, "tun2: "), traf: traf, } - s2 := new(tsd.System) + s2 := tsd.NewSystem() e2, err := wgengine.NewUserspaceEngine(l2, wgengine.Config{ Router: router.NewFake(l2), NetMon: nil, ListenPort: 0, Tun: t2, SetSubsystem: s2.Set, - HealthTracker: s2.HealthTracker(), + HealthTracker: s2.HealthTracker.Get(), }) if err != nil { log.Fatalf("e2 init: %v", err) @@ -113,9 +111,8 @@ func setupWGTest(b *testing.B, logf logger.Logf, traf *TrafficGen, a1, a2 netip. Endpoints: epFromTyped(st.LocalAddrs), } e2.SetNetworkMap(&netmap.NetworkMap{ - NodeKey: k2.Public(), - PrivateKey: k2, - Peers: []tailcfg.NodeView{n.View()}, + NodeKey: k2.Public(), + Peers: []tailcfg.NodeView{n.View()}, }) p := wgcfg.Peer{ @@ -145,9 +142,8 @@ func setupWGTest(b *testing.B, logf logger.Logf, traf *TrafficGen, a1, a2 netip. Endpoints: epFromTyped(st.LocalAddrs), } e1.SetNetworkMap(&netmap.NetworkMap{ - NodeKey: k1.Public(), - PrivateKey: k1, - Peers: []tailcfg.NodeView{n.View()}, + NodeKey: k1.Public(), + Peers: []tailcfg.NodeView{n.View()}, }) p := wgcfg.Peer{ diff --git a/wgengine/filter/filter.go b/wgengine/filter/filter.go index 56224ac5d..987fcee01 100644 --- a/wgengine/filter/filter.go +++ b/wgengine/filter/filter.go @@ -24,6 +24,7 @@ import ( "tailscale.com/types/views" "tailscale.com/util/mak" "tailscale.com/util/slicesx" + "tailscale.com/util/usermetric" "tailscale.com/wgengine/filter/filtertype" ) @@ -202,16 +203,17 @@ func New(matches []Match, capTest CapTestFunc, localNets, logIPs *netipx.IPSet, } f := &Filter{ - logf: logf, - matches4: matchesFamily(matches, netip.Addr.Is4), - matches6: matchesFamily(matches, netip.Addr.Is6), - cap4: capMatchesFunc(matches, netip.Addr.Is4), - cap6: capMatchesFunc(matches, netip.Addr.Is6), - local4: ipset.FalseContainsIPFunc(), - local6: ipset.FalseContainsIPFunc(), - logIPs4: ipset.FalseContainsIPFunc(), - logIPs6: ipset.FalseContainsIPFunc(), - state: state, + logf: logf, + matches4: matchesFamily(matches, netip.Addr.Is4), + matches6: matchesFamily(matches, netip.Addr.Is6), + cap4: capMatchesFunc(matches, netip.Addr.Is4), + cap6: capMatchesFunc(matches, netip.Addr.Is6), + local4: ipset.FalseContainsIPFunc(), + local6: ipset.FalseContainsIPFunc(), + logIPs4: ipset.FalseContainsIPFunc(), + logIPs6: ipset.FalseContainsIPFunc(), + state: state, + srcIPHasCap: capTest, } if localNets != nil { p := localNets.Prefixes() @@ -409,7 +411,7 @@ func (f *Filter) ShieldsUp() bool { return f.shieldsUp } // Tailscale peer. func (f *Filter) RunIn(q *packet.Parsed, rf RunFlags) Response { dir := in - r := f.pre(q, rf, dir) + r, _ := f.pre(q, rf, dir) if r == Accept || r == Drop { // already logged return r @@ -430,16 +432,16 @@ func (f *Filter) RunIn(q *packet.Parsed, rf RunFlags) Response { // RunOut determines whether this node is allowed to send q to a // Tailscale peer. -func (f *Filter) RunOut(q *packet.Parsed, rf RunFlags) Response { +func (f *Filter) RunOut(q *packet.Parsed, rf RunFlags) (Response, usermetric.DropReason) { dir := out - r := f.pre(q, rf, dir) + r, reason := f.pre(q, rf, dir) if r == Accept || r == Drop { // already logged - return r + return r, reason } r, why := f.runOut(q) f.logRateLimit(rf, q, dir, r, why) - return r + return r, "" } var unknownProtoStringCache sync.Map // ipproto.Proto -> string @@ -609,33 +611,38 @@ var gcpDNSAddr = netaddr.IPv4(169, 254, 169, 254) // pre runs the direction-agnostic filter logic. dir is only used for // logging. -func (f *Filter) pre(q *packet.Parsed, rf RunFlags, dir direction) Response { +func (f *Filter) pre(q *packet.Parsed, rf RunFlags, dir direction) (Response, usermetric.DropReason) { if len(q.Buffer()) == 0 { // wireguard keepalive packet, always permit. - return Accept + return Accept, "" } if len(q.Buffer()) < 20 { f.logRateLimit(rf, q, dir, Drop, "too short") - return Drop + return Drop, usermetric.ReasonTooShort + } + + if q.IPProto == ipproto.Unknown { + f.logRateLimit(rf, q, dir, Drop, "unknown proto") + return Drop, usermetric.ReasonUnknownProtocol } if q.Dst.Addr().IsMulticast() { f.logRateLimit(rf, q, dir, Drop, "multicast") - return Drop + return Drop, usermetric.ReasonMulticast } if q.Dst.Addr().IsLinkLocalUnicast() && q.Dst.Addr() != gcpDNSAddr { f.logRateLimit(rf, q, dir, Drop, "link-local-unicast") - return Drop + return Drop, usermetric.ReasonLinkLocalUnicast } if q.IPProto == ipproto.Fragment { // Fragments after the first always need to be passed through. // Very small fragments are considered Junk by Parsed. f.logRateLimit(rf, q, dir, Accept, "fragment") - return Accept + return Accept, "" } - return noVerdict + return noVerdict, "" } // loggingAllowed reports whether p can appear in logs at all. diff --git a/wgengine/filter/filter_test.go b/wgengine/filter/filter_test.go index f2796d71f..ae39eeb08 100644 --- a/wgengine/filter/filter_test.go +++ b/wgengine/filter/filter_test.go @@ -18,7 +18,6 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" "go4.org/netipx" - xmaps "golang.org/x/exp/maps" "tailscale.com/net/flowtrack" "tailscale.com/net/ipset" "tailscale.com/net/packet" @@ -30,6 +29,8 @@ import ( "tailscale.com/types/logger" "tailscale.com/types/views" "tailscale.com/util/must" + "tailscale.com/util/slicesx" + "tailscale.com/util/usermetric" "tailscale.com/wgengine/filter/filtertype" ) @@ -211,7 +212,7 @@ func TestUDPState(t *testing.T) { t.Fatalf("incoming initial packet not dropped, got=%v: %v", got, a4) } // We talk to that peer - if got := acl.RunOut(&b4, flags); got != Accept { + if got, _ := acl.RunOut(&b4, flags); got != Accept { t.Fatalf("outbound packet didn't egress, got=%v: %v", got, b4) } // Now, the same packet as before is allowed back. @@ -227,7 +228,7 @@ func TestUDPState(t *testing.T) { t.Fatalf("incoming initial packet not dropped: %v", a4) } // We talk to that peer - if got := acl.RunOut(&b6, flags); got != Accept { + if got, _ := acl.RunOut(&b6, flags); got != Accept { t.Fatalf("outbound packet didn't egress: %v", b4) } // Now, the same packet as before is allowed back. @@ -382,25 +383,27 @@ func BenchmarkFilter(b *testing.B) { func TestPreFilter(t *testing.T) { packets := []struct { - desc string - want Response - b []byte + desc string + want Response + wantReason usermetric.DropReason + b []byte }{ - {"empty", Accept, []byte{}}, - {"short", Drop, []byte("short")}, - {"junk", Drop, raw4default(ipproto.Unknown, 10)}, - {"fragment", Accept, raw4default(ipproto.Fragment, 40)}, - {"tcp", noVerdict, raw4default(ipproto.TCP, 0)}, - {"udp", noVerdict, raw4default(ipproto.UDP, 0)}, - {"icmp", noVerdict, raw4default(ipproto.ICMPv4, 0)}, + {"empty", Accept, "", []byte{}}, + {"short", Drop, usermetric.ReasonTooShort, []byte("short")}, + {"short-junk", Drop, usermetric.ReasonTooShort, raw4default(ipproto.Unknown, 10)}, + {"long-junk", Drop, usermetric.ReasonUnknownProtocol, raw4default(ipproto.Unknown, 21)}, + {"fragment", Accept, "", raw4default(ipproto.Fragment, 40)}, + {"tcp", noVerdict, "", raw4default(ipproto.TCP, 0)}, + {"udp", noVerdict, "", raw4default(ipproto.UDP, 0)}, + {"icmp", noVerdict, "", raw4default(ipproto.ICMPv4, 0)}, } f := NewAllowNone(t.Logf, &netipx.IPSet{}) for _, testPacket := range packets { p := &packet.Parsed{} p.Decode(testPacket.b) - got := f.pre(p, LogDrops|LogAccepts, in) - if got != testPacket.want { - t.Errorf("%q got=%v want=%v packet:\n%s", testPacket.desc, got, testPacket.want, packet.Hexdump(testPacket.b)) + got, gotReason := f.pre(p, LogDrops|LogAccepts, in) + if got != testPacket.want || gotReason != testPacket.wantReason { + t.Errorf("%q got=%v want=%v gotReason=%s wantReason=%s packet:\n%s", testPacket.desc, got, testPacket.want, gotReason, testPacket.wantReason, packet.Hexdump(testPacket.b)) } } } @@ -768,7 +771,7 @@ func ports(s string) PortRange { if err != nil { panic(fmt.Sprintf("invalid NetPortRange %q", s)) } - return PortRange{uint16(first), uint16(last)} + return PortRange{First: uint16(first), Last: uint16(last)} } func netports(netPorts ...string) (ret []NetPortRange) { @@ -814,11 +817,11 @@ func TestMatchesFromFilterRules(t *testing.T) { Dsts: []NetPortRange{ { Net: netip.MustParsePrefix("0.0.0.0/0"), - Ports: PortRange{22, 22}, + Ports: PortRange{First: 22, Last: 22}, }, { Net: netip.MustParsePrefix("::0/0"), - Ports: PortRange{22, 22}, + Ports: PortRange{First: 22, Last: 22}, }, }, Srcs: []netip.Prefix{ @@ -848,7 +851,7 @@ func TestMatchesFromFilterRules(t *testing.T) { Dsts: []NetPortRange{ { Net: netip.MustParsePrefix("1.2.0.0/16"), - Ports: PortRange{22, 22}, + Ports: PortRange{First: 22, Last: 22}, }, }, Srcs: []netip.Prefix{ @@ -997,7 +1000,7 @@ func TestPeerCaps(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got := xmaps.Keys(filt.CapsWithValues(netip.MustParseAddr(tt.src), netip.MustParseAddr(tt.dst))) + got := slicesx.MapKeys(filt.CapsWithValues(netip.MustParseAddr(tt.src), netip.MustParseAddr(tt.dst))) slices.Sort(got) slices.Sort(tt.want) if !slices.Equal(got, tt.want) { diff --git a/wgengine/magicsock/batching_conn.go b/wgengine/magicsock/batching_conn.go deleted file mode 100644 index 5320d1caf..000000000 --- a/wgengine/magicsock/batching_conn.go +++ /dev/null @@ -1,25 +0,0 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package magicsock - -import ( - "net/netip" - - "golang.org/x/net/ipv4" - "golang.org/x/net/ipv6" - "tailscale.com/types/nettype" -) - -var ( - // This acts as a compile-time check for our usage of ipv6.Message in - // batchingConn for both IPv6 and IPv4 operations. - _ ipv6.Message = ipv4.Message{} -) - -// batchingConn is a nettype.PacketConn that provides batched i/o. -type batchingConn interface { - nettype.PacketConn - ReadBatch(msgs []ipv6.Message, flags int) (n int, err error) - WriteBatchTo(buffs [][]byte, addr netip.AddrPort) error -} diff --git a/wgengine/magicsock/batching_conn_default.go b/wgengine/magicsock/batching_conn_default.go deleted file mode 100644 index 519cf8082..000000000 --- a/wgengine/magicsock/batching_conn_default.go +++ /dev/null @@ -1,14 +0,0 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !linux - -package magicsock - -import ( - "tailscale.com/types/nettype" -) - -func tryUpgradeToBatchingConn(pconn nettype.PacketConn, _ string, _ int) nettype.PacketConn { - return pconn -} diff --git a/wgengine/magicsock/blockforever_conn.go b/wgengine/magicsock/blockforever_conn.go index f2e85dcd5..272a12513 100644 --- a/wgengine/magicsock/blockforever_conn.go +++ b/wgengine/magicsock/blockforever_conn.go @@ -10,11 +10,13 @@ import ( "sync" "syscall" "time" + + "tailscale.com/syncs" ) // blockForeverConn is a net.PacketConn whose reads block until it is closed. type blockForeverConn struct { - mu sync.Mutex + mu syncs.Mutex cond *sync.Cond closed bool } diff --git a/wgengine/magicsock/cloudinfo.go b/wgengine/magicsock/cloudinfo.go index 1de369631..0db56b3f6 100644 --- a/wgengine/magicsock/cloudinfo.go +++ b/wgengine/magicsock/cloudinfo.go @@ -17,6 +17,7 @@ import ( "strings" "time" + "tailscale.com/feature/buildfeatures" "tailscale.com/types/logger" "tailscale.com/util/cloudenv" ) @@ -34,6 +35,9 @@ type cloudInfo struct { } func newCloudInfo(logf logger.Logf) *cloudInfo { + if !buildfeatures.HasCloud { + return nil + } tr := &http.Transport{ DisableKeepAlives: true, Dial: (&net.Dialer{ @@ -53,6 +57,9 @@ func newCloudInfo(logf logger.Logf) *cloudInfo { // if the tailscaled process is running in a known cloud and there are any such // IPs present. func (ci *cloudInfo) GetPublicIPs(ctx context.Context) ([]netip.Addr, error) { + if !buildfeatures.HasCloud { + return nil, nil + } switch ci.cloud { case cloudenv.AWS: ret, err := ci.getAWS(ctx) diff --git a/wgengine/magicsock/debughttp.go b/wgengine/magicsock/debughttp.go index 6c07b0d5e..9aecab74b 100644 --- a/wgengine/magicsock/debughttp.go +++ b/wgengine/magicsock/debughttp.go @@ -13,6 +13,8 @@ import ( "strings" "time" + "tailscale.com/feature" + "tailscale.com/feature/buildfeatures" "tailscale.com/tailcfg" "tailscale.com/tstime/mono" "tailscale.com/types/key" @@ -24,6 +26,11 @@ import ( // /debug/magicsock) or via peerapi to a peer that's owned by the same // user (so they can e.g. inspect their phones). func (c *Conn) ServeHTTPDebug(w http.ResponseWriter, r *http.Request) { + if !buildfeatures.HasDebug { + http.Error(w, feature.ErrUnavailable.Error(), http.StatusNotImplemented) + return + } + c.mu.Lock() defer c.mu.Unlock() @@ -72,18 +79,18 @@ func (c *Conn) ServeHTTPDebug(w http.ResponseWriter, r *http.Request) { fmt.Fprintf(w, "

        # ip:port to endpoint

          ") { type kv struct { - ipp netip.AddrPort - pi *peerInfo + addr epAddr + pi *peerInfo } - ent := make([]kv, 0, len(c.peerMap.byIPPort)) - for k, v := range c.peerMap.byIPPort { + ent := make([]kv, 0, len(c.peerMap.byEpAddr)) + for k, v := range c.peerMap.byEpAddr { ent = append(ent, kv{k, v}) } - sort.Slice(ent, func(i, j int) bool { return ipPortLess(ent[i].ipp, ent[j].ipp) }) + sort.Slice(ent, func(i, j int) bool { return epAddrLess(ent[i].addr, ent[j].addr) }) for _, e := range ent { ep := e.pi.ep shortStr := ep.publicKey.ShortString() - fmt.Fprintf(w, "
        • %v: %v
        • \n", e.ipp, strings.Trim(shortStr, "[]"), shortStr) + fmt.Fprintf(w, "
        • %v: %v
        • \n", e.addr, strings.Trim(shortStr, "[]"), shortStr) } } @@ -102,8 +109,7 @@ func (c *Conn) ServeHTTPDebug(w http.ResponseWriter, r *http.Request) { sort.Slice(ent, func(i, j int) bool { return ent[i].pub.Less(ent[j].pub) }) peers := map[key.NodePublic]tailcfg.NodeView{} - for i := range c.peers.Len() { - p := c.peers.At(i) + for _, p := range c.peers.All() { peers[p.Key()] = p } @@ -149,11 +155,11 @@ func printEndpointHTML(w io.Writer, ep *endpoint) { for ipp := range ep.endpointState { eps = append(eps, ipp) } - sort.Slice(eps, func(i, j int) bool { return ipPortLess(eps[i], eps[j]) }) + sort.Slice(eps, func(i, j int) bool { return addrPortLess(eps[i], eps[j]) }) io.WriteString(w, "

          Endpoints:

            ") for _, ipp := range eps { s := ep.endpointState[ipp] - if ipp == ep.bestAddr.AddrPort { + if ipp == ep.bestAddr.ap && !ep.bestAddr.vni.IsSet() { fmt.Fprintf(w, "
          • %s: (best)
              ", ipp) } else { fmt.Fprintf(w, "
            • %s: ...
                ", ipp) @@ -197,9 +203,19 @@ func peerDebugName(p tailcfg.NodeView) string { return p.Hostinfo().Hostname() } -func ipPortLess(a, b netip.AddrPort) bool { +func addrPortLess(a, b netip.AddrPort) bool { if v := a.Addr().Compare(b.Addr()); v != 0 { return v < 0 } return a.Port() < b.Port() } + +func epAddrLess(a, b epAddr) bool { + if v := a.ap.Addr().Compare(b.ap.Addr()); v != 0 { + return v < 0 + } + if a.ap.Port() == b.ap.Port() { + return a.vni.Get() < b.vni.Get() + } + return a.ap.Port() < b.ap.Port() +} diff --git a/wgengine/magicsock/debugknobs.go b/wgengine/magicsock/debugknobs.go index f8fd9f040..b0a47ff87 100644 --- a/wgengine/magicsock/debugknobs.go +++ b/wgengine/magicsock/debugknobs.go @@ -62,6 +62,9 @@ var ( // //lint:ignore U1000 used on Linux/Darwin only debugPMTUD = envknob.RegisterBool("TS_DEBUG_PMTUD") + // debugNeverDirectUDP disables the use of direct UDP connections, forcing + // all peer communication over DERP or peer relay. + debugNeverDirectUDP = envknob.RegisterBool("TS_DEBUG_NEVER_DIRECT_UDP") // Hey you! Adding a new debugknob? Make sure to stub it out in the // debugknobs_stubs.go file too. ) diff --git a/wgengine/magicsock/debugknobs_stubs.go b/wgengine/magicsock/debugknobs_stubs.go index 336d7baa1..7dee1d6b0 100644 --- a/wgengine/magicsock/debugknobs_stubs.go +++ b/wgengine/magicsock/debugknobs_stubs.go @@ -31,3 +31,4 @@ func debugRingBufferMaxSizeBytes() int { return 0 } func inTest() bool { return false } func debugPeerMap() bool { return false } func pretendpoints() []netip.AddrPort { return []netip.AddrPort{} } +func debugNeverDirectUDP() bool { return false } diff --git a/wgengine/magicsock/derp.go b/wgengine/magicsock/derp.go index 69c5cbc90..37a4f1a64 100644 --- a/wgengine/magicsock/derp.go +++ b/wgengine/magicsock/derp.go @@ -11,9 +11,7 @@ import ( "net" "net/netip" "reflect" - "runtime" "slices" - "sync" "time" "unsafe" @@ -21,7 +19,6 @@ import ( "tailscale.com/derp" "tailscale.com/derp/derphttp" "tailscale.com/health" - "tailscale.com/logtail/backoff" "tailscale.com/net/dnscache" "tailscale.com/net/netcheck" "tailscale.com/net/tsaddr" @@ -30,9 +27,9 @@ import ( "tailscale.com/tstime/mono" "tailscale.com/types/key" "tailscale.com/types/logger" + "tailscale.com/util/backoff" "tailscale.com/util/mak" "tailscale.com/util/rands" - "tailscale.com/util/sysresources" "tailscale.com/util/testenv" ) @@ -64,17 +61,37 @@ func (c *Conn) removeDerpPeerRoute(peer key.NodePublic, regionID int, dc *derpht // addDerpPeerRoute adds a DERP route entry, noting that peer was seen // on DERP node derpID, at least on the connection identified by dc. // See issue 150 for details. -func (c *Conn) addDerpPeerRoute(peer key.NodePublic, derpID int, dc *derphttp.Client) { +func (c *Conn) addDerpPeerRoute(peer key.NodePublic, regionID int, dc *derphttp.Client) { c.mu.Lock() defer c.mu.Unlock() - mak.Set(&c.derpRoute, peer, derpRoute{derpID, dc}) + mak.Set(&c.derpRoute, peer, derpRoute{regionID, dc}) +} + +// fallbackDERPRegionForPeer returns the DERP region ID we might be able to use +// to contact peer, learned from observing recent DERP traffic from them. +// +// This is used as a fallback when a peer receives a packet from a peer +// over DERP but doesn't known that peer's home DERP or any UDP endpoints. +// This is particularly useful for large one-way nodes (such as hello.ts.net) +// that don't actively reach out to other nodes, so don't need to be told +// the DERP home of peers. They can instead learn the DERP home upon getting the +// first connection. +// +// This can also help nodes from a slow or misbehaving control plane. +func (c *Conn) fallbackDERPRegionForPeer(peer key.NodePublic) (regionID int) { + c.mu.Lock() + defer c.mu.Unlock() + if dr, ok := c.derpRoute[peer]; ok { + return dr.regionID + } + return 0 } // activeDerp contains fields for an active DERP connection. type activeDerp struct { c *derphttp.Client cancel context.CancelFunc - writeCh chan<- derpWriteRequest + writeCh chan derpWriteRequest // lastWrite is the time of the last request for its write // channel (currently even if there was no write). // It is always non-nil and initialized to a non-zero Time. @@ -158,10 +175,10 @@ func (c *Conn) maybeSetNearestDERP(report *netcheck.Report) (preferredDERP int) } else { connectedToControl = c.health.GetInPollNetMap() } + c.mu.Lock() + myDerp := c.myDerp + c.mu.Unlock() if !connectedToControl { - c.mu.Lock() - myDerp := c.myDerp - c.mu.Unlock() if myDerp != 0 { metricDERPHomeNoChangeNoControl.Add(1) return myDerp @@ -178,6 +195,11 @@ func (c *Conn) maybeSetNearestDERP(report *netcheck.Report) (preferredDERP int) // one. preferredDERP = c.pickDERPFallback() } + if preferredDERP != myDerp { + c.logf( + "magicsock: home DERP changing from derp-%d [%dms] to derp-%d [%dms]", + c.myDerp, report.RegionLatency[myDerp].Milliseconds(), preferredDERP, report.RegionLatency[preferredDERP].Milliseconds()) + } if !c.setNearestDERP(preferredDERP) { preferredDERP = 0 } @@ -257,59 +279,20 @@ func (c *Conn) goDerpConnect(regionID int) { go c.derpWriteChanForRegion(regionID, key.NodePublic{}) } -var ( - bufferedDerpWrites int - bufferedDerpWritesOnce sync.Once -) - -// bufferedDerpWritesBeforeDrop returns how many packets writes can be queued -// up the DERP client to write on the wire before we start dropping. -func bufferedDerpWritesBeforeDrop() int { - // For mobile devices, always return the previous minimum value of 32; - // we can do this outside the sync.Once to avoid that overhead. - if runtime.GOOS == "ios" || runtime.GOOS == "android" { - return 32 - } - - bufferedDerpWritesOnce.Do(func() { - // Some rough sizing: for the previous fixed value of 32, the - // total consumed memory can be: - // = numDerpRegions * messages/region * sizeof(message) - // - // For sake of this calculation, assume 100 DERP regions; at - // time of writing (2023-04-03), we have 24. - // - // A reasonable upper bound for the worst-case average size of - // a message is a *disco.CallMeMaybe message with 16 endpoints; - // since sizeof(netip.AddrPort) = 32, that's 512 bytes. Thus: - // = 100 * 32 * 512 - // = 1638400 (1.6MiB) - // - // On a reasonably-small node with 4GiB of memory that's - // connected to each region and handling a lot of load, 1.6MiB - // is about 0.04% of the total system memory. - // - // For sake of this calculation, then, let's double that memory - // usage to 0.08% and scale based on total system memory. - // - // For a 16GiB Linux box, this should buffer just over 256 - // messages. - systemMemory := sysresources.TotalMemory() - memoryUsable := float64(systemMemory) * 0.0008 - - const ( - theoreticalDERPRegions = 100 - messageMaximumSizeBytes = 512 - ) - bufferedDerpWrites = int(memoryUsable / (theoreticalDERPRegions * messageMaximumSizeBytes)) - - // Never drop below the previous minimum value. - if bufferedDerpWrites < 32 { - bufferedDerpWrites = 32 - } - }) - return bufferedDerpWrites -} +// derpWriteQueueDepth is the depth of the in-process write queue to a single +// DERP region. DERP connections are TCP, and so the actual write queue depth is +// substantially larger than this suggests - often scaling into megabytes +// depending on dynamic TCP parameters and platform TCP tuning. This queue is +// excess of the TCP buffer depth, which means it's almost pure buffer bloat, +// and does not want to be deep - if there are key situations where a node can't +// keep up, either the TCP link to DERP is too slow, or there is a +// synchronization issue in the write path, fixes should be focused on those +// paths, rather than extending this queue. +// TODO(raggi): make this even shorter, ideally this should be a fairly direct +// line into a socket TCP buffer. The challenge at present is that connect and +// reconnect are in the write path and we don't want to block other write +// operations on those. +const derpWriteQueueDepth = 32 // derpWriteChanForRegion returns a channel to which to send DERP packet write // requests. It creates a new DERP connection to regionID if necessary. @@ -319,7 +302,7 @@ func bufferedDerpWritesBeforeDrop() int { // // It returns nil if the network is down, the Conn is closed, or the regionID is // not known. -func (c *Conn) derpWriteChanForRegion(regionID int, peer key.NodePublic) chan<- derpWriteRequest { +func (c *Conn) derpWriteChanForRegion(regionID int, peer key.NodePublic) chan derpWriteRequest { if c.networkDown() { return nil } @@ -404,7 +387,7 @@ func (c *Conn) derpWriteChanForRegion(regionID int, peer key.NodePublic) chan<- dc.DNSCache = dnscache.Get() ctx, cancel := context.WithCancel(c.connCtx) - ch := make(chan derpWriteRequest, bufferedDerpWritesBeforeDrop()) + ch := make(chan derpWriteRequest, derpWriteQueueDepth) ad.c = dc ad.writeCh = ch @@ -627,7 +610,7 @@ func (c *Conn) runDerpReader(ctx context.Context, regionID int, dc *derphttp.Cli // Do nothing. case derp.PeerGoneReasonNotHere: metricRecvDiscoDERPPeerNotHere.Add(1) - c.logf("[unexpected] magicsock: derp-%d does not know about peer %s, removing route", + c.logf("magicsock: derp-%d does not know about peer %s, removing route", regionID, key.NodePublic(m.Peer).ShortString()) default: metricRecvDiscoDERPPeerGoneUnknown.Add(1) @@ -644,9 +627,10 @@ func (c *Conn) runDerpReader(ctx context.Context, regionID int, dc *derphttp.Cli } type derpWriteRequest struct { - addr netip.AddrPort - pubKey key.NodePublic - b []byte // copied; ownership passed to receiver + addr netip.AddrPort + pubKey key.NodePublic + b []byte // copied; ownership passed to receiver + isDisco bool } // runDerpWriter runs in a goroutine for the life of a DERP @@ -668,8 +652,12 @@ func (c *Conn) runDerpWriter(ctx context.Context, dc *derphttp.Client, ch <-chan if err != nil { c.logf("magicsock: derp.Send(%v): %v", wr.addr, err) metricSendDERPError.Add(1) - } else { - metricSendDERP.Add(1) + if !wr.isDisco { + c.metrics.outboundPacketsDroppedErrors.Add(1) + } + } else if !wr.isDisco { + c.metrics.outboundPacketsDERPTotal.Add(1) + c.metrics.outboundBytesDERPTotal.Add(int64(len(wr.b))) } } } @@ -690,7 +678,6 @@ func (c *connBind) receiveDERP(buffs [][]byte, sizes []int, eps []conn.Endpoint) // No data read occurred. Wait for another packet. continue } - metricRecvDataDERP.Add(1) sizes[0] = n eps[0] = ep return 1, nil @@ -711,8 +698,11 @@ func (c *Conn) processDERPReadResult(dm derpReadResult, b []byte) (n int, ep *en return 0, nil } - ipp := netip.AddrPortFrom(tailcfg.DerpMagicIPAddr, uint16(regionID)) - if c.handleDiscoMessage(b[:n], ipp, dm.src, discoRXPathDERP) { + srcAddr := epAddr{ap: netip.AddrPortFrom(tailcfg.DerpMagicIPAddr, uint16(regionID))} + pt, isGeneveEncap := packetLooksLike(b[:n]) + if pt == packetLooksLikeDisco && + !isGeneveEncap { // We should never receive Geneve-encapsulated disco over DERP. + c.handleDiscoMessage(b[:n], srcAddr, false, dm.src, discoRXPathDERP) return 0, nil } @@ -726,10 +716,13 @@ func (c *Conn) processDERPReadResult(dm derpReadResult, b []byte) (n int, ep *en return 0, nil } - ep.noteRecvActivity(ipp, mono.Now()) - if stats := c.stats.Load(); stats != nil { - stats.UpdateRxPhysical(ep.nodeAddr, ipp, dm.n) + ep.noteRecvActivity(srcAddr, mono.Now()) + if update := c.connCounter.Load(); update != nil { + update(0, netip.AddrPortFrom(ep.nodeAddr, 0), srcAddr.ap, 1, dm.n, true) } + + c.metrics.inboundPacketsDERPTotal.Add(1) + c.metrics.inboundBytesDERPTotal.Add(int64(n)) return n, ep } @@ -843,7 +836,6 @@ func (c *Conn) maybeCloseDERPsOnRebind(okayLocalIPs []netip.Prefix) { c.closeOrReconnectDERPLocked(regionID, "rebind-default-route-change") continue } - regionID := regionID dc := ad.c go func() { ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) diff --git a/wgengine/magicsock/disco_atomic.go b/wgengine/magicsock/disco_atomic.go new file mode 100644 index 000000000..5b765fbc2 --- /dev/null +++ b/wgengine/magicsock/disco_atomic.go @@ -0,0 +1,58 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package magicsock + +import ( + "sync/atomic" + + "tailscale.com/types/key" +) + +type discoKeyPair struct { + private key.DiscoPrivate + public key.DiscoPublic + short string // public.ShortString() +} + +// discoAtomic is an atomic container for a disco private key, public key, and +// the public key's ShortString. The private and public keys are always kept +// synchronized. +// +// The zero value is not ready for use. Use [Set] to provide a usable value. +type discoAtomic struct { + pair atomic.Pointer[discoKeyPair] +} + +// Pair returns the private and public keys together atomically. +// Code that needs both the private and public keys synchronized should +// use Pair instead of calling Private and Public separately. +func (dk *discoAtomic) Pair() (key.DiscoPrivate, key.DiscoPublic) { + p := dk.pair.Load() + return p.private, p.public +} + +// Private returns the private key. +func (dk *discoAtomic) Private() key.DiscoPrivate { + return dk.pair.Load().private +} + +// Public returns the public key. +func (dk *discoAtomic) Public() key.DiscoPublic { + return dk.pair.Load().public +} + +// Short returns the short string of the public key (see [DiscoPublic.ShortString]). +func (dk *discoAtomic) Short() string { + return dk.pair.Load().short +} + +// Set updates the private key (and the cached public key and short string). +func (dk *discoAtomic) Set(private key.DiscoPrivate) { + public := private.Public() + dk.pair.Store(&discoKeyPair{ + private: private, + public: public, + short: public.ShortString(), + }) +} diff --git a/wgengine/magicsock/disco_atomic_test.go b/wgengine/magicsock/disco_atomic_test.go new file mode 100644 index 000000000..a1de9b843 --- /dev/null +++ b/wgengine/magicsock/disco_atomic_test.go @@ -0,0 +1,70 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package magicsock + +import ( + "testing" + + "tailscale.com/types/key" +) + +func TestDiscoAtomic(t *testing.T) { + var dk discoAtomic + dk.Set(key.NewDisco()) + + private := dk.Private() + public := dk.Public() + short := dk.Short() + + if private.IsZero() { + t.Fatal("DiscoKey private key should not be zero") + } + if public.IsZero() { + t.Fatal("DiscoKey public key should not be zero") + } + if short == "" { + t.Fatal("DiscoKey short string should not be empty") + } + + if public != private.Public() { + t.Fatal("DiscoKey public key doesn't match private key") + } + if short != public.ShortString() { + t.Fatal("DiscoKey short string doesn't match public key") + } + + gotPrivate, gotPublic := dk.Pair() + if !gotPrivate.Equal(private) { + t.Fatal("Pair() returned different private key") + } + if gotPublic != public { + t.Fatal("Pair() returned different public key") + } +} + +func TestDiscoAtomicSet(t *testing.T) { + var dk discoAtomic + dk.Set(key.NewDisco()) + oldPrivate := dk.Private() + oldPublic := dk.Public() + + newPrivate := key.NewDisco() + dk.Set(newPrivate) + + currentPrivate := dk.Private() + currentPublic := dk.Public() + + if currentPrivate.Equal(oldPrivate) { + t.Fatal("DiscoKey private key should have changed after Set") + } + if currentPublic == oldPublic { + t.Fatal("DiscoKey public key should have changed after Set") + } + if !currentPrivate.Equal(newPrivate) { + t.Fatal("DiscoKey private key doesn't match the set key") + } + if currentPublic != newPrivate.Public() { + t.Fatal("DiscoKey public key doesn't match derived from set private key") + } +} diff --git a/wgengine/magicsock/discopingpurpose_string.go b/wgengine/magicsock/discopingpurpose_string.go index 3dc327de1..8eebf97a2 100644 --- a/wgengine/magicsock/discopingpurpose_string.go +++ b/wgengine/magicsock/discopingpurpose_string.go @@ -22,8 +22,9 @@ const _discoPingPurpose_name = "DiscoveryHeartbeatCLIHeartbeatForUDPLifetime" var _discoPingPurpose_index = [...]uint8{0, 9, 18, 21, 44} func (i discoPingPurpose) String() string { - if i < 0 || i >= discoPingPurpose(len(_discoPingPurpose_index)-1) { + idx := int(i) - 0 + if i < 0 || idx >= len(_discoPingPurpose_index)-1 { return "discoPingPurpose(" + strconv.FormatInt(int64(i), 10) + ")" } - return _discoPingPurpose_name[_discoPingPurpose_index[i]:_discoPingPurpose_index[i+1]] + return _discoPingPurpose_name[_discoPingPurpose_index[idx]:_discoPingPurpose_index[idx+1]] } diff --git a/wgengine/magicsock/endpoint.go b/wgengine/magicsock/endpoint.go index 53ecb84de..eda589e14 100644 --- a/wgengine/magicsock/endpoint.go +++ b/wgengine/magicsock/endpoint.go @@ -9,6 +9,7 @@ import ( "encoding/binary" "errors" "fmt" + "iter" "math" "math/rand/v2" "net" @@ -16,23 +17,24 @@ import ( "reflect" "runtime" "slices" - "sync" "sync/atomic" "time" - xmaps "golang.org/x/exp/maps" "golang.org/x/net/ipv4" "golang.org/x/net/ipv6" "tailscale.com/disco" "tailscale.com/ipn/ipnstate" + "tailscale.com/net/packet" "tailscale.com/net/stun" "tailscale.com/net/tstun" + "tailscale.com/syncs" "tailscale.com/tailcfg" "tailscale.com/tstime/mono" "tailscale.com/types/key" "tailscale.com/types/logger" "tailscale.com/util/mak" - "tailscale.com/util/ringbuffer" + "tailscale.com/util/ringlog" + "tailscale.com/util/slicesx" ) var mtuProbePingSizesV4 []int @@ -58,7 +60,7 @@ type endpoint struct { lastRecvWG mono.Time // last time there were incoming packets from this peer destined for wireguard-go (e.g. not disco) lastRecvUDPAny mono.Time // last time there were incoming UDP packets from this peer of any kind numStopAndResetAtomic int64 - debugUpdates *ringbuffer.RingBuffer[EndpointChange] + debugUpdates *ringlog.RingLog[EndpointChange] // These fields are initialized once and never modified. c *Conn @@ -71,19 +73,20 @@ type endpoint struct { disco atomic.Pointer[endpointDisco] // if the peer supports disco, the key and short string // mu protects all following fields. - mu sync.Mutex // Lock ordering: Conn.mu, then endpoint.mu + mu syncs.Mutex // Lock ordering: Conn.mu, then endpoint.mu - heartBeatTimer *time.Timer // nil when idle - lastSendExt mono.Time // last time there were outgoing packets sent to this peer from an external trigger (e.g. wireguard-go or disco pingCLI) - lastSendAny mono.Time // last time there were outgoing packets sent this peer from any trigger, internal or external to magicsock - lastFullPing mono.Time // last time we pinged all disco or wireguard only endpoints - derpAddr netip.AddrPort // fallback/bootstrap path, if non-zero (non-zero for well-behaved clients) + heartBeatTimer *time.Timer // nil when idle + lastSendExt mono.Time // last time there were outgoing packets sent to this peer from an external trigger (e.g. wireguard-go or disco pingCLI) + lastSendAny mono.Time // last time there were outgoing packets sent this peer from any trigger, internal or external to magicsock + lastFullPing mono.Time // last time we pinged all disco or wireguard only endpoints + lastUDPRelayPathDiscovery mono.Time // last time we ran UDP relay path discovery + derpAddr netip.AddrPort // fallback/bootstrap path, if non-zero (non-zero for well-behaved clients) bestAddr addrQuality // best non-DERP path; zero if none; mutate via setBestAddrLocked() bestAddrAt mono.Time // time best address re-confirmed trustBestAddrUntil mono.Time // time when bestAddr expires sentPing map[stun.TxID]sentPing - endpointState map[netip.AddrPort]*endpointState + endpointState map[netip.AddrPort]*endpointState // netip.AddrPort type for key (instead of [epAddr]) as [endpointState] is irrelevant for Geneve-encapsulated paths isCallMeMaybeEP map[netip.AddrPort]bool // The following fields are related to the new "silent disco" @@ -94,10 +97,40 @@ type endpoint struct { expired bool // whether the node has expired isWireguardOnly bool // whether the endpoint is WireGuard only + relayCapable bool // whether the node is capable of speaking via a [tailscale.com/net/udprelay.Server] +} + +// udpRelayEndpointReady determines whether the given relay [addrQuality] should +// be installed as de.bestAddr. It is only called by [relayManager] once it has +// determined maybeBest is functional via [disco.Pong] reception. +func (de *endpoint) udpRelayEndpointReady(maybeBest addrQuality) { + de.mu.Lock() + defer de.mu.Unlock() + now := mono.Now() + curBestAddrTrusted := now.Before(de.trustBestAddrUntil) + sameRelayServer := de.bestAddr.vni.IsSet() && maybeBest.relayServerDisco.Compare(de.bestAddr.relayServerDisco) == 0 + + if !curBestAddrTrusted || + sameRelayServer || + betterAddr(maybeBest, de.bestAddr) { + // We must set maybeBest as de.bestAddr if: + // 1. de.bestAddr is untrusted. betterAddr does not consider + // time-based trust. + // 2. maybeBest & de.bestAddr are on the same relay. If the maybeBest + // handshake happened to use a different source address/transport, + // the relay will drop packets from the 'old' de.bestAddr's. + // 3. maybeBest is a 'betterAddr'. + // + // TODO(jwhited): add observability around !curBestAddrTrusted and sameRelayServer + // TODO(jwhited): collapse path change logging with endpoint.handlePongConnLocked() + de.c.logf("magicsock: disco: node %v %v now using %v mtu=%v", de.publicKey.ShortString(), de.discoShort(), maybeBest.epAddr, maybeBest.wireMTU) + de.setBestAddrLocked(maybeBest) + de.trustBestAddrUntil = now.Add(trustUDPAddrDuration) + } } func (de *endpoint) setBestAddrLocked(v addrQuality) { - if v.AddrPort != de.bestAddr.AddrPort { + if v.epAddr != de.bestAddr.epAddr { de.probeUDPLifetime.resetCycleEndpointLocked() } de.bestAddr = v @@ -133,11 +166,11 @@ type probeUDPLifetime struct { // timeout cliff in the future. timer *time.Timer - // bestAddr contains the endpoint.bestAddr.AddrPort at the time a cycle was + // bestAddr contains the endpoint.bestAddr.epAddr at the time a cycle was // scheduled to start. A probing cycle is 1:1 with the current - // endpoint.bestAddr.AddrPort in the interest of simplicity. When - // endpoint.bestAddr.AddrPort changes, any active probing cycle will reset. - bestAddr netip.AddrPort + // endpoint.bestAddr.epAddr in the interest of simplicity. When + // endpoint.bestAddr.epAddr changes, any active probing cycle will reset. + bestAddr epAddr // cycleStartedAt contains the time at which the first cliff // (ProbeUDPLifetimeConfig.Cliffs[0]) was pinged for the current/last cycle. cycleStartedAt time.Time @@ -189,7 +222,7 @@ func (p *probeUDPLifetime) resetCycleEndpointLocked() { } p.cycleActive = false p.currentCliff = 0 - p.bestAddr = netip.AddrPort{} + p.bestAddr = epAddr{} } // ProbeUDPLifetimeConfig represents the configuration for probing UDP path @@ -332,7 +365,7 @@ type endpointDisco struct { } type sentPing struct { - to netip.AddrPort + to epAddr at mono.Time timer *time.Timer // timeout timer purpose discoPingPurpose @@ -444,7 +477,8 @@ func (de *endpoint) deleteEndpointLocked(why string, ep netip.AddrPort) { From: ep, }) delete(de.endpointState, ep) - if de.bestAddr.AddrPort == ep { + asEpAddr := epAddr{ap: ep} + if de.bestAddr.epAddr == asEpAddr { de.debugUpdates.Add(EndpointChange{ When: time.Now(), What: "deleteEndpointLocked-bestAddr-" + why, @@ -466,11 +500,12 @@ func (de *endpoint) initFakeUDPAddr() { } // noteRecvActivity records receive activity on de, and invokes -// Conn.noteRecvActivity no more than once every 10s. -func (de *endpoint) noteRecvActivity(ipp netip.AddrPort, now mono.Time) { +// Conn.noteRecvActivity no more than once every 10s, returning true if it +// was called, otherwise false. +func (de *endpoint) noteRecvActivity(src epAddr, now mono.Time) bool { if de.isWireguardOnly { de.mu.Lock() - de.bestAddr.AddrPort = ipp + de.bestAddr.ap = src.ap de.bestAddrAt = now de.trustBestAddrUntil = now.Add(5 * time.Second) de.mu.Unlock() @@ -480,7 +515,7 @@ func (de *endpoint) noteRecvActivity(ipp netip.AddrPort, now mono.Time) { // kick off discovery disco pings every trustUDPAddrDuration and mirror // to DERP. de.mu.Lock() - if de.heartbeatDisabled && de.bestAddr.AddrPort == ipp { + if de.heartbeatDisabled && de.bestAddr.epAddr == src { de.trustBestAddrUntil = now.Add(trustUDPAddrDuration) } de.mu.Unlock() @@ -491,10 +526,12 @@ func (de *endpoint) noteRecvActivity(ipp netip.AddrPort, now mono.Time) { de.lastRecvWG.StoreAtomic(now) if de.c.noteRecvActivity == nil { - return + return false } de.c.noteRecvActivity(de.publicKey) + return true } + return false } func (de *endpoint) discoShort() string { @@ -528,10 +565,10 @@ func (de *endpoint) DstToBytes() []byte { return packIPPort(de.fakeWGAddr) } // de.mu must be held. // // TODO(val): Rewrite the addrFor*Locked() variations to share code. -func (de *endpoint) addrForSendLocked(now mono.Time) (udpAddr, derpAddr netip.AddrPort, sendWGPing bool) { - udpAddr = de.bestAddr.AddrPort +func (de *endpoint) addrForSendLocked(now mono.Time) (udpAddr epAddr, derpAddr netip.AddrPort, sendWGPing bool) { + udpAddr = de.bestAddr.epAddr - if udpAddr.IsValid() && !now.After(de.trustBestAddrUntil) { + if udpAddr.ap.IsValid() && !now.After(de.trustBestAddrUntil) { return udpAddr, netip.AddrPort{}, false } @@ -550,12 +587,12 @@ func (de *endpoint) addrForSendLocked(now mono.Time) (udpAddr, derpAddr netip.Ad // addrForWireGuardSendLocked returns the address that should be used for // sending the next packet. If a packet has never or not recently been sent to // the endpoint, then a randomly selected address for the endpoint is returned, -// as well as a bool indiciating that WireGuard discovery pings should be started. +// as well as a bool indicating that WireGuard discovery pings should be started. // If the addresses have latency information available, then the address with the // best latency is used. // // de.mu must be held. -func (de *endpoint) addrForWireGuardSendLocked(now mono.Time) (udpAddr netip.AddrPort, shouldPing bool) { +func (de *endpoint) addrForWireGuardSendLocked(now mono.Time) (udpAddr epAddr, shouldPing bool) { if len(de.endpointState) == 0 { de.c.logf("magicsock: addrForSendWireguardLocked: [unexpected] no candidates available for endpoint") return udpAddr, false @@ -579,22 +616,22 @@ func (de *endpoint) addrForWireGuardSendLocked(now mono.Time) (udpAddr netip.Add // TODO(catzkorn): Consider a small increase in latency to use // IPv6 in comparison to IPv4, when possible. lowestLatency = latency - udpAddr = ipp + udpAddr.ap = ipp } } } needPing := len(de.endpointState) > 1 && now.Sub(oldestPing) > wireguardPingInterval - if !udpAddr.IsValid() { - candidates := xmaps.Keys(de.endpointState) + if !udpAddr.ap.IsValid() { + candidates := slicesx.MapKeys(de.endpointState) // Randomly select an address to use until we retrieve latency information // and give it a short trustBestAddrUntil time so we avoid flapping between // addresses while waiting on latency information to be populated. - udpAddr = candidates[rand.IntN(len(candidates))] + udpAddr.ap = candidates[rand.IntN(len(candidates))] } - de.bestAddr.AddrPort = udpAddr + de.bestAddr.epAddr = epAddr{ap: udpAddr.ap} // Only extend trustBestAddrUntil by one second to avoid packet // reordering and/or CPU usage from random selection during the first // second. We should receive a response due to a WireGuard handshake in @@ -612,18 +649,18 @@ func (de *endpoint) addrForWireGuardSendLocked(now mono.Time) (udpAddr netip.Add // both of the returned UDP address and DERP address may be non-zero. // // de.mu must be held. -func (de *endpoint) addrForPingSizeLocked(now mono.Time, size int) (udpAddr, derpAddr netip.AddrPort) { +func (de *endpoint) addrForPingSizeLocked(now mono.Time, size int) (udpAddr epAddr, derpAddr netip.AddrPort) { if size == 0 { udpAddr, derpAddr, _ = de.addrForSendLocked(now) return } - udpAddr = de.bestAddr.AddrPort + udpAddr = de.bestAddr.epAddr pathMTU := de.bestAddr.wireMTU - requestedMTU := pingSizeToPktLen(size, udpAddr.Addr().Is6()) + requestedMTU := pingSizeToPktLen(size, udpAddr) mtuOk := requestedMTU <= pathMTU - if udpAddr.IsValid() && mtuOk { + if udpAddr.ap.IsValid() && mtuOk { if !now.After(de.trustBestAddrUntil) { return udpAddr, netip.AddrPort{} } @@ -636,7 +673,7 @@ func (de *endpoint) addrForPingSizeLocked(now mono.Time, size int) (udpAddr, der // for the packet. Return a zero-value udpAddr to signal that we should // keep probing the path MTU to all addresses for this endpoint, and a // valid DERP addr to signal that we should also send via DERP. - return netip.AddrPort{}, de.derpAddr + return epAddr{}, de.derpAddr } // maybeProbeUDPLifetimeLocked returns an afterInactivityFor duration and true @@ -647,7 +684,7 @@ func (de *endpoint) maybeProbeUDPLifetimeLocked() (afterInactivityFor time.Durat if p == nil { return afterInactivityFor, false } - if !de.bestAddr.IsValid() { + if !de.bestAddr.ap.IsValid() { return afterInactivityFor, false } epDisco := de.disco.Load() @@ -660,7 +697,7 @@ func (de *endpoint) maybeProbeUDPLifetimeLocked() (afterInactivityFor time.Durat // shuffling probing probability where the local node ends up with a large // key value lexicographically relative to the other nodes it tends to // communicate with. If de's disco key changes, the cycle will reset. - if de.c.discoPublic.Compare(epDisco.key) >= 0 { + if de.c.discoAtomic.Public().Compare(epDisco.key) >= 0 { // lower disco pub key node probes higher return afterInactivityFor, false } @@ -699,7 +736,7 @@ func (de *endpoint) scheduleHeartbeatForLifetimeLocked(after time.Duration, via } de.c.dlogf("[v1] magicsock: disco: scheduling UDP lifetime probe for cliff=%v via=%v to %v (%v)", p.currentCliffDurationEndpointLocked(), via, de.publicKey.ShortString(), de.discoShort()) - p.bestAddr = de.bestAddr.AddrPort + p.bestAddr = de.bestAddr.epAddr p.timer = time.AfterFunc(after, de.heartbeatForLifetime) if via == heartbeatForLifetimeViaSelf { metricUDPLifetimeCliffsRescheduled.Add(1) @@ -727,7 +764,7 @@ func (de *endpoint) heartbeatForLifetime() { return } p.timer = nil - if !p.bestAddr.IsValid() || de.bestAddr.AddrPort != p.bestAddr { + if !p.bestAddr.ap.IsValid() || de.bestAddr.epAddr != p.bestAddr { // best path changed p.resetCycleEndpointLocked() return @@ -759,7 +796,7 @@ func (de *endpoint) heartbeatForLifetime() { } de.c.dlogf("[v1] magicsock: disco: sending disco ping for UDP lifetime probe cliff=%v to %v (%v)", p.currentCliffDurationEndpointLocked(), de.publicKey.ShortString(), de.discoShort()) - de.startDiscoPingLocked(de.bestAddr.AddrPort, mono.Now(), pingHeartbeatForUDPLifetime, 0, nil) + de.startDiscoPingLocked(de.bestAddr.epAddr, mono.Now(), pingHeartbeatForUDPLifetime, 0, nil) } // heartbeat is called every heartbeatInterval to keep the best UDP path alive, @@ -817,8 +854,8 @@ func (de *endpoint) heartbeat() { } udpAddr, _, _ := de.addrForSendLocked(now) - if udpAddr.IsValid() { - // We have a preferred path. Ping that every 2 seconds. + if udpAddr.ap.IsValid() { + // We have a preferred path. Ping that every 'heartbeatInterval'. de.startDiscoPingLocked(udpAddr, now, pingHeartbeat, 0, nil) } @@ -826,6 +863,10 @@ func (de *endpoint) heartbeat() { de.sendDiscoPingsLocked(now, true) } + if de.wantUDPRelayPathDiscoveryLocked(now) { + de.discoverUDPRelayPathsLocked(now) + } + de.heartBeatTimer = time.AfterFunc(heartbeatInterval, de.heartbeat) } @@ -836,6 +877,53 @@ func (de *endpoint) setHeartbeatDisabled(v bool) { de.heartbeatDisabled = v } +// discoverUDPRelayPathsLocked starts UDP relay path discovery. +func (de *endpoint) discoverUDPRelayPathsLocked(now mono.Time) { + de.lastUDPRelayPathDiscovery = now + lastBest := de.bestAddr + lastBestIsTrusted := mono.Now().Before(de.trustBestAddrUntil) + de.c.relayManager.startUDPRelayPathDiscoveryFor(de, lastBest, lastBestIsTrusted) +} + +// wantUDPRelayPathDiscoveryLocked reports whether we should kick off UDP relay +// path discovery. +func (de *endpoint) wantUDPRelayPathDiscoveryLocked(now mono.Time) bool { + if runtime.GOOS == "js" { + return false + } + if !de.c.hasPeerRelayServers.Load() { + // Changes in this value between its access and a call to + // [endpoint.discoverUDPRelayPathsLocked] are fine, we will eventually + // do the "right" thing during future path discovery. The worst case is + // we suppress path discovery for the current cycle, or we unnecessarily + // call into [relayManager] and do some wasted work. + return false + } + if !de.relayCapable { + return false + } + if de.bestAddr.isDirect() && now.Before(de.trustBestAddrUntil) { + return false + } + if !de.lastUDPRelayPathDiscovery.IsZero() && now.Sub(de.lastUDPRelayPathDiscovery) < discoverUDPRelayPathsInterval { + return false + } + // TODO(jwhited): consider applying 'goodEnoughLatency' suppression here, + // but not until we have a strategy for triggering CallMeMaybeVia regularly + // and/or enabling inbound packets to act as a UDP relay path discovery + // trigger, otherwise clients without relay servers may fall off a UDP + // relay path and never come back. They are dependent on the remote side + // regularly TX'ing CallMeMaybeVia, which currently only happens as part + // of full UDP relay path discovery. + if now.After(de.trustBestAddrUntil) { + return true + } + if !de.lastUDPRelayPathDiscovery.IsZero() && now.Sub(de.lastUDPRelayPathDiscovery) >= upgradeUDPRelayInterval { + return true + } + return false +} + // wantFullPingLocked reports whether we should ping to all our peers looking for // a better path. // @@ -844,7 +932,7 @@ func (de *endpoint) wantFullPingLocked(now mono.Time) bool { if runtime.GOOS == "js" { return false } - if !de.bestAddr.IsValid() || de.lastFullPing.IsZero() { + if !de.bestAddr.isDirect() || de.lastFullPing.IsZero() { return true } if now.After(de.trustBestAddrUntil) { @@ -853,7 +941,7 @@ func (de *endpoint) wantFullPingLocked(now mono.Time) bool { if de.bestAddr.latency <= goodEnoughLatency { return false } - if now.Sub(de.lastFullPing) >= upgradeInterval { + if now.Sub(de.lastFullPing) >= upgradeUDPDirectInterval { return true } return false @@ -904,17 +992,38 @@ func (de *endpoint) discoPing(res *ipnstate.PingResult, size int, cb func(*ipnst udpAddr, derpAddr := de.addrForPingSizeLocked(now, size) if derpAddr.IsValid() { - de.startDiscoPingLocked(derpAddr, now, pingCLI, size, resCB) + de.startDiscoPingLocked(epAddr{ap: derpAddr}, now, pingCLI, size, resCB) } - if udpAddr.IsValid() && now.Before(de.trustBestAddrUntil) { - // Already have an active session, so just ping the address we're using. - // Otherwise "tailscale ping" results to a node on the local network - // can look like they're bouncing between, say 10.0.0.0/9 and the peer's - // IPv6 address, both 1ms away, and it's random who replies first. + + switch { + case udpAddr.ap.IsValid() && now.Before(de.trustBestAddrUntil): + // We have a "trusted" direct OR peer relay address, ping it. de.startDiscoPingLocked(udpAddr, now, pingCLI, size, resCB) - } else { + if !udpAddr.vni.IsSet() { + // If the path is direct we do not want to fallthrough to pinging + // all candidate direct paths, otherwise "tailscale ping" results to + // a node on the local network can look like they're bouncing + // between, say 10.0.0.0/8 and the peer's IPv6 address, both 1ms + // away, and it's random who replies first. cb() is called with the + // first reply, vs background path discovery that is subject to + // betterAddr() comparison and hysteresis + break + } + // If the trusted path is via a peer relay we want to fallthrough in + // order to also try all candidate direct paths. + fallthrough + default: + // Ping all candidate direct paths and start peer relay path discovery, + // if appropriate. This work overlaps with what [de.heartbeat] will + // periodically fire when it calls [de.sendDiscoPingsLocked] and + // [de.discoveryUDPRelayPathsLocked], but a user-initiated [pingCLI] is + // a "do it now" operation that should not be subject to + // [heartbeatInterval] tick or [discoPingInterval] rate-limiting. for ep := range de.endpointState { - de.startDiscoPingLocked(ep, now, pingCLI, size, resCB) + de.startDiscoPingLocked(epAddr{ap: ep}, now, pingCLI, size, resCB) + } + if de.wantUDPRelayPathDiscoveryLocked(now) { + de.discoverUDPRelayPathsLocked(now) } } } @@ -925,7 +1034,7 @@ var ( errPingTooBig = errors.New("ping size too big") ) -func (de *endpoint) send(buffs [][]byte) error { +func (de *endpoint) send(buffs [][]byte, offset int) error { de.mu.Lock() if de.expired { de.mu.Unlock() @@ -939,19 +1048,30 @@ func (de *endpoint) send(buffs [][]byte) error { if startWGPing { de.sendWireGuardOnlyPingsLocked(now) } - } else if !udpAddr.IsValid() || now.After(de.trustBestAddrUntil) { + } else if !udpAddr.isDirect() || now.After(de.trustBestAddrUntil) { de.sendDiscoPingsLocked(now, true) + if de.wantUDPRelayPathDiscoveryLocked(now) { + de.discoverUDPRelayPathsLocked(now) + } } de.noteTxActivityExtTriggerLocked(now) de.lastSendAny = now de.mu.Unlock() - if !udpAddr.IsValid() && !derpAddr.IsValid() { - return errNoUDPOrDERP + if !udpAddr.ap.IsValid() && !derpAddr.IsValid() { + // Make a last ditch effort to see if we have a DERP route for them. If + // they contacted us over DERP and we don't know their UDP endpoints or + // their DERP home, we can at least assume they're reachable over the + // DERP they used to contact us. + if rid := de.c.fallbackDERPRegionForPeer(de.publicKey); rid != 0 { + derpAddr = netip.AddrPortFrom(tailcfg.DerpMagicIPAddr, uint16(rid)) + } else { + return errNoUDPOrDERP + } } var err error - if udpAddr.IsValid() { - _, err = de.c.sendUDPBatch(udpAddr, buffs) + if udpAddr.ap.IsValid() { + _, err = de.c.sendUDPBatch(udpAddr, buffs, offset) // If the error is known to indicate that the endpoint is no longer // usable, clear the endpoint statistics so that the next send will @@ -960,26 +1080,52 @@ func (de *endpoint) send(buffs [][]byte) error { de.noteBadEndpoint(udpAddr) } - // TODO(raggi): needs updating for accuracy, as in error conditions we may have partial sends. - if stats := de.c.stats.Load(); err == nil && stats != nil { - var txBytes int - for _, b := range buffs { - txBytes += len(b) + var txBytes int + for _, b := range buffs { + txBytes += len(b[offset:]) + } + + switch { + case udpAddr.ap.Addr().Is4(): + if udpAddr.vni.IsSet() { + de.c.metrics.outboundPacketsPeerRelayIPv4Total.Add(int64(len(buffs))) + de.c.metrics.outboundBytesPeerRelayIPv4Total.Add(int64(txBytes)) + } else { + de.c.metrics.outboundPacketsIPv4Total.Add(int64(len(buffs))) + de.c.metrics.outboundBytesIPv4Total.Add(int64(txBytes)) + } + case udpAddr.ap.Addr().Is6(): + if udpAddr.vni.IsSet() { + de.c.metrics.outboundPacketsPeerRelayIPv6Total.Add(int64(len(buffs))) + de.c.metrics.outboundBytesPeerRelayIPv6Total.Add(int64(txBytes)) + } else { + de.c.metrics.outboundPacketsIPv6Total.Add(int64(len(buffs))) + de.c.metrics.outboundBytesIPv6Total.Add(int64(txBytes)) } - stats.UpdateTxPhysical(de.nodeAddr, udpAddr, txBytes) + } + + // TODO(raggi): needs updating for accuracy, as in error conditions we may have partial sends. + if update := de.c.connCounter.Load(); err == nil && update != nil { + update(0, netip.AddrPortFrom(de.nodeAddr, 0), udpAddr.ap, len(buffs), txBytes, false) } } if derpAddr.IsValid() { allOk := true + var txBytes int for _, buff := range buffs { - ok, _ := de.c.sendAddr(derpAddr, de.publicKey, buff) - if stats := de.c.stats.Load(); stats != nil { - stats.UpdateTxPhysical(de.nodeAddr, derpAddr, len(buff)) - } + buff = buff[offset:] + const isDisco = false + const isGeneveEncap = false + ok, _ := de.c.sendAddr(derpAddr, de.publicKey, buff, isDisco, isGeneveEncap) + txBytes += len(buff) if !ok { allOk = false } } + + if update := de.c.connCounter.Load(); update != nil { + update(0, netip.AddrPortFrom(de.nodeAddr, 0), derpAddr, len(buffs), txBytes, false) + } if allOk { return nil } @@ -1030,7 +1176,12 @@ func (de *endpoint) discoPingTimeout(txid stun.TxID) { if !ok { return } - if debugDisco() || !de.bestAddr.IsValid() || mono.Now().After(de.trustBestAddrUntil) { + bestUntrusted := mono.Now().After(de.trustBestAddrUntil) + if sp.to == de.bestAddr.epAddr && sp.to.vni.IsSet() && bestUntrusted { + // TODO(jwhited): consider applying this to direct UDP paths as well + de.clearBestAddrLocked() + } + if debugDisco() || !de.bestAddr.ap.IsValid() || bestUntrusted { de.c.dlogf("[v1] magicsock: disco: timeout waiting for pong %x from %v (%v, %v)", txid[:6], sp.to, de.publicKey.ShortString(), de.discoShort()) } de.removeSentDiscoPingLocked(txid, sp, discoPingTimedOut) @@ -1084,7 +1235,7 @@ const discoPingSize = len(disco.Magic) + key.DiscoPublicRawLen + disco.NonceLen // // The caller should use de.discoKey as the discoKey argument. // It is passed in so that sendDiscoPing doesn't need to lock de.mu. -func (de *endpoint) sendDiscoPing(ep netip.AddrPort, discoKey key.DiscoPublic, txid stun.TxID, size int, logLevel discoLogLevel) { +func (de *endpoint) sendDiscoPing(ep epAddr, discoKey key.DiscoPublic, txid stun.TxID, size int, logLevel discoLogLevel) { size = min(size, MaxDiscoPingSize) padding := max(size-discoPingSize, 0) @@ -1100,7 +1251,7 @@ func (de *endpoint) sendDiscoPing(ep netip.AddrPort, discoKey key.DiscoPublic, t if size != 0 { metricSentDiscoPeerMTUProbes.Add(1) - metricSentDiscoPeerMTUProbeBytes.Add(int64(pingSizeToPktLen(size, ep.Addr().Is6()))) + metricSentDiscoPeerMTUProbeBytes.Add(int64(pingSizeToPktLen(size, ep))) } } @@ -1131,16 +1282,20 @@ const ( // if non-nil, means that a caller external to the magicsock package internals // is interested in the result (such as a CLI "tailscale ping" or a c2n ping // request, etc) -func (de *endpoint) startDiscoPingLocked(ep netip.AddrPort, now mono.Time, purpose discoPingPurpose, size int, resCB *pingResultAndCallback) { +func (de *endpoint) startDiscoPingLocked(ep epAddr, now mono.Time, purpose discoPingPurpose, size int, resCB *pingResultAndCallback) { if runtime.GOOS == "js" { return } + if debugNeverDirectUDP() && !ep.vni.IsSet() && ep.ap.Addr() != tailcfg.DerpMagicIPAddr { + return + } epDisco := de.disco.Load() if epDisco == nil { return } - if purpose != pingCLI { - st, ok := de.endpointState[ep] + if purpose != pingCLI && + !ep.vni.IsSet() { // de.endpointState is only relevant for direct/non-vni epAddr's + st, ok := de.endpointState[ep.ap] if !ok { // Shouldn't happen. But don't ping an endpoint that's // not active for us. @@ -1157,11 +1312,11 @@ func (de *endpoint) startDiscoPingLocked(ep netip.AddrPort, now mono.Time, purpo // Default to sending a single ping of the specified size sizes := []int{size} if de.c.PeerMTUEnabled() { - isDerp := ep.Addr() == tailcfg.DerpMagicIPAddr + isDerp := ep.ap.Addr() == tailcfg.DerpMagicIPAddr if !isDerp && ((purpose == pingDiscovery) || (purpose == pingCLI && size == 0)) { de.c.dlogf("[v1] magicsock: starting MTU probe") sizes = mtuProbePingSizesV4 - if ep.Addr().Is6() { + if ep.ap.Addr().Is6() { sizes = mtuProbePingSizesV6 } } @@ -1216,7 +1371,7 @@ func (de *endpoint) sendDiscoPingsLocked(now mono.Time, sendCallMeMaybe bool) { de.c.dlogf("[v1] magicsock: disco: send, starting discovery for %v (%v)", de.publicKey.ShortString(), de.discoShort()) } - de.startDiscoPingLocked(ep, now, pingDiscovery, 0, nil) + de.startDiscoPingLocked(epAddr{ap: ep}, now, pingDiscovery, 0, nil) } derpAddr := de.derpAddr if sentAny && sendCallMeMaybe && derpAddr.IsValid() { @@ -1230,7 +1385,7 @@ func (de *endpoint) sendDiscoPingsLocked(now mono.Time, sendCallMeMaybe bool) { } // sendWireGuardOnlyPingsLocked evaluates all available addresses for -// a WireGuard only endpoint and initates an ICMP ping for useable +// a WireGuard only endpoint and initiates an ICMP ping for useable // addresses. func (de *endpoint) sendWireGuardOnlyPingsLocked(now mono.Time) { if runtime.GOOS == "js" { @@ -1344,7 +1499,7 @@ func (de *endpoint) updateFromNode(n tailcfg.NodeView, heartbeatDisabled bool, p }) de.resetLocked() } - if n.DERP() == "" { + if n.HomeDERP() == 0 { if de.derpAddr.IsValid() { de.debugUpdates.Add(EndpointChange{ When: time.Now(), @@ -1354,7 +1509,7 @@ func (de *endpoint) updateFromNode(n tailcfg.NodeView, heartbeatDisabled bool, p } de.derpAddr = netip.AddrPort{} } else { - newDerp, _ := netip.ParseAddrPort(n.DERP()) + newDerp := netip.AddrPortFrom(tailcfg.DerpMagicIPAddr, uint16(n.HomeDERP())) if de.derpAddr != newDerp { de.debugUpdates.Add(EndpointChange{ When: time.Now(), @@ -1367,23 +1522,23 @@ func (de *endpoint) updateFromNode(n tailcfg.NodeView, heartbeatDisabled bool, p } de.setEndpointsLocked(n.Endpoints()) + + de.relayCapable = capVerIsRelayCapable(n.Cap()) } func (de *endpoint) setEndpointsLocked(eps interface { - Len() int - At(i int) netip.AddrPort + All() iter.Seq2[int, netip.AddrPort] }) { for _, st := range de.endpointState { st.index = indexSentinelDeleted // assume deleted until updated in next loop } var newIpps []netip.AddrPort - for i := range eps.Len() { + for i, ipp := range eps.All() { if i > math.MaxInt16 { // Seems unlikely. break } - ipp := eps.At(i) if !ipp.IsValid() { de.c.logf("magicsock: bogus netmap endpoint from %v", eps) continue @@ -1451,7 +1606,7 @@ func (de *endpoint) addCandidateEndpoint(ep netip.AddrPort, forRxPingTxID stun.T } } size2 := len(de.endpointState) - de.c.dlogf("[v1] magicsock: disco: addCandidateEndpoint pruned %v candidate set from %v to %v entries", size, size2) + de.c.dlogf("[v1] magicsock: disco: addCandidateEndpoint pruned %v (%s) candidate set from %v to %v entries", de.discoShort(), de.publicKey.ShortString(), size, size2) } return false } @@ -1466,17 +1621,19 @@ func (de *endpoint) clearBestAddrLocked() { de.trustBestAddrUntil = 0 } -// noteBadEndpoint marks ipp as a bad endpoint that would need to be +// noteBadEndpoint marks udpAddr as a bad endpoint that would need to be // re-evaluated before future use, this should be called for example if a send -// to ipp fails due to a host unreachable error or similar. -func (de *endpoint) noteBadEndpoint(ipp netip.AddrPort) { +// to udpAddr fails due to a host unreachable error or similar. +func (de *endpoint) noteBadEndpoint(udpAddr epAddr) { de.mu.Lock() defer de.mu.Unlock() de.clearBestAddrLocked() - if st, ok := de.endpointState[ipp]; ok { - st.clear() + if !udpAddr.vni.IsSet() { + if st, ok := de.endpointState[udpAddr.ap]; ok { + st.clear() + } } } @@ -1496,17 +1653,20 @@ func (de *endpoint) noteConnectivityChange() { // pingSizeToPktLen calculates the minimum path MTU that would permit // a disco ping message of length size to reach its target at -// addr. size is the length of the entire disco message including +// udpAddr. size is the length of the entire disco message including // disco headers. If size is zero, assume it is the safe wire MTU. -func pingSizeToPktLen(size int, is6 bool) tstun.WireMTU { +func pingSizeToPktLen(size int, udpAddr epAddr) tstun.WireMTU { if size == 0 { return tstun.SafeWireMTU() } headerLen := ipv4.HeaderLen - if is6 { + if udpAddr.ap.Addr().Is6() { headerLen = ipv6.HeaderLen } headerLen += 8 // UDP header length + if udpAddr.vni.IsSet() { + headerLen += packet.GeneveFixedHeaderLength + } return tstun.WireMTU(size + headerLen) } @@ -1533,11 +1693,11 @@ func pktLenToPingSize(mtu tstun.WireMTU, is6 bool) int { // It should be called with the Conn.mu held. // // It reports whether m.TxID corresponds to a ping that this endpoint sent. -func (de *endpoint) handlePongConnLocked(m *disco.Pong, di *discoInfo, src netip.AddrPort) (knownTxID bool) { +func (de *endpoint) handlePongConnLocked(m *disco.Pong, di *discoInfo, src epAddr) (knownTxID bool) { de.mu.Lock() defer de.mu.Unlock() - isDerp := src.Addr() == tailcfg.DerpMagicIPAddr + isDerp := src.ap.Addr() == tailcfg.DerpMagicIPAddr sp, ok := de.sentPing[m.TxID] if !ok { @@ -1547,7 +1707,7 @@ func (de *endpoint) handlePongConnLocked(m *disco.Pong, di *discoInfo, src netip knownTxID = true // for naked returns below de.removeSentDiscoPingLocked(m.TxID, sp, discoPongReceived) - pktLen := int(pingSizeToPktLen(sp.size, sp.to.Addr().Is6())) + pktLen := int(pingSizeToPktLen(sp.size, src)) if sp.size != 0 { m := getPeerMTUsProbedMetric(tstun.WireMTU(pktLen)) m.Add(1) @@ -1559,25 +1719,27 @@ func (de *endpoint) handlePongConnLocked(m *disco.Pong, di *discoInfo, src netip now := mono.Now() latency := now.Sub(sp.at) - if !isDerp { - st, ok := de.endpointState[sp.to] + if !isDerp && !src.vni.IsSet() { + // Note: we check vni.isSet() as relay [epAddr]'s are not stored in + // endpointState, they are either de.bestAddr or not. + st, ok := de.endpointState[sp.to.ap] if !ok { // This is no longer an endpoint we care about. return } - de.c.peerMap.setNodeKeyForIPPort(src, de.publicKey) + de.c.peerMap.setNodeKeyForEpAddr(src, de.publicKey) st.addPongReplyLocked(pongReply{ latency: latency, pongAt: now, - from: src, + from: src.ap, pongSrc: m.Src, }) } if sp.purpose != pingHeartbeat && sp.purpose != pingHeartbeatForUDPLifetime { - de.c.dlogf("[v1] magicsock: disco: %v<-%v (%v, %v) got pong tx=%x latency=%v pktlen=%v pong.src=%v%v", de.c.discoShort, de.discoShort(), de.publicKey.ShortString(), src, m.TxID[:6], latency.Round(time.Millisecond), pktLen, m.Src, logger.ArgWriter(func(bw *bufio.Writer) { + de.c.dlogf("[v1] magicsock: disco: %v<-%v (%v, %v) got pong tx=%x latency=%v pktlen=%v pong.src=%v%v", de.c.discoAtomic.Short(), de.discoShort(), de.publicKey.ShortString(), src, m.TxID[:6], latency.Round(time.Millisecond), pktLen, m.Src, logger.ArgWriter(func(bw *bufio.Writer) { if sp.to != src { fmt.Fprintf(bw, " ping.to=%v", sp.to) } @@ -1595,21 +1757,30 @@ func (de *endpoint) handlePongConnLocked(m *disco.Pong, di *discoInfo, src netip // Promote this pong response to our current best address if it's lower latency. // TODO(bradfitz): decide how latency vs. preference order affects decision if !isDerp { - thisPong := addrQuality{sp.to, latency, tstun.WireMTU(pingSizeToPktLen(sp.size, sp.to.Addr().Is6()))} + thisPong := addrQuality{ + epAddr: sp.to, + latency: latency, + wireMTU: pingSizeToPktLen(sp.size, sp.to), + } + // TODO(jwhited): consider checking de.trustBestAddrUntil as well. If + // de.bestAddr is untrusted we may want to clear it, otherwise we could + // get stuck with a forever untrusted bestAddr that blackholes, since + // we don't clear direct UDP paths on disco ping timeout (see + // discoPingTimeout). if betterAddr(thisPong, de.bestAddr) { de.c.logf("magicsock: disco: node %v %v now using %v mtu=%v tx=%x", de.publicKey.ShortString(), de.discoShort(), sp.to, thisPong.wireMTU, m.TxID[:6]) de.debugUpdates.Add(EndpointChange{ When: time.Now(), - What: "handlePingLocked-bestAddr-update", + What: "handlePongConnLocked-bestAddr-update", From: de.bestAddr, To: thisPong, }) de.setBestAddrLocked(thisPong) } - if de.bestAddr.AddrPort == thisPong.AddrPort { + if de.bestAddr.epAddr == thisPong.epAddr { de.debugUpdates.Add(EndpointChange{ When: time.Now(), - What: "handlePingLocked-bestAddr-latency", + What: "handlePongConnLocked-bestAddr-latency", From: de.bestAddr, To: thisPong, }) @@ -1621,20 +1792,43 @@ func (de *endpoint) handlePongConnLocked(m *disco.Pong, di *discoInfo, src netip return } -// addrQuality is an IPPort with an associated latency and path mtu. +// epAddr is a [netip.AddrPort] with an optional Geneve header (RFC8926) +// [packet.VirtualNetworkID]. +type epAddr struct { + ap netip.AddrPort // if ap == tailcfg.DerpMagicIPAddr then vni is never set + vni packet.VirtualNetworkID // vni.IsSet() indicates if this [epAddr] involves a Geneve header +} + +// isDirect returns true if e.ap is valid and not tailcfg.DerpMagicIPAddr, +// and a VNI is not set. +func (e epAddr) isDirect() bool { + return e.ap.IsValid() && e.ap.Addr() != tailcfg.DerpMagicIPAddr && !e.vni.IsSet() +} + +func (e epAddr) String() string { + if !e.vni.IsSet() { + return e.ap.String() + } + return fmt.Sprintf("%v:vni:%d", e.ap.String(), e.vni.Get()) +} + +// addrQuality is an [epAddr], an optional [key.DiscoPublic] if a relay server +// is associated, a round-trip latency measurement, and path mtu. type addrQuality struct { - netip.AddrPort - latency time.Duration - wireMTU tstun.WireMTU + epAddr + relayServerDisco key.DiscoPublic // only relevant if epAddr.vni.isSet(), otherwise zero value + latency time.Duration + wireMTU tstun.WireMTU } func (a addrQuality) String() string { - return fmt.Sprintf("%v@%v+%v", a.AddrPort, a.latency, a.wireMTU) + // TODO(jwhited): consider including relayServerDisco + return fmt.Sprintf("%v@%v+%v", a.epAddr, a.latency, a.wireMTU) } // betterAddr reports whether a is a better addr to use than b. func betterAddr(a, b addrQuality) bool { - if a.AddrPort == b.AddrPort { + if a.epAddr == b.epAddr { if a.wireMTU > b.wireMTU { // TODO(val): Think harder about the case of lower // latency and smaller or unknown MTU, and higher @@ -1645,10 +1839,19 @@ func betterAddr(a, b addrQuality) bool { } return false } - if !b.IsValid() { + if !b.ap.IsValid() { + return true + } + if !a.ap.IsValid() { + return false + } + + // Geneve-encapsulated paths (UDP relay servers) are lower preference in + // relation to non. + if !a.vni.IsSet() && b.vni.IsSet() { return true } - if !a.IsValid() { + if a.vni.IsSet() && !b.vni.IsSet() { return false } @@ -1672,27 +1875,27 @@ func betterAddr(a, b addrQuality) bool { // addresses, and prefer link-local unicast addresses over other types // of private IP addresses since it's definitionally more likely that // they'll be on the same network segment than a general private IP. - if a.Addr().IsLoopback() { + if a.ap.Addr().IsLoopback() { aPoints += 50 - } else if a.Addr().IsLinkLocalUnicast() { + } else if a.ap.Addr().IsLinkLocalUnicast() { aPoints += 30 - } else if a.Addr().IsPrivate() { + } else if a.ap.Addr().IsPrivate() { aPoints += 20 } - if b.Addr().IsLoopback() { + if b.ap.Addr().IsLoopback() { bPoints += 50 - } else if b.Addr().IsLinkLocalUnicast() { + } else if b.ap.Addr().IsLinkLocalUnicast() { bPoints += 30 - } else if b.Addr().IsPrivate() { + } else if b.ap.Addr().IsPrivate() { bPoints += 20 } // Prefer IPv6 for being a bit more robust, as long as // the latencies are roughly equivalent. - if a.Addr().Is6() { + if a.ap.Addr().Is6() { aPoints += 10 } - if b.Addr().Is6() { + if b.ap.Addr().Is6() { bPoints += 10 } @@ -1776,7 +1979,25 @@ func (de *endpoint) handleCallMeMaybe(m *disco.CallMeMaybe) { for _, st := range de.endpointState { st.lastPing = 0 } - de.sendDiscoPingsLocked(mono.Now(), false) + monoNow := mono.Now() + de.sendDiscoPingsLocked(monoNow, false) + + // This hook is required to trigger peer relay path discovery around + // disco "tailscale ping" initiated by de. We may be configured with peer + // relay servers that differ from de. + // + // The only other peer relay path discovery hook is in [endpoint.heartbeat], + // which is kicked off around outbound WireGuard packet flow, or if you are + // the "tailscale ping" initiator. Disco "tailscale ping" does not propagate + // into wireguard-go. + // + // We choose not to hook this around disco ping reception since peer relay + // path discovery can also trigger disco ping transmission, which *could* + // lead to an infinite loop of peer relay path discovery between two peers, + // absent intended triggers. + if de.wantUDPRelayPathDiscoveryLocked(monoNow) { + de.discoverUDPRelayPathsLocked(monoNow) + } } func (de *endpoint) populatePeerStatus(ps *ipnstate.PeerStatus) { @@ -1793,8 +2014,12 @@ func (de *endpoint) populatePeerStatus(ps *ipnstate.PeerStatus) { ps.LastWrite = de.lastSendExt.WallTime() ps.Active = now.Sub(de.lastSendExt) < sessionActiveTimeout - if udpAddr, derpAddr, _ := de.addrForSendLocked(now); udpAddr.IsValid() && !derpAddr.IsValid() { - ps.CurAddr = udpAddr.String() + if udpAddr, derpAddr, _ := de.addrForSendLocked(now); udpAddr.ap.IsValid() && !derpAddr.IsValid() { + if udpAddr.vni.IsSet() { + ps.PeerRelay = udpAddr.String() + } else { + ps.CurAddr = udpAddr.String() + } } } @@ -1842,14 +2067,22 @@ func (de *endpoint) resetLocked() { } } de.probeUDPLifetime.resetCycleEndpointLocked() + de.c.relayManager.stopWork(de) } func (de *endpoint) numStopAndReset() int64 { return atomic.LoadInt64(&de.numStopAndResetAtomic) } +// setDERPHome sets the provided regionID as home for de. Calls to setDERPHome +// must never run concurrent to [Conn.updateRelayServersSet], otherwise +// [candidatePeerRelay] DERP home changes may be missed from the perspective of +// [relayManager]. func (de *endpoint) setDERPHome(regionID uint16) { de.mu.Lock() defer de.mu.Unlock() de.derpAddr = netip.AddrPortFrom(tailcfg.DerpMagicIPAddr, uint16(regionID)) + if de.c.hasPeerRelayServers.Load() { + de.c.relayManager.handleDERPHomeChange(de.publicKey, regionID) + } } diff --git a/wgengine/magicsock/endpoint_test.go b/wgengine/magicsock/endpoint_test.go index 1e2de8967..f1dab924f 100644 --- a/wgengine/magicsock/endpoint_test.go +++ b/wgengine/magicsock/endpoint_test.go @@ -8,7 +8,9 @@ import ( "testing" "time" - "github.com/dsnet/try" + "tailscale.com/net/packet" + "tailscale.com/tailcfg" + "tailscale.com/tstime/mono" "tailscale.com/types/key" ) @@ -144,17 +146,24 @@ func TestProbeUDPLifetimeConfig_Valid(t *testing.T) { } func Test_endpoint_maybeProbeUDPLifetimeLocked(t *testing.T) { + var lowerPriv, higherPriv key.DiscoPrivate var lower, higher key.DiscoPublic - a := key.NewDisco().Public() - b := key.NewDisco().Public() + privA := key.NewDisco() + privB := key.NewDisco() + a := privA.Public() + b := privB.Public() if a.String() < b.String() { lower = a higher = b + lowerPriv = privA + higherPriv = privB } else { lower = b higher = a + lowerPriv = privB + higherPriv = privA } - addr := addrQuality{AddrPort: try.E1[netip.AddrPort](netip.ParseAddrPort("1.1.1.1:1"))} + addr := addrQuality{epAddr: epAddr{ap: netip.MustParseAddrPort("1.1.1.1:1")}} newProbeUDPLifetime := func() *probeUDPLifetime { return &probeUDPLifetime{ config: *defaultProbeUDPLifetimeConfig, @@ -171,138 +180,126 @@ func Test_endpoint_maybeProbeUDPLifetimeLocked(t *testing.T) { wantMaybe bool }{ { - "nil probeUDPLifetime", - higher, - &lower, - func() *probeUDPLifetime { + name: "nil probeUDPLifetime", + localDisco: higher, + remoteDisco: &lower, + probeUDPLifetimeFn: func() *probeUDPLifetime { return nil }, - addr, - func(lifetime *probeUDPLifetime) time.Duration { - return 0 - }, - false, + bestAddr: addr, }, { - "local higher disco key", - higher, - &lower, - newProbeUDPLifetime, - addr, - func(lifetime *probeUDPLifetime) time.Duration { - return 0 - }, - false, + name: "local higher disco key", + localDisco: higher, + remoteDisco: &lower, + probeUDPLifetimeFn: newProbeUDPLifetime, + bestAddr: addr, }, { - "remote no disco key", - higher, - nil, - newProbeUDPLifetime, - addr, - func(lifetime *probeUDPLifetime) time.Duration { - return 0 - }, - false, + name: "remote no disco key", + localDisco: higher, + remoteDisco: nil, + probeUDPLifetimeFn: newProbeUDPLifetime, + bestAddr: addr, }, { - "invalid bestAddr", - lower, - &higher, - newProbeUDPLifetime, - addrQuality{}, - func(lifetime *probeUDPLifetime) time.Duration { - return 0 - }, - false, + name: "invalid bestAddr", + localDisco: lower, + remoteDisco: &higher, + probeUDPLifetimeFn: newProbeUDPLifetime, + bestAddr: addrQuality{}, }, { - "cycle started too recently", - lower, - &higher, - func() *probeUDPLifetime { - l := newProbeUDPLifetime() - l.cycleActive = false - l.cycleStartedAt = time.Now() - return l - }, - addr, - func(lifetime *probeUDPLifetime) time.Duration { - return 0 + name: "cycle started too recently", + localDisco: lower, + remoteDisco: &higher, + probeUDPLifetimeFn: func() *probeUDPLifetime { + lt := newProbeUDPLifetime() + lt.cycleActive = false + lt.cycleStartedAt = time.Now() + return lt }, - false, + bestAddr: addr, }, { - "maybe cliff 0 cycle not active", - lower, - &higher, - func() *probeUDPLifetime { - l := newProbeUDPLifetime() - l.cycleActive = false - l.cycleStartedAt = time.Now().Add(-l.config.CycleCanStartEvery).Add(-time.Second) - return l + name: "maybe cliff 0 cycle not active", + localDisco: lower, + remoteDisco: &higher, + probeUDPLifetimeFn: func() *probeUDPLifetime { + lt := newProbeUDPLifetime() + lt.cycleActive = false + lt.cycleStartedAt = time.Now().Add(-lt.config.CycleCanStartEvery).Add(-time.Second) + return lt }, - addr, - func(lifetime *probeUDPLifetime) time.Duration { + bestAddr: addr, + wantAfterInactivityForFn: func(lifetime *probeUDPLifetime) time.Duration { return lifetime.config.Cliffs[0] - udpLifetimeProbeCliffSlack }, - true, + wantMaybe: true, }, { - "maybe cliff 0", - lower, - &higher, - func() *probeUDPLifetime { - l := newProbeUDPLifetime() - l.cycleActive = true - l.currentCliff = 0 - return l + name: "maybe cliff 0", + localDisco: lower, + remoteDisco: &higher, + probeUDPLifetimeFn: func() *probeUDPLifetime { + lt := newProbeUDPLifetime() + lt.cycleActive = true + lt.currentCliff = 0 + return lt }, - addr, - func(lifetime *probeUDPLifetime) time.Duration { + bestAddr: addr, + wantAfterInactivityForFn: func(lifetime *probeUDPLifetime) time.Duration { return lifetime.config.Cliffs[0] - udpLifetimeProbeCliffSlack }, - true, + wantMaybe: true, }, { - "maybe cliff 1", - lower, - &higher, - func() *probeUDPLifetime { - l := newProbeUDPLifetime() - l.cycleActive = true - l.currentCliff = 1 - return l + name: "maybe cliff 1", + localDisco: lower, + remoteDisco: &higher, + probeUDPLifetimeFn: func() *probeUDPLifetime { + lt := newProbeUDPLifetime() + lt.cycleActive = true + lt.currentCliff = 1 + return lt }, - addr, - func(lifetime *probeUDPLifetime) time.Duration { + bestAddr: addr, + wantAfterInactivityForFn: func(lifetime *probeUDPLifetime) time.Duration { return lifetime.config.Cliffs[1] - udpLifetimeProbeCliffSlack }, - true, + wantMaybe: true, }, { - "maybe cliff 2", - lower, - &higher, - func() *probeUDPLifetime { - l := newProbeUDPLifetime() - l.cycleActive = true - l.currentCliff = 2 - return l + name: "maybe cliff 2", + localDisco: lower, + remoteDisco: &higher, + probeUDPLifetimeFn: func() *probeUDPLifetime { + lt := newProbeUDPLifetime() + lt.cycleActive = true + lt.currentCliff = 2 + return lt }, - addr, - func(lifetime *probeUDPLifetime) time.Duration { + bestAddr: addr, + wantAfterInactivityForFn: func(lifetime *probeUDPLifetime) time.Duration { return lifetime.config.Cliffs[2] - udpLifetimeProbeCliffSlack }, - true, + wantMaybe: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + c := &Conn{} + if tt.localDisco.IsZero() { + c.discoAtomic.Set(key.NewDisco()) + } else if tt.localDisco.Compare(lower) == 0 { + c.discoAtomic.Set(lowerPriv) + } else if tt.localDisco.Compare(higher) == 0 { + c.discoAtomic.Set(higherPriv) + } else { + t.Fatalf("unexpected localDisco value") + } de := &endpoint{ - c: &Conn{ - discoPublic: tt.localDisco, - }, + c: c, bestAddr: tt.bestAddr, } if tt.remoteDisco != nil { @@ -314,7 +311,10 @@ func Test_endpoint_maybeProbeUDPLifetimeLocked(t *testing.T) { p := tt.probeUDPLifetimeFn() de.probeUDPLifetime = p gotAfterInactivityFor, gotMaybe := de.maybeProbeUDPLifetimeLocked() - wantAfterInactivityFor := tt.wantAfterInactivityForFn(p) + var wantAfterInactivityFor time.Duration + if tt.wantAfterInactivityForFn != nil { + wantAfterInactivityFor = tt.wantAfterInactivityForFn(p) + } if gotAfterInactivityFor != wantAfterInactivityFor { t.Errorf("maybeProbeUDPLifetimeLocked() gotAfterInactivityFor = %v, want %v", gotAfterInactivityFor, wantAfterInactivityFor) } @@ -324,3 +324,132 @@ func Test_endpoint_maybeProbeUDPLifetimeLocked(t *testing.T) { }) } } + +func Test_epAddr_isDirectUDP(t *testing.T) { + vni := packet.VirtualNetworkID{} + vni.Set(7) + tests := []struct { + name string + ap netip.AddrPort + vni packet.VirtualNetworkID + want bool + }{ + { + name: "true", + ap: netip.MustParseAddrPort("192.0.2.1:7"), + vni: packet.VirtualNetworkID{}, + want: true, + }, + { + name: "false derp magic addr", + ap: netip.AddrPortFrom(tailcfg.DerpMagicIPAddr, 0), + vni: packet.VirtualNetworkID{}, + want: false, + }, + { + name: "false vni set", + ap: netip.MustParseAddrPort("192.0.2.1:7"), + vni: vni, + want: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + e := epAddr{ + ap: tt.ap, + vni: tt.vni, + } + if got := e.isDirect(); got != tt.want { + t.Errorf("isDirect() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_endpoint_udpRelayEndpointReady(t *testing.T) { + directAddrQuality := addrQuality{epAddr: epAddr{ap: netip.MustParseAddrPort("192.0.2.1:7")}} + peerRelayAddrQuality := addrQuality{epAddr: epAddr{ap: netip.MustParseAddrPort("192.0.2.2:77")}, latency: time.Second} + peerRelayAddrQuality.vni.Set(1) + peerRelayAddrQualityHigherLatencySameServer := addrQuality{ + epAddr: epAddr{ap: netip.MustParseAddrPort("192.0.2.3:77"), vni: peerRelayAddrQuality.vni}, + latency: peerRelayAddrQuality.latency * 10, + } + peerRelayAddrQualityHigherLatencyDiffServer := addrQuality{ + epAddr: epAddr{ap: netip.MustParseAddrPort("192.0.2.3:77"), vni: peerRelayAddrQuality.vni}, + latency: peerRelayAddrQuality.latency * 10, + relayServerDisco: key.NewDisco().Public(), + } + peerRelayAddrQualityLowerLatencyDiffServer := addrQuality{ + epAddr: epAddr{ap: netip.MustParseAddrPort("192.0.2.4:77"), vni: peerRelayAddrQuality.vni}, + latency: peerRelayAddrQuality.latency / 10, + relayServerDisco: key.NewDisco().Public(), + } + peerRelayAddrQualityEqualLatencyDiffServer := addrQuality{ + epAddr: epAddr{ap: netip.MustParseAddrPort("192.0.2.4:77"), vni: peerRelayAddrQuality.vni}, + latency: peerRelayAddrQuality.latency, + relayServerDisco: key.NewDisco().Public(), + } + tests := []struct { + name string + curBestAddr addrQuality + trustBestAddrUntil mono.Time + maybeBest addrQuality + wantBestAddr addrQuality + }{ + { + name: "bestAddr trusted direct", + curBestAddr: directAddrQuality, + trustBestAddrUntil: mono.Now().Add(1 * time.Hour), + maybeBest: peerRelayAddrQuality, + wantBestAddr: directAddrQuality, + }, + { + name: "bestAddr untrusted direct", + curBestAddr: directAddrQuality, + trustBestAddrUntil: mono.Now().Add(-1 * time.Hour), + maybeBest: peerRelayAddrQuality, + wantBestAddr: peerRelayAddrQuality, + }, + { + name: "maybeBest same relay server higher latency bestAddr trusted", + curBestAddr: peerRelayAddrQuality, + trustBestAddrUntil: mono.Now().Add(1 * time.Hour), + maybeBest: peerRelayAddrQualityHigherLatencySameServer, + wantBestAddr: peerRelayAddrQualityHigherLatencySameServer, + }, + { + name: "maybeBest diff relay server higher latency bestAddr trusted", + curBestAddr: peerRelayAddrQuality, + trustBestAddrUntil: mono.Now().Add(1 * time.Hour), + maybeBest: peerRelayAddrQualityHigherLatencyDiffServer, + wantBestAddr: peerRelayAddrQuality, + }, + { + name: "maybeBest diff relay server lower latency bestAddr trusted", + curBestAddr: peerRelayAddrQuality, + trustBestAddrUntil: mono.Now().Add(1 * time.Hour), + maybeBest: peerRelayAddrQualityLowerLatencyDiffServer, + wantBestAddr: peerRelayAddrQualityLowerLatencyDiffServer, + }, + { + name: "maybeBest diff relay server equal latency bestAddr trusted", + curBestAddr: peerRelayAddrQuality, + trustBestAddrUntil: mono.Now().Add(1 * time.Hour), + maybeBest: peerRelayAddrQualityEqualLatencyDiffServer, + wantBestAddr: peerRelayAddrQuality, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + de := &endpoint{ + c: &Conn{logf: func(msg string, args ...any) { return }}, + bestAddr: tt.curBestAddr, + trustBestAddrUntil: tt.trustBestAddrUntil, + } + de.udpRelayEndpointReady(tt.maybeBest) + if de.bestAddr != tt.wantBestAddr { + t.Errorf("de.bestAddr = %v, want %v", de.bestAddr, tt.wantBestAddr) + } + }) + } +} diff --git a/wgengine/magicsock/endpoint_tracker.go b/wgengine/magicsock/endpoint_tracker.go index 5caddd1a0..e95852d24 100644 --- a/wgengine/magicsock/endpoint_tracker.go +++ b/wgengine/magicsock/endpoint_tracker.go @@ -6,9 +6,9 @@ package magicsock import ( "net/netip" "slices" - "sync" "time" + "tailscale.com/syncs" "tailscale.com/tailcfg" "tailscale.com/tempfork/heap" "tailscale.com/util/mak" @@ -107,7 +107,7 @@ func (eh endpointHeap) Min() *endpointTrackerEntry { // // See tailscale/tailscale#7877 for more information. type endpointTracker struct { - mu sync.Mutex + mu syncs.Mutex endpoints map[netip.Addr]*endpointHeap } diff --git a/wgengine/magicsock/magicsock.go b/wgengine/magicsock/magicsock.go index 08aff842d..064838a2d 100644 --- a/wgengine/magicsock/magicsock.go +++ b/wgengine/magicsock/magicsock.go @@ -9,6 +9,7 @@ import ( "bufio" "bytes" "context" + "encoding/binary" "errors" "fmt" "io" @@ -16,31 +17,34 @@ import ( "net/netip" "reflect" "runtime" + "slices" "strconv" "strings" "sync" "sync/atomic" - "syscall" "time" "github.com/tailscale/wireguard-go/conn" + "github.com/tailscale/wireguard-go/device" "go4.org/mem" "golang.org/x/net/ipv6" - "tailscale.com/control/controlknobs" "tailscale.com/disco" "tailscale.com/envknob" + "tailscale.com/feature/buildfeatures" + "tailscale.com/feature/condlite/expvar" "tailscale.com/health" "tailscale.com/hostinfo" "tailscale.com/ipn/ipnstate" - "tailscale.com/net/connstats" + "tailscale.com/net/batching" "tailscale.com/net/netcheck" "tailscale.com/net/neterror" "tailscale.com/net/netmon" "tailscale.com/net/netns" "tailscale.com/net/packet" "tailscale.com/net/ping" - "tailscale.com/net/portmapper" + "tailscale.com/net/portmapper/portmappertype" + "tailscale.com/net/sockopts" "tailscale.com/net/sockstats" "tailscale.com/net/stun" "tailscale.com/net/tstun" @@ -51,17 +55,19 @@ import ( "tailscale.com/types/key" "tailscale.com/types/lazy" "tailscale.com/types/logger" + "tailscale.com/types/netlogfunc" "tailscale.com/types/netmap" "tailscale.com/types/nettype" "tailscale.com/types/views" "tailscale.com/util/clientmetric" + "tailscale.com/util/eventbus" "tailscale.com/util/mak" - "tailscale.com/util/ringbuffer" + "tailscale.com/util/ringlog" "tailscale.com/util/set" "tailscale.com/util/testenv" - "tailscale.com/util/uniq" "tailscale.com/util/usermetric" - "tailscale.com/wgengine/capture" + "tailscale.com/wgengine/filter" + "tailscale.com/wgengine/router" "tailscale.com/wgengine/wgint" ) @@ -80,11 +86,77 @@ const ( socketBufferSize = 7 << 20 ) +// Path is a label indicating the type of path a packet took. +type Path string + +const ( + PathDirectIPv4 Path = "direct_ipv4" + PathDirectIPv6 Path = "direct_ipv6" + PathDERP Path = "derp" + PathPeerRelayIPv4 Path = "peer_relay_ipv4" + PathPeerRelayIPv6 Path = "peer_relay_ipv6" +) + +type pathLabel struct { + // Path indicates the path that the packet took: + // - direct_ipv4 + // - direct_ipv6 + // - derp + // - peer_relay_ipv4 + // - peer_relay_ipv6 + Path Path +} + +// metrics in wgengine contains the usermetrics counters for magicsock, it +// is however a bit special. All them metrics are labeled, but looking up +// the metric everytime we need to record it has an overhead, and includes +// a lock in MultiLabelMap. The metrics are therefore instead created with +// wgengine and the underlying expvar.Int is stored to be used directly. +type metrics struct { + // inboundPacketsTotal is the total number of inbound packets received, + // labeled by the path the packet took. + inboundPacketsIPv4Total expvar.Int + inboundPacketsIPv6Total expvar.Int + inboundPacketsDERPTotal expvar.Int + inboundPacketsPeerRelayIPv4Total expvar.Int + inboundPacketsPeerRelayIPv6Total expvar.Int + + // inboundBytesTotal is the total number of inbound bytes received, + // labeled by the path the packet took. + inboundBytesIPv4Total expvar.Int + inboundBytesIPv6Total expvar.Int + inboundBytesDERPTotal expvar.Int + inboundBytesPeerRelayIPv4Total expvar.Int + inboundBytesPeerRelayIPv6Total expvar.Int + + // outboundPacketsTotal is the total number of outbound packets sent, + // labeled by the path the packet took. + outboundPacketsIPv4Total expvar.Int + outboundPacketsIPv6Total expvar.Int + outboundPacketsDERPTotal expvar.Int + outboundPacketsPeerRelayIPv4Total expvar.Int + outboundPacketsPeerRelayIPv6Total expvar.Int + + // outboundBytesTotal is the total number of outbound bytes sent, + // labeled by the path the packet took. + outboundBytesIPv4Total expvar.Int + outboundBytesIPv6Total expvar.Int + outboundBytesDERPTotal expvar.Int + outboundBytesPeerRelayIPv4Total expvar.Int + outboundBytesPeerRelayIPv6Total expvar.Int + + // outboundPacketsDroppedErrors is the total number of outbound packets + // dropped due to errors. + outboundPacketsDroppedErrors expvar.Int +} + // A Conn routes UDP packets and actively manages a list of its endpoints. type Conn struct { // This block mirrors the contents and field order of the Options // struct. Initialized once at construction, then constant. + eventBus *eventbus.Bus + eventClient *eventbus.Client logf logger.Logf epFunc func([]tailcfg.Endpoint) derpActiveFunc func() @@ -104,6 +176,12 @@ type Conn struct { connCtxCancel func() // closes connCtx donec <-chan struct{} // connCtx.Done()'s to avoid context.cancelCtx.Done()'s mutex per call + // A publisher for synchronization points to ensure correct ordering of + // config changes between magicsock and wireguard. + syncPub *eventbus.Publisher[syncPoint] + allocRelayEndpointPub *eventbus.Publisher[UDPRelayAllocReq] + portUpdatePub *eventbus.Publisher[router.PortUpdate] + // pconn4 and pconn6 are the underlying UDP sockets used to // send/receive packets for wireguard and other magicsock // protocols. @@ -124,7 +202,8 @@ type Conn struct { // portMapper is the NAT-PMP/PCP/UPnP prober/client, for requesting // port mappings from NAT devices. - portMapper *portmapper.Client + // If nil, the portmapper is disabled. + portMapper portmappertype.Client // derpRecvCh is used by receiveDERP to read DERP messages. // It must have buffer size > 0; see issue 3736. @@ -182,26 +261,26 @@ type Conn struct { //lint:ignore U1000 used on Linux/Darwin only peerMTUEnabled atomic.Bool - // stats maintains per-connection counters. - stats atomic.Pointer[connstats.Statistics] + // connCounter maintains per-connection counters. + connCounter syncs.AtomicValue[netlogfunc.ConnectionCounter] // captureHook, if non-nil, is the pcap logging callback when capturing. - captureHook syncs.AtomicValue[capture.Callback] - - // discoPrivate is the private naclbox key used for active - // discovery traffic. It is always present, and immutable. - discoPrivate key.DiscoPrivate - // public of discoPrivate. It is always present and immutable. - discoPublic key.DiscoPublic - // ShortString of discoPublic (to save logging work later). It is always - // present and immutable. - discoShort string + captureHook syncs.AtomicValue[packet.CaptureCallback] + + // hasPeerRelayServers is whether [relayManager] is configured with at least + // one peer relay server via [relayManager.handleRelayServersSet]. It exists + // to suppress calls into [relayManager] leading to wasted work involving + // channel operations and goroutine creation. + hasPeerRelayServers atomic.Bool + + // discoAtomic is the current disco private and public keypair for this conn. + discoAtomic discoAtomic // ============================================================ // mu guards all following fields; see userspaceEngine lock // ordering rules against the engine. For derphttp, mu must // be held before derphttp.Client.mu. - mu sync.Mutex + mu syncs.Mutex muCond *sync.Cond onlyTCP443 atomic.Bool @@ -258,7 +337,11 @@ type Conn struct { // by node key, node ID, and discovery key. peerMap peerMap - // discoInfo is the state for an active DiscoKey. + // relayManager manages allocation and handshaking of + // [tailscale.com/net/udprelay.Server] endpoints. + relayManager relayManager + + // discoInfo is the state for an active peer DiscoKey. discoInfo map[key.DiscoPublic]*discoInfo // netInfoFunc is a callback that provides a tailcfg.NetInfo when @@ -276,17 +359,19 @@ type Conn struct { // magicsock could do with any complexity reduction it can get. netInfoLast *tailcfg.NetInfo - derpMap *tailcfg.DERPMap // nil (or zero regions/nodes) means DERP is disabled - peers views.Slice[tailcfg.NodeView] // from last SetNetworkMap update - lastFlags debugFlags // at time of last SetNetworkMap - firstAddrForTest netip.Addr // from last SetNetworkMap update; for tests only - privateKey key.NodePrivate // WireGuard private key for this node - everHadKey bool // whether we ever had a non-zero private key - myDerp int // nearest DERP region ID; 0 means none/unknown - homeless bool // if true, don't try to find & stay conneted to a DERP home (myDerp will stay 0) - derpStarted chan struct{} // closed on first connection to DERP; for tests & cleaner Close - activeDerp map[int]activeDerp // DERP regionID -> connection to a node in that region - prevDerp map[int]*syncs.WaitGroupChan + derpMap *tailcfg.DERPMap // nil (or zero regions/nodes) means DERP is disabled + self tailcfg.NodeView // from last onNodeViewsUpdate + peers views.Slice[tailcfg.NodeView] // from last onNodeViewsUpdate, sorted by Node.ID; Note: [netmap.NodeMutation]'s rx'd in onNodeMutationsUpdate are never applied + filt *filter.Filter // from last onFilterUpdate + relayClientEnabled bool // whether we can allocate UDP relay endpoints on UDP relay servers or receive CallMeMaybeVia messages from peers + lastFlags debugFlags // at time of last onNodeViewsUpdate + privateKey key.NodePrivate // WireGuard private key for this node + everHadKey bool // whether we ever had a non-zero private key + myDerp int // nearest DERP region ID; 0 means none/unknown + homeless bool // if true, don't try to find & stay conneted to a DERP home (myDerp will stay 0) + derpStarted chan struct{} // closed on first connection to DERP; for tests & cleaner Close + activeDerp map[int]activeDerp // DERP regionID -> connection to a node in that region + prevDerp map[int]*syncs.WaitGroupChan // derpRoute contains optional alternate routes to use as an // optimization instead of contacting a peer via their home @@ -304,23 +389,22 @@ type Conn struct { // wgPinger is the WireGuard only pinger used for latency measurements. wgPinger lazy.SyncValue[*ping.Pinger] - // onPortUpdate is called with the new port when magicsock rebinds to - // a new port. - onPortUpdate func(port uint16, network string) - // getPeerByKey optionally specifies a function to look up a peer's // wireguard state by its public key. If nil, it's not used. getPeerByKey func(key.NodePublic) (_ wgint.Peer, ok bool) - // lastEPERMRebind tracks the last time a rebind was performed - // after experiencing a syscall.EPERM. - lastEPERMRebind syncs.AtomicValue[time.Time] + // lastErrRebind tracks the last time a rebind was performed after + // experiencing a write error, and is used to throttle the rate of rebinds. + lastErrRebind syncs.AtomicValue[time.Time] // staticEndpoints are user set endpoints that this node should // advertise amongst its wireguard endpoints. It is user's // responsibility to ensure that traffic from these endpoints is routed // to the node. staticEndpoints views.Slice[netip.AddrPort] + + // metrics contains the metrics for the magicsock instance. + metrics *metrics } // SetDebugLoggingEnabled controls whether spammy debug logging is enabled. @@ -343,8 +427,13 @@ func (c *Conn) dlogf(format string, a ...any) { // Options contains options for Listen. type Options struct { - // Logf optionally provides a log function to use. - // Must not be nil. + // EventBus, if non-nil, is used for event publication and subscription by + // each Conn created from these Options. It must not be nil outside of + // tests. + EventBus *eventbus.Bus + + // Logf provides a log function to use. It must not be nil. + // Use [logger.Discard] to disrcard logs. Logf logger.Logf // Port is the port to listen on. @@ -370,7 +459,8 @@ type Options struct { // NoteRecvActivity, if provided, is a func for magicsock to call // whenever it receives a packet from a a peer if it's been more // than ~10 seconds since the last one. (10 seconds is somewhat - // arbitrary; the sole user just doesn't need or want it called on + // arbitrary; the sole user, lazy WireGuard configuration, + // just doesn't need or want it called on // every packet, just every minute or two for WireGuard timeouts, // and 10 seconds seems like a good trade-off between often enough // and not too often.) @@ -394,10 +484,6 @@ type Options struct { // If nil, they're ignored and not updated. ControlKnobs *controlknobs.Knobs - // OnPortUpdate is called with the new port when magicsock rebinds to - // a new port. - OnPortUpdate func(port uint16, network string) - // PeerByKeyFunc optionally specifies a function to look up a peer's // WireGuard state by its public key. If nil, it's not used. // In regular use, this will be wgengine.(*userspaceEngine).PeerByKey. @@ -429,6 +515,77 @@ func (o *Options) derpActiveFunc() func() { return o.DERPActiveFunc } +// NodeViewsUpdate represents an update event of [tailcfg.NodeView] for all +// nodes. This event is published over an [eventbus.Bus]. It may be published +// with an invalid SelfNode, and/or zero/nil Peers. [magicsock.Conn] is the sole +// subscriber as of 2025-06. If you are adding more subscribers consider moving +// this type out of magicsock. +type NodeViewsUpdate struct { + SelfNode tailcfg.NodeView + Peers []tailcfg.NodeView // sorted by Node.ID +} + +// NodeMutationsUpdate represents an update event of one or more +// [netmap.NodeMutation]. This event is published over an [eventbus.Bus]. +// [magicsock.Conn] is the sole subscriber as of 2025-06. If you are adding more +// subscribers consider moving this type out of magicsock. +type NodeMutationsUpdate struct { + Mutations []netmap.NodeMutation +} + +// FilterUpdate represents an update event for a [*filter.Filter]. This event is +// signaled over an [eventbus.Bus]. [magicsock.Conn] is the sole subscriber as +// of 2025-06. If you are adding more subscribers consider moving this type out +// of magicsock. +type FilterUpdate struct { + *filter.Filter +} + +// syncPoint is an event published over an [eventbus.Bus] by [Conn.Synchronize]. +// It serves as a synchronization point, allowing to wait until magicsock +// has processed all pending events. +type syncPoint chan struct{} + +// Wait blocks until [syncPoint.Signal] is called. +func (s syncPoint) Wait() { + <-s +} + +// Signal signals the sync point, unblocking the [syncPoint.Wait] call. +func (s syncPoint) Signal() { + close(s) +} + +// UDPRelayAllocReq represents a [*disco.AllocateUDPRelayEndpointRequest] +// reception event. This is signaled over an [eventbus.Bus] from +// [magicsock.Conn] towards [relayserver.extension]. +type UDPRelayAllocReq struct { + // RxFromNodeKey is the unauthenticated (DERP server claimed src) node key + // of the transmitting party, noted at disco message reception time over + // DERP. This node key is unambiguously-aligned with RxFromDiscoKey being + // that the disco message is received over DERP. + RxFromNodeKey key.NodePublic + // RxFromDiscoKey is the disco key of the transmitting party, noted and + // authenticated at reception time. + RxFromDiscoKey key.DiscoPublic + // Message is the disco message. + Message *disco.AllocateUDPRelayEndpointRequest +} + +// UDPRelayAllocResp represents a [*disco.AllocateUDPRelayEndpointResponse] +// that is yet to be transmitted over DERP (or delivered locally if +// ReqRxFromNodeKey is self). This is signaled over an [eventbus.Bus] from +// [relayserver.extension] towards [magicsock.Conn]. +type UDPRelayAllocResp struct { + // ReqRxFromNodeKey is copied from [UDPRelayAllocReq.RxFromNodeKey]. It + // enables peer lookup leading up to transmission over DERP. + ReqRxFromNodeKey key.NodePublic + // ReqRxFromDiscoKey is copied from [UDPRelayAllocReq.RxFromDiscoKey]. + ReqRxFromDiscoKey key.DiscoPublic + // Message is the disco message. + Message *disco.AllocateUDPRelayEndpointResponse +} + // newConn is the error-free, network-listening-side-effect-free based // of NewConn. Mostly for tests. func newConn(logf logger.Logf) *Conn { @@ -440,17 +597,15 @@ func newConn(logf logger.Logf) *Conn { peerLastDerp: make(map[key.NodePublic]int), peerMap: newPeerMap(), discoInfo: make(map[key.DiscoPublic]*discoInfo), - discoPrivate: discoPrivate, - discoPublic: discoPrivate.Public(), cloudInfo: newCloudInfo(logf), } - c.discoShort = c.discoPublic.ShortString() + c.discoAtomic.Set(discoPrivate) c.bind = &connBind{Conn: c, closed: true} c.receiveBatchPool = sync.Pool{New: func() any { msgs := make([]ipv6.Message, c.bind.BatchSize()) for i := range msgs { msgs[i].Buffers = make([][]byte, 1) - msgs[i].OOB = make([]byte, controlMessageSize) + msgs[i].OOB = make([]byte, batching.MinControlMessageSize()) } batch := &receiveBatch{ msgs: msgs, @@ -462,15 +617,65 @@ func newConn(logf logger.Logf) *Conn { return c } +func (c *Conn) onUDPRelayAllocResp(allocResp UDPRelayAllocResp) { + c.mu.Lock() + defer c.mu.Unlock() + ep, ok := c.peerMap.endpointForNodeKey(allocResp.ReqRxFromNodeKey) + if !ok { + // If it's not a peer, it might be for self (we can peer relay through + // ourselves), in which case we want to hand it down to [relayManager] + // now versus taking a network round-trip through DERP. + selfNodeKey := c.publicKeyAtomic.Load() + if selfNodeKey.Compare(allocResp.ReqRxFromNodeKey) == 0 && + allocResp.ReqRxFromDiscoKey.Compare(c.discoAtomic.Public()) == 0 { + c.relayManager.handleRxDiscoMsg(c, allocResp.Message, selfNodeKey, allocResp.ReqRxFromDiscoKey, epAddr{}) + metricLocalDiscoAllocUDPRelayEndpointResponse.Add(1) + } + return + } + disco := ep.disco.Load() + if disco == nil { + return + } + if disco.key.Compare(allocResp.ReqRxFromDiscoKey) != 0 { + return + } + ep.mu.Lock() + defer ep.mu.Unlock() + derpAddr := ep.derpAddr + if derpAddr.IsValid() { + go c.sendDiscoMessage(epAddr{ap: derpAddr}, ep.publicKey, disco.key, allocResp.Message, discoVerboseLog) + } +} + +// Synchronize waits for all [eventbus] events published +// prior to this call to be processed by the receiver. +func (c *Conn) Synchronize() { + if c.syncPub == nil { + // Eventbus is not used; no need to synchronize (in certain tests). + return + } + sp := syncPoint(make(chan struct{})) + c.syncPub.Publish(sp) + select { + case <-sp: + case <-c.donec: + } +} + // NewConn creates a magic Conn listening on opts.Port. // As the set of possible endpoints for a Conn changes, the // callback opts.EndpointsFunc is called. func NewConn(opts Options) (*Conn, error) { - if opts.NetMon == nil { + switch { + case opts.NetMon == nil: return nil, errors.New("magicsock.Options.NetMon must be non-nil") + case opts.EventBus == nil: + return nil, errors.New("magicsock.Options.EventBus must be non-nil") } c := newConn(opts.logf()) + c.eventBus = opts.EventBus c.port.Store(uint32(opts.Port)) c.controlKnobs = opts.ControlKnobs c.epFunc = opts.endpointsFunc() @@ -478,22 +683,56 @@ func NewConn(opts Options) (*Conn, error) { c.idleFunc = opts.IdleFunc c.testOnlyPacketListener = opts.TestOnlyPacketListener c.noteRecvActivity = opts.NoteRecvActivity - portMapOpts := &portmapper.DebugKnobs{ - DisableAll: func() bool { return opts.DisablePortMapper || c.onlyTCP443.Load() }, + + // Set up publishers and subscribers. Subscribe calls must return before + // NewConn otherwise published events can be missed. + ec := c.eventBus.Client("magicsock.Conn") + c.eventClient = ec + c.syncPub = eventbus.Publish[syncPoint](ec) + c.allocRelayEndpointPub = eventbus.Publish[UDPRelayAllocReq](ec) + c.portUpdatePub = eventbus.Publish[router.PortUpdate](ec) + eventbus.SubscribeFunc(ec, c.onPortMapChanged) + eventbus.SubscribeFunc(ec, c.onFilterUpdate) + eventbus.SubscribeFunc(ec, c.onNodeViewsUpdate) + eventbus.SubscribeFunc(ec, c.onNodeMutationsUpdate) + eventbus.SubscribeFunc(ec, func(sp syncPoint) { + c.dlogf("magicsock: received sync point after reconfig") + sp.Signal() + }) + eventbus.SubscribeFunc(ec, c.onUDPRelayAllocResp) + + c.connCtx, c.connCtxCancel = context.WithCancel(context.Background()) + c.donec = c.connCtx.Done() + + // Don't log the same log messages possibly every few seconds in our + // portmapper. + if buildfeatures.HasPortMapper && !opts.DisablePortMapper { + portmapperLogf := logger.WithPrefix(c.logf, "portmapper: ") + portmapperLogf = netmon.LinkChangeLogLimiter(c.connCtx, portmapperLogf, opts.NetMon) + var disableUPnP func() bool + if c.controlKnobs != nil { + disableUPnP = c.controlKnobs.DisableUPnP.Load + } + newPortMapper, ok := portmappertype.HookNewPortMapper.GetOk() + if ok { + c.portMapper = newPortMapper(portmapperLogf, opts.EventBus, opts.NetMon, disableUPnP, c.onlyTCP443.Load) + } + // If !ok, the HookNewPortMapper hook is not set (so feature/portmapper + // isn't linked), but the build tag to explicitly omit the portmapper + // isn't set either. This should only happen to js/wasm builds, where + // the portmapper is a no-op even if linked (but it's no longer linked, + // since the move to feature/portmapper), or if people are wiring up + // their own Tailscale build from pieces. } - c.portMapper = portmapper.NewClient(logger.WithPrefix(c.logf, "portmapper: "), opts.NetMon, portMapOpts, opts.ControlKnobs, c.onPortMapChanged) - c.portMapper.SetGatewayLookupFunc(opts.NetMon.GatewayAndSelfIP) + c.netMon = opts.NetMon c.health = opts.HealthTracker - c.onPortUpdate = opts.OnPortUpdate c.getPeerByKey = opts.PeerByKeyFunc if err := c.rebind(keepCurrentPort); err != nil { return nil, err } - c.connCtx, c.connCtxCancel = context.WithCancel(context.Background()) - c.donec = c.connCtx.Done() c.netChecker = &netcheck.Client{ Logf: logger.WithPrefix(c.logf, "netcheck: "), NetMon: c.netMon, @@ -503,6 +742,8 @@ func NewConn(opts Options) (*Conn, error) { UseDNSCache: true, } + c.metrics = registerMetrics(opts.Metrics) + if d4, err := c.listenRawDisco("ip4"); err == nil { c.logf("[v1] using BPF disco receiver for IPv4") c.closeDisco4 = d4 @@ -516,15 +757,139 @@ func NewConn(opts Options) (*Conn, error) { c.logf("[v1] couldn't create raw v6 disco listener, using regular listener instead: %v", err) } - c.logf("magicsock: disco key = %v", c.discoShort) + c.logf("magicsock: disco key = %v", c.discoAtomic.Short()) return c, nil } +// registerMetrics wires up the metrics for wgengine, instead of +// registering the label metric directly, the underlying expvar is exposed. +// See metrics for more info. +func registerMetrics(reg *usermetric.Registry) *metrics { + pathDirectV4 := pathLabel{Path: PathDirectIPv4} + pathDirectV6 := pathLabel{Path: PathDirectIPv6} + pathDERP := pathLabel{Path: PathDERP} + pathPeerRelayV4 := pathLabel{Path: PathPeerRelayIPv4} + pathPeerRelayV6 := pathLabel{Path: PathPeerRelayIPv6} + inboundPacketsTotal := usermetric.NewMultiLabelMapWithRegistry[pathLabel]( + reg, + "tailscaled_inbound_packets_total", + "counter", + "Counts the number of packets received from other peers", + ) + inboundBytesTotal := usermetric.NewMultiLabelMapWithRegistry[pathLabel]( + reg, + "tailscaled_inbound_bytes_total", + "counter", + "Counts the number of bytes received from other peers", + ) + outboundPacketsTotal := usermetric.NewMultiLabelMapWithRegistry[pathLabel]( + reg, + "tailscaled_outbound_packets_total", + "counter", + "Counts the number of packets sent to other peers", + ) + outboundBytesTotal := usermetric.NewMultiLabelMapWithRegistry[pathLabel]( + reg, + "tailscaled_outbound_bytes_total", + "counter", + "Counts the number of bytes sent to other peers", + ) + outboundPacketsDroppedErrors := reg.DroppedPacketsOutbound() + + m := new(metrics) + + // Map clientmetrics to the usermetric counters. + metricRecvDataPacketsIPv4.Register(&m.inboundPacketsIPv4Total) + metricRecvDataPacketsIPv6.Register(&m.inboundPacketsIPv6Total) + metricRecvDataPacketsDERP.Register(&m.inboundPacketsDERPTotal) + metricRecvDataPacketsPeerRelayIPv4.Register(&m.inboundPacketsPeerRelayIPv4Total) + metricRecvDataPacketsPeerRelayIPv6.Register(&m.inboundPacketsPeerRelayIPv6Total) + metricRecvDataBytesIPv4.Register(&m.inboundBytesIPv4Total) + metricRecvDataBytesIPv6.Register(&m.inboundBytesIPv6Total) + metricRecvDataBytesDERP.Register(&m.inboundBytesDERPTotal) + metricRecvDataBytesPeerRelayIPv4.Register(&m.inboundBytesPeerRelayIPv4Total) + metricRecvDataBytesPeerRelayIPv6.Register(&m.inboundBytesPeerRelayIPv6Total) + metricSendDataPacketsIPv4.Register(&m.outboundPacketsIPv4Total) + metricSendDataPacketsIPv6.Register(&m.outboundPacketsIPv6Total) + metricSendDataPacketsDERP.Register(&m.outboundPacketsDERPTotal) + metricSendDataPacketsPeerRelayIPv4.Register(&m.outboundPacketsPeerRelayIPv4Total) + metricSendDataPacketsPeerRelayIPv6.Register(&m.outboundPacketsPeerRelayIPv6Total) + metricSendDataBytesIPv4.Register(&m.outboundBytesIPv4Total) + metricSendDataBytesIPv6.Register(&m.outboundBytesIPv6Total) + metricSendDataBytesDERP.Register(&m.outboundBytesDERPTotal) + metricSendDataBytesPeerRelayIPv4.Register(&m.outboundBytesPeerRelayIPv4Total) + metricSendDataBytesPeerRelayIPv6.Register(&m.outboundBytesPeerRelayIPv6Total) + metricSendUDP.Register(&m.outboundPacketsIPv4Total) + metricSendUDP.Register(&m.outboundPacketsIPv6Total) + metricSendDERP.Register(&m.outboundPacketsDERPTotal) + metricSendPeerRelay.Register(&m.outboundPacketsPeerRelayIPv4Total) + metricSendPeerRelay.Register(&m.outboundPacketsPeerRelayIPv6Total) + + inboundPacketsTotal.Set(pathDirectV4, &m.inboundPacketsIPv4Total) + inboundPacketsTotal.Set(pathDirectV6, &m.inboundPacketsIPv6Total) + inboundPacketsTotal.Set(pathDERP, &m.inboundPacketsDERPTotal) + inboundPacketsTotal.Set(pathPeerRelayV4, &m.inboundPacketsPeerRelayIPv4Total) + inboundPacketsTotal.Set(pathPeerRelayV6, &m.inboundPacketsPeerRelayIPv6Total) + + inboundBytesTotal.Set(pathDirectV4, &m.inboundBytesIPv4Total) + inboundBytesTotal.Set(pathDirectV6, &m.inboundBytesIPv6Total) + inboundBytesTotal.Set(pathDERP, &m.inboundBytesDERPTotal) + inboundBytesTotal.Set(pathPeerRelayV4, &m.inboundBytesPeerRelayIPv4Total) + inboundBytesTotal.Set(pathPeerRelayV6, &m.inboundBytesPeerRelayIPv6Total) + + outboundPacketsTotal.Set(pathDirectV4, &m.outboundPacketsIPv4Total) + outboundPacketsTotal.Set(pathDirectV6, &m.outboundPacketsIPv6Total) + outboundPacketsTotal.Set(pathDERP, &m.outboundPacketsDERPTotal) + outboundPacketsTotal.Set(pathPeerRelayV4, &m.outboundPacketsPeerRelayIPv4Total) + outboundPacketsTotal.Set(pathPeerRelayV6, &m.outboundPacketsPeerRelayIPv6Total) + + outboundBytesTotal.Set(pathDirectV4, &m.outboundBytesIPv4Total) + outboundBytesTotal.Set(pathDirectV6, &m.outboundBytesIPv6Total) + outboundBytesTotal.Set(pathDERP, &m.outboundBytesDERPTotal) + outboundBytesTotal.Set(pathPeerRelayV4, &m.outboundBytesPeerRelayIPv4Total) + outboundBytesTotal.Set(pathPeerRelayV6, &m.outboundBytesPeerRelayIPv6Total) + + outboundPacketsDroppedErrors.Set(usermetric.DropLabels{Reason: usermetric.ReasonError}, &m.outboundPacketsDroppedErrors) + + return m +} + +// deregisterMetrics unregisters the underlying usermetrics expvar counters +// from clientmetrics. +func deregisterMetrics() { + metricRecvDataPacketsIPv4.UnregisterAll() + metricRecvDataPacketsIPv6.UnregisterAll() + metricRecvDataPacketsDERP.UnregisterAll() + metricRecvDataPacketsPeerRelayIPv4.UnregisterAll() + metricRecvDataPacketsPeerRelayIPv6.UnregisterAll() + metricRecvDataBytesIPv4.UnregisterAll() + metricRecvDataBytesIPv6.UnregisterAll() + metricRecvDataBytesDERP.UnregisterAll() + metricRecvDataBytesPeerRelayIPv4.UnregisterAll() + metricRecvDataBytesPeerRelayIPv6.UnregisterAll() + metricSendDataPacketsIPv4.UnregisterAll() + metricSendDataPacketsIPv6.UnregisterAll() + metricSendDataPacketsDERP.UnregisterAll() + metricSendDataPacketsPeerRelayIPv4.UnregisterAll() + metricSendDataPacketsPeerRelayIPv6.UnregisterAll() + metricSendDataBytesIPv4.UnregisterAll() + metricSendDataBytesIPv6.UnregisterAll() + metricSendDataBytesDERP.UnregisterAll() + metricSendDataBytesPeerRelayIPv4.UnregisterAll() + metricSendDataBytesPeerRelayIPv6.UnregisterAll() + metricSendUDP.UnregisterAll() + metricSendDERP.UnregisterAll() + metricSendPeerRelay.UnregisterAll() +} + // InstallCaptureHook installs a callback which is called to // log debug information into the pcap stream. This function // can be called with a nil argument to uninstall the capture // hook. -func (c *Conn) InstallCaptureHook(cb capture.Callback) { +func (c *Conn) InstallCaptureHook(cb packet.CaptureCallback) { + if !buildfeatures.HasCapture { + return + } c.captureHook.Store(cb) } @@ -580,7 +945,7 @@ func (c *Conn) updateEndpoints(why string) { c.muCond.Broadcast() }() c.dlogf("[v1] magicsock: starting endpoint update (%s)", why) - if c.noV4Send.Load() && runtime.GOOS != "js" && !c.onlyTCP443.Load() { + if c.noV4Send.Load() && runtime.GOOS != "js" && !c.onlyTCP443.Load() && !hostinfo.IsInVM86() { c.mu.Lock() closed := c.closed c.mu.Unlock() @@ -650,6 +1015,7 @@ func (c *Conn) setEndpoints(endpoints []tailcfg.Endpoint) (changed bool) { func (c *Conn) SetStaticEndpoints(ep views.Slice[netip.AddrPort]) { c.mu.Lock() if reflect.DeepEqual(c.staticEndpoints.AsSlice(), ep.AsSlice()) { + c.mu.Unlock() return } c.staticEndpoints = ep @@ -713,7 +1079,9 @@ func (c *Conn) updateNetInfo(ctx context.Context) (*netcheck.Report, error) { UPnP: report.UPnP, PMP: report.PMP, PCP: report.PCP, - HavePortMap: c.portMapper.HaveMapping(), + } + if c.portMapper != nil { + ni.HavePortMap = c.portMapper.HaveMapping() } for rid, d := range report.RegionV4Latency { ni.DERPLatency[fmt.Sprintf("%d-v4", rid)] = d.Seconds() @@ -763,7 +1131,7 @@ func (c *Conn) callNetInfoCallbackLocked(ni *tailcfg.NetInfo) { func (c *Conn) addValidDiscoPathForTest(nodeKey key.NodePublic, addr netip.AddrPort) { c.mu.Lock() defer c.mu.Unlock() - c.peerMap.setNodeKeyForIPPort(addr, nodeKey) + c.peerMap.setNodeKeyForEpAddr(epAddr{ap: addr}, nodeKey) } // SetNetInfoCallback sets the func to be called whenever the network conditions @@ -832,13 +1200,17 @@ func (c *Conn) Ping(peer tailcfg.NodeView, res *ipnstate.PingResult, size int, c } // c.mu must be held -func (c *Conn) populateCLIPingResponseLocked(res *ipnstate.PingResult, latency time.Duration, ep netip.AddrPort) { +func (c *Conn) populateCLIPingResponseLocked(res *ipnstate.PingResult, latency time.Duration, ep epAddr) { res.LatencySeconds = latency.Seconds() - if ep.Addr() != tailcfg.DerpMagicIPAddr { - res.Endpoint = ep.String() + if ep.ap.Addr() != tailcfg.DerpMagicIPAddr { + if ep.vni.IsSet() { + res.PeerRelay = ep.String() + } else { + res.Endpoint = ep.String() + } return } - regionID := int(ep.Port()) + regionID := int(ep.ap.Port()) res.DERPRegionID = regionID res.DERPRegionCode = c.derpRegionCodeLocked(regionID) } @@ -864,19 +1236,44 @@ func (c *Conn) GetEndpointChanges(peer tailcfg.NodeView) ([]EndpointChange, erro // DiscoPublicKey returns the discovery public key. func (c *Conn) DiscoPublicKey() key.DiscoPublic { - return c.discoPublic + return c.discoAtomic.Public() +} + +// RotateDiscoKey generates a new discovery key pair and updates the connection +// to use it. This invalidates all existing disco sessions and will cause peers +// to re-establish discovery sessions with the new key. +// +// This is primarily for debugging and testing purposes, a future enhancement +// should provide a mechanism for seamless rotation by supporting short term use +// of the old key. +func (c *Conn) RotateDiscoKey() { + oldShort := c.discoAtomic.Short() + newPrivate := key.NewDisco() + + c.mu.Lock() + c.discoAtomic.Set(newPrivate) + newShort := c.discoAtomic.Short() + c.discoInfo = make(map[key.DiscoPublic]*discoInfo) + connCtx := c.connCtx + c.mu.Unlock() + + c.logf("magicsock: rotated disco key from %v to %v", oldShort, newShort) + + if connCtx != nil { + c.ReSTUN("disco-key-rotation") + } } // determineEndpoints returns the machine's endpoint addresses. It does a STUN -// lookup (via netcheck) to determine its public address. Additionally any -// static enpoints provided by user are always added to the returned endpoints +// lookup (via netcheck) to determine its public address. Additionally, any +// static endpoints provided by user are always added to the returned endpoints // without validating if the node can be reached via those endpoints. // // c.mu must NOT be held. func (c *Conn) determineEndpoints(ctx context.Context) ([]tailcfg.Endpoint, error) { var havePortmap bool var portmapExt netip.AddrPort - if runtime.GOOS != "js" { + if runtime.GOOS != "js" && c.portMapper != nil { portmapExt, havePortmap = c.portMapper.GetCachedMappingOrStartCreatingOne() } @@ -916,7 +1313,7 @@ func (c *Conn) determineEndpoints(ctx context.Context) ([]tailcfg.Endpoint, erro } // If we didn't have a portmap earlier, maybe it's done by now. - if !havePortmap { + if !havePortmap && c.portMapper != nil { portmapExt, havePortmap = c.portMapper.GetCachedMappingOrStartCreatingOne() } if havePortmap { @@ -988,8 +1385,8 @@ func (c *Conn) determineEndpoints(ctx context.Context) ([]tailcfg.Endpoint, erro // re-run. eps = c.endpointTracker.update(time.Now(), eps) - for i := range c.staticEndpoints.Len() { - addAddr(c.staticEndpoints.At(i), tailcfg.EndpointExplicitConf) + for _, ep := range c.staticEndpoints.All() { + addAddr(ep, tailcfg.EndpointExplicitConf) } if localAddr := c.pconn4.LocalAddr(); localAddr.IP.IsUnspecified() { @@ -1077,20 +1474,31 @@ func (c *Conn) networkDown() bool { return !c.networkUp.Load() } // Send implements conn.Bind. // -// See https://pkg.go.dev/golang.zx2c4.com/wireguard/conn#Bind.Send -func (c *Conn) Send(buffs [][]byte, ep conn.Endpoint) error { +// See https://pkg.go.dev/github.com/tailscale/wireguard-go/conn#Bind.Send +func (c *Conn) Send(buffs [][]byte, ep conn.Endpoint, offset int) (err error) { n := int64(len(buffs)) + defer func() { + if err != nil { + c.metrics.outboundPacketsDroppedErrors.Add(n) + } + }() metricSendData.Add(n) if c.networkDown() { metricSendDataNetworkDown.Add(n) return errNetworkDown } - if ep, ok := ep.(*endpoint); ok { - return ep.send(buffs) + switch ep := ep.(type) { + case *endpoint: + return ep.send(buffs, offset) + case *lazyEndpoint: + // A [*lazyEndpoint] may end up on this TX codepath when wireguard-go is + // deemed "under handshake load" and ends up transmitting a cookie reply + // using the received [conn.Endpoint] in [device.SendHandshakeCookie]. + if ep.src.ap.Addr().Is6() { + return c.pconn6.WriteWireGuardBatchTo(buffs, ep.src, offset) + } + return c.pconn4.WriteWireGuardBatchTo(buffs, ep.src, offset) } - // If it's not of type *endpoint, it's probably *lazyEndpoint, which means - // we don't actually know who the peer is and we're waiting for wireguard-go - // to switch the endpoint. See go/corp/20732. return nil } @@ -1102,19 +1510,19 @@ var errNoUDP = errors.New("no UDP available on platform") var errUnsupportedConnType = errors.New("unsupported connection type") -func (c *Conn) sendUDPBatch(addr netip.AddrPort, buffs [][]byte) (sent bool, err error) { +func (c *Conn) sendUDPBatch(addr epAddr, buffs [][]byte, offset int) (sent bool, err error) { isIPv6 := false switch { - case addr.Addr().Is4(): - case addr.Addr().Is6(): + case addr.ap.Addr().Is4(): + case addr.ap.Addr().Is6(): isIPv6 = true default: panic("bogus sendUDPBatch addr type") } if isIPv6 { - err = c.pconn6.WriteBatchTo(buffs, addr) + err = c.pconn6.WriteWireGuardBatchTo(buffs, addr, offset) } else { - err = c.pconn4.WriteBatchTo(buffs, addr) + err = c.pconn4.WriteWireGuardBatchTo(buffs, addr, offset) } if err != nil { var errGSO neterror.ErrUDPGSODisabled @@ -1122,7 +1530,7 @@ func (c *Conn) sendUDPBatch(addr netip.AddrPort, buffs [][]byte) (sent bool, err c.logf("magicsock: %s", errGSO.Error()) err = errGSO.RetryErr } else { - _ = c.maybeRebindOnError(runtime.GOOS, err) + c.maybeRebindOnError(err) } } return err == nil, err @@ -1130,48 +1538,59 @@ func (c *Conn) sendUDPBatch(addr netip.AddrPort, buffs [][]byte) (sent bool, err // sendUDP sends UDP packet b to ipp. // See sendAddr's docs on the return value meanings. -func (c *Conn) sendUDP(ipp netip.AddrPort, b []byte) (sent bool, err error) { +func (c *Conn) sendUDP(ipp netip.AddrPort, b []byte, isDisco bool, isGeneveEncap bool) (sent bool, err error) { if runtime.GOOS == "js" { return false, errNoUDP } sent, err = c.sendUDPStd(ipp, b) if err != nil { - metricSendUDPError.Add(1) - _ = c.maybeRebindOnError(runtime.GOOS, err) + if isGeneveEncap { + metricSendPeerRelayError.Add(1) + } else { + metricSendUDPError.Add(1) + } + c.maybeRebindOnError(err) } else { - if sent { - metricSendUDP.Add(1) + if sent && !isDisco { + switch { + case ipp.Addr().Is4(): + if isGeneveEncap { + c.metrics.outboundPacketsPeerRelayIPv4Total.Add(1) + c.metrics.outboundBytesPeerRelayIPv4Total.Add(int64(len(b))) + } else { + c.metrics.outboundPacketsIPv4Total.Add(1) + c.metrics.outboundBytesIPv4Total.Add(int64(len(b))) + } + case ipp.Addr().Is6(): + if isGeneveEncap { + c.metrics.outboundPacketsPeerRelayIPv6Total.Add(1) + c.metrics.outboundBytesPeerRelayIPv6Total.Add(int64(len(b))) + } else { + c.metrics.outboundPacketsIPv6Total.Add(1) + c.metrics.outboundBytesIPv6Total.Add(int64(len(b))) + } + } } } return } -// maybeRebindOnError performs a rebind and restun if the error is defined and -// any conditionals are met. -func (c *Conn) maybeRebindOnError(os string, err error) bool { - switch { - case errors.Is(err, syscall.EPERM): - why := "operation-not-permitted-rebind" - switch os { - // We currently will only rebind and restun on a syscall.EPERM if it is experienced - // on a client running darwin. - // TODO(charlotte, raggi): expand os options if required. - case "darwin": - // TODO(charlotte): implement a backoff, so we don't end up in a rebind loop for persistent - // EPERMs. - if c.lastEPERMRebind.Load().Before(time.Now().Add(-5 * time.Second)) { - c.logf("magicsock: performing %q", why) - c.lastEPERMRebind.Store(time.Now()) - c.Rebind() - go c.ReSTUN(why) - return true - } - default: - c.logf("magicsock: not performing %q", why) - return false - } +// maybeRebindOnError performs a rebind and restun if the error is one that is +// known to be healed by a rebind, and the rebind is not throttled. +func (c *Conn) maybeRebindOnError(err error) { + ok, reason := shouldRebind(err) + if !ok { + return + } + + if c.lastErrRebind.Load().Before(time.Now().Add(-5 * time.Second)) { + c.logf("magicsock: performing rebind due to %q", reason) + c.lastErrRebind.Store(time.Now()) + c.Rebind() + go c.ReSTUN(reason) + } else { + c.logf("magicsock: not performing %q rebind due to throttle", reason) } - return false } // sendUDPNetcheck sends b via UDP to addr. It is used exclusively by netcheck. @@ -1225,9 +1644,9 @@ func (c *Conn) sendUDPStd(addr netip.AddrPort, b []byte) (sent bool, err error) // An example of when they might be different: sending to an // IPv6 address when the local machine doesn't have IPv6 support // returns (false, nil); it's not an error, but nothing was sent. -func (c *Conn) sendAddr(addr netip.AddrPort, pubKey key.NodePublic, b []byte) (sent bool, err error) { +func (c *Conn) sendAddr(addr netip.AddrPort, pubKey key.NodePublic, b []byte, isDisco bool, isGeneveEncap bool) (sent bool, err error) { if addr.Addr() != tailcfg.DerpMagicIPAddr { - return c.sendUDP(addr, b) + return c.sendUDP(addr, b, isDisco, isGeneveEncap) } regionID := int(addr.Port()) @@ -1244,18 +1663,27 @@ func (c *Conn) sendAddr(addr netip.AddrPort, pubKey key.NodePublic, b []byte) (s // internal locks. pkt := bytes.Clone(b) - select { - case <-c.donec: - metricSendDERPErrorClosed.Add(1) - return false, errConnClosed - case ch <- derpWriteRequest{addr, pubKey, pkt}: - metricSendDERPQueued.Add(1) - return true, nil - default: - metricSendDERPErrorQueue.Add(1) - // Too many writes queued. Drop packet. - return false, errDropDerpPacket + wr := derpWriteRequest{addr, pubKey, pkt, isDisco} + for range 3 { + select { + case <-c.donec: + metricSendDERPErrorClosed.Add(1) + return false, errConnClosed + case ch <- wr: + metricSendDERPQueued.Add(1) + return true, nil + default: + select { + case <-ch: + metricSendDERPDropped.Add(1) + default: + } + } } + // gave up after 3 write attempts + metricSendDERPErrorQueue.Add(1) + // Too many writes queued. Drop packet. + return false, errDropDerpPacket } type receiveBatch struct { @@ -1278,24 +1706,33 @@ func (c *Conn) putReceiveBatch(batch *receiveBatch) { c.receiveBatchPool.Put(batch) } -// receiveIPv4 creates an IPv4 ReceiveFunc reading from c.pconn4. func (c *Conn) receiveIPv4() conn.ReceiveFunc { - return c.mkReceiveFunc(&c.pconn4, c.health.ReceiveFuncStats(health.ReceiveIPv4), metricRecvDataIPv4) + return c.mkReceiveFunc(&c.pconn4, c.health.ReceiveFuncStats(health.ReceiveIPv4), + &c.metrics.inboundPacketsIPv4Total, + &c.metrics.inboundPacketsPeerRelayIPv4Total, + &c.metrics.inboundBytesIPv4Total, + &c.metrics.inboundBytesPeerRelayIPv4Total, + ) } // receiveIPv6 creates an IPv6 ReceiveFunc reading from c.pconn6. func (c *Conn) receiveIPv6() conn.ReceiveFunc { - return c.mkReceiveFunc(&c.pconn6, c.health.ReceiveFuncStats(health.ReceiveIPv6), metricRecvDataIPv6) + return c.mkReceiveFunc(&c.pconn6, c.health.ReceiveFuncStats(health.ReceiveIPv6), + &c.metrics.inboundPacketsIPv6Total, + &c.metrics.inboundPacketsPeerRelayIPv6Total, + &c.metrics.inboundBytesIPv6Total, + &c.metrics.inboundBytesPeerRelayIPv6Total, + ) } // mkReceiveFunc creates a ReceiveFunc reading from ruc. -// The provided healthItem and metric are updated if non-nil. -func (c *Conn) mkReceiveFunc(ruc *RebindingUDPConn, healthItem *health.ReceiveFuncStats, metric *clientmetric.Metric) conn.ReceiveFunc { - // epCache caches an IPPort->endpoint for hot flows. - var epCache ippEndpointCache +// The provided healthItem and metrics are updated if non-nil. +func (c *Conn) mkReceiveFunc(ruc *RebindingUDPConn, healthItem *health.ReceiveFuncStats, directPacketMetric, peerRelayPacketMetric, directBytesMetric, peerRelayBytesMetric *expvar.Int) conn.ReceiveFunc { + // epCache caches an epAddr->endpoint for hot flows. + var epCache epAddrEndpointCache return func(buffs [][]byte, sizes []int, eps []conn.Endpoint) (_ int, retErr error) { - if healthItem != nil { + if buildfeatures.HasHealth && healthItem != nil { healthItem.Enter() defer healthItem.Exit() defer func() { @@ -1326,12 +1763,24 @@ func (c *Conn) mkReceiveFunc(ruc *RebindingUDPConn, healthItem *health.ReceiveFu continue } ipp := msg.Addr.(*net.UDPAddr).AddrPort() - if ep, ok := c.receiveIP(msg.Buffers[0][:msg.N], ipp, &epCache); ok { - if metric != nil { - metric.Add(1) + if ep, size, isGeneveEncap, ok := c.receiveIP(msg.Buffers[0][:msg.N], ipp, &epCache); ok { + if isGeneveEncap { + if peerRelayPacketMetric != nil { + peerRelayPacketMetric.Add(1) + } + if peerRelayBytesMetric != nil { + peerRelayBytesMetric.Add(int64(msg.N)) + } + } else { + if directPacketMetric != nil { + directPacketMetric.Add(1) + } + if directBytesMetric != nil { + directBytesMetric.Add(int64(msg.N)) + } } eps[i] = ep - sizes[i] = msg.N + sizes[i] = size reportToCaller = true } else { sizes[i] = 0 @@ -1344,49 +1793,114 @@ func (c *Conn) mkReceiveFunc(ruc *RebindingUDPConn, healthItem *health.ReceiveFu } } +// looksLikeInitiationMsg returns true if b looks like a WireGuard initiation +// message, otherwise it returns false. +func looksLikeInitiationMsg(b []byte) bool { + return len(b) == device.MessageInitiationSize && + binary.LittleEndian.Uint32(b) == device.MessageInitiationType +} + // receiveIP is the shared bits of ReceiveIPv4 and ReceiveIPv6. // +// size is the length of 'b' to report up to wireguard-go (only relevant if +// 'ok' is true). +// +// isGeneveEncap is whether 'b' is encapsulated by a Geneve header (only +// relevant if 'ok' is true). +// // ok is whether this read should be reported up to wireguard-go (our // caller). -func (c *Conn) receiveIP(b []byte, ipp netip.AddrPort, cache *ippEndpointCache) (_ conn.Endpoint, ok bool) { +func (c *Conn) receiveIP(b []byte, ipp netip.AddrPort, cache *epAddrEndpointCache) (_ conn.Endpoint, size int, isGeneveEncap bool, ok bool) { var ep *endpoint - if stun.Is(b) { - c.netChecker.ReceiveSTUNPacket(b, ipp) - return nil, false + size = len(b) + + var geneve packet.GeneveHeader + pt, isGeneveEncap := packetLooksLike(b) + src := epAddr{ap: ipp} + if isGeneveEncap { + err := geneve.Decode(b) + if err != nil { + // Decode only returns an error when 'b' is too short, and + // 'isGeneveEncap' indicates it's a sufficient length. + c.logf("[unexpected] geneve header decoding error: %v", err) + return nil, 0, false, false + } + src.vni = geneve.VNI } - if c.handleDiscoMessage(b, ipp, key.NodePublic{}, discoRXPathUDP) { - return nil, false + switch pt { + case packetLooksLikeDisco: + if isGeneveEncap { + b = b[packet.GeneveFixedHeaderLength:] + } + // The Geneve header control bit should only be set for relay handshake + // messages terminating on or originating from a UDP relay server. We + // have yet to open the encrypted disco payload to determine the + // [disco.MessageType], but we assert it should be handshake-related. + shouldByRelayHandshakeMsg := geneve.Control == true + c.handleDiscoMessage(b, src, shouldByRelayHandshakeMsg, key.NodePublic{}, discoRXPathUDP) + return nil, 0, false, false + case packetLooksLikeSTUNBinding: + c.netChecker.ReceiveSTUNPacket(b, ipp) + return nil, 0, false, false + default: + // Fall through for all other packet types as they are assumed to + // be potentially WireGuard. } + if !c.havePrivateKey.Load() { // If we have no private key, we're logged out or // stopped. Don't try to pass these wireguard packets // up to wireguard-go; it'll just complain (issue 1167). - return nil, false + return nil, 0, false, false + } + + // geneveInclusivePacketLen holds the packet length prior to any potential + // Geneve header stripping. + geneveInclusivePacketLen := len(b) + if src.vni.IsSet() { + // Strip away the Geneve header before returning the packet to + // wireguard-go. + // + // TODO(jwhited): update [github.com/tailscale/wireguard-go/conn.ReceiveFunc] + // to support returning start offset in order to get rid of this memmove perf + // penalty. + size = copy(b, b[packet.GeneveFixedHeaderLength:]) + b = b[:size] } - if cache.ipp == ipp && cache.de != nil && cache.gen == cache.de.numStopAndReset() { + + if cache.epAddr == src && cache.de != nil && cache.gen == cache.de.numStopAndReset() { ep = cache.de } else { c.mu.Lock() - de, ok := c.peerMap.endpointForIPPort(ipp) + de, ok := c.peerMap.endpointForEpAddr(src) c.mu.Unlock() if !ok { - if c.controlKnobs != nil && c.controlKnobs.DisableCryptorouting.Load() { - return nil, false - } - return &lazyEndpoint{c: c, src: ipp}, true + // TODO(jwhited): reuse [lazyEndpoint] across calls to receiveIP() + // for the same batch & [epAddr] src. + return &lazyEndpoint{c: c, src: src}, size, isGeneveEncap, true } - cache.ipp = ipp + cache.epAddr = src cache.de = de cache.gen = de.numStopAndReset() ep = de } now := mono.Now() ep.lastRecvUDPAny.StoreAtomic(now) - ep.noteRecvActivity(ipp, now) - if stats := c.stats.Load(); stats != nil { - stats.UpdateRxPhysical(ep.nodeAddr, ipp, len(b)) + connNoted := ep.noteRecvActivity(src, now) + if buildfeatures.HasNetLog { + if update := c.connCounter.Load(); update != nil { + update(0, netip.AddrPortFrom(ep.nodeAddr, 0), ipp, 1, geneveInclusivePacketLen, true) + } + } + if src.vni.IsSet() && (connNoted || looksLikeInitiationMsg(b)) { + // connNoted is periodic, but we also want to verify if the peer is who + // we believe for all initiation messages, otherwise we could get + // unlucky and fail to JIT configure the "correct" peer. + // TODO(jwhited): relax this to include direct connections + // See http://go/corp/29422 & http://go/corp/30042 + return &lazyEndpoint{c: c, maybeEP: ep, src: src}, size, isGeneveEncap, true } - return ep, true + return ep, size, isGeneveEncap, true } // discoLogLevel controls the verbosity of discovery log messages. @@ -1407,28 +1921,88 @@ const ( // speeds. var debugIPv4DiscoPingPenalty = envknob.RegisterDuration("TS_DISCO_PONG_IPV4_DELAY") +// sendDiscoAllocateUDPRelayEndpointRequest is primarily an alias for +// sendDiscoMessage, but it will alternatively send m over the eventbus if dst +// is a DERP IP:port, and dstKey is self. This saves a round-trip through DERP +// when we are attempting to allocate on a self (in-process) peer relay server. +func (c *Conn) sendDiscoAllocateUDPRelayEndpointRequest(dst epAddr, dstKey key.NodePublic, dstDisco key.DiscoPublic, allocReq *disco.AllocateUDPRelayEndpointRequest, logLevel discoLogLevel) (sent bool, err error) { + isDERP := dst.ap.Addr() == tailcfg.DerpMagicIPAddr + selfNodeKey := c.publicKeyAtomic.Load() + if isDERP && dstKey.Compare(selfNodeKey) == 0 { + c.allocRelayEndpointPub.Publish(UDPRelayAllocReq{ + RxFromNodeKey: selfNodeKey, + RxFromDiscoKey: c.discoAtomic.Public(), + Message: allocReq, + }) + metricLocalDiscoAllocUDPRelayEndpointRequest.Add(1) + return true, nil + } + return c.sendDiscoMessage(dst, dstKey, dstDisco, allocReq, logLevel) +} + // sendDiscoMessage sends discovery message m to dstDisco at dst. // -// If dst is a DERP IP:port, then dstKey must be non-zero. +// If dst.ap is a DERP IP:port, then dstKey must be non-zero. +// +// If dst.vni.isSet(), the [disco.Message] will be preceded by a Geneve header +// with the VNI field set to the value returned by vni.get(). // // The dstKey should only be non-zero if the dstDisco key // unambiguously maps to exactly one peer. -func (c *Conn) sendDiscoMessage(dst netip.AddrPort, dstKey key.NodePublic, dstDisco key.DiscoPublic, m disco.Message, logLevel discoLogLevel) (sent bool, err error) { - isDERP := dst.Addr() == tailcfg.DerpMagicIPAddr - if _, isPong := m.(*disco.Pong); isPong && !isDERP && dst.Addr().Is4() { +func (c *Conn) sendDiscoMessage(dst epAddr, dstKey key.NodePublic, dstDisco key.DiscoPublic, m disco.Message, logLevel discoLogLevel) (sent bool, err error) { + isDERP := dst.ap.Addr() == tailcfg.DerpMagicIPAddr + if _, isPong := m.(*disco.Pong); isPong && !isDERP && dst.ap.Addr().Is4() { time.Sleep(debugIPv4DiscoPingPenalty()) } + isRelayHandshakeMsg := false + switch m.(type) { + case *disco.BindUDPRelayEndpoint, *disco.BindUDPRelayEndpointAnswer: + isRelayHandshakeMsg = true + } + c.mu.Lock() if c.closed { c.mu.Unlock() return false, errConnClosed } + var di *discoInfo + switch { + case isRelayHandshakeMsg: + var ok bool + di, ok = c.relayManager.discoInfo(dstDisco) + if !ok { + c.mu.Unlock() + return false, errors.New("unknown relay server") + } + case c.peerMap.knownPeerDiscoKey(dstDisco): + di = c.discoInfoForKnownPeerLocked(dstDisco) + default: + // This is an attempt to send to an unknown peer that is not a relay + // server. This can happen when a call to the current function, which is + // often via a new goroutine, races with applying a change in the + // netmap, e.g. the associated peer(s) for dstDisco goes away. + c.mu.Unlock() + return false, errors.New("unknown peer") + } + c.mu.Unlock() + pkt := make([]byte, 0, 512) // TODO: size it correctly? pool? if it matters. + if dst.vni.IsSet() { + gh := packet.GeneveHeader{ + Version: 0, + Protocol: packet.GeneveProtocolDisco, + VNI: dst.vni, + Control: isRelayHandshakeMsg, + } + pkt = append(pkt, make([]byte, packet.GeneveFixedHeaderLength)...) + err := gh.Encode(pkt) + if err != nil { + return false, err + } + } pkt = append(pkt, disco.Magic...) - pkt = c.discoPublic.AppendTo(pkt) - di := c.discoInfoLocked(dstDisco) - c.mu.Unlock() + pkt = c.discoAtomic.Public().AppendTo(pkt) if isDERP { metricSendDiscoDERP.Add(1) @@ -1438,14 +2012,15 @@ func (c *Conn) sendDiscoMessage(dst netip.AddrPort, dstKey key.NodePublic, dstDi box := di.sharedKey.Seal(m.AppendMarshal(nil)) pkt = append(pkt, box...) - sent, err = c.sendAddr(dst, dstKey, pkt) + const isDisco = true + sent, err = c.sendAddr(dst.ap, dstKey, pkt, isDisco, dst.vni.IsSet()) if sent { if logLevel == discoLog || (logLevel == discoVerboseLog && debugDisco()) { node := "?" if !dstKey.IsZero() { node = dstKey.ShortString() } - c.dlogf("[v1] magicsock: disco: %v->%v (%v, %v) sent %v len %v\n", c.discoShort, dstDisco.ShortString(), node, derpStr(dst.String()), disco.MessageSummary(m), len(pkt)) + c.dlogf("[v1] magicsock: disco: %v->%v (%v, %v) sent %v len %v\n", c.discoAtomic.Short(), dstDisco.ShortString(), node, derpStr(dst.String()), disco.MessageSummary(m), len(pkt)) } if isDERP { metricSentDiscoDERP.Add(1) @@ -1459,12 +2034,22 @@ func (c *Conn) sendDiscoMessage(dst netip.AddrPort, dstKey key.NodePublic, dstDi metricSentDiscoPong.Add(1) case *disco.CallMeMaybe: metricSentDiscoCallMeMaybe.Add(1) + case *disco.CallMeMaybeVia: + metricSentDiscoCallMeMaybeVia.Add(1) + case *disco.BindUDPRelayEndpoint: + metricSentDiscoBindUDPRelayEndpoint.Add(1) + case *disco.BindUDPRelayEndpointAnswer: + metricSentDiscoBindUDPRelayEndpointAnswer.Add(1) + case *disco.AllocateUDPRelayEndpointRequest: + metricSentDiscoAllocUDPRelayEndpointRequest.Add(1) + case *disco.AllocateUDPRelayEndpointResponse: + metricSentDiscoAllocUDPRelayEndpointResponse.Add(1) } } else if err == nil { // Can't send. (e.g. no IPv6 locally) } else { if !c.networkDown() && pmtuShouldLogDiscoTxErr(m, err) { - c.logf("magicsock: disco: failed to send %v to %v: %v", disco.MessageSummary(m), dst, err) + c.logf("magicsock: disco: failed to send %v to %v %s: %v", disco.MessageSummary(m), dst, dstKey.ShortString(), err) } } return sent, err @@ -1478,8 +2063,98 @@ const ( discoRXPathRawSocket discoRXPath = "raw socket" ) -// handleDiscoMessage handles a discovery message and reports whether -// msg was a Tailscale inter-node discovery message. +const discoHeaderLen = len(disco.Magic) + key.DiscoPublicRawLen + +type packetLooksLikeType int + +const ( + packetLooksLikeWireGuard packetLooksLikeType = iota + packetLooksLikeSTUNBinding + packetLooksLikeDisco +) + +// packetLooksLike reports a [packetsLooksLikeType] for 'msg', and whether +// 'msg' is encapsulated by a Geneve header (or naked). +// +// [packetLooksLikeSTUNBinding] is never Geneve-encapsulated. +// +// Naked STUN binding, Naked Disco, Geneve followed by Disco, naked WireGuard, +// and Geneve followed by WireGuard can be confidently distinguished based on +// the following: +// +// 1. STUN binding @ msg[1] (0x01) is sufficiently non-overlapping with the +// Geneve header where the LSB is always 0 (part of 6 "reserved" bits). +// +// 2. STUN binding @ msg[1] (0x01) is sufficiently non-overlapping with naked +// WireGuard, which is always a 0 byte value (WireGuard message type +// occupies msg[0:4], and msg[1:4] are always 0). +// +// 3. STUN binding @ msg[1] (0x01) is sufficiently non-overlapping with the +// second byte of [disco.Magic] (0x53). +// +// 4. [disco.Magic] @ msg[2:4] (0xf09f) is sufficiently non-overlapping with a +// Geneve protocol field value of [packet.GeneveProtocolDisco] or +// [packet.GeneveProtocolWireGuard] . +// +// 5. [disco.Magic] @ msg[0] (0x54) is sufficiently non-overlapping with the +// first byte of a WireGuard packet (0x01-0x04). +// +// 6. [packet.GeneveHeader] with a Geneve protocol field value of +// [packet.GeneveProtocolDisco] or [packet.GeneveProtocolWireGuard] +// (msg[2:4]) is sufficiently non-overlapping with the second 2 bytes of a +// WireGuard packet which are always 0x0000. +func packetLooksLike(msg []byte) (t packetLooksLikeType, isGeneveEncap bool) { + if stun.Is(msg) && + msg[1] == 0x01 { // method binding + return packetLooksLikeSTUNBinding, false + } + + // TODO(jwhited): potentially collapse into disco.LooksLikeDiscoWrapper() + // if safe to do so. + looksLikeDisco := func(msg []byte) bool { + if len(msg) >= discoHeaderLen && string(msg[:len(disco.Magic)]) == disco.Magic { + return true + } + return false + } + + // Do we have a Geneve header? + if len(msg) >= packet.GeneveFixedHeaderLength && + msg[0]&0xC0 == 0 && // version bits that we always transmit as 0s + msg[1]&0x3F == 0 && // reserved bits that we always transmit as 0s + msg[7] == 0 { // reserved byte that we always transmit as 0 + switch binary.BigEndian.Uint16(msg[2:4]) { + case packet.GeneveProtocolDisco: + if looksLikeDisco(msg[packet.GeneveFixedHeaderLength:]) { + return packetLooksLikeDisco, true + } else { + // The Geneve header is well-formed, and it indicated this + // was disco, but it's not. The evaluated bytes at this point + // are always distinct from naked WireGuard (msg[2:4] are always + // 0x0000) and naked Disco (msg[2:4] are always 0xf09f), but + // maintain pre-Geneve behavior and fall back to assuming it's + // naked WireGuard. + return packetLooksLikeWireGuard, false + } + case packet.GeneveProtocolWireGuard: + return packetLooksLikeWireGuard, true + default: + // The Geneve header is well-formed, but the protocol field value is + // unknown to us. The evaluated bytes at this point are not + // necessarily distinct from naked WireGuard or naked Disco, fall + // through. + } + } + + if looksLikeDisco(msg) { + return packetLooksLikeDisco, false + } else { + return packetLooksLikeWireGuard, false + } +} + +// handleDiscoMessage handles a discovery message. The caller is assumed to have +// verified 'msg' returns [packetLooksLikeDisco] from packetLooksLike(). // // A discovery message has the form: // @@ -1488,23 +2163,18 @@ const ( // - nonce [24]byte // - naclbox of payload (see tailscale.com/disco package for inner payload format) // -// For messages received over DERP, the src.Addr() will be derpMagicIP (with -// src.Port() being the region ID) and the derpNodeSrc will be the node key +// For messages received over DERP, the src.ap.Addr() will be derpMagicIP (with +// src.ap.Port() being the region ID) and the derpNodeSrc will be the node key // it was received from at the DERP layer. derpNodeSrc is zero when received // over UDP. -func (c *Conn) handleDiscoMessage(msg []byte, src netip.AddrPort, derpNodeSrc key.NodePublic, via discoRXPath) (isDiscoMsg bool) { - const headerLen = len(disco.Magic) + key.DiscoPublicRawLen - if len(msg) < headerLen || string(msg[:len(disco.Magic)]) != disco.Magic { - return false - } - - // If the first four parts are the prefix of disco.Magic - // (0x5453f09f) then it's definitely not a valid WireGuard - // packet (which starts with little-endian uint32 1, 2, 3, 4). - // Use naked returns for all following paths. - isDiscoMsg = true - - sender := key.DiscoPublicFromRaw32(mem.B(msg[len(disco.Magic):headerLen])) +// +// If 'msg' was encapsulated by a Geneve header it is assumed to have already +// been stripped. +// +// 'shouldBeRelayHandshakeMsg' will be true if 'msg' was encapsulated +// by a Geneve header with the control bit set. +func (c *Conn) handleDiscoMessage(msg []byte, src epAddr, shouldBeRelayHandshakeMsg bool, derpNodeSrc key.NodePublic, via discoRXPath) { + sender := key.DiscoPublicFromRaw32(mem.B(msg[len(disco.Magic):discoHeaderLen])) c.mu.Lock() defer c.mu.Unlock() @@ -1517,11 +2187,23 @@ func (c *Conn) handleDiscoMessage(msg []byte, src netip.AddrPort, derpNodeSrc ke } if c.privateKey.IsZero() { // Ignore disco messages when we're stopped. - // Still return true, to not pass it down to wireguard. return } - if !c.peerMap.knownPeerDiscoKey(sender) { + var di *discoInfo + switch { + case shouldBeRelayHandshakeMsg: + var ok bool + di, ok = c.relayManager.discoInfo(sender) + if !ok { + if debugDisco() { + c.logf("magicsock: disco: ignoring disco-looking relay handshake frame, no active handshakes with key %v over %v", sender.ShortString(), src) + } + return + } + case c.peerMap.knownPeerDiscoKey(sender): + di = c.discoInfoForKnownPeerLocked(sender) + default: metricRecvDiscoBadPeer.Add(1) if debugDisco() { c.logf("magicsock: disco: ignoring disco-looking frame, don't know of key %v", sender.ShortString()) @@ -1529,26 +2211,22 @@ func (c *Conn) handleDiscoMessage(msg []byte, src netip.AddrPort, derpNodeSrc ke return } - isDERP := src.Addr() == tailcfg.DerpMagicIPAddr - if !isDERP { + isDERP := src.ap.Addr() == tailcfg.DerpMagicIPAddr + if !isDERP && !shouldBeRelayHandshakeMsg { // Record receive time for UDP transport packets. - pi, ok := c.peerMap.byIPPort[src] + pi, ok := c.peerMap.byEpAddr[src] if ok { pi.ep.lastRecvUDPAny.StoreAtomic(mono.Now()) } } - // We're now reasonably sure we're expecting communication from - // this peer, do the heavy crypto lifting to see what they want. - // - // From here on, peerNode and de are non-nil. - - di := c.discoInfoLocked(sender) + // We're now reasonably sure we're expecting communication from 'sender', + // do the heavy crypto lifting to see what they want. - sealedBox := msg[headerLen:] + sealedBox := msg[discoHeaderLen:] payload, ok := di.sharedKey.Open(sealedBox) if !ok { - // This might be have been intended for a previous + // This might have been intended for a previous // disco key. When we restart we get a new disco key // and old packets might've still been in flight (or // scheduled). This is particularly the case for LANs @@ -1568,7 +2246,8 @@ func (c *Conn) handleDiscoMessage(msg []byte, src netip.AddrPort, derpNodeSrc ke // Emit information about the disco frame into the pcap stream // if a capture hook is installed. if cb := c.captureHook.Load(); cb != nil { - cb(capture.PathDisco, time.Now(), disco.ToPCAPFrame(src, derpNodeSrc, payload), packet.CaptureMeta{}) + // TODO(jwhited): include VNI context? + cb(packet.PathDisco, time.Now(), disco.ToPCAPFrame(src.ap, derpNodeSrc, payload), packet.CaptureMeta{}) } dm, err := disco.Parse(payload) @@ -1591,6 +2270,20 @@ func (c *Conn) handleDiscoMessage(msg []byte, src netip.AddrPort, derpNodeSrc ke metricRecvDiscoUDP.Add(1) } + if shouldBeRelayHandshakeMsg { + challenge, ok := dm.(*disco.BindUDPRelayEndpointChallenge) + if !ok { + // We successfully parsed the disco message, but it wasn't a + // challenge. We should never receive other message types + // from a relay server with the Geneve header control bit set. + c.logf("[unexpected] %T packets should not come from a relay server with Geneve control bit set", dm) + return + } + c.relayManager.handleRxDiscoMsg(c, challenge, key.NodePublic{}, di.discoKey, src) + metricRecvDiscoBindUDPRelayEndpointChallenge.Add(1) + return + } + switch dm := dm.(type) { case *disco.Ping: metricRecvDiscoPing.Add(1) @@ -1600,24 +2293,65 @@ func (c *Conn) handleDiscoMessage(msg []byte, src netip.AddrPort, derpNodeSrc ke // There might be multiple nodes for the sender's DiscoKey. // Ask each to handle it, stopping once one reports that // the Pong's TxID was theirs. + knownTxID := false c.peerMap.forEachEndpointWithDiscoKey(sender, func(ep *endpoint) (keepGoing bool) { if ep.handlePongConnLocked(dm, di, src) { + knownTxID = true return false } return true }) - case *disco.CallMeMaybe: - metricRecvDiscoCallMeMaybe.Add(1) + if !knownTxID && src.vni.IsSet() { + // If it's an unknown TxID, and it's Geneve-encapsulated, then + // make [relayManager] aware. It might be in the middle of probing + // src. + c.relayManager.handleRxDiscoMsg(c, dm, key.NodePublic{}, di.discoKey, src) + } + case *disco.CallMeMaybe, *disco.CallMeMaybeVia: + var via *disco.CallMeMaybeVia + isVia := false + msgType := "CallMeMaybe" + cmm, ok := dm.(*disco.CallMeMaybe) + if ok { + metricRecvDiscoCallMeMaybe.Add(1) + } else { + metricRecvDiscoCallMeMaybeVia.Add(1) + via = dm.(*disco.CallMeMaybeVia) + msgType = "CallMeMaybeVia" + isVia = true + } + if !isDERP || derpNodeSrc.IsZero() { - // CallMeMaybe messages should only come via DERP. - c.logf("[unexpected] CallMeMaybe packets should only come via DERP") + // CallMeMaybe{Via} messages should only come via DERP. + c.logf("[unexpected] %s packets should only come via DERP", msgType) return } nodeKey := derpNodeSrc ep, ok := c.peerMap.endpointForNodeKey(nodeKey) if !ok { - metricRecvDiscoCallMeMaybeBadNode.Add(1) - c.logf("magicsock: disco: ignoring CallMeMaybe from %v; %v is unknown", sender.ShortString(), derpNodeSrc.ShortString()) + if isVia { + metricRecvDiscoCallMeMaybeViaBadNode.Add(1) + } else { + metricRecvDiscoCallMeMaybeBadNode.Add(1) + } + c.logf("magicsock: disco: ignoring %s from %v; %v is unknown", msgType, sender.ShortString(), derpNodeSrc.ShortString()) + return + } + // If the "disable-relay-client" node attr is set for this node, it + // can't be a UDP relay client, so drop any CallMeMaybeVia messages it + // receives. + if isVia && !c.relayClientEnabled { + c.logf("magicsock: disco: ignoring %s from %v; disable-relay-client node attr is set", msgType, sender.ShortString()) + return + } + + ep.mu.Lock() + relayCapable := ep.relayCapable + lastBest := ep.bestAddr + lastBestIsTrusted := mono.Now().Before(ep.trustBestAddrUntil) + ep.mu.Unlock() + if isVia && !relayCapable { + c.logf("magicsock: disco: ignoring %s from %v; %v is not known to be relay capable", msgType, sender.ShortString(), sender.ShortString()) return } epDisco := ep.disco.Load() @@ -1625,15 +2359,119 @@ func (c *Conn) handleDiscoMessage(msg []byte, src netip.AddrPort, derpNodeSrc ke return } if epDisco.key != di.discoKey { - metricRecvDiscoCallMeMaybeBadDisco.Add(1) - c.logf("[unexpected] CallMeMaybe from peer via DERP whose netmap discokey != disco source") + if isVia { + metricRecvDiscoCallMeMaybeViaBadDisco.Add(1) + } else { + metricRecvDiscoCallMeMaybeBadDisco.Add(1) + } + c.logf("[unexpected] %s from peer via DERP whose netmap discokey != disco source", msgType) + return + } + if isVia { + c.dlogf("[v1] magicsock: disco: %v<-%v via %v (%v, %v) got call-me-maybe-via, %d endpoints", + c.discoAtomic.Short(), epDisco.short, via.ServerDisco.ShortString(), + ep.publicKey.ShortString(), derpStr(src.String()), + len(via.AddrPorts)) + c.relayManager.handleCallMeMaybeVia(ep, lastBest, lastBestIsTrusted, via) + } else { + c.dlogf("[v1] magicsock: disco: %v<-%v (%v, %v) got call-me-maybe, %d endpoints", + c.discoAtomic.Short(), epDisco.short, + ep.publicKey.ShortString(), derpStr(src.String()), + len(cmm.MyNumber)) + go ep.handleCallMeMaybe(cmm) + } + case *disco.AllocateUDPRelayEndpointRequest, *disco.AllocateUDPRelayEndpointResponse: + var resp *disco.AllocateUDPRelayEndpointResponse + isResp := false + msgType := "AllocateUDPRelayEndpointRequest" + req, ok := dm.(*disco.AllocateUDPRelayEndpointRequest) + if ok { + metricRecvDiscoAllocUDPRelayEndpointRequest.Add(1) + } else { + metricRecvDiscoAllocUDPRelayEndpointResponse.Add(1) + resp = dm.(*disco.AllocateUDPRelayEndpointResponse) + msgType = "AllocateUDPRelayEndpointResponse" + isResp = true + } + + if !isDERP { + // These messages should only come via DERP. + c.logf("[unexpected] %s packets should only come via DERP", msgType) + return + } + nodeKey := derpNodeSrc + ep, ok := c.peerMap.endpointForNodeKey(nodeKey) + if !ok { + c.logf("magicsock: disco: ignoring %s from %v; %v is unknown", msgType, sender.ShortString(), derpNodeSrc.ShortString()) return } - c.dlogf("[v1] magicsock: disco: %v<-%v (%v, %v) got call-me-maybe, %d endpoints", - c.discoShort, epDisco.short, - ep.publicKey.ShortString(), derpStr(src.String()), - len(dm.MyNumber)) - go ep.handleCallMeMaybe(dm) + epDisco := ep.disco.Load() + if epDisco == nil { + return + } + if epDisco.key != di.discoKey { + if isResp { + metricRecvDiscoAllocUDPRelayEndpointResponseBadDisco.Add(1) + } else { + metricRecvDiscoAllocUDPRelayEndpointRequestBadDisco.Add(1) + } + c.logf("[unexpected] %s from peer via DERP whose netmap discokey != disco source", msgType) + return + } + + if isResp { + c.dlogf("[v1] magicsock: disco: %v<-%v (%v, %v) got %s, %d endpoints", + c.discoAtomic.Short(), epDisco.short, + ep.publicKey.ShortString(), derpStr(src.String()), + msgType, + len(resp.AddrPorts)) + c.relayManager.handleRxDiscoMsg(c, resp, nodeKey, di.discoKey, src) + return + } else if sender.Compare(req.ClientDisco[0]) != 0 && sender.Compare(req.ClientDisco[1]) != 0 { + // An allocation request must contain the sender's disco key in + // ClientDisco. One of the relay participants must be the sender. + c.logf("magicsock: disco: %s from %v; %v does not contain sender's disco key", + msgType, sender.ShortString(), derpNodeSrc.ShortString()) + return + } else { + c.dlogf("[v1] magicsock: disco: %v<-%v (%v, %v) got %s disco[0]=%v disco[1]=%v", + c.discoAtomic.Short(), epDisco.short, + ep.publicKey.ShortString(), derpStr(src.String()), + msgType, + req.ClientDisco[0].ShortString(), req.ClientDisco[1].ShortString()) + } + + if c.filt == nil { + return + } + // Binary search of peers is O(log n) while c.mu is held. + // TODO: We might be able to use ep.nodeAddr instead of all addresses, + // or we might be able to release c.mu before doing this work. Keep it + // simple and slow for now. c.peers.AsSlice is a copy. We may need to + // write our own binary search for a [views.Slice]. + peerI, ok := slices.BinarySearchFunc(c.peers.AsSlice(), ep.nodeID, func(peer tailcfg.NodeView, target tailcfg.NodeID) int { + if peer.ID() < target { + return -1 + } else if peer.ID() > target { + return 1 + } + return 0 + }) + if !ok { + // unexpected + return + } + if !nodeHasCap(c.filt, c.peers.At(peerI), c.self, tailcfg.PeerCapabilityRelay) { + return + } + // [Conn.mu] must not be held while publishing, or [Conn.onUDPRelayAllocResp] + // can deadlock as the req sub and resp pub are the same goroutine. + // See #17830. + go c.allocRelayEndpointPub.Publish(UDPRelayAllocReq{ + RxFromDiscoKey: sender, + RxFromNodeKey: nodeKey, + Message: req, + }) } return } @@ -1678,25 +2516,45 @@ func (c *Conn) unambiguousNodeKeyOfPingLocked(dm *disco.Ping, dk key.DiscoPublic // di is the discoInfo of the source of the ping. // derpNodeSrc is non-zero if the ping arrived via DERP. -func (c *Conn) handlePingLocked(dm *disco.Ping, src netip.AddrPort, di *discoInfo, derpNodeSrc key.NodePublic) { +func (c *Conn) handlePingLocked(dm *disco.Ping, src epAddr, di *discoInfo, derpNodeSrc key.NodePublic) { likelyHeartBeat := src == di.lastPingFrom && time.Since(di.lastPingTime) < 5*time.Second di.lastPingFrom = src di.lastPingTime = time.Now() - isDerp := src.Addr() == tailcfg.DerpMagicIPAddr + isDerp := src.ap.Addr() == tailcfg.DerpMagicIPAddr + + if src.vni.IsSet() { + if isDerp { + c.logf("[unexpected] got Geneve-encapsulated disco ping from %v/%v over DERP", src, derpNodeSrc) + return + } + + // [relayManager] is always responsible for handling (replying) to + // Geneve-encapsulated [disco.Ping] messages in the interest of + // simplicity. It might be in the middle of probing src, so it must be + // made aware. + c.relayManager.handleRxDiscoMsg(c, dm, key.NodePublic{}, di.discoKey, src) + return + } + + // This is a naked [disco.Ping] without a VNI. // If we can figure out with certainty which node key this disco - // message is for, eagerly update our IP<>node and disco<>node + // message is for, eagerly update our [epAddr]<>node and disco<>node // mappings to make p2p path discovery faster in simple // cases. Without this, disco would still work, but would be // reliant on DERP call-me-maybe to establish the disco<>node // mapping, and on subsequent disco handlePongConnLocked to establish - // the IP<>disco mapping. + // the IP:port<>disco mapping. if nk, ok := c.unambiguousNodeKeyOfPingLocked(dm, di.discoKey, derpNodeSrc); ok { if !isDerp { - c.peerMap.setNodeKeyForIPPort(src, nk) + c.peerMap.setNodeKeyForEpAddr(src, nk) } } + // numNodes tracks how many nodes (node keys) are associated with the disco + // key tied to this inbound ping. Multiple nodes may share the same disco + // key in the case of node sharing and users switching accounts. + var numNodes int // If we got a ping over DERP, then derpNodeSrc is non-zero and we reply // over DERP (in which case ipDst is also a DERP address). // But if the ping was over UDP (ipDst is not a DERP address), then dstKey @@ -1705,18 +2563,14 @@ func (c *Conn) handlePingLocked(dm *disco.Ping, src netip.AddrPort, di *discoInf dstKey := derpNodeSrc // Remember this route if not present. - var numNodes int var dup bool if isDerp { - if ep, ok := c.peerMap.endpointForNodeKey(derpNodeSrc); ok { - if ep.addCandidateEndpoint(src, dm.TxID) { - return - } + if _, ok := c.peerMap.endpointForNodeKey(derpNodeSrc); ok { numNodes = 1 } } else { c.peerMap.forEachEndpointWithDiscoKey(di.discoKey, func(ep *endpoint) (keepGoing bool) { - if ep.addCandidateEndpoint(src, dm.TxID) { + if ep.addCandidateEndpoint(src.ap, dm.TxID) { dup = true return false } @@ -1746,14 +2600,14 @@ func (c *Conn) handlePingLocked(dm *disco.Ping, src netip.AddrPort, di *discoInf if numNodes > 1 { pingNodeSrcStr = "[one-of-multi]" } - c.dlogf("[v1] magicsock: disco: %v<-%v (%v, %v) got ping tx=%x padding=%v", c.discoShort, di.discoShort, pingNodeSrcStr, src, dm.TxID[:6], dm.Padding) + c.dlogf("[v1] magicsock: disco: %v<-%v (%v, %v) got ping tx=%x padding=%v", c.discoAtomic.Short(), di.discoShort, pingNodeSrcStr, src, dm.TxID[:6], dm.Padding) } ipDst := src discoDest := di.discoKey go c.sendDiscoMessage(ipDst, dstKey, discoDest, &disco.Pong{ TxID: dm.TxID, - Src: src, + Src: src.ap, }, discoVerboseLog) } @@ -1796,25 +2650,30 @@ func (c *Conn) enqueueCallMeMaybe(derpAddr netip.AddrPort, de *endpoint) { for _, ep := range c.lastEndpoints { eps = append(eps, ep.Addr) } - go de.c.sendDiscoMessage(derpAddr, de.publicKey, epDisco.key, &disco.CallMeMaybe{MyNumber: eps}, discoLog) + go de.c.sendDiscoMessage(epAddr{ap: derpAddr}, de.publicKey, epDisco.key, &disco.CallMeMaybe{MyNumber: eps}, discoLog) if debugSendCallMeUnknownPeer() { // Send a callMeMaybe packet to a non-existent peer unknownKey := key.NewNode().Public() c.logf("magicsock: sending CallMeMaybe to unknown peer per TS_DEBUG_SEND_CALLME_UNKNOWN_PEER") - go de.c.sendDiscoMessage(derpAddr, unknownKey, epDisco.key, &disco.CallMeMaybe{MyNumber: eps}, discoLog) + go de.c.sendDiscoMessage(epAddr{ap: derpAddr}, unknownKey, epDisco.key, &disco.CallMeMaybe{MyNumber: eps}, discoLog) } } -// discoInfoLocked returns the previous or new discoInfo for k. +// discoInfoForKnownPeerLocked returns the previous or new discoInfo for k. +// +// Callers must only pass key.DiscoPublic's that are present in and +// lifetime-managed via [Conn].peerMap. UDP relay server disco keys are discovered +// at relay endpoint allocation time or [disco.CallMeMaybeVia] reception time +// and therefore must never pass through this method. // // c.mu must be held. -func (c *Conn) discoInfoLocked(k key.DiscoPublic) *discoInfo { +func (c *Conn) discoInfoForKnownPeerLocked(k key.DiscoPublic) *discoInfo { di, ok := c.discoInfo[k] if !ok { di = &discoInfo{ discoKey: k, discoShort: k.ShortString(), - sharedKey: c.discoPrivate.Shared(k), + sharedKey: c.discoAtomic.Private().Shared(k), } c.discoInfo[k] = di } @@ -1834,7 +2693,9 @@ func (c *Conn) SetNetworkUp(up bool) { if up { c.startDerpHomeConnectLocked() } else { - c.portMapper.NoteNetworkDown() + if c.portMapper != nil { + c.portMapper.NoteNetworkDown() + } c.closeAllDerpLocked("network-down") } } @@ -2017,34 +2878,184 @@ func (c *Conn) SetProbeUDPLifetime(v bool) { }) } -// SetNetworkMap is called when the control client gets a new network -// map from the control server. It must always be non-nil. +// capVerIsRelayCapable returns true if version is relay client and server +// capable, otherwise it returns false. +func capVerIsRelayCapable(version tailcfg.CapabilityVersion) bool { + return version >= 121 +} + +// onFilterUpdate is called when a [FilterUpdate] is received over the +// [eventbus.Bus]. +func (c *Conn) onFilterUpdate(f FilterUpdate) { + c.mu.Lock() + c.filt = f.Filter + self := c.self + peers := c.peers + relayClientEnabled := c.relayClientEnabled + c.mu.Unlock() // release c.mu before potentially calling c.updateRelayServersSet which is O(m * n) + + if !relayClientEnabled { + // Early return if we cannot operate as a relay client. + return + } + + // The filter has changed, and we are operating as a relay server client. + // Re-evaluate it in order to produce an updated relay server set. + c.updateRelayServersSet(f.Filter, self, peers) +} + +// updateRelayServersSet iterates all peers and self, evaluating filt for each +// one in order to determine which are relay server candidates. filt, self, and +// peers are passed as args (vs c.mu-guarded fields) to enable callers to +// release c.mu before calling as this is O(m * n) (we iterate all cap rules 'm' +// in filt for every peer 'n'). // -// It should not use the DERPMap field of NetworkMap; that's -// conditionally sent to SetDERPMap instead. -func (c *Conn) SetNetworkMap(nm *netmap.NetworkMap) { +// Calls to updateRelayServersSet must never run concurrent to +// [endpoint.setDERPHome], otherwise [candidatePeerRelay] DERP home changes may +// be missed from the perspective of [relayManager]. +// +// TODO: Optimize this so that it's not O(m * n). This might involve: +// 1. Changes to [filter.Filter], e.g. adding a CapsWithValues() to check for +// a given capability instead of building and returning a map of all of +// them. +// 2. Moving this work upstream into [nodeBackend] or similar, and publishing +// the computed result over the eventbus instead. +func (c *Conn) updateRelayServersSet(filt *filter.Filter, self tailcfg.NodeView, peers views.Slice[tailcfg.NodeView]) { + relayServers := make(set.Set[candidatePeerRelay]) + nodes := append(peers.AsSlice(), self) + for _, maybeCandidate := range nodes { + if maybeCandidate.ID() != self.ID() && !capVerIsRelayCapable(maybeCandidate.Cap()) { + // If maybeCandidate's [tailcfg.CapabilityVersion] is not relay-capable, + // we skip it. If maybeCandidate happens to be self, then this check is + // unnecessary as self is always capable from this point (the statically + // compiled [tailcfg.CurrentCapabilityVersion]) forward. + continue + } + if !nodeHasCap(filt, maybeCandidate, self, tailcfg.PeerCapabilityRelayTarget) { + continue + } + relayServers.Add(candidatePeerRelay{ + nodeKey: maybeCandidate.Key(), + discoKey: maybeCandidate.DiscoKey(), + derpHomeRegionID: uint16(maybeCandidate.HomeDERP()), + }) + } + c.relayManager.handleRelayServersSet(relayServers) + if len(relayServers) > 0 { + c.hasPeerRelayServers.Store(true) + } else { + c.hasPeerRelayServers.Store(false) + } +} + +// nodeHasCap returns true if src has cap on dst, otherwise it returns false. +func nodeHasCap(filt *filter.Filter, src, dst tailcfg.NodeView, cap tailcfg.PeerCapability) bool { + if filt == nil || + !src.Valid() || + !dst.Valid() { + return false + } + for _, srcPrefix := range src.Addresses().All() { + if !srcPrefix.IsSingleIP() { + continue + } + srcAddr := srcPrefix.Addr() + for _, dstPrefix := range dst.Addresses().All() { + if !dstPrefix.IsSingleIP() { + continue + } + dstAddr := dstPrefix.Addr() + if dstAddr.BitLen() == srcAddr.BitLen() { // same address family + // [nodeBackend.peerCapsLocked] only returns/considers the + // [tailcfg.PeerCapMap] between the passed src and the _first_ + // host (/32 or /128) address for self. We are consistent with + // that behavior here. If src and dst host addresses are of the + // same address family they either have the capability or not. + // We do not check against additional host addresses of the same + // address family. + return filt.CapsWithValues(srcAddr, dstAddr).HasCapability(cap) + } + } + } + return false +} + +// candidatePeerRelay represents the identifiers and DERP home region ID for a +// peer relay server. +type candidatePeerRelay struct { + nodeKey key.NodePublic + discoKey key.DiscoPublic + derpHomeRegionID uint16 +} + +func (c *candidatePeerRelay) isValid() bool { + return !c.nodeKey.IsZero() && !c.discoKey.IsZero() +} + +// onNodeViewsUpdate is called when a [NodeViewsUpdate] is received over the +// [eventbus.Bus]. +func (c *Conn) onNodeViewsUpdate(update NodeViewsUpdate) { + peersChanged := c.updateNodes(update) + + relayClientEnabled := update.SelfNode.Valid() && + !update.SelfNode.HasCap(tailcfg.NodeAttrDisableRelayClient) && + !update.SelfNode.HasCap(tailcfg.NodeAttrOnlyTCP443) + + c.mu.Lock() + relayClientChanged := c.relayClientEnabled != relayClientEnabled + c.relayClientEnabled = relayClientEnabled + filt := c.filt + self := c.self + peers := c.peers + isClosed := c.closed + c.mu.Unlock() // release c.mu before potentially calling c.updateRelayServersSet which is O(m * n) + + if isClosed { + return // nothing to do here, the conn is closed and the update is no longer relevant + } + + if peersChanged || relayClientChanged { + if !relayClientEnabled { + c.relayManager.handleRelayServersSet(nil) + c.hasPeerRelayServers.Store(false) + } else { + c.updateRelayServersSet(filt, self, peers) + } + } +} + +// updateNodes updates [Conn] to reflect the [tailcfg.NodeView]'s contained +// in update. It returns true if update.Peers was unequal to c.peers, otherwise +// false. +func (c *Conn) updateNodes(update NodeViewsUpdate) (peersChanged bool) { c.mu.Lock() defer c.mu.Unlock() if c.closed { - return + return false } priorPeers := c.peers - metricNumPeers.Set(int64(len(nm.Peers))) + metricNumPeers.Set(int64(len(update.Peers))) - // Update c.netMap regardless, before the following early return. - curPeers := views.SliceOf(nm.Peers) + // Update c.self & c.peers regardless, before the following early return. + c.self = update.SelfNode + curPeers := views.SliceOf(update.Peers) c.peers = curPeers + // [debugFlags] are mutable in [Conn.SetSilentDisco] & + // [Conn.SetProbeUDPLifetime]. These setters are passed [controlknobs.Knobs] + // values by [ipnlocal.LocalBackend] around netmap reception. + // [controlknobs.Knobs] are simply self [tailcfg.NodeCapability]'s. They are + // useful as a global view of notable feature toggles, but the magicsock + // setters are completely unnecessary as we have the same values right here + // (update.SelfNode.Capabilities) at a time they are considered most + // up-to-date. + // TODO: mutate [debugFlags] here instead of in various [Conn] setters. flags := c.debugFlagsLocked() - if addrs := nm.GetAddresses(); addrs.Len() > 0 { - c.firstAddrForTest = addrs.At(0).Addr() - } else { - c.firstAddrForTest = netip.Addr{} - } - if nodesEqual(priorPeers, curPeers) && c.lastFlags == flags { + peersChanged = !nodesEqual(priorPeers, curPeers) + if !peersChanged && c.lastFlags == flags { // The rest of this function is all adjusting state for peers that have // changed. But if the set of peers is equal and the debug flags (for // silent disco and probe UDP lifetime) haven't changed, there is no @@ -2054,16 +3065,16 @@ func (c *Conn) SetNetworkMap(nm *netmap.NetworkMap) { c.lastFlags = flags - c.logf("[v1] magicsock: got updated network map; %d peers", len(nm.Peers)) + c.logf("[v1] magicsock: got updated network map; %d peers", len(update.Peers)) - entriesPerBuffer := debugRingBufferSize(len(nm.Peers)) + entriesPerBuffer := debugRingBufferSize(len(update.Peers)) // Try a pass of just upserting nodes and creating missing // endpoints. If the set of nodes is the same, this is an // efficient alloc-free update. If the set of nodes is different, // we'll fall through to the next pass, which allocates but can // handle full set updates. - for _, n := range nm.Peers { + for _, n := range update.Peers { if n.ID() == 0 { devPanicf("node with zero ID") continue @@ -2140,7 +3151,7 @@ func (c *Conn) SetNetworkMap(nm *netmap.NetworkMap) { // ~1MB on mobile but we never used the data so the memory was just // wasted. default: - ep.debugUpdates = ringbuffer.New[EndpointChange](entriesPerBuffer) + ep.debugUpdates = ringlog.New[EndpointChange](entriesPerBuffer) } if n.Addresses().Len() > 0 { ep.nodeAddr = n.Addresses().At(0).Addr() @@ -2163,14 +3174,14 @@ func (c *Conn) SetNetworkMap(nm *netmap.NetworkMap) { c.peerMap.upsertEndpoint(ep, key.DiscoPublic{}) } - // If the set of nodes changed since the last SetNetworkMap, the + // If the set of nodes changed since the last onNodeViewsUpdate, the // upsert loop just above made c.peerMap contain the union of the // old and new peers - which will be larger than the set from the // current netmap. If that happens, go through the allocful // deletion path to clean up moribund nodes. - if c.peerMap.nodeCount() != len(nm.Peers) { + if c.peerMap.nodeCount() != len(update.Peers) { keep := set.Set[key.NodePublic]{} - for _, n := range nm.Peers { + for _, n := range update.Peers { keep.Add(n.Key()) } c.peerMap.forEachEndpoint(func(ep *endpoint) { @@ -2186,6 +3197,8 @@ func (c *Conn) SetNetworkMap(nm *netmap.NetworkMap) { delete(c.discoInfo, dk) } } + + return peersChanged } func devPanicf(format string, a ...any) { @@ -2196,10 +3209,7 @@ func devPanicf(format string, a ...any) { func (c *Conn) logEndpointCreated(n tailcfg.NodeView) { c.logf("magicsock: created endpoint key=%s: disco=%s; %v", n.Key().ShortString(), n.DiscoKey().ShortString(), logger.ArgWriter(func(w *bufio.Writer) { - const derpPrefix = "127.3.3.40:" - if strings.HasPrefix(n.DERP(), derpPrefix) { - ipp, _ := netip.ParseAddrPort(n.DERP()) - regionID := int(ipp.Port()) + if regionID := n.HomeDERP(); regionID != 0 { code := c.derpRegionCodeLocked(regionID) if code != "" { code = "(" + code + ")" @@ -2207,16 +3217,14 @@ func (c *Conn) logEndpointCreated(n tailcfg.NodeView) { fmt.Fprintf(w, "derp=%v%s ", regionID, code) } - for i := range n.AllowedIPs().Len() { - a := n.AllowedIPs().At(i) + for _, a := range n.AllowedIPs().All() { if a.IsSingleIP() { fmt.Fprintf(w, "aip=%v ", a.Addr()) } else { fmt.Fprintf(w, "aip=%v ", a) } } - for i := range n.Endpoints().Len() { - ep := n.Endpoints().At(i) + for _, ep := range n.Endpoints().All() { fmt.Fprintf(w, "ep=%v ", ep) } })) @@ -2335,6 +3343,13 @@ func (c *connBind) isClosed() bool { // // Only the first close does anything. Any later closes return nil. func (c *Conn) Close() error { + // Close the [eventbus.Client] to wait for subscribers to + // return before acquiring c.mu: + // 1. Event handlers also acquire c.mu, they can deadlock with c.Close(). + // 2. Event handlers may not guard against undesirable post/in-progress + // Conn.Close() behaviors. + c.eventClient.Close() + c.mu.Lock() defer c.mu.Unlock() if c.closed { @@ -2345,7 +3360,9 @@ func (c *Conn) Close() error { c.derpCleanupTimer.Stop() } c.stopPeriodicReSTUNTimerLocked() - c.portMapper.Close() + if c.portMapper != nil { + c.portMapper.Close() + } c.peerMap.forEachEndpoint(func(ep *endpoint) { ep.stopAndReset() @@ -2364,7 +3381,6 @@ func (c *Conn) Close() error { if c.closeDisco6 != nil { c.closeDisco6.Close() } - // Wait on goroutines updating right at the end, once everything is // already closed. We want everything else in the Conn to be // consistently in the closed state before we release mu to wait @@ -2377,6 +3393,8 @@ func (c *Conn) Close() error { pinger.Close() } + deregisterMetrics() + return nil } @@ -2427,7 +3445,7 @@ func (c *Conn) shouldDoPeriodicReSTUNLocked() bool { return true } -func (c *Conn) onPortMapChanged() { c.ReSTUN("portmap-changed") } +func (c *Conn) onPortMapChanged(portmappertype.Mapping) { c.ReSTUN("portmap-changed") } // ReSTUN triggers an address discovery. // The provided why string is for debug logging only. @@ -2484,9 +3502,9 @@ func (c *Conn) listenPacket(network string, port uint16) (nettype.PacketConn, er return nettype.MakePacketListenerWithNetIP(netns.Listener(c.logf, c.netMon)).ListenPacket(ctx, network, addr) } -// bindSocket initializes rucPtr if necessary and binds a UDP socket to it. +// bindSocket binds a UDP socket to ruc. // Network indicates the UDP socket type; it must be "udp4" or "udp6". -// If rucPtr had an existing UDP socket bound, it closes that socket. +// If ruc had an existing UDP socket bound, it closes that socket. // The caller is responsible for informing the portMapper of any changes. // If curPortFate is set to dropCurrentPort, no attempt is made to reuse // the current port. @@ -2525,7 +3543,7 @@ func (c *Conn) bindSocket(ruc *RebindingUDPConn, network string, curPortFate cur } ports = append(ports, 0) // Remove duplicates. (All duplicates are consecutive.) - uniq.ModifySlice(&ports) + ports = slices.Compact(ports) if debugBindSocket() { c.logf("magicsock: bindSocket: candidate ports: %+v", ports) @@ -2544,7 +3562,7 @@ func (c *Conn) bindSocket(ruc *RebindingUDPConn, network string, curPortFate cur c.logf("magicsock: unable to bind %v port %d: %v", network, port, err) continue } - if c.onPortUpdate != nil { + if c.portUpdatePub.ShouldPublish() { _, gotPortStr, err := net.SplitHostPort(pconn.LocalAddr().String()) if err != nil { c.logf("could not parse port from %s: %w", pconn.LocalAddr().String(), err) @@ -2553,11 +3571,13 @@ func (c *Conn) bindSocket(ruc *RebindingUDPConn, network string, curPortFate cur if err != nil { c.logf("could not parse port from %s: %w", gotPort, err) } else { - c.onPortUpdate(uint16(gotPort), network) + c.portUpdatePub.Publish(router.PortUpdate{ + UDPPort: uint16(gotPort), + EndpointNetwork: network, + }) } } } - trySetSocketBuffer(pconn, c.logf) trySetUDPSocketOptions(pconn, c.logf) // Success. @@ -2598,7 +3618,9 @@ func (c *Conn) rebind(curPortFate currentPortFate) error { if err := c.bindSocket(&c.pconn4, "udp4", curPortFate); err != nil { return fmt.Errorf("magicsock: Rebind IPv4 failed: %w", err) } - c.portMapper.SetLocalPort(c.LocalPort()) + if c.portMapper != nil { + c.portMapper.SetLocalPort(c.LocalPort()) + } c.UpdatePMTUD() return nil } @@ -2620,7 +3642,9 @@ func (c *Conn) Rebind() { c.logf("Rebind; defIf=%q, ips=%v", defIf, ifIPs) } - c.maybeCloseDERPsOnRebind(ifIPs) + if len(ifIPs) > 0 { + c.maybeCloseDERPsOnRebind(ifIPs) + } c.resetEndpointStates() } @@ -2692,12 +3716,13 @@ func simpleDur(d time.Duration) time.Duration { return d.Round(time.Minute) } -// UpdateNetmapDelta implements controlclient.NetmapDeltaUpdater. -func (c *Conn) UpdateNetmapDelta(muts []netmap.NodeMutation) (handled bool) { +// onNodeMutationsUpdate is called when a [NodeMutationsUpdate] is received over +// the [eventbus.Bus]. Note: It does not apply these mutations to c.peers. +func (c *Conn) onNodeMutationsUpdate(update NodeMutationsUpdate) { c.mu.Lock() defer c.mu.Unlock() - for _, m := range muts { + for _, m := range update.Mutations { nodeID := m.NodeIDBeingMutated() ep, ok := c.peerMap.endpointForNodeID(nodeID) if !ok { @@ -2712,10 +3737,9 @@ func (c *Conn) UpdateNetmapDelta(muts []netmap.NodeMutation) (handled bool) { ep.mu.Unlock() } } - return true } -// UpdateStatus implements the interface nede by ipnstate.StatusBuilder. +// UpdateStatus implements the interface needed by ipnstate.StatusBuilder. // // This method adds in the magicsock-specific information only. Most // of the status is otherwise populated by LocalBackend. @@ -2750,10 +3774,12 @@ func (c *Conn) UpdateStatus(sb *ipnstate.StatusBuilder) { }) } -// SetStatistics specifies a per-connection statistics aggregator. +// SetConnectionCounter specifies a per-connection statistics aggregator. // Nil may be specified to disable statistics gathering. -func (c *Conn) SetStatistics(stats *connstats.Statistics) { - c.stats.Store(stats) +func (c *Conn) SetConnectionCounter(fn netlogfunc.ConnectionCounter) { + if buildfeatures.HasNetLog { + c.connCounter.Store(fn) + } } // SetHomeless sets whether magicsock should idle harder and not have a DERP @@ -2781,9 +3807,17 @@ const ( // keep NAT mappings alive. sessionActiveTimeout = 45 * time.Second - // upgradeInterval is how often we try to upgrade to a better path - // even if we have some non-DERP route that works. - upgradeInterval = 1 * time.Minute + // upgradeUDPDirectInterval is how often we try to upgrade to a better, + // direct UDP path even if we have some direct UDP path that works. + upgradeUDPDirectInterval = 1 * time.Minute + + // upgradeUDPRelayInterval is how often we try to discover UDP relay paths + // even if we have a UDP relay path that works. + upgradeUDPRelayInterval = 1 * time.Minute + + // discoverUDPRelayPathsInterval is the minimum time between UDP relay path + // discovery. + discoverUDPRelayPathsInterval = 30 * time.Second // heartbeatInterval is how often pings to the best UDP address // are sent. @@ -2860,40 +3894,56 @@ func (c *Conn) DebugPickNewDERP() error { return errors.New("too few regions") } -// portableTrySetSocketBuffer sets SO_SNDBUF and SO_RECVBUF on pconn to socketBufferSize, -// logging an error if it occurs. -func portableTrySetSocketBuffer(pconn nettype.PacketConn, logf logger.Logf) { - if c, ok := pconn.(*net.UDPConn); ok { - // Attempt to increase the buffer size, and allow failures. - if err := c.SetReadBuffer(socketBufferSize); err != nil { - logf("magicsock: failed to set UDP read buffer size to %d: %v", socketBufferSize, err) +func (c *Conn) DebugForcePreferDERP(n int) { + c.mu.Lock() + defer c.mu.Unlock() + + c.logf("magicsock: [debug] force preferred DERP set to: %d", n) + c.netChecker.SetForcePreferredDERP(n) +} + +func trySetUDPSocketOptions(pconn nettype.PacketConn, logf logger.Logf) { + directions := []sockopts.BufferDirection{sockopts.ReadDirection, sockopts.WriteDirection} + for _, direction := range directions { + forceErr, portableErr := sockopts.SetBufferSize(pconn, direction, socketBufferSize) + if forceErr != nil { + logf("magicsock: [warning] failed to force-set UDP %v buffer size to %d: %v; using kernel default values (impacts throughput only)", direction, socketBufferSize, forceErr) } - if err := c.SetWriteBuffer(socketBufferSize); err != nil { - logf("magicsock: failed to set UDP write buffer size to %d: %v", socketBufferSize, err) + if portableErr != nil { + logf("magicsock: failed to set UDP %v buffer size to %d: %v", direction, socketBufferSize, portableErr) } } + + err := sockopts.SetICMPErrImmunity(pconn) + if err != nil { + logf("magicsock: %v", err) + } } // derpStr replaces DERP IPs in s with "derp-". func derpStr(s string) string { return strings.ReplaceAll(s, "127.3.3.40:", "derp-") } -// ippEndpointCache is a mutex-free single-element cache, mapping from -// a single netip.AddrPort to a single endpoint. -type ippEndpointCache struct { - ipp netip.AddrPort - gen int64 - de *endpoint +// epAddrEndpointCache is a mutex-free single-element cache, mapping from +// a single [epAddr] to a single [*endpoint]. +type epAddrEndpointCache struct { + epAddr epAddr + gen int64 + de *endpoint } // discoInfo is the info and state for the DiscoKey -// in the Conn.discoInfo map key. +// in the [Conn.discoInfo] and [relayManager.discoInfoByServerDisco] map keys. +// +// When the disco protocol is used to handshake with a peer relay server, the +// corresponding discoInfo is held in [relayManager.discoInfoByServerDisco] +// instead of [Conn.discoInfo]. // // Note that a DiscoKey does not necessarily map to exactly one // node. In the case of shared nodes and users switching accounts, two // nodes in the NetMap may legitimately have the same DiscoKey. As // such, no fields in here should be considered node-specific. type discoInfo struct { - // discoKey is the same as the Conn.discoInfo map key, + // discoKey is the same as the corresponding map key, // just so you can pass around a *discoInfo alone. // Not modified once initialized. discoKey key.DiscoPublic @@ -2904,14 +3954,16 @@ type discoInfo struct { // sharedKey is the precomputed key for communication with the // peer that has the DiscoKey used to look up this *discoInfo in - // Conn.discoInfo. + // the corresponding map. // Not modified once initialized. sharedKey key.DiscoShared - // Mutable fields follow, owned by Conn.mu: + // Mutable fields follow, owned by [Conn.mu]. These are irrelevant when + // discoInfo is a peer relay server disco key in the + // [relayManager.discoInfoByServerDisco] map: // lastPingFrom is the src of a ping for discoKey. - lastPingFrom netip.AddrPort + lastPingFrom epAddr // lastPingTime is the last time of a ping for discoKey. lastPingTime time.Time @@ -2930,41 +3982,85 @@ var ( metricSendDERPErrorChan = clientmetric.NewCounter("magicsock_send_derp_error_chan") metricSendDERPErrorClosed = clientmetric.NewCounter("magicsock_send_derp_error_closed") metricSendDERPErrorQueue = clientmetric.NewCounter("magicsock_send_derp_error_queue") - metricSendUDP = clientmetric.NewCounter("magicsock_send_udp") + metricSendDERPDropped = clientmetric.NewCounter("magicsock_send_derp_dropped") metricSendUDPError = clientmetric.NewCounter("magicsock_send_udp_error") - metricSendDERP = clientmetric.NewCounter("magicsock_send_derp") + metricSendPeerRelayError = clientmetric.NewCounter("magicsock_send_peer_relay_error") metricSendDERPError = clientmetric.NewCounter("magicsock_send_derp_error") + // Sends (data) + // + // Note: Prior to v1.78 metricSendUDP & metricSendDERP counted sends of data + // AND disco packets. They were updated in v1.78 to only count data packets. + // metricSendPeerRelay was added in v1.86 and has always counted only data + // packets. + metricSendUDP = clientmetric.NewAggregateCounter("magicsock_send_udp") + metricSendPeerRelay = clientmetric.NewAggregateCounter("magicsock_send_peer_relay") + metricSendDERP = clientmetric.NewAggregateCounter("magicsock_send_derp") + // Data packets (non-disco) - metricSendData = clientmetric.NewCounter("magicsock_send_data") - metricSendDataNetworkDown = clientmetric.NewCounter("magicsock_send_data_network_down") - metricRecvDataDERP = clientmetric.NewCounter("magicsock_recv_data_derp") - metricRecvDataIPv4 = clientmetric.NewCounter("magicsock_recv_data_ipv4") - metricRecvDataIPv6 = clientmetric.NewCounter("magicsock_recv_data_ipv6") + metricSendData = clientmetric.NewCounter("magicsock_send_data") + metricSendDataNetworkDown = clientmetric.NewCounter("magicsock_send_data_network_down") + metricRecvDataPacketsDERP = clientmetric.NewAggregateCounter("magicsock_recv_data_derp") + metricRecvDataPacketsIPv4 = clientmetric.NewAggregateCounter("magicsock_recv_data_ipv4") + metricRecvDataPacketsIPv6 = clientmetric.NewAggregateCounter("magicsock_recv_data_ipv6") + metricRecvDataPacketsPeerRelayIPv4 = clientmetric.NewAggregateCounter("magicsock_recv_data_peer_relay_ipv4") + metricRecvDataPacketsPeerRelayIPv6 = clientmetric.NewAggregateCounter("magicsock_recv_data_peer_relay_ipv6") + metricSendDataPacketsDERP = clientmetric.NewAggregateCounter("magicsock_send_data_derp") + metricSendDataPacketsIPv4 = clientmetric.NewAggregateCounter("magicsock_send_data_ipv4") + metricSendDataPacketsIPv6 = clientmetric.NewAggregateCounter("magicsock_send_data_ipv6") + metricSendDataPacketsPeerRelayIPv4 = clientmetric.NewAggregateCounter("magicsock_send_data_peer_relay_ipv4") + metricSendDataPacketsPeerRelayIPv6 = clientmetric.NewAggregateCounter("magicsock_send_data_peer_relay_ipv6") + + // Data bytes (non-disco) + metricRecvDataBytesDERP = clientmetric.NewAggregateCounter("magicsock_recv_data_bytes_derp") + metricRecvDataBytesIPv4 = clientmetric.NewAggregateCounter("magicsock_recv_data_bytes_ipv4") + metricRecvDataBytesIPv6 = clientmetric.NewAggregateCounter("magicsock_recv_data_bytes_ipv6") + metricRecvDataBytesPeerRelayIPv4 = clientmetric.NewAggregateCounter("magicsock_recv_data_bytes_peer_relay_ipv4") + metricRecvDataBytesPeerRelayIPv6 = clientmetric.NewAggregateCounter("magicsock_recv_data_bytes_peer_relay_ipv6") + metricSendDataBytesDERP = clientmetric.NewAggregateCounter("magicsock_send_data_bytes_derp") + metricSendDataBytesIPv4 = clientmetric.NewAggregateCounter("magicsock_send_data_bytes_ipv4") + metricSendDataBytesIPv6 = clientmetric.NewAggregateCounter("magicsock_send_data_bytes_ipv6") + metricSendDataBytesPeerRelayIPv4 = clientmetric.NewAggregateCounter("magicsock_send_data_bytes_peer_relay_ipv4") + metricSendDataBytesPeerRelayIPv6 = clientmetric.NewAggregateCounter("magicsock_send_data_bytes_peer_relay_ipv6") // Disco packets - metricSendDiscoUDP = clientmetric.NewCounter("magicsock_disco_send_udp") - metricSendDiscoDERP = clientmetric.NewCounter("magicsock_disco_send_derp") - metricSentDiscoUDP = clientmetric.NewCounter("magicsock_disco_sent_udp") - metricSentDiscoDERP = clientmetric.NewCounter("magicsock_disco_sent_derp") - metricSentDiscoPing = clientmetric.NewCounter("magicsock_disco_sent_ping") - metricSentDiscoPong = clientmetric.NewCounter("magicsock_disco_sent_pong") - metricSentDiscoPeerMTUProbes = clientmetric.NewCounter("magicsock_disco_sent_peer_mtu_probes") - metricSentDiscoPeerMTUProbeBytes = clientmetric.NewCounter("magicsock_disco_sent_peer_mtu_probe_bytes") - metricSentDiscoCallMeMaybe = clientmetric.NewCounter("magicsock_disco_sent_callmemaybe") - metricRecvDiscoBadPeer = clientmetric.NewCounter("magicsock_disco_recv_bad_peer") - metricRecvDiscoBadKey = clientmetric.NewCounter("magicsock_disco_recv_bad_key") - metricRecvDiscoBadParse = clientmetric.NewCounter("magicsock_disco_recv_bad_parse") - - metricRecvDiscoUDP = clientmetric.NewCounter("magicsock_disco_recv_udp") - metricRecvDiscoDERP = clientmetric.NewCounter("magicsock_disco_recv_derp") - metricRecvDiscoPing = clientmetric.NewCounter("magicsock_disco_recv_ping") - metricRecvDiscoPong = clientmetric.NewCounter("magicsock_disco_recv_pong") - metricRecvDiscoCallMeMaybe = clientmetric.NewCounter("magicsock_disco_recv_callmemaybe") - metricRecvDiscoCallMeMaybeBadNode = clientmetric.NewCounter("magicsock_disco_recv_callmemaybe_bad_node") - metricRecvDiscoCallMeMaybeBadDisco = clientmetric.NewCounter("magicsock_disco_recv_callmemaybe_bad_disco") - metricRecvDiscoDERPPeerNotHere = clientmetric.NewCounter("magicsock_disco_recv_derp_peer_not_here") - metricRecvDiscoDERPPeerGoneUnknown = clientmetric.NewCounter("magicsock_disco_recv_derp_peer_gone_unknown") + metricSendDiscoUDP = clientmetric.NewCounter("magicsock_disco_send_udp") + metricSendDiscoDERP = clientmetric.NewCounter("magicsock_disco_send_derp") + metricSentDiscoUDP = clientmetric.NewCounter("magicsock_disco_sent_udp") + metricSentDiscoDERP = clientmetric.NewCounter("magicsock_disco_sent_derp") + metricSentDiscoPing = clientmetric.NewCounter("magicsock_disco_sent_ping") + metricSentDiscoPong = clientmetric.NewCounter("magicsock_disco_sent_pong") + metricSentDiscoPeerMTUProbes = clientmetric.NewCounter("magicsock_disco_sent_peer_mtu_probes") + metricSentDiscoPeerMTUProbeBytes = clientmetric.NewCounter("magicsock_disco_sent_peer_mtu_probe_bytes") + metricSentDiscoCallMeMaybe = clientmetric.NewCounter("magicsock_disco_sent_callmemaybe") + metricSentDiscoCallMeMaybeVia = clientmetric.NewCounter("magicsock_disco_sent_callmemaybevia") + metricSentDiscoBindUDPRelayEndpoint = clientmetric.NewCounter("magicsock_disco_sent_bind_udp_relay_endpoint") + metricSentDiscoBindUDPRelayEndpointAnswer = clientmetric.NewCounter("magicsock_disco_sent_bind_udp_relay_endpoint_answer") + metricSentDiscoAllocUDPRelayEndpointRequest = clientmetric.NewCounter("magicsock_disco_sent_alloc_udp_relay_endpoint_request") + metricLocalDiscoAllocUDPRelayEndpointRequest = clientmetric.NewCounter("magicsock_disco_local_alloc_udp_relay_endpoint_request") + metricSentDiscoAllocUDPRelayEndpointResponse = clientmetric.NewCounter("magicsock_disco_sent_alloc_udp_relay_endpoint_response") + metricRecvDiscoBadPeer = clientmetric.NewCounter("magicsock_disco_recv_bad_peer") + metricRecvDiscoBadKey = clientmetric.NewCounter("magicsock_disco_recv_bad_key") + metricRecvDiscoBadParse = clientmetric.NewCounter("magicsock_disco_recv_bad_parse") + + metricRecvDiscoUDP = clientmetric.NewCounter("magicsock_disco_recv_udp") + metricRecvDiscoDERP = clientmetric.NewCounter("magicsock_disco_recv_derp") + metricRecvDiscoPing = clientmetric.NewCounter("magicsock_disco_recv_ping") + metricRecvDiscoPong = clientmetric.NewCounter("magicsock_disco_recv_pong") + metricRecvDiscoCallMeMaybe = clientmetric.NewCounter("magicsock_disco_recv_callmemaybe") + metricRecvDiscoCallMeMaybeVia = clientmetric.NewCounter("magicsock_disco_recv_callmemaybevia") + metricRecvDiscoCallMeMaybeBadNode = clientmetric.NewCounter("magicsock_disco_recv_callmemaybe_bad_node") + metricRecvDiscoCallMeMaybeViaBadNode = clientmetric.NewCounter("magicsock_disco_recv_callmemaybevia_bad_node") + metricRecvDiscoCallMeMaybeBadDisco = clientmetric.NewCounter("magicsock_disco_recv_callmemaybe_bad_disco") + metricRecvDiscoCallMeMaybeViaBadDisco = clientmetric.NewCounter("magicsock_disco_recv_callmemaybevia_bad_disco") + metricRecvDiscoBindUDPRelayEndpointChallenge = clientmetric.NewCounter("magicsock_disco_recv_bind_udp_relay_endpoint_challenge") + metricRecvDiscoAllocUDPRelayEndpointRequest = clientmetric.NewCounter("magicsock_disco_recv_alloc_udp_relay_endpoint_request") + metricRecvDiscoAllocUDPRelayEndpointRequestBadDisco = clientmetric.NewCounter("magicsock_disco_recv_alloc_udp_relay_endpoint_request_bad_disco") + metricRecvDiscoAllocUDPRelayEndpointResponseBadDisco = clientmetric.NewCounter("magicsock_disco_recv_alloc_udp_relay_endpoint_response_bad_disco") + metricRecvDiscoAllocUDPRelayEndpointResponse = clientmetric.NewCounter("magicsock_disco_recv_alloc_udp_relay_endpoint_response") + metricLocalDiscoAllocUDPRelayEndpointResponse = clientmetric.NewCounter("magicsock_disco_local_alloc_udp_relay_endpoint_response") + metricRecvDiscoDERPPeerNotHere = clientmetric.NewCounter("magicsock_disco_recv_derp_peer_not_here") + metricRecvDiscoDERPPeerGoneUnknown = clientmetric.NewCounter("magicsock_disco_recv_derp_peer_gone_unknown") // metricDERPHomeChange is how many times our DERP home region DI has // changed from non-zero to a different non-zero. metricDERPHomeChange = clientmetric.NewCounter("derp_home_change") @@ -3038,34 +4134,133 @@ func (c *Conn) SetLastNetcheckReportForTest(ctx context.Context, report *netchec c.lastNetCheckReport.Store(report) } -// lazyEndpoint is a wireguard conn.Endpoint for when magicsock received a +// lazyEndpoint is a wireguard [conn.Endpoint] for when magicsock received a // non-disco (presumably WireGuard) packet from a UDP address from which we -// can't map to a Tailscale peer. But Wireguard most likely can, once it -// decrypts it. So we implement the conn.PeerAwareEndpoint interface -// from https://github.com/tailscale/wireguard-go/pull/27 to allow WireGuard -// to tell us who it is later and get the correct conn.Endpoint. +// can't map to a Tailscale peer. But WireGuard most likely can, once it +// decrypts it. So we implement the [conn.InitiationAwareEndpoint] and +// [conn.PeerAwareEndpoint] interfaces, to allow WireGuard to tell us who it is +// later, just-in-time to configure the peer, and set the associated [epAddr] +// in the [peerMap]. Future receives on the associated [epAddr] will then +// resolve directly to an [*endpoint]. +// +// We also sometimes (see [Conn.receiveIP]) return a [*lazyEndpoint] to +// wireguard-go to verify an [epAddr] resolves to the [*endpoint] (maybeEP) we +// believe it to be, to resolve [epAddr] collisions across peers. [epAddr] +// collisions have a higher chance of occurrence for packets received over peer +// relays versus direct connections, as peer relay connections do not upsert +// into [peerMap] around disco packet reception, but direct connections do. type lazyEndpoint struct { - c *Conn - src netip.AddrPort + c *Conn + maybeEP *endpoint // or nil if unknown + src epAddr } +var _ conn.InitiationAwareEndpoint = (*lazyEndpoint)(nil) var _ conn.PeerAwareEndpoint = (*lazyEndpoint)(nil) var _ conn.Endpoint = (*lazyEndpoint)(nil) -func (le *lazyEndpoint) ClearSrc() {} -func (le *lazyEndpoint) SrcIP() netip.Addr { return le.src.Addr() } -func (le *lazyEndpoint) DstIP() netip.Addr { return netip.Addr{} } -func (le *lazyEndpoint) SrcToString() string { return le.src.String() } -func (le *lazyEndpoint) DstToString() string { return "dst" } -func (le *lazyEndpoint) DstToBytes() []byte { return nil } -func (le *lazyEndpoint) GetPeerEndpoint(peerPublicKey [32]byte) conn.Endpoint { +// InitiationMessagePublicKey implements [conn.InitiationAwareEndpoint]. +// wireguard-go calls us here if we passed it a [*lazyEndpoint] for an +// initiation message, for which it might not have the relevant peer configured, +// enabling us to just-in-time configure it and note its activity via +// [*endpoint.noteRecvActivity], before it performs peer lookup and attempts +// decryption. +// +// Reception of all other WireGuard message types implies pre-existing knowledge +// of the peer by wireguard-go for it to do useful work. See +// [userspaceEngine.maybeReconfigWireguardLocked] & +// [userspaceEngine.noteRecvActivity] for more details around just-in-time +// wireguard-go peer (de)configuration. +func (le *lazyEndpoint) InitiationMessagePublicKey(peerPublicKey [32]byte) { pubKey := key.NodePublicFromRaw32(mem.B(peerPublicKey[:])) + if le.maybeEP != nil && pubKey.Compare(le.maybeEP.publicKey) == 0 { + return + } + le.c.mu.Lock() + ep, ok := le.c.peerMap.endpointForNodeKey(pubKey) + // [Conn.mu] must not be held while [Conn.noteRecvActivity] is called, which + // [endpoint.noteRecvActivity] can end up calling. See + // [Options.NoteRecvActivity] docs. + le.c.mu.Unlock() + if !ok { + return + } + now := mono.Now() + ep.lastRecvUDPAny.StoreAtomic(now) + ep.noteRecvActivity(le.src, now) + // [ep.noteRecvActivity] may end up JIT configuring the peer, but we don't + // update [peerMap] as wireguard-go hasn't decrypted the initiation + // message yet. wireguard-go will call us below in [lazyEndpoint.FromPeer] + // if it successfully decrypts the message, at which point it's safe to + // insert le.src into the [peerMap] for ep. +} + +func (le *lazyEndpoint) ClearSrc() {} +func (le *lazyEndpoint) SrcIP() netip.Addr { return netip.Addr{} } + +// DstIP returns the remote address of the peer. +// +// Note: DstIP is used internally by wireguard-go as part of handshake DoS +// mitigation. +func (le *lazyEndpoint) DstIP() netip.Addr { return le.src.ap.Addr() } + +func (le *lazyEndpoint) SrcToString() string { return "" } +func (le *lazyEndpoint) DstToString() string { return le.src.String() } + +// DstToBytes returns a binary representation of the remote address of the peer. +// +// Note: DstToBytes is used internally by wireguard-go as part of handshake DoS +// mitigation. +func (le *lazyEndpoint) DstToBytes() []byte { + b, _ := le.src.ap.MarshalBinary() + return b +} + +// FromPeer implements [conn.PeerAwareEndpoint]. We return a [*lazyEndpoint] in +// [Conn.receiveIP] when we are unable to identify the peer at WireGuard +// packet reception time, pre-decryption, or we want wireguard-go to verify who +// we believe it to be (le.maybeEP). If wireguard-go successfully decrypts the +// packet it calls us here, and we update our [peerMap] to associate le.src with +// peerPublicKey. +func (le *lazyEndpoint) FromPeer(peerPublicKey [32]byte) { + pubKey := key.NodePublicFromRaw32(mem.B(peerPublicKey[:])) + if le.maybeEP != nil && pubKey.Compare(le.maybeEP.publicKey) == 0 { + return + } le.c.mu.Lock() defer le.c.mu.Unlock() ep, ok := le.c.peerMap.endpointForNodeKey(pubKey) if !ok { - return nil + return + } + // TODO(jwhited): Consider [lazyEndpoint] effectiveness as a means to make + // this the sole call site for setNodeKeyForEpAddr. If this is the sole + // call site, and we always update the mapping based on successful + // Cryptokey Routing identification events, then we can go ahead and make + // [epAddr]s singular per peer (like they are for Geneve-encapsulated ones + // already). + // See http://go/corp/29422 & http://go/corp/30042 + le.c.peerMap.setNodeKeyForEpAddr(le.src, pubKey) + le.c.logf("magicsock: lazyEndpoint.FromPeer(%v) setting epAddr(%v) in peerMap for node(%v)", pubKey.ShortString(), le.src, ep.nodeAddr) +} + +// PeerRelays returns the current set of candidate peer relays. +func (c *Conn) PeerRelays() set.Set[netip.Addr] { + candidatePeerRelays := c.relayManager.getServers() + servers := make(set.Set[netip.Addr], len(candidatePeerRelays)) + c.mu.Lock() + defer c.mu.Unlock() + for relay := range candidatePeerRelays { + pi, ok := c.peerMap.byNodeKey[relay.nodeKey] + if !ok { + if c.self.Key().Compare(relay.nodeKey) == 0 { + if c.self.Addresses().Len() > 0 { + servers.Add(c.self.Addresses().At(0).Addr()) + } + } + continue + } + servers.Add(pi.ep.nodeAddr) } - le.c.logf("magicsock: lazyEndpoint.GetPeerEndpoint(%v) found: %v", pubKey.ShortString(), ep.nodeAddr) - return ep + return servers } diff --git a/wgengine/magicsock/magicsock_default.go b/wgengine/magicsock/magicsock_default.go index 7614c64c9..88759d3ac 100644 --- a/wgengine/magicsock/magicsock_default.go +++ b/wgengine/magicsock/magicsock_default.go @@ -1,7 +1,7 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -//go:build !linux +//go:build !linux || ts_omit_listenrawdisco package magicsock @@ -9,19 +9,8 @@ import ( "errors" "fmt" "io" - - "tailscale.com/types/logger" - "tailscale.com/types/nettype" ) func (c *Conn) listenRawDisco(family string) (io.Closer, error) { return nil, fmt.Errorf("raw disco listening not supported on this OS: %w", errors.ErrUnsupported) } - -func trySetSocketBuffer(pconn nettype.PacketConn, logf logger.Logf) { - portableTrySetSocketBuffer(pconn, logf) -} - -const ( - controlMessageSize = 0 -) diff --git a/wgengine/magicsock/magicsock_linux.go b/wgengine/magicsock/magicsock_linux.go index c5df555cd..f37e19165 100644 --- a/wgengine/magicsock/magicsock_linux.go +++ b/wgengine/magicsock/magicsock_linux.go @@ -1,6 +1,8 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause +//go:build linux && !ts_omit_listenrawdisco + package magicsock import ( @@ -13,7 +15,6 @@ import ( "net" "net/netip" "strings" - "syscall" "time" "github.com/mdlayher/socket" @@ -28,7 +29,6 @@ import ( "tailscale.com/types/ipproto" "tailscale.com/types/key" "tailscale.com/types/logger" - "tailscale.com/types/nettype" ) const ( @@ -66,10 +66,10 @@ var ( // fragmented, and we don't want to handle reassembly. bpf.LoadAbsolute{Off: 6, Size: 2}, // More Fragments bit set means this is part of a fragmented packet. - bpf.JumpIf{Cond: bpf.JumpBitsSet, Val: 0x2000, SkipTrue: 7, SkipFalse: 0}, + bpf.JumpIf{Cond: bpf.JumpBitsSet, Val: 0x2000, SkipTrue: 8, SkipFalse: 0}, // Non-zero fragment offset with MF=0 means this is the last // fragment of packet. - bpf.JumpIf{Cond: bpf.JumpBitsSet, Val: 0x1fff, SkipTrue: 6, SkipFalse: 0}, + bpf.JumpIf{Cond: bpf.JumpBitsSet, Val: 0x1fff, SkipTrue: 7, SkipFalse: 0}, // Load IP header length into X register. bpf.LoadMemShift{Off: 0}, @@ -453,7 +453,13 @@ func (c *Conn) receiveDisco(pc *socket.Conn, isIPV6 bool) { metricRecvDiscoPacketIPv4.Add(1) } - c.handleDiscoMessage(payload, srcAddr, key.NodePublic{}, discoRXPathRawSocket) + pt, isGeneveEncap := packetLooksLike(payload) + if pt == packetLooksLikeDisco && !isGeneveEncap { + // The BPF program matching on disco does not currently support + // Geneve encapsulation. isGeneveEncap should not return true if + // payload is disco. + c.handleDiscoMessage(payload, epAddr{ap: srcAddr}, false, key.NodePublic{}, discoRXPathRawSocket) + } } } @@ -483,38 +489,3 @@ func printSockaddr(sa unix.Sockaddr) string { return fmt.Sprintf("unknown(%T)", sa) } } - -// trySetSocketBuffer attempts to set SO_SNDBUFFORCE and SO_RECVBUFFORCE which -// can overcome the limit of net.core.{r,w}mem_max, but require CAP_NET_ADMIN. -// It falls back to the portable implementation if that fails, which may be -// silently capped to net.core.{r,w}mem_max. -func trySetSocketBuffer(pconn nettype.PacketConn, logf logger.Logf) { - if c, ok := pconn.(*net.UDPConn); ok { - var errRcv, errSnd error - rc, err := c.SyscallConn() - if err == nil { - rc.Control(func(fd uintptr) { - errRcv = syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_RCVBUFFORCE, socketBufferSize) - if errRcv != nil { - logf("magicsock: [warning] failed to force-set UDP read buffer size to %d: %v; using kernel default values (impacts throughput only)", socketBufferSize, errRcv) - } - errSnd = syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_SNDBUFFORCE, socketBufferSize) - if errSnd != nil { - logf("magicsock: [warning] failed to force-set UDP write buffer size to %d: %v; using kernel default values (impacts throughput only)", socketBufferSize, errSnd) - } - }) - } - - if err != nil || errRcv != nil || errSnd != nil { - portableTrySetSocketBuffer(pconn, logf) - } - } -} - -var controlMessageSize = -1 // bomb if used for allocation before init - -func init() { - // controlMessageSize is set to hold a UDP_GRO or UDP_SEGMENT control - // message. These contain a single uint16 of data. - controlMessageSize = unix.CmsgSpace(2) -} diff --git a/wgengine/magicsock/magicsock_linux_test.go b/wgengine/magicsock/magicsock_linux_test.go index 6b86b04f2..28ccd220e 100644 --- a/wgengine/magicsock/magicsock_linux_test.go +++ b/wgengine/magicsock/magicsock_linux_test.go @@ -9,6 +9,7 @@ import ( "net/netip" "testing" + "golang.org/x/net/bpf" "golang.org/x/sys/cpu" "golang.org/x/sys/unix" "tailscale.com/disco" @@ -146,3 +147,78 @@ func TestEthernetProto(t *testing.T) { } } } + +func TestBpfDiscardV4(t *testing.T) { + // Good packet as a reference for what should not be rejected + udp4Packet := []byte{ + // IPv4 header + 0x45, 0x00, 0x00, 0x26, 0x00, 0x00, 0x00, 0x00, + 0x40, 0x11, 0x00, 0x00, + 0x7f, 0x00, 0x00, 0x01, // source ip + 0x7f, 0x00, 0x00, 0x02, // dest ip + + // UDP header + 0x30, 0x39, // src port + 0xd4, 0x31, // dest port + 0x00, 0x12, // length; 8 bytes header + 10 bytes payload = 18 bytes + 0x00, 0x00, // checksum; unused + + // Payload: disco magic plus 32 bytes for key and 24 bytes for nonce + 0x54, 0x53, 0xf0, 0x9f, 0x92, 0xac, 0x00, 0x01, 0x02, 0x03, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, + } + + vm, err := bpf.NewVM(magicsockFilterV4) + if err != nil { + t.Fatalf("failed creating BPF VM: %v", err) + } + + tests := []struct { + name string + replace map[int]byte + accept bool + }{ + { + name: "base accepted datagram", + replace: map[int]byte{}, + accept: true, + }, + { + name: "more fragments", + replace: map[int]byte{ + 6: 0x20, + }, + accept: false, + }, + { + name: "some fragment", + replace: map[int]byte{ + 7: 0x01, + }, + accept: false, + }, + } + + udp4PacketChanged := make([]byte, len(udp4Packet)) + copy(udp4PacketChanged, udp4Packet) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + for k, v := range tt.replace { + udp4PacketChanged[k] = v + } + ret, err := vm.Run(udp4PacketChanged) + if err != nil { + t.Fatalf("BPF VM error: %v", err) + } + + if (ret != 0) != tt.accept { + t.Errorf("expected accept=%v, got ret=%v", tt.accept, ret) + } + }) + } +} diff --git a/wgengine/magicsock/magicsock_notplan9.go b/wgengine/magicsock/magicsock_notplan9.go new file mode 100644 index 000000000..86d099ee7 --- /dev/null +++ b/wgengine/magicsock/magicsock_notplan9.go @@ -0,0 +1,31 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !plan9 + +package magicsock + +import ( + "errors" + "syscall" +) + +// shouldRebind returns if the error is one that is known to be healed by a +// rebind, and if so also returns a resason string for the rebind. +func shouldRebind(err error) (ok bool, reason string) { + switch { + // EPIPE/ENOTCONN are common errors when a send fails due to a closed + // socket. There is some platform and version inconsistency in which + // error is returned, but the meaning is the same. + case errors.Is(err, syscall.EPIPE), errors.Is(err, syscall.ENOTCONN): + return true, "broken-pipe" + + // EPERM is typically caused by EDR software, and has been observed to be + // transient, it seems that some versions of some EDR lose track of sockets + // at times, and return EPERM, but reconnects will establish appropriate + // rights associated with a new socket. + case errors.Is(err, syscall.EPERM): + return true, "operation-not-permitted" + } + return false, "" +} diff --git a/wgengine/magicsock/magicsock_plan9.go b/wgengine/magicsock/magicsock_plan9.go new file mode 100644 index 000000000..65714c3e1 --- /dev/null +++ b/wgengine/magicsock/magicsock_plan9.go @@ -0,0 +1,12 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build plan9 + +package magicsock + +// shouldRebind returns if the error is one that is known to be healed by a +// rebind, and if so also returns a resason string for the rebind. +func shouldRebind(err error) (ok bool, reason string) { + return false, "" +} diff --git a/wgengine/magicsock/magicsock_test.go b/wgengine/magicsock/magicsock_test.go index 6b2d961b9..7ae422906 100644 --- a/wgengine/magicsock/magicsock_test.go +++ b/wgengine/magicsock/magicsock_test.go @@ -9,6 +9,7 @@ import ( crand "crypto/rand" "crypto/tls" "encoding/binary" + "encoding/hex" "errors" "fmt" "io" @@ -18,6 +19,7 @@ import ( "net/http/httptest" "net/netip" "os" + "reflect" "runtime" "strconv" "strings" @@ -25,30 +27,31 @@ import ( "sync/atomic" "syscall" "testing" + "testing/synctest" "time" "unsafe" + qt "github.com/frankban/quicktest" + "github.com/google/go-cmp/cmp" wgconn "github.com/tailscale/wireguard-go/conn" "github.com/tailscale/wireguard-go/device" "github.com/tailscale/wireguard-go/tun/tuntest" "go4.org/mem" - xmaps "golang.org/x/exp/maps" "golang.org/x/net/icmp" "golang.org/x/net/ipv4" "tailscale.com/cmd/testwrapper/flakytest" "tailscale.com/control/controlknobs" - "tailscale.com/derp" - "tailscale.com/derp/derphttp" + "tailscale.com/derp/derpserver" "tailscale.com/disco" "tailscale.com/envknob" "tailscale.com/health" "tailscale.com/ipn/ipnstate" - "tailscale.com/net/connstats" "tailscale.com/net/netaddr" "tailscale.com/net/netcheck" "tailscale.com/net/netmon" "tailscale.com/net/packet" "tailscale.com/net/ping" + "tailscale.com/net/stun" "tailscale.com/net/stun/stuntest" "tailscale.com/net/tstun" "tailscale.com/tailcfg" @@ -62,10 +65,16 @@ import ( "tailscale.com/types/nettype" "tailscale.com/types/ptr" "tailscale.com/util/cibuild" + "tailscale.com/util/clientmetric" + "tailscale.com/util/eventbus" + "tailscale.com/util/eventbus/eventbustest" + "tailscale.com/util/must" "tailscale.com/util/racebuild" "tailscale.com/util/set" + "tailscale.com/util/slicesx" "tailscale.com/util/usermetric" "tailscale.com/wgengine/filter" + "tailscale.com/wgengine/filter/filtertype" "tailscale.com/wgengine/wgcfg" "tailscale.com/wgengine/wgcfg/nmcfg" "tailscale.com/wgengine/wglog" @@ -102,15 +111,15 @@ func (c *Conn) WaitReady(t testing.TB) { } } -func runDERPAndStun(t *testing.T, logf logger.Logf, l nettype.PacketListener, stunIP netip.Addr) (derpMap *tailcfg.DERPMap, cleanup func()) { - d := derp.NewServer(key.NewNode(), logf) +func runDERPAndStun(t *testing.T, logf logger.Logf, ln nettype.PacketListener, stunIP netip.Addr) (derpMap *tailcfg.DERPMap, cleanup func()) { + d := derpserver.New(key.NewNode(), logf) - httpsrv := httptest.NewUnstartedServer(derphttp.Handler(d)) + httpsrv := httptest.NewUnstartedServer(derpserver.Handler(d)) httpsrv.Config.ErrorLog = logger.StdLogger(logf) httpsrv.Config.TLSNextProto = make(map[string]func(*http.Server, *tls.Conn, http.Handler)) httpsrv.StartTLS() - stunAddr, stunCleanup := stuntest.ServeWithPacketListener(t, l) + stunAddr, stunCleanup := stuntest.ServeWithPacketListener(t, ln) m := &tailcfg.DERPMap{ Regions: map[int]*tailcfg.DERPRegion{ @@ -149,41 +158,46 @@ func runDERPAndStun(t *testing.T, logf logger.Logf, l nettype.PacketListener, st // happiness. type magicStack struct { privateKey key.NodePrivate - epCh chan []tailcfg.Endpoint // endpoint updates produced by this peer - stats *connstats.Statistics // per-connection statistics - conn *Conn // the magicsock itself - tun *tuntest.ChannelTUN // TUN device to send/receive packets - tsTun *tstun.Wrapper // wrapped tun that implements filtering and wgengine hooks - dev *device.Device // the wireguard-go Device that connects the previous things - wgLogger *wglog.Logger // wireguard-go log wrapper - netMon *netmon.Monitor // always non-nil + epCh chan []tailcfg.Endpoint // endpoint updates produced by this peer + counts netlogtype.CountsByConnection // per-connection statistics + conn *Conn // the magicsock itself + tun *tuntest.ChannelTUN // TUN device to send/receive packets + tsTun *tstun.Wrapper // wrapped tun that implements filtering and wgengine hooks + dev *device.Device // the wireguard-go Device that connects the previous things + wgLogger *wglog.Logger // wireguard-go log wrapper + netMon *netmon.Monitor // always non-nil metrics *usermetric.Registry } // newMagicStack builds and initializes an idle magicsock and -// friends. You need to call conn.SetNetworkMap and dev.Reconfig +// friends. You need to call conn.onNodeViewsUpdate and dev.Reconfig // before anything interesting happens. -func newMagicStack(t testing.TB, logf logger.Logf, l nettype.PacketListener, derpMap *tailcfg.DERPMap) *magicStack { +func newMagicStack(t testing.TB, logf logger.Logf, ln nettype.PacketListener, derpMap *tailcfg.DERPMap) *magicStack { privateKey := key.NewNode() - return newMagicStackWithKey(t, logf, l, derpMap, privateKey) + return newMagicStackWithKey(t, logf, ln, derpMap, privateKey) } -func newMagicStackWithKey(t testing.TB, logf logger.Logf, l nettype.PacketListener, derpMap *tailcfg.DERPMap, privateKey key.NodePrivate) *magicStack { +func newMagicStackWithKey(t testing.TB, logf logger.Logf, ln nettype.PacketListener, derpMap *tailcfg.DERPMap, privateKey key.NodePrivate) *magicStack { t.Helper() - netMon, err := netmon.New(logf) + bus := eventbustest.NewBus(t) + + netMon, err := netmon.New(bus, logf) if err != nil { t.Fatalf("netmon.New: %v", err) } + ht := health.NewTracker(bus) var reg usermetric.Registry epCh := make(chan []tailcfg.Endpoint, 100) // arbitrary conn, err := NewConn(Options{ NetMon: netMon, + EventBus: bus, Metrics: ®, Logf: logf, + HealthTracker: ht, DisablePortMapper: true, - TestOnlyPacketListener: l, + TestOnlyPacketListener: ln, EndpointsFunc: func(eps []tailcfg.Endpoint) { epCh <- eps }, @@ -265,7 +279,10 @@ func (s *magicStack) Status() *ipnstate.Status { func (s *magicStack) IP() netip.Addr { for deadline := time.Now().Add(5 * time.Second); time.Now().Before(deadline); time.Sleep(10 * time.Millisecond) { s.conn.mu.Lock() - addr := s.conn.firstAddrForTest + var addr netip.Addr + if s.conn.self.Valid() && s.conn.self.Addresses().Len() > 0 { + addr = s.conn.self.Addresses().At(0).Addr() + } s.conn.mu.Unlock() if addr.IsValid() { return addr @@ -291,8 +308,7 @@ func meshStacks(logf logger.Logf, mutateNetmap func(idx int, nm *netmap.NetworkM buildNetmapLocked := func(myIdx int) *netmap.NetworkMap { me := ms[myIdx] nm := &netmap.NetworkMap{ - PrivateKey: me.privateKey, - NodeKey: me.privateKey.Public(), + NodeKey: me.privateKey.Public(), SelfNode: (&tailcfg.Node{ Addresses: []netip.Prefix{netip.PrefixFrom(netaddr.IPv4(1, 0, 0, byte(myIdx+1)), 32)}, }).View(), @@ -310,7 +326,7 @@ func meshStacks(logf logger.Logf, mutateNetmap func(idx int, nm *netmap.NetworkM Addresses: addrs, AllowedIPs: addrs, Endpoints: epFromTyped(eps[i]), - DERP: "127.3.3.40:1", + HomeDERP: 1, } nm.Peers = append(nm.Peers, peer.View()) } @@ -329,13 +345,17 @@ func meshStacks(logf logger.Logf, mutateNetmap func(idx int, nm *netmap.NetworkM for i, m := range ms { nm := buildNetmapLocked(i) - m.conn.SetNetworkMap(nm) - peerSet := make(set.Set[key.NodePublic], len(nm.Peers)) - for _, peer := range nm.Peers { + nv := NodeViewsUpdate{ + SelfNode: nm.SelfNode, + Peers: nm.Peers, + } + m.conn.onNodeViewsUpdate(nv) + peerSet := make(set.Set[key.NodePublic], len(nv.Peers)) + for _, peer := range nv.Peers { peerSet.Add(peer.Key()) } m.conn.UpdatePeers(peerSet) - wg, err := nmcfg.WGCfg(nm, logf, 0, "") + wg, err := nmcfg.WGCfg(ms[i].privateKey, nm, logf, 0, "") if err != nil { // We're too far from the *testing.T to be graceful, // blow up. Shouldn't happen anyway. @@ -386,7 +406,10 @@ func TestNewConn(t *testing.T) { } } - netMon, err := netmon.New(logger.WithPrefix(t.Logf, "... netmon: ")) + bus := eventbus.New() + t.Cleanup(bus.Close) + + netMon, err := netmon.New(bus, logger.WithPrefix(t.Logf, "... netmon: ")) if err != nil { t.Fatalf("netmon.New: %v", err) } @@ -402,6 +425,7 @@ func TestNewConn(t *testing.T) { EndpointsFunc: epFunc, Logf: t.Logf, NetMon: netMon, + EventBus: bus, Metrics: new(usermetric.Registry), }) if err != nil { @@ -519,7 +543,10 @@ func TestDeviceStartStop(t *testing.T) { tstest.PanicOnLog() tstest.ResourceCheck(t) - netMon, err := netmon.New(logger.WithPrefix(t.Logf, "... netmon: ")) + bus := eventbus.New() + t.Cleanup(bus.Close) + + netMon, err := netmon.New(bus, logger.WithPrefix(t.Logf, "... netmon: ")) if err != nil { t.Fatalf("netmon.New: %v", err) } @@ -529,6 +556,7 @@ func TestDeviceStartStop(t *testing.T) { EndpointsFunc: func(eps []tailcfg.Endpoint) {}, Logf: t.Logf, NetMon: netMon, + EventBus: bus, Metrics: new(usermetric.Registry), }) if err != nil { @@ -659,13 +687,13 @@ func (localhostListener) ListenPacket(ctx context.Context, network, address stri func TestTwoDevicePing(t *testing.T) { flakytest.Mark(t, "https://github.com/tailscale/tailscale/issues/11762") - l, ip := localhostListener{}, netaddr.IPv4(127, 0, 0, 1) + ln, ip := localhostListener{}, netaddr.IPv4(127, 0, 0, 1) n := &devices{ - m1: l, + m1: ln, m1IP: ip, - m2: l, + m2: ln, m2IP: ip, - stun: l, + stun: ln, stunIP: ip, } testTwoDevicePing(t, n) @@ -1030,7 +1058,6 @@ func testTwoDevicePing(t *testing.T, d *devices) { }) m1cfg := &wgcfg.Config{ - Name: "peer1", PrivateKey: m1.privateKey, Addresses: []netip.Prefix{netip.MustParsePrefix("1.0.0.1/32")}, Peers: []wgcfg.Peer{ @@ -1042,7 +1069,6 @@ func testTwoDevicePing(t *testing.T, d *devices) { }, } m2cfg := &wgcfg.Config{ - Name: "peer2", PrivateKey: m2.privateKey, Addresses: []netip.Prefix{netip.MustParsePrefix("1.0.0.2/32")}, Peers: []wgcfg.Peer{ @@ -1114,22 +1140,19 @@ func testTwoDevicePing(t *testing.T, d *devices) { } } - m1.stats = connstats.NewStatistics(0, 0, nil) - defer m1.stats.Shutdown(context.Background()) - m1.conn.SetStatistics(m1.stats) - m2.stats = connstats.NewStatistics(0, 0, nil) - defer m2.stats.Shutdown(context.Background()) - m2.conn.SetStatistics(m2.stats) + m1.conn.SetConnectionCounter(m1.counts.Add) + m2.conn.SetConnectionCounter(m2.counts.Add) checkStats := func(t *testing.T, m *magicStack, wantConns []netlogtype.Connection) { - _, stats := m.stats.TestExtract() + defer m.counts.Reset() + counts := m.counts.Clone() for _, conn := range wantConns { - if _, ok := stats[conn]; ok { + if _, ok := counts[conn]; ok { return } } t.Helper() - t.Errorf("missing any connection to %s from %s", wantConns, xmaps.Keys(stats)) + t.Errorf("missing any connection to %s from %s", wantConns, slicesx.MapKeys(counts)) } addrPort := netip.MustParseAddrPort @@ -1188,41 +1211,96 @@ func testTwoDevicePing(t *testing.T, d *devices) { checkStats(t, m1, m1Conns) checkStats(t, m2, m2Conns) }) -} - -func TestDiscoMessage(t *testing.T) { - c := newConn(t.Logf) - c.privateKey = key.NewNode() - - peer1Pub := c.DiscoPublicKey() - peer1Priv := c.discoPrivate - n := &tailcfg.Node{ - Key: key.NewNode().Public(), - DiscoKey: peer1Pub, - } - ep := &endpoint{ - nodeID: 1, - publicKey: n.Key, - } - ep.disco.Store(&endpointDisco{ - key: n.DiscoKey, - short: n.DiscoKey.ShortString(), + t.Run("compare-metrics-stats", func(t *testing.T) { + setT(t) + defer setT(outerT) + m1.conn.resetMetricsForTest() + m1.counts.Reset() + m2.conn.resetMetricsForTest() + m2.counts.Reset() + t.Logf("Metrics before: %s\n", m1.metrics.String()) + ping1(t) + ping2(t) + assertConnStatsAndUserMetricsEqual(t, m1) + assertConnStatsAndUserMetricsEqual(t, m2) + t.Logf("Metrics after: %s\n", m1.metrics.String()) }) - c.peerMap.upsertEndpoint(ep, key.DiscoPublic{}) - - const payload = "why hello" - - var nonce [24]byte - crand.Read(nonce[:]) +} - pkt := peer1Pub.AppendTo([]byte("TSđŸ’Ŧ")) +func (c *Conn) resetMetricsForTest() { + c.metrics.inboundBytesIPv4Total.Set(0) + c.metrics.inboundPacketsIPv4Total.Set(0) + c.metrics.outboundBytesIPv4Total.Set(0) + c.metrics.outboundPacketsIPv4Total.Set(0) + c.metrics.inboundBytesIPv6Total.Set(0) + c.metrics.inboundPacketsIPv6Total.Set(0) + c.metrics.outboundBytesIPv6Total.Set(0) + c.metrics.outboundPacketsIPv6Total.Set(0) + c.metrics.inboundBytesDERPTotal.Set(0) + c.metrics.inboundPacketsDERPTotal.Set(0) + c.metrics.outboundBytesDERPTotal.Set(0) + c.metrics.outboundPacketsDERPTotal.Set(0) +} - box := peer1Priv.Shared(c.discoPrivate.Public()).Seal([]byte(payload)) - pkt = append(pkt, box...) - got := c.handleDiscoMessage(pkt, netip.AddrPort{}, key.NodePublic{}, discoRXPathUDP) - if !got { - t.Error("failed to open it") +func assertConnStatsAndUserMetricsEqual(t *testing.T, ms *magicStack) { + physIPv4RxBytes := int64(0) + physIPv4TxBytes := int64(0) + physDERPRxBytes := int64(0) + physDERPTxBytes := int64(0) + physIPv4RxPackets := int64(0) + physIPv4TxPackets := int64(0) + physDERPRxPackets := int64(0) + physDERPTxPackets := int64(0) + for conn, count := range ms.counts.Clone() { + t.Logf("physconn src: %s, dst: %s", conn.Src.String(), conn.Dst.String()) + if conn.Dst.String() == "127.3.3.40:1" { + physDERPRxBytes += int64(count.RxBytes) + physDERPTxBytes += int64(count.TxBytes) + physDERPRxPackets += int64(count.RxPackets) + physDERPTxPackets += int64(count.TxPackets) + } else { + physIPv4RxBytes += int64(count.RxBytes) + physIPv4TxBytes += int64(count.TxBytes) + physIPv4RxPackets += int64(count.RxPackets) + physIPv4TxPackets += int64(count.TxPackets) + } } + ms.counts.Reset() + + metricIPv4RxBytes := ms.conn.metrics.inboundBytesIPv4Total.Value() + metricIPv4RxPackets := ms.conn.metrics.inboundPacketsIPv4Total.Value() + metricIPv4TxBytes := ms.conn.metrics.outboundBytesIPv4Total.Value() + metricIPv4TxPackets := ms.conn.metrics.outboundPacketsIPv4Total.Value() + + metricDERPRxBytes := ms.conn.metrics.inboundBytesDERPTotal.Value() + metricDERPRxPackets := ms.conn.metrics.inboundPacketsDERPTotal.Value() + metricDERPTxBytes := ms.conn.metrics.outboundBytesDERPTotal.Value() + metricDERPTxPackets := ms.conn.metrics.outboundPacketsDERPTotal.Value() + + c := qt.New(t) + c.Assert(physDERPRxBytes, qt.Equals, metricDERPRxBytes) + c.Assert(physDERPTxBytes, qt.Equals, metricDERPTxBytes) + c.Assert(physIPv4RxBytes, qt.Equals, metricIPv4RxBytes) + c.Assert(physIPv4TxBytes, qt.Equals, metricIPv4TxBytes) + c.Assert(physDERPRxPackets, qt.Equals, metricDERPRxPackets) + c.Assert(physDERPTxPackets, qt.Equals, metricDERPTxPackets) + c.Assert(physIPv4RxPackets, qt.Equals, metricIPv4RxPackets) + c.Assert(physIPv4TxPackets, qt.Equals, metricIPv4TxPackets) + + // Validate that the usermetrics and clientmetrics are in sync + // Note: the clientmetrics are global, this means that when they are registering with the + // wgengine, multiple in-process nodes used by this test will be updating the same metrics. This is why we need to multiply + // the metrics by 2 to get the expected value. + // TODO(kradalby): https://github.com/tailscale/tailscale/issues/13420 + c.Assert(metricSendUDP.Value(), qt.Equals, metricIPv4TxPackets*2) + c.Assert(metricSendDataPacketsIPv4.Value(), qt.Equals, metricIPv4TxPackets*2) + c.Assert(metricSendDataPacketsDERP.Value(), qt.Equals, metricDERPTxPackets*2) + c.Assert(metricSendDataBytesIPv4.Value(), qt.Equals, metricIPv4TxBytes*2) + c.Assert(metricSendDataBytesDERP.Value(), qt.Equals, metricDERPTxBytes*2) + c.Assert(metricRecvDataPacketsIPv4.Value(), qt.Equals, metricIPv4RxPackets*2) + c.Assert(metricRecvDataPacketsDERP.Value(), qt.Equals, metricDERPRxPackets*2) + c.Assert(metricRecvDataBytesIPv4.Value(), qt.Equals, metricIPv4RxBytes*2) + c.Assert(metricRecvDataBytesDERP.Value(), qt.Equals, metricDERPRxBytes*2) } // tests that having a endpoint.String prevents wireguard-go's @@ -1258,11 +1336,11 @@ func Test32bitAlignment(t *testing.T) { t.Fatalf("endpoint.lastRecvWG is not 8-byte aligned") } - de.noteRecvActivity(netip.AddrPort{}, mono.Now()) // verify this doesn't panic on 32-bit + de.noteRecvActivity(epAddr{}, mono.Now()) // verify this doesn't panic on 32-bit if called != 1 { t.Fatal("expected call to noteRecvActivity") } - de.noteRecvActivity(netip.AddrPort{}, mono.Now()) + de.noteRecvActivity(epAddr{}, mono.Now()) if called != 1 { t.Error("expected no second call to noteRecvActivity") } @@ -1273,7 +1351,9 @@ func newTestConn(t testing.TB) *Conn { t.Helper() port := pickPort(t) - netMon, err := netmon.New(logger.WithPrefix(t.Logf, "... netmon: ")) + bus := eventbustest.NewBus(t) + + netMon, err := netmon.New(bus, logger.WithPrefix(t.Logf, "... netmon: ")) if err != nil { t.Fatalf("netmon.New: %v", err) } @@ -1281,7 +1361,8 @@ func newTestConn(t testing.TB) *Conn { conn, err := NewConn(Options{ NetMon: netMon, - HealthTracker: new(health.Tracker), + EventBus: bus, + HealthTracker: health.NewTracker(bus), Metrics: new(usermetric.Registry), DisablePortMapper: true, Logf: t.Logf, @@ -1297,7 +1378,7 @@ func newTestConn(t testing.TB) *Conn { return conn } -// addTestEndpoint sets conn's network map to a single peer expected +// addTestEndpoint sets conn's node views to a single peer expected // to receive packets from sendConn (or DERP), and returns that peer's // nodekey and discokey. func addTestEndpoint(tb testing.TB, conn *Conn, sendConn net.PacketConn) (key.NodePublic, key.DiscoPublic) { @@ -1306,7 +1387,7 @@ func addTestEndpoint(tb testing.TB, conn *Conn, sendConn net.PacketConn) (key.No // codepath. discoKey := key.DiscoPublicFromRaw32(mem.B([]byte{31: 1})) nodeKey := key.NodePublicFromRaw32(mem.B([]byte{0: 'N', 1: 'K', 31: 0})) - conn.SetNetworkMap(&netmap.NetworkMap{ + conn.onNodeViewsUpdate(NodeViewsUpdate{ Peers: nodeViews([]*tailcfg.Node{ { ID: 1, @@ -1495,11 +1576,11 @@ func nodeViews(v []*tailcfg.Node) []tailcfg.NodeView { return nv } -// Test that a netmap update where node changes its node key but +// Test that a node views update where node changes its node key but // doesn't change its disco key doesn't result in a broken state. // // https://github.com/tailscale/tailscale/issues/1391 -func TestSetNetworkMapChangingNodeKey(t *testing.T) { +func TestOnNodeViewsUpdateChangingNodeKey(t *testing.T) { conn := newTestConn(t) t.Cleanup(func() { conn.Close() }) var buf tstest.MemLogger @@ -1511,7 +1592,7 @@ func TestSetNetworkMapChangingNodeKey(t *testing.T) { nodeKey1 := key.NodePublicFromRaw32(mem.B([]byte{0: 'N', 1: 'K', 2: '1', 31: 0})) nodeKey2 := key.NodePublicFromRaw32(mem.B([]byte{0: 'N', 1: 'K', 2: '2', 31: 0})) - conn.SetNetworkMap(&netmap.NetworkMap{ + conn.onNodeViewsUpdate(NodeViewsUpdate{ Peers: nodeViews([]*tailcfg.Node{ { ID: 1, @@ -1527,7 +1608,7 @@ func TestSetNetworkMapChangingNodeKey(t *testing.T) { } for range 3 { - conn.SetNetworkMap(&netmap.NetworkMap{ + conn.onNodeViewsUpdate(NodeViewsUpdate{ Peers: nodeViews([]*tailcfg.Node{ { ID: 2, @@ -1696,10 +1777,15 @@ func TestEndpointSetsEqual(t *testing.T) { func TestBetterAddr(t *testing.T) { const ms = time.Millisecond al := func(ipps string, d time.Duration) addrQuality { - return addrQuality{AddrPort: netip.MustParseAddrPort(ipps), latency: d} + return addrQuality{epAddr: epAddr{ap: netip.MustParseAddrPort(ipps)}, latency: d} } almtu := func(ipps string, d time.Duration, mtu tstun.WireMTU) addrQuality { - return addrQuality{AddrPort: netip.MustParseAddrPort(ipps), latency: d, wireMTU: mtu} + return addrQuality{epAddr: epAddr{ap: netip.MustParseAddrPort(ipps)}, latency: d, wireMTU: mtu} + } + avl := func(ipps string, vni uint32, d time.Duration) addrQuality { + q := al(ipps, d) + q.vni.Set(vni) + return q } zero := addrQuality{} @@ -1805,6 +1891,18 @@ func TestBetterAddr(t *testing.T) { b: al("[::1]:555", 100*ms), want: false, }, + + // Prefer non-Geneve over Geneve-encapsulated + { + a: al(publicV4, 100*ms), + b: avl(publicV4, 1, 100*ms), + want: true, + }, + { + a: avl(publicV4, 1, 100*ms), + b: al(publicV4, 100*ms), + want: false, + }, } for i, tt := range tests { got := betterAddr(tt.a, tt.b) @@ -1835,7 +1933,7 @@ func eps(s ...string) []netip.AddrPort { return eps } -func TestStressSetNetworkMap(t *testing.T) { +func TestStressOnNodeViewsUpdate(t *testing.T) { t.Parallel() conn := newTestConn(t) @@ -1883,15 +1981,15 @@ func TestStressSetNetworkMap(t *testing.T) { allPeers[j].Key = randNodeKey() } } - // Clone existing peers into a new netmap. + // Clone existing peers. peers := make([]*tailcfg.Node, 0, len(allPeers)) for peerIdx, p := range allPeers { if present[peerIdx] { peers = append(peers, p.Clone()) } } - // Set the netmap. - conn.SetNetworkMap(&netmap.NetworkMap{ + // Set the node views. + conn.onNodeViewsUpdate(NodeViewsUpdate{ Peers: nodeViews(peers), }) // Check invariants. @@ -1916,9 +2014,9 @@ func (m *peerMap) validate() error { return fmt.Errorf("duplicate endpoint present: %v", pi.ep.publicKey) } seenEps[pi.ep] = true - for ipp := range pi.ipPorts { - if got := m.byIPPort[ipp]; got != pi { - return fmt.Errorf("m.byIPPort[%v] = %v, want %v", ipp, got, pi) + for addr := range pi.epAddrs { + if got := m.byEpAddr[addr]; got != pi { + return fmt.Errorf("m.byEpAddr[%v] = %v, want %v", addr, got, pi) } } } @@ -1934,13 +2032,13 @@ func (m *peerMap) validate() error { } } - for ipp, pi := range m.byIPPort { - if !pi.ipPorts.Contains(ipp) { - return fmt.Errorf("ipPorts[%v] for %v is false", ipp, pi.ep.publicKey) + for addr, pi := range m.byEpAddr { + if !pi.epAddrs.Contains(addr) { + return fmt.Errorf("epAddrs[%v] for %v is false", addr, pi.ep.publicKey) } pi2 := m.byNodeKey[pi.ep.publicKey] if pi != pi2 { - return fmt.Errorf("byNodeKey[%v]=%p doesn't match byIPPort[%v]=%p", pi, pi, pi.ep.publicKey, pi2) + return fmt.Errorf("byNodeKey[%v]=%p doesn't match byEpAddr[%v]=%p", pi, pi, pi.ep.publicKey, pi2) } } @@ -2016,10 +2114,10 @@ func TestRebindingUDPConn(t *testing.T) { } // https://github.com/tailscale/tailscale/issues/6680: don't ignore -// SetNetworkMap calls when there are no peers. (A too aggressive fast path was +// onNodeViewsUpdate calls when there are no peers. (A too aggressive fast path was // previously bailing out early, thinking there were no changes since all zero -// peers didn't change, but the netmap has non-peer info in it too we shouldn't discard) -func TestSetNetworkMapWithNoPeers(t *testing.T) { +// peers didn't change, but the node views has non-peer info in it too we shouldn't discard) +func TestOnNodeViewsUpdateWithNoPeers(t *testing.T) { var c Conn knobs := &controlknobs.Knobs{} c.logf = logger.Discard @@ -2028,23 +2126,15 @@ func TestSetNetworkMapWithNoPeers(t *testing.T) { for i := 1; i <= 3; i++ { v := !debugEnableSilentDisco() envknob.Setenv("TS_DEBUG_ENABLE_SILENT_DISCO", fmt.Sprint(v)) - nm := &netmap.NetworkMap{} - c.SetNetworkMap(nm) - t.Logf("ptr %d: %p", i, nm) + nv := NodeViewsUpdate{} + c.onNodeViewsUpdate(nv) + t.Logf("ptr %d: %p", i, nv) if c.lastFlags.heartbeatDisabled != v { t.Fatalf("call %d: didn't store netmap", i) } } } -func TestBufferedDerpWritesBeforeDrop(t *testing.T) { - vv := bufferedDerpWritesBeforeDrop() - if vv < 32 { - t.Fatalf("got bufferedDerpWritesBeforeDrop=%d, which is < 32", vv) - } - t.Logf("bufferedDerpWritesBeforeDrop = %d", vv) -} - // newWireguard starts up a new wireguard-go device attached to a test tun, and // returns the device, tun and endpoint port. To add peers call device.IpcSet with UAPI instructions. func newWireguard(t *testing.T, uapi string, aips []netip.Prefix) (*device.Device, *tuntest.ChannelTUN, uint16) { @@ -2110,10 +2200,9 @@ func TestIsWireGuardOnlyPeer(t *testing.T) { defer m.Close() nm := &netmap.NetworkMap{ - Name: "ts", - PrivateKey: m.privateKey, - NodeKey: m.privateKey.Public(), + NodeKey: m.privateKey.Public(), SelfNode: (&tailcfg.Node{ + Name: "ts.", Addresses: []netip.Prefix{tsaip}, }).View(), Peers: nodeViews([]*tailcfg.Node{ @@ -2127,9 +2216,13 @@ func TestIsWireGuardOnlyPeer(t *testing.T) { }, }), } - m.conn.SetNetworkMap(nm) + nv := NodeViewsUpdate{ + SelfNode: nm.SelfNode, + Peers: nm.Peers, + } + m.conn.onNodeViewsUpdate(nv) - cfg, err := nmcfg.WGCfg(nm, t.Logf, netmap.AllowSubnetRoutes, "") + cfg, err := nmcfg.WGCfg(m.privateKey, nm, t.Logf, netmap.AllowSubnetRoutes, "") if err != nil { t.Fatal(err) } @@ -2171,10 +2264,9 @@ func TestIsWireGuardOnlyPeerWithMasquerade(t *testing.T) { defer m.Close() nm := &netmap.NetworkMap{ - Name: "ts", - PrivateKey: m.privateKey, - NodeKey: m.privateKey.Public(), + NodeKey: m.privateKey.Public(), SelfNode: (&tailcfg.Node{ + Name: "ts.", Addresses: []netip.Prefix{tsaip}, }).View(), Peers: nodeViews([]*tailcfg.Node{ @@ -2189,9 +2281,13 @@ func TestIsWireGuardOnlyPeerWithMasquerade(t *testing.T) { }, }), } - m.conn.SetNetworkMap(nm) + nv := NodeViewsUpdate{ + SelfNode: nm.SelfNode, + Peers: nm.Peers, + } + m.conn.onNodeViewsUpdate(nv) - cfg, err := nmcfg.WGCfg(nm, t.Logf, netmap.AllowSubnetRoutes, "") + cfg, err := nmcfg.WGCfg(m.privateKey, nm, t.Logf, netmap.AllowSubnetRoutes, "") if err != nil { t.Fatal(err) } @@ -2226,12 +2322,16 @@ func TestIsWireGuardOnlyPeerWithMasquerade(t *testing.T) { // configures WG. func applyNetworkMap(t *testing.T, m *magicStack, nm *netmap.NetworkMap) { t.Helper() - m.conn.SetNetworkMap(nm) + nv := NodeViewsUpdate{ + SelfNode: nm.SelfNode, + Peers: nm.Peers, + } + m.conn.onNodeViewsUpdate(nv) // Make sure we can't use v6 to avoid test failures. m.conn.noV6.Store(true) // Turn the network map into a wireguard config (for the tailscale internal wireguard device). - cfg, err := nmcfg.WGCfg(nm, t.Logf, netmap.AllowSubnetRoutes, "") + cfg, err := nmcfg.WGCfg(m.privateKey, nm, t.Logf, netmap.AllowSubnetRoutes, "") if err != nil { t.Fatal(err) } @@ -2300,10 +2400,9 @@ func TestIsWireGuardOnlyPickEndpointByPing(t *testing.T) { wgEpV6 := netip.MustParseAddrPort(v6.LocalAddr().String()) nm := &netmap.NetworkMap{ - Name: "ts", - PrivateKey: m.privateKey, - NodeKey: m.privateKey.Public(), + NodeKey: m.privateKey.Public(), SelfNode: (&tailcfg.Node{ + Name: "ts.", Addresses: []netip.Prefix{tsaip}, }).View(), Peers: nodeViews([]*tailcfg.Node{ @@ -2341,7 +2440,7 @@ func TestIsWireGuardOnlyPickEndpointByPing(t *testing.T) { // Check that we got a valid address set on the first send - this // will be randomly selected, but because we have noV6 set to true, // it will be the IPv4 address. - if !pi.ep.bestAddr.Addr().IsValid() { + if !pi.ep.bestAddr.ap.Addr().IsValid() { t.Fatal("bestaddr was nil") } @@ -2363,7 +2462,7 @@ func TestIsWireGuardOnlyPickEndpointByPing(t *testing.T) { if len(state.recentPongs) != 1 { t.Errorf("IPv4 address did not have a recentPong entry: got %v, want %v", len(state.recentPongs), 1) } - // Set the latency extremely high so we dont choose endpoint during the next + // Set the latency extremely high so we don't choose endpoint during the next // addrForSendLocked call. state.recentPongs[state.recentPong].latency = time.Second } @@ -2401,12 +2500,12 @@ func TestIsWireGuardOnlyPickEndpointByPing(t *testing.T) { t.Fatal("wgkey doesn't exist in peer map") } - if !pi.ep.bestAddr.Addr().IsValid() { + if !pi.ep.bestAddr.ap.Addr().IsValid() { t.Error("no bestAddr address was set") } - if pi.ep.bestAddr.Addr() != wgEp.Addr() { - t.Errorf("bestAddr was not set to the expected IPv4 address: got %v, want %v", pi.ep.bestAddr.Addr().String(), wgEp.Addr()) + if pi.ep.bestAddr.ap.Addr() != wgEp.Addr() { + t.Errorf("bestAddr was not set to the expected IPv4 address: got %v, want %v", pi.ep.bestAddr.ap.Addr().String(), wgEp.Addr()) } if pi.ep.trustBestAddrUntil.IsZero() { @@ -2567,7 +2666,7 @@ func TestAddrForSendLockedForWireGuardOnly(t *testing.T) { sendFollowUpPing bool pingTime mono.Time ep []endpointDetails - want netip.AddrPort + want epAddr }{ { name: "no endpoints", @@ -2576,7 +2675,7 @@ func TestAddrForSendLockedForWireGuardOnly(t *testing.T) { sendFollowUpPing: false, pingTime: testTime, ep: []endpointDetails{}, - want: netip.AddrPort{}, + want: epAddr{}, }, { name: "singular endpoint does not request ping", @@ -2590,7 +2689,7 @@ func TestAddrForSendLockedForWireGuardOnly(t *testing.T) { latency: 100 * time.Millisecond, }, }, - want: netip.MustParseAddrPort("1.1.1.1:111"), + want: epAddr{ap: netip.MustParseAddrPort("1.1.1.1:111")}, }, { name: "ping sent within wireguardPingInterval should not request ping", @@ -2608,7 +2707,7 @@ func TestAddrForSendLockedForWireGuardOnly(t *testing.T) { latency: 2000 * time.Millisecond, }, }, - want: netip.MustParseAddrPort("1.1.1.1:111"), + want: epAddr{ap: netip.MustParseAddrPort("1.1.1.1:111")}, }, { name: "ping sent outside of wireguardPingInterval should request ping", @@ -2626,7 +2725,7 @@ func TestAddrForSendLockedForWireGuardOnly(t *testing.T) { latency: 150 * time.Millisecond, }, }, - want: netip.MustParseAddrPort("1.1.1.1:111"), + want: epAddr{ap: netip.MustParseAddrPort("1.1.1.1:111")}, }, { name: "choose lowest latency for useable IPv4 and IPv6", @@ -2644,7 +2743,7 @@ func TestAddrForSendLockedForWireGuardOnly(t *testing.T) { latency: 10 * time.Millisecond, }, }, - want: netip.MustParseAddrPort("[2345:0425:2CA1:0000:0000:0567:5673:23b5]:222"), + want: epAddr{ap: netip.MustParseAddrPort("[2345:0425:2CA1:0000:0000:0567:5673:23b5]:222")}, }, { name: "choose IPv6 address when latency is the same for v4 and v6", @@ -2662,7 +2761,7 @@ func TestAddrForSendLockedForWireGuardOnly(t *testing.T) { latency: 100 * time.Millisecond, }, }, - want: netip.MustParseAddrPort("[1::1]:567"), + want: epAddr{ap: netip.MustParseAddrPort("[1::1]:567")}, }, } @@ -2682,8 +2781,8 @@ func TestAddrForSendLockedForWireGuardOnly(t *testing.T) { endpoint.endpointState[epd.addrPort] = &endpointState{} } udpAddr, _, shouldPing := endpoint.addrForSendLocked(testTime) - if udpAddr.IsValid() != test.validAddr { - t.Errorf("udpAddr validity is incorrect; got %v, want %v", udpAddr.IsValid(), test.validAddr) + if udpAddr.ap.IsValid() != test.validAddr { + t.Errorf("udpAddr validity is incorrect; got %v, want %v", udpAddr.ap.IsValid(), test.validAddr) } if shouldPing != test.sendInitialPing { t.Errorf("addrForSendLocked did not indiciate correct ping state; got %v, want %v", shouldPing, test.sendInitialPing) @@ -2715,8 +2814,8 @@ func TestAddrForSendLockedForWireGuardOnly(t *testing.T) { if shouldPing != test.sendFollowUpPing { t.Errorf("addrForSendLocked did not indiciate correct ping state; got %v, want %v", shouldPing, test.sendFollowUpPing) } - if endpoint.bestAddr.AddrPort != test.want { - t.Errorf("bestAddr.AddrPort is not as expected: got %v, want %v", endpoint.bestAddr.AddrPort, test.want) + if endpoint.bestAddr.epAddr != test.want { + t.Errorf("bestAddr.epAddr is not as expected: got %v, want %v", endpoint.bestAddr.epAddr, test.want) } }) } @@ -2803,7 +2902,7 @@ func TestAddrForPingSizeLocked(t *testing.T) { t.Run(test.desc, func(t *testing.T) { bestAddr := addrQuality{wireMTU: test.mtu} if test.bestAddr { - bestAddr.AddrPort = validUdpAddr + bestAddr.epAddr.ap = validUdpAddr } ep := &endpoint{ derpAddr: validDerpAddr, @@ -2815,10 +2914,10 @@ func TestAddrForPingSizeLocked(t *testing.T) { udpAddr, derpAddr := ep.addrForPingSizeLocked(testTime, test.size) - if test.wantUDP && !udpAddr.IsValid() { + if test.wantUDP && !udpAddr.ap.IsValid() { t.Errorf("%s: udpAddr returned is not valid, won't be sent to UDP address", test.desc) } - if !test.wantUDP && udpAddr.IsValid() { + if !test.wantUDP && udpAddr.ap.IsValid() { t.Errorf("%s: udpAddr returned is valid, discovery will not start", test.desc) } if test.wantDERP && !derpAddr.IsValid() { @@ -2934,7 +3033,7 @@ func TestMaybeSetNearestDERP(t *testing.T) { } for _, tt := range testCases { t.Run(tt.name, func(t *testing.T) { - ht := new(health.Tracker) + ht := health.NewTracker(eventbustest.NewBus(t)) c := newConn(t.Logf) c.myDerp = tt.old c.derpMap = derpMap @@ -2961,37 +3060,1248 @@ func TestMaybeSetNearestDERP(t *testing.T) { } } +func TestShouldRebind(t *testing.T) { + tests := []struct { + err error + ok bool + reason string + }{ + {nil, false, ""}, + {io.EOF, false, ""}, + {io.ErrUnexpectedEOF, false, ""}, + {io.ErrShortBuffer, false, ""}, + {&net.OpError{Err: syscall.EPERM}, true, "operation-not-permitted"}, + {&net.OpError{Err: syscall.EPIPE}, true, "broken-pipe"}, + {&net.OpError{Err: syscall.ENOTCONN}, true, "broken-pipe"}, + } + for _, tt := range tests { + t.Run(fmt.Sprintf("%s-%v", tt.err, tt.ok), func(t *testing.T) { + if got, reason := shouldRebind(tt.err); got != tt.ok || reason != tt.reason { + t.Errorf("errShouldRebind(%v) = %v, %q; want %v, %q", tt.err, got, reason, tt.ok, tt.reason) + } + }) + } +} + func TestMaybeRebindOnError(t *testing.T) { tstest.PanicOnLog() tstest.ResourceCheck(t) - err := fmt.Errorf("outer err: %w", syscall.EPERM) + var rebindErrs []error + if runtime.GOOS != "plan9" { + rebindErrs = append(rebindErrs, + &net.OpError{Err: syscall.EPERM}, + &net.OpError{Err: syscall.EPIPE}, + &net.OpError{Err: syscall.ENOTCONN}, + ) + } + + for _, rebindErr := range rebindErrs { + t.Run(fmt.Sprintf("rebind-%s", rebindErr), func(t *testing.T) { + conn := newTestConn(t) + defer conn.Close() + + before := metricRebindCalls.Value() + conn.maybeRebindOnError(rebindErr) + after := metricRebindCalls.Value() + if before+1 != after { + t.Errorf("should rebind on %#v", rebindErr) + } + }) + } - t.Run("darwin-rebind", func(t *testing.T) { - conn := newTestConn(t) - defer conn.Close() - rebound := conn.maybeRebindOnError("darwin", err) - if !rebound { - t.Errorf("darwin should rebind on syscall.EPERM") - } + t.Run("no-frequent-rebind", func(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + if runtime.GOOS != "plan9" { + err := fmt.Errorf("outer err: %w", syscall.EPERM) + conn := newTestConn(t) + defer conn.Close() + lastRebindTime := time.Now().Add(-1 * time.Second) + conn.lastErrRebind.Store(lastRebindTime) + before := metricRebindCalls.Value() + conn.maybeRebindOnError(err) + after := metricRebindCalls.Value() + if before != after { + t.Errorf("should not rebind within 5 seconds of last") + } + + // ensure that rebinds are performed and store an updated last + // rebind time. + time.Sleep(6 * time.Second) + + conn.maybeRebindOnError(err) + newTime := conn.lastErrRebind.Load() + if newTime == lastRebindTime { + t.Errorf("expected a rebind to occur") + } + if newTime.Sub(lastRebindTime) < 5*time.Second { + t.Errorf("expected at least 5 seconds between %s and %s", lastRebindTime, newTime) + } + } + + }) }) +} + +func newTestConnAndRegistry(t *testing.T) (*Conn, *usermetric.Registry, func()) { + t.Helper() + bus := eventbus.New() + netMon := must.Get(netmon.New(bus, t.Logf)) + + reg := new(usermetric.Registry) + + conn := must.Get(NewConn(Options{ + DisablePortMapper: true, + Logf: t.Logf, + NetMon: netMon, + EventBus: bus, + Metrics: reg, + })) + + return conn, reg, func() { + bus.Close() + netMon.Close() + conn.Close() + } +} + +func TestNetworkSendErrors(t *testing.T) { + t.Run("network-down", func(t *testing.T) { + // TODO(alexc): This test case fails on Windows because it never + // successfully sends the first packet: + // + // expected successful Send, got err: "write udp4 0.0.0.0:57516->127.0.0.1:9999: + // wsasendto: The requested address is not valid in its context." + // + // It would be nice to run this test on Windows, but I was already + // on a side quest and it was unclear if this test has ever worked + // correctly on Windows. + if runtime.GOOS == "windows" { + t.Skipf("skipping on %s", runtime.GOOS) + } + + conn, reg, close := newTestConnAndRegistry(t) + defer close() + + buffs := [][]byte{{00, 00, 00, 00, 00, 00, 00, 00}} + ep := &lazyEndpoint{ + src: epAddr{ap: netip.MustParseAddrPort("127.0.0.1:9999")}, + } + offset := 8 - t.Run("linux-not-rebind", func(t *testing.T) { - conn := newTestConn(t) - defer conn.Close() - rebound := conn.maybeRebindOnError("linux", err) - if rebound { - t.Errorf("linux should not rebind on syscall.EPERM") + // Check this is a valid payload to send when the network is up + conn.SetNetworkUp(true) + if err := conn.Send(buffs, ep, offset); err != nil { + t.Errorf("expected successful Send, got err: %q", err) + } + + // Now we know the payload would be sent if the network is up, + // send it again when the network is down + conn.SetNetworkUp(false) + err := conn.Send(buffs, ep, offset) + if err == nil { + t.Error("expected error, got nil") + } + resp := httptest.NewRecorder() + reg.Handler(resp, new(http.Request)) + if !strings.Contains(resp.Body.String(), `tailscaled_outbound_dropped_packets_total{reason="error"} 1`) { + t.Errorf("expected NetworkDown to increment packet dropped metric; got %q", resp.Body.String()) } }) - t.Run("no-frequent-rebind", func(t *testing.T) { - conn := newTestConn(t) - defer conn.Close() - conn.lastEPERMRebind.Store(time.Now().Add(-1 * time.Second)) - rebound := conn.maybeRebindOnError("darwin", err) - if rebound { - t.Errorf("darwin should not rebind on syscall.EPERM within 5 seconds of last") + t.Run("invalid-payload", func(t *testing.T) { + conn, reg, close := newTestConnAndRegistry(t) + defer close() + + conn.SetNetworkUp(false) + err := conn.Send([][]byte{{00}}, &lazyEndpoint{}, 0) + if err == nil { + t.Error("expected error, got nil") + } + resp := httptest.NewRecorder() + reg.Handler(resp, new(http.Request)) + if !strings.Contains(resp.Body.String(), `tailscaled_outbound_dropped_packets_total{reason="error"} 1`) { + t.Errorf("expected invalid payload to increment packet dropped metric; got %q", resp.Body.String()) } }) } + +func Test_packetLooksLike(t *testing.T) { + discoPub := key.DiscoPublicFromRaw32(mem.B([]byte{1: 1, 30: 30, 31: 31})) + nakedDisco := make([]byte, 0, 512) + nakedDisco = append(nakedDisco, disco.Magic...) + nakedDisco = discoPub.AppendTo(nakedDisco) + + geneveEncapDisco := make([]byte, packet.GeneveFixedHeaderLength+len(nakedDisco)) + gh := packet.GeneveHeader{ + Version: 0, + Protocol: packet.GeneveProtocolDisco, + Control: true, + } + gh.VNI.Set(1) + err := gh.Encode(geneveEncapDisco) + if err != nil { + t.Fatal(err) + } + copy(geneveEncapDisco[packet.GeneveFixedHeaderLength:], nakedDisco) + + nakedWireGuardInitiation := make([]byte, len(geneveEncapDisco)) + binary.LittleEndian.PutUint32(nakedWireGuardInitiation, device.MessageInitiationType) + nakedWireGuardResponse := make([]byte, len(geneveEncapDisco)) + binary.LittleEndian.PutUint32(nakedWireGuardResponse, device.MessageResponseType) + nakedWireGuardCookieReply := make([]byte, len(geneveEncapDisco)) + binary.LittleEndian.PutUint32(nakedWireGuardCookieReply, device.MessageCookieReplyType) + nakedWireGuardTransport := make([]byte, len(geneveEncapDisco)) + binary.LittleEndian.PutUint32(nakedWireGuardTransport, device.MessageTransportType) + + geneveEncapWireGuard := make([]byte, packet.GeneveFixedHeaderLength+len(nakedWireGuardInitiation)) + gh = packet.GeneveHeader{ + Version: 0, + Protocol: packet.GeneveProtocolWireGuard, + Control: true, + } + gh.VNI.Set(1) + err = gh.Encode(geneveEncapWireGuard) + if err != nil { + t.Fatal(err) + } + copy(geneveEncapWireGuard[packet.GeneveFixedHeaderLength:], nakedWireGuardInitiation) + + geneveEncapDiscoNonZeroGeneveVersion := make([]byte, packet.GeneveFixedHeaderLength+len(nakedDisco)) + gh = packet.GeneveHeader{ + Version: 1, + Protocol: packet.GeneveProtocolDisco, + Control: true, + } + gh.VNI.Set(1) + err = gh.Encode(geneveEncapDiscoNonZeroGeneveVersion) + if err != nil { + t.Fatal(err) + } + copy(geneveEncapDiscoNonZeroGeneveVersion[packet.GeneveFixedHeaderLength:], nakedDisco) + + geneveEncapDiscoNonZeroGeneveReservedBits := make([]byte, packet.GeneveFixedHeaderLength+len(nakedDisco)) + gh = packet.GeneveHeader{ + Version: 0, + Protocol: packet.GeneveProtocolDisco, + Control: true, + } + gh.VNI.Set(1) + err = gh.Encode(geneveEncapDiscoNonZeroGeneveReservedBits) + if err != nil { + t.Fatal(err) + } + geneveEncapDiscoNonZeroGeneveReservedBits[1] |= 0x3F + copy(geneveEncapDiscoNonZeroGeneveReservedBits[packet.GeneveFixedHeaderLength:], nakedDisco) + + geneveEncapDiscoNonZeroGeneveVNILSB := make([]byte, packet.GeneveFixedHeaderLength+len(nakedDisco)) + gh = packet.GeneveHeader{ + Version: 0, + Protocol: packet.GeneveProtocolDisco, + Control: true, + } + gh.VNI.Set(1) + err = gh.Encode(geneveEncapDiscoNonZeroGeneveVNILSB) + if err != nil { + t.Fatal(err) + } + geneveEncapDiscoNonZeroGeneveVNILSB[7] |= 0xFF + copy(geneveEncapDiscoNonZeroGeneveVNILSB[packet.GeneveFixedHeaderLength:], nakedDisco) + + tests := []struct { + name string + msg []byte + wantPacketLooksLikeType packetLooksLikeType + wantIsGeneveEncap bool + }{ + { + name: "STUN binding success response", + msg: stun.Response(stun.NewTxID(), netip.MustParseAddrPort("127.0.0.1:1")), + wantPacketLooksLikeType: packetLooksLikeSTUNBinding, + wantIsGeneveEncap: false, + }, + { + name: "naked disco", + msg: nakedDisco, + wantPacketLooksLikeType: packetLooksLikeDisco, + wantIsGeneveEncap: false, + }, + { + name: "geneve encap disco", + msg: geneveEncapDisco, + wantPacketLooksLikeType: packetLooksLikeDisco, + wantIsGeneveEncap: true, + }, + { + name: "geneve encap too short disco", + msg: geneveEncapDisco[:len(geneveEncapDisco)-key.DiscoPublicRawLen], + wantPacketLooksLikeType: packetLooksLikeWireGuard, + wantIsGeneveEncap: false, + }, + { + name: "geneve encap disco nonzero geneve version", + msg: geneveEncapDiscoNonZeroGeneveVersion, + wantPacketLooksLikeType: packetLooksLikeWireGuard, + wantIsGeneveEncap: false, + }, + { + name: "geneve encap disco nonzero geneve reserved bits", + msg: geneveEncapDiscoNonZeroGeneveReservedBits, + wantPacketLooksLikeType: packetLooksLikeWireGuard, + wantIsGeneveEncap: false, + }, + { + name: "geneve encap disco nonzero geneve vni lsb", + msg: geneveEncapDiscoNonZeroGeneveVNILSB, + wantPacketLooksLikeType: packetLooksLikeWireGuard, + wantIsGeneveEncap: false, + }, + { + name: "geneve encap wireguard", + msg: geneveEncapWireGuard, + wantPacketLooksLikeType: packetLooksLikeWireGuard, + wantIsGeneveEncap: true, + }, + { + name: "naked WireGuard Initiation type", + msg: nakedWireGuardInitiation, + wantPacketLooksLikeType: packetLooksLikeWireGuard, + wantIsGeneveEncap: false, + }, + { + name: "naked WireGuard Response type", + msg: nakedWireGuardResponse, + wantPacketLooksLikeType: packetLooksLikeWireGuard, + wantIsGeneveEncap: false, + }, + { + name: "naked WireGuard Cookie Reply type", + msg: nakedWireGuardCookieReply, + wantPacketLooksLikeType: packetLooksLikeWireGuard, + wantIsGeneveEncap: false, + }, + { + name: "naked WireGuard Transport type", + msg: nakedWireGuardTransport, + wantPacketLooksLikeType: packetLooksLikeWireGuard, + wantIsGeneveEncap: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotPacketLooksLikeType, gotIsGeneveEncap := packetLooksLike(tt.msg) + if gotPacketLooksLikeType != tt.wantPacketLooksLikeType { + t.Errorf("packetLooksLike() gotPacketLooksLikeType = %v, want %v", gotPacketLooksLikeType, tt.wantPacketLooksLikeType) + } + if gotIsGeneveEncap != tt.wantIsGeneveEncap { + t.Errorf("packetLooksLike() gotIsGeneveEncap = %v, want %v", gotIsGeneveEncap, tt.wantIsGeneveEncap) + } + }) + } +} + +func Test_looksLikeInitiationMsg(t *testing.T) { + // initMsg was captured as the first packet from a WireGuard "session" + initMsg, err := hex.DecodeString("01000000d9205f67915a500e377b409e0c3d97ca91e68654b95952de965e75df491000cce00632678cd9e8c8525556aa8daf24e6cfc44c48812bb560ff3c1c5dee061b3f833dfaa48acf13b64bd1e0027aa4d977a3721b82fd6072338702fc3193651404980ad46dae2869ba6416cc0eb38621a4140b5b918eb6402b697202adb3002a6d00000000000000000000000000000000") + if err != nil { + t.Fatal(err) + } + if len(initMsg) != device.MessageInitiationSize { + t.Fatalf("initMsg is not %d bytes long", device.MessageInitiationSize) + } + initMsgSizeTransportType := make([]byte, len(initMsg)) + copy(initMsgSizeTransportType, initMsg) + binary.LittleEndian.PutUint32(initMsgSizeTransportType, device.MessageTransportType) + tests := []struct { + name string + b []byte + want bool + }{ + { + name: "valid initiation", + b: initMsg, + want: true, + }, + { + name: "invalid message type field", + b: initMsgSizeTransportType, + want: false, + }, + { + name: "too small", + b: initMsg[:device.MessageInitiationSize-1], + want: false, + }, + { + name: "too big", + b: append(initMsg, 0), + want: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := looksLikeInitiationMsg(tt.b); got != tt.want { + t.Errorf("looksLikeInitiationMsg() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_nodeHasCap(t *testing.T) { + nodeAOnlyIPv4 := &tailcfg.Node{ + ID: 1, + Addresses: []netip.Prefix{ + netip.MustParsePrefix("1.1.1.1/32"), + }, + } + nodeBOnlyIPv6 := nodeAOnlyIPv4.Clone() + nodeBOnlyIPv6.Addresses[0] = netip.MustParsePrefix("::1/128") + + nodeCOnlyIPv4 := &tailcfg.Node{ + ID: 2, + Addresses: []netip.Prefix{ + netip.MustParsePrefix("2.2.2.2/32"), + }, + } + nodeDOnlyIPv6 := nodeCOnlyIPv4.Clone() + nodeDOnlyIPv6.Addresses[0] = netip.MustParsePrefix("::2/128") + + tests := []struct { + name string + filt *filter.Filter + src tailcfg.NodeView + dst tailcfg.NodeView + cap tailcfg.PeerCapability + want bool + }{ + { + name: "match v4", + filt: filter.New([]filtertype.Match{ + { + Srcs: []netip.Prefix{netip.MustParsePrefix("2.2.2.2/32")}, + Caps: []filtertype.CapMatch{ + { + Dst: netip.MustParsePrefix("1.1.1.1/32"), + Cap: tailcfg.PeerCapabilityRelayTarget, + }, + }, + }, + }, nil, nil, nil, nil, nil), + src: nodeCOnlyIPv4.View(), + dst: nodeAOnlyIPv4.View(), + cap: tailcfg.PeerCapabilityRelayTarget, + want: true, + }, + { + name: "match v6", + filt: filter.New([]filtertype.Match{ + { + Srcs: []netip.Prefix{netip.MustParsePrefix("::2/128")}, + Caps: []filtertype.CapMatch{ + { + Dst: netip.MustParsePrefix("::1/128"), + Cap: tailcfg.PeerCapabilityRelayTarget, + }, + }, + }, + }, nil, nil, nil, nil, nil), + src: nodeDOnlyIPv6.View(), + dst: nodeBOnlyIPv6.View(), + cap: tailcfg.PeerCapabilityRelayTarget, + want: true, + }, + { + name: "no match CapMatch Dst", + filt: filter.New([]filtertype.Match{ + { + Srcs: []netip.Prefix{netip.MustParsePrefix("::2/128")}, + Caps: []filtertype.CapMatch{ + { + Dst: netip.MustParsePrefix("::3/128"), + Cap: tailcfg.PeerCapabilityRelayTarget, + }, + }, + }, + }, nil, nil, nil, nil, nil), + src: nodeDOnlyIPv6.View(), + dst: nodeBOnlyIPv6.View(), + cap: tailcfg.PeerCapabilityRelayTarget, + want: false, + }, + { + name: "no match peer cap", + filt: filter.New([]filtertype.Match{ + { + Srcs: []netip.Prefix{netip.MustParsePrefix("::2/128")}, + Caps: []filtertype.CapMatch{ + { + Dst: netip.MustParsePrefix("::1/128"), + Cap: tailcfg.PeerCapabilityIngress, + }, + }, + }, + }, nil, nil, nil, nil, nil), + src: nodeDOnlyIPv6.View(), + dst: nodeBOnlyIPv6.View(), + cap: tailcfg.PeerCapabilityRelayTarget, + want: false, + }, + { + name: "nil src", + filt: filter.New([]filtertype.Match{ + { + Srcs: []netip.Prefix{netip.MustParsePrefix("2.2.2.2/32")}, + Caps: []filtertype.CapMatch{ + { + Dst: netip.MustParsePrefix("1.1.1.1/32"), + Cap: tailcfg.PeerCapabilityRelayTarget, + }, + }, + }, + }, nil, nil, nil, nil, nil), + src: tailcfg.NodeView{}, + dst: nodeAOnlyIPv4.View(), + cap: tailcfg.PeerCapabilityRelayTarget, + want: false, + }, + { + name: "nil dst", + filt: filter.New([]filtertype.Match{ + { + Srcs: []netip.Prefix{netip.MustParsePrefix("2.2.2.2/32")}, + Caps: []filtertype.CapMatch{ + { + Dst: netip.MustParsePrefix("1.1.1.1/32"), + Cap: tailcfg.PeerCapabilityRelayTarget, + }, + }, + }, + }, nil, nil, nil, nil, nil), + src: nodeCOnlyIPv4.View(), + dst: tailcfg.NodeView{}, + cap: tailcfg.PeerCapabilityRelayTarget, + want: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := nodeHasCap(tt.filt, tt.src, tt.dst, tt.cap); got != tt.want { + t.Errorf("nodeHasCap() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestConn_onNodeViewsUpdate_updateRelayServersSet(t *testing.T) { + peerNodeCandidateRelay := &tailcfg.Node{ + Cap: 121, + ID: 1, + Addresses: []netip.Prefix{ + netip.MustParsePrefix("1.1.1.1/32"), + }, + HomeDERP: 1, + Key: key.NewNode().Public(), + DiscoKey: key.NewDisco().Public(), + } + + peerNodeNotCandidateRelayCapVer := &tailcfg.Node{ + Cap: 120, // intentionally lower to fail capVer check + ID: 1, + Addresses: []netip.Prefix{ + netip.MustParsePrefix("1.1.1.1/32"), + }, + HomeDERP: 1, + Key: key.NewNode().Public(), + DiscoKey: key.NewDisco().Public(), + } + + selfNode := &tailcfg.Node{ + Cap: 120, // intentionally lower than capVerIsRelayCapable to verify self check + ID: 2, + Addresses: []netip.Prefix{ + netip.MustParsePrefix("2.2.2.2/32"), + }, + HomeDERP: 2, + Key: key.NewNode().Public(), + DiscoKey: key.NewDisco().Public(), + } + + selfNodeNodeAttrDisableRelayClient := selfNode.Clone() + selfNodeNodeAttrDisableRelayClient.CapMap = make(tailcfg.NodeCapMap) + selfNodeNodeAttrDisableRelayClient.CapMap[tailcfg.NodeAttrDisableRelayClient] = nil + + selfNodeNodeAttrOnlyTCP443 := selfNode.Clone() + selfNodeNodeAttrOnlyTCP443.CapMap = make(tailcfg.NodeCapMap) + selfNodeNodeAttrOnlyTCP443.CapMap[tailcfg.NodeAttrOnlyTCP443] = nil + + tests := []struct { + name string + filt *filter.Filter + self tailcfg.NodeView + peers []tailcfg.NodeView + wantRelayServers set.Set[candidatePeerRelay] + wantRelayClientEnabled bool + }{ + { + name: "candidate relay server", + filt: filter.New([]filtertype.Match{ + { + Srcs: peerNodeCandidateRelay.Addresses, + Caps: []filtertype.CapMatch{ + { + Dst: selfNode.Addresses[0], + Cap: tailcfg.PeerCapabilityRelayTarget, + }, + }, + }, + }, nil, nil, nil, nil, nil), + self: selfNode.View(), + peers: []tailcfg.NodeView{peerNodeCandidateRelay.View()}, + wantRelayServers: set.SetOf([]candidatePeerRelay{ + { + nodeKey: peerNodeCandidateRelay.Key, + discoKey: peerNodeCandidateRelay.DiscoKey, + derpHomeRegionID: 1, + }, + }), + wantRelayClientEnabled: true, + }, + { + name: "no candidate relay server because self has tailcfg.NodeAttrDisableRelayClient", + filt: filter.New([]filtertype.Match{ + { + Srcs: peerNodeCandidateRelay.Addresses, + Caps: []filtertype.CapMatch{ + { + Dst: selfNodeNodeAttrDisableRelayClient.Addresses[0], + Cap: tailcfg.PeerCapabilityRelayTarget, + }, + }, + }, + }, nil, nil, nil, nil, nil), + self: selfNodeNodeAttrDisableRelayClient.View(), + peers: []tailcfg.NodeView{peerNodeCandidateRelay.View()}, + wantRelayServers: make(set.Set[candidatePeerRelay]), + wantRelayClientEnabled: false, + }, + { + name: "no candidate relay server because self has tailcfg.NodeAttrOnlyTCP443", + filt: filter.New([]filtertype.Match{ + { + Srcs: peerNodeCandidateRelay.Addresses, + Caps: []filtertype.CapMatch{ + { + Dst: selfNodeNodeAttrOnlyTCP443.Addresses[0], + Cap: tailcfg.PeerCapabilityRelayTarget, + }, + }, + }, + }, nil, nil, nil, nil, nil), + self: selfNodeNodeAttrOnlyTCP443.View(), + peers: []tailcfg.NodeView{peerNodeCandidateRelay.View()}, + wantRelayServers: make(set.Set[candidatePeerRelay]), + wantRelayClientEnabled: false, + }, + { + name: "self candidate relay server", + filt: filter.New([]filtertype.Match{ + { + Srcs: selfNode.Addresses, + Caps: []filtertype.CapMatch{ + { + Dst: selfNode.Addresses[0], + Cap: tailcfg.PeerCapabilityRelayTarget, + }, + }, + }, + }, nil, nil, nil, nil, nil), + self: selfNode.View(), + peers: []tailcfg.NodeView{selfNode.View()}, + wantRelayServers: set.SetOf([]candidatePeerRelay{ + { + nodeKey: selfNode.Key, + discoKey: selfNode.DiscoKey, + derpHomeRegionID: 2, + }, + }), + wantRelayClientEnabled: true, + }, + { + name: "no candidate relay server", + filt: filter.New([]filtertype.Match{ + { + Srcs: peerNodeNotCandidateRelayCapVer.Addresses, + Caps: []filtertype.CapMatch{ + { + Dst: selfNode.Addresses[0], + Cap: tailcfg.PeerCapabilityRelayTarget, + }, + }, + }, + }, nil, nil, nil, nil, nil), + self: selfNode.View(), + peers: []tailcfg.NodeView{peerNodeNotCandidateRelayCapVer.View()}, + wantRelayServers: make(set.Set[candidatePeerRelay]), + wantRelayClientEnabled: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := newConn(t.Logf) + c.filt = tt.filt + if len(tt.wantRelayServers) == 0 { + // So we can verify it gets flipped back. + c.hasPeerRelayServers.Store(true) + } + + c.onNodeViewsUpdate(NodeViewsUpdate{ + SelfNode: tt.self, + Peers: tt.peers, + }) + got := c.relayManager.getServers() + if !got.Equal(tt.wantRelayServers) { + t.Fatalf("got: %v != want: %v", got, tt.wantRelayServers) + } + if len(tt.wantRelayServers) > 0 != c.hasPeerRelayServers.Load() { + t.Fatalf("c.hasPeerRelayServers: %v != len(tt.wantRelayServers) > 0: %v", c.hasPeerRelayServers.Load(), len(tt.wantRelayServers) > 0) + } + if c.relayClientEnabled != tt.wantRelayClientEnabled { + t.Fatalf("c.relayClientEnabled: %v != wantRelayClientEnabled: %v", c.relayClientEnabled, tt.wantRelayClientEnabled) + } + }) + } +} + +func TestConn_receiveIP(t *testing.T) { + looksLikeNakedDisco := make([]byte, 0, len(disco.Magic)+key.DiscoPublicRawLen) + looksLikeNakedDisco = append(looksLikeNakedDisco, disco.Magic...) + looksLikeNakedDisco = looksLikeNakedDisco[:cap(looksLikeNakedDisco)] + + looksLikeGeneveDisco := make([]byte, packet.GeneveFixedHeaderLength+len(looksLikeNakedDisco)) + gh := packet.GeneveHeader{ + Protocol: packet.GeneveProtocolDisco, + } + gh.VNI.Set(1) + err := gh.Encode(looksLikeGeneveDisco) + if err != nil { + t.Fatal(err) + } + copy(looksLikeGeneveDisco[packet.GeneveFixedHeaderLength:], looksLikeNakedDisco) + + looksLikeSTUNBinding := stun.Response(stun.NewTxID(), netip.MustParseAddrPort("127.0.0.1:7777")) + + findMetricByName := func(name string) *clientmetric.Metric { + for _, metric := range clientmetric.Metrics() { + if metric.Name() == name { + return metric + } + } + t.Fatalf("failed to find metric with name: %v", name) + return nil + } + + looksLikeNakedWireGuardInit := make([]byte, device.MessageInitiationSize) + binary.LittleEndian.PutUint32(looksLikeNakedWireGuardInit, device.MessageInitiationType) + + looksLikeGeneveWireGuardInit := make([]byte, packet.GeneveFixedHeaderLength+device.MessageInitiationSize) + gh = packet.GeneveHeader{ + Protocol: packet.GeneveProtocolWireGuard, + } + gh.VNI.Set(1) + err = gh.Encode(looksLikeGeneveWireGuardInit) + if err != nil { + t.Fatal(err) + } + copy(looksLikeGeneveWireGuardInit[packet.GeneveFixedHeaderLength:], looksLikeNakedWireGuardInit) + + newPeerMapInsertableEndpoint := func(lastRecvWG mono.Time) *endpoint { + ep := &endpoint{ + nodeID: 1, + publicKey: key.NewNode().Public(), + lastRecvWG: lastRecvWG, + } + ep.disco.Store(&endpointDisco{ + key: key.NewDisco().Public(), + }) + return ep + } + + tests := []struct { + name string + // A copy of b is used as input, tests may re-use the same value. + b []byte + ipp netip.AddrPort + // cache must be non-nil, and must not be reused across tests. If + // cache.de is non-nil after receiveIP(), then we verify it is equal to + // wantEndpointType. + cache *epAddrEndpointCache + // If true, wantEndpointType is inserted into the [peerMap]. + insertWantEndpointTypeInPeerMap bool + // If insertWantEndpointTypeInPeerMap is true, use this [epAddr] for it + // in the [peerMap.setNodeKeyForEpAddr] call. + peerMapEpAddr epAddr + // If [*endpoint] then we expect 'got' to be the same [*endpoint]. If + // [*lazyEndpoint] and [*lazyEndpoint.maybeEP] is non-nil, we expect + // got.maybeEP to also be non-nil. Must not be reused across tests. + wantEndpointType wgconn.Endpoint + wantSize int + wantIsGeneveEncap bool + wantOk bool + wantMetricInc *clientmetric.Metric + wantNoteRecvActivityCalled bool + }{ + { + name: "naked disco", + b: looksLikeNakedDisco, + ipp: netip.MustParseAddrPort("127.0.0.1:7777"), + cache: &epAddrEndpointCache{}, + wantEndpointType: nil, + wantSize: 0, + wantIsGeneveEncap: false, + wantOk: false, + wantMetricInc: metricRecvDiscoBadPeer, + wantNoteRecvActivityCalled: false, + }, + { + name: "geneve encap disco", + b: looksLikeGeneveDisco, + ipp: netip.MustParseAddrPort("127.0.0.1:7777"), + cache: &epAddrEndpointCache{}, + wantEndpointType: nil, + wantSize: 0, + wantIsGeneveEncap: false, + wantOk: false, + wantMetricInc: metricRecvDiscoBadPeer, + wantNoteRecvActivityCalled: false, + }, + { + name: "STUN binding", + b: looksLikeSTUNBinding, + ipp: netip.MustParseAddrPort("127.0.0.1:7777"), + cache: &epAddrEndpointCache{}, + wantEndpointType: nil, + wantSize: 0, + wantIsGeneveEncap: false, + wantOk: false, + wantMetricInc: findMetricByName("netcheck_stun_recv_ipv4"), + wantNoteRecvActivityCalled: false, + }, + { + name: "naked WireGuard init lazyEndpoint empty peerMap", + b: looksLikeNakedWireGuardInit, + ipp: netip.MustParseAddrPort("127.0.0.1:7777"), + cache: &epAddrEndpointCache{}, + wantEndpointType: &lazyEndpoint{}, + wantSize: len(looksLikeNakedWireGuardInit), + wantIsGeneveEncap: false, + wantOk: true, + wantMetricInc: nil, + wantNoteRecvActivityCalled: false, + }, + { + name: "naked WireGuard init endpoint matching peerMap entry", + b: looksLikeNakedWireGuardInit, + ipp: netip.MustParseAddrPort("127.0.0.1:7777"), + cache: &epAddrEndpointCache{}, + insertWantEndpointTypeInPeerMap: true, + peerMapEpAddr: epAddr{ap: netip.MustParseAddrPort("127.0.0.1:7777")}, + wantEndpointType: newPeerMapInsertableEndpoint(0), + wantSize: len(looksLikeNakedWireGuardInit), + wantIsGeneveEncap: false, + wantOk: true, + wantMetricInc: nil, + wantNoteRecvActivityCalled: true, + }, + { + name: "geneve WireGuard init lazyEndpoint empty peerMap", + b: looksLikeGeneveWireGuardInit, + ipp: netip.MustParseAddrPort("127.0.0.1:7777"), + cache: &epAddrEndpointCache{}, + wantEndpointType: &lazyEndpoint{}, + wantSize: len(looksLikeGeneveWireGuardInit) - packet.GeneveFixedHeaderLength, + wantIsGeneveEncap: true, + wantOk: true, + wantMetricInc: nil, + wantNoteRecvActivityCalled: false, + }, + { + name: "geneve WireGuard init lazyEndpoint matching peerMap activity noted", + b: looksLikeGeneveWireGuardInit, + ipp: netip.MustParseAddrPort("127.0.0.1:7777"), + cache: &epAddrEndpointCache{}, + insertWantEndpointTypeInPeerMap: true, + peerMapEpAddr: epAddr{ap: netip.MustParseAddrPort("127.0.0.1:7777"), vni: gh.VNI}, + wantEndpointType: &lazyEndpoint{ + maybeEP: newPeerMapInsertableEndpoint(0), + }, + wantSize: len(looksLikeGeneveWireGuardInit) - packet.GeneveFixedHeaderLength, + wantIsGeneveEncap: true, + wantOk: true, + wantMetricInc: nil, + wantNoteRecvActivityCalled: true, + }, + { + name: "geneve WireGuard init lazyEndpoint matching peerMap no activity noted", + b: looksLikeGeneveWireGuardInit, + ipp: netip.MustParseAddrPort("127.0.0.1:7777"), + cache: &epAddrEndpointCache{}, + insertWantEndpointTypeInPeerMap: true, + peerMapEpAddr: epAddr{ap: netip.MustParseAddrPort("127.0.0.1:7777"), vni: gh.VNI}, + wantEndpointType: &lazyEndpoint{ + maybeEP: newPeerMapInsertableEndpoint(mono.Now().Add(time.Hour * 24)), + }, + wantSize: len(looksLikeGeneveWireGuardInit) - packet.GeneveFixedHeaderLength, + wantIsGeneveEncap: true, + wantOk: true, + wantMetricInc: nil, + wantNoteRecvActivityCalled: false, + }, + // TODO(jwhited): verify cache.de is used when conditions permit + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + noteRecvActivityCalled := false + metricBefore := int64(0) + if tt.wantMetricInc != nil { + metricBefore = tt.wantMetricInc.Value() + } + + // Init Conn. + c := &Conn{ + privateKey: key.NewNode(), + netChecker: &netcheck.Client{}, + peerMap: newPeerMap(), + } + c.havePrivateKey.Store(true) + c.noteRecvActivity = func(public key.NodePublic) { + noteRecvActivityCalled = true + } + var counts netlogtype.CountsByConnection + c.SetConnectionCounter(counts.Add) + + if tt.insertWantEndpointTypeInPeerMap { + var insertEPIntoPeerMap *endpoint + switch ep := tt.wantEndpointType.(type) { + case *endpoint: + insertEPIntoPeerMap = ep + case *lazyEndpoint: + insertEPIntoPeerMap = ep.maybeEP + default: + t.Fatal("unexpected tt.wantEndpointType concrete type") + } + insertEPIntoPeerMap.c = c + c.peerMap.upsertEndpoint(insertEPIntoPeerMap, key.DiscoPublic{}) + c.peerMap.setNodeKeyForEpAddr(tt.peerMapEpAddr, insertEPIntoPeerMap.publicKey) + } + + // Allow the same input packet to be used across tests, receiveIP() + // may mutate. + inputPacket := make([]byte, len(tt.b)) + copy(inputPacket, tt.b) + + got, gotSize, gotIsGeneveEncap, gotOk := c.receiveIP(inputPacket, tt.ipp, tt.cache) + if (tt.wantEndpointType == nil) != (got == nil) { + t.Errorf("receiveIP() (tt.wantEndpointType == nil): %v != (got == nil): %v", tt.wantEndpointType == nil, got == nil) + } + if tt.wantEndpointType != nil && reflect.TypeOf(got).String() != reflect.TypeOf(tt.wantEndpointType).String() { + t.Errorf("receiveIP() got = %v, want %v", reflect.TypeOf(got).String(), reflect.TypeOf(tt.wantEndpointType).String()) + } else { + switch ep := tt.wantEndpointType.(type) { + case *endpoint: + if ep != got.(*endpoint) { + t.Errorf("receiveIP() want [*endpoint]: %p != got [*endpoint]: %p", ep, got) + } + case *lazyEndpoint: + if ep.maybeEP != nil && ep.maybeEP != got.(*lazyEndpoint).maybeEP { + t.Errorf("receiveIP() want [*lazyEndpoint.maybeEP]: %p != got [*lazyEndpoint.maybeEP] %p", ep, got) + } + } + } + + if gotSize != tt.wantSize { + t.Errorf("receiveIP() gotSize = %v, want %v", gotSize, tt.wantSize) + } + if gotIsGeneveEncap != tt.wantIsGeneveEncap { + t.Errorf("receiveIP() gotIsGeneveEncap = %v, want %v", gotIsGeneveEncap, tt.wantIsGeneveEncap) + } + if gotOk != tt.wantOk { + t.Errorf("receiveIP() gotOk = %v, want %v", gotOk, tt.wantOk) + } + if tt.wantMetricInc != nil && tt.wantMetricInc.Value() != metricBefore+1 { + t.Errorf("receiveIP() metric %v not incremented", tt.wantMetricInc.Name()) + } + if tt.wantNoteRecvActivityCalled != noteRecvActivityCalled { + t.Errorf("receiveIP() noteRecvActivityCalled = %v, want %v", noteRecvActivityCalled, tt.wantNoteRecvActivityCalled) + } + + if tt.cache.de != nil { + switch ep := got.(type) { + case *endpoint: + if tt.cache.de != ep { + t.Errorf("receiveIP() cache populated with [*endpoint] %p, want %p", tt.cache.de, ep) + } + case *lazyEndpoint: + if tt.cache.de != ep.maybeEP { + t.Errorf("receiveIP() cache populated with [*endpoint] %p, want (lazyEndpoint.maybeEP) %p", tt.cache.de, ep.maybeEP) + } + default: + t.Fatal("receiveIP() unexpected [conn.Endpoint] type") + } + } + + // Verify physical rx stats + wantNonzeroRxStats := false + gotPhy := counts.Clone() + switch ep := tt.wantEndpointType.(type) { + case *lazyEndpoint: + if ep.maybeEP != nil { + wantNonzeroRxStats = true + } + case *endpoint: + wantNonzeroRxStats = true + } + if tt.wantOk && wantNonzeroRxStats { + wantRxBytes := uint64(tt.wantSize) + if tt.wantIsGeneveEncap { + wantRxBytes += packet.GeneveFixedHeaderLength + } + wantPhy := map[netlogtype.Connection]netlogtype.Counts{ + {Dst: tt.ipp}: { + RxPackets: 1, + RxBytes: wantRxBytes, + }, + } + if d := cmp.Diff(gotPhy, wantPhy); d != "" { + t.Errorf("receiveIP() stats mismatch (-got +want):\n%s", d) + } + } else { + if len(gotPhy) != 0 { + t.Errorf("receiveIP() unexpected nonzero physical count stats: %+v", gotPhy) + } + } + }) + } +} + +func Test_lazyEndpoint_InitiationMessagePublicKey(t *testing.T) { + tests := []struct { + name string + callWithPeerMapKey bool + maybeEPMatchingKey bool + wantNoteRecvActivityCalled bool + }{ + { + name: "noteRecvActivity called", + callWithPeerMapKey: true, + maybeEPMatchingKey: false, + wantNoteRecvActivityCalled: true, + }, + { + name: "maybeEP early return", + callWithPeerMapKey: true, + maybeEPMatchingKey: true, + wantNoteRecvActivityCalled: false, + }, + { + name: "not in peerMap early return", + callWithPeerMapKey: false, + maybeEPMatchingKey: false, + wantNoteRecvActivityCalled: false, + }, + { + name: "not in peerMap maybeEP early return", + callWithPeerMapKey: false, + maybeEPMatchingKey: true, + wantNoteRecvActivityCalled: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ep := &endpoint{ + nodeID: 1, + publicKey: key.NewNode().Public(), + } + ep.disco.Store(&endpointDisco{ + key: key.NewDisco().Public(), + }) + + var noteRecvActivityCalledFor key.NodePublic + conn := newConn(t.Logf) + conn.noteRecvActivity = func(public key.NodePublic) { + // wireguard-go will call into ParseEndpoint if the "real" + // noteRecvActivity ends up JIT configuring the peer. Mimic that + // to ensure there are no deadlocks around conn.mu. + // See tailscale/tailscale#16651 & http://go/corp#30836 + _, err := conn.ParseEndpoint(ep.publicKey.UntypedHexString()) + if err != nil { + t.Fatalf("ParseEndpoint() err: %v", err) + } + noteRecvActivityCalledFor = public + } + ep.c = conn + + var pubKey [32]byte + if tt.callWithPeerMapKey { + copy(pubKey[:], ep.publicKey.AppendTo(nil)) + } + conn.peerMap.upsertEndpoint(ep, key.DiscoPublic{}) + + le := &lazyEndpoint{ + c: conn, + } + if tt.maybeEPMatchingKey { + le.maybeEP = ep + } + le.InitiationMessagePublicKey(pubKey) + want := key.NodePublic{} + if tt.wantNoteRecvActivityCalled { + want = ep.publicKey + } + if noteRecvActivityCalledFor.Compare(want) != 0 { + t.Fatalf("noteRecvActivityCalledFor = %v, want %v", noteRecvActivityCalledFor, want) + } + }) + } +} + +func Test_lazyEndpoint_FromPeer(t *testing.T) { + tests := []struct { + name string + callWithPeerMapKey bool + maybeEPMatchingKey bool + wantEpAddrInPeerMap bool + }{ + { + name: "epAddr in peerMap", + callWithPeerMapKey: true, + maybeEPMatchingKey: false, + wantEpAddrInPeerMap: true, + }, + { + name: "maybeEP early return", + callWithPeerMapKey: true, + maybeEPMatchingKey: true, + wantEpAddrInPeerMap: false, + }, + { + name: "not in peerMap early return", + callWithPeerMapKey: false, + maybeEPMatchingKey: false, + wantEpAddrInPeerMap: false, + }, + { + name: "not in peerMap maybeEP early return", + callWithPeerMapKey: false, + maybeEPMatchingKey: true, + wantEpAddrInPeerMap: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ep := &endpoint{ + nodeID: 1, + publicKey: key.NewNode().Public(), + } + ep.disco.Store(&endpointDisco{ + key: key.NewDisco().Public(), + }) + conn := newConn(t.Logf) + ep.c = conn + + var pubKey [32]byte + if tt.callWithPeerMapKey { + copy(pubKey[:], ep.publicKey.AppendTo(nil)) + } + conn.peerMap.upsertEndpoint(ep, key.DiscoPublic{}) + + le := &lazyEndpoint{ + c: conn, + src: epAddr{ap: netip.MustParseAddrPort("127.0.0.1:7777")}, + } + if tt.maybeEPMatchingKey { + le.maybeEP = ep + } + le.FromPeer(pubKey) + if tt.wantEpAddrInPeerMap { + gotEP, ok := conn.peerMap.endpointForEpAddr(le.src) + if !ok { + t.Errorf("lazyEndpoint epAddr not found in peerMap") + } else if gotEP != ep { + t.Errorf("gotEP: %p != ep: %p", gotEP, ep) + } + } else if len(conn.peerMap.byEpAddr) != 0 { + t.Errorf("unexpected epAddr in peerMap") + } + }) + } +} + +func TestRotateDiscoKey(t *testing.T) { + c := newConn(t.Logf) + + oldPrivate, oldPublic := c.discoAtomic.Pair() + oldShort := c.discoAtomic.Short() + + if oldPublic != oldPrivate.Public() { + t.Fatalf("old public key doesn't match old private key") + } + if oldShort != oldPublic.ShortString() { + t.Fatalf("old short string doesn't match old public key") + } + + testDiscoKey := key.NewDisco().Public() + c.mu.Lock() + c.discoInfo[testDiscoKey] = &discoInfo{ + discoKey: testDiscoKey, + discoShort: testDiscoKey.ShortString(), + } + if len(c.discoInfo) != 1 { + t.Fatalf("expected 1 discoInfo entry, got %d", len(c.discoInfo)) + } + c.mu.Unlock() + + c.RotateDiscoKey() + + newPrivate, newPublic := c.discoAtomic.Pair() + newShort := c.discoAtomic.Short() + + if newPublic.Compare(oldPublic) == 0 { + t.Fatalf("disco key didn't change after rotation") + } + if newShort == oldShort { + t.Fatalf("short string didn't change after rotation") + } + + if newPublic != newPrivate.Public() { + t.Fatalf("new public key doesn't match new private key") + } + if newShort != newPublic.ShortString() { + t.Fatalf("new short string doesn't match new public key") + } + + c.mu.Lock() + if len(c.discoInfo) != 0 { + t.Fatalf("expected discoInfo to be cleared, got %d entries", len(c.discoInfo)) + } + c.mu.Unlock() +} + +func TestRotateDiscoKeyMultipleTimes(t *testing.T) { + c := newConn(t.Logf) + + keys := make([]key.DiscoPublic, 0, 5) + keys = append(keys, c.discoAtomic.Public()) + + for i := 0; i < 4; i++ { + c.RotateDiscoKey() + newKey := c.discoAtomic.Public() + + for j, oldKey := range keys { + if newKey.Compare(oldKey) == 0 { + t.Fatalf("rotation %d produced same key as rotation %d", i+1, j) + } + } + + keys = append(keys, newKey) + } +} diff --git a/wgengine/magicsock/peermap.go b/wgengine/magicsock/peermap.go index e1c7db1f6..136353563 100644 --- a/wgengine/magicsock/peermap.go +++ b/wgengine/magicsock/peermap.go @@ -4,8 +4,6 @@ package magicsock import ( - "net/netip" - "tailscale.com/tailcfg" "tailscale.com/types/key" "tailscale.com/util/set" @@ -15,17 +13,17 @@ import ( // peer. type peerInfo struct { ep *endpoint // always non-nil. - // ipPorts is an inverted version of peerMap.byIPPort (below), so + // epAddrs is an inverted version of peerMap.byEpAddr (below), so // that when we're deleting this node, we can rapidly find out the - // keys that need deleting from peerMap.byIPPort without having to - // iterate over every IPPort known for any peer. - ipPorts set.Set[netip.AddrPort] + // keys that need deleting from peerMap.byEpAddr without having to + // iterate over every epAddr known for any peer. + epAddrs set.Set[epAddr] } func newPeerInfo(ep *endpoint) *peerInfo { return &peerInfo{ ep: ep, - ipPorts: set.Set[netip.AddrPort]{}, + epAddrs: set.Set[epAddr]{}, } } @@ -35,9 +33,21 @@ func newPeerInfo(ep *endpoint) *peerInfo { // It doesn't do any locking; all access must be done with Conn.mu held. type peerMap struct { byNodeKey map[key.NodePublic]*peerInfo - byIPPort map[netip.AddrPort]*peerInfo + byEpAddr map[epAddr]*peerInfo byNodeID map[tailcfg.NodeID]*peerInfo + // relayEpAddrByNodeKey ensures we only hold a single relay + // [epAddr] (vni.isSet()) for a given node key in byEpAddr, vs letting them + // grow unbounded. Relay [epAddr]'s are dynamically created by + // [relayManager] during path discovery, and are only useful to track in + // peerMap so long as they are the endpoint.bestAddr. [relayManager] handles + // all creation and initial probing responsibilities otherwise, and it does + // not depend on [peerMap]. + // + // Note: This doesn't address unbounded growth of non-relay epAddr's in + // byEpAddr. That issue is being tracked in http://go/corp/29422. + relayEpAddrByNodeKey map[key.NodePublic]epAddr + // nodesOfDisco contains the set of nodes that are using a // DiscoKey. Usually those sets will be just one node. nodesOfDisco map[key.DiscoPublic]set.Set[key.NodePublic] @@ -45,10 +55,11 @@ type peerMap struct { func newPeerMap() peerMap { return peerMap{ - byNodeKey: map[key.NodePublic]*peerInfo{}, - byIPPort: map[netip.AddrPort]*peerInfo{}, - byNodeID: map[tailcfg.NodeID]*peerInfo{}, - nodesOfDisco: map[key.DiscoPublic]set.Set[key.NodePublic]{}, + byNodeKey: map[key.NodePublic]*peerInfo{}, + byEpAddr: map[epAddr]*peerInfo{}, + byNodeID: map[tailcfg.NodeID]*peerInfo{}, + relayEpAddrByNodeKey: map[key.NodePublic]epAddr{}, + nodesOfDisco: map[key.DiscoPublic]set.Set[key.NodePublic]{}, } } @@ -88,10 +99,10 @@ func (m *peerMap) endpointForNodeID(nodeID tailcfg.NodeID) (ep *endpoint, ok boo return nil, false } -// endpointForIPPort returns the endpoint for the peer we -// believe to be at ipp, or nil if we don't know of any such peer. -func (m *peerMap) endpointForIPPort(ipp netip.AddrPort) (ep *endpoint, ok bool) { - if info, ok := m.byIPPort[ipp]; ok { +// endpointForEpAddr returns the endpoint for the peer we +// believe to be at addr, or nil if we don't know of any such peer. +func (m *peerMap) endpointForEpAddr(addr epAddr) (ep *endpoint, ok bool) { + if info, ok := m.byEpAddr[addr]; ok { return info.ep, true } return nil, false @@ -148,10 +159,10 @@ func (m *peerMap) upsertEndpoint(ep *endpoint, oldDiscoKey key.DiscoPublic) { // TODO(raggi,catzkorn): this could mean that if a "isWireguardOnly" // peer has, say, 192.168.0.2 and so does a tailscale peer, the // wireguard one will win. That may not be the outcome that we want - - // perhaps we should prefer bestAddr.AddrPort if it is set? + // perhaps we should prefer bestAddr.epAddr.ap if it is set? // see tailscale/tailscale#7994 for ipp := range ep.endpointState { - m.setNodeKeyForIPPort(ipp, ep.publicKey) + m.setNodeKeyForEpAddr(epAddr{ap: ipp}, ep.publicKey) } return } @@ -163,20 +174,31 @@ func (m *peerMap) upsertEndpoint(ep *endpoint, oldDiscoKey key.DiscoPublic) { discoSet.Add(ep.publicKey) } -// setNodeKeyForIPPort makes future peer lookups by ipp return the +// setNodeKeyForEpAddr makes future peer lookups by addr return the // same endpoint as a lookup by nk. // -// This should only be called with a fully verified mapping of ipp to +// This should only be called with a fully verified mapping of addr to // nk, because calling this function defines the endpoint we hand to -// WireGuard for packets received from ipp. -func (m *peerMap) setNodeKeyForIPPort(ipp netip.AddrPort, nk key.NodePublic) { - if pi := m.byIPPort[ipp]; pi != nil { - delete(pi.ipPorts, ipp) - delete(m.byIPPort, ipp) +// WireGuard for packets received from addr. +func (m *peerMap) setNodeKeyForEpAddr(addr epAddr, nk key.NodePublic) { + if pi := m.byEpAddr[addr]; pi != nil { + delete(pi.epAddrs, addr) + delete(m.byEpAddr, addr) + if addr.vni.IsSet() { + delete(m.relayEpAddrByNodeKey, pi.ep.publicKey) + } } if pi, ok := m.byNodeKey[nk]; ok { - pi.ipPorts.Add(ipp) - m.byIPPort[ipp] = pi + if addr.vni.IsSet() { + relay, ok := m.relayEpAddrByNodeKey[nk] + if ok { + delete(pi.epAddrs, relay) + delete(m.byEpAddr, relay) + } + m.relayEpAddrByNodeKey[nk] = addr + } + pi.epAddrs.Add(addr) + m.byEpAddr[addr] = pi } } @@ -203,7 +225,8 @@ func (m *peerMap) deleteEndpoint(ep *endpoint) { // Unexpected. But no logger plumbed here to log so. return } - for ip := range pi.ipPorts { - delete(m.byIPPort, ip) + for ip := range pi.epAddrs { + delete(m.byEpAddr, ip) } + delete(m.relayEpAddrByNodeKey, ep.publicKey) } diff --git a/wgengine/magicsock/peermap_test.go b/wgengine/magicsock/peermap_test.go new file mode 100644 index 000000000..171e22a6d --- /dev/null +++ b/wgengine/magicsock/peermap_test.go @@ -0,0 +1,37 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package magicsock + +import ( + "net/netip" + "testing" + + "tailscale.com/net/packet" + "tailscale.com/types/key" +) + +func Test_peerMap_oneRelayEpAddrPerNK(t *testing.T) { + pm := newPeerMap() + nk := key.NewNode().Public() + ep := &endpoint{ + nodeID: 1, + publicKey: nk, + } + ed := &endpointDisco{key: key.NewDisco().Public()} + ep.disco.Store(ed) + pm.upsertEndpoint(ep, key.DiscoPublic{}) + vni := packet.VirtualNetworkID{} + vni.Set(1) + relayEpAddrA := epAddr{ap: netip.MustParseAddrPort("127.0.0.1:1"), vni: vni} + relayEpAddrB := epAddr{ap: netip.MustParseAddrPort("127.0.0.1:2"), vni: vni} + pm.setNodeKeyForEpAddr(relayEpAddrA, nk) + pm.setNodeKeyForEpAddr(relayEpAddrB, nk) + if len(pm.byEpAddr) != 1 { + t.Fatalf("expected 1 epAddr in byEpAddr, got: %d", len(pm.byEpAddr)) + } + got := pm.relayEpAddrByNodeKey[nk] + if got != relayEpAddrB { + t.Fatalf("expected relay epAddr %v, got: %v", relayEpAddrB, got) + } +} diff --git a/wgengine/magicsock/rebinding_conn.go b/wgengine/magicsock/rebinding_conn.go index c27abbadc..c98e64570 100644 --- a/wgengine/magicsock/rebinding_conn.go +++ b/wgengine/magicsock/rebinding_conn.go @@ -5,14 +5,17 @@ package magicsock import ( "errors" + "fmt" "net" "net/netip" - "sync" "sync/atomic" "syscall" "golang.org/x/net/ipv6" + "tailscale.com/net/batching" "tailscale.com/net/netaddr" + "tailscale.com/net/packet" + "tailscale.com/syncs" "tailscale.com/types/nettype" ) @@ -28,7 +31,7 @@ type RebindingUDPConn struct { // Neither is expected to be nil, sockets are bound on creation. pconnAtomic atomic.Pointer[nettype.PacketConn] - mu sync.Mutex // held while changing pconn (and pconnAtomic) + mu syncs.Mutex // held while changing pconn (and pconnAtomic) pconn nettype.PacketConn port uint16 } @@ -40,7 +43,7 @@ type RebindingUDPConn struct { // disrupting surrounding code that assumes nettype.PacketConn is a // *net.UDPConn. func (c *RebindingUDPConn) setConnLocked(p nettype.PacketConn, network string, batchSize int) { - upc := tryUpgradeToBatchingConn(p, network, batchSize) + upc := batching.TryUpgradeToConn(p, network, batchSize) c.pconn = upc c.pconnAtomic.Store(&upc) c.port = uint16(c.localAddrLocked().Port) @@ -70,21 +73,39 @@ func (c *RebindingUDPConn) ReadFromUDPAddrPort(b []byte) (int, netip.AddrPort, e return c.readFromWithInitPconn(*c.pconnAtomic.Load(), b) } -// WriteBatchTo writes buffs to addr. -func (c *RebindingUDPConn) WriteBatchTo(buffs [][]byte, addr netip.AddrPort) error { +// WriteWireGuardBatchTo writes buffs to addr. It serves primarily as an alias +// for [batching.Conn.WriteBatchTo], with fallback to single packet operations +// if c.pconn is not a [batching.Conn]. +// +// WriteWireGuardBatchTo assumes buffs are WireGuard packets, which is notable +// for Geneve encapsulation: Geneve protocol is set to [packet.GeneveProtocolWireGuard], +// and the control bit is left unset. +func (c *RebindingUDPConn) WriteWireGuardBatchTo(buffs [][]byte, addr epAddr, offset int) error { + if offset != packet.GeneveFixedHeaderLength { + return fmt.Errorf("RebindingUDPConn.WriteWireGuardBatchTo: [unexpected] offset (%d) != Geneve header length (%d)", offset, packet.GeneveFixedHeaderLength) + } + gh := packet.GeneveHeader{ + Protocol: packet.GeneveProtocolWireGuard, + VNI: addr.vni, + } for { pconn := *c.pconnAtomic.Load() - b, ok := pconn.(batchingConn) + b, ok := pconn.(batching.Conn) if !ok { for _, buf := range buffs { - _, err := c.writeToUDPAddrPortWithInitPconn(pconn, buf, addr) + if gh.VNI.IsSet() { + gh.Encode(buf) + } else { + buf = buf[offset:] + } + _, err := c.writeToUDPAddrPortWithInitPconn(pconn, buf, addr.ap) if err != nil { return err } } return nil } - err := b.WriteBatchTo(buffs, addr) + err := b.WriteBatchTo(buffs, addr.ap, gh, offset) if err != nil { if pconn != c.currentConn() { continue @@ -95,13 +116,12 @@ func (c *RebindingUDPConn) WriteBatchTo(buffs [][]byte, addr netip.AddrPort) err } } -// ReadBatch reads messages from c into msgs. It returns the number of messages -// the caller should evaluate for nonzero len, as a zero len message may fall -// on either side of a nonzero. +// ReadBatch is an alias for [batching.Conn.ReadBatch] with fallback to single +// packet operations if c.pconn is not a [batching.Conn]. func (c *RebindingUDPConn) ReadBatch(msgs []ipv6.Message, flags int) (int, error) { for { pconn := *c.pconnAtomic.Load() - b, ok := pconn.(batchingConn) + b, ok := pconn.(batching.Conn) if !ok { n, ap, err := c.readFromWithInitPconn(pconn, msgs[0].Buffers[0]) if err == nil { diff --git a/wgengine/magicsock/relaymanager.go b/wgengine/magicsock/relaymanager.go new file mode 100644 index 000000000..69831a4df --- /dev/null +++ b/wgengine/magicsock/relaymanager.go @@ -0,0 +1,1071 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package magicsock + +import ( + "context" + "errors" + "fmt" + "net/netip" + "sync" + "time" + + "tailscale.com/disco" + "tailscale.com/net/packet" + "tailscale.com/net/stun" + udprelay "tailscale.com/net/udprelay/endpoint" + "tailscale.com/syncs" + "tailscale.com/tailcfg" + "tailscale.com/tstime" + "tailscale.com/types/key" + "tailscale.com/util/set" +) + +// relayManager manages allocation, handshaking, and initial probing (disco +// ping/pong) of [tailscale.com/net/udprelay.Server] endpoints. The zero value +// is ready for use. +// +// [relayManager] methods can be called by [Conn] and [endpoint] while their .mu +// mutexes are held. Therefore, in order to avoid deadlocks, [relayManager] must +// never attempt to acquire those mutexes synchronously from its runLoop(), +// including synchronous calls back towards [Conn] or [endpoint] methods that +// acquire them. +type relayManager struct { + initOnce sync.Once + + // =================================================================== + // The following fields are owned by a single goroutine, runLoop(). + serversByNodeKey map[key.NodePublic]candidatePeerRelay + allocWorkByCandidatePeerRelayByEndpoint map[*endpoint]map[candidatePeerRelay]*relayEndpointAllocWork + allocWorkByDiscoKeysByServerNodeKey map[key.NodePublic]map[key.SortedPairOfDiscoPublic]*relayEndpointAllocWork + handshakeWorkByServerDiscoByEndpoint map[*endpoint]map[key.DiscoPublic]*relayHandshakeWork + handshakeWorkByServerDiscoVNI map[serverDiscoVNI]*relayHandshakeWork + handshakeWorkAwaitingPong map[*relayHandshakeWork]addrPortVNI + addrPortVNIToHandshakeWork map[addrPortVNI]*relayHandshakeWork + handshakeGeneration uint32 + allocGeneration uint32 + + // =================================================================== + // The following chan fields serve event inputs to a single goroutine, + // runLoop(). + startDiscoveryCh chan endpointWithLastBest + allocateWorkDoneCh chan relayEndpointAllocWorkDoneEvent + handshakeWorkDoneCh chan relayEndpointHandshakeWorkDoneEvent + cancelWorkCh chan *endpoint + newServerEndpointCh chan newRelayServerEndpointEvent + rxDiscoMsgCh chan relayDiscoMsgEvent + serversCh chan set.Set[candidatePeerRelay] + getServersCh chan chan set.Set[candidatePeerRelay] + derpHomeChangeCh chan derpHomeChangeEvent + + discoInfoMu syncs.Mutex // guards the following field + discoInfoByServerDisco map[key.DiscoPublic]*relayHandshakeDiscoInfo + + // runLoopStoppedCh is written to by runLoop() upon return, enabling event + // writers to restart it when they are blocked (see + // relayManagerInputEvent()). + runLoopStoppedCh chan struct{} +} + +// serverDiscoVNI represents a [tailscale.com/net/udprelay.Server] disco key +// and Geneve header VNI value for a given [udprelay.ServerEndpoint]. +type serverDiscoVNI struct { + serverDisco key.DiscoPublic + vni uint32 +} + +// relayHandshakeWork serves to track in-progress relay handshake work for a +// [udprelay.ServerEndpoint]. This structure is immutable once initialized. +type relayHandshakeWork struct { + wlb endpointWithLastBest + se udprelay.ServerEndpoint + server candidatePeerRelay + + handshakeGen uint32 + + // handshakeServerEndpoint() always writes to doneCh (len 1) when it + // returns. It may end up writing the same event afterward to + // relayManager.handshakeWorkDoneCh if runLoop() can receive it. runLoop() + // must select{} read on doneCh to prevent deadlock when attempting to write + // to rxDiscoMsgCh. + rxDiscoMsgCh chan relayDiscoMsgEvent + doneCh chan relayEndpointHandshakeWorkDoneEvent + + ctx context.Context + cancel context.CancelFunc +} + +func (r *relayHandshakeWork) dlogf(format string, args ...any) { + if !r.wlb.ep.c.debugLogging.Load() { + return + } + var relay string + if r.server.nodeKey.IsZero() { + relay = "from-call-me-maybe-via" + } else { + relay = r.server.nodeKey.ShortString() + } + r.wlb.ep.c.logf("%s node=%v relay=%v handshakeGen=%d disco[0]=%v disco[1]=%v", + fmt.Sprintf(format, args...), + r.wlb.ep.publicKey.ShortString(), + relay, + r.handshakeGen, + r.se.ClientDisco[0].ShortString(), + r.se.ClientDisco[1].ShortString(), + ) +} + +// newRelayServerEndpointEvent indicates a new [udprelay.ServerEndpoint] has +// become known either via allocation with a relay server, or via +// [disco.CallMeMaybeVia] reception. This structure is immutable once +// initialized. +type newRelayServerEndpointEvent struct { + wlb endpointWithLastBest + se udprelay.ServerEndpoint + server candidatePeerRelay // zero value if learned via [disco.CallMeMaybeVia] +} + +// relayEndpointAllocWorkDoneEvent indicates relay server endpoint allocation +// work for an [*endpoint] has completed. This structure is immutable once +// initialized. +type relayEndpointAllocWorkDoneEvent struct { + work *relayEndpointAllocWork + allocated udprelay.ServerEndpoint // !allocated.ServerDisco.IsZero() if allocation succeeded +} + +// relayEndpointHandshakeWorkDoneEvent indicates relay server endpoint handshake +// work for an [*endpoint] has completed. This structure is immutable once +// initialized. +type relayEndpointHandshakeWorkDoneEvent struct { + work *relayHandshakeWork + pongReceivedFrom netip.AddrPort // or zero value if handshake or ping/pong did not complete + latency time.Duration // only relevant if pongReceivedFrom.IsValid() +} + +// hasActiveWorkRunLoop returns true if there is outstanding allocation or +// handshaking work for any endpoint, otherwise it returns false. +func (r *relayManager) hasActiveWorkRunLoop() bool { + return len(r.allocWorkByCandidatePeerRelayByEndpoint) > 0 || len(r.handshakeWorkByServerDiscoByEndpoint) > 0 +} + +// hasActiveWorkForEndpointRunLoop returns true if there is outstanding +// allocation or handshaking work for the provided endpoint, otherwise it +// returns false. +func (r *relayManager) hasActiveWorkForEndpointRunLoop(ep *endpoint) bool { + _, handshakeWork := r.handshakeWorkByServerDiscoByEndpoint[ep] + _, allocWork := r.allocWorkByCandidatePeerRelayByEndpoint[ep] + return handshakeWork || allocWork +} + +// derpHomeChangeEvent represents a change in the DERP home region for the +// node identified by nodeKey. This structure is immutable once initialized. +type derpHomeChangeEvent struct { + nodeKey key.NodePublic + regionID uint16 +} + +// handleDERPHomeChange handles a DERP home change event for nodeKey and +// regionID. +func (r *relayManager) handleDERPHomeChange(nodeKey key.NodePublic, regionID uint16) { + relayManagerInputEvent(r, nil, &r.derpHomeChangeCh, derpHomeChangeEvent{ + nodeKey: nodeKey, + regionID: regionID, + }) +} + +func (r *relayManager) handleDERPHomeChangeRunLoop(event derpHomeChangeEvent) { + c, ok := r.serversByNodeKey[event.nodeKey] + if ok { + c.derpHomeRegionID = event.regionID + r.serversByNodeKey[event.nodeKey] = c + } +} + +// runLoop is a form of event loop. It ensures exclusive access to most of +// [relayManager] state. +func (r *relayManager) runLoop() { + defer func() { + r.runLoopStoppedCh <- struct{}{} + }() + + for { + select { + case startDiscovery := <-r.startDiscoveryCh: + if !r.hasActiveWorkForEndpointRunLoop(startDiscovery.ep) { + r.allocateAllServersRunLoop(startDiscovery) + } + if !r.hasActiveWorkRunLoop() { + return + } + case done := <-r.allocateWorkDoneCh: + r.handleAllocWorkDoneRunLoop(done) + if !r.hasActiveWorkRunLoop() { + return + } + case ep := <-r.cancelWorkCh: + r.stopWorkRunLoop(ep) + if !r.hasActiveWorkRunLoop() { + return + } + case newServerEndpoint := <-r.newServerEndpointCh: + r.handleNewServerEndpointRunLoop(newServerEndpoint) + if !r.hasActiveWorkRunLoop() { + return + } + case done := <-r.handshakeWorkDoneCh: + r.handleHandshakeWorkDoneRunLoop(done) + if !r.hasActiveWorkRunLoop() { + return + } + case discoMsgEvent := <-r.rxDiscoMsgCh: + r.handleRxDiscoMsgRunLoop(discoMsgEvent) + if !r.hasActiveWorkRunLoop() { + return + } + case serversUpdate := <-r.serversCh: + r.handleServersUpdateRunLoop(serversUpdate) + if !r.hasActiveWorkRunLoop() { + return + } + case getServersCh := <-r.getServersCh: + r.handleGetServersRunLoop(getServersCh) + if !r.hasActiveWorkRunLoop() { + return + } + case derpHomeChange := <-r.derpHomeChangeCh: + r.handleDERPHomeChangeRunLoop(derpHomeChange) + if !r.hasActiveWorkRunLoop() { + return + } + } + } +} + +func (r *relayManager) handleGetServersRunLoop(getServersCh chan set.Set[candidatePeerRelay]) { + servers := make(set.Set[candidatePeerRelay], len(r.serversByNodeKey)) + for _, v := range r.serversByNodeKey { + servers.Add(v) + } + getServersCh <- servers +} + +func (r *relayManager) getServers() set.Set[candidatePeerRelay] { + ch := make(chan set.Set[candidatePeerRelay]) + relayManagerInputEvent(r, nil, &r.getServersCh, ch) + return <-ch +} + +func (r *relayManager) handleServersUpdateRunLoop(update set.Set[candidatePeerRelay]) { + for _, v := range r.serversByNodeKey { + if !update.Contains(v) { + delete(r.serversByNodeKey, v.nodeKey) + } + } + for _, v := range update.Slice() { + r.serversByNodeKey[v.nodeKey] = v + } +} + +type relayDiscoMsgEvent struct { + conn *Conn // for access to [Conn] if there is no associated [relayHandshakeWork] + msg disco.Message + relayServerNodeKey key.NodePublic // nonzero if msg is a [*disco.AllocateUDPRelayEndpointResponse] + disco key.DiscoPublic + from netip.AddrPort + vni uint32 + at time.Time +} + +// relayEndpointAllocWork serves to track in-progress relay endpoint allocation +// for an [*endpoint]. This structure is immutable once initialized. +type relayEndpointAllocWork struct { + wlb endpointWithLastBest + discoKeys key.SortedPairOfDiscoPublic + candidatePeerRelay candidatePeerRelay // zero value if learned via [disco.CallMeMaybeVia] + + allocGen uint32 + + // allocateServerEndpoint() always writes to doneCh (len 1) when it + // returns. It may end up writing the same event afterward to + // [relayManager.allocateWorkDoneCh] if runLoop() can receive it. runLoop() + // must select{} read on doneCh to prevent deadlock when attempting to write + // to rxDiscoMsgCh. + rxDiscoMsgCh chan *disco.AllocateUDPRelayEndpointResponse + doneCh chan relayEndpointAllocWorkDoneEvent + + ctx context.Context + cancel context.CancelFunc +} + +func (r *relayEndpointAllocWork) dlogf(format string, args ...any) { + if !r.wlb.ep.c.debugLogging.Load() { + return + } + r.wlb.ep.c.logf("%s node=%v relay=%v allocGen=%d disco[0]=%v disco[1]=%v", + fmt.Sprintf(format, args...), + r.wlb.ep.publicKey.ShortString(), + r.candidatePeerRelay.nodeKey.ShortString(), + r.allocGen, + r.discoKeys.Get()[0].ShortString(), + r.discoKeys.Get()[1].ShortString(), + ) +} + +// init initializes [relayManager] if it is not already initialized. +func (r *relayManager) init() { + r.initOnce.Do(func() { + r.discoInfoByServerDisco = make(map[key.DiscoPublic]*relayHandshakeDiscoInfo) + r.serversByNodeKey = make(map[key.NodePublic]candidatePeerRelay) + r.allocWorkByCandidatePeerRelayByEndpoint = make(map[*endpoint]map[candidatePeerRelay]*relayEndpointAllocWork) + r.allocWorkByDiscoKeysByServerNodeKey = make(map[key.NodePublic]map[key.SortedPairOfDiscoPublic]*relayEndpointAllocWork) + r.handshakeWorkByServerDiscoByEndpoint = make(map[*endpoint]map[key.DiscoPublic]*relayHandshakeWork) + r.handshakeWorkByServerDiscoVNI = make(map[serverDiscoVNI]*relayHandshakeWork) + r.handshakeWorkAwaitingPong = make(map[*relayHandshakeWork]addrPortVNI) + r.addrPortVNIToHandshakeWork = make(map[addrPortVNI]*relayHandshakeWork) + r.startDiscoveryCh = make(chan endpointWithLastBest) + r.allocateWorkDoneCh = make(chan relayEndpointAllocWorkDoneEvent) + r.handshakeWorkDoneCh = make(chan relayEndpointHandshakeWorkDoneEvent) + r.cancelWorkCh = make(chan *endpoint) + r.newServerEndpointCh = make(chan newRelayServerEndpointEvent) + r.rxDiscoMsgCh = make(chan relayDiscoMsgEvent) + r.serversCh = make(chan set.Set[candidatePeerRelay]) + r.getServersCh = make(chan chan set.Set[candidatePeerRelay]) + r.derpHomeChangeCh = make(chan derpHomeChangeEvent) + r.runLoopStoppedCh = make(chan struct{}, 1) + r.runLoopStoppedCh <- struct{}{} + }) +} + +// relayHandshakeDiscoInfo serves to cache a [*discoInfo] for outstanding +// [*relayHandshakeWork] against a given relay server. +type relayHandshakeDiscoInfo struct { + work set.Set[*relayHandshakeWork] // guarded by relayManager.discoInfoMu + di *discoInfo // immutable once initialized +} + +// ensureDiscoInfoFor ensures a [*discoInfo] will be returned by discoInfo() for +// the server disco key associated with 'work'. Callers must also call +// derefDiscoInfoFor() when 'work' is complete. +func (r *relayManager) ensureDiscoInfoFor(work *relayHandshakeWork) { + r.discoInfoMu.Lock() + defer r.discoInfoMu.Unlock() + di, ok := r.discoInfoByServerDisco[work.se.ServerDisco] + if !ok { + di = &relayHandshakeDiscoInfo{} + di.work.Make() + r.discoInfoByServerDisco[work.se.ServerDisco] = di + } + di.work.Add(work) + if di.di == nil { + di.di = &discoInfo{ + discoKey: work.se.ServerDisco, + discoShort: work.se.ServerDisco.ShortString(), + sharedKey: work.wlb.ep.c.discoAtomic.Private().Shared(work.se.ServerDisco), + } + } +} + +// derefDiscoInfoFor decrements the reference count of the [*discoInfo] +// associated with 'work'. +func (r *relayManager) derefDiscoInfoFor(work *relayHandshakeWork) { + r.discoInfoMu.Lock() + defer r.discoInfoMu.Unlock() + di, ok := r.discoInfoByServerDisco[work.se.ServerDisco] + if !ok { + // TODO(jwhited): unexpected + return + } + di.work.Delete(work) + if di.work.Len() == 0 { + delete(r.discoInfoByServerDisco, work.se.ServerDisco) + } +} + +// discoInfo returns a [*discoInfo] for 'serverDisco' if there is an +// active/ongoing handshake with it, otherwise it returns nil, false. +func (r *relayManager) discoInfo(serverDisco key.DiscoPublic) (_ *discoInfo, ok bool) { + r.discoInfoMu.Lock() + defer r.discoInfoMu.Unlock() + di, ok := r.discoInfoByServerDisco[serverDisco] + if ok { + return di.di, ok + } + return nil, false +} + +func (r *relayManager) handleCallMeMaybeVia(ep *endpoint, lastBest addrQuality, lastBestIsTrusted bool, dm *disco.CallMeMaybeVia) { + se := udprelay.ServerEndpoint{ + ServerDisco: dm.ServerDisco, + ClientDisco: dm.ClientDisco, + LamportID: dm.LamportID, + AddrPorts: dm.AddrPorts, + VNI: dm.VNI, + } + se.BindLifetime.Duration = dm.BindLifetime + se.SteadyStateLifetime.Duration = dm.SteadyStateLifetime + relayManagerInputEvent(r, nil, &r.newServerEndpointCh, newRelayServerEndpointEvent{ + wlb: endpointWithLastBest{ + ep: ep, + lastBest: lastBest, + lastBestIsTrusted: lastBestIsTrusted, + }, + se: se, + }) +} + +// handleRxDiscoMsg handles reception of disco messages that [relayManager] +// may be interested in. This includes all Geneve-encapsulated disco messages +// and [*disco.AllocateUDPRelayEndpointResponse]. If dm is a +// [*disco.AllocateUDPRelayEndpointResponse] then relayServerNodeKey must be +// nonzero. +func (r *relayManager) handleRxDiscoMsg(conn *Conn, dm disco.Message, relayServerNodeKey key.NodePublic, discoKey key.DiscoPublic, src epAddr) { + relayManagerInputEvent(r, nil, &r.rxDiscoMsgCh, relayDiscoMsgEvent{ + conn: conn, + msg: dm, + relayServerNodeKey: relayServerNodeKey, + disco: discoKey, + from: src.ap, + vni: src.vni.Get(), + at: time.Now(), + }) +} + +// handleRelayServersSet handles an update of the complete relay server set. +func (r *relayManager) handleRelayServersSet(servers set.Set[candidatePeerRelay]) { + relayManagerInputEvent(r, nil, &r.serversCh, servers) +} + +// relayManagerInputEvent initializes [relayManager] if necessary, starts +// relayManager.runLoop() if it is not running, and writes 'event' on 'eventCh'. +// +// [relayManager] initialization will make `*eventCh`, so it must be passed as +// a pointer to a channel. +// +// 'ctx' can be used for returning when runLoop is waiting for the calling +// goroutine to return, i.e. the calling goroutine was birthed by runLoop and is +// cancelable via 'ctx'. 'ctx' may be nil. +func relayManagerInputEvent[T any](r *relayManager, ctx context.Context, eventCh *chan T, event T) { + r.init() + var ctxDoneCh <-chan struct{} + if ctx != nil { + ctxDoneCh = ctx.Done() + } + for { + select { + case <-ctxDoneCh: + return + case *eventCh <- event: + return + case <-r.runLoopStoppedCh: + go r.runLoop() + } + } +} + +// endpointWithLastBest represents an [*endpoint], its last bestAddr, and if +// the last bestAddr was trusted (see endpoint.trustBestAddrUntil) at the time +// of init. This structure is immutable once initialized. +type endpointWithLastBest struct { + ep *endpoint + lastBest addrQuality + lastBestIsTrusted bool +} + +// startUDPRelayPathDiscoveryFor starts UDP relay path discovery for ep on all +// known relay servers if ep has no in-progress work. +func (r *relayManager) startUDPRelayPathDiscoveryFor(ep *endpoint, lastBest addrQuality, lastBestIsTrusted bool) { + relayManagerInputEvent(r, nil, &r.startDiscoveryCh, endpointWithLastBest{ + ep: ep, + lastBest: lastBest, + lastBestIsTrusted: lastBestIsTrusted, + }) +} + +// stopWork stops all outstanding allocation & handshaking work for 'ep'. +func (r *relayManager) stopWork(ep *endpoint) { + relayManagerInputEvent(r, nil, &r.cancelWorkCh, ep) +} + +// stopWorkRunLoop cancels & clears outstanding allocation and handshaking +// work for 'ep'. +func (r *relayManager) stopWorkRunLoop(ep *endpoint) { + byDiscoKeys, ok := r.allocWorkByCandidatePeerRelayByEndpoint[ep] + if ok { + for _, work := range byDiscoKeys { + work.cancel() + done := <-work.doneCh + r.handleAllocWorkDoneRunLoop(done) + } + } + byServerDisco, ok := r.handshakeWorkByServerDiscoByEndpoint[ep] + if ok { + for _, handshakeWork := range byServerDisco { + handshakeWork.cancel() + done := <-handshakeWork.doneCh + r.handleHandshakeWorkDoneRunLoop(done) + } + } +} + +// addrPortVNI represents a combined netip.AddrPort and Geneve header virtual +// network identifier. +type addrPortVNI struct { + addrPort netip.AddrPort + vni uint32 +} + +func (r *relayManager) handleRxDiscoMsgRunLoop(event relayDiscoMsgEvent) { + var ( + work *relayHandshakeWork + ok bool + ) + apv := addrPortVNI{event.from, event.vni} + switch msg := event.msg.(type) { + case *disco.AllocateUDPRelayEndpointResponse: + sorted := key.NewSortedPairOfDiscoPublic(msg.ClientDisco[0], msg.ClientDisco[1]) + byDiscoKeys, ok := r.allocWorkByDiscoKeysByServerNodeKey[event.relayServerNodeKey] + if !ok { + // No outstanding work tied to this relay sever, discard. + return + } + allocWork, ok := byDiscoKeys[sorted] + if !ok { + // No outstanding work tied to these disco keys, discard. + return + } + select { + case done := <-allocWork.doneCh: + // allocateServerEndpoint returned, clean up its state + r.handleAllocWorkDoneRunLoop(done) + return + case allocWork.rxDiscoMsgCh <- msg: + return + } + case *disco.BindUDPRelayEndpointChallenge: + work, ok = r.handshakeWorkByServerDiscoVNI[serverDiscoVNI{event.disco, event.vni}] + if !ok { + // No outstanding work tied to this challenge, discard. + return + } + _, ok = r.handshakeWorkAwaitingPong[work] + if ok { + // We've seen a challenge for this relay endpoint previously, + // discard. Servers only respond to the first src ip:port they see + // binds from. + return + } + _, ok = r.addrPortVNIToHandshakeWork[apv] + if ok { + // There is existing work for the same [addrPortVNI] that is not + // 'work'. If both instances happen to be on the same server we + // could attempt to resolve event order using LamportID. For now + // just leave both work instances alone and take no action other + // than to discard this challenge msg. + return + } + // Update state so that future ping/pong will route to 'work'. + r.handshakeWorkAwaitingPong[work] = apv + r.addrPortVNIToHandshakeWork[apv] = work + case *disco.Ping: + // Always TX a pong. We might not have any associated work if ping + // reception raced with our call to [endpoint.udpRelayEndpointReady()], so + // err on the side of enabling the remote side to use this path. + // + // Conn.handlePingLocked() makes efforts to suppress duplicate pongs + // where the same ping can be received both via raw socket and UDP + // socket on Linux. We make no such efforts here as the raw socket BPF + // program does not support Geneve-encapsulated disco, and is also + // disabled by default. + vni := packet.VirtualNetworkID{} + vni.Set(event.vni) + go event.conn.sendDiscoMessage(epAddr{ap: event.from, vni: vni}, key.NodePublic{}, event.disco, &disco.Pong{ + TxID: msg.TxID, + Src: event.from, + }, discoVerboseLog) + + work, ok = r.addrPortVNIToHandshakeWork[apv] + if !ok { + // No outstanding work tied to this [addrPortVNI], return early. + return + } + case *disco.Pong: + work, ok = r.addrPortVNIToHandshakeWork[apv] + if !ok { + // No outstanding work tied to this [addrPortVNI], discard. + return + } + default: + // Unexpected message type, discard. + return + } + select { + case done := <-work.doneCh: + // handshakeServerEndpoint() returned, clean up its state. + r.handleHandshakeWorkDoneRunLoop(done) + return + case work.rxDiscoMsgCh <- event: + return + } +} + +func (r *relayManager) handleAllocWorkDoneRunLoop(done relayEndpointAllocWorkDoneEvent) { + byCandidatePeerRelay, ok := r.allocWorkByCandidatePeerRelayByEndpoint[done.work.wlb.ep] + if !ok { + return + } + work, ok := byCandidatePeerRelay[done.work.candidatePeerRelay] + if !ok || work != done.work { + return + } + delete(byCandidatePeerRelay, done.work.candidatePeerRelay) + if len(byCandidatePeerRelay) == 0 { + delete(r.allocWorkByCandidatePeerRelayByEndpoint, done.work.wlb.ep) + } + byDiscoKeys, ok := r.allocWorkByDiscoKeysByServerNodeKey[done.work.candidatePeerRelay.nodeKey] + if !ok { + // unexpected + return + } + delete(byDiscoKeys, done.work.discoKeys) + if len(byDiscoKeys) == 0 { + delete(r.allocWorkByDiscoKeysByServerNodeKey, done.work.candidatePeerRelay.nodeKey) + } + if !done.allocated.ServerDisco.IsZero() { + r.handleNewServerEndpointRunLoop(newRelayServerEndpointEvent{ + wlb: done.work.wlb, + se: done.allocated, + server: done.work.candidatePeerRelay, + }) + } +} + +func (r *relayManager) handleHandshakeWorkDoneRunLoop(done relayEndpointHandshakeWorkDoneEvent) { + byServerDisco, ok := r.handshakeWorkByServerDiscoByEndpoint[done.work.wlb.ep] + if !ok { + return + } + work, ok := byServerDisco[done.work.se.ServerDisco] + if !ok || work != done.work { + return + } + delete(byServerDisco, done.work.se.ServerDisco) + if len(byServerDisco) == 0 { + delete(r.handshakeWorkByServerDiscoByEndpoint, done.work.wlb.ep) + } + delete(r.handshakeWorkByServerDiscoVNI, serverDiscoVNI{done.work.se.ServerDisco, done.work.se.VNI}) + apv, ok := r.handshakeWorkAwaitingPong[work] + if ok { + delete(r.handshakeWorkAwaitingPong, work) + delete(r.addrPortVNIToHandshakeWork, apv) + } + if !done.pongReceivedFrom.IsValid() { + // The handshake or ping/pong probing timed out. + return + } + // This relay endpoint is functional. + vni := packet.VirtualNetworkID{} + vni.Set(done.work.se.VNI) + addr := epAddr{ap: done.pongReceivedFrom, vni: vni} + // ep.udpRelayEndpointReady() must be called in a new goroutine to prevent + // deadlocks as it acquires [endpoint] & [Conn] mutexes. See [relayManager] + // docs for details. + go done.work.wlb.ep.udpRelayEndpointReady(addrQuality{ + epAddr: addr, + relayServerDisco: done.work.se.ServerDisco, + latency: done.latency, + wireMTU: pingSizeToPktLen(0, addr), + }) +} + +func (r *relayManager) handleNewServerEndpointRunLoop(newServerEndpoint newRelayServerEndpointEvent) { + // Check for duplicate work by server disco + VNI. + sdv := serverDiscoVNI{newServerEndpoint.se.ServerDisco, newServerEndpoint.se.VNI} + existingWork, ok := r.handshakeWorkByServerDiscoVNI[sdv] + if ok { + // There's in-progress handshake work for the server disco + VNI, which + // uniquely identify a [udprelay.ServerEndpoint]. Compare Lamport + // IDs to determine which is newer. + if existingWork.se.LamportID >= newServerEndpoint.se.LamportID { + // The existing work is a duplicate or newer. Return early. + return + } + + // The existing work is no longer valid, clean it up. + existingWork.cancel() + done := <-existingWork.doneCh + r.handleHandshakeWorkDoneRunLoop(done) + } + + // Check for duplicate work by [*endpoint] + server disco. + byServerDisco, ok := r.handshakeWorkByServerDiscoByEndpoint[newServerEndpoint.wlb.ep] + if ok { + existingWork, ok := byServerDisco[newServerEndpoint.se.ServerDisco] + if ok { + if newServerEndpoint.se.LamportID <= existingWork.se.LamportID { + // The "new" server endpoint is outdated or duplicate in + // consideration against existing handshake work. Return early. + return + } + // Cancel existing handshake that has a lower lamport ID. + existingWork.cancel() + done := <-existingWork.doneCh + r.handleHandshakeWorkDoneRunLoop(done) + } + } + + // We're now reasonably sure we're dealing with the latest + // [udprelay.ServerEndpoint] from a server event order perspective + // (LamportID). + + if newServerEndpoint.server.isValid() { + // Send a [disco.CallMeMaybeVia] to the remote peer if we allocated this + // endpoint, regardless of if we start a handshake below. + go r.sendCallMeMaybeVia(newServerEndpoint.wlb.ep, newServerEndpoint.se) + } + + lastBestMatchingServer := newServerEndpoint.se.ServerDisco.Compare(newServerEndpoint.wlb.lastBest.relayServerDisco) == 0 + if lastBestMatchingServer && newServerEndpoint.wlb.lastBestIsTrusted { + // This relay server endpoint is the same as [endpoint]'s bestAddr at + // the time UDP relay path discovery was started, and it was also a + // trusted path (see endpoint.trustBestAddrUntil), so return early. + // + // If we were to start a new handshake, there is a chance that we + // cause [endpoint] to blackhole some packets on its bestAddr if we end + // up shifting to a new address family or src, e.g. IPv4 to IPv6, due to + // the window of time between the handshake completing, and our call to + // udpRelayEndpointReady(). The relay server can only forward packets + // from us on a single [epAddr]. + return + } + + // TODO(jwhited): if lastBest is untrusted, consider some strategies + // to reduce the chance we blackhole if it were to transition to + // trusted during/before the new handshake: + // 1. Start by attempting a handshake with only lastBest.epAddr. If + // that fails then try the remaining [epAddr]s. + // 2. Signal bestAddr trust transitions between [endpoint] and + // [relayManager] in order to prevent a handshake from starting + // and/or stop one that is running. + + // We're ready to start a new handshake. + ctx, cancel := context.WithCancel(context.Background()) + work := &relayHandshakeWork{ + wlb: newServerEndpoint.wlb, + se: newServerEndpoint.se, + server: newServerEndpoint.server, + rxDiscoMsgCh: make(chan relayDiscoMsgEvent), + doneCh: make(chan relayEndpointHandshakeWorkDoneEvent, 1), + ctx: ctx, + cancel: cancel, + } + // We must look up byServerDisco again. The previous value may have been + // deleted from the outer map when cleaning up duplicate work. + byServerDisco, ok = r.handshakeWorkByServerDiscoByEndpoint[newServerEndpoint.wlb.ep] + if !ok { + byServerDisco = make(map[key.DiscoPublic]*relayHandshakeWork) + r.handshakeWorkByServerDiscoByEndpoint[newServerEndpoint.wlb.ep] = byServerDisco + } + byServerDisco[newServerEndpoint.se.ServerDisco] = work + r.handshakeWorkByServerDiscoVNI[sdv] = work + + r.handshakeGeneration++ + if r.handshakeGeneration == 0 { // generation must be nonzero + r.handshakeGeneration++ + } + work.handshakeGen = r.handshakeGeneration + + go r.handshakeServerEndpoint(work) +} + +// sendCallMeMaybeVia sends a [disco.CallMeMaybeVia] to ep over DERP. It must be +// called as part of a goroutine independent from runLoop(), for 2 reasons: +// 1. it acquires ep.mu (refer to [relayManager] docs for reasoning) +// 2. it makes a networking syscall, which can introduce unwanted backpressure +func (r *relayManager) sendCallMeMaybeVia(ep *endpoint, se udprelay.ServerEndpoint) { + ep.mu.Lock() + derpAddr := ep.derpAddr + ep.mu.Unlock() + epDisco := ep.disco.Load() + if epDisco == nil || !derpAddr.IsValid() { + return + } + callMeMaybeVia := &disco.CallMeMaybeVia{ + UDPRelayEndpoint: disco.UDPRelayEndpoint{ + ServerDisco: se.ServerDisco, + ClientDisco: se.ClientDisco, + LamportID: se.LamportID, + VNI: se.VNI, + BindLifetime: se.BindLifetime.Duration, + SteadyStateLifetime: se.SteadyStateLifetime.Duration, + AddrPorts: se.AddrPorts, + }, + } + ep.c.sendDiscoMessage(epAddr{ap: derpAddr}, ep.publicKey, epDisco.key, callMeMaybeVia, discoVerboseLog) +} + +func (r *relayManager) handshakeServerEndpoint(work *relayHandshakeWork) { + done := relayEndpointHandshakeWorkDoneEvent{work: work} + r.ensureDiscoInfoFor(work) + + defer func() { + r.derefDiscoInfoFor(work) + work.doneCh <- done + relayManagerInputEvent(r, work.ctx, &r.handshakeWorkDoneCh, done) + work.cancel() + }() + + ep := work.wlb.ep + epDisco := ep.disco.Load() + if epDisco == nil { + return + } + + common := disco.BindUDPRelayEndpointCommon{ + VNI: work.se.VNI, + Generation: work.handshakeGen, + RemoteKey: epDisco.key, + } + + work.dlogf("[v1] magicsock: relayManager: starting handshake addrPorts=%v", + work.se.AddrPorts, + ) + sentBindAny := false + bind := &disco.BindUDPRelayEndpoint{ + BindUDPRelayEndpointCommon: common, + } + vni := packet.VirtualNetworkID{} + vni.Set(work.se.VNI) + for _, addrPort := range work.se.AddrPorts { + if addrPort.IsValid() { + sentBindAny = true + go ep.c.sendDiscoMessage(epAddr{ap: addrPort, vni: vni}, key.NodePublic{}, work.se.ServerDisco, bind, discoVerboseLog) + } + } + if !sentBindAny { + return + } + + // Limit goroutine lifetime to a reasonable duration. This is intentionally + // detached and independent of 'BindLifetime' to prevent relay server + // (mis)configuration from negatively impacting client resource usage. + const maxHandshakeLifetime = time.Second * 30 + timer := time.NewTimer(min(work.se.BindLifetime.Duration, maxHandshakeLifetime)) + defer timer.Stop() + + // Limit the number of pings we will transmit. Inbound pings trigger + // outbound pings, so we want to be a little defensive. + const limitPings = 10 + + var ( + handshakeState disco.BindUDPRelayHandshakeState = disco.BindUDPRelayHandshakeStateBindSent + sentPingAt = make(map[stun.TxID]time.Time) + ) + + txPing := func(to netip.AddrPort, withAnswer *[32]byte) { + if len(sentPingAt) == limitPings { + return + } + txid := stun.NewTxID() + sentPingAt[txid] = time.Now() + ping := &disco.Ping{ + TxID: txid, + NodeKey: ep.c.publicKeyAtomic.Load(), + } + go func() { + if withAnswer != nil { + answer := &disco.BindUDPRelayEndpointAnswer{BindUDPRelayEndpointCommon: common} + answer.Challenge = *withAnswer + ep.c.sendDiscoMessage(epAddr{ap: to, vni: vni}, key.NodePublic{}, work.se.ServerDisco, answer, discoVerboseLog) + } + ep.c.sendDiscoMessage(epAddr{ap: to, vni: vni}, key.NodePublic{}, epDisco.key, ping, discoVerboseLog) + }() + } + + validateVNIAndRemoteKey := func(common disco.BindUDPRelayEndpointCommon) error { + if common.VNI != work.se.VNI { + return errors.New("mismatching VNI") + } + if common.RemoteKey.Compare(epDisco.key) != 0 { + return errors.New("mismatching RemoteKey") + } + return nil + } + + // This for{select{}} is responsible for handshaking and tx'ing ping/pong + // when the handshake is complete. + for { + select { + case <-work.ctx.Done(): + work.dlogf("[v1] magicsock: relayManager: handshake canceled") + return + case msgEvent := <-work.rxDiscoMsgCh: + switch msg := msgEvent.msg.(type) { + case *disco.BindUDPRelayEndpointChallenge: + err := validateVNIAndRemoteKey(msg.BindUDPRelayEndpointCommon) + if err != nil { + continue + } + if handshakeState >= disco.BindUDPRelayHandshakeStateAnswerSent { + continue + } + work.dlogf("[v1] magicsock: relayManager: got handshake challenge from %v", msgEvent.from) + txPing(msgEvent.from, &msg.Challenge) + handshakeState = disco.BindUDPRelayHandshakeStateAnswerSent + case *disco.Ping: + if handshakeState < disco.BindUDPRelayHandshakeStateAnswerSent { + continue + } + work.dlogf("[v1] magicsock: relayManager: got relayed ping from %v", msgEvent.from) + // An inbound ping from the remote peer indicates we completed a + // handshake with the relay server (our answer msg was + // received). Chances are our ping was dropped before the remote + // handshake was complete. We need to rx a pong to determine + // latency, so send another ping. Since the handshake is + // complete we do not need to send an answer in front of this + // one. + // + // We don't need to TX a pong, that was already handled for us + // in handleRxDiscoMsgRunLoop(). + txPing(msgEvent.from, nil) + case *disco.Pong: + at, ok := sentPingAt[msg.TxID] + if !ok { + continue + } + // The relay server endpoint is functional! Record the + // round-trip latency and return. + done.pongReceivedFrom = msgEvent.from + done.latency = time.Since(at) + work.dlogf("[v1] magicsock: relayManager: got relayed pong from %v latency=%v", + msgEvent.from, + done.latency.Round(time.Millisecond), + ) + return + default: + // unexpected message type, silently discard + continue + } + case <-timer.C: + // The handshake timed out. + work.dlogf("[v1] magicsock: relayManager: handshake timed out") + return + } + } +} + +const allocateUDPRelayEndpointRequestTimeout = time.Second * 10 + +func (r *relayManager) allocateServerEndpoint(work *relayEndpointAllocWork) { + done := relayEndpointAllocWorkDoneEvent{work: work} + + defer func() { + work.doneCh <- done + relayManagerInputEvent(r, work.ctx, &r.allocateWorkDoneCh, done) + work.cancel() + }() + + dm := &disco.AllocateUDPRelayEndpointRequest{ + ClientDisco: work.discoKeys.Get(), + Generation: work.allocGen, + } + + sendAllocReq := func() { + work.wlb.ep.c.sendDiscoAllocateUDPRelayEndpointRequest( + epAddr{ + ap: netip.AddrPortFrom(tailcfg.DerpMagicIPAddr, work.candidatePeerRelay.derpHomeRegionID), + }, + work.candidatePeerRelay.nodeKey, + work.candidatePeerRelay.discoKey, + dm, + discoVerboseLog, + ) + work.dlogf("[v1] magicsock: relayManager: sent alloc request") + } + go sendAllocReq() + + returnAfterTimer := time.NewTimer(allocateUDPRelayEndpointRequestTimeout) + defer returnAfterTimer.Stop() + // While connections to DERP are over TCP, they can be lossy on the DERP + // server when data moves between the two independent streams. Also, the + // peer relay server may not be "ready" (see [tailscale.com/net/udprelay.ErrServerNotReady]). + // So, start a timer to retry once if needed. + retryAfterTimer := time.NewTimer(udprelay.ServerRetryAfter) + defer retryAfterTimer.Stop() + + for { + select { + case <-work.ctx.Done(): + work.dlogf("[v1] magicsock: relayManager: alloc request canceled") + return + case <-returnAfterTimer.C: + work.dlogf("[v1] magicsock: relayManager: alloc request timed out") + return + case <-retryAfterTimer.C: + go sendAllocReq() + case resp := <-work.rxDiscoMsgCh: + if resp.Generation != work.allocGen || + !work.discoKeys.Equal(key.NewSortedPairOfDiscoPublic(resp.ClientDisco[0], resp.ClientDisco[1])) { + continue + } + work.dlogf("[v1] magicsock: relayManager: got alloc response") + done.allocated = udprelay.ServerEndpoint{ + ServerDisco: resp.ServerDisco, + ClientDisco: resp.ClientDisco, + LamportID: resp.LamportID, + AddrPorts: resp.AddrPorts, + VNI: resp.VNI, + BindLifetime: tstime.GoDuration{Duration: resp.BindLifetime}, + SteadyStateLifetime: tstime.GoDuration{Duration: resp.SteadyStateLifetime}, + } + return + } + } +} + +func (r *relayManager) allocateAllServersRunLoop(wlb endpointWithLastBest) { + if len(r.serversByNodeKey) == 0 { + return + } + remoteDisco := wlb.ep.disco.Load() + if remoteDisco == nil { + return + } + discoKeys := key.NewSortedPairOfDiscoPublic(wlb.ep.c.discoAtomic.Public(), remoteDisco.key) + for _, v := range r.serversByNodeKey { + byDiscoKeys, ok := r.allocWorkByDiscoKeysByServerNodeKey[v.nodeKey] + if !ok { + byDiscoKeys = make(map[key.SortedPairOfDiscoPublic]*relayEndpointAllocWork) + r.allocWorkByDiscoKeysByServerNodeKey[v.nodeKey] = byDiscoKeys + } else { + _, ok = byDiscoKeys[discoKeys] + if ok { + // If there is an existing key, a disco key collision may have + // occurred across peers ([*endpoint]). Do not overwrite the + // existing work, let it finish. + wlb.ep.c.logf("[unexpected] magicsock: relayManager: suspected disco key collision on server %v for keys: %v", v.nodeKey.ShortString(), discoKeys) + continue + } + } + ctx, cancel := context.WithCancel(context.Background()) + started := &relayEndpointAllocWork{ + wlb: wlb, + discoKeys: discoKeys, + candidatePeerRelay: v, + rxDiscoMsgCh: make(chan *disco.AllocateUDPRelayEndpointResponse), + doneCh: make(chan relayEndpointAllocWorkDoneEvent, 1), + ctx: ctx, + cancel: cancel, + } + byDiscoKeys[discoKeys] = started + byCandidatePeerRelay, ok := r.allocWorkByCandidatePeerRelayByEndpoint[wlb.ep] + if !ok { + byCandidatePeerRelay = make(map[candidatePeerRelay]*relayEndpointAllocWork) + r.allocWorkByCandidatePeerRelayByEndpoint[wlb.ep] = byCandidatePeerRelay + } + byCandidatePeerRelay[v] = started + r.allocGeneration++ + started.allocGen = r.allocGeneration + go r.allocateServerEndpoint(started) + } +} diff --git a/wgengine/magicsock/relaymanager_test.go b/wgengine/magicsock/relaymanager_test.go new file mode 100644 index 000000000..e8fddfd91 --- /dev/null +++ b/wgengine/magicsock/relaymanager_test.go @@ -0,0 +1,260 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package magicsock + +import ( + "testing" + + "tailscale.com/disco" + udprelay "tailscale.com/net/udprelay/endpoint" + "tailscale.com/types/key" + "tailscale.com/util/set" +) + +func TestRelayManagerInitAndIdle(t *testing.T) { + rm := relayManager{} + rm.startUDPRelayPathDiscoveryFor(&endpoint{}, addrQuality{}, false) + <-rm.runLoopStoppedCh + + rm = relayManager{} + rm.stopWork(&endpoint{}) + <-rm.runLoopStoppedCh + + rm = relayManager{} + c1 := &Conn{} + c1.discoAtomic.Set(key.NewDisco()) + rm.handleCallMeMaybeVia(&endpoint{c: c1}, addrQuality{}, false, &disco.CallMeMaybeVia{UDPRelayEndpoint: disco.UDPRelayEndpoint{ServerDisco: key.NewDisco().Public()}}) + <-rm.runLoopStoppedCh + + rm = relayManager{} + c2 := &Conn{} + c2.discoAtomic.Set(key.NewDisco()) + rm.handleRxDiscoMsg(c2, &disco.BindUDPRelayEndpointChallenge{}, key.NodePublic{}, key.DiscoPublic{}, epAddr{}) + <-rm.runLoopStoppedCh + + rm = relayManager{} + rm.handleRelayServersSet(make(set.Set[candidatePeerRelay])) + <-rm.runLoopStoppedCh + + rm = relayManager{} + rm.getServers() + <-rm.runLoopStoppedCh + + rm = relayManager{} + rm.handleDERPHomeChange(key.NodePublic{}, 1) + <-rm.runLoopStoppedCh +} + +func TestRelayManagerHandleDERPHomeChange(t *testing.T) { + rm := relayManager{} + servers := make(set.Set[candidatePeerRelay], 1) + c := candidatePeerRelay{ + nodeKey: key.NewNode().Public(), + discoKey: key.NewDisco().Public(), + derpHomeRegionID: 1, + } + servers.Add(c) + rm.handleRelayServersSet(servers) + want := c + want.derpHomeRegionID = 2 + rm.handleDERPHomeChange(c.nodeKey, 2) + got := rm.getServers() + if len(got) != 1 { + t.Fatalf("got %d servers, want 1", len(got)) + } + _, ok := got[want] + if !ok { + t.Fatal("DERP home change failed to propagate") + } +} + +func TestRelayManagerGetServers(t *testing.T) { + rm := relayManager{} + servers := make(set.Set[candidatePeerRelay], 1) + c := candidatePeerRelay{ + nodeKey: key.NewNode().Public(), + discoKey: key.NewDisco().Public(), + } + servers.Add(c) + rm.handleRelayServersSet(servers) + got := rm.getServers() + if !servers.Equal(got) { + t.Errorf("got %v != want %v", got, servers) + } +} + +func TestRelayManager_handleNewServerEndpointRunLoop(t *testing.T) { + wantHandshakeWorkCount := func(t *testing.T, rm *relayManager, n int) { + t.Helper() + byServerDiscoByEndpoint := 0 + for _, v := range rm.handshakeWorkByServerDiscoByEndpoint { + byServerDiscoByEndpoint += len(v) + } + byServerDiscoVNI := len(rm.handshakeWorkByServerDiscoVNI) + if byServerDiscoByEndpoint != n || + byServerDiscoVNI != n || + byServerDiscoByEndpoint != byServerDiscoVNI { + t.Fatalf("want handshake work count %d byServerDiscoByEndpoint=%d byServerDiscoVNI=%d", + n, + byServerDiscoByEndpoint, + byServerDiscoVNI, + ) + } + } + + conn := newConn(t.Logf) + epA := &endpoint{c: conn} + epB := &endpoint{c: conn} + serverDiscoA := key.NewDisco().Public() + serverDiscoB := key.NewDisco().Public() + + serverAendpointALamport1VNI1 := newRelayServerEndpointEvent{ + wlb: endpointWithLastBest{ep: epA}, + se: udprelay.ServerEndpoint{ServerDisco: serverDiscoA, LamportID: 1, VNI: 1}, + } + serverAendpointALamport1VNI1LastBestMatching := newRelayServerEndpointEvent{ + wlb: endpointWithLastBest{ep: epA, lastBestIsTrusted: true, lastBest: addrQuality{relayServerDisco: serverDiscoA}}, + se: udprelay.ServerEndpoint{ServerDisco: serverDiscoA, LamportID: 1, VNI: 1}, + } + serverAendpointALamport2VNI1 := newRelayServerEndpointEvent{ + wlb: endpointWithLastBest{ep: epA}, + se: udprelay.ServerEndpoint{ServerDisco: serverDiscoA, LamportID: 2, VNI: 1}, + } + serverAendpointALamport2VNI2 := newRelayServerEndpointEvent{ + wlb: endpointWithLastBest{ep: epA}, + se: udprelay.ServerEndpoint{ServerDisco: serverDiscoA, LamportID: 2, VNI: 2}, + } + serverAendpointBLamport1VNI2 := newRelayServerEndpointEvent{ + wlb: endpointWithLastBest{ep: epB}, + se: udprelay.ServerEndpoint{ServerDisco: serverDiscoA, LamportID: 1, VNI: 2}, + } + serverBendpointALamport1VNI1 := newRelayServerEndpointEvent{ + wlb: endpointWithLastBest{ep: epA}, + se: udprelay.ServerEndpoint{ServerDisco: serverDiscoB, LamportID: 1, VNI: 1}, + } + + tests := []struct { + name string + events []newRelayServerEndpointEvent + want []newRelayServerEndpointEvent + }{ + { + // Test for http://go/corp/32978 + name: "eq server+ep neq VNI higher lamport", + events: []newRelayServerEndpointEvent{ + serverAendpointALamport1VNI1, + serverAendpointALamport2VNI2, + }, + want: []newRelayServerEndpointEvent{ + serverAendpointALamport2VNI2, + }, + }, + { + name: "eq server+ep neq VNI lower lamport", + events: []newRelayServerEndpointEvent{ + serverAendpointALamport2VNI2, + serverAendpointALamport1VNI1, + }, + want: []newRelayServerEndpointEvent{ + serverAendpointALamport2VNI2, + }, + }, + { + name: "eq server+vni neq ep lower lamport", + events: []newRelayServerEndpointEvent{ + serverAendpointALamport2VNI2, + serverAendpointBLamport1VNI2, + }, + want: []newRelayServerEndpointEvent{ + serverAendpointALamport2VNI2, + }, + }, + { + name: "eq server+vni neq ep higher lamport", + events: []newRelayServerEndpointEvent{ + serverAendpointBLamport1VNI2, + serverAendpointALamport2VNI2, + }, + want: []newRelayServerEndpointEvent{ + serverAendpointALamport2VNI2, + }, + }, + { + name: "eq server+endpoint+vni higher lamport", + events: []newRelayServerEndpointEvent{ + serverAendpointALamport1VNI1, + serverAendpointALamport2VNI1, + }, + want: []newRelayServerEndpointEvent{ + serverAendpointALamport2VNI1, + }, + }, + { + name: "eq server+endpoint+vni lower lamport", + events: []newRelayServerEndpointEvent{ + serverAendpointALamport2VNI1, + serverAendpointALamport1VNI1, + }, + want: []newRelayServerEndpointEvent{ + serverAendpointALamport2VNI1, + }, + }, + { + name: "eq endpoint+vni+lamport neq server", + events: []newRelayServerEndpointEvent{ + serverAendpointALamport1VNI1, + serverBendpointALamport1VNI1, + }, + want: []newRelayServerEndpointEvent{ + serverAendpointALamport1VNI1, + serverBendpointALamport1VNI1, + }, + }, + { + name: "trusted last best with matching server", + events: []newRelayServerEndpointEvent{ + serverAendpointALamport1VNI1LastBestMatching, + }, + want: []newRelayServerEndpointEvent{}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rm := &relayManager{} + rm.init() + <-rm.runLoopStoppedCh // prevent runLoop() from starting + + // feed events + for _, event := range tt.events { + rm.handleNewServerEndpointRunLoop(event) + } + + // validate state + wantHandshakeWorkCount(t, rm, len(tt.want)) + for _, want := range tt.want { + byServerDisco, ok := rm.handshakeWorkByServerDiscoByEndpoint[want.wlb.ep] + if !ok { + t.Fatal("work not found by endpoint") + } + workByServerDiscoByEndpoint, ok := byServerDisco[want.se.ServerDisco] + if !ok { + t.Fatal("work not found by server disco by endpoint") + } + workByServerDiscoVNI, ok := rm.handshakeWorkByServerDiscoVNI[serverDiscoVNI{want.se.ServerDisco, want.se.VNI}] + if !ok { + t.Fatal("work not found by server disco + VNI") + } + if workByServerDiscoByEndpoint != workByServerDiscoVNI { + t.Fatal("workByServerDiscoByEndpoint != workByServerDiscoVNI") + } + } + + // cleanup + for _, event := range tt.events { + rm.stopWorkRunLoop(event.wlb.ep) + } + wantHandshakeWorkCount(t, rm, 0) + }) + } +} diff --git a/wgengine/netlog/logger.go b/wgengine/netlog/logger.go deleted file mode 100644 index 3a696b246..000000000 --- a/wgengine/netlog/logger.go +++ /dev/null @@ -1,274 +0,0 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package netlog provides a logger that monitors a TUN device and -// periodically records any traffic into a log stream. -package netlog - -import ( - "context" - "encoding/json" - "fmt" - "io" - "log" - "net/http" - "net/netip" - "sync" - "time" - - "tailscale.com/health" - "tailscale.com/logpolicy" - "tailscale.com/logtail" - "tailscale.com/net/connstats" - "tailscale.com/net/netmon" - "tailscale.com/net/sockstats" - "tailscale.com/net/tsaddr" - "tailscale.com/tailcfg" - "tailscale.com/types/logid" - "tailscale.com/types/netlogtype" - "tailscale.com/util/multierr" - "tailscale.com/wgengine/router" -) - -// pollPeriod specifies how often to poll for network traffic. -const pollPeriod = 5 * time.Second - -// Device is an abstraction over a tunnel device or a magic socket. -// Both *tstun.Wrapper and *magicsock.Conn implement this interface. -type Device interface { - SetStatistics(*connstats.Statistics) -} - -type noopDevice struct{} - -func (noopDevice) SetStatistics(*connstats.Statistics) {} - -// Logger logs statistics about every connection. -// At present, it only logs connections within a tailscale network. -// Exit node traffic is not logged for privacy reasons. -// The zero value is ready for use. -type Logger struct { - mu sync.Mutex // protects all fields below - - logger *logtail.Logger - stats *connstats.Statistics - tun Device - sock Device - - addrs map[netip.Addr]bool - prefixes map[netip.Prefix]bool -} - -// Running reports whether the logger is running. -func (nl *Logger) Running() bool { - nl.mu.Lock() - defer nl.mu.Unlock() - return nl.logger != nil -} - -var testClient *http.Client - -// Startup starts an asynchronous network logger that monitors -// statistics for the provided tun and/or sock device. -// -// The tun Device captures packets within the tailscale network, -// where at least one address is a tailscale IP address. -// The source is always from the perspective of the current node. -// If one of the other endpoint is not a tailscale IP address, -// then it suggests the use of a subnet router or exit node. -// For example, when using a subnet router, the source address is -// the tailscale IP address of the current node, and -// the destination address is an IP address within the subnet range. -// In contrast, when acting as a subnet router, the source address is -// an IP address within the subnet range, and the destination is a -// tailscale IP address that initiated the subnet proxy connection. -// In this case, the node acting as a subnet router is acting on behalf -// of some remote endpoint within the subnet range. -// The tun is used to populate the VirtualTraffic, SubnetTraffic, -// and ExitTraffic fields in Message. -// -// The sock Device captures packets at the magicsock layer. -// The source is always a tailscale IP address and the destination -// is a non-tailscale IP address to contact for that particular tailscale node. -// The IP protocol and source port are always zero. -// The sock is used to populated the PhysicalTraffic field in Message. -// The netMon parameter is optional; if non-nil it's used to do faster interface lookups. -func (nl *Logger) Startup(nodeID tailcfg.StableNodeID, nodeLogID, domainLogID logid.PrivateID, tun, sock Device, netMon *netmon.Monitor, health *health.Tracker, logExitFlowEnabledEnabled bool) error { - nl.mu.Lock() - defer nl.mu.Unlock() - if nl.logger != nil { - return fmt.Errorf("network logger already running for %v", nl.logger.PrivateID().Public()) - } - - // Startup a log stream to Tailscale's logging service. - logf := log.Printf - httpc := &http.Client{Transport: logpolicy.NewLogtailTransport(logtail.DefaultHost, netMon, health, logf)} - if testClient != nil { - httpc = testClient - } - nl.logger = logtail.NewLogger(logtail.Config{ - Collection: "tailtraffic.log.tailscale.io", - PrivateID: nodeLogID, - CopyPrivateID: domainLogID, - Stderr: io.Discard, - CompressLogs: true, - HTTPC: httpc, - // TODO(joetsai): Set Buffer? Use an in-memory buffer for now. - - // Include process sequence numbers to identify missing samples. - IncludeProcID: true, - IncludeProcSequence: true, - }, logf) - nl.logger.SetSockstatsLabel(sockstats.LabelNetlogLogger) - - // Startup a data structure to track per-connection statistics. - // There is a maximum size for individual log messages that logtail - // can upload to the Tailscale log service, so stay below this limit. - const maxLogSize = 256 << 10 - const maxConns = (maxLogSize - netlogtype.MaxMessageJSONSize) / netlogtype.MaxConnectionCountsJSONSize - nl.stats = connstats.NewStatistics(pollPeriod, maxConns, func(start, end time.Time, virtual, physical map[netlogtype.Connection]netlogtype.Counts) { - nl.mu.Lock() - addrs := nl.addrs - prefixes := nl.prefixes - nl.mu.Unlock() - recordStatistics(nl.logger, nodeID, start, end, virtual, physical, addrs, prefixes, logExitFlowEnabledEnabled) - }) - - // Register the connection tracker into the TUN device. - if tun == nil { - tun = noopDevice{} - } - nl.tun = tun - nl.tun.SetStatistics(nl.stats) - - // Register the connection tracker into magicsock. - if sock == nil { - sock = noopDevice{} - } - nl.sock = sock - nl.sock.SetStatistics(nl.stats) - - return nil -} - -func recordStatistics(logger *logtail.Logger, nodeID tailcfg.StableNodeID, start, end time.Time, connstats, sockStats map[netlogtype.Connection]netlogtype.Counts, addrs map[netip.Addr]bool, prefixes map[netip.Prefix]bool, logExitFlowEnabled bool) { - m := netlogtype.Message{NodeID: nodeID, Start: start.UTC(), End: end.UTC()} - - classifyAddr := func(a netip.Addr) (isTailscale, withinRoute bool) { - // NOTE: There could be mis-classifications where an address is treated - // as a Tailscale IP address because the subnet range overlaps with - // the subnet range that Tailscale IP addresses are allocated from. - // This should never happen for IPv6, but could happen for IPv4. - withinRoute = addrs[a] - for p := range prefixes { - if p.Contains(a) && p.Bits() > 0 { - withinRoute = true - break - } - } - return withinRoute && tsaddr.IsTailscaleIP(a), withinRoute && !tsaddr.IsTailscaleIP(a) - } - - exitTraffic := make(map[netlogtype.Connection]netlogtype.Counts) - for conn, cnts := range connstats { - srcIsTailscaleIP, srcWithinSubnet := classifyAddr(conn.Src.Addr()) - dstIsTailscaleIP, dstWithinSubnet := classifyAddr(conn.Dst.Addr()) - switch { - case srcIsTailscaleIP && dstIsTailscaleIP: - m.VirtualTraffic = append(m.VirtualTraffic, netlogtype.ConnectionCounts{Connection: conn, Counts: cnts}) - case srcWithinSubnet || dstWithinSubnet: - m.SubnetTraffic = append(m.SubnetTraffic, netlogtype.ConnectionCounts{Connection: conn, Counts: cnts}) - default: - const anonymize = true - if anonymize && !logExitFlowEnabled { - // Only preserve the address if it is a Tailscale IP address. - srcOrig, dstOrig := conn.Src, conn.Dst - conn = netlogtype.Connection{} // scrub everything by default - if srcIsTailscaleIP { - conn.Src = netip.AddrPortFrom(srcOrig.Addr(), 0) - } - if dstIsTailscaleIP { - conn.Dst = netip.AddrPortFrom(dstOrig.Addr(), 0) - } - } - exitTraffic[conn] = exitTraffic[conn].Add(cnts) - } - } - for conn, cnts := range exitTraffic { - m.ExitTraffic = append(m.ExitTraffic, netlogtype.ConnectionCounts{Connection: conn, Counts: cnts}) - } - for conn, cnts := range sockStats { - m.PhysicalTraffic = append(m.PhysicalTraffic, netlogtype.ConnectionCounts{Connection: conn, Counts: cnts}) - } - - if len(m.VirtualTraffic)+len(m.SubnetTraffic)+len(m.ExitTraffic)+len(m.PhysicalTraffic) > 0 { - if b, err := json.Marshal(m); err != nil { - logger.Logf("json.Marshal error: %v", err) - } else { - logger.Logf("%s", b) - } - } -} - -func makeRouteMaps(cfg *router.Config) (addrs map[netip.Addr]bool, prefixes map[netip.Prefix]bool) { - addrs = make(map[netip.Addr]bool) - for _, p := range cfg.LocalAddrs { - if p.IsSingleIP() { - addrs[p.Addr()] = true - } - } - prefixes = make(map[netip.Prefix]bool) - insertPrefixes := func(rs []netip.Prefix) { - for _, p := range rs { - if p.IsSingleIP() { - addrs[p.Addr()] = true - } else { - prefixes[p] = true - } - } - } - insertPrefixes(cfg.Routes) - insertPrefixes(cfg.SubnetRoutes) - return addrs, prefixes -} - -// ReconfigRoutes configures the network logger with updated routes. -// The cfg is used to classify the types of connections captured by -// the tun Device passed to Startup. -func (nl *Logger) ReconfigRoutes(cfg *router.Config) { - nl.mu.Lock() - defer nl.mu.Unlock() - // TODO(joetsai): There is a race where deleted routes are not known at - // the time of extraction. We need to keep old routes around for a bit. - nl.addrs, nl.prefixes = makeRouteMaps(cfg) -} - -// Shutdown shuts down the network logger. -// This attempts to flush out all pending log messages. -// Even if an error is returned, the logger is still shut down. -func (nl *Logger) Shutdown(ctx context.Context) error { - nl.mu.Lock() - defer nl.mu.Unlock() - if nl.logger == nil { - return nil - } - - // Shutdown in reverse order of Startup. - // Do not hold lock while shutting down since this may flush one last time. - nl.mu.Unlock() - nl.sock.SetStatistics(nil) - nl.tun.SetStatistics(nil) - err1 := nl.stats.Shutdown(ctx) - err2 := nl.logger.Shutdown(ctx) - nl.mu.Lock() - - // Purge state. - nl.logger = nil - nl.stats = nil - nl.tun = nil - nl.sock = nil - nl.addrs = nil - nl.prefixes = nil - - return multierr.New(err1, err2) -} diff --git a/wgengine/netlog/netlog.go b/wgengine/netlog/netlog.go new file mode 100644 index 000000000..12fe9c797 --- /dev/null +++ b/wgengine/netlog/netlog.go @@ -0,0 +1,494 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !ts_omit_netlog && !ts_omit_logtail + +// Package netlog provides a logger that monitors a TUN device and +// periodically records any traffic into a log stream. +package netlog + +import ( + "cmp" + "context" + "fmt" + "io" + "log" + "net/http" + "net/netip" + "time" + + "tailscale.com/health" + "tailscale.com/logpolicy" + "tailscale.com/logtail" + "tailscale.com/net/netmon" + "tailscale.com/net/sockstats" + "tailscale.com/net/tsaddr" + "tailscale.com/syncs" + "tailscale.com/types/ipproto" + "tailscale.com/types/logger" + "tailscale.com/types/logid" + "tailscale.com/types/netlogfunc" + "tailscale.com/types/netlogtype" + "tailscale.com/types/netmap" + "tailscale.com/util/eventbus" + "tailscale.com/util/set" + "tailscale.com/wgengine/router" + + jsonv2 "github.com/go-json-experiment/json" + "github.com/go-json-experiment/json/jsontext" +) + +// pollPeriod specifies how often to poll for network traffic. +const pollPeriod = 5 * time.Second + +// Device is an abstraction over a tunnel device or a magic socket. +// Both *tstun.Wrapper and *magicsock.Conn implement this interface. +type Device interface { + SetConnectionCounter(netlogfunc.ConnectionCounter) +} + +type noopDevice struct{} + +func (noopDevice) SetConnectionCounter(netlogfunc.ConnectionCounter) {} + +// Logger logs statistics about every connection. +// At present, it only logs connections within a tailscale network. +// By default, exit node traffic is not logged for privacy reasons +// unless the Tailnet administrator opts-into explicit logging. +// The zero value is ready for use. +type Logger struct { + mu syncs.Mutex // protects all fields below + logf logger.Logf + + // shutdownLocked shuts down the logger. + // The mutex must be held when calling. + shutdownLocked func(context.Context) error + + record record // the current record of network connection flows + recordLen int // upper bound on JSON length of record + recordsChan chan record // set to nil when shutdown + flushTimer *time.Timer // fires when record should flush to recordsChan + + // Information about Tailscale nodes. + // These are read-only once updated by ReconfigNetworkMap. + selfNode nodeUser + allNodes map[netip.Addr]nodeUser // includes selfNode; nodeUser values are always valid + + // Information about routes. + // These are read-only once updated by ReconfigRoutes. + routeAddrs set.Set[netip.Addr] + routePrefixes []netip.Prefix +} + +// Running reports whether the logger is running. +func (nl *Logger) Running() bool { + nl.mu.Lock() + defer nl.mu.Unlock() + return nl.shutdownLocked != nil +} + +var testClient *http.Client + +// Startup starts an asynchronous network logger that monitors +// statistics for the provided tun and/or sock device. +// +// The tun [Device] captures packets within the tailscale network, +// where at least one address is usually a tailscale IP address. +// The source is usually from the perspective of the current node. +// If one of the other endpoint is not a tailscale IP address, +// then it suggests the use of a subnet router or exit node. +// For example, when using a subnet router, the source address is +// the tailscale IP address of the current node, and +// the destination address is an IP address within the subnet range. +// In contrast, when acting as a subnet router, the source address is +// an IP address within the subnet range, and the destination is a +// tailscale IP address that initiated the subnet proxy connection. +// In this case, the node acting as a subnet router is acting on behalf +// of some remote endpoint within the subnet range. +// The tun is used to populate the VirtualTraffic, SubnetTraffic, +// and ExitTraffic fields in [netlogtype.Message]. +// +// The sock [Device] captures packets at the magicsock layer. +// The source is always a tailscale IP address and the destination +// is a non-tailscale IP address to contact for that particular tailscale node. +// The IP protocol and source port are always zero. +// The sock is used to populated the PhysicalTraffic field in [netlogtype.Message]. +// +// The netMon parameter is optional; if non-nil it's used to do faster interface lookups. +func (nl *Logger) Startup(logf logger.Logf, nm *netmap.NetworkMap, nodeLogID, domainLogID logid.PrivateID, tun, sock Device, netMon *netmon.Monitor, health *health.Tracker, bus *eventbus.Bus, logExitFlowEnabledEnabled bool) error { + nl.mu.Lock() + defer nl.mu.Unlock() + + if nl.shutdownLocked != nil { + return fmt.Errorf("network logger already running") + } + nl.selfNode, nl.allNodes = makeNodeMaps(nm) + + // Startup a log stream to Tailscale's logging service. + if logf == nil { + logf = log.Printf + } + httpc := &http.Client{Transport: logpolicy.NewLogtailTransport(logtail.DefaultHost, netMon, health, logf)} + if testClient != nil { + httpc = testClient + } + logger := logtail.NewLogger(logtail.Config{ + Collection: "tailtraffic.log.tailscale.io", + PrivateID: nodeLogID, + CopyPrivateID: domainLogID, + Bus: bus, + Stderr: io.Discard, + CompressLogs: true, + HTTPC: httpc, + // TODO(joetsai): Set Buffer? Use an in-memory buffer for now. + + // Include process sequence numbers to identify missing samples. + IncludeProcID: true, + IncludeProcSequence: true, + }, logf) + logger.SetSockstatsLabel(sockstats.LabelNetlogLogger) + + // Register the connection tracker into the TUN device. + tun = cmp.Or[Device](tun, noopDevice{}) + tun.SetConnectionCounter(nl.updateVirtConn) + + // Register the connection tracker into magicsock. + sock = cmp.Or[Device](sock, noopDevice{}) + sock.SetConnectionCounter(nl.updatePhysConn) + + // Startup a goroutine to record log messages. + // This is done asynchronously so that the cost of serializing + // the network flow log message never stalls processing of packets. + nl.record = record{} + nl.recordLen = 0 + nl.recordsChan = make(chan record, 100) + recorderDone := make(chan struct{}) + go func(recordsChan chan record) { + defer close(recorderDone) + for rec := range recordsChan { + msg := rec.toMessage(false, !logExitFlowEnabledEnabled) + if b, err := jsonv2.Marshal(msg, jsontext.AllowInvalidUTF8(true)); err != nil { + if nl.logf != nil { + nl.logf("netlog: json.Marshal error: %v", err) + } + } else { + logger.Logf("%s", b) + } + } + }(nl.recordsChan) + + // Register the mechanism for shutting down. + nl.shutdownLocked = func(ctx context.Context) error { + tun.SetConnectionCounter(nil) + sock.SetConnectionCounter(nil) + + // Flush and process all pending records. + nl.flushRecordLocked() + close(nl.recordsChan) + nl.recordsChan = nil + <-recorderDone + recorderDone = nil + + // Try to upload all pending records. + err := logger.Shutdown(ctx) + + // Purge state. + nl.shutdownLocked = nil + nl.selfNode = nodeUser{} + nl.allNodes = nil + nl.routeAddrs = nil + nl.routePrefixes = nil + + return err + } + + return nil +} + +var ( + tailscaleServiceIPv4 = tsaddr.TailscaleServiceIP() + tailscaleServiceIPv6 = tsaddr.TailscaleServiceIPv6() +) + +func (nl *Logger) updateVirtConn(proto ipproto.Proto, src, dst netip.AddrPort, packets, bytes int, recv bool) { + // Network logging is defined as traffic between two Tailscale nodes. + // Traffic with the internal Tailscale service is not with another node + // and should not be logged. It also happens to be a high volume + // amount of discrete traffic flows (e.g., DNS lookups). + switch dst.Addr() { + case tailscaleServiceIPv4, tailscaleServiceIPv6: + return + } + + nl.mu.Lock() + defer nl.mu.Unlock() + + // Lookup the connection and increment the counts. + nl.initRecordLocked() + conn := netlogtype.Connection{Proto: proto, Src: src, Dst: dst} + cnts, found := nl.record.virtConns[conn] + if !found { + cnts.connType = nl.addNewVirtConnLocked(conn) + } + if recv { + cnts.RxPackets += uint64(packets) + cnts.RxBytes += uint64(bytes) + } else { + cnts.TxPackets += uint64(packets) + cnts.TxBytes += uint64(bytes) + } + nl.record.virtConns[conn] = cnts +} + +// addNewVirtConnLocked adds the first insertion of a physical connection. +// The [Logger.mu] must be held. +func (nl *Logger) addNewVirtConnLocked(c netlogtype.Connection) connType { + // Check whether this is the first insertion of the src and dst node. + // If so, compute the additional JSON bytes that would be added + // to the record for the node information. + var srcNodeLen, dstNodeLen int + srcNode, srcSeen := nl.record.seenNodes[c.Src.Addr()] + if !srcSeen { + srcNode = nl.allNodes[c.Src.Addr()] + if srcNode.Valid() { + srcNodeLen = srcNode.jsonLen() + } + } + dstNode, dstSeen := nl.record.seenNodes[c.Dst.Addr()] + if !dstSeen { + dstNode = nl.allNodes[c.Dst.Addr()] + if dstNode.Valid() { + dstNodeLen = dstNode.jsonLen() + } + } + + // Check whether the additional [netlogtype.ConnectionCounts] + // and [netlogtype.Node] information would exceed [maxLogSize]. + if nl.recordLen+netlogtype.MaxConnectionCountsJSONSize+srcNodeLen+dstNodeLen > maxLogSize { + nl.flushRecordLocked() + nl.initRecordLocked() + } + + // Insert newly seen src and/or dst nodes. + if !srcSeen && srcNode.Valid() { + nl.record.seenNodes[c.Src.Addr()] = srcNode + } + if !dstSeen && dstNode.Valid() { + nl.record.seenNodes[c.Dst.Addr()] = dstNode + } + nl.recordLen += netlogtype.MaxConnectionCountsJSONSize + srcNodeLen + dstNodeLen + + // Classify the traffic type. + var srcIsSelfNode bool + if nl.selfNode.Valid() { + srcIsSelfNode = nl.selfNode.Addresses().ContainsFunc(func(p netip.Prefix) bool { + return c.Src.Addr() == p.Addr() && p.IsSingleIP() + }) + } + switch { + case srcIsSelfNode && dstNode.Valid(): + return virtualTraffic + case srcIsSelfNode: + // TODO: Should we swap src for the node serving as the proxy? + // It is relatively useless always using the self IP address. + if nl.withinRoutesLocked(c.Dst.Addr()) { + return subnetTraffic // a client using another subnet router + } else { + return exitTraffic // a client using exit an exit node + } + case dstNode.Valid(): + if nl.withinRoutesLocked(c.Src.Addr()) { + return subnetTraffic // serving as a subnet router + } else { + return exitTraffic // serving as an exit node + } + default: + return unknownTraffic + } +} + +func (nl *Logger) updatePhysConn(proto ipproto.Proto, src, dst netip.AddrPort, packets, bytes int, recv bool) { + nl.mu.Lock() + defer nl.mu.Unlock() + + // Lookup the connection and increment the counts. + nl.initRecordLocked() + conn := netlogtype.Connection{Proto: proto, Src: src, Dst: dst} + cnts, found := nl.record.physConns[conn] + if !found { + nl.addNewPhysConnLocked(conn) + } + if recv { + cnts.RxPackets += uint64(packets) + cnts.RxBytes += uint64(bytes) + } else { + cnts.TxPackets += uint64(packets) + cnts.TxBytes += uint64(bytes) + } + nl.record.physConns[conn] = cnts +} + +// addNewPhysConnLocked adds the first insertion of a physical connection. +// The [Logger.mu] must be held. +func (nl *Logger) addNewPhysConnLocked(c netlogtype.Connection) { + // Check whether this is the first insertion of the src node. + var srcNodeLen int + srcNode, srcSeen := nl.record.seenNodes[c.Src.Addr()] + if !srcSeen { + srcNode = nl.allNodes[c.Src.Addr()] + if srcNode.Valid() { + srcNodeLen = srcNode.jsonLen() + } + } + + // Check whether the additional [netlogtype.ConnectionCounts] + // and [netlogtype.Node] information would exceed [maxLogSize]. + if nl.recordLen+netlogtype.MaxConnectionCountsJSONSize+srcNodeLen > maxLogSize { + nl.flushRecordLocked() + nl.initRecordLocked() + } + + // Insert newly seen src and/or dst nodes. + if !srcSeen && srcNode.Valid() { + nl.record.seenNodes[c.Src.Addr()] = srcNode + } + nl.recordLen += netlogtype.MaxConnectionCountsJSONSize + srcNodeLen +} + +// initRecordLocked initialize the current record if uninitialized. +// The [Logger.mu] must be held. +func (nl *Logger) initRecordLocked() { + if nl.recordLen != 0 { + return + } + nl.record = record{ + selfNode: nl.selfNode, + start: time.Now().UTC(), + seenNodes: make(map[netip.Addr]nodeUser), + virtConns: make(map[netlogtype.Connection]countsType), + physConns: make(map[netlogtype.Connection]netlogtype.Counts), + } + nl.recordLen = netlogtype.MinMessageJSONSize + nl.selfNode.jsonLen() + + // Start a time to auto-flush the record. + // Avoid tickers since continually waking up a goroutine + // is expensive on battery powered devices. + nl.flushTimer = time.AfterFunc(pollPeriod, func() { + nl.mu.Lock() + defer nl.mu.Unlock() + if !nl.record.start.IsZero() && time.Since(nl.record.start) > pollPeriod/2 { + nl.flushRecordLocked() + } + }) +} + +// flushRecordLocked flushes the current record if initialized. +// The [Logger.mu] must be held. +func (nl *Logger) flushRecordLocked() { + if nl.recordLen == 0 { + return + } + nl.record.end = time.Now().UTC() + if nl.recordsChan != nil { + select { + case nl.recordsChan <- nl.record: + default: + if nl.logf != nil { + nl.logf("netlog: dropped record due to processing backlog") + } + } + } + if nl.flushTimer != nil { + nl.flushTimer.Stop() + nl.flushTimer = nil + } + nl.record = record{} + nl.recordLen = 0 +} + +func makeNodeMaps(nm *netmap.NetworkMap) (selfNode nodeUser, allNodes map[netip.Addr]nodeUser) { + if nm == nil { + return + } + allNodes = make(map[netip.Addr]nodeUser) + if nm.SelfNode.Valid() { + selfNode = nodeUser{nm.SelfNode, nm.UserProfiles[nm.SelfNode.User()]} + for _, addr := range nm.SelfNode.Addresses().All() { + if addr.IsSingleIP() { + allNodes[addr.Addr()] = selfNode + } + } + } + for _, peer := range nm.Peers { + if peer.Valid() { + for _, addr := range peer.Addresses().All() { + if addr.IsSingleIP() { + allNodes[addr.Addr()] = nodeUser{peer, nm.UserProfiles[peer.User()]} + } + } + } + } + return selfNode, allNodes +} + +// ReconfigNetworkMap configures the network logger with an updated netmap. +func (nl *Logger) ReconfigNetworkMap(nm *netmap.NetworkMap) { + selfNode, allNodes := makeNodeMaps(nm) // avoid holding lock while making maps + nl.mu.Lock() + nl.selfNode, nl.allNodes = selfNode, allNodes + nl.mu.Unlock() +} + +func makeRouteMaps(cfg *router.Config) (addrs set.Set[netip.Addr], prefixes []netip.Prefix) { + addrs = make(set.Set[netip.Addr]) + insertPrefixes := func(rs []netip.Prefix) { + for _, p := range rs { + if p.IsSingleIP() { + addrs.Add(p.Addr()) + } else { + prefixes = append(prefixes, p) + } + } + } + insertPrefixes(cfg.LocalAddrs) + insertPrefixes(cfg.Routes) + insertPrefixes(cfg.SubnetRoutes) + return addrs, prefixes +} + +// ReconfigRoutes configures the network logger with updated routes. +// The cfg is used to classify the types of connections captured by +// the tun Device passed to Startup. +func (nl *Logger) ReconfigRoutes(cfg *router.Config) { + addrs, prefixes := makeRouteMaps(cfg) // avoid holding lock while making maps + nl.mu.Lock() + nl.routeAddrs, nl.routePrefixes = addrs, prefixes + nl.mu.Unlock() +} + +// withinRoutesLocked reports whether a is within the configured routes, +// which should only contain Tailscale addresses and subnet routes. +// The [Logger.mu] must be held. +func (nl *Logger) withinRoutesLocked(a netip.Addr) bool { + if nl.routeAddrs.Contains(a) { + return true + } + for _, p := range nl.routePrefixes { + if p.Contains(a) && p.Bits() > 0 { + return true + } + } + return false +} + +// Shutdown shuts down the network logger. +// This attempts to flush out all pending log messages. +// Even if an error is returned, the logger is still shut down. +func (nl *Logger) Shutdown(ctx context.Context) error { + nl.mu.Lock() + defer nl.mu.Unlock() + if nl.shutdownLocked == nil { + return nil + } + return nl.shutdownLocked(ctx) +} diff --git a/wgengine/netlog/netlog_omit.go b/wgengine/netlog/netlog_omit.go new file mode 100644 index 000000000..03610a1ef --- /dev/null +++ b/wgengine/netlog/netlog_omit.go @@ -0,0 +1,14 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build ts_omit_netlog || ts_omit_logtail + +package netlog + +type Logger struct{} + +func (*Logger) Startup(...any) error { return nil } +func (*Logger) Running() bool { return false } +func (*Logger) Shutdown(any) error { return nil } +func (*Logger) ReconfigNetworkMap(any) {} +func (*Logger) ReconfigRoutes(any) {} diff --git a/wgengine/netlog/netlog_test.go b/wgengine/netlog/netlog_test.go new file mode 100644 index 000000000..b4758c7ec --- /dev/null +++ b/wgengine/netlog/netlog_test.go @@ -0,0 +1,237 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !ts_omit_netlog && !ts_omit_logtail + +package netlog + +import ( + "encoding/binary" + "math/rand/v2" + "net/netip" + "sync" + "testing" + "testing/synctest" + "time" + + jsonv2 "github.com/go-json-experiment/json" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "tailscale.com/tailcfg" + "tailscale.com/types/bools" + "tailscale.com/types/ipproto" + "tailscale.com/types/netlogtype" + "tailscale.com/types/netmap" + "tailscale.com/wgengine/router" +) + +func TestEmbedNodeInfo(t *testing.T) { + // Initialize the logger with a particular view of the netmap. + var logger Logger + logger.ReconfigNetworkMap(&netmap.NetworkMap{ + SelfNode: (&tailcfg.Node{ + StableID: "n123456CNTL", + ID: 123456, + Name: "test.tail123456.ts.net", + Addresses: []netip.Prefix{prefix("100.1.2.3")}, + Tags: []string{"tag:foo", "tag:bar"}, + }).View(), + Peers: []tailcfg.NodeView{ + (&tailcfg.Node{ + StableID: "n123457CNTL", + ID: 123457, + Name: "peer1.tail123456.ts.net", + Addresses: []netip.Prefix{prefix("100.1.2.4")}, + Tags: []string{"tag:peer"}, + }).View(), + (&tailcfg.Node{ + StableID: "n123458CNTL", + ID: 123458, + Name: "peer2.tail123456.ts.net", + Addresses: []netip.Prefix{prefix("100.1.2.5")}, + User: 54321, + }).View(), + }, + UserProfiles: map[tailcfg.UserID]tailcfg.UserProfileView{ + 54321: (&tailcfg.UserProfile{ID: 54321, LoginName: "peer@example.com"}).View(), + }, + }) + logger.ReconfigRoutes(&router.Config{ + SubnetRoutes: []netip.Prefix{ + prefix("172.16.1.1/16"), + prefix("192.168.1.1/24"), + }, + }) + + // Update the counters for a few connections. + var group sync.WaitGroup + defer group.Wait() + conns := []struct { + virt bool + proto ipproto.Proto + src, dst netip.AddrPort + txP, txB, rxP, rxB int + }{ + {true, 0x6, addrPort("100.1.2.3:80"), addrPort("100.1.2.4:1812"), 88, 278, 34, 887}, + {true, 0x6, addrPort("100.1.2.3:443"), addrPort("100.1.2.5:1742"), 96, 635, 23, 790}, + {true, 0x6, addrPort("100.1.2.3:443"), addrPort("100.1.2.6:1175"), 48, 94, 86, 618}, // unknown peer (in Tailscale IP space, but not a known peer) + {true, 0x6, addrPort("100.1.2.3:80"), addrPort("192.168.1.241:713"), 43, 154, 66, 883}, + {true, 0x6, addrPort("100.1.2.3:80"), addrPort("192.168.2.241:713"), 43, 154, 66, 883}, // not in the subnet, must be exit traffic + {true, 0x6, addrPort("100.1.2.3:80"), addrPort("172.16.5.18:713"), 7, 243, 40, 59}, + {true, 0x6, addrPort("100.1.2.3:80"), addrPort("172.20.5.18:713"), 61, 753, 42, 492}, // not in the subnet, must be exit traffic + {true, 0x6, addrPort("192.168.1.241:713"), addrPort("100.1.2.3:80"), 43, 154, 66, 883}, + {true, 0x6, addrPort("192.168.2.241:713"), addrPort("100.1.2.3:80"), 43, 154, 66, 883}, // not in the subnet, must be exit traffic + {true, 0x6, addrPort("172.16.5.18:713"), addrPort("100.1.2.3:80"), 7, 243, 40, 59}, + {true, 0x6, addrPort("172.20.5.18:713"), addrPort("100.1.2.3:80"), 61, 753, 42, 492}, // not in the subnet, must be exit traffic + {true, 0x6, addrPort("14.255.192.128:39230"), addrPort("243.42.106.193:48206"), 81, 791, 79, 316}, // unknown connection + {false, 0x6, addrPort("100.1.2.4:0"), addrPort("35.92.180.165:9743"), 63, 136, 61, 409}, // physical traffic with peer1 + {false, 0x6, addrPort("100.1.2.5:0"), addrPort("131.19.35.17:9743"), 88, 452, 2, 716}, // physical traffic with peer2 + } + for range 10 { + for _, conn := range conns { + update := bools.IfElse(conn.virt, logger.updateVirtConn, logger.updatePhysConn) + group.Go(func() { update(conn.proto, conn.src, conn.dst, conn.txP, conn.txB, false) }) + group.Go(func() { update(conn.proto, conn.src, conn.dst, conn.rxP, conn.rxB, true) }) + } + } + group.Wait() + + // Verify that the counters match. + got := logger.record.toMessage(false, false) + got.Start = time.Time{} // avoid flakiness + want := netlogtype.Message{ + NodeID: "n123456CNTL", + SrcNode: netlogtype.Node{ + NodeID: "n123456CNTL", + Name: "test.tail123456.ts.net", + Addresses: []netip.Addr{addr("100.1.2.3")}, + Tags: []string{"tag:bar", "tag:foo"}, + }, + DstNodes: []netlogtype.Node{{ + NodeID: "n123457CNTL", + Name: "peer1.tail123456.ts.net", + Addresses: []netip.Addr{addr("100.1.2.4")}, + Tags: []string{"tag:peer"}, + }, { + NodeID: "n123458CNTL", + Name: "peer2.tail123456.ts.net", + Addresses: []netip.Addr{addr("100.1.2.5")}, + User: "peer@example.com", + }}, + VirtualTraffic: []netlogtype.ConnectionCounts{ + {Connection: conn(0x6, "100.1.2.3:80", "100.1.2.4:1812"), Counts: counts(880, 2780, 340, 8870)}, + {Connection: conn(0x6, "100.1.2.3:443", "100.1.2.5:1742"), Counts: counts(960, 6350, 230, 7900)}, + }, + SubnetTraffic: []netlogtype.ConnectionCounts{ + {Connection: conn(0x6, "100.1.2.3:80", "172.16.5.18:713"), Counts: counts(70, 2430, 400, 590)}, + {Connection: conn(0x6, "100.1.2.3:80", "192.168.1.241:713"), Counts: counts(430, 1540, 660, 8830)}, + {Connection: conn(0x6, "172.16.5.18:713", "100.1.2.3:80"), Counts: counts(70, 2430, 400, 590)}, + {Connection: conn(0x6, "192.168.1.241:713", "100.1.2.3:80"), Counts: counts(430, 1540, 660, 8830)}, + }, + ExitTraffic: []netlogtype.ConnectionCounts{ + {Connection: conn(0x6, "14.255.192.128:39230", "243.42.106.193:48206"), Counts: counts(810, 7910, 790, 3160)}, + {Connection: conn(0x6, "100.1.2.3:80", "172.20.5.18:713"), Counts: counts(610, 7530, 420, 4920)}, + {Connection: conn(0x6, "100.1.2.3:80", "192.168.2.241:713"), Counts: counts(430, 1540, 660, 8830)}, + {Connection: conn(0x6, "100.1.2.3:443", "100.1.2.6:1175"), Counts: counts(480, 940, 860, 6180)}, + {Connection: conn(0x6, "172.20.5.18:713", "100.1.2.3:80"), Counts: counts(610, 7530, 420, 4920)}, + {Connection: conn(0x6, "192.168.2.241:713", "100.1.2.3:80"), Counts: counts(430, 1540, 660, 8830)}, + }, + PhysicalTraffic: []netlogtype.ConnectionCounts{ + {Connection: conn(0x6, "100.1.2.4:0", "35.92.180.165:9743"), Counts: counts(630, 1360, 610, 4090)}, + {Connection: conn(0x6, "100.1.2.5:0", "131.19.35.17:9743"), Counts: counts(880, 4520, 20, 7160)}, + }, + } + if d := cmp.Diff(got, want, cmpopts.EquateComparable(netip.Addr{}, netip.AddrPort{})); d != "" { + t.Errorf("Message (-got +want):\n%s", d) + } +} + +func TestUpdateRace(t *testing.T) { + var logger Logger + logger.recordsChan = make(chan record, 1) + go func(recordsChan chan record) { + for range recordsChan { + } + }(logger.recordsChan) + + var group sync.WaitGroup + defer group.Wait() + for i := range 1000 { + group.Go(func() { + src, dst := randAddrPort(), randAddrPort() + for j := range 1000 { + if i%2 == 0 { + logger.updateVirtConn(0x1, src, dst, rand.IntN(10), rand.IntN(1000), j%2 == 0) + } else { + logger.updatePhysConn(0x1, src, dst, rand.IntN(10), rand.IntN(1000), j%2 == 0) + } + } + }) + group.Go(func() { + for range 1000 { + logger.ReconfigNetworkMap(new(netmap.NetworkMap)) + } + }) + group.Go(func() { + for range 1000 { + logger.ReconfigRoutes(new(router.Config)) + } + }) + } + + group.Wait() + logger.mu.Lock() + close(logger.recordsChan) + logger.recordsChan = nil + logger.mu.Unlock() +} + +func randAddrPort() netip.AddrPort { + var b [4]uint8 + binary.LittleEndian.PutUint32(b[:], rand.Uint32()) + return netip.AddrPortFrom(netip.AddrFrom4(b), uint16(rand.Uint32())) +} + +func TestAutoFlushMaxConns(t *testing.T) { + var logger Logger + logger.recordsChan = make(chan record, 1) + for i := 0; len(logger.recordsChan) == 0; i++ { + logger.updateVirtConn(0, netip.AddrPortFrom(netip.Addr{}, uint16(i)), netip.AddrPort{}, 1, 1, false) + } + b, _ := jsonv2.Marshal(logger.recordsChan) + if len(b) > maxLogSize { + t.Errorf("len(Message) = %v, want <= %d", len(b), maxLogSize) + } +} + +func TestAutoFlushTimeout(t *testing.T) { + var logger Logger + logger.recordsChan = make(chan record, 1) + synctest.Test(t, func(t *testing.T) { + logger.updateVirtConn(0, netip.AddrPort{}, netip.AddrPort{}, 1, 1, false) + time.Sleep(pollPeriod) + }) + rec := <-logger.recordsChan + if d := rec.end.Sub(rec.start); d != pollPeriod { + t.Errorf("window = %v, want %v", d, pollPeriod) + } + if len(rec.virtConns) != 1 { + t.Errorf("len(virtConns) = %d, want 1", len(rec.virtConns)) + } +} + +func BenchmarkUpdateSameConn(b *testing.B) { + var logger Logger + b.ReportAllocs() + for range b.N { + logger.updateVirtConn(0, netip.AddrPort{}, netip.AddrPort{}, 1, 1, false) + } +} + +func BenchmarkUpdateNewConns(b *testing.B) { + var logger Logger + b.ReportAllocs() + for i := range b.N { + logger.updateVirtConn(0, netip.AddrPortFrom(netip.Addr{}, uint16(i)), netip.AddrPort{}, 1, 1, false) + } +} diff --git a/wgengine/netlog/record.go b/wgengine/netlog/record.go new file mode 100644 index 000000000..25b6b1148 --- /dev/null +++ b/wgengine/netlog/record.go @@ -0,0 +1,218 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !ts_omit_netlog && !ts_omit_logtail + +package netlog + +import ( + "cmp" + "net/netip" + "slices" + "strings" + "time" + "unicode/utf8" + + "tailscale.com/tailcfg" + "tailscale.com/types/bools" + "tailscale.com/types/netlogtype" + "tailscale.com/util/set" +) + +// maxLogSize is the maximum number of bytes for a log message. +const maxLogSize = 256 << 10 + +// record is the in-memory representation of a [netlogtype.Message]. +// It uses maps to efficiently look-up addresses and connections. +// In contrast, [netlogtype.Message] is designed to be JSON serializable, +// where complex keys types are not well support in JSON objects. +type record struct { + selfNode nodeUser + + start time.Time + end time.Time + + seenNodes map[netip.Addr]nodeUser + + virtConns map[netlogtype.Connection]countsType + physConns map[netlogtype.Connection]netlogtype.Counts +} + +// nodeUser is a node with additional user profile information. +type nodeUser struct { + tailcfg.NodeView + user tailcfg.UserProfileView // UserProfileView for NodeView.User +} + +// countsType is a counts with classification information about the connection. +type countsType struct { + netlogtype.Counts + connType connType +} + +type connType uint8 + +const ( + unknownTraffic connType = iota + virtualTraffic + subnetTraffic + exitTraffic +) + +// toMessage converts a [record] into a [netlogtype.Message]. +func (r record) toMessage(excludeNodeInfo, anonymizeExitTraffic bool) netlogtype.Message { + if !r.selfNode.Valid() { + return netlogtype.Message{} + } + + m := netlogtype.Message{ + NodeID: r.selfNode.StableID(), + Start: r.start.UTC(), + End: r.end.UTC(), + } + + // Convert node fields. + if !excludeNodeInfo { + m.SrcNode = r.selfNode.toNode() + seenIDs := set.Of(r.selfNode.ID()) + for _, node := range r.seenNodes { + if _, ok := seenIDs[node.ID()]; !ok && node.Valid() { + m.DstNodes = append(m.DstNodes, node.toNode()) + seenIDs.Add(node.ID()) + } + } + slices.SortFunc(m.DstNodes, func(x, y netlogtype.Node) int { + return cmp.Compare(x.NodeID, y.NodeID) + }) + } + + // Converter traffic fields. + anonymizedExitTraffic := make(map[netlogtype.Connection]netlogtype.Counts) + for conn, cnts := range r.virtConns { + switch cnts.connType { + case virtualTraffic: + m.VirtualTraffic = append(m.VirtualTraffic, netlogtype.ConnectionCounts{Connection: conn, Counts: cnts.Counts}) + case subnetTraffic: + m.SubnetTraffic = append(m.SubnetTraffic, netlogtype.ConnectionCounts{Connection: conn, Counts: cnts.Counts}) + default: + if anonymizeExitTraffic { + conn = netlogtype.Connection{ // scrub the IP protocol type + Src: netip.AddrPortFrom(conn.Src.Addr(), 0), // scrub the port number + Dst: netip.AddrPortFrom(conn.Dst.Addr(), 0), // scrub the port number + } + if !r.seenNodes[conn.Src.Addr()].Valid() { + conn.Src = netip.AddrPort{} // not a Tailscale node, so scrub the address + } + if !r.seenNodes[conn.Dst.Addr()].Valid() { + conn.Dst = netip.AddrPort{} // not a Tailscale node, so scrub the address + } + anonymizedExitTraffic[conn] = anonymizedExitTraffic[conn].Add(cnts.Counts) + continue + } + m.ExitTraffic = append(m.ExitTraffic, netlogtype.ConnectionCounts{Connection: conn, Counts: cnts.Counts}) + } + } + for conn, cnts := range anonymizedExitTraffic { + m.ExitTraffic = append(m.ExitTraffic, netlogtype.ConnectionCounts{Connection: conn, Counts: cnts}) + } + for conn, cnts := range r.physConns { + m.PhysicalTraffic = append(m.PhysicalTraffic, netlogtype.ConnectionCounts{Connection: conn, Counts: cnts}) + } + + // Sort the connections for deterministic results. + slices.SortFunc(m.VirtualTraffic, compareConnCnts) + slices.SortFunc(m.SubnetTraffic, compareConnCnts) + slices.SortFunc(m.ExitTraffic, compareConnCnts) + slices.SortFunc(m.PhysicalTraffic, compareConnCnts) + + return m +} + +func compareConnCnts(x, y netlogtype.ConnectionCounts) int { + return cmp.Or( + netip.AddrPort.Compare(x.Src, y.Src), + netip.AddrPort.Compare(x.Dst, y.Dst), + cmp.Compare(x.Proto, y.Proto)) +} + +// jsonLen computes an upper-bound on the size of the JSON representation. +func (nu nodeUser) jsonLen() (n int) { + if !nu.Valid() { + return len(`{"nodeId":""}`) + } + n += len(`{}`) + n += len(`"nodeId":`) + jsonQuotedLen(string(nu.StableID())) + len(`,`) + if len(nu.Name()) > 0 { + n += len(`"name":`) + jsonQuotedLen(nu.Name()) + len(`,`) + } + if nu.Addresses().Len() > 0 { + n += len(`"addresses":[]`) + for _, addr := range nu.Addresses().All() { + n += bools.IfElse(addr.Addr().Is4(), len(`"255.255.255.255"`), len(`"ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff"`)) + len(",") + } + } + if nu.Hostinfo().Valid() && len(nu.Hostinfo().OS()) > 0 { + n += len(`"os":`) + jsonQuotedLen(nu.Hostinfo().OS()) + len(`,`) + } + if nu.Tags().Len() > 0 { + n += len(`"tags":[]`) + for _, tag := range nu.Tags().All() { + n += jsonQuotedLen(tag) + len(",") + } + } else if nu.user.Valid() && nu.user.ID() == nu.User() && len(nu.user.LoginName()) > 0 { + n += len(`"user":`) + jsonQuotedLen(nu.user.LoginName()) + len(",") + } + return n +} + +// toNode converts the [nodeUser] into a [netlogtype.Node]. +func (nu nodeUser) toNode() netlogtype.Node { + if !nu.Valid() { + return netlogtype.Node{} + } + n := netlogtype.Node{ + NodeID: nu.StableID(), + Name: strings.TrimSuffix(nu.Name(), "."), + } + var ipv4, ipv6 netip.Addr + for _, addr := range nu.Addresses().All() { + switch { + case addr.IsSingleIP() && addr.Addr().Is4(): + ipv4 = addr.Addr() + case addr.IsSingleIP() && addr.Addr().Is6(): + ipv6 = addr.Addr() + } + } + n.Addresses = []netip.Addr{ipv4, ipv6} + n.Addresses = slices.DeleteFunc(n.Addresses, func(a netip.Addr) bool { return !a.IsValid() }) + if nu.Hostinfo().Valid() { + n.OS = nu.Hostinfo().OS() + } + if nu.Tags().Len() > 0 { + n.Tags = nu.Tags().AsSlice() + slices.Sort(n.Tags) + n.Tags = slices.Compact(n.Tags) + } else if nu.user.Valid() && nu.user.ID() == nu.User() { + n.User = nu.user.LoginName() + } + return n +} + +// jsonQuotedLen computes the length of the JSON serialization of s +// according to [jsontext.AppendQuote]. +func jsonQuotedLen(s string) int { + n := len(`"`) + len(s) + len(`"`) + for i, r := range s { + switch { + case r == '\b', r == '\t', r == '\n', r == '\f', r == '\r', r == '"', r == '\\': + n += len(`\X`) - 1 + case r < ' ': + n += len(`\uXXXX`) - 1 + case r == utf8.RuneError: + if _, m := utf8.DecodeRuneInString(s[i:]); m == 1 { // exactly an invalid byte + n += len("īŋŊ") - 1 + } + } + } + return n +} diff --git a/wgengine/netlog/record_test.go b/wgengine/netlog/record_test.go new file mode 100644 index 000000000..ec0229534 --- /dev/null +++ b/wgengine/netlog/record_test.go @@ -0,0 +1,257 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !ts_omit_netlog && !ts_omit_logtail + +package netlog + +import ( + "net/netip" + "testing" + "time" + + jsonv2 "github.com/go-json-experiment/json" + "github.com/go-json-experiment/json/jsontext" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "tailscale.com/tailcfg" + "tailscale.com/types/ipproto" + "tailscale.com/types/netlogtype" + "tailscale.com/util/must" +) + +func addr(s string) netip.Addr { + if s == "" { + return netip.Addr{} + } + return must.Get(netip.ParseAddr(s)) +} +func addrPort(s string) netip.AddrPort { + if s == "" { + return netip.AddrPort{} + } + return must.Get(netip.ParseAddrPort(s)) +} +func prefix(s string) netip.Prefix { + if p, err := netip.ParsePrefix(s); err == nil { + return p + } + a := addr(s) + return netip.PrefixFrom(a, a.BitLen()) +} + +func conn(proto ipproto.Proto, src, dst string) netlogtype.Connection { + return netlogtype.Connection{Proto: proto, Src: addrPort(src), Dst: addrPort(dst)} +} + +func counts(txP, txB, rxP, rxB uint64) netlogtype.Counts { + return netlogtype.Counts{TxPackets: txP, TxBytes: txB, RxPackets: rxP, RxBytes: rxB} +} + +func TestToMessage(t *testing.T) { + rec := record{ + selfNode: nodeUser{NodeView: (&tailcfg.Node{ + ID: 123456, + StableID: "n123456CNTL", + Name: "src.tail123456.ts.net.", + Addresses: []netip.Prefix{prefix("100.1.2.3")}, + Tags: []string{"tag:src"}, + }).View()}, + start: time.Now(), + end: time.Now().Add(5 * time.Second), + + seenNodes: map[netip.Addr]nodeUser{ + addr("100.1.2.4"): {NodeView: (&tailcfg.Node{ + ID: 123457, + StableID: "n123457CNTL", + Name: "dst1.tail123456.ts.net.", + Addresses: []netip.Prefix{prefix("100.1.2.4")}, + Tags: []string{"tag:dst1"}, + }).View()}, + addr("100.1.2.5"): {NodeView: (&tailcfg.Node{ + ID: 123458, + StableID: "n123458CNTL", + Name: "dst2.tail123456.ts.net.", + Addresses: []netip.Prefix{prefix("100.1.2.5")}, + Tags: []string{"tag:dst2"}, + }).View()}, + }, + + virtConns: map[netlogtype.Connection]countsType{ + conn(0x1, "100.1.2.3:1234", "100.1.2.4:80"): {Counts: counts(12, 34, 56, 78), connType: virtualTraffic}, + conn(0x1, "100.1.2.3:1234", "100.1.2.5:80"): {Counts: counts(23, 45, 78, 790), connType: virtualTraffic}, + conn(0x6, "172.16.1.1:80", "100.1.2.4:1234"): {Counts: counts(91, 54, 723, 621), connType: subnetTraffic}, + conn(0x6, "172.16.1.2:443", "100.1.2.5:1234"): {Counts: counts(42, 813, 3, 1823), connType: subnetTraffic}, + conn(0x6, "172.16.1.3:80", "100.1.2.6:1234"): {Counts: counts(34, 52, 78, 790), connType: subnetTraffic}, + conn(0x6, "100.1.2.3:1234", "12.34.56.78:80"): {Counts: counts(11, 110, 10, 100), connType: exitTraffic}, + conn(0x6, "100.1.2.4:1234", "23.34.56.78:80"): {Counts: counts(423, 1, 6, 123), connType: exitTraffic}, + conn(0x6, "100.1.2.4:1234", "23.34.56.78:443"): {Counts: counts(22, 220, 20, 200), connType: exitTraffic}, + conn(0x6, "100.1.2.5:1234", "45.34.56.78:80"): {Counts: counts(33, 330, 30, 300), connType: exitTraffic}, + conn(0x6, "100.1.2.6:1234", "67.34.56.78:80"): {Counts: counts(44, 440, 40, 400), connType: exitTraffic}, + conn(0x6, "42.54.72.42:555", "18.42.7.1:777"): {Counts: counts(44, 440, 40, 400)}, + }, + + physConns: map[netlogtype.Connection]netlogtype.Counts{ + conn(0, "100.1.2.4:0", "4.3.2.1:1234"): counts(12, 34, 56, 78), + conn(0, "100.1.2.5:0", "4.3.2.10:1234"): counts(78, 56, 34, 12), + }, + } + rec.seenNodes[rec.selfNode.toNode().Addresses[0]] = rec.selfNode + + got := rec.toMessage(false, false) + want := netlogtype.Message{ + NodeID: rec.selfNode.StableID(), + Start: rec.start, + End: rec.end, + SrcNode: rec.selfNode.toNode(), + DstNodes: []netlogtype.Node{ + rec.seenNodes[addr("100.1.2.4")].toNode(), + rec.seenNodes[addr("100.1.2.5")].toNode(), + }, + VirtualTraffic: []netlogtype.ConnectionCounts{ + {Connection: conn(0x1, "100.1.2.3:1234", "100.1.2.4:80"), Counts: counts(12, 34, 56, 78)}, + {Connection: conn(0x1, "100.1.2.3:1234", "100.1.2.5:80"), Counts: counts(23, 45, 78, 790)}, + }, + SubnetTraffic: []netlogtype.ConnectionCounts{ + {Connection: conn(0x6, "172.16.1.1:80", "100.1.2.4:1234"), Counts: counts(91, 54, 723, 621)}, + {Connection: conn(0x6, "172.16.1.2:443", "100.1.2.5:1234"), Counts: counts(42, 813, 3, 1823)}, + {Connection: conn(0x6, "172.16.1.3:80", "100.1.2.6:1234"), Counts: counts(34, 52, 78, 790)}, + }, + ExitTraffic: []netlogtype.ConnectionCounts{ + {Connection: conn(0x6, "42.54.72.42:555", "18.42.7.1:777"), Counts: counts(44, 440, 40, 400)}, + {Connection: conn(0x6, "100.1.2.3:1234", "12.34.56.78:80"), Counts: counts(11, 110, 10, 100)}, + {Connection: conn(0x6, "100.1.2.4:1234", "23.34.56.78:80"), Counts: counts(423, 1, 6, 123)}, + {Connection: conn(0x6, "100.1.2.4:1234", "23.34.56.78:443"), Counts: counts(22, 220, 20, 200)}, + {Connection: conn(0x6, "100.1.2.5:1234", "45.34.56.78:80"), Counts: counts(33, 330, 30, 300)}, + {Connection: conn(0x6, "100.1.2.6:1234", "67.34.56.78:80"), Counts: counts(44, 440, 40, 400)}, + }, + PhysicalTraffic: []netlogtype.ConnectionCounts{ + {Connection: conn(0, "100.1.2.4:0", "4.3.2.1:1234"), Counts: counts(12, 34, 56, 78)}, + {Connection: conn(0, "100.1.2.5:0", "4.3.2.10:1234"), Counts: counts(78, 56, 34, 12)}, + }, + } + if d := cmp.Diff(got, want, cmpopts.EquateComparable(netip.Addr{}, netip.AddrPort{})); d != "" { + t.Errorf("toMessage(false, false) mismatch (-got +want):\n%s", d) + } + + got = rec.toMessage(true, false) + want.SrcNode = netlogtype.Node{} + want.DstNodes = nil + if d := cmp.Diff(got, want, cmpopts.EquateComparable(netip.Addr{}, netip.AddrPort{})); d != "" { + t.Errorf("toMessage(true, false) mismatch (-got +want):\n%s", d) + } + + got = rec.toMessage(true, true) + want.ExitTraffic = []netlogtype.ConnectionCounts{ + {Connection: conn(0, "", ""), Counts: counts(44+44, 440+440, 40+40, 400+400)}, + {Connection: conn(0, "100.1.2.3:0", ""), Counts: counts(11, 110, 10, 100)}, + {Connection: conn(0, "100.1.2.4:0", ""), Counts: counts(423+22, 1+220, 6+20, 123+200)}, + {Connection: conn(0, "100.1.2.5:0", ""), Counts: counts(33, 330, 30, 300)}, + } + if d := cmp.Diff(got, want, cmpopts.EquateComparable(netip.Addr{}, netip.AddrPort{})); d != "" { + t.Errorf("toMessage(true, true) mismatch (-got +want):\n%s", d) + } +} + +func TestToNode(t *testing.T) { + tests := []struct { + node *tailcfg.Node + user *tailcfg.UserProfile + want netlogtype.Node + }{ + {}, + { + node: &tailcfg.Node{ + StableID: "n123456CNTL", + Name: "test.tail123456.ts.net.", + Addresses: []netip.Prefix{prefix("100.1.2.3")}, + Tags: []string{"tag:dupe", "tag:test", "tag:dupe"}, + User: 12345, // should be ignored + }, + want: netlogtype.Node{ + NodeID: "n123456CNTL", + Name: "test.tail123456.ts.net", + Addresses: []netip.Addr{addr("100.1.2.3")}, + Tags: []string{"tag:dupe", "tag:test"}, + }, + }, + { + node: &tailcfg.Node{ + StableID: "n123456CNTL", + Addresses: []netip.Prefix{prefix("100.1.2.3")}, + User: 12345, + }, + want: netlogtype.Node{ + NodeID: "n123456CNTL", + Addresses: []netip.Addr{addr("100.1.2.3")}, + }, + }, + { + node: &tailcfg.Node{ + StableID: "n123456CNTL", + Addresses: []netip.Prefix{prefix("100.1.2.3")}, + Hostinfo: (&tailcfg.Hostinfo{OS: "linux"}).View(), + User: 12345, + }, + user: &tailcfg.UserProfile{ + ID: 12345, + LoginName: "user@domain", + }, + want: netlogtype.Node{ + NodeID: "n123456CNTL", + Addresses: []netip.Addr{addr("100.1.2.3")}, + OS: "linux", + User: "user@domain", + }, + }, + } + for _, tt := range tests { + nu := nodeUser{tt.node.View(), tt.user.View()} + got := nu.toNode() + b := must.Get(jsonv2.Marshal(got)) + if len(b) > nu.jsonLen() { + t.Errorf("jsonLen = %v, want >= %d", nu.jsonLen(), len(b)) + } + if d := cmp.Diff(got, tt.want, cmpopts.EquateComparable(netip.Addr{})); d != "" { + t.Errorf("toNode mismatch (-got +want):\n%s", d) + } + } +} + +func FuzzQuotedLen(f *testing.F) { + for _, s := range quotedLenTestdata { + f.Add(s) + } + f.Fuzz(func(t *testing.T, s string) { + testQuotedLen(t, s) + }) +} + +func TestQuotedLen(t *testing.T) { + for _, s := range quotedLenTestdata { + testQuotedLen(t, s) + } +} + +var quotedLenTestdata = []string{ + "", // empty string + func() string { + b := make([]byte, 128) + for i := range b { + b[i] = byte(i) + } + return string(b) + }(), // all ASCII + "īŋŊ", // replacement rune + "\xff", // invalid UTF-8 + "Ę•â—”Ī–â—”Ę”", // Unicode gopher +} + +func testQuotedLen(t *testing.T, in string) { + got := jsonQuotedLen(in) + b, _ := jsontext.AppendQuote(nil, in) + want := len(b) + if got != want { + t.Errorf("jsonQuotedLen(%q) = %v, want %v", in, got, want) + } +} diff --git a/wgengine/netstack/gro/gro.go b/wgengine/netstack/gro/gro.go index b268534eb..c8e5e56e1 100644 --- a/wgengine/netstack/gro/gro.go +++ b/wgengine/netstack/gro/gro.go @@ -1,11 +1,14 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause +//go:build !ts_omit_netstack + // Package gro implements GRO for the receive (write) path into gVisor. package gro import ( "bytes" + "github.com/tailscale/wireguard-go/tun" "gvisor.dev/gvisor/pkg/buffer" "gvisor.dev/gvisor/pkg/tcpip" diff --git a/wgengine/netstack/gro/gro_default.go b/wgengine/netstack/gro/gro_default.go index f92ee15ec..c70e19f7c 100644 --- a/wgengine/netstack/gro/gro_default.go +++ b/wgengine/netstack/gro/gro_default.go @@ -1,7 +1,7 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -//go:build !ios +//go:build !ios && !ts_omit_gro package gro diff --git a/wgengine/netstack/gro/gro_ios.go b/wgengine/netstack/gro/gro_disabled.go similarity index 59% rename from wgengine/netstack/gro/gro_ios.go rename to wgengine/netstack/gro/gro_disabled.go index 627b42d7e..d7ffbd913 100644 --- a/wgengine/netstack/gro/gro_ios.go +++ b/wgengine/netstack/gro/gro_disabled.go @@ -1,22 +1,27 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -//go:build ios +//go:build ios || ts_omit_gro package gro import ( - "gvisor.dev/gvisor/pkg/tcpip/stack" + "runtime" + "tailscale.com/net/packet" ) type GRO struct{} func NewGRO() *GRO { - panic("unsupported on iOS") + if runtime.GOOS == "ios" { + panic("unsupported on iOS") + } + panic("GRO disabled in build") + } -func (g *GRO) SetDispatcher(_ stack.NetworkDispatcher) {} +func (g *GRO) SetDispatcher(any) {} func (g *GRO) Enqueue(_ *packet.Parsed) {} diff --git a/wgengine/netstack/gro/netstack_disabled.go b/wgengine/netstack/gro/netstack_disabled.go new file mode 100644 index 000000000..a0f56fa44 --- /dev/null +++ b/wgengine/netstack/gro/netstack_disabled.go @@ -0,0 +1,10 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build ts_omit_netstack + +package gro + +func RXChecksumOffload(any) any { + panic("unreachable") +} diff --git a/wgengine/netstack/link_endpoint.go b/wgengine/netstack/link_endpoint.go index 485d829a3..c5a9dbcbc 100644 --- a/wgengine/netstack/link_endpoint.go +++ b/wgengine/netstack/link_endpoint.go @@ -10,25 +10,34 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" + "tailscale.com/feature/buildfeatures" "tailscale.com/net/packet" "tailscale.com/types/ipproto" "tailscale.com/wgengine/netstack/gro" ) type queue struct { - // TODO(jwhited): evaluate performance with mu as Mutex and/or alternative - // non-channel buffer. - c chan *stack.PacketBuffer - mu sync.RWMutex // mu guards closed + // TODO(jwhited): evaluate performance with a non-channel buffer. + c chan *stack.PacketBuffer + + closeOnce sync.Once + closedCh chan struct{} + + mu sync.RWMutex closed bool } func (q *queue) Close() { + q.closeOnce.Do(func() { + close(q.closedCh) + }) + q.mu.Lock() defer q.mu.Unlock() - if !q.closed { - close(q.c) + if q.closed { + return } + close(q.c) q.closed = true } @@ -51,26 +60,27 @@ func (q *queue) ReadContext(ctx context.Context) *stack.PacketBuffer { } func (q *queue) Write(pkt *stack.PacketBuffer) tcpip.Error { - // q holds the PacketBuffer. q.mu.RLock() defer q.mu.RUnlock() if q.closed { return &tcpip.ErrClosedForSend{} } - - wrote := false select { case q.c <- pkt.IncRef(): - wrote = true - default: - // TODO(jwhited): reconsider/count + return nil + case <-q.closedCh: pkt.DecRef() + return &tcpip.ErrClosedForSend{} } +} - if wrote { - return nil +func (q *queue) Drain() int { + c := 0 + for pkt := range q.c { + pkt.DecRef() + c++ } - return &tcpip.ErrNoBufferSpace{} + return c } func (q *queue) Num() int { @@ -107,7 +117,8 @@ func newLinkEndpoint(size int, mtu uint32, linkAddr tcpip.LinkAddress, supported le := &linkEndpoint{ supportedGRO: supportedGRO, q: &queue{ - c: make(chan *stack.PacketBuffer, size), + c: make(chan *stack.PacketBuffer, size), + closedCh: make(chan struct{}), }, mtu: mtu, linkAddr: linkAddr, @@ -115,24 +126,24 @@ func newLinkEndpoint(size int, mtu uint32, linkAddr tcpip.LinkAddress, supported return le } -// gro attempts to enqueue p on g if l supports a GRO kind matching the +// gro attempts to enqueue p on g if ep supports a GRO kind matching the // transport protocol carried in p. gro may allocate g if it is nil. gro can // either return the existing g, a newly allocated one, or nil. Callers are // responsible for calling Flush() on the returned value if it is non-nil once // they have finished iterating through all GRO candidates for a given vector. -// If gro allocates a *gro.GRO it will have l's stack.NetworkDispatcher set via +// If gro allocates a *gro.GRO it will have ep's stack.NetworkDispatcher set via // SetDispatcher(). -func (l *linkEndpoint) gro(p *packet.Parsed, g *gro.GRO) *gro.GRO { - if l.supportedGRO == groNotSupported || p.IPProto != ipproto.TCP { +func (ep *linkEndpoint) gro(p *packet.Parsed, g *gro.GRO) *gro.GRO { + if !buildfeatures.HasGRO || ep.supportedGRO == groNotSupported || p.IPProto != ipproto.TCP { // IPv6 may have extension headers preceding a TCP header, but we trade // for a fast path and assume p cannot be coalesced in such a case. - l.injectInbound(p) + ep.injectInbound(p) return g } if g == nil { - l.mu.RLock() - d := l.dispatcher - l.mu.RUnlock() + ep.mu.RLock() + d := ep.dispatcher + ep.mu.RUnlock() g = gro.NewGRO() g.SetDispatcher(d) } @@ -143,45 +154,40 @@ func (l *linkEndpoint) gro(p *packet.Parsed, g *gro.GRO) *gro.GRO { // Close closes l. Further packet injections will return an error, and all // pending packets are discarded. Close may be called concurrently with // WritePackets. -func (l *linkEndpoint) Close() { - l.mu.Lock() - l.dispatcher = nil - l.mu.Unlock() - l.q.Close() - l.Drain() +func (ep *linkEndpoint) Close() { + ep.mu.Lock() + ep.dispatcher = nil + ep.mu.Unlock() + ep.q.Close() + ep.Drain() } // Read does non-blocking read one packet from the outbound packet queue. -func (l *linkEndpoint) Read() *stack.PacketBuffer { - return l.q.Read() +func (ep *linkEndpoint) Read() *stack.PacketBuffer { + return ep.q.Read() } // ReadContext does blocking read for one packet from the outbound packet queue. // It can be cancelled by ctx, and in this case, it returns nil. -func (l *linkEndpoint) ReadContext(ctx context.Context) *stack.PacketBuffer { - return l.q.ReadContext(ctx) +func (ep *linkEndpoint) ReadContext(ctx context.Context) *stack.PacketBuffer { + return ep.q.ReadContext(ctx) } // Drain removes all outbound packets from the channel and counts them. -func (l *linkEndpoint) Drain() int { - c := 0 - for pkt := l.Read(); pkt != nil; pkt = l.Read() { - pkt.DecRef() - c++ - } - return c +func (ep *linkEndpoint) Drain() int { + return ep.q.Drain() } // NumQueued returns the number of packets queued for outbound. -func (l *linkEndpoint) NumQueued() int { - return l.q.Num() +func (ep *linkEndpoint) NumQueued() int { + return ep.q.Num() } -func (l *linkEndpoint) injectInbound(p *packet.Parsed) { - l.mu.RLock() - d := l.dispatcher - l.mu.RUnlock() - if d == nil { +func (ep *linkEndpoint) injectInbound(p *packet.Parsed) { + ep.mu.RLock() + d := ep.dispatcher + ep.mu.RUnlock() + if d == nil || !buildfeatures.HasNetstack { return } pkt := gro.RXChecksumOffload(p) @@ -194,35 +200,35 @@ func (l *linkEndpoint) injectInbound(p *packet.Parsed) { // Attach saves the stack network-layer dispatcher for use later when packets // are injected. -func (l *linkEndpoint) Attach(dispatcher stack.NetworkDispatcher) { - l.mu.Lock() - defer l.mu.Unlock() - l.dispatcher = dispatcher +func (ep *linkEndpoint) Attach(dispatcher stack.NetworkDispatcher) { + ep.mu.Lock() + defer ep.mu.Unlock() + ep.dispatcher = dispatcher } // IsAttached implements stack.LinkEndpoint.IsAttached. -func (l *linkEndpoint) IsAttached() bool { - l.mu.RLock() - defer l.mu.RUnlock() - return l.dispatcher != nil +func (ep *linkEndpoint) IsAttached() bool { + ep.mu.RLock() + defer ep.mu.RUnlock() + return ep.dispatcher != nil } // MTU implements stack.LinkEndpoint.MTU. -func (l *linkEndpoint) MTU() uint32 { - l.mu.RLock() - defer l.mu.RUnlock() - return l.mtu +func (ep *linkEndpoint) MTU() uint32 { + ep.mu.RLock() + defer ep.mu.RUnlock() + return ep.mtu } // SetMTU implements stack.LinkEndpoint.SetMTU. -func (l *linkEndpoint) SetMTU(mtu uint32) { - l.mu.Lock() - defer l.mu.Unlock() - l.mtu = mtu +func (ep *linkEndpoint) SetMTU(mtu uint32) { + ep.mu.Lock() + defer ep.mu.Unlock() + ep.mtu = mtu } // Capabilities implements stack.LinkEndpoint.Capabilities. -func (l *linkEndpoint) Capabilities() stack.LinkEndpointCapabilities { +func (ep *linkEndpoint) Capabilities() stack.LinkEndpointCapabilities { // We are required to offload RX checksum validation for the purposes of // GRO. return stack.CapabilityRXChecksumOffload @@ -236,8 +242,8 @@ func (*linkEndpoint) GSOMaxSize() uint32 { } // SupportedGSO implements stack.GSOEndpoint. -func (l *linkEndpoint) SupportedGSO() stack.SupportedGSO { - return l.SupportedGSOKind +func (ep *linkEndpoint) SupportedGSO() stack.SupportedGSO { + return ep.SupportedGSOKind } // MaxHeaderLength returns the maximum size of the link layer header. Given it @@ -247,22 +253,22 @@ func (*linkEndpoint) MaxHeaderLength() uint16 { } // LinkAddress returns the link address of this endpoint. -func (l *linkEndpoint) LinkAddress() tcpip.LinkAddress { - l.mu.RLock() - defer l.mu.RUnlock() - return l.linkAddr +func (ep *linkEndpoint) LinkAddress() tcpip.LinkAddress { + ep.mu.RLock() + defer ep.mu.RUnlock() + return ep.linkAddr } // SetLinkAddress implements stack.LinkEndpoint.SetLinkAddress. -func (l *linkEndpoint) SetLinkAddress(addr tcpip.LinkAddress) { - l.mu.Lock() - defer l.mu.Unlock() - l.linkAddr = addr +func (ep *linkEndpoint) SetLinkAddress(addr tcpip.LinkAddress) { + ep.mu.Lock() + defer ep.mu.Unlock() + ep.linkAddr = addr } // WritePackets stores outbound packets into the channel. // Multiple concurrent calls are permitted. -func (l *linkEndpoint) WritePackets(pkts stack.PacketBufferList) (int, tcpip.Error) { +func (ep *linkEndpoint) WritePackets(pkts stack.PacketBufferList) (int, tcpip.Error) { n := 0 // TODO(jwhited): evaluate writing a stack.PacketBufferList instead of a // single packet. We can split 2 x 64K GSO across @@ -272,7 +278,7 @@ func (l *linkEndpoint) WritePackets(pkts stack.PacketBufferList) (int, tcpip.Err // control MTU (and by effect TCP MSS in gVisor) we *shouldn't* expect to // ever overflow 128 slots (see wireguard-go/tun.ErrTooManySegments usage). for _, pkt := range pkts.AsSlice() { - if err := l.q.Write(pkt); err != nil { + if err := ep.q.Write(pkt); err != nil { if _, ok := err.(*tcpip.ErrNoBufferSpace); !ok && n == 0 { return 0, err } diff --git a/wgengine/netstack/netstack.go b/wgengine/netstack/netstack.go index efb328102..c2b5d8a32 100644 --- a/wgengine/netstack/netstack.go +++ b/wgengine/netstack/netstack.go @@ -32,13 +32,14 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" "gvisor.dev/gvisor/pkg/waiter" - "tailscale.com/drive" "tailscale.com/envknob" + "tailscale.com/feature/buildfeatures" "tailscale.com/ipn/ipnlocal" "tailscale.com/metrics" "tailscale.com/net/dns" "tailscale.com/net/ipset" "tailscale.com/net/netaddr" + "tailscale.com/net/netx" "tailscale.com/net/packet" "tailscale.com/net/tsaddr" "tailscale.com/net/tsdial" @@ -51,6 +52,7 @@ import ( "tailscale.com/types/netmap" "tailscale.com/types/nettype" "tailscale.com/util/clientmetric" + "tailscale.com/util/set" "tailscale.com/version" "tailscale.com/wgengine" "tailscale.com/wgengine/filter" @@ -174,19 +176,18 @@ type Impl struct { // It can only be set before calling Start. ProcessSubnets bool - ipstack *stack.Stack - linkEP *linkEndpoint - tundev *tstun.Wrapper - e wgengine.Engine - pm *proxymap.Mapper - mc *magicsock.Conn - logf logger.Logf - dialer *tsdial.Dialer - ctx context.Context // alive until Close - ctxCancel context.CancelFunc // called on Close - lb *ipnlocal.LocalBackend // or nil - dns *dns.Manager - driveForLocal drive.FileSystemForLocal // or nil + ipstack *stack.Stack + linkEP *linkEndpoint + tundev *tstun.Wrapper + e wgengine.Engine + pm *proxymap.Mapper + mc *magicsock.Conn + logf logger.Logf + dialer *tsdial.Dialer + ctx context.Context // alive until Close + ctxCancel context.CancelFunc // called on Close + lb *ipnlocal.LocalBackend // or nil + dns *dns.Manager // loopbackPort, if non-nil, will enable Impl to loop back (dnat to // :loopbackPort) TCP & UDP flows originally @@ -202,12 +203,14 @@ type Impl struct { // updates. atomicIsLocalIPFunc syncs.AtomicValue[func(netip.Addr) bool] + atomicIsVIPServiceIPFunc syncs.AtomicValue[func(netip.Addr) bool] + // forwardDialFunc, if non-nil, is the net.Dialer.DialContext-style // function that is used to make outgoing connections when forwarding a // TCP connection to another host (e.g. in subnet router mode). // // This is currently only used in tests. - forwardDialFunc func(context.Context, string, string) (net.Conn, error) + forwardDialFunc netx.DialFunc // forwardInFlightPerClientDropped is a metric that tracks how many // in-flight TCP forward requests were dropped due to the per-client @@ -288,7 +291,7 @@ func setTCPBufSizes(ipstack *stack.Stack) error { } // Create creates and populates a new Impl. -func Create(logf logger.Logf, tundev *tstun.Wrapper, e wgengine.Engine, mc *magicsock.Conn, dialer *tsdial.Dialer, dns *dns.Manager, pm *proxymap.Mapper, driveForLocal drive.FileSystemForLocal) (*Impl, error) { +func Create(logf logger.Logf, tundev *tstun.Wrapper, e wgengine.Engine, mc *magicsock.Conn, dialer *tsdial.Dialer, dns *dns.Manager, pm *proxymap.Mapper) (*Impl, error) { if mc == nil { return nil, errors.New("nil magicsock.Conn") } @@ -316,16 +319,24 @@ func Create(logf logger.Logf, tundev *tstun.Wrapper, e wgengine.Engine, mc *magi if tcpipErr != nil { return nil, fmt.Errorf("could not enable TCP SACK: %v", tcpipErr) } - if runtime.GOOS == "windows" { - // See https://github.com/tailscale/tailscale/issues/9707 - // Windows w/RACK performs poorly. ACKs do not appear to be handled in a - // timely manner, leading to spurious retransmissions and a reduced - // congestion window. - tcpRecoveryOpt := tcpip.TCPRecovery(0) - tcpipErr = ipstack.SetTransportProtocolOption(tcp.ProtocolNumber, &tcpRecoveryOpt) - if tcpipErr != nil { - return nil, fmt.Errorf("could not disable TCP RACK: %v", tcpipErr) - } + // See https://github.com/tailscale/tailscale/issues/9707 + // gVisor's RACK performs poorly. ACKs do not appear to be handled in a + // timely manner, leading to spurious retransmissions and a reduced + // congestion window. + tcpRecoveryOpt := tcpip.TCPRecovery(0) + tcpipErr = ipstack.SetTransportProtocolOption(tcp.ProtocolNumber, &tcpRecoveryOpt) + if tcpipErr != nil { + return nil, fmt.Errorf("could not disable TCP RACK: %v", tcpipErr) + } + // gVisor defaults to reno at the time of writing. We explicitly set reno + // congestion control in order to prevent unexpected changes. Netstack + // has an int overflow in sender congestion window arithmetic that is more + // prone to trigger with cubic congestion control. + // See https://github.com/google/gvisor/issues/11632 + renoOpt := tcpip.CongestionControlOption("reno") + tcpipErr = ipstack.SetTransportProtocolOption(tcp.ProtocolNumber, &renoOpt) + if tcpipErr != nil { + return nil, fmt.Errorf("could not set reno congestion control: %v", tcpipErr) } err := setTCPBufSizes(ipstack) if err != nil { @@ -333,7 +344,7 @@ func Create(logf logger.Logf, tundev *tstun.Wrapper, e wgengine.Engine, mc *magi } supportedGSOKind := stack.GSONotSupported supportedGROKind := groNotSupported - if runtime.GOOS == "linux" { + if runtime.GOOS == "linux" && buildfeatures.HasGRO { // TODO(jwhited): add Windows support https://github.com/tailscale/corp/issues/21874 supportedGROKind = tcpGROSupported supportedGSOKind = stack.HostGSOSupported @@ -382,7 +393,6 @@ func Create(logf logger.Logf, tundev *tstun.Wrapper, e wgengine.Engine, mc *magi connsInFlightByClient: make(map[netip.Addr]int), packetsInFlight: make(map[stack.TransportEndpointID]struct{}), dns: dns, - driveForLocal: driveForLocal, } loopbackPort, ok := envknob.LookupInt("TS_DEBUG_NETSTACK_LOOPBACK_PORT") if ok && loopbackPort >= 0 && loopbackPort <= math.MaxUint16 { @@ -390,6 +400,7 @@ func Create(logf logger.Logf, tundev *tstun.Wrapper, e wgengine.Engine, mc *magi } ns.ctx, ns.ctxCancel = context.WithCancel(context.Background()) ns.atomicIsLocalIPFunc.Store(ipset.FalseContainsIPFunc()) + ns.atomicIsVIPServiceIPFunc.Store(ipset.FalseContainsIPFunc()) ns.tundev.PostFilterPacketInboundFromWireGuard = ns.injectInbound ns.tundev.PreFilterPacketOutboundToWireGuardNetstackIntercept = ns.handleLocalPackets stacksForMetrics.Store(ns, struct{}{}) @@ -404,6 +415,14 @@ func (ns *Impl) Close() error { return nil } +// SetTransportProtocolOption forwards to the underlying +// [stack.Stack.SetTransportProtocolOption]. Callers are responsible for +// ensuring that the options are valid, compatible and appropriate for their use +// case. Compatibility may change at any version. +func (ns *Impl) SetTransportProtocolOption(transport tcpip.TransportProtocolNumber, option tcpip.SettableTransportProtocolOption) tcpip.Error { + return ns.ipstack.SetTransportProtocolOption(transport, option) +} + // A single process might have several netstacks running at the same time. // Exported clientmetric counters will have a sum of counters of all of them. var stacksForMetrics syncs.Map[*Impl, struct{}] @@ -414,15 +433,14 @@ func init() { // endpoint, and name collisions will result in Prometheus scraping errors. clientmetric.NewCounterFunc("netstack_tcp_forward_dropped_attempts", func() int64 { var total uint64 - stacksForMetrics.Range(func(ns *Impl, _ struct{}) bool { + for ns := range stacksForMetrics.Keys() { delta := ns.ipstack.Stats().TCP.ForwardMaxInFlightDrop.Value() if total+delta > math.MaxInt64 { total = math.MaxInt64 - return false + break } total += delta - return true - }) + } return int64(total) }) } @@ -536,7 +554,7 @@ func (ns *Impl) wrapTCPProtocolHandler(h protocolHandlerFunc) protocolHandlerFun // Dynamically reconfigure ns's subnet addresses as needed for // outbound traffic. - if !ns.isLocalIP(localIP) { + if !ns.isLocalIP(localIP) && !ns.isVIPServiceIP(localIP) { ns.addSubnetAddress(localIP) } @@ -560,9 +578,16 @@ func (ns *Impl) decrementInFlightTCPForward(tei stack.TransportEndpointID, remot } } +// LocalBackend is a fake name for *ipnlocal.LocalBackend to avoid an import cycle. +type LocalBackend = any + // Start sets up all the handlers so netstack can start working. Implements // wgengine.FakeImpl. -func (ns *Impl) Start(lb *ipnlocal.LocalBackend) error { +func (ns *Impl) Start(b LocalBackend) error { + if b == nil { + panic("nil LocalBackend interface") + } + lb := b.(*ipnlocal.LocalBackend) if lb == nil { panic("nil LocalBackend") } @@ -624,11 +649,21 @@ var v4broadcast = netaddr.IPv4(255, 255, 255, 255) // address slice views. func (ns *Impl) UpdateNetstackIPs(nm *netmap.NetworkMap) { var selfNode tailcfg.NodeView + var serviceAddrSet set.Set[netip.Addr] if nm != nil { ns.atomicIsLocalIPFunc.Store(ipset.NewContainsIPFunc(nm.GetAddresses())) + if buildfeatures.HasServe { + vipServiceIPMap := nm.GetVIPServiceIPMap() + serviceAddrSet = make(set.Set[netip.Addr], len(vipServiceIPMap)*2) + for _, addrs := range vipServiceIPMap { + serviceAddrSet.AddSlice(addrs) + } + ns.atomicIsVIPServiceIPFunc.Store(serviceAddrSet.Contains) + } selfNode = nm.SelfNode } else { ns.atomicIsLocalIPFunc.Store(ipset.FalseContainsIPFunc()) + ns.atomicIsVIPServiceIPFunc.Store(ipset.FalseContainsIPFunc()) } oldPfx := make(map[netip.Prefix]bool) @@ -647,18 +682,21 @@ func (ns *Impl) UpdateNetstackIPs(nm *netmap.NetworkMap) { newPfx := make(map[netip.Prefix]bool) if selfNode.Valid() { - for i := range selfNode.Addresses().Len() { - p := selfNode.Addresses().At(i) + for _, p := range selfNode.Addresses().All() { newPfx[p] = true } if ns.ProcessSubnets { - for i := range selfNode.AllowedIPs().Len() { - p := selfNode.AllowedIPs().At(i) + for _, p := range selfNode.AllowedIPs().All() { newPfx[p] = true } } } + for addr := range serviceAddrSet { + p := netip.PrefixFrom(addr, addr.BitLen()) + newPfx[p] = true + } + pfxToAdd := make(map[netip.Prefix]bool) for p := range newPfx { if !oldPfx[p] { @@ -821,6 +859,27 @@ func (ns *Impl) DialContextTCP(ctx context.Context, ipp netip.AddrPort) (*gonet. return gonet.DialContextTCP(ctx, ns.ipstack, remoteAddress, ipType) } +// DialContextTCPWithBind creates a new gonet.TCPConn connected to the specified +// remoteAddress with its local address bound to localAddr on an available port. +func (ns *Impl) DialContextTCPWithBind(ctx context.Context, localAddr netip.Addr, remoteAddr netip.AddrPort) (*gonet.TCPConn, error) { + remoteAddress := tcpip.FullAddress{ + NIC: nicID, + Addr: tcpip.AddrFromSlice(remoteAddr.Addr().AsSlice()), + Port: remoteAddr.Port(), + } + localAddress := tcpip.FullAddress{ + NIC: nicID, + Addr: tcpip.AddrFromSlice(localAddr.AsSlice()), + } + var ipType tcpip.NetworkProtocolNumber + if remoteAddr.Addr().Is4() { + ipType = ipv4.ProtocolNumber + } else { + ipType = ipv6.ProtocolNumber + } + return gonet.DialTCPWithBind(ctx, ns.ipstack, localAddress, remoteAddress, ipType) +} + func (ns *Impl) DialContextUDP(ctx context.Context, ipp netip.AddrPort) (*gonet.UDPConn, error) { remoteAddress := &tcpip.FullAddress{ NIC: nicID, @@ -837,6 +896,28 @@ func (ns *Impl) DialContextUDP(ctx context.Context, ipp netip.AddrPort) (*gonet. return gonet.DialUDP(ns.ipstack, nil, remoteAddress, ipType) } +// DialContextUDPWithBind creates a new gonet.UDPConn. Connected to remoteAddr. +// With its local address bound to localAddr on an available port. +func (ns *Impl) DialContextUDPWithBind(ctx context.Context, localAddr netip.Addr, remoteAddr netip.AddrPort) (*gonet.UDPConn, error) { + remoteAddress := &tcpip.FullAddress{ + NIC: nicID, + Addr: tcpip.AddrFromSlice(remoteAddr.Addr().AsSlice()), + Port: remoteAddr.Port(), + } + localAddress := &tcpip.FullAddress{ + NIC: nicID, + Addr: tcpip.AddrFromSlice(localAddr.AsSlice()), + } + var ipType tcpip.NetworkProtocolNumber + if remoteAddr.Addr().Is4() { + ipType = ipv4.ProtocolNumber + } else { + ipType = ipv6.ProtocolNumber + } + + return gonet.DialUDP(ns.ipstack, localAddress, remoteAddress, ipType) +} + // getInjectInboundBuffsSizes returns packet memory and a sizes slice for usage // when calling tstun.Wrapper.InjectInboundPacketBuffer(). These are sized with // consideration for MTU and GSO support on ns.linkEP. They should be recycled @@ -958,6 +1039,15 @@ func (ns *Impl) isLocalIP(ip netip.Addr) bool { return ns.atomicIsLocalIPFunc.Load()(ip) } +// isVIPServiceIP reports whether ip is an IP address that's +// assigned to a VIP service. +func (ns *Impl) isVIPServiceIP(ip netip.Addr) bool { + if !buildfeatures.HasServe { + return false + } + return ns.atomicIsVIPServiceIPFunc.Load()(ip) +} + func (ns *Impl) peerAPIPortAtomic(ip netip.Addr) *atomic.Uint32 { if ip.Is4() { return &ns.peerapiPort4Atomic @@ -974,6 +1064,7 @@ func (ns *Impl) shouldProcessInbound(p *packet.Parsed, t *tstun.Wrapper) bool { // Handle incoming peerapi connections in netstack. dstIP := p.Dst.Addr() isLocal := ns.isLocalIP(dstIP) + isService := ns.isVIPServiceIP(dstIP) // Handle TCP connection to the Tailscale IP(s) in some cases: if ns.lb != nil && p.IPProto == ipproto.TCP && isLocal { @@ -996,6 +1087,19 @@ func (ns *Impl) shouldProcessInbound(p *packet.Parsed, t *tstun.Wrapper) bool { return true } } + if buildfeatures.HasServe && isService { + if p.IsEchoRequest() { + return true + } + if ns.lb != nil && p.IPProto == ipproto.TCP { + // An assumption holds for this to work: when tun mode is on for a service, + // its tcp and web are not set. This is enforced in b.setServeConfigLocked. + if ns.lb.ShouldInterceptVIPServiceTCPPort(p.Dst) { + return true + } + } + return false + } if p.IPVersion == 6 && !isLocal && viaRange.Contains(dstIP) { return ns.lb != nil && ns.lb.ShouldHandleViaIP(dstIP) } @@ -1344,6 +1448,13 @@ func (ns *Impl) acceptTCP(r *tcp.ForwarderRequest) { } } +// tcpCloser is an interface to abstract around various TCPConn types that +// allow closing of the read and write streams independently of each other. +type tcpCloser interface { + CloseRead() error + CloseWrite() error +} + func (ns *Impl) forwardTCP(getClient func(...tcpip.SettableSocketOption) *gonet.TCPConn, clientRemoteIP netip.Addr, wq *waiter.Queue, dialAddr netip.AddrPort) (handled bool) { dialAddrStr := dialAddr.String() if debugNetstack() { @@ -1372,7 +1483,7 @@ func (ns *Impl) forwardTCP(getClient func(...tcpip.SettableSocketOption) *gonet. }() // Attempt to dial the outbound connection before we accept the inbound one. - var dialFunc func(context.Context, string, string) (net.Conn, error) + var dialFunc netx.DialFunc if ns.forwardDialFunc != nil { dialFunc = ns.forwardDialFunc } else { @@ -1410,18 +1521,48 @@ func (ns *Impl) forwardTCP(getClient func(...tcpip.SettableSocketOption) *gonet. } defer client.Close() + // As of 2025-07-03, backend is always either a net.TCPConn + // from stdDialer.DialContext (which has the requisite functions), + // or nil from hangDialer in tests (in which case we would have + // errored out by now), so this conversion should always succeed. + backendTCPCloser, backendIsTCPCloser := backend.(tcpCloser) connClosed := make(chan error, 2) go func() { _, err := io.Copy(backend, client) + if err != nil { + err = fmt.Errorf("client -> backend: %w", err) + } connClosed <- err + err = nil + if backendIsTCPCloser { + err = backendTCPCloser.CloseWrite() + } + err = errors.Join(err, client.CloseRead()) + if err != nil { + ns.logf("client -> backend close connection: %v", err) + } }() go func() { _, err := io.Copy(client, backend) + if err != nil { + err = fmt.Errorf("backend -> client: %w", err) + } connClosed <- err + err = nil + if backendIsTCPCloser { + err = backendTCPCloser.CloseRead() + } + err = errors.Join(err, client.CloseWrite()) + if err != nil { + ns.logf("backend -> client close connection: %v", err) + } }() - err = <-connClosed - if err != nil { - ns.logf("proxy connection closed with error: %v", err) + // Wait for both ends of the connection to close. + for range 2 { + err = <-connClosed + if err != nil { + ns.logf("proxy connection closed with error: %v", err) + } } ns.logf("[v2] netstack: forwarder connection to %s closed", dialAddrStr) return @@ -1764,7 +1905,6 @@ func (ns *Impl) ExpVar() expvar.Var { {"option_unknown_received", ipStats.OptionUnknownReceived}, } for _, metric := range ipMetrics { - metric := metric m.Set("counter_ip_"+metric.name, expvar.Func(func() any { return readStatCounter(metric.field) })) @@ -1791,7 +1931,6 @@ func (ns *Impl) ExpVar() expvar.Var { {"errors", fwdStats.Errors}, } for _, metric := range fwdMetrics { - metric := metric m.Set("counter_ip_forward_"+metric.name, expvar.Func(func() any { return readStatCounter(metric.field) })) @@ -1835,7 +1974,6 @@ func (ns *Impl) ExpVar() expvar.Var { {"forward_max_in_flight_drop", tcpStats.ForwardMaxInFlightDrop}, } for _, metric := range tcpMetrics { - metric := metric m.Set("counter_tcp_"+metric.name, expvar.Func(func() any { return readStatCounter(metric.field) })) @@ -1862,7 +2000,6 @@ func (ns *Impl) ExpVar() expvar.Var { {"checksum_errors", udpStats.ChecksumErrors}, } for _, metric := range udpMetrics { - metric := metric m.Set("counter_udp_"+metric.name, expvar.Func(func() any { return readStatCounter(metric.field) })) diff --git a/wgengine/netstack/netstack_test.go b/wgengine/netstack/netstack_test.go index 1bfc76fef..93022811c 100644 --- a/wgengine/netstack/netstack_test.go +++ b/wgengine/netstack/netstack_test.go @@ -22,6 +22,7 @@ import ( "tailscale.com/ipn/ipnlocal" "tailscale.com/ipn/store/mem" "tailscale.com/metrics" + "tailscale.com/net/netx" "tailscale.com/net/packet" "tailscale.com/net/tsaddr" "tailscale.com/net/tsdial" @@ -44,13 +45,14 @@ func TestInjectInboundLeak(t *testing.T) { t.Logf(format, args...) } } - sys := new(tsd.System) + sys := tsd.NewSystem() eng, err := wgengine.NewUserspaceEngine(logf, wgengine.Config{ Tun: tunDev, Dialer: dialer, SetSubsystem: sys.Set, - HealthTracker: sys.HealthTracker(), + HealthTracker: sys.HealthTracker.Get(), Metrics: sys.UserMetricsRegistry(), + EventBus: sys.Bus.Get(), }) if err != nil { t.Fatal(err) @@ -64,8 +66,9 @@ func TestInjectInboundLeak(t *testing.T) { if err != nil { t.Fatal(err) } + t.Cleanup(lb.Shutdown) - ns, err := Create(logf, tunWrap, eng, sys.MagicSock.Get(), dialer, sys.DNSManager.Get(), sys.ProxyMapper(), nil) + ns, err := Create(logf, tunWrap, eng, sys.MagicSock.Get(), dialer, sys.DNSManager.Get(), sys.ProxyMapper()) if err != nil { t.Fatal(err) } @@ -99,7 +102,7 @@ func getMemStats() (ms runtime.MemStats) { func makeNetstack(tb testing.TB, config func(*Impl)) *Impl { tunDev := tstun.NewFake() - sys := &tsd.System{} + sys := tsd.NewSystem() sys.Set(new(mem.Store)) dialer := new(tsdial.Dialer) logf := tstest.WhileTestRunningLogger(tb) @@ -107,8 +110,9 @@ func makeNetstack(tb testing.TB, config func(*Impl)) *Impl { Tun: tunDev, Dialer: dialer, SetSubsystem: sys.Set, - HealthTracker: sys.HealthTracker(), + HealthTracker: sys.HealthTracker.Get(), Metrics: sys.UserMetricsRegistry(), + EventBus: sys.Bus.Get(), }) if err != nil { tb.Fatal(err) @@ -116,7 +120,7 @@ func makeNetstack(tb testing.TB, config func(*Impl)) *Impl { tb.Cleanup(func() { eng.Close() }) sys.Set(eng) - ns, err := Create(logf, sys.Tun.Get(), eng, sys.MagicSock.Get(), dialer, sys.DNSManager.Get(), sys.ProxyMapper(), nil) + ns, err := Create(logf, sys.Tun.Get(), eng, sys.MagicSock.Get(), dialer, sys.DNSManager.Get(), sys.ProxyMapper()) if err != nil { tb.Fatal(err) } @@ -126,6 +130,7 @@ func makeNetstack(tb testing.TB, config func(*Impl)) *Impl { if err != nil { tb.Fatalf("NewLocalBackend: %v", err) } + tb.Cleanup(lb.Shutdown) ns.atomicIsLocalIPFunc.Store(func(netip.Addr) bool { return true }) if config != nil { @@ -510,7 +515,7 @@ func tcp4syn(tb testing.TB, src, dst netip.Addr, sport, dport uint16) []byte { // makeHangDialer returns a dialer that notifies the returned channel when a // connection is dialed and then hangs until the test finishes. -func makeHangDialer(tb testing.TB) (func(context.Context, string, string) (net.Conn, error), chan struct{}) { +func makeHangDialer(tb testing.TB) (netx.DialFunc, chan struct{}) { done := make(chan struct{}) tb.Cleanup(func() { close(done) diff --git a/wgengine/netstack/netstack_userping.go b/wgengine/netstack/netstack_userping.go index ee635bd87..b35a6eca9 100644 --- a/wgengine/netstack/netstack_userping.go +++ b/wgengine/netstack/netstack_userping.go @@ -13,6 +13,7 @@ import ( "runtime" "time" + "tailscale.com/feature/buildfeatures" "tailscale.com/version/distro" ) @@ -20,7 +21,7 @@ import ( // CAP_NET_RAW from tailscaled's binary. var setAmbientCapsRaw func(*exec.Cmd) -var isSynology = runtime.GOOS == "linux" && distro.Get() == distro.Synology +var isSynology = runtime.GOOS == "linux" && buildfeatures.HasSynology && distro.Get() == distro.Synology // sendOutboundUserPing sends a non-privileged ICMP (or ICMPv6) ping to dstIP with the given timeout. func (ns *Impl) sendOutboundUserPing(dstIP netip.Addr, timeout time.Duration) error { @@ -61,7 +62,7 @@ func (ns *Impl) sendOutboundUserPing(dstIP netip.Addr, timeout time.Duration) er ping = "/bin/ping" } cmd := exec.Command(ping, "-c", "1", "-W", "3", dstIP.String()) - if isSynology && os.Getuid() != 0 { + if buildfeatures.HasSynology && isSynology && os.Getuid() != 0 { // On DSM7 we run as non-root and need to pass // CAP_NET_RAW if our binary has it. setAmbientCapsRaw(cmd) diff --git a/wgengine/pendopen.go b/wgengine/pendopen.go index 59b1fccda..7eaf43e52 100644 --- a/wgengine/pendopen.go +++ b/wgengine/pendopen.go @@ -1,22 +1,29 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause +//go:build !ts_omit_debug + package wgengine import ( "fmt" + "net/netip" "runtime" + "strings" + "sync" "time" + "github.com/gaissmai/bart" "tailscale.com/net/flowtrack" "tailscale.com/net/packet" - "tailscale.com/net/tsaddr" "tailscale.com/net/tstun" "tailscale.com/types/ipproto" "tailscale.com/util/mak" "tailscale.com/wgengine/filter" ) +type flowtrackTuple = flowtrack.Tuple + const tcpTimeoutBeforeDebug = 5 * time.Second type pendingOpenFlow struct { @@ -53,6 +60,10 @@ func (e *userspaceEngine) noteFlowProblemFromPeer(f flowtrack.Tuple, problem pac of.problem = problem } +func tsRejectFlow(rh packet.TailscaleRejectedHeader) flowtrack.Tuple { + return flowtrack.MakeTuple(rh.Proto, rh.Src, rh.Dst) +} + func (e *userspaceEngine) trackOpenPreFilterIn(pp *packet.Parsed, t *tstun.Wrapper) (res filter.Response) { res = filter.Accept // always @@ -63,8 +74,8 @@ func (e *userspaceEngine) trackOpenPreFilterIn(pp *packet.Parsed, t *tstun.Wrapp return } if rh.MaybeBroken { - e.noteFlowProblemFromPeer(rh.Flow(), rh.Reason) - } else if f := rh.Flow(); e.removeFlow(f) { + e.noteFlowProblemFromPeer(tsRejectFlow(rh), rh.Reason) + } else if f := tsRejectFlow(rh); e.removeFlow(f) { e.logf("open-conn-track: flow %v %v > %v rejected due to %v", rh.Proto, rh.Src, rh.Dst, rh.Reason) } return @@ -86,6 +97,57 @@ func (e *userspaceEngine) trackOpenPreFilterIn(pp *packet.Parsed, t *tstun.Wrapp return } +var ( + appleIPRange = netip.MustParsePrefix("17.0.0.0/8") + canonicalIPs = sync.OnceValue(func() (checkIPFunc func(netip.Addr) bool) { + // https://bgp.he.net/AS41231#_prefixes + t := &bart.Table[bool]{} + for _, s := range strings.Fields(` + 91.189.89.0/24 + 91.189.91.0/24 + 91.189.92.0/24 + 91.189.93.0/24 + 91.189.94.0/24 + 91.189.95.0/24 + 162.213.32.0/24 + 162.213.34.0/24 + 162.213.35.0/24 + 185.125.188.0/23 + 185.125.190.0/24 + 194.169.254.0/24`) { + t.Insert(netip.MustParsePrefix(s), true) + } + return func(ip netip.Addr) bool { + v, _ := t.Lookup(ip) + return v + } + }) +) + +// isOSNetworkProbe reports whether the target is likely a network +// connectivity probe target from e.g. iOS or Ubuntu network-manager. +// +// iOS likes to probe Apple IPs on all interfaces to check for connectivity. +// Don't start timers tracking those. They won't succeed anyway. Avoids log +// spam like: +func (e *userspaceEngine) isOSNetworkProbe(dst netip.AddrPort) bool { + // iOS had log spam like: + // open-conn-track: timeout opening (100.115.73.60:52501 => 17.125.252.5:443); no associated peer node + if runtime.GOOS == "ios" && dst.Port() == 443 && appleIPRange.Contains(dst.Addr()) { + if _, ok := e.PeerForIP(dst.Addr()); !ok { + return true + } + } + // NetworkManager; https://github.com/tailscale/tailscale/issues/13687 + // open-conn-track: timeout opening (TCP 100.96.229.119:42798 => 185.125.190.49:80); no associated peer node + if runtime.GOOS == "linux" && dst.Port() == 80 && canonicalIPs()(dst.Addr()) { + if _, ok := e.PeerForIP(dst.Addr()); !ok { + return true + } + } + return false +} + func (e *userspaceEngine) trackOpenPostFilterOut(pp *packet.Parsed, t *tstun.Wrapper) (res filter.Response) { res = filter.Accept // always @@ -95,19 +157,12 @@ func (e *userspaceEngine) trackOpenPostFilterOut(pp *packet.Parsed, t *tstun.Wra pp.TCPFlags&packet.TCPSyn == 0 { return } + if e.isOSNetworkProbe(pp.Dst) { + return + } flow := flowtrack.MakeTuple(pp.IPProto, pp.Src, pp.Dst) - // iOS likes to probe Apple IPs on all interfaces to check for connectivity. - // Don't start timers tracking those. They won't succeed anyway. Avoids log spam - // like: - // open-conn-track: timeout opening (100.115.73.60:52501 => 17.125.252.5:443); no associated peer node - if runtime.GOOS == "ios" && flow.DstPort() == 443 && !tsaddr.IsTailscaleIP(flow.DstAddr()) { - if _, ok := e.PeerForIP(flow.DstAddr()); !ok { - return - } - } - e.mu.Lock() defer e.mu.Unlock() if _, dup := e.pendOpen[flow]; dup { @@ -151,7 +206,7 @@ func (e *userspaceEngine) onOpenTimeout(flow flowtrack.Tuple) { e.logf("open-conn-track: timeout opening %v; peer node %v running pre-0.100", flow, n.Key().ShortString()) return } - if n.DERP() == "" { + if n.HomeDERP() == 0 { e.logf("open-conn-track: timeout opening %v; peer node %v not connected to any DERP relay", flow, n.Key().ShortString()) return } @@ -160,8 +215,7 @@ func (e *userspaceEngine) onOpenTimeout(flow flowtrack.Tuple) { ps, found := e.getPeerStatusLite(n.Key()) if !found { onlyZeroRoute := true // whether peerForIP returned n only because its /0 route matched - for i := range n.AllowedIPs().Len() { - r := n.AllowedIPs().At(i) + for _, r := range n.AllowedIPs().All() { if r.Bits() != 0 && r.Contains(flow.DstAddr()) { onlyZeroRoute = false break @@ -193,15 +247,15 @@ func (e *userspaceEngine) onOpenTimeout(flow flowtrack.Tuple) { if n.IsWireGuardOnly() { online = "wg" } else { - if v := n.Online(); v != nil { - if *v { + if v, ok := n.Online().GetOk(); ok { + if v { online = "yes" } else { online = "no" } } - if n.LastSeen() != nil && online != "yes" { - online += fmt.Sprintf(", lastseen=%v", durFmt(*n.LastSeen())) + if lastSeen, ok := n.LastSeen().GetOk(); ok && online != "yes" { + online += fmt.Sprintf(", lastseen=%v", durFmt(lastSeen)) } } e.logf("open-conn-track: timeout opening %v to node %v; online=%v, lastRecv=%v", diff --git a/wgengine/pendopen_omit.go b/wgengine/pendopen_omit.go new file mode 100644 index 000000000..013425d35 --- /dev/null +++ b/wgengine/pendopen_omit.go @@ -0,0 +1,24 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build ts_omit_debug + +package wgengine + +import ( + "tailscale.com/net/packet" + "tailscale.com/net/tstun" + "tailscale.com/wgengine/filter" +) + +type flowtrackTuple = struct{} + +type pendingOpenFlow struct{} + +func (*userspaceEngine) trackOpenPreFilterIn(pp *packet.Parsed, t *tstun.Wrapper) (res filter.Response) { + panic("unreachable") +} + +func (*userspaceEngine) trackOpenPostFilterOut(pp *packet.Parsed, t *tstun.Wrapper) (res filter.Response) { + panic("unreachable") +} diff --git a/wgengine/router/callback.go b/wgengine/router/callback.go index 1d9091277..c1838539b 100644 --- a/wgengine/router/callback.go +++ b/wgengine/router/callback.go @@ -56,13 +56,6 @@ func (r *CallbackRouter) Set(rcfg *Config) error { return r.SetBoth(r.rcfg, r.dcfg) } -// UpdateMagicsockPort implements the Router interface. This implementation -// does nothing and returns nil because this router does not currently need -// to know what the magicsock UDP port is. -func (r *CallbackRouter) UpdateMagicsockPort(_ uint16, _ string) error { - return nil -} - // SetDNS implements dns.OSConfigurator. func (r *CallbackRouter) SetDNS(dcfg dns.OSConfig) error { r.mu.Lock() diff --git a/wgengine/router/consolidating_router_test.go b/wgengine/router/consolidating_router_test.go index 871682d13..ba2e4d07a 100644 --- a/wgengine/router/consolidating_router_test.go +++ b/wgengine/router/consolidating_router_test.go @@ -4,7 +4,6 @@ package router import ( - "log" "net/netip" "testing" @@ -56,7 +55,7 @@ func TestConsolidateRoutes(t *testing.T) { }, } - cr := &consolidatingRouter{logf: log.Printf} + cr := &consolidatingRouter{logf: t.Logf} for _, test := range tests { t.Run(test.name, func(t *testing.T) { got := cr.consolidateRoutes(test.cfg) diff --git a/wgengine/router/ifconfig_windows.go b/wgengine/router/osrouter/ifconfig_windows.go similarity index 99% rename from wgengine/router/ifconfig_windows.go rename to wgengine/router/osrouter/ifconfig_windows.go index 40e9dc6e0..cb87ad5f2 100644 --- a/wgengine/router/ifconfig_windows.go +++ b/wgengine/router/osrouter/ifconfig_windows.go @@ -3,7 +3,7 @@ * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. */ -package router +package osrouter import ( "errors" @@ -18,7 +18,7 @@ import ( "tailscale.com/net/netmon" "tailscale.com/net/tsaddr" "tailscale.com/net/tstun" - "tailscale.com/util/multierr" + "tailscale.com/wgengine/router" "tailscale.com/wgengine/winnet" ole "github.com/go-ole/go-ole" @@ -246,7 +246,7 @@ var networkCategoryWarnable = health.Register(&health.Warnable{ MapDebugFlag: "warn-network-category-unhealthy", }) -func configureInterface(cfg *Config, tun *tun.NativeTun, ht *health.Tracker) (retErr error) { +func configureInterface(cfg *router.Config, tun *tun.NativeTun, ht *health.Tracker) (retErr error) { var mtu = tstun.DefaultTUNMTU() luid := winipcfg.LUID(tun.LUID()) iface, err := interfaceFromLUID(luid, @@ -830,5 +830,5 @@ func syncRoutes(ifc *winipcfg.IPAdapterAddresses, want []*routeData, dontDelete } } - return multierr.New(errs...) + return errors.Join(errs...) } diff --git a/wgengine/router/ifconfig_windows_test.go b/wgengine/router/osrouter/ifconfig_windows_test.go similarity index 99% rename from wgengine/router/ifconfig_windows_test.go rename to wgengine/router/osrouter/ifconfig_windows_test.go index 11b98d1d7..b858ef4f6 100644 --- a/wgengine/router/ifconfig_windows_test.go +++ b/wgengine/router/osrouter/ifconfig_windows_test.go @@ -1,7 +1,7 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -package router +package osrouter import ( "fmt" diff --git a/wgengine/router/osrouter/osrouter.go b/wgengine/router/osrouter/osrouter.go new file mode 100644 index 000000000..281454b06 --- /dev/null +++ b/wgengine/router/osrouter/osrouter.go @@ -0,0 +1,15 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package osrouter contains OS-specific router implementations. +// This package has no API; it exists purely to import +// for the side effect of it registering itself with the wgengine/router +// package. +package osrouter + +import "tailscale.com/wgengine/router" + +// shutdownConfig is a routing configuration that removes all router +// state from the OS. It's the config used when callers pass in a nil +// Config. +var shutdownConfig router.Config diff --git a/wgengine/router/osrouter/osrouter_test.go b/wgengine/router/osrouter/osrouter_test.go new file mode 100644 index 000000000..d0cb3db69 --- /dev/null +++ b/wgengine/router/osrouter/osrouter_test.go @@ -0,0 +1,15 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package osrouter + +import "net/netip" + +//lint:ignore U1000 used in Windows/Linux tests only +func mustCIDRs(ss ...string) []netip.Prefix { + var ret []netip.Prefix + for _, s := range ss { + ret = append(ret, netip.MustParsePrefix(s)) + } + return ret +} diff --git a/wgengine/router/router_freebsd.go b/wgengine/router/osrouter/router_freebsd.go similarity index 56% rename from wgengine/router/router_freebsd.go rename to wgengine/router/osrouter/router_freebsd.go index 40523b4fd..a142e7a84 100644 --- a/wgengine/router/router_freebsd.go +++ b/wgengine/router/osrouter/router_freebsd.go @@ -1,22 +1,18 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -package router +package osrouter import ( - "github.com/tailscale/wireguard-go/tun" - "tailscale.com/health" "tailscale.com/net/netmon" "tailscale.com/types/logger" + "tailscale.com/wgengine/router" ) -// For now this router only supports the userspace WireGuard implementations. -// -// Work is currently underway for an in-kernel FreeBSD implementation of wireguard -// https://svnweb.freebsd.org/base?view=revision&revision=357986 - -func newUserspaceRouter(logf logger.Logf, tundev tun.Device, netMon *netmon.Monitor, health *health.Tracker) (Router, error) { - return newUserspaceBSDRouter(logf, tundev, netMon, health) +func init() { + router.HookCleanUp.Set(func(logf logger.Logf, netMon *netmon.Monitor, ifName string) { + cleanUp(logf, ifName) + }) } func cleanUp(logf logger.Logf, interfaceName string) { diff --git a/wgengine/router/router_linux.go b/wgengine/router/osrouter/router_linux.go similarity index 89% rename from wgengine/router/router_linux.go rename to wgengine/router/osrouter/router_linux.go index 2af73e26d..7442c045e 100644 --- a/wgengine/router/router_linux.go +++ b/wgengine/router/osrouter/router_linux.go @@ -1,7 +1,9 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -package router +//go:build !android + +package osrouter import ( "errors" @@ -12,6 +14,7 @@ import ( "os/exec" "strconv" "strings" + "sync" "sync/atomic" "syscall" "time" @@ -24,14 +27,27 @@ import ( "tailscale.com/envknob" "tailscale.com/health" "tailscale.com/net/netmon" + "tailscale.com/tsconst" "tailscale.com/types/logger" "tailscale.com/types/opt" "tailscale.com/types/preftype" + "tailscale.com/util/eventbus" "tailscale.com/util/linuxfw" - "tailscale.com/util/multierr" "tailscale.com/version/distro" + "tailscale.com/wgengine/router" ) +func init() { + router.HookNewUserspaceRouter.Set(func(opts router.NewOpts) (router.Router, error) { + return newUserspaceRouter(opts.Logf, opts.Tun, opts.NetMon, opts.Health, opts.Bus) + }) + router.HookCleanUp.Set(func(logf logger.Logf, netMon *netmon.Monitor, ifName string) { + cleanUp(logf, ifName) + }) +} + +var getDistroFunc = distro.Get + const ( netfilterOff = preftype.NetfilterOff netfilterNoDivert = preftype.NetfilterNoDivert @@ -39,19 +55,14 @@ const ( ) type linuxRouter struct { - closed atomic.Bool - logf func(fmt string, args ...any) - tunname string - netMon *netmon.Monitor - health *health.Tracker - unregNetMon func() - addrs map[netip.Prefix]bool - routes map[netip.Prefix]bool - localRoutes map[netip.Prefix]bool - snatSubnetRoutes bool - statefulFiltering bool - netfilterMode preftype.NetfilterMode - netfilterKind string + closed atomic.Bool + logf func(fmt string, args ...any) + tunname string + netMon *netmon.Monitor + health *health.Tracker + eventClient *eventbus.Client + rulesAddedPub *eventbus.Publisher[AddIPRules] + unregNetMon func() // ruleRestorePending is whether a timer has been started to // restore deleted ip rules. @@ -69,11 +80,19 @@ type linuxRouter struct { cmd commandRunner nfr linuxfw.NetfilterRunner - magicsockPortV4 uint16 - magicsockPortV6 uint16 + mu sync.Mutex + addrs map[netip.Prefix]bool + routes map[netip.Prefix]bool + localRoutes map[netip.Prefix]bool + snatSubnetRoutes bool + statefulFiltering bool + netfilterMode preftype.NetfilterMode + netfilterKind string + magicsockPortV4 uint16 + magicsockPortV6 uint16 } -func newUserspaceRouter(logf logger.Logf, tunDev tun.Device, netMon *netmon.Monitor, health *health.Tracker) (Router, error) { +func newUserspaceRouter(logf logger.Logf, tunDev tun.Device, netMon *netmon.Monitor, health *health.Tracker, bus *eventbus.Bus) (router.Router, error) { tunname, err := tunDev.Name() if err != nil { return nil, err @@ -83,10 +102,10 @@ func newUserspaceRouter(logf logger.Logf, tunDev tun.Device, netMon *netmon.Moni ambientCapNetAdmin: useAmbientCaps(), } - return newUserspaceRouterAdvanced(logf, tunname, netMon, cmd, health) + return newUserspaceRouterAdvanced(logf, tunname, netMon, cmd, health, bus) } -func newUserspaceRouterAdvanced(logf logger.Logf, tunname string, netMon *netmon.Monitor, cmd commandRunner, health *health.Tracker) (Router, error) { +func newUserspaceRouterAdvanced(logf logger.Logf, tunname string, netMon *netmon.Monitor, cmd commandRunner, health *health.Tracker, bus *eventbus.Bus) (router.Router, error) { r := &linuxRouter{ logf: logf, tunname: tunname, @@ -99,6 +118,19 @@ func newUserspaceRouterAdvanced(logf logger.Logf, tunname string, netMon *netmon ipRuleFixLimiter: rate.NewLimiter(rate.Every(5*time.Second), 10), ipPolicyPrefBase: 5200, } + ec := bus.Client("router-linux") + r.rulesAddedPub = eventbus.Publish[AddIPRules](ec) + eventbus.SubscribeFunc(ec, func(rs netmon.RuleDeleted) { + r.onIPRuleDeleted(rs.Table, rs.Priority) + }) + eventbus.SubscribeFunc(ec, func(pu router.PortUpdate) { + r.logf("portUpdate(port=%v, network=%s)", pu.UDPPort, pu.EndpointNetwork) + if err := r.updateMagicsockPort(pu.UDPPort, pu.EndpointNetwork); err != nil { + r.logf("updateMagicsockPort(port=%v, network=%s) failed: %v", pu.UDPPort, pu.EndpointNetwork, err) + } + }) + r.eventClient = ec + if r.useIPCommand() { r.ipRuleAvailable = (cmd.run("ip", "rule") == nil) } else { @@ -222,7 +254,7 @@ func busyboxParseVersion(output string) (major, minor, patch int, err error) { } func useAmbientCaps() bool { - if distro.Get() != distro.Synology { + if getDistroFunc() != distro.Synology { return false } return distro.DSMVersion() >= 7 @@ -272,6 +304,10 @@ func (r *linuxRouter) fwmaskWorks() bool { return v } +// AddIPRules is used as an event signal to signify that rules have been added. +// It is added to aid testing, but could be extended if there's a reason for it. +type AddIPRules struct{} + // onIPRuleDeleted is the callback from the network monitor for when an IP // policy rule is deleted. See Issue 1591. // @@ -299,6 +335,9 @@ func (r *linuxRouter) onIPRuleDeleted(table uint8, priority uint32) { r.ruleRestorePending.Swap(false) return } + + r.rulesAddedPub.Publish(AddIPRules{}) + time.AfterFunc(rr.Delay()+250*time.Millisecond, func() { if r.ruleRestorePending.Swap(false) && !r.closed.Load() { r.logf("somebody (likely systemd-networkd) deleted ip rules; restoring Tailscale's") @@ -308,10 +347,9 @@ func (r *linuxRouter) onIPRuleDeleted(table uint8, priority uint32) { } func (r *linuxRouter) Up() error { - if r.unregNetMon == nil && r.netMon != nil { - r.unregNetMon = r.netMon.RegisterRuleDeleteCallback(r.onIPRuleDeleted) - } - if err := r.setNetfilterMode(netfilterOff); err != nil { + r.mu.Lock() + defer r.mu.Unlock() + if err := r.setNetfilterModeLocked(netfilterOff); err != nil { return fmt.Errorf("setting netfilter mode: %w", err) } if err := r.addIPRules(); err != nil { @@ -325,17 +363,20 @@ func (r *linuxRouter) Up() error { } func (r *linuxRouter) Close() error { + r.mu.Lock() + defer r.mu.Unlock() r.closed.Store(true) if r.unregNetMon != nil { r.unregNetMon() } + r.eventClient.Close() if err := r.downInterface(); err != nil { return err } if err := r.delIPRules(); err != nil { return err } - if err := r.setNetfilterMode(netfilterOff); err != nil { + if err := r.setNetfilterModeLocked(netfilterOff); err != nil { return err } if err := r.delRoutes(); err != nil { @@ -349,10 +390,10 @@ func (r *linuxRouter) Close() error { return nil } -// setupNetfilter initializes the NetfilterRunner in r.nfr. It expects r.nfr +// setupNetfilterLocked initializes the NetfilterRunner in r.nfr. It expects r.nfr // to be nil, or the current netfilter to be set to netfilterOff. // kind should be either a linuxfw.FirewallMode, or the empty string for auto. -func (r *linuxRouter) setupNetfilter(kind string) error { +func (r *linuxRouter) setupNetfilterLocked(kind string) error { r.netfilterKind = kind var err error @@ -365,25 +406,27 @@ func (r *linuxRouter) setupNetfilter(kind string) error { } // Set implements the Router interface. -func (r *linuxRouter) Set(cfg *Config) error { +func (r *linuxRouter) Set(cfg *router.Config) error { + r.mu.Lock() + defer r.mu.Unlock() var errs []error if cfg == nil { cfg = &shutdownConfig } if cfg.NetfilterKind != r.netfilterKind { - if err := r.setNetfilterMode(netfilterOff); err != nil { + if err := r.setNetfilterModeLocked(netfilterOff); err != nil { err = fmt.Errorf("could not disable existing netfilter: %w", err) errs = append(errs, err) } else { r.nfr = nil - if err := r.setupNetfilter(cfg.NetfilterKind); err != nil { + if err := r.setupNetfilterLocked(cfg.NetfilterKind); err != nil { errs = append(errs, err) } } } - if err := r.setNetfilterMode(cfg.NetfilterMode); err != nil { + if err := r.setNetfilterModeLocked(cfg.NetfilterMode); err != nil { errs = append(errs, err) } @@ -425,11 +468,11 @@ func (r *linuxRouter) Set(cfg *Config) error { case cfg.StatefulFiltering == r.statefulFiltering: // state already correct, nothing to do. case cfg.StatefulFiltering: - if err := r.addStatefulRule(); err != nil { + if err := r.addStatefulRuleLocked(); err != nil { errs = append(errs, err) } default: - if err := r.delStatefulRule(); err != nil { + if err := r.delStatefulRuleLocked(); err != nil { errs = append(errs, err) } } @@ -438,11 +481,11 @@ func (r *linuxRouter) Set(cfg *Config) error { // Issue 11405: enable IP forwarding on gokrazy. advertisingRoutes := len(cfg.SubnetRoutes) > 0 - if distro.Get() == distro.Gokrazy && advertisingRoutes { + if getDistroFunc() == distro.Gokrazy && advertisingRoutes { r.enableIPForwarding() } - return multierr.New(errs...) + return errors.Join(errs...) } var dockerStatefulFilteringWarnable = health.Register(&health.Warnable{ @@ -452,7 +495,7 @@ var dockerStatefulFilteringWarnable = health.Register(&health.Warnable{ Text: health.StaticMessage("Stateful filtering is enabled and Docker was detected; this may prevent Docker containers on this host from resolving DNS and connecting to Tailscale nodes. See https://tailscale.com/s/stateful-docker"), }) -func (r *linuxRouter) updateStatefulFilteringWithDockerWarning(cfg *Config) { +func (r *linuxRouter) updateStatefulFilteringWithDockerWarning(cfg *router.Config) { // If stateful filtering is disabled, clear the warning. if !r.statefulFiltering { r.health.SetHealthy(dockerStatefulFilteringWarnable) @@ -493,10 +536,12 @@ func (r *linuxRouter) updateStatefulFilteringWithDockerWarning(cfg *Config) { r.health.SetHealthy(dockerStatefulFilteringWarnable) } -// UpdateMagicsockPort implements the Router interface. -func (r *linuxRouter) UpdateMagicsockPort(port uint16, network string) error { +// updateMagicsockPort implements the Router interface. +func (r *linuxRouter) updateMagicsockPort(port uint16, network string) error { + r.mu.Lock() + defer r.mu.Unlock() if r.nfr == nil { - if err := r.setupNetfilter(r.netfilterKind); err != nil { + if err := r.setupNetfilterLocked(r.netfilterKind); err != nil { return fmt.Errorf("could not setup netfilter: %w", err) } } @@ -536,7 +581,7 @@ func (r *linuxRouter) UpdateMagicsockPort(port uint16, network string) error { } if port != 0 { - if err := r.nfr.AddMagicsockPortRule(*magicsockPort, network); err != nil { + if err := r.nfr.AddMagicsockPortRule(port, network); err != nil { return fmt.Errorf("add magicsock port rule: %w", err) } } @@ -545,19 +590,17 @@ func (r *linuxRouter) UpdateMagicsockPort(port uint16, network string) error { return nil } -// setNetfilterMode switches the router to the given netfilter +// setNetfilterModeLocked switches the router to the given netfilter // mode. Netfilter state is created or deleted appropriately to // reflect the new mode, and r.snatSubnetRoutes is updated to reflect // the current state of subnet SNATing. -func (r *linuxRouter) setNetfilterMode(mode preftype.NetfilterMode) error { +func (r *linuxRouter) setNetfilterModeLocked(mode preftype.NetfilterMode) error { if !platformCanNetfilter() { mode = netfilterOff } if r.nfr == nil { - var err error - r.nfr, err = linuxfw.New(r.logf, r.netfilterKind) - if err != nil { + if err := r.setupNetfilterLocked(r.netfilterKind); err != nil { return err } } @@ -1181,7 +1224,9 @@ var ( tailscaleRouteTable = newRouteTable("tailscale", 52) ) -// ipRules are the policy routing rules that Tailscale uses. +// baseIPRules are the policy routing rules that Tailscale uses, when not +// running on a UBNT device. +// // The priority is the value represented here added to r.ipPolicyPrefBase, // which is usually 5200. // @@ -1196,19 +1241,19 @@ var ( // and 'ip rule' implementations (including busybox), don't support // checking for the lack of a fwmark, only the presence. The technique // below works even on very old kernels. -var ipRules = []netlink.Rule{ +var baseIPRules = []netlink.Rule{ // Packets from us, tagged with our fwmark, first try the kernel's // main routing table. { Priority: 10, - Mark: linuxfw.TailscaleBypassMarkNum, + Mark: tsconst.LinuxBypassMarkNum, Table: mainRouteTable.Num, }, // ...and then we try the 'default' table, for correctness, // even though it's been empty on every Linux system I've ever seen. { Priority: 30, - Mark: linuxfw.TailscaleBypassMarkNum, + Mark: tsconst.LinuxBypassMarkNum, Table: defaultRouteTable.Num, }, // If neither of those matched (no default route on this system?) @@ -1216,7 +1261,7 @@ var ipRules = []netlink.Rule{ // to the tailscale routes, because that would create routing loops. { Priority: 50, - Mark: linuxfw.TailscaleBypassMarkNum, + Mark: tsconst.LinuxBypassMarkNum, Type: unix.RTN_UNREACHABLE, }, // If we get to this point, capture all packets and send them @@ -1232,6 +1277,34 @@ var ipRules = []netlink.Rule{ // usual rules (pref 32766 and 32767, ie. main and default). } +// ubntIPRules are the policy routing rules that Tailscale uses, when running +// on a UBNT device. +// +// The priority is the value represented here added to +// r.ipPolicyPrefBase, which is usually 5200. +// +// This represents an experiment that will be used to gather more information. +// If this goes well, Tailscale may opt to use this for all of Linux. +var ubntIPRules = []netlink.Rule{ + // non-fwmark packets fall through to the usual rules (pref 32766 and 32767, + // ie. main and default). + { + Priority: 70, + Invert: true, + Mark: tsconst.LinuxBypassMarkNum, + Table: tailscaleRouteTable.Num, + }, +} + +// ipRules returns the appropriate list of ip rules to be used by Tailscale. See +// comments on baseIPRules and ubntIPRules for more details. +func ipRules() []netlink.Rule { + if getDistroFunc() == distro.UBNT { + return ubntIPRules + } + return baseIPRules +} + // justAddIPRules adds policy routing rule without deleting any first. func (r *linuxRouter) justAddIPRules() error { if !r.ipRuleAvailable { @@ -1242,12 +1315,11 @@ func (r *linuxRouter) justAddIPRules() error { } var errAcc error for _, family := range r.addrFamilies() { - - for _, ru := range ipRules { + for _, ru := range ipRules() { // Note: r is a value type here; safe to mutate it. ru.Family = family.netlinkInt() if ru.Mark != 0 { - ru.Mask = linuxfw.TailscaleFwmarkMaskNum + ru.Mask = tsconst.LinuxFwmarkMaskNum } ru.Goto = -1 ru.SuppressIfgroup = -1 @@ -1272,7 +1344,7 @@ func (r *linuxRouter) addIPRulesWithIPCommand() error { rg := newRunGroup(nil, r.cmd) for _, family := range r.addrFamilies() { - for _, rule := range ipRules { + for _, rule := range ipRules() { args := []string{ "ip", family.dashArg(), "rule", "add", @@ -1280,7 +1352,7 @@ func (r *linuxRouter) addIPRulesWithIPCommand() error { } if rule.Mark != 0 { if r.fwmaskWorks() { - args = append(args, "fwmark", fmt.Sprintf("0x%x/%s", rule.Mark, linuxfw.TailscaleFwmarkMask)) + args = append(args, "fwmark", fmt.Sprintf("0x%x/%s", rule.Mark, tsconst.LinuxFwmarkMask)) } else { args = append(args, "fwmark", fmt.Sprintf("0x%x", rule.Mark)) } @@ -1320,7 +1392,7 @@ func (r *linuxRouter) delIPRules() error { } var errAcc error for _, family := range r.addrFamilies() { - for _, ru := range ipRules { + for _, ru := range ipRules() { // Note: r is a value type here; safe to mutate it. // When deleting rules, we want to be a bit specific (mention which // table we were routing to) but not *too* specific (fwmarks, etc). @@ -1363,7 +1435,7 @@ func (r *linuxRouter) delIPRulesWithIPCommand() error { // That leaves us some flexibility to change these values in later // versions without having ongoing hacks for every possible // combination. - for _, rule := range ipRules { + for _, rule := range ipRules() { args := []string{ "ip", family.dashArg(), "rule", "del", @@ -1407,9 +1479,9 @@ func (r *linuxRouter) delSNATRule() error { return nil } -// addStatefulRule adds a netfilter rule to perform stateful filtering from +// addStatefulRuleLocked adds a netfilter rule to perform stateful filtering from // subnets onto the tailnet. -func (r *linuxRouter) addStatefulRule() error { +func (r *linuxRouter) addStatefulRuleLocked() error { if r.netfilterMode == netfilterOff { return nil } @@ -1417,9 +1489,9 @@ func (r *linuxRouter) addStatefulRule() error { return r.nfr.AddStatefulRule(r.tunname) } -// delStatefulRule removes the netfilter rule to perform stateful filtering +// delStatefulRuleLocked removes the netfilter rule to perform stateful filtering // from subnets onto the tailnet. -func (r *linuxRouter) delStatefulRule() error { +func (r *linuxRouter) delStatefulRuleLocked() error { if r.netfilterMode == netfilterOff { return nil } @@ -1500,7 +1572,7 @@ func normalizeCIDR(cidr netip.Prefix) string { // platformCanNetfilter reports whether the current distro/environment supports // running iptables/nftables commands. func platformCanNetfilter() bool { - switch distro.Get() { + switch getDistroFunc() { case distro.Synology: // Synology doesn't support iptables or nftables. Attempting to run it // just blocks for a long time while it logs about failures. @@ -1526,7 +1598,7 @@ func cleanUp(logf logger.Logf, interfaceName string) { // of the config file being present as well as a policy rule with a specific // priority (2000 + 1 - first interface mwan3 manages) and non-zero mark. func checkOpenWRTUsingMWAN3() (bool, error) { - if distro.Get() != distro.OpenWrt { + if getDistroFunc() != distro.OpenWrt { return false, nil } @@ -1545,7 +1617,7 @@ func checkOpenWRTUsingMWAN3() (bool, error) { // We want to match on a rule like this: // 2001: from all fwmark 0x100/0x3f00 lookup 1 // - // We dont match on the mask because it can vary, or the + // We don't match on the mask because it can vary, or the // table because I'm not sure if it can vary. if r.Priority >= 2001 && r.Priority <= 2004 && r.Mark != 0 { return true, nil diff --git a/wgengine/router/router_linux_test.go b/wgengine/router/osrouter/router_linux_test.go similarity index 89% rename from wgengine/router/router_linux_test.go rename to wgengine/router/osrouter/router_linux_test.go index dce69550d..68ed8dbb2 100644 --- a/wgengine/router/router_linux_test.go +++ b/wgengine/router/osrouter/router_linux_test.go @@ -1,7 +1,7 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -package router +package osrouter import ( "errors" @@ -25,11 +25,18 @@ import ( "tailscale.com/health" "tailscale.com/net/netmon" "tailscale.com/net/tsaddr" + "tailscale.com/tsconst" "tailscale.com/tstest" "tailscale.com/types/logger" + "tailscale.com/util/eventbus" + "tailscale.com/util/eventbus/eventbustest" "tailscale.com/util/linuxfw" + "tailscale.com/version/distro" + "tailscale.com/wgengine/router" ) +type Config = router.Config + func TestRouterStates(t *testing.T) { basic := ` ip rule add -4 pref 5210 fwmark 0x80000/0xff0000 table main @@ -362,7 +369,9 @@ ip route add throw 192.168.0.0/24 table 52` + basic, }, } - mon, err := netmon.New(logger.Discard) + bus := eventbus.New() + defer bus.Close() + mon, err := netmon.New(bus, logger.Discard) if err != nil { t.Fatal(err) } @@ -370,8 +379,8 @@ ip route add throw 192.168.0.0/24 table 52` + basic, defer mon.Close() fake := NewFakeOS(t) - ht := new(health.Tracker) - router, err := newUserspaceRouterAdvanced(t.Logf, "tailscale0", mon, fake, ht) + ht := health.NewTracker(bus) + router, err := newUserspaceRouterAdvanced(t.Logf, "tailscale0", mon, fake, ht, bus) router.(*linuxRouter).nfr = fake.nfr if err != nil { t.Fatalf("failed to create router: %v", err) @@ -410,7 +419,7 @@ type fakeIPTablesRunner struct { t *testing.T ipt4 map[string][]string ipt6 map[string][]string - //we always assume ipv6 and ipv6 nat are enabled when testing + // we always assume ipv6 and ipv6 nat are enabled when testing } func newIPTablesRunner(t *testing.T) linuxfw.NetfilterRunner { @@ -537,6 +546,7 @@ func (n *fakeIPTablesRunner) EnsureSNATForDst(src, dst netip.Addr) error { func (n *fakeIPTablesRunner) DNATNonTailscaleTraffic(exemptInterface string, dst netip.Addr) error { return errors.New("not implemented") } + func (n *fakeIPTablesRunner) EnsurePortMapRuleForSvc(svc, tun string, targetIP netip.Addr, pm linuxfw.PortMap) error { return errors.New("not implemented") } @@ -553,13 +563,21 @@ func (n *fakeIPTablesRunner) ClampMSSToPMTU(tun string, addr netip.Addr) error { return errors.New("not implemented") } +func (n *fakeIPTablesRunner) EnsureDNATRuleForSvc(svcName string, origDst, dst netip.Addr) error { + return errors.New("not implemented") +} + +func (n *fakeIPTablesRunner) DeleteDNATRuleForSvc(svcName string, origDst, dst netip.Addr) error { + return errors.New("not implemented") +} + func (n *fakeIPTablesRunner) addBase4(tunname string) error { curIPT := n.ipt4 newRules := []struct{ chain, rule string }{ {"filter/ts-input", fmt.Sprintf("! -i %s -s %s -j RETURN", tunname, tsaddr.ChromeOSVMRange().String())}, {"filter/ts-input", fmt.Sprintf("! -i %s -s %s -j DROP", tunname, tsaddr.CGNATRange().String())}, - {"filter/ts-forward", fmt.Sprintf("-i %s -j MARK --set-mark %s/%s", tunname, linuxfw.TailscaleSubnetRouteMark, linuxfw.TailscaleFwmarkMask)}, - {"filter/ts-forward", fmt.Sprintf("-m mark --mark %s/%s -j ACCEPT", linuxfw.TailscaleSubnetRouteMark, linuxfw.TailscaleFwmarkMask)}, + {"filter/ts-forward", fmt.Sprintf("-i %s -j MARK --set-mark %s/%s", tunname, tsconst.LinuxSubnetRouteMark, tsconst.LinuxFwmarkMask)}, + {"filter/ts-forward", fmt.Sprintf("-m mark --mark %s/%s -j ACCEPT", tsconst.LinuxSubnetRouteMark, tsconst.LinuxFwmarkMask)}, {"filter/ts-forward", fmt.Sprintf("-o %s -s %s -j DROP", tunname, tsaddr.CGNATRange().String())}, {"filter/ts-forward", fmt.Sprintf("-o %s -j ACCEPT", tunname)}, } @@ -574,8 +592,8 @@ func (n *fakeIPTablesRunner) addBase4(tunname string) error { func (n *fakeIPTablesRunner) addBase6(tunname string) error { curIPT := n.ipt6 newRules := []struct{ chain, rule string }{ - {"filter/ts-forward", fmt.Sprintf("-i %s -j MARK --set-mark %s/%s", tunname, linuxfw.TailscaleSubnetRouteMark, linuxfw.TailscaleFwmarkMask)}, - {"filter/ts-forward", fmt.Sprintf("-m mark --mark %s/%s -j ACCEPT", linuxfw.TailscaleSubnetRouteMark, linuxfw.TailscaleFwmarkMask)}, + {"filter/ts-forward", fmt.Sprintf("-i %s -j MARK --set-mark %s/%s", tunname, tsconst.LinuxSubnetRouteMark, tsconst.LinuxFwmarkMask)}, + {"filter/ts-forward", fmt.Sprintf("-m mark --mark %s/%s -j ACCEPT", tsconst.LinuxSubnetRouteMark, tsconst.LinuxFwmarkMask)}, {"filter/ts-forward", fmt.Sprintf("-o %s -j ACCEPT", tunname)}, } for _, rule := range newRules { @@ -659,7 +677,7 @@ func (n *fakeIPTablesRunner) DelBase() error { } func (n *fakeIPTablesRunner) AddSNATRule() error { - newRule := fmt.Sprintf("-m mark --mark %s/%s -j MASQUERADE", linuxfw.TailscaleSubnetRouteMark, linuxfw.TailscaleFwmarkMask) + newRule := fmt.Sprintf("-m mark --mark %s/%s -j MASQUERADE", tsconst.LinuxSubnetRouteMark, tsconst.LinuxFwmarkMask) for _, ipt := range []map[string][]string{n.ipt4, n.ipt6} { if err := appendRule(n, ipt, "nat/ts-postrouting", newRule); err != nil { return err @@ -669,7 +687,7 @@ func (n *fakeIPTablesRunner) AddSNATRule() error { } func (n *fakeIPTablesRunner) DelSNATRule() error { - delRule := fmt.Sprintf("-m mark --mark %s/%s -j MASQUERADE", linuxfw.TailscaleSubnetRouteMark, linuxfw.TailscaleFwmarkMask) + delRule := fmt.Sprintf("-m mark --mark %s/%s -j MASQUERADE", tsconst.LinuxSubnetRouteMark, tsconst.LinuxFwmarkMask) for _, ipt := range []map[string][]string{n.ipt4, n.ipt6} { if err := deleteRule(n, ipt, "nat/ts-postrouting", delRule); err != nil { return err @@ -769,8 +787,8 @@ type fakeOS struct { ips []string routes []string rules []string - //This test tests on the router level, so we will not bother - //with using iptables or nftables, chose the simpler one. + // This test tests on the router level, so we will not bother + // with using iptables or nftables, chose the simpler one. nfr linuxfw.NetfilterRunner } @@ -852,7 +870,7 @@ func (o *fakeOS) run(args ...string) error { rest = family + " " + strings.Join(args[3:], " ") } - var l *[]string + var ls *[]string switch args[1] { case "link": got := strings.Join(args[2:], " ") @@ -866,31 +884,31 @@ func (o *fakeOS) run(args ...string) error { } return nil case "addr": - l = &o.ips + ls = &o.ips case "route": - l = &o.routes + ls = &o.routes case "rule": - l = &o.rules + ls = &o.rules default: return unexpected() } switch args[2] { case "add": - for _, el := range *l { + for _, el := range *ls { if el == rest { o.t.Errorf("can't add %q, already present", rest) return errors.New("already exists") } } - *l = append(*l, rest) - sort.Strings(*l) + *ls = append(*ls, rest) + sort.Strings(*ls) case "del": found := false - for i, el := range *l { + for i, el := range *ls { if el == rest { found = true - *l = append((*l)[:i], (*l)[i+1:]...) + *ls = append((*ls)[:i], (*ls)[i+1:]...) break } } @@ -962,7 +980,7 @@ func (lt *linuxTest) Close() error { return nil } -func newLinuxRootTest(t *testing.T) *linuxTest { +func newLinuxRootTest(t *testing.T) (*linuxTest, *eventbus.Bus) { if os.Getuid() != 0 { t.Skip("test requires root") } @@ -972,7 +990,9 @@ func newLinuxRootTest(t *testing.T) *linuxTest { logf := lt.logOutput.Logf - mon, err := netmon.New(logger.Discard) + bus := eventbustest.NewBus(t) + + mon, err := netmon.New(bus, logger.Discard) if err != nil { lt.Close() t.Fatal(err) @@ -980,7 +1000,7 @@ func newLinuxRootTest(t *testing.T) *linuxTest { mon.Start() lt.mon = mon - r, err := newUserspaceRouter(logf, lt.tun, mon, nil) + r, err := newUserspaceRouter(logf, lt.tun, mon, nil, bus) if err != nil { lt.Close() t.Fatal(err) @@ -991,11 +1011,31 @@ func newLinuxRootTest(t *testing.T) *linuxTest { t.Fatal(err) } lt.r = lr - return lt + return lt, bus +} + +func TestRuleDeletedEvent(t *testing.T) { + fake := NewFakeOS(t) + lt, bus := newLinuxRootTest(t) + lt.r.nfr = fake.nfr + defer lt.Close() + event := netmon.RuleDeleted{ + Table: 52, + Priority: 5210, + } + tw := eventbustest.NewWatcher(t, bus) + + t.Logf("Value before: %t", lt.r.ruleRestorePending.Load()) + if lt.r.ruleRestorePending.Load() { + t.Errorf("rule deletion already ongoing") + } + injector := eventbustest.NewInjector(t, bus) + eventbustest.Inject(injector, event) + eventbustest.Expect(tw, eventbustest.Type[AddIPRules]()) } func TestDelRouteIdempotent(t *testing.T) { - lt := newLinuxRootTest(t) + lt, _ := newLinuxRootTest(t) defer lt.Close() for _, s := range []string{ @@ -1021,7 +1061,7 @@ func TestDelRouteIdempotent(t *testing.T) { } func TestAddRemoveRules(t *testing.T) { - lt := newLinuxRootTest(t) + lt, _ := newLinuxRootTest(t) defer lt.Close() r := lt.r @@ -1039,14 +1079,12 @@ func TestAddRemoveRules(t *testing.T) { t.Logf("Rule: %+v", r) } } - } step("init_del_and_add", r.addIPRules) step("dup_add", r.justAddIPRules) step("del", r.delIPRules) step("dup_del", r.delIPRules) - } func TestDebugListLinks(t *testing.T) { @@ -1231,3 +1269,64 @@ func adjustFwmask(t *testing.T, s string) string { return fwmaskAdjustRe.ReplaceAllString(s, "$1") } + +func TestIPRulesForUBNT(t *testing.T) { + // Override the global getDistroFunc + getDistroFunc = func() distro.Distro { + return distro.UBNT + } + defer func() { getDistroFunc = distro.Get }() // Restore original after the test + + expected := ubntIPRules + actual := ipRules() + + if len(expected) != len(actual) { + t.Fatalf("Expected %d rules, got %d", len(expected), len(actual)) + } + + for i, rule := range expected { + if rule != actual[i] { + t.Errorf("Rule mismatch at index %d: expected %+v, got %+v", i, rule, actual[i]) + } + } +} + +func TestUpdateMagicsockPortChange(t *testing.T) { + nfr := &fakeIPTablesRunner{ + t: t, + ipt4: make(map[string][]string), + ipt6: make(map[string][]string), + } + nfr.ipt4["filter/ts-input"] = []string{} + + r := &linuxRouter{ + logf: logger.Discard, + health: new(health.Tracker), + netfilterMode: netfilterOn, + nfr: nfr, + } + + if err := r.updateMagicsockPort(12345, "udp4"); err != nil { + t.Fatalf("failed to set initial port: %v", err) + } + + if err := r.updateMagicsockPort(54321, "udp4"); err != nil { + t.Fatalf("failed to update port: %v", err) + } + + newPortRule := buildMagicsockPortRule(54321) + hasNewRule := slices.Contains(nfr.ipt4["filter/ts-input"], newPortRule) + + if !hasNewRule { + t.Errorf("firewall rule for NEW port 54321 not found.\nExpected: %s\nActual rules: %v", + newPortRule, nfr.ipt4["filter/ts-input"]) + } + + oldPortRule := buildMagicsockPortRule(12345) + hasOldRule := slices.Contains(nfr.ipt4["filter/ts-input"], oldPortRule) + + if hasOldRule { + t.Errorf("firewall rule for OLD port 12345 still exists (should be deleted).\nFound: %s\nAll rules: %v", + oldPortRule, nfr.ipt4["filter/ts-input"]) + } +} diff --git a/wgengine/router/router_openbsd.go b/wgengine/router/osrouter/router_openbsd.go similarity index 90% rename from wgengine/router/router_openbsd.go rename to wgengine/router/osrouter/router_openbsd.go index 6fdd47ac9..55b485f0e 100644 --- a/wgengine/router/router_openbsd.go +++ b/wgengine/router/osrouter/router_openbsd.go @@ -1,7 +1,7 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -package router +package osrouter import ( "errors" @@ -15,11 +15,20 @@ import ( "tailscale.com/health" "tailscale.com/net/netmon" "tailscale.com/types/logger" + "tailscale.com/util/eventbus" "tailscale.com/util/set" + "tailscale.com/wgengine/router" ) -// For now this router only supports the WireGuard userspace implementation. -// There is an experimental kernel version in the works for OpenBSD: +func init() { + router.HookNewUserspaceRouter.Set(func(opts router.NewOpts) (router.Router, error) { + return newUserspaceRouter(opts.Logf, opts.Tun, opts.NetMon, opts.Health, opts.Bus) + }) + router.HookCleanUp.Set(func(logf logger.Logf, netMon *netmon.Monitor, ifName string) { + cleanUp(logf, ifName) + }) +} + // https://git.zx2c4.com/wireguard-openbsd. type openbsdRouter struct { @@ -31,7 +40,7 @@ type openbsdRouter struct { routes set.Set[netip.Prefix] } -func newUserspaceRouter(logf logger.Logf, tundev tun.Device, netMon *netmon.Monitor, health *health.Tracker) (Router, error) { +func newUserspaceRouter(logf logger.Logf, tundev tun.Device, netMon *netmon.Monitor, health *health.Tracker, bus *eventbus.Bus) (router.Router, error) { tunname, err := tundev.Name() if err != nil { return nil, err @@ -67,7 +76,7 @@ func inet(p netip.Prefix) string { return "inet" } -func (r *openbsdRouter) Set(cfg *Config) error { +func (r *openbsdRouter) Set(cfg *router.Config) error { if cfg == nil { cfg = &shutdownConfig } @@ -229,13 +238,6 @@ func (r *openbsdRouter) Set(cfg *Config) error { return errq } -// UpdateMagicsockPort implements the Router interface. This implementation -// does nothing and returns nil because this router does not currently need -// to know what the magicsock UDP port is. -func (r *openbsdRouter) UpdateMagicsockPort(_ uint16, _ string) error { - return nil -} - func (r *openbsdRouter) Close() error { cleanUp(r.logf, r.tunname) return nil diff --git a/wgengine/router/osrouter/router_plan9.go b/wgengine/router/osrouter/router_plan9.go new file mode 100644 index 000000000..a5b461a6f --- /dev/null +++ b/wgengine/router/osrouter/router_plan9.go @@ -0,0 +1,155 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package osrouter + +import ( + "bufio" + "bytes" + "fmt" + "net/netip" + "os" + "strings" + + "github.com/tailscale/wireguard-go/tun" + "tailscale.com/health" + "tailscale.com/net/netmon" + "tailscale.com/types/logger" + "tailscale.com/wgengine/router" +) + +func init() { + router.HookCleanUp.Set(func(logf logger.Logf, netMon *netmon.Monitor, ifName string) { + cleanAllTailscaleRoutes(logf) + }) + router.HookNewUserspaceRouter.Set(func(opts router.NewOpts) (router.Router, error) { + return newUserspaceRouter(opts.Logf, opts.Tun, opts.NetMon) + }) +} + +func newUserspaceRouter(logf logger.Logf, tundev tun.Device, netMon *netmon.Monitor) (router.Router, error) { + r := &plan9Router{ + logf: logf, + tundev: tundev, + netMon: netMon, + } + cleanAllTailscaleRoutes(logf) + return r, nil +} + +type plan9Router struct { + logf logger.Logf + tundev tun.Device + netMon *netmon.Monitor + health *health.Tracker +} + +func (r *plan9Router) Up() error { + return nil +} + +func (r *plan9Router) Set(cfg *router.Config) error { + if cfg == nil { + cleanAllTailscaleRoutes(r.logf) + return nil + } + + var self4, self6 netip.Addr + for _, addr := range cfg.LocalAddrs { + ctl := r.tundev.File() + maskBits := addr.Bits() + if addr.Addr().Is4() { + // The mask sizes in Plan9 are in IPv6 bits, even for IPv4. + maskBits += (128 - 32) + self4 = addr.Addr() + } + if addr.Addr().Is6() { + self6 = addr.Addr() + } + _, err := fmt.Fprintf(ctl, "add %s /%d\n", addr.Addr().String(), maskBits) + r.logf("route/plan9: add %s /%d = %v", addr.Addr().String(), maskBits, err) + } + + ipr, err := os.OpenFile("/net/iproute", os.O_RDWR, 0) + if err != nil { + return fmt.Errorf("open /net/iproute: %w", err) + } + defer ipr.Close() + + // TODO(bradfitz): read existing routes, delete ones tagged "tail" + // that aren't in cfg.LocalRoutes. + + if _, err := fmt.Fprintf(ipr, "tag tail\n"); err != nil { + return fmt.Errorf("tag tail: %w", err) + } + + for _, route := range cfg.Routes { + maskBits := route.Bits() + if route.Addr().Is4() { + // The mask sizes in Plan9 are in IPv6 bits, even for IPv4. + maskBits += (128 - 32) + } + var nextHop netip.Addr + if route.Addr().Is4() { + nextHop = self4 + } else if route.Addr().Is6() { + nextHop = self6 + } + if !nextHop.IsValid() { + r.logf("route/plan9: skipping route %s: no next hop (no self addr)", route.String()) + continue + } + r.logf("route/plan9: plan9.router: add %s /%d %s", route.Addr(), maskBits, nextHop) + if _, err := fmt.Fprintf(ipr, "add %s /%d %s\n", route.Addr(), maskBits, nextHop); err != nil { + return fmt.Errorf("add %s: %w", route.String(), err) + } + } + + if len(cfg.LocalRoutes) > 0 { + r.logf("route/plan9: TODO: Set LocalRoutes %v", cfg.LocalRoutes) + } + if len(cfg.SubnetRoutes) > 0 { + r.logf("route/plan9: TODO: Set SubnetRoutes %v", cfg.SubnetRoutes) + } + + return nil +} + +func (r *plan9Router) Close() error { + // TODO(bradfitz): unbind + return nil +} + +func cleanAllTailscaleRoutes(logf logger.Logf) { + routes, err := os.OpenFile("/net/iproute", os.O_RDWR, 0) + if err != nil { + logf("cleaning routes: %v", err) + return + } + defer routes.Close() + + // Using io.ReadAll or os.ReadFile on /net/iproute fails; it results in a + // 511 byte result when the actual /net/iproute contents are over 1k. + // So do it in one big read instead. Who knows. + routeBuf := make([]byte, 1<<20) + n, err := routes.Read(routeBuf) + if err != nil { + logf("cleaning routes: %v", err) + return + } + routeBuf = routeBuf[:n] + + bs := bufio.NewScanner(bytes.NewReader(routeBuf)) + for bs.Scan() { + f := strings.Fields(bs.Text()) + if len(f) < 6 { + continue + } + tag := f[4] + if tag != "tail" { + continue + } + _, err := fmt.Fprintf(routes, "remove %s %s\n", f[0], f[1]) + logf("router: cleaning route %s %s: %v", f[0], f[1], err) + } +} diff --git a/wgengine/router/router_userspace_bsd.go b/wgengine/router/osrouter/router_userspace_bsd.go similarity index 92% rename from wgengine/router/router_userspace_bsd.go rename to wgengine/router/osrouter/router_userspace_bsd.go index 0b7e4f36a..70ef2b6bf 100644 --- a/wgengine/router/router_userspace_bsd.go +++ b/wgengine/router/osrouter/router_userspace_bsd.go @@ -3,7 +3,7 @@ //go:build darwin || freebsd -package router +package osrouter import ( "fmt" @@ -19,8 +19,15 @@ import ( "tailscale.com/net/tsaddr" "tailscale.com/types/logger" "tailscale.com/version" + "tailscale.com/wgengine/router" ) +func init() { + router.HookNewUserspaceRouter.Set(func(opts router.NewOpts) (router.Router, error) { + return newUserspaceBSDRouter(opts.Logf, opts.Tun, opts.NetMon, opts.Health) + }) +} + type userspaceBSDRouter struct { logf logger.Logf netMon *netmon.Monitor @@ -30,7 +37,7 @@ type userspaceBSDRouter struct { routes map[netip.Prefix]bool } -func newUserspaceBSDRouter(logf logger.Logf, tundev tun.Device, netMon *netmon.Monitor, health *health.Tracker) (Router, error) { +func newUserspaceBSDRouter(logf logger.Logf, tundev tun.Device, netMon *netmon.Monitor, health *health.Tracker) (router.Router, error) { tunname, err := tundev.Name() if err != nil { return nil, err @@ -99,7 +106,7 @@ func inet(p netip.Prefix) string { return "inet" } -func (r *userspaceBSDRouter) Set(cfg *Config) (reterr error) { +func (r *userspaceBSDRouter) Set(cfg *router.Config) (reterr error) { if cfg == nil { cfg = &shutdownConfig } @@ -199,13 +206,6 @@ func (r *userspaceBSDRouter) Set(cfg *Config) (reterr error) { return reterr } -// UpdateMagicsockPort implements the Router interface. This implementation -// does nothing and returns nil because this router does not currently need -// to know what the magicsock UDP port is. -func (r *userspaceBSDRouter) UpdateMagicsockPort(_ uint16, _ string) error { - return nil -} - func (r *userspaceBSDRouter) Close() error { return nil } diff --git a/wgengine/router/router_windows.go b/wgengine/router/osrouter/router_windows.go similarity index 95% rename from wgengine/router/router_windows.go rename to wgengine/router/osrouter/router_windows.go index 64163660d..a1acbe3b6 100644 --- a/wgengine/router/router_windows.go +++ b/wgengine/router/osrouter/router_windows.go @@ -1,7 +1,7 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -package router +package osrouter import ( "bufio" @@ -23,12 +23,20 @@ import ( "golang.org/x/sys/windows" "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" "tailscale.com/health" - "tailscale.com/logtail/backoff" "tailscale.com/net/dns" "tailscale.com/net/netmon" "tailscale.com/types/logger" + "tailscale.com/util/backoff" + "tailscale.com/util/eventbus" + "tailscale.com/wgengine/router" ) +func init() { + router.HookNewUserspaceRouter.Set(func(opts router.NewOpts) (router.Router, error) { + return newUserspaceRouter(opts.Logf, opts.Tun, opts.NetMon, opts.Health, opts.Bus) + }) +} + type winRouter struct { logf func(fmt string, args ...any) netMon *netmon.Monitor // may be nil @@ -38,7 +46,7 @@ type winRouter struct { firewall *firewallTweaker } -func newUserspaceRouter(logf logger.Logf, tundev tun.Device, netMon *netmon.Monitor, health *health.Tracker) (Router, error) { +func newUserspaceRouter(logf logger.Logf, tundev tun.Device, netMon *netmon.Monitor, health *health.Tracker, bus *eventbus.Bus) (router.Router, error) { nativeTun := tundev.(*tun.NativeTun) luid := winipcfg.LUID(nativeTun.LUID()) guid, err := luid.GUID() @@ -72,7 +80,7 @@ func (r *winRouter) Up() error { return nil } -func (r *winRouter) Set(cfg *Config) error { +func (r *winRouter) Set(cfg *router.Config) error { if cfg == nil { cfg = &shutdownConfig } @@ -106,13 +114,6 @@ func hasDefaultRoute(routes []netip.Prefix) bool { return false } -// UpdateMagicsockPort implements the Router interface. This implementation -// does nothing and returns nil because this router does not currently need -// to know what the magicsock UDP port is. -func (r *winRouter) UpdateMagicsockPort(_ uint16, _ string) error { - return nil -} - func (r *winRouter) Close() error { r.firewall.clear() @@ -123,10 +124,6 @@ func (r *winRouter) Close() error { return nil } -func cleanUp(logf logger.Logf, interfaceName string) { - // Nothing to do here. -} - // firewallTweaker changes the Windows firewall. Normally this wouldn't be so complicated, // but it can be REALLY SLOW to change the Windows firewall for reasons not understood. // Like 4 minutes slow. But usually it's tens of milliseconds. diff --git a/wgengine/router/router_windows_test.go b/wgengine/router/osrouter/router_windows_test.go similarity index 95% rename from wgengine/router/router_windows_test.go rename to wgengine/router/osrouter/router_windows_test.go index 9989ddbc7..119b6a778 100644 --- a/wgengine/router/router_windows_test.go +++ b/wgengine/router/osrouter/router_windows_test.go @@ -1,7 +1,7 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -package router +package osrouter import ( "path/filepath" diff --git a/wgengine/router/runner.go b/wgengine/router/osrouter/runner.go similarity index 99% rename from wgengine/router/runner.go rename to wgengine/router/osrouter/runner.go index 8fa068e33..7afb7fdc2 100644 --- a/wgengine/router/runner.go +++ b/wgengine/router/osrouter/runner.go @@ -3,7 +3,7 @@ //go:build linux -package router +package osrouter import ( "errors" diff --git a/wgengine/router/router.go b/wgengine/router/router.go index 423008978..04cc89887 100644 --- a/wgengine/router/router.go +++ b/wgengine/router/router.go @@ -6,14 +6,21 @@ package router import ( + "errors" + "fmt" "net/netip" "reflect" + "runtime" + "slices" "github.com/tailscale/wireguard-go/tun" + "tailscale.com/feature" + "tailscale.com/feature/buildfeatures" "tailscale.com/health" "tailscale.com/net/netmon" "tailscale.com/types/logger" "tailscale.com/types/preftype" + "tailscale.com/util/eventbus" ) // Router is responsible for managing the system network stack. @@ -28,33 +35,70 @@ type Router interface { // implementation should handle gracefully. Set(*Config) error - // UpdateMagicsockPort tells the OS network stack what port magicsock - // is currently listening on, so it can be threaded through firewalls - // and such. This is distinct from Set() since magicsock may rebind - // ports independently from the Config changing. - // - // network should be either "udp4" or "udp6". - UpdateMagicsockPort(port uint16, network string) error - // Close closes the router. Close() error } +// NewOpts are the options passed to the NewUserspaceRouter hook. +type NewOpts struct { + Logf logger.Logf // required + Tun tun.Device // required + NetMon *netmon.Monitor // optional + Health *health.Tracker // required (but TODO: support optional later) + Bus *eventbus.Bus // required +} + +// PortUpdate is an eventbus value, reporting the port and address family +// magicsock is currently listening on, so it can be threaded through firewalls +// and such. +type PortUpdate struct { + UDPPort uint16 + EndpointNetwork string // either "udp4" or "udp6". +} + +// HookNewUserspaceRouter is the registration point for router implementations +// to register a constructor for userspace routers. It's meant for implementations +// in wgengine/router/osrouter. +// +// If no implementation is registered, [New] will return an error. +var HookNewUserspaceRouter feature.Hook[func(NewOpts) (Router, error)] + // New returns a new Router for the current platform, using the // provided tun device. // // If netMon is nil, it's not used. It's currently (2021-07-20) only // used on Linux in some situations. -func New(logf logger.Logf, tundev tun.Device, netMon *netmon.Monitor, health *health.Tracker) (Router, error) { +func New(logf logger.Logf, tundev tun.Device, netMon *netmon.Monitor, + health *health.Tracker, bus *eventbus.Bus, +) (Router, error) { logf = logger.WithPrefix(logf, "router: ") - return newUserspaceRouter(logf, tundev, netMon, health) + if f, ok := HookNewUserspaceRouter.GetOk(); ok { + return f(NewOpts{ + Logf: logf, + Tun: tundev, + NetMon: netMon, + Health: health, + Bus: bus, + }) + } + if !buildfeatures.HasOSRouter { + return nil, errors.New("router: tailscaled was built without OSRouter support") + } + return nil, fmt.Errorf("unsupported OS %q", runtime.GOOS) } +// HookCleanUp is the optional registration point for router implementations +// to register a cleanup function for [CleanUp] to use. It's meant for +// implementations in wgengine/router/osrouter. +var HookCleanUp feature.Hook[func(_ logger.Logf, _ *netmon.Monitor, ifName string)] + // CleanUp restores the system network configuration to its original state // in case the Tailscale daemon terminated without closing the router. // No other state needs to be instantiated before this runs. func CleanUp(logf logger.Logf, netMon *netmon.Monitor, interfaceName string) { - cleanUp(logf, interfaceName) + if f, ok := HookCleanUp.GetOk(); ok { + f(logf, netMon, interfaceName) + } } // Config is the subset of Tailscale configuration that is relevant to @@ -91,7 +135,7 @@ type Config struct { SNATSubnetRoutes bool // SNAT traffic to local subnets StatefulFiltering bool // Apply stateful filtering to inbound connections NetfilterMode preftype.NetfilterMode // how much to manage netfilter rules - NetfilterKind string // what kind of netfilter to use (nftables, iptables) + NetfilterKind string // what kind of netfilter to use ("nftables", "iptables", or "" to auto-detect) } func (a *Config) Equal(b *Config) bool { @@ -104,7 +148,14 @@ func (a *Config) Equal(b *Config) bool { return reflect.DeepEqual(a, b) } -// shutdownConfig is a routing configuration that removes all router -// state from the OS. It's the config used when callers pass in a nil -// Config. -var shutdownConfig = Config{} +func (c *Config) Clone() *Config { + if c == nil { + return nil + } + c2 := *c + c2.LocalAddrs = slices.Clone(c.LocalAddrs) + c2.Routes = slices.Clone(c.Routes) + c2.LocalRoutes = slices.Clone(c.LocalRoutes) + c2.SubnetRoutes = slices.Clone(c.SubnetRoutes) + return &c2 +} diff --git a/wgengine/router/router_darwin.go b/wgengine/router/router_darwin.go deleted file mode 100644 index 73e394b04..000000000 --- a/wgengine/router/router_darwin.go +++ /dev/null @@ -1,19 +0,0 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package router - -import ( - "github.com/tailscale/wireguard-go/tun" - "tailscale.com/health" - "tailscale.com/net/netmon" - "tailscale.com/types/logger" -) - -func newUserspaceRouter(logf logger.Logf, tundev tun.Device, netMon *netmon.Monitor, health *health.Tracker) (Router, error) { - return newUserspaceBSDRouter(logf, tundev, netMon, health) -} - -func cleanUp(logger.Logf, string) { - // Nothing to do. -} diff --git a/wgengine/router/router_default.go b/wgengine/router/router_default.go deleted file mode 100644 index 1e675d1fc..000000000 --- a/wgengine/router/router_default.go +++ /dev/null @@ -1,24 +0,0 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !windows && !linux && !darwin && !openbsd && !freebsd - -package router - -import ( - "fmt" - "runtime" - - "github.com/tailscale/wireguard-go/tun" - "tailscale.com/health" - "tailscale.com/net/netmon" - "tailscale.com/types/logger" -) - -func newUserspaceRouter(logf logger.Logf, tunDev tun.Device, netMon *netmon.Monitor, health *health.Tracker) (Router, error) { - return nil, fmt.Errorf("unsupported OS %q", runtime.GOOS) -} - -func cleanUp(logf logger.Logf, interfaceName string) { - // Nothing to do here. -} diff --git a/wgengine/router/router_fake.go b/wgengine/router/router_fake.go index 549867eca..db35fc9ee 100644 --- a/wgengine/router/router_fake.go +++ b/wgengine/router/router_fake.go @@ -27,11 +27,6 @@ func (r fakeRouter) Set(cfg *Config) error { return nil } -func (r fakeRouter) UpdateMagicsockPort(_ uint16, _ string) error { - r.logf("[v1] warning: fakeRouter.UpdateMagicsockPort: not implemented.") - return nil -} - func (r fakeRouter) Close() error { r.logf("[v1] warning: fakeRouter.Close: not implemented.") return nil diff --git a/wgengine/router/router_test.go b/wgengine/router/router_test.go index 8842173d7..fd17b8c5d 100644 --- a/wgengine/router/router_test.go +++ b/wgengine/router/router_test.go @@ -11,15 +11,6 @@ import ( "tailscale.com/types/preftype" ) -//lint:ignore U1000 used in Windows/Linux tests only -func mustCIDRs(ss ...string) []netip.Prefix { - var ret []netip.Prefix - for _, s := range ss { - ret = append(ret, netip.MustParsePrefix(s)) - } - return ret -} - func TestConfigEqual(t *testing.T) { testedFields := []string{ "LocalAddrs", "Routes", "LocalRoutes", "NewMTU", diff --git a/wgengine/userspace.go b/wgengine/userspace.go index fc204736a..8ad771fc5 100644 --- a/wgengine/userspace.go +++ b/wgengine/userspace.go @@ -10,8 +10,10 @@ import ( "errors" "fmt" "io" + "maps" "math" "net/netip" + "reflect" "runtime" "slices" "strings" @@ -23,17 +25,18 @@ import ( "tailscale.com/control/controlknobs" "tailscale.com/drive" "tailscale.com/envknob" + "tailscale.com/feature" + "tailscale.com/feature/buildfeatures" "tailscale.com/health" "tailscale.com/ipn/ipnstate" "tailscale.com/net/dns" - "tailscale.com/net/flowtrack" + "tailscale.com/net/dns/resolver" "tailscale.com/net/ipset" "tailscale.com/net/netmon" "tailscale.com/net/packet" "tailscale.com/net/sockstats" "tailscale.com/net/tsaddr" "tailscale.com/net/tsdial" - "tailscale.com/net/tshttpproxy" "tailscale.com/net/tstun" "tailscale.com/syncs" "tailscale.com/tailcfg" @@ -44,14 +47,15 @@ import ( "tailscale.com/types/logger" "tailscale.com/types/netmap" "tailscale.com/types/views" + "tailscale.com/util/backoff" + "tailscale.com/util/checkchange" "tailscale.com/util/clientmetric" - "tailscale.com/util/deephash" + "tailscale.com/util/eventbus" "tailscale.com/util/mak" "tailscale.com/util/set" "tailscale.com/util/testenv" "tailscale.com/util/usermetric" "tailscale.com/version" - "tailscale.com/wgengine/capture" "tailscale.com/wgengine/filter" "tailscale.com/wgengine/magicsock" "tailscale.com/wgengine/netlog" @@ -90,23 +94,27 @@ const statusPollInterval = 1 * time.Minute const networkLoggerUploadTimeout = 5 * time.Second type userspaceEngine struct { - logf logger.Logf - wgLogger *wglog.Logger //a wireguard-go logging wrapper - reqCh chan struct{} - waitCh chan struct{} // chan is closed when first Close call completes; contrast with closing bool - timeNow func() mono.Time - tundev *tstun.Wrapper - wgdev *device.Device - router router.Router - confListenPort uint16 // original conf.ListenPort - dns *dns.Manager - magicConn *magicsock.Conn - netMon *netmon.Monitor - health *health.Tracker - netMonOwned bool // whether we created netMon (and thus need to close it) - netMonUnregister func() // unsubscribes from changes; used regardless of netMonOwned - birdClient BIRDClient // or nil - controlKnobs *controlknobs.Knobs // or nil + // eventBus will eventually become required, but for now may be nil. + eventBus *eventbus.Bus + eventClient *eventbus.Client + + logf logger.Logf + wgLogger *wglog.Logger // a wireguard-go logging wrapper + reqCh chan struct{} + waitCh chan struct{} // chan is closed when first Close call completes; contrast with closing bool + timeNow func() mono.Time + tundev *tstun.Wrapper + wgdev *device.Device + router router.Router + dialer *tsdial.Dialer + confListenPort uint16 // original conf.ListenPort + dns *dns.Manager + magicConn *magicsock.Conn + netMon *netmon.Monitor + health *health.Tracker + netMonOwned bool // whether we created netMon (and thus need to close it) + birdClient BIRDClient // or nil + controlKnobs *controlknobs.Knobs // or nil testMaybeReconfigHook func() // for tests; if non-nil, fires if maybeReconfigWireguardLocked called @@ -122,11 +130,11 @@ type userspaceEngine struct { wgLock sync.Mutex // serializes all wgdev operations; see lock order comment below lastCfgFull wgcfg.Config lastNMinPeers int - lastRouterSig deephash.Sum // of router.Config - lastEngineSigFull deephash.Sum // of full wireguard config - lastEngineSigTrim deephash.Sum // of trimmed wireguard config - lastDNSConfig *dns.Config - lastIsSubnetRouter bool // was the node a primary subnet router in the last run. + lastRouter *router.Config + lastEngineFull *wgcfg.Config // of full wireguard config, not trimmed + lastEngineInputs *maybeReconfigInputs + lastDNSConfig dns.ConfigView // or invalid if none + lastIsSubnetRouter bool // was the node a primary subnet router in the last run. recvActivityAt map[key.NodePublic]mono.Time trimmedNodes map[key.NodePublic]bool // set of node keys of peers currently excluded from wireguard config sentActivityAt map[netip.Addr]*mono.Time // value is accessed atomically @@ -138,9 +146,9 @@ type userspaceEngine struct { netMap *netmap.NetworkMap // or nil closing bool // Close was called (even if we're still closing) statusCallback StatusCallback - peerSequence []key.NodePublic + peerSequence views.Slice[key.NodePublic] endpoints []tailcfg.Endpoint - pendOpen map[flowtrack.Tuple]*pendingOpenFlow // see pendopen.go + pendOpen map[flowtrackTuple]*pendingOpenFlow // see pendopen.go // pongCallback is the map of response handlers waiting for disco or TSMP // pong callbacks. The map key is a random slice of bytes. @@ -228,6 +236,13 @@ type Config struct { // DriveForLocal, if populated, will cause the engine to expose a Taildrive // listener at 100.100.100.100:8080. DriveForLocal drive.FileSystemForLocal + + // EventBus, if non-nil, is used for event publication and subscription by + // the Engine and its subsystems. + // + // TODO(creachadair): As of 2025-03-19 this is optional, but is intended to + // become required non-nil. + EventBus *eventbus.Bus } // NewFakeUserspaceEngine returns a new userspace engine for testing. @@ -256,6 +271,8 @@ func NewFakeUserspaceEngine(logf logger.Logf, opts ...any) (Engine, error) { conf.HealthTracker = v case *usermetric.Registry: conf.Metrics = v + case *eventbus.Bus: + conf.EventBus = v default: return nil, fmt.Errorf("unknown option type %T", v) } @@ -296,6 +313,9 @@ func NewUserspaceEngine(logf logger.Logf, conf Config) (_ Engine, reterr error) } if conf.Dialer == nil { conf.Dialer = &tsdial.Dialer{Logf: logf} + if conf.EventBus != nil { + conf.Dialer.SetBus(conf.EventBus) + } } var tsTUNDev *tstun.Wrapper @@ -324,12 +344,14 @@ func NewUserspaceEngine(logf logger.Logf, conf Config) (_ Engine, reterr error) } e := &userspaceEngine{ + eventBus: conf.EventBus, timeNow: mono.Now, logf: logf, reqCh: make(chan struct{}, 1), waitCh: make(chan struct{}), tundev: tsTUNDev, router: rtr, + dialer: conf.Dialer, confListenPort: conf.ListenPort, birdClient: conf.BIRDClient, controlKnobs: conf.ControlKnobs, @@ -349,7 +371,7 @@ func NewUserspaceEngine(logf logger.Logf, conf Config) (_ Engine, reterr error) if conf.NetMon != nil { e.netMon = conf.NetMon } else { - mon, err := netmon.New(logf) + mon, err := netmon.New(conf.EventBus, logf) if err != nil { return nil, err } @@ -361,6 +383,7 @@ func NewUserspaceEngine(logf logger.Logf, conf Config) (_ Engine, reterr error) tunName, _ := conf.Tun.Name() conf.Dialer.SetTUNName(tunName) conf.Dialer.SetNetMon(e.netMon) + conf.Dialer.SetBus(e.eventBus) e.dns = dns.NewManager(logf, conf.DNS, e.health, conf.Dialer, fwdDNSLinkSelector{e, tunName}, conf.ControlKnobs, runtime.GOOS) // TODO: there's probably a better place for this @@ -368,13 +391,6 @@ func NewUserspaceEngine(logf logger.Logf, conf Config) (_ Engine, reterr error) logf("link state: %+v", e.netMon.InterfaceState()) - unregisterMonWatch := e.netMon.RegisterChangeCallback(func(delta *netmon.ChangeDelta) { - tshttpproxy.InvalidateCache() - e.linkChange(delta) - }) - closePool.addFunc(unregisterMonWatch) - e.netMonUnregister = unregisterMonWatch - endpointsFn := func(endpoints []tailcfg.Endpoint) { e.mu.Lock() e.endpoints = append(e.endpoints[:0], endpoints...) @@ -382,26 +398,21 @@ func NewUserspaceEngine(logf logger.Logf, conf Config) (_ Engine, reterr error) e.RequestStatus() } - onPortUpdate := func(port uint16, network string) { - e.logf("onPortUpdate(port=%v, network=%s)", port, network) - - if err := e.router.UpdateMagicsockPort(port, network); err != nil { - e.logf("UpdateMagicsockPort(port=%v, network=%s) failed: %v", port, network, err) - } - } magicsockOpts := magicsock.Options{ - Logf: logf, - Port: conf.ListenPort, - EndpointsFunc: endpointsFn, - DERPActiveFunc: e.RequestStatus, - IdleFunc: e.tundev.IdleDuration, - NoteRecvActivity: e.noteRecvActivity, - NetMon: e.netMon, - HealthTracker: e.health, - Metrics: conf.Metrics, - ControlKnobs: conf.ControlKnobs, - OnPortUpdate: onPortUpdate, - PeerByKeyFunc: e.PeerByKey, + EventBus: e.eventBus, + Logf: logf, + Port: conf.ListenPort, + EndpointsFunc: endpointsFn, + DERPActiveFunc: e.RequestStatus, + IdleFunc: e.tundev.IdleDuration, + NetMon: e.netMon, + HealthTracker: e.health, + Metrics: conf.Metrics, + ControlKnobs: conf.ControlKnobs, + PeerByKeyFunc: e.PeerByKey, + } + if buildfeatures.HasLazyWG { + magicsockOpts.NoteRecvActivity = e.noteRecvActivity } var err error @@ -419,7 +430,7 @@ func NewUserspaceEngine(logf logger.Logf, conf Config) (_ Engine, reterr error) } e.tundev.PreFilterPacketOutboundToWireGuardEngineIntercept = e.handleLocalPackets - if envknob.BoolDefaultTrue("TS_DEBUG_CONNECT_FAILURES") { + if buildfeatures.HasDebug && envknob.BoolDefaultTrue("TS_DEBUG_CONNECT_FAILURES") { if e.tundev.PreFilterPacketInboundFromWireGuard != nil { return nil, errors.New("unexpected PreFilterIn already set") } @@ -528,6 +539,14 @@ func NewUserspaceEngine(logf logger.Logf, conf Config) (_ Engine, reterr error) } } + ec := e.eventBus.Client("userspaceEngine") + eventbus.SubscribeFunc(ec, func(cd netmon.ChangeDelta) { + if f, ok := feature.HookProxyInvalidateCache.GetOk(); ok { + f() + } + e.linkChange(&cd) + }) + e.eventClient = ec e.logf("Engine created.") return e, nil } @@ -570,6 +589,17 @@ func (e *userspaceEngine) handleLocalPackets(p *packet.Parsed, t *tstun.Wrapper) return filter.Drop } } + if runtime.GOOS == "plan9" { + isLocalAddr, ok := e.isLocalAddr.LoadOk() + if ok { + if isLocalAddr(p.Dst.Addr()) { + // On Plan9's "tun" equivalent, everything goes back in and out + // the tun, even when the kernel's replying to itself. + t.InjectInboundCopy(p.Buffer()) + return filter.Drop + } + } + } return filter.Accept } @@ -673,6 +703,29 @@ func (e *userspaceEngine) isActiveSinceLocked(nk key.NodePublic, ip netip.Addr, return timePtr.LoadAtomic().After(t) } +// maybeReconfigInputs holds the inputs to the maybeReconfigWireguardLocked +// function. If these things don't change between calls, there's nothing to do. +type maybeReconfigInputs struct { + WGConfig *wgcfg.Config + TrimmedNodes map[key.NodePublic]bool + TrackNodes views.Slice[key.NodePublic] + TrackIPs views.Slice[netip.Addr] +} + +func (i *maybeReconfigInputs) Equal(o *maybeReconfigInputs) bool { + return reflect.DeepEqual(i, o) +} + +func (i *maybeReconfigInputs) Clone() *maybeReconfigInputs { + if i == nil { + return nil + } + v := *i + v.WGConfig = i.WGConfig.Clone() + v.TrimmedNodes = maps.Clone(i.TrimmedNodes) + return &v +} + // discoChanged are the set of peers whose disco keys have changed, implying they've restarted. // If a peer is in this set and was previously in the live wireguard config, // it needs to be first removed and then re-added to flush out its wireguard session key. @@ -698,15 +751,22 @@ func (e *userspaceEngine) maybeReconfigWireguardLocked(discoChanged map[key.Node // the past 5 minutes. That's more than WireGuard's key // rotation time anyway so it's no harm if we remove it // later if it's been inactive. - activeCutoff := e.timeNow().Add(-lazyPeerIdleThreshold) + var activeCutoff mono.Time + if buildfeatures.HasLazyWG { + activeCutoff = e.timeNow().Add(-lazyPeerIdleThreshold) + } // Not all peers can be trimmed from the network map (see // isTrimmablePeer). For those that are trimmable, keep track of // their NodeKey and Tailscale IPs. These are the ones we'll need // to install tracking hooks for to watch their send/receive // activity. - trackNodes := make([]key.NodePublic, 0, len(full.Peers)) - trackIPs := make([]netip.Addr, 0, len(full.Peers)) + var trackNodes []key.NodePublic + var trackIPs []netip.Addr + if buildfeatures.HasLazyWG { + trackNodes = make([]key.NodePublic, 0, len(full.Peers)) + trackIPs = make([]netip.Addr, 0, len(full.Peers)) + } // Don't re-alloc the map; the Go compiler optimizes map clears as of // Go 1.11, so we can re-use the existing + allocated map. @@ -720,7 +780,7 @@ func (e *userspaceEngine) maybeReconfigWireguardLocked(discoChanged map[key.Node for i := range full.Peers { p := &full.Peers[i] nk := p.PublicKey - if !e.isTrimmablePeer(p, len(full.Peers)) { + if !buildfeatures.HasLazyWG || !e.isTrimmablePeer(p, len(full.Peers)) { min.Peers = append(min.Peers, *p) if discoChanged[nk] { needRemoveStep = true @@ -744,16 +804,18 @@ func (e *userspaceEngine) maybeReconfigWireguardLocked(discoChanged map[key.Node } e.lastNMinPeers = len(min.Peers) - if changed := deephash.Update(&e.lastEngineSigTrim, &struct { - WGConfig *wgcfg.Config - TrimmedNodes map[key.NodePublic]bool - TrackNodes []key.NodePublic - TrackIPs []netip.Addr - }{&min, e.trimmedNodes, trackNodes, trackIPs}); !changed { + if changed := checkchange.Update(&e.lastEngineInputs, &maybeReconfigInputs{ + WGConfig: &min, + TrimmedNodes: e.trimmedNodes, + TrackNodes: views.SliceOf(trackNodes), + TrackIPs: views.SliceOf(trackIPs), + }); !changed { return nil } - e.updateActivityMapsLocked(trackNodes, trackIPs) + if buildfeatures.HasLazyWG { + e.updateActivityMapsLocked(trackNodes, trackIPs) + } if needRemoveStep { minner := min @@ -789,6 +851,9 @@ func (e *userspaceEngine) maybeReconfigWireguardLocked(discoChanged map[key.Node // // e.wgLock must be held. func (e *userspaceEngine) updateActivityMapsLocked(trackNodes []key.NodePublic, trackIPs []netip.Addr) { + if !buildfeatures.HasLazyWG { + return + } // Generate the new map of which nodekeys we want to track // receive times for. mr := map[key.NodePublic]mono.Time{} // TODO: only recreate this if set of keys changed @@ -852,8 +917,7 @@ func (e *userspaceEngine) updateActivityMapsLocked(trackNodes []key.NodePublic, // hasOverlap checks if there is a IPPrefix which is common amongst the two // provided slices. func hasOverlap(aips, rips views.Slice[netip.Prefix]) bool { - for i := range aips.Len() { - aip := aips.At(i) + for _, aip := range aips.All() { if views.SliceContains(rips, aip) { return true } @@ -861,6 +925,32 @@ func hasOverlap(aips, rips views.Slice[netip.Prefix]) bool { return false } +// ResetAndStop resets the engine to a clean state (like calling Reconfig +// with all pointers to zero values) and waits for it to be fully stopped, +// with no live peers or DERPs. +// +// Unlike Reconfig, it does not return ErrNoChanges. +// +// If the engine stops, returns the status. NB that this status will not be sent +// to the registered status callback, it is on the caller to ensure this status +// is handled appropriately. +func (e *userspaceEngine) ResetAndStop() (*Status, error) { + if err := e.Reconfig(&wgcfg.Config{}, &router.Config{}, &dns.Config{}); err != nil && !errors.Is(err, ErrNoChanges) { + return nil, err + } + bo := backoff.NewBackoff("UserspaceEngineResetAndStop", e.logf, 1*time.Second) + for { + st, err := e.getStatus() + if err != nil { + return nil, err + } + if len(st.Peers) == 0 && st.DERPs == 0 { + return st, nil + } + bo.BackOff(context.Background(), fmt.Errorf("waiting for engine to stop: peers=%d derps=%d", len(st.Peers), st.DERPs)) + } +} + func (e *userspaceEngine) Reconfig(cfg *wgcfg.Config, routerCfg *router.Config, dnsCfg *dns.Config) error { if routerCfg == nil { panic("routerCfg must not be nil") @@ -874,15 +964,17 @@ func (e *userspaceEngine) Reconfig(cfg *wgcfg.Config, routerCfg *router.Config, e.wgLock.Lock() defer e.wgLock.Unlock() e.tundev.SetWGConfig(cfg) - e.lastDNSConfig = dnsCfg peerSet := make(set.Set[key.NodePublic], len(cfg.Peers)) + e.mu.Lock() - e.peerSequence = e.peerSequence[:0] + seq := make([]key.NodePublic, 0, len(cfg.Peers)) for _, p := range cfg.Peers { - e.peerSequence = append(e.peerSequence, p.PublicKey) + seq = append(seq, p.PublicKey) peerSet.Add(p.PublicKey) } + e.peerSequence = views.SliceOf(seq) + nm := e.netMap e.mu.Unlock() @@ -894,22 +986,24 @@ func (e *userspaceEngine) Reconfig(cfg *wgcfg.Config, routerCfg *router.Config, peerMTUEnable := e.magicConn.ShouldPMTUD() isSubnetRouter := false - if e.birdClient != nil && nm != nil && nm.SelfNode.Valid() { + if buildfeatures.HasBird && e.birdClient != nil && nm != nil && nm.SelfNode.Valid() { isSubnetRouter = hasOverlap(nm.SelfNode.PrimaryRoutes(), nm.SelfNode.Hostinfo().RoutableIPs()) e.logf("[v1] Reconfig: hasOverlap(%v, %v) = %v; isSubnetRouter=%v lastIsSubnetRouter=%v", nm.SelfNode.PrimaryRoutes(), nm.SelfNode.Hostinfo().RoutableIPs(), isSubnetRouter, isSubnetRouter, e.lastIsSubnetRouter) } - isSubnetRouterChanged := isSubnetRouter != e.lastIsSubnetRouter + isSubnetRouterChanged := buildfeatures.HasAdvertiseRoutes && isSubnetRouter != e.lastIsSubnetRouter + + engineChanged := checkchange.Update(&e.lastEngineFull, cfg) + routerChanged := checkchange.Update(&e.lastRouter, routerCfg) + dnsChanged := buildfeatures.HasDNS && !e.lastDNSConfig.Equal(dnsCfg.View()) + if dnsChanged { + e.lastDNSConfig = dnsCfg.View() + } - engineChanged := deephash.Update(&e.lastEngineSigFull, cfg) - routerChanged := deephash.Update(&e.lastRouterSig, &struct { - RouterConfig *router.Config - DNSConfig *dns.Config - }{routerCfg, dnsCfg}) listenPortChanged := listenPort != e.magicConn.LocalPort() peerMTUChanged := peerMTUEnable != e.magicConn.PeerMTUEnabled() - if !engineChanged && !routerChanged && !listenPortChanged && !isSubnetRouterChanged && !peerMTUChanged { + if !engineChanged && !routerChanged && !dnsChanged && !listenPortChanged && !isSubnetRouterChanged && !peerMTUChanged { return ErrNoChanges } newLogIDs := cfg.NetworkLogging @@ -918,7 +1012,7 @@ func (e *userspaceEngine) Reconfig(cfg *wgcfg.Config, routerCfg *router.Config, netLogIDsWasValid := !oldLogIDs.NodeID.IsZero() && !oldLogIDs.DomainID.IsZero() netLogIDsChanged := netLogIDsNowValid && netLogIDsWasValid && newLogIDs != oldLogIDs netLogRunning := netLogIDsNowValid && !routerCfg.Equal(&router.Config{}) - if envknob.NoLogsNoSupport() { + if !buildfeatures.HasNetLog || envknob.NoLogsNoSupport() { netLogRunning = false } @@ -927,7 +1021,9 @@ func (e *userspaceEngine) Reconfig(cfg *wgcfg.Config, routerCfg *router.Config, // instead have ipnlocal populate a map of DNS IP => linkName and // put that in the *dns.Config instead, and plumb it down to the // dns.Manager. Maybe also with isLocalAddr above. - e.isDNSIPOverTailscale.Store(ipset.NewContainsIPFunc(views.SliceOf(dnsIPsOverTailscale(dnsCfg, routerCfg)))) + if buildfeatures.HasDNS { + e.isDNSIPOverTailscale.Store(ipset.NewContainsIPFunc(views.SliceOf(dnsIPsOverTailscale(dnsCfg, routerCfg)))) + } // See if any peers have changed disco keys, which means they've restarted. // If so, we need to update the wireguard-go/device.Device in two phases: @@ -973,7 +1069,7 @@ func (e *userspaceEngine) Reconfig(cfg *wgcfg.Config, routerCfg *router.Config, // Shutdown the network logger because the IDs changed. // Let it be started back up by subsequent logic. - if netLogIDsChanged && e.networkLogger.Running() { + if buildfeatures.HasNetLog && netLogIDsChanged && e.networkLogger.Running() { e.logf("wgengine: Reconfig: shutting down network logger") ctx, cancel := context.WithTimeout(context.Background(), networkLoggerUploadTimeout) defer cancel() @@ -984,12 +1080,12 @@ func (e *userspaceEngine) Reconfig(cfg *wgcfg.Config, routerCfg *router.Config, // Startup the network logger. // Do this before configuring the router so that we capture initial packets. - if netLogRunning && !e.networkLogger.Running() { + if buildfeatures.HasNetLog && netLogRunning && !e.networkLogger.Running() { nid := cfg.NetworkLogging.NodeID tid := cfg.NetworkLogging.DomainID logExitFlowEnabled := cfg.NetworkLogging.LogExitFlowEnabled e.logf("wgengine: Reconfig: starting up network logger (node:%s tailnet:%s)", nid.Public(), tid.Public()) - if err := e.networkLogger.Startup(cfg.NodeID, nid, tid, e.tundev, e.magicConn, e.netMon, e.health, logExitFlowEnabled); err != nil { + if err := e.networkLogger.Startup(e.logf, nm, nid, tid, e.tundev, e.magicConn, e.netMon, e.health, e.eventBus, logExitFlowEnabled); err != nil { e.logf("wgengine: Reconfig: error starting up network logger: %v", err) } e.networkLogger.ReconfigRoutes(routerCfg) @@ -1003,11 +1099,30 @@ func (e *userspaceEngine) Reconfig(cfg *wgcfg.Config, routerCfg *router.Config, if err != nil { return err } + } + + // We've historically re-set DNS even after just a router change. While + // refactoring in tailscale/tailscale#17448 and and + // tailscale/tailscale#17499, I'm erring on the side of keeping that + // historical quirk for now (2025-10-08), lest it's load bearing in + // unexpected ways + // + // TODO(bradfitz): try to do the "configuring DNS" part below only if + // dnsChanged, not routerChanged. The "resolver.ShouldUseRoutes" part + // probably needs to keep happening for both. + if buildfeatures.HasDNS && (routerChanged || dnsChanged) { + if resolver.ShouldUseRoutes(e.controlKnobs) { + e.logf("wgengine: Reconfig: user dialer") + e.dialer.SetRoutes(routerCfg.Routes, routerCfg.LocalRoutes) + } else { + e.dialer.SetRoutes(nil, nil) + } + // Keep DNS configuration after router configuration, as some // DNS managers refuse to apply settings if the device has no // assigned address. e.logf("wgengine: Reconfig: configuring DNS") - err = e.dns.Set(*dnsCfg) + err := e.dns.Set(*dnsCfg) e.health.SetDNSHealth(err) if err != nil { return err @@ -1029,7 +1144,7 @@ func (e *userspaceEngine) Reconfig(cfg *wgcfg.Config, routerCfg *router.Config, } } - if isSubnetRouterChanged && e.birdClient != nil { + if buildfeatures.HasBird && isSubnetRouterChanged && e.birdClient != nil { e.logf("wgengine: Reconfig: configuring BIRD") var err error if isSubnetRouter { @@ -1114,7 +1229,7 @@ func (e *userspaceEngine) getStatus() (*Status, error) { e.mu.Lock() closing := e.closing - peerKeys := slices.Clone(e.peerSequence) + peerKeys := e.peerSequence localAddrs := slices.Clone(e.endpoints) e.mu.Unlock() @@ -1122,8 +1237,8 @@ func (e *userspaceEngine) getStatus() (*Status, error) { return nil, ErrEngineClosing } - peers := make([]ipnstate.PeerStatusLite, 0, len(peerKeys)) - for _, key := range peerKeys { + peers := make([]ipnstate.PeerStatusLite, 0, peerKeys.Len()) + for _, key := range peerKeys.All() { if status, ok := e.getPeerStatusLite(key); ok { peers = append(peers, status) } @@ -1172,6 +1287,7 @@ func (e *userspaceEngine) RequestStatus() { } func (e *userspaceEngine) Close() { + e.eventClient.Close() e.mu.Lock() if e.closing { e.mu.Unlock() @@ -1183,7 +1299,6 @@ func (e *userspaceEngine) Close() { r := bufio.NewReader(strings.NewReader("")) e.wgdev.IpcSetOperation(r) e.magicConn.Close() - e.netMonUnregister() if e.netMonOwned { e.netMon.Close() } @@ -1236,12 +1351,12 @@ func (e *userspaceEngine) linkChange(delta *netmon.ChangeDelta) { // and Apple platforms. if changed { switch runtime.GOOS { - case "linux", "android", "ios", "darwin": + case "linux", "android", "ios", "darwin", "openbsd": e.wgLock.Lock() dnsCfg := e.lastDNSConfig e.wgLock.Unlock() - if dnsCfg != nil { - if err := e.dns.Set(*dnsCfg); err != nil { + if dnsCfg.Valid() { + if err := e.dns.Set(*dnsCfg.AsStruct()); err != nil { e.logf("wgengine: error setting DNS config after major link change: %v", err) } else if err := e.reconfigureVPNIfNecessary(); err != nil { e.logf("wgengine: error reconfiguring VPN after major link change: %v", err) @@ -1264,10 +1379,12 @@ func (e *userspaceEngine) linkChange(delta *netmon.ChangeDelta) { } func (e *userspaceEngine) SetNetworkMap(nm *netmap.NetworkMap) { - e.magicConn.SetNetworkMap(nm) e.mu.Lock() e.netMap = nm e.mu.Unlock() + if e.networkLogger.Running() { + e.networkLogger.ReconfigNetworkMap(nm) + } } func (e *userspaceEngine) UpdateStatus(sb *ipnstate.StatusBuilder) { @@ -1329,9 +1446,9 @@ func (e *userspaceEngine) mySelfIPMatchingFamily(dst netip.Addr) (src netip.Addr if addrs.Len() == 0 { return zero, errors.New("no self address in netmap") } - for i := range addrs.Len() { - if a := addrs.At(i); a.IsSingleIP() && a.Addr().BitLen() == dst.BitLen() { - return a.Addr(), nil + for _, p := range addrs.All() { + if p.IsSingleIP() && p.Addr().BitLen() == dst.BitLen() { + return p.Addr(), nil } } return zero, errors.New("no self address in netmap matching address family") @@ -1582,6 +1699,12 @@ type fwdDNSLinkSelector struct { } func (ls fwdDNSLinkSelector) PickLink(ip netip.Addr) (linkName string) { + // sandboxed macOS does not automatically bind to the loopback interface so + // we must be explicit about it. + if runtime.GOOS == "darwin" && ip.IsLoopback() { + return "lo0" + } + if ls.ue.isDNSIPOverTailscale.Load()(ip) { return ls.tunName } @@ -1595,7 +1718,10 @@ var ( metricNumMinorChanges = clientmetric.NewCounter("wgengine_minor_changes") ) -func (e *userspaceEngine) InstallCaptureHook(cb capture.Callback) { +func (e *userspaceEngine) InstallCaptureHook(cb packet.CaptureCallback) { + if !buildfeatures.HasCapture { + return + } e.tundev.InstallCaptureHook(cb) e.magicConn.InstallCaptureHook(cb) } diff --git a/wgengine/userspace_ext_test.go b/wgengine/userspace_ext_test.go index cc29be234..8e7bbb7a9 100644 --- a/wgengine/userspace_ext_test.go +++ b/wgengine/userspace_ext_test.go @@ -16,13 +16,14 @@ import ( ) func TestIsNetstack(t *testing.T) { - sys := new(tsd.System) + sys := tsd.NewSystem() e, err := wgengine.NewUserspaceEngine( tstest.WhileTestRunningLogger(t), wgengine.Config{ SetSubsystem: sys.Set, - HealthTracker: sys.HealthTracker(), + HealthTracker: sys.HealthTracker.Get(), Metrics: sys.UserMetricsRegistry(), + EventBus: sys.Bus.Get(), }, ) if err != nil { @@ -66,14 +67,15 @@ func TestIsNetstackRouter(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - sys := &tsd.System{} + sys := tsd.NewSystem() if tt.setNetstackRouter { sys.NetstackRouter.Set(true) } conf := tt.conf conf.SetSubsystem = sys.Set - conf.HealthTracker = sys.HealthTracker() + conf.HealthTracker = sys.HealthTracker.Get() conf.Metrics = sys.UserMetricsRegistry() + conf.EventBus = sys.Bus.Get() e, err := wgengine.NewUserspaceEngine(logger.Discard, conf) if err != nil { t.Fatal(err) diff --git a/wgengine/userspace_test.go b/wgengine/userspace_test.go index 051421862..89d75b98a 100644 --- a/wgengine/userspace_test.go +++ b/wgengine/userspace_test.go @@ -25,6 +25,7 @@ import ( "tailscale.com/types/key" "tailscale.com/types/netmap" "tailscale.com/types/opt" + "tailscale.com/util/eventbus/eventbustest" "tailscale.com/util/usermetric" "tailscale.com/wgengine/router" "tailscale.com/wgengine/wgcfg" @@ -100,9 +101,11 @@ func nodeViews(v []*tailcfg.Node) []tailcfg.NodeView { } func TestUserspaceEngineReconfig(t *testing.T) { - ht := new(health.Tracker) + bus := eventbustest.NewBus(t) + + ht := health.NewTracker(bus) reg := new(usermetric.Registry) - e, err := NewFakeUserspaceEngine(t.Logf, 0, ht, reg) + e, err := NewFakeUserspaceEngine(t.Logf, 0, ht, reg, bus) if err != nil { t.Fatal(err) } @@ -166,13 +169,15 @@ func TestUserspaceEnginePortReconfig(t *testing.T) { var knobs controlknobs.Knobs + bus := eventbustest.NewBus(t) + // Keep making a wgengine until we find an unused port var ue *userspaceEngine - ht := new(health.Tracker) + ht := health.NewTracker(bus) reg := new(usermetric.Registry) for i := range 100 { attempt := uint16(defaultPort + i) - e, err := NewFakeUserspaceEngine(t.Logf, attempt, &knobs, ht, reg) + e, err := NewFakeUserspaceEngine(t.Logf, attempt, &knobs, ht, reg, bus) if err != nil { t.Fatal(err) } @@ -251,9 +256,10 @@ func TestUserspaceEnginePeerMTUReconfig(t *testing.T) { var knobs controlknobs.Knobs - ht := new(health.Tracker) + bus := eventbustest.NewBus(t) + ht := health.NewTracker(bus) reg := new(usermetric.Registry) - e, err := NewFakeUserspaceEngine(t.Logf, 0, &knobs, ht, reg) + e, err := NewFakeUserspaceEngine(t.Logf, 0, &knobs, ht, reg, bus) if err != nil { t.Fatal(err) } diff --git a/wgengine/watchdog.go b/wgengine/watchdog.go index 232591f5e..9cc4ed3b5 100644 --- a/wgengine/watchdog.go +++ b/wgengine/watchdog.go @@ -1,7 +1,7 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -//go:build !js +//go:build !js && !ts_omit_debug package wgengine @@ -15,12 +15,13 @@ import ( "time" "tailscale.com/envknob" + "tailscale.com/feature/buildfeatures" "tailscale.com/ipn/ipnstate" "tailscale.com/net/dns" + "tailscale.com/net/packet" "tailscale.com/tailcfg" "tailscale.com/types/key" "tailscale.com/types/netmap" - "tailscale.com/wgengine/capture" "tailscale.com/wgengine/filter" "tailscale.com/wgengine/router" "tailscale.com/wgengine/wgcfg" @@ -123,6 +124,12 @@ func (e *watchdogEngine) watchdog(name string, fn func()) { func (e *watchdogEngine) Reconfig(cfg *wgcfg.Config, routerCfg *router.Config, dnsCfg *dns.Config) error { return e.watchdogErr("Reconfig", func() error { return e.wrap.Reconfig(cfg, routerCfg, dnsCfg) }) } +func (e *watchdogEngine) ResetAndStop() (st *Status, err error) { + e.watchdog("ResetAndStop", func() { + st, err = e.wrap.ResetAndStop() + }) + return st, err +} func (e *watchdogEngine) GetFilter() *filter.Filter { return e.wrap.GetFilter() } @@ -162,7 +169,10 @@ func (e *watchdogEngine) Done() <-chan struct{} { return e.wrap.Done() } -func (e *watchdogEngine) InstallCaptureHook(cb capture.Callback) { +func (e *watchdogEngine) InstallCaptureHook(cb packet.CaptureCallback) { + if !buildfeatures.HasCapture { + return + } e.wrap.InstallCaptureHook(cb) } diff --git a/wgengine/watchdog_js.go b/wgengine/watchdog_js.go deleted file mode 100644 index 872ce36d5..000000000 --- a/wgengine/watchdog_js.go +++ /dev/null @@ -1,17 +0,0 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build js - -package wgengine - -import "tailscale.com/net/dns/resolver" - -type watchdogEngine struct { - Engine - wrap Engine -} - -func (e *watchdogEngine) GetResolver() (r *resolver.Resolver, ok bool) { - return nil, false -} diff --git a/wgengine/watchdog_omit.go b/wgengine/watchdog_omit.go new file mode 100644 index 000000000..1d175b41a --- /dev/null +++ b/wgengine/watchdog_omit.go @@ -0,0 +1,8 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build js || ts_omit_debug + +package wgengine + +func NewWatchdog(e Engine) Engine { return e } diff --git a/wgengine/watchdog_test.go b/wgengine/watchdog_test.go index b05cd421f..35fd8f331 100644 --- a/wgengine/watchdog_test.go +++ b/wgengine/watchdog_test.go @@ -9,6 +9,7 @@ import ( "time" "tailscale.com/health" + "tailscale.com/util/eventbus/eventbustest" "tailscale.com/util/usermetric" ) @@ -24,9 +25,10 @@ func TestWatchdog(t *testing.T) { t.Run("default watchdog does not fire", func(t *testing.T) { t.Parallel() - ht := new(health.Tracker) + bus := eventbustest.NewBus(t) + ht := health.NewTracker(bus) reg := new(usermetric.Registry) - e, err := NewFakeUserspaceEngine(t.Logf, 0, ht, reg) + e, err := NewFakeUserspaceEngine(t.Logf, 0, ht, reg, bus) if err != nil { t.Fatal(err) } diff --git a/wgengine/wgcfg/config.go b/wgengine/wgcfg/config.go index 154dc0a30..2734f6c6e 100644 --- a/wgengine/wgcfg/config.go +++ b/wgengine/wgcfg/config.go @@ -6,8 +6,8 @@ package wgcfg import ( "net/netip" + "slices" - "tailscale.com/tailcfg" "tailscale.com/types/key" "tailscale.com/types/logid" ) @@ -17,8 +17,6 @@ import ( // Config is a WireGuard configuration. // It only supports the set of things Tailscale uses. type Config struct { - Name string - NodeID tailcfg.StableNodeID PrivateKey key.NodePrivate Addresses []netip.Prefix MTU uint16 @@ -35,6 +33,18 @@ type Config struct { } } +func (c *Config) Equal(o *Config) bool { + if c == nil || o == nil { + return c == o + } + return c.PrivateKey.Equal(o.PrivateKey) && + c.MTU == o.MTU && + c.NetworkLogging == o.NetworkLogging && + slices.Equal(c.Addresses, o.Addresses) && + slices.Equal(c.DNS, o.DNS) && + slices.EqualFunc(c.Peers, o.Peers, Peer.Equal) +} + type Peer struct { PublicKey key.NodePublic DiscoKey key.DiscoPublic // present only so we can handle restarts within wgengine, not passed to WireGuard @@ -50,6 +60,24 @@ type Peer struct { WGEndpoint key.NodePublic } +func addrPtrEq(a, b *netip.Addr) bool { + if a == nil || b == nil { + return a == b + } + return *a == *b +} + +func (p Peer) Equal(o Peer) bool { + return p.PublicKey == o.PublicKey && + p.DiscoKey == o.DiscoKey && + slices.Equal(p.AllowedIPs, o.AllowedIPs) && + p.IsJailed == o.IsJailed && + p.PersistentKeepalive == o.PersistentKeepalive && + addrPtrEq(p.V4MasqAddr, o.V4MasqAddr) && + addrPtrEq(p.V6MasqAddr, o.V6MasqAddr) && + p.WGEndpoint == o.WGEndpoint +} + // PeerWithKey returns the Peer with key k and reports whether it was found. func (config Config) PeerWithKey(k key.NodePublic) (Peer, bool) { for _, p := range config.Peers { diff --git a/wgengine/wgcfg/config_test.go b/wgengine/wgcfg/config_test.go new file mode 100644 index 000000000..5ac3b7cd5 --- /dev/null +++ b/wgengine/wgcfg/config_test.go @@ -0,0 +1,41 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package wgcfg + +import ( + "reflect" + "testing" +) + +// Tests that [Config.Equal] tests all fields of [Config], even ones +// that might get added in the future. +func TestConfigEqual(t *testing.T) { + rt := reflect.TypeFor[Config]() + for i := range rt.NumField() { + sf := rt.Field(i) + switch sf.Name { + case "Name", "NodeID", "PrivateKey", "MTU", "Addresses", "DNS", "Peers", + "NetworkLogging": + // These are compared in [Config.Equal]. + default: + t.Errorf("Have you added field %q to Config.Equal? Do so if not, and then update TestConfigEqual", sf.Name) + } + } +} + +// Tests that [Peer.Equal] tests all fields of [Peer], even ones +// that might get added in the future. +func TestPeerEqual(t *testing.T) { + rt := reflect.TypeFor[Peer]() + for i := range rt.NumField() { + sf := rt.Field(i) + switch sf.Name { + case "PublicKey", "DiscoKey", "AllowedIPs", "IsJailed", + "PersistentKeepalive", "V4MasqAddr", "V6MasqAddr", "WGEndpoint": + // These are compared in [Peer.Equal]. + default: + t.Errorf("Have you added field %q to Peer.Equal? Do so if not, and then update TestPeerEqual", sf.Name) + } + } +} diff --git a/wgengine/wgcfg/device.go b/wgengine/wgcfg/device.go index 80fa159e3..ee7eb91c9 100644 --- a/wgengine/wgcfg/device.go +++ b/wgengine/wgcfg/device.go @@ -4,6 +4,7 @@ package wgcfg import ( + "errors" "io" "sort" @@ -11,7 +12,6 @@ import ( "github.com/tailscale/wireguard-go/device" "github.com/tailscale/wireguard-go/tun" "tailscale.com/types/logger" - "tailscale.com/util/multierr" ) // NewDevice returns a wireguard-go Device configured for Tailscale use. @@ -31,7 +31,7 @@ func DeviceConfig(d *device.Device) (*Config, error) { cfg, fromErr := FromUAPI(r) r.Close() getErr := <-errc - err := multierr.New(getErr, fromErr) + err := errors.Join(getErr, fromErr) if err != nil { return nil, err } @@ -64,5 +64,5 @@ func ReconfigDevice(d *device.Device, cfg *Config, logf logger.Logf) (err error) toErr := cfg.ToUAPI(logf, w, prev) w.Close() setErr := <-errc - return multierr.New(setErr, toErr) + return errors.Join(setErr, toErr) } diff --git a/wgengine/wgcfg/device_test.go b/wgengine/wgcfg/device_test.go index d54282e4b..9138d6e5a 100644 --- a/wgengine/wgcfg/device_test.go +++ b/wgengine/wgcfg/device_test.go @@ -242,9 +242,9 @@ type noopBind struct{} func (noopBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) { return nil, 1, nil } -func (noopBind) Close() error { return nil } -func (noopBind) SetMark(mark uint32) error { return nil } -func (noopBind) Send(b [][]byte, ep conn.Endpoint) error { return nil } +func (noopBind) Close() error { return nil } +func (noopBind) SetMark(mark uint32) error { return nil } +func (noopBind) Send(b [][]byte, ep conn.Endpoint, offset int) error { return nil } func (noopBind) ParseEndpoint(s string) (conn.Endpoint, error) { return dummyEndpoint(s), nil } diff --git a/wgengine/wgcfg/nmcfg/nmcfg.go b/wgengine/wgcfg/nmcfg/nmcfg.go index d156f7fcb..487e78d81 100644 --- a/wgengine/wgcfg/nmcfg/nmcfg.go +++ b/wgengine/wgcfg/nmcfg/nmcfg.go @@ -5,12 +5,14 @@ package nmcfg import ( - "bytes" + "bufio" + "cmp" "fmt" "net/netip" "strings" "tailscale.com/tailcfg" + "tailscale.com/types/key" "tailscale.com/types/logger" "tailscale.com/types/logid" "tailscale.com/types/netmap" @@ -18,16 +20,7 @@ import ( ) func nodeDebugName(n tailcfg.NodeView) string { - name := n.Name() - if name == "" { - name = n.Hostinfo().Hostname() - } - if i := strings.Index(name, "."); i != -1 { - name = name[:i] - } - if name == "" && n.Addresses().Len() != 0 { - return n.Addresses().At(0).String() - } + name, _, _ := strings.Cut(cmp.Or(n.Name(), n.Hostinfo().Hostname()), ".") return name } @@ -40,8 +33,7 @@ func cidrIsSubnet(node tailcfg.NodeView, cidr netip.Prefix) bool { if !cidr.IsSingleIP() { return true } - for i := range node.Addresses().Len() { - selfCIDR := node.Addresses().At(i) + for _, selfCIDR := range node.Addresses().All() { if cidr == selfCIDR { return false } @@ -50,17 +42,15 @@ func cidrIsSubnet(node tailcfg.NodeView, cidr netip.Prefix) bool { } // WGCfg returns the NetworkMaps's WireGuard configuration. -func WGCfg(nm *netmap.NetworkMap, logf logger.Logf, flags netmap.WGConfigFlags, exitNode tailcfg.StableNodeID) (*wgcfg.Config, error) { +func WGCfg(pk key.NodePrivate, nm *netmap.NetworkMap, logf logger.Logf, flags netmap.WGConfigFlags, exitNode tailcfg.StableNodeID) (*wgcfg.Config, error) { cfg := &wgcfg.Config{ - Name: "tailscale", - PrivateKey: nm.PrivateKey, + PrivateKey: pk, Addresses: nm.GetAddresses().AsSlice(), Peers: make([]wgcfg.Peer, 0, len(nm.Peers)), } // Setup log IDs for data plane audit logging. if nm.SelfNode.Valid() { - cfg.NodeID = nm.SelfNode.StableID() canNetworkLog := nm.SelfNode.HasCap(tailcfg.CapabilityDataPlaneAuditLogs) logExitFlowEnabled := nm.SelfNode.HasCap(tailcfg.NodeAttrLogExitFlows) if canNetworkLog && nm.SelfNode.DataPlaneAuditLogID() != "" && nm.DomainAuditLogID != "" { @@ -80,13 +70,10 @@ func WGCfg(nm *netmap.NetworkMap, logf logger.Logf, flags netmap.WGConfigFlags, } } - // Logging buffers - skippedUnselected := new(bytes.Buffer) - skippedIPs := new(bytes.Buffer) - skippedSubnets := new(bytes.Buffer) + var skippedExitNode, skippedSubnetRouter, skippedExpired []tailcfg.NodeView for _, peer := range nm.Peers { - if peer.DiscoKey().IsZero() && peer.DERP() == "" && !peer.IsWireGuardOnly() { + if peer.DiscoKey().IsZero() && peer.HomeDERP() == 0 && !peer.IsWireGuardOnly() { // Peer predates both DERP and active discovery, we cannot // communicate with it. logf("[v1] wgcfg: skipped peer %s, doesn't offer DERP or disco", peer.Key().ShortString()) @@ -96,7 +83,7 @@ func WGCfg(nm *netmap.NetworkMap, logf logger.Logf, flags netmap.WGConfigFlags, // anyway, since control intentionally breaks node keys for // expired peers so that we can't discover endpoints via DERP. if peer.Expired() { - logf("[v1] wgcfg: skipped expired peer %s", peer.Key().ShortString()) + skippedExpired = append(skippedExpired, peer) continue } @@ -106,29 +93,22 @@ func WGCfg(nm *netmap.NetworkMap, logf logger.Logf, flags netmap.WGConfigFlags, }) cpeer := &cfg.Peers[len(cfg.Peers)-1] - didExitNodeWarn := false - cpeer.V4MasqAddr = peer.SelfNodeV4MasqAddrForThisPeer() - cpeer.V6MasqAddr = peer.SelfNodeV6MasqAddrForThisPeer() + didExitNodeLog := false + cpeer.V4MasqAddr = peer.SelfNodeV4MasqAddrForThisPeer().Clone() + cpeer.V6MasqAddr = peer.SelfNodeV6MasqAddrForThisPeer().Clone() cpeer.IsJailed = peer.IsJailed() - for i := range peer.AllowedIPs().Len() { - allowedIP := peer.AllowedIPs().At(i) + for _, allowedIP := range peer.AllowedIPs().All() { if allowedIP.Bits() == 0 && peer.StableID() != exitNode { - if didExitNodeWarn { + if didExitNodeLog { // Don't log about both the IPv4 /0 and IPv6 /0. continue } - didExitNodeWarn = true - if skippedUnselected.Len() > 0 { - skippedUnselected.WriteString(", ") - } - fmt.Fprintf(skippedUnselected, "%q (%v)", nodeDebugName(peer), peer.Key().ShortString()) + didExitNodeLog = true + skippedExitNode = append(skippedExitNode, peer) continue } else if cidrIsSubnet(peer, allowedIP) { if (flags & netmap.AllowSubnetRoutes) == 0 { - if skippedSubnets.Len() > 0 { - skippedSubnets.WriteString(", ") - } - fmt.Fprintf(skippedSubnets, "%v from %q (%v)", allowedIP, nodeDebugName(peer), peer.Key().ShortString()) + skippedSubnetRouter = append(skippedSubnetRouter, peer) continue } } @@ -136,15 +116,27 @@ func WGCfg(nm *netmap.NetworkMap, logf logger.Logf, flags netmap.WGConfigFlags, } } - if skippedUnselected.Len() > 0 { - logf("[v1] wgcfg: skipped unselected default routes from: %s", skippedUnselected.Bytes()) - } - if skippedIPs.Len() > 0 { - logf("[v1] wgcfg: skipped node IPs: %s", skippedIPs) - } - if skippedSubnets.Len() > 0 { - logf("[v1] wgcfg: did not accept subnet routes: %s", skippedSubnets) + logList := func(title string, nodes []tailcfg.NodeView) { + if len(nodes) == 0 { + return + } + logf("[v1] wgcfg: %s from %d nodes: %s", title, len(nodes), logger.ArgWriter(func(bw *bufio.Writer) { + const max = 5 + for i, n := range nodes { + if i == max { + fmt.Fprintf(bw, "... +%d", len(nodes)-max) + return + } + if i > 0 { + bw.WriteString(", ") + } + fmt.Fprintf(bw, "%s (%s)", nodeDebugName(n), n.StableID()) + } + })) } + logList("skipped unselected exit nodes", skippedExitNode) + logList("did not accept subnet routes", skippedSubnetRouter) + logList("skipped expired peers", skippedExpired) return cfg, nil } diff --git a/wgengine/wgcfg/wgcfg_clone.go b/wgengine/wgcfg/wgcfg_clone.go index 749d8d816..9f3cabde1 100644 --- a/wgengine/wgcfg/wgcfg_clone.go +++ b/wgengine/wgcfg/wgcfg_clone.go @@ -8,7 +8,6 @@ package wgcfg import ( "net/netip" - "tailscale.com/tailcfg" "tailscale.com/types/key" "tailscale.com/types/logid" "tailscale.com/types/ptr" @@ -35,8 +34,6 @@ func (src *Config) Clone() *Config { // A compilation failure here means this code must be regenerated, with the command at the top of this file. var _ConfigCloneNeedsRegeneration = Config(struct { - Name string - NodeID tailcfg.StableNodeID PrivateKey key.NodePrivate Addresses []netip.Prefix MTU uint16 diff --git a/wgengine/wgengine.go b/wgengine/wgengine.go index c165ccdf3..be7873147 100644 --- a/wgengine/wgengine.go +++ b/wgengine/wgengine.go @@ -11,10 +11,10 @@ import ( "tailscale.com/ipn/ipnstate" "tailscale.com/net/dns" + "tailscale.com/net/packet" "tailscale.com/tailcfg" "tailscale.com/types/key" "tailscale.com/types/netmap" - "tailscale.com/wgengine/capture" "tailscale.com/wgengine/filter" "tailscale.com/wgengine/router" "tailscale.com/wgengine/wgcfg" @@ -69,6 +69,13 @@ type Engine interface { // The returned error is ErrNoChanges if no changes were made. Reconfig(*wgcfg.Config, *router.Config, *dns.Config) error + // ResetAndStop resets the engine to a clean state (like calling Reconfig + // with all pointers to zero values) and waits for it to be fully stopped, + // with no live peers or DERPs. + // + // Unlike Reconfig, it does not return ErrNoChanges. + ResetAndStop() (*Status, error) + // PeerForIP returns the node to which the provided IP routes, // if any. If none is found, (nil, false) is returned. PeerForIP(netip.Addr) (_ PeerForIP, ok bool) @@ -129,5 +136,5 @@ type Engine interface { // InstallCaptureHook registers a function to be called to capture // packets traversing the data path. The hook can be uninstalled by // calling this function with a nil value. - InstallCaptureHook(capture.Callback) + InstallCaptureHook(packet.CaptureCallback) } diff --git a/words/scales.txt b/words/scales.txt index f27dfc5c4..bb623fb6f 100644 --- a/words/scales.txt +++ b/words/scales.txt @@ -391,3 +391,63 @@ godzilla sirius vector cherimoya +shilling +kettle +kitchen +fahrenheit +rankine +piano +ruler +scoville +oratrice +teeth +cliff +degree +company +economy +court +justitia +themis +carat +carob +karat +barley +corn +penny +pound +mark +pence +mine +stairs +escalator +elevator +skilift +gondola +firefighter +newton +smoot +city +truck +everest +wall +fence +fort +trench +matrix +census +likert +sidemirror +wage +salary +fujita +caiman +cichlid +logarithm +exponential +geological +cosmological +barometric +ph +pain +temperature +wyrm diff --git a/words/tails.txt b/words/tails.txt index 497533241..b0119a756 100644 --- a/words/tails.txt +++ b/words/tails.txt @@ -694,3 +694,81 @@ ussuri kitty tanuki neko +wind +airplane +time +gumiho +eel +moray +twin +hair +braid +gate +end +queue +miku +at +fin +solarflare +asymptote +reverse +bone +stern +quaver +note +mining +coat +follow +stalk +caudal +chronicle +trout +sturgeon +swordfish +catfish +pike +angler +anchovy +angelfish +cod +icefish +carp +mackarel +salmon +grayling +lungfish +dragonfish +barracuda +barreleye +bass +ridgehead +bigscale +blowfish +bream +bullhead +pufferfish +sardine +sunfish +mullet +snapper +pipefish +seahorse +flounder +tilapia +dorado +shad +lionfish +crayfish +sailfish +billfish +taimen +sargo +story +tale +gecko +wyrm +meteor +ribbon +echo +lemming +worm
        KeyTokens