diff --git a/net/netstat/netstat.go b/net/netstat/netstat.go index 13edf0afb..cef2ba6b2 100644 --- a/net/netstat/netstat.go +++ b/net/netstat/netstat.go @@ -17,6 +17,7 @@ type Entry struct { Local, Remote netip.AddrPort Pid int State string // TODO: type? + OSMetadata OSMetadata } // Table contains local machine's TCP connection entries. diff --git a/net/netstat/netstat_noimpl.go b/net/netstat/netstat_noimpl.go index fc16661fd..82a2ef262 100644 --- a/net/netstat/netstat_noimpl.go +++ b/net/netstat/netstat_noimpl.go @@ -6,6 +6,10 @@ package netstat +// OSMetadata includes any additional OS-specific information that may be +// obtained during the retrieval of a given Entry. +type OSMetadata struct{} + func get() (*Table, error) { return nil, ErrNotImplemented } diff --git a/net/netstat/netstat_windows.go b/net/netstat/netstat_windows.go index 0f7f1fed9..09756fe1d 100644 --- a/net/netstat/netstat_windows.go +++ b/net/netstat/netstat_windows.go @@ -10,7 +10,6 @@ import ( "fmt" "math/bits" "net/netip" - "syscall" "unsafe" "github.com/josharian/native" @@ -18,37 +17,99 @@ import ( "tailscale.com/net/netaddr" ) +// OSMetadata includes any additional OS-specific information that may be +// obtained during the retrieval of a given Entry. +type OSMetadata interface { + GetModule() (string, error) +} + // See https://docs.microsoft.com/en-us/windows/win32/api/iphlpapi/nf-iphlpapi-getextendedtcptable -// TCP_TABLE_OWNER_PID_ALL means to include the PID info. The table type +// TCP_TABLE_OWNER_MODULE_ALL means to include the PID and module. The table type // we get back from Windows depends on AF_INET vs AF_INET6: -// MIB_TCPTABLE_OWNER_PID for v4 or MIB_TCP6TABLE_OWNER_PID for v6. -const tcpTableOwnerPidAll = 5 +// MIB_TCPTABLE_OWNER_MODULE for v4 or MIB_TCP6TABLE_OWNER_MODULE for v6. +const tcpTableOwnerModuleAll = 8 + +// TCPIP_OWNER_MODULE_BASIC_INFO means to request "basic information" about the +// owner module. +const tcpipOwnerModuleBasicInfo = 0 var ( - iphlpapi = syscall.NewLazyDLL("iphlpapi.dll") - getTCPTable = iphlpapi.NewProc("GetExtendedTcpTable") + iphlpapi = windows.NewLazySystemDLL("iphlpapi.dll") + getTCPTable = iphlpapi.NewProc("GetExtendedTcpTable") + getOwnerModuleFromTcpEntry = iphlpapi.NewProc("GetOwnerModuleFromTcpEntry") + getOwnerModuleFromTcp6Entry = iphlpapi.NewProc("GetOwnerModuleFromTcp6Entry") // TODO: GetExtendedUdpTable also? if/when needed. ) -type _MIB_TCPROW_OWNER_PID struct { - state uint32 - localAddr uint32 - localPort uint32 - remoteAddr uint32 - remotePort uint32 - pid uint32 +// See https://web.archive.org/web/20221219211913/https://learn.microsoft.com/en-us/windows/win32/api/tcpmib/ns-tcpmib-mib_tcprow_owner_module +type _MIB_TCPROW_OWNER_MODULE struct { + state uint32 + localAddr uint32 + localPort uint32 + remoteAddr uint32 + remotePort uint32 + pid uint32 + createTimestamp int64 + owningModuleInfo [16]uint64 +} + +func (row *_MIB_TCPROW_OWNER_MODULE) asEntry() Entry { + return Entry{ + Local: ipport4(row.localAddr, port(&row.localPort)), + Remote: ipport4(row.remoteAddr, port(&row.remotePort)), + Pid: int(row.pid), + State: state(row.state), + OSMetadata: row, + } +} + +type _MIB_TCPTABLE_OWNER_MODULE struct { + numEntries uint32 + table _MIB_TCPROW_OWNER_MODULE +} + +func (m *_MIB_TCPTABLE_OWNER_MODULE) getRows() []_MIB_TCPROW_OWNER_MODULE { + return unsafe.Slice(&m.table, m.numEntries) +} + +// See https://web.archive.org/web/20221219212442/https://learn.microsoft.com/en-us/windows/win32/api/tcpmib/ns-tcpmib-mib_tcp6row_owner_module +type _MIB_TCP6ROW_OWNER_MODULE struct { + localAddr [16]byte + localScope uint32 + localPort uint32 + remoteAddr [16]byte + remoteScope uint32 + remotePort uint32 + state uint32 + pid uint32 + createTimestamp int64 + owningModuleInfo [16]uint64 +} + +func (row *_MIB_TCP6ROW_OWNER_MODULE) asEntry() Entry { + return Entry{ + Local: ipport6(row.localAddr, row.localScope, port(&row.localPort)), + Remote: ipport6(row.remoteAddr, row.remoteScope, port(&row.remotePort)), + Pid: int(row.pid), + State: state(row.state), + OSMetadata: row, + } +} + +type _MIB_TCP6TABLE_OWNER_MODULE struct { + numEntries uint32 + table _MIB_TCP6ROW_OWNER_MODULE } -type _MIB_TCP6ROW_OWNER_PID struct { - localAddr [16]byte - localScope uint32 - localPort uint32 - remoteAddr [16]byte - remoteScope uint32 - remotePort uint32 - state uint32 - pid uint32 +func (m *_MIB_TCP6TABLE_OWNER_MODULE) getRows() []_MIB_TCP6ROW_OWNER_MODULE { + return unsafe.Slice(&m.table, m.numEntries) +} + +// See https://web.archive.org/web/20221219213143/https://learn.microsoft.com/en-us/windows/win32/api/iprtrmib/ns-iprtrmib-tcpip_owner_module_basic_info +type _TCPIP_OWNER_MODULE_BASIC_INFO struct { + moduleName *uint16 + modulePath *uint16 } func get() (*Table, error) { @@ -72,13 +133,13 @@ func (t *Table) addEntries(fam int) error { uintptr(unsafe.Pointer(&size)), 1, // sorted uintptr(fam), - tcpTableOwnerPidAll, + tcpTableOwnerModuleAll, 0, // reserved; "must be zero" ) if err == 0 { break } - if err == uintptr(syscall.ERROR_INSUFFICIENT_BUFFER) { + if err == uintptr(windows.ERROR_INSUFFICIENT_BUFFER) { const maxSize = 10 << 20 if size > maxSize || size < 4 { return fmt.Errorf("unreasonable kernel-reported size %d", size) @@ -87,48 +148,28 @@ func (t *Table) addEntries(fam int) error { addr = unsafe.Pointer(&buf[0]) continue } - return syscall.Errno(err) + return windows.Errno(err) } if len(buf) < int(size) { return errors.New("unexpected size growth from system call") } buf = buf[:size] - numEntries := native.Endian.Uint32(buf[:4]) - buf = buf[4:] - - var recSize int switch fam { case windows.AF_INET: - recSize = 6 * 4 + info := (*_MIB_TCPTABLE_OWNER_MODULE)(unsafe.Pointer(&buf[0])) + rows := info.getRows() + for _, row := range rows { + t.Entries = append(t.Entries, row.asEntry()) + } case windows.AF_INET6: - recSize = 6*4 + 16*2 - } - dataLen := numEntries * uint32(recSize) - if uint32(len(buf)) > dataLen { - buf = buf[:dataLen] - } - for len(buf) >= recSize { - switch fam { - case windows.AF_INET: - row := (*_MIB_TCPROW_OWNER_PID)(unsafe.Pointer(&buf[0])) - t.Entries = append(t.Entries, Entry{ - Local: ipport4(row.localAddr, port(&row.localPort)), - Remote: ipport4(row.remoteAddr, port(&row.remotePort)), - Pid: int(row.pid), - State: state(row.state), - }) - case windows.AF_INET6: - row := (*_MIB_TCP6ROW_OWNER_PID)(unsafe.Pointer(&buf[0])) - t.Entries = append(t.Entries, Entry{ - Local: ipport6(row.localAddr, row.localScope, port(&row.localPort)), - Remote: ipport6(row.remoteAddr, row.remoteScope, port(&row.remotePort)), - Pid: int(row.pid), - State: state(row.state), - }) + info := (*_MIB_TCP6TABLE_OWNER_MODULE)(unsafe.Pointer(&buf[0])) + rows := info.getRows() + for _, row := range rows { + t.Entries = append(t.Entries, row.asEntry()) } - buf = buf[recSize:] } + return nil } @@ -178,3 +219,43 @@ func port(v *uint32) uint16 { } return uint16(*v >> 16) } + +type moduleInfoConstraint interface { + _MIB_TCPROW_OWNER_MODULE | _MIB_TCP6ROW_OWNER_MODULE +} + +func moduleInfo[entryType moduleInfoConstraint](entry *entryType, proc *windows.LazyProc) (string, error) { + var buf []byte + var desiredLen uint32 + var addr unsafe.Pointer + + for { + e, _, _ := proc.Call( + uintptr(unsafe.Pointer(entry)), + uintptr(tcpipOwnerModuleBasicInfo), + uintptr(addr), + uintptr(unsafe.Pointer(&desiredLen)), + ) + err := windows.Errno(e) + if err == windows.ERROR_SUCCESS { + break + } + if err != windows.ERROR_INSUFFICIENT_BUFFER { + return "", err + } + + buf = make([]byte, desiredLen) + addr = unsafe.Pointer(&buf[0]) + } + + basicInfo := (*_TCPIP_OWNER_MODULE_BASIC_INFO)(addr) + return windows.UTF16PtrToString(basicInfo.moduleName), nil +} + +func (m *_MIB_TCPROW_OWNER_MODULE) GetModule() (string, error) { + return moduleInfo(m, getOwnerModuleFromTcpEntry) +} + +func (m *_MIB_TCP6ROW_OWNER_MODULE) GetModule() (string, error) { + return moduleInfo(m, getOwnerModuleFromTcp6Entry) +} diff --git a/portlist/portlist_windows.go b/portlist/portlist_windows.go index 811e4f0aa..c79461ef2 100644 --- a/portlist/portlist_windows.go +++ b/portlist/portlist_windows.go @@ -5,12 +5,8 @@ package portlist import ( - "path/filepath" - "strings" - "syscall" "time" - "golang.org/x/sys/windows" "tailscale.com/net/netstat" ) @@ -26,7 +22,7 @@ func init() { type famPort struct { proto string port uint16 - pid uintptr + pid uint32 } type windowsImpl struct { @@ -69,19 +65,25 @@ func (im *windowsImpl) AppendListeningPorts(base []Port) ([]Port, error) { fp := famPort{ proto: "tcp", // TODO(bradfitz): UDP too; add to netstat port: e.Local.Port(), - pid: uintptr(e.Pid), + pid: uint32(e.Pid), } pm, ok := im.known[fp] if ok { pm.keep = true continue } + var process string + if e.OSMetadata != nil { + if module, err := e.OSMetadata.GetModule(); err == nil { + process = module + } + } pm = &portMeta{ keep: true, port: Port{ Proto: "tcp", Port: e.Local.Port(), - Process: procNameOfPid(e.Pid), + Process: process, }, } im.known[fp] = pm @@ -94,27 +96,6 @@ func (im *windowsImpl) AppendListeningPorts(base []Port) ([]Port, error) { } ret = append(ret, m.port) } - return sortAndDedup(ret), nil -} -func procNameOfPid(pid int) string { - const da = windows.PROCESS_QUERY_LIMITED_INFORMATION - h, err := syscall.OpenProcess(da, false, uint32(pid)) - if err != nil { - return "" - } - defer syscall.CloseHandle(h) - - var buf [512]uint16 - var size = uint32(len(buf)) - if err := windows.QueryFullProcessImageName(windows.Handle(h), 0, &buf[0], &size); err != nil { - return "" - } - name := filepath.Base(windows.UTF16ToString(buf[:])) - if name == "." { - return "" - } - name = strings.TrimSuffix(name, ".exe") - name = strings.TrimSuffix(name, ".EXE") - return name + return sortAndDedup(ret), nil }