diff --git a/net/tstun/wrap.go b/net/tstun/wrap.go index 9df0d4e1a..655c4ad46 100644 --- a/net/tstun/wrap.go +++ b/net/tstun/wrap.go @@ -181,6 +181,16 @@ func (t *Wrapper) Close() error { return err } +// isClosed reports whether t is closed. +func (t *Wrapper) isClosed() bool { + select { + case <-t.closed: + return true + default: + return false + } +} + // pumpEvents copies events from t.tdev to t.eventsUpDown and t.eventsOther. // pumpEvents exits when t.tdev.events or t.closed is closed. // pumpEvents closes t.eventsUpDown and t.eventsOther when it exits. @@ -266,28 +276,24 @@ func allowSendOnClosedChannel() { // so packets may be stuck in t.outbound if t.Read called t.tdev.Read directly. func (t *Wrapper) poll() { defer allowSendOnClosedChannel() // for send to t.outbound - for { - <-t.bufferConsumed - + for range t.bufferConsumed { + var n int + var err error // Read may use memory in t.buffer before PacketStartOffset for mandatory headers. // This is the rationale behind the tun.Wrapper.{Read,Write} interfaces // and the reason t.buffer has size MaxMessageSize and not MaxContentSize. - n, err := t.tdev.Read(t.buffer[:], PacketStartOffset) - if err != nil { - t.outbound <- tunReadResult{err: err} - // In principle, read errors are not fatal (but wireguard-go disagrees). - t.bufferConsumed <- struct{}{} - continue - } - - // Wireguard will skip an empty read, - // so we might as well do it here to avoid the send through t.outbound. - if n == 0 { - t.bufferConsumed <- struct{}{} - continue + // In principle, read errors are not fatal (but wireguard-go disagrees). + // We loop here until we get a non-empty (or failed) read. + // We don't need this loop for correctness, + // but wireguard-go will skip an empty read, + // so we might as well avoid the send through t.outbound. + for n == 0 && err == nil { + if t.isClosed() { + return + } + n, err = t.tdev.Read(t.buffer[:], PacketStartOffset) } - - t.outbound <- tunReadResult{data: t.buffer[PacketStartOffset : PacketStartOffset+n]} + t.outbound <- tunReadResult{data: t.buffer[PacketStartOffset : PacketStartOffset+n], err: err} } }