wgengine/magicsock: factor out receiveIPv4 & receiveIPv6 common code

Updates #2331

Change-Id: I801df38b217f5d17203e8dc3b8654f44747e0f4b
Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
pull/7870/head
Brad Fitzpatrick 2 years ago committed by Brad Fitzpatrick
parent c889254b42
commit 6866aaeab3

@ -322,11 +322,6 @@ type Conn struct {
// bind is the wireguard-go conn.Bind for Conn. // bind is the wireguard-go conn.Bind for Conn.
bind *connBind bind *connBind
// ippEndpoint4 and ippEndpoint6 are owned by receiveIPv4 and
// receiveIPv6, respectively, to cache an IPPort->endpoint for
// hot flows.
ippEndpoint4, ippEndpoint6 ippEndpointCache
// ============================================================ // ============================================================
// Fields that must be accessed via atomic load/stores. // Fields that must be accessed via atomic load/stores.
@ -1851,56 +1846,37 @@ func (c *Conn) putReceiveBatch(batch *receiveBatch) {
c.receiveBatchPool.Put(batch) c.receiveBatchPool.Put(batch)
} }
func (c *Conn) receiveIPv6(buffs [][]byte, sizes []int, eps []conn.Endpoint) (int, error) { // receiveIPv4 creates an IPv4 ReceiveFunc reading from c.pconn4.
health.ReceiveIPv6.Enter() func (c *Conn) receiveIPv4() conn.ReceiveFunc {
defer health.ReceiveIPv6.Exit() return c.mkReceiveFunc(&c.pconn4, &health.ReceiveIPv4, metricRecvDataIPv4)
}
batch := c.getReceiveBatchForBuffs(buffs) // receiveIPv6 creates an IPv6 ReceiveFunc reading from c.pconn6.
defer c.putReceiveBatch(batch) func (c *Conn) receiveIPv6() conn.ReceiveFunc {
for { return c.mkReceiveFunc(&c.pconn6, &health.ReceiveIPv6, metricRecvDataIPv6)
numMsgs, err := c.pconn6.ReadBatch(batch.msgs[:len(buffs)], 0) }
if err != nil {
if neterror.PacketWasTruncated(err) {
// TODO(raggi): discuss whether to log?
continue
}
return 0, err
}
reportToCaller := false // mkReceiveFunc creates a ReceiveFunc reading from ruc.
for i, msg := range batch.msgs[:numMsgs] { // The provided healthItem and metric are updated if non-nil.
if msg.N == 0 { func (c *Conn) mkReceiveFunc(ruc *RebindingUDPConn, healthItem *health.ReceiveFuncStats, metric *clientmetric.Metric) conn.ReceiveFunc {
sizes[i] = 0 // epCache caches an IPPort->endpoint for hot flows.
continue var epCache ippEndpointCache
}
ipp := msg.Addr.(*net.UDPAddr).AddrPort()
if ep, ok := c.receiveIP(msg.Buffers[0][:msg.N], ipp, &c.ippEndpoint6); ok {
metricRecvDataIPv6.Add(1)
eps[i] = ep
sizes[i] = msg.N
reportToCaller = true
} else {
sizes[i] = 0
}
}
if reportToCaller { return func(buffs [][]byte, sizes []int, eps []conn.Endpoint) (int, error) {
return numMsgs, nil if healthItem != nil {
healthItem.Enter()
defer healthItem.Exit()
} }
if ruc == nil {
panic("nil RebindingUDPConn")
} }
}
func (c *Conn) receiveIPv4(buffs [][]byte, sizes []int, eps []conn.Endpoint) (int, error) {
health.ReceiveIPv4.Enter()
defer health.ReceiveIPv4.Exit()
batch := c.getReceiveBatchForBuffs(buffs) batch := c.getReceiveBatchForBuffs(buffs)
defer c.putReceiveBatch(batch) defer c.putReceiveBatch(batch)
for { for {
numMsgs, err := c.pconn4.ReadBatch(batch.msgs[:len(buffs)], 0) numMsgs, err := ruc.ReadBatch(batch.msgs[:len(buffs)], 0)
if err != nil { if err != nil {
if neterror.PacketWasTruncated(err) { if neterror.PacketWasTruncated(err) {
// TODO(raggi): discuss whether to log?
continue continue
} }
return 0, err return 0, err
@ -1913,8 +1889,10 @@ func (c *Conn) receiveIPv4(buffs [][]byte, sizes []int, eps []conn.Endpoint) (in
continue continue
} }
ipp := msg.Addr.(*net.UDPAddr).AddrPort() ipp := msg.Addr.(*net.UDPAddr).AddrPort()
if ep, ok := c.receiveIP(msg.Buffers[0][:msg.N], ipp, &c.ippEndpoint4); ok { if ep, ok := c.receiveIP(msg.Buffers[0][:msg.N], ipp, &epCache); ok {
metricRecvDataIPv4.Add(1) if metric != nil {
metric.Add(1)
}
eps[i] = ep eps[i] = ep
sizes[i] = msg.N sizes[i] = msg.N
reportToCaller = true reportToCaller = true
@ -1926,6 +1904,7 @@ func (c *Conn) receiveIPv4(buffs [][]byte, sizes []int, eps []conn.Endpoint) (in
return numMsgs, nil return numMsgs, nil
} }
} }
}
} }
// receiveIP is the shared bits of ReceiveIPv4 and ReceiveIPv6. // receiveIP is the shared bits of ReceiveIPv4 and ReceiveIPv6.
@ -3044,7 +3023,7 @@ func (c *connBind) Open(ignoredPort uint16) ([]conn.ReceiveFunc, uint16, error)
return nil, 0, errors.New("magicsock: connBind already open") return nil, 0, errors.New("magicsock: connBind already open")
} }
c.closed = false c.closed = false
fns := []conn.ReceiveFunc{c.receiveIPv4, c.receiveIPv6, c.receiveDERP} fns := []conn.ReceiveFunc{c.receiveIPv4(), c.receiveIPv6(), c.receiveDERP}
if runtime.GOOS == "js" { if runtime.GOOS == "js" {
fns = []conn.ReceiveFunc{c.receiveDERP} fns = []conn.ReceiveFunc{c.receiveDERP}
} }

