wgengine/...: split into multiple receive functions

Upstream wireguard-go has changed its receive model.
NewDevice now accepts a conn.Bind interface.

The conn.Bind is stateless; magicsock.Conns are stateful.
To work around this, we add a connBind type that supports
cheap teardown and bring-up, backed by a Conn.

The new conn.Bind allows us to specify a set of receive functions,
rather than having to shoehorn everything into ReceiveIPv4 and ReceiveIPv6.
This lets us plumbing DERP messages directly into wireguard-go,
instead of having to mux them via ReceiveIPv4.

One consequence of the new conn.Bind layer is that
closing the wireguard-go device is now indistinguishable
from the routine bring-up and tear-down normally experienced
by a conn.Bind. We thus have to explicitly close the magicsock.Conn
when the close the wireguard-go device.

One downside of this change is that we are reliant on wireguard-go
to call receiveDERP to process DERP messages. This is fine for now,
but is perhaps something we should fix in the future.

Signed-off-by: Josh Bleecher Snyder <josh@tailscale.com>
pull/1652/head
Josh Bleecher Snyder 4 years ago committed by Josh Bleecher Snyder
parent 2074dfa5e0
commit b3ceca1dd7

@ -22,6 +22,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de
L github.com/mdlayher/sdnotify from tailscale.com/util/systemd L github.com/mdlayher/sdnotify from tailscale.com/util/systemd
W github.com/pkg/errors from github.com/github/certstore W github.com/pkg/errors from github.com/github/certstore
💣 github.com/tailscale/wireguard-go/conn from github.com/tailscale/wireguard-go/device+ 💣 github.com/tailscale/wireguard-go/conn from github.com/tailscale/wireguard-go/device+
W 💣 github.com/tailscale/wireguard-go/conn/winrio from github.com/tailscale/wireguard-go/conn
💣 github.com/tailscale/wireguard-go/device from tailscale.com/wgengine+ 💣 github.com/tailscale/wireguard-go/device from tailscale.com/wgengine+
💣 github.com/tailscale/wireguard-go/ipc from github.com/tailscale/wireguard-go/device 💣 github.com/tailscale/wireguard-go/ipc from github.com/tailscale/wireguard-go/device
W 💣 github.com/tailscale/wireguard-go/ipc/winpipe from github.com/tailscale/wireguard-go/ipc W 💣 github.com/tailscale/wireguard-go/ipc/winpipe from github.com/tailscale/wireguard-go/ipc
@ -124,7 +125,6 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de
tailscale.com/types/netmap from tailscale.com/control/controlclient+ tailscale.com/types/netmap from tailscale.com/control/controlclient+
tailscale.com/types/nettype from tailscale.com/wgengine/magicsock tailscale.com/types/nettype from tailscale.com/wgengine/magicsock
tailscale.com/types/opt from tailscale.com/control/controlclient+ tailscale.com/types/opt from tailscale.com/control/controlclient+
tailscale.com/types/pad32 from tailscale.com/wgengine/magicsock
tailscale.com/types/persist from tailscale.com/control/controlclient+ tailscale.com/types/persist from tailscale.com/control/controlclient+
tailscale.com/types/preftype from tailscale.com/ipn+ tailscale.com/types/preftype from tailscale.com/ipn+
tailscale.com/types/strbuilder from tailscale.com/net/packet tailscale.com/types/strbuilder from tailscale.com/net/packet

