From 57756ef673cd642721d70a8ca72ca4b53859bb0e Mon Sep 17 00:00:00 2001 From: Maisem Ali Date: Fri, 2 Apr 2021 20:04:52 -0700 Subject: [PATCH] net/nettest: make nettest.NewConn pass x/net/nettest.TestConn. Signed-off-by: Maisem Ali --- net/nettest/conn.go | 38 +++--- net/nettest/conn_test.go | 22 +++ net/nettest/listener.go | 83 ++++++++++++ net/nettest/listener_test.go | 34 +++++ net/nettest/pipe.go | 250 ++++++++++++++++------------------- net/nettest/pipe_test.go | 9 +- 6 files changed, 283 insertions(+), 153 deletions(-) create mode 100644 net/nettest/conn_test.go create mode 100644 net/nettest/listener.go create mode 100644 net/nettest/listener_test.go diff --git a/net/nettest/conn.go b/net/nettest/conn.go index 9cbbd3e34..90727c4a8 100644 --- a/net/nettest/conn.go +++ b/net/nettest/conn.go @@ -5,21 +5,13 @@ package nettest import ( - "io" + "net" "time" ) -// Conn is a bi-directional in-memory stream that looks like a TCP net.Conn. +// Conn is a net.Conn that can additionally have its reads and writes blocked and unblocked. type Conn interface { - io.Reader - io.Writer - io.Closer - - // The *Deadline methods follow the semantics of net.Conn. - - SetDeadline(t time.Time) error - SetReadDeadline(t time.Time) error - SetWriteDeadline(t time.Time) error + net.Conn // SetReadBlock blocks or unblocks the Read method of this Conn. // It reports an error if the existing value matches the new value, @@ -40,24 +32,37 @@ func NewConn(name string, maxBuf int) (Conn, Conn) { return &connHalf{r: r, w: w}, &connHalf{r: w, w: r} } +type connAddr string + +func (a connAddr) Network() string { return "mem" } +func (a connAddr) String() string { return string(a) } + type connHalf struct { r, w *Pipe } +func (c *connHalf) LocalAddr() net.Addr { + return connAddr(c.r.name) +} + +func (c *connHalf) RemoteAddr() net.Addr { + return connAddr(c.w.name) +} + func (c *connHalf) Read(b []byte) (n int, err error) { return c.r.Read(b) } func (c *connHalf) Write(b []byte) (n int, err error) { return c.w.Write(b) } + func (c *connHalf) Close() error { - err1 := c.r.Close() - err2 := c.w.Close() - if err1 != nil { - return err1 + if err := c.w.Close(); err != nil { + return err } - return err2 + return c.r.Close() } + func (c *connHalf) SetDeadline(t time.Time) error { err1 := c.SetReadDeadline(t) err2 := c.SetWriteDeadline(t) @@ -72,6 +77,7 @@ func (c *connHalf) SetReadDeadline(t time.Time) error { func (c *connHalf) SetWriteDeadline(t time.Time) error { return c.w.SetWriteDeadline(t) } + func (c *connHalf) SetReadBlock(b bool) error { if b { return c.r.Block() diff --git a/net/nettest/conn_test.go b/net/nettest/conn_test.go new file mode 100644 index 000000000..76c189198 --- /dev/null +++ b/net/nettest/conn_test.go @@ -0,0 +1,22 @@ +// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package nettest + +import ( + "net" + "testing" + + "golang.org/x/net/nettest" +) + +func TestConn(t *testing.T) { + nettest.TestConn(t, func() (c1 net.Conn, c2 net.Conn, stop func(), err error) { + c1, c2 = NewConn("test", bufferSize) + return c1, c2, func() { + c1.Close() + c2.Close() + }, nil + }) +} diff --git a/net/nettest/listener.go b/net/nettest/listener.go new file mode 100644 index 000000000..31e16c0a9 --- /dev/null +++ b/net/nettest/listener.go @@ -0,0 +1,83 @@ +// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package nettest + +import ( + "context" + "net" + "strings" + "sync" +) + +const ( + bufferSize = 256 * 1024 +) + +// Listener is a net.Listener using using NewConn to create pairs of network +// connections connected in memory using a buffered pipe. It also provides a +// Dial method to establish new connections. +type Listener struct { + addr connAddr + ch chan Conn + closeOnce sync.Once + closed chan struct{} +} + +// Listen returns a new Listener for the provided address. +func Listen(addr string) *Listener { + return &Listener{ + addr: connAddr(addr), + ch: make(chan Conn), + closed: make(chan struct{}), + } +} + +// Addr implements net.Listener.Addr. +func (l *Listener) Addr() net.Addr { + return l.addr +} + +// Close closes the pipe listener. +func (l *Listener) Close() error { + l.closeOnce.Do(func() { + close(l.closed) + }) + return nil +} + +// Accept blocks until a new connection is available or the listener is closed. +func (l *Listener) Accept() (net.Conn, error) { + select { + case c := <-l.ch: + return c, nil + case <-l.closed: + return nil, net.ErrClosed + } +} + +// Dial connects to the listener using the provided context. +// The provided Context must be non-nil. If the context expires before the +// connection is complete, an error is returned. Once successfully connected +// any expiration of the context will not affect the connection. +func (l *Listener) Dial(ctx context.Context, network, addr string) (net.Conn, error) { + if !strings.HasSuffix(network, "tcp") { + return nil, net.UnknownNetworkError(network) + } + if connAddr(addr) != l.addr { + return nil, &net.AddrError{ + Err: "invalid address", + Addr: addr, + } + } + c, s := NewConn(addr, bufferSize) + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-l.closed: + return nil, net.ErrClosed + case l.ch <- s: + return c, nil + } +} diff --git a/net/nettest/listener_test.go b/net/nettest/listener_test.go new file mode 100644 index 000000000..09b7bed5f --- /dev/null +++ b/net/nettest/listener_test.go @@ -0,0 +1,34 @@ +// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package nettest + +import ( + "context" + "testing" +) + +func TestListener(t *testing.T) { + l := Listen("srv.local") + defer l.Close() + go func() { + c, err := l.Accept() + if err != nil { + t.Error(err) + return + } + defer c.Close() + }() + + if c, err := l.Dial(context.Background(), "tcp", "invalid"); err == nil { + c.Close() + t.Fatalf("dial to invalid address succeeded") + } + c, err := l.Dial(context.Background(), "tcp", "srv.local") + if err != nil { + t.Fatalf("dial failed: %v", err) + return + } + c.Close() +} diff --git a/net/nettest/pipe.go b/net/nettest/pipe.go index 366a23069..671118bea 100644 --- a/net/nettest/pipe.go +++ b/net/nettest/pipe.go @@ -5,11 +5,13 @@ package nettest import ( + "bytes" "context" - "errors" "fmt" "io" "log" + "net" + "os" "sync" "time" ) @@ -20,13 +22,12 @@ const debugPipe = false type Pipe struct { name string maxBuf int - rCh chan struct{} - wCh chan struct{} + mu sync.Mutex + cnd *sync.Cond - mu sync.Mutex - closed bool blocked bool - buf []byte + closed bool + buf bytes.Buffer readTimeout time.Time writeTimeout time.Time cancelReadTimer func() @@ -35,21 +36,42 @@ type Pipe struct { // NewPipe creates a Pipe with a buffer size fixed at maxBuf. func NewPipe(name string, maxBuf int) *Pipe { - return &Pipe{ + p := &Pipe{ name: name, maxBuf: maxBuf, - rCh: make(chan struct{}, 1), - wCh: make(chan struct{}, 1), } + p.cnd = sync.NewCond(&p.mu) + return p } -var ( - ErrTimeout = errors.New("timeout") - ErrReadTimeout = fmt.Errorf("read %w", ErrTimeout) - ErrWriteTimeout = fmt.Errorf("write %w", ErrTimeout) -) +// readOrBlock attempts to read from the buffer, if the buffer is empty and +// the connection hasn't been closed it will block until there is a change. +func (p *Pipe) readOrBlock(b []byte) (int, error) { + p.mu.Lock() + defer p.mu.Unlock() + if !p.readTimeout.IsZero() && !time.Now().Before(p.readTimeout) { + return 0, os.ErrDeadlineExceeded + } + if p.blocked { + p.cnd.Wait() + return 0, nil + } + + n, err := p.buf.Read(b) + // err will either be nil or io.EOF. + if err == io.EOF { + if p.closed { + return n, err + } + // Wait for something to change. + p.cnd.Wait() + } + return n, nil +} // Read implements io.Reader. +// Once the buffer is drained (i.e. after Close), subsequent calls will +// return io.EOF. func (p *Pipe) Read(b []byte) (n int, err error) { if debugPipe { orig := b @@ -57,35 +79,48 @@ func (p *Pipe) Read(b []byte) (n int, err error) { log.Printf("Pipe(%q).Read( %q) n=%d, err=%v", p.name, string(orig[:n]), n, err) }() } - for { - p.mu.Lock() - closed := p.closed - timedout := !p.readTimeout.IsZero() && !time.Now().Before(p.readTimeout) - blocked := p.blocked - if !closed && !timedout && len(p.buf) > 0 { - n2 := copy(b, p.buf) - p.buf = p.buf[n2:] - b = b[n2:] - n += n2 + for n == 0 { + n2, err := p.readOrBlock(b) + if err != nil { + return n2, err } - p.mu.Unlock() + n += n2 + } + p.cnd.Signal() + return n, nil +} - if closed { - return 0, fmt.Errorf("nettest.Pipe(%q): closed: %w", p.name, io.EOF) - } - if timedout { - return 0, fmt.Errorf("nettest.Pipe(%q): %w", p.name, ErrReadTimeout) - } - if blocked { - <-p.rCh - continue - } - if n > 0 { - p.signalWrite() - return n, nil - } - <-p.rCh +// writeOrBlock attempts to write to the buffer, if the buffer is full it will +// block until there is a change. +func (p *Pipe) writeOrBlock(b []byte) (int, error) { + p.mu.Lock() + defer p.mu.Unlock() + if p.closed { + return 0, net.ErrClosed } + if !p.writeTimeout.IsZero() && !time.Now().Before(p.writeTimeout) { + return 0, os.ErrDeadlineExceeded + } + if p.blocked { + p.cnd.Wait() + return 0, nil + } + + // Optimistically we want to write the entire slice. + n := len(b) + if limit := p.maxBuf - p.buf.Len(); limit < n { + // However, we don't have enough capacity to write everything. + n = limit + } + if n == 0 { + // Wait for something to change. + p.cnd.Wait() + return 0, nil + } + + p.buf.Write(b[:n]) + p.cnd.Signal() + return n, nil } // Write implements io.Writer. @@ -96,47 +131,23 @@ func (p *Pipe) Write(b []byte) (n int, err error) { log.Printf("Pipe(%q).Write(%q) n=%d, err=%v", p.name, string(orig), n, err) }() } - for { - p.mu.Lock() - closed := p.closed - timedout := !p.writeTimeout.IsZero() && !time.Now().Before(p.writeTimeout) - blocked := p.blocked - if !closed && !timedout { - n2 := len(b) - if limit := p.maxBuf - len(p.buf); limit < n2 { - n2 = limit - } - p.buf = append(p.buf, b[:n2]...) - b = b[n2:] - n += n2 - } - p.mu.Unlock() - - if closed { - return n, fmt.Errorf("nettest.Pipe(%q): closed: %w", p.name, io.EOF) - } - if timedout { - return n, fmt.Errorf("nettest.Pipe(%q): %w", p.name, ErrWriteTimeout) + for len(b) > 0 { + n2, err := p.writeOrBlock(b) + if err != nil { + return n + n2, err } - if blocked { - <-p.wCh - continue - } - if n > 0 { - p.signalRead() - } - if len(b) == 0 { - return n, nil - } - <-p.wCh + n += n2 + b = b[n2:] } + return n, nil } -// Close implements io.Closer. +// Close closes the pipe. func (p *Pipe) Close() error { p.mu.Lock() - closed := p.closed + defer p.mu.Unlock() p.closed = true + p.blocked = false if p.cancelWriteTimer != nil { p.cancelWriteTimer() p.cancelWriteTimer = nil @@ -145,77 +156,65 @@ func (p *Pipe) Close() error { p.cancelReadTimer() p.cancelReadTimer = nil } - p.mu.Unlock() + p.cnd.Broadcast() - if closed { - return fmt.Errorf("nettest.Pipe(%q).Close: already closed", p.name) - } - - p.signalRead() - p.signalWrite() return nil } +func (p *Pipe) deadlineTimer(t time.Time) func() { + if t.IsZero() { + return nil + } + if t.Before(time.Now()) { + p.cnd.Broadcast() + return nil + } + ctx, cancel := context.WithDeadline(context.Background(), t) + go func() { + <-ctx.Done() + if ctx.Err() == context.DeadlineExceeded { + p.cnd.Broadcast() + } + }() + return cancel +} + // SetReadDeadline sets the deadline for future Read calls. func (p *Pipe) SetReadDeadline(t time.Time) error { p.mu.Lock() + defer p.mu.Unlock() p.readTimeout = t + // If we already have a deadline, cancel it and create a new one. if p.cancelReadTimer != nil { p.cancelReadTimer() p.cancelReadTimer = nil } - if d := time.Until(t); !t.IsZero() && d > 0 { - ctx, cancel := context.WithCancel(context.Background()) - p.cancelReadTimer = cancel - go func() { - t := time.NewTimer(d) - defer t.Stop() - select { - case <-t.C: - p.signalRead() - case <-ctx.Done(): - } - }() - } - p.mu.Unlock() - - p.signalRead() + p.cancelReadTimer = p.deadlineTimer(t) return nil } // SetWriteDeadline sets the deadline for future Write calls. func (p *Pipe) SetWriteDeadline(t time.Time) error { p.mu.Lock() + defer p.mu.Unlock() p.writeTimeout = t + // If we already have a deadline, cancel it and create a new one. if p.cancelWriteTimer != nil { p.cancelWriteTimer() p.cancelWriteTimer = nil } - if d := time.Until(t); !t.IsZero() && d > 0 { - ctx, cancel := context.WithCancel(context.Background()) - p.cancelWriteTimer = cancel - go func() { - t := time.NewTimer(d) - defer t.Stop() - select { - case <-t.C: - p.signalWrite() - case <-ctx.Done(): - } - }() - } - p.mu.Unlock() - - p.signalWrite() + p.cancelWriteTimer = p.deadlineTimer(t) return nil } +// Block will cause all calls to Read and Write to block until they either +// timeout, are unblocked or the pipe is closed. func (p *Pipe) Block() error { p.mu.Lock() + defer p.mu.Unlock() closed := p.closed blocked := p.blocked p.blocked = true - p.mu.Unlock() if closed { return fmt.Errorf("nettest.Pipe(%q).Block: closed", p.name) @@ -223,17 +222,17 @@ func (p *Pipe) Block() error { if blocked { return fmt.Errorf("nettest.Pipe(%q).Block: already blocked", p.name) } - p.signalRead() - p.signalWrite() + p.cnd.Broadcast() return nil } +// Unblock will cause all blocked Read/Write calls to continue execution. func (p *Pipe) Unblock() error { p.mu.Lock() + defer p.mu.Unlock() closed := p.closed blocked := p.blocked p.blocked = false - p.mu.Unlock() if closed { return fmt.Errorf("nettest.Pipe(%q).Block: closed", p.name) @@ -241,21 +240,6 @@ func (p *Pipe) Unblock() error { if !blocked { return fmt.Errorf("nettest.Pipe(%q).Block: already unblocked", p.name) } - p.signalRead() - p.signalWrite() + p.cnd.Broadcast() return nil } - -func (p *Pipe) signalRead() { - select { - case p.rCh <- struct{}{}: - default: - } -} - -func (p *Pipe) signalWrite() { - select { - case p.wCh <- struct{}{}: - default: - } -} diff --git a/net/nettest/pipe_test.go b/net/nettest/pipe_test.go index 2094b202d..caa7d4b43 100644 --- a/net/nettest/pipe_test.go +++ b/net/nettest/pipe_test.go @@ -7,6 +7,7 @@ package nettest import ( "errors" "fmt" + "os" "testing" "time" ) @@ -35,7 +36,7 @@ func TestPipeTimeout(t *testing.T) { p := NewPipe("p1", 1<<16) p.SetWriteDeadline(time.Now().Add(-1 * time.Second)) n, err := p.Write([]byte{'h'}) - if !errors.Is(err, ErrWriteTimeout) || !errors.Is(err, ErrTimeout) { + if !errors.Is(err, os.ErrDeadlineExceeded) { t.Errorf("missing write timeout got err: %v", err) } if n != 0 { @@ -49,7 +50,7 @@ func TestPipeTimeout(t *testing.T) { p.SetReadDeadline(time.Now().Add(-1 * time.Second)) b := make([]byte, 1) n, err := p.Read(b) - if !errors.Is(err, ErrReadTimeout) || !errors.Is(err, ErrTimeout) { + if !errors.Is(err, os.ErrDeadlineExceeded) { t.Errorf("missing read timeout got err: %v", err) } if n != 0 { @@ -65,7 +66,7 @@ func TestPipeTimeout(t *testing.T) { if err := p.Block(); err != nil { t.Fatal(err) } - if _, err := p.Write([]byte{'h'}); !errors.Is(err, ErrWriteTimeout) { + if _, err := p.Write([]byte{'h'}); !errors.Is(err, os.ErrDeadlineExceeded) { t.Fatalf("want write timeout got: %v", err) } }) @@ -80,7 +81,7 @@ func TestPipeTimeout(t *testing.T) { if err := p.Block(); err != nil { t.Fatal(err) } - if _, err := p.Read(b); !errors.Is(err, ErrReadTimeout) { + if _, err := p.Read(b); !errors.Is(err, os.ErrDeadlineExceeded) { t.Fatalf("want read timeout got: %v", err) } })