From a44687e71f54dc0bc5c60148350dc75de3600b71 Mon Sep 17 00:00:00 2001 From: Aaron Klotz Date: Wed, 26 Oct 2022 14:41:29 -0600 Subject: [PATCH] wgengine/winnet: invoke some COM methods directly instead of through IDispatch. Intermittently in the wild we are seeing failures when calling `INetworkConnection::GetNetwork`. It is unclear what the root cause is, but what is clear is that the error is happening inside the object's `IDispatch` invoker (as opposed to the method implementation itself). This patch replaces our wrapper for `INetworkConnection::GetNetwork` with an alternate implementation that directly invokes the method, instead of using `IDispatch`. I also replaced the implementations of `INetwork::SetCategory` and `INetwork::GetCategory` while I was there. This patch is speculative and tightly-scoped so that we could possibly add it to a dot-release if necessary. Updates https://github.com/tailscale/tailscale/issues/4134 Updates https://github.com/tailscale/tailscale/issues/6037 Signed-off-by: Aaron Klotz --- wgengine/winnet/winnet.go | 72 +++++++++++++++++++++++++++++---------- 1 file changed, 54 insertions(+), 18 deletions(-) diff --git a/wgengine/winnet/winnet.go b/wgengine/winnet/winnet.go index 582566367..a2abc9791 100644 --- a/wgengine/winnet/winnet.go +++ b/wgengine/winnet/winnet.go @@ -9,6 +9,7 @@ package winnet import ( "fmt" + "syscall" "unsafe" "github.com/go-ole/go-ole" @@ -45,6 +46,23 @@ type INetwork struct { ole.IDispatch } +type INetworkVtbl struct { + ole.IDispatchVtbl + GetName uintptr + SetName uintptr + GetDescription uintptr + SetDescription uintptr + GetNetworkId uintptr + GetDomainType uintptr + GetNetworkConnections uintptr + GetTimeCreatedAndConnected uintptr + Get_IsConnectedToInternet uintptr + Get_IsConnected uintptr + GetConnectivity uintptr + GetCategory uintptr + SetCategory uintptr +} + func NewNetworkListManager(c *ole.Connection) (*NetworkListManager, error) { err := c.Create(CLSID_NetworkListManager) if err != nil { @@ -124,16 +142,35 @@ func (n *INetwork) GetName() (string, error) { } func (n *INetwork) GetCategory() (int32, error) { - v, err := n.CallMethod("GetCategory") - if err != nil { - return 0, err + var result int32 + + r, _, _ := syscall.SyscallN( + n.VTable().GetCategory, + uintptr(unsafe.Pointer(n)), + uintptr(unsafe.Pointer(&result)), + ) + if int32(r) < 0 { + return 0, ole.NewError(r) } - return v.Value().(int32), err + + return result, nil } -func (n *INetwork) SetCategory(v uint32) error { - _, err := n.CallMethod("SetCategory", v) - return err +func (n *INetwork) SetCategory(v int32) error { + r, _, _ := syscall.SyscallN( + n.VTable().SetCategory, + uintptr(unsafe.Pointer(n)), + uintptr(v), + ) + if int32(r) < 0 { + return ole.NewError(r) + } + + return nil +} + +func (n *INetwork) VTable() *INetworkVtbl { + return (*INetworkVtbl)(unsafe.Pointer(n.RawVTable)) } func (v *INetworkConnection) VTable() *INetworkConnectionVtbl { @@ -141,17 +178,16 @@ func (v *INetworkConnection) VTable() *INetworkConnectionVtbl { } func (v *INetworkConnection) GetNetwork() (*INetwork, error) { - nraw, err := v.CallMethod("GetNetwork") - if err != nil { - return nil, err + var result *INetwork + + r, _, _ := syscall.SyscallN( + v.VTable().GetNetwork, + uintptr(unsafe.Pointer(v)), + uintptr(unsafe.Pointer(&result)), + ) + if int32(r) < 0 { + return nil, ole.NewError(r) } - n := nraw.ToIDispatch() - if n == nil { - return nil, fmt.Errorf("GetNetwork: nil IDispatch") - } - if err != nil { - return nil, err - } - return (*INetwork)(unsafe.Pointer(n)), nil + return result, nil }