@ -374,8 +374,9 @@ func TestNewConn(t *testing.T) {
sizes := make([]int, 1) sizes := make([]int, 1)
eps := make([]wgconn.Endpoint, 1) eps := make([]wgconn.Endpoint, 1)
pkts[0] = make([]byte, 64<<10) pkts[0] = make([]byte, 64<<10)
receiveIPv4 := conn.receiveIPv4()
for { for {
_, err := conn.receiveIPv4(pkts, sizes, eps) _, err := receiveIPv4(pkts, sizes, eps)
if err != nil { if err != nil {
return return
} }
@ -1284,11 +1285,12 @@ func setUpReceiveFrom(tb testing.TB) (roundTrip func()) {
buffs[0] = make([]byte, 2<<10) buffs[0] = make([]byte, 2<<10)
sizes := make([]int, 1) sizes := make([]int, 1)
eps := make([]wgconn.Endpoint, 1) eps := make([]wgconn.Endpoint, 1)
receiveIPv4 := conn.receiveIPv4()
return func() { return func() {
if _, err := sendConn.WriteTo(sendBuf, dstAddr); err != nil { if _, err := sendConn.WriteTo(sendBuf, dstAddr); err != nil {
tb.Fatalf("WriteTo: %v", err) tb.Fatalf("WriteTo: %v", err)
} }
n, err := conn.receiveIPv4(buffs, sizes, eps) n, err := receiveIPv4(buffs, sizes, eps)
if err != nil { if err != nil {
tb.Fatal(err) tb.Fatal(err)
} }
@ -1513,8 +1515,9 @@ func TestRebindStress(t *testing.T) {
sizes := make([]int, 1) sizes := make([]int, 1)
eps := make([]wgconn.Endpoint, 1) eps := make([]wgconn.Endpoint, 1)
buffs[0] = make([]byte, 1500) buffs[0] = make([]byte, 1500)
receiveIPv4 := conn.receiveIPv4()
for { for {
_, err := conn.receiveIPv4(buffs, sizes, eps) _, err := receiveIPv4(buffs, sizes, eps)
if ctx.Err() != nil { if ctx.Err() != nil {
errc <- nil errc <- nil
return return

Loading…
Cancel
Save