@ -24,7 +24,7 @@ require (
github.com/peterbourgon/ff/v2 v2.0.0 github.com/peterbourgon/ff/v2 v2.0.0
github.com/pkg/errors v0.9.1 // indirect github.com/pkg/errors v0.9.1 // indirect
github.com/tailscale/depaware v0.0.0-20201214215404-77d1e9757027 github.com/tailscale/depaware v0.0.0-20201214215404-77d1e9757027
github.com/tailscale/wireguard-go v0.0.0-20210401164443-2d6878b6b30d github.com/tailscale/wireguard-go v0.0.0-20210402173217-0a47c6e64d15
github.com/tcnksm/go-httpstat v0.2.0 github.com/tcnksm/go-httpstat v0.2.0
github.com/toqueteos/webbrowser v1.2.0 github.com/toqueteos/webbrowser v1.2.0
go4.org/mem v0.0.0-20201119185036-c04c5a6ff174 go4.org/mem v0.0.0-20201119185036-c04c5a6ff174

@ -129,6 +129,12 @@ github.com/tailscale/wireguard-go v0.0.0-20210330200845-4914b4a944c4 h1:7Y0H5Nzr
github.com/tailscale/wireguard-go v0.0.0-20210330200845-4914b4a944c4/go.mod h1:6t0OVdJwFOKFnvaHaVMKG6GznWaHqkmiR2n3kH0t924= github.com/tailscale/wireguard-go v0.0.0-20210330200845-4914b4a944c4/go.mod h1:6t0OVdJwFOKFnvaHaVMKG6GznWaHqkmiR2n3kH0t924=
github.com/tailscale/wireguard-go v0.0.0-20210401164443-2d6878b6b30d h1:zbDBqtYvc492gcRL5BB7AO5Aed+aVht2jbYg8SKoMYs= github.com/tailscale/wireguard-go v0.0.0-20210401164443-2d6878b6b30d h1:zbDBqtYvc492gcRL5BB7AO5Aed+aVht2jbYg8SKoMYs=
github.com/tailscale/wireguard-go v0.0.0-20210401164443-2d6878b6b30d/go.mod h1:6t0OVdJwFOKFnvaHaVMKG6GznWaHqkmiR2n3kH0t924= github.com/tailscale/wireguard-go v0.0.0-20210401164443-2d6878b6b30d/go.mod h1:6t0OVdJwFOKFnvaHaVMKG6GznWaHqkmiR2n3kH0t924=
github.com/tailscale/wireguard-go v0.0.0-20210401172819-1aca620a8afb h1:6TGRROCOrjTKbt1ucBTZaDMBeScG6yVEXEjuabOiBzU=
github.com/tailscale/wireguard-go v0.0.0-20210401172819-1aca620a8afb/go.mod h1:jy12FSeiDLRvS7VQvSoiaqH9WtpapbrC8YSzyZ7fUAk=
github.com/tailscale/wireguard-go v0.0.0-20210401194826-bb7bc2f24083 h1:e3k65apTVs7NM6mhQ1c94XISLe+2gdizPfRdsImNL8Y=
github.com/tailscale/wireguard-go v0.0.0-20210401194826-bb7bc2f24083/go.mod h1:jy12FSeiDLRvS7VQvSoiaqH9WtpapbrC8YSzyZ7fUAk=
github.com/tailscale/wireguard-go v0.0.0-20210402173217-0a47c6e64d15 h1:13GZsTKbCmPGwDBurcSXT+ssYID2IfcX0MfsvhaaagY=
github.com/tailscale/wireguard-go v0.0.0-20210402173217-0a47c6e64d15/go.mod h1:jy12FSeiDLRvS7VQvSoiaqH9WtpapbrC8YSzyZ7fUAk=
github.com/tcnksm/go-httpstat v0.2.0 h1:rP7T5e5U2HfmOBmZzGgGZjBQ5/GluWUylujl0tJ04I0= github.com/tcnksm/go-httpstat v0.2.0 h1:rP7T5e5U2HfmOBmZzGgGZjBQ5/GluWUylujl0tJ04I0=
github.com/tcnksm/go-httpstat v0.2.0/go.mod h1:s3JVJFtQxtBEBC9dwcdTTXS9xFnM3SXAZwPG41aurT8= github.com/tcnksm/go-httpstat v0.2.0/go.mod h1:s3JVJFtQxtBEBC9dwcdTTXS9xFnM3SXAZwPG41aurT8=
github.com/toqueteos/webbrowser v1.2.0 h1:tVP/gpK69Fx+qMJKsLE7TD8LuGWPnEV71wBN9rrstGQ= github.com/toqueteos/webbrowser v1.2.0 h1:tVP/gpK69Fx+qMJKsLE7TD8LuGWPnEV71wBN9rrstGQ=

@ -12,7 +12,6 @@ import (
crand "crypto/rand" crand "crypto/rand"
"encoding/binary" "encoding/binary"
"errors" "errors"
"expvar"
"fmt" "fmt"
"hash/fnv" "hash/fnv"
"math" "math"
@ -25,7 +24,6 @@ import (
"strings" "strings"
"sync" "sync"
"sync/atomic" "sync/atomic"
"syscall"
"time" "time"
"github.com/tailscale/wireguard-go/conn" "github.com/tailscale/wireguard-go/conn"
@ -53,7 +51,6 @@ import (
"tailscale.com/types/logger" "tailscale.com/types/logger"
"tailscale.com/types/netmap" "tailscale.com/types/netmap"
"tailscale.com/types/nettype" "tailscale.com/types/nettype"
"tailscale.com/types/pad32"
"tailscale.com/types/wgkey" "tailscale.com/types/wgkey"
"tailscale.com/version" "tailscale.com/version"
"tailscale.com/wgengine/monitor" "tailscale.com/wgengine/monitor"
@ -161,17 +158,14 @@ type Conn struct {
// Its Loaded value is always non-nil. // Its Loaded value is always non-nil.
stunReceiveFunc atomic.Value // of func(p []byte, fromAddr *net.UDPAddr) stunReceiveFunc atomic.Value // of func(p []byte, fromAddr *net.UDPAddr)
// derpRecvCh is used by ReceiveIPv4 to read DERP messages. // derpRecvCh is used by receiveDERP to read DERP messages.
derpRecvCh chan derpReadResult derpRecvCh chan derpReadResult
_ pad32.Four // bind is the wireguard-go conn.Bind for Conn.
// derpRecvCountAtomic is how many derpRecvCh sends are pending. bind *connBind
// It's incremented by runDerpReader whenever a DERP message
// arrives and decremented when they're read.
derpRecvCountAtomic int64
// ippEndpoint4 and ippEndpoint6 are owned by ReceiveIPv4 and // ippEndpoint4 and ippEndpoint6 are owned by receiveIPv4 and
// ReceiveIPv6, respectively, to cache an IPPort->endpoint for // receiveIPv6, respectively, to cache an IPPort->endpoint for
// hot flows. // hot flows.
ippEndpoint4, ippEndpoint6 ippEndpointCache ippEndpoint4, ippEndpoint6 ippEndpointCache
@ -467,6 +461,7 @@ func newConn() *Conn {
sharedDiscoKey: make(map[tailcfg.DiscoKey]*[32]byte), sharedDiscoKey: make(map[tailcfg.DiscoKey]*[32]byte),
discoOfAddr: make(map[netaddr.IPPort]tailcfg.DiscoKey), discoOfAddr: make(map[netaddr.IPPort]tailcfg.DiscoKey),
} }
c.bind = &connBind{Conn: c, closed: true}
c.muCond = sync.NewCond(&c.mu) c.muCond = sync.NewCond(&c.mu)
c.networkUp.Set(true) // assume up until told otherwise c.networkUp.Set(true) // assume up until told otherwise
return c return c
@ -1499,58 +1494,18 @@ func (c *Conn) runDerpReader(ctx context.Context, derpFakeAddr netaddr.IPPort, d
continue continue
} }
if !c.sendDerpReadResult(ctx, res) {
return
}
select { select {
case <-ctx.Done(): case <-ctx.Done():
return return
case <-didCopy: case c.derpRecvCh <- res:
continue
}
}
} }
var (
testCounterZeroDerpReadResultSend expvar.Int
testCounterZeroDerpReadResultRecv expvar.Int
)
// sendDerpReadResult sends res to c.derpRecvCh and reports whether it
// was sent. (It reports false if ctx was done first.)
//
// This includes doing the whole wake-up dance to interrupt
// ReceiveIPv4's blocking UDP read.
func (c *Conn) sendDerpReadResult(ctx context.Context, res derpReadResult) (sent bool) {
// Before we wake up ReceiveIPv4 with SetReadDeadline,
// note that a DERP packet has arrived. ReceiveIPv4
// will read this field to note that its UDP read
// error is due to us.
atomic.AddInt64(&c.derpRecvCountAtomic, 1)
// Cancel the pconn read goroutine.
c.pconn4.SetReadDeadline(aLongTimeAgo)
select { select {
case <-ctx.Done(): case <-ctx.Done():
select { return
case <-c.donec: case <-didCopy:
// The whole Conn shut down. The reader of continue
// c.derpRecvCh also selects on c.donec, so it's
// safe to abort now.
case c.derpRecvCh <- (derpReadResult{}):
// Just this DERP reader is closing (perhaps
// the user is logging out, or the DERP
// connection is too idle for sends). Since we
// already incremented c.derpRecvCountAtomic,
// we need to send on the channel (unless the
// conn is going down).
// The receiver treats a derpReadResult zero value
// message as a skip.
testCounterZeroDerpReadResultSend.Add(1)
} }
return false
case c.derpRecvCh <- res:
return true
} }
} }
@ -1623,10 +1578,8 @@ func (c *Conn) noteRecvActivityFromEndpoint(e conn.Endpoint) {
} }
} }
func (c *Conn) ReceiveIPv6(b []byte) (int, conn.Endpoint, error) { // receiveIPv6 receives a UDP IPv6 packet. It is called by wireguard-go.
if c.pconn6 == nil { func (c *Conn) receiveIPv6(b []byte) (int, conn.Endpoint, error) {
return 0, nil, syscall.EAFNOSUPPORT
}
for { for {
n, ipp, err := c.pconn6.ReadFromNetaddr(b) n, ipp, err := c.pconn6.ReadFromNetaddr(b)
if err != nil { if err != nil {
@ -1638,43 +1591,16 @@ func (c *Conn) ReceiveIPv6(b []byte) (int, conn.Endpoint, error) {
} }
} }
func (c *Conn) derpPacketArrived() bool { // receiveIPv4 receives a UDP IPv4 packet. It is called by wireguard-go.
return atomic.LoadInt64(&c.derpRecvCountAtomic) > 0 func (c *Conn) receiveIPv4(b []byte) (n int, ep conn.Endpoint, err error) {
}
// ReceiveIPv4 is called by wireguard-go to receive an IPv4 packet.
// In Tailscale's case, that packet might also arrive via DERP. A DERP packet arrival
// aborts the pconn4 read deadline to make it fail.
func (c *Conn) ReceiveIPv4(b []byte) (n int, ep conn.Endpoint, err error) {
var ipp netaddr.IPPort
for { for {
// Drain DERP queues before reading new UDP packets. n, ipp, err := c.pconn4.ReadFromNetaddr(b)
if c.derpPacketArrived() {
goto ReadDERP
}
n, ipp, err = c.pconn4.ReadFromNetaddr(b)
if err != nil { if err != nil {
// If the pconn4 read failed, the likely reason is a DERP reader received
// a packet and interrupted us.
// It's possible for ReadFrom to return a non deadline exceeded error
// and for there to have also had a DERP packet arrive, but that's fine:
// we'll get the same error from ReadFrom later.
if c.derpPacketArrived() {
goto ReadDERP
}
return 0, nil, err return 0, nil, err
} }
if ep, ok := c.receiveIP(b[:n], ipp, &c.ippEndpoint4); ok { if ep, ok := c.receiveIP(b[:n], ipp, &c.ippEndpoint4); ok {
return n, ep, nil return n, ep, nil
} else {
continue
} }
ReadDERP:
n, ep, err = c.receiveIPv4DERP(b)
if err == errLoopAgain {
continue
}
return n, ep, err
} }
} }
@ -1693,8 +1619,7 @@ func (c *Conn) receiveIP(b []byte, ipp netaddr.IPPort, cache *ippEndpointCache)
if !c.havePrivateKey.Get() { if !c.havePrivateKey.Get() {
// If we have no private key, we're logged out or // If we have no private key, we're logged out or
// stopped. Don't try to pass these wireguard packets // stopped. Don't try to pass these wireguard packets
// up to wireguard-go; it'll just complain (Issue // up to wireguard-go; it'll just complain (issue 1167).
// 1167).
return nil, false return nil, false
} }
if cache.ipp == ipp && cache.de != nil && cache.gen == cache.de.numStopAndReset() { if cache.ipp == ipp && cache.de != nil && cache.gen == cache.de.numStopAndReset() {
@ -1714,50 +1639,42 @@ func (c *Conn) receiveIP(b []byte, ipp netaddr.IPPort, cache *ippEndpointCache)
return ep, true return ep, true
} }
var errLoopAgain = errors.New("received packet was not a wireguard-go packet or no endpoint found") // receiveDERP reads a packet from c.derpRecvCh into b and returns the associated endpoint.
// It is called by wireguard-go.
// receiveIPv4DERP reads a packet from c.derpRecvCh into b and returns the associated endpoint.
// //
// If the packet was a disco message or the peer endpoint wasn't // If the packet was a disco message or the peer endpoint wasn't
// found, the returned error is errLoopAgain. // found, the returned error is errLoopAgain.
func (c *Conn) receiveIPv4DERP(b []byte) (n int, ep conn.Endpoint, err error) { func (c *connBind) receiveDERP(b []byte) (n int, ep conn.Endpoint, err error) {
var dm derpReadResult for dm := range c.derpRecvCh {
select { if c.Closed() {
case <-c.donec: break
// Socket has been shut down. All the producers of packets
// respond to the context cancellation and go away, so we have
// to also unblock and return an error, to inform wireguard-go
// that this socket has gone away.
//
// Specifically, wireguard-go depends on its bind.Conn having
// the standard socket behavior, which is that a Close()
// unblocks any concurrent Read()s. wireguard-go itself calls
// Close() on magicsock, and expects ReceiveIPv4 to unblock
// with an error so it can clean up.
return 0, nil, errors.New("socket closed")
case dm = <-c.derpRecvCh:
// Below.
}
if atomic.AddInt64(&c.derpRecvCountAtomic, -1) == 0 {
c.pconn4.SetReadDeadline(time.Time{})
} }
if dm.copyBuf == nil { n, ep := c.processDERPReadResult(dm, b)
testCounterZeroDerpReadResultRecv.Add(1) if n == 0 {
return 0, nil, errLoopAgain // No data read occurred. Wait for another packet.
continue
}
return n, ep, nil
}
return 0, nil, net.ErrClosed
} }
func (c *Conn) processDERPReadResult(dm derpReadResult, b []byte) (n int, ep conn.Endpoint) {
if dm.copyBuf == nil {
return 0, nil
}
var regionID int var regionID int
n, regionID = dm.n, dm.regionID n, regionID = dm.n, dm.regionID
ncopy := dm.copyBuf(b) ncopy := dm.copyBuf(b)
if ncopy != n { if ncopy != n {
err = fmt.Errorf("received DERP packet of length %d that's too big for WireGuard ReceiveIPv4 buf size %d", n, ncopy) err := fmt.Errorf("received DERP packet of length %d that's too big for WireGuard buf size %d", n, ncopy)
c.logf("magicsock: %v", err) c.logf("magicsock: %v", err)
return 0, nil, err return 0, nil
} }
ipp := netaddr.IPPort{IP: derpMagicIPAddr, Port: uint16(regionID)} ipp := netaddr.IPPort{IP: derpMagicIPAddr, Port: uint16(regionID)}
if c.handleDiscoMessage(b[:n], ipp) { if c.handleDiscoMessage(b[:n], ipp) {
return 0, nil, errLoopAgain return 0, nil
} }
var ( var (
@ -1799,14 +1716,14 @@ func (c *Conn) receiveIPv4DERP(b []byte) (n int, ep conn.Endpoint, err error) {
c.logf("magicsock: DERP packet from unknown key: %s", key.ShortString()) c.logf("magicsock: DERP packet from unknown key: %s", key.ShortString())
ep = c.findEndpoint(ipp, b[:n]) ep = c.findEndpoint(ipp, b[:n])
if ep == nil { if ep == nil {
return 0, nil, errLoopAgain return 0, nil
} }
} }
if !didNoteRecvActivity { if !didNoteRecvActivity {
c.noteRecvActivityFromEndpoint(ep) c.noteRecvActivityFromEndpoint(ep)
} }
return n, ep, nil return n, ep
} }
// discoLogLevel controls the verbosity of discovery log messages. // discoLogLevel controls the verbosity of discovery log messages.
@ -2468,8 +2385,86 @@ func (c *Conn) DERPs() int {
return len(c.activeDerp) return len(c.activeDerp)
} }
func (c *Conn) SetMark(value uint32) error { return nil } // Bind returns the wireguard-go conn.Bind for c.
func (c *Conn) LastMark() uint32 { return 0 } func (c *Conn) Bind() conn.Bind {
return c.bind
}
// connBind is a wireguard-go conn.Bind for a Conn.
//
// wireguard-go wants binds to be stateless.
// It wants to be able to Close and re-Open them cheaply.
// And Close must cause all receive functions to immediately return an error.
//
// Conns are very stateful.
// A connBind is intended to be a cheap, stateless abstraction over top of a Conn.
//
// connBind must implement the Close-unblocking.
// For DERP connections, it sends a zero value on the DERP channel;
// receiveDERP checks whether the connBind is closed on every iteration.
// For UDP connections, we push the implementation of cheap Close and Open to RebindingUDPConns.
// RebindingUDPConns have a "fake close", which allows them to close and unblock
// and then re-open without actually releasing any resources.
type connBind struct {
*Conn
mu sync.Mutex
closed bool
}
// Open is called by WireGuard to create a UDP binding.
// The ignoredPort comes from wireguard-go, via the wgcfg config.
// We ignore that port value here, since we have the local port available easily.
func (c *connBind) Open(ignoredPort uint16) ([]conn.ReceiveFunc, uint16, error) {
c.mu.Lock()
defer c.mu.Unlock()
if !c.closed {
return nil, 0, errors.New("magicsock: connBind already open")
}
c.closed = false
// Restore all receive calls.
c.pconn4.SetFakeClosed(false)
fns := []conn.ReceiveFunc{c.receiveIPv4, c.receiveDERP}
if c.pconn6 != nil {
c.pconn6.SetFakeClosed(false)
fns = append(fns, c.receiveIPv6)
}
// TODO: Combine receiveIPv4 and receiveIPv6 and receiveIP into a single
// closure that closes over a *RebindingUDPConn?
return fns, c.LocalPort(), nil
}
// SetMark is used by wireguard-go to set a mark bit for packets to avoid routing loops.
// We handle that ourselves elsewhere.
func (c *connBind) SetMark(value uint32) error {
return nil
}
func (c *connBind) Close() error {
c.mu.Lock()
defer c.mu.Unlock()
if c.closed {
return nil
}
c.closed = true
// Unblock all outstanding receives.
c.pconn4.SetFakeClosed(true)
if c.pconn6 != nil {
c.pconn6.SetFakeClosed(true)
}
// Send an empty read result to unblock receiveDERP,
// which will then check connBind.Closed.
c.derpRecvCh <- derpReadResult{}
return nil
}
// Closed reports whether c is closed.
func (c *connBind) Closed() bool {
c.mu.Lock()
defer c.mu.Unlock()
return c.closed
}
// Close closes the connection. // Close closes the connection.
// //
@ -2731,12 +2726,7 @@ func packIPPort(ua netaddr.IPPort) []byte {
return b return b
} }
// CreateBind is called by WireGuard to create a UDP binding. // ParseEndpoint is called by WireGuard to connect to an endpoint.
func (c *Conn) CreateBind(uint16) (conn.Bind, uint16, error) {
return c, c.LocalPort(), nil
}
// CreateEndpoint is called by WireGuard to connect to an endpoint.
// //
// keyAddrs is the 32 byte public key of the peer followed by addrs. // keyAddrs is the 32 byte public key of the peer followed by addrs.
// Addrs is either: // Addrs is either:
@ -2745,9 +2735,9 @@ func (c *Conn) CreateBind(uint16) (conn.Bind, uint16, error) {
// 2) "<hex-discovery-key>.disco.tailscale:12345", a magic value that means the peer // 2) "<hex-discovery-key>.disco.tailscale:12345", a magic value that means the peer
// is running code that supports active discovery, so CreateEndpoint returns // is running code that supports active discovery, so CreateEndpoint returns
// a discoEndpoint. // a discoEndpoint.
func (c *Conn) CreateEndpoint(keyAddrs string) (conn.Endpoint, error) { func (c *Conn) ParseEndpoint(keyAddrs string) (conn.Endpoint, error) {
if len(keyAddrs) < 32 { if len(keyAddrs) < 32 {
c.logf("[unexpected] CreateEndpoint keyAddrs too short: %q", keyAddrs) c.logf("[unexpected] ParseEndpoint keyAddrs too short: %q", keyAddrs)
return nil, errors.New("endpoint string too short") return nil, errors.New("endpoint string too short")
} }
var pk key.Public var pk key.Public
@ -2756,7 +2746,7 @@ func (c *Conn) CreateEndpoint(keyAddrs string) (conn.Endpoint, error) {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
c.logf("magicsock: CreateEndpoint: key=%s: %s", pk.ShortString(), derpStr(addrs)) c.logf("magicsock: ParseEndpoint: key=%s: %s", pk.ShortString(), derpStr(addrs))
if !strings.HasSuffix(addrs, wgcfg.EndpointDiscoSuffix) { if !strings.HasSuffix(addrs, wgcfg.EndpointDiscoSuffix) {
return c.createLegacyEndpointLocked(pk, addrs) return c.createLegacyEndpointLocked(pk, addrs)
@ -2787,6 +2777,30 @@ func (c *Conn) CreateEndpoint(keyAddrs string) (conn.Endpoint, error) {
type RebindingUDPConn struct { type RebindingUDPConn struct {
mu sync.Mutex mu sync.Mutex
pconn net.PacketConn pconn net.PacketConn
fakeClosed bool // whether to pretend that the conn is closed; see type connBind
}
// currentConn returns c's current pconn and whether it is (fake) closed.
func (c *RebindingUDPConn) currentConn() (pconn net.PacketConn, fakeClosed bool) {
c.mu.Lock()
defer c.mu.Unlock()
return c.pconn, c.fakeClosed
}
// SetFakeClosed fake closes/opens c.
// Fake closing c unblocks all receives.
// See connBind for details about how this is used.
func (c *RebindingUDPConn) SetFakeClosed(b bool) {
c.mu.Lock()
defer c.mu.Unlock()
c.fakeClosed = b
if b {
// Unblock any existing reads so that they can discover that c is closed.
c.pconn.SetReadDeadline(aLongTimeAgo)
} else {
// Make reads blocking again.
c.pconn.SetReadDeadline(time.Time{})
}
} }
func (c *RebindingUDPConn) Reset(pconn net.PacketConn) { func (c *RebindingUDPConn) Reset(pconn net.PacketConn) {
@ -2804,16 +2818,17 @@ func (c *RebindingUDPConn) Reset(pconn net.PacketConn) {
// It returns the number of bytes copied and the source address. // It returns the number of bytes copied and the source address.
func (c *RebindingUDPConn) ReadFrom(b []byte) (int, net.Addr, error) { func (c *RebindingUDPConn) ReadFrom(b []byte) (int, net.Addr, error) {
for { for {
c.mu.Lock() pconn, closed := c.currentConn()
pconn := c.pconn if closed {
c.mu.Unlock() return 0, nil, net.ErrClosed
}
n, addr, err := pconn.ReadFrom(b) n, addr, err := pconn.ReadFrom(b)
if err != nil { if err != nil {
c.mu.Lock() pconn2, closed := c.currentConn()
pconn2 := c.pconn if closed {
c.mu.Unlock() return 0, nil, net.ErrClosed
}
if pconn != pconn2 { if pconn != pconn2 {
continue continue
} }
@ -2831,9 +2846,10 @@ func (c *RebindingUDPConn) ReadFrom(b []byte) (int, net.Addr, error) {
// when c's underlying connection is a net.UDPConn. // when c's underlying connection is a net.UDPConn.
func (c *RebindingUDPConn) ReadFromNetaddr(b []byte) (n int, ipp netaddr.IPPort, err error) { func (c *RebindingUDPConn) ReadFromNetaddr(b []byte) (n int, ipp netaddr.IPPort, err error) {
for { for {
c.mu.Lock() pconn, closed := c.currentConn()
pconn := c.pconn if closed {
c.mu.Unlock() return 0, netaddr.IPPort{}, net.ErrClosed
}
// Optimization: Treat *net.UDPConn specially. // Optimization: Treat *net.UDPConn specially.
// ReadFromUDP gets partially inlined, avoiding allocating a *net.UDPAddr, // ReadFromUDP gets partially inlined, avoiding allocating a *net.UDPAddr,
@ -2854,10 +2870,10 @@ func (c *RebindingUDPConn) ReadFromNetaddr(b []byte) (n int, ipp netaddr.IPPort,
} }
if err != nil { if err != nil {
c.mu.Lock() pconn2, closed := c.currentConn()
pconn2 := c.pconn if closed {
c.mu.Unlock() return 0, netaddr.IPPort{}, net.ErrClosed
}
if pconn != pconn2 { if pconn != pconn2 {
continue continue
} }
@ -2890,12 +2906,6 @@ func (c *RebindingUDPConn) Close() error {
return c.pconn.Close() return c.pconn.Close()
} }
func (c *RebindingUDPConn) SetReadDeadline(t time.Time) {
c.mu.Lock()
defer c.mu.Unlock()
c.pconn.SetReadDeadline(t)
}
func (c *RebindingUDPConn) WriteToUDP(b []byte, addr *net.UDPAddr) (int, error) { func (c *RebindingUDPConn) WriteToUDP(b []byte, addr *net.UDPAddr) (int, error) {
for { for {
c.mu.Lock() c.mu.Lock()

@ -22,7 +22,6 @@ import (
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
"sync/atomic"
"testing" "testing"
"time" "time"
"unsafe" "unsafe"
@ -170,11 +169,7 @@ func newMagicStack(t testing.TB, logf logger.Logf, l nettype.PacketListener, der
tsTun.SetFilter(filter.NewAllowAllForTest(logf)) tsTun.SetFilter(filter.NewAllowAllForTest(logf))
wgLogger := wglog.NewLogger(logf) wgLogger := wglog.NewLogger(logf)
opts := &device.DeviceOptions{ dev := device.NewDevice(tsTun, conn.Bind(), wgLogger.DeviceLogger, new(device.DeviceOptions))
CreateEndpoint: conn.CreateEndpoint,
CreateBind: conn.CreateBind,
}
dev := device.NewDevice(tsTun, wgLogger.DeviceLogger, opts)
dev.Up() dev.Up()
// Wait for magicsock to connect up to DERP. // Wait for magicsock to connect up to DERP.
@ -363,7 +358,7 @@ func TestNewConn(t *testing.T) {
go func() { go func() {
var pkt [64 << 10]byte var pkt [64 << 10]byte
for { for {
_, _, err := conn.ReceiveIPv4(pkt[:]) _, _, err := conn.receiveIPv4(pkt[:])
if err != nil { if err != nil {
return return
} }
@ -521,11 +516,7 @@ func TestDeviceStartStop(t *testing.T) {
tun := tuntest.NewChannelTUN() tun := tuntest.NewChannelTUN()
wgLogger := wglog.NewLogger(t.Logf) wgLogger := wglog.NewLogger(t.Logf)
opts := &device.DeviceOptions{ dev := device.NewDevice(tun.TUN(), conn.Bind(), wgLogger.DeviceLogger, new(device.DeviceOptions))
CreateEndpoint: conn.CreateEndpoint,
CreateBind: conn.CreateBind,
}
dev := device.NewDevice(tun.TUN(), wgLogger.DeviceLogger, opts)
dev.Up() dev.Up()
dev.Close() dev.Close()
} }
@ -1382,14 +1373,10 @@ func stringifyConfig(cfg wgcfg.Config) string {
func Test32bitAlignment(t *testing.T) { func Test32bitAlignment(t *testing.T) {
var de discoEndpoint var de discoEndpoint
var c Conn
if off := unsafe.Offsetof(de.lastRecvUnixAtomic); off%8 != 0 { if off := unsafe.Offsetof(de.lastRecvUnixAtomic); off%8 != 0 {
t.Fatalf("discoEndpoint.lastRecvUnixAtomic is not 8-byte aligned") t.Fatalf("discoEndpoint.lastRecvUnixAtomic is not 8-byte aligned")
} }
if off := unsafe.Offsetof(c.derpRecvCountAtomic); off%8 != 0 {
t.Fatalf("Conn.derpRecvCountAtomic is not 8-byte aligned")
}
if !de.isFirstRecvActivityInAwhile() { // verify this doesn't panic on 32-bit if !de.isFirstRecvActivityInAwhile() { // verify this doesn't panic on 32-bit
t.Error("expected true") t.Error("expected true")
@ -1397,7 +1384,6 @@ func Test32bitAlignment(t *testing.T) {
if de.isFirstRecvActivityInAwhile() { if de.isFirstRecvActivityInAwhile() {
t.Error("expected false on second call") t.Error("expected false on second call")
} }
atomic.AddInt64(&c.derpRecvCountAtomic, 1)
} }
// newNonLegacyTestConn returns a new Conn with DisableLegacyNetworking set true. // newNonLegacyTestConn returns a new Conn with DisableLegacyNetworking set true.
@ -1418,92 +1404,6 @@ func newNonLegacyTestConn(t testing.TB) *Conn {
return conn return conn
} }
// Tests concurrent DERP readers pushing DERP data into ReceiveIPv4
// (which should blend all DERP reads into UDP reads).
func TestDerpReceiveFromIPv4(t *testing.T) {
conn := newNonLegacyTestConn(t)
defer conn.Close()
sendConn, err := net.ListenPacket("udp4", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
defer sendConn.Close()
nodeKey, _ := addTestEndpoint(t, conn, sendConn)
var sends int = 250e3 // takes about a second
if testing.Short() {
sends /= 10
}
senders := runtime.NumCPU()
sends -= (sends % senders)
var wg sync.WaitGroup
defer wg.Wait()
t.Logf("doing %v sends over %d senders", sends, senders)
ctx, cancel := context.WithCancel(context.Background())
defer conn.Close()
defer cancel()
doneCtx, cancelDoneCtx := context.WithCancel(context.Background())
cancelDoneCtx()
for i := 0; i < senders; i++ {
wg.Add(1)
regionID := i + 1
go func() {
defer wg.Done()
for i := 0; i < sends/senders; i++ {
res := derpReadResult{
regionID: regionID,
n: 123,
src: key.Public(nodeKey),
copyBuf: func(dst []byte) int { return 123 },
}
// First send with the closed context. ~50% of
// these should end up going through the
// send-a-zero-derpReadResult path, returning
// true, in which case we don't want to send again.
// We test later that we hit the other path.
if conn.sendDerpReadResult(doneCtx, res) {
continue
}
if !conn.sendDerpReadResult(ctx, res) {
t.Error("unexpected false")
return
}
}
}()
}
zeroSendsStart := testCounterZeroDerpReadResultSend.Value()
buf := make([]byte, 1500)
for i := 0; i < sends; i++ {
n, ep, err := conn.ReceiveIPv4(buf)
if err != nil {
t.Fatal(err)
}
_ = n
_ = ep
}
t.Logf("did %d ReceiveIPv4 calls", sends)
zeroSends, zeroRecv := testCounterZeroDerpReadResultSend.Value(), testCounterZeroDerpReadResultRecv.Value()
if zeroSends != zeroRecv {
t.Errorf("did %d zero sends != %d corresponding receives", zeroSends, zeroRecv)
}
zeroSendDelta := zeroSends - zeroSendsStart
if zeroSendDelta == 0 {
t.Errorf("didn't see any sends of derpReadResult zero value")
}
if zeroSendDelta == int64(sends) {
t.Errorf("saw %v sends of the derpReadResult zero value which was unexpectedly high (100%% of our %v sends)", zeroSendDelta, sends)
}
}
// addTestEndpoint sets conn's network map to a single peer expected // addTestEndpoint sets conn's network map to a single peer expected
// to receive packets from sendConn (or DERP), and returns that peer's // to receive packets from sendConn (or DERP), and returns that peer's
// nodekey and discokey. // nodekey and discokey.
@ -1523,7 +1423,7 @@ func addTestEndpoint(tb testing.TB, conn *Conn, sendConn net.PacketConn) (tailcf
}, },
}) })
conn.SetPrivateKey(wgkey.Private{0: 1}) conn.SetPrivateKey(wgkey.Private{0: 1})
_, err := conn.CreateEndpoint(string(nodeKey[:]) + "0000000000000000000000000000000000000000000000000000000000000001.disco.tailscale:12345") _, err := conn.ParseEndpoint(string(nodeKey[:]) + "0000000000000000000000000000000000000000000000000000000000000001.disco.tailscale:12345")
if err != nil { if err != nil {
tb.Fatal(err) tb.Fatal(err)
} }
@ -1554,7 +1454,7 @@ func setUpReceiveFrom(tb testing.TB) (roundTrip func()) {
if _, err := sendConn.WriteTo(sendBuf, dstAddr); err != nil { if _, err := sendConn.WriteTo(sendBuf, dstAddr); err != nil {
tb.Fatalf("WriteTo: %v", err) tb.Fatalf("WriteTo: %v", err)
} }
n, ep, err := conn.ReceiveIPv4(buf) n, ep, err := conn.receiveIPv4(buf)
if err != nil { if err != nil {
tb.Fatal(err) tb.Fatal(err)
} }
@ -1697,7 +1597,7 @@ func TestSetNetworkMapChangingNodeKey(t *testing.T) {
}, },
}, },
}) })
_, err := conn.CreateEndpoint(string(nodeKey1[:]) + "0000000000000000000000000000000000000000000000000000000000000001.disco.tailscale:12345") _, err := conn.ParseEndpoint(string(nodeKey1[:]) + "0000000000000000000000000000000000000000000000000000000000000001.disco.tailscale:12345")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -1755,7 +1655,7 @@ func TestRebindStress(t *testing.T) {
go func() { go func() {
buf := make([]byte, 1500) buf := make([]byte, 1500)
for { for {
_, _, err := conn.ReceiveIPv4(buf) _, _, err := conn.receiveIPv4(buf)
if ctx.Err() != nil { if ctx.Err() != nil {
errc <- nil errc <- nil
return return

@ -320,8 +320,6 @@ func NewUserspaceEngine(logf logger.Logf, conf Config) (_ Engine, reterr error)
logf("[unexpected] peer %s has no single-IP routes: %v", peerWGKey.ShortString(), allowedIPs) logf("[unexpected] peer %s has no single-IP routes: %v", peerWGKey.ShortString(), allowedIPs)
} }
}, },
CreateBind: e.magicConn.CreateBind,
CreateEndpoint: e.magicConn.CreateEndpoint,
} }
e.tundev.OnTSMPPongReceived = func(pong packet.TSMPPongReply) { e.tundev.OnTSMPPongReceived = func(pong packet.TSMPPongReply) {
@ -336,8 +334,13 @@ func NewUserspaceEngine(logf logger.Logf, conf Config) (_ Engine, reterr error)
// wgdev takes ownership of tundev, will close it when closed. // wgdev takes ownership of tundev, will close it when closed.
e.logf("Creating wireguard device...") e.logf("Creating wireguard device...")
e.wgdev = device.NewDevice(e.tundev, e.wgLogger.DeviceLogger, opts) e.wgdev = device.NewDevice(e.tundev, e.magicConn.Bind(), e.wgLogger.DeviceLogger, opts)
closePool.addFunc(e.wgdev.Close) closePool.addFunc(e.wgdev.Close)
closePool.addFunc(func() {
if err := e.magicConn.Close(); err != nil {
e.logf("error closing magicconn: %v", err)
}
})
go func() { go func() {
up := false up := false

@ -14,6 +14,7 @@ import (
"sync" "sync"
"testing" "testing"
"github.com/tailscale/wireguard-go/conn"
"github.com/tailscale/wireguard-go/device" "github.com/tailscale/wireguard-go/device"
"github.com/tailscale/wireguard-go/tun" "github.com/tailscale/wireguard-go/tun"
"inet.af/netaddr" "inet.af/netaddr"
@ -55,8 +56,8 @@ func TestDeviceConfig(t *testing.T) {
}}, }},
} }
device1 := device.NewDevice(newNilTun(), device.NewLogger(device.LogLevelError, "device1")) device1 := device.NewDevice(newNilTun(), conn.NewDefaultBind(), device.NewLogger(device.LogLevelError, "device1"))
device2 := device.NewDevice(newNilTun(), device.NewLogger(device.LogLevelError, "device2")) device2 := device.NewDevice(newNilTun(), conn.NewDefaultBind(), device.NewLogger(device.LogLevelError, "device2"))
defer device1.Close() defer device1.Close()
defer device2.Close() defer device2.Close()

Loading…
Cancel
Save