diff --git a/util/winutil/mksyscall.go b/util/winutil/mksyscall.go index 5fb915b41..afee73998 100644 --- a/util/winutil/mksyscall.go +++ b/util/winutil/mksyscall.go @@ -6,9 +6,11 @@ package winutil //go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go mksyscall.go //go:generate go run golang.org/x/tools/cmd/goimports -w zsyscall_windows.go +//sys dsGetDcName(computerName *uint16, domainName *uint16, domainGuid *windows.GUID, siteName *uint16, flags dsGetDcNameFlag, dcInfo **_DOMAIN_CONTROLLER_INFO) (ret error) = netapi32.DsGetDcNameW //sys expandEnvironmentStringsForUser(token windows.Token, src *uint16, dst *uint16, dstLen uint32) (err error) [int32(failretval)==0] = userenv.ExpandEnvironmentStringsForUserW //sys getApplicationRestartSettings(process windows.Handle, commandLine *uint16, commandLineLen *uint32, flags *uint32) (ret wingoes.HRESULT) = kernel32.GetApplicationRestartSettings //sys loadUserProfile(token windows.Token, profileInfo *_PROFILEINFO) (err error) [int32(failretval)==0] = userenv.LoadUserProfileW +//sys netValidateName(server *uint16, name *uint16, account *uint16, password *uint16, nameType _NETSETUP_NAME_TYPE) (ret error) = netapi32.NetValidateName //sys queryServiceConfig2(hService windows.Handle, infoLevel uint32, buf *byte, bufLen uint32, bytesNeeded *uint32) (err error) [failretval==0] = advapi32.QueryServiceConfig2W //sys registerApplicationRestart(cmdLineExclExeName *uint16, flags uint32) (ret wingoes.HRESULT) = kernel32.RegisterApplicationRestart //sys rmEndSession(session _RMHANDLE) (ret error) = rstrtmgr.RmEndSession diff --git a/util/winutil/userprofile_windows.go b/util/winutil/userprofile_windows.go index 6bedf420b..d2e6067c7 100644 --- a/util/winutil/userprofile_windows.go +++ b/util/winutil/userprofile_windows.go @@ -135,9 +135,36 @@ func (up *UserProfile) Close() error { } func getRoamingProfilePath(logf logger.Logf, token windows.Token, computerName, userName *uint16) (path *uint16, err error) { - // logf is for debugging/testing. - if logf == nil { - logf = logger.Discard + // logf is for debugging/testing. While we would normally replace a nil logf + // with logger.Discard, we're using explicit checks within this func so that + // we don't waste time allocating and converting UTF-16 strings unnecessarily. + var comp string + if logf != nil { + comp = windows.UTF16PtrToString(computerName) + user := windows.UTF16PtrToString(userName) + logf("BEGIN getRoamingProfilePath(%q, %q)", comp, user) + defer logf("END getRoamingProfilePath(%q, %q)", comp, user) + } + + isDomainName, err := isDomainName(computerName) + if err != nil { + return nil, err + } + if isDomainName { + if logf != nil { + logf("computerName %q is a domain, resolving...", comp) + } + dcInfo, err := resolveDomainController(computerName, nil) + if err != nil { + return nil, err + } + defer dcInfo.Close() + + computerName = dcInfo.DomainControllerName + if logf != nil { + dom := windows.UTF16PtrToString(computerName) + logf("%q resolved to %q", comp, dom) + } } var pbuf *byte @@ -147,7 +174,9 @@ func getRoamingProfilePath(logf logger.Logf, token windows.Token, computerName, defer windows.NetApiBufferFree(pbuf) ui4 := (*_USER_INFO_4)(unsafe.Pointer(pbuf)) - logf("getRoamingProfilePath: got %#v", *ui4) + if logf != nil { + logf("getRoamingProfilePath: got %#v", *ui4) + } profilePath := ui4.Profile if profilePath == nil { return nil, nil @@ -162,6 +191,10 @@ func getRoamingProfilePath(logf logger.Logf, token windows.Token, computerName, return nil, err } + if logf != nil { + logf("returning %q", windows.UTF16ToString(expanded[:])) + } + // This buffer is only used briefly, so we don't bother copying it into a shorter slice. return &expanded[0], nil } diff --git a/util/winutil/userprofile_windows_test.go b/util/winutil/userprofile_windows_test.go new file mode 100644 index 000000000..09dcfd596 --- /dev/null +++ b/util/winutil/userprofile_windows_test.go @@ -0,0 +1,24 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package winutil + +import ( + "testing" + + "golang.org/x/sys/windows" +) + +func TestGetRoamingProfilePath(t *testing.T) { + token := windows.GetCurrentProcessToken() + computerName, userName, err := getComputerAndUserName(token, nil) + if err != nil { + t.Fatal(err) + } + + if _, err := getRoamingProfilePath(t.Logf, token, computerName, userName); err != nil { + t.Error(err) + } + + // TODO(aaron): Flesh out better once can run tests under domain accounts. +} diff --git a/util/winutil/winutil_windows.go b/util/winutil/winutil_windows.go index f464d01d4..46fac4633 100644 --- a/util/winutil/winutil_windows.go +++ b/util/winutil/winutil_windows.go @@ -784,3 +784,147 @@ func SetNTString[NTS NTStr, BU BufUnit](nts *NTS, buf []BU) { panic("unknown type") } } + +type domainControllerAddressType uint32 + +const ( + //lint:ignore U1000 maps to a win32 API + _DS_INET_ADDRESS domainControllerAddressType = 1 + _DS_NETBIOS_ADDRESS domainControllerAddressType = 2 +) + +type domainControllerFlag uint32 + +const ( + //lint:ignore U1000 maps to a win32 API + _DS_PDC_FLAG domainControllerFlag = 0x00000001 + _DS_GC_FLAG domainControllerFlag = 0x00000004 + _DS_LDAP_FLAG domainControllerFlag = 0x00000008 + _DS_DS_FLAG domainControllerFlag = 0x00000010 + _DS_KDC_FLAG domainControllerFlag = 0x00000020 + _DS_TIMESERV_FLAG domainControllerFlag = 0x00000040 + _DS_CLOSEST_FLAG domainControllerFlag = 0x00000080 + _DS_WRITABLE_FLAG domainControllerFlag = 0x00000100 + _DS_GOOD_TIMESERV_FLAG domainControllerFlag = 0x00000200 + _DS_NDNC_FLAG domainControllerFlag = 0x00000400 + _DS_SELECT_SECRET_DOMAIN_6_FLAG domainControllerFlag = 0x00000800 + _DS_FULL_SECRET_DOMAIN_6_FLAG domainControllerFlag = 0x00001000 + _DS_WS_FLAG domainControllerFlag = 0x00002000 + _DS_DS_8_FLAG domainControllerFlag = 0x00004000 + _DS_DS_9_FLAG domainControllerFlag = 0x00008000 + _DS_DS_10_FLAG domainControllerFlag = 0x00010000 + _DS_KEY_LIST_FLAG domainControllerFlag = 0x00020000 + _DS_PING_FLAGS domainControllerFlag = 0x000FFFFF + _DS_DNS_CONTROLLER_FLAG domainControllerFlag = 0x20000000 + _DS_DNS_DOMAIN_FLAG domainControllerFlag = 0x40000000 + _DS_DNS_FOREST_FLAG domainControllerFlag = 0x80000000 +) + +type _DOMAIN_CONTROLLER_INFO struct { + DomainControllerName *uint16 + DomainControllerAddress *uint16 + DomainControllerAddressType domainControllerAddressType + DomainGuid windows.GUID + DomainName *uint16 + DnsForestName *uint16 + Flags domainControllerFlag + DcSiteName *uint16 + ClientSiteName *uint16 +} + +func (dci *_DOMAIN_CONTROLLER_INFO) Close() error { + if dci == nil { + return nil + } + return windows.NetApiBufferFree((*byte)(unsafe.Pointer(dci))) +} + +type dsGetDcNameFlag uint32 + +const ( + //lint:ignore U1000 maps to a win32 API + _DS_FORCE_REDISCOVERY dsGetDcNameFlag = 0x00000001 + _DS_DIRECTORY_SERVICE_REQUIRED dsGetDcNameFlag = 0x00000010 + _DS_DIRECTORY_SERVICE_PREFERRED dsGetDcNameFlag = 0x00000020 + _DS_GC_SERVER_REQUIRED dsGetDcNameFlag = 0x00000040 + _DS_PDC_REQUIRED dsGetDcNameFlag = 0x00000080 + _DS_BACKGROUND_ONLY dsGetDcNameFlag = 0x00000100 + _DS_IP_REQUIRED dsGetDcNameFlag = 0x00000200 + _DS_KDC_REQUIRED dsGetDcNameFlag = 0x00000400 + _DS_TIMESERV_REQUIRED dsGetDcNameFlag = 0x00000800 + _DS_WRITABLE_REQUIRED dsGetDcNameFlag = 0x00001000 + _DS_GOOD_TIMESERV_PREFERRED dsGetDcNameFlag = 0x00002000 + _DS_AVOID_SELF dsGetDcNameFlag = 0x00004000 + _DS_ONLY_LDAP_NEEDED dsGetDcNameFlag = 0x00008000 + _DS_IS_FLAT_NAME dsGetDcNameFlag = 0x00010000 + _DS_IS_DNS_NAME dsGetDcNameFlag = 0x00020000 + _DS_TRY_NEXTCLOSEST_SITE dsGetDcNameFlag = 0x00040000 + _DS_DIRECTORY_SERVICE_6_REQUIRED dsGetDcNameFlag = 0x00080000 + _DS_WEB_SERVICE_REQUIRED dsGetDcNameFlag = 0x00100000 + _DS_DIRECTORY_SERVICE_8_REQUIRED dsGetDcNameFlag = 0x00200000 + _DS_DIRECTORY_SERVICE_9_REQUIRED dsGetDcNameFlag = 0x00400000 + _DS_DIRECTORY_SERVICE_10_REQUIRED dsGetDcNameFlag = 0x00800000 + _DS_KEY_LIST_SUPPORT_REQUIRED dsGetDcNameFlag = 0x01000000 + _DS_RETURN_DNS_NAME dsGetDcNameFlag = 0x40000000 + _DS_RETURN_FLAT_NAME dsGetDcNameFlag = 0x80000000 +) + +func resolveDomainController(domainName *uint16, domainGUID *windows.GUID) (*_DOMAIN_CONTROLLER_INFO, error) { + const flags = _DS_DIRECTORY_SERVICE_REQUIRED | _DS_IS_FLAT_NAME | _DS_RETURN_DNS_NAME + var dcInfo *_DOMAIN_CONTROLLER_INFO + if err := dsGetDcName(nil, domainName, domainGUID, nil, flags, &dcInfo); err != nil { + return nil, err + } + return dcInfo, nil +} + +// ResolveDomainController resolves the DNS name of the nearest available +// domain controller for the domain specified by domainName. +func ResolveDomainController(domainName string) (string, error) { + domainName16, err := windows.UTF16PtrFromString(domainName) + if err != nil { + return "", err + } + + dcInfo, err := resolveDomainController(domainName16, nil) + if err != nil { + return "", err + } + defer dcInfo.Close() + + return windows.UTF16PtrToString(dcInfo.DomainControllerName), nil +} + +type _NETSETUP_NAME_TYPE int32 + +const ( + _NetSetupUnknown _NETSETUP_NAME_TYPE = 0 + _NetSetupMachine _NETSETUP_NAME_TYPE = 1 + _NetSetupWorkgroup _NETSETUP_NAME_TYPE = 2 + _NetSetupDomain _NETSETUP_NAME_TYPE = 3 + _NetSetupNonExistentDomain _NETSETUP_NAME_TYPE = 4 + _NetSetupDnsMachine _NETSETUP_NAME_TYPE = 5 +) + +func isDomainName(name *uint16) (bool, error) { + err := netValidateName(nil, name, nil, nil, _NetSetupDomain) + switch err { + case nil: + return true, nil + case windows.ERROR_NO_SUCH_DOMAIN: + return false, nil + default: + return false, err + } +} + +// IsDomainName checks whether name represents an existing domain reachable by +// the current machine. +func IsDomainName(name string) (bool, error) { + name16, err := windows.UTF16PtrFromString(name) + if err != nil { + return false, err + } + + return isDomainName(name16) +} diff --git a/util/winutil/zsyscall_windows.go b/util/winutil/zsyscall_windows.go index d5d2d8721..b4674dff3 100644 --- a/util/winutil/zsyscall_windows.go +++ b/util/winutil/zsyscall_windows.go @@ -42,12 +42,15 @@ func errnoErr(e syscall.Errno) error { var ( modadvapi32 = windows.NewLazySystemDLL("advapi32.dll") modkernel32 = windows.NewLazySystemDLL("kernel32.dll") + modnetapi32 = windows.NewLazySystemDLL("netapi32.dll") modrstrtmgr = windows.NewLazySystemDLL("rstrtmgr.dll") moduserenv = windows.NewLazySystemDLL("userenv.dll") procQueryServiceConfig2W = modadvapi32.NewProc("QueryServiceConfig2W") procGetApplicationRestartSettings = modkernel32.NewProc("GetApplicationRestartSettings") procRegisterApplicationRestart = modkernel32.NewProc("RegisterApplicationRestart") + procDsGetDcNameW = modnetapi32.NewProc("DsGetDcNameW") + procNetValidateName = modnetapi32.NewProc("NetValidateName") procRmEndSession = modrstrtmgr.NewProc("RmEndSession") procRmGetList = modrstrtmgr.NewProc("RmGetList") procRmJoinSession = modrstrtmgr.NewProc("RmJoinSession") @@ -78,6 +81,22 @@ func registerApplicationRestart(cmdLineExclExeName *uint16, flags uint32) (ret w return } +func dsGetDcName(computerName *uint16, domainName *uint16, domainGuid *windows.GUID, siteName *uint16, flags dsGetDcNameFlag, dcInfo **_DOMAIN_CONTROLLER_INFO) (ret error) { + r0, _, _ := syscall.Syscall6(procDsGetDcNameW.Addr(), 6, uintptr(unsafe.Pointer(computerName)), uintptr(unsafe.Pointer(domainName)), uintptr(unsafe.Pointer(domainGuid)), uintptr(unsafe.Pointer(siteName)), uintptr(flags), uintptr(unsafe.Pointer(dcInfo))) + if r0 != 0 { + ret = syscall.Errno(r0) + } + return +} + +func netValidateName(server *uint16, name *uint16, account *uint16, password *uint16, nameType _NETSETUP_NAME_TYPE) (ret error) { + r0, _, _ := syscall.Syscall6(procNetValidateName.Addr(), 5, uintptr(unsafe.Pointer(server)), uintptr(unsafe.Pointer(name)), uintptr(unsafe.Pointer(account)), uintptr(unsafe.Pointer(password)), uintptr(nameType), 0) + if r0 != 0 { + ret = syscall.Errno(r0) + } + return +} + func rmEndSession(session _RMHANDLE) (ret error) { r0, _, _ := syscall.Syscall(procRmEndSession.Addr(), 1, uintptr(session), 0, 0) if r0 != 0 {