diff --git a/net/nettest/conn.go b/net/nettest/conn.go new file mode 100644 index 000000000..9cbbd3e34 --- /dev/null +++ b/net/nettest/conn.go @@ -0,0 +1,86 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package nettest + +import ( + "io" + "time" +) + +// Conn is a bi-directional in-memory stream that looks like a TCP net.Conn. +type Conn interface { + io.Reader + io.Writer + io.Closer + + // The *Deadline methods follow the semantics of net.Conn. + + SetDeadline(t time.Time) error + SetReadDeadline(t time.Time) error + SetWriteDeadline(t time.Time) error + + // SetReadBlock blocks or unblocks the Read method of this Conn. + // It reports an error if the existing value matches the new value, + // or if the Conn has been Closed. + SetReadBlock(bool) error + + // SetWriteBlock blocks or unblocks the Write method of this Conn. + // It reports an error if the existing value matches the new value, + // or if the Conn has been Closed. + SetWriteBlock(bool) error +} + +// NewConn creates a pair of Conns that are wired together by pipes. +func NewConn(name string, maxBuf int) (Conn, Conn) { + r := NewPipe(name+"|0", maxBuf) + w := NewPipe(name+"|1", maxBuf) + + return &connHalf{r: r, w: w}, &connHalf{r: w, w: r} +} + +type connHalf struct { + r, w *Pipe +} + +func (c *connHalf) Read(b []byte) (n int, err error) { + return c.r.Read(b) +} +func (c *connHalf) Write(b []byte) (n int, err error) { + return c.w.Write(b) +} +func (c *connHalf) Close() error { + err1 := c.r.Close() + err2 := c.w.Close() + if err1 != nil { + return err1 + } + return err2 +} +func (c *connHalf) SetDeadline(t time.Time) error { + err1 := c.SetReadDeadline(t) + err2 := c.SetWriteDeadline(t) + if err1 != nil { + return err1 + } + return err2 +} +func (c *connHalf) SetReadDeadline(t time.Time) error { + return c.r.SetReadDeadline(t) +} +func (c *connHalf) SetWriteDeadline(t time.Time) error { + return c.w.SetWriteDeadline(t) +} +func (c *connHalf) SetReadBlock(b bool) error { + if b { + return c.r.Block() + } + return c.r.Unblock() +} +func (c *connHalf) SetWriteBlock(b bool) error { + if b { + return c.w.Block() + } + return c.w.Unblock() +} diff --git a/net/nettest/pipe.go b/net/nettest/pipe.go new file mode 100644 index 000000000..e4c6e18ac --- /dev/null +++ b/net/nettest/pipe.go @@ -0,0 +1,261 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package nettest + +import ( + "context" + "errors" + "fmt" + "io" + "log" + "sync" + "time" +) + +const debugPipe = false + +// Pipe implements an in-memory FIFO with timeouts. +type Pipe struct { + name string + maxBuf int + rCh chan struct{} + wCh chan struct{} + + mu sync.Mutex + closed bool + blocked bool + buf []byte + readTimeout time.Time + writeTimeout time.Time + cancelReadTimer func() + cancelWriteTimer func() +} + +// NewPipe creates a Pipe with a buffer size fixed at maxBuf. +func NewPipe(name string, maxBuf int) *Pipe { + return &Pipe{ + name: name, + maxBuf: maxBuf, + rCh: make(chan struct{}, 1), + wCh: make(chan struct{}, 1), + } +} + +var ( + ErrTimeout = errors.New("timeout") + ErrReadTimeout = fmt.Errorf("read %w", ErrTimeout) + ErrWriteTimeout = fmt.Errorf("write %w", ErrTimeout) +) + +// Read implements io.Reader. +func (p *Pipe) Read(b []byte) (n int, err error) { + if debugPipe { + orig := b + defer func() { + log.Printf("Pipe(%q).Read( %q) n=%d, err=%v", p.name, string(orig[:n]), n, err) + }() + } + for { + p.mu.Lock() + closed := p.closed + timedout := !p.readTimeout.IsZero() && time.Now().After(p.readTimeout) + blocked := p.blocked + if !closed && !timedout && len(p.buf) > 0 { + n2 := copy(b, p.buf) + p.buf = p.buf[n2:] + b = b[n2:] + n += n2 + } + p.mu.Unlock() + + if closed { + return 0, fmt.Errorf("nettest.Pipe(%q): closed: %w", p.name, io.EOF) + } + if timedout { + return 0, fmt.Errorf("nettest.Pipe(%q): %w", p.name, ErrReadTimeout) + } + if blocked { + <-p.rCh + continue + } + if n > 0 { + p.signalWrite() + return n, nil + } + <-p.rCh + } +} + +// Write implements io.Writer. +func (p *Pipe) Write(b []byte) (n int, err error) { + if debugPipe { + orig := b + defer func() { + log.Printf("Pipe(%q).Write(%q) n=%d, err=%v", p.name, string(orig), n, err) + }() + } + for { + p.mu.Lock() + closed := p.closed + timedout := !p.writeTimeout.IsZero() && time.Now().After(p.writeTimeout) + blocked := p.blocked + if !closed && !timedout { + n2 := len(b) + if limit := p.maxBuf - len(p.buf); limit < n2 { + n2 = limit + } + p.buf = append(p.buf, b[:n2]...) + b = b[n2:] + n += n2 + } + p.mu.Unlock() + + if closed { + return n, fmt.Errorf("nettest.Pipe(%q): closed: %w", p.name, io.EOF) + } + if timedout { + return n, fmt.Errorf("nettest.Pipe(%q): %w", p.name, ErrWriteTimeout) + } + if blocked { + <-p.wCh + continue + } + if n > 0 { + p.signalRead() + } + if len(b) == 0 { + return n, nil + } + <-p.wCh + } +} + +// Close implements io.Closer. +func (p *Pipe) Close() error { + p.mu.Lock() + closed := p.closed + p.closed = true + if p.cancelWriteTimer != nil { + p.cancelWriteTimer() + p.cancelWriteTimer = nil + } + if p.cancelReadTimer != nil { + p.cancelReadTimer() + p.cancelReadTimer = nil + } + p.mu.Unlock() + + if closed { + return fmt.Errorf("nettest.Pipe(%q).Close: already closed", p.name) + } + + p.signalRead() + p.signalWrite() + return nil +} + +// SetReadDeadline sets the deadline for future Read calls. +func (p *Pipe) SetReadDeadline(t time.Time) error { + p.mu.Lock() + p.readTimeout = t + if p.cancelReadTimer != nil { + p.cancelReadTimer() + p.cancelReadTimer = nil + } + if d := time.Until(t); !t.IsZero() && d > 0 { + ctx, cancel := context.WithCancel(context.Background()) + p.cancelReadTimer = cancel + go func() { + t := time.NewTimer(d) + defer t.Stop() + select { + case <-t.C: + p.signalRead() + case <-ctx.Done(): + } + }() + } + p.mu.Unlock() + + p.signalRead() + return nil +} + +// SetWriteDeadline sets the deadline for future Write calls. +func (p *Pipe) SetWriteDeadline(t time.Time) error { + p.mu.Lock() + p.writeTimeout = t + if p.cancelWriteTimer != nil { + p.cancelWriteTimer() + p.cancelWriteTimer = nil + } + if d := time.Until(t); !t.IsZero() && d > 0 { + ctx, cancel := context.WithCancel(context.Background()) + p.cancelWriteTimer = cancel + go func() { + t := time.NewTimer(d) + defer t.Stop() + select { + case <-t.C: + p.signalWrite() + case <-ctx.Done(): + } + }() + } + p.mu.Unlock() + + p.signalWrite() + return nil +} + +func (p *Pipe) Block() error { + p.mu.Lock() + closed := p.closed + blocked := p.blocked + p.blocked = true + p.mu.Unlock() + + if closed { + return fmt.Errorf("nettest.Pipe(%q).Block: closed", p.name) + } + if blocked { + return fmt.Errorf("nettest.Pipe(%q).Block: already blocked", p.name) + } + p.signalRead() + p.signalWrite() + return nil +} + +func (p *Pipe) Unblock() error { + p.mu.Lock() + closed := p.closed + blocked := p.blocked + p.blocked = false + p.mu.Unlock() + + if closed { + return fmt.Errorf("nettest.Pipe(%q).Block: closed", p.name) + } + if !blocked { + return fmt.Errorf("nettest.Pipe(%q).Block: already unblocked", p.name) + } + p.signalRead() + p.signalWrite() + return nil +} + +func (p *Pipe) signalRead() { + select { + case p.rCh <- struct{}{}: + default: + } +} + +func (p *Pipe) signalWrite() { + select { + case p.wCh <- struct{}{}: + default: + } +} diff --git a/net/nettest/pipe_test.go b/net/nettest/pipe_test.go new file mode 100644 index 000000000..f40d27c53 --- /dev/null +++ b/net/nettest/pipe_test.go @@ -0,0 +1,116 @@ +package nettest + +import ( + "errors" + "fmt" + "testing" + "time" +) + +func TestPipeHello(t *testing.T) { + p := NewPipe("p1", 1<<16) + msg := "Hello, World!" + if n, err := p.Write([]byte(msg)); err != nil { + t.Fatal(err) + } else if n != len(msg) { + t.Errorf("p.Write(%q) n=%d, want %d", msg, n, len(msg)) + } + b := make([]byte, len(msg)) + if n, err := p.Read(b); err != nil { + t.Fatal(err) + } else if n != len(b) { + t.Errorf("p.Read(%q) n=%d, want %d", string(b[:n]), n, len(b)) + } + if got := string(b); got != msg { + t.Errorf("p.Read: %q, want %q", got, msg) + } +} + +func TestPipeTimeout(t *testing.T) { + t.Run("write", func(t *testing.T) { + p := NewPipe("p1", 1<<16) + p.SetWriteDeadline(time.Now().Add(-1 * time.Second)) + n, err := p.Write([]byte{'h'}) + if err == nil || !errors.Is(err, ErrWriteTimeout) || !errors.Is(err, ErrTimeout) { + t.Errorf("missing write timeout got err: %v", err) + } + if n != 0 { + t.Errorf("n=%d on timeout", n) + } + }) + t.Run("read", func(t *testing.T) { + p := NewPipe("p1", 1<<16) + p.Write([]byte{'h'}) + + p.SetReadDeadline(time.Now().Add(-1 * time.Second)) + b := make([]byte, 1) + n, err := p.Read(b) + if err == nil || !errors.Is(err, ErrReadTimeout) || !errors.Is(err, ErrTimeout) { + t.Errorf("missing read timeout got err: %v", err) + } + if n != 0 { + t.Errorf("n=%d on timeout", n) + } + }) + t.Run("block-write", func(t *testing.T) { + p := NewPipe("p1", 1<<16) + p.SetWriteDeadline(time.Now().Add(10 * time.Millisecond)) + if _, err := p.Write([]byte{'h'}); err != nil { + t.Fatal(err) + } + if err := p.Block(); err != nil { + t.Fatal(err) + } + if _, err := p.Write([]byte{'h'}); err == nil || !errors.Is(err, ErrWriteTimeout) { + t.Fatalf("want write timeout got: %v", err) + } + }) + t.Run("block-read", func(t *testing.T) { + p := NewPipe("p1", 1<<16) + p.Write([]byte{'h', 'i'}) + p.SetReadDeadline(time.Now().Add(10 * time.Millisecond)) + b := make([]byte, 1) + if _, err := p.Read(b); err != nil { + t.Fatal(err) + } + if err := p.Block(); err != nil { + t.Fatal(err) + } + if _, err := p.Read(b); err == nil || !errors.Is(err, ErrReadTimeout) { + t.Fatalf("want read timeout got: %v", err) + } + }) + +} + +func TestLimit(t *testing.T) { + p := NewPipe("p1", 1) + errCh := make(chan error) + go func() { + n, err := p.Write([]byte{'a', 'b', 'c'}) + if err != nil { + errCh <- err + } else if n != 3 { + errCh <- fmt.Errorf("p.Write n=%d, want 3", n) + } else { + errCh <- nil + } + }() + b := make([]byte, 3) + + if n, err := p.Read(b); err != nil { + t.Fatal(err) + } else if n != 1 { + t.Errorf("Read(%q): n=%d want 1", string(b), n) + } + if n, err := p.Read(b); err != nil { + t.Fatal(err) + } else if n != 1 { + t.Errorf("Read(%q): n=%d want 1", string(b), n) + } + if n, err := p.Read(b); err != nil { + t.Fatal(err) + } else if n != 1 { + t.Errorf("Read(%q): n=%d want 1", string(b), n) + } +}