diff --git a/wgengine/winnet/winnet.go b/wgengine/winnet/winnet.go index 582566367..a2abc9791 100644 --- a/wgengine/winnet/winnet.go +++ b/wgengine/winnet/winnet.go @@ -9,6 +9,7 @@ package winnet import ( "fmt" + "syscall" "unsafe" "github.com/go-ole/go-ole" @@ -45,6 +46,23 @@ type INetwork struct { ole.IDispatch } +type INetworkVtbl struct { + ole.IDispatchVtbl + GetName uintptr + SetName uintptr + GetDescription uintptr + SetDescription uintptr + GetNetworkId uintptr + GetDomainType uintptr + GetNetworkConnections uintptr + GetTimeCreatedAndConnected uintptr + Get_IsConnectedToInternet uintptr + Get_IsConnected uintptr + GetConnectivity uintptr + GetCategory uintptr + SetCategory uintptr +} + func NewNetworkListManager(c *ole.Connection) (*NetworkListManager, error) { err := c.Create(CLSID_NetworkListManager) if err != nil { @@ -124,16 +142,35 @@ func (n *INetwork) GetName() (string, error) { } func (n *INetwork) GetCategory() (int32, error) { - v, err := n.CallMethod("GetCategory") - if err != nil { - return 0, err + var result int32 + + r, _, _ := syscall.SyscallN( + n.VTable().GetCategory, + uintptr(unsafe.Pointer(n)), + uintptr(unsafe.Pointer(&result)), + ) + if int32(r) < 0 { + return 0, ole.NewError(r) } - return v.Value().(int32), err + + return result, nil } -func (n *INetwork) SetCategory(v uint32) error { - _, err := n.CallMethod("SetCategory", v) - return err +func (n *INetwork) SetCategory(v int32) error { + r, _, _ := syscall.SyscallN( + n.VTable().SetCategory, + uintptr(unsafe.Pointer(n)), + uintptr(v), + ) + if int32(r) < 0 { + return ole.NewError(r) + } + + return nil +} + +func (n *INetwork) VTable() *INetworkVtbl { + return (*INetworkVtbl)(unsafe.Pointer(n.RawVTable)) } func (v *INetworkConnection) VTable() *INetworkConnectionVtbl { @@ -141,17 +178,16 @@ func (v *INetworkConnection) VTable() *INetworkConnectionVtbl { } func (v *INetworkConnection) GetNetwork() (*INetwork, error) { - nraw, err := v.CallMethod("GetNetwork") - if err != nil { - return nil, err + var result *INetwork + + r, _, _ := syscall.SyscallN( + v.VTable().GetNetwork, + uintptr(unsafe.Pointer(v)), + uintptr(unsafe.Pointer(&result)), + ) + if int32(r) < 0 { + return nil, ole.NewError(r) } - n := nraw.ToIDispatch() - if n == nil { - return nil, fmt.Errorf("GetNetwork: nil IDispatch") - } - if err != nil { - return nil, err - } - return (*INetwork)(unsafe.Pointer(n)), nil + return result, nil }