diff --git a/cmd/tailscale/depaware.txt b/cmd/tailscale/depaware.txt index da38ff9db..b7ce9c401 100644 --- a/cmd/tailscale/depaware.txt +++ b/cmd/tailscale/depaware.txt @@ -152,7 +152,7 @@ tailscale.com/cmd/tailscale dependencies: (generated by github.com/tailscale/dep golang.org/x/net/icmp from tailscale.com/net/ping golang.org/x/net/idna from golang.org/x/net/http/httpguts+ golang.org/x/net/ipv4 from golang.org/x/net/icmp+ - golang.org/x/net/ipv6 from golang.org/x/net/icmp + golang.org/x/net/ipv6 from golang.org/x/net/icmp+ golang.org/x/net/proxy from tailscale.com/net/netns D golang.org/x/net/route from net+ golang.org/x/oauth2 from golang.org/x/oauth2/clientcredentials diff --git a/net/netcheck/netcheck.go b/net/netcheck/netcheck.go index 603fa3a41..ac52a7a41 100644 --- a/net/netcheck/netcheck.go +++ b/net/netcheck/netcheck.go @@ -1321,10 +1321,7 @@ func (c *Client) measureAllICMPLatency(ctx context.Context, rs *reportState, nee ctx, done := context.WithTimeout(ctx, icmpProbeTimeout) defer done() - p, err := ping.New(ctx, c.logf, c.NetMon) - if err != nil { - return err - } + p := ping.New(ctx, c.logf, netns.Listener(c.logf, c.NetMon)) defer p.Close() c.logf("UDP is blocked, trying ICMP") diff --git a/net/ping/ping.go b/net/ping/ping.go index 9b9618e0f..170d87fb9 100644 --- a/net/ping/ping.go +++ b/net/ping/ping.go @@ -11,16 +11,25 @@ import ( "crypto/rand" "encoding/binary" "fmt" + "io" "log" "net" + "net/netip" "sync" + "sync/atomic" "time" "golang.org/x/net/icmp" "golang.org/x/net/ipv4" - "tailscale.com/net/netmon" - "tailscale.com/net/netns" + "golang.org/x/net/ipv6" "tailscale.com/types/logger" + "tailscale.com/util/mak" + "tailscale.com/util/multierr" +) + +const ( + v4Type = "ip4:icmp" + v6Type = "ip6:icmp" ) type response struct { @@ -33,12 +42,21 @@ type outstanding struct { data []byte } +// PacketListener defines the interface required to listen to packages +// on an address. +type ListenPacketer interface { + ListenPacket(ctx context.Context, typ string, addr string) (net.PacketConn, error) +} + // Pinger represents a set of ICMP echo requests to be sent at a single time. // // A new instance should be created for each concurrent set of ping requests; // this type should not be reused. type Pinger struct { - c net.PacketConn + lp ListenPacketer + + // closed guards against send incrementing the waitgroup concurrently with close. + closed atomic.Bool Logf logger.Logf Verbose bool timeNow func() time.Time @@ -46,16 +64,37 @@ type Pinger struct { wg sync.WaitGroup // Following fields protected by mu - mu sync.Mutex + mu sync.Mutex + // conns is a map of "type" to net.PacketConn, type is either + // "ip4:icmp" or "ip6:icmp" + conns map[string]net.PacketConn seq uint16 // uint16 per RFC 792 pings map[uint16]outstanding } // New creates a new Pinger. The Context provided will be used to create // network listeners, and to set an absolute deadline (if any) on the net.Conn -// The netMon parameter is optional; if non-nil it's used to do faster interface lookups. -func New(ctx context.Context, logf logger.Logf, netMon *netmon.Monitor) (*Pinger, error) { - p, err := newUnstarted(ctx, logf, netMon) +func New(ctx context.Context, logf logger.Logf, lp ListenPacketer) *Pinger { + var id [2]byte + if _, err := io.ReadFull(rand.Reader, id[:]); err != nil { + panic("net/ping: New:" + err.Error()) + } + + return &Pinger{ + lp: lp, + Logf: logf, + timeNow: time.Now, + id: binary.LittleEndian.Uint16(id[:]), + pings: make(map[uint16]outstanding), + } +} + +func (p *Pinger) mkconn(ctx context.Context, typ, addr string) (net.PacketConn, error) { + if p.closed.Load() { + return nil, net.ErrClosed + } + + c, err := p.lp.ListenPacket(ctx, typ, addr) if err != nil { return nil, err } @@ -64,35 +103,36 @@ func New(ctx context.Context, logf logger.Logf, netMon *netmon.Monitor) (*Pinger // applies to all future I/O, so we only need to do it once. deadline, ok := ctx.Deadline() if ok { - if err := p.c.SetReadDeadline(deadline); err != nil { + if err := c.SetReadDeadline(deadline); err != nil { return nil, err } } p.wg.Add(1) - go p.run(ctx) - return p, nil + go p.run(ctx, c, typ) + + return c, err } -func newUnstarted(ctx context.Context, logf logger.Logf, netMon *netmon.Monitor) (*Pinger, error) { - var id [2]byte - _, err := rand.Read(id[:]) - if err != nil { - return nil, err +// getConn creates or returns a conn matching typ which is ip4:icmp +// or ip6:icmp. +func (p *Pinger) getConn(ctx context.Context, typ string) (net.PacketConn, error) { + p.mu.Lock() + defer p.mu.Unlock() + if c, ok := p.conns[typ]; ok { + return c, nil } - conn, err := netns.Listener(logf, netMon).ListenPacket(ctx, "ip4:icmp", "0.0.0.0") + var addr = "0.0.0.0" + if typ == v6Type { + addr = "::" + } + c, err := p.mkconn(ctx, typ, addr) if err != nil { return nil, err } - - return &Pinger{ - c: conn, - Logf: logf, - timeNow: time.Now, - id: binary.LittleEndian.Uint16(id[:]), - pings: make(map[uint16]outstanding), - }, nil + mak.Set(&p.conns, typ, c) + return c, nil } func (p *Pinger) logf(format string, a ...any) { @@ -110,13 +150,34 @@ func (p *Pinger) vlogf(format string, a ...any) { } func (p *Pinger) Close() error { - err := p.c.Close() + p.closed.Store(true) + + p.mu.Lock() + conns := p.conns + p.conns = nil + p.mu.Unlock() + + var errors []error + for _, c := range conns { + if err := c.Close(); err != nil { + errors = append(errors, err) + } + } + p.wg.Wait() - return err + p.cleanupOutstanding() + + return multierr.New(errors...) } -func (p *Pinger) run(ctx context.Context) { +func (p *Pinger) run(ctx context.Context, conn net.PacketConn, typ string) { defer p.wg.Done() + defer func() { + conn.Close() + p.mu.Lock() + delete(p.conns, typ) + p.mu.Unlock() + }() buf := make([]byte, 1500) loop: @@ -127,7 +188,7 @@ loop: default: } - n, addr, err := p.c.ReadFrom(buf) + n, _, err := conn.ReadFrom(buf) if err != nil { // Ignore temporary errors; everything else is fatal if netErr, ok := err.(net.Error); !ok || !netErr.Temporary() { @@ -136,10 +197,8 @@ loop: continue } - p.handleResponse(buf[:n], addr, p.timeNow()) + p.handleResponse(buf[:n], p.timeNow(), typ) } - - p.cleanupOutstanding() } func (p *Pinger) cleanupOutstanding() { @@ -151,16 +210,28 @@ func (p *Pinger) cleanupOutstanding() { } } -func (p *Pinger) handleResponse(buf []byte, addr net.Addr, now time.Time) { - const ProtocolICMP = 1 - m, err := icmp.ParseMessage(ProtocolICMP, buf) +func (p *Pinger) handleResponse(buf []byte, now time.Time, typ string) { + // We need to handle responding to both IPv4 + // and IPv6. + var icmpType icmp.Type + switch typ { + case v4Type: + icmpType = ipv4.ICMPTypeEchoReply + case v6Type: + icmpType = ipv6.ICMPTypeEchoReply + default: + p.vlogf("handleResponse: unknown icmp.Type") + return + } + + m, err := icmp.ParseMessage(icmpType.Protocol(), buf) if err != nil { p.vlogf("handleResponse: invalid packet: %v", err) return } - if m.Type != ipv4.ICMPTypeEchoReply { - p.vlogf("handleResponse: wanted m.Type=%d; got %d", ipv4.ICMPTypeEchoReply, m.Type) + if m.Type != icmpType { + p.vlogf("handleResponse: wanted m.Type=%d; got %d", icmpType, m.Type) return } @@ -212,9 +283,27 @@ func (p *Pinger) Send(ctx context.Context, dest net.Addr, data []byte) (time.Dur seq := p.seq p.mu.Unlock() + // Check whether the address is IPv4 or IPv6 to + // determine the icmp.Type and conn to use. + var conn net.PacketConn + var icmpType icmp.Type = ipv4.ICMPTypeEcho + ap, err := netip.ParseAddr(dest.String()) + if err != nil { + return 0, err + } + if ap.Is6() { + icmpType = ipv6.ICMPTypeEchoRequest + conn, err = p.getConn(ctx, v6Type) + } else { + conn, err = p.getConn(ctx, v4Type) + } + if err != nil { + return 0, err + } + m := icmp.Message{ - Type: ipv4.ICMPTypeEcho, - Code: 0, + Type: icmpType, + Code: icmpType.Protocol(), Body: &icmp.Echo{ ID: int(p.id), Seq: int(seq), @@ -234,7 +323,7 @@ func (p *Pinger) Send(ctx context.Context, dest net.Addr, data []byte) (time.Dur p.mu.Unlock() start := p.timeNow() - n, err := p.c.WriteTo(b, dest) + n, err := conn.WriteTo(b, dest) if err != nil { return 0, err } else if n != len(b) { diff --git a/net/ping/ping_test.go b/net/ping/ping_test.go index 1fdefc6e7..bbedbcad8 100644 --- a/net/ping/ping_test.go +++ b/net/ping/ping_test.go @@ -6,18 +6,20 @@ package ping import ( "context" "errors" + "fmt" "net" "testing" "time" "golang.org/x/net/icmp" "golang.org/x/net/ipv4" + "golang.org/x/net/ipv6" "tailscale.com/tstest" + "tailscale.com/util/mak" ) var ( - localhost = &net.IPAddr{IP: net.IPv4(127, 0, 0, 1)} - localhostUDP = &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 12345} + localhost = &net.IPAddr{IP: net.IPv4(127, 0, 0, 1)} ) func TestPinger(t *testing.T) { @@ -35,7 +37,7 @@ func TestPinger(t *testing.T) { // Start a ping in the background r := make(chan time.Duration, 1) go func() { - dur, err := p.Send(ctx, localhostUDP, bodyData) + dur, err := p.Send(ctx, localhost, bodyData) if err != nil { t.Errorf("p.Send: %v", err) r <- 0 @@ -49,7 +51,7 @@ func TestPinger(t *testing.T) { // Fake a response from ourself fakeResponse := mustMarshal(t, &icmp.Message{ Type: ipv4.ICMPTypeEchoReply, - Code: 0, + Code: ipv4.ICMPTypeEchoReply.Protocol(), Body: &icmp.Echo{ ID: 1234, Seq: 1, @@ -58,7 +60,65 @@ func TestPinger(t *testing.T) { }) const fakeDuration = 100 * time.Millisecond - p.handleResponse(fakeResponse, localhost, clock.Now().Add(fakeDuration)) + p.handleResponse(fakeResponse, clock.Now().Add(fakeDuration), v4Type) + + select { + case dur := <-r: + want := fakeDuration + if dur != want { + t.Errorf("wanted ping response time = %d; got %d", want, dur) + } + case <-ctx.Done(): + t.Fatal("did not get response by timeout") + } +} + +func TestV6Pinger(t *testing.T) { + if c, err := net.ListenPacket("udp6", "::1"); err != nil { + // skip test if we can't use IPv6. + t.Skipf("IPv6 not supported: %s", err) + } else { + c.Close() + } + + clock := &tstest.Clock{} + + ctx := context.Background() + ctx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + + p, closeP := mockPinger(t, clock) + defer closeP() + + bodyData := []byte("data goes here") + + // Start a ping in the background + r := make(chan time.Duration, 1) + go func() { + dur, err := p.Send(ctx, &net.IPAddr{IP: net.ParseIP("::")}, bodyData) + if err != nil { + t.Errorf("p.Send: %v", err) + r <- 0 + } else { + r <- dur + } + }() + + p.waitOutstanding(t, ctx, 1) + + // Fake a response from ourself + fakeResponse := mustMarshal(t, &icmp.Message{ + Type: ipv6.ICMPTypeEchoReply, + Code: ipv6.ICMPTypeEchoReply.Protocol(), + Body: &icmp.Echo{ + ID: 1234, + Seq: 1, + Data: bodyData, + }, + }) + + const fakeDuration = 100 * time.Millisecond + p.handleResponse(fakeResponse, clock.Now().Add(fakeDuration), v6Type) select { case dur := <-r: @@ -83,7 +143,7 @@ func TestPingerTimeout(t *testing.T) { // Send a ping in the background r := make(chan error, 1) go func() { - _, err := p.Send(ctx, localhostUDP, []byte("data goes here")) + _, err := p.Send(ctx, localhost, []byte("data goes here")) r <- err }() @@ -115,7 +175,7 @@ func TestPingerMismatch(t *testing.T) { // Start a ping in the background r := make(chan time.Duration, 1) go func() { - dur, err := p.Send(ctx, localhostUDP, bodyData) + dur, err := p.Send(ctx, localhost, bodyData) if err != nil && !errors.Is(err, context.DeadlineExceeded) { t.Errorf("p.Send: %v", err) r <- 0 @@ -185,11 +245,11 @@ func TestPingerMismatch(t *testing.T) { for _, tt := range badPackets { fakeResponse := mustMarshal(t, tt.pkt) - p.handleResponse(fakeResponse, localhost, tm) + p.handleResponse(fakeResponse, tm, v4Type) } // Also "receive" a packet that does not unmarshal as an ICMP packet - p.handleResponse([]byte("foo"), localhost, tm) + p.handleResponse([]byte("foo"), tm, v4Type) select { case <-r: @@ -199,23 +259,59 @@ func TestPingerMismatch(t *testing.T) { } } +// udpingPacketConn will convert potentially ICMP destination addrs to UDP +// destination addrs in WriteTo so that a test that is intending to send ICMP +// traffic will instead send UDP traffic, without the higher level Pinger being +// aware of this difference. +type udpingPacketConn struct { + net.PacketConn + // destPort will be configured by the test to be the peer expected to respond to a ping. + destPort uint16 +} + +func (u *udpingPacketConn) WriteTo(body []byte, dest net.Addr) (int, error) { + switch d := dest.(type) { + case *net.IPAddr: + udpAddr := &net.UDPAddr{ + IP: d.IP, + Port: int(u.destPort), + Zone: d.Zone, + } + return u.PacketConn.WriteTo(body, udpAddr) + } + return 0, fmt.Errorf("unimplemented udpingPacketConn for %T", dest) +} + func mockPinger(t *testing.T, clock *tstest.Clock) (*Pinger, func()) { + p := New(context.Background(), t.Logf, nil) + p.timeNow = clock.Now + p.Verbose = true + p.id = 1234 + // In tests, we use UDP so that we can test without being root; this // doesn't matter because we mock out the ICMP reply below to be a real // ICMP echo reply packet. - conn, err := net.ListenPacket("udp4", "127.0.0.1:0") + conn4, err := net.ListenPacket("udp4", "127.0.0.1:0") if err != nil { t.Fatalf("net.ListenPacket: %v", err) } - p := &Pinger{ - c: conn, - Logf: t.Logf, - Verbose: true, - timeNow: clock.Now, - id: 1234, - pings: make(map[uint16]outstanding), + conn6, err := net.ListenPacket("udp6", "[::]:0") + if err != nil { + t.Fatalf("net.ListenPacket: %v", err) } + + conn4 = &udpingPacketConn{ + destPort: 12345, + PacketConn: conn4, + } + conn6 = &udpingPacketConn{ + PacketConn: conn6, + destPort: 12345, + } + + mak.Set(&p.conns, v4Type, conn4) + mak.Set(&p.conns, v6Type, conn6) done := func() { if err := p.Close(); err != nil { t.Errorf("error on close: %v", err)