mirror of https://github.com/tailscale/tailscale/
net/nettest: new package with net-like testing primitives
This is a lot like wiring up a local UDP socket, read and write deadlines work. The big difference is the Block feature, which lets you stop the packet flow without breaking the connection. This lets you emulate broken sockets and test timeouts actually work. Signed-off-by: David Crawshaw <crawshaw@tailscale.com>pull/181/head
parent
52c0cb12fb
commit
41ac4a79d6
@ -0,0 +1,86 @@
|
|||||||
|
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package nettest
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Conn is a bi-directional in-memory stream that looks like a TCP net.Conn.
|
||||||
|
type Conn interface {
|
||||||
|
io.Reader
|
||||||
|
io.Writer
|
||||||
|
io.Closer
|
||||||
|
|
||||||
|
// The *Deadline methods follow the semantics of net.Conn.
|
||||||
|
|
||||||
|
SetDeadline(t time.Time) error
|
||||||
|
SetReadDeadline(t time.Time) error
|
||||||
|
SetWriteDeadline(t time.Time) error
|
||||||
|
|
||||||
|
// SetReadBlock blocks or unblocks the Read method of this Conn.
|
||||||
|
// It reports an error if the existing value matches the new value,
|
||||||
|
// or if the Conn has been Closed.
|
||||||
|
SetReadBlock(bool) error
|
||||||
|
|
||||||
|
// SetWriteBlock blocks or unblocks the Write method of this Conn.
|
||||||
|
// It reports an error if the existing value matches the new value,
|
||||||
|
// or if the Conn has been Closed.
|
||||||
|
SetWriteBlock(bool) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewConn creates a pair of Conns that are wired together by pipes.
|
||||||
|
func NewConn(name string, maxBuf int) (Conn, Conn) {
|
||||||
|
r := NewPipe(name+"|0", maxBuf)
|
||||||
|
w := NewPipe(name+"|1", maxBuf)
|
||||||
|
|
||||||
|
return &connHalf{r: r, w: w}, &connHalf{r: w, w: r}
|
||||||
|
}
|
||||||
|
|
||||||
|
type connHalf struct {
|
||||||
|
r, w *Pipe
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *connHalf) Read(b []byte) (n int, err error) {
|
||||||
|
return c.r.Read(b)
|
||||||
|
}
|
||||||
|
func (c *connHalf) Write(b []byte) (n int, err error) {
|
||||||
|
return c.w.Write(b)
|
||||||
|
}
|
||||||
|
func (c *connHalf) Close() error {
|
||||||
|
err1 := c.r.Close()
|
||||||
|
err2 := c.w.Close()
|
||||||
|
if err1 != nil {
|
||||||
|
return err1
|
||||||
|
}
|
||||||
|
return err2
|
||||||
|
}
|
||||||
|
func (c *connHalf) SetDeadline(t time.Time) error {
|
||||||
|
err1 := c.SetReadDeadline(t)
|
||||||
|
err2 := c.SetWriteDeadline(t)
|
||||||
|
if err1 != nil {
|
||||||
|
return err1
|
||||||
|
}
|
||||||
|
return err2
|
||||||
|
}
|
||||||
|
func (c *connHalf) SetReadDeadline(t time.Time) error {
|
||||||
|
return c.r.SetReadDeadline(t)
|
||||||
|
}
|
||||||
|
func (c *connHalf) SetWriteDeadline(t time.Time) error {
|
||||||
|
return c.w.SetWriteDeadline(t)
|
||||||
|
}
|
||||||
|
func (c *connHalf) SetReadBlock(b bool) error {
|
||||||
|
if b {
|
||||||
|
return c.r.Block()
|
||||||
|
}
|
||||||
|
return c.r.Unblock()
|
||||||
|
}
|
||||||
|
func (c *connHalf) SetWriteBlock(b bool) error {
|
||||||
|
if b {
|
||||||
|
return c.w.Block()
|
||||||
|
}
|
||||||
|
return c.w.Unblock()
|
||||||
|
}
|
@ -0,0 +1,261 @@
|
|||||||
|
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package nettest
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"log"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
const debugPipe = false
|
||||||
|
|
||||||
|
// Pipe implements an in-memory FIFO with timeouts.
|
||||||
|
type Pipe struct {
|
||||||
|
name string
|
||||||
|
maxBuf int
|
||||||
|
rCh chan struct{}
|
||||||
|
wCh chan struct{}
|
||||||
|
|
||||||
|
mu sync.Mutex
|
||||||
|
closed bool
|
||||||
|
blocked bool
|
||||||
|
buf []byte
|
||||||
|
readTimeout time.Time
|
||||||
|
writeTimeout time.Time
|
||||||
|
cancelReadTimer func()
|
||||||
|
cancelWriteTimer func()
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewPipe creates a Pipe with a buffer size fixed at maxBuf.
|
||||||
|
func NewPipe(name string, maxBuf int) *Pipe {
|
||||||
|
return &Pipe{
|
||||||
|
name: name,
|
||||||
|
maxBuf: maxBuf,
|
||||||
|
rCh: make(chan struct{}, 1),
|
||||||
|
wCh: make(chan struct{}, 1),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrTimeout = errors.New("timeout")
|
||||||
|
ErrReadTimeout = fmt.Errorf("read %w", ErrTimeout)
|
||||||
|
ErrWriteTimeout = fmt.Errorf("write %w", ErrTimeout)
|
||||||
|
)
|
||||||
|
|
||||||
|
// Read implements io.Reader.
|
||||||
|
func (p *Pipe) Read(b []byte) (n int, err error) {
|
||||||
|
if debugPipe {
|
||||||
|
orig := b
|
||||||
|
defer func() {
|
||||||
|
log.Printf("Pipe(%q).Read( %q) n=%d, err=%v", p.name, string(orig[:n]), n, err)
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
for {
|
||||||
|
p.mu.Lock()
|
||||||
|
closed := p.closed
|
||||||
|
timedout := !p.readTimeout.IsZero() && time.Now().After(p.readTimeout)
|
||||||
|
blocked := p.blocked
|
||||||
|
if !closed && !timedout && len(p.buf) > 0 {
|
||||||
|
n2 := copy(b, p.buf)
|
||||||
|
p.buf = p.buf[n2:]
|
||||||
|
b = b[n2:]
|
||||||
|
n += n2
|
||||||
|
}
|
||||||
|
p.mu.Unlock()
|
||||||
|
|
||||||
|
if closed {
|
||||||
|
return 0, fmt.Errorf("nettest.Pipe(%q): closed: %w", p.name, io.EOF)
|
||||||
|
}
|
||||||
|
if timedout {
|
||||||
|
return 0, fmt.Errorf("nettest.Pipe(%q): %w", p.name, ErrReadTimeout)
|
||||||
|
}
|
||||||
|
if blocked {
|
||||||
|
<-p.rCh
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if n > 0 {
|
||||||
|
p.signalWrite()
|
||||||
|
return n, nil
|
||||||
|
}
|
||||||
|
<-p.rCh
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write implements io.Writer.
|
||||||
|
func (p *Pipe) Write(b []byte) (n int, err error) {
|
||||||
|
if debugPipe {
|
||||||
|
orig := b
|
||||||
|
defer func() {
|
||||||
|
log.Printf("Pipe(%q).Write(%q) n=%d, err=%v", p.name, string(orig), n, err)
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
for {
|
||||||
|
p.mu.Lock()
|
||||||
|
closed := p.closed
|
||||||
|
timedout := !p.writeTimeout.IsZero() && time.Now().After(p.writeTimeout)
|
||||||
|
blocked := p.blocked
|
||||||
|
if !closed && !timedout {
|
||||||
|
n2 := len(b)
|
||||||
|
if limit := p.maxBuf - len(p.buf); limit < n2 {
|
||||||
|
n2 = limit
|
||||||
|
}
|
||||||
|
p.buf = append(p.buf, b[:n2]...)
|
||||||
|
b = b[n2:]
|
||||||
|
n += n2
|
||||||
|
}
|
||||||
|
p.mu.Unlock()
|
||||||
|
|
||||||
|
if closed {
|
||||||
|
return n, fmt.Errorf("nettest.Pipe(%q): closed: %w", p.name, io.EOF)
|
||||||
|
}
|
||||||
|
if timedout {
|
||||||
|
return n, fmt.Errorf("nettest.Pipe(%q): %w", p.name, ErrWriteTimeout)
|
||||||
|
}
|
||||||
|
if blocked {
|
||||||
|
<-p.wCh
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if n > 0 {
|
||||||
|
p.signalRead()
|
||||||
|
}
|
||||||
|
if len(b) == 0 {
|
||||||
|
return n, nil
|
||||||
|
}
|
||||||
|
<-p.wCh
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close implements io.Closer.
|
||||||
|
func (p *Pipe) Close() error {
|
||||||
|
p.mu.Lock()
|
||||||
|
closed := p.closed
|
||||||
|
p.closed = true
|
||||||
|
if p.cancelWriteTimer != nil {
|
||||||
|
p.cancelWriteTimer()
|
||||||
|
p.cancelWriteTimer = nil
|
||||||
|
}
|
||||||
|
if p.cancelReadTimer != nil {
|
||||||
|
p.cancelReadTimer()
|
||||||
|
p.cancelReadTimer = nil
|
||||||
|
}
|
||||||
|
p.mu.Unlock()
|
||||||
|
|
||||||
|
if closed {
|
||||||
|
return fmt.Errorf("nettest.Pipe(%q).Close: already closed", p.name)
|
||||||
|
}
|
||||||
|
|
||||||
|
p.signalRead()
|
||||||
|
p.signalWrite()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetReadDeadline sets the deadline for future Read calls.
|
||||||
|
func (p *Pipe) SetReadDeadline(t time.Time) error {
|
||||||
|
p.mu.Lock()
|
||||||
|
p.readTimeout = t
|
||||||
|
if p.cancelReadTimer != nil {
|
||||||
|
p.cancelReadTimer()
|
||||||
|
p.cancelReadTimer = nil
|
||||||
|
}
|
||||||
|
if d := time.Until(t); !t.IsZero() && d > 0 {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
p.cancelReadTimer = cancel
|
||||||
|
go func() {
|
||||||
|
t := time.NewTimer(d)
|
||||||
|
defer t.Stop()
|
||||||
|
select {
|
||||||
|
case <-t.C:
|
||||||
|
p.signalRead()
|
||||||
|
case <-ctx.Done():
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
p.mu.Unlock()
|
||||||
|
|
||||||
|
p.signalRead()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetWriteDeadline sets the deadline for future Write calls.
|
||||||
|
func (p *Pipe) SetWriteDeadline(t time.Time) error {
|
||||||
|
p.mu.Lock()
|
||||||
|
p.writeTimeout = t
|
||||||
|
if p.cancelWriteTimer != nil {
|
||||||
|
p.cancelWriteTimer()
|
||||||
|
p.cancelWriteTimer = nil
|
||||||
|
}
|
||||||
|
if d := time.Until(t); !t.IsZero() && d > 0 {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
p.cancelWriteTimer = cancel
|
||||||
|
go func() {
|
||||||
|
t := time.NewTimer(d)
|
||||||
|
defer t.Stop()
|
||||||
|
select {
|
||||||
|
case <-t.C:
|
||||||
|
p.signalWrite()
|
||||||
|
case <-ctx.Done():
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
p.mu.Unlock()
|
||||||
|
|
||||||
|
p.signalWrite()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Pipe) Block() error {
|
||||||
|
p.mu.Lock()
|
||||||
|
closed := p.closed
|
||||||
|
blocked := p.blocked
|
||||||
|
p.blocked = true
|
||||||
|
p.mu.Unlock()
|
||||||
|
|
||||||
|
if closed {
|
||||||
|
return fmt.Errorf("nettest.Pipe(%q).Block: closed", p.name)
|
||||||
|
}
|
||||||
|
if blocked {
|
||||||
|
return fmt.Errorf("nettest.Pipe(%q).Block: already blocked", p.name)
|
||||||
|
}
|
||||||
|
p.signalRead()
|
||||||
|
p.signalWrite()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Pipe) Unblock() error {
|
||||||
|
p.mu.Lock()
|
||||||
|
closed := p.closed
|
||||||
|
blocked := p.blocked
|
||||||
|
p.blocked = false
|
||||||
|
p.mu.Unlock()
|
||||||
|
|
||||||
|
if closed {
|
||||||
|
return fmt.Errorf("nettest.Pipe(%q).Block: closed", p.name)
|
||||||
|
}
|
||||||
|
if !blocked {
|
||||||
|
return fmt.Errorf("nettest.Pipe(%q).Block: already unblocked", p.name)
|
||||||
|
}
|
||||||
|
p.signalRead()
|
||||||
|
p.signalWrite()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Pipe) signalRead() {
|
||||||
|
select {
|
||||||
|
case p.rCh <- struct{}{}:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Pipe) signalWrite() {
|
||||||
|
select {
|
||||||
|
case p.wCh <- struct{}{}:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,116 @@
|
|||||||
|
package nettest
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestPipeHello(t *testing.T) {
|
||||||
|
p := NewPipe("p1", 1<<16)
|
||||||
|
msg := "Hello, World!"
|
||||||
|
if n, err := p.Write([]byte(msg)); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
} else if n != len(msg) {
|
||||||
|
t.Errorf("p.Write(%q) n=%d, want %d", msg, n, len(msg))
|
||||||
|
}
|
||||||
|
b := make([]byte, len(msg))
|
||||||
|
if n, err := p.Read(b); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
} else if n != len(b) {
|
||||||
|
t.Errorf("p.Read(%q) n=%d, want %d", string(b[:n]), n, len(b))
|
||||||
|
}
|
||||||
|
if got := string(b); got != msg {
|
||||||
|
t.Errorf("p.Read: %q, want %q", got, msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPipeTimeout(t *testing.T) {
|
||||||
|
t.Run("write", func(t *testing.T) {
|
||||||
|
p := NewPipe("p1", 1<<16)
|
||||||
|
p.SetWriteDeadline(time.Now().Add(-1 * time.Second))
|
||||||
|
n, err := p.Write([]byte{'h'})
|
||||||
|
if err == nil || !errors.Is(err, ErrWriteTimeout) || !errors.Is(err, ErrTimeout) {
|
||||||
|
t.Errorf("missing write timeout got err: %v", err)
|
||||||
|
}
|
||||||
|
if n != 0 {
|
||||||
|
t.Errorf("n=%d on timeout", n)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
t.Run("read", func(t *testing.T) {
|
||||||
|
p := NewPipe("p1", 1<<16)
|
||||||
|
p.Write([]byte{'h'})
|
||||||
|
|
||||||
|
p.SetReadDeadline(time.Now().Add(-1 * time.Second))
|
||||||
|
b := make([]byte, 1)
|
||||||
|
n, err := p.Read(b)
|
||||||
|
if err == nil || !errors.Is(err, ErrReadTimeout) || !errors.Is(err, ErrTimeout) {
|
||||||
|
t.Errorf("missing read timeout got err: %v", err)
|
||||||
|
}
|
||||||
|
if n != 0 {
|
||||||
|
t.Errorf("n=%d on timeout", n)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
t.Run("block-write", func(t *testing.T) {
|
||||||
|
p := NewPipe("p1", 1<<16)
|
||||||
|
p.SetWriteDeadline(time.Now().Add(10 * time.Millisecond))
|
||||||
|
if _, err := p.Write([]byte{'h'}); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if err := p.Block(); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if _, err := p.Write([]byte{'h'}); err == nil || !errors.Is(err, ErrWriteTimeout) {
|
||||||
|
t.Fatalf("want write timeout got: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
t.Run("block-read", func(t *testing.T) {
|
||||||
|
p := NewPipe("p1", 1<<16)
|
||||||
|
p.Write([]byte{'h', 'i'})
|
||||||
|
p.SetReadDeadline(time.Now().Add(10 * time.Millisecond))
|
||||||
|
b := make([]byte, 1)
|
||||||
|
if _, err := p.Read(b); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if err := p.Block(); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if _, err := p.Read(b); err == nil || !errors.Is(err, ErrReadTimeout) {
|
||||||
|
t.Fatalf("want read timeout got: %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLimit(t *testing.T) {
|
||||||
|
p := NewPipe("p1", 1)
|
||||||
|
errCh := make(chan error)
|
||||||
|
go func() {
|
||||||
|
n, err := p.Write([]byte{'a', 'b', 'c'})
|
||||||
|
if err != nil {
|
||||||
|
errCh <- err
|
||||||
|
} else if n != 3 {
|
||||||
|
errCh <- fmt.Errorf("p.Write n=%d, want 3", n)
|
||||||
|
} else {
|
||||||
|
errCh <- nil
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
b := make([]byte, 3)
|
||||||
|
|
||||||
|
if n, err := p.Read(b); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
} else if n != 1 {
|
||||||
|
t.Errorf("Read(%q): n=%d want 1", string(b), n)
|
||||||
|
}
|
||||||
|
if n, err := p.Read(b); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
} else if n != 1 {
|
||||||
|
t.Errorf("Read(%q): n=%d want 1", string(b), n)
|
||||||
|
}
|
||||||
|
if n, err := p.Read(b); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
} else if n != 1 {
|
||||||
|
t.Errorf("Read(%q): n=%d want 1", string(b), n)
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue