From 9939374c48aff28ea9bee63a749869312d0954ef Mon Sep 17 00:00:00 2001 From: Andrew Dunham Date: Wed, 10 Jul 2024 16:46:31 -0500 Subject: [PATCH] wgengine/magicsock: use cloud metadata to get public IPs Updates #12774 Signed-off-by: Andrew Dunham Change-Id: I1661b6a2da7966ab667b075894837afd96f4742f --- wgengine/magicsock/cloudinfo.go | 182 ++++++++++++++++++++++++ wgengine/magicsock/cloudinfo_nocloud.go | 23 +++ wgengine/magicsock/cloudinfo_test.go | 123 ++++++++++++++++ wgengine/magicsock/magicsock.go | 31 +++- wgengine/magicsock/magicsock_test.go | 10 +- 5 files changed, 360 insertions(+), 9 deletions(-) create mode 100644 wgengine/magicsock/cloudinfo.go create mode 100644 wgengine/magicsock/cloudinfo_nocloud.go create mode 100644 wgengine/magicsock/cloudinfo_test.go diff --git a/wgengine/magicsock/cloudinfo.go b/wgengine/magicsock/cloudinfo.go new file mode 100644 index 000000000..1de369631 --- /dev/null +++ b/wgengine/magicsock/cloudinfo.go @@ -0,0 +1,182 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !(ios || android || js) + +package magicsock + +import ( + "context" + "errors" + "fmt" + "io" + "net" + "net/http" + "net/netip" + "slices" + "strings" + "time" + + "tailscale.com/types/logger" + "tailscale.com/util/cloudenv" +) + +const maxCloudInfoWait = 2 * time.Second + +type cloudInfo struct { + client http.Client + logf logger.Logf + + // The following parameters are fixed for the lifetime of the cloudInfo + // object, but are used for testing. + cloud cloudenv.Cloud + endpoint string +} + +func newCloudInfo(logf logger.Logf) *cloudInfo { + tr := &http.Transport{ + DisableKeepAlives: true, + Dial: (&net.Dialer{ + Timeout: maxCloudInfoWait, + }).Dial, + } + + return &cloudInfo{ + client: http.Client{Transport: tr}, + logf: logf, + cloud: cloudenv.Get(), + endpoint: "http://" + cloudenv.CommonNonRoutableMetadataIP, + } +} + +// GetPublicIPs returns any public IPs attached to the current cloud instance, +// if the tailscaled process is running in a known cloud and there are any such +// IPs present. +func (ci *cloudInfo) GetPublicIPs(ctx context.Context) ([]netip.Addr, error) { + switch ci.cloud { + case cloudenv.AWS: + ret, err := ci.getAWS(ctx) + ci.logf("[v1] cloudinfo.GetPublicIPs: AWS: %v, %v", ret, err) + return ret, err + } + + return nil, nil +} + +// getAWSMetadata makes a request to the AWS metadata service at the given +// path, authenticating with the provided IMDSv2 token. The returned metadata +// is split by newline and returned as a slice. +func (ci *cloudInfo) getAWSMetadata(ctx context.Context, token, path string) ([]string, error) { + req, err := http.NewRequestWithContext(ctx, "GET", ci.endpoint+path, nil) + if err != nil { + return nil, fmt.Errorf("creating request to %q: %w", path, err) + } + req.Header.Set("X-aws-ec2-metadata-token", token) + + resp, err := ci.client.Do(req) + if err != nil { + return nil, fmt.Errorf("making request to metadata service %q: %w", path, err) + } + defer resp.Body.Close() + + switch resp.StatusCode { + case http.StatusOK: + // Good + case http.StatusNotFound: + // Nothing found, but this isn't an error; just return + return nil, nil + default: + return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode) + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("reading response body for %q: %w", path, err) + } + + return strings.Split(strings.TrimSpace(string(body)), "\n"), nil +} + +// getAWS returns all public IPv4 and IPv6 addresses present in the AWS instance metadata. +func (ci *cloudInfo) getAWS(ctx context.Context) ([]netip.Addr, error) { + ctx, cancel := context.WithTimeout(ctx, maxCloudInfoWait) + defer cancel() + + // Get a token so we can query the metadata service. + req, err := http.NewRequestWithContext(ctx, "PUT", ci.endpoint+"/latest/api/token", nil) + if err != nil { + return nil, fmt.Errorf("creating token request: %w", err) + } + req.Header.Set("X-Aws-Ec2-Metadata-Token-Ttl-Seconds", "10") + + resp, err := ci.client.Do(req) + if err != nil { + return nil, fmt.Errorf("making token request to metadata service: %w", err) + } + body, err := io.ReadAll(resp.Body) + resp.Body.Close() + if err != nil { + return nil, fmt.Errorf("reading token response body: %w", err) + } + token := string(body) + + server := resp.Header.Get("Server") + if server != "EC2ws" { + return nil, fmt.Errorf("unexpected server header: %q", server) + } + + // Iterate over all interfaces and get their public IP addresses, both IPv4 and IPv6. + macAddrs, err := ci.getAWSMetadata(ctx, token, "/latest/meta-data/network/interfaces/macs/") + if err != nil { + return nil, fmt.Errorf("getting interface MAC addresses: %w", err) + } + + var ( + addrs []netip.Addr + errs []error + ) + + addAddr := func(addr string) { + ip, err := netip.ParseAddr(addr) + if err != nil { + errs = append(errs, fmt.Errorf("parsing IP address %q: %w", addr, err)) + return + } + addrs = append(addrs, ip) + } + for _, mac := range macAddrs { + ips, err := ci.getAWSMetadata(ctx, token, "/latest/meta-data/network/interfaces/macs/"+mac+"/public-ipv4s") + if err != nil { + errs = append(errs, fmt.Errorf("getting IPv4 addresses for %q: %w", mac, err)) + continue + } + + for _, ip := range ips { + addAddr(ip) + } + + // Try querying for IPv6 addresses. + ips, err = ci.getAWSMetadata(ctx, token, "/latest/meta-data/network/interfaces/macs/"+mac+"/ipv6s") + if err != nil { + errs = append(errs, fmt.Errorf("getting IPv6 addresses for %q: %w", mac, err)) + continue + } + for _, ip := range ips { + addAddr(ip) + } + } + + // Sort the returned addresses for determinism. + slices.SortFunc(addrs, func(a, b netip.Addr) int { + return a.Compare(b) + }) + + // Preferentially return any addresses we found, even if there were errors. + if len(addrs) > 0 { + return addrs, nil + } + if len(errs) > 0 { + return nil, fmt.Errorf("getting IP addresses: %w", errors.Join(errs...)) + } + return nil, nil +} diff --git a/wgengine/magicsock/cloudinfo_nocloud.go b/wgengine/magicsock/cloudinfo_nocloud.go new file mode 100644 index 000000000..b4414d318 --- /dev/null +++ b/wgengine/magicsock/cloudinfo_nocloud.go @@ -0,0 +1,23 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build ios || android || js + +package magicsock + +import ( + "context" + "net/netip" + + "tailscale.com/types/logger" +) + +type cloudInfo struct{} + +func newCloudInfo(_ logger.Logf) *cloudInfo { + return &cloudInfo{} +} + +func (ci *cloudInfo) GetPublicIPs(_ context.Context) ([]netip.Addr, error) { + return nil, nil +} diff --git a/wgengine/magicsock/cloudinfo_test.go b/wgengine/magicsock/cloudinfo_test.go new file mode 100644 index 000000000..15191aeef --- /dev/null +++ b/wgengine/magicsock/cloudinfo_test.go @@ -0,0 +1,123 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package magicsock + +import ( + "context" + "net/http" + "net/http/httptest" + "net/netip" + "slices" + "testing" + + "tailscale.com/util/cloudenv" +) + +func TestCloudInfo_AWS(t *testing.T) { + const ( + mac1 = "06:1d:00:00:00:00" + mac2 = "06:1d:00:00:00:01" + publicV4 = "1.2.3.4" + otherV4_1 = "5.6.7.8" + otherV4_2 = "11.12.13.14" + v6addr = "2001:db8::1" + + macsPrefix = "/latest/meta-data/network/interfaces/macs/" + ) + // Launch a fake AWS IMDS server + fake := &fakeIMDS{ + tb: t, + paths: map[string]string{ + macsPrefix: mac1 + "\n" + mac2, + // This is the "main" public IP address for the instance + macsPrefix + mac1 + "/public-ipv4s": publicV4, + + // There's another interface with two public IPs + // attached to it and an IPv6 address, all of which we + // should discover. + macsPrefix + mac2 + "/public-ipv4s": otherV4_1 + "\n" + otherV4_2, + macsPrefix + mac2 + "/ipv6s": v6addr, + }, + } + + srv := httptest.NewServer(fake) + defer srv.Close() + + ci := newCloudInfo(t.Logf) + ci.cloud = cloudenv.AWS + ci.endpoint = srv.URL + + ips, err := ci.GetPublicIPs(context.Background()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + wantIPs := []netip.Addr{ + netip.MustParseAddr(publicV4), + netip.MustParseAddr(otherV4_1), + netip.MustParseAddr(otherV4_2), + netip.MustParseAddr(v6addr), + } + if !slices.Equal(ips, wantIPs) { + t.Fatalf("got %v, want %v", ips, wantIPs) + } +} + +func TestCloudInfo_AWSNotPublic(t *testing.T) { + returns404 := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == "PUT" && r.URL.Path == "/latest/api/token" { + w.Header().Set("Server", "EC2ws") + w.Write([]byte("fake-imds-token")) + return + } + http.NotFound(w, r) + }) + srv := httptest.NewServer(returns404) + defer srv.Close() + + ci := newCloudInfo(t.Logf) + ci.cloud = cloudenv.AWS + ci.endpoint = srv.URL + + // If the IMDS server doesn't return any public IPs, it's not an error + // and we should just get an empty list. + ips, err := ci.GetPublicIPs(context.Background()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(ips) != 0 { + t.Fatalf("got %v, want none", ips) + } +} + +type fakeIMDS struct { + tb testing.TB + paths map[string]string +} + +func (f *fakeIMDS) ServeHTTP(w http.ResponseWriter, r *http.Request) { + f.tb.Logf("%s %s", r.Method, r.URL.Path) + path := r.URL.Path + + // Handle the /latest/api/token case + const token = "fake-imds-token" + if r.Method == "PUT" && path == "/latest/api/token" { + w.Header().Set("Server", "EC2ws") + w.Write([]byte(token)) + return + } + + // Otherwise, require the IMDSv2 token to be set + if r.Header.Get("X-aws-ec2-metadata-token") != token { + f.tb.Errorf("missing or invalid IMDSv2 token") + http.Error(w, "missing or invalid IMDSv2 token", http.StatusForbidden) + return + } + + if v, ok := f.paths[path]; ok { + w.Write([]byte(v)) + return + } + http.NotFound(w, r) +} diff --git a/wgengine/magicsock/magicsock.go b/wgengine/magicsock/magicsock.go index cd7fb23da..5ac53c771 100644 --- a/wgengine/magicsock/magicsock.go +++ b/wgengine/magicsock/magicsock.go @@ -133,6 +133,9 @@ type Conn struct { // bind is the wireguard-go conn.Bind for Conn. bind *connBind + // cloudInfo is used to query cloud metadata services. + cloudInfo *cloudInfo + // ============================================================ // Fields that must be accessed via atomic load/stores. @@ -425,9 +428,10 @@ func (o *Options) derpActiveFunc() func() { // newConn is the error-free, network-listening-side-effect-free based // of NewConn. Mostly for tests. -func newConn() *Conn { +func newConn(logf logger.Logf) *Conn { discoPrivate := key.NewDisco() c := &Conn{ + logf: logf, derpRecvCh: make(chan derpReadResult, 1), // must be buffered, see issue 3736 derpStarted: make(chan struct{}), peerLastDerp: make(map[key.NodePublic]int), @@ -435,6 +439,7 @@ func newConn() *Conn { discoInfo: make(map[key.DiscoPublic]*discoInfo), discoPrivate: discoPrivate, discoPublic: discoPrivate.Public(), + cloudInfo: newCloudInfo(logf), } c.discoShort = c.discoPublic.ShortString() c.bind = &connBind{Conn: c, closed: true} @@ -462,10 +467,9 @@ func NewConn(opts Options) (*Conn, error) { return nil, errors.New("magicsock.Options.NetMon must be non-nil") } - c := newConn() + c := newConn(opts.logf()) c.port.Store(uint32(opts.Port)) c.controlKnobs = opts.ControlKnobs - c.logf = opts.logf() c.epFunc = opts.endpointsFunc() c.derpActiveFunc = opts.derpActiveFunc() c.idleFunc = opts.IdleFunc @@ -952,6 +956,27 @@ func (c *Conn) determineEndpoints(ctx context.Context) ([]tailcfg.Endpoint, erro addAddr(ap, tailcfg.EndpointExplicitConf) } + // If we're on a cloud instance, we might have a public IPv4 or IPv6 + // address that we can be reached at. Find those, if they exist, and + // add them. + if addrs, err := c.cloudInfo.GetPublicIPs(ctx); err == nil { + var port4, port6 uint16 + if addr := c.pconn4.LocalAddr(); addr != nil { + port4 = uint16(addr.Port) + } + if addr := c.pconn6.LocalAddr(); addr != nil { + port6 = uint16(addr.Port) + } + + for _, addr := range addrs { + if addr.Is4() && port4 > 0 { + addAddr(netip.AddrPortFrom(addr, port4), tailcfg.EndpointLocal) + } else if addr.Is6() && port6 > 0 { + addAddr(netip.AddrPortFrom(addr, port6), tailcfg.EndpointLocal) + } + } + } + // Update our set of endpoints by adding any endpoints that we // previously found but haven't expired yet. This also updates the // cache with the set of endpoints discovered in this function. diff --git a/wgengine/magicsock/magicsock_test.go b/wgengine/magicsock/magicsock_test.go index cec05dffc..a721c24e4 100644 --- a/wgengine/magicsock/magicsock_test.go +++ b/wgengine/magicsock/magicsock_test.go @@ -452,7 +452,7 @@ func TestPickDERPFallback(t *testing.T) { tstest.PanicOnLog() tstest.ResourceCheck(t) - c := newConn() + c := newConn(t.Logf) dm := &tailcfg.DERPMap{ Regions: map[int]*tailcfg.DERPRegion{ 1: {}, @@ -483,7 +483,7 @@ func TestPickDERPFallback(t *testing.T) { // distribution over nodes works. got := map[int]int{} for range 50 { - c = newConn() + c = newConn(t.Logf) c.derpMap = dm got[c.pickDERPFallback()]++ } @@ -1185,8 +1185,7 @@ func testTwoDevicePing(t *testing.T, d *devices) { } func TestDiscoMessage(t *testing.T) { - c := newConn() - c.logf = t.Logf + c := newConn(t.Logf) c.privateKey = key.NewNode() peer1Pub := c.DiscoPublicKey() @@ -3161,8 +3160,7 @@ func TestMaybeSetNearestDERP(t *testing.T) { for _, tt := range testCases { t.Run(tt.name, func(t *testing.T) { ht := new(health.Tracker) - c := newConn() - c.logf = t.Logf + c := newConn(t.Logf) c.myDerp = tt.old c.derpMap = derpMap c.health = ht