You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
tailscale/net/nettest/pipe.go

262 lines
4.9 KiB
Go

// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package nettest
import (
"context"
"errors"
"fmt"
"io"
"log"
"sync"
"time"
)
const debugPipe = false
// Pipe implements an in-memory FIFO with timeouts.
type Pipe struct {
name string
maxBuf int
rCh chan struct{}
wCh chan struct{}
mu sync.Mutex
closed bool
blocked bool
buf []byte
readTimeout time.Time
writeTimeout time.Time
cancelReadTimer func()
cancelWriteTimer func()
}
// NewPipe creates a Pipe with a buffer size fixed at maxBuf.
func NewPipe(name string, maxBuf int) *Pipe {
return &Pipe{
name: name,
maxBuf: maxBuf,
rCh: make(chan struct{}, 1),
wCh: make(chan struct{}, 1),
}
}
var (
ErrTimeout = errors.New("timeout")
ErrReadTimeout = fmt.Errorf("read %w", ErrTimeout)
ErrWriteTimeout = fmt.Errorf("write %w", ErrTimeout)
)
// Read implements io.Reader.
func (p *Pipe) Read(b []byte) (n int, err error) {
if debugPipe {
orig := b
defer func() {
log.Printf("Pipe(%q).Read( %q) n=%d, err=%v", p.name, string(orig[:n]), n, err)
}()
}
for {
p.mu.Lock()
closed := p.closed
timedout := !p.readTimeout.IsZero() && time.Now().After(p.readTimeout)
blocked := p.blocked
if !closed && !timedout && len(p.buf) > 0 {
n2 := copy(b, p.buf)
p.buf = p.buf[n2:]
b = b[n2:]
n += n2
}
p.mu.Unlock()
if closed {
return 0, fmt.Errorf("nettest.Pipe(%q): closed: %w", p.name, io.EOF)
}
if timedout {
return 0, fmt.Errorf("nettest.Pipe(%q): %w", p.name, ErrReadTimeout)
}
if blocked {
<-p.rCh
continue
}
if n > 0 {
p.signalWrite()
return n, nil
}
<-p.rCh
}
}
// Write implements io.Writer.
func (p *Pipe) Write(b []byte) (n int, err error) {
if debugPipe {
orig := b
defer func() {
log.Printf("Pipe(%q).Write(%q) n=%d, err=%v", p.name, string(orig), n, err)
}()
}
for {
p.mu.Lock()
closed := p.closed
timedout := !p.writeTimeout.IsZero() && time.Now().After(p.writeTimeout)
blocked := p.blocked
if !closed && !timedout {
n2 := len(b)
if limit := p.maxBuf - len(p.buf); limit < n2 {
n2 = limit
}
p.buf = append(p.buf, b[:n2]...)
b = b[n2:]
n += n2
}
p.mu.Unlock()
if closed {
return n, fmt.Errorf("nettest.Pipe(%q): closed: %w", p.name, io.EOF)
}
if timedout {
return n, fmt.Errorf("nettest.Pipe(%q): %w", p.name, ErrWriteTimeout)
}
if blocked {
<-p.wCh
continue
}
if n > 0 {
p.signalRead()
}
if len(b) == 0 {
return n, nil
}
<-p.wCh
}
}
// Close implements io.Closer.
func (p *Pipe) Close() error {
p.mu.Lock()
closed := p.closed
p.closed = true
if p.cancelWriteTimer != nil {
p.cancelWriteTimer()
p.cancelWriteTimer = nil
}
if p.cancelReadTimer != nil {
p.cancelReadTimer()
p.cancelReadTimer = nil
}
p.mu.Unlock()
if closed {
return fmt.Errorf("nettest.Pipe(%q).Close: already closed", p.name)
}
p.signalRead()
p.signalWrite()
return nil
}
// SetReadDeadline sets the deadline for future Read calls.
func (p *Pipe) SetReadDeadline(t time.Time) error {
p.mu.Lock()
p.readTimeout = t
if p.cancelReadTimer != nil {
p.cancelReadTimer()
p.cancelReadTimer = nil
}
if d := time.Until(t); !t.IsZero() && d > 0 {
ctx, cancel := context.WithCancel(context.Background())
p.cancelReadTimer = cancel
go func() {
t := time.NewTimer(d)
defer t.Stop()
select {
case <-t.C:
p.signalRead()
case <-ctx.Done():
}
}()
}
p.mu.Unlock()
p.signalRead()
return nil
}
// SetWriteDeadline sets the deadline for future Write calls.
func (p *Pipe) SetWriteDeadline(t time.Time) error {
p.mu.Lock()
p.writeTimeout = t
if p.cancelWriteTimer != nil {
p.cancelWriteTimer()
p.cancelWriteTimer = nil
}
if d := time.Until(t); !t.IsZero() && d > 0 {
ctx, cancel := context.WithCancel(context.Background())
p.cancelWriteTimer = cancel
go func() {
t := time.NewTimer(d)
defer t.Stop()
select {
case <-t.C:
p.signalWrite()
case <-ctx.Done():
}
}()
}
p.mu.Unlock()
p.signalWrite()
return nil
}
func (p *Pipe) Block() error {
p.mu.Lock()
closed := p.closed
blocked := p.blocked
p.blocked = true
p.mu.Unlock()
if closed {
return fmt.Errorf("nettest.Pipe(%q).Block: closed", p.name)
}
if blocked {
return fmt.Errorf("nettest.Pipe(%q).Block: already blocked", p.name)
}
p.signalRead()
p.signalWrite()
return nil
}
func (p *Pipe) Unblock() error {
p.mu.Lock()
closed := p.closed
blocked := p.blocked
p.blocked = false
p.mu.Unlock()
if closed {
return fmt.Errorf("nettest.Pipe(%q).Block: closed", p.name)
}
if !blocked {
return fmt.Errorf("nettest.Pipe(%q).Block: already unblocked", p.name)
}
p.signalRead()
p.signalWrite()
return nil
}
func (p *Pipe) signalRead() {
select {
case p.rCh <- struct{}{}:
default:
}
}
func (p *Pipe) signalWrite() {
select {
case p.wCh <- struct{}{}:
default:
}
}