diff --git a/control/controlbase/conn.go b/control/controlbase/conn.go index 231a54044..cccdd40ab 100644 --- a/control/controlbase/conn.go +++ b/control/controlbase/conn.go @@ -52,10 +52,11 @@ type rxState struct { sync.Mutex cipher cipher.AEAD nonce nonce - buf [maxMessageSize]byte - n int // number of valid bytes in buf - next int // offset of next undecrypted packet - plaintext []byte // slice into buf of decrypted bytes + buf *maxMsgBuffer // or nil when reads exhausted + n int // number of valid bytes in buf + next int // offset of next undecrypted packet + plaintext []byte // slice into buf of decrypted bytes + hdrBuf [headerLen]byte // small buffer used when buf is nil } // txState is all the Conn state that Write uses. @@ -88,6 +89,10 @@ func (c *Conn) Peer() key.MachinePublic { // readNLocked reads into c.rx.buf until buf contains at least total // bytes. Returns a slice of the total bytes in rxBuf, or an // error if fewer than total bytes are available. +// +// It may be called with a nil c.rx.buf only if total == headerLen. +// +// On success, c.rx.buf will be non-nil. func (c *Conn) readNLocked(total int) ([]byte, error) { if total > maxMessageSize { return nil, errReadTooBig{total} @@ -96,8 +101,26 @@ func (c *Conn) readNLocked(total int) ([]byte, error) { if total <= c.rx.n { return c.rx.buf[:total], nil } - - n, err := c.conn.Read(c.rx.buf[c.rx.n:]) + var n int + var err error + if c.rx.buf == nil { + if c.rx.n != 0 || total != headerLen { + panic("unexpected") + } + // Optimization to reduce memory usage. + // Most connections are blocked forever waiting for + // a read, so we don't want c.rx.buf to be allocated until + // we know there's data to read. Instead, when we're + // waiting for data to arrive here, read into the + // 3 byte hdrBuf: + n, err = c.conn.Read(c.rx.hdrBuf[:]) + if n > 0 { + c.rx.buf = getMaxMsgBuffer() + copy(c.rx.buf[:], c.rx.hdrBuf[:n]) + } + } else { + n, err = c.conn.Read(c.rx.buf[c.rx.n:]) + } c.rx.n += n if err != nil { return nil, err @@ -190,6 +213,14 @@ func (c *Conn) decryptOneLocked() error { c.rx.next = 0 } + // Return our buffer to the pool if it's empty, lest we be + // blocked in a long Read call, reading the 3 byte header. We + // don't to keep that buffer unnecessarily alive. + if c.rx.n == 0 && c.rx.next == 0 && c.rx.buf != nil { + bufPool.Put(c.rx.buf) + c.rx.buf = nil + } + bs, err := c.readNLocked(headerLen) if err != nil { return err @@ -226,6 +257,12 @@ func (c *Conn) Read(bs []byte) (int, error) { } n := copy(bs, c.rx.plaintext) c.rx.plaintext = c.rx.plaintext[n:] + + // Lose slice's underlying array pointer to unneeded memory so + // GC can collect more. + if len(c.rx.plaintext) == 0 { + c.rx.plaintext = nil + } return n, nil } @@ -256,7 +293,7 @@ func (c *Conn) Write(bs []byte) (n int, err error) { return 0, net.ErrClosed } - buf := bufPool.Get().(*maxMsgBuffer) + buf := getMaxMsgBuffer() defer bufPool.Put(buf) var sent int @@ -366,3 +403,7 @@ var bufPool = &sync.Pool{ return new(maxMsgBuffer) }, } + +func getMaxMsgBuffer() *maxMsgBuffer { + return bufPool.Get().(*maxMsgBuffer) +} diff --git a/control/controlbase/conn_test.go b/control/controlbase/conn_test.go index c0dfa9940..f25ae2850 100644 --- a/control/controlbase/conn_test.go +++ b/control/controlbase/conn_test.go @@ -13,10 +13,12 @@ import ( "fmt" "io" "net" + "runtime" "strings" "sync" "testing" "testing/iotest" + "time" chp "golang.org/x/crypto/chacha20poly1305" "golang.org/x/net/nettest" @@ -224,6 +226,81 @@ func TestConnStd(t *testing.T) { }) } +// tests that the idle memory overhead of a Conn blocked in a read is +// reasonable (under 2K). It was previously over 8KB with two 4KB +// buffers for rx/tx. This make sure we don't regress. Hopefully it +// doesn't turn into a flaky test. If so, const max can be adjusted, +// or it can be deleted or reworked. +func TestConnMemoryOverhead(t *testing.T) { + num := 1000 + if testing.Short() { + num = 100 + } + ng0 := runtime.NumGoroutine() + + runtime.GC() + var ms0 runtime.MemStats + runtime.ReadMemStats(&ms0) + + var closers []io.Closer + closeAll := func() { + for _, c := range closers { + c.Close() + } + closers = nil + } + defer closeAll() + + for i := 0; i < num; i++ { + client, server := pair(t) + closers = append(closers, client, server) + go func() { + var buf [1]byte + client.Read(buf[:]) + }() + } + + t0 := time.Now() + deadline := t0.Add(3 * time.Second) + var ngo int + for time.Now().Before(deadline) { + runtime.GC() + ngo = runtime.NumGoroutine() + if ngo >= num { + break + } + time.Sleep(10 * time.Millisecond) + } + if ngo < num { + t.Fatalf("only %v goroutines; expected %v+", ngo, num) + } + runtime.GC() + var ms runtime.MemStats + runtime.ReadMemStats(&ms) + growthTotal := int64(ms.HeapAlloc) - int64(ms0.HeapAlloc) + growthEach := float64(growthTotal) / float64(num) + t.Logf("Alloced %v bytes, %.2f B/each", growthTotal, growthEach) + const max = 2000 + if growthEach > max { + t.Errorf("allocated more than expected; want max %v bytes/each", max) + } + + closeAll() + + // And make sure our goroutines go away too. + deadline = time.Now().Add(3 * time.Second) + for time.Now().Before(deadline) { + ngo = runtime.NumGoroutine() + if ngo < ng0+num/10 { + break + } + time.Sleep(10 * time.Millisecond) + } + if ngo >= ng0+num/10 { + t.Errorf("goroutines didn't go back down; started at %v, now %v", ng0, ngo) + } +} + // mkConns creates synthetic Noise Conns wrapping the given net.Conns. // This function is for testing just the Conn transport logic without // having to muck about with Noise handshakes.