diff --git a/AUTHORS b/AUTHORS new file mode 100644 index 000000000..03d5932c0 --- /dev/null +++ b/AUTHORS @@ -0,0 +1,17 @@ +# This is the official list of Tailscale +# authors for copyright purposes. +# +# Names should be added to this file as one of +# Organization's name +# Individual's name +# Individual's name +# +# Please keep the list sorted. +# +# You do not need to add entries to this list, and we don't actively +# populate this list. If you do want to be acknowledged explicitly as +# a copyright holder, though, then please send a PR referencing your +# earlier contributions and clarifying whether it's you or your +# company that owns the rights to your contribution. + +Tailscale Inc. diff --git a/LICENSE b/LICENSE new file mode 100644 index 000000000..eb7c3f6b2 --- /dev/null +++ b/LICENSE @@ -0,0 +1,27 @@ +Copyright (c) 2020 Tailscale & AUTHORS. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + * Neither the name of Tailscale Inc. nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/PATENTS b/PATENTS new file mode 100644 index 000000000..560a2b8f0 --- /dev/null +++ b/PATENTS @@ -0,0 +1,24 @@ +Additional IP Rights Grant (Patents) + +"This implementation" means the copyrightable works distributed by +Tailscale Inc. as part of the Tailscale project. + +Tailscale Inc. hereby grants to You a perpetual, worldwide, +non-exclusive, no-charge, royalty-free, irrevocable (except as stated +in this section) patent license to make, have made, use, offer to +sell, sell, import, transfer and otherwise run, modify and propagate +the contents of this implementation of Tailscale, where such license +applies only to those patent claims, both currently owned or +controlled by Tailscale Inc. and acquired in the future, licensable +by Tailscale Inc. that are necessarily infringed by this +implementation of Tailscale. This grant does not include claims that +would be infringed only as a consequence of further modification of +this implementation. If you or your agent or exclusive licensee +institute or order or agree to the institution of patent litigation +against any entity (including a cross-claim or counterclaim in a +lawsuit) alleging that this implementation of Tailscale or any code +incorporated within this implementation of Tailscale constitutes +direct or contributory patent infringement, or inducement of patent +infringement, then any patent rights granted to you under this License +for this implementation of Tailscale shall terminate as of the date +such litigation is filed. diff --git a/atomicfile/atomicfile.go b/atomicfile/atomicfile.go new file mode 100644 index 000000000..dcf3c3235 --- /dev/null +++ b/atomicfile/atomicfile.go @@ -0,0 +1,28 @@ +// Copyright 2019 Tailscale & AUTHORS. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package atomicfile contains code related to writing to filesystems +// atomically. +// +// This package should be considered internal; its API is not stable. +package atomicfile // import "tailscale.com/atomicfile" + +import ( + "fmt" + "io/ioutil" + "os" +) + +// WriteFile writes data to filename+some suffix, then renames it +// into filename. +func WriteFile(filename string, data []byte, perm os.FileMode) error { + tmpname := filename + ".new.tmp" + if err := ioutil.WriteFile(tmpname, data, perm); err != nil { + return fmt.Errorf("%#v: %v", tmpname, err) + } + if err := os.Rename(tmpname, filename); err != nil { + return fmt.Errorf("%#v->%#v: %v", tmpname, filename, err) + } + return nil +} diff --git a/cmd/relaynode/.gitignore b/cmd/relaynode/.gitignore new file mode 100644 index 000000000..d1e727f82 --- /dev/null +++ b/cmd/relaynode/.gitignore @@ -0,0 +1,14 @@ +/*.tar.gz +/*.deb +/*.rpm +/*.spec +/pkgver +debian/changelog +debian/debhelper-build-stamp +debian/files +debian/*.log +debian/*.substvars +debian/*.debhelper +debian/tailscale-relay +/tailscale-relay/ +/tailscale-relay-* diff --git a/cmd/relaynode/acl.json b/cmd/relaynode/acl.json new file mode 100644 index 000000000..29aff21df --- /dev/null +++ b/cmd/relaynode/acl.json @@ -0,0 +1,63 @@ +{ + // Declare static groups of users beyond those in the identity service + "Groups": { + "group:eng": ["u1@example.com", "u2@example.com"] + }, + + // Declare convenient hostname aliases to use in place of IP addresses + "Hosts": { + "h222": "100.2.2.2" + }, + + // Access control list + "ACLs": [ + { + "Action": "accept", + // Match any of several users + "Users": ["a@example.com", "b@example.com"], + // Match any port on h222, and port 22 of 10.1.2.3 + "Ports": ["h222:*", "10.1.2.3:22"] + }, + { + "Action": "accept", + // Match any user at all + "Users": ["*"], + // Match port 80 on one machine, ports 53 and 5353 on a second one, + // and ports 8000 through 8080 (a port range) on a third one. + "Ports": ["h222:80", "10.8.8.8:53,5353", "10.2.3.4:8000-8080"] + }, + { + "Action": "accept", + // Match all users in the "Admin" role (network administrators) + "Users": ["role:Admin", "group:eng"], + // Allow access to port 22 on all servers + "Ports": ["*:22"] + }, + { + "Action": "accept", + "Users": ["role:User"], + // Match only windows and linux workstations (not implemented yet) + "OS": ["windows", "linux"], + // Only desktop machines are allowed to access this server + "Ports": ["10.1.1.1:443"] + }, + { + "Action": "accept", + "Users": ["*"], + // Match machines which have never been authorized, or which expired. + // (not implemented yet) + "MachineAuth": ["unauthorized", "expired"], + // Logged-in users on unauthorized machines can access the email server. + // Open the TLS ports for SMTP, IMAP, and HTTP. + "Ports": ["10.1.2.3:465", "10.1.2.3:993", "10.1.2.3:443"] + }, + + // Match absolutely everything. Comment out this section if you want + // the above ACLs to apply. + { "Action": "accept", "Users": ["*"], "Ports": ["*:*"] }, + + // Leave this line here so that every rule can end in a comma. + // It has no effect since it has no matching rules. + {"Action": "accept"} + ] +} diff --git a/cmd/relaynode/clean.do b/cmd/relaynode/clean.do new file mode 100644 index 000000000..e2d5edc3c --- /dev/null +++ b/cmd/relaynode/clean.do @@ -0,0 +1 @@ +rm -f debian/changelog *~ debian/*~ diff --git a/cmd/relaynode/clean.od b/cmd/relaynode/clean.od new file mode 100644 index 000000000..2caeabc60 --- /dev/null +++ b/cmd/relaynode/clean.od @@ -0,0 +1,13 @@ +exec >&2 +read -r package &2 +dir=${1%/*} +redo-ifchange "$S/$dir/package" "$S/oss/version/short.txt" +read -r package <"$S/$dir/package" +read -r version <"$S/oss/version/short.txt" +arch=$(dpkg --print-architecture) + +redo-ifchange "$dir/${package}_$arch.deb" +rm -f "$dir/${package}"_*_"$arch.deb" +ln -sf "${package}_$arch.deb" "$dir/${package}_${version}_$arch.deb" diff --git a/cmd/relaynode/debian/README.Debian b/cmd/relaynode/debian/README.Debian new file mode 100644 index 000000000..f8ac2950c --- /dev/null +++ b/cmd/relaynode/debian/README.Debian @@ -0,0 +1 @@ +Tailscale IPN relay daemon. diff --git a/cmd/relaynode/debian/changelog.do b/cmd/relaynode/debian/changelog.do new file mode 100644 index 000000000..40d871b2d --- /dev/null +++ b/cmd/relaynode/debian/changelog.do @@ -0,0 +1,5 @@ +redo-ifchange ../../../version/short.txt gen-changelog +( + cd .. + debian/gen-changelog +) >$3 diff --git a/cmd/relaynode/debian/clean b/cmd/relaynode/debian/clean new file mode 100644 index 000000000..e69de29bb diff --git a/cmd/relaynode/debian/compat b/cmd/relaynode/debian/compat new file mode 100644 index 000000000..ec635144f --- /dev/null +++ b/cmd/relaynode/debian/compat @@ -0,0 +1 @@ +9 diff --git a/cmd/relaynode/debian/control b/cmd/relaynode/debian/control new file mode 100644 index 000000000..8d7054f66 --- /dev/null +++ b/cmd/relaynode/debian/control @@ -0,0 +1,14 @@ +Source: tailscale-relay +Section: net +Priority: extra +Maintainer: Avery Pennarun +Build-Depends: debhelper (>= 10.2.5), dh-systemd (>= 1.5) +Standards-Version: 3.9.2 +Homepage: https://tailscale.com/ +Vcs-Git: https://github.com/tailscale/tailscale +Vcs-Browser: https://github.com/tailscale/tailscale + +Package: tailscale-relay +Architecture: any +Depends: ${shlibs:Depends}, ${misc:Depends} +Description: Traffic relay node for Tailscale IPN diff --git a/cmd/relaynode/debian/copyright b/cmd/relaynode/debian/copyright new file mode 100644 index 000000000..dfae96f4c --- /dev/null +++ b/cmd/relaynode/debian/copyright @@ -0,0 +1,11 @@ +Format: http://svn.debian.org/wsvn/dep/web/deps/dep5.mdwn?op=file&rev=173 +Upstream-Name: tailscale-relay +Upstream-Contact: Avery Pennarun +Source: https://github.com/tailscale/tailscale/ + +Files: * +Copyright: © 2019 Tailscale Inc. +License: Proprietary + * + * Copyright 2019 Tailscale Inc. All rights reserved. + * diff --git a/cmd/relaynode/debian/gen-changelog b/cmd/relaynode/debian/gen-changelog new file mode 100755 index 000000000..19a2795be --- /dev/null +++ b/cmd/relaynode/debian/gen-changelog @@ -0,0 +1,25 @@ +#!/bin/sh +read junk pkgname %aD +' . | +python -Sc ' +import os, re, subprocess, sys + +first = True +def Describe(g): + global first + if first: + s = sys.argv[1] + first = False + else: + sha = g.group(1) + s = subprocess.check_output(["git", "describe", "--", sha]).strip().decode("utf-8") + return re.sub(r"^\D*", "", s) + +print(re.sub(r"SHA:([0-9a-f]+)", Describe, sys.stdin.read())) +' "$shortver" diff --git a/cmd/relaynode/debian/install b/cmd/relaynode/debian/install new file mode 100644 index 000000000..d2b64573f --- /dev/null +++ b/cmd/relaynode/debian/install @@ -0,0 +1,4 @@ +relaynode /usr/sbin +tailscale-login /usr/sbin +taillogin /usr/sbin +acl.json /etc/tailscale diff --git a/cmd/relaynode/debian/postinst b/cmd/relaynode/debian/postinst new file mode 100644 index 000000000..2042f4b4e --- /dev/null +++ b/cmd/relaynode/debian/postinst @@ -0,0 +1,8 @@ +#DEBHELPER# + +f=/var/lib/tailscale/relay.conf +if ! [ -e "$f" ]; then + echo + echo "Note: Run tailscale-login to configure $f." >&2 + echo +fi diff --git a/cmd/relaynode/debian/rules b/cmd/relaynode/debian/rules new file mode 100755 index 000000000..cb8a4dfc0 --- /dev/null +++ b/cmd/relaynode/debian/rules @@ -0,0 +1,10 @@ +#!/usr/bin/make -f +DESTDIR=debian/tailscale-relay + +override_dh_auto_test: +override_dh_auto_install: + mkdir -p "${DESTDIR}/etc/default" + cp tailscale-relay.defaults "${DESTDIR}/etc/default/tailscale-relay" + +%: + dh $@ --with=systemd diff --git a/cmd/relaynode/debian/tailscale-relay.service b/cmd/relaynode/debian/tailscale-relay.service new file mode 100644 index 000000000..446295a30 --- /dev/null +++ b/cmd/relaynode/debian/tailscale-relay.service @@ -0,0 +1,12 @@ +[Unit] +Description=Traffic relay node for Tailscale IPN +After=network.target +ConditionPathExists=/var/lib/tailscale/relay.conf + +[Service] +EnvironmentFile=/etc/default/tailscale-relay +ExecStart=/usr/sbin/relaynode --config=/var/lib/tailscale/relay.conf --tun=wg0 $PORT $ACL_FILE $FLAGS +Restart=on-failure + +[Install] +WantedBy=multi-user.target diff --git a/cmd/relaynode/default.deb.od b/cmd/relaynode/default.deb.od new file mode 100644 index 000000000..09cd35502 --- /dev/null +++ b/cmd/relaynode/default.deb.od @@ -0,0 +1,20 @@ +exec >&2 +dir=${1%/*} +redo-ifchange "$S/oss/version/short.txt" "$S/$dir/package" "$dir/debtmp.dir" +read -r package <"$S/$dir/package" +read -r version <"$S/oss/version/short.txt" +arch=$(dpkg --print-architecture) + +( + cd "$S/$dir" + git ls-files debian | xargs redo-ifchange debian/changelog +) +cp -a "$S/$dir/debian" "$dir/debtmp/" +rm -f "$dir/debtmp/debian/$package.debhelper.log" +( + cd "$dir/debtmp" && + debian/rules build && + fakeroot debian/rules binary +) + +mv "$dir/${package}_${version}_${arch}.deb" "$3" diff --git a/cmd/relaynode/default.dir.od b/cmd/relaynode/default.dir.od new file mode 100644 index 000000000..90e963ff9 --- /dev/null +++ b/cmd/relaynode/default.dir.od @@ -0,0 +1,21 @@ +# Generate a directory tree suitable for forming a tarball of +# this package. +exec >&2 +dir=${1%/*} +outdir=$PWD/${1%.dir} +rm -rf "$outdir" +mkdir "$outdir" +touch $outdir/.stamp +sfiles=" + tailscale-login + acl.json + debian/*.service + *.defaults +" +ofiles=" + relaynode + ../taillogin/taillogin +" +redo-ifchange "$outdir/.stamp" +(cd "$S/$dir" && redo-ifchange $sfiles && cp $sfiles "$outdir/") +(cd "$dir" && redo-ifchange $ofiles && cp $ofiles "$outdir/") diff --git a/cmd/relaynode/default.rpm.od b/cmd/relaynode/default.rpm.od new file mode 100644 index 000000000..254ba1fcf --- /dev/null +++ b/cmd/relaynode/default.rpm.od @@ -0,0 +1,14 @@ +exec >&2 +dir=${1%/*} +pkg=${1##*/} +pkg=${pkg%.rpm} +redo-ifchange "$S/oss/version/short.txt" "$dir/$pkg.tar.gz" "$dir/$pkg.spec" +read -r pkgver junk <"$S/oss/version/short.txt" + +machine=$(uname -m) +rpmbase=$HOME/rpmbuild + +mkdir -p "$rpmbase/SOURCES/" +cp "$dir/$pkg.tar.gz" "$rpmbase/SOURCES/" +rpmbuild -bb "$dir/$pkg.spec" +mv "$rpmbase/RPMS/$machine/$pkg-$pkgver.$machine.rpm" $3 diff --git a/cmd/relaynode/default.spec.od b/cmd/relaynode/default.spec.od new file mode 100644 index 000000000..30f2be2fa --- /dev/null +++ b/cmd/relaynode/default.spec.od @@ -0,0 +1,7 @@ +redo-ifchange "$S/$1.in" "$S/oss/version/short.txt" +read -r pkgver junk <"$S/oss/version/short.txt" +basever=${pkgver%-*} +subver=${pkgver#*-} +sed -e "s/Version: 0.00$/Version: $basever/" \ + -e "s/Release: 0$/Release: $subver/" \ + <"$S/$1.in" >"$3" diff --git a/cmd/relaynode/default.tar.gz.od b/cmd/relaynode/default.tar.gz.od new file mode 100644 index 000000000..aabe90ef9 --- /dev/null +++ b/cmd/relaynode/default.tar.gz.od @@ -0,0 +1,8 @@ +exec >&2 +xdir=${1%.tar.gz} +base=${xdir##*/} +updir=${xdir%/*} +redo-ifchange "$xdir.dir" +OUT="$PWD/$3" + +cd "$updir" && tar -czvf "$OUT" --exclude "$base/.stamp" "$base" diff --git a/cmd/relaynode/dist.od b/cmd/relaynode/dist.od new file mode 100644 index 000000000..ef258d44a --- /dev/null +++ b/cmd/relaynode/dist.od @@ -0,0 +1,15 @@ +# Build packages for customer distribution. +dir=${1%/*} +cd "$dir" +targets="tarball" +if which dh_clean fakeroot dpkg >/dev/null; then + targets="$targets deb" +else + echo "Skipping debian packages: debhelper and/or dpkg build tools missing." >&2 +fi +if which rpm >/dev/null; then + targets="$targets rpm" +else + echo "Skipping rpm packages: rpm build tools missing." >&2 +fi +redo-ifchange $targets diff --git a/cmd/relaynode/docker/.gitignore b/cmd/relaynode/docker/.gitignore new file mode 100644 index 000000000..09a5bf85a --- /dev/null +++ b/cmd/relaynode/docker/.gitignore @@ -0,0 +1 @@ +/relaynode diff --git a/cmd/relaynode/docker/Dockerfile b/cmd/relaynode/docker/Dockerfile new file mode 100644 index 000000000..714abc216 --- /dev/null +++ b/cmd/relaynode/docker/Dockerfile @@ -0,0 +1,17 @@ +# Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file. + +# Build with: docker build -t tailcontrol-alpine . +# Run with: docker run --cap-add=NET_ADMIN --device=/dev/net/tun:/dev/net/tun -it tailcontrol-alpine + +FROM debian:stretch-slim + +RUN apt-get update && apt-get -y install iproute2 iptables +RUN apt-get -y install ca-certificates +RUN apt-get -y install nginx-light + +COPY relaynode / + +# tailcontrol -tun=wg0 -dbdir=$HOME/taildb >> tailcontrol.log 2>&1 & +CMD ["/relaynode", "-R", "--config", "relay.conf"] diff --git a/cmd/relaynode/docker/all.do b/cmd/relaynode/docker/all.do new file mode 100644 index 000000000..941435df8 --- /dev/null +++ b/cmd/relaynode/docker/all.do @@ -0,0 +1 @@ +redo-ifchange build diff --git a/cmd/relaynode/docker/build.do b/cmd/relaynode/docker/build.do new file mode 100644 index 000000000..29cb0e6bc --- /dev/null +++ b/cmd/relaynode/docker/build.do @@ -0,0 +1,3 @@ +exec >&2 +redo-ifchange Dockerfile relaynode +docker build -t tailscale . diff --git a/cmd/relaynode/docker/relaynode.do b/cmd/relaynode/docker/relaynode.do new file mode 100644 index 000000000..11adac7f4 --- /dev/null +++ b/cmd/relaynode/docker/relaynode.do @@ -0,0 +1,2 @@ +redo-ifchange ../relaynode +cp ../relaynode $3 diff --git a/cmd/relaynode/docker/run.sh b/cmd/relaynode/docker/run.sh new file mode 100755 index 000000000..97a528fa2 --- /dev/null +++ b/cmd/relaynode/docker/run.sh @@ -0,0 +1,10 @@ +#!/bin/sh +# Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file. + +set -e +redo-ifchange build +docker run --cap-add=NET_ADMIN \ + --device=/dev/net/tun:/dev/net/tun \ + -it tailscale diff --git a/cmd/relaynode/package b/cmd/relaynode/package new file mode 100644 index 000000000..ecdb13319 --- /dev/null +++ b/cmd/relaynode/package @@ -0,0 +1 @@ +tailscale-relay diff --git a/cmd/relaynode/relaynode.go b/cmd/relaynode/relaynode.go new file mode 100644 index 000000000..2848ac171 --- /dev/null +++ b/cmd/relaynode/relaynode.go @@ -0,0 +1,300 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Relaynode is the old Linux Tailscale daemon. +// +// Deprecated: this program will be soon deleted. The replacement is +// cmd/tailscaled. +package main + +import ( + "context" + "encoding/json" + "fmt" + "io/ioutil" + "log" + "net/http" + "net/http/pprof" + "os" + "os/signal" + "strings" + "syscall" + "time" + + "github.com/apenwarr/fixconsole" + "github.com/google/go-cmp/cmp" + "github.com/klauspost/compress/zstd" + "github.com/pborman/getopt/v2" + "github.com/tailscale/wireguard-go/wgcfg" + "tailscale.com/atomicfile" + "tailscale.com/control/controlclient" + "tailscale.com/control/policy" + "tailscale.com/logpolicy" + "tailscale.com/version" + "tailscale.com/wgengine" + "tailscale.com/wgengine/filter" + "tailscale.com/wgengine/magicsock" +) + +func main() { + err := fixconsole.FixConsoleIfNeeded() + if err != nil { + log.Printf("fixConsoleOutput: %v\n", err) + } + config := getopt.StringLong("config", 'f', "", "path to config file") + server := getopt.StringLong("server", 's', "https://login.tailscale.com", "URL to tailcontrol server") + listenport := getopt.Uint16Long("port", 'p', magicsock.DefaultPort, "WireGuard port (0=autoselect)") + tunname := getopt.StringLong("tun", 0, "wg0", "tunnel interface name") + alwaysrefresh := getopt.BoolLong("always-refresh", 0, "force key refresh at startup") + fake := getopt.BoolLong("fake", 0, "fake tunnel+routing instead of tuntap") + nuroutes := getopt.BoolLong("no-single-routes", 'N', "disallow (non-subnet) routes to single nodes") + rroutes := getopt.BoolLong("remote-routes", 'R', "allow routing subnets to remote nodes") + droutes := getopt.BoolLong("default-routes", 'D', "allow default route on remote node") + routes := getopt.StringLong("routes", 0, "", "list of IP ranges this node can relay") + aclfile := getopt.StringLong("acl-file", 0, "", "restrict traffic relaying according to json ACL file") + derp := getopt.BoolLong("derp", 0, "enable bypass via Detour Encrypted Routing Protocol (DERP)", "false") + debug := getopt.StringLong("debug", 0, "", "Address of debug server") + getopt.Parse() + if len(getopt.Args()) > 0 { + log.Fatalf("too many non-flag arguments: %#v", getopt.Args()[0]) + } + uflags := controlclient.UFlagsHelper(!*nuroutes, *rroutes, *droutes) + if *config == "" { + log.Fatal("no --config file specified") + } + if *tunname == "" { + log.Printf("Warning: no --tun device specified; routing disabled.\n") + } + + pol := logpolicy.New("tailnode.log.tailscale.io", *config) + + logf := wgengine.RusagePrefixLog(log.Printf) + + // The wgengine takes a wireguard configuration produced by the + // controlclient, and runs the actual tunnels and packets. + var e wgengine.Engine + if *fake { + e, err = wgengine.NewFakeUserspaceEngine(logf, *listenport, *derp) + } else { + e, err = wgengine.NewUserspaceEngine(logf, *tunname, *listenport, *derp) + } + if err != nil { + log.Fatalf("Error starting wireguard engine: %v\n", err) + } + + e = wgengine.NewWatchdog(e) + var lastacljson string + var p *policy.Policy + + if *aclfile == "" { + e.SetFilter(nil) + } else { + lastacljson = readOrFatal(*aclfile) + p = installFilterOrFatal(e, *aclfile, lastacljson, nil) + } + + var lastNetMap *controlclient.NetworkMap + var lastUserMap map[string][]filter.IP + statusFunc := func(new controlclient.Status) { + if new.URL != "" { + fmt.Fprintf(os.Stderr, "To authenticate, visit:\n\n\t%s\n\n", new.URL) + return + } + if new.Err != "" { + log.Print(new.Err) + return + } + if new.Persist != nil { + if err := saveConfig(*config, *new.Persist); err != nil { + log.Println(err) + } + } + + if m := new.NetMap; m != nil { + if lastNetMap != nil { + s1 := strings.Split(lastNetMap.Concise(), "\n") + s2 := strings.Split(new.NetMap.Concise(), "\n") + logf("netmap diff:\n%v\n", cmp.Diff(s1, s2)) + } + lastNetMap = m + + if m.Equal(&controlclient.NetworkMap{}) { + return + } + + wgcfg, err := m.WGCfg(uflags, m.DNS) + if err != nil { + log.Fatalf("Error getting wg config: %v\n", err) + } + err = e.Reconfig(wgcfg, m.DNSDomains) + if err != nil { + log.Fatalf("Error reconfiguring engine: %v\n", err) + } + lastUserMap = m.UserMap() + if p != nil { + matches, err := p.Expand(lastUserMap) + if err != nil { + log.Fatalf("Error expanding ACLs: %v\n", err) + } + e.SetFilter(filter.New(matches)) + } + } + } + + cfg, err := loadConfig(*config) + if err != nil { + log.Fatal(err) + } + + hi := controlclient.NewHostinfo() + hi.FrontendLogID = pol.PublicID.String() + hi.BackendLogID = pol.PublicID.String() + if *routes != "" { + for _, routeStr := range strings.Split(*routes, ",") { + cidr, err := wgcfg.ParseCIDR(routeStr) + if err != nil { + log.Fatalf("--routes: not an IP range: %s", routeStr) + } + hi.RoutableIPs = append(hi.RoutableIPs, *cidr) + } + } + + c, err := controlclient.New(controlclient.Options{ + Persist: cfg, + ServerURL: *server, + Hostinfo: &hi, + NewDecompressor: func() (controlclient.Decompressor, error) { + return zstd.NewReader(nil) + }, + KeepAlive: true, + }) + c.SetStatusFunc(statusFunc) + if err != nil { + log.Fatal(err) + } + lf := controlclient.LoginDefault + if *alwaysrefresh { + lf |= controlclient.LoginInteractive + } + c.Login(nil, lf) + + // Print the wireguard status when we get an update. + e.SetStatusCallback(func(s *wgengine.Status, err error) { + if err != nil { + log.Fatalf("Wireguard engine status error: %v\n", err) + } + var ss []string + for _, p := range s.Peers { + if p.LastHandshake.IsZero() { + ss = append(ss, "x") + } else { + ss = append(ss, fmt.Sprintf("%d/%d", p.RxBytes, p.TxBytes)) + } + } + logf("v%v peers: %v\n", version.LONG, strings.Join(ss, " ")) + c.UpdateEndpoints(0, s.LocalAddrs) + }) + + if *debug != "" { + go runDebugServer(*debug) + } + + sigCh := make(chan os.Signal, 1) + signal.Notify(sigCh, os.Interrupt) + signal.Notify(sigCh, syscall.SIGTERM) + + t := time.NewTicker(5 * time.Second) +loop: + for { + select { + case <-t.C: + // For the sake of curiosity, request a status + // update periodically. + e.RequestStatus() + + // check if aclfile has changed. + // TODO(apenwarr): use fsnotify instead of polling? + if *aclfile != "" { + json := readOrFatal(*aclfile) + if json != lastacljson { + logf("ACL file (%v) changed. Reloading filter.\n", *aclfile) + lastacljson = json + p = installFilterOrFatal(e, *aclfile, json, lastUserMap) + } + } + case <-sigCh: + logf("signal received, exiting") + t.Stop() + break loop + } + } + + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + + e.Close() + pol.Shutdown(ctx) +} + +func loadConfig(path string) (cfg controlclient.Persist, err error) { + b, err := ioutil.ReadFile(path) + if os.IsNotExist(err) { + log.Printf("config %s does not exist", path) + return controlclient.Persist{}, nil + } + if err := json.Unmarshal(b, &cfg); err != nil { + return controlclient.Persist{}, fmt.Errorf("load config: %v", err) + } + return cfg, nil +} + +func saveConfig(path string, cfg controlclient.Persist) error { + b, err := json.MarshalIndent(cfg, "", "\t") + if err != nil { + return fmt.Errorf("save config: %v", err) + } + if err := atomicfile.WriteFile(path, b, 0666); err != nil { + return fmt.Errorf("save config: %v", err) + } + return nil +} + +func readOrFatal(filename string) string { + b, err := ioutil.ReadFile(filename) + if err != nil { + log.Fatalf("%v: ReadFile: %v\n", filename, err) + } + return string(b) +} + +func installFilterOrFatal(e wgengine.Engine, filename, acljson string, usermap map[string][]filter.IP) *policy.Policy { + p, err := policy.Parse(acljson) + if err != nil { + log.Fatalf("%v: json filter: %v\n", filename, err) + } + + matches, err := p.Expand(usermap) + if err != nil { + log.Fatalf("%v: json filter: %v\n", filename, err) + } + + e.SetFilter(filter.New(matches)) + return p +} + +func runDebugServer(addr string) { + mux := http.NewServeMux() + mux.HandleFunc("/debug/pprof/", pprof.Index) + mux.HandleFunc("/debug/pprof/cmdline", pprof.Cmdline) + mux.HandleFunc("/debug/pprof/profile", pprof.Profile) + mux.HandleFunc("/debug/pprof/symbol", pprof.Symbol) + mux.HandleFunc("/debug/pprof/trace", pprof.Trace) + srv := http.Server{ + Addr: addr, + Handler: mux, + } + if err := srv.ListenAndServe(); err != nil { + log.Fatal(err) + } +} diff --git a/cmd/relaynode/rpm.od b/cmd/relaynode/rpm.od new file mode 100644 index 000000000..1594bb796 --- /dev/null +++ b/cmd/relaynode/rpm.od @@ -0,0 +1,9 @@ +exec >&2 +dir=${2%/*} +redo-ifchange "$S/$dir/package" "$S/oss/version/short.txt" +read -r package <"$S/$dir/package" +read -r pkgver <"$S/oss/version/short.txt" +machine=$(uname -m) +redo-ifchange "$dir/$package.rpm" +rm -f "$dir/${package}"-*."$machine.rpm" +ln -sf "$package.rpm" "$dir/$package-$pkgver.$machine.rpm" diff --git a/cmd/relaynode/tailscale-login b/cmd/relaynode/tailscale-login new file mode 100755 index 000000000..f4ec4e816 --- /dev/null +++ b/cmd/relaynode/tailscale-login @@ -0,0 +1,4 @@ +#!/bin/sh +cfg=/var/lib/tailscale/relay.conf +dir=$(dirname "$0") +"$dir/taillogin" --config="$cfg" diff --git a/cmd/relaynode/tailscale-relay.defaults b/cmd/relaynode/tailscale-relay.defaults new file mode 100644 index 000000000..077d112a2 --- /dev/null +++ b/cmd/relaynode/tailscale-relay.defaults @@ -0,0 +1,14 @@ +# Set the port to listen on for incoming VPN packets. +# Remote nodes will automatically be informed about the new port number, +# but you might want to configure this in order to set external firewall +# settings. +PORT="--port=41641" + +# Comment out this line to allow all traffic to be relayed. +# Or edit the given file to allow specific traffic. +# The example file is unlikely to match any users on your network, so it +# will block all incoming traffic by default. +ACL_FILE="--acl-file=/etc/tailscale/acl.json" + +# Extra flags you might want to pass to relaynode. +FLAGS="" diff --git a/cmd/relaynode/tailscale-relay.spec.in b/cmd/relaynode/tailscale-relay.spec.in new file mode 100644 index 000000000..351947e36 --- /dev/null +++ b/cmd/relaynode/tailscale-relay.spec.in @@ -0,0 +1,42 @@ +Name: tailscale-relay +Version: 0.00 +Release: 0 +Summary: Traffic relay node for Tailscale +Group: Network +License: Proprietary +URL: https://tailscale.com/ +Vendor: Tailscale Inc. +#Source: https://github.com/tailscale/tailscale +Source0: tailscale-relay.tar.gz +#Prefix: %{_prefix} +Packager: Avery Pennarun +BuildRoot: %{_tmppath}/%{name}-root + +%description +Traffic relay node for Tailscale. + +%prep +%setup -n tailscale-relay + +%build + +%install +D=$RPM_BUILD_ROOT +[ "$D" = "/" -o -z "$D" ] && exit 99 +rm -rf "$D" +mkdir -p $D/usr/sbin $D/lib/systemd/system $D/etc/default $D/etc/tailscale +cp taillogin tailscale-login relaynode $D/usr/sbin +cp tailscale-relay.service $D/lib/systemd/system/ +cp tailscale-relay.defaults $D/etc/default/tailscale-relay +cp acl.json $D/etc/tailscale/acl.json + +%clean + +%files +%defattr(-,root,root) +%config(noreplace) /etc/default/tailscale-relay +%config(noreplace) /etc/tailscale/acl.json +/lib/systemd/system/tailscale-relay.service +/usr/sbin/taillogin +/usr/sbin/tailscale-login +/usr/sbin/relaynode diff --git a/cmd/relaynode/tarball.od b/cmd/relaynode/tarball.od new file mode 100644 index 000000000..290fc5480 --- /dev/null +++ b/cmd/relaynode/tarball.od @@ -0,0 +1,7 @@ +dir=${1%/*} +redo-ifchange "$S/$dir/package" "$S/oss/version/short.txt" +read -r package <"$S/$dir/package" +read -r version <"$S/oss/version/short.txt" +redo-ifchange "$dir/$package.tar.gz" +rm -f "$dir/$package"-*.tar.gz +ln -sf "$package.tar.gz" "$dir/$package-$version.tar.gz" diff --git a/cmd/taillogin/taillogin.go b/cmd/taillogin/taillogin.go new file mode 100644 index 000000000..605cd3d01 --- /dev/null +++ b/cmd/taillogin/taillogin.go @@ -0,0 +1,96 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// The taillogin command, invoked via the tailscale-login shell script, is shipped +// with the current (old) Linux client, to log in to Tailscale on a Linux box. +// +// Deprecated: this will be deleted, to be replaced by cmd/tailscale. +package main + +import ( + "encoding/json" + "fmt" + "io/ioutil" + "log" + "os" + + "github.com/pborman/getopt/v2" + "tailscale.com/atomicfile" + "tailscale.com/control/controlclient" + "tailscale.com/logpolicy" +) + +func main() { + config := getopt.StringLong("config", 'f', "", "path to config file") + server := getopt.StringLong("server", 's', "https://login.tailscale.com", "URL to tailgate server") + getopt.Parse() + if len(getopt.Args()) > 0 { + log.Fatal("too many non-flag arguments") + } + if *config == "" { + log.Fatal("no --config file specified") + } + pol := logpolicy.New("tailnode.log.tailscale.io", *config) + defer pol.Close() + + cfg, err := loadConfig(*config) + if err != nil { + log.Fatal(err) + } + + hi := controlclient.NewHostinfo() + hi.FrontendLogID = pol.PublicID.String() + hi.BackendLogID = pol.PublicID.String() + + done := make(chan struct{}, 1) + c, err := controlclient.New(controlclient.Options{ + Persist: cfg, + ServerURL: *server, + Hostinfo: &hi, + }) + c.SetStatusFunc(func(new controlclient.Status) { + if new.URL != "" { + fmt.Fprintf(os.Stderr, "To authenticate, visit:\n\n\t%s\n\n", new.URL) + return + } + if new.Err != "" { + log.Print(new.Err) + return + } + if new.Persist != nil { + if err := saveConfig(*config, *new.Persist); err != nil { + log.Println(err) + } + } + if new.NetMap != nil { + done <- struct{}{} + } + }) + c.Login(nil, 0) + <-done + log.Printf("Success.\n") +} + +func loadConfig(path string) (cfg controlclient.Persist, err error) { + b, err := ioutil.ReadFile(path) + if os.IsNotExist(err) { + log.Printf("config %s does not exist", path) + return controlclient.Persist{}, nil + } + if err := json.Unmarshal(b, &cfg); err != nil { + return controlclient.Persist{}, fmt.Errorf("load config: %v", err) + } + return cfg, nil +} + +func saveConfig(path string, cfg controlclient.Persist) error { + b, err := json.MarshalIndent(cfg, "", "\t") + if err != nil { + return fmt.Errorf("save config: %v", err) + } + if err := atomicfile.WriteFile(path, b, 0666); err != nil { + return fmt.Errorf("save config: %v", err) + } + return nil +} diff --git a/cmd/tailscale/ipn.go b/cmd/tailscale/ipn.go new file mode 100644 index 000000000..685401c51 --- /dev/null +++ b/cmd/tailscale/ipn.go @@ -0,0 +1,149 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// The tailscale command is the Tailscale command-line client. It interacts +// with the tailscaled client daemon. +package main + +import ( + "context" + "encoding/json" + "fmt" + "io/ioutil" + "log" + "net" + "os" + "os/signal" + "syscall" + + "github.com/apenwarr/fixconsole" + "github.com/pborman/getopt/v2" + "tailscale.com/atomicfile" + "tailscale.com/control/controlclient" + "tailscale.com/ipn" + "tailscale.com/logpolicy" + "tailscale.com/safesocket" +) + +func pump(ctx context.Context, bc *ipn.BackendClient, c net.Conn) { + defer log.Printf("Control connection done.\n") + defer c.Close() + for ctx.Err() == nil { + msg, err := ipn.ReadMsg(c) + if err != nil { + log.Printf("ReadMsg: %v\n", err) + break + } + bc.GotNotifyMsg(msg) + } +} + +func main() { + err := fixconsole.FixConsoleIfNeeded() + if err != nil { + log.Printf("fixConsoleOutput: %v\n", err) + } + config := getopt.StringLong("config", 'f', "", "path to config file") + server := getopt.StringLong("server", 's', "https://login.tailscale.com", "URL to tailcontrol server") + alwaysrefresh := getopt.BoolLong("always-refresh", 0, "force key refresh at startup") + nuroutes := getopt.BoolLong("no-single-routes", 'N', "disallow (non-subnet) routes to single nodes") + rroutes := getopt.BoolLong("remote-routes", 'R', "allow routing subnets to remote nodes") + droutes := getopt.BoolLong("default-routes", 'D', "allow default route on remote node") + getopt.Parse() + if *config == "" { + logpolicy.New("tailnode.log.tailscale.io", "tailscale") + log.Fatal("no --config file specified") + } + if len(getopt.Args()) > 0 { + log.Fatalf("too many non-flag arguments: %#v", getopt.Args()[0]) + } + + pol := logpolicy.New("tailnode.log.tailscale.io", *config) + defer pol.Close() + + prefs, err := loadConfig(*config) + if err != nil { + log.Fatal(err) + } + + // TODO(apenwarr): fix different semantics between prefs and uflags + // TODO(apenwarr): allow setting/using CorpDNS + prefs.WantRunning = true + prefs.RouteAll = *rroutes || *droutes + prefs.AllowSingleHosts = !*nuroutes + + c, err := safesocket.Connect("", "Tailscale", "tailscaled", 41112) + if err != nil { + log.Fatalf("safesocket.Connect: %v\n", err) + } + clientToServer := func(b []byte) { + ipn.WriteMsg(c, b) + } + + ctx, cancel := context.WithCancel(context.Background()) + lf := controlclient.LoginDefault + if *alwaysrefresh { + lf |= controlclient.LoginInteractive + } + + go func() { + interrupt := make(chan os.Signal, 1) + signal.Notify(interrupt, syscall.SIGINT, syscall.SIGTERM) + <-interrupt + c.Close() + }() + + bc := ipn.NewBackendClient(log.Printf, clientToServer) + opts := ipn.Options{ + Prefs: prefs, + ServerURL: *server, + LoginFlags: lf, + Notify: func(n ipn.Notify) { + log.Printf("Notify: %v\n", n) + if n.ErrMessage != nil { + log.Fatalf("backend error: %v\n", *n.ErrMessage) + } + if s := n.State; s != nil { + switch *s { + case ipn.NeedsLogin: + bc.StartLoginInteractive() + case ipn.NeedsMachineAuth: + fmt.Fprintf(os.Stderr, "\nTo authorize your machine, visit (as admin):\n\n\t%s/admin/machines\n\n", *server) + case ipn.Starting, ipn.Running: + // Done full authentication process + cancel() + } + } + if url := n.BrowseToURL; url != nil { + fmt.Fprintf(os.Stderr, "\nTo authenticate, visit:\n\n\t%s\n\n", *url) + } + if p := n.Prefs; p != nil { + prefs = *p + saveConfig(*config, *p) + } + }, + } + bc.Start(opts) + pump(ctx, bc, c) +} + +func loadConfig(path string) (ipn.Prefs, error) { + b, err := ioutil.ReadFile(path) + if os.IsNotExist(err) { + log.Printf("config %s does not exist", path) + return ipn.NewPrefs(), nil + } + return ipn.PrefsFromBytes(b, false) +} + +func saveConfig(path string, prefs ipn.Prefs) error { + b, err := json.MarshalIndent(prefs, "", "\t") + if err != nil { + return fmt.Errorf("save config: %v", err) + } + if err := atomicfile.WriteFile(path, b, 0666); err != nil { + return fmt.Errorf("save config: %v", err) + } + return nil +} diff --git a/cmd/tailscaled/ipnd.go b/cmd/tailscaled/ipnd.go new file mode 100644 index 000000000..6f0a470ad --- /dev/null +++ b/cmd/tailscaled/ipnd.go @@ -0,0 +1,88 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// The tailscaled program is the Tailscale client daemon. It's configured +// and controlled via the tailscale CLI program. +// +// It primarily supports Linux, though other systems will likely be +// supported in the future. +package main + +import ( + "context" + "log" + "net/http" + "net/http/pprof" + + "github.com/apenwarr/fixconsole" + "github.com/pborman/getopt/v2" + "tailscale.com/ipn/ipnserver" + "tailscale.com/logpolicy" + "tailscale.com/wgengine" +) + +func main() { + fake := getopt.BoolLong("fake", 0, "fake tunnel+routing instead of tuntap") + debug := getopt.StringLong("debug", 0, "", "Address of debug server") + + logf := wgengine.RusagePrefixLog(log.Printf) + + err := fixconsole.FixConsoleIfNeeded() + if err != nil { + logf("fixConsoleOutput: %v\n", err) + } + pol := logpolicy.New("tailnode.log.tailscale.io", "tailscaled") + + getopt.Parse() + if len(getopt.Args()) > 0 { + log.Fatalf("too many non-flag arguments: %#v", getopt.Args()[0]) + } + + if *debug != "" { + go runDebugServer(*debug) + } + + var e wgengine.Engine + if *fake { + e, err = wgengine.NewFakeUserspaceEngine(logf, 0, false) + } else { + e, err = wgengine.NewUserspaceEngine(logf, "ts0", 0, false) + } + if err != nil { + log.Fatalf("wgengine.New: %v\n", err) + } + e = wgengine.NewWatchdog(e) + + opts := ipnserver.Options{ + SurviveDisconnects: true, + AllowQuit: false, + } + err = ipnserver.Run(context.Background(), logf, pol.PublicID.String(), opts, e) + if err != nil { + log.Fatalf("tailscaled: %v\n", err) + } + + // TODO(crawshaw): It would be nice to start a timeout context the moment a signal + // is received and use that timeout to give us a moment to finish uploading logs + // here. But the signal is handled inside ipnserver.Run, so some plumbing is needed. + ctx, cancel := context.WithCancel(context.Background()) + cancel() + pol.Shutdown(ctx) +} + +func runDebugServer(addr string) { + mux := http.NewServeMux() + mux.HandleFunc("/debug/pprof/", pprof.Index) + mux.HandleFunc("/debug/pprof/cmdline", pprof.Cmdline) + mux.HandleFunc("/debug/pprof/profile", pprof.Profile) + mux.HandleFunc("/debug/pprof/symbol", pprof.Symbol) + mux.HandleFunc("/debug/pprof/trace", pprof.Trace) + srv := http.Server{ + Addr: addr, + Handler: mux, + } + if err := srv.ListenAndServe(); err != nil { + log.Fatal(err) + } +} diff --git a/control/controlclient/auto.go b/control/controlclient/auto.go new file mode 100644 index 000000000..67f187f59 --- /dev/null +++ b/control/controlclient/auto.go @@ -0,0 +1,594 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package controlclient implements the client for the IPN control plane. +// +// It handles authentication, port picking, and collects the local +// network configuration. +package controlclient + +import ( + "context" + "encoding/json" + "fmt" + "reflect" + "sync" + "time" + + "golang.org/x/oauth2" + "tailscale.com/logger" + "tailscale.com/logtail/backoff" + "tailscale.com/tailcfg" +) + +// TODO(apenwarr): eliminate the 'state' variable, as it's now obsolete. +// It's used only by the unit tests. +type state int + +const ( + stateNew = state(iota) + stateNotAuthenticated + stateAuthenticating + stateURLVisitRequired + stateAuthenticated + stateSynchronized // connected and received map update +) + +func (s state) MarshalText() ([]byte, error) { + return []byte(s.String()), nil +} + +func (s state) String() string { + switch s { + case stateNew: + return "state:new" + case stateNotAuthenticated: + return "state:not-authenticated" + case stateAuthenticating: + return "state:authenticating" + case stateURLVisitRequired: + return "state:url-visit-required" + case stateAuthenticated: + return "state:authenticated" + case stateSynchronized: + return "state:synchronized" + default: + return fmt.Sprintf("state:unknown:%d", int(s)) + } +} + +type Status struct { + LoginFinished *struct{} + Err string + URL string + Persist *Persist // locally persisted configuration + NetMap *NetworkMap // server-pushed configuration + Hostinfo tailcfg.Hostinfo // current Hostinfo data + state state +} + +// Equal reports whether s and s2 are equal. +func (s *Status) Equal(s2 *Status) bool { + if s == nil && s2 == nil { + return true + } + return s != nil && s2 != nil && + (s.LoginFinished == nil) == (s2.LoginFinished == nil) && + s.Err == s2.Err && + s.URL == s2.URL && + reflect.DeepEqual(s.Persist, s2.Persist) && + reflect.DeepEqual(s.NetMap, s2.NetMap) && + reflect.DeepEqual(s.Hostinfo, s2.Hostinfo) && + s.state == s2.state +} + +func (s Status) String() string { + b, err := json.MarshalIndent(s, "", "\t") + if err != nil { + panic(err) + } + return s.state.String() + " " + string(b) +} + +type LoginGoal struct { + wantLoggedIn bool // true if we *want* to be logged in + token *oauth2.Token // oauth token to use when logging in + flags LoginFlags // flags to use when logging in + url string // auth url that needs to be visited +} + +// Client connects to a tailcontrol server for a node. +type Client struct { + direct *Direct // our interface to the server APIs + timeNow func() time.Time + logf logger.Logf + expiry *time.Time + closed bool + newMapCh chan struct{} // readable when we must restart a map request + + mu sync.Mutex // mutex guards the following fields + statusFunc func(Status) // called to update Client status + + loggedIn bool // true if currently logged in + loginGoal *LoginGoal // non-nil if some login activity is desired + synced bool // true if our netmap is up-to-date + hostinfo tailcfg.Hostinfo + inPollNetMap bool // true if currently running a PollNetMap + inSendStatus int // number of sendStatus calls currently in progress + state state + + authCtx context.Context // context used for auth requests + mapCtx context.Context // context used for netmap requests + authCancel func() // cancel the auth context + mapCancel func() // cancel the netmap context + quit chan struct{} // when closed, goroutines should all exit + authDone chan struct{} // when closed, auth goroutine is done + mapDone chan struct{} // when closed, map goroutine is done +} + +// New creates and starts a new Client. +func New(opts Options) (*Client, error) { + c, err := NewNoStart(opts) + if c != nil { + c.Start() + } + return c, err +} + +// NewNoStart creates a new Client, but without calling Start on it. +func NewNoStart(opts Options) (*Client, error) { + direct, err := NewDirect(opts) + if err != nil { + return nil, err + } + c := &Client{ + direct: direct, + timeNow: opts.TimeNow, + logf: opts.Logf, + newMapCh: make(chan struct{}, 1), + quit: make(chan struct{}), + authDone: make(chan struct{}), + mapDone: make(chan struct{}), + } + c.authCtx, c.authCancel = context.WithCancel(context.Background()) + c.mapCtx, c.mapCancel = context.WithCancel(context.Background()) + return c, nil +} + +// Start starts the client's goroutines. +// +// It should only be called for clients created by NewNoStart. +func (c *Client) Start() { + go c.authRoutine() + go c.mapRoutine() +} + +func (c *Client) cancelAuth() { + c.mu.Lock() + if c.authCancel != nil { + c.authCancel() + } + if !c.closed { + c.authCtx, c.authCancel = context.WithCancel(context.Background()) + } + c.mu.Unlock() +} + +func (c *Client) cancelMapLocked() { + if c.mapCancel != nil { + c.mapCancel() + } + if !c.closed { + c.mapCtx, c.mapCancel = context.WithCancel(context.Background()) + } +} + +func (c *Client) cancelMapUnsafely() { + c.mu.Lock() + c.cancelMapLocked() + c.mu.Unlock() +} + +func (c *Client) cancelMapSafely() { + c.mu.Lock() + defer c.mu.Unlock() + + c.logf("cancelMapSafely: synced=%v\n", c.synced) + + if c.inPollNetMap == true { + // received at least one netmap since the last + // interruption. That means the server has already + // fully processed our last request, which might + // include UpdateEndpoints(). Interrupt it and try + // again. + c.cancelMapLocked() + } else { + // !synced means we either haven't done a netmap + // request yet, or it hasn't answered yet. So the + // server is in an undefined state. If we send + // another netmap request too soon, it might race + // with the last one, and if we're very unlucky, + // the new request will be applied before the old one, + // and the wrong endpoints will get registered. We + // have to tell the client to abort politely, only + // after it receives a response to its existing netmap + // request. + select { + case c.newMapCh <- struct{}{}: + c.logf("cancelMapSafely: wrote to channel\n") + default: + // if channel write failed, then there was already + // an outstanding newMapCh request. One is enough, + // since it'll always use the latest endpoints. + c.logf("cancelMapSafely: channel was full\n") + } + } +} + +func (c *Client) authRoutine() { + defer close(c.authDone) + bo := backoff.Backoff{Name: "authRoutine"} + + for { + c.mu.Lock() + c.logf("authRoutine: %s\n", c.state) + expiry := c.expiry + goal := c.loginGoal + ctx := c.authCtx + synced := c.synced + c.mu.Unlock() + + select { + case <-c.quit: + c.logf("authRoutine: quit\n") + return + default: + } + + report := func(err error, msg string) { + c.logf("%s: %v\n", msg, err) + err = fmt.Errorf("%s: %v", msg, err) + // don't send status updates for context errors, + // since context cancelation is always on purpose. + if ctx.Err() == nil { + c.sendStatus("authRoutine1", err, "", nil) + } + } + + if goal == nil { + // Wait for something interesting to happen + var exp <-chan time.Time + if expiry != nil && !expiry.IsZero() { + // if expiry is in the future, don't delay + // past that time. + // If it's in the past, then it's already + // being handled by someone, so no need to + // wake ourselves up again. + now := c.timeNow() + if expiry.Before(now) { + delay := expiry.Sub(now) + if delay > 5*time.Second { + delay = time.Second + } + exp = time.After(delay) + } + } + select { + case <-ctx.Done(): + c.logf("authRoutine: context done.\n") + case <-exp: + // Unfortunately the key expiry isn't provided + // by the control server until mapRequest. + // So we have to do some hackery with c.expiry + // in here. + // TODO(apenwarr): add a key expiry field in RegisterResponse. + c.logf("authRoutine: key expiration check.\n") + if synced && expiry != nil && !expiry.IsZero() && expiry.Before(c.timeNow()) { + c.logf("Key expired; setting loggedIn=false.") + + c.mu.Lock() + c.loginGoal = &LoginGoal{ + wantLoggedIn: c.loggedIn, + } + c.loggedIn = false + c.expiry = nil + c.mu.Unlock() + } + } + } else if !goal.wantLoggedIn { + err := c.direct.TryLogout(c.authCtx) + if err != nil { + report(err, "TryLogout") + bo.BackOff(ctx, err) + continue + } + + // success + c.mu.Lock() + c.loggedIn = false + c.loginGoal = nil + c.state = stateNotAuthenticated + c.synced = false + c.mu.Unlock() + + c.sendStatus("authRoutine2", nil, "", nil) + bo.BackOff(ctx, nil) + } else { // ie. goal.wantLoggedIn + c.mu.Lock() + if goal.url != "" { + c.state = stateURLVisitRequired + } else { + c.state = stateAuthenticating + } + c.mu.Unlock() + + var url string + var err error + var f string + if goal.url != "" { + url, err = c.direct.WaitLoginURL(ctx, goal.url) + f = "WaitLoginURL" + } else { + url, err = c.direct.TryLogin(ctx, goal.token, goal.flags) + f = "TryLogin" + } + if err != nil { + report(err, f) + bo.BackOff(ctx, err) + continue + } else if url != "" { + if goal.url != "" { + err = fmt.Errorf("weird: server required a new url?") + report(err, "WaitLoginURL") + } + goal.url = url + goal.token = nil + goal.flags = LoginDefault + + c.mu.Lock() + c.loginGoal = goal + c.state = stateURLVisitRequired + c.synced = false + c.mu.Unlock() + + c.sendStatus("authRoutine3", err, url, nil) + bo.BackOff(ctx, err) + continue + } + + // success + c.mu.Lock() + c.loggedIn = true + c.loginGoal = nil + c.state = stateAuthenticated + c.mu.Unlock() + + c.sendStatus("authRoutine4", nil, "", nil) + c.cancelMapSafely() + bo.BackOff(ctx, nil) + } + } +} + +func (c *Client) mapRoutine() { + defer close(c.mapDone) + bo := backoff.Backoff{Name: "mapRoutine"} + + for { + c.mu.Lock() + c.logf("mapRoutine: %s\n", c.state) + loggedIn := c.loggedIn + ctx := c.mapCtx + c.mu.Unlock() + + select { + case <-c.quit: + c.logf("mapRoutine: quit\n") + return + default: + } + + report := func(err error, msg string) { + c.logf("%s: %v\n", msg, err) + err = fmt.Errorf("%s: %v", msg, err) + // don't send status updates for context errors, + // since context cancelation is always on purpose. + if ctx.Err() == nil { + c.sendStatus("mapRoutine1", err, "", nil) + } + } + + if !loggedIn { + // Wait for something interesting to happen + c.mu.Lock() + c.synced = false + // c.state is set by authRoutine() + c.mu.Unlock() + + select { + case <-ctx.Done(): + c.logf("mapRoutine: context done.\n") + case <-c.newMapCh: + c.logf("mapRoutine: new map needed while idle.\n") + } + } else { + // Be sure this is false when we're not inside + // PollNetMap, so that cancelMapSafely() can notify + // us correctly. + c.mu.Lock() + c.inPollNetMap = false + c.mu.Unlock() + + err := c.direct.PollNetMap(ctx, -1, func(nm *NetworkMap) { + c.mu.Lock() + + select { + case <-c.newMapCh: + c.logf("mapRoutine: new map request during PollNetMap. canceling.\n") + c.cancelMapLocked() + + // Don't emit this netmap; we're + // about to request a fresh one. + c.mu.Unlock() + return + default: + } + + c.synced = true + c.inPollNetMap = true + if c.loggedIn { + c.state = stateSynchronized + } + exp := nm.Expiry + c.expiry = &exp + stillAuthed := c.loggedIn + state := c.state + + c.mu.Unlock() + + c.logf("mapRoutine: netmap received: %s\n", state) + if stillAuthed { + c.sendStatus("mapRoutine2", nil, "", nm) + } + }) + + c.mu.Lock() + c.synced = false + c.inPollNetMap = false + if c.state == stateSynchronized { + c.state = stateAuthenticated + } + c.mu.Unlock() + + if err != nil { + report(err, "PollNetMap") + bo.BackOff(ctx, err) + continue + } + bo.BackOff(ctx, nil) + } + } +} + +func (c *Client) AuthCantContinue() bool { + c.mu.Lock() + defer c.mu.Unlock() + + return !c.loggedIn && (c.loginGoal == nil || c.loginGoal.url != "") +} + +func (c *Client) SetStatusFunc(fn func(Status)) { + c.mu.Lock() + c.statusFunc = fn + c.mu.Unlock() +} + +func (c *Client) SetHostinfo(hi tailcfg.Hostinfo) { + c.direct.SetHostinfo(hi) + // Send new Hostinfo to server + c.cancelMapSafely() +} + +func (c *Client) sendStatus(who string, err error, url string, nm *NetworkMap) { + c.mu.Lock() + state := c.state + loggedIn := c.loggedIn + synced := c.synced + statusFunc := c.statusFunc + hi := c.hostinfo + c.inSendStatus++ + c.mu.Unlock() + + c.logf("sendStatus: %s: %v\n", who, state) + + var p *Persist + var fin *struct{} + if state == stateAuthenticated { + fin = &struct{}{} + } + if nm != nil && loggedIn && synced { + pp := c.direct.GetPersist() + p = &pp + } else { + // don't send netmap status, as it's misleading when we're + // not logged in. + nm = nil + } + new := Status{ + LoginFinished: fin, + URL: url, + Persist: p, + NetMap: nm, + Hostinfo: hi, + state: state, + } + if err != nil { + new.Err = err.Error() + } + if statusFunc != nil { + statusFunc(new) + } + + c.mu.Lock() + c.inSendStatus-- + c.mu.Unlock() +} + +func (c *Client) Login(t *oauth2.Token, flags LoginFlags) { + c.logf("client.Login(%v, %v)\n", t != nil, flags) + + c.mu.Lock() + c.loginGoal = &LoginGoal{ + wantLoggedIn: true, + token: t, + flags: flags, + } + c.mu.Unlock() + + c.cancelAuth() +} + +func (c *Client) Logout() { + c.logf("client.Logout()\n") + + c.mu.Lock() + c.loginGoal = &LoginGoal{ + wantLoggedIn: false, + } + c.mu.Unlock() + + c.cancelAuth() +} + +func (c *Client) UpdateEndpoints(localPort uint16, endpoints []string) { + changed, err := c.direct.SetEndpoints(localPort, endpoints) + if err != nil { + c.sendStatus("updateEndpoints", err, "", nil) + } else if changed { + c.cancelMapSafely() + } +} + +func (c *Client) Shutdown() { + c.logf("client.Shutdown()\n") + + c.mu.Lock() + inSendStatus := c.inSendStatus + closed := c.closed + if !closed { + c.closed = true + c.statusFunc = nil + } + c.mu.Unlock() + + c.logf("client.Shutdown: inSendStatus=%v\n", inSendStatus) + if !closed { + close(c.quit) + c.cancelAuth() + <-c.authDone + c.cancelMapUnsafely() + <-c.mapDone + c.logf("Client.Shutdown done.\n") + } +} diff --git a/control/controlclient/auto_test.go b/control/controlclient/auto_test.go new file mode 100644 index 000000000..eab5d7d76 --- /dev/null +++ b/control/controlclient/auto_test.go @@ -0,0 +1,1107 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build depends_on_currently_unreleased + +package controlclient + +import ( + "context" + "encoding/json" + "fmt" + "io/ioutil" + "log" + "net/http" + "net/http/httptest" + "net/url" + "os" + "reflect" + "runtime/pprof" + "strconv" + "strings" + "sync" + "testing" + "time" + + "github.com/klauspost/compress/zstd" + "github.com/tailscale/wireguard-go/wgcfg" + "tailscale.com/tailcfg" + "tailscale.com/testy" + "tailscale.io/control" // not yet released +) + +func TestTest(t *testing.T) { + check := testy.NewResourceCheck() + defer check.Assert(t) +} + +func TestServerStartStop(t *testing.T) { + s := newServer(t) + defer s.close() +} + +func TestControlBasics(t *testing.T) { + s := newServer(t) + defer s.close() + + c := s.newClient(t, "c") + c.Login(nil, 0) + status := c.waitStatus(t, stateURLVisitRequired) + c.postAuthURL(t, "foo@tailscale.com", status.New) +} + +func TestControl(t *testing.T) { + log.SetFlags(log.Ltime | log.Lshortfile) + s := newServer(t) + defer s.close() + + c1 := s.newClient(t, "c1") + + t.Run("authorize first tailscale.com client", func(t *testing.T) { + const loginName = "testuser1@tailscale.com" + c1.checkNoStatus(t) + c1.loginAs(t, loginName) + c1.waitStatus(t, stateAuthenticated) + status := c1.waitStatus(t, stateSynchronized) + if got, want := status.New.NetMap.MachineStatus, tailcfg.MachineUnauthorized; got != want { + t.Errorf("MachineStatus=%v, want %v", got, want) + } + c1.checkNoStatus(t) + affectedPeers, err := s.control.AuthorizeMachine(c1.mkey, c1.nkey) + if err != nil { + t.Fatal(err) + } + status = c1.status(t) + if got := status.New.Persist.LoginName; got != loginName { + t.Errorf("LoginName=%q, want %q", got, loginName) + } + if got := status.New.Persist.Provider; got != "google" { + t.Errorf("Provider=%q, want google", got) + } + if len(affectedPeers) != 1 || affectedPeers[0] != c1.id { + t.Errorf("authorization should notify the node being authorized (%v), got: %v", c1.id, affectedPeers) + } + if peers := status.New.NetMap.Peers; len(peers) != 0 { + t.Errorf("peers=%v, want none", peers) + } + if userID := status.New.NetMap.User; userID == 0 { + t.Errorf("NetMap.User is missing") + } else { + profile := status.New.NetMap.UserProfiles[userID] + if profile.LoginName != loginName { + t.Errorf("NetMap user LoginName=%q, want %q", profile.LoginName, loginName) + } + } + c1.checkNoStatus(t) + }) + + c2 := s.newClient(t, "c2") + + t.Run("authorize second tailscale.io client", func(t *testing.T) { + c2.loginAs(t, "testuser2@tailscale.com") + c2.waitStatus(t, stateAuthenticated) + c2.waitStatus(t, stateSynchronized) + c2.checkNoStatus(t) + + // Make sure not to call operations like this on a client in a + // test until the initial map read is done. Otherwise the + // initial map read will trigger a map update to peers, and + // there will sometimes be a spurious map update. + affectedPeers, err := s.control.AuthorizeMachine(c2.mkey, c2.nkey) + if err != nil { + t.Fatal(err) + } + status := c2.waitStatus(t, stateSynchronized) + c1Status := c1.waitStatus(t, stateSynchronized) + + if len(affectedPeers) != 2 { + t.Errorf("affectedPeers=%v, want two entries", affectedPeers) + } + if want := []tailcfg.NodeID{c1.id, c2.id}; !nodeIDsEqual(affectedPeers, want) { + t.Errorf("affectedPeers=%v, want %v", affectedPeers, want) + } + + c1NetMap := c1Status.New.NetMap + c2NetMap := status.New.NetMap + if len(c1NetMap.Peers) != 1 || len(c2NetMap.Peers) != 1 { + t.Error("wrong number of peers") + } else { + if c2NetMap.Peers[0].Key != c1.nkey { + t.Errorf("c2 has wrong peer key %v, want %v", c2NetMap.Peers[0].Key, c1.nkey) + } + if c1NetMap.Peers[0].Key != c2.nkey { + t.Errorf("c1 has wrong peer key %v, want %v", c1NetMap.Peers[0].Key, c2.nkey) + } + } + if t.Failed() { + t.Errorf("client1 network map:\n%s", c1Status.New.NetMap) + t.Errorf("client2 network map:\n%s", status.New.NetMap) + } + + c1.checkNoStatus(t) + c2.checkNoStatus(t) + }) + + // c3/c4 are on a different domain to c1/c2. + // The two domains should never affect one another. + c3 := s.newClient(t, "c3") + + t.Run("authorize first onmicrosoft client", func(t *testing.T) { + c3.loginAs(t, "testuser1@tailscale.onmicrosoft.com") + c3.waitStatus(t, stateAuthenticated) + c3Status := c3.waitStatus(t, stateSynchronized) + // no machine authorization for tailscale.onmicrosoft.com + c1.checkNoStatus(t) + c2.checkNoStatus(t) + + netMap := c3Status.New.NetMap + if netMap.NodeKey != c3.nkey { + t.Errorf("netMap.NodeKey=%v, want %v", netMap.NodeKey, c3.nkey) + } + if len(netMap.Peers) != 0 { + t.Errorf("netMap.Peers=%v, want none", netMap.Peers) + } + + c1.checkNoStatus(t) + c2.checkNoStatus(t) + c3.checkNoStatus(t) + }) + + c4 := s.newClient(t, "c4") + + t.Run("authorize second onmicrosoft client", func(t *testing.T) { + c4.loginAs(t, "testuser2@tailscale.onmicrosoft.com") + c4.waitStatus(t, stateAuthenticated) + c3Status := c3.waitStatus(t, stateSynchronized) + c4Status := c4.waitStatus(t, stateSynchronized) + c3NetMap := c3Status.New.NetMap + c4NetMap := c4Status.New.NetMap + + c1.checkNoStatus(t) + c2.checkNoStatus(t) + + if len(c3NetMap.Peers) != 1 { + t.Errorf("wrong number of c3 peers: %d", len(c3NetMap.Peers)) + } else if len(c4NetMap.Peers) != 1 { + t.Errorf("wrong number of c4 peers: %d", len(c4NetMap.Peers)) + } else { + if c3NetMap.Peers[0].Key != c4.nkey || c4NetMap.Peers[0].Key != c3.nkey { + t.Error("wrong peer key") + } + } + if t.Failed() { + t.Errorf("client3 network map:\n%s", c3NetMap) + t.Errorf("client4 network map:\n%s", c4NetMap) + } + }) + + var c1NetMap *NetworkMap + t.Run("update c1 and c2 endpoints", func(t *testing.T) { + c1Endpoints := []string{"172.16.1.5:12345", "4.4.4.4:4444"} + c1.checkNoStatus(t) + c1.UpdateEndpoints(1234, c1Endpoints) + c1NetMap = c1.status(t).New.NetMap + c2NetMap := c2.status(t).New.NetMap + c1.checkNoStatus(t) + c2.checkNoStatus(t) + + if c1NetMap.LocalPort != 1234 { + t.Errorf("c1 netmap localport=%d, want 1234", c1NetMap.LocalPort) + } + if len(c2NetMap.Peers) != 1 { + t.Fatalf("wrong peer count: %d", len(c2NetMap.Peers)) + } + if got := c2NetMap.Peers[0].Endpoints; !reflect.DeepEqual(c1Endpoints, got) { + t.Errorf("c2 peer endpoints=%v, want %v", got, c1Endpoints) + } + c3.checkNoStatus(t) + c4.checkNoStatus(t) + + c2Endpoints := []string{"172.16.1.7:6543", "5.5.5.5.3333"} + c2.UpdateEndpoints(9876, c2Endpoints) + c1NetMap = c1.status(t).New.NetMap + c2NetMap = c2.status(t).New.NetMap + + if c1NetMap.LocalPort != 1234 { + t.Errorf("c1 netmap localport=%d, want 1234", c1NetMap.LocalPort) + } + if c2NetMap.LocalPort != 9876 { + t.Errorf("c2 netmap localport=%d, want 9876", c2NetMap.LocalPort) + } + if got := c2NetMap.Peers[0].Endpoints; !reflect.DeepEqual(c1Endpoints, got) { + t.Errorf("c2 peer endpoints=%v, want %v", got, c1Endpoints) + } + if got := c1NetMap.Peers[0].Endpoints; !reflect.DeepEqual(c2Endpoints, got) { + t.Errorf("c1 peer endpoints=%v, want %v", got, c2Endpoints) + } + + c1.checkNoStatus(t) + c2.checkNoStatus(t) + c3.checkNoStatus(t) + c4.checkNoStatus(t) + }) + + allZeros, err := wgcfg.ParseCIDR("0.0.0.0/0") + if err != nil { + t.Fatal(err) + } + + t.Run("route all traffic via client 1", func(t *testing.T) { + aips := []wgcfg.CIDR{} + aips = append(aips, c1NetMap.Addresses...) + aips = append(aips, *allZeros) + + affectedPeers, err := s.control.SetAllowedIPs(c1.nkey, aips) + if err != nil { + t.Fatal(err) + } + c2Status := c2.status(t) + c2NetMap := c2Status.New.NetMap + + if want := []tailcfg.NodeID{c2.id}; !nodeIDsEqual(affectedPeers, want) { + t.Errorf("affectedPeers=%v, want %v", affectedPeers, want) + } + + _ = c2NetMap + foundAllZeros := false + for _, cidr := range c2NetMap.Peers[0].AllowedIPs { + if cidr == *allZeros { + foundAllZeros = true + } + } + if !foundAllZeros { + t.Errorf("client2 peer does not contain %s: %v", allZeros, c2NetMap.Peers[0].AllowedIPs) + } + + c1.checkNoStatus(t) + c3.checkNoStatus(t) + c4.checkNoStatus(t) + }) + + t.Run("remove route all traffic", func(t *testing.T) { + affectedPeers, err := s.control.SetAllowedIPs(c1.nkey, c1NetMap.Addresses) + if err != nil { + t.Fatal(err) + } + c2NetMap := c2.status(t).New.NetMap + + if want := []tailcfg.NodeID{c2.id}; !nodeIDsEqual(affectedPeers, want) { + t.Errorf("affectedPeers=%v, want %v", affectedPeers, want) + } + + foundAllZeros := false + for _, cidr := range c2NetMap.Peers[0].AllowedIPs { + if cidr == *allZeros { + foundAllZeros = true + } + } + if foundAllZeros { + t.Errorf("client2 peer still contains %s: %v", allZeros, c2NetMap.Peers[0].AllowedIPs) + } + + c1.checkNoStatus(t) + c3.checkNoStatus(t) + c4.checkNoStatus(t) + }) + + t.Run("refresh client key", func(t *testing.T) { + oldKey := c1.nkey + + c1.Login(nil, LoginInteractive) + status := c1.waitStatus(t, stateURLVisitRequired) + authURL := status.New.URL + + resp, err := c1.httpc.Get(authURL) + if err != nil { + t.Fatal(err) + } + if resp.StatusCode != 200 { + t.Errorf("GET %s failed: %q", authURL, resp.Status) + } + body, err := ioutil.ReadAll(resp.Body) + resp.Body.Close() + if err != nil { + t.Fatal(err) + } + cookies := resp.Cookies() + if len(cookies) == 0 || cookies[0].Name != "tailcontrol" { + t.Logf("GET %s: %s", authURL, string(body)) + t.Fatalf("GET %s: bad cookie: %v", authURL, cookies) + } + c1.waitStatus(t, stateAuthenticated) + status = c1.waitStatus(t, stateSynchronized) + if status.New.Err != "" { + t.Fatal(status.New.Err) + } + + c1NetMap := status.New.NetMap + c1.nkey = c1NetMap.NodeKey + if c1.nkey == oldKey { + t.Errorf("new key is the same as the old key: %s", oldKey) + } + c2NetMap := c2.status(t).New.NetMap + if len(c2NetMap.Peers) != 1 || c2NetMap.Peers[0].Key != c1.nkey { + t.Errorf("c2 peer: %v, want new node key %v", c1.nkey, c2NetMap.Peers[0].Key) + } + + c3.checkNoStatus(t) + c4.checkNoStatus(t) + }) +} + +func TestLoginInterrupt(t *testing.T) { + s := newServer(t) + defer s.close() + + c := s.newClient(t, "c") + + const loginName = "testuser1@tailscale.com" + c.checkNoStatus(t) + c.loginAs(t, loginName) + c.waitStatus(t, stateAuthenticated) + c.waitStatus(t, stateSynchronized) + t.Logf("authorizing: %v %v %v %v\n", s, s.control, c.mkey, c.nkey) + if _, err := s.control.AuthorizeMachine(c.mkey, c.nkey); err != nil { + t.Fatal(err) + } + status := c.waitStatus(t, stateSynchronized) + if got, want := status.New.NetMap.MachineStatus, tailcfg.MachineAuthorized; got != want { + t.Errorf("MachineStatus=%v, want %v", got, want) + } + origAddrs := status.New.NetMap.Addresses + if len(origAddrs) == 0 { + t.Errorf("Addresses empty, want something") + } + + c.Logout() + c.waitStatus(t, stateNotAuthenticated) + c.Login(nil, 0) + status = c.waitStatus(t, stateURLVisitRequired) + authURL := status.New.URL + + // Interrupt, and do login again. + c.Login(nil, 0) + status = c.waitStatus(t, stateURLVisitRequired) + authURL2 := status.New.URL + + if authURL == authURL2 { + t.Errorf("auth URLs match for subsequent logins: %s", authURL) + } + + form := url.Values{"user": []string{loginName}} + req, err := http.NewRequest("POST", authURL2, strings.NewReader(form.Encode())) + if err != nil { + t.Fatal(err) + } + req.Header.Add("Content-Type", "application/x-www-form-urlencoded") + resp, err := c.httpc.Do(req.WithContext(c.ctx)) + if err != nil { + t.Fatal(err) + } + if resp.StatusCode != 200 { + t.Fatalf("POST %s failed: %q", authURL2, resp.Status) + } + cookies := resp.Cookies() + if len(cookies) == 0 || cookies[0].Name != "tailcontrol" { + t.Fatalf("POST %s: bad cookie: %v", authURL2, cookies) + } + + c.waitStatus(t, stateAuthenticated) + status = c.status(t) + if got := status.New.NetMap.NodeKey; got != c.nkey { + t.Errorf("netmap has wrong node key: %v, want %v", got, c.nkey) + } + if got := status.New.NetMap.Addresses; len(got) == 0 { + t.Errorf("Addresses empty after re-login, want something") + } else if len(origAddrs) > 0 && origAddrs[0] != got[0] { + t.Errorf("Addresses=%v after re-login, originally was %v, want IP to be unchanged", got, origAddrs) + } +} + +func TestSpinUpdateEndpoints(t *testing.T) { + s := newServer(t) + defer s.close() + + c1 := s.newClient(t, "c1") + c2 := s.newClient(t, "c2") + + const loginName = "testuser1@tailscale.com" + c1.loginAs(t, loginName) + c1.waitStatus(t, stateAuthenticated) + c1.waitStatus(t, stateSynchronized) + if _, err := s.control.AuthorizeMachine(c1.mkey, c1.nkey); err != nil { + t.Fatal(err) + } + c1.waitStatus(t, stateSynchronized) + + c2.loginAs(t, loginName) + c2.waitStatus(t, stateAuthenticated) + c2.waitStatus(t, stateSynchronized) + if _, err := s.control.AuthorizeMachine(c2.mkey, c2.nkey); err != nil { + t.Fatal(err) + } + c2.waitStatus(t, stateSynchronized) + c1.waitStatus(t, stateSynchronized) + + const portBase = 1200 + const portCount = 50 + const portLast = portBase + portCount - 1 + + errCh := make(chan error, 1) + collectPorts := func() error { + t := time.After(10 * time.Second) + var port int + for i := 0; i < portCount; i++ { + var status statusChange + select { + case status = <-c2.statusCh: + case <-t: + return fmt.Errorf("c2 status timeout (i=%d)", i) + } + peers := status.New.NetMap.Peers + if len(peers) != 1 { + return fmt.Errorf("c2 len(peers)=%d, want 1", len(peers)) + } + eps := peers[0].Endpoints + if len(eps) != 2 { + return fmt.Errorf("c2 peer len(eps)=%d, want 2", len(eps)) + } + ep := eps[1] + const prefix = "192.168.1.45:" + if !strings.HasPrefix(ep, prefix) { + return fmt.Errorf("c2 peer endpoint=%s, want prefix %s", ep, prefix) + } + var err error + port, err = strconv.Atoi(strings.TrimPrefix(ep, prefix)) + if err != nil { + return fmt.Errorf("c2 peer endpoint port: %v", err) + } + if port == portLast { + return nil // got it + } + } + return fmt.Errorf("c2 peer endpoint did not see portLast (saw %d)", port) + } + go func() { + errCh <- collectPorts() + }() + + // Very quickly call UpdateEndpoints several times. + // Some (most) of these calls will never make it to the server, they + // will be canceled by subsequent calls. + // The last call goes through, so we can see portLast. + eps := []string{"127.0.0.1:1234", ""} + for i := 0; i < portCount; i++ { + eps[1] = fmt.Sprintf("192.168.1.45:%d", portBase+i) + c1.UpdateEndpoints(1234, eps) + } + + if err := <-errCh; err != nil { + t.Fatalf("collect ports: %v", err) + } +} + +func TestLogout(t *testing.T) { + s := newServer(t) + defer s.close() + + c1 := s.newClient(t, "c1") + + const loginName = "testuser1@tailscale.com" + c1.loginAs(t, loginName) + + c1.waitStatus(t, stateAuthenticated) + c1.waitStatus(t, stateSynchronized) + if _, err := s.control.AuthorizeMachine(c1.mkey, c1.nkey); err != nil { + t.Fatal(err) + } + nkey1 := c1.status(t).New.NetMap.NodeKey + + c1.Logout() + c1.waitStatus(t, stateNotAuthenticated) + + c1.loginAs(t, loginName) + c1.waitStatus(t, stateAuthenticated) + status := c1.waitStatus(t, stateSynchronized) + if got, want := status.New.NetMap.MachineStatus, tailcfg.MachineAuthorized; got != want { + t.Errorf("re-login MachineStatus=%v, want %v", got, want) + } + nkey2 := status.New.NetMap.NodeKey + if nkey1 == nkey2 { + t.Errorf("key not changed after re-login: %v", nkey1) + } + + c1.checkNoStatus(t) +} + +func TestExpiry(t *testing.T) { + var nowMu sync.Mutex + now := time.Now() // Server and Client use this variable as the current time + timeNow := func() time.Time { + nowMu.Lock() + defer nowMu.Unlock() + return now + } + timeInc := func(d time.Duration) { + nowMu.Lock() + defer nowMu.Unlock() + now = now.Add(d) + } + + s := newServer(t) + s.control.TimeNow = timeNow + defer s.close() + + c1 := s.newClient(t, "c1") + + const loginName = "testuser1@tailscale.com" + c1.loginAs(t, loginName) + + c1.waitStatus(t, stateAuthenticated) + c1.waitStatus(t, stateSynchronized) + if _, err := s.control.AuthorizeMachine(c1.mkey, c1.nkey); err != nil { + t.Fatal(err) + } + status := c1.waitStatus(t, stateSynchronized).New + nkey1 := c1.direct.persist.PrivateNodeKey + nkey1Expiry := status.NetMap.Expiry + if wantExpiry := timeNow().Add(180 * 24 * time.Hour); !nkey1Expiry.Equal(wantExpiry) { + t.Errorf("node key expiry = %v, want %v", nkey1Expiry, wantExpiry) + } + + timeInc(1 * time.Hour) // move the clock forward + c1.Login(nil, LoginInteractive) // refresh the key + status = c1.waitStatus(t, stateURLVisitRequired).New + c1.postAuthURL(t, loginName, status) + c1.waitStatus(t, stateAuthenticated) + status = c1.waitStatus(t, stateSynchronized).New + if newKey := c1.direct.persist.PrivateNodeKey; newKey == nkey1 { + t.Errorf("node key unchanged after LoginInteractive: %v", nkey1) + } + if want, got := timeNow().Add(180*24*time.Hour), status.NetMap.Expiry; !got.Equal(want) { + t.Errorf("node key expiry = %v, want %v", got, want) + } + + timeInc(2 * time.Hour) // move the clock forward + c1.Login(nil, 0) + c1.waitStatus(t, stateAuthenticated) + c1.waitStatus(t, stateSynchronized) + c1.checkNoStatus(t) // nothing happens, network map stays the same + + timeInc(180 * 24 * time.Hour) // move the clock past expiry + c1.loginAs(t, loginName) + c1.waitStatus(t, stateAuthenticated) + status = c1.waitStatus(t, stateSynchronized).New + if got, want := c1.expiry, timeNow(); got.Equal(want) { + t.Errorf("node key expiry = %v, want %v", got, want) + } + if c1.direct.persist.PrivateNodeKey == nkey1 { + t.Errorf("node key after 37 hours is still %v", status.NetMap.NodeKey) + } +} + +func TestRefresh(t *testing.T) { + var nowMu sync.Mutex + now := time.Now() // Server and Client use this variable as the current time + timeNow := func() time.Time { + nowMu.Lock() + defer nowMu.Unlock() + return now + } + + s := newServer(t) + s.control.TimeNow = timeNow + defer s.close() + + c1 := s.newClient(t, "c1") + + const loginName = "testuser1@versabank.com" // versabank cfgdb has 72 hour key expiry configured + c1.loginAs(t, loginName) + + c1.waitStatus(t, stateAuthenticated) + c1.waitStatus(t, stateSynchronized) + if _, err := s.control.AuthorizeMachine(c1.mkey, c1.nkey); err != nil { + t.Fatal(err) + } + status := c1.status(t).New + nkey1 := status.NetMap.NodeKey + nkey1Expiry := status.NetMap.Expiry + if wantExpiry := timeNow().Add(72 * time.Hour); !nkey1Expiry.Equal(wantExpiry) { + t.Errorf("node key expiry = %v, want %v", nkey1Expiry, wantExpiry) + } + + c1.Login(nil, LoginInteractive) + c1.waitStatus(t, stateURLVisitRequired) + // Until authorization happens, old netmap is still valid. + exp := c1.expiry + if exp == nil { + t.Errorf("expiry==nil during refresh\n") + } + if got := *exp; !nkey1Expiry.Equal(got) { + t.Errorf("node key expiry = %v, want %v", got, nkey1Expiry) + } + k := tailcfg.NodeKey(*c1.direct.persist.PrivateNodeKey.Public()) + if k != nkey1 { + t.Errorf("node key after 2 hours is %v, want %v", k, nkey1) + } + c1.Shutdown() +} + +func TestExpectedProvider(t *testing.T) { + s := newServer(t) + defer s.close() + + c := s.newClient(t, "c1") + + c.direct.persist.LoginName = "testuser1@tailscale.com" + c.direct.persist.Provider = "microsoft" + c.Login(nil, 0) + status := c.readStatus(t) + if e, substr := status.New.Err, `provider "microsoft" is not supported`; !strings.Contains(e, substr) { + t.Errorf("Err=%q, expect substring %q", e, substr) + } +} + +func TestNewUserWebFlow(t *testing.T) { + s := newServer(t) + defer s.close() + s.control.DB().SetSegmentAPIKey(segmentKey) + + c := s.newClient(t, "c1") + c.Login(nil, 0) + status := c.waitStatus(t, stateURLVisitRequired) + authURL := status.New.URL + resp, err := c.httpc.Get(authURL) + if err != nil { + t.Fatal(err) + } + if resp.StatusCode != 200 { + t.Errorf("statuscode=%d, want 200", resp.StatusCode) + } + b, err := ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + got := string(b) + if !strings.Contains(got, ` stateAuthenticated +// TODO: test os/hostname gets sent to server +// TODO: test vpn IP not assigned until machine is authorized +// TODO: test overlapping calls to RefreshLogin +// TODO: test registering a new node for a user+machine key replaces the old +// node even if the OldNodeKey is not specified by the client. +// TODO: test "does not expire" on server extends expiry in sent network map diff --git a/control/controlclient/controlclient_test.go b/control/controlclient/controlclient_test.go new file mode 100644 index 000000000..a764a8c71 --- /dev/null +++ b/control/controlclient/controlclient_test.go @@ -0,0 +1,68 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package controlclient + +import ( + "reflect" + "testing" +) + +func fieldsOf(t reflect.Type) (fields []string) { + for i := 0; i < t.NumField(); i++ { + fields = append(fields, t.Field(i).Name) + } + return +} + +func TestStatusEqual(t *testing.T) { + // Verify that the Equal method stays in sync with reality + equalHandles := []string{"LoginFinished", "Err", "URL", "Persist", "NetMap", "Hostinfo", "state"} + if have := fieldsOf(reflect.TypeOf(Status{})); !reflect.DeepEqual(have, equalHandles) { + t.Errorf("Status.Equal check might be out of sync\nfields: %q\nhandled: %q\n", + have, equalHandles) + } + + tests := []struct { + a, b *Status + want bool + }{ + { + &Status{}, + nil, + false, + }, + { + nil, + &Status{}, + false, + }, + { + &Status{}, + &Status{}, + true, + }, + { + &Status{state: stateNew}, + &Status{state: stateNew}, + true, + }, + { + &Status{state: stateNew}, + &Status{state: stateAuthenticated}, + false, + }, + { + &Status{LoginFinished: nil}, + &Status{LoginFinished: new(struct{})}, + false, + }, + } + for i, tt := range tests { + got := tt.a.Equal(tt.b) + if got != tt.want { + t.Errorf("%d. Equal = %v; want %v", i, got, tt.want) + } + } +} diff --git a/control/controlclient/direct.go b/control/controlclient/direct.go new file mode 100644 index 000000000..a8cd83f12 --- /dev/null +++ b/control/controlclient/direct.go @@ -0,0 +1,656 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package controlclient + +import ( + "bytes" + "context" + "crypto/rand" + "encoding/binary" + "encoding/json" + "errors" + "fmt" + "io" + "io/ioutil" + "log" + "net/http" + "os" + "runtime" + "strings" + "sync" + "time" + + "github.com/tailscale/wireguard-go/wgcfg" + "golang.org/x/crypto/nacl/box" + "golang.org/x/oauth2" + "tailscale.com/logger" + "tailscale.com/tailcfg" + "tailscale.com/version" + "tailscale.com/wgengine/filter" +) + +type Persist struct { + PrivateMachineKey wgcfg.PrivateKey + PrivateNodeKey wgcfg.PrivateKey + OldPrivateNodeKey wgcfg.PrivateKey // needed to request key rotation + Provider string + LoginName string +} + +func (p *Persist) Pretty() string { + var mk, ok, nk wgcfg.Key + if !p.PrivateMachineKey.IsZero() { + mk = *p.PrivateMachineKey.Public() + } + if !p.OldPrivateNodeKey.IsZero() { + ok = *p.OldPrivateNodeKey.Public() + } + if !p.PrivateNodeKey.IsZero() { + nk = *p.PrivateNodeKey.Public() + } + return fmt.Sprintf("Persist{m=%v, o=%v, n=%v u=%#v}", + mk.ShortString(), ok.ShortString(), nk.ShortString(), + p.LoginName) +} + +// Direct is the client that connects to a tailcontrol server for a node. +type Direct struct { + httpc *http.Client // HTTP client used to talk to tailcontrol + serverURL string // URL of the tailcontrol server + timeNow func() time.Time + newDecompressor func() (Decompressor, error) + keepAlive bool + logf logger.Logf + + mu sync.Mutex // mutex guards the following fields + serverKey wgcfg.Key + persist Persist + tryingNewKey wgcfg.PrivateKey + expiry *time.Time + hostinfo tailcfg.Hostinfo + endpoints []string + localPort uint16 + cmdCh chan interface{} + doneCh chan struct{} +} + +type Options struct { + Persist Persist // initial persistent data + HTTPC *http.Client // HTTP client used to talk to tailcontrol + ServerURL string // URL of the tailcontrol server + TimeNow func() time.Time // time.Now implementation used by Client + Hostinfo *tailcfg.Hostinfo + NewDecompressor func() (Decompressor, error) + KeepAlive bool + Logf logger.Logf +} + +type Decompressor interface { + DecodeAll(input, dst []byte) ([]byte, error) + Close() +} + +// NewDirect returns a new Direct client. +func NewDirect(opts Options) (*Direct, error) { + if opts.ServerURL == "" { + return nil, errors.New("controlclient.New: no server URL specified") + } + opts.ServerURL = strings.TrimRight(opts.ServerURL, "/") + if opts.HTTPC == nil { + opts.HTTPC = http.DefaultClient + } + if opts.TimeNow == nil { + opts.TimeNow = time.Now + } + if opts.Logf == nil { + // TODO(apenwarr): remove this default and fail instead. + opts.Logf = log.Printf + } + + c := &Direct{ + httpc: opts.HTTPC, + serverURL: opts.ServerURL, + timeNow: opts.TimeNow, + logf: opts.Logf, + newDecompressor: opts.NewDecompressor, + keepAlive: opts.KeepAlive, + persist: opts.Persist, + } + if opts.Hostinfo == nil { + c.SetHostinfo(NewHostinfo()) + } else { + c.SetHostinfo(*opts.Hostinfo) + } + + return c, nil +} + +func NewHostinfo() tailcfg.Hostinfo { + hostname, _ := os.Hostname() + os := runtime.GOOS + switch os { + case "darwin": + switch runtime.GOARCH { + case "arm", "arm64": + os = "iOS" + default: + os = "macOS" + } + } + + return tailcfg.Hostinfo{ + IPNVersion: version.LONG, + Hostname: hostname, + OS: os, + } +} + +func (c *Direct) SetHostinfo(hi tailcfg.Hostinfo) { + c.mu.Lock() + defer c.mu.Unlock() + + c.logf("Hostinfo: %v\n", hi) + c.hostinfo = hi +} + +func (c *Direct) GetPersist() Persist { + c.mu.Lock() + defer c.mu.Unlock() + return c.persist +} + +type LoginFlags int + +const ( + LoginDefault = LoginFlags(0) + LoginInteractive = LoginFlags(1 << iota) // force user login and key refresh +) + +func (c *Direct) TryLogout(ctx context.Context) error { + c.logf("direct.TryLogout()\n") + + c.mu.Lock() + defer c.mu.Unlock() + + if c.persist.PrivateNodeKey != (wgcfg.PrivateKey{}) { + // TODO(crawshaw): Tell the server. This node key should be immediately invalidated. + } + c.persist = Persist{ + PrivateMachineKey: c.persist.PrivateMachineKey, + } + return nil +} + +func (c *Direct) TryLogin(ctx context.Context, t *oauth2.Token, flags LoginFlags) (url string, err error) { + c.logf("direct.TryLogin(%v, %v)\n", t != nil, flags) + return c.doLoginOrRegen(ctx, t, flags, false, "") +} + +func (c *Direct) WaitLoginURL(ctx context.Context, url string) (newUrl string, err error) { + c.logf("direct.WaitLoginURL\n") + return c.doLoginOrRegen(ctx, nil, LoginDefault, false, url) +} + +func (c *Direct) doLoginOrRegen(ctx context.Context, t *oauth2.Token, flags LoginFlags, regen bool, url string) (newUrl string, err error) { + mustregen, url, err := c.doLogin(ctx, t, flags, regen, url) + if err != nil { + return url, err + } + if mustregen { + _, url, err = c.doLogin(ctx, t, flags, true, url) + } + return url, err +} + +func (c *Direct) doLogin(ctx context.Context, t *oauth2.Token, flags LoginFlags, regen bool, url string) (mustregen bool, newurl string, err error) { + c.mu.Lock() + persist := c.persist + tryingNewKey := c.tryingNewKey + serverKey := c.serverKey + expired := c.expiry != nil && !c.expiry.IsZero() && c.expiry.Before(c.timeNow()) + c.mu.Unlock() + + if persist.PrivateMachineKey == (wgcfg.PrivateKey{}) { + c.logf("Generating a new machinekey.\n") + mkey, err := wgcfg.NewPrivateKey() + if err != nil { + log.Fatal(err) + } + persist.PrivateMachineKey = *mkey + } + + if expired { + c.logf("Old key expired -> regen=true\n") + regen = true + } + if (flags & LoginInteractive) != 0 { + c.logf("LoginInteractive -> regen=true\n") + regen = true + } + + c.logf("doLogin(regen=%v, hasUrl=%v)\n", regen, url != "") + if serverKey == (wgcfg.Key{}) { + var err error + serverKey, err = loadServerKey(ctx, c.httpc, c.serverURL) + if err != nil { + return regen, url, err + } + + c.mu.Lock() + c.serverKey = serverKey + c.mu.Unlock() + } + + var oldNodeKey wgcfg.Key + if url != "" { + } else if regen || persist.PrivateNodeKey == (wgcfg.PrivateKey{}) { + c.logf("Generating a new nodekey.\n") + persist.OldPrivateNodeKey = persist.PrivateNodeKey + key, err := wgcfg.NewPrivateKey() + if err != nil { + c.logf("login keygen: %v", err) + return regen, url, err + } + tryingNewKey = *key + } else { + // Try refreshing the current key first + tryingNewKey = persist.PrivateNodeKey + } + if persist.OldPrivateNodeKey != (wgcfg.PrivateKey{}) { + oldNodeKey = *persist.OldPrivateNodeKey.Public() + } + + if tryingNewKey == (wgcfg.PrivateKey{}) { + log.Fatalf("tryingNewKey is empty, give up\n") + } + if c.hostinfo.BackendLogID == "" { + err = errors.New("hostinfo: BackendLogID missing") + return regen, url, err + } + request := tailcfg.RegisterRequest{ + Version: 1, + OldNodeKey: tailcfg.NodeKey(oldNodeKey), + NodeKey: tailcfg.NodeKey(*tryingNewKey.Public()), + Hostinfo: c.hostinfo, + Followup: url, + } + c.logf("RegisterReq: onode=%v node=%v fup=%v\n", + request.OldNodeKey.AbbrevString(), + request.NodeKey.AbbrevString(), url != "") + request.Auth.Oauth2Token = t + request.Auth.Provider = persist.Provider + request.Auth.LoginName = persist.LoginName + bodyData, err := encode(request, &serverKey, &persist.PrivateMachineKey) + if err != nil { + return regen, url, err + } + body := bytes.NewReader(bodyData) + + u := fmt.Sprintf("%s/machine/%s", c.serverURL, persist.PrivateMachineKey.Public().HexString()) + req, err := http.NewRequest("POST", u, body) + if err != nil { + return regen, url, err + } + req = req.WithContext(ctx) + + res, err := c.httpc.Do(req) + if err != nil { + return regen, url, fmt.Errorf("register request: %v", err) + } + c.logf("RegisterReq: returned.\n") + resp := tailcfg.RegisterResponse{} + if err := decode(res, &resp, &serverKey, &persist.PrivateMachineKey); err != nil { + return regen, url, fmt.Errorf("register request: %v", err) + } + + if resp.NodeKeyExpired { + if regen { + return true, "", fmt.Errorf("weird: regen=true but server says NodeKeyExpired: %v", request.NodeKey) + } + c.logf("server reports new node key %v has expired", + request.NodeKey.AbbrevString()) + return true, "", nil + } + if persist.Provider == "" { + persist.Provider = resp.Login.Provider + } + if persist.LoginName == "" { + persist.LoginName = resp.Login.LoginName + } + + // TODO(crawshaw): RegisterResponse should be able to mechanically + // communicate some extra instructions from the server: + // - new node key required + // - machine key no longer supported + // - user is disabled + + if resp.AuthURL != "" { + c.logf("AuthURL is %.20v...\n", resp.AuthURL) + } else { + c.logf("No AuthURL\n") + } + + c.mu.Lock() + if resp.AuthURL == "" { + // key rotation is complete + persist.PrivateNodeKey = tryingNewKey + } else { + // save it for the retry-with-URL + c.tryingNewKey = tryingNewKey + } + c.persist = persist + c.mu.Unlock() + + if err != nil { + return regen, "", err + } + if ctx.Err() != nil { + return regen, "", ctx.Err() + } + return false, resp.AuthURL, nil +} + +func sameStrings(a, b []string) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if a[i] != b[i] { + return false + } + } + return true +} + +func (c *Direct) newEndpoints(localPort uint16, endpoints []string) bool { + c.mu.Lock() + defer c.mu.Unlock() + + // Nothing new? + if c.localPort == localPort && sameStrings(c.endpoints, endpoints) { + return false // unchanged + } + c.logf("client.newEndpoints(%v, %v)\n", localPort, endpoints) + if len(c.endpoints) > 0 { + // empty the old list without deallocating it + c.endpoints = c.endpoints[:0] + } + c.localPort = localPort + c.endpoints = append(c.endpoints, endpoints...) + return true // changed +} + +// SetEndpoints updates the list of locally advertised endpoints. +// It won't be replicated to the server until a *fresh* call to PollNetMap(). +// You don't need to restart PollNetMap if we return changed==false. +func (c *Direct) SetEndpoints(localPort uint16, endpoints []string) (changed bool, err error) { + // (no log message on function entry, because it clutters the logs + // if endpoints haven't changed. newEndpoints() will log it.) + changed = c.newEndpoints(localPort, endpoints) + return changed, nil +} + +func (c *Direct) PollNetMap(ctx context.Context, maxPolls int, cb func(*NetworkMap)) error { + c.mu.Lock() + persist := c.persist + serverURL := c.serverURL + serverKey := c.serverKey + hostinfo := c.hostinfo + localPort := c.localPort + ep := append([]string(nil), c.endpoints...) + c.mu.Unlock() + + if hostinfo.BackendLogID == "" { + return errors.New("hostinfo: BackendLogID missing") + } + + allowStream := maxPolls != 1 + c.logf("PollNetMap: stream=%v :%v %v\n", maxPolls, localPort, ep) + + request := tailcfg.MapRequest{ + Version: 4, + KeepAlive: c.keepAlive, + NodeKey: tailcfg.NodeKey(*persist.PrivateNodeKey.Public()), + Endpoints: ep, + Stream: allowStream, + Hostinfo: hostinfo, + } + if c.newDecompressor != nil { + request.Compress = "zstd" + } + + bodyData, err := encode(request, &serverKey, &persist.PrivateMachineKey) + if err != nil { + return err + } + + u := fmt.Sprintf("%s/machine/%s/map", serverURL, persist.PrivateMachineKey.Public().HexString()) + req, err := http.NewRequest("POST", u, bytes.NewReader(bodyData)) + if err != nil { + return err + } + ctx, cancel := context.WithCancel(ctx) + defer cancel() + req = req.WithContext(ctx) + + res, err := c.httpc.Do(req) + if err != nil { + return err + } + if res.StatusCode != 200 { + msg, _ := ioutil.ReadAll(res.Body) + res.Body.Close() + return fmt.Errorf("initial fetch failed %d: %s", + res.StatusCode, strings.TrimSpace(string(msg))) + } + defer res.Body.Close() + + // If we go more than pollTimeout without hearing from the server, + // end the long poll. We should be receiving a keep alive ping + // every minute. + const pollTimeout = 120 * time.Second + timeout := time.NewTimer(pollTimeout) + timeoutReset := make(chan struct{}) + defer close(timeoutReset) + go func() { + for { + select { + case <-timeout.C: + c.logf("map response long-poll timed out!") + cancel() + return + case _, ok := <-timeoutReset: + if !ok { + return // channel closed, shut down goroutine + } + if !timeout.Stop() { + <-timeout.C + } + timeout.Reset(pollTimeout) + } + } + }() + + // If allowStream, then the server will use an HTTP long poll to + // return incremental results. There is always one response right + // away, followed by a delay, and eventually others. + // If !allowStream, it'll still send the first result in exactly + // the same format before just closing the connection. + // We can use this same read loop either way. + var msg []byte + for i := 0; i < maxPolls || maxPolls < 0; i++ { + var siz [4]byte + if _, err := io.ReadFull(res.Body, siz[:]); err != nil { + return err + } + size := binary.LittleEndian.Uint32(siz[:]) + msg = append(msg[:0], make([]byte, size)...) + if _, err := io.ReadFull(res.Body, msg); err != nil { + return err + } + + var resp tailcfg.MapResponse + + // Default filter if the key is missing from the incoming + // json (ie. old tailcontrol server without PacketFilter + // support). If even an empty PacketFilter is provided, this + // will be overwritten. + // TODO(apenwarr 2020-02-01): remove after tailcontrol is fully deployed. + resp.PacketFilter = filter.MatchAllowAll + + if err := c.decodeMsg(msg, &resp); err != nil { + return err + } + if resp.KeepAlive { + c.logf("map response keep alive received") + timeoutReset <- struct{}{} + continue + } + + nm := &NetworkMap{ + NodeKey: tailcfg.NodeKey(*persist.PrivateNodeKey.Public()), + PrivateKey: persist.PrivateNodeKey, + Expiry: resp.Node.KeyExpiry, + Addresses: resp.Node.Addresses, + Peers: resp.Peers, + LocalPort: localPort, + User: resp.Node.User, + UserProfiles: make(map[tailcfg.UserID]tailcfg.UserProfile), + Domain: resp.Domain, + Roles: resp.Roles, + DNS: resp.DNS, + DNSDomains: resp.SearchPaths, + Hostinfo: resp.Node.Hostinfo, + PacketFilter: resp.PacketFilter, + } + for _, profile := range resp.UserProfiles { + nm.UserProfiles[profile.ID] = profile + } + if resp.Node.MachineAuthorized { + nm.MachineStatus = tailcfg.MachineAuthorized + } else { + nm.MachineStatus = tailcfg.MachineUnauthorized + } + //c.logf("new network map[%d]:\n%s", i, nm.Concise()) + + c.mu.Lock() + c.expiry = &nm.Expiry + c.mu.Unlock() + + cb(nm) + } + if ctx.Err() != nil { + return ctx.Err() + } + return nil +} + +func decode(res *http.Response, v interface{}, serverKey *wgcfg.Key, mkey *wgcfg.PrivateKey) error { + defer res.Body.Close() + msg, err := ioutil.ReadAll(io.LimitReader(res.Body, 1<<20)) + if err != nil { + return err + } + if res.StatusCode != 200 { + return fmt.Errorf("%d: %v", res.StatusCode, string(msg)) + } + return decodeMsg(msg, v, serverKey, mkey) +} + +func (c *Direct) decodeMsg(msg []byte, v interface{}) error { + mkey := c.persist.PrivateMachineKey + serverKey := c.serverKey + + decrypted, err := decryptMsg(msg, &serverKey, &mkey) + if err != nil { + return err + } + var b []byte + if c.newDecompressor == nil { + b = decrypted + } else { + //decoder, err := zstd.NewReader(nil) + decoder, err := c.newDecompressor() + if err != nil { + return err + } + defer decoder.Close() + b, err = decoder.DecodeAll(decrypted, nil) + if err != nil { + return err + } + } + if err := json.Unmarshal(b, v); err != nil { + return fmt.Errorf("response: %v", err) + } + return nil + +} + +func decodeMsg(msg []byte, v interface{}, serverKey *wgcfg.Key, mkey *wgcfg.PrivateKey) error { + decrypted, err := decryptMsg(msg, serverKey, mkey) + if err != nil { + return err + } + if err := json.Unmarshal(decrypted, v); err != nil { + return fmt.Errorf("response: %v", err) + } + return nil +} + +func decryptMsg(msg []byte, serverKey *wgcfg.Key, mkey *wgcfg.PrivateKey) ([]byte, error) { + var nonce [24]byte + if len(msg) < len(nonce)+1 { + return nil, fmt.Errorf("response missing nonce, len=%d", len(msg)) + } + copy(nonce[:], msg) + msg = msg[len(nonce):] + + pub, pri := (*[32]byte)(serverKey), (*[32]byte)(mkey) + decrypted, ok := box.Open(nil, msg, &nonce, pub, pri) + if !ok { + return nil, fmt.Errorf("cannot decrypt response") + } + return decrypted, nil +} + +func encode(v interface{}, serverKey *wgcfg.Key, mkey *wgcfg.PrivateKey) ([]byte, error) { + b, err := json.Marshal(v) + if err != nil { + return nil, err + } + var nonce [24]byte + if _, err := io.ReadFull(rand.Reader, nonce[:]); err != nil { + panic(err) + } + pub, pri := (*[32]byte)(serverKey), (*[32]byte)(mkey) + msg := box.Seal(nonce[:], b, &nonce, pub, pri) + return msg, nil +} + +func loadServerKey(ctx context.Context, httpc *http.Client, serverURL string) (wgcfg.Key, error) { + req, err := http.NewRequest("GET", serverURL+"/key", nil) + if err != nil { + return wgcfg.Key{}, fmt.Errorf("create control key request: %v", err) + } + req = req.WithContext(ctx) + res, err := httpc.Do(req) + if err != nil { + return wgcfg.Key{}, fmt.Errorf("fetch control key: %v", err) + } + defer res.Body.Close() + b, err := ioutil.ReadAll(io.LimitReader(res.Body, 1<<16)) + if err != nil { + return wgcfg.Key{}, fmt.Errorf("fetch control key response: %v", err) + } + if res.StatusCode != 200 { + return wgcfg.Key{}, fmt.Errorf("fetch control key: %d: %s", res.StatusCode, string(b)) + } + key, err := wgcfg.ParseHexKey(string(b)) + if err != nil { + return wgcfg.Key{}, fmt.Errorf("fetch control key: %v", err) + } + return *key, nil +} diff --git a/control/controlclient/direct_test.go b/control/controlclient/direct_test.go new file mode 100644 index 000000000..6e852db1c --- /dev/null +++ b/control/controlclient/direct_test.go @@ -0,0 +1,305 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build depends_on_currently_unreleased + +package controlclient + +import ( + "context" + "io/ioutil" + "net/http" + "net/http/httptest" + "os" + "testing" + "time" + + "github.com/klauspost/compress/zstd" + "github.com/tailscale/wireguard-go/wgcfg" + "tailscale.com/tailcfg" + "tailscale.io/control" // not yet released +) + +func TestClientsReusingKeys(t *testing.T) { + tmpdir, err := ioutil.TempDir("", "control-test-") + if err != nil { + t.Fatal(err) + } + var server *control.Server + httpsrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + server.ServeHTTP(w, r) + })) + server, err = control.New(tmpdir, httpsrv.URL, true) + if err != nil { + t.Fatal(err) + } + server.QuietLogging = true + defer func() { + httpsrv.CloseClientConnections() + httpsrv.Close() + os.RemoveAll(tmpdir) + }() + + hi := NewHostinfo() + hi.FrontendLogID = "go-test-only" + hi.BackendLogID = "go-test-only" + c1, err := NewDirect(Options{ + ServerURL: httpsrv.URL, + HTTPC: httpsrv.Client(), + //TimeNow: s.control.TimeNow, + Logf: func(fmt string, args ...interface{}) { + t.Helper() + t.Logf("c1: "+fmt, args...) + }, + Hostinfo: &hi, + }) + if err != nil { + t.Fatal(err) + } + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + authURL, err := c1.TryLogin(ctx, nil, 0) + if err != nil { + t.Fatal(err) + } + const user = "testuser1@tailscale.onmicrosoft.com" + postAuthURL(t, ctx, httpsrv.Client(), user, authURL) + newURL, err := c1.WaitLoginURL(ctx, authURL) + if err != nil { + t.Fatal(err) + } + if newURL != "" { + t.Fatalf("unexpected newURL: %s", newURL) + } + + pollErrCh := make(chan error) + go func() { + err := c1.PollNetMap(ctx, -1, func(netMap *NetworkMap) {}) + pollErrCh <- err + }() + + select { + case err := <-pollErrCh: + t.Fatal(err) + default: + } + + c2, err := NewDirect(Options{ + ServerURL: httpsrv.URL, + HTTPC: httpsrv.Client(), + Logf: func(fmt string, args ...interface{}) { + t.Helper() + t.Logf("c2: "+fmt, args...) + }, + Persist: c1.GetPersist(), + Hostinfo: &hi, + NewDecompressor: func() (Decompressor, error) { + return zstd.NewReader(nil) + }, + KeepAlive: true, + }) + if err != nil { + t.Fatal(err) + } + authURL, err = c2.TryLogin(ctx, nil, 0) + if err != nil { + t.Fatal(err) + } + if authURL != "" { + t.Errorf("unexpected authURL %s", authURL) + } + + err = c2.PollNetMap(ctx, 1, func(netMap *NetworkMap) {}) + if err != nil { + t.Fatal(err) + } + + select { + case err := <-pollErrCh: + t.Logf("expected poll error: %v", err) + case <-time.After(5 * time.Second): + t.Fatal("first client poll failed to close") + } +} + +func TestClientsReusingOldKey(t *testing.T) { + tmpdir, err := ioutil.TempDir("", "control-test-") + if err != nil { + t.Fatal(err) + } + var server *control.Server + httpsrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + server.ServeHTTP(w, r) + })) + server, err = control.New(tmpdir, httpsrv.URL, true) + if err != nil { + t.Fatal(err) + } + server.QuietLogging = true + defer func() { + httpsrv.CloseClientConnections() + httpsrv.Close() + os.RemoveAll(tmpdir) + }() + + hi := NewHostinfo() + hi.FrontendLogID = "go-test-only" + hi.BackendLogID = "go-test-only" + genOpts := func() Options { + return Options{ + ServerURL: httpsrv.URL, + HTTPC: httpsrv.Client(), + //TimeNow: s.control.TimeNow, + Logf: func(fmt string, args ...interface{}) { + t.Helper() + t.Logf("c1: "+fmt, args...) + }, + Hostinfo: &hi, + } + } + + // Login with a new node key. This requires authorization. + c1, err := NewDirect(genOpts()) + if err != nil { + t.Fatal(err) + } + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + authURL, err := c1.TryLogin(ctx, nil, 0) + if err != nil { + t.Fatal(err) + } + const user = "testuser1@tailscale.onmicrosoft.com" + postAuthURL(t, ctx, httpsrv.Client(), user, authURL) + newURL, err := c1.WaitLoginURL(ctx, authURL) + if err != nil { + t.Fatal(err) + } + if newURL != "" { + t.Fatalf("unexpected newURL: %s", newURL) + } + + if err := c1.PollNetMap(ctx, 1, func(netMap *NetworkMap) {}); err != nil { + t.Fatal(err) + } + + newPrivKey := func(t *testing.T) wgcfg.PrivateKey { + t.Helper() + k, err := wgcfg.NewPrivateKey() + if err != nil { + t.Fatal(err) + } + return *k + } + + // Replace the previous key with a new key. + persist1 := c1.GetPersist() + persist2 := Persist{ + PrivateMachineKey: persist1.PrivateMachineKey, + OldPrivateNodeKey: persist1.PrivateNodeKey, + PrivateNodeKey: newPrivKey(t), + } + opts := genOpts() + opts.Persist = persist2 + + c1, err = NewDirect(opts) + if err != nil { + t.Fatal(err) + } + if authURL, err := c1.TryLogin(ctx, nil, 0); err != nil { + t.Fatal(err) + } else if authURL == "" { + t.Fatal("expected authURL for reused oldNodeKey, got none") + } else { + postAuthURL(t, ctx, httpsrv.Client(), user, authURL) + if newURL, err := c1.WaitLoginURL(ctx, authURL); err != nil { + t.Fatal(err) + } else if newURL != "" { + t.Fatalf("unexpected newURL: %s", newURL) + } + } + if p := c1.GetPersist(); p.PrivateNodeKey != opts.Persist.PrivateNodeKey { + t.Error("unexpected node key change") + } else { + persist2 = p + } + + // Here we simulate a client using using old persistant data. + // We use the key we have already replaced as the old node key. + // This requires the user to authenticate. + persist3 := Persist{ + PrivateMachineKey: persist1.PrivateMachineKey, + OldPrivateNodeKey: persist1.PrivateNodeKey, + PrivateNodeKey: newPrivKey(t), + } + opts = genOpts() + opts.Persist = persist3 + + c1, err = NewDirect(opts) + if err != nil { + t.Fatal(err) + } + if authURL, err := c1.TryLogin(ctx, nil, 0); err != nil { + t.Fatal(err) + } else if authURL == "" { + t.Fatal("expected authURL for reused oldNodeKey, got none") + } else { + postAuthURL(t, ctx, httpsrv.Client(), user, authURL) + if newURL, err := c1.WaitLoginURL(ctx, authURL); err != nil { + t.Fatal(err) + } else if newURL != "" { + t.Fatalf("unexpected newURL: %s", newURL) + } + } + if err := c1.PollNetMap(ctx, 1, func(netMap *NetworkMap) {}); err != nil { + t.Fatal(err) + } + + // At this point, there should only be one node for the machine key + // registered as active in the server. + mkey := tailcfg.MachineKey(*persist1.PrivateMachineKey.Public()) + nodeIDs, err := server.DB().MachineNodes(mkey) + if err != nil { + t.Fatal(err) + } + if len(nodeIDs) != 1 { + t.Logf("active nodes for machine key %v:", mkey) + for i, nodeID := range nodeIDs { + nodeKey := server.DB().NodeKey(nodeID) + t.Logf("\tnode %d: id=%v, key=%v", i, nodeID, nodeKey) + } + t.Fatalf("want 1 active node for the client machine, got %d", len(nodeIDs)) + } + + // Now try the previous node key. It should fail. + opts = genOpts() + opts.Persist = persist2 + c1, err = NewDirect(opts) + if err != nil { + t.Fatal(err) + } + // TODO(crawshaw): make this return an actual error. + // Have cfgdb track expired keys, and when an expired key is reused + // produce an error. + if authURL, err := c1.TryLogin(ctx, nil, 0); err != nil { + t.Fatal(err) + } else if authURL == "" { + t.Fatal("expected authURL for reused nodeKey, got none") + } else { + postAuthURL(t, ctx, httpsrv.Client(), user, authURL) + if newURL, err := c1.WaitLoginURL(ctx, authURL); err != nil { + t.Fatal(err) + } else if newURL != "" { + t.Fatalf("unexpected newURL: %s", newURL) + } + } + if err := c1.PollNetMap(ctx, 1, func(netMap *NetworkMap) {}); err != nil { + t.Fatal(err) + } + if nodeIDs, err := server.DB().MachineNodes(mkey); err != nil { + t.Fatal(err) + } else if len(nodeIDs) != 1 { + t.Fatalf("want 1 active node for the client machine, got %d", len(nodeIDs)) + } +} diff --git a/control/controlclient/netmap.go b/control/controlclient/netmap.go new file mode 100644 index 000000000..3e25cefde --- /dev/null +++ b/control/controlclient/netmap.go @@ -0,0 +1,294 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package controlclient + +import ( + "bytes" + "encoding/base64" + "encoding/json" + "fmt" + "log" + "net" + "runtime" + "strings" + "time" + + "github.com/tailscale/wireguard-go/wgcfg" + "tailscale.com/tailcfg" + "tailscale.com/wgengine/filter" +) + +type NetworkMap struct { + // Core networking + + NodeKey tailcfg.NodeKey + PrivateKey wgcfg.PrivateKey + Expiry time.Time + Addresses []wgcfg.CIDR + LocalPort uint16 // used for debugging + MachineStatus tailcfg.MachineStatus + Peers []tailcfg.Node + DNS []wgcfg.IP + DNSDomains []string + Hostinfo tailcfg.Hostinfo + PacketFilter filter.Matches + + // ACLs + + User tailcfg.UserID + Domain string + // TODO(crawshaw): reduce UserProfiles to []tailcfg.UserProfile? + // There are lots of ways to slice this data, leave it up to users. + UserProfiles map[tailcfg.UserID]tailcfg.UserProfile + Roles []tailcfg.Role + // TODO(crawshaw): Groups []tailcfg.Group + // TODO(crawshaw): Capabilities []tailcfg.Capability +} + +func (n *NetworkMap) Equal(n2 *NetworkMap) bool { + // TODO(crawshaw): this is crude, but is an easy way to avoid bugs. + b, err := json.Marshal(n) + if err != nil { + panic(err) + } + b2, err := json.Marshal(n2) + if err != nil { + panic(err) + } + return bytes.Equal(b, b2) +} + +func (n *NetworkMap) isEmpty() bool { + if n == nil { + return true + } + return n.Equal(&NetworkMap{}) +} + +func (nm NetworkMap) String() string { + return nm.Concise() +} + +func keyString(key [32]byte) string { + b64 := base64.StdEncoding.EncodeToString(key[:]) + abbrev := "invalid" + if len(b64) == 44 { + abbrev = b64[0:4] + "…" + b64[39:43] + } + return fmt.Sprintf("[%s]", abbrev) +} + +func (nm *NetworkMap) Concise() string { + buf := new(strings.Builder) + fmt.Fprintf(buf, "NetworkMap: self: %v auth=%v :%v %v\n", + keyString(nm.NodeKey), nm.MachineStatus, + nm.LocalPort, nm.Addresses) + for _, p := range nm.Peers { + aip := make([]string, len(p.AllowedIPs)) + for i, a := range p.AllowedIPs { + aip[i] = fmt.Sprint(a) + } + u := fmt.Sprint(p.User) + if strings.HasPrefix(u, "userid:") { + u = "u:" + u[7:] + } + f1 := fmt.Sprintf(" %v %-6v %v", + keyString(p.Key), u, p.Endpoints) + f2 := fmt.Sprintf(" %*v\n", 70-len(f1), + strings.Join(aip, " ")) + fmt.Fprintf(buf, "%s%s", f1, f2) + } + return buf.String() +} + +func (nm *NetworkMap) JSON() string { + b, err := json.MarshalIndent(*nm, "", " ") + if err != nil { + return fmt.Sprintf("[json error: %v]", err) + } + return string(b) +} + +// TODO(apenwarr): delete me once relaynode doesn't need this anymore. +// control.go:userMap() supercedes it. This does not belong in the client. +func (nm *NetworkMap) UserMap() map[string][]filter.IP { + // Make a lookup table of roles + log.Printf("roles list is: %v\n", nm.Roles) + roles := make(map[tailcfg.RoleID]tailcfg.Role) + for _, role := range nm.Roles { + roles[role.ID] = role + } + + // First, go through each node's addresses and make a lookup table + // of IP->User. + fwd := make(map[wgcfg.IP]string) + for _, node := range nm.Peers { + for _, addr := range node.Addresses { + if addr.Mask == 32 && addr.IP.Is4() { + user, ok := nm.UserProfiles[node.User] + if ok { + fwd[addr.IP] = user.LoginName + } + } + } + } + + // Next, reverse the mapping into User->IP. + rev := make(map[string][]filter.IP) + for ip, username := range fwd { + ip4 := ip.To4() + if ip4 != nil { + fip := filter.NewIP(net.IP(ip4)) + rev[username] = append(rev[username], fip) + } + } + + // Now add roles, which are lists of users, and therefore lists + // of those users' IP addresses. + for _, user := range nm.UserProfiles { + for _, roleid := range user.Roles { + role, ok := roles[roleid] + if ok { + rolename := "role:" + role.Name + rev[rolename] = append(rev[rolename], rev[user.LoginName]...) + } + } + } + + //log.Printf("Usermap is: %v\n", rev) + return rev +} + +var iOS = runtime.GOOS == "darwin" && (runtime.GOARCH == "arm" || runtime.GOARCH == "arm64") +var keepalive = !iOS + +const ( + UAllowSingleHosts = 1 << iota + UAllowSubnetRoutes + UAllowDefaultRoute + UHackDefaultRoute + + UDefault = 0 +) + +// Several programs need to parse these arguments into uflags, so let's +// centralize it here. +func UFlagsHelper(uroutes, rroutes, droutes bool) int { + uflags := 0 + if uroutes { + uflags |= UAllowSingleHosts + } + if rroutes { + uflags |= UAllowSubnetRoutes + } + if droutes { + uflags |= UAllowDefaultRoute + } + return uflags +} + +func (nm *NetworkMap) UAPI(uflags int, dnsOverride []wgcfg.IP) string { + wgcfg, err := nm.WGCfg(uflags, dnsOverride) + if err != nil { + log.Fatalf("WGCfg() failed unexpectedly: %v\n", err) + } + s, err := wgcfg.ToUAPI() + if err != nil { + log.Fatalf("ToUAPI() failed unexpectedly: %v\n", err) + } + return s +} + +func (nm *NetworkMap) WGCfg(uflags int, dnsOverride []wgcfg.IP) (*wgcfg.Config, error) { + s := nm._WireGuardConfig(uflags, dnsOverride, true) + return wgcfg.FromWgQuick(s, "tailscale") +} + +// TODO(apenwarr): This mode is dangerous. +// Discarding the extra endpoints is almost universally the wrong choice. +// Except that plain wireguard can't handle a peer with multiple endpoints. +// (Yet?) +func (nm *NetworkMap) WireGuardConfigOneEndpoint(uflags int, dnsOverride []wgcfg.IP) string { + return nm._WireGuardConfig(uflags, dnsOverride, false) +} + +func (nm *NetworkMap) _WireGuardConfig(uflags int, dnsOverride []wgcfg.IP, allEndpoints bool) string { + buf := new(strings.Builder) + fmt.Fprintf(buf, "[Interface]\n") + fmt.Fprintf(buf, "PrivateKey = %s\n", base64.StdEncoding.EncodeToString(nm.PrivateKey[:])) + if len(nm.Addresses) > 0 { + fmt.Fprintf(buf, "Address = ") + for i, cidr := range nm.Addresses { + if i > 0 { + fmt.Fprintf(buf, ", ") + } + fmt.Fprintf(buf, "%s", cidr) + } + fmt.Fprintf(buf, "\n") + } + fmt.Fprintf(buf, "ListenPort = %d\n", nm.LocalPort) + if len(dnsOverride) > 0 { + dnss := []string{} + for _, ip := range dnsOverride { + dnss = append(dnss, ip.String()) + } + fmt.Fprintf(buf, "DNS = %s\n", strings.Join(dnss, ",")) + } + fmt.Fprintf(buf, "\n") + + for i, peer := range nm.Peers { + if (uflags&UAllowSingleHosts) == 0 && len(peer.AllowedIPs) < 2 { + log.Printf("wgcfg: %v skipping a single-host peer.\n", peer.Key.AbbrevString()) + continue + } + if i > 0 { + fmt.Fprintf(buf, "\n") + } + fmt.Fprintf(buf, "[Peer]\n") + fmt.Fprintf(buf, "PublicKey = %s\n", base64.StdEncoding.EncodeToString(peer.Key[:])) + if len(peer.Endpoints) > 0 { + if len(peer.Endpoints) == 1 { + fmt.Fprintf(buf, "Endpoint = %s", peer.Endpoints[0]) + } else if allEndpoints { + // TODO(apenwarr): This mode is incompatible. + // Normal wireguard clients don't know how to + // parse it (yet?) + fmt.Fprintf(buf, "Endpoint = %s", + strings.Join(peer.Endpoints, ",")) + } else { + fmt.Fprintf(buf, "Endpoint = %s # other endpoints: %s", + peer.Endpoints[0], + strings.Join(peer.Endpoints[1:], ", ")) + } + buf.WriteByte('\n') + } + var aips []string + for _, allowedIP := range peer.AllowedIPs { + aip := allowedIP.String() + if allowedIP.Mask == 0 { + if (uflags & UAllowDefaultRoute) == 0 { + log.Printf("wgcfg: %v skipping default route\n", peer.Key.AbbrevString()) + continue + } + if (uflags & UHackDefaultRoute) != 0 { + aip = "10.0.0.0/8" + log.Printf("wgcfg: %v converting default route => %v\n", peer.Key.AbbrevString(), aip) + } + } else if allowedIP.Mask < 32 { + if (uflags & UAllowSubnetRoutes) == 0 { + log.Printf("wgcfg: %v skipping subnet route\n", peer.Key.AbbrevString()) + continue + } + } + aips = append(aips, aip) + } + fmt.Fprintf(buf, "AllowedIPs = %s\n", strings.Join(aips, ", ")) + if keepalive { + fmt.Fprintf(buf, "PersistentKeepalive = 25\n") + } + } + + return buf.String() +} diff --git a/control/policy/policy.go b/control/policy/policy.go new file mode 100644 index 000000000..8e3d59b81 --- /dev/null +++ b/control/policy/policy.go @@ -0,0 +1,227 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package policy + +import ( + "bytes" + "errors" + "fmt" + "github.com/tailscale/hujson" + "net" + "strconv" + "strings" + "tailscale.com/wgengine/filter" +) + +type IP = filter.IP + +const IPAny = filter.IPAny + +type row struct { + Action string + Users []string + Ports []string +} + +type Policy struct { + ACLs []row + Groups map[string][]string + Hosts map[string]IP +} + +func lineAndColumn(b []byte, ofs int64) (line, col int) { + line = 1 + for _, c := range b[:ofs] { + if c == '\n' { + col = 1 + line++ + } else { + col++ + } + } + return line, col +} + +func betterUnmarshal(b []byte, obj interface{}) error { + bio := bytes.NewReader(b) + d := hujson.NewDecoder(bio) + d.DisallowUnknownFields() + err := d.Decode(obj) + if err != nil { + switch ee := err.(type) { + case *hujson.SyntaxError: + row, col := lineAndColumn(b, ee.Offset) + return fmt.Errorf("line %d col %d: %v", row, col, ee) + default: + return fmt.Errorf("parser: %v", err) + } + } + return nil +} + +func Parse(acljson string) (*Policy, error) { + p := &Policy{} + err := betterUnmarshal([]byte(acljson), p) + if err != nil { + return nil, err + } + + // Check syntax with an empty usermap to start with. + // The caller might not have a valid usermap at startup, but we still + // want to check that the acljson doesn't have any syntax errors + // as early as possible. When the usermap updates later, it won't + // add any new syntax errors. + // + // TODO(apenwarr): change unmarshal code to detect syntax errors above. + // Right now some of the sub-objects aren't parsed until .Expand(). + emptyUserMap := make(map[string][]IP) + _, err = p.Expand(emptyUserMap) + if err != nil { + return nil, err + } + + return p, nil +} + +func parseHostPortRange(hostport string) (host string, ports []filter.PortRange, err error) { + hl := strings.Split(hostport, ":") + if len(hl) != 2 { + return "", nil, errors.New("hostport must have exactly one colon(:)") + } + host = hl[0] + portlist := hl[1] + + if portlist == "*" { + // Special case: permit hostname:* as a port wildcard. + ports = append(ports, filter.PortRangeAny) + return host, ports, nil + } + + pl := strings.Split(portlist, ",") + for _, pp := range pl { + if len(pp) == 0 { + return "", nil, fmt.Errorf("invalid port list: %#v", portlist) + } + + pr := strings.Split(pp, "-") + if len(pr) > 2 { + return "", nil, fmt.Errorf("port range %#v: too many dashes(-)", pp) + } + + var first, last uint64 + first, err := strconv.ParseUint(pr[0], 10, 16) + if err != nil { + return "", nil, fmt.Errorf("port range %#v: invalid first integer", pp) + } + + if len(pr) >= 2 { + last, err = strconv.ParseUint(pr[1], 10, 16) + if err != nil { + return "", nil, fmt.Errorf("port range %#v: invalid last integer", pp) + } + } else { + last = first + } + + if first == 0 { + return "", nil, fmt.Errorf("port range %#v: first port must be >0, or use '*' for wildcard", pp) + } + + if first > last { + return "", nil, fmt.Errorf("port range %#v: first port must be >= last port", pp) + } + + ports = append(ports, filter.PortRange{uint16(first), uint16(last)}) + } + + return host, ports, nil +} + +func (p *Policy) Expand(usermap map[string][]IP) (filter.Matches, error) { + lcusermap := make(map[string][]IP) + for k, v := range usermap { + k = strings.ToLower(k) + lcusermap[k] = v + } + + for k, userlist := range p.Groups { + k = strings.ToLower(k) + if !strings.HasPrefix(k, "group:") { + return nil, fmt.Errorf("Group[%#v]: group names must start with 'group:'", k) + } + for _, u := range userlist { + uips := lcusermap[u] + lcusermap[k] = append(lcusermap[k], uips...) + } + } + + hosts := p.Hosts + + var out filter.Matches + for _, acl := range p.ACLs { + if acl.Action != "accept" { + return nil, fmt.Errorf("Action=%#v is not supported", acl.Action) + } + + var srcs []IP + for _, user := range acl.Users { + user = strings.ToLower(user) + if user == "*" { + srcs = append(srcs, IPAny) + continue + } else if strings.Contains(user, "@") || + strings.HasPrefix(user, "role:") || + strings.HasPrefix(user, "group:") { + // fine if the requested user doesn't exist. + // we don't want to crash ACL parsing just + // because a previously authed user gets + // deleted. We'll silently ignore it and + // no firewall rules are needed. + // TODO(apenwarr): maybe print a warning? + for _, ip := range lcusermap[user] { + if ip != IPAny { + srcs = append(srcs, ip) + } + } + } else { + return nil, fmt.Errorf("wgengine/filter: invalid username: %q: needs @domain or group: or role:", user) + } + } + + var dsts []filter.IPPortRange + for _, hostport := range acl.Ports { + host, ports, err := parseHostPortRange(hostport) + if err != nil { + return nil, fmt.Errorf("Ports=%#v: %v", hostport, err) + } + ip := net.ParseIP(host) + ipv, ok := hosts[host] + if ok { + // matches an alias; ipv is now valid + } else if ip != nil && ip.IsUnspecified() { + // For clarity, reject 0.0.0.0 as an input + return nil, fmt.Errorf("Ports=%#v: to allow all IP addresses, use *:port, not 0.0.0.0:port", hostport) + } else if ip == nil && host == "*" { + // User explicitly requested wildcard dst ip + ipv = IPAny + } else { + if ip != nil { + ip = ip.To4() + } + if ip == nil || len(ip) != 4 { + return nil, fmt.Errorf("Ports=%#v: %#v: invalid IPv4 address", hostport, host) + } + ipv = filter.NewIP(ip) + } + + for _, pr := range ports { + dsts = append(dsts, filter.IPPortRange{ipv, pr}) + } + } + + out = append(out, filter.Match{DstPorts: dsts, SrcIPs: srcs}) + } + return out, nil +} diff --git a/control/policy/policy_test.go b/control/policy/policy_test.go new file mode 100644 index 000000000..4d20e350d --- /dev/null +++ b/control/policy/policy_test.go @@ -0,0 +1,156 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package policy + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + "tailscale.com/wgengine/filter" +) + +type PortRange = filter.PortRange +type IPPortRange = filter.IPPortRange + +var syntax_errors = []string{ + `{ "ACLs": []! }`, + + `{ "ACLs": [ + {"Action": "accept", "Users": [], "xPorts": ["100.122.98.50:22"]} + ]}`, + + `{ "ACLs": [ + {"Action": "drop", "Users": [], "Ports": ["100.122.98.50:22"]} + ]}`, + + `{ "ACLs": [ + {"Users": [], "Ports": ["100.122.98.50:22"]} + ]}`, + + `{ "ACLs": [ + {"Action": "accept", "Users": [], "Ports": ["1.2.3.4"]} + ]}`, + + `{ "ACLs": [ + {"Action": "accept", "Users": [], "Ports": ["1.2.3.4:0"]} + ]}`, + + `{ "ACLs": [ + {"Action": "accept", "Users": [], "Ports": ["0.0.0.0:12"]} + ]}`, + + `{ "ACLs": [ + {"Action": "accept", "Users": [], "Ports": ["*:0"]} + ]}`, + + `{ "ACLs": [ + {"Action": "accept", "Users": [], "Ports": ["1.2.3.4:5:6"]} + ]}`, + + `{ "ACLs": [ + {"Action": "accept", "Users": [], "Ports": ["1.2.3.4.5:12"]} + ]}`, + + `{ "ACLs": [ + {"Action": "accept", "Users": [], "Ports": ["1.2.3.4::12"]} + ]}`, + + `{ "ACLs": [ + {"Action": "accept", "Users": [], "Ports": ["1.2.3.4"]} + ]}`, + + `{ "ACLs": [ + {"Action": "accept", "Users": [], "Ports": ["1.2.3.4:0-0"]} + ]}`, + + `{ "ACLs": [ + {"Action": "accept", "Users": [], "Ports": ["1.2.3.4:1-10,2-"]} + ]}`, + + `{ "ACLs": [ + {"Action": "accept", "Users": [], "Ports": ["1.2.3.4:1-10,*"]} + ]}`, + + `{ "ACLs": [ + {"Action": "accept", "Users": [], "Ports": ["1.2.3.4,5.6.7.8:1-10"]} + ]}`, + + `{ "Hosts": {"mailserver": "not-an-ip"} }`, + + `{ "Hosts": {"mailserver": "1.2.3.4:55"} }`, + + `{ "xGroups": { + "bob": ["user1", "user2"] + }}`, +} + +func TestSyntaxErrors(t *testing.T) { + for _, s := range syntax_errors { + _, err := Parse(s) + if err == nil { + t.Fatalf("Parse passed when it shouldn't. json:\n---\n%v\n---", s) + } + } +} + +func ippr(ip IP, start, end uint16) []IPPortRange { + return []IPPortRange{ + IPPortRange{ip, PortRange{start, end}}, + } +} + +func TestPolicy(t *testing.T) { + // Check ACL table parsing + + usermap := map[string][]IP{ + "A@b.com": []IP{0x08010101, 0x08020202}, + "role:admin": []IP{0x02020202}, + "user1@org": []IP{0x99010101, 0x99010102}, + // user2 is intentionally missing + "user3@org": []IP{0x99030303}, + "user4@org": []IP{}, + } + want := filter.Matches{ + {SrcIPs: []IP{0x08010101, 0x08020202}, DstPorts: []IPPortRange{ + IPPortRange{0x01020304, PortRange{22, 22}}, + IPPortRange{0x05060708, PortRange{23, 24}}, + IPPortRange{0x05060708, PortRange{27, 28}}, + }}, + {SrcIPs: []IP{0x02020202}, DstPorts: ippr(0x08010101, 22, 22)}, + {SrcIPs: []IP{0}, DstPorts: []IPPortRange{ + IPPortRange{0x647a6232, PortRange{0, 65535}}, + IPPortRange{0, PortRange{443, 443}}, + }}, + {SrcIPs: []IP{0x99010101, 0x99010102, 0x99030303}, DstPorts: ippr(0x01020304, 999, 999)}, + } + + p, err := Parse(` +{ + // Test comment + "Hosts": { + "h1": "1.2.3.4", /* test comment */ + "h2": "5.6.7.8" + }, + "Groups": { + "group:eng": ["user1@org", "user2@org", "user3@org", "user4@org"] + }, + "ACLs": [ + {"Action": "accept", "Users": ["a@b.com"], "Ports": ["h1:22", "h2:23-24,27-28"]}, + {"Action": "accept", "Users": ["role:Admin"], "Ports": ["8.1.1.1:22"]}, + {"Action": "accept", "Users": ["*"], "Ports": ["100.122.98.50:*", "*:443"]}, + {"Action": "accept", "Users": ["group:eng"], "Ports": ["h1:999"]}, + ]} +`) + if err != nil { + t.Fatalf("Parse failed: %v", err) + } + matches, err := p.Expand(usermap) + if err != nil { + t.Fatalf("Expand failed: %v", err) + } + if diff := cmp.Diff(want, matches); diff != "" { + t.Fatalf("Expand mismatch (-want +got):\n%s", diff) + } +} diff --git a/derp/derp_client.go b/derp/derp_client.go new file mode 100644 index 000000000..551548b92 --- /dev/null +++ b/derp/derp_client.go @@ -0,0 +1,182 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package derp + +import ( + "bufio" + "crypto/rand" + "encoding/json" + "fmt" + "io" + "net" + "time" + + "golang.org/x/crypto/curve25519" + "golang.org/x/crypto/nacl/box" +) + +type Client struct { + serverKey [32]byte + privateKey [32]byte // TODO(crawshaw): make this wgcfg.PrivateKey? + publicKey [32]byte + logf func(format string, args ...interface{}) + netConn net.Conn + conn *bufio.ReadWriter +} + +func NewClient(privateKey [32]byte, netConn net.Conn, conn *bufio.ReadWriter, logf func(format string, args ...interface{})) (*Client, error) { + c := &Client{ + privateKey: privateKey, + logf: logf, + netConn: netConn, + conn: conn, + } + curve25519.ScalarBaseMult(&c.publicKey, &c.privateKey) + + if err := c.recvServerKey(); err != nil { + return nil, fmt.Errorf("derp.Client: failed to receive server key: %v", err) + } + if err := c.sendClientKey(); err != nil { + return nil, fmt.Errorf("derp.Client: failed to send client key: %v", err) + } + _, err := c.recvServerInfo() + if err != nil { + return nil, fmt.Errorf("derp.Client: failed to receive server info: %v", err) + } + + return c, nil +} + +func (c *Client) recvServerKey() error { + gotMagic, err := readUint32(c.conn, 0xffffffff) + if err != nil { + return err + } + if gotMagic != magic { + return fmt.Errorf("bad magic %x, want %x", gotMagic, magic) + } + if err := readType(c.conn.Reader, typeServerKey); err != nil { + return err + } + if _, err := io.ReadFull(c.conn, c.serverKey[:]); err != nil { + return err + } + return nil +} + +func (c *Client) recvServerInfo() (*serverInfo, error) { + if err := readType(c.conn.Reader, typeServerInfo); err != nil { + return nil, err + } + var nonce [24]byte + if _, err := io.ReadFull(c.conn, nonce[:]); err != nil { + return nil, fmt.Errorf("nonce: %v", err) + } + msgLen, err := readUint32(c.conn, oneMB) + if err != nil { + return nil, fmt.Errorf("msglen: %v", err) + } + msgbox := make([]byte, msgLen) + if _, err := io.ReadFull(c.conn, msgbox); err != nil { + return nil, fmt.Errorf("msgbox: %v", err) + } + msg, ok := box.Open(nil, msgbox, &nonce, &c.serverKey, &c.privateKey) + if !ok { + return nil, fmt.Errorf("msgbox: cannot open len=%d with server key %x", msgLen, c.serverKey[:]) + } + info := new(serverInfo) + if err := json.Unmarshal(msg, info); err != nil { + return nil, fmt.Errorf("msg: %v", err) + } + return info, nil +} + +func (c *Client) sendClientKey() error { + var nonce [24]byte + if _, err := rand.Read(nonce[:]); err != nil { + return err + } + msg := []byte("{}") // no clientInfo for now + msgbox := box.Seal(nil, msg, &nonce, &c.serverKey, &c.privateKey) + + if _, err := c.conn.Write(c.publicKey[:]); err != nil { + return err + } + if _, err := c.conn.Write(nonce[:]); err != nil { + return err + } + if err := putUint32(c.conn.Writer, uint32(len(msgbox))); err != nil { + return err + } + if _, err := c.conn.Write(msgbox); err != nil { + return err + } + return c.conn.Flush() +} + +func (c *Client) Send(dstKey [32]byte, msg []byte) (err error) { + defer func() { + if err != nil { + err = fmt.Errorf("derp.Send: %v", err) + } + }() + + if err := c.conn.WriteByte(typeSendPacket); err != nil { + return err + } + if _, err := c.conn.Write(dstKey[:]); err != nil { + return err + } + msgLen := uint32(len(msg)) + if int(msgLen) != len(msg) { + return fmt.Errorf("packet too big: %d", len(msg)) + } + if err := putUint32(c.conn.Writer, msgLen); err != nil { + return err + } + if _, err := c.conn.Write(msg); err != nil { + return err + } + return c.conn.Flush() +} + +func (c *Client) Recv(b []byte) (n int, err error) { + defer func() { + if err != nil { + err = fmt.Errorf("derp.Recv: %v", err) + } + }() + +loop: + for { + c.netConn.SetReadDeadline(time.Now().Add(120 * time.Second)) + packetType, err := c.conn.ReadByte() + if err != nil { + return 0, err + } + switch packetType { + case typeKeepAlive: + continue + case typeRecvPacket: + break loop + default: + return 0, fmt.Errorf("derp.Recv: unknown packet type %d", packetType) + } + } + + packetLen, err := readUint32(c.conn.Reader, oneMB) + if err != nil { + return 0, err + } + if int(packetLen) > len(b) { + // TODO(crawshaw): discard the packet + return 0, io.ErrShortBuffer + } + b = b[:packetLen] + if _, err := io.ReadFull(c.conn, b); err != nil { + return 0, err + } + return int(packetLen), nil +} diff --git a/derp/derp_server.go b/derp/derp_server.go new file mode 100644 index 000000000..14948836c --- /dev/null +++ b/derp/derp_server.go @@ -0,0 +1,380 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package derp + +// TODO(crawshaw): revise protocol so unknown type packets have a predictable length for skipping. +// TODO(crawshaw): send srcKey with packets to clients? +// TODO(crawshaw): with predefined serverKey in clients and HMAC on packets we could skip TLS + +import ( + "bufio" + "context" + "crypto/rand" + "encoding/binary" + "encoding/json" + "fmt" + "io" + "math/big" + "net" + "sync" + "time" + + "golang.org/x/crypto/curve25519" + "golang.org/x/crypto/nacl/box" +) + +const magic = 0x44c55250 // "DERP" with a non-ASCII high-bit + +const ( + typeServerKey = 0x01 + typeServerInfo = 0x02 + typeSendPacket = 0x03 + typeRecvPacket = 0x04 + typeKeepAlive = 0x05 +) + +const keepAlive = 60 * time.Second + +var bin = binary.BigEndian + +const oneMB = 1 << 20 + +type Server struct { + privateKey [32]byte // TODO(crawshaw): make this wgcfg.PrivateKey? + publicKey [32]byte + logf func(format string, args ...interface{}) + + mu sync.Mutex + netConns map[net.Conn]chan struct{} + clients map[[32]byte]*client +} + +func NewServer(privateKey [32]byte, logf func(format string, args ...interface{})) *Server { + s := &Server{ + privateKey: privateKey, + logf: logf, + clients: make(map[[32]byte]*client), + netConns: make(map[net.Conn]chan struct{}), + } + curve25519.ScalarBaseMult(&s.publicKey, &s.privateKey) + return s +} + +func (s *Server) Close() error { + var closedChs []chan struct{} + + s.mu.Lock() + for netConn, closed := range s.netConns { + netConn.Close() + closedChs = append(closedChs, closed) + } + s.mu.Unlock() + + for _, closed := range closedChs { + <-closed + } + + return nil +} + +func (s *Server) Accept(netConn net.Conn, conn *bufio.ReadWriter) { + closed := make(chan struct{}) + + s.mu.Lock() + s.netConns[netConn] = closed + s.mu.Unlock() + + defer func() { + netConn.Close() + close(closed) + + s.mu.Lock() + delete(s.netConns, netConn) + s.mu.Unlock() + }() + + if err := s.accept(netConn, conn); err != nil { + s.logf("derp: %s: %v", netConn.RemoteAddr(), err) + } +} + +func (s *Server) accept(netConn net.Conn, conn *bufio.ReadWriter) error { + netConn.SetDeadline(time.Now().Add(10 * time.Second)) + if err := s.sendServerKey(conn); err != nil { + return fmt.Errorf("send server key: %v", err) + } + netConn.SetDeadline(time.Now().Add(10 * time.Second)) + clientKey, clientInfo, err := s.recvClientKey(conn) + if err != nil { + return fmt.Errorf("receive client key: %v", err) + } + if err := s.verifyClient(clientKey, clientInfo); err != nil { + return fmt.Errorf("client %x rejected: %v", clientKey, err) + } + + // At this point we trust the client so we don't time out. + netConn.SetDeadline(time.Time{}) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + c := &client{ + key: clientKey, + netConn: netConn, + conn: conn, + } + if clientInfo != nil { + c.info = *clientInfo + } + go func() { + if err := c.keepAlive(ctx); err != nil { + s.logf("derp: %s: client %x: keep alive failed: %v", netConn.RemoteAddr(), c.key, err) + } + }() + + defer func() { + s.mu.Lock() + curClient := s.clients[c.key] + if curClient != nil && curClient.conn == conn { + s.logf("derp: %s: client %x: removing connection", netConn.RemoteAddr(), c.key) + delete(s.clients, c.key) + } + s.mu.Unlock() + }() + + // Hold mu while we add the new client to the clients list and under + // the same acquisition send server info. This ensure that both: + // 1. by the time the client receives the server info, it can be addressed. + // 2. the server info is the very first + c.mu.Lock() + s.mu.Lock() + oldClient := s.clients[c.key] + s.clients[c.key] = c + s.mu.Unlock() + if err := s.sendServerInfo(conn, clientKey); err != nil { + return fmt.Errorf("send server info: %v", err) + } + c.mu.Unlock() + + if oldClient == nil { + s.logf("derp: %s: client %x: adding connection", netConn.RemoteAddr(), c.key) + } else { + oldClient.netConn.Close() + s.logf("derp: %s: client %x: adding connection, replacing %s", netConn.RemoteAddr(), c.key, oldClient.netConn.RemoteAddr()) + } + + for { + dstKey, contents, err := s.recvPacket(c.conn) + if err != nil { + return fmt.Errorf("client %x: recv: %v", c.key, err) + } + + s.mu.Lock() + dst := s.clients[dstKey] + s.mu.Unlock() + + if dst == nil { + s.logf("derp: %s: client %x: dropping packet for unknown %x", netConn.RemoteAddr(), c.key, dstKey) + continue + } + + dst.mu.Lock() + err = s.sendPacket(dst.conn, c.key, contents) + dst.mu.Unlock() + + if err != nil { + s.logf("derp: %s: client %x: dropping packet for %x: %v", netConn.RemoteAddr(), c.key, dstKey, err) + + // If we cannot send to a destination, shut it down. + // Let its receive loop do the cleanup. + s.mu.Lock() + if s.clients[dstKey].conn == dst.conn { + s.clients[dstKey].netConn.Close() + } + s.mu.Unlock() + } + } +} + +func (s *Server) verifyClient(clientKey [32]byte, info *clientInfo) error { + // TODO(crawshaw): implement policy constraints on who can use the DERP server + return nil +} + +func (s *Server) sendServerKey(conn *bufio.ReadWriter) error { + if err := putUint32(conn, magic); err != nil { + return err + } + if err := conn.WriteByte(typeServerKey); err != nil { + return err + } + if _, err := conn.Write(s.publicKey[:]); err != nil { + return err + } + return conn.Flush() +} + +func (s *Server) sendServerInfo(conn *bufio.ReadWriter, clientKey [32]byte) error { + var nonce [24]byte + if _, err := rand.Read(nonce[:]); err != nil { + return err + } + msg := []byte("{}") // no serverInfo for now + msgbox := box.Seal(nil, msg, &nonce, &clientKey, &s.privateKey) + + if err := conn.WriteByte(typeServerInfo); err != nil { + return err + } + if _, err := conn.Write(nonce[:]); err != nil { + return err + } + if err := putUint32(conn, uint32(len(msgbox))); err != nil { + return err + } + if _, err := conn.Write(msgbox); err != nil { + return err + } + return conn.Flush() +} + +func (s *Server) recvClientKey(conn *bufio.ReadWriter) (clientKey [32]byte, info *clientInfo, err error) { + if _, err := io.ReadFull(conn, clientKey[:]); err != nil { + return [32]byte{}, nil, err + } + var nonce [24]byte + if _, err := io.ReadFull(conn, nonce[:]); err != nil { + return [32]byte{}, nil, fmt.Errorf("nonce: %v", err) + } + msgLen, err := readUint32(conn, oneMB) + if err != nil { + return [32]byte{}, nil, fmt.Errorf("msglen: %v", err) + } + msgbox := make([]byte, msgLen) + if _, err := io.ReadFull(conn, msgbox); err != nil { + return [32]byte{}, nil, fmt.Errorf("msgbox: %v", err) + } + msg, ok := box.Open(nil, msgbox, &nonce, &clientKey, &s.privateKey) + if !ok { + return [32]byte{}, nil, fmt.Errorf("msgbox: cannot open len=%d with client key %x", msgLen, clientKey[:]) + } + info = new(clientInfo) + if err := json.Unmarshal(msg, info); err != nil { + return [32]byte{}, nil, fmt.Errorf("msg: %v", err) + } + return clientKey, info, nil +} + +func (s *Server) sendPacket(conn *bufio.ReadWriter, srcKey [32]byte, contents []byte) error { + if err := conn.WriteByte(typeRecvPacket); err != nil { + return err + } + if err := putUint32(conn.Writer, uint32(len(contents))); err != nil { + return err + } + if _, err := conn.Write(contents); err != nil { + return err + } + return conn.Flush() +} + +func (s *Server) recvPacket(conn *bufio.ReadWriter) (dstKey [32]byte, contents []byte, err error) { + if err := readType(conn.Reader, typeSendPacket); err != nil { + return [32]byte{}, nil, err + } + if _, err := io.ReadFull(conn, dstKey[:]); err != nil { + return [32]byte{}, nil, err + } + packetLen, err := readUint32(conn.Reader, oneMB) + if err != nil { + return [32]byte{}, nil, err + } + contents = make([]byte, packetLen) + if _, err := io.ReadFull(conn, contents); err != nil { + return [32]byte{}, nil, err + } + return dstKey, contents, nil +} + +type client struct { + netConn net.Conn + key [32]byte + info clientInfo + + keepAliveTimer *time.Timer + keepAliveReset chan struct{} + + mu sync.Mutex + conn *bufio.ReadWriter +} + +func (c *client) keepAlive(ctx context.Context) error { + jitterMs, err := rand.Int(rand.Reader, big.NewInt(5000)) + if err != nil { + panic(err) + } + jitter := time.Duration(jitterMs.Int64()) * time.Millisecond + c.keepAliveTimer = time.NewTimer(keepAlive + jitter) + + for { + select { + case <-ctx.Done(): + return nil + case <-c.keepAliveReset: + if c.keepAliveTimer.Stop() { + <-c.keepAliveTimer.C + } + c.keepAliveTimer.Reset(keepAlive + jitter) + case <-c.keepAliveTimer.C: + c.mu.Lock() + err := c.conn.WriteByte(typeKeepAlive) + if err == nil { + err = c.conn.Flush() + } + c.mu.Unlock() + + if err != nil { + // TODO log + c.netConn.Close() + return err + } + } + } +} + +type clientInfo struct { +} + +type serverInfo struct { +} + +func readType(r *bufio.Reader, t uint8) error { + packetType, err := r.ReadByte() + if err != nil { + return err + } + if packetType != t { + return fmt.Errorf("bad packet type 0x%X, want 0x%X", packetType, t) + } + return nil +} + +func putUint32(w io.Writer, v uint32) error { + var b [4]byte + bin.PutUint32(b[:], v) + _, err := w.Write(b[:]) + return err +} + +func readUint32(r io.Reader, maxVal uint32) (uint32, error) { + b := make([]byte, 4) + if _, err := io.ReadFull(r, b); err != nil { + return 0, err + } + val := bin.Uint32(b) + if val > maxVal { + return 0, fmt.Errorf("uint32 %d exceeds limit %d", val, maxVal) + } + return val, nil +} diff --git a/derp/derp_test.go b/derp/derp_test.go new file mode 100644 index 000000000..940748f36 --- /dev/null +++ b/derp/derp_test.go @@ -0,0 +1,125 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package derp + +import ( + "bufio" + "crypto/rand" + "net" + "testing" + "time" + + "golang.org/x/crypto/curve25519" +) + +func TestSendRecv(t *testing.T) { + const numClients = 3 + var serverPrivateKey [32]byte + if _, err := rand.Read(serverPrivateKey[:]); err != nil { + t.Fatal(err) + } + var clientPrivateKeys [][32]byte + for i := 0; i < numClients; i++ { + var key [32]byte + if _, err := rand.Read(key[:]); err != nil { + t.Fatal(err) + } + clientPrivateKeys = append(clientPrivateKeys, key) + } + var clientKeys [][32]byte + for _, privKey := range clientPrivateKeys { + var key [32]byte + curve25519.ScalarBaseMult(&key, &privKey) + clientKeys = append(clientKeys, key) + } + + ln, err := net.Listen("tcp", ":0") + if err != nil { + t.Fatal(err) + } + + var clientConns []net.Conn + for i := 0; i < numClients; i++ { + conn, err := net.Dial("tcp", ln.Addr().String()) + if err != nil { + t.Fatal(err) + } + clientConns = append(clientConns, conn) + } + s := NewServer(serverPrivateKey, t.Logf) + defer s.Close() + for i := 0; i < numClients; i++ { + netConn, err := ln.Accept() + if err != nil { + t.Fatal(err) + } + conn := bufio.NewReadWriter(bufio.NewReader(netConn), bufio.NewWriter(netConn)) + go s.Accept(netConn, conn) + } + + var clients []*Client + var recvChs []chan []byte + errCh := make(chan error, 3) + for i := 0; i < numClients; i++ { + key := clientPrivateKeys[i] + netConn := clientConns[i] + conn := bufio.NewReadWriter(bufio.NewReader(netConn), bufio.NewWriter(netConn)) + c, err := NewClient(key, netConn, conn, t.Logf) + if err != nil { + t.Fatalf("client %d: %v", i, err) + } + clients = append(clients, c) + recvChs = append(recvChs, make(chan []byte)) + + go func(i int) { + for { + b := make([]byte, 1<<16) + n, err := c.Recv(b) + if err != nil { + errCh <- err + return + } + b = b[:n] + recvChs[i] <- b + } + }(i) + } + + recv := func(i int, want string) { + t.Helper() + select { + case b := <-recvChs[i]: + if got := string(b); got != want { + t.Errorf("client1.Recv=%q, want %q", got, want) + } + case <-time.After(1 * time.Second): + t.Errorf("client%d.Recv, got nothing, want %q", i, want) + } + } + recvNothing := func(i int) { + t.Helper() + select { + case b := <-recvChs[0]: + t.Errorf("client%d.Recv=%q, want nothing", i, string(b)) + default: + } + } + + msg1 := []byte("hello 0->1\n") + if err := clients[0].Send(clientKeys[1], msg1); err != nil { + t.Fatal(err) + } + recv(1, string(msg1)) + recvNothing(0) + recvNothing(2) + + msg2 := []byte("hello 1->2\n") + if err := clients[1].Send(clientKeys[2], msg2); err != nil { + t.Fatal(err) + } + recv(2, string(msg2)) + recvNothing(0) + recvNothing(1) +} diff --git a/derp/derphttp/derphttp_client.go b/derp/derphttp/derphttp_client.go new file mode 100644 index 000000000..16d747b02 --- /dev/null +++ b/derp/derphttp/derphttp_client.go @@ -0,0 +1,203 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package derphttp implements DERP-over-HTTP. +// +// This makes DERP look exactly like WebSockets. +// A server can implement DERP over HTTPS and even if the TLS connection +// intercepted using a fake root CA, unless the interceptor knows how to +// detect DERP packets, it will look like a web socket. +package derphttp + +import ( + "bufio" + "bytes" + "crypto/tls" + "errors" + "fmt" + "io/ioutil" + "net" + "net/http" + "net/url" + "sync" + + "tailscale.com/derp" +) + +// Client is a DERP-over-HTTP client. +// +// It automatically reconnects on error retry. That is, a failed Send or +// Recv will report the error and not retry, but subsequent calls to +// Send/Recv will completely re-establish the connection. +type Client struct { + privateKey [32]byte + logf func(format string, args ...interface{}) + closed chan struct{} + url *url.URL + resp *http.Response + + netConnMu sync.Mutex + netConn net.Conn + + clientMu sync.Mutex + client *derp.Client +} + +func NewClient(privateKey [32]byte, serverURL string, logf func(format string, args ...interface{})) (c *Client, err error) { + u, err := url.Parse(serverURL) + if err != nil { + return nil, fmt.Errorf("derphttp.NewClient: %v", err) + } + + c = &Client{ + privateKey: privateKey, + logf: logf, + url: u, + closed: make(chan struct{}), + } + if _, err := c.connect("derphttp.NewClient"); err != nil { + c.logf("%v", err) + } + return c, nil +} + +func (c *Client) connect(caller string) (client *derp.Client, err error) { + select { + case <-c.closed: + return nil, ErrClientClosed + default: + } + + c.clientMu.Lock() + defer c.clientMu.Unlock() + + if c.client != nil { + return c.client, nil + } + + c.logf("%s: connecting", caller) + + var netConn net.Conn + defer func() { + if err != nil { + err = fmt.Errorf("%s connect: %v", caller, err) + if netConn := netConn; netConn != nil { + netConn.Close() + } + } + }() + + if c.url.Scheme == "https" { + port := c.url.Port() + if port == "" { + port = "443" + } + config := &tls.Config{} + var tlsConn *tls.Conn + tlsConn, err = tls.Dial("tcp", net.JoinHostPort(c.url.Host, port), config) + if tlsConn != nil { + netConn = tlsConn + } + } else { + netConn, err = net.Dial("tcp", c.url.Host) + } + if err != nil { + return nil, err + } + + c.netConnMu.Lock() + c.netConn = netConn + c.netConnMu.Unlock() + + conn := bufio.NewReadWriter(bufio.NewReader(netConn), bufio.NewWriter(netConn)) + + req, err := http.NewRequest("GET", c.url.String(), nil) + if err != nil { + return nil, err + } + req.Header.Set("Upgrade", "WebSocket") + req.Header.Set("Connection", "Upgrade") + if err := req.Write(conn); err != nil { + return nil, err + } + if err := conn.Flush(); err != nil { + return nil, err + } + + resp, err := http.ReadResponse(conn.Reader, req) + if err != nil { + return nil, err + } + if resp.StatusCode != http.StatusSwitchingProtocols { + b, _ := ioutil.ReadAll(resp.Body) + resp.Body.Close() + return nil, fmt.Errorf("GET failed: %v: %s", err, b) + } + resp.Body = ioutil.NopCloser(bytes.NewReader([]byte{})) + + derpClient, err := derp.NewClient(c.privateKey, netConn, conn, c.logf) + if err != nil { + return nil, err + } + c.resp = resp + c.client = derpClient + return c.client, nil +} + +func (c *Client) Send(dstKey [32]byte, b []byte) error { + client, err := c.connect("derphttp.Client.Send") + if err != nil { + return err + } + if err := client.Send(dstKey, b); err != nil { + c.close() + } + return err +} + +func (c *Client) Recv(b []byte) (int, error) { + client, err := c.connect("derphttp.Client.Recv") + if err != nil { + return 0, err + } + n, err := client.Recv(b) + if err != nil { + c.close() + } + return n, err +} + +func (c *Client) Close() error { + select { + case <-c.closed: + return ErrClientClosed + default: + } + close(c.closed) + c.close() + return nil +} + +func (c *Client) close() { + c.netConnMu.Lock() + netConn := c.netConn + c.netConnMu.Unlock() + + if netConn != nil { + netConn.Close() + } + + c.clientMu.Lock() + defer c.clientMu.Unlock() + if c.client == nil { + return + } + c.resp = nil + c.client = nil + c.netConnMu.Lock() + c.netConn = nil + c.netConnMu.Unlock() +} + +var ErrClientClosed = errors.New("derphttp.Client closed") diff --git a/derp/derphttp/derphttp_server.go b/derp/derphttp/derphttp_server.go new file mode 100644 index 000000000..0aef1c69e --- /dev/null +++ b/derp/derphttp/derphttp_server.go @@ -0,0 +1,35 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package derphttp + +import ( + "net/http" + + "tailscale.com/derp" +) + +func Handler(s *derp.Server) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Upgrade") != "WebSocket" { + http.Error(w, "DERP requires connection upgrade", http.StatusUpgradeRequired) + return + } + w.Header().Set("Upgrade", "WebSocket") + w.Header().Set("Connection", "Upgrade") + w.WriteHeader(http.StatusSwitchingProtocols) + + h, ok := w.(http.Hijacker) + if !ok { + http.Error(w, "HTTP does not support general TCP support", 500) + return + } + netConn, conn, err := h.Hijack() + if err != nil { + http.Error(w, "HTTP does not support general TCP support", 500) + return + } + s.Accept(netConn, conn) + }) +} diff --git a/derp/derphttp/derphttp_test.go b/derp/derphttp/derphttp_test.go new file mode 100644 index 000000000..4d8a147c6 --- /dev/null +++ b/derp/derphttp/derphttp_test.go @@ -0,0 +1,142 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package derphttp + +import ( + "crypto/rand" + "crypto/tls" + "net" + "net/http" + "sync" + "testing" + "time" + + "golang.org/x/crypto/curve25519" + "tailscale.com/derp" +) + +func TestSendRecv(t *testing.T) { + const numClients = 3 + var serverPrivateKey [32]byte + if _, err := rand.Read(serverPrivateKey[:]); err != nil { + t.Fatal(err) + } + var clientPrivateKeys [][32]byte + for i := 0; i < numClients; i++ { + var key [32]byte + if _, err := rand.Read(key[:]); err != nil { + t.Fatal(err) + } + clientPrivateKeys = append(clientPrivateKeys, key) + } + var clientKeys [][32]byte + for _, privKey := range clientPrivateKeys { + var key [32]byte + curve25519.ScalarBaseMult(&key, &privKey) + clientKeys = append(clientKeys, key) + } + + s := derp.NewServer(serverPrivateKey, t.Logf) + defer s.Close() + + httpsrv := &http.Server{ + TLSNextProto: make(map[string]func(*http.Server, *tls.Conn, http.Handler)), + Handler: Handler(s), + } + + ln, err := net.Listen("tcp4", "localhost:0") + if err != nil { + t.Fatal(err) + } + serverURL := "http://" + ln.Addr().String() + t.Logf("server URL: %s", serverURL) + + go func() { + if err := httpsrv.Serve(ln); err != nil { + if err == http.ErrServerClosed { + return + } + panic(err) + } + }() + + var clients []*Client + var recvChs []chan []byte + done := make(chan struct{}) + var wg sync.WaitGroup + defer func() { + close(done) + for _, c := range clients { + c.Close() + } + wg.Wait() + }() + for i := 0; i < numClients; i++ { + key := clientPrivateKeys[i] + c, err := NewClient(key, serverURL, t.Logf) + if err != nil { + t.Fatalf("client %d: %v", i, err) + } + clients = append(clients, c) + recvChs = append(recvChs, make(chan []byte)) + + wg.Add(1) + go func(i int) { + defer wg.Done() + for { + select { + case <-done: + return + default: + } + b := make([]byte, 1<<16) + n, err := c.Recv(b) + if err != nil { + t.Logf("client%d: %v", i, err) + break + } + b = b[:n] + recvChs[i] <- b + } + }(i) + } + + recv := func(i int, want string) { + t.Helper() + select { + case b := <-recvChs[i]: + if got := string(b); got != want { + t.Errorf("client1.Recv=%q, want %q", got, want) + } + case <-time.After(1 * time.Second): + t.Errorf("client%d.Recv, got nothing, want %q", i, want) + } + } + recvNothing := func(i int) { + t.Helper() + select { + case b := <-recvChs[0]: + t.Errorf("client%d.Recv=%q, want nothing", i, string(b)) + default: + } + } + + msg1 := []byte("hello 0->1\n") + if err := clients[0].Send(clientKeys[1], msg1); err != nil { + t.Fatal(err) + } + recv(1, string(msg1)) + recvNothing(0) + recvNothing(2) + + msg2 := []byte("hello 1->2\n") + if err := clients[1].Send(clientKeys[2], msg2); err != nil { + t.Fatal(err) + } + recv(2, string(msg2)) + recvNothing(0) + recvNothing(1) + +} diff --git a/derp/doc.go b/derp/doc.go new file mode 100644 index 000000000..9b92eb79e --- /dev/null +++ b/derp/doc.go @@ -0,0 +1,13 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package derp implements DERP, the Detour Encrypted Routing Protocol. +// +// DERP routes packets to clients using curve25519 keys as addresses. +// +// DERP is used by Tailscale nodes to proxy encrypted WireGuard +// packets through the Tailscale cloud servers when a direct path +// cannot be found or opened. DERP is a last resort. Both sides +// between very aggressive NATs, firewalls, no IPv6, etc? Well, DERP. +package derp diff --git a/go.mod b/go.mod new file mode 100644 index 000000000..716262722 --- /dev/null +++ b/go.mod @@ -0,0 +1,19 @@ +module tailscale.com + +go 1.13 + +require ( + github.com/apenwarr/fixconsole v0.0.0-20191012055117-5a9f6489cc29 + github.com/go-ole/go-ole v1.2.4 + github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e + github.com/google/go-cmp v0.4.0 + github.com/klauspost/compress v1.9.8 + github.com/mdlayher/netlink v1.1.0 + github.com/pborman/getopt v0.0.0-20190409184431-ee0cd42419d3 + github.com/tailscale/hujson v0.0.0-20190930033718-5098e564d9b3 + github.com/tailscale/wireguard-go v0.0.0-20200208214841-2981baf46731 + golang.org/x/crypto v0.0.0-20200208060501-ecb85df21340 + golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d + golang.org/x/sys v0.0.0-20200202164722-d101bd2416d5 + gortc.io/stun v1.22.1 +) diff --git a/go.sum b/go.sum new file mode 100644 index 000000000..50c838299 --- /dev/null +++ b/go.sum @@ -0,0 +1,76 @@ +cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= +github.com/apenwarr/fixconsole v0.0.0-20191012055117-5a9f6489cc29 h1:muXWUcay7DDy1/hEQWrYlBy+g0EuwT70sBHg65SeUc4= +github.com/apenwarr/fixconsole v0.0.0-20191012055117-5a9f6489cc29/go.mod h1:JYWahgHer+Z2xbsgHPtaDYVWzeHDminu+YIBWkxpCAY= +github.com/apenwarr/w32 v0.0.0-20190407065021-aa00fece76ab h1:CMGzRRCjnD50RjUFSArBLuCxiDvdp7b8YPAcikBEQ+k= +github.com/apenwarr/w32 v0.0.0-20190407065021-aa00fece76ab/go.mod h1:nfFtvHn2Hgs9G1u0/J6LHQv//EksNC+7G8vXmd1VTJ8= +github.com/go-ole/go-ole v1.2.4 h1:nNBDSCOigTSiarFpYE9J/KtEA1IOW4CNeqT9TQDqCxI= +github.com/go-ole/go-ole v1.2.4/go.mod h1:XCwSNxSkXRo4vlyPy93sltvi/qJq0jqQhjqQNIwKuxM= +github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e h1:1r7pUrabqp18hOBcwBwiTsbnFeTZHV9eER/QT5JVZxY= +github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= +github.com/golang/protobuf v1.2.0 h1:P3YflyNX/ehuJFLhxviNdFxQPkGK5cDcApsge1SqnvM= +github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= +github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= +github.com/google/go-cmp v0.4.0 h1:xsAVV57WRhGj6kEIi8ReJzQlHHqcBYCElAvkovg3B/4= +github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/jsimonetti/rtnetlink v0.0.0-20190606172950-9527aa82566a/go.mod h1:Oz+70psSo5OFh8DBl0Zv2ACw7Esh6pPUphlvZG9x7uw= +github.com/jsimonetti/rtnetlink v0.0.0-20200117123717-f846d4f6c1f4 h1:nwOc1YaOrYJ37sEBrtWZrdqzK22hiJs3GpDmP3sR2Yw= +github.com/jsimonetti/rtnetlink v0.0.0-20200117123717-f846d4f6c1f4/go.mod h1:WGuG/smIU4J/54PblvSbh+xvCZmpJnFgr3ds6Z55XMQ= +github.com/klauspost/compress v1.9.8 h1:VMAMUUOh+gaxKTMk+zqbjsSjsIcUcL/LF4o63i82QyA= +github.com/klauspost/compress v1.9.8/go.mod h1:RyIbtBH6LamlWaDj8nUwkbUhJ87Yi3uG0guNDohfE1A= +github.com/mdlayher/netlink v0.0.0-20190409211403-11939a169225/go.mod h1:eQB3mZE4aiYnlUsyGGCOpPETfdQq4Jhsgf1fk3cwQaA= +github.com/mdlayher/netlink v1.0.0/go.mod h1:KxeJAFOFLG6AjpyDkQ/iIhxygIUKD+vcwqcnu43w/+M= +github.com/mdlayher/netlink v1.1.0 h1:mpdLgm+brq10nI9zM1BpX1kpDbh3NLl3RSnVq6ZSkfg= +github.com/mdlayher/netlink v1.1.0/go.mod h1:H4WCitaheIsdF9yOYu8CFmCgQthAPIWZmcKp9uZHgmY= +github.com/pborman/getopt v0.0.0-20190409184431-ee0cd42419d3 h1:YtFkrqsMEj7YqpIhRteVxJxCeC3jJBieuLr0d4C4rSA= +github.com/pborman/getopt v0.0.0-20190409184431-ee0cd42419d3/go.mod h1:85jBQOZwpVEaDAr341tbn15RS4fCAsIst0qp7i8ex1o= +github.com/tailscale/hujson v0.0.0-20190930033718-5098e564d9b3 h1:rdtXEo9yffOjh4vZQJw3heaY+ggXKp+zvMX5fihh6lI= +github.com/tailscale/hujson v0.0.0-20190930033718-5098e564d9b3/go.mod h1:STqf+YV0ADdzk4ejtXFsGqDpATP9JoL0OB+hiFQbkdE= +github.com/tailscale/wireguard-go v0.0.0-20191108062213-b93cdd0582db h1:oP0crfwOb3WZSVrMVm/o51NXN2JirDlcdlNEIPTmgI0= +github.com/tailscale/wireguard-go v0.0.0-20200207221558-a158079b156a h1:5TWA3nl2QUfL9OiE3tlBpqJd4GYd4hbGtDNkWQQ2fyc= +github.com/tailscale/wireguard-go v0.0.0-20200207221558-a158079b156a/go.mod h1:QPS8HjBzzAXoQNndUNx2efJaQbCCz8nI2Cv1ksTUHyY= +github.com/tailscale/wireguard-go v0.0.0-20200208161837-3cd0a483944a h1:vIyObUBvnXB1XTKTBM4AgoUFR9RHiz/kslGHClkXQVg= +github.com/tailscale/wireguard-go v0.0.0-20200208161837-3cd0a483944a/go.mod h1:JPm5cTfu1K+qDFRbiHy0sOlHUylYQbpl356sdYFD8V4= +github.com/tailscale/wireguard-go v0.0.0-20200208214841-2981baf46731 h1:sNmny/5pHqHdm081Fx8rcNFnwt0zTGuee/0+Jz+tXCA= +github.com/tailscale/wireguard-go v0.0.0-20200208214841-2981baf46731/go.mod h1:JPm5cTfu1K+qDFRbiHy0sOlHUylYQbpl356sdYFD8V4= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20191002192127-34f69633bfdc/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20200206161412-a0c6ece9d31a h1:aczoJ0HPNE92XKa7DrIzkNN6esOKO2TBwiiYoKcINhA= +golang.org/x/crypto v0.0.0-20200206161412-a0c6ece9d31a/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.0.0-20200208060501-ecb85df21340 h1:KOcEaR10tFr7gdJV2GCKw8Os5yED1u1aOqHjOAb6d2Y= +golang.org/x/crypto v0.0.0-20200208060501-ecb85df21340/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190827160401-ba9fcec4b297/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20191003171128-d98b1b443823/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20191007182048-72f939374954/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20200202094626-16171245cfb2 h1:CCH4IOTTfewWjGOlSp+zGcjutRKlBEZQ6wTn8ozI/nI= +golang.org/x/net v0.0.0-20200202094626-16171245cfb2/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d h1:TzXSXBo42m9gQenoE3b9BGiEpg5IG2JkU5FkPIawgtw= +golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= +golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4 h1:YUO/7uOKsKeq9UokNS62b8FYywz3ker1l1vDZRCRefw= +golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190405154228-4b34438f7a67/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190411185658-b44545bcd369/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190826190057-c7b8b68b1456/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191003212358-c178f38b412c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191008105621-543471e840be/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200202164722-d101bd2416d5 h1:LfCXLvNmTYH9kEmVgqbnsWfruoXZIrh4YBgqVHtDvw0= +golang.org/x/sys v0.0.0-20200202164722-d101bd2416d5/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs= +golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.zx2c4.com/wireguard v0.0.20200121 h1:vcswa5Q6f+sylDfjqyrVNNrjsFUUbPsgAQTBCAg/Qf8= +golang.zx2c4.com/wireguard v0.0.20200121/go.mod h1:P2HsVp8SKwZEufsnezXZA4GRX/T49/HlU7DGuelXsU4= +google.golang.org/appengine v1.4.0 h1:/wp5JvzpHIxhs/dumFmF7BXTf3Z+dd4uXta4kVyO508= +google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= +gortc.io/stun v1.22.1 h1:96mOdDATYRqhYB+TZdenWBg4CzL2Ye5kPyBXQ8KAB+8= +gortc.io/stun v1.22.1/go.mod h1:XD5lpONVyjvV3BgOyJFNo0iv6R2oZB4L+weMqxts+zg= diff --git a/ipn/backend.go b/ipn/backend.go new file mode 100644 index 000000000..7b0a35e12 --- /dev/null +++ b/ipn/backend.go @@ -0,0 +1,79 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ipn + +import ( + "tailscale.com/control/controlclient" + "tailscale.com/tailcfg" + "tailscale.com/wgengine" + "time" +) + +type State int + +const ( + NoState = State(iota) + NeedsLogin + NeedsMachineAuth + Stopped + Starting + Running +) + +func (s State) String() string { + return [...]string{"NoState", "NeedsLogin", "NeedsMachineAuth", + "Stopped", "Starting", "Running"}[s] +} + +type EngineStatus struct { + RBytes, WBytes wgengine.ByteCount + NumLive int + LivePeers map[tailcfg.NodeKey]wgengine.PeerStatus +} + +type NetworkMap = controlclient.NetworkMap + +// In any given notification, any or all of these may be nil, meaning +// that they have not changed. +type Notify struct { + Version string // version number of IPN backend + ErrMessage *string // critical error message, if any + LoginFinished *struct{} // event: login process succeeded + State *State // current IPN state has changed + Prefs *Prefs // preferences were changed + NetMap *NetworkMap // new netmap received + Engine *EngineStatus // wireguard engine stats + BrowseToURL *string // UI should open a browser right now + BackendLogID *string // public logtail id used by backend +} + +type Options struct { + FrontendLogID string // public logtail id used by frontend + ServerURL string + Prefs Prefs + LoginFlags controlclient.LoginFlags + Notify func(n Notify) `json:"-"` +} + +type Backend interface { + // Start or restart the backend, because a new Handle has connected. + Start(opts Options) error + // Start a new interactive login. This should trigger a new + // BrowseToURL notification eventually. + StartLoginInteractive() + // Terminate the current login session and stop the wireguard engine. + Logout() + // Install a new set of user preferences, including WantRunning. + // This may cause the wireguard engine to reconfigure or stop. + SetPrefs(new Prefs) + // Poll for an update from the wireguard engine. Only needed if + // you want to display byte counts. Connection events are emitted + // automatically without polling. + RequestEngineStatus() + // Pretend the current key is going to expire after duration x. + // This is useful for testing GUIs to make sure they react properly + // with keys that are going to expire. + FakeExpireAfter(x time.Duration) +} diff --git a/ipn/doc.go b/ipn/doc.go new file mode 100644 index 000000000..0cc326a9a --- /dev/null +++ b/ipn/doc.go @@ -0,0 +1,11 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package ipn implements the interactions between the Tailscale cloud +// control plane and the local network stack. +// +// IPN is the abbreviated name for a Tailscale network. What's less +// clear is what it's an abbreviation for: Identified Private Network? +// IP Network? Internet Private Network? I Privately Network? +package ipn diff --git a/ipn/e2e_test.go b/ipn/e2e_test.go new file mode 100644 index 000000000..9f3ab3712 --- /dev/null +++ b/ipn/e2e_test.go @@ -0,0 +1,207 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build depends_on_currently_unreleased + +package ipn + +import ( + "bytes" + "io/ioutil" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/tailscale/wireguard-go/tun/tuntest" + "tailscale.com/control/controlclient" + "tailscale.com/tailcfg" + "tailscale.com/testy" + "tailscale.com/wgengine" + "tailscale.com/wgengine/magicsock" + "tailscale.io/control" // not yet released +) + +func TestIPN(t *testing.T) { + testy.FixLogs(t) + defer testy.UnfixLogs(t) + + // Turn off STUN for the test to make it hermitic. + // TODO(crawshaw): add a test that runs against a local STUN server. + origDefaultSTUN := magicsock.DefaultSTUN + magicsock.DefaultSTUN = nil + defer func() { + magicsock.DefaultSTUN = origDefaultSTUN + }() + + // TODO(apenwarr): Make resource checks actually pass. + // They don't right now, because (at least) wgengine doesn't fully + // shut down. + // rc := testy.NewResourceCheck() + // defer rc.Assert(t) + + var ctl *control.Server + + ctlHandler := func(w http.ResponseWriter, r *http.Request) { + ctl.ServeHTTP(w, r) + } + https := httptest.NewServer(http.HandlerFunc(ctlHandler)) + serverURL := https.URL + defer https.Close() + defer https.CloseClientConnections() + + tmpdir, err := ioutil.TempDir("", "ipntest") + if err != nil { + t.Fatalf("create tempdir: %v\n", err) + } + ctl, err = control.New(tmpdir, serverURL, true) + if err != nil { + t.Fatalf("create control server: %v\n", ctl) + } + + n1 := newNode(t, "n1", https) + defer n1.Backend.Shutdown() + n1.Backend.StartLoginInteractive() + + n2 := newNode(t, "n2", https) + defer n2.Backend.Shutdown() + n2.Backend.StartLoginInteractive() + + var s1, s2 State + for { + t.Logf("\n\nn1.state=%v n2.state=%v\n\n", s1, s2) + + // TODO(crawshaw): switch from || to &&. To do this we need to + // transmit some data so that the handshake completes on both + // sides. (Beacuse handshakes are 1RTT, it is the data + // transmission that completes the handshake.) + if s1 == Running || s2 == Running { + // TODO(apenwarr): ensure state sequence. + // Right now we'll just exit as soon as + // state==Running, even if the backend is lying or + // something. Not a great test. + break + } + + select { + case n := <-n1.NotifyCh: + t.Logf("n1n: %v\n", n) + if n.State != nil { + s1 = *n.State + if s1 == NeedsMachineAuth { + authNode(t, ctl, n1.Backend) + } + } + case n := <-n2.NotifyCh: + t.Logf("n2n: %v\n", n) + if n.State != nil { + s2 = *n.State + if s2 == NeedsMachineAuth { + authNode(t, ctl, n2.Backend) + } + } + case <-time.After(3 * time.Second): + t.Fatalf("\n\n\nFATAL: timed out waiting for notifications.\n\n\n") + } + } + + t.Skip("skipping ping tests, they are flaky") // TODO(crawshaw): this exposes a real bug! + + n1addr := n1.Backend.NetMap().Addresses[0].IP + n2addr := n2.Backend.NetMap().Addresses[0].IP + t.Run("ping n2", func(t *testing.T) { + msg := tuntest.Ping(n2addr.IP(), n1addr.IP()) + n1.ChannelTUN.Outbound <- msg + select { + case msgRecv := <-n2.ChannelTUN.Inbound: + if !bytes.Equal(msg, msgRecv) { + t.Error("bad ping") + } + case <-time.After(1 * time.Second): + t.Error("no ping seen") + } + }) + t.Run("ping n1", func(t *testing.T) { + msg := tuntest.Ping(n1addr.IP(), n2addr.IP()) + n2.ChannelTUN.Outbound <- msg + select { + case msgRecv := <-n1.ChannelTUN.Inbound: + if !bytes.Equal(msg, msgRecv) { + t.Error("bad ping") + } + case <-time.After(1 * time.Second): + t.Error("no ping seen") + } + }) +} + +type testNode struct { + Backend *LocalBackend + ChannelTUN *tuntest.ChannelTUN + NotifyCh <-chan Notify +} + +// Create a new IPN node. +func newNode(t *testing.T, prefix string, https *httptest.Server) testNode { + t.Helper() + logfe := func(fmt string, args ...interface{}) { + t.Logf(prefix+".e: "+fmt, args...) + } + logf := func(fmt string, args ...interface{}) { + t.Logf(prefix+": "+fmt, args...) + } + + derp := false + tun := tuntest.NewChannelTUN() + e1, err := wgengine.NewUserspaceEngineAdvanced(logfe, tun.TUN(), wgengine.NewFakeRouter, 0, derp) + if err != nil { + t.Fatalf("NewFakeEngine: %v\n", err) + } + n, err := NewLocalBackend(logf, prefix, e1) + if err != nil { + t.Fatalf("NewLocalBackend: %v\n", err) + } + nch := make(chan Notify, 1000) + c := controlclient.Persist{ + Provider: "google", + LoginName: "test1@tailscale.com", + } + n.Start(Options{ + FrontendLogID: prefix + "-f", + ServerURL: https.URL, + Prefs: Prefs{ + RouteAll: true, + AllowSingleHosts: true, + CorpDNS: true, + WantRunning: true, + Persist: &c, + }, + LoginFlags: controlclient.LoginDefault, + Notify: func(n Notify) { + // Automatically visit auth URLs + if n.BrowseToURL != nil { + t.Logf("\n\n\nURL! %vv\n", *n.BrowseToURL) + hc := https.Client() + _, err := hc.Get(*n.BrowseToURL) + if err != nil { + t.Logf("BrowseToURL: %v\n", err) + } + } + nch <- n + }, + }) + + return testNode{ + Backend: n, + ChannelTUN: tun, + NotifyCh: nch, + } +} + +// Tell the control server to authorize the given node. +func authNode(t *testing.T, ctl *control.Server, n *LocalBackend) { + mk := *n.prefs.Persist.PrivateMachineKey.Public() + nk := *n.prefs.Persist.PrivateNodeKey.Public() + ctl.AuthorizeMachine(tailcfg.MachineKey(mk), tailcfg.NodeKey(nk)) +} diff --git a/ipn/fake.go b/ipn/fake.go new file mode 100644 index 000000000..3e885d1af --- /dev/null +++ b/ipn/fake.go @@ -0,0 +1,72 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ipn + +import ( + "log" + "time" +) + +type FakeBackend struct { + serverURL string + notify func(n Notify) + live bool +} + +func (b *FakeBackend) Start(opts Options) error { + b.serverURL = opts.ServerURL + if opts.Notify == nil { + log.Fatalf("FakeBackend.Start: opts.Notify is nil\n") + } + b.notify = opts.Notify + b.notify(Notify{Prefs: &opts.Prefs}) + nl := NeedsLogin + b.notify(Notify{State: &nl}) + return nil +} + +func (b *FakeBackend) newState(s State) { + b.notify(Notify{State: &s}) + if s == Running { + b.live = true + } else { + b.live = false + } +} + +func (b *FakeBackend) StartLoginInteractive() { + u := b.serverURL + "/this/is/fake" + b.notify(Notify{BrowseToURL: &u}) + b.newState(NeedsMachineAuth) + b.newState(Stopped) + // TODO(apenwarr): Fill in a more interesting netmap here. + b.notify(Notify{NetMap: &NetworkMap{}}) + b.newState(Starting) + // TODO(apenwarr): Fill in a more interesting status. + b.notify(Notify{Engine: &EngineStatus{}}) + b.newState(Running) +} + +func (b *FakeBackend) Logout() { + b.newState(NeedsLogin) +} + +func (b *FakeBackend) SetPrefs(new Prefs) { + b.notify(Notify{Prefs: &new}) + if new.WantRunning && !b.live { + b.newState(Starting) + b.newState(Running) + } else if !new.WantRunning && b.live { + b.newState(Stopped) + } +} + +func (b *FakeBackend) RequestEngineStatus() { + b.notify(Notify{Engine: &EngineStatus{}}) +} + +func (b *FakeBackend) FakeExpireAfter(x time.Duration) { + b.notify(Notify{NetMap: &NetworkMap{}}) +} diff --git a/ipn/handle.go b/ipn/handle.go new file mode 100644 index 000000000..4d5f9020c --- /dev/null +++ b/ipn/handle.go @@ -0,0 +1,166 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ipn + +import ( + "strings" + "sync" + "time" + + "github.com/tailscale/wireguard-go/wgcfg" + "tailscale.com/logger" +) + +type Handle struct { + serverURL string + frontendLogID string + b Backend + xnotify func(n Notify) + logf logger.Logf + + // Mutex protects everything below + mu sync.Mutex + netmapCache *NetworkMap + engineStatusCache EngineStatus + stateCache State + prefsCache Prefs +} + +func NewHandle(b Backend, logf logger.Logf, opts Options) (*Handle, error) { + h := &Handle{ + b: b, + logf: logf, + } + + err := h.Start(opts) + if err != nil { + return nil, err + } + + return h, nil +} + +func (h *Handle) Start(opts Options) error { + h.serverURL = strings.TrimRight(opts.ServerURL, "/") + h.frontendLogID = opts.FrontendLogID + h.xnotify = opts.Notify + h.netmapCache = nil + h.engineStatusCache = EngineStatus{} + h.stateCache = NoState + h.prefsCache = opts.Prefs + xopts := opts + xopts.Notify = h.notify + return h.b.Start(xopts) +} + +func (h *Handle) Reset() { + st := NoState + h.notify(Notify{State: &st}) +} + +func (h *Handle) notify(n Notify) { + h.mu.Lock() + if n.BackendLogID != nil { + h.logf("Handle: logs: be:%v fe:%v\n", + *n.BackendLogID, h.frontendLogID) + } + if n.State != nil { + h.stateCache = *n.State + } + if n.Prefs != nil { + h.prefsCache = *n.Prefs + } + if n.NetMap != nil { + h.netmapCache = n.NetMap + } + if n.Engine != nil { + h.engineStatusCache = *n.Engine + } + h.mu.Unlock() + + if h.xnotify != nil { + // Forward onward to our parent's notifier + h.xnotify(n) + } +} + +func (h *Handle) Prefs() Prefs { + h.mu.Lock() + defer h.mu.Unlock() + + return h.prefsCache +} + +func (h *Handle) UpdatePrefs(updateFn func(old Prefs) (new Prefs)) { + h.mu.Lock() + defer h.mu.Unlock() + + new := updateFn(h.prefsCache) + h.prefsCache = new + h.b.SetPrefs(new) +} + +func (h *Handle) State() State { + h.mu.Lock() + defer h.mu.Unlock() + + return h.stateCache +} + +func (h *Handle) EngineStatus() EngineStatus { + h.mu.Lock() + defer h.mu.Unlock() + + return h.engineStatusCache +} + +func (h *Handle) LocalAddrs() []wgcfg.CIDR { + h.mu.Lock() + defer h.mu.Unlock() + + nm := h.netmapCache + if nm != nil { + return nm.Addresses + } + return []wgcfg.CIDR{} +} + +func (h *Handle) NetMap() *NetworkMap { + h.mu.Lock() + defer h.mu.Unlock() + + return h.netmapCache +} + +func (h *Handle) Expiry() time.Time { + h.mu.Lock() + defer h.mu.Unlock() + + nm := h.netmapCache + if nm != nil { + return nm.Expiry + } + return time.Time{} +} + +func (h *Handle) AdminPageURL() string { + return h.serverURL + "/admin/machines" +} + +func (h *Handle) StartLoginInteractive() { + h.b.StartLoginInteractive() +} + +func (h *Handle) Logout() { + h.b.Logout() +} + +func (h *Handle) RequestEngineStatus() { + h.b.RequestEngineStatus() +} + +func (h *Handle) FakeExpireAfter(x time.Duration) { + h.b.FakeExpireAfter(x) +} diff --git a/ipn/ipnserver/server.go b/ipn/ipnserver/server.go new file mode 100644 index 000000000..8cbbf3ef4 --- /dev/null +++ b/ipn/ipnserver/server.go @@ -0,0 +1,253 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ipnserver + +import ( + "bufio" + "context" + "fmt" + "log" + "net" + "os" + "os/exec" + "os/signal" + "strings" + "sync" + "syscall" + "time" + + "github.com/google/go-cmp/cmp" + "github.com/klauspost/compress/zstd" + "tailscale.com/control/controlclient" + "tailscale.com/ipn" + "tailscale.com/logger" + "tailscale.com/logtail/backoff" + "tailscale.com/safesocket" + "tailscale.com/wgengine" +) + +type Options struct { + SurviveDisconnects bool + AllowQuit bool +} + +func pump(logf logger.Logf, ctx context.Context, bs *ipn.BackendServer, s net.Conn) { + defer logf("Control connection done.\n") + + for ctx.Err() == nil && !bs.GotQuit { + msg, err := ipn.ReadMsg(s) + if err != nil { + logf("ReadMsg: %v\n", err) + break + } + err = bs.GotCommandMsg(msg) + if err != nil { + logf("GotCommandMsg: %v\n", err) + break + } + } +} + +func Run(rctx context.Context, logf logger.Logf, logid string, opts Options, e wgengine.Engine) error { + bo := backoff.Backoff{Name: "ipnserver"} + + listen, _, err := safesocket.Listen("", "Tailscale", "tailscaled", 41112) + if err != nil { + return fmt.Errorf("safesocket.Listen: %v", err) + } + + b, err := ipn.NewLocalBackend(logf, logid, e) + if err != nil { + return fmt.Errorf("NewLocalBackend: %v", err) + } + b.SetDecompressor(func() (controlclient.Decompressor, error) { + return zstd.NewReader(nil) + }) + b.SetCmpDiff(func(x, y interface{}) string { return cmp.Diff(x, y) }) + + var s net.Conn + serverToClient := func(b []byte) { + if s != nil { + ipn.WriteMsg(s, b) + } + } + + bs := ipn.NewBackendServer(logf, b, serverToClient) + + logf("Listening on %v\n", listen.Addr()) + + // Go listeners can't take a context, close it instead. + go func() { + <-rctx.Done() + listen.Close() + }() + + var oldS net.Conn + ctx, cancel := context.WithCancel(rctx) + + stopAll := func() { + // Currently we only support one client connection at a time. + // Theoretically we could allow multiple clients, by passing + // notifications to all of them and accepting commands from + // any of them, but there doesn't seem to be much need for + // that right now. + if oldS != nil { + cancel() + safesocket.ConnCloseRead(oldS) + safesocket.ConnCloseWrite(oldS) + } + } + + for i := 1; rctx.Err() == nil; i++ { + s, err = listen.Accept() + if err != nil { + logf("%d: Accept: %v\n", i, err) + bo.BackOff(rctx, err) + continue + } + logf("%d: Incoming control connection.\n", i) + stopAll() + + ctx, cancel = context.WithCancel(context.Background()) + oldS = s + + go func(ctx context.Context, bs *ipn.BackendServer, s net.Conn, i int) { + si := fmt.Sprintf("%d: ", i) + pump(func(fmt string, args ...interface{}) { + logf(si+fmt, args...) + }, ctx, bs, s) + if !opts.SurviveDisconnects || bs.GotQuit { + bs.Reset() + s.Close() + } + if opts.AllowQuit { + os.Exit(0) + } else { + bs.GotQuit = false + } + }(ctx, bs, s, i) + + bo.BackOff(ctx, nil) + } + stopAll() + + return rctx.Err() +} + +func BabysitProc(ctx context.Context, args []string, logf logger.Logf) { + + executable, err := os.Executable() + if err != nil { + panic("cannot determine executable: " + err.Error()) + } + + var proc struct { + mu sync.Mutex + p *os.Process + } + + done := make(chan struct{}) + go func() { + interrupt := make(chan os.Signal, 1) + signal.Notify(interrupt, syscall.SIGINT, syscall.SIGTERM) + var sig os.Signal + select { + case sig = <-interrupt: + logf("BabysitProc: got signal: %v\n", sig) + close(done) + case <-ctx.Done(): + logf("BabysitProc: context done\n") + sig = os.Kill + close(done) + } + + proc.mu.Lock() + proc.p.Signal(sig) + proc.mu.Unlock() + }() + + bo := backoff.Backoff{Name: "BabysitProc"} + + for { + startTime := time.Now() + log.Printf("exec: %#v %v\n", executable, args) + cmd := exec.Command(executable, args...) + + // Create a pipe object to use as the subproc's stdin. + // When the writer goes away, the reader gets EOF. + // A subproc can watch its stdin and exit when it gets EOF; + // this is a very reliable way to have a subproc die when + // its parent (us) disappears. + // We never need to actually write to wStdin. + rStdin, wStdin, err := os.Pipe() + if err != nil { + log.Printf("os.Pipe 1: %v\n", err) + return + } + + // Create a pipe object to use as the subproc's stdout/stderr. + // We'll read from this pipe and send it to logf, line by line. + // We can't use os.exec's io.Writer for this because it + // doesn't care about lines, and thus ends up merging multiple + // log lines into one or splitting one line into multiple + // logf() calls. bufio is more appropriate. + rStdout, wStdout, err := os.Pipe() + if err != nil { + log.Printf("os.Pipe 2: %v\n", err) + } + go func(r *os.File) { + defer r.Close() + rb := bufio.NewReader(r) + for { + s, err := rb.ReadString('\n') + if s != "" { + logf("%s\n", strings.TrimSuffix(s, "\n")) + } + if err != nil { + break + } + } + }(rStdout) + + cmd.Stdin = rStdin + cmd.Stdout = wStdout + cmd.Stderr = wStdout + err = cmd.Start() + + // Now that the subproc is started, get rid of our copy of the + // pipe reader. Bad things happen on Windows if more than one + // process owns the read side of a pipe. + rStdin.Close() + wStdout.Close() + + if err != nil { + log.Printf("starting subprocess failed: %v", err) + } else { + proc.mu.Lock() + proc.p = cmd.Process + proc.mu.Unlock() + + err = cmd.Wait() + log.Printf("subprocess exited: %v", err) + } + + // If the process finishes, clean up the write side of the + // pipe. We'll make a new one when we restart the subproc. + wStdin.Close() + + if time.Since(startTime) < 60*time.Second { + bo.BackOff(ctx, fmt.Errorf("subproc early exit: %v", err)) + } else { + // Reset the timeout, since the process ran for a while. + bo.BackOff(ctx, nil) + } + + select { + case <-done: + return + default: + } + } +} diff --git a/ipn/local.go b/ipn/local.go new file mode 100644 index 000000000..8b73d155d --- /dev/null +++ b/ipn/local.go @@ -0,0 +1,635 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ipn + +import ( + "fmt" + "log" + "strings" + "sync" + "time" + + "github.com/tailscale/wireguard-go/wgcfg" + "tailscale.com/control/controlclient" + "tailscale.com/logger" + "tailscale.com/portlist" + "tailscale.com/tailcfg" + "tailscale.com/version" + "tailscale.com/wgengine" + "tailscale.com/wgengine/filter" +) + +// LocalBackend is the scaffolding between the Tailscale cloud control +// plane and the local network stack. +type LocalBackend struct { + logf logger.Logf + notify func(n Notify) + c *controlclient.Client + e wgengine.Engine + serverURL string + backendLogID string + portpoll *portlist.Poller // may be nil + newDecompressor func() (controlclient.Decompressor, error) + cmpDiff func(x, y interface{}) string + + // The mutex protects the following elements. + mu sync.Mutex + prefs Prefs + state State + hiCache tailcfg.Hostinfo + netMapCache *controlclient.NetworkMap + engineStatus EngineStatus + endPoints []string + blocked bool + authURL string + interact int + + // statusLock must be held before calling statusChanged.Lock() or + // statusChanged.Broadcast(). + statusLock sync.Mutex + statusChanged *sync.Cond +} + +func NewLocalBackend(logf logger.Logf, logid string, e wgengine.Engine) (*LocalBackend, error) { + + if e == nil { + panic("ipn.NewLocalBackend: wgengine must not be nil") + } + + // Default filter blocks everything, until Start() is called. + e.SetFilter(filter.NewAllowNone()) + + portpoll, err := portlist.NewPoller() + if err != nil { + logf("skipping portlist: %s\n", err) + } + + b := LocalBackend{ + logf: logf, + e: e, + backendLogID: logid, + state: NoState, + portpoll: portpoll, + } + b.statusChanged = sync.NewCond(&b.statusLock) + + if b.portpoll != nil { + go b.portpoll.Run() + go b.runPoller() + } + + return &b, nil +} + +func (b *LocalBackend) Shutdown() { + if b.portpoll != nil { + b.portpoll.Close() + } + b.c.Shutdown() + b.e.Close() + b.e.Wait() +} + +// SetDecompressor sets a decompression function, which must be a zstd +// reader. +// +// This exists because the iOS/Mac NetworkExtension is very resource +// constrained, and the zstd package is too heavy to fit in the +// constrained RSS limit. +func (b *LocalBackend) SetDecompressor(fn func() (controlclient.Decompressor, error)) { + b.newDecompressor = fn +} + +// SetCmpDiff sets a comparison function used to generate logs of what +// has changed in the network map. +// +// Typically the comparison function comes from go-cmp. +// We don't wire it in directly here because the go-cmp package adds +// 1.77mb to the binary size of the iOS NetworkExtension, which takes +// away from its precious RSS limit. +func (b *LocalBackend) SetCmpDiff(cmpDiff func(x, y interface{}) string) { + b.cmpDiff = cmpDiff +} + +func (b *LocalBackend) Start(opts Options) error { + if b.c != nil { + // TODO(apenwarr): avoid the need to reinit controlclient. + // This will trigger a full relogin/reconfigure cycle every + // time a Handle reconnects to the backend. Ideally, we + // would send the new Prefs and everything would get back + // into sync with the minimal changes. But that's not how it + // is right now, which is a sign that the code is still too + // complicated. + b.c.Shutdown() + } + + b.logf("Start: %v\n", opts.Prefs.Pretty()) + + hi := controlclient.NewHostinfo() + hi.BackendLogID = b.backendLogID + hi.FrontendLogID = opts.FrontendLogID + + b.mu.Lock() + hi.Services = b.hiCache.Services // keep any previous session + b.hiCache = hi + b.state = NoState + b.serverURL = opts.ServerURL + b.prefs = opts.Prefs + b.notify = opts.Notify + b.netMapCache = nil + b.mu.Unlock() + + b.updateFilter() + + var err error + persist := b.prefs.Persist + if persist == nil { + // let controlclient initialize it + persist = &controlclient.Persist{} + } + cli, err := controlclient.New(controlclient.Options{ + Logf: func(fmt string, args ...interface{}) { + b.logf("control: "+fmt, args...) + }, + Persist: *persist, + ServerURL: b.serverURL, + Hostinfo: &hi, + KeepAlive: true, + NewDecompressor: b.newDecompressor, + }) + if err != nil { + return err + } + + b.mu.Lock() + b.c = cli + b.mu.Unlock() + + if b.endPoints != nil { + cli.UpdateEndpoints(0, b.endPoints) + } + + cli.SetStatusFunc(func(new controlclient.Status) { + if new.LoginFinished != nil { + // Auth completed, unblock the engine + b.blockEngineUpdates(false) + b.authReconfig() + noargs := struct{}{} + b.send(Notify{LoginFinished: &noargs}) + } + if new.Persist != nil { + persist := *new.Persist // copy + b.prefs.Persist = &persist + np := b.prefs + b.send(Notify{Prefs: &np}) + } + if new.NetMap != nil { + if b.netMapCache != nil && b.cmpDiff != nil { + s1 := strings.Split(b.netMapCache.Concise(), "\n") + s2 := strings.Split(new.NetMap.Concise(), "\n") + b.logf("netmap diff:\n%v\n", b.cmpDiff(s1, s2)) + } + b.netMapCache = new.NetMap + b.send(Notify{NetMap: new.NetMap}) + b.updateFilter() + } + if new.URL != "" { + b.logf("Received auth URL: %.20v...\n", new.URL) + + b.mu.Lock() + interact := b.interact + b.authURL = new.URL + b.mu.Unlock() + + if interact > 0 { + b.popBrowserAuthNow() + } + } + if new.Err != "" { + // TODO(crawshaw): display in the UI. + log.Print(new.Err) + return + } + if new.NetMap != nil { + if b.prefs.WantRunning || b.State() == NeedsLogin { + b.prefs.WantRunning = true + } + b.SetPrefs(b.prefs) + } + b.stateMachine() + }) + + b.e.SetStatusCallback(func(s *wgengine.Status, err error) { + if err != nil { + b.logf("wgengine status error: %#v", err) + return + } + if s == nil { + log.Fatalf("weird: non-error wgengine update with status=nil\n") + } + + b.mu.Lock() + es := b.parseWgStatus(s) + b.mu.Unlock() + + b.engineStatus = es + + if b.c != nil { + b.c.UpdateEndpoints(0, s.LocalAddrs) + } + b.endPoints = append([]string{}, s.LocalAddrs...) + b.stateMachine() + + b.statusLock.Lock() + b.statusChanged.Broadcast() + b.statusLock.Unlock() + + b.send(Notify{Engine: &es}) + }) + + blid := b.backendLogID + b.logf("Backend: logs: be:%v fe:%v\n", blid, opts.FrontendLogID) + b.send(Notify{BackendLogID: &blid}) + + cli.Login(nil, opts.LoginFlags) + return nil +} + +func (b *LocalBackend) updateFilter() { + if !b.Prefs().UsePacketFilter { + b.e.SetFilter(filter.NewAllowAll()) + } else if b.netMapCache == nil { + // Not configured yet, block everything + b.e.SetFilter(filter.NewAllowNone()) + } else { + b.logf("netmap packet filter: %v\n", b.netMapCache.PacketFilter) + b.e.SetFilter(filter.New(b.netMapCache.PacketFilter)) + } +} + +func (b *LocalBackend) runPoller() { + for { + ports := <-b.portpoll.C + if ports == nil { + break + } + sl := []tailcfg.Service{} + for _, p := range ports { + var proto tailcfg.ServiceProto + if p.Proto == "tcp" { + proto = tailcfg.TCP + } else if p.Proto == "udp" { + proto = tailcfg.UDP + } + if p.Port == 53 || p.Port == 68 || + p.Port == 5353 || p.Port == 5355 { + // uninteresting system services + continue + } + s := tailcfg.Service{ + Proto: proto, + Port: p.Port, + Description: p.Process, + } + sl = append(sl, s) + } + + b.mu.Lock() + hi := b.hiCache + hi.Services = sl + b.hiCache = hi + cli := b.c + b.mu.Unlock() + + // b.c might not be started yet + if cli != nil { + cli.SetHostinfo(hi) + } + } +} + +func (b *LocalBackend) send(n Notify) { + if b.notify != nil { + n.Version = version.LONG + b.notify(n) + } +} + +func (b *LocalBackend) popBrowserAuthNow() { + b.mu.Lock() + url := b.authURL + b.interact = 0 + b.authURL = "" + b.mu.Unlock() + b.logf("popBrowserAuthNow: url=%v\n", url != "") + + b.blockEngineUpdates(true) + b.stopEngineAndWait() + b.send(Notify{BrowseToURL: &url}) + if b.State() == Running { + b.enterState(Starting) + } +} + +func (b *LocalBackend) State() State { + b.mu.Lock() + defer b.mu.Unlock() + + return b.state +} + +func (b *LocalBackend) EngineStatus() EngineStatus { + b.mu.Lock() + defer b.mu.Unlock() + + return b.engineStatus +} + +func (b *LocalBackend) StartLoginInteractive() { + b.assertClient() + b.mu.Lock() + b.interact++ + url := b.authURL + b.mu.Unlock() + b.logf("StartLoginInteractive: url=%v\n", url != "") + + if url != "" { + b.popBrowserAuthNow() + } else { + b.c.Login(nil, controlclient.LoginInteractive) + } +} + +func (b *LocalBackend) FakeExpireAfter(x time.Duration) { + b.logf("FakeExpireAfter: %v\n", x) + if b.netMapCache != nil { + e := b.netMapCache.Expiry + if e.IsZero() || time.Until(e) > x { + b.netMapCache.Expiry = time.Now().Add(x) + } + b.send(Notify{NetMap: b.netMapCache}) + } +} + +func (b *LocalBackend) LocalAddrs() []wgcfg.CIDR { + if b.netMapCache != nil { + return b.netMapCache.Addresses + } else { + return nil + } +} + +func (b *LocalBackend) Expiry() time.Time { + if b.netMapCache != nil { + return b.netMapCache.Expiry + } else { + return time.Time{} + } +} + +func (b *LocalBackend) parseWgStatus(s *wgengine.Status) EngineStatus { + var ss []string + var rx, tx wgengine.ByteCount + peers := make(map[tailcfg.NodeKey]wgengine.PeerStatus) + + live := 0 + for _, p := range s.Peers { + if p.LastHandshake.IsZero() { + ss = append(ss, "x") + } else { + ss = append(ss, fmt.Sprintf("%d/%d", p.RxBytes, p.TxBytes)) + live++ + peers[p.NodeKey] = p + } + rx += p.RxBytes + tx += p.TxBytes + } + b.logf("v%v peers: %v\n", version.LONG, strings.Join(ss, " ")) + return EngineStatus{ + RBytes: rx, + WBytes: tx, + NumLive: live, + LivePeers: peers, + } +} + +func (b *LocalBackend) AdminPageURL() string { + return b.serverURL + "/admin/machines" +} + +func (b *LocalBackend) Prefs() Prefs { + b.mu.Lock() + defer b.mu.Unlock() + + return b.prefs +} + +func (b *LocalBackend) SetPrefs(new Prefs) { + b.mu.Lock() + old := b.prefs + new.Persist = old.Persist // caller isn't allowed to override this + b.prefs = new + b.mu.Unlock() + + if old.WantRunning != new.WantRunning { + b.stateMachine() + } else { + b.authReconfig() + } + + b.logf("SetPrefs: %v\n", new.Pretty()) + b.send(Notify{Prefs: &new}) +} + +// Note: return value may be nil, if we haven't received a netmap yet. +func (b *LocalBackend) NetMap() *controlclient.NetworkMap { + return b.netMapCache +} + +func (b *LocalBackend) blockEngineUpdates(block bool) { + // TODO(apenwarr): probably need mutex here (and several other places) + b.logf("blockEngineUpdates(%v)\n", block) + + b.mu.Lock() + b.blocked = block + b.mu.Unlock() +} + +func (b *LocalBackend) authReconfig() { + b.mu.Lock() + blocked := b.blocked + uc := b.prefs + nm := b.netMapCache + b.mu.Unlock() + + if blocked { + b.logf("authReconfig: blocked, skipping.\n") + return + } + if nm == nil { + b.logf("authReconfig: netmap not yet valid. Skipping.\n") + return + } + if !uc.WantRunning { + b.logf("authReconfig: skipping because !WantRunning.\n") + return + } + b.logf("Configuring wireguard connection.\n") + + uflags := controlclient.UDefault + if uc.RouteAll { + uflags |= controlclient.UAllowDefaultRoute + // TODO(apenwarr): Make subnet routes a different pref? + uflags |= controlclient.UAllowSubnetRoutes + // TODO(apenwarr): Remove this once we sort out subnet routes. + // Right now default routes are broken in Windows, but + // controlclient doesn't properly send subnet routes. So + // let's convert a default route into a subnet route in order + // to allow experimentation. + uflags |= controlclient.UHackDefaultRoute + } + if uc.AllowSingleHosts { + uflags |= controlclient.UAllowSingleHosts + } + b.logf("reconfig: ra=%v dns=%v 0x%02x\n", uc.RouteAll, uc.CorpDNS, uflags) + + if nm != nil { + dns := nm.DNS + dom := nm.DNSDomains + if !uc.CorpDNS { + dns = []wgcfg.IP{} + dom = []string{} + } + cfg, err := nm.WGCfg(uflags, dns) + if err != nil { + log.Fatalf("WGCfg: %v\n", err) + } + + err = b.e.Reconfig(cfg, dom) + if err != nil { + b.logf("reconfig: %v", err) + } + } +} + +func (b *LocalBackend) enterState(newState State) { + b.mu.Lock() + state := b.state + prefs := b.prefs + b.mu.Unlock() + + if state == newState { + return + } + b.logf("Switching ipn state %v -> %v (WantRunning=%v)\n", + state, newState, prefs.WantRunning) + if b.notify != nil { + b.send(Notify{State: &newState}) + } + + b.state = newState + switch newState { + case NeedsLogin: + b.blockEngineUpdates(true) + fallthrough + case Stopped: + err := b.e.Reconfig(&wgcfg.Config{}, nil) + if err != nil { + b.logf("Reconfig(down): %v\n", err) + } + case Starting, NeedsMachineAuth: + b.authReconfig() + // Needed so that UpdateEndpoints can run + b.e.RequestStatus() + case Running: + break + default: + b.logf("Weird: unknown newState %#v\n", newState) + } + +} + +func (b *LocalBackend) nextState() State { + b.assertClient() + state := b.State() + + if b.netMapCache == nil { + if b.c.AuthCantContinue() { + // Auth was interrupted or waiting for URL visit, + // so it won't proceed without human help. + return NeedsLogin + } else { + // Auth or map request needs to finish + return state + } + } else if !b.prefs.WantRunning { + return Stopped + } else if e := b.netMapCache.Expiry; !e.IsZero() && time.Until(e) <= 0 { + return NeedsLogin + } else if b.netMapCache.MachineStatus != tailcfg.MachineAuthorized { + // TODO(crawshaw): handle tailcfg.MachineInvalid + return NeedsMachineAuth + } else if state == NeedsMachineAuth { + // (if we get here, we know MachineAuthorized == true) + return Starting + } else if state == Starting { + if b.EngineStatus().NumLive > 0 { + return Running + } else { + return state + } + } else if state == Running { + return Running + } else { + return Starting + } +} + +func (b *LocalBackend) RequestEngineStatus() { + b.e.RequestStatus() +} + +// TODO(apenwarr): use a channel or something to prevent re-entrancy? +// Or maybe just call the state machine from fewer places. +func (b *LocalBackend) stateMachine() { + b.enterState(b.nextState()) +} + +func (b *LocalBackend) stopEngineAndWait() { + b.logf("stopEngineAndWait...\n") + b.e.Reconfig(&wgcfg.Config{}, nil) + b.requestEngineStatusAndWait() + b.logf("stopEngineAndWait: done.\n") +} + +// Requests the wgengine status, and does not return until the status +// was delivered (to the usual callback). +func (b *LocalBackend) requestEngineStatusAndWait() { + b.logf("requestEngineStatusAndWait\n") + + b.statusLock.Lock() + go b.e.RequestStatus() + b.logf("requestEngineStatusAndWait: waiting...\n") + b.statusChanged.Wait() // temporarily releases lock while waiting + b.logf("requestEngineStatusAndWait: got status update.\n") + b.statusLock.Unlock() +} + +// NOTE(apenwarr): No easy way to persist logged-out status. +// Maybe that's for the better; if someone logs out accidentally, +// rebooting will fix it. +func (b *LocalBackend) Logout() { + b.assertClient() + b.netMapCache = nil + b.c.Logout() + b.netMapCache = nil + b.stateMachine() +} + +func (b *LocalBackend) assertClient() { + if b.c == nil { + panic("LocalBackend.assertClient: b.c == nil") + } +} diff --git a/ipn/message.go b/ipn/message.go new file mode 100644 index 000000000..e75599dbe --- /dev/null +++ b/ipn/message.go @@ -0,0 +1,249 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ipn + +import ( + "encoding/binary" + "encoding/json" + "errors" + "fmt" + "io" + "log" + "time" + + "tailscale.com/logger" + "tailscale.com/version" +) + +type NoArgs struct{} + +type StartArgs struct { + Opts Options +} + +type SetPrefsArgs struct { + New Prefs +} + +type FakeExpireAfterArgs struct { + Duration time.Duration +} + +// A command message sent to the server. Exactly one of these must be non-nil. +type Command struct { + Version string + Quit *NoArgs + Start *StartArgs + StartLoginInteractive *NoArgs + Logout *NoArgs + SetPrefs *SetPrefsArgs + RequestEngineStatus *NoArgs + FakeExpireAfter *FakeExpireAfterArgs +} + +type BackendServer struct { + logf logger.Logf + b Backend // the Backend we are serving up + sendNotifyMsg func(b []byte) // send a notification message + GotQuit bool // a Quit command was received +} + +func NewBackendServer(logf logger.Logf, b Backend, sendNotifyMsg func(b []byte)) *BackendServer { + return &BackendServer{ + logf: logf, + b: b, + sendNotifyMsg: sendNotifyMsg, + } +} + +func (bs *BackendServer) send(n Notify) { + n.Version = version.LONG + b, err := json.Marshal(n) + if err != nil { + log.Fatalf("Failed json.Marshal(notify): %v\n%#v\n", err, n) + } + bs.sendNotifyMsg(b) +} + +// Inform the BackendServer of an incoming message. +func (bs *BackendServer) GotCommandMsg(b []byte) error { + cmd := Command{} + if err := json.Unmarshal(b, &cmd); err != nil { + return err + } + return bs.GotCommand(&cmd) +} + +func (bs *BackendServer) GotCommand(cmd *Command) error { + if cmd.Version != version.LONG { + vs := fmt.Sprintf("Version mismatch! frontend=%#v backend=%#v\n", + cmd.Version, version.LONG) + bs.logf("%s\n", vs) + // ignore the command, but send a message back to the + // caller so it can realize the version mismatch too. + // We don't want to exit because it might cause a crash + // loop, and restarting won't fix the problem. + bs.send(Notify{ + ErrMessage: &vs, + }) + return nil + } + if cmd.Quit != nil { + bs.GotQuit = true + return errors.New("Quit command received") + } + + if c := cmd.Start; c != nil { + opts := c.Opts + opts.Notify = bs.send + return bs.b.Start(opts) + } else if c := cmd.StartLoginInteractive; c != nil { + bs.b.StartLoginInteractive() + return nil + } else if c := cmd.Logout; c != nil { + bs.b.Logout() + return nil + } else if c := cmd.SetPrefs; c != nil { + bs.b.SetPrefs(c.New) + return nil + } else if c := cmd.RequestEngineStatus; c != nil { + bs.b.RequestEngineStatus() + return nil + } else if c := cmd.FakeExpireAfter; c != nil { + bs.b.FakeExpireAfter(c.Duration) + return nil + } else { + return fmt.Errorf("BackendServer.Do: no command specified") + } +} + +func (bs *BackendServer) Reset() error { + // Tell the backend we got a Logout command, which will cause it + // to forget all its authentication information. + return bs.GotCommand(&Command{Logout: &NoArgs{}}) +} + +type BackendClient struct { + logf logger.Logf + sendCommandMsg func(b []byte) + notify func(n Notify) +} + +func NewBackendClient(logf logger.Logf, sendCommandMsg func(b []byte)) *BackendClient { + return &BackendClient{ + logf: logf, + sendCommandMsg: sendCommandMsg, + } +} + +func (bc *BackendClient) GotNotifyMsg(b []byte) { + n := Notify{} + if err := json.Unmarshal(b, &n); err != nil { + log.Fatalf("BackendClient.Notify: cannot decode message") + } + if n.Version != version.LONG { + vs := fmt.Sprintf("Version mismatch! frontend=%#v backend=%#v", + version.LONG, n.Version) + bc.logf("%s\n", vs) + // delete anything in the notification except the version, + // to prevent incorrect operation. + n = Notify{ + Version: n.Version, + ErrMessage: &vs, + } + } + if bc.notify != nil { + bc.notify(n) + } +} + +func (bc *BackendClient) send(cmd Command) { + cmd.Version = version.LONG + b, err := json.Marshal(cmd) + if err != nil { + log.Fatalf("Failed json.Marshal(cmd): %v\n%#v\n", err, cmd) + } + bc.sendCommandMsg(b) +} + +func (bc *BackendClient) Quit() error { + bc.send(Command{Quit: &NoArgs{}}) + return nil +} + +func (bc *BackendClient) Start(opts Options) error { + bc.notify = opts.Notify + opts.Notify = nil // server can't call our function pointer + bc.send(Command{Start: &StartArgs{Opts: opts}}) + return nil // remote Start() errors must be handled remotely +} + +func (bc *BackendClient) StartLoginInteractive() { + bc.send(Command{StartLoginInteractive: &NoArgs{}}) +} + +func (bc *BackendClient) Logout() { + bc.send(Command{Logout: &NoArgs{}}) +} + +func (bc *BackendClient) SetPrefs(new Prefs) { + bc.send(Command{SetPrefs: &SetPrefsArgs{New: new}}) +} + +func (bc *BackendClient) RequestEngineStatus() { + bc.send(Command{RequestEngineStatus: &NoArgs{}}) +} + +func (bc *BackendClient) FakeExpireAfter(x time.Duration) { + bc.send(Command{FakeExpireAfter: &FakeExpireAfterArgs{Duration: x}}) +} + +const MSG_MAX = 1024 * 1024 + +// TODO(apenwarr): incremental json decode? +// That would let us avoid storing the whole byte array uselessly in RAM. +func ReadMsg(r io.Reader) ([]byte, error) { + cb := make([]byte, 4) + _, err := io.ReadFull(r, cb) + if err != nil { + return nil, err + } + n := binary.LittleEndian.Uint32(cb) + if n > 1024*1024 { + return nil, fmt.Errorf("ipn.Read: message too large: %v bytes", n) + } + b := make([]byte, n) + _, err = io.ReadFull(r, b) + if err != nil { + return nil, err + } + return b, nil +} + +// TODO(apenwarr): incremental json encode? +// That would save RAM, at the expense of having to encode once so that +// we can produce the initial byte count. +func WriteMsg(w io.Writer, b []byte) error { + cb := make([]byte, 4) + if len(b) > MSG_MAX { + return fmt.Errorf("ipn.Write: message too large: %v bytes", len(b)) + } + binary.LittleEndian.PutUint32(cb, uint32(len(b))) + n, err := w.Write(cb) + if err != nil { + return err + } + if n != 4 { + return fmt.Errorf("ipn.Write: short write: %v bytes (wanted 4)", n) + } + n, err = w.Write(b) + if err != nil { + return err + } + if n != len(b) { + return fmt.Errorf("ipn.Write: short write: %v bytes (wanted %v)", n, len(b)) + } + return nil +} diff --git a/ipn/message_test.go b/ipn/message_test.go new file mode 100644 index 000000000..43cca8013 --- /dev/null +++ b/ipn/message_test.go @@ -0,0 +1,171 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ipn + +import ( + "bytes" + "tailscale.com/testy" + "testing" + "time" +) + +func TestReadWrite(t *testing.T) { + testy.FixLogs(t) + defer testy.UnfixLogs(t) + + rc := testy.NewResourceCheck() + defer rc.Assert(t) + + buf := bytes.Buffer{} + err := WriteMsg(&buf, []byte("Test string1")) + if err != nil { + t.Fatalf("write1: %v\n", err) + } + err = WriteMsg(&buf, []byte("")) + if err != nil { + t.Fatalf("write2: %v\n", err) + } + err = WriteMsg(&buf, []byte("Test3")) + if err != nil { + t.Fatalf("write3: %v\n", err) + } + + b, err := ReadMsg(&buf) + if want, got := "Test string1", string(b); want != got { + t.Fatalf("read1: %#v != %#v\n", want, got) + } + b, err = ReadMsg(&buf) + if want, got := "", string(b); want != got { + t.Fatalf("read2: %#v != %#v\n", want, got) + } + b, err = ReadMsg(&buf) + if want, got := "Test3", string(b); want != got { + t.Fatalf("read3: %#v != %#v\n", want, got) + } + + b, err = ReadMsg(&buf) + if err == nil { + t.Fatalf("read4: expected error, got %#v\n", b) + } +} + +func TestClientServer(t *testing.T) { + testy.FixLogs(t) + defer testy.UnfixLogs(t) + + rc := testy.NewResourceCheck() + defer rc.Assert(t) + + b := &FakeBackend{} + var bs *BackendServer + var bc *BackendClient + serverToClientCh := make(chan []byte, 16) + defer close(serverToClientCh) + go func() { + for b := range serverToClientCh { + bc.GotNotifyMsg(b) + } + }() + serverToClient := func(b []byte) { + serverToClientCh <- append([]byte{}, b...) + } + clientToServer := func(b []byte) { + bs.GotCommandMsg(b) + } + slogf := func(fmt string, args ...interface{}) { + t.Logf("s: "+fmt, args...) + } + clogf := func(fmt string, args ...interface{}) { + t.Logf("c: "+fmt, args...) + } + bs = NewBackendServer(slogf, b, serverToClient) + bc = NewBackendClient(clogf, clientToServer) + + ch := make(chan Notify, 256) + h, err := NewHandle(bc, clogf, Options{ + ServerURL: "http://example.com/fake", + Notify: func(n Notify) { + ch <- n + }, + }) + if err != nil { + t.Fatalf("NewHandle error: %v\n", err) + } + + notes := Notify{} + nn := []Notify{} + processNote := func(n Notify) { + nn = append(nn, n) + if n.State != nil { + t.Logf("state change: %v", *n.State) + notes.State = n.State + } + if n.Prefs != nil { + notes.Prefs = n.Prefs + } + if n.NetMap != nil { + notes.NetMap = n.NetMap + } + if n.Engine != nil { + notes.Engine = n.Engine + } + if n.BrowseToURL != nil { + notes.BrowseToURL = n.BrowseToURL + } + } + notesState := func() State { + if notes.State != nil { + return *notes.State + } + return NoState + } + + flushUntil := func(wantFlush State) { + t.Helper() + timer := time.NewTimer(1 * time.Second) + loop: + for { + select { + case n := <-ch: + processNote(n) + if notesState() == wantFlush { + break loop + } + case <-timer.C: + t.Fatalf("timeout waiting for state %v, got %v", wantFlush, notes.State) + } + } + timer.Stop() + loop2: + for { + select { + case n := <-ch: + processNote(n) + default: + break loop2 + } + } + if got, want := h.State(), notesState(); got != want { + t.Errorf("h.State()=%v, notes.State=%v (on flush until %v)\n", got, want, wantFlush) + } + } + + flushUntil(NeedsLogin) + + h.StartLoginInteractive() + flushUntil(Running) + if notes.NetMap == nil && h.NetMap() != nil { + t.Errorf("notes.NetMap == nil while h.NetMap != nil\nnotes:\n%v", nn) + } + + h.UpdatePrefs(func(p Prefs) Prefs { + p.WantRunning = false + return p + }) + flushUntil(Stopped) + + h.Logout() + flushUntil(NeedsLogin) +} diff --git a/ipn/prefs.go b/ipn/prefs.go new file mode 100644 index 000000000..922f30729 --- /dev/null +++ b/ipn/prefs.go @@ -0,0 +1,149 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ipn + +import ( + "bytes" + "encoding/json" + "fmt" + "io/ioutil" + "log" + "os" + "path/filepath" + + "tailscale.com/atomicfile" + "tailscale.com/control/controlclient" +) + +type Prefs struct { + RouteAll bool + AllowSingleHosts bool + CorpDNS bool + WantRunning bool + NotepadURLs bool + UsePacketFilter bool + + // The Persist field is named 'Config' in the file for backward + // compatibility with earlier versions. + // TODO(apenwarr): We should move this out of here, it's not a pref. + // We can maybe do that once we're sure which module should persist + // it (backend or frontend?) + Persist *controlclient.Persist `json:"Config"` +} + +func (uc *Prefs) Pretty() string { + var ucp string + if uc.Persist != nil { + ucp = uc.Persist.Pretty() + } else { + ucp = "Persist=nil" + } + return fmt.Sprintf("Prefs{ra=%v mesh=%v dns=%v want=%v notepad=%v %v}", + uc.RouteAll, uc.AllowSingleHosts, uc.CorpDNS, uc.WantRunning, + uc.NotepadURLs, ucp) +} + +func (uc *Prefs) ToBytes() []byte { + data, err := json.MarshalIndent(uc, "", "\t") + if err != nil { + log.Fatalf("Prefs marshal: %v\n", err) + } + return data +} + +func (uc *Prefs) Equals(uc2 *Prefs) bool { + b1 := uc.ToBytes() + b2 := uc2.ToBytes() + return bytes.Equal(b1, b2) +} + +func NewPrefs() Prefs { + return Prefs{ + // Provide default values for options which are normally + // true, but might be missing from the json data for any + // reason. The json can still override them to false. + RouteAll: true, + AllowSingleHosts: true, + CorpDNS: true, + WantRunning: true, + UsePacketFilter: true, + } +} + +func PrefsFromBytes(b []byte, enforceDefaults bool) (Prefs, error) { + uc := NewPrefs() + if len(b) == 0 { + return uc, nil + } + persist := &controlclient.Persist{} + err := json.Unmarshal(b, persist) + if err == nil && (persist.Provider != "" || persist.LoginName != "") { + // old-style relaynode config; import it + uc.Persist = persist + } else { + err = json.Unmarshal(b, &uc) + if err != nil { + log.Printf("Prefs parse: %v: %v\n", err, b) + } + } + if enforceDefaults { + uc.RouteAll = true + uc.AllowSingleHosts = true + } + return uc, err +} + +func (uc *Prefs) Copy() *Prefs { + uc2, err := PrefsFromBytes(uc.ToBytes(), false) + if err != nil { + log.Fatalf("Prefs was uncopyable: %v\n", err) + } + return &uc2 +} + +func LoadPrefs(filename string, enforceDefaults bool) Prefs { + log.Printf("Loading prefs %v\n", filename) + data, err := ioutil.ReadFile(filename) + uc := NewPrefs() + if err != nil { + log.Printf("Read: %v: %v\n", filename, err) + goto fail + } + uc, err = PrefsFromBytes(data, enforceDefaults) + if err != nil { + log.Printf("Parse: %v: %v\n", filename, err) + goto fail + } + goto post +fail: + log.Printf("failed to load config. Generating a new one.\n") + uc = NewPrefs() + uc.WantRunning = true +post: + // Update: we changed our minds :) + // Versabank would like to persist the setting across reboots, for now, + // because they don't fully trust the system and want to be able to + // leave it turned off when not in use. Eventually we need to make + // all motivation for this go away. + if false { + // Usability note: we always want WantRunning = true on startup. + // That way, if someone accidentally disables their VPN and doesn't + // know how, rebooting will fix it. + // We still persist WantRunning just in case we change our minds on + // this topic. + uc.WantRunning = true + } + log.Printf("Loaded prefs %v %v\n", filename, uc.Pretty()) + return uc +} + +func SavePrefs(filename string, uc *Prefs) { + log.Printf("Saving prefs %v %v\n", filename, uc.Pretty()) + data := uc.ToBytes() + os.MkdirAll(filepath.Dir(filename), 0700) + if err := atomicfile.WriteFile(filename, data, 0666); err != nil { + log.Printf("SavePrefs: %v\n", err) + } +} diff --git a/ipn/prefs_test.go b/ipn/prefs_test.go new file mode 100644 index 000000000..5163abaa0 --- /dev/null +++ b/ipn/prefs_test.go @@ -0,0 +1,68 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ipn + +import ( + "testing" + + "tailscale.com/control/controlclient" +) + +func checkPrefs(t *testing.T, p Prefs) { + var err error + var p2, p2c Prefs + var p2b Prefs + + pp := p.Pretty() + if pp == "" { + t.Fatalf("default p.Pretty() failed\n") + } + t.Logf("\npp: %#v\n", pp) + b := p.ToBytes() + if len(b) == 0 { + t.Fatalf("default p.ToBytes() failed\n") + } + if p != p { + t.Fatalf("p != p\n") + } + p2 = p + p2.RouteAll = true + if p == p2 { + t.Fatalf("p == p2\n") + } + p2b, err = PrefsFromBytes(p2.ToBytes(), false) + if err != nil { + t.Fatalf("PrefsFromBytes(p2) failed\n") + } + p2p := p2.Pretty() + p2bp := p2b.Pretty() + t.Logf("\np2p: %#v\np2bp: %#v\n", p2p, p2bp) + if p2p != p2bp { + t.Fatalf("p2p != p2bp\n%#v\n%#v\n", p2p, p2bp) + } + if !p2.Equals(&p2b) { + t.Fatalf("p2 != p2b\n%#v\n%#v\n", p2, p2b) + } + p2c = *p2.Copy() + if !p2b.Equals(&p2c) { + t.Fatalf("p2b != p2c\n") + } +} + +func TestBasicPrefs(t *testing.T) { + p := Prefs{} + checkPrefs(t, p) +} + +func TestPrefsPersist(t *testing.T) { + c := controlclient.Persist{ + LoginName: "test@example.com", + } + p := Prefs{ + CorpDNS: true, + Persist: &c, + } + checkPrefs(t, p) +} diff --git a/logger/logger.go b/logger/logger.go new file mode 100644 index 000000000..6c5d25292 --- /dev/null +++ b/logger/logger.go @@ -0,0 +1,10 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package logger defines a type for writing to logs. It's just a +// convenience type so that we don't have to pass verbose func(...) +// types around. +package logger + +type Logf func(fmt string, args ...interface{}) diff --git a/logpolicy/logpolicy.go b/logpolicy/logpolicy.go new file mode 100644 index 000000000..0ae040c6a --- /dev/null +++ b/logpolicy/logpolicy.go @@ -0,0 +1,171 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package logpolicy + +import ( + "context" + "encoding/json" + "io/ioutil" + "log" + "os" + "path/filepath" + "runtime" + + "github.com/klauspost/compress/zstd" + "golang.org/x/crypto/ssh/terminal" + "tailscale.com/atomicfile" + "tailscale.com/logtail" + "tailscale.com/logtail/filch" + "tailscale.com/version" +) + +type Config struct { + Collection string + PrivateID logtail.PrivateID + PublicID logtail.PublicID +} + +type Policy struct { + Logtail logtail.Logger + PublicID logtail.PublicID +} + +func (c *Config) ToBytes() []byte { + data, err := json.MarshalIndent(c, "", "\t") + if err != nil { + log.Fatalf("logpolicy.Config marshal: %v\n", err) + } + return data +} + +func (c *Config) Save(statefile string) { + c.PublicID = c.PrivateID.Public() + os.MkdirAll(filepath.Dir(statefile), 0777) + data := c.ToBytes() + if err := atomicfile.WriteFile(statefile, data, 0600); err != nil { + log.Printf("logpolicy.Config write: %v\n", err) + } +} + +func ConfigFromBytes(b []byte) (*Config, error) { + c := &Config{} + if err := json.Unmarshal(b, c); err != nil { + return nil, err + } + return c, nil +} + +type stderrWriter struct{} + +// Always writes to the latest os.Stderr, even if os.Stderr changes +// during the lifetime of this object. +func (l *stderrWriter) Write(buf []byte) (int, error) { + return os.Stderr.Write(buf) +} + +type logWriter struct { + logger *log.Logger +} + +func (l *logWriter) Write(buf []byte) (int, error) { + l.logger.Print(string(buf)) + return len(buf), nil +} + +func New(collection string, filePrefix string) *Policy { + statefile := filePrefix + ".log.conf" + var lflags int + if terminal.IsTerminal(2) || runtime.GOOS == "windows" { + lflags = 0 + } else { + lflags = log.LstdFlags + } + console := log.New(&stderrWriter{}, "", lflags) + + var oldc *Config + data, err := ioutil.ReadFile(statefile) + if err != nil { + log.Printf("logpolicy.Read %v: %v\n", statefile, err) + oldc = &Config{} + oldc.Collection = collection + } else { + oldc, err = ConfigFromBytes(data) + if err != nil { + log.Printf("logpolicy.Config unmarshal: %v\n", err) + oldc = &Config{} + } + } + + newc := *oldc + if newc.Collection != collection { + log.Printf("logpolicy.Config: config collection %q does not match %q", newc.Collection, collection) + // We picked up an incompatible config file. + // Regenerate the private ID. + newc.PrivateID = logtail.PrivateID{} + newc.Collection = collection + } + if newc.PrivateID == (logtail.PrivateID{}) { + newc.PrivateID, err = logtail.NewPrivateID() + if err != nil { + log.Fatalf("logpolicy: NewPrivateID() should never fail") + } + } + newc.PublicID = newc.PrivateID.Public() + if newc != *oldc { + newc.Save(statefile) + } + + c := logtail.Config{ + Collection: newc.Collection, + PrivateID: newc.PrivateID, + Stderr: &logWriter{console}, + NewZstdEncoder: func() logtail.Encoder { + w, err := zstd.NewWriter(nil) + if err != nil { + panic(err) + } + return w + }, + } + + // TODO(crawshaw): filePrefix is a place meant to store configuration. + // OS policies usually have other preferred places to + // store logs. Use one of them? + filchBuf, filchErr := filch.New(filePrefix, filch.Options{}) + if filchBuf != nil { + c.Buffer = filchBuf + } + lw := logtail.Log(c) + log.SetFlags(0) // other logflags are set on console, not here + log.SetOutput(lw) + + log.Printf("Program starting: v%v: %#v\n", version.LONG, os.Args) + log.Printf("LogID: %v\n", newc.PublicID) + if filchErr != nil { + log.Printf("filch failed: %v", err) + } + + return &Policy{ + Logtail: lw, + PublicID: newc.PublicID, + } +} + +// Close immediately shuts down the logger. +func (p *Policy) Close() { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + p.Shutdown(ctx) +} + +// Shutdown gracefully shuts down the logger, finishing any current +// log upload if it can be done before ctx is canceled. +func (p *Policy) Shutdown(ctx context.Context) error { + log.Printf("flushing log.\n") + if p.Logtail != nil { + return p.Logtail.Shutdown(ctx) + } + return nil +} diff --git a/logtail/.gitignore b/logtail/.gitignore new file mode 100644 index 000000000..0b29b4aca --- /dev/null +++ b/logtail/.gitignore @@ -0,0 +1,6 @@ +*~ +*.out +/example/logadopt/logadopt +/example/logreprocess/logreprocess +/example/logtail/logtail +/logtail diff --git a/logtail/README.md b/logtail/README.md new file mode 100644 index 000000000..20d22c350 --- /dev/null +++ b/logtail/README.md @@ -0,0 +1,10 @@ +# Tailscale Logs Service + +This github repository contains libraries, documentation, and examples +for working with the public API of the tailscale logs service. + +For a very quick introduction to the core features, read the +[API docs](api.md) and peruse the +[logs reprocessing](./example/logreprocess/demo.sh) example. + +For more information, write to info@tailscale.io. \ No newline at end of file diff --git a/logtail/api.md b/logtail/api.md new file mode 100644 index 000000000..b2aae98ec --- /dev/null +++ b/logtail/api.md @@ -0,0 +1,195 @@ +# Tailscale Logs Service + +The Tailscale Logs Service defines a REST interface for configuring, storing, +retrieving, and processing log entries. + +# Overview + +HTTP requests are received at the service **base URL** +[https://log.tailscale.io](https://log.tailscale.io), and return JSON-encoded +responses using standard HTTP response codes. + +Authorization for the configuration and retrieval APIs is done with a secret +API key passed as the HTTP basic auth username. Secret keys are generated via +the web UI at base URL. An example of using basic auth with curl: + + curl -u : https://log.tailscale.io/collections + +In the future, an HTTP header will allow using MessagePack instead of JSON. + +## Collections + +Logs are organized into collections. Inside each collection is any number of +instances. + +A collection is a domain name. It is a grouping of related logs. As a +guideline, create one collection per product using subdomains of your +company's domain name. Collections must be registered with the logs service +before any attempt is made to store logs. + +## Instances + +Each collection is a set of instances. There is one instance per machine +writing logs. + +An instance has a name and a number. An instance has a **private** and +**public** ID. The private ID is a 32-byte random number encoded as hex. +The public ID is the SHA-256 hash of the private ID, encoded as hex. + +The private ID is used to write logs. The only copy of the private ID +should be on the machine sending logs. Ideally it is generated on the +machine. Logs can be written as soon as a private ID is generated. + +The public ID is used to read and adopt logs. It is designed to be sent +to a service that also holds a logs service API key. + +The tailscale logs service will store any logs for a short period of time. +To enable logs retention, the log can be **adopted** using the public ID +and a logs service API key. +Once this is done, logs will be retained long-term (for the configured +retention period). + +Unadopted instance logs are stored temporarily to help with debugging: +a misconfigured machine writing logs with a bad ID can be spotted by +reading the logs. +If a public ID is not adopted, storage is tightly capped and logs are +deleted after 12 hours. + +# APIs + +## Storage + +### `POST /c//` — send a log + +The body of the request is JSON. + +A **single message** is an object with properties: + +`{ }` + +The client may send any properties it wants in the JSON message, except +for the `logtail` property which has special meaning. Inside the logtail +object the client may only set the following properties: + +- `client_time` in the format of RFC3339: "2006-01-02T15:04:05.999999999Z07:00" + +A future version of the logs service API will also support: + +- `client_time_offset` a integer of nanoseconds since the client was reset +- `client_time_reset` a boolean if set to true resets the time offset counter + +On receipt by the server the `client_time_offset` is transformed into a +`client_time` based on the `server_time` when the first (or +client_time_reset) event was received. + +If any other properties are set in the logtail object they are moved into +the "error" field, the message is saved and a 4xx status code is returned. + +A **batch of messages** is a JSON array filled with single message objects: + +`[ { }, { }, ... ]` + +If any of the array entries are not objects, the content is converted +into a message with a `"logtail": { "error": ...}` property, saved, and +a 4xx status code is returned. + +Similarly any other request content not matching one of these formats is +saved in a logtail error field, and a 4xx status code is returned. + +An invalid collection name returns `{"error": "invalid collection name"}` +along with a 403 status code. + +Clients are encouraged to: + +- POST as rapidly as possible (if not battery constrained). This minimizes + both the time necessary to see logs in a log viewer and the chance of + losing logs. +- Use HTTP/2 when streaming logs, as it does a much better job of + maintaining a TLS connection to minimize overhead for subsequent posts. + +A future version of logs service API will support sending requests with +`Content-Encoding: zstd`. + +## Retrieval + +### `GET /collections` — query the set of collections and instances + +Returns a JSON object listing all of the named collections. + +The caller can query-encode the following fields: + +- `collection-name` — limit the results to one collection + + ``` + { + "collections": { + "collection1.yourcompany.com": { + "instances": { + "" :{ + "first-seen": "timestamp", + "size": 4096 + }, + "" :{ + "first-seen": "timestamp", + "size": 512000, + "orphan": true, + } + } + } + } + } + ``` + +### `GET /c/` — query stored logs + +The caller can query-encode the following fields: + +- `instances` — zero or more log collection instances to limit results to +- `time-start` — the earliest log to include +- One of: + - `time-end` — the latest log to include + - `max-count` — maximum number of logs to return, allows paging + - `stream` — boolean that keeps the response dangling, streaming in + logs like `tail -f`. Incompatible with logtail-time-end. + +In **stream=false** mode, the response is a single JSON object: + + { + // TODO: header fields + "logs": [ {}, {}, ... ] + } + +In **stream=true** mode, the response begins with a JSON header object +similar to the storage format, and then is a sequence of JSON log +objects, `{...}`, one per line. The server continues to send these until +the client closes the connection. + +## Configuration + +For organizations with a small number of instances writing logs, the +Configuration API are best used by a trusted human operator, usually +through a GUI. Organizations with many instances will need to automate +the creation of tokens. + +### `POST /collections` — create or delete a collection + +The caller must set the `collection` property and `action=create` or +`action=delete`, either form encoded or JSON encoded. Its character set +is restricted to the mundane: [a-zA-Z0-9-_.]+ + +Collection names are a global space. Typically they are a domain name. + +### `POST /instances` — adopt an instance into a collection + +The caller must send the following properties, form encoded or JSON encoded: + +- `collection` — a valid FQDN ([a-zA-Z0-9-_.]+) +- `instances` an instance public ID encoded as hex + +The collection name must be claimed by a group the caller belongs to. +The pair (collection-name, instance-public-ID) may or may not already have +logs associated with it. + +On failure, an error message is returned with a 4xx or 5xx status code: + +`{"error": "what went wrong"}` \ No newline at end of file diff --git a/logtail/backoff/backoff.go b/logtail/backoff/backoff.go new file mode 100644 index 000000000..40e518454 --- /dev/null +++ b/logtail/backoff/backoff.go @@ -0,0 +1,49 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package backoff + +import ( + "context" + "log" + "math/rand" + "time" +) + +const MAX_BACKOFF_MSEC = 30000 + +type Backoff struct { + n int + Name string + NewTimer func(d time.Duration) *time.Timer +} + +func (b *Backoff) BackOff(ctx context.Context, err error) { + if ctx.Err() == nil && err != nil { + b.n++ + // n^2 backoff timer is a little smoother than the + // common choice of 2^n. + msec := b.n * b.n * 10 + if msec > MAX_BACKOFF_MSEC { + msec = MAX_BACKOFF_MSEC + } + // Randomize the delay between 0.5-1.5 x msec, in order + // to prevent accidental "thundering herd" problems. + msec = rand.Intn(msec) + msec/2 + log.Printf("%s: backoff: %d msec\n", b.Name, msec) + newTimer := b.NewTimer + if newTimer == nil { + newTimer = time.NewTimer + } + t := newTimer(time.Duration(msec) * time.Millisecond) + select { + case <-ctx.Done(): + t.Stop() + case <-t.C: + } + } else { + // not a regular error + b.n = 0 + } +} diff --git a/logtail/buffer.go b/logtail/buffer.go new file mode 100644 index 000000000..499dea1c7 --- /dev/null +++ b/logtail/buffer.go @@ -0,0 +1,82 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package logtail + +import ( + "bytes" + "errors" + "fmt" + "sync" +) + +type Buffer interface { + // TryReadLine tries to read a log line from the ring buffer. + // If no line is available it returns a nil slice. + // If the ring buffer is closed it returns io.EOF. + TryReadLine() ([]byte, error) + + // Write writes a log line into the ring buffer. + Write([]byte) (int, error) +} + +func NewMemoryBuffer(numEntries int) Buffer { + return &memBuffer{ + pending: make(chan qentry, numEntries), + } +} + +type memBuffer struct { + next []byte + pending chan qentry + + dropMu sync.Mutex + dropCount int +} + +func (m *memBuffer) TryReadLine() ([]byte, error) { + if m.next != nil { + msg := m.next + m.next = nil + return msg, nil + } + + select { + case ent := <-m.pending: + if ent.dropCount > 0 { + m.next = ent.msg + buf := new(bytes.Buffer) + fmt.Fprintf(buf, "----------- %d logs dropped ----------", ent.dropCount) + return buf.Bytes(), nil + } + return ent.msg, nil + default: + return nil, nil + } +} + +func (m *memBuffer) Write(b []byte) (int, error) { + m.dropMu.Lock() + defer m.dropMu.Unlock() + + ent := qentry{ + msg: b, + dropCount: m.dropCount, + } + select { + case m.pending <- ent: + m.dropCount = 0 + return len(b), nil + default: + m.dropCount++ + return 0, errBufferFull + } +} + +type qentry struct { + msg []byte + dropCount int +} + +var errBufferFull = errors.New("logtail: buffer full") diff --git a/logtail/example/logadopt/logadopt.go b/logtail/example/logadopt/logadopt.go new file mode 100644 index 000000000..12726056b --- /dev/null +++ b/logtail/example/logadopt/logadopt.go @@ -0,0 +1,51 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package main + +import ( + "flag" + "io/ioutil" + "log" + "net/http" + "net/url" + "os" + "strings" +) + +func main() { + collection := flag.String("c", "", "logtail collection name") + publicID := flag.String("m", "", "machine public identifier") + apiKey := flag.String("p", "", "logtail API key") + flag.Parse() + if len(flag.Args()) != 0 { + flag.Usage() + os.Exit(1) + } + log.SetFlags(0) + + req, err := http.NewRequest("POST", "https://log.tailscale.io/instances", strings.NewReader(url.Values{ + "collection": []string{*collection}, + "instances": []string{*publicID}, + "adopt": []string{"true"}, + }.Encode())) + if err != nil { + log.Fatal(err) + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.SetBasicAuth(*apiKey, "") + resp, err := http.DefaultClient.Do(req) + if err != nil { + log.Fatal(err) + } + b, err := ioutil.ReadAll(resp.Body) + resp.Body.Close() + if err != nil { + log.Fatalf("logadopt: response read failed %d: %v", resp.StatusCode, err) + } + if resp.StatusCode != 200 { + log.Fatalf("adoption failed: %d: %s", resp.StatusCode, string(b)) + } + log.Printf("%s", string(b)) +} diff --git a/logtail/example/logreprocess/demo.sh b/logtail/example/logreprocess/demo.sh new file mode 100755 index 000000000..38ee88192 --- /dev/null +++ b/logtail/example/logreprocess/demo.sh @@ -0,0 +1,87 @@ +#!/bin/bash +# Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file. + +# +# This shell script demonstrates writing logs from machines +# and then reprocessing those logs to amalgamate python tracebacks +# into a single log entry in a new collection. +# +# To run this demo, first install the example applications: +# +# go install tailscale.com/logtail/example/... +# +# Then generate a LOGTAIL_API_KEY and two test collections by visiting: +# +# https://log.tailscale.io +# +# Then set the three variables below. +trap 'rv=$?; [ "$rv" = 0 ] || echo "-- exiting with code $rv"; exit $rv' EXIT +set -e + +LOG_TEXT='server starting +config file loaded +answering queries +Traceback (most recent call last): + File "/Users/crawshaw/junk.py", line 6, in + main() + File "/Users/crawshaw/junk.py", line 4, in main + raise Exception("oops") +Exception: oops' + +die() { + echo "$0: $*" >&2 + exit 1 +} + +msg() { + echo "-- $*" >&2 +} + +if [ -z "$LOGTAIL_API_KEY" ]; then + die "LOGTAIL_API_KEY is not set" +fi + +if [ -z "$COLLECTION_IN" ]; then + die "COLLECTION_IN is not set" +fi + +if [ -z "$COLLECTION_OUT" ]; then + die "COLLECTION_OUT is not set" +fi + +# Private IDs are 32-bytes of random hex. +# Normally you'd keep the same private IDs from one run to the next, but +# this is just an example. +msg "Generating keys..." +privateid1=$(hexdump -n 32 -e '8/4 "%08X"' /dev/urandom) +privateid2=$(hexdump -n 32 -e '8/4 "%08X"' /dev/urandom) +privateid3=$(hexdump -n 32 -e '8/4 "%08X"' /dev/urandom) + +# Public IDs are the SHA-256 of the private ID. +publicid1=$(echo -n $privateid1 | xxd -r -p - | shasum -a 256 | sed 's/ -//') +publicid2=$(echo -n $privateid2 | xxd -r -p - | shasum -a 256 | sed 's/ -//') +publicid3=$(echo -n $privateid3 | xxd -r -p - | shasum -a 256 | sed 's/ -//') + +# Write the machine logs to the input collection. +# Notice that this doesn't require an API key. +msg "Producing new logs..." +echo "$LOG_TEXT" | logtail -c $COLLECTION_IN -k $privateid1 >/dev/null +echo "$LOG_TEXT" | logtail -c $COLLECTION_IN -k $privateid2 >/dev/null + +# Adopt the logs, so they will be kept and are readable. +msg "Adopting logs..." +logadopt -p "$LOGTAIL_API_KEY" -c "$COLLECTION_IN" -m $publicid1 +logadopt -p "$LOGTAIL_API_KEY" -c "$COLLECTION_IN" -m $publicid2 + +# Reprocess the logs, amalgamating python tracebacks. +# +# We'll take that reprocessed output and write it to a separate collection, +# again via logtail. +# +# Time out quickly because all our "interesting" logs (generated +# above) have already been processed. +msg "Reprocessing logs..." +logreprocess -t 3s -c "$COLLECTION_IN" -p "$LOGTAIL_API_KEY" 2>&1 | + logtail -c "$COLLECTION_OUT" -k $privateid3 diff --git a/logtail/example/logreprocess/logreprocess.go b/logtail/example/logreprocess/logreprocess.go new file mode 100644 index 000000000..c65565b9e --- /dev/null +++ b/logtail/example/logreprocess/logreprocess.go @@ -0,0 +1,116 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// The logreprocess program tails a log and reprocesses it. +package main + +import ( + "bufio" + "encoding/json" + "flag" + "io/ioutil" + "log" + "net/http" + "os" + "strings" + "time" + + "tailscale.com/logtail" +) + +func main() { + collection := flag.String("c", "", "logtail collection name to read") + apiKey := flag.String("p", "", "logtail API key") + timeout := flag.Duration("t", 0, "timeout after which logreprocess quits") + flag.Parse() + if len(flag.Args()) != 0 { + flag.Usage() + os.Exit(1) + } + log.SetFlags(0) + + if *timeout != 0 { + go func() { + <-time.After(*timeout) + log.Printf("logreprocess: timeout reached, quitting") + os.Exit(1) + }() + } + + req, err := http.NewRequest("GET", "https://log.tailscale.io/c/"+*collection+"?stream=true", nil) + if err != nil { + log.Fatal(err) + } + req.SetBasicAuth(*apiKey, "") + resp, err := http.DefaultClient.Do(req) + if err != nil { + log.Fatal(err) + } + defer resp.Body.Close() + + if resp.StatusCode != 200 { + b, err := ioutil.ReadAll(resp.Body) + if err != nil { + log.Fatalf("logreprocess: read error %d: %v", resp.StatusCode, err) + } + log.Fatalf("logreprocess: read error %d: %s", resp.StatusCode, string(b)) + } + + tracebackCache := make(map[logtail.PublicID]*ProcessedMsg) + + scanner := bufio.NewScanner(resp.Body) + for scanner.Scan() { + var msg Msg + if err := json.Unmarshal(scanner.Bytes(), &msg); err != nil { + log.Fatalf("logreprocess of %q: %v", string(scanner.Bytes()), err) + } + var pMsg *ProcessedMsg + if pMsg = tracebackCache[msg.Logtail.Instance]; pMsg != nil { + pMsg.Text += "\n" + msg.Text + if strings.HasPrefix(msg.Text, "Exception: ") { + delete(tracebackCache, msg.Logtail.Instance) + } else { + continue // write later + } + } else { + pMsg = &ProcessedMsg{ + OrigInstance: msg.Logtail.Instance, + Text: msg.Text, + } + pMsg.Logtail.ClientTime = msg.Logtail.ClientTime + } + + if strings.HasPrefix(msg.Text, "Traceback (most recent call last):") { + tracebackCache[msg.Logtail.Instance] = pMsg + continue // write later + } + + b, err := json.Marshal(pMsg) + if err != nil { + log.Fatal(err) + } + log.Printf("%s", b) + } + if err := scanner.Err(); err != nil { + log.Fatal(err) + } +} + +type Msg struct { + Logtail struct { + Instance logtail.PublicID `json:"instance"` + ClientTime time.Time `json:"client_time"` + } `json:"logtail"` + + Text string `json:"text"` +} + +type ProcessedMsg struct { + Logtail struct { + ClientTime time.Time `json:"client_time"` + } `json:"logtail"` + + OrigInstance logtail.PublicID `json:"orig_instance"` + Text string `json:"text"` +} diff --git a/logtail/example/logtail/logtail.go b/logtail/example/logtail/logtail.go new file mode 100644 index 000000000..d56ee4585 --- /dev/null +++ b/logtail/example/logtail/logtail.go @@ -0,0 +1,46 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// The logtail program logs stdin. +package main + +import ( + "bufio" + "flag" + "io" + "log" + "os" + + "tailscale.com/logtail" +) + +func main() { + collection := flag.String("c", "", "logtail collection name") + privateID := flag.String("k", "", "machine private identifier, 32-bytes in hex") + flag.Parse() + if len(flag.Args()) != 0 { + flag.Usage() + os.Exit(1) + } + + log.SetFlags(0) + + var id logtail.PrivateID + if err := id.UnmarshalText([]byte(*privateID)); err != nil { + log.Fatalf("logtail: bad -privateid: %v", err) + } + + logger := logtail.Log(logtail.Config{ + Collection: *collection, + PrivateID: id, + }) + log.SetOutput(io.MultiWriter(logger, os.Stdout)) + defer logger.Flush() + defer log.Printf("logtail exited") + + scanner := bufio.NewScanner(os.Stdin) + for scanner.Scan() { + log.Println(scanner.Text()) + } +} diff --git a/logtail/filch/filch.go b/logtail/filch/filch.go new file mode 100644 index 000000000..64c983626 --- /dev/null +++ b/logtail/filch/filch.go @@ -0,0 +1,238 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package filch is a file system queue that pilfers your stderr. +// (A FILe CHannel that filches.) +package filch + +import ( + "bufio" + "bytes" + "fmt" + "io" + "os" + "sync" +) + +var stderrFD = 2 // a variable for testing + +type Options struct { + ReplaceStderr bool // dup over fd 2 so everything written to stderr comes here +} + +type Filch struct { + OrigStderr *os.File + + mu sync.Mutex + cur *os.File + alt *os.File + altscan *bufio.Scanner + recovered int64 +} + +func (f *Filch) TryReadLine() ([]byte, error) { + f.mu.Lock() + defer f.mu.Unlock() + + if f.altscan != nil { + if b, err := f.scan(); b != nil || err != nil { + return b, err + } + } + + f.cur, f.alt = f.alt, f.cur + if f.OrigStderr != nil { + if err := dup2Stderr(f.cur); err != nil { + return nil, err + } + } + if _, err := f.alt.Seek(0, os.SEEK_SET); err != nil { + return nil, err + } + f.altscan = bufio.NewScanner(f.alt) + f.altscan.Split(splitLines) + return f.scan() +} + +func (f *Filch) scan() ([]byte, error) { + if f.altscan.Scan() { + return f.altscan.Bytes(), nil + } + err := f.altscan.Err() + err2 := f.alt.Truncate(0) + _, err3 := f.alt.Seek(0, os.SEEK_SET) + f.altscan = nil + if err != nil { + return nil, err + } + if err2 != nil { + return nil, err2 + } + if err3 != nil { + return nil, err3 + } + return nil, nil +} + +func (f *Filch) Write(b []byte) (int, error) { + f.mu.Lock() + defer f.mu.Unlock() + + if len(b) == 0 || b[len(b)-1] != '\n' { + bnl := make([]byte, len(b)+1) + copy(bnl, b) + bnl[len(bnl)-1] = '\n' + return f.cur.Write(bnl) + } + return f.cur.Write(b) +} + +func (f *Filch) Close() (err error) { + f.mu.Lock() + defer f.mu.Unlock() + + if f.OrigStderr != nil { + if err2 := unsaveStderr(f.OrigStderr); err == nil { + err = err2 + } + f.OrigStderr = nil + } + + if err2 := f.cur.Close(); err == nil { + err = err2 + } + if err2 := f.alt.Close(); err == nil { + err = err2 + } + + return err +} + +func New(filePrefix string, opts Options) (f *Filch, err error) { + var f1, f2 *os.File + defer func() { + if err != nil { + if f1 != nil { + f1.Close() + } + if f2 != nil { + f2.Close() + } + err = fmt.Errorf("filch: %s", err) + } + }() + + path1 := filePrefix + ".log1.txt" + path2 := filePrefix + ".log2.txt" + + f1, err = os.OpenFile(path1, os.O_CREATE|os.O_RDWR, 0666) + if err != nil { + return nil, err + } + f2, err = os.OpenFile(path2, os.O_CREATE|os.O_RDWR, 0666) + if err != nil { + return nil, err + } + + fi1, err := f1.Stat() + if err != nil { + return nil, err + } + fi2, err := f2.Stat() + if err != nil { + return nil, err + } + + f = &Filch{ + OrigStderr: os.Stderr, // temporary, for past logs recovery + } + + // Neither, either, or both files may exist and contain logs from + // the last time the process ran. The three cases are: + // + // - neither: all logs were read out and files were truncated + // - either: logs were being written into one of the files + // - both: the files were swapped and were starting to be + // read out, while new logs streamed into the other + // file, but the read out did not complete + if n := fi1.Size() + fi2.Size(); n > 0 { + f.recovered = n + } + switch { + case fi1.Size() > 0 && fi2.Size() == 0: + f.cur, f.alt = f2, f1 + case fi2.Size() > 0 && fi1.Size() == 0: + f.cur, f.alt = f1, f2 + case fi1.Size() > 0 && fi2.Size() > 0: // both + // We need to pick one of the files to be the elder, + // which we do using the mtime. + var older, newer *os.File + if fi1.ModTime().Before(fi2.ModTime()) { + older, newer = f1, f2 + } else { + older, newer = f2, f1 + } + if err := moveContents(older, newer); err != nil { + fmt.Fprintf(f.OrigStderr, "filch: recover move failed: %v\n", err) + fmt.Fprintf(older, "filch: recover move failed: %v\n", err) + } + f.cur, f.alt = newer, older + default: + f.cur, f.alt = f1, f2 // does not matter + } + if f.recovered > 0 { + f.altscan = bufio.NewScanner(f.alt) + f.altscan.Split(splitLines) + } + + f.OrigStderr = nil + if opts.ReplaceStderr { + f.OrigStderr, err = saveStderr() + if err != nil { + return nil, err + } + if err := dup2Stderr(f.cur); err != nil { + return nil, err + } + } + + return f, nil +} + +func moveContents(dst, src *os.File) (err error) { + defer func() { + _, err2 := src.Seek(0, os.SEEK_SET) + err3 := src.Truncate(0) + _, err4 := dst.Seek(0, os.SEEK_SET) + if err == nil { + err = err2 + } + if err == nil { + err = err3 + } + if err == nil { + err = err4 + } + }() + if _, err := src.Seek(0, os.SEEK_SET); err != nil { + return err + } + if _, err := io.Copy(dst, src); err != nil { + return err + } + return nil +} + +func splitLines(data []byte, atEOF bool) (advance int, token []byte, err error) { + if atEOF && len(data) == 0 { + return 0, nil, nil + } + if i := bytes.IndexByte(data, '\n'); i >= 0 { + return i + 1, data[0 : i+1], nil + } + if atEOF { + return len(data), data, nil + } + return 0, nil, nil +} diff --git a/logtail/filch/filch_test.go b/logtail/filch/filch_test.go new file mode 100644 index 000000000..78acb9ce5 --- /dev/null +++ b/logtail/filch/filch_test.go @@ -0,0 +1,178 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package filch + +import ( + "fmt" + "io/ioutil" + "os" + "path/filepath" + "strings" + "testing" + "unicode" +) + +type filchTest struct { + *Filch +} + +func newFilchTest(t *testing.T, filePrefix string, opts Options) *filchTest { + f, err := New(filePrefix, opts) + if err != nil { + t.Fatal(err) + } + return &filchTest{Filch: f} +} + +func (f *filchTest) write(t *testing.T, s string) { + t.Helper() + if _, err := f.Write([]byte(s)); err != nil { + t.Fatal(err) + } +} + +func (f *filchTest) read(t *testing.T, want string) { + t.Helper() + if b, err := f.TryReadLine(); err != nil { + t.Fatalf("r.ReadLine() err=%v", err) + } else if got := strings.TrimRightFunc(string(b), unicode.IsSpace); got != want { + t.Errorf("r.ReadLine()=%q, want %q", got, want) + } +} + +func (f *filchTest) readEOF(t *testing.T) { + t.Helper() + if b, err := f.TryReadLine(); b != nil || err != nil { + t.Fatalf("r.ReadLine()=%q err=%v, want nil slice", string(b), err) + } +} + +func (f *filchTest) close(t *testing.T) { + t.Helper() + if err := f.Close(); err != nil { + t.Fatal(err) + } +} + +func genFilePrefix(t *testing.T) string { + t.Helper() + filePrefix, err := ioutil.TempDir("", "filch") + if err != nil { + t.Fatal(err) + } + return filepath.Join(filePrefix, "ringbuffer-") +} + +func TestQueue(t *testing.T) { + filePrefix := genFilePrefix(t) + defer os.RemoveAll(filepath.Dir(filePrefix)) + + f := newFilchTest(t, filePrefix, Options{ReplaceStderr: false}) + + f.readEOF(t) + const line1 = "Hello, World!" + const line2 = "This is a test." + const line3 = "Of filch." + f.write(t, line1) + f.write(t, line2) + f.read(t, line1) + f.write(t, line3) + f.read(t, line2) + f.read(t, line3) + f.readEOF(t) + f.write(t, line1) + f.read(t, line1) + f.readEOF(t) + f.close(t) +} + +func TestRecover(t *testing.T) { + t.Run("empty", func(t *testing.T) { + filePrefix := genFilePrefix(t) + defer os.RemoveAll(filepath.Dir(filePrefix)) + f := newFilchTest(t, filePrefix, Options{ReplaceStderr: false}) + f.write(t, "hello") + f.read(t, "hello") + f.readEOF(t) + f.close(t) + + f = newFilchTest(t, filePrefix, Options{ReplaceStderr: false}) + f.readEOF(t) + f.close(t) + }) + + t.Run("cur", func(t *testing.T) { + filePrefix := genFilePrefix(t) + defer os.RemoveAll(filepath.Dir(filePrefix)) + f := newFilchTest(t, filePrefix, Options{ReplaceStderr: false}) + f.write(t, "hello") + f.close(t) + + f = newFilchTest(t, filePrefix, Options{ReplaceStderr: false}) + f.read(t, "hello") + f.readEOF(t) + f.close(t) + }) + + t.Run("alt", func(t *testing.T) { + t.Skip("currently broken on linux, passes on macOS") + /* --- FAIL: TestRecover/alt (0.00s) + filch_test.go:128: r.ReadLine()="world", want "hello" + filch_test.go:129: r.ReadLine()="hello", want "world" + */ + + filePrefix := genFilePrefix(t) + defer os.RemoveAll(filepath.Dir(filePrefix)) + f := newFilchTest(t, filePrefix, Options{ReplaceStderr: false}) + f.write(t, "hello") + f.read(t, "hello") + f.write(t, "world") + f.close(t) + + f = newFilchTest(t, filePrefix, Options{ReplaceStderr: false}) + // TODO(crawshaw): The "hello" log is replayed in recovery. + // We could reduce replays by risking some logs loss. + // What should our policy here be? + f.read(t, "hello") + f.read(t, "world") + f.readEOF(t) + f.close(t) + }) +} + +func TestFilchStderr(t *testing.T) { + pipeR, pipeW, err := os.Pipe() + if err != nil { + t.Fatal(err) + } + defer pipeR.Close() + defer pipeW.Close() + + stderrFD = int(pipeW.Fd()) + defer func() { + stderrFD = 2 + }() + + filePrefix := genFilePrefix(t) + defer os.RemoveAll(filepath.Dir(filePrefix)) + f := newFilchTest(t, filePrefix, Options{ReplaceStderr: true}) + f.write(t, "hello") + if _, err := fmt.Fprintf(pipeW, "filch\n"); err != nil { + t.Fatal(err) + } + f.read(t, "hello") + f.read(t, "filch") + f.readEOF(t) + f.close(t) + + pipeW.Close() + b, err := ioutil.ReadAll(pipeR) + if err != nil { + t.Fatal(err) + } + if len(b) > 0 { + t.Errorf("unexpected write to fake stderr: %s", b) + } +} diff --git a/logtail/filch/filch_unix.go b/logtail/filch/filch_unix.go new file mode 100644 index 000000000..aadf66986 --- /dev/null +++ b/logtail/filch/filch_unix.go @@ -0,0 +1,30 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//+build !windows + +package filch + +import ( + "os" + "syscall" +) + +func saveStderr() (*os.File, error) { + fd, err := syscall.Dup(stderrFD) + if err != nil { + return nil, err + } + return os.NewFile(uintptr(fd), "stderr"), nil +} + +func unsaveStderr(f *os.File) error { + err := dup2Stderr(f) + f.Close() + return err +} + +func dup2Stderr(f *os.File) error { + return syscall.Dup2(int(f.Fd()), stderrFD) +} diff --git a/logtail/filch/filch_windows.go b/logtail/filch/filch_windows.go new file mode 100644 index 000000000..1dba9d50c --- /dev/null +++ b/logtail/filch/filch_windows.go @@ -0,0 +1,44 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package filch + +import ( + "fmt" + "os" + "syscall" +) + +var kernel32 = syscall.MustLoadDLL("kernel32.dll") +var procSetStdHandle = kernel32.MustFindProc("SetStdHandle") + +func setStdHandle(stdHandle int32, handle syscall.Handle) error { + r, _, e := syscall.Syscall(procSetStdHandle.Addr(), 2, uintptr(stdHandle), uintptr(handle), 0) + if r == 0 { + if e != 0 { + return error(e) + } + return syscall.EINVAL + } + return nil +} + +func saveStderr() (*os.File, error) { + return os.Stderr, nil +} + +func unsaveStderr(f *os.File) error { + os.Stderr = f + return nil +} + +func dup2Stderr(f *os.File) error { + fd := int(f.Fd()) + err := setStdHandle(syscall.STD_ERROR_HANDLE, syscall.Handle(fd)) + if err != nil { + return fmt.Errorf("dup2Stderr: %w", err) + } + os.Stderr = f + return nil +} diff --git a/logtail/id.go b/logtail/id.go new file mode 100644 index 000000000..5f78bd744 --- /dev/null +++ b/logtail/id.go @@ -0,0 +1,103 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package logtail + +import ( + "crypto/rand" + "crypto/sha256" + "encoding/hex" + "fmt" +) + +// PrivateID represents an instance that write logs. +// Private IDs are only shared with the server when writing logs. +type PrivateID [32]byte + +// Safely generate a new PrivateId for use in Config objects. +// You should persist this across runs of an instance of your app, so that +// it can append to the same log file on each run. +func NewPrivateID() (id PrivateID, err error) { + _, err = rand.Read(id[:]) + if err != nil { + return PrivateID{}, err + } + // Clamping, for future use. + id[0] &= 248 + id[31] = (id[31] & 127) | 64 + return id, nil +} + +func (id PrivateID) MarshalText() ([]byte, error) { + b := make([]byte, hex.EncodedLen(len(id))) + if i := hex.Encode(b, id[:]); i != len(b) { + return nil, fmt.Errorf("logtail.PrivateID.MarhsalText: i=%d", i) + } + return b, nil +} + +func (id *PrivateID) UnmarshalText(s []byte) error { + b, err := hex.DecodeString(string(s)) + if err != nil { + return fmt.Errorf("logtail.PrivateID.UnmarshalText: %v", err) + } + if len(b) != len(id) { + return fmt.Errorf("logtail.PrivateID.UnmarshalText: invalid hex length: %d", len(b)) + } + copy(id[:], b) + return nil +} + +func (id PrivateID) String() string { + b, err := id.MarshalText() + if err != nil { + panic(err) + } + return string(b) +} + +func (id PrivateID) Public() (pub PublicID) { + var emptyID PrivateID + if id == emptyID { + panic("invalid logtail.Public() on an empty private ID") + } + h := sha256.New() + h.Write(id[:]) + if n := copy(pub[:], h.Sum(pub[:0])); n != len(pub) { + panic(fmt.Sprintf("public id short copy: %d", n)) + } + return pub +} + +// PublicID represents an instance in the logs service for reading and adoption. +// The public ID value is a SHA-256 hash of a private ID. +type PublicID [sha256.Size]byte + +func (id PublicID) MarshalText() ([]byte, error) { + b := make([]byte, hex.EncodedLen(len(id))) + if i := hex.Encode(b, id[:]); i != len(b) { + return nil, fmt.Errorf("logtail.PublicID.MarhsalText: i=%d", i) + } + return b, nil +} + +func (id *PublicID) UnmarshalText(s []byte) error { + b, err := hex.DecodeString(string(s)) + if err != nil { + return fmt.Errorf("logtail.PublicID.UnmarshalText: %v", err) + } + if len(b) != len(id) { + return fmt.Errorf("logtail.PublicID.UnmarshalText: invalid hex length: %d", len(b)) + } + copy(id[:], b) + return nil +} + +func (id PublicID) String() string { + b, err := id.MarshalText() + if err != nil { + panic(err) + } + return string(b) +} diff --git a/logtail/id_test.go b/logtail/id_test.go new file mode 100644 index 000000000..25000eab7 --- /dev/null +++ b/logtail/id_test.go @@ -0,0 +1,54 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package logtail + +import ( + "testing" +) + +func TestIDs(t *testing.T) { + id1, err := NewPrivateID() + if err != nil { + t.Fatal(err) + } + pub1 := id1.Public() + + id2, err := NewPrivateID() + if err != nil { + t.Fatal(err) + } + pub2 := id2.Public() + + if id1 == id2 { + t.Fatalf("subsequent private IDs match: %v", id1) + } + if pub1 == pub2 { + t.Fatalf("subsequent public IDs match: %v", id1) + } + if id1.String() == id2.String() { + t.Fatalf("id1.String()=%v equals id2.String()", id1.String()) + } + if pub1.String() == pub2.String() { + t.Fatalf("pub1.String()=%v equals pub2.String()", pub1.String()) + } + + id1txt, err := id1.MarshalText() + if err != nil { + t.Fatal(err) + } + var id3 PrivateID + if err := id3.UnmarshalText(id1txt); err != nil { + t.Fatal(err) + } + if id1 != id3 { + t.Fatalf("id1 %v: marshal and unmarshal gives different key: %v", id1, id3) + } + if want, got := id1.Public(), id3.Public(); want != got { + t.Fatalf("id1.Public()=%v does not match id3.Public()=%v", want, got) + } + if id1.String() != id3.String() { + t.Fatalf("id1.String()=%v does not match id3.String()=%v", id1.String(), id3.String()) + } +} diff --git a/logtail/logtail.go b/logtail/logtail.go new file mode 100644 index 000000000..41dae40fa --- /dev/null +++ b/logtail/logtail.go @@ -0,0 +1,464 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package logtail sends logs to log.tailscale.io. +package logtail + +import ( + "bytes" + "context" + "crypto/rand" + "encoding/json" + "errors" + "fmt" + "io" + "io/ioutil" + "math/big" + "net/http" + "os" + "strconv" + "sync" + "time" + + "tailscale.com/logtail/backoff" +) + +type Logger interface { + // Write logs an encoded JSON blob. + // + // If the []byte passed to Write is not an encoded JSON blob, + // then contents is fit into a JSON blob and written. + // + // This is intended as an interface for the stdlib "log" package. + Write([]byte) (int, error) + + // Flush uploads all logs to the server. + // It blocks until complete or there is an unrecoverable error. + Flush() error + + // Shutdown gracefully shuts down the logger while completing any + // remaining uploads. + // + // It will block, continuing to try and upload unless the passed + // context object interrupts it by being done. + // If the shutdown is interrupted, an error is returned. + Shutdown(context.Context) error + + // Close shuts down this logger object, the background log uploader + // process, and any associated goroutines. + // + // DEPRECATED: use Shutdown + Close() +} + +type Encoder interface { + EncodeAll(src, dst []byte) []byte + Close() error +} + +type Config struct { + Collection string // collection name, a domain name + PrivateID PrivateID // machine-specific private identifier + BaseURL string // if empty defaults to "https://log.tailscale.io" + HTTPC *http.Client // if empty defaults to http.DefaultClient + SkipClientTime bool // if true, client_time is not written to logs + LowMemory bool // if true, logtail minimizes memory use + TimeNow func() time.Time // if set, subsitutes uses of time.Now + Stderr io.Writer // if set, logs are sent here instead of os.Stderr + Buffer Buffer // temp storage, if nil a MemoryBuffer + CheckLogs <-chan struct{} // signals Logger to check for filched logs to upload + NewZstdEncoder func() Encoder // if set, used to compress logs for transmission +} + +func Log(cfg Config) Logger { + if cfg.BaseURL == "" { + cfg.BaseURL = "https://log.tailscale.io" + } + if cfg.HTTPC == nil { + cfg.HTTPC = http.DefaultClient + } + if cfg.TimeNow == nil { + cfg.TimeNow = time.Now + } + if cfg.Stderr == nil { + cfg.Stderr = os.Stderr + } + if cfg.Buffer == nil { + pendingSize := 256 + if cfg.LowMemory { + pendingSize = 64 + } + cfg.Buffer = NewMemoryBuffer(pendingSize) + } + if cfg.CheckLogs == nil { + cfg.CheckLogs = make(chan struct{}) + } + l := &logger{ + stderr: cfg.Stderr, + httpc: cfg.HTTPC, + url: cfg.BaseURL + "/c/" + cfg.Collection + "/" + cfg.PrivateID.String(), + lowMem: cfg.LowMemory, + buffer: cfg.Buffer, + skipClientTime: cfg.SkipClientTime, + sent: make(chan struct{}, 1), + sentinel: make(chan int32, 16), + checkLogs: cfg.CheckLogs, + timeNow: cfg.TimeNow, + bo: backoff.Backoff{ + Name: "logtail", + }, + + shutdownStart: make(chan struct{}), + shutdownDone: make(chan struct{}), + } + if cfg.NewZstdEncoder != nil { + l.zstdEncoder = cfg.NewZstdEncoder() + } + + ctx, cancel := context.WithCancel(context.Background()) + l.uploadCancel = cancel + + go l.uploading(ctx) + l.Write([]byte("logtail started")) + return l +} + +type logger struct { + stderr io.Writer + httpc *http.Client + url string + lowMem bool + skipClientTime bool + buffer Buffer + sent chan struct{} // signal to speed up drain + checkLogs <-chan struct{} // external signal to attempt a drain + sentinel chan int32 + timeNow func() time.Time + bo backoff.Backoff + zstdEncoder Encoder + uploadCancel func() + + shutdownStart chan struct{} // closed when shutdown begins + shutdownDone chan struct{} // closd when shutdown complete + + dropMu sync.Mutex + dropCount int +} + +func (l *logger) Shutdown(ctx context.Context) error { + if ctx == nil { + ctx = context.Background() + } + done := make(chan struct{}) + go func() { + select { + case <-ctx.Done(): + l.uploadCancel() + <-l.shutdownDone + case <-l.shutdownDone: + } + close(done) + }() + + close(l.shutdownStart) + io.WriteString(l, "logger closing down\n") + <-done + + if l.zstdEncoder != nil { + return l.zstdEncoder.Close() + } + return nil +} + +func (l *logger) Close() { + l.Shutdown(nil) +} + +func (l *logger) drainPending() (res []byte) { + buf := new(bytes.Buffer) + entries := 0 + + var batchDone bool + for buf.Len() < 1<<18 && !batchDone { + b, err := l.buffer.TryReadLine() + if err == io.EOF { + break + } else if err != nil { + b = []byte(fmt.Sprintf("reading ringbuffer: %v", err)) + batchDone = true + } else if b == nil { + if entries > 0 { + break + } + + select { + case <-l.shutdownStart: + batchDone = true + case <-l.checkLogs: + case <-l.sent: + } + continue + } + + if len(b) == 0 { + continue + } + if b[0] != '{' || !json.Valid(b) { + // This is probably a log added to stderr by filch + // outside of the logtail logger. Encode it. + // Do not add a client time, as it could have been + // been written a long time ago. + b = l.encodeText(b, true) + } + + switch { + case entries == 0: + buf.Write(b) + case entries == 1: + buf2 := new(bytes.Buffer) + buf2.WriteByte('[') + buf2.Write(buf.Bytes()) + buf2.WriteByte(',') + buf2.Write(b) + buf.Reset() + buf.Write(buf2.Bytes()) + default: + buf.WriteByte(',') + buf.Write(b) + } + entries++ + } + + if entries > 1 { + buf.WriteByte(']') + } + if buf.Len() == 0 { + return nil + } + return buf.Bytes() +} + +var clientSentinelPrefix = []byte(`{"logtail":{"client_sentinel":`) + +const ( + noSentinel = 0 + stopSentinel = 1 +) + +// newSentinel creates a client sentinel between 2 and maxint32. +// It does not generate the reserved values: +// 0 is no sentinel +// 1 is stop the logger +func newSentinel() ([]byte, int32) { + val, err := rand.Int(rand.Reader, big.NewInt(1<<31-2)) + if err != nil { + panic(err) + } + v := int32(val.Int64()) + 2 + + buf := new(bytes.Buffer) + fmt.Fprintf(buf, "%s%d}}\n", clientSentinelPrefix, v) + return buf.Bytes(), v +} + +// readSentinel reads a sentinel. +// If it is not a sentinel it reports 0. +func readSentinel(b []byte) int32 { + if !bytes.HasPrefix(b, clientSentinelPrefix) { + return 0 + } + b = bytes.TrimPrefix(b, clientSentinelPrefix) + b = bytes.TrimSuffix(bytes.TrimSpace(b), []byte("}}")) + v, err := strconv.Atoi(string(b)) + if err != nil { + return 0 + } + return int32(v) +} + +// This is the goroutine that repeatedly uploads logs in the background. +func (l *logger) uploading(ctx context.Context) { + defer close(l.shutdownDone) + + for { + body := l.drainPending() + if l.zstdEncoder != nil { + body = l.zstdEncoder.EncodeAll(body, nil) + } + + for len(body) > 0 { + select { + case <-ctx.Done(): + return + default: + } + uploaded, err := l.upload(ctx, body) + if err != nil { + fmt.Fprintf(l.stderr, "logtail: upload: %v\n", err) + } + if uploaded { + break + } + l.bo.BackOff(ctx, err) + } + + select { + case <-l.shutdownStart: + return + default: + } + } +} + +func (l *logger) upload(ctx context.Context, body []byte) (uploaded bool, err error) { + req, err := http.NewRequest("POST", l.url, bytes.NewReader(body)) + if err != nil { + // I know of no conditions under which this could fail. + // Report it very loudly. + // TODO record logs to disk + panic("logtail: cannot build http request: " + err.Error()) + } + if l.zstdEncoder != nil { + req.Header.Add("Content-Encoding", "zstd") + } + + maxUploadTime := 45 * time.Second + ctx, cancel := context.WithTimeout(ctx, maxUploadTime) + defer cancel() + req = req.WithContext(ctx) + + compressedNote := "not-compressed" + if l.zstdEncoder != nil { + compressedNote = "compressed" + } + + resp, err := l.httpc.Do(req) + if err != nil { + return false, fmt.Errorf("log upload of %d bytes %s failed: %v", len(body), compressedNote, err) + } + defer resp.Body.Close() + if resp.StatusCode != 200 { + uploaded = resp.StatusCode == 400 // the server saved the logs anyway + b, _ := ioutil.ReadAll(resp.Body) + return uploaded, fmt.Errorf("log upload of %d bytes %s failed %d: %q", len(body), compressedNote, resp.StatusCode, string(b)) + } + return true, nil +} + +func (l *logger) Flush() error { + return nil +} + +var errHasLogtail = errors.New("logtail: JSON log message contains reserved 'logtail' property") + +func (l *logger) send(jsonBlob []byte) (int, error) { + n, err := l.buffer.Write(jsonBlob) + select { + case l.sent <- struct{}{}: + default: + } + return n, err +} + +func (l *logger) encodeText(buf []byte, skipClientTime bool) []byte { + now := l.timeNow() + + b := make([]byte, 0, len(buf)+16) + b = append(b, '{') + + if !skipClientTime { + b = append(b, `"logtail": {"client_time": "`...) + b = now.AppendFormat(b, time.RFC3339Nano) + b = append(b, "\"}, "...) + } + + b = append(b, "\"text\": \""...) + for i, c := range buf { + switch c { + case '\b': + b = append(b, '\\', 'b') + case '\f': + b = append(b, '\\', 'f') + case '\n': + b = append(b, '\\', 'n') + case '\r': + b = append(b, '\\', 'r') + case '\t': + b = append(b, '\\', 't') + case '"': + b = append(b, '\\', '"') + case '\\': + b = append(b, '\\', '\\') + default: + b = append(b, c) + } + if l.lowMem && i > 254 { + b = append(b, "…"...) + break + } + } + b = append(b, "\"}\n"...) + return b +} + +func (l *logger) encode(buf []byte) []byte { + if buf[0] != '{' { + return l.encodeText(buf, l.skipClientTime) // text fast-path + } + + now := l.timeNow() + + obj := make(map[string]interface{}) + if err := json.Unmarshal(buf, &obj); err != nil { + for k := range obj { + delete(obj, k) + } + obj["text"] = string(buf) + } + if txt, isStr := obj["text"].(string); l.lowMem && isStr && len(txt) > 254 { + // TODO(crawshaw): trim to unicode code point + obj["text"] = txt[:254] + "…" + } + + hasLogtail := obj["logtail"] != nil + if hasLogtail { + obj["error_has_logtail"] = obj["logtail"] + obj["logtail"] = nil + } + if !l.skipClientTime { + obj["logtail"] = map[string]string{ + "client_time": now.Format(time.RFC3339Nano), + } + } + + b, err := json.Marshal(obj) + if err != nil { + fmt.Fprintf(l.stderr, "logtail: re-encoding JSON failed: %v\n", err) + // I know of no conditions under which this could fail. + // Report it very loudly. + panic("logtail: re-encoding JSON failed: " + err.Error()) + } + b = append(b, '\n') + return b +} + +func (l *logger) Write(buf []byte) (int, error) { + if len(buf) == 0 { + return 0, nil + } + if l.stderr != nil && l.stderr != ioutil.Discard { + if buf[len(buf)-1] == '\n' { + l.stderr.Write(buf) + } else { + // The log package always line-terminates logs, + // so this is an uncommon path. + bufnl := make([]byte, len(buf)+1) + copy(bufnl, buf) + bufnl[len(bufnl)-1] = '\n' + l.stderr.Write(bufnl) + } + } + b := l.encode(buf) + return l.send(b) +} diff --git a/logtail/logtail_test.go b/logtail/logtail_test.go new file mode 100644 index 000000000..13c115d8c --- /dev/null +++ b/logtail/logtail_test.go @@ -0,0 +1,20 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package logtail + +import ( + "context" + "testing" +) + +func TestFastShutdown(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + l := Log(Config{ + BaseURL: "http://localhost:1234", + }) + l.Shutdown(ctx) +} diff --git a/portlist/netstat.go b/portlist/netstat.go new file mode 100644 index 000000000..1d1bc8731 --- /dev/null +++ b/portlist/netstat.go @@ -0,0 +1,155 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package portlist + +import ( + "fmt" + "sort" + "strconv" + "strings" + + exec "tailscale.com/tempfork/osexec" +) + +func parsePort(s string) int { + // a.b.c.d:1234 or [a:b:c:d]:1234 + i1 := strings.LastIndexByte(s, ':') + // a.b.c.d.1234 or [a:b:c:d].1234 + i2 := strings.LastIndexByte(s, '.') + + i := i1 + if i2 > i { + i = i2 + } + if i < 0 { + // no match; weird + return -1 + } + + portstr := s[i+1 : len(s)] + if portstr == "*" { + return 0 + } + + port, err := strconv.ParseUint(portstr, 10, 16) + if err != nil { + // invalid port; weird + return -1 + } + + return int(port) +} + +type nothing struct{} + +// Lowest common denominator parser for "netstat -na" format. +// All of Linux, Windows, and macOS support -na and give similar-ish output +// formats that we can parse without special detection logic. +// Unfortunately, options to filter by proto or state are non-portable, +// so we'll filter for ourselves. +func parsePortsNetstat(output string) List { + m := map[Port]nothing{} + lines := strings.Split(string(output), "\n") + + var lastline string + var lastport Port + for _, line := range lines { + trimline := strings.TrimSpace(line) + cols := strings.Fields(trimline) + if len(cols) < 1 { + continue + } + protos := strings.ToLower(cols[0]) + var proto, laddr, raddr string + if strings.HasPrefix(protos, "tcp") { + if len(cols) < 4 { + continue + } + proto = "tcp" + laddr = cols[len(cols)-3] + raddr = cols[len(cols)-2] + state := cols[len(cols)-1] + if !strings.HasPrefix(state, "LISTEN") { + // not interested in non-listener sockets + continue + } + } else if strings.HasPrefix(protos, "udp") { + if len(cols) < 3 { + continue + } + proto = "udp" + laddr = cols[len(cols)-2] + raddr = cols[len(cols)-1] + } else if protos[0] == '[' && len(trimline) > 2 { + // Windows: with netstat -nab, appends a line like: + // [description] + // after the port line. + p := lastport + delete(m, lastport) + proc := trimline[1 : len(trimline)-1] + if proc == "svchost.exe" && lastline != "" { + p.Process = lastline + } else { + if strings.HasSuffix(proc, ".exe") { + p.Process = proc[:len(proc)-4] + } else { + p.Process = proc + } + } + m[p] = nothing{} + } else { + // not interested in other protocols + lastline = trimline + continue + } + + lport := parsePort(laddr) + rport := parsePort(raddr) + if rport != 0 || lport <= 0 { + // not interested in "connected" sockets + continue + } + + p := Port{ + Proto: proto, + Port: uint16(lport), + } + m[p] = nothing{} + lastport = p + lastline = "" + } + + l := []Port{} + for p := range m { + l = append(l, p) + } + sort.Slice(l, func(i, j int) bool { + return (&l[i]).lessThan(&l[j]) + }) + + return l +} + +func listPortsNetstat(args string) (List, error) { + exe, err := exec.LookPath("netstat") + if err != nil { + return nil, fmt.Errorf("netstat: lookup: %v", err) + } + c := exec.Cmd{ + Path: exe, + Args: []string{exe, args}, + } + output, err := c.Output() + if err != nil { + xe, ok := err.(*exec.ExitError) + stderr := "" + if ok { + stderr = strings.TrimSpace(string(xe.Stderr)) + } + return nil, fmt.Errorf("netstat: %v (%q)", err, stderr) + } + + return parsePortsNetstat(string(output)), nil +} diff --git a/portlist/netstat_test.go b/portlist/netstat_test.go new file mode 100644 index 000000000..e39909e7f --- /dev/null +++ b/portlist/netstat_test.go @@ -0,0 +1,89 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package portlist + +import ( + "fmt" + "testing" +) + +func TestParsePort(t *testing.T) { + type InOut struct { + in string + expect int + } + tests := []InOut{ + InOut{"1.2.3.4:5678", 5678}, + InOut{"0.0.0.0.999", 999}, + InOut{"1.2.3.4:*", 0}, + InOut{"5.5.5.5:0", 0}, + InOut{"[1::2]:5", 5}, + InOut{"[1::2].5", 5}, + InOut{"gibberish", -1}, + } + + for _, io := range tests { + got := parsePort(io.in) + if got != io.expect { + t.Fatalf("input:%#v expect:%v got:%v\n", io.in, io.expect, got) + } + } +} + +var netstat_output = ` +// linux +tcp 0 0 0.0.0.0:22 0.0.0.0:* LISTEN +udp 0 0 0.0.0.0:5353 0.0.0.0:* +udp6 0 0 :::5353 :::* +udp6 0 0 :::5354 :::* + +// macOS +tcp4 0 0 *.23 *.* LISTEN +tcp6 0 0 *.24 *.* LISTEN +udp6 0 0 *.5453 *.* +udp4 0 0 *.5553 *.* + +// Windows 10 + Proto Local Address Foreign Address State + TCP 0.0.0.0:32 0.0.0.0:0 LISTENING + [sshd.exe] + UDP 0.0.0.0:5050 *:* + CDPSvc + [svchost.exe] + UDP 0.0.0.0:53 *:* + [chrome.exe] + UDP 10.0.1.43:9353 *:* + [iTunes.exe] + UDP [::]:53 *:* + UDP [::]:53 *:* + [funball.exe] +` + +func TestParsePortsNetstat(t *testing.T) { + expect := List{ + Port{"tcp", 22, "", ""}, + Port{"tcp", 23, "", ""}, + Port{"tcp", 24, "", ""}, + Port{"tcp", 32, "", "sshd"}, + Port{"udp", 53, "", "chrome"}, + Port{"udp", 53, "", "funball"}, + Port{"udp", 5050, "", "CDPSvc"}, + Port{"udp", 5353, "", ""}, + Port{"udp", 5354, "", ""}, + Port{"udp", 5453, "", ""}, + Port{"udp", 5553, "", ""}, + Port{"udp", 9353, "", "iTunes"}, + } + + pl := parsePortsNetstat(netstat_output) + fmt.Printf("--- expect:\n%v\n", expect) + fmt.Printf("--- got:\n%v\n", pl) + for i := range pl { + if expect[i] != pl[i] { + t.Fatalf("row#%d\n expect=%v\n got=%v\n", + i, expect[i], pl[i]) + } + } +} diff --git a/portlist/poller.go b/portlist/poller.go new file mode 100644 index 000000000..d6fd7036e --- /dev/null +++ b/portlist/poller.go @@ -0,0 +1,59 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package portlist + +import ( + "time" +) + +type Poller struct { + C chan List // new data when it arrives; closed when done + quitCh chan struct{} // close this to force exit + Err error // last returned error code, if any + prev List // most recent data +} + +func NewPoller() (*Poller, error) { + p := &Poller{ + C: make(chan List), + quitCh: make(chan struct{}), + } + // Do one initial poll synchronously, so the caller can react + // to any obvious errors. + p.prev, p.Err = GetList(nil) + return p, p.Err +} + +func (p *Poller) Close() { + close(p.quitCh) + <-p.C +} + +// Poll periodically. Run this in a goroutine if you want. +func (p *Poller) Run() error { + defer close(p.C) + tick := time.NewTicker(POLL_SECONDS * time.Second) + defer tick.Stop() + + // Send out the pre-generated initial value + p.C <- p.prev + + for { + select { + case <-tick.C: + pl, err := GetList(p.prev) + if err != nil { + p.Err = err + return p.Err + } + if !pl.SameInodes(p.prev) { + p.prev = pl + p.C <- pl + } + case <-p.quitCh: + return nil + } + } +} diff --git a/portlist/portlist.go b/portlist/portlist.go new file mode 100644 index 000000000..99e04891a --- /dev/null +++ b/portlist/portlist.go @@ -0,0 +1,87 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package portlist + +import ( + "fmt" + "strings" +) + +type Port struct { + Proto string + Port uint16 + inode string + Process string +} + +type List []Port + +var protos = []string{"tcp", "udp"} + +func (a *Port) lessThan(b *Port) bool { + if a.Port < b.Port { + return true + } else if a.Port > b.Port { + return false + } + + if a.Proto < b.Proto { + return true + } else if a.Proto > b.Proto { + return false + } + + if a.inode < b.inode { + return true + } else if a.inode > b.inode { + return false + } + + if a.Process < b.Process { + return true + } else if a.Process > b.Process { + return false + } + return false +} + +func (a List) SameInodes(b List) bool { + if a == nil || b == nil || len(a) != len(b) { + return false + } + for i := range a { + if a[i].Proto != b[i].Proto || + a[i].Port != b[i].Port || + a[i].inode != b[i].inode { + return false + } + } + return true +} + +func (pl List) String() string { + out := []string{} + for _, v := range pl { + out = append(out, fmt.Sprintf("%-3s %5d %-17s %#v", + v.Proto, v.Port, v.inode, v.Process)) + } + return strings.Join(out, "\n") +} + +func GetList(prev List) (List, error) { + pl, err := listPorts() + if err != nil { + return nil, fmt.Errorf("listPorts: %s", err) + } + if pl.SameInodes(prev) { + // Nothing changed, skip inode lookup + return prev, nil + } + pl, err = addProcesses(pl) + if err != nil { + return nil, fmt.Errorf("addProcesses: %s", err) + } + return pl, nil +} diff --git a/portlist/portlist_darwin.go b/portlist/portlist_darwin.go new file mode 100644 index 000000000..767f4f7e4 --- /dev/null +++ b/portlist/portlist_darwin.go @@ -0,0 +1,99 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build !linux,!windows + +package portlist + +import ( + "bufio" + "bytes" + "fmt" + "log" + "os" + "strings" + + exec "tailscale.com/tempfork/osexec" +) + +// We have to run netstat, which is a bit expensive, so don't do it too often. +const POLL_SECONDS = 5 + +func listPorts() (List, error) { + return listPortsNetstat("-na") +} + +// In theory, lsof could replace the function of both listPorts() and +// addProcesses(), since it provides a superset of the netstat output. +// However, "netstat -na" runs ~100x faster than lsof on my machine, so +// we should do it only if the list of open ports has actually changed. +// +// TODO(apenwarr): this fails in a macOS sandbox (ie. our usual case). +// We might as well just delete this code if we can't find a solution. +func addProcesses(pl []Port) ([]Port, error) { + exe, err := exec.LookPath("lsof") + if err != nil { + return nil, fmt.Errorf("lsof: lookup: %v", err) + } + c := exec.Cmd{ + Path: exe, + Args: []string{exe, "-F", "-n", "-P", "-O", "-S2", "-T", "-i4", "-i6"}, + } + output, err := c.Output() + if err != nil { + xe, ok := err.(*exec.ExitError) + stderr := "" + if ok { + stderr = strings.TrimSpace(string(xe.Stderr)) + } + // fails when run in a macOS sandbox, so make this non-fatal. + log.Printf("portlist: lsof: %v (%q)\n", err, stderr) + return pl, nil + } + + type ProtoPort struct { + proto string + port uint16 + } + m := map[ProtoPort]*Port{} + for i := range pl { + pp := ProtoPort{pl[i].Proto, pl[i].Port} + m[pp] = &pl[i] + } + + r := bytes.NewReader(output) + scanner := bufio.NewScanner(r) + + var cmd, proto string + for scanner.Scan() { + line := scanner.Text() + if line[0] == 'p' { + // starting a new process + cmd = "" + proto = "" + } else if line[0] == 'c' { + cmd = line[1:len(line)] + } else if line[0] == 'P' { + proto = strings.ToLower(line[1:len(line)]) + } else if line[0] == 'n' { + rest := line[1:len(line)] + i := strings.Index(rest, "->") + if i < 0 { + // a listening port + port := parsePort(rest) + if port > 0 { + pp := ProtoPort{proto, uint16(port)} + p := m[pp] + if p != nil { + p.Process = cmd + } else { + fmt.Fprintf(os.Stderr, "weird: missing %v\n", pp) + } + } + } + } + } + + return pl, nil +} diff --git a/portlist/portlist_linux.go b/portlist/portlist_linux.go new file mode 100644 index 000000000..53e03ee56 --- /dev/null +++ b/portlist/portlist_linux.go @@ -0,0 +1,155 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package portlist + +import ( + "bufio" + "fmt" + "io" + "io/ioutil" + "os" + "sort" + "strconv" + "strings" +) + +// Reading the sockfiles on Linux is very fast, so we can do it often. +const POLL_SECONDS = 1 + +// TODO(apenwarr): Include IPv6 ports eventually. +// Right now we don't route IPv6 anyway so it's better to exclude them. +var sockfiles = []string{"/proc/net/tcp", "/proc/net/udp"} + +func listPorts() (List, error) { + l := []Port{} + + for pi, fname := range sockfiles { + proto := protos[pi] + + f, err := os.Open(fname) + if err != nil { + return nil, fmt.Errorf("%s: %s", fname, err) + } + defer f.Close() + r := bufio.NewReader(f) + + // skip header row + _, err = r.ReadString('\n') + if err != nil { + return nil, err + } + + for err == nil { + line, err := r.ReadString('\n') + if err == io.EOF { + break + } + if err != nil { + return nil, err + } + + // sl local rem ... inode + words := strings.Fields(line) + local := words[1] + rem := words[2] + inode := words[9] + + if rem != "00000000:0000" { + // not a "listener" port + continue + } + + portv, err := strconv.ParseUint(local[9:], 16, 16) + if err != nil { + return nil, fmt.Errorf("%#v: %s", local[9:], err) + } + inodev := fmt.Sprintf("socket:[%s]", inode) + l = append(l, Port{ + Proto: proto, + Port: uint16(portv), + inode: inodev, + }) + } + } + + sort.Slice(l, func(i, j int) bool { + return (&l[i]).lessThan(&l[j]) + }) + + return l, nil +} + +func addProcesses(pl []Port) ([]Port, error) { + pm := map[string]*Port{} + for k := range pl { + pm[pl[k].inode] = &pl[k] + } + + pdir, err := os.Open("/proc") + if err != nil { + return nil, fmt.Errorf("/proc: %s", err) + } + defer pdir.Close() + + for { + pids, err := pdir.Readdirnames(100) + if err == io.EOF { + break + } + if err != nil { + return nil, fmt.Errorf("/proc: %s", err) + } + + for _, pid := range pids { + _, err := strconv.ParseInt(pid, 10, 64) + if err != nil { + // not a pid, ignore it. + // /proc has lots of non-pid stuff in it. + continue + } + fddir, err := os.Open(fmt.Sprintf("/proc/%s/fd", pid)) + if err != nil { + // Can't open fd list for this pid. Maybe + // don't have access. Ignore it. + continue + } + defer fddir.Close() + + for { + fds, err := fddir.Readdirnames(100) + if err == io.EOF { + break + } + if err != nil { + return nil, fmt.Errorf("readdir: %s", err) + } + for _, fd := range fds { + target, err := os.Readlink(fmt.Sprintf("/proc/%s/fd/%s", pid, fd)) + if err != nil { + // Not a symlink or no permission. + // Skip it. + continue + } + + // TODO(apenwarr): use /proc/*/cmdline instead of /comm? + // Unsure right now whether users will want the extra detail + // or not. + pe := pm[target] + if pe != nil { + comm, err := ioutil.ReadFile(fmt.Sprintf("/proc/%s/comm", pid)) + if err != nil { + // Usually shouldn't happen. One possibility is + // the process has gone away, so let's skip it. + continue + } + pe.Process = strings.TrimSpace(string(comm)) + } + } + } + } + } + + return pl, nil +} diff --git a/portlist/portlist_other.go b/portlist/portlist_other.go new file mode 100644 index 000000000..ffcbecb0b --- /dev/null +++ b/portlist/portlist_other.go @@ -0,0 +1,20 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build !linux,!windows,!darwin + +package portlist + +// We have to run netstat, which is a bit expensive, so don't do it too often. +const POLL_SECONDS = 5 + +func listPorts() (List, error) { + return listPortsNetstat("-na") +} + +func addProcesses(pl []Port) ([]Port, error) { + // Generic version has no way to get process mappings. + // This has to be OS-specific. + return pl, nil +} diff --git a/portlist/portlist_windows.go b/portlist/portlist_windows.go new file mode 100644 index 000000000..8e2885d25 --- /dev/null +++ b/portlist/portlist_windows.go @@ -0,0 +1,16 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package portlist + +// Forking on Windows is insanely expensive, so don't do it too often. +const POLL_SECONDS = 5 + +func listPorts() (List, error) { + return listPortsNetstat("-na") +} + +func addProcesses(pl []Port) ([]Port, error) { + return listPortsNetstat("-nab") +} diff --git a/ratelimit/ratelimit.go b/ratelimit/ratelimit.go new file mode 100644 index 000000000..ca7403ca5 --- /dev/null +++ b/ratelimit/ratelimit.go @@ -0,0 +1,78 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ratelimit + +import ( + "sync" + "time" +) + +type Bucket struct { + mu sync.Mutex + FillInterval time.Duration + Burst int + v int + quitCh chan struct{} + started bool + closed bool +} + +func (b *Bucket) startLocked() { + b.v = b.Burst + b.quitCh = make(chan struct{}) + b.started = true + + t := time.NewTicker(b.FillInterval) + go func() { + for { + select { + case <-b.quitCh: + return + case <-t.C: + b.tick() + } + } + }() +} + +func (b *Bucket) tick() { + b.mu.Lock() + defer b.mu.Unlock() + + if b.v < b.Burst { + b.v++ + } +} + +func (b *Bucket) Close() { + b.mu.Lock() + if !b.started { + b.closed = true + b.mu.Unlock() + return + } + if b.closed { + b.mu.Unlock() + return + } + b.closed = true + b.mu.Unlock() + + b.quitCh <- struct{}{} +} + +func (b *Bucket) TryGet() int { + b.mu.Lock() + defer b.mu.Unlock() + + if !b.started { + b.startLocked() + } + if b.v > 0 { + b.v-- + return b.v + 1 + } + return 0 +} diff --git a/ratelimit/ratelimit_test.go b/ratelimit/ratelimit_test.go new file mode 100644 index 000000000..bcea8c46e --- /dev/null +++ b/ratelimit/ratelimit_test.go @@ -0,0 +1,28 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ratelimit + +import ( + "testing" + "time" +) + +func TestBucket(t *testing.T) { + b := Bucket{ + FillInterval: time.Second, + Burst: 3, + } + expect := []int{3, 2, 1, 0, 0} + for i, want := range expect { + got := b.TryGet() + if want != got { + t.Errorf("#%d want=%d got=%d\n", i, want, got) + } + } + b.tick() + if want, got := 1, b.TryGet(); want != got { + t.Errorf("after tick: want=%d got=%d\n", want, got) + } +} diff --git a/safesocket/basic_test.go b/safesocket/basic_test.go new file mode 100644 index 000000000..8675d3379 --- /dev/null +++ b/safesocket/basic_test.go @@ -0,0 +1,63 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package safesocket + +import ( + "fmt" + "testing" +) + +func TestBasics(t *testing.T) { + fmt.Printf("listening2...\n") + l, port, err := Listen("COOKIE", "Tailscale", "test", 0) + if err != nil { + t.Fatal(err) + } + fmt.Printf("listened.\n") + + go func() { + fmt.Printf("accepting...\n") + s, err := l.Accept() + if err != nil { + t.Fatal(err) + } + fmt.Printf("accepted.\n") + l.Close() + s.Write([]byte("hello")) + fmt.Printf("server wrote.\n") + + b := make([]byte, 1024) + n, err := s.Read(b) + if err != nil { + t.Fatal(err) + } + fmt.Printf("server read %d bytes.\n", n) + if string(b[:n]) != "world" { + t.Fatalf("got %#v, expected %#v\n", string(b[:n]), "world") + } + s.Close() + }() + + fmt.Printf("connecting...\n") + c, err := Connect("COOKIE", "Tailscale", "test", port) + if err != nil { + t.Fatal(err) + } + fmt.Printf("connected.\n") + c.Write([]byte("world")) + fmt.Printf("client wrote.\n") + + b := make([]byte, 1024) + n, err := c.Read(b) + if err != nil { + t.Fatal(err) + } + fmt.Printf("client read %d bytes.\n", n) + if string(b[:n]) != "hello" { + t.Fatalf("got %#v, expected %#v\n", string(b[:n]), "hello") + } + + c.Close() +} diff --git a/safesocket/pipe_windows.go b/safesocket/pipe_windows.go new file mode 100644 index 000000000..49dd3fed7 --- /dev/null +++ b/safesocket/pipe_windows.go @@ -0,0 +1,59 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package safesocket + +import ( + "context" + "fmt" + "net" + "syscall" +) + +func path(vendor, name string, port uint16) string { + return fmt.Sprintf("127.0.0.1:%v", port) +} + +func ConnCloseRead(c net.Conn) error { + return c.(*net.TCPConn).CloseRead() +} + +func ConnCloseWrite(c net.Conn) error { + return c.(*net.TCPConn).CloseWrite() +} + +// TODO(apenwarr): handle magic cookie auth +func Connect(cookie, vendor, name string, port uint16) (net.Conn, error) { + p := path(vendor, name, port) + pipe, err := net.Dial("tcp", p) + if err != nil { + return nil, err + } + return pipe, err +} + +func setFlags(network, address string, c syscall.RawConn) error { + return c.Control(func(fd uintptr) { + syscall.SetsockoptInt(syscall.Handle(fd), syscall.SOL_SOCKET, + syscall.SO_REUSEADDR, 1) + }) +} + +// TODO(apenwarr): use named pipes instead of sockets? +// I tried to use winio.ListenPipe() here, but that code is a disaster, +// built on top of an API that's a disaster. So for now we'll hack it by +// just always using a TCP session on a fixed port on localhost. As a +// result, on Windows we ignore the vendor and name strings. +// TODO(apenwarr): handle magic cookie auth +func Listen(cookie, vendor, name string, port uint16) (net.Listener, uint16, error) { + lc := net.ListenConfig{ + Control: setFlags, + } + p := path(vendor, name, port) + pipe, err := lc.Listen(context.Background(), "tcp", p) + if err != nil { + return nil, 0, err + } + return pipe, uint16(pipe.Addr().(*net.TCPAddr).Port), err +} diff --git a/safesocket/unixsocket.go b/safesocket/unixsocket.go new file mode 100644 index 000000000..8093a64e7 --- /dev/null +++ b/safesocket/unixsocket.go @@ -0,0 +1,61 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build !windows + +package safesocket + +import ( + "fmt" + "net" + "os" +) + +func path(vendor, name string) string { + return fmt.Sprintf("%s-%s.sock", vendor, name) +} + +func ConnCloseRead(c net.Conn) error { + return c.(*net.UnixConn).CloseRead() +} + +func ConnCloseWrite(c net.Conn) error { + return c.(*net.UnixConn).CloseWrite() +} + +// TODO(apenwarr): handle magic cookie auth +func Connect(cookie, vendor, name string, port uint16) (net.Conn, error) { + pipe, err := net.Dial("unix", path(vendor, name)) + if err != nil { + return nil, err + } + return pipe, err +} + +// TODO(apenwarr): handle magic cookie auth +func Listen(cookie, vendor, name string, port uint16) (net.Listener, uint16, error) { + // Unix sockets hang around in the filesystem even after nobody + // is listening on them. (Which is really unfortunate but long- + // entrenched semantics.) Try connecting first; if it works, then + // the socket is still live, so let's not replace it. If it doesn't + // work, then replace it. + // + // Note that there's a race condition between these two steps. A + // "proper" daemon usually uses a dance involving pidfiles to first + // ensure that no other instances of itself are running, but that's + // beyond the scope of our simple socket library. + p := path(vendor, name) + c, err := net.Dial("unix", p) + if err == nil { + c.Close() + return nil, 0, fmt.Errorf("%v: address already in use", p) + } + _ = os.Remove(p) + pipe, err := net.Listen("unix", p) + if err != nil { + return nil, 0, err + } + os.Chmod(p, 0666) + return pipe, 0, err +} diff --git a/stun/stun.go b/stun/stun.go new file mode 100644 index 000000000..65dcf5eab --- /dev/null +++ b/stun/stun.go @@ -0,0 +1,206 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package STUN generates STUN request packets and parses response packets. +package stun + +import ( + "bytes" + "errors" + "hash/crc32" +) + +var ( + bindingRequest = []byte{0x00, 0x01} + magicCookie = []byte{0x21, 0x12, 0xa4, 0x42} + attrSoftware = append([]byte{ + 0x80, 0x22, // software header + 0x00, byte(len("tailnode")), // attr length + }, "tailnode"...) + lenMsg = byte(len(attrSoftware) + lenFingerprint) // number of bytes following header +) + +const lenFingerprint = 8 // 2+byte header + 2-byte length + 4-byte crc32 + +// Request generates a binding request STUN packet. +// The transaction ID, tID, should be a random sequence of bytes. +func Request(tID [12]byte) []byte { + // STUN header, RFC5389 Section 6. + b := make([]byte, 0, 20+len(attrSoftware)+lenFingerprint) + b = append(b, bindingRequest...) + b = append(b, 0x00, lenMsg) + b = append(b, magicCookie...) + b = append(b, tID[:]...) + + // Attribute SOFTWARE, RFC5389 Section 15.5. + b = append(b, attrSoftware...) + + // Attribute FINGERPRINT, RFC5389 Section 15.5. + fp := crc32.ChecksumIEEE(b) ^ 0x5354554e + b = append(b, 0x80, 0x28) // fingerprint header + b = append(b, 0x00, 0x04) // fingerprint length + b = append(b, + byte(fp>>24), + byte(fp>>16), + byte(fp>>8), + byte(fp), + ) + + return b +} + +var ( + ErrNotSTUN = errors.New("response is not a STUN packet") + ErrNotSuccessResponse = errors.New("STUN response error") + ErrMalformedAttrs = errors.New("STUN response has malformed attributes") +) + +// ParseResponse parses a successful binding response STUN packet. +// The IP address is extracted from the XOR-MAPPED-ADDRESS attribute. +func ParseResponse(b []byte) (tID [12]byte, addr []byte, port uint16, err error) { + if !Is(b) { + return tID, nil, 0, ErrNotSTUN + } + copy(tID[:], b[8:20]) + if b[0] != 0x01 || b[1] != 0x01 { + return tID, nil, 0, ErrNotSuccessResponse + } + attrsLen := int(b[2])<<8 | int(b[3]) + b = b[20:] // remove STUN header + if attrsLen > len(b) { + return tID, nil, 0, ErrMalformedAttrs + } else if len(b) > attrsLen { + b = b[:attrsLen] // trim trailing packet bytes + } + + var addr6, fallbackAddr, fallbackAddr6 []byte + var port6, fallbackPort, fallbackPort6 uint16 + + // Read through the attributes. + // The the addr+port reported by XOR-MAPPED-ADDRESS + // as the canonical value. If the attribute is not + // present but the STUN server responds with + // MAPPED-ADDRESS we fall back to it. + for len(b) > 0 { + if len(b) < 4 { + return tID, nil, 0, ErrMalformedAttrs + } + attrType := uint16(b[0])<<8 | uint16(b[1]) + attrLen := int(b[2])<<8 | int(b[3]) + attrLenPad := attrLen % 4 + if attrLen+attrLenPad > len(b)-4 { + return tID, nil, 0, ErrMalformedAttrs + } + b = b[4:] + + const typeMappedAddress = 0x0001 + const typeXorMappedAddress = 0x0020 + // This alternative attribute type is not + // mentioned in the RFC, but the shift into + // the "comprehension-optional" range seems + // like an easy mistake for a server to make. + // And servers appear to send it. + const typeXorMappedAddressAlt = 0x8020 + switch attrType { + case typeXorMappedAddress, typeXorMappedAddressAlt: + a, p, err := xorMappedAddress(tID, b[:attrLen]) + if err != nil { + return tID, nil, 0, ErrMalformedAttrs + } + if len(a) == 16 { + addr6, port6 = a, p + } else { + addr, port = a, p + } + case typeMappedAddress: + a, p, err := mappedAddress(b[:attrLen]) + if err != nil { + return tID, nil, 0, ErrMalformedAttrs + } + if len(a) == 16 { + fallbackAddr6, fallbackPort6 = a, p + } else { + fallbackAddr, fallbackPort = a, p + } + } + + b = b[attrLen+attrLenPad:] + } + + if addr != nil { + return tID, addr, port, nil + } + if fallbackAddr != nil { + return tID, append([]byte{}, fallbackAddr...), fallbackPort, nil + } + if addr6 != nil { + return tID, addr6, port6, nil + } + if fallbackAddr6 != nil { + return tID, append([]byte{}, fallbackAddr6...), fallbackPort6, nil + } + return tID, nil, 0, ErrMalformedAttrs +} + +func xorMappedAddress(tID [12]byte, b []byte) (addr []byte, port uint16, err error) { + // XOR-MAPPED-ADDRESS attribute, RFC5389 Section 15.2 + if len(b) < 8 { + return nil, 0, ErrMalformedAttrs + } + xorPort := uint16(b[2])<<8 | uint16(b[3]) + port = xorPort ^ 0x2112 // first half of magicCookie + + switch ipFamily := b[1]; ipFamily { // RFC5389 Section 15.1 + case 0x01: // IPv4 + addr = make([]byte, 4) + xorAddr := b[4 : 4+len(addr)] + for i := range xorAddr { + addr[i] = xorAddr[i] ^ magicCookie[i] + } + case 0x02: // IPv6 + addr = make([]byte, 16) + xorAddr := b[4 : 4+len(addr)] + for i := range xorAddr { + addr[i] = xorAddr[i] ^ magicCookie[i] + } + for i := 4; i < len(addr); i++ { + addr[i] = xorAddr[i] ^ tID[4-i] + } + default: + return nil, 0, ErrMalformedAttrs + } + if len(b) < 4+len(addr) { + return nil, 0, ErrMalformedAttrs + } + return addr, port, err +} + +func mappedAddress(b []byte) (addr []byte, port uint16, err error) { + if len(b) < 8 { + return nil, 0, ErrMalformedAttrs + } + port = uint16(b[2])<<8 | uint16(b[3]) + + switch ipFamily := b[1]; ipFamily { // RFC5389 Section 15.1 + case 0x01: // IPv4 + addr = b[4 : 4+4] + case 0x02: // IPv6 + addr = b[4 : 4+16] + default: + return nil, 0, ErrMalformedAttrs + } + return addr, port, err +} + +// Is reports whether b is a STUN message. +func Is(b []byte) bool { + if len(b) < 20 { + return false // every STUN message must have a 20-byte header + } + // TODO RFC5389 suggests checking the first 2 bits of the header are zero. + if !bytes.Equal(b[4:8], magicCookie) { + return false + } + return true +} diff --git a/stun/stun_test.go b/stun/stun_test.go new file mode 100644 index 000000000..46aed0ec1 --- /dev/null +++ b/stun/stun_test.go @@ -0,0 +1,148 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package stun_test + +import ( + "bytes" + "crypto/rand" + "fmt" + "log" + "testing" + + "tailscale.com/stun" +) + +func ExampleRequest() { + var transactionID [12]byte + if _, err := rand.Read(transactionID[:]); err != nil { + log.Fatal(err) + } + + req := stun.Request(transactionID) + fmt.Printf("%x\n", req) +} + +var responseTests = []struct { + name string + data []byte + wantTID []byte + wantAddr []byte + wantPort uint16 +}{ + { + name: "google-1", + data: []byte{ + 0x01, 0x01, 0x00, 0x0c, 0x21, 0x12, 0xa4, 0x42, + 0x23, 0x60, 0xb1, 0x1e, 0x3e, 0xc6, 0x8f, 0xfa, + 0x93, 0xe0, 0x80, 0x07, 0x00, 0x20, 0x00, 0x08, + 0x00, 0x01, 0xc7, 0x86, 0x69, 0x57, 0x85, 0x6f, + }, + wantTID: []byte{ + 0x23, 0x60, 0xb1, 0x1e, 0x3e, 0xc6, 0x8f, 0xfa, + 0x93, 0xe0, 0x80, 0x07, + }, + wantAddr: []byte{72, 69, 33, 45}, + wantPort: uint16(59028), + }, + { + name: "google-2", + data: []byte{ + 0x01, 0x01, 0x00, 0x0c, 0x21, 0x12, 0xa4, 0x42, + 0xf9, 0xf1, 0x21, 0xcb, 0xde, 0x7d, 0x7c, 0x75, + 0x92, 0x3c, 0xe2, 0x71, 0x00, 0x20, 0x00, 0x08, + 0x00, 0x01, 0xc7, 0x87, 0x69, 0x57, 0x85, 0x6f, + }, + wantTID: []byte{ + 0xf9, 0xf1, 0x21, 0xcb, 0xde, 0x7d, 0x7c, 0x75, + 0x92, 0x3c, 0xe2, 0x71, + }, + wantAddr: []byte{72, 69, 33, 45}, + wantPort: uint16(59029), + }, + { + name: "stun.sipgate.net:10000", + data: []byte{ + 0x01, 0x01, 0x00, 0x44, 0x21, 0x12, 0xa4, 0x42, + 0x48, 0x2e, 0xb6, 0x47, 0x15, 0xe8, 0xb2, 0x8e, + 0xae, 0xad, 0x64, 0x44, 0x00, 0x01, 0x00, 0x08, + 0x00, 0x01, 0xe4, 0xab, 0x48, 0x45, 0x21, 0x2d, + 0x00, 0x04, 0x00, 0x08, 0x00, 0x01, 0x27, 0x10, + 0xd9, 0x0a, 0x44, 0x98, 0x00, 0x05, 0x00, 0x08, + 0x00, 0x01, 0x27, 0x11, 0xd9, 0x74, 0x7a, 0x8a, + 0x80, 0x20, 0x00, 0x08, 0x00, 0x01, 0xc5, 0xb9, + 0x69, 0x57, 0x85, 0x6f, 0x80, 0x22, 0x00, 0x10, + 0x56, 0x6f, 0x76, 0x69, 0x64, 0x61, 0x2e, 0x6f, + 0x72, 0x67, 0x20, 0x30, 0x2e, 0x39, 0x36, 0x00, + }, + wantTID: []byte{ + 0x48, 0x2e, 0xb6, 0x47, 0x15, 0xe8, 0xb2, 0x8e, + 0xae, 0xad, 0x64, 0x44, + }, + wantAddr: []byte{72, 69, 33, 45}, + wantPort: uint16(58539), + }, + { + name: "stun.powervoip.com:3478", + data: []byte{ + 0x01, 0x01, 0x00, 0x24, 0x21, 0x12, 0xa4, 0x42, + 0x7e, 0x57, 0x96, 0x68, 0x29, 0xf4, 0x44, 0x60, + 0x9d, 0x1d, 0xea, 0xa6, 0x00, 0x01, 0x00, 0x08, + 0x00, 0x01, 0xe9, 0xd3, 0x48, 0x45, 0x21, 0x2d, + 0x00, 0x04, 0x00, 0x08, 0x00, 0x01, 0x0d, 0x96, + 0x4d, 0x48, 0xa9, 0xd4, 0x00, 0x05, 0x00, 0x08, + 0x00, 0x01, 0x0d, 0x97, 0x4d, 0x48, 0xa9, 0xd5, + }, + wantTID: []byte{ + 0x7e, 0x57, 0x96, 0x68, 0x29, 0xf4, 0x44, 0x60, + 0x9d, 0x1d, 0xea, 0xa6, + }, + wantAddr: []byte{72, 69, 33, 45}, + wantPort: uint16(59859), + }, + { + name: "in-process pion server", + data: []byte{ + 0x01, 0x01, 0x00, 0x24, 0x21, 0x12, 0xa4, 0x42, + 0xeb, 0xc2, 0xd3, 0x6e, 0xf4, 0x71, 0x21, 0x7c, + 0x4f, 0x3e, 0x30, 0x8e, 0x80, 0x22, 0x00, 0x0a, + 0x65, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, + 0x65, 0x72, 0x00, 0x00, 0x00, 0x20, 0x00, 0x08, + 0x00, 0x01, 0xce, 0x66, 0x5e, 0x12, 0xa4, 0x43, + 0x80, 0x28, 0x00, 0x04, 0xb6, 0x99, 0xbb, 0x02, + 0x01, 0x01, 0x00, 0x24, 0x21, 0x12, 0xa4, 0x42, + }, + wantTID: []byte{ + 0xeb, 0xc2, 0xd3, 0x6e, 0xf4, 0x71, 0x21, 0x7c, + 0x4f, 0x3e, 0x30, 0x8e, + }, + wantAddr: []byte{127, 0, 0, 1}, + wantPort: uint16(61300), + }, +} + +func TestParseResponse(t *testing.T) { + subtest := func(t *testing.T, i int) { + test := responseTests[i] + tID, addr, port, err := stun.ParseResponse(test.data) + if err != nil { + t.Fatal(err) + } + + if !bytes.Equal(tID[:], test.wantTID) { + t.Errorf("tid=%v, want %v", tID[:], test.wantTID) + } + if !bytes.Equal(addr, test.wantAddr) { + t.Errorf("addr=%v, want %v", addr, test.wantAddr) + } + if port != test.wantPort { + t.Errorf("port=%d, want %d", port, test.wantPort) + } + } + for i, test := range responseTests { + t.Run(test.name, func(t *testing.T) { + subtest(t, i) + }) + } +} diff --git a/stunner/stunner.go b/stunner/stunner.go new file mode 100644 index 000000000..a236b0260 --- /dev/null +++ b/stunner/stunner.go @@ -0,0 +1,197 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package stunner + +import ( + "bytes" + "context" + "crypto/rand" + "fmt" + "log" + "net" + "strconv" + "sync" + "time" + + "tailscale.com/stun" +) + +// Stunner sends a STUN request to several servers and handles a response. +// +// It is designed to used on a connection owned by other code and so does +// not directly reference a net.Conn of any sort. Instead, the user should +// provide Send function to send packets, and call Receive when a new +// STUN response is received. +// +// In response, a Stunner will call Endpoint with any endpoints determined +// for the connection. (An endpoint may be reported multiple times if +// multiple servers are provided.) +type Stunner struct { + Send func([]byte, net.Addr) (int, error) // sends a packet + Endpoint func(endpoint string) // reports an endpoint + Servers []string // STUN servers to contact + Resolver *net.Resolver + Logf func(format string, args ...interface{}) + + sessions map[string]*session + tIDs map[string][][12]byte +} + +type session struct { + replied chan struct{} // closed when server responds + tIDs [][12]byte // transaction IDs sent to a server +} + +// Receive delivers a STUN packet to the stunner. +func (s *Stunner) Receive(p []byte, fromAddr *net.UDPAddr) { + if !stun.Is(p) { + log.Println("stunner: received non-STUN packet") + return + } + + responseTID, addr, port, err := stun.ParseResponse(p) + if err != nil { + log.Printf("stunner: received bad STUN response: %v", err) + return + } + + // Accept any of the tIDs from any of the active sessions. + for server, session := range s.sessions { + for _, tID := range session.tIDs { + if bytes.Equal(tID[:], responseTID[:]) { + select { + case <-session.replied: + return // already got a reply from this server + default: + } + close(session.replied) + + // TODO(crawshaw): use different endpoints returned from + // different STUN servers to detect NAT types. + portStr := fmt.Sprintf("%d", port) + host := net.JoinHostPort(net.IP(addr).String(), portStr) + if s.Logf != nil { + s.Logf("STUN server %s reports public endpoint %s", server, host) + } + s.Endpoint(host) + return + } + } + } + log.Printf("stunner: received STUN packet for unknown transaction: %x", responseTID) +} + +// Run starts a Stunner and blocks until all servers either respond +// or are tried multiple times and timeout. +func (s *Stunner) Run(ctx context.Context) error { + if s.Resolver == nil { + s.Resolver = net.DefaultResolver + } + for _, server := range s.Servers { + // Generate the transaction IDs for this session. + tIDs := make([][12]byte, len(retryDurations)) + for i := range tIDs { + if _, err := rand.Read(tIDs[i][:]); err != nil { + return fmt.Errorf("stunner: rand failed: %v", err) + } + } + if s.sessions == nil { + s.sessions = make(map[string]*session) + } + s.sessions[server] = &session{ + replied: make(chan struct{}), + tIDs: tIDs, + } + } + // after this point, the s.sessions map is read-only + + var wg sync.WaitGroup + for _, server := range s.Servers { + wg.Add(1) + go func(server string) { + defer wg.Done() + s.runServer(ctx, server) + }(server) + } + wg.Wait() + + return nil +} + +func (s *Stunner) runServer(ctx context.Context, server string) { + session := s.sessions[server] + + for i, d := range retryDurations { + ctx, cancel := context.WithTimeout(ctx, d) + err := s.sendSTUN(ctx, session.tIDs[i], server) + if err != nil { + if s.Logf != nil { + s.Logf("stunner: %s: %v", server, err) + } + } + + select { + case <-ctx.Done(): + cancel() + case <-session.replied: + cancel() + if i > 0 && s.Logf != nil { + s.Logf("stunner: slow STUN response from %s: %d retries", server, i) + } + return + } + } + if s.Logf != nil { + s.Logf("stunner: no STUN response from %s", server) + } +} + +func (s *Stunner) sendSTUN(ctx context.Context, tID [12]byte, server string) error { + host, port, err := net.SplitHostPort(server) + if err != nil { + return err + } + addrPort, err := strconv.Atoi(port) + if err != nil { + return fmt.Errorf("port: %v", err) + } + if addrPort == 0 { + addrPort = 3478 + } + addr := &net.UDPAddr{Port: addrPort} + + ipAddrs, err := s.Resolver.LookupIPAddr(ctx, host) + if err != nil { + return fmt.Errorf("lookup ip addr: %v", err) + } + for _, ipAddr := range ipAddrs { + if ip4 := ipAddr.IP.To4(); ip4 != nil { + addr.IP = ip4 + addr.Zone = ipAddr.Zone + break + } + } + if addr.IP == nil { + return fmt.Errorf("cannot resolve any ipv4 addresses for %s, got: %v", server, ipAddrs) + } + + req := stun.Request(tID) + if _, err := s.Send(req, addr); err != nil { + return fmt.Errorf("Send: %v", err) + } + return nil +} + +var retryDurations = []time.Duration{ + 100 * time.Millisecond, + 100 * time.Millisecond, + 100 * time.Millisecond, + 200 * time.Millisecond, + 200 * time.Millisecond, + 400 * time.Millisecond, + 800 * time.Millisecond, + 1600 * time.Millisecond, + 3200 * time.Millisecond, +} diff --git a/stunner/stunner_test.go b/stunner/stunner_test.go new file mode 100644 index 000000000..839104d71 --- /dev/null +++ b/stunner/stunner_test.go @@ -0,0 +1,150 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package stunner + +import ( + "context" + "errors" + "fmt" + "net" + "sort" + "testing" + "time" + + "gortc.io/stun" +) + +func TestStun(t *testing.T) { + conn1, err := net.ListenPacket("udp4", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer conn1.Close() + conn2, err := net.ListenPacket("udp4", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer conn2.Close() + stunServers := []string{ + conn1.LocalAddr().String(), conn2.LocalAddr().String(), + } + + epCh := make(chan string, 16) + + localConn, err := net.ListenPacket("udp4", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + + s := &Stunner{ + Send: localConn.WriteTo, + Endpoint: func(ep string) { epCh <- ep }, + Servers: stunServers, + } + + stun1Err := make(chan error) + go func() { + stun1Err <- startSTUN(conn1, s.Receive) + }() + stun2Err := make(chan error) + go func() { + stun2Err <- startSTUNDrop1(conn2, s.Receive) + }() + + errCh := make(chan error) + go func() { + errCh <- s.Run(context.Background()) + }() + + var eps []string + select { + case ep := <-epCh: + eps = append(eps, ep) + case <-time.After(100 * time.Millisecond): + t.Fatal("missing first endpoint response") + } + select { + case ep := <-epCh: + eps = append(eps, ep) + case <-time.After(500 * time.Millisecond): + t.Fatal("missing second endpoint response") + } + sort.Strings(eps) + if want := "1.2.3.4:1234"; eps[0] != want { + t.Errorf("eps[0]=%q, want %q", eps[0], want) + } + if want := "4.5.6.7:4567"; eps[1] != want { + t.Errorf("eps[1]=%q, want %q", eps[1], want) + } + + if err := <-errCh; err != nil { + t.Fatal(err) + } +} + +func startSTUNDrop1(conn net.PacketConn, writeTo func([]byte, *net.UDPAddr)) error { + if _, _, err := conn.ReadFrom(make([]byte, 1024)); err != nil { + return fmt.Errorf("first stun server read failed: %v", err) + } + req := new(stun.Message) + res := new(stun.Message) + + p := make([]byte, 1024) + n, addr, err := conn.ReadFrom(p) + if err != nil { + return err + } + p = p[:n] + if !stun.IsMessage(p) { + return errors.New("not a STUN message") + } + if _, err := req.Write(p); err != nil { + return err + } + mappedAddr := &stun.XORMappedAddress{ + IP: net.ParseIP("1.2.3.4"), + Port: 1234, + } + software := stun.NewSoftware("endpointer") + err = res.Build(req, stun.BindingSuccess, software, mappedAddr, stun.Fingerprint) + if err != nil { + return err + } + writeTo(res.Raw, addr.(*net.UDPAddr)) + return nil +} + +func startSTUN(conn net.PacketConn, writeTo func([]byte, *net.UDPAddr)) error { + req := new(stun.Message) + res := new(stun.Message) + + p := make([]byte, 1024) + n, addr, err := conn.ReadFrom(p) + if err != nil { + return err + } + p = p[:n] + if !stun.IsMessage(p) { + return errors.New("not a STUN message") + } + if _, err := req.Write(p); err != nil { + return err + } + mappedAddr := &stun.XORMappedAddress{ + IP: net.ParseIP("4.5.6.7"), + Port: 4567, + } + software := stun.NewSoftware("endpointer") + err = res.Build(req, stun.BindingSuccess, software, mappedAddr, stun.Fingerprint) + if err != nil { + return err + } + writeTo(res.Raw, addr.(*net.UDPAddr)) + return nil +} + +// TODO: test retry timeout (overwrite the retryDurations) +// TODO: test canceling context passed to Run +// TODO: test sending bad packets diff --git a/tailcfg/tailcfg.go b/tailcfg/tailcfg.go new file mode 100644 index 000000000..6ce555ad0 --- /dev/null +++ b/tailcfg/tailcfg.go @@ -0,0 +1,359 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tailcfg + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "strings" + "time" + + "github.com/tailscale/wireguard-go/wgcfg" + "golang.org/x/oauth2" + "tailscale.com/wgengine/filter" +) + +type ID int64 + +type UserID ID + +type LoginID ID + +type NodeID ID + +type GroupID ID + +type RoleID ID + +type CapabilityID ID + +type MachineKey [32]byte + +type NodeKey [32]byte + +type Group struct { + ID GroupID + Name string + Members []ID +} + +type Role struct { + ID RoleID + Name string + Capabilities []CapabilityID +} + +type CapType string + +const ( + CapRead = CapType("read") + CapWrite = CapType("write") +) + +type Capability struct { + ID CapabilityID + Type CapType + Val ID +} + +// User is an IPN user. +// +// A user can have multiple logins associated with it (e.g. gmail and github oauth). +// (Note: none of our UIs support this yet.) +// +// Some properties are inhereted from the logins and can be overridden, such as +// display name and profile picture. +// +// Other properties must be the same for all logins associated with a user. +// In particular: domain. If a user has a "tailscale.io" domain login, they cannot +// have a general gmail address login associated with the user. +type User struct { + ID UserID + LoginName string `json:"-"` // not stored, filled from Login // TODO REMOVE + DisplayName string // if non-empty overrides Login field + ProfilePicURL string // if non-empty overrides Login field + Domain string + Logins []LoginID + Roles []RoleID + Created time.Time +} + +type Login struct { + ID LoginID + Provider string + LoginName string + DisplayName string + ProfilePicURL string + Domain string +} + +// A UserProfile is display-friendly data for a user. +// It includes the LoginName for display purposes but *not* the Provider. +// It also includes derived data from one of the user's logins. +type UserProfile struct { + ID UserID + LoginName string // for display purposes only (provider is not listed) + DisplayName string + ProfilePicURL string + Roles []RoleID +} + +type Node struct { + ID NodeID + Name string // DNS + User UserID + Key NodeKey + KeyExpiry time.Time + Machine MachineKey + Addresses []wgcfg.CIDR // IP addresses of this Node directly + AllowedIPs []wgcfg.CIDR // range of IP addresses to route to this node + Endpoints []string `json:",omitempty"` // IP+port (public via STUN, and local LANs) + Hostinfo Hostinfo + Created time.Time + LastSeen *time.Time `json:",omitempty"` + + MachineAuthorized bool // TODO(crawshaw): replace with MachineStatus + + // NOTE: any new fields containing pointers in this type + // require changes to Node.Copy. +} + +// Copy makes a deep copy of Node. +// The result aliases no memory with the original. +func (n *Node) Copy() (res *Node) { + res = new(Node) + *res = *n + + res.Addresses = append([]wgcfg.CIDR{}, res.Addresses...) + res.AllowedIPs = append([]wgcfg.CIDR{}, res.AllowedIPs...) + res.Endpoints = append([]string{}, res.Endpoints...) + if res.LastSeen != nil { + lastSeen := *res.LastSeen + res.LastSeen = &lastSeen + } + res.Hostinfo = *res.Hostinfo.Copy() + return res +} + +type MachineStatus int + +const ( + MachineUnknown = MachineStatus(iota) + MachineUnauthorized // server has yet to approve + MachineAuthorized // server has approved + MachineInvalid // server has explicitly rejected this machine key +) + +func (m MachineStatus) MarshalText() ([]byte, error) { + return []byte(m.String()), nil +} + +func (m *MachineStatus) UnmarshalText(b []byte) error { + switch string(b) { + case "machine-unknown": + *m = MachineUnknown + case "machine-unauthorized": + *m = MachineUnauthorized + case "machine-authorized": + *m = MachineAuthorized + case "machine-invalid": + *m = MachineInvalid + default: + var val int + if _, err := fmt.Sscanf(string(b), "machine-unknown(%d)", &val); err != nil { + *m = MachineStatus(val) + } else { + *m = MachineUnknown + } + } + return nil +} + +func (m MachineStatus) String() string { + switch m { + case MachineUnknown: + return "machine-unknown" + case MachineUnauthorized: + return "machine-unauthorized" + case MachineAuthorized: + return "machine-authorized" + case MachineInvalid: + return "machine-invalid" + default: + return fmt.Sprintf("machine-unknown(%d)", int(m)) + } +} + +type ServiceProto string + +const ( + TCP = ServiceProto("tcp") + UDP = ServiceProto("udp") +) + +type Service struct { + Proto ServiceProto // TCP or UDP + Port uint16 // port number service is listening on + Description string // text description of service + // TODO(apenwarr): allow advertising services on subnet IPs? + // TODO(apenwarr): add "tags" here for each service? + + // NOTE: any new fields containing pointers in this type + // require changes to Hostinfo.Copy. +} + +type Hostinfo struct { + // TODO(crawshaw): mark all these fields ",omitempty" when all the + // iOS apps are updated with the latest swift version of this struct. + IPNVersion string // version number of this code + FrontendLogID string // logtail ID of frontend instance + BackendLogID string // logtail ID of backend instance + OS string // operating system the client runs on + Hostname string // name of the host the client runs on + RoutableIPs []wgcfg.CIDR `json:",omitempty"` // set of IP ranges this client can route + Services []Service `json:",omitempty"` // services advertised by this machine + + // NOTE: any new fields containing pointers in this type + // require changes to Hostinfo.Copy. +} + +// Copy makes a deep copy of Hostinfo. +// The result aliases no memory with the original. +func (hinfo *Hostinfo) Copy() (res *Hostinfo) { + res = new(Hostinfo) + *res = *hinfo + + res.RoutableIPs = append([]wgcfg.CIDR{}, res.RoutableIPs...) + res.Services = append([]Service{}, res.Services...) + return res +} + +type RegisterRequest struct { + Version int + NodeKey NodeKey + OldNodeKey NodeKey + Auth struct { + Provider string + LoginName string + // One of LoginName or Oauth2Token is set. + Oauth2Token *oauth2.Token + } + Expiry time.Time // requested key expiry, server policy may override + Followup string // response waits until AuthURL is visited + Hostinfo Hostinfo +} + +type RegisterResponse struct { + User User + Login Login + NodeKeyExpired bool // if true, the NodeKey needs to be replaced + MachineAuthorized bool // TODO(crawshaw): move to using MachineStatus + AuthURL string // if set, authorization pending +} + +type MapRequest struct { + Version int // current version is 4 + Compress string // "zstd" or "" (no compression) + KeepAlive bool // server sends keep-alives + NodeKey NodeKey + Endpoints []string + Stream bool + Hostinfo Hostinfo +} + +type MapResponse struct { + KeepAlive bool // if set, all other fields are ignored + + // Networking + Node Node + Peers []Node + DNS []wgcfg.IP + SearchPaths []string + + // ACLs + Domain string + PacketFilter filter.Matches + UserProfiles []UserProfile + Roles []Role + // TODO: Groups []Group + // TODO: Capabilities []Capability +} + +func (k MachineKey) String() string { return fmt.Sprintf("mkey:%x", k[:]) } + +func (k MachineKey) MarshalText() ([]byte, error) { + buf := new(bytes.Buffer) + fmt.Fprintf(buf, "mkey:%x", k[:]) + return buf.Bytes(), nil +} + +func (k *MachineKey) UnmarshalText(text []byte) error { + s := string(text) + if !strings.HasPrefix(s, "mkey:") { + return errors.New(`MachineKey.UnmarshalText: missing prefix`) + } + s = strings.TrimPrefix(s, `mkey:`) + key, err := wgcfg.ParseHexKey(s) + if err != nil { + return fmt.Errorf("MachineKey.UnmarhsalText: %v", err) + } + copy(k[:], key[:]) + return nil +} + +func (k NodeKey) String() string { return fmt.Sprintf("nodekey:%x", k[:]) } + +func (k NodeKey) AbbrevString() string { + pk := wgcfg.Key(k) + return pk.ShortString() +} + +func (k NodeKey) MarshalText() ([]byte, error) { + buf := new(bytes.Buffer) + fmt.Fprintf(buf, "nodekey:%x", k[:]) + return buf.Bytes(), nil +} + +func (k *NodeKey) UnmarshalText(text []byte) error { + s := string(text) + if !strings.HasPrefix(s, "nodekey:") { + return errors.New(`Nodekey.UnmarshalText: missing prefix`) + } + s = strings.TrimPrefix(s, "nodekey:") + key, err := wgcfg.ParseHexKey(s) + if err != nil { + return fmt.Errorf("tailcfg.Ukey.UnmarhsalText: %v", err) + } + copy(k[:], key[:]) + return nil +} + +func (k *NodeKey) IsZero() bool { + z := NodeKey{} + return bytes.Equal(k[:], z[:]) +} + +func (id ID) String() string { return fmt.Sprintf("id:%x", int64(id)) } +func (id UserID) String() string { return fmt.Sprintf("userid:%x", int64(id)) } +func (id LoginID) String() string { return fmt.Sprintf("loginid:%x", int64(id)) } +func (id NodeID) String() string { return fmt.Sprintf("nodeid:%x", int64(id)) } +func (id GroupID) String() string { return fmt.Sprintf("groupid:%x", int64(id)) } +func (id RoleID) String() string { return fmt.Sprintf("roleid:%x", int64(id)) } +func (id CapabilityID) String() string { return fmt.Sprintf("capid:%x", int64(id)) } + +func (n *Node) Equal(n2 *Node) bool { + // TODO(crawshaw): this is crude, but is an easy way to avoid bugs. + b, err := json.Marshal(n) + if err != nil { + panic(err) + } + b2, err := json.Marshal(n2) + if err != nil { + panic(err) + } + return bytes.Equal(b, b2) +} diff --git a/tempfork/osexec/README.md b/tempfork/osexec/README.md new file mode 100644 index 000000000..06a60130c --- /dev/null +++ b/tempfork/osexec/README.md @@ -0,0 +1,47 @@ +This is a temporary fork of Go 1.13's os/exec package, +to work around https://github.com/golang/go/issues/36644. + +The main modification (outside of removing some tests that require +internal-only packages to run) is: + +``` +commit 3c66be240f1ee1f1b5f03bed79eb0d9f8c08965a +Author: Avery Pennarun +Date: Sun Jan 19 03:17:30 2020 -0500 + +Cmd.Wait(): handle EINTR return code from os.Process.Wait(). + +This is probably not actually the correct fix; most likely +os.Process.Wait() itself should be fixed to retry on EINTR so that it +never leaks out of that function. But if we're going to patch a +particular module, it's safer to patch a higher-level one like os/exec +rather than the os module itself. + +diff --git a/exec.go b/exec.go +index 17ef003e..5375e673 100644 +--- a/exec.go ++++ b/exec.go +@@ -498,7 +498,21 @@ func (c *Cmd) Wait() error { + } + c.finished = true + +- state, err := c.Process.Wait() ++ var err error ++ var state *os.ProcessState ++ for { ++ state, err = c.Process.Wait() ++ if err != nil { ++ xe, ok := err.(*os.SyscallError) ++ if ok { ++ if xe.Unwrap() == syscall.EINTR { ++ // temporary error, retry wait syscall ++ continue ++ } ++ } ++ } ++ break ++ } + if c.waitDone != nil { + close(c.waitDone) + } +``` diff --git a/tempfork/osexec/bench_test.go b/tempfork/osexec/bench_test.go new file mode 100644 index 000000000..9a94001e8 --- /dev/null +++ b/tempfork/osexec/bench_test.go @@ -0,0 +1,23 @@ +// Copyright 2019 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package exec + +import ( + "testing" +) + +func BenchmarkExecHostname(b *testing.B) { + b.ReportAllocs() + path, err := LookPath("hostname") + if err != nil { + b.Fatalf("could not find hostname: %v", err) + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + if err := Command(path).Run(); err != nil { + b.Fatalf("hostname: %v", err) + } + } +} diff --git a/tempfork/osexec/env_test.go b/tempfork/osexec/env_test.go new file mode 100644 index 000000000..b5ac398c2 --- /dev/null +++ b/tempfork/osexec/env_test.go @@ -0,0 +1,39 @@ +// Copyright 2017 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package exec + +import ( + "reflect" + "testing" +) + +func TestDedupEnv(t *testing.T) { + tests := []struct { + noCase bool + in []string + want []string + }{ + { + noCase: true, + in: []string{"k1=v1", "k2=v2", "K1=v3"}, + want: []string{"K1=v3", "k2=v2"}, + }, + { + noCase: false, + in: []string{"k1=v1", "K1=V2", "k1=v3"}, + want: []string{"k1=v3", "K1=V2"}, + }, + { + in: []string{"=a", "=b", "foo", "bar"}, + want: []string{"=b", "foo", "bar"}, + }, + } + for _, tt := range tests { + got := dedupEnvCase(tt.noCase, tt.in) + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Dedup(%v, %q) = %q; want %q", tt.noCase, tt.in, got, tt.want) + } + } +} diff --git a/tempfork/osexec/example_test.go b/tempfork/osexec/example_test.go new file mode 100644 index 000000000..62866fa71 --- /dev/null +++ b/tempfork/osexec/example_test.go @@ -0,0 +1,156 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package exec_test + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "io/ioutil" + "log" + "os" + "os/exec" + "strings" + "time" +) + +func ExampleLookPath() { + path, err := exec.LookPath("fortune") + if err != nil { + log.Fatal("installing fortune is in your future") + } + fmt.Printf("fortune is available at %s\n", path) +} + +func ExampleCommand() { + cmd := exec.Command("tr", "a-z", "A-Z") + cmd.Stdin = strings.NewReader("some input") + var out bytes.Buffer + cmd.Stdout = &out + err := cmd.Run() + if err != nil { + log.Fatal(err) + } + fmt.Printf("in all caps: %q\n", out.String()) +} + +func ExampleCommand_environment() { + cmd := exec.Command("prog") + cmd.Env = append(os.Environ(), + "FOO=duplicate_value", // ignored + "FOO=actual_value", // this value is used + ) + if err := cmd.Run(); err != nil { + log.Fatal(err) + } +} + +func ExampleCmd_Output() { + out, err := exec.Command("date").Output() + if err != nil { + log.Fatal(err) + } + fmt.Printf("The date is %s\n", out) +} + +func ExampleCmd_Run() { + cmd := exec.Command("sleep", "1") + log.Printf("Running command and waiting for it to finish...") + err := cmd.Run() + log.Printf("Command finished with error: %v", err) +} + +func ExampleCmd_Start() { + cmd := exec.Command("sleep", "5") + err := cmd.Start() + if err != nil { + log.Fatal(err) + } + log.Printf("Waiting for command to finish...") + err = cmd.Wait() + log.Printf("Command finished with error: %v", err) +} + +func ExampleCmd_StdoutPipe() { + cmd := exec.Command("echo", "-n", `{"Name": "Bob", "Age": 32}`) + stdout, err := cmd.StdoutPipe() + if err != nil { + log.Fatal(err) + } + if err := cmd.Start(); err != nil { + log.Fatal(err) + } + var person struct { + Name string + Age int + } + if err := json.NewDecoder(stdout).Decode(&person); err != nil { + log.Fatal(err) + } + if err := cmd.Wait(); err != nil { + log.Fatal(err) + } + fmt.Printf("%s is %d years old\n", person.Name, person.Age) +} + +func ExampleCmd_StdinPipe() { + cmd := exec.Command("cat") + stdin, err := cmd.StdinPipe() + if err != nil { + log.Fatal(err) + } + + go func() { + defer stdin.Close() + io.WriteString(stdin, "values written to stdin are passed to cmd's standard input") + }() + + out, err := cmd.CombinedOutput() + if err != nil { + log.Fatal(err) + } + + fmt.Printf("%s\n", out) +} + +func ExampleCmd_StderrPipe() { + cmd := exec.Command("sh", "-c", "echo stdout; echo 1>&2 stderr") + stderr, err := cmd.StderrPipe() + if err != nil { + log.Fatal(err) + } + + if err := cmd.Start(); err != nil { + log.Fatal(err) + } + + slurp, _ := ioutil.ReadAll(stderr) + fmt.Printf("%s\n", slurp) + + if err := cmd.Wait(); err != nil { + log.Fatal(err) + } +} + +func ExampleCmd_CombinedOutput() { + cmd := exec.Command("sh", "-c", "echo stdout; echo 1>&2 stderr") + stdoutStderr, err := cmd.CombinedOutput() + if err != nil { + log.Fatal(err) + } + fmt.Printf("%s\n", stdoutStderr) +} + +func ExampleCommandContext() { + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + if err := exec.CommandContext(ctx, "sleep", "5").Run(); err != nil { + // This will fail after 100 milliseconds. The 5 second sleep + // will be interrupted. + } +} diff --git a/tempfork/osexec/exec.go b/tempfork/osexec/exec.go new file mode 100644 index 000000000..5375e6738 --- /dev/null +++ b/tempfork/osexec/exec.go @@ -0,0 +1,797 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package exec runs external commands. It wraps os.StartProcess to make it +// easier to remap stdin and stdout, connect I/O with pipes, and do other +// adjustments. +// +// Unlike the "system" library call from C and other languages, the +// os/exec package intentionally does not invoke the system shell and +// does not expand any glob patterns or handle other expansions, +// pipelines, or redirections typically done by shells. The package +// behaves more like C's "exec" family of functions. To expand glob +// patterns, either call the shell directly, taking care to escape any +// dangerous input, or use the path/filepath package's Glob function. +// To expand environment variables, use package os's ExpandEnv. +// +// Note that the examples in this package assume a Unix system. +// They may not run on Windows, and they do not run in the Go Playground +// used by golang.org and godoc.org. +package exec + +import ( + "bytes" + "context" + "errors" + "io" + "os" + "path/filepath" + "runtime" + "strconv" + "strings" + "sync" + "syscall" +) + +// Error is returned by LookPath when it fails to classify a file as an +// executable. +type Error struct { + // Name is the file name for which the error occurred. + Name string + // Err is the underlying error. + Err error +} + +func (e *Error) Error() string { + return "exec: " + strconv.Quote(e.Name) + ": " + e.Err.Error() +} + +func (e *Error) Unwrap() error { return e.Err } + +// Cmd represents an external command being prepared or run. +// +// A Cmd cannot be reused after calling its Run, Output or CombinedOutput +// methods. +type Cmd struct { + // Path is the path of the command to run. + // + // This is the only field that must be set to a non-zero + // value. If Path is relative, it is evaluated relative + // to Dir. + Path string + + // Args holds command line arguments, including the command as Args[0]. + // If the Args field is empty or nil, Run uses {Path}. + // + // In typical use, both Path and Args are set by calling Command. + Args []string + + // Env specifies the environment of the process. + // Each entry is of the form "key=value". + // If Env is nil, the new process uses the current process's + // environment. + // If Env contains duplicate environment keys, only the last + // value in the slice for each duplicate key is used. + // As a special case on Windows, SYSTEMROOT is always added if + // missing and not explicitly set to the empty string. + Env []string + + // Dir specifies the working directory of the command. + // If Dir is the empty string, Run runs the command in the + // calling process's current directory. + Dir string + + // Stdin specifies the process's standard input. + // + // If Stdin is nil, the process reads from the null device (os.DevNull). + // + // If Stdin is an *os.File, the process's standard input is connected + // directly to that file. + // + // Otherwise, during the execution of the command a separate + // goroutine reads from Stdin and delivers that data to the command + // over a pipe. In this case, Wait does not complete until the goroutine + // stops copying, either because it has reached the end of Stdin + // (EOF or a read error) or because writing to the pipe returned an error. + Stdin io.Reader + + // Stdout and Stderr specify the process's standard output and error. + // + // If either is nil, Run connects the corresponding file descriptor + // to the null device (os.DevNull). + // + // If either is an *os.File, the corresponding output from the process + // is connected directly to that file. + // + // Otherwise, during the execution of the command a separate goroutine + // reads from the process over a pipe and delivers that data to the + // corresponding Writer. In this case, Wait does not complete until the + // goroutine reaches EOF or encounters an error. + // + // If Stdout and Stderr are the same writer, and have a type that can + // be compared with ==, at most one goroutine at a time will call Write. + Stdout io.Writer + Stderr io.Writer + + // ExtraFiles specifies additional open files to be inherited by the + // new process. It does not include standard input, standard output, or + // standard error. If non-nil, entry i becomes file descriptor 3+i. + // + // ExtraFiles is not supported on Windows. + ExtraFiles []*os.File + + // SysProcAttr holds optional, operating system-specific attributes. + // Run passes it to os.StartProcess as the os.ProcAttr's Sys field. + SysProcAttr *syscall.SysProcAttr + + // Process is the underlying process, once started. + Process *os.Process + + // ProcessState contains information about an exited process, + // available after a call to Wait or Run. + ProcessState *os.ProcessState + + ctx context.Context // nil means none + lookPathErr error // LookPath error, if any. + finished bool // when Wait was called + childFiles []*os.File + closeAfterStart []io.Closer + closeAfterWait []io.Closer + goroutine []func() error + errch chan error // one send per goroutine + waitDone chan struct{} +} + +// Command returns the Cmd struct to execute the named program with +// the given arguments. +// +// It sets only the Path and Args in the returned structure. +// +// If name contains no path separators, Command uses LookPath to +// resolve name to a complete path if possible. Otherwise it uses name +// directly as Path. +// +// The returned Cmd's Args field is constructed from the command name +// followed by the elements of arg, so arg should not include the +// command name itself. For example, Command("echo", "hello"). +// Args[0] is always name, not the possibly resolved Path. +// +// On Windows, processes receive the whole command line as a single string +// and do their own parsing. Command combines and quotes Args into a command +// line string with an algorithm compatible with applications using +// CommandLineToArgvW (which is the most common way). Notable exceptions are +// msiexec.exe and cmd.exe (and thus, all batch files), which have a different +// unquoting algorithm. In these or other similar cases, you can do the +// quoting yourself and provide the full command line in SysProcAttr.CmdLine, +// leaving Args empty. +func Command(name string, arg ...string) *Cmd { + cmd := &Cmd{ + Path: name, + Args: append([]string{name}, arg...), + } + if filepath.Base(name) == name { + if lp, err := LookPath(name); err != nil { + cmd.lookPathErr = err + } else { + cmd.Path = lp + } + } + return cmd +} + +// CommandContext is like Command but includes a context. +// +// The provided context is used to kill the process (by calling +// os.Process.Kill) if the context becomes done before the command +// completes on its own. +func CommandContext(ctx context.Context, name string, arg ...string) *Cmd { + if ctx == nil { + panic("nil Context") + } + cmd := Command(name, arg...) + cmd.ctx = ctx + return cmd +} + +// String returns a human-readable description of c. +// It is intended only for debugging. +// In particular, it is not suitable for use as input to a shell. +// The output of String may vary across Go releases. +func (c *Cmd) String() string { + if c.lookPathErr != nil { + // failed to resolve path; report the original requested path (plus args) + return strings.Join(c.Args, " ") + } + // report the exact executable path (plus args) + b := new(strings.Builder) + b.WriteString(c.Path) + for _, a := range c.Args[1:] { + b.WriteByte(' ') + b.WriteString(a) + } + return b.String() +} + +// interfaceEqual protects against panics from doing equality tests on +// two interfaces with non-comparable underlying types. +func interfaceEqual(a, b interface{}) bool { + defer func() { + recover() + }() + return a == b +} + +func (c *Cmd) envv() []string { + if c.Env != nil { + return c.Env + } + return os.Environ() +} + +func (c *Cmd) argv() []string { + if len(c.Args) > 0 { + return c.Args + } + return []string{c.Path} +} + +// skipStdinCopyError optionally specifies a function which reports +// whether the provided stdin copy error should be ignored. +// It is non-nil everywhere but Plan 9, which lacks EPIPE. See exec_posix.go. +var skipStdinCopyError func(error) bool + +func (c *Cmd) stdin() (f *os.File, err error) { + if c.Stdin == nil { + f, err = os.Open(os.DevNull) + if err != nil { + return + } + c.closeAfterStart = append(c.closeAfterStart, f) + return + } + + if f, ok := c.Stdin.(*os.File); ok { + return f, nil + } + + pr, pw, err := os.Pipe() + if err != nil { + return + } + + c.closeAfterStart = append(c.closeAfterStart, pr) + c.closeAfterWait = append(c.closeAfterWait, pw) + c.goroutine = append(c.goroutine, func() error { + _, err := io.Copy(pw, c.Stdin) + if skip := skipStdinCopyError; skip != nil && skip(err) { + err = nil + } + if err1 := pw.Close(); err == nil { + err = err1 + } + return err + }) + return pr, nil +} + +func (c *Cmd) stdout() (f *os.File, err error) { + return c.writerDescriptor(c.Stdout) +} + +func (c *Cmd) stderr() (f *os.File, err error) { + if c.Stderr != nil && interfaceEqual(c.Stderr, c.Stdout) { + return c.childFiles[1], nil + } + return c.writerDescriptor(c.Stderr) +} + +func (c *Cmd) writerDescriptor(w io.Writer) (f *os.File, err error) { + if w == nil { + f, err = os.OpenFile(os.DevNull, os.O_WRONLY, 0) + if err != nil { + return + } + c.closeAfterStart = append(c.closeAfterStart, f) + return + } + + if f, ok := w.(*os.File); ok { + return f, nil + } + + pr, pw, err := os.Pipe() + if err != nil { + return + } + + c.closeAfterStart = append(c.closeAfterStart, pw) + c.closeAfterWait = append(c.closeAfterWait, pr) + c.goroutine = append(c.goroutine, func() error { + _, err := io.Copy(w, pr) + pr.Close() // in case io.Copy stopped due to write error + return err + }) + return pw, nil +} + +func (c *Cmd) closeDescriptors(closers []io.Closer) { + for _, fd := range closers { + fd.Close() + } +} + +// Run starts the specified command and waits for it to complete. +// +// The returned error is nil if the command runs, has no problems +// copying stdin, stdout, and stderr, and exits with a zero exit +// status. +// +// If the command starts but does not complete successfully, the error is of +// type *ExitError. Other error types may be returned for other situations. +// +// If the calling goroutine has locked the operating system thread +// with runtime.LockOSThread and modified any inheritable OS-level +// thread state (for example, Linux or Plan 9 name spaces), the new +// process will inherit the caller's thread state. +func (c *Cmd) Run() error { + if err := c.Start(); err != nil { + return err + } + return c.Wait() +} + +// lookExtensions finds windows executable by its dir and path. +// It uses LookPath to try appropriate extensions. +// lookExtensions does not search PATH, instead it converts `prog` into `.\prog`. +func lookExtensions(path, dir string) (string, error) { + if filepath.Base(path) == path { + path = filepath.Join(".", path) + } + if dir == "" { + return LookPath(path) + } + if filepath.VolumeName(path) != "" { + return LookPath(path) + } + if len(path) > 1 && os.IsPathSeparator(path[0]) { + return LookPath(path) + } + dirandpath := filepath.Join(dir, path) + // We assume that LookPath will only add file extension. + lp, err := LookPath(dirandpath) + if err != nil { + return "", err + } + ext := strings.TrimPrefix(lp, dirandpath) + return path + ext, nil +} + +// Start starts the specified command but does not wait for it to complete. +// +// The Wait method will return the exit code and release associated resources +// once the command exits. +func (c *Cmd) Start() error { + if c.lookPathErr != nil { + c.closeDescriptors(c.closeAfterStart) + c.closeDescriptors(c.closeAfterWait) + return c.lookPathErr + } + if runtime.GOOS == "windows" { + lp, err := lookExtensions(c.Path, c.Dir) + if err != nil { + c.closeDescriptors(c.closeAfterStart) + c.closeDescriptors(c.closeAfterWait) + return err + } + c.Path = lp + } + if c.Process != nil { + return errors.New("exec: already started") + } + if c.ctx != nil { + select { + case <-c.ctx.Done(): + c.closeDescriptors(c.closeAfterStart) + c.closeDescriptors(c.closeAfterWait) + return c.ctx.Err() + default: + } + } + + c.childFiles = make([]*os.File, 0, 3+len(c.ExtraFiles)) + type F func(*Cmd) (*os.File, error) + for _, setupFd := range []F{(*Cmd).stdin, (*Cmd).stdout, (*Cmd).stderr} { + fd, err := setupFd(c) + if err != nil { + c.closeDescriptors(c.closeAfterStart) + c.closeDescriptors(c.closeAfterWait) + return err + } + c.childFiles = append(c.childFiles, fd) + } + c.childFiles = append(c.childFiles, c.ExtraFiles...) + + var err error + c.Process, err = os.StartProcess(c.Path, c.argv(), &os.ProcAttr{ + Dir: c.Dir, + Files: c.childFiles, + Env: addCriticalEnv(dedupEnv(c.envv())), + Sys: c.SysProcAttr, + }) + if err != nil { + c.closeDescriptors(c.closeAfterStart) + c.closeDescriptors(c.closeAfterWait) + return err + } + + c.closeDescriptors(c.closeAfterStart) + + // Don't allocate the channel unless there are goroutines to fire. + if len(c.goroutine) > 0 { + c.errch = make(chan error, len(c.goroutine)) + for _, fn := range c.goroutine { + go func(fn func() error) { + c.errch <- fn() + }(fn) + } + } + + if c.ctx != nil { + c.waitDone = make(chan struct{}) + go func() { + select { + case <-c.ctx.Done(): + c.Process.Kill() + case <-c.waitDone: + } + }() + } + + return nil +} + +// An ExitError reports an unsuccessful exit by a command. +type ExitError struct { + *os.ProcessState + + // Stderr holds a subset of the standard error output from the + // Cmd.Output method if standard error was not otherwise being + // collected. + // + // If the error output is long, Stderr may contain only a prefix + // and suffix of the output, with the middle replaced with + // text about the number of omitted bytes. + // + // Stderr is provided for debugging, for inclusion in error messages. + // Users with other needs should redirect Cmd.Stderr as needed. + Stderr []byte +} + +func (e *ExitError) Error() string { + return e.ProcessState.String() +} + +// Wait waits for the command to exit and waits for any copying to +// stdin or copying from stdout or stderr to complete. +// +// The command must have been started by Start. +// +// The returned error is nil if the command runs, has no problems +// copying stdin, stdout, and stderr, and exits with a zero exit +// status. +// +// If the command fails to run or doesn't complete successfully, the +// error is of type *ExitError. Other error types may be +// returned for I/O problems. +// +// If any of c.Stdin, c.Stdout or c.Stderr are not an *os.File, Wait also waits +// for the respective I/O loop copying to or from the process to complete. +// +// Wait releases any resources associated with the Cmd. +func (c *Cmd) Wait() error { + if c.Process == nil { + return errors.New("exec: not started") + } + if c.finished { + return errors.New("exec: Wait was already called") + } + c.finished = true + + var err error + var state *os.ProcessState + for { + state, err = c.Process.Wait() + if err != nil { + xe, ok := err.(*os.SyscallError) + if ok { + if xe.Unwrap() == syscall.EINTR { + // temporary error, retry wait syscall + continue + } + } + } + break + } + if c.waitDone != nil { + close(c.waitDone) + } + c.ProcessState = state + + var copyError error + for range c.goroutine { + if err := <-c.errch; err != nil && copyError == nil { + copyError = err + } + } + + c.closeDescriptors(c.closeAfterWait) + + if err != nil { + return err + } else if !state.Success() { + return &ExitError{ProcessState: state} + } + + return copyError +} + +// Output runs the command and returns its standard output. +// Any returned error will usually be of type *ExitError. +// If c.Stderr was nil, Output populates ExitError.Stderr. +func (c *Cmd) Output() ([]byte, error) { + if c.Stdout != nil { + return nil, errors.New("exec: Stdout already set") + } + var stdout bytes.Buffer + c.Stdout = &stdout + + captureErr := c.Stderr == nil + if captureErr { + c.Stderr = &prefixSuffixSaver{N: 32 << 10} + } + + err := c.Run() + if err != nil && captureErr { + if ee, ok := err.(*ExitError); ok { + ee.Stderr = c.Stderr.(*prefixSuffixSaver).Bytes() + } + } + return stdout.Bytes(), err +} + +// CombinedOutput runs the command and returns its combined standard +// output and standard error. +func (c *Cmd) CombinedOutput() ([]byte, error) { + if c.Stdout != nil { + return nil, errors.New("exec: Stdout already set") + } + if c.Stderr != nil { + return nil, errors.New("exec: Stderr already set") + } + var b bytes.Buffer + c.Stdout = &b + c.Stderr = &b + err := c.Run() + return b.Bytes(), err +} + +// StdinPipe returns a pipe that will be connected to the command's +// standard input when the command starts. +// The pipe will be closed automatically after Wait sees the command exit. +// A caller need only call Close to force the pipe to close sooner. +// For example, if the command being run will not exit until standard input +// is closed, the caller must close the pipe. +func (c *Cmd) StdinPipe() (io.WriteCloser, error) { + if c.Stdin != nil { + return nil, errors.New("exec: Stdin already set") + } + if c.Process != nil { + return nil, errors.New("exec: StdinPipe after process started") + } + pr, pw, err := os.Pipe() + if err != nil { + return nil, err + } + c.Stdin = pr + c.closeAfterStart = append(c.closeAfterStart, pr) + wc := &closeOnce{File: pw} + c.closeAfterWait = append(c.closeAfterWait, wc) + return wc, nil +} + +type closeOnce struct { + *os.File + + once sync.Once + err error +} + +func (c *closeOnce) Close() error { + c.once.Do(c.close) + return c.err +} + +func (c *closeOnce) close() { + c.err = c.File.Close() +} + +// StdoutPipe returns a pipe that will be connected to the command's +// standard output when the command starts. +// +// Wait will close the pipe after seeing the command exit, so most callers +// need not close the pipe themselves; however, an implication is that +// it is incorrect to call Wait before all reads from the pipe have completed. +// For the same reason, it is incorrect to call Run when using StdoutPipe. +// See the example for idiomatic usage. +func (c *Cmd) StdoutPipe() (io.ReadCloser, error) { + if c.Stdout != nil { + return nil, errors.New("exec: Stdout already set") + } + if c.Process != nil { + return nil, errors.New("exec: StdoutPipe after process started") + } + pr, pw, err := os.Pipe() + if err != nil { + return nil, err + } + c.Stdout = pw + c.closeAfterStart = append(c.closeAfterStart, pw) + c.closeAfterWait = append(c.closeAfterWait, pr) + return pr, nil +} + +// StderrPipe returns a pipe that will be connected to the command's +// standard error when the command starts. +// +// Wait will close the pipe after seeing the command exit, so most callers +// need not close the pipe themselves; however, an implication is that +// it is incorrect to call Wait before all reads from the pipe have completed. +// For the same reason, it is incorrect to use Run when using StderrPipe. +// See the StdoutPipe example for idiomatic usage. +func (c *Cmd) StderrPipe() (io.ReadCloser, error) { + if c.Stderr != nil { + return nil, errors.New("exec: Stderr already set") + } + if c.Process != nil { + return nil, errors.New("exec: StderrPipe after process started") + } + pr, pw, err := os.Pipe() + if err != nil { + return nil, err + } + c.Stderr = pw + c.closeAfterStart = append(c.closeAfterStart, pw) + c.closeAfterWait = append(c.closeAfterWait, pr) + return pr, nil +} + +// prefixSuffixSaver is an io.Writer which retains the first N bytes +// and the last N bytes written to it. The Bytes() methods reconstructs +// it with a pretty error message. +type prefixSuffixSaver struct { + N int // max size of prefix or suffix + prefix []byte + suffix []byte // ring buffer once len(suffix) == N + suffixOff int // offset to write into suffix + skipped int64 + + // TODO(bradfitz): we could keep one large []byte and use part of it for + // the prefix, reserve space for the '... Omitting N bytes ...' message, + // then the ring buffer suffix, and just rearrange the ring buffer + // suffix when Bytes() is called, but it doesn't seem worth it for + // now just for error messages. It's only ~64KB anyway. +} + +func (w *prefixSuffixSaver) Write(p []byte) (n int, err error) { + lenp := len(p) + p = w.fill(&w.prefix, p) + + // Only keep the last w.N bytes of suffix data. + if overage := len(p) - w.N; overage > 0 { + p = p[overage:] + w.skipped += int64(overage) + } + p = w.fill(&w.suffix, p) + + // w.suffix is full now if p is non-empty. Overwrite it in a circle. + for len(p) > 0 { // 0, 1, or 2 iterations. + n := copy(w.suffix[w.suffixOff:], p) + p = p[n:] + w.skipped += int64(n) + w.suffixOff += n + if w.suffixOff == w.N { + w.suffixOff = 0 + } + } + return lenp, nil +} + +// fill appends up to len(p) bytes of p to *dst, such that *dst does not +// grow larger than w.N. It returns the un-appended suffix of p. +func (w *prefixSuffixSaver) fill(dst *[]byte, p []byte) (pRemain []byte) { + if remain := w.N - len(*dst); remain > 0 { + add := minInt(len(p), remain) + *dst = append(*dst, p[:add]...) + p = p[add:] + } + return p +} + +func (w *prefixSuffixSaver) Bytes() []byte { + if w.suffix == nil { + return w.prefix + } + if w.skipped == 0 { + return append(w.prefix, w.suffix...) + } + var buf bytes.Buffer + buf.Grow(len(w.prefix) + len(w.suffix) + 50) + buf.Write(w.prefix) + buf.WriteString("\n... omitting ") + buf.WriteString(strconv.FormatInt(w.skipped, 10)) + buf.WriteString(" bytes ...\n") + buf.Write(w.suffix[w.suffixOff:]) + buf.Write(w.suffix[:w.suffixOff]) + return buf.Bytes() +} + +func minInt(a, b int) int { + if a < b { + return a + } + return b +} + +// dedupEnv returns a copy of env with any duplicates removed, in favor of +// later values. +// Items not of the normal environment "key=value" form are preserved unchanged. +func dedupEnv(env []string) []string { + return dedupEnvCase(runtime.GOOS == "windows", env) +} + +// dedupEnvCase is dedupEnv with a case option for testing. +// If caseInsensitive is true, the case of keys is ignored. +func dedupEnvCase(caseInsensitive bool, env []string) []string { + out := make([]string, 0, len(env)) + saw := make(map[string]int, len(env)) // key => index into out + for _, kv := range env { + eq := strings.Index(kv, "=") + if eq < 0 { + out = append(out, kv) + continue + } + k := kv[:eq] + if caseInsensitive { + k = strings.ToLower(k) + } + if dupIdx, isDup := saw[k]; isDup { + out[dupIdx] = kv + continue + } + saw[k] = len(out) + out = append(out, kv) + } + return out +} + +// addCriticalEnv adds any critical environment variables that are required +// (or at least almost always required) on the operating system. +// Currently this is only used for Windows. +func addCriticalEnv(env []string) []string { + if runtime.GOOS != "windows" { + return env + } + for _, kv := range env { + eq := strings.Index(kv, "=") + if eq < 0 { + continue + } + k := kv[:eq] + if strings.EqualFold(k, "SYSTEMROOT") { + // We already have it. + return env + } + } + return append(env, "SYSTEMROOT="+os.Getenv("SYSTEMROOT")) +} diff --git a/tempfork/osexec/exec_unix.go b/tempfork/osexec/exec_unix.go new file mode 100644 index 000000000..9c3e17d23 --- /dev/null +++ b/tempfork/osexec/exec_unix.go @@ -0,0 +1,24 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build !plan9,!windows + +package exec + +import ( + "os" + "syscall" +) + +func init() { + skipStdinCopyError = func(err error) bool { + // Ignore EPIPE errors copying to stdin if the program + // completed successfully otherwise. + // See Issue 9173. + pe, ok := err.(*os.PathError) + return ok && + pe.Op == "write" && pe.Path == "|1" && + pe.Err == syscall.EPIPE + } +} diff --git a/tempfork/osexec/exec_windows.go b/tempfork/osexec/exec_windows.go new file mode 100644 index 000000000..af8cd9721 --- /dev/null +++ b/tempfork/osexec/exec_windows.go @@ -0,0 +1,23 @@ +// Copyright 2017 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package exec + +import ( + "os" + "syscall" +) + +func init() { + skipStdinCopyError = func(err error) bool { + // Ignore ERROR_BROKEN_PIPE and ERROR_NO_DATA errors copying + // to stdin if the program completed successfully otherwise. + // See Issue 20445. + const _ERROR_NO_DATA = syscall.Errno(0xe8) + pe, ok := err.(*os.PathError) + return ok && + pe.Op == "write" && pe.Path == "|1" && + (pe.Err == syscall.ERROR_BROKEN_PIPE || pe.Err == _ERROR_NO_DATA) + } +} diff --git a/tempfork/osexec/internal_test.go b/tempfork/osexec/internal_test.go new file mode 100644 index 000000000..68d517ffb --- /dev/null +++ b/tempfork/osexec/internal_test.go @@ -0,0 +1,61 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package exec + +import ( + "io" + "testing" +) + +func TestPrefixSuffixSaver(t *testing.T) { + tests := []struct { + N int + writes []string + want string + }{ + { + N: 2, + writes: nil, + want: "", + }, + { + N: 2, + writes: []string{"a"}, + want: "a", + }, + { + N: 2, + writes: []string{"abc", "d"}, + want: "abcd", + }, + { + N: 2, + writes: []string{"abc", "d", "e"}, + want: "ab\n... omitting 1 bytes ...\nde", + }, + { + N: 2, + writes: []string{"ab______________________yz"}, + want: "ab\n... omitting 22 bytes ...\nyz", + }, + { + N: 2, + writes: []string{"ab_______________________y", "z"}, + want: "ab\n... omitting 23 bytes ...\nyz", + }, + } + for i, tt := range tests { + w := &prefixSuffixSaver{N: tt.N} + for _, s := range tt.writes { + n, err := io.WriteString(w, s) + if err != nil || n != len(s) { + t.Errorf("%d. WriteString(%q) = %v, %v; want %v, %v", i, s, n, err, len(s), nil) + } + } + if got := string(w.Bytes()); got != tt.want { + t.Errorf("%d. Bytes = %q; want %q", i, got, tt.want) + } + } +} diff --git a/tempfork/osexec/lp_js.go b/tempfork/osexec/lp_js.go new file mode 100644 index 000000000..6750fb99b --- /dev/null +++ b/tempfork/osexec/lp_js.go @@ -0,0 +1,23 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build js,wasm + +package exec + +import ( + "errors" +) + +// ErrNotFound is the error resulting if a path search failed to find an executable file. +var ErrNotFound = errors.New("executable file not found in $PATH") + +// LookPath searches for an executable named file in the +// directories named by the PATH environment variable. +// If file contains a slash, it is tried directly and the PATH is not consulted. +// The result may be an absolute path or a path relative to the current directory. +func LookPath(file string) (string, error) { + // Wasm can not execute processes, so act as if there are no executables at all. + return "", &Error{file, ErrNotFound} +} diff --git a/tempfork/osexec/lp_plan9.go b/tempfork/osexec/lp_plan9.go new file mode 100644 index 000000000..5860cbca4 --- /dev/null +++ b/tempfork/osexec/lp_plan9.go @@ -0,0 +1,55 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package exec + +import ( + "errors" + "os" + "path/filepath" + "strings" +) + +// ErrNotFound is the error resulting if a path search failed to find an executable file. +var ErrNotFound = errors.New("executable file not found in $path") + +func findExecutable(file string) error { + d, err := os.Stat(file) + if err != nil { + return err + } + if m := d.Mode(); !m.IsDir() && m&0111 != 0 { + return nil + } + return os.ErrPermission +} + +// LookPath searches for an executable named file in the +// directories named by the path environment variable. +// If file begins with "/", "#", "./", or "../", it is tried +// directly and the path is not consulted. +// The result may be an absolute path or a path relative to the current directory. +func LookPath(file string) (string, error) { + // skip the path lookup for these prefixes + skip := []string{"/", "#", "./", "../"} + + for _, p := range skip { + if strings.HasPrefix(file, p) { + err := findExecutable(file) + if err == nil { + return file, nil + } + return "", &Error{file, err} + } + } + + path := os.Getenv("path") + for _, dir := range filepath.SplitList(path) { + path := filepath.Join(dir, file) + if err := findExecutable(path); err == nil { + return path, nil + } + } + return "", &Error{file, ErrNotFound} +} diff --git a/tempfork/osexec/lp_test.go b/tempfork/osexec/lp_test.go new file mode 100644 index 000000000..77d8e848c --- /dev/null +++ b/tempfork/osexec/lp_test.go @@ -0,0 +1,33 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package exec + +import ( + "testing" +) + +var nonExistentPaths = []string{ + "some-non-existent-path", + "non-existent-path/slashed", +} + +func TestLookPathNotFound(t *testing.T) { + for _, name := range nonExistentPaths { + path, err := LookPath(name) + if err == nil { + t.Fatalf("LookPath found %q in $PATH", name) + } + if path != "" { + t.Fatalf("LookPath path == %q when err != nil", path) + } + perr, ok := err.(*Error) + if !ok { + t.Fatal("LookPath error is not an exec.Error") + } + if perr.Name != name { + t.Fatalf("want Error name %q, got %q", name, perr.Name) + } + } +} diff --git a/tempfork/osexec/lp_unix.go b/tempfork/osexec/lp_unix.go new file mode 100644 index 000000000..799e0b4ee --- /dev/null +++ b/tempfork/osexec/lp_unix.go @@ -0,0 +1,58 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build aix darwin dragonfly freebsd linux nacl netbsd openbsd solaris + +package exec + +import ( + "errors" + "os" + "path/filepath" + "strings" +) + +// ErrNotFound is the error resulting if a path search failed to find an executable file. +var ErrNotFound = errors.New("executable file not found in $PATH") + +func findExecutable(file string) error { + d, err := os.Stat(file) + if err != nil { + return err + } + if m := d.Mode(); !m.IsDir() && m&0111 != 0 { + return nil + } + return os.ErrPermission +} + +// LookPath searches for an executable named file in the +// directories named by the PATH environment variable. +// If file contains a slash, it is tried directly and the PATH is not consulted. +// The result may be an absolute path or a path relative to the current directory. +func LookPath(file string) (string, error) { + // NOTE(rsc): I wish we could use the Plan 9 behavior here + // (only bypass the path if file begins with / or ./ or ../) + // but that would not match all the Unix shells. + + if strings.Contains(file, "/") { + err := findExecutable(file) + if err == nil { + return file, nil + } + return "", &Error{file, err} + } + path := os.Getenv("PATH") + for _, dir := range filepath.SplitList(path) { + if dir == "" { + // Unix shell semantics: path element "" means "." + dir = "." + } + path := filepath.Join(dir, file) + if err := findExecutable(path); err == nil { + return path, nil + } + } + return "", &Error{file, ErrNotFound} +} diff --git a/tempfork/osexec/lp_unix_test.go b/tempfork/osexec/lp_unix_test.go new file mode 100644 index 000000000..e4656cafb --- /dev/null +++ b/tempfork/osexec/lp_unix_test.go @@ -0,0 +1,55 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build aix darwin dragonfly freebsd linux netbsd openbsd solaris + +package exec + +import ( + "io/ioutil" + "os" + "testing" +) + +func TestLookPathUnixEmptyPath(t *testing.T) { + tmp, err := ioutil.TempDir("", "TestLookPathUnixEmptyPath") + if err != nil { + t.Fatal("TempDir failed: ", err) + } + defer os.RemoveAll(tmp) + wd, err := os.Getwd() + if err != nil { + t.Fatal("Getwd failed: ", err) + } + err = os.Chdir(tmp) + if err != nil { + t.Fatal("Chdir failed: ", err) + } + defer os.Chdir(wd) + + f, err := os.OpenFile("exec_me", os.O_CREATE|os.O_EXCL, 0700) + if err != nil { + t.Fatal("OpenFile failed: ", err) + } + err = f.Close() + if err != nil { + t.Fatal("Close failed: ", err) + } + + pathenv := os.Getenv("PATH") + defer os.Setenv("PATH", pathenv) + + err = os.Setenv("PATH", "") + if err != nil { + t.Fatal("Setenv failed: ", err) + } + + path, err := LookPath("exec_me") + if err == nil { + t.Fatal("LookPath found exec_me in empty $PATH") + } + if path != "" { + t.Fatalf("LookPath path == %q when err != nil", path) + } +} diff --git a/tempfork/osexec/lp_windows.go b/tempfork/osexec/lp_windows.go new file mode 100644 index 000000000..9ea3d7657 --- /dev/null +++ b/tempfork/osexec/lp_windows.go @@ -0,0 +1,93 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package exec + +import ( + "errors" + "os" + "path/filepath" + "strings" +) + +// ErrNotFound is the error resulting if a path search failed to find an executable file. +var ErrNotFound = errors.New("executable file not found in %PATH%") + +func chkStat(file string) error { + d, err := os.Stat(file) + if err != nil { + return err + } + if d.IsDir() { + return os.ErrPermission + } + return nil +} + +func hasExt(file string) bool { + i := strings.LastIndex(file, ".") + if i < 0 { + return false + } + return strings.LastIndexAny(file, `:\/`) < i +} + +func findExecutable(file string, exts []string) (string, error) { + if len(exts) == 0 { + return file, chkStat(file) + } + if hasExt(file) { + if chkStat(file) == nil { + return file, nil + } + } + for _, e := range exts { + if f := file + e; chkStat(f) == nil { + return f, nil + } + } + return "", os.ErrNotExist +} + +// LookPath searches for an executable named file in the +// directories named by the PATH environment variable. +// If file contains a slash, it is tried directly and the PATH is not consulted. +// LookPath also uses PATHEXT environment variable to match +// a suitable candidate. +// The result may be an absolute path or a path relative to the current directory. +func LookPath(file string) (string, error) { + var exts []string + x := os.Getenv(`PATHEXT`) + if x != "" { + for _, e := range strings.Split(strings.ToLower(x), `;`) { + if e == "" { + continue + } + if e[0] != '.' { + e = "." + e + } + exts = append(exts, e) + } + } else { + exts = []string{".com", ".exe", ".bat", ".cmd"} + } + + if strings.ContainsAny(file, `:\/`) { + if f, err := findExecutable(file, exts); err == nil { + return f, nil + } else { + return "", &Error{file, err} + } + } + if f, err := findExecutable(filepath.Join(".", file), exts); err == nil { + return f, nil + } + path := os.Getenv("path") + for _, dir := range filepath.SplitList(path) { + if f, err := findExecutable(filepath.Join(dir, file), exts); err == nil { + return f, nil + } + } + return "", &Error{file, ErrNotFound} +} diff --git a/testy/log.go b/testy/log.go new file mode 100644 index 000000000..43674afce --- /dev/null +++ b/testy/log.go @@ -0,0 +1,30 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package testy + +import ( + "log" + "os" + "testing" +) + +type testLogWriter struct { + t *testing.T +} + +func (w *testLogWriter) Write(b []byte) (int, error) { + w.t.Helper() + w.t.Logf("%s", b) + return len(b), nil +} + +func FixLogs(t *testing.T) { + log.SetFlags(log.Ltime | log.Lshortfile) + log.SetOutput(&testLogWriter{t}) +} + +func UnfixLogs(t *testing.T) { + defer log.SetOutput(os.Stderr) +} diff --git a/testy/resource.go b/testy/resource.go new file mode 100644 index 000000000..2c12bb57d --- /dev/null +++ b/testy/resource.go @@ -0,0 +1,72 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package testy + +import ( + "bytes" + "runtime" + "runtime/pprof" + "testing" + "time" + + "github.com/google/go-cmp/cmp" +) + +type ResourceCheck struct { + startNumRoutines int + startDump string +} + +func NewResourceCheck() *ResourceCheck { + // NOTE(apenwarr): I'd rather not pre-generate a goroutine dump here. + // However, it turns out to be tricky to debug when eg. the initial + // goroutine count > the ending goroutine count, because of course + // the missing ones are not in the final dump. Also, we have to + // render the profile as a string right away, because the + // pprof.Profile object doesn't stay stable over time. Every time + // you render the string, you might get a different answer. + r := &ResourceCheck{} + r.startNumRoutines, r.startDump = goroutineDump() + return r +} + +func goroutineDump() (int, string) { + p := pprof.Lookup("goroutine") + b := &bytes.Buffer{} + p.WriteTo(b, 1) + return p.Count(), b.String() +} + +func (r *ResourceCheck) Assert(t *testing.T) { + t.Helper() + want := r.startNumRoutines + + // Some goroutines might be still exiting, so give them a chance + got := runtime.NumGoroutine() + if want != got { + _, dump := goroutineDump() + for i := 0; i < 100; i++ { + got = runtime.NumGoroutine() + if want == got { + break + } + time.Sleep(1 * time.Millisecond) + } + + // If the count is *still* wrong, that's a failure. + if want != got { + t.Logf("goroutine diff:\n%v\n", cmp.Diff(r.startDump, dump)) + t.Logf("goroutine count: expected %d, got %d\n", want, got) + // Don't fail if there are *fewer* goroutines than + // expected. That just might be some leftover ones + // from the previous test, which are pretty hard to + // eliminate. + if want < got { + t.Fatalf("goroutine count: expected %d, got %d\n", want, got) + } + } + } + t.Logf("Assert: goroutines before=%d after=%d - ok\n", got, want) +} diff --git a/version/.gitignore b/version/.gitignore new file mode 100644 index 000000000..e6b904135 --- /dev/null +++ b/version/.gitignore @@ -0,0 +1,6 @@ +long.txt +short.txt +version.h +version.xcconfig +ver.go +version diff --git a/version/GENERATE.go b/version/GENERATE.go new file mode 100644 index 000000000..675a497cb --- /dev/null +++ b/version/GENERATE.go @@ -0,0 +1,8 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Placeholder that indicates this directory is a valid go package, +// but that redo must 'redo all' in this directory before it can +// be imported. +package version diff --git a/version/all.do b/version/all.do new file mode 100644 index 000000000..d0bdd8ead --- /dev/null +++ b/version/all.do @@ -0,0 +1,2 @@ +redo-ifchange ver.go version.xcconfig version.h + diff --git a/version/clean.do b/version/clean.do new file mode 100644 index 000000000..5ddcb4f31 --- /dev/null +++ b/version/clean.do @@ -0,0 +1 @@ +rm -f *~ .*~ long.txt short.txt version.xcconfig ver.go version.h version diff --git a/version/long.txt.do b/version/long.txt.do new file mode 100644 index 000000000..c62151806 --- /dev/null +++ b/version/long.txt.do @@ -0,0 +1,10 @@ +ver=$(git describe | sed 's/^v//') +if [ "$ver" = "${ver%-*}" ]; then + # no sub-version. ie. it's 0.05 and not 0.05-341 + # so add a sub-version. + ver=$ver-0 +fi +echo "$ver" >$3 + +redo-always +redo-stamp <$3 diff --git a/version/short.txt.do b/version/short.txt.do new file mode 100644 index 000000000..54540d3ec --- /dev/null +++ b/version/short.txt.do @@ -0,0 +1,18 @@ +redo-ifchange long.txt +read -r LONGVER junk $3 + ;; + *-*) + echo "$LONGVER" >$3 + ;; + *) + echo "Fatal: long version in invalid format." >&2 + exit 44 +esac + +redo-stamp <$3 diff --git a/version/ver.go.do b/version/ver.go.do new file mode 100644 index 000000000..f7a8e1f6d --- /dev/null +++ b/version/ver.go.do @@ -0,0 +1,8 @@ +redo-ifchange long.txt short.txt ver.go.in + +read -r LONGVER $3 diff --git a/version/ver.go.in b/version/ver.go.in new file mode 100644 index 000000000..d9dd25ed4 --- /dev/null +++ b/version/ver.go.in @@ -0,0 +1,4 @@ +package version + +const LONG = "{LONGVER}" +const SHORT = "{SHORTVER}" diff --git a/version/version.h.do b/version/version.h.do new file mode 100644 index 000000000..013a0b7d3 --- /dev/null +++ b/version/version.h.do @@ -0,0 +1,11 @@ +redo-ifchange long.txt short.txt +read -r long $3 diff --git a/version/version.xcconfig.do b/version/version.xcconfig.do new file mode 100644 index 000000000..7ba1a854d --- /dev/null +++ b/version/version.xcconfig.do @@ -0,0 +1,17 @@ +redo-ifchange short.txt +read -r ver $3 +# CFBundleVersion: the build number. Needs to increment each release. +# start counting at 100 because we submitted using raw build numbers +# before (and Apple doesn't let you start over). +# eg. 100.92.98 + +major=$((${ver%%.*} + 100)) +minor=${ver#*.} +echo "VERSION_ID = $major.$minor" >>$3 diff --git a/wgengine/faketun.go b/wgengine/faketun.go new file mode 100644 index 000000000..01a059f59 --- /dev/null +++ b/wgengine/faketun.go @@ -0,0 +1,55 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package wgengine + +import ( + "github.com/tailscale/wireguard-go/tun" + "io" + "os" +) + +type fakeTun struct { + datachan chan []byte + evchan chan tun.Event + closechan chan struct{} +} + +func NewFakeTun() tun.Device { + return &fakeTun{ + datachan: make(chan []byte), + evchan: make(chan tun.Event), + closechan: make(chan struct{}), + } +} + +func (t *fakeTun) File() *os.File { + panic("fakeTun.File() called, which makes no sense") +} + +func (t *fakeTun) Close() error { + close(t.closechan) + close(t.datachan) + return nil +} + +func (t *fakeTun) InsertRead(b []byte) { + t.datachan <- b +} + +func (t *fakeTun) Read(out []byte, offset int) (int, error) { + select { + case <-t.closechan: + return 0, io.EOF + case b := <-t.datachan: + copy(out[offset:offset+len(b)], b) + return len(b), nil + } +} + +func (t *fakeTun) Write(b []byte, n int) (int, error) { return len(b), nil } +func (t *fakeTun) Flush() error { return nil } +func (t *fakeTun) MTU() (int, error) { return 1500, nil } +func (t *fakeTun) Name() (string, error) { return "FakeTun", nil } +func (t *fakeTun) Events() chan tun.Event { return t.evchan } diff --git a/wgengine/filter/filter.go b/wgengine/filter/filter.go new file mode 100644 index 000000000..ac0fd1282 --- /dev/null +++ b/wgengine/filter/filter.go @@ -0,0 +1,218 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package filter + +import ( + "fmt" + "log" + "sync" + "time" + + "github.com/golang/groupcache/lru" + "tailscale.com/ratelimit" + "tailscale.com/wgengine/packet" +) + +type Filter struct { + matches Matches + + udpMu sync.Mutex + udplru *lru.Cache +} + +type Response int + +const ( + Drop Response = iota + Accept + noVerdict // Returned from subfilters to continue processing. +) + +func (r Response) String() string { + switch r { + case Drop: + return "Drop" + case Accept: + return "Accept" + case noVerdict: + return "noVerdict" + default: + return "???" + } +} + +type RunFlags int + +const ( + LogDrops RunFlags = 1 << iota + LogAccepts + HexdumpDrops + HexdumpAccepts +) + +type tuple struct { + SrcIP IP + DstIP IP + SrcPort uint16 + DstPort uint16 +} + +const LRU_MAX = 512 // max entries in UDP LRU cache + +var MatchAllowAll = Matches{ + Match{[]IPPortRange{IPPortRangeAny}, []IP{IPAny}}, +} + +func NewAllowAll() *Filter { + return New(MatchAllowAll) +} + +func NewAllowNone() *Filter { + return New(nil) +} + +func New(matches Matches) *Filter { + f := &Filter{ + matches: matches, + udplru: lru.New(LRU_MAX), + } + return f +} + +func maybeHexdump(flag RunFlags, b []byte) string { + if flag != 0 { + return packet.Hexdump(b) + "\n" + } else { + return "" + } +} + +// TODO(apenwarr): use a bigger bucket for specifically TCP SYN accept logging? +// Logging is a quick way to record every newly opened TCP connection, but +// we have to be cautious about flooding the logs vs letting people use +// flood protection to hide their traffic. We could use a rate limiter in +// the actual *filter* for SYN accepts, perhaps. +var acceptBucket = ratelimit.Bucket{ + Burst: 3, + FillInterval: 10 * time.Second, +} +var dropBucket = ratelimit.Bucket{ + Burst: 10, + FillInterval: 5 * time.Second, +} + +func logRateLimit(runflags RunFlags, b []byte, q *packet.QDecode, r Response, why string) { + if r == Drop && (runflags&LogDrops) != 0 && dropBucket.TryGet() > 0 { + var qs string + if q == nil { + qs = fmt.Sprintf("(%d bytes)", len(b)) + } else { + qs = q.String() + } + log.Printf("Drop: %v %v %s\n%s", qs, len(b), why, maybeHexdump(runflags&HexdumpDrops, b)) + } else if r == Accept && (runflags&LogAccepts) != 0 && acceptBucket.TryGet() > 0 { + log.Printf("Accept: %v %v %s\n%s", q, len(b), why, maybeHexdump(runflags&HexdumpAccepts, b)) + } +} + +func (f *Filter) RunIn(b []byte, q *packet.QDecode, rf RunFlags) Response { + r := pre(b, q, rf) + if r == Accept || r == Drop { + // already logged + return r + } + + r, why := f.runIn(q) + logRateLimit(rf, b, q, r, why) + return r +} + +func (f *Filter) RunOut(b []byte, q *packet.QDecode, rf RunFlags) Response { + r := pre(b, q, rf) + if r == Drop || r == Accept { + // already logged + return r + } + r, why := f.runOut(q) + logRateLimit(rf, b, q, r, why) + return r +} + +func (f *Filter) runIn(q *packet.QDecode) (r Response, why string) { + switch q.IPProto { + case packet.ICMP: + // If any port is open to an IP, allow ICMP to it. + if matchIPWithoutPorts(f.matches, q) { + return Accept, "icmp ok" + } + case packet.TCP: + // For TCP, we want to allow *outgoing* connections, + // which means we want to allow return packets on those + // connections. To make this restriction work, we need to + // allow non-SYN packets (continuation of an existing session) + // to arrive. This should be okay since a new incoming session + // can't be initiated without first sending a SYN. + // It happens to also be much faster. + // TODO(apenwarr): Skip the rest of decoding in this path? + if q.IPProto == packet.TCP && !q.IsTCPSyn() { + return Accept, "tcp non-syn" + } + if matchIPPorts(f.matches, q) { + return Accept, "tcp ok" + } + case packet.UDP: + t := tuple{q.SrcIP, q.DstIP, q.SrcPort, q.DstPort} + + f.udpMu.Lock() + _, ok := f.udplru.Get(t) + f.udpMu.Unlock() + + if ok { + return Accept, "udp cached" + } + if matchIPPorts(f.matches, q) { + return Accept, "udp ok" + } + default: + return Drop, "Unknown proto" + } + return Drop, "no rules matched" +} + +func (f *Filter) runOut(q *packet.QDecode) (r Response, why string) { + if q.IPProto == packet.UDP { + t := tuple{q.DstIP, q.SrcIP, q.DstPort, q.SrcPort} + + f.udpMu.Lock() + f.udplru.Add(t, t) + f.udpMu.Unlock() + } + return Accept, "ok out" +} + +func pre(b []byte, q *packet.QDecode, rf RunFlags) Response { + if len(b) == 0 { + // wireguard keepalive packet, always permit. + return Accept + } + if len(b) < 20 { + logRateLimit(rf, b, nil, Drop, "too short") + return Drop + } + q.Decode(b) + + if q.IPProto == packet.Junk { + // Junk packets are dangerous; always drop them. + logRateLimit(rf, b, q, Drop, "junk!") + return Drop + } else if q.IPProto == packet.Fragment { + // Fragments after the first always need to be passed through. + // Very small fragments are considered Junk by QDecode. + logRateLimit(rf, b, q, Accept, "fragment") + return Accept + } + + return noVerdict +} diff --git a/wgengine/filter/filter_test.go b/wgengine/filter/filter_test.go new file mode 100644 index 000000000..1833e9232 --- /dev/null +++ b/wgengine/filter/filter_test.go @@ -0,0 +1,162 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package filter + +import ( + "encoding/binary" + "encoding/json" + "net" + "testing" + + "tailscale.com/wgengine/packet" +) + +type QDecode = packet.QDecode + +var Junk = packet.Junk +var ICMP = packet.ICMP +var TCP = packet.TCP +var UDP = packet.UDP +var Fragment = packet.Fragment + +func ippr(ip IP, start, end uint16) []IPPortRange { + return []IPPortRange{ + IPPortRange{ip, PortRange{start, end}}, + } +} + +func TestFilter(t *testing.T) { + mm := Matches{ + {SrcIPs: []IP{0x08010101, 0x08020202}, DstPorts: []IPPortRange{ + IPPortRange{0x01020304, PortRange{22, 22}}, + IPPortRange{0x05060708, PortRange{23, 24}}, + }}, + {SrcIPs: []IP{0x08010101, 0x08020202}, DstPorts: ippr(0x05060708, 27, 28)}, + {SrcIPs: []IP{0x02020202}, DstPorts: ippr(0x08010101, 22, 22)}, + {SrcIPs: []IP{0}, DstPorts: ippr(0x647a6232, 0, 65535)}, + {SrcIPs: []IP{0}, DstPorts: ippr(0, 443, 443)}, + {SrcIPs: []IP{0x99010101, 0x99010102, 0x99030303}, DstPorts: ippr(0x01020304, 999, 999)}, + } + acl := New(mm) + + for _, ent := range []Matches{Matches{mm[0]}, mm} { + b, err := json.Marshal(ent) + if err != nil { + t.Fatalf("marshal: %v", err) + } + + mm2 := Matches{} + if err := json.Unmarshal(b, &mm2); err != nil { + t.Fatalf("unmarshal: %v (%v)", err, string(b)) + } + } + + // check packet filtering based on the table + + type InOut struct { + want Response + p QDecode + } + tests := []InOut{ + // Basic + {Accept, qdecode(TCP, 0x08010101, 0x01020304, 999, 22)}, + {Accept, qdecode(UDP, 0x08010101, 0x01020304, 999, 22)}, + {Accept, qdecode(ICMP, 0x08010101, 0x01020304, 0, 0)}, + {Drop, qdecode(TCP, 0x08010101, 0x01020304, 0, 0)}, + {Accept, qdecode(TCP, 0x08010101, 0x01020304, 0, 22)}, + {Drop, qdecode(TCP, 0x08010101, 0x01020304, 0, 21)}, + {Accept, qdecode(TCP, 0x11223344, 0x22334455, 0, 443)}, + {Drop, qdecode(TCP, 0x11223344, 0x22334455, 0, 444)}, + {Accept, qdecode(TCP, 0x11223344, 0x647a6232, 0, 999)}, + {Accept, qdecode(TCP, 0x11223344, 0x647a6232, 0, 0)}, + + // Stateful UDP. + // Initially empty cache + {Drop, qdecode(UDP, 0x77777777, 0x66666666, 4242, 4343)}, + // Return packet from previous attempt is allowed + {Accept, qdecode(UDP, 0x66666666, 0x77777777, 4343, 4242)}, + // Because of the return above, initial attempt is allowed now + {Accept, qdecode(UDP, 0x77777777, 0x66666666, 4242, 4343)}, + } + for i, test := range tests { + if got, _ := acl.runIn(&test.p); test.want != got { + t.Errorf("#%d got=%v want=%v packet:%v\n", i, got, test.want, test.p) + } + // Update UDP state + _, _ = acl.runOut(&test.p) + } +} + +func TestPreFilter(t *testing.T) { + packets := []struct { + desc string + want Response + b []byte + }{ + {"empty", Accept, []byte{}}, + {"short", Drop, []byte("short")}, + {"junk", Drop, rawpacket(Junk, 10)}, + {"fragment", Accept, rawpacket(Fragment, 40)}, + {"tcp", noVerdict, rawpacket(TCP, 200)}, + {"udp", noVerdict, rawpacket(UDP, 200)}, + {"icmp", noVerdict, rawpacket(ICMP, 200)}, + } + for _, testPacket := range packets { + got := pre([]byte(testPacket.b), &QDecode{}, LogDrops|LogAccepts) + if got != testPacket.want { + t.Errorf("%q got=%v want=%v packet:\n%s", testPacket.desc, got, testPacket.want, packet.Hexdump(testPacket.b)) + } + } +} + +func qdecode(proto packet.IPProto, src, dst packet.IP, sport, dport uint16) QDecode { + return QDecode{ + IPProto: proto, + SrcIP: src, + DstIP: dst, + SrcPort: sport, + DstPort: dport, + TCPFlags: packet.TCPSyn, + } +} + +func rawpacket(proto packet.IPProto, len uint16) []byte { + bl := len + if len < 24 { + bl = 24 + } + bin := binary.BigEndian + hdr := make([]byte, bl) + hdr[0] = 0x45 + bin.PutUint16(hdr[2:4], len) + hdr[8] = 64 + ip := net.IPv4(8, 8, 8, 8).To4() + copy(hdr[12:16], ip) + copy(hdr[16:20], ip) + // ports + bin.PutUint16(hdr[20:22], 53) + bin.PutUint16(hdr[22:24], 53) + + switch proto { + case ICMP: + hdr[9] = 1 + case TCP: + hdr[9] = 6 + case UDP: + hdr[9] = 17 + case Fragment: + hdr[9] = 6 + // flags + fragOff + bin.PutUint16(hdr[6:8], (1<<13)|1234) + case Junk: + default: + panic("unknown protocol") + } + + // Truncate the header if requested + hdr = hdr[:len] + + return hdr +} diff --git a/wgengine/filter/match.go b/wgengine/filter/match.go new file mode 100644 index 000000000..27e0b8f98 --- /dev/null +++ b/wgengine/filter/match.go @@ -0,0 +1,121 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package filter + +import ( + "fmt" + "strings" + "tailscale.com/wgengine/packet" +) + +type IP = packet.IP + +const IPAny = IP(0) + +var NewIP = packet.NewIP + +type PortRange struct { + First, Last uint16 +} + +var PortRangeAny = PortRange{0, 65535} + +func (pr PortRange) String() string { + if pr.First == 0 && pr.Last == 65535 { + return "*" + } else if pr.First == pr.Last { + return fmt.Sprintf("%d", pr.First) + } else { + return fmt.Sprintf("%d-%d", pr.First, pr.Last) + } +} + +type IPPortRange struct { + IP IP + Ports PortRange +} + +var IPPortRangeAny = IPPortRange{IPAny, PortRangeAny} + +func (ipr IPPortRange) String() string { + return fmt.Sprintf("%v:%v", ipr.IP, ipr.Ports) +} + +type Match struct { + DstPorts []IPPortRange + SrcIPs []IP +} + +func (m Match) String() string { + srcs := []string{} + for _, srcip := range m.SrcIPs { + srcs = append(srcs, srcip.String()) + } + dsts := []string{} + for _, dst := range m.DstPorts { + dsts = append(dsts, dst.String()) + } + + var ss, ds string + if len(srcs) == 1 { + ss = srcs[0] + } else { + ss = "[" + strings.Join(srcs, ",") + "]" + } + if len(dsts) == 1 { + ds = dsts[0] + } else { + ds = "[" + strings.Join(dsts, ",") + "]" + } + return fmt.Sprintf("%v=>%v", ss, ds) +} + +type Matches []Match + +func ipInList(ip IP, iplist []IP) bool { + for _, ipp := range iplist { + if ipp == IPAny || ipp == ip { + return true + } + } + return false +} + +func matchIPPorts(mm Matches, q *packet.QDecode) bool { + for _, acl := range mm { + for _, dst := range acl.DstPorts { + if dst.IP != IPAny && dst.IP != q.DstIP { + continue + } + if q.DstPort < dst.Ports.First || q.DstPort > dst.Ports.Last { + continue + } + if !ipInList(q.SrcIP, acl.SrcIPs) { + // Skip other dests in this acl, since + // the src will never match. + break + } + return true + } + } + return false +} + +func matchIPWithoutPorts(mm Matches, q *packet.QDecode) bool { + for _, acl := range mm { + for _, dst := range acl.DstPorts { + if dst.IP != IPAny && dst.IP != q.DstIP { + continue + } + if !ipInList(q.SrcIP, acl.SrcIPs) { + // Skip other dests in this acl, since + // the src will never match. + break + } + return true + } + } + return false +} diff --git a/wgengine/ifconfig_windows.go b/wgengine/ifconfig_windows.go new file mode 100644 index 000000000..707a21e38 --- /dev/null +++ b/wgengine/ifconfig_windows.go @@ -0,0 +1,411 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. + */ + +package wgengine + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + "log" + "net" + "sort" + "time" + "unsafe" + + "github.com/go-ole/go-ole" + "github.com/tailscale/wireguard-go/device" + "github.com/tailscale/wireguard-go/tun" + "github.com/tailscale/wireguard-go/wgcfg" + "golang.org/x/sys/windows" + "golang.org/x/sys/windows/registry" + "golang.zx2c4.com/winipcfg" + "tailscale.com/wgengine/winnet" +) + +const ( + sockoptIP_UNICAST_IF = 31 + sockoptIPV6_UNICAST_IF = 31 +) + +func htonl(val uint32) uint32 { + bytes := make([]byte, 4) + binary.BigEndian.PutUint32(bytes, val) + return *(*uint32)(unsafe.Pointer(&bytes[0])) +} + +func bindSocketRoute(family winipcfg.AddressFamily, device *device.Device, ourLuid uint64, lastLuid *uint64) error { + routes, err := winipcfg.GetRoutes(family) + if err != nil { + return err + } + lowestMetric := ^uint32(0) + index := uint32(0) // Zero is "unspecified", which for IP_UNICAST_IF resets the value, which is what we want. + luid := uint64(0) // Hopefully luid zero is unspecified, but hard to find docs saying so. + for _, route := range routes { + if route.DestinationPrefix.PrefixLength != 0 || route.InterfaceLuid == ourLuid { + continue + } + if route.Metric < lowestMetric { + lowestMetric = route.Metric + index = route.InterfaceIndex + luid = route.InterfaceLuid + } + } + if luid == *lastLuid { + return nil + } + *lastLuid = luid + if false { + // TODO(apenwarr): doesn't work with magic socket yet. + if family == winipcfg.AF_INET { + return device.BindSocketToInterface4(index, false) + } else if family == winipcfg.AF_INET6 { + return device.BindSocketToInterface6(index, false) + } + } else { + log.Printf("WARNING: skipping windows socket binding.\n") + } + return nil +} + +func MonitorDefaultRoutes(device *device.Device, autoMTU bool, tun *tun.NativeTun) (*winipcfg.RouteChangeCallback, error) { + guid := tun.GUID() + ourLuid, err := winipcfg.InterfaceGuidToLuid(&guid) + lastLuid4 := uint64(0) + lastLuid6 := uint64(0) + lastMtu := uint32(0) + if err != nil { + return nil, err + } + doIt := func() error { + err = bindSocketRoute(winipcfg.AF_INET, device, ourLuid, &lastLuid4) + if err != nil { + return err + } + err = bindSocketRoute(winipcfg.AF_INET6, device, ourLuid, &lastLuid6) + if err != nil { + return err + } + if !autoMTU { + return nil + } + mtu := uint32(0) + if lastLuid4 != 0 { + iface, err := winipcfg.InterfaceFromLUID(lastLuid4) + if err != nil { + return err + } + if iface.Mtu > 0 { + mtu = iface.Mtu + } + } + if lastLuid6 != 0 { + iface, err := winipcfg.InterfaceFromLUID(lastLuid6) + if err != nil { + return err + } + if iface.Mtu > 0 && iface.Mtu < mtu { + mtu = iface.Mtu + } + } + if mtu > 0 && (lastMtu == 0 || lastMtu != mtu) { + iface, err := winipcfg.GetIpInterface(ourLuid, winipcfg.AF_INET) + if err != nil { + return err + } + iface.NlMtu = mtu - 80 + if iface.NlMtu < 576 { + iface.NlMtu = 576 + } + err = iface.Set() + if err != nil { + return err + } + tun.ForceMTU(int(iface.NlMtu)) //TODO: it sort of breaks the model with v6 mtu and v4 mtu being different. Just set v4 one for now. + iface, err = winipcfg.GetIpInterface(ourLuid, winipcfg.AF_INET6) + if err != nil { + return err + } + iface.NlMtu = mtu - 80 + if iface.NlMtu < 1280 { + iface.NlMtu = 1280 + } + err = iface.Set() + if err != nil { + return err + } + lastMtu = mtu + } + return nil + } + err = doIt() + if err != nil { + return nil, err + } + cb, err := winipcfg.RegisterRouteChangeCallback(func(notificationType winipcfg.MibNotificationType, route *winipcfg.Route) { + //fmt.Printf("MonitorDefaultRoutes: changed: %v\n", route.DestinationPrefix) + if route.DestinationPrefix.PrefixLength == 0 { + _ = doIt() + } + }) + if err != nil { + return nil, err + } + return cb, nil +} + +func setDNSDomains(g windows.GUID, dnsDomains []string) { + gs := g.String() + log.Printf("setDNSDomains(%v) guid=%v\n", dnsDomains, gs) + p := `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces\` + gs + key, err := registry.OpenKey(registry.LOCAL_MACHINE, p, registry.READ|registry.SET_VALUE) + if err != nil { + log.Printf("setDNSDomains(%v): open: %v\n", p, err) + return + } + defer key.Close() + + // Windows only supports a single per-interface DNS domain. + dom := "" + if len(dnsDomains) > 0 { + dom = dnsDomains[0] + } + err = key.SetStringValue("Domain", dom) + if err != nil { + log.Printf("setDNSDomains(%v): SetStringValue: %v\n", p, err) + } +} + +func setFirewall(ifcGUID *windows.GUID) (bool, error) { + c := ole.Connection{} + err := c.Initialize() + if err != nil { + panic(err) + } + defer c.Uninitialize() + + m, err := winnet.NewNetworkListManager(&c) + if err != nil { + panic(err) + } + defer m.Release() + + cl, err := m.GetNetworkConnections() + if err != nil { + panic(err) + } + defer cl.Release() + + for _, nco := range cl { + aid, err := nco.GetAdapterId() + if err != nil { + panic(err) + } + if aid != ifcGUID.String() { + log.Printf("skipping adapter id: %v\n", aid) + continue + } + log.Printf("found! adapter id: %v\n", aid) + + n, err := nco.GetNetwork() + if err != nil { + return false, fmt.Errorf("GetNetwork: %v", err) + } + defer n.Release() + + cat, err := n.GetCategory() + if err != nil { + return false, fmt.Errorf("GetCategory: %v", err) + } + + if cat == 0 { + err = n.SetCategory(1) + if err != nil { + return false, fmt.Errorf("SetCategory: %v", err) + } + } else { + log.Printf("setFirewall: already category %v\n", cat) + } + + return true, nil + } + + return false, nil +} + +func ConfigureInterface(m *wgcfg.Config, tun *tun.NativeTun, dns []net.IP, dnsDomains []string) error { + const mtu = 0 + guid := tun.GUID() + log.Printf("wintun GUID is %v\n", guid) + iface, err := winipcfg.InterfaceFromGUID(&guid) + if err != nil { + return err + } + + go func() { + // It takes a weirdly long time for Windows to notice the + // new interface has come up. Poll periodically until it + // does. + for i := 0; i < 20; i++ { + found, err := setFirewall(&guid) + if err != nil { + log.Printf("setFirewall: %v\n", err) + // fall through anyway, this isn't fatal. + } + if found { + break + } + time.Sleep(1 * time.Second) + } + }() + + setDNSDomains(guid, dnsDomains) + + routes := []winipcfg.RouteData{} + var firstGateway4 *net.IP + var firstGateway6 *net.IP + addresses := make([]*net.IPNet, len(m.Interface.Addresses)) + for i, addr := range m.Interface.Addresses { + ipnet := addr.IPNet() + addresses[i] = ipnet + gateway := ipnet.IP + if addr.IP.Is4() && firstGateway4 == nil { + firstGateway4 = &gateway + } else if addr.IP.Is6() && firstGateway6 == nil { + firstGateway6 = &gateway + } + } + + foundDefault4 := false + foundDefault6 := false + for _, peer := range m.Peers { + for _, allowedip := range peer.AllowedIPs { + if (allowedip.IP.Is4() && firstGateway4 == nil) || (allowedip.IP.Is6() && firstGateway6 == nil) { + return errors.New("Due to a Windows limitation, one cannot have interface routes without an interface address") + } + + ipn := allowedip.IPNet() + var gateway net.IP + if allowedip.IP.Is4() { + gateway = *firstGateway4 + } else if allowedip.IP.Is6() { + gateway = *firstGateway6 + } + r := winipcfg.RouteData{ + Destination: net.IPNet{ + IP: ipn.IP.Mask(ipn.Mask), + Mask: ipn.Mask, + }, + NextHop: gateway, + Metric: 0, + } + if bytes.Compare(r.Destination.IP, gateway) == 0 { + // no need to add a route for the interface's + // own IP. The kernel does that for us. + // If we try to replace it, we'll fail to + // add the route unless NextHop is set, but + // then the interface's IP won't be pingable. + continue + } + if allowedip.IP.Is4() { + if allowedip.Mask == 0 { + foundDefault4 = true + } + r.NextHop = *firstGateway4 + } else if allowedip.IP.Is6() { + if allowedip.Mask == 0 { + foundDefault6 = true + } + r.NextHop = *firstGateway6 + } + routes = append(routes, r) + } + } + + err = iface.SetAddresses(addresses) + if err != nil { + return err + } + + sort.Slice(routes, func(i, j int) bool { + return (bytes.Compare(routes[i].Destination.IP, routes[j].Destination.IP) == -1 || + // Narrower masks first + bytes.Compare(routes[i].Destination.Mask, routes[j].Destination.Mask) == 1 || + // No nexthop before non-empty nexthop + bytes.Compare(routes[i].NextHop, routes[j].NextHop) == -1 || + // Lower metrics first + routes[i].Metric < routes[j].Metric) + }) + + deduplicatedRoutes := []*winipcfg.RouteData{} + for i := 0; i < len(routes); i++ { + // There's only one way to get to a given IP+Mask, so delete + // all matches after the first. + if i > 0 && + bytes.Equal(routes[i].Destination.IP, routes[i-1].Destination.IP) && + bytes.Equal(routes[i].Destination.Mask, routes[i-1].Destination.Mask) { + continue + } + deduplicatedRoutes = append(deduplicatedRoutes, &routes[i]) + } + log.Printf("routes: %v\n", routes) + + var errAcc error + err = iface.SetRoutes(deduplicatedRoutes) + if err != nil && errAcc == nil { + log.Printf("setroutes: %v\n", err) + errAcc = err + } + + err = iface.SetDNS(dns) + if err != nil && errAcc == nil { + log.Printf("setdns: %v\n", err) + errAcc = err + } + + ipif, err := iface.GetIpInterface(winipcfg.AF_INET) + if err != nil { + log.Printf("getipif: %v\n", err) + return err + } + log.Printf("foundDefault4: %v\n", foundDefault4) + if foundDefault4 { + ipif.UseAutomaticMetric = false + ipif.Metric = 0 + } + if mtu > 0 { + ipif.NlMtu = uint32(mtu) + tun.ForceMTU(int(ipif.NlMtu)) + } + err = ipif.Set() + if err != nil && errAcc == nil { + errAcc = err + } + + ipif, err = iface.GetIpInterface(winipcfg.AF_INET6) + if err != nil { + return err + } + if err != nil && errAcc == nil { + errAcc = err + } + if foundDefault6 { + ipif.UseAutomaticMetric = false + ipif.Metric = 0 + } + if mtu > 0 { + ipif.NlMtu = uint32(mtu) + } + ipif.DadTransmits = 0 + ipif.RouterDiscoveryBehavior = winipcfg.RouterDiscoveryDisabled + err = ipif.Set() + if err != nil && errAcc == nil { + errAcc = err + } + + return errAcc +} diff --git a/wgengine/magicsock/magicsock.go b/wgengine/magicsock/magicsock.go new file mode 100644 index 000000000..661fc2427 --- /dev/null +++ b/wgengine/magicsock/magicsock.go @@ -0,0 +1,815 @@ +// Copyright 2019 Tailscale & AUTHORS. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package magicsock implements a socket that can change its communication path while +// in use, actively searching for the best way to communicate. +package magicsock + +import ( + "context" + "encoding/binary" + "fmt" + "log" + "net" + "strings" + "sync" + "syscall" + "time" + + "github.com/tailscale/wireguard-go/device" + "github.com/tailscale/wireguard-go/wgcfg" + "tailscale.com/derp/derphttp" + "tailscale.com/stun" + "tailscale.com/stunner" +) + +// A Conn routes UDP packets and actively manages a list of its endpoints. +// It implements wireguard/device.Bind. +type Conn struct { + pconn *RebindingUDPConn + pconnPort uint16 + stunServers []string + derpServer string + startEpUpdate chan struct{} // send to trigger endpoint update + epUpdateCancel func() + epFunc func(endpoints []string) + logf func(format string, args ...interface{}) + + // indexedAddrs is a map of every remote ip:port to a priority + // list of endpoint addresses for a peer. + // The priority list is provided by wgengine configuration. + // + // Given a wgcfg describing: + // machineA: 10.0.0.1:1, 10.0.0.2:2 + // machineB: 10.0.0.3:3 + // the indexedAddrs map contains: + // 10.0.0.1:1 -> [10.0.0.1:1, 10.0.0.2:2], index:0 + // 10.0.0.2:2 -> [10.0.0.1:1, 10.0.0.2:2], index:1 + // 10.0.0.3:3 -> [10.0.0.3:3], index:0 + indexedAddrsMu sync.Mutex + indexedAddrs map[udpAddr]indexedAddrSet + + stunReceiveMu sync.Mutex + stunReceive func(p []byte, fromAddr *net.UDPAddr) + + derpMu sync.Mutex + derp *derphttp.Client +} + +// udpAddr is the key in the indexedAddrs map. +// It maps an ip:port onto an indexedAddr. +type udpAddr struct { + ip wgcfg.IP + port uint16 +} + +// indexedAddrSet is an AddrSet (a priority list of ip:ports for a peer and the +// current favored ip:port for communicating with the peer) and an index +// number saying which element of the priority list is this map entry. +type indexedAddrSet struct { + addr *AddrSet + index int // index of map key in addr.Addrs +} + +const DefaultPort = 0 + +const DefaultDERP = "https://derp.tailscale.com/derp" + +var DefaultSTUN = []string{ + "stun.l.google.com:19302", + "stun3.l.google.com:19302", +} + +// Options contains options for Listen. +type Options struct { + // Port is the port to listen on. + // Zero means to pick one automatically. + Port uint16 + + STUN []string + DERP string + + // EndpointsFunc optionally provides a func to be called when + // endpoints change. The called func does not own the slice. + EndpointsFunc func(endpoint []string) +} + +func (o *Options) endpointsFunc() func([]string) { + if o == nil || o.EndpointsFunc == nil { + return func([]string) {} + } + return o.EndpointsFunc +} + +// Listen creates a magic Conn listening on opts.Port. +// As the set of possible endpoints for a Conn changes, the +// callback opts.EndpointsFunc is called. +func Listen(opts Options) (*Conn, error) { + var packetConn net.PacketConn + var err error + if opts.Port == 0 { + // Our choice of port. Start with DefaultPort. + // If unavailable, pick any port. + want := fmt.Sprintf(":%d", DefaultPort) + log.Printf("magicsock: bind: trying %v\n", want) + packetConn, err = net.ListenPacket("udp4", want) + if err != nil { + want = ":0" + log.Printf("magicsock: bind: falling back to %v (%v)\n", want, err) + packetConn, err = net.ListenPacket("udp4", want) + } + } else { + packetConn, err = net.ListenPacket("udp4", fmt.Sprintf(":%d", opts.Port)) + } + if err != nil { + return nil, fmt.Errorf("magicsock.Listen: %v", err) + } + + epUpdateCtx, epUpdateCancel := context.WithCancel(context.Background()) + c := &Conn{ + pconn: new(RebindingUDPConn), + stunServers: append([]string{}, opts.STUN...), + derpServer: opts.DERP, + startEpUpdate: make(chan struct{}, 1), + epUpdateCancel: epUpdateCancel, + epFunc: opts.endpointsFunc(), + logf: log.Printf, + indexedAddrs: make(map[udpAddr]indexedAddrSet), + } + c.pconn.Reset(packetConn.(*net.UDPConn)) + c.startEpUpdate <- struct{}{} // STUN immediately on start + go c.epUpdate(epUpdateCtx) + return c, nil +} + +func (c *Conn) epUpdate(ctx context.Context) { + var lastEndpoints []string + var lastCancel func() + var lastDone chan struct{} + for { + select { + case <-ctx.Done(): + if lastCancel != nil { + lastCancel() + } + return + case <-c.startEpUpdate: + } + + if lastCancel != nil { + lastCancel() + <-lastDone + } + var epCtx context.Context + epCtx, lastCancel = context.WithCancel(ctx) + lastDone = make(chan struct{}) + + go func() { + defer close(lastDone) + endpoints, err := c.determineEndpoints(epCtx) + if err != nil { + c.logf("magicsock.Conn: endpoint update failed: %v", err) + // TODO(crawshaw): are there any conditions under which + // we should trigger a retry based on the error here? + return + } + if stringsEqual(endpoints, lastEndpoints) { + return + } + lastEndpoints = endpoints + c.epFunc(endpoints) + }() + } +} + +func (c *Conn) determineEndpoints(ctx context.Context) ([]string, error) { + var alreadyMu sync.Mutex + already := make(map[string]struct{}) + var eps []string + + addAddr := func(s, reason string) { + log.Printf("magicsock: found local %s (%s)\n", s, reason) + + alreadyMu.Lock() + defer alreadyMu.Unlock() + if _, ok := already[s]; !ok { + already[s] = struct{}{} + eps = append(eps, s) + } + } + + s := &stunner.Stunner{ + Send: c.pconn.WriteTo, + Endpoint: func(s string) { addAddr(s, "stun") }, + Servers: c.stunServers, + Logf: c.logf, + } + + c.stunReceiveMu.Lock() + c.stunReceive = s.Receive + c.stunReceiveMu.Unlock() + + if err := s.Run(ctx); err != nil { + return nil, err + } + + c.stunReceiveMu.Lock() + c.stunReceive = nil + c.stunReceiveMu.Unlock() + + if localAddr := c.pconn.LocalAddr(); localAddr.IP.IsUnspecified() { + localPort := fmt.Sprintf("%d", localAddr.Port) + loopbacks, err := localAddresses(localPort, func(s string) { + addAddr(s, "localAddresses") + }) + if err != nil { + return nil, err + } + if len(eps) == 0 { + // Only include loopback addresses if we have no + // interfaces at all to use as endpoints. This allows + // for localhost testing when you're on a plane and + // offline, for example. + for _, s := range loopbacks { + addAddr(s, "loopback") + } + } + } else { + // Our local endpoint is bound to a particular address. + // Do not offer addresses on other local interfaces. + addAddr(localAddr.String(), "socket") + } + + // Note: the endpoints are intentionally returned in priority order, + // from "farthest but most reliable" to "closest but least + // reliable." Addresses returned from STUN should be globally + // addressable, but might go farther on the network than necessary. + // Local interface addresses might have lower latency, but not be + // globally addressable. + // + // The STUN address(es) are always first so that legacy wireguard + // can use eps[0] as its only known endpoint address (although that's + // obviously non-ideal). + return eps, nil +} + +func stringsEqual(x, y []string) bool { + if len(x) != len(y) { + return false + } + for i := range x { + if x[i] != y[i] { + return false + } + } + return true +} + +func localAddresses(localPort string, addAddr func(s string)) ([]string, error) { + var loopback []string + + // TODO(crawshaw): don't serve interface addresses that we are routing + ifaces, err := net.Interfaces() + if err != nil { + return nil, err + } + for _, i := range ifaces { + if (i.Flags & net.FlagUp) == 0 { + // Down interfaces don't count + continue + } + ifcIsLoopback := (i.Flags & net.FlagLoopback) != 0 + + addrs, err := i.Addrs() + if err != nil { + return nil, err + } + for _, a := range addrs { + switch v := a.(type) { + case *net.IPNet: + // TODO(crawshaw): IPv6 support. + // Easy to do here, but we need good endpoint ordering logic. + ip := v.IP.To4() + if ip == nil { + continue + } + // TODO(apenwarr): don't special case cgNAT. + // In the general wireguard case, it might + // very well be something we can route to + // directly, because both nodes are + // behind the same CGNAT router. + if cgNAT.Contains(ip) { + continue + } + if linkLocalIPv4.Contains(ip) { + continue + } + ep := net.JoinHostPort(ip.String(), localPort) + if ip.IsLoopback() || ifcIsLoopback { + loopback = append(loopback, ep) + continue + } + addAddr(ep) + } + } + } + return loopback, nil +} + +var cgNAT = func() *net.IPNet { + _, ipNet, err := net.ParseCIDR("100.64.0.0/10") + if err != nil { + panic(err) + } + return ipNet +}() + +var linkLocalIPv4 = func() *net.IPNet { + _, ipNet, err := net.ParseCIDR("169.254.0.0/16") + if err != nil { + panic(err) + } + return ipNet +}() + +func (c *Conn) LocalPort() uint16 { + laddr := c.pconn.LocalAddr() + return uint16(laddr.Port) +} + +func (c *Conn) Send(b []byte, ep device.Endpoint) error { + a := ep.(*AddrSet) + + msgType := binary.LittleEndian.Uint32(b[:4]) + switch msgType { + case device.MessageInitiationType, device.MessageResponseType, device.MessageCookieReplyType: + // Part of the wireguard handshake. + // Send to every potential endpoint we have for a peer. + a.mu.Lock() + roamAddr := a.roamAddr + a.mu.Unlock() + + var err error + var success bool + if roamAddr != nil { + _, err = c.pconn.WriteTo(b, roamAddr) + if err == nil { + success = true + } + } + for i := len(a.addrs) - 1; i >= 0; i-- { + addr := &a.addrs[i] + _, err = c.pconn.WriteTo(b, addr) + if err == nil { + success = true + } + } + + if msgType == device.MessageInitiationType { + // Send initial handshake messages via DERP. + c.derpMu.Lock() + derp := c.derp + c.derpMu.Unlock() + + if derp != nil { + if err := derp.Send(a.publicKey, b); err != nil { + log.Printf("derp send failed: %v", err) + } + } + } + + if success { + return nil + } + } + + // Write to the highest-priority address we have seen so far. + _, err := c.pconn.WriteTo(b, a.dst()) + return err +} + +func (c *Conn) findIndexedAddrSet(addr *net.UDPAddr) (addrSet *AddrSet, index int) { + var epAddr udpAddr + copy(epAddr.ip.Addr[:], addr.IP.To16()) + epAddr.port = uint16(addr.Port) + + c.indexedAddrsMu.Lock() + defer c.indexedAddrsMu.Unlock() + + indAddr := c.indexedAddrs[epAddr] + if indAddr.addr == nil { + return nil, 0 + } + return indAddr.addr, indAddr.index +} + +func (c *Conn) ReceiveIPv4(b []byte) (n int, ep device.Endpoint, addr *net.UDPAddr, err error) { + // Read a packet, and process any STUN packets before returning. + for { + var pAddr net.Addr + n, pAddr, err = c.pconn.ReadFrom(b) + if err != nil { + return n, nil, nil, err + } + addr = pAddr.(*net.UDPAddr) + addr.IP = addr.IP.To4() + + if !stun.Is(b[:n]) { + break + } + c.stunReceiveMu.Lock() + fn := c.stunReceive + c.stunReceiveMu.Unlock() + + if fn != nil { + fn(b, addr) + } + } + + // TODO(crawshaw): remove all the indexed-addr logic + addrSet, _ := c.findIndexedAddrSet(addr) + if addrSet == nil { + // The peer that sent this packet has roamed beyond the + // knowledge provided by the control server. + // If the packet is valid wireguard will call UpdateDst + // on the original endpoint using this addr. + return n, (*singleEndpoint)(addr), addr, nil + } + return n, addrSet, addr, nil +} + +func (c *Conn) ReceiveIPv6(buff []byte) (int, device.Endpoint, *net.UDPAddr, error) { + // TODO(crawshaw): IPv6 support + return 0, nil, nil, syscall.EAFNOSUPPORT +} + +func (c *Conn) SetPrivateKey(privateKey [32]byte) error { + if c.derpServer == "" { + return nil + } + + derp, err := derphttp.NewClient(privateKey, c.derpServer, log.Printf) + if err != nil { + return err + } + go func() { + var b [1 << 16]byte + for { + n, err := derp.Recv(b[:]) + if err != nil { + if err == derphttp.ErrClientClosed { + return + } + log.Printf("%v", err) + time.Sleep(250 * time.Millisecond) + } + + // Trigger re-STUN. + c.startEpUpdate <- struct{}{} + + addr := c.pconn.LocalAddr() + if _, err := c.pconn.WriteToUDP(b[:n], addr); err != nil { + log.Printf("%v", err) + } + } + }() + + c.derpMu.Lock() + if c.derp != nil { + if err := c.derp.Close(); err != nil { + log.Printf("derp.Close: %v", err) + } + } + c.derp = derp + c.derpMu.Unlock() + + return nil +} + +func (c *Conn) SetMark(value uint32) error { return nil } + +func (c *Conn) Close() error { + c.epUpdateCancel() + return c.pconn.Close() +} + +func (c *Conn) LinkChange() { + defer func() { + c.startEpUpdate <- struct{}{} // re-STUN + }() + + if c.pconnPort != 0 { + c.pconn.mu.Lock() + if err := c.pconn.pconn.Close(); err != nil { + log.Printf("magicsock: link change close failed: %v", err) + } + packetConn, err := net.ListenPacket("udp4", fmt.Sprintf(":%d", c.pconnPort)) + if err == nil { + log.Printf("magicsock: link change rebound port: %d", c.pconnPort) + c.pconn.pconn = packetConn.(*net.UDPConn) + c.pconn.mu.Unlock() + return + } + log.Printf("magicsock: link change unable to bind fixed port %d: %v, falling back to random port", c.pconnPort, err) + c.pconn.mu.Unlock() + } + + log.Printf("magicsock: link change, binding new port") + packetConn, err := net.ListenPacket("udp4", ":0") + if err != nil { + log.Printf("magicsock: link change failed to bind new port: %v", err) + return + } + c.pconn.Reset(packetConn.(*net.UDPConn)) +} + +// AddrSet is a set of UDP addresses that implements wireguard/device.Endpoint. +type AddrSet struct { + publicKey [32]byte // peer public key used for DERP communication + addrs []net.UDPAddr // ordered priority list provided by wgengine + + mu sync.Mutex // guards roamAddr and curAddr + roamAddr *net.UDPAddr // peer addr determined from incoming packets + // curAddr is an index into addrs of the highest-priority + // address a valid packet has been received from so far. + // If no valid packet from addrs has been received, curAddr is -1. + curAddr int +} + +var noAddr = &net.UDPAddr{ + IP: net.ParseIP("127.127.127.127"), + Port: 127, +} + +func (a *AddrSet) dst() *net.UDPAddr { + a.mu.Lock() + defer a.mu.Unlock() + + if a.roamAddr != nil { + return a.roamAddr + } + if len(a.addrs) == 0 { + return noAddr + } + i := a.curAddr + if i == -1 { + i = 0 + } + return &a.addrs[i] +} + +func (a *AddrSet) DstToBytes() []byte { + dst := a.dst() + b := append([]byte(nil), dst.IP.To4()...) + if len(b) == 0 { + b = append([]byte(nil), dst.IP...) + } + b = append(b, byte(dst.Port&0xff)) + b = append(b, byte((dst.Port>>8)&0xff)) + return b +} +func (a *AddrSet) DstToString() string { + dst := a.dst() + return dst.String() +} +func (a *AddrSet) DstIP() net.IP { + return a.dst().IP +} +func (a *AddrSet) SrcIP() net.IP { return nil } +func (a *AddrSet) SrcToString() string { return "" } +func (a *AddrSet) ClearSrc() {} + +func (a *AddrSet) UpdateDst(new *net.UDPAddr) error { + a.mu.Lock() + defer a.mu.Unlock() + + if a.roamAddr != nil { + if equalUDPAddr(a.roamAddr, new) { + // Packet from the current roaming address, no logging. + // This is a hot path for established connections. + return nil + } + } else if a.curAddr >= 0 && equalUDPAddr(new, &a.addrs[a.curAddr]) { + // Packet from current-priority address, no logging. + // This is a hot path for established connections. + return nil + } + + index := -1 + for i := range a.addrs { + if equalUDPAddr(new, &a.addrs[i]) { + index = i + break + } + } + + publicKey := wgcfg.Key(a.publicKey) + pk := publicKey.ShortString() + old := "" + if a.curAddr >= 0 { + old = a.addrs[a.curAddr].String() + } + + switch { + case index == -1: + if a.roamAddr == nil { + log.Printf("magicsock: rx %s from roaming address %s, set as new priority", pk, new) + } else { + log.Printf("magicsock: rx %s from roaming address %s, replaces roaming address %s", pk, new, a.roamAddr) + } + a.roamAddr = new + + case a.roamAddr != nil: + log.Printf("magicsock: rx %s from known %s (%d), replacs roaming address %s", pk, new, index, a.roamAddr) + a.roamAddr = nil + a.curAddr = index + + case a.curAddr == -1: + log.Printf("magicsock: rx %s from %s (%d/%d), set as new priority", pk, new, index, len(a.addrs)) + a.curAddr = index + + case index < a.curAddr: + log.Printf("magicsock: rx %s from low-pri %s (%d), keeping current %s (%d)", pk, new, index, old, a.curAddr) + + default: // index > a.curAddr + log.Printf("magicsock: rx %s from %s (%d/%d), replaces old priority %s", pk, new, index, len(a.addrs), old) + a.curAddr = index + } + + return nil +} + +func equalUDPAddr(x, y *net.UDPAddr) bool { + return x.Port == y.Port && x.IP.Equal(y.IP) +} + +func (a *AddrSet) String() string { + a.mu.Lock() + defer a.mu.Unlock() + + buf := new(strings.Builder) + buf.WriteByte('[') + if a.roamAddr != nil { + fmt.Fprintf(buf, "roam:%s:%d", a.roamAddr.IP, a.roamAddr.Port) + } + for i, addr := range a.addrs { + if i > 0 || a.roamAddr != nil { + buf.WriteString(", ") + } + fmt.Fprintf(buf, "%s:%d", addr.IP, addr.Port) + if a.curAddr == i { + buf.WriteByte('*') + } + } + buf.WriteByte(']') + + return buf.String() +} + +func (c *Conn) CreateEndpoint(key [32]byte, s string) (device.Endpoint, error) { + pk := wgcfg.Key(key) + log.Printf("magicsock: CreateEndpoint: key=%s: %s", pk.ShortString(), s) + a := &AddrSet{ + publicKey: key, + curAddr: -1, + } + + if s != "" { + for _, ep := range strings.Split(s, ",") { + addr, err := net.ResolveUDPAddr("udp", ep) + if err != nil { + return nil, err + } + if ip4 := addr.IP.To4(); ip4 != nil { + addr.IP = ip4 + } + a.addrs = append(a.addrs, *addr) + } + } + + c.indexedAddrsMu.Lock() + for i, addr := range a.addrs { + var epAddr udpAddr + copy(epAddr.ip.Addr[:], addr.IP.To16()) + epAddr.port = uint16(addr.Port) + c.indexedAddrs[epAddr] = indexedAddrSet{ + addr: a, + index: i, + } + } + c.indexedAddrsMu.Unlock() + + return a, nil +} + +type singleEndpoint net.UDPAddr + +func (e *singleEndpoint) ClearSrc() {} +func (e *singleEndpoint) DstIP() net.IP { return (*net.UDPAddr)(e).IP } +func (e *singleEndpoint) SrcIP() net.IP { return nil } +func (e *singleEndpoint) SrcToString() string { return "" } +func (e *singleEndpoint) DstToString() string { return (*net.UDPAddr)(e).String() } +func (e *singleEndpoint) DstToBytes() []byte { + addr := (*net.UDPAddr)(e) + out := addr.IP.To4() + if out == nil { + out = addr.IP + } + out = append(out, byte(addr.Port&0xff)) + out = append(out, byte((addr.Port>>8)&0xff)) + return out +} +func (e *singleEndpoint) UpdateDst(dst *net.UDPAddr) error { + return fmt.Errorf("magicsock.singleEndpoint(%s).UpdateDst(%s): should never be called", (*net.UDPAddr)(e), dst) +} + +// RebindingUDPConn is a UDP socket that can be re-bound. +// Unix has no notion of re-binding a socket, so we swap it out for a new one. +type RebindingUDPConn struct { + mu sync.Mutex + pconn *net.UDPConn +} + +func (c *RebindingUDPConn) Reset(pconn *net.UDPConn) { + c.mu.Lock() + old := c.pconn + c.pconn = pconn + c.mu.Unlock() + + if old != nil { + old.Close() + } +} + +func (c *RebindingUDPConn) ReadFrom(b []byte) (int, net.Addr, error) { + for { + c.mu.Lock() + pconn := c.pconn + c.mu.Unlock() + + n, addr, err := pconn.ReadFrom(b) + if err != nil { + c.mu.Lock() + pconn2 := c.pconn + c.mu.Unlock() + + if pconn != pconn2 { + continue + } + } + return n, addr, err + } +} + +func (c *RebindingUDPConn) LocalAddr() *net.UDPAddr { + c.mu.Lock() + defer c.mu.Unlock() + return c.pconn.LocalAddr().(*net.UDPAddr) +} + +func (c *RebindingUDPConn) Close() error { + c.mu.Lock() + defer c.mu.Unlock() + return c.pconn.Close() +} + +func (c *RebindingUDPConn) WriteToUDP(b []byte, addr *net.UDPAddr) (int, error) { + for { + c.mu.Lock() + pconn := c.pconn + c.mu.Unlock() + + n, err := pconn.WriteToUDP(b, addr) + if err != nil { + c.mu.Lock() + pconn2 := c.pconn + c.mu.Unlock() + + if pconn != pconn2 { + continue + } + } + return n, err + } +} + +func (c *RebindingUDPConn) WriteTo(b []byte, addr net.Addr) (int, error) { + for { + c.mu.Lock() + pconn := c.pconn + c.mu.Unlock() + + n, err := pconn.WriteTo(b, addr) + if err != nil { + c.mu.Lock() + pconn2 := c.pconn + c.mu.Unlock() + + if pconn != pconn2 { + continue + } + } + return n, err + } +} diff --git a/wgengine/magicsock/magicsock_test.go b/wgengine/magicsock/magicsock_test.go new file mode 100644 index 000000000..299783fa4 --- /dev/null +++ b/wgengine/magicsock/magicsock_test.go @@ -0,0 +1,73 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package magicsock + +import ( + "fmt" + "net" + "strings" + "testing" + "time" +) + +func TestListen(t *testing.T) { + epCh := make(chan string, 16) + epFunc := func(endpoints []string) { + for _, ep := range endpoints { + epCh <- ep + } + } + + // TODO(crawshaw): break test dependency on the network + // using "gortc.io/stun" (like stunner_test.go). + stunServers := DefaultSTUN + + port := pickPort(t) + conn, err := Listen(Options{ + Port: port, + STUN: stunServers, + EndpointsFunc: epFunc, + }) + if err != nil { + t.Fatal(err) + } + defer conn.Close() + + go func() { + var pkt [1 << 16]byte + for { + _, _, _, err := conn.ReceiveIPv4(pkt[:]) + if err != nil { + return + } + } + }() + + timeout := time.After(10 * time.Second) + var endpoints []string + suffix := fmt.Sprintf(":%d", port) +collectEndpoints: + for { + select { + case ep := <-epCh: + endpoints = append(endpoints, ep) + if strings.HasSuffix(ep, suffix) { + break collectEndpoints + } + case <-timeout: + t.Fatalf("timeout with endpoints: %v", endpoints) + } + } +} + +func pickPort(t *testing.T) uint16 { + t.Helper() + conn, err := net.ListenPacket("udp4", ":0") + if err != nil { + t.Fatal(err) + } + defer conn.Close() + return uint16(conn.LocalAddr().(*net.UDPAddr).Port) +} diff --git a/wgengine/packet/packet.go b/wgengine/packet/packet.go new file mode 100644 index 000000000..6af7b6dcb --- /dev/null +++ b/wgengine/packet/packet.go @@ -0,0 +1,363 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package packet + +import ( + "encoding/binary" + "encoding/json" + "fmt" + "log" + "net" + "strings" +) + +type IPProto int + +const ( + Junk IPProto = iota + Fragment + ICMP + UDP + TCP +) + +// RFC1858: prevent overlapping fragment attacks. +const MIN_FRAG = 60 + 20 // max IPv4 header + basic TCP header + +func (p IPProto) String() string { + switch p { + case Fragment: + return "Frag" + case ICMP: + return "ICMP" + case UDP: + return "UDP" + case TCP: + return "TCP" + default: + return "Junk" + } +} + +type IP uint32 + +const IPAny = IP(0) + +func NewIP(b net.IP) IP { + b4 := b.To4() + if b4 == nil { + panic(fmt.Sprintf("To4(%v) failed", b)) + } + return IP(binary.BigEndian.Uint32(b4)) +} + +func (ip IP) String() string { + if ip == 0 { + return "*" + } + b := make([]byte, 4) + binary.BigEndian.PutUint32(b, uint32(ip)) + return fmt.Sprintf("%d.%d.%d.%d", b[0], b[1], b[2], b[3]) +} + +func (ipp *IP) MarshalJSON() ([]byte, error) { + s := "\"" + (*ipp).String() + "\"" + return []byte(s), nil +} + +func (ipp *IP) UnmarshalJSON(b []byte) error { + var hostp *string + err := json.Unmarshal(b, &hostp) + if err != nil { + return err + } + host := *hostp + ip := net.ParseIP(host) + if ip != nil && ip.IsUnspecified() { + // For clarity, reject 0.0.0.0 as an input + return fmt.Errorf("Ports=%#v: to allow all IP addresses, use *:port, not 0.0.0.0:port", host) + } else if ip == nil && host == "*" { + // User explicitly requested wildcard dst ip + *ipp = IPAny + } else { + if ip != nil { + ip = ip.To4() + } + if ip == nil || len(ip) != 4 { + return fmt.Errorf("Ports=%#v: invalid IPv4 address", host) + } + *ipp = NewIP(ip) + } + return nil +} + +const ( + EchoReply uint8 = 0x00 + EchoRequest uint8 = 0x08 +) + +const ( + TCPSyn uint8 = 0x02 + TCPAck uint8 = 0x10 + TCPSynAck uint8 = TCPSyn | TCPAck +) + +type QDecode struct { + b []byte // Packet buffer that this decodes + subofs int // byte offset of IP subprotocol + + IPProto IPProto // IP subprotocol (UDP, TCP, etc) + SrcIP IP // IP source address + DstIP IP // IP destination address + SrcPort uint16 // TCP/UDP source port + DstPort uint16 // TCP/UDP destination port + TCPFlags uint8 // TCP flags (SYN, ACK, etc) +} + +func (q QDecode) String() string { + if q.IPProto == Junk { + return "Junk{}" + } + srcip := make([]byte, 4) + dstip := make([]byte, 4) + binary.BigEndian.PutUint32(srcip, uint32(q.SrcIP)) + binary.BigEndian.PutUint32(dstip, uint32(q.DstIP)) + return fmt.Sprintf("%v{%d.%d.%d.%d:%d > %d.%d.%d.%d:%d}", + q.IPProto, + srcip[0], srcip[1], srcip[2], srcip[3], q.SrcPort, + dstip[0], dstip[1], dstip[2], dstip[3], q.DstPort) +} + +// based on https://tools.ietf.org/html/rfc1071 +func ipChecksum(b []byte) uint16 { + var ac uint32 + i := 0 + n := len(b) + for n >= 2 { + ac += uint32(binary.BigEndian.Uint16(b[i : i+2])) + n -= 2 + i += 2 + } + if n == 1 { + ac += uint32(b[i]) << 8 + } + for (ac >> 16) > 0 { + ac = (ac >> 16) + (ac & 0xffff) + } + return uint16(^ac) +} + +func GenICMP(srcIP, dstIP IP, ipid uint16, icmpType uint8, icmpCode uint8, payload []byte) []byte { + if len(payload) < 4 { + return nil + } + if len(payload) > 65535-24 { + return nil + } + + sz := 24 + len(payload) + out := make([]byte, 24+len(payload)) + out[0] = 0x45 // IPv4, 20-byte header + out[1] = 0x00 // DHCP, ECN + binary.BigEndian.PutUint16(out[2:4], uint16(sz)) + binary.BigEndian.PutUint16(out[4:6], ipid) + binary.BigEndian.PutUint16(out[6:8], 0) // flags, offset + out[8] = 64 // TTL + out[9] = 0x01 // ICMPv4 + // out[10:12] = 0x00 // blank IP header checksum + binary.BigEndian.PutUint32(out[12:16], uint32(srcIP)) + binary.BigEndian.PutUint32(out[16:20], uint32(dstIP)) + + out[20] = icmpType + out[21] = icmpCode + //out[22:24] = 0x00 // blank ICMP checksum + copy(out[24:len(out)], payload) + + binary.BigEndian.PutUint16(out[10:12], ipChecksum(out[0:20])) + binary.BigEndian.PutUint16(out[22:24], ipChecksum(out)) + return out +} + +// An extremely simple packet decoder for basic IPv4 packet types. +// It extracts only the subprotocol id, IP addresses, and (if any) ports, +// and shouldn't need any memory allocation. +func (q *QDecode) Decode(b []byte) { + q.b = nil + + if len(b) < 20 { + q.IPProto = Junk + return + } + // Check that it's IPv4. + // TODO(apenwarr): consider IPv6 support + if ((b[0] & 0xF0) >> 4) != 4 { + q.IPProto = Junk + return + } + + n := int(binary.BigEndian.Uint16(b[2:4])) + if len(b) < n { + // Packet was cut off before full IPv4 length. + q.IPProto = Junk + return + } + + // If it's valid IPv4, then the IP addresses are valid + q.SrcIP = IP(binary.BigEndian.Uint32(b[12:16])) + q.DstIP = IP(binary.BigEndian.Uint32(b[16:20])) + + q.subofs = int((b[0] & 0x0F) * 4) + sub := b[q.subofs:] + + // We don't care much about IP fragmentation, except insofar as it's + // used for firewall bypass attacks. The trick is make the first + // fragment of a TCP or UDP packet so short that it doesn't fit + // the TCP or UDP header, so we can't read the port, in hope that + // it'll sneak past. Then subsequent fragments fill it in, but we're + // missing the first part of the header, so we can't read that either. + // + // A "perfectly correct" implementation would have to reassemble + // fragments before deciding what to do. But the truth is there's + // zero reason to send such a short first fragment, so we can treat + // it as Junk. We can also treat any subsequent fragment that starts + // at such a low offset as Junk. + fragFlags := binary.BigEndian.Uint16(b[6:8]) + moreFrags := (fragFlags & 0x20) != 0 + fragOfs := fragFlags & 0x1FFF + if fragOfs == 0 { + // This is the first fragment + if moreFrags && len(sub) < MIN_FRAG { + // Suspiciously short first fragment, dump it. + log.Printf("junk1!\n") + q.IPProto = Junk + return + } + // otherwise, this is either non-fragmented (the usual case) + // or a big enough initial fragment that we can read the + // whole subprotocol header. + proto := b[9] + switch proto { + case 1: // ICMPv4 + if len(sub) < 8 { + q.IPProto = Junk + return + } + q.IPProto = ICMP + q.SrcPort = 0 + q.DstPort = 0 + q.b = b + return + case 6: // TCP + if len(sub) < 20 { + q.IPProto = Junk + return + } + q.IPProto = TCP + q.SrcPort = binary.BigEndian.Uint16(sub[0:2]) + q.DstPort = binary.BigEndian.Uint16(sub[2:4]) + q.TCPFlags = sub[13] & 0x3F + q.b = b + return + case 17: // UDP + if len(sub) < 8 { + q.IPProto = Junk + return + } + q.IPProto = UDP + q.SrcPort = binary.BigEndian.Uint16(sub[0:2]) + q.DstPort = binary.BigEndian.Uint16(sub[2:4]) + q.b = b + return + default: + q.IPProto = Junk + return + } + } else { + // This is a fragment other than the first one. + if fragOfs < MIN_FRAG { + // First frag was suspiciously short, so we can't + // trust the followup either. + q.IPProto = Junk + return + } + // otherwise, we have to permit the fragment to slide through. + // Second and later fragments don't have sub-headers. + // Ideally, we would drop fragments that we can't identify, + // but that would require statefulness. Anyway, receivers' + // kernels know to drop fragments where the initial fragment + // doesn't arrive. + q.IPProto = Fragment + return + } +} + +// Returns a subset of the IP subprotocol section. +func (q *QDecode) Sub(begin, n int) []byte { + return q.b[q.subofs+begin : q.subofs+begin+n] +} + +// For a packet that is known to be IPv4, trim the buffer to its IPv4 length. +// Sometimes packets arrive from an interface with extra bytes on the end. +// This removes them. +func (q *QDecode) Trim() []byte { + n := binary.BigEndian.Uint16(q.b[2:4]) + return q.b[0:n] +} + +// For a decoded TCP packet, return true if it's a TCP SYN packet (ie. the +// first packet in a new connection). +func (q *QDecode) IsTCPSyn() bool { + const Syn = 0x02 + const Ack = 0x10 + const SynAck = Syn | Ack + return (q.TCPFlags & SynAck) == Syn +} + +// For a packet that has already been decoded, check if it's an IPv4 ICMP +// Echo Request. +func (q *QDecode) IsEchoRequest() bool { + if q.IPProto == ICMP && len(q.b) >= q.subofs+8 { + return q.b[q.subofs] == EchoRequest && q.b[q.subofs+1] == 0 + } + return false +} + +func (q *QDecode) EchoRespond() []byte { + icmpid := binary.BigEndian.Uint16(q.Sub(4, 2)) + b := q.Trim() + return GenICMP(q.DstIP, q.SrcIP, icmpid, EchoReply, 0, b[q.subofs+4:]) +} + +func Hexdump(b []byte) string { + out := new(strings.Builder) + for i := 0; i < len(b); i += 16 { + if i > 0 { + fmt.Fprintf(out, "\n") + } + fmt.Fprintf(out, " %04x ", i) + j := 0 + for ; j < 16 && i+j < len(b); j++ { + if j == 8 { + fmt.Fprintf(out, " ") + } + fmt.Fprintf(out, "%02x ", b[i+j]) + } + for ; j < 16; j++ { + if j == 8 { + fmt.Fprintf(out, " ") + } + fmt.Fprintf(out, " ") + } + fmt.Fprintf(out, " ") + for j = 0; j < 16 && i+j < len(b); j++ { + if b[i+j] >= 32 && b[i+j] < 128 { + fmt.Fprintf(out, "%c", b[i+j]) + } else { + fmt.Fprintf(out, ".") + } + } + } + return out.String() +} diff --git a/wgengine/router_darwin.go b/wgengine/router_darwin.go new file mode 100644 index 000000000..c4c09ace4 --- /dev/null +++ b/wgengine/router_darwin.go @@ -0,0 +1,36 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package wgengine + +import ( + "github.com/tailscale/wireguard-go/device" + "github.com/tailscale/wireguard-go/tun" + "tailscale.com/logger" +) + +type darwinRouter struct { + tunname string +} + +func NewUserspaceRouter(logf logger.Logf, tunname string, dev *device.Device, tuntap tun.Device, netChanged func()) Router { + r := darwinRouter{ + tunname: tunname, + } + return &r +} + +func (r *darwinRouter) Up() error { + return nil +} + +func (r *darwinRouter) SetRoutes(rs RouteSettings) error { + if SetRoutesFunc != nil { + return SetRoutesFunc(rs) + } + return nil +} + +func (r *darwinRouter) Close() { +} diff --git a/wgengine/router_default.go b/wgengine/router_default.go new file mode 100644 index 000000000..74f993b39 --- /dev/null +++ b/wgengine/router_default.go @@ -0,0 +1,17 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build !windows,!linux,!darwin + +package wgengine + +import ( + "github.com/tailscale/wireguard-go/device" + "github.com/tailscale/wireguard-go/tun" + "tailscale.com/logger" +) + +func NewUserspaceRouter(logf logger.Logf, tunname string, dev *device.Device, tuntap tun.Device, netChanged func()) Router { + return NewFakeRouter(logf, tunname, dev, tuntap) +} diff --git a/wgengine/router_fake.go b/wgengine/router_fake.go new file mode 100644 index 000000000..8157e929a --- /dev/null +++ b/wgengine/router_fake.go @@ -0,0 +1,38 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package wgengine + +import ( + "github.com/tailscale/wireguard-go/device" + "github.com/tailscale/wireguard-go/tun" + "tailscale.com/logger" +) + +type fakeRouter struct { + tunname string + logf logger.Logf +} + +func NewFakeRouter(logf logger.Logf, tunname string, dev *device.Device, tuntap tun.Device, netChanged func()) Router { + r := fakeRouter{ + logf: logf, + tunname: tunname, + } + return &r +} + +func (r *fakeRouter) Up() error { + r.logf("Warning: fakeRouter.Up: not implemented.\n") + return nil +} + +func (r *fakeRouter) SetRoutes(rs RouteSettings) error { + r.logf("Warning: fakeRouter.SetRoutes: not implemented.\n") + return nil +} + +func (r *fakeRouter) Close() { + r.logf("Warning: fakeRouter.Close: not implemented.\n") +} diff --git a/wgengine/router_linux.go b/wgengine/router_linux.go new file mode 100644 index 000000000..edd93d570 --- /dev/null +++ b/wgengine/router_linux.go @@ -0,0 +1,267 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package wgengine + +import ( + "bytes" + "fmt" + "io/ioutil" + "log" + "net" + "os" + "os/exec" + "path/filepath" + "strings" + + "github.com/tailscale/wireguard-go/device" + "github.com/tailscale/wireguard-go/tun" + "github.com/tailscale/wireguard-go/wgcfg" + "tailscale.com/atomicfile" + "tailscale.com/logger" + "tailscale.com/wgengine/rtnlmon" +) + +type linuxRouter struct { + logf func(fmt string, args ...interface{}) + tunname string + mon *rtnlmon.Mon + netChanged func() + local wgcfg.CIDR + routes map[wgcfg.CIDR]struct{} +} + +func NewUserspaceRouter(logf logger.Logf, tunname string, dev *device.Device, tuntap tun.Device, netChanged func()) Router { + mon, err := rtnlmon.New(logf, netChanged) + if err != nil { + log.Fatalf("rtnlmon.New() failed: %v", err) + } + + r := linuxRouter{ + logf: logf, + tunname: tunname, + mon: mon, + netChanged: netChanged, + } + return &r +} + +func cmd(args ...string) *exec.Cmd { + if len(args) == 0 { + log.Fatalf("exec.Cmd(%#v) invalid; need argv[0]\n", args) + } + return exec.Command(args[0], args[1:]...) +} + +func (r *linuxRouter) Up() error { + out, err := cmd("ip", "link", "set", r.tunname, "up").CombinedOutput() + if err != nil { + log.Fatalf("running ip link failed: %v\n%s", err, out) + } + + // TODO(apenwarr): This never cleans up after itself! + out, err = cmd("iptables", + "-A", "FORWARD", + "-i", r.tunname, + "-j", "ACCEPT").CombinedOutput() + if err != nil { + r.logf("iptables forward failed: %v\n%s", err, out) + } + // TODO(apenwarr): hardcoded eth0 interface is obviously not right. + out, err = cmd("iptables", + "-t", "nat", + "-A", "POSTROUTING", + "-o", "eth0", + "-j", "MASQUERADE").CombinedOutput() + if err != nil { + r.logf("iptables nat failed: %v\n%s", err, out) + } + return nil +} + +func (r *linuxRouter) SetRoutes(rs RouteSettings) error { + var errq error + + if rs.LocalAddr != r.local { + if r.local != (wgcfg.CIDR{}) { + addrdel := []string{"ip", "addr", + "del", r.local.String(), + "dev", r.tunname} + out, err := cmd(addrdel...).CombinedOutput() + if err != nil { + r.logf("addr del failed: %v: %v\n%s", addrdel, err, out) + if errq == nil { + errq = err + } + } + } + addradd := []string{"ip", "addr", + "add", rs.LocalAddr.String(), + "dev", r.tunname} + out, err := cmd(addradd...).CombinedOutput() + if err != nil { + r.logf("addr add failed: %v: %v\n%s", addradd, err, out) + if errq == nil { + errq = err + } + } + } + + newRoutes := make(map[wgcfg.CIDR]struct{}) + for _, peer := range rs.Cfg.Peers { + for _, route := range peer.AllowedIPs { + newRoutes[route] = struct{}{} + } + } + for route := range r.routes { + if _, keep := newRoutes[route]; !keep { + net := route.IPNet() + nip := net.IP.Mask(net.Mask) + nstr := fmt.Sprintf("%v/%d", nip, route.Mask) + addrdel := []string{"ip", "route", + "del", nstr, + "via", r.local.IP.String(), + "dev", r.tunname} + out, err := cmd(addrdel...).CombinedOutput() + if err != nil { + r.logf("addr del failed: %v: %v\n%s", addrdel, err, out) + if errq == nil { + errq = err + } + } + } + } + for route := range newRoutes { + if _, exists := r.routes[route]; !exists { + net := route.IPNet() + nip := net.IP.Mask(net.Mask) + nstr := fmt.Sprintf("%v/%d", nip, route.Mask) + addradd := []string{"ip", "route", + "add", nstr, + "via", rs.LocalAddr.IP.String(), + "dev", r.tunname} + out, err := cmd(addradd...).CombinedOutput() + if err != nil { + r.logf("addr add failed: %v: %v\n%s", addradd, err, out) + if errq == nil { + errq = err + } + } + } + } + + r.local = rs.LocalAddr + r.routes = newRoutes + + if false { + if err := r.replaceResolvConf(rs.DNS, rs.DNSDomains); err != nil { + errq = fmt.Errorf("replacing resolv.conf failed: %v", err) + } + } + return errq +} + +func (r *linuxRouter) Close() { + r.mon.Close() + if err := r.restoreResolvConf(); err != nil { + r.logf("failed to restore system resolv.conf: %v", err) + } + // TODO(apenwarr): clean up iptables etc. +} + +const ( + tsConf = "/etc/resolv.tailscale.conf" + backupConf = "/etc/resolv.pre-tailscale-backup.conf" + resolvConf = "/etc/resolv.conf" +) + +func (r *linuxRouter) replaceResolvConf(servers []net.IP, domains []string) error { + if len(servers) == 0 { + return r.restoreResolvConf() + } + + // First write the tsConf file. + buf := new(bytes.Buffer) + fmt.Fprintf(buf, "# resolv.conf(5) file generated by tailscale\n") + fmt.Fprintf(buf, "# DO NOT EDIT THIS FILE BY HAND -- CHANGES WILL BE OVERWRITTEN\n\n") + for _, ns := range servers { + fmt.Fprintf(buf, "nameserver %s\n", ns) + } + if len(domains) > 0 { + fmt.Fprintf(buf, "search "+strings.Join(domains, " ")+"\n") + } + f, err := ioutil.TempFile(filepath.Dir(tsConf), filepath.Base(tsConf)+".*") + if err != nil { + return err + } + f.Close() + if err := atomicfile.WriteFile(f.Name(), buf.Bytes(), 0644); err != nil { + return err + } + os.Chmod(f.Name(), 0644) // ioutil.TempFile creates the file with 0600 + if err := os.Rename(f.Name(), tsConf); err != nil { + return err + } + + if linkPath, err := os.Readlink(resolvConf); err != nil { + // Remove any old backup that may exist. + os.Remove(backupConf) + + // Backup the existing /etc/resolv.conf file. + contents, err := ioutil.ReadFile(resolvConf) + if os.IsNotExist(err) { + // No existing /etc/resolve.conf file to backup. + // Nothing to do. + return nil + } else if err != nil { + return err + } + if err := atomicfile.WriteFile(backupConf, contents, 0644); err != nil { + return err + } + } else if linkPath != tsConf { + // Backup the existing symlink. + os.Remove(backupConf) + if err := os.Symlink(linkPath, backupConf); err != nil { + return err + } + } else { + // Nothing to do, resolvConf already points to tsConf. + return nil + } + + os.Remove(resolvConf) + if err := os.Symlink(tsConf, resolvConf); err != nil { + return nil + } + + out, _ := exec.Command("service", "systemd-resolved", "restart").CombinedOutput() + if len(out) > 0 { + r.logf("service systemd-resolved restart: %s", out) + } + return nil +} + +func (r *linuxRouter) restoreResolvConf() error { + if _, err := os.Stat(backupConf); err != nil { + if os.IsNotExist(err) { + return nil // no backup resolve.conf to restore + } + return err + } + if ln, err := os.Readlink(resolvConf); err != nil { + return err + } else if ln != tsConf { + return fmt.Errorf("resolve.conf is not a symlink to %s", tsConf) + } + if err := os.Rename(backupConf, resolvConf); err != nil { + return err + } + os.Remove(tsConf) // best effort removal of tsConf file + out, _ := exec.Command("service", "systemd-resolved", "restart").CombinedOutput() + if len(out) > 0 { + r.logf("service systemd-resolved restart: %s", out) + } + return nil +} diff --git a/wgengine/router_windows.go b/wgengine/router_windows.go new file mode 100644 index 000000000..c05b81b3e --- /dev/null +++ b/wgengine/router_windows.go @@ -0,0 +1,58 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package wgengine + +import ( + "log" + + "github.com/tailscale/wireguard-go/device" + "github.com/tailscale/wireguard-go/tun" + "golang.zx2c4.com/winipcfg" + "tailscale.com/logger" +) + +type winRouter struct { + logf func(fmt string, args ...interface{}) + tunname string + dev *device.Device + nativeTun *tun.NativeTun + routeChangeCallback *winipcfg.RouteChangeCallback +} + +func NewUserspaceRouter(logf logger.Logf, tunname string, dev *device.Device, tuntap tun.Device, netChanged func()) Router { + r := winRouter{ + logf: logf, + tunname: tunname, + dev: dev, + nativeTun: tuntap.(*tun.NativeTun), + } + return &r +} + +func (r *winRouter) Up() error { + // MonitorDefaultRoutes handles making sure our wireguard UDP + // traffic goes through the old route, not recursively through the VPN. + var err error + r.routeChangeCallback, err = MonitorDefaultRoutes(r.dev, true, r.nativeTun) + if err != nil { + log.Fatalf("MonitorDefaultRoutes: %v\n", err) + } + return nil +} + +func (r *winRouter) SetRoutes(rs RouteSettings) error { + err := ConfigureInterface(&rs.Cfg, r.nativeTun, rs.DNS, rs.DNSDomains) + if err != nil { + r.logf("ConfigureInterface: %v\n", err) + return err + } + return nil +} + +func (r *winRouter) Close() { + if r.routeChangeCallback != nil { + r.routeChangeCallback.Unregister() + } +} diff --git a/wgengine/rtnlmon/mon.go b/wgengine/rtnlmon/mon.go new file mode 100644 index 000000000..080ed2362 --- /dev/null +++ b/wgengine/rtnlmon/mon.go @@ -0,0 +1,114 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package rtnlmon watches for "interesting" changes to the network +// stack and fires a callback. +package rtnlmon + +import ( + "fmt" + "time" + + "github.com/mdlayher/netlink" + "golang.org/x/sys/unix" + "tailscale.com/logger" +) + +// Netlink is not a great protocol for *knowing* things. The protocol +// design makes it impossible to track changes precisely. You can see +// this by looking at things like Quagga or Bird, which all include +// keeping a local impression of what they think is in the kernel, and +// periodically doing a full state dump to find errors. They do use +// events, but explicitly only as an optimization, because they can't +// be trusted. +// +// Fortunately, we don't really need to know what exactly changed. We +// just want to know that network conditions may have changed, and we +// should re-explore connectivity. This is why we subscribe to events, +// and then blindly fire our callback without looking at the content +// of the notifications. + +type ChangeFunc func() + +type Mon struct { + logf logger.Logf + cb ChangeFunc + nl *netlink.Conn + change chan struct{} + stop chan struct{} +} + +func New(logf logger.Logf, callback ChangeFunc) (*Mon, error) { + conn, err := netlink.Dial(unix.NETLINK_ROUTE, &netlink.Config{ + // IPv4 address and route changes. Routes get us most of the + // events of interest, but we need address as well to cover + // things like DHCP deciding to give us a new address upon + // renewal - routing wouldn't change, but all reachability + // would. + // + // Why magic numbers? These aren't exposed in x/sys/unix + // yet. The values come from rtnetlink.h, RTMGRP_IPV4_IFADDR + // and RTMGRP_IPV4_ROUTE. + Groups: 0x10 | 0x40, + }) + if err != nil { + return nil, fmt.Errorf("dialing netlink socket: %v", err) + } + + ret := &Mon{ + logf: logf, + cb: callback, + nl: conn, + change: make(chan struct{}, 1), + stop: make(chan struct{}), + } + go ret.pump() + go ret.debounce() + return ret, nil +} + +func (m *Mon) Close() error { + close(m.stop) + return m.nl.Close() +} + +func (m *Mon) pump() { + for { + _, err := m.nl.Receive() + if err != nil { + select { + case <-m.stop: + return + default: + } + // Keep retrying while we're not closed. + m.logf("Error receiving from netlink: %v", err) + time.Sleep(time.Second) + continue + } + + select { + case m.change <- struct{}{}: + default: + } + } +} + +func (m *Mon) debounce() { + for { + select { + case <-m.stop: + return + case <-m.change: + } + + m.cb() + + select { + case <-m.stop: + return + case <-time.After(100 * time.Millisecond): + } + } +} diff --git a/wgengine/rusage.go b/wgengine/rusage.go new file mode 100644 index 000000000..35756b46c --- /dev/null +++ b/wgengine/rusage.go @@ -0,0 +1,21 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package wgengine + +import ( + "fmt" + "runtime" +) + +func RusagePrefixLog(logf func(f string, argv ...interface{})) func(f string, argv ...interface{}) { + return func(f string, argv ...interface{}) { + var m runtime.MemStats + runtime.ReadMemStats(&m) + goMem := float64(m.HeapInuse+m.StackInuse) / (1 << 20) + maxRSS := rusageMaxRSS() + pf := fmt.Sprintf("%.1fM/%.1fM %s", goMem, maxRSS, f) + logf(pf, argv...) + } +} diff --git a/wgengine/rusage_nowindows.go b/wgengine/rusage_nowindows.go new file mode 100644 index 000000000..f6dd89b4b --- /dev/null +++ b/wgengine/rusage_nowindows.go @@ -0,0 +1,29 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build !windows + +package wgengine + +import ( + "runtime" + "syscall" +) + +func rusageMaxRSS() float64 { + var ru syscall.Rusage + err := syscall.Getrusage(syscall.RUSAGE_SELF, &ru) + if err != nil { + return 0 + } + + rss := float64(ru.Maxrss) + if runtime.GOOS == "darwin" { + rss /= 1 << 20 // ru_maxrss is bytes on darwin + } else { + // ru_maxrss is kilobytes elsewhere (linux, openbsd, etc) + rss /= 1024 + } + return rss +} diff --git a/wgengine/rusage_windows.go b/wgengine/rusage_windows.go new file mode 100644 index 000000000..f4bf15119 --- /dev/null +++ b/wgengine/rusage_windows.go @@ -0,0 +1,10 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package wgengine + +func rusageMaxRSS() float64 { + // TODO(apenwarr): Substitute Windows equivalent of Getrusage() here. + return 0 +} diff --git a/wgengine/userspace.go b/wgengine/userspace.go new file mode 100644 index 000000000..3057d3f0a --- /dev/null +++ b/wgengine/userspace.go @@ -0,0 +1,477 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package wgengine + +import ( + "bufio" + "fmt" + "log" + "strconv" + "strings" + "sync" + "time" + + "github.com/tailscale/wireguard-go/device" + "github.com/tailscale/wireguard-go/tun" + "github.com/tailscale/wireguard-go/wgcfg" + "tailscale.com/logger" + "tailscale.com/tailcfg" + "tailscale.com/wgengine/filter" + "tailscale.com/wgengine/magicsock" + "tailscale.com/wgengine/packet" +) + +type userspaceEngine struct { + logf logger.Logf + statusCallback StatusCallback + reqCh chan struct{} + waitCh chan struct{} + tuntap tun.Device + wgdev *device.Device + router Router + magicConn *magicsock.Conn + + wgLock sync.Mutex // serializes all wgdev operations + lastReconfig string + lastRoutes string + + mu sync.Mutex + peerSequence []wgcfg.Key + endpoints []string +} + +type Loggify struct { + f logger.Logf +} + +func (l *Loggify) Write(b []byte) (int, error) { + l.f(string(b)) + return len(b), nil +} + +func NewFakeUserspaceEngine(logf logger.Logf, listenPort uint16, derp bool) (Engine, error) { + logf("Starting userspace wireguard engine (FAKE tuntap device).") + tun := NewFakeTun() + return NewUserspaceEngineAdvanced(logf, tun, NewFakeRouter, listenPort, derp) +} + +func NewUserspaceEngine(logf logger.Logf, tunname string, listenPort uint16, derp bool) (Engine, error) { + logf("Starting userspace wireguard engine.") + logf("external packet routing via --tun=%s enabled", tunname) + + if tunname == "" { + return nil, fmt.Errorf("--tun name must not be blank.") + } + + tuntap, err := tun.CreateTUN(tunname, device.DefaultMTU) + if err != nil { + log.Printf("CreateTUN: %v\n", err) + return nil, err + } + log.Printf("CreateTUN ok.\n") + + e, err := NewUserspaceEngineAdvanced(logf, tuntap, NewUserspaceRouter, listenPort, derp) + if err != nil { + log.Printf("NewUserspaceEngineAdv: %v\n", err) + return nil, err + } + return e, err +} + +type RouterGen func(logf logger.Logf, tunname string, dev *device.Device, tuntap tun.Device, netStateChanged func()) Router + +func NewUserspaceEngineAdvanced(logf logger.Logf, tuntap tun.Device, routerGen RouterGen, listenPort uint16, derp bool) (Engine, error) { + e := &userspaceEngine{ + logf: logf, + reqCh: make(chan struct{}, 1), + waitCh: make(chan struct{}), + tuntap: tuntap, + } + + tunname, err := tuntap.Name() + if err != nil { + return nil, err + } + + endpointsFn := func(endpoints []string) { + e.mu.Lock() + if e.endpoints != nil { + e.endpoints = e.endpoints[:0] + } + e.endpoints = append(e.endpoints, endpoints...) + e.mu.Unlock() + + e.RequestStatus() + } + magicsockOpts := magicsock.Options{ + Port: listenPort, + STUN: magicsock.DefaultSTUN, + // TODO(crawshaw): DERP: magicsock.DefaultDERP, + EndpointsFunc: endpointsFn, + } + if derp { + magicsockOpts.DERP = magicsock.DefaultDERP + } + e.magicConn, err = magicsock.Listen(magicsockOpts) + if err != nil { + return nil, fmt.Errorf("wgengine: %v", err) + } + + // flags==0 because logf is already nested in another logger. + // The outer one can display the preferred log prefixes, etc. + dlog := log.New(&Loggify{logf}, "", 0) + logger := device.Logger{ + Debug: dlog, + Info: dlog, + Error: dlog, + } + nofilter := func(b []byte) device.FilterResult { + // for safety, default to dropping all packets + logf("Warning: you forgot to use wgengine.SetFilterInOut()! Packet dropped.\n") + return device.FilterDrop + } + + opts := &device.DeviceOptions{ + Logger: &logger, + FilterIn: nofilter, + FilterOut: nofilter, + HandshakeDone: func() { + // Send an unsolicited status event every time a + // handshake completes. This makes sure our UI can + // update quickly as soon as it connects to a peer. + // + // We use a goroutine here to avoid deadlocking + // wireguard, since RequestStatus() will call back + // into it, and wireguard is what called us to get + // here. + go e.RequestStatus() + }, + CreateBind: func(uint16) (device.Bind, uint16, error) { + return e.magicConn, e.magicConn.LocalPort(), nil + }, + CreateEndpoint: e.magicConn.CreateEndpoint, + SkipBindUpdate: true, + } + + e.wgdev = device.NewDevice(e.tuntap, opts) + + go func() { + up := false + for event := range e.tuntap.Events() { + if event&tun.EventMTUUpdate != 0 { + mtu, err := e.tuntap.MTU() + e.logf("external route MTU: %d (%v)", mtu, err) + } + if event&tun.EventUp != 0 && !up { + e.logf("external route: up") + e.RequestStatus() + up = true + } + if event&tun.EventDown != 0 && up { + e.logf("external route: down") + e.RequestStatus() + up = false + } + } + }() + + e.router = routerGen(logf, tunname, e.wgdev, e.tuntap, func() { e.LinkChange(false) }) + e.wgdev.Up() + if err := e.router.Up(); err != nil { + e.wgdev.Close() + return nil, err + } + if err := e.router.SetRoutes(RouteSettings{}); err != nil { + e.wgdev.Close() + return nil, err + } + + return e, nil +} + +// TODO(apenwarr): dnsDomains really ought to be in wgcfg.Config. +// However, we don't actually ever provide it to wireguard and it's not in +// the traditional wireguard config format. On the other hand, wireguard +// itself doesn't use the traditional 'dns =' setting either. +func (e *userspaceEngine) Reconfig(cfg *wgcfg.Config, dnsDomains []string) error { + e.logf("Reconfig(): configuring userspace wireguard engine.\n") + e.wgLock.Lock() + defer e.wgLock.Unlock() + + e.peerSequence = make([]wgcfg.Key, len(cfg.Peers)) + for i, p := range cfg.Peers { + e.peerSequence[i] = p.PublicKey + } + + // TODO(apenwarr): get rid of silly uapi stuff for in-process comms + uapi, err := cfg.ToUAPI() + if err != nil { + return err + } + + rc := uapi + "\x00" + strings.Join(dnsDomains, "\x00") + if rc == e.lastReconfig { + e.logf("...unchanged config, skipping.\n") + return nil + } + e.lastReconfig = rc + + r := bufio.NewReader(strings.NewReader(uapi)) + if err = e.wgdev.IpcSetOperation(r); err != nil { + e.logf("IpcSetOperation: %v\n", err) + return err + } + + if err := e.magicConn.SetPrivateKey(cfg.Interface.PrivateKey); err != nil { + e.logf("magicsock: %v\n", err) + } + + // TODO(apenwarr): only handling the first local address. + // Currently we never use more than one anyway. + var cidr wgcfg.CIDR + if len(cfg.Interface.Addresses) > 0 { + cidr = cfg.Interface.Addresses[0] + // TODO(apenwarr): this shouldn't be hardcoded in the client + cidr.Mask = 10 // route the whole cgnat range + } + + rs := RouteSettings{ + LocalAddr: cidr, + Cfg: *cfg, + DNS: cfg.Interface.Dns, + DNSDomains: dnsDomains, + } + e.logf("Reconfiguring router. la=%v dns=%v dom=%v\n", + rs.LocalAddr, rs.DNS, rs.DNSDomains) + + // TODO(apenwarr): all the parts of RouteSettings should be "relevant." + // We're checking only the "relevant" parts to see if they have + // changed, and if not, skipping SetRoutes(). But if SetRoutes() + // is getting the non-relevant parts of Cfg, it might act on them, + // and this optimization is unsafe. Probably we should not pass + // a whole Cfg object as part of RouteSettings; instead, trim it to + // just what's absolutely needed (the set of actual routes). + rss := rs.OnlyRelevantParts() + e.logf("New routes: %v\n", rss) + if rss == e.lastRoutes { + e.logf("...unchanged routes, skipping.\n") + return nil + } + e.lastRoutes = rss + err = e.router.SetRoutes(rs) + e.logf("Reconfig() done.\n") + return err +} + +func (e *userspaceEngine) SetFilter(filt *filter.Filter) { + var filtin, filtout func(b []byte) device.FilterResult + if filt == nil { + e.logf("wgengine: nil filter provided; no access restrictions.\n") + } else { + ft, ft_ok := e.tuntap.(*fakeTun) + filtin = func(b []byte) device.FilterResult { + runf := filter.LogDrops + //runf |= filter.HexdumpDrops + runf |= filter.LogAccepts + //runf |= filter.HexdumpAccepts + q := &packet.QDecode{} + if filt.RunIn(b, q, runf) == filter.Accept { + // Only in fake mode, answer any incoming pings + if ft_ok && q.IsEchoRequest() { + pb := q.EchoRespond() + ft.InsertRead(pb) + // We already handled it, stop. + return device.FilterDrop + } + return device.FilterAccept + } + return device.FilterDrop + } + + filtout = func(b []byte) device.FilterResult { + runf := filter.LogDrops + //runf |= filter.HexdumpDrops + runf |= filter.LogAccepts + //runf |= filter.HexdumpAccepts + q := &packet.QDecode{} + if filt.RunOut(b, q, runf) == filter.Accept { + return device.FilterAccept + } + return device.FilterDrop + } + } + + e.wgLock.Lock() + defer e.wgLock.Unlock() + + e.wgdev.SetFilterInOut(filtin, filtout) +} + +func (e *userspaceEngine) SetStatusCallback(cb StatusCallback) { + e.statusCallback = cb +} + +func (e *userspaceEngine) getStatus() (*Status, error) { + e.wgLock.Lock() + defer e.wgLock.Unlock() + + if e.wgdev == nil { + // RequestStatus was invoked before the wgengine has + // finished initializing. This can happen when wgegine + // provides a callback to magicsock for endpoint + // updates that calls RequestStatus. + return nil, nil + } + + // TODO(apenwarr): get rid of silly uapi stuff for in-process comms + // FIXME: get notified of status changes instead of polling. + var bb strings.Builder + bio := bufio.NewWriter(&bb) + ipcErr := e.wgdev.IpcGetOperation(bio) + if ipcErr != nil { + log.Fatalf("IpcGetOperation: %v\n", ipcErr) + } + bio.Flush() + + s := Status{} + pp := make(map[wgcfg.Key]*PeerStatus) + var p *PeerStatus = &PeerStatus{} + bbs := bb.String() + lines := strings.Split(bbs, "\n") + var hst1, hst2, n int64 + var err error + for _, line := range lines { + kv := strings.SplitN(line, "=", 2) + var k, v string + k = kv[0] + if len(kv) > 1 { + v = kv[1] + } + switch k { + case "public_key": + pk, err := wgcfg.ParseHexKey(v) + if err != nil { + log.Fatalf("IpcGetOperation: invalid key %#v\n", v) + } + p = &PeerStatus{} + pp[*pk] = p + + key := tailcfg.NodeKey(*pk) + p.NodeKey = key + case "rx_bytes": + n, err = strconv.ParseInt(v, 10, 64) + p.RxBytes = ByteCount(n) + if err != nil { + log.Fatalf("IpcGetOperation: rx_bytes invalid: %#v\n", line) + } + case "tx_bytes": + n, err = strconv.ParseInt(v, 10, 64) + p.TxBytes = ByteCount(n) + if err != nil { + log.Fatalf("IpcGetOperation: tx_bytes invalid: %#v\n", line) + } + case "last_handshake_time_sec": + hst1, err = strconv.ParseInt(v, 10, 64) + if err != nil { + log.Fatalf("IpcGetOperation: hst1 invalid: %#v\n", line) + } + case "last_handshake_time_nsec": + hst2, err = strconv.ParseInt(v, 10, 64) + if err != nil { + log.Fatalf("IpcGetOperation: hst2 invalid: %#v\n", line) + } + if hst1 != 0 || hst2 != 0 { + p.LastHandshake = time.Unix(hst1, hst2) + } // else leave at time.IsZero() + } + } + + e.mu.Lock() + defer e.mu.Unlock() + + var peers []PeerStatus + for _, pk := range e.peerSequence { + p := pp[pk] + if p == nil { + p = &PeerStatus{} + } + peers = append(peers, *p) + } + + if len(pp) != len(e.peerSequence) { + e.logf("wg status returned %v peers, expected %v\n", len(s.Peers), len(e.peerSequence)) + } + + return &Status{ + LocalAddrs: append([]string(nil), e.endpoints...), + Peers: peers, + }, nil +} + +func (e *userspaceEngine) RequestStatus() { + // This is slightly tricky. e.getStatus() can theoretically get + // blocked inside wireguard for a while, and RequestStatus() is + // sometimes called from a goroutine, so we don't want a lot of + // them hanging around. On the other hand, requesting multiple + // status updates simultaneously is pointless anyway; they will + // all say the same thing. + + // Enqueue at most one request. If one is in progress already, this + // adds one more to the queue. If one has been requested but not + // started, it is a no-op. + select { + case e.reqCh <- struct{}{}: + default: + } + + // Dequeue at most one request. Another thread may have already + // dequeued the request we enqueued above, which is fine, since the + // information is guaranteed to be at least as recent as the current + // call to RequestStatus(). + select { + case <-e.reqCh: + s, err := e.getStatus() + if s == nil && err == nil { + e.logf("RequestStatus: weird: both s and err are nil\n") + return + } + if e.statusCallback != nil { + e.statusCallback(s, err) + } + default: + } +} + +func (e *userspaceEngine) Close() { + e.Reconfig(&wgcfg.Config{}, nil) + e.router.Close() + e.magicConn.Close() + close(e.waitCh) +} + +func (e *userspaceEngine) Wait() { + <-e.waitCh +} + +func (e *userspaceEngine) LinkChange(isExpensive bool) { + e.logf("LinkChange(isExpensive=%v): rebinding socket", isExpensive) + e.wgLock.Lock() + defer e.wgLock.Unlock() + + // TODO(crawshaw): use isExpensive=true to switch into "client mode" on macOS? + e.magicConn.LinkChange() + + // TODO(crawshaw): when we have an incremental notion of reconfig, + // be gentler here. No need to smash in-progress connections, + // we just need to handshake again. + if e.lastReconfig == "" { + return + } + uapi := e.lastReconfig[:strings.Index(e.lastReconfig, "\x00")] + r := bufio.NewReader(strings.NewReader(uapi)) + if err := e.wgdev.IpcSetOperation(r); err != nil { + e.logf("IpcSetOperation: %v\n", err) + } +} diff --git a/wgengine/watchdog.go b/wgengine/watchdog.go new file mode 100644 index 000000000..ee4b23d98 --- /dev/null +++ b/wgengine/watchdog.go @@ -0,0 +1,83 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package wgengine + +import ( + "bytes" + "log" + "runtime/pprof" + "time" + + "github.com/tailscale/wireguard-go/wgcfg" + "tailscale.com/wgengine/filter" +) + +// NewWatchdog wraps an Engine and makes sure that all methods complete +// within a reasonable amount of time. +// +// If they do not, the watchdog crashes the process. +func NewWatchdog(e Engine) Engine { + return &watchdogEngine{ + wrap: e, + logf: log.Printf, + fatalf: log.Fatalf, + maxWait: 45 * time.Second, + } +} + +type watchdogEngine struct { + wrap Engine + logf func(format string, args ...interface{}) + fatalf func(format string, args ...interface{}) + maxWait time.Duration +} + +func (e *watchdogEngine) watchdogErr(name string, fn func() error) error { + errCh := make(chan error) + go func() { + errCh <- fn() + }() + t := time.NewTimer(e.maxWait) + select { + case err := <-errCh: + t.Stop() + return err + case <-t.C: + buf := new(bytes.Buffer) + pprof.Lookup("goroutine").WriteTo(buf, 1) + e.logf("wgengine watchdog stacks:\n%s", buf.String()) + e.fatalf("wgengine: watchdog timeout on %s", name) + return nil + } +} + +func (e *watchdogEngine) watchdog(name string, fn func()) { + e.watchdogErr(name, func() error { + fn() + return nil + }) +} + +func (e *watchdogEngine) Reconfig(cfg *wgcfg.Config, dnsDomains []string) error { + return e.watchdogErr("Reconfig", func() error { return e.wrap.Reconfig(cfg, dnsDomains) }) +} +func (e *watchdogEngine) SetFilter(filt *filter.Filter) { + e.watchdog("SetFilter", func() { e.wrap.SetFilter(filt) }) +} +func (e *watchdogEngine) SetStatusCallback(cb StatusCallback) { + e.watchdog("SetStatusCallback", func() { e.wrap.SetStatusCallback(cb) }) +} +func (e *watchdogEngine) RequestStatus() { + e.watchdog("RequestStatus", func() { e.wrap.RequestStatus() }) +} +func (e *watchdogEngine) LinkChange(isExpensive bool) { + e.watchdog("LinkChange", func() { e.wrap.LinkChange(isExpensive) }) +} +func (e *watchdogEngine) Close() { + e.watchdog("Close", e.wrap.Close) +} +func (e *watchdogEngine) Wait() { + e.wrap.Wait() +} diff --git a/wgengine/watchdog_test.go b/wgengine/watchdog_test.go new file mode 100644 index 000000000..dbd4a7794 --- /dev/null +++ b/wgengine/watchdog_test.go @@ -0,0 +1,71 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package wgengine + +import ( + "bytes" + "fmt" + "strings" + "testing" + "time" +) + +func TestWatchdog(t *testing.T) { + t.Parallel() + + t.Run("default watchdog does not fire", func(t *testing.T) { + t.Parallel() + tun := NewFakeTun() + e, err := NewUserspaceEngineAdvanced(t.Logf, tun, NewFakeRouter, 0, false) + if err != nil { + t.Fatal(err) + } + + e = NewWatchdog(e) + e.(*watchdogEngine).maxWait = 150 * time.Millisecond + + e.RequestStatus() + e.RequestStatus() + e.RequestStatus() + e.Close() + }) + + t.Run("watchdog fires on blocked getStatus", func(t *testing.T) { + t.Parallel() + tun := NewFakeTun() + e, err := NewUserspaceEngineAdvanced(t.Logf, tun, NewFakeRouter, 0, false) + if err != nil { + t.Fatal(err) + } + usEngine := e.(*userspaceEngine) + e = NewWatchdog(e) + wdEngine := e.(*watchdogEngine) + wdEngine.maxWait = 100 * time.Millisecond + + logBuf := new(bytes.Buffer) + fatalCalled := make(chan struct{}) + wdEngine.logf = func(format string, args ...interface{}) { + fmt.Fprintf(logBuf, format+"\n", args...) + } + wdEngine.fatalf = func(format string, args ...interface{}) { + t.Logf("FATAL: %s", fmt.Sprintf(format, args...)) + fatalCalled <- struct{}{} + } + + usEngine.wgLock.Lock() // blocks getStatus so the watchdog will fire + + go e.RequestStatus() + + select { + case <-fatalCalled: + if !strings.Contains(logBuf.String(), "goroutine profile: total ") { + t.Errorf("fatal called without watchdog stacks, got: %s", logBuf.String()) + } + // expected + case <-time.After(3 * time.Second): + t.Fatalf("watchdog failed to fire") + } + }) +} diff --git a/wgengine/wgengine.go b/wgengine/wgengine.go new file mode 100644 index 000000000..f395dcc33 --- /dev/null +++ b/wgengine/wgengine.go @@ -0,0 +1,79 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package wgengine + +import ( + "fmt" + "net" + "time" + + "github.com/tailscale/wireguard-go/wgcfg" + "tailscale.com/tailcfg" + "tailscale.com/wgengine/filter" +) + +type ByteCount int64 + +type PeerStatus struct { + TxBytes, RxBytes ByteCount + LastHandshake time.Time + NodeKey tailcfg.NodeKey +} + +type Status struct { + Peers []PeerStatus + LocalAddrs []string // TODO(crawshaw): []wgcfg.Endpoint? +} + +type StatusCallback func(s *Status, err error) + +type RouteSettings struct { + LocalAddr wgcfg.CIDR + DNS []net.IP + DNSDomains []string + Cfg wgcfg.Config +} + +// Only used on darwin for now +// TODO(apenwarr): This probably belongs in the darwinRouter struct. +var SetRoutesFunc func(rs RouteSettings) error + +func (rs *RouteSettings) OnlyRelevantParts() string { + var peers [][]wgcfg.CIDR + for _, p := range rs.Cfg.Peers { + peers = append(peers, p.AllowedIPs) + } + return fmt.Sprintf("%v %v %v %v", + rs.LocalAddr, rs.DNS, rs.DNSDomains, peers) +} + +type Router interface { + Up() error + SetRoutes(rs RouteSettings) error + Close() +} + +type Engine interface { + // Reconfigure wireguard and make sure it's running. + // This also handles setting up any kernel routes. + Reconfig(cfg *wgcfg.Config, dnsDomains []string) error + // Update the packet filter. + SetFilter(filt *filter.Filter) + // Set the function to call when wireguard status changes. + SetStatusCallback(cb StatusCallback) + // Request a wireguard status update right away, sent to the callback. + RequestStatus() + // Shut down this wireguard instance, remove any routes it added, etc. + // To bring it up again later, you'll need a new Engine. + Close() + // Wait until the Engine is .Close()ed or aborts with an error. + // You don't have to call this. + Wait() + // LinkChange informs the engine that the system network + // link has changed. The isExpensive parameter is set on links + // where sending packets uses substantial power or dollars + // (such as LTE on a phone). + LinkChange(isExpensive bool) +} diff --git a/wgengine/winnet/winnet.go b/wgengine/winnet/winnet.go new file mode 100644 index 000000000..be76fd9ca --- /dev/null +++ b/wgengine/winnet/winnet.go @@ -0,0 +1,153 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package winnet + +import ( + "fmt" + "github.com/go-ole/go-ole" + "github.com/go-ole/go-ole/oleutil" + "unsafe" +) + +const CLSID_NetworkListManager = "{DCB00C01-570F-4A9B-8D69-199FDBA5723B}" + +var IID_INetwork = ole.NewGUID("{8A40A45D-055C-4B62-ABD7-6D613E2CEAEC}") +var IID_INetworkConnection = ole.NewGUID("{DCB00005-570F-4A9B-8D69-199FDBA5723B}") + +type NetworkListManager struct { + d *ole.Dispatch +} + +type INetworkConnection struct { + ole.IDispatch +} + +type ConnectionList []*INetworkConnection + +type INetworkConnectionVtbl struct { + ole.IDispatchVtbl + GetNetwork uintptr + Get_IsConnectedToInternet uintptr + Get_IsConnected uintptr + GetConnectivity uintptr + GetConnectionId uintptr + GetAdapterId uintptr + GetDomainType uintptr +} + +type INetwork struct { + ole.IDispatch +} + +func NewNetworkListManager(c *ole.Connection) (*NetworkListManager, error) { + err := c.Create(CLSID_NetworkListManager) + if err != nil { + return nil, err + } + defer c.Release() + + d, err := c.Dispatch() + if err != nil { + return nil, err + } + + return &NetworkListManager{ + d: d, + }, nil +} + +func (m *NetworkListManager) Release() { + m.d.Release() +} + +func (cl ConnectionList) Release() { + for _, v := range cl { + v.Release() + } +} + +func asIID(u ole.UnknownLike, iid *ole.GUID) (*ole.IDispatch, error) { + if u == nil { + return nil, fmt.Errorf("asIID: nil UnknownLike") + } + + d, err := u.QueryInterface(iid) + u.Release() + if err != nil { + return nil, err + } + return d, nil +} + +func (m *NetworkListManager) GetNetworkConnections() (ConnectionList, error) { + ncraw, err := m.d.Call("GetNetworkConnections") + if err != nil { + return nil, err + } + + nli := ncraw.ToIDispatch() + if nli == nil { + return nil, fmt.Errorf("GetNetworkConnections: not IDispatch") + } + + cl := ConnectionList{} + + err = oleutil.ForEach(nli, func(v *ole.VARIANT) error { + nc, err := asIID(v.ToIUnknown(), IID_INetworkConnection) + if err != nil { + return err + } + nco := (*INetworkConnection)(unsafe.Pointer(nc)) + cl = append(cl, nco) + return nil + }) + + if err != nil { + cl.Release() + return nil, err + } + return cl, nil +} + +func (n *INetwork) GetName() (string, error) { + v, err := n.CallMethod("GetName") + if err != nil { + return "", err + } + return v.ToString(), err +} + +func (n *INetwork) GetCategory() (int32, error) { + v, err := n.CallMethod("GetCategory") + if err != nil { + return 0, err + } + return v.Value().(int32), err +} + +func (n *INetwork) SetCategory(v uint32) error { + _, err := n.CallMethod("SetCategory", v) + return err +} + +func (v *INetworkConnection) VTable() *INetworkConnectionVtbl { + return (*INetworkConnectionVtbl)(unsafe.Pointer(v.RawVTable)) +} + +func (v *INetworkConnection) GetNetwork() (*INetwork, error) { + nraw, err := v.CallMethod("GetNetwork") + if err != nil { + return nil, err + } + + n := nraw.ToIDispatch() + if n == nil { + return nil, fmt.Errorf("GetNetwork: nil IDispatch") + } + if err != nil { + return nil, err + } + return (*INetwork)(unsafe.Pointer(n)), nil +} diff --git a/wgengine/winnet/winnet_windows.go b/wgengine/winnet/winnet_windows.go new file mode 100644 index 000000000..8744aec53 --- /dev/null +++ b/wgengine/winnet/winnet_windows.go @@ -0,0 +1,26 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package winnet + +import ( + "fmt" + "github.com/go-ole/go-ole" + "syscall" + "unsafe" +) + +func (v *INetworkConnection) GetAdapterId() (string, error) { + buf := ole.GUID{} + hr, _, _ := syscall.Syscall( + v.VTable().GetAdapterId, + 2, + uintptr(unsafe.Pointer(v)), + uintptr(unsafe.Pointer(&buf)), + 0) + if hr != 0 { + return "", fmt.Errorf("GetAdapterId failed: %08x", hr) + } + return buf.String(), nil +}