From 41ac4a79d65a6834f64d05b589d2958a21f240aa Mon Sep 17 00:00:00 2001 From: David Crawshaw Date: Thu, 12 Mar 2020 11:01:58 -0400 Subject: [PATCH] net/nettest: new package with net-like testing primitives This is a lot like wiring up a local UDP socket, read and write deadlines work. The big difference is the Block feature, which lets you stop the packet flow without breaking the connection. This lets you emulate broken sockets and test timeouts actually work. Signed-off-by: David Crawshaw --- net/nettest/conn.go | 86 +++++++++++++ net/nettest/pipe.go | 261 +++++++++++++++++++++++++++++++++++++++ net/nettest/pipe_test.go | 116 +++++++++++++++++ 3 files changed, 463 insertions(+) create mode 100644 net/nettest/conn.go create mode 100644 net/nettest/pipe.go create mode 100644 net/nettest/pipe_test.go diff --git a/net/nettest/conn.go b/net/nettest/conn.go new file mode 100644 index 000000000..9cbbd3e34 --- /dev/null +++ b/net/nettest/conn.go @@ -0,0 +1,86 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package nettest + +import ( + "io" + "time" +) + +// Conn is a bi-directional in-memory stream that looks like a TCP net.Conn. +type Conn interface { + io.Reader + io.Writer + io.Closer + + // The *Deadline methods follow the semantics of net.Conn. + + SetDeadline(t time.Time) error + SetReadDeadline(t time.Time) error + SetWriteDeadline(t time.Time) error + + // SetReadBlock blocks or unblocks the Read method of this Conn. + // It reports an error if the existing value matches the new value, + // or if the Conn has been Closed. + SetReadBlock(bool) error + + // SetWriteBlock blocks or unblocks the Write method of this Conn. + // It reports an error if the existing value matches the new value, + // or if the Conn has been Closed. + SetWriteBlock(bool) error +} + +// NewConn creates a pair of Conns that are wired together by pipes. +func NewConn(name string, maxBuf int) (Conn, Conn) { + r := NewPipe(name+"|0", maxBuf) + w := NewPipe(name+"|1", maxBuf) + + return &connHalf{r: r, w: w}, &connHalf{r: w, w: r} +} + +type connHalf struct { + r, w *Pipe +} + +func (c *connHalf) Read(b []byte) (n int, err error) { + return c.r.Read(b) +} +func (c *connHalf) Write(b []byte) (n int, err error) { + return c.w.Write(b) +} +func (c *connHalf) Close() error { + err1 := c.r.Close() + err2 := c.w.Close() + if err1 != nil { + return err1 + } + return err2 +} +func (c *connHalf) SetDeadline(t time.Time) error { + err1 := c.SetReadDeadline(t) + err2 := c.SetWriteDeadline(t) + if err1 != nil { + return err1 + } + return err2 +} +func (c *connHalf) SetReadDeadline(t time.Time) error { + return c.r.SetReadDeadline(t) +} +func (c *connHalf) SetWriteDeadline(t time.Time) error { + return c.w.SetWriteDeadline(t) +} +func (c *connHalf) SetReadBlock(b bool) error { + if b { + return c.r.Block() + } + return c.r.Unblock() +} +func (c *connHalf) SetWriteBlock(b bool) error { + if b { + return c.w.Block() + } + return c.w.Unblock() +} diff --git a/net/nettest/pipe.go b/net/nettest/pipe.go new file mode 100644 index 000000000..e4c6e18ac --- /dev/null +++ b/net/nettest/pipe.go @@ -0,0 +1,261 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package nettest + +import ( + "context" + "errors" + "fmt" + "io" + "log" + "sync" + "time" +) + +const debugPipe = false + +// Pipe implements an in-memory FIFO with timeouts. +type Pipe struct { + name string + maxBuf int + rCh chan struct{} + wCh chan struct{} + + mu sync.Mutex + closed bool + blocked bool + buf []byte + readTimeout time.Time + writeTimeout time.Time + cancelReadTimer func() + cancelWriteTimer func() +} + +// NewPipe creates a Pipe with a buffer size fixed at maxBuf. +func NewPipe(name string, maxBuf int) *Pipe { + return &Pipe{ + name: name, + maxBuf: maxBuf, + rCh: make(chan struct{}, 1), + wCh: make(chan struct{}, 1), + } +} + +var ( + ErrTimeout = errors.New("timeout") + ErrReadTimeout = fmt.Errorf("read %w", ErrTimeout) + ErrWriteTimeout = fmt.Errorf("write %w", ErrTimeout) +) + +// Read implements io.Reader. +func (p *Pipe) Read(b []byte) (n int, err error) { + if debugPipe { + orig := b + defer func() { + log.Printf("Pipe(%q).Read( %q) n=%d, err=%v", p.name, string(orig[:n]), n, err) + }() + } + for { + p.mu.Lock() + closed := p.closed + timedout := !p.readTimeout.IsZero() && time.Now().After(p.readTimeout) + blocked := p.blocked + if !closed && !timedout && len(p.buf) > 0 { + n2 := copy(b, p.buf) + p.buf = p.buf[n2:] + b = b[n2:] + n += n2 + } + p.mu.Unlock() + + if closed { + return 0, fmt.Errorf("nettest.Pipe(%q): closed: %w", p.name, io.EOF) + } + if timedout { + return 0, fmt.Errorf("nettest.Pipe(%q): %w", p.name, ErrReadTimeout) + } + if blocked { + <-p.rCh + continue + } + if n > 0 { + p.signalWrite() + return n, nil + } + <-p.rCh + } +} + +// Write implements io.Writer. +func (p *Pipe) Write(b []byte) (n int, err error) { + if debugPipe { + orig := b + defer func() { + log.Printf("Pipe(%q).Write(%q) n=%d, err=%v", p.name, string(orig), n, err) + }() + } + for { + p.mu.Lock() + closed := p.closed + timedout := !p.writeTimeout.IsZero() && time.Now().After(p.writeTimeout) + blocked := p.blocked + if !closed && !timedout { + n2 := len(b) + if limit := p.maxBuf - len(p.buf); limit < n2 { + n2 = limit + } + p.buf = append(p.buf, b[:n2]...) + b = b[n2:] + n += n2 + } + p.mu.Unlock() + + if closed { + return n, fmt.Errorf("nettest.Pipe(%q): closed: %w", p.name, io.EOF) + } + if timedout { + return n, fmt.Errorf("nettest.Pipe(%q): %w", p.name, ErrWriteTimeout) + } + if blocked { + <-p.wCh + continue + } + if n > 0 { + p.signalRead() + } + if len(b) == 0 { + return n, nil + } + <-p.wCh + } +} + +// Close implements io.Closer. +func (p *Pipe) Close() error { + p.mu.Lock() + closed := p.closed + p.closed = true + if p.cancelWriteTimer != nil { + p.cancelWriteTimer() + p.cancelWriteTimer = nil + } + if p.cancelReadTimer != nil { + p.cancelReadTimer() + p.cancelReadTimer = nil + } + p.mu.Unlock() + + if closed { + return fmt.Errorf("nettest.Pipe(%q).Close: already closed", p.name) + } + + p.signalRead() + p.signalWrite() + return nil +} + +// SetReadDeadline sets the deadline for future Read calls. +func (p *Pipe) SetReadDeadline(t time.Time) error { + p.mu.Lock() + p.readTimeout = t + if p.cancelReadTimer != nil { + p.cancelReadTimer() + p.cancelReadTimer = nil + } + if d := time.Until(t); !t.IsZero() && d > 0 { + ctx, cancel := context.WithCancel(context.Background()) + p.cancelReadTimer = cancel + go func() { + t := time.NewTimer(d) + defer t.Stop() + select { + case <-t.C: + p.signalRead() + case <-ctx.Done(): + } + }() + } + p.mu.Unlock() + + p.signalRead() + return nil +} + +// SetWriteDeadline sets the deadline for future Write calls. +func (p *Pipe) SetWriteDeadline(t time.Time) error { + p.mu.Lock() + p.writeTimeout = t + if p.cancelWriteTimer != nil { + p.cancelWriteTimer() + p.cancelWriteTimer = nil + } + if d := time.Until(t); !t.IsZero() && d > 0 { + ctx, cancel := context.WithCancel(context.Background()) + p.cancelWriteTimer = cancel + go func() { + t := time.NewTimer(d) + defer t.Stop() + select { + case <-t.C: + p.signalWrite() + case <-ctx.Done(): + } + }() + } + p.mu.Unlock() + + p.signalWrite() + return nil +} + +func (p *Pipe) Block() error { + p.mu.Lock() + closed := p.closed + blocked := p.blocked + p.blocked = true + p.mu.Unlock() + + if closed { + return fmt.Errorf("nettest.Pipe(%q).Block: closed", p.name) + } + if blocked { + return fmt.Errorf("nettest.Pipe(%q).Block: already blocked", p.name) + } + p.signalRead() + p.signalWrite() + return nil +} + +func (p *Pipe) Unblock() error { + p.mu.Lock() + closed := p.closed + blocked := p.blocked + p.blocked = false + p.mu.Unlock() + + if closed { + return fmt.Errorf("nettest.Pipe(%q).Block: closed", p.name) + } + if !blocked { + return fmt.Errorf("nettest.Pipe(%q).Block: already unblocked", p.name) + } + p.signalRead() + p.signalWrite() + return nil +} + +func (p *Pipe) signalRead() { + select { + case p.rCh <- struct{}{}: + default: + } +} + +func (p *Pipe) signalWrite() { + select { + case p.wCh <- struct{}{}: + default: + } +} diff --git a/net/nettest/pipe_test.go b/net/nettest/pipe_test.go new file mode 100644 index 000000000..f40d27c53 --- /dev/null +++ b/net/nettest/pipe_test.go @@ -0,0 +1,116 @@ +package nettest + +import ( + "errors" + "fmt" + "testing" + "time" +) + +func TestPipeHello(t *testing.T) { + p := NewPipe("p1", 1<<16) + msg := "Hello, World!" + if n, err := p.Write([]byte(msg)); err != nil { + t.Fatal(err) + } else if n != len(msg) { + t.Errorf("p.Write(%q) n=%d, want %d", msg, n, len(msg)) + } + b := make([]byte, len(msg)) + if n, err := p.Read(b); err != nil { + t.Fatal(err) + } else if n != len(b) { + t.Errorf("p.Read(%q) n=%d, want %d", string(b[:n]), n, len(b)) + } + if got := string(b); got != msg { + t.Errorf("p.Read: %q, want %q", got, msg) + } +} + +func TestPipeTimeout(t *testing.T) { + t.Run("write", func(t *testing.T) { + p := NewPipe("p1", 1<<16) + p.SetWriteDeadline(time.Now().Add(-1 * time.Second)) + n, err := p.Write([]byte{'h'}) + if err == nil || !errors.Is(err, ErrWriteTimeout) || !errors.Is(err, ErrTimeout) { + t.Errorf("missing write timeout got err: %v", err) + } + if n != 0 { + t.Errorf("n=%d on timeout", n) + } + }) + t.Run("read", func(t *testing.T) { + p := NewPipe("p1", 1<<16) + p.Write([]byte{'h'}) + + p.SetReadDeadline(time.Now().Add(-1 * time.Second)) + b := make([]byte, 1) + n, err := p.Read(b) + if err == nil || !errors.Is(err, ErrReadTimeout) || !errors.Is(err, ErrTimeout) { + t.Errorf("missing read timeout got err: %v", err) + } + if n != 0 { + t.Errorf("n=%d on timeout", n) + } + }) + t.Run("block-write", func(t *testing.T) { + p := NewPipe("p1", 1<<16) + p.SetWriteDeadline(time.Now().Add(10 * time.Millisecond)) + if _, err := p.Write([]byte{'h'}); err != nil { + t.Fatal(err) + } + if err := p.Block(); err != nil { + t.Fatal(err) + } + if _, err := p.Write([]byte{'h'}); err == nil || !errors.Is(err, ErrWriteTimeout) { + t.Fatalf("want write timeout got: %v", err) + } + }) + t.Run("block-read", func(t *testing.T) { + p := NewPipe("p1", 1<<16) + p.Write([]byte{'h', 'i'}) + p.SetReadDeadline(time.Now().Add(10 * time.Millisecond)) + b := make([]byte, 1) + if _, err := p.Read(b); err != nil { + t.Fatal(err) + } + if err := p.Block(); err != nil { + t.Fatal(err) + } + if _, err := p.Read(b); err == nil || !errors.Is(err, ErrReadTimeout) { + t.Fatalf("want read timeout got: %v", err) + } + }) + +} + +func TestLimit(t *testing.T) { + p := NewPipe("p1", 1) + errCh := make(chan error) + go func() { + n, err := p.Write([]byte{'a', 'b', 'c'}) + if err != nil { + errCh <- err + } else if n != 3 { + errCh <- fmt.Errorf("p.Write n=%d, want 3", n) + } else { + errCh <- nil + } + }() + b := make([]byte, 3) + + if n, err := p.Read(b); err != nil { + t.Fatal(err) + } else if n != 1 { + t.Errorf("Read(%q): n=%d want 1", string(b), n) + } + if n, err := p.Read(b); err != nil { + t.Fatal(err) + } else if n != 1 { + t.Errorf("Read(%q): n=%d want 1", string(b), n) + } + if n, err := p.Read(b); err != nil { + t.Fatal(err) + } else if n != 1 { + t.Errorf("Read(%q): n=%d want 1", string(b), n) + } +}