diff --git a/wgengine/magicsock/magicsock.go b/wgengine/magicsock/magicsock.go index e2e3b8066..1ce098809 100644 --- a/wgengine/magicsock/magicsock.go +++ b/wgengine/magicsock/magicsock.go @@ -3258,19 +3258,23 @@ func (c *RebindingUDPConn) currentConn() nettype.PacketConn { return c.pconn } -// ReadFrom reads a packet from c into b. -// It returns the number of bytes copied and the source address. -func (c *RebindingUDPConn) ReadFrom(b []byte) (int, net.Addr, error) { +func (c *RebindingUDPConn) readFromWithInitPconn(pconn nettype.PacketConn, b []byte) (int, net.Addr, error) { for { - pconn := *c.pconnAtomic.Load() n, addr, err := pconn.ReadFrom(b) if err != nil && pconn != c.currentConn() { + pconn = *c.pconnAtomic.Load() continue } return n, addr, err } } +// ReadFrom reads a packet from c into b. +// It returns the number of bytes copied and the source address. +func (c *RebindingUDPConn) ReadFrom(b []byte) (int, net.Addr, error) { + return c.readFromWithInitPconn(*c.pconnAtomic.Load(), b) +} + // ReadFromNetaddr reads a packet from c into b. // It returns the number of bytes copied and the return address. // It is identical to c.ReadFrom, except that it returns a netip.AddrPort instead of a net.Addr. @@ -3321,7 +3325,7 @@ func (c *RebindingUDPConn) WriteBatch(msgs []ipv6.Message, flags int) (int, erro bw, ok := pconn.(batchWriter) if !ok { for _, msg := range msgs { - _, err = pconn.WriteTo(msg.Buffers[0], msg.Addr) + _, err = c.writeToWithInitPconn(pconn, msg.Buffers[0], msg.Addr) if err != nil { return n, err } @@ -3350,7 +3354,7 @@ func (c *RebindingUDPConn) ReadBatch(msgs []ipv6.Message, flags int) (int, error br, ok := pconn.(batchReader) if !ok { var err error - msgs[0].N, msgs[0].Addr, err = c.ReadFrom(msgs[0].Buffers[0]) + msgs[0].N, msgs[0].Addr, err = c.readFromWithInitPconn(pconn, msgs[0].Buffers[0]) if err == nil { return 1, nil } @@ -3398,17 +3402,21 @@ func (c *RebindingUDPConn) closeLocked() error { return c.pconn.Close() } -func (c *RebindingUDPConn) WriteTo(b []byte, addr net.Addr) (int, error) { +func (c *RebindingUDPConn) writeToWithInitPconn(pconn nettype.PacketConn, b []byte, addr net.Addr) (int, error) { for { - pconn := *c.pconnAtomic.Load() n, err := pconn.WriteTo(b, addr) if err != nil && pconn != c.currentConn() { + pconn = *c.pconnAtomic.Load() continue } return n, err } } +func (c *RebindingUDPConn) WriteTo(b []byte, addr net.Addr) (int, error) { + return c.writeToWithInitPconn(*c.pconnAtomic.Load(), b, addr) +} + func (c *RebindingUDPConn) WriteToUDPAddrPort(b []byte, addr netip.AddrPort) (int, error) { for { pconn := *c.pconnAtomic.Load()