wgengine/magicsock: factor out receiveIPv4 & receiveIPv6 common code

Updates #2331

Change-Id: I801df38b217f5d17203e8dc3b8654f44747e0f4b
Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
pull/7870/head
Brad Fitzpatrick 1 year ago committed by Brad Fitzpatrick
parent c889254b42
commit 6866aaeab3

@ -322,11 +322,6 @@ type Conn struct {
// bind is the wireguard-go conn.Bind for Conn.
bind *connBind
// ippEndpoint4 and ippEndpoint6 are owned by receiveIPv4 and
// receiveIPv6, respectively, to cache an IPPort->endpoint for
// hot flows.
ippEndpoint4, ippEndpoint6 ippEndpointCache
// ============================================================
// Fields that must be accessed via atomic load/stores.
@ -1851,80 +1846,64 @@ func (c *Conn) putReceiveBatch(batch *receiveBatch) {
c.receiveBatchPool.Put(batch)
}
func (c *Conn) receiveIPv6(buffs [][]byte, sizes []int, eps []conn.Endpoint) (int, error) {
health.ReceiveIPv6.Enter()
defer health.ReceiveIPv6.Exit()
// receiveIPv4 creates an IPv4 ReceiveFunc reading from c.pconn4.
func (c *Conn) receiveIPv4() conn.ReceiveFunc {
return c.mkReceiveFunc(&c.pconn4, &health.ReceiveIPv4, metricRecvDataIPv4)
}
batch := c.getReceiveBatchForBuffs(buffs)
defer c.putReceiveBatch(batch)
for {
numMsgs, err := c.pconn6.ReadBatch(batch.msgs[:len(buffs)], 0)
if err != nil {
if neterror.PacketWasTruncated(err) {
// TODO(raggi): discuss whether to log?
continue
}
return 0, err
}
// receiveIPv6 creates an IPv6 ReceiveFunc reading from c.pconn6.
func (c *Conn) receiveIPv6() conn.ReceiveFunc {
return c.mkReceiveFunc(&c.pconn6, &health.ReceiveIPv6, metricRecvDataIPv6)
}
reportToCaller := false
for i, msg := range batch.msgs[:numMsgs] {
if msg.N == 0 {
sizes[i] = 0
continue
}
ipp := msg.Addr.(*net.UDPAddr).AddrPort()
if ep, ok := c.receiveIP(msg.Buffers[0][:msg.N], ipp, &c.ippEndpoint6); ok {
metricRecvDataIPv6.Add(1)
eps[i] = ep
sizes[i] = msg.N
reportToCaller = true
} else {
sizes[i] = 0
}
}
// mkReceiveFunc creates a ReceiveFunc reading from ruc.
// The provided healthItem and metric are updated if non-nil.
func (c *Conn) mkReceiveFunc(ruc *RebindingUDPConn, healthItem *health.ReceiveFuncStats, metric *clientmetric.Metric) conn.ReceiveFunc {
// epCache caches an IPPort->endpoint for hot flows.
var epCache ippEndpointCache
if reportToCaller {
return numMsgs, nil
return func(buffs [][]byte, sizes []int, eps []conn.Endpoint) (int, error) {
if healthItem != nil {
healthItem.Enter()
defer healthItem.Exit()
}
if ruc == nil {
panic("nil RebindingUDPConn")
}
}
}
func (c *Conn) receiveIPv4(buffs [][]byte, sizes []int, eps []conn.Endpoint) (int, error) {
health.ReceiveIPv4.Enter()
defer health.ReceiveIPv4.Exit()
batch := c.getReceiveBatchForBuffs(buffs)
defer c.putReceiveBatch(batch)
for {
numMsgs, err := c.pconn4.ReadBatch(batch.msgs[:len(buffs)], 0)
if err != nil {
if neterror.PacketWasTruncated(err) {
// TODO(raggi): discuss whether to log?
continue
batch := c.getReceiveBatchForBuffs(buffs)
defer c.putReceiveBatch(batch)
for {
numMsgs, err := ruc.ReadBatch(batch.msgs[:len(buffs)], 0)
if err != nil {
if neterror.PacketWasTruncated(err) {
continue
}
return 0, err
}
return 0, err
}
reportToCaller := false
for i, msg := range batch.msgs[:numMsgs] {
if msg.N == 0 {
sizes[i] = 0
continue
reportToCaller := false
for i, msg := range batch.msgs[:numMsgs] {
if msg.N == 0 {
sizes[i] = 0
continue
}
ipp := msg.Addr.(*net.UDPAddr).AddrPort()
if ep, ok := c.receiveIP(msg.Buffers[0][:msg.N], ipp, &epCache); ok {
if metric != nil {
metric.Add(1)
}
eps[i] = ep
sizes[i] = msg.N
reportToCaller = true
} else {
sizes[i] = 0
}
}
ipp := msg.Addr.(*net.UDPAddr).AddrPort()
if ep, ok := c.receiveIP(msg.Buffers[0][:msg.N], ipp, &c.ippEndpoint4); ok {
metricRecvDataIPv4.Add(1)
eps[i] = ep
sizes[i] = msg.N
reportToCaller = true
} else {
sizes[i] = 0
if reportToCaller {
return numMsgs, nil
}
}
if reportToCaller {
return numMsgs, nil
}
}
}
@ -3044,7 +3023,7 @@ func (c *connBind) Open(ignoredPort uint16) ([]conn.ReceiveFunc, uint16, error)
return nil, 0, errors.New("magicsock: connBind already open")
}
c.closed = false
fns := []conn.ReceiveFunc{c.receiveIPv4, c.receiveIPv6, c.receiveDERP}
fns := []conn.ReceiveFunc{c.receiveIPv4(), c.receiveIPv6(), c.receiveDERP}
if runtime.GOOS == "js" {
fns = []conn.ReceiveFunc{c.receiveDERP}
}

@ -374,8 +374,9 @@ func TestNewConn(t *testing.T) {
sizes := make([]int, 1)
eps := make([]wgconn.Endpoint, 1)
pkts[0] = make([]byte, 64<<10)
receiveIPv4 := conn.receiveIPv4()
for {
_, err := conn.receiveIPv4(pkts, sizes, eps)
_, err := receiveIPv4(pkts, sizes, eps)
if err != nil {
return
}
@ -1284,11 +1285,12 @@ func setUpReceiveFrom(tb testing.TB) (roundTrip func()) {
buffs[0] = make([]byte, 2<<10)
sizes := make([]int, 1)
eps := make([]wgconn.Endpoint, 1)
receiveIPv4 := conn.receiveIPv4()
return func() {
if _, err := sendConn.WriteTo(sendBuf, dstAddr); err != nil {
tb.Fatalf("WriteTo: %v", err)
}
n, err := conn.receiveIPv4(buffs, sizes, eps)
n, err := receiveIPv4(buffs, sizes, eps)
if err != nil {
tb.Fatal(err)
}
@ -1513,8 +1515,9 @@ func TestRebindStress(t *testing.T) {
sizes := make([]int, 1)
eps := make([]wgconn.Endpoint, 1)
buffs[0] = make([]byte, 1500)
receiveIPv4 := conn.receiveIPv4()
for {
_, err := conn.receiveIPv4(buffs, sizes, eps)
_, err := receiveIPv4(buffs, sizes, eps)
if ctx.Err() != nil {
errc <- nil
return

Loading…
Cancel
Save