From f65eb4e5c1825c277a99c9f3a4631bf83f7bfd5a Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Thu, 10 Sep 2020 15:21:32 -0700 Subject: [PATCH] net/netstat: start of new netstat package, with Windows for now This will be used in a future change to do localhost connection authentication. This lets us quickly map a localhost TCP connection to a PID. (A future change will then map a pid to a user) TODO: pull portlist's netstat code into this package. Then portlist will be fast on Windows without requiring shelling out to netstat.exe. --- net/netstat/netstat.go | 36 +++++++ net/netstat/netstat_noimpl.go | 11 ++ net/netstat/netstat_test.go | 22 ++++ net/netstat/netstat_windows.go | 178 +++++++++++++++++++++++++++++++++ 4 files changed, 247 insertions(+) create mode 100644 net/netstat/netstat.go create mode 100644 net/netstat/netstat_noimpl.go create mode 100644 net/netstat/netstat_test.go create mode 100644 net/netstat/netstat_windows.go diff --git a/net/netstat/netstat.go b/net/netstat/netstat.go new file mode 100644 index 000000000..3aef7ea1e --- /dev/null +++ b/net/netstat/netstat.go @@ -0,0 +1,36 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package netstat returns the local machine's network connection table. +package netstat + +import ( + "errors" + "runtime" + + "inet.af/netaddr" +) + +var ErrNotImplemented = errors.New("not implemented for GOOS=" + runtime.GOOS) + +type Entry struct { + Local, Remote netaddr.IPPort + Pid int + State string // TODO: type? +} + +// Table contains local machine's TCP connection entries. +// +// Currently only TCP (IPv4 and IPv6) are included. +type Table struct { + Entries []Entry +} + +// Get returns the connection table. +// +// It returns ErrNotImplemented if the table is not available for the +// current operating system. +func Get() (*Table, error) { + return get() +} diff --git a/net/netstat/netstat_noimpl.go b/net/netstat/netstat_noimpl.go new file mode 100644 index 000000000..65732f0d7 --- /dev/null +++ b/net/netstat/netstat_noimpl.go @@ -0,0 +1,11 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build !windows + +package netstat + +func get() (*Table, error) { + return nil, ErrNotImplemented +} diff --git a/net/netstat/netstat_test.go b/net/netstat/netstat_test.go new file mode 100644 index 000000000..d75f2e777 --- /dev/null +++ b/net/netstat/netstat_test.go @@ -0,0 +1,22 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package netstat + +import ( + "testing" +) + +func TestGet(t *testing.T) { + nt, err := Get() + if err == ErrNotImplemented { + t.Skip("TODO: not implemented") + } + if err != nil { + t.Fatal(err) + } + for _, e := range nt.Entries { + t.Logf("Entry: %+v", e) + } +} diff --git a/net/netstat/netstat_windows.go b/net/netstat/netstat_windows.go new file mode 100644 index 000000000..893d264e9 --- /dev/null +++ b/net/netstat/netstat_windows.go @@ -0,0 +1,178 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package netstat returns the local machine's network connection table. +package netstat + +import ( + "encoding/binary" + "errors" + "fmt" + "syscall" + "unsafe" + + "golang.org/x/sys/windows" + "inet.af/netaddr" +) + +// See https://docs.microsoft.com/en-us/windows/win32/api/iphlpapi/nf-iphlpapi-getextendedtcptable + +// TCP_TABLE_OWNER_PID_ALL means to include the PID info. The table type +// we get back from Windows depends on AF_INET vs AF_INET6: +// MIB_TCPTABLE_OWNER_PID for v4 or MIB_TCP6TABLE_OWNER_PID for v6. +const tcpTableOwnerPidAll = 5 + +var ( + iphlpapi = syscall.NewLazyDLL("iphlpapi.dll") + getTCPTable = iphlpapi.NewProc("GetExtendedTcpTable") + // TODO: GetExtendedUdpTable also? if/when needed. +) + +type _MIB_TCPROW_OWNER_PID struct { + state uint32 + localAddr uint32 + localPort uint32 + remoteAddr uint32 + remotePort uint32 + pid uint32 +} + +type _MIB_TCP6ROW_OWNER_PID struct { + localAddr [16]byte + localScope uint32 + localPort uint32 + remoteAddr [16]byte + remoteScope uint32 + remotePort uint32 + state uint32 + pid uint32 +} + +func get() (*Table, error) { + t := new(Table) + if err := t.addEntries(windows.AF_INET); err != nil { + return nil, fmt.Errorf("failed to get IPv4 entries: %w", err) + } + if err := t.addEntries(windows.AF_INET6); err != nil { + return nil, fmt.Errorf("failed to get IPv6 entries: %w", err) + } + return t, nil +} + +func (t *Table) addEntries(fam int) error { + var size uint32 + var addr unsafe.Pointer + var buf []byte + for { + err, _, _ := getTCPTable.Call( + uintptr(addr), + uintptr(unsafe.Pointer(&size)), + 1, // sorted + uintptr(fam), + tcpTableOwnerPidAll, + 0, // reserved; "must be zero" + ) + if err == 0 { + break + } + if err == uintptr(syscall.ERROR_INSUFFICIENT_BUFFER) { + const maxSize = 10 << 20 + if size > maxSize || size < 4 { + return fmt.Errorf("unreasonable kernel-reported size %d", size) + } + buf = make([]byte, size) + addr = unsafe.Pointer(&buf[0]) + continue + } + return syscall.Errno(err) + } + if len(buf) < int(size) { + return errors.New("unexpected size growth from system call") + } + buf = buf[:size] + + numEntries := *(*uint32)(unsafe.Pointer(&buf[0])) + buf = buf[4:] + + var recSize int + switch fam { + case windows.AF_INET: + recSize = 6 * 4 + case windows.AF_INET6: + recSize = 6*4 + 16*2 + } + dataLen := numEntries * uint32(recSize) + if uint32(len(buf)) > dataLen { + buf = buf[:dataLen] + } + for len(buf) >= recSize { + switch fam { + case windows.AF_INET: + row := (*_MIB_TCPROW_OWNER_PID)(unsafe.Pointer(&buf[0])) + t.Entries = append(t.Entries, Entry{ + Local: ipport4(row.localAddr, port(&row.localPort)), + Remote: ipport4(row.remoteAddr, port(&row.remotePort)), + Pid: int(row.pid), + State: state(row.state), + }) + case windows.AF_INET6: + row := (*_MIB_TCP6ROW_OWNER_PID)(unsafe.Pointer(&buf[0])) + t.Entries = append(t.Entries, Entry{ + Local: ipport6(row.localAddr, row.localScope, port(&row.localPort)), + Remote: ipport6(row.remoteAddr, row.remoteScope, port(&row.remotePort)), + Pid: int(row.pid), + State: state(row.state), + }) + } + buf = buf[recSize:] + } + return nil +} + +var states = []string{ + "", + "CLOSED", + "LISTEN", + "SYN-SENT", + "SYN-RECEIVED", + "ESTABLISHED", + "FIN-WAIT-1", + "FIN-WAIT-2", + "CLOSE-WAIT", + "CLOSING", + "LAST-ACK", + "DELETE-TCB", +} + +func state(v uint32) string { + if v < uint32(len(states)) { + return states[v] + } + return fmt.Sprintf("unknown-state-%d", v) +} + +func ipport4(addr uint32, port uint16) netaddr.IPPort { + a4 := (*[4]byte)(unsafe.Pointer(&addr)) + return netaddr.IPPort{ + IP: netaddr.IPv4(a4[0], a4[1], a4[2], a4[3]), + Port: port, + } +} + +func ipport6(addr [16]byte, scope uint32, port uint16) netaddr.IPPort { + ip := netaddr.IPFrom16(addr) + if scope != 0 { + // TODO: something better here? + ip = ip.WithZone(fmt.Sprint(scope)) + } + return netaddr.IPPort{ + IP: ip, + Port: port, + } +} + +func port(v *uint32) uint16 { + p := (*[4]byte)(unsafe.Pointer(v)) + return binary.BigEndian.Uint16(p[:2]) +}