wgengine: fix crash reading long UAPI lines from legacy peers

Also don't log.Fatalf in a function returning an error.

Fixes #1204

Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
pull/1209/head
Brad Fitzpatrick 4 years ago committed by Brad Fitzpatrick
parent a7edcd0872
commit e970ed0995

@ -11,7 +11,6 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"log"
"net" "net"
"os" "os"
"os/exec" "os/exec"
@ -112,6 +111,7 @@ type userspaceEngine struct {
trimmedDisco map[tailcfg.DiscoKey]bool // set of disco keys of peers currently excluded from wireguard config trimmedDisco map[tailcfg.DiscoKey]bool // set of disco keys of peers currently excluded from wireguard config
sentActivityAt map[netaddr.IP]*int64 // value is atomic int64 of unixtime sentActivityAt map[netaddr.IP]*int64 // value is atomic int64 of unixtime
destIPActivityFuncs map[netaddr.IP]func() destIPActivityFuncs map[netaddr.IP]func()
statusBufioReader *bufio.Reader // reusable for UAPI
mu sync.Mutex // guards following; see lock order comment below mu sync.Mutex // guards following; see lock order comment below
closing bool // Close was called (even if we're still closing) closing bool // Close was called (even if we're still closing)
@ -1035,8 +1035,6 @@ func (e *userspaceEngine) getStatusCallback() StatusCallback {
return e.statusCallback return e.statusCallback
} }
// TODO: this function returns an error but it's always nil, and when
// there's actually a problem it just calls log.Fatal. Why?
func (e *userspaceEngine) getStatus() (*Status, error) { func (e *userspaceEngine) getStatus() (*Status, error) {
// Grab derpConns before acquiring wgLock to not violate lock ordering; // Grab derpConns before acquiring wgLock to not violate lock ordering;
// the DERPs method acquires magicsock.Conn.mu. // the DERPs method acquires magicsock.Conn.mu.
@ -1061,15 +1059,11 @@ func (e *userspaceEngine) getStatus() (*Status, error) {
return nil, nil return nil, nil
} }
// lineLen is the max UAPI line we expect. The longest I see is
// len("preshared_key=")+64 hex+"\n" == 79. Add some slop.
const lineLen = 100
pr, pw := io.Pipe() pr, pw := io.Pipe()
errc := make(chan error, 1) errc := make(chan error, 1)
go func() { go func() {
defer pw.Close() defer pw.Close()
bw := bufio.NewWriterSize(pw, lineLen)
// TODO(apenwarr): get rid of silly uapi stuff for in-process comms // TODO(apenwarr): get rid of silly uapi stuff for in-process comms
// FIXME: get notified of status changes instead of polling. // FIXME: get notified of status changes instead of polling.
filter := device.IPCGetFilter{ filter := device.IPCGetFilter{
@ -1077,23 +1071,34 @@ func (e *userspaceEngine) getStatus() (*Status, error) {
// unused below; request that they not be sent instead. // unused below; request that they not be sent instead.
FilterAllowedIPs: true, FilterAllowedIPs: true,
} }
if err := e.wgdev.IpcGetOperationFiltered(bw, filter); err != nil { err := e.wgdev.IpcGetOperationFiltered(pw, filter)
errc <- fmt.Errorf("IpcGetOperation: %w", err) if err != nil {
return err = fmt.Errorf("IpcGetOperation: %w", err)
} }
errc <- bw.Flush() errc <- err
}() }()
pp := make(map[wgkey.Key]*PeerStatus) pp := make(map[wgkey.Key]*PeerStatus)
p := &PeerStatus{} p := &PeerStatus{}
var hst1, hst2, n int64 var hst1, hst2, n int64
var err error
bs := bufio.NewScanner(pr) br := e.statusBufioReader
bs.Buffer(make([]byte, lineLen), lineLen) if br != nil {
for bs.Scan() { br.Reset(pr)
line := bs.Bytes() } else {
br = bufio.NewReaderSize(pr, 1<<10)
e.statusBufioReader = br
}
for {
line, err := br.ReadSlice('\n')
if err == io.EOF {
break
}
if err != nil {
pr.Close()
return nil, fmt.Errorf("reading from UAPI pipe: %w", err)
}
k := line k := line
var v mem.RO var v mem.RO
if i := bytes.IndexByte(line, '='); i != -1 { if i := bytes.IndexByte(line, '='); i != -1 {
@ -1104,7 +1109,7 @@ func (e *userspaceEngine) getStatus() (*Status, error) {
case "public_key": case "public_key":
pk, err := key.NewPublicFromHexMem(v) pk, err := key.NewPublicFromHexMem(v)
if err != nil { if err != nil {
log.Fatalf("IpcGetOperation: invalid key %#v", v) return nil, fmt.Errorf("IpcGetOperation: invalid key %#v", v)
} }
p = &PeerStatus{} p = &PeerStatus{}
pp[wgkey.Key(pk)] = p pp[wgkey.Key(pk)] = p
@ -1115,34 +1120,31 @@ func (e *userspaceEngine) getStatus() (*Status, error) {
n, err = mem.ParseInt(v, 10, 64) n, err = mem.ParseInt(v, 10, 64)
p.RxBytes = ByteCount(n) p.RxBytes = ByteCount(n)
if err != nil { if err != nil {
log.Fatalf("IpcGetOperation: rx_bytes invalid: %#v", line) return nil, fmt.Errorf("IpcGetOperation: rx_bytes invalid: %#v", line)
} }
case "tx_bytes": case "tx_bytes":
n, err = mem.ParseInt(v, 10, 64) n, err = mem.ParseInt(v, 10, 64)
p.TxBytes = ByteCount(n) p.TxBytes = ByteCount(n)
if err != nil { if err != nil {
log.Fatalf("IpcGetOperation: tx_bytes invalid: %#v", line) return nil, fmt.Errorf("IpcGetOperation: tx_bytes invalid: %#v", line)
} }
case "last_handshake_time_sec": case "last_handshake_time_sec":
hst1, err = mem.ParseInt(v, 10, 64) hst1, err = mem.ParseInt(v, 10, 64)
if err != nil { if err != nil {
log.Fatalf("IpcGetOperation: hst1 invalid: %#v", line) return nil, fmt.Errorf("IpcGetOperation: hst1 invalid: %#v", line)
} }
case "last_handshake_time_nsec": case "last_handshake_time_nsec":
hst2, err = mem.ParseInt(v, 10, 64) hst2, err = mem.ParseInt(v, 10, 64)
if err != nil { if err != nil {
log.Fatalf("IpcGetOperation: hst2 invalid: %#v", line) return nil, fmt.Errorf("IpcGetOperation: hst2 invalid: %#v", line)
} }
if hst1 != 0 || hst2 != 0 { if hst1 != 0 || hst2 != 0 {
p.LastHandshake = time.Unix(hst1, hst2) p.LastHandshake = time.Unix(hst1, hst2)
} // else leave at time.IsZero() } // else leave at time.IsZero()
} }
} }
if err := bs.Err(); err != nil {
log.Fatalf("reading IpcGetOperation output: %v", err)
}
if err := <-errc; err != nil { if err := <-errc; err != nil {
log.Fatalf("IpcGetOperation: %v", err) return nil, fmt.Errorf("IpcGetOperation: %v", err)
} }
e.mu.Lock() e.mu.Lock()

Loading…
Cancel
Save