diff --git a/util/osdiag/mksyscall.go b/util/osdiag/mksyscall.go index 72e4475cb..03d531d0d 100644 --- a/util/osdiag/mksyscall.go +++ b/util/osdiag/mksyscall.go @@ -7,3 +7,6 @@ package osdiag //go:generate go run golang.org/x/tools/cmd/goimports -w zsyscall_windows.go //sys regEnumValue(key registry.Key, index uint32, valueName *uint16, valueNameLen *uint32, reserved *uint32, valueType *uint32, pData *byte, cbData *uint32) (ret error) [failretval!=0] = advapi32.RegEnumValueW +//sys wscEnumProtocols(iProtocols *int32, protocolBuffer *wsaProtocolInfo, bufLen *uint32, errno *int32) (ret int32) = ws2_32.WSCEnumProtocols +//sys wscGetProviderInfo(providerId *windows.GUID, infoType _WSC_PROVIDER_INFO_TYPE, info unsafe.Pointer, infoSize *uintptr, flags uint32, errno *int32) (ret int32) = ws2_32.WSCGetProviderInfo +//sys wscGetProviderPath(providerId *windows.GUID, providerDllPath *uint16, providerDllPathLen *int32, errno *int32) (ret int32) = ws2_32.WSCGetProviderPath diff --git a/util/osdiag/osdiag_windows.go b/util/osdiag/osdiag_windows.go index 2b23df073..12ea366b1 100644 --- a/util/osdiag/osdiag_windows.go +++ b/util/osdiag/osdiag_windows.go @@ -22,6 +22,10 @@ import ( "tailscale.com/util/winutil/authenticode" ) +var ( + errUnexpectedResult = errors.New("API call returned an unexpected value") +) + const ( maxBinaryValueLen = 128 // we'll truncate any binary values longer than this maxRegValueNameLen = 16384 // maximum length supported by Windows + 1 @@ -38,8 +42,9 @@ func logSupportInfo(logf logger.Logf, reason LogSupportInfoReason) { } const ( - supportInfoKeyModules = "modules" - supportInfoKeyRegistry = "registry" + supportInfoKeyModules = "modules" + supportInfoKeyRegistry = "registry" + supportInfoKeyWinsockLSP = "winsockLSP" ) func getSupportInfo(w io.Writer, reason LogSupportInfoReason) error { @@ -59,6 +64,13 @@ func getSupportInfo(w io.Writer, reason LogSupportInfoReason) error { } else { output[supportInfoKeyModules] = err } + + lspInfo, err := getWinsockLSPInfo() + if err == nil { + output[supportInfoKeyWinsockLSP] = lspInfo + } else { + output[supportInfoKeyWinsockLSP] = err + } } enc := json.NewEncoder(w) @@ -228,8 +240,8 @@ func (mi *moduleInfo) setVersionInfo() { var errAssertingType = errors.New("asserting DataDirectory type") -func (mi *moduleInfo) setDebugInfo(base uintptr, size uint32) { - pem, err := pe.NewPEFromBaseAddressAndSize(base, size) +func (mi *moduleInfo) setDebugInfo() { + pem, err := pe.NewPEFromBaseAddressAndSize(mi.BaseAddress, mi.Size) if err != nil { mi.DebugInfoErr = err return @@ -320,7 +332,7 @@ func getModuleInfo() (map[string]moduleInfo, error) { } entry.setVersionInfo() - entry.setDebugInfo(base, size) + entry.setDebugInfo() entry.setAuthenticodeInfo() result[name] = entry @@ -328,3 +340,145 @@ func getModuleInfo() (map[string]moduleInfo, error) { return result, nil } + +type _WSC_PROVIDER_INFO_TYPE int32 + +const ( + providerInfoLspCategories _WSC_PROVIDER_INFO_TYPE = 0 +) + +const ( + _SOCKET_ERROR = -1 +) + +// Note that wsaProtocolInfo needs to be identical to windows.WSAProtocolInfo; +// the purpose of this type is to have the ability to use it as a reciever in +// the path and categoryFlags funcs defined below. +type wsaProtocolInfo windows.WSAProtocolInfo + +func (pi *wsaProtocolInfo) path() (string, error) { + var errno int32 + var buf [windows.MAX_PATH]uint16 + bufCount := int32(len(buf)) + ret := wscGetProviderPath(&pi.ProviderId, &buf[0], &bufCount, &errno) + if ret == _SOCKET_ERROR { + return "", windows.Errno(errno) + } + if ret != 0 { + return "", errUnexpectedResult + } + + return windows.UTF16ToString(buf[:bufCount]), nil +} + +func (pi *wsaProtocolInfo) categoryFlags() (uint32, error) { + var errno int32 + var result uint32 + bufLen := uintptr(unsafe.Sizeof(result)) + ret := wscGetProviderInfo(&pi.ProviderId, providerInfoLspCategories, unsafe.Pointer(&result), &bufLen, 0, &errno) + if ret == _SOCKET_ERROR { + return 0, windows.Errno(errno) + } + if ret != 0 { + return 0, errUnexpectedResult + } + + return result, nil +} + +type wsaProtocolInfoOutput struct { + Description string `json:"description,omitempty"` + Version int32 `json:"version"` + AddressFamily int32 `json:"addressFamily"` + SocketType int32 `json:"socketType"` + Protocol int32 `json:"protocol"` + ServiceFlags1 string `json:"serviceFlags1"` + ProviderFlags string `json:"providerFlags"` + Path string `json:"path,omitempty"` + PathErr error `json:"pathErr,omitempty"` + Category string `json:"category,omitempty"` + CategoryErr error `json:"categoryErr,omitempty"` + BaseProviderID string `json:"baseProviderID,omitempty"` + LayerProviderID string `json:"layerProviderID,omitempty"` + Chain []uint32 `json:"chain,omitempty"` +} + +func getWinsockLSPInfo() (map[uint32]wsaProtocolInfoOutput, error) { + protocols, err := enumWinsockProtocols() + if err != nil { + return nil, err + } + + result := make(map[uint32]wsaProtocolInfoOutput, len(protocols)) + for _, p := range protocols { + v := wsaProtocolInfoOutput{ + Description: windows.UTF16ToString(p.ProtocolName[:]), + Version: p.Version, + AddressFamily: p.AddressFamily, + SocketType: p.SocketType, + Protocol: p.Protocol, + ServiceFlags1: fmt.Sprintf("0x%08X", p.ServiceFlags1), // Serializing as hex string to make the flags easier to decode by human inspection + ProviderFlags: fmt.Sprintf("0x%08X", p.ProviderFlags), + } + + switch p.ProtocolChain.ChainLen { + case windows.BASE_PROTOCOL: + v.BaseProviderID = p.ProviderId.String() + case windows.LAYERED_PROTOCOL: + v.LayerProviderID = p.ProviderId.String() + default: + v.Chain = p.ProtocolChain.ChainEntries[:p.ProtocolChain.ChainLen] + } + + // Queries that are only valid for base and layered protocols (not chains) + if v.Chain == nil { + path, err := p.path() + if err == nil { + v.Path = strings.ToLower(path) + } else { + v.PathErr = err + } + + category, err := p.categoryFlags() + if err == nil { + v.Category = fmt.Sprintf("0x%08X", category) + } else if !errors.Is(err, windows.WSAEINVALIDPROVIDER) { + // WSAEINVALIDPROVIDER == "no category info found", so we only log + // errors other than that one. + v.CategoryErr = err + } + } + + // Chains reference other providers using catalog entry IDs, so we use that + // value as the key in our map. + result[p.CatalogEntryId] = v + } + + return result, nil +} + +func enumWinsockProtocols() ([]wsaProtocolInfo, error) { + // Get the required size + var errno int32 + var bytesReqd uint32 + ret := wscEnumProtocols(nil, nil, &bytesReqd, &errno) + if ret != _SOCKET_ERROR { + return nil, errUnexpectedResult + } + if e := windows.Errno(errno); e != windows.WSAENOBUFS { + return nil, e + } + + // Allocate + szEntry := uint32(unsafe.Sizeof(wsaProtocolInfo{})) + buf := make([]wsaProtocolInfo, bytesReqd/szEntry) + + // Now do the query for real + bufLen := uint32(len(buf)) * szEntry + ret = wscEnumProtocols(nil, &buf[0], &bufLen, &errno) + if ret == _SOCKET_ERROR { + return nil, windows.Errno(errno) + } + + return buf, nil +} diff --git a/util/osdiag/zsyscall_windows.go b/util/osdiag/zsyscall_windows.go index caeb245d2..f9d482931 100644 --- a/util/osdiag/zsyscall_windows.go +++ b/util/osdiag/zsyscall_windows.go @@ -40,8 +40,12 @@ func errnoErr(e syscall.Errno) error { var ( modadvapi32 = windows.NewLazySystemDLL("advapi32.dll") + modws2_32 = windows.NewLazySystemDLL("ws2_32.dll") - procRegEnumValueW = modadvapi32.NewProc("RegEnumValueW") + procRegEnumValueW = modadvapi32.NewProc("RegEnumValueW") + procWSCEnumProtocols = modws2_32.NewProc("WSCEnumProtocols") + procWSCGetProviderInfo = modws2_32.NewProc("WSCGetProviderInfo") + procWSCGetProviderPath = modws2_32.NewProc("WSCGetProviderPath") ) func regEnumValue(key registry.Key, index uint32, valueName *uint16, valueNameLen *uint32, reserved *uint32, valueType *uint32, pData *byte, cbData *uint32) (ret error) { @@ -51,3 +55,21 @@ func regEnumValue(key registry.Key, index uint32, valueName *uint16, valueNameLe } return } + +func wscEnumProtocols(iProtocols *int32, protocolBuffer *wsaProtocolInfo, bufLen *uint32, errno *int32) (ret int32) { + r0, _, _ := syscall.Syscall6(procWSCEnumProtocols.Addr(), 4, uintptr(unsafe.Pointer(iProtocols)), uintptr(unsafe.Pointer(protocolBuffer)), uintptr(unsafe.Pointer(bufLen)), uintptr(unsafe.Pointer(errno)), 0, 0) + ret = int32(r0) + return +} + +func wscGetProviderInfo(providerId *windows.GUID, infoType _WSC_PROVIDER_INFO_TYPE, info unsafe.Pointer, infoSize *uintptr, flags uint32, errno *int32) (ret int32) { + r0, _, _ := syscall.Syscall6(procWSCGetProviderInfo.Addr(), 6, uintptr(unsafe.Pointer(providerId)), uintptr(infoType), uintptr(info), uintptr(unsafe.Pointer(infoSize)), uintptr(flags), uintptr(unsafe.Pointer(errno))) + ret = int32(r0) + return +} + +func wscGetProviderPath(providerId *windows.GUID, providerDllPath *uint16, providerDllPathLen *int32, errno *int32) (ret int32) { + r0, _, _ := syscall.Syscall6(procWSCGetProviderPath.Addr(), 4, uintptr(unsafe.Pointer(providerId)), uintptr(unsafe.Pointer(providerDllPath)), uintptr(unsafe.Pointer(providerDllPathLen)), uintptr(unsafe.Pointer(errno)), 0, 0) + ret = int32(r0) + return